[
  {
    "path": ".clang-format",
    "content": "---\nLanguage:        Cpp\n# BasedOnStyle:  Google\nAccessModifierOffset: -1\nAlignAfterOpenBracket: AlwaysBreak\nAlignArrayOfStructures: None\nAlignConsecutiveAssignments:\n  Enabled:         false\n  AcrossEmptyLines: false\n  AcrossComments:  false\n  AlignCompound:   false\n  PadOperators:    true\nAlignConsecutiveBitFields:\n  Enabled:         false\n  AcrossEmptyLines: false\n  AcrossComments:  false\n  AlignCompound:   false\n  PadOperators:    false\nAlignConsecutiveDeclarations:\n  Enabled:         false\n  AcrossEmptyLines: false\n  AcrossComments:  false\n  AlignCompound:   false\n  PadOperators:    false\nAlignConsecutiveMacros:\n  Enabled:         false\n  AcrossEmptyLines: false\n  AcrossComments:  false\n  AlignCompound:   false\n  PadOperators:    false\nAlignEscapedNewlines: Left\nAlignOperands:   Align\nAllowAllArgumentsOnNextLine: true\nAllowAllParametersOfDeclarationOnNextLine: true\nAllowShortBlocksOnASingleLine: Never\nAllowShortCaseLabelsOnASingleLine: false\nAllowShortEnumsOnASingleLine: true\nAllowShortFunctionsOnASingleLine: All\nAllowShortIfStatementsOnASingleLine: WithoutElse\nAllowShortLambdasOnASingleLine: All\nAllowShortLoopsOnASingleLine: true\nAlwaysBreakAfterDefinitionReturnType: None\nAlwaysBreakAfterReturnType: None\nAlwaysBreakBeforeMultilineStrings: true\nAlwaysBreakTemplateDeclarations: Yes\nAttributeMacros:\n  - __capability\nBinPackArguments: true\nBinPackParameters: true\nBitFieldColonSpacing: Both\nBraceWrapping:\n  AfterCaseLabel:  false\n  AfterClass:      false\n  AfterControlStatement: Never\n  AfterEnum:       false\n  AfterExternBlock: false\n  AfterFunction:   false\n  AfterNamespace:  false\n  AfterObjCDeclaration: false\n  AfterStruct:     false\n  AfterUnion:      false\n  BeforeCatch:     false\n  BeforeElse:      false\n  BeforeLambdaBody: false\n  BeforeWhile:     false\n  IndentBraces:    false\n  SplitEmptyFunction: true\n  SplitEmptyRecord: true\n  SplitEmptyNamespace: true\nBreakAfterJavaFieldAnnotations: false\nBreakBeforeBinaryOperators: None\nBreakBeforeConceptDeclarations: Always\nBreakBeforeBraces: Attach\nBreakBeforeTernaryOperators: true\nBreakConstructorInitializers: BeforeColon\nBreakInheritanceList: BeforeColon\nBreakStringLiterals: true\nColumnLimit:     80\nCommentPragmas:  '^ IWYU pragma:'\nCompactNamespaces: false\nConstructorInitializerIndentWidth: 4\nContinuationIndentWidth: 4\nCpp11BracedListStyle: true\nDerivePointerAlignment: true\nDisableFormat:   false\nEmptyLineAfterAccessModifier: Never\nEmptyLineBeforeAccessModifier: LogicalBlock\nExperimentalAutoDetectBinPacking: false\nFixNamespaceComments: true\nForEachMacros:\n  - foreach\n  - Q_FOREACH\n  - BOOST_FOREACH\nIfMacros:\n  - KJ_IF_MAYBE\nIncludeBlocks:   Regroup\nIncludeCategories:\n  - Regex:           '^<ext/.*\\.h>'\n    Priority:        2\n    SortPriority:    0\n    CaseSensitive:   false\n  - Regex:           '^<.*\\.h>'\n    Priority:        1\n    SortPriority:    0\n    CaseSensitive:   false\n  - Regex:           '^<.*'\n    Priority:        2\n    SortPriority:    0\n    CaseSensitive:   false\n  - Regex:           '.*'\n    Priority:        3\n    SortPriority:    0\n    CaseSensitive:   false\nIncludeIsMainRegex: '([-_](test|unittest))?$'\nIncludeIsMainSourceRegex: ''\nIndentAccessModifiers: false\nIndentCaseBlocks: false\nIndentCaseLabels: true\nIndentExternBlock: AfterExternBlock\nIndentGotoLabels: true\nIndentPPDirectives: None\nIndentRequiresClause: true\nIndentWidth:     2\nIndentWrappedFunctionNames: false\nInsertBraces:    false\nInsertTrailingCommas: None\nJavaScriptQuotes: Leave\nJavaScriptWrapImports: true\nKeepEmptyLinesAtTheStartOfBlocks: false\nLambdaBodyIndentation: Signature\nMacroBlockBegin: ''\nMacroBlockEnd:   ''\nMaxEmptyLinesToKeep: 1\nNamespaceIndentation: None\nObjCBinPackProtocolList: Never\nObjCBlockIndentWidth: 2\nObjCBreakBeforeNestedBlockParam: true\nObjCSpaceAfterProperty: false\nObjCSpaceBeforeProtocolList: true\nPackConstructorInitializers: NextLine\nPenaltyBreakAssignment: 2\nPenaltyBreakBeforeFirstCallParameter: 1\nPenaltyBreakComment: 300\nPenaltyBreakFirstLessLess: 120\nPenaltyBreakOpenParenthesis: 0\nPenaltyBreakString: 1000\nPenaltyBreakTemplateDeclaration: 10\nPenaltyExcessCharacter: 1000000\nPenaltyIndentedWhitespace: 0\nPenaltyReturnTypeOnItsOwnLine: 200\nPointerAlignment: Left\nPPIndentWidth:   -1\nQualifierAlignment: Leave\nRawStringFormats:\n  - Language:        Cpp\n    Delimiters:\n      - cc\n      - CC\n      - cpp\n      - Cpp\n      - CPP\n      - 'c++'\n      - 'C++'\n    CanonicalDelimiter: ''\n    BasedOnStyle:    google\n  - Language:        TextProto\n    Delimiters:\n      - pb\n      - PB\n      - proto\n      - PROTO\n    EnclosingFunctions:\n      - EqualsProto\n      - EquivToProto\n      - PARSE_PARTIAL_TEXT_PROTO\n      - PARSE_TEST_PROTO\n      - PARSE_TEXT_PROTO\n      - ParseTextOrDie\n      - ParseTextProtoOrDie\n      - ParseTestProto\n      - ParsePartialTestProto\n    CanonicalDelimiter: pb\n    BasedOnStyle:    google\nReferenceAlignment: Pointer\nReflowComments:  true\nRemoveBracesLLVM: false\nRequiresClausePosition: OwnLine\nSeparateDefinitionBlocks: Leave\nShortNamespaceLines: 1\nSortIncludes:    CaseSensitive\nSortJavaStaticImport: Before\nSpaceAfterCStyleCast: false\nSpaceAfterLogicalNot: false\nSpaceAfterTemplateKeyword: true\nSpaceAroundPointerQualifiers: Default\nSpaceBeforeAssignmentOperators: true\nSpaceBeforeCaseColon: false\nSpaceBeforeCpp11BracedList: false\nSpaceBeforeCtorInitializerColon: true\nSpaceBeforeInheritanceColon: true\nSpaceBeforeParens: ControlStatements\nSpaceBeforeParensOptions:\n  AfterControlStatements: true\n  AfterForeachMacros: true\n  AfterFunctionDefinitionName: false\n  AfterFunctionDeclarationName: false\n  AfterIfMacros:   true\n  AfterOverloadedOperator: false\n  AfterRequiresInClause: false\n  AfterRequiresInExpression: false\n  BeforeNonEmptyParentheses: false\nSpaceBeforeRangeBasedForLoopColon: true\nSpaceBeforeSquareBrackets: false\nSpaceInEmptyBlock: false\nSpacesBeforeTrailingComments: 2\nSpacesInAngles:  Never\nSpacesInContainerLiterals: true\nSpacesInLineCommentPrefix:\n  Minimum:         1\n  Maximum:         -1\nSpacesInSquareBrackets: false\nStandard:        Auto\nStatementAttributeLikeMacros:\n  - Q_EMIT\nStatementMacros:\n  - Q_UNUSED\n  - QT_REQUIRE_VERSION\nTabWidth:        8\nUseTab:          Never\nWhitespaceSensitiveMacros:\n  - BOOST_PP_STRINGIZE\n  - CF_SWIFT_NAME\n  - NS_SWIFT_NAME\n  - PP_STRINGIZE\n  - STRINGIZE\n...\n\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/--work-item--dev-only-.md",
    "content": "---\nname: \"\\U0001F528Work Item (DEV ONLY)\"\nabout: Work item issue for tracking progress. Dev team only.\ntitle: ''\nlabels: Work Item\nassignees: ''\n\n---\n\n## 🔨Work Item\n\n**IMPORTANT:**\n* This template is only for dev team to track project progress. For feature request or bug report, please use the corresponding issue templates.\n* DO NOT create a new work item if the purpose is to fix an existing issue or feature request. We will directly use the issue in the project tracker.\n\nProject tracker: https://github.com/orgs/dmlc/projects/2\n\n## Description\n\n<!-- short description of the work item -->\n\n## Depending work items or issues\n\n<!-- what must be done before this -->\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/bug-report.md",
    "content": "---\nname: \"\\U0001F41B Bug Report\"\nabout: Submit a bug report to help us improve DGL\ntitle: ''\nlabels: ''\nassignees: ''\n\n---\n\n## 🐛 Bug\n\n<!-- A clear and concise description of what the bug is. -->\n\n## To Reproduce\n\nSteps to reproduce the behavior:\n\n1.\n1.\n1.\n\n<!-- If you have a code sample, error messages, stack traces, please provide it here as well -->\n\n## Expected behavior\n\n<!-- A clear and concise description of what you expected to happen. -->\n\n## Environment\n\n - DGL Version (e.g., 1.0):\n - Backend Library & Version (e.g., PyTorch 0.4.1, MXNet/Gluon 1.3):\n - OS (e.g., Linux):\n - How you installed DGL (`conda`, `pip`, source):\n - Build command you used (if compiling from source):\n - Python version:\n - CUDA/cuDNN version (if applicable):\n - GPU models and configuration (e.g. V100):\n - Any other relevant information:\n\n## Additional context\n\n<!-- Add any other context about the problem here. -->\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/documentation.md",
    "content": "---\nname: \"\\U0001F4DA Documentation\"\nabout: Report an issue related to docs.dgl.ai\ntitle: ''\nlabels: ''\nassignees: ''\n\n---\n\n## 📚 Documentation\n\n<!-- Please specify whether it's tutorial part or API reference part-->\n<!-- Describe the issue.-->\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/feature-request.md",
    "content": "---\nname: \"\\U0001F680Feature Request\"\nabout: Submit a proposal/request for a new DGL feature\ntitle: ''\nlabels: ''\nassignees: ''\n\n---\n\n## 🚀 Feature\n<!-- A brief description of the feature proposal -->\n\n## Motivation\n\n<!-- Please outline the motivation for the proposal. Is your feature request\nrelated to a problem? e.g., I'm always frustrated when [...]. If this is\nrelated to another GitHub issue, please link here too -->\n\n## Alternatives\n\n<!-- A clear and concise description of any alternative solutions or features you've considered, if any. -->\n\n## Pitch\n\n<!-- A clear and concise description of what you want to happen. -->\n\n## Additional context\n\n<!-- Add any other context or screenshots about the feature request here. -->\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/questions-help-support.md",
    "content": "---\nname: \"❓Questions/Help/Support\"\nabout: Do you need support? We have resources.\ntitle: ''\nlabels: ''\nassignees: ''\n\n---\n\n## ❓ Questions and Help\n\nBefore proceeding, please note that we recommend\nusing our discussion forum (https://discuss.dgl.ai) for\ngeneral questions. As a result, this issue will\nlikely be CLOSED shortly.\n"
  },
  {
    "path": ".github/PULL_REQUEST_TEMPLATE.md",
    "content": "## Description\r\n<!-- Brief description. Refer to the related issues if existed.\r\nIt'll be great if relevant reviewers can be assigned as well.-->\r\n\r\n## Checklist\r\nPlease feel free to remove inapplicable items for your PR.\r\n- [ ] The PR title starts with [$CATEGORY] (such as [NN], [Model], [Doc], [Feature]])\r\n- [ ] I've leverage the [tools](https://docs.google.com/document/d/1iHyj7zlmygKSk5gBPsqIqL5ASPzJSPREaNT_QdsiYA4/edit) to beautify the python and c++ code.\r\n- [ ] The PR is complete and small, read the [Google eng practice (CL equals to PR)](https://google.github.io/eng-practices/review/developer/small-cls.html) to understand more about small PR. In DGL, we consider PRs with less than 200 lines of core code change are small (example, test and documentation could be exempted).\r\n- [ ] All changes have test coverage\r\n- [ ] Code is well-documented\r\n- [ ] To the best of my knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change\r\n- [ ] Related issue is referred in this PR\r\n- [ ] If the PR is for a new model/paper, I've updated the example index [here](../examples/README.md).\r\n\r\n## Changes\r\n<!-- You could use following template\r\n- [ ] Feature1, tests, (and when applicable, API doc)\r\n- [ ] Feature2, tests, (and when applicable, API doc)\r\n-->\r\n"
  },
  {
    "path": ".github/workflows/lint.yml",
    "content": "name: Lint\n\non: [pull_request]\n\njobs:\n  lintrunner:\n    runs-on: ubuntu-latest\n    steps:\n      - name: Pull DGL\n        uses: actions/checkout@v3\n        with:\n          fetch-depth: 0\n\n      - name: Checkout master and HEAD\n        run: |\n          git checkout -t origin/master\n          git checkout ${{ github.event.pull_request.head.sha }}\n\n      - name: Setup Python\n        uses: actions/setup-python@v4\n        with:\n          python-version: '3.8'\n\n      - name: Install requirements\n        run: |\n          python -m pip install --upgrade pip\n          pip install lintrunner --user\n\n      - name: Initialize lint dependencies\n        run: lintrunner init\n\n      - name: Run lintrunner on all changed files\n        run: |\n          set +e\n          if ! lintrunner --force-color -m master --tee-json=lint.json; then\n              echo \"\"\n              echo -e \"\\e[1m\\e[36mYou can reproduce these results locally by using \\`lintrunner\\`.\\e[0m\"\n              echo -e \"\\e[1m\\e[36mSee https://github.com/pytorch/pytorch/wiki/lintrunner for setup instructions.\\e[0m\"\n              exit 1\n          fi\n\n      - name: Store annotations\n        if: always() && github.event_name == 'pull_request'\n        # Don't show this as an error; the above step will have already failed.\n        continue-on-error: true\n        run: |\n          # Use jq to massage the JSON lint output into GitHub Actions workflow commands.\n          jq --raw-output \\\n            '\"::\\(if .severity == \"advice\" or .severity == \"disabled\" then \"warning\" else .severity end) file=\\(.path),line=\\(.line),col=\\(.char),title=\\(.code) \\(.name)::\" + (.description | gsub(\"\\\\n\"; \"%0A\"))' \\\n            lint.json\n\nconcurrency:\n  group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}\n  cancel-in-progress: true\n"
  },
  {
    "path": ".github/workflows/stale.yml",
    "content": "# This workflow warns and then closes issues and PRs that have had no activity for a specified amount of time.\n#\n# You can adjust the behavior by modifying this file.\n# For more information, see:\n# https://github.com/actions/stale\nname: Mark stale issues and pull requests\n\non:\n  schedule:\n  - cron: '0 1 * * *'\n\njobs:\n  stale:\n\n    runs-on: ubuntu-latest\n    permissions:\n      issues: write\n      pull-requests: write\n\n    steps:\n    - uses: actions/stale@v4.1.0\n      with:\n        repo-token: ${{ secrets.GITHUB_TOKEN }}\n        days-before-issue-stale: 30\n        days-before-issue-close: -1 # disable issue close\n        days-before-pr-stale: -1 # disable stale bot on pr\n        days-before-pr-close: -1 # disable stale bot on pr\n        stale-issue-message: 'This issue has been automatically marked as stale due to lack of activity. It will be closed if no further activity occurs. Thank you'\n        close-issue-message: 'This issue is closed due to lack of activity. Feel free to reopen it if you still have questions.'\n        stale-issue-label: 'stale-issue'\n        exempt-issue-labels: 'bug:confirmed,feature request,help wanted,Work Item'\n        exempt-all-issue-milestones: true\n"
  },
  {
    "path": ".gitignore",
    "content": "# IDE\n.idea\n\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndataset/\ndatasets/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# Whitelist some distribution / package non-related directories\n!tests/dist\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n.hypothesis/\n.pytest_cache/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# pyenv\n.python-version\n\n# celery beat schedule file\ncelerybeat-schedule\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n\nexamples/pytorch/data/ind.pubmed.y\nexamples/pytorch/data/ind.pubmed.x\nexamples/pytorch/data/ind.pubmed.ty\nexamples/pytorch/data/ind.pubmed.tx\nexamples/pytorch/data/ind.pubmed.test.index\nexamples/pytorch/data/ind.pubmed.graph\nexamples/pytorch/data/ind.pubmed.ally\nexamples/pytorch/data/ind.pubmed.allx\nexamples/pytorch/data/ind.cora.y\nexamples/pytorch/data/ind.cora.x\nexamples/pytorch/data/ind.cora.ty\nexamples/pytorch/data/ind.cora.tx\nexamples/pytorch/data/ind.cora.test.index\nexamples/pytorch/data/ind.cora.graph\nexamples/pytorch/data/ind.cora.ally\nexamples/pytorch/data/ind.cora.allx\nexamples/pytorch/data/ind.citeseer.y\nexamples/pytorch/data/ind.citeseer.x\nexamples/pytorch/data/ind.citeseer.ty\nexamples/pytorch/data/ind.citeseer.tx\nexamples/pytorch/data/ind.citeseer.test.index\nexamples/pytorch/data/ind.citeseer.graph\nexamples/pytorch/data/ind.citeseer.ally\nexamples/pytorch/data/ind.citeseer.allx\nexamples/pytorch/.DS_Store\nexamples/.DS_Store\nexamples/pytorch/generative_graph/*.p\n.DS_Store\n\n# data directory\n_download\n\n# CTags & CScope\ntags\ncscope.*\n\n# Vim\n*.swp\n*.swo\n*.un~\n*~\n\n# parameters\n*.params\n\n# vscode\n.clangd\n.vscode\n\n# asv\n.asv\n\n.ycm_extra_conf.py\n**.png\n\n# model file\n*.pth\n"
  },
  {
    "path": ".gitmodules",
    "content": "[submodule \"third_party/dmlc-core\"]\n\tpath = third_party/dmlc-core\n\turl = https://github.com/dmlc/dmlc-core.git\n[submodule \"third_party/dlpack\"]\n\tpath = third_party/dlpack\n\turl = https://github.com/dmlc/dlpack.git\n[submodule \"third_party/googletest\"]\n\tpath = third_party/googletest\n\turl = https://github.com/google/googletest.git\n[submodule \"third_party/METIS\"]\n\tpath = third_party/METIS\n\turl = https://github.com/KarypisLab/METIS.git\n[submodule \"third_party/nanoflann\"]\n\tpath = third_party/nanoflann\n\turl = https://github.com/jlblancoc/nanoflann\n[submodule \"third_party/libxsmm\"]\n\tpath = third_party/libxsmm\n\turl = https://github.com/hfp/libxsmm.git\n[submodule \"third_party/pcg\"]\n\tpath = third_party/pcg\n\turl = https://github.com/imneme/pcg-cpp.git\n[submodule \"third_party/cccl\"]\n\tpath = third_party/cccl\n\turl = https://github.com/NVIDIA/cccl.git\n[submodule \"third_party/liburing\"]\n\tpath = third_party/liburing\n\turl = https://github.com/axboe/liburing.git\n[submodule \"third_party/cuco\"]\n\tpath = third_party/cuco\n\turl = https://github.com/NVIDIA/cuCollections.git\n[submodule \"third_party/GKlib\"]\n\tpath = third_party/GKlib\n\turl = https://github.com/KarypisLab/GKlib.git\n[submodule \"third_party/taskflow\"]\n\tpath = third_party/taskflow\n\turl = https://github.com/taskflow/taskflow.git\n[submodule \"third_party/tsl_robin_map\"]\n\tpath = third_party/tsl_robin_map\n\turl = https://github.com/Tessil/robin-map.git\n"
  },
  {
    "path": ".lintrunner.toml",
    "content": "# Black + usort\n[[linter]]\ncode = 'UFMT'\ninclude_patterns = [\n    '**/*.py',\n]\ncommand = [\n    'python3',\n    'tests/lint/ufmt_linter.py',\n    '--',\n    '@{{PATHSFILE}}'\n]\nexclude_patterns = [\n    '.github/*',\n    'build/*',\n    'cmake/*',\n    'conda/*',\n    'docker/*',\n    'third_party/*',\n]\ninit_command = [\n    'python3',\n    'tests/lint/pip_init.py',\n    '--dry-run={{DRYRUN}}',\n    'black==22.10.0',\n    'ufmt==2.0.1',\n    'usort==1.0.5',\n]\nis_formatter = true\n\n[[linter]]\ncode = 'CLANGFORMAT'\ninclude_patterns = [\n    '**/*.h',\n    '**/*.c',\n    '**/*.cc',\n    '**/*.cpp',\n    '**/*.cuh',\n    '**/*.cu',\n]\nexclude_patterns = [\n    'third_party/**',\n]\ninit_command = [\n    'python3',\n    'tests/lint/pip_init.py',\n    '--dry-run={{DRYRUN}}',\n    'clang-format==15.0.4',\n]\ncommand = [\n    'python3',\n    'tests/lint/clangformat_linter.py',\n    '--binary=clang-format',\n    '--',\n    '@{{PATHSFILE}}'\n]\nis_formatter = true\n"
  },
  {
    "path": "CMakeLists.txt",
    "content": "cmake_minimum_required(VERSION 3.18)\n########################################\n# Borrowed and adapted from TVM project\n########################################\nproject(dgl C CXX)\nmessage(STATUS \"Start configuring project ${PROJECT_NAME}\")\n\nset(CMAKE_CXX_STANDARD 17)\nset(CMAKE_CXX_STANDARD_REQUIRED ON)\n\n# cmake utils\ninclude(cmake/util/Util.cmake)\ninclude(cmake/util/MshadowUtil.cmake)\ninclude(cmake/util/FindCUDA.cmake)\n\n# Options for building DGL.\n# NOTE: Please avoid editing this file to change build type. Instead, using\n# bash script/build_dgl.sh -e -t release to overwrite the value.\ndgl_option(BUILD_TYPE \"Type of the build: dev, dogfood or release\" \"dev\")\nmessage(STATUS \"Build for ${BUILD_TYPE}\")\n\ndgl_option(USE_CUDA \"Build with CUDA\" OFF)\ndgl_option(TORCH_PYTHON_INTERPS \"Python interpreter for building sub-components\" python3)\n\n# Conda build related options.\ndgl_option(EXTERNAL_DLPACK_PATH \"Path to external dlpack\" OFF)\ndgl_option(EXTERNAL_DMLC_PATH \"Path to external dmlc-core\" OFF)\ndgl_option(EXTERNAL_DMLC_LIB_PATH \"Path to external dmlc-core library\" OFF)\ndgl_option(EXTERNAL_PHMAP_PATH \"Path to external parallel-hashmap\" OFF)\ndgl_option(EXTERNAL_NANOFLANN_PATH \"Path to use external nanoflann\" OFF)\ndgl_option(EXTERNAL_METIS_PATH \"Path to external metis\" OFF)\ndgl_option(EXTERNAL_METIS_LIB_PATH \"Path to external metis library\" OFF)\ndgl_option(EXTERNAL_GKLIB_PATH \"Path to external gklib\" OFF)\n\n# Options for building DGL features: \"none,\" \"dev,\" \"dogfood,\" \"release,\" and\n# \"all.\"\n#    \"none\"  - The feature is OFF for all build types. This is used when\n#              disabling a feature.\n#    \"dev\"   - The feature is ON for dev build. The default build from source\n#              and the build for unit tests are using this build type.\n#  \"dogfood\" - The major function of this feature is done. The regression and\n#              benchmark framework are using this build type.\n#  \"release\" - The feature will be build for release.\n#    \"all\"   - The feature is ON for all build types. Equivalent to set [\"dev\"\n#              \"dogfood\" \"release\"].\n# NOTE: Please avoid editing this file to change feature options for a local\n# build. Instead, using bash script/build_dgl.sh -e '-DFEATURE_NAME=ON/OFF' to\n# overwrite the value.\ndgl_feature_option(\n    BUILD_SPARSE\n    \"Build DGL sparse library\"\n    \"all\"\n)\ndgl_feature_option(\n    BUILD_TORCH\n    \"Build the PyTorch plugin\"\n    \"all\"\n)\ndgl_feature_option(\n    USE_EPOLL\n    \"Build with epoll for socket communicator\"\n    \"all\"\n)\ndgl_feature_option(\n    USE_LIBXSMM\n    \"Build with LIBXSMM library optimization\"\n    \"all\"\n)\ndgl_feature_option(\n    USE_OPENMP\n    \"Build with OpenMP\"\n    \"all\"\n)\n\ndgl_feature_option(\n    BUILD_GRAPHBOLT\n    \"Build Graphbolt library\"\n    \"all\"\n)\n\ndgl_feature_option(\n    LIBCXX_ENABLE_PARALLEL_ALGORITHMS\n    \"Enable the parallel algorithms library. This requires the PSTL to be available.\"\n    \"none\"\n)\ndgl_feature_option(\n    REBUILD_LIBXSMM\n    \"Clean LIBXSMM build cache at every build\"\n    \"none\"\n)\ndgl_feature_option(\n    USE_HDFS\n    \"Build with HDFS support\"\n    \"none\"\n) # Set env HADOOP_HDFS_HOME if needed\ndgl_feature_option(\n    USE_S3\n    \"Build with S3 support\"\n    \"none\"\n)\n\n# Only build C++ tests for unit testing purposes in dev build.\ndgl_feature_option(\n    BUILD_CPP_TEST\n    \"Build cpp unittest executables\"\n    \"dev\"\n)\n\nif (EXTERNAL_DLPACK_PATH OR EXTERNAL_DMLC_PATH OR EXTERNAL_NANOFLANN_PATH OR EXTERNAL_NANOFLANN_PATH OR EXTERNAL_METIS_PATH OR EXTERNAL_GKLIB_PATH)\n  message(STATUS \"Using at least one external library\")\n  set(USE_EXTERNAL_LIBS ON)\n  \n  if (BUILD_CPP_TEST)\n    message(FATAL_ERROR \"Cannot build cpp unittests with external libraries\")\n  endif(BUILD_CPP_TEST)\nendif()\n\nset(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} ${CMAKE_CURRENT_SOURCE_DIR}/cmake/modules)\n\n# Set optimization options for different build types.\nif (${BUILD_TYPE} STREQUAL \"dev\")\n  if (MSVC)\n    set(CMAKE_C_FLAGS \"${CMAKE_C_FLAGS} /Od\")\n    set(CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS} /Od\")\n  else()\n    set(CMAKE_C_FLAGS \"${CMAKE_C_FLAGS} -O0 -g3 -ggdb\")\n    set(CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS} -O0 -g3 -ggdb\")\n  endif()\nelse()\n  if (MSVC)\n    set(CMAKE_C_FLAGS \"${CMAKE_C_FLAGS} /O2 /DNDEBUG\")\n    set(CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS} /O2 /DNDEBUG\")\n  else()\n    set(CMAKE_C_FLAGS \"${CMAKE_C_FLAGS} -O2 -DNDEBUG\")\n    set(CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS} -O2 -DNDEBUG\")\n  endif()\nendif()\n\nif(USE_CUDA)\n  message(STATUS \"Build with CUDA support\")\n  project(dgl C CXX)\n  include(cmake/modules/CUDA.cmake)\n  message(STATUS \"Use external CCCL library for a consistent API and performance.\")\n  cuda_include_directories(BEFORE \"${CMAKE_SOURCE_DIR}/third_party/cccl/thrust\")\n  cuda_include_directories(BEFORE \"${CMAKE_SOURCE_DIR}/third_party/cccl/cub\")\n  cuda_include_directories(BEFORE \"${CMAKE_SOURCE_DIR}/third_party/cccl/libcudacxx/include\")\nendif(USE_CUDA)\n\n# initial variables\nif(NOT MSVC)\nset(DGL_LINKER_LIBS \"dl\")\nendif(NOT MSVC)\n\nif(MSVC OR CMAKE_SYSTEM_NAME STREQUAL \"Darwin\")\nset(DGL_RUNTIME_LINKER_LIBS \"\")\nelse(MSVC OR CMAKE_SYSTEM_NAME STREQUAL \"Darwin\")\nset(DGL_RUNTIME_LINKER_LIBS \"rt\")\nendif(MSVC OR CMAKE_SYSTEM_NAME STREQUAL \"Darwin\")\n\n# Generic compilation options\nif(MSVC)\n  add_definitions(-DWIN32_LEAN_AND_MEAN)\n  add_definitions(-D_CRT_SECURE_NO_WARNINGS)\n  add_definitions(-D_SCL_SECURE_NO_WARNINGS)\n  add_definitions(-DNOMINMAX)\n  set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS 1)\n  set(CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS} /EHsc\")\n  set(CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS} /MP\")\n  set(CMAKE_C_FLAGS \"${CMAKE_C_FLAGS} /bigobj\")\n  if(USE_MSVC_MT)\n    foreach(flag_var\n        CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE\n        CMAKE_CXX_FLAGS_MINSIZEREL CMAKE_CXX_FLAGS_RELWITHDEBINFO)\n      if(${flag_var} MATCHES \"/MD\")\n        string(REGEX REPLACE \"/MD\" \"/MT\" ${flag_var} \"${${flag_var}}\")\n      endif(${flag_var} MATCHES \"/MD\")\n    endforeach(flag_var)\n  endif()\nelse(MSVC)\n  include(CheckCXXCompilerFlag)\n  set(CMAKE_C_FLAGS \"-Wall -fPIC ${CMAKE_C_FLAGS}\")\n  set(CMAKE_CXX_FLAGS \"-Wall -fPIC ${CMAKE_CXX_FLAGS}\")\n  if(NOT APPLE)\n    set(CMAKE_SHARED_LINKER_FLAGS \"-Wl,--warn-common ${CMAKE_SHARED_LINKER_FLAGS}\")\n  endif(NOT APPLE)\nendif(MSVC)\n\nif(NOT CMAKE_SYSTEM_PROCESSOR MATCHES \"(x86)|(X86)|(amd64)|(AMD64)\")\n  message(STATUS \"Disabling LIBXSMM on ${CMAKE_SYSTEM_PROCESSOR}.\")\n  set(USE_LIBXSMM OFF)\nendif()\n\n# Source file lists\nfile(GLOB DGL_SRC\n  src/*.cc\n  src/array/*.cc\n  src/array/cpu/*.cc\n  src/random/*.cc\n  src/random/cpu/*.cc\n  src/runtime/*.cc\n  src/geometry/*.cc\n  src/geometry/cpu/*.cc\n  src/partition/*.cc\n)\n\nfile(GLOB_RECURSE DGL_SRC_1\n  src/api/*.cc\n  src/graph/*.cc\n  src/scheduler/*.cc\n)\n\nlist(APPEND DGL_SRC ${DGL_SRC_1})\n\nif (NOT MSVC)\n  file(GLOB_RECURSE DGL_RPC_SRC src/rpc/*.cc)\nelse()\n  file(GLOB_RECURSE DGL_RPC_SRC src/rpc/network/*.cc)\nendif()\nlist(APPEND DGL_SRC ${DGL_RPC_SRC})\n\nif(USE_OPENMP)\n  find_package(OpenMP REQUIRED)\n  list(APPEND DGL_LINKER_LIBS OpenMP::OpenMP_CXX)\n  message(STATUS \"Build with OpenMP.\")\nendif(USE_OPENMP)\n\n# Configure cuda\nif(USE_CUDA)\n  file(GLOB_RECURSE DGL_CUDA_SRC\n    src/array/cuda/*.cc\n    src/array/cuda/*.cu\n    src/array/cuda/uvm/*.cc\n    src/array/cuda/uvm/*.cu\n    src/kernel/cuda/*.cc\n    src/kernel/cuda/*.cu\n    src/partition/cuda/*.cu\n    src/runtime/cuda/*.cc\n    src/runtime/cuda/*.cu\n    src/geometry/cuda/*.cu\n    src/graph/transform/cuda/*.cu\n    src/graph/sampling/randomwalks/*.cu\n  )\n  list(APPEND DGL_SRC ${DGL_CUDA_SRC})\n  dgl_config_cuda(DGL_LINKER_LIBS)\n  cuda_add_library(dgl SHARED ${DGL_SRC})\nelse(USE_CUDA)\n  add_library(dgl SHARED ${DGL_SRC})\nendif(USE_CUDA)\n\nif ((NOT MSVC) AND USE_EPOLL)\n  INCLUDE(CheckIncludeFile)\n  check_include_file(\"sys/epoll.h\" EPOLL_AVAILABLE)\n  if (EPOLL_AVAILABLE)\n    target_compile_definitions(dgl PRIVATE USE_EPOLL)\n  else()\n    message(WARNING \"EPOLL is not available on this platform...\")\n  endif()\nendif ()\n\n# include directories\ntarget_include_directories(dgl PRIVATE \"include\")\n# check for conda includes\nif(\"$ENV{CONDA_BUILD}\" STREQUAL \"1\")\n  set(in_conda_build TRUE)\n  message(STATUS \"Conda build environment detected\")\nelseif(DEFINED ENV{CONDA_PREFIX})\n  set(in_conda_prefix TRUE)\n  message(STATUS \"Conda environment detected: $ENV{CONDA_PREFIX}\")\nendif()\n\nif (USE_CONDA_INCLUDES)\n  if(in_conda_build)\n    message(STATUS \"Using Conda build environment includes: $ENV{PREFIX}\")\n    target_include_directories(dgl PRIVATE \"$ENV{PREFIX}/include\" \"$ENV{BUILD_PREFIX}/include\")\n  elseif(in_conda_prefix)\n    message(STATUS \"Using Conda environment includes: $ENV{CONDA_PREFIX}\")\n    target_include_directories(dgl PRIVATE \"$ENV{CONDA_PREFIX}/include\")\n  else()\n    message(FATAL_ERROR \"Conda environment not detected\")\n  endif()\nendif()\n\nif(EXTERNAL_DLPACK_PATH)\n  message(STATUS \"looking for dlpack headers in ${EXTERNAL_DLPACK_PATH}\")\n  include_directories(SYSTEM ${EXTERNAL_DLPACK_PATH})\nelse(EXTERNAL_DLPACK_PATH)\n  target_include_directories(dgl PRIVATE \"third_party/dlpack/include\")\nendif(EXTERNAL_DLPACK_PATH)\n\nif(EXTERNAL_DMLC_PATH)\n  if (USE_HDFS)\n    message(FATAL_ERROR \"Cannot use HDFS and external dmlc-core at the same time\")\n  endif()\n  message(STATUS \"looking for dmlc headers in ${EXTERNAL_DMLC_PATH}\")\n  include_directories(SYSTEM ${EXTERNAL_DMLC_PATH})\n  \n  if (NOT EXTERNAL_DMLC_LIB_PATH)\n    message(FATAL_ERROR \"EXTERNAL_DMLC_LIB_PATH must be set if EXTERNAL_DMLC_PATH is set\")\n  endif()\n  message(STATUS \"looking for dmlc library in ${EXTERNAL_DMLC_LIB_PATH}\")\n  find_package(dmlc\n    REQUIRED\n    HINTS ${EXTERNAL_DMLC_LIB_PATH}\n  )\n  if(NOT dmlc_FOUND)\n      message(FATAL_ERROR \"Failed to find DMLC library\")\n  endif()\n  list(APPEND DGL_LINKER_LIBS dmlc::dmlc)\n\nelse(EXTERNAL_DMLC_PATH)\n  target_include_directories(dgl PRIVATE \"third_party/dmlc-core/include\")\n  # For serialization\n  if (USE_HDFS)\n    option(DMLC_HDFS_SHARED \"dgl has to build with dynamic hdfs library\" ON)\n  endif()\n  add_subdirectory(\"third_party/dmlc-core\")\n  list(APPEND DGL_LINKER_LIBS dmlc)\n  set(GOOGLE_TEST 0) # Turn off dmlc-core test\nendif(EXTERNAL_DMLC_PATH)\n\ntarget_include_directories(dgl PRIVATE \"tensoradapter/include\")\ntarget_include_directories(dgl PRIVATE \"third_party/pcg/include\")\ntarget_include_directories(dgl PRIVATE \"third_party/tsl_robin_map/include\")\n\nif(EXTERNAL_NANOFLANN_PATH)\n  include_directories(SYSTEM ${EXTERNAL_NANOFLANN_PATH})\nelse(EXTERNAL_NANOFLANN_PATH)\n  target_include_directories(dgl PRIVATE \"third_party/nanoflann/include\")\nendif(EXTERNAL_NANOFLANN_PATH)\n\nif (USE_LIBXSMM)\n  target_compile_definitions(dgl PRIVATE USE_LIBXSMM DGL_CPU_LLC_SIZE=40000000 __BLAS=0)\n  target_include_directories(dgl PRIVATE \"third_party/libxsmm/include\")\n  message(STATUS \"Build with LIBXSMM optimization.\")\nendif()\n\n# To compile METIS correct for DGL.\nadd_compile_definitions(IDXTYPEWIDTH=64 REALTYPEWIDTH=32)\nif (EXTERNAL_METIS_PATH)\n  # To compile METIS correct for DGL.\n  if(MSVC)\n    set(CMAKE_C_FLAGS \"${CMAKE_C_FLAGS} /DIDXTYPEWIDTH=64 /DREALTYPEWIDTH=32\")\n    set(CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS} /DIDXTYPEWIDTH=64 /DREALTYPEWIDTH=32\")\n  else(MSVC)\n    set(CMAKE_C_FLAGS \"${CMAKE_C_FLAGS} -DIDXTYPEWIDTH=64 -DREALTYPEWIDTH=32\")\n    set(CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS} -DIDXTYPEWIDTH=64 -DREALTYPEWIDTH=32\")\n  endif(MSVC)\n  find_package(METIS REQUIRED)\n  message(STATUS \"Found METIS library\")\n  target_include_directories(dgl SYSTEM PUBLIC ${METIS_INCLUDE_DIR})\n  list(APPEND DGL_LINKER_LIBS ${METIS_LIBRARIES})\nelse(EXTERNAL_METIS_PATH)\n  target_include_directories(dgl PRIVATE \"third_party/METIS/include\")\n  # Compile METIS\n  if(NOT MSVC)\n    set(GKLIB_PATH \"${CMAKE_CURRENT_SOURCE_DIR}/third_party/GKlib\")\n    include(${GKLIB_PATH}/GKlibSystem.cmake)\n    include_directories(${GKLIB_PATH})\n    add_library(GKlib ${GKlib_sources})\n    include_directories(\"third_party/METIS/include/\")\n    add_subdirectory(\"third_party/METIS/libmetis/\")\n    # When building on ubi7, it fails with the following error:\n    # /usr/include/signal.h:156:29: error: unknown type name 'siginfo_t'.\n    # So I(Rui) define _POSIX_C_SOURCE to 200809L for GKlib and metis to avoid the error.\n    target_compile_definitions(GKlib PRIVATE _POSIX_C_SOURCE=200809L)\n    target_compile_definitions(metis PRIVATE _POSIX_C_SOURCE=200809L)\n    list(APPEND DGL_LINKER_LIBS metis GKlib)\n  endif(NOT MSVC)\nendif(EXTERNAL_METIS_PATH)\n\n\n# Avoid exposing third-party symbols when using DGL as a library.\nif((NOT MSVC) AND (NOT ${CMAKE_SYSTEM_NAME} MATCHES \"Darwin\"))\n  set(CMAKE_C_FLAGS \"${CMAKE_C_FLAGS} -Wl,--exclude-libs,ALL\")\n  set(CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS} -Wl,--exclude-libs,ALL\")\nendif()\n\n# Compile gpu_cache\nif(USE_CUDA)\n  # Manually build gpu_cache because CMake always builds it as shared\n  file(GLOB gpu_cache_src\n    third_party/HugeCTR/gpu_cache/src/nv_gpu_cache.cu\n  )\n  cuda_add_library(gpu_cache STATIC ${gpu_cache_src})\n  target_include_directories(gpu_cache PRIVATE \"third_party/HugeCTR/gpu_cache/include\")\n  target_include_directories(dgl PRIVATE \"third_party/HugeCTR/gpu_cache/include\")\n  list(APPEND DGL_LINKER_LIBS gpu_cache)\n  message(STATUS \"Build with HugeCTR GPU embedding cache.\")\nendif(USE_CUDA)\n\n# support PARALLEL_ALGORITHMS\nif (LIBCXX_ENABLE_PARALLEL_ALGORITHMS)\n  target_compile_definitions(dgl PRIVATE PARALLEL_ALGORITHMS)\nendif(LIBCXX_ENABLE_PARALLEL_ALGORITHMS)\n\ntarget_link_libraries(dgl ${DGL_LINKER_LIBS} ${DGL_RUNTIME_LINKER_LIBS})\nif(MSVC)\n  add_custom_command(\n    TARGET dgl POST_BUILD COMMAND\n    ${CMAKE_COMMAND} -E copy \"$<TARGET_FILE:dgl>\" \"$<TARGET_FILE_DIR:dgl>/..\")\nendif(MSVC)\n\n# Tensor adapter libraries\n# Linking against LibTorch involves linking against a bunch of other libraries\n# returned by PyTorch's CMake (e.g. C10 or NVTools).  Because CMake caches\n# the found libraries in find_library(), often times CMake will look into the libraries\n# of the wrong version when I build everything in the same CMake process.  As\n# a result, I (BarclayII) am launching an individual CMake build for every PyTorch version.\nif(BUILD_TORCH)\n  file(TO_NATIVE_PATH ${CMAKE_CURRENT_BINARY_DIR} BINDIR)\n  file(TO_NATIVE_PATH ${CMAKE_COMMAND} CMAKE_CMD)\n  if(MSVC)\n    file(TO_NATIVE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/tensoradapter/pytorch/build.bat BUILD_SCRIPT)\n    add_custom_target(\n      tensoradapter_pytorch\n      ${CMAKE_COMMAND} -E env\n      CMAKE_COMMAND=${CMAKE_CMD}\n      CUDA_TOOLKIT_ROOT_DIR=${CUDA_TOOLKIT_ROOT_DIR}\n      USE_CUDA=${USE_CUDA}\n      EXTERNAL_DMLC_LIB_PATH=${EXTERNAL_DMLC_LIB_PATH}\n      BINDIR=${BINDIR}\n      cmd /e:on /c ${BUILD_SCRIPT} ${TORCH_PYTHON_INTERPS}\n      DEPENDS ${BUILD_SCRIPT}\n      WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/tensoradapter/pytorch)\n  else(MSVC)\n    file(TO_NATIVE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/tensoradapter/pytorch/build.sh BUILD_SCRIPT)\n    add_custom_target(\n      tensoradapter_pytorch\n      ${CMAKE_COMMAND} -E env\n      CMAKE_COMMAND=${CMAKE_CMD}\n      CUDA_TOOLKIT_ROOT_DIR=${CUDA_TOOLKIT_ROOT_DIR}\n      USE_CUDA=${USE_CUDA}\n      EXTERNAL_DMLC_LIB_PATH=${EXTERNAL_DMLC_LIB_PATH}\n      BINDIR=${CMAKE_CURRENT_BINARY_DIR}\n      bash ${BUILD_SCRIPT} ${TORCH_PYTHON_INTERPS}\n      DEPENDS ${BUILD_SCRIPT}\n      WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/tensoradapter/pytorch)\n  endif(MSVC)\n  add_dependencies(dgl tensoradapter_pytorch)\nendif(BUILD_TORCH)\n\n# Installation rules\ninstall(TARGETS dgl DESTINATION lib${LIB_SUFFIX})\n\n# Testing\nif(BUILD_CPP_TEST)\n  message(STATUS \"Build with unittest\")\n  add_subdirectory(./third_party/googletest)\n  enable_testing()\n  include_directories(${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR})\n  include_directories(\"include\")\n  include_directories(\"third_party/dlpack/include\")\n  include_directories(\"third_party/dmlc-core/include\")\n  include_directories(\"third_party/tsl_robin_map/include\")\n  include_directories(\"third_party/libxsmm/include\")\n  include_directories(\"third_party/pcg/include\")\n  file(GLOB_RECURSE TEST_SRC_FILES ${PROJECT_SOURCE_DIR}/tests/cpp/*.cc)\n  add_executable(runUnitTests ${TEST_SRC_FILES})\n  target_link_libraries(runUnitTests gtest gtest_main)\n  target_link_libraries(runUnitTests dgl)\n  add_test(UnitTests runUnitTests)\nendif(BUILD_CPP_TEST)\n\nif(BUILD_SPARSE)\n  message(STATUS \"Configuring DGL sparse library\")\n  file(TO_NATIVE_PATH ${CMAKE_CURRENT_BINARY_DIR} BINDIR)\n  file(TO_NATIVE_PATH ${CMAKE_COMMAND} CMAKE_CMD)\n  get_target_property(DGL_INCLUDE_DIRS dgl INCLUDE_DIRECTORIES)\n  message(STATUS \"DGL include directories: ${DGL_INCLUDE_DIRS}\")\n  message(STATUS \"DGL link directories: ${DGL_INCLUDE_DIRS}\")\n  if(MSVC)\n    file(TO_NATIVE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/dgl_sparse/build.bat BUILD_SCRIPT)\n    add_custom_target(\n      dgl_sparse\n      ALL\n      ${CMAKE_COMMAND} -E env\n      CMAKE_COMMAND=${CMAKE_CMD}\n      CUDA_TOOLKIT_ROOT_DIR=${CUDA_TOOLKIT_ROOT_DIR}\n      USE_CUDA=${USE_CUDA}\n      BINDIR=${BINDIR}\n      INCLUDEDIR=\"${DGL_INCLUDE_DIRS}\"\n      CFLAGS=${CMAKE_C_FLAGS}\n      CXXFLAGS=${CMAKE_CXX_FLAGS}\n      LDFLAGS=${CMAKE_SHARED_LINKER_FLAGS}\n      cmd /e:on /c ${BUILD_SCRIPT} ${TORCH_PYTHON_INTERPS}\n      DEPENDS ${BUILD_SCRIPT}\n      WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/dgl_sparse)\n  else(MSVC)\n    file(TO_NATIVE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/dgl_sparse/build.sh BUILD_SCRIPT)\n    add_custom_target(\n      dgl_sparse\n      ALL\n      ${CMAKE_COMMAND} -E env\n      CMAKE_COMMAND=${CMAKE_CMD}\n      CUDA_TOOLKIT_ROOT_DIR=${CUDA_TOOLKIT_ROOT_DIR}\n      USE_CUDA=${USE_CUDA}\n      BINDIR=${CMAKE_CURRENT_BINARY_DIR}\n      INCLUDEDIR=\"${DGL_INCLUDE_DIRS}\"\n      CFLAGS=${CMAKE_C_FLAGS}\n      CXXFLAGS=${CMAKE_CXX_FLAGS}\n      LDFLAGS=${CMAKE_SHARED_LINKER_FLAGS}\n      bash ${BUILD_SCRIPT} ${TORCH_PYTHON_INTERPS}\n      DEPENDS ${BUILD_SCRIPT}\n      WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/dgl_sparse)\n  endif(MSVC)\n  add_dependencies(dgl_sparse dgl)\nendif(BUILD_SPARSE)\n\nif(BUILD_GRAPHBOLT)\n  message(STATUS \"Configuring graphbolt library\")\n  string(REPLACE \";\" \"\\\\;\" CUDA_ARCHITECTURES_ESCAPED \"${CUDA_ARCHITECTURES}\")\n  file(TO_NATIVE_PATH ${CMAKE_CURRENT_BINARY_DIR} BINDIR)\n  file(TO_NATIVE_PATH ${CMAKE_COMMAND} CMAKE_CMD)\n  if(MSVC)\n    file(TO_NATIVE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/graphbolt/build.bat BUILD_SCRIPT)\n    add_custom_target(\n      graphbolt\n      ALL\n      ${CMAKE_COMMAND} -E env\n      CMAKE_COMMAND=${CMAKE_CMD}\n      CUDA_TOOLKIT_ROOT_DIR=${CUDA_TOOLKIT_ROOT_DIR}\n      USE_CUDA=${USE_CUDA}\n      BINDIR=${BINDIR}\n      CFLAGS=${CMAKE_C_FLAGS}\n      CXXFLAGS=${CMAKE_CXX_FLAGS}\n      CUDAARCHS=\"${CUDA_ARCHITECTURES_ESCAPED}\"\n      LDFLAGS=${CMAKE_SHARED_LINKER_FLAGS}\n      cmd /e:on /c ${BUILD_SCRIPT} ${TORCH_PYTHON_INTERPS}\n      DEPENDS ${BUILD_SCRIPT}\n      WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/graphbolt)\n  else(MSVC)\n    file(TO_NATIVE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/graphbolt/build.sh BUILD_SCRIPT)\n    add_custom_target(\n      graphbolt\n      ALL\n      ${CMAKE_COMMAND} -E env\n      CMAKE_COMMAND=${CMAKE_CMD}\n      CUDA_TOOLKIT_ROOT_DIR=${CUDA_TOOLKIT_ROOT_DIR}\n      USE_CUDA=${USE_CUDA}\n      USE_LIBURING=${USE_LIBURING}\n      BINDIR=${CMAKE_CURRENT_BINARY_DIR}\n      CFLAGS=${CMAKE_C_FLAGS}\n      CXXFLAGS=${CMAKE_CXX_FLAGS}\n      CUDAARCHS=\"${CUDA_ARCHITECTURES_ESCAPED}\"\n      LDFLAGS=${CMAKE_SHARED_LINKER_FLAGS}\n      bash ${BUILD_SCRIPT} ${TORCH_PYTHON_INTERPS}\n      DEPENDS ${BUILD_SCRIPT}\n      WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/graphbolt)\n  endif(MSVC)\nendif(BUILD_GRAPHBOLT)\n"
  },
  {
    "path": "CONTRIBUTORS.md",
    "content": "## Contributing to DGL\r\n\r\nContribution is always welcomed. A good starting place is the roadmap issue, where\r\nyou can find our current milestones. All contributions must go through pull requests\r\nand be reviewed by the committers. See our [contribution\r\nguide](https://docs.dgl.ai/contribute.html) for more details.\r\n\r\nOnce your contribution is accepted and merged, congratulations, you are now a\r\ncontributor to the DGL project.  We will put your name in the list below.\r\n\r\nContributors\r\n------------\r\n\r\n* [Minjie Wang](https://github.com/jermainewang) from AWS\r\n* [Da Zheng](https://github.com/zheng-da) from AWS\r\n* [Quan Gan](https://github.com/BarclayII) from AWS\r\n* [Mufei Li](https://github.com/mufeili) from AWS\r\n* [Jinjing Zhou](https://github.com/VoVAllen) from AWS\r\n* [Xiang Song](https://github.com/classicsong) from AWS\r\n* [Tianjun Xiao](https://github.com/sneakerkg) from AWS\r\n* [Tong He](https://github.com/hetong007) from AWS\r\n* [Jian Zhang](https://github.com/zhjwy9343) from AWS\r\n* [Qipeng Guo](https://github.com/QipengGuo) from AWS\r\n* [Xiangkun Hu](https://github.com/HuXiangkun) from AWS\r\n* [Ying Rui](https://github.com/Rhett-Ying) from AWS\r\n* [Israt Nisa](https://github.com/isratnisa) from AWS\r\n* [Zheng Zhang](https://github.com/zzhang-cn) from AWS\r\n* [Zihao Ye](https://github.com/yzh119) from University of Washington\r\n* [Chao Ma](https://github.com/aksnzhy)\r\n* [Qidong](https://github.com/soodoshll)\r\n* [Lingfan Yu](https://github.com/lingfanyu) from New York University\r\n* [Yu Gai](https://github.com/GaiYu0) from University of California, Berkeyley\r\n* [Qi Huang]() from New York University\r\n* [Dominique LaSalle](https://github.com/nv-dlasalle) from Nvidia\r\n* [Pawel Piotrowcz](https://github.com/pawelpiotrowicz) from Intel\r\n* [Michal Szarmach](https://github.com/mszarma) from Intel\r\n* [Izabela Mazur](https://github.com/IzabelaMazur) from Intel\r\n* [Sanchit Misra](https://github.com/sanchit-misra) from Intel\r\n* [Andrzej Kotlowski](https://github.com/anko-intel) from Intel\r\n* [Sheng Zha](https://github.com/szha) from AWS\r\n* [Yifei Ma](https://github.com/yifeim) from  AWS\r\n* [Yizhi Liu](https://github.com/yzhliu) from AWS\r\n* [Kay Liu](https://github.com/kayzliu) from UIC\r\n* [Tianqi Zhang](https://github.com/lygztq) from SJTU\r\n* [Hengrui Zhang](https://github.com/hengruizhang98)\r\n* [Seung Won Min](https://github.com/davidmin7) from UIUC\r\n* [@hbsun2113](https://github.com/hbsun2113): GraphSAGE in PyTorch\r\n* [Tianyi Zhang](https://github.com/Tiiiger): SGC in PyTorch\r\n* [Jun Chen](https://github.com/kitaev-chen): GIN in PyTorch\r\n* [Aymen Waheb](https://github.com/aymenwah): APPNP in PyTorch\r\n* [Chengqiang Lu](https://github.com/geekinglcq): MGCN, SchNet and MPNN in PyTorch\r\n* [Gongze Cao](https://github.com/Zardinality): Cluster GCN\r\n* [Yicheng Wu](https://github.com/MilkshakeForReal): RotatE in PyTorch\r\n* [Hao Xiong](https://github.com/ShawXh): DeepWalk in PyTorch\r\n* [Zhi Lin](https://github.com/kira-lin): Integrate FeatGraph into DGL\r\n* [Andrew Tsesis](https://github.com/noncomputable): Framework-Agnostic Graph Ops\r\n* [Brett Koonce](https://github.com/brettkoonce)\r\n* [@giuseppefutia](https://github.com/giuseppefutia)\r\n* [@mori97](https://github.com/mori97)\r\n* [@xnuohz](https://github.com/xnuohz)\r\n* [Hao Jin](https://github.com/haojin2) from Amazon\r\n* [Xin Yao](https://github.com/yaox12) from Nvidia\r\n* [Abdurrahman Yasar](https://github.com/ayasar70) from Nvidia\r\n* [Shaked Brody](https://github.com/shakedbr) from Technion\r\n* [Jiahui Liu](https://github.com/paoxiaode) from Nvidia\r\n* [Neil Dickson](https://github.com/ndickson-nvidia) from Nvidia\r\n* [Chang Liu](https://github.com/chang-l) from Nvidia\r\n* [Muhammed Fatih Balin](https://github.com/mfbalin) from Nvidia and Georgia Tech\r\n"
  },
  {
    "path": "Jenkinsfile",
    "content": "#!/usr/bin/env groovy\n\n// CI tests are executed within Docker containers as the 'root' user. However,\n// communications between Jenkins nodes are done with the 'ubuntu' user(login\n// via root is disallowed on AWS EC2 instances). Therefore, we need to change\n// the file permission to allow 'ubuntu' user to access the files created by\n// the 'root' user. This is achieved by running 'chmod -R 777 .'.\n\n// Summary of Jenkins nodes:\n// - linux-benchmark-node: Linux CPU node for authentication and lint check.\n//      number of nodes: 1\n//      instance type: m5.2xlarge(8 vCPUs, 32 GB memory)\n//      number of executors per node: 6\n//      number of jobs running on this node per CI run: 3\n// - dgl-ci-linux-cpu: Linux CPU node for building and testing.\n//      number of nodes: 4\n//      instance type: m6i.24xlarge(96 vCPUs, 384 GB memory)\n//      number of executors per node: 6\n//      number of jobs running on this node per CI run: 8\n// - dgl-ci-linux-gpu: Linux GPU node for building and testing.\n//      number of nodes: 4\n//      instance type: g4dn.4xlarge(16 vCPUs, 64 GB memory, 1 GPU)\n//      number of executors per node: 1\n//      number of jobs running on this node per CI run: 4\n// - dgl-ci-windows-cpu: Windows CPU node for building and testing.\n//      number of nodes: 4\n//      instance type: m6i.8xlarge(32 vCPUs, 128 GB memory)\n//      number of executors per node: 2\n//      number of jobs running on this node per CI run: 3\n\ndgl_linux_libs = 'build/libdgl.so, build/runUnitTests, python/dgl/_ffi/_cy3/core.cpython-*-x86_64-linux-gnu.so, build/tensoradapter/pytorch/*.so, build/dgl_sparse/*.so, build/graphbolt/*.so'\n// Currently DGL on Windows is not working with Cython yet\ndgl_win64_libs = \"build\\\\dgl.dll, build\\\\runUnitTests.exe, build\\\\tensoradapter\\\\pytorch\\\\*.dll, build\\\\dgl_sparse\\\\*.dll, build\\\\graphbolt\\\\*.dll\"\n\ndef init_git() {\n  sh \"chmod -R 777 .\" // Fix permission issue\n  sh 'rm -rf *'\n  sh \"git config --global --add safe.directory '*'\"\n  checkout scm\n  sh 'git submodule update --recursive --init'\n}\n\ndef init_git_win64() {\n  checkout scm\n  bat 'git submodule update --recursive --init'\n}\n\n// pack libraries for later use\ndef pack_lib(name, libs) {\n  echo \"Packing ${libs} into ${name}\"\n  stash includes: libs, name: name\n}\n\n// unpack libraries saved before\ndef unpack_lib(name, libs) {\n  unstash name\n  echo \"Unpacked ${libs} from ${name}\"\n}\n\ndef build_dgl_linux(dev) {\n  init_git()\n  sh \"bash tests/scripts/build_dgl.sh ${dev}\"\n  sh 'ls -lh /usr/lib/x86_64-linux-gnu/'\n  pack_lib(\"dgl-${dev}-linux\", dgl_linux_libs)\n}\n\ndef build_dgl_win64(dev) {\n  /* Assuming that Windows slaves are already configured with MSBuild VS2017,\n   * CMake and Python/pip/setuptools etc. */\n  init_git_win64()\n  bat \"CALL tests\\\\scripts\\\\build_dgl.bat\"\n  pack_lib(\"dgl-${dev}-win64\", dgl_win64_libs)\n}\n\ndef cpp_unit_test_linux(dev) {\n  init_git()\n  unpack_lib(\"dgl-${dev}-linux\", dgl_linux_libs)\n  sh 'bash tests/scripts/task_cpp_unit_test.sh'\n}\n\ndef cpp_unit_test_win64() {\n  init_git_win64()\n  unpack_lib('dgl-cpu-win64', dgl_win64_libs)\n  bat \"CALL tests\\\\scripts\\\\task_cpp_unit_test.bat\"\n}\n\ndef unit_test_linux(backend, dev) {\n  init_git()\n  unpack_lib(\"dgl-${dev}-linux\", dgl_linux_libs)\n  timeout(time: 40, unit: 'MINUTES') {\n    sh \"bash tests/scripts/task_unit_test.sh ${backend} ${dev}\"\n  }\n}\n\ndef unit_distributed_linux(backend, dev) {\n  init_git()\n  unpack_lib(\"dgl-${dev}-linux\", dgl_linux_libs)\n  timeout(time: 40, unit: 'MINUTES') {\n    sh \"bash tests/scripts/task_distributed_test.sh ${backend} ${dev}\"\n  }\n}\n\ndef unit_test_cugraph(backend, dev) {\n  init_git()\n  unpack_lib(\"dgl-${dev}-linux\", dgl_linux_libs)\n  timeout(time: 15, unit: 'MINUTES') {\n    sh \"bash tests/scripts/cugraph_unit_test.sh ${backend}\"\n  }\n}\n\ndef unit_test_win64(backend, dev) {\n  init_git_win64()\n  unpack_lib(\"dgl-${dev}-win64\", dgl_win64_libs)\n  timeout(time: 50, unit: 'MINUTES') {\n    bat \"CALL tests\\\\scripts\\\\task_unit_test.bat ${backend}\"\n  }\n}\n\ndef example_test_linux(backend, dev) {\n  init_git()\n  unpack_lib(\"dgl-${dev}-linux\", dgl_linux_libs)\n  timeout(time: 20, unit: 'MINUTES') {\n    sh \"bash tests/scripts/task_example_test.sh ${dev}\"\n  }\n}\n\ndef example_test_win64(backend, dev) {\n  init_git_win64()\n  unpack_lib(\"dgl-${dev}-win64\", dgl_win64_libs)\n  timeout(time: 20, unit: 'MINUTES') {\n    bat \"CALL tests\\\\scripts\\\\task_example_test.bat ${dev}\"\n  }\n}\n\ndef tutorial_test_linux(backend) {\n  init_git()\n  unpack_lib('dgl-cpu-linux', dgl_linux_libs)\n  timeout(time: 20, unit: 'MINUTES') {\n    sh \"bash tests/scripts/task_${backend}_tutorial_test.sh\"\n  }\n}\n\ndef go_test_linux() {\n  init_git()\n  unpack_lib('dgl-cpu-linux', dgl_linux_libs)\n  timeout(time: 20, unit: 'MINUTES') {\n    sh \"bash tests/scripts/task_go_test.sh\"\n  }\n}\n\ndef is_authorized(name) {\n  def devs = [\n    // System:\n    'dgl-bot', 'noreply',\n    // Core:\n    'Rhett-Ying', 'BarclayII', 'jermainewang', 'mufeili', 'isratnisa',\n    'rudongyu', 'classicsong', 'HuXiangkun', 'hetong007', 'kylasa',\n    'frozenbugs', 'peizhou001', 'zheng-da', 'czkkkkkk', 'thvasilo',\n    // Intern:\n    'pyynb', 'az15240', 'BowenYao18', 'kec020', 'Liu-rj',\n    // Friends:\n    'nv-dlasalle', 'yaox12', 'chang-l', 'Kh4L', 'VibhuJawa', 'kkranen',\n    'TristonC', 'mfbalin',\n    'bgawrych', 'itaraban', 'daniil-sizov', 'anko-intel', 'Kacper-Pietkun',\n    'hankaj', 'agrabows', 'DominikaJedynak', 'RafLit', 'CfromBU',\n    // Emeritus:\n    'VoVAllen',\n  ]\n  return (name in devs)\n}\n\ndef is_admin(name) {\n  def admins = ['dgl-bot', 'Rhett-Ying', 'BarclayII', 'jermainewang']\n  return (name in admins)\n}\n\ndef regression_test_done = false\n\npipeline {\n  agent any\n  triggers {\n        issueCommentTrigger('@dgl-bot.*')\n  }\n  stages {\n    // Below 2 stages are to authenticate the change/comment author.\n    // Only core developers are allowed to trigger CI.\n    // Such authentication protects CI from malicious code which may bring CI instances down.\n    stage('Authentication') {\n      agent {\n        docker {\n            label 'linux-benchmark-node'\n            image 'dgllib/dgl-ci-lint'\n            alwaysPull true\n        }\n      }\n      when { not { triggeredBy 'IssueCommentCause' } }\n      steps {\n        script {\n          def author = env.CHANGE_AUTHOR\n          def prOpenTriggerCause = currentBuild.getBuildCauses('jenkins.branch.BranchEventCause')\n          def first_run = prOpenTriggerCause && env.BUILD_ID == '1'\n          if (author && !is_authorized(author)) {\n            pullRequest.comment(\"Not authorized to trigger CI. Please ask core developer to help trigger via issuing comment: \\n - `@dgl-bot`\")\n            error(\"Authentication failed.\")\n          }\n          if (first_run) {\n            pullRequest.comment('To trigger regression tests: \\n - `@dgl-bot run [instance-type] [which tests] [compare-with-branch]`; \\n For example: `@dgl-bot run g4dn.4xlarge all dmlc/master` or `@dgl-bot run c5.9xlarge kernel,api dmlc/master`')\n          }\n        }\n      }\n    }\n    stage('AuthenticationComment') {\n      agent {\n        docker {\n            label 'linux-benchmark-node'\n            image 'dgllib/dgl-ci-lint'\n            alwaysPull true\n        }\n      }\n      when { triggeredBy 'IssueCommentCause' }\n      steps {\n        script {\n          def author = env.GITHUB_COMMENT_AUTHOR\n          if (!is_authorized(author)) {\n            pullRequest.comment(\"Not authorized to trigger CI via issuing comment.\")\n            error(\"Authentication failed.\")\n          }\n        }\n      }\n    }\n    stage('Regression Test') {\n      agent {\n        docker {\n            label 'linux-benchmark-node'\n            image 'dgllib/dgl-ci-lint'\n            alwaysPull true\n        }\n      }\n      when { triggeredBy 'IssueCommentCause' }\n      steps {\n          checkout scm\n          script {\n              def comment = env.GITHUB_COMMENT\n              def command_lists = comment.split(' ')\n              if (command_lists.size() == 1) {\n                // CI command, not for regression\n                return\n              }\n              if (command_lists.size() != 5) {\n                pullRequest.comment('Cannot run the regression test due to unknown command')\n                error('Unknown command')\n              }\n              def author = env.GITHUB_COMMENT_AUTHOR\n              echo(\"${env.GIT_URL}\")\n              echo(\"${env}\")\n              if (!is_admin(author)) {\n                error('Not authorized to launch regression tests')\n              }\n              dir('benchmark_scripts_repo') {\n                checkout([$class: 'GitSCM', branches: [[name: '*/master']],\n                        userRemoteConfigs: [[credentialsId: 'github', url: 'https://github.com/dglai/DGL_scripts.git']]])\n              }\n              sh('cp benchmark_scripts_repo/benchmark/* benchmarks/scripts/')\n              def instance_type = command_lists[2].replace('.', '')\n              pullRequest.comment(\"Start the Regression test. View at ${RUN_DISPLAY_URL}\")\n              def prNumber = env.BRANCH_NAME.replace('PR-', '')\n              dir('benchmarks/scripts') {\n                sh('python3 -m pip install boto3')\n                sh(\"PYTHONUNBUFFERED=1 GIT_PR_ID=${prNumber} GIT_URL=${env.GIT_URL} GIT_BRANCH=${env.CHANGE_BRANCH} python3 run_reg_test.py --data-folder ${env.GIT_COMMIT}_${instance_type} --run-cmd '${comment}'\")\n              }\n              pullRequest.comment(\"Finished the Regression test. Result table is at https://dgl-asv-data.s3-us-west-2.amazonaws.com/${env.GIT_COMMIT}_${instance_type}/results/result.csv. Jenkins job link is ${RUN_DISPLAY_URL}. \")\n              currentBuild.result = 'SUCCESS'\n              regression_test_done = true\n          }\n      }\n    }\n    stage('CI') {\n      when { expression { !regression_test_done } }\n      stages {\n        stage('Abort Previous CI') {\n          steps {\n            script {\n              if (env.BRANCH_NAME != \"master\") {\n                // Jenkins will abort an older build if a newer build already\n                // passed a higher milestone.\n                // https://www.jenkins.io/doc/pipeline/steps/pipeline-milestone-step/\n                def buildNumber = env.BUILD_NUMBER as int\n                for (int i = 1; i <= buildNumber; i++) {\n                  milestone(i)\n                }\n              }\n            }\n          }\n        }\n\n        stage('Lint Check') {\n          agent {\n            docker {\n              label \"linux-benchmark-node\"\n              image \"dgllib/dgl-ci-lint\"\n              alwaysPull true\n            }\n          }\n          steps {\n            init_git()\n            sh 'bash tests/scripts/task_lint.sh'\n          }\n          post {\n            always {\n              cleanWs disableDeferredWipeout: true, deleteDirs: true\n            }\n          }\n        }\n\n        stage('Build') {\n          parallel {\n            stage('CPU Build') {\n              agent {\n                docker {\n                  label \"dgl-ci-linux-cpu\"\n                  image \"dgllib/dgl-ci-cpu:v240511_1440\"\n                  args \"-u root\"\n                  alwaysPull true\n                }\n              }\n              steps {\n                build_dgl_linux('cpu')\n              }\n              post {\n                always {\n                  sh \"chmod -R 777 .\" // Fix permission issue\n                  cleanWs disableDeferredWipeout: true, deleteDirs: true\n                }\n              }\n            }\n            stage('GPU Build') {\n              agent {\n                docker {\n                  label \"dgl-ci-linux-cpu\"\n                  image \"dgllib/dgl-ci-gpu:cu121_v240511_1440\"\n                  args \"-u root\"\n                  alwaysPull true\n                }\n              }\n              steps {\n                // sh \"nvidia-smi\"\n                build_dgl_linux('gpu')\n              }\n              post {\n                always {\n                  sh \"chmod -R 777 .\" // Fix permission issue\n                  cleanWs disableDeferredWipeout: true, deleteDirs: true\n                }\n              }\n            }\n            stage('PyTorch Cugraph GPU Build') {\n              agent {\n                docker {\n                  label \"dgl-ci-linux-cpu\"\n                  image \"rapidsai/cugraph_stable_torch-cuda:11.8-base-ubuntu20.04-py3.10-pytorch2.0.0-rapids23.04\"\n                  args \"-u root\"\n                  alwaysPull true\n                }\n              }\n              steps {\n                build_dgl_linux('cugraph')\n              }\n              post {\n                always {\n                  sh \"chmod -R 777 .\" // Fix permission issue\n                  cleanWs disableDeferredWipeout: true, deleteDirs: true\n                }\n              }\n            }\n            stage('CPU Build (Win64)') {\n              agent { label 'dgl-ci-windows-cpu' }\n              steps {\n                build_dgl_win64('cpu')\n              }\n              post {\n                always {\n                  cleanWs disableDeferredWipeout: true, deleteDirs: true\n                }\n              }\n            }\n          // Currently we don't have Windows GPU build machines\n          }\n        }\n        stage('Test') {\n          parallel {\n            stage('C++ CPU') {\n              agent {\n                docker {\n                  label \"dgl-ci-linux-cpu\"\n                  image \"dgllib/dgl-ci-cpu:v240511_1440\"\n                  args \"-u root\"\n                  alwaysPull true\n                }\n              }\n              steps {\n                cpp_unit_test_linux('cpu')\n              }\n              post {\n                always {\n                  sh \"chmod -R 777 .\" // Fix permission issue\n                  cleanWs disableDeferredWipeout: true, deleteDirs: true\n                }\n              }\n            }\n            stage('C++ GPU') {\n              agent {\n                docker {\n                  label \"dgl-ci-linux-gpu\"\n                  image \"dgllib/dgl-ci-gpu:cu121_v240511_1440\"\n                  args \"-u root --runtime nvidia\"\n                  alwaysPull true\n                }\n              }\n              steps {\n                cpp_unit_test_linux('gpu')\n              }\n              post {\n                always {\n                  sh \"chmod -R 777 .\" // Fix permission issue\n                  cleanWs disableDeferredWipeout: true, deleteDirs: true\n                }\n              }\n            }\n            stage('C++ CPU (Win64)') {\n              agent { label 'dgl-ci-windows-cpu' }\n              steps {\n                cpp_unit_test_win64()\n              }\n              post {\n                always {\n                  cleanWs disableDeferredWipeout: true, deleteDirs: true\n                }\n              }\n            }\n            stage('Tensorflow CPU') {\n              agent {\n                docker {\n                  label \"dgl-ci-linux-cpu\"\n                  image \"dgllib/dgl-ci-cpu:v230810\"\n                  args \"-u root\"\n                  alwaysPull true\n                }\n              }\n              stages {\n                stage('Tensorflow CPU Unit test') {\n                  steps {\n                    unit_test_linux('tensorflow', 'cpu')\n                  }\n                  // Tensorflow is deprecated.\n                  when { expression { false } }\n                }\n              }\n              post {\n                always {\n                  sh \"chmod -R 777 .\" // Fix permission issue\n                  cleanWs disableDeferredWipeout: true, deleteDirs: true\n                }\n              }\n            }\n            stage('Tensorflow GPU') {\n              agent {\n                docker {\n                  label \"dgl-ci-linux-gpu\"\n                  image \"dgllib/dgl-ci-gpu:cu121_v240511_1440\"\n                  args \"-u root --runtime nvidia\"\n                  alwaysPull true\n                }\n              }\n              stages {\n                stage('Tensorflow GPU Unit test') {\n                  steps {\n                    unit_test_linux('tensorflow', 'gpu')\n                  }\n                  // Tensorflow does not support cuda 11.6 yet.\n                  when { expression { false } }\n                }\n              }\n              post {\n                always {\n                  sh \"chmod -R 777 .\" // Fix permission issue\n                  cleanWs disableDeferredWipeout: true, deleteDirs: true\n                }\n              }\n            }\n            stage('Torch CPU') {\n              agent {\n                docker {\n                  label \"dgl-ci-linux-cpu\"\n                  image \"dgllib/dgl-ci-cpu:v240511_1440\"\n                  args \"-u root --shm-size=4gb\"\n                  alwaysPull true\n                }\n              }\n              stages {\n                stage('Torch CPU Unit test') {\n                  steps {\n                    unit_test_linux('pytorch', 'cpu')\n                  }\n                }\n                stage('Torch CPU Example test') {\n                  steps {\n                    example_test_linux('pytorch', 'cpu')\n                  }\n                }\n                stage('Torch CPU Tutorial test') {\n                  steps {\n                    tutorial_test_linux('pytorch')\n                  }\n                }\n              }\n              post {\n                always {\n                  sh \"chmod -R 777 .\" // Fix permission issue\n                  cleanWs disableDeferredWipeout: true, deleteDirs: true\n                }\n              }\n            }\n            stage('Torch CPU (Win64)') {\n              agent { label 'dgl-ci-windows-cpu' }\n              stages {\n                stage('Torch CPU (Win64) Unit test') {\n                  steps {\n                    unit_test_win64('pytorch', 'cpu')\n                  }\n                }\n                stage('Torch CPU (Win64) Example test') {\n                  steps {\n                    example_test_win64('pytorch', 'cpu')\n                  }\n                }\n              }\n              post {\n                always {\n                  cleanWs disableDeferredWipeout: true, deleteDirs: true\n                }\n              }\n            }\n            stage('Torch GPU') {\n              agent {\n                docker {\n                  label \"dgl-ci-linux-gpu\"\n                  image \"dgllib/dgl-ci-gpu:cu121_v240511_1440\"\n                  args \"-u root --runtime nvidia --shm-size=8gb\"\n                  alwaysPull true\n                }\n              }\n              stages {\n                stage('Torch GPU Unit test') {\n                  steps {\n                    sh 'nvidia-smi'\n                    unit_test_linux('pytorch', 'gpu')\n                  }\n                }\n                stage('Torch GPU Example test') {\n                  steps {\n                    example_test_linux('pytorch', 'gpu')\n                  }\n                }\n              }\n              post {\n                always {\n                  sh \"chmod -R 777 .\" // Fix permission issue\n                  cleanWs disableDeferredWipeout: true, deleteDirs: true\n                }\n              }\n            }\n            stage('Distributed') {\n              agent {\n                docker {\n                  label \"dgl-ci-linux-cpu\"\n                  image \"dgllib/dgl-ci-cpu:v240511_1440\"\n                  args \"-u root --shm-size=8gb\"\n                  alwaysPull true\n                }\n              }\n              stages {\n                stage('Distributed Torch CPU Unit test') {\n                  steps {\n                    unit_distributed_linux('pytorch', 'cpu')\n                  }\n                }\n              }\n              post {\n                always {\n                  sh \"chmod -R 777 .\" // Fix permission issue\n                  cleanWs disableDeferredWipeout: true, deleteDirs: true\n                }\n              }\n            }\n            stage('PyTorch Cugraph GPU') {\n              agent {\n                docker {\n                  label \"dgl-ci-linux-gpu\"\n                  image \"rapidsai/cugraph_stable_torch-cuda:11.8-base-ubuntu20.04-py3.10-pytorch2.0.0-rapids23.04\"\n                  args \"-u root --runtime nvidia --shm-size=8gb\"\n                  alwaysPull true\n                }\n              }\n              stages {\n                stage('PyTorch Cugraph GPU Unit test') {\n                  steps {\n                    sh 'nvidia-smi'\n                    unit_test_cugraph('pytorch', 'cugraph')\n                  }\n                  // Cugraph is under refactoring. Skip the test for now.\n                  when { expression { false } }\n                }\n              }\n              post {\n                always {\n                  sh \"chmod -R 777 .\" // Fix permission issue\n                  cleanWs disableDeferredWipeout: true, deleteDirs: true\n                }\n              }\n            }\n            stage('DGL-Go') {\n              agent {\n                docker {\n                  label \"dgl-ci-linux-cpu\"\n                  image \"dgllib/dgl-ci-cpu:v240511_1440\"\n                  args \"-u root\"\n                  alwaysPull true\n                }\n              }\n              stages {\n                stage('DGL-Go CPU test') {\n                  steps {\n                    go_test_linux()\n                  }\n                }\n              }\n              post {\n                always {\n                  sh \"chmod -R 777 .\" // Fix permission issue\n                  cleanWs disableDeferredWipeout: true, deleteDirs: true\n                }\n              }\n            }\n          }\n        }\n      }\n    }\n  }\n  post {\n    always {\n      script {\n        node(\"dglci-post-linux\") {\n          docker.image('dgllib/dgl-ci-awscli:v220418').inside(\"--pull always --entrypoint=''\") {\n            sh(\"rm -rf ci_tmp\")\n            dir('ci_tmp') {\n              sh(\"curl -k -o cireport.log ${BUILD_URL}consoleText\")\n              sh(\"curl -o report.py https://raw.githubusercontent.com/dmlc/dgl/master/tests/scripts/ci_report/report.py\")\n              sh(\"curl -o status.py https://raw.githubusercontent.com/dmlc/dgl/master/tests/scripts/ci_report/status.py\")\n              sh(\"curl -k -L ${BUILD_URL}wfapi\")\n              sh(\"cat status.py\")\n              sh(\"pytest --html=report.html --self-contained-html report.py || true\")\n              sh(\"aws s3 sync ./ s3://dgl-ci-result/${JOB_NAME}/${BUILD_NUMBER}/${BUILD_ID}/logs/  --exclude '*' --include '*.log' --acl public-read --content-type text/plain\")\n              sh(\"aws s3 sync ./ s3://dgl-ci-result/${JOB_NAME}/${BUILD_NUMBER}/${BUILD_ID}/logs/  --exclude '*.log' --acl public-read\")\n\n              def comment = sh(returnStdout: true, script: \"python3 status.py --result ${currentBuild.currentResult}\").trim()\n              echo(comment)\n              if ((env.BRANCH_NAME).startsWith('PR-')) {\n                pullRequest.comment(comment)\n              }\n            }\n          }\n        }\n        node('dgl-ci-windows-cpu') {\n            bat(script: \"rmvirtualenv ${BUILD_TAG}\", returnStatus: true)\n        }\n      }\n    }\n  }\n}\n"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "NEWS.md",
    "content": "DGL release and change logs\n==========\n\nRefer to the roadmap issue for the on-going versions and features.\n\n0.2\n---\nMajor release that includes many features, bugfix and performance improvement.\nSpeed of GCN model on Pubmed dataset has been improved by **4.19x**! Speed of\nRGCN model on Mutag dataset has been improved by **7.35x**! Important new\nfeature: **graph sampling APIs**.\n\nUpdate details:\n\n# Model examples\n- [x] TreeLSTM w/ MXNet (PR #279 by @szha )\n- [x] GraphSage (@ZiyueHuang )\n- [x] Improve GAT model speed (PR #348 by @jermainewang )\n\n# Core system improvement\n- [x] Immutable CSR graph structure (PR #342 by @zheng-da )\n  - [x] Finish remaining functionality (Issue #369, PR #404 by @yzh119)\n- [x] Nodeflow data structure (PR #361 by @zheng-da )\n- [x] Neighbor sampler (PR #322 )\n- [x] Layer-wise sampler (PR #362 by @GaiYu0 )\n- [x] Multi-GPU support by data parallelism (PR #356 #338 by @ylfdq1118 )\n- [x] More dataset:\n  - [x] Reddit dataset loader (PR #372 by @ZiyueHuang )\n  - [x] PPI dataset loader (PR #395 by @sufeidechabei )\n  - [x] Mini graph classification dataset (PR #364 by @mufeili )\n- [x] NN modules (PR #406 by @jermainewang @mufeili)\n  - [x] GraphConv layer\n  - [x] Edge softmax layer\n- [x] Edge group apply API (PR #358 by @VoVAllen )\n- [x] Reversed graph and transform.py module (PR #331 by @mufeili )\n- [x] Max readout (PR #341 by @mufeili )\n- [x] Random walk APIs (PR #392 by @BarclayII )\n\n# Tutorial/Blog\n- [x] Batched graph classification in DGL (PR #360 by @mufeili )\n- [x] Understanding GAT (@sufeidechabei )\n\n# Project improvement\n- [x] Python lint check (PR #330 by @jermainewang )\n- [x] Win CI (PR #324 by @BarclayII )\n- [x] Auto doc build (by @VoVAllen )\n- [x] Unify tests for different backends (PR #333 by @BarclayII )\n\n0.1.3\n-----\nBug fix\n* Compatible with Pytorch v1.0\n* Bug fix in networkx graph conversion.\n\n0.1.2\n-----\nFirst open release.\n* Basic graph APIs.\n* Basic message passing APIs.\n* Pytorch backend.\n* MXNet backend.\n* Optimization using SPMV.\n* Model examples w/ Pytorch:\n  - GCN\n  - GAT\n  - JTNN\n  - DGMG\n  - Capsule\n  - LGNN\n  - RGCN\n  - Transformer\n  - TreeLSTM\n* Model examples w/ MXNet:\n  - GCN\n  - GAT\n  - RGCN\n  - SSE\n"
  },
  {
    "path": "README.md",
    "content": "<p align=\"center\">\n  <img src=\"http://data.dgl.ai/asset/logo.jpg\" height=\"200\">\n</p>\n\n[![Latest Release](https://img.shields.io/github/v/release/dmlc/dgl)](https://github.com/dmlc/dgl/releases)\n[![Conda Latest Release](https://anaconda.org/dglteam/dgl/badges/version.svg)](https://anaconda.org/dglteam/dgl)\n[![Build Status](https://ci.dgl.ai/buildStatus/icon?job=DGL/master)](https://ci.dgl.ai/job/DGL/job/master/)\n[![Benchmark by ASV](http://img.shields.io/badge/benchmarked%20by-asv-green.svg?style=flat)](https://asv.dgl.ai/)\n[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](./LICENSE)\n[![Twitter](https://img.shields.io/twitter/follow/DGLGraph?style=social)](https://twitter.com/GraphDeep)\n\n[Website](https://www.dgl.ai) | [A Blitz Introduction to DGL](https://docs.dgl.ai/tutorials/blitz/index.html) | Documentation ([Latest](https://www.dgl.ai/dgl_docs/) | [Official Examples](examples/README.md) | [Discussion Forum](https://discuss.dgl.ai) | [Slack Channel](https://join.slack.com/t/deep-graph-library/shared_invite/zt-eb4ict1g-xcg3PhZAFAB8p6dtKuP6xQ)\n\nDGL is an easy-to-use, high performance and scalable Python package for deep learning on graphs. DGL is framework agnostic, meaning if a deep graph model is a component of an end-to-end application, the rest of the logics can be implemented in any major frameworks, such as PyTorch, Apache MXNet or TensorFlow.\n\n<p align=\"center\">\n  <img src=\"http://data.dgl.ai/asset/image/DGL-Arch.png\" alt=\"DGL v0.4 architecture\" width=\"600\">\n  <br>\n  <b>Figure</b>: DGL Overall Architecture\n</p>\n\n## Highlighted Features\n\n### A GPU-ready graph library\n\nDGL provides a powerful graph object that can reside on either CPU or GPU. It bundles structural data as well as features for better control. We provide a variety of functions for computing with graph objects including efficient and customizable message passing primitives for Graph Neural Networks.\n\n### A versatile tool for GNN researchers and practitioners\n\nThe field of graph deep learning is still rapidly evolving and many research ideas emerge by standing on the shoulders of giants. To ease the process, [DGl-Go](https://github.com/dmlc/dgl/tree/master/dglgo) is a command-line interface to get started with training, using and studying state-of-the-art GNNs.\nDGL collects a rich set of [example implementations](https://github.com/dmlc/dgl/tree/master/examples) of popular GNN models of a wide range of topics. Researchers can [search](https://www.dgl.ai/) for related models to innovate new ideas from or use them as baselines for experiments. Moreover, DGL provides many state-of-the-art [GNN layers and modules](https://docs.dgl.ai/api/python/nn.html) for users to build new model architectures. DGL is one of the preferred platforms for many standard graph deep learning benchmarks including [OGB](https://ogb.stanford.edu/) and [GNNBenchmarks](https://github.com/graphdeeplearning/benchmarking-gnns).\n\n### Easy to learn and use\n\nDGL provides plenty of learning materials for all kinds of users from ML researchers to domain experts. The [Blitz Introduction to DGL](https://docs.dgl.ai/tutorials/blitz/index.html) is a 120-minute tour of the basics of graph machine learning. The [User Guide](https://docs.dgl.ai/guide/index.html) explains in more details the concepts of graphs as well as the training methodology. All of them include code snippets in DGL that are runnable and ready to be plugged into one’s own pipeline.\n\n### Scalable and efficient\n\nIt is convenient to train models using DGL on large-scale graphs across **multiple GPUs** or **multiple machines**. DGL extensively optimizes the whole stack to reduce the overhead in communication, memory consumption and synchronization. As a result, DGL can easily scale to billion-sized graphs. Get started with the [tutorials](https://docs.dgl.ai/en/tutorials/dist/index.html) and [user guide](https://docs.dgl.ai/en/latest/guide/distributed.html) for distributed training. See the [system performance note](https://docs.dgl.ai/performance.html) for the comparison with other tools.\n\n## Get Started\n\nUsers can install DGL from [pip and conda](https://www.dgl.ai/pages/start.html). You can also download GPU enabled DGL docker [containers](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/dgl) (backended by PyTorch) from NVIDIA NGC for both x86 and ARM based linux systems. Advanced users can follow the [instructions](https://docs.dgl.ai/install/index.html#install-from-source) to install from source.\n\nFor absolute beginners, start with [the Blitz Introduction to DGL](https://docs.dgl.ai/tutorials/blitz/index.html). It covers the basic concepts of common graph machine learning tasks and a step-by-step on building Graph Neural Networks (GNNs) to solve them.\n\nFor acquainted users who wish to learn more,\n\n* Experience state-of-the-art GNN models in only two command-lines using [DGL-Go](https://github.com/dmlc/dgl/tree/master/dglgo).\n* Learn DGL by [example implementations](https://www.dgl.ai/) of popular GNN models.\n* Read the [User Guide](https://docs.dgl.ai/guide/index.html) ([中文版链接](https://docs.dgl.ai/guide_cn/index.html)), which explains the concepts and usage of DGL in much more details.\n* Go through the tutorials for advanced features like [stochastic training of GNNs](https://docs.dgl.ai/tutorials/large/index.html), training on [multi-GPU](https://docs.dgl.ai/tutorials/multi/index.html) or [multi-machine](https://docs.dgl.ai/tutorials/dist/index.html).\n* [Study classical papers](https://docs.dgl.ai/tutorials/models/index.html) on graph machine learning alongside DGL.\n* Search for the usage of a specific API in the [API reference manual](https://docs.dgl.ai/api/python/index.html), which organizes all DGL APIs by their namespace.\n\nAll the learning materials are available at our [documentation site](https://docs.dgl.ai/). If you are new to deep learning in general,\ncheck out the open source book [Dive into Deep Learning](https://d2l.ai/).\n\n\n## Community\n\n### Get connected\n\nWe provide multiple channels to connect you to the community of the DGL developers, users, and the general GNN academic researchers:\n\n* Our Slack channel, [click to join](https://join.slack.com/t/deep-graph-library/shared_invite/zt-eb4ict1g-xcg3PhZAFAB8p6dtKuP6xQ)\n* Our discussion forum: https://discuss.dgl.ai/\n* Our [Zhihu blog (in Chinese)](https://www.zhihu.com/column/c_1070749881013936128)\n* Monthly GNN User Group online seminar ([event link](https://www.eventbrite.com/e/graph-neural-networks-user-group-tickets-137512275919?utm-medium=discovery&utm-campaign=social&utm-content=attendeeshare&aff=escb&utm-source=cp&utm-term=listing) | [past videos](https://www.youtube.com/channel/UCnmuSDY1pTlaFH1WRQElfTg))\n\nTake the survey [here](https://forms.gle/Ej3jHCocACmb49Gp8) and leave any feedback to make DGL better fit for your needs. Thanks!\n\n### DGL-powered projects\n\n* DGL-LifeSci: a DGL-based package for various applications in life science with graph neural networks. https://github.com/awslabs/dgl-lifesci\n* DGL-KE: a high performance, easy-to-use, and scalable package for learning large-scale knowledge graph embeddings. https://github.com/awslabs/dgl-ke\n* Benchmarking GNN: https://github.com/graphdeeplearning/benchmarking-gnns\n* OGB: a collection of realistic, large-scale, and diverse benchmark datasets for machine learning on graphs. https://ogb.stanford.edu/\n* Graph4NLP: an easy-to-use library for R&D at the intersection of Deep Learning on Graphs and Natural Language Processing. https://github.com/graph4ai/graph4nlp\n* GNN-RecSys: https://github.com/je-dbl/GNN-RecSys\n* Amazon Neptune ML: a new capability of Neptune that uses Graph Neural Networks (GNNs), a machine learning technique purpose-built for graphs, to make easy, fast, and more accurate predictions using graph data. https://aws.amazon.com/cn/neptune/machine-learning/\n* GNNLens2: Visualization tool for Graph Neural Networks. https://github.com/dmlc/GNNLens2\n* RNAGlib: A package to facilitate construction, analysis, visualization and machine learning on RNA 2.5D Graphs. Includes a pre-built dataset: https://rnaglib.cs.mcgill.ca\n* OpenHGNN: Model zoo and benchmarks for Heterogeneous Graph Neural Networks. https://github.com/BUPT-GAMMA/OpenHGNN\n* TGL: A graph learning framework for large-scale temporal graphs. https://github.com/amazon-research/tgl\n* gtrick: Bag of Tricks for Graph Neural Networks. https://github.com/sangyx/gtrick\n* ArangoDB-DGL Adapter: Import [ArangoDB](https://github.com/arangodb/arangodb) graphs into DGL and vice-versa. https://github.com/arangoml/dgl-adapter\n* DGLD: [DGLD](https://github.com/EagleLab-ZJU/DGLD) is an open-source library for Deep Graph Anomaly Detection based on pytorch and DGL.\n### Awesome Papers Using DGL\n\n1. [**Benchmarking Graph Neural Networks**](https://arxiv.org/pdf/2003.00982.pdf), *Vijay Prakash Dwivedi, Chaitanya K. Joshi, Thomas Laurent, Yoshua Bengio, Xavier Bresson*\n\n1. [**Open Graph Benchmarks: Datasets for Machine Learning on Graphs**](https://arxiv.org/pdf/2005.00687.pdf), NeurIPS'20, *Weihua Hu, Matthias Fey, Marinka Zitnik, Yuxiao Dong, Hongyu Ren, Bowen Liu, Michele Catasta, Jure Leskovec*\n\n1. [**DropEdge: Towards Deep Graph Convolutional Networks on Node Classification**](https://openreview.net/pdf?id=Hkx1qkrKPr), ICLR'20, *Yu Rong, Wenbing Huang, Tingyang Xu, Junzhou Huan*\n\n1. [**Discourse-Aware Neural Extractive Text Summarization**](https://www.aclweb.org/anthology/2020.acl-main.451/), ACL'20, *Jiacheng Xu, Zhe Gan, Yu Cheng, Jingjing Liu*\n\n1. [**GCC: Graph Contrastive Coding for Graph Neural Network Pre-Training**](https://dl.acm.org/doi/pdf/10.1145/3394486.3403168?casa_token=EClsH2Vc4DcAAAAA:LIB8cbtr6yTDbYuv4cTLwTIYeDq5Y2dhj_ktcWdKpzdPLGeiuL0o8GlcN4QIOnpsAnmGeGVZ), KDD'20, *Jiezhong Qiu, Qibin Chen, Yuxiao Dong, Jing Zhang, Hongxia Yang, Ming Ding, Kuansan Wang, Jie Tang*\n\n1. [**DGL-KE: Training Knowledge Graph Embeddings at Scale**](https://arxiv.org/pdf/2004.08532), SIGIR'20, *Da Zheng, Xiang Song, Chao Ma, Zeyuan Tan, Zihao Ye, Jin Dong, Hao Xiong, Zheng Zhang, George Karypis*\n\n1. [**Improving Graph Neural Network Expressivity via Subgraph Isomorphism Counting**](https://arxiv.org/pdf/2006.09252.pdf), *Giorgos Bouritsas, Fabrizio Frasca, Stefanos Zafeiriou, Michael M. Bronstein*\n\n1. [**INT: An Inequality Benchmark for Evaluating Generalization in Theorem Proving**](https://arxiv.org/pdf/2007.02924.pdf), *Yuhuai Wu, Albert Q. Jiang, Jimmy Ba, Roger Grosse*\n\n1. [**Finding Patient Zero: Learning Contagion Source with Graph Neural Networks**](https://arxiv.org/pdf/2006.11913.pdf), *Chintan Shah, Nima Dehmamy, Nicola Perra, Matteo Chinazzi, Albert-László Barabási, Alessandro Vespignani, Rose Yu*\n\n1. [**FeatGraph: A Flexible and Efficient Backend for Graph Neural Network Systems**](https://arxiv.org/pdf/2008.11359.pdf), SC'20, *Yuwei Hu, Zihao Ye, Minjie Wang, Jiali Yu, Da Zheng, Mu Li, Zheng Zhang, Zhiru Zhang, Yida Wang*\n\n\n<details><summary>more</summary>\n\n11. [**BP-Transformer: Modelling Long-Range Context via Binary Partitioning.**](https://arxiv.org/pdf/1911.04070.pdf), *Zihao Ye, Qipeng Guo, Quan Gan, Xipeng Qiu, Zheng Zhang*\n\n12. [**OptiMol: Optimization of Binding Affinities in Chemical Space for Drug Discovery**](https://www.biorxiv.org/content/biorxiv/early/2020/06/16/2020.05.23.112201.full.pdf), *Jacques Boitreaud,Vincent Mallet, Carlos Oliver, Jérôme Waldispühl*\n\n1. [**JAKET: Joint Pre-training of Knowledge Graph and Language Understanding**](https://arxiv.org/pdf/2010.00796.pdf), *Donghan Yu, Chenguang Zhu, Yiming Yang, Michael Zeng*\n\n1. [**Architectural Implications of Graph Neural Networks**](https://arxiv.org/pdf/2009.00804.pdf), *Zhihui Zhang, Jingwen Leng, Lingxiao Ma, Youshan Miao, Chao Li, Minyi Guo*\n\n1. [**Combining Reinforcement Learning and Constraint Programming for Combinatorial Optimization**](https://arxiv.org/pdf/2006.01610.pdf), *Quentin Cappart, Thierry Moisan, Louis-Martin Rousseau1, Isabeau Prémont-Schwarz, and Andre Cire*\n\n1. [**Therapeutics Data Commons: Machine Learning Datasets and Tasks for Therapeutics**](https://arxiv.org/abs/2102.09548) ([code repo](https://github.com/mims-harvard/TDC)), *Kexin Huang, Tianfan Fu, Wenhao Gao, Yue Zhao, Yusuf Roohani, Jure Leskovec, Connor W. Coley, Cao Xiao, Jimeng Sun, Marinka Zitnik*\n\n1. [**Sparse Graph Attention Networks**](https://arxiv.org/abs/1912.00552), *Yang Ye, Shihao Ji*\n\n1. [**On Self-Distilling Graph Neural Network**](https://arxiv.org/pdf/2011.02255.pdf), *Yuzhao Chen, Yatao Bian, Xi Xiao, Yu Rong, Tingyang Xu, Junzhou Huang*\n\n1. [**Learning Robust Node Representations on Graphs**](https://arxiv.org/pdf/2008.11416.pdf), *Xu Chen, Ya Zhang, Ivor Tsang, and Yuangang Pan*\n\n1. [**Recurrent Event Network: Autoregressive Structure Inference over Temporal Knowledge Graphs**](https://arxiv.org/abs/1904.05530), *Woojeong Jin, Meng Qu, Xisen Jin, Xiang Ren*\n\n1. [**Graph Neural Ordinary Differential Equations**](https://arxiv.org/abs/1911.07532), *Michael Poli, Stefano Massaroli, Junyoung Park, Atsushi Yamashita, Hajime Asama, Jinkyoo Park*\n\n1. [**FusedMM: A Unified SDDMM-SpMM Kernel for Graph Embedding and Graph Neural Networks**](https://arxiv.org/pdf/2011.06391.pdf), *Md. Khaledur Rahman, Majedul Haque Sujon, , Ariful Azad*\n\n1. [**An Efficient Neighborhood-based Interaction Model for Recommendation on Heterogeneous Graph**](https://arxiv.org/pdf/2007.00216.pdf), KDD'20 *Jiarui Jin, Jiarui Qin, Yuchen Fang, Kounianhua Du, Weinan Zhang, Yong Yu, Zheng Zhang, Alexander J. Smola*\n\n1. [**Learning Interaction Models of Structured Neighborhood on Heterogeneous Information Network**](https://arxiv.org/pdf/2011.12683.pdf), *Jiarui Jin, Kounianhua Du, Weinan Zhang, Jiarui Qin, Yuchen Fang, Yong Yu, Zheng Zhang, Alexander J. Smola*\n\n1. [**Graphein - a Python Library for Geometric Deep Learning and Network Analysis on Protein Structures**](https://www.biorxiv.org/content/10.1101/2020.07.15.204701v1), *Arian R. Jamasb, Pietro Lió, Tom L. Blundell*\n\n1. [**Graph Policy Gradients for Large Scale Robot Control**](https://arxiv.org/abs/1907.03822), *Arbaaz Khan, Ekaterina Tolstaya, Alejandro Ribeiro, Vijay Kumar*\n\n1. [**Heterogeneous Molecular Graph Neural Networks for Predicting Molecule Properties**](https://arxiv.org/abs/2009.12710), *Zeren Shui, George Karypis*\n\n1. [**Could Graph Neural Networks Learn Better Molecular Representation for Drug Discovery? A Comparison Study of Descriptor-based and Graph-based Models**](https://assets.researchsquare.com/files/rs-81439/v1_stamped.pdf), *Dejun Jiang, Zhenxing Wu, Chang-Yu Hsieh, Guangyong Chen, Ben Liao, Zhe Wang, Chao Shen, Dongsheng Cao, Jian Wu, Tingjun Hou*\n\n1. [**Principal Neighbourhood Aggregation for Graph Nets**](https://arxiv.org/abs/2004.05718), *Gabriele Corso, Luca Cavalleri, Dominique Beaini, Pietro Liò, Petar Veličković*\n\n1. [**Collective Multi-type Entity Alignment Between Knowledge Graphs**](https://dl.acm.org/doi/abs/10.1145/3366423.3380289), *Qi Zhu, Hao Wei, Bunyamin Sisman, Da Zheng, Christos Faloutsos, Xin Luna Dong, Jiawei Han*\n\n1. [**Graph Representation Forecasting of Patient's Medical Conditions: towards A Digital Twin**](https://arxiv.org/abs/2009.08299), *Pietro Barbiero, Ramon Viñas Torné, Pietro Lió*\n\n1. [**Relational Graph Learning on Visual and Kinematics Embeddings for Accurate Gesture Recognition in Robotic Surgery**](https://arxiv.org/abs/2011.01619), *Yong-Hao Long, Jie-Ying Wu, Bo Lu, Yue-Ming Jin, Mathias Unberath, Yun-Hui Liu, Pheng-Ann Heng and Qi Dou*\n\n1. [**Dark Reciprocal-Rank: Boosting Graph-Convolutional Self-Localization Network via Teacher-to-student Knowledge Transfer**](https://arxiv.org/abs/2011.00402), *Takeda Koji, Tanaka Kanji*\n\n1. [**Graph InfoClust: Leveraging Cluster-Level Node Information For Unsupervised Graph Representation Learning**](https://arxiv.org/abs/2009.06946), *Costas Mavromatis, George Karypis*\n\n1. [**GraphSeam: Supervised Graph Learning Framework for Semantic UV Mapping**](https://arxiv.org/abs/2011.13748), *Fatemeh Teimury, Bruno Roy, Juan Sebastian Casallas, David macdonald, Mark Coates*\n\n1. [**Comprehensive Study on Molecular Supervised Learning with Graph Neural Networks**](https://pubs.acs.org/doi/10.1021/acs.jcim.0c00416), *Doyeong Hwang, Soojung Yang, Yongchan Kwon, Kyung Hoon Lee, Grace Lee, Hanseok Jo, Seyeol Yoon, and Seongok Ryu*\n\n1. [**A graph auto-encoder model for miRNA-disease associations prediction**](https://academic.oup.com/bib/advance-article-abstract/doi/10.1093/bib/bbaa240/5929824?redirectedFrom=fulltext), *Zhengwei Li, Jiashu Li, Ru Nie, Zhu-Hong You, Wenzheng Bao*\n\n1. [**Graph convolutional regression of cardiac depolarization from sparse endocardial maps**](https://arxiv.org/abs/2009.14068), STACOM 2020 workshop, *Felix Meister, Tiziano Passerini, Chloé Audigier, Èric Lluch, Viorel Mihalef, Hiroshi Ashikaga, Andreas Maier, Henry Halperin, Tommaso Mansi*\n\n1. [**AttnIO: Knowledge Graph Exploration with In-and-Out Attention Flow for Knowledge-Grounded Dialogue**](https://www.aclweb.org/anthology/2020.emnlp-main.280/), EMNLP'20, *Jaehun Jung, Bokyung Son, Sungwon Lyu*\n\n1. [**Learning from Non-Binary Constituency Trees via Tensor Decomposition**](https://github.com/danielecastellana22/tensor-tree-nn), COLING'20, *Daniele Castellana, Davide Bacciu*\n\n1. [**Inducing Alignment Structure with Gated Graph Attention Networks for Sentence Matching**](https://arxiv.org/abs/2010.07668), *Peng Cui, Le Hu, Yuanchao Liu*\n\n1. [**Enhancing Extractive Text Summarization with Topic-Aware Graph Neural Networks**](https://arxiv.org/abs/2010.06253), COLING'20, *Peng Cui, Le Hu, Yuanchao Liu*\n\n1. [**Double Graph Based Reasoning for Document-level Relation Extraction**](https://arxiv.org/abs/2009.13752), EMNLP'20, *Shuang Zeng, Runxin Xu, Baobao Chang, Lei Li*\n\n1. [**Systematic Generalization on gSCAN with Language Conditioned Embedding**](https://arxiv.org/abs/2009.05552), AACL-IJCNLP'20, *Tong Gao, Qi Huang, Raymond J. Mooney*\n\n1. [**Automatic selection of clustering algorithms using supervised graph embedding**](https://arxiv.org/pdf/2011.08225.pdf), *Noy Cohen-Shapira, Lior Rokach*\n\n1. [**Improving Learning to Branch via Reinforcement Learning**](https://openreview.net/forum?id=z4D7-PTxTb), *Haoran Sun, Wenbo Chen, Hui Li, Le Song*\n\n1. [**A Practical Guide to Graph Neural Networks**](https://arxiv.org/pdf/2010.05234.pdf), *Isaac Ronald Ward, Jack Joyner, Casey Lickfold, Stash Rowe, Yulan Guo, Mohammed Bennamoun*, [code](https://github.com/isolabs/gnn-tutorial)\n\n1. [**APAN: Asynchronous Propagation Attention Network for Real-time Temporal Graph Embedding**](https://arxiv.org/pdf/2011.11545.pdf), SIGMOD'21, *Xuhong Wang, Ding Lyu, Mengjian Li, Yang Xia, Qi Yang, Xinwen Wang, Xinguang Wang, Ping Cui, Yupu Yang, Bowen Sun, Zhenyu Guo, Junkui Li*\n\n1. [**Uncertainty-Matching Graph Neural Networks to Defend Against Poisoning Attacks**](https://arxiv.org/pdf/2009.14455.pdf), *Uday Shankar Shanthamallu, Jayaraman J. Thiagarajan, Andreas Spanias*\n\n1. [**Computing Graph Neural Networks: A Survey from Algorithms to Accelerators**](https://arxiv.org/pdf/2010.00130.pdf), *Sergi Abadal, Akshay Jain, Robert Guirado, Jorge López-Alonso, Eduard Alarcón*\n\n1. [**NHK_STRL at WNUT-2020 Task 2: GATs with Syntactic Dependencies as Edges and CTC-based Loss for Text Classification**](https://www.aclweb.org/anthology/2020.wnut-1.43.pdf), *Yuki Yasuda, Taichi Ishiwatari, Taro Miyazaki, Jun Goto*\n\n1. [**Relation-aware Graph Attention Networks with Relational Position Encodings for Emotion Recognition in Conversations**](https://www.aclweb.org/anthology/2020.emnlp-main.597.pdf), *Taichi Ishiwatari, Yuki Yasuda, Taro Miyazaki, Jun Goto*\n\n1. [**PGM-Explainer: Probabilistic Graphical Model Explanations for Graph Neural Networks**](https://proceedings.neurips.cc/paper/2020/file/8fb134f258b1f7865a6ab2d935a897c9-Paper.pdf), *Minh N. Vu, My T. Thai*\n\n1. [**A Generalization of Transformer Networks to Graphs**](https://arxiv.org/pdf/2012.09699.pdf), *Vijay Prakash Dwivedi, Xavier Bresson*\n\n1. [**Discourse-Aware Neural Extractive Text Summarization**](https://www.aclweb.org/anthology/2020.acl-main.451.pdf), ACL'20, *Jiacheng Xu, Zhe Gan, Yu Cheng, Jingjing Liu*\n\n1. [**Learning Robust Node Representations on Graphs**](https://arxiv.org/abs/2008.11416), *Xu Chen, Ya Zhang, Ivor Tsang, Yuangang Pan*\n\n1. [**Adaptive Graph Diffusion Networks with Hop-wise Attention**](https://arxiv.org/abs/2012.15024), *Chuxiong Sun, Guoshi Wu*\n\n1. [**The Photoswitch Dataset: A Molecular Machine Learning Benchmark for the Advancement of Synthetic Chemistry**](https://arxiv.org/abs/2008.03226), *Aditya R. Thawani, Ryan-Rhys Griffiths, Arian Jamasb, Anthony Bourached, Penelope Jones, William McCorkindale, Alexander A. Aldrick, Alpha A. Lee*\n\n1. [**A community-powered search of machine learning strategy space to find NMR property prediction models**](https://arxiv.org/abs/2008.05994), *Lars A. Bratholm, Will Gerrard, Brandon Anderson, Shaojie Bai, Sunghwan Choi, Lam Dang, Pavel Hanchar, Addison Howard, Guillaume Huard, Sanghoon Kim, Zico Kolter, Risi Kondor, Mordechai Kornbluth, Youhan Lee, Youngsoo Lee, Jonathan P. Mailoa, Thanh Tu Nguyen, Milos Popovic, Goran Rakocevic, Walter Reade, Wonho Song, Luka Stojanovic, Erik H. Thiede, Nebojsa Tijanic, Andres Torrubia, Devin Willmott, Craig P. Butts, David R. Glowacki, Kaggle participants*\n\n1. [**Adaptive Layout Decomposition with Graph Embedding Neural Networks**](http://www.cse.cuhk.edu.hk/~byu/papers/C98-DAC2020-MPL-Selector.pdf), *Wei Li, Jialu Xia, Yuzhe Ma, Jialu Li, Yibo Lin, Bei Yu*, DAC'20\n\n1. [**Transfer Learning with Graph Neural Networks for Optoelectronic Properties of Conjugated Oligomers**](https://aip.scitation.org/doi/10.1063/5.0037863), J. Chem. Phys. 154, *Chee-Kong Lee, Chengqiang Lu, Yue Yu, Qiming Sun, Chang-Yu Hsieh, Shengyu Zhang, Qi Liu, and  Liang Shi*\n\n1. [**Jet tagging in the Lund plane with graph networks**](https://link.springer.com/article/10.1007/JHEP03(2021)052), Journal of High Energy Physics 2021, *Frédéric A. Dreyer and Huilin Qu*\n\n1. [**Global Attention Improves Graph Networks Generalization**](https://arxiv.org/abs/2006.07846), *Omri Puny, Heli Ben-Hamu, and Yaron Lipman*\n\n1. [**Learning over Families of Sets -- Hypergraph Representation Learning for Higher Order Tasks**](https://arxiv.org/abs/2101.07773), SDM 2021, *Balasubramaniam Srinivasan, Da Zheng, and George Karypis*\n\n1. [**SSFG: Stochastically Scaling Features and Gradients for Regularizing Graph Convolution Networks**](https://arxiv.org/abs/2102.10338), *Haimin Zhang, Min Xu*\n\n1. [**Application and evaluation of knowledge graph embeddings in biomedical data**](https://peerj.com/articles/cs-341/), PeerJ Computer Science 7:e341, *Mona Alshahrani​, Maha A. Thafar, Magbubah Essack*\n\n1. [**MoTSE: an interpretable task similarity estimator for small molecular property prediction tasks**](https://www.biorxiv.org/content/10.1101/2021.01.13.426608v2), bioRxiv 2021.01.13.426608, *Han Li, Xinyi Zhao, Shuya Li, Fangping Wan, Dan Zhao, Jianyang Zeng*\n\n1. [**Reinforcement Learning For Data Poisoning on Graph Neural Networks**](https://arxiv.org/abs/2102.06800), *Jacob Dineen, A S M Ahsan-Ul Haque, Matthew Bielskas*\n\n1. [**Generalising Recursive Neural Models by Tensor Decomposition**](https://github.com/danielecastellana22/tensor-tree-nn), IJCNN'20, *Daniele Castellana, Davide Bacciu*\n\n1. [**Tensor Decompositions in Recursive Neural Networks for Tree-Structured Data**](https://github.com/danielecastellana22/tensor-tree-nn), ESANN'20, *Daniele Castellana, Davide Bacciu*\n\n1. [**Combining Self-Organizing and Graph Neural Networks for Modeling Deformable Objects in Robotic Manipulation**](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC7806087/), Frotiers in Robotics and AI, *Valencia, Angel J., and Pierre Payeur*\n\n1. [**Joint stroke classification and text line grouping in online handwritten documents with edge pooling attention networks**](https://www.sciencedirect.com/science/article/abs/pii/S0031320321000467), Pattern Recognition, *Jun-Yu Ye, Yan-Ming Zhang, Qing Yang, Cheng-Lin Liu*\n\n1. [**Toward Accurate Predictions of Atomic Properties via Quantum Mechanics Descriptors Augmented Graph Convolutional Neural Network: Application of This Novel Approach in NMR Chemical Shifts Predictions**](https://pubs.acs.org/doi/full/10.1021/acs.jpclett.0c02654), The Journal of Physical Chemistry Letters, *Peng Gao, Jie Zhang, Yuzhu Sun, and Jianguo Yu*\n\n1. [**A Graph Neural Network to Model User Comfort in Robot Navigation**](https://arxiv.org/abs/2102.08863), *Pilar Bachiller, Daniel Rodriguez-Criado, Ronit R. Jorvekar, Pablo Bustos, Diego R. Faria, Luis J. Manso*\n\n1. [**Medical Entity Disambiguation Using Graph Neural Networks**](https://arxiv.org/abs/2104.01488), *Alina Vretinaris, Chuan Lei, Vasilis Efthymiou, Xiao Qin, Fatma Özcan*\n\n1. [**Chemistry-informed Macromolecule Graph Representation for Similarity Computation and Supervised Learning**](https://arxiv.org/abs/2103.02565), *Somesh Mohapatra, Joyce An, Rafael Gómez-Bombarelli*\n\n1. [**Characterizing and Forecasting User Engagement with In-app Action Graph: A Case Study of Snapchat**](https://arxiv.org/pdf/1906.00355.pdf), *Yozen Liu, Xiaolin Shi, Lucas Pierce, Xiang Ren*\n\n1. [**GIPA: General Information Propagation Algorithm for Graph Learning**](https://arxiv.org/abs/2105.06035), *Qinkai Zheng, Houyi Li, Peng Zhang, Zhixiong Yang, Guowei Zhang, Xintan Zeng, Yongchao Liu*\n\n1. [**Graph Ensemble Learning over Multiple Dependency Trees for Aspect-level Sentiment Classification**](https://arxiv.org/abs/2103.11794), NAACL'21, *Xiaochen Hou, Peng Qi, Guangtao Wang, Rex Ying, Jing Huang, Xiaodong He, Bowen Zhou*\n\n1. [**Enhancing Scientific Papers Summarization with Citation Graph**](https://arxiv.org/abs/2104.03057), AAAI'21, *Chenxin An, Ming Zhong, Yiran Chen, Danqing Wang, Xipeng Qiu, Xuanjing Huang*\n\n1. [**Improving Graph Representation Learning by Contrastive Regularization**](https://arxiv.org/pdf/2101.11525.pdf), *Kaili Ma, Haochen Yang, Han Yang, Tatiana Jin, Pengfei Chen, Yongqiang Chen, Barakeel Fanseu Kamhoua, James Cheng*\n\n1. [**Extract the Knowledge of Graph Neural Networks and Go Beyond it: An Effective Knowledge Distillation Framework**](https://arxiv.org/pdf/2103.02885.pdf), WWW'21, *Cheng Yang, Jiawei Liu, Chuan Shi*\n\n1. [**VIKING: Adversarial Attack on Network Embeddings via Supervised Network Poisoning**](https://arxiv.org/pdf/2102.07164.pdf), PAKDD'21, *Viresh Gupta, Tanmoy Chakraborty*\n\n1. [**Knowledge Graph Embedding using Graph Convolutional Networks with Relation-Aware Attention**](https://arxiv.org/pdf/2102.07200.pdf), *Nasrullah Sheikh, Xiao Qin, Berthold Reinwald, Christoph Miksovic, Thomas Gschwind, Paolo Scotton*\n\n1. [**SLAPS: Self-Supervision Improves Structure Learning for Graph Neural Networks**](https://arxiv.org/pdf/2102.05034.pdf), *Bahare Fatemi, Layla El Asri, Seyed Mehran Kazemi*\n\n1. [**Finding Needles in Heterogeneous Haystacks**](https://homepage.divms.uiowa.edu/~badhikari/assets/doc/papers/CONGCNIAAI2021.pdf), AAAI'21, *Bijaya Adhikari, Liangyue Li, Nikhil Rao, Karthik Subbian*\n\n1. [**RetCL: A Selection-based Approach for Retrosynthesis via Contrastive Learning**](https://arxiv.org/abs/2105.00795), IJCAI 2021, *Hankook Lee, Sungsoo Ahn, Seung-Woo Seo, You Young Song, Eunho Yang, Sung-Ju Hwang, Jinwoo Shin*\n\n1. [**Accurate Prediction of Free Solvation Energy of Organic Molecules via Graph Attention Network and Message Passing Neural Network from Pairwise Atomistic Interactions**](https://arxiv.org/abs/2105.02048), *Ramin Ansari, Amirata Ghorbani*\n\n1. [**DIPS-Plus: The Enhanced Database of Interacting Protein Structures for Interface Prediction**](https://arxiv.org/abs/2106.04362), *Alex Morehead, Chen Chen, Ada Sedova, Jianlin Cheng*\n\n1. [**Coreference-Aware Dialogue Summarization**](https://arxiv.org/abs/2106.08556), SIGDIAL'21, *Zhengyuan Liu, Ke Shi, Nancy F. Chen*\n\n1. [**Document Structure aware Relational Graph Convolutional Networks for Ontology Population**](https://arxiv.org/abs/2104.12950), arXiv, *Abhay M Shalghar, Ayush Kumar, Balaji Ganesan, Aswin Kannan, Shobha G*\n\n1. [**Covid-19 Detection from Chest X-ray and Patient Metadata using Graph Convolutional Neural Networks**](https://arxiv.org/abs/2105.09720), *Thosini Bamunu Mudiyanselage, Nipuna Senanayake, Chunyan Ji, Yi Pan, Yanqing Zhang*\n\n1. [**Rossmann-toolbox: a deep learning-based protocol for the prediction and design of cofactor specificity in Rossmann fold proteins**](https://academic.oup.com/bib/advance-article/doi/10.1093/bib/bbab371/6375059), Briefings in Bioinformatics, *Kamil Kaminski, Jan Ludwiczak, Maciej Jasinski, Adriana Bukala, Rafal Madaj, Krzysztof Szczepaniak, Stanislaw Dunin-Horkawicz*\n\n1. [**LGESQL: Line Graph Enhanced Text-to-SQL Model with Mixed Local and Non-Local Relations**](https://arxiv.org/pdf/2106.01093.pdf), ACL'21, *Ruisheng Cao, Lu Chen, Zhi Chen, Yanbin Zhao, Su Zhu, Kai Yu*\n\n1. [**Enhancing Graph Neural Networks via auxiliary training for semi-supervised node classification**](https://www.sciencedirect.com/science/article/pii/S0950705121001477), Knowledge-Based System'21, *Yao Wu, Yu Song, Hong Huang, Fanghua Ye, Xing Xie, Hai Jin*\n\n1. [**Modeling Graph Node Correlations with Neighbor Mixture Models**](https://arxiv.org/pdf/2103.15966.pdf), *Linfeng Liu, Michael C. Hughes, Li-Ping Liu*\n\n1. [**COMBINING PHYSICS AND MACHINE LEARNING FOR NETWORK FLOW ESTIMATION**](https://openreview.net/pdf/9dc2744a465941220de07cf308acf822ec8aaa64.pdf), ICLR'21, *Arlei Silva, Furkan Kocayusufoglu, Saber Jafarpour, Francesco Bullo, Ananthram Swami, Ambuj Singh*\n\n1. [**A Classification Method for Academic Resources Based on a Graph Attention Network**](https://www.mdpi.com/1999-5903/13/3/64/htm), Future Internet'21, *Jie Yu, Yaliu Li, Chenle Pan and Junwei Wang*\n\n1. [**Large Graph Convolutional Network Training with GPU-Oriented Data Communication Architecture**](https://arxiv.org/abs/2103.03330), *Seung Won Min, Kun Wu, Sitao Huang, Mert Hidayetoğlu, Jinjun Xiong, Eiman Ebrahimi, Deming Chen, Wen-mei Hwu*\n\n1. [**Graph Attention Multi-Layer Perception**](https://github.com/PKU-DAIR/GAMLP/blob/main/GAMLP.pdf), *Wentao Zhang, Ziqi Yin, Zeang Sheng, Wen Ouyang, Xiaosen Li, Yangyu Tao, Zhi Yang, Bin Cui*\n\n1. [**GNNLens: A Visual Analytics Approach for Prediction Error Diagnosis of Graph Neural Networks**](https://arxiv.org/abs/2011.11048v5), *Zhihua Jin, Yong Wang, Qianwen Wang, Yao Ming, Tengfei Ma, Huamin Qu*\n\n1. [**How Attentive are Graph Attention Networks?**](https://arxiv.org/pdf/2105.14491.pdf), *Shaked Brody, Uri Alon, Eran Yahav*, [code](https://github.com/tech-srl/how_attentive_are_gats)\n\n1. [**SCENE: Reasoning about Traffic Scenes using Heterogeneous Graph Neural Networks**](https://arxiv.org/pdf/2301.03512.pdf), *Thomas Monninger\\*, Julian Schmidt\\*, Jan Rupprecht, David Raba, Julian Jordan, Daniel Frank, Steffen Staab, Klaus Dietmayer*, [code](https://github.com/schmidt-ju/scene), \\*co-first authors\n\n</details>\n\n## Contributing\n\nPlease let us know if you encounter a bug or have any suggestions by [filing an issue](https://github.com/dmlc/dgl/issues).\n\nWe welcome all contributions from bug fixes to new features and extensions.\n\nWe expect all contributions discussed in the issue tracker and going through PRs.  Please refer to our [contribution guide](https://docs.dgl.ai/contribute.html).\n\n## Cite\n\nIf you use DGL in a scientific publication, we would appreciate citations to the following paper:\n```\n@article{wang2019dgl,\n    title={Deep Graph Library: A Graph-Centric, Highly-Performant Package for Graph Neural Networks},\n    author={Minjie Wang and Da Zheng and Zihao Ye and Quan Gan and Mufei Li and Xiang Song and Jinjing Zhou and Chao Ma and Lingfan Yu and Yu Gai and Tianjun Xiao and Tong He and George Karypis and Jinyang Li and Zheng Zhang},\n    year={2019},\n    journal={arXiv preprint arXiv:1909.01315}\n}\n```\n\n## The Team\n\nDGL is developed and maintained by [NYU, NYU Shanghai, AWS Shanghai AI Lab, and AWS MXNet Science Team](https://www.dgl.ai/pages/about.html).\n\n## License\n\nDGL uses Apache License 2.0.\n"
  },
  {
    "path": "apps/life_sci/README.md",
    "content": "# DGL-LifeSci\n\nDGL-LifeSci is moved [here](https://github.com/awslabs/dgl-lifesci).\n"
  },
  {
    "path": "benchmarks/.gitignore",
    "content": "html\nresults\n"
  },
  {
    "path": "benchmarks/Jenkinsfile",
    "content": "pipeline {\n    triggers {\n        issueCommentTrigger('@dgl-bot .*')\n    }\n    agent {\n        docker {\n            label 'linux-benchmark-node'\n            image 'dgllib/dgl-ci-lint'\n            alwaysPull true\n        }\n    }\n    stages {\n        stage('Regression Test') {\n            steps {\n                checkout scm\n                script {\n                    def commentTriggerCause = currentBuild.getBuildCauses('org.jenkinsci.plugins.pipeline.github.trigger.IssueCommentCause')\n                    def prOpenTriggerCause = currentBuild.getBuildCauses('jenkins.branch.BranchEventCause')\n                    def realTriggerCause = currentBuild.getBuildCauses()\n                    echo(\"BUILD CAUSE: ${realTriggerCause.toString()}\")\n\n                    if (commentTriggerCause) {\n                        dir('benchmark_scripts_repo') {\n                            checkout([$class: 'GitSCM', branches: [[name: '*/master']],\n                                userRemoteConfigs: [[credentialsId: 'github', url: 'https://github.com/dglai/DGL_scripts.git']]])\n                        }\n                        sh('cp benchmark_scripts_repo/benchmark/* benchmarks/scripts/')\n                        def comment = env.GITHUB_COMMENT\n                        def author = env.GITHUB_COMMENT_AUTHOR\n                        def authorized_user = ['VoVAllen', 'BarclayII', 'jermainewang', 'zheng-da', 'mufeili']\n                        def isauthorized = author in authorized_user\n                        def command_lists = comment.split(' ')\n                        def instance_type = command_lists[2].replace('.', \"\")\n                        if (!isauthorized) {\n                            error(\"Not authorized to launch regression tests\")\n                        }\n                        if (command_lists.size() != 5) {\n                            pullRequest.comment('Cannot run the regression test due to unknown command')\n                            error('Unknown command')\n                        } else {\n                            pullRequest.comment(\"Start the Regression test. View at ${RUN_DISPLAY_URL}\")\n                        }\n                        dir('benchmarks/scripts') {\n                            sh('python3 -m pip install boto3')\n                            sh(\"PYTHONUNBUFFERED=1 GIT_URL=${env.GIT_URL} GIT_BRANCH=${env.CHANGE_BRANCH} python3 run_reg_test.py --data-folder ${env.GIT_COMMIT}_${instance_type} --run-cmd '${comment}'\")\n                        }\n                        pullRequest.comment(\"Finished the Regression test. Result table is at https://dgl-asv-data.s3-us-west-2.amazonaws.com/${env.GIT_COMMIT}_${instance_type}/results/result.csv. Jenkins job link is ${RUN_DISPLAY_URL}. \")\n                    } else {\n                        // if (prOpenTriggerCause) {\n                        //     if (env.BUILD_ID == \"1\") {\n                        //         pullRequest.comment('To trigger regression tests: \\n - `@dgl-bot run [instance-type] [which tests] [compare-with-branch]`; \\n For example: `@dgl-bot run g4dn.4xlarge all dmlc/master` or `@dgl-bot run c5.9xlarge kernel,api dmlc/master`')\n                        //     }\n                        // }\n                        echo('Build was not started by a trigger')\n                    }\n                // echo(\"Comment: ${commentTriggerCause.getComment()}\")\n                }\n            }\n            post {\n                failure {\n                    echo '========Regression execution failed========'\n                }\n            }\n        }\n    }\n}\n"
  },
  {
    "path": "benchmarks/README.md",
    "content": "DGL Benchmarks\n====\n\nBenchmarking DGL with Airspeed Velocity.\n\nUsage\n---\n\nBefore beginning, ensure that airspeed velocity is installed:\n\n```bash\npip install asv\n```\n\nTo run all benchmarks locally, build the project first and then run:\n\n```bash\nasv run -n -e --python=same --verbose\n```\n\n**Due to ASV's restriction, `--python=same` will not write any benchmark results\nto disk. It does not support specifying branches and commits either. They are only\navailable under ASV's managed environment.**\n\nTo change the device for benchmarking, set the `DGL_BENCH_DEVICE` environment variable.\nAllowed values are `\"cpu\"` or `\"gpu\"`.\n\n```bash\nexport DGL_BENCH_DEVICE=gpu\n```\n\nTo select which benchmark to run, use the `--bench` flag. For example,\n\n```bash\nasv run -n -e --python=same --verbose --bench model_acc.bench_gat\n```\n\nNote that OGB dataset need to be download manually to `/tmp/dataset` folder (i.e. `/tmp/dataset/ogbn-products/`) beforehand. \nYou can do it by runnnig the code below in this folder\n```python\nfrom benchmarks.utils import get_ogb_graph\nget_ogb_graph(\"ogbn-product\")\n```\n\nRun in docker locally\n---\n\nDGL runs all benchmarks automatically in docker container. To run bencmarks in docker locally,\n\n* Git commit your locally changes. No need to push to remote repository.\n* To compare commits from different branches. Change the `\"branches\"` list in `asv.conf.json`.\n  The default is `\"HEAD\"` which is the last commit of the current branch. For example, to\n  compare your proposed changes with the master branch, set it to be `[\"HEAD\", \"master\"]`.\n  If your workspace is a forked repository, make sure your local master has synced with\n  the upstream.\n* Use the `publish.sh` script. It accepts two arguments, a name specifying the identity of\n  the test machine and a device name. For example,\n  ```bash\n  bash publish.sh dev-machine gpu\n  ```\n\nThe script will output two folders `results` and `html`. The `html` folder contains the\ngenerated static web pages. View it by:\n\n```bash\nasv preview\n```\n\nPlease see `publish.sh` for more information on how it works and how to modify it according\nto your need.\n\nAdding a new benchmark suite\n---\n\nThe benchmark folder is organized as follows:\n\n```\n|-- benchmarks/\n  |-- model_acc/           # benchmarks for model accuracy\n    |-- bench_gcn.py\n    |-- bench_gat.py\n    |-- bench_sage.py\n    ...\n  |-- model_speed/         # benchmarks for model training speed\n    |-- bench_gat.py\n    |-- bench_sage.py\n    ...\n  ...                      # other types of benchmarks\n|-- html/                  # generated html files\n|-- results/               # generated result files\n|-- asv.conf.json          # asv config file\n|-- build_dgl_asv.sh       # script for building dgl in asv\n|-- install_dgl_asv.sh     # script for installing dgl in asv\n|-- publish.sh             # script for running benchmarks in docker\n|-- README.md              # this readme\n|-- run.sh                 # script for calling asv in docker\n|-- ...                    # other aux files\n```\n\nTo add a new benchmark, pick a suitable benchmark type and create a python script under\nit. We prefer to have the prefix `bench_` in the name. Here is a toy example:\n\n```python\n# bench_range.py\n\nimport time\nfrom .. import utils\n\n@utils.benchmark('time')\n@utils.parametrize('l', [10, 100, 1000])\n@utils.parametrize('u', [10, 100, 1000])\ndef track_time(l, u):\n    t0 = time.time()\n    for i in range(l, u):\n        pass\n    return time.time() - t0\n```\n\n* The main entry point of each benchmark script is a `track_*` function. The function\n  can have arbitrary arguments and must return the benchmark result.\n* There are two useful decorators: `utils.benchmark` and `utils.parametrize`.\n* `utils.benchmark` indicates the type of this benchmark. Currently supported types are:\n  `'time'` and `'acc'`. The decorator will perform some necessary setup and finalize\n  steps such as fixing the random seed for the `'acc'` type.\n* `utils.parametrize` specifies the parameters to test.\n  Multiple parametrize decorators mean benchmarking the combination.\n* Check out `model_acc/bench_gcn.py` and `model_speed/bench_sage.py`.\n* ASV's [official guide on writing benchmarks](https://asv.readthedocs.io/en/stable/writing_benchmarks.html)\n  is also very helpful.\n\n\nTips\n----\n* Feed flags `-e --verbose` to `asv run` to print out stderr and more information.\n* When running benchmarks locally (e.g., with `--python=same`), ASV will not write results to disk\n  so `asv publish` will not generate plots.\n* Try make your benchmarks compatible with all the versions being tested.\n* For ogbn dataset, put the dataset into /tmp/dataset/\n"
  },
  {
    "path": "benchmarks/asv.conf.json",
    "content": "{\n    // The version of the config file format.  Do not change, unless\n    // you know what you are doing.\n    \"version\": 1,\n    // The name of the project being benchmarked\n    \"project\": \"dgl\",\n    // The project's homepage\n    \"project_url\": \"https://www.dgl.ai\",\n    // The URL or local path of the source code repository for the\n    // project being benchmarked\n    \"repo\": \"..\",\n    // The Python project's subdirectory in your repo.  If missing or\n    // the empty string, the project is assumed to be located at the root\n    // of the repository.\n    // \"repo_subdir\": \"python\",\n    // Customizable commands for building, installing, and\n    // uninstalling the project. See asv.conf.json documentation.\n    //\n    \"build_command\": [\n        \"/bin/bash {conf_dir}/scripts/build_dgl_asv.sh\"\n    ],\n    \"install_command\": [\n        \"/bin/bash {conf_dir}/scripts/install_dgl_asv.sh\"\n    ],\n    \"uninstall_command\": [\n        \"return-code=any python -m pip uninstall -y dgl\"\n    ],\n    // List of branches to benchmark. If not provided, defaults to \"master\"\n    // (for git) or \"default\" (for mercurial).\n    \"branches\": [\n        \"HEAD\"\n    ], // for git\n    // The DVCS being used.  If not set, it will be automatically\n    // determined from \"repo\" by looking at the protocol in the URL\n    // (if remote), or by looking for special directories, such as\n    // \".git\" (if local).\n    \"dvcs\": \"git\",\n    // The tool to use to create environments.  May be \"conda\",\n    // \"virtualenv\" or other value depending on the plugins in use.\n    // If missing or the empty string, the tool will be automatically\n    // determined by looking for tools on the PATH environment\n    // variable.\n    \"environment_type\": \"conda\",\n    // timeout in seconds for installing any dependencies in environment\n    // defaults to 10 min\n    \"install_timeout\": 600,\n    // the base URL to show a commit for the project.\n    // \"show_commit_url\": \"http://github.com/owner/project/commit/\",\n    // The Pythons you'd like to test against.  If not provided, defaults\n    // to the current version of Python used to run `asv`.\n    // \"pythons\": [\"2.7\", \"3.6\"],\n    // The list of conda channel names to be searched for benchmark\n    // dependency packages in the specified order\n    // \"conda_channels\": [\"conda-forge\", \"defaults\"],\n    // The matrix of dependencies to test.  Each key is the name of a\n    // package (in PyPI) and the values are version numbers.  An empty\n    // list or empty string indicates to just test against the default\n    // (latest) version. null indicates that the package is to not be\n    // installed. If the package to be tested is only available from\n    // PyPi, and the 'environment_type' is conda, then you can preface\n    // the package name by 'pip+', and the package will be installed via\n    // pip (with all the conda available packages installed first,\n    // followed by the pip installed packages).\n    //\n    // \"matrix\": {\n    //     \"numpy\": [\"1.6\", \"1.7\"],\n    //     \"six\": [\"\", null],        // test with and without six installed\n    //     \"pip+emcee\": [\"\"],   // emcee is only available for install with pip.\n    // },\n    // Combinations of libraries/python versions can be excluded/included\n    // from the set to test. Each entry is a dictionary containing additional\n    // key-value pairs to include/exclude.\n    //\n    // An exclude entry excludes entries where all values match. The\n    // values are regexps that should match the whole string.\n    //\n    // An include entry adds an environment. Only the packages listed\n    // are installed. The 'python' key is required. The exclude rules\n    // do not apply to includes.\n    //\n    // In addition to package names, the following keys are available:\n    //\n    // - python\n    //     Python version, as in the *pythons* variable above.\n    // - environment_type\n    //     Environment type, as above.\n    // - sys_platform\n    //     Platform, as in sys.platform. Possible values for the common\n    //     cases: 'linux2', 'win32', 'cygwin', 'darwin'.\n    //\n    // \"exclude\": [\n    //     {\"python\": \"3.2\", \"sys_platform\": \"win32\"}, // skip py3.2 on windows\n    //     {\"environment_type\": \"conda\", \"six\": null}, // don't run without six on conda\n    // ],\n    //\n    // \"include\": [\n    //     // additional env for python2.7\n    //     {\"python\": \"2.7\", \"numpy\": \"1.8\"},\n    //     // additional env if run on windows+conda\n    //     {\"platform\": \"win32\", \"environment_type\": \"conda\", \"python\": \"2.7\", \"libpython\": \"\"},\n    // ],\n    // The directory (relative to the current directory) that benchmarks are\n    // stored in.  If not provided, defaults to \"benchmarks\"\n    // \"benchmark_dir\": \"benchmarks\",\n    // The directory (relative to the current directory) to cache the Python\n    // environments in.  If not provided, defaults to \"env\"\n    \"env_dir\": \"env\",\n    // The directory (relative to the current directory) that raw benchmark\n    // results are stored in.  If not provided, defaults to \"results\".\n    \"results_dir\": \"results\",\n    // The directory (relative to the current directory) that the html tree\n    // should be written to.  If not provided, defaults to \"html\".\n    \"html_dir\": \"html\",\n    // The number of characters to retain in the commit hashes.\n    // \"hash_length\": 8,\n    // `asv` will cache results of the recent builds in each\n    // environment, making them faster to install next time.  This is\n    // the number of builds to keep, per environment.\n    // \"build_cache_size\": 2,\n    // The commits after which the regression search in `asv publish`\n    // should start looking for regressions. Dictionary whose keys are\n    // regexps matching to benchmark names, and values corresponding to\n    // the commit (exclusive) after which to start looking for\n    // regressions.  The default is to start from the first commit\n    // with results. If the commit is `null`, regression detection is\n    // skipped for the matching benchmark.\n    //\n    // \"regressions_first_commits\": {\n    //    \"some_benchmark\": \"352cdf\",  // Consider regressions only after this commit\n    //    \"another_benchmark\": null,   // Skip regression detection altogether\n    // },\n    // The thresholds for relative change in results, after which `asv\n    // publish` starts reporting regressions. Dictionary of the same\n    // form as in ``regressions_first_commits``, with values\n    // indicating the thresholds.  If multiple entries match, the\n    // maximum is taken. If no entry matches, the default is 5%.\n    //\n    // \"regressions_thresholds\": {\n    //    \"some_benchmark\": 0.01,     // Threshold of 1%\n    //    \"another_benchmark\": 0.5,   // Threshold of 50%\n    // },\n}\n"
  },
  {
    "path": "benchmarks/benchmarks/__init__.py",
    "content": ""
  },
  {
    "path": "benchmarks/benchmarks/api/__init__.py",
    "content": ""
  },
  {
    "path": "benchmarks/benchmarks/api/bench_add_self_loop.py",
    "content": "import time\n\nimport dgl\nimport dgl.function as fn\n\nimport numpy as np\nimport torch\n\nfrom .. import utils\n\n\n@utils.benchmark(\"time\")\n@utils.parametrize(\"graph_name\", [\"cora\", \"livejournal\"])\n@utils.parametrize(\"format\", [\"coo\"])\ndef track_time(graph_name, format):\n    device = utils.get_bench_device()\n    graph = utils.get_graph(graph_name, format)\n    graph = graph.to(device)\n\n    # dry run\n    for i in range(3):\n        g = graph.add_self_loop()\n\n    # timing\n\n    with utils.Timer() as t:\n        for i in range(3):\n            edges = graph.add_self_loop()\n\n    return t.elapsed_secs / 3\n"
  },
  {
    "path": "benchmarks/benchmarks/api/bench_batch.py",
    "content": "import time\n\nimport dgl\n\nimport torch\n\nfrom .. import utils\n\n\n@utils.benchmark(\"time\")\n@utils.parametrize(\"batch_size\", [4, 32, 256, 1024])\ndef track_time(batch_size):\n    device = utils.get_bench_device()\n    ds = dgl.data.QM7bDataset()\n    # prepare graph\n    graphs = []\n    for graph in ds[0:batch_size][0]:\n        g = graph.to(device)\n        graphs.append(g)\n\n    # dry run\n    for i in range(10):\n        g = dgl.batch(graphs)\n\n    # timing\n\n    with utils.Timer() as t:\n        for i in range(100):\n            g = dgl.batch(graphs)\n\n    return t.elapsed_secs / 100\n"
  },
  {
    "path": "benchmarks/benchmarks/api/bench_builtin_apply_edges.py",
    "content": "import time\n\nimport dgl\nimport dgl.function as fn\n\nimport numpy as np\nimport torch\n\nfrom .. import utils\n\n\n@utils.benchmark(\"time\", timeout=600)\n@utils.parametrize(\"graph_name\", [\"cora\", \"ogbn-arxiv\"])\n@utils.parametrize(\"format\", [\"coo\", \"csr\"])\n@utils.parametrize(\"feat_size\", [8, 128, 512])\n@utils.parametrize(\"reduce_type\", [\"u->e\", \"u+v\"])\ndef track_time(graph_name, format, feat_size, reduce_type):\n    device = utils.get_bench_device()\n    graph = utils.get_graph(graph_name, format)\n    graph = graph.to(device)\n    graph.ndata[\"h\"] = torch.randn(\n        (graph.num_nodes(), feat_size), device=device\n    )\n\n    reduce_builtin_dict = {\n        \"u->e\": fn.copy_u(\"h\", \"x\"),\n        \"u+v\": fn.u_add_v(\"h\", \"h\", \"x\"),\n    }\n\n    # dry run\n    for i in range(3):\n        graph.apply_edges(reduce_builtin_dict[reduce_type])\n\n    # timing\n\n    with utils.Timer() as t:\n        for i in range(10):\n            graph.apply_edges(reduce_builtin_dict[reduce_type])\n\n    return t.elapsed_secs / 10\n"
  },
  {
    "path": "benchmarks/benchmarks/api/bench_builtin_apply_edges_hetero.py",
    "content": "import time\n\nimport dgl\nimport dgl.function as fn\n\nimport numpy as np\nimport torch\n\nfrom .. import utils\n\n\n@utils.benchmark(\"time\", timeout=600)\n@utils.parametrize(\"num_relations\", [5, 50, 500])\n@utils.parametrize(\"format\", [\"coo\", \"csr\"])\n@utils.parametrize(\"feat_size\", [8, 128, 512])\n@utils.parametrize(\"reduce_type\", [\"u->e\"])  # , 'e->u'])\ndef track_time(num_relations, format, feat_size, reduce_type):\n    device = utils.get_bench_device()\n    dd = {}\n    candidate_edges = [\n        dgl.data.CoraGraphDataset(verbose=False)[0].edges(),\n        dgl.data.PubmedGraphDataset(verbose=False)[0].edges(),\n        dgl.data.CiteseerGraphDataset(verbose=False)[0].edges(),\n    ]\n    for i in range(num_relations):\n        dd[(\"n1\", \"e_{}\".format(i), \"n2\")] = candidate_edges[\n            i % len(candidate_edges)\n        ]\n    graph = dgl.heterograph(dd)\n\n    graph = graph.to(device)\n    graph.nodes[\"n1\"].data[\"h\"] = torch.randn(\n        (graph.num_nodes(\"n1\"), feat_size), device=device\n    )\n    graph.nodes[\"n2\"].data[\"h\"] = torch.randn(\n        (graph.num_nodes(\"n2\"), feat_size), device=device\n    )\n\n    reduce_builtin_dict = {\n        \"u->e\": fn.copy_u(\"h\", \"x\"),\n        # 'e->u': fn.copy_e('h', 'x'),\n    }\n\n    # dry run\n    for i in range(3):\n        graph.apply_edges(reduce_builtin_dict[reduce_type])\n\n    # timing\n\n    with utils.Timer() as t:\n        for i in range(10):\n            graph.apply_edges(reduce_builtin_dict[reduce_type])\n\n    return t.elapsed_secs / 10\n"
  },
  {
    "path": "benchmarks/benchmarks/api/bench_builtin_multi_update_all.py",
    "content": "import time\n\nimport dgl\nimport dgl.function as fn\nimport numpy as np\nimport torch\n\nfrom .. import utils\n\n\n@utils.benchmark(\"time\", timeout=600)\n@utils.parametrize(\"feat_size\", [32, 128, 512])\n@utils.parametrize(\"num_relations\", [5, 50, 500])\n@utils.parametrize(\"multi_reduce_type\", [\"sum\", \"stack\"])\ndef track_time(feat_size, num_relations, multi_reduce_type):\n    device = utils.get_bench_device()\n    dd = {}\n    candidate_edges = [\n        dgl.data.CoraGraphDataset(verbose=False)[0].edges(),\n        dgl.data.PubmedGraphDataset(verbose=False)[0].edges(),\n        dgl.data.CiteseerGraphDataset(verbose=False)[0].edges(),\n    ]\n    for i in range(num_relations):\n        dd[(\"n1\", \"e_{}\".format(i), \"n2\")] = candidate_edges[\n            i % len(candidate_edges)\n        ]\n    graph = dgl.heterograph(dd)\n\n    graph = graph.to(device)\n    graph.nodes[\"n1\"].data[\"h\"] = torch.randn(\n        (graph.num_nodes(\"n1\"), feat_size), device=device\n    )\n    graph.nodes[\"n2\"].data[\"h\"] = torch.randn(\n        (graph.num_nodes(\"n2\"), feat_size), device=device\n    )\n\n    # dry run\n    update_dict = {}\n    for i in range(num_relations):\n        update_dict[\"e_{}\".format(i)] = (fn.copy_u(\"h\", \"m\"), fn.sum(\"m\", \"h\"))\n    graph.multi_update_all(update_dict, multi_reduce_type)\n\n    # timing\n\n    with utils.Timer() as t:\n        for i in range(3):\n            graph.multi_update_all(update_dict, multi_reduce_type)\n\n    return t.elapsed_secs / 3\n"
  },
  {
    "path": "benchmarks/benchmarks/api/bench_builtin_update_all_coo.py",
    "content": "import time\n\nimport dgl\nimport dgl.function as fn\n\nimport numpy as np\nimport torch\n\nfrom .. import utils\n\n\n@utils.benchmark(\"time\", timeout=600)\n@utils.parametrize(\"graph_name\", [\"ogbn-arxiv\"])\n@utils.parametrize(\"format\", [\"coo\"])\n@utils.parametrize(\"feat_size\", [4, 32, 256])\n@utils.parametrize(\"msg_type\", [\"copy_u\", \"u_mul_e\"])\n@utils.parametrize(\"reduce_type\", [\"sum\", \"mean\", \"max\"])\ndef track_time(graph_name, format, feat_size, msg_type, reduce_type):\n    device = utils.get_bench_device()\n    graph = utils.get_graph(graph_name, format)\n    graph = graph.to(device)\n    graph.ndata[\"h\"] = torch.randn(\n        (graph.num_nodes(), feat_size), device=device\n    )\n    graph.edata[\"e\"] = torch.randn((graph.num_edges(), 1), device=device)\n\n    msg_builtin_dict = {\n        \"copy_u\": fn.copy_u(\"h\", \"x\"),\n        \"u_mul_e\": fn.u_mul_e(\"h\", \"e\", \"x\"),\n    }\n\n    reduce_builtin_dict = {\n        \"sum\": fn.sum(\"x\", \"h_new\"),\n        \"mean\": fn.mean(\"x\", \"h_new\"),\n        \"max\": fn.max(\"x\", \"h_new\"),\n    }\n\n    # dry run\n    graph.update_all(\n        msg_builtin_dict[msg_type], reduce_builtin_dict[reduce_type]\n    )\n\n    # timing\n\n    with utils.Timer() as t:\n        for i in range(3):\n            graph.update_all(\n                msg_builtin_dict[msg_type], reduce_builtin_dict[reduce_type]\n            )\n\n    return t.elapsed_secs / 3\n"
  },
  {
    "path": "benchmarks/benchmarks/api/bench_builtin_update_all_csc.py",
    "content": "import time\n\nimport dgl\nimport dgl.function as fn\n\nimport numpy as np\nimport torch\n\nfrom .. import utils\n\n\n@utils.benchmark(\"time\", timeout=600)\n@utils.parametrize(\"graph_name\", [\"ogbn-arxiv\", \"reddit\", \"ogbn-proteins\"])\n@utils.parametrize(\"format\", [\"csc\"])\n@utils.parametrize(\"feat_size\", [4, 32, 256])\n@utils.parametrize(\"msg_type\", [\"copy_u\", \"u_mul_e\"])\n@utils.parametrize(\"reduce_type\", [\"sum\", \"mean\", \"max\"])\ndef track_time(graph_name, format, feat_size, msg_type, reduce_type):\n    device = utils.get_bench_device()\n    graph = utils.get_graph(graph_name, format)\n    graph = graph.to(device)\n    graph.ndata[\"h\"] = torch.randn(\n        (graph.num_nodes(), feat_size), device=device\n    )\n    graph.edata[\"e\"] = torch.randn((graph.num_edges(), 1), device=device)\n\n    msg_builtin_dict = {\n        \"copy_u\": fn.copy_u(\"h\", \"x\"),\n        \"u_mul_e\": fn.u_mul_e(\"h\", \"e\", \"x\"),\n    }\n\n    reduce_builtin_dict = {\n        \"sum\": fn.sum(\"x\", \"h_new\"),\n        \"mean\": fn.mean(\"x\", \"h_new\"),\n        \"max\": fn.max(\"x\", \"h_new\"),\n    }\n\n    # dry run\n\n    for i in range(3):\n        graph.update_all(\n            msg_builtin_dict[msg_type], reduce_builtin_dict[reduce_type]\n        )\n\n    # timing\n\n    with utils.Timer() as t:\n        for i in range(10):\n            graph.update_all(\n                msg_builtin_dict[msg_type], reduce_builtin_dict[reduce_type]\n            )\n\n    return t.elapsed_secs / 10\n"
  },
  {
    "path": "benchmarks/benchmarks/api/bench_edge_ids.py",
    "content": "import time\n\nimport dgl\n\nimport numpy as np\nimport torch\n\nfrom .. import utils\n\n\n# edge_ids is not supported on cuda\n# @utils.skip_if_gpu()\n@utils.benchmark(\"time\", timeout=1200)\n@utils.parametrize_cpu(\"graph_name\", [\"cora\", \"livejournal\", \"friendster\"])\n@utils.parametrize_gpu(\"graph_name\", [\"cora\", \"livejournal\"])\n@utils.parametrize(\"format\", [\"coo\", \"csr\", \"csc\"])\n@utils.parametrize(\"fraction\", [0.01, 0.1])\n@utils.parametrize(\"return_uv\", [True, False])\ndef track_time(graph_name, format, fraction, return_uv):\n    device = utils.get_bench_device()\n    graph = utils.get_graph(graph_name, format)\n    coo_graph = utils.get_graph(graph_name, \"coo\")\n    graph = graph.to(device)\n    eids = np.random.choice(\n        np.arange(graph.num_edges(), dtype=np.int64),\n        int(graph.num_edges() * fraction),\n    )\n    eids = torch.tensor(eids, device=\"cpu\", dtype=torch.int64)\n    u, v = coo_graph.find_edges(eids)\n    del coo_graph, eids\n    u = u.to(device)\n    v = v.to(device)\n    # dry run\n    for i in range(10):\n        out = graph.edge_ids(u[0], v[0])\n\n    # timing\n\n    with utils.Timer() as t:\n        for i in range(3):\n            edges = graph.edge_ids(u, v, return_uv=return_uv)\n\n    return t.elapsed_secs / 3\n"
  },
  {
    "path": "benchmarks/benchmarks/api/bench_edge_subgraph.py",
    "content": "import time\n\nimport dgl\nimport dgl.function as fn\n\nimport numpy as np\nimport torch\n\nfrom .. import utils\n\n\n@utils.benchmark(\"time\")\n@utils.parametrize(\"graph_name\", [\"livejournal\", \"reddit\"])\n@utils.parametrize(\"format\", [\"coo\"])\n@utils.parametrize(\"seed_egdes_num\", [500, 5000, 50000])\ndef track_time(graph_name, format, seed_egdes_num):\n    device = utils.get_bench_device()\n    graph = utils.get_graph(graph_name, format)\n    graph = graph.to(device)\n\n    seed_edges = np.random.randint(0, graph.num_edges(), seed_egdes_num)\n    seed_edges = torch.from_numpy(seed_edges).to(device)\n\n    # dry run\n    for i in range(3):\n        dgl.edge_subgraph(graph, seed_edges)\n\n    # timing\n    num_iters = 50\n    with utils.Timer() as t:\n        for i in range(num_iters):\n            dgl.edge_subgraph(graph, seed_edges)\n\n    return t.elapsed_secs / num_iters\n"
  },
  {
    "path": "benchmarks/benchmarks/api/bench_find_edges.py",
    "content": "import time\n\nimport dgl\n\nimport numpy as np\nimport torch\n\nfrom .. import utils\n\n\n@utils.benchmark(\"time\", timeout=600)\n@utils.parametrize_cpu(\"graph_name\", [\"cora\", \"livejournal\", \"friendster\"])\n@utils.parametrize_gpu(\"graph_name\", [\"cora\", \"livejournal\"])\n@utils.parametrize(\"format\", [\"coo\"])  # csc is not supported\n@utils.parametrize(\"fraction\", [0.01, 0.1])\ndef track_time(graph_name, format, fraction):\n    device = utils.get_bench_device()\n    graph = utils.get_graph(graph_name, format)\n    graph = graph.to(device)\n    eids = np.random.choice(\n        np.arange(graph.num_edges(), dtype=np.int64),\n        int(graph.num_edges() * fraction),\n    )\n    eids = torch.tensor(eids, device=device, dtype=torch.int64)\n    # dry run\n    for i in range(10):\n        out = graph.find_edges(i)\n        out = graph.find_edges(\n            torch.arange(i * 10, dtype=torch.int64, device=device)\n        )\n\n    # timing\n\n    with utils.Timer() as t:\n        for i in range(10):\n            edges = graph.find_edges(eids)\n\n    return t.elapsed_secs / 10\n"
  },
  {
    "path": "benchmarks/benchmarks/api/bench_format_conversion.py",
    "content": "import time\n\nimport dgl\n\nimport numpy as np\nimport torch\n\nfrom .. import utils\n\n\n@utils.benchmark(\"time\", timeout=600)\n@utils.parametrize_cpu(\n    \"graph_name\", [\"cora\", \"pubmed\", \"ogbn-arxiv\", \"livejournal\", \"friendster\"]\n)\n@utils.parametrize_gpu(\"graph_name\", [\"cora\", \"livejournal\"])\n@utils.parametrize(\n    \"format\",\n    [\n        (\"coo\", \"csc\"),\n        (\"csc\", \"coo\"),\n        (\"coo\", \"csr\"),\n        (\"csr\", \"coo\"),\n        (\"csr\", \"csc\"),\n        (\"csc\", \"csr\"),\n    ],\n)\ndef track_time(graph_name, format):\n    from_format, to_format = format\n    device = utils.get_bench_device()\n    graph = utils.get_graph(graph_name, from_format)\n    graph = graph.to(device)\n    if format == (\"coo\", \"csr\") and graph_name == \"friendster\":\n        # Mark graph as sorted to check performance for COO matrix marked as\n        # sorted. Note that friendster dataset is already sorted.\n        graph = dgl.graph(graph.edges(), row_sorted=True)\n    graph = graph.formats([from_format])\n    # dry run\n    graph.formats([to_format])\n\n    # timing\n    with utils.Timer() as t:\n        for i in range(10):\n            gg = graph.formats([to_format])\n\n    return t.elapsed_secs / 10\n"
  },
  {
    "path": "benchmarks/benchmarks/api/bench_fused_sample_neighbors.py",
    "content": "import time\n\nimport dgl\nimport dgl.function as fn\n\nimport numpy as np\nimport torch\n\nfrom .. import utils\n\n\n@utils.benchmark(\"time\")\n@utils.parametrize_cpu(\"graph_name\", [\"livejournal\", \"reddit\"])\n@utils.parametrize_gpu(\"graph_name\", [\"ogbn-arxiv\", \"reddit\"])\n@utils.parametrize(\"format\", [\"csr\", \"csc\"])\n@utils.parametrize(\"seed_nodes_num\", [200, 5000, 20000])\n@utils.parametrize(\"fanout\", [5, 20, 40])\ndef track_time(graph_name, format, seed_nodes_num, fanout):\n    device = utils.get_bench_device()\n    graph = utils.get_graph(graph_name, format).to(device)\n\n    edge_dir = \"in\" if format == \"csc\" else \"out\"\n    seed_nodes = np.random.randint(0, graph.num_nodes(), seed_nodes_num)\n    seed_nodes = torch.from_numpy(seed_nodes).to(device)\n\n    # dry run\n    for i in range(3):\n        dgl.sampling.sample_neighbors_fused(\n            graph, seed_nodes, fanout, edge_dir=edge_dir\n        )\n\n    # timing\n    with utils.Timer() as t:\n        for i in range(50):\n            dgl.sampling.sample_neighbors_fused(\n                graph, seed_nodes, fanout, edge_dir=edge_dir\n            )\n\n    return t.elapsed_secs / 50\n"
  },
  {
    "path": "benchmarks/benchmarks/api/bench_heterograph_construction.py",
    "content": "import time\n\nimport dgl\nimport dgl.function as fn\n\nimport numpy as np\nimport torch\n\nfrom .. import utils\n\n\n@utils.benchmark(\"time\")\n@utils.parametrize(\"num_relations\", [5, 50, 500])\ndef track_time(num_relations):\n    dd = {}\n    candidate_edges = [\n        dgl.data.CoraGraphDataset(verbose=False)[0].edges(),\n        dgl.data.PubmedGraphDataset(verbose=False)[0].edges(),\n        dgl.data.CiteseerGraphDataset(verbose=False)[0].edges(),\n    ]\n    for i in range(num_relations):\n        dd[(\"n1\", \"e_{}\".format(i), \"n2\")] = candidate_edges[\n            i % len(candidate_edges)\n        ]\n\n    # dry run\n    graph = dgl.heterograph(dd)\n\n    # timing\n    with utils.Timer() as t:\n        for i in range(3):\n            graph = dgl.heterograph(dd)\n\n    return t.elapsed_secs / 3\n"
  },
  {
    "path": "benchmarks/benchmarks/api/bench_homograph_edge_construction.py",
    "content": "import time\n\nimport dgl\nimport dgl.function as fn\n\nimport numpy as np\nimport torch\n\nfrom .. import utils\n\n\n@utils.skip_if_gpu()\n@utils.benchmark(\"time\")\n@utils.parametrize(\"size\", [\"small\", \"large\"])\ndef track_time(size):\n    edge_list = {\n        \"small\": dgl.data.CiteseerGraphDataset(verbose=False)[0].edges(),\n        \"large\": utils.get_livejournal().edges(),\n    }\n\n    # dry run\n    dgl.graph(edge_list[size])\n\n    # timing\n    with utils.Timer() as t:\n        for i in range(10):\n            g = dgl.graph(edge_list[size])\n\n    return t.elapsed_secs / 10\n"
  },
  {
    "path": "benchmarks/benchmarks/api/bench_homograph_scipy_construction.py",
    "content": "import time\n\nimport dgl\nimport dgl.function as fn\n\nimport numpy as np\nimport torch\n\nfrom .. import utils\n\n\n@utils.skip_if_gpu()\n@utils.benchmark(\"time\")\n@utils.parametrize(\"size\", [\"small\", \"large\"])\n@utils.parametrize(\"scipy_format\", [\"coo\", \"csr\"])\ndef track_time(size, scipy_format):\n    matrix_dict = {\n        \"small\": dgl.data.CiteseerGraphDataset(verbose=False)[0].adj_external(\n            scipy_fmt=scipy_format\n        ),\n        \"large\": utils.get_livejournal().adj_external(scipy_fmt=scipy_format),\n    }\n\n    # dry run\n    dgl.from_scipy(matrix_dict[size])\n\n    # timing\n    with utils.Timer() as t:\n        for i in range(3):\n            dgl.from_scipy(matrix_dict[size])\n\n    return t.elapsed_secs / 3\n"
  },
  {
    "path": "benchmarks/benchmarks/api/bench_in_degrees.py",
    "content": "import time\n\nimport dgl\n\nimport numpy as np\nimport torch\n\nfrom .. import utils\n\n\n@utils.benchmark(\"time\", timeout=1200)\n@utils.parametrize_cpu(\"graph_name\", [\"cora\", \"livejournal\", \"friendster\"])\n@utils.parametrize_gpu(\"graph_name\", [\"cora\", \"livejournal\"])\n# in_degrees on coo is not supported on cuda\n@utils.parametrize_cpu(\"format\", [\"coo\", \"csc\"])\n@utils.parametrize_gpu(\"format\", [\"csc\"])\n@utils.parametrize(\"fraction\", [0.01, 0.1])\ndef track_time(graph_name, format, fraction):\n    device = utils.get_bench_device()\n    graph = utils.get_graph(graph_name, format)\n    graph = graph.to(device)\n    nids = np.random.choice(\n        np.arange(graph.num_nodes(), dtype=np.int64),\n        int(graph.num_nodes() * fraction),\n    )\n    nids = torch.tensor(nids, device=device, dtype=torch.int64)\n\n    # dry run\n    for i in range(10):\n        out = graph.in_degrees(i)\n\n    # timing\n    with utils.Timer() as t:\n        for i in range(10):\n            edges = graph.in_degrees(nids)\n\n    return t.elapsed_secs / 10\n"
  },
  {
    "path": "benchmarks/benchmarks/api/bench_in_edges.py",
    "content": "import time\n\nimport dgl\n\nimport numpy as np\nimport torch\n\nfrom .. import utils\n\n\n@utils.benchmark(\"time\", timeout=1200)\n@utils.parametrize_cpu(\"graph_name\", [\"cora\", \"livejournal\", \"friendster\"])\n@utils.parametrize_gpu(\"graph_name\", [\"cora\", \"livejournal\"])\n# in_edges on coo is not supported on cuda\n@utils.parametrize_cpu(\"format\", [\"coo\", \"csc\"])\n@utils.parametrize_gpu(\"format\", [\"csc\"])\n@utils.parametrize(\"fraction\", [0.01, 0.1])\ndef track_time(graph_name, format, fraction):\n    device = utils.get_bench_device()\n    graph = utils.get_graph(graph_name, format)\n\n    graph = graph.to(device)\n    nids = np.random.choice(\n        np.arange(graph.num_nodes(), dtype=np.int64),\n        int(graph.num_nodes() * fraction),\n    )\n    nids = torch.tensor(nids, device=device, dtype=torch.int64)\n\n    # dry run\n    for i in range(10):\n        out = graph.in_edges(i)\n\n    # timing\n    with utils.Timer() as t:\n        for i in range(10):\n            edges = graph.in_edges(nids)\n\n    return t.elapsed_secs / 10\n"
  },
  {
    "path": "benchmarks/benchmarks/api/bench_in_subgraph.py",
    "content": "import time\n\nimport dgl\nimport dgl.function as fn\n\nimport numpy as np\nimport torch\n\nfrom .. import utils\n\n\n@utils.benchmark(\"time\")\n@utils.parametrize(\"graph_name\", [\"livejournal\", \"reddit\"])\n@utils.parametrize(\"format\", [\"csc\"])  # coo is not supported\n@utils.parametrize(\"seed_nodes_num\", [200, 5000, 20000])\ndef track_time(graph_name, format, seed_nodes_num):\n    device = utils.get_bench_device()\n    graph = utils.get_graph(graph_name, format)\n    graph = graph.to(device)\n\n    seed_nodes = np.random.randint(0, graph.num_nodes(), seed_nodes_num)\n    seed_nodes = torch.from_numpy(seed_nodes).to(device)\n\n    # dry run\n    for i in range(3):\n        dgl.in_subgraph(graph, seed_nodes)\n\n    # timing\n    num_iters = 50\n    with utils.Timer() as t:\n        for i in range(num_iters):\n            dgl.in_subgraph(graph, seed_nodes)\n\n    return t.elapsed_secs / num_iters\n"
  },
  {
    "path": "benchmarks/benchmarks/api/bench_khop.py",
    "content": "import time\n\nimport dgl\n\nimport numpy as np\nimport torch\n\nfrom .. import utils\n\n\n@utils.benchmark(\"time\", timeout=60)\n@utils.parametrize(\"graph_name\", [\"cora\"])\n@utils.parametrize(\"format\", [\"coo\", \"csr\"])\n@utils.parametrize(\"k\", [1, 3, 5])\ndef track_time(graph_name, format, k):\n    device = utils.get_bench_device()\n    graph = utils.get_graph(graph_name, format)\n    graph = graph.to(device)\n    graph = graph.formats([format])\n    # dry run\n    dgl.khop_graph(graph, k)\n\n    # timing\n    with utils.Timer() as t:\n        for i in range(10):\n            gg = dgl.khop_graph(graph, k)\n\n    return t.elapsed_secs / 10\n"
  },
  {
    "path": "benchmarks/benchmarks/api/bench_knn_graph.py",
    "content": "import time\n\nimport dgl\n\nimport numpy as np\nimport torch\n\nfrom .. import utils\n\n\n@utils.benchmark(\"time\", timeout=60)\n@utils.parametrize(\"k\", [8, 64])\n@utils.parametrize(\"size\", [1000, 10000])\n@utils.parametrize(\"dim\", [4, 32, 256])\n@utils.parametrize_cpu(\n    \"algorithm\", [\"bruteforce-blas\", \"bruteforce\", \"kd-tree\", \"nn-descent\"]\n)\n@utils.parametrize_gpu(\n    \"algorithm\",\n    [\"bruteforce-blas\", \"bruteforce\", \"bruteforce-sharemem\", \"nn-descent\"],\n)\ndef track_time(size, dim, k, algorithm):\n    device = utils.get_bench_device()\n    features = np.random.RandomState(42).randn(size, dim)\n    feat = torch.tensor(features, dtype=torch.float, device=device)\n    # dry run\n    for i in range(1):\n        dgl.knn_graph(feat, k, algorithm=algorithm)\n    # timing\n    with utils.Timer() as t:\n        for i in range(5):\n            dgl.knn_graph(feat, k, algorithm=algorithm)\n\n    return t.elapsed_secs / 5\n"
  },
  {
    "path": "benchmarks/benchmarks/api/bench_metis_partition.py",
    "content": "import time\n\nimport dgl\n\nimport numpy as np\nimport torch\n\nfrom .. import utils\n\n\n@utils.skip_if_gpu()\n@utils.benchmark(\"time\", timeout=1200)\n@utils.parametrize(\"graph_name\", [\"reddit\"])\n@utils.parametrize(\"k\", [2, 4, 8])\ndef track_time(graph_name, k):\n    device = utils.get_bench_device()\n    data = utils.process_data(graph_name)\n    graph = data[0]\n    # dry run\n    gg = dgl.transforms.metis_partition(graph, k)\n\n    # timing\n    with utils.Timer() as t:\n        for i in range(3):\n            gg = dgl.transforms.metis_partition(graph, k)\n\n    return t.elapsed_secs / 3\n"
  },
  {
    "path": "benchmarks/benchmarks/api/bench_nn_graphconv.py",
    "content": "import time\n\nimport dgl\nimport dgl.function as fn\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dgl.nn.pytorch import SAGEConv\n\nfrom .. import utils\n\n\n@utils.benchmark(\"time\")\n@utils.parametrize(\"graph_name\", [\"pubmed\", \"ogbn-arxiv\"])\n@utils.parametrize(\"feat_dim\", [4, 32, 256])\n@utils.parametrize(\"aggr_type\", [\"mean\", \"gcn\", \"pool\"])\ndef track_time(graph_name, feat_dim, aggr_type):\n    device = utils.get_bench_device()\n    graph = utils.get_graph(graph_name).to(device)\n\n    feat = torch.randn((graph.num_nodes(), feat_dim), device=device)\n    model = SAGEConv(\n        feat_dim, feat_dim, aggr_type, activation=F.relu, bias=False\n    ).to(device)\n\n    # dry run\n    for i in range(3):\n        model(graph, feat)\n    # timing\n    with utils.Timer() as t:\n        for i in range(50):\n            model(graph, feat)\n\n    return t.elapsed_secs / 50\n"
  },
  {
    "path": "benchmarks/benchmarks/api/bench_nn_heterographconv.py",
    "content": "import time\n\nimport dgl\nimport dgl.function as fn\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dgl.nn.pytorch import HeteroGraphConv, SAGEConv\n\nfrom .. import utils\n\n\n@utils.benchmark(\"time\")\n@utils.parametrize(\"feat_dim\", [4, 32, 256])\n@utils.parametrize(\"num_relations\", [5, 50, 200])\ndef track_time(feat_dim, num_relations):\n    device = utils.get_bench_device()\n    dd = {}\n    nn_dict = {}\n    candidate_edges = [\n        dgl.data.CoraGraphDataset(verbose=False)[0].edges(),\n        dgl.data.PubmedGraphDataset(verbose=False)[0].edges(),\n        dgl.data.CiteseerGraphDataset(verbose=False)[0].edges(),\n    ]\n    for i in range(num_relations):\n        dd[(\"n1\", \"e_{}\".format(i), \"n2\")] = candidate_edges[\n            i % len(candidate_edges)\n        ]\n        nn_dict[\"e_{}\".format(i)] = SAGEConv(\n            feat_dim, feat_dim, \"mean\", activation=F.relu\n        )\n\n    # dry run\n    feat_dict = {}\n    graph = dgl.heterograph(dd)\n    for i in range(num_relations):\n        etype = \"e_{}\".format(i)\n        feat_dict[etype] = torch.randn(\n            (graph[etype].num_nodes(), feat_dim), device=device\n        )\n\n    conv = HeteroGraphConv(nn_dict).to(device)\n\n    # dry run\n    for i in range(3):\n        conv(graph, feat_dict)\n    # timing\n    with utils.Timer() as t:\n        for i in range(50):\n            conv(graph, feat_dict)\n\n    return t.elapsed_secs / 50\n"
  },
  {
    "path": "benchmarks/benchmarks/api/bench_node_subgraph.py",
    "content": "import time\n\nimport dgl\nimport dgl.function as fn\n\nimport numpy as np\nimport torch\n\nfrom .. import utils\n\n\n@utils.benchmark(\"time\")\n@utils.parametrize(\"graph_name\", [\"livejournal\", \"reddit\"])\n@utils.parametrize(\"format\", [\"coo\", \"csc\"])\n@utils.parametrize(\"seed_nodes_num\", [200, 5000, 20000])\ndef track_time(graph_name, format, seed_nodes_num):\n    device = utils.get_bench_device()\n    graph = utils.get_graph(graph_name, format)\n    graph = graph.to(device)\n\n    seed_nodes = np.random.randint(0, graph.num_nodes(), seed_nodes_num)\n    seed_nodes = torch.from_numpy(seed_nodes).to(device)\n\n    # dry run\n    for i in range(3):\n        dgl.node_subgraph(graph, seed_nodes)\n\n    # timing\n    num_iters = 50\n    with utils.Timer() as t:\n        for i in range(num_iters):\n            dgl.node_subgraph(graph, seed_nodes)\n\n    return t.elapsed_secs / num_iters\n"
  },
  {
    "path": "benchmarks/benchmarks/api/bench_random_walk.py",
    "content": "import time\n\nimport dgl\n\nimport torch\n\nfrom .. import utils\n\n\ndef _random_walk(g, seeds, length):\n    return dgl.sampling.random_walk(g, seeds, length=length)\n\n\ndef _node2vec(g, seeds, length):\n    return dgl.sampling.node2vec_random_walk(g, seeds, 1, 1, length)\n\n\n@utils.skip_if_gpu()\n@utils.benchmark(\"time\")\n@utils.parametrize(\"graph_name\", [\"cora\", \"livejournal\", \"friendster\"])\n@utils.parametrize(\"num_seeds\", [10, 100, 1000])\n@utils.parametrize(\"length\", [2, 5, 10, 20])\n@utils.parametrize(\"algorithm\", [\"_random_walk\", \"_node2vec\"])\ndef track_time(graph_name, num_seeds, length, algorithm):\n    device = utils.get_bench_device()\n    graph = utils.get_graph(graph_name, \"csr\")\n    seeds = torch.randint(0, graph.num_nodes(), (num_seeds,))\n    print(graph_name, num_seeds, length)\n    alg = globals()[algorithm]\n    # dry run\n    for i in range(5):\n        _ = alg(graph, seeds, length=length)\n\n    # timing\n    with utils.Timer() as t:\n        for i in range(50):\n            _ = alg(graph, seeds, length=length)\n\n    return t.elapsed_secs / 50\n"
  },
  {
    "path": "benchmarks/benchmarks/api/bench_readout.py",
    "content": "import time\n\nimport dgl\n\nimport torch\n\nfrom .. import utils\n\n\n@utils.benchmark(\"time\")\n@utils.parametrize(\"batch_size\", [4, 256, 1024])\n@utils.parametrize(\"feat_size\", [16, 128, 512])\n@utils.parametrize(\"readout_op\", [\"sum\", \"max\", \"min\", \"mean\"])\n@utils.parametrize(\"type\", [\"edge\", \"node\"])\ndef track_time(batch_size, feat_size, readout_op, type):\n    device = utils.get_bench_device()\n    ds = dgl.data.QM7bDataset()\n    # prepare graph\n    graphs = ds[0:batch_size][0]\n\n    g = dgl.batch(graphs).to(device)\n    if type == \"node\":\n        g.ndata[\"h\"] = torch.randn((g.num_nodes(), feat_size), device=device)\n        for i in range(10):\n            out = dgl.readout_nodes(g, \"h\", op=readout_op)\n        with utils.Timer() as t:\n            for i in range(50):\n                out = dgl.readout_nodes(g, \"h\", op=readout_op)\n    elif type == \"edge\":\n        g.edata[\"h\"] = torch.randn((g.num_edges(), feat_size), device=device)\n        for i in range(10):\n            out = dgl.readout_edges(g, \"h\", op=readout_op)\n        with utils.Timer() as t:\n            for i in range(50):\n                out = dgl.readout_edges(g, \"h\", op=readout_op)\n    else:\n        raise Exception(\"Unknown type\")\n\n    return t.elapsed_secs / 50\n"
  },
  {
    "path": "benchmarks/benchmarks/api/bench_reverse.py",
    "content": "import time\n\nimport dgl\n\nimport numpy as np\nimport torch\n\nfrom .. import utils\n\n\n@utils.benchmark(\"time\", timeout=1200)\n@utils.parametrize_cpu(\"graph_name\", [\"cora\", \"livejournal\", \"friendster\"])\n@utils.parametrize_gpu(\"graph_name\", [\"cora\", \"livejournal\"])\n@utils.parametrize(\"format\", [\"coo\", \"csc\", \"csr\"])\ndef track_time(graph_name, format):\n    device = utils.get_bench_device()\n    graph = utils.get_graph(graph_name, format)\n    graph = graph.to(device)\n    graph = graph.formats([format])\n    # dry run\n    dgl.reverse(graph)\n\n    # timing\n    with utils.Timer() as t:\n        for i in range(100):\n            gg = dgl.reverse(graph)\n\n    return t.elapsed_secs / 100\n"
  },
  {
    "path": "benchmarks/benchmarks/api/bench_sample_neighbors.py",
    "content": "import time\n\nimport dgl\nimport dgl.function as fn\n\nimport numpy as np\nimport torch\n\nfrom .. import utils\n\n\n@utils.benchmark(\"time\")\n@utils.parametrize_cpu(\"graph_name\", [\"livejournal\", \"reddit\"])\n@utils.parametrize_gpu(\"graph_name\", [\"ogbn-arxiv\", \"reddit\"])\n@utils.parametrize(\"format\", [\"coo\", \"csc\"])\n@utils.parametrize(\"seed_nodes_num\", [200, 5000, 20000])\n@utils.parametrize(\"fanout\", [5, 20, 40])\ndef track_time(graph_name, format, seed_nodes_num, fanout):\n    device = utils.get_bench_device()\n    graph = utils.get_graph(graph_name, format).to(device)\n\n    edge_dir = \"in\"\n    seed_nodes = np.random.randint(0, graph.num_nodes(), seed_nodes_num)\n    seed_nodes = torch.from_numpy(seed_nodes).to(device)\n\n    # dry run\n    for i in range(3):\n        dgl.sampling.sample_neighbors(\n            graph, seed_nodes, fanout, edge_dir=edge_dir\n        )\n\n    # timing\n    with utils.Timer() as t:\n        for i in range(50):\n            dgl.sampling.sample_neighbors(\n                graph, seed_nodes, fanout, edge_dir=edge_dir\n            )\n\n    return t.elapsed_secs / 50\n"
  },
  {
    "path": "benchmarks/benchmarks/api/bench_to_block.py",
    "content": "import time\n\nimport dgl\n\nimport numpy as np\nimport torch\n\nfrom .. import utils\n\n\n@utils.skip_if_gpu()\n@utils.benchmark(\"time\", timeout=1200)\n@utils.parametrize(\"graph_name\", [\"reddit\", \"ogbn-products\"])\n@utils.parametrize(\"num_seed_nodes\", [32, 256, 1024, 2048])\n@utils.parametrize(\"fanout\", [5, 10, 20])\ndef track_time(graph_name, num_seed_nodes, fanout):\n    device = utils.get_bench_device()\n    data = utils.process_data(graph_name)\n    graph = data[0]\n\n    # dry run\n    dgl.sampling.sample_neighbors(graph, [1, 2, 3], fanout)\n\n    subg_list = []\n    for i in range(10):\n        seed_nodes = np.random.randint(\n            0, graph.num_nodes(), size=num_seed_nodes\n        )\n        subg = dgl.sampling.sample_neighbors(graph, seed_nodes, fanout)\n        subg_list.append(subg)\n\n    # timing\n    with utils.Timer() as t:\n        for i in range(10):\n            gg = dgl.to_block(subg_list[i])\n\n    return t.elapsed_secs / 10\n"
  },
  {
    "path": "benchmarks/benchmarks/api/bench_udf_apply_edges.py",
    "content": "import time\n\nimport dgl\nimport dgl.function as fn\n\nimport numpy as np\nimport torch\n\nfrom .. import utils\n\n\n@utils.benchmark(\"time\", timeout=7200)\n@utils.parametrize(\"graph_name\", [\"ogbn-arxiv\", \"pubmed\"])\n@utils.parametrize(\"format\", [\"coo\"])  # only coo supports udf\n@utils.parametrize(\"feat_size\", [8, 32, 128, 512])\n@utils.parametrize(\"reduce_type\", [\"u->e\", \"u+v\"])\ndef track_time(graph_name, format, feat_size, reduce_type):\n    device = utils.get_bench_device()\n    graph = utils.get_graph(graph_name, format)\n    graph = graph.to(device)\n    graph.ndata[\"h\"] = torch.randn(\n        (graph.num_nodes(), feat_size), device=device\n    )\n\n    reduce_udf_dict = {\n        \"u->e\": lambda edges: {\"x\": edges.src[\"h\"]},\n        \"u+v\": lambda edges: {\"x\": edges.src[\"h\"] + edges.dst[\"h\"]},\n    }\n\n    # dry run\n    graph.apply_edges(reduce_udf_dict[reduce_type])\n\n    # timing\n    with utils.Timer() as t:\n        for i in range(3):\n            graph.apply_edges(reduce_udf_dict[reduce_type])\n\n    return t.elapsed_secs / 3\n"
  },
  {
    "path": "benchmarks/benchmarks/api/bench_udf_multi_update_all.py",
    "content": "import time\n\nimport dgl\nimport dgl.function as fn\n\nimport numpy as np\nimport torch\n\nfrom .. import utils\n\n\n@utils.benchmark(\"time\", timeout=600)\n@utils.parametrize(\"feat_size\", [32, 128, 512])\n@utils.parametrize(\"num_relations\", [5, 50, 500])\n@utils.parametrize(\"multi_reduce_type\", [\"sum\", \"stack\"])\ndef track_time(feat_size, num_relations, multi_reduce_type):\n    device = utils.get_bench_device()\n    dd = {}\n    candidate_edges = [\n        dgl.data.CoraGraphDataset(verbose=False)[0].edges(),\n        dgl.data.PubmedGraphDataset(verbose=False)[0].edges(),\n        dgl.data.CiteseerGraphDataset(verbose=False)[0].edges(),\n    ]\n    for i in range(num_relations):\n        dd[(\"n1\", \"e_{}\".format(i), \"n2\")] = candidate_edges[\n            i % len(candidate_edges)\n        ]\n    graph = dgl.heterograph(dd)\n\n    graph = graph.to(device)\n    graph.nodes[\"n1\"].data[\"h\"] = torch.randn(\n        (graph.num_nodes(\"n1\"), feat_size), device=device\n    )\n    graph.nodes[\"n2\"].data[\"h\"] = torch.randn(\n        (graph.num_nodes(\"n2\"), feat_size), device=device\n    )\n\n    # dry run\n    update_dict = {}\n    for i in range(num_relations):\n        update_dict[\"e_{}\".format(i)] = (\n            lambda edges: {\"x\": edges.src[\"h\"]},\n            lambda nodes: {\"h_new\": torch.sum(nodes.mailbox[\"x\"], dim=1)},\n        )\n    graph.multi_update_all(update_dict, multi_reduce_type)\n\n    # timing\n    with utils.Timer() as t:\n        for i in range(3):\n            graph.multi_update_all(update_dict, multi_reduce_type)\n\n    return t.elapsed_secs / 3\n"
  },
  {
    "path": "benchmarks/benchmarks/api/bench_udf_update_all.py",
    "content": "import time\n\nimport dgl\nimport dgl.function as fn\n\nimport numpy as np\nimport torch\n\nfrom .. import utils\n\n\n@utils.benchmark(\"time\", timeout=600)\n@utils.parametrize(\"graph_name\", [\"pubmed\", \"ogbn-arxiv\"])\n@utils.parametrize(\"format\", [\"coo\"])  # only coo supports udf\n@utils.parametrize(\"feat_size\", [8, 64, 512])\n@utils.parametrize(\"msg_type\", [\"copy_u\", \"u_mul_e\"])\n@utils.parametrize(\"reduce_type\", [\"sum\", \"mean\", \"max\"])\ndef track_time(graph_name, format, feat_size, msg_type, reduce_type):\n    device = utils.get_bench_device()\n    graph = utils.get_graph(graph_name, format)\n    graph = graph.to(device)\n    graph.ndata[\"h\"] = torch.randn(\n        (graph.num_nodes(), feat_size), device=device\n    )\n    graph.edata[\"e\"] = torch.randn((graph.num_edges(), 1), device=device)\n\n    msg_udf_dict = {\n        \"copy_u\": lambda edges: {\"x\": edges.src[\"h\"]},\n        \"u_mul_e\": lambda edges: {\"x\": edges.src[\"h\"] * edges.data[\"e\"]},\n    }\n\n    reduct_udf_dict = {\n        \"sum\": lambda nodes: {\"h_new\": torch.sum(nodes.mailbox[\"x\"], dim=1)},\n        \"mean\": lambda nodes: {\"h_new\": torch.mean(nodes.mailbox[\"x\"], dim=1)},\n        \"max\": lambda nodes: {\"h_new\": torch.max(nodes.mailbox[\"x\"], dim=1)[0]},\n    }\n\n    # dry run\n    graph.update_all(msg_udf_dict[msg_type], reduct_udf_dict[reduce_type])\n\n    # timing\n    with utils.Timer() as t:\n        for i in range(3):\n            graph.update_all(\n                msg_udf_dict[msg_type], reduct_udf_dict[reduce_type]\n            )\n\n    return t.elapsed_secs / 3\n"
  },
  {
    "path": "benchmarks/benchmarks/api/bench_unbatch.py",
    "content": "import time\n\nimport dgl\n\nimport torch\n\nfrom .. import utils\n\n\n@utils.benchmark(\"time\")\n@utils.parametrize(\"batch_size\", [4, 32, 256, 1024])\ndef track_time(batch_size):\n    device = utils.get_bench_device()\n    ds = dgl.data.QM7bDataset()\n    # prepare graph\n    graphs = ds[0:batch_size][0]\n    bg = dgl.batch(graphs).to(device)\n\n    # dry run\n    for i in range(10):\n        glist = dgl.unbatch(bg)\n\n    # timing\n    with utils.Timer() as t:\n        for i in range(100):\n            glist = dgl.unbatch(bg)\n\n    return t.elapsed_secs / 100\n"
  },
  {
    "path": "benchmarks/benchmarks/kernel/__init__.py",
    "content": ""
  },
  {
    "path": "benchmarks/benchmarks/kernel/bench_edgesoftmax.py",
    "content": "import time\n\nimport dgl\n\nimport torch\n\nfrom .. import utils\n\n\n# The benchmarks for ops edge_softmax\n@utils.benchmark(\"time\", timeout=600)\n@utils.parametrize(\"graph\", [\"ogbn-arxiv\", \"reddit\", \"cora\", \"pubmed\"])\n@utils.parametrize(\"num_heads\", [1, 4, 8])\ndef track_time(graph, num_heads):\n    device = utils.get_bench_device()\n    graph = utils.get_graph(graph).to(device)\n    score = (\n        torch.randn((graph.num_edges(), num_heads))\n        .requires_grad_(True)\n        .float()\n        .to(device)\n    )\n\n    # dry run\n    for i in range(3):\n        y = dgl.ops.edge_softmax(graph, score)\n\n    # timing\n    with utils.Timer(device) as t:\n        for i in range(100):\n            y = dgl.ops.edge_softmax(graph, score)\n\n    return t.elapsed_secs / 100\n"
  },
  {
    "path": "benchmarks/benchmarks/kernel/bench_gsddmm_u_dot_v.py",
    "content": "import time\n\nimport dgl\n\nimport torch\n\nfrom .. import utils\n\n\ndef calc_gflops(graph, feat_size, num_heads, time):\n    return round(\n        2 * graph.num_edges() * feat_size / 1000000000 / time, 2\n    )  # count both mul and add\n\n\n# The benchmarks include broadcasting cases.\n# Given feat_size = D, num_heads = H, the node feature shape will be (H, D // H)\n#   while the edge feature shape will be (H, ), so tested operations will broadcast\n#   along the last dimension. The total FLOP is controlled by the feat_size no\n#   matter how many heads are there.\n# If num_heads = 0, it falls back to the normal element-wise operation without\n#   broadcasting.\n@utils.benchmark(\"flops\", timeout=600)\n@utils.parametrize(\"graph\", [\"ogbn-arxiv\", \"reddit\", \"ogbn-proteins\"])\n@utils.parametrize(\"feat_size\", [4, 32, 256])\n@utils.parametrize(\"num_heads\", [0, 1, 4])\ndef track_flops(graph, feat_size, num_heads):\n    device = utils.get_bench_device()\n    graph = utils.get_graph(graph, format=\"coo\").to(device)\n    if num_heads == 0:\n        x = torch.randn(graph.num_nodes(), feat_size, device=device)\n    else:\n        x = torch.randn(\n            graph.num_nodes(), num_heads, feat_size // num_heads, device=device\n        )\n\n    # dry run\n    for i in range(3):\n        y = dgl.ops.u_dot_v(graph, x, x)\n\n    # timing\n    with utils.Timer(device) as t:\n        for i in range(10):\n            y = dgl.ops.u_dot_v(graph, x, x)\n\n    return calc_gflops(graph, feat_size, num_heads, t.elapsed_secs / 10)\n"
  },
  {
    "path": "benchmarks/benchmarks/kernel/bench_gspmm_copy_u.py",
    "content": "import time\n\nimport dgl\n\nimport torch\n\nfrom .. import utils\n\n\ndef calc_gflops(graph, feat_size, time):\n    return round(graph.num_edges() * feat_size / 1000000000 / time, 2)\n\n\n@utils.benchmark(\"flops\", timeout=600)\n@utils.parametrize(\"graph\", [\"ogbn-arxiv\", \"reddit\", \"ogbn-proteins\"])\n@utils.parametrize(\"feat_size\", [4, 32, 256])\n@utils.parametrize(\"reducer\", [\"sum\", \"max\"])\ndef track_flops(graph, feat_size, reducer):\n    device = utils.get_bench_device()\n    graph = utils.get_graph(graph, format=\"csc\").to(device)\n    x = torch.randn(graph.num_nodes(), feat_size, device=device)\n\n    if reducer == \"sum\":\n        op = dgl.ops.copy_u_sum\n    elif reducer == \"max\":\n        op = dgl.ops.copy_u_max\n    else:\n        raise ValueError(\"Invalid reducer\", reducer)\n\n    # dry run\n    for i in range(3):\n        y = op(graph, x)\n\n    # timing\n    with utils.Timer(device) as t:\n        for i in range(10):\n            y = op(graph, x)\n\n    return calc_gflops(graph, feat_size, t.elapsed_secs / 10)\n"
  },
  {
    "path": "benchmarks/benchmarks/kernel/bench_gspmm_u_mul_e_sum.py",
    "content": "import time\n\nimport dgl\n\nimport torch\n\nfrom .. import utils\n\n\ndef calc_gflops(graph, feat_size, num_heads, time):\n    return round(\n        2 * graph.num_edges() * feat_size / 1000000000 / time, 2\n    )  # count both mul and add\n\n\n# The benchmarks include broadcasting cases.\n# Given feat_size = D, num_heads = H, the node feature shape will be (H, D // H)\n#   while the edge feature shape will be (H, ), so tested operations will broadcast\n#   along the last dimension. The total FLOP is controlled by the feat_size no\n#   matter how many heads are there.\n# If num_heads = 0, it falls back to the normal element-wise operation without\n#   broadcasting.\n@utils.benchmark(\"flops\", timeout=600)\n@utils.parametrize(\"graph\", [\"ogbn-arxiv\", \"reddit\", \"ogbn-proteins\"])\n@utils.parametrize(\"feat_size\", [4, 32, 256])\n@utils.parametrize(\"num_heads\", [0, 1, 4])\ndef track_flops(graph, feat_size, num_heads):\n    device = utils.get_bench_device()\n    graph = utils.get_graph(graph, format=\"csc\").to(device)\n    if num_heads == 0:\n        x = torch.randn(graph.num_nodes(), feat_size, device=device)\n        w = torch.randn(graph.num_edges(), feat_size, device=device)\n    else:\n        x = torch.randn(\n            graph.num_nodes(), num_heads, feat_size // num_heads, device=device\n        )\n        w = torch.randn(graph.num_edges(), num_heads, 1, device=device)\n\n    # dry run\n    for i in range(3):\n        y = dgl.ops.u_mul_e_sum(graph, x, w)\n\n    # timing\n    with utils.Timer(device) as t:\n        for i in range(10):\n            y = dgl.ops.u_mul_e_sum(graph, x, w)\n\n    return calc_gflops(graph, feat_size, num_heads, t.elapsed_secs / 10)\n"
  },
  {
    "path": "benchmarks/benchmarks/model_acc/__init__.py",
    "content": ""
  },
  {
    "path": "benchmarks/benchmarks/model_acc/bench_gat.py",
    "content": "import dgl\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dgl.nn.pytorch import GATConv\n\nfrom .. import utils\n\n\nclass GAT(nn.Module):\n    def __init__(\n        self,\n        num_layers,\n        in_dim,\n        num_hidden,\n        num_classes,\n        heads,\n        activation,\n        feat_drop,\n        attn_drop,\n        negative_slope,\n        residual,\n    ):\n        super(GAT, self).__init__()\n        self.num_layers = num_layers\n        self.gat_layers = nn.ModuleList()\n        self.activation = activation\n        # input projection (no residual)\n        self.gat_layers.append(\n            GATConv(\n                in_dim,\n                num_hidden,\n                heads[0],\n                feat_drop,\n                attn_drop,\n                negative_slope,\n                False,\n                self.activation,\n            )\n        )\n        # hidden layers\n        for l in range(1, num_layers):\n            # due to multi-head, the in_dim = num_hidden * num_heads\n            self.gat_layers.append(\n                GATConv(\n                    num_hidden * heads[l - 1],\n                    num_hidden,\n                    heads[l],\n                    feat_drop,\n                    attn_drop,\n                    negative_slope,\n                    residual,\n                    self.activation,\n                )\n            )\n        # output projection\n        self.gat_layers.append(\n            GATConv(\n                num_hidden * heads[-2],\n                num_classes,\n                heads[-1],\n                feat_drop,\n                attn_drop,\n                negative_slope,\n                residual,\n                None,\n            )\n        )\n\n    def forward(self, g, inputs):\n        h = inputs\n        for l in range(self.num_layers):\n            h = self.gat_layers[l](g, h).flatten(1)\n        # output projection\n        logits = self.gat_layers[-1](g, h).mean(1)\n        return logits\n\n\ndef evaluate(model, g, features, labels, mask):\n    model.eval()\n    with torch.no_grad():\n        logits = model(g, features)\n        logits = logits[mask]\n        labels = labels[mask]\n        _, indices = torch.max(logits, dim=1)\n        correct = torch.sum(indices == labels)\n        return correct.item() * 1.0 / len(labels) * 100\n\n\n@utils.benchmark(\"acc\")\n@utils.parametrize(\"data\", [\"cora\", \"pubmed\"])\ndef track_acc(data):\n    data = utils.process_data(data)\n    device = utils.get_bench_device()\n\n    g = data[0].to(device)\n\n    features = g.ndata[\"feat\"]\n    labels = g.ndata[\"label\"]\n    train_mask = g.ndata[\"train_mask\"]\n    val_mask = g.ndata[\"val_mask\"]\n    test_mask = g.ndata[\"test_mask\"]\n\n    in_feats = features.shape[1]\n    n_classes = data.num_classes\n\n    g = dgl.remove_self_loop(g)\n    g = dgl.add_self_loop(g)\n\n    # create model\n    model = GAT(1, in_feats, 8, n_classes, [8, 1], F.elu, 0.6, 0.6, 0.2, False)\n    loss_fcn = torch.nn.CrossEntropyLoss()\n\n    model = model.to(device)\n    model.train()\n\n    # optimizer\n    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)\n    for epoch in range(200):\n        logits = model(g, features)\n        loss = loss_fcn(logits[train_mask], labels[train_mask])\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n    acc = evaluate(model, g, features, labels, test_mask)\n    return acc\n"
  },
  {
    "path": "benchmarks/benchmarks/model_acc/bench_gcn.py",
    "content": "import dgl\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dgl.nn.pytorch import GraphConv\n\nfrom .. import utils\n\n\nclass GCN(nn.Module):\n    def __init__(\n        self, in_feats, n_hidden, n_classes, n_layers, activation, dropout\n    ):\n        super(GCN, self).__init__()\n        self.layers = nn.ModuleList()\n        # input layer\n        self.layers.append(GraphConv(in_feats, n_hidden, activation=activation))\n        # hidden layers\n        for i in range(n_layers - 1):\n            self.layers.append(\n                GraphConv(n_hidden, n_hidden, activation=activation)\n            )\n        # output layer\n        self.layers.append(GraphConv(n_hidden, n_classes))\n        self.dropout = nn.Dropout(p=dropout)\n\n    def forward(self, g, features):\n        h = features\n        for i, layer in enumerate(self.layers):\n            if i != 0:\n                h = self.dropout(h)\n            h = layer(g, h)\n        return h\n\n\ndef evaluate(model, g, features, labels, mask):\n    model.eval()\n    with torch.no_grad():\n        logits = model(g, features)\n        logits = logits[mask]\n        labels = labels[mask]\n        _, indices = torch.max(logits, dim=1)\n        correct = torch.sum(indices == labels)\n        return correct.item() * 1.0 / len(labels) * 100\n\n\n@utils.benchmark(\"acc\")\n@utils.parametrize(\"data\", [\"cora\", \"pubmed\"])\ndef track_acc(data):\n    data = utils.process_data(data)\n    device = utils.get_bench_device()\n\n    g = data[0].to(device).int()\n\n    features = g.ndata[\"feat\"]\n    labels = g.ndata[\"label\"]\n    train_mask = g.ndata[\"train_mask\"]\n    val_mask = g.ndata[\"val_mask\"]\n    test_mask = g.ndata[\"test_mask\"]\n\n    in_feats = features.shape[1]\n    n_classes = data.num_classes\n\n    g = dgl.remove_self_loop(g)\n    g = dgl.add_self_loop(g)\n\n    # normalization\n    degs = g.in_degrees().float()\n    norm = torch.pow(degs, -0.5)\n    norm[torch.isinf(norm)] = 0\n    g.ndata[\"norm\"] = norm.unsqueeze(1)\n\n    # create GCN model\n    model = GCN(in_feats, 16, n_classes, 1, F.relu, 0.5)\n    loss_fcn = torch.nn.CrossEntropyLoss()\n\n    model = model.to(device)\n    model.train()\n\n    # optimizer\n    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)\n    for epoch in range(200):\n        logits = model(g, features)\n        loss = loss_fcn(logits[train_mask], labels[train_mask])\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n    acc = evaluate(model, g, features, labels, test_mask)\n    return acc\n"
  },
  {
    "path": "benchmarks/benchmarks/model_acc/bench_gcn_udf.py",
    "content": "import dgl\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom .. import utils\n\n\nclass GraphConv(nn.Module):\n    def __init__(self, in_dim, out_dim, activation=None):\n        super(GraphConv, self).__init__()\n        self.in_dim = in_dim\n        self.out_dim = out_dim\n        self.activation = activation\n        self.weight = nn.Parameter(torch.Tensor(in_dim, out_dim))\n        self.bias = nn.Parameter(torch.Tensor(out_dim))\n        nn.init.xavier_normal_(self.weight)\n        nn.init.zeros_(self.bias)\n\n    def forward(self, graph, feat):\n        with graph.local_scope():\n            graph.ndata[\"ci\"] = torch.pow(\n                graph.out_degrees().float().clamp(min=1), -0.5\n            )\n            graph.ndata[\"cj\"] = torch.pow(\n                graph.in_degrees().float().clamp(min=1), -0.5\n            )\n            graph.ndata[\"h\"] = feat\n            graph.update_all(self.mfunc, self.rfunc)\n            h = graph.ndata[\"h\"]\n            h = torch.matmul(h, self.weight) + self.bias\n            if self.activation is not None:\n                h = self.activation(h)\n            return h\n\n    def mfunc(self, edges):\n        return {\"m\": edges.src[\"h\"], \"ci\": edges.src[\"ci\"]}\n\n    def rfunc(self, nodes):\n        ci = nodes.mailbox[\"ci\"].unsqueeze(2)\n        newh = (nodes.mailbox[\"m\"] * ci).sum(1) * nodes.data[\"cj\"].unsqueeze(1)\n        return {\"h\": newh}\n\n\nclass GCN(nn.Module):\n    def __init__(\n        self, in_feats, n_hidden, n_classes, n_layers, activation, dropout\n    ):\n        super(GCN, self).__init__()\n        self.layers = nn.ModuleList()\n        # input layer\n        self.layers.append(GraphConv(in_feats, n_hidden, activation=activation))\n        # hidden layers\n        for i in range(n_layers - 1):\n            self.layers.append(\n                GraphConv(n_hidden, n_hidden, activation=activation)\n            )\n        # output layer\n        self.layers.append(GraphConv(n_hidden, n_classes))\n        self.dropout = nn.Dropout(p=dropout)\n\n    def forward(self, g, features):\n        h = features\n        for i, layer in enumerate(self.layers):\n            if i != 0:\n                h = self.dropout(h)\n            h = layer(g, h)\n        return h\n\n\ndef evaluate(model, g, features, labels, mask):\n    model.eval()\n    with torch.no_grad():\n        logits = model(g, features)\n        logits = logits[mask]\n        labels = labels[mask]\n        _, indices = torch.max(logits, dim=1)\n        correct = torch.sum(indices == labels)\n        return correct.item() * 1.0 / len(labels) * 100\n\n\n@utils.benchmark(\"acc\", timeout=300)\n@utils.parametrize(\"data\", [\"cora\", \"pubmed\"])\ndef track_acc(data):\n    data = utils.process_data(data)\n    device = utils.get_bench_device()\n\n    g = data[0].to(device).int()\n\n    features = g.ndata[\"feat\"]\n    labels = g.ndata[\"label\"]\n    train_mask = g.ndata[\"train_mask\"]\n    val_mask = g.ndata[\"val_mask\"]\n    test_mask = g.ndata[\"test_mask\"]\n\n    in_feats = features.shape[1]\n    n_classes = data.num_classes\n\n    g = dgl.remove_self_loop(g)\n    g = dgl.add_self_loop(g)\n\n    # normalization\n    degs = g.in_degrees().float()\n    norm = torch.pow(degs, -0.5)\n    norm[torch.isinf(norm)] = 0\n    g.ndata[\"norm\"] = norm.unsqueeze(1)\n\n    # create GCN model\n    model = GCN(in_feats, 16, n_classes, 1, F.relu, 0.5)\n    loss_fcn = torch.nn.CrossEntropyLoss()\n\n    model = model.to(device)\n    model.train()\n\n    # optimizer\n    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)\n    for epoch in range(200):\n        logits = model(g, features)\n        loss = loss_fcn(logits[train_mask], labels[train_mask])\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n    acc = evaluate(model, g, features, labels, test_mask)\n    return acc\n"
  },
  {
    "path": "benchmarks/benchmarks/model_acc/bench_rgcn_base.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torchmetrics.functional import accuracy\n\nfrom .. import rgcn, utils\n\n\n@utils.benchmark(\"acc\", timeout=1200)\n@utils.parametrize(\"dataset\", [\"aifb\", \"mutag\"])\n@utils.parametrize(\"ns_mode\", [False])\ndef track_acc(dataset, ns_mode):\n    (\n        g,\n        num_rels,\n        num_classes,\n        labels,\n        train_idx,\n        test_idx,\n        target_idx,\n    ) = rgcn.load_data(dataset, get_norm=True)\n    num_hidden = 16\n    if dataset == \"aifb\":\n        num_bases = -1\n        l2norm = 0.0\n    elif dataset == \"mutag\":\n        num_bases = 30\n        l2norm = 5e-4\n    elif dataset == \"am\":\n        num_bases = 40\n        l2norm = 5e-4\n    else:\n        raise ValueError()\n    model = rgcn.RGCN(\n        g.num_nodes(),\n        num_hidden,\n        num_classes,\n        num_rels,\n        num_bases=num_bases,\n        ns_mode=ns_mode,\n    )\n    device = utils.get_bench_device()\n    labels = labels.to(device)\n    model = model.to(device)\n    g = g.int().to(device)\n\n    optimizer = torch.optim.Adam(\n        model.parameters(), lr=1e-2, weight_decay=l2norm\n    )\n\n    model.train()\n    for epoch in range(30):\n        logits = model(g)\n        logits = logits[target_idx]\n        loss = F.cross_entropy(logits[train_idx], labels[train_idx])\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n    model.eval()\n    with torch.no_grad():\n        logits = model(g)\n    logits = logits[target_idx]\n    test_acc = accuracy(\n        logits[test_idx].argmax(dim=1),\n        labels[test_idx],\n        task=\"multiclass\",\n        num_classes=num_classes,\n    ).item()\n\n    return test_acc\n"
  },
  {
    "path": "benchmarks/benchmarks/model_acc/bench_rgcn_ns.py",
    "content": "import itertools\nimport time\n\nimport dgl\nimport dgl.nn.pytorch as dglnn\nimport torch as th\nimport torch.multiprocessing as mp\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nfrom dgl.nn import RelGraphConv\nfrom torch.utils.data import DataLoader\n\nfrom .. import utils\n\n\nclass EntityClassify(nn.Module):\n    \"\"\"Entity classification class for RGCN\n    Parameters\n    ----------\n    device : int\n        Device to run the layer.\n    num_nodes : int\n        Number of nodes.\n    h_dim : int\n        Hidden dim size.\n    out_dim : int\n        Output dim size.\n    num_rels : int\n        Numer of relation types.\n    num_bases : int\n        Number of bases. If is none, use number of relations.\n    num_hidden_layers : int\n        Number of hidden RelGraphConv Layer\n    dropout : float\n        Dropout\n    use_self_loop : bool\n        Use self loop if True, default False.\n    \"\"\"\n\n    def __init__(\n        self,\n        device,\n        num_nodes,\n        h_dim,\n        out_dim,\n        num_rels,\n        num_bases=None,\n        num_hidden_layers=1,\n        dropout=0,\n        use_self_loop=False,\n        layer_norm=False,\n    ):\n        super(EntityClassify, self).__init__()\n        self.device = device\n        self.num_nodes = num_nodes\n        self.h_dim = h_dim\n        self.out_dim = out_dim\n        self.num_rels = num_rels\n        self.num_bases = None if num_bases < 0 else num_bases\n        self.num_hidden_layers = num_hidden_layers\n        self.dropout = dropout\n        self.use_self_loop = use_self_loop\n        self.layer_norm = layer_norm\n\n        self.layers = nn.ModuleList()\n        # i2h\n        self.layers.append(\n            RelGraphConv(\n                self.h_dim,\n                self.h_dim,\n                self.num_rels,\n                \"basis\",\n                self.num_bases,\n                activation=F.relu,\n                self_loop=self.use_self_loop,\n                dropout=self.dropout,\n                layer_norm=layer_norm,\n            )\n        )\n        # h2h\n        for idx in range(self.num_hidden_layers):\n            self.layers.append(\n                RelGraphConv(\n                    self.h_dim,\n                    self.h_dim,\n                    self.num_rels,\n                    \"basis\",\n                    self.num_bases,\n                    activation=F.relu,\n                    self_loop=self.use_self_loop,\n                    dropout=self.dropout,\n                    layer_norm=layer_norm,\n                )\n            )\n        # h2o\n        self.layers.append(\n            RelGraphConv(\n                self.h_dim,\n                self.out_dim,\n                self.num_rels,\n                \"basis\",\n                self.num_bases,\n                activation=None,\n                self_loop=self.use_self_loop,\n                layer_norm=layer_norm,\n            )\n        )\n\n    def forward(self, blocks, feats, norm=None):\n        if blocks is None:\n            # full graph training\n            blocks = [self.g] * len(self.layers)\n        h = feats\n        for layer, block in zip(self.layers, blocks):\n            block = block.to(self.device)\n            h = layer(block, h, block.edata[\"etype\"], block.edata[\"norm\"])\n        return h\n\n\nclass RelGraphEmbedLayer(nn.Module):\n    r\"\"\"Embedding layer for featureless heterograph.\n    Parameters\n    ----------\n    device : int\n        Device to run the layer.\n    num_nodes : int\n        Number of nodes.\n    node_tides : tensor\n        Storing the node type id for each node starting from 0\n    num_of_ntype : int\n        Number of node types\n    input_size : list of int\n        A list of input feature size for each node type. If None, we then\n        treat certain input feature as an one-hot encoding feature.\n    embed_size : int\n        Output embed size\n    embed_name : str, optional\n        Embed name\n    \"\"\"\n\n    def __init__(\n        self,\n        device,\n        num_nodes,\n        node_tids,\n        num_of_ntype,\n        input_size,\n        embed_size,\n        sparse_emb=False,\n        embed_name=\"embed\",\n    ):\n        super(RelGraphEmbedLayer, self).__init__()\n        self.device = device\n        self.embed_size = embed_size\n        self.embed_name = embed_name\n        self.num_nodes = num_nodes\n        self.sparse_emb = sparse_emb\n\n        # create weight embeddings for each node for each relation\n        self.embeds = nn.ParameterDict()\n        self.num_of_ntype = num_of_ntype\n        self.idmap = th.empty(num_nodes).long()\n\n        for ntype in range(num_of_ntype):\n            if input_size[ntype] is not None:\n                input_emb_size = input_size[ntype].shape[1]\n                embed = nn.Parameter(th.Tensor(input_emb_size, self.embed_size))\n                nn.init.xavier_uniform_(embed)\n                self.embeds[str(ntype)] = embed\n\n        self.node_embeds = th.nn.Embedding(\n            node_tids.shape[0], self.embed_size, sparse=self.sparse_emb\n        )\n        nn.init.uniform_(self.node_embeds.weight, -1.0, 1.0)\n\n    def forward(self, node_ids, node_tids, type_ids, features):\n        \"\"\"Forward computation\n        Parameters\n        ----------\n        node_ids : tensor\n            node ids to generate embedding for.\n        node_tids : tensor\n            node type ids\n        features : list of features\n            list of initial features for nodes belong to different node type.\n            If None, the corresponding features is an one-hot encoding feature,\n            else use the features directly as input feature and matmul a\n            projection matrix.\n        Returns\n        -------\n        tensor\n            embeddings as the input of the next layer\n        \"\"\"\n        tsd_ids = node_ids.to(self.node_embeds.weight.device)\n        embeds = th.empty(\n            node_ids.shape[0], self.embed_size, device=self.device\n        )\n        for ntype in range(self.num_of_ntype):\n            if features[ntype] is not None:\n                loc = node_tids == ntype\n                embeds[loc] = features[ntype][type_ids[loc]].to(\n                    self.device\n                ) @ self.embeds[str(ntype)].to(self.device)\n            else:\n                loc = node_tids == ntype\n                embeds[loc] = self.node_embeds(tsd_ids[loc]).to(self.device)\n\n        return embeds\n\n\ndef evaluate(model, embed_layer, eval_loader, node_feats):\n    model.eval()\n    embed_layer.eval()\n    eval_logits = []\n    eval_seeds = []\n\n    with th.no_grad():\n        for sample_data in eval_loader:\n            th.cuda.empty_cache()\n            _, _, blocks = sample_data\n            feats = embed_layer(\n                blocks[0].srcdata[dgl.NID],\n                blocks[0].srcdata[dgl.NTYPE],\n                blocks[0].srcdata[\"type_id\"],\n                node_feats,\n            )\n            logits = model(blocks, feats)\n            eval_logits.append(logits.cpu().detach())\n            eval_seeds.append(blocks[-1].dstdata[\"type_id\"].cpu().detach())\n    eval_logits = th.cat(eval_logits)\n    eval_seeds = th.cat(eval_seeds)\n\n    return eval_logits, eval_seeds\n\n\n@utils.benchmark(\"acc\", timeout=3600)  # ogbn-mag takes ~1 hour to train\n@utils.parametrize(\"data\", [\"am\", \"ogbn-mag\"])\ndef track_acc(data):\n    dataset = utils.process_data(data)\n    device = utils.get_bench_device()\n\n    if data == \"am\":\n        n_bases = 40\n        l2norm = 5e-4\n        n_epochs = 20\n    elif data == \"ogbn-mag\":\n        n_bases = 2\n        l2norm = 0\n        n_epochs = 20\n    else:\n        raise ValueError()\n\n    fanouts = [25, 15]\n    n_layers = 2\n    batch_size = 1024\n    n_hidden = 64\n    dropout = 0.5\n    use_self_loop = True\n    lr = 0.01\n    num_workers = 4\n\n    hg = dataset[0]\n    category = dataset.predict_category\n    num_classes = dataset.num_classes\n    train_mask = hg.nodes[category].data.pop(\"train_mask\")\n    train_idx = th.nonzero(train_mask, as_tuple=False).squeeze()\n    test_mask = hg.nodes[category].data.pop(\"test_mask\")\n    test_idx = th.nonzero(test_mask, as_tuple=False).squeeze()\n    labels = hg.nodes[category].data.pop(\"labels\").to(device)\n    num_of_ntype = len(hg.ntypes)\n    num_rels = len(hg.canonical_etypes)\n\n    node_feats = []\n    for ntype in hg.ntypes:\n        if len(hg.nodes[ntype].data) == 0 or \"feat\" not in hg.nodes[ntype].data:\n            node_feats.append(None)\n        else:\n            feat = hg.nodes[ntype].data.pop(\"feat\")\n            node_feats.append(feat.share_memory_())\n\n    # get target category id\n    category_id = len(hg.ntypes)\n    for i, ntype in enumerate(hg.ntypes):\n        if ntype == category:\n            category_id = i\n    g = dgl.to_homogeneous(hg)\n    u, v, eid = g.all_edges(form=\"all\")\n\n    # global norm\n    _, inverse_index, count = th.unique(\n        v, return_inverse=True, return_counts=True\n    )\n    degrees = count[inverse_index]\n    norm = th.ones(eid.shape[0]) / degrees\n    norm = norm.unsqueeze(1)\n    g.edata[\"norm\"] = norm\n    g.edata[\"etype\"] = g.edata[dgl.ETYPE]\n    g.ndata[\"type_id\"] = g.ndata[dgl.NID]\n    g.ndata[\"ntype\"] = g.ndata[dgl.NTYPE]\n\n    node_ids = th.arange(g.num_nodes())\n    # find out the target node ids\n    node_tids = g.ndata[dgl.NTYPE]\n    loc = node_tids == category_id\n    target_nids = node_ids[loc]\n\n    g = g.formats(\"csc\")\n    sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts)\n    train_loader = dgl.dataloading.DataLoader(\n        g,\n        target_nids[train_idx],\n        sampler,\n        batch_size=batch_size,\n        shuffle=True,\n        drop_last=False,\n        num_workers=num_workers,\n    )\n    test_loader = dgl.dataloading.DataLoader(\n        g,\n        target_nids[test_idx],\n        sampler,\n        batch_size=batch_size,\n        shuffle=True,\n        drop_last=False,\n        num_workers=num_workers,\n    )\n\n    # node features\n    # None for one-hot feature, if not none, it should be the feature tensor.\n    embed_layer = RelGraphEmbedLayer(\n        device,\n        g.num_nodes(),\n        node_tids,\n        num_of_ntype,\n        node_feats,\n        n_hidden,\n        sparse_emb=True,\n    )\n\n    # create model\n    # all model params are in device.\n    model = EntityClassify(\n        device,\n        g.num_nodes(),\n        n_hidden,\n        num_classes,\n        num_rels,\n        num_bases=n_bases,\n        num_hidden_layers=n_layers - 2,\n        dropout=dropout,\n        use_self_loop=use_self_loop,\n        layer_norm=False,\n    )\n\n    embed_layer = embed_layer.to(device)\n    model = model.to(device)\n\n    all_params = itertools.chain(\n        model.parameters(), embed_layer.embeds.parameters()\n    )\n    optimizer = th.optim.Adam(all_params, lr=lr, weight_decay=l2norm)\n    emb_optimizer = th.optim.SparseAdam(\n        list(embed_layer.node_embeds.parameters()), lr=lr\n    )\n\n    print(\"start training...\")\n    for epoch in range(n_epochs):\n        model.train()\n        embed_layer.train()\n\n        for i, sample_data in enumerate(train_loader):\n            input_nodes, output_nodes, blocks = sample_data\n            feats = embed_layer(\n                input_nodes,\n                blocks[0].srcdata[\"ntype\"],\n                blocks[0].srcdata[\"type_id\"],\n                node_feats,\n            )\n            logits = model(blocks, feats)\n            seed_idx = blocks[-1].dstdata[\"type_id\"]\n            loss = F.cross_entropy(logits, labels[seed_idx])\n            optimizer.zero_grad()\n            emb_optimizer.zero_grad()\n\n            loss.backward()\n            optimizer.step()\n            emb_optimizer.step()\n\n    print(\"start testing...\")\n\n    test_logits, test_seeds = evaluate(\n        model, embed_layer, test_loader, node_feats\n    )\n    test_loss = F.cross_entropy(test_logits, labels[test_seeds].cpu()).item()\n    test_acc = th.sum(\n        test_logits.argmax(dim=1) == labels[test_seeds].cpu()\n    ).item() / len(test_seeds)\n\n    return test_acc\n"
  },
  {
    "path": "benchmarks/benchmarks/model_acc/bench_sage.py",
    "content": "import dgl\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dgl.nn.pytorch import SAGEConv\n\nfrom .. import utils\n\n\nclass GraphSAGE(nn.Module):\n    def __init__(\n        self,\n        in_feats,\n        n_hidden,\n        n_classes,\n        n_layers,\n        activation,\n        dropout,\n        aggregator_type,\n    ):\n        super(GraphSAGE, self).__init__()\n        self.layers = nn.ModuleList()\n        self.dropout = nn.Dropout(dropout)\n        self.activation = activation\n\n        # input layer\n        self.layers.append(SAGEConv(in_feats, n_hidden, aggregator_type))\n        # hidden layers\n        for i in range(n_layers - 1):\n            self.layers.append(SAGEConv(n_hidden, n_hidden, aggregator_type))\n        # output layer\n        self.layers.append(\n            SAGEConv(n_hidden, n_classes, aggregator_type)\n        )  # activation None\n\n    def forward(self, graph, inputs):\n        h = self.dropout(inputs)\n        for l, layer in enumerate(self.layers):\n            h = layer(graph, h)\n            if l != len(self.layers) - 1:\n                h = self.activation(h)\n                h = self.dropout(h)\n        return h\n\n\ndef evaluate(model, g, features, labels, mask):\n    model.eval()\n    with torch.no_grad():\n        logits = model(g, features)\n        logits = logits[mask]\n        labels = labels[mask]\n        _, indices = torch.max(logits, dim=1)\n        correct = torch.sum(indices == labels)\n        return correct.item() * 1.0 / len(labels) * 100\n\n\n@utils.benchmark(\"acc\")\n@utils.parametrize(\"data\", [\"cora\", \"pubmed\"])\ndef track_acc(data):\n    data = utils.process_data(data)\n    device = utils.get_bench_device()\n\n    g = data[0].to(device)\n\n    features = g.ndata[\"feat\"]\n    labels = g.ndata[\"label\"]\n    train_mask = g.ndata[\"train_mask\"]\n    val_mask = g.ndata[\"val_mask\"]\n    test_mask = g.ndata[\"test_mask\"]\n\n    in_feats = features.shape[1]\n    n_classes = data.num_classes\n\n    g = dgl.remove_self_loop(g)\n    g = dgl.add_self_loop(g)\n\n    # create model\n    model = GraphSAGE(in_feats, 16, n_classes, 1, F.relu, 0.5, \"gcn\")\n    loss_fcn = torch.nn.CrossEntropyLoss()\n\n    model = model.to(device)\n    model.train()\n\n    # optimizer\n    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)\n    for epoch in range(200):\n        logits = model(g, features)\n        loss = loss_fcn(logits[train_mask], labels[train_mask])\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n    acc = evaluate(model, g, features, labels, test_mask)\n    return acc\n"
  },
  {
    "path": "benchmarks/benchmarks/model_acc/bench_sage_ns.py",
    "content": "import time\n\nimport dgl\nimport dgl.nn.pytorch as dglnn\nimport torch as th\nimport torch.multiprocessing as mp\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nfrom torch.utils.data import DataLoader\n\nfrom .. import utils\n\n\nclass SAGE(nn.Module):\n    def __init__(\n        self, in_feats, n_hidden, n_classes, n_layers, activation, dropout\n    ):\n        super().__init__()\n        self.n_layers = n_layers\n        self.n_hidden = n_hidden\n        self.n_classes = n_classes\n        self.layers = nn.ModuleList()\n        self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, \"mean\"))\n        for i in range(1, n_layers - 1):\n            self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, \"mean\"))\n        self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, \"mean\"))\n        self.dropout = nn.Dropout(dropout)\n        self.activation = activation\n\n    def forward(self, blocks, x):\n        h = x\n        for l, (layer, block) in enumerate(zip(self.layers, blocks)):\n            h = layer(block, h)\n            if l != len(self.layers) - 1:\n                h = self.activation(h)\n                h = self.dropout(h)\n        return h\n\n    def inference(self, g, x, batch_size, device):\n        \"\"\"\n        Inference with the GraphSAGE model on full neighbors (i.e. without neighbor sampling).\n        g : the entire graph.\n        x : the input of entire node set.\n\n        The inference code is written in a fashion that it could handle any number of nodes and\n        layers.\n        \"\"\"\n        # During inference with sampling, multi-layer blocks are very inefficient because\n        # lots of computations in the first few layers are repeated.\n        # Therefore, we compute the representation of all nodes layer by layer.  The nodes\n        # on each layer are of course splitted in batches.\n        # TODO: can we standardize this?\n        for l, layer in enumerate(self.layers):\n            y = th.zeros(\n                g.num_nodes(),\n                self.n_hidden if l != len(self.layers) - 1 else self.n_classes,\n            )\n\n            sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)\n            dataloader = dgl.dataloading.DataLoader(\n                g,\n                th.arange(g.num_nodes()),\n                sampler,\n                batch_size=batch_size,\n                shuffle=True,\n                drop_last=False,\n                num_workers=4,\n            )\n\n            for input_nodes, output_nodes, blocks in dataloader:\n                block = blocks[0]\n\n                block = block.int().to(device)\n                h = x[input_nodes].to(device)\n                h = layer(block, h)\n                if l != len(self.layers) - 1:\n                    h = self.activation(h)\n                    h = self.dropout(h)\n\n                y[output_nodes] = h.cpu()\n\n            x = y\n        return y\n\n\ndef compute_acc(pred, labels):\n    \"\"\"\n    Compute the accuracy of prediction given the labels.\n    \"\"\"\n    labels = labels.long()\n    return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred)\n\n\ndef evaluate(model, g, inputs, labels, val_nid, batch_size, device):\n    \"\"\"\n    Evaluate the model on the validation set specified by ``val_nid``.\n    g : The entire graph.\n    inputs : The features of all the nodes.\n    labels : The labels of all the nodes.\n    val_nid : the node Ids for validation.\n    batch_size : Number of nodes to compute at the same time.\n    device : The GPU device to evaluate on.\n    \"\"\"\n    model.eval()\n    with th.no_grad():\n        pred = model.inference(g, inputs, batch_size, device)\n    model.train()\n    return compute_acc(pred[val_nid], labels[val_nid])\n\n\ndef load_subtensor(g, seeds, input_nodes, device):\n    \"\"\"\n    Copys features and labels of a set of nodes onto GPU.\n    \"\"\"\n    batch_inputs = g.ndata[\"features\"][input_nodes].to(device)\n    batch_labels = g.ndata[\"labels\"][seeds].to(device)\n    return batch_inputs, batch_labels\n\n\n@utils.benchmark(\"acc\", 600)\n@utils.parametrize(\"data\", [\"ogbn-products\", \"reddit\"])\ndef track_acc(data):\n    data = utils.process_data(data)\n    device = utils.get_bench_device()\n    g = data[0]\n    g.ndata[\"features\"] = g.ndata[\"feat\"]\n    g.ndata[\"labels\"] = g.ndata[\"label\"]\n    in_feats = g.ndata[\"features\"].shape[1]\n    n_classes = data.num_classes\n\n    # Create csr/coo/csc formats before launching training processes with multi-gpu.\n    # This avoids creating certain formats in each sub-process, which saves momory and CPU.\n    g.create_formats_()\n\n    num_epochs = 20\n    num_hidden = 16\n    num_layers = 2\n    fan_out = \"5,10\"\n    batch_size = 1024\n    lr = 0.003\n    dropout = 0.5\n    num_workers = 4\n\n    train_nid = th.nonzero(g.ndata[\"train_mask\"], as_tuple=True)[0]\n\n    # Create PyTorch DataLoader for constructing blocks\n    sampler = dgl.dataloading.MultiLayerNeighborSampler(\n        [int(fanout) for fanout in fan_out.split(\",\")]\n    )\n    dataloader = dgl.dataloading.DataLoader(\n        g,\n        train_nid,\n        sampler,\n        batch_size=batch_size,\n        shuffle=True,\n        drop_last=False,\n        num_workers=num_workers,\n    )\n\n    # Define model and optimizer\n    model = SAGE(in_feats, num_hidden, n_classes, num_layers, F.relu, dropout)\n    model = model.to(device)\n    loss_fcn = nn.CrossEntropyLoss()\n    loss_fcn = loss_fcn.to(device)\n    optimizer = optim.Adam(model.parameters(), lr=lr)\n\n    # dry run one epoch\n    for step, (input_nodes, seeds, blocks) in enumerate(dataloader):\n        # Load the input features as well as output labels\n        # batch_inputs, batch_labels = load_subtensor(g, seeds, input_nodes, device)\n        blocks = [block.int().to(device) for block in blocks]\n        batch_inputs = blocks[0].srcdata[\"features\"]\n        batch_labels = blocks[-1].dstdata[\"labels\"]\n\n        # Compute loss and prediction\n        batch_pred = model(blocks, batch_inputs)\n        loss = loss_fcn(batch_pred, batch_labels)\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n    # Training loop\n    for epoch in range(num_epochs):\n        # Loop over the dataloader to sample the computation dependency graph as a list of\n        # blocks.\n        for step, (input_nodes, seeds, blocks) in enumerate(dataloader):\n            # Load the input features as well as output labels\n            # batch_inputs, batch_labels = load_subtensor(g, seeds, input_nodes, device)\n            blocks = [block.int().to(device) for block in blocks]\n            batch_inputs = blocks[0].srcdata[\"features\"]\n            batch_labels = blocks[-1].dstdata[\"labels\"]\n\n            # Compute loss and prediction\n            batch_pred = model(blocks, batch_inputs)\n            loss = loss_fcn(batch_pred, batch_labels)\n            optimizer.zero_grad()\n            loss.backward()\n            optimizer.step()\n\n    test_g = g\n    test_nid = th.nonzero(\n        ~(test_g.ndata[\"train_mask\"] | test_g.ndata[\"val_mask\"]), as_tuple=True\n    )[0]\n    test_acc = evaluate(\n        model,\n        test_g,\n        test_g.ndata[\"features\"],\n        test_g.ndata[\"labels\"],\n        test_nid,\n        batch_size,\n        device,\n    )\n\n    return test_acc.item()\n"
  },
  {
    "path": "benchmarks/benchmarks/model_speed/__init__.py",
    "content": ""
  },
  {
    "path": "benchmarks/benchmarks/model_speed/bench_gat.py",
    "content": "import time\n\nimport dgl\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dgl.nn.pytorch import GATConv\n\nfrom .. import utils\n\n\nclass GAT(nn.Module):\n    def __init__(\n        self,\n        num_layers,\n        in_dim,\n        num_hidden,\n        num_classes,\n        heads,\n        activation,\n        feat_drop,\n        attn_drop,\n        negative_slope,\n        residual,\n    ):\n        super(GAT, self).__init__()\n        self.num_layers = num_layers\n        self.gat_layers = nn.ModuleList()\n        self.activation = activation\n        # input projection (no residual)\n        self.gat_layers.append(\n            GATConv(\n                in_dim,\n                num_hidden,\n                heads[0],\n                feat_drop,\n                attn_drop,\n                negative_slope,\n                False,\n                self.activation,\n            )\n        )\n        # hidden layers\n        for l in range(1, num_layers):\n            # due to multi-head, the in_dim = num_hidden * num_heads\n            self.gat_layers.append(\n                GATConv(\n                    num_hidden * heads[l - 1],\n                    num_hidden,\n                    heads[l],\n                    feat_drop,\n                    attn_drop,\n                    negative_slope,\n                    residual,\n                    self.activation,\n                )\n            )\n        # output projection\n        self.gat_layers.append(\n            GATConv(\n                num_hidden * heads[-2],\n                num_classes,\n                heads[-1],\n                feat_drop,\n                attn_drop,\n                negative_slope,\n                residual,\n                None,\n            )\n        )\n\n    def forward(self, g, inputs):\n        h = inputs\n        for l in range(self.num_layers):\n            h = self.gat_layers[l](g, h).flatten(1)\n        # output projection\n        logits = self.gat_layers[-1](g, h).mean(1)\n        return logits\n\n\n@utils.benchmark(\"time\")\n@utils.parametrize(\"data\", [\"cora\", \"pubmed\"])\ndef track_time(data):\n    data = utils.process_data(data)\n    device = utils.get_bench_device()\n    num_epochs = 200\n\n    g = data[0].to(device)\n\n    features = g.ndata[\"feat\"]\n    labels = g.ndata[\"label\"]\n    train_mask = g.ndata[\"train_mask\"]\n    val_mask = g.ndata[\"val_mask\"]\n    test_mask = g.ndata[\"test_mask\"]\n\n    in_feats = features.shape[1]\n    n_classes = data.num_classes\n\n    g = dgl.remove_self_loop(g)\n    g = dgl.add_self_loop(g)\n\n    # create model\n    model = GAT(1, in_feats, 8, n_classes, [8, 1], F.elu, 0.6, 0.6, 0.2, False)\n    loss_fcn = torch.nn.CrossEntropyLoss()\n\n    model = model.to(device)\n    model.train()\n\n    # optimizer\n    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)\n\n    # dry run\n    for epoch in range(10):\n        logits = model(g, features)\n        loss = loss_fcn(logits[train_mask], labels[train_mask])\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n    # timing\n    t0 = time.time()\n    for epoch in range(num_epochs):\n        logits = model(g, features)\n        loss = loss_fcn(logits[train_mask], labels[train_mask])\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n    t1 = time.time()\n\n    return (t1 - t0) / num_epochs\n"
  },
  {
    "path": "benchmarks/benchmarks/model_speed/bench_gat_ns.py",
    "content": "import time\nimport traceback\n\nimport dgl\nimport dgl.nn.pytorch as dglnn\nimport torch as th\nimport torch.multiprocessing as mp\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nfrom torch.utils.data import DataLoader\n\nfrom .. import utils\n\n\nclass GAT(nn.Module):\n    def __init__(\n        self,\n        in_feats,\n        num_heads,\n        n_hidden,\n        n_classes,\n        n_layers,\n        activation,\n        dropout=0.0,\n    ):\n        super().__init__()\n        self.n_layers = n_layers\n        self.n_hidden = n_hidden\n        self.n_classes = n_classes\n        self.layers = nn.ModuleList()\n        self.num_heads = num_heads\n        self.layers.append(\n            dglnn.GATConv(\n                in_feats,\n                n_hidden,\n                num_heads=num_heads,\n                feat_drop=dropout,\n                attn_drop=dropout,\n                activation=activation,\n                negative_slope=0.2,\n            )\n        )\n        for i in range(1, n_layers - 1):\n            self.layers.append(\n                dglnn.GATConv(\n                    n_hidden * num_heads,\n                    n_hidden,\n                    num_heads=num_heads,\n                    feat_drop=dropout,\n                    attn_drop=dropout,\n                    activation=activation,\n                    negative_slope=0.2,\n                )\n            )\n        self.layers.append(\n            dglnn.GATConv(\n                n_hidden * num_heads,\n                n_classes,\n                num_heads=num_heads,\n                feat_drop=dropout,\n                attn_drop=dropout,\n                activation=None,\n                negative_slope=0.2,\n            )\n        )\n\n    def forward(self, blocks, x):\n        h = x\n        for l, (layer, block) in enumerate(zip(self.layers, blocks)):\n            h = layer(block, h)\n            if l < len(self.layers) - 1:\n                h = h.flatten(1)\n        h = h.mean(1)\n        return h.log_softmax(dim=-1)\n\n\ndef load_subtensor(g, seeds, input_nodes, device):\n    \"\"\"\n    Copys features and labels of a set of nodes onto GPU.\n    \"\"\"\n    batch_inputs = g.ndata[\"features\"][input_nodes].to(device)\n    batch_labels = g.ndata[\"labels\"][seeds].to(device)\n    return batch_inputs, batch_labels\n\n\n@utils.benchmark(\"time\", 600)\n@utils.parametrize(\"data\", [\"reddit\", \"ogbn-products\"])\ndef track_time(data):\n    data = utils.process_data(data)\n    device = utils.get_bench_device()\n    g = data[0]\n    g.ndata[\"features\"] = g.ndata[\"feat\"]\n    g.ndata[\"labels\"] = g.ndata[\"label\"]\n    g = g.remove_self_loop().add_self_loop()\n    in_feats = g.ndata[\"features\"].shape[1]\n    n_classes = data.num_classes\n\n    # Create csr/coo/csc formats before launching training processes with multi-gpu.\n    # This avoids creating certain formats in each sub-process, which saves momory and CPU.\n    g.create_formats_()\n\n    num_hidden = 16\n    num_heads = 8\n    num_layers = 2\n    fan_out = \"10,25\"\n    batch_size = 1024\n    lr = 0.003\n    dropout = 0.5\n    num_workers = 4\n    iter_start = 3\n    iter_count = 10\n\n    train_nid = th.nonzero(g.ndata[\"train_mask\"], as_tuple=True)[0]\n\n    # Create PyTorch DataLoader for constructing blocks\n    sampler = dgl.dataloading.MultiLayerNeighborSampler(\n        [int(fanout) for fanout in fan_out.split(\",\")]\n    )\n    dataloader = dgl.dataloading.DataLoader(\n        g,\n        train_nid,\n        sampler,\n        batch_size=batch_size,\n        shuffle=True,\n        drop_last=False,\n        num_workers=num_workers,\n    )\n\n    # Define model and optimizer\n    model = GAT(\n        in_feats, num_heads, num_hidden, n_classes, num_layers, F.relu, dropout\n    )\n    model = model.to(device)\n    loss_fcn = nn.CrossEntropyLoss()\n    loss_fcn = loss_fcn.to(device)\n    optimizer = optim.Adam(model.parameters(), lr=lr)\n\n    # Enable dataloader cpu affinitization for cpu devices (no effect on gpu)\n    with dataloader.enable_cpu_affinity():\n        # Loop over the dataloader to sample the computation dependency graph as a list of\n        # blocks.\n\n        # Training loop\n        avg = 0\n        iter_tput = []\n        for step, (input_nodes, seeds, blocks) in enumerate(dataloader):\n            # Load the input features as well as output labels\n            blocks = [block.int().to(device) for block in blocks]\n            batch_inputs = blocks[0].srcdata[\"features\"]\n            batch_labels = blocks[-1].dstdata[\"labels\"]\n\n            # Compute loss and prediction\n            batch_pred = model(blocks, batch_inputs)\n            loss = loss_fcn(batch_pred, batch_labels)\n            optimizer.zero_grad()\n            loss.backward()\n            optimizer.step()\n\n            # start timer at before iter_start\n            if step == iter_start - 1:\n                t0 = time.time()\n            elif (\n                step == iter_count + iter_start - 1\n            ):  # time iter_count iterations\n                break\n\n    t1 = time.time()\n\n    return (t1 - t0) / iter_count\n"
  },
  {
    "path": "benchmarks/benchmarks/model_speed/bench_gcn_udf.py",
    "content": "import time\n\nimport dgl\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom .. import utils\n\n\nclass GraphConv(nn.Module):\n    def __init__(self, in_dim, out_dim, activation=None):\n        super(GraphConv, self).__init__()\n        self.in_dim = in_dim\n        self.out_dim = out_dim\n        self.activation = activation\n        self.weight = nn.Parameter(torch.Tensor(in_dim, out_dim))\n        self.bias = nn.Parameter(torch.Tensor(out_dim))\n        nn.init.xavier_normal_(self.weight)\n        nn.init.zeros_(self.bias)\n\n    def forward(self, graph, feat):\n        with graph.local_scope():\n            graph.ndata[\"ci\"] = torch.pow(\n                graph.out_degrees().float().clamp(min=1), -0.5\n            )\n            graph.ndata[\"cj\"] = torch.pow(\n                graph.in_degrees().float().clamp(min=1), -0.5\n            )\n            graph.ndata[\"h\"] = feat\n            graph.update_all(self.mfunc, self.rfunc)\n            h = graph.ndata[\"h\"]\n            h = torch.matmul(h, self.weight) + self.bias\n            if self.activation is not None:\n                h = self.activation(h)\n            return h\n\n    def mfunc(self, edges):\n        return {\"m\": edges.src[\"h\"], \"ci\": edges.src[\"ci\"]}\n\n    def rfunc(self, nodes):\n        ci = nodes.mailbox[\"ci\"].unsqueeze(2)\n        newh = (nodes.mailbox[\"m\"] * ci).sum(1) * nodes.data[\"cj\"].unsqueeze(1)\n        return {\"h\": newh}\n\n\nclass GCN(nn.Module):\n    def __init__(\n        self, in_feats, n_hidden, n_classes, n_layers, activation, dropout\n    ):\n        super(GCN, self).__init__()\n        self.layers = nn.ModuleList()\n        # input layer\n        self.layers.append(GraphConv(in_feats, n_hidden, activation=activation))\n        # hidden layers\n        for i in range(n_layers - 1):\n            self.layers.append(\n                GraphConv(n_hidden, n_hidden, activation=activation)\n            )\n        # output layer\n        self.layers.append(GraphConv(n_hidden, n_classes))\n        self.dropout = nn.Dropout(p=dropout)\n\n    def forward(self, g, features):\n        h = features\n        for i, layer in enumerate(self.layers):\n            if i != 0:\n                h = self.dropout(h)\n            h = layer(g, h)\n        return h\n\n\n@utils.benchmark(\"time\", timeout=300)\n@utils.parametrize(\"data\", [\"cora\", \"pubmed\"])\ndef track_time(data):\n    data = utils.process_data(data)\n    device = utils.get_bench_device()\n\n    g = data[0].to(device).int()\n\n    features = g.ndata[\"feat\"]\n    labels = g.ndata[\"label\"]\n    train_mask = g.ndata[\"train_mask\"]\n    val_mask = g.ndata[\"val_mask\"]\n    test_mask = g.ndata[\"test_mask\"]\n\n    in_feats = features.shape[1]\n    n_classes = data.num_classes\n\n    g = dgl.remove_self_loop(g)\n    g = dgl.add_self_loop(g)\n\n    # normalization\n    degs = g.in_degrees().float()\n    norm = torch.pow(degs, -0.5)\n    norm[torch.isinf(norm)] = 0\n    g.ndata[\"norm\"] = norm.unsqueeze(1)\n\n    # create GCN model\n    model = GCN(in_feats, 16, n_classes, 1, F.relu, 0.5)\n    loss_fcn = torch.nn.CrossEntropyLoss()\n\n    model = model.to(device)\n    model.train()\n\n    # optimizer\n    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)\n    # dry run\n    for epoch in range(5):\n        logits = model(g, features)\n        loss = loss_fcn(logits[train_mask], labels[train_mask])\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n    with utils.Timer(device) as t:\n        for epoch in range(200):\n            logits = model(g, features)\n            loss = loss_fcn(logits[train_mask], labels[train_mask])\n            optimizer.zero_grad()\n            loss.backward()\n            optimizer.step()\n\n    return t.elapsed_secs / 200\n"
  },
  {
    "path": "benchmarks/benchmarks/model_speed/bench_pinsage.py",
    "content": "import argparse\nimport pickle\nimport time\n\nimport dgl\nimport dgl.function as fn\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.utils.data import DataLoader, IterableDataset\n\nfrom .. import utils\n\n\ndef _init_input_modules(g, ntype, textset, hidden_dims):\n    # We initialize the linear projections of each input feature ``x`` as\n    # follows:\n    # * If ``x`` is a scalar integral feature, we assume that ``x`` is a categorical\n    #   feature, and assume the range of ``x`` is 0..max(x).\n    # * If ``x`` is a float one-dimensional feature, we assume that ``x`` is a\n    #   numeric vector.\n    # * If ``x`` is a field of a textset, we process it as bag of words.\n    module_dict = nn.ModuleDict()\n\n    for column, data in g.nodes[ntype].data.items():\n        if column == dgl.NID:\n            continue\n        if data.dtype == torch.float32:\n            assert data.ndim == 2\n            m = nn.Linear(data.shape[1], hidden_dims)\n            nn.init.xavier_uniform_(m.weight)\n            nn.init.constant_(m.bias, 0)\n            module_dict[column] = m\n        elif data.dtype == torch.int64:\n            assert data.ndim == 1\n            m = nn.Embedding(data.max() + 2, hidden_dims, padding_idx=-1)\n            nn.init.xavier_uniform_(m.weight)\n            module_dict[column] = m\n\n    if textset is not None:\n        for column, field in textset.fields.items():\n            if field.vocab.vectors:\n                module_dict[column] = BagOfWordsPretrained(field, hidden_dims)\n            else:\n                module_dict[column] = BagOfWords(field, hidden_dims)\n\n    return module_dict\n\n\nclass BagOfWordsPretrained(nn.Module):\n    def __init__(self, field, hidden_dims):\n        super().__init__()\n\n        input_dims = field.vocab.vectors.shape[1]\n        self.emb = nn.Embedding(\n            len(field.vocab.itos),\n            input_dims,\n            padding_idx=field.vocab.stoi[field.pad_token],\n        )\n        self.emb.weight[:] = field.vocab.vectors\n        self.proj = nn.Linear(input_dims, hidden_dims)\n        nn.init.xavier_uniform_(self.proj.weight)\n        nn.init.constant_(self.proj.bias, 0)\n\n        disable_grad(self.emb)\n\n    def forward(self, x, length):\n        \"\"\"\n        x: (batch_size, max_length) LongTensor\n        length: (batch_size,) LongTensor\n        \"\"\"\n        x = self.emb(x).sum(1) / length.unsqueeze(1).float()\n        return self.proj(x)\n\n\nclass BagOfWords(nn.Module):\n    def __init__(self, field, hidden_dims):\n        super().__init__()\n\n        self.emb = nn.Embedding(\n            len(field.vocab.itos),\n            hidden_dims,\n            padding_idx=field.vocab.stoi[field.pad_token],\n        )\n        nn.init.xavier_uniform_(self.emb.weight)\n\n    def forward(self, x, length):\n        return self.emb(x).sum(1) / length.unsqueeze(1).float()\n\n\nclass WeightedSAGEConv(nn.Module):\n    def __init__(self, input_dims, hidden_dims, output_dims, act=F.relu):\n        super().__init__()\n\n        self.act = act\n        self.Q = nn.Linear(input_dims, hidden_dims)\n        self.W = nn.Linear(input_dims + hidden_dims, output_dims)\n        self.reset_parameters()\n        self.dropout = nn.Dropout(0.5)\n\n    def reset_parameters(self):\n        gain = nn.init.calculate_gain(\"relu\")\n        nn.init.xavier_uniform_(self.Q.weight, gain=gain)\n        nn.init.xavier_uniform_(self.W.weight, gain=gain)\n        nn.init.constant_(self.Q.bias, 0)\n        nn.init.constant_(self.W.bias, 0)\n\n    def forward(self, g, h, weights):\n        \"\"\"\n        g : graph\n        h : node features\n        weights : scalar edge weights\n        \"\"\"\n        h_src, h_dst = h\n        with g.local_scope():\n            g.srcdata[\"n\"] = self.act(self.Q(self.dropout(h_src)))\n            g.edata[\"w\"] = weights.float()\n            g.update_all(fn.u_mul_e(\"n\", \"w\", \"m\"), fn.sum(\"m\", \"n\"))\n            g.update_all(fn.copy_e(\"w\", \"m\"), fn.sum(\"m\", \"ws\"))\n            n = g.dstdata[\"n\"]\n            ws = g.dstdata[\"ws\"].unsqueeze(1).clamp(min=1)\n            z = self.act(self.W(self.dropout(torch.cat([n / ws, h_dst], 1))))\n            z_norm = z.norm(2, 1, keepdim=True)\n            z_norm = torch.where(\n                z_norm == 0, torch.tensor(1.0).to(z_norm), z_norm\n            )\n            z = z / z_norm\n            return z\n\n\nclass SAGENet(nn.Module):\n    def __init__(self, hidden_dims, n_layers):\n        \"\"\"\n        g : DGLGraph\n            The user-item interaction graph.\n            This is only for finding the range of categorical variables.\n        item_textsets : torchtext.data.Dataset\n            The textual features of each item node.\n        \"\"\"\n        super().__init__()\n\n        self.convs = nn.ModuleList()\n        for _ in range(n_layers):\n            self.convs.append(\n                WeightedSAGEConv(hidden_dims, hidden_dims, hidden_dims)\n            )\n\n    def forward(self, blocks, h):\n        for layer, block in zip(self.convs, blocks):\n            h_dst = h[: block.num_nodes(\"DST/\" + block.ntypes[0])]\n            h = layer(block, (h, h_dst), block.edata[\"weights\"])\n        return h\n\n\nclass LinearProjector(nn.Module):\n    \"\"\"\n    Projects each input feature of the graph linearly and sums them up\n    \"\"\"\n\n    def __init__(self, full_graph, ntype, textset, hidden_dims):\n        super().__init__()\n\n        self.ntype = ntype\n        self.inputs = _init_input_modules(\n            full_graph, ntype, textset, hidden_dims\n        )\n\n    def forward(self, ndata):\n        projections = []\n        for feature, data in ndata.items():\n            if feature == dgl.NID or feature.endswith(\"__len\"):\n                # This is an additional feature indicating the length of the ``feature``\n                # column; we shouldn't process this.\n                continue\n\n            module = self.inputs[feature]\n            if isinstance(module, (BagOfWords, BagOfWordsPretrained)):\n                # Textual feature; find the length and pass it to the textual module.\n                length = ndata[feature + \"__len\"]\n                result = module(data, length)\n            else:\n                result = module(data)\n            projections.append(result)\n\n        return torch.stack(projections, 1).sum(1)\n\n\nclass ItemToItemScorer(nn.Module):\n    def __init__(self, full_graph, ntype):\n        super().__init__()\n\n        n_nodes = full_graph.num_nodes(ntype)\n        self.bias = nn.Parameter(torch.zeros(n_nodes))\n\n    def _add_bias(self, edges):\n        bias_src = self.bias[edges.src[dgl.NID]]\n        bias_dst = self.bias[edges.dst[dgl.NID]]\n        return {\"s\": edges.data[\"s\"] + bias_src + bias_dst}\n\n    def forward(self, item_item_graph, h):\n        \"\"\"\n        item_item_graph : graph consists of edges connecting the pairs\n        h : hidden state of every node\n        \"\"\"\n        with item_item_graph.local_scope():\n            item_item_graph.ndata[\"h\"] = h\n            item_item_graph.apply_edges(fn.u_dot_v(\"h\", \"h\", \"s\"))\n            item_item_graph.apply_edges(self._add_bias)\n            pair_score = item_item_graph.edata[\"s\"]\n        return pair_score\n\n\nclass PinSAGEModel(nn.Module):\n    def __init__(self, full_graph, ntype, textsets, hidden_dims, n_layers):\n        super().__init__()\n\n        self.proj = LinearProjector(full_graph, ntype, textsets, hidden_dims)\n        self.sage = SAGENet(hidden_dims, n_layers)\n        self.scorer = ItemToItemScorer(full_graph, ntype)\n\n    def forward(self, pos_graph, neg_graph, blocks):\n        h_item = self.get_repr(blocks)\n        pos_score = self.scorer(pos_graph, h_item)\n        neg_score = self.scorer(neg_graph, h_item)\n        return (neg_score - pos_score + 1).clamp(min=0)\n\n    def get_repr(self, blocks):\n        h_item = self.proj(blocks[0].srcdata)\n        h_item_dst = self.proj(blocks[-1].dstdata)\n        return h_item_dst + self.sage(blocks, h_item)\n\n\ndef compact_and_copy(frontier, seeds):\n    block = dgl.to_block(frontier, seeds)\n    for col, data in frontier.edata.items():\n        if col == dgl.EID:\n            continue\n        block.edata[col] = data[block.edata[dgl.EID]]\n    return block\n\n\nclass ItemToItemBatchSampler(IterableDataset):\n    def __init__(self, g, user_type, item_type, batch_size):\n        self.g = g\n        self.user_type = user_type\n        self.item_type = item_type\n        self.user_to_item_etype = list(g.metagraph()[user_type][item_type])[0]\n        self.item_to_user_etype = list(g.metagraph()[item_type][user_type])[0]\n        self.batch_size = batch_size\n\n    def __iter__(self):\n        while True:\n            heads = torch.randint(\n                0, self.g.num_nodes(self.item_type), (self.batch_size,)\n            )\n            tails = dgl.sampling.random_walk(\n                self.g,\n                heads,\n                metapath=[self.item_to_user_etype, self.user_to_item_etype],\n            )[0][:, 2]\n            neg_tails = torch.randint(\n                0, self.g.num_nodes(self.item_type), (self.batch_size,)\n            )\n\n            mask = tails != -1\n            yield heads[mask], tails[mask], neg_tails[mask]\n\n\nclass NeighborSampler(object):\n    def __init__(\n        self,\n        g,\n        user_type,\n        item_type,\n        random_walk_length,\n        random_walk_restart_prob,\n        num_random_walks,\n        num_neighbors,\n        num_layers,\n    ):\n        self.g = g\n        self.user_type = user_type\n        self.item_type = item_type\n        self.user_to_item_etype = list(g.metagraph()[user_type][item_type])[0]\n        self.item_to_user_etype = list(g.metagraph()[item_type][user_type])[0]\n        self.samplers = [\n            dgl.sampling.PinSAGESampler(\n                g,\n                item_type,\n                user_type,\n                random_walk_length,\n                random_walk_restart_prob,\n                num_random_walks,\n                num_neighbors,\n            )\n            for _ in range(num_layers)\n        ]\n\n    def sample_blocks(self, seeds, heads=None, tails=None, neg_tails=None):\n        blocks = []\n        for sampler in self.samplers:\n            frontier = sampler(seeds)\n            if heads is not None:\n                eids = frontier.edge_ids(\n                    torch.cat([heads, heads]),\n                    torch.cat([tails, neg_tails]),\n                    return_uv=True,\n                )[2]\n                if len(eids) > 0:\n                    old_frontier = frontier\n                    frontier = dgl.remove_edges(old_frontier, eids)\n                    # print(old_frontier)\n                    # print(frontier)\n                    # print(frontier.edata['weights'])\n                    # frontier.edata['weights'] = old_frontier.edata['weights'][frontier.edata[dgl.EID]]\n            block = compact_and_copy(frontier, seeds)\n            seeds = block.srcdata[dgl.NID]\n            blocks.insert(0, block)\n        return blocks\n\n    def sample_from_item_pairs(self, heads, tails, neg_tails):\n        # Create a graph with positive connections only and another graph with negative\n        # connections only.\n        pos_graph = dgl.graph(\n            (heads, tails), num_nodes=self.g.num_nodes(self.item_type)\n        )\n        neg_graph = dgl.graph(\n            (heads, neg_tails), num_nodes=self.g.num_nodes(self.item_type)\n        )\n        pos_graph, neg_graph = dgl.compact_graphs([pos_graph, neg_graph])\n        seeds = pos_graph.ndata[dgl.NID]\n\n        blocks = self.sample_blocks(seeds, heads, tails, neg_tails)\n        return pos_graph, neg_graph, blocks\n\n\ndef assign_simple_node_features(ndata, g, ntype, assign_id=False):\n    \"\"\"\n    Copies data to the given block from the corresponding nodes in the original graph.\n    \"\"\"\n    for col in g.nodes[ntype].data.keys():\n        if not assign_id and col == dgl.NID:\n            continue\n        induced_nodes = ndata[dgl.NID]\n        ndata[col] = g.nodes[ntype].data[col][induced_nodes]\n\n\ndef assign_textual_node_features(ndata, textset, ntype):\n    \"\"\"\n    Assigns numericalized tokens from a torchtext dataset to given block.\n\n    The numericalized tokens would be stored in the block as node features\n    with the same name as ``field_name``.\n\n    The length would be stored as another node feature with name\n    ``field_name + '__len'``.\n\n    block : DGLGraph\n        First element of the compacted blocks, with \"dgl.NID\" as the\n        corresponding node ID in the original graph, hence the index to the\n        text dataset.\n\n        The numericalized tokens (and lengths if available) would be stored\n        onto the blocks as new node features.\n    textset : torchtext.data.Dataset\n        A torchtext dataset whose number of examples is the same as that\n        of nodes in the original graph.\n    \"\"\"\n    node_ids = ndata[dgl.NID].numpy()\n\n    for field_name, field in textset.fields.items():\n        examples = [getattr(textset[i], field_name) for i in node_ids]\n\n        tokens, lengths = field.process(examples)\n\n        if not field.batch_first:\n            tokens = tokens.t()\n\n        ndata[field_name] = tokens\n        ndata[field_name + \"__len\"] = lengths\n\n\ndef assign_features_to_blocks(blocks, g, textset, ntype):\n    # For the first block (which is closest to the input), copy the features from\n    # the original graph as well as the texts.\n    assign_simple_node_features(blocks[0].srcdata, g, ntype)\n    assign_textual_node_features(blocks[0].srcdata, textset, ntype)\n    assign_simple_node_features(blocks[-1].dstdata, g, ntype)\n    assign_textual_node_features(blocks[-1].dstdata, textset, ntype)\n\n\nclass PinSAGECollator(object):\n    def __init__(self, sampler, g, ntype, textset):\n        self.sampler = sampler\n        self.ntype = ntype\n        self.g = g\n        self.textset = textset\n\n    def collate_train(self, batches):\n        heads, tails, neg_tails = batches[0]\n        # Construct multilayer neighborhood via PinSAGE...\n        pos_graph, neg_graph, blocks = self.sampler.sample_from_item_pairs(\n            heads, tails, neg_tails\n        )\n        assign_features_to_blocks(blocks, self.g, self.textset, self.ntype)\n\n        return pos_graph, neg_graph, blocks\n\n    def collate_test(self, samples):\n        batch = torch.LongTensor(samples)\n        blocks = self.sampler.sample_blocks(batch)\n        assign_features_to_blocks(blocks, self.g, self.textset, self.ntype)\n        return blocks\n\n\n@utils.benchmark(\"time\", 600)\n@utils.parametrize(\"data\", [\"nowplaying_rs\"])\ndef track_time(data):\n    dataset = utils.process_data(data)\n    device = utils.get_bench_device()\n\n    user_ntype = dataset.user_ntype\n    item_ntype = dataset.item_ntype\n    textset = dataset.textset\n\n    batch_size = 32\n    random_walk_length = 2\n    random_walk_restart_prob = 0.5\n    num_random_walks = 10\n    num_neighbors = 3\n    num_layers = 2\n    num_workers = 0\n    hidden_dims = 16\n    lr = 3e-5\n    iter_start = 3\n    iter_count = 10\n\n    g = dataset[0]\n    # Sampler\n    batch_sampler = ItemToItemBatchSampler(\n        g, user_ntype, item_ntype, batch_size\n    )\n    neighbor_sampler = NeighborSampler(\n        g,\n        user_ntype,\n        item_ntype,\n        random_walk_length,\n        random_walk_restart_prob,\n        num_random_walks,\n        num_neighbors,\n        num_layers,\n    )\n    collator = PinSAGECollator(neighbor_sampler, g, item_ntype, textset)\n    dataloader = DataLoader(\n        batch_sampler,\n        collate_fn=collator.collate_train,\n        num_workers=num_workers,\n    )\n    dataloader_test = DataLoader(\n        torch.arange(g.num_nodes(item_ntype)),\n        batch_size=batch_size,\n        collate_fn=collator.collate_test,\n        num_workers=num_workers,\n    )\n\n    # Model\n    model = PinSAGEModel(g, item_ntype, textset, hidden_dims, num_layers).to(\n        device\n    )\n    # Optimizer\n    opt = torch.optim.Adam(model.parameters(), lr=lr)\n\n    model.train()\n\n    print(\"start training...\")\n    # For each batch of head-tail-negative triplets...\n    for batch_id, (pos_graph, neg_graph, blocks) in enumerate(dataloader):\n        # Copy to GPU\n        for i in range(len(blocks)):\n            blocks[i] = blocks[i].to(device)\n        pos_graph = pos_graph.to(device)\n        neg_graph = neg_graph.to(device)\n\n        loss = model(pos_graph, neg_graph, blocks).mean()\n        opt.zero_grad()\n        loss.backward()\n        opt.step()\n\n        # start timer at before iter_start\n        if batch_id == iter_start - 1:\n            t0 = time.time()\n        elif (\n            batch_id == iter_count + iter_start - 1\n        ):  # time iter_count iterations\n            break\n\n    t1 = time.time()\n\n    return (t1 - t0) / iter_count\n"
  },
  {
    "path": "benchmarks/benchmarks/model_speed/bench_rgcn_base.py",
    "content": "import time\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom .. import rgcn, utils\n\n\n@utils.benchmark(\"time\", 1200)\n@utils.parametrize(\"data\", [\"aifb\", \"am\"])\ndef track_time(data):\n    # args\n    if data == \"aifb\":\n        num_bases = -1\n        l2norm = 0.0\n    elif data == \"am\":\n        num_bases = 40\n        l2norm = 5e-4\n    else:\n        raise ValueError()\n\n    (\n        g,\n        num_rels,\n        num_classes,\n        labels,\n        train_idx,\n        test_idx,\n        target_idx,\n    ) = rgcn.load_data(data, get_norm=True)\n    num_hidden = 16\n\n    model = rgcn.RGCN(\n        g.num_nodes(), num_hidden, num_classes, num_rels, num_bases=num_bases\n    )\n    device = utils.get_bench_device()\n    labels = labels.to(device)\n    model = model.to(device)\n    g = g.int().to(device)\n\n    optimizer = torch.optim.Adam(\n        model.parameters(), lr=1e-2, weight_decay=l2norm\n    )\n\n    model.train()\n    num_epochs = 30\n    t0 = time.time()\n    for epoch in range(num_epochs):\n        logits = model(g)\n        logits = logits[target_idx]\n        loss = F.cross_entropy(logits[train_idx], labels[train_idx])\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n    t1 = time.time()\n\n    return (t1 - t0) / num_epochs\n"
  },
  {
    "path": "benchmarks/benchmarks/model_speed/bench_rgcn_hetero_ns.py",
    "content": "import itertools\nimport time\nimport traceback\n\nimport dgl\nimport dgl.nn.pytorch as dglnn\nimport torch as th\nimport torch.multiprocessing as mp\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nfrom torch.utils.data import DataLoader\n\nfrom .. import utils\n\n\nclass RelGraphConvLayer(nn.Module):\n    r\"\"\"Relational graph convolution layer.\n\n    Parameters\n    ----------\n    in_feat : int\n        Input feature size.\n    out_feat : int\n        Output feature size.\n    rel_names : list[str]\n        Relation names.\n    num_bases : int, optional\n        Number of bases. If is none, use number of relations. Default: None.\n    weight : bool, optional\n        True if a linear layer is applied after message passing. Default: True\n    bias : bool, optional\n        True if bias is added. Default: True\n    activation : callable, optional\n        Activation function. Default: None\n    self_loop : bool, optional\n        True to include self loop message. Default: False\n    dropout : float, optional\n        Dropout rate. Default: 0.0\n    \"\"\"\n\n    def __init__(\n        self,\n        in_feat,\n        out_feat,\n        rel_names,\n        num_bases,\n        *,\n        weight=True,\n        bias=True,\n        activation=None,\n        self_loop=False,\n        dropout=0.0\n    ):\n        super(RelGraphConvLayer, self).__init__()\n        self.in_feat = in_feat\n        self.out_feat = out_feat\n        self.rel_names = rel_names\n        self.num_bases = num_bases\n        self.bias = bias\n        self.activation = activation\n        self.self_loop = self_loop\n\n        self.conv = dglnn.HeteroGraphConv(\n            {\n                rel: dglnn.GraphConv(\n                    in_feat, out_feat, norm=\"right\", weight=False, bias=False\n                )\n                for rel in rel_names\n            }\n        )\n\n        self.use_weight = weight\n        self.use_basis = num_bases < len(self.rel_names) and weight\n        if self.use_weight:\n            if self.use_basis:\n                self.basis = dglnn.WeightBasis(\n                    (in_feat, out_feat), num_bases, len(self.rel_names)\n                )\n            else:\n                self.weight = nn.Parameter(\n                    th.Tensor(len(self.rel_names), in_feat, out_feat)\n                )\n                nn.init.xavier_uniform_(\n                    self.weight, gain=nn.init.calculate_gain(\"relu\")\n                )\n\n        # bias\n        if bias:\n            self.h_bias = nn.Parameter(th.Tensor(out_feat))\n            nn.init.zeros_(self.h_bias)\n\n        # weight for self loop\n        if self.self_loop:\n            self.loop_weight = nn.Parameter(th.Tensor(in_feat, out_feat))\n            nn.init.xavier_uniform_(\n                self.loop_weight, gain=nn.init.calculate_gain(\"relu\")\n            )\n\n        self.dropout = nn.Dropout(dropout)\n\n    def forward(self, g, inputs):\n        \"\"\"Forward computation\n\n        Parameters\n        ----------\n        g : DGLGraph\n            Input graph.\n        inputs : dict[str, torch.Tensor]\n            Node feature for each node type.\n\n        Returns\n        -------\n        dict[str, torch.Tensor]\n            New node features for each node type.\n        \"\"\"\n        g = g.local_var()\n        if self.use_weight:\n            weight = self.basis() if self.use_basis else self.weight\n            wdict = {\n                self.rel_names[i]: {\"weight\": w.squeeze(0)}\n                for i, w in enumerate(th.split(weight, 1, dim=0))\n            }\n        else:\n            wdict = {}\n\n        if g.is_block:\n            inputs_src = inputs\n            inputs_dst = {\n                k: v[: g.number_of_dst_nodes(k)] for k, v in inputs.items()\n            }\n        else:\n            inputs_src = inputs_dst = inputs\n\n        hs = self.conv(g, inputs, mod_kwargs=wdict)\n\n        def _apply(ntype, h):\n            if self.self_loop:\n                h = h + th.matmul(inputs_dst[ntype], self.loop_weight)\n            if self.bias:\n                h = h + self.h_bias\n            if self.activation:\n                h = self.activation(h)\n            return self.dropout(h)\n\n        return {ntype: _apply(ntype, h) for ntype, h in hs.items()}\n\n\nclass RelGraphEmbed(nn.Module):\n    r\"\"\"Embedding layer for featureless heterograph.\"\"\"\n\n    def __init__(\n        self,\n        g,\n        device,\n        embed_size,\n        num_nodes,\n        node_feats,\n        embed_name=\"embed\",\n        activation=None,\n        dropout=0.0,\n    ):\n        super(RelGraphEmbed, self).__init__()\n        self.g = g\n        self.device = device\n        self.embed_size = embed_size\n        self.embed_name = embed_name\n        self.activation = activation\n        self.dropout = nn.Dropout(dropout)\n        self.node_feats = node_feats\n\n        # create weight embeddings for each node for each relation\n        self.embeds = nn.ParameterDict()\n        self.node_embeds = nn.ModuleDict()\n        for ntype in g.ntypes:\n            if node_feats[ntype] is None:\n                sparse_emb = th.nn.Embedding(\n                    num_nodes[ntype], embed_size, sparse=True\n                )\n                nn.init.uniform_(sparse_emb.weight, -1.0, 1.0)\n                self.node_embeds[ntype] = sparse_emb\n            else:\n                input_emb_size = node_feats[ntype].shape[1]\n                embed = nn.Parameter(th.Tensor(input_emb_size, embed_size))\n                nn.init.xavier_uniform_(embed)\n                self.embeds[ntype] = embed\n\n    def forward(self, block=None):\n        \"\"\"Forward computation\n\n        Parameters\n        ----------\n        block : DGLGraph, optional\n            If not specified, directly return the full graph with embeddings stored in\n            :attr:`embed_name`. Otherwise, extract and store the embeddings to the block\n            graph and return.\n\n        Returns\n        -------\n        DGLGraph\n            The block graph fed with embeddings.\n        \"\"\"\n        embeds = {}\n        for ntype in block.ntypes:\n            if self.node_feats[ntype] is None:\n                embeds[ntype] = self.node_embeds[ntype](block.nodes(ntype)).to(\n                    self.device\n                )\n            else:\n                embeds[ntype] = (\n                    self.node_feats[ntype][block.nodes(ntype)].to(self.device)\n                    @ self.embeds[ntype]\n                )\n        return embeds\n\n\nclass EntityClassify(nn.Module):\n    def __init__(\n        self,\n        g,\n        h_dim,\n        out_dim,\n        num_bases,\n        num_hidden_layers=1,\n        dropout=0,\n        use_self_loop=False,\n    ):\n        super(EntityClassify, self).__init__()\n        self.g = g\n        self.h_dim = h_dim\n        self.out_dim = out_dim\n        self.rel_names = list(set(g.etypes))\n        self.rel_names.sort()\n        if num_bases < 0 or num_bases > len(self.rel_names):\n            self.num_bases = len(self.rel_names)\n        else:\n            self.num_bases = num_bases\n        self.num_hidden_layers = num_hidden_layers\n        self.dropout = dropout\n        self.use_self_loop = use_self_loop\n\n        self.layers = nn.ModuleList()\n        # i2h\n        self.layers.append(\n            RelGraphConvLayer(\n                self.h_dim,\n                self.h_dim,\n                self.rel_names,\n                self.num_bases,\n                activation=F.relu,\n                self_loop=self.use_self_loop,\n                dropout=self.dropout,\n                weight=False,\n            )\n        )\n        # h2h\n        for i in range(self.num_hidden_layers):\n            self.layers.append(\n                RelGraphConvLayer(\n                    self.h_dim,\n                    self.h_dim,\n                    self.rel_names,\n                    self.num_bases,\n                    activation=F.relu,\n                    self_loop=self.use_self_loop,\n                    dropout=self.dropout,\n                )\n            )\n        # h2o\n        self.layers.append(\n            RelGraphConvLayer(\n                self.h_dim,\n                self.out_dim,\n                self.rel_names,\n                self.num_bases,\n                activation=None,\n                self_loop=self.use_self_loop,\n            )\n        )\n\n    def forward(self, h, blocks):\n        for layer, block in zip(self.layers, blocks):\n            h = layer(block, h)\n        return h\n\n\n@utils.benchmark(\"time\", 600)\n@utils.parametrize(\"data\", [\"ogbn-mag\"])\ndef track_time(data):\n    dataset = utils.process_data(data)\n    device = utils.get_bench_device()\n\n    if data == \"ogbn-mag\":\n        n_bases = 2\n        l2norm = 0\n    else:\n        raise ValueError()\n\n    fanout = 4\n    n_layers = 2\n    batch_size = 1024\n    n_hidden = 64\n    dropout = 0.5\n    use_self_loop = True\n    lr = 0.01\n    iter_start = 3\n    iter_count = 10\n\n    hg = dataset[0]\n    category = dataset.predict_category\n    num_classes = dataset.num_classes\n    train_mask = hg.nodes[category].data.pop(\"train_mask\")\n    train_idx = th.nonzero(train_mask, as_tuple=False).squeeze()\n    labels = hg.nodes[category].data.pop(\"labels\")\n\n    node_feats = {}\n    num_nodes = {}\n    for ntype in hg.ntypes:\n        node_feats[ntype] = (\n            hg.nodes[ntype].data[\"feat\"]\n            if \"feat\" in hg.nodes[ntype].data\n            else None\n        )\n        num_nodes[ntype] = hg.num_nodes(ntype)\n\n    embed_layer = RelGraphEmbed(hg, device, n_hidden, num_nodes, node_feats)\n    model = EntityClassify(\n        hg,\n        n_hidden,\n        num_classes,\n        num_bases=n_bases,\n        num_hidden_layers=n_layers - 2,\n        dropout=dropout,\n        use_self_loop=use_self_loop,\n    )\n    embed_layer = embed_layer.to(device)\n    model = model.to(device)\n\n    all_params = itertools.chain(\n        model.parameters(), embed_layer.embeds.parameters()\n    )\n    optimizer = th.optim.Adam(all_params, lr=lr, weight_decay=l2norm)\n    sparse_optimizer = th.optim.SparseAdam(\n        list(embed_layer.node_embeds.parameters()), lr=lr\n    )\n\n    sampler = dgl.dataloading.MultiLayerNeighborSampler([fanout] * n_layers)\n    loader = dgl.dataloading.DataLoader(\n        hg,\n        {category: train_idx},\n        sampler,\n        batch_size=batch_size,\n        shuffle=True,\n        num_workers=4,\n    )\n\n    print(\"start training...\")\n    model.train()\n    embed_layer.train()\n    optimizer.zero_grad()\n    sparse_optimizer.zero_grad()\n\n    # Enable dataloader cpu affinitization for cpu devices (no effect on gpu)\n    with loader.enable_cpu_affinity():\n        for step, (input_nodes, seeds, blocks) in enumerate(loader):\n            blocks = [blk.to(device) for blk in blocks]\n            seeds = seeds[\n                category\n            ]  # we only predict the nodes with type \"category\"\n            batch_tic = time.time()\n            emb = embed_layer(blocks[0])\n            lbl = labels[seeds].to(device)\n            emb = {k: e.to(device) for k, e in emb.items()}\n            logits = model(emb, blocks)[category]\n            loss = F.cross_entropy(logits, lbl)\n            loss.backward()\n            optimizer.step()\n            sparse_optimizer.step()\n\n            # start timer at before iter_start\n            if step == iter_start - 1:\n                t0 = time.time()\n            elif (\n                step == iter_count + iter_start - 1\n            ):  # time iter_count iterations\n                break\n\n    t1 = time.time()\n\n    return (t1 - t0) / iter_count\n"
  },
  {
    "path": "benchmarks/benchmarks/model_speed/bench_rgcn_homogeneous_ns.py",
    "content": "import itertools\nimport time\n\nimport dgl\nimport dgl.nn.pytorch as dglnn\nimport torch as th\nimport torch.multiprocessing as mp\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nfrom dgl.nn import RelGraphConv\nfrom torch.utils.data import DataLoader\n\nfrom .. import utils\n\n\nclass EntityClassify(nn.Module):\n    \"\"\"Entity classification class for RGCN\n    Parameters\n    ----------\n    device : int\n        Device to run the layer.\n    num_nodes : int\n        Number of nodes.\n    h_dim : int\n        Hidden dim size.\n    out_dim : int\n        Output dim size.\n    num_rels : int\n        Numer of relation types.\n    num_bases : int\n        Number of bases. If is none, use number of relations.\n    num_hidden_layers : int\n        Number of hidden RelGraphConv Layer\n    dropout : float\n        Dropout\n    use_self_loop : bool\n        Use self loop if True, default False.\n    \"\"\"\n\n    def __init__(\n        self,\n        device,\n        num_nodes,\n        h_dim,\n        out_dim,\n        num_rels,\n        num_bases=None,\n        num_hidden_layers=1,\n        dropout=0,\n        use_self_loop=False,\n        layer_norm=False,\n    ):\n        super(EntityClassify, self).__init__()\n        self.device = device\n        self.num_nodes = num_nodes\n        self.h_dim = h_dim\n        self.out_dim = out_dim\n        self.num_rels = num_rels\n        self.num_bases = None if num_bases < 0 else num_bases\n        self.num_hidden_layers = num_hidden_layers\n        self.dropout = dropout\n        self.use_self_loop = use_self_loop\n        self.layer_norm = layer_norm\n\n        self.layers = nn.ModuleList()\n        # i2h\n        self.layers.append(\n            RelGraphConv(\n                self.h_dim,\n                self.h_dim,\n                self.num_rels,\n                \"basis\",\n                self.num_bases,\n                activation=F.relu,\n                self_loop=self.use_self_loop,\n                dropout=self.dropout,\n                layer_norm=layer_norm,\n            )\n        )\n        # h2h\n        for idx in range(self.num_hidden_layers):\n            self.layers.append(\n                RelGraphConv(\n                    self.h_dim,\n                    self.h_dim,\n                    self.num_rels,\n                    \"basis\",\n                    self.num_bases,\n                    activation=F.relu,\n                    self_loop=self.use_self_loop,\n                    dropout=self.dropout,\n                    layer_norm=layer_norm,\n                )\n            )\n        # h2o\n        self.layers.append(\n            RelGraphConv(\n                self.h_dim,\n                self.out_dim,\n                self.num_rels,\n                \"basis\",\n                self.num_bases,\n                activation=None,\n                self_loop=self.use_self_loop,\n                layer_norm=layer_norm,\n            )\n        )\n\n    def forward(self, blocks, feats, norm=None):\n        if blocks is None:\n            # full graph training\n            blocks = [self.g] * len(self.layers)\n        h = feats\n        for layer, block in zip(self.layers, blocks):\n            block = block.to(self.device)\n            h = layer(block, h, block.edata[\"etype\"], block.edata[\"norm\"])\n        return h\n\n\nclass RelGraphEmbedLayer(nn.Module):\n    r\"\"\"Embedding layer for featureless heterograph.\n    Parameters\n    ----------\n    device : int\n        Device to run the layer.\n    num_nodes : int\n        Number of nodes.\n    node_tides : tensor\n        Storing the node type id for each node starting from 0\n    num_of_ntype : int\n        Number of node types\n    input_size : list of int\n        A list of input feature size for each node type. If None, we then\n        treat certain input feature as an one-hot encoding feature.\n    embed_size : int\n        Output embed size\n    embed_name : str, optional\n        Embed name\n    \"\"\"\n\n    def __init__(\n        self,\n        device,\n        num_nodes,\n        node_tids,\n        num_of_ntype,\n        input_size,\n        embed_size,\n        sparse_emb=False,\n        embed_name=\"embed\",\n    ):\n        super(RelGraphEmbedLayer, self).__init__()\n        self.device = device\n        self.embed_size = embed_size\n        self.embed_name = embed_name\n        self.num_nodes = num_nodes\n        self.sparse_emb = sparse_emb\n\n        # create weight embeddings for each node for each relation\n        self.embeds = nn.ParameterDict()\n        self.num_of_ntype = num_of_ntype\n        self.idmap = th.empty(num_nodes).long()\n\n        for ntype in range(num_of_ntype):\n            if input_size[ntype] is not None:\n                input_emb_size = input_size[ntype].shape[1]\n                embed = nn.Parameter(th.Tensor(input_emb_size, self.embed_size))\n                nn.init.xavier_uniform_(embed)\n                self.embeds[str(ntype)] = embed\n\n        self.node_embeds = th.nn.Embedding(\n            node_tids.shape[0], self.embed_size, sparse=self.sparse_emb\n        )\n        nn.init.uniform_(self.node_embeds.weight, -1.0, 1.0)\n\n    def forward(self, node_ids, node_tids, type_ids, features):\n        \"\"\"Forward computation\n        Parameters\n        ----------\n        node_ids : tensor\n            node ids to generate embedding for.\n        node_tids : tensor\n            node type ids\n        features : list of features\n            list of initial features for nodes belong to different node type.\n            If None, the corresponding features is an one-hot encoding feature,\n            else use the features directly as input feature and matmul a\n            projection matrix.\n        Returns\n        -------\n        tensor\n            embeddings as the input of the next layer\n        \"\"\"\n        tsd_ids = node_ids.to(self.node_embeds.weight.device)\n        embeds = th.empty(\n            node_ids.shape[0], self.embed_size, device=self.device\n        )\n        for ntype in range(self.num_of_ntype):\n            if features[ntype] is not None:\n                loc = node_tids == ntype\n                embeds[loc] = features[ntype][type_ids[loc]].to(\n                    self.device\n                ) @ self.embeds[str(ntype)].to(self.device)\n            else:\n                loc = node_tids == ntype\n                embeds[loc] = self.node_embeds(tsd_ids[loc]).to(self.device)\n\n        return embeds\n\n\n@utils.benchmark(\"time\", 600)\n@utils.parametrize(\"data\", [\"am\", \"ogbn-mag\"])\ndef track_time(data):\n    dataset = utils.process_data(data)\n    device = utils.get_bench_device()\n\n    if data == \"am\":\n        batch_size = 64\n        n_bases = 40\n        l2norm = 5e-4\n    elif data == \"ogbn-mag\":\n        batch_size = 1024\n        n_bases = 2\n        l2norm = 0\n    else:\n        raise ValueError()\n\n    fanouts = [25, 15]\n    n_layers = 2\n    n_hidden = 64\n    dropout = 0.5\n    use_self_loop = True\n    lr = 0.01\n    num_workers = 4\n    iter_start = 3\n    iter_count = 10\n\n    hg = dataset[0]\n    category = dataset.predict_category\n    num_classes = dataset.num_classes\n    train_mask = hg.nodes[category].data.pop(\"train_mask\")\n    train_idx = th.nonzero(train_mask, as_tuple=False).squeeze()\n    labels = hg.nodes[category].data.pop(\"labels\").to(device)\n    num_of_ntype = len(hg.ntypes)\n    num_rels = len(hg.canonical_etypes)\n\n    node_feats = []\n    for ntype in hg.ntypes:\n        if len(hg.nodes[ntype].data) == 0 or \"feat\" not in hg.nodes[ntype].data:\n            node_feats.append(None)\n        else:\n            feat = hg.nodes[ntype].data.pop(\"feat\")\n            node_feats.append(feat.share_memory_())\n\n    # get target category id\n    category_id = len(hg.ntypes)\n    for i, ntype in enumerate(hg.ntypes):\n        if ntype == category:\n            category_id = i\n    g = dgl.to_homogeneous(hg)\n    u, v, eid = g.all_edges(form=\"all\")\n\n    # global norm\n    _, inverse_index, count = th.unique(\n        v, return_inverse=True, return_counts=True\n    )\n    degrees = count[inverse_index]\n    norm = th.ones(eid.shape[0]) / degrees\n    norm = norm.unsqueeze(1)\n    g.edata[\"norm\"] = norm\n    g.edata[\"etype\"] = g.edata[dgl.ETYPE]\n    g.ndata[\"type_id\"] = g.ndata[dgl.NID]\n    g.ndata[\"ntype\"] = g.ndata[dgl.NTYPE]\n\n    node_ids = th.arange(g.num_nodes())\n    # find out the target node ids\n    node_tids = g.ndata[dgl.NTYPE]\n    loc = node_tids == category_id\n    target_nids = node_ids[loc]\n    train_nids = target_nids[train_idx]\n\n    g = g.formats(\"csc\")\n    sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts)\n    loader = dgl.dataloading.DataLoader(\n        g,\n        target_nids[train_idx],\n        sampler,\n        batch_size=batch_size,\n        shuffle=True,\n        drop_last=False,\n        num_workers=num_workers,\n    )\n\n    # node features\n    # None for one-hot feature, if not none, it should be the feature tensor.\n    #\n    embed_layer = RelGraphEmbedLayer(\n        device,\n        g.num_nodes(),\n        node_tids,\n        num_of_ntype,\n        node_feats,\n        n_hidden,\n        sparse_emb=True,\n    )\n\n    # create model\n    # all model params are in device.\n    model = EntityClassify(\n        device,\n        g.num_nodes(),\n        n_hidden,\n        num_classes,\n        num_rels,\n        num_bases=n_bases,\n        num_hidden_layers=n_layers - 2,\n        dropout=dropout,\n        use_self_loop=use_self_loop,\n        layer_norm=False,\n    )\n\n    embed_layer = embed_layer.to(device)\n    model = model.to(device)\n\n    all_params = itertools.chain(\n        model.parameters(), embed_layer.embeds.parameters()\n    )\n    optimizer = th.optim.Adam(all_params, lr=lr, weight_decay=l2norm)\n    emb_optimizer = th.optim.SparseAdam(\n        list(embed_layer.node_embeds.parameters()), lr=lr\n    )\n\n    print(\"start training...\")\n    model.train()\n    embed_layer.train()\n\n    # Enable dataloader cpu affinitization for cpu devices (no effect on gpu)\n    with loader.enable_cpu_affinity():\n        for step, sample_data in enumerate(loader):\n            input_nodes, output_nodes, blocks = sample_data\n            feats = embed_layer(\n                input_nodes,\n                blocks[0].srcdata[\"ntype\"],\n                blocks[0].srcdata[\"type_id\"],\n                node_feats,\n            )\n            logits = model(blocks, feats)\n            seed_idx = blocks[-1].dstdata[\"type_id\"]\n            loss = F.cross_entropy(logits, labels[seed_idx])\n            optimizer.zero_grad()\n            emb_optimizer.zero_grad()\n\n            loss.backward()\n            optimizer.step()\n            emb_optimizer.step()\n\n            # start timer at before iter_start\n            if step == iter_start - 1:\n                t0 = time.time()\n            elif (\n                step == iter_count + iter_start - 1\n            ):  # time iter_count iterations\n                break\n\n    t1 = time.time()\n\n    return (t1 - t0) / iter_count\n"
  },
  {
    "path": "benchmarks/benchmarks/model_speed/bench_sage.py",
    "content": "import time\n\nimport dgl\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dgl.nn.pytorch import SAGEConv\n\nfrom .. import utils\n\n\nclass GraphSAGE(nn.Module):\n    def __init__(\n        self,\n        in_feats,\n        n_hidden,\n        n_classes,\n        n_layers,\n        activation,\n        dropout,\n        aggregator_type,\n    ):\n        super(GraphSAGE, self).__init__()\n        self.layers = nn.ModuleList()\n        self.dropout = nn.Dropout(dropout)\n        self.activation = activation\n\n        # input layer\n        self.layers.append(SAGEConv(in_feats, n_hidden, aggregator_type))\n        # hidden layers\n        for i in range(n_layers - 1):\n            self.layers.append(SAGEConv(n_hidden, n_hidden, aggregator_type))\n        # output layer\n        self.layers.append(\n            SAGEConv(n_hidden, n_classes, aggregator_type)\n        )  # activation None\n\n    def forward(self, graph, inputs):\n        h = self.dropout(inputs)\n        for l, layer in enumerate(self.layers):\n            h = layer(graph, h)\n            if l != len(self.layers) - 1:\n                h = self.activation(h)\n                h = self.dropout(h)\n        return h\n\n\n@utils.benchmark(\"time\")\n@utils.parametrize(\"data\", [\"cora\", \"pubmed\"])\ndef track_time(data):\n    data = utils.process_data(data)\n    device = utils.get_bench_device()\n    num_epochs = 200\n\n    g = data[0].to(device)\n\n    features = g.ndata[\"feat\"]\n    labels = g.ndata[\"label\"]\n    train_mask = g.ndata[\"train_mask\"]\n    val_mask = g.ndata[\"val_mask\"]\n    test_mask = g.ndata[\"test_mask\"]\n\n    in_feats = features.shape[1]\n    n_classes = data.num_classes\n\n    g = dgl.remove_self_loop(g)\n    g = dgl.add_self_loop(g)\n\n    # create model\n    model = GraphSAGE(in_feats, 16, n_classes, 1, F.relu, 0.5, \"gcn\")\n    loss_fcn = torch.nn.CrossEntropyLoss()\n\n    model = model.to(device)\n    model.train()\n\n    # optimizer\n    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)\n\n    # dry run\n    for i in range(10):\n        logits = model(g, features)\n        loss = loss_fcn(logits[train_mask], labels[train_mask])\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n    # timing\n    t0 = time.time()\n    for epoch in range(num_epochs):\n        logits = model(g, features)\n        loss = loss_fcn(logits[train_mask], labels[train_mask])\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n    t1 = time.time()\n\n    return (t1 - t0) / num_epochs\n"
  },
  {
    "path": "benchmarks/benchmarks/model_speed/bench_sage_ns.py",
    "content": "import time\n\nimport dgl\nimport dgl.nn.pytorch as dglnn\nimport torch as th\nimport torch.multiprocessing as mp\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nfrom torch.utils.data import DataLoader\n\nfrom .. import utils\n\n\nclass SAGE(nn.Module):\n    def __init__(\n        self, in_feats, n_hidden, n_classes, n_layers, activation, dropout\n    ):\n        super().__init__()\n        self.n_layers = n_layers\n        self.n_hidden = n_hidden\n        self.n_classes = n_classes\n        self.layers = nn.ModuleList()\n        self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, \"mean\"))\n        for i in range(1, n_layers - 1):\n            self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, \"mean\"))\n        self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, \"mean\"))\n        self.dropout = nn.Dropout(dropout)\n        self.activation = activation\n\n    def forward(self, blocks, x):\n        h = x\n        for l, (layer, block) in enumerate(zip(self.layers, blocks)):\n            h = layer(block, h)\n            if l != len(self.layers) - 1:\n                h = self.activation(h)\n                h = self.dropout(h)\n        return h\n\n\ndef load_subtensor(g, seeds, input_nodes, device):\n    \"\"\"\n    Copys features and labels of a set of nodes onto GPU.\n    \"\"\"\n    batch_inputs = g.ndata[\"features\"][input_nodes].to(device)\n    batch_labels = g.ndata[\"labels\"][seeds].to(device)\n    return batch_inputs, batch_labels\n\n\n@utils.benchmark(\"time\", 600)\n@utils.parametrize(\"data\", [\"reddit\", \"ogbn-products\"])\ndef track_time(data):\n    data = utils.process_data(data)\n    device = utils.get_bench_device()\n    g = data[0]\n    g.ndata[\"features\"] = g.ndata[\"feat\"]\n    g.ndata[\"labels\"] = g.ndata[\"label\"]\n    in_feats = g.ndata[\"features\"].shape[1]\n    n_classes = data.num_classes\n\n    # Create csr/coo/csc formats before launching training processes with multi-gpu.\n    # This avoids creating certain formats in each sub-process, which saves momory and CPU.\n    g.create_formats_()\n\n    num_epochs = 20\n    num_hidden = 16\n    num_layers = 2\n    fan_out = \"10,25\"\n    batch_size = 1024\n    lr = 0.003\n    dropout = 0.5\n    num_workers = 4\n    iter_start = 3\n    iter_count = 10\n\n    train_nid = th.nonzero(g.ndata[\"train_mask\"], as_tuple=True)[0]\n\n    # Create PyTorch DataLoader for constructing blocks\n    sampler = dgl.dataloading.MultiLayerNeighborSampler(\n        [int(fanout) for fanout in fan_out.split(\",\")]\n    )\n    dataloader = dgl.dataloading.DataLoader(\n        g,\n        train_nid,\n        sampler,\n        batch_size=batch_size,\n        shuffle=True,\n        drop_last=False,\n        num_workers=num_workers,\n    )\n\n    # Define model and optimizer\n    model = SAGE(in_feats, num_hidden, n_classes, num_layers, F.relu, dropout)\n    model = model.to(device)\n    loss_fcn = nn.CrossEntropyLoss()\n    loss_fcn = loss_fcn.to(device)\n    optimizer = optim.Adam(model.parameters(), lr=lr)\n\n    # Enable dataloader cpu affinitization for cpu devices (no effect on gpu)\n    with dataloader.enable_cpu_affinity():\n        # Training loop\n        avg = 0\n        iter_tput = []\n\n        for step, (input_nodes, seeds, blocks) in enumerate(dataloader):\n            # Load the input features as well as output labels\n            # batch_inputs, batch_labels = load_subtensor(g, seeds, input_nodes, device)\n            blocks = [block.int().to(device) for block in blocks]\n            batch_inputs = blocks[0].srcdata[\"features\"]\n            batch_labels = blocks[-1].dstdata[\"labels\"]\n\n            # Compute loss and prediction\n            batch_pred = model(blocks, batch_inputs)\n            loss = loss_fcn(batch_pred, batch_labels)\n            optimizer.zero_grad()\n            loss.backward()\n            optimizer.step()\n\n            # start timer at before iter_start\n            if step == iter_start - 1:\n                t0 = time.time()\n            elif (\n                step == iter_count + iter_start - 1\n            ):  # time iter_count iterations\n                break\n\n    t1 = time.time()\n\n    return (t1 - t0) / iter_count\n"
  },
  {
    "path": "benchmarks/benchmarks/model_speed/bench_sage_unsupervised_ns.py",
    "content": "import time\n\nimport dgl\nimport dgl.function as fn\nimport dgl.nn.pytorch as dglnn\n\nimport numpy as np\nimport torch as th\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\n\nfrom .. import utils\n\n\nclass NegativeSampler(object):\n    def __init__(self, g, k, neg_share=False):\n        self.weights = g.in_degrees().float() ** 0.75\n        self.k = k\n        self.neg_share = neg_share\n\n    def __call__(self, g, eids):\n        src, _ = g.find_edges(eids)\n        n = len(src)\n        if self.neg_share and n % self.k == 0:\n            dst = self.weights.multinomial(n, replacement=True)\n            dst = dst.view(-1, 1, self.k).expand(-1, self.k, -1).flatten()\n        else:\n            dst = self.weights.multinomial(n * self.k, replacement=True)\n        src = src.repeat_interleave(self.k)\n        return src, dst\n\n\ndef load_subtensor(g, input_nodes, device):\n    \"\"\"\n    Copys features and labels of a set of nodes onto GPU.\n    \"\"\"\n    batch_inputs = g.ndata[\"features\"][input_nodes].to(device)\n    return batch_inputs\n\n\nclass SAGE(nn.Module):\n    def __init__(\n        self, in_feats, n_hidden, n_classes, n_layers, activation, dropout\n    ):\n        super().__init__()\n        self.n_layers = n_layers\n        self.n_hidden = n_hidden\n        self.n_classes = n_classes\n        self.layers = nn.ModuleList()\n        self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, \"mean\"))\n        for i in range(1, n_layers - 1):\n            self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, \"mean\"))\n        self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, \"mean\"))\n        self.dropout = nn.Dropout(dropout)\n        self.activation = activation\n\n    def forward(self, blocks, x):\n        h = x\n        for l, (layer, block) in enumerate(zip(self.layers, blocks)):\n            h = layer(block, h)\n            if l != len(self.layers) - 1:\n                h = self.activation(h)\n                h = self.dropout(h)\n        return h\n\n\ndef load_subtensor(g, input_nodes, device):\n    \"\"\"\n    Copys features and labels of a set of nodes onto GPU.\n    \"\"\"\n    batch_inputs = g.ndata[\"features\"][input_nodes].to(device)\n    return batch_inputs\n\n\nclass CrossEntropyLoss(nn.Module):\n    def forward(self, block_outputs, pos_graph, neg_graph):\n        with pos_graph.local_scope():\n            pos_graph.ndata[\"h\"] = block_outputs\n            pos_graph.apply_edges(fn.u_dot_v(\"h\", \"h\", \"score\"))\n            pos_score = pos_graph.edata[\"score\"]\n        with neg_graph.local_scope():\n            neg_graph.ndata[\"h\"] = block_outputs\n            neg_graph.apply_edges(fn.u_dot_v(\"h\", \"h\", \"score\"))\n            neg_score = neg_graph.edata[\"score\"]\n\n        score = th.cat([pos_score, neg_score])\n        label = th.cat(\n            [th.ones_like(pos_score), th.zeros_like(neg_score)]\n        ).long()\n        loss = F.binary_cross_entropy_with_logits(score, label.float())\n        return loss\n\n\n@utils.benchmark(\"time\", 600)\n@utils.parametrize(\"data\", [\"reddit\"])\n@utils.parametrize(\"num_negs\", [2, 8, 32])\n@utils.parametrize(\"batch_size\", [1024, 2048, 8192])\ndef track_time(data, num_negs, batch_size):\n    data = utils.process_data(data)\n    device = utils.get_bench_device()\n    g = data[0]\n    g.ndata[\"features\"] = g.ndata[\"feat\"]\n    g.ndata[\"labels\"] = g.ndata[\"label\"]\n    in_feats = g.ndata[\"features\"].shape[1]\n    n_classes = data.num_classes\n\n    # Create csr/coo/csc formats before launching training processes with multi-gpu.\n    # This avoids creating certain formats in each sub-process, which saves momory and CPU.\n    g.create_formats_()\n\n    num_epochs = 2\n    num_hidden = 16\n    num_layers = 2\n    fan_out = \"10,25\"\n    lr = 0.003\n    dropout = 0.5\n    num_workers = 4\n    num_negs = 2\n    iter_start = 3\n    iter_count = 10\n\n    n_edges = g.num_edges()\n    train_seeds = np.arange(n_edges)\n\n    # Create PyTorch DataLoader for constructing blocks\n    sampler = dgl.dataloading.MultiLayerNeighborSampler(\n        [int(fanout) for fanout in fan_out.split(\",\")]\n    )\n    sampler = dgl.dataloading.as_edge_prediction_sampler(\n        sampler,\n        exclude=\"reverse_id\",\n        # For each edge with ID e in Reddit dataset, the reverse edge is e ± |E|/2.\n        reverse_eids=th.cat(\n            [th.arange(n_edges // 2, n_edges), th.arange(0, n_edges // 2)]\n        ),\n        negative_sampler=NegativeSampler(g, num_negs),\n    )\n    dataloader = dgl.dataloading.DataLoader(\n        g,\n        train_seeds,\n        sampler,\n        batch_size=batch_size,\n        shuffle=True,\n        drop_last=False,\n        num_workers=num_workers,\n    )\n\n    # Define model and optimizer\n    model = SAGE(in_feats, num_hidden, n_classes, num_layers, F.relu, dropout)\n    model = model.to(device)\n    loss_fcn = CrossEntropyLoss()\n    loss_fcn = loss_fcn.to(device)\n    optimizer = optim.Adam(model.parameters(), lr=lr)\n\n    # Training loop\n    avg = 0\n    iter_tput = []\n    for step, (input_nodes, pos_graph, neg_graph, blocks) in enumerate(\n        dataloader\n    ):\n        # Load the input features as well as output labels\n        batch_inputs = load_subtensor(g, input_nodes, device)\n\n        pos_graph = pos_graph.to(device)\n        neg_graph = neg_graph.to(device)\n        blocks = [block.int().to(device) for block in blocks]\n        # Compute loss and prediction\n        batch_pred = model(blocks, batch_inputs)\n        loss = loss_fcn(batch_pred, pos_graph, neg_graph)\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        # start timer at before iter_start\n        if step == iter_start - 1:\n            t0 = time.time()\n        elif step == iter_count + iter_start - 1:  # time iter_count iterations\n            break\n\n    t1 = time.time()\n\n    return (t1 - t0) / iter_count\n"
  },
  {
    "path": "benchmarks/benchmarks/multigpu/__init__.py",
    "content": ""
  },
  {
    "path": "benchmarks/benchmarks/multigpu/bench_multigpu_rgcn.py",
    "content": "\"\"\"\nModeling Relational Data with Graph Convolutional Networks\nPaper: https://arxiv.org/abs/1703.06103\nCode: https://github.com/tkipf/relational-gcn\nDifference compared to tkipf/relation-gcn\n* l2norm applied to all weights\n* remove nodes that won't be touched\n\"\"\"\nimport argparse\nimport gc\nimport logging\nimport time\nfrom pathlib import Path\nfrom types import SimpleNamespace\n\nimport dgl\n\nimport numpy as np\nimport torch as th\nimport torch.multiprocessing as mp\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dgl.nn import RelGraphConv\nfrom torch.multiprocessing import Queue\nfrom torch.nn.parallel import DistributedDataParallel\nfrom torch.utils.data import DataLoader\n\nfrom .. import utils\n\n\nclass EntityClassify(nn.Module):\n    def __init__(\n        self,\n        device,\n        num_nodes,\n        h_dim,\n        out_dim,\n        num_rels,\n        num_bases=None,\n        num_hidden_layers=1,\n        dropout=0,\n        use_self_loop=False,\n        layer_norm=False,\n    ):\n        super(EntityClassify, self).__init__()\n        self.device = th.device(device if device >= 0 else \"cpu\")\n        self.num_nodes = num_nodes\n        self.h_dim = h_dim\n        self.out_dim = out_dim\n        self.num_rels = num_rels\n        self.num_bases = None if num_bases < 0 else num_bases\n        self.num_hidden_layers = num_hidden_layers\n        self.dropout = dropout\n        self.use_self_loop = use_self_loop\n        self.layer_norm = layer_norm\n\n        self.layers = nn.ModuleList()\n        # i2h\n        self.layers.append(\n            RelGraphConv(\n                self.h_dim,\n                self.h_dim,\n                self.num_rels,\n                \"basis\",\n                self.num_bases,\n                activation=F.relu,\n                self_loop=self.use_self_loop,\n                dropout=self.dropout,\n                layer_norm=layer_norm,\n            )\n        )\n        # h2h\n        for idx in range(self.num_hidden_layers):\n            self.layers.append(\n                RelGraphConv(\n                    self.h_dim,\n                    self.h_dim,\n                    self.num_rels,\n                    \"basis\",\n                    self.num_bases,\n                    activation=F.relu,\n                    self_loop=self.use_self_loop,\n                    dropout=self.dropout,\n                    layer_norm=layer_norm,\n                )\n            )\n        # h2o\n        self.layers.append(\n            RelGraphConv(\n                self.h_dim,\n                self.out_dim,\n                self.num_rels,\n                \"basis\",\n                self.num_bases,\n                activation=None,\n                self_loop=self.use_self_loop,\n                layer_norm=layer_norm,\n            )\n        )\n\n    def forward(self, blocks, feats, norm=None):\n        if blocks is None:\n            # full graph training\n            blocks = [self.g] * len(self.layers)\n        h = feats\n        for layer, block in zip(self.layers, blocks):\n            block = block.to(self.device)\n            h = layer(block, h, block.edata[\"etype\"], block.edata[\"norm\"])\n        return h\n\n\ndef gen_norm(g):\n    _, v, eid = g.all_edges(form=\"all\")\n    _, inverse_index, count = th.unique(\n        v, return_inverse=True, return_counts=True\n    )\n    degrees = count[inverse_index]\n    norm = th.ones(eid.shape[0], device=eid.device) / degrees\n    norm = norm.unsqueeze(1)\n    g.edata[\"norm\"] = norm\n\n\nclass NeighborSampler:\n    def __init__(self, g, target_idx, fanouts):\n        self.g = g\n        self.target_idx = target_idx\n        self.fanouts = fanouts\n\n    def sample_blocks(self, seeds):\n        blocks = []\n        etypes = []\n        norms = []\n        ntypes = []\n        seeds = th.tensor(seeds).long()\n        cur = self.target_idx[seeds]\n        for fanout in self.fanouts:\n            if fanout is None or fanout == -1:\n                frontier = dgl.in_subgraph(self.g, cur)\n            else:\n                frontier = dgl.sampling.sample_neighbors(self.g, cur, fanout)\n            block = dgl.to_block(frontier, cur)\n            gen_norm(block)\n            cur = block.srcdata[dgl.NID]\n            blocks.insert(0, block)\n        return seeds, blocks\n\n\n@utils.thread_wrapped_func\ndef run(proc_id, n_gpus, n_cpus, args, devices, dataset, split, queue=None):\n    from .rgcn_model import RelGraphEmbedLayer\n\n    dev_id = devices[proc_id]\n    (\n        g,\n        node_feats,\n        num_of_ntype,\n        num_classes,\n        num_rels,\n        target_idx,\n        train_idx,\n        val_idx,\n        test_idx,\n        labels,\n    ) = dataset\n    labels = labels.cuda(dev_id)\n    if split is not None:\n        train_seed, val_seed, test_seed = split\n        train_idx = train_idx[train_seed]\n        # val_idx = val_idx[val_seed]\n        # test_idx = test_idx[test_seed]\n\n    fanouts = args.fanout\n    node_tids = g.ndata[dgl.NTYPE]\n    sampler = NeighborSampler(g, target_idx, fanouts)\n    loader = DataLoader(\n        dataset=train_idx.numpy(),\n        batch_size=args.batch_size,\n        collate_fn=sampler.sample_blocks,\n        shuffle=True,\n        num_workers=args.num_workers,\n    )\n\n    world_size = n_gpus\n\n    dist_init_method = \"tcp://{master_ip}:{master_port}\".format(\n        master_ip=\"127.0.0.1\", master_port=\"12345\"\n    )\n    backend = \"nccl\"\n\n    # using sparse embedding or usig mix_cpu_gpu model (embedding model can not be stored in GPU)\n    if args.dgl_sparse is False:\n        backend = \"gloo\"\n    print(\"backend using {}\".format(backend))\n    th.distributed.init_process_group(\n        backend=backend,\n        init_method=dist_init_method,\n        world_size=world_size,\n        rank=dev_id,\n    )\n\n    # node features\n    # None for one-hot feature, if not none, it should be the feature tensor.\n    #\n    embed_layer = RelGraphEmbedLayer(\n        dev_id,\n        g.num_nodes(),\n        node_tids,\n        num_of_ntype,\n        node_feats,\n        args.n_hidden,\n        dgl_sparse=args.dgl_sparse,\n    )\n\n    # create model\n    # all model params are in device.\n    model = EntityClassify(\n        dev_id,\n        g.num_nodes(),\n        args.n_hidden,\n        num_classes,\n        num_rels,\n        num_bases=args.n_bases,\n        num_hidden_layers=args.n_layers - 2,\n        dropout=args.dropout,\n        use_self_loop=args.use_self_loop,\n        layer_norm=args.layer_norm,\n    )\n\n    model.cuda(dev_id)\n    model = DistributedDataParallel(\n        model, device_ids=[dev_id], output_device=dev_id\n    )\n    if args.dgl_sparse:\n        embed_layer.cuda(dev_id)\n        if len(list(embed_layer.parameters())) > 0:\n            embed_layer = DistributedDataParallel(\n                embed_layer, device_ids=[dev_id], output_device=dev_id\n            )\n    else:\n        if len(list(embed_layer.parameters())) > 0:\n            embed_layer = DistributedDataParallel(\n                embed_layer, device_ids=None, output_device=None\n            )\n\n    # optimizer\n    dense_params = list(model.parameters())\n    if args.node_feats:\n        if n_gpus > 1:\n            dense_params += list(embed_layer.module.embeds.parameters())\n        else:\n            dense_params += list(embed_layer.embeds.parameters())\n    optimizer = th.optim.Adam(\n        dense_params, lr=args.lr, weight_decay=args.l2norm\n    )\n\n    if args.dgl_sparse:\n        all_params = list(model.parameters()) + list(embed_layer.parameters())\n        optimizer = th.optim.Adam(\n            all_params, lr=args.lr, weight_decay=args.l2norm\n        )\n        if n_gpus > 1 and isinstance(embed_layer, DistributedDataParallel):\n            dgl_emb = embed_layer.module.dgl_emb\n        else:\n            dgl_emb = embed_layer.dgl_emb\n        emb_optimizer = (\n            dgl.optim.SparseAdam(params=dgl_emb, lr=args.sparse_lr, eps=1e-8)\n            if len(dgl_emb) > 0\n            else None\n        )\n    else:\n        if n_gpus > 1:\n            embs = list(embed_layer.module.node_embeds.parameters())\n        else:\n            embs = list(embed_layer.node_embeds.parameters())\n        emb_optimizer = (\n            th.optim.SparseAdam(embs, lr=args.sparse_lr)\n            if len(embs) > 0\n            else None\n        )\n\n    # training loop\n    print(\"start training...\")\n    forward_time = []\n    backward_time = []\n\n    train_time = 0\n    validation_time = 0\n    test_time = 0\n    last_val_acc = 0.0\n    do_test = False\n    if n_gpus > 1 and n_cpus - args.num_workers > 0:\n        th.set_num_threads(n_cpus - args.num_workers)\n    steps = 0\n    time_records = []\n    model.train()\n    embed_layer.train()\n\n    # Warm up\n    for i, sample_data in enumerate(loader):\n        seeds, blocks = sample_data\n        t0 = time.time()\n        feats = embed_layer(\n            blocks[0].srcdata[dgl.NID],\n            blocks[0].srcdata[\"ntype\"],\n            blocks[0].srcdata[\"type_id\"],\n            node_feats,\n        )\n        logits = model(blocks, feats)\n        loss = F.cross_entropy(logits, labels[seeds])\n        t1 = time.time()\n        optimizer.zero_grad()\n        if emb_optimizer is not None:\n            emb_optimizer.zero_grad()\n\n        loss.backward()\n        if emb_optimizer is not None:\n            emb_optimizer.step()\n        optimizer.step()\n        gc.collect()\n        if i >= 3:\n            break\n\n    # real time\n    for i, sample_data in enumerate(loader):\n        seeds, blocks = sample_data\n        t0 = time.time()\n        feats = embed_layer(\n            blocks[0].srcdata[dgl.NID],\n            blocks[0].srcdata[\"ntype\"],\n            blocks[0].srcdata[\"type_id\"],\n            node_feats,\n        )\n        logits = model(blocks, feats)\n        loss = F.cross_entropy(logits, labels[seeds])\n        t1 = time.time()\n        optimizer.zero_grad()\n        if emb_optimizer is not None:\n            emb_optimizer.zero_grad()\n\n        loss.backward()\n        if emb_optimizer is not None:\n            emb_optimizer.step()\n        optimizer.step()\n        th.distributed.barrier()\n        t2 = time.time()\n\n        forward_time.append(t1 - t0)\n        backward_time.append(t2 - t1)\n        time_records.append(t2 - t0)\n\n        gc.collect()\n        if i >= 10:\n            break\n\n    if proc_id == 0:\n        queue.put(np.array(time_records))\n\n\n@utils.skip_if_not_4gpu()\n@utils.benchmark(\"time\", timeout=600)\n@utils.parametrize(\"data\", [\"am\", \"ogbn-mag\"])\n@utils.parametrize(\"dgl_sparse\", [True, False])\ndef track_time(data, dgl_sparse):\n    # load graph data\n    dataset = utils.process_data(data)\n    args = config()\n    devices = [0, 1, 2, 3]\n    args.dgl_sparse = dgl_sparse\n    args.dataset = dataset\n    ogb_dataset = False\n\n    if data == \"am\":\n        args.n_bases = 40\n        args.l2norm = 5e-4\n    elif data == \"ogbn-mag\":\n        args.n_bases = 2\n        args.l2norm = 0\n    else:\n        raise ValueError()\n\n    if ogb_dataset is True:\n        split_idx = dataset.get_idx_split()\n        train_idx = split_idx[\"train\"][\"paper\"]\n        val_idx = split_idx[\"valid\"][\"paper\"]\n        test_idx = split_idx[\"test\"][\"paper\"]\n        hg_orig, labels = dataset[0]\n        subgs = {}\n        for etype in hg_orig.canonical_etypes:\n            u, v = hg_orig.all_edges(etype=etype)\n            subgs[etype] = (u, v)\n            subgs[(etype[2], \"rev-\" + etype[1], etype[0])] = (v, u)\n        hg = dgl.heterograph(subgs)\n        hg.nodes[\"paper\"].data[\"feat\"] = hg_orig.nodes[\"paper\"].data[\"feat\"]\n        labels = labels[\"paper\"].squeeze()\n\n        num_rels = len(hg.canonical_etypes)\n        num_of_ntype = len(hg.ntypes)\n        num_classes = dataset.num_classes\n        if args.dataset == \"ogbn-mag\":\n            category = \"paper\"\n        print(\"Number of relations: {}\".format(num_rels))\n        print(\"Number of class: {}\".format(num_classes))\n        print(\"Number of train: {}\".format(len(train_idx)))\n        print(\"Number of valid: {}\".format(len(val_idx)))\n        print(\"Number of test: {}\".format(len(test_idx)))\n\n    else:\n        # Load from hetero-graph\n        hg = dataset[0]\n\n        num_rels = len(hg.canonical_etypes)\n        num_of_ntype = len(hg.ntypes)\n        category = dataset.predict_category\n        num_classes = dataset.num_classes\n        train_mask = hg.nodes[category].data.pop(\"train_mask\")\n        test_mask = hg.nodes[category].data.pop(\"test_mask\")\n        labels = hg.nodes[category].data.pop(\"labels\")\n        train_idx = th.nonzero(train_mask, as_tuple=False).squeeze()\n        test_idx = th.nonzero(test_mask, as_tuple=False).squeeze()\n\n        # AIFB, MUTAG, BGS and AM datasets do not provide validation set split.\n        # Split train set into train and validation if args.validation is set\n        # otherwise use train set as the validation set.\n        if args.validation:\n            val_idx = train_idx[: len(train_idx) // 5]\n            train_idx = train_idx[len(train_idx) // 5 :]\n        else:\n            val_idx = train_idx\n\n    node_feats = []\n    for ntype in hg.ntypes:\n        if len(hg.nodes[ntype].data) == 0 or args.node_feats is False:\n            node_feats.append(hg.num_nodes(ntype))\n        else:\n            assert len(hg.nodes[ntype].data) == 1\n            feat = hg.nodes[ntype].data.pop(\"feat\")\n            node_feats.append(feat.share_memory_())\n\n    # get target category id\n    category_id = len(hg.ntypes)\n    for i, ntype in enumerate(hg.ntypes):\n        if ntype == category:\n            category_id = i\n        print(\"{}:{}\".format(i, ntype))\n\n    g = dgl.to_homogeneous(hg)\n    g.ndata[\"ntype\"] = g.ndata[dgl.NTYPE]\n    g.ndata[\"ntype\"].share_memory_()\n    g.edata[\"etype\"] = g.edata[dgl.ETYPE]\n    g.edata[\"etype\"].share_memory_()\n    g.ndata[\"type_id\"] = g.ndata[dgl.NID]\n    g.ndata[\"type_id\"].share_memory_()\n    node_ids = th.arange(g.num_nodes())\n\n    # find out the target node ids\n    node_tids = g.ndata[dgl.NTYPE]\n    loc = node_tids == category_id\n    target_idx = node_ids[loc]\n    target_idx.share_memory_()\n    train_idx.share_memory_()\n    val_idx.share_memory_()\n    test_idx.share_memory_()\n    # Create csr/coo/csc formats before launching training processes with multi-gpu.\n    # This avoids creating certain formats in each sub-process, which saves momory and CPU.\n    g.create_formats_()\n\n    n_gpus = len(devices)\n    n_cpus = mp.cpu_count()\n\n    ctx = mp.get_context(\"spawn\")\n    queue = ctx.Queue()\n    procs = []\n    num_train_seeds = train_idx.shape[0]\n    num_valid_seeds = val_idx.shape[0]\n    num_test_seeds = test_idx.shape[0]\n    train_seeds = th.randperm(num_train_seeds)\n    valid_seeds = th.randperm(num_valid_seeds)\n    test_seeds = th.randperm(num_test_seeds)\n    tseeds_per_proc = num_train_seeds // n_gpus\n    vseeds_per_proc = num_valid_seeds // n_gpus\n    tstseeds_per_proc = num_test_seeds // n_gpus\n\n    for proc_id in range(n_gpus):\n        # we have multi-gpu for training, evaluation and testing\n        # so split trian set, valid set and test set into num-of-gpu parts.\n        proc_train_seeds = train_seeds[\n            proc_id * tseeds_per_proc : (proc_id + 1) * tseeds_per_proc\n            if (proc_id + 1) * tseeds_per_proc < num_train_seeds\n            else num_train_seeds\n        ]\n        proc_valid_seeds = valid_seeds[\n            proc_id * vseeds_per_proc : (proc_id + 1) * vseeds_per_proc\n            if (proc_id + 1) * vseeds_per_proc < num_valid_seeds\n            else num_valid_seeds\n        ]\n        proc_test_seeds = test_seeds[\n            proc_id * tstseeds_per_proc : (proc_id + 1) * tstseeds_per_proc\n            if (proc_id + 1) * tstseeds_per_proc < num_test_seeds\n            else num_test_seeds\n        ]\n\n        p = ctx.Process(\n            target=run,\n            args=(\n                proc_id,\n                n_gpus,\n                n_cpus // n_gpus,\n                args,\n                devices,\n                (\n                    g,\n                    node_feats,\n                    num_of_ntype,\n                    num_classes,\n                    num_rels,\n                    target_idx,\n                    train_idx,\n                    val_idx,\n                    test_idx,\n                    labels,\n                ),\n                (proc_train_seeds, proc_valid_seeds, proc_test_seeds),\n                queue,\n            ),\n        )\n        p.start()\n\n        procs.append(p)\n    for p in procs:\n        p.join()\n    time_records = queue.get(block=False)\n    num_exclude = 10  # exclude first 10 iterations\n    if len(time_records) < 15:\n        # exclude less if less records\n        num_exclude = int(len(time_records) * 0.3)\n    return np.mean(time_records[num_exclude:])\n\n\ndef config():\n    # parser = argparse.ArgumentParser(description='RGCN')\n    args = SimpleNamespace(\n        dropout=0,\n        n_hidden=16,\n        gpu=\"0,1,2,3\",\n        lr=1e-2,\n        sparse_lr=2e-2,\n        n_bases=-1,\n        n_layers=2,\n        dataset=None,\n        l2norm=0,\n        fanout=[10, 25],\n        use_self_loop=True,\n        batch_size=100,\n        layer_norm=False,\n        validation=False,\n        node_feats=False,\n        num_workers=0,\n        dgl_sparse=False,\n    )\n\n    return args\n\n\nif __name__ == \"__main__\":\n    track_time(\"am\")\n"
  },
  {
    "path": "benchmarks/benchmarks/multigpu/bench_multigpu_sage.py",
    "content": "import argparse\nimport math\nimport time\nfrom types import SimpleNamespace\nfrom typing import NamedTuple\n\nimport dgl\nimport dgl.nn.pytorch as dglnn\nimport numpy as np\nimport torch as th\nimport torch.multiprocessing as mp\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nfrom torch.nn.parallel import DistributedDataParallel\n\nfrom .. import utils\n\n\nclass SAGE(nn.Module):\n    def __init__(\n        self, in_feats, n_hidden, n_classes, n_layers, activation, dropout\n    ):\n        super().__init__()\n        self.n_layers = n_layers\n        self.n_hidden = n_hidden\n        self.n_classes = n_classes\n        self.layers = nn.ModuleList()\n        self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, \"mean\"))\n        for i in range(1, n_layers - 1):\n            self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, \"mean\"))\n        self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, \"mean\"))\n        self.dropout = nn.Dropout(dropout)\n        self.activation = activation\n\n    def forward(self, blocks, x):\n        h = x\n        for l, (layer, block) in enumerate(zip(self.layers, blocks)):\n            h = layer(block, h)\n            if l != len(self.layers) - 1:\n                h = self.activation(h)\n                h = self.dropout(h)\n        return h\n\n\ndef load_subtensor(nfeat, labels, seeds, input_nodes, dev_id):\n    \"\"\"\n    Extracts features and labels for a subset of nodes.\n    \"\"\"\n    batch_inputs = nfeat[input_nodes].to(dev_id)\n    batch_labels = labels[seeds].to(dev_id)\n    return batch_inputs, batch_labels\n\n\n# Entry point\n@utils.thread_wrapped_func\ndef run(result_queue, proc_id, n_gpus, args, devices, data):\n    dev_id = devices[proc_id]\n    timing_records = []\n    if n_gpus > 1:\n        dist_init_method = \"tcp://{master_ip}:{master_port}\".format(\n            master_ip=\"127.0.0.1\", master_port=\"12345\"\n        )\n        world_size = n_gpus\n        th.distributed.init_process_group(\n            backend=\"nccl\",\n            init_method=dist_init_method,\n            world_size=world_size,\n            rank=proc_id,\n        )\n    th.cuda.set_device(dev_id)\n\n    n_classes, train_g, _, _ = data\n\n    train_nfeat = train_g.ndata.pop(\"feat\")\n    train_labels = train_g.ndata.pop(\"label\")\n\n    train_nfeat = train_nfeat.to(dev_id)\n    train_labels = train_labels.to(dev_id)\n\n    in_feats = train_nfeat.shape[1]\n\n    train_mask = train_g.ndata[\"train_mask\"]\n    train_nid = train_mask.nonzero().squeeze()\n\n    # Split train_nid\n    train_nid = th.split(train_nid, math.ceil(len(train_nid) / n_gpus))[proc_id]\n\n    # Create PyTorch DataLoader for constructing blocks\n    sampler = dgl.dataloading.MultiLayerNeighborSampler(\n        [int(fanout) for fanout in args.fan_out.split(\",\")]\n    )\n    dataloader = dgl.dataloading.DataLoader(\n        train_g,\n        train_nid,\n        sampler,\n        batch_size=args.batch_size,\n        shuffle=True,\n        drop_last=False,\n        num_workers=args.num_workers,\n    )\n\n    # Define model and optimizer\n    model = SAGE(\n        in_feats,\n        args.num_hidden,\n        n_classes,\n        args.num_layers,\n        F.relu,\n        args.dropout,\n    )\n    model = model.to(dev_id)\n    if n_gpus > 1:\n        model = DistributedDataParallel(\n            model, device_ids=[dev_id], output_device=dev_id\n        )\n    loss_fcn = nn.CrossEntropyLoss()\n    optimizer = optim.Adam(model.parameters(), lr=args.lr)\n\n    # Training loop\n    for step, (input_nodes, seeds, blocks) in enumerate(dataloader):\n        if proc_id == 0:\n            tic_step = time.time()\n\n        batch_inputs, batch_labels = load_subtensor(\n            train_nfeat, train_labels, seeds, input_nodes, dev_id\n        )\n        blocks = [block.int().to(dev_id) for block in blocks]\n        batch_pred = model(blocks, batch_inputs)\n        loss = loss_fcn(batch_pred, batch_labels)\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        if proc_id == 0:\n            timing_records.append(time.time() - tic_step)\n\n        if step >= 50:\n            break\n\n    if n_gpus > 1:\n        th.distributed.barrier()\n    if proc_id == 0:\n        result_queue.put(np.array(timing_records))\n\n\n@utils.benchmark(\"time\", timeout=600)\n@utils.skip_if_not_4gpu()\n@utils.parametrize(\"data\", [\"reddit\", \"ogbn-products\"])\ndef track_time(data):\n    args = SimpleNamespace(\n        num_hidden=16,\n        fan_out=\"10,25\",\n        batch_size=1000,\n        lr=0.003,\n        dropout=0.5,\n        num_layers=2,\n        num_workers=4,\n    )\n\n    devices = [0, 1, 2, 3]\n    n_gpus = len(devices)\n    data = utils.process_data(data)\n    g = data[0]\n    n_classes = data.num_classes\n    train_g = val_g = test_g = g\n\n    # Create csr/coo/csc formats before launching training processes with multi-gpu.\n    # This avoids creating certain formats in each sub-process, which saves momory and CPU.\n    train_g.create_formats_()\n    val_g.create_formats_()\n    test_g.create_formats_()\n    # Pack data\n    data = n_classes, train_g, val_g, test_g\n\n    ctx = mp.get_context(\"spawn\")\n    result_queue = ctx.Queue()\n    procs = []\n    for proc_id in range(n_gpus):\n        p = ctx.Process(\n            target=run,\n            args=(result_queue, proc_id, n_gpus, args, devices, data),\n        )\n        p.start()\n        procs.append(p)\n    for p in procs:\n        p.join()\n    time_records = result_queue.get(block=False)\n    num_exclude = 10  # exclude first 10 iterations\n    if len(time_records) < 15:\n        # exclude less if less records\n        num_exclude = int(len(time_records) * 0.3)\n    return np.mean(time_records[num_exclude:])\n"
  },
  {
    "path": "benchmarks/benchmarks/multigpu/rgcn_model.py",
    "content": "import dgl\nimport torch as th\nimport torch.nn as nn\n\n\nclass BaseRGCN(nn.Module):\n    def __init__(\n        self,\n        num_nodes,\n        h_dim,\n        out_dim,\n        num_rels,\n        num_bases,\n        num_hidden_layers=1,\n        dropout=0,\n        use_self_loop=False,\n        use_cuda=False,\n    ):\n        super(BaseRGCN, self).__init__()\n        self.num_nodes = num_nodes\n        self.h_dim = h_dim\n        self.out_dim = out_dim\n        self.num_rels = num_rels\n        self.num_bases = None if num_bases < 0 else num_bases\n        self.num_hidden_layers = num_hidden_layers\n        self.dropout = dropout\n        self.use_self_loop = use_self_loop\n        self.use_cuda = use_cuda\n\n        # create rgcn layers\n        self.build_model()\n\n    def build_model(self):\n        self.layers = nn.ModuleList()\n        # i2h\n        i2h = self.build_input_layer()\n        if i2h is not None:\n            self.layers.append(i2h)\n        # h2h\n        for idx in range(self.num_hidden_layers):\n            h2h = self.build_hidden_layer(idx)\n            self.layers.append(h2h)\n        # h2o\n        h2o = self.build_output_layer()\n        if h2o is not None:\n            self.layers.append(h2o)\n\n    def build_input_layer(self):\n        return None\n\n    def build_hidden_layer(self, idx):\n        raise NotImplementedError\n\n    def build_output_layer(self):\n        return None\n\n    def forward(self, g, h, r, norm):\n        for layer in self.layers:\n            h = layer(g, h, r, norm)\n        return h\n\n\ndef initializer(emb):\n    emb.uniform_(-1.0, 1.0)\n    return emb\n\n\nclass RelGraphEmbedLayer(nn.Module):\n    r\"\"\"Embedding layer for featureless heterograph.\n    Parameters\n    ----------\n    dev_id : int\n        Device to run the layer.\n    num_nodes : int\n        Number of nodes.\n    node_tides : tensor\n        Storing the node type id for each node starting from 0\n    num_of_ntype : int\n        Number of node types\n    input_size : list of int\n        A list of input feature size for each node type. If None, we then\n        treat certain input feature as an one-hot encoding feature.\n    embed_size : int\n        Output embed size\n    dgl_sparse : bool, optional\n        If true, use dgl.nn.NodeEmbedding otherwise use torch.nn.Embedding\n    \"\"\"\n\n    def __init__(\n        self,\n        dev_id,\n        num_nodes,\n        node_tids,\n        num_of_ntype,\n        input_size,\n        embed_size,\n        dgl_sparse=False,\n    ):\n        super(RelGraphEmbedLayer, self).__init__()\n        self.dev_id = th.device(dev_id if dev_id >= 0 else \"cpu\")\n        self.embed_size = embed_size\n        self.num_nodes = num_nodes\n        self.dgl_sparse = dgl_sparse\n\n        # create weight embeddings for each node for each relation\n        self.embeds = nn.ParameterDict()\n        self.node_embeds = {} if dgl_sparse else nn.ModuleDict()\n        self.num_of_ntype = num_of_ntype\n\n        for ntype in range(num_of_ntype):\n            if isinstance(input_size[ntype], int):\n                if dgl_sparse:\n                    self.node_embeds[str(ntype)] = dgl.nn.NodeEmbedding(\n                        input_size[ntype],\n                        embed_size,\n                        name=str(ntype),\n                        init_func=initializer,\n                    )\n                else:\n                    sparse_emb = th.nn.Embedding(\n                        input_size[ntype], embed_size, sparse=True\n                    )\n                    nn.init.uniform_(sparse_emb.weight, -1.0, 1.0)\n                    self.node_embeds[str(ntype)] = sparse_emb\n            else:\n                input_emb_size = input_size[ntype].shape[1]\n                embed = nn.Parameter(th.Tensor(input_emb_size, self.embed_size))\n                nn.init.xavier_uniform_(embed)\n                self.embeds[str(ntype)] = embed\n\n    @property\n    def dgl_emb(self):\n        \"\"\" \"\"\"\n        if self.dgl_sparse:\n            embs = [emb for emb in self.node_embeds.values()]\n            return embs\n        else:\n            return []\n\n    def forward(self, node_ids, node_tids, type_ids, features):\n        \"\"\"Forward computation\n        Parameters\n        ----------\n        node_ids : tensor\n            node ids to generate embedding for.\n        node_ids : tensor\n            node type ids\n        features : list of features\n            list of initial features for nodes belong to different node type.\n            If None, the corresponding features is an one-hot encoding feature,\n            else use the features directly as input feature and matmul a\n            projection matrix.\n        Returns\n        -------\n        tensor\n            embeddings as the input of the next layer\n        \"\"\"\n        tsd_ids = node_ids.to(self.dev_id)\n        embeds = th.empty(\n            node_ids.shape[0], self.embed_size, device=self.dev_id\n        )\n        for ntype in range(self.num_of_ntype):\n            loc = node_tids == ntype\n            if isinstance(features[ntype], int):\n                if self.dgl_sparse:\n                    embeds[loc] = self.node_embeds[str(ntype)](\n                        type_ids[loc], self.dev_id\n                    )\n                else:\n                    embeds[loc] = self.node_embeds[str(ntype)](\n                        type_ids[loc]\n                    ).to(self.dev_id)\n            else:\n                embeds[loc] = features[ntype][type_ids[loc]].to(\n                    self.dev_id\n                ) @ self.embeds[str(ntype)].to(self.dev_id)\n\n        return embeds\n"
  },
  {
    "path": "benchmarks/benchmarks/rgcn.py",
    "content": "import dgl\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dgl.nn.pytorch import RelGraphConv\n\nfrom . import utils\n\n\nclass RGCN(nn.Module):\n    def __init__(\n        self,\n        num_nodes,\n        h_dim,\n        out_dim,\n        num_rels,\n        regularizer=\"basis\",\n        num_bases=-1,\n        dropout=0.0,\n        self_loop=False,\n        ns_mode=False,\n    ):\n        super(RGCN, self).__init__()\n\n        if num_bases == -1:\n            num_bases = num_rels\n        self.emb = nn.Embedding(num_nodes, h_dim)\n        self.conv1 = RelGraphConv(\n            h_dim, h_dim, num_rels, regularizer, num_bases, self_loop=self_loop\n        )\n        self.conv2 = RelGraphConv(\n            h_dim,\n            out_dim,\n            num_rels,\n            regularizer,\n            num_bases,\n            self_loop=self_loop,\n        )\n        self.dropout = nn.Dropout(dropout)\n        self.ns_mode = ns_mode\n\n    def forward(self, g, nids=None):\n        if self.ns_mode:\n            # forward for neighbor sampling\n            x = self.emb(g[0].srcdata[dgl.NID])\n            h = self.conv1(g[0], x, g[0].edata[dgl.ETYPE], g[0].edata[\"norm\"])\n            h = self.dropout(F.relu(h))\n            h = self.conv2(g[1], h, g[1].edata[dgl.ETYPE], g[1].edata[\"norm\"])\n            return h\n        else:\n            x = self.emb.weight if nids is None else self.emb(nids)\n            h = self.conv1(g, x, g.edata[dgl.ETYPE], g.edata[\"norm\"])\n            h = self.dropout(F.relu(h))\n            h = self.conv2(g, h, g.edata[dgl.ETYPE], g.edata[\"norm\"])\n            return h\n\n\ndef load_data(data_name, get_norm=False, inv_target=False):\n    dataset = utils.process_data(data_name)\n\n    # Load hetero-graph\n    hg = dataset[0]\n\n    num_rels = len(hg.canonical_etypes)\n    category = dataset.predict_category\n    num_classes = dataset.num_classes\n    labels = hg.nodes[category].data.pop(\"labels\")\n    train_mask = hg.nodes[category].data.pop(\"train_mask\")\n    test_mask = hg.nodes[category].data.pop(\"test_mask\")\n    train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze()\n    test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze()\n\n    if get_norm:\n        # Calculate normalization weight for each edge,\n        # 1. / d, d is the degree of the destination node\n        for cetype in hg.canonical_etypes:\n            hg.edges[cetype].data[\"norm\"] = dgl.norm_by_dst(\n                hg, cetype\n            ).unsqueeze(1)\n        edata = [\"norm\"]\n    else:\n        edata = None\n\n    # get target category id\n    category_id = hg.ntypes.index(category)\n\n    g = dgl.to_homogeneous(hg, edata=edata)\n    # Rename the fields as they can be changed by for example DataLoader\n    g.ndata[\"ntype\"] = g.ndata.pop(dgl.NTYPE)\n    g.ndata[\"type_id\"] = g.ndata.pop(dgl.NID)\n    node_ids = torch.arange(g.num_nodes())\n\n    # find out the target node ids in g\n    loc = g.ndata[\"ntype\"] == category_id\n    target_idx = node_ids[loc]\n\n    if inv_target:\n        # Map global node IDs to type-specific node IDs. This is required for\n        # looking up type-specific labels in a minibatch\n        inv_target = torch.empty((g.num_nodes(),), dtype=torch.int64)\n        inv_target[target_idx] = torch.arange(\n            0, target_idx.shape[0], dtype=inv_target.dtype\n        )\n        return (\n            g,\n            num_rels,\n            num_classes,\n            labels,\n            train_idx,\n            test_idx,\n            target_idx,\n            inv_target,\n        )\n    else:\n        return g, num_rels, num_classes, labels, train_idx, test_idx, target_idx\n"
  },
  {
    "path": "benchmarks/benchmarks/utils.py",
    "content": "import inspect\nimport json\nimport os\nimport pickle\nimport shutil\nimport time\nimport zipfile\nfrom functools import partial, reduce, wraps\nfrom timeit import default_timer\n\nimport dgl\n\nimport numpy as np\nimport pandas\nimport requests\nimport torch\nfrom ogb.nodeproppred import DglNodePropPredDataset\n\n\ndef _download(url, path, filename):\n    fn = os.path.join(path, filename)\n    if os.path.exists(fn):\n        return\n\n    os.makedirs(path, exist_ok=True)\n    f_remote = requests.get(url, stream=True)\n    sz = f_remote.headers.get(\"content-length\")\n    assert f_remote.status_code == 200, \"fail to open {}\".format(url)\n    with open(fn, \"wb\") as writer:\n        for chunk in f_remote.iter_content(chunk_size=1024 * 1024):\n            writer.write(chunk)\n    print(\"Download finished.\")\n\n\nimport traceback\nfrom _thread import start_new_thread\n\n# GRAPH_CACHE = {}\nimport torch.multiprocessing as mp\n\n\ndef thread_wrapped_func(func):\n    \"\"\"\n    Wraps a process entry point to make it work with OpenMP.\n    \"\"\"\n\n    @wraps(func)\n    def decorated_function(*args, **kwargs):\n        queue = mp.Queue()\n\n        def _queue_result():\n            exception, trace, res = None, None, None\n            try:\n                res = func(*args, **kwargs)\n            except Exception as e:\n                exception = e\n                trace = traceback.format_exc()\n            queue.put((res, exception, trace))\n\n        start_new_thread(_queue_result, ())\n        result, exception, trace = queue.get()\n        if exception is None:\n            return result\n        else:\n            assert isinstance(exception, Exception)\n            raise exception.__class__(trace)\n\n    return decorated_function\n\n\ndef get_graph(name, format=None):\n    # global GRAPH_CACHE\n    # if name in GRAPH_CACHE:\n    #     return GRAPH_CACHE[name].to(format)\n    if isinstance(format, str):\n        format = [format]  # didn't specify format\n    if format is None:\n        format = [\"csc\", \"csr\", \"coo\"]\n    g = None\n    if name == \"cora\":\n        g = dgl.data.CoraGraphDataset(verbose=False)[0]\n    elif name == \"pubmed\":\n        g = dgl.data.PubmedGraphDataset(verbose=False)[0]\n    elif name == \"livejournal\":\n        bin_path = \"/tmp/dataset/livejournal/livejournal_{}.bin\".format(format)\n        if os.path.exists(bin_path):\n            g_list, _ = dgl.load_graphs(bin_path)\n            g = g_list[0]\n        else:\n            g = get_livejournal().formats(format)\n            dgl.save_graphs(bin_path, [g])\n    elif name == \"friendster\":\n        bin_path = \"/tmp/dataset/friendster/friendster_{}.bin\".format(format)\n        if os.path.exists(bin_path):\n            g_list, _ = dgl.load_graphs(bin_path)\n            g = g_list[0]\n        else:\n            # the original node IDs of friendster are not consecutive, so we compact it\n            g = dgl.compact_graphs(get_friendster()).formats(format)\n            dgl.save_graphs(bin_path, [g])\n    elif name == \"reddit\":\n        bin_path = \"/tmp/dataset/reddit/reddit_{}.bin\".format(format)\n        if os.path.exists(bin_path):\n            g_list, _ = dgl.load_graphs(bin_path)\n            g = g_list[0]\n        else:\n            g = dgl.data.RedditDataset(self_loop=True)[0].formats(format)\n            dgl.save_graphs(bin_path, [g])\n    elif name.startswith(\"ogb\"):\n        g = get_ogb_graph(name)\n    else:\n        raise Exception(\"Unknown dataset\")\n    # GRAPH_CACHE[name] = g\n    g = g.formats(format)\n    return g\n\n\ndef get_ogb_graph(name):\n    os.symlink(\"/tmp/dataset/\", os.path.join(os.getcwd(), \"dataset\"))\n    data = DglNodePropPredDataset(name=name)\n    return data[0][0]\n\n\ndef get_livejournal():\n    # Same as https://snap.stanford.edu/data/soc-LiveJournal1.txt.gz\n    _download(\n        \"https://dgl-asv-data.s3-us-west-2.amazonaws.com/dataset/livejournal/soc-LiveJournal1.txt.gz\",\n        \"/tmp/dataset/livejournal\",\n        \"soc-LiveJournal1.txt.gz\",\n    )\n    df = pandas.read_csv(\n        \"/tmp/dataset/livejournal/soc-LiveJournal1.txt.gz\",\n        sep=\"\\t\",\n        skiprows=4,\n        header=None,\n        names=[\"src\", \"dst\"],\n        compression=\"gzip\",\n    )\n    src = df[\"src\"].values\n    dst = df[\"dst\"].values\n    print(\"construct the graph\")\n    return dgl.graph((src, dst))\n\n\ndef get_friendster():\n    # Same as https://snap.stanford.edu/data/bigdata/communities/com-friendster.ungraph.txt.gz\n    _download(\n        \"https://dgl-asv-data.s3-us-west-2.amazonaws.com/dataset/friendster/com-friendster.ungraph.txt.gz\",\n        \"/tmp/dataset/friendster\",\n        \"com-friendster.ungraph.txt.gz\",\n    )\n    df = pandas.read_csv(\n        \"/tmp/dataset/friendster/com-friendster.ungraph.txt.gz\",\n        sep=\"\\t\",\n        skiprows=4,\n        header=None,\n        names=[\"src\", \"dst\"],\n        compression=\"gzip\",\n    )\n    src = df[\"src\"].values\n    dst = df[\"dst\"].values\n    print(\"construct the graph\")\n    return dgl.graph((src, dst))\n\n\nclass OGBDataset(object):\n    def __init__(self, g, num_labels, predict_category=None):\n        self._g = g\n        self._num_labels = num_labels\n        self._predict_category = predict_category\n\n    @property\n    def num_labels(self):\n        return self._num_labels\n\n    @property\n    def num_classes(self):\n        return self._num_labels\n\n    @property\n    def predict_category(self):\n        return self._predict_category\n\n    def __getitem__(self, idx):\n        return self._g\n\n\ndef load_ogb_product():\n    name = \"ogbn-products\"\n    os.symlink(\"/tmp/dataset/\", os.path.join(os.getcwd(), \"dataset\"))\n\n    print(\"load\", name)\n    data = DglNodePropPredDataset(name=name)\n    print(\"finish loading\", name)\n    splitted_idx = data.get_idx_split()\n    graph, labels = data[0]\n    labels = labels[:, 0]\n\n    graph.ndata[\"label\"] = labels\n    in_feats = graph.ndata[\"feat\"].shape[1]\n    num_labels = len(\n        torch.unique(labels[torch.logical_not(torch.isnan(labels))])\n    )\n\n    # Find the node IDs in the training, validation, and test set.\n    train_nid, val_nid, test_nid = (\n        splitted_idx[\"train\"],\n        splitted_idx[\"valid\"],\n        splitted_idx[\"test\"],\n    )\n    train_mask = torch.zeros((graph.num_nodes(),), dtype=torch.bool)\n    train_mask[train_nid] = True\n    val_mask = torch.zeros((graph.num_nodes(),), dtype=torch.bool)\n    val_mask[val_nid] = True\n    test_mask = torch.zeros((graph.num_nodes(),), dtype=torch.bool)\n    test_mask[test_nid] = True\n    graph.ndata[\"train_mask\"] = train_mask\n    graph.ndata[\"val_mask\"] = val_mask\n    graph.ndata[\"test_mask\"] = test_mask\n\n    return OGBDataset(graph, num_labels)\n\n\ndef load_ogb_mag():\n    name = \"ogbn-mag\"\n    os.symlink(\"/tmp/dataset/\", os.path.join(os.getcwd(), \"dataset\"))\n\n    print(\"load\", name)\n    dataset = DglNodePropPredDataset(name=name)\n    print(\"finish loading\", name)\n    split_idx = dataset.get_idx_split()\n    train_idx = split_idx[\"train\"][\"paper\"]\n    val_idx = split_idx[\"valid\"][\"paper\"]\n    test_idx = split_idx[\"test\"][\"paper\"]\n    hg_orig, labels = dataset[0]\n    subgs = {}\n    for etype in hg_orig.canonical_etypes:\n        u, v = hg_orig.all_edges(etype=etype)\n        subgs[etype] = (u, v)\n        subgs[(etype[2], \"rev-\" + etype[1], etype[0])] = (v, u)\n    hg = dgl.heterograph(subgs)\n    hg.nodes[\"paper\"].data[\"feat\"] = hg_orig.nodes[\"paper\"].data[\"feat\"]\n    hg.nodes[\"paper\"].data[\"labels\"] = labels[\"paper\"].squeeze()\n    train_mask = torch.zeros((hg.num_nodes(\"paper\"),), dtype=torch.bool)\n    train_mask[train_idx] = True\n    val_mask = torch.zeros((hg.num_nodes(\"paper\"),), dtype=torch.bool)\n    val_mask[val_idx] = True\n    test_mask = torch.zeros((hg.num_nodes(\"paper\"),), dtype=torch.bool)\n    test_mask[test_idx] = True\n    hg.nodes[\"paper\"].data[\"train_mask\"] = train_mask\n    hg.nodes[\"paper\"].data[\"val_mask\"] = val_mask\n    hg.nodes[\"paper\"].data[\"test_mask\"] = test_mask\n\n    num_classes = dataset.num_classes\n    return OGBDataset(hg, num_classes, \"paper\")\n\n\nclass PinsageDataset:\n    def __init__(self, g, user_ntype, item_ntype, textset):\n        self._g = g\n        self._user_ntype = user_ntype\n        self._item_ntype = item_ntype\n        self._textset = textset\n\n    @property\n    def user_ntype(self):\n        return self._user_ntype\n\n    @property\n    def item_ntype(self):\n        return self._item_ntype\n\n    @property\n    def textset(self):\n        return self._textset\n\n    def __getitem__(self, idx):\n        return self._g\n\n\ndef load_nowplaying_rs():\n    import torchtext.legacy as torchtext\n\n    # follow examples/pytorch/pinsage/README to create train_g.bin\n    name = \"train_g.bin\"\n    dataset_dir = os.path.join(os.getcwd(), \"dataset\")\n    os.symlink(\"/tmp/dataset/\", dataset_dir)\n\n    dataset_path = os.path.join(dataset_dir, \"nowplaying_rs\", name)\n    g_list, _ = dgl.load_graphs(dataset_path)\n    g = g_list[0]\n    user_ntype = \"user\"\n    item_ntype = \"track\"\n\n    # Assign user and movie IDs and use them as features (to learn an individual trainable\n    # embedding for each entity)\n    g.nodes[user_ntype].data[\"id\"] = torch.arange(g.num_nodes(user_ntype))\n    g.nodes[item_ntype].data[\"id\"] = torch.arange(g.num_nodes(item_ntype))\n\n    # Prepare torchtext dataset and vocabulary\n    fields = {}\n    examples = []\n    for i in range(g.num_nodes(item_ntype)):\n        example = torchtext.data.Example.fromlist([], [])\n        examples.append(example)\n    textset = torchtext.data.Dataset(examples, fields)\n\n    return PinsageDataset(g, user_ntype, item_ntype, textset)\n\n\ndef process_data(name):\n    if name == \"cora\":\n        return dgl.data.CoraGraphDataset()\n    elif name == \"pubmed\":\n        return dgl.data.PubmedGraphDataset()\n    elif name == \"aifb\":\n        return dgl.data.AIFBDataset()\n    elif name == \"mutag\":\n        return dgl.data.MUTAGDataset()\n    elif name == \"bgs\":\n        return dgl.data.BGSDataset()\n    elif name == \"am\":\n        return dgl.data.AMDataset()\n    elif name == \"reddit\":\n        return dgl.data.RedditDataset(self_loop=True)\n    elif name == \"ogbn-products\":\n        return load_ogb_product()\n    elif name == \"ogbn-mag\":\n        return load_ogb_mag()\n    elif name == \"nowplaying_rs\":\n        return load_nowplaying_rs()\n    else:\n        raise ValueError(\"Invalid dataset name:\", name)\n\n\ndef get_bench_device():\n    device = os.environ.get(\"DGL_BENCH_DEVICE\", \"cpu\")\n    if device.lower() == \"gpu\":\n        return \"cuda:0\"\n    else:\n        return device\n\n\ndef setup_track_time(*args, **kwargs):\n    # fix random seed\n    np.random.seed(42)\n    torch.random.manual_seed(42)\n\n\ndef setup_track_acc(*args, **kwargs):\n    # fix random seed\n    np.random.seed(42)\n    torch.random.manual_seed(42)\n\n\ndef setup_track_flops(*args, **kwargs):\n    # fix random seed\n    np.random.seed(42)\n    torch.random.manual_seed(42)\n\n\nTRACK_UNITS = {\n    \"time\": \"s\",\n    \"acc\": \"%\",\n    \"flops\": \"GFLOPS\",\n}\n\nTRACK_SETUP = {\n    \"time\": setup_track_time,\n    \"acc\": setup_track_acc,\n    \"flops\": setup_track_flops,\n}\n\n\ndef parametrize(param_name, params):\n    \"\"\"Decorator for benchmarking over a set of parameters.\n\n    Parameters\n    ----------\n    param_name : str\n        Parameter name. Must be one of the arguments of the decorated function.\n    params : list[any]\n        List of values to benchmark for the given parameter name. Recommend\n        to use Python's native object type (e.g., int, str, list[int]) because\n        ASV will display them on the plot.\n\n    Examples\n    --------\n\n    Benchmark function `foo` when argument `x` is equal to 10 or 20.\n\n    .. code::\n        @benchmark('time')\n        @parametrize('x', [10, 20]):\n        def foo(x):\n            pass\n\n    Benchmark function with multiple parametrizations. It will run the function\n    with all possible combinations. The example below generates 6 benchmarks.\n\n    .. code::\n        @benchmark('time')\n        @parametrize('x', [10, 20]):\n        @parametrize('y', [-1, -2, -3]):\n        def foo(x, y):\n            pass\n\n    When using multiple parametrizations, it can have arbitrary order. The example\n    below is the same as the above one.\n\n    .. code::\n        @benchmark('time')\n        @parametrize('y', [-1, -2, -3]):\n        @parametrize('x', [10, 20]):\n        def foo(x, y):\n            pass\n    \"\"\"\n\n    def _wrapper(func):\n        sig_params = inspect.signature(func).parameters.keys()\n        num_params = len(sig_params)\n        if getattr(func, \"params\", None) is None:\n            func.params = [None] * num_params\n        if getattr(func, \"param_names\", None) is None:\n            func.param_names = [None] * num_params\n        found_param = False\n        for i, sig_param in enumerate(sig_params):\n            if sig_param == param_name:\n                func.params[i] = params\n                func.param_names[i] = param_name\n                found_param = True\n                break\n        if not found_param:\n            raise ValueError(\"Invalid parameter name:\", param_name)\n        return func\n\n    return _wrapper\n\n\ndef noop_decorator(param_name, params):\n    \"\"\"noop decorator\"\"\"\n\n    def _wrapper(func):\n        return func\n\n    return _wrapper\n\n\nclass TestFilter:\n    def __init__(self):\n        self.conf = None\n        if \"DGL_REG_CONF\" in os.environ:\n            current_dir = os.path.dirname(os.path.abspath(__file__))\n            path = os.path.join(\n                current_dir, \"../../\", os.environ[\"DGL_REG_CONF\"]\n            )\n            with open(path, \"r\") as f:\n                self.conf = json.load(f)\n            if \"INSTANCE_TYPE\" in os.environ:\n                instance_type = os.environ[\"INSTANCE_TYPE\"]\n            else:\n                raise Exception(\n                    \"Must set both DGL_REG_CONF and INSTANCE_TYPE as env\"\n                )\n            self.enabled_tests = self.conf[instance_type][\"tests\"]\n        else:\n            import logging\n\n            logging.warning(\"No regression test conf file specified\")\n\n    def check(self, func):\n        funcfullname = inspect.getmodule(func).__name__ + \".\" + func.__name__\n        if self.conf is None:\n            return True\n        else:\n            for enabled_testname in self.enabled_tests:\n                if enabled_testname in funcfullname:\n                    return True\n            return False\n\n\nfilter = TestFilter()\n\n\ndevice = os.environ.get(\"DGL_BENCH_DEVICE\", \"cpu\")\n\nif device == \"cpu\":\n    parametrize_cpu = parametrize\n    parametrize_gpu = noop_decorator\nelif device == \"gpu\":\n    parametrize_cpu = noop_decorator\n    parametrize_gpu = parametrize\nelse:\n    raise Exception(\n        \"Unknown device. Must be one of ['cpu', 'gpu'], but got {}\".format(\n            device\n        )\n    )\n\n\ndef skip_if_gpu():\n    \"\"\"skip if DGL_BENCH_DEVICE is gpu\"\"\"\n    device = os.environ.get(\"DGL_BENCH_DEVICE\", \"cpu\")\n\n    def _wrapper(func):\n        if device == \"gpu\":\n            # skip if not enabled\n            func.benchmark_name = \"skip_\" + func.__name__\n        return func\n\n    return _wrapper\n\n\ndef _cuda_device_count(q):\n    import torch\n\n    q.put(torch.cuda.device_count())\n\n\ndef get_num_gpu():\n    import multiprocessing as mp\n\n    q = mp.Queue()\n    p = mp.Process(target=_cuda_device_count, args=(q,))\n    p.start()\n    p.join()\n    return q.get(block=False)\n\n\nGPU_COUNT = get_num_gpu()\n\n\ndef skip_if_not_4gpu():\n    \"\"\"skip if DGL_BENCH_DEVICE is gpu\"\"\"\n\n    def _wrapper(func):\n        if GPU_COUNT < 4:\n            # skip if not enabled\n            print(\"Skip {}\".format(func.__name__))\n            func.benchmark_name = \"skip_\" + func.__name__\n        return func\n\n    return _wrapper\n\n\ndef benchmark(track_type, timeout=60):\n    \"\"\"Decorator for indicating the benchmark type.\n\n    Parameters\n    ----------\n    track_type : str\n        Type. Must be either:\n\n            - 'time' : For timing. Unit: second.\n            - 'acc' : For accuracy. Unit: percentage, value between 0 and 100.\n            - 'flops' : Unit: GFlops, number of floating point operations per second.\n    timeout : int\n        Timeout threshold in second.\n\n    Examples\n    --------\n\n    .. code::\n        @benchmark('time')\n        def foo():\n            pass\n    \"\"\"\n    assert track_type in [\"time\", \"acc\", \"flops\"]\n\n    def _wrapper(func):\n        func.unit = TRACK_UNITS[track_type]\n        func.setup = TRACK_SETUP[track_type]\n        func.timeout = timeout\n        if not filter.check(func):\n            # skip if not enabled\n            func.benchmark_name = \"skip_\" + func.__name__\n        return func\n\n    return _wrapper\n\n\n#####################################\n# Timer\n#####################################\n\n\nclass Timer:\n    def __init__(self, device=None):\n        self.timer = default_timer\n        if device is None:\n            self.device = get_bench_device()\n        else:\n            self.device = device\n\n    def __enter__(self):\n        if self.device == \"cuda:0\":\n            self.start_event = torch.cuda.Event(enable_timing=True)\n            self.end_event = torch.cuda.Event(enable_timing=True)\n            self.start_event.record()\n        else:\n            self.tic = self.timer()\n        return self\n\n    def __exit__(self, type, value, traceback):\n        if self.device == \"cuda:0\":\n            self.end_event.record()\n            torch.cuda.synchronize()  # Wait for the events to be recorded!\n            self.elapsed_secs = (\n                self.start_event.elapsed_time(self.end_event) / 1e3\n            )\n        else:\n            self.elapsed_secs = self.timer() - self.tic\n"
  },
  {
    "path": "benchmarks/run.sh",
    "content": "#!/bin/bash\n\nset -e\n\nDEVICE=$1\nROOT=/asv/dgl\n\n. /opt/conda/etc/profile.d/conda.sh\n\nconda activate base\npip install --upgrade pip\n# Newer asv version like 0.5.1 has different result format,\n# so we fix the version here. Or `generate_excel.py` has to be changed.\npip install asv==0.4.2\npip uninstall -y dgl\n\nexport DGL_BENCH_DEVICE=$DEVICE\necho \"DGL_BENCH_DEVICE=$DGL_BENCH_DEVICE\"\npushd $ROOT/benchmarks\ncat asv.conf.json\nasv machine --yes\n# If --launch-method is specified as 'spawn', multigpu tests will crash with\n# \"No module named 'benchmarks' is found\".\nasv run -e -v\nasv publish\npopd\n"
  },
  {
    "path": "benchmarks/scripts/README.md",
    "content": "Regression Test Suite\n========================\n\n### Spec of task.json\n```json\n# Note the test will be run if the name specified below is a substring of the full test name.\n# The fullname of \"benchmarks/model_acc/bench_sage_ns.track_acc\" will be \"model_acc.bench_sage_ns.track_acc\". Test will be run if it contains any keyword.\n# For example, \"model_acc\" will run all the tests under \"model_acc\" folder\n# \"bench_sage\" will run both \"bench_sage\" and \"bench_sage_ns\"\n# \"bench_sage.\" will only run \"bench_sage\"\n# \"ns\" will run any tests name contains \"ms\"\n# \"\" will run all tests\n{\n    \"c5.9xlarge\": { # The instance type to run the test\n        \"tests\": [\n            \"bench_sage\" # The test to be run on this instance\n        ],\n        \"env\": {\n            \"DEVICE\": \"cpu\" # The environment variable passed to publish.sh\n        }\n    },\n    \"g4dn.2xlarge\": {\n        ...\n    }\n}\n```\n\n\n### Environment variable\n- `MOUNT_PATH` specify the directory in the host to be mapped into docker, if exists will map the `MOUNT_PATH`(in host) to `/tmp/dataset`(in docker)\n- `INSTANCE_TYPE` specify the current instance type\n- `DGL_REG_CONF` specify the path to `task.json`, which is relative to the repo root. If specified, must specify `INSTANCE_TYPE` also"
  },
  {
    "path": "benchmarks/scripts/build_dgl_asv.sh",
    "content": "#!/bin/bash\n\nset -e\n\n# Default building only with cpu\nDEVICE=${DGL_BENCH_DEVICE:-cpu}\n\npip install -r /asv/torch_gpu_pip.txt\n\n# build\n# 'CUDA_TOOLKIT_ROOT_DIR' is always required for sparse build as torch1.13.1+cu116 is installed.\nCMAKE_VARS=\"-DUSE_OPENMP=ON -DBUILD_TORCH=ON -DCUDA_TOOLKIT_ROOT_DIR=/usr/local/cuda\"\nif [[ $DEVICE == \"gpu\" ]]; then\n    CMAKE_VARS=\"-DUSE_CUDA=ON $CMAKE_VARS\"\nfi\nmkdir -p build\npushd build\ncmake $CMAKE_VARS ..\nmake -j8\npopd\n"
  },
  {
    "path": "benchmarks/scripts/fix_ram_info.py",
    "content": "import json\nfrom pathlib import Path\n\n\ndef main():\n    result_dir = Path(__file__).parent / \"..\" / Path(\"results/\")\n    for per_machine_dir in result_dir.iterdir():\n        if per_machine_dir.is_dir():\n            try:\n                machine_json = json.loads(\n                    (per_machine_dir / \"machine.json\").read_text()\n                )\n                ram = machine_json[\"ram\"]\n                for f in per_machine_dir.glob(\"*.json\"):\n                    if f.stem != \"machine\":\n                        result = json.loads(f.read_text())\n                        result_ram = result[\"params\"][\"ram\"]\n                        if result_ram != ram:\n                            result[\"params\"][\"ram\"] = ram\n                            print(f\"Fix ram in {f}\")\n                            f.write_text(json.dumps(result))\n                        else:\n                            print(f\"Skip {f}\")\n            except Exception as e:\n                print(e)\n\n\nmain()\n"
  },
  {
    "path": "benchmarks/scripts/generate_excel.py",
    "content": "import json\nfrom itertools import product\nfrom pathlib import Path\n\nimport pandas as pd\n\n\ndef get_branch_name_from_hash(hash):\n    import subprocess\n\n    process = subprocess.Popen(\n        [\"git\", \"name-rev\", \"--name-only\", hash],\n        stdout=subprocess.PIPE,\n        stderr=subprocess.PIPE,\n    )\n    stdout, stderr = process.communicate()\n    if len(stderr) > 0:\n        return hash[:10]\n    else:\n        return stdout.decode(\"utf-8\").strip(\"\\n\")\n\n\ndef main():\n    results_path = Path(\"../results\")\n    results_path.is_dir()\n    machines = [f for f in results_path.glob(\"*\") if f.is_dir()]\n    output_results_dict = {}\n    for machine in machines:\n        per_machine_result = {}\n        commit_results_json_paths = [\n            f for f in machine.glob(\"*\") if f.name != \"machine.json\"\n        ]\n        for commit in commit_results_json_paths:\n            with commit.open() as f:\n                commit_result = json.load(f)\n            commit_hash = commit_result[\"commit_hash\"]\n            per_commit_result = {}\n            for test_name, result in commit_result[\"results\"].items():\n                per_commit_result[test_name] = []\n                if result[\"result\"] is None:\n                    for test_args in product(*result[\"params\"]):\n                        per_commit_result[test_name].append(\n                            {\"params\": \", \".join(test_args), \"result\": None}\n                        )\n                else:\n                    for test_args, performance_number in zip(\n                        product(*result[\"params\"]), result[\"result\"]\n                    ):\n                        per_commit_result[test_name].append(\n                            {\n                                \"params\": \", \".join(test_args),\n                                \"result\": performance_number,\n                            }\n                        )\n            per_machine_result[commit_hash] = per_commit_result\n        output_results_dict[machine.name] = per_machine_result\n    return output_results_dict\n\n\ndef dict_to_csv(output_results_dict):\n    with open(\"../results/benchmarks.json\") as f:\n        benchmark_conf = json.load(f)\n    unit_dict = {}\n    for k, v in benchmark_conf.items():\n        if k != \"version\":\n            unit_dict[k] = v[\"unit\"]\n    result_list = []\n    for machine, per_machine_result in output_results_dict.items():\n        for commit, test_cases in per_machine_result.items():\n            branch_name = get_branch_name_from_hash(commit)\n            result_column_name = \"number_{}\".format(branch_name)\n            # per_commit_result_list = []\n            for test_case_name, results in test_cases.items():\n                for result in results:\n                    result_list.append(\n                        {\n                            \"test_name\": test_case_name,\n                            \"params\": result[\"params\"],\n                            \"unit\": unit_dict[test_case_name],\n                            \"number\": result[\"result\"],\n                            \"commit\": branch_name,\n                            \"machine\": machine,\n                        }\n                    )\n    df = pd.DataFrame(result_list)\n    return df\n\n\ndef side_by_side_view(df):\n    commits = df[\"commit\"].unique().tolist()\n    full_df = df.loc[df[\"commit\"] == commits[0]]\n    for commit in commits[1:]:\n        per_commit_df = df.loc[df[\"commit\"] == commit]\n        full_df: pd.DataFrame = full_df.merge(\n            per_commit_df,\n            on=[\"test_name\", \"params\", \"machine\", \"unit\"],\n            how=\"outer\",\n            suffixes=(\n                \"_{}\".format(full_df.iloc[0][\"commit\"]),\n                \"_{}\".format(per_commit_df.iloc[0][\"commit\"]),\n            ),\n        )\n    full_df = full_df.loc[:, ~full_df.columns.str.startswith(\"commit\")]\n    return full_df\n\n\noutput_results_dict = main()\ndf = dict_to_csv(output_results_dict)\nsbs_df = side_by_side_view(df)\nsbs_df.to_csv(\"result.csv\")\n"
  },
  {
    "path": "benchmarks/scripts/install_dgl_asv.sh",
    "content": "#!/bin/bash\n\nset -e\n\n# install\npushd python\nrm -rf build *.egg-info dist\npip uninstall -y dgl\npython3 setup.py install\npopd\n"
  },
  {
    "path": "benchmarks/scripts/publish.sh",
    "content": "#!/bin/bash\n\n# The script launches a docker container to run ASV benchmarks. We use the same docker\n# image as our CI (i.e., dgllib/dgl-ci-gpu:conda). It performs the following steps:\n#\n#   1. Start a docker container of the given machine name. The machine name will be\n#      displayed on the generated website.\n#   2. Copy `.git` into the container. It allows ASV to determine the repository information\n#      such as commit hash, branches, etc.\n#   3. Copy this folder into the container including the ASV configuration file `asv.conf.json`.\n#      This means any changes to the files in this folder do not\n#      require a git commit. By contrast, to correctly benchmark your changes to the core\n#      library (e.g., \"python/dgl\"), you must call git commit first.\n#   4. It then calls the `run.sh` script inside the container. It will invoke `asv run`.\n#      You can change the command such as specifying the benchmarks to run or adding some flags.\n#   5. After benchmarking, it copies the generated `results` and `html` folders back to\n#      the host machine.\n#\n\nif [ $# -eq 2 ]; then\n    MACHINE=$1\n    DEVICE=$2\nelse\n    echo \"publish.sh <machine_name> <device>\"\n    exit 1\nfi\n\nWS_ROOT=/asv/dgl\ndocker pull public.ecr.aws/s1o7b3d9/benchmark_test:cu116_v230110\nif [ -z \"$DGL_REG_CONF\" ]; then\n    DOCKER_ENV_OPT=\"$DOCKER_ENV_OPT\"\nelse\n    DOCKER_ENV_OPT=\" -e DGL_REG_CONF=$DGL_REG_CONF $DOCKER_ENV_OPT\"\nfi\n\nif [ -z \"$INSTANCE_TYPE\" ]; then\n    DOCKER_ENV_OPT=\"$DOCKER_ENV_OPT\"\nelse\n    DOCKER_ENV_OPT=\" -e INSTANCE_TYPE=$INSTANCE_TYPE $DOCKER_ENV_OPT\"\nfi\n\nif [ -z \"$MOUNT_PATH\" ]; then\n    DOCKER_MOUNT_OPT=\"\"\nelse\n    DOCKER_MOUNT_OPT=\"-v ${MOUNT_PATH}:/tmp/dataset -v ${MOUNT_PATH}/dgl_home/:/root/.dgl/\"\nfi\n\necho $HOME\necho \"Mount Point: ${DOCKER_MOUNT_OPT}\"\necho \"Env opt: ${DOCKER_ENV_OPT}\"\necho \"DEVICE: ${DEVICE}\"\n\nif [[ $DEVICE == \"cpu\" ]]; then\n    docker run --name dgl-reg \\\n        --rm \\\n        $DOCKER_MOUNT_OPT \\\n        $DOCKER_ENV_OPT \\\n        --shm-size=\"16g\" \\\n        --hostname=$MACHINE -dit public.ecr.aws/s1o7b3d9/benchmark_test:cu116_v230110 /bin/bash\nelse\n    docker run --name dgl-reg \\\n        --rm --gpus all \\\n        $DOCKER_MOUNT_OPT \\\n        $DOCKER_ENV_OPT \\\n        --shm-size=\"16g\" \\\n        --hostname=$MACHINE -dit public.ecr.aws/s1o7b3d9/benchmark_test:cu116_v230110 /bin/bash\nfi\n\npwd\n\ndocker exec dgl-reg mkdir -p $WS_ROOT\ndocker cp ../../.git dgl-reg:$WS_ROOT\ndocker cp ../ dgl-reg:$WS_ROOT/benchmarks/\ndocker cp torch_gpu_pip.txt dgl-reg:/asv\ndocker exec $DOCKER_ENV_OPT dgl-reg bash $WS_ROOT/benchmarks/run.sh $DEVICE\ndocker cp dgl-reg:$WS_ROOT/benchmarks/results ../\ndocker cp dgl-reg:$WS_ROOT/benchmarks/html ../\ndocker stop dgl-reg\n"
  },
  {
    "path": "benchmarks/scripts/replace_branch.py",
    "content": "import argparse\nimport json\nimport os\nimport re\n\n\ndef json_minify(string, strip_space=True):\n    \"\"\"\n    Based on JSON.minify.js:\n    https://github.com/getify/JSON.minify\n    Contributers:\n    - Pradyun S. Gedam (conditions and variable names changed)\n    \"\"\"\n    tokenizer = re.compile(r'\"|(/\\*)|(\\*/)|(//)|\\n|\\r')\n    in_string = False\n    in_multi = False\n    in_single = False\n\n    new_str = []\n    index = 0\n\n    for match in re.finditer(tokenizer, string):\n        if not (in_multi or in_single):\n            tmp = string[index : match.start()]\n            if not in_string and strip_space:\n                # replace white space as defined in standard\n                tmp = re.sub(\"[ \\t\\n\\r]+\", \"\", tmp)\n            new_str.append(tmp)\n\n        index = match.end()\n        val = match.group()\n\n        if val == '\"' and not (in_multi or in_single):\n            escaped = re.search(r\"(\\\\)*$\", string[: match.start()])\n\n            # start of string or unescaped quote character to end string\n            if not in_string or (\n                escaped is None or len(escaped.group()) % 2 == 0\n            ):\n                in_string = not in_string\n            index -= 1  # include \" character in next catch\n        elif not (in_string or in_multi or in_single):\n            if val == \"/*\":\n                in_multi = True\n            elif val == \"//\":\n                in_single = True\n        elif val == \"*/\" and in_multi and not (in_string or in_single):\n            in_multi = False\n        elif val in \"\\r\\n\" and not (in_multi or in_string) and in_single:\n            in_single = False\n        elif not (\n            (in_multi or in_single) or (val in \" \\r\\n\\t\" and strip_space)\n        ):\n            new_str.append(val)\n\n    new_str.append(string[index:])\n    content = \"\".join(new_str)\n    content = content.replace(\",]\", \"]\")\n    content = content.replace(\",}\", \"}\")\n    return content\n\n\ndef add_prefix(branch_name):\n    if \"/\" not in branch_name:\n        return \"origin/\" + branch_name\n    else:\n        return branch_name\n\n\ndef change_branch(branch_str: str):\n    branches = [add_prefix(b) for b in branch_str.split(\",\")]\n    with open(\"../asv.conf.json\", \"r\") as f:\n        ss = f.read()\n        config_json = json.loads(json_minify(ss))\n        config_json[\"branches\"] = branches\n    with open(\"../asv.conf.json\", \"w\") as f:\n        json.dump(config_json, f)\n\n\nif __name__ == \"__main__\":\n    if \"BRANCH_STR\" in os.environ:\n        change_branch(os.environ[\"BRANCH_STR\"])\n"
  },
  {
    "path": "benchmarks/scripts/torch_gpu_pip.txt",
    "content": "--find-links https://download.pytorch.org/whl/torch_stable.html\ntorch==1.13.1+cu116\ntorchvision==0.14.1+cu116\ntorchmetrics\npytest\nnose\nnumpy\ncython\nscipy\nnetworkx\nmatplotlib\nnltk\nrequests[security]\ntqdm\nawscli\ntorchtext\npandas\nrdflib\nogb\n"
  },
  {
    "path": "benchmarks/task.json",
    "content": "{\n    \"r5.16xlarge\": {\n        \"tests\": [\n            \"api.\", \"kernel.\", \"model_acc.\", \"model_speed.\"\n        ],\n        \"env\": {\n            \"DEVICE\": \"cpu\"\n        }\n    },\n    \"g4dn.2xlarge\": {\n        \"tests\": [\n            \"api.\", \"kernel.\", \"model_acc.\", \"model_speed.\"\n        ],\n        \"env\": {\n            \"DEVICE\": \"gpu\"\n        }\n    },\n    \"g4dn.12xlarge\": {\n        \"tests\": [\n            \"multigpu.\"\n        ],\n        \"env\": {\n            \"DEVICE\": \"gpu\"\n        }\n    }\n}"
  },
  {
    "path": "cmake/modules/CUDA.cmake",
    "content": "# CUDA Module\nif(USE_CUDA)\n  find_cuda(${USE_CUDA} REQUIRED)\nelse(USE_CUDA)\n  return()\nendif()\n\n###### Borrowed from MSHADOW project\n\ninclude(CheckCXXCompilerFlag)\ncheck_cxx_compiler_flag(\"-std=c++17\"   SUPPORT_CXX17)\n\nset(dgl_known_gpu_archs \"35\" \"50\" \"60\" \"70\" \"75\")\nset(dgl_cuda_arch_ptx \"70\")\nif (CUDA_VERSION_MAJOR GREATER_EQUAL \"11\")\n  list(APPEND dgl_known_gpu_archs \"80\" \"86\")\n  set(dgl_cuda_arch_ptx \"80\" \"86\")\nendif()\nif (CUDA_VERSION VERSION_GREATER_EQUAL \"11.8\")\n  list(APPEND dgl_known_gpu_archs \"89\" \"90\")\n  set(dgl_cuda_arch_ptx \"90\")\nendif()\nif (CUDA_VERSION VERSION_GREATER_EQUAL \"12.0\")\n  list(REMOVE_ITEM dgl_known_gpu_archs \"35\")\nendif()\n\n################################################################################################\n# A function for automatic detection of GPUs installed  (if autodetection is enabled)\n# Usage:\n#   dgl_detect_installed_gpus(out_variable)\nfunction(dgl_detect_installed_gpus out_variable)\nset(CUDA_gpu_detect_output \"\")\n  if(NOT CUDA_gpu_detect_output)\n    message(STATUS \"Running GPU architecture autodetection\")\n    set(__cufile ${PROJECT_BINARY_DIR}/detect_cuda_archs.cu)\n\n    file(WRITE ${__cufile} \"\"\n      \"#include <cstdio>\\n\"\n      \"#include <iostream>\\n\"\n      \"using namespace std;\\n\"\n      \"int main()\\n\"\n      \"{\\n\"\n      \"  int count = 0;\\n\"\n      \"  if (cudaSuccess != cudaGetDeviceCount(&count)) { return -1; }\\n\"\n      \"  if (count == 0) { cerr << \\\"No cuda devices detected\\\" << endl; return -1; }\\n\"\n      \"  for (int device = 0; device < count; ++device)\\n\"\n      \"  {\\n\"\n      \"    cudaDeviceProp prop;\\n\"\n      \"    if (cudaSuccess == cudaGetDeviceProperties(&prop, device))\\n\"\n      \"      std::printf(\\\"%d.%d \\\", prop.major, prop.minor);\\n\"\n      \"  }\\n\"\n      \"  return 0;\\n\"\n      \"}\\n\")\n    if(MSVC)\n      #find vcvarsall.bat and run it building msvc environment\n      get_filename_component(MY_COMPILER_DIR ${CMAKE_CXX_COMPILER} DIRECTORY)\n      find_file(MY_VCVARSALL_BAT vcvarsall.bat \"${MY_COMPILER_DIR}/..\" \"${MY_COMPILER_DIR}/../..\")\n      execute_process(COMMAND ${MY_VCVARSALL_BAT} && ${CUDA_NVCC_EXECUTABLE} -arch native --run  ${__cufile}\n                      WORKING_DIRECTORY \"${PROJECT_BINARY_DIR}/CMakeFiles/\"\n                      RESULT_VARIABLE __nvcc_res OUTPUT_VARIABLE __nvcc_out\n                      OUTPUT_STRIP_TRAILING_WHITESPACE)\n    else()\n      if(CUDA_LIBRARY_PATH)\n        set(CUDA_LINK_LIBRARY_PATH \"-L${CUDA_LIBRARY_PATH}\")\n      endif()\n      execute_process(COMMAND ${CUDA_NVCC_EXECUTABLE} -arch native --run ${__cufile} ${CUDA_LINK_LIBRARY_PATH}\n                      WORKING_DIRECTORY \"${PROJECT_BINARY_DIR}/CMakeFiles/\"\n                      RESULT_VARIABLE __nvcc_res OUTPUT_VARIABLE __nvcc_out\n                      OUTPUT_STRIP_TRAILING_WHITESPACE)\n    endif()\n    if(__nvcc_res EQUAL 0)\n      # nvcc outputs text containing line breaks when building with MSVC.\n      # The line below prevents CMake from inserting a variable with line\n      # breaks in the cache\n      message(STATUS \"Found GPU arch ${__nvcc_out}\")\n      string(REGEX MATCH \"([1-9].[0-9])\" __nvcc_out \"${__nvcc_out}\")\n      if(__nvcc_out VERSION_LESS \"3.5\")\n        # drop support for cc < 3.5 and build for all known archs.\n        message(WARNING \"GPU arch less than 3.5 is not supported.\")\n      else()\n        set(CUDA_gpu_detect_output ${__nvcc_out} CACHE INTERNAL \"Returned GPU architetures from mshadow_detect_gpus tool\" FORCE)\n      endif()\n    else()\n      message(WARNING \"Running GPU detection script with nvcc failed: ${__nvcc_out}\")\n    endif()\n  endif()\n\n  if(NOT CUDA_gpu_detect_output)\n    message(WARNING \"Automatic GPU detection failed. Building for all known architectures (${dgl_known_gpu_archs}).\")\n    set(${out_variable} ${dgl_known_gpu_archs} PARENT_SCOPE)\n  else()\n    set(${out_variable} ${CUDA_gpu_detect_output} PARENT_SCOPE)\n  endif()\nendfunction()\n\n\n################################################################################################\n# Function for selecting GPU arch flags for nvcc based on CUDA_ARCH_NAME\n# Usage:\n#   dgl_select_nvcc_arch_flags(out_variable)\nfunction(dgl_select_nvcc_arch_flags out_variable)\n  # List of arch names. Turing and Ada don't have a new major version, so they are not added to default build.\n  set(__archs_names \"Kepler\" \"Maxwell\" \"Pascal\" \"Volta\" \"Turing\" \"Ampere\" \"Ada\" \"Hopper\" \"All\" \"Manual\")\n  if (NOT CUDA_VERSION VERSION_LESS \"12.0\")\n    list(REMOVE_ITEM __archs_names \"Kepler\")\n  endif()\n  set(__archs_name_default \"All\")\n  if(NOT CMAKE_CROSSCOMPILING)\n    list(APPEND __archs_names \"Auto\")\n    set(__archs_name_default \"Auto\")\n  endif()\n\n  # set CUDA_ARCH_NAME strings (so it will be seen as dropbox in CMake-Gui)\n  set(CUDA_ARCH_NAME ${__archs_name_default} CACHE STRING \"Select target NVIDIA GPU achitecture.\")\n  set_property( CACHE CUDA_ARCH_NAME PROPERTY STRINGS \"\" ${__archs_names} )\n  mark_as_advanced(CUDA_ARCH_NAME)\n\n  # verify CUDA_ARCH_NAME value\n  if(NOT \";${__archs_names};\" MATCHES \";${CUDA_ARCH_NAME};\")\n    string(REPLACE \";\" \", \" __archs_names \"${__archs_names}\")\n    message(FATAL_ERROR \"Only ${__archs_names} architeture names are supported.\")\n  endif()\n\n  if(${CUDA_ARCH_NAME} STREQUAL \"Manual\")\n    set(CUDA_ARCH_BIN ${dgl_known_gpu_archs} CACHE STRING \"Specify 'real' GPU architectures to build binaries for, BIN(PTX) format is supported\")\n    set(CUDA_ARCH_PTX ${dgl_cuda_arch_ptx} CACHE STRING \"Specify 'virtual' PTX architectures to build PTX intermediate code for\")\n    mark_as_advanced(CUDA_ARCH_BIN CUDA_ARCH_PTX)\n  else()\n    unset(CUDA_ARCH_BIN CACHE)\n    unset(CUDA_ARCH_PTX CACHE)\n  endif()\n\n  if(${CUDA_ARCH_NAME} STREQUAL \"Kepler\")\n    set(__cuda_arch_bin \"35\")\n    set(__cuda_arch_ptx \"35\")\n  elseif(${CUDA_ARCH_NAME} STREQUAL \"Maxwell\")\n    set(__cuda_arch_bin \"50\")\n    set(__cuda_arch_ptx \"50\")\n  elseif(${CUDA_ARCH_NAME} STREQUAL \"Pascal\")\n    set(__cuda_arch_bin \"60\")\n    set(__cuda_arch_ptx \"60\")\n  elseif(${CUDA_ARCH_NAME} STREQUAL \"Volta\")\n    set(__cuda_arch_bin \"70\")\n    set(__cuda_arch_ptx \"70\")\n  elseif(${CUDA_ARCH_NAME} STREQUAL \"Turing\")\n    set(__cuda_arch_bin \"75\")\n    set(__cuda_arch_ptx \"75\")\n  elseif(${CUDA_ARCH_NAME} STREQUAL \"Ampere\")\n    set(__cuda_arch_bin \"80\")\n    set(__cuda_arch_ptx \"80\")\n  elseif(${CUDA_ARCH_NAME} STREQUAL \"Ada\")\n    set(__cuda_arch_bin \"89\")\n    set(__cuda_arch_ptx \"89\")\n  elseif(${CUDA_ARCH_NAME} STREQUAL \"Hopper\")\n    set(__cuda_arch_bin \"90\")\n    set(__cuda_arch_ptx \"90\")\n  elseif(${CUDA_ARCH_NAME} STREQUAL \"All\")\n    set(__cuda_arch_bin ${dgl_known_gpu_archs})\n    set(__cuda_arch_ptx ${dgl_cuda_arch_ptx})\n  elseif(${CUDA_ARCH_NAME} STREQUAL \"Auto\")\n    dgl_detect_installed_gpus(__cuda_arch_bin)\n    # if detect successes, __cuda_arch_ptx = __cuda_arch_bin\n    # if detect fails, __cuda_arch_ptx is the latest arch in __cuda_arch_bin\n    list(GET __cuda_arch_bin -1 __cuda_arch_ptx)\n  else()  # (${CUDA_ARCH_NAME} STREQUAL \"Manual\")\n    set(__cuda_arch_bin ${CUDA_ARCH_BIN})\n    set(__cuda_arch_ptx ${CUDA_ARCH_PTX})\n  endif()\n\n  # remove dots and convert to lists\n  string(REGEX REPLACE \"\\\\.\" \"\" __cuda_arch_bin \"${__cuda_arch_bin}\")\n  string(REGEX REPLACE \"\\\\.\" \"\" __cuda_arch_ptx \"${__cuda_arch_ptx}\")\n  string(REGEX MATCHALL \"[0-9()]+\" __cuda_arch_bin \"${__cuda_arch_bin}\")\n  string(REGEX MATCHALL \"[0-9]+\"   __cuda_arch_ptx \"${__cuda_arch_ptx}\")\n  mshadow_list_unique(__cuda_arch_bin __cuda_arch_ptx)\n\n  set(__nvcc_flags \"--expt-relaxed-constexpr\")\n  set(__nvcc_archs_readable \"\")\n  set(__archs \"\")\n\n  # Tell NVCC to add binaries for the specified GPUs\n  foreach(__arch ${__cuda_arch_bin})\n    if(__arch MATCHES \"([0-9]+)\\\\(([0-9]+)\\\\)\")\n      # User explicitly specified PTX for the concrete BIN\n      list(APPEND __nvcc_flags -gencode arch=compute_${CMAKE_MATCH_2},code=sm_${CMAKE_MATCH_1})\n      list(APPEND __nvcc_archs_readable sm_${CMAKE_MATCH_1})\n      list(APPEND __archs ${CMAKE_MATCH_1})\n    else()\n      # User didn't explicitly specify PTX for the concrete BIN, we assume PTX=BIN\n      list(APPEND __nvcc_flags -gencode arch=compute_${__arch},code=sm_${__arch})\n      list(APPEND __nvcc_archs_readable sm_${__arch})\n      list(APPEND __archs ${__arch})\n    endif()\n  endforeach()\n\n  # Tell NVCC to add PTX intermediate code for the specified architectures\n  foreach(__arch ${__cuda_arch_ptx})\n    list(APPEND __nvcc_flags -gencode arch=compute_${__arch},code=compute_${__arch})\n    list(APPEND __nvcc_archs_readable compute_${__arch})\n  endforeach()\n\n  string(REPLACE \";\" \" \" __nvcc_archs_readable \"${__nvcc_archs_readable}\")\n  set(${out_variable}          ${__nvcc_flags}          PARENT_SCOPE)\n  set(${out_variable}_readable ${__nvcc_archs_readable} PARENT_SCOPE)\n  set(CUDA_ARCHITECTURES       ${__archs}               PARENT_SCOPE)\nendfunction()\n\n################################################################################################\n# Config cuda compilation and append CUDA libraries to linker_libs\n# Usage:\n#  dgl_config_cuda(linker_libs)\nmacro(dgl_config_cuda linker_libs)\n  if(NOT CUDA_FOUND)\n    message(FATAL_ERROR \"Cannot find CUDA.\")\n  endif()\n  # always set the includedir when cuda is available\n  # avoid global retrigger of cmake\n\tinclude_directories(${CUDA_INCLUDE_DIRS})\n\n  add_definitions(-DDGL_USE_CUDA)\n\n  # NVCC flags\n  # Manually set everything\n  set(CUDA_PROPAGATE_HOST_FLAGS OFF)\n\n  # 0. Add host flags\n  message(STATUS \"CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}\")\n  string(REGEX REPLACE \"[ \\t\\n\\r]\" \",\" CXX_HOST_FLAGS \"${CMAKE_CXX_FLAGS}\")\n  if(MSVC AND NOT USE_MSVC_MT)\n    string(CONCAT CXX_HOST_FLAGS ${CXX_HOST_FLAGS} \",/MD\")\n  endif()\n  list(APPEND CUDA_NVCC_FLAGS \"-Xcompiler\" \"${CXX_HOST_FLAGS}\")\n  if(USE_OPENMP)\n    # Needed by CUDA disjoint union source file.\n    list(APPEND CUDA_NVCC_FLAGS \"-Xcompiler\" \"${OpenMP_CXX_FLAGS}\")\n  endif(USE_OPENMP)\n\n  # 1. Add arch flags\n  dgl_select_nvcc_arch_flags(NVCC_FLAGS_ARCH)\n  list(APPEND CUDA_NVCC_FLAGS ${NVCC_FLAGS_ARCH})\n\n  # 2. flags in third_party/moderngpu\n  list(APPEND CUDA_NVCC_FLAGS \"--expt-extended-lambda;-Wno-deprecated-declarations;-std=c++17\")\n\n  message(STATUS \"CUDA_NVCC_FLAGS: ${CUDA_NVCC_FLAGS}\")\n\n  list(APPEND ${linker_libs} \n    ${CUDA_CUDART_LIBRARY}\n    ${CUDA_CUBLAS_LIBRARIES}\n    ${CUDA_cusparse_LIBRARY})\nendmacro()\n"
  },
  {
    "path": "cmake/modules/FindMETIS.cmake",
    "content": "# Find the METIS includes and library\n#\n# This module defines\n#  METIS_INCLUDE_DIR        -    where to find metis.h\n#  METIS_LIBRARIES          -    libraries to link against to use METIS.\n#  METIS_FOUND              -    METIS library was found\n\nINCLUDE(FindPackageHandleStandardArgs)\n\nFIND_PATH(METIS_INCLUDE_DIR\n    NAMES\n    \"metis.h\"\n    PATHS\n    ${EXTERNAL_METIS_PATH}\n    )\n\n\nFIND_LIBRARY(METIS_LIBRARIES\n    NAMES\n    libmetis metis\n    PATHS\n    ${EXTERNAL_METIS_LIB_PATH}\n    )\n\n\nFIND_PACKAGE_HANDLE_STANDARD_ARGS(METIS DEFAULT_MSG METIS_INCLUDE_DIR METIS_LIBRARIES)\nMARK_AS_ADVANCED(METIS_LIBRARIES METIS_INCLUDE_DIR)\n"
  },
  {
    "path": "cmake/util/FindCUDA.cmake",
    "content": "#######################################################\n# Enhanced version of find CUDA.\n#\n# Usage:\n#   find_cuda(${USE_CUDA})\n#\n# - When USE_CUDA=ON, use auto search\n#\n# Please use the CMAKE variable CUDA_TOOLKIT_ROOT_DIR to set CUDA directory\n#\n# Provide variables:\n#\n# - CUDA_FOUND\n# - CUDA_INCLUDE_DIRS\n# - CUDA_TOOLKIT_ROOT_DIR\n# - CUDA_CUDA_LIBRARY\n# - CUDA_CUDART_LIBRARY\n# - CUDA_NVRTC_LIBRARY\n# - CUDA_CUDNN_LIBRARY\n# - CUDA_CUBLAS_LIBRARY\n#\nmacro(find_cuda use_cuda)\n  set(__use_cuda ${use_cuda})\n  if(__use_cuda STREQUAL \"ON\")\n    include(FindCUDA)\n  endif()\n\n  # additional libraries\n  if(CUDA_FOUND)\n    if(MSVC)\n      find_library(CUDA_CUDA_LIBRARY cuda\n        ${CUDA_TOOLKIT_ROOT_DIR}/lib/x64\n        ${CUDA_TOOLKIT_ROOT_DIR}/lib/Win32)\n      find_library(CUDA_NVRTC_LIBRARY nvrtc\n        ${CUDA_TOOLKIT_ROOT_DIR}/lib/x64\n        ${CUDA_TOOLKIT_ROOT_DIR}/lib/Win32)\n      find_library(CUDA_CUDNN_LIBRARY cudnn\n        ${CUDA_TOOLKIT_ROOT_DIR}/lib/x64\n        ${CUDA_TOOLKIT_ROOT_DIR}/lib/Win32)\n      find_library(CUDA_CUBLAS_LIBRARY cublas\n        ${CUDA_TOOLKIT_ROOT_DIR}/lib/x64\n        ${CUDA_TOOLKIT_ROOT_DIR}/lib/Win32)\n      find_library(CUDA_CURAND_LIBRARY curand\n        ${CUDA_TOOLKIT_ROOT_DIR}/lib/x64\n        ${CUDA_TOOLKIT_ROOT_DIR}/lib/Win32)\n    else(MSVC)\n      #find_library(CUDA_CUDA_LIBRARY cuda\n      #  PATHS ${CUDA_TOOLKIT_ROOT_DIR}\n      #  PATH_SUFFIXES lib lib64 targets/x86_64-linux/lib\n      #  NO_DEFAULT_PATH)\n      find_library(CUDA_CUBLAS_LIBRARY cublas\n        ${CUDA_TOOLKIT_ROOT_DIR}/lib64\n        ${CUDA_TOOLKIT_ROOT_DIR}/lib)\n      find_library(CUDA_CURAND_LIBRARY curand\n        ${CUDA_TOOLKIT_ROOT_DIR}/lib64\n        ${CUDA_TOOLKIT_ROOT_DIR}/lib)\n    endif(MSVC)\n    message(STATUS \"Found CUDA_TOOLKIT_ROOT_DIR=\" ${CUDA_TOOLKIT_ROOT_DIR})\n    #message(STATUS \"Found CUDA_CUDA_LIBRARY=\" ${CUDA_CUDA_LIBRARY})\n    message(STATUS \"Found CUDA_CUDART_LIBRARY=\" ${CUDA_CUDART_LIBRARY})\n    #message(STATUS \"Found CUDA_NVRTC_LIBRARY=\" ${CUDA_NVRTC_LIBRARY})\n    #message(STATUS \"Found CUDA_CUDNN_LIBRARY=\" ${CUDA_CUDNN_LIBRARY})\n    message(STATUS \"Found CUDA_CUBLAS_LIBRARY=\" ${CUDA_CUBLAS_LIBRARY})\n    message(STATUS \"Found CUDA_CURAND_LIBRARY=\" ${CUDA_CURAND_LIBRARY})\n  endif(CUDA_FOUND)\nendmacro(find_cuda)\n"
  },
  {
    "path": "cmake/util/MshadowUtil.cmake",
    "content": "################################################################################################\n# Command alias for debugging messages\n# Usage:\n#   dmsg(<message>)\nfunction(dmsg)\n  message(STATUS ${ARGN})\nendfunction()\n\n################################################################################################\n# Removes duplicates from list(s)\n# Usage:\n#   mshadow_list_unique(<list_variable> [<list_variable>] [...])\nmacro(mshadow_list_unique)\n  foreach(__lst ${ARGN})\n    if(${__lst})\n      list(REMOVE_DUPLICATES ${__lst})\n    endif()\n  endforeach()\nendmacro()\n\n################################################################################################\n# Clears variables from list\n# Usage:\n#   mshadow_clear_vars(<variables_list>)\nmacro(mshadow_clear_vars)\n  foreach(_var ${ARGN})\n    unset(${_var})\n  endforeach()\nendmacro()\n\n################################################################################################\n# Removes duplicates from string\n# Usage:\n#   mshadow_string_unique(<string_variable>)\nfunction(mshadow_string_unique __string)\n  if(${__string})\n    set(__list ${${__string}})\n    separate_arguments(__list)\n    list(REMOVE_DUPLICATES __list)\n    foreach(__e ${__list})\n      set(__str \"${__str} ${__e}\")\n    endforeach()\n    set(${__string} ${__str} PARENT_SCOPE)\n  endif()\nendfunction()\n\n################################################################################################\n# Prints list element per line\n# Usage:\n#   mshadow_print_list(<list>)\nfunction(mshadow_print_list)\n  foreach(e ${ARGN})\n    message(STATUS ${e})\n  endforeach()\nendfunction()\n\n################################################################################################\n# Function merging lists of compiler flags to single string.\n# Usage:\n#   mshadow_merge_flag_lists(out_variable <list1> [<list2>] [<list3>] ...)\nfunction(mshadow_merge_flag_lists out_var)\n  set(__result \"\")\n  foreach(__list ${ARGN})\n    foreach(__flag ${${__list}})\n      string(STRIP ${__flag} __flag)\n      set(__result \"${__result} ${__flag}\")\n    endforeach()\n  endforeach()\n  string(STRIP ${__result} __result)\n  set(${out_var} ${__result} PARENT_SCOPE)\nendfunction()\n\n################################################################################################\n# Converts all paths in list to absolute\n# Usage:\n#   mshadow_convert_absolute_paths(<list_variable>)\nfunction(mshadow_convert_absolute_paths variable)\n  set(__dlist \"\")\n  foreach(__s ${${variable}})\n    get_filename_component(__abspath ${__s} ABSOLUTE)\n    list(APPEND __list ${__abspath})\n  endforeach()\n  set(${variable} ${__list} PARENT_SCOPE)\nendfunction()\n\n################################################################################################\n# Reads set of version defines from the header file\n# Usage:\n#   mshadow_parse_header(<file> <define1> <define2> <define3> ..)\nmacro(mshadow_parse_header FILENAME FILE_VAR)\n  set(vars_regex \"\")\n  set(__parnet_scope OFF)\n  set(__add_cache OFF)\n  foreach(name ${ARGN})\n    if(\"${name}\" STREQUAL \"PARENT_SCOPE\")\n      set(__parnet_scope ON)\n    elseif(\"${name}\" STREQUAL \"CACHE\")\n      set(__add_cache ON)\n    elseif(vars_regex)\n      set(vars_regex \"${vars_regex}|${name}\")\n    else()\n      set(vars_regex \"${name}\")\n    endif()\n  endforeach()\n  if(EXISTS \"${FILENAME}\")\n    file(STRINGS \"${FILENAME}\" ${FILE_VAR} REGEX \"#define[ \\t]+(${vars_regex})[ \\t]+[0-9]+\" )\n  else()\n    unset(${FILE_VAR})\n  endif()\n  foreach(name ${ARGN})\n    if(NOT \"${name}\" STREQUAL \"PARENT_SCOPE\" AND NOT \"${name}\" STREQUAL \"CACHE\")\n      if(${FILE_VAR})\n        if(${FILE_VAR} MATCHES \".+[ \\t]${name}[ \\t]+([0-9]+).*\")\n          string(REGEX REPLACE \".+[ \\t]${name}[ \\t]+([0-9]+).*\" \"\\\\1\" ${name} \"${${FILE_VAR}}\")\n        else()\n          set(${name} \"\")\n        endif()\n        if(__add_cache)\n          set(${name} ${${name}} CACHE INTERNAL \"${name} parsed from ${FILENAME}\" FORCE)\n        elseif(__parnet_scope)\n          set(${name} \"${${name}}\" PARENT_SCOPE)\n        endif()\n      else()\n        unset(${name} CACHE)\n      endif()\n    endif()\n  endforeach()\nendmacro()\n\n################################################################################################\n# Reads single version define from the header file and parses it\n# Usage:\n#   mshadow_parse_header_single_define(<library_name> <file> <define_name>)\nfunction(mshadow_parse_header_single_define LIBNAME HDR_PATH VARNAME)\n  set(${LIBNAME}_H \"\")\n  if(EXISTS \"${HDR_PATH}\")\n    file(STRINGS \"${HDR_PATH}\" ${LIBNAME}_H REGEX \"^#define[ \\t]+${VARNAME}[ \\t]+\\\"[^\\\"]*\\\".*$\" LIMIT_COUNT 1)\n  endif()\n\n  if(${LIBNAME}_H)\n    string(REGEX REPLACE \"^.*[ \\t]${VARNAME}[ \\t]+\\\"([0-9]+).*$\" \"\\\\1\" ${LIBNAME}_VERSION_MAJOR \"${${LIBNAME}_H}\")\n    string(REGEX REPLACE \"^.*[ \\t]${VARNAME}[ \\t]+\\\"[0-9]+\\\\.([0-9]+).*$\" \"\\\\1\" ${LIBNAME}_VERSION_MINOR  \"${${LIBNAME}_H}\")\n    string(REGEX REPLACE \"^.*[ \\t]${VARNAME}[ \\t]+\\\"[0-9]+\\\\.[0-9]+\\\\.([0-9]+).*$\" \"\\\\1\" ${LIBNAME}_VERSION_PATCH \"${${LIBNAME}_H}\")\n    set(${LIBNAME}_VERSION_MAJOR ${${LIBNAME}_VERSION_MAJOR} ${ARGN} PARENT_SCOPE)\n    set(${LIBNAME}_VERSION_MINOR ${${LIBNAME}_VERSION_MINOR} ${ARGN} PARENT_SCOPE)\n    set(${LIBNAME}_VERSION_PATCH ${${LIBNAME}_VERSION_PATCH} ${ARGN} PARENT_SCOPE)\n    set(${LIBNAME}_VERSION_STRING \"${${LIBNAME}_VERSION_MAJOR}.${${LIBNAME}_VERSION_MINOR}.${${LIBNAME}_VERSION_PATCH}\" PARENT_SCOPE)\n\n    # append a TWEAK version if it exists:\n    set(${LIBNAME}_VERSION_TWEAK \"\")\n    if(\"${${LIBNAME}_H}\" MATCHES \"^.*[ \\t]${VARNAME}[ \\t]+\\\"[0-9]+\\\\.[0-9]+\\\\.[0-9]+\\\\.([0-9]+).*$\")\n      set(${LIBNAME}_VERSION_TWEAK \"${CMAKE_MATCH_1}\" ${ARGN} PARENT_SCOPE)\n    endif()\n    if(${LIBNAME}_VERSION_TWEAK)\n      set(${LIBNAME}_VERSION_STRING \"${${LIBNAME}_VERSION_STRING}.${${LIBNAME}_VERSION_TWEAK}\" ${ARGN} PARENT_SCOPE)\n    else()\n      set(${LIBNAME}_VERSION_STRING \"${${LIBNAME}_VERSION_STRING}\" ${ARGN} PARENT_SCOPE)\n    endif()\n  endif()\nendfunction()\n\n########################################################################################################\n# An option that the user can select. Can accept condition to control when option is available for user.\n# Usage:\n#   mshadow_option(<option_variable> \"doc string\" <initial value or boolean expression> [IF <condition>])\nfunction(mshadow_option variable description value)\n  set(__value ${value})\n  set(__condition \"\")\n  set(__varname \"__value\")\n  foreach(arg ${ARGN})\n    if(arg STREQUAL \"IF\" OR arg STREQUAL \"if\")\n      set(__varname \"__condition\")\n    else()\n      list(APPEND ${__varname} ${arg})\n    endif()\n  endforeach()\n  unset(__varname)\n  if(\"${__condition}\" STREQUAL \"\")\n    set(__condition 2 GREATER 1)\n  endif()\n\n  if(${__condition})\n    if(\"${__value}\" MATCHES \";\")\n      if(${__value})\n        option(${variable} \"${description}\" ON)\n      else()\n        option(${variable} \"${description}\" OFF)\n      endif()\n    elseif(DEFINED ${__value})\n      if(${__value})\n        option(${variable} \"${description}\" ON)\n      else()\n        option(${variable} \"${description}\" OFF)\n      endif()\n    else()\n      option(${variable} \"${description}\" ${__value})\n    endif()\n  else()\n    unset(${variable} CACHE)\n  endif()\nendfunction()\n\n################################################################################################\n# Utility macro for comparing two lists. Used for CMake debugging purposes\n# Usage:\n#   mshadow_compare_lists(<list_variable> <list2_variable> [description])\nfunction(mshadow_compare_lists list1 list2 desc)\n  set(__list1 ${${list1}})\n  set(__list2 ${${list2}})\n  list(SORT __list1)\n  list(SORT __list2)\n  list(LENGTH __list1 __len1)\n  list(LENGTH __list2 __len2)\n\n  if(NOT ${__len1} EQUAL ${__len2})\n    message(FATAL_ERROR \"Lists are not equal. ${__len1} != ${__len2}. ${desc}\")\n  endif()\n\n  foreach(__i RANGE 1 ${__len1})\n    math(EXPR __index \"${__i}- 1\")\n    list(GET __list1 ${__index} __item1)\n    list(GET __list2 ${__index} __item2)\n    if(NOT ${__item1} STREQUAL ${__item2})\n      message(FATAL_ERROR \"Lists are not equal. Differ at element ${__index}. ${desc}\")\n    endif()\n  endforeach()\nendfunction()\n\n################################################################################################\n# Command for disabling warnings for different platforms (see below for gcc and VisualStudio)\n# Usage:\n#   mshadow_warnings_disable(<CMAKE_[C|CXX]_FLAGS[_CONFIGURATION]> -Wshadow /wd4996 ..,)\nmacro(mshadow_warnings_disable)\n  set(_flag_vars \"\")\n  set(_msvc_warnings \"\")\n  set(_gxx_warnings \"\")\n\n  foreach(arg ${ARGN})\n    if(arg MATCHES \"^CMAKE_\")\n      list(APPEND _flag_vars ${arg})\n    elseif(arg MATCHES \"^/wd\")\n      list(APPEND _msvc_warnings ${arg})\n    elseif(arg MATCHES \"^-W\")\n      list(APPEND _gxx_warnings ${arg})\n    endif()\n  endforeach()\n\n  if(NOT _flag_vars)\n    set(_flag_vars CMAKE_C_FLAGS CMAKE_CXX_FLAGS)\n  endif()\n\n  if(MSVC AND _msvc_warnings)\n    foreach(var ${_flag_vars})\n      foreach(warning ${_msvc_warnings})\n        set(${var} \"${${var}} ${warning}\")\n      endforeach()\n    endforeach()\n  elseif((CMAKE_COMPILER_IS_GNUCXX OR CMAKE_COMPILER_IS_CLANGXX) AND _gxx_warnings)\n    foreach(var ${_flag_vars})\n      foreach(warning ${_gxx_warnings})\n        if(NOT warning MATCHES \"^-Wno-\")\n          string(REPLACE \"${warning}\" \"\" ${var} \"${${var}}\")\n          string(REPLACE \"-W\" \"-Wno-\" warning \"${warning}\")\n        endif()\n        set(${var} \"${${var}} ${warning}\")\n      endforeach()\n    endforeach()\n  endif()\n  mshadow_clear_vars(_flag_vars _msvc_warnings _gxx_warnings)\nendmacro()\n\n################################################################################################\n# Helper function get current definitions\n# Usage:\n#   mshadow_get_current_definitions(<definitions_variable>)\nfunction(mshadow_get_current_definitions definitions_var)\n  get_property(current_definitions DIRECTORY PROPERTY COMPILE_DEFINITIONS)\n  set(result \"\")\n\n  foreach(d ${current_definitions})\n    list(APPEND result -D${d})\n  endforeach()\n\n  mshadow_list_unique(result)\n  set(${definitions_var} ${result} PARENT_SCOPE)\nendfunction()\n\n################################################################################################\n# Helper function get current includes/definitions\n# Usage:\n#   mshadow_get_current_cflags(<cflagslist_variable>)\nfunction(mshadow_get_current_cflags cflags_var)\n  get_property(current_includes DIRECTORY PROPERTY INCLUDE_DIRECTORIES)\n  mshadow_convert_absolute_paths(current_includes)\n  mshadow_get_current_definitions(cflags)\n\n  foreach(i ${current_includes})\n    list(APPEND cflags \"-I${i}\")\n  endforeach()\n\n  mshadow_list_unique(cflags)\n  set(${cflags_var} ${cflags} PARENT_SCOPE)\nendfunction()\n\n################################################################################################\n# Helper function to parse current linker libs into link directories, libflags and osx frameworks\n# Usage:\n#   mshadow_parse_linker_libs(<mshadow_LINKER_LIBS_var> <directories_var> <libflags_var> <frameworks_var>)\nfunction(mshadow_parse_linker_libs mshadow_LINKER_LIBS_variable folders_var flags_var frameworks_var)\n\n  set(__unspec \"\")\n  set(__debug \"\")\n  set(__optimized \"\")\n  set(__framework \"\")\n  set(__varname \"__unspec\")\n\n  # split libs into debug, optimized, unspecified and frameworks\n  foreach(list_elem ${${mshadow_LINKER_LIBS_variable}})\n    if(list_elem STREQUAL \"debug\")\n      set(__varname \"__debug\")\n    elseif(list_elem STREQUAL \"optimized\")\n      set(__varname \"__optimized\")\n    elseif(list_elem MATCHES \"^-framework[ \\t]+([^ \\t].*)\")\n      list(APPEND __framework -framework ${CMAKE_MATCH_1})\n    else()\n      list(APPEND ${__varname} ${list_elem})\n      set(__varname \"__unspec\")\n    endif()\n  endforeach()\n\n  # attach debug or optimized libs to unspecified according to current configuration\n  if(CMAKE_BUILD_TYPE MATCHES \"Debug\")\n    set(__libs ${__unspec} ${__debug})\n  else()\n    set(__libs ${__unspec} ${__optimized})\n  endif()\n\n  set(libflags \"\")\n  set(folders \"\")\n\n  # convert linker libraries list to link flags\n  foreach(lib ${__libs})\n    if(TARGET ${lib})\n      list(APPEND folders $<TARGET_LINKER_FILE_DIR:${lib}>)\n      list(APPEND libflags -l${lib})\n    elseif(lib MATCHES \"^-l.*\")\n      list(APPEND libflags ${lib})\n    elseif(IS_ABSOLUTE ${lib})\n      get_filename_component(name_we ${lib} NAME_WE)\n      get_filename_component(folder  ${lib} PATH)\n\n      string(REGEX MATCH \"^lib(.*)\" __match ${name_we})\n      list(APPEND libflags -l${CMAKE_MATCH_1})\n      list(APPEND folders    ${folder})\n    else()\n      message(FATAL_ERROR \"Logic error. Need to update cmake script\")\n    endif()\n  endforeach()\n\n  mshadow_list_unique(libflags folders)\n\n  set(${folders_var} ${folders} PARENT_SCOPE)\n  set(${flags_var} ${libflags} PARENT_SCOPE)\n  set(${frameworks_var} ${__framework} PARENT_SCOPE)\nendfunction()\n\n################################################################################################\n# Helper function to detect Darwin version, i.e. 10.8, 10.9, 10.10, ....\n# Usage:\n#   mshadow_detect_darwin_version(<version_variable>)\nfunction(mshadow_detect_darwin_version output_var)\n  if(APPLE)\n    execute_process(COMMAND /usr/bin/sw_vers -productVersion\n                    RESULT_VARIABLE __sw_vers OUTPUT_VARIABLE __sw_vers_out\n                    ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE)\n\n    set(${output_var} ${__sw_vers_out} PARENT_SCOPE)\n  else()\n    set(${output_var} \"\" PARENT_SCOPE)\n  endif()\nendfunction()\n\n################################################################################################\n# Convenient command to setup source group for IDEs that support this feature (VS, XCode)\n# Usage:\n#   caffe_source_group(<group> GLOB[_RECURSE] <globbing_expression>)\nfunction(mshadow_source_group group)\n  cmake_parse_arguments(CAFFE_SOURCE_GROUP \"\" \"\" \"GLOB;GLOB_RECURSE\" ${ARGN})\n  if(CAFFE_SOURCE_GROUP_GLOB)\n    file(GLOB srcs1 ${CAFFE_SOURCE_GROUP_GLOB})\n    source_group(${group} FILES ${srcs1})\n  endif()\n\n  if(CAFFE_SOURCE_GROUP_GLOB_RECURSE)\n    file(GLOB_RECURSE srcs2 ${CAFFE_SOURCE_GROUP_GLOB_RECURSE})\n    source_group(${group} FILES ${srcs2})\n  endif()\nendfunction()"
  },
  {
    "path": "cmake/util/Util.cmake",
    "content": "# NOTE: __dgl_option will not reset existing variables.\nmacro(__dgl_option variable description value)\n  if(NOT DEFINED ${variable})\n    set(${variable} ${value} CACHE STRING ${description})\n  endif()\nendmacro()\n\n#######################################################\n# An option to specify the build type for a feature.\n# Usage:\n#   dgl_feature_option(<option_variable> \"doc string\" \"dev\" \"release\")\nmacro(dgl_feature_option variable description)\n  set(__value \"\")\n  foreach(arg ${ARGN})\n    if(arg STREQUAL \"all\")\n      __dgl_option(${variable} \"${description}\" ON)\n    elseif(arg STREQUAL \"dev\" OR arg STREQUAL \"dogfood\" OR arg STREQUAL \"release\")\n      list(APPEND __value ${arg})\n    endif()\n  endforeach()\n\n  if(${BUILD_TYPE} IN_LIST __value)\n    __dgl_option(${variable} \"${description}\" ON)\n  else()\n    # NOTE: __dgl_option will not reset existing variables.\n    __dgl_option(${variable} \"${description}\" OFF)\n  endif()\nendmacro()\n\n#######################################################\n# An option that the user can select. Can accept condition to control when option is available for user.\n# Usage:\n#   dgl_option(<option_variable> \"doc string\" <initial value or boolean expression> [IF <condition>])\nmacro(dgl_option variable description value)\n  set(__value ${value})\n  set(__condition \"\")\n  set(__varname \"__value\")\n  foreach(arg ${ARGN})\n    if(arg STREQUAL \"IF\" OR arg STREQUAL \"if\")\n      set(__varname \"__condition\")\n    else()\n      list(APPEND ${__varname} ${arg})\n    endif()\n  endforeach()\n  unset(__varname)\n  if(\"${__condition}\" STREQUAL \"\")\n    set(__condition 2 GREATER 1)\n  endif()\n\n  if(${__condition})\n    if(\"${__value}\" MATCHES \";\")\n      if(${__value})\n        __dgl_option(${variable} \"${description}\" ON)\n      else()\n        __dgl_option(${variable} \"${description}\" OFF)\n      endif()\n    elseif(DEFINED ${__value})\n      if(${__value})\n        __dgl_option(${variable} \"${description}\" ON)\n      else()\n        __dgl_option(${variable} \"${description}\" OFF)\n      endif()\n    else()\n      __dgl_option(${variable} \"${description}\" \"${__value}\")\n    endif()\n  else()\n    unset(${variable} CACHE)\n  endif()\nendmacro()\n"
  },
  {
    "path": "conda/dgl/README.md",
    "content": "conda recipe\n===\n\nBuild the package with `conda build .`\n"
  },
  {
    "path": "conda/dgl/bld.bat",
    "content": "REM Needs vcvars64.bat to be called\ngit submodule init\ngit submodule update --recursive\nmd build\ncd build\nCOPY %TEMP%\\dgl.dll .\ncd ..\\python\n\"%PYTHON%\" setup.py install --single-version-externally-managed --record=record.txt || EXIT /B 1\nEXIT /B\n"
  },
  {
    "path": "conda/dgl/build.sh",
    "content": "git submodule init\ngit submodule update --recursive\nmkdir build\ncd build\ncmake -DUSE_CUDA=$USE_CUDA -DUSE_OPENMP=ON -DCUDA_ARCH_NAME=All ..\nmake\ncd ../python\n$PYTHON setup.py install --single-version-externally-managed --record=record.txt\n"
  },
  {
    "path": "conda/dgl/conda_build_config.yaml",
    "content": "python:\n  - 3.8\n  - 3.9\n  - 3.10\n  - 3.11\n  - 3.12\n"
  },
  {
    "path": "conda/dgl/meta.yaml",
    "content": "package:\n  name: dgl{{ environ.get('DGL_PACKAGE_SUFFIX', '') }}\n  version: 2.5{{ environ.get('DGL_VERSION_SUFFIX', '') }}\n\nsource:\n  git_rev: {{ environ.get('DGL_RELEASE_BRANCH', 'master') }}\n  git_url: https://github.com/dmlc/dgl.git\n\nrequirements:\n  build:\n    - python {{ python }}\n    - setuptools\n    - cmake\n    - git\n    - cython\n  run:\n    - python\n    - numpy\n    - scipy\n    - networkx\n    - requests\n    - tqdm\n    - psutil\n\nbuild:\n  script_env:\n    - USE_CUDA\n    - CUDA_VER\n    - CACHEDIR\n    - DGL_VERSION_SUFFIX\n\nabout:\n  home: https://github.com/dmlc/dgl.git\n  license_file: {{ environ.get('SRC_DIR') }}/LICENSE\n  license: Apache\n"
  },
  {
    "path": "conda/dgl/run_test.bat",
    "content": "set DGLBACKEND=numpy\n%PYTHON% -c \"import dgl\"\n"
  },
  {
    "path": "conda/dgl/run_test.sh",
    "content": "DGLBACKEND=numpy $PYTHON -c 'import dgl'\n"
  },
  {
    "path": "dgl_sparse/CMakeLists.txt",
    "content": "cmake_minimum_required(VERSION 3.8)\nproject(dgl_sparse C CXX)\n\n# Find PyTorch cmake files and PyTorch versions with the python interpreter $PYTHON_INTERP\n# (\"python3\" or \"python\" if empty)\nif(NOT PYTHON_INTERP)\n  find_program(PYTHON_INTERP NAMES python3 python)\nendif()\nmessage(STATUS \"Using Python interpreter: ${PYTHON_INTERP}\")\nfile(TO_NATIVE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/find_cmake.py FIND_CMAKE_PY)\nexecute_process(\n  COMMAND ${PYTHON_INTERP} ${FIND_CMAKE_PY}\n  OUTPUT_VARIABLE TORCH_PREFIX_VER\n  OUTPUT_STRIP_TRAILING_WHITESPACE)\nmessage(STATUS \"find_cmake.py output: ${TORCH_PREFIX_VER}\")\nlist(GET TORCH_PREFIX_VER 0 TORCH_PREFIX)\nlist(GET TORCH_PREFIX_VER 1 TORCH_VER)\nmessage(STATUS \"Configuring for PyTorch ${TORCH_VER}\")\nstring(REPLACE \".\" \";\" TORCH_VERSION_LIST ${TORCH_VER})\nlist(GET TORCH_VERSION_LIST 0 TORCH_VERSION_MAJOR)\nlist(GET TORCH_VERSION_LIST 1 TORCH_VERSION_MINOR)\n\nset(SPARSE_LINKER_LIBS \"\")\n\nif(USE_CUDA)\n  add_definitions(-DDGL_USE_CUDA)\n  enable_language(CUDA)\nendif()\n\n# For windows, define NOMINMAX to avoid conflict with std::min/max\nif(MSVC)\n  add_definitions(-DNOMINMAX)\nendif()\n\nset(Torch_DIR \"${TORCH_PREFIX}/Torch\")\nmessage(STATUS \"Setting directory to ${Torch_DIR}\")\nfind_package(Torch REQUIRED)\nset(CMAKE_C_FLAGS \"${CMAKE_C_FLAGS} ${TORCH_C_FLAGS}\")\nset(CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}\")\nset(CMAKE_CXX_FLAGS_DEBUG \"${CMAKE_CXX_FLAGS_DEBUG} -O0 -g3 -ggdb\")\n\nset(LIB_DGL_SPARSE_NAME \"dgl_sparse_pytorch_${TORCH_VER}\")\nlist(APPEND SPARSE_LINKER_LIBS ${TORCH_LIBRARIES})\n\nset(SPARSE_DIR \"${CMAKE_CURRENT_SOURCE_DIR}/src\")\nset(SPARSE_INCLUDE \"${CMAKE_CURRENT_SOURCE_DIR}/include\")\nfile(GLOB SPARSE_HEADERS ${SPARSE_INCLUDE})\nfile(GLOB SPARSE_SRC\n  ${SPARSE_DIR}/*.cc\n  ${SPARSE_DIR}/cpu/*.cc\n)\nif(USE_CUDA)\n  file(GLOB SPARSE_CUDA_SRC\n    ${SPARSE_DIR}/cuda/*.cu\n  )\n  list(APPEND SPARSE_SRC ${SPARSE_CUDA_SRC})\nendif()\n\nadd_library(${LIB_DGL_SPARSE_NAME} SHARED ${SPARSE_SRC} ${SPARSE_HEADERS})\ntarget_include_directories(\n  ${LIB_DGL_SPARSE_NAME} PRIVATE ${SPARSE_DIR} ${SPARSE_HEADERS})\ntarget_link_libraries(${LIB_DGL_SPARSE_NAME} ${SPARSE_LINKER_LIBS})\ntarget_compile_definitions(${LIB_DGL_SPARSE_NAME} PRIVATE TORCH_VERSION_MAJOR=${TORCH_VERSION_MAJOR})\ntarget_compile_definitions(${LIB_DGL_SPARSE_NAME} PRIVATE TORCH_VERSION_MINOR=${TORCH_VERSION_MINOR})\n\ntarget_include_directories(${LIB_DGL_SPARSE_NAME} PRIVATE \"${CMAKE_SOURCE_DIR}/third_party/dmlc-core/include\")\nmessage(STATUS \"DGL include directories: ${DGL_INCLUDE_DIRS}\")\ntarget_include_directories(${LIB_DGL_SPARSE_NAME} PRIVATE ${DGL_INCLUDE_DIRS})\ntarget_link_directories(${LIB_DGL_SPARSE_NAME} PRIVATE ${DGL_BUILD_DIR} \"${DGL_BUILD_DIR}/third_party/dmlc-core\")\n\n# The Torch CMake configuration only sets up the path for the MKL library when\n# using the conda distribution. The following is a workaround to address this\n# when using a standalone installation of MKL.\nif(DEFINED MKL_LIBRARIES)\n  target_link_directories(${LIB_DGL_SPARSE_NAME} PRIVATE ${MKL_ROOT}/lib/${MKL_ARCH})\nendif()\nif (EXTERNAL_DMLC_LIB_PATH)\n   # external dmlc requires OpenMP link\n   include(FindOpenMP)\n   if(OPENMP_FOUND)\n        set(CMAKE_C_FLAGS \"${OpenMP_C_FLAGS} ${CMAKE_C_FLAGS}\")\n        set(CMAKE_CXX_FLAGS \"${OpenMP_CXX_FLAGS} ${CMAKE_CXX_FLAGS}\")\n   endif(OPENMP_FOUND)\t\n   message(STATUS \"looking for dmlc library in ${EXTERNAL_DMLC_LIB_PATH}\")\n   find_package(dmlc REQUIRED HINTS ${EXTERNAL_DMLC_LIB_PATH})\n   target_link_libraries(${LIB_DGL_SPARSE_NAME} dmlc::dmlc dgl)\nelse (EXTERNAL_DMLC_LIB_PATH)\n   target_link_libraries(${LIB_DGL_SPARSE_NAME} dmlc dgl)\nendif()\nset(GOOGLE_TEST 0) # Turn off dmlc-core test\n\n# Configure dgl_sparse library to use C++17 standard for compatibility with PyTorch\nset_property(TARGET ${LIB_DGL_SPARSE_NAME} PROPERTY CXX_STANDARD 17)\n"
  },
  {
    "path": "dgl_sparse/build.bat",
    "content": "REM Helper script to build DGL sparse libraries for PyTorch\n@ECHO OFF\nSETLOCAL EnableDelayedExpansion\n\nMD \"%BINDIR%\\dgl_sparse\"\nDEL /S /Q build\nMD build\nPUSHD build\n\nIF x%1x == xx GOTO single\nCOPY %BINDIR%\\third_party\\dmlc-core\\Release\\dmlc.lib %BINDIR%\nCOPY %BINDIR%\\Release\\dgl.lib %BINDIR%\n\nFOR %%X IN (%*) DO (\n\tDEL /S /Q *\n\t\"%CMAKE_COMMAND%\" -DDGL_BUILD_DIR=%BINDIR% -DCMAKE_CONFIGURATION_TYPES=Release -DCUDA_TOOLKIT_ROOT_DIR=\"%CUDA_TOOLKIT_ROOT_DIR%\" -DTORCH_CUDA_ARCH_LIST=%TORCH_CUDA_ARCH_LIST% -DDGL_INCLUDE_DIRS=%INCLUDEDIR: =;% -DUSE_CUDA=%USE_CUDA% -DPYTHON_INTERP=%%X .. -G \"Visual Studio 16 2019\" || EXIT /B 1\n\tmsbuild dgl_sparse.sln /m /nr:false || EXIT /B 1\n\tCOPY /Y Release\\*.dll \"%BINDIR%\\dgl_sparse\" || EXIT /B 1\n)\n\nGOTO end\n\n:single\n\nDEL /S /Q *\n\"%CMAKE_COMMAND%\" -DDGL_BUILD_DIR=%BINDIR% -DCMAKE_CONFIGURATION_TYPES=Release -DCUDA_TOOLKIT_ROOT_DIR=\"%CUDA_TOOLKIT_ROOT_DIR%\" -DTORCH_CUDA_ARCH_LIST=%TORCH_CUDA_ARCH_LIST% -DUSE_CUDA=%USE_CUDA% -DDGL_INCLUDE_DIRS=%INCLUDEDIR: =;% .. -G \"Visual Studio 16 2019\" || EXIT /B 1\nmsbuild dgl_sparse.sln /m /nr:false || EXIT /B 1\nCOPY /Y Release\\*.dll \"%BINDIR%\\dgl_sparse\" || EXIT /B 1\n\n:end\nPOPD\n\nENDLOCAL\n"
  },
  {
    "path": "dgl_sparse/build.sh",
    "content": "#!/bin/bash\n# Helper script to build dgl sparse libraries for PyTorch\nset -e\n\nmkdir -p build\nmkdir -p $BINDIR/dgl_sparse\ncd build\n\nif [ $(uname) = 'Darwin' ]; then\n\tCPSOURCE=*.dylib\nelse\n\tCPSOURCE=*.so\nfi\n\nCMAKE_FLAGS=\"-DCUDA_TOOLKIT_ROOT_DIR=$CUDA_TOOLKIT_ROOT_DIR -DTORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST -DUSE_CUDA=$USE_CUDA -DEXTERNAL_DMLC_LIB_PATH=$EXTERNAL_DMLC_LIB_PATH\"\n# CMake passes in the list of directories separated by spaces.  Here we replace them with semicolons.\nCMAKE_FLAGS=\"$CMAKE_FLAGS -DDGL_INCLUDE_DIRS=${INCLUDEDIR// /;} -DDGL_BUILD_DIR=$BINDIR\"\necho $CMAKE_FLAGS\n\nif [ $# -eq 0 ]; then\n\t$CMAKE_COMMAND $CMAKE_FLAGS ..\n\tmake -j\n\tcp -v $CPSOURCE $BINDIR/dgl_sparse\nelse\n\tfor PYTHON_INTERP in $@; do\n\t\tTORCH_VER=$($PYTHON_INTERP -c 'import torch; print(torch.__version__.split(\"+\")[0])')\n\t\tmkdir -p $TORCH_VER\n\t\tcd $TORCH_VER\n\t\t$CMAKE_COMMAND $CMAKE_FLAGS -DPYTHON_INTERP=$PYTHON_INTERP ../..\n\t\tmake -j\n\t\tcp -v $CPSOURCE $BINDIR/dgl_sparse\n\t\tcd ..\n\tdone\nfi\n"
  },
  {
    "path": "dgl_sparse/find_cmake.py",
    "content": "import os\n\nimport torch\n\ncmake_prefix_path = getattr(\n    torch.utils,\n    \"cmake_prefix_path\",\n    os.path.join(os.path.dirname(torch.__file__), \"share\", \"cmake\"),\n)\nversion = torch.__version__.split(\"+\")[0]\nprint(\";\".join([cmake_prefix_path, version]))\n"
  },
  {
    "path": "dgl_sparse/include/sparse/dgl_headers.h",
    "content": "/**\n *  Copyright (c) 2022 by Contributors\n * @file sparse/dgl_headers.h\n * @brief DGL headers used in the sparse library. This is a workaround to\n * avoid the macro naming conflict between dmlc/logging.h and torch logger. This\n * file includes all the DGL headers used in the sparse library and\n * undefines logging macros defined in dmlc/logging.h. There are two rules to\n * use this file. (1) All DGL headers used in the sparse library should be and\n * only be registered in this file. (2) When including Pytorch headers, this\n * file should be included in advance.\n */\n#ifndef SPARSE_DGL_HEADERS_H_\n#define SPARSE_DGL_HEADERS_H_\n\n#include <dgl/aten/coo.h>\n#include <dgl/aten/csr.h>\n#include <dgl/kernel.h>\n#include <dgl/runtime/dlpack_convert.h>\n#include <dmlc/logging.h>\n\n#undef CHECK\n#undef CHECK_OP\n#undef CHECK_EQ\n#undef CHECK_NE\n#undef CHECK_LE\n#undef CHECK_LT\n#undef CHECK_GE\n#undef CHECK_GT\n#undef CHECK_NOTNULL\n#undef DCHECK\n#undef DCHECK_EQ\n#undef DCHECK_NE\n#undef DCHECK_LE\n#undef DCHECK_LT\n#undef DCHECK_GE\n#undef DCHECK_GT\n#undef DCHECK_NOTNULL\n#undef VLOG\n#undef LOG\n#undef DLOG\n#undef LOG_IF\n\n#endif  // SPARSE_DGL_HEADERS_H_\n"
  },
  {
    "path": "dgl_sparse/include/sparse/elementwise_op.h",
    "content": "/**\n *  Copyright (c) 2022 by Contributors\n * @file sparse/elementwise_op.h\n * @brief DGL C++ sparse elementwise operators.\n */\n#ifndef SPARSE_ELEMENTWISE_OP_H_\n#define SPARSE_ELEMENTWISE_OP_H_\n\n#include <sparse/sparse_matrix.h>\n\nnamespace dgl {\nnamespace sparse {\n\n/**\n * @brief Adds two sparse matrices possibly with different sparsities.\n *\n * @param lhs_mat SparseMatrix\n * @param rhs_mat SparseMatrix\n *\n * @return SparseMatrix\n */\nc10::intrusive_ptr<SparseMatrix> SpSpAdd(\n    const c10::intrusive_ptr<SparseMatrix>& lhs_mat,\n    const c10::intrusive_ptr<SparseMatrix>& rhs_mat);\n\n/**\n * @brief Multiplies two sparse matrices possibly with different sparsities.\n *\n * @param lhs_mat SparseMatrix\n * @param rhs_mat SparseMatrix\n *\n * @return SparseMatrix\n */\nc10::intrusive_ptr<SparseMatrix> SpSpMul(\n    const c10::intrusive_ptr<SparseMatrix>& lhs_mat,\n    const c10::intrusive_ptr<SparseMatrix>& rhs_mat);\n\n/**\n * @brief Divides two sparse matrices with the same sparsity.\n *\n * @param lhs_mat SparseMatrix\n * @param rhs_mat SparseMatrix\n *\n * @return SparseMatrix\n */\nc10::intrusive_ptr<SparseMatrix> SpSpDiv(\n    const c10::intrusive_ptr<SparseMatrix>& lhs_mat,\n    const c10::intrusive_ptr<SparseMatrix>& rhs_mat);\n\n}  // namespace sparse\n}  // namespace dgl\n\n#endif  // SPARSE_ELEMENTWISE_OP_H_\n"
  },
  {
    "path": "dgl_sparse/include/sparse/matrix_ops.h",
    "content": "/**\n *  Copyright (c) 2023 by Contributors\n * @file sparse/matrix_ops.h\n * @brief DGL C++ sparse matrix operators.\n */\n#ifndef SPARSE_MATRIX_OPS_H_\n#define SPARSE_MATRIX_OPS_H_\n\n#include <sparse/sparse_matrix.h>\n\n#include <tuple>\n\nnamespace dgl {\nnamespace sparse {\n\n/**\n * @brief Compute the intersection of two COO matrices. Return the intersection\n * matrix, and the indices of the intersection in the left-hand-side and\n * right-hand-side matrices.\n *\n * @param lhs The left-hand-side COO matrix.\n * @param rhs The right-hand-side COO matrix.\n *\n * @return A tuple of COO matrix, lhs indices, and rhs indices.\n */\nstd::tuple<std::shared_ptr<COO>, torch::Tensor, torch::Tensor> COOIntersection(\n    const std::shared_ptr<COO>& lhs, const std::shared_ptr<COO>& rhs);\n\n/**\n * @brief Compact sparse matrix by removing rows or columns without non-zero\n * elements in the sparse matrix and relabeling indices of the dimension.\n *\n * This function serves a dual purpose: it allows you to reorganize the\n * indices within a specific dimension (rows or columns) of the sparse matrix\n * and, if needed, place certain 'leading_indices' at the beginning of the\n * compact dimension.\n *\n * @param mat The sparse matrix to be compacted.\n * @param dim The dimension to compact. Should be 0 or 1. Use 0 for row-wise\n *        compaction and 1 for column-wise compaction.\n * @param leading_indices An optional tensor containing row or column ids that\n *        should be placed at the beginning of the compact dimension.\n *\n * @return A tuple containing the compacted sparse matrix and the index mapping\n *         of the compact dimension from the new index to the original index.\n */\nstd::tuple<c10::intrusive_ptr<SparseMatrix>, torch::Tensor> Compact(\n    const c10::intrusive_ptr<SparseMatrix>& mat, int64_t dim,\n    const torch::optional<torch::Tensor>& leading_indices);\n\n}  // namespace sparse\n}  // namespace dgl\n\n#endif  // SPARSE_MATRIX_OPS_H_\n"
  },
  {
    "path": "dgl_sparse/include/sparse/reduction.h",
    "content": "/**\n *  Copyright (c) 2022 by Contributors\n * @file sparse/reduction.h\n * @brief DGL C++ sparse matrix reduction operators.\n */\n#ifndef SPARSE_REDUCTION_H_\n#define SPARSE_REDUCTION_H_\n\n#include <sparse/sparse_matrix.h>\n\n#include <string>\n\nnamespace dgl {\nnamespace sparse {\n\n/**\n * @brief Reduces a sparse matrix along the specified sparse dimension.\n *\n * @param A The sparse matrix.\n * @param dim The sparse dimension to reduce along.  Must be either 0 (rows) or\n * 1 (columns).\n * @param reduce The reduce operator.  Must be either \"sum\", \"smin\", \"smax\",\n * \"mean\", or \"sprod\".\n *\n * @return Tensor\n */\ntorch::Tensor Reduce(\n    const c10::intrusive_ptr<SparseMatrix>& A, const std::string& reduce,\n    const torch::optional<int64_t>& dim = torch::nullopt);\n\ninline torch::Tensor ReduceSum(\n    const c10::intrusive_ptr<SparseMatrix>& A,\n    const torch::optional<int64_t>& dim = torch::nullopt) {\n  return Reduce(A, \"sum\", dim);\n}\n\ninline torch::Tensor ReduceMin(\n    const c10::intrusive_ptr<SparseMatrix>& A,\n    const torch::optional<int64_t>& dim = torch::nullopt) {\n  return Reduce(A, \"smin\", dim);\n}\n\ninline torch::Tensor ReduceMax(\n    const c10::intrusive_ptr<SparseMatrix>& A,\n    const torch::optional<int64_t>& dim = torch::nullopt) {\n  return Reduce(A, \"smax\", dim);\n}\n\ninline torch::Tensor ReduceMean(\n    const c10::intrusive_ptr<SparseMatrix>& A,\n    const torch::optional<int64_t>& dim = torch::nullopt) {\n  return Reduce(A, \"smean\", dim);\n}\n\ninline torch::Tensor ReduceProd(\n    const c10::intrusive_ptr<SparseMatrix>& A,\n    const torch::optional<int64_t>& dim = torch::nullopt) {\n  return Reduce(A, \"sprod\", dim);\n}\n\n}  // namespace sparse\n}  // namespace dgl\n\n#endif  // SPARSE_REDUCTION_H_\n"
  },
  {
    "path": "dgl_sparse/include/sparse/sddmm.h",
    "content": "/**\n *  Copyright (c) 2022 by Contributors\n * @file sparse/sddmm.h\n * @brief DGL C++ SDDMM operator.\n */\n#ifndef SPARSE_SDDMM_H_\n#define SPARSE_SDDMM_H_\n\n#include <sparse/sparse_matrix.h>\n#include <torch/script.h>\n\nnamespace dgl {\nnamespace sparse {\n\n/**\n * @brief Perform a sampled matrix multiplication of a sparse matrix and two\n * dense matrices. It calculates `sparse_mat * (mat1 @ mat2)`. The SDDMM can be\n * batched, where the batch dimension is the last dimension for all input\n * matrices.\n *\n * There are four cases for the input and output matrix shapes:\n *   (1) (n, m), (n, k), (k, m), and (n, m);\n *   (2) (n, m), (n,), and (m,), and (n, m);\n *   (3) (n, m, b), (n, k, b), (k, m, b), and (n, m, b);\n *   (4) (n, m), (n, k, b), (k, m, b), and (n, m, b);\n *\n * This function supports autograd for `mat1` and `mat2` but does not support\n * high order gradient.\n *\n *\n * @param sparse_mat The sparse matrix.\n * @param mat1 The first dense matrix.\n * @param mat2 The second dense matrix.\n *\n * @return SparseMatrix\n */\nc10::intrusive_ptr<SparseMatrix> SDDMM(\n    const c10::intrusive_ptr<SparseMatrix>& sparse_mat, torch::Tensor mat1,\n    torch::Tensor mat2);\n\n}  // namespace sparse\n}  // namespace dgl\n\n#endif  // SPARSE_SDDMM_H_\n"
  },
  {
    "path": "dgl_sparse/include/sparse/softmax.h",
    "content": "/**\n *  Copyright (c) 2022 by Contributors\n * @file sparse/softmax.h\n * @brief DGL C++ Softmax operator\n */\n#ifndef SPARSE_SOFTMAX_H_\n#define SPARSE_SOFTMAX_H_\n\n#include <sparse/sparse_matrix.h>\n\nnamespace dgl {\nnamespace sparse {\n\n/**\n * @brief Apply softmax to the non-zero entries of the sparse matrix on the\n * dimension dim. dim = 0 or 1 indicates column-wise or row-wise softmax\n * respectively.\n *\n * This function supports autograd for the sparse matrix, but it does not\n * support higher order gradient.\n *\n * @param sparse_mat The sparse matrix\n * @param dim The dimension to apply softmax\n *\n * @return Sparse matrix\n */\nc10::intrusive_ptr<SparseMatrix> Softmax(\n    const c10::intrusive_ptr<SparseMatrix>& sparse_mat, int64_t dim);\n\n}  // namespace sparse\n}  // namespace dgl\n\n#endif  // SPARSE_SOFTMAX_H_\n"
  },
  {
    "path": "dgl_sparse/include/sparse/sparse_format.h",
    "content": "/**\n *  Copyright (c) 2022 by Contributors\n * @file sparse/sparse_format.h\n * @brief DGL C++ sparse format header.\n */\n#ifndef SPARSE_SPARSE_FORMAT_H_\n#define SPARSE_SPARSE_FORMAT_H_\n\n// clang-format off\n#include <sparse/dgl_headers.h>\n// clang-format on\n\n#include <torch/custom_class.h>\n#include <torch/script.h>\n\n#include <memory>\n#include <utility>\n\nnamespace dgl {\nnamespace sparse {\n\n/** @brief SparseFormat enumeration. */\nenum SparseFormat { kCOO, kCSR, kCSC, kDiag };\n\n/** @brief COO sparse structure. */\nstruct COO {\n  /** @brief The shape of the matrix. */\n  int64_t num_rows = 0, num_cols = 0;\n  /**\n   * @brief COO tensor of shape (2, nnz), stacking the row and column indices.\n   */\n  torch::Tensor indices;\n  /** @brief Whether the row indices are sorted. */\n  bool row_sorted = false;\n  /** @brief Whether the column indices per row are sorted. */\n  bool col_sorted = false;\n};\n\n/** @brief CSR sparse structure. */\nstruct CSR {\n  /** @brief The dense shape of the matrix. */\n  int64_t num_rows = 0, num_cols = 0;\n  /** @brief CSR format index pointer array of the matrix. */\n  torch::Tensor indptr;\n  /** @brief CSR format index array of the matrix. */\n  torch::Tensor indices;\n  /** @brief Data index tensor. When it is null, assume it is from 0 to NNZ - 1.\n   */\n  torch::optional<torch::Tensor> value_indices;\n  /** @brief Whether the column indices per row are sorted. */\n  bool sorted = false;\n};\n\nstruct Diag {\n  /** @brief The dense shape of the matrix. */\n  int64_t num_rows = 0, num_cols = 0;\n};\n\n/** @brief Convert an old DGL COO format to a COO in the sparse library. */\nstd::shared_ptr<COO> COOFromOldDGLCOO(const aten::COOMatrix& dgl_coo);\n\n/** @brief Convert a COO in the sparse library to an old DGL COO matrix. */\naten::COOMatrix COOToOldDGLCOO(const std::shared_ptr<COO>& coo);\n\n/** @brief Convert an old DGL CSR format to a CSR in the sparse library. */\nstd::shared_ptr<CSR> CSRFromOldDGLCSR(const aten::CSRMatrix& dgl_csr);\n\n/** @brief Convert a CSR in the sparse library to an old DGL CSR matrix. */\naten::CSRMatrix CSRToOldDGLCSR(const std::shared_ptr<CSR>& csr);\n\n/**\n *  @brief Convert a COO and its nonzero values to a Torch COO matrix.\n *  @param coo The COO format in the sparse library\n *  @param value Values of the sparse matrix\n *\n *  @return Torch Sparse Tensor in COO format\n */\ntorch::Tensor COOToTorchCOO(\n    const std::shared_ptr<COO>& coo, torch::Tensor value);\n\n/** @brief Convert a CSR format to COO format. */\nstd::shared_ptr<COO> CSRToCOO(const std::shared_ptr<CSR>& csr);\n\n/** @brief Convert a CSC format to COO format. */\nstd::shared_ptr<COO> CSCToCOO(const std::shared_ptr<CSR>& csc);\n\n/** @brief Convert a COO format to CSR format. */\nstd::shared_ptr<CSR> COOToCSR(const std::shared_ptr<COO>& coo);\n\n/** @brief Convert a CSC format to CSR format. */\nstd::shared_ptr<CSR> CSCToCSR(const std::shared_ptr<CSR>& csc);\n\n/** @brief Convert a COO format to CSC format. */\nstd::shared_ptr<CSR> COOToCSC(const std::shared_ptr<COO>& coo);\n\n/** @brief Convert a CSR format to CSC format. */\nstd::shared_ptr<CSR> CSRToCSC(const std::shared_ptr<CSR>& csr);\n\n/** @brief Convert a Diag format to COO format. */\nstd::shared_ptr<COO> DiagToCOO(\n    const std::shared_ptr<Diag>& diag,\n    const c10::TensorOptions& indices_options);\n\n/** @brief Convert a Diag format to CSR format. */\nstd::shared_ptr<CSR> DiagToCSR(\n    const std::shared_ptr<Diag>& diag,\n    const c10::TensorOptions& indices_options);\n\n/** @brief Convert a Diag format to CSC format. */\nstd::shared_ptr<CSR> DiagToCSC(\n    const std::shared_ptr<Diag>& diag,\n    const c10::TensorOptions& indices_options);\n\n/** @brief COO transposition. */\nstd::shared_ptr<COO> COOTranspose(const std::shared_ptr<COO>& coo);\n\n/**\n * @brief Sort the COO matrix by row and column indices.\n * @return A pair of the sorted COO matrix and the permutation indices.\n */\nstd::pair<std::shared_ptr<COO>, torch::Tensor> COOSort(\n    const std::shared_ptr<COO>& coo);\n\n}  // namespace sparse\n}  // namespace dgl\n\n#endif  // SPARSE_SPARSE_FORMAT_H_\n"
  },
  {
    "path": "dgl_sparse/include/sparse/sparse_matrix.h",
    "content": "/**\n *  Copyright (c) 2022 by Contributors\n * @file sparse/sparse_matrix.h\n * @brief DGL C++ sparse matrix header.\n */\n#ifndef SPARSE_SPARSE_MATRIX_H_\n#define SPARSE_SPARSE_MATRIX_H_\n\n// clang-format off\n#include <sparse/dgl_headers.h>\n// clang-format on\n\n#include <sparse/sparse_format.h>\n#include <torch/custom_class.h>\n#include <torch/script.h>\n\n#include <memory>\n#include <tuple>\n#include <utility>\n#include <vector>\n\nnamespace dgl {\nnamespace sparse {\n\n/** @brief SparseMatrix bound to Python.  */\nclass SparseMatrix : public torch::CustomClassHolder {\n public:\n  /**\n   * @brief General constructor to construct a sparse matrix for different\n   * sparse formats. At least one of the sparse formats should be provided,\n   * while others could be nullptrs.\n   *\n   * @param coo The COO format.\n   * @param csr The CSR format.\n   * @param csc The CSC format.\n   * @param value Value of the sparse matrix.\n   * @param shape Shape of the sparse matrix.\n   */\n  SparseMatrix(\n      const std::shared_ptr<COO>& coo, const std::shared_ptr<CSR>& csr,\n      const std::shared_ptr<CSR>& csc, const std::shared_ptr<Diag>& diag,\n      torch::Tensor value, const std::vector<int64_t>& shape);\n\n  /**\n   * @brief Construct a SparseMatrix from a COO format.\n   * @param coo The COO format\n   * @param value Values of the sparse matrix\n   * @param shape Shape of the sparse matrix\n   *\n   * @return SparseMatrix\n   */\n  static c10::intrusive_ptr<SparseMatrix> FromCOOPointer(\n      const std::shared_ptr<COO>& coo, torch::Tensor value,\n      const std::vector<int64_t>& shape);\n\n  /**\n   * @brief Construct a SparseMatrix from a CSR format.\n   * @param csr The CSR format\n   * @param value Values of the sparse matrix\n   * @param shape Shape of the sparse matrix\n   *\n   * @return SparseMatrix\n   */\n  static c10::intrusive_ptr<SparseMatrix> FromCSRPointer(\n      const std::shared_ptr<CSR>& csr, torch::Tensor value,\n      const std::vector<int64_t>& shape);\n\n  /**\n   * @brief Construct a SparseMatrix from a CSC format.\n   * @param csc The CSC format\n   * @param value Values of the sparse matrix\n   * @param shape Shape of the sparse matrix\n   *\n   * @return SparseMatrix\n   */\n  static c10::intrusive_ptr<SparseMatrix> FromCSCPointer(\n      const std::shared_ptr<CSR>& csc, torch::Tensor value,\n      const std::vector<int64_t>& shape);\n\n  /**\n   * @brief Construct a SparseMatrix from a Diag format.\n   * @param diag The Diag format\n   * @param value Values of the sparse matrix\n   * @param shape Shape of the sparse matrix\n   *\n   * @return SparseMatrix\n   */\n  static c10::intrusive_ptr<SparseMatrix> FromDiagPointer(\n      const std::shared_ptr<Diag>& diag, torch::Tensor value,\n      const std::vector<int64_t>& shape);\n\n  /**\n   * @brief Create a SparseMatrix from tensors in COO format.\n   * @param indices COO coordinates with shape (2, nnz).\n   * @param value Values of the sparse matrix.\n   * @param shape Shape of the sparse matrix.\n   *\n   * @return SparseMatrix\n   */\n  static c10::intrusive_ptr<SparseMatrix> FromCOO(\n      torch::Tensor indices, torch::Tensor value,\n      const std::vector<int64_t>& shape);\n\n  /**\n   * @brief Create a SparseMatrix from tensors in CSR format.\n   * @param indptr Index pointer array of the CSR\n   * @param indices Indices array of the CSR\n   * @param value Values of the sparse matrix\n   * @param shape Shape of the sparse matrix\n   *\n   * @return SparseMatrix\n   */\n  static c10::intrusive_ptr<SparseMatrix> FromCSR(\n      torch::Tensor indptr, torch::Tensor indices, torch::Tensor value,\n      const std::vector<int64_t>& shape);\n\n  /**\n   * @brief Create a SparseMatrix from tensors in CSC format.\n   * @param indptr Index pointer array of the CSC\n   * @param indices Indices array of the CSC\n   * @param value Values of the sparse matrix\n   * @param shape Shape of the sparse matrix\n   *\n   * @return SparseMatrix\n   */\n  static c10::intrusive_ptr<SparseMatrix> FromCSC(\n      torch::Tensor indptr, torch::Tensor indices, torch::Tensor value,\n      const std::vector<int64_t>& shape);\n\n  /**\n   * @brief Create a SparseMatrix with Diag format.\n   * @param value Values of the sparse matrix\n   * @param shape Shape of the sparse matrix\n   *\n   * @return SparseMatrix\n   */\n  static c10::intrusive_ptr<SparseMatrix> FromDiag(\n      torch::Tensor value, const std::vector<int64_t>& shape);\n\n  /**\n   * @brief Create a SparseMatrix by selecting rows or columns based on provided\n   * indices.\n   *\n   * This function allows you to create a new SparseMatrix by selecting specific\n   * rows or columns from the original SparseMatrix based on the provided\n   * indices. The selection can be performed either row-wise or column-wise,\n   * determined by the 'dim' parameter.\n   *\n   * @param dim Select rows (dim=0) or columns (dim=1).\n   * @param ids A tensor containing the indices of the selected rows or columns.\n   *\n   * @return A new SparseMatrix containing the selected rows or columns.\n   *\n   * @note The 'dim' parameter should be either 0 (for row-wise selection) or 1\n   * (for column-wise selection).\n   * @note The 'ids' tensor should contain valid indices within the range of the\n   * original SparseMatrix's dimensions.\n   */\n  c10::intrusive_ptr<SparseMatrix> IndexSelect(int64_t dim, torch::Tensor ids);\n\n  /**\n   * @brief Create a SparseMatrix by selecting a range of rows or columns based\n   * on provided indices.\n   *\n   * This function allows you to create a new SparseMatrix by selecting a range\n   * of specific rows or columns from the original SparseMatrix based on the\n   * provided indices. The selection can be performed either row-wise or\n   * column-wise, determined by the 'dim' parameter.\n   *\n   * @param dim Select rows (dim=0) or columns (dim=1).\n   * @param start The starting index (inclusive) of the range.\n   * @param end The ending index (exclusive) of the range.\n   *\n   * @return A new SparseMatrix containing the selected range of rows or\n   * columns.\n   *\n   * @note The 'dim' parameter should be either 0 (for row-wise selection) or 1\n   * (for column-wise selection).\n   * @note The 'start' and 'end' indices should be valid indices within\n   * the valid range of the original SparseMatrix's dimensions.\n   */\n  c10::intrusive_ptr<SparseMatrix> RangeSelect(\n      int64_t dim, int64_t start, int64_t end);\n\n  /**\n   * @brief Create a SparseMatrix by sampling elements based on the specified\n   * dimension and sample count.\n   *\n   * If `ids` is provided, this function samples elements from the specified\n   * set of row or column IDs, resulting in a sparse matrix containing only\n   * the sampled rows or columns.\n   *\n   * @param dim Select rows (dim=0) or columns (dim=1) for sampling.\n   * @param fanout The number of elements to randomly sample from each row or\n   * column.\n   * @param ids An optional tensor containing row or column IDs from which to\n   * sample elements.\n   * @param replace Indicates whether repeated sampling of the same element\n   * is allowed. If True, repeated sampling is allowed; otherwise, it is not\n   * allowed.\n   * @param bias An optional boolean flag indicating whether to enable biasing\n   * during sampling. If True, the values of the sparse matrix will be used as\n   * bias weights, meaning that elements with higher values will be more likely\n   * to be sampled. Otherwise, all elements will be sampled uniformly,\n   * regardless of their value.\n   *\n   * @return A new SparseMatrix with the same shape as the original matrix\n   * containing the sampled elements.\n   *\n   * @note If 'replace = false' and there are fewer elements than 'fanout',\n   * all non-zero elements will be sampled.\n   * @note If 'ids' is not provided, the function will sample from\n   * all rows or columns.\n   */\n  c10::intrusive_ptr<SparseMatrix> Sample(\n      int64_t dim, int64_t fanout, torch::Tensor ids, bool replace, bool bias);\n\n  /**\n   * @brief Create a SparseMatrix from a SparseMatrix using new values.\n   * @param mat An existing sparse matrix\n   * @param value New values of the sparse matrix\n   *\n   * @return SparseMatrix\n   */\n  static c10::intrusive_ptr<SparseMatrix> ValLike(\n      const c10::intrusive_ptr<SparseMatrix>& mat, torch::Tensor value);\n\n  /** @return Value of the sparse matrix. */\n  inline torch::Tensor value() const { return value_; }\n  /** @return Shape of the sparse matrix. */\n  inline const std::vector<int64_t>& shape() const { return shape_; }\n  /** @return Number of non-zero values */\n  inline int64_t nnz() const { return value_.size(0); }\n  /** @return Non-zero value data type */\n  inline caffe2::TypeMeta dtype() const { return value_.dtype(); }\n  /** @return Device of the sparse matrix */\n  inline torch::Device device() const { return value_.device(); }\n\n  /** @return COO of the sparse matrix. The COO is created if not exists. */\n  std::shared_ptr<COO> COOPtr();\n  /** @return CSR of the sparse matrix. The CSR is created if not exists. */\n  std::shared_ptr<CSR> CSRPtr();\n  /** @return CSC of the sparse matrix. The CSC is created if not exists. */\n  std::shared_ptr<CSR> CSCPtr();\n  /**\n   * @return Diagonal format of the sparse matrix. An error will be raised if\n   * it does not have a diagonal format.\n   */\n  std::shared_ptr<Diag> DiagPtr();\n\n  /** @brief Check whether this sparse matrix has COO format. */\n  inline bool HasCOO() const { return coo_ != nullptr; }\n  /** @brief Check whether this sparse matrix has CSR format. */\n  inline bool HasCSR() const { return csr_ != nullptr; }\n  /** @brief Check whether this sparse matrix has CSC format. */\n  inline bool HasCSC() const { return csc_ != nullptr; }\n  /** @brief Check whether this sparse matrix has Diag format. */\n  inline bool HasDiag() const { return diag_ != nullptr; }\n\n  /** @return {row, col} tensors in the COO format. */\n  std::tuple<torch::Tensor, torch::Tensor> COOTensors();\n  /** @return Stacked row and col tensors in the COO format. */\n  torch::Tensor Indices();\n  /** @return {row, col, value_indices} tensors in the CSR format. */\n  std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>\n  CSRTensors();\n  /** @return {row, col, value_indices} tensors in the CSC format. */\n  std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>\n  CSCTensors();\n\n  /** @brief Return the transposition of the sparse matrix. It transposes the\n   * first existing sparse format by checking COO, CSR, and CSC.\n   */\n  c10::intrusive_ptr<SparseMatrix> Transpose() const;\n\n  /**\n   * @brief Return a new coalesced matrix.\n   *\n   * A coalesced sparse matrix satisfies the following properties:\n   *   - the indices of the non-zero elements are unique,\n   *   - the indices are sorted in lexicographical order.\n   *\n   * @return A coalesced sparse matrix.\n   */\n  c10::intrusive_ptr<SparseMatrix> Coalesce();\n\n  /**\n   * @brief Return true if this sparse matrix contains duplicate indices.\n   * @return A bool flag.\n   */\n  bool HasDuplicate();\n\n private:\n  /** @brief Create the COO format for the sparse matrix internally */\n  void _CreateCOO();\n  /** @brief Create the CSR format for the sparse matrix internally */\n  void _CreateCSR();\n  /** @brief Create the CSC format for the sparse matrix internally */\n  void _CreateCSC();\n\n  // COO/CSC/CSR/Diag pointers. Nullptr indicates non-existence.\n  std::shared_ptr<COO> coo_;\n  std::shared_ptr<CSR> csr_, csc_;\n  std::shared_ptr<Diag> diag_;\n  // Value of the SparseMatrix\n  torch::Tensor value_;\n  // Shape of the SparseMatrix\n  const std::vector<int64_t> shape_;\n};\n}  // namespace sparse\n}  // namespace dgl\n\n#endif  //  SPARSE_SPARSE_MATRIX_H_\n"
  },
  {
    "path": "dgl_sparse/include/sparse/spmm.h",
    "content": "/**\n *  Copyright (c) 2022 by Contributors\n * @file sparse/spmm.h\n * @brief DGL C++ SpMM operator.\n */\n#ifndef SPARSE_SPMM_H_\n#define SPARSE_SPMM_H_\n\n#include <sparse/sparse_matrix.h>\n#include <torch/script.h>\n\nnamespace dgl {\nnamespace sparse {\n\n/**\n * @brief Perform a matrix multiplication of the sparse matrix and dense\n * matrix. The SpMM can be batched, where the batch dimension is the last\n * dimension for both sparse and dense matrices.\n *\n * There are three cases for sparse, dense, and output matrix shapes:\n *   (1) (n, m), (m, k), and (n, k);\n *   (2) (n, m), (m,), and (n,);\n *   (3) (n, m, b), (m, k, b), and (n, k, b).\n *\n * This function supports autograd for both the sparse and dense matrix but does\n * not support higher order gradient.\n *\n * @param sparse_mat The sparse matrix.\n * @param dense_mat The dense matrix.\n *\n * @return Dense matrix.\n */\ntorch::Tensor SpMM(\n    const c10::intrusive_ptr<SparseMatrix>& sparse_mat,\n    torch::Tensor dense_mat);\n\n}  // namespace sparse\n}  // namespace dgl\n\n#endif  // SPARSE_SPMM_H_\n"
  },
  {
    "path": "dgl_sparse/include/sparse/spspmm.h",
    "content": "/**\n *  Copyright (c) 2022 by Contributors\n * @file sparse/spspmm.h\n * @brief DGL C++ SpSpMM operator.\n */\n#ifndef SPARSE_SPSPMM_H_\n#define SPARSE_SPSPMM_H_\n\n#include <sparse/sparse_matrix.h>\n#include <torch/script.h>\n\nnamespace dgl {\nnamespace sparse {\n\n/**\n * @brief Perform a sparse-sparse matrix multiplication on matrices with\n * possibly different sparsities. The two sparse matrices must have\n * 1-D values. If the first sparse matrix has shape (n, m), the second\n * sparse matrix must have shape (m, k), and the returned sparse matrix has\n * shape (n, k).\n *\n * This function supports autograd for both sparse matrices but does\n * not support higher order gradient.\n *\n * @param lhs_mat The first sparse matrix of shape (n, m).\n * @param rhs_mat The second sparse matrix of shape (m, k).\n *\n * @return Sparse matrix of shape (n, k).\n */\nc10::intrusive_ptr<SparseMatrix> SpSpMM(\n    const c10::intrusive_ptr<SparseMatrix>& lhs_mat,\n    const c10::intrusive_ptr<SparseMatrix>& rhs_mat);\n\n}  // namespace sparse\n}  // namespace dgl\n\n#endif  // SPARSE_SPSPMM_H_\n"
  },
  {
    "path": "dgl_sparse/src/cpu/matrix_ops_impl.cc",
    "content": "/**\n *  Copyright (c) 2023 by Contributors\n * @file cpu/matrix_ops_impl.cc\n * @brief DGL C++ matrix operators.\n */\n#include \"./matrix_ops_impl.h\"\n\nnamespace dgl {\nnamespace sparse {}  // namespace sparse\n}  // namespace dgl\n"
  },
  {
    "path": "dgl_sparse/src/elemenwise_op.cc",
    "content": "/**\n *  Copyright (c) 2022 by Contributors\n * @file elementwise_op.cc\n * @brief DGL C++ sparse elementwise operator implementation.\n */\n\n#include <sparse/elementwise_op.h>\n#include <sparse/matrix_ops.h>\n#include <sparse/sparse_matrix.h>\n#include <torch/script.h>\n\n#include <memory>\n\n#include \"./utils.h\"\n\nnamespace dgl {\nnamespace sparse {\n\nusing namespace torch::autograd;\n\nc10::intrusive_ptr<SparseMatrix> SpSpAdd(\n    const c10::intrusive_ptr<SparseMatrix>& lhs_mat,\n    const c10::intrusive_ptr<SparseMatrix>& rhs_mat) {\n  ElementwiseOpSanityCheck(lhs_mat, rhs_mat);\n  if (lhs_mat->HasDiag() && rhs_mat->HasDiag()) {\n    return SparseMatrix::FromDiagPointer(\n        lhs_mat->DiagPtr(), lhs_mat->value() + rhs_mat->value(),\n        lhs_mat->shape());\n  }\n  auto torch_lhs = COOToTorchCOO(lhs_mat->COOPtr(), lhs_mat->value());\n  auto torch_rhs = COOToTorchCOO(rhs_mat->COOPtr(), rhs_mat->value());\n  auto sum = (torch_lhs + torch_rhs).coalesce();\n  return SparseMatrix::FromCOO(sum.indices(), sum.values(), lhs_mat->shape());\n}\n\nclass SpSpMulAutoGrad : public Function<SpSpMulAutoGrad> {\n public:\n  static variable_list forward(\n      AutogradContext* ctx, c10::intrusive_ptr<SparseMatrix> lhs_mat,\n      torch::Tensor lhs_val, c10::intrusive_ptr<SparseMatrix> rhs_mat,\n      torch::Tensor rhs_val);\n\n  static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs);\n};\n\nvariable_list SpSpMulAutoGrad::forward(\n    AutogradContext* ctx, c10::intrusive_ptr<SparseMatrix> lhs_mat,\n    torch::Tensor lhs_val, c10::intrusive_ptr<SparseMatrix> rhs_mat,\n    torch::Tensor rhs_val) {\n  std::shared_ptr<COO> intersection;\n  torch::Tensor lhs_indices, rhs_indices;\n  std::tie(intersection, lhs_indices, rhs_indices) =\n      COOIntersection(lhs_mat->COOPtr(), rhs_mat->COOPtr());\n  auto lhs_intersect_val = lhs_val.index_select(0, lhs_indices);\n  auto rhs_intersect_val = rhs_val.index_select(0, rhs_indices);\n  auto ret_val = lhs_intersect_val * rhs_intersect_val;\n  auto ret_mat =\n      SparseMatrix::FromCOOPointer(intersection, ret_val, lhs_mat->shape());\n\n  ctx->saved_data[\"lhs_require_grad\"] = lhs_val.requires_grad();\n  ctx->saved_data[\"rhs_require_grad\"] = rhs_val.requires_grad();\n  if (lhs_val.requires_grad()) {\n    ctx->saved_data[\"lhs_val_shape\"] = lhs_val.sizes().vec();\n    ctx->saved_data[\"rhs_intersect_lhs\"] =\n        SparseMatrix::ValLike(ret_mat, rhs_intersect_val);\n    ctx->saved_data[\"lhs_indices\"] = lhs_indices;\n  }\n  if (rhs_val.requires_grad()) {\n    ctx->saved_data[\"rhs_val_shape\"] = rhs_val.sizes().vec();\n    ctx->saved_data[\"lhs_intersect_rhs\"] =\n        SparseMatrix::ValLike(ret_mat, lhs_intersect_val);\n    ctx->saved_data[\"rhs_indices\"] = rhs_indices;\n  }\n  return {intersection->indices, ret_val};\n}\n\ntensor_list SpSpMulAutoGrad::backward(\n    AutogradContext* ctx, tensor_list grad_outputs) {\n  torch::Tensor lhs_val_grad, rhs_val_grad;\n  auto output_grad = grad_outputs[1];\n  if (ctx->saved_data[\"lhs_require_grad\"].toBool()) {\n    auto rhs_intersect_lhs =\n        ctx->saved_data[\"rhs_intersect_lhs\"].toCustomClass<SparseMatrix>();\n    const auto& lhs_val_shape = ctx->saved_data[\"lhs_val_shape\"].toIntVector();\n    auto lhs_indices = ctx->saved_data[\"lhs_indices\"].toTensor();\n    lhs_val_grad = torch::zeros(lhs_val_shape, output_grad.options());\n    auto intersect_grad = rhs_intersect_lhs->value() * output_grad;\n    lhs_val_grad.index_put_({lhs_indices}, intersect_grad);\n  }\n  if (ctx->saved_data[\"rhs_require_grad\"].toBool()) {\n    auto lhs_intersect_rhs =\n        ctx->saved_data[\"lhs_intersect_rhs\"].toCustomClass<SparseMatrix>();\n    const auto& rhs_val_shape = ctx->saved_data[\"rhs_val_shape\"].toIntVector();\n    auto rhs_indices = ctx->saved_data[\"rhs_indices\"].toTensor();\n    rhs_val_grad = torch::zeros(rhs_val_shape, output_grad.options());\n    auto intersect_grad = lhs_intersect_rhs->value() * output_grad;\n    rhs_val_grad.index_put_({rhs_indices}, intersect_grad);\n  }\n  return {torch::Tensor(), lhs_val_grad, torch::Tensor(), rhs_val_grad};\n}\n\nc10::intrusive_ptr<SparseMatrix> SpSpMul(\n    const c10::intrusive_ptr<SparseMatrix>& lhs_mat,\n    const c10::intrusive_ptr<SparseMatrix>& rhs_mat) {\n  ElementwiseOpSanityCheck(lhs_mat, rhs_mat);\n  if (lhs_mat->HasDiag() && rhs_mat->HasDiag()) {\n    return SparseMatrix::FromDiagPointer(\n        lhs_mat->DiagPtr(), lhs_mat->value() * rhs_mat->value(),\n        lhs_mat->shape());\n  }\n  TORCH_CHECK(\n      !lhs_mat->HasDuplicate() && !rhs_mat->HasDuplicate(),\n      \"Only support SpSpMul on sparse matrices without duplicate values\")\n  auto results = SpSpMulAutoGrad::apply(\n      lhs_mat, lhs_mat->value(), rhs_mat, rhs_mat->value());\n  const auto& indices = results[0];\n  const auto& val = results[1];\n  return SparseMatrix::FromCOO(indices, val, lhs_mat->shape());\n}\n\nc10::intrusive_ptr<SparseMatrix> SpSpDiv(\n    const c10::intrusive_ptr<SparseMatrix>& lhs_mat,\n    const c10::intrusive_ptr<SparseMatrix>& rhs_mat) {\n  ElementwiseOpSanityCheck(lhs_mat, rhs_mat);\n  if (lhs_mat->HasDiag() && rhs_mat->HasDiag()) {\n    return SparseMatrix::FromDiagPointer(\n        lhs_mat->DiagPtr(), lhs_mat->value() / rhs_mat->value(),\n        lhs_mat->shape());\n  }\n  std::shared_ptr<COO> sorted_lhs, sorted_rhs;\n  torch::Tensor lhs_sorted_perm, rhs_sorted_perm;\n  std::tie(sorted_lhs, lhs_sorted_perm) = COOSort(lhs_mat->COOPtr());\n  std::tie(sorted_rhs, rhs_sorted_perm) = COOSort(rhs_mat->COOPtr());\n  TORCH_CHECK(\n      !lhs_mat->HasDuplicate() && !rhs_mat->HasDuplicate(),\n      \"Only support SpSpDiv on sparse matrices without duplicate values\")\n  TORCH_CHECK(\n      torch::equal(sorted_lhs->indices, sorted_rhs->indices),\n      \"Cannot divide two COO matrices with different sparsities.\");\n  // This is to make sure the return matrix is in the same order as the lhs_mat\n  auto lhs_sorted_rperm = lhs_sorted_perm.argsort();\n  auto rhs_perm_on_lhs = rhs_sorted_perm.index_select(0, lhs_sorted_rperm);\n  auto lhs_value = lhs_mat->value();\n  auto rhs_value = rhs_mat->value().index_select(0, rhs_perm_on_lhs);\n  auto ret_val = lhs_value / rhs_value;\n  return SparseMatrix::FromCOOPointer(\n      lhs_mat->COOPtr(), ret_val, lhs_mat->shape());\n}\n\n}  // namespace sparse\n}  // namespace dgl\n"
  },
  {
    "path": "dgl_sparse/src/matmul.cc",
    "content": "/**\n *  Copyright (c) 2022 by Contributors\n * @file matmul.cc\n * @brief DGL sparse matrix multiplication functions.\n */\n#include \"./matmul.h\"\n\n// clang-format off\n#include <sparse/dgl_headers.h>\n// clang-format on\n\n#include <sparse/sparse_matrix.h>\n#include <torch/script.h>\n\n#include \"./utils.h\"\n\nnamespace dgl {\nnamespace sparse {\n\ntorch::Tensor SpMMNoAutoGrad(\n    const c10::intrusive_ptr<SparseMatrix>& sparse_mat,\n    torch::Tensor sparse_val, torch::Tensor dense_mat, bool transpose_sparse) {\n  const std::string op = \"mul\";\n  const std::string reduce = \"sum\";\n  const int64_t out_row =\n      transpose_sparse ? sparse_mat->shape()[1] : sparse_mat->shape()[0];\n  std::vector<int64_t> shape = {out_row, dense_mat.size(1)};\n  // Batched SpMM\n  if (sparse_val.dim() >= 2) {\n    shape = {out_row, dense_mat.size(1), sparse_val.size(1)};\n  }\n\n  auto ret = torch::zeros(shape, dense_mat.options());\n  auto dgl_sparse_val = TorchTensorToDGLArray(sparse_val);\n  auto dgl_dense_mat = TorchTensorToDGLArray(dense_mat);\n  auto dgl_ret = TorchTensorToDGLArray(ret);\n  if (!transpose_sparse) {\n    // The format for calculation will be chosen in the following order: CSR,\n    // COO. CSR is created if the sparse matrix only has CSC format.\n    if (sparse_mat->HasCSR() || !sparse_mat->HasCOO()) {\n      // sparse_mat->CSRPtr() will implicitly convert CSC to CSR format if CSR\n      // does not exist.\n      auto csr = CSRToOldDGLCSR(sparse_mat->CSRPtr());\n      aten::CSRSpMM(\n          op.c_str(), reduce.c_str(), csr, dgl_dense_mat, dgl_sparse_val,\n          dgl_ret, {});\n    } else {  // COO\n      // Use the reverse order of aten::COOSpMM because it calculates A^T @ X.\n      auto coo = COOToOldDGLCOO(sparse_mat->COOPtr());\n      coo = aten::COOTranspose(coo);\n      aten::COOSpMM(\n          op.c_str(), reduce.c_str(), coo, dgl_dense_mat, dgl_sparse_val,\n          dgl_ret, {});\n    }\n  } else {  // transpose_sparse\n    // The format for calculation will be chosen in the following order: CSC,\n    // COO. CSC is created if the sparse matrix only has CSR format.\n    if (sparse_mat->HasCSC() || !sparse_mat->HasCOO()) {\n      // sparse_mat->CSCPtr() will implicitly convert CSR to CSC format if CSR\n      // does not exist.\n      // Use CSC in DGL's CSRSpMM is equivalent as computing A^T @ X.\n      auto csc = CSRToOldDGLCSR(sparse_mat->CSCPtr());\n      aten::CSRSpMM(\n          op.c_str(), reduce.c_str(), csc, dgl_dense_mat, dgl_sparse_val,\n          dgl_ret, {});\n    } else {  // COO\n      // Use the reverse order of aten::COOSpMM because it calculates A^T @ X.\n      auto coo = COOToOldDGLCOO(sparse_mat->COOPtr());\n      aten::COOSpMM(\n          op.c_str(), reduce.c_str(), coo, dgl_dense_mat, dgl_sparse_val,\n          dgl_ret, {});\n    }\n  }\n  return ret;\n}\n\ntorch::Tensor SDDMMNoAutoGrad(\n    const c10::intrusive_ptr<SparseMatrix>& sparse_mat, torch::Tensor mat1,\n    torch::Tensor mat2_tr) {\n  const int64_t out_row = sparse_mat->nnz();\n  std::vector<int64_t> shape({out_row});\n  // Batched SDDMM\n  if (mat1.dim() >= 3) {\n    shape.push_back(mat1.size(2));\n    // (N, K, B) -> (N, B, K)\n    mat1 = mat1.transpose(1, 2);\n    // (M, K, B) -> (M, B, K)\n    mat2_tr = mat2_tr.transpose(1, 2);\n  }\n  auto ret = torch::zeros(shape, mat1.options());\n  const std::string op = \"dot\";\n  auto dgl_mat1 = TorchTensorToDGLArray(mat1);\n  auto dgl_mat2_tr = TorchTensorToDGLArray(mat2_tr);\n  auto dgl_ret = TorchTensorToDGLArray(ret);\n  // The format for calculation will be chosen in the following order: CSR,\n  // COO. CSR is created if the sparse matrix only has CSC format.\n  if (sparse_mat->HasCSR() || !sparse_mat->HasCOO()) {\n    // sparse_mat->CSRPtr() will implicitly convert CSC to CSR format if CSR\n    // does not exist.\n    auto csr = CSRToOldDGLCSR(sparse_mat->CSRPtr());\n    aten::CSRSDDMM(\n        op.c_str(), csr, dgl_mat1, dgl_mat2_tr, dgl_ret, 0 /* Lhs target: u */,\n        2 /* rhs target: v */);\n  } else {  // COO\n    auto coo = COOToOldDGLCOO(sparse_mat->COOPtr());\n    aten::COOSDDMM(\n        op.c_str(), coo, dgl_mat1, dgl_mat2_tr, dgl_ret, 0 /* Lhs target: u */,\n        2 /* rhs target: v */);\n  }\n  return ret;\n}\n\ntorch::Tensor BroadcastOpNoAutoGrad(\n    const c10::intrusive_ptr<SparseMatrix>& sparse_mat, torch::Tensor dense_mat,\n    const std::string& op, int64_t dim) {\n  auto sparse_val = sparse_mat->value();\n  const int64_t out_row = sparse_mat->nnz();\n  const std::vector<int64_t> shape({out_row, sparse_val.size(1)});\n  auto ret = torch::zeros(shape, sparse_val.options());\n\n  auto dgl_sparse_val = TorchTensorToDGLArray(sparse_val);\n  auto dgl_dense_mat = TorchTensorToDGLArray(dense_mat);\n  auto dgl_ret = TorchTensorToDGLArray(ret);\n  // Setting dgl_rhs_target to 0 or 2 means using row or column coordinators\n  // to access dgl_dense_mat for each edge, respectively.\n  auto dgl_rhs_target = dim == 0 ? 2 : 0;\n\n  // The format for calculation will be chosen in the following order: COO, CSR\n  // . COO is created if the sparse matrix only has CSC format.\n  if (sparse_mat->HasCOO() || !sparse_mat->HasCSR()) {\n    // sparse_mat->COOPtr() will implicitly convert CSC to COO format if COO\n    // does not exist.\n    auto coo = COOToOldDGLCOO(sparse_mat->COOPtr());\n    aten::COOSDDMM(\n        op.c_str(), coo, dgl_sparse_val, dgl_dense_mat, dgl_ret,\n        1 /* Lhs target: e */, dgl_rhs_target);\n  } else {\n    auto csr = CSRToOldDGLCSR(sparse_mat->CSRPtr());\n    aten::CSRSDDMM(\n        op.c_str(), csr, dgl_sparse_val, dgl_dense_mat, dgl_ret,\n        1 /* Lhs target: e */, dgl_rhs_target);\n  }\n  return ret;\n}\n\ntorch::Tensor BroadcastSubNoAutoGrad(\n    const c10::intrusive_ptr<SparseMatrix>& sparse_mat, torch::Tensor dense_mat,\n    int64_t dim) {\n  return BroadcastOpNoAutoGrad(sparse_mat, dense_mat, \"sub\", dim);\n}\n\ntorch::Tensor BroadcastDivNoAutoGrad(\n    const c10::intrusive_ptr<SparseMatrix>& sparse_mat, torch::Tensor dense_mat,\n    int64_t dim) {\n  return BroadcastOpNoAutoGrad(sparse_mat, dense_mat, \"div\", dim);\n}\n\ntorch::Tensor BroadcastMulNoAutoGrad(\n    const c10::intrusive_ptr<SparseMatrix>& sparse_mat, torch::Tensor dense_mat,\n    int64_t dim) {\n  return BroadcastOpNoAutoGrad(sparse_mat, dense_mat, \"mul\", dim);\n}\n\nc10::intrusive_ptr<SparseMatrix> SpSpMMNoAutoGrad(\n    const c10::intrusive_ptr<SparseMatrix>& lhs_mat, torch::Tensor lhs_val,\n    const c10::intrusive_ptr<SparseMatrix>& rhs_mat, torch::Tensor rhs_val,\n    bool lhs_transpose, bool rhs_transpose) {\n  aten::CSRMatrix lhs_dgl_csr, rhs_dgl_csr;\n  if (!lhs_transpose) {\n    lhs_dgl_csr = CSRToOldDGLCSR(lhs_mat->CSRPtr());\n  } else {\n    lhs_dgl_csr = CSRToOldDGLCSR(lhs_mat->CSCPtr());\n  }\n  if (!rhs_transpose) {\n    rhs_dgl_csr = CSRToOldDGLCSR(rhs_mat->CSRPtr());\n  } else {\n    rhs_dgl_csr = CSRToOldDGLCSR(rhs_mat->CSCPtr());\n  }\n  auto lhs_dgl_val = TorchTensorToDGLArray(lhs_val);\n  auto rhs_dgl_val = TorchTensorToDGLArray(rhs_val);\n  const int64_t ret_row =\n      lhs_transpose ? lhs_mat->shape()[1] : lhs_mat->shape()[0];\n  const int64_t ret_col =\n      rhs_transpose ? rhs_mat->shape()[0] : rhs_mat->shape()[1];\n  std::vector<int64_t> ret_shape({ret_row, ret_col});\n  aten::CSRMatrix ret_dgl_csr;\n  runtime::NDArray ret_val;\n  std::tie(ret_dgl_csr, ret_val) =\n      aten::CSRMM(lhs_dgl_csr, lhs_dgl_val, rhs_dgl_csr, rhs_dgl_val);\n  return SparseMatrix::FromCSRPointer(\n      CSRFromOldDGLCSR(ret_dgl_csr), DGLArrayToTorchTensor(ret_val), ret_shape);\n}\n\n}  // namespace sparse\n}  // namespace dgl\n"
  },
  {
    "path": "dgl_sparse/src/matmul.h",
    "content": "/**\n *  Copyright (c) 2022 by Contributors\n * @file matmul.h\n * @brief DGL sparse matrix multiplication functions.\n */\n#ifndef DGL_SPARSE_MATMUL_H_\n#define DGL_SPARSE_MATMUL_H_\n\n#include <sparse/sparse_matrix.h>\n#include <torch/script.h>\n\n#include <string>\n\nnamespace dgl {\nnamespace sparse {\n\n/**\n * @brief Perform a matrix multiplication of the sparse matrix and dense\n * matrix. It uses the sparse formats of `sparse_mat` and non-zero values of\n * `sparse_val` for SpMM. The `sparse_val` must be 1-dimensional. If the sparse\n * matrix has shape (n, m), the dense matrix must have shape (m, k). And\n * the returned dense matrix has shape (n, k).\n *\n * This function does not take care of autograd.\n *\n * @param sparse_mat The sparse matrix.\n * @param sparse_val Non-zero values of the sparse matrix.\n * @param dense_mat The dense matrix.\n * @param transpose_sparse Whether the sparse_mat is transposed.\n *\n * @return Dense tensor.\n */\ntorch::Tensor SpMMNoAutoGrad(\n    const c10::intrusive_ptr<SparseMatrix>& sparse_mat,\n    torch::Tensor sparse_val, torch::Tensor dense_mat, bool transpose_sparse);\n\n/**\n * @brief Perform a sampled matrix multiplication of a sparse matrix and two\n * dense matrices. It calculates `(mat1 @ mat2_tr^T) * spy(A)` and does consider\n * the values of the sparse matrix. For efficiency, `mat2_tr` is the\n * transposition of the matrix to be multiplied. If the sparse matrix has shape\n * (n, m), `mat1` and `mat2_tr` must have shapes of `(n, k)` and `(m,\n * k)`respectively. And the returned tensor has shape\n * `(sparse_matrix->nnz(),)`.\n *\n * This function does not take care of autograd.\n *\n * @param sparse_mat The sparse matrix.\n * @param mat1 The first dense matrix.\n * @param mat2_tr Transposition of the second matrix.\n *\n * @return Dense tensor.\n */\ntorch::Tensor SDDMMNoAutoGrad(\n    const c10::intrusive_ptr<SparseMatrix>& sparse_mat, torch::Tensor mat1,\n    torch::Tensor mat2_tr);\n\n/**\n * @brief Broadcast the dense feature to the nonzero entries and then compute\n * x_e = \\phi(x_e, x_v) on the dimension dim, where x_e is the nonzero value,\n * x_v is the dense feature, and \\phi is add, sub, mul, or div. dim = 0 or 1\n * means column-wise or row-wise broadcast respectively.\n *\n * This function does not take care of autograd.\n *\n * @param sparse_mat The sparse matrix with N rows and (nnz, D) nonzero values\n * @param dense_mat Dense feature of shape (N, D)\n * @param op Operator, can be add, sub, mul, or div\n * @param dim The dimension to broadcast.\n *\n * @return Dense tensor of shape (nnz, D)\n */\ntorch::Tensor BroadcastOpNoAutoGrad(\n    const c10::intrusive_ptr<SparseMatrix>& sparse_mat, torch::Tensor dense_mat,\n    const std::string& op, int64_t dim);\n\n/**\n * @brief Broadcast the dense feature to the nonzero entries and then compute\n * x_e = x_e - x_v on the dimension dim, where x_e is the nonzero value, x_v is\n * the dense feature. dim = 0 or 1 means column-wise or row-wise broadcast\n * respectively.\n *\n * This function does not take care of autograd.\n *\n * @param sparse_mat The sparse matrix with N rows and (nnz, D) nonzero values\n * @param dense_mat Dense feature of shape (N, D)\n * @param dim The dimension to broadcast.\n *\n * @return Dense tensor of shape (nnz, D)\n */\ntorch::Tensor BroadcastSubNoAutoGrad(\n    const c10::intrusive_ptr<SparseMatrix>& sparse_mat, torch::Tensor dense_mat,\n    int64_t dim);\n\n/**\n * @brief Broadcast the dense feature to the nonzero entries and then compute\n * x_e = x_e / x_v on the dimension dim, where x_e is the nonzero value, x_v is\n * the dense feature. dim = 0 or 1 means column-wise or row-wise broadcast\n * respectively.\n *\n * This function does not take care of autograd.\n *\n * @param sparse_mat The sparse matrix with N rows and (nnz, D) nonzero values\n * @param dense_mat Dense feature of shape (N, D)\n * @param dim The dimension to broadcast.\n *\n * @return Dense tensor of shape (nnz, D)\n */\ntorch::Tensor BroadcastDivNoAutoGrad(\n    const c10::intrusive_ptr<SparseMatrix>& sparse_mat, torch::Tensor dense_mat,\n    int64_t dim);\n\n/**\n * @brief Broadcast the dense feature to the nonzero entries and then compute\n * x_e = x_e * x_v on the dimension dim, where x_e is the nonzero value, x_v is\n * the dense feature. dim = 0 or 1 means column-wise or row-wise broadcast\n * respectively.\n *\n * This function does not take care of autograd.\n *\n * @param sparse_mat The sparse matrix with N rows and (nnz, D) nonzero values\n * @param dense_mat Dense feature of shape (N, D)\n * @param dim The dimension to broadcast.\n *\n * @return Dense tensor of shape (nnz, D)\n */\ntorch::Tensor BroadcastMulNoAutoGrad(\n    const c10::intrusive_ptr<SparseMatrix>& sparse_mat, torch::Tensor dense_mat,\n    int64_t dim);\n\n/**\n * @brief Perform a sparse-sparse matrix multiplication with possibly different\n * sparsities. The two sparse values must have 1-dimensional values. If the\n * first sparse matrix has shape (n, m), the second sparse matrix must have\n * shape (m, k), and the returned sparse matrix has shape (n, k).\n *\n * This function does not take care of autograd.\n *\n * @param lhs_mat The first sparse matrix of shape (n, m).\n * @param lhs_val Sparse value for the first sparse matrix.\n * @param rhs_mat The second sparse matrix of shape (m, k).\n * @param rhs_val Sparse value for the second sparse matrix.\n * @param lhs_transpose Whether the first matrix is transposed.\n * @param rhs_transpose Whether the second matrix is transposed.\n *\n * @return Sparse matrix of shape (n, k).\n */\nc10::intrusive_ptr<SparseMatrix> SpSpMMNoAutoGrad(\n    const c10::intrusive_ptr<SparseMatrix>& lhs_mat, torch::Tensor lhs_val,\n    const c10::intrusive_ptr<SparseMatrix>& rhs_mat, torch::Tensor rhs_val,\n    bool lhs_transpose, bool rhs_transpose);\n\n}  // namespace sparse\n}  // namespace dgl\n\n#endif  // DGL_SPARSE_MATMUL_H_\n"
  },
  {
    "path": "dgl_sparse/src/matrix_ops.cc",
    "content": "/**\n *  Copyright (c) 2023 by Contributors\n * @file matrix_ops.cc\n * @brief DGL C++ matrix operators.\n */\n#include <sparse/matrix_ops.h>\n#include <torch/script.h>\n\nnamespace dgl {\nnamespace sparse {\n\n/**\n * @brief Compute the intersection of two COO matrices. Return the intersection\n * COO matrix, and the indices of the intersection in the left-hand-side and\n * right-hand-side COO matrices.\n *\n * @param lhs The left-hand-side COO matrix.\n * @param rhs The right-hand-side COO matrix.\n *\n * @return A tuple of COO matrix, lhs indices, and rhs indices.\n */\nstd::tuple<std::shared_ptr<COO>, torch::Tensor, torch::Tensor> COOIntersection(\n    const std::shared_ptr<COO>& lhs, const std::shared_ptr<COO>& rhs) {\n  // 1. Encode the two COO matrices into arrays of integers.\n  auto lhs_arr =\n      lhs->indices.index({0}) * lhs->num_cols + lhs->indices.index({1});\n  auto rhs_arr =\n      rhs->indices.index({0}) * rhs->num_cols + rhs->indices.index({1});\n  // 2. Concatenate the two arrays.\n  auto arr = torch::cat({lhs_arr, rhs_arr});\n  // 3. Unique the concatenated array.\n  torch::Tensor unique, inverse, counts;\n  std::tie(unique, inverse, counts) =\n      torch::unique_dim(arr, 0, false, true, true);\n  // 4. Find the indices of the counts greater than 1 in the unique array.\n  auto mask = counts > 1;\n  // 5. Map the inverse array to the original array to generate indices.\n  auto lhs_inverse = inverse.slice(0, 0, lhs_arr.numel());\n  auto rhs_inverse = inverse.slice(0, lhs_arr.numel(), arr.numel());\n  auto map_to_original = torch::empty_like(unique);\n  map_to_original.index_put_(\n      {lhs_inverse},\n      torch::arange(lhs_inverse.numel(), map_to_original.options()));\n  auto lhs_indices = map_to_original.index({mask});\n  map_to_original.index_put_(\n      {rhs_inverse},\n      torch::arange(rhs_inverse.numel(), map_to_original.options()));\n  auto rhs_indices = map_to_original.index({mask});\n  // 6. Decode the indices to get the intersection COO matrix.\n  auto ret_arr = unique.index({mask});\n  auto ret_indices = torch::stack(\n      {ret_arr.floor_divide(lhs->num_cols), ret_arr % lhs->num_cols}, 0);\n  auto ret_coo = std::make_shared<COO>(\n      COO{lhs->num_rows, lhs->num_cols, ret_indices, false, false});\n  return {ret_coo, lhs_indices, rhs_indices};\n}\n\n/** @brief Return the reverted mapping of a permutation. */\nstatic torch::Tensor RevertPermutation(const torch::Tensor& perm) {\n  auto rev_tensor = torch::empty_like(perm);\n  rev_tensor.index_put_(\n      {perm}, torch::arange(0, perm.numel(), rev_tensor.options()));\n  return rev_tensor;\n}\n\n/**\n * @brief Compute the compact indices of row indices and leading indices. Return\n * the compacted indices and the original row indices of compacted indices.\n *\n * @param row The row indices.\n * @param leading_indices The leading indices.\n *\n * @return A tuple of compact indices, original indices.\n */\nstatic std::tuple<torch::Tensor, torch::Tensor> CompactIndices(\n    const torch::Tensor& row,\n    const torch::optional<torch::Tensor>& leading_indices) {\n  torch::Tensor sorted, sort_indices, uniqued, unique_reverse_indices, counts;\n  // 1. Sort leading indices and row indices in ascending order.\n  int64_t n_leading_indices = 0;\n  if (leading_indices.has_value()) {\n    n_leading_indices = leading_indices.value().numel();\n    std::tie(sorted, sort_indices) =\n        torch::cat({leading_indices.value(), row}).sort();\n  } else {\n    std::tie(sorted, sort_indices) = row.sort();\n  }\n  // 2. Reverse sort indices.\n  auto sort_rev_indices = RevertPermutation(sort_indices);\n  // 3. Unique the sorted array.\n  std::tie(uniqued, unique_reverse_indices, counts) =\n      torch::unique_consecutive(sorted, true);\n  auto reverse_indices = unique_reverse_indices.index({sort_rev_indices});\n  auto n_uniqued = uniqued.numel();\n\n  // 4. Relabel the indices and map the inverse array to the original array.\n  auto split_indices = torch::full({n_uniqued}, -1, reverse_indices.options());\n\n  split_indices.index_put_(\n      {reverse_indices.slice(0, 0, n_leading_indices)},\n      torch::arange(0, n_leading_indices, split_indices.options()));\n\n  split_indices.index_put_(\n      {(split_indices == -1).nonzero().view(-1)},\n      torch::arange(n_leading_indices, n_uniqued, split_indices.options()));\n  // 5. Decode the indices to get the compact indices.\n  auto new_row = split_indices.index({reverse_indices.slice(\n      0, n_leading_indices, n_leading_indices + row.numel())});\n  return {new_row, uniqued.index({RevertPermutation(split_indices)})};\n}\n\nstatic std::tuple<c10::intrusive_ptr<SparseMatrix>, torch::Tensor> CompactCOO(\n    const c10::intrusive_ptr<SparseMatrix>& mat, int64_t dim,\n    const torch::optional<torch::Tensor>& leading_indices) {\n  torch::Tensor row, col;\n  auto coo = mat->COOTensors();\n  if (dim == 0)\n    std::tie(row, col) = coo;\n  else\n    std::tie(col, row) = coo;\n\n  torch::Tensor new_row, uniqued;\n  std::tie(new_row, uniqued) = CompactIndices(row, leading_indices);\n\n  if (dim == 0) {\n    auto ret = SparseMatrix::FromCOO(\n        torch::stack({new_row, col}, 0), mat->value(),\n        std::vector<int64_t>{uniqued.numel(), mat->shape()[1]});\n    return {ret, uniqued};\n  } else {\n    auto ret = SparseMatrix::FromCOO(\n        torch::stack({col, new_row}, 0), mat->value(),\n        std::vector<int64_t>{mat->shape()[0], uniqued.numel()});\n    return {ret, uniqued};\n  }\n}\n\nstatic std::tuple<c10::intrusive_ptr<SparseMatrix>, torch::Tensor> CompactCSR(\n    const c10::intrusive_ptr<SparseMatrix>& mat, int64_t dim,\n    const torch::optional<torch::Tensor>& leading_indices) {\n  std::shared_ptr<CSR> csr;\n  if (dim == 0)\n    csr = mat->CSCPtr();\n  else\n    csr = mat->CSRPtr();\n\n  torch::Tensor new_indices, uniqued;\n  std::tie(new_indices, uniqued) =\n      CompactIndices(csr->indices, leading_indices);\n\n  auto ret_value = mat->value();\n  if (csr->value_indices.has_value())\n    ret_value = mat->value().index_select(0, csr->value_indices.value());\n  if (dim == 0) {\n    auto ret = SparseMatrix::FromCSC(\n        csr->indptr, new_indices, ret_value,\n        std::vector<int64_t>{uniqued.numel(), mat->shape()[1]});\n    return {ret, uniqued};\n  } else {\n    auto ret = SparseMatrix::FromCSR(\n        csr->indptr, new_indices, ret_value,\n        std::vector<int64_t>{mat->shape()[0], uniqued.numel()});\n    return {ret, uniqued};\n  }\n}\n\nstd::tuple<c10::intrusive_ptr<SparseMatrix>, torch::Tensor> Compact(\n    const c10::intrusive_ptr<SparseMatrix>& mat, int64_t dim,\n    const torch::optional<torch::Tensor>& leading_indices) {\n  if (mat->HasCOO()) {\n    return CompactCOO(mat, dim, leading_indices);\n  }\n  return CompactCSR(mat, dim, leading_indices);\n}\n\n}  // namespace sparse\n}  // namespace dgl\n"
  },
  {
    "path": "dgl_sparse/src/matrix_ops_impl.h",
    "content": "/**\n *  Copyright (c) 2023 by Contributors\n * @file matrix_ops_impl.h\n * @brief DGL C++ sparse matrix operator implementations.\n */\n#ifndef DGL_SPARSE_MATRIX_OPS_IMPL_H_\n#define DGL_SPARSE_MATRIX_OPS_IMPL_H_\n\n#include <sparse/sparse_format.h>\n#include <sparse/sparse_matrix.h>\n\n#include <tuple>\n#include <vector>\n\n#include \"./utils.h\"\n\nnamespace dgl {\nnamespace sparse {}  // namespace sparse\n}  // namespace dgl\n\n#endif  // DGL_SPARSE_MATRIX_OPS_IMPL_H_\n"
  },
  {
    "path": "dgl_sparse/src/python_binding.cc",
    "content": "/**\n *  Copyright (c) 2022 by Contributors\n * @file python_binding.cc\n * @brief DGL sparse library Python binding.\n */\n// clang-format off\n#include <sparse/dgl_headers.h>\n// clang-format on\n\n#include <sparse/elementwise_op.h>\n#include <sparse/matrix_ops.h>\n#include <sparse/reduction.h>\n#include <sparse/sddmm.h>\n#include <sparse/softmax.h>\n#include <sparse/sparse_matrix.h>\n#include <sparse/spmm.h>\n#include <sparse/spspmm.h>\n#include <torch/custom_class.h>\n#include <torch/script.h>\n\nnamespace dgl {\nnamespace sparse {\n\nTORCH_LIBRARY(dgl_sparse, m) {\n  m.class_<SparseMatrix>(\"SparseMatrix\")\n      .def(\"val\", &SparseMatrix::value)\n      .def(\"nnz\", &SparseMatrix::nnz)\n      .def(\"device\", &SparseMatrix::device)\n      .def(\"shape\", &SparseMatrix::shape)\n      .def(\"coo\", &SparseMatrix::COOTensors)\n      .def(\"indices\", &SparseMatrix::Indices)\n      .def(\"csr\", &SparseMatrix::CSRTensors)\n      .def(\"csc\", &SparseMatrix::CSCTensors)\n      .def(\"transpose\", &SparseMatrix::Transpose)\n      .def(\"coalesce\", &SparseMatrix::Coalesce)\n      .def(\"has_duplicate\", &SparseMatrix::HasDuplicate)\n      .def(\"is_diag\", &SparseMatrix::HasDiag)\n      .def(\"index_select\", &SparseMatrix::IndexSelect)\n      .def(\"range_select\", &SparseMatrix::RangeSelect)\n      .def(\"sample\", &SparseMatrix::Sample);\n  m.def(\"from_coo\", &SparseMatrix::FromCOO)\n      .def(\"from_csr\", &SparseMatrix::FromCSR)\n      .def(\"from_csc\", &SparseMatrix::FromCSC)\n      .def(\"from_diag\", &SparseMatrix::FromDiag)\n      .def(\"spsp_add\", &SpSpAdd)\n      .def(\"spsp_mul\", &SpSpMul)\n      .def(\"spsp_div\", &SpSpDiv)\n      .def(\"reduce\", &Reduce)\n      .def(\"sum\", &ReduceSum)\n      .def(\"smean\", &ReduceMean)\n      .def(\"smin\", &ReduceMin)\n      .def(\"smax\", &ReduceMax)\n      .def(\"sprod\", &ReduceProd)\n      .def(\"val_like\", &SparseMatrix::ValLike)\n      .def(\"spmm\", &SpMM)\n      .def(\"sddmm\", &SDDMM)\n      .def(\"softmax\", &Softmax)\n      .def(\"spspmm\", &SpSpMM)\n      .def(\"compact\", &Compact);\n}\n\n}  // namespace sparse\n}  // namespace dgl\n"
  },
  {
    "path": "dgl_sparse/src/reduction.cc",
    "content": "/**\n *  Copyright (c) 2022 by Contributors\n * @file reduction.cc\n * @brief DGL C++ sparse matrix reduction operator implementation.\n */\n// clang-format off\n#include <sparse/dgl_headers.h>\n// clang-format on\n\n#include <sparse/elementwise_op.h>\n#include <sparse/reduction.h>\n#include <sparse/sparse_matrix.h>\n#include <torch/script.h>\n\n#include <string>\n#include <vector>\n\nnamespace dgl {\nnamespace sparse {\n\nnamespace {\n\ntorch::Tensor ReduceAlong(\n    const c10::intrusive_ptr<SparseMatrix>& A, const std::string& reduce,\n    int64_t dim) {\n  auto value = A->value();\n  auto coo = A->COOPtr();\n\n  std::string reduce_op;\n  if (reduce == \"sum\") {\n    reduce_op = \"sum\";\n  } else if (reduce == \"smin\") {\n    reduce_op = \"amin\";\n  } else if (reduce == \"smax\") {\n    reduce_op = \"amax\";\n  } else if (reduce == \"smean\") {\n    reduce_op = \"mean\";\n  } else if (reduce == \"sprod\") {\n    reduce_op = \"prod\";\n  } else {\n    TORCH_CHECK(false, \"unknown reduce function \", reduce);\n    return torch::Tensor();\n  }\n\n  // Create the output tensor with shape\n  //\n  //   [A.num_rows if dim == 1 else A.num_cols] + A.val.shape[1:]\n  std::vector<int64_t> output_shape = value.sizes().vec();\n  std::vector<int64_t> view_dims(output_shape.size(), 1);\n  view_dims[0] = -1;\n  torch::Tensor idx;\n  if (dim == 0) {\n    output_shape[0] = coo->num_cols;\n    idx = coo->indices.index({1}).view(view_dims).expand_as(value);\n  } else if (dim == 1) {\n    output_shape[0] = coo->num_rows;\n    idx = coo->indices.index({0}).view(view_dims).expand_as(value);\n  }\n  torch::Tensor out = torch::zeros(output_shape, value.options());\n\n  if (dim == 0) {\n    out.scatter_reduce_(0, idx, value, reduce_op, false);\n  } else if (dim == 1) {\n    out.scatter_reduce_(0, idx, value, reduce_op, false);\n  }\n\n  return out;\n}\n\ntorch::Tensor ReduceAll(\n    const c10::intrusive_ptr<SparseMatrix>& A, const std::string& reduce) {\n  if (reduce == \"sum\") {\n    return A->value().sum(0);\n  } else if (reduce == \"smin\") {\n    return A->value().amin(0);\n  } else if (reduce == \"smax\") {\n    return A->value().amax(0);\n  } else if (reduce == \"smean\") {\n    return A->value().mean(0);\n  } else if (reduce == \"sprod\") {\n    return A->value().prod(0);\n  }\n\n  TORCH_CHECK(false, \"unknown reduce function \", reduce);\n  return torch::Tensor();\n}\n\n}  // namespace\n\ntorch::Tensor Reduce(\n    const c10::intrusive_ptr<SparseMatrix>& A, const std::string& reduce,\n    const torch::optional<int64_t>& dim) {\n  return dim.has_value() ? ReduceAlong(A, reduce, dim.value())\n                         : ReduceAll(A, reduce);\n}\n\n}  // namespace sparse\n}  // namespace dgl\n"
  },
  {
    "path": "dgl_sparse/src/sddmm.cc",
    "content": "/**\n *  Copyright (c) 2022 by Contributors\n * @file sddmm.cc\n * @brief DGL C++ sparse SDDMM operator implementation.\n */\n#include <sparse/sparse_matrix.h>\n#include <sparse/spmm.h>\n#include <torch/script.h>\n\n#include <sstream>\n\n#include \"./matmul.h\"\n#include \"./utils.h\"\n\nnamespace dgl {\nnamespace sparse {\n\nusing namespace torch::autograd;\n\nclass SDDMMAutoGrad : public Function<SDDMMAutoGrad> {\n public:\n  static torch::Tensor forward(\n      AutogradContext* ctx, const c10::intrusive_ptr<SparseMatrix>& sparse_mat,\n      torch::Tensor mat1, torch::Tensor mat2_tr);\n\n  static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs);\n};\n\nvoid _SDDMMSanityCheck(\n    const c10::intrusive_ptr<SparseMatrix>& sparse_mat, torch::Tensor mat1,\n    torch::Tensor mat2) {\n  bool shape_check = true;\n  shape_check &= mat1.dim() == mat2.dim();\n  shape_check &= mat1.dim() <= 3;\n  shape_check &= sparse_mat->shape()[0] == mat1.size(0);\n  if (mat1.dim() == 3) {\n    shape_check &= sparse_mat->shape()[1] == mat2.size(1);\n    shape_check &= mat1.size(2) == mat2.size(2);\n    if (sparse_mat->value().dim() > 1) {\n      shape_check &= sparse_mat->value().size(1) == mat1.size(2);\n    }\n  } else {\n    shape_check &= sparse_mat->shape()[1] == mat2.size(mat2.dim() - 1);\n  }\n  if (mat1.dim() >= 2) {\n    shape_check &= mat1.size(1) == mat2.size(0);\n  }\n  if (!shape_check) {\n    std::stringstream error;\n    error << \"SDDMM: Invalid input shapes. sparse_mat: \"\n          << c10::IntArrayRef(sparse_mat->shape())\n          << \", sparse_val: \" << sparse_mat->value().sizes()\n          << \", mat1: \" << mat1.sizes() << \", mat2: \" << mat2.sizes()\n          << \". Valid input shapes (sparse_mat, mat1, mat2) are: (1) (n, m), \"\n             \"(n, k), and (k, m); (2) (n, m), (n,), and (m,); (3) (n, m, b), \"\n             \"(n, k, b) and (k, m, b); (4) \"\n             \"(n, m), (n, k, b), and (k, m, b).\";\n    TORCH_CHECK(false, error.str());\n  }\n  TORCH_CHECK(\n      mat1.dtype() == mat2.dtype(),\n      \"SDDMM: the two dense matrices should have the same dtype.\");\n  TORCH_CHECK(\n      mat1.device() == mat2.device() && sparse_mat->device() == mat2.device(),\n      \"SDDMM: the two dense matrices and sparse matrix should on the same \"\n      \"device.\");\n}\n\ntorch::Tensor SDDMMAutoGrad::forward(\n    AutogradContext* ctx, const c10::intrusive_ptr<SparseMatrix>& sparse_mat,\n    torch::Tensor mat1, torch::Tensor mat2) {\n  auto mat2_tr = mat2.transpose(0, 1);\n  auto ret = SDDMMNoAutoGrad(sparse_mat, mat1, mat2_tr);\n  torch::Tensor cache_mat1, cache_mat2;\n  if (mat1.requires_grad()) {\n    cache_mat2 = mat2;\n  }\n  if (mat2.requires_grad()) {\n    cache_mat1 = mat1;\n  }\n  ctx->save_for_backward({cache_mat1, cache_mat2});\n  ctx->saved_data[\"mat1_requires_grad\"] = mat1.requires_grad();\n  ctx->saved_data[\"mat2_requires_grad\"] = mat2.requires_grad();\n  ctx->saved_data[\"sparse_mat\"] = sparse_mat;\n  return ret;\n}\n\ntensor_list SDDMMAutoGrad::backward(\n    AutogradContext* ctx, tensor_list grad_outputs) {\n  auto saved = ctx->get_saved_variables();\n  auto mat1 = saved[0];\n  auto mat2 = saved[1];\n  auto sparse_mat = ctx->saved_data[\"sparse_mat\"].toCustomClass<SparseMatrix>();\n  auto grad = grad_outputs[0];\n  torch::Tensor mat1_grad, mat2_grad;\n  if (ctx->saved_data[\"mat1_requires_grad\"].toBool()) {\n    // SDDMM(M, A, B) = C. dA = SpMM(dC, B^T)\n    mat1_grad = SpMMNoAutoGrad(sparse_mat, grad, mat2.transpose(0, 1), false);\n  }\n  if (ctx->saved_data[\"mat2_requires_grad\"].toBool()) {\n    // SDDMM(M, A, B) = C. dB = SpMM(dC^T, A)^T\n    auto mat2_tr_grad = SpMMNoAutoGrad(sparse_mat, grad, mat1, true);\n    mat2_grad = mat2_tr_grad.transpose(0, 1);\n  }\n  return {torch::Tensor(), mat1_grad, mat2_grad};\n}\n\nc10::intrusive_ptr<SparseMatrix> SDDMM(\n    const c10::intrusive_ptr<SparseMatrix>& sparse_mat, torch::Tensor mat1,\n    torch::Tensor mat2) {\n  if (mat1.dim() == 1) {\n    mat1 = mat1.view({mat1.size(0), 1});\n  }\n  if (mat2.dim() == 1) {\n    mat2 = mat2.view({1, mat2.size(0)});\n  }\n  _SDDMMSanityCheck(sparse_mat, mat1, mat2);\n  auto val = SDDMMAutoGrad::apply(sparse_mat, mat1, mat2);\n  auto sparse_val = sparse_mat->value();\n  // Broadcast the sparse value in batched SDDMM.\n  if (sparse_val.dim() < val.dim()) {\n    sparse_val = sparse_val.unsqueeze(-1);\n  }\n  val = val * sparse_val;\n  return SparseMatrix::ValLike(sparse_mat, val);\n}\n\n}  // namespace sparse\n}  // namespace dgl\n"
  },
  {
    "path": "dgl_sparse/src/softmax.cc",
    "content": "/**\n *  Copyright (c) 2022 by Contributors\n * @file softmax.cc\n * @brief DGL C++ Softmax operator implementation\n */\n\n#include <sparse/reduction.h>\n#include <sparse/sparse_matrix.h>\n#include <torch/script.h>\n\n#include \"./matmul.h\"\n#include \"./utils.h\"\n\nnamespace dgl {\nnamespace sparse {\n\nusing namespace torch::autograd;\n\nclass SoftmaxAutoGrad : public Function<SoftmaxAutoGrad> {\n public:\n  static torch::Tensor forward(\n      AutogradContext* ctx, c10::intrusive_ptr<SparseMatrix> sparse_mat,\n      torch::Tensor sparse_val, int64_t dim);\n\n  static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs);\n};\n\ntorch::Tensor SoftmaxAutoGrad::forward(\n    AutogradContext* ctx, c10::intrusive_ptr<SparseMatrix> sparse_mat,\n    torch::Tensor sparse_val, int64_t dim) {\n  // Reduce by columns with dim 1.\n  auto sparse_val_max = ReduceMax(sparse_mat, dim);\n  auto sparse_val_exp =\n      BroadcastSubNoAutoGrad(sparse_mat, sparse_val_max, dim).exp();\n  auto sparse_val_sum =\n      ReduceSum(SparseMatrix::ValLike(sparse_mat, sparse_val_exp), dim);\n  auto sparse_score = BroadcastDivNoAutoGrad(\n      SparseMatrix::ValLike(sparse_mat, sparse_val_exp), sparse_val_sum, dim);\n\n  const bool sparse_requires_grad = sparse_val.requires_grad();\n  torch::Tensor cache_sparse_score;\n  if (sparse_requires_grad) {\n    cache_sparse_score = sparse_score;\n  }\n  ctx->saved_data[\"sparse_matrix\"] = sparse_mat;\n  ctx->saved_data[\"sparse_requires_grad\"] = sparse_requires_grad;\n  ctx->saved_data[\"dim\"] = dim;\n  ctx->save_for_backward({cache_sparse_score});\n  return sparse_score;\n}\n\ntensor_list SoftmaxAutoGrad::backward(\n    AutogradContext* ctx, tensor_list grad_outputs) {\n  auto saved = ctx->get_saved_variables();\n  auto sparse_score = saved[0];\n  auto output_grad = grad_outputs[0];\n\n  auto sparse_mat =\n      ctx->saved_data[\"sparse_matrix\"].toCustomClass<SparseMatrix>();\n  const bool sparse_requires_grad =\n      ctx->saved_data[\"sparse_requires_grad\"].toBool();\n  const int64_t dim = ctx->saved_data[\"dim\"].toInt();\n\n  torch::Tensor sparse_val_grad;\n  if (sparse_requires_grad) {\n    auto sds = sparse_score * output_grad;\n    auto accum = ReduceSum(SparseMatrix::ValLike(sparse_mat, sds), dim);\n    sparse_val_grad =\n        sds - BroadcastMulNoAutoGrad(\n                  SparseMatrix::ValLike(sparse_mat, sparse_score), accum, dim);\n  }\n\n  return {torch::Tensor(), sparse_val_grad, torch::Tensor()};\n}\n\nc10::intrusive_ptr<SparseMatrix> Softmax(\n    const c10::intrusive_ptr<SparseMatrix>& sparse_mat, int64_t dim) {\n  auto sparse_val = sparse_mat->value();\n  bool expand_dim = false;\n  auto new_sparse_mat = sparse_mat;\n  if (sparse_val.dim() == 1) {\n    sparse_val = sparse_val.view({-1, 1});\n    expand_dim = true;\n    new_sparse_mat = SparseMatrix::ValLike(sparse_mat, sparse_val);\n  }\n\n  auto new_sparse_val = SoftmaxAutoGrad::apply(new_sparse_mat, sparse_val, dim);\n\n  if (expand_dim) {\n    new_sparse_val = new_sparse_val.view(-1);\n  }\n  return SparseMatrix::ValLike(sparse_mat, new_sparse_val);\n}\n\n}  // namespace sparse\n}  // namespace dgl\n"
  },
  {
    "path": "dgl_sparse/src/sparse_format.cc",
    "content": "/**\n *  Copyright (c) 2022 by Contributors\n * @file sparse_format.cc\n * @brief DGL C++ sparse format implementations.\n */\n// clang-format off\n#include <sparse/dgl_headers.h>\n// clang-format on\n\n#include <sparse/sparse_format.h>\n\n#include \"./utils.h\"\n\nnamespace dgl {\nnamespace sparse {\n\nstd::shared_ptr<COO> COOFromOldDGLCOO(const aten::COOMatrix& dgl_coo) {\n  auto row = DGLArrayToTorchTensor(dgl_coo.row);\n  auto col = DGLArrayToTorchTensor(dgl_coo.col);\n  TORCH_CHECK(aten::IsNullArray(dgl_coo.data));\n  auto indices = torch::stack({row, col});\n  return std::make_shared<COO>(\n      COO{dgl_coo.num_rows, dgl_coo.num_cols, indices, dgl_coo.row_sorted,\n          dgl_coo.col_sorted});\n}\n\naten::COOMatrix COOToOldDGLCOO(const std::shared_ptr<COO>& coo) {\n  auto row = TorchTensorToDGLArray(coo->indices.index({0}));\n  auto col = TorchTensorToDGLArray(coo->indices.index({1}));\n  return aten::COOMatrix(\n      coo->num_rows, coo->num_cols, row, col, aten::NullArray(),\n      coo->row_sorted, coo->col_sorted);\n}\n\nstd::shared_ptr<CSR> CSRFromOldDGLCSR(const aten::CSRMatrix& dgl_csr) {\n  auto indptr = DGLArrayToTorchTensor(dgl_csr.indptr);\n  auto indices = DGLArrayToTorchTensor(dgl_csr.indices);\n  auto value_indices = DGLArrayToOptionalTorchTensor(dgl_csr.data);\n  return std::make_shared<CSR>(\n      CSR{dgl_csr.num_rows, dgl_csr.num_cols, indptr, indices, value_indices,\n          dgl_csr.sorted});\n}\n\naten::CSRMatrix CSRToOldDGLCSR(const std::shared_ptr<CSR>& csr) {\n  auto indptr = TorchTensorToDGLArray(csr->indptr);\n  auto indices = TorchTensorToDGLArray(csr->indices);\n  auto data = OptionalTorchTensorToDGLArray(csr->value_indices);\n  return aten::CSRMatrix(\n      csr->num_rows, csr->num_cols, indptr, indices, data, csr->sorted);\n}\n\ntorch::Tensor COOToTorchCOO(\n    const std::shared_ptr<COO>& coo, torch::Tensor value) {\n  torch::Tensor indices = coo->indices;\n  if (value.ndimension() == 2) {\n    return torch::sparse_coo_tensor(\n        indices, value, {coo->num_rows, coo->num_cols, value.size(1)});\n  } else {\n    return torch::sparse_coo_tensor(\n        indices, value, {coo->num_rows, coo->num_cols});\n  }\n}\n\nstd::shared_ptr<COO> CSRToCOO(const std::shared_ptr<CSR>& csr) {\n  auto dgl_csr = CSRToOldDGLCSR(csr);\n  auto dgl_coo = aten::CSRToCOO(dgl_csr, csr->value_indices.has_value());\n  return COOFromOldDGLCOO(dgl_coo);\n}\n\nstd::shared_ptr<COO> CSCToCOO(const std::shared_ptr<CSR>& csc) {\n  auto dgl_csc = CSRToOldDGLCSR(csc);\n  auto dgl_coo = aten::CSRToCOO(dgl_csc, csc->value_indices.has_value());\n  dgl_coo = aten::COOTranspose(dgl_coo);\n  return COOFromOldDGLCOO(dgl_coo);\n}\n\nstd::shared_ptr<CSR> COOToCSR(const std::shared_ptr<COO>& coo) {\n  auto dgl_coo = COOToOldDGLCOO(coo);\n  auto dgl_csr = aten::COOToCSR(dgl_coo);\n  return CSRFromOldDGLCSR(dgl_csr);\n}\n\nstd::shared_ptr<CSR> CSCToCSR(const std::shared_ptr<CSR>& csc) {\n  auto dgl_csc = CSRToOldDGLCSR(csc);\n  auto dgl_csr = aten::CSRTranspose(dgl_csc);\n  return CSRFromOldDGLCSR(dgl_csr);\n}\n\nstd::shared_ptr<CSR> COOToCSC(const std::shared_ptr<COO>& coo) {\n  auto dgl_coo = COOToOldDGLCOO(coo);\n  auto dgl_coo_transpose = aten::COOTranspose(dgl_coo);\n  auto dgl_csc = aten::COOToCSR(dgl_coo_transpose);\n  return CSRFromOldDGLCSR(dgl_csc);\n}\n\nstd::shared_ptr<CSR> CSRToCSC(const std::shared_ptr<CSR>& csr) {\n  auto dgl_csr = CSRToOldDGLCSR(csr);\n  auto dgl_csc = aten::CSRTranspose(dgl_csr);\n  return CSRFromOldDGLCSR(dgl_csc);\n}\n\nstd::shared_ptr<COO> DiagToCOO(\n    const std::shared_ptr<Diag>& diag,\n    const c10::TensorOptions& indices_options) {\n  int64_t nnz = std::min(diag->num_rows, diag->num_cols);\n  auto indices = torch::arange(nnz, indices_options).repeat({2, 1});\n  return std::make_shared<COO>(\n      COO{diag->num_rows, diag->num_cols, indices, true, true});\n}\n\nstd::shared_ptr<CSR> DiagToCSR(\n    const std::shared_ptr<Diag>& diag,\n    const c10::TensorOptions& indices_options) {\n  int64_t nnz = std::min(diag->num_rows, diag->num_cols);\n  auto indptr = torch::full(diag->num_rows + 1, nnz, indices_options);\n  auto nnz_range = torch::arange(nnz + 1, indices_options);\n  indptr.index_put_({nnz_range}, nnz_range);\n  auto indices = torch::arange(nnz, indices_options);\n  return std::make_shared<CSR>(\n      CSR{diag->num_rows, diag->num_cols, indptr, indices,\n          torch::optional<torch::Tensor>(), true});\n}\n\nstd::shared_ptr<CSR> DiagToCSC(\n    const std::shared_ptr<Diag>& diag,\n    const c10::TensorOptions& indices_options) {\n  int64_t nnz = std::min(diag->num_rows, diag->num_cols);\n  auto indptr = torch::full(diag->num_cols + 1, nnz, indices_options);\n  auto nnz_range = torch::arange(nnz + 1, indices_options);\n  indptr.index_put_({nnz_range}, nnz_range);\n  auto indices = torch::arange(nnz, indices_options);\n  return std::make_shared<CSR>(\n      CSR{diag->num_cols, diag->num_rows, indptr, indices,\n          torch::optional<torch::Tensor>(), true});\n}\n\nstd::shared_ptr<COO> COOTranspose(const std::shared_ptr<COO>& coo) {\n  auto dgl_coo = COOToOldDGLCOO(coo);\n  auto dgl_coo_tr = aten::COOTranspose(dgl_coo);\n  return COOFromOldDGLCOO(dgl_coo_tr);\n}\n\nstd::pair<std::shared_ptr<COO>, torch::Tensor> COOSort(\n    const std::shared_ptr<COO>& coo) {\n  auto encoded_coo =\n      coo->indices.index({0}) * coo->num_cols + coo->indices.index({1});\n  torch::Tensor sorted, perm;\n  std::tie(sorted, perm) = encoded_coo.sort();\n  auto sorted_coo = std::make_shared<COO>(\n      COO{coo->num_rows, coo->num_cols, coo->indices.index_select(1, perm),\n          true, true});\n  return {sorted_coo, perm};\n}\n\n}  // namespace sparse\n}  // namespace dgl\n"
  },
  {
    "path": "dgl_sparse/src/sparse_matrix.cc",
    "content": "/**\n *  Copyright (c) 2022 by Contributors\n * @file sparse_matrix.cc\n * @brief DGL C++ sparse matrix implementations.\n */\n// clang-format off\n#include <sparse/dgl_headers.h>\n// clang-format on\n\n#include <c10/util/Logging.h>\n#include <sparse/elementwise_op.h>\n#include <sparse/sparse_matrix.h>\n#include <torch/script.h>\n\n#include \"./utils.h\"\n\nnamespace dgl {\nnamespace sparse {\n\nSparseMatrix::SparseMatrix(\n    const std::shared_ptr<COO>& coo, const std::shared_ptr<CSR>& csr,\n    const std::shared_ptr<CSR>& csc, const std::shared_ptr<Diag>& diag,\n    torch::Tensor value, const std::vector<int64_t>& shape)\n    : coo_(coo),\n      csr_(csr),\n      csc_(csc),\n      diag_(diag),\n      value_(value),\n      shape_(shape) {\n  TORCH_CHECK(\n      coo != nullptr || csr != nullptr || csc != nullptr || diag != nullptr,\n      \"At least one of CSR/COO/CSC/Diag is required to construct a \"\n      \"SparseMatrix.\")\n  TORCH_CHECK(\n      shape.size() == 2, \"The shape of a sparse matrix should be \",\n      \"2-dimensional.\");\n  // NOTE: Currently all the tensors of a SparseMatrix should on the same\n  // device. Do we allow the graph structure and values are on different\n  // devices?\n  if (coo != nullptr) {\n    TORCH_CHECK(coo->indices.dim() == 2);\n    TORCH_CHECK(coo->indices.size(0) == 2);\n    TORCH_CHECK(coo->indices.size(1) == value.size(0));\n    TORCH_CHECK(coo->indices.device() == value.device());\n  }\n  if (csr != nullptr) {\n    TORCH_CHECK(csr->indptr.dim() == 1);\n    TORCH_CHECK(csr->indices.dim() == 1);\n    TORCH_CHECK(csr->indptr.size(0) == shape[0] + 1);\n    TORCH_CHECK(csr->indices.size(0) == value.size(0));\n    TORCH_CHECK(csr->indptr.device() == value.device());\n    TORCH_CHECK(csr->indices.device() == value.device());\n  }\n  if (csc != nullptr) {\n    TORCH_CHECK(csc->indptr.dim() == 1);\n    TORCH_CHECK(csc->indices.dim() == 1);\n    TORCH_CHECK(csc->indptr.size(0) == shape[1] + 1);\n    TORCH_CHECK(csc->indices.size(0) == value.size(0));\n    TORCH_CHECK(csc->indptr.device() == value.device());\n    TORCH_CHECK(csc->indices.device() == value.device());\n  }\n  if (diag != nullptr) {\n    TORCH_CHECK(value.size(0) == std::min(diag->num_rows, diag->num_cols));\n  }\n}\n\nc10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCOOPointer(\n    const std::shared_ptr<COO>& coo, torch::Tensor value,\n    const std::vector<int64_t>& shape) {\n  return c10::make_intrusive<SparseMatrix>(\n      coo, nullptr, nullptr, nullptr, value, shape);\n}\n\nc10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCSRPointer(\n    const std::shared_ptr<CSR>& csr, torch::Tensor value,\n    const std::vector<int64_t>& shape) {\n  return c10::make_intrusive<SparseMatrix>(\n      nullptr, csr, nullptr, nullptr, value, shape);\n}\n\nc10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCSCPointer(\n    const std::shared_ptr<CSR>& csc, torch::Tensor value,\n    const std::vector<int64_t>& shape) {\n  return c10::make_intrusive<SparseMatrix>(\n      nullptr, nullptr, csc, nullptr, value, shape);\n}\n\nc10::intrusive_ptr<SparseMatrix> SparseMatrix::FromDiagPointer(\n    const std::shared_ptr<Diag>& diag, torch::Tensor value,\n    const std::vector<int64_t>& shape) {\n  return c10::make_intrusive<SparseMatrix>(\n      nullptr, nullptr, nullptr, diag, value, shape);\n}\n\nc10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCOO(\n    torch::Tensor indices, torch::Tensor value,\n    const std::vector<int64_t>& shape) {\n  auto coo =\n      std::make_shared<COO>(COO{shape[0], shape[1], indices, false, false});\n  return SparseMatrix::FromCOOPointer(coo, value, shape);\n}\n\nc10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCSR(\n    torch::Tensor indptr, torch::Tensor indices, torch::Tensor value,\n    const std::vector<int64_t>& shape) {\n  auto csr = std::make_shared<CSR>(\n      CSR{shape[0], shape[1], indptr, indices, torch::optional<torch::Tensor>(),\n          false});\n  return SparseMatrix::FromCSRPointer(csr, value, shape);\n}\n\nc10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCSC(\n    torch::Tensor indptr, torch::Tensor indices, torch::Tensor value,\n    const std::vector<int64_t>& shape) {\n  auto csc = std::make_shared<CSR>(\n      CSR{shape[1], shape[0], indptr, indices, torch::optional<torch::Tensor>(),\n          false});\n  return SparseMatrix::FromCSCPointer(csc, value, shape);\n}\n\nc10::intrusive_ptr<SparseMatrix> SparseMatrix::FromDiag(\n    torch::Tensor value, const std::vector<int64_t>& shape) {\n  auto diag = std::make_shared<Diag>(Diag{shape[0], shape[1]});\n  return SparseMatrix::FromDiagPointer(diag, value, shape);\n}\n\nc10::intrusive_ptr<SparseMatrix> SparseMatrix::IndexSelect(\n    int64_t dim, torch::Tensor ids) {\n  auto id_array = TorchTensorToDGLArray(ids);\n  bool rowwise = dim == 0;\n  auto csr = rowwise ? this->CSRPtr() : this->CSCPtr();\n  auto slice_csr = dgl::aten::CSRSliceRows(CSRToOldDGLCSR(csr), id_array);\n  auto slice_value =\n      this->value().index_select(0, DGLArrayToTorchTensor(slice_csr.data));\n  // To prevent potential errors in future conversions to the COO format,\n  // where this array might be used as an initialization array for\n  // constructing COO representations, it is necessary to clear this array.\n  slice_csr.data = dgl::aten::NullArray();\n  auto ret = CSRFromOldDGLCSR(slice_csr);\n  if (rowwise) {\n    return SparseMatrix::FromCSRPointer(\n        ret, slice_value, {ret->num_rows, ret->num_cols});\n  } else {\n    return SparseMatrix::FromCSCPointer(\n        ret, slice_value, {ret->num_cols, ret->num_rows});\n  }\n}\n\nc10::intrusive_ptr<SparseMatrix> SparseMatrix::RangeSelect(\n    int64_t dim, int64_t start, int64_t end) {\n  bool rowwise = dim == 0;\n  auto csr = rowwise ? this->CSRPtr() : this->CSCPtr();\n  auto slice_csr = dgl::aten::CSRSliceRows(CSRToOldDGLCSR(csr), start, end);\n  auto slice_value =\n      this->value().index_select(0, DGLArrayToTorchTensor(slice_csr.data));\n  // To prevent potential errors in future conversions to the COO format,\n  // where this array might be used as an initialization array for\n  // constructing COO representations, it is necessary to clear this array.\n  slice_csr.data = dgl::aten::NullArray();\n  auto ret = CSRFromOldDGLCSR(slice_csr);\n  if (rowwise) {\n    return SparseMatrix::FromCSRPointer(\n        ret, slice_value, {ret->num_rows, ret->num_cols});\n  } else {\n    return SparseMatrix::FromCSCPointer(\n        ret, slice_value, {ret->num_cols, ret->num_rows});\n  }\n}\n\nc10::intrusive_ptr<SparseMatrix> SparseMatrix::Sample(\n    int64_t dim, int64_t fanout, torch::Tensor ids, bool replace, bool bias) {\n  bool rowwise = dim == 0;\n  auto id_array = TorchTensorToDGLArray(ids);\n  auto csr = rowwise ? this->CSRPtr() : this->CSCPtr();\n  // Slicing matrix.\n  auto slice_csr = dgl::aten::CSRSliceRows(CSRToOldDGLCSR(csr), id_array);\n  auto slice_value =\n      this->value().index_select(0, DGLArrayToTorchTensor(slice_csr.data));\n  // Reset value indices.\n  slice_csr.data = dgl::aten::NullArray();\n\n  auto prob =\n      bias ? TorchTensorToDGLArray(slice_value) : dgl::aten::NullArray();\n  auto slice_id =\n      dgl::aten::Range(0, id_array.NumElements(), 64, id_array->ctx);\n  // Sampling all rows on sliced matrix.\n  auto sample_coo =\n      dgl::aten::CSRRowWiseSampling(slice_csr, slice_id, fanout, prob, replace);\n  auto sample_value =\n      slice_value.index_select(0, DGLArrayToTorchTensor(sample_coo.data));\n  sample_coo.data = dgl::aten::NullArray();\n  auto ret = COOFromOldDGLCOO(sample_coo);\n  if (!rowwise) ret = COOTranspose(ret);\n  return SparseMatrix::FromCOOPointer(\n      ret, sample_value, {ret->num_rows, ret->num_cols});\n}\n\nc10::intrusive_ptr<SparseMatrix> SparseMatrix::ValLike(\n    const c10::intrusive_ptr<SparseMatrix>& mat, torch::Tensor value) {\n  TORCH_CHECK(\n      mat->value().size(0) == value.size(0), \"The first dimension of \",\n      \"the old values and the new values must be the same.\");\n  TORCH_CHECK(\n      mat->value().device() == value.device(), \"The device of the \",\n      \"old values and the new values must be the same.\");\n  const auto& shape = mat->shape();\n  if (mat->HasDiag()) {\n    return SparseMatrix::FromDiagPointer(mat->DiagPtr(), value, shape);\n  }\n  if (mat->HasCOO()) {\n    return SparseMatrix::FromCOOPointer(mat->COOPtr(), value, shape);\n  }\n  if (mat->HasCSR()) {\n    return SparseMatrix::FromCSRPointer(mat->CSRPtr(), value, shape);\n  }\n  TORCH_CHECK(mat->HasCSC(), \"Invalid sparse format for ValLike.\")\n  return SparseMatrix::FromCSCPointer(mat->CSCPtr(), value, shape);\n}\n\nstd::shared_ptr<COO> SparseMatrix::COOPtr() {\n  if (coo_ == nullptr) {\n    _CreateCOO();\n  }\n  return coo_;\n}\n\nstd::shared_ptr<CSR> SparseMatrix::CSRPtr() {\n  if (csr_ == nullptr) {\n    _CreateCSR();\n  }\n  return csr_;\n}\n\nstd::shared_ptr<CSR> SparseMatrix::CSCPtr() {\n  if (csc_ == nullptr) {\n    _CreateCSC();\n  }\n  return csc_;\n}\n\nstd::shared_ptr<Diag> SparseMatrix::DiagPtr() {\n  TORCH_CHECK(\n      diag_ != nullptr,\n      \"Cannot get Diag sparse format from a non-diagonal sparse matrix\");\n  return diag_;\n}\n\nstd::tuple<torch::Tensor, torch::Tensor> SparseMatrix::COOTensors() {\n  auto coo = COOPtr();\n  return std::make_tuple(coo->indices.index({0}), coo->indices.index({1}));\n}\n\ntorch::Tensor SparseMatrix::Indices() {\n  auto coo = COOPtr();\n  return coo->indices;\n}\n\nstd::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>\nSparseMatrix::CSRTensors() {\n  auto csr = CSRPtr();\n  auto val = value();\n  return std::make_tuple(csr->indptr, csr->indices, csr->value_indices);\n}\n\nstd::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>\nSparseMatrix::CSCTensors() {\n  auto csc = CSCPtr();\n  return std::make_tuple(csc->indptr, csc->indices, csc->value_indices);\n}\n\nc10::intrusive_ptr<SparseMatrix> SparseMatrix::Transpose() const {\n  auto shape = shape_;\n  std::swap(shape[0], shape[1]);\n  auto value = value_;\n  if (HasDiag()) {\n    return SparseMatrix::FromDiag(value, shape);\n  } else if (HasCOO()) {\n    auto coo = COOTranspose(coo_);\n    return SparseMatrix::FromCOOPointer(coo, value, shape);\n  } else if (HasCSR()) {\n    return SparseMatrix::FromCSCPointer(csr_, value, shape);\n  } else {\n    return SparseMatrix::FromCSRPointer(csc_, value, shape);\n  }\n}\n\nvoid SparseMatrix::_CreateCOO() {\n  if (HasCOO()) return;\n  if (HasDiag()) {\n    auto indices_options = torch::TensorOptions()\n                               .dtype(torch::kInt64)\n                               .layout(torch::kStrided)\n                               .device(this->device());\n    coo_ = DiagToCOO(diag_, indices_options);\n  } else if (HasCSR()) {\n    coo_ = CSRToCOO(csr_);\n  } else if (HasCSC()) {\n    coo_ = CSCToCOO(csc_);\n  } else {\n    LOG(FATAL) << \"SparseMatrix does not have any sparse format\";\n  }\n}\n\nvoid SparseMatrix::_CreateCSR() {\n  if (HasCSR()) return;\n  if (HasDiag()) {\n    auto indices_options = torch::TensorOptions()\n                               .dtype(torch::kInt64)\n                               .layout(torch::kStrided)\n                               .device(this->device());\n    csr_ = DiagToCSR(diag_, indices_options);\n  } else if (HasCOO()) {\n    csr_ = COOToCSR(coo_);\n  } else if (HasCSC()) {\n    csr_ = CSCToCSR(csc_);\n  } else {\n    LOG(FATAL) << \"SparseMatrix does not have any sparse format\";\n  }\n}\n\nvoid SparseMatrix::_CreateCSC() {\n  if (HasCSC()) return;\n  if (HasDiag()) {\n    auto indices_options = torch::TensorOptions()\n                               .dtype(torch::kInt64)\n                               .layout(torch::kStrided)\n                               .device(this->device());\n    csc_ = DiagToCSC(diag_, indices_options);\n  } else if (HasCOO()) {\n    csc_ = COOToCSC(coo_);\n  } else if (HasCSR()) {\n    csc_ = CSRToCSC(csr_);\n  } else {\n    LOG(FATAL) << \"SparseMatrix does not have any sparse format\";\n  }\n}\n\n}  // namespace sparse\n}  // namespace dgl\n"
  },
  {
    "path": "dgl_sparse/src/sparse_matrix_coalesce.cc",
    "content": "/**\n *  Copyright (c) 2022 by Contributors\n * @file sparse_matrix_coalesce.cc\n * @brief Operators related to sparse matrix coalescing.\n */\n// clang-format off\n#include <sparse/dgl_headers.h>\n// clang-format on\n\n#include <sparse/sparse_matrix.h>\n\n#include \"./utils.h\"\n\nnamespace dgl {\nnamespace sparse {\n\nc10::intrusive_ptr<SparseMatrix> SparseMatrix::Coalesce() {\n  auto torch_coo = COOToTorchCOO(this->COOPtr(), this->value());\n  auto coalesced_coo = torch_coo.coalesce();\n  return SparseMatrix::FromCOO(\n      coalesced_coo.indices(), coalesced_coo.values(), this->shape());\n}\n\nbool SparseMatrix::HasDuplicate() {\n  aten::CSRMatrix dgl_csr;\n  if (HasDiag()) {\n    return false;\n  }\n  // The format for calculation will be chosen in the following order: CSR,\n  // CSC. CSR is created if the sparse matrix only has CSC format.\n  if (HasCSR() || !HasCSC()) {\n    dgl_csr = CSRToOldDGLCSR(CSRPtr());\n  } else {\n    dgl_csr = CSRToOldDGLCSR(CSCPtr());\n  }\n  return aten::CSRHasDuplicate(dgl_csr);\n}\n\n}  // namespace sparse\n}  // namespace dgl\n"
  },
  {
    "path": "dgl_sparse/src/spmm.cc",
    "content": "/**\n *  Copyright (c) 2022 by Contributors\n * @file spmm.cc\n * @brief DGL C++ sparse SpMM operator implementation.\n */\n\n#include <sparse/sddmm.h>\n#include <sparse/sparse_matrix.h>\n#include <sparse/spmm.h>\n#include <torch/script.h>\n\n#include <sstream>\n\n#include \"./matmul.h\"\n#include \"./utils.h\"\n\nnamespace dgl {\nnamespace sparse {\n\nusing namespace torch::autograd;\n\nclass SpMMAutoGrad : public Function<SpMMAutoGrad> {\n public:\n  static torch::Tensor forward(\n      AutogradContext* ctx, c10::intrusive_ptr<SparseMatrix> sparse_mat,\n      torch::Tensor sparse_val, torch::Tensor dense_mat);\n\n  static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs);\n};\n\nvoid _SpMMSanityCheck(\n    c10::intrusive_ptr<SparseMatrix> sparse_mat, torch::Tensor sparse_val,\n    torch::Tensor dense_mat) {\n  const auto& sparse_mat_shape = sparse_mat->shape();\n  auto val_shape = sparse_val.sizes();\n  auto dense_shape = dense_mat.sizes();\n  bool shape_check = true;\n  shape_check &= sparse_mat_shape[1] == dense_shape[0];\n  shape_check &= val_shape.size() <= 2;\n  shape_check &= val_shape[0] == sparse_mat->nnz();\n  shape_check &= dense_shape.size() <= 3;\n  if (dense_shape.size() == 3 || val_shape.size() == 2) {\n    shape_check &= dense_shape.size() == val_shape.size() + 1;\n    shape_check &= dense_shape[2] == val_shape[1];\n  }\n  if (!shape_check) {\n    std::stringstream error;\n    error << \"SpMM: Invalid input shapes. sparse_mat: \"\n          << c10::IntArrayRef(sparse_mat->shape())\n          << \", sparse_val: \" << sparse_mat->value().sizes()\n          << \", dense_mat: \" << dense_mat.sizes()\n          << \". Valid input shapes (sparse_mat, dense_mat) are: (1) (n, m) and \"\n             \"(m, k); (2) (n, m) and (m,); (3) (n, m, b) and (m, k, b).\";\n    TORCH_CHECK(false, error.str());\n  }\n  TORCH_CHECK(\n      sparse_val.dtype() == dense_mat.dtype(),\n      \"SpMM: the non-zero values does not have the same dtype as the dense \"\n      \"matrix.\");\n  TORCH_CHECK(\n      sparse_val.device() == sparse_mat->device() &&\n          sparse_val.device() == dense_mat.device(),\n      \"SpMM: sparse matrix, non-zero values and the dense matrix should be \"\n      \"on the same device.\");\n}\n\ntorch::Tensor SpMMAutoGrad::forward(\n    AutogradContext* ctx, c10::intrusive_ptr<SparseMatrix> sparse_mat,\n    torch::Tensor sparse_val, torch::Tensor dense_mat) {\n  auto ret = SpMMNoAutoGrad(sparse_mat, sparse_val, dense_mat, false);\n\n  const bool sparse_requires_grad = sparse_val.requires_grad();\n  const bool dense_requires_grad = dense_mat.requires_grad();\n  torch::Tensor cache_sparse_val, cache_dense_mat;\n  if (dense_requires_grad) {\n    cache_sparse_val = sparse_val;\n  }\n  if (sparse_requires_grad) {\n    cache_dense_mat = dense_mat;\n  }\n  ctx->saved_data[\"sparse_matrix\"] = sparse_mat;\n  ctx->saved_data[\"sparse_requires_grad\"] = sparse_requires_grad;\n  ctx->saved_data[\"dense_requires_grad\"] = dense_requires_grad;\n  ctx->save_for_backward({cache_sparse_val, cache_dense_mat});\n  return ret;\n}\n\ntensor_list SpMMAutoGrad::backward(\n    AutogradContext* ctx, tensor_list grad_outputs) {\n  auto saved = ctx->get_saved_variables();\n  auto sparse_val = saved[0];\n  auto dense_mat = saved[1];\n  auto output_grad = grad_outputs[0];\n\n  auto sparse_mat =\n      ctx->saved_data[\"sparse_matrix\"].toCustomClass<SparseMatrix>();\n  const bool sparse_requires_grad =\n      ctx->saved_data[\"sparse_requires_grad\"].toBool();\n  const bool dense_requires_grad =\n      ctx->saved_data[\"dense_requires_grad\"].toBool();\n\n  torch::Tensor dense_mat_grad, sparse_val_grad;\n  if (sparse_requires_grad) {\n    // A @ B = C -> dA = dC @ (B^T)\n    sparse_val_grad = SDDMMNoAutoGrad(sparse_mat, output_grad, dense_mat);\n  }\n  if (dense_requires_grad) {\n    // A @ B = C -> dB = (A^T) @ dC\n    dense_mat_grad = SpMMNoAutoGrad(sparse_mat, sparse_val, output_grad, true);\n  }\n  return {torch::Tensor(), sparse_val_grad, dense_mat_grad};\n}\n\ntorch::Tensor SpMM(\n    const c10::intrusive_ptr<SparseMatrix>& sparse_mat,\n    torch::Tensor dense_mat) {\n  _SpMMSanityCheck(sparse_mat, sparse_mat->value(), dense_mat);\n  bool expand_dim = false;\n  if (dense_mat.dim() == 1) {\n    dense_mat = dense_mat.view({-1, 1});\n    expand_dim = true;\n  }\n  auto ret = SpMMAutoGrad::apply(sparse_mat, sparse_mat->value(), dense_mat);\n  if (expand_dim) {\n    ret = ret.view(-1);\n  }\n  return ret;\n}\n\n}  // namespace sparse\n}  // namespace dgl\n"
  },
  {
    "path": "dgl_sparse/src/spspmm.cc",
    "content": "/**\n *  Copyright (c) 2022 by Contributors\n * @file spspmm.cc\n * @brief DGL C++ sparse SpSpMM operator implementation.\n */\n\n#include <sparse/sddmm.h>\n#include <sparse/sparse_matrix.h>\n#include <sparse/spspmm.h>\n#include <torch/script.h>\n\n#include \"./matmul.h\"\n#include \"./utils.h\"\n\nnamespace dgl {\nnamespace sparse {\n\nusing namespace torch::autograd;\n\nclass SpSpMMAutoGrad : public Function<SpSpMMAutoGrad> {\n public:\n  static variable_list forward(\n      AutogradContext* ctx, c10::intrusive_ptr<SparseMatrix> lhs_mat,\n      torch::Tensor lhs_val, c10::intrusive_ptr<SparseMatrix> rhs_mat,\n      torch::Tensor rhs_val);\n\n  static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs);\n};\n\nvoid _SpSpMMSanityCheck(\n    const c10::intrusive_ptr<SparseMatrix>& lhs_mat,\n    const c10::intrusive_ptr<SparseMatrix>& rhs_mat) {\n  const auto& lhs_shape = lhs_mat->shape();\n  const auto& rhs_shape = rhs_mat->shape();\n  TORCH_CHECK(\n      lhs_shape[1] == rhs_shape[0],\n      \"SpSpMM: the second dim of lhs_mat should be equal to the first dim \",\n      \"of the second matrix\");\n  TORCH_CHECK(\n      lhs_mat->value().dim() == 1,\n      \"SpSpMM: the value shape of lhs_mat should be 1-D\");\n  TORCH_CHECK(\n      rhs_mat->value().dim() == 1,\n      \"SpSpMM: the value shape of rhs_mat should be 1-D\");\n  TORCH_CHECK(\n      lhs_mat->device() == rhs_mat->device(),\n      \"SpSpMM: lhs_mat and rhs_mat should be on the same device\");\n  TORCH_CHECK(\n      lhs_mat->dtype() == rhs_mat->dtype(),\n      \"SpSpMM: lhs_mat and rhs_mat should have the same dtype\");\n  TORCH_CHECK(\n      !lhs_mat->HasDuplicate(),\n      \"SpSpMM does not support lhs_mat with duplicate indices. \",\n      \"Call A = A.coalesce() to dedup first.\");\n  TORCH_CHECK(\n      !rhs_mat->HasDuplicate(),\n      \"SpSpMM does not support rhs_mat with duplicate indices. \",\n      \"Call A = A.coalesce() to dedup first.\");\n}\n\n// Mask select value of `mat` by `sub_mat`.\ntorch::Tensor _CSRMask(\n    const c10::intrusive_ptr<SparseMatrix>& mat, torch::Tensor value,\n    const c10::intrusive_ptr<SparseMatrix>& sub_mat) {\n  auto csr = CSRToOldDGLCSR(mat->CSRPtr());\n  auto val = TorchTensorToDGLArray(value);\n  auto row = TorchTensorToDGLArray(sub_mat->COOPtr()->indices.index({0}));\n  auto col = TorchTensorToDGLArray(sub_mat->COOPtr()->indices.index({1}));\n  runtime::NDArray ret = aten::CSRGetFloatingData(csr, row, col, val, 0.);\n  return DGLArrayToTorchTensor(ret);\n}\n\nvariable_list SpSpMMAutoGrad::forward(\n    AutogradContext* ctx, c10::intrusive_ptr<SparseMatrix> lhs_mat,\n    torch::Tensor lhs_val, c10::intrusive_ptr<SparseMatrix> rhs_mat,\n    torch::Tensor rhs_val) {\n  auto ret_mat =\n      SpSpMMNoAutoGrad(lhs_mat, lhs_val, rhs_mat, rhs_val, false, false);\n\n  ctx->saved_data[\"lhs_mat\"] = lhs_mat;\n  ctx->saved_data[\"rhs_mat\"] = rhs_mat;\n  ctx->saved_data[\"ret_mat\"] = ret_mat;\n  ctx->saved_data[\"lhs_require_grad\"] = lhs_val.requires_grad();\n  ctx->saved_data[\"rhs_require_grad\"] = rhs_val.requires_grad();\n  ctx->save_for_backward({lhs_val, rhs_val});\n\n  auto csr = ret_mat->CSRPtr();\n  auto val = ret_mat->value();\n  TORCH_CHECK(!csr->value_indices.has_value());\n  return {csr->indptr, csr->indices, val};\n}\n\ntensor_list SpSpMMAutoGrad::backward(\n    AutogradContext* ctx, tensor_list grad_outputs) {\n  auto saved = ctx->get_saved_variables();\n  auto lhs_val = saved[0];\n  auto rhs_val = saved[1];\n  auto output_grad = grad_outputs[2];\n  auto lhs_mat = ctx->saved_data[\"lhs_mat\"].toCustomClass<SparseMatrix>();\n  auto rhs_mat = ctx->saved_data[\"rhs_mat\"].toCustomClass<SparseMatrix>();\n  auto ret_mat = ctx->saved_data[\"ret_mat\"].toCustomClass<SparseMatrix>();\n  torch::Tensor lhs_val_grad, rhs_val_grad;\n\n  if (ctx->saved_data[\"lhs_require_grad\"].toBool()) {\n    // A @ B = C -> dA = dC @ (B^T)\n    auto lhs_mat_grad =\n        SpSpMMNoAutoGrad(ret_mat, output_grad, rhs_mat, rhs_val, false, true);\n    lhs_val_grad = _CSRMask(lhs_mat_grad, lhs_mat_grad->value(), lhs_mat);\n  }\n  if (ctx->saved_data[\"rhs_require_grad\"].toBool()) {\n    // A @ B = C -> dB = (A^T) @ dC\n    auto rhs_mat_grad =\n        SpSpMMNoAutoGrad(lhs_mat, lhs_val, ret_mat, output_grad, true, false);\n    rhs_val_grad = _CSRMask(rhs_mat_grad, rhs_mat_grad->value(), rhs_mat);\n  }\n  return {torch::Tensor(), lhs_val_grad, torch::Tensor(), rhs_val_grad};\n}\n\nc10::intrusive_ptr<SparseMatrix> DiagSpSpMM(\n    const c10::intrusive_ptr<SparseMatrix>& lhs_mat,\n    const c10::intrusive_ptr<SparseMatrix>& rhs_mat) {\n  if (lhs_mat->HasDiag() && rhs_mat->HasDiag()) {\n    // Diag @ Diag\n    const int64_t m = lhs_mat->shape()[0];\n    const int64_t n = lhs_mat->shape()[1];\n    const int64_t p = rhs_mat->shape()[1];\n    const int64_t common_diag_len = std::min({m, n, p});\n    const int64_t new_diag_len = std::min(m, p);\n    auto slice = torch::indexing::Slice(0, common_diag_len);\n    auto new_val =\n        lhs_mat->value().index({slice}) * rhs_mat->value().index({slice});\n    new_val =\n        torch::constant_pad_nd(new_val, {0, new_diag_len - common_diag_len}, 0);\n    return SparseMatrix::FromDiag(new_val, {m, p});\n  }\n  if (lhs_mat->HasDiag() && !rhs_mat->HasDiag()) {\n    // Diag @ Sparse\n    auto row = rhs_mat->Indices().index({0});\n    auto val = lhs_mat->value().index_select(0, row) * rhs_mat->value();\n    return SparseMatrix::ValLike(rhs_mat, val);\n  }\n  if (!lhs_mat->HasDiag() && rhs_mat->HasDiag()) {\n    // Sparse @ Diag\n    auto col = lhs_mat->Indices().index({1});\n    auto val = rhs_mat->value().index_select(0, col) * lhs_mat->value();\n    return SparseMatrix::ValLike(lhs_mat, val);\n  }\n  TORCH_CHECK(\n      false,\n      \"For DiagSpSpMM, at least one of the sparse matries need to have kDiag \"\n      \"format\");\n  return c10::intrusive_ptr<SparseMatrix>();\n}\n\nc10::intrusive_ptr<SparseMatrix> SpSpMM(\n    const c10::intrusive_ptr<SparseMatrix>& lhs_mat,\n    const c10::intrusive_ptr<SparseMatrix>& rhs_mat) {\n  _SpSpMMSanityCheck(lhs_mat, rhs_mat);\n  if (lhs_mat->HasDiag() || rhs_mat->HasDiag()) {\n    return DiagSpSpMM(lhs_mat, rhs_mat);\n  }\n  auto results = SpSpMMAutoGrad::apply(\n      lhs_mat, lhs_mat->value(), rhs_mat, rhs_mat->value());\n  std::vector<int64_t> ret_shape({lhs_mat->shape()[0], rhs_mat->shape()[1]});\n  auto indptr = results[0];\n  auto indices = results[1];\n  auto value = results[2];\n  return SparseMatrix::FromCSR(indptr, indices, value, ret_shape);\n}\n\n}  // namespace sparse\n}  // namespace dgl\n"
  },
  {
    "path": "dgl_sparse/src/utils.h",
    "content": "/**\n *  Copyright (c) 2022 by Contributors\n * @file utils.h\n * @brief DGL C++ sparse API utilities\n */\n#ifndef DGL_SPARSE_UTILS_H_\n#define DGL_SPARSE_UTILS_H_\n\n// clang-format off\n#include <sparse/dgl_headers.h>\n// clang-format on\n\n#include <ATen/DLConvertor.h>\n#include <sparse/sparse_matrix.h>\n#include <torch/custom_class.h>\n#include <torch/script.h>\n\nnamespace dgl {\nnamespace sparse {\n\n/** @brief Find a proper sparse format for two sparse matrices. It chooses\n * COO if anyone of the sparse matrices has COO format. If none of them has\n * COO, it tries CSR and CSC in the same manner. */\ninline static SparseFormat FindAnyExistingFormat(\n    const c10::intrusive_ptr<SparseMatrix>& A,\n    const c10::intrusive_ptr<SparseMatrix>& B) {\n  SparseFormat fmt;\n  if (A->HasCOO() || B->HasCOO()) {\n    fmt = SparseFormat::kCOO;\n  } else if (A->HasCSR() || B->HasCSR()) {\n    fmt = SparseFormat::kCSR;\n  } else {\n    fmt = SparseFormat::kCSC;\n  }\n  return fmt;\n}\n\n/** @brief Check whether two matrices has the same dtype and shape for\n * elementwise operators. */\ninline static void ElementwiseOpSanityCheck(\n    const c10::intrusive_ptr<SparseMatrix>& A,\n    const c10::intrusive_ptr<SparseMatrix>& B) {\n  TORCH_CHECK(\n      A->value().dtype() == B->value().dtype(),\n      \"Elementwise operators\"\n      \" do not support two sparse matrices with different dtypes.\");\n  TORCH_CHECK(\n      A->shape()[0] == B->shape()[0] && A->shape()[1] == B->shape()[1],\n      \"Elementwise operators do not support two sparse matrices with different\"\n      \" shapes.\");\n}\n\n/** @brief Convert a Torch tensor to a DGL array. */\ninline static runtime::NDArray TorchTensorToDGLArray(torch::Tensor tensor) {\n  return runtime::DLPackConvert::FromDLPack(at::toDLPack(tensor.contiguous()));\n}\n\n/** @brief Convert a DGL array to a Torch tensor. */\ninline static torch::Tensor DGLArrayToTorchTensor(runtime::NDArray array) {\n  return at::fromDLPack(runtime::DLPackConvert::ToDLPack(array));\n}\n\n/** @brief Convert an optional Torch tensor to a DGL array. */\ninline static runtime::NDArray OptionalTorchTensorToDGLArray(\n    torch::optional<torch::Tensor> tensor) {\n  if (!tensor.has_value()) {\n    return aten::NullArray();\n  }\n  return TorchTensorToDGLArray(tensor.value());\n}\n\n/** @brief Convert a DGL array to an optional Torch tensor. */\ninline static torch::optional<torch::Tensor> DGLArrayToOptionalTorchTensor(\n    runtime::NDArray array) {\n  if (aten::IsNullArray(array)) {\n    return torch::optional<torch::Tensor>();\n  }\n  return torch::make_optional<torch::Tensor>(DGLArrayToTorchTensor(array));\n}\n\n}  // namespace sparse\n}  // namespace dgl\n\n#endif  // DGL_SPARSE_UTILS_H_\n"
  },
  {
    "path": "dglgo/README.md",
    "content": "# DGL-Go\n\n\nDGL-Go is a command line tool for users to get started with training, using and\nstudying Graph Neural Networks (GNNs). Data scientists can quickly apply GNNs\nto their problems, whereas researchers will find it useful to customize their\nexperiments.\n\n\n## Installation and get started\n\nDGL-Go requires DGL v0.8+ so please make sure DGL is updated properly.\n\n### Install the latest stable version\n\n```\npip install dglgo\n```\n\n### Install from source for experimental features\n\n```\npython setup.py install\n```\n\n### Get started\n\nType `dgl` in your console:\n\n```\nUsage: dgl [OPTIONS] COMMAND [ARGS]...\n\nOptions:\n  --help  Show this message and exit.\n\nCommands:\n  configure  Generate a configuration file\n  export     Export a runnable python script\n  recipe     Get example recipes\n  train      Launch training\n```\n\n<p align=\"center\">\n  <img src=\"./dglgo.png\" height=\"400\">\n</p>\n\nUsing DGL-Go is as easy as three steps:\n\n1. Use `dgl configure` to pick the task, dataset and model of your interests. It generates\n   a configuration file for later use. You could also use `dgl recipe get` to retrieve\n   a configuration file we provided.\n1. Use `dgl train` to launch training according to the configuration and see the results.\n1. Use `dgl export` to generate a *self-contained, reproducible* Python script for advanced\n   customization, or try the model on custom data stored in CSV format.\n\nNext, we will walk through all these steps one-by-one.\n\n## Training GraphSAGE for node classification on Cora\n\nLet's use one of the most classical setups -- training a GraphSAGE model for node\nclassification on the Cora citation graph dataset as an\nexample.\n\n### Step 1: `dgl configure`\n\nFirst step, use `dgl configure` to generate a YAML configuration file.\n\n```\ndgl configure nodepred --data cora --model sage --cfg cora_sage.yaml\n```\n\nNote that `nodepred` is the name of DGL-Go *pipeline*. For now, you can think of\npipeline as training task: `nodepred` is for node multiclass classification task; other\noptions include `linkpred` for link prediction task, and `graphpred` for graph binary classification etc. The command will\ngenerate a configurate file `cora_sage.yaml` which includes:\n\n* Options for the selected dataset (i.e., `cora` here).\n* Model hyperparameters (e.g., number of layers, hidden size, etc.).\n* Training hyperparameters (e.g., learning rate, loss function, etc.).\n\nDifferent choices of task, model and datasets may give very different options,\nso DGL-Go also adds a comment per option for explanation.\nAt this point you can also change options to explore optimization potentials.\n\nThe snippet below shows the configuration file generated by the command above.\n\n```yaml\nversion: 0.0.2\npipeline_name: nodepred\npipeline_mode: train\ndevice: cpu\ndata:\n  name: cora\n  split_ratio:                # Ratio to generate split masks, for example set to [0.8, 0.1, 0.1] for 80% train/10% val/10% test. Leave blank to use builtin split in original dataset\nmodel:\n  name: sage\n  embed_size: -1              # The dimension of created embedding table. -1 means using original node embedding\n  hidden_size: 16             # Hidden size.\n  num_layers: 1               # Number of hidden layers.\n  activation: relu            # Activation function name under torch.nn.functional\n  dropout: 0.5                # Dropout rate.\n  aggregator_type: gcn        # Aggregator type to use (``mean``, ``gcn``, ``pool``, ``lstm``).\ngeneral_pipeline:\n  early_stop:\n    patience: 20              # Steps before early stop\n    checkpoint_path: checkpoint.pth # Early stop checkpoint model file path\n  num_epochs: 200             # Number of training epochs\n  eval_period: 5              # Interval epochs between evaluations\n  optimizer:\n    name: Adam\n    lr: 0.01\n    weight_decay: 0.0005\n  loss: CrossEntropyLoss\n  save_path: results          # Directory to save the experiment results\n  num_runs: 1                 # Number of experiments to run\n```\n\nApart from `dgl configure`, you could also get one of DGL-Go's built-in configuration files\n(called *recipe*) using `dgl recipe`. There are two sub-commands:\n\n```\ndgl recipe list\n```\n\nwill list the available recipes:\n\n```\n➜ dgl recipe list\n===============================================================================\n| Filename                       |  Pipeline           | Dataset              |\n===============================================================================\n| graphpred_pcba_gin.yaml        |  graphpred          | ogbg-molpcba         |\n| graphpred_hiv_pna.yaml         |  graphpred          | ogbg-molhiv          |\n| graphpred_hiv_gin.yaml         |  graphpred          | ogbg-molhiv          |\n| linkpred_citation2_sage.yaml   |  linkpred           | ogbl-citation2       |\n| linkpred_collab_sage.yaml      |  linkpred           | ogbl-collab          |\n| nodepred_citeseer_sage.yaml    |  nodepred           | citeseer             |\n| nodepred_citeseer_gcn.yaml     |  nodepred           | citeseer             |\n| nodepred-ns_arxiv_gcn.yaml     |  nodepred-ns        | ogbn-arxiv           |\n| nodepred_cora_gat.yaml         |  nodepred           | cora                 |\n| nodepred_pubmed_sage.yaml      |  nodepred           | pubmed               |\n| linkpred_cora_sage.yaml        |  linkpred           | cora                 |\n| nodepred_pubmed_gcn.yaml       |  nodepred           | pubmed               |\n| nodepred_pubmed_gat.yaml       |  nodepred           | pubmed               |\n| nodepred_cora_gcn.yaml         |  nodepred           | cora                 |\n| nodepred_cora_sage.yaml        |  nodepred           | cora                 |\n| nodepred_citeseer_gat.yaml     |  nodepred           | citeseer             |\n| nodepred-ns_product_sage.yaml  |  nodepred-ns        | ogbn-products        |\n===============================================================================\n```\n\nThen use\n\n```\ndgl recipe get nodepred_cora_sage.yaml\n```\n\nto copy the YAML configuration file to your local folder.\n\n### Step 2: `dgl train`\n\nSimply run `dgl train --cfg cora_sage.yaml` will start the training process.\n```log\n...\nEpoch 00190 | Loss 1.5225 | TrainAcc 0.9500 | ValAcc 0.6840\nEpoch 00191 | Loss 1.5416 | TrainAcc 0.9357 | ValAcc 0.6840\nEpoch 00192 | Loss 1.5391 | TrainAcc 0.9357 | ValAcc 0.6840\nEpoch 00193 | Loss 1.5257 | TrainAcc 0.9643 | ValAcc 0.6840\nEpoch 00194 | Loss 1.5196 | TrainAcc 0.9286 | ValAcc 0.6840\nEarlyStopping counter: 12 out of 20\nEpoch 00195 | Loss 1.4862 | TrainAcc 0.9643 | ValAcc 0.6760\nEpoch 00196 | Loss 1.5142 | TrainAcc 0.9714 | ValAcc 0.6760\nEpoch 00197 | Loss 1.5145 | TrainAcc 0.9714 | ValAcc 0.6760\nEpoch 00198 | Loss 1.5174 | TrainAcc 0.9571 | ValAcc 0.6760\nEpoch 00199 | Loss 1.5235 | TrainAcc 0.9714 | ValAcc 0.6760\nTest Accuracy 0.7740\nAccuracy across 1 runs: 0.774 ± 0.0\n```\n\nThat's all! Basically you only need two commands to train a graph neural network.\n\n### Step 3: `dgl export` for more advanced customization\n\nThat's not everything yet. You may want to open the hood and invoke deeper\ncustomization. DGL-Go can export a **self-contained, reproducible** Python\nscript for you to do anything you like.\n\nTry `dgl export --cfg cora_sage.yaml --output script.py`,\nand you'll get the script used to train the model. Here's the code snippet:\n\n```python\n...\n\nclass GraphSAGE(nn.Module):\n    def __init__(self,\n                 data_info: dict,\n                 embed_size: int = -1,\n                 hidden_size: int = 16,\n                 num_layers: int = 1,\n                 activation: str = \"relu\",\n                 dropout: float = 0.5,\n                 aggregator_type: str = \"gcn\"):\n        \"\"\"GraphSAGE model\n\n        Parameters\n        ----------\n        data_info : dict\n            The information about the input dataset.\n        embed_size : int\n            The dimension of created embedding table. -1 means using original node embedding\n        hidden_size : int\n            Hidden size.\n        num_layers : int\n            Number of hidden layers.\n        dropout : float\n            Dropout rate.\n        activation : str\n            Activation function name under torch.nn.functional\n        aggregator_type : str\n            Aggregator type to use (``mean``, ``gcn``, ``pool``, ``lstm``).\n        \"\"\"\n        super(GraphSAGE, self).__init__()\n        self.data_info = data_info\n        self.embed_size = embed_size\n        if embed_size > 0:\n            self.embed = nn.Embedding(data_info[\"num_nodes\"], embed_size)\n            in_size = embed_size\n        else:\n            in_size = data_info[\"in_size\"]\n        self.layers = nn.ModuleList()\n        self.dropout = nn.Dropout(dropout)\n        self.activation = getattr(nn.functional, activation)\n\n        for i in range(num_layers):\n            in_hidden = hidden_size if i > 0 else in_size\n            out_hidden = hidden_size if i < num_layers - \\\n                1 else data_info[\"out_size\"]\n            self.layers.append(\n                dgl.nn.SAGEConv(\n                    in_hidden,\n                    out_hidden,\n                    aggregator_type))\n\n    def forward(self, graph, node_feat, edge_feat=None):\n        if self.embed_size > 0:\n            dgl_warning(\n                \"The embedding for node feature is used, and input node_feat is ignored, due to the provided embed_size.\")\n            h = self.embed.weight\n        else:\n            h = node_feat\n        h = self.dropout(h)\n        for l, layer in enumerate(self.layers):\n            h = layer(graph, h, edge_feat)\n            if l != len(self.layers) - 1:\n                h = self.activation(h)\n                h = self.dropout(h)\n        return h\n\n...\n\ndef train(cfg, pipeline_cfg, device, data, model, optimizer, loss_fcn):\n    g = data[0]  # Only train on the first graph\n    g = dgl.remove_self_loop(g)\n    g = dgl.add_self_loop(g)\n    g = g.to(device)\n\n    node_feat = g.ndata.get('feat', None)\n    edge_feat = g.edata.get('feat', None)\n    label = g.ndata['label']\n    train_mask, val_mask, test_mask = g.ndata['train_mask'].bool(\n    ), g.ndata['val_mask'].bool(), g.ndata['test_mask'].bool()\n\n    stopper = EarlyStopping(**pipeline_cfg['early_stop'])\n\n    val_acc = 0.\n    for epoch in range(pipeline_cfg['num_epochs']):\n        model.train()\n        logits = model(g, node_feat, edge_feat)\n        loss = loss_fcn(logits[train_mask], label[train_mask])\n\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        train_acc = accuracy(logits[train_mask], label[train_mask])\n        if epoch != 0 and epoch % pipeline_cfg['eval_period'] == 0:\n            val_acc = accuracy(logits[val_mask], label[val_mask])\n\n            if stopper.step(val_acc, model):\n                break\n\n        print(\"Epoch {:05d} | Loss {:.4f} | TrainAcc {:.4f} | ValAcc {:.4f}\".\n              format(epoch, loss.item(), train_acc, val_acc))\n\n    stopper.load_checkpoint(model)\n    stopper.close()\n\n    model.eval()\n    with torch.no_grad():\n        logits = model(g, node_feat, edge_feat)\n        test_acc = accuracy(logits[test_mask], label[test_mask])\n    return test_acc\n\n\ndef main(run, cfg, data):\n    device = cfg['device']\n    pipeline_cfg = cfg['general_pipeline']\n    # create model\n    model = GraphSAGE(**cfg[\"model\"])\n    model = model.to(device)\n    loss = torch.nn.CrossEntropyLoss()\n    optimizer = torch.optim.Adam(\n        model.parameters(),\n        **pipeline_cfg[\"optimizer\"])\n    # train\n    test_acc = train(cfg, pipeline_cfg, device, data, model, optimizer, loss)\n    torch.save({'cfg': cfg, 'model': model.state_dict()},\n               os.path.join(pipeline_cfg[\"save_path\"], 'run_{}.pth'.format(run)))\n\n    return test_acc\n\nif __name__ == '__main__':\n    ...\n\n    # load data\n    data = AsNodePredDataset(CoraGraphDataset())\n\n    model_cfg = cfg[\"model\"]\n    cfg[\"model\"][\"data_info\"] = {\n        \"in_size\": model_cfg['embed_size'] if model_cfg['embed_size'] > 0 else data[0].ndata['feat'].shape[1],\n        \"out_size\": data.num_classes,\n        \"num_nodes\": data[0].num_nodes()\n    }\n\n    os.makedirs(cfg['general_pipeline'][\"save_path\"])\n\n    all_acc = []\n    num_runs = 1\n    for run in range(num_runs):\n        print(f'Run experiment #{run}')\n        test_acc = main(run, cfg, data)\n        print(\"Test Accuracy {:.4f}\".format(test_acc))\n        all_acc.append(test_acc)\n    avg_acc = np.round(np.mean(all_acc), 6)\n    std_acc = np.round(np.std(all_acc), 6)\n    print(f'Accuracy across {num_runs} runs: {avg_acc} ± {std_acc}')\n```\n\nYou can see that everything is collected into one Python script which includes the\nentire `GraphSAGE` model definition, data processing and training loop. Simply running\n`python script.py` will give you the *exact same* result as you've seen by `dgl train`.\nAt this point, you can change any part as you wish such as plugging your own GNN module,\nchanging the loss function and so on.\n\n## Use DGL-Go on your own dataset\n\nDGL-Go supports training a model on custom dataset by DGL's `CSVDataset`.\n\n### Step 1: Prepare your CSV and metadata file.\n\nFollow the tutorial at [Loading data from CSV\nfiles](https://docs.dgl.ai/en/latest/guide/data-loadcsv.html#guide-data-pipeline-loadcsv`)\nto prepare your dataset. Generally, the dataset folder should include:\n* At least one CSV file for node data.\n* At least one CSV file for edge data.\n* A metadata file called `meta.yaml`.\n\n### Step 2: `dgl configure` with `--data csv` option\nRun\n\n```\ndgl configure nodepred --data csv --model sage --cfg csv_sage.yaml\n```\n\nto generate the configuration file. You will see that the file includes a section like\nthe followings:\n\n```yaml\n...\ndata:\n  name: csv\n  split_ratio:                # Ratio to generate split masks, for example set to [0.8, 0.1, 0.1] for 80% train/10% val/10% test. Leave blank to use builtin split in original dataset\n  data_path: ./               # metadata.yaml, nodes.csv, edges.csv should in this folder\n...\n```\n\nFill in the `data_path` option with the path to your dataset folder.\n\nIf your dataset does not have any native split for training, validation and test sets,\nyou can set the split ratio in the `split_ratio` option, which will\ngenerate a random split for you.\n\n### Step 3: `train` the model / `export` the script\nThen you can do the same as the tutorial above, either train the model by\n`dgl train --cfg csv_sage.yaml` or use `dgl export --cfg csv_sage.yaml\n--output script.py` to get the training script.\n\n## FAQ\n\n**Q: What are the available options for each command?**\nA: You can use `--help` for all commands. For example, use `dgl --help` for general\nhelp message; use `dgl configure --help` for the configuration options; use\n`dgl configure nodepred --help` for the configuration options of node prediction pipeline.\n\n**Q: What exactly is nodepred/linkpred? How many are they?**\nA: They are called DGL-Go pipelines. A pipeline represents the training methodology for\na certain task. Therefore, its naming convention is *<task_name>[-<method_name>]*. For example,\n`nodepred` trains the selected GNN model for node classification using full-graph training method;\nwhile `nodepred-ns` trains the model for node classifiation but using neighbor sampling.\nCurrently DGL-Go provides four training pipelines (`nodepred`, `nodepred-ns`, `linkpred`, and `graphpred`). Use `dgl configure --help` to see\nall the available pipelines.\n\n**Q: How to add my model to the official model recipe zoo?**\nA: Currently not supported. We will enable this feature soon. Please stay tuned!\n\n**Q: After training a model on some dataset, how can I apply it to another one?**\nA: The `save_path` option in the generated configuration file allows you to specify the directory to save the experiment results. After training, `{save_path}/run_{i}.pth` will be the checkpoint for the i-th run, consisting of the training configuration and trained model state dict. You can then use `dgl apply` as follows.\n\n```\ndgl configure-apply X --data Y --cpt {save_path}/run_{i}.pth --cfg Z\ndgl apply --cfg Z\n```\n\n- `X` is the pipeline name as in `dgl configure`.\n- `Y` is the dataset to apply and can be omitted if you are applying the trained model to the training dataset.\n- `Z` is the configuration file and a default value will be used if not specified.\n\nYou can also use `dgl export --cfg Z` to generate a python script for further modification.\n"
  },
  {
    "path": "dglgo/dglgo/__init__.py",
    "content": ""
  },
  {
    "path": "dglgo/dglgo/apply_pipeline/__init__.py",
    "content": "from .graphpred import ApplyGraphpredPipeline\nfrom .nodepred import ApplyNodepredPipeline\nfrom .nodepred_sample import ApplyNodepredNsPipeline\n"
  },
  {
    "path": "dglgo/dglgo/apply_pipeline/graphpred/__init__.py",
    "content": "from .gen import *\n"
  },
  {
    "path": "dglgo/dglgo/apply_pipeline/graphpred/gen.py",
    "content": "from copy import deepcopy\nfrom pathlib import Path\nfrom typing import Optional\n\nimport ruamel.yaml\nimport torch\nimport typer\nfrom jinja2 import Template\nfrom pydantic import BaseModel, Field\n\nfrom ...utils.factory import (\n    ApplyPipelineFactory,\n    DataFactory,\n    GraphModelFactory,\n    PipelineBase,\n)\nfrom ...utils.yaml_dump import deep_convert_dict, merge_comment\n\npipeline_comments = {\n    \"batch_size\": \"Graph batch size\",\n    \"num_workers\": \"Number of workers for data loading\",\n    \"save_path\": \"Directory to save the inference results\",\n}\n\n\nclass ApplyGraphpredPipelineCfg(BaseModel):\n    batch_size: int = 32\n    num_workers: int = 4\n    save_path: str = \"apply_results\"\n\n\n@ApplyPipelineFactory.register(\"graphpred\")\nclass ApplyGraphpredPipeline(PipelineBase):\n    def __init__(self):\n        self.pipeline = {\"name\": \"graphpred\", \"mode\": \"apply\"}\n\n    @classmethod\n    def setup_user_cfg_cls(cls):\n        from ...utils.enter_config import UserConfig\n\n        class ApplyGraphPredUserConfig(UserConfig):\n            data: DataFactory.filter(\"graphpred\").get_pydantic_config() = Field(\n                ..., discriminator=\"name\"\n            )\n            general_pipeline: ApplyGraphpredPipelineCfg = (\n                ApplyGraphpredPipelineCfg()\n            )\n\n        cls.user_cfg_cls = ApplyGraphPredUserConfig\n\n    @property\n    def user_cfg_cls(self):\n        return self.__class__.user_cfg_cls\n\n    def get_cfg_func(self):\n        def config(\n            data: DataFactory.filter(\n                \"graphpred\"\n            ).get_dataset_enum() = typer.Option(None, help=\"input data name\"),\n            cfg: Optional[str] = typer.Option(\n                None, help=\"output configuration file path\"\n            ),\n            cpt: str = typer.Option(..., help=\"input checkpoint file path\"),\n        ):\n            # Training configuration\n            train_cfg = torch.load(cpt, weights_only=False)[\"cfg\"]\n            if data is None:\n                print(\"data is not specified, use the training dataset\")\n                data = train_cfg[\"data_name\"]\n            else:\n                data = data.name\n            if cfg is None:\n                cfg = (\n                    \"_\".join(\n                        [\"apply\", \"graphpred\", data, train_cfg[\"model_name\"]]\n                    )\n                    + \".yaml\"\n                )\n\n            self.__class__.setup_user_cfg_cls()\n            generated_cfg = {\n                \"pipeline_name\": self.pipeline[\"name\"],\n                \"pipeline_mode\": self.pipeline[\"mode\"],\n                \"device\": train_cfg[\"device\"],\n                \"data\": {\"name\": data},\n                \"cpt_path\": cpt,\n                \"general_pipeline\": {\n                    \"batch_size\": train_cfg[\"general_pipeline\"][\n                        \"eval_batch_size\"\n                    ],\n                    \"num_workers\": train_cfg[\"general_pipeline\"][\"num_workers\"],\n                },\n            }\n            output_cfg = self.user_cfg_cls(**generated_cfg).dict()\n            output_cfg = deep_convert_dict(output_cfg)\n            # Not applicable for inference\n            output_cfg[\"data\"].pop(\"split_ratio\")\n            comment_dict = {\n                \"device\": \"Torch device name, e.g., cpu or cuda or cuda:0\",\n                \"cpt_path\": \"Path to the checkpoint file\",\n                \"general_pipeline\": pipeline_comments,\n            }\n            comment_dict = merge_comment(output_cfg, comment_dict)\n\n            yaml = ruamel.yaml.YAML()\n            yaml.dump(comment_dict, Path(cfg).open(\"w\"))\n            print(\n                \"Configuration file is generated at {}\".format(\n                    Path(cfg).absolute()\n                )\n            )\n\n        return config\n\n    @classmethod\n    def gen_script(cls, user_cfg_dict):\n        # Check validation\n        cls.setup_user_cfg_cls()\n        cls.user_cfg_cls(**user_cfg_dict)\n\n        # Training configuration\n        train_cfg = torch.load(user_cfg_dict[\"cpt_path\"], weights_only=False)[\n            \"cfg\"\n        ]\n\n        # Dict for code rendering\n        render_cfg = deepcopy(user_cfg_dict)\n        model_name = train_cfg[\"model_name\"]\n        model_code = GraphModelFactory.get_source_code(model_name)\n        render_cfg[\"model_code\"] = model_code\n        render_cfg[\"model_class_name\"] = GraphModelFactory.get_model_class_name(\n            model_name\n        )\n        render_cfg.update(\n            DataFactory.get_generated_code_dict(user_cfg_dict[\"data\"][\"name\"])\n        )\n\n        # Dict for defining cfg in the rendered code\n        generated_user_cfg = deepcopy(user_cfg_dict)\n        generated_user_cfg.pop(\"pipeline_name\")\n        generated_user_cfg.pop(\"pipeline_mode\")\n        # model arch configuration\n        generated_user_cfg[\"model\"] = train_cfg[\"model\"]\n\n        render_cfg[\"user_cfg_str\"] = f\"cfg = {str(generated_user_cfg)}\"\n        render_cfg[\"user_cfg\"] = user_cfg_dict\n\n        file_current_dir = Path(__file__).resolve().parent\n        with open(file_current_dir / \"graphpred.jinja-py\", \"r\") as f:\n            template = Template(f.read())\n\n        return template.render(**render_cfg)\n\n    @staticmethod\n    def get_description() -> str:\n        return \"Graph classification pipeline for inference on binary classification\"\n"
  },
  {
    "path": "dglgo/dglgo/apply_pipeline/graphpred/graphpred.jinja-py",
    "content": "import torch\nimport os\nimport csv\n\nfrom tqdm import tqdm\nfrom dgl.data import AsGraphPredDataset\nfrom dgl.dataloading import GraphDataLoader\n{{ data_import_code }}\n\n{{ model_code }}\n\ndef infer(device, loader, model):\n    model = model.to(device)\n    model.eval()\n    all_pred = []\n\n    with torch.no_grad():\n        for _, (g, labels) in enumerate(tqdm(loader, desc=\"Iteration\")):\n            g = g.to(device)\n            node_feat = g.ndata['feat']\n            edge_feat = g.edata['feat']\n            pred = model(g, node_feat, edge_feat)\n            pred = (pred.sigmoid() >= 0.5).long()\n            all_pred.append(pred)\n\n    return torch.cat(all_pred, dim=0)\n\ndef main():\n    {{ user_cfg_str }}\n\n    device = cfg['device']\n    if not torch.cuda.is_available():\n        device = 'cpu'\n    pipeline_cfg = cfg['general_pipeline']\n\n    # load data\n    data = AsGraphPredDataset({{data_initialize_code}})\n    data_loader = GraphDataLoader(data, batch_size=pipeline_cfg['batch_size'],\n                                  num_workers=pipeline_cfg['num_workers'], shuffle=False)\n\n    # validation\n    train_data_name = cfg['model']['data_info']['name']\n    infer_data_name = cfg['data']['name']\n    if train_data_name.startswith('ogbg-mol'):\n        assert infer_data_name.startswith('ogbg-mol'), 'Expect the inference data name to start \\\n            with ogbg-mol, got {}'.format(infer_data_name)\n    else:\n        assert train_data_name == infer_data_name, 'Expect the training and inference data to \\\n            have the same name, got {} and {}'.format(train_data_name, infer_data_name)\n    model_node_feat_size = cfg['model']['data_info']['node_feat_size']\n    model_edge_feat_size = cfg['model']['data_info']['edge_feat_size']\n    data_node_feat_size = data.node_feat_size\n    data_edge_feat_size = data.edge_feat_size\n    assert model_node_feat_size == data_node_feat_size, 'Expect the training data and inference \\\n        data to have the same number of input node features, got {:d} and {:d}'.format(model_node_feat_size, data_node_feat_size)\n    assert model_edge_feat_size == data_edge_feat_size, 'Expect the training data and inference \\\n        data to have the same number of input edge features, got {:d} and {:d}'.format(model_edge_feat_size, data_edge_feat_size)\n\n    model = {{ model_class_name }}(**cfg['model'])\n    model.load_state_dict(torch.load(cfg['cpt_path'], weights_only=False, map_location='cpu')['model'])\n    pred = infer(device, data_loader, model).detach().cpu()\n\n    # Dump the results\n    os.makedirs(cfg['general_pipeline'][\"save_path\"])\n    file_path = os.path.join(cfg['general_pipeline'][\"save_path\"], 'output.csv')\n    header = ['graph id']\n    header.extend(['task_{:d}'.format(i) for i in range(cfg['model']['data_info']['out_size'])])\n    with open(file_path, 'w') as f:\n        writer = csv.writer(f)\n        writer.writerow(header)\n        writer.writerows([\n            [i] + pred[i].tolist() for i in range(len(pred))\n        ])\n    print('Saved inference results to {}'.format(file_path))\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "dglgo/dglgo/apply_pipeline/nodepred/__init__.py",
    "content": "from .gen import *\n"
  },
  {
    "path": "dglgo/dglgo/apply_pipeline/nodepred/gen.py",
    "content": "from copy import deepcopy\nfrom pathlib import Path\nfrom typing import Optional\n\nimport ruamel.yaml\nimport torch\nimport typer\nfrom jinja2 import Template\nfrom pydantic import Field\n\nfrom ...utils.factory import (\n    ApplyPipelineFactory,\n    DataFactory,\n    NodeModelFactory,\n    PipelineBase,\n)\nfrom ...utils.yaml_dump import deep_convert_dict, merge_comment\n\n\n@ApplyPipelineFactory.register(\"nodepred\")\nclass ApplyNodepredPipeline(PipelineBase):\n    def __init__(self):\n        self.pipeline = {\"name\": \"nodepred\", \"mode\": \"apply\"}\n\n    @classmethod\n    def setup_user_cfg_cls(cls):\n        from ...utils.enter_config import UserConfig\n\n        class ApplyNodePredUserConfig(UserConfig):\n            data: DataFactory.filter(\"nodepred\").get_pydantic_config() = Field(\n                ..., discriminator=\"name\"\n            )\n\n        cls.user_cfg_cls = ApplyNodePredUserConfig\n\n    @property\n    def user_cfg_cls(self):\n        return self.__class__.user_cfg_cls\n\n    def get_cfg_func(self):\n        def config(\n            data: DataFactory.filter(\n                \"nodepred\"\n            ).get_dataset_enum() = typer.Option(None, help=\"input data name\"),\n            cfg: Optional[str] = typer.Option(\n                None, help=\"output configuration file path\"\n            ),\n            cpt: str = typer.Option(..., help=\"input checkpoint file path\"),\n        ):\n            # Training configuration\n            train_cfg = torch.load(cpt, weights_only=False)[\"cfg\"]\n            if data is None:\n                print(\"data is not specified, use the training dataset\")\n                data = train_cfg[\"data_name\"]\n            else:\n                data = data.name\n            if cfg is None:\n                cfg = (\n                    \"_\".join(\n                        [\"apply\", \"nodepred\", data, train_cfg[\"model_name\"]]\n                    )\n                    + \".yaml\"\n                )\n\n            self.__class__.setup_user_cfg_cls()\n            generated_cfg = {\n                \"pipeline_name\": self.pipeline[\"name\"],\n                \"pipeline_mode\": self.pipeline[\"mode\"],\n                \"device\": train_cfg[\"device\"],\n                \"data\": {\"name\": data},\n                \"cpt_path\": cpt,\n                \"general_pipeline\": {\"save_path\": \"apply_results\"},\n            }\n            output_cfg = self.user_cfg_cls(**generated_cfg).dict()\n            output_cfg = deep_convert_dict(output_cfg)\n            # Not applicable for inference\n            output_cfg[\"data\"].pop(\"split_ratio\")\n            comment_dict = {\n                \"device\": \"Torch device name, e.g., cpu or cuda or cuda:0\",\n                \"cpt_path\": \"Path to the checkpoint file\",\n                \"general_pipeline\": {\n                    \"save_path\": \"Directory to save the inference results\"\n                },\n            }\n            comment_dict = merge_comment(output_cfg, comment_dict)\n\n            yaml = ruamel.yaml.YAML()\n            yaml.dump(comment_dict, Path(cfg).open(\"w\"))\n            print(\n                \"Configuration file is generated at {}\".format(\n                    Path(cfg).absolute()\n                )\n            )\n\n        return config\n\n    @classmethod\n    def gen_script(cls, user_cfg_dict):\n        # Check validation\n        cls.setup_user_cfg_cls()\n        cls.user_cfg_cls(**user_cfg_dict)\n\n        # Training configuration\n        train_cfg = torch.load(user_cfg_dict[\"cpt_path\"], weights_only=False)[\n            \"cfg\"\n        ]\n\n        # Dict for code rendering\n        render_cfg = deepcopy(user_cfg_dict)\n        model_name = train_cfg[\"model_name\"]\n        model_code = NodeModelFactory.get_source_code(model_name)\n        render_cfg[\"model_code\"] = model_code\n        render_cfg[\"model_class_name\"] = NodeModelFactory.get_model_class_name(\n            model_name\n        )\n        render_cfg.update(\n            DataFactory.get_generated_code_dict(user_cfg_dict[\"data\"][\"name\"])\n        )\n\n        # Dict for defining cfg in the rendered code\n        generated_user_cfg = deepcopy(user_cfg_dict)\n        generated_user_cfg[\"data\"].pop(\"name\")\n        generated_user_cfg.pop(\"pipeline_name\")\n        generated_user_cfg.pop(\"pipeline_mode\")\n        # model arch configuration\n        generated_user_cfg[\"model\"] = train_cfg[\"model\"]\n\n        render_cfg[\"user_cfg_str\"] = f\"cfg = {str(generated_user_cfg)}\"\n        render_cfg[\"user_cfg\"] = user_cfg_dict\n\n        file_current_dir = Path(__file__).resolve().parent\n        with open(file_current_dir / \"nodepred.jinja-py\", \"r\") as f:\n            template = Template(f.read())\n\n        return template.render(**render_cfg)\n\n    @staticmethod\n    def get_description() -> str:\n        return \"Node classification pipeline for inference\"\n"
  },
  {
    "path": "dglgo/dglgo/apply_pipeline/nodepred/nodepred.jinja-py",
    "content": "import torch\nimport dgl\nimport os\nimport csv\n\nfrom dgl.data import AsNodePredDataset\n{{ data_import_code }}\n\n{{ model_code }}\n\ndef infer(device, data, model):\n    g = data[0] # Only infer on the first graph\n    g = dgl.remove_self_loop(g)\n    g = dgl.add_self_loop(g)\n    g = g.to(device)\n\n    node_feat = g.ndata.get('feat', None)\n    edge_feat = g.edata.get('feat', None)\n\n    model = model.to(device)\n    model.eval()\n\n    with torch.no_grad():\n        logits = model(g, node_feat, edge_feat)\n\n    return logits\n\ndef main():\n    {{ user_cfg_str }}\n\n    device = cfg['device']\n    if not torch.cuda.is_available():\n        device = 'cpu'\n\n    # load data\n    data = AsNodePredDataset({{data_initialize_code}})\n    # validation\n    if cfg['model']['embed_size'] > 0:\n        model_num_nodes = cfg['model']['data_info']['num_nodes']\n        data_num_nodes = data[0].num_nodes()\n        assert model_num_nodes == data_num_nodes, \\\n            'Training and inference need to be on the same dataset when node embeddings were learned from scratch'\n    else:\n        model_in_size = cfg['model']['data_info']['in_size']\n        data_in_size = data[0].ndata['feat'].shape[1]\n        assert model_in_size == data_in_size, \\\n            'Expect the training data and inference data to have the same number of input node \\\n                features, got {:d} and {:d}'.format(model_in_size, data_in_size)\n\n    model = {{ model_class_name }}(**cfg['model'])\n    model.load_state_dict(torch.load(cfg['cpt_path'], weights_only=False, map_location='cpu')['model'])\n    logits = infer(device, data, model)\n    pred = logits.argmax(dim=1).cpu()\n\n    # Dump the results\n    os.makedirs(cfg['general_pipeline'][\"save_path\"])\n    file_path = os.path.join(cfg['general_pipeline'][\"save_path\"], 'output.csv')\n    with open(file_path, 'w') as f:\n        writer = csv.writer(f)\n        writer.writerow(['node id', 'predicted label'])\n        writer.writerows([\n            [i, pred[i].item()] for i in range(len(pred))\n        ])\n    print('Saved inference results to {}'.format(file_path))\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "dglgo/dglgo/apply_pipeline/nodepred_sample/__init__.py",
    "content": "from .gen import *\n"
  },
  {
    "path": "dglgo/dglgo/apply_pipeline/nodepred_sample/gen.py",
    "content": "from copy import deepcopy\nfrom pathlib import Path\nfrom typing import Optional\n\nimport ruamel.yaml\nimport torch\nimport typer\nfrom jinja2 import Template\nfrom pydantic import Field\n\nfrom ...utils.factory import (\n    ApplyPipelineFactory,\n    DataFactory,\n    NodeModelFactory,\n    PipelineBase,\n)\nfrom ...utils.yaml_dump import deep_convert_dict, merge_comment\n\n\n@ApplyPipelineFactory.register(\"nodepred-ns\")\nclass ApplyNodepredNsPipeline(PipelineBase):\n    def __init__(self):\n        self.pipeline = {\"name\": \"nodepred-ns\", \"mode\": \"apply\"}\n\n    @classmethod\n    def setup_user_cfg_cls(cls):\n        from ...utils.enter_config import UserConfig\n\n        class ApplyNodePredUserConfig(UserConfig):\n            data: DataFactory.filter(\n                \"nodepred-ns\"\n            ).get_pydantic_config() = Field(..., discriminator=\"name\")\n\n        cls.user_cfg_cls = ApplyNodePredUserConfig\n\n    @property\n    def user_cfg_cls(self):\n        return self.__class__.user_cfg_cls\n\n    def get_cfg_func(self):\n        def config(\n            data: DataFactory.filter(\n                \"nodepred-ns\"\n            ).get_dataset_enum() = typer.Option(None, help=\"input data name\"),\n            cfg: Optional[str] = typer.Option(\n                None, help=\"output configuration file path\"\n            ),\n            cpt: str = typer.Option(..., help=\"input checkpoint file path\"),\n        ):\n            # Training configuration\n            train_cfg = torch.load(cpt, weights_only=False)[\"cfg\"]\n            if data is None:\n                print(\"data is not specified, use the training dataset\")\n                data = train_cfg[\"data_name\"]\n            else:\n                data = data.name\n            if cfg is None:\n                cfg = (\n                    \"_\".join(\n                        [\"apply\", \"nodepred-ns\", data, train_cfg[\"model_name\"]]\n                    )\n                    + \".yaml\"\n                )\n\n            self.__class__.setup_user_cfg_cls()\n            generated_cfg = {\n                \"pipeline_name\": self.pipeline[\"name\"],\n                \"pipeline_mode\": self.pipeline[\"mode\"],\n                \"device\": train_cfg[\"device\"],\n                \"data\": {\"name\": data},\n                \"cpt_path\": cpt,\n                \"general_pipeline\": {\"save_path\": \"apply_results\"},\n            }\n            output_cfg = self.user_cfg_cls(**generated_cfg).dict()\n            output_cfg = deep_convert_dict(output_cfg)\n            # Not applicable for inference\n            output_cfg[\"data\"].pop(\"split_ratio\")\n            comment_dict = {\n                \"device\": \"Torch device name, e.g., cpu or cuda or cuda:0\",\n                \"cpt_path\": \"Path to the checkpoint file\",\n                \"general_pipeline\": {\n                    \"save_path\": \"Directory to save the inference results\"\n                },\n            }\n            comment_dict = merge_comment(output_cfg, comment_dict)\n\n            yaml = ruamel.yaml.YAML()\n            yaml.dump(comment_dict, Path(cfg).open(\"w\"))\n            print(\n                \"Configuration file is generated at {}\".format(\n                    Path(cfg).absolute()\n                )\n            )\n\n        return config\n\n    @classmethod\n    def gen_script(cls, user_cfg_dict):\n        # Check validation\n        cls.setup_user_cfg_cls()\n        cls.user_cfg_cls(**user_cfg_dict)\n\n        # Training configuration\n        train_cfg = torch.load(user_cfg_dict[\"cpt_path\"], weights_only=False)[\n            \"cfg\"\n        ]\n\n        # Dict for code rendering\n        render_cfg = deepcopy(user_cfg_dict)\n        model_name = train_cfg[\"model_name\"]\n        model_code = NodeModelFactory.get_source_code(model_name)\n        render_cfg[\"model_code\"] = model_code\n        render_cfg[\"model_class_name\"] = NodeModelFactory.get_model_class_name(\n            model_name\n        )\n        render_cfg.update(\n            DataFactory.get_generated_code_dict(user_cfg_dict[\"data\"][\"name\"])\n        )\n\n        # Dict for defining cfg in the rendered code\n        generated_user_cfg = deepcopy(user_cfg_dict)\n        generated_user_cfg[\"data\"].pop(\"name\")\n        generated_user_cfg.pop(\"pipeline_name\")\n        generated_user_cfg.pop(\"pipeline_mode\")\n        # model arch configuration\n        generated_user_cfg[\"model\"] = train_cfg[\"model\"]\n\n        render_cfg[\"user_cfg_str\"] = f\"cfg = {str(generated_user_cfg)}\"\n        render_cfg[\"user_cfg\"] = user_cfg_dict\n\n        file_current_dir = Path(__file__).resolve().parent\n        with open(file_current_dir / \"nodepred-ns.jinja-py\", \"r\") as f:\n            template = Template(f.read())\n\n        return template.render(**render_cfg)\n\n    @staticmethod\n    def get_description() -> str:\n        return \"Node classification neighbor sampling pipeline for inference\"\n"
  },
  {
    "path": "dglgo/dglgo/apply_pipeline/nodepred_sample/nodepred-ns.jinja-py",
    "content": "import torch\nimport dgl\nimport os\nimport csv\n\nfrom dgl.data import AsNodePredDataset\n{{ data_import_code }}\n\n{{ model_code }}\n\ndef infer(device, data, model):\n    g = data[0] # Only infer on the first graph\n    g = dgl.remove_self_loop(g)\n    g = dgl.add_self_loop(g)\n    g = g.to(device)\n\n    node_feat = g.ndata.get('feat', None)\n    edge_feat = g.edata.get('feat', None)\n\n    model = model.to(device)\n    model.eval()\n\n    with torch.no_grad():\n        logits = model(g, node_feat, edge_feat)\n\n    return logits\n\ndef main():\n    {{ user_cfg_str }}\n\n    device = cfg['device']\n    if not torch.cuda.is_available():\n        device = 'cpu'\n\n    # load data\n    data = AsNodePredDataset({{data_initialize_code}})\n    # validation\n    if cfg['model']['embed_size'] > 0:\n        model_num_nodes = cfg['model']['data_info']['num_nodes']\n        data_num_nodes = data[0].num_nodes()\n        assert model_num_nodes == data_num_nodes, \\\n            'Training and inference need to be on the same dataset when node embeddings were learned from scratch'\n    else:\n        model_in_size = cfg['model']['data_info']['in_size']\n        data_in_size = data[0].ndata['feat'].shape[1]\n        assert model_in_size == data_in_size, \\\n            'Expect the training data and inference data to have the same number of input node \\\n                features, got {:d} and {:d}'.format(model_in_size, data_in_size)\n\n    model = {{ model_class_name }}(**cfg['model'])\n    model.load_state_dict(torch.load(cfg['cpt_path'], weights_only=False, map_location='cpu')['model'])\n    logits = infer(device, data, model)\n    pred = logits.argmax(dim=1).cpu()\n\n    # Dump the results\n    os.makedirs(cfg['general_pipeline'][\"save_path\"])\n    file_path = os.path.join(cfg['general_pipeline'][\"save_path\"], 'output.csv')\n    with open(file_path, 'w') as f:\n        writer = csv.writer(f)\n        writer.writerow(['node id', 'predicted label'])\n        writer.writerows([\n            [i, pred[i].item()] for i in range(len(pred))\n        ])\n    print('Saved inference results to {}'.format(file_path))\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "dglgo/dglgo/cli/__init__.py",
    "content": "from .cli import app\n\nif __name__ == \"__main__\":\n    app()\n"
  },
  {
    "path": "dglgo/dglgo/cli/apply_cli.py",
    "content": "from pathlib import Path\n\nimport autopep8\nimport isort\nimport typer\nimport yaml\n\nfrom ..utils.factory import ApplyPipelineFactory\n\n\ndef apply(cfg: str = typer.Option(..., help=\"config yaml file name\")):\n    user_cfg = yaml.safe_load(Path(cfg).open(\"r\"))\n    pipeline_name = user_cfg[\"pipeline_name\"]\n    output_file_content = ApplyPipelineFactory.registry[\n        pipeline_name\n    ].gen_script(user_cfg)\n\n    f_code = autopep8.fix_code(output_file_content, options={\"aggressive\": 1})\n    f_code = isort.code(f_code)\n    code = compile(f_code, \"dglgo_tmp.py\", \"exec\")\n    exec(code, {\"__name__\": \"__main__\"})\n"
  },
  {
    "path": "dglgo/dglgo/cli/cli.py",
    "content": "import typer\nfrom ..pipeline import *\nfrom ..model import *\nfrom .apply_cli import apply\nfrom .config_apply_cli import config_apply_app\nfrom .config_cli import config_app\nfrom .export_cli import export\nfrom .recipe_cli import recipe_app\nfrom .train_cli import train\n\nno_args_is_help = False\napp = typer.Typer(no_args_is_help=True, add_completion=False)\napp.add_typer(config_app, name=\"configure\", no_args_is_help=no_args_is_help)\napp.add_typer(recipe_app, name=\"recipe\", no_args_is_help=True)\napp.command(help=\"Launch training\", no_args_is_help=no_args_is_help)(train)\napp.command(\n    help=\"Export a runnable python script\", no_args_is_help=no_args_is_help\n)(export)\napp.add_typer(\n    config_apply_app, name=\"configure-apply\", no_args_is_help=no_args_is_help\n)\napp.command(help=\"Launch inference\", no_args_is_help=no_args_is_help)(apply)\n\n\ndef main():\n    app()\n\n\nif __name__ == \"__main__\":\n    app()\n"
  },
  {
    "path": "dglgo/dglgo/cli/config_apply_cli.py",
    "content": "from ..apply_pipeline import *\nimport typer\n\nfrom ..utils.factory import ApplyPipelineFactory\n\nconfig_apply_app = typer.Typer(\n    help=\"Generate a configuration file for inference\"\n)\nfor key, pipeline in ApplyPipelineFactory.registry.items():\n    config_apply_app.command(key, help=pipeline.get_description())(\n        pipeline.get_cfg_func()\n    )\n"
  },
  {
    "path": "dglgo/dglgo/cli/config_cli.py",
    "content": "from ..pipeline import *\nimport typing\nfrom enum import Enum\nfrom pathlib import Path\n\nimport typer\nimport yaml\n\nfrom ..utils.factory import ModelFactory, PipelineFactory\n\nconfig_app = typer.Typer(help=\"Generate a configuration file\")\nfor key, pipeline in PipelineFactory.registry.items():\n    config_app.command(key, help=pipeline.get_description())(\n        pipeline.get_cfg_func()\n    )\n\nif __name__ == \"__main__\":\n    config_app()\n"
  },
  {
    "path": "dglgo/dglgo/cli/export_cli.py",
    "content": "import typing\nfrom enum import Enum\nfrom pathlib import Path\n\nimport autopep8\n\nimport isort\nimport typer\nimport yaml\n\nfrom ..utils.factory import ApplyPipelineFactory, ModelFactory, PipelineFactory\n\n\ndef export(\n    cfg: str = typer.Option(\"cfg.yaml\", help=\"config yaml file name\"),\n    output: str = typer.Option(\"script.py\", help=\"output python file name\"),\n):\n    user_cfg = yaml.safe_load(Path(cfg).open(\"r\"))\n    pipeline_name = user_cfg[\"pipeline_name\"]\n    pipeline_mode = user_cfg[\"pipeline_mode\"]\n    if pipeline_mode == \"train\":\n        output_file_content = PipelineFactory.registry[\n            pipeline_name\n        ].gen_script(user_cfg)\n    else:\n        output_file_content = ApplyPipelineFactory.registry[\n            pipeline_name\n        ].gen_script(user_cfg)\n\n    f_code = autopep8.fix_code(output_file_content, options={\"aggressive\": 1})\n    f_code = isort.code(f_code)\n    with open(output, \"w\") as f:\n        f.write(f_code)\n    print(\n        \"The python script is generated at {}, based on config file {}\".format(\n            Path(output).absolute(), Path(cfg).absolute()\n        )\n    )\n\n\nif __name__ == \"__main__\":\n    export_app = typer.Typer()\n    export_app.command()(export)\n    export_app()\n"
  },
  {
    "path": "dglgo/dglgo/cli/recipe_cli.py",
    "content": "import os\nimport shutil\nfrom pathlib import Path\nfrom typing import Optional\n\nimport typer\nimport yaml\n\n\ndef list_recipes():\n    file_current_dir = Path(__file__).resolve().parent\n    recipe_dir = file_current_dir.parent.parent / \"recipes\"\n    file_list = list(recipe_dir.glob(\"*.yaml\"))\n    header = \"| {:<30} |  {:<18} | {:<20} |\".format(\n        \"Filename\", \"Pipeline\", \"Dataset\"\n    )\n    typer.echo(\"=\" * len(header))\n    typer.echo(header)\n    typer.echo(\"=\" * len(header))\n    output_list = []\n    for file in file_list:\n        cfg = yaml.safe_load(Path(file).open(\"r\"))\n        output_list.append(\n            {\n                \"file_name\": file.name,\n                \"pipeline_name\": cfg[\"pipeline_name\"],\n                \"dataset_name\": cfg[\"data\"][\"name\"],\n            }\n        )\n    # sort by pipeline, if same sort by dataset, if same sort by file name\n    output_list.sort(\n        key=lambda f: (f[\"pipeline_name\"], f[\"dataset_name\"], f[\"file_name\"])\n    )\n    for f in output_list:\n        typer.echo(\n            \"| {:<30} |  {:<18} | {:<20} |\".format(\n                f[\"file_name\"], f[\"pipeline_name\"], f[\"dataset_name\"]\n            )\n        )\n    typer.echo(\"=\" * len(header))\n\n\ndef get_recipe(\n    recipe_name: Optional[str] = typer.Argument(\n        None, help=\"The recipe filename to get, e.q. nodepred_citeseer_gcn.yaml\"\n    )\n):\n    if recipe_name is None:\n        typer.echo(\"Usage: dgl recipe get [RECIPE_NAME] \\n\")\n        typer.echo(\" Copy the recipe to current directory \\n\")\n        typer.echo(\" Arguments:\")\n        typer.echo(\n            \"  [RECIPE_NAME]  The recipe filename to get, e.q. nodepred_citeseer_gcn.yaml\\n\"\n        )\n        typer.echo(\"Here are all avaliable recipe filename\")\n        list_recipes()\n    else:\n        file_current_dir = Path(__file__).resolve().parent\n        recipe_dir = file_current_dir.parent.parent / \"recipes\"\n        current_dir = Path(os.getcwd())\n        recipe_path = recipe_dir / recipe_name\n        shutil.copy(recipe_path, current_dir)\n        print(\n            \"Recipe {} is copied to {}\".format(\n                recipe_path.absolute(), current_dir.absolute()\n            )\n        )\n\n\nrecipe_app = typer.Typer(help=\"Get example recipes\")\nrecipe_app.command(name=\"list\", help=\"List all available example recipes\")(\n    list_recipes\n)\nrecipe_app.command(name=\"get\", help=\"Copy the recipe to current directory\")(\n    get_recipe\n)\n\nif __name__ == \"__main__\":\n    recipe_app()\n"
  },
  {
    "path": "dglgo/dglgo/cli/train_cli.py",
    "content": "import typing\nfrom enum import Enum\nfrom pathlib import Path\n\nimport autopep8\nimport isort\nimport typer\nimport yaml\n\nfrom ..utils.factory import ModelFactory, PipelineFactory\n\n\ndef train(\n    cfg: str = typer.Option(\"cfg.yaml\", help=\"config yaml file name\"),\n):\n    user_cfg = yaml.safe_load(Path(cfg).open(\"r\"))\n    pipeline_name = user_cfg[\"pipeline_name\"]\n    output_file_content = PipelineFactory.registry[pipeline_name].gen_script(\n        user_cfg\n    )\n\n    f_code = autopep8.fix_code(output_file_content, options={\"aggressive\": 1})\n    f_code = isort.code(f_code)\n    code = compile(f_code, \"dglgo_tmp.py\", \"exec\")\n    exec(code, {\"__name__\": \"__main__\"})\n\n\nif __name__ == \"__main__\":\n    train_app = typer.Typer()\n    train_app.command()(train)\n    train_app()\n"
  },
  {
    "path": "dglgo/dglgo/model/__init__.py",
    "content": "from .node_encoder import *\nfrom .edge_encoder import *\nfrom .graph_encoder import *\n"
  },
  {
    "path": "dglgo/dglgo/model/edge_encoder/__init__.py",
    "content": "from ...utils.factory import EdgeModelFactory\nfrom .bilinear import BilinearPredictor\nfrom .ele import ElementWiseProductPredictor\n\nEdgeModelFactory.register(\"ele\")(ElementWiseProductPredictor)\nEdgeModelFactory.register(\"bilinear\")(BilinearPredictor)\n"
  },
  {
    "path": "dglgo/dglgo/model/edge_encoder/bilinear.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass BilinearPredictor(nn.Module):\n    def __init__(\n        self,\n        data_info: dict,\n        hidden_size: int = 32,\n        num_layers: int = 1,\n        bias: bool = True,\n    ):\n        \"\"\"Bilinear product model for edge scores\n\n        Parameters\n        ----------\n        data_info : dict\n            The information about the input dataset.\n        hidden_size : int\n            Hidden size.\n        num_layers : int\n            Number of hidden layers.\n        bias : bool\n            Whether to use bias in the linaer layer.\n        \"\"\"\n        super(BilinearPredictor, self).__init__()\n        in_size, out_size = data_info[\"in_size\"], data_info[\"out_size\"]\n        self.bilinear = nn.Bilinear(in_size, in_size, hidden_size, bias=bias)\n        lins_list = []\n        for _ in range(num_layers - 2):\n            lins_list.append(nn.Linear(hidden_size, hidden_size, bias=bias))\n            lins_list.append(nn.ReLU())\n        lins_list.append(nn.Linear(hidden_size, out_size, bias=bias))\n        self.linear = nn.Sequential(*lins_list)\n\n    def forward(self, h_src, h_dst):\n        h = self.bilinear(h_src, h_dst)\n        h = self.linear(h)\n        h = torch.sigmoid(h)\n        return h\n"
  },
  {
    "path": "dglgo/dglgo/model/edge_encoder/dot.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass DotPredictor(nn.Module):\n    def __init__(\n        self,\n        in_size: int = -1,\n        out_size: int = 1,\n        hidden_size: int = 256,\n        num_layers: int = 3,\n        bias: bool = False,\n    ):\n        super(DotPredictor, self).__init__()\n        lins_list = []\n        for _ in range(num_layers - 2):\n            lins_list.append(nn.Linear(in_size, hidden_size, bias=bias))\n            lins_list.append(nn.ReLU())\n        lins_list.append(nn.Linear(hidden_size, out_size, bias=bias))\n        self.linear = nn.Sequential(*lins_list)\n\n    def forward(self, h_src, h_dst):\n        h = h_src * h_dst\n        h = self.linear(h)\n        h = torch.sigmoid(h)\n        return h\n"
  },
  {
    "path": "dglgo/dglgo/model/edge_encoder/ele.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass ElementWiseProductPredictor(nn.Module):\n    def __init__(\n        self,\n        data_info: dict,\n        hidden_size: int = 64,\n        num_layers: int = 2,\n        bias: bool = True,\n    ):\n        \"\"\"Elementwise product model for edge scores\n\n        Parameters\n        ----------\n        data_info : dict\n            The information about the input dataset.\n        hidden_size : int\n            Hidden size.\n        num_layers : int\n            Number of hidden layers.\n        bias : bool\n            Whether to use bias in the linaer layer.\n        \"\"\"\n        super(ElementWiseProductPredictor, self).__init__()\n        lins_list = []\n        in_size, out_size = data_info[\"in_size\"], data_info[\"out_size\"]\n        for i in range(num_layers):\n            in_hiddnen = in_size if i == 0 else hidden_size\n            out_hidden = hidden_size if i < num_layers - 1 else out_size\n            lins_list.append(nn.Linear(in_hiddnen, out_hidden, bias=bias))\n            if i < num_layers - 1:\n                lins_list.append(nn.ReLU())\n        self.linear = nn.Sequential(*lins_list)\n\n    def forward(self, h_src, h_dst):\n        h = h_src * h_dst\n        h = self.linear(h)\n        h = torch.sigmoid(h)\n        return h\n"
  },
  {
    "path": "dglgo/dglgo/model/graph_encoder/__init__.py",
    "content": "from ...utils.factory import GraphModelFactory\nfrom .gin_ogbg import OGBGGIN\nfrom .pna import PNA\n\nGraphModelFactory.register(\"gin\")(OGBGGIN)\nGraphModelFactory.register(\"pna\")(PNA)\n"
  },
  {
    "path": "dglgo/dglgo/model/graph_encoder/gin_ogbg.py",
    "content": "import dgl\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dgl.nn import AvgPooling, GINEConv, SumPooling\nfrom ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder\n\n\nclass MLP(nn.Module):\n    def __init__(self, feat_size: int):\n        \"\"\"Multilayer Perceptron (MLP)\"\"\"\n        super(MLP, self).__init__()\n        self.mlp = nn.Sequential(\n            nn.Linear(feat_size, 2 * feat_size),\n            nn.BatchNorm1d(2 * feat_size),\n            nn.ReLU(),\n            nn.Linear(2 * feat_size, feat_size),\n            nn.BatchNorm1d(feat_size),\n        )\n\n    def forward(self, h):\n        return self.mlp(h)\n\n\nclass OGBGGIN(nn.Module):\n    def __init__(\n        self,\n        data_info: dict,\n        embed_size: int = 300,\n        num_layers: int = 5,\n        dropout: float = 0.5,\n        virtual_node: bool = False,\n    ):\n        \"\"\"Graph Isomorphism Network (GIN) variant introduced in baselines\n        for OGB graph property prediction datasets\n\n        Parameters\n        ----------\n        data_info : dict\n            The information about the input dataset.\n        embed_size : int\n            Embedding size.\n        num_layers : int\n            Number of layers.\n        dropout : float\n            Dropout rate.\n        virtual_node : bool\n            Whether to use virtual node.\n        \"\"\"\n        super(OGBGGIN, self).__init__()\n        self.data_info = data_info\n        self.embed_size = embed_size\n        self.num_layers = num_layers\n        self.virtual_node = virtual_node\n\n        if data_info[\"name\"] in [\"ogbg-molhiv\", \"ogbg-molpcba\"]:\n            self.node_encoder = AtomEncoder(embed_size)\n            self.edge_encoders = nn.ModuleList(\n                [BondEncoder(embed_size) for _ in range(num_layers)]\n            )\n        else:\n            # Handle other datasets\n            self.node_encoder = nn.Linear(\n                data_info[\"node_feat_size\"], embed_size\n            )\n            self.edge_encoders = nn.ModuleList(\n                [\n                    nn.Linear(data_info[\"edge_feat_size\"], embed_size)\n                    for _ in range(num_layers)\n                ]\n            )\n\n        self.conv_layers = nn.ModuleList(\n            [GINEConv(MLP(embed_size)) for _ in range(num_layers)]\n        )\n\n        self.dropout = nn.Dropout(dropout)\n        self.pool = AvgPooling()\n        self.pred = nn.Linear(embed_size, data_info[\"out_size\"])\n\n        if virtual_node:\n            self.virtual_emb = nn.Embedding(1, embed_size)\n            nn.init.constant_(self.virtual_emb.weight.data, 0)\n            self.mlp_virtual = nn.ModuleList()\n            for _ in range(num_layers - 1):\n                self.mlp_virtual.append(MLP(embed_size))\n            self.virtual_pool = SumPooling()\n\n    def forward(self, graph, node_feat, edge_feat):\n        if self.virtual_node:\n            virtual_emb = self.virtual_emb.weight.expand(graph.batch_size, -1)\n\n        hn = self.node_encoder(node_feat)\n\n        for layer in range(self.num_layers):\n\n            if self.virtual_node:\n                # messages from virtual nodes to graph nodes\n                virtual_hn = dgl.broadcast_nodes(graph, virtual_emb)\n                hn = hn + virtual_hn\n\n            he = self.edge_encoders[layer](edge_feat)\n            hn = self.conv_layers[layer](graph, hn, he)\n            if layer != self.num_layers - 1:\n                hn = F.relu(hn)\n            hn = self.dropout(hn)\n\n            if self.virtual_node and layer != self.num_layers - 1:\n                # messages from graph nodes to virtual nodes\n                virtual_emb_tmp = self.virtual_pool(graph, hn) + virtual_emb\n                virtual_emb = self.mlp_virtual[layer](virtual_emb_tmp)\n                virtual_emb = self.dropout(F.relu(virtual_emb))\n\n        hg = self.pool(graph, hn)\n\n        return self.pred(hg)\n"
  },
  {
    "path": "dglgo/dglgo/model/graph_encoder/pna.py",
    "content": "from typing import List\n\nimport dgl.function as fn\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dgl.nn import AvgPooling, SumPooling\nfrom ogb.graphproppred.mol_encoder import AtomEncoder\n\n\ndef aggregate_mean(h):\n    \"\"\"mean aggregation\"\"\"\n    return torch.mean(h, dim=1)\n\n\ndef aggregate_max(h):\n    \"\"\"max aggregation\"\"\"\n    return torch.max(h, dim=1)[0]\n\n\ndef aggregate_min(h):\n    \"\"\"min aggregation\"\"\"\n    return torch.min(h, dim=1)[0]\n\n\ndef aggregate_sum(h):\n    \"\"\"sum aggregation\"\"\"\n    return torch.sum(h, dim=1)\n\n\ndef aggregate_var(h):\n    \"\"\"variance aggregation\"\"\"\n    h_mean_squares = torch.mean(h * h, dim=1)\n    h_mean = torch.mean(h, dim=1)\n    var = torch.relu(h_mean_squares - h_mean * h_mean)\n    return var\n\n\ndef aggregate_std(h):\n    \"\"\"standard deviation aggregation\"\"\"\n    return torch.sqrt(aggregate_var(h) + 1e-5)\n\n\nAGGREGATORS = {\n    \"mean\": aggregate_mean,\n    \"sum\": aggregate_sum,\n    \"max\": aggregate_max,\n    \"min\": aggregate_min,\n    \"std\": aggregate_std,\n    \"var\": aggregate_var,\n}\n\n\ndef scale_identity(h, D, delta):\n    \"\"\"identity scaling (no scaling operation)\"\"\"\n    return h\n\n\ndef scale_amplification(h, D, delta):\n    \"\"\"amplification scaling\"\"\"\n    return h * (np.log(D + 1) / delta)\n\n\ndef scale_attenuation(h, D, delta):\n    \"\"\"attenuation scaling\"\"\"\n    return h * (delta / np.log(D + 1))\n\n\nSCALERS = {\n    \"identity\": scale_identity,\n    \"amplification\": scale_amplification,\n    \"attenuation\": scale_attenuation,\n}\n\n\nclass MLP(nn.Module):\n    def __init__(\n        self,\n        in_feat_size: int,\n        out_feat_size: int,\n        num_layers: int = 3,\n        decreasing_hidden_size=False,\n    ):\n        \"\"\"Multilayer Perceptron (MLP)\"\"\"\n        super(MLP, self).__init__()\n\n        self.layers = nn.ModuleList()\n        if decreasing_hidden_size:\n            for i in range(num_layers - 1):\n                self.layers.append(\n                    nn.Linear(\n                        in_feat_size // 2**i, in_feat_size // 2 ** (i + 1)\n                    )\n                )\n            self.layers.append(\n                nn.Linear(in_feat_size // 2 ** (num_layers - 1), out_feat_size)\n            )\n        else:\n            self.layers.append(nn.Linear(in_feat_size, out_feat_size))\n            for _ in range(num_layers - 1):\n                self.layers.append(nn.Linear(out_feat_size, out_feat_size))\n        self.num_layers = num_layers\n\n    def forward(self, h):\n        for i, layer in enumerate(self.layers):\n            h = layer(h)\n            if i != self.num_layers - 1:\n                h = F.relu(h)\n        return h\n\n\nclass SimplePNAConv(nn.Module):\n    r\"\"\"A simplified PNAConv variant used in OGB submissions\"\"\"\n\n    def __init__(\n        self,\n        feat_size: int,\n        aggregators: List[str],\n        scalers: List[str],\n        delta: float,\n        dropout: float,\n        batch_norm: bool,\n        residual: bool,\n        num_mlp_layers: int,\n    ):\n        super(SimplePNAConv, self).__init__()\n\n        self.aggregators = [AGGREGATORS[aggr] for aggr in aggregators]\n        self.scalers = [SCALERS[scale] for scale in scalers]\n        self.delta = delta\n        self.mlp = MLP(\n            in_feat_size=(len(aggregators) * len(scalers)) * feat_size,\n            out_feat_size=feat_size,\n            num_layers=num_mlp_layers,\n        )\n        self.dropout = nn.Dropout(dropout)\n        self.residual = residual\n\n        if batch_norm:\n            self.bn = nn.BatchNorm1d(feat_size)\n        else:\n            self.bn = None\n\n    def reduce(self, nodes):\n        h = nodes.mailbox[\"m\"]\n        D = h.shape[-2]\n        h = torch.cat([aggregate(h) for aggregate in self.aggregators], dim=1)\n        h = torch.cat(\n            [scale(h, D=D, delta=self.delta) for scale in self.scalers], dim=1\n        )\n        return {\"h\": h}\n\n    def forward(self, g, h):\n        with g.local_scope():\n            g.ndata[\"h\"] = h\n            g.update_all(fn.copy_u(\"h\", \"m\"), self.reduce)\n            h_new = g.ndata[\"h\"]\n        h_new = self.mlp(h_new)\n\n        if self.bn is not None:\n            h_new = self.bn(h_new)\n        h_new = F.relu(h_new)\n\n        if self.residual:\n            h_new = h_new + h\n        h_new = self.dropout(h_new)\n\n        return h_new\n\n\nclass PNA(nn.Module):\n    def __init__(\n        self,\n        data_info: dict,\n        embed_size: int = 80,\n        aggregators: str = \"mean max min std\",\n        scalers: str = \"identity amplification attenuation\",\n        dropout: float = 0.3,\n        batch_norm: bool = True,\n        residual: bool = True,\n        num_mlp_layers: int = 1,\n        num_layers: int = 4,\n        readout: str = \"mean\",\n    ):\n        \"\"\"Principal Neighbourhood Aggregation\n\n        Parameters\n        ----------\n        data_info : dict\n            The information about the input dataset.\n        embed_size : int\n            Embedding size.\n        aggregators : str\n            Aggregation function names separated by space, can include mean, max, min, std, sum\n        scalers : str\n            Scaler function names separated by space, can include identity, amplification, and attenuation\n        dropout : float\n            Dropout rate.\n        batch_norm : bool\n            Whether to use batch normalization.\n        residual : bool\n            Whether to use residual connection.\n        num_mlp_layers : int\n            Number of MLP layers to use after message aggregation in each PNA layer.\n        num_layers : int\n            Number of PNA layers.\n        readout : str\n            Readout for computing graph-level representations, can be 'sum' or 'mean'.\n        \"\"\"\n        super(PNA, self).__init__()\n        self.data_info = data_info\n        self.embed_size = embed_size\n        self.dropout = dropout\n        self.batch_norm = batch_norm\n        self.residual = residual\n        self.num_mlp_layers = num_mlp_layers\n        self.num_layers = num_layers\n        self.readout = readout\n\n        if aggregators is None:\n            aggregators = [\"mean\", \"max\", \"min\", \"std\"]\n        else:\n            aggregators = [agg.strip() for agg in aggregators.split(\" \")]\n            assert set(aggregators).issubset(\n                {\"mean\", \"max\", \"min\", \"std\", \"sum\"}\n            ), \"Expect aggregators to be a subset of ['mean', 'max', 'min', 'std', 'sum'], \\\n                    got {}\".format(\n                aggregators\n            )\n        if scalers is None:\n            scalers = [\"identity\", \"amplification\", \"attenuation\"]\n        else:\n            scalers = [scl.strip() for scl in scalers.split(\" \")]\n            assert set(scalers).issubset(\n                {\"identity\", \"amplification\", \"attenuation\"}\n            ), \"Expect scalers to be a subset of ['identity', 'amplification', 'attenuation'], \\\n                    got {}\".format(\n                scalers\n            )\n        self.aggregators = aggregators\n        self.scalers = scalers\n\n        if data_info[\"name\"] in [\"ogbg-molhiv\", \"ogbg-molpcba\"]:\n            self.node_encoder = AtomEncoder(embed_size)\n        else:\n            # Handle other datasets\n            self.node_encoder = nn.Linear(\n                data_info[\"node_feat_size\"], embed_size\n            )\n        self.conv_layers = nn.ModuleList(\n            [\n                SimplePNAConv(\n                    feat_size=embed_size,\n                    aggregators=aggregators,\n                    scalers=scalers,\n                    delta=data_info[\"delta\"],\n                    dropout=dropout,\n                    batch_norm=batch_norm,\n                    residual=residual,\n                    num_mlp_layers=num_mlp_layers,\n                )\n                for _ in range(num_layers)\n            ]\n        )\n\n        if readout == \"sum\":\n            self.pool = SumPooling()\n        elif readout == \"mean\":\n            self.pool = AvgPooling()\n        else:\n            raise ValueError(\n                \"Expect readout to be 'sum' or 'mean', got {}\".format(readout)\n            )\n        self.pred = MLP(\n            embed_size, data_info[\"out_size\"], decreasing_hidden_size=True\n        )\n\n    def forward(self, graph, node_feat, edge_feat=None):\n        hn = self.node_encoder(node_feat)\n        for conv in self.conv_layers:\n            hn = conv(graph, hn)\n        hg = self.pool(graph, hn)\n\n        return self.pred(hg)\n"
  },
  {
    "path": "dglgo/dglgo/model/node_encoder/__init__.py",
    "content": "from ...utils.factory import NodeModelFactory\nfrom .gat import GAT\nfrom .gcn import GCN\nfrom .gin import GIN\nfrom .sage import GraphSAGE\nfrom .sgc import SGC\n\nNodeModelFactory.register(\"gcn\")(GCN)\nNodeModelFactory.register(\"gat\")(GAT)\nNodeModelFactory.register(\"sage\")(GraphSAGE)\nNodeModelFactory.register(\"sgc\")(SGC)\nNodeModelFactory.register(\"gin\")(GIN)\n"
  },
  {
    "path": "dglgo/dglgo/model/node_encoder/gat.py",
    "content": "from typing import List\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dgl.base import dgl_warning\nfrom dgl.nn import GATConv\n\n\nclass GAT(nn.Module):\n    def __init__(\n        self,\n        data_info: dict,\n        embed_size: int = -1,\n        num_layers: int = 2,\n        hidden_size: int = 8,\n        heads: List[int] = [8, 8],\n        activation: str = \"elu\",\n        feat_drop: float = 0.6,\n        attn_drop: float = 0.6,\n        negative_slope: float = 0.2,\n        residual: bool = False,\n    ):\n        \"\"\"Graph Attention Networks\n\n        Parameters\n        ----------\n        data_info : dict\n            The information about the input dataset.\n        embed_size : int\n            The dimension of created embedding table. -1 means using original node embedding\n        hidden_size : int\n            Hidden size.\n        num_layers : int\n            Number of layers.\n        norm : str\n            GCN normalization type. Can be 'both', 'right', 'left', 'none'.\n        activation : str\n            Activation function.\n        feat_drop : float\n            Dropout rate for features.\n        attn_drop : float\n            Dropout rate for attentions.\n        negative_slope : float\n            Negative slope for leaky relu in GATConv\n        residual : bool\n            If true, the GATConv will use residule connection\n        \"\"\"\n        super(GAT, self).__init__()\n        self.data_info = data_info\n        self.embed_size = embed_size\n        self.num_layers = num_layers\n        self.gat_layers = nn.ModuleList()\n        self.activation = getattr(torch.nn.functional, activation)\n\n        if embed_size > 0:\n            self.embed = nn.Embedding(data_info[\"num_nodes\"], embed_size)\n            in_size = embed_size\n        else:\n            in_size = data_info[\"in_size\"]\n\n        for i in range(num_layers):\n            in_hidden = hidden_size * heads[i - 1] if i > 0 else in_size\n            out_hidden = (\n                hidden_size if i < num_layers - 1 else data_info[\"out_size\"]\n            )\n            activation = None if i == num_layers - 1 else self.activation\n\n            self.gat_layers.append(\n                GATConv(\n                    in_hidden,\n                    out_hidden,\n                    heads[i],\n                    feat_drop,\n                    attn_drop,\n                    negative_slope,\n                    residual,\n                    activation,\n                )\n            )\n\n    def forward(self, graph, node_feat, edge_feat=None):\n        if self.embed_size > 0:\n            dgl_warning(\n                \"The embedding for node feature is used, and input node_feat is ignored, due to the provided embed_size.\"\n            )\n            h = self.embed.weight\n        else:\n            h = node_feat\n        for l in range(self.num_layers - 1):\n            h = self.gat_layers[l](graph, h).flatten(1)\n        # output projection\n        logits = self.gat_layers[-1](graph, h).mean(1)\n        return logits\n\n    def forward_block(self, blocks, node_feat, edge_feat=None):\n        h = node_feat\n        for l in range(self.num_layers - 1):\n            h = self.gat_layers[l](blocks[l], h).flatten(1)\n        logits = self.gat_layers[-1](blocks[-1], h).mean(1)\n        return logits\n"
  },
  {
    "path": "dglgo/dglgo/model/node_encoder/gcn.py",
    "content": "import dgl\nimport torch\nimport torch.nn as nn\nfrom dgl.base import dgl_warning\n\n\nclass GCN(nn.Module):\n    def __init__(\n        self,\n        data_info: dict,\n        embed_size: int = -1,\n        hidden_size: int = 16,\n        num_layers: int = 1,\n        norm: str = \"both\",\n        activation: str = \"relu\",\n        dropout: float = 0.5,\n        use_edge_weight: bool = False,\n    ):\n        \"\"\"Graph Convolutional Networks\n\n        Parameters\n        ----------\n        data_info : dict\n            The information about the input dataset.\n        embed_size : int\n            The dimension of created embedding table. -1 means using original node embedding\n        hidden_size : int\n            Hidden size.\n        num_layers : int\n            Number of layers.\n        norm : str\n            GCN normalization type. Can be 'both', 'right', 'left', 'none'.\n        activation : str\n            Activation function.\n        dropout : float\n            Dropout rate.\n        use_edge_weight : bool\n            If true, scale the messages by edge weights.\n        \"\"\"\n        super().__init__()\n        self.use_edge_weight = use_edge_weight\n        self.data_info = data_info\n        self.embed_size = embed_size\n        self.layers = nn.ModuleList()\n        if embed_size > 0:\n            self.embed = nn.Embedding(data_info[\"num_nodes\"], embed_size)\n            in_size = embed_size\n        else:\n            in_size = data_info[\"in_size\"]\n\n        for i in range(num_layers):\n            in_hidden = hidden_size if i > 0 else in_size\n            out_hidden = (\n                hidden_size if i < num_layers - 1 else data_info[\"out_size\"]\n            )\n\n            self.layers.append(\n                dgl.nn.GraphConv(\n                    in_hidden, out_hidden, norm=norm, allow_zero_in_degree=True\n                )\n            )\n\n        self.dropout = nn.Dropout(p=dropout)\n        self.act = getattr(torch, activation)\n\n    def forward(self, g, node_feat, edge_feat=None):\n        if self.embed_size > 0:\n            dgl_warning(\n                \"The embedding for node feature is used, and input node_feat is ignored, due to the provided embed_size.\"\n            )\n            h = self.embed.weight\n        else:\n            h = node_feat\n        edge_weight = edge_feat if self.use_edge_weight else None\n        for l, layer in enumerate(self.layers):\n            h = layer(g, h, edge_weight=edge_weight)\n            if l != len(self.layers) - 1:\n                h = self.act(h)\n                h = self.dropout(h)\n        return h\n\n    def forward_block(self, blocks, node_feat, edge_feat=None):\n        h = node_feat\n        edge_weight = edge_feat if self.use_edge_weight else None\n        for l, (layer, block) in enumerate(zip(self.layers, blocks)):\n            h = layer(block, h, edge_weight=edge_weight)\n            if l != len(self.layers) - 1:\n                h = self.act(h)\n                h = self.dropout(h)\n        return h\n"
  },
  {
    "path": "dglgo/dglgo/model/node_encoder/gin.py",
    "content": "import torch.nn as nn\nfrom dgl.base import dgl_warning\nfrom dgl.nn import GINConv\n\n\nclass GIN(nn.Module):\n    def __init__(\n        self,\n        data_info: dict,\n        embed_size: int = -1,\n        hidden_size=64,\n        num_layers=3,\n        aggregator_type=\"sum\",\n    ):\n        \"\"\"Graph Isomophism Networks\n\n        Edge feature is ignored in this model.\n\n        Parameters\n        ----------\n        data_info : dict\n            The information about the input dataset.\n        embed_size : int\n            The dimension of created embedding table. -1 means using original node embedding\n        hidden_size : int\n            Hidden size.\n        num_layers : int\n            Number of layers.\n        aggregator_type : str\n            Aggregator type to use (``sum``, ``max`` or ``mean``), default: 'sum'.\n        \"\"\"\n        super().__init__()\n        self.data_info = data_info\n        self.embed_size = embed_size\n        self.conv_list = nn.ModuleList()\n        self.num_layers = num_layers\n        if embed_size > 0:\n            self.embed = nn.Embedding(data_info[\"num_nodes\"], embed_size)\n            in_size = embed_size\n        else:\n            in_size = data_info[\"in_size\"]\n        for i in range(num_layers):\n            input_dim = in_size if i == 0 else hidden_size\n            mlp = nn.Sequential(\n                nn.Linear(input_dim, hidden_size),\n                nn.BatchNorm1d(hidden_size),\n                nn.ReLU(),\n                nn.Linear(hidden_size, hidden_size),\n                nn.ReLU(),\n            )\n\n            self.conv_list.append(GINConv(mlp, aggregator_type, 1e-5, True))\n        self.out_mlp = nn.Linear(hidden_size, data_info[\"out_size\"])\n\n    def forward(self, graph, node_feat, edge_feat=None):\n        if self.embed_size > 0:\n            dgl_warning(\n                \"The embedding for node feature is used, and input node_feat is ignored, due to the provided embed_size.\"\n            )\n            h = self.embed.weight\n        else:\n            h = node_feat\n        for i in range(self.num_layers):\n            h = self.conv_list[i](graph, h)\n        h = self.out_mlp(h)\n        return h\n"
  },
  {
    "path": "dglgo/dglgo/model/node_encoder/sage.py",
    "content": "import dgl\nimport torch.nn as nn\nfrom dgl.base import dgl_warning\n\n\nclass GraphSAGE(nn.Module):\n    def __init__(\n        self,\n        data_info: dict,\n        embed_size: int = -1,\n        hidden_size: int = 16,\n        num_layers: int = 1,\n        activation: str = \"relu\",\n        dropout: float = 0.5,\n        aggregator_type: str = \"gcn\",\n    ):\n        \"\"\"GraphSAGE model\n\n        Parameters\n        ----------\n        data_info : dict\n            The information about the input dataset.\n        embed_size : int\n            The dimension of created embedding table. -1 means using original node embedding\n        hidden_size : int\n            Hidden size.\n        num_layers : int\n            Number of hidden layers.\n        dropout : float\n            Dropout rate.\n        activation : str\n            Activation function name under torch.nn.functional\n        aggregator_type : str\n            Aggregator type to use (``mean``, ``gcn``, ``pool``, ``lstm``).\n        \"\"\"\n        super(GraphSAGE, self).__init__()\n        self.data_info = data_info\n        self.embed_size = embed_size\n        if embed_size > 0:\n            self.embed = nn.Embedding(data_info[\"num_nodes\"], embed_size)\n            in_size = embed_size\n        else:\n            in_size = data_info[\"in_size\"]\n        self.layers = nn.ModuleList()\n        self.dropout = nn.Dropout(dropout)\n        self.activation = getattr(nn.functional, activation)\n\n        for i in range(num_layers):\n            in_hidden = hidden_size if i > 0 else in_size\n            out_hidden = (\n                hidden_size if i < num_layers - 1 else data_info[\"out_size\"]\n            )\n            self.layers.append(\n                dgl.nn.SAGEConv(in_hidden, out_hidden, aggregator_type)\n            )\n\n    def forward(self, graph, node_feat, edge_feat=None):\n        if self.embed_size > 0:\n            dgl_warning(\n                \"The embedding for node feature is used, and input node_feat is ignored, due to the provided embed_size.\"\n            )\n            h = self.embed.weight\n        else:\n            h = node_feat\n        h = self.dropout(h)\n        for l, layer in enumerate(self.layers):\n            h = layer(graph, h, edge_feat)\n            if l != len(self.layers) - 1:\n                h = self.activation(h)\n                h = self.dropout(h)\n        return h\n\n    def forward_block(self, blocks, node_feat, edge_feat=None):\n        h = node_feat\n        for l, (layer, block) in enumerate(zip(self.layers, blocks)):\n            h = layer(block, h, edge_feat)\n            if l != len(self.layers) - 1:\n                h = self.activation(h)\n                h = self.dropout(h)\n        return h\n"
  },
  {
    "path": "dglgo/dglgo/model/node_encoder/sgc.py",
    "content": "import dgl.function as fn\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dgl.base import dgl_warning\nfrom dgl.nn import SGConv\n\n\nclass SGC(nn.Module):\n    def __init__(self, data_info: dict, embed_size: int = -1, bias=True, k=2):\n        \"\"\"Simplifying Graph Convolutional Networks\n\n        Edge feature is ignored in this model.\n\n        Parameters\n        ----------\n        data_info : dict\n            The information about the input dataset.\n        embed_size : int\n            The dimension of created embedding table. -1 means using original node embedding\n        bias : bool\n            If True, adds a learnable bias to the output. Default: ``True``.\n        k : int\n            Number of hops :math:`K`. Defaults:``1``.\n        \"\"\"\n        super().__init__()\n        self.data_info = data_info\n        self.out_size = data_info[\"out_size\"]\n        self.embed_size = embed_size\n        if embed_size > 0:\n            self.embed = nn.Embedding(data_info[\"num_nodes\"], embed_size)\n            in_size = embed_size\n        else:\n            in_size = data_info[\"in_size\"]\n        self.sgc = SGConv(\n            in_size,\n            self.out_size,\n            k=k,\n            cached=True,\n            bias=bias,\n            norm=self.normalize,\n        )\n\n    def forward(self, g, node_feat, edge_feat=None):\n        if self.embed_size > 0:\n            dgl_warning(\n                \"The embedding for node feature is used, and input node_feat is ignored, due to the provided embed_size.\"\n            )\n            h = self.embed.weight\n        else:\n            h = node_feat\n        return self.sgc(g, h)\n\n    @staticmethod\n    def normalize(h):\n        return (h - h.mean(0)) / (h.std(0) + 1e-5)\n"
  },
  {
    "path": "dglgo/dglgo/pipeline/__init__.py",
    "content": "from .graphpred import GraphpredPipeline\nfrom .linkpred import LinkpredPipeline\nfrom .nodepred import NodepredPipeline\nfrom .nodepred_sample import NodepredNsPipeline\n"
  },
  {
    "path": "dglgo/dglgo/pipeline/graphpred/__init__.py",
    "content": "from .gen import *\n"
  },
  {
    "path": "dglgo/dglgo/pipeline/graphpred/gen.py",
    "content": "import copy\nfrom pathlib import Path\nfrom typing import Optional\n\nimport ruamel.yaml\nimport typer\nfrom jinja2 import Template\nfrom pydantic import BaseModel, Field\n\nfrom ...utils.factory import (\n    DataFactory,\n    GraphModelFactory,\n    PipelineBase,\n    PipelineFactory,\n)\nfrom ...utils.yaml_dump import deep_convert_dict, merge_comment\n\npipeline_comments = {\n    \"num_runs\": \"Number of experiments to run\",\n    \"train_batch_size\": \"Graph batch size when training\",\n    \"eval_batch_size\": \"Graph batch size when evaluating\",\n    \"num_workers\": \"Number of workers for data loading\",\n    \"num_epochs\": \"Number of training epochs\",\n    \"save_path\": \"Directory to save the experiment results\",\n}\n\n\nclass GraphpredPipelineCfg(BaseModel):\n    num_runs: int = 1\n    train_batch_size: int = 32\n    eval_batch_size: int = 32\n    num_workers: int = 4\n    optimizer: dict = {\"name\": \"Adam\", \"lr\": 0.001, \"weight_decay\": 0}\n    # Default to no lr decay\n    lr_scheduler: dict = {\"name\": \"StepLR\", \"step_size\": 100, \"gamma\": 1}\n    loss: str = \"BCEWithLogitsLoss\"\n    metric: str = \"roc_auc_score\"\n    num_epochs: int = 100\n    save_path: str = \"results\"\n\n\n@PipelineFactory.register(\"graphpred\")\nclass GraphpredPipeline(PipelineBase):\n    def __init__(self):\n        self.pipeline = {\"name\": \"graphpred\", \"mode\": \"train\"}\n\n    @classmethod\n    def setup_user_cfg_cls(cls):\n        from ...utils.enter_config import UserConfig\n\n        class GraphPredUserConfig(UserConfig):\n            data: DataFactory.filter(\"graphpred\").get_pydantic_config() = Field(\n                ..., discriminator=\"name\"\n            )\n            model: GraphModelFactory.get_pydantic_model_config() = Field(\n                ..., discriminator=\"name\"\n            )\n            general_pipeline: GraphpredPipelineCfg = GraphpredPipelineCfg()\n\n        cls.user_cfg_cls = GraphPredUserConfig\n\n    @property\n    def user_cfg_cls(self):\n        return self.__class__.user_cfg_cls\n\n    def get_cfg_func(self):\n        def config(\n            data: DataFactory.filter(\n                \"graphpred\"\n            ).get_dataset_enum() = typer.Option(..., help=\"input data name\"),\n            cfg: Optional[str] = typer.Option(\n                None, help=\"output configuration path\"\n            ),\n            model: GraphModelFactory.get_model_enum() = typer.Option(\n                ..., help=\"Model name\"\n            ),\n        ):\n            self.__class__.setup_user_cfg_cls()\n            generated_cfg = {\n                \"pipeline_name\": self.pipeline[\"name\"],\n                \"pipeline_mode\": self.pipeline[\"mode\"],\n                \"device\": \"cpu\",\n                \"data\": {\"name\": data.name},\n                \"model\": {\"name\": model.value},\n                \"general_pipeline\": {},\n            }\n            output_cfg = self.user_cfg_cls(**generated_cfg).dict()\n            output_cfg = deep_convert_dict(output_cfg)\n            comment_dict = {\n                \"device\": \"Torch device name, e.g., cpu or cuda or cuda:0\",\n                \"data\": {\n                    \"split_ratio\": \"Ratio to generate data split, for example set to [0.8, 0.1, 0.1] for 80% train/10% val/10% test. Leave blank to use builtin split in original dataset\"\n                },\n                \"general_pipeline\": pipeline_comments,\n                \"model\": GraphModelFactory.get_constructor_doc_dict(\n                    model.value\n                ),\n            }\n            comment_dict = merge_comment(output_cfg, comment_dict)\n\n            yaml = ruamel.yaml.YAML()\n            if cfg is None:\n                cfg = \"_\".join([\"graphpred\", data.value, model.value]) + \".yaml\"\n            yaml.dump(comment_dict, Path(cfg).open(\"w\"))\n            print(\n                \"Configuration file is generated at {}\".format(\n                    Path(cfg).absolute()\n                )\n            )\n\n        return config\n\n    @classmethod\n    def gen_script(cls, user_cfg_dict):\n        cls.setup_user_cfg_cls()\n        file_current_dir = Path(__file__).resolve().parent\n        with open(file_current_dir / \"graphpred.jinja-py\", \"r\") as f:\n            template = Template(f.read())\n\n        render_cfg = copy.deepcopy(user_cfg_dict)\n        model_code = GraphModelFactory.get_source_code(\n            user_cfg_dict[\"model\"][\"name\"]\n        )\n        render_cfg[\"model_code\"] = model_code\n        render_cfg[\"model_class_name\"] = GraphModelFactory.get_model_class_name(\n            user_cfg_dict[\"model\"][\"name\"]\n        )\n        render_cfg.update(\n            DataFactory.get_generated_code_dict(\n                user_cfg_dict[\"data\"][\"name\"], '**cfg[\"data\"]'\n            )\n        )\n\n        generated_user_cfg = copy.deepcopy(user_cfg_dict)\n        if \"split_ratio\" in generated_user_cfg[\"data\"]:\n            generated_user_cfg[\"data\"].pop(\"split_ratio\")\n        generated_user_cfg[\"data_name\"] = generated_user_cfg[\"data\"].pop(\"name\")\n        generated_user_cfg.pop(\"pipeline_name\")\n        generated_user_cfg.pop(\"pipeline_mode\")\n        generated_user_cfg[\"model_name\"] = generated_user_cfg[\"model\"].pop(\n            \"name\"\n        )\n        generated_user_cfg[\"general_pipeline\"][\"optimizer\"].pop(\"name\")\n        generated_user_cfg[\"general_pipeline\"][\"lr_scheduler\"].pop(\"name\")\n\n        generated_train_cfg = copy.deepcopy(user_cfg_dict[\"general_pipeline\"])\n        generated_train_cfg[\"optimizer\"].pop(\"name\")\n        generated_train_cfg[\"lr_scheduler\"].pop(\"name\")\n\n        if user_cfg_dict[\"data\"].get(\"split_ratio\", None) is not None:\n            render_cfg[\"data_initialize_code\"] = \"{}, split_ratio={}\".format(\n                render_cfg[\"data_initialize_code\"],\n                user_cfg_dict[\"data\"][\"split_ratio\"],\n            )\n        render_cfg[\"user_cfg_str\"] = f\"cfg = {str(generated_user_cfg)}\"\n        render_cfg[\"user_cfg\"] = user_cfg_dict\n        return template.render(**render_cfg)\n\n    @staticmethod\n    def get_description() -> str:\n        return \"Graph property prediction pipeline on binary classification\"\n"
  },
  {
    "path": "dglgo/dglgo/pipeline/graphpred/graphpred.jinja-py",
    "content": "import numpy as np\nimport sklearn\nimport torch\nimport torch.nn as nn\nimport os\n\nfrom torch.optim.lr_scheduler import ReduceLROnPlateau\nfrom tqdm import tqdm\nfrom dgl.data import AsGraphPredDataset\nfrom dgl.dataloading import GraphDataLoader\n{{ data_import_code }}\n\n{{ model_code }}\n\ndef train(device, loader, model, criterion, optimizer):\n    model.train()\n\n    for _, (g, labels) in enumerate(tqdm(loader, desc=\"Iteration\")):\n        g = g.to(device)\n        labels = labels.to(device)\n        node_feat = g.ndata['feat']\n        edge_feat = g.edata['feat']\n\n        pred = model(g, node_feat, edge_feat)\n        optimizer.zero_grad()\n        # ignore nan targets (unlabeled) when computing training loss\n        is_labeled = labels == labels\n        loss = criterion(pred.float()[is_labeled], labels.float()[is_labeled])\n        loss.backward()\n        optimizer.step()\n\ndef calc_metric(y_true, y_pred):\n    task_metric_list = []\n    for i in range(y_true.shape[1]):\n        # AUC is only defined when there is at least one positive and negative datapoint.\n        if np.sum(y_true[:, i] == 1) > 0 and np.sum(y_true[:, i] == 0) > 0:\n            # ignore nan values\n            is_labeled = y_true[:,i] == y_true[:,i]\n            task_metric = sklearn.metrics.{{ user_cfg.general_pipeline.metric }}(\n                y_true[is_labeled, i], y_pred[is_labeled, i])\n            task_metric_list.append(task_metric)\n\n    return sum(task_metric_list) / len(task_metric_list)\n\ndef evaluate(device, loader, model):\n    model.eval()\n    y_true = []\n    y_pred = []\n\n    for _, (g, labels) in enumerate(tqdm(loader, desc=\"Iteration\")):\n        g = g.to(device)\n        labels = labels.to(device)\n        node_feat = g.ndata['feat']\n        edge_feat = g.edata['feat']\n\n        with torch.no_grad():\n            pred = model(g, node_feat, edge_feat)\n        y_true.append(labels.view(pred.shape).detach().cpu())\n        y_pred.append(pred.detach().cpu())\n\n    y_true = torch.cat(y_true, dim=0).numpy()\n    y_pred = torch.cat(y_pred, dim=0).numpy()\n\n    return calc_metric(y_true, y_pred)\n\ndef main(run, cfg, data):\n    device = cfg['device']\n    pipeline_cfg = cfg['general_pipeline']\n\n    train_loader = GraphDataLoader(data[data.train_idx], batch_size=pipeline_cfg['train_batch_size'],\n                                   shuffle=True, num_workers=pipeline_cfg['num_workers'])\n    val_loader = GraphDataLoader(data[data.val_idx], batch_size=pipeline_cfg['eval_batch_size'],\n                                 shuffle=False, num_workers=pipeline_cfg['num_workers'])\n    test_loader = GraphDataLoader(data[data.test_idx], batch_size=pipeline_cfg['eval_batch_size'],\n                                  shuffle=False, num_workers=pipeline_cfg['num_workers'])\n\n    # create model\n    model = {{ model_class_name }}(**cfg[\"model\"])\n    model = model.to(device)\n\n    criterion = nn.{{ user_cfg.general_pipeline.loss }}()\n    optimizer = torch.optim.{{ user_cfg.general_pipeline.optimizer.name }}(\n        model.parameters(), **pipeline_cfg[\"optimizer\"])\n    lr_scheduler = torch.optim.lr_scheduler.{{ user_cfg.general_pipeline.lr_scheduler.name }}(\n        optimizer, **pipeline_cfg[\"lr_scheduler\"])\n    best_val_metric = 0.\n\n    tmp_cpt_path = 'checkpoint.pth'\n\n    for epoch in range(pipeline_cfg['num_epochs']):\n        train(device, train_loader, model, criterion, optimizer)\n        val_metric = evaluate(device, val_loader, model)\n        if val_metric >= best_val_metric:\n            best_val_metric = val_metric\n            torch.save(model.state_dict(), tmp_cpt_path)\n        print('Run {:d} | Epoch {:d} | Val Metric {:.4f} | Best Val Metric {:.4f}'.format(\n              run, epoch, val_metric, best_val_metric))\n\n        if isinstance(lr_scheduler, ReduceLROnPlateau):\n            lr_scheduler.step(val_metric)\n        else:\n            lr_scheduler.step()\n\n    model.load_state_dict(torch.load(tmp_cpt_path, weights_only=False))\n    os.remove(tmp_cpt_path)\n    test_metric = evaluate(device, test_loader, model)\n    print('Test Metric: {:.4f}'.format(test_metric))\n\n    cpt_path = os.path.join(pipeline_cfg[\"save_path\"], 'run_{}.pth'.format(run))\n    torch.save({'cfg': cfg, 'model': model.state_dict()}, cpt_path)\n    print('Saved training checkpoint to {}'.format(cpt_path))\n\n    return test_metric\n\nif __name__ == '__main__':\n    {{ user_cfg_str }}\n    if not torch.cuda.is_available():\n        cfg['device'] = 'cpu'\n\n    # load data\n    data = AsGraphPredDataset({{data_initialize_code}})\n\n    cfg[\"model\"][\"data_info\"] = {\n        \"name\": cfg[\"data_name\"],\n        \"node_feat_size\": data.node_feat_size,\n        \"edge_feat_size\": data.edge_feat_size,\n        \"out_size\": data.num_tasks\n    }\n    if cfg[\"model_name\"] == 'pna':\n        in_deg = torch.cat([g.in_degrees() for (g, _) in data[data.train_idx]])\n        cfg[\"model\"][\"data_info\"][\"delta\"] = torch.mean(torch.log(in_deg + 1)).item()\n\n    os.makedirs(cfg['general_pipeline'][\"save_path\"])\n\n    all_run_metrics = []\n    num_runs = {{ user_cfg.general_pipeline.num_runs }}\n    for run in range(num_runs):\n        print('Run experiment {:d}'.format(run))\n        test_metric = main(run, cfg, data)\n        all_run_metrics.append(test_metric)\n    avg_metric = np.round(np.mean(all_run_metrics), 6)\n    std_metric = np.round(np.std(all_run_metrics), 6)\n    print('Test Metric across {:d} runs: {:.6f} ± {:.6f}'.format(\n        num_runs, avg_metric, std_metric))\n"
  },
  {
    "path": "dglgo/dglgo/pipeline/linkpred/__init__.py",
    "content": "from .gen import *\n"
  },
  {
    "path": "dglgo/dglgo/pipeline/linkpred/gen.py",
    "content": "import copy\nfrom pathlib import Path\nfrom typing import Optional\n\nimport ruamel.yaml\nimport typer\nimport yaml\nfrom jinja2 import Template\nfrom pydantic import BaseModel, Field\nfrom ruamel.yaml.comments import CommentedMap\n\nfrom ...utils.base_model import DeviceEnum, EarlyStopConfig\nfrom ...utils.factory import (\n    DataFactory,\n    EdgeModelFactory,\n    NegativeSamplerFactory,\n    NodeModelFactory,\n    PipelineBase,\n    PipelineFactory,\n)\n\nfrom ...utils.yaml_dump import deep_convert_dict, merge_comment\n\n\nclass LinkpredPipelineCfg(BaseModel):\n    hidden_size: int = 256\n    eval_batch_size: int = 32769\n    train_batch_size: int = 32769\n    num_epochs: int = 200\n    eval_period: int = 5\n    optimizer: dict = {\"name\": \"Adam\", \"lr\": 0.005}\n    loss: str = \"BCELoss\"\n    save_path: str = \"results\"\n    num_runs: int = 1\n\n\npipeline_comments = {\n    \"hidden_size\": \"The intermediate hidden size between node model and edge model\",\n    \"eval_batch_size\": \"Edge batch size when evaluating\",\n    \"train_batch_size\": \"Edge batch size when training\",\n    \"num_epochs\": \"Number of training epochs\",\n    \"eval_period\": \"Interval epochs between evaluations\",\n    \"save_path\": \"Directory to save the experiment results\",\n    \"num_runs\": \"Number of experiments to run\",\n}\n\n\n@PipelineFactory.register(\"linkpred\")\nclass LinkpredPipeline(PipelineBase):\n\n    user_cfg_cls = None\n    pipeline_name = \"linkpred\"\n\n    def __init__(self):\n        self.pipeline = {\"name\": \"linkpred\", \"mode\": \"train\"}\n\n    @classmethod\n    def setup_user_cfg_cls(cls):\n        from ...utils.enter_config import UserConfig\n\n        class LinkPredUserConfig(UserConfig):\n            data: DataFactory.filter(\"linkpred\").get_pydantic_config() = Field(\n                ..., discriminator=\"name\"\n            )\n            node_model: NodeModelFactory.get_pydantic_model_config() = Field(\n                ..., discriminator=\"name\"\n            )\n            edge_model: EdgeModelFactory.get_pydantic_model_config() = Field(\n                ..., discriminator=\"name\"\n            )\n            neg_sampler: NegativeSamplerFactory.get_pydantic_model_config() = (\n                Field(..., discriminator=\"name\")\n            )\n            general_pipeline: LinkpredPipelineCfg = LinkpredPipelineCfg()\n\n        cls.user_cfg_cls = LinkPredUserConfig\n\n    @property\n    def user_cfg_cls(self):\n        return self.__class__.user_cfg_cls\n\n    def get_cfg_func(self):\n        def config(\n            data: DataFactory.filter(\n                \"linkpred\"\n            ).get_dataset_enum() = typer.Option(..., help=\"input data name\"),\n            cfg: str = typer.Option(\n                \"cfg.yaml\", help=\"output configuration path\"\n            ),\n            node_model: NodeModelFactory.get_model_enum() = typer.Option(\n                ..., help=\"Model name\"\n            ),\n            edge_model: EdgeModelFactory.get_model_enum() = typer.Option(\n                ..., help=\"Model name\"\n            ),\n            neg_sampler: NegativeSamplerFactory.get_model_enum() = typer.Option(\n                \"persource\", help=\"Negative sampler name\"\n            ),\n        ):\n            self.__class__.setup_user_cfg_cls()\n            generated_cfg = {\n                \"pipeline_name\": self.pipeline[\"name\"],\n                \"pipeline_mode\": self.pipeline[\"mode\"],\n                \"device\": \"cpu\",\n                \"data\": {\"name\": data.name},\n                \"neg_sampler\": {\"name\": neg_sampler.value},\n                \"node_model\": {\"name\": node_model.value},\n                \"edge_model\": {\"name\": edge_model.value},\n            }\n            output_cfg = self.user_cfg_cls(**generated_cfg).dict()\n            output_cfg = deep_convert_dict(output_cfg)\n            comment_dict = {\n                \"device\": \"Torch device name, e.g., cpu or cuda or cuda:0\",\n                \"general_pipeline\": pipeline_comments,\n                \"node_model\": NodeModelFactory.get_constructor_doc_dict(\n                    node_model.value\n                ),\n                \"edge_model\": EdgeModelFactory.get_constructor_doc_dict(\n                    edge_model.value\n                ),\n                \"neg_sampler\": NegativeSamplerFactory.get_constructor_doc_dict(\n                    neg_sampler.value\n                ),\n                \"data\": {\n                    \"split_ratio\": \"List of float, e.q. [0.8, 0.1, 0.1]. Split ratios for training, validation and test sets. Must sum to one. Leave blank to use builtin split in original dataset\",\n                    \"neg_ratio\": \"Int, e.q. 2. Indicate how much negative samples to be sampled per positive samples. Leave blank to use builtin split in original dataset\",\n                },\n            }\n            comment_dict = merge_comment(output_cfg, comment_dict)\n\n            if cfg is None:\n                cfg = (\n                    \"_\".join(\n                        [\n                            \"linkpred\",\n                            data.value,\n                            node_model.value,\n                            edge_model.value,\n                        ]\n                    )\n                    + \".yaml\"\n                )\n            yaml = ruamel.yaml.YAML()\n            yaml.dump(comment_dict, Path(cfg).open(\"w\"))\n            print(\n                \"Configuration file is generated at {}\".format(\n                    Path(cfg).absolute()\n                )\n            )\n\n        return config\n\n    @classmethod\n    def gen_script(cls, user_cfg_dict):\n        cls.setup_user_cfg_cls()\n        # Check validation\n        user_cfg = cls.user_cfg_cls(**user_cfg_dict)\n        file_current_dir = Path(__file__).resolve().parent\n        with open(file_current_dir / \"linkpred.jinja-py\", \"r\") as f:\n            template = Template(f.read())\n\n        render_cfg = copy.deepcopy(user_cfg_dict)\n        render_cfg[\"node_model_code\"] = NodeModelFactory.get_source_code(\n            user_cfg_dict[\"node_model\"][\"name\"]\n        )\n        render_cfg[\"edge_model_code\"] = EdgeModelFactory.get_source_code(\n            user_cfg_dict[\"edge_model\"][\"name\"]\n        )\n        render_cfg[\n            \"node_model_class_name\"\n        ] = NodeModelFactory.get_model_class_name(\n            user_cfg_dict[\"node_model\"][\"name\"]\n        )\n        render_cfg[\n            \"edge_model_class_name\"\n        ] = EdgeModelFactory.get_model_class_name(\n            user_cfg_dict[\"edge_model\"][\"name\"]\n        )\n        render_cfg[\n            \"neg_sampler_name\"\n        ] = NegativeSamplerFactory.get_model_class_name(\n            user_cfg_dict[\"neg_sampler\"][\"name\"]\n        )\n        render_cfg[\"loss\"] = user_cfg_dict[\"general_pipeline\"][\"loss\"]\n        # update import and initialization code\n        render_cfg.update(\n            DataFactory.get_generated_code_dict(\n                user_cfg_dict[\"data\"][\"name\"], '**cfg[\"data\"]'\n            )\n        )\n        generated_user_cfg = copy.deepcopy(user_cfg_dict)\n        if len(generated_user_cfg[\"data\"]) == 1:\n            generated_user_cfg.pop(\"data\")\n        else:\n            generated_user_cfg[\"data\"].pop(\"name\")\n        generated_user_cfg.pop(\"pipeline_name\")\n        generated_user_cfg.pop(\"pipeline_mode\")\n        generated_user_cfg[\"node_model\"].pop(\"name\")\n        generated_user_cfg[\"edge_model\"].pop(\"name\")\n        generated_user_cfg[\"neg_sampler\"].pop(\"name\")\n        generated_user_cfg[\"general_pipeline\"][\"optimizer\"].pop(\"name\")\n        generated_user_cfg[\"general_pipeline\"].pop(\"loss\")\n        generated_train_cfg = copy.deepcopy(user_cfg_dict[\"general_pipeline\"])\n        generated_train_cfg[\"optimizer\"].pop(\"name\")\n\n        if user_cfg_dict[\"data\"].get(\"split_ratio\", None) is not None:\n            assert (\n                user_cfg_dict[\"data\"].get(\"neg_ratio\", None) is not None\n            ), \"Please specify both split_ratio and neg_ratio\"\n            render_cfg[\n                \"data_initialize_code\"\n            ] = \"{}, split_ratio={}, neg_ratio={}\".format(\n                render_cfg[\"data_initialize_code\"],\n                user_cfg_dict[\"data\"][\"split_ratio\"],\n                user_cfg_dict[\"data\"][\"neg_ratio\"],\n            )\n            generated_user_cfg[\"data\"].pop(\"split_ratio\")\n            generated_user_cfg[\"data\"].pop(\"neg_ratio\")\n\n        render_cfg[\"user_cfg_str\"] = f\"cfg = {str(generated_user_cfg)}\"\n        render_cfg[\"user_cfg\"] = user_cfg_dict\n        return template.render(**render_cfg)\n\n    @staticmethod\n    def get_description() -> str:\n        return \"Link prediction pipeline\"\n"
  },
  {
    "path": "dglgo/dglgo/pipeline/linkpred/linkpred.jinja-py",
    "content": "import numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport dgl\nimport os\n\nfrom torch.utils.data import DataLoader\nfrom dgl.data import AsLinkPredDataset\n{{ data_import_code }}\n\n{{ node_model_code}}\n\n{{ edge_model_code }}\n\nclass Model(nn.Module):\n    def __init__(self, node_model, edge_model, neg_sampler, eval_batch_size):\n        super().__init__()\n        self.node_model = node_model\n        self.edge_model = edge_model\n        self.neg_sampler = neg_sampler\n        self.eval_batch_size = eval_batch_size\n\n    def inference(self, g, x, edges):\n        src, dst = edges\n        h = self.node_model(g, x)\n        eid_dataloader = DataLoader(\n            range(\n                src.shape[-1]),\n            batch_size=self.eval_batch_size)\n        score_list = []\n        for eids in eid_dataloader:\n            score = self.edge_model(h[src[eids]], h[dst[eids]])\n        score_list.append(score)\n        return torch.cat(score_list, dim=0)\n\ndef calc_hitsk(y_pred_pos, y_pred_neg, k):\n    kth_score_in_negative_edges = torch.topk(y_pred_neg.flatten(), k)[0][-1]\n    hitsK = (y_pred_pos > kth_score_in_negative_edges).float().mean()\n    return hitsK.item()\n\ndef train(cfg, pipeline_cfg, device, data, model, optimizer, loss_fcn):\n    train_g = data.train_graph\n    train_g = train_g.to(device)\n    node_feat = train_g.ndata['feat']\n    train_src, train_dst = train_g.edges()\n    for epoch in range(pipeline_cfg['num_epochs']):\n        model.train()\n        eid_dataloader = DataLoader(range(train_g.num_edges()), batch_size = pipeline_cfg[\"train_batch_size\"], shuffle=True)\n        for eids in eid_dataloader:\n            h = model.node_model(train_g, node_feat)\n            eids = eids.to(device)\n            src, dst = train_src[eids], train_dst[eids]\n            pos_score = model.edge_model(h[src], h[dst])\n            neg_src, neg_dst = model.neg_sampler(train_g, eids)\n            neg_score = model.edge_model(h[neg_src], h[neg_dst])\n            loss = loss_fcn(torch.cat([pos_score, neg_score]),  torch.cat(\n                [torch.ones_like(pos_score), torch.zeros_like(neg_score)]))\n\n            optimizer.zero_grad()\n            loss.backward()\n            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n            optimizer.step()\n        with torch.no_grad():\n            model.eval()\n            val_neg_edges = data.val_edges[1]\n            val_neg_score = model.inference(train_g, node_feat, val_neg_edges)\n        train_hits = calc_hitsk(pos_score, val_neg_score, k=50)\n        print(\"Epoch {:05d} | Loss {:.4f} | Train Hits@50 {:.4f}\".format(epoch, loss, train_hits))\n\n        if epoch != 0 and epoch % pipeline_cfg['eval_period'] == 0:\n            with torch.no_grad():\n                model.eval()\n                val_pos_edge, val_neg_edges = data.val_edges\n                pos_result = model.inference(train_g, node_feat, val_pos_edge)\n                neg_result = model.inference(train_g, node_feat, val_neg_edges)\n                val_hits = calc_hitsk(pos_result, neg_result, k=50)\n            print(\"Epoch {:05d} | Val Hits@50 {:.4f}\".format(epoch, val_hits))\n\n    with torch.no_grad():\n        model.eval()\n        test_pos_edge, test_neg_edges = data.test_edges\n        pos_result = model.inference(train_g, node_feat, test_pos_edge)\n        neg_result = model.inference(train_g, node_feat, test_neg_edges)\n        test_hits = calc_hitsk(pos_result, neg_result, k=50)\n        print(\"Test Hits@50 {:.4f}\".format(test_hits))\n    return test_hits\n\n\ndef main(run, cfg, data):\n    device = cfg['device']\n    pipeline_cfg = cfg['general_pipeline']\n    node_model = {{node_model_class_name}}(**cfg[\"node_model\"])\n    edge_model = {{edge_model_class_name}}(**cfg[\"edge_model\"])\n    neg_sampler = dgl.dataloading.negative_sampler.{{ neg_sampler_name }}(**cfg[\"neg_sampler\"])\n    model = Model(node_model, edge_model, neg_sampler, pipeline_cfg[\"eval_batch_size\"])\n    model = model.to(device)\n    loss = torch.nn.{{ loss }}()\n    optimizer = torch.optim.Adam(model.parameters(), **pipeline_cfg[\"optimizer\"])\n    test_hits = train(cfg, pipeline_cfg, device, data, model, optimizer, loss)\n\n    cpt_path = os.path.join(pipeline_cfg[\"save_path\"], 'run_{}.pth'.format(run))\n    torch.save({'cfg': cfg, 'model': model.state_dict()}, cpt_path)\n    print('Saved training checkpoint to {}'.format(cpt_path))\n\n    return test_hits\n\nif __name__ == '__main__':\n    {{user_cfg_str}}\n    if not torch.cuda.is_available():\n        cfg['device'] = 'cpu'\n\n    # load data\n    data = AsLinkPredDataset({{ data_initialize_code }})\n\n    nmodel_cfg = cfg[\"node_model\"]\n    pipeline_cfg = cfg['general_pipeline']\n    if 'feat' not in data[0].ndata:\n        assert nmodel_cfg[\"embed_size\"] > 0, \"Need to specify embed size if graph doesn't have feat in ndata\"\n    cfg[\"node_model\"][\"data_info\"] = {\n        \"in_size\": nmodel_cfg['embed_size'] if nmodel_cfg['embed_size'] > 0 else data[0].ndata['feat'].shape[1],\n        \"out_size\": pipeline_cfg['hidden_size'],\n        \"num_nodes\": data[0].num_nodes()\n    }\n    cfg[\"edge_model\"][\"data_info\"] = {\n        \"in_size\": pipeline_cfg['hidden_size'],\n        \"out_size\": 1 # output each edge score\n    }\n\n    os.makedirs(pipeline_cfg[\"save_path\"])\n\n    all_acc = []\n    num_runs = {{ user_cfg.general_pipeline.num_runs }}\n    for run in range(num_runs):\n        print(f'Run experiment #{run}')\n        test_acc = main(run, cfg, data)\n        print(\"Test Hits@50 {:.4f}\".format(test_acc))\n        all_acc.append(test_acc)\n    avg_acc = np.round(np.mean(all_acc), 6)\n    std_acc = np.round(np.std(all_acc), 6)\n    print(f'Test Hits@50 across {num_runs} runs: {avg_acc} ± {std_acc}')\n"
  },
  {
    "path": "dglgo/dglgo/pipeline/nodepred/__init__.py",
    "content": "from .gen import *\n"
  },
  {
    "path": "dglgo/dglgo/pipeline/nodepred/gen.py",
    "content": "import copy\nfrom pathlib import Path\nfrom typing import Optional\n\nimport ruamel.yaml\nimport typer\nimport yaml\nfrom jinja2 import Template\nfrom pydantic import BaseModel, Field\nfrom ruamel.yaml.comments import CommentedMap\n\nfrom ...utils.base_model import DeviceEnum, EarlyStopConfig\nfrom ...utils.factory import (\n    DataFactory,\n    NodeModelFactory,\n    PipelineBase,\n    PipelineFactory,\n)\nfrom ...utils.yaml_dump import deep_convert_dict, merge_comment\n\npipeline_comments = {\n    \"num_epochs\": \"Number of training epochs\",\n    \"eval_period\": \"Interval epochs between evaluations\",\n    \"early_stop\": {\n        \"patience\": \"Steps before early stop\",\n        \"checkpoint_path\": \"Early stop checkpoint model file path\",\n    },\n    \"save_path\": \"Directory to save the experiment results\",\n    \"num_runs\": \"Number of experiments to run\",\n}\n\n\nclass NodepredPipelineCfg(BaseModel):\n    early_stop: Optional[EarlyStopConfig] = EarlyStopConfig()\n    num_epochs: int = 200\n    eval_period: int = 5\n    optimizer: dict = {\"name\": \"Adam\", \"lr\": 0.01, \"weight_decay\": 5e-4}\n    loss: str = \"CrossEntropyLoss\"\n    save_path: str = \"results\"\n    num_runs: int = 1\n\n\n@PipelineFactory.register(\"nodepred\")\nclass NodepredPipeline(PipelineBase):\n\n    user_cfg_cls = None\n\n    def __init__(self):\n        self.pipeline = {\"name\": \"nodepred\", \"mode\": \"train\"}\n\n    @classmethod\n    def setup_user_cfg_cls(cls):\n        from ...utils.enter_config import UserConfig\n\n        class NodePredUserConfig(UserConfig):\n            data: DataFactory.filter(\"nodepred\").get_pydantic_config() = Field(\n                ..., discriminator=\"name\"\n            )\n            model: NodeModelFactory.get_pydantic_model_config() = Field(\n                ..., discriminator=\"name\"\n            )\n            general_pipeline: NodepredPipelineCfg = NodepredPipelineCfg()\n\n        cls.user_cfg_cls = NodePredUserConfig\n\n    @property\n    def user_cfg_cls(self):\n        return self.__class__.user_cfg_cls\n\n    def get_cfg_func(self):\n        def config(\n            data: DataFactory.filter(\n                \"nodepred\"\n            ).get_dataset_enum() = typer.Option(..., help=\"input data name\"),\n            cfg: Optional[str] = typer.Option(\n                None, help=\"output configuration path\"\n            ),\n            model: NodeModelFactory.get_model_enum() = typer.Option(\n                ..., help=\"Model name\"\n            ),\n        ):\n            self.__class__.setup_user_cfg_cls()\n            generated_cfg = {\n                \"pipeline_name\": self.pipeline[\"name\"],\n                \"pipeline_mode\": self.pipeline[\"mode\"],\n                \"device\": \"cpu\",\n                \"data\": {\"name\": data.name},\n                \"model\": {\"name\": model.value},\n                \"general_pipeline\": {},\n            }\n            output_cfg = self.user_cfg_cls(**generated_cfg).dict()\n            output_cfg = deep_convert_dict(output_cfg)\n            comment_dict = {\n                \"device\": \"Torch device name, e.g., cpu or cuda or cuda:0\",\n                \"data\": {\n                    \"split_ratio\": \"Ratio to generate split masks, for example set to [0.8, 0.1, 0.1] for 80% train/10% val/10% test. Leave blank to use builtin split in original dataset\"\n                },\n                \"general_pipeline\": pipeline_comments,\n                \"model\": NodeModelFactory.get_constructor_doc_dict(model.value),\n            }\n            comment_dict = merge_comment(output_cfg, comment_dict)\n\n            yaml = ruamel.yaml.YAML()\n            if cfg is None:\n                cfg = \"_\".join([\"nodepred\", data.value, model.value]) + \".yaml\"\n            yaml.dump(comment_dict, Path(cfg).open(\"w\"))\n            print(\n                \"Configuration file is generated at {}\".format(\n                    Path(cfg).absolute()\n                )\n            )\n\n        return config\n\n    @classmethod\n    def gen_script(cls, user_cfg_dict):\n        # Check validation\n        cls.setup_user_cfg_cls()\n        user_cfg = cls.user_cfg_cls(**user_cfg_dict)\n        file_current_dir = Path(__file__).resolve().parent\n        with open(file_current_dir / \"nodepred.jinja-py\", \"r\") as f:\n            template = Template(f.read())\n\n        render_cfg = copy.deepcopy(user_cfg_dict)\n        model_code = NodeModelFactory.get_source_code(\n            user_cfg_dict[\"model\"][\"name\"]\n        )\n        render_cfg[\"model_code\"] = model_code\n        render_cfg[\"model_class_name\"] = NodeModelFactory.get_model_class_name(\n            user_cfg_dict[\"model\"][\"name\"]\n        )\n        render_cfg.update(\n            DataFactory.get_generated_code_dict(\n                user_cfg_dict[\"data\"][\"name\"], '**cfg[\"data\"]'\n            )\n        )\n\n        generated_user_cfg = copy.deepcopy(user_cfg_dict)\n        if \"split_ratio\" in generated_user_cfg[\"data\"]:\n            generated_user_cfg[\"data\"].pop(\"split_ratio\")\n        generated_user_cfg[\"data_name\"] = generated_user_cfg[\"data\"].pop(\"name\")\n        generated_user_cfg.pop(\"pipeline_name\")\n        generated_user_cfg.pop(\"pipeline_mode\")\n        generated_user_cfg[\"model_name\"] = generated_user_cfg[\"model\"].pop(\n            \"name\"\n        )\n        generated_user_cfg[\"general_pipeline\"][\"optimizer\"].pop(\"name\")\n\n        generated_train_cfg = copy.deepcopy(user_cfg_dict[\"general_pipeline\"])\n        generated_train_cfg[\"optimizer\"].pop(\"name\")\n\n        if user_cfg_dict[\"data\"].get(\"split_ratio\", None) is not None:\n            render_cfg[\"data_initialize_code\"] = \"{}, split_ratio={}\".format(\n                render_cfg[\"data_initialize_code\"],\n                user_cfg_dict[\"data\"][\"split_ratio\"],\n            )\n        render_cfg[\"user_cfg_str\"] = f\"cfg = {str(generated_user_cfg)}\"\n        render_cfg[\"user_cfg\"] = user_cfg_dict\n        return template.render(**render_cfg)\n\n    @staticmethod\n    def get_description() -> str:\n        return \"Node classification pipeline for training\"\n"
  },
  {
    "path": "dglgo/dglgo/pipeline/nodepred/nodepred.jinja-py",
    "content": "import numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport dgl\nimport os\n\nfrom dgl.data import AsNodePredDataset\n{{ data_import_code }}\n\n{{ model_code }}\n\n{% if user_cfg.general_pipeline.early_stop %}\nclass EarlyStopping:\n    def __init__(self,\n                 patience: int = -1,\n                 checkpoint_path: str = 'checkpoint.pth'):\n        self.patience = patience\n        self.checkpoint_path = checkpoint_path\n        self.counter = 0\n        self.best_score = None\n        self.early_stop = False\n\n    def step(self, acc, model):\n        score = acc\n        if self.best_score is None:\n            self.best_score = score\n            self.save_checkpoint(model)\n        elif score < self.best_score:\n            self.counter += 1\n            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')\n            if self.counter >= self.patience:\n                self.early_stop = True\n        else:\n            self.best_score = score\n            self.save_checkpoint(model)\n            self.counter = 0\n        return self.early_stop\n\n    def save_checkpoint(self, model):\n        '''Save model when validation loss decreases.'''\n        torch.save(model.state_dict(), self.checkpoint_path)\n\n    def load_checkpoint(self, model):\n        model.load_state_dict(torch.load(self.checkpoint_path, weights_only=False))\n\n    def close(self):\n        os.remove(self.checkpoint_path)\n{% endif %}\n\n\ndef accuracy(logits, labels):\n    _, indices = torch.max(logits, dim=1)\n    correct = torch.sum(indices == labels)\n    return correct.item() * 1.0 / len(labels)\n\ndef train(cfg, pipeline_cfg, device, data, model, optimizer, loss_fcn):\n    g = data[0]  # Only train on the first graph\n    g = dgl.remove_self_loop(g)\n    g = dgl.add_self_loop(g)\n    g = g.to(device)\n\n    node_feat = g.ndata.get('feat', None)\n    edge_feat = g.edata.get('feat', None)\n    label = g.ndata['label']\n    train_mask, val_mask, test_mask = g.ndata['train_mask'].bool(), g.ndata['val_mask'].bool(), g.ndata['test_mask'].bool()\n\n    {% if user_cfg.general_pipeline.early_stop %}\n    stopper = EarlyStopping(**pipeline_cfg['early_stop'])\n    {% endif %}\n    val_acc = 0.\n    for epoch in range(pipeline_cfg['num_epochs']):\n        model.train()\n        logits = model(g, node_feat, edge_feat)\n        loss = loss_fcn(logits[train_mask], label[train_mask])\n\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        train_acc = accuracy(logits[train_mask], label[train_mask])\n        if epoch != 0 and epoch % pipeline_cfg['eval_period'] == 0:\n            val_acc = accuracy(logits[val_mask], label[val_mask])\n\n            {% if user_cfg.general_pipeline.early_stop %}\n            if stopper.step(val_acc, model):\n                break\n            {% endif %}\n        print(\"Epoch {:05d} | Loss {:.4f} | TrainAcc {:.4f} | ValAcc {:.4f}\".\n              format(epoch, loss.item(), train_acc, val_acc))\n\n    {% if user_cfg.general_pipeline.early_stop %}\n    stopper.load_checkpoint(model)\n    stopper.close()\n    {% endif %}\n    model.eval()\n    with torch.no_grad():\n        logits = model(g, node_feat, edge_feat)\n        test_acc = accuracy(logits[test_mask], label[test_mask])\n    return test_acc\n\ndef main(run, cfg, data):\n    device = cfg['device']\n    pipeline_cfg = cfg['general_pipeline']\n    model = {{ model_class_name }}(**cfg[\"model\"])\n    model = model.to(device)\n    loss = torch.nn.{{ user_cfg.general_pipeline.loss }}()\n    optimizer = torch.optim.{{ user_cfg.general_pipeline.optimizer.name }}(model.parameters(), **pipeline_cfg[\"optimizer\"])\n    test_acc = train(cfg, pipeline_cfg, device, data, model, optimizer, loss)\n    cpt_path = os.path.join(pipeline_cfg[\"save_path\"], 'run_{}.pth'.format(run))\n    torch.save({'cfg': cfg, 'model': model.state_dict()}, cpt_path)\n    print('Saved training checkpoint to {}'.format(cpt_path))\n\n    return test_acc\n\nif __name__ == '__main__':\n    {{ user_cfg_str }}\n    if not torch.cuda.is_available():\n        cfg['device'] = 'cpu'\n\n    # load data\n    data = AsNodePredDataset({{data_initialize_code}})\n\n    model_cfg = cfg[\"model\"]\n    cfg[\"model\"][\"data_info\"] = {\n        \"in_size\": model_cfg['embed_size'] if model_cfg['embed_size'] > 0 else data[0].ndata['feat'].shape[1],\n        \"out_size\": data.num_classes,\n        \"num_nodes\": data[0].num_nodes()\n    }\n\n    os.makedirs(cfg['general_pipeline'][\"save_path\"])\n\n    all_acc = []\n    num_runs = {{ user_cfg.general_pipeline.num_runs }}\n    for run in range(num_runs):\n        print(f'Run experiment #{run}')\n        test_acc = main(run, cfg, data)\n        print(\"Test Accuracy {:.4f}\".format(test_acc))\n        all_acc.append(test_acc)\n    avg_acc = np.round(np.mean(all_acc), 6)\n    std_acc = np.round(np.std(all_acc), 6)\n    print(f'Accuracy across {num_runs} runs: {avg_acc} ± {std_acc}')\n"
  },
  {
    "path": "dglgo/dglgo/pipeline/nodepred_sample/__init__.py",
    "content": "from .gen import *\n"
  },
  {
    "path": "dglgo/dglgo/pipeline/nodepred_sample/gen.py",
    "content": "import copy\nfrom enum import Enum\nfrom pathlib import Path\nfrom typing import List, Optional, Union\n\nimport ruamel.yaml\n\nimport typer\nimport yaml\nfrom jinja2 import ext, Template\nfrom pydantic import BaseModel, Field\nfrom ruamel.yaml.comments import CommentedMap\nfrom typing_extensions import Literal\n\nfrom ...utils.base_model import DeviceEnum, EarlyStopConfig, extract_name\nfrom ...utils.factory import (\n    DataFactory,\n    NodeModelFactory,\n    PipelineBase,\n    PipelineFactory,\n)\n\nfrom ...utils.yaml_dump import deep_convert_dict, merge_comment\n\n\nclass SamplerConfig(BaseModel):\n    name: Literal[\"neighbor\"]\n    fan_out: List[int] = [5, 10]\n    batch_size: int = Field(64, description=\"Batch size\")\n    num_workers: int = 4\n    eval_batch_size: int = 1024\n    eval_num_workers: int = 4\n\n    class Config:\n        extra = \"forbid\"\n\n\npipeline_comments = {\n    \"num_epochs\": \"Number of training epochs\",\n    \"eval_period\": \"Interval epochs between evaluations\",\n    \"early_stop\": {\n        \"patience\": \"Steps before early stop\",\n        \"checkpoint_path\": \"Early stop checkpoint model file path\",\n    },\n    \"sampler\": {\n        \"fan_out\": \"List of neighbors to sample per edge type for each GNN layer, with the i-th element being the fanout for the i-th GNN layer. Length should be the same as num_layers in model setting\",\n        \"batch_size\": \"Batch size of seed nodes in training stage\",\n        \"num_workers\": \"Number of workers to accelerate the graph data processing step\",\n        \"eval_batch_size\": \"Batch size of seed nodes in training stage in evaluation stage\",\n        \"eval_num_workers\": \"Number of workers to accelerate the graph data processing step in evaluation stage\",\n    },\n    \"save_path\": \"Directory to save the experiment results\",\n    \"num_runs\": \"Number of experiments to run\",\n}\n\n\nclass NodepredNSPipelineCfg(BaseModel):\n    sampler: SamplerConfig = Field(\"neighbor\")\n    early_stop: Optional[EarlyStopConfig] = EarlyStopConfig()\n    num_epochs: int = 200\n    eval_period: int = 5\n    optimizer: dict = {\"name\": \"Adam\", \"lr\": 0.005, \"weight_decay\": 0.0}\n    loss: str = \"CrossEntropyLoss\"\n    num_runs: int = 1\n    save_path: str = \"results\"\n\n\n@PipelineFactory.register(\"nodepred-ns\")\nclass NodepredNsPipeline(PipelineBase):\n    def __init__(self):\n        self.pipeline = {\"name\": \"nodepred-ns\", \"mode\": \"train\"}\n        self.default_cfg = None\n\n    @classmethod\n    def setup_user_cfg_cls(cls):\n        from ...utils.enter_config import UserConfig\n\n        class NodePredUserConfig(UserConfig):\n            eval_device: DeviceEnum = Field(\"cpu\")\n            data: DataFactory.filter(\n                \"nodepred-ns\"\n            ).get_pydantic_config() = Field(..., discriminator=\"name\")\n            model: NodeModelFactory.filter(\n                lambda cls: hasattr(cls, \"forward_block\")\n            ).get_pydantic_model_config() = Field(..., discriminator=\"name\")\n            general_pipeline: NodepredNSPipelineCfg\n\n        cls.user_cfg_cls = NodePredUserConfig\n\n    @property\n    def user_cfg_cls(self):\n        return self.__class__.user_cfg_cls\n\n    def get_cfg_func(self):\n        def config(\n            data: DataFactory.filter(\n                \"nodepred-ns\"\n            ).get_dataset_enum() = typer.Option(..., help=\"input data name\"),\n            cfg: Optional[str] = typer.Option(\n                None, help=\"output configuration path\"\n            ),\n            model: NodeModelFactory.filter(\n                lambda cls: hasattr(cls, \"forward_block\")\n            ).get_model_enum() = typer.Option(..., help=\"Model name\"),\n        ):\n            self.__class__.setup_user_cfg_cls()\n            generated_cfg = {\n                \"pipeline_name\": self.pipeline[\"name\"],\n                \"pipeline_mode\": self.pipeline[\"mode\"],\n                \"device\": \"cpu\",\n                \"data\": {\"name\": data.name},\n                \"model\": {\"name\": model.value},\n                \"general_pipeline\": {\"sampler\": {\"name\": \"neighbor\"}},\n            }\n            output_cfg = self.user_cfg_cls(**generated_cfg).dict()\n            output_cfg = deep_convert_dict(output_cfg)\n            comment_dict = {\n                \"device\": \"Torch device name, e.g., cpu or cuda or cuda:0\",\n                \"data\": {\n                    \"split_ratio\": \"Ratio to generate split masks, for example set to [0.8, 0.1, 0.1] for 80% train/10% val/10% test. Leave blank to use builtin split in original dataset\"\n                },\n                \"general_pipeline\": pipeline_comments,\n                \"model\": NodeModelFactory.get_constructor_doc_dict(model.value),\n            }\n            comment_dict = merge_comment(output_cfg, comment_dict)\n\n            # truncate length fan_out to be the same as num_layers in model\n            if \"num_layers\" in comment_dict[\"model\"]:\n                comment_dict[\"general_pipeline\"][\"sampler\"][\"fan_out\"] = [\n                    5,\n                    10,\n                    15,\n                    15,\n                    15,\n                ][: int(comment_dict[\"model\"][\"num_layers\"])]\n\n            if cfg is None:\n                cfg = (\n                    \"_\".join([\"nodepred-ns\", data.value, model.value]) + \".yaml\"\n                )\n            yaml = ruamel.yaml.YAML()\n            yaml.dump(comment_dict, Path(cfg).open(\"w\"))\n            print(\n                \"Configuration file is generated at {}\".format(\n                    Path(cfg).absolute()\n                )\n            )\n\n        return config\n\n    @staticmethod\n    def gen_script(user_cfg_dict):\n        file_current_dir = Path(__file__).resolve().parent\n        template_filename = file_current_dir / \"nodepred-ns.jinja-py\"\n        with open(template_filename, \"r\") as f:\n            template = Template(f.read())\n        pipeline_cfg = NodepredNSPipelineCfg(\n            **user_cfg_dict[\"general_pipeline\"]\n        )\n\n        if \"num_layers\" in user_cfg_dict[\"model\"]:\n            assert user_cfg_dict[\"model\"][\"num_layers\"] == len(\n                user_cfg_dict[\"general_pipeline\"][\"sampler\"][\"fan_out\"]\n            ), \"The num_layers in model config should be the same as the length of fan_out in sampler. For example, if num_layers is 1, the fan_out cannot be [5, 10]\"\n\n        render_cfg = copy.deepcopy(user_cfg_dict)\n        model_code = NodeModelFactory.get_source_code(\n            user_cfg_dict[\"model\"][\"name\"]\n        )\n        render_cfg[\"model_code\"] = model_code\n        render_cfg[\"model_class_name\"] = NodeModelFactory.get_model_class_name(\n            user_cfg_dict[\"model\"][\"name\"]\n        )\n        render_cfg.update(\n            DataFactory.get_generated_code_dict(\n                user_cfg_dict[\"data\"][\"name\"], '**cfg[\"data\"]'\n            )\n        )\n        generated_user_cfg = copy.deepcopy(user_cfg_dict)\n\n        if \"split_ratio\" in generated_user_cfg[\"data\"]:\n            generated_user_cfg[\"data\"].pop(\"split_ratio\")\n        generated_user_cfg[\"data_name\"] = generated_user_cfg[\"data\"].pop(\"name\")\n        generated_user_cfg.pop(\"pipeline_name\")\n        generated_user_cfg.pop(\"pipeline_mode\")\n        generated_user_cfg[\"model_name\"] = generated_user_cfg[\"model\"].pop(\n            \"name\"\n        )\n        generated_user_cfg[\"general_pipeline\"][\"optimizer\"].pop(\"name\")\n\n        if user_cfg_dict[\"data\"].get(\"split_ratio\", None) is not None:\n            render_cfg[\"data_initialize_code\"] = \"{}, split_ratio={}\".format(\n                render_cfg[\"data_initialize_code\"],\n                user_cfg_dict[\"data\"][\"split_ratio\"],\n            )\n\n        render_cfg[\"user_cfg_str\"] = f\"cfg = {str(generated_user_cfg)}\"\n        render_cfg[\"user_cfg\"] = user_cfg_dict\n        with open(\"output.py\", \"w\") as f:\n            return template.render(**render_cfg)\n\n    @staticmethod\n    def get_description() -> str:\n        return \"Node classification neighbor sampling pipeline for training\"\n"
  },
  {
    "path": "dglgo/dglgo/pipeline/nodepred_sample/nodepred-ns.jinja-py",
    "content": "import numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport dgl\nimport os\n\nfrom dgl.data import AsNodePredDataset\n{{ data_import_code }}\n\n{{ model_code }}\n\n{% if user_cfg.early_stop %}\nclass EarlyStopping:\n    def __init__(self,\n                 patience: int = -1,\n                 checkpoint_path: str = 'checkpoint.pth'):\n        self.patience = patience\n        self.checkpoint_path = checkpoint_path\n        self.counter = 0\n        self.best_score = None\n        self.early_stop = False\n\n    def step(self, acc, model):\n        score = acc\n        if self.best_score is None:\n            self.best_score = score\n            self.save_checkpoint(model)\n        elif score < self.best_score:\n            self.counter += 1\n            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')\n            if self.counter >= self.patience:\n                self.early_stop = True\n        else:\n            self.best_score = score\n            self.save_checkpoint(model)\n            self.counter = 0\n        return self.early_stop\n\n    def save_checkpoint(self, model):\n        '''Save model when validation loss decreases.'''\n        torch.save(model.state_dict(), self.checkpoint_path)\n\n    def load_checkpoint(self, model):\n        model.load_state_dict(torch.load(self.checkpoint_path, weights_only=False))\n\n    def close(self):\n        os.remove(self.checkpoint_path)\n{% endif %}\n\n\ndef load_subtensor(nfeat, labels, seeds, input_nodes, device):\n    \"\"\"\n    Extracts features and labels for a subset of nodes\n    \"\"\"\n    batch_inputs = nfeat[input_nodes].to(device)\n    batch_labels = labels[seeds].to(device)\n    return batch_inputs, batch_labels\n\ndef evaluate(model, g, nfeat, labels, val_nid, eval_device):\n    \"\"\"\n    Evaluate the model on the validation set specified by ``val_nid``.\n    g : The entire graph.\n    inputs : The features of all the nodes.\n    labels : The labels of all the nodes.\n    val_nid : the node Ids for validation.\n    device : The GPU device to evaluate on.\n    \"\"\"\n    model.eval()\n    eval_model = model.to(eval_device)\n    g = g.to(eval_device)\n    nfeat = nfeat.to(eval_device)\n    with torch.no_grad():\n        y = eval_model(g, nfeat)\n    model.train()\n    return accuracy(y[val_nid], labels[val_nid].to(y.device))\n\ndef accuracy(logits, labels):\n    _, indices = torch.max(logits, dim=1)\n    correct = torch.sum(indices == labels)\n    return correct.item() * 1.0 / len(labels)\n\ndef train(cfg, pipeline_cfg, device, data, model, optimizer, loss_fcn):\n    g = data[0]  # Only train on the first graph\n    g = dgl.remove_self_loop(g)\n    g = dgl.add_self_loop(g)\n    train_g = val_g = test_g = g\n\n    train_nfeat = val_nfeat = test_nfeat = train_g.ndata['feat']\n    train_labels = val_labels = test_labels = train_g.ndata['label']\n\n    train_nid = torch.nonzero(train_g.ndata['train_mask'], as_tuple=True)[0]\n    val_nid = torch.nonzero(val_g.ndata['val_mask'], as_tuple=True)[0]\n    test_nid = torch.nonzero(~(test_g.ndata['train_mask'] | test_g.ndata['val_mask']), as_tuple=True)[0]\n\n    sampler = dgl.dataloading.MultiLayerNeighborSampler(\n        [int(fanout) for fanout in pipeline_cfg[\"sampler\"][\"fan_out\"]])\n    dataloader = dgl.dataloading.NodeDataLoader(\n        train_g,\n        train_nid,\n        sampler,\n        device=device,\n        batch_size=pipeline_cfg[\"sampler\"][\"batch_size\"],\n        shuffle=True,\n        drop_last=False,\n        num_workers=pipeline_cfg[\"sampler\"][\"num_workers\"])\n\n    {% if user_cfg.early_stop %}\n    stopper = EarlyStopping(pipeline_cfg['patience'], pipeline_cfg['checkpoint_path'])\n    {% endif %}\n    val_acc = 0.\n    for epoch in range(pipeline_cfg['num_epochs']):\n        model.train()\n        model = model.to(device)\n        for step, (input_nodes, seeds, subgs) in enumerate(dataloader):\n            # Load the input features as well as output labels\n            batch_inputs, batch_labels = load_subtensor(train_nfeat, train_labels,\n                                                        seeds, input_nodes, device)\n            subgs = [subg.int().to(device) for subg in subgs]\n            batch_pred = model.forward_block(subgs, batch_inputs)\n            loss = loss_fcn(batch_pred, batch_labels)\n\n            optimizer.zero_grad()\n            loss.backward()\n            optimizer.step()\n\n            train_acc = accuracy(batch_pred, batch_labels)\n            print(\"Epoch {:05d} | Step {:05d} | Loss {:.4f} | TrainAcc {:.4f}\".\n                format(epoch, step, loss.item(), train_acc))\n\n        if epoch % pipeline_cfg[\"eval_period\"] == 0 and epoch != 0:\n            val_acc = evaluate(model, val_g, val_nfeat, val_labels, val_nid, cfg[\"eval_device\"])\n            print('Eval Acc {:.4f}'.format(val_acc))\n        {% if user_cfg.early_stop %}\n        if stopper.step(val_acc, model):\n            break\n        {% endif %}\n\n    {% if user_cfg.early_stop %}\n    stopper.load_checkpoint(model)\n    stopper.close()\n    {% endif %}\n    model.eval()\n    with torch.no_grad():\n        test_acc = evaluate(model, test_g, test_nfeat, test_labels, test_nid, cfg[\"eval_device\"])\n    return test_acc\n\ndef main(run, cfg, data):\n    device = cfg['device']\n    pipeline_cfg = cfg[\"general_pipeline\"]\n    model = {{ model_class_name }}(**cfg[\"model\"])\n    model = model.to(device)\n    loss = torch.nn.{{ user_cfg.general_pipeline.loss }}()\n    optimizer = torch.optim.{{ user_cfg.general_pipeline.optimizer.name }}(model.parameters(), **pipeline_cfg[\"optimizer\"])\n    test_acc = train(cfg, pipeline_cfg, device, data, model, optimizer, loss)\n\n    cpt_path = os.path.join(pipeline_cfg[\"save_path\"], 'run_{}.pth'.format(run))\n    torch.save({'cfg': cfg, 'model': model.state_dict()}, cpt_path)\n    print('Saved training checkpoint to {}'.format(cpt_path))\n\n    return test_acc\n\nif __name__ == '__main__':\n    {{ user_cfg_str }}\n    if not torch.cuda.is_available():\n        cfg['device'] = 'cpu'\n\n    # load data\n    data = AsNodePredDataset({{data_initialize_code}})\n\n    model_cfg = cfg[\"model\"]\n    cfg[\"model\"][\"data_info\"] = {\n        \"in_size\": model_cfg['embed_size'] if model_cfg['embed_size'] > 0 else data[0].ndata['feat'].shape[1],\n        \"out_size\": data.num_classes,\n        \"num_nodes\": data[0].num_nodes()\n    }\n\n    os.makedirs(cfg['general_pipeline'][\"save_path\"])\n\n    all_acc = []\n    num_runs = {{ user_cfg.general_pipeline.num_runs }}\n    for run in range(num_runs):\n        print(f'Run experiment #{run}')\n        test_acc = main(run, cfg, data)\n        print(\"Test Accuracy {:.4f}\".format(test_acc))\n        all_acc.append(test_acc)\n    avg_acc = np.round(np.mean(all_acc), 6)\n    std_acc = np.round(np.std(all_acc), 6)\n    print(f'Accuracy across {num_runs} runs: {avg_acc} ± {std_acc}')\n"
  },
  {
    "path": "dglgo/dglgo/utils/__init__.py",
    "content": "from .factory import *\n"
  },
  {
    "path": "dglgo/dglgo/utils/base_model.py",
    "content": "import copy\nimport enum\nfrom enum import Enum, IntEnum\nfrom typing import Optional\n\nfrom jinja2 import Template\nfrom pydantic import (\n    BaseModel as PydanticBaseModel,\n    create_model,\n    create_model,\n    Field,\n)\n\n\nclass DeviceEnum(str, Enum):\n    cpu = \"cpu\"\n    cuda = \"cuda\"\n\n\nclass DGLBaseModel(PydanticBaseModel):\n    class Config:\n        extra = \"allow\"\n        use_enum_values = True\n\n    @classmethod\n    def with_fields(cls, model_name, **field_definitions):\n        return create_model(model_name, __base__=cls, **field_definitions)\n\n\ndef get_literal_value(type_):\n    if hasattr(type_, \"__values__\"):\n        name = type_.__values__[0]\n    elif hasattr(type_, \"__args__\"):\n        name = type_.__args__[0]\n    return name\n\n\ndef extract_name(union_type):\n    name_dict = {}\n    for t in union_type.__args__:\n        type_ = t.__fields__[\"name\"].type_\n        name = get_literal_value(type_)\n        name_dict[name] = name\n    return enum.Enum(\"Choice\", name_dict)\n\n\nclass EarlyStopConfig(DGLBaseModel):\n    patience: int = 20\n    checkpoint_path: str = \"checkpoint.pth\"\n"
  },
  {
    "path": "dglgo/dglgo/utils/early_stop.py",
    "content": "import torch\n\n\nclass EarlyStopping:\n    def __init__(\n        self, patience: int = -1, checkpoint_path: str = \"checkpoint.pth\"\n    ):\n        self.patience = patience\n        self.checkpoint_path = checkpoint_path\n        self.counter = 0\n        self.best_score = None\n        self.early_stop = False\n\n    def step(self, acc, model):\n        score = acc\n        if self.best_score is None:\n            self.best_score = score\n            self.save_checkpoint(model)\n        elif score < self.best_score:\n            self.counter += 1\n            print(\n                f\"EarlyStopping counter: {self.counter} out of {self.patience}\"\n            )\n            if self.counter >= self.patience:\n                self.early_stop = True\n        else:\n            self.best_score = score\n            self.save_checkpoint(model)\n            self.counter = 0\n        return self.early_stop\n\n    def save_checkpoint(self, model):\n        \"\"\"Save model when validation loss decreases.\"\"\"\n        torch.save(model.state_dict(), self.checkpoint_path)\n\n    def load_checkpoint(self, model):\n        model.load_state_dict(\n            torch.load(self.checkpoint_path, weights_only=False)\n        )\n"
  },
  {
    "path": "dglgo/dglgo/utils/enter_config.py",
    "content": "import copy\nfrom enum import Enum, IntEnum\nfrom typing import Optional\n\nimport jinja2\nimport yaml\nfrom jinja2 import Template\nfrom pydantic import BaseModel as PydanticBaseModel, create_model, Field\n\nfrom .base_model import DGLBaseModel\n\n# from ..pipeline import nodepred, nodepred_sample\nfrom .factory import DataFactory, ModelFactory, PipelineFactory\n\n\nclass PipelineConfig(DGLBaseModel):\n    node_embed_size: Optional[int] = -1\n    early_stop: Optional[dict]\n    num_epochs: int = 200\n    eval_period: int = 5\n    optimizer: dict = {\"name\": \"Adam\", \"lr\": 0.005}\n    loss: str = \"CrossEntropyLoss\"\n\n\nclass UserConfig(DGLBaseModel):\n    version: Optional[str] = \"0.0.2\"\n    pipeline_name: PipelineFactory.get_pipeline_enum()\n    pipeline_mode: str\n    device: str = \"cpu\"\n"
  },
  {
    "path": "dglgo/dglgo/utils/factory.py",
    "content": "import enum\nimport inspect\nimport logging\nfrom abc import ABC, abstractmethod, abstractstaticmethod\nfrom pathlib import Path\nfrom typing import Callable, Dict, List, Optional, Tuple, Union\n\nimport yaml\nfrom dgl.dataloading.negative_sampler import GlobalUniform, PerSourceUniform\nfrom numpydoc import docscrape\nfrom pydantic import create_model, create_model_from_typeddict, Field\nfrom typing_extensions import Literal\n\nfrom .base_model import DGLBaseModel\n\nlogger = logging.getLogger(__name__)\n\nALL_PIPELINE = [\"nodepred\", \"nodepred-ns\", \"linkpred\", \"graphpred\"]\n\n\nclass PipelineBase(ABC):\n    @abstractmethod\n    def __init__(self) -> None:\n        super().__init__()\n\n    @abstractmethod\n    def get_cfg_func(self):\n        pass\n\n    @abstractstaticmethod\n    def gen_script(user_cfg_dict: dict):\n        pass\n\n    @abstractstaticmethod\n    def get_description() -> str:\n        pass\n\n\nclass DataFactoryClass:\n    def __init__(self):\n        self.registry = {}\n        self.pipeline_name = None\n        self.pipeline_allowed = {}\n\n    def register(\n        self,\n        name: str,\n        import_code: str,\n        class_name: str,\n        allowed_pipeline: List[str],\n        extra_args={},\n    ):\n        self.registry[name] = {\n            \"name\": name,\n            \"import_code\": import_code,\n            \"class_name\": class_name,\n            \"extra_args\": extra_args,\n        }\n        for pipeline in allowed_pipeline:\n            if pipeline in self.pipeline_allowed:\n                self.pipeline_allowed[pipeline].append(name)\n            else:\n                self.pipeline_allowed[pipeline] = [name]\n        return self\n\n    def get_dataset_enum(self):\n        enum_class = enum.Enum(\n            \"DatasetName\", {v[\"name\"]: k for k, v in self.registry.items()}\n        )\n        return enum_class\n\n    def get_dataset_classname(self, name):\n        return self.registry[name][\"class_name\"]\n\n    def get_constructor_arg_type(self, model_name):\n        sigs = inspect.signature(self.registry[model_name].__init__)\n        type_annotation_dict = {}\n        for k, param in dict(sigs.parameters).items():\n            type_annotation_dict[k] = param.annotation\n        return type_annotation_dict\n\n    def get_pydantic_config(self):\n\n        type_annotation_dict = {}\n        dataset_list = []\n        for k, v in self.registry.items():\n            dataset_name = v[\"name\"]\n            type_annotation_dict = v[\"extra_args\"]\n            if \"name\" in type_annotation_dict:\n                del type_annotation_dict[\"name\"]\n            base = self.get_base_class(dataset_name, self.pipeline_name)\n            dataset_list.append(\n                create_model(\n                    f\"{dataset_name}Config\",\n                    **type_annotation_dict,\n                    __base__=base,\n                )\n            )\n\n        output = dataset_list[0]\n        for d in dataset_list[1:]:\n            output = Union[output, d]\n        return output\n\n    def get_import_code(self, name):\n        return self.registry[name][\"import_code\"]\n\n    def get_import_code(self, name):\n        return self.registry[name][\"import_code\"]\n\n    def get_extra_args(self, name):\n        return self.registry[name][\"extra_args\"]\n\n    def get_class_name(self, name):\n        return self.registry[name][\"class_name\"]\n\n    def get_generated_code_dict(self, name, args='**cfg[\"data\"]'):\n        d = {}\n        d[\"data_import_code\"] = self.registry[name][\"import_code\"]\n        data_initialize_code = self.registry[name][\"class_name\"]\n        extra_args_dict = self.registry[name][\"extra_args\"]\n        if len(extra_args_dict) > 0:\n            data_initialize_code = data_initialize_code.format('**cfg[\"data\"]')\n        d[\"data_initialize_code\"] = data_initialize_code\n        return d\n\n    def filter(self, pipeline_name):\n        allowed_name = self.pipeline_allowed[pipeline_name]\n        new_registry = {\n            k: v for k, v in self.registry.items() if k in allowed_name\n        }\n        d = DataFactoryClass()\n        d.registry = new_registry\n        d.pipeline_name = pipeline_name\n        return d\n\n    @staticmethod\n    def get_base_class(dataset_name, pipeline_name):\n        if pipeline_name == \"linkpred\":\n\n            class EdgeBase(DGLBaseModel):\n                name: Literal[dataset_name]\n                split_ratio: Optional[Tuple[float, float, float]] = None\n                neg_ratio: Optional[int] = None\n\n            return EdgeBase\n        else:\n\n            class NodeBase(DGLBaseModel):\n                name: Literal[dataset_name]\n                split_ratio: Optional[Tuple[float, float, float]] = None\n\n            return NodeBase\n\n\nDataFactory = DataFactoryClass()\n\nDataFactory.register(\n    \"cora\",\n    import_code=\"from dgl.data import CoraGraphDataset\",\n    class_name=\"CoraGraphDataset()\",\n    allowed_pipeline=[\"nodepred\", \"nodepred-ns\", \"linkpred\"],\n)\n\nDataFactory.register(\n    \"citeseer\",\n    import_code=\"from dgl.data import CiteseerGraphDataset\",\n    class_name=\"CiteseerGraphDataset()\",\n    allowed_pipeline=[\"nodepred\", \"nodepred-ns\", \"linkpred\"],\n)\n\nDataFactory.register(\n    \"pubmed\",\n    import_code=\"from dgl.data import PubmedGraphDataset\",\n    class_name=\"PubmedGraphDataset()\",\n    allowed_pipeline=[\"nodepred\", \"nodepred-ns\", \"linkpred\"],\n)\n\nDataFactory.register(\n    \"csv\",\n    import_code=\"from dgl.data import CSVDataset\",\n    extra_args={\"data_path\": \"./\"},\n    class_name=\"CSVDataset({})\",\n    allowed_pipeline=[\"nodepred\", \"nodepred-ns\", \"linkpred\", \"graphpred\"],\n)\n\nDataFactory.register(\n    \"reddit\",\n    import_code=\"from dgl.data import RedditDataset\",\n    class_name=\"RedditDataset()\",\n    allowed_pipeline=[\"nodepred\", \"nodepred-ns\", \"linkpred\"],\n)\n\nDataFactory.register(\n    \"co-buy-computer\",\n    import_code=\"from dgl.data import AmazonCoBuyComputerDataset\",\n    class_name=\"AmazonCoBuyComputerDataset()\",\n    allowed_pipeline=[\"nodepred\", \"nodepred-ns\", \"linkpred\"],\n)\n\nDataFactory.register(\n    \"ogbn-arxiv\",\n    import_code=\"from ogb.nodeproppred import DglNodePropPredDataset\",\n    extra_args={},\n    class_name=\"DglNodePropPredDataset('ogbn-arxiv')\",\n    allowed_pipeline=[\"nodepred\", \"nodepred-ns\", \"linkpred\"],\n)\n\nDataFactory.register(\n    \"ogbn-products\",\n    import_code=\"from ogb.nodeproppred import DglNodePropPredDataset\",\n    extra_args={},\n    class_name=\"DglNodePropPredDataset('ogbn-products')\",\n    allowed_pipeline=[\"nodepred\", \"nodepred-ns\", \"linkpred\"],\n)\n\nDataFactory.register(\n    \"ogbl-collab\",\n    import_code=\"from ogb.linkproppred import DglLinkPropPredDataset\",\n    extra_args={},\n    class_name=\"DglLinkPropPredDataset('ogbl-collab')\",\n    allowed_pipeline=[\"linkpred\"],\n)\n\nDataFactory.register(\n    \"ogbl-citation2\",\n    import_code=\"from ogb.linkproppred import DglLinkPropPredDataset\",\n    extra_args={},\n    class_name=\"DglLinkPropPredDataset('ogbl-citation2')\",\n    allowed_pipeline=[\"linkpred\"],\n)\n\nDataFactory.register(\n    \"ogbg-molhiv\",\n    import_code=\"from ogb.graphproppred import DglGraphPropPredDataset\",\n    extra_args={},\n    class_name=\"DglGraphPropPredDataset(name='ogbg-molhiv')\",\n    allowed_pipeline=[\"graphpred\"],\n)\n\nDataFactory.register(\n    \"ogbg-molpcba\",\n    import_code=\"from ogb.graphproppred import DglGraphPropPredDataset\",\n    extra_args={},\n    class_name=\"DglGraphPropPredDataset(name='ogbg-molpcba')\",\n    allowed_pipeline=[\"graphpred\"],\n)\n\n\nclass PipelineFactory:\n    \"\"\"The factory class for creating executors\"\"\"\n\n    registry: Dict[str, PipelineBase] = {}\n    default_config_registry = {}\n    \"\"\" Internal registry for available executors \"\"\"\n\n    @classmethod\n    def register(cls, name: str) -> Callable:\n        def inner_wrapper(wrapped_class) -> Callable:\n            if name in cls.registry:\n                logger.warning(\n                    \"Executor %s already exists. Will replace it\", name\n                )\n            cls.registry[name] = wrapped_class()\n            return wrapped_class\n\n        return inner_wrapper\n\n    @classmethod\n    def register_default_config_generator(cls, name: str) -> Callable:\n        def inner_wrapper(wrapped_class) -> Callable:\n            if name in cls.registry:\n                logger.warning(\n                    \"Executor %s already exists. Will replace it\", name\n                )\n            cls.default_config_registry[name] = wrapped_class\n            return wrapped_class\n\n        return inner_wrapper\n\n    @classmethod\n    def call_default_config_generator(\n        cls, generator_name, model_name, dataset_name\n    ):\n        return cls.default_config_registry[generator_name](\n            model_name, dataset_name\n        )\n\n    @classmethod\n    def call_generator(cls, generator_name, cfg):\n        return cls.registry[generator_name](cfg)\n\n    @classmethod\n    def get_pipeline_enum(cls):\n        enum_class = enum.Enum(\n            \"PipelineName\", {k: k for k, v in cls.registry.items()}\n        )\n        return enum_class\n\n\nclass ApplyPipelineFactory:\n    \"\"\"The factory class for creating executors for inference\"\"\"\n\n    registry: Dict[str, PipelineBase] = {}\n    \"\"\" Internal registry for available executors \"\"\"\n\n    @classmethod\n    def register(cls, name: str) -> Callable:\n        def inner_wrapper(wrapped_class) -> Callable:\n            if name in cls.registry:\n                logger.warning(\n                    \"Executor %s already exists. Will replace it\", name\n                )\n            cls.registry[name] = wrapped_class()\n            return wrapped_class\n\n        return inner_wrapper\n\n\nmodel_dir = Path(__file__).parent.parent / \"model\"\n\n\nclass ModelFactory:\n    \"\"\"The factory class for creating executors\"\"\"\n\n    def __init__(self):\n        self.registry = {}\n        self.code_registry = {}\n\n    \"\"\" Internal registry for available executors \"\"\"\n\n    def get_model_enum(self):\n        enum_class = enum.Enum(\n            \"ModelName\", {k: k for k, v in self.registry.items()}\n        )\n        return enum_class\n\n    def register(self, model_name: str) -> Callable:\n        def inner_wrapper(wrapped_class) -> Callable:\n            if model_name in self.registry:\n                logger.warning(\n                    \"Executor %s already exists. Will replace it\", model_name\n                )\n            self.registry[model_name] = wrapped_class\n            # code_filename = model_dir / filename\n            code_filename = Path(inspect.getfile(wrapped_class))\n            self.code_registry[model_name] = code_filename.read_text()\n            return wrapped_class\n\n        return inner_wrapper\n\n    def get_source_code(self, model_name):\n        return self.code_registry[model_name]\n\n    def get_constructor_default_args(self, model_name):\n        sigs = inspect.signature(self.registry[model_name].__init__)\n        default_map = {}\n        for k, param in dict(sigs.parameters).items():\n            default_map[k] = param.default\n        return default_map\n\n    def get_pydantic_constructor_arg_type(self, model_name: str):\n        model_enum = self.get_model_enum()\n        arg_dict = self.get_constructor_default_args(model_name)\n        type_annotation_dict = {}\n        # type_annotation_dict[\"name\"] = Literal[\"\"]\n        exempt_keys = [\"self\", \"in_size\", \"out_size\", \"data_info\"]\n        for k, param in arg_dict.items():\n            if k not in exempt_keys:\n                type_annotation_dict[k] = arg_dict[k]\n\n        class Base(DGLBaseModel):\n            name: Literal[model_name]\n\n        return create_model(\n            f\"{model_name.upper()}ModelConfig\",\n            **type_annotation_dict,\n            __base__=Base,\n        )\n\n    def get_constructor_doc_dict(self, name):\n        model_class = self.registry[name]\n        docs = inspect.getdoc(model_class.__init__)\n        param_docs = docscrape.NumpyDocString(docs)\n        param_docs_dict = {}\n        for param in param_docs[\"Parameters\"]:\n            param_docs_dict[param.name] = param.desc[0]\n        return param_docs_dict\n\n    def get_pydantic_model_config(self):\n        model_list = []\n        for k in self.registry:\n            model_list.append(self.get_pydantic_constructor_arg_type(k))\n        output = model_list[0]\n        for m in model_list[1:]:\n            output = Union[output, m]\n        return output\n\n    def get_model_class_name(self, model_name):\n        return self.registry[model_name].__name__\n\n    def get_constructor_arg_type(self, model_name):\n        sigs = inspect.signature(self.registry[model_name].__init__)\n        type_annotation_dict = {}\n        for k, param in dict(sigs.parameters).items():\n            type_annotation_dict[k] = param.annotation\n        return type_annotation_dict\n\n    def filter(self, filter_func):\n        new_fac = ModelFactory()\n        for name in self.registry:\n            if filter_func(self.registry[name]):\n                new_fac.registry[name] = self.registry[name]\n                new_fac.code_registry[name] = self.code_registry[name]\n        return new_fac\n\n\nclass SamplerFactory:\n    \"\"\"The factory class for creating executors\"\"\"\n\n    def __init__(self):\n        self.registry = {}\n\n    def get_model_enum(self):\n        enum_class = enum.Enum(\n            \"NegativeSamplerName\", {k: k for k, v in self.registry.items()}\n        )\n        return enum_class\n\n    def register(self, sampler_name: str) -> Callable:\n        def inner_wrapper(wrapped_class) -> Callable:\n            if sampler_name in self.registry:\n                logger.warning(\n                    \"Sampler %s already exists. Will replace it\", sampler_name\n                )\n            self.registry[sampler_name] = wrapped_class\n            return wrapped_class\n\n        return inner_wrapper\n\n    def get_constructor_default_args(self, sampler_name):\n        sigs = inspect.signature(self.registry[sampler_name].__init__)\n        default_map = {}\n        for k, param in dict(sigs.parameters).items():\n            default_map[k] = param.default\n        return default_map\n\n    def get_pydantic_constructor_arg_type(self, sampler_name: str):\n        model_enum = self.get_model_enum()\n        arg_dict = self.get_constructor_default_args(sampler_name)\n        type_annotation_dict = {}\n        # type_annotation_dict[\"name\"] = Literal[\"\"]\n        exempt_keys = [\"self\", \"in_size\", \"out_size\", \"redundancy\"]\n        for k, param in arg_dict.items():\n            if k not in exempt_keys or param is None:\n                if k == \"k\" or k == \"redundancy\":\n                    type_annotation_dict[k] = 3\n                else:\n                    type_annotation_dict[k] = arg_dict[k]\n\n        class Base(DGLBaseModel):\n            name: Literal[sampler_name]\n\n        return create_model(\n            f\"{sampler_name.upper()}SamplerConfig\",\n            **type_annotation_dict,\n            __base__=Base,\n        )\n\n    def get_pydantic_model_config(self):\n        model_list = []\n        for k in self.registry:\n            model_list.append(self.get_pydantic_constructor_arg_type(k))\n        output = model_list[0]\n        for m in model_list[1:]:\n            output = Union[output, m]\n        return output\n\n    def get_model_class_name(self, model_name):\n        return self.registry[model_name].__name__\n\n    def get_constructor_arg_type(self, model_name):\n        sigs = inspect.signature(self.registry[model_name].__init__)\n        type_annotation_dict = {}\n        for k, param in dict(sigs.parameters).items():\n            type_annotation_dict[k] = param.annotation\n        return type_annotation_dict\n\n    def get_constructor_doc_dict(self, name):\n        model_class = self.registry[name]\n        docs = inspect.getdoc(model_class)\n        param_docs = docscrape.NumpyDocString(docs)\n        param_docs_dict = {}\n        for param in param_docs[\"Parameters\"]:\n            param_docs_dict[param.name] = param.desc[0]\n        return param_docs_dict\n\n\nNegativeSamplerFactory = SamplerFactory()\nNegativeSamplerFactory.register(\"global\")(GlobalUniform)\nNegativeSamplerFactory.register(\"persource\")(PerSourceUniform)\n\nNodeModelFactory = ModelFactory()\nEdgeModelFactory = ModelFactory()\nGraphModelFactory = ModelFactory()\n"
  },
  {
    "path": "dglgo/dglgo/utils/yaml_dump.py",
    "content": "from ruamel.yaml.comments import CommentedMap\n\n\ndef deep_convert_dict(layer):\n    to_ret = layer\n    if isinstance(layer, dict):\n        to_ret = CommentedMap(layer)\n    try:\n        for key, value in to_ret.items():\n            to_ret[key] = deep_convert_dict(value)\n    except AttributeError:\n        pass\n\n    return to_ret\n\n\nimport collections.abc\n\n\ndef merge_comment(d, comment_dict, column=30):\n    for k, v in comment_dict.items():\n        if isinstance(v, collections.abc.Mapping):\n            d[k] = merge_comment(d.get(k, CommentedMap()), v)\n        else:\n            d.yaml_add_eol_comment(v, key=k, column=column)\n    return d\n"
  },
  {
    "path": "dglgo/recipes/__init__.py",
    "content": ""
  },
  {
    "path": "dglgo/recipes/graphpred_hiv_gin.yaml",
    "content": "version: 0.0.2\npipeline_name: graphpred\npipeline_mode: train\ndevice: cuda:0                # Torch device name, e.q. cpu or cuda or cuda:0\ndata:\n  name: ogbg-molhiv\n  split_ratio:                # Ratio to generate data split, for example set to [0.8, 0.1, 0.1] for 80% train/10% val/10% test. Leave blank to use builtin split in original dataset\nmodel:\n  name: gin\n  embed_size: 300             # Embedding size.\n  num_layers: 5               # Number of layers.\n  dropout: 0.5                # Dropout rate.\n  virtual_node: true          # Whether to use virtual node.\ngeneral_pipeline:\n  num_runs: 10                # Number of experiments to run\n  train_batch_size: 32        # Graph batch size when training\n  eval_batch_size: 32         # Graph batch size when evaluating\n  num_workers: 4              # Number of workers for data loading\n  optimizer:\n    name: Adam\n    lr: 0.001\n    weight_decay: 0\n  lr_scheduler:\n    name: StepLR\n    step_size: 100\n    gamma: 1\n  loss: BCEWithLogitsLoss\n  metric: roc_auc_score\n  num_epochs: 100             # Number of training epochs\n  save_path: \"results\"        # Directory to save the experiment results\n"
  },
  {
    "path": "dglgo/recipes/graphpred_hiv_pna.yaml",
    "content": "version: 0.0.2\npipeline_name: graphpred\npipeline_mode: train\ndevice: cuda:0                # Torch device name, e.q. cpu or cuda or cuda:0\ndata:\n  name: ogbg-molhiv\n  split_ratio:                # Ratio to generate data split, for example set to [0.8, 0.1, 0.1] for 80% train/10% val/10% test. Leave blank to use builtin split in original dataset\nmodel:\n  name: pna\n  embed_size: 80              # Embedding size.\n  aggregators: mean max min std # Aggregation function names separated by space, can include mean, max, min, std, sum\n  scalers: identity amplification attenuation # Scaler function names separated by space, can include identity, amplification, and attenuation\n  dropout: 0.3                # Dropout rate.\n  batch_norm: true            # Whether to use batch normalization.\n  residual: true              # Whether to use residual connection.\n  num_mlp_layers: 1           # Number of MLP layers to use after message aggregation in each PNA layer.\n  num_layers: 4               # Number of PNA layers.\n  readout: mean               # Readout for computing graph-level representations, can be 'sum' or 'mean'.\ngeneral_pipeline:\n  num_runs: 10                # Number of experiments to run\n  train_batch_size: 128       # Graph batch size when training\n  eval_batch_size: 128        # Graph batch size when evaluating\n  num_workers: 4              # Number of workers for data loading\n  optimizer:\n    name: Adam\n    lr: 0.01\n    weight_decay: 0.000003\n  lr_scheduler:\n    name: ReduceLROnPlateau\n    mode: max\n    factor: 0.5\n    patience: 20\n    verbose: true\n  loss: BCEWithLogitsLoss\n  metric: roc_auc_score\n  num_epochs: 200             # Number of training epochs\n  save_path: \"results\"        # Directory to save the experiment results\n"
  },
  {
    "path": "dglgo/recipes/graphpred_pcba_gin.yaml",
    "content": "version: 0.0.2\npipeline_name: graphpred\npipeline_mode: train\ndevice: cuda:0                # Torch device name, e.q. cpu or cuda or cuda:0\ndata:\n  name: ogbg-molpcba\n  split_ratio:                # Ratio to generate data split, for example set to [0.8, 0.1, 0.1] for 80% train/10% val/10% test. Leave blank to use builtin split in original dataset\nmodel:\n  name: gin\n  embed_size: 300             # Embedding size.\n  num_layers: 5               # Number of layers.\n  dropout: 0.5                # Dropout rate.\n  virtual_node: true          # Whether to use virtual node.\ngeneral_pipeline:\n  num_runs: 10                # Number of experiments to run\n  train_batch_size: 32        # Graph batch size when training\n  eval_batch_size: 32         # Graph batch size when evaluating\n  num_workers: 4              # Number of workers for data loading\n  optimizer:\n    name: Adam\n    lr: 0.001\n    weight_decay: 0\n  lr_scheduler:\n    name: StepLR\n    step_size: 100\n    gamma: 1\n  loss: BCEWithLogitsLoss\n  metric: average_precision_score\n  num_epochs: 100             # Number of training epochs\n  save_path: \"results\"        # Directory to save the experiment results\n"
  },
  {
    "path": "dglgo/recipes/linkpred_citation2_sage.yaml",
    "content": "version: 0.0.2\npipeline_name: linkpred\npipeline_mode: train\ndevice: cpu\ndata:\n  name: ogbl-citation2\n  split_ratio:                # List of float, e.q. [0.8, 0.1, 0.1]. Split ratios for training, validation and test sets. Must sum to one. Leave blank to use builtin split in original dataset\n  neg_ratio:                  # Int, e.q. 2. Indicate how much negative samples to be sampled per positive samples. Leave blank to use builtin split in original dataset\nnode_model:\n  name: sage\n  embed_size: -1              # The dimension of created embedding table. -1 means using original node embedding\n  hidden_size: 16             # Hidden size.\n  num_layers: 1               # Number of hidden layers.\n  activation: relu\n  dropout: 0.5                # Dropout rate.\n  aggregator_type: gcn        # Aggregator type to use (``mean``, ``gcn``, ``pool``, ``lstm``).\nedge_model:\n  name: ele\n  hidden_size: 64             # Hidden size.\n  num_layers: 2               # Number of hidden layers.\n  bias: true                  # Whether to use bias in the linaer layer.\nneg_sampler:\n  name: persource\n  k: 3                        # The number of negative samples per edge.\ngeneral_pipeline:\n  hidden_size: 256            # The intermediate hidden size between node model and edge model\n  eval_batch_size: 32769      # Edge batch size when evaluating\n  train_batch_size: 32769     # Edge batch size when training\n  num_epochs: 200             # Number of training epochs\n  eval_period: 5              # Interval epochs between evaluations\n  optimizer:\n    name: Adam\n    lr: 0.005\n  loss: BCELoss\n  save_path: \"results\"        # Directory to save the experiment results\n  num_runs: 1                 # Number of experiments to run\n"
  },
  {
    "path": "dglgo/recipes/linkpred_collab_sage.yaml",
    "content": "version: 0.0.2\npipeline_name: linkpred\npipeline_mode: train\ndevice: cpu\ndata:\n  name: ogbl-collab\n  split_ratio:                # List of float, e.q. [0.8, 0.1, 0.1]. Split ratios for training, validation and test sets. Must sum to one. Leave blank to use builtin split in original dataset\n  neg_ratio:                  # Int, e.q. 2. Indicate how much negative samples to be sampled per positive samples. Leave blank to use builtin split in original dataset\nnode_model:\n  name: sage\n  embed_size: -1              # The dimension of created embedding table. -1 means using original node embedding\n  hidden_size: 16             # Hidden size.\n  num_layers: 1               # Number of hidden layers.\n  activation: relu\n  dropout: 0.5                # Dropout rate.\n  aggregator_type: gcn        # Aggregator type to use (``mean``, ``gcn``, ``pool``, ``lstm``).\nedge_model:\n  name: ele\n  hidden_size: 64             # Hidden size.\n  num_layers: 2               # Number of hidden layers.\n  bias: true                  # Whether to use bias in the linaer layer.\nneg_sampler:\n  name: persource\n  k: 3                        # The number of negative samples per edge.\ngeneral_pipeline:\n  hidden_size: 256            # The intermediate hidden size between node model and edge model\n  eval_batch_size: 32769      # Edge batch size when evaluating\n  train_batch_size: 32769     # Edge batch size when training\n  num_epochs: 200             # Number of training epochs\n  eval_period: 5              # Interval epochs between evaluations\n  optimizer:\n    name: Adam\n    lr: 0.005\n  loss: BCELoss\n  save_path: \"results\"        # Directory to save the experiment results\n  num_runs: 1                 # Number of experiments to run\n"
  },
  {
    "path": "dglgo/recipes/linkpred_cora_sage.yaml",
    "content": "version: 0.0.2\npipeline_name: linkpred\npipeline_mode: train\ndevice: cuda\ndata:\n  name: cora\n  split_ratio: [0.8, 0.1, 0.1]               # List of float, e.q. [0.8, 0.1, 0.1]. Split ratios for training, validation and test sets. Must sum to one. Leave blank to use builtin split in original dataset\n  neg_ratio: 3                 # Int, e.q. 2. Indicate how much negative samples to be sampled per positive samples. Leave blank to use builtin split in original dataset\nnode_model:\n  name: sage\n  embed_size: -1              # The dimension of created embedding table. -1 means using original node embedding\n  hidden_size: 32             # Hidden size.\n  num_layers: 2               # Number of hidden layers.\n  activation: relu\n  dropout: 0.5                # Dropout rate.\n  aggregator_type: gcn        # Aggregator type to use (``mean``, ``gcn``, ``pool``, ``lstm``).\nedge_model:\n  name: ele\n  hidden_size: 64             # Hidden size.\n  num_layers: 2               # Number of hidden layers.\n  bias: true                  # Whether to use bias in the linaer layer.\nneg_sampler:\n  name: persource\n  k: 3                        # The number of negative samples per edge.\ngeneral_pipeline:\n  hidden_size: 256            # The intermediate hidden size between node model and edge model\n  eval_batch_size: 32769      # Edge batch size when evaluating\n  train_batch_size: 32769     # Edge batch size when training\n  num_epochs: 200             # Number of training epochs\n  eval_period: 5              # Interval epochs between evaluations\n  optimizer:\n    name: Adam\n    lr: 0.005\n  loss: BCELoss\n  save_path: \"results\"        # Directory to save the experiment results\n  num_runs: 1                 # Number of experiments to run\n"
  },
  {
    "path": "dglgo/recipes/nodepred-ns_arxiv_gcn.yaml",
    "content": "# Accuracy across 5 runs: 0.593288 ± 0.006103\nversion: 0.0.2\npipeline_name: nodepred-ns\npipeline_mode: train\ndevice: 'cuda:0'\neval_device: 'cpu'\ndata:\n  name: ogbn-arxiv\nmodel:\n  name: gcn\n  embed_size: -1              # The dimension of created embedding table. -1 means using original node embedding\n  hidden_size: 256             # Hidden size.\n  num_layers: 2               # Number of layers.\n  norm: both                  # GCN normalization type. Can be 'both', 'right', 'left', 'none'.\n  activation: relu            # Activation function.\n  dropout: 0.5                # Dropout rate.\n  use_edge_weight: false      # If true, scale the messages by edge weights.\ngeneral_pipeline:\n  sampler:\n    name: neighbor\n    fan_out:\n    - 5\n    - 10\n    batch_size: 1024\n    num_workers: 4\n    eval_batch_size: 10240\n    eval_num_workers: 4\n  num_epochs: 20              # Number of training epochs\n  eval_period: 1              # Interval epochs between evaluations\n  optimizer:\n    name: Adam\n    lr: 0.005\n    weight_decay: 0.0\n  loss: CrossEntropyLoss\n  save_path: \"results\"        # Directory to save the experiment results\n  num_runs: 5\n"
  },
  {
    "path": "dglgo/recipes/nodepred-ns_product_sage.yaml",
    "content": "# Accuracy across 1 runs: 0.796911\nversion: 0.0.2\npipeline_name: nodepred-ns\npipeline_mode: train\ndevice: cuda\neval_device: cpu\ndata:\n  name: ogbn-products\n  split_ratio:                # Ratio to generate split masks, for example set to [0.8, 0.1, 0.1] for 80% train/10% val/10% test. Leave blank to use builtin split in original dataset\nmodel:\n  name: sage\n  embed_size: -1              # The dimension of created embedding table. -1 means using original node embedding\n  hidden_size: 256             # Hidden size.\n  num_layers: 3               # Number of hidden layers.\n  activation: relu\n  dropout: 0.5                # Dropout rate.\n  aggregator_type: gcn        # Aggregator type to use (``mean``, ``gcn``, ``pool``, ``lstm``).\ngeneral_pipeline:\n  sampler:\n    name: neighbor\n    fan_out:\n    - 5\n    - 10\n    - 15\n    batch_size: 1000\n    num_workers: 4\n    eval_batch_size: 10000\n    eval_num_workers: 4\n  early_stop:\n    patience: 20              # Steps before early stop\n    checkpoint_path: checkpoint.pth # Early stop checkpoint model file path\n  num_epochs: 20             # Number of training epochs\n  eval_period: 5              # Interval epochs between evaluations\n  optimizer:\n    name: Adam\n    lr: 0.005\n    weight_decay: 0.0\n  loss: CrossEntropyLoss\n  save_path: \"results\"        # Directory to save the experiment results\n  num_runs: 5                 # Number of experiments to run\n"
  },
  {
    "path": "dglgo/recipes/nodepred_citeseer_gat.yaml",
    "content": "# Accuracy across 10 runs: 0.7097 ± 0.006914\nversion: 0.0.2\npipeline_name: nodepred\npipeline_mode: train\ndevice: cuda:0\ndata:\n  name: citeseer\nmodel:\n  name: gat\n  embed_size: -1              # The dimension of created embedding table. -1 means using original node embedding\n  num_layers: 2               # Number of layers.\n  hidden_size: 8              # Hidden size.\n  heads:\n  - 8\n  - 1\n  activation: elu             # Activation function.\n  feat_drop: 0.6              # Dropout rate for features.\n  attn_drop: 0.6              # Dropout rate for attentions.\n  negative_slope: 0.2\n  residual: false             # If true, the GATConv will use residule connection\ngeneral_pipeline:\n  early_stop:\n    patience: 100             # Steps before early stop\n    checkpoint_path: checkpoint.pth # Early stop checkpoint model file path\n  num_epochs: 200             # Number of training epochs\n  eval_period: 5              # Interval epochs between evaluations\n  optimizer:\n    name: Adam\n    lr: 0.005\n    weight_decay: 0.0005\n  loss: CrossEntropyLoss\n  save_path: \"results\"       # Directory to save the experiment results\n  num_runs: 10               # Number of experiments to run\n"
  },
  {
    "path": "dglgo/recipes/nodepred_citeseer_gcn.yaml",
    "content": "# Accuracy across 10 runs: 0.6852 ± 0.008875\nversion: 0.0.2\npipeline_name: nodepred\npipeline_mode: train\ndevice: cuda:0\ndata:\n  name: citeseer\nmodel:\n  name: gcn\n  embed_size: -1              # The dimension of created embedding table. -1 means using original node embedding\n  hidden_size: 16             # Hidden size.\n  num_layers: 2               # Number of layers.\n  norm: both                  # GCN normalization type. Can be 'both', 'right', 'left', 'none'.\n  activation: relu            # Activation function.\n  dropout: 0.5                # Dropout rate.\n  use_edge_weight: false      # If true, scale the messages by edge weights.\ngeneral_pipeline:\n  early_stop:\n    patience: 100             # Steps before early stop\n    checkpoint_path: checkpoint.pth # Early stop checkpoint model file path\n  num_epochs: 200             # Number of training epochs\n  eval_period: 5              # Interval epochs between evaluations\n  optimizer:\n    name: Adam\n    lr: 0.01\n    weight_decay: 0.0005\n  loss: CrossEntropyLoss\n  save_path: \"results\"        # Directory to save the experiment results\n  num_runs: 10                # Number of experiments to run\n"
  },
  {
    "path": "dglgo/recipes/nodepred_citeseer_sage.yaml",
    "content": "# Accuracy across 10 runs: 0.6994 ± 0.004005\nversion: 0.0.2\npipeline_name: nodepred\npipeline_mode: train\ndevice: cuda:0\ndata:\n  name: citeseer\nmodel:\n  name: sage\n  embed_size: -1              # The dimension of created embedding table. -1 means using original node embedding\n  hidden_size: 16             # Hidden size.\n  num_layers: 2               # Number of layers.\n  activation: relu\n  dropout: 0.5                # Dropout rate.\n  aggregator_type: gcn        # Aggregator type to use (``mean``, ``gcn``, ``pool``, ``lstm``).\ngeneral_pipeline:\n  early_stop:\n    patience: 100             # Steps before early stop\n    checkpoint_path: checkpoint.pth # Early stop checkpoint model file path\n  num_epochs: 200             # Number of training epochs\n  eval_period: 5              # Interval epochs between evaluations\n  optimizer:\n    name: Adam\n    lr: 0.01\n    weight_decay: 0.0005\n  loss: CrossEntropyLoss\n  save_path: \"results\"        # Directory to save the experiment results\n  num_runs: 10                # Number of experiments to run\n"
  },
  {
    "path": "dglgo/recipes/nodepred_cora_gat.yaml",
    "content": "# Accuracy across 10 runs: 0.8208 ± 0.00663\nversion: 0.0.2\npipeline_name: nodepred\npipeline_mode: train\ndevice: cuda:0\ndata:\n  name: cora\nmodel:\n  name: gat\n  embed_size: -1              # The dimension of created embedding table. -1 means using original node embedding\n  num_layers: 2               # Number of layers.\n  hidden_size: 8              # Hidden size.\n  heads:\n  - 8\n  - 1\n  activation: elu             # Activation function.\n  feat_drop: 0.6              # Dropout rate for features.\n  attn_drop: 0.6              # Dropout rate for attentions.\n  negative_slope: 0.2\n  residual: false             # If true, the GATConv will use residule connection\ngeneral_pipeline:\n  early_stop:\n    patience: 100             # Steps before early stop\n    checkpoint_path: checkpoint.pth # Early stop checkpoint model file path\n  num_epochs: 200             # Number of training epochs\n  eval_period: 5              # Interval epochs between evaluations\n  optimizer:\n    name: Adam\n    lr: 0.005\n    weight_decay: 0.0005\n  loss: CrossEntropyLoss\n  save_path: \"results\"        # Directory to save the experiment results\n  num_runs: 10                # Number of experiments to run\n"
  },
  {
    "path": "dglgo/recipes/nodepred_cora_gcn.yaml",
    "content": "# Accuracy across 10 runs: 0.802 ± 0.005329\nversion: 0.0.2\npipeline_name: nodepred\npipeline_mode: train\ndevice: cuda:0\ndata:\n  name: cora\nmodel:\n  name: gcn\n  embed_size: -1              # The dimension of created embedding table. -1 means using original node embedding\n  hidden_size: 16             # Hidden size.\n  num_layers: 2               # Number of layers.\n  norm: both                  # GCN normalization type. Can be 'both', 'right', 'left', 'none'.\n  activation: relu            # Activation function.\n  dropout: 0.5                # Dropout rate.\n  use_edge_weight: false      # If true, scale the messages by edge weights.\ngeneral_pipeline:\n  early_stop:\n    patience: 100             # Steps before early stop\n    checkpoint_path: checkpoint.pth # Early stop checkpoint model file path\n  num_epochs: 200             # Number of training epochs\n  eval_period: 5              # Interval epochs between evaluations\n  optimizer:\n    name: Adam\n    lr: 0.01\n    weight_decay: 0.0005\n  loss: CrossEntropyLoss\n  save_path: \"results\"        # Directory to save the experiment results\n  num_runs: 10                # Number of experiments to run\n"
  },
  {
    "path": "dglgo/recipes/nodepred_cora_sage.yaml",
    "content": "# Accuracy across 10 runs: 0.8163 ± 0.006856\nversion: 0.0.2\npipeline_name: nodepred\npipeline_mode: train\ndevice: cuda:0\ndata:\n  name: cora\nmodel:\n  name: sage\n  embed_size: -1              # The dimension of created embedding table. -1 means using original node embedding\n  hidden_size: 16             # Hidden size.\n  num_layers: 2               # Number of layers.\n  activation: relu\n  dropout: 0.5                # Dropout rate.\n  aggregator_type: gcn        # Aggregator type to use (``mean``, ``gcn``, ``pool``, ``lstm``).\ngeneral_pipeline:\n  early_stop:\n    patience: 100             # Steps before early stop\n    checkpoint_path: checkpoint.pth # Early stop checkpoint model file path\n  num_epochs: 200             # Number of training epochs\n  eval_period: 5              # Interval epochs between evaluations\n  optimizer:\n    name: Adam\n    lr: 0.01\n    weight_decay: 0.0005\n  loss: CrossEntropyLoss\n  save_path: \"results\"        # Directory to save the experiment results\n  num_runs: 10                # Number of experiments to run\n"
  },
  {
    "path": "dglgo/recipes/nodepred_pubmed_gat.yaml",
    "content": "# Accuracy across 10 runs: 0.7788 ± 0.002227\nversion: 0.0.2\npipeline_name: nodepred\npipeline_mode: train\ndevice: cuda:0\ndata:\n  name: pubmed\nmodel:\n  name: gat\n  embed_size: -1              # The dimension of created embedding table. -1 means using original node embedding\n  num_layers: 2               # Number of layers.\n  hidden_size: 8              # Hidden size.\n  heads:\n  - 8\n  - 8\n  activation: elu             # Activation function.\n  feat_drop: 0.6              # Dropout rate for features.\n  attn_drop: 0.6              # Dropout rate for attentions.\n  negative_slope: 0.2\n  residual: false             # If true, the GATConv will use residule connection\ngeneral_pipeline:\n  early_stop:\n    patience: 100             # Steps before early stop\n    checkpoint_path: checkpoint.pth # Early stop checkpoint model file path\n  num_epochs: 200             # Number of training epochs\n  eval_period: 5              # Interval epochs between evaluations\n  optimizer:\n    name: Adam\n    lr: 0.005\n    weight_decay: 0.001\n  loss: CrossEntropyLoss\n  save_path: \"results\"        # Directory to save the experiment results\n  num_runs: 10                # Number of experiments to run\n"
  },
  {
    "path": "dglgo/recipes/nodepred_pubmed_gcn.yaml",
    "content": "# Accuracy across 10 runs: 0.7826 ± 0.004317\nversion: 0.0.2\npipeline_name: nodepred\npipeline_mode: train\ndevice: cuda:0\ndata:\n  name: pubmed\nmodel:\n  name: gcn\n  embed_size: -1              # The dimension of created embedding table. -1 means using original node embedding\n  hidden_size: 16             # Hidden size.\n  num_layers: 2               # Number of layers.\n  norm: both                  # GCN normalization type. Can be 'both', 'right', 'left', 'none'.\n  activation: relu            # Activation function.\n  dropout: 0.5                # Dropout rate.\n  use_edge_weight: false      # If true, scale the messages by edge weights.\ngeneral_pipeline:\n  early_stop:\n    patience: 100             # Steps before early stop\n    checkpoint_path: checkpoint.pth # Early stop checkpoint model file path\n  num_epochs: 200             # Number of training epochs\n  eval_period: 5              # Interval epochs between evaluations\n  optimizer:\n    name: Adam\n    lr: 0.01\n    weight_decay: 0.0005\n  loss: CrossEntropyLoss\n  save_path: \"results\"        # Directory to save the experiment results\n  num_runs: 10                # Number of experiments to run\n"
  },
  {
    "path": "dglgo/recipes/nodepred_pubmed_sage.yaml",
    "content": "# Accuracy across 10 runs: 0.7819 ± 0.003176\nversion: 0.0.2\npipeline_name: nodepred\npipeline_mode: train\ndevice: cuda:0\ndata:\n  name: pubmed\nmodel:\n  name: sage\n  embed_size: -1              # The dimension of created embedding table. -1 means using original node embedding\n  hidden_size: 16             # Hidden size.\n  num_layers: 2               # Number of layers.\n  activation: relu\n  dropout: 0.5                # Dropout rate.\n  aggregator_type: gcn        # Aggregator type to use (``mean``, ``gcn``, ``pool``, ``lstm``).\ngeneral_pipeline:\n  early_stop:\n    patience: 100             # Steps before early stop\n    checkpoint_path: checkpoint.pth # Early stop checkpoint model file path\n  num_epochs: 200             # Number of training epochs\n  eval_period: 5              # Interval epochs between evaluations\n  optimizer:\n    name: Adam\n    lr: 0.01\n    weight_decay: 0.0005\n  loss: CrossEntropyLoss\n  save_path: \"results\"        # Directory to save the experiment results\n  num_runs: 10                # Number of experiments to run\n"
  },
  {
    "path": "dglgo/setup.py",
    "content": "#!/usr/bin/env python\n\nfrom setuptools import find_packages, setup\n\nsetup(\n    name=\"dglgo\",\n    version=\"0.0.2\",\n    description=\"DGL\",\n    author=\"DGL Team\",\n    author_email=\"wmjlyjemaine@gmail.com\",\n    packages=find_packages(),\n    install_requires=[\n        \"typer>=0.4.0\",\n        \"isort>=5.10.1\",\n        \"autopep8>=1.6.0\",\n        \"numpydoc>=1.1.0\",\n        \"pydantic>=1.9.0\",\n        \"ruamel.yaml>=0.17.20\",\n        \"PyYAML>=5.1\",\n        \"ogb>=1.3.3\",\n        \"rdkit-pypi\",\n        \"scikit-learn>=0.20.0\",\n    ],\n    package_data={\"\": [\"./*\"]},\n    include_package_data=True,\n    license=\"APACHE\",\n    entry_points={\"console_scripts\": [\"dgl = dglgo.cli.cli:main\"]},\n    url=\"https://github.com/dmlc/dgl\",\n)\n"
  },
  {
    "path": "dglgo/tests/cfg.yml",
    "content": "version: 0.0.2\npipeline_name: nodepred\npipeline_mode: train\ndevice: cpu\ndata:\n  name: cora\n  split_ratio:                # Ratio to generate split masks, for example set to [0.8, 0.1, 0.1] for 80% train/10% val/10% test. Leave blank to use builtin split in original dataset\nmodel:\n  name: sage\n  embed_size: -1              # The dimension of created embedding table. -1 means using original node embedding\n  hidden_size: 16             # Hidden size.\n  num_layers: 1               # Number of hidden layers.\n  activation: relu            # Activation function name under torch.nn.functional\n  dropout: 0.5                # Dropout rate.\n  aggregator_type: gcn        # Aggregator type to use (``mean``, ``gcn``, ``pool``, ``lstm``).\ngeneral_pipeline:\n  early_stop:\n    patience: 20              # Steps before early stop\n    checkpoint_path: checkpoint.pth # Early stop checkpoint model file path\n  num_epochs: 200             # Number of training epochs\n  eval_period: 5              # Interval epochs between evaluations\n  optimizer:\n    name: Adam\n    lr: 0.01\n    weight_decay: 0.0005\n  loss: CrossEntropyLoss\n  num_runs: 1                 # Number of experiments to run\n"
  },
  {
    "path": "dglgo/tests/run_test.sh",
    "content": "python -m pytest --pdb -vv --capture=tee-sys test_pipeline.py::test_recipe"
  },
  {
    "path": "dglgo/tests/test_pipeline.py",
    "content": "import subprocess\nfrom pathlib import Path\nfrom typing import NamedTuple\n\nimport pytest\n\n# class DatasetSpec:\n\ndataset_spec = {\"cora\": {\"timeout\": 30}}\n\n\nclass ExperimentSpec(NamedTuple):\n    pipeline: str\n    dataset: str\n    model: str\n    timeout: int\n    extra_cfg: dict = {}\n\n\nexps = [\n    ExperimentSpec(\n        pipeline=\"nodepred\", dataset=\"cora\", model=\"sage\", timeout=0.5\n    )\n]\n\n\n@pytest.mark.parametrize(\"spec\", exps)\ndef test_train(spec):\n    cfg_path = \"/tmp/test.yaml\"\n    run = subprocess.run(\n        [\n            \"dgl\",\n            \"config\",\n            spec.pipeline,\n            \"--data\",\n            spec.dataset,\n            \"--model\",\n            spec.model,\n            \"--cfg\",\n            cfg_path,\n        ],\n        timeout=spec.timeout,\n        capture_output=True,\n    )\n    assert (\n        run.stderr is None or len(run.stderr) == 0\n    ), \"Found error message: {}\".format(run.stderr)\n    output = run.stdout.decode(\"utf-8\")\n    print(output)\n\n    run = subprocess.run(\n        [\"dgl\", \"train\", \"--cfg\", cfg_path],\n        timeout=spec.timeout,\n        capture_output=True,\n    )\n    assert (\n        run.stderr is None or len(run.stderr) == 0\n    ), \"Found error message: {}\".format(run.stderr)\n    output = run.stdout.decode(\"utf-8\")\n    print(output)\n\n\nTEST_RECIPE_FOLDER = \"my_recipes\"\n\n\n@pytest.fixture\ndef setup_recipe_folder():\n    run = subprocess.run(\n        [\"dgl\", \"recipe\", \"copy\", \"--dir\", TEST_RECIPE_FOLDER],\n        timeout=15,\n        capture_output=True,\n    )\n\n\n@pytest.mark.parametrize(\n    \"file\", [str(f) for f in Path(TEST_RECIPE_FOLDER).glob(\"*.yaml\")]\n)\ndef test_recipe(file, setup_recipe_folder):\n    print(\"DGL enter train {}\".format(file))\n    try:\n        run = subprocess.run(\n            [\"dgl\", \"train\", \"--cfg\", file], timeout=5, capture_output=True\n        )\n        sh_stdout, sh_stderr = run.stdout, run.stderr\n    except subprocess.TimeoutExpired as e:\n        sh_stdout = e.stdout\n        sh_stderr = e.stderr\n    if sh_stderr is not None and len(sh_stderr) != 0:\n        error_str = sh_stderr.decode(\"utf-8\")\n        lines = error_str.split(\"\\n\")\n        for line in lines:\n            line = line.strip()\n            if (\n                line.startswith(\"WARNING\")\n                or line.startswith(\"Aborted\")\n                or line.startswith(\"0%\")\n            ):\n                continue\n            else:\n                assert len(line) == 0, error_str\n    print(\"{} stdout: {}\".format(file, sh_stdout))\n    print(\"{} stderr: {}\".format(file, sh_stderr))\n\n\n# test_recipe( , None)\n"
  },
  {
    "path": "docker/Dockerfile.awscli",
    "content": "# Using the Ubuntu image (our OS)\nFROM ubuntu:latest\n# Update package manager (apt-get) \n# and install (with the yes flag `-y`)\n# Python and Pip\nRUN apt-get update && apt-get install -y \\\n    python3.8 \\\n    python3-pip\n\nRUN apt-get install -y \\\n        unzip \\\n        curl \\\n    && apt-get clean \\\n    && curl \"https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip\" -o \"awscliv2.zip\" \\\n    && unzip awscliv2.zip \\\n    && ./aws/install \\\n    && rm -rf \\\n        awscliv2.zip\n\nRUN pip install pytest pytest-html requests\n"
  },
  {
    "path": "docker/Dockerfile.ci_benchmark",
    "content": "# CI docker GPU env\nFROM nvidia/cuda:11.6.0-cudnn8-devel-ubuntu20.04\n\nENV TZ=US\nRUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone\n\nRUN apt-get update --fix-missing\n\nCOPY install/ubuntu_install_core.sh /install/ubuntu_install_core.sh\nRUN bash /install/ubuntu_install_core.sh\n\nCOPY install/ubuntu_install_build.sh /install/ubuntu_install_build.sh\nRUN bash /install/ubuntu_install_build.sh\n\n# python\nCOPY install/ubuntu_install_conda.sh /install/ubuntu_install_conda.sh\nRUN bash /install/ubuntu_install_conda.sh\n\nENV CONDA_ALWAYS_YES=\"true\"\n\nENV CONDA_ALWAYS_YES=\n\n# Environment variables\nENV PATH=/usr/local/nvidia/bin:${PATH}\nENV PATH=/usr/local/cuda/bin:${PATH}\nENV CPLUS_INCLUDE_PATH=/usr/local/cuda/include:${CPLUS_INCLUDE_PATH}\nENV C_INCLUDE_PATH=/usr/local/cuda/include:${C_INCLUDE_PATH}\nENV LIBRARY_PATH=/usr/local/cuda/lib64:/usr/local/nvidia/lib64:${LIBRARY_PATH}\nENV LD_LIBRARY_PATH=/usr/local/cuda/lib64:/usr/local/nvidia/lib64:${LD_LIBRARY_PATH}\nENV TF_FORCE_GPU_ALLOW_GROWTH=true"
  },
  {
    "path": "docker/Dockerfile.ci_cpu",
    "content": "# CI docker CPU env\n# Adapted from github.com/dmlc/tvm/docker/Dockerfile.ci_cpu\nFROM ubuntu:20.04\n\nENV TZ=US\nRUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone\n\nRUN apt-get update --fix-missing\n\nCOPY install/ubuntu_install_core.sh /install/ubuntu_install_core.sh\nRUN bash /install/ubuntu_install_core.sh\n\nCOPY install/ubuntu_install_build.sh /install/ubuntu_install_build.sh\nRUN bash /install/ubuntu_install_build.sh\n\n# tcmalloc\nRUN apt-get install -y libgoogle-perftools4\nENV LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc.so.4:$LD_PRELOAD\n\n# python\nCOPY install/ubuntu_install_conda.sh /install/ubuntu_install_conda.sh\nRUN bash /install/ubuntu_install_conda.sh\n\nENV CONDA_ALWAYS_YES=\"true\"\n\nCOPY install/conda_env/torch_cpu.yml /install/conda_env/torch_cpu.yml\nCOPY install/conda_env/torch_cpu_pip.txt /install/conda_env/torch_cpu_pip.txt\nRUN [\"/bin/bash\", \"-i\", \"-c\", \"conda env create -f /install/conda_env/torch_cpu.yml\"]\n\nCOPY install/conda_env/tensorflow_cpu.yml /install/conda_env/tensorflow_cpu.yml\nRUN [\"/bin/bash\", \"-i\", \"-c\", \"conda env create -f /install/conda_env/tensorflow_cpu.yml\"]\n\nCOPY install/conda_env/mxnet_cpu.yml /install/conda_env/mxnet_cpu.yml\nRUN [\"/bin/bash\", \"-i\", \"-c\", \"conda env create -f /install/conda_env/mxnet_cpu.yml\"]\n\nENV CONDA_ALWAYS_YES=\n\n# SSH\nRUN [\"/bin/bash\", \"-i\", \"-c\", \"ssh-keygen -f ~/.ssh/id_rsa -N ''\"]\nRUN [\"/bin/bash\", \"-i\", \"-c\", \"cat ~/.ssh/id_rsa.pub >> ~/.ssh/authorized_keys\"]\nENTRYPOINT service ssh restart && bash\n"
  },
  {
    "path": "docker/Dockerfile.ci_gpu",
    "content": "# CI docker GPU env\nFROM nvidia/cuda:12.1.0-cudnn8-devel-ubuntu20.04\n\nENV TZ=US\nRUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone\n\nCOPY install/ubuntu_install_core.sh /install/ubuntu_install_core.sh\nRUN bash /install/ubuntu_install_core.sh\n\nCOPY install/ubuntu_install_build.sh /install/ubuntu_install_build.sh\nRUN bash /install/ubuntu_install_build.sh\n\n# python\nCOPY install/ubuntu_install_conda.sh /install/ubuntu_install_conda.sh\nRUN bash /install/ubuntu_install_conda.sh\n\nENV CONDA_ALWAYS_YES=\"true\"\n\nCOPY install/conda_env/torch_gpu.yml /install/conda_env/torch_gpu.yml\nCOPY install/conda_env/torch_gpu_pip.txt /install/conda_env/torch_gpu_pip.txt\nRUN [\"/bin/bash\", \"-i\", \"-c\", \"conda env create -f /install/conda_env/torch_gpu.yml\"]\n\nCOPY install/conda_env/tensorflow_gpu.yml /install/conda_env/tensorflow_gpu.yml\nRUN [\"/bin/bash\", \"-i\", \"-c\", \"conda env create -f /install/conda_env/tensorflow_gpu.yml\"]\n\nCOPY install/conda_env/mxnet_gpu.yml /install/conda_env/mxnet_gpu.yml\nRUN [\"/bin/bash\", \"-i\", \"-c\", \"conda env create -f /install/conda_env/mxnet_gpu.yml\"]\n\nENV CONDA_ALWAYS_YES=\n\n# Environment variables\nENV PATH=/usr/local/nvidia/bin:${PATH}\nENV PATH=/usr/local/cuda/bin:${PATH}\nENV CPLUS_INCLUDE_PATH=/usr/local/cuda/include:${CPLUS_INCLUDE_PATH}\nENV C_INCLUDE_PATH=/usr/local/cuda/include:${C_INCLUDE_PATH}\nENV LIBRARY_PATH=/usr/local/cuda/lib64:/usr/local/nvidia/lib64:${LIBRARY_PATH}\nENV LD_LIBRARY_PATH=/usr/local/cuda/lib64:/usr/local/nvidia/lib64:${LD_LIBRARY_PATH}\nENV TF_FORCE_GPU_ALLOW_GROWTH=true\n"
  },
  {
    "path": "docker/Dockerfile.ci_lint",
    "content": "# CI docker for lint\n# Adapted from github.com/dmlc/tvm/docker/Dockerfile.ci_lint\n\nFROM ubuntu:18.04\n\nENV DEBIAN_FRONTEND=noninteractive\n\nCOPY install/ubuntu_install_core.sh /install/ubuntu_install_core.sh\nRUN bash /install/ubuntu_install_core.sh\n\nCOPY install/ubuntu_install_python.sh /install/ubuntu_install_python.sh\nRUN bash /install/ubuntu_install_python.sh\n\nRUN apt-get install -y doxygen graphviz\n\nRUN pip3 install cpplint==1.3.0 pylint==2.7.0 mypy\n"
  },
  {
    "path": "docker/README.md",
    "content": "## Build docker image for CI\n\n### CPU image\n```bash\ndocker build -t dgl-cpu -f Dockerfile.ci_cpu .\n```\n\n### GPU image\n```bash\ndocker build -t dgl-gpu -f Dockerfile.ci_gpu .\n```\n\n### Lint image\n```bash\ndocker build -t dgl-lint -f Dockerfile.ci_lint .\n```\n\n### CPU image for kg\n```bash\nwget https://data.dgl.ai/dataset/FB15k.zip -P install/\ndocker build -t dgl-cpu:torch-1.2.0 -f Dockerfile.ci_cpu_torch_1.2.0 .\n```\n\n### GPU image for kg\n```bash\nwget https://data.dgl.ai/dataset/FB15k.zip -P install/\ndocker build -t dgl-gpu:torch-1.2.0 -f Dockerfile.ci_gpu_torch_1.2.0 .\n```\n"
  },
  {
    "path": "docker/install/conda_env/kg_cpu.yml",
    "content": "name: kg-ci\ndependencies:\n  - python=3.6.9\n  - pip\n  - pip:\n    - torch\n    - torchvision\n    - mxnet\n    - pytest\n  - nose\n  - numpy\n  - cython\n  - scipy\n  - networkx\n  - matplotlib\n  - nltk\n  - requests[security]\n  - tqdm"
  },
  {
    "path": "docker/install/conda_env/kg_gpu.yml",
    "content": "name: kg-ci\ndependencies:\n  - python=3.6.9\n  - pip\n  - pip:\n    - torch\n    - torchvision\n    - mxnet-cu101\n    - pytest\n  - nose\n  - numpy\n  - cython\n  - scipy\n  - networkx\n  - matplotlib\n  - nltk\n  - requests[security]\n  - tqdm"
  },
  {
    "path": "docker/install/conda_env/mxnet_cpu.yml",
    "content": "name: mxnet-ci\ndependencies:\n  - python=3.7.0\n  - pip\n  - pip:\n    - mxnet==1.6.0\n    - pytest\n    - nose\n    - numpy\n    - cython==0.29\n    - scipy\n    - networkx\n    - matplotlib\n    - nltk\n    - requests[security]\n    - tqdm\n    - psutil\n    - pyyaml\n    - pydantic\n    - pandas\n    - rdflib\n    - ogb\n"
  },
  {
    "path": "docker/install/conda_env/mxnet_gpu.yml",
    "content": "name: mxnet-ci\ndependencies:\n  - python=3.7.0\n  - pip\n  - pip:\n    - mxnet-cu101==1.7.0\n    - pytest\n    - nose\n    - numpy\n    - cython==0.29\n    - scipy\n    - networkx\n    - matplotlib\n    - nltk\n    - requests[security]\n    - tqdm\n    - psutil\n    - pyyaml\n    - pydantic\n    - pandas\n    - rdflib\n    - ogb\n"
  },
  {
    "path": "docker/install/conda_env/tensorflow_cpu.yml",
    "content": "name: tensorflow-ci\ndependencies:\n  - python=3.7\n  - pip\n  - pip:\n    - tensorflow==2.3.0\n    - pytest\n    - nose\n    - numpy\n    - cython==0.29\n    - scipy\n    - networkx\n    - matplotlib\n    - nltk\n    - requests[security]\n    - tqdm\n    - psutil\n    - pyyaml\n    - pydantic\n    - pandas\n    - rdflib\n    - ogb\n"
  },
  {
    "path": "docker/install/conda_env/tensorflow_gpu.yml",
    "content": "name: tensorflow-ci\ndependencies:\n  - python=3.7.0\n  - pip\n  - pip:\n    - tensorflow==2.3.0\n    - pytest\n    - nose\n    - numpy\n    - cython==0.29\n    - scipy\n    - networkx\n    - matplotlib\n    - nltk\n    - requests[security]\n    - tqdm\n    - psutil\n    - pyyaml\n    - pydantic\n    - pandas\n    - rdflib\n    - ogb\n"
  },
  {
    "path": "docker/install/conda_env/torch_cpu.yml",
    "content": "name: pytorch-ci\ndependencies:\n  - python=3.10\n  - pip\n  - pip:\n    - --find-links https://download.pytorch.org/whl/torch_stable.html\n    - --requirement torch_cpu_pip.txt"
  },
  {
    "path": "docker/install/conda_env/torch_cpu_pip.txt",
    "content": "--find-links https://download.pytorch.org/whl/torch_stable.html\ncython\nfilelock\nmatplotlib\nnetworkx\nnltk\nnose\nnumpy\nogb\npandas\npsutil\npyarrow\npydantic\npytest\npyyaml\nrdflib\nrequests[security]==2.28\nscikit-learn\nscipy\ntorch==2.3.0+cpu\ntorcheval\ntorchmetrics\ntorch_geometric\ntqdm\n"
  },
  {
    "path": "docker/install/conda_env/torch_gpu.yml",
    "content": "name: pytorch-ci\ndependencies:\n  - python=3.10\n  - pip\n  - pip:\n    - --find-links https://download.pytorch.org/whl/torch_stable.html\n    - --requirement torch_gpu_pip.txt"
  },
  {
    "path": "docker/install/conda_env/torch_gpu_pip.txt",
    "content": "--find-links https://download.pytorch.org/whl/torch_stable.html\ncython\nmatplotlib\nnetworkx\nnltk\nnose\nnumpy\nogb\npandas\npsutil\npydantic\npytest\npyyaml\nrdflib\nrequests[security]==2.28\nscikit-learn\nscipy\ntorch==2.3.0+cu121\ntorcheval\ntorchmetrics\ntorch_geometric\ntqdm\n"
  },
  {
    "path": "docker/install/ubuntu_install_antlr.sh",
    "content": "#!/bin/bash\n\nset -e\nset -u\nset -o pipefail\n\ncd /usr/local/lib\nwget -q https://www.antlr.org/download/antlr-4.7.1-complete.jar\ncd -\n"
  },
  {
    "path": "docker/install/ubuntu_install_build.sh",
    "content": "# Install cmake with minimum required version.\nversion=3.18\nbuild=0\nmkdir ~/temp\ncd ~/temp\nwget https://cmake.org/files/v$version/cmake-$version.$build-Linux-x86_64.sh \nsudo mkdir /opt/cmake\nsudo sh cmake-$version.$build-Linux-x86_64.sh --prefix=/opt/cmake --skip-license\nsudo ln -s /opt/cmake/bin/cmake /usr/local/bin/cmake\ncd ~\nrm -rf ~/temp"
  },
  {
    "path": "docker/install/ubuntu_install_conda.sh",
    "content": "#!/bin/sh\nexport LANG=C.UTF-8 LC_ALL=C.UTF-8\nexport PATH=/opt/conda/bin:$PATH\n\napt-get update --fix-missing && \\\n    apt-get install -y wget bzip2 ca-certificates curl git && \\\n    apt-get clean && \\\n    rm -rf /var/lib/apt/lists/*\n\nwget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh && \\\n    /bin/bash ~/miniconda.sh -b -p /opt/conda && \\\n    rm ~/miniconda.sh && \\\n    /opt/conda/bin/conda clean -tipy && \\\n    ln -s /opt/conda/etc/profile.d/conda.sh /etc/profile.d/conda.sh && \\\n    echo \". /opt/conda/etc/profile.d/conda.sh\" >> ~/.bashrc && \\\n    echo \"conda activate base\" >> ~/.bashrc\n\nexport TINI_VERSION=v0.16.1\nsource ~/.bashrc\n"
  },
  {
    "path": "docker/install/ubuntu_install_core.sh",
    "content": "# install libraries for building c++ core on ubuntu\napt update && apt install -y --no-install-recommends --force-yes \\\n        apt-utils git build-essential make wget unzip sudo \\\n        libz-dev libxml2-dev libopenblas-dev libopencv-dev \\\n        graphviz graphviz-dev libgraphviz-dev ca-certificates \\\n        systemd vim openssh-client openssh-server\n"
  },
  {
    "path": "docker/install/ubuntu_install_java.sh",
    "content": "#!/bin/bash\n\nset -o errexit -o nounset\nset -o pipefail\n\napt-get update && apt-get install -y openjdk-8-jdk maven\ntest -d \"/usr/lib/jvm/java-8-openjdk-amd64/jre\"\necho \"export JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64/jre\" >> /etc/profile\n"
  },
  {
    "path": "docker/install/ubuntu_install_mxnet_cpu.sh",
    "content": "pip3 install mxnet\n"
  },
  {
    "path": "docker/install/ubuntu_install_mxnet_gpu.sh",
    "content": "pip3 install mxnet-cu90\n"
  },
  {
    "path": "docker/install/ubuntu_install_python.sh",
    "content": "# install python and pip, don't modify this, modify install_python_package.sh\napt-get update\napt-get install -y python-dev python3-dev\n\n# install pip\ncd /tmp && wget https://bootstrap.pypa.io/get-pip.py\npython2 get-pip.py\npython3 get-pip.py\n\n# santiy check\npython2 --version\npython3 --version\npip2 --version\npip3 --version\n"
  },
  {
    "path": "docker/install/ubuntu_install_python_package.sh",
    "content": "# install libraries for python package on ubuntu\n#pip2 install nose numpy cython scipy networkx matplotlib nltk requests[security] tqdm\npip3 install nose numpy cython scipy networkx matplotlib nltk requests[security] tqdm\n"
  },
  {
    "path": "docker/install/ubuntu_install_torch.sh",
    "content": "#!/bin/bash\n# install torch\npip2 install torch==1.0.1 torchvision==0.2.2\npip3 install torch==1.0.1 torchvision==0.2.2\n"
  },
  {
    "path": "docker/install/ubuntu_install_torch_1.2.0.sh",
    "content": "#!/bin/bash\n# install torch\npip3 install torch==1.2.0+cu92 torchvision==0.4.0+cu92 -f https://download.pytorch.org/whl/torch_stable.html\n"
  },
  {
    "path": "docker/pods/ci-compile-cpu.yaml",
    "content": "apiVersion: v1\nkind: Pod\nspec:\n  securityContext:\n    runAsUser: 0\n  containers:\n  - name: dgl-ci-cpu-compile\n    image: dgllib/dgl-ci-cpu:cu101_v220123\n    imagePullPolicy: Always\n    tty: true\n    resources:\n      requests:\n        cpu: 16\n  # affinity:\n  #   nodeAffinity:\n  #     requiredDuringSchedulingIgnoredDuringExecution:\n  #       nodeSelectorTerms:\n  #       - matchExpressions:\n  #         - key: beta.kubernetes.io/instance-type\n  #           operator: In\n  #           values:\n  #           - c5.9xlarge"
  },
  {
    "path": "docker/pods/ci-compile-gpu.yaml",
    "content": "apiVersion: v1\nkind: Pod\nspec:\n  securityContext:\n    runAsUser: 0\n  containers:\n  - name: dgl-ci-gpu-compile\n    image: dgllib/dgl-ci-gpu:cu101_v220123\n    imagePullPolicy: Always\n    tty: true\n    resources:\n      requests:\n        cpu: 32\n  # affinity:\n  #   nodeAffinity:\n  #     requiredDuringSchedulingIgnoredDuringExecution:\n  #       nodeSelectorTerms:\n  #       - matchExpressions:\n  #         - key: beta.kubernetes.io/instance-type\n  #           operator: In\n  #           values:\n  #           - c5.9xlarge"
  },
  {
    "path": "docker/pods/ci-cpu.yaml",
    "content": "apiVersion: v1\nkind: Pod\nspec:\n  securityContext:\n    runAsUser: 0\n  containers:\n  - name: dgl-ci-cpu\n    image: dgllib/dgl-ci-cpu:cu101_v220217\n    imagePullPolicy: Always\n    tty: true\n    resources:\n      requests:\n        cpu: 16\n    volumeMounts:\n      # - name: persistent-storage\n      #   mountPath: /tmp/dataset\n      - name: dshm\n        mountPath: /dev/shm    \n  volumes:\n  # - name: persistent-storage\n  #   persistentVolumeClaim:\n  #     claimName: ogb-efs-claim\n  - name: dshm\n    emptyDir:\n      medium: Memory\n"
  },
  {
    "path": "docker/pods/ci-gpu.yaml",
    "content": "apiVersion: v1\nkind: Pod\nspec:\n  securityContext:\n    runAsUser: 0\n  containers:\n    - name: dgl-ci-gpu\n      image: dgllib/dgl-ci-gpu:cu101_v220217\n      imagePullPolicy: Always\n      tty: true\n      resources:\n        limits:\n          nvidia.com/gpu: 1 # requesting 1 GPU\n      volumeMounts:\n        - name: dshm\n          mountPath: /dev/shm\n  volumes:\n  - name: dshm\n    emptyDir:\n      medium: Memory\n  affinity:\n    nodeAffinity:\n      requiredDuringSchedulingIgnoredDuringExecution:\n        nodeSelectorTerms:\n        - matchExpressions:\n          - key: beta.kubernetes.io/instance-type\n            operator: In\n            values:\n            - g4dn.2xlarge"
  },
  {
    "path": "docker/pods/ci-lint.yaml",
    "content": "apiVersion: v1\nkind: Pod\nspec:\n  securityContext:\n    runAsUser: 0\n  containers:\n  - name: dgl-ci-lint\n    image: dgllib/dgl-ci-lint\n    imagePullPolicy: Always\n    tty: true\n    resources:\n      requests:\n        cpu: 1\n  serviceAccountName: dglciuser"
  },
  {
    "path": "docs/.gitignore",
    "content": "build\n\n# tutorials are auto-generated\nsource/tutorials\nsource/new-tutorial\nsource/generated\n"
  },
  {
    "path": "docs/Makefile",
    "content": "# Minimal makefile for Sphinx documentation\n#\n\n# You can set these variables from the command line.\nSPHINXOPTS    =\nSPHINXBUILD   = sphinx-build\nSOURCEDIR     = source\nBUILDDIR      = build\n\n# Put it first so that \"make\" without argument is like \"make help\".\nhelp:\n\t@$(SPHINXBUILD) -M help \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n\n.PHONY: help Makefile\n\nmxnet:\n\t@echo \"##################################################################\"\n\t@echo \"#                                                                #\"\n\t@echo \"#                Step 1: Building MXNet tutorials                #\"\n\t@echo \"#                                                                #\"\n\t@echo \"##################################################################\"\n\t@DGLBACKEND=mxnet $(SPHINXBUILD) -M html \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n\npytorch:\n\t@echo \"##################################################################\"\n\t@echo \"#                                                                #\"\n\t@echo \"#                Step 2: Building PyTorch tutorials              #\"\n\t@echo \"#                                                                #\"\n\t@echo \"##################################################################\"\n\t@DGLBACKEND=pytorch $(SPHINXBUILD) -M html \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O) \n\ntensorflow:\n\t@echo \"##################################################################\"\n\t@echo \"#                                                                #\"\n\t@echo \"#                Step 3: Building Tensorflow tutorials           #\"\n\t@echo \"#                                                                #\"\n\t@echo \"##################################################################\"\n\t@DGLBACKEND=tensorflow $(SPHINXBUILD) -M html \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O) \n\nhtml-noexec:\n\t$(SPHINXBUILD) -D plot_gallery=0 -b html \"$(SOURCEDIR)\" \"$(BUILDDIR)/html\" \n\t@echo\n\t@echo \"Build finished. The HTML pages are in $(BUILDDIR)/html.\"\n\nhtml: Makefile mxnet pytorch tensorflow\n\n# Catch-all target: route all unknown targets to Sphinx using the new\n# \"make mode\" option.  $(O) is meant as a shortcut for $(SPHINXOPTS).\n%: Makefile\n\t@$(SPHINXBUILD) -M $@ \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O) \n"
  },
  {
    "path": "docs/README.md",
    "content": "DGL document and tutorial folder\n================================\n\n\nTo build the doc:\n\n- Create the developer conda environment using the script [here](../script/create_dev_conda_env.sh).\n- Activate the developer conda environment.\n- Build DGL from source using the script [here](../script/build_dgl.sh).\n- Build the doc using the script [here](../script/build_doc.sh).\n\nTo render locally:\n```\ncd build/html\npython3 -m http.server 8000\n```\n"
  },
  {
    "path": "docs/clean.sh",
    "content": "#!/bin/sh\n\nmake clean\nrm -rf build\nrm -rf source/tutorials\nrm -rf source/generated\n"
  },
  {
    "path": "docs/migrate-guide-0.5.md",
    "content": "# Migration Guide for DGL 0.5\n\n## Breaking changes\n\nThe following changes may break existing codes if the related APIs are used. Note that **most of the removed APIs have quite rare use cases** and have quite easy replacements.\n\n1. DGLGraph now requires the graph structure and feature data to have the same device placement. If the given node/edge feature tensors have different devices as the graph’s, dgl.ndata and dgl.edata will raise an error as follow:\n    ```bash\n    dgl._ffi.base.DGLError: Cannot assign node feature \"x\" on device cpu to a graph on device cuda:0.\n    Call DGLGraph.to() to copy the graph to the same device.\n    ```\n    To fix it, copy either the graph (using the `DGLGraph.to` API) or the feature tensors to the same device.\n\n1. Changes to `dgl.graph`:\n    * No longer accept SciPy matrix/NetworkX graph as the input data. Use `dgl.from_scipy`/`dgl.from_networkx` instead.\n    * `ntype` and `etype` are removed from the arguments. To construct graphs with named node/edge types, use `dgl.heterograph`.\n        ```python\n        g = dgl.heterograph(('user', 'follows', 'user') : ...)\n        ```\n    * `validate` is removed from the arguments. DGL now always checks whether the num_nodes is greater than the largest node ID if specified.\n1. `dgl.bipartite` is removed.\n    * To create a uni-directional bipartite graph, use `dgl.heterograph`. E.g.,\n        ```python\n        g = dgl.hetrograph(('user', 'rates', 'movie'): ...)\n        ```\n    * To create a uni-directional bipartite graph from a SciPy matrix, use the new API `dgl.bipartite_from_scipy`.\n    * To create a uni-directional bipartite graph from a NetworkX graph, use the new API `dgl.bipartite_from_networkx`.\n1. Changes to `dgl.heterograph`:\n    * No longer accept SciPy matrix/NetworkX graph as the input data. Use the `from_*` APIs to create graphs first and then pass their edges to the `dgl.heterograph` API. E.g.,\n        ```python\n        nx_g = ...  # some networkx graph\n        spmat = ...  # some scipy matrix\n        g1 = dgl.from_networkx(nx_g)\n        g2 = dgl.bipartite_from_scipy(spmat)\n        g = dgl.heterograph({('user', 'follows', 'user') : g1.edges(),\n                             ('user', 'rates', 'movie') : g2.edges()})\n        ```\n1. `dgl.hetero_from_relations` is removed. Use `dgl.heterograph` instead.\n1. From 0.5, subgraphs extracted via DGL APIs automatically inherits node and edge features from the parent graph. DGL also saves the original nodes/edge IDs in `subg.ndata[dgl.NID]` and `subg.edata[dgl.EID]` if nodes/edges are relabeled. This new behavior makes the following `DGLGraph` methods useless and we thus remove them:\n    * `DGLGraph.parent`, `DGLGraph.parent_nid`, `DGLGraph.parent_eid`, `DGLGraph.map_to_subgraph_nid`, `DGLGraph.copy_from_parent`, `DGLGraph.copy_to_parent` and `DGLGraph.detach_parent`.\n1. Other removed DGLGraph APIs:\n    * `DGLGraph.from_networkx`. Use `dgl.from_networkx` to construct a DGLGraph from a NetworkX graph.\n    * `DGLGraph.from_scipy_sparse_matrix`. Use `dgl.from_scipy` to construct a DGLGraph from a SciPy matrix.\n    * `DGLGraph.register_apply_node_func` , `DGLGraph.register_apply_edge_func`, `DGLGraph.register_message_func` and `DGLGraph.register_reduce_func`. Please specify them directly as the arguments of the message passing APIs.\n        ```python\n        g = ...  # some graph\n        # before 0.5\n        g.register_message_func(mfunc)\n        g.register_reduce_func(rfunc)\n        g.update_all()\n        \n        # starting from 0.5\n        g.update_all(mfunc, rfunc)\n        ```\n    * `DGLGraph.group_apply_edges`. To normalize edge weights within the neighborhood of each destination node, use `dgl.nn.edge_softmax`. To normalize edge weights within the neighborhood of each source node, use `dgl.reverse` first before the edge softmax.\n    * `DGLGraph.send` and `DGLGraph.recv`. There are rarely any cases where send and recv must be invoked separately. Use `DGLGraph.send_and_recv` or `DGLGraph.update_all` for message passing.\n    * `DGLGraph.multi_recv`, `DGLGraph.multi_pull`, `DGLGraph.multi_send_and_recv`. To perform message passing on a part  of the nodes and edges, use `dgl.node_subgraph` or `dgl.edge_subgraph` to extract the subset first and then call `DGLGraph.multi_update_all`.\n    * `DGLGraph.clear`. Use `dgl.graph(([], []))`` to create a new empty graph.\n    * `DGLGraph.subgraphs`. Use `DGLGraph.subgraph`.\n    * `DGLGraph.batch_num_nodes` and `DGLGraph.batch_num_edges` are now functions that accept node/edge type as the only argument for getting batching information of a heterograph.\n    * `DGLGraph.flatten`. To create a new graph without batching information, use `new_g = gl.graph(old_g.edges())``.\n1. The reduce function `dgl.function.prod` is removed.\n1. `dgl.add_self_loop` will NOT remove existing self loops automatically. It is recommanded to call `dgl.remove_self_loop` before invoking `dgl.add_self_loop`.\n\n\n\n## Deprecations\n\nWill not break old codes but will throw deprecation warning.\n\n### Core APIs\n\n1. Creating a graph using `dgl.DGLGraph(data)` is deprecated. Use `dgl.graph(data)`.\n1. Deprecated `DGLGraph` methods:\n    - `DGLGraph.to_networkx` -> `dgl.to_networkx`\n    - `DGLGraph.readonly` and `DGLGraph.is_readonly`. Before 0.5, this flag is a hint for more efficient implementation. From 0.5, the efficiency issue has been resolved so they become useless. \n    - `DGLGraph.__len__` -> `DGLGraph.number_of_nodes`\n    - `dgl.DGLGraph.__contains__` -> `DGLGraph.has_nodes`\n    - `DGLGraph.add_node` -> `DGLGraph.add_nodes`\n    - `DGLGraph.add_edge` -> `DGLGraph.add_edges`\n    - `DGLGraph.has_node` -> `DGLGraph.has_nodes`\n    - `DGLGraph.has_edge_between` -> `DGLGraph.has_edges_between` \n    - `DGLGraph.edge_id` -> `dgl.DGLGraph.edge_ids`.\n    - `DGLGraph.in_degree` -> `dgl.DGLGraph.in_degrees`.\n    - `DGLGraph.out_degree` -> `dgl.DGLGraph.out_degrees`.\n1. `dgl.to_simple_graph` -> `dgl.to_simple`.\n1. `dgl.to_homo` -> `dgl.to_homogeneous`.\n1. `dgl.to_hetero` -> `dgl.to_heterogeneous`.\n1. `dgl.as_heterograph` and `dgl.as_immutable_graph` are deprecated as `dgl.DGLGraph` and `dgl.DGLHeteroGraph` are now merged.\n1. `dgl.batch_hetero` -> `dgl.batch`\n1. `dgl.unbatch_hetero` -> `dgl.unbatch`\n1. The `node_attrs` / `edge_attrs` arguments of `dgl.batch` are renamed to `ndata` / `edata`.\n1. The arguments `share_ndata` and `share_edata` of `dgl.reverse` are renamed to `copy_ndata` and `copy_edata`.\n\n### Dataset APIs\n\nFor all the current datsets, their class attributes such as `graph`, `feat`, etc. are deprecated. The recommended usage is to get them from each sample:\n```python\n# Before 0.5\ndataset = dgl.data.CoraFull()\ng = dataset.graph\nfeat = dataset.feat\n...\n\n# From 0.5\ndataset = dgl.data.CoraFullDataset()  # in 0.5, all the classes have a \"Dataset\" in the name.\ng = dataset[0]  # is directly a DGLGraph object\nfeat = g.ndata['feat']\n...\n```\n\n**Other changes**\n* ``dgl.data.SST`` is deprecated and replaced by ``dgl.data.SSTDataset``. The attribute ``trees`` is deprecated and replaced by ``__getitem__``. The attribute ``num_vocabs`` is deprecated and replaced by ``vocab_size``\n"
  },
  {
    "path": "docs/source/_static/css/custom.css",
    "content": ".wy-table-responsive table td,\n.wy-table-responsive table th {\n  white-space: normal;\n}\n\n.wy-table-bordered-all,\n.rst-content table.docutils {\n  border: none;\n}\n\n.wy-table-bordered-all td,\n.rst-content table.docutils td {\n  border: none;\n}\n\n.wy-table td,\n.rst-content table.docutils td,\n.rst-content table.field-list td,\n.wy-table th,\n.rst-content table.docutils th,\n.rst-content table.field-list th {\n  padding: 14px;\n}"
  },
  {
    "path": "docs/source/_templates/classtemplate.rst",
    "content": ".. role:: hidden\n    :class: hidden-section\n.. currentmodule:: {{ module }}\n\n\n{{ name | underline}}\n\n.. autoclass:: {{ name }}\n    :show-inheritance:\n    :members: __getitem__, __len__, collate_fn, forward, reset_parameters, rel_emb, rel_project, explain_node, explain_graph, train_step, train_step_node\n"
  },
  {
    "path": "docs/source/_templates/graphbolt_classtemplate.rst",
    "content": ".. role:: hidden\n    :class: hidden-section\n.. currentmodule:: {{ module }}\n\n\n{{ name | underline}}\n\n.. autoclass:: {{ name }}\n    :show-inheritance:\n    :members:\n    :member-order: groupwise\n"
  },
  {
    "path": "docs/source/api/python/dgl.DGLGraph.rst",
    "content": ".. _apigraph:\n\ndgl.DGLGraph\n=====================================================\n\n.. currentmodule:: dgl\n.. class:: DGLGraph\n\n    Class for storing graph structure and node/edge feature data.\n\n    There are a few ways to create a DGLGraph:\n\n    * To create a homogeneous graph from Tensor data, use :func:`dgl.graph`.\n    * To create a heterogeneous graph from Tensor data, use :func:`dgl.heterograph`.\n    * To create a graph from other data sources, use ``dgl.*`` create ops. See\n      :ref:`api-graph-create-ops`.\n\n    Read the user guide chapter :ref:`guide-graph` for an in-depth explanation about its\n    usage.\n\nQuerying metagraph structure\n----------------------------\n\nMethods for getting information about the node and edge types. They are typically useful\nwhen the graph is heterogeneous.\n\n.. autosummary::\n    :toctree: ../../generated/\n\n    DGLGraph.ntypes\n    DGLGraph.etypes\n    DGLGraph.srctypes\n    DGLGraph.dsttypes\n    DGLGraph.canonical_etypes\n    DGLGraph.metagraph\n    DGLGraph.to_canonical_etype\n\n.. _apigraph-querying-graph-structure:\n\nQuerying graph structure\n------------------------\n\nMethods for getting information about the graph structure such as capacity, connectivity,\nneighborhood, etc.\n\n.. autosummary::\n    :toctree: ../../generated/\n\n    DGLGraph.num_nodes\n    DGLGraph.number_of_nodes\n    DGLGraph.num_edges\n    DGLGraph.number_of_edges\n    DGLGraph.num_src_nodes\n    DGLGraph.number_of_src_nodes\n    DGLGraph.num_dst_nodes\n    DGLGraph.number_of_dst_nodes\n    DGLGraph.is_unibipartite\n    DGLGraph.is_multigraph\n    DGLGraph.is_homogeneous\n    DGLGraph.has_nodes\n    DGLGraph.has_edges_between\n    DGLGraph.predecessors\n    DGLGraph.successors\n    DGLGraph.edge_ids\n    DGLGraph.find_edges\n    DGLGraph.in_edges\n    DGLGraph.out_edges\n    DGLGraph.in_degrees\n    DGLGraph.out_degrees\n\nQuerying and manipulating sparse format\n---------------------------------------\n\nMethods for getting or manipulating the internal storage formats of a ``DGLGraph``.\n\n.. autosummary::\n    :toctree: ../../generated/\n\n    DGLGraph.formats\n    DGLGraph.create_formats_\n\nQuerying and manipulating node/edge ID type\n-----------------------------------------\n\nMethods for getting or manipulating the data type for storing structure-related\ndata such as node and edge IDs.\n\n.. autosummary::\n    :toctree: ../../generated/\n\n    DGLGraph.idtype\n    DGLGraph.long\n    DGLGraph.int\n\nUsing Node/edge features\n------------------------\n\nMethods for getting or setting the data type for storing structure-related\ndata such as node and edge IDs.\n\n.. autosummary::\n    :toctree: ../../generated/\n\n    DGLGraph.nodes\n    DGLGraph.ndata\n    DGLGraph.edges\n    DGLGraph.edata\n    DGLGraph.node_attr_schemes\n    DGLGraph.edge_attr_schemes\n    DGLGraph.srcnodes\n    DGLGraph.dstnodes\n    DGLGraph.srcdata\n    DGLGraph.dstdata\n\nTransforming graph\n------------------\n\nMethods for generating a new graph by transforming the current ones. Most of them\nare alias of the :ref:`api-subgraph-extraction` and :ref:`api-transform`\nunder the ``dgl`` namespace.\n\n.. autosummary::\n    :toctree: ../../generated/\n\n    DGLGraph.subgraph\n    DGLGraph.edge_subgraph\n    DGLGraph.node_type_subgraph\n    DGLGraph.edge_type_subgraph\n    DGLGraph.__getitem__\n    DGLGraph.line_graph\n    DGLGraph.reverse\n    DGLGraph.add_self_loop\n    DGLGraph.remove_self_loop\n    DGLGraph.to_simple\n    DGLGraph.to_cugraph\n    DGLGraph.reorder_graph\n\nAdjacency and incidence matrix\n---------------------------------\n\nMethods for getting the adjacency and the incidence matrix of the graph.\n\n.. autosummary::\n    :toctree: ../../generated/\n\n    DGLGraph.adj\n    DGLGraph.adjacency_matrix\n    DGLGraph.adj_tensors\n    DGLGraph.adj_external\n    DGLGraph.inc\n    DGLGraph.incidence_matrix\n\nComputing with DGLGraph\n-----------------------------\n\nMethods for performing message passing, applying functions on node/edge features, etc.\n\n.. autosummary::\n    :toctree: ../../generated/\n\n    DGLGraph.apply_nodes\n    DGLGraph.apply_edges\n    DGLGraph.send_and_recv\n    DGLGraph.pull\n    DGLGraph.push\n    DGLGraph.update_all\n    DGLGraph.multi_update_all\n    DGLGraph.prop_nodes\n    DGLGraph.prop_edges\n    DGLGraph.filter_nodes\n    DGLGraph.filter_edges\n\nQuerying and manipulating batch information\n----------------------------------------------\n\nMethods for getting/setting the batching information if the current graph is a batched\ngraph generated from :func:`dgl.batch`. They are also widely used in the\n:ref:`api-batch`.\n\n.. autosummary::\n    :toctree: ../../generated/\n\n    DGLGraph.batch_size\n    DGLGraph.batch_num_nodes\n    DGLGraph.batch_num_edges\n    DGLGraph.set_batch_num_nodes\n    DGLGraph.set_batch_num_edges\n\n\nMutating topology\n-----------------\n\nMethods for mutating the graph structure *in-place*.\n\n.. autosummary::\n    :toctree: ../../generated/\n\n    DGLGraph.add_nodes\n    DGLGraph.add_edges\n    DGLGraph.remove_nodes\n    DGLGraph.remove_edges\n\nDevice Control\n--------------\n\nMethods for getting or changing the device on which the graph is hosted.\n\n.. autosummary::\n    :toctree: ../../generated/\n\n    DGLGraph.to\n    DGLGraph.device\n    DGLGraph.cpu\n    DGLGraph.pin_memory_\n    DGLGraph.unpin_memory_\n    DGLGraph.is_pinned\n\nMisc\n----\n\nOther utility methods.\n\n.. autosummary::\n    :toctree: ../../generated/\n\n    DGLGraph.local_scope\n"
  },
  {
    "path": "docs/source/api/python/dgl.data.rst",
    "content": ".. _apidata:\n\ndgl.data\n=========\n\n.. currentmodule:: dgl.data\n.. automodule:: dgl.data\n\nBase Class\n---------------------------------------\n\n.. autosummary::\n    :toctree: ../../generated/\n    :nosignatures:\n    :template: classtemplate.rst\n\n    DGLDataset\n    CSVDataset\n\nNode Prediction Datasets\n---------------------------------------\n\nDatasets for node classification/regression tasks\n\n.. autosummary::\n    :toctree: ../../generated/\n    :nosignatures:\n    :template: classtemplate.rst\n\n    SSTDataset\n    KarateClubDataset\n    CoraGraphDataset\n    CiteseerGraphDataset\n    PubmedGraphDataset\n    CoraFullDataset\n    AIFBDataset\n    MUTAGDataset\n    BGSDataset\n    AMDataset\n    AmazonCoBuyComputerDataset\n    AmazonCoBuyPhotoDataset\n    CoauthorCSDataset\n    CoauthorPhysicsDataset\n    PPIDataset\n    RedditDataset\n    SBMMixtureDataset\n    FraudDataset\n    FraudYelpDataset\n    FraudAmazonDataset\n    BAShapeDataset\n    BACommunityDataset\n    TreeCycleDataset\n    TreeGridDataset\n    WikiCSDataset\n    FlickrDataset\n    YelpDataset\n    PATTERNDataset\n    CLUSTERDataset\n    ChameleonDataset\n    SquirrelDataset\n    ActorDataset\n    CornellDataset\n    TexasDataset\n    WisconsinDataset\n    RomanEmpireDataset\n    AmazonRatingsDataset\n    MinesweeperDataset\n    TolokersDataset\n    QuestionsDataset\n    MovieLensDataset\n    VOCSuperpixelsDataset\n    COCOSuperpixelsDataset\n\n\nEdge Prediction Datasets\n---------------------------------------\n\nDatasets for edge classification/regression and link prediction\n\n.. autosummary::\n    :toctree: ../../generated/\n    :nosignatures:\n    :template: classtemplate.rst\n\n    FB15k237Dataset\n    FB15kDataset\n    WN18Dataset\n    BitcoinOTCDataset\n    ICEWS18Dataset\n    GDELTDataset\n\nGraph Prediction Datasets\n---------------------------------------\n\nDatasets for graph classification/regression tasks\n\n.. autosummary::\n    :toctree: ../../generated/\n    :nosignatures:\n    :template: classtemplate.rst\n\n    QM7bDataset\n    QM9Dataset\n    QM9EdgeDataset\n    MiniGCDataset\n    TUDataset\n    LegacyTUDataset\n    GINDataset\n    FakeNewsDataset\n    BA2MotifDataset\n    ZINCDataset\n    PeptidesStructuralDataset\n    PeptidesFunctionalDataset\n    MNISTSuperPixelDataset\n    CIFAR10SuperPixelDataset\n\nDataset adapters\n-------------------\n\n.. autosummary::\n    :toctree: ../../generated/\n    :nosignatures:\n    :template: classtemplate.rst\n\n    AsNodePredDataset\n    AsLinkPredDataset\n    AsGraphPredDataset\n\nUtilities\n-----------------\n\n.. autosummary::\n    :toctree: ../../generated/\n    :nosignatures:\n    :template: classtemplate.rst\n\n    utils.get_download_dir\n    utils.download\n    utils.check_sha1\n    utils.extract_archive\n    utils.split_dataset\n    utils.load_labels\n    utils.save_info\n    utils.load_info\n    utils.add_nodepred_split\n    utils.mask_nodes_by_property\n    utils.add_node_property_split\n    utils.Subset\n"
  },
  {
    "path": "docs/source/api/python/dgl.dataloading.rst",
    "content": ".. _api-dataloading:\n\ndgl.dataloading\n=================================\n\n.. currentmodule:: dgl.dataloading\n\nThe ``dgl.dataloading`` package provides two primitives to compose a data pipeline\nfor loading from graph data. ``Sampler`` represents algorithms\nto generate subgraph samples from the original graph, and ``DataLoader``\nrepresents the iterable over these samples.\n\nDGL provides a number of built-in samplers that subclass :class:`~dgl.dataloading.Sampler`.\nCreating new samplers follow the same paradigm. Read our user guide chapter\n:ref:`guide-minibatch` for more examples and explanations.\n\nThe entire package only works for PyTorch backend.\n\nDataLoaders\n-----------\n\n.. autosummary::\n    :toctree: ../../generated/\n    :nosignatures:\n    :template: classtemplate.rst\n\n    DataLoader\n    GraphDataLoader\n\n.. _api-dataloading-neighbor-sampling:\n\nSamplers\n--------\n\n.. autosummary::\n    :toctree: ../../generated/\n    :nosignatures:\n    :template: classtemplate.rst\n\n    Sampler\n    NeighborSampler\n    LaborSampler\n    MultiLayerFullNeighborSampler\n    ClusterGCNSampler\n    ShaDowKHopSampler\n    SAINTSampler\n\nSampler Transformations\n-----------------------\n\n.. autosummary::\n    :toctree: ../../generated/\n\n    as_edge_prediction_sampler\n    BlockSampler\n\n.. _api-dataloading-negative-sampling:\n\nNegative Samplers for Link Prediction\n-------------------------------------\n.. currentmodule:: dgl.dataloading.negative_sampler\n\n.. autosummary::\n    :toctree: ../../generated/\n    :nosignatures:\n    :template: classtemplate.rst\n\n    Uniform\n    PerSourceUniform\n    GlobalUniform\n\nUtility Class and Functions for Feature Prefetching\n---------------------------------------------------\n.. currentmodule:: dgl.dataloading.base\n\n.. autosummary::\n    :toctree: ../../generated/\n    :nosignatures:\n    :template: classtemplate.rst\n\n    set_node_lazy_features\n    set_edge_lazy_features\n    set_src_lazy_features\n    set_dst_lazy_features\n    LazyFeature\n"
  },
  {
    "path": "docs/source/api/python/dgl.distributed.rst",
    "content": ".. _api-distributed:\n\ndgl.distributed\n=================================\n\n.. currentmodule:: dgl.distributed\n\nDGL distributed module contains classes and functions to support\ndistributed Graph Neural Network training and inference on a cluster of\nmachines.\n\nThis includes a few submodules:\n\n* distributed data structures including distributed graph, distributed tensor\n  and distributed embeddings.\n* distributed sampling.\n* distributed workload split at runtime.\n* graph partition.\n\n\nInitialization\n---------------\n\n.. autosummary::\n    :toctree: ../../generated/\n\n    initialize\n\nDistributed Graph\n-----------------\n\n.. autoclass:: DistGraph\n    :members: ndata, edata, idtype, device, ntypes, etypes, number_of_nodes, number_of_edges, node_attr_schemes, edge_attr_schemes, rank, find_edges, get_partition_book, barrier, local_partition, num_nodes, num_edges, get_node_partition_policy, get_edge_partition_policy, get_etype_id, get_ntype_id, nodes, edges, out_degrees, in_degrees\n\nDistributed Tensor\n------------------\n\n.. autoclass:: DistTensor\n    :members: part_policy, shape, dtype, name\n\nDistributed Node Embedding\n---------------------\n\n.. autoclass:: DistEmbedding\n\n\nDistributed embedding optimizer\n-------------------------\n\n.. autoclass:: dgl.distributed.optim.SparseAdagrad\n    :members: step, save, load\n\n.. autoclass:: dgl.distributed.optim.SparseAdam\n    :members: step, save, load\n\nDistributed workload split\n--------------------------\n\n.. autosummary::\n    :toctree: ../../generated/\n\n    node_split\n    edge_split\n\nDistributed Sampling\n--------------------\n\nDistributed DataLoader\n``````````````````````\n\n.. autoclass:: NodeCollator\n\n.. autoclass:: EdgeCollator\n\n.. autoclass:: DistDataLoader\n\n.. autoclass:: DistNodeDataLoader\n\n.. autoclass:: DistEdgeDataLoader\n\n.. _api-distributed-sampling-ops:\nDistributed Graph Sampling Operators\n```````````````````````````````````````\n\n.. autosummary::\n    :toctree: ../../generated/\n\n    sample_neighbors\n    sample_etype_neighbors\n    find_edges\n    in_subgraph\n\nPartition\n---------\n\nGraph partition book\n````````````````````\n\n.. autoclass:: GraphPartitionBook\n    :members: shared_memory, num_partitions, metadata, nid2partid, eid2partid, partid2nids, partid2eids, nid2localnid, eid2localeid, partid, map_to_per_ntype, map_to_per_etype, map_to_homo_nid, map_to_homo_eid, canonical_etypes\n\n.. autoclass:: PartitionPolicy\n    :members: policy_str, part_id, partition_book, to_local, to_partid, get_part_size, get_size\n\nSplit and Load Partitions\n````````````````````````````\n\n.. autosummary::\n    :toctree: ../../generated/\n\n    load_partition\n    load_partition_feats\n    load_partition_book\n    partition_graph\n    dgl_partition_to_graphbolt\n"
  },
  {
    "path": "docs/source/api/python/dgl.function.rst",
    "content": ".. _apifunction:\n\n.. currentmodule:: dgl.function\n\ndgl.function\n==================================\n\nThis subpackage hosts all the **built-in functions** provided by DGL. Built-in functions\nare DGL's recommended way to express different types of :ref:`guide-message-passing` computation\n(i.e., via :func:`~dgl.DGLGraph.update_all`) or computing edge-wise features from\nnode-wise features (i.e., via :func:`~dgl.DGLGraph.apply_edges`). Built-in functions\ndescribe the node-wise and edge-wise computation in a symbolic way without any\nactual computation, so DGL can analyze and map them to efficient low-level kernels.\nHere are some examples:\n\n.. code:: python\n\n   import dgl\n   import dgl.function as fn\n   import torch as th\n   g = ... # create a DGLGraph\n   g.ndata['h'] = th.randn((g.num_nodes(), 10)) # each node has feature size 10\n   g.edata['w'] = th.randn((g.num_edges(), 1))  # each edge has feature size 1\n   # collect features from source nodes and aggregate them in destination nodes\n   g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h_sum'))\n   # multiply source node features with edge weights and aggregate them in destination nodes\n   g.update_all(fn.u_mul_e('h', 'w', 'm'), fn.max('m', 'h_max'))\n   # compute edge embedding by multiplying source and destination node embeddings\n   g.apply_edges(fn.u_mul_v('h', 'h', 'w_new'))\n\n``fn.copy_u``, ``fn.u_mul_e``, ``fn.u_mul_v`` are built-in message functions, while ``fn.sum``\nand ``fn.max`` are built-in reduce functions. DGL's convention is to use ``u``, ``v``\nand ``e`` to represent source nodes, destination nodes, and edges, respectively.\nFor example, ``copy_u`` tells DGL to copy the source node data as the messages;\n``u_mul_e`` tells DGL to multiply source node features with edge features.\n\nTo define a unary message function (e.g. ``copy_u``), specify one input feature name and one output\nmessage name. To define a binary message function (e.g. ``u_mul_e``), specify\ntwo input feature names and one output message name. During the computation,\nthe message function will read the data under the given names, perform computation, and return\nthe output using the output name. For example, the above ``fn.u_mul_e('h', 'w', 'm')`` is\nthe same as the following user-defined function:\n\n.. code:: python\n\n   def udf_u_mul_e(edges):\n      return {'m' : edges.src['h'] * edges.data['w']}\n\nTo define a reduce function, one input message name and one output node feature name\nneed to be specified. For example, the above ``fn.max('m', 'h_max')`` is the same as the\nfollowing user-defined function:\n\n.. code:: python\n\n   def udf_max(nodes):\n      return {'h_max' : th.max(nodes.mailbox['m'], 1)[0]}\n\nAll binary message function supports **broadcasting**, a mechanism for extending element-wise\noperations to tensor inputs with different shapes. DGL generally follows the standard\nbroadcasting semantic by `NumPy <https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html>`_\nand `PyTorch <https://pytorch.org/docs/stable/notes/broadcasting.html>`_. Below are some\nexamples:\n\n.. code:: python\n\n   import dgl\n   import dgl.function as fn\n   import torch as th\n   g = ... # create a DGLGraph\n\n   # case 1\n   g.ndata['h'] = th.randn((g.num_nodes(), 10))\n   g.edata['w'] = th.randn((g.num_edges(), 1))\n   # OK, valid broadcasting between feature shapes (10,) and (1,)\n   g.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h_new'))\n   g.ndata['h_new']  # shape: (g.num_nodes(), 10)\n\n   # case 2\n   g.ndata['h'] = th.randn((g.num_nodes(), 5, 10))\n   g.edata['w'] = th.randn((g.num_edges(), 10))\n   # OK, valid broadcasting between feature shapes (5, 10) and (10,)\n   g.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h_new'))\n   g.ndata['h_new']  # shape: (g.num_nodes(), 5, 10)\n\n   # case 3\n   g.ndata['h'] = th.randn((g.num_nodes(), 5, 10))\n   g.edata['w'] = th.randn((g.num_edges(), 5))\n   # NOT OK, invalid broadcasting between feature shapes (5, 10) and (5,)\n   # shapes are aligned from right\n   g.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h_new'))\n\n   # case 3\n   g.ndata['h1'] = th.randn((g.num_nodes(), 1, 10))\n   g.ndata['h2'] = th.randn((g.num_nodes(), 5, 1))\n   # OK, valid broadcasting between feature shapes (1, 10) and (5, 1)\n   g.apply_edges(fn.u_add_v('h1', 'h2', 'x'))  # apply_edges also supports broadcasting\n   g.edata['x']  # shape: (g.num_edges(), 5, 10)\n\n   # case 4\n   g.ndata['h1'] = th.randn((g.num_nodes(), 1, 10, 128))\n   g.ndata['h2'] = th.randn((g.num_nodes(), 5, 1, 128))\n   # OK, u_dot_v supports broadcasting but requires the last dimension to match\n   g.apply_edges(fn.u_dot_v('h1', 'h2', 'x'))\n   g.edata['x']  # shape: (g.num_edges(), 5, 10, 1)\n\n\n.. _api-built-in:\n\nDGL Built-in Function\n-------------------------\n\nHere is a cheatsheet of all the DGL built-in functions.\n\n+-------------------------+-----------------------------------------------------------------+-----------------------+\n| Category                | Functions                                                       | Memo                  |\n+=========================+=================================================================+=======================+\n| Unary message function  | ``copy_u``                                                      |                       |\n|                         +-----------------------------------------------------------------+-----------------------+\n|                         | ``copy_e``                                                      |                       |\n+-------------------------+-----------------------------------------------------------------+-----------------------+\n| Binary message function | ``u_add_v``, ``u_sub_v``, ``u_mul_v``, ``u_div_v``, ``u_dot_v`` |                       |\n|                         +-----------------------------------------------------------------+-----------------------+\n|                         | ``u_add_e``, ``u_sub_e``, ``u_mul_e``, ``u_div_e``, ``u_dot_e`` |                       |\n|                         +-----------------------------------------------------------------+-----------------------+\n|                         | ``v_add_u``, ``v_sub_u``, ``v_mul_u``, ``v_div_u``, ``v_dot_u`` |                       |\n|                         +-----------------------------------------------------------------+-----------------------+\n|                         | ``v_add_e``, ``v_sub_e``, ``v_mul_e``, ``v_div_e``, ``v_dot_e`` |                       |\n|                         +-----------------------------------------------------------------+-----------------------+\n|                         | ``e_add_u``, ``e_sub_u``, ``e_mul_u``, ``e_div_u``, ``e_dot_u`` |                       |\n|                         +-----------------------------------------------------------------+-----------------------+\n|                         | ``e_add_v``, ``e_sub_v``, ``e_mul_v``, ``e_div_v``, ``e_dot_v`` |                       |\n+-------------------------+-----------------------------------------------------------------+-----------------------+\n| Reduce function         | ``max``                                                         |                       |\n|                         +-----------------------------------------------------------------+-----------------------+\n|                         | ``min``                                                         |                       |\n|                         +-----------------------------------------------------------------+-----------------------+\n|                         | ``sum``                                                         |                       |\n|                         +-----------------------------------------------------------------+-----------------------+\n|                         | ``mean``                                                        |                       |\n+-------------------------+-----------------------------------------------------------------+-----------------------+\n\nMessage functions\n-----------------\n\n.. autosummary::\n    :toctree: ../../generated/\n\n    copy_u\n    copy_e\n    u_add_v\n    u_sub_v\n    u_mul_v\n    u_div_v\n    u_add_e\n    u_sub_e\n    u_mul_e\n    u_div_e\n    v_add_u\n    v_sub_u\n    v_mul_u\n    v_div_u\n    v_add_e\n    v_sub_e\n    v_mul_e\n    v_div_e\n    e_add_u\n    e_sub_u\n    e_mul_u\n    e_div_u\n    e_add_v\n    e_sub_v\n    e_mul_v\n    e_div_v\n    u_dot_v\n    u_dot_e\n    v_dot_e\n    v_dot_u\n    e_dot_u\n    e_dot_v\n\nReduce functions\n----------------\n\n.. autosummary::\n    :toctree: ../../generated/\n\n    sum\n    max\n    min\n    mean\n"
  },
  {
    "path": "docs/source/api/python/dgl.geometry.rst",
    "content": ".. _api-geometry:\n\ndgl.geometry\n=================================\n\n.. automodule:: dgl.geometry\n\n.. _api-geometry-farthest-point-sampler:\n\nFarthest Point Sampler\n-----------\n\nFarthest point sampling is a greedy algorithm that samples from a point cloud\ndata iteratively. It starts from a random single sample of point. In each iteration,\nit samples from the rest points that is the farthest from the set of sampled points.\n\n.. autoclass:: farthest_point_sampler\n\n.. _api-geometry-neighbor-matching:\n\nNeighbor Matching\n-----------------------------\n\nNeighbor matching is an important module in the Graclus clustering algorithm.\n\n.. autoclass:: neighbor_matching\n"
  },
  {
    "path": "docs/source/api/python/dgl.graphbolt.rst",
    "content": ".. _apibackend:\n\n🆕 dgl.graphbolt\n=================================\n\n.. currentmodule:: dgl.graphbolt\n\n**dgl.graphbolt** is a dataloading framework for GNNs that provides well-defined\nAPIs for each stage of the data pipeline and multiple standard implementations.\n\nDataset\n-------\n\nA dataset is a collection of graph structure data, feature data and tasks.\n\n.. autosummary::\n    :toctree: ../../generated/\n    :nosignatures:\n    :template: graphbolt_classtemplate.rst\n\n    Dataset\n    OnDiskDataset\n    BuiltinDataset\n    LegacyDataset\n    Task\n\nGraph\n-----\n\nA graph is a collection of nodes and edges. It can be a homogeneous graph or a\nheterogeneous graph.\n\n.. autosummary::\n    :toctree: ../../generated/\n    :nosignatures:\n    :template: graphbolt_classtemplate.rst\n\n    SamplingGraph\n    FusedCSCSamplingGraph\n\n\nFeature and FeatureStore\n------------------------\n\nA feature is a collection of data(tensor, array). A feature store is a\ncollection of features.\n\n.. autosummary::\n    :toctree: ../../generated/\n    :nosignatures:\n    :template: graphbolt_classtemplate.rst\n\n    Feature\n    FeatureStore\n    BasicFeatureStore\n    TorchBasedFeature\n    TorchBasedFeatureStore\n    DiskBasedFeature\n    CPUCachedFeature\n    GPUCachedFeature\n\n\nDataLoader\n----------\n\nA dataloader is for iterating over a dataset and generate mini-batches.\n\n.. autosummary::\n    :toctree: ../../generated/\n    :nosignatures:\n    :template: graphbolt_classtemplate.rst\n\n    DataLoader\n\n\nItemSet\n-------\n\nAn item set is an iterable collection of items.\n\n.. autosummary::\n    :toctree: ../../generated/\n    :nosignatures:\n    :template: graphbolt_classtemplate.rst\n\n    ItemSet\n    HeteroItemSet\n\n\nItemSampler\n-----------\n\nAn item sampler is for sampling items from an item set.\n\n.. autosummary::\n    :toctree: ../../generated/\n    :nosignatures:\n    :template: graphbolt_classtemplate.rst\n\n    ItemSampler\n    DistributedItemSampler\n\n\nMiniBatch\n---------\n\nA mini-batch is a collection of sampled subgraphs and their corresponding\nfeatures. It is the basic unit for training a GNN model.\n\n.. autosummary::\n    :toctree: ../../generated/\n    :nosignatures:\n    :template: graphbolt_classtemplate.rst\n\n    MiniBatch\n    MiniBatchTransformer\n\n\nNegativeSampler\n---------------\n\nA negative sampler is for sampling negative items from mini-batches.\n\n.. autosummary::\n    :toctree: ../../generated/\n    :nosignatures:\n    :template: graphbolt_classtemplate.rst\n\n    NegativeSampler\n    UniformNegativeSampler\n\n\nSubgraphSampler\n---------------\n\nA subgraph sampler is for sampling subgraphs from a graph.\n\n.. autosummary::\n    :toctree: ../../generated/\n    :nosignatures:\n    :template: graphbolt_classtemplate.rst\n\n    SubgraphSampler\n    SampledSubgraph\n    NeighborSampler\n    LayerNeighborSampler\n    TemporalNeighborSampler\n    TemporalLayerNeighborSampler\n    SampledSubgraphImpl\n    FusedSampledSubgraphImpl\n    InSubgraphSampler\n\n\nFeatureFetcher\n--------------\n\nA feature fetcher is for fetching features from a feature store.\n\n.. autosummary::\n    :toctree: ../../generated/\n    :nosignatures:\n    :template: graphbolt_classtemplate.rst\n\n    FeatureFetcher\n\n\nCopyTo\n------\n\nThis datapipe is for copying data to a device.\n\n.. autosummary::\n    :toctree: ../../generated/\n    :nosignatures:\n    :template: graphbolt_classtemplate.rst\n\n    CopyTo\n\n\nUtilities\n---------\n\n.. autosummary::\n    :toctree: ../../generated/\n    :nosignatures:\n\n    cpu_cached_feature\n    gpu_cached_feature\n    fused_csc_sampling_graph\n    load_from_shared_memory\n    from_dglgraph\n    etype_str_to_tuple\n    etype_tuple_to_str\n    isin\n    seed\n    index_select\n    expand_indptr\n    indptr_edge_ids\n    add_reverse_edges\n    exclude_seed_edges\n    compact_csc_format\n    unique_and_compact\n    unique_and_compact_csc_formats\n    numpy_save_aligned\n"
  },
  {
    "path": "docs/source/api/python/dgl.multiprocessing.rst",
    "content": ".. _apimultiprocessing:\n\ndgl.multiprocessing\n===================\n\nThis is a minimal wrapper of Python's native :mod:`multiprocessing` module.\nIt modifies the :class:`multiprocessing.Process` class to make forking\nwork with OpenMP in the DGL core library.\n\nThe API usage is exactly the same as the native module, so DGL does not provide\nadditional documentation.\n\nIn addition, if your backend is PyTorch, this module will also be compatible with\n:mod:`torch.multiprocessing` module.\n\n.. currentmodule:: dgl.multiprocessing.pytorch\n.. autosummary::\n    :toctree: ../../generated/\n\n    call_once_and_share\n    shared_tensor\n"
  },
  {
    "path": "docs/source/api/python/dgl.ops.rst",
    "content": ".. _apibackend:\n\n.. currentmodule:: dgl.ops\n\ndgl.ops\n==================================\n\nFrame-agnostic operators for message passing on graphs.\n\nGSpMM functions\n---------------\n\nGeneralized Sparse-Matrix Dense-Matrix Multiplication functions.\nIt *fuses* two steps into one kernel.\n\n1. Computes messages by add/sub/mul/div source node and edge features,\n   or copy node features to edges.\n2. Aggregate the messages by sum/max/min/mean as the features on destination nodes.\n\nOur implementation supports tensors on CPU/GPU in PyTorch/MXNet/Tensorflow\nas input. All operators are equipped with autograd (computing the input gradients\ngiven output gradient) and broadcasting (if the feature shape of operands do not\nmatch, we first broadcast them to the same shape, then applies the binary\noperators). Our broadcast semantics follows NumPy, please see\nhttps://docs.scipy.org/doc/numpy/user/basics.broadcasting.html\nfor more details.\n\nWhat do we mean by *fuses* is that the messages are not materialized on edges,\ninstead we compute the result on destination nodes directly, thus saving memory\ncost. The space complexity of GSpMM operators is :math:`O(|N|D)` where :math:`|N|`\nrefers to the number of nodes in the graph, and :math:`D` refers to the feature\nsize (:math:`D=\\prod_{i=1}^{N}D_i` if your feature is a multi-dimensional tensor).\n\nThe following is an example showing how GSpMM works (we use PyTorch as the backend\nhere, you can enjoy the same convenience on other frameworks by similar usage):\n\n   >>> import dgl\n   >>> import torch as th\n   >>> import dgl.ops as F\n   >>> g = dgl.graph(([0, 0, 0, 1, 1, 2], [0, 1, 2, 1, 2, 2]))  # 3 nodes, 6 edges\n   >>> x = th.ones(3, 2, requires_grad=True)\n   >>> x\n   tensor([[1., 1.],\n           [1., 1.],\n           [1., 1.]], requires_grad=True)\n   >>> y = th.arange(1, 13).float().view(6, 2).requires_grad_()\n   tensor([[ 1.,  2.],\n           [ 3.,  4.],\n           [ 5.,  6.],\n           [ 7.,  8.],\n           [ 9., 10.],\n           [11., 12.]], requires_grad=True)\n   >>> out_1 = F.u_mul_e_sum(g, x, y)\n   >>> out_1  # (10, 12) = ((1, 1) * (3, 4)) + ((1, 1) * (7, 8))\n   tensor([[ 1.,  2.],\n           [10., 12.],\n           [25., 28.]], grad_fn=<GSpMMBackward>)\n   >>> out_1.sum().backward()\n   >>> x.grad\n   tensor([[12., 15.],\n           [18., 20.],\n           [12., 13.]])\n   >>> y.grad\n   tensor([[1., 1.],\n           [1., 1.],\n           [1., 1.],\n           [1., 1.],\n           [1., 1.],\n           [1., 1.]])\n   >>> out_2 = F.copy_u_sum(g, x)\n   >>> out_2\n   tensor([[1., 1.],\n           [2., 2.],\n           [3., 3.]], grad_fn=<GSpMMBackward>)\n   >>> out_3 = F.u_add_e_max(g, x, y)\n   >>> out_3\n   tensor([[ 2.,  3.],\n           [ 8.,  9.],\n           [12., 13.]], grad_fn=<GSpMMBackward>)\n   >>> y1 = th.rand(6, 4, 2, requires_grad=True)  # test broadcast\n   >>> F.u_mul_e_sum(g, x, y1).shape  # (2,), (4, 2) -> (4, 2)\n   torch.Size([3, 4, 2])\n\nFor all operators, the input graph could either be a homogeneous or a bipartite\ngraph.\n\n.. autosummary::\n    :toctree: ../../generated/\n\n    gspmm\n    u_add_e_sum\n    u_sub_e_sum\n    u_mul_e_sum\n    u_div_e_sum\n    u_add_e_max\n    u_sub_e_max\n    u_mul_e_max\n    u_div_e_max\n    u_add_e_min\n    u_sub_e_min\n    u_mul_e_min\n    u_div_e_min\n    u_add_e_mean\n    u_sub_e_mean\n    u_mul_e_mean\n    u_div_e_mean\n    copy_u_sum\n    copy_e_sum\n    copy_u_max\n    copy_e_max\n    copy_u_min\n    copy_e_min \n    copy_u_mean\n    copy_e_mean\n\nGSDDMM functions\n----------------\n\nGeneralized Sampled Dense-Dense Matrix Multiplication.\nIt computes edge features by add/sub/mul/div/dot features on source/destination\nnodes or edges.\n\nLike GSpMM, our implementation supports tensors on CPU/GPU in\nPyTorch/MXNet/Tensorflow as input. All operators are equipped with autograd and\nbroadcasting.\n\nThe memory cost of GSDDMM is :math:`O(|E|D)` where :math:`|E|` refers to the number\nof edges in the graph while :math:`D` refers to the feature size.\n\nNote that we support ``dot`` operator, which semantically is the same as reduce\nthe last dimension by sum to the result of ``mul`` operator. However, the ``dot``\nis more memory efficient because it *fuses* ``mul`` and sum reduction, which is\ncritical in the cases while the feature size on last dimension is non-trivial\n(e.g. multi-head attention in Transformer-like models).\n\nThe following is an example showing how GSDDMM works:\n\n   >>> import dgl\n   >>> import torch as th\n   >>> import dgl.ops as F\n   >>> g = dgl.graph(([0, 0, 0, 1, 1, 2], [0, 1, 2, 1, 2, 2]))  # 3 nodes, 6 edges\n   >>> x = th.ones(3, 2, requires_grad=True)\n   >>> x\n   tensor([[1., 1.],\n           [1., 1.],\n           [1., 1.]], requires_grad=True)\n   >>> y = th.arange(1, 7).float().view(3, 2).requires_grad_()\n   >>> y\n   tensor([[1., 2.],\n           [3., 4.],\n           [5., 6.]], requires_grad=True)\n   >>> e = th.ones(6, 1, 2, requires_grad=True) * 2\n   tensor([[[2., 2.]],\n           [[2., 2.]],\n           [[2., 2.]],\n           [[2., 2.]],\n           [[2., 2.]],\n           [[2., 2.]]], grad_fn=<MulBackward0>)\n   >>> out1 = F.u_div_v(g, x, y)\n   tensor([[1.0000, 0.5000],\n           [0.3333, 0.2500],\n           [0.2000, 0.1667],\n           [0.3333, 0.2500],\n           [0.2000, 0.1667],\n           [0.2000, 0.1667]], grad_fn=<GSDDMMBackward>)\n   >>> out1.sum().backward()\n   >>> x.grad\n   tensor([[1.5333, 0.9167],\n           [0.5333, 0.4167],\n           [0.2000, 0.1667]])\n   >>> y.grad\n   tensor([[-1.0000, -0.2500],\n           [-0.2222, -0.1250],\n           [-0.1200, -0.0833]])\n   >>> out2 = F.e_sub_v(g, e, y)\n   >>> out2\n   tensor([[[ 1.,  0.]],\n           [[-1., -2.]],\n           [[-3., -4.]],\n           [[-1., -2.]],\n           [[-3., -4.]],\n           [[-3., -4.]]], grad_fn=<GSDDMMBackward>)\n   >>> out3 = F.copy_v(g, y)\n   >>> out3\n   tensor([[1., 2.],\n           [3., 4.],\n           [5., 6.],\n           [3., 4.],\n           [5., 6.],\n           [5., 6.]], grad_fn=<GSDDMMBackward>)\n   >>> out4 = F.u_dot_v(g, x, y)\n   >>> out4  # the last dimension was reduced to size 1.\n   tensor([[ 3.],\n           [ 7.],\n           [11.],\n           [ 7.],\n           [11.],\n           [11.]], grad_fn=<GSDDMMBackward>)\n\n.. autosummary::\n    :toctree: ../../generated/\n\n    gsddmm\n    u_add_v\n    u_sub_v\n    u_mul_v\n    u_dot_v\n    u_div_v\n    u_add_e\n    u_sub_e\n    u_mul_e\n    u_dot_e\n    u_div_e\n    e_add_v\n    e_sub_v\n    e_mul_v\n    e_dot_v\n    e_div_v\n    v_add_u\n    v_sub_u\n    v_mul_u\n    v_dot_u\n    v_div_u\n    e_add_u\n    e_sub_u\n    e_mul_u\n    e_dot_u\n    e_div_u\n    v_add_e\n    v_sub_e\n    v_mul_e\n    v_dot_e\n    v_div_e\n    copy_u\n    copy_v\n\nLike GSpMM, GSDDMM operators support both homogeneous and bipartite graph.\n\nSegment Reduce Module\n---------------------\n\nDGL provide operators to reduce value tensor along the first dimension by segments.\n\n.. autosummary::\n   :toctree: ../../generated/\n\n   segment_reduce\n\nGatherMM and SegmentMM Module\n-----------------------------\n\nSegmentMM: DGL provide operators to perform matrix multiplication according to segments.\n\nGatherMM: DGL provide operators to gather data according to the given indices and perform matrix multiplication.\n\n.. autosummary::\n   :toctree: ../../generated/\n\n   gather_mm\n   segment_mm\n\nSupported Data types\n--------------------\nOperators defined in ``dgl.ops`` support floating point data types, i.e. the operands\nmust be ``half`` (``float16``) /``float``/``double`` tensors.\nThe input tensors must have the same data type (if one input tensor has type float16\nand the other input tensor has data type float32, user must convert one of them to\nalign with the other one).\n\n``float16`` data type support is disabled by default as it has a minimum GPU\ncompute capacity requirement of ``sm_53`` (Pascal, Volta, Turing and Ampere\narchitectures).\n\nUser can enable float16 for mixed precision training by compiling DGL from source\n(see :doc:`Mixed Precision Training </guide/mixed_precision>` tutorial for details).\n\nRelation with Message Passing APIs\n----------------------------------\n\n``dgl.update_all`` and ``dgl.apply_edges`` calls with built-in message/reduce functions\nwould be dispatched into function calls of operators defined in ``dgl.ops``:\n\n    >>> import dgl\n    >>> import torch as th\n    >>> import dgl.ops as F\n    >>> import dgl.function as fn\n    >>> g = dgl.rand_graph(100, 1000)   # create a DGLGraph with 100 nodes and 1000 edges.\n    >>> x = th.rand(100, 20)            # node features.\n    >>> e = th.rand(1000, 20)\n    >>>\n    >>> # dgl.update_all + builtin functions\n    >>> g.srcdata['x'] = x              # srcdata is the same as ndata for graphs with one node type.\n    >>> g.edata['e'] = e\n    >>> g.update_all(fn.u_mul_e('x', 'e', 'm'), fn.sum('m', 'y'))\n    >>> y = g.dstdata['y']              # dstdata is the same as ndata for graphs with one node type.\n    >>>\n    >>> # use GSpMM operators defined in dgl.ops directly\n    >>> y = F.u_mul_e_sum(g, x, e)\n    >>>\n    >>> # dgl.apply_edges + builtin functions\n    >>> g.srcdata['x'] = x\n    >>> g.dstdata['y'] = y\n    >>> g.apply_edges(fn.u_dot_v('x', 'y', 'z'))\n    >>> z = g.edata['z']\n    >>>\n    >>> # use GSDDMM operators defined in dgl.ops directly\n    >>> z = F.u_dot_v(g, x, y)\n\nIt up to user to decide whether to use message-passing APIs or GSpMM/GSDDMM operators, and both\nof them have the same efficiency. Programs written in message-passing APIs look more like DGL-style\nbut in some cases calling GSpMM/GSDDMM operators is more concise.\n\nNote that on PyTorch all operators defined in ``dgl.ops`` support higher-order gradients, so as\nmessage passing APIs because they entirely depend on these operators.\n\n\n"
  },
  {
    "path": "docs/source/api/python/dgl.optim.rst",
    "content": ".. _apioptim:\n\ndgl.optim\n=========\n\n.. automodule:: dgl.optim\n\nNode embedding optimizer\n-------------------------\n.. currentmodule:: dgl.optim.pytorch\n\n.. autoclass:: SparseAdagrad\n.. autoclass:: SparseAdam"
  },
  {
    "path": "docs/source/api/python/dgl.rst",
    "content": ".. _apidgl:\n\ndgl\n=============================\n\n.. currentmodule:: dgl\n.. automodule:: dgl\n\n.. _api-graph-create-ops:\n\nGraph Create Ops\n-------------------------\n\nOperators for constructing :class:`DGLGraph` from raw data formats.\n\n.. autosummary::\n    :toctree: ../../generated/\n\n    graph\n    heterograph\n    from_cugraph\n    from_scipy\n    from_networkx\n    bipartite_from_scipy\n    bipartite_from_networkx\n    rand_graph\n    rand_bipartite\n    knn_graph\n    segmented_knn_graph\n    radius_graph\n    create_block\n    block_to_graph\n    merge\n\n.. _api-subgraph-extraction:\n\nSubgraph Extraction Ops\n-------------------------------------\n\nOperators for extracting and returning subgraphs.\n\n.. autosummary::\n    :toctree: ../../generated/\n\n    node_subgraph\n    edge_subgraph\n    node_type_subgraph\n    edge_type_subgraph\n    in_subgraph\n    out_subgraph\n    khop_in_subgraph\n    khop_out_subgraph\n\n.. _api-transform:\n\nGraph Transform Ops\n----------------------------------\n\nOperators for generating new graphs by manipulating the structure of the existing ones.\n\n.. autosummary::\n    :toctree: ../../generated/\n\n    add_edges\n    add_nodes\n    add_reverse_edges\n    add_self_loop\n    adj_product_graph\n    adj_sum_graph\n    compact_graphs\n    khop_adj\n    khop_graph\n    knn_graph\n    laplacian_lambda_max\n    line_graph\n    metapath_reachable_graph\n    metis_partition\n    metis_partition_assignment\n    norm_by_dst\n    partition_graph_with_halo\n    radius_graph\n    remove_edges\n    remove_nodes\n    remove_self_loop\n    reorder_graph\n    reverse\n    segmented_knn_graph\n    sort_csr_by_tag\n    sort_csc_by_tag\n    to_bidirected\n    to_bidirected_stale\n    to_block\n    to_cugraph\n    to_double\n    to_float\n    to_half\n    to_heterogeneous\n    to_homogeneous\n    to_networkx\n    to_simple\n    to_simple_graph\n\n.. _api-positional-encoding:\n\nGraph Positional Encoding Ops:\n-----------------------------------------\n\nOperators for generating positional encodings of each node.\n\n.. autosummary::\n    :toctree: ../../generated\n\n    random_walk_pe\n    lap_pe\n    double_radius_node_labeling\n    shortest_dist\n    svd_pe\n\n.. _api-partition:\n\nGraph Partition Utilities\n-------------------------\n.. autosummary::\n    :toctree: ../../generated/\n\n    metis_partition\n    metis_partition_assignment\n    partition_graph_with_halo\n\n.. _api-batch:\n\nBatching and Reading Out Ops\n-------------------------------\n\nOperators for batching multiple graphs into one for batch processing and\noperators for computing graph-level representation for both single and batched graphs.\n\n.. autosummary::\n    :toctree: ../../generated/\n\n    batch\n    unbatch\n    slice_batch\n    readout_nodes\n    readout_edges\n    sum_nodes\n    sum_edges\n    mean_nodes\n    mean_edges\n    max_nodes\n    max_edges\n    softmax_nodes\n    softmax_edges\n    broadcast_nodes\n    broadcast_edges\n    topk_nodes\n    topk_edges\n\nAdjacency Related Utilities\n-------------------------------\n\nUtilities for computing adjacency matrix and Lapacian matrix.\n\n.. autosummary::\n    :toctree: ../../generated/\n\n    khop_adj\n    laplacian_lambda_max\n\nGraph Traversal & Message Propagation\n------------------------------------------\n\nDGL implements graph traversal algorithms implemented as python generators,\nwhich returns the visited set of nodes or edges (in ID tensor) at each iteration.\nThe naming convention is ``<algorithm>_[nodes|edges]_generator``.\nAn example usage is as follows.\n\n.. code:: python\n\n    g = ...  # some DGLGraph\n    for nodes in dgl.bfs_nodes_generator(g, 0):\n        do_something(nodes)\n\n.. autosummary::\n    :toctree: ../../generated/\n\n    bfs_nodes_generator\n    bfs_edges_generator\n    topological_nodes_generator\n    dfs_edges_generator\n    dfs_labeled_edges_generator\n\nDGL provides APIs to perform message passing following graph traversal order. ``prop_nodes_XXX``\ncalls traversal algorithm ``XXX`` and triggers :func:`~DGLGraph.pull()` on the visited node\nset at each iteration. ``prop_edges_YYY`` applies traversal algorithm ``YYY`` and triggers\n:func:`~DGLGraph.send_and_recv()` on the visited edge set at each iteration.\n\n.. autosummary::\n    :toctree: ../../generated/\n\n    prop_nodes\n    prop_nodes_bfs\n    prop_nodes_topo\n    prop_edges\n    prop_edges_dfs\n\nHomophily Measures\n-------------------------\n\nUtilities for measuring homophily of a graph\n\n.. autosummary::\n    :toctree: ../../generated/\n\n    edge_homophily\n    node_homophily\n    linkx_homophily\n    adjusted_homophily\n\nLabel Informativeness Measures\n-------------------------\n\nUtilities for measuring label informativeness of a graph\n\n.. autosummary::\n    :toctree: ../../generated/\n\n    edge_label_informativeness\n    node_label_informativeness\n\nUtilities\n-----------------------------------------------\n\nOther utilities for controlling randomness, saving and loading graphs, setting and getting runtime configurations, functions that applies\nthe same function to every elements in a container, etc.\n\n.. autosummary::\n    :toctree: ../../generated/\n\n    seed\n    save_graphs\n    load_graphs\n    apply_each\n    use_libxsmm\n    is_libxsmm_enabled\n"
  },
  {
    "path": "docs/source/api/python/dgl.sampling.rst",
    "content": ".. _api-sampling:\n\ndgl.sampling\n=================================\n\n.. automodule:: dgl.sampling\n\nRandom walk\n------------------------------\n\n.. autosummary::\n    :toctree: ../../generated/\n\n    random_walk\n    node2vec_random_walk\n    pack_traces\n\nNeighbor sampling\n---------------------------\n\n.. autosummary::\n    :toctree: ../../generated/\n\n    sample_neighbors\n    sample_labors\n    sample_neighbors_biased\n    select_topk\n    PinSAGESampler\n\nNegative sampling\n-----------------\n\n.. autosummary::\n    :toctree: ../../generated/\n\n    global_uniform_negative_sampling\n"
  },
  {
    "path": "docs/source/api/python/dgl.sparse_v0.rst",
    "content": ".. _apibackend:\n\ndgl.sparse\n=================================\n\n`dgl.sparse` is a library for sparse operators that are commonly used in GNN models.\n\nSparse matrix class\n-------------------------\n.. currentmodule:: dgl.sparse\n\n.. class:: SparseMatrix\n\n    A SparseMatrix can be created from Coordinate format indices using the\n    :func:`spmatrix` constructor:\n\n        >>> indices = torch.tensor([[1, 1, 2],\n        >>>                         [2, 4, 3]])\n        >>> A = dglsp.spmatrix(indices)\n        SparseMatrix(indices=tensor([[1, 1, 2],\n                                     [2, 4, 3]]),\n                     values=tensor([1., 1., 1.]),\n                     shape=(3, 5), nnz=3)\n\nCreation Ops\n````````\n\n.. autosummary::\n    :toctree: ../../generated/\n\n    spmatrix\n    val_like\n    from_coo\n    from_csr\n    from_csc\n    diag\n    identity\n\nAttributes and methods\n``````````````````````\n\n.. autosummary::\n    :toctree: ../../generated/\n\n    SparseMatrix.shape\n    SparseMatrix.nnz\n    SparseMatrix.dtype\n    SparseMatrix.device\n    SparseMatrix.val\n    SparseMatrix.row\n    SparseMatrix.col\n    SparseMatrix.indices\n    SparseMatrix.coo\n    SparseMatrix.csr\n    SparseMatrix.csc\n    SparseMatrix.coalesce\n    SparseMatrix.has_duplicate\n    SparseMatrix.to_dense\n    SparseMatrix.to\n    SparseMatrix.cuda\n    SparseMatrix.cpu\n    SparseMatrix.float\n    SparseMatrix.double\n    SparseMatrix.int\n    SparseMatrix.long\n    SparseMatrix.transpose\n    SparseMatrix.t\n    SparseMatrix.T\n    SparseMatrix.neg\n    SparseMatrix.reduce\n    SparseMatrix.sum\n    SparseMatrix.smax\n    SparseMatrix.smin\n    SparseMatrix.smean\n    SparseMatrix.softmax\n\nOperators\n---------\n.. currentmodule:: dgl.sparse\n\nElementwise Operators\n````````\n\n.. autosummary::\n    :toctree: ../../generated/\n\n    add\n    sub\n    mul\n    div\n    power\n\nMatrix Multiplication\n````````\n\n.. autosummary::\n    :toctree: ../../generated/\n\n    matmul\n    spmm\n    bspmm\n    spspmm\n    sddmm\n    bsddmm\n\nNon-linear activation functions\n````````\n\n.. autosummary::\n    :toctree: ../../generated/\n\n    softmax\n\nBroadcast operators\n````````\n\n.. autosummary::\n    :toctree: ../../generated/\n\n    sp_broadcast_v\n    sp_add_v\n    sp_sub_v\n    sp_mul_v\n    sp_div_v"
  },
  {
    "path": "docs/source/api/python/index.rst",
    "content": "API Reference\n=============\n\n.. toctree::\n   :maxdepth: 2\n\n   dgl\n   dgl.data\n   dgl.dataloading\n   dgl.DGLGraph\n   dgl.distributed\n   dgl.function\n   nn-pytorch\n   nn-tensorflow\n   nn-mxnet\n   dgl.ops\n   dgl.sampling\n   udf\n   transforms\n"
  },
  {
    "path": "docs/source/api/python/knn_benchmark.rst",
    "content": ".. _knn_benchmark:\n\nBenchmark the performance of KNN algorithms\n===========================================\n\nIn this doc, we benchmark the performance on multiple K-Nearest Neighbor algorithms implemented by :func:`dgl.knn_graph`.\n\nGiven a dataset of ``N`` samples with ``D`` dimensions, the common use case of KNN algorithms in graph learning is to build a KNN graph by finding the ``K`` nearest neighbors for each of the ``N`` samples among the dataset.\n\nEmpirically, the three parameters, ``N``, ``D``, and ``K``, all have impact on the computation cost. To benchmark the algorithms, we pick a few represensitive datasets to cover most common scenarios:\n\n* A synthetic dataset with mixed gaussian samples: ``N = 1000``, ``D = 3``.\n* A point cloud sample from ModelNet: ``N = 10000``, ``D = 3``.\n* Subsets of MNIST\n    - A small subset: ``N = 1000``, ``D = 784``\n    - A medium subset: ``N = 10000``, ``D = 784``\n    - A large subset: ``N = 50000``, ``D = 784``\n\nSome notes:\n\n* ``bruteforce-sharemem`` is an optimized implementation of ``bruteforce`` on GPU.\n* ``kd-tree`` is currently only implemented on CPU.\n* ``bruteforce-blas`` conducts matrix multiplication, thus is memory inefficient.\n* ``nn-descent`` is an approximate algorithm, and we also report the recall rate of its result.\n\nResults\n-------\n\nIn this section, we show the runtime and recall rate (where applicable) for the algorithms under various scenarios.\n\nThe experiments are run on an Amazon EC2 P3.2xlarge instance. This instance has 8 vCPUs with 61GB RAM, and one Tesla V100 GPU with 16GB RAM. In terms of the environment, we obtain the numbers with DGL==0.7.0(`64d0f3f <https://github.com/dmlc/dgl/commit/64d0f3f3554911ec06d015f1c9659180796adf9a>`_), PyTorch==1.8.1, CUDA==11.1 on Ubuntu 18.04.5 LTS.\n\n* **Mixed Gaussian:**\n\n+---------------------+------------------+-------------------+------------------+------------------+\n| Model               | CPU                                  | GPU                                 |\n|                     +------------------+-------------------+------------------+------------------+\n|                     | K = 8            | K = 64            | K = 8            | K = 64           |\n+=====================+==================+===================+==================+==================+\n| bruteforce-blas     | 0.010            | 0.011             | 0.002            | 0.003            |\n+---------------------+------------------+-------------------+------------------+------------------+\n| kd-tree             | 0.004            | 0.006             | n/a              | n/a              |\n+---------------------+------------------+-------------------+------------------+------------------+\n| bruteforce          | 0.004            | 0.006             | 0.126            | 0.009            |\n+---------------------+------------------+-------------------+------------------+------------------+\n| bruteforce-sharemem | n/a              | n/a               | 0.002            | 0.003            |\n+---------------------+------------------+-------------------+------------------+------------------+\n| nn-descent          | 0.014 (R: 0.985) | 0.148 (R: 1.000)  | 0.016 (R: 0.973) | 0.077 (R: 1.000) |\n+---------------------+------------------+-------------------+------------------+------------------+\n\n* **Point Cloud**\n\n+---------------------+------------------+-------------------+------------------+------------------+\n| Model               | CPU                                  | GPU                                 |\n|                     +------------------+-------------------+------------------+------------------+\n|                     | K = 8            | K = 64            | K = 8            | K = 64           |\n+=====================+==================+===================+==================+==================+\n| bruteforce-blas     | 0.359            | 0.432             | 0.010            | 0.010            |\n+---------------------+------------------+-------------------+------------------+------------------+\n| kd-tree             | 0.007            | 0.026             | n/a              | n/a              |\n+---------------------+------------------+-------------------+------------------+------------------+\n| bruteforce          | 0.074            | 0.167             | 0.008            | 0.039            |\n+---------------------+------------------+-------------------+------------------+------------------+\n| bruteforce-sharemem | n/a              | n/a               | 0.004            | 0.017            |\n+---------------------+------------------+-------------------+------------------+------------------+\n| nn-descent          | 0.161 (R: 0.977) | 1.345 (R: 0.999)  | 0.086 (R: 0.966) | 0.445 (R: 0.999) |\n+---------------------+------------------+-------------------+------------------+------------------+\n\n* **Small MNIST**\n\n+---------------------+------------------+-------------------+------------------+------------------+\n| Model               | CPU                                  | GPU                                 |\n|                     +------------------+-------------------+------------------+------------------+\n|                     | K = 8            | K = 64            | K = 8            | K = 64           |\n+=====================+==================+===================+==================+==================+\n| bruteforce-blas     | 0.014            | 0.015             | 0.002            | 0.002            |\n+---------------------+------------------+-------------------+------------------+------------------+\n| kd-tree             | 0.179            | 0.182             | n/a              | n/a              |\n+---------------------+------------------+-------------------+------------------+------------------+\n| bruteforce          | 0.173            | 0.228             | 0.123            | 0.170            |\n+---------------------+------------------+-------------------+------------------+------------------+\n| bruteforce-sharemem | n/a              | n/a               | 0.045            | 0.054            |\n+---------------------+------------------+-------------------+------------------+------------------+\n| nn-descent          | 0.060 (R: 0.878) | 1.077 (R: 0.999)  | 0.030 (R: 0.952) | 0.457 (R: 0.999) |\n+---------------------+------------------+-------------------+------------------+------------------+\n\n* **Medium MNIST**\n\n+---------------------+------------------+-------------------+------------------+------------------+\n| Model               | CPU                                  | GPU                                 |\n|                     +------------------+-------------------+------------------+------------------+\n|                     | K = 8            | K = 64            | K = 8            | K = 64           |\n+=====================+==================+===================+==================+==================+\n| bruteforce-blas     | 0.897            | 0.970             | 0.019            | 0.023            |\n+---------------------+------------------+-------------------+------------------+------------------+\n| kd-tree             | 18.902           | 18.928            | n/a              | n/a              |\n+---------------------+------------------+-------------------+------------------+------------------+\n| bruteforce          | 14.495           | 17.652            | 2.058            | 2.588            |\n+---------------------+------------------+-------------------+------------------+------------------+\n| bruteforce-sharemem | n/a              | n/a               | 2.257            | 2.524            |\n+---------------------+------------------+-------------------+------------------+------------------+\n| nn-descent          | 0.804 (R: 0.755) | 14.108 (R: 0.999) | 0.158 (R: 0.900) | 1.794 (R: 0.999) |\n+---------------------+------------------+-------------------+------------------+------------------+\n\n* **Large MNIST**\n\n+---------------------+------------------+-------------------+------------------+------------------+\n| Model               | CPU                                  | GPU                                 |\n|                     +------------------+-------------------+------------------+------------------+\n|                     | K = 8            | K = 64            | K = 8            | K = 64           |\n+=====================+==================+===================+==================+==================+\n| bruteforce-blas     | 21.829           | 22.135            | Out of Memory    | Out of Memory    |\n+---------------------+------------------+-------------------+------------------+------------------+\n| kd-tree             | 542.688          | 573.379           | n/a              | n/a              |\n+---------------------+------------------+-------------------+------------------+------------------+\n| bruteforce          | 373.823          | 432.963           | 10.317           | 12.639           |\n+---------------------+------------------+-------------------+------------------+------------------+\n| bruteforce-sharemem | n/a              | n/a               | 53.133           | 58.419           |\n+---------------------+------------------+-------------------+------------------+------------------+\n| nn-descent          | 4.995 (R: 0.658) | 75.487 (R: 0.999) | 1.478 (R: 0.860) | 15.698 (R: 0.999)| \n+---------------------+------------------+-------------------+------------------+------------------+\n\nConclusion\n----------\n\n- As long as you have enough memory, ``bruteforce-blas`` is the default algorithm to go with.\n- Specifically, when ``D`` is small and the data is on CPU, ``kd-tree`` is the best algorithm.\n\n"
  },
  {
    "path": "docs/source/api/python/nn-mxnet.rst",
    "content": ".. _apinn-mxnet:\n\ndgl.nn (MXNet)\n================\n\nConv Layers\n----------------------------------------\n\n.. autosummary::\n    :toctree: ../../generated/\n    :nosignatures:\n    :template: classtemplate.rst\n\n    ~dgl.nn.mxnet.conv.GraphConv\n    ~dgl.nn.mxnet.conv.RelGraphConv\n    ~dgl.nn.mxnet.conv.TAGConv\n    ~dgl.nn.mxnet.conv.GATConv\n    ~dgl.nn.mxnet.conv.EdgeConv\n    ~dgl.nn.mxnet.conv.SAGEConv\n    ~dgl.nn.mxnet.conv.SGConv\n    ~dgl.nn.mxnet.conv.APPNPConv\n    ~dgl.nn.mxnet.conv.GINConv\n    ~dgl.nn.mxnet.conv.GatedGraphConv\n    ~dgl.nn.mxnet.conv.GMMConv\n    ~dgl.nn.mxnet.conv.ChebConv\n    ~dgl.nn.mxnet.conv.AGNNConv\n    ~dgl.nn.mxnet.conv.NNConv\n\nDense Conv Layers\n----------------------------------------\n\n.. autosummary::\n    :toctree: ../../generated/\n    :nosignatures:\n    :template: classtemplate.rst\n\n    ~dgl.nn.mxnet.conv.DenseGraphConv\n    ~dgl.nn.mxnet.conv.DenseSAGEConv\n    ~dgl.nn.mxnet.conv.DenseChebConv\n\nGlobal Pooling Layers\n----------------------------------------\n\n.. autosummary::\n    :toctree: ../../generated/\n    :nosignatures:\n    :template: classtemplate.rst\n\n    ~dgl.nn.mxnet.glob.SumPooling\n    ~dgl.nn.mxnet.glob.AvgPooling\n    ~dgl.nn.mxnet.glob.MaxPooling\n    ~dgl.nn.mxnet.glob.SortPooling\n    ~dgl.nn.mxnet.glob.GlobalAttentionPooling\n    ~dgl.nn.mxnet.glob.Set2Set\n\nHeterogeneous Learning Modules\n----------------------------------------\n\n.. autosummary::\n    :toctree: ../../generated/\n    :nosignatures:\n    :template: classtemplate.rst\n\n    ~dgl.nn.mxnet.HeteroGraphConv\n\nUtility Modules\n----------------------------------------\n\n.. autosummary::\n    :toctree: ../../generated/\n    :nosignatures:\n    :template: classtemplate.rst\n\n    ~dgl.nn.mxnet.utils.Sequential\n"
  },
  {
    "path": "docs/source/api/python/nn-pytorch.rst",
    "content": ".. _apinn-pytorch:\n\ndgl.nn (PyTorch)\n================\n\nConv Layers\n----------------------------------------\n\n.. autosummary::\n    :toctree: ../../generated/\n    :nosignatures:\n    :template: classtemplate.rst\n\n    ~dgl.nn.pytorch.conv.GraphConv\n    ~dgl.nn.pytorch.conv.EdgeWeightNorm\n    ~dgl.nn.pytorch.conv.RelGraphConv\n    ~dgl.nn.pytorch.conv.TAGConv\n    ~dgl.nn.pytorch.conv.GATConv\n    ~dgl.nn.pytorch.conv.GATv2Conv\n    ~dgl.nn.pytorch.conv.EGATConv\n    ~dgl.nn.pytorch.conv.EdgeGATConv\n    ~dgl.nn.pytorch.conv.EdgeConv\n    ~dgl.nn.pytorch.conv.SAGEConv\n    ~dgl.nn.pytorch.conv.SGConv\n    ~dgl.nn.pytorch.conv.APPNPConv\n    ~dgl.nn.pytorch.conv.GINConv\n    ~dgl.nn.pytorch.conv.GINEConv\n    ~dgl.nn.pytorch.conv.GatedGraphConv\n    ~dgl.nn.pytorch.conv.GatedGCNConv\n    ~dgl.nn.pytorch.conv.GMMConv\n    ~dgl.nn.pytorch.conv.ChebConv\n    ~dgl.nn.pytorch.conv.AGNNConv\n    ~dgl.nn.pytorch.conv.NNConv\n    ~dgl.nn.pytorch.conv.AtomicConv\n    ~dgl.nn.pytorch.conv.CFConv\n    ~dgl.nn.pytorch.conv.DotGatConv\n    ~dgl.nn.pytorch.conv.TWIRLSConv\n    ~dgl.nn.pytorch.conv.TWIRLSUnfoldingAndAttention\n    ~dgl.nn.pytorch.conv.GCN2Conv\n    ~dgl.nn.pytorch.conv.HGTConv\n    ~dgl.nn.pytorch.conv.GroupRevRes\n    ~dgl.nn.pytorch.conv.EGNNConv\n    ~dgl.nn.pytorch.conv.PNAConv\n    ~dgl.nn.pytorch.conv.DGNConv\n\nCuGraph Conv Layers\n----------------------------------------\n\n.. autosummary::\n    :toctree: ../../generated/\n    :nosignatures:\n    :template: classtemplate.rst\n\n    ~dgl.nn.pytorch.conv.CuGraphRelGraphConv\n    ~dgl.nn.pytorch.conv.CuGraphGATConv\n    ~dgl.nn.pytorch.conv.CuGraphSAGEConv\n\nDense Conv Layers\n----------------------------------------\n\n.. autosummary::\n    :toctree: ../../generated/\n    :nosignatures:\n    :template: classtemplate.rst\n\n    ~dgl.nn.pytorch.conv.DenseGraphConv\n    ~dgl.nn.pytorch.conv.DenseSAGEConv\n    ~dgl.nn.pytorch.conv.DenseChebConv\n\nGlobal Pooling Layers\n----------------------------------------\n\n.. autosummary::\n    :toctree: ../../generated/\n    :nosignatures:\n    :template: classtemplate.rst\n\n    ~dgl.nn.pytorch.glob.SumPooling\n    ~dgl.nn.pytorch.glob.AvgPooling\n    ~dgl.nn.pytorch.glob.MaxPooling\n    ~dgl.nn.pytorch.glob.SortPooling\n    ~dgl.nn.pytorch.glob.WeightAndSum\n    ~dgl.nn.pytorch.glob.GlobalAttentionPooling\n    ~dgl.nn.pytorch.glob.Set2Set\n    ~dgl.nn.pytorch.glob.SetTransformerEncoder\n    ~dgl.nn.pytorch.glob.SetTransformerDecoder\n\nScore Modules for Link Prediction and Knowledge Graph Completion\n----------------------------------------\n\n.. autosummary::\n    :toctree: ../../generated/\n    :nosignatures:\n    :template: classtemplate.rst\n\n    ~dgl.nn.pytorch.link.EdgePredictor\n    ~dgl.nn.pytorch.link.TransE\n    ~dgl.nn.pytorch.link.TransR\n\nHeterogeneous Learning Modules\n----------------------------------------\n\n.. autosummary::\n    :toctree: ../../generated/\n    :nosignatures:\n    :template: classtemplate.rst\n\n    ~dgl.nn.pytorch.HeteroGraphConv\n    ~dgl.nn.pytorch.HeteroLinear\n    ~dgl.nn.pytorch.HeteroEmbedding\n    ~dgl.nn.pytorch.TypedLinear\n\nUtility Modules\n----------------------------------------\n\n.. autosummary::\n    :toctree: ../../generated/\n    :nosignatures:\n    :template: classtemplate.rst\n\n    ~dgl.nn.pytorch.utils.Sequential\n    ~dgl.nn.pytorch.utils.WeightBasis\n    ~dgl.nn.pytorch.factory.KNNGraph\n    ~dgl.nn.pytorch.factory.SegmentedKNNGraph\n    ~dgl.nn.pytorch.factory.RadiusGraph\n    ~dgl.nn.pytorch.utils.JumpingKnowledge\n    ~dgl.nn.pytorch.sparse_emb.NodeEmbedding\n    ~dgl.nn.pytorch.explain.GNNExplainer\n    ~dgl.nn.pytorch.explain.HeteroGNNExplainer\n    ~dgl.nn.pytorch.explain.SubgraphX\n    ~dgl.nn.pytorch.explain.HeteroSubgraphX\n    ~dgl.nn.pytorch.explain.PGExplainer\n    ~dgl.nn.pytorch.explain.HeteroPGExplainer\n    ~dgl.nn.pytorch.utils.LabelPropagation\n    ~dgl.nn.pytorch.utils.LaplacianPosEnc\n\nNetwork Embedding Modules\n----------------------------------------\n\n.. autosummary::\n    :toctree: ../../generated/\n    :nosignatures:\n    :template: classtemplate.rst\n\n    ~dgl.nn.pytorch.DeepWalk\n    ~dgl.nn.pytorch.MetaPath2Vec\n\nUtility Modules for Graph Transformer\n----------------------------------------\n.. autosummary::\n    :toctree: ../../generated/\n    :nosignatures:\n    :template: classtemplate.rst\n\n    ~dgl.nn.pytorch.gt.DegreeEncoder\n    ~dgl.nn.pytorch.gt.LapPosEncoder\n    ~dgl.nn.pytorch.gt.PathEncoder\n    ~dgl.nn.pytorch.gt.SpatialEncoder\n    ~dgl.nn.pytorch.gt.SpatialEncoder3d\n    ~dgl.nn.pytorch.gt.BiasedMHA\n    ~dgl.nn.pytorch.gt.GraphormerLayer\n    ~dgl.nn.pytorch.gt.EGTLayer\n"
  },
  {
    "path": "docs/source/api/python/nn-tensorflow.rst",
    "content": ".. _apinn-tensorflow:\n\ndgl.nn (TensorFlow)\n================\n\nConv Layers\n----------------------------------------\n\n.. autosummary::\n    :toctree: ../../generated/\n    :nosignatures:\n    :template: classtemplate.rst\n\n    ~dgl.nn.tensorflow.conv.GraphConv\n    ~dgl.nn.tensorflow.conv.RelGraphConv\n    ~dgl.nn.tensorflow.conv.GATConv\n    ~dgl.nn.tensorflow.conv.SAGEConv\n    ~dgl.nn.tensorflow.conv.ChebConv\n    ~dgl.nn.tensorflow.conv.SGConv\n    ~dgl.nn.tensorflow.conv.APPNPConv\n    ~dgl.nn.tensorflow.conv.GINConv\n\nGlobal Pooling Layers\n----------------------------------------\n\n.. autosummary::\n    :toctree: ../../generated/\n    :nosignatures:\n    :template: classtemplate.rst\n\n    ~dgl.nn.tensorflow.glob.SumPooling\n    ~dgl.nn.tensorflow.glob.AvgPooling\n    ~dgl.nn.tensorflow.glob.MaxPooling\n    ~dgl.nn.tensorflow.glob.SortPooling\n    ~dgl.nn.tensorflow.glob.GlobalAttentionPooling\n\nHeterogeneous Learning Modules\n----------------------------------------\n\n.. autosummary::\n    :toctree: ../../generated/\n    :nosignatures:\n    :template: classtemplate.rst\n\n    ~dgl.nn.tensorflow.glob.HeteroGraphConv\n"
  },
  {
    "path": "docs/source/api/python/nn.functional.rst",
    "content": ".. _apinn-functional:\n\ndgl.nn.functional\n=================\n\n.. automodule:: dgl.nn.functional\n\n.. autosummary::\n    :toctree: ../../generated/\n\n   edge_softmax\n"
  },
  {
    "path": "docs/source/api/python/transforms.rst",
    "content": ".. _apitransform-namespace:\n\ndgl.transforms\n==============\n\n.. currentmodule:: dgl.transforms\n.. automodule:: dgl.transforms\n\n.. autosummary::\n    :toctree: ../../generated/\n    :nosignatures:\n    :template: classtemplate.rst\n\n    BaseTransform\n    Compose\n    AddSelfLoop\n    RemoveSelfLoop\n    AddReverse\n    ToSimple\n    LineGraph\n    KHopGraph\n    AddMetaPaths\n    GCNNorm\n    PPR\n    HeatKernel\n    GDC\n    NodeShuffle\n    DropNode\n    DropEdge\n    AddEdge\n    RandomWalkPE\n    LapPE\n    FeatMask\n    RowFeatNormalizer\n    SIGNDiffusion\n    ToLevi\n    SVDPE\n"
  },
  {
    "path": "docs/source/api/python/udf.rst",
    "content": ".. _apiudf:\n\nUser-defined Functions\n==================================================\n\n.. currentmodule:: dgl.udf\n\nUser-defined functions (UDFs) allow arbitrary computation in message passing\n(see :ref:`guide-message-passing`) and edge feature update with\n:func:`~dgl.DGLGraph.apply_edges`. They bring more flexibility when :ref:`apifunction`\ncannot realize a desired computation.\n\nEdge-wise User-defined Function\n-------------------------------\n\nOne can use an edge-wise user defined function for a message function in message passing or\na function to apply in :func:`~dgl.DGLGraph.apply_edges`. It takes a batch of edges as input\nand returns messages (in message passing) or features (in :func:`~dgl.DGLGraph.apply_edges`)\nfor each edge. The function may combine the features of the edges and their end nodes in\ncomputation.\n\nFormally, it takes the following form\n\n.. code::\n\n    def edge_udf(edges):\n        \"\"\"\n        Parameters\n        ----------\n        edges : EdgeBatch\n            A batch of edges.\n\n        Returns\n        -------\n        dict[str, tensor]\n            The messages or edge features generated. It maps a message/feature name to the\n            corresponding messages/features of all edges in the batch. The order of the\n            messages/features is the same as the order of the edges in the input argument.\n        \"\"\"\n\nDGL generates :class:`~dgl.udf.EdgeBatch` instances internally, which expose the following\ninterface for defining ``edge_udf``.\n\n.. autosummary::\n    :toctree: ../../generated/\n\n    EdgeBatch.src\n    EdgeBatch.dst\n    EdgeBatch.data\n    EdgeBatch.edges\n    EdgeBatch.batch_size\n\nNode-wise User-defined Function\n-------------------------------\n\nOne can use a node-wise user defined function for a reduce function in message passing. It takes\na batch of nodes as input and returns the updated features for each node. It may combine the\ncurrent node features and the messages nodes received. Formally, it takes the following form\n\n.. code::\n\n    def node_udf(nodes):\n        \"\"\"\n        Parameters\n        ----------\n        nodes : NodeBatch\n            A batch of nodes.\n\n        Returns\n        -------\n        dict[str, tensor]\n            The updated node features. It maps a feature name to the corresponding features of\n            all nodes in the batch. The order of the nodes is the same as the order of the nodes\n            in the input argument.\n        \"\"\"\n\nDGL generates :class:`~dgl.udf.NodeBatch` instances internally, which expose the following\ninterface for defining ``node_udf``.\n\n.. autosummary::\n    :toctree: ../../generated/\n\n    NodeBatch.data\n    NodeBatch.mailbox\n    NodeBatch.nodes\n    NodeBatch.batch_size\n\nDegree Bucketing for Message Passing with User Defined Functions\n----------------------------------------------------------------\n\nDGL employs a degree-bucketing mechanism for message passing with UDFs. It groups nodes with\na same in-degree and invokes message passing for each group of nodes. As a result, one shall\nnot make any assumptions about the batch size of :class:`~dgl.udf.NodeBatch` instances.\n\nFor a batch of nodes, DGL stacks the incoming messages of each node along the second dimension,\nordered by edge ID.  An example goes as follows:\n\n.. code:: python\n\n    >>> import dgl\n    >>> import torch\n    >>> import dgl.function as fn\n    >>> g = dgl.graph(([1, 3, 5, 0, 4, 2, 3, 3, 4, 5], [1, 1, 0, 0, 1, 2, 2, 0, 3, 3]))\n    >>> g.edata['eid'] = torch.arange(10)\n    >>> def reducer(nodes):\n    ...     print(nodes.mailbox['eid'])\n    ...     return {'n': nodes.mailbox['eid'].sum(1)}\n    >>> g.update_all(fn.copy_e('eid', 'eid'), reducer)\n    tensor([[5, 6],\n            [8, 9]])\n    tensor([[3, 7, 2],\n            [0, 1, 4]])\n\nEssentially, node #2 and node #3 are grouped into one bucket with in-degree of 2, and node\n#0 and node #1 are grouped into one bucket with in-degree of 3.  Within each bucket, the\nedges are ordered by the edge IDs for each node.\n"
  },
  {
    "path": "docs/source/conf.py",
    "content": "# -*- coding: utf-8 -*-\n#\n# Configuration file for the Sphinx documentation builder.\n#\n# This file does only contain a selection of the most common options. For a\n# full list see the documentation:\n# http://www.sphinx-doc.org/en/master/config\n\n# -- Path setup --------------------------------------------------------------\n\n# If extensions (or modules to document with autodoc) are in another directory,\n# add these directories to sys.path here. If the directory is relative to the\n# documentation root, use os.path.abspath to make it absolute, like shown here.\n#\nimport os\nimport sys\n\nsys.path.insert(0, os.path.abspath(\"../../python\"))\n\n\n# -- Project information -----------------------------------------------------\n\nproject = \"DGL\"\ncopyright = \"2018, DGL Team\"\nauthor = \"DGL Team\"\n\nimport dgl\n\nversion = dgl.__version__\nrelease = dgl.__version__\ndglbackend = os.environ.get(\"DGLBACKEND\", \"pytorch\")\n\n\n# -- General configuration ---------------------------------------------------\n\n# If your documentation needs a minimal Sphinx version, state it here.\n#\n# needs_sphinx = '1.0'\n\n# Add any Sphinx extension module names here, as strings. They can be\n# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom\n# ones.\nextensions = [\n    \"sphinx.ext.autodoc\",\n    \"sphinx.ext.autosummary\",\n    \"sphinx.ext.coverage\",\n    \"sphinx.ext.mathjax\",\n    \"sphinx.ext.napoleon\",\n    \"sphinx.ext.viewcode\",\n    \"sphinx.ext.intersphinx\",\n    \"sphinx.ext.graphviz\",\n    \"sphinxemoji.sphinxemoji\",\n    \"sphinx_gallery.gen_gallery\",\n    \"sphinx_copybutton\",\n    \"nbsphinx\",\n    \"nbsphinx_link\",\n]\n\n# Do not run notebooks on non-pytorch backends\nif dglbackend != \"pytorch\":\n    nbsphinx_execute = \"never\"\n\n# Add any paths that contain templates here, relative to this directory.\ntemplates_path = [\"_templates\"]\n\n# The suffix(es) of source filenames.\n# You can specify multiple suffix as a list of string:\n#\nsource_suffix = [\".rst\", \".md\"]\n\n# The master toctree document.\nmaster_doc = \"index\"\n\n# The language for content autogenerated by Sphinx. Refer to documentation\n# for a list of supported languages.\n#\n# This is also used if you do content translation via gettext catalogs.\n# Usually you set \"language\" from the command line for these cases.\nlanguage = None\n\n# List of patterns, relative to source directory, that match files and\n# directories to ignore when looking for source files.\n# This pattern also affects html_static_path and html_extra_path.\nexclude_patterns = [\n    \"tutorials/**/*.ipynb\",\n    \"tutorials/**/*.py\",\n]\n\n# The name of the Pygments (syntax highlighting) style to use.\npygments_style = None\n\n\n# -- Options for HTML output -------------------------------------------------\n\n# The theme to use for HTML and HTML Help pages.  See the documentation for\n# a list of builtin themes.\n#\nhtml_theme = \"sphinx_rtd_theme\"\n\n# Theme options are theme-specific and customize the look and feel of a theme\n# further.  For a list of options available for each theme, see the\n# documentation.\n#\n# html_theme_options = {}\n\n# Add any paths that contain custom static files (such as style sheets) here,\n# relative to this directory. They are copied after the builtin static files,\n# so a file named \"default.css\" will overwrite the builtin \"default.css\".\nhtml_static_path = [\"_static\"]\nhtml_css_files = [\"css/custom.css\"]\n\n# Custom sidebar templates, must be a dictionary that maps document names\n# to template names.\n#\n# The default sidebars (for documents that don't match any pattern) are\n# defined by theme itself.  Builtin themes are using these templates by\n# default: ``['localtoc.html', 'relations.html', 'sourcelink.html',\n# 'searchbox.html']``.\n#\n# html_sidebars = {}\n\n\n# -- Options for HTMLHelp output ---------------------------------------------\n\n# Output file base name for HTML help builder.\nhtmlhelp_basename = \"dgldoc\"\n\n\n# -- Options for LaTeX output ------------------------------------------------\n\nlatex_elements = {\n    # The paper size ('letterpaper' or 'a4paper').\n    #\n    # 'papersize': 'letterpaper',\n    # The font size ('10pt', '11pt' or '12pt').\n    #\n    # 'pointsize': '10pt',\n    # Additional stuff for the LaTeX preamble.\n    #\n    # 'preamble': '',\n    # Latex figure (float) alignment\n    #\n    # 'figure_align': 'htbp',\n}\n\n# Grouping the document tree into LaTeX files. List of tuples\n# (source start file, target name, title,\n#  author, documentclass [howto, manual, or own class]).\nlatex_documents = [\n    (master_doc, \"dgl.tex\", \"DGL Documentation\", \"DGL Team\", \"manual\"),\n]\n\n\n# -- Options for manual page output ------------------------------------------\n\n# One entry per manual page. List of tuples\n# (source start file, name, description, authors, manual section).\nman_pages = [(master_doc, \"dgl\", \"DGL Documentation\", [author], 1)]\n\n\n# -- Options for Texinfo output ----------------------------------------------\n\n# Grouping the document tree into Texinfo files. List of tuples\n# (source start file, target name, title, author,\n#  dir menu entry, description, category)\ntexinfo_documents = [\n    (\n        master_doc,\n        \"dgl\",\n        \"DGL Documentation\",\n        author,\n        \"dgl\",\n        \"Library for deep learning on graphs.\",\n        \"Miscellaneous\",\n    ),\n]\n\n\n# -- Options for Epub output -------------------------------------------------\n\n# Bibliographic Dublin Core info.\nepub_title = project\n\n# The unique identifier of the text. This can be a ISBN number\n# or the project homepage.\n#\n# epub_identifier = ''\n\n# A unique identification for the text.\n#\n# epub_uid = ''\n\n# A list of files that should not be packed into the epub file.\nepub_exclude_files = [\"search.html\"]\n\n\n# -- Extension configuration -------------------------------------------------\nautosummary_generate = True\nautodoc_member_order = \"alphabetical\"\n# Skip the following members.\nautodoc_mock_imports = [\"dgl.nn.mxnet\", \"dgl.nn.tensorflow\"]\n\nintersphinx_mapping = {\n    \"python\": (\n        \"https://docs.python.org/{.major}\".format(sys.version_info),\n        None,\n    ),\n    \"numpy\": (\"http://docs.scipy.org/doc/numpy/\", None),\n    \"scipy\": (\"http://docs.scipy.org/doc/scipy/reference\", None),\n    \"matplotlib\": (\"http://matplotlib.org/\", None),\n    \"networkx\": (\"https://networkx.github.io/documentation/stable\", None),\n}\n\n# sphinx gallery configurations\nfrom sphinx_gallery.sorting import FileNameSortKey\n\nexamples_dirs = [\n    \"../../tutorials/blitz\",\n    \"../../tutorials/dist\",\n    \"../../tutorials/models\",\n    \"../../tutorials/multi\",\n    \"../../tutorials/cpu\",\n]  # path to find sources\ngallery_dirs = [\n    \"tutorials/blitz/\",\n    \"tutorials/dist/\",\n    \"tutorials/models/\",\n    \"tutorials/multi/\",\n    \"tutorials/cpu\",\n]  # path to generate docs\nif dglbackend != \"pytorch\":\n    examples_dirs = []\n    gallery_dirs = []\n\nreference_url = {\n    \"dgl\": None,\n    \"numpy\": \"http://docs.scipy.org/doc/numpy/\",\n    \"scipy\": \"http://docs.scipy.org/doc/scipy/reference\",\n    \"matplotlib\": \"http://matplotlib.org/\",\n    \"networkx\": \"https://networkx.github.io/documentation/stable\",\n}\n\nsphinx_gallery_conf = {\n    \"backreferences_dir\": \"generated/backreferences\",\n    \"doc_module\": (\"dgl\", \"numpy\"),\n    \"examples_dirs\": examples_dirs,\n    \"gallery_dirs\": gallery_dirs,\n    \"within_subsection_order\": FileNameSortKey,\n    \"filename_pattern\": \".py\",\n    \"download_all_examples\": False,\n}\n\n# Compatibility for different backend when builds tutorials\nif dglbackend == \"mxnet\":\n    sphinx_gallery_conf[\"filename_pattern\"] = \"/*(?<=mx)\\.py\"\nif dglbackend == \"pytorch\":\n    sphinx_gallery_conf[\"filename_pattern\"] = \"/*(?<!mx)\\.py\"\n\n# sphinx-copybutton tool\ncopybutton_prompt_text = r\">>> |\\.\\.\\. \"\ncopybutton_prompt_is_regexp = True\n"
  },
  {
    "path": "docs/source/contribute.rst",
    "content": "Contribute to DGL\n=================\n\nAny contribution to DGL is welcome. This guide covers everything\nabout how to contribute to DGL.\n\nGeneral development process\n---------------------------\n\nA non-inclusive list of types of contribution is as follows:\n\n* New features and enhancements (`example <https://github.com/dmlc/dgl/pull/331>`__).\n* New NN Modules (`example <https://github.com/dmlc/dgl/pull/788>`__).\n* Bugfix (`example <https://github.com/dmlc/dgl/pull/247>`__).\n* Document improvement (`example <https://github.com/dmlc/dgl/pull/263>`__).\n* New models and examples (`example <https://github.com/dmlc/dgl/pull/279>`__).\n\nFor features and bugfix, we recommend first raise an `issue <https://github.com/dmlc/dgl/issues>`__\nusing the corresponding issue template, so that the change could be fully discussed with\nthe community before implementation. For document improvement and new models, we suggest\npost a thread in our `discussion forum <https://discuss.dgl.ai>`__.\n\nBefore development, please first read the following sections about coding styles and testing.\nAll the changes need to be reviewed in the form of `pull request <https://github.com/dmlc/dgl/pulls>`__.\nOur `committors <https://github.com/orgs/dmlc/teams/dgl-team/members>`__\n(who have write permission on the repository) will review the codes and suggest the necessary\nchanges. The PR could be merged once the reviewers approve the changes.\n\nGit setup (for developers)\n--------------------------\n\nFirst, fork the DGL github repository. Suppose the forked repo is ``https://github.com/username/dgl``.\n\nClone your forked repository locally:\n\n.. code-block:: bash\n\n   git clone --recursive https://github.com/username/dgl.git\n\n\nSetup the upstream to the DGL official repository:\n\n.. code-block:: bash\n\n   git remote add upstream https://github.com/dmlc/dgl.git\n\nYou could verify the remote setting by typing ``git remote -v``:\n\n.. code-block:: bash\n\n   origin  https://github.com/username/dgl.git (fetch)\n   origin  https://github.com/username/dgl.git (push)\n   upstream        https://github.com/dmlc/dgl.git (fetch)\n   upstream        https://github.com/dmlc/dgl.git (push)\n\nDuring developing, we suggest work on another branch than the master.\n\n.. code-block:: bash\n\n   git branch working-branch\n   git checkout working-branch\n\nOnce the changes are done, `create a pull request <https://help.github.com/articles/creating-a-pull-request/>`__\nso we could review your codes.\n\nOnce the pull request is merged, update your forked repository and delete your working branch:\n\n.. code-block:: bash\n\n   git checkout master\n   git pull upstream master\n   git push origin master  # update your forked repo\n   git branch -D working-branch  # the local branch could be deleted\n\nCoding styles\n-------------\n\nFor python codes, we generally follow the `PEP8 style guide <https://www.python.org/dev/peps/pep-0008/>`__.\nThe python comments follow `NumPy style python docstrings <https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_numpy.html>`__.\n\nFor C++ codes, we generally follow the `Google C++ style guide <https://google.github.io/styleguide/cppguide.html>`__.\nThe C++ comments should be `Doxygen compatible <http://www.doxygen.nl/manual/docblocks.html#cppblock>`__.\n\nCoding styles check is mandatory for every pull requests. To ease the development, please check it\nlocally first (require cpplint and pylint to be installed first):\n\n.. code-block:: bash\n\n   bash tests/scripts/task_lint.sh\n\nThe python code style configure file is ``tests/lint/pylintrc``. We tweak it a little bit from\nthe standard. For example, following variable names are accepted:\n\n* ``i,j,k``: for loop variables\n* ``u,v``: for representing nodes\n* ``e``: for representing edges\n* ``g``: for representing graph\n* ``fn``: for representing functions\n* ``n,m``: for representing sizes\n* ``w,x,y``: for representing weight, input, output tensors\n* ``_``: for unused variables\n\nContributing New Models as Examples\n-----------------------------------\n\nTo contribute a new model within a specific supported tensor framework (e.g. PyTorch, or MXNet), simply\n\n1. Make a directory with the name of your model (say ``awesome-gnn``) within the directory\n   ``examples/${DGLBACKEND}`` where ``${DGLBACKEND}`` refers to the framework name.\n   \n2. Populate it with your work, along with a README.  Make a pull request once you are done.  Your README should contain at least these:\n\n   * Instructions for running your program.\n   \n   * The performance results, such as speed or accuracy or any metric, along with comparisons against some alternative implementations (if available).\n   \n     * Your performance metric does not have to beat others' implementation; they are just a signal of your code being *likely* correct.\n     \n     * Your speed also does not have to surpass others'.\n     \n     * However, better numbers are always welcomed.\n   \n3. The committers will review it, suggesting or making changes as necessary.\n\n4. Resolve the suggestions and reviews, and go back to step 3 until approved.\n\n5. Merge it and enjoy your day.\n\nData hosting\n````````````\n\nOne often wishes to upload a dataset when contributing a new runnable model example, especially when covering\na new field not in our existing examples.\n\nUploading data file into the Git repository directly is a **bad idea** because we do not want the cloners to\nalways download the dataset no matter what.  Instead, we strongly suggest the data files be hosted on a\npermanent cloud storage service (e.g. DropBox, Amazon S3, Baidu, Google Drive, etc.).\n\nOne can either\n\n* Make your scripts automatically download your data if possible (e.g. when using Amazon S3), or\n* Clearly state the instructions of downloading your dataset (e.g. when using Baidu, where auto-downloading\n  is hard).\n  \nIf you have trouble doing so (e.g. you cannot find a permanent cloud storage), feel free to post in our\n`discussion forum <https://discuss.dgl.ai>`__.\n\nDepending on the commonality of the contributed task, model, or dataset, we (the DGL team) would migrate\nyour dataset to the official DGL Dataset Repository on Amazon S3.  If you wish to host a particular dataset,\nyou can either\n\n* DIY: make changes in the ``dgl.data`` module; see our :ref:`dataset APIs <apidata>` for more details, or,\n* Post in our `discussion forum <https://discuss.dgl.ai>`__ (again).\n\nCurrently, all the datasets of DGL model examples are hosted on Amazon S3.\n\nContributing Core Features\n--------------------------\n\nWe call a feature that goes into the Python ``dgl`` package a *core feature*.\n\nSince DGL supports multiple tensor frameworks, contributing a core feature is no easy job.  However, we do\n**NOT** require knowledge of all tensor frameworks.  Instead,\n\n1. Before making a pull request, please make sure your code is covered with unit tests on **at least one**\n   supported frameworks; see the `Building and Testing`_ section for details.\n2. Once you have done that, make a pull request and summarize your changes, and wait for the CI to finish.\n3. If the CI fails on a tensor platform that you are unfamiliar with (which is well often the case), please\n   refer to `Supporting Multiple Platforms`_ section.\n4. The committers will review it, suggesting or making changes as necessary.\n5. Resolve the suggestions and reviews, and go back to step 3 until approved.\n6. Merge it and enjoy your day.\n\nSupporting Multiple Platforms\n`````````````````````````````\n\nThis is the hard one, but you don't have to know PyTorch AND MXNet (maybe AND Tensorflow, AND Chainer, etc.,\nin the future) to do so.  The rule of thumb in supporting Multiple Platforms is simple:\n\n* In the ``dgl`` Python package, **always** avoid using framework-specific operators (*including array indexing!*)\n  directly.  Use the wrappers in ``dgl.backend`` or ``numpy`` arrays instead.\n* If you have trouble doing so (either because ``dgl.backend`` does not cover the necessary operator, or you don't\n  have a GPU, or for whatever reason), please label your PR with the ``backend support`` tag, and one or more DGL\n  team member who understand CPU AND GPU AND PyTorch AND MXNet (AND Tensorflow AND Chainer AND etc.) will\n  look into it.\n\nBuilding and Testing\n````````````````````\n\nTo build DGL locally, follow the steps described in :ref:`Install from source <install-from-source>`.\nHowever, to ease the development, we suggest NOT install DGL but directly working in the source tree.\nTo achieve this, export following environment variables:\n\n.. code-block:: bash\n\n   export DGL_HOME=/path/to/your/dgl/clone\n   export DGL_LIBRARY_PATH=$DGL_HOME/build\n   export PYTHONPATH=$PYTHONPATH:$DGL_HOME/python\n\nIf you are working on performance critical part, you may want to turn on Cython build:\n\n.. code-block:: bash\n\n   cd python\n   python setup.py build_ext --inplace\n\nYou could test the build by running the following command and see the path of your local clone.\n\n.. code-block:: bash\n\n   python -c 'import dgl; print(dgl.__path__)'\n\nUnit tests\n~~~~~~~~~~\n\nCurrently, we use ``nose`` for unit tests.  The organization goes as follows:\n\n* ``backend``: Additional unified tensor interface for supported frameworks.\n  The functions there are only used in unit tests, not DGL itself.  Note that\n  the code there are not unit tests by themselves.  The additional backend can\n  be imported with\n  \n  .. code-block:: python\n\n     import backend\n\n  The additional backend contains the following files:\n\n  - ``backend/backend_unittest.py``: stub file for all additional tensor\n    functions.\n  - ``backend/${DGLBACKEND}/__init__.py``: implementations of the stubs\n    for the backend ``${DGLBACKEND}``.\n  - ``backend/__init__.py``: when imported, it replaces the stub implementations\n    with the framework-specific code, depending on the selected backend.  It\n    also changes the signature of some existing backend functions to automatically\n    select dtypes and contexts.\n\n* ``compute``: All framework-agnostic computation-related unit tests go there.\n  Anything inside should not depend on a specific tensor library.  Tensor\n  functions not provided in DGL unified tensor interface (i.e. ``dgl.backend``)\n  should go into ``backend`` directory.\n* ``${DGLBACKEND}`` (e.g. ``pytorch`` and ``mxnet``): All framework-specific\n  computation-related unit tests go there.\n* ``graph_index``: All unit tests for C++ graph structure implementation go\n  there.  The Python API being tested in this directory, if any, should be\n  as minimal as possible (usually simple wrappers of corresponding C++\n  functions).\n* ``lint``: Pylint-related files.\n* ``scripts``: Automated test scripts for CI.\n\nTo run unit tests, run\n\n.. code-block:: bash\n\n   sh tests/scripts/task_unit_test.sh <your-backend>\n\nwhere ``<your-backend>`` can be any supported backends (i.e. ``pytorch`` or ``mxnet``).\n\nContributing Documentations\n---------------------------\n\nIf the change is about document improvement, we suggest (and strongly suggest if you change the runnable code\nthere) building the document and render it locally before making a pull request.\n\nBuilding Docs Locally\n`````````````````````\n\nIn general building the docs locally involves the following:\n\n1. Install ``sphinx``, ``sphinx-gallery``, and ``sphinx_rtd_theme``.\n\n2. You need both PyTorch and MXNet because our tutorial contains code from both frameworks.  This does *not*\n   require knowledge of coding with both frameworks, though.\n   \n3. Run the following:\n\n   .. code-block:: bash\n   \n      cd docs\n      ./clean.sh\n      make html\n      cd build/html\n      python3 -m http.server 8080\n      \n4. Open ``http://localhost:8080`` and enjoy your work.\n\nSee `here <https://github.com/dmlc/dgl/tree/master/docs>`__ for more details.\n\nContributing Editorial Changes via GitHub Web Interface\n```````````````````````````````````````````````````````\n\nIf one is only changing the wording (i.e. not touching the runnable code at all), one can simply do\nwithout the usage of Git CLI:\n\n1. Make your fork by clicking on the **Fork** button in the DGL main repository web page.\n2. Make whatever changes in the web interface *within your own fork*.  You can usually tell\n   if you are inside your own fork or in the main repository by checking whether you can commit\n   to the ``master`` branch: if you cannot, you are in the wrong place.\n3. Once done, make a pull request (on the web interface).\n4. The committers will review it, suggesting or making changes as necessary.\n5. Resolve the suggestions and reviews, and go back to step 4 until approved.\n6. Merge it and enjoy your day.\n\nContributing Code Changes\n`````````````````````````\n\nWhen changing code, please make sure to build it locally and see if it fails.\n"
  },
  {
    "path": "docs/source/developer/ffi.rst",
    "content": ".. currentmodule:: dgl\n\nDGL Foreign Function Interface (FFI)\n====================================\n\nWe all like Python because it is easy to manipulate. We all like C because it\nis fast, reliable and typed. To have the merits of both ends, DGL is mostly in\npython, for quick prototyping, while lowers the performance-critical part to C.\nThus, DGL developers frequently face the scenario to write a C routine and has\nit exposed to python, via a mechanism called *Foreign Function Interface (FFI)*.\n\nThere are many FFI solutions out there. In DGL, we want to keep it simple,\nintuitive and efficient for critical use cases. That's why when we came across the\nFFI solution in the TVM project, we immediately fell for it. It exploits the idea of\nfunctional programming so that it exposes only a dozens of C APIs and new APIs\ncan be built upon it.\n\nWe decided to borrow the idea (shamelessly). For example, to define a C\nAPI that is exposed to python is only a few lines of codes:\n\n.. code:: c++\n\n   // file: calculator.cc (put it in dgl/src folder)\n   #include <dgl/runtime/packed_func.h>\n   #include <dgl/runtime/registry.h>\n\n   using namespace dgl::runtime;\n\n   DGL_REGISTER_GLOBAL(\"calculator.MyAdd\")\n   .set_body([] (DGLArgs args, DGLRetValue* rv) {\n       int a = args[0];\n       int b = args[1];\n       *rv = a + b;\n     });\n\nCompile and build the library. On the python side, create a\n``calculator.py`` file under ``dgl/python/dgl/``\n\n.. code:: python\n\n   # file: calculator.py\n   from ._ffi.function import _init_api\n\n   def add(a, b):\n     # MyAdd has been registered via `_ini_api` call below\n     return MyAdd(a, b)\n\n   _init_api(\"dgl.calculator\")\n\nThe trick is that the FFI system first masks the type information of the\nfunction arguments, so all the C function calls can go through one C API\n(``DGLFuncCall``). The type information is retrieved in the function body by\nstatic conversion, and we will do runtime type check to make sure that the type\nconversion is correct. The overhead of such back-and-forth is negligible as\nlong as the function call is not too light (the above example is actually a bad\none). TVM's `PackedFunc\ndocument <https://docs.tvm.ai/dev/runtime.html#packedfunc>`_ has more details.\n\nDefining new types\n------------------\n\n``DGLArgs`` and ``DGLRetValue`` only support a limited number of types:\n\n* Numerical values: int, float, double, ...\n* string\n* Function (in the form of PackedFunc)\n* NDArray\n\nThough limited, the above type system is very powerful because it supports\nfunction as a first-class citizen. For example, if you want to return multiple\nvalues, you can return a PackedFunc which returns each value given an integer\nindex. However, in many cases, new types are still desired to ease the\ndevelopment process:\n\n* The argument/return value is a composition of collections (e.g. dictionary of\n  dictionary of list).\n* Sometimes we just want to have a notion of \"structure\" (e.g. given an apple,\n  get its color by ``apple.color``).\n\nTo achieve this, we introduce the Object type system. For example, to define a\nnew type ``Calculator``:\n\n.. code:: c++\n\n   // file: calculator.cc\n   #include <dgl/packed_func_ext.h>\n   using namespace runtime;\n   class CalculatorObject : public Object {\n    public:\n     std::string brand;\n     int price;\n     \n     void VisitAttrs(AttrVisitor *v) final {\n       v->Visit(\"brand\", &brand);\n       v->Visit(\"price\", &price);\n     }\n\n     static constexpr const char* _type_key = \"Calculator\";\n     DGL_DECLARE_OBJECT_TYPE_INFO(CalculatorObject, Object);\n   };\n\n   // This is to define a reference class (the wrapper of an object shared pointer).\n   // A minimal implementation is as follows, but you could define extra methods.\n   class Calculator : public ObjectRef {\n    public:\n     const CalculatorObject* operator->() const {\n       return static_cast<const CalculatorObject*>(obj_.get());\n     }\n     using ContainerType = CalculatorObject;\n   };\n\n   DGL_REGISTER_GLOBAL(\"calculator.CreateCaculator\")\n   .set_body([] (DGLArgs args, DGLRetValue* rv) {\n     std::string brand = args[0];\n     int price = args[1];\n     auto o = std::make_shared<CalculatorObject>();\n     o->brand = brand;\n     o->price = price;\n     *rv = o;\n   }\n\nOn the python side:\n\n.. code:: python\n\n   # file: calculator.py\n   from dgl._ffi.object import register_object, ObjectBase\n   from ._ffi.function import _init_api\n\n   @register_object\n   class Calculator(ObjectBase):\n     @staticmethod\n     def create(brand, price):\n       # invoke a C API, the return value is of `Calculator` type\n       return CreateCalculator(brand, price)\n\n   _init_api(\"dgl.calculator\")\n\nWe can then simply create ``Calculator`` object by:\n\n.. code:: python\n\n   calc = Calculator.create(\"casio\", 100)\n\nWhat is nice about this object is that, it defines a visitor pattern that is\nessentially a reflection mechanism to get its internal attributes. For example,\nyou can print the calculator's brand and by simply accessing its attributes.\n\n.. code:: python\n\n   print(calc.brand)\n   print(calc.price)\n\nThe reflection is indeed a little bit slow due to the string key lookup. To\nspeed it up, you could define an attribute access API:\n\n.. code:: c++\n\n   // file: calculator.cc\n   DGL_REGISTER_GLOBAL(\"calculator.CaculatorGetBrand\")\n   .set_body([] (DGLArgs args, DGLRetValue* rv) {\n     Calculator calc = args[0];\n     *rv = calc->brand;\n   }\n\nContainers\n----------\n\nContainers are also objects. For example, the C API below accepts a list of\nintegers and return their sum:\n\n.. code:: c++\n\n   // in file: calculator.cc\n   #include <dgl/runtime/container.h>\n   using namespace runtime;\n   DGL_REGISTER_GLOBAL(\"calculator.Sum\")\n   .set_body([] (DGLArgs args, DGLRetValue* rv) {\n     // All the DGL supported values are represented as a ValueObject, which\n     //   contains a data field.\n     List<Value> values = args[0];\n     int sum = 0;\n     for (int i = 0; i < values.size(); ++i) {\n       sum += static_cast<int>(values[i]->data);\n     }\n   }\n\nInvoking this API is simple -- just pass a python list of integers. DGL FFI will\nautomatically convert python list/tuple/dictionary to the corresponding object\ntype.\n\n.. code:: python\n\n   # in file: calculator.py\n   from ._ffi.function import _init_api\n\n   Sum([0, 1, 2, 3, 4, 5])\n\n   _init_api(\"dgl.calculator\")\n\nThe elements in the containers can be any objects, which allows the containers\nto be composed. Below is an API that accepts a list of calculators and print\nout their price:\n\n.. code:: c++\n\n   // in file: calculator.cc\n   #include <iostream>\n   #include <dgl/runtime/container.h>\n   using namespace runtime;\n   DGL_REGISTER_GLOBAL(\"calculator.PrintCalculators\")\n   .set_body([] (DGLArgs args, DGLRetValue* rv) {\n     List<Calculator> calcs = args[0];\n     for (int i = 0; i < calcs.size(); ++i) {\n       std::cout << calcs[i]->price << std::endl;\n     }\n   }\n\nPlease note that containers are NOT meant for passing a large collection of\nitems from/to C APIs. It will be quite slow in these cases. It is recommended\nto benchmark first. As an alternative, use NDArray for a large collection of\nnumerical values and use ``dgl.batch`` to batch a lot of ``DGLGraph``'s into \na single ``DGLGraph``.\n"
  },
  {
    "path": "docs/source/env_var.rst",
    "content": "Environment Variables\n=====================\n\nGlobal Configurations\n---------------------\n* ``DGLDEFAULTDIR``:\n    * Values: String (default=``\"${HOME}/.dgl\"``)\n    * The directory to save the DGL configuration files.\n\n* ``DGL_LOG_DEBUG``:\n    * Values: Set to ``\"1\"`` to enable debug level logging for DGL\n    * Enable debug level logging for DGL\n\nBackend Options\n---------------\n* ``DGLBACKEND``:\n    * Values: String (default='pytorch')\n    * The backend deep learning framework for DGL.\n    * Choices:\n        * 'pytorch': use PyTorch as the backend implementation.        \n        * 'tensorflow': use Apache TensorFlow as the backend implementation.\n        * 'mxnet': use Apache MXNet as the backend implementation.\n\nData Repository\n---------------\n* ``DGL_REPO``:\n    * Values: String (default='https://data.dgl.ai/')\n    * The repository url to be used for DGL datasets and pre-trained models.\n    * Suggested values:\n        * 'https://data.dgl.ai/': DGL repo for Global Region.\n        * 'https://dgl-data.s3.cn-north-1.amazonaws.com.cn/': DGL repo for Mainland China\n* ``DGL_DOWNLOAD_DIR``:\n    * Values: String (default=``\"${HOME}/.dgl\"``)\n    * The local directory to cache the downloaded data.\n"
  },
  {
    "path": "docs/source/faq.rst",
    "content": "Frequently Asked Questions (FAQ)\n================================\n\nFor frequently asked questions, refer to `this post <https://discuss.dgl.ai/t/frequently-asked-questions-faq/1681>`__.\n"
  },
  {
    "path": "docs/source/features/dataset.rst",
    "content": "Dataset (Temporary)\n\n\n.. table:: \n\n    +----------------+----------------------------------------------------------+-----------+---------------+---------------+--------------------------------------------+-----------+--------+\n    |  Datset Name   |                          Usage                           |# of graphs|Avg. # of nodes|Avg. # of edges|                 Node field                 |Edge field |Temporal|\n    +================+==========================================================+===========+===============+===============+============================================+===========+========+\n    |BitcoinOTC      |BitcoinOTC()                                              |        136|        6005.00|       21209.98|                                            |h          |True    |\n    +----------------+----------------------------------------------------------+-----------+---------------+---------------+--------------------------------------------+-----------+--------+\n    |Cora            |CitationGraphDataset('cora')                              |          1|        2708.00|       10556.00|train_mask, val_mask, test_mask, label, feat|           |False   |\n    +----------------+----------------------------------------------------------+-----------+---------------+---------------+--------------------------------------------+-----------+--------+\n    |Citeseer        |CitationGraphDataset('citeseer')                          |          1|        3327.00|        9228.00|train_mask, val_mask, test_mask, label, feat|           |False   |\n    +----------------+----------------------------------------------------------+-----------+---------------+---------------+--------------------------------------------+-----------+--------+\n    |PubMed          |CitationGraphDataset('pubmed')                            |          1|       19717.00|       88651.00|train_mask, val_mask, test_mask, label, feat|           |False   |\n    +----------------+----------------------------------------------------------+-----------+---------------+---------------+--------------------------------------------+-----------+--------+\n    |QM7b            |QM7b()                                                    |       7211|          15.42|         244.95|                                            |h          |False   |\n    +----------------+----------------------------------------------------------+-----------+---------------+---------------+--------------------------------------------+-----------+--------+\n    |Reddit          |RedditDataset()                                           |          1|      232965.00|   114615892.00|train_mask, val_mask, test_mask, feat, label|           |False   |\n    +----------------+----------------------------------------------------------+-----------+---------------+---------------+--------------------------------------------+-----------+--------+\n    |ENZYMES         |TUDataset('ENZYMES')                                      |        600|          32.63|         124.27|node_labels, node_attr                      |           |False   |\n    +----------------+----------------------------------------------------------+-----------+---------------+---------------+--------------------------------------------+-----------+--------+\n    |DD              |TUDataset('DD')                                           |       1178|         284.32|        1431.32|node_labels                                 |           |False   |\n    +----------------+----------------------------------------------------------+-----------+---------------+---------------+--------------------------------------------+-----------+--------+\n    |COLLAB          |TUDataset('COLLAB')                                       |       5000|          74.49|        9830.00|                                            |           |False   |\n    +----------------+----------------------------------------------------------+-----------+---------------+---------------+--------------------------------------------+-----------+--------+\n    |MUTAG           |TUDataset('MUTAG')                                        |        188|          17.93|          39.59|node_labels                                 |edge_labels|False   |\n    +----------------+----------------------------------------------------------+-----------+---------------+---------------+--------------------------------------------+-----------+--------+\n    |PROTEINS        |TUDataset('PROTEINS')                                     |       1113|          39.06|         145.63|node_labels, node_attr                      |           |False   |\n    +----------------+----------------------------------------------------------+-----------+---------------+---------------+--------------------------------------------+-----------+--------+\n    |PPI             |PPIDataset('train')/PPIDataset('valid')/PPIDataset('test')|         20|        2245.30|       63563.70|feat                                        |           |False   |\n    +----------------+----------------------------------------------------------+-----------+---------------+---------------+--------------------------------------------+-----------+--------+\n    |KarateClub      |KarateClub()                                              |          1|          34.00|         156.00|label                                       |           |False   |\n    +----------------+----------------------------------------------------------+-----------+---------------+---------------+--------------------------------------------+-----------+--------+\n    |Amazon computer |AmazonCoBuy('computers')                                  |          1|       13752.00|      574418.00|feat, label                                 |           |False   |\n    +----------------+----------------------------------------------------------+-----------+---------------+---------------+--------------------------------------------+-----------+--------+\n    |Amazon photo    |AmazonCoBuy('photo')                                      |          1|        7650.00|      287326.00|feat, label                                 |           |False   |\n    +----------------+----------------------------------------------------------+-----------+---------------+---------------+--------------------------------------------+-----------+--------+\n    |Coauthor cs     |Coauthor('cs')                                            |          1|       18333.00|      327576.00|feat, label                                 |           |False   |\n    +----------------+----------------------------------------------------------+-----------+---------------+---------------+--------------------------------------------+-----------+--------+\n    |Coauthor physics|Coauthor('physics')                                       |          1|       34493.00|      991848.00|feat, label                                 |           |False   |\n    +----------------+----------------------------------------------------------+-----------+---------------+---------------+--------------------------------------------+-----------+--------+\n    |GDELT           |GDELT('train')/GDELT('valid')/GDELT('test')               |       2304|       23033.00|      811333.15|                                            |rel_type   |True    |\n    +----------------+----------------------------------------------------------+-----------+---------------+---------------+--------------------------------------------+-----------+--------+\n    |ICEWS18         |ICEWS18('train')/ICEWS18('valid')/ICEWS18('test')         |        240|       23033.00|      192640.22|                                            |rel_type   |True    |\n    +----------------+----------------------------------------------------------+-----------+---------------+---------------+--------------------------------------------+-----------+--------+\n    |CoraFull        |CoraFull()                                                |          1|       19793.00|      130622.00|feat, label                                 |           |False   |\n    +----------------+----------------------------------------------------------+-----------+---------------+---------------+--------------------------------------------+-----------+--------+"
  },
  {
    "path": "docs/source/gen_dataset_stat.py",
    "content": "import numpy as np\nimport pandas as pd\nfrom dgl import DGLGraph\n\n# from dgl.data.qm9 import QM9\nfrom dgl.data import CitationGraphDataset, PPIDataset, RedditDataset, TUDataset\nfrom dgl.data.bitcoinotc import BitcoinOTC\nfrom dgl.data.gdelt import GDELT\nfrom dgl.data.gindt import GINDataset\nfrom dgl.data.gnn_benchmark import AmazonCoBuy, Coauthor, CoraFull\nfrom dgl.data.icews18 import ICEWS18\nfrom dgl.data.karate import KarateClub\nfrom dgl.data.qm7b import QM7b\nfrom pytablewriter import MarkdownTableWriter, RstGridTableWriter\n\nds_list = {\n    \"BitcoinOTC\": \"BitcoinOTC()\",\n    \"Cora\": \"CitationGraphDataset('cora')\",\n    \"Citeseer\": \"CitationGraphDataset('citeseer')\",\n    \"PubMed\": \"CitationGraphDataset('pubmed')\",\n    \"QM7b\": \"QM7b()\",\n    \"Reddit\": \"RedditDataset()\",\n    \"ENZYMES\": \"TUDataset('ENZYMES')\",\n    \"DD\": \"TUDataset('DD')\",\n    \"COLLAB\": \"TUDataset('COLLAB')\",\n    \"MUTAG\": \"TUDataset('MUTAG')\",\n    \"PROTEINS\": \"TUDataset('PROTEINS')\",\n    \"PPI\": \"PPIDataset('train')/PPIDataset('valid')/PPIDataset('test')\",\n    # \"Cora Binary\": \"CitationGraphDataset('cora_binary')\",\n    \"KarateClub\": \"KarateClub()\",\n    \"Amazon computer\": \"AmazonCoBuy('computers')\",\n    \"Amazon photo\": \"AmazonCoBuy('photo')\",\n    \"Coauthor cs\": \"Coauthor('cs')\",\n    \"Coauthor physics\": \"Coauthor('physics')\",\n    \"GDELT\": \"GDELT('train')/GDELT('valid')/GDELT('test')\",\n    \"ICEWS18\": \"ICEWS18('train')/ICEWS18('valid')/ICEWS18('test')\",\n    \"CoraFull\": \"CoraFull()\",\n}\n\nwriter = RstGridTableWriter()\n# writer = MarkdownTableWriter()\n\nextract_graph = lambda g: g if isinstance(g, DGLGraph) else g[0]\nstat_list = []\nfor k, v in ds_list.items():\n    print(k, \" \", v)\n    ds = eval(v.split(\"/\")[0])\n    num_nodes = []\n    num_edges = []\n    for i in range(len(ds)):\n        g = extract_graph(ds[i])\n        num_nodes.append(g.num_nodes())\n        num_edges.append(g.num_edges())\n\n    gg = extract_graph(ds[0])\n    dd = {\n        \"Datset Name\": k,\n        \"Usage\": v,\n        \"# of graphs\": len(ds),\n        \"Avg. # of nodes\": np.mean(num_nodes),\n        \"Avg. # of edges\": np.mean(num_edges),\n        \"Node field\": \", \".join(list(gg.ndata.keys())),\n        \"Edge field\": \", \".join(list(gg.edata.keys())),\n        # \"Graph field\": ', '.join(ds[0][0].gdata.keys()) if hasattr(ds[0][0], \"gdata\") else \"\",\n        \"Temporal\": hasattr(ds, \"is_temporal\"),\n    }\n    stat_list.append(dd)\n\nprint(dd.keys())\ndf = pd.DataFrame(stat_list)\ndf = df.reindex(columns=dd.keys())\nwriter.from_dataframe(df)\n\nwriter.write_table()\n"
  },
  {
    "path": "docs/source/graphtransformer/data.rst",
    "content": "Prepare Data\n============\n\nIn this section, we will prepare the data for the Graphormer model introduced before. We can use any dataset containing :class:`~dgl.DGLGraph` objects and standard PyTorch dataloader to feed the data to the model. The key is to define a collate function to group features of multiple graphs into batches. We show an example of the collate function as follows:\n\n\n.. code:: python\n\n    def collate(graphs):\n        # compute shortest path features, can be done in advance\n        for g in graphs:\n            spd, path = dgl.shortest_dist(g, root=None, return_paths=True)\n            g.ndata[\"spd\"] = spd\n            g.ndata[\"path\"] = path\n\n        num_graphs = len(graphs)\n        num_nodes = [g.num_nodes() for g in graphs]\n        max_num_nodes = max(num_nodes)\n\n        attn_mask = th.zeros(num_graphs, max_num_nodes, max_num_nodes)\n        node_feat = []\n        in_degree, out_degree = [], []\n        path_data = []\n        # Since shortest_dist returns -1 for unreachable node pairs and padded\n        # nodes are unreachable to others, distance relevant to padded nodes\n        # use -1 padding as well.\n        dist = -th.ones(\n            (num_graphs, max_num_nodes, max_num_nodes), dtype=th.long\n        )\n\n        for i in range(num_graphs):\n            # A binary mask where invalid positions are indicated by True.\n            # Avoid the case where all positions are invalid.\n            attn_mask[i, :, num_nodes[i] + 1 :] = 1\n\n            # +1 to distinguish padded non-existing nodes from real nodes\n            node_feat.append(graphs[i].ndata[\"feat\"] + 1)\n\n            # 0 for padding\n            in_degree.append(\n                th.clamp(graphs[i].in_degrees() + 1, min=0, max=512)\n            )\n            out_degree.append(\n                th.clamp(graphs[i].out_degrees() + 1, min=0, max=512)\n            )\n\n            # Path padding to make all paths to the same length \"max_len\".\n            path = graphs[i].ndata[\"path\"]\n            path_len = path.size(dim=2)\n            # shape of shortest_path: [n, n, max_len]\n            max_len = 5\n            if path_len >= max_len:\n                shortest_path = path[:, :, :max_len]\n            else:\n                p1d = (0, max_len - path_len)\n                # Use the same -1 padding as shortest_dist for\n                # invalid edge IDs.\n                shortest_path = th.nn.functional.pad(path, p1d, \"constant\", -1)\n            pad_num_nodes = max_num_nodes - num_nodes[i]\n            p3d = (0, 0, 0, pad_num_nodes, 0, pad_num_nodes)\n            shortest_path = th.nn.functional.pad(shortest_path, p3d, \"constant\", -1)\n            # +1 to distinguish padded non-existing edges from real edges\n            edata = graphs[i].edata[\"feat\"] + 1\n\n            # shortest_dist pads non-existing edges (at the end of shortest\n            # paths) with edge IDs -1, and th.zeros(1, edata.shape[1]) stands\n            # for all padded edge features.\n            edata = th.cat(\n                (edata, th.zeros(1, edata.shape[1]).to(edata.device)), dim=0\n            )\n            path_data.append(edata[shortest_path])\n\n            dist[i, : num_nodes[i], : num_nodes[i]] = graphs[i].ndata[\"spd\"]\n\n        # node feat padding\n        node_feat = th.nn.utils.rnn.pad_sequence(node_feat, batch_first=True)\n\n        # degree padding\n        in_degree = th.nn.utils.rnn.pad_sequence(in_degree, batch_first=True)\n        out_degree = th.nn.utils.rnn.pad_sequence(out_degree, batch_first=True)\n\n        return (\n            node_feat,\n            in_degree,\n            out_degree,\n            attn_mask,\n            th.stack(path_data),\n            dist,\n        )\n\nIn this example, we also omit details like the addition of a virtual node. For more details, please refer to the `Graphormer example <https://github.com/dmlc/dgl/tree/master/examples/core/Graphormer>`_.\n"
  },
  {
    "path": "docs/source/graphtransformer/index.rst",
    "content": "🆕 Tutorial: Graph Transformer\n==========\n\nThis tutorial introduces the **graph transformer** (:mod:`~dgl.nn.gt`) module,\nwhich is a set of utility modules for building and training graph transformer models.\n\n.. toctree::\n  :maxdepth: 2\n  :titlesonly:\n\n  model\n  data\n"
  },
  {
    "path": "docs/source/graphtransformer/model.rst",
    "content": "Build Model\n===========\n\n**GraphTransformer** is a graph neural network that uses multi-head self-attention (sparse or dense) to encode the graph structure and node features. It is a generalization of the `Transformer <https://arxiv.org/abs/1706.03762>`_ architecture to arbitrary graphs. \n\nIn this tutorial, we will show how to build a graph transformer model with DGL using the `Graphormer <https://arxiv.org/abs/2106.05234>`_ model as an example.\n\nGraphormer is a Transformer model designed for graph-structured data, which encodes the structural information of a graph into the standard Transformer. Specifically, Graphormer utilizes degree encoding to measure the importance of nodes, spatial and path Encoding to measure the relation between node pairs. The degree encoding and the node features serve as input to Graphormer, while the spatial and path encoding act as bias terms in the self-attention module.\n\nDegree Encoding\n-------------------\nThe degree encoder is a learnable embedding layer that encodes the degree of each node into a vector. It takes as input the batched input and output degrees of graph nodes, and outputs the degree embeddings of the nodes.\n\n.. code:: python\n\n    degree_encoder = dgl.nn.DegreeEncoder(\n        max_degree=8,  # the maximum degree to cut off\n        embedding_dim=512  # the dimension of the degree embedding\n    )\n\nPath Encoding\n-------------\nThe path encoder encodes the edge features on the shortest path between two nodes to get attention bias for the self-attention module. It takes as input the batched edge features in shape  and outputs the attention bias based on path encoding.\n\n.. code:: python\n\n    path_encoder = PathEncoder(\n        max_len=5,  # the maximum length of the shortest path\n        feat_dim=512,  # the dimension of the edge feature\n        num_heads=8,  # the number of attention heads\n    )\n\nSpatial Encoding\n----------------\nThe spatial encoder encodes the shortest distance between two nodes to get attention bias for the self-attention module. It takes as input the shortest distance between two nodes and outputs the attention bias based on spatial encoding.\n\n.. code:: python\n\n    spatial_encoder = SpatialEncoder(\n        max_dist=5,  # the maximum distance between two nodes\n        num_heads=8,  # the number of attention heads\n    )\n\n\nGraphormer Layer\n----------------\nThe Graphormer layer is like a Transformer encoder layer with the Multi-head Attention part replaced with :class:`~dgl.nn.BiasedMHA`. It takes in not only the input node features, but also the attention bias computed computed above, and outputs the updated node features.\n\nWe can stack multiple Graphormer layers as a list just like implementing a Transformer encoder in PyTorch.\n\n.. code:: python\n\n    layers = th.nn.ModuleList([\n        GraphormerLayer(\n            feat_size=512,  # the dimension of the input node features\n            hidden_size=1024,  # the dimension of the hidden layer\n            num_heads=8,  # the number of attention heads\n            dropout=0.1,  # the dropout rate\n            activation=th.nn.ReLU(),  # the activation function\n            norm_first=False,  # whether to put the normalization before attention and feedforward\n        )\n        for _ in range(6)\n    ])\n\nModel Forward\n-------------\nGrouping the modules above defines the primary components of the Graphormer model. We then can define the forward process as follows:\n\n.. code:: python\n\n    node_feat, in_degree, out_degree, attn_mask, path_data, dist = \\\n        next(iter(dataloader))  #  we will use the first batch as an example\n    num_graphs, max_num_nodes, _ = node_feat.shape\n    deg_emb = degree_encoder(th.stack((in_degree, out_degree)))\n\n    # node feature + degree encoding as input\n    node_feat = node_feat + deg_emb\n\n    # spatial encoding and path encoding serve as attention bias\n    path_encoding = path_encoder(dist, path_data)\n    spatial_encoding = spatial_encoder(dist)\n    attn_bias[:, 1:, 1:, :] = path_encoding + spatial_encoding\n\n    # graphormer layers\n    for layer in layers:\n        x = layer(\n            x,\n            attn_mask=attn_mask,\n            attn_bias=attn_bias,\n        )\n\nFor simplicity, we omit some details in the forward process. For the complete implementation, please refer to the `Graphormer example <https://github.com/dmlc/dgl/tree/master/examples/core/Graphormer>`_.\n\nYou can also explore other `utility modules <https://docs.dgl.ai/api/python/nn-pytorch.html#utility-modules-for-graph-transformer>`_ to customize your own graph transformer model. In the next section, we will show how to prepare the data for training.\n"
  },
  {
    "path": "docs/source/guide/data-dataset.rst",
    "content": ".. _guide-data-pipeline-dataset:\n\n4.1 DGLDataset class\n--------------------\n\n:ref:`(中文版) <guide_cn-data-pipeline-dataset>`\n\n:class:`~dgl.data.DGLDataset` is the base class for processing, loading and saving\ngraph datasets defined in :ref:`apidata`. It implements the basic pipeline\nfor processing graph data. The following flow chart shows how the\npipeline works.\n\nTo process a graph dataset located in a remote server or local disk, one can\ndefine a class, say ``MyDataset``, inheriting from :class:`dgl.data.DGLDataset`. The\ntemplate of ``MyDataset`` is as follows.\n\n.. figure:: https://data.dgl.ai/asset/image/userguide_data_flow.png\n    :align: center\n\n    Flow chart for graph data input pipeline defined in class DGLDataset.\n\n.. code:: \n\n    from dgl.data import DGLDataset\n    \n    class MyDataset(DGLDataset):\n        \"\"\" Template for customizing graph datasets in DGL.\n    \n        Parameters\n        ----------\n        url : str\n            URL to download the raw dataset\n        raw_dir : str\n            Specifying the directory that will store the \n            downloaded data or the directory that\n            already stores the input data.\n            Default: ~/.dgl/\n        save_dir : str\n            Directory to save the processed dataset.\n            Default: the value of `raw_dir`\n        force_reload : bool\n            Whether to reload the dataset. Default: False\n        verbose : bool\n            Whether to print out progress information\n        \"\"\"\n        def __init__(self, \n                     url=None, \n                     raw_dir=None, \n                     save_dir=None, \n                     force_reload=False, \n                     verbose=False):\n            super(MyDataset, self).__init__(name='dataset_name',\n                                            url=url,\n                                            raw_dir=raw_dir,\n                                            save_dir=save_dir,\n                                            force_reload=force_reload,\n                                            verbose=verbose)\n    \n        def download(self):\n            # download raw data to local disk\n            pass\n    \n        def process(self):\n            # process raw data to graphs, labels, splitting masks\n            pass\n        \n        def __getitem__(self, idx):\n            # get one example by index\n            pass\n    \n        def __len__(self):\n            # number of data examples\n            pass\n    \n        def save(self):\n            # save processed data to directory `self.save_path`\n            pass\n    \n        def load(self):\n            # load processed data from directory `self.save_path`\n            pass\n    \n        def has_cache(self):\n            # check whether there are processed data in `self.save_path`\n            pass\n\n\n:class:`~dgl.data.DGLDataset` class has abstract functions ``process()``,\n``__getitem__(idx)`` and ``__len__()`` that must be implemented in the\nsubclass. DGL also recommends implementing saving and loading as well,\nsince they can save significant time for processing large datasets, and\nthere are several APIs making it easy (see :ref:`guide-data-pipeline-savenload`).\n\nNote that the purpose of :class:`~dgl.data.DGLDataset` is to provide a standard and\nconvenient way to load graph data. One can store graphs, features,\nlabels, masks and basic information about the dataset, such as number of\nclasses, number of labels, etc. Operations such as sampling, partition\nor feature normalization are done outside of the :class:`~dgl.data.DGLDataset`\nsubclass.\n\nThe rest of this chapter shows the best practices to implement the\nfunctions in the pipeline.\n"
  },
  {
    "path": "docs/source/guide/data-download.rst",
    "content": ".. _guide-data-pipeline-download:\n\n4.2 Download raw data (optional)\n--------------------------------\n\n:ref:`(中文版) <guide_cn-data-pipeline-download>`\n\nIf a dataset is already in local disk, make sure it’s in directory\n``raw_dir``. If one wants to run the code anywhere without bothering to\ndownload and move data to the right directory, one can do it\nautomatically by implementing function ``download()``.\n\nIf the dataset is a zip file, make ``MyDataset`` inherit from\n:class:`dgl.data.DGLBuiltinDataset` class, which handles the zip file extraction for us. Otherwise,\none needs to implement ``download()`` like in :class:`~dgl.data.QM7bDataset`:\n\n.. code:: \n\n    import os\n    from dgl.data.utils import download\n    \n    def download(self):\n        # path to store the file\n        file_path = os.path.join(self.raw_dir, self.name + '.mat')\n        # download file\n        download(self.url, path=file_path)\n\nThe above code downloads a .mat file to directory ``self.raw_dir``. If\nthe file is a .gz, .tar, .tar.gz or .tgz file, use :func:`~dgl.data.utils.extract_archive`\nfunction to extract. The following code shows how to download a .gz file\nin :class:`~dgl.data.BitcoinOTCDataset`:\n\n.. code:: \n\n    from dgl.data.utils import download, check_sha1\n    \n    def download(self):\n        # path to store the file\n        # make sure to use the same suffix as the original file name's\n        gz_file_path = os.path.join(self.raw_dir, self.name + '.csv.gz')\n        # download file\n        download(self.url, path=gz_file_path)\n        # check SHA-1\n        if not check_sha1(gz_file_path, self._sha1_str):\n            raise UserWarning('File {} is downloaded but the content hash does not match.'\n                              'The repo may be outdated or download may be incomplete. '\n                              'Otherwise you can create an issue for it.'.format(self.name + '.csv.gz'))\n        # extract file to directory `self.name` under `self.raw_dir`\n        self._extract_gz(gz_file_path, self.raw_path)\n\nThe above code will extract the file into directory ``self.name`` under\n``self.raw_dir``. If the class inherits from :class:`dgl.data.DGLBuiltinDataset`\nto handle zip file, it will extract the file into directory ``self.name`` \nas well.\n\nOptionally, one can check SHA-1 string of the downloaded file as the\nexample above does, in case the author changed the file in the remote\nserver some day.\n"
  },
  {
    "path": "docs/source/guide/data-loadcsv.rst",
    "content": ".. _guide-data-pipeline-loadcsv:\n\n4.6 Loading data from CSV files\n----------------------------------------------\n\nComma Separated Value (CSV) is a widely used data storage format. DGL provides\n:class:`~dgl.data.CSVDataset` for loading and parsing graph data stored in\nCSV format.\n\nTo create a ``CSVDataset`` object:\n\n.. code:: python\n\n    import dgl\n    ds = dgl.data.CSVDataset('/path/to/dataset')\n\nThe returned ``ds`` object is a standard :class:`~dgl.data.DGLDataset`. For\nexample, one can get graph samples using ``__getitem__`` as well as node/edge\nfeatures using ``ndata``/``edata``.\n\n.. code:: python\n\n    # A demonstration of how to use the loaded dataset. The feature names\n    # may vary depending on the CSV contents.\n    g = ds[0] # get the graph\n    label = g.ndata['label']\n    feat = g.ndata['feat']\n\nData folder structure\n~~~~~~~~~~~~~~~~~~~~~\n\n.. code::\n\n    /path/to/dataset/\n    |-- meta.yaml     # metadata of the dataset\n    |-- edges_0.csv   # edge data including src_id, dst_id, feature, label and so on\n    |-- ...           # you can have as many CSVs for edge data as you want\n    |-- nodes_0.csv   # node data including node_id, feature, label and so on\n    |-- ...           # you can have as many CSVs for node data as you want\n    |-- graphs.csv    # graph-level features\n\nNode/edge/graph-level data are stored in CSV files. ``meta.yaml`` is a metadata file specifying\nwhere to read nodes/edges/graphs data and how to parse them to construct the dataset\nobject. A minimal data folder contains one ``meta.yaml`` and two CSVs, one for node data and one\nfor edge data, in which case the dataset contains only a single graph with no graph-level data.\n\nDataset of a single feature-less graph\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nWhen the dataset contains only one graph with no node or edge features, there need only three\nfiles in the data folder: ``meta.yaml``, one CSV for node IDs and one CSV for edges:\n\n.. code::\n\n    ./mini_featureless_dataset/\n    |-- meta.yaml\n    |-- nodes.csv\n    |-- edges.csv\n\n``meta.yaml`` contains the following information:\n\n.. code:: yaml\n\n    dataset_name: mini_featureless_dataset\n    edge_data:\n    - file_name: edges.csv\n    node_data:\n    - file_name: nodes.csv\n\n``nodes.csv`` lists the node IDs under the ``node_id`` field:\n\n.. code::\n\n    node_id\n    0\n    1\n    2\n    3\n    4\n\n``edges.csv`` lists all the edges in two columns (``src_id`` and ``dst_id``) specifying the\nsource and destination node ID of each edge:\n\n.. code::\n\n    src_id,dst_id\n    4,4\n    4,1\n    3,0\n    4,1\n    4,0\n    1,2\n    1,3\n    3,3\n    1,1\n    4,1\n\nAfter loaded, the dataset has one graph without any features:\n\n.. code:: python\n\n    >>> import dgl\n    >>> dataset = dgl.data.CSVDataset('./mini_featureless_dataset')\n    >>> g = dataset[0]  # only one graph\n    >>> print(g)\n    Graph(num_nodes=5, num_edges=10,\n          ndata_schemes={}\n          edata_schemes={})\n\n.. note::\n    Non-integer node IDs are allowed. When constructing the graph, ``CSVDataset`` will\n    map each raw ID to an integer ID starting from zero.\n    If the node IDs are already distinct integers from 0 to ``num_nodes-1``, no mapping\n    is applied.\n\n.. note::\n    Edges are always directed. To have both directions, add reversed edges in the edge\n    CSV file or use :class:`~dgl.transforms.AddReverse` to transform the loaded graph.\n\n\nA graph without any feature is often of less interest. In the next example, we will show\nhow to load and parse node or edge features.\n\nDataset of a single graph with features and labels\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nWhen the dataset contains a single graph with node or edge features and labels, there still\nneed only three files in the data folder: ``meta.yaml``, one CSV for node IDs and one CSV\nfor edges:\n\n.. code::\n\n    ./mini_feature_dataset/\n    |-- meta.yaml\n    |-- nodes.csv\n    |-- edges.csv\n\n``meta.yaml``:\n\n.. code:: yaml\n\n    dataset_name: mini_feature_dataset\n    edge_data:\n    - file_name: edges.csv\n    node_data:\n    - file_name: nodes.csv\n\n``edges.csv`` with five synthetic edge data (``label``, ``train_mask``, ``val_mask``, ``test_mask``, ``feat``):\n\n.. code::\n\n    src_id,dst_id,label,train_mask,val_mask,test_mask,feat\n    4,0,2,False,True,True,\"0.5477868606453535, 0.4470617033458436, 0.936706701616337\"\n    4,0,0,False,False,True,\"0.9794634290792008, 0.23682038840665198, 0.049629338970987646\"\n    0,3,1,True,True,True,\"0.8586722047523594, 0.5746912787380253, 0.6462162561249654\"\n    0,1,2,True,False,False,\"0.2730008213674695, 0.5937484188166621, 0.765544096939567\"\n    0,2,1,True,True,True,\"0.45441619816038514, 0.1681403185591509, 0.9952376085297715\"\n    0,0,0,False,False,False,\"0.4197669213305396, 0.849983324532477, 0.16974127573016262\"\n    2,2,1,False,True,True,\"0.5495035052928215, 0.21394654203489705, 0.7174910641836348\"\n    1,0,2,False,True,False,\"0.008790817766266334, 0.4216530595907526, 0.529195480661293\"\n    3,0,0,True,True,True,\"0.6598715708878852, 0.1932390907048961, 0.9774471538377553\"\n    4,0,1,False,False,False,\"0.16846068931179736, 0.41516080644186737, 0.002158116134429955\"\n\n\n``nodes.csv`` with five synthetic node data (``label``, ``train_mask``, ``val_mask``, ``test_mask``, ``feat``):\n\n.. code::\n\n    node_id,label,train_mask,val_mask,test_mask,feat\n    0,1,False,True,True,\"0.07816474278491703, 0.9137336384979067, 0.4654086994009452\"\n    1,1,True,True,True,\"0.05354099924658973, 0.8753101998792645, 0.33929432608774135\"\n    2,1,True,False,True,\"0.33234211884156384, 0.9370522452510665, 0.6694943496824788\"\n    3,0,False,True,False,\"0.9784264442230887, 0.22131880861864428, 0.3161154827254189\"\n    4,1,True,True,False,\"0.23142237259162102, 0.8715767748481147, 0.19117861103555467\"\n\nAfter loaded, the dataset has one graph. Node/edge features are stored in ``ndata`` and ``edata``\nwith the same column names. The example demonstrates how to specify a vector-shaped feature\nusing comma-separated list enclosed by double quotes ``\"...\"``.\n\n.. code:: python\n\n    >>> import dgl\n    >>> dataset = dgl.data.CSVDataset('./mini_feature_dataset')\n    >>> g = dataset[0]  # only one graph\n    >>> print(g)\n    Graph(num_nodes=5, num_edges=10,\n          ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'train_mask': Scheme(shape=(), dtype=torch.bool), 'val_mask': Scheme(shape=(), dtype=torch.bool), 'test_mask': Scheme(shape=(), dtype=torch.bool), 'feat': Scheme(shape=(3,), dtype=torch.float64)}\n          edata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'train_mask': Scheme(shape=(), dtype=torch.bool), 'val_mask': Scheme(shape=(), dtype=torch.bool), 'test_mask': Scheme(shape=(), dtype=torch.bool), 'feat': Scheme(shape=(3,), dtype=torch.float64)})\n\n.. note::\n    By default, ``CSVDatatset`` assumes all feature data to be numerical values (e.g., int, float, bool or\n    list) and missing values are not allowed. Users could provide custom data parser for these cases.\n    See `Custom Data Parser`_ for more details.\n\nDataset of a single heterogeneous graph\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nOne can specify multiple node and edge CSV files (each for one type) to represent a heterogeneous graph.\nHere is an example data with two node types and two edge types:\n\n.. code::\n\n    ./mini_hetero_dataset/\n    |-- meta.yaml\n    |-- nodes_0.csv\n    |-- nodes_1.csv\n    |-- edges_0.csv\n    |-- edges_1.csv\n\nThe ``meta.yaml`` specifies the node type name (using ``ntype``) and edge type name (using ``etype``)\nof each CSV file. The edge type name is a string triplet containing the source node type name, relation\nname and the destination node type name.\n\n.. code:: yaml\n\n    dataset_name: mini_hetero_dataset\n    edge_data:\n    - file_name: edges_0.csv\n      etype: [user, follow, user]\n    - file_name: edges_1.csv\n      etype: [user, like, item]\n    node_data:\n    - file_name: nodes_0.csv\n      ntype: user\n    - file_name: nodes_1.csv\n      ntype: item\n\nThe node and edge CSV files follow the same format as in homogeneous graphs. Here are some synthetic\ndata for demonstration purposes:\n\n``edges_0.csv`` and ``edges_1.csv``:\n\n.. code::\n\n    src_id,dst_id,label,feat\n    4,4,1,\"0.736833152378035,0.10522806046048205,0.9418796835016118\"\n    3,4,2,\"0.5749339182767451,0.20181320245665535,0.490938012147181\"\n    1,4,2,\"0.7697294432580938,0.49397782380750765,0.10864079337442234\"\n    0,4,0,\"0.1364240150959487,0.1393107840629273,0.7901988878812207\"\n    2,3,1,\"0.42988138237505735,0.18389137408509248,0.18431292077750894\"\n    0,4,2,\"0.8613368738351794,0.67985810014162,0.6580438064356824\"\n    2,4,1,\"0.6594951663841697,0.26499036865016423,0.7891429392727503\"\n    4,1,0,\"0.36649684241348557,0.9511783938523962,0.8494919263589972\"\n    1,1,2,\"0.698592283371875,0.038622249776255946,0.5563827995742111\"\n    0,4,1,\"0.5227112950269823,0.3148264185956532,0.47562693094002173\"\n\n``nodes_0.csv`` and ``nodes_1.csv``:\n\n.. code::\n\n    node_id,label,feat\n    0,2,\"0.5400687466285844,0.7588441197954202,0.4268254673041745\"\n    1,1,\"0.08680051341900807,0.11446843700743892,0.7196969604886617\"\n    2,2,\"0.8964389655603473,0.23368113896545695,0.8813472954005022\"\n    3,1,\"0.5454703921677284,0.7819383771535038,0.3027939452162367\"\n    4,1,\"0.5365210052235699,0.8975240205792763,0.7613943085507672\"\n\nAfter loaded, the dataset has one heterograph with features and labels:\n\n.. code:: python\n\n    >>> import dgl\n    >>> dataset = dgl.data.CSVDataset('./mini_hetero_dataset')\n    >>> g = dataset[0]  # only one graph\n    >>> print(g)\n    Graph(num_nodes={'item': 5, 'user': 5},\n          num_edges={('user', 'follow', 'user'): 10, ('user', 'like', 'item'): 10},\n          metagraph=[('user', 'user', 'follow'), ('user', 'item', 'like')])\n    >>> g.nodes['user'].data\n    {'label': tensor([2, 1, 2, 1, 1]), 'feat': tensor([[0.5401, 0.7588, 0.4268],\n            [0.0868, 0.1145, 0.7197],\n            [0.8964, 0.2337, 0.8813],\n            [0.5455, 0.7819, 0.3028],\n            [0.5365, 0.8975, 0.7614]], dtype=torch.float64)}\n    >>> g.edges['like'].data\n    {'label': tensor([1, 2, 2, 0, 1, 2, 1, 0, 2, 1]), 'feat': tensor([[0.7368, 0.1052, 0.9419],\n            [0.5749, 0.2018, 0.4909],\n            [0.7697, 0.4940, 0.1086],\n            [0.1364, 0.1393, 0.7902],\n            [0.4299, 0.1839, 0.1843],\n            [0.8613, 0.6799, 0.6580],\n            [0.6595, 0.2650, 0.7891],\n            [0.3665, 0.9512, 0.8495],\n            [0.6986, 0.0386, 0.5564],\n            [0.5227, 0.3148, 0.4756]], dtype=torch.float64)}\n\nDataset of multiple graphs\n~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nWhen there are multiple graphs, one can include an additional CSV file for storing graph-level features.\nHere is an example:\n\n.. code::\n\n    ./mini_multi_dataset/\n    |-- meta.yaml\n    |-- nodes.csv\n    |-- edges.csv\n    |-- graphs.csv\n\nAccordingly, the ``meta.yaml`` should include an extra ``graph_data`` key to tell which CSV file to\nload graph-level features from.\n\n.. code:: yaml\n\n    dataset_name: mini_multi_dataset\n    edge_data:\n    - file_name: edges.csv\n    node_data:\n    - file_name: nodes.csv\n    graph_data:\n      file_name: graphs.csv\n\nTo distinguish nodes and edges of different graphs, the ``node.csv`` and ``edge.csv`` must contain\nan extra column ``graph_id``:\n\n``edges.csv``:\n\n.. code::\n\n    graph_id,src_id,dst_id,feat\n    0,0,4,\"0.39534097273254654,0.9422093637539785,0.634899790318452\"\n    0,3,0,\"0.04486384200747007,0.6453746567017163,0.8757520744192612\"\n    0,3,2,\"0.9397636966928355,0.6526403892728874,0.8643238446466464\"\n    0,1,1,\"0.40559906615287566,0.9848072295736628,0.493888090726854\"\n    0,4,1,\"0.253458867276219,0.9168191778828504,0.47224962583565544\"\n    0,0,1,\"0.3219496197945605,0.3439899477636117,0.7051530741717352\"\n    0,2,1,\"0.692873149428549,0.4770019763881086,0.21937428942781778\"\n    0,4,0,\"0.620118223673067,0.08691420300562658,0.86573472329756\"\n    0,2,1,\"0.00743445923710373,0.5251800239734318,0.054016385555202384\"\n    0,4,1,\"0.6776417760682221,0.7291568018841328,0.4523600060547709\"\n    1,1,3,\"0.6375445528248924,0.04878384701995819,0.4081642382536248\"\n    1,0,4,\"0.776002616178397,0.8851294998284638,0.7321742043493028\"\n    1,1,0,\"0.0928555079874982,0.6156748364694707,0.6985674921582508\"\n    1,0,2,\"0.31328748118329997,0.8326121496142408,0.04133991340612775\"\n    1,1,0,\"0.36786902637778773,0.39161865931662243,0.9971749359397111\"\n    1,1,1,\"0.4647410679872376,0.8478810655406659,0.6746269314422184\"\n    1,0,2,\"0.8117650553546695,0.7893727601272978,0.41527155506593394\"\n    1,1,3,\"0.40707309111756307,0.2796588354307046,0.34846782265758314\"\n    1,1,0,\"0.18626464175355095,0.3523777809254057,0.7863421810531344\"\n    1,3,0,\"0.28357022069634585,0.13774964202156292,0.5913335505943637\"\n\n``nodes.csv``:\n\n.. code::\n\n    graph_id,node_id,feat\n    0,0,\"0.5725330322207948,0.8451870383322376,0.44412796119211184\"\n    0,1,\"0.6624186423087752,0.6118386331195641,0.7352138669985214\"\n    0,2,\"0.7583372765843964,0.15218126307872892,0.6810484348765842\"\n    0,3,\"0.14627522432017592,0.7457985352827006,0.1037097085190507\"\n    0,4,\"0.49037522512771525,0.8778998699783784,0.0911194482288028\"\n    1,0,\"0.11158102039672668,0.08543289788089736,0.6901745368284345\"\n    1,1,\"0.28367647637469273,0.07502571020414439,0.01217200152200748\"\n    1,2,\"0.2472495901894738,0.24285506608575758,0.6494437360242048\"\n    1,3,\"0.5614197853127827,0.059172654879085296,0.4692371689047904\"\n    1,4,\"0.17583413999295983,0.5191278830882644,0.8453123358491914\"\n\nThe ``graphs.csv`` contains a ``graph_id`` column and arbitrary number of feature columns.\nThe example dataset here has two graphs, each with a ``feat`` and a ``label`` graph-level\ndata.\n\n.. code::\n\n    graph_id,feat,label\n    0,\"0.7426272601929126,0.5197462471155317,0.8149104951283953\",0\n    1,\"0.534822233529295,0.2863627767733977,0.1154897249106891\",0\n\nAfter loaded, the dataset has multiple homographs with features and labels:\n\n.. code:: python\n\n    >>> import dgl\n    >>> dataset = dgl.data.CSVDataset('./mini_multi_dataset')\n    >>> print(len(dataset))\n    2\n    >>> graph0, data0 = dataset[0]\n    >>> print(graph0)\n    Graph(num_nodes=5, num_edges=10,\n          ndata_schemes={'feat': Scheme(shape=(3,), dtype=torch.float64)}\n          edata_schemes={'feat': Scheme(shape=(3,), dtype=torch.float64)})\n    >>> print(data0)\n    {'feat': tensor([0.7426, 0.5197, 0.8149], dtype=torch.float64), 'label': tensor(0)}\n    >>> graph1, data1 = dataset[1]\n    >>> print(graph1)\n    Graph(num_nodes=5, num_edges=10,\n          ndata_schemes={'feat': Scheme(shape=(3,), dtype=torch.float64)}\n          edata_schemes={'feat': Scheme(shape=(3,), dtype=torch.float64)})\n    >>> print(data1)\n    {'feat': tensor([0.5348, 0.2864, 0.1155], dtype=torch.float64), 'label': tensor(0)}\n\nIf there is a single feature column in ``graphs.csv``, ``data0`` will directly be a tensor for the feature.\n\n\nCustom Data Parser\n~~~~~~~~~~~~~~~~~~\n\nBy default, ``CSVDataset`` assumes that all the stored node-/edge-/graph- level data are numerical\nvalues. Users can provide custom ``DataParser`` to ``CSVDataset`` to handle more complex\ndata type. A ``DataParser`` needs to implement the ``__call__`` method which takes in the\n:class:`pandas.DataFrame` object created from CSV file and should return a dictionary of\nparsed feature data. The parsed feature data will be saved to the ``ndata`` and ``edata`` of\nthe corresponding ``DGLGraph`` object, and thus must be tensors or numpy arrays. Below shows an example\n``DataParser`` which converts string type labels to integers:\n\nGiven a dataset as follows,\n\n.. code::\n\n    ./customized_parser_dataset/\n    |-- meta.yaml\n    |-- nodes.csv\n    |-- edges.csv\n\n``meta.yaml``:\n\n.. code:: yaml\n\n    dataset_name: customized_parser_dataset\n    edge_data:\n    - file_name: edges.csv\n    node_data:\n    - file_name: nodes.csv\n\n``edges.csv``:\n\n.. code::\n\n    src_id,dst_id,label\n    4,0,positive\n    4,0,negative\n    0,3,positive\n    0,1,positive\n    0,2,negative\n    0,0,positive\n    2,2,negative\n    1,0,positive\n    3,0,negative\n    4,0,positive\n\n``nodes.csv``:\n\n.. code::\n\n    node_id,label\n    0,positive\n    1,negative\n    2,positive\n    3,negative\n    4,positive\n\nTo parse the string type labels, one can define a ``DataParser`` class as follows:\n\n.. code:: python\n\n    import numpy as np\n    import pandas as pd\n\n    class MyDataParser:\n        def __call__(self, df: pd.DataFrame):\n            parsed = {}\n            for header in df:\n                if 'Unnamed' in header:  # Handle Unnamed column\n                    print(\"Unnamed column is found. Ignored...\")\n                    continue\n                dt = df[header].to_numpy().squeeze()\n                if header == 'label':\n                    dt = np.array([1 if e == 'positive' else 0 for e in dt])\n                parsed[header] = dt\n            return parsed\n\nCreate a ``CSVDataset`` using the defined ``DataParser``:\n\n.. code:: python\n\n    >>> import dgl\n    >>> dataset = dgl.data.CSVDataset('./customized_parser_dataset',\n    ...                               ndata_parser=MyDataParser(),\n    ...                               edata_parser=MyDataParser())\n    >>> print(dataset[0].ndata['label'])\n    tensor([1, 0, 1, 0, 1])\n    >>> print(dataset[0].edata['label'])\n    tensor([1, 0, 1, 1, 0, 1, 0, 1, 0, 1])\n\n.. note::\n\n    To specify different ``DataParser``\\s for different node/edge types, pass a dictionary to\n    ``ndata_parser`` and ``edata_parser``, where the key is type name (a single string for\n    node type; a string triplet for edge type) and the value is the ``DataParser`` to use.\n\n\nFull YAML Specification\n~~~~~~~~~~~~~~~~~~~~~~~\n\n``CSVDataset`` allows more flexible control over the loading and parsing process. For example, one\ncan change the ID column names via ``meta.yaml``. The example below lists all the supported keys.\n\n.. code:: yaml\n\n    version: 1.0.0\n    dataset_name: some_complex_data\n    separator: ','                   # CSV separator symbol. Default: ','\n    edge_data:\n    - file_name: edges_0.csv\n      etype: [user, follow, user]\n      src_id_field: src_id           # Column name for source node IDs. Default: src_id\n      dst_id_field: dst_id           # Column name for destination node IDs. Default: dst_id\n    - file_name: edges_1.csv\n      etype: [user, like, item]\n      src_id_field: src_id\n      dst_id_field: dst_id\n    node_data:\n    - file_name: nodes_0.csv\n      ntype: user\n      node_id_field: node_id         # Column name for node IDs. Default: node_id\n    - file_name: nodes_1.csv\n      ntype: item\n      node_id_field: node_id         # Column name for node IDs. Default: node_id\n    graph_data:\n      file_name: graphs.csv\n      graph_id_field: graph_id       # Column name for graph IDs. Default: graph_id\n\nTop-level\n^^^^^^^^^^^^^^\n\nAt the top level, only 6 keys are available:\n\n  - ``version``: Optional. String.\n    It specifies which version of ``meta.yaml`` is used. More feature may be added in the future.\n  - ``dataset_name``: Required. String.\n    It specifies the dataset name.\n  - ``separator``: Optional. String.\n    It specifies how to parse data in CSV files. Default: ``','``.\n  - ``edge_data``: Required. List of ``EdgeData``.\n    Meta data for parsing edge CSV files.\n  - ``node_data``: Required. List of ``NodeData``.\n    Meta data for parsing node CSV files.\n  - ``graph_data``: Optional. ``GraphData``.\n    Meta data for parsing the graph CSV file.\n\n``EdgeData``\n^^^^^^^^^^^^^^^^^^^^^^\n\nThere are 4 keys:\n\n  - ``file_name``: Required. String.\n    The CSV file to load data from.\n  - ``etype``: Optional. List of string.\n    Edge type name in string triplet: [source node type, relation type, destination node type].\n  - ``src_id_field``: Optional. String.\n    Which column to read for source node IDs. Default: ``src_id``.\n  - ``dst_id_field``: Optional. String.\n    Which column to read for destination node IDs. Default: ``dst_id``.\n\n``NodeData``\n^^^^^^^^^^^^^^^^^^^^^^\n\nThere are 3 keys:\n\n  - ``file_name``: Required. String.\n    The CSV file to load data from.\n  - ``ntype``: Optional. String.\n    Node type name.\n  - ``node_id_field``: Optional. String.\n    Which column to read for node IDs. Default: ``node_id``.\n\n``GraphData``\n^^^^^^^^^^^^^^^^^^^^^^\n\nThere are 2 keys:\n\n  - ``file_name``: Required. String.\n    The CSV file to load data from.\n  - ``graph_id_field``: Optional. String.\n    Which column to read for graph IDs. Default: ``graph_id``."
  },
  {
    "path": "docs/source/guide/data-loadogb.rst",
    "content": ".. _guide-data-pipeline-loadogb:\n\n4.5 Loading OGB datasets using ``ogb`` package\n----------------------------------------------\n\n:ref:`(中文版) <guide_cn-data-pipeline-loadogb>`\n\n`Open Graph Benchmark (OGB) <https://ogb.stanford.edu/docs/home/>`__ is\na collection of benchmark datasets. The official OGB package\n`ogb <https://github.com/snap-stanford/ogb>`__ provides APIs for\ndownloading and processing OGB datasets into :class:`dgl.data.DGLGraph` objects. The section\nintroduce their basic usage here.\n\nFirst install ogb package using pip:\n\n.. code:: \n\n    pip install ogb\n\nThe following code shows how to load datasets for *Graph Property\nPrediction* tasks.\n\n.. code:: \n\n    # Load Graph Property Prediction datasets in OGB\n    import dgl\n    import torch\n    from ogb.graphproppred import DglGraphPropPredDataset\n    from dgl.dataloading import GraphDataLoader\n    \n    \n    def _collate_fn(batch):\n        # batch is a list of tuple (graph, label)\n        graphs = [e[0] for e in batch]\n        g = dgl.batch(graphs)\n        labels = [e[1] for e in batch]\n        labels = torch.stack(labels, 0)\n        return g, labels\n    \n    # load dataset\n    dataset = DglGraphPropPredDataset(name='ogbg-molhiv')\n    split_idx = dataset.get_idx_split()\n    # dataloader\n    train_loader = GraphDataLoader(dataset[split_idx[\"train\"]], batch_size=32, shuffle=True, collate_fn=_collate_fn)\n    valid_loader = GraphDataLoader(dataset[split_idx[\"valid\"]], batch_size=32, shuffle=False, collate_fn=_collate_fn)\n    test_loader = GraphDataLoader(dataset[split_idx[\"test\"]], batch_size=32, shuffle=False, collate_fn=_collate_fn)\n\nLoading *Node Property Prediction* datasets is similar, but note that\nthere is only one graph object in this kind of dataset.\n\n.. code:: \n\n    # Load Node Property Prediction datasets in OGB\n    from ogb.nodeproppred import DglNodePropPredDataset\n    \n    dataset = DglNodePropPredDataset(name='ogbn-proteins')\n    split_idx = dataset.get_idx_split()\n    \n    # there is only one graph in Node Property Prediction datasets\n    g, labels = dataset[0]\n    # get split labels\n    train_label = dataset.labels[split_idx['train']]\n    valid_label = dataset.labels[split_idx['valid']]\n    test_label = dataset.labels[split_idx['test']]\n\n*Link Property Prediction* datasets also contain one graph per dataset.\n\n.. code:: \n\n    # Load Link Property Prediction datasets in OGB\n    from ogb.linkproppred import DglLinkPropPredDataset\n    \n    dataset = DglLinkPropPredDataset(name='ogbl-ppa')\n    split_edge = dataset.get_edge_split()\n    \n    graph = dataset[0]\n    print(split_edge['train'].keys())\n    print(split_edge['valid'].keys())\n    print(split_edge['test'].keys())\n"
  },
  {
    "path": "docs/source/guide/data-process.rst",
    "content": ".. _guide-data-pipeline-process:\n\n4.3 Process data\n----------------\n\n:ref:`(中文版) <guide_cn-data-pipeline-process>`\n\nOne can implement the data processing code in function ``process()``, and it\nassumes that the raw data is located in ``self.raw_dir`` already. There\nare typically three types of tasks in machine learning on graphs: graph\nclassification, node classification, and link prediction. This section will show\nhow to process datasets related to these tasks.\n\nThe section focuses on the standard way to process graphs, features and masks.\nIt will use builtin datasets as examples and skip the implementations\nfor building graphs from files, but add links to the detailed\nimplementations. Please refer to :ref:`guide-graph-external` to see a\ncomplete guide on how to build graphs from external sources.\n\nProcessing Graph Classification datasets\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nGraph classification datasets are almost the same as most datasets in\ntypical machine learning tasks, where mini-batch training is used. So one can\nprocess the raw data to a list of :class:`dgl.DGLGraph` objects and a list of\nlabel tensors. In addition, if the raw data has been split into\nseveral files, one can add a parameter ``split`` to load specific part of\nthe data.\n\nTake :class:`~dgl.data.QM7bDataset` as example:\n\n.. code::\n\n    from dgl.data import DGLDataset\n\n    class QM7bDataset(DGLDataset):\n        _url = 'http://deepchem.io.s3-website-us-west-1.amazonaws.com/' \\\n               'datasets/qm7b.mat'\n        _sha1_str = '4102c744bb9d6fd7b40ac67a300e49cd87e28392'\n\n        def __init__(self, raw_dir=None, force_reload=False, verbose=False):\n            super(QM7bDataset, self).__init__(name='qm7b',\n                                              url=self._url,\n                                              raw_dir=raw_dir,\n                                              force_reload=force_reload,\n                                              verbose=verbose)\n\n        def process(self):\n            mat_path = self.raw_path + '.mat'\n            # process data to a list of graphs and a list of labels\n            self.graphs, self.label = self._load_graph(mat_path)\n\n        def __getitem__(self, idx):\n            \"\"\" Get graph and label by index\n\n            Parameters\n            ----------\n            idx : int\n                Item index\n\n            Returns\n            -------\n            (dgl.DGLGraph, Tensor)\n            \"\"\"\n            return self.graphs[idx], self.label[idx]\n\n        def __len__(self):\n            \"\"\"Number of graphs in the dataset\"\"\"\n            return len(self.graphs)\n\n\nIn ``process()``, the raw data is processed to a list of graphs and a\nlist of labels. One must implement ``__getitem__(idx)`` and ``__len__()``\nfor iteration. DGL recommends making ``__getitem__(idx)`` return a\ntuple ``(graph, label)`` as above. Please check the `QM7bDataset source\ncode <https://docs.dgl.ai/en/0.5.x/_modules/dgl/data/qm7b.html#QM7bDataset>`__\nfor details of ``self._load_graph()`` and ``__getitem__``.\n\nOne can also add properties to the class to indicate some useful\ninformation of the dataset. In :class:`~dgl.data.QM7bDataset`, one can add a property\n``num_tasks`` to indicate the total number of prediction tasks in this\nmulti-task dataset:\n\n.. code::\n\n    @property\n    def num_tasks(self):\n        \"\"\"Number of labels for each graph, i.e. number of prediction tasks.\"\"\"\n        return 14\n\nAfter all these coding, one can finally use :class:`~dgl.data.QM7bDataset` as\nfollows:\n\n.. code::\n\n    import dgl\n    import torch\n\n    from dgl.dataloading import GraphDataLoader\n\n    # load data\n    dataset = QM7bDataset()\n    num_tasks = dataset.num_tasks\n\n    # create dataloaders\n    dataloader = GraphDataLoader(dataset, batch_size=1, shuffle=True)\n\n    # training\n    for epoch in range(100):\n        for g, labels in dataloader:\n            # your training code here\n            pass\n\nA complete guide for training graph classification models can be found\nin :ref:`guide-training-graph-classification`.\n\nFor more examples of graph classification datasets, please refer to DGL's builtin graph classification\ndatasets:\n\n* :ref:`gindataset`\n\n* :ref:`minigcdataset`\n\n* :ref:`qm7bdata`\n\n* :ref:`tudata`\n\nProcessing Node Classification datasets\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nDifferent from graph classification, node classification is typically on\na single graph. As such, splits of the dataset are on the nodes of the\ngraph. DGL recommends using node masks to specify the splits. The section uses\nbuiltin dataset `CitationGraphDataset <https://docs.dgl.ai/en/0.5.x/_modules/dgl/data/citation_graph.html#CitationGraphDataset>`__ as an example:\n\nIn addition, DGL recommends re-arrange the nodes and edges so that nodes\nnear to each other have IDs in a close range. The procedure could improve\nthe locality to access a node's neighbors, which may benefit follow-up\ncomputation and analysis conducted on the graph. DGL provides an API called\n:func:`dgl.reorder_graph` for this purpose. Please refer to ``process()``\npart in below example for more details.\n\n.. code::\n\n    from dgl.data import DGLBuiltinDataset\n    from dgl.data.utils import _get_dgl_url\n\n    class CitationGraphDataset(DGLBuiltinDataset):\n        _urls = {\n            'cora_v2' : 'dataset/cora_v2.zip',\n            'citeseer' : 'dataset/citeseer.zip',\n            'pubmed' : 'dataset/pubmed.zip',\n        }\n\n        def __init__(self, name, raw_dir=None, force_reload=False, verbose=True):\n            assert name.lower() in ['cora', 'citeseer', 'pubmed']\n            if name.lower() == 'cora':\n                name = 'cora_v2'\n            url = _get_dgl_url(self._urls[name])\n            super(CitationGraphDataset, self).__init__(name,\n                                                       url=url,\n                                                       raw_dir=raw_dir,\n                                                       force_reload=force_reload,\n                                                       verbose=verbose)\n\n        def process(self):\n            # Skip some processing code\n            # === data processing skipped ===\n\n            # build graph\n            g = dgl.graph(graph)\n            # splitting masks\n            g.ndata['train_mask'] = train_mask\n            g.ndata['val_mask'] = val_mask\n            g.ndata['test_mask'] = test_mask\n            # node labels\n            g.ndata['label'] = torch.tensor(labels)\n            # node features\n            g.ndata['feat'] = torch.tensor(_preprocess_features(features),\n                                           dtype=F.data_type_dict['float32'])\n            self._num_tasks = onehot_labels.shape[1]\n            self._labels = labels\n            # reorder graph to obtain better locality.\n            self._g = dgl.reorder_graph(g)\n\n        def __getitem__(self, idx):\n            assert idx == 0, \"This dataset has only one graph\"\n            return self._g\n\n        def __len__(self):\n            return 1\n\nFor brevity, this section skips some code in ``process()`` to highlight the key\npart for processing node classification dataset: splitting masks. Node\nfeatures and node labels are stored in ``g.ndata``. For detailed\nimplementation, please refer to `CitationGraphDataset source\ncode <https://docs.dgl.ai/en/0.5.x/_modules/dgl/data/citation_graph.html#CitationGraphDataset>`__.\n\nNote that the implementations of ``__getitem__(idx)`` and\n``__len__()`` are changed as well, since there is often only one graph\nfor node classification tasks. The masks are ``bool tensors`` in PyTorch\nand TensorFlow, and ``float tensors`` in MXNet.\n\nThe section uses a subclass of ``CitationGraphDataset``, :class:`dgl.data.CiteseerGraphDataset`,\nto show the usage of it:\n\n.. code::\n\n    # load data\n    dataset = CiteseerGraphDataset(raw_dir='')\n    graph = dataset[0]\n\n    # get split masks\n    train_mask = graph.ndata['train_mask']\n    val_mask = graph.ndata['val_mask']\n    test_mask = graph.ndata['test_mask']\n\n    # get node features\n    feats = graph.ndata['feat']\n\n    # get labels\n    labels = graph.ndata['label']\n\nA complete guide for training node classification models can be found in\n:ref:`guide-training-node-classification`.\n\nFor more examples of node classification datasets, please refer to DGL's\nbuiltin datasets:\n\n* :ref:`citationdata`\n\n* :ref:`corafulldata`\n\n* :ref:`amazoncobuydata`\n\n* :ref:`coauthordata`\n\n* :ref:`karateclubdata`\n\n* :ref:`ppidata`\n\n* :ref:`redditdata`\n\n* :ref:`sbmdata`\n\n* :ref:`sstdata`\n\n* :ref:`rdfdata`\n\nProcessing dataset for Link Prediction datasets\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nThe processing of link prediction datasets is similar to that for node\nclassification’s, there is often one graph in the dataset.\n\nThe section uses builtin dataset\n`KnowledgeGraphDataset <https://docs.dgl.ai/en/0.5.x/_modules/dgl/data/knowledge_graph.html#KnowledgeGraphDataset>`__\nas an example, and still skips the detailed data processing code to\nhighlight the key part for processing link prediction datasets:\n\n.. code::\n\n    # Example for creating Link Prediction datasets\n    class KnowledgeGraphDataset(DGLBuiltinDataset):\n        def __init__(self, name, reverse=True, raw_dir=None, force_reload=False, verbose=True):\n            self._name = name\n            self.reverse = reverse\n            url = _get_dgl_url('dataset/') + '{}.tgz'.format(name)\n            super(KnowledgeGraphDataset, self).__init__(name,\n                                                        url=url,\n                                                        raw_dir=raw_dir,\n                                                        force_reload=force_reload,\n                                                        verbose=verbose)\n\n        def process(self):\n            # Skip some processing code\n            # === data processing skipped ===\n\n            # splitting mask\n            g.edata['train_mask'] = train_mask\n            g.edata['val_mask'] = val_mask\n            g.edata['test_mask'] = test_mask\n            # edge type\n            g.edata['etype'] = etype\n            # node type\n            g.ndata['ntype'] = ntype\n            self._g = g\n\n        def __getitem__(self, idx):\n            assert idx == 0, \"This dataset has only one graph\"\n            return self._g\n\n        def __len__(self):\n            return 1\n\nAs shown in the code, it adds splitting masks into ``edata`` field of the\ngraph. Check `KnowledgeGraphDataset source\ncode <https://docs.dgl.ai/en/0.5.x/_modules/dgl/data/knowledge_graph.html#KnowledgeGraphDataset>`__\nto see the complete code. The following code uses a subclass of ``KnowledgeGraphDataset``,\n:class:`dgl.data.FB15k237Dataset`, to show the usage of it:\n\n.. code::\n\n    from dgl.data import FB15k237Dataset\n\n    # load data\n    dataset = FB15k237Dataset()\n    graph = dataset[0]\n\n    # get training mask\n    train_mask = graph.edata['train_mask']\n    train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze()\n    src, dst = graph.edges(train_idx)\n    # get edge types in training set\n    rel = graph.edata['etype'][train_idx]\n\n\nA complete guide for training link prediction models can be found in\n:ref:`guide-training-link-prediction`.\n\nFor more examples of link prediction datasets, please refer to DGL's\nbuiltin datasets:\n\n* :ref:`kgdata`\n\n* :ref:`bitcoinotcdata`\n"
  },
  {
    "path": "docs/source/guide/data-savenload.rst",
    "content": ".. _guide-data-pipeline-savenload:\n\n4.4 Save and load data\n----------------------\n\n:ref:`(中文版) <guide_cn-data-pipeline-savenload>`\n\nDGL recommends implementing saving and loading functions to cache the\nprocessed data in local disk. This saves a lot of data processing time\nin most cases. DGL provides four functions to make things simple:\n\n-  :func:`dgl.save_graphs` and :func:`dgl.load_graphs`: save/load DGLGraph objects and labels to/from local disk.\n-  :func:`dgl.data.utils.save_info` and :func:`dgl.data.utils.load_info`: save/load useful information of the dataset (python ``dict`` object) to/from local disk.\n\nThe following example shows how to save and load a list of graphs and\ndataset information.\n\n.. code:: \n\n    import os\n    from dgl import save_graphs, load_graphs\n    from dgl.data.utils import makedirs, save_info, load_info\n    \n    def save(self):\n        # save graphs and labels\n        graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin')\n        save_graphs(graph_path, self.graphs, {'labels': self.labels})\n        # save other information in python dict\n        info_path = os.path.join(self.save_path, self.mode + '_info.pkl')\n        save_info(info_path, {'num_classes': self.num_classes})\n    \n    def load(self):\n        # load processed data from directory `self.save_path`\n        graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin')\n        self.graphs, label_dict = load_graphs(graph_path)\n        self.labels = label_dict['labels']\n        info_path = os.path.join(self.save_path, self.mode + '_info.pkl')\n        self.num_classes = load_info(info_path)['num_classes']\n    \n    def has_cache(self):\n        # check whether there are processed data in `self.save_path`\n        graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin')\n        info_path = os.path.join(self.save_path, self.mode + '_info.pkl')\n        return os.path.exists(graph_path) and os.path.exists(info_path)\n\nNote that there are cases not suitable to save processed data. For\nexample, in the builtin dataset :class:`~dgl.data.GDELTDataset`,\nthe processed data is quite large, so it’s more effective to process\neach data example in ``__getitem__(idx)``."
  },
  {
    "path": "docs/source/guide/data.rst",
    "content": ".. _guide-data-pipeline:\n\nChapter 4: Graph Data Pipeline\n==============================\n\n:ref:`(中文版) <guide_cn-data-pipeline>`\n\nDGL implements many commonly used graph datasets in :ref:`apidata`. They\nfollow a standard pipeline defined in class :class:`dgl.data.DGLDataset`. DGL highly\nrecommends processing graph data into a :class:`dgl.data.DGLDataset` subclass, as the\npipeline provides simple and clean solution for loading, processing and\nsaving graph data.\n\nRoadmap\n-------\n\nThis chapter introduces how to create a custom DGL-Dataset.\nThe following sections explain how the pipeline works, and\nshows how to implement each component of it.\n\n* :ref:`guide-data-pipeline-dataset`\n* :ref:`guide-data-pipeline-download`\n* :ref:`guide-data-pipeline-process`\n* :ref:`guide-data-pipeline-savenload`\n* :ref:`guide-data-pipeline-loadogb`\n* :ref:`guide-data-pipeline-loadcsv`\n\n.. toctree::\n    :maxdepth: 1\n    :hidden:\n    :glob:\n\n    data-dataset\n    data-download\n    data-process\n    data-savenload\n    data-loadogb\n    data-loadcsv"
  },
  {
    "path": "docs/source/guide/distributed-apis.rst",
    "content": ".. _guide-distributed-apis:\n\n7.3 Programming APIs\n-----------------------------------\n\n:ref:`(中文版) <guide_cn-distributed-apis>`\n\nThis section covers the core python components commonly used in a training script. DGL\nprovides three distributed data structures and various APIs for initialization,\ndistributed sampling and workload split.\n\n* :class:`~dgl.distributed.DistGraph` for accessing structure and feature of a distributedly\n  stored graph.\n* :class:`~dgl.distributed.DistTensor` for accessing node/edge feature tensor that\n  is partitioned across machines.\n* :class:`~dgl.distributed.DistEmbedding` for accessing learnable node/edge embedding\n  tensor that is partitioned across machines.\n\nInitialization of the DGL distributed module\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n:func:`dgl.distributed.initialize` initializes the distributed module. If invoked\nby a trainer, this API creates sampler processes and builds connections with graph\nservers; if invoked by graph server, this API starts a service loop to listen to\ntrainer/sampler requests. The API *must* be called before\n:func:`torch.distributed.init_process_group` and any other ``dgl.distributed`` APIs\nas shown in the order below:\n\n.. code:: python\n\n    dgl.distributed.initialize('ip_config.txt')\n    th.distributed.init_process_group(backend='gloo')\n\n.. note::\n\n    If the training script contains user-defined functions (UDFs) that have to be invoked on\n    the servers (see the section of DistTensor and DistEmbedding for more details), these UDFs have to\n    be declared before :func:`~dgl.distributed.initialize`.\n\nDistributed graph\n~~~~~~~~~~~~~~~~~\n\n:class:`~dgl.distributed.DistGraph` is a Python class to access the graph\nstructure and node/edge features in a cluster of machines. Each machine is\nresponsible for one and only one partition. It loads the partition data (the\ngraph structure and the node data and edge data in the partition) and makes it\naccessible to all trainers in the cluster. :class:`~dgl.distributed.DistGraph`\nprovides a small subset of :class:`~dgl.DGLGraph` APIs for data access.\n\nDistributed mode vs. standalone mode\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n:class:`~dgl.distributed.DistGraph` can run in two modes: *distributed mode* and *standalone mode*.\nWhen a user executes a training script in a Python command line or Jupyter Notebook, it runs in\na standalone mode. That is, it runs all computation in a single process and does not communicate\nwith any other processes. Thus, the standalone mode requires the input graph to have only one partition.\nThis mode is mainly used for development and testing (e.g., develop and run the code in Jupyter Notebook).\nWhen a user executes a training script with a launch script (see the section of launch script),\n:class:`~dgl.distributed.DistGraph` runs in the distributed mode. The launch tool starts servers\n(node/edge feature access and graph sampling) behind the scene and loads the partition data in\neach machine automatically. :class:`~dgl.distributed.DistGraph` connects with the servers in the cluster\nof machines and access them through the network.\n\nDistGraph creation\n^^^^^^^^^^^^^^^^^^\n\nIn the distributed mode, the creation of :class:`~dgl.distributed.DistGraph`\nrequires the graph name given during graph partitioning. The graph name\nidentifies the graph loaded in the cluster.\n\n.. code:: python\n\n    import dgl\n    g = dgl.distributed.DistGraph('graph_name')\n\nWhen running in the standalone mode, it loads the graph data in the local\nmachine. Therefore, users need to provide the partition configuration file,\nwhich contains all information about the input graph.\n\n.. code:: python\n\n    import dgl\n    g = dgl.distributed.DistGraph('graph_name', part_config='data/graph_name.json')\n\n.. note::\n\n    DGL only allows one single ``DistGraph`` object. The behavior\n    of destroying a DistGraph and creating a new one is undefined.\n\nAccessing graph structure\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n:class:`~dgl.distributed.DistGraph` provides a set of APIs to\naccess the graph structure.  Currently, most APIs provide graph information,\nsuch as the number of nodes and edges. The main use case of DistGraph is to run\nsampling APIs to support mini-batch training (see `Distributed sampling`_).\n\n.. code:: python\n\n    print(g.num_nodes())\n\nAccess node/edge data\n^^^^^^^^^^^^^^^^^^^^^\n\nLike :class:`~dgl.DGLGraph`, :class:`~dgl.distributed.DistGraph` provides ``ndata`` and ``edata``\nto access data in nodes and edges.\nThe difference is that ``ndata``/``edata`` in :class:`~dgl.distributed.DistGraph` returns\n:class:`~dgl.distributed.DistTensor`, instead of the tensor of the underlying framework.\nUsers can also assign a new :class:`~dgl.distributed.DistTensor` to\n:class:`~dgl.distributed.DistGraph` as node data or edge data.\n\n.. code:: python\n\n    g.ndata['train_mask']  # <dgl.distributed.dist_graph.DistTensor at 0x7fec820937b8>\n    g.ndata['train_mask'][0]  # tensor([1], dtype=torch.uint8)\n\nDistributed Tensor\n~~~~~~~~~~~~~~~~~~~~~\n\nAs mentioned earlier, DGL shards node/edge features and stores them in a cluster of machines.\nDGL provides distributed tensors with a tensor-like interface to access the partitioned\nnode/edge features in the cluster. In the distributed setting, DGL only supports dense node/edge\nfeatures.\n\n:class:`~dgl.distributed.DistTensor` manages the dense tensors partitioned and stored in\nmultiple machines. Right now, a distributed tensor has to be associated with nodes or edges\nof a graph. In other words, the number of rows in a DistTensor has to be the same as the number\nof nodes or the number of edges in a graph. The following code creates a distributed tensor.\nIn addition to the shape and dtype for the tensor, a user can also provide a unique tensor name.\nThis name is useful if a user wants to reference a persistent distributed tensor (the one exists\nin the cluster even if the :class:`~dgl.distributed.DistTensor` object disappears).\n\n.. code:: python\n\n    tensor = dgl.distributed.DistTensor((g.num_nodes(), 10), th.float32, name='test')\n\n.. note::\n\n    :class:`~dgl.distributed.DistTensor` creation is a synchronized operation. All trainers\n    have to invoke the creation and the creation succeeds only when all trainers call it.\n\nA user can add a :class:`~dgl.distributed.DistTensor` to a :class:`~dgl.distributed.DistGraph`\nobject as one of the node data or edge data.\n\n.. code:: python\n\n    g.ndata['feat'] = tensor\n\n.. note::\n\n    The node data name and the tensor name do not have to be the same. The former identifies\n    node data from :class:`~dgl.distributed.DistGraph` (in the trainer process) while the latter\n    identifies a distributed tensor in DGL servers.\n\n:class:`~dgl.distributed.DistTensor` has the same APIs as\nregular tensors to access its metadata, such as the shape and dtype. It also\nsupports indexed reads and writes but does not support\ncomputation operators, such as sum and mean.\n\n.. code:: python\n\n    data = g.ndata['feat'][[1, 2, 3]]\n    print(data)\n    g.ndata['feat'][[3, 4, 5]] = data\n\n\n.. note::\n\n    Currently, DGL does not provide protection for concurrent writes from\n    multiple trainers when a machine runs multiple servers. This may result in\n    data corruption. One way to avoid concurrent writes to the same row of data\n    is to run one server process on a machine.\n\nDistributed DistEmbedding\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nDGL provides :class:`~dgl.distributed.DistEmbedding` to support transductive models that require\nnode embeddings. Creating distributed embeddings is very similar to creating distributed tensors.\n\n.. code:: python\n\n    def initializer(shape, dtype):\n        arr = th.zeros(shape, dtype=dtype)\n        arr.uniform_(-1, 1)\n        return arr\n    emb = dgl.distributed.DistEmbedding(g.num_nodes(), 10, init_func=initializer)\n\nInternally, distributed embeddings are built on top of distributed tensors,\nand, thus, has very similar behaviors to distributed tensors. For example, when\nembeddings are created, they are sharded and stored across all machines in the\ncluster. It can be uniquely identified by a name.\n\n.. note::\n\n    The initializer function is invoked in the server process. Therefore, it has to be\n    declared before :class:`dgl.distributed.initialize`.\n\nBecause the embeddings are part of the model, a user has to attach them to an\noptimizer for mini-batch training. Currently, DGL provides a sparse Adagrad\noptimizer :class:`~dgl.distributed.SparseAdagrad` (DGL will add more optimizers\nfor sparse embeddings later).  Users need to collect all distributed embeddings\nfrom a model and pass them to the sparse optimizer.  If a model has both node\nembeddings and regular dense model parameters and users want to perform sparse\nupdates on the embeddings, they need to create two optimizers, one for node\nembeddings and the other for dense model parameters, as shown in the code\nbelow:\n\n.. code:: python\n\n    sparse_optimizer = dgl.distributed.SparseAdagrad([emb], lr=lr1)\n    optimizer = th.optim.Adam(model.parameters(), lr=lr2)\n    feats = emb(nids)\n    loss = model(feats)\n    loss.backward()\n    optimizer.step()\n    sparse_optimizer.step()\n\n.. note::\n\n    :class:`~dgl.distributed.DistEmbedding` does not inherit :class:`torch.nn.Module`,\n    so we recommend using it outside of your own NN module.\n\nDistributed sampling\n~~~~~~~~~~~~~~~~~~~~\n\nDGL provides two levels of APIs for sampling nodes and edges to generate\nmini-batches (see the section of mini-batch training). The low-level APIs\nrequire users to write code to explicitly define how a layer of nodes are\nsampled (e.g., using :func:`dgl.sampling.sample_neighbors` ).  The high-level\nsampling APIs implement a few popular sampling algorithms for node\nclassification and link prediction tasks (e.g.,\n:class:`~dgl.dataloading.NodeDataLoader` and\n:class:`~dgl.dataloading.EdgeDataLoader` ).\n\nThe distributed sampling module follows the same design and provides two levels\nof sampling APIs.  For the lower-level sampling API, it provides\n:func:`~dgl.distributed.sample_neighbors` for distributed neighborhood sampling\non :class:`~dgl.distributed.DistGraph`. In addition, DGL provides a distributed\nDataLoader (:class:`~dgl.distributed.DistDataLoader` ) for distributed\nsampling.  The distributed DataLoader has the same interface as Pytorch\nDataLoader except that users cannot specify the number of worker processes when\ncreating a dataloader. The worker processes are created in\n:func:`dgl.distributed.initialize`.\n\n.. note::\n\n    When running :func:`dgl.distributed.sample_neighbors` on\n    :class:`~dgl.distributed.DistGraph`, the sampler cannot run in Pytorch\n    DataLoader with multiple worker processes. The main reason is that Pytorch\n    DataLoader creates new sampling worker processes in every epoch, which\n    leads to creating and destroying :class:`~dgl.distributed.DistGraph`\n    objects many times.\n\nWhen using the low-level API, the sampling code is similar to single-process sampling. The only\ndifference is that users need to use :func:`dgl.distributed.sample_neighbors` and\n:class:`~dgl.distributed.DistDataLoader`.\n\n.. code:: python\n\n    def sample_blocks(seeds):\n        seeds = th.LongTensor(np.asarray(seeds))\n        blocks = []\n        for fanout in [10, 25]:\n            frontier = dgl.distributed.sample_neighbors(g, seeds, fanout, replace=True)\n            block = dgl.to_block(frontier, seeds)\n            seeds = block.srcdata[dgl.NID]\n            blocks.insert(0, block)\n            return blocks\n        dataloader = dgl.distributed.DistDataLoader(dataset=train_nid,\n                                                    batch_size=batch_size,\n                                                    collate_fn=sample_blocks,\n                                                    shuffle=True)\n        for batch in dataloader:\n            ...\n\nThe high-level sampling APIs (:class:`~dgl.dataloading.NodeDataLoader` and\n:class:`~dgl.dataloading.EdgeDataLoader` ) has distributed counterparts\n(:class:`~dgl.distributed.DistNodeDataLoader` and\n:class:`~dgl.distributed.DistEdgeDataLoader`).  The code is exactly the same as\nsingle-process sampling otherwise.\n\n.. code:: python\n\n    sampler = dgl.sampling.MultiLayerNeighborSampler([10, 25])\n    dataloader = dgl.distributed.DistNodeDataLoader(g, train_nid, sampler,\n                                                 batch_size=batch_size, shuffle=True)\n    for batch in dataloader:\n        ...\n\n\nSplit workloads\n~~~~~~~~~~~~~~~~~~\n\nTo train a model, users first need to split the dataset into training,\nvalidation and test sets.  For distributed training, this step is usually done\nbefore we invoke :func:`dgl.distributed.partition_graph` to partition a graph.\nWe recommend to store the data split in boolean arrays as node data or edge\ndata. For node classification tasks, the length of these boolean arrays is the\nnumber of nodes in a graph and each of their elements indicates the existence\nof a node in a training/validation/test set.  Similar boolean arrays should be\nused for link prediction tasks.  :func:`dgl.distributed.partition_graph` splits\nthese boolean arrays (because they are stored as the node data or edge data of\nthe graph) based on the graph partitioning result and store them with graph\npartitions.\n\nDuring distributed training, users need to assign training nodes/edges to each\ntrainer. Similarly, we also need to split the validation and test set in the\nsame way.  DGL provides :func:`~dgl.distributed.node_split` and\n:func:`~dgl.distributed.edge_split` to split the training, validation and test\nset at runtime for distributed training. The two functions take the boolean\narrays constructed before graph partitioning as input, split them and return a\nportion for the local trainer.  By default, they ensure that all portions have\nthe same number of nodes/edges. This is important for synchronous SGD, which\nassumes each trainer has the same number of mini-batches.\n\nThe example below splits the training set and returns a subset of nodes for the\nlocal process.\n\n.. code:: python\n\n    train_nids = dgl.distributed.node_split(g.ndata['train_mask'])\n"
  },
  {
    "path": "docs/source/guide/distributed-hetero.rst",
    "content": ".. _guide-distributed-hetero:\n\n7.5 Heterogeneous Graph Under The Hood\n--------------------------------------------\n\nThe chapter covers the implementation details of distributed heterogeneous\ngraph. They are transparent to users in most scenarios but could be useful\nfor advanced customization.\n\nIn DGL, a node or edge in a heterogeneous graph has a unique ID in its own node\ntype or edge type.  Therefore, DGL can identify a node or an edge\nwith a tuple: ``(node/edge type, type-wise ID)``. We call IDs of such form as\n**heterogeneous IDs**. To patition a heterogeneous graph for distributed training,\nDGL converts it to a homogeneous graph so that we can reuse the partitioning\nalgorithms designed for homogeneous graphs. Each node/edge is thus uniquely mapped\nto an integer ID in a consecutive ID range (e.g., from 0 to the total number of\nnodes of all types). We call the IDs after conversion as **homogeneous IDs**.\n\nBelow is an illustration of the ID conversion process.  Here, the graph has two\ntypes of nodes (:math:`T0` and :math:`T1` ), and four types of edges\n(:math:`R0`, :math:`R1`, :math:`R2`, :math:`R3` ).  There are a total of 400\nnodes in the graph and each type has 200 nodes. Nodes of :math:`T0` have IDs in\n[0,200), while nodes of :math:`T1` have IDs in [200, 400).  In this example, if\nwe use a tuple to identify the nodes, nodes of :math:`T0` are identified as\n(T0, type-wise ID), where type-wise ID falls in [0, 200); nodes of :math:`T1`\nare identified as (T1, type-wise ID), where type-wise ID also falls in [0,\n200).\n\n.. figure:: https://data.dgl.ai/tutorial/hetero/heterograph_ids.png\n   :alt: Imgur\n\nID Conversion Utilities\n^^^^^^^^^^^^^^^^^^^^^^^^\n\nDuring Preprocessing\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nThe steps of :ref:`Parallel Processing Pipeline <guide-distributed-preprocessing>`\nall use heterogeneous IDs for their inputs and outputs. Nevertheless, some steps such as\nParMETIS partitioning are easier to be implemented using homogeneous IDs, thus\nrequiring a utility to perform ID conversion.\nThe code below implements a simple ``IDConverter`` using the metadata information\nin the metadata JSON from the chunked graph data format. It starts from some\nnode type :math:`A` as node type 0, then assigns all its nodes with IDs\nin range :math:`[0, |V_A|-1)`. It then moves to the next node\ntype B as node type 1 and assigns all its nodes with IDs in range\n:math:`[|V_A|, |V_A|+|V_B|-1)`.\n\n.. code:: python\n\n    from bisect import bisect_left\n    import numpy as np\n\n    class IDConverter:\n        def __init__(self, meta):\n            # meta is the JSON object loaded from metadata.json\n            self.node_type = meta['node_type']\n            self.edge_type = meta['edge_type']\n            self.ntype2id_map = {ntype : i for i, ntype in enumerate(self.node_type)}\n            self.etype2id_map = {etype : i for i, etype in enumerate(self.edge_type)}\n            self.num_nodes = [sum(ns) for ns in meta['num_nodes_per_chunk']]\n            self.num_edges = [sum(ns) for ns in meta['num_edges_per_chunk']]\n            self.nid_offset = np.cumsum([0] + self.num_nodes)\n            self.eid_offset = np.cumsum([0] + self.num_edges)\n\n        def ntype2id(self, ntype):\n            \"\"\"From node type name to node type ID\"\"\"\n            return self.ntype2id_map[ntype]\n\n        def etype2id(self, etype):\n            \"\"\"From edge type name to edge type ID\"\"\"\n            return self.etype2id_map[etype]\n\n        def id2ntype(self, id):\n            \"\"\"From node type ID to node type name\"\"\"\n            return self.node_type[id]\n\n        def id2etype(self, id):\n            \"\"\"From edge type ID to edge type name\"\"\"\n            return self.edge_type[id]\n\n        def nid_het2hom(self, ntype, id):\n            \"\"\"From heterogeneous node ID to homogeneous node ID\"\"\"\n            tid = self.ntype2id(ntype)\n            if id < 0 or id >= self.num_nodes[tid]:\n                raise ValueError(f'Invalid node ID of type {ntype}. Must be within range [0, {self.num_nodes[tid]})')\n            return self.nid_offset[tid] + id\n\n        def nid_hom2het(self, id):\n            \"\"\"From heterogeneous node ID to homogeneous node ID\"\"\"\n            if id < 0 or id >= self.nid_offset[-1]:\n                raise ValueError(f'Invalid homogeneous node ID. Must be within range [0, self.nid_offset[-1])')\n            tid = bisect_left(self.nid_offset, id) - 1\n            # Return a pair (node_type, type_wise_id)\n            return self.id2ntype(tid), id - self.nid_offset[tid]\n\n        def eid_het2hom(self, etype, id):\n            \"\"\"From heterogeneous edge ID to homogeneous edge ID\"\"\"\n            tid = self.etype2id(etype)\n            if id < 0 or id >= self.num_edges[tid]:\n                raise ValueError(f'Invalid edge ID of type {etype}. Must be within range [0, {self.num_edges[tid]})')\n            return self.eid_offset[tid] + id\n\n        def eid_hom2het(self, id):\n            \"\"\"From heterogeneous edge ID to homogeneous edge ID\"\"\"\n            if id < 0 or id >= self.eid_offset[-1]:\n                raise ValueError(f'Invalid homogeneous edge ID. Must be within range [0, self.eid_offset[-1])')\n            tid = bisect_left(self.eid_offset, id) - 1\n            # Return a pair (edge_type, type_wise_id)\n            return self.id2etype(tid), id - self.eid_offset[tid]\n\nAfter Partition Loading\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nAfter the partitions are loaded into trainer or server processes, the loaded\n:class:`~dgl.distributed.GraphPartitionBook` provides utilities for conversion\nbetween homogeneous IDs and heterogeneous IDs.\n\n* :func:`~dgl.distributed.GraphPartitionBook.map_to_per_ntype`: convert a homogeneous node ID to type-wise ID and node type ID.\n* :func:`~dgl.distributed.GraphPartitionBook.map_to_per_etype`: convert a homogeneous edge ID to type-wise ID and edge type ID.\n* :func:`~dgl.distributed.GraphPartitionBook.map_to_homo_nid`: convert type-wise ID and node type to a homogeneous node ID.\n* :func:`~dgl.distributed.GraphPartitionBook.map_to_homo_eid`: convert type-wise ID and edge type to a homogeneous edge ID.\n\nBecause all DGL's low-level :ref:`distributed graph sampling operators\n<api-distributed-sampling-ops>` use homogeneous IDs, DGL internally converts\nthe heterogeneous IDs specified by users to homogeneous IDs before invoking\nsampling operators.  Below shows an example of sampling a subgraph by\n:func:`~dgl.distributed.sample_neighbors` from nodes of type ``\"paper\"``.  It\nfirst performs ID conversion, and after getting the sampled subgraph, converts\nthe homogeneous node/edge IDs back to heterogeneous ones.\n\n.. code:: python\n\n        gpb = g.get_partition_book()\n        # We need to map the type-wise node IDs to homogeneous IDs.\n        cur = gpb.map_to_homo_nid(seeds, 'paper')\n        # For a heterogeneous input graph, the returned frontier is stored in\n        # the homogeneous graph format.\n        frontier = dgl.distributed.sample_neighbors(g, cur, fanout, replace=False)\n        block = dgl.to_block(frontier, cur)\n        cur = block.srcdata[dgl.NID]\n\n        block.edata[dgl.EID] = frontier.edata[dgl.EID]\n        # Map the homogeneous edge Ids to their edge type.\n        block.edata[dgl.ETYPE], block.edata[dgl.EID] = gpb.map_to_per_etype(block.edata[dgl.EID])\n        # Map the homogeneous node Ids to their node types and per-type Ids.\n        block.srcdata[dgl.NTYPE], block.srcdata[dgl.NID] = gpb.map_to_per_ntype(block.srcdata[dgl.NID])\n        block.dstdata[dgl.NTYPE], block.dstdata[dgl.NID] = gpb.map_to_per_ntype(block.dstdata[dgl.NID])\n\nNote that getting node/edge types from type IDs is simple -- just getting them\nfrom the ``ntypes`` attributes of a ``DistGraph``, i.e., ``g.ntypes[node_type_id]``.\n\nAccess distributed graph data\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nThe :class:`~dgl.distributed.DistGraph` class supports similar interface as\n:class:`~dgl.DGLGraph`.  Below shows an example of getting the feature data of\nnodes 0, 10, 20 of type :math:`T0`. When accessing data in\n:class:`~dgl.distributed.DistGraph`, a user needs to use type-wise IDs and\ncorresponding node types or edge types.\n\n.. code:: python\n\n    import dgl\n    g = dgl.distributed.DistGraph('graph_name', part_config='data/graph_name.json')\n    feat = g.nodes['T0'].data['feat'][[0, 10, 20]]\n\nA user can create distributed tensors and distributed embeddings for a\nparticular node type or edge type. Distributed tensors and embeddings are split\nand stored in multiple machines. To create one, a user needs to specify how it\nis partitioned with :class:`~dgl.distributed.PartitionPolicy`.  By default, DGL\nchooses the right partition policy based on the size of the first dimension.\nHowever, if multiple node types or edge types have the same number of nodes or\nedges, DGL cannot determine the partition policy automatically. A user needs to\nexplicitly specify the partition policy.  Below shows an example of creating a\ndistributed tensor for node type :math:`T0` by using the partition policy for :math:`T0`\nand store it as node data of :math:`T0`.\n\n.. code:: python\n\n    g.nodes['T0'].data['feat1'] = dgl.distributed.DistTensor(\n        (g.num_nodes('T0'), 1), th.float32, 'feat1',\n        part_policy=g.get_node_partition_policy('T0'))\n\nThe partition policies used for creating distributed tensors and embeddings are\ninitialized when a heterogeneous graph is loaded into the graph server. A user\ncannot create a new partition policy at runtime. Therefore, a user can only\ncreate distributed tensors or embeddings for a node type or edge type.\nAccessing distributed tensors and embeddings also requires type-wise IDs.\n"
  },
  {
    "path": "docs/source/guide/distributed-partition.rst",
    "content": ".. _guide-distributed-partition:\n\n7.4 Advanced Graph Partitioning\n---------------------------------------\n\nThe chapter covers some of the advanced topics for graph partitioning.\n\nMETIS partition algorithm\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n`METIS <http://glaros.dtc.umn.edu/gkhome/views/metis>`__ is a state-of-the-art\ngraph partitioning algorithm that can generate partitions with minimal number\nof cross-partition edges, making it suitable for distributed message passing\nwhere the amount of network communication is proportional to the number of\ncross-partition edges. DGL has integrated METIS as the default partitioning\nalgorithm in its :func:`dgl.distributed.partition_graph` API.\n\nOutput format\n~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nRegardless of the partitioning algorithm in use, the partitioned results are stored\nin data files organized as follows:\n\n.. code-block:: none\n\n    data_root_dir/\n      |-- graph_name.json       # partition configuration file in JSON\n      |-- part0/                # data for partition 0\n      |  |-- node_feats.dgl     # node features stored in binary format\n      |  |-- edge_feats.dgl     # edge features stored in binary format\n      |  |-- graph.dgl          # graph structure of this partition stored in binary format\n      |\n      |-- part1/                # data for partition 1\n      |  |-- node_feats.dgl\n      |  |-- edge_feats.dgl\n      |  |-- graph.dgl\n      |\n      |-- ...                   # data for other partitions\n\nWhen distributed to a cluster, the metadata JSON should be copied to all the machines\nwhile the ``partX`` folders should be dispatched accordingly.\n\nDGL provides a :func:`dgl.distributed.load_partition` function to load one partition\nfor inspection.\n\n.. code:: python\n\n  >>> import dgl\n  >>> # load partition 0\n  >>> part_data = dgl.distributed.load_partition('data_root_dir/graph_name.json', 0)\n  >>> g, nfeat, efeat, partition_book, graph_name, ntypes, etypes = part_data  # unpack\n  >>> print(g)\n  Graph(num_nodes=966043, num_edges=34270118,\n        ndata_schemes={'orig_id': Scheme(shape=(), dtype=torch.int64),\n                       'part_id': Scheme(shape=(), dtype=torch.int64),\n                       '_ID': Scheme(shape=(), dtype=torch.int64),\n                       'inner_node': Scheme(shape=(), dtype=torch.int32)}\n        edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64),\n                       'inner_edge': Scheme(shape=(), dtype=torch.int8),\n                       'orig_id': Scheme(shape=(), dtype=torch.int64)})\n\nAs mentioned in the `ID mapping`_ section, each partition carries auxiliary information\nsaved as ndata or edata such as original node/edge IDs, partition IDs, etc. Each partition\nnot only saves nodes/edges it owns, but also includes node/edges that are adjacent to\nthe partition (called **HALO** nodes/edges). The ``inner_node`` and ``inner_edge``\nindicate whether a node/edge truely belongs to the partition (value is ``True``)\nor is a HALO node/edge (value is ``False``).\n\nThe :func:`~dgl.distributed.load_partition` function loads all data at once. Users can\nload features or the partition book using the :func:`dgl.distributed.load_partition_feats`\nand :func:`dgl.distributed.load_partition_book` APIs respectively.\n\n\nParallel METIS partitioning\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nFor massive graphs where parallel preprocessing is desired, DGL supports\n`ParMETIS <http://glaros.dtc.umn.edu/gkhome/metis/parmetis/overview>`__ as one\nof the choices of partitioning algorithms.\n\n.. note::\n\n    Because ParMETIS does not support heterogeneous graph, users need to\n    conduct ID conversion before and after running ParMETIS.\n    Check out chapter :ref:`guide-distributed-hetero` for explanation.\n\n.. note::\n\n    Please make sure that the input graph to ParMETIS does not have\n    duplicate edges (or parallel edges) and self-loop edges.\n\nParMETIS Installation\n^^^^^^^^^^^^^^^^^^^^^^\nParMETIS requires METIS and GKLib. Please follow the instructions `here\n<https://github.com/KarypisLab/GKlib>`__ to compile and install GKLib. For\ncompiling and install METIS, please follow the instructions below to clone\nMETIS with GIT and compile it with int64 support.\n\n.. code-block:: bash\n\n    git clone https://github.com/KarypisLab/METIS.git\n    make config shared=1 cc=gcc prefix=~/local i64=1\n    make install\n\n\nFor now, we need to compile and install ParMETIS manually. We clone the DGL branch of ParMETIS as follows:\n\n.. code-block:: bash\n\n    git clone --branch dgl https://github.com/KarypisLab/ParMETIS.git\n\nThen compile and install ParMETIS.\n\n.. code-block:: bash\n\n    make config cc=mpicc prefix=~/local\n    make install\n\nBefore running ParMETIS, we need to set two environment variables: ``PATH`` and ``LD_LIBRARY_PATH``.\n\n.. code-block:: bash\n\n    export PATH=$PATH:$HOME/local/bin\n    export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/local/lib/\n\nInput format\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n.. note::\n\n    As a prerequisite, read chapter :doc:`guide-distributed-hetero` to understand\n    how DGL organize heterogeneous graph for distributed training.\n\nThe input graph for ParMETIS is stored in three files with the following names:\n``xxx_nodes.txt``, ``xxx_edges.txt`` and ``xxx_stats.txt``, where ``xxx`` is a\ngraph name.\n\nEach row in ``xxx_nodes.txt`` stores the information of a node. Row ID is\nalso the *homogeneous* ID of a node, e.g., row 0 is for node 0; row 1 is for\nnode 1, etc. Each row has the following format:\n\n.. code-block:: none\n\n    <node_type_id> <node_weight_list> <type_wise_node_id>\n\nAll fields are separated by whitespace:\n\n* ``<node_type_id>`` is an integer starting from 0. Each node type is mapped to\n  an integer. For a homogeneous graph, its value is always 0.\n* ``<node_weight_list>`` are integers (separated by whitespace) that indicate\n  the node weights used by ParMETIS to balance graph partitions. For homogeneous\n  graphs, the list has only one integer while for heterogeneous graphs with\n  :math:`T` node types, the list should has :math:`T` integers. If the node\n  belongs to node type :math:`t`, then all the integers except the :math:`t^{th}`\n  one are zero; the :math:`t^{th}` integer is the weight of that node. ParMETIS\n  will try to balance the total node weight of each partition. For heterogeneous\n  graph, it will try to distribute nodes of the same type to all partitions.\n  The recommended node weights are 1 for balancing the number of nodes in each\n  partition or node degrees for balancing the number of edges in each partition.\n* ``<type_wise_node_id>`` is an integer representing the node ID in its own type.\n\nBelow shows an example of a node file for a heterogeneous graph with two node\ntypes. Node type 0 has three nodes; node type 1 has four nodes. It uses two\nnode weights to ensure that ParMETIS will generate partitions with roughly the\nsame number of nodes for type 0 and the same number of nodes for type 1.\n\n.. code-block:: none\n\n    0 1 0 0\n    0 1 0 1\n    0 1 0 2\n    1 0 1 0\n    1 0 1 1\n    1 0 1 2\n    1 0 1 3\n\nSimilarly, each row in ``xxx_edges.txt`` stores the information of an edge. Row ID is\nalso the *homogeneous* ID of an edge, e.g., row 0 is for edge 0; row 1 is for\nedge 1, etc. Each row has the following format:\n\n.. code-block:: none\n\n    <src_node_id> <dst_node_id> <type_wise_edge_id> <edge_type_id>\n\nAll fields are separated by whitespace:\n\n* ``<src_node_id>`` is the *homogeneous* ID of the source node.\n* ``<dst_node_id>`` is the *homogeneous* ID of the destination node.\n* ``<type_wise_edge_id>`` is the edge ID for the edge type.\n* ``<edge_type_id>`` is an integer starting from 0. Each edge type is mapped to\n  an integer. For a homogeneous graph, its value is always 0.\n\n``xxx_stats.txt`` stores some basic statistics of the graph. It has only one line with three fields\nseparated by whitespace:\n\n.. code-block:: none\n\n    <num_nodes> <num_edges> <total_node_weights>\n\n* ``num_nodes`` stores the total number of nodes regardless of node types.\n* ``num_edges`` stores the total number of edges regardless of edge types.\n* ``total_node_weights`` stores the number of node weights in the node file.\n\nRun ParMETIS and output format\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nParMETIS contains a command called ``pm_dglpart``, which loads the graph stored\nin the three files from the machine where ``pm_dglpart`` is invoked, distributes\ndata to all machines in the cluster and invokes ParMETIS to partition the\ngraph. When it completes, it generates three files for each partition:\n``p<part_id>-xxx_nodes.txt``, ``p<part_id>-xxx_edges.txt``,\n``p<part_id>-xxx_stats.txt``.\n\n.. note::\n\n    ParMETIS reassigns IDs to nodes during the partitioning. After ID reassignment,\n    the nodes in a partition are assigned with contiguous IDs; furthermore, the nodes of\n    the same type are assigned with contiguous IDs.\n\n``p<part_id>-xxx_nodes.txt`` stores the node data of the partition. Each row represents\na node with the following fields:\n\n.. code-block:: none\n\n    <node_id> <node_type_id> <node_weight_list> <type_wise_node_id>\n\n* ``<node_id>`` is the *homogeneous* node ID after ID reassignment.\n* ``<node_type_id>`` is the node type ID.\n* ``<node_weight_list>`` is the node weight used by ParMETIS (copied from the input file).\n* ``<type_wise_node_id>`` is an integer representing the node ID in its own type.\n\n``p<part_id>-xxx_edges.txt`` stores the edge data of the partition. Each row represents\nan edge with the following fields:\n\n.. code-block:: none\n\n    <src_id> <dst_id> <orig_src_id> <orig_dst_id> <type_wise_edge_id> <edge_type_id>\n\n* ``<src_id>`` is the *homogeneous* ID of the source node after ID reassignment.\n* ``<dst_id>`` is the *homogeneous* ID of the destination node after ID reassignment.\n* ``<orig_src_id>`` is the *homogeneous* ID of the source node in the input graph.\n* ``<orig_dst_id>`` is the *homogeneous* ID of the destination node in the input graph.\n* ``<type_wise_edge_id>`` is the edge ID in its own type.\n* ``<edge_type_id>`` is the edge type ID.\n\nWhen invoking ``pm_dglpart``, the three input files: ``xxx_nodes.txt``,\n``xxx_edges.txt``, ``xxx_stats.txt`` should be located in the directory where\n``pm_dglpart`` runs. The following command run four ParMETIS processes to\npartition the graph named ``xxx`` into eight partitions (each process handles\ntwo partitions).\n\n.. code-block:: bash\n\n    mpirun -np 4 pm_dglpart xxx 2\n\nThe output files from ParMETIS then need to be converted to the\n:ref:`partition assignment format <guide-distributed-prep-partition>` to in\norder to run subsequent preprocessing steps.\n"
  },
  {
    "path": "docs/source/guide/distributed-preprocessing.rst",
    "content": ".. _guide-distributed-preprocessing:\n\n7.1 Data Preprocessing\n------------------------------------------\n\nBefore launching training jobs, DGL requires the input data to be partitioned\nand distributed to the target machines. In order to handle different scales\nof graphs, DGL provides 2 partitioning approaches:\n\n* A partitioning API for graphs that can fit in a single machine memory.\n* A distributed partition pipeline for graphs beyond a single machine capacity.\n\n7.1.1 Partitioning API\n^^^^^^^^^^^^^^^^^^^^^^\n\nFor relatively small graphs, DGL provides a partitioning API\n:func:`~dgl.distributed.partition_graph` that partitions\nan in-memory :class:`~dgl.DGLGraph` object. It supports\nmultiple partitioning algorithms such as random partitioning and\n`Metis <http://glaros.dtc.umn.edu/gkhome/views/metis>`__.\nThe benefit of Metis partitioning is that it can generate partitions with\nminimal edge cuts to reduce network communication for distributed training and\ninference. DGL uses the latest version of Metis with the options optimized for\nthe real-world graphs with power-law distribution. After partitioning, the API\nconstructs the partitioned results in a format that is easy to load during the\ntraining. For example,\n\n.. code-block:: python\n\n    import dgl\n\n    g = ...  # create or load a DGLGraph object\n    dgl.distributed.partition_graph(g, 'mygraph', 2, 'data_root_dir')\n\nwill outputs the following data file.\n\n.. code-block:: none\n\n    data_root_dir/\n      |-- mygraph.json          # metadata JSON. File name is the given graph name.\n      |-- part0/                # data for partition 0\n      |  |-- node_feats.dgl     # node features stored in binary format\n      |  |-- edge_feats.dgl     # edge features stored in binary format\n      |  |-- graph.dgl          # graph structure of this partition stored in binary format\n      |\n      |-- part1/                # data for partition 1\n         |-- node_feats.dgl\n         |-- edge_feats.dgl\n         |-- graph.dgl\n\nChapter :ref:`guide-distributed-partition` covers more details about the\npartition format. To distribute the partitions to a cluster, users can either save\nthe data in some shared folder accessible by all machines, or copy the metadata\nJSON as well as the corresponding partition folder ``partX`` to the X^th machine.\n\nUsing :func:`~dgl.distributed.partition_graph` requires an instance with large enough\nCPU RAM to hold the entire graph structure and features, which may not be viable for\ngraphs with hundreds of billions of edges or large features. We describe how to use\nthe *parallel data preparation pipeline* for such cases next.\n\nLoad balancing\n~~~~~~~~~~~~~~\n\nWhen partitioning a graph, by default, METIS only balances the number of nodes\nin each partition.  This can result in suboptimal configuration, depending on\nthe task at hand. For example, in the case of semi-supervised node\nclassification, a trainer performs computation on a subset of labeled nodes in\na local partition. A partitioning that only balances nodes in a graph (both\nlabeled and unlabeled), may end up with computational load imbalance. To get a\nbalanced workload in each partition, the partition API allows balancing between\npartitions with respect to the number of nodes in each node type, by specifying\n``balance_ntypes`` in :func:`~dgl.distributed.partition_graph`. Users can take\nadvantage of this and consider nodes in the training set, validation set and\ntest set are of different node types.\n\nThe following example considers nodes inside the training set and outside the\ntraining set are two types of nodes:\n\n.. code:: python\n\n    dgl.distributed.partition_graph(g, 'graph_name', 4, '/tmp/test', balance_ntypes=g.ndata['train_mask'])\n\nIn addition to balancing the node types,\n:func:`dgl.distributed.partition_graph` also allows balancing between\nin-degrees of nodes of different node types by specifying ``balance_edges``.\nThis balances the number of edges incident to the nodes of different types.\n\nID mapping\n~~~~~~~~~~~~~\n\nAfter partitioning, :func:`~dgl.distributed.partition_graph` remap node\nand edge IDs so that nodes of the same partition are aranged together\n(in a consecutive ID range), making it easier to store partitioned node/edge\nfeatures. The API also automatically shuffles the node/edge features\naccording to the new IDs. However, some downstream tasks may want to\nrecover the original node/edge IDs (such as extracting the computed node\nembeddings for later use). For such cases, pass ``return_mapping=True``\nto :func:`~dgl.distributed.partition_graph`, which makes the API returns\nthe ID mappings between the remapped node/edge IDs and their origianl ones.\nFor a homogeneous graph, it returns two vectors. The first vector maps every new\nnode ID to its original ID; the second vector maps every new edge ID to\nits original ID. For a heterogeneous graph, it returns two dictionaries of\nvectors. The first dictionary contains the mapping for each node type; the\nsecond dictionary contains the mapping for each edge type.\n\n.. code:: python\n\n    node_map, edge_map = dgl.distributed.partition_graph(g, 'graph_name', 4, '/tmp/test',\n                                                         balance_ntypes=g.ndata['train_mask'],\n                                                         return_mapping=True)\n    # Let's assume that node_emb is saved from the distributed training.\n    orig_node_emb = th.zeros(node_emb.shape, dtype=node_emb.dtype)\n    orig_node_emb[node_map] = node_emb\n\n\nLoad partitioned graphs\n^^^^^^^^^^^^^^^^^^^^^^^\n\nDGL provides a :func:`dgl.distributed.load_partition` function to load one partition\nfor inspection.\n\n.. code:: python\n\n  >>> import dgl\n  >>> # load partition 0\n  >>> part_data = dgl.distributed.load_partition('data_root_dir/graph_name.json', 0)\n  >>> g, nfeat, efeat, partition_book, graph_name, ntypes, etypes = part_data  # unpack\n  >>> print(g)\n  Graph(num_nodes=966043, num_edges=34270118,\n        ndata_schemes={'orig_id': Scheme(shape=(), dtype=torch.int64),\n                       'part_id': Scheme(shape=(), dtype=torch.int64),\n                       '_ID': Scheme(shape=(), dtype=torch.int64),\n                       'inner_node': Scheme(shape=(), dtype=torch.int32)}\n        edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64),\n                       'inner_edge': Scheme(shape=(), dtype=torch.int8),\n                       'orig_id': Scheme(shape=(), dtype=torch.int64)})\n\nAs mentioned in the `ID mapping`_ section, each partition carries auxiliary information\nsaved as ndata or edata such as original node/edge IDs, partition IDs, etc. Each partition\nnot only saves nodes/edges it owns, but also includes node/edges that are adjacent to\nthe partition (called **HALO** nodes/edges). The ``inner_node`` and ``inner_edge``\nindicate whether a node/edge truely belongs to the partition (value is ``True``)\nor is a HALO node/edge (value is ``False``).\n\nThe :func:`~dgl.distributed.load_partition` function loads all data at once. Users can\nload features or the partition book using the :func:`dgl.distributed.load_partition_feats`\nand :func:`dgl.distributed.load_partition_book` APIs respectively.\n\n\n7.1.2 Distributed Graph Partitioning Pipeline\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nTo handle massive graph data that cannot fit in the CPU RAM of a\nsingle machine, DGL utilizes data chunking and parallel processing to reduce\nmemory footprint and running time. The figure below illustrates the\npipeline:\n\n.. figure:: https://data.dgl.ai/asset/image/guide_7_distdataprep.png\n\n* The pipeline takes input data stored in *Chunked Graph Format* and\n  produces and dispatches data partitions to the target machines.\n* **Step.1 Graph Partitioning:** It calculates the ownership of each partition\n  and saves the results as a set of files called *partition assignment*.\n  To speedup the step, some algorithms (e.g., ParMETIS) support parallel computing\n  using multiple machines.\n* **Step.2 Data Dispatching:** Given the partition assignment, the step then\n  physically partitions the graph data and dispatches them to the machines user\n  specified. It also converts the graph data into formats that are suitable for\n  distributed training and evaluation.\n\nThe whole pipeline is modularized so that each step can be invoked\nindividually. For example, users can replace Step.1 with some custom graph partition\nalgorithm as long as it produces partition assignment files\ncorrectly.\n\n.. _guide-distributed-prep-chunk:\nChunked Graph Format\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nTo run the pipeline, DGL requires the input graph to be stored in multiple data\nchunks.  Each data chunk is the unit of data preprocessing and thus should fit\ninto CPU RAM.  In this section, we use the MAG240M-LSC data from `Open Graph\nBenchmark <https://ogb.stanford.edu/docs/lsc/mag240m/>`__  as an example to\ndescribe the overall design, followed by a formal specification and\ntips for creating data in such format.\n\nExample: MAG240M-LSC\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nThe MAG240M-LSC graph is a heterogeneous academic graph\nextracted from the Microsoft Academic Graph (MAG), whose schema diagram is\nillustrated below:\n\n.. figure:: https://data.dgl.ai/asset/image/guide_7_mag240m.png\n\nIts raw data files are organized as follows:\n\n.. code-block:: none\n\n    /mydata/MAG240M-LSC/\n      |-- meta.pt   # # A dictionary of the number of nodes for each type saved by torch.save,\n      |             # as well as num_classes\n      |-- processed/\n        |-- author___affiliated_with___institution/\n        |  |-- edge_index.npy            # graph, 713 MB\n        |\n        |-- paper/\n        |  |-- node_feat.npy             # feature, 187 GB, (numpy memmap format)\n        |  |-- node_label.npy            # label, 974 MB\n        |  |-- node_year.npy             # year, 974 MB\n        |\n        |-- paper___cites___paper/\n        |  |-- edge_index.npy            # graph, 21 GB\n        |\n        |-- author___writes___paper/\n           |-- edge_index.npy            # graph, 6GB\n\nThe graph has three node types (``\"paper\"``, ``\"author\"`` and ``\"institution\"``),\nthree edge types/relations (``\"cites\"``, ``\"writes\"`` and ``\"affiliated_with\"``). The\n``\"paper\"`` nodes have three attributes (``\"feat\"``, ``\"label\"``, ``\"year\"'``), while\nother types of nodes and edges are featureless. Below shows the data files when\nit is stored in DGL Chunked Graph Format:\n\n.. code-block:: none\n\n    /mydata/MAG240M-LSC_chunked/\n      |-- metadata.json            # metadata json file\n      |-- edges/                   # stores edge ID data\n      |  |-- writes-part1.csv\n      |  |-- writes-part2.csv\n      |  |-- affiliated_with-part1.csv\n      |  |-- affiliated_with-part2.csv\n      |  |-- cites-part1.csv\n      |  |-- cites-part1.csv\n      |\n      |-- node_data/               # stores node feature data\n         |-- paper-feat-part1.npy\n         |-- paper-feat-part2.npy\n         |-- paper-label-part1.npy\n         |-- paper-label-part2.npy\n         |-- paper-year-part1.npy\n         |-- paper-year-part2.npy\n\nAll the data files are chunked into two parts, including the edges of each relation\n(e.g., writes, affiliates, cites) and node features. If the graph has edge features,\nthey will be chunked into multiple files too. All ID data are stored in\nCSV (we will illustrate the contents soon) while node features are stored in\nnumpy arrays.\n\nThe ``metadata.json`` stores all the metadata information such as file names\nand chunk sizes (e.g., number of nodes, number of edges).\n\n.. code-block:: python\n\n    {\n       \"graph_name\" : \"MAG240M-LSC\",  # given graph name\n       \"node_type\": [\"author\", \"paper\", \"institution\"],\n       \"num_nodes_per_chunk\": [\n           [61191556, 61191556],      # number of author nodes per chunk\n           [61191553, 61191552],      # number of paper nodes per chunk\n           [12861, 12860]             # number of institution nodes per chunk\n       ],\n       # The edge type name is a colon-joined string of source, edge, and destination type.\n       \"edge_type\": [\n           \"author:writes:paper\",\n           \"author:affiliated_with:institution\",\n           \"paper:cites:paper\"\n       ],\n       \"num_edges_per_chunk\": [\n           [193011360, 193011360],    # number of author:writes:paper edges per chunk\n           [22296293, 22296293],      # number of author:affiliated_with:institution edges per chunk\n           [648874463, 648874463]     # number of paper:cites:paper edges per chunk\n       ],\n       \"edges\" : {\n            \"author:writes:paper\" : {  # edge type\n                 \"format\" : {\"name\": \"csv\", \"delimiter\": \" \"},\n                 # The list of paths. Can be relative or absolute.\n                 \"data\" : [\"edges/writes-part1.csv\", \"edges/writes-part2.csv\"]\n            },\n            \"author:affiliated_with:institution\" : {\n                 \"format\" : {\"name\": \"csv\", \"delimiter\": \" \"},\n                 \"data\" : [\"edges/affiliated_with-part1.csv\", \"edges/affiliated_with-part2.csv\"]\n            },\n            \"paper:cites:paper\" : {\n                 \"format\" : {\"name\": \"csv\", \"delimiter\": \" \"},\n                 \"data\" : [\"edges/cites-part1.csv\", \"edges/cites-part2.csv\"]\n            }\n       },\n       \"node_data\" : {\n            \"paper\": {       # node type\n                 \"feat\": {   # feature key\n                     \"format\": {\"name\": \"numpy\"},\n                     \"data\": [\"node_data/paper-feat-part1.npy\", \"node_data/paper-feat-part2.npy\"]\n                 },\n                 \"label\": {   # feature key\n                     \"format\": {\"name\": \"numpy\"},\n                     \"data\": [\"node_data/paper-label-part1.npy\", \"node_data/paper-label-part2.npy\"]\n                 },\n                 \"year\": {   # feature key\n                     \"format\": {\"name\": \"numpy\"},\n                     \"data\": [\"node_data/paper-year-part1.npy\", \"node_data/paper-year-part2.npy\"]\n                 }\n            }\n       },\n       \"edge_data\" : {}  # MAG240M-LSC does not have edge features\n    }\n\nThere are three parts in ``metadata.json``:\n\n* Graph schema information and chunk sizes, e.g., ``\"node_type\"`` , ``\"num_nodes_per_chunk\"``, etc.\n* Edge index data under key ``\"edges\"``.\n* Node/edge feature data under keys ``\"node_data\"`` and ``\"edge_data\"``.\n\nThe edge index files contain edges in the form of node ID pairs:\n\n.. code-block:: bash\n\n    # writes-part1.csv\n    0 0\n    0 1\n    0 20\n    0 29\n    0 1203\n    ...\n\nSpecification\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nIn general, a chunked graph data folder just needs a ``metadata.json`` and a\nbunch of data files. The folder structure in the MAG240M-LSC example is not a\nstrict requirement as long as ``metadata.json`` contains valid file paths.\n\n``metadata.json`` top-level keys:\n\n* ``graph_name``: String. Unique name used by :class:`dgl.distributed.DistGraph`\n  to load graph.\n* ``node_type``: List of string. Node type names.\n* ``num_nodes_per_chunk``: List of list of integer. For graphs with :math:`T` node\n  types stored in :math:`P` chunks, the value contains :math:`T` integer lists.\n  Each list contains :math:`P` integers, which specify the number of nodes\n  in each chunk.\n* ``edge_type``: List of string. Edge type names in the form of\n  ``<source node type>:<relation>:<destination node type>``.\n* ``num_edges_per_chunk``: List of list of integer. For graphs with :math:`R` edge\n  types stored in :math:`P` chunks, the value contains :math:`R` integer lists.\n  Each list contains :math:`P` integers, which specify the number of edges\n  in each chunk.\n* ``edges``: Dict of ``ChunkFileSpec``. Edge index files.\n  Dictionary keys are edge type names in the form of\n  ``<source node type>:<relation>:<destination node type>``.\n* ``node_data``: Dict of ``ChunkFileSpec``. Data files that store node attributes\n  could have arbitrary number of files regardless of ``num_parts``. Dictionary\n  keys are node type names.\n* ``edge_data``: Dict of ``ChunkFileSpec``. Data files that store edge attributes\n  could have arbitrary number of files regardless of ``num_parts``. Dictionary\n  keys are edge type names in the form of\n  ``<source node type>:<relation>:<destination node type>``.\n\n``ChunkFileSpec`` has two keys:\n\n* ``format``: File format. Depending on the format ``name``, users can configure more\n  details about how to parse each data file.\n    - ``\"csv\"``: CSV file. Use the ``delimiter`` key to specify delimiter in use.\n    - ``\"numpy\"``: NumPy array binary file created by :func:`numpy.save`.\n    - ``\"parquet\"``: parquet table binary file created by :func:`pyarrow.parquet.write_table`.\n* ``data``: List of string. File path to each data chunk. Support absolute path.\n\nTips for making chunked graph data\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nDepending on the raw data, the implementation could include:\n\n* Construct graphs out of non-structured data such as texts or tabular data.\n* Augment or transform the input graph struture or features. E.g., adding reverse\n  or self-loop edges, normalizing features, etc.\n* Chunk the input graph structure and features into multiple data files so that\n  each one can fit in CPU RAM for subsequent preprocessing steps.\n\nTo avoid running into out-of-memory error, it is recommended to process graph\nstructures and feature data separately. Processing one chunk at a time can also\nreduce the maximal runtime memory footprint. As an example, DGL provides a\n`tools/chunk_graph.py\n<https://github.com/dmlc/dgl/blob/master/tools/chunk_graph.py>`_ script that\nchunks an in-memory feature-less :class:`~dgl.DGLGraph` and feature tensors\nstored in :class:`numpy.memmap`.\n\n\n.. _guide-distributed-prep-partition:\nStep.1 Graph Partitioning\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nThis step reads the chunked graph data and calculates which partition each node\nshould belong to. The results are saved in a set of *partition assignment files*.\nFor example, to randomly partition MAG240M-LSC to two parts, run the\n``partition_algo/random_partition.py`` script in the ``tools`` folder:\n\n.. code-block:: bash\n\n    python /my/repo/dgl/tools/partition_algo/random_partition.py\n        --in_dir /mydata/MAG240M-LSC_chunked\n        --out_dir /mydata/MAG240M-LSC_2parts\n        --num_partitions 2\n\n, which outputs files as follows:\n\n.. code-block:: none\n\n    MAG240M-LSC_2parts/\n      |-- paper.txt\n      |-- author.txt\n      |-- institution.txt\n\nEach file stores the partition assignment of the corresponding node type.\nThe contents are the partition ID of each node stored in lines, i.e., line i is\nthe partition ID of node i.\n\n.. code-block:: bash\n\n    # paper.txt\n    0\n    1\n    1\n    0\n    0\n    1\n    0\n    ...\n\nDespite its simplicity, random partitioning may result in frequent\ncross-machine communication.  Check out chapter\n:ref:`guide-distributed-partition` for more advanced options.\n\nStep.2 Data Dispatching\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nDGL provides a ``dispatch_data.py`` script to physically partition the data and\ndispatch partitions to each training machines. It will also convert the data\nonce again to data objects that can be loaded by DGL training processes\nefficiently. The entire step can be further accelerated using multi-processing.\n\n.. code-block:: bash\n\n    python /myrepo/dgl/tools/dispatch_data.py         \\\n       --in-dir /mydata/MAG240M-LSC_chunked/          \\\n       --partitions-dir /mydata/MAG240M-LSC_2parts/   \\\n       --out-dir data/MAG_LSC_partitioned            \\\n       --ip-config ip_config.txt\n\n* ``--in-dir`` specifies the path to the folder of the input chunked graph data produced\n* ``--partitions-dir`` specifies the path to the partition assignment folder produced by Step.1.\n* ``--out-dir`` specifies the path to stored the data partition on each machine.\n* ``--ip-config`` specifies the IP configuration file of the cluster.\n\nAn example IP configuration file is as follows:\n\n.. code-block:: bash\n\n    172.31.19.1\n    172.31.23.205\n\nAs a counterpart of ``return_mapping=True`` in :func:`~dgl.distributed.partition_graph`, the\n:ref:`distributed partitioning pipeline <guide-distributed-preprocessing>`\nprovides two arguments in ``dispatch_data.py`` to save the original node/edge IDs to disk.\n\n* ``--save-orig-nids`` save original node IDs into files.\n* ``--save-orig-eids`` save original edge IDs into files.\n\nSpecifying the two options will create two files ``orig_nids.dgl`` and ``orig_eids.dgl``\nunder each partition folder.\n\n.. code-block:: none\n\n    data_root_dir/\n      |-- graph_name.json       # partition configuration file in JSON\n      |-- part0/                # data for partition 0\n      |  |-- orig_nids.dgl      # original node IDs\n      |  |-- orig_eids.dgl      # original edge IDs\n      |  |-- ...                # other data such as graph and node/edge feats\n      |\n      |-- part1/                # data for partition 1\n      |  |-- orig_nids.dgl\n      |  |-- orig_eids.dgl\n      |  |-- ...\n      |\n      |-- ...                   # data for other partitions\n\nThe two files store the original IDs as a dictionary of tensors, where keys are node/edge\ntype names and values are ID tensors. Users can use the :func:`dgl.data.load_tensors`\nutility to load them:\n\n.. code:: python\n\n    # Load the original IDs for the nodes in partition 0.\n    orig_nids_0 = dgl.data.load_tensors('/path/to/data/part0/orig_nids.dgl')\n    # Get the original node IDs for node type 'user'\n    user_orig_nids_0 = orig_nids_0['user']\n\n    # Load the original IDs for the edges in partition 0.\n    orig_eids_0 = dgl.data.load_tensors('/path/to/data/part0/orig_eids.dgl')\n    # Get the original edge IDs for edge type 'like'\n    like_orig_eids_0 = orig_nids_0['like']\n\nDuring data dispatching, DGL assumes that the combined CPU RAM of the cluster\nis able to hold the entire graph data. Node ownership is determined by the result\nof partitioning algorithm where as for edges the owner of the destination node\nalso owns the edge as well.\n\n"
  },
  {
    "path": "docs/source/guide/distributed-tools.rst",
    "content": ".. _guide-distributed-tools:\n\n7.2 Tools for launching distributed training/inference\n------------------------------------------------------\n\nDGL provides a launching script ``launch.py`` under\n`dgl/tools <https://github.com/dmlc/dgl/tree/master/tools>`__ to launch a distributed\ntraining job in a cluster. This script makes the following assumptions:\n\n* The partitioned data and the training script have been provisioned to the cluster or\n  a shared storage (e.g., NFS) accessible to all the worker machines.\n* The machine that invokes ``launch.py`` has passwordless ssh access\n  to all other machines. The launching machine must be one of the worker machines.\n\nBelow shows an example of launching a distributed training job in a cluster.\n\n.. code:: bash\n\n    python3 tools/launch.py               \\\n      --workspace /my/workspace/          \\\n      --num_trainers 2                    \\\n      --num_samplers 4                    \\\n      --num_servers 1                     \\\n      --part_config data/mygraph.json     \\\n      --ip_config ip_config.txt           \\\n      \"python3 my_train_script.py\"\n\nThe argument specifies the workspace path, where to find the partition metadata JSON\nand machine IP configurations, how many trainer, sampler, and server processes to be launched\non each machine. The last argument is the command to launch which is usually the\nmodel training/evaluation script.\n\nEach line of ``ip_config.txt`` is the IP address of a machine in the cluster.\nOptionally, the IP address can be followed by a network port (default is ``30050``).\nA typical example is as follows:\n\n.. code:: none\n\n    172.31.19.1\n    172.31.23.205\n    172.31.29.175\n    172.31.16.98\n\nThe workspace specified in the launch script is the working directory in the\nmachines, which contains the training script, the IP configuration file, the\npartition configuration file as well as the graph partitions. All paths of the\nfiles should be specified as relative paths to the workspace.\n\nThe launch script creates a specified number of training jobs\n(``--num_trainers``) on each machine.  In addition, users need to specify the\nnumber of sampler processes for each trainer (``--num_samplers``).\n"
  },
  {
    "path": "docs/source/guide/distributed.rst",
    "content": ".. _guide-distributed:\n\nChapter 7: Distributed Training\n=====================================\n\n:ref:`(中文版) <guide_cn-distributed>`\n\n.. note::\n\n    Distributed training is only available for PyTorch backend.\n\nDGL adopts a fully distributed approach that distributes both data and computation\nacross a collection of computation resources. In the context of this section, we\nwill assume a cluster setting (i.e., a group of machines). DGL partitions a graph\ninto subgraphs and each machine in a cluster is responsible for one subgraph (partition).\nDGL runs an identical training script on all machines in the cluster to parallelize\nthe computation and runs servers on the same machines to serve partitioned data to the trainers.\n\nFor the training script, DGL provides distributed APIs that are similar to the ones for\nmini-batch training. This makes distributed training require only small code modifications\nfrom mini-batch training on a single machine. Below shows an example of training GraphSage\nin a distributed fashion. The notable code modifications are:\n1) initialization of DGL's distributed module, 2) create a distributed graph object, and\n3) split the training set and calculate the nodes for the local process.\nThe rest of the code, including sampler creation, model definition, training loops\nare the same as :ref:`mini-batch training <guide-minibatch>`.\n\n.. code:: python\n\n    import dgl\n    from dgl.dataloading import NeighborSampler\n    from dgl.distributed import DistGraph, DistDataLoader, node_split\n    import torch as th\n\n    # initialize distributed contexts\n    dgl.distributed.initialize('ip_config.txt')\n    th.distributed.init_process_group(backend='gloo')\n    # load distributed graph\n    g = DistGraph('graph_name', 'part_config.json')\n    pb = g.get_partition_book()\n    # get training workload, i.e., training node IDs\n    train_nid = node_split(g.ndata['train_mask'], pb, force_even=True)\n\n\n    # Create sampler\n    sampler = NeighborSampler(g, [10,25],\n                              dgl.distributed.sample_neighbors,\n                              device)\n\n    dataloader = DistDataLoader(\n        dataset=train_nid.numpy(),\n        batch_size=batch_size,\n        collate_fn=sampler.sample_blocks,\n        shuffle=True,\n        drop_last=False)\n\n    # Define model and optimizer\n    model = SAGE(in_feats, num_hidden, n_classes, num_layers, F.relu, dropout)\n    model = th.nn.parallel.DistributedDataParallel(model)\n    loss_fcn = nn.CrossEntropyLoss()\n    optimizer = optim.Adam(model.parameters(), lr=args.lr)\n\n    # training loop\n    for epoch in range(args.num_epochs):\n        with model.join():\n            for step, blocks in enumerate(dataloader):\n                batch_inputs, batch_labels = load_subtensor(g, blocks[0].srcdata[dgl.NID],\n                                                            blocks[-1].dstdata[dgl.NID])\n                batch_pred = model(blocks, batch_inputs)\n                loss = loss_fcn(batch_pred, batch_labels)\n                optimizer.zero_grad()\n                loss.backward()\n                optimizer.step()\n\nDGL implements a few distributed components to support distributed training. The figure below\nshows the components and their interactions.\n\n.. figure:: https://data.dgl.ai/asset/image/distributed.png\n   :alt: Imgur\n\nSpecifically, DGL's distributed training has three types of interacting processes:\n*server*, *sampler* and *trainer*.\n\n* **Servers** store graph partitions which includes both structure data and node/edge\n  features. They provide services such as sampling, getting or updating node/edge\n  features. Note that each machine may run multiple server processes simultaneously\n  to increase service throughput. One of them is *main server* in charge of data\n  loading and sharing data via shared memory with *backup servers* that provide\n  services.\n* **Sampler processes** interact with the servers and sample nodes and edges to\n  generate mini-batches for training.\n* **Trainers** are in charge of training networks on mini-batches. They utilize\n  APIs such as :class:`~dgl.distributed.DistGraph` to access partitioned graph data,\n  :class:`~dgl.distributed.DistEmbedding` and :class:`~dgl.distributed.DistTensor` to access\n  node/edge features/embeddings and :class:`~dgl.distributed.DistDataLoader` to interact\n  with samplers to get mini-batches. Trainers communicate gradients among each other\n  using PyTorch's native ``DistributedDataParallel`` paradigm.\n\nBesides Python APIs, DGL also provides `tools <https://github.com/dmlc/dgl/tree/master/tools>`__\nfor provisioning graph data and processes to the entire cluster.\n\nHaving the distributed components in mind, the rest of the section will cover\nthe following distributed components:\n\n* :ref:`guide-distributed-preprocessing`\n* :ref:`guide-distributed-tools`\n* :ref:`guide-distributed-apis`\n\nFor more advanced users who are interested in more details:\n\n* :ref:`guide-distributed-partition`\n* :ref:`guide-distributed-hetero`\n\n.. toctree::\n    :maxdepth: 1\n    :hidden:\n    :glob:\n\n    distributed-preprocessing\n    distributed-tools\n    distributed-apis\n    distributed-partition\n    distributed-hetero\n"
  },
  {
    "path": "docs/source/guide/graph-basic.rst",
    "content": ".. _guide-graph-basic:\n\n1.1 Some Basic Definitions about Graphs (Graphs 101)\n----------------------------------------------------\n\n:ref:`(中文版)<guide_cn-graph-basic>`\n\nA graph :math:`G=(V, E)` is a structure used to represent entities and their relations. It consists of\ntwo sets -- the set of nodes :math:`V` (also called vertices) and the set of edges :math:`E` (also called\narcs). An edge :math:`(u, v) \\in E` connecting a pair of nodes :math:`u` and :math:`v` indicates that there is a\nrelation between them. The relation can either be undirected, e.g., capturing symmetric\nrelations between nodes, or directed, capturing asymmetric relations. For example, if a\ngraph is used to model the friendships relations of people in a social network, then the edges\nwill be undirected as friendship is mutual; however, if the graph is used to model how people\nfollow each other on Twitter, then the edges are directed. Depending on the edges'\ndirectionality, a graph can be *directed* or *undirected*.\n\nGraphs can be *weighted* or *unweighted*. In a weighted graph, each edge is associated with a\nscalar weight. For example, such weights might represent lengths or connectivity strengths.\n\nGraphs can also be either *homogeneous* or *heterogeneous*. In a homogeneous graph, all\nthe nodes represent instances of the same type and all the edges represent relations of the\nsame type. For instance, a social network is a graph consisting of people and their\nconnections, representing the same entity type.\n\nIn contrast, in a heterogeneous graph, the nodes and edges can be of different types. For\ninstance, the graph encoding a marketplace will have buyer, seller, and product nodes that\nare connected via wants-to-buy, has-bought, is-customer-of, and is-selling edges. The\nbipartite graph is a special, commonly-used type of heterogeneous graph, where edges\nexist between nodes of two different types. For example, in a recommender system, one can\nuse a bipartite graph to represent the interactions between users and items. For working\nwith heterogeneous graphs in DGL, see :ref:`guide-graph-heterogeneous`.\n\nMultigraphs are graphs that can have multiple (directed) edges between the same pair of nodes,\nincluding self loops. For instance, two authors can coauthor a paper in different years,\nresulting in edges with different features.\n"
  },
  {
    "path": "docs/source/guide/graph-external.rst",
    "content": ".. _guide-graph-external:\n\n1.4 Creating Graphs from External Sources\n-----------------------------------------\n\n:ref:`(中文版)<guide_cn-graph-external>`\n\nThe options to construct a :class:`~dgl.DGLGraph` from external sources include:\n\n- Conversion from external python libraries for graphs and sparse matrices (NetworkX and SciPy).\n- Loading graphs from disk.\n\nThe section does not cover functions that generate graphs by transforming from other\ngraphs. See the API reference manual for an overview of them.\n\nCreating Graphs from External Libraries\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nThe following code snippet is an example for creating a graph from a SciPy sparse matrix and a NetworkX graph.\n\n.. code::\n\n    >>> import dgl\n    >>> import torch as th\n    >>> import scipy.sparse as sp\n    >>> spmat = sp.rand(100, 100, density=0.05) # 5% nonzero entries\n    >>> dgl.from_scipy(spmat)                   # from SciPy\n    Graph(num_nodes=100, num_edges=500,\n          ndata_schemes={}\n          edata_schemes={})\n\n    >>> import networkx as nx\n    >>> nx_g = nx.path_graph(5) # a chain 0-1-2-3-4\n    >>> dgl.from_networkx(nx_g) # from networkx\n    Graph(num_nodes=5, num_edges=8,\n          ndata_schemes={}\n          edata_schemes={})\n\nNote that when constructing from the `nx.path_graph(5)`, the resulting :class:`~dgl.DGLGraph` has 8\nedges instead of 4. This is because `nx.path_graph(5)` constructs an undirected NetworkX graph\n:class:`networkx.Graph` while a :class:`~dgl.DGLGraph` is always directed. In converting an undirected\nNetworkX graph into a :class:`~dgl.DGLGraph`, DGL internally converts undirected edges to two directed edges.\nUsing directed NetworkX graphs :class:`networkx.DiGraph` can avoid such behavior.\n\n.. code::\n\n    >>> nxg = nx.DiGraph([(2, 1), (1, 2), (2, 3), (0, 0)])\n    >>> dgl.from_networkx(nxg)\n    Graph(num_nodes=4, num_edges=4,\n          ndata_schemes={}\n          edata_schemes={})\n\n.. note::\n\n    DGL internally converts SciPy matrices and NetworkX graphs to tensors to construct graphs.\n    Hence, these construction methods are not meant for performance critical parts.\n\nSee APIs: :func:`dgl.from_scipy`, :func:`dgl.from_networkx`.\n\nLoading Graphs from Disk\n^^^^^^^^^^^^^^^^^^^^^^^^\n\nThere are many data formats for storing graphs and it isn't possible to enumerate every option.\nThus, this section only gives some general pointers on certain common ones.\n\nComma Separated Values (CSV)\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n\nOne very common format is CSV, which stores nodes, edges, and their features in a tabular format:\n\n.. table:: nodes.csv\n\n   +-----------+\n   |age, title |\n   +===========+\n   |43, 1      |\n   +-----------+\n   |23, 3      |\n   +-----------+\n   |...        |\n   +-----------+\n\n.. table:: edges.csv\n\n   +-----------------+\n   |src, dst, weight |\n   +=================+\n   |0, 1, 0.4        |\n   +-----------------+\n   |0, 3, 0.9        |\n   +-----------------+\n   |...              |\n   +-----------------+\n\nThere are known Python libraries (e.g. pandas) for loading this type of data into python\nobjects (e.g., :class:`numpy.ndarray`), which can then be used to construct a DGLGraph. If the\nbackend framework also provides utilities to save/load tensors from disk (e.g., :func:`torch.save`,\n:func:`torch.load`), one can follow the same principle to build a graph.\n\nSee also: `Tutorial for loading a Karate Club Network from edge pairs CSV <https://github.com/dglai/WWW20-Hands-on-Tutorial/blob/master/basic_tasks/1_load_data.ipynb>`_.\n\nJSON/GML Format\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n\nThough not particularly fast, NetworkX provides many utilities to parse\n`a variety of data formats <https://networkx.github.io/documentation/stable/reference/readwrite/index.html>`_\nwhich indirectly allows DGL to create graphs from these sources.\n\nDGL Binary Format\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n\nDGL provides APIs to save and load graphs from disk stored in binary format. Apart from the\ngraph structure, the APIs also handle feature data and graph-level label data. DGL also\nsupports checkpointing graphs directly to S3 or HDFS. The reference manual provides more\ndetails about the usage.\n\nSee APIs: :func:`dgl.save_graphs`, :func:`dgl.load_graphs`.\n"
  },
  {
    "path": "docs/source/guide/graph-feature.rst",
    "content": ".. _guide-graph-feature:\n\n1.3 Node and Edge Features\n--------------------------\n\n:ref:`(中文版)<guide_cn-graph-feature>`\n\nThe nodes and edges of a :class:`~dgl.DGLGraph` can have several user-defined named features for\nstoring graph-specific properties of the nodes and edges. These features can be accessed\nvia the :py:attr:`~dgl.DGLGraph.ndata` and :py:attr:`~dgl.DGLGraph.edata` interface. For example, the following code creates two node\nfeatures (named ``'x'`` and ``'y'`` in line 8 and 15) and one edge feature (named ``'x'`` in line 9).\n\n.. code-block:: python\n    :linenos:\n\n    >>> import dgl\n    >>> import torch as th\n    >>> g = dgl.graph(([0, 0, 1, 5], [1, 2, 2, 0])) # 6 nodes, 4 edges\n    >>> g\n    Graph(num_nodes=6, num_edges=4,\n          ndata_schemes={}\n          edata_schemes={})\n    >>> g.ndata['x'] = th.ones(g.num_nodes(), 3)               # node feature of length 3\n    >>> g.edata['x'] = th.ones(g.num_edges(), dtype=th.int32)  # scalar integer feature\n    >>> g\n    Graph(num_nodes=6, num_edges=4,\n          ndata_schemes={'x' : Scheme(shape=(3,), dtype=torch.float32)}\n          edata_schemes={'x' : Scheme(shape=(,), dtype=torch.int32)})\n    >>> # different names can have different shapes\n    >>> g.ndata['y'] = th.randn(g.num_nodes(), 5)\n    >>> g.ndata['x'][1]                  # get node 1's feature\n    tensor([1., 1., 1.])\n    >>> g.edata['x'][th.tensor([0, 3])]  # get features of edge 0 and 3\n        tensor([1, 1], dtype=torch.int32)\n\nImportant facts about the :py:attr:`~dgl.DGLGraph.ndata`/:py:attr:`~dgl.DGLGraph.edata` interface:\n\n- Only features of numerical types (e.g., float, double, and int) are allowed. They can\n  be scalars, vectors or multi-dimensional tensors.\n- Each node feature has a unique name and each edge feature has a unique name.\n  The features of nodes and edges can have the same name. (e.g., 'x' in the above example).\n- A feature is created via tensor assignment, which assigns a feature to each\n  node/edge in the graph. The leading dimension of that tensor must be equal to the\n  number of nodes/edges in the graph. You cannot assign a feature to a subset of the\n  nodes/edges in the graph.\n- Features of the same name must have the same dimensionality and data type.\n- The feature tensor is in row-major layout -- each row-slice stores the feature of one\n  node or edge (e.g., see lines 16 and 18 in the above example).\n\nFor weighted graphs, one can store the weights as an edge feature as below.\n\n.. code-block:: python\n\n    >>> # edges 0->1, 0->2, 0->3, 1->3\n    >>> edges = th.tensor([0, 0, 0, 1]), th.tensor([1, 2, 3, 3])\n    >>> weights = th.tensor([0.1, 0.6, 0.9, 0.7])  # weight of each edge\n    >>> g = dgl.graph(edges)\n    >>> g.edata['w'] = weights  # give it a name 'w'\n    >>> g\n    Graph(num_nodes=4, num_edges=4,\n          ndata_schemes={}\n          edata_schemes={'w' : Scheme(shape=(,), dtype=torch.float32)})\n\n\nSee APIs: :py:attr:`~dgl.DGLGraph.ndata`, :py:attr:`~dgl.DGLGraph.edata`.\n"
  },
  {
    "path": "docs/source/guide/graph-gpu.rst",
    "content": ".. _guide-graph-gpu:\n\n1.6 Using DGLGraph on a GPU\n---------------------------\n\n:ref:`(中文版)<guide_cn-graph-gpu>`\n\nOne can create a :class:`~dgl.DGLGraph` on a GPU by passing two GPU tensors during construction.\nAnother approach is to use the :func:`~dgl.DGLGraph.to` API to copy a :class:`~dgl.DGLGraph` to a GPU, which\ncopies the graph structure as well as the feature data to the given device.\n\n.. code::\n\n    >>> import dgl\n    >>> import torch as th\n    >>> u, v = th.tensor([0, 1, 2]), th.tensor([2, 3, 4])\n    >>> g = dgl.graph((u, v))\n    >>> g.ndata['x'] = th.randn(5, 3)  # original feature is on CPU\n    >>> g.device\n    device(type='cpu')\n    >>> cuda_g = g.to('cuda:0')  # accepts any device objects from backend framework\n    >>> cuda_g.device\n    device(type='cuda', index=0)\n    >>> cuda_g.ndata['x'].device       # feature data is copied to GPU too\n    device(type='cuda', index=0)\n\n    >>> # A graph constructed from GPU tensors is also on GPU\n    >>> u, v = u.to('cuda:0'), v.to('cuda:0')\n    >>> g = dgl.graph((u, v))\n    >>> g.device\n    device(type='cuda', index=0)\n\nAny operations involving a GPU graph are performed on a GPU. Thus, they require all\ntensor arguments to be placed on GPU already and the results (graph or tensor) will be on\nGPU too. Furthermore, a GPU graph only accepts feature data on a GPU.\n\n.. code::\n\n    >>> cuda_g.in_degrees()\n    tensor([0, 0, 1, 1, 1], device='cuda:0')\n    >>> cuda_g.in_edges([2, 3, 4])   # ok for non-tensor type arguments\n    (tensor([0, 1, 2], device='cuda:0'), tensor([2, 3, 4], device='cuda:0'))\n    >>> cuda_g.in_edges(th.tensor([2, 3, 4]).to('cuda:0'))  # tensor type must be on GPU\n    (tensor([0, 1, 2], device='cuda:0'), tensor([2, 3, 4], device='cuda:0'))\n    >>> cuda_g.ndata['h'] = th.randn(5, 4)  # ERROR! feature must be on GPU too!\n    DGLError: Cannot assign node feature \"h\" on device cpu to a graph on device\n    cuda:0. Call DGLGraph.to() to copy the graph to the same device.\n"
  },
  {
    "path": "docs/source/guide/graph-graphs-nodes-edges.rst",
    "content": ".. _guide-graph-graphs-nodes-edges:\n\n1.2 Graphs, Nodes, and Edges\n----------------------------\n\n:ref:`(中文版)<guide_cn-graph-graphs-nodes-edges>`\n\nDGL represents each node by a unique integer, called its node ID, and each edge by a pair\nof integers corresponding to the IDs of its end nodes. DGL assigns to each edge a unique\ninteger, called its **edge ID**, based on the order in which it was added to the graph. The\nnumbering of node and edge IDs starts from 0. In DGL, all the edges are directed, and an\nedge :math:`(u, v)` indicates that the direction goes from node :math:`u` to node :math:`v`.\n\nTo specify multiple nodes, DGL uses a 1-D integer tensor (i.e., PyTorch's tensor,\nTensorFlow's Tensor, or MXNet's ndarray) of node IDs. DGL calls this format \"node-tensors\".\nTo specify multiple edges, it uses a tuple of node-tensors :math:`(U, V)`. :math:`(U[i], V[i])`\ndecides an edge from :math:`U[i]` to :math:`V[i]`.\n\nOne way to create a :class:`~dgl.DGLGraph` is to use the :func:`dgl.graph` method, which takes\nas input a set of edges. DGL also supports creating graphs from other data sources, see :ref:`guide-graph-external`.\n\nThe following code snippet uses the :func:`dgl.graph` method to create a :class:`~dgl.DGLGraph`\ncorresponding to the four-node graph shown below and illustrates some of its APIs for\nquerying the graph's structure.\n\n.. figure:: https://data.dgl.ai/asset/image/user_guide_graphch_1.png\n    :height: 200px\n    :width: 300px\n    :align: center\n\n.. code::\n\n    >>> import dgl\n    >>> import torch as th\n\n    >>> # edges 0->1, 0->2, 0->3, 1->3\n    >>> u, v = th.tensor([0, 0, 0, 1]), th.tensor([1, 2, 3, 3])\n    >>> g = dgl.graph((u, v))\n    >>> print(g) # number of nodes are inferred from the max node IDs in the given edges\n    Graph(num_nodes=4, num_edges=4,\n          ndata_schemes={}\n          edata_schemes={})\n\n    >>> # Node IDs\n    >>> print(g.nodes())\n    tensor([0, 1, 2, 3])\n    >>> # Edge end nodes\n    >>> print(g.edges())\n    (tensor([0, 0, 0, 1]), tensor([1, 2, 3, 3]))\n    >>> # Edge end nodes and edge IDs\n    >>> print(g.edges(form='all'))\n    (tensor([0, 0, 0, 1]), tensor([1, 2, 3, 3]), tensor([0, 1, 2, 3]))\n\n    >>> # If the node with the largest ID is isolated (meaning no edges),\n    >>> # then one needs to explicitly set the number of nodes\n    >>> g = dgl.graph((u, v), num_nodes=8)\n\nFor an undirected graph, one needs to create edges for both directions. :func:`dgl.to_bidirected`\ncan be helpful in this case, which converts a graph into a new one with edges for both directions.\n\n.. code::\n\n    >>> bg = dgl.to_bidirected(g)\n    >>> bg.edges()\n    (tensor([0, 0, 0, 1, 1, 2, 3, 3]), tensor([1, 2, 3, 0, 3, 0, 0, 1]))\n\n.. note::\n\n    Tensor types are generally preferred throughout DGL APIs due to their efficient internal\n    storage in C and explicit data type and device context information. However, most DGL APIs\n    do support python iterable (e.g., list) or numpy.ndarray as arguments for quick prototyping.\n\nDGL can use either :math:`32`- or :math:`64`-bit integers to store the node and edge IDs. The data types for\nthe node and edge IDs should be the same. By using :math:`64` bits, DGL can handle graphs with\nup to :math:`2^{63} - 1` nodes or edges. However, if a graph contains less than :math:`2^{31} - 1` nodes or edges,\none should use :math:`32`-bit integers as it leads to better speed and requires less memory.\nDGL provides methods for making such conversions. See below for an example.\n\n.. code::\n\n    >>> edges = th.tensor([2, 5, 3]), th.tensor([3, 5, 0])  # edges 2->3, 5->5, 3->0\n    >>> g64 = dgl.graph(edges)  # DGL uses int64 by default\n    >>> print(g64.idtype)\n    torch.int64\n    >>> g32 = dgl.graph(edges, idtype=th.int32)  # create a int32 graph\n    >>> g32.idtype\n    torch.int32\n    >>> g64_2 = g32.long()  # convert to int64\n    >>> g64_2.idtype\n    torch.int64\n    >>> g32_2 = g64.int()  # convert to int32\n    >>> g32_2.idtype\n    torch.int32\n\nSee APIs: :func:`dgl.graph`, :func:`dgl.DGLGraph.nodes`, :func:`dgl.DGLGraph.edges`, :func:`dgl.to_bidirected`,\n:func:`dgl.DGLGraph.int`, :func:`dgl.DGLGraph.long`, and :py:attr:`dgl.DGLGraph.idtype`.\n"
  },
  {
    "path": "docs/source/guide/graph-heterogeneous.rst",
    "content": ".. _guide-graph-heterogeneous:\n\n1.5 Heterogeneous Graphs\n------------------------\n\n:ref:`(中文版)<guide_cn-graph-heterogeneous>`\n\nA heterogeneous graph can have nodes and edges of different types. Nodes/Edges of\ndifferent types have independent ID space and feature storage. For example in the figure below, the\nuser and game node IDs both start from zero and they have different features.\n\n.. figure:: https://data.dgl.ai/asset/image/user_guide_graphch_2.png\n\n    An example heterogeneous graph with two types of nodes (user and game) and two types of edges (follows and plays).\n\nCreating a Heterogeneous Graph\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nIn DGL, a heterogeneous graph (heterograph for short) is specified with a series of graphs as below, one per\nrelation. Each relation is a string triplet ``(source node type, edge type, destination node type)``.\nSince relations disambiguate the edge types, DGL calls them canonical edge types.\n\nThe following code snippet is an example for creating a heterogeneous graph in DGL.\n\n.. code::\n\n    >>> import dgl\n    >>> import torch as th\n\n    >>> # Create a heterograph with 3 node types and 3 edges types.\n    >>> graph_data = {\n    ...    ('drug', 'interacts', 'drug'): (th.tensor([0, 1]), th.tensor([1, 2])),\n    ...    ('drug', 'interacts', 'gene'): (th.tensor([0, 1]), th.tensor([2, 3])),\n    ...    ('drug', 'treats', 'disease'): (th.tensor([1]), th.tensor([2]))\n    ... }\n    >>> g = dgl.heterograph(graph_data)\n    >>> g.ntypes\n    ['disease', 'drug', 'gene']\n    >>> g.etypes\n    ['interacts', 'interacts', 'treats']\n    >>> g.canonical_etypes\n    [('drug', 'interacts', 'drug'),\n     ('drug', 'interacts', 'gene'),\n     ('drug', 'treats', 'disease')]\n\nNote that homogeneous and bipartite graphs are just special heterogeneous graphs with one\nrelation.\n\n.. code::\n\n    >>> # A homogeneous graph\n    >>> dgl.heterograph({('node_type', 'edge_type', 'node_type'): (u, v)})\n    >>> # A bipartite graph\n    >>> dgl.heterograph({('source_type', 'edge_type', 'destination_type'): (u, v)})\n\nThe *metagraph* associated with a heterogeneous graph is the schema of the graph. It specifies\ntype constraints on the sets of nodes and edges between the nodes. A node :math:`u` in a metagraph\ncorresponds to a node type in the associated heterograph. An edge :math:`(u, v)` in a metagraph indicates that\nthere are edges from nodes of type :math:`u` to nodes of type :math:`v` in the associated heterograph.\n\n.. code::\n\n    >>> g\n    Graph(num_nodes={'disease': 3, 'drug': 3, 'gene': 4},\n          num_edges={('drug', 'interacts', 'drug'): 2,\n                     ('drug', 'interacts', 'gene'): 2,\n                     ('drug', 'treats', 'disease'): 1},\n          metagraph=[('drug', 'drug', 'interacts'),\n                     ('drug', 'gene', 'interacts'),\n                     ('drug', 'disease', 'treats')])\n    >>> g.metagraph().edges()\n    OutMultiEdgeDataView([('drug', 'drug'), ('drug', 'gene'), ('drug', 'disease')])\n\nSee APIs: :func:`dgl.heterograph`, :py:attr:`~dgl.DGLGraph.ntypes`, :py:attr:`~dgl.DGLGraph.etypes`,\n:py:attr:`~dgl.DGLGraph.canonical_etypes`, :py:attr:`~dgl.DGLGraph.metagraph`.\n\nWorking with Multiple Types\n^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nWhen multiple node/edge types are introduced, users need to specify the particular\nnode/edge type when invoking a DGLGraph API for type-specific information. In addition,\nnodes/edges of different types have separate IDs.\n\n.. code::\n\n    >>> # Get the number of all nodes in the graph\n    >>> g.num_nodes()\n    10\n    >>> # Get the number of drug nodes\n    >>> g.num_nodes('drug')\n    3\n    >>> # Nodes of different types have separate IDs,\n    >>> # hence not well-defined without a type specified\n    >>> g.nodes()\n    DGLError: Node type name must be specified if there are more than one node types.\n    >>> g.nodes('drug')\n    tensor([0, 1, 2])\n\nTo set/get features for a specific node/edge type, DGL provides two new types of syntax --\n`g.nodes['node_type'].data['feat_name']` and `g.edges['edge_type'].data['feat_name']`.\n\n.. code::\n\n    >>> # Set/get feature 'hv' for nodes of type 'drug'\n    >>> g.nodes['drug'].data['hv'] = th.ones(3, 1)\n    >>> g.nodes['drug'].data['hv']\n    tensor([[1.],\n            [1.],\n            [1.]])\n    >>> # Set/get feature 'he' for edge of type 'treats'\n    >>> g.edges['treats'].data['he'] = th.zeros(1, 1)\n    >>> g.edges['treats'].data['he']\n    tensor([[0.]])\n\nIf the graph only has one node/edge type, there is no need to specify the node/edge type.\n\n.. code::\n\n    >>> g = dgl.heterograph({\n    ...    ('drug', 'interacts', 'drug'): (th.tensor([0, 1]), th.tensor([1, 2])),\n    ...    ('drug', 'is similar', 'drug'): (th.tensor([0, 1]), th.tensor([2, 3]))\n    ... })\n    >>> g.nodes()\n    tensor([0, 1, 2, 3])\n    >>> # To set/get feature with a single type, no need to use the new syntax\n    >>> g.ndata['hv'] = th.ones(4, 1)\n\n.. note::\n\n    When the edge type uniquely determines the types of source and destination nodes, one\n    can just use one string instead of a string triplet to specify the edge type. For example, for a\n    heterograph with two relations ``('user', 'plays', 'game')`` and ``('user', 'likes', 'game')``, it\n    is safe to just use ``'plays'`` or ``'likes'`` to refer to the two relations.\n\nLoading Heterographs from Disk\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nComma Separated Values (CSV)\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n\nA common way to store a heterograph is to store nodes and edges of different types in different CSV files.\nAn example is as follows.\n\n.. code::\n\n    # data folder\n    data/\n    |-- drug.csv        # drug nodes\n    |-- gene.csv        # gene nodes\n    |-- disease.csv     # disease nodes\n    |-- drug-interact-drug.csv  # drug-drug interaction edges\n    |-- drug-interact-gene.csv  # drug-gene interaction edges\n    |-- drug-treat-disease.csv  # drug-treat-disease edges\n\nSimilar to the case of homogeneous graphs, one can use packages like Pandas to parse\nCSV files into numpy arrays or framework tensors, build a relation dictionary and\nconstruct a heterograph from that. The approach also applies to other popular formats like\nGML/JSON.\n\nDGL Binary Format\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n\nDGL provides :func:`dgl.save_graphs` and :func:`dgl.load_graphs` respectively for saving\nheterogeneous graphs in binary format and loading them from binary format.\n\nEdge Type Subgraph\n^^^^^^^^^^^^^^^^^^\n\nOne can create a subgraph of a heterogeneous graph by specifying the relations to retain, with\nfeatures copied if any.\n\n.. code::\n\n    >>> g = dgl.heterograph({\n    ...    ('drug', 'interacts', 'drug'): (th.tensor([0, 1]), th.tensor([1, 2])),\n    ...    ('drug', 'interacts', 'gene'): (th.tensor([0, 1]), th.tensor([2, 3])),\n    ...    ('drug', 'treats', 'disease'): (th.tensor([1]), th.tensor([2]))\n    ... })\n    >>> g.nodes['drug'].data['hv'] = th.ones(3, 1)\n\n    >>> # Retain relations ('drug', 'interacts', 'drug') and ('drug', 'treats', 'disease')\n    >>> # All nodes for 'drug' and 'disease' will be retained\n    >>> eg = dgl.edge_type_subgraph(g, [('drug', 'interacts', 'drug'),\n    ...                                 ('drug', 'treats', 'disease')])\n    >>> eg\n    Graph(num_nodes={'disease': 3, 'drug': 3},\n          num_edges={('drug', 'interacts', 'drug'): 2, ('drug', 'treats', 'disease'): 1},\n          metagraph=[('drug', 'drug', 'interacts'), ('drug', 'disease', 'treats')])\n    >>> # The associated features will be copied as well\n    >>> eg.nodes['drug'].data['hv']\n    tensor([[1.],\n            [1.],\n            [1.]])\n\nConverting Heterogeneous Graphs to Homogeneous Graphs\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nHeterographs provide a clean interface for managing nodes/edges of different types and\ntheir associated features. This is particularly helpful when:\n\n1. The features for nodes/edges of different types have different data types or sizes.\n2. We want to apply different operations to nodes/edges of different types.\n\nIf the above conditions do not hold and one does not want to distinguish node/edge types in\nmodeling, then DGL allows converting a heterogeneous graph to a homogeneous graph with :func:`dgl.DGLGraph.to_homogeneous` API.\nIt proceeds as follows:\n\n1. Relabels nodes/edges of all types using consecutive integers starting from 0\n2. Merges the features across node/edge types specified by the user.\n\n.. code::\n\n    >>> g = dgl.heterograph({\n    ...    ('drug', 'interacts', 'drug'): (th.tensor([0, 1]), th.tensor([1, 2])),\n    ...    ('drug', 'treats', 'disease'): (th.tensor([1]), th.tensor([2]))})\n    >>> g.nodes['drug'].data['hv'] = th.zeros(3, 1)\n    >>> g.nodes['disease'].data['hv'] = th.ones(3, 1)\n    >>> g.edges['interacts'].data['he'] = th.zeros(2, 1)\n    >>> g.edges['treats'].data['he'] = th.zeros(1, 2)\n\n    >>> # By default, it does not merge any features\n    >>> hg = dgl.to_homogeneous(g)\n    >>> 'hv' in hg.ndata\n    False\n\n    >>> # Copy edge features\n    >>> # For feature copy, it expects features to have\n    >>> # the same size and dtype across node/edge types\n    >>> hg = dgl.to_homogeneous(g, edata=['he'])\n    DGLError: Cannot concatenate column ‘he’ with shape Scheme(shape=(2,), dtype=torch.float32) and shape Scheme(shape=(1,), dtype=torch.float32)\n\n    >>> # Copy node features\n    >>> hg = dgl.to_homogeneous(g, ndata=['hv'])\n    >>> hg.ndata['hv']\n    tensor([[1.],\n            [1.],\n            [1.],\n            [0.],\n            [0.],\n            [0.]])\n\nThe original node/edge types and type-specific IDs are stored in :py:attr:`~dgl.DGLGraph.ndata` and :py:attr:`~dgl.DGLGraph.edata`.\n\n.. code::\n\n    >>> # Order of node types in the heterograph\n    >>> g.ntypes\n    ['disease', 'drug']\n    >>> # Original node types\n    >>> hg.ndata[dgl.NTYPE]\n    tensor([0, 0, 0, 1, 1, 1])\n    >>> # Original type-specific node IDs\n    >>> hg.ndata[dgl.NID]\n    tensor([0, 1, 2, 0, 1, 2])\n\n    >>> # Order of edge types in the heterograph\n    >>> g.etypes\n    ['interacts', 'treats']\n    >>> # Original edge types\n    >>> hg.edata[dgl.ETYPE]\n    tensor([0, 0, 1])\n    >>> # Original type-specific edge IDs\n    >>> hg.edata[dgl.EID]\n    tensor([0, 1, 0])\n\nFor modeling purposes, one may want to group some relations together and apply the same\noperation to them. To address this need, one can first take an edge type subgraph of the\nheterograph and then convert the subgraph to a homogeneous graph.\n\n.. code::\n\n    >>> g = dgl.heterograph({\n    ...    ('drug', 'interacts', 'drug'): (th.tensor([0, 1]), th.tensor([1, 2])),\n    ...    ('drug', 'interacts', 'gene'): (th.tensor([0, 1]), th.tensor([2, 3])),\n    ...    ('drug', 'treats', 'disease'): (th.tensor([1]), th.tensor([2]))\n    ... })\n    >>> sub_g = dgl.edge_type_subgraph(g, [('drug', 'interacts', 'drug'),\n    ...                                    ('drug', 'interacts', 'gene')])\n    >>> h_sub_g = dgl.to_homogeneous(sub_g)\n    >>> h_sub_g\n    Graph(num_nodes=7, num_edges=4,\n          ...)\n"
  },
  {
    "path": "docs/source/guide/graph.rst",
    "content": ".. _guide-graph:\n\nChapter 1: Graph\n======================\n\n:ref:`(中文版)<guide_cn-graph>`\n\nGraphs express entities (nodes) along with their relations (edges), and both nodes and\nedges can be typed (e.g., ``\"user\"`` and ``\"item\"`` are two different types of nodes). DGL provides a\ngraph-centric programming abstraction with its core data structure -- :class:`~dgl.DGLGraph`. :class:`~dgl.DGLGraph`\nprovides its interface to handle a graph's structure, its node/edge features, and the resulting\ncomputations that can be performed using these components.\n\nRoadmap\n-------\n\nThe chapter starts with a brief introduction to graph definitions in 1.1 and then introduces some core\nconcepts of :class:`~dgl.DGLGraph`:\n\n* :ref:`guide-graph-basic`\n* :ref:`guide-graph-graphs-nodes-edges`\n* :ref:`guide-graph-feature`\n* :ref:`guide-graph-external`\n* :ref:`guide-graph-heterogeneous`\n* :ref:`guide-graph-gpu`\n\n.. toctree::\n    :maxdepth: 1\n    :hidden:\n    :glob:\n\n    graph-basic\n    graph-graphs-nodes-edges\n    graph-feature\n    graph-external\n    graph-heterogeneous\n    graph-gpu\n"
  },
  {
    "path": "docs/source/guide/index.rst",
    "content": "User Guide\n==========\n\n.. toctree::\n  :maxdepth: 2\n  :titlesonly:\n\n  graph\n  message\n  nn\n  data\n  training\n  minibatch\n  distributed\n  mixed_precision\n"
  },
  {
    "path": "docs/source/guide/message-api.rst",
    "content": ".. _guide-message-passing-api:\n\n2.1 Built-in Functions and Message Passing APIs\n-----------------------------------------------\n\n:ref:`(中文版) <guide_cn-message-passing-api>`\n\nIn DGL, **message function** takes a single argument ``edges``,\nwhich is an :class:`~dgl.udf.EdgeBatch` instance. During message passing,\nDGL generates it internally to represent a batch of edges. It has three\nmembers ``src``, ``dst`` and ``data`` to access features of source nodes,\ndestination nodes, and edges, respectively.\n\n**reduce function** takes a single argument ``nodes``, which is a\n:class:`~dgl.udf.NodeBatch` instance. During message passing,\nDGL generates it internally to represent a batch of nodes. It has member\n``mailbox`` to access the messages received for the nodes in the batch.\nSome of the most common reduce operations include ``sum``, ``max``, ``min``, etc.\n\n**update function** takes a single argument ``nodes`` as described above.\nThis function operates on the aggregation result from ``reduce function``, typically\ncombining it with a node’s original feature at the the last step and saving the result\nas a node feature.\n\nDGL has implemented commonly used message functions and reduce functions\nas **built-in** in the namespace ``dgl.function``. In general, DGL\nsuggests using built-in functions **whenever possible** since they are\nheavily optimized and automatically handle dimension broadcasting.\n\nIf your message passing functions cannot be implemented with built-ins,\nyou can implement user-defined message/reduce function (aka. **UDF**).\n\nBuilt-in message functions can be unary or binary. DGL supports ``copy``\nfor unary. For binary funcs, DGL supports ``add``, ``sub``, ``mul``, ``div``,\n``dot``. The naming convention for message built-in funcs is that ``u``\nrepresents ``src`` nodes, ``v`` represents ``dst`` nodes, and ``e`` represents ``edges``.\nThe parameters for those functions are strings indicating the input and output field names for\nthe corresponding nodes and edges. The list of supported built-in functions\ncan be found in :ref:`api-built-in`. For example, to add the ``hu`` feature from src\nnodes and ``hv`` feature from dst nodes then save the result on the edge\nat ``he`` field, one can use built-in function ``dgl.function.u_add_v('hu', 'hv', 'he')``.\nThis is equivalent to the Message UDF:\n\n.. code::\n\n    def message_func(edges):\n         return {'he': edges.src['hu'] + edges.dst['hv']}\n\nBuilt-in reduce functions support operations ``sum``, ``max``, ``min``,\nand ``mean``. Reduce functions usually have two parameters, one\nfor field name in ``mailbox``, one for field name in node features, both\nare strings. For example, ``dgl.function.sum('m', 'h')`` is equivalent\nto the Reduce UDF that sums up the message ``m``:\n\n.. code::\n\n    import torch\n    def reduce_func(nodes):\n         return {'h': torch.sum(nodes.mailbox['m'], dim=1)}\n\nFor advanced usage of UDF, see :ref:`apiudf`.\n\nIt is also possible to invoke only edge-wise computation by :meth:`~dgl.DGLGraph.apply_edges`\nwithout invoking message passing. :meth:`~dgl.DGLGraph.apply_edges` takes a message function\nfor parameter and by default updates the features of all edges. For example:\n\n.. code::\n\n    import dgl.function as fn\n    graph.apply_edges(fn.u_add_v('el', 'er', 'e'))\n\nFor message passing, :meth:`~dgl.DGLGraph.update_all` is a high-level\nAPI that merges message generation, message aggregation and node update\nin a single call, which leaves room for optimization as a whole.\n\nThe parameters for :meth:`~dgl.DGLGraph.update_all` are a message function, a\nreduce function and an update function. One can call update function outside of\n``update_all`` and not specify it in invoking :meth:`~dgl.DGLGraph.update_all`.\nDGL recommends this approach since the update function can usually be\nwritten as pure tensor operations to make the code concise. For\nexample：\n\n.. code::\n\n    def update_all_example(graph):\n        # store the result in graph.ndata['ft']\n        graph.update_all(fn.u_mul_e('ft', 'a', 'm'),\n                         fn.sum('m', 'ft'))\n        # Call update function outside of update_all\n        final_ft = graph.ndata['ft'] * 2\n        return final_ft\n\nThis call will generate the messages ``m`` by multiply src node features\n``ft`` and edge features ``a``, sum up the messages ``m`` to update node\nfeatures ``ft``, and finally multiply ``ft`` by 2 to get the result\n``final_ft``. After the call, DGL will clean the intermediate messages ``m``.\nThe math formula for the above function is:\n\n.. math::  {final\\_ft}_i = 2 * \\sum_{j\\in\\mathcal{N}(i)} ({ft}_j * a_{ji})\n\nDGL's built-in functions support floating point data types, i.e. the feature must\nbe ``half`` (``float16``) /``float``/``double`` tensors.\n``float16`` data type support is disabled by default as it has a minimum GPU\ncompute capacity requirement of ``sm_53`` (Pascal, Volta, Turing and Ampere\narchitectures).\n\nUser can enable float16 for mixed precision training by compiling DGL from source\n(see :doc:`Mixed Precision Training <mixed_precision>` tutorial for details).\n"
  },
  {
    "path": "docs/source/guide/message-efficient.rst",
    "content": ".. _guide-message-passing-efficient:\n\n2.2 Writing Efficient Message Passing Code\n------------------------------------------\n\n:ref:`(中文版) <guide_cn-message-passing-efficient>`\n\nDGL optimizes memory consumption and computing speed for message\npassing. A common practise to leverage those\noptimizations is to construct one's own message passing functionality as\na combination of :meth:`~dgl.DGLGraph.update_all` calls with built-in\nfunctions as parameters.\n\nBesides that, considering that the number of edges is much larger than the number of nodes for some graphs, avoiding unnecessary memory copy from nodes to edges is beneficial. For some cases like\n:class:`~dgl.nn.pytorch.conv.GATConv`,\nwhere it is necessary to save message on the edges, one needs to call\n:meth:`~dgl.DGLGraph.apply_edges` with built-in functions. Sometimes the\nmessages on the edges can be high dimensional, which is memory consuming.\nDGL recommends keeping the dimension of edge features as low as possible.\n\nHere’s an example on how to achieve this by splitting operations on the\nedges to nodes. The approach does the following: concatenate the ``src``\nfeature and ``dst`` feature, then apply a linear layer, i.e.\n:math:`W\\times (u || v)`. The ``src`` and ``dst`` feature dimension is\nhigh, while the linear layer output dimension is low. A straight forward\nimplementation would be like:\n\n.. code::\n\n    import torch\n    import torch.nn as nn\n\n    linear = nn.Parameter(torch.FloatTensor(size=(node_feat_dim * 2, out_dim)))\n    def concat_message_function(edges):\n         return {'cat_feat': torch.cat([edges.src['feat'], edges.dst['feat']], dim=1)}\n    g.apply_edges(concat_message_function)\n    g.edata['out'] = g.edata['cat_feat'] @ linear\n\nThe suggested implementation splits the linear operation into two,\none applies on ``src`` feature, the other applies on ``dst`` feature.\nIt then adds the output of the linear operations on the edges at the final stage,\ni.e. performing :math:`W_l\\times u + W_r \\times v`. This is because\n:math:`W \\times (u||v) = W_l \\times u + W_r \\times v`, where :math:`W_l`\nand :math:`W_r` are the left and the right half of the matrix :math:`W`,\nrespectively:\n\n.. code::\n\n    import dgl.function as fn\n\n    linear_src = nn.Parameter(torch.FloatTensor(size=(node_feat_dim, out_dim)))\n    linear_dst = nn.Parameter(torch.FloatTensor(size=(node_feat_dim, out_dim)))\n    out_src = g.ndata['feat'] @ linear_src\n    out_dst = g.ndata['feat'] @ linear_dst\n    g.srcdata.update({'out_src': out_src})\n    g.dstdata.update({'out_dst': out_dst})\n    g.apply_edges(fn.u_add_v('out_src', 'out_dst', 'out'))\n\nThe above two implementations are mathematically equivalent. The latter\none is more efficient because it does not need to save feat_src and\nfeat_dst on edges, which is not memory-efficient. Plus, addition could\nbe optimized with DGL’s built-in function :func:`~dgl.function.u_add_v`, which further\nspeeds up computation and saves memory footprint.\n"
  },
  {
    "path": "docs/source/guide/message-heterograph.rst",
    "content": ".. _guide-message-passing-heterograph:\n\n2.5 Message Passing on Heterogeneous Graph\n------------------------------------------\n\n:ref:`(中文版) <guide_cn-message-passing-heterograph>`\n\nHeterogeneous graphs (:ref:`guide-graph-heterogeneous`), or\nheterographs for short, are graphs that contain different types of nodes\nand edges. The different types of nodes and edges tend to have different\ntypes of attributes that are designed to capture the characteristics of\neach node and edge type. Within the context of graph neural networks,\ndepending on their complexity, certain node and edge types might need to\nbe modeled with representations that have a different number of\ndimensions.\n\nThe message passing on heterographs can be split into two parts:\n\n1. Message computation and aggregation for each relation r.\n2. Reduction that merges the aggregation results from all relations for each node type.\n\nDGL’s interface to call message passing on heterographs is\n:meth:`~dgl.DGLGraph.multi_update_all`.\n:meth:`~dgl.DGLGraph.multi_update_all` takes a dictionary containing\nthe parameters for :meth:`~dgl.DGLGraph.update_all` within each relation\nusing relation as the key, and a string representing the cross type reducer.\nThe reducer can be one of ``sum``, ``min``, ``max``, ``mean``, ``stack``.\nHere’s an example:\n\n.. code::\n\n    import dgl.function as fn\n\n    for c_etype in G.canonical_etypes:\n        srctype, etype, dsttype = c_etype\n        Wh = self.weight[etype](feat_dict[srctype])\n        # Save it in graph for message passing\n        G.nodes[srctype].data['Wh_%s' % etype] = Wh\n        # Specify per-relation message passing functions: (message_func, reduce_func).\n        # Note that the results are saved to the same destination feature 'h', which\n        # hints the type wise reducer for aggregation.\n        funcs[etype] = (fn.copy_u('Wh_%s' % etype, 'm'), fn.mean('m', 'h'))\n    # Trigger message passing of multiple types.\n    G.multi_update_all(funcs, 'sum')\n    # return the updated node feature dictionary\n    return {ntype : G.nodes[ntype].data['h'] for ntype in G.ntypes}\n"
  },
  {
    "path": "docs/source/guide/message-part.rst",
    "content": ".. _guide-message-passing-part:\n\n2.3 Apply Message Passing On Part Of The Graph\n----------------------------------------------\n\n:ref:`(中文版) <guide_cn-message-passing-part>`\n\nIf one only wants to update part of the nodes in the graph, the practice\nis to create a subgraph by providing the IDs for the nodes to\ninclude in the update, then call :meth:`~dgl.DGLGraph.update_all` on the\nsubgraph. For example:\n\n.. code::\n\n    nid = [0, 2, 3, 6, 7, 9]\n    sg = g.subgraph(nid)\n    sg.update_all(message_func, reduce_func, apply_node_func)\n\nThis is a common usage in mini-batch training. Check :ref:`guide-minibatch` for more detailed\nusages."
  },
  {
    "path": "docs/source/guide/message.rst",
    "content": ".. _guide-message-passing:\n\nChapter 2: Message Passing\n==========================\n\n:ref:`(中文版) <guide_cn-message-passing>`\n\nMessage Passing Paradigm\n------------------------\n\nLet :math:`x_v\\in\\mathbb{R}^{d_1}` be the feature for node :math:`v`,\nand :math:`w_{e}\\in\\mathbb{R}^{d_2}` be the feature for edge\n:math:`({u}, {v})`. The **message passing paradigm** defines the\nfollowing node-wise and edge-wise computation at step :math:`t+1`:\n\n.. math::  \\text{Edge-wise: } m_{e}^{(t+1)} = \\phi \\left( x_v^{(t)}, x_u^{(t)}, w_{e}^{(t)} \\right) , ({u}, {v},{e}) \\in \\mathcal{E}.\n\n.. math::  \\text{Node-wise: } x_v^{(t+1)} = \\psi \\left(x_v^{(t)}, \\rho\\left(\\left\\lbrace m_{e}^{(t+1)} : ({u}, {v},{e}) \\in \\mathcal{E} \\right\\rbrace \\right) \\right).\n\nIn the above equations, :math:`\\phi` is a **message function**\ndefined on each edge to generate a message by combining the edge feature\nwith the features of its incident nodes; :math:`\\psi` is an\n**update function** defined on each node to update the node feature\nby aggregating its incoming messages using the **reduce function**\n:math:`\\rho`.\n\nRoadmap\n-------\n\nThis chapter introduces DGL's message passing APIs, and how to efficiently use them on both nodes and edges.\nThe last section of it explains how to implement message passing on heterogeneous graphs.\n\n* :ref:`guide-message-passing-api`\n* :ref:`guide-message-passing-efficient`\n* :ref:`guide-message-passing-part`\n* :ref:`guide-message-passing-heterograph`\n\n.. toctree::\n    :maxdepth: 1\n    :hidden:\n    :glob:\n\n    message-api\n    message-efficient\n    message-part\n    message-heterograph\n"
  },
  {
    "path": "docs/source/guide/minibatch-custom-sampler.rst",
    "content": ".. _guide-minibatch-customizing-neighborhood-sampler:\n\n6.4 Implementing Custom Graph Samplers\n----------------------------------------------\n\nImplementing custom samplers involves subclassing the\n:class:`dgl.graphbolt.SubgraphSampler` base class and implementing its abstract\n:attr:`sample_subgraphs` method. The :attr:`sample_subgraphs` method should\ntake in seed nodes which are the nodes to sample neighbors from:\n\n.. code:: python\n\n    def sample_subgraphs(self, seed_nodes):\n        return input_nodes, sampled_subgraphs\n\nThe method should return the input node IDs list and a list of subgraphs. Each\nsubgraph is a :class:`~dgl.graphbolt.SampledSubgraph` object.\n\n\nAny other data that are required during sampling such as the graph structure,\nfanout size, etc. should be passed to the sampler via the constructor.\n\nThe code below implements a classical neighbor sampler:\n\n.. code:: python\n\n    @functional_datapipe(\"customized_sample_neighbor\")\n    class CustomizedNeighborSampler(dgl.graphbolt.SubgraphSampler):\n       def __init__(self, datapipe, graph, fanouts):\n           super().__init__(datapipe)\n           self.graph = graph\n           self.fanouts = fanouts\n\n       def sample_subgraphs(self, seed_nodes):\n           subgs = []\n           for fanout in reversed(self.fanouts):\n               # Sample a fixed number of neighbors of the current seed nodes.\n               input_nodes, sg = g.sample_neighbors(seed_nodes, fanout)\n               subgs.insert(0, sg)\n               seed_nodes = input_nodes\n           return input_nodes, subgs\n\nTo use this sampler with :class:`~dgl.graphbolt.DataLoader`:\n\n.. code:: python\n\n    datapipe = gb.ItemSampler(train_set, batch_size=1024, shuffle=True)\n    datapipe = datapipe.customized_sample_neighbor(g, [10, 10]) # 2 layers.\n    datapipe = datapipe.fetch_feature(feature, node_feature_keys=[\"feat\"])\n    datapipe = datapipe.copy_to(device)\n    dataloader = gb.DataLoader(datapipe)\n\n    for data in dataloader:\n        input_features = data.node_features[\"feat\"]\n        output_labels = data.labels\n        output_predictions = model(data.blocks, input_features)\n        loss = compute_loss(output_labels, output_predictions)\n        opt.zero_grad()\n        loss.backward()\n        opt.step()\n\n\nSampler for Heterogeneous Graphs\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nTo write a sampler for heterogeneous graphs, one needs to be aware that\nthe argument `graph` is a heterogeneous graph while `seeds` could be a\ndictionary of ID tensors. Most of DGL's graph sampling operators (e.g.,\nthe ``sample_neighbors`` and ``to_block`` functions in the above example) can\nwork on heterogeneous graph natively, so many samplers are automatically\nready for heterogeneous graph. For example, the above ``CustomizedNeighborSampler``\ncan be used on heterogeneous graphs:\n\n.. code:: python\n\n    import dgl.graphbolt as gb\n    hg = gb.FusedCSCSamplingGraph()\n    train_set = item_set = gb.HeteroItemSet(\n        {\n            \"user\": gb.ItemSet(\n                (torch.arange(0, 5), torch.arange(5, 10)),\n                names=(\"seeds\", \"labels\"),\n            ),\n            \"item\": gb.ItemSet(\n                (torch.arange(5, 10), torch.arange(10, 15)),\n                names=(\"seeds\", \"labels\"),\n            ),\n        }\n    )\n    datapipe = gb.ItemSampler(train_set, batch_size=1024, shuffle=True)\n    datapipe = datapipe.customized_sample_neighbor(g, [10, 10]) # 2 layers.\n    datapipe = datapipe.fetch_feature(\n        feature, node_feature_keys={\"user\": [\"feat\"], \"item\": [\"feat\"]}\n    )\n    datapipe = datapipe.copy_to(device)\n    dataloader = gb.DataLoader(datapipe)\n\n    for data in dataloader:\n        input_features = {\n            ntype: data.node_features[(ntype, \"feat\")]\n            for ntype in data.blocks[0].srctypes\n        }\n        output_labels = data.labels[\"user\"]\n        output_predictions = model(data.blocks, input_features)[\"user\"]\n        loss = compute_loss(output_labels, output_predictions)\n        opt.zero_grad()\n        loss.backward()\n        opt.step()\n\n\nExclude Edges After Sampling\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nIn some cases, we may want to exclude seed edges from the sampled subgraph. For\nexample, in link prediction tasks, we want to exclude the edges in the\ntraining set from the sampled subgraph to prevent information leakage. To\ndo so, we need to add an additional datapipe right after sampling as follows:\n\n.. code:: python\n\n    datapipe = datapipe.customized_sample_neighbor(g, [10, 10]) # 2 layers.\n    datapipe = datapipe.transform(gb.exclude_seed_edges)\n\nPlease check the API page of :func:`~dgl.graphbolt.exclude_seed_edges` for more\ndetails.\n\nThe above API is based on :meth:`~dgl.graphbolt.SampledSubgrahp.exclude_edges`.\nIf you want to exclude edges from the sampled subgraph based on some other\ncriteria, you could write your own transform function. Please check the method\nfor reference.\n\nYou could also refer to examples in\n`Link Prediction <https://github.com/dmlc/dgl/blob/master/examples/graphbolt/link_prediction.py>`__.\n"
  },
  {
    "path": "docs/source/guide/minibatch-edge.rst",
    "content": ".. _guide-minibatch-edge-classification-sampler:\n\n6.2 Training GNN for Edge Classification with Neighborhood Sampling\n----------------------------------------------------------------------\n\n:ref:`(中文版) <guide_cn-minibatch-edge-classification-sampler>`\n\nTraining for edge classification/regression is somewhat similar to that\nof node classification/regression with several notable differences.\n\nDefine a neighborhood sampler and data loader\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nYou can use the\n:ref:`same neighborhood samplers as node classification <guide-minibatch-node-classification-sampler>`.\n\n.. code:: python\n\n    datapipe = datapipe.sample_neighbor(g, [10, 10])\n    # Or equivalently\n    datapipe = dgl.graphbolt.NeighborSampler(datapipe, g, [10, 10])\n\nThe code for defining a data loader is also the same as that of node\nclassification. The only difference is that it iterates over the\nedges(namely, node pairs) in the training set instead of the nodes.\n\n.. code:: python\n\n    import dgl.graphbolt as gb\n\n    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n    g = gb.SamplingGraph()\n    seeds = torch.arange(0, 1000).reshape(-1, 2)\n    labels = torch.randint(0, 2, (5,))\n    train_set = gb.ItemSet((seeds, labels), names=(\"seeds\", \"labels\"))\n    datapipe = gb.ItemSampler(train_set, batch_size=128, shuffle=True)\n    datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.\n    # Or equivalently:\n    # datapipe = gb.NeighborSampler(datapipe, g, [10, 10])\n    datapipe = datapipe.fetch_feature(feature, node_feature_keys=[\"feat\"])\n    datapipe = datapipe.copy_to(device)\n    dataloader = gb.DataLoader(datapipe)\n\nIterating over the DataLoader will yield :class:`~dgl.graphbolt.MiniBatch`\nwhich contains a list of specially created graphs representing the computation\ndependencies on each layer. You can access the *message flow graphs* (MFGs) via\n`mini_batch.blocks`.\n\n.. code:: python\n    mini_batch = next(iter(dataloader))\n    print(mini_batch.blocks)\n\n.. note::\n\n   See the :doc:`Stochastic Training Tutorial\n   <../notebooks/stochastic_training/neighbor_sampling_overview.nblink>`__\n   for the concept of message flow graph.\n\n   If you wish to develop your own neighborhood sampler or you want a more\n   detailed explanation of the concept of MFGs, please refer to\n   :ref:`guide-minibatch-customizing-neighborhood-sampler`.\n\n.. _guide-minibatch-edge-classification-sampler-exclude:\n\nRemoving edges in the minibatch from the original graph for neighbor sampling\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nWhen training edge classification models, sometimes you wish to remove\nthe edges appearing in the training data from the computation dependency\nas if they never existed. Otherwise, the model will “know” the fact that\nan edge exists between the two nodes, and potentially use it for\nadvantage.\n\nTherefore in edge classification you sometimes would like to exclude the\nseed edges as well as their reverse edges from the sampled minibatch.\nYou can use :func:`~dgl.graphbolt.exclude_seed_edges` alongside with\n:class:`~dgl.graphbolt.MiniBatchTransformer` to achieve this.\n\n.. code:: python\n\n    import dgl.graphbolt as gb\n    from functools import partial\n\n    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n    g = gb.SamplingGraph()\n    seeds = torch.arange(0, 1000).reshape(-1, 2)\n    labels = torch.randint(0, 2, (5,))\n    train_set = gb.ItemSet((seeds, labels), names=(\"seeds\", \"labels\"))\n    datapipe = gb.ItemSampler(train_set, batch_size=128, shuffle=True)\n    datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.\n    exclude_seed_edges = partial(gb.exclude_seed_edges, include_reverse_edges=True)\n    datapipe = datapipe.transform(exclude_seed_edges)\n    datapipe = datapipe.fetch_feature(feature, node_feature_keys=[\"feat\"])\n    datapipe = datapipe.copy_to(device)\n    dataloader = gb.DataLoader(datapipe)\n    \n\nAdapt your model for minibatch training\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nThe edge classification model usually consists of two parts:\n\n-  One part that obtains the representation of incident nodes.\n-  The other part that computes the edge score from the incident node\n   representations.\n\nThe former part is exactly the same as\n:ref:`that from node classification <guide-minibatch-node-classification-model>`\nand we can simply reuse it. The input is still the list of\nMFGs generated from a data loader provided by DGL, as well as the\ninput features.\n\n.. code:: python\n\n    class StochasticTwoLayerGCN(nn.Module):\n        def __init__(self, in_features, hidden_features, out_features):\n            super().__init__()\n            self.conv1 = dglnn.GraphConv(in_features, hidden_features)\n            self.conv2 = dglnn.GraphConv(hidden_features, out_features)\n    \n        def forward(self, blocks, x):\n            x = F.relu(self.conv1(blocks[0], x))\n            x = F.relu(self.conv2(blocks[1], x))\n            return x\n\nThe input to the latter part is usually the output from the\nformer part, as well as the subgraph(node pairs) of the original graph induced\nby the edges in the minibatch. The subgraph is yielded from the same data\nloader.\n\nThe following code shows an example of predicting scores on the edges by\nconcatenating the incident node features and projecting it with a dense layer.\n\n.. code:: python\n\n    class ScorePredictor(nn.Module):\n        def __init__(self, num_classes, in_features):\n            super().__init__()\n            self.W = nn.Linear(2 * in_features, num_classes)\n    \n        def forward(self, seeds, x):\n            src_x = x[seeds[:, 0]]\n            dst_x = x[seeds[:, 1]]\n            data = torch.cat([src_x, dst_x], 1)\n            return self.W(data)\n\n\nThe entire model will take the list of MFGs and the edges generated by the data\nloader, as well as the input node features as follows:\n\n.. code:: python\n\n    class Model(nn.Module):\n        def __init__(self, in_features, hidden_features, out_features, num_classes):\n            super().__init__()\n            self.gcn = StochasticTwoLayerGCN(\n                in_features, hidden_features, out_features)\n            self.predictor = ScorePredictor(num_classes, out_features)\n\n        def forward(self, blocks, x, seeds):\n            x = self.gcn(blocks, x)\n            return self.predictor(seeds, x)\n\nDGL ensures that that the nodes in the edge subgraph are the same as the\noutput nodes of the last MFG in the generated list of MFGs.\n\nTraining Loop\n~~~~~~~~~~~~~\n\nThe training loop is very similar to node classification. You can\niterate over the dataloader and get a subgraph induced by the edges in\nthe minibatch, as well as the list of MFGs necessary for computing\ntheir incident node representations.\n\n.. code:: python\n\n    import torch.nn.functional as F\n    model = Model(in_features, hidden_features, out_features, num_classes)\n    model = model.to(device)\n    opt = torch.optim.Adam(model.parameters())\n\n    for data in dataloader:\n        blocks = data.blocks\n        x = data.edge_features(\"feat\")\n        y_hat = model(data.blocks, x, data.compacted_seeds)\n        loss = F.cross_entropy(data.labels, y_hat)\n        opt.zero_grad()\n        loss.backward()\n        opt.step()\n\n\nFor heterogeneous graphs\n~~~~~~~~~~~~~~~~~~~~~~~~\n\nThe models computing the node representations on heterogeneous graphs\ncan also be used for computing incident node representations for edge\nclassification/regression.\n\n.. code:: python\n\n    class StochasticTwoLayerRGCN(nn.Module):\n        def __init__(self, in_feat, hidden_feat, out_feat, rel_names):\n            super().__init__()\n            self.conv1 = dglnn.HeteroGraphConv({\n                    rel : dglnn.GraphConv(in_feat, hidden_feat, norm='right')\n                    for rel in rel_names\n                })\n            self.conv2 = dglnn.HeteroGraphConv({\n                    rel : dglnn.GraphConv(hidden_feat, out_feat, norm='right')\n                    for rel in rel_names\n                })\n    \n        def forward(self, blocks, x):\n            x = self.conv1(blocks[0], x)\n            x = self.conv2(blocks[1], x)\n            return x\n\nFor score prediction, the only implementation difference between the\nhomogeneous graph and the heterogeneous graph is that we are looping\nover the edge types.\n\n.. code:: python\n\n    class ScorePredictor(nn.Module):\n        def __init__(self, num_classes, in_features):\n            super().__init__()\n            self.W = nn.Linear(2 * in_features, num_classes)\n    \n        def forward(self, seeds, x):\n            scores = {}\n            for etype in seeds.keys():\n                src, dst = seeds[etype].T\n                data = torch.cat([x[etype][src], x[etype][dst]], 1)\n                scores[etype] = self.W(data)\n            return scores\n\n    class Model(nn.Module):\n        def __init__(self, in_features, hidden_features, out_features, num_classes,\n                     etypes):\n            super().__init__()\n            self.rgcn = StochasticTwoLayerRGCN(\n                in_features, hidden_features, out_features, etypes)\n            self.pred = ScorePredictor(num_classes, out_features)\n\n        def forward(self, seeds, blocks, x):\n            x = self.rgcn(blocks, x)\n            return self.pred(seeds, x)\n\nData loader definition is almost identical to that of homogeneous graph. The\nonly difference is that the train_set is now an instance of\n:class:`~dgl.graphbolt.HeteroItemSet` instead of :class:`~dgl.graphbolt.ItemSet`.\n\n.. code:: python\n\n    import dgl.graphbolt as gb\n\n    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n    g = gb.SamplingGraph()\n    seeds = torch.arange(0, 1000).reshape(-1, 2)\n    labels = torch.randint(0, 3, (1000,))\n    seeds_labels = {\n        \"user:like:item\": gb.ItemSet(\n            (seeds, labels), names=(\"seeds\", \"labels\")\n        ),\n        \"user:follow:user\": gb.ItemSet(\n            (seeds, labels), names=(\"seeds\", \"labels\")\n        ),\n    }\n    train_set = gb.HeteroItemSet(seeds_labels)\n    datapipe = gb.ItemSampler(train_set, batch_size=128, shuffle=True)\n    datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.\n    datapipe = datapipe.fetch_feature(\n        feature, node_feature_keys={\"item\": [\"feat\"], \"user\": [\"feat\"]}\n    )\n    datapipe = datapipe.copy_to(device)\n    dataloader = gb.DataLoader(datapipe)\n\nThings become a little different if you wish to exclude the reverse\nedges on heterogeneous graphs. On heterogeneous graphs, reverse edges\nusually have a different edge type from the edges themselves, in order\nto differentiate the “forward” and “backward” relationships (e.g.\n``follow`` and ``followed_by`` are reverse relations of each other,\n``like`` and ``liked_by`` are reverse relations of each other,\netc.).\n\nIf each edge in a type has a reverse edge with the same ID in another\ntype, you can specify the mapping between edge types and their reverse\ntypes. The way to exclude the edges in the minibatch as well as their\nreverse edges then goes as follows.\n\n.. code:: python\n\n\n    exclude_seed_edges = partial(\n        gb.exclude_seed_edges,\n        include_reverse_edges=True,\n        reverse_etypes_mapping={\n            \"user:like:item\": \"item:liked_by:user\",\n            \"user:follow:user\": \"user:followed_by:user\",\n        },\n    )\n    datapipe = datapipe.transform(exclude_seed_edges)\n\n\nThe training loop is again almost the same as that on homogeneous graph,\nexcept for the implementation of ``compute_loss`` that will take in two\ndictionaries of node types and predictions here.\n\n.. code:: python\n\n    import torch.nn.functional as F\n    model = Model(in_features, hidden_features, out_features, num_classes, etypes)\n    model = model.to(device)\n    opt = torch.optim.Adam(model.parameters())\n\n    for data in dataloader:\n        blocks = data.blocks\n        x = data.edge_features((\"user:like:item\", \"feat\"))\n        y_hat = model(data.blocks, x, data.compacted_seeds)\n        loss = F.cross_entropy(data.labels, y_hat)\n        opt.zero_grad()\n        loss.backward()\n        opt.step()\n\n"
  },
  {
    "path": "docs/source/guide/minibatch-gpu-sampling.rst",
    "content": ".. _guide-minibatch-gpu-sampling:\n\n6.8 Using GPU for Neighborhood Sampling\n---------------------------------------\n\n.. note::\n  GraphBolt does not support GPU-based neighborhood sampling yet. So this guide is\n  utilizing :class:`~dgl.dataloading.DataLoader` for illustration.\n\nDGL since 0.7 has been supporting GPU-based neighborhood sampling, which has a significant\nspeed advantage over CPU-based neighborhood sampling.  If you estimate that your graph \ncan fit onto GPU and your model does not take a lot of GPU memory, then it is best to\nput the graph onto GPU memory and use GPU-based neighbor sampling.\n\nFor example, `OGB Products <https://ogb.stanford.edu/docs/nodeprop/#ogbn-products>`_ has\n2.4M nodes and 61M edges.  The graph takes less than 1GB since the memory consumption of\na graph depends on the number of edges.  Therefore it is entirely possible to fit the\nwhole graph onto GPU.\n\n\nUsing GPU-based neighborhood sampling in DGL data loaders\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nOne can use GPU-based neighborhood sampling with DGL data loaders via:\n\n* Put the graph onto GPU.\n\n* Put the ``train_nid`` onto GPU.\n\n* Set ``device`` argument to a GPU device.\n\n* Set ``num_workers`` argument to 0, because CUDA does not allow multiple processes\n  accessing the same context.\n\nAll the other arguments for the :class:`~dgl.dataloading.DataLoader` can be\nthe same as the other user guides and tutorials.\n\n.. code:: python\n\n   g = g.to('cuda:0')\n   train_nid = train_nid.to('cuda:0')\n   dataloader = dgl.dataloading.DataLoader(\n       g,                                # The graph must be on GPU.\n       train_nid,                        # train_nid must be on GPU.\n       sampler,\n       device=torch.device('cuda:0'),    # The device argument must be GPU.\n       num_workers=0,                    # Number of workers must be 0.\n       batch_size=1000,\n       drop_last=False,\n       shuffle=True)\n\n.. note::\n\n  GPU-based neighbor sampling also works for custom neighborhood samplers as long as\n  (1) your sampler is subclassed from :class:`~dgl.dataloading.BlockSampler`, and (2)\n  your sampler entirely works on GPU.\n\n\nUsing CUDA UVA-based neighborhood sampling in DGL data loaders\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. note::\n   New feature introduced in DGL 0.8.\n\nFor the case where the graph is too large to fit onto the GPU memory, we introduce the\nCUDA UVA (Unified Virtual Addressing)-based sampling, in which GPUs perform the sampling\non the graph pinned in CPU memory via zero-copy access.\nYou can enable UVA-based neighborhood sampling in DGL data loaders via:\n\n* Put the ``train_nid`` onto GPU.\n\n* Set ``device`` argument to a GPU device.\n\n* Set ``num_workers`` argument to 0, because CUDA does not allow multiple processes\n  accessing the same context.\n\n* Set ``use_uva=True``.\n\nAll the other arguments for the :class:`~dgl.dataloading.DataLoader` can be\nthe same as the other user guides and tutorials.\n\n.. code:: python\n\n   train_nid = train_nid.to('cuda:0')\n   dataloader = dgl.dataloading.DataLoader(\n       g,\n       train_nid,                        # train_nid must be on GPU.\n       sampler,\n       device=torch.device('cuda:0'),    # The device argument must be GPU.\n       num_workers=0,                    # Number of workers must be 0.\n       batch_size=1000,\n       drop_last=False,\n       shuffle=True,\n       use_uva=True)                     # Set use_uva=True\n\nUVA-based sampling is the recommended solution for mini-batch training on large graphs,\nespecially for multi-GPU training.\n\n.. note::\n\n  To use UVA-based sampling in multi-GPU training, you should first materialize all the\n  necessary sparse formats of the graph before spawning training processes.\n  Refer to our `GraphSAGE example <https://github.com/dmlc/dgl/blob/master/examples/pytorch/graphsage/multi_gpu_node_classification.py>`_ for more details.\n\n\nUVA and GPU support for PinSAGESampler/RandomWalkNeighborSampler\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nPinSAGESampler and RandomWalkNeighborSampler support UVA and GPU sampling.\nYou can enable them via:\n\n* Pin the graph (for UVA sampling) or put the graph onto GPU (for GPU sampling).\n\n* Put the ``train_nid`` onto GPU.\n\n.. code:: python\n\n  g = dgl.heterograph({\n      ('item', 'bought-by', 'user'): ([0, 0, 1, 1, 2, 2, 3, 3], [0, 1, 0, 1, 2, 3, 2, 3]),\n      ('user', 'bought', 'item'): ([0, 1, 0, 1, 2, 3, 2, 3], [0, 0, 1, 1, 2, 2, 3, 3])})\n\n  # UVA setup\n  # g.create_formats_()\n  # g.pin_memory_()\n\n  # GPU setup\n  device = torch.device('cuda:0')\n  g = g.to(device)\n\n  sampler1 = dgl.sampling.PinSAGESampler(g, 'item', 'user', 4, 0.5, 3, 2)\n  sampler2 = dgl.sampling.RandomWalkNeighborSampler(g, 4, 0.5, 3, 2, ['bought-by', 'bought'])\n\n  train_nid = torch.tensor([0, 2], dtype=g.idtype, device=device)\n  sampler1(train_nid)\n  sampler2(train_nid)\n\n\nUsing GPU-based neighbor sampling with DGL functions\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nYou can build your own GPU sampling pipelines with the following functions that support\noperating on GPU:\n\n* :func:`dgl.sampling.sample_neighbors`\n* :func:`dgl.sampling.random_walk`\n\nSubgraph extraction ops:\n\n* :func:`dgl.node_subgraph`\n* :func:`dgl.edge_subgraph`\n* :func:`dgl.in_subgraph`\n* :func:`dgl.out_subgraph`\n\nGraph transform ops for subgraph construction:\n\n* :func:`dgl.to_block`\n* :func:`dgl.compact_graph`\n"
  },
  {
    "path": "docs/source/guide/minibatch-inference.rst",
    "content": ".. _guide-minibatch-inference:\n\n6.7 Exact Offline Inference on Large Graphs\n------------------------------------------------------\n\n:ref:`(中文版) <guide_cn-minibatch-inference>`\n\nBoth subgraph sampling and neighborhood sampling are to reduce the\nmemory and time consumption for training GNNs with GPUs. When performing\ninference it is usually better to truly aggregate over all neighbors\ninstead to get rid of the randomness introduced by sampling. However,\nfull-graph forward propagation is usually infeasible on GPU due to\nlimited memory, and slow on CPU due to slow computation. This section\nintroduces the methodology of full-graph forward propagation with\nlimited GPU memory via minibatch and neighborhood sampling.\n\nThe inference algorithm is different from the training algorithm, as the\nrepresentations of all nodes should be computed layer by layer, starting\nfrom the first layer. Specifically, for a particular layer, we need to\ncompute the output representations of all nodes from this GNN layer in\nminibatches. The consequence is that the inference algorithm will have\nan outer loop iterating over the layers, and an inner loop iterating\nover the minibatches of nodes. In contrast, the training algorithm has\nan outer loop iterating over the minibatches of nodes, and an inner loop\niterating over the layers for both neighborhood sampling and message\npassing.\n\nThe following animation shows how the computation would look like (note\nthat for every layer only the first three minibatches are drawn).\n\n.. figure:: https://data.dgl.ai/asset/image/guide_6_6_0.gif\n   :alt: Imgur\n\n\n\nImplementing Offline Inference\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nConsider the two-layer GCN we have mentioned in Section 6.1\n:ref:`guide-minibatch-node-classification-model`. The way\nto implement offline inference still involves using\n:class:`~dgl.graphbolt.NeighborSampler`, but sampling for\nonly one layer at a time.\n\n.. code:: python\n\n    datapipe = gb.ItemSampler(all_nodes_set, batch_size=1024, shuffle=True)\n    datapipe = datapipe.sample_neighbor(g, [-1]) # 1 layers.\n    datapipe = datapipe.fetch_feature(feature, node_feature_keys=[\"feat\"])\n    datapipe = datapipe.copy_to(device)\n    dataloader = gb.DataLoader(datapipe)\n\n\nNote that offline inference is implemented as a method of the GNN module\nbecause the computation on one layer depends on how messages are aggregated\nand combined as well.\n\n.. code:: python\n\n    class SAGE(nn.Module):\n        def __init__(self, in_size, hidden_size, out_size):\n            super().__init__()\n            self.layers = nn.ModuleList()\n            # Three-layer GraphSAGE-mean.\n            self.layers.append(dglnn.SAGEConv(in_size, hidden_size, \"mean\"))\n            self.layers.append(dglnn.SAGEConv(hidden_size, hidden_size, \"mean\"))\n            self.layers.append(dglnn.SAGEConv(hidden_size, out_size, \"mean\"))\n            self.dropout = nn.Dropout(0.5)\n            self.hidden_size = hidden_size\n            self.out_size = out_size\n\n        def forward(self, blocks, x):\n            hidden_x = x\n            for layer_idx, (layer, block) in enumerate(zip(self.layers, blocks)):\n                hidden_x = layer(block, hidden_x)\n                is_last_layer = layer_idx == len(self.layers) - 1\n                if not is_last_layer:\n                    hidden_x = F.relu(hidden_x)\n                    hidden_x = self.dropout(hidden_x)\n            return hidden_x\n    \n        def inference(self, graph, features, dataloader, device):\n            \"\"\"\n            Offline inference with this module\n            \"\"\"\n            feature = features.read(\"node\", None, \"feat\")\n\n            # Compute representations layer by layer\n            for layer_idx, layer in enumerate(self.layers):\n                is_last_layer = layer_idx == len(self.layers) - 1\n\n                y = torch.empty(\n                    graph.total_num_nodes,\n                    self.out_size if is_last_layer else self.hidden_size,\n                    dtype=torch.float32,\n                    device=buffer_device,\n                    pin_memory=pin_memory,\n                )\n                feature = feature.to(device)\n\n                for step, data in tqdm(enumerate(dataloader)):\n                    x = feature[data.input_nodes]\n                    hidden_x = layer(data.blocks[0], x)  # len(blocks) = 1\n                    if not is_last_layer:\n                        hidden_x = F.relu(hidden_x)\n                        hidden_x = self.dropout(hidden_x)\n                    # By design, our output nodes are contiguous.\n                    y[\n                        data.seeds[0] : data.seeds[-1] + 1\n                    ] = hidden_x.to(device)\n                feature = y\n\n            return y\n\n\nNote that for the purpose of computing evaluation metric on the\nvalidation set for model selection we usually don’t have to compute\nexact offline inference. The reason is that we need to compute the\nrepresentation for every single node on every single layer, which is\nusually very costly especially in the semi-supervised regime with a lot\nof unlabeled data. Neighborhood sampling will work fine for model\nselection and validation.\n\nOne can see\n`GraphSAGE <https://github.com/dmlc/dgl/blob/master/examples/graphbolt/node_classification.py>`__\nand\n`RGCN <https://github.com/dmlc/dgl/blob/master/examples/graphbolt/rgcn/hetero_rgcn.py>`__\nfor examples of offline inference.\n"
  },
  {
    "path": "docs/source/guide/minibatch-link.rst",
    "content": ".. _guide-minibatch-link-classification-sampler:\n\n6.3 Training GNN for Link Prediction with Neighborhood Sampling\n--------------------------------------------------------------------\n\n:ref:`(中文版) <guide_cn-minibatch-link-classification-sampler>`\n\nDefine a data loader with neighbor and negative sampling\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nYou can still use the same data loader as the one in node/edge classification.\nThe only difference is that you need to add an additional stage\n`negative sampling` before neighbor sampling stage. The following data loader\nwill pick 5 negative destination nodes uniformly for each source node of an\nedge.\n\n.. code:: python\n\n    datapipe = datapipe.sample_uniform_negative(graph, 5)\n\nThe whole data loader pipeline is as follows:\n\n.. code:: python\n\n    datapipe = gb.ItemSampler(itemset, batch_size=1024, shuffle=True)\n    datapipe = datapipe.sample_uniform_negative(graph, 5)\n    datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.\n    datapipe = datapipe.transform(gb.exclude_seed_edges)\n    datapipe = datapipe.fetch_feature(feature, node_feature_keys=[\"feat\"])\n    datapipe = datapipe.copy_to(device)\n    dataloader = gb.DataLoader(datapipe)\n\n\nFor the details about the builtin uniform negative sampler please see\n:class:`~dgl.graphbolt.UniformNegativeSampler`.\n\nYou can also give your own negative sampler function, as long as it inherits\nfrom :class:`~dgl.graphbolt.NegativeSampler` and overrides the\n:meth:`~dgl.graphbolt.NegativeSampler._sample_with_etype` method which takes in\nthe node pairs in minibatch, and returns the negative node pairs back.\n\nThe following gives an example of custom negative sampler that samples\nnegative destination nodes according to a probability distribution\nproportional to a power of degrees.\n\n.. code:: python\n\n    @functional_datapipe(\"customized_sample_negative\")\n    class CustomizedNegativeSampler(dgl.graphbolt.NegativeSampler):\n        def __init__(self, datapipe, k, node_degrees):\n            super().__init__(datapipe, k)\n            # caches the probability distribution\n            self.weights = node_degrees ** 0.75\n            self.k = k\n    \n        def _sample_with_etype(self, seeds, etype=None):\n            src, _ = seeds.T\n            src = src.repeat_interleave(self.k)\n            dst = self.weights.multinomial(len(src), replacement=True)\n            return src, dst\n\n    datapipe = datapipe.customized_sample_negative(5, node_degrees)\n\n\nDefine a GraphSAGE model for minibatch training\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. code:: python\n\n    class SAGE(nn.Module):\n        def __init__(self, in_size, hidden_size):\n            super().__init__()\n            self.layers = nn.ModuleList()\n            self.layers.append(dglnn.SAGEConv(in_size, hidden_size, \"mean\"))\n            self.layers.append(dglnn.SAGEConv(hidden_size, hidden_size, \"mean\"))\n            self.layers.append(dglnn.SAGEConv(hidden_size, hidden_size, \"mean\"))\n            self.hidden_size = hidden_size\n            self.predictor = nn.Sequential(\n                nn.Linear(hidden_size, hidden_size),\n                nn.ReLU(),\n                nn.Linear(hidden_size, hidden_size),\n                nn.ReLU(),\n                nn.Linear(hidden_size, 1),\n            )\n\n        def forward(self, blocks, x):\n            hidden_x = x\n            for layer_idx, (layer, block) in enumerate(zip(self.layers, blocks)):\n                hidden_x = layer(block, hidden_x)\n                is_last_layer = layer_idx == len(self.layers) - 1\n                if not is_last_layer:\n                    hidden_x = F.relu(hidden_x)\n            return hidden_x\n\n\nWhen a negative sampler is provided, the data loader will generate positive and\nnegative node pairs for each minibatch besides the *Message Flow Graphs* (MFGs).\nUse `compacted_seeds` and `labels` to get compact node pairs and corresponding\nlabels.\n\n\nTraining loop\n~~~~~~~~~~~~~\n\nThe training loop simply involves iterating over the data loader and\nfeeding in the graphs as well as the input features to the model defined\nabove.\n\n.. code:: python\n\n    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n\n    for epoch in tqdm.trange(args.epochs):\n        model.train()\n        total_loss = 0\n        start_epoch_time = time.time()\n        for step, data in enumerate(dataloader):\n            # Unpack MiniBatch.\n            compacted_seeds = data.compacted_seeds.T\n            labels = data.labels\n            node_feature = data.node_features[\"feat\"]\n            # Convert sampled subgraphs to DGL blocks.\n            blocks = data.blocks\n\n            # Get the embeddings of the input nodes.\n            y = model(blocks, node_feature)\n            logits = model.predictor(\n                y[compacted_seeds[0]] * y[compacted_seeds[1]]\n            ).squeeze()\n\n            # Compute loss.\n            loss = F.binary_cross_entropy_with_logits(logits, labels)\n            optimizer.zero_grad()\n            loss.backward()\n            optimizer.step()\n\n            total_loss += loss.item()\n        end_epoch_time = time.time()\n\n\nDGL provides the\n`unsupervised learning GraphSAGE <https://github.com/dmlc/dgl/blob/master/examples/graphbolt/link_prediction.py>`__\nthat shows an example of link prediction on homogeneous graphs.\n\nFor heterogeneous graphs\n~~~~~~~~~~~~~~~~~~~~~~~~\n\nThe previous model could be easily extended to heterogeneous graphs. The only\ndifference is that you need to use :class:`~dgl.nn.HeteroGraphConv` to wrap\n:class:`~dgl.nn.SAGEConv` according to edge types.\n\n.. code:: python\n\n    class SAGE(nn.Module):\n        def __init__(self, in_size, hidden_size):\n            super().__init__()\n            self.layers = nn.ModuleList()\n            self.layers.append(dglnn.HeteroGraphConv({\n                    rel : dglnn.SAGEConv(in_size, hidden_size, \"mean\")\n                    for rel in rel_names\n                }))\n            self.layers.append(dglnn.HeteroGraphConv({\n                    rel : dglnn.SAGEConv(hidden_size, hidden_size, \"mean\")\n                    for rel in rel_names\n                }))\n            self.layers.append(dglnn.HeteroGraphConv({\n                    rel : dglnn.SAGEConv(hidden_size, hidden_size, \"mean\")\n                    for rel in rel_names\n                }))\n            self.hidden_size = hidden_size\n            self.predictor = nn.Sequential(\n                nn.Linear(hidden_size, hidden_size),\n                nn.ReLU(),\n                nn.Linear(hidden_size, hidden_size),\n                nn.ReLU(),\n                nn.Linear(hidden_size, 1),\n            )\n\n        def forward(self, blocks, x):\n            hidden_x = x\n            for layer_idx, (layer, block) in enumerate(zip(self.layers, blocks)):\n                hidden_x = layer(block, hidden_x)\n                is_last_layer = layer_idx == len(self.layers) - 1\n                if not is_last_layer:\n                    hidden_x = F.relu(hidden_x)\n            return hidden_x\n\n\nData loader definition is also very similar to that for homogeneous graph. The\nonly difference is that you need to give edge types for feature fetching.\n\n.. code:: python\n\n    datapipe = gb.ItemSampler(itemset, batch_size=1024, shuffle=True)\n    datapipe = datapipe.sample_uniform_negative(graph, 5)\n    datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.\n    datapipe = datapipe.transform(gb.exclude_seed_edges)\n    datapipe = datapipe.fetch_feature(\n        feature,\n        node_feature_keys={\"user\": [\"feat\"], \"item\": [\"feat\"]}\n    )\n    datapipe = datapipe.copy_to(device)\n    dataloader = gb.DataLoader(datapipe)\n\nIf you want to give your own negative sampling function, just inherit from the\n:class:`~dgl.graphbolt.NegativeSampler` class and override the\n:meth:`~dgl.graphbolt.NegativeSampler._sample_with_etype` method.\n\n.. code:: python\n\n    @functional_datapipe(\"customized_sample_negative\")\n    class CustomizedNegativeSampler(dgl.graphbolt.NegativeSampler):\n        def __init__(self, datapipe, k, node_degrees):\n            super().__init__(datapipe, k)\n            # caches the probability distribution\n            self.weights = {\n                etype: node_degrees[etype] ** 0.75 for etype in node_degrees\n            }\n            self.k = k\n    \n        def _sample_with_etype(self, seeds, etype):\n            src, _ = seeds.T\n            src = src.repeat_interleave(self.k)\n            dst = self.weights[etype].multinomial(len(src), replacement=True)\n            return src, dst\n\n    datapipe = datapipe.customized_sample_negative(5, node_degrees)\n\n\nFor heterogeneous graphs, node pairs are grouped by edge types. The training\nloop is again almost the same as that on homogeneous graph, except for computing\nloss on specific edge type.\n\n.. code:: python\n\n    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n\n    category = \"user\"\n    for epoch in tqdm.trange(args.epochs):\n        model.train()\n        total_loss = 0\n        start_epoch_time = time.time()\n        for step, data in enumerate(dataloader):\n            # Unpack MiniBatch.\n            compacted_seeds = data.compacted_seeds\n            labels = data.labels\n            node_features = {\n                ntype: data.node_features[(ntype, \"feat\")]\n                for ntype in data.blocks[0].srctypes\n            }\n            # Convert sampled subgraphs to DGL blocks.\n            blocks = data.blocks\n            # Get the embeddings of the input nodes.\n            y = model(blocks, node_feature)\n            logits = model.predictor(\n                y[category][compacted_pairs[category][:, 0]]\n                * y[category][compacted_pairs[category][:, 1]]\n            ).squeeze()\n\n            # Compute loss.\n            loss = F.binary_cross_entropy_with_logits(logits, labels[category])\n            optimizer.zero_grad()\n            loss.backward()\n            optimizer.step()\n\n            total_loss += loss.item()\n        end_epoch_time = time.time()\n\n"
  },
  {
    "path": "docs/source/guide/minibatch-nn.rst",
    "content": ".. _guide-minibatch-custom-gnn-module:\n\n6.6 Implementing Custom GNN Module for Mini-batch Training\n-------------------------------------------------------------\n\n:ref:`(中文版) <guide_cn-minibatch-custom-gnn-module>`\n\n.. note::\n\n   :doc:`This tutorial <tutorials/large/L4_message_passing>` has similar\n   content to this section for the homogeneous graph case.\n\n\nIf you were familiar with how to write a custom GNN module for updating\nthe entire graph for homogeneous or heterogeneous graphs (see\n:ref:`guide-nn`), the code for computing on\nMFGs is similar, with the exception that the nodes are divided into\ninput nodes and output nodes.\n\nFor example, consider the following custom graph convolution module\ncode. Note that it is not necessarily among the most efficient implementations\n- they only serve for an example of how a custom GNN module could look\nlike.\n\n.. code:: python\n\n    class CustomGraphConv(nn.Module):\n        def __init__(self, in_feats, out_feats):\n            super().__init__()\n            self.W = nn.Linear(in_feats * 2, out_feats)\n    \n        def forward(self, g, h):\n            with g.local_scope():\n                g.ndata['h'] = h\n                g.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h_neigh'))\n                return self.W(torch.cat([g.ndata['h'], g.ndata['h_neigh']], 1))\n\nIf you have a custom message passing NN module for the full graph, and\nyou would like to make it work for MFGs, you only need to rewrite the\nforward function as follows. Note that the corresponding statements from\nthe full-graph implementation are commented; you can compare the\noriginal statements with the new statements.\n\n.. code:: python\n\n    class CustomGraphConv(nn.Module):\n        def __init__(self, in_feats, out_feats):\n            super().__init__()\n            self.W = nn.Linear(in_feats * 2, out_feats)\n    \n        # h is now a pair of feature tensors for input and output nodes, instead of\n        # a single feature tensor.\n        # def forward(self, g, h):\n        def forward(self, block, h):\n            # with g.local_scope():\n            with block.local_scope():\n                # g.ndata['h'] = h\n                h_src = h\n                h_dst = h[:block.number_of_dst_nodes()]\n                block.srcdata['h'] = h_src\n                block.dstdata['h'] = h_dst\n    \n                # g.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h_neigh'))\n                block.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h_neigh'))\n    \n                # return self.W(torch.cat([g.ndata['h'], g.ndata['h_neigh']], 1))\n                return self.W(torch.cat(\n                    [block.dstdata['h'], block.dstdata['h_neigh']], 1))\n\nIn general, you need to do the following to make your NN module work for\nMFGs.\n\n-  Obtain the features for output nodes from the input features by\n   slicing the first few rows. The number of rows can be obtained by\n   :meth:`block.number_of_dst_nodes <dgl.DGLGraph.number_of_dst_nodes>`.\n-  Replace\n   :attr:`g.ndata <dgl.DGLGraph.ndata>` with either\n   :attr:`block.srcdata <dgl.DGLGraph.srcdata>` for features on input nodes or\n   :attr:`block.dstdata <dgl.DGLGraph.dstdata>` for features on output nodes, if\n   the original graph has only one node type.\n-  Replace\n   :attr:`g.nodes <dgl.DGLGraph.nodes>` with either\n   :attr:`block.srcnodes <dgl.DGLGraph.srcnodes>` for features on input nodes or\n   :attr:`block.dstnodes <dgl.DGLGraph.dstnodes>` for features on output nodes,\n   if the original graph has multiple node types.\n-  Replace\n   :meth:`g.num_nodes <dgl.DGLGraph.num_nodes>` with either\n   :meth:`block.number_of_src_nodes <dgl.DGLGraph.number_of_src_nodes>` or\n   :meth:`block.number_of_dst_nodes <dgl.DGLGraph.number_of_dst_nodes>` for the number of\n   input nodes or output nodes respectively.\n\nHeterogeneous graphs\n~~~~~~~~~~~~~~~~~~~~\n\nFor heterogeneous graph the way of writing custom GNN modules is\nsimilar. For instance, consider the following module that work on full\ngraph.\n\n.. code:: python\n\n    class CustomHeteroGraphConv(nn.Module):\n        def __init__(self, g, in_feats, out_feats):\n            super().__init__()\n            self.Ws = nn.ModuleDict()\n            for etype in g.canonical_etypes:\n                utype, _, vtype = etype\n                self.Ws[etype] = nn.Linear(in_feats[utype], out_feats[vtype])\n            for ntype in g.ntypes:\n                self.Vs[ntype] = nn.Linear(in_feats[ntype], out_feats[ntype])\n    \n        def forward(self, g, h):\n            with g.local_scope():\n                for ntype in g.ntypes:\n                    g.nodes[ntype].data['h_dst'] = self.Vs[ntype](h[ntype])\n                    g.nodes[ntype].data['h_src'] = h[ntype]\n                for etype in g.canonical_etypes:\n                    utype, _, vtype = etype\n                    g.update_all(\n                        fn.copy_u('h_src', 'm'), fn.mean('m', 'h_neigh'),\n                        etype=etype)\n                    g.nodes[vtype].data['h_dst'] = g.nodes[vtype].data['h_dst'] + \\\n                        self.Ws[etype](g.nodes[vtype].data['h_neigh'])\n                return {ntype: g.nodes[ntype].data['h_dst'] for ntype in g.ntypes}\n\nFor ``CustomHeteroGraphConv``, the principle is to replace ``g.nodes``\nwith ``g.srcnodes`` or ``g.dstnodes`` depend on whether the features\nserve for input or output.\n\n.. code:: python\n\n    class CustomHeteroGraphConv(nn.Module):\n        def __init__(self, g, in_feats, out_feats):\n            super().__init__()\n            self.Ws = nn.ModuleDict()\n            for etype in g.canonical_etypes:\n                utype, _, vtype = etype\n                self.Ws[etype] = nn.Linear(in_feats[utype], out_feats[vtype])\n            for ntype in g.ntypes:\n                self.Vs[ntype] = nn.Linear(in_feats[ntype], out_feats[ntype])\n    \n        def forward(self, g, h):\n            with g.local_scope():\n                for ntype in g.ntypes:\n                    h_src, h_dst = h[ntype]\n                    g.dstnodes[ntype].data['h_dst'] = self.Vs[ntype](h[ntype])\n                    g.srcnodes[ntype].data['h_src'] = h[ntype]\n                for etype in g.canonical_etypes:\n                    utype, _, vtype = etype\n                    g.update_all(\n                        fn.copy_u('h_src', 'm'), fn.mean('m', 'h_neigh'),\n                        etype=etype)\n                    g.dstnodes[vtype].data['h_dst'] = \\\n                        g.dstnodes[vtype].data['h_dst'] + \\\n                        self.Ws[etype](g.dstnodes[vtype].data['h_neigh'])\n                return {ntype: g.dstnodes[ntype].data['h_dst']\n                        for ntype in g.ntypes}\n\nWriting modules that work on homogeneous graphs, bipartite graphs, and MFGs\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nAll message passing modules in DGL work on homogeneous graphs,\nunidirectional bipartite graphs (that have two node types and one edge\ntype), and a MFG with one edge type. Essentially, the input graph and\nfeature of a builtin DGL neural network module must satisfy either of\nthe following cases.\n\n-  If the input feature is a pair of tensors, then the input graph must\n   be unidirectional bipartite.\n-  If the input feature is a single tensor and the input graph is a\n   MFG, DGL will automatically set the feature on the output nodes as\n   the first few rows of the input node features.\n-  If the input feature must be a single tensor and the input graph is\n   not a MFG, then the input graph must be homogeneous.\n\nFor example, the following is simplified from the PyTorch implementation\nof :class:`dgl.nn.pytorch.SAGEConv` (also available in MXNet and Tensorflow)\n(removing normalization and dealing with only mean aggregation etc.).\n\n.. code:: python\n\n    import dgl.function as fn\n    class SAGEConv(nn.Module):\n        def __init__(self, in_feats, out_feats):\n            super().__init__()\n            self.W = nn.Linear(in_feats * 2, out_feats)\n    \n        def forward(self, g, h):\n            if isinstance(h, tuple):\n                h_src, h_dst = h\n            elif g.is_block:\n                h_src = h\n                h_dst = h[:g.number_of_dst_nodes()]\n            else:\n                h_src = h_dst = h\n                 \n            g.srcdata['h'] = h_src\n            g.dstdata['h'] = h_dst\n            g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h_neigh'))\n            return F.relu(\n                self.W(torch.cat([g.dstdata['h'], g.dstdata['h_neigh']], 1)))\n\n:ref:`guide-nn` also provides a walkthrough on :class:`dgl.nn.pytorch.SAGEConv`,\nwhich works on unidirectional bipartite graphs, homogeneous graphs, and MFGs.\n\n\n"
  },
  {
    "path": "docs/source/guide/minibatch-node.rst",
    "content": ".. _guide-minibatch-node-classification-sampler:\n\n6.1 Training GNN for Node Classification with Neighborhood Sampling\n-----------------------------------------------------------------------\n\n:ref:`(中文版) <guide_cn-minibatch-node-classification-sampler>`\n\nTo make your model been trained stochastically, you need to do the\nfollowings:\n\n-  Define a neighborhood sampler.\n-  Adapt your model for minibatch training.\n-  Modify your training loop.\n\nThe following sub-subsections address these steps one by one.\n\nDefine a neighborhood sampler and data loader\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nDGL provides several neighborhood sampler classes that generates the\ncomputation dependencies needed for each layer given the nodes we wish\nto compute on.\n\nThe simplest neighborhood sampler is :class:`~dgl.graphbolt.NeighborSampler`\nor the equivalent function-like interface :func:`~dgl.graphbolt.sample_neighbor`\nwhich makes the node gather messages from its neighbors.\n\nTo use a sampler provided by DGL, one also need to combine it with\n:class:`~dgl.graphbolt.DataLoader`, which iterates\nover a set of indices (nodes in this case) in minibatches.\n\nFor example, the following code creates a DataLoader that\niterates over the training node ID set of ``ogbn-arxiv`` in batches,\nputting the list of generated MFGs onto GPU.\n\n.. code:: python\n\n    import dgl\n    import dgl.graphbolt as gb\n    import dgl.nn as dglnn\n    import torch\n    import torch.nn as nn\n    import torch.nn.functional as F\n\n    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n    dataset = gb.BuiltinDataset(\"ogbn-arxiv\").load()\n    g = dataset.graph\n    feature = dataset.feature\n    train_set = dataset.tasks[0].train_set\n    datapipe = gb.ItemSampler(train_set, batch_size=1024, shuffle=True)\n    datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.\n    # Or equivalently:\n    # datapipe = gb.NeighborSampler(datapipe, g, [10, 10])\n    datapipe = datapipe.fetch_feature(feature, node_feature_keys=[\"feat\"])\n    datapipe = datapipe.copy_to(device)\n    dataloader = gb.DataLoader(datapipe)\n\n\nIterating over the DataLoader will yield :class:`~dgl.graphbolt.MiniBatch`\nwhich contains a list of specially created graphs representing the computation\ndependencies on each layer. In order to train with DGL, you can access the\n*message flow graphs* (MFGs) by calling `mini_batch.blocks`.\n\n.. code:: python\n\n    mini_batch = next(iter(dataloader))\n    print(mini_batch.blocks)\n\n\n.. note::\n\n   See the `Stochastic Training Tutorial\n   <../notebooks/stochastic_training/neighbor_sampling_overview.nblink>`__\n   for the concept of message flow graph.\n\n   If you wish to develop your own neighborhood sampler or you want a more\n   detailed explanation of the concept of MFGs, please refer to\n   :ref:`guide-minibatch-customizing-neighborhood-sampler`.\n\n\n.. _guide-minibatch-node-classification-model:\n\nAdapt your model for minibatch training\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nIf your message passing modules are all provided by DGL, the changes\nrequired to adapt your model to minibatch training is minimal. Take a\nmulti-layer GCN as an example. If your model on full graph is\nimplemented as follows:\n\n.. code:: python\n\n    class TwoLayerGCN(nn.Module):\n        def __init__(self, in_features, hidden_features, out_features):\n            super().__init__()\n            self.conv1 = dglnn.GraphConv(in_features, hidden_features)\n            self.conv2 = dglnn.GraphConv(hidden_features, out_features)\n    \n        def forward(self, g, x):\n            x = F.relu(self.conv1(g, x))\n            x = F.relu(self.conv2(g, x))\n            return x\n\nThen all you need is to replace ``g`` with ``blocks`` generated above.\n\n.. code:: python\n\n    class StochasticTwoLayerGCN(nn.Module):\n        def __init__(self, in_features, hidden_features, out_features):\n            super().__init__()\n            self.conv1 = dgl.nn.GraphConv(in_features, hidden_features)\n            self.conv2 = dgl.nn.GraphConv(hidden_features, out_features)\n    \n        def forward(self, blocks, x):\n            x = F.relu(self.conv1(blocks[0], x))\n            x = F.relu(self.conv2(blocks[1], x))\n            return x\n\nThe DGL ``GraphConv`` modules above accepts an element in ``blocks``\ngenerated by the data loader as an argument.\n\n:ref:`The API reference of each NN module <apinn>` will tell you\nwhether it supports accepting a MFG as an argument.\n\nIf you wish to use your own message passing module, please refer to\n:ref:`guide-minibatch-custom-gnn-module`.\n\nTraining Loop\n~~~~~~~~~~~~~\n\nThe training loop simply consists of iterating over the dataset with the\ncustomized batching iterator. During each iteration that yields\n:class:`~dgl.graphbolt.MiniBatch`, we:\n\n1. Access the node features corresponding to the input nodes via\n   ``data.node_features[\"feat\"]``. These features are already moved to the\n   target device (CPU or GPU) by the data loader.\n\n2. Access the node labels corresponding to the output nodes via\n   ``data.labels``. These labels are already moved to the target device\n   (CPU or GPU) by the data loader.\n\n3. Feed the list of MFGs and the input node features to the multilayer\n   GNN and get the outputs.\n\n4. Compute the loss and backpropagate.\n\n.. code:: python\n\n    model = StochasticTwoLayerGCN(in_features, hidden_features, out_features)\n    model = model.to(device)\n    opt = torch.optim.Adam(model.parameters())\n\n    for data in dataloader:\n        input_features = data.node_features[\"feat\"]\n        output_labels = data.labels\n        output_predictions = model(data.blocks, input_features)\n        loss = compute_loss(output_labels, output_predictions)\n        opt.zero_grad()\n        loss.backward()\n        opt.step()\n\n\nDGL provides an end-to-end stochastic training example `GraphSAGE\nimplementation <https://github.com/dmlc/dgl/blob/master/examples/graphbolt/node_classification.py>`__.\n\nFor heterogeneous graphs\n~~~~~~~~~~~~~~~~~~~~~~~~\n\nTraining a graph neural network for node classification on heterogeneous\ngraph is similar.\n\nFor instance, we have previously seen\n:ref:`how to train a 2-layer RGCN on full graph <guide-training-rgcn-node-classification>`.\nThe code for RGCN implementation on minibatch training looks very\nsimilar to that (with self-loops, non-linearity and basis decomposition\nremoved for simplicity):\n\n.. code:: python\n\n    class StochasticTwoLayerRGCN(nn.Module):\n        def __init__(self, in_feat, hidden_feat, out_feat, rel_names):\n            super().__init__()\n            self.conv1 = dglnn.HeteroGraphConv({\n                    rel : dglnn.GraphConv(in_feat, hidden_feat, norm='right')\n                    for rel in rel_names\n                })\n            self.conv2 = dglnn.HeteroGraphConv({\n                    rel : dglnn.GraphConv(hidden_feat, out_feat, norm='right')\n                    for rel in rel_names\n                })\n    \n        def forward(self, blocks, x):\n            x = self.conv1(blocks[0], x)\n            x = self.conv2(blocks[1], x)\n            return x\n\nThe samplers provided by DGL also support heterogeneous graphs.\nFor example, one can still use the provided\n:class:`~dgl.graphbolt.NeighborSampler` class and\n:class:`~dgl.graphbolt.DataLoader` class for\nstochastic training. The only difference is that the itemset is now an\ninstance of :class:`~dgl.graphbolt.HeteroItemSet` which is a dictionary\nof node types to node IDs.\n\n.. code:: python\n\n    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n    dataset = gb.BuiltinDataset(\"ogbn-mag\").load()\n    g = dataset.graph\n    feature = dataset.feature\n    train_set = dataset.tasks[0].train_set\n    datapipe = gb.ItemSampler(train_set, batch_size=1024, shuffle=True)\n    datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.\n    # Or equivalently:\n    # datapipe = gb.NeighborSampler(datapipe, g, [10, 10])\n    # For heterogeneous graphs, we need to specify the node feature keys\n    # for each node type.\n    datapipe = datapipe.fetch_feature(\n        feature, node_feature_keys={\"author\": [\"feat\"], \"paper\": [\"feat\"]}\n    )\n    datapipe = datapipe.copy_to(device)\n    dataloader = gb.DataLoader(datapipe)\n\nThe training loop is almost the same as that of homogeneous graphs,\nexcept for the implementation of ``compute_loss`` that will take in two\ndictionaries of node types and predictions here.\n\n.. code:: python\n\n    model = StochasticTwoLayerRGCN(in_features, hidden_features, out_features, etypes)\n    model = model.to(device)\n    opt = torch.optim.Adam(model.parameters())\n    \n    for data in dataloader:\n        # For heterogeneous graphs, we need to specify the node types and\n        # feature name when accessing the node features. So does the labels.\n        input_features = {\n            \"author\": data.node_features[(\"author\", \"feat\")],\n            \"paper\": data.node_features[(\"paper\", \"feat\")]\n        }\n        output_labels = data.labels[\"paper\"]\n        output_predictions = model(data.blocks, input_features)\n        loss = compute_loss(output_labels, output_predictions)\n        opt.zero_grad()\n        loss.backward()\n        opt.step()\n\nDGL provides an end-to-end stochastic training example `RGCN\nimplementation <https://github.com/dmlc/dgl/blob/master/examples/graphbolt/rgcn/hetero_rgcn.py>`__.\n\n\n"
  },
  {
    "path": "docs/source/guide/minibatch-parallelism.rst",
    "content": ".. _guide-minibatch-parallelism:\n\n6.9 Data Loading Parallelism\n-----------------------\n\nIn minibatch training of GNNs, we usually need to cover several stages to\ngenerate a minibatch, including:\n\n* Iterate over item set and generate minibatch seeds in batch size.\n* Sample negative items for each seed from graph.\n* Sample neighbors for each seed from graph.\n* Exclude seed edges from the sampled subgraphs.\n* Fetch node and edge features for the sampled subgraphs.\n* Copy the MiniBatches to the target device.\n\n.. code:: python\n\n    datapipe = gb.ItemSampler(itemset, batch_size=1024, shuffle=True)\n    datapipe = datapipe.sample_uniform_negative(g, 5)\n    datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.\n    datapipe = datapipe.transform(gb.exclude_seed_edges)\n    datapipe = datapipe.fetch_feature(feature, node_feature_keys=[\"feat\"])\n    datapipe = datapipe.copy_to(device)\n    dataloader = gb.DataLoader(datapipe)\n\nAll these stages are implemented in separate\n`IterableDataPipe <https://pytorch.org/data/0.7/torchdata.datapipes.iter.html>`__\nand stacked together with `PyTorch DataLoader <https://pytorch.org/docs/stable/data\n.html#torch.utils.data.DataLoader>`__.\nThis design allows us to easily customize the data loading process by\nchaining different data pipes together. For example, if we want to sample\nnegative items for each seed from graph, we can simply chain the\n:class:`~dgl.graphbolt.NegativeSampler` after the :class:`~dgl.graphbolt.ItemSampler`.\n\nBut simply chaining data pipes together incurs performance overheads as various\nhardware resources such as CPU, GPU, PCIe, etc. are utilized by different stages.\nAs a result, the data loading mechanism is optimized to minimize the overheads\nand achieve the best performance.\n\nIn specific, GraphBolt wraps the data pipes before ``fetch_feature`` with\nmultiprocessing which enables multiple processes to run in parallel. As for\n``fetch_feature`` data pipe, we keep it running in the main process to avoid\ndata movement overheads between processes.\n\nWhat's more, in order to overlap the data movement and model computation, we\nwrap data pipes before ``copy_to`` with\n`torchdata.datapipes.iter.Perfetcher <https://pytorch.org/data/0.7/generated/\ntorchdata.datapipes.iter.Prefetcher.html>`__\nwhich prefetches elements from previous data pipes and puts them into a buffer.\nSuch prefetching is totally transparent to users and requires no extra code. It\nbrings a significant performance boost to minibatch training of GNNs.\n\nPlease refer to the source code of :class:`~dgl.graphbolt.DataLoader`\nfor more details.\n"
  },
  {
    "path": "docs/source/guide/minibatch-sparse.rst",
    "content": ".. _guide-minibatch-sparse:\n\n6.5 Training GNN with DGL sparse\n---------------------------------\n\nThis tutorial demonstrates how to use dgl sparse library to sample on graph and\ntrain model. It trains and tests a GraphSAGE model using the sparse sample and\ncompact operators to sample submatrix from the whole matrix.\n\nTraining GNN with DGL sparse is quite similar to\n:ref:`guide-minibatch-node-classification-sampler`. The major difference is\nthe customized sampler and matrix that represents graph.\n\nWe have cutomized one sampler in\n:ref:`guide-minibatch-customizing-neighborhood-sampler`. In this tutorial, we\nwill customize another sampler with DGL sparse library as shown below.\n\n.. code:: python\n\n    @functional_datapipe(\"sample_sparse_neighbor\")\n    class SparseNeighborSampler(SubgraphSampler):\n        def __init__(self, datapipe, matrix, fanouts):\n            super().__init__(datapipe)\n            self.matrix = matrix\n            # Convert fanouts to a list of tensors.\n            self.fanouts = []\n            for fanout in fanouts:\n                if not isinstance(fanout, torch.Tensor):\n                    fanout = torch.LongTensor([int(fanout)])\n                self.fanouts.insert(0, fanout)\n\n        def sample_subgraphs(self, seeds):\n            sampled_matrices = []\n            src = seeds\n\n            #####################################################################\n            # (HIGHLIGHT) Using the sparse sample operator to preform random\n            # sampling on the neighboring nodes of the seeds nodes. The sparse\n            # compact operator is then employed to compact and relabel the sampled\n            # matrix, resulting in the sampled matrix and the relabel index.\n            #####################################################################\n            for fanout in self.fanouts:\n                # Sample neighbors.\n                sampled_matrix = self.matrix.sample(1, fanout, ids=src).coalesce()\n                # Compact the sampled matrix.\n                compacted_mat, row_ids = sampled_matrix.compact(0)\n                sampled_matrices.insert(0, compacted_mat)\n                src = row_ids\n\n            return src, sampled_matrices\n\nAnother major difference is the matrix that represents graph. Previously we use\n:class:`~dgl.graphbolt.FusedCSCSamplingGraph` for sampling. In this tutorial,\nwe use :class:`~dgl.sparse.SparseMatrix` to represent graph.\n\n.. code:: python\n\n    dataset = gb.BuiltinDataset(\"ogbn-products\").load()\n    g = dataset.graph\n    # Create sparse.\n    N = g.num_nodes\n    A = dglsp.from_csc(g.csc_indptr, g.indices, shape=(N, N))\n\n\nThe remaining code is almost same as node classification tutorial.\n\nTo use this sampler with :class:`~dgl.graphbolt.DataLoader`:\n\n.. code:: python\n\n    datapipe = gb.ItemSampler(ids, batch_size=1024)\n    # Customize graphbolt sampler by sparse.\n    datapipe = datapipe.sample_sparse_neighbor(A, fanouts)\n    # Use grapbolt to fetch features.\n    datapipe = datapipe.fetch_feature(features, node_feature_keys=[\"feat\"])\n    datapipe = datapipe.copy_to(device)\n    dataloader = gb.DataLoader(datapipe)\n\nModel definition is shown below:\n\n.. code:: python\n\n    class SAGEConv(nn.Module):\n        r\"\"\"GraphSAGE layer from `Inductive Representation Learning on\n        Large Graphs <https://arxiv.org/pdf/1706.02216.pdf>`__\n        \"\"\"\n\n        def __init__(\n            self,\n            in_feats,\n            out_feats,\n        ):\n            super(SAGEConv, self).__init__()\n            self._in_src_feats, self._in_dst_feats = in_feats, in_feats\n            self._out_feats = out_feats\n\n            self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=False)\n            self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=True)\n            self.reset_parameters()\n\n        def reset_parameters(self):\n            gain = nn.init.calculate_gain(\"relu\")\n            nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)\n            nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)\n\n        def forward(self, A, feat):\n            feat_src = feat\n            feat_dst = feat[: A.shape[1]]\n\n            # Aggregator type: mean.\n            srcdata = self.fc_neigh(feat_src)\n            # Divided by degree.\n            D_hat = dglsp.diag(A.sum(0)) ** -1\n            A_div = A @ D_hat\n            # Conv neighbors.\n            dstdata = A_div.T @ srcdata\n\n            rst = self.fc_self(feat_dst) + dstdata\n            return rst\n\n\n    class SAGE(nn.Module):\n        def __init__(self, in_size, hid_size, out_size):\n            super().__init__()\n            self.layers = nn.ModuleList()\n            # Three-layer GraphSAGE-gcn.\n            self.layers.append(SAGEConv(in_size, hid_size))\n            self.layers.append(SAGEConv(hid_size, hid_size))\n            self.layers.append(SAGEConv(hid_size, out_size))\n            self.dropout = nn.Dropout(0.5)\n            self.hid_size = hid_size\n            self.out_size = out_size\n\n        def forward(self, sampled_matrices, x):\n            hidden_x = x\n            for layer_idx, (layer, sampled_matrix) in enumerate(\n                zip(self.layers, sampled_matrices)\n            ):\n                hidden_x = layer(sampled_matrix, hidden_x)\n                if layer_idx != len(self.layers) - 1:\n                    hidden_x = F.relu(hidden_x)\n                    hidden_x = self.dropout(hidden_x)\n            return hidden_x\n\n\nLaunch training:\n\n.. code:: python\n\n    features = dataset.feature\n    # Create GraphSAGE model.\n    in_size = features.size(\"node\", None, \"feat\")[0]\n    num_classes = dataset.tasks[0].metadata[\"num_classes\"]\n    out_size = num_classes\n    model = SAGE(in_size, 256, out_size).to(device)\n\n    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4)\n\n    for epoch in range(10):\n        model.train()\n        total_loss = 0\n        for it, data in enumerate(dataloader):\n            node_feature = data.node_features[\"feat\"].float()\n            blocks = data.sampled_subgraphs\n            y = data.labels\n            y_hat = model(blocks, node_feature)\n            loss = F.cross_entropy(y_hat, y)\n            optimizer.zero_grad()\n            loss.backward()\n            optimizer.step()\n            total_loss += loss.item()\n\nFor more details, please refer to the `full example\n<https://github.com/dmlc/dgl/blob/master/examples/graphbolt/sparse/graphsage.py>`__.\n"
  },
  {
    "path": "docs/source/guide/minibatch.rst",
    "content": ".. _guide-minibatch:\n\nChapter 6: Stochastic Training on Large Graphs\n=======================================================\n\n:ref:`(中文版) <guide_cn-minibatch>`\n\nIf we have a massive graph with, say, millions or even billions of nodes\nor edges, usually full-graph training as described in\n:ref:`guide-training`\nwould not work. Consider an :math:`L`-layer graph convolutional network\nwith hidden state size :math:`H` running on an :math:`N`-node graph.\nStoring the intermediate hidden states requires :math:`O(NLH)` memory,\neasily exceeding one GPU’s capacity with large :math:`N`.\n\nThis section provides a way to perform stochastic minibatch training,\nwhere we do not have to fit the feature of all the nodes into GPU.\n\nOverview of Neighborhood Sampling Approaches\n--------------------------------------------\n\nNeighborhood sampling methods generally work as the following. For each\ngradient descent step, we select a minibatch of nodes whose final\nrepresentations at the :math:`L`-th layer are to be computed. We then\ntake all or some of their neighbors at the :math:`L-1` layer. This\nprocess continues until we reach the input. This iterative process\nbuilds the dependency graph starting from the output and working\nbackwards to the input, as the figure below shows:\n\n.. figure:: https://data.dgl.ai/asset/image/guide_6_0_0.png\n   :alt: Imgur\n\n\n\nWith this, one can save the workload and computation resources for\ntraining a GNN on a large graph.\n\nDGL provides a few neighborhood samplers and a pipeline for training a\nGNN with neighborhood sampling, as well as ways to customize your\nsampling strategies.\n\nRoadmap\n-----------\n\nThe chapter starts with sections for training GNNs stochastically under\ndifferent scenarios.\n\n* :ref:`guide-minibatch-node-classification-sampler`\n* :ref:`guide-minibatch-edge-classification-sampler`\n* :ref:`guide-minibatch-link-classification-sampler`\n\nThe remaining sections cover more advanced topics, suitable for those who\nwish to develop new sampling algorithms, new GNN modules compatible with\nmini-batch training and understand how evaluation and inference can be\nconducted in mini-batches.\n\n* :ref:`guide-minibatch-customizing-neighborhood-sampler`\n* :ref:`guide-minibatch-sparse`\n* :ref:`guide-minibatch-custom-gnn-module`\n* :ref:`guide-minibatch-inference`\n\nThe following are performance tips for implementing and using neighborhood\nsampling:\n\n* :ref:`guide-minibatch-gpu-sampling`\n* :ref:`guide-minibatch-parallelism`\n\n.. toctree::\n    :maxdepth: 1\n    :hidden:\n    :glob:\n\n    minibatch-node\n    minibatch-edge\n    minibatch-link\n    minibatch-custom-sampler\n    minibatch-sparse\n    minibatch-nn\n    minibatch-inference\n    minibatch-gpu-sampling\n    minibatch-parallelism\n"
  },
  {
    "path": "docs/source/guide/mixed_precision.rst",
    "content": ".. _guide-mixed_precision:\n\nChapter 8: Mixed Precision Training\n===================================\nDGL is compatible with the `PyTorch Automatic Mixed Precision (AMP) package\n<https://pytorch.org/docs/stable/amp.html>`_\nfor mixed precision training, thus saving both training time and GPU/CPU memory\nconsumption. This feature requires DGL 0.9+ and 1.1+ for CPU bloat16.\n\nMessage-Passing with Half Precision\n-----------------------------------\nDGL allows message-passing on ``float16 (fp16)`` / ``bfloat16 (bf16)``\nfeatures for both UDFs (User Defined Functions) and built-in functions\n(e.g., ``dgl.function.sum``, ``dgl.function.copy_u``).\n\n.. note::\n   Please check bfloat16 support via ``torch.cuda.is_bf16_supported()`` before using it.\n   Typically it requires CUDA >= 11.0 and GPU compute capability >= 8.0.\n\nThe following example shows how to use DGL's message-passing APIs on half-precision\nfeatures:\n\n    >>> import torch\n    >>> import dgl\n    >>> import dgl.function as fn\n    >>> dev = torch.device('cuda')\n    >>> g = dgl.rand_graph(30, 100).to(dev)  # Create a graph on GPU w/ 30 nodes and 100 edges.\n    >>> g.ndata['h'] = torch.rand(30, 16).to(dev).half()  # Create fp16 node features.\n    >>> g.edata['w'] = torch.rand(100, 1).to(dev).half()  # Create fp16 edge features.\n    >>> # Use DGL's built-in functions for message passing on fp16 features.\n    >>> g.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'x'))\n    >>> g.ndata['x'].dtype\n    torch.float16\n    >>> g.apply_edges(fn.u_dot_v('h', 'x', 'hx'))\n    >>> g.edata['hx'].dtype\n    torch.float16\n\n    >>> # Use UDFs for message passing on fp16 features.\n    >>> def message(edges):\n    ...     return {'m': edges.src['h'] * edges.data['w']}\n    ...\n    >>> def reduce(nodes):\n    ...     return {'y': torch.sum(nodes.mailbox['m'], 1)}\n    ...\n    >>> def dot(edges):\n    ...     return {'hy': (edges.src['h'] * edges.dst['y']).sum(-1, keepdims=True)}\n    ...\n    >>> g.update_all(message, reduce)\n    >>> g.ndata['y'].dtype\n    torch.float16\n    >>> g.apply_edges(dot)\n    >>> g.edata['hy'].dtype\n    torch.float16\n\nEnd-to-End Mixed Precision Training\n-----------------------------------\nDGL relies on PyTorch's AMP package for mixed precision training,\nand the user experience is exactly\nthe same as `PyTorch's <https://pytorch.org/docs/stable/notes/amp_examples.html>`_.\n\nBy wrapping the forward pass with ``torch.amp.autocast()``, PyTorch automatically\nselects the appropriate datatype for each op and tensor. Half precision tensors are memory\nefficient, most operators on half precision tensors are faster as they leverage GPU tensorcores\nand CPU special instructon set.\n\n.. code::\n\n    import torch.nn.functional as F\n    from torch.amp import autocast\n\n    def forward(device_type, g, feat, label, mask, model, amp_dtype):\n        amp_enabled = amp_dtype in (torch.float16, torch.bfloat16)\n        with autocast(device_type, enabled=amp_enabled, dtype=amp_dtype):\n            logit = model(g, feat)\n            loss = F.cross_entropy(logit[mask], label[mask])\n            return loss\n\nSmall Gradients in ``float16`` format have underflow problems (flush to zero).\nPyTorch provides a ``GradScaler`` module to address this issue. It multiplies\nthe loss by a factor and invokes backward pass on the scaled loss to prevent\nthe underflow problem. It then unscales the computed gradients before the optimizer\nupdates the parameters. The scale factor is determined automatically.\nNote that ``bfloat16`` doesn't require a ``GradScaler``.\n\n.. code::\n\n    from torch.cuda.amp import GradScaler\n\n    scaler = GradScaler()\n\n    def backward(scaler, loss, optimizer):\n        scaler.scale(loss).backward()\n        scaler.step(optimizer)\n        scaler.update()\n\nThe following example trains a 3-layer GAT on the Reddit dataset (w/ 114 million edges).\nPay attention to the differences in the code when AMP is activated or not.\n\n.. code::\n\n    import torch\n    import torch.nn as nn\n    import dgl\n    from dgl.data import RedditDataset\n    from dgl.nn import GATConv\n    from dgl.transforms import AddSelfLoop\n\n    amp_dtype = torch.bfloat16 # or torch.float16\n\n    class GAT(nn.Module):\n        def __init__(self,\n                     in_feats,\n                     n_hidden,\n                     n_classes,\n                     heads):\n            super().__init__()\n            self.layers = nn.ModuleList()\n            self.layers.append(GATConv(in_feats, n_hidden, heads[0], activation=F.elu))\n            self.layers.append(GATConv(n_hidden * heads[0], n_hidden, heads[1], activation=F.elu))\n            self.layers.append(GATConv(n_hidden * heads[1], n_classes, heads[2], activation=F.elu))\n\n        def forward(self, g, h):\n            for l, layer in enumerate(self.layers):\n                h = layer(g, h)\n                if l != len(self.layers) - 1:\n                    h = h.flatten(1)\n                else:\n                    h = h.mean(1)\n            return h\n\n    # Data loading\n    transform = AddSelfLoop()\n    data = RedditDataset(transform)\n    device_type = 'cuda' # or 'cpu'\n    dev = torch.device(device_type)\n\n    g = data[0]\n    g = g.int().to(dev)\n    train_mask = g.ndata['train_mask']\n    feat = g.ndata['feat']\n    label = g.ndata['label']\n\n    in_feats = feat.shape[1]\n    n_hidden = 256\n    n_classes = data.num_classes\n    heads = [1, 1, 1]\n    model = GAT(in_feats, n_hidden, n_classes, heads)\n    model = model.to(dev)\n    model.train()\n\n    # Create optimizer\n    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4)\n\n    for epoch in range(100):\n        optimizer.zero_grad()\n        loss = forward(device_type, g, feat, label, train_mask, model, amp_dtype)\n\n        if amp_dtype == torch.float16:\n            # Backprop w/ gradient scaling\n            backward(scaler, loss, optimizer)\n        else:\n            loss.backward()\n            optimizer.step()\n\n        print('Epoch {} | Loss {}'.format(epoch, loss.item()))\n\nOn a NVIDIA V100 (16GB) machine, training this model without fp16 consumes\n15.2GB GPU memory; with fp16 turned on, the training consumes 12.8G\nGPU memory, the loss converges to similar values in both settings.\nIf we change the number of heads to ``[2, 2, 2]``, training without fp16\ntriggers GPU OOM(out-of-memory) issue while training with fp16 consumes\n15.7G GPU memory.\n\nBFloat16 CPU example\n-----------------------------------\nDGL supports running training in the bfloat16 data type on the CPU.\nThis data type doesn't require any CPU feature and can improve the performance of a memory-bound model.\nStarting with Intel Xeon 4th Generation, which has `AMX\n<https://www.intel.com/content/www/us/en/products/docs/accelerator-engines/advanced-matrix-extensions/overview.html>`_ instructon set, bfloat16 should significantly improve training and inference performance without huge code changes.\nHere is an example of simple GCN bfloat16 training:\n\n.. code::\n\n    import torch\n    import torch.nn as nn\n    import torch.nn.functional as F\n    import dgl\n    from dgl.data import CiteseerGraphDataset\n    from dgl.nn import GraphConv\n    from dgl.transforms import AddSelfLoop\n\n\n    class GCN(nn.Module):\n        def __init__(self, in_size, hid_size, out_size):\n            super().__init__()\n            self.layers = nn.ModuleList()\n            # two-layer GCN\n            self.layers.append(\n                GraphConv(in_size, hid_size, activation=F.relu)\n            )\n            self.layers.append(GraphConv(hid_size, out_size))\n            self.dropout = nn.Dropout(0.5)\n    \n        def forward(self, g, features):\n            h = features\n            for i, layer in enumerate(self.layers):\n                if i != 0:\n                    h = self.dropout(h)\n                h = layer(g, h)\n            return h\n\n\n    # Data loading\n    transform = AddSelfLoop()\n    data = CiteseerGraphDataset(transform=transform)\n\n    g = data[0]\n    g = g.int()\n    train_mask = g.ndata['train_mask']\n    feat = g.ndata['feat']\n    label = g.ndata['label']\n\n    in_size = feat.shape[1]\n    hid_size = 16\n    out_size = data.num_classes\n    model = GCN(in_size, hid_size, out_size)\n    \n    # Convert model and graph to bfloat16\n    g = dgl.to_bfloat16(g)\n    feat = feat.to(dtype=torch.bfloat16)\n    model = model.to(dtype=torch.bfloat16)\n    \n    model.train()\n\n    # Create optimizer\n    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)\n    loss_fcn = nn.CrossEntropyLoss()\n\n    for epoch in range(100):\n        logits = model(g, feat)\n        loss = loss_fcn(logits[train_mask], label[train_mask])\n        \n        loss.backward()\n        optimizer.step()\n\n        print('Epoch {} | Loss {}'.format(epoch, loss.item()))\n\nThe only difference with common training is model and graph conversion before training/inference.\n\n.. code::\n    g = dgl.to_bfloat16(g)\n    feat = feat.to(dtype=torch.bfloat16)\n    model = model.to(dtype=torch.bfloat16)\n\n\nDGL is still improving its half-precision support and the compute kernel's\nperformance is far from optimal, please stay tuned to our future updates.\n"
  },
  {
    "path": "docs/source/guide/nn-construction.rst",
    "content": ".. _guide-nn-construction:\n\n3.1 DGL NN Module Construction Function\n---------------------------------------\n\n:ref:`(中文版) <guide_cn-nn-construction>`\n\nThe construction function performs the following steps:\n\n1. Set options.\n2. Register learnable parameters or submodules.\n3. Reset parameters.\n\n.. code::\n\n    import torch.nn as nn\n\n    from dgl.utils import expand_as_pair\n\n    class SAGEConv(nn.Module):\n        def __init__(self,\n                     in_feats,\n                     out_feats,\n                     aggregator_type,\n                     bias=True,\n                     norm=None,\n                     activation=None):\n            super(SAGEConv, self).__init__()\n\n            self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)\n            self._out_feats = out_feats\n            self._aggre_type = aggregator_type\n            self.norm = norm\n            self.activation = activation\n\nIn construction function, one first needs to set the data dimensions. For\ngeneral PyTorch module, the dimensions are usually input dimension,\noutput dimension and hidden dimensions. For graph neural networks, the input\ndimension can be split into source node dimension and destination node\ndimension.\n\nBesides data dimensions, a typical option for graph neural network is\naggregation type (``self._aggre_type``). Aggregation type determines how\nmessages on different edges are aggregated for a certain destination\nnode. Commonly used aggregation types include ``mean``, ``sum``,\n``max``, ``min``. Some modules may apply more complicated aggregation\nlike an ``lstm``.\n\n``norm`` here is a callable function for feature normalization. In the\nSAGEConv paper, such normalization can be l2 normalization:\n:math:`h_v = h_v / \\lVert h_v \\rVert_2`.\n\n.. code::\n\n            # aggregator type: mean, pool, lstm, gcn\n            if aggregator_type not in ['mean', 'pool', 'lstm', 'gcn']:\n                raise KeyError('Aggregator type {} not supported.'.format(aggregator_type))\n            if aggregator_type == 'pool':\n                self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)\n            if aggregator_type == 'lstm':\n                self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)\n            if aggregator_type in ['mean', 'pool', 'lstm']:\n                self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)\n            self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias)\n            self.reset_parameters()\n\nRegister parameters and submodules. In SAGEConv, submodules vary\naccording to the aggregation type. Those modules are pure PyTorch nn\nmodules like ``nn.Linear``, ``nn.LSTM``, etc. At the end of construction\nfunction, weight initialization is applied by calling\n``reset_parameters()``.\n\n.. code::\n\n        def reset_parameters(self):\n            \"\"\"Reinitialize learnable parameters.\"\"\"\n            gain = nn.init.calculate_gain('relu')\n            if self._aggre_type == 'pool':\n                nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)\n            if self._aggre_type == 'lstm':\n                self.lstm.reset_parameters()\n            if self._aggre_type != 'gcn':\n                nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)\n            nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)\n"
  },
  {
    "path": "docs/source/guide/nn-forward.rst",
    "content": ".. _guide-nn-forward:\n\n3.2 DGL NN Module Forward Function\n----------------------------------\n\n:ref:`(中文版) <guide_cn-nn-forward>`\n\nIn NN module, ``forward()`` function does the actual message passing and\ncomputation. Compared with PyTorch’s NN module which usually takes\ntensors as the parameters, DGL NN module takes an additional parameter\n:class:`dgl.DGLGraph`. The\nworkload for ``forward()`` function can be split into three parts:\n\n-  Graph checking and graph type specification.\n\n-  Message passing.\n\n-  Feature update.\n\nThe rest of the section takes a deep dive into the ``forward()`` function in SAGEConv example.\n\nGraph checking and graph type specification\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. code::\n\n        def forward(self, graph, feat):\n            with graph.local_scope():\n                # Specify graph type then expand input feature according to graph type\n                feat_src, feat_dst = expand_as_pair(feat, graph)\n\n``forward()`` needs to handle many corner cases on the input that can\nlead to invalid values in computing and message passing. One typical check in conv modules\nlike :class:`~dgl.nn.pytorch.conv.GraphConv` is to verify that the input graph has no 0-in-degree nodes.\nWhen a node has 0 in-degree, the ``mailbox`` will be empty and the reduce function will produce\nall-zero values. This may cause silent regression in model performance. However, in\n:class:`~dgl.nn.pytorch.conv.SAGEConv` module, the aggregated representation will be concatenated\nwith the original node feature, the output of ``forward()`` will not be all-zero. No such check is\nneeded in this case.\n\nDGL NN module should be reusable across different types of graph input\nincluding: homogeneous graph, heterogeneous\ngraph (:ref:`guide-graph-heterogeneous`), subgraph\nblock (:ref:`guide-minibatch`).\n\nThe math formulas for SAGEConv are:\n\n.. math::\n\n\n   h_{\\mathcal{N}(dst)}^{(l+1)}  = \\mathrm{aggregate}\n           \\left(\\{h_{src}^{l}, \\forall src \\in \\mathcal{N}(dst) \\}\\right)\n\n.. math::\n\n    h_{dst}^{(l+1)} = \\sigma \\left(W \\cdot \\mathrm{concat}\n           (h_{dst}^{l}, h_{\\mathcal{N}(dst)}^{l+1}) + b \\right)\n\n.. math::\n\n    h_{dst}^{(l+1)} = \\mathrm{norm}(h_{dst}^{l+1})\n\nOne needs to specify the source node feature ``feat_src`` and destination\nnode feature ``feat_dst`` according to the graph type.\n:meth:`~dgl.utils.expand_as_pair` is a function that specifies the graph\ntype and expand ``feat`` into ``feat_src`` and ``feat_dst``.\nThe detail of this function is shown below.\n\n.. code::\n\n    def expand_as_pair(input_, g=None):\n        if isinstance(input_, tuple):\n            # Bipartite graph case\n            return input_\n        elif g is not None and g.is_block:\n            # Subgraph block case\n            if isinstance(input_, Mapping):\n                input_dst = {\n                    k: F.narrow_row(v, 0, g.number_of_dst_nodes(k))\n                    for k, v in input_.items()}\n            else:\n                input_dst = F.narrow_row(input_, 0, g.number_of_dst_nodes())\n            return input_, input_dst\n        else:\n            # Homogeneous graph case\n            return input_, input_\n\nFor homogeneous whole graph training, source nodes and destination nodes\nare the same. They are all the nodes in the graph.\n\nFor heterogeneous case, the graph can be split into several bipartite\ngraphs, one for each relation. The relations are represented as\n``(src_type, edge_type, dst_dtype)``. When it identifies that the input feature\n``feat`` is a tuple, it will treat the graph as bipartite. The first\nelement in the tuple will be the source node feature and the second\nelement will be the destination node feature.\n\nIn mini-batch training, the computing is applied on a subgraph sampled\nbased on a bunch of destination nodes. The subgraph is called as\n``block`` in DGL. In the block creation phase,\n``dst nodes`` are in the front of the node list. One can find the\n``feat_dst`` by the index ``[0:g.number_of_dst_nodes()]``.\n\nAfter determining ``feat_src`` and ``feat_dst``, the computing for the\nabove three graph types are the same.\n\nMessage passing and reducing\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. code::\n\n                import dgl.function as fn\n                import torch.nn.functional as F\n                from dgl.utils import check_eq_shape\n\n                if self._aggre_type == 'mean':\n                    graph.srcdata['h'] = feat_src\n                    graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh'))\n                    h_neigh = graph.dstdata['neigh']\n                elif self._aggre_type == 'gcn':\n                    check_eq_shape(feat)\n                    graph.srcdata['h'] = feat_src\n                    graph.dstdata['h'] = feat_dst\n                    graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh'))\n                    # divide in_degrees\n                    degs = graph.in_degrees().to(feat_dst)\n                    h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)\n                elif self._aggre_type == 'pool':\n                    graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))\n                    graph.update_all(fn.copy_u('h', 'm'), fn.max('m', 'neigh'))\n                    h_neigh = graph.dstdata['neigh']\n                else:\n                    raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))\n\n                # GraphSAGE GCN does not require fc_self.\n                if self._aggre_type == 'gcn':\n                    rst = self.fc_neigh(h_neigh)\n                else:\n                    rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)\n\nThe code actually does message passing and reducing computing. This part\nof code varies module by module. Note that all the message passing in\nthe above code are implemented using :meth:`~dgl.DGLGraph.update_all` API and\n``built-in`` message/reduce functions to fully utilize DGL’s performance\noptimization as described in :ref:`guide-message-passing-efficient`.\n\nUpdate feature after reducing for output\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. code::\n\n                # activation\n                if self.activation is not None:\n                    rst = self.activation(rst)\n                # normalization\n                if self.norm is not None:\n                    rst = self.norm(rst)\n                return rst\n\nThe last part of ``forward()`` function is to update the feature after\nthe ``reduce function``. Common update operations are applying\nactivation function and normalization according to the option set in the\nobject construction phase.\n"
  },
  {
    "path": "docs/source/guide/nn-heterograph.rst",
    "content": ".. _guide-nn-heterograph:\n\n3.3 Heterogeneous GraphConv Module\n------------------------------------\n\n:ref:`(中文版) <guide_cn-nn-heterograph>`\n\n:class:`~dgl.nn.pytorch.HeteroGraphConv`\nis a module-level encapsulation to run DGL NN module on heterogeneous\ngraphs. The implementation logic is the same as message passing level API\n:meth:`~dgl.DGLGraph.multi_update_all`, including:\n\n-  DGL NN module within each relation :math:`r`.\n-  Reduction that merges the results on the same node type from multiple\n   relations.\n\nThis can be formulated as:\n\n.. math::  h_{dst}^{(l+1)} = \\underset{r\\in\\mathcal{R}, r_{dst}=dst}{AGG} (f_r(g_r, h_{r_{src}}^l, h_{r_{dst}}^l))\n\nwhere :math:`f_r` is the NN module for each relation :math:`r`,\n:math:`AGG` is the aggregation function.\n\nHeteroGraphConv implementation logic:\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. code::\n\n    import torch.nn as nn\n\n    class HeteroGraphConv(nn.Module):\n        def __init__(self, mods, aggregate='sum'):\n            super(HeteroGraphConv, self).__init__()\n            self.mods = nn.ModuleDict(mods)\n            if isinstance(aggregate, str):\n                # An internal function to get common aggregation functions\n                self.agg_fn = get_aggregate_fn(aggregate)\n            else:\n                self.agg_fn = aggregate\n\nThe heterograph convolution takes a dictionary ``mods`` that maps each\nrelation to an nn module and sets the function that aggregates results on\nthe same node type from multiple relations.\n\n.. code::\n\n    def forward(self, g, inputs, mod_args=None, mod_kwargs=None):\n        if mod_args is None:\n            mod_args = {}\n        if mod_kwargs is None:\n            mod_kwargs = {}\n        outputs = {nty : [] for nty in g.dsttypes}\n\nBesides input graph and input tensors, the ``forward()`` function takes\ntwo additional dictionary parameters ``mod_args`` and ``mod_kwargs``.\nThese two dictionaries have the same keys as ``self.mods``. They are\nused as customized parameters when calling their corresponding NN\nmodules in ``self.mods`` for different types of relations.\n\nAn output dictionary is created to hold output tensor for each\ndestination type ``nty`` . Note that the value for each ``nty`` is a\nlist, indicating a single node type may get multiple outputs if more\nthan one relations have ``nty`` as the destination type. ``HeteroGraphConv``\nwill perform a further aggregation on the lists.\n\n.. code::\n\n          if g.is_block:\n              src_inputs = inputs\n              dst_inputs = {k: v[:g.number_of_dst_nodes(k)] for k, v in inputs.items()}\n          else:\n              src_inputs = dst_inputs = inputs\n\n          for stype, etype, dtype in g.canonical_etypes:\n              rel_graph = g[stype, etype, dtype]\n              if rel_graph.num_edges() == 0:\n                  continue\n              if stype not in src_inputs or dtype not in dst_inputs:\n                  continue\n              dstdata = self.mods[etype](\n                  rel_graph,\n                  (src_inputs[stype], dst_inputs[dtype]),\n                  *mod_args.get(etype, ()),\n                  **mod_kwargs.get(etype, {}))\n              outputs[dtype].append(dstdata)\n\nThe input ``g`` can be a heterogeneous graph or a subgraph block from a\nheterogeneous graph. As in ordinary NN module, the ``forward()``\nfunction need to handle different input graph types separately.\n\nEach relation is represented as a ``canonical_etype``, which is\n``(stype, etype, dtype)``. Using ``canonical_etype`` as the key, one can\nextract out a bipartite graph ``rel_graph``. For bipartite graph, the\ninput feature will be organized as a tuple\n``(src_inputs[stype], dst_inputs[dtype])``. The NN module for each\nrelation is called and the output is saved. To avoid unnecessary call,\nrelations with no edges or no nodes with the src type will be skipped.\n\n.. code::\n\n        rsts = {}\n        for nty, alist in outputs.items():\n            if len(alist) != 0:\n                rsts[nty] = self.agg_fn(alist, nty)\n\nFinally, the results on the same destination node type from multiple\nrelations are aggregated using ``self.agg_fn`` function. Examples can\nbe found in the API Doc for :class:`~dgl.nn.pytorch.HeteroGraphConv`.\n"
  },
  {
    "path": "docs/source/guide/nn.rst",
    "content": ".. _guide-nn:\n\nChapter 3: Building GNN Modules\n===============================\n\n:ref:`(中文版) <guide_cn-nn>`\n\nDGL NN module consists of building blocks for GNN models. An NN module inherits\nfrom `Pytorch’s NN Module <https://pytorch.org/docs/1.2.0/_modules/torch/nn/modules/module.html>`__, `MXNet Gluon’s NN Block  <http://mxnet.incubator.apache.org/versions/1.6/api/python/docs/api/gluon/nn/index.html>`__ and `TensorFlow’s Keras\nLayer <https://www.tensorflow.org/api_docs/python/tf/keras/layers>`__, depending on the DNN framework backend in use. In a DGL NN\nmodule, the parameter registration in construction function and tensor\noperation in forward function are the same with the backend framework.\nIn this way, DGL code can be seamlessly integrated into the backend\nframework code. The major difference lies in the message passing\noperations that are unique in DGL.\n\nDGL has integrated many commonly used\n:ref:`apinn-pytorch-conv`, :ref:`apinn-pytorch-dense-conv`, :ref:`apinn-pytorch-pooling`,\nand\n:ref:`apinn-pytorch-util`. We welcome your contribution!\n\nThis chapter takes :class:`~dgl.nn.pytorch.conv.SAGEConv` with Pytorch backend as an example\nto introduce how to build a custom DGL NN Module.\n\nRoadmap\n-------\n\n* :ref:`guide-nn-construction`\n* :ref:`guide-nn-forward`\n* :ref:`guide-nn-heterograph`\n\n.. toctree::\n    :maxdepth: 1\n    :hidden:\n    :glob:\n\n    nn-construction\n    nn-forward\n    nn-heterograph\n"
  },
  {
    "path": "docs/source/guide/training-edge.rst",
    "content": ".. _guide-training-edge-classification:\n\n5.2 Edge Classification/Regression\n---------------------------------------------\n\n:ref:`(中文版) <guide_cn-training-edge-classification>`\n\nSometimes you wish to predict the attributes on the edges of the graph. In that\ncase, you would like to have an *edge classification/regression* model.\n\nHere we generate a random graph for edge prediction as a demonstration.\n\n.. code:: python\n\n    src = np.random.randint(0, 100, 500)\n    dst = np.random.randint(0, 100, 500)\n    # make it symmetric\n    edge_pred_graph = dgl.graph((np.concatenate([src, dst]), np.concatenate([dst, src])))\n    # synthetic node and edge features, as well as edge labels\n    edge_pred_graph.ndata['feature'] = torch.randn(100, 10)\n    edge_pred_graph.edata['feature'] = torch.randn(1000, 10)\n    edge_pred_graph.edata['label'] = torch.randn(1000)\n    # synthetic train-validation-test splits\n    edge_pred_graph.edata['train_mask'] = torch.zeros(1000, dtype=torch.bool).bernoulli(0.6)\n\nOverview\n~~~~~~~~\n\nFrom the previous section you have learned how to do node classification\nwith a multilayer GNN. The same technique can be applied for computing a\nhidden representation of any node. The prediction on edges can then be\nderived from the representation of their incident nodes.\n\nThe most common case of computing the prediction on an edge is to\nexpress it as a parameterized function of the representation of its\nincident nodes, and optionally the features on the edge itself.\n\nModel Implementation Difference from Node Classification\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nAssuming that you compute the node representation with the model from\nthe previous section, you only need to write another component that\ncomputes the edge prediction with the\n:meth:`~dgl.DGLGraph.apply_edges` method.\n\nFor instance, if you would like to compute a score for each edge for\nedge regression, the following code computes the dot product of incident\nnode representations on each edge.\n\n.. code:: python\n\n    import dgl.function as fn\n    class DotProductPredictor(nn.Module):\n        def forward(self, graph, h):\n            # h contains the node representations computed from the GNN defined\n            # in the node classification section (Section 5.1).\n            with graph.local_scope():\n                graph.ndata['h'] = h\n                graph.apply_edges(fn.u_dot_v('h', 'h', 'score'))\n                return graph.edata['score']\n\nOne can also write a prediction function that predicts a vector for each\nedge with an MLP. Such vector can be used in further downstream tasks,\ne.g. as logits of a categorical distribution.\n\n.. code:: python\n\n    class MLPPredictor(nn.Module):\n        def __init__(self, in_features, out_classes):\n            super().__init__()\n            self.W = nn.Linear(in_features * 2, out_classes)\n\n        def apply_edges(self, edges):\n            h_u = edges.src['h']\n            h_v = edges.dst['h']\n            score = self.W(torch.cat([h_u, h_v], 1))\n            return {'score': score}\n\n        def forward(self, graph, h):\n            # h contains the node representations computed from the GNN defined\n            # in the node classification section (Section 5.1).\n            with graph.local_scope():\n                graph.ndata['h'] = h\n                graph.apply_edges(self.apply_edges)\n                return graph.edata['score']\n\nTraining loop\n~~~~~~~~~~~~~\n\nGiven the node representation computation model and an edge predictor\nmodel, we can easily write a full-graph training loop where we compute\nthe prediction on all edges.\n\nThe following example takes ``SAGE`` in the previous section as the node\nrepresentation computation model and ``DotPredictor`` as an edge\npredictor model.\n\n.. code:: python\n\n    class Model(nn.Module):\n        def __init__(self, in_features, hidden_features, out_features):\n            super().__init__()\n            self.sage = SAGE(in_features, hidden_features, out_features)\n            self.pred = DotProductPredictor()\n        def forward(self, g, x):\n            h = self.sage(g, x)\n            return self.pred(g, h)\n\nIn this example, we also assume that the training/validation/test edge\nsets are identified by boolean masks on edges. This example also does\nnot include early stopping and model saving.\n\n.. code:: python\n\n    node_features = edge_pred_graph.ndata['feature']\n    edge_label = edge_pred_graph.edata['label']\n    train_mask = edge_pred_graph.edata['train_mask']\n    model = Model(10, 20, 5)\n    opt = torch.optim.Adam(model.parameters())\n    for epoch in range(10):\n        pred = model(edge_pred_graph, node_features)\n        loss = ((pred[train_mask] - edge_label[train_mask]) ** 2).mean()\n        opt.zero_grad()\n        loss.backward()\n        opt.step()\n        print(loss.item())\n\n.. _guide-training-edge-classification-heterogeneous-graph:\n\nHeterogeneous graph\n~~~~~~~~~~~~~~~~~~~\n\nEdge classification on heterogeneous graphs is not very different from\nthat on homogeneous graphs. If you wish to perform edge classification\non one edge type, you only need to compute the node representation for\nall node types, and predict on that edge type with\n:meth:`~dgl.DGLGraph.apply_edges` method.\n\nFor example, to make ``DotProductPredictor`` work on one edge type of a\nheterogeneous graph, you only need to specify the edge type in\n``apply_edges`` method.\n\n.. code:: python\n\n    class HeteroDotProductPredictor(nn.Module):\n        def forward(self, graph, h, etype):\n            # h contains the node representations for each edge type computed from\n            # the GNN for heterogeneous graphs defined in the node classification\n            # section (Section 5.1).\n            with graph.local_scope():\n                graph.ndata['h'] = h   # assigns 'h' of all node types in one shot\n                graph.apply_edges(fn.u_dot_v('h', 'h', 'score'), etype=etype)\n                return graph.edges[etype].data['score']\n\nYou can similarly write a ``HeteroMLPPredictor``.\n\n.. code:: python\n\n    class HeteroMLPPredictor(nn.Module):\n        def __init__(self, in_features, out_classes):\n            super().__init__()\n            self.W = nn.Linear(in_features * 2, out_classes)\n\n        def apply_edges(self, edges):\n            h_u = edges.src['h']\n            h_v = edges.dst['h']\n            score = self.W(torch.cat([h_u, h_v], 1))\n            return {'score': score}\n\n        def forward(self, graph, h, etype):\n            # h contains the node representations for each edge type computed from\n            # the GNN for heterogeneous graphs defined in the node classification\n            # section (Section 5.1).\n            with graph.local_scope():\n                graph.ndata['h'] = h   # assigns 'h' of all node types in one shot\n                graph.apply_edges(self.apply_edges, etype=etype)\n                return graph.edges[etype].data['score']\n\nThe end-to-end model that predicts a score for each edge on a single\nedge type will look like this:\n\n.. code:: python\n\n    class Model(nn.Module):\n        def __init__(self, in_features, hidden_features, out_features, rel_names):\n            super().__init__()\n            self.sage = RGCN(in_features, hidden_features, out_features, rel_names)\n            self.pred = HeteroDotProductPredictor()\n        def forward(self, g, x, etype):\n            h = self.sage(g, x)\n            return self.pred(g, h, etype)\n\nUsing the model simply involves feeding the model a dictionary of node\ntypes and features.\n\n.. code:: python\n\n    model = Model(10, 20, 5, hetero_graph.etypes)\n    user_feats = hetero_graph.nodes['user'].data['feature']\n    item_feats = hetero_graph.nodes['item'].data['feature']\n    label = hetero_graph.edges['click'].data['label']\n    train_mask = hetero_graph.edges['click'].data['train_mask']\n    node_features = {'user': user_feats, 'item': item_feats}\n\nThen the training loop looks almost the same as that in homogeneous\ngraph. For instance, if you wish to predict the edge labels on edge type\n``click``, then you can simply do\n\n.. code:: python\n\n    opt = torch.optim.Adam(model.parameters())\n    for epoch in range(10):\n        pred = model(hetero_graph, node_features, 'click')\n        loss = ((pred[train_mask] - label[train_mask]) ** 2).mean()\n        opt.zero_grad()\n        loss.backward()\n        opt.step()\n        print(loss.item())\n\n\nPredicting Edge Type of an Existing Edge on a Heterogeneous Graph\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nSometimes you may want to predict which type an existing edge belongs\nto.\n\nFor instance, given the\n:ref:`heterogeneous graph example <guide-training-heterogeneous-graph-example>`,\nyour task is given an edge connecting a user and an item, to predict whether\nthe user would ``click`` or ``dislike`` an item.\n\nThis is a simplified version of rating prediction, which is common in\nrecommendation literature.\n\nYou can use a heterogeneous graph convolution network to obtain the node\nrepresentations. For instance, you can still use the\n:ref:`RGCN defined previously <guide-training-rgcn-node-classification>`\nfor this purpose.\n\nTo predict the type of an edge, you can simply repurpose the\n``HeteroDotProductPredictor`` above so that it takes in another graph\nwith only one edge type that “merges” all the edge types to be\npredicted, and emits the score of each type for every edge.\n\nIn the example here, you will need a graph that has two node types\n``user`` and ``item``, and one single edge type that “merges” all the\nedge types from ``user`` and ``item``, i.e. ``click`` and ``dislike``.\nThis can be conveniently created using the following syntax:\n\n.. code:: python\n\n    dec_graph = hetero_graph['user', :, 'item']\n\nwhich returns a heterogeneous graphs with node type ``user`` and ``item``,\nas well as a single edge type combining all edge types in between, i.e.\n``click`` and ``dislike``.\n\nSince the statement above also returns the original edge types as a\nfeature named ``dgl.ETYPE``, we can use that as labels.\n\n.. code:: python\n\n    edge_label = dec_graph.edata[dgl.ETYPE]\n\nGiven the graph above as input to the edge type predictor module, you\ncan write your predictor module as follows.\n\n.. code:: python\n\n    class HeteroMLPPredictor(nn.Module):\n        def __init__(self, in_dims, n_classes):\n            super().__init__()\n            self.W = nn.Linear(in_dims * 2, n_classes)\n\n        def apply_edges(self, edges):\n            x = torch.cat([edges.src['h'], edges.dst['h']], 1)\n            y = self.W(x)\n            return {'score': y}\n\n        def forward(self, graph, h):\n            # h contains the node representations for each edge type computed from\n            # the GNN for heterogeneous graphs defined in the node classification\n            # section (Section 5.1).\n            with graph.local_scope():\n                graph.ndata['h'] = h   # assigns 'h' of all node types in one shot\n                graph.apply_edges(self.apply_edges)\n                return graph.edata['score']\n\nThe model that combines the node representation module and the edge type\npredictor module is the following:\n\n.. code:: python\n\n    class Model(nn.Module):\n        def __init__(self, in_features, hidden_features, out_features, rel_names):\n            super().__init__()\n            self.sage = RGCN(in_features, hidden_features, out_features, rel_names)\n            self.pred = HeteroMLPPredictor(out_features, len(rel_names))\n        def forward(self, g, x, dec_graph):\n            h = self.sage(g, x)\n            return self.pred(dec_graph, h)\n\nThe training loop then simply be the following:\n\n.. code:: python\n\n    model = Model(10, 20, 5, hetero_graph.etypes)\n    user_feats = hetero_graph.nodes['user'].data['feature']\n    item_feats = hetero_graph.nodes['item'].data['feature']\n    node_features = {'user': user_feats, 'item': item_feats}\n\n    opt = torch.optim.Adam(model.parameters())\n    for epoch in range(10):\n        logits = model(hetero_graph, node_features, dec_graph)\n        loss = F.cross_entropy(logits, edge_label)\n        opt.zero_grad()\n        loss.backward()\n        opt.step()\n        print(loss.item())\n\n\nDGL provides `Graph Convolutional Matrix\nCompletion <https://github.com/dmlc/dgl/tree/master/examples/pytorch/gcmc>`__\nas an example of rating prediction, which is formulated by predicting\nthe type of an existing edge on a heterogeneous graph. The node\nrepresentation module in the `model implementation\nfile <https://github.com/dmlc/dgl/tree/master/examples/pytorch/gcmc>`__\nis called ``GCMCLayer``. The edge type predictor module is called\n``BiDecoder``. Both of them are more complicated than the setting\ndescribed here.\n"
  },
  {
    "path": "docs/source/guide/training-eweight.rst",
    "content": ".. _guide-training-eweight:\n\n5.5 Use of Edge Weights\n----------------------------------\n\n:ref:`(中文版) <guide_cn-training-eweight>`\n\nIn a weighted graph, each edge is associated with a semantically meaningful scalar weight. For\nexample, the edge weights can be connectivity strengths or confidence scores. Naturally, one\nmay want to utilize edge weights in model development.\n\nMessage Passing with Edge Weights\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nMost graph neural networks (GNNs) integrate the graph topology information in forward computation\nby and only by the message passing mechanism. A message passing operation can be viewed as\na function that takes an adjacency matrix and additional input features as input arguments. For an\nunweighted graph, the entries in the adjacency matrix can be zero or one, where a one-valued entry\nindicates an edge. If this graph is weighted, the non-zero entries can take arbitrary scalar\nvalues. This is equivalent to multiplying each message by its corresponding edge weight as in\n`GAT <https://arxiv.org/pdf/1710.10903.pdf>`__.\n\nWith DGL, one can achieve this by:\n\n- Saving the edge weights as an edge feature\n- Multplying the original message by the edge feature in the message function\n\nConsider the message passing example with DGL below.\n\n.. code::\n\n    import dgl.function as fn\n\n    # Suppose graph.ndata['ft'] stores the input node features\n    graph.update_all(fn.copy_u('ft', 'm'), fn.sum('m', 'ft'))\n\nOne can modify it for edge weight support as follows.\n\n.. code::\n\n    import dgl.function as fn\n\n    # Save edge weights as an edge feature, which is a tensor of shape (E, *)\n    # E is the number of edges\n    graph.edata['w'] = eweight\n\n    # Suppose graph.ndata['ft'] stores the input node features\n    graph.update_all(fn.u_mul_e('ft', 'w', 'm'), fn.sum('m', 'ft'))\n\nUsing NN Modules with Edge Weights\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nOne can modify an NN module for edge weight support by modifying all message passing operations\nin it. The following code snippet is an example for NN module supporting edge weights.\n\n.. code::\n    import dgl.function as fn\n    import torch.nn as nn\n\n    class GNN(nn.Module):\n        def __init__(self, in_feats, out_feats):\n            super().__init__()\n            self.linear = nn.Linear(in_feats, out_feats)\n\n        def forward(self, g, feat, edge_weight=None):\n            with g.local_scope():\n                g.ndata['ft'] = self.linear(feat)\n                if edge_weight is None:\n                    msg_func = fn.copy_u('ft', 'm')\n                else:\n                    g.edata['w'] = edge_weight\n                    msg_func = fn.u_mul_e('ft', 'w', 'm')\n                g.update_all(msg_func, fn.sum('m', 'ft'))\n                return g.ndata['ft']\n\nDGL's built-in NN modules support edge weights if they take an optional :attr:`edge_weight`\nargument in the forward function.\n\nOne may need to normalize raw edge weights. In this regard, DGL provides\n:func:`~dgl.nn.pytorch.conv.EdgeWeightNorm`.\n"
  },
  {
    "path": "docs/source/guide/training-graph.rst",
    "content": ".. _guide-training-graph-classification:\n\n5.4 Graph Classification\n----------------------------------\n\n:ref:`(中文版) <guide_cn-training-graph-classification>`\n\nInstead of a big single graph, sometimes one might have the data in the\nform of multiple graphs, for example a list of different types of\ncommunities of people. By characterizing the friendship among people in\nthe same community by a graph, one can get a list of graphs to classify. In\nthis scenario, a graph classification model could help identify the type\nof the community, i.e. to classify each graph based on the structure and\noverall information.\n\nOverview\n~~~~~~~~\n\nThe major difference between graph classification and node\nclassification or link prediction is that the prediction result\ncharacterizes the property of the entire input graph. One can perform the\nmessage passing over nodes/edges just like the previous tasks, but also\nneeds to retrieve a graph-level representation.\n\nThe graph classification pipeline proceeds as follows:\n\n.. figure:: https://data.dgl.ai/tutorial/batch/graph_classifier.png\n   :alt: Graph Classification Process\n\n   Graph Classification Process\n\nFrom left to right, the common practice is:\n\n-  Prepare a batch of graphs\n-  Perform message passing on the batched graphs to update node/edge features\n-  Aggregate node/edge features into graph-level representations\n-  Classify graphs based on graph-level representations\n\nBatch of Graphs\n^^^^^^^^^^^^^^^\n\nUsually a graph classification task trains on a lot of graphs, and it\nwill be very inefficient to use only one graph at a time when\ntraining the model. Borrowing the idea of mini-batch training from\ncommon deep learning practice, one can build a batch of multiple graphs\nand send them together for one training iteration.\n\nIn DGL, one can build a single batched graph from a list of graphs. This\nbatched graph can be simply used as a single large graph, with connected\ncomponents corresponding to the original small graphs.\n\n.. figure:: https://data.dgl.ai/tutorial/batch/batch.png\n   :alt: Batched Graph\n\n   Batched Graph\n\nThe following example calls :func:`dgl.batch` on a list of graphs.\nA batched graph is a single graph, while it also carries information\nabout the list.\n\n.. code:: python\n\n    import dgl\n    import torch as th\n\n    g1 = dgl.graph((th.tensor([0, 1, 2]), th.tensor([1, 2, 3])))\n    g2 = dgl.graph((th.tensor([0, 0, 0, 1]), th.tensor([0, 1, 2, 0])))\n\n    bg = dgl.batch([g1, g2])\n    bg\n    # Graph(num_nodes=7, num_edges=7,\n    #       ndata_schemes={}\n    #       edata_schemes={})\n    bg.batch_size\n    # 2\n    bg.batch_num_nodes()\n    # tensor([4, 3])\n    bg.batch_num_edges()\n    # tensor([3, 4])\n    bg.edges()\n    # (tensor([0, 1, 2, 4, 4, 4, 5], tensor([1, 2, 3, 4, 5, 6, 4]))\n\nPlease note that most dgl transformation functions will discard the batch information.\nIn order to maintain such information, please use :func:`dgl.DGLGraph.set_batch_num_nodes`\nand :func:`dgl.DGLGraph.set_batch_num_edges` on the transformed graph.\n\nGraph Readout\n^^^^^^^^^^^^^\n\nEvery graph in the data may have its unique structure, as well as its\nnode and edge features. In order to make a single prediction, one usually\naggregates and summarizes over the possibly abundant information. This\ntype of operation is named *readout*. Common readout operations include\nsummation, average, maximum or minimum over all node or edge features.\n\nGiven a graph :math:`g`, one can define the average node feature readout as\n\n.. math:: h_g = \\frac{1}{|\\mathcal{V}|}\\sum_{v\\in \\mathcal{V}}h_v\n\nwhere :math:`h_g` is the representation of :math:`g`, :math:`\\mathcal{V}` is\nthe set of nodes in :math:`g`, :math:`h_v` is the feature of node :math:`v`.\n\nDGL provides built-in support for common readout operations. For example,\n:func:`dgl.mean_nodes` implements the above readout operation.\n\nOnce :math:`h_g` is available, one can pass it through an MLP layer for\nclassification output.\n\nWriting Neural Network Model\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nThe input to the model is the batched graph with node and edge features.\n\nComputation on a Batched Graph\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nFirst, different graphs in a batch are entirely separated, i.e. no edges\nbetween any two graphs. With this nice property, all message passing\nfunctions still have the same results.\n\nSecond, the readout function on a batched graph will be conducted over\neach graph separately. Assuming the batch size is :math:`B` and the\nfeature to be aggregated has dimension :math:`D`, the shape of the\nreadout result will be :math:`(B, D)`.\n\n.. code:: python\n\n    import dgl\n    import torch\n\n    g1 = dgl.graph(([0, 1], [1, 0]))\n    g1.ndata['h'] = torch.tensor([1., 2.])\n    g2 = dgl.graph(([0, 1], [1, 2]))\n    g2.ndata['h'] = torch.tensor([1., 2., 3.])\n\n    dgl.readout_nodes(g1, 'h')\n    # tensor([3.])  # 1 + 2\n\n    bg = dgl.batch([g1, g2])\n    dgl.readout_nodes(bg, 'h')\n    # tensor([3., 6.])  # [1 + 2, 1 + 2 + 3]\n\nFinally, each node/edge feature in a batched graph is obtained by\nconcatenating the corresponding features from all graphs in order.\n\n.. code:: python\n\n    bg.ndata['h']\n    # tensor([1., 2., 1., 2., 3.])\n\nModel Definition\n^^^^^^^^^^^^^^^^\n\nBeing aware of the above computation rules, one can define a model as follows.\n\n.. code:: python\n\n    import dgl.nn.pytorch as dglnn\n    import torch.nn as nn\n\n    class Classifier(nn.Module):\n        def __init__(self, in_dim, hidden_dim, n_classes):\n            super(Classifier, self).__init__()\n            self.conv1 = dglnn.GraphConv(in_dim, hidden_dim)\n            self.conv2 = dglnn.GraphConv(hidden_dim, hidden_dim)\n            self.classify = nn.Linear(hidden_dim, n_classes)\n\n        def forward(self, g, h):\n            # Apply graph convolution and activation.\n            h = F.relu(self.conv1(g, h))\n            h = F.relu(self.conv2(g, h))\n            with g.local_scope():\n                g.ndata['h'] = h\n                # Calculate graph representation by average readout.\n                hg = dgl.mean_nodes(g, 'h')\n                return self.classify(hg)\n\nTraining Loop\n~~~~~~~~~~~~~\n\nData Loading\n^^^^^^^^^^^^\n\nOnce the model is defined, one can start training. Since graph\nclassification deals with lots of relatively small graphs instead of a big\nsingle one, one can train efficiently on stochastic mini-batches\nof graphs, without the need to design sophisticated graph sampling\nalgorithms.\n\nAssuming that one have a graph classification dataset as introduced in\n:ref:`guide-data-pipeline`.\n\n.. code:: python\n\n    import dgl.data\n    dataset = dgl.data.GINDataset('MUTAG', False)\n\nEach item in the graph classification dataset is a pair of a graph and\nits label. One can speed up the data loading process by taking advantage\nof the GraphDataLoader to iterate over the dataset of\ngraphs in mini-batches.\n\n.. code:: python\n\n    from dgl.dataloading import GraphDataLoader\n    dataloader = GraphDataLoader(\n        dataset,\n        batch_size=1024,\n        drop_last=False,\n        shuffle=True)\n\nTraining loop then simply involves iterating over the dataloader and\nupdating the model.\n\n.. code:: python\n\n    import torch.nn.functional as F\n\n    # Only an example, 7 is the input feature size\n    model = Classifier(7, 20, 5)\n    opt = torch.optim.Adam(model.parameters())\n    for epoch in range(20):\n        for batched_graph, labels in dataloader:\n            feats = batched_graph.ndata['attr']\n            logits = model(batched_graph, feats)\n            loss = F.cross_entropy(logits, labels)\n            opt.zero_grad()\n            loss.backward()\n            opt.step()\n\nFor an end-to-end example of graph classification, see\n`DGL's GIN example <https://github.com/dmlc/dgl/tree/master/examples/pytorch/gin>`__.\nThe training loop is inside the\nfunction ``train`` in\n`main.py <https://github.com/dmlc/dgl/blob/master/examples/pytorch/gin/main.py>`__.\nThe model implementation is inside\n`gin.py <https://github.com/dmlc/dgl/blob/master/examples/pytorch/gin/gin.py>`__\nwith more components such as using\n:class:`dgl.nn.pytorch.GINConv` (also available in MXNet and Tensorflow)\nas the graph convolution layer, batch normalization, etc.\n\nHeterogeneous graph\n~~~~~~~~~~~~~~~~~~~\n\nGraph classification with heterogeneous graphs is a little different\nfrom that with homogeneous graphs. In addition to graph convolution modules\ncompatible with heterogeneous graphs, one also needs to aggregate over the nodes of\ndifferent types in the readout function.\n\nThe following shows an example of summing up the average of node\nrepresentations for each node type.\n\n.. code:: python\n\n    class RGCN(nn.Module):\n        def __init__(self, in_feats, hid_feats, out_feats, rel_names):\n            super().__init__()\n\n            self.conv1 = dglnn.HeteroGraphConv({\n                rel: dglnn.GraphConv(in_feats, hid_feats)\n                for rel in rel_names}, aggregate='sum')\n            self.conv2 = dglnn.HeteroGraphConv({\n                rel: dglnn.GraphConv(hid_feats, out_feats)\n                for rel in rel_names}, aggregate='sum')\n\n        def forward(self, graph, inputs):\n            # inputs is features of nodes\n            h = self.conv1(graph, inputs)\n            h = {k: F.relu(v) for k, v in h.items()}\n            h = self.conv2(graph, h)\n            return h\n\n    class HeteroClassifier(nn.Module):\n        def __init__(self, in_dim, hidden_dim, n_classes, rel_names):\n            super().__init__()\n\n            self.rgcn = RGCN(in_dim, hidden_dim, hidden_dim, rel_names)\n            self.classify = nn.Linear(hidden_dim, n_classes)\n\n        def forward(self, g):\n            h = g.ndata['feat']\n            h = self.rgcn(g, h)\n            with g.local_scope():\n                g.ndata['h'] = h\n                # Calculate graph representation by average readout.\n                hg = 0\n                for ntype in g.ntypes:\n                    hg = hg + dgl.mean_nodes(g, 'h', ntype=ntype)\n                return self.classify(hg)\n\nThe rest of the code is not different from that for homogeneous graphs.\n\n.. code:: python\n\n    # etypes is the list of edge types as strings.\n    model = HeteroClassifier(10, 20, 5, etypes)\n    opt = torch.optim.Adam(model.parameters())\n    for epoch in range(20):\n        for batched_graph, labels in dataloader:\n            logits = model(batched_graph)\n            loss = F.cross_entropy(logits, labels)\n            opt.zero_grad()\n            loss.backward()\n            opt.step()\n"
  },
  {
    "path": "docs/source/guide/training-link.rst",
    "content": ".. _guide-training-link-prediction:\n\n5.3 Link Prediction\n---------------------------\n\n:ref:`(中文版) <guide_cn-training-link-prediction>`\n\nIn some other settings you may want to predict whether an edge exists\nbetween two given nodes or not. Such task is called a *link prediction*\ntask.\n\nOverview\n~~~~~~~~\n\nA GNN-based link prediction model represents the likelihood of\nconnectivity between two nodes :math:`u` and :math:`v` as a function of\n:math:`\\boldsymbol{h}_u^{(L)}` and :math:`\\boldsymbol{h}_v^{(L)}`, their\nnode representation computed from the multi-layer GNN.\n\n.. math::\n\n\n   y_{u,v} = \\phi(\\boldsymbol{h}_u^{(L)}, \\boldsymbol{h}_v^{(L)})\n\nIn this section we refer to :math:`y_{u,v}` the *score* between node\n:math:`u` and node :math:`v`.\n\nTraining a link prediction model involves comparing the scores between\nnodes connected by an edge against the scores between an arbitrary pair\nof nodes. For example, given an edge connecting :math:`u` and :math:`v`,\nwe encourage the score between node :math:`u` and :math:`v` to be higher\nthan the score between node :math:`u` and a sampled node :math:`v'` from\nan arbitrary *noise* distribution :math:`v' \\sim P_n(v)`. Such\nmethodology is called *negative sampling*.\n\nThere are lots of loss functions that can achieve the behavior above if\nminimized. A non-exhaustive list include:\n\n-  Cross-entropy loss:\n   :math:`\\mathcal{L} = - \\log \\sigma (y_{u,v}) - \\sum_{v_i \\sim P_n(v), i=1,\\dots,k}\\log \\left[ 1 - \\sigma (y_{u,v_i})\\right]`\n-  BPR loss:\n   :math:`\\mathcal{L} = \\sum_{v_i \\sim P_n(v), i=1,\\dots,k} - \\log \\sigma (y_{u,v} - y_{u,v_i})`\n-  Margin loss:\n   :math:`\\mathcal{L} = \\sum_{v_i \\sim P_n(v), i=1,\\dots,k} \\max(0, M - y_{u, v} + y_{u, v_i})`,\n   where :math:`M` is a constant hyperparameter.\n\nYou may find this idea familiar if you know what `implicit\nfeedback <https://arxiv.org/ftp/arxiv/papers/1205/1205.2618.pdf>`__ or\n`noise-contrastive\nestimation <http://proceedings.mlr.press/v9/gutmann10a/gutmann10a.pdf>`__\nis.\n\nThe neural network model to compute the score between :math:`u` and\n:math:`v` is identical to the edge regression model described\n:ref:`above <guide-training-edge-classification>`.\n\nHere is an example of using dot product to compute the scores on edges.\n\n.. code:: python\n\n    class DotProductPredictor(nn.Module):\n        def forward(self, graph, h):\n            # h contains the node representations computed from the GNN defined\n            # in the node classification section (Section 5.1).\n            with graph.local_scope():\n                graph.ndata['h'] = h\n                graph.apply_edges(fn.u_dot_v('h', 'h', 'score'))\n                return graph.edata['score']\n\nTraining loop\n~~~~~~~~~~~~~\n\nBecause our score prediction model operates on graphs, we need to\nexpress the negative examples as another graph. The graph will contain\nall negative node pairs as edges.\n\nThe following shows an example of expressing negative examples as a\ngraph. Each edge :math:`(u,v)` gets :math:`k` negative examples\n:math:`(u,v_i)` where :math:`v_i` is sampled from a uniform\ndistribution.\n\n.. code:: python\n\n    def construct_negative_graph(graph, k):\n        src, dst = graph.edges()\n    \n        neg_src = src.repeat_interleave(k)\n        neg_dst = torch.randint(0, graph.num_nodes(), (len(src) * k,))\n        return dgl.graph((neg_src, neg_dst), num_nodes=graph.num_nodes())\n\nThe model that predicts edge scores is the same as that of edge\nclassification/regression.\n\n.. code:: python\n\n    class Model(nn.Module):\n        def __init__(self, in_features, hidden_features, out_features):\n            super().__init__()\n            self.sage = SAGE(in_features, hidden_features, out_features)\n            self.pred = DotProductPredictor()\n        def forward(self, g, neg_g, x):\n            h = self.sage(g, x)\n            return self.pred(g, h), self.pred(neg_g, h)\n\nThe training loop then repeatedly constructs the negative graph and\ncomputes loss.\n\n.. code:: python\n\n    def compute_loss(pos_score, neg_score):\n        # Margin loss\n        n_edges = pos_score.shape[0]\n        return (1 - pos_score + neg_score.view(n_edges, -1)).clamp(min=0).mean()\n    \n    node_features = graph.ndata['feat']\n    n_features = node_features.shape[1]\n    k = 5\n    model = Model(n_features, 100, 100)\n    opt = torch.optim.Adam(model.parameters())\n    for epoch in range(10):\n        negative_graph = construct_negative_graph(graph, k)\n        pos_score, neg_score = model(graph, negative_graph, node_features)\n        loss = compute_loss(pos_score, neg_score)\n        opt.zero_grad()\n        loss.backward()\n        opt.step()\n        print(loss.item())\n\n\nAfter training, the node representation can be obtained via\n\n.. code:: python\n\n    node_embeddings = model.sage(graph, node_features)\n\nThere are multiple ways of using the node embeddings. Examples include\ntraining downstream classifiers, or doing nearest neighbor search or\nmaximum inner product search for relevant entity recommendation.\n\nHeterogeneous graphs\n~~~~~~~~~~~~~~~~~~~~\n\nLink prediction on heterogeneous graphs is not very different from that\non homogeneous graphs. The following assumes that we are predicting on\none edge type, and it is easy to extend it to multiple edge types.\n\nFor example, you can reuse the ``HeteroDotProductPredictor``\n:ref:`above <guide-training-edge-classification-heterogeneous-graph>`\nfor computing the scores of the edges of an edge type for link prediction.\n\n.. code:: python\n\n    class HeteroDotProductPredictor(nn.Module):\n        def forward(self, graph, h, etype):\n            # h contains the node representations for each node type computed from\n            # the GNN defined in the previous section (Section 5.1).\n            with graph.local_scope():\n                graph.ndata['h'] = h\n                graph.apply_edges(fn.u_dot_v('h', 'h', 'score'), etype=etype)\n                return graph.edges[etype].data['score']\n\nTo perform negative sampling, one can construct a negative graph for the\nedge type you are performing link prediction on as well.\n\n.. code:: python\n\n    def construct_negative_graph(graph, k, etype):\n        utype, _, vtype = etype\n        src, dst = graph.edges(etype=etype)\n        neg_src = src.repeat_interleave(k)\n        neg_dst = torch.randint(0, graph.num_nodes(vtype), (len(src) * k,))\n        return dgl.heterograph(\n            {etype: (neg_src, neg_dst)},\n            num_nodes_dict={ntype: graph.num_nodes(ntype) for ntype in graph.ntypes})\n\nThe model is a bit different from that in edge classification on\nheterogeneous graphs since you need to specify edge type where you\nperform link prediction.\n\n.. code:: python\n\n    class Model(nn.Module):\n        def __init__(self, in_features, hidden_features, out_features, rel_names):\n            super().__init__()\n            self.sage = RGCN(in_features, hidden_features, out_features, rel_names)\n            self.pred = HeteroDotProductPredictor()\n        def forward(self, g, neg_g, x, etype):\n            h = self.sage(g, x)\n            return self.pred(g, h, etype), self.pred(neg_g, h, etype)\n\nThe training loop is similar to that of homogeneous graphs.\n\n.. code:: python\n\n    def compute_loss(pos_score, neg_score):\n        # Margin loss\n        n_edges = pos_score.shape[0]\n        return (1 - pos_score + neg_score.view(n_edges, -1)).clamp(min=0).mean()\n    \n    k = 5\n    model = Model(10, 20, 5, hetero_graph.etypes)\n    user_feats = hetero_graph.nodes['user'].data['feature']\n    item_feats = hetero_graph.nodes['item'].data['feature']\n    node_features = {'user': user_feats, 'item': item_feats}\n    opt = torch.optim.Adam(model.parameters())\n    for epoch in range(10):\n        negative_graph = construct_negative_graph(hetero_graph, k, ('user', 'click', 'item'))\n        pos_score, neg_score = model(hetero_graph, negative_graph, node_features, ('user', 'click', 'item'))\n        loss = compute_loss(pos_score, neg_score)\n        opt.zero_grad()\n        loss.backward()\n        opt.step()\n        print(loss.item())\n\n\n\n"
  },
  {
    "path": "docs/source/guide/training-node.rst",
    "content": ".. _guide-training-node-classification:\n\n5.1 Node Classification/Regression\n--------------------------------------------------\n\n:ref:`(中文版) <guide_cn-training-node-classification>`\n\nOne of the most popular and widely adopted tasks for graph neural\nnetworks is node classification, where each node in the\ntraining/validation/test set is assigned a ground truth category from a\nset of predefined categories. Node regression is similar, where each\nnode in the training/validation/test set is assigned a ground truth\nnumber.\n\nOverview\n~~~~~~~~\n\nTo classify nodes, graph neural network performs message passing\ndiscussed in :ref:`guide-message-passing` to utilize the node’s own\nfeatures, but also its neighboring node and edge features. Message\npassing can be repeated multiple rounds to incorporate information from\nlarger range of neighborhood.\n\nWriting neural network model\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nDGL provides a few built-in graph convolution modules that can perform\none round of message passing. In this guide, we choose\n:class:`dgl.nn.pytorch.SAGEConv` (also available in MXNet and Tensorflow),\nthe graph convolution module for GraphSAGE.\n\nUsually for deep learning models on graphs we need a multi-layer graph\nneural network, where we do multiple rounds of message passing. This can\nbe achieved by stacking graph convolution modules as follows.\n\n.. code:: python\n\n    # Contruct a two-layer GNN model\n    import dgl.nn as dglnn\n    import torch.nn as nn\n    import torch.nn.functional as F\n    class SAGE(nn.Module):\n        def __init__(self, in_feats, hid_feats, out_feats):\n            super().__init__()\n            self.conv1 = dglnn.SAGEConv(\n                in_feats=in_feats, out_feats=hid_feats, aggregator_type='mean')\n            self.conv2 = dglnn.SAGEConv(\n                in_feats=hid_feats, out_feats=out_feats, aggregator_type='mean')\n      \n        def forward(self, graph, inputs):\n            # inputs are features of nodes\n            h = self.conv1(graph, inputs)\n            h = F.relu(h)\n            h = self.conv2(graph, h)\n            return h\n\nNote that you can use the model above for not only node classification,\nbut also obtaining hidden node representations for other downstream\ntasks such as\n:ref:`guide-training-edge-classification`,\n:ref:`guide-training-link-prediction`, or\n:ref:`guide-training-graph-classification`.\n\nFor a complete list of built-in graph convolution modules, please refer\nto :ref:`apinn`.\n\nFor more details in how DGL\nneural network modules work and how to write a custom neural network\nmodule with message passing please refer to the example in :ref:`guide-nn`.\n\nTraining loop\n~~~~~~~~~~~~~\n\nTraining on the full graph simply involves a forward propagation of the\nmodel defined above, and computing the loss by comparing the prediction\nagainst ground truth labels on the training nodes.\n\nThis section uses a DGL built-in dataset\n:class:`dgl.data.CiteseerGraphDataset` to\nshow a training loop. The node features\nand labels are stored on its graph instance, and the\ntraining-validation-test split are also stored on the graph as boolean\nmasks. This is similar to what you have seen in :ref:`guide-data-pipeline`.\n\n.. code:: python\n\n    node_features = graph.ndata['feat']\n    node_labels = graph.ndata['label']\n    train_mask = graph.ndata['train_mask']\n    valid_mask = graph.ndata['val_mask']\n    test_mask = graph.ndata['test_mask']\n    n_features = node_features.shape[1]\n    n_labels = int(node_labels.max().item() + 1)\n\nThe following is an example of evaluating your model by accuracy.\n\n.. code:: python\n\n    def evaluate(model, graph, features, labels, mask):\n        model.eval()\n        with torch.no_grad():\n            logits = model(graph, features)\n            logits = logits[mask]\n            labels = labels[mask]\n            _, indices = torch.max(logits, dim=1)\n            correct = torch.sum(indices == labels)\n            return correct.item() * 1.0 / len(labels)\n\nYou can then write our training loop as follows.\n\n.. code:: python\n\n    model = SAGE(in_feats=n_features, hid_feats=100, out_feats=n_labels)\n    opt = torch.optim.Adam(model.parameters())\n    \n    for epoch in range(10):\n        model.train()\n        # forward propagation by using all nodes\n        logits = model(graph, node_features)\n        # compute loss\n        loss = F.cross_entropy(logits[train_mask], node_labels[train_mask])\n        # compute validation accuracy\n        acc = evaluate(model, graph, node_features, node_labels, valid_mask)\n        # backward propagation\n        opt.zero_grad()\n        loss.backward()\n        opt.step()\n        print(loss.item())\n    \n        # Save model if necessary.  Omitted in this example.\n\n\n`GraphSAGE <https://github.com/dmlc/dgl/blob/master/examples/pytorch/graphsage/train_full.py>`__\nprovides an end-to-end homogeneous graph node classification example.\nYou could see the corresponding model implementation is in the\n``GraphSAGE`` class in the example with adjustable number of layers,\ndropout probabilities, and customizable aggregation functions and\nnonlinearities.\n\n.. _guide-training-rgcn-node-classification:\n\nHeterogeneous graph\n~~~~~~~~~~~~~~~~~~~\n\nIf your graph is heterogeneous, you may want to gather message from\nneighbors along all edge types. You can use the module\n:class:`dgl.nn.pytorch.HeteroGraphConv` (also available in MXNet and Tensorflow)\nto perform message passing\non all edge types, then combining different graph convolution modules\nfor each edge type.\n\nThe following code will define a heterogeneous graph convolution module\nthat first performs a separate graph convolution on each edge type, then\nsums the message aggregations on each edge type as the final result for\nall node types.\n\n.. code:: python\n\n    # Define a Heterograph Conv model\n\n    class RGCN(nn.Module):\n        def __init__(self, in_feats, hid_feats, out_feats, rel_names):\n            super().__init__()\n            \n            self.conv1 = dglnn.HeteroGraphConv({\n                rel: dglnn.GraphConv(in_feats, hid_feats)\n                for rel in rel_names}, aggregate='sum')\n            self.conv2 = dglnn.HeteroGraphConv({\n                rel: dglnn.GraphConv(hid_feats, out_feats)\n                for rel in rel_names}, aggregate='sum')\n      \n        def forward(self, graph, inputs):\n            # inputs are features of nodes\n            h = self.conv1(graph, inputs)\n            h = {k: F.relu(v) for k, v in h.items()}\n            h = self.conv2(graph, h)\n            return h\n\n``dgl.nn.HeteroGraphConv`` takes in a dictionary of node types and node\nfeature tensors as input, and returns another dictionary of node types\nand node features.\n\nSo given that we have the user and item features in the\n:ref:`heterogeneous graph example <guide-training-heterogeneous-graph-example>`.\n\n.. code:: python\n\n    model = RGCN(n_hetero_features, 20, n_user_classes, hetero_graph.etypes)\n    user_feats = hetero_graph.nodes['user'].data['feature']\n    item_feats = hetero_graph.nodes['item'].data['feature']\n    labels = hetero_graph.nodes['user'].data['label']\n    train_mask = hetero_graph.nodes['user'].data['train_mask']\n\nOne can simply perform a forward propagation as follows:\n\n.. code:: python\n\n    node_features = {'user': user_feats, 'item': item_feats}\n    h_dict = model(hetero_graph, {'user': user_feats, 'item': item_feats})\n    h_user = h_dict['user']\n    h_item = h_dict['item']\n\nTraining loop is the same as the one for homogeneous graph, except that\nnow you have a dictionary of node representations from which you compute\nthe predictions. For instance, if you are only predicting the ``user``\nnodes, you can just extract the ``user`` node embeddings from the\nreturned dictionary:\n\n.. code:: python\n\n    opt = torch.optim.Adam(model.parameters())\n    \n    for epoch in range(5):\n        model.train()\n        # forward propagation by using all nodes and extracting the user embeddings\n        logits = model(hetero_graph, node_features)['user']\n        # compute loss\n        loss = F.cross_entropy(logits[train_mask], labels[train_mask])\n        # Compute validation accuracy.  Omitted in this example.\n        # backward propagation\n        opt.zero_grad()\n        loss.backward()\n        opt.step()\n        print(loss.item())\n    \n        # Save model if necessary.  Omitted in the example.\n\n\nDGL provides an end-to-end example of\n`RGCN <https://github.com/dmlc/dgl/blob/master/examples/pytorch/rgcn-hetero/entity_classify.py>`__\nfor node classification. You can see the definition of heterogeneous\ngraph convolution in ``RelGraphConvLayer`` in the `model implementation\nfile <https://github.com/dmlc/dgl/blob/master/examples/pytorch/rgcn-hetero/model.py>`__.\n\n\n"
  },
  {
    "path": "docs/source/guide/training.rst",
    "content": ".. _guide-training:\n\nChapter 5: Training Graph Neural Networks\n=====================================================\n\n:ref:`(中文版) <guide_cn-training>`\n\nOverview\n--------\n\nThis chapter discusses how to train a graph neural network for node\nclassification, edge classification, link prediction, and graph\nclassification for small graph(s), by message passing methods introduced\nin :ref:`guide-message-passing` and neural network modules introduced in\n:ref:`guide-nn`.\n\nThis chapter assumes that your graph as well as all of its node and edge\nfeatures can fit into GPU; see :ref:`guide-minibatch` if they cannot.\n\nThe following text assumes that the graph(s) and node/edge features are\nalready prepared. If you plan to use the dataset DGL provides or other\ncompatible ``DGLDataset`` as is described in :ref:`guide-data-pipeline`, you can\nget the graph for a single-graph dataset with something like\n\n.. code:: python\n\n    import dgl\n\n    dataset = dgl.data.CiteseerGraphDataset()\n    graph = dataset[0]\n\n\nNote: In this chapter we will use PyTorch as backend.\n\n.. _guide-training-heterogeneous-graph-example:\n\nHeterogeneous Graphs\n~~~~~~~~~~~~~~~~~~~~\n\nSometimes you would like to work on heterogeneous graphs. Here we take a\nsynthetic heterogeneous graph as an example for demonstrating node\nclassification, edge classification, and link prediction tasks.\n\nThe synthetic heterogeneous graph ``hetero_graph`` has these edge types:\n\n-  ``('user', 'follow', 'user')``\n-  ``('user', 'followed-by', 'user')``\n-  ``('user', 'click', 'item')``\n-  ``('item', 'clicked-by', 'user')``\n-  ``('user', 'dislike', 'item')``\n-  ``('item', 'disliked-by', 'user')``\n\n.. code:: python\n\n    import numpy as np\n    import torch\n\n    n_users = 1000\n    n_items = 500\n    n_follows = 3000\n    n_clicks = 5000\n    n_dislikes = 500\n    n_hetero_features = 10\n    n_user_classes = 5\n    n_max_clicks = 10\n\n    follow_src = np.random.randint(0, n_users, n_follows)\n    follow_dst = np.random.randint(0, n_users, n_follows)\n    click_src = np.random.randint(0, n_users, n_clicks)\n    click_dst = np.random.randint(0, n_items, n_clicks)\n    dislike_src = np.random.randint(0, n_users, n_dislikes)\n    dislike_dst = np.random.randint(0, n_items, n_dislikes)\n\n    hetero_graph = dgl.heterograph({\n        ('user', 'follow', 'user'): (follow_src, follow_dst),\n        ('user', 'followed-by', 'user'): (follow_dst, follow_src),\n        ('user', 'click', 'item'): (click_src, click_dst),\n        ('item', 'clicked-by', 'user'): (click_dst, click_src),\n        ('user', 'dislike', 'item'): (dislike_src, dislike_dst),\n        ('item', 'disliked-by', 'user'): (dislike_dst, dislike_src)})\n\n    hetero_graph.nodes['user'].data['feature'] = torch.randn(n_users, n_hetero_features)\n    hetero_graph.nodes['item'].data['feature'] = torch.randn(n_items, n_hetero_features)\n    hetero_graph.nodes['user'].data['label'] = torch.randint(0, n_user_classes, (n_users,))\n    hetero_graph.edges['click'].data['label'] = torch.randint(1, n_max_clicks, (n_clicks,)).float()\n    # randomly generate training masks on user nodes and click edges\n    hetero_graph.nodes['user'].data['train_mask'] = torch.zeros(n_users, dtype=torch.bool).bernoulli(0.6)\n    hetero_graph.edges['click'].data['train_mask'] = torch.zeros(n_clicks, dtype=torch.bool).bernoulli(0.6)\n\n\nRoadmap\n------------\n\nThe chapter has four sections, each for one type of graph learning tasks.\n\n* :ref:`guide-training-node-classification`\n* :ref:`guide-training-edge-classification`\n* :ref:`guide-training-link-prediction`\n* :ref:`guide-training-graph-classification`\n* :ref:`guide-training-eweight`\n\n.. toctree::\n    :maxdepth: 1\n    :hidden:\n    :glob:\n\n    training-node\n    training-edge\n    training-link\n    training-graph\n    training-eweight\n"
  },
  {
    "path": "docs/source/guide_cn/data-dataset.rst",
    "content": ".. _guide_cn-data-pipeline-dataset:\n\n4.1 DGLDataset类\n--------------------\n\n:ref:`(English Version) <guide-data-pipeline-dataset>`\n\n:class:`~dgl.data.DGLDataset` 是处理、导入和保存 :ref:`apidata` 中定义的图数据集的基类。\n它实现了用于处理图数据的基本模版。下面的流程图展示了这个模版的工作方式。\n\n.. figure:: https://data.dgl.ai/asset/image/userguide_data_flow.png\n    :align: center\n\n    在类DGLDataset中定义的图数据处理模版的流程图。\n\n为了处理位于远程服务器或本地磁盘上的图数据集，下面的例子中定义了一个类，称为 ``MyDataset``,\n它继承自 :class:`dgl.data.DGLDataset`。\n\n.. code::\n\n    from dgl.data import DGLDataset\n    \n    class MyDataset(DGLDataset):\n        \"\"\" 用于在DGL中自定义图数据集的模板：\n    \n        Parameters\n        ----------\n        url : str\n            下载原始数据集的url。\n        raw_dir : str\n            指定下载数据的存储目录或已下载数据的存储目录。默认: ~/.dgl/\n        save_dir : str\n            处理完成的数据集的保存目录。默认：raw_dir指定的值\n        force_reload : bool\n            是否重新导入数据集。默认：False\n        verbose : bool\n            是否打印进度信息。\n        \"\"\"\n        def __init__(self, \n                     url=None, \n                     raw_dir=None, \n                     save_dir=None, \n                     force_reload=False, \n                     verbose=False):\n            super(MyDataset, self).__init__(name='dataset_name',\n                                            url=url,\n                                            raw_dir=raw_dir,\n                                            save_dir=save_dir,\n                                            force_reload=force_reload,\n                                            verbose=verbose)\n    \n        def download(self):\n            # 将原始数据下载到本地磁盘\n            pass\n    \n        def process(self):\n            # 将原始数据处理为图、标签和数据集划分的掩码\n            pass\n        \n        def __getitem__(self, idx):\n            # 通过idx得到与之对应的一个样本\n            pass\n    \n        def __len__(self):\n            # 数据样本的数量\n            pass\n    \n        def save(self):\n            # 将处理后的数据保存至 `self.save_path`\n            pass\n    \n        def load(self):\n            # 从 `self.save_path` 导入处理后的数据\n            pass\n    \n        def has_cache(self):\n            # 检查在 `self.save_path` 中是否存有处理后的数据\n            pass\n\n:class:`~dgl.data.DGLDataset` 类有抽象函数 ``process()``，\n``__getitem__(idx)`` 和 ``__len__()``。子类必须实现这些函数。同时DGL也建议实现保存和导入函数，\n因为对于处理后的大型数据集，这么做可以节省大量的时间，\n并且有多个已有的API可以简化此操作(请参阅 :ref:`guide_cn-data-pipeline-savenload`)。\n\n请注意， :class:`~dgl.data.DGLDataset` 的目的是提供一种标准且方便的方式来导入图数据。\n用户可以存储有关数据集的图、特征、标签、掩码，以及诸如类别数、标签数等基本信息。\n诸如采样、划分或特征归一化等操作建议在 :class:`~dgl.data.DGLDataset` 子类之外完成。\n\n本章的后续部分展示了实现这些函数的最佳实践。"
  },
  {
    "path": "docs/source/guide_cn/data-download.rst",
    "content": ".. _guide_cn-data-pipeline-download:\n\n4.2 下载原始数据（可选）\n--------------------------------\n\n:ref:`(English Version) <guide-data-pipeline-download>`\n\n如果用户的数据集已经在本地磁盘中，请确保它被存放在目录 ``raw_dir`` 中。\n如果用户想在任何地方运行代码而又不想自己下载数据并将其移动到正确的目录中，则可以通过实现函数 ``download()`` 来自动完成。\n\n如果数据集是一个zip文件，可以直接继承 :class:`dgl.data.DGLBuiltinDataset` 类。后者支持解压缩zip文件。\n否则用户需要自己实现 ``download()``，具体可以参考 :class:`~dgl.data.QM7bDataset` 类：\n\n.. code:: \n\n    import os\n    from dgl.data.utils import download\n    \n    def download(self):\n        # 存储文件的路径\n        file_path = os.path.join(self.raw_dir, self.name + '.mat')\n        # 下载文件\n        download(self.url, path=file_path)\n\n上面的代码将一个.mat文件下载到目录 ``self.raw_dir``。如果文件是.gz、.tar、.tar.gz或.tgz文件，请使用\n:func:`~dgl.data.utils.extract_archive` 函数进行解压缩。以下代码展示了如何在\n:class:`~dgl.data.BitcoinOTCDataset` 类中下载一个.gz文件：\n\n.. code:: \n\n    from dgl.data.utils import download, check_sha1\n    \n    def download(self):\n        # 存储文件的路径，请确保使用与原始文件名相同的后缀\n        gz_file_path = os.path.join(self.raw_dir, self.name + '.csv.gz')\n        # 下载文件\n        download(self.url, path=gz_file_path)\n        # 检查 SHA-1\n        if not check_sha1(gz_file_path, self._sha1_str):\n            raise UserWarning('File {} is downloaded but the content hash does not match.'\n                              'The repo may be outdated or download may be incomplete. '\n                              'Otherwise you can create an issue for it.'.format(self.name + '.csv.gz'))\n        # 将文件解压缩到目录self.raw_dir下的self.name目录中\n        self._extract_gz(gz_file_path, self.raw_path)\n\n上面的代码会将文件解压缩到 ``self.raw_dir`` 下的目录 ``self.name`` 中。\n如果该类继承自 :class:`dgl.data.DGLBuiltinDataset` 来处理zip文件，\n则它也会将文件解压缩到目录 ``self.name`` 中。\n\n一个可选项是用户可以按照上面的示例检查下载后文件的SHA-1字符串，以防作者在远程服务器上更改了文件。"
  },
  {
    "path": "docs/source/guide_cn/data-loadogb.rst",
    "content": ".. _guide_cn-data-pipeline-loadogb:\n\n4.5 使用ogb包导入OGB数据集\n----------------------------------------------\n\n:ref:`(English Version) <guide-data-pipeline-loadogb>`\n\n`Open Graph Benchmark (OGB) <https://ogb.stanford.edu/docs/home/>`__ 是一个图深度学习的基准数据集。\n官方的 `ogb <https://github.com/snap-stanford/ogb>`__ 包提供了用于下载和处理OGB数据集到\n:class:`dgl.data.DGLGraph` 对象的API。本节会介绍它们的基本用法。\n\n首先使用pip安装ogb包：\n\n.. code:: \n\n    pip install ogb\n\n\n以下代码显示了如何为 *Graph Property Prediction* 任务加载数据集。\n\n.. code:: \n\n    # 载入OGB的Graph Property Prediction数据集\n    import dgl\n    import torch\n    from ogb.graphproppred import DglGraphPropPredDataset\n    from dgl.dataloading import GraphDataLoader\n    \n    def _collate_fn(batch):\n        # 小批次是一个元组(graph, label)列表\n        graphs = [e[0] for e in batch]\n        g = dgl.batch(graphs)\n        labels = [e[1] for e in batch]\n        labels = torch.stack(labels, 0)\n        return g, labels\n    \n    # 载入数据集\n    dataset = DglGraphPropPredDataset(name='ogbg-molhiv')\n    split_idx = dataset.get_idx_split()\n    # dataloader\n    train_loader = GraphDataLoader(dataset[split_idx[\"train\"]], batch_size=32, shuffle=True, collate_fn=_collate_fn)\n    valid_loader = GraphDataLoader(dataset[split_idx[\"valid\"]], batch_size=32, shuffle=False, collate_fn=_collate_fn)\n    test_loader = GraphDataLoader(dataset[split_idx[\"test\"]], batch_size=32, shuffle=False, collate_fn=_collate_fn)\n\n加载 *Node Property Prediction* 数据集类似，但要注意的是这种数据集只有一个图对象。\n\n.. code:: \n\n    # 载入OGB的Node Property Prediction数据集\n    from ogb.nodeproppred import DglNodePropPredDataset\n    \n    dataset = DglNodePropPredDataset(name='ogbn-proteins')\n    split_idx = dataset.get_idx_split()\n    \n    # there is only one graph in Node Property Prediction datasets\n    # 在Node Property Prediction数据集里只有一个图\n    g, labels = dataset[0]\n    # 获取划分的标签\n    train_label = dataset.labels[split_idx['train']]\n    valid_label = dataset.labels[split_idx['valid']]\n    test_label = dataset.labels[split_idx['test']]\n\n每个 *Link Property Prediction* 数据集也只包括一个图。\n\n.. code::\n\n    # 载入OGB的Link Property Prediction数据集\n    from ogb.linkproppred import DglLinkPropPredDataset\n    \n    dataset = DglLinkPropPredDataset(name='ogbl-ppa')\n    split_edge = dataset.get_edge_split()\n    \n    graph = dataset[0]\n    print(split_edge['train'].keys())\n    print(split_edge['valid'].keys())\n    print(split_edge['test'].keys())\n"
  },
  {
    "path": "docs/source/guide_cn/data-process.rst",
    "content": ".. _guide_cn-data-pipeline-process:\n\n4.3 处理数据\n----------------\n\n:ref:`(English Version) <guide-data-pipeline-process>`\n\n用户可以在 ``process()`` 函数中实现数据处理。该函数假定原始数据已经位于 ``self.raw_dir`` 目录中。\n\n图上的机器学习任务通常有三种类型：整图分类、节点分类和链接预测。本节将展示如何处理与这些任务相关的数据集。\n\n本节重点介绍了处理图、特征和划分掩码的标准方法。用户指南将以内置数据集为例，并跳过从文件构建图的实现。\n用户可以参考 :ref:`guide_cn-graph-external` 以查看如何从外部数据源构建图的完整指南。\n\n处理整图分类数据集\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n整图分类数据集与用小批次训练的典型机器学习任务中的大多数数据集类似。\n因此，需要将原始数据处理为 :class:`dgl.DGLGraph` 对象的列表和标签张量的列表。\n此外，如果原始数据已被拆分为多个文件，则可以添加参数 ``split`` 以导入数据的特定部分。\n\n下面以 :class:`~dgl.data.QM7bDataset` 为例：\n\n.. code::\n\n    from dgl.data import DGLDataset\n\n    class QM7bDataset(DGLDataset):\n        _url = 'http://deepchem.io.s3-website-us-west-1.amazonaws.com/' \\\n               'datasets/qm7b.mat'\n        _sha1_str = '4102c744bb9d6fd7b40ac67a300e49cd87e28392'\n\n        def __init__(self, raw_dir=None, force_reload=False, verbose=False):\n            super(QM7bDataset, self).__init__(name='qm7b',\n                                              url=self._url,\n                                              raw_dir=raw_dir,\n                                              force_reload=force_reload,\n                                              verbose=verbose)\n\n        def process(self):\n            mat_path = self.raw_path + '.mat'\n            # 将数据处理为图列表和标签列表\n            self.graphs, self.label = self._load_graph(mat_path)\n\n        def __getitem__(self, idx):\n            \"\"\" 通过idx获取对应的图和标签\n\n            Parameters\n            ----------\n            idx : int\n                Item index\n\n            Returns\n            -------\n            (dgl.DGLGraph, Tensor)\n            \"\"\"\n            return self.graphs[idx], self.label[idx]\n\n        def __len__(self):\n            \"\"\"数据集中图的数量\"\"\"\n            return len(self.graphs)\n\n函数 ``process()`` 将原始数据处理为图列表和标签列表。用户必须实现 ``__getitem__(idx)`` 和  ``__len__()`` 以进行迭代。\nDGL建议让 ``__getitem__(idx)`` 返回如上面代码所示的元组 ``(图，标签)``。\n用户可以参考 `QM7bDataset源代码  <https://docs.dgl.ai/en/0.5.x/_modules/dgl/data/qm7b.html#QM7bDataset>`__\n以获得 ``self._load_graph()`` 和 ``__getitem__`` 的详细信息。\n\n用户还可以向类添加属性以指示一些有用的数据集信息。在 :class:`~dgl.data.QM7bDataset` 中，\n用户可以添加属性 ``num_tasks`` 来指示此多任务数据集中的预测任务总数：\n\n.. code::\n\n    @property\n    def num_tasks(self):\n        \"\"\"每个图的标签数，即预测任务数。\"\"\"\n        return 14\n\n在编写完这些代码之后，用户可以按如下所示的方式来使用 :class:`~dgl.data.QM7bDataset`：\n\n.. code::\n\n    import dgl\n    import torch\n\n    from dgl.dataloading import GraphDataLoader\n\n    # 数据导入\n    dataset = QM7bDataset()\n    num_tasks = dataset.num_tasks\n\n    # 创建 dataloaders\n    dataloader = GraphDataLoader(dataset, batch_size=1, shuffle=True)\n\n    # 训练\n    for epoch in range(100):\n        for g, labels in dataloader:\n            # 用户自己的训练代码\n            pass\n\n训练整图分类模型的完整指南可以在 :ref:`guide_cn-training-graph-classification` 中找到。\n\n有关整图分类数据集的更多示例，用户可以参考 :ref:`guide_cn-training-graph-classification`：\n\n* :ref:`gindataset`\n\n* :ref:`minigcdataset`\n\n* :ref:`qm7bdata`\n\n* :ref:`tudata`\n\n处理节点分类数据集\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n与整图分类不同，节点分类通常在单个图上进行。因此数据集的划分是在图的节点集上进行。\nDGL建议使用节点掩码来指定数据集的划分。\n本节以内置数据集 `CitationGraphDataset <https://docs.dgl.ai/en/0.5.x/_modules/dgl/data/citation_graph.html#CitationGraphDataset>`__ 为例：\n\n此外，DGL推荐重新排列图的节点/边，使得相邻节点/边的ID位于邻近区间内。这个过程\n可以提高节点/边的邻居的局部性，为后续在图上进行的计算与分析的性能改善提供可能。\nDGL提供了名为 :func:`dgl.reorder_graph` 的API用于此优化。更多细节，请参考\n下面例子中的 ``process()`` 的部分。\n\n.. code::\n\n    from dgl.data import DGLBuiltinDataset\n    from dgl.data.utils import _get_dgl_url\n\n    class CitationGraphDataset(DGLBuiltinDataset):\n        _urls = {\n            'cora_v2' : 'dataset/cora_v2.zip',\n            'citeseer' : 'dataset/citeseer.zip',\n            'pubmed' : 'dataset/pubmed.zip',\n        }\n\n        def __init__(self, name, raw_dir=None, force_reload=False, verbose=True):\n            assert name.lower() in ['cora', 'citeseer', 'pubmed']\n            if name.lower() == 'cora':\n                name = 'cora_v2'\n            url = _get_dgl_url(self._urls[name])\n            super(CitationGraphDataset, self).__init__(name,\n                                                       url=url,\n                                                       raw_dir=raw_dir,\n                                                       force_reload=force_reload,\n                                                       verbose=verbose)\n\n        def process(self):\n            # 跳过一些处理的代码\n            # === 跳过数据处理 ===\n\n            # 构建图\n            g = dgl.graph(graph)\n\n            # 划分掩码\n            g.ndata['train_mask'] = train_mask\n            g.ndata['val_mask'] = val_mask\n            g.ndata['test_mask'] = test_mask\n\n            # 节点的标签\n            g.ndata['label'] = torch.tensor(labels)\n\n            # 节点的特征\n            g.ndata['feat'] = torch.tensor(_preprocess_features(features),\n                                           dtype=F.data_type_dict['float32'])\n            self._num_tasks = onehot_labels.shape[1]\n            self._labels = labels\n            # 重排图以获得更优的局部性\n            self._g = dgl.reorder_graph(g)\n\n        def __getitem__(self, idx):\n            assert idx == 0, \"这个数据集里只有一个图\"\n            return self._g\n\n        def __len__(self):\n            return 1\n\n为简便起见，这里省略了 ``process()`` 中的一些代码，以突出展示用于处理节点分类数据集的关键部分：划分掩码。\n节点特征和节点的标签被存储在 ``g.ndata`` 中。详细的实现请参考\n`CitationGraphDataset源代码 <https://docs.dgl.ai/en/0.5.x/_modules/dgl/data/citation_graph.html#CitationGraphDataset>`__ 。\n\n请注意，这里 ``__getitem__(idx)`` 和 ``__len__()`` 的实现也发生了变化，\n这是因为节点分类任务通常只用一个图。掩码在PyTorch和TensorFlow中是bool张量，在MXNet中是float张量。\n\n下面中使用 :class:`dgl.data.CitationGraphDataset` 的子类 :class:`dgl.data.CiteseerGraphDataset`\n来演示如何使用用于节点分类的数据集：\n\n.. code::\n\n    # 导入数据\n    dataset = CiteseerGraphDataset(raw_dir='')\n    graph = dataset[0]\n\n    # 获取划分的掩码\n    train_mask = graph.ndata['train_mask']\n    val_mask = graph.ndata['val_mask']\n    test_mask = graph.ndata['test_mask']\n\n    # 获取节点特征\n    feats = graph.ndata['feat']\n\n    # 获取标签\n    labels = graph.ndata['label']\n\n:ref:`guide_cn-training-node-classification` 提供了训练节点分类模型的完整指南。\n\n有关节点分类数据集的更多示例，用户可以参考以下内置数据集：\n\n* :ref:`citationdata`\n\n* :ref:`corafulldata`\n\n* :ref:`amazoncobuydata`\n\n* :ref:`coauthordata`\n\n* :ref:`karateclubdata`\n\n* :ref:`ppidata`\n\n* :ref:`redditdata`\n\n* :ref:`sbmdata`\n\n* :ref:`sstdata`\n\n* :ref:`rdfdata`\n\n处理链接预测数据集\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n链接预测数据集的处理与节点分类相似，数据集中通常只有一个图。\n\n本节以内置的数据集 `KnowledgeGraphDataset <https://docs.dgl.ai/en/0.5.x/_modules/dgl/data/knowledge_graph.html#KnowledgeGraphDataset>`__\n为例，同时省略了详细的数据处理代码以突出展示处理链接预测数据集的关键部分：\n\n.. code::\n\n    # 创建链接预测数据集示例\n    class KnowledgeGraphDataset(DGLBuiltinDataset):\n        def __init__(self, name, reverse=True, raw_dir=None, force_reload=False, verbose=True):\n            self._name = name\n            self.reverse = reverse\n            url = _get_dgl_url('dataset/') + '{}.tgz'.format(name)\n            super(KnowledgeGraphDataset, self).__init__(name,\n                                                        url=url,\n                                                        raw_dir=raw_dir,\n                                                        force_reload=force_reload,\n                                                        verbose=verbose)\n\n        def process(self):\n            # 跳过一些处理的代码\n            # === 跳过数据处理 ===\n\n            # 划分掩码\n            g.edata['train_mask'] = train_mask\n            g.edata['val_mask'] = val_mask\n            g.edata['test_mask'] = test_mask\n\n            # 边类型\n            g.edata['etype'] = etype\n\n            # 节点类型\n            g.ndata['ntype'] = ntype\n            self._g = g\n\n        def __getitem__(self, idx):\n            assert idx == 0, \"这个数据集只有一个图\"\n            return self._g\n\n        def __len__(self):\n            return 1\n\n\n如代码所示，图的 ``edata`` 存储了划分掩码。在\n`KnowledgeGraphDataset 源代码 <https://docs.dgl.ai/en/0.5.x/_modules/dgl/data/knowledge_graph.html#KnowledgeGraphDataset>`__\n中可以查看完整的代码。下面使用 ``KnowledgeGraphDataset``的子类 :class:`dgl.data.FB15k237Dataset` 来做演示如何使用用于链路预测的数据集：\n\n.. code::\n\n    from dgl.data import FB15k237Dataset\n\n    # 导入数据\n    dataset = FB15k237Dataset()\n    graph = dataset[0]\n\n    # 获取训练集掩码\n    train_mask = graph.edata['train_mask']\n    train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze()\n    src, dst = graph.edges(train_idx)\n\n    # 获取训练集中的边类型\n    rel = graph.edata['etype'][train_idx]\n\n有关训练链接预测模型的完整指南，请参见 :ref:`guide_cn-training-link-prediction`。\n\n有关链接预测数据集的更多示例，请参考DGL的内置数据集：\n\n* :ref:`kgdata`\n\n* :ref:`bitcoinotcdata`\n"
  },
  {
    "path": "docs/source/guide_cn/data-savenload.rst",
    "content": ".. _guide_cn-data-pipeline-savenload:\n\n4.4 保存和加载数据\n----------------------\n\n:ref:`(English Version) <guide-data-pipeline-savenload>`\n\nDGL建议用户实现保存和加载数据的函数，将处理后的数据缓存在本地磁盘中。\n这样在多数情况下可以帮用户节省大量的数据处理时间。DGL提供了4个函数让任务变得简单。\n\n-  :func:`dgl.save_graphs` 和 :func:`dgl.load_graphs`: 保存DGLGraph对象和标签到本地磁盘和从本地磁盘读取它们。\n-  :func:`dgl.data.utils.save_info` 和 :func:`dgl.data.utils.load_info`: 将数据集的有用信息(python dict对象)保存到本地磁盘和从本地磁盘读取它们。\n\n下面的示例显示了如何保存和读取图和数据集信息的列表。\n\n.. code:: \n\n    import os\n    from dgl import save_graphs, load_graphs\n    from dgl.data.utils import makedirs, save_info, load_info\n    \n    def save(self):\n        # 保存图和标签\n        graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin')\n        save_graphs(graph_path, self.graphs, {'labels': self.labels})\n        # 在Python字典里保存其他信息\n        info_path = os.path.join(self.save_path, self.mode + '_info.pkl')\n        save_info(info_path, {'num_classes': self.num_classes})\n    \n    def load(self):\n        # 从目录 `self.save_path` 里读取处理过的数据\n        graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin')\n        self.graphs, label_dict = load_graphs(graph_path)\n        self.labels = label_dict['labels']\n        info_path = os.path.join(self.save_path, self.mode + '_info.pkl')\n        self.num_classes = load_info(info_path)['num_classes']\n    \n    def has_cache(self):\n        # 检查在 `self.save_path` 里是否有处理过的数据文件\n        graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin')\n        info_path = os.path.join(self.save_path, self.mode + '_info.pkl')\n        return os.path.exists(graph_path) and os.path.exists(info_path)\n\n请注意：有些情况下不适合保存处理过的数据。例如，在内置数据集 :class:`~dgl.data.GDELTDataset` 中，\n处理过的数据比较大。所以这个时候，在 ``__getitem__(idx)`` 中处理每个数据实例是更高效的方法。\n"
  },
  {
    "path": "docs/source/guide_cn/data.rst",
    "content": ".. _guide_cn-data-pipeline:\n\n第4章：图数据处理管道\n==============================\n\n:ref:`(English Version) <guide-data-pipeline>`\n\nDGL在 :ref:`apidata` 里实现了很多常用的图数据集。它们遵循了由 :class:`dgl.data.DGLDataset` 类定义的标准的数据处理管道。\nDGL推荐用户将图数据处理为 :class:`dgl.data.DGLDataset` 的子类。该类为导入、处理和保存图数据提供了简单而干净的解决方案。\n\n本章路线图\n-----------\n\n本章介绍了如何为用户自己的图数据创建一个DGL数据集。以下内容说明了管道的工作方式，并展示了如何实现管道的每个组件。\n\n* :ref:`guide_cn-data-pipeline-dataset`\n* :ref:`guide_cn-data-pipeline-download`\n* :ref:`guide_cn-data-pipeline-process`\n* :ref:`guide_cn-data-pipeline-savenload`\n* :ref:`guide_cn-data-pipeline-loadogb`\n\n.. toctree::\n    :maxdepth: 1\n    :hidden:\n    :glob:\n\n    data-dataset\n    data-download\n    data-process\n    data-savenload\n    data-loadogb"
  },
  {
    "path": "docs/source/guide_cn/distributed-apis.rst",
    "content": ".. _guide_cn-distributed-apis:\n\n7.2 分布式计算的API\n--------------------\n\n:ref:`(English Version) <guide-distributed-apis>`\n\n本节介绍了在训练脚本中使用的分布式计算API。DGL提供了三种分布式数据结构和多种API，用于初始化、分布式采样和数据分割。\n对于分布式训练/推断，DGL提供了三种分布式数据结构：用于分布式图的 :class:`~dgl.distributed.DistGraph`、\n用于分布式张量的 :class:`~dgl.distributed.DistTensor` 和用于分布式可学习嵌入的\n:class:`~dgl.distributed.DistEmbedding`。\n\nDGL分布式模块的初始化\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n:func:`~dgl.distributed.initialize` 可以用于初始化分布式模块。当训练脚本在训练器模式下运行时，\n这个API会与DGL服务器建立连接并创建采样器进程。当脚本在服务器模式下运行时，这个API将运行服务器代码，\n直到训练任务结束。必须在DGL的任何其他分布式API之前，调用此API。在使用PyTorch时，必须在\n``torch.distributed.init_process_group`` 之前调用 :func:`~dgl.distributed.initialize`。\n通常，初始化API应按以下顺序调用：\n\n.. code:: python\n\n    dgl.distributed.initialize('ip_config.txt')\n    th.distributed.init_process_group(backend='gloo')\n\n**Note**: 如果训练脚本里包含需要在服务器(细节内容可以在下面的DistTensor和DistEmbedding章节里查看)上调用的用户自定义函数(UDF)，\n这些UDF必须在 :func:`~dgl.distributed.initialize` 之前被声明。\n\n分布式图\n~~~~~~~~~~~~~~~~~\n\n:class:`~dgl.distributed.DistGraph` 是一个Python类，用于访问计算机集群中的图结构和节点/边特征。每台计算机负责一个且只负责一个分区。\n它加载分区数据(包括分区中的图结构、节点数据和边数据)，并使集群中的所有训练器均可访问它们。\n:class:`~dgl.distributed.DistGraph` 提供了一小部分 :class:`~dgl.DGLGraph` 的API以方便数据访问。\n\n**Note**: :class:`~dgl.distributed.DistGraph` 当前仅支持一种节点类型和一种边类型的图。\n\n分布式模式与独立模式\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n:class:`~dgl.distributed.DistGraph` 可以在两种模式下运行：分布式模式和独立模式。\n当用户在Python命令行或Jupyter Notebook中执行训练脚本时，它将以独立模式运行。也就是说，它在单个进程中运行所有计算，\n并且不与任何其他进程通信。因此，独立模式要求输入图仅具有一个分区。此模式主要用于开发和测试\n(例如，在Jupyter Notebook中开发和运行代码)。当用户使用启动脚本执行训练脚本时(请参见启动脚本部分)，\n:class:`~dgl.distributed.DistGraph` 将以分布式模式运行。启动脚本在后台启动服务器(包括访问节点/边特征和图采样)，\n并将分区数据自动加载到每台计算机中。:class:`~dgl.distributed.DistGraph` 与集群中的服务器连接并通过网络访问它们。\n\n创建DistGraph\n^^^^^^^^^^^^^^^^^^\n\n在分布式模式下，:class:`~dgl.distributed.DistGraph` 的创建需要(定义)在图划分期间的图名称。\n图名称标识了集群中所需加载的图。\n\n.. code:: python\n\n    import dgl\n    g = dgl.distributed.DistGraph('graph_name')\n\n在独立模式下运行时，DistGraph将图数据加载到本地计算机中。因此，用户需要提供分区配置文件，其中包含有关输入图的所有信息。\n\n.. code:: python\n\n    import dgl\n    g = dgl.distributed.DistGraph('graph_name', part_config='data/graph_name.json')\n\n**Note**: 在当前实现中，DGL仅允许创建单个DistGraph对象。销毁DistGraph并创建一个新DistGraph的行为没有被定义。\n\n访问图结构\n^^^^^^^^^^^^^^^^^^^^^^\n\n:class:`~dgl.distributed.DistGraph` 提供了几个API来访问图结构。当前，它们主要被用来提供图信息，例如节点和边的数量。\n主要应用场景是运行采样API以支持小批量训练(请参阅下文里分布式图采样部分)。\n\n.. code:: python\n\n    print(g.num_nodes())\n\n访问节点/边数据\n^^^^^^^^^^^^^^^^^^^^^\n\n与 :class:`~dgl.DGLGraph` 一样， :class:`~dgl.distributed.DistGraph` 也提供了\n``ndata`` 和 ``edata`` 来访问节点和边中的数据。它们的区别在于\n:class:`~dgl.distributed.DistGraph` 中的 ``ndata`` / ``edata`` 返回的是 :class:`~dgl.distributed.DistTensor`，\n而不是底层框架里的张量。用户还可以将新的 :class:`~dgl.distributed.DistTensor` 分配给\n:class:`~dgl.distributed.DistGraph` 作为节点数据或边数据。\n\n.. code:: python\n\n    g.ndata['train_mask']\n    <dgl.distributed.dist_graph.DistTensor at 0x7fec820937b8>\n    g.ndata['train_mask'][0]\n    tensor([1], dtype=torch.uint8)\n\n分布式张量\n~~~~~~~~~~~~~~~~~\n\n如前所述，在分布式模式下，DGL会划分节点和边特征，并将它们存储在计算机集群中。\nDGL为分布式张量提供了类似于单机普通张量的接口，以访问群集中的分区节点和边特征。\n在分布式设置中，DGL仅支持密集节点和边特征，暂不支持稀疏节点和边特征。\n\n:class:`~dgl.distributed.DistTensor` 管理在多个计算机中被划分和存储的密集张量。\n目前，分布式张量必须与图的节点或边相关联。换句话说，DistTensor中的行数必须与图中的节点数或边数相同。\n以下代码创建一个分布式张量。 除了张量的形状和数据类型之外，用户还可以提供唯一的张量名称。\n如果用户要引用一个固定的分布式张量(即使 :class:`~dgl.distributed.DistTensor` 对象消失，该名称仍存在于群集中)，\n则(使用这样的)名称就很有用。\n\n.. code:: python\n\n    tensor = dgl.distributed.DistTensor((g.num_nodes(), 10), th.float32, name='test')\n\n**Note**: :class:`~dgl.distributed.DistTensor` 的创建是一个同步操作。所有训练器都必须调用创建，\n并且只有当所有训练器都调用它时，此创建过程才能成功。\n\n用户可以将 :class:`~dgl.distributed.DistTensor` 作为节点数据或边数据之一添加到\n:class:`~dgl.distributed.DistGraph` 对象。\n\n.. code:: python\n\n    g.ndata['feat'] = tensor\n\n**Note**: 节点数据名称和张量名称不必相同。前者在 :class:`~dgl.distributed.DistGraph` 中标识节点数据(在训练器进程中)，\n而后者则标识DGL服务器中的分布式张量。\n\n:class:`~dgl.distributed.DistTensor` 提供了一些功能。它具有与常规张量相同的API，用于访问其元数据，\n例如形状和数据类型。:class:`~dgl.distributed.DistTensor` 支持索引读取和写入，\n但不支持一些计算运算符，例如求和以及求均值。\n\n.. code:: python\n\n    data = g.ndata['feat'][[1, 2, 3]]\n    print(data)\n    g.ndata['feat'][[3, 4, 5]] = data\n\n**Note**: 当前，当一台机器运行多个服务器时，DGL不提供对来自多个训练器的并发写入的保护。\n这可能会导致数据损坏。\n\n分布式嵌入\n~~~~~~~~~~~~~~~~~~~~~\n\nDGL提供 :class:`~dgl.distributed.DistEmbedding` 以支持需要节点嵌入的直推(transductive)模型。\n分布式嵌入的创建与分布式张量的创建非常相似。\n\n.. code:: python\n\n    def initializer(shape, dtype):\n        arr = th.zeros(shape, dtype=dtype)\n        arr.uniform_(-1, 1)\n        return arr\n    emb = dgl.distributed.DistEmbedding(g.num_nodes(), 10, init_func=initializer)\n\n在内部，分布式嵌入建立在分布式张量之上，因此，其行为与分布式张量非常相似。\n例如，创建嵌入时，DGL会将它们分片并存储在集群中的所有计算机上。(分布式嵌入)可以通过名称唯一标识。\n\n**Note**: 服务器进程负责调用初始化函数。因此，必须在初始化( :class:`~dgl.distributed.initialize` )之前声明分布式嵌入。\n\n因为嵌入是模型的一部分，所以用户必须将其附加到优化器上以进行小批量训练。当前，\nDGL提供了一个稀疏的Adagrad优化器 :class:`~dgl.distributed.SparseAdagrad` (DGL以后将为稀疏嵌入添加更多的优化器)。\n用户需要从模型中收集所有分布式嵌入，并将它们传递给稀疏优化器。如果模型同时具有节点嵌入和规则的密集模型参数，\n并且用户希望对嵌入执行稀疏更新，则需要创建两个优化器，一个用于节点嵌入，另一个用于密集模型参数，如以下代码所示：\n\n.. code:: python\n\n    sparse_optimizer = dgl.distributed.SparseAdagrad([emb], lr=lr1)\n    optimizer = th.optim.Adam(model.parameters(), lr=lr2)\n    feats = emb(nids)\n    loss = model(feats)\n    loss.backward()\n    optimizer.step()\n    sparse_optimizer.step()\n\n**Note**: :class:`~dgl.distributed.DistEmbedding` 不是PyTorch的nn模块，因此用户无法从nn模块的参数访问它。\n\n分布式采样\n~~~~~~~~~~~~~~~~~~~~\n\nDGL提供了两个级别的API，用于对节点和边进行采样以生成小批次训练数据(请参阅小批次训练的章节)。\n底层API要求用户编写代码以明确定义如何对节点层进行采样(例如，使用 :func:`dgl.sampling.sample_neighbors` )。\n高层采样API为节点分类和链接预测任务实现了一些流行的采样算法（例如\n:class:`~dgl.dataloading.pytorch.NodeDataLoader`\n和\n:class:`~dgl.dataloading.pytorch.EdgeDataLoader` )。\n\n分布式采样模块遵循相同的设计，也提供两个级别的采样API。对于底层的采样API，它为\n:class:`~dgl.distributed.DistGraph` 上的分布式邻居采样提供了\n:func:`~dgl.distributed.sample_neighbors`。另外，DGL提供了用于分布式采样的分布式数据加载器(\n:class:`~dgl.distributed.DistDataLoader`)。除了用户在创建数据加载器时无法指定工作进程的数量，\n分布式数据加载器具有与PyTorch DataLoader相同的接口。其中的工作进程(worker)在 :func:`dgl.distributed.initialize` 中创建。\n\n**Note**: 在 :class:`~dgl.distributed.DistGraph` 上运行 :func:`dgl.distributed.sample_neighbors` 时，\n采样器无法在具有多个工作进程的PyTorch DataLoader中运行。主要原因是PyTorch DataLoader在每个训练周期都会创建新的采样工作进程，\n从而导致多次创建和删除 :class:`~dgl.distributed.DistGraph` 对象。\n\n使用底层API时，采样代码类似于单进程采样。唯一的区别是用户需要使用\n:func:`dgl.distributed.sample_neighbors`\n和\n:class:`~dgl.distributed.DistDataLoader`。\n\n.. code:: python\n\n    def sample_blocks(seeds):\n        seeds = th.LongTensor(np.asarray(seeds))\n        blocks = []\n        for fanout in [10, 25]:\n            frontier = dgl.distributed.sample_neighbors(g, seeds, fanout, replace=True)\n            block = dgl.to_block(frontier, seeds)\n            seeds = block.srcdata[dgl.NID]\n            blocks.insert(0, block)\n            return blocks\n        dataloader = dgl.distributed.DistDataLoader(dataset=train_nid,\n                                                    batch_size=batch_size,\n                                                    collate_fn=sample_blocks,\n                                                    shuffle=True)\n        for batch in dataloader:\n            ...\n\n:class:`~dgl.dataloading.pytorch.NodeDataLoader`\n和\n:class:`~dgl.dataloading.pytorch.EdgeDataLoader` 有分布式的版本\n:class:`~dgl.dataloading.pytorch.DistNodeDataLoader`\n和\n:class:`~dgl.dataloading.pytorch.DistEdgeDataLoader` 。使用\n时分布式采样代码与单进程采样几乎完全相同。\n\n.. code:: python\n\n    sampler = dgl.sampling.MultiLayerNeighborSampler([10, 25])\n    dataloader = dgl.sampling.DistNodeDataLoader(g, train_nid, sampler,\n                                                 batch_size=batch_size, shuffle=True)\n    for batch in dataloader:\n        ...\n\n\n分割数据集\n~~~~~~~~~~~~~~~\n\n用户需要分割训练集，以便每个训练器都可以使用自己的训练集子集。同样，用户还需要以相同的方式分割验证和测试集。\n\n对于分布式训练和评估，推荐的方法是使用布尔数组表示训练、验证和测试集。对于节点分类任务，\n这些布尔数组的长度是图中节点的数量，并且它们的每个元素都表示训练/验证/测试集中是否存在对应节点。\n链接预测任务也应使用类似的布尔数组。\n\nDGL提供了 :func:`~dgl.distributed.node_split` 和 :func:`~dgl.distributed.edge_split`\n函数来在运行时拆分训练、验证和测试集，以进行分布式训练。这两个函数将布尔数组作为输入，对其进行拆分，并向本地训练器返回一部分。\n默认情况下，它们确保所有部分都具有相同数量的节点和边。这对于同步SGD非常重要，\n因为同步SGD会假定每个训练器具有相同数量的小批次。\n\n下面的示例演示了训练集拆分，并向本地进程返回节点的子集。\n\n.. code:: python\n\n    train_nids = dgl.distributed.node_split(g.ndata['train_mask'])\n\n"
  },
  {
    "path": "docs/source/guide_cn/distributed-preprocessing.rst",
    "content": ".. _guide_cn-distributed-preprocessing:\n\n7.1 分布式训练所需的图数据预处理\n------------------------------------------\n\n:ref:`(English Version) <guide-distributed-preprocessing>`\n\nDGL要求预处理图数据以进行分布式训练，这包括两个步骤：1)将一张图划分为多张子图(分区)，2)为节点和边分配新的ID。\nDGL提供了一个API以执行这两个步骤。该API支持随机划分和一个基于\n`Metis <http://glaros.dtc.umn.edu/gkhome/views/metis>`__ 的划分。Metis划分的好处在于，\n它可以用最少的边分割以生成分区，从而减少了用于分布式训练和推理的网络通信。DGL使用最新版本的Metis，\n并针对真实世界中具有幂律分布的图进行了优化。在图划分后，API以易于在训练期间加载的格式构造划分结果。\n\n**Note**: 图划分API当前在一台机器上运行。 因此如果一张图很大，用户将需要一台大内存的机器来对图进行划分。\n未来DGL将支持分布式图划分。\n\n默认情况下，为了在分布式训练/推理期间定位节点/边，API将新ID分配给输入图的节点和边。\n分配ID后，该API会相应地打乱所有节点数据和边数据。在训练期间，用户只需使用新的节点和边的ID。\n与此同时，用户仍然可以通过 ``g.ndata['orig_id']`` 和 ``g.edata['orig_id']`` 获取原始ID。\n其中 ``g`` 是 ``DistGraph`` 对象（详细解释，请参见:ref:`guide-distributed-apis`）。\n\nDGL将图划分结果存储在输出目录中的多个文件中。输出目录里始终包含一个名为xxx.json的JSON文件，其中xxx是提供给划分API的图的名称。\nJSON文件包含所有划分的配置。如果该API没有为节点和边分配新ID，它将生成两个额外的NumPy文件：`node_map.npy` 和 `edge_map.npy`。\n它们存储节点和边ID与分区ID之间的映射。对于具有十亿级数量节点和边的图，两个文件中的NumPy数组会很大，\n这是因为图中的每个节点和边都对应一个条目。在每个分区的文件夹内，有3个文件以DGL格式存储分区数据。\n`graph.dgl` 存储分区的图结构以及节点和边上的一些元数据。`node_feats.dgl` 和 `edge_feats.dgl` 存储属于该分区的节点和边的所有特征。\n\n.. code-block:: none\n\n    data_root_dir/\n        |-- xxx.json             # JSON中的分区配置文件\n        |-- node_map.npy         # 存储在NumPy数组中的每个节点的分区ID（可选）\n        |-- edge_map.npy         # 存储在NumPy数组中的每个边的分区ID（可选）\n        |-- part0/               # 分区0的数据\n            |-- node_feats.dgl   # 以二进制格式存储的节点特征\n            |-- edge_feats.dgl   # 以二进制格式存储的边特征\n            |-- graph.dgl        # 以二进制格式存储的子图结构\n        |-- part1/               # 分区1的数据\n            |-- node_feats.dgl\n            |-- edge_feats.dgl\n            |-- graph.dgl\n\n负载均衡\n~~~~~~~~~~~~~~\n\n在对图进行划分时，默认情况下，Metis仅平衡每个子图中的节点数。根据当前的任务情况，这可能带来非最优的配置。\n例如，在半监督节点分类的场景里，训练器会对局部分区中带标签节点的子集进行计算。\n一个仅平衡图中节点(带标签和未带标签)的划分可能会导致计算负载不平衡。为了在每个分区中获得平衡的工作负载，\n划分API通过在 :func:`dgl.distributed.partition_graph` 中指定 ``balance_ntypes``\n在每个节点类型中的节点数上实现分区间的平衡。用户可以利用这一点将训练集、验证集和测试集中的节点看作不同类型的节点。\n\n以下示例将训练集内和训练集外的节点看作两种类型的节点：\n\n.. code:: python\n\n    dgl.distributed.partition_graph(g, 'graph_name', 4, '/tmp/test', balance_ntypes=g.ndata['train_mask'])\n\n除了平衡节点的类型之外， :func:`dgl.distributed.partition_graph` 还允许通过指定\n``balance_edges`` 来平衡每个类型节点在子图中的入度。这平衡了不同类型节点的连边数量。\n\n**Note**: 传给 :func:`dgl.distributed.partition_graph` 的图名称是一个重要的参数。\n:class:`dgl.distributed.DistGraph` 使用该名称来识别一个分布式的图。一个有效的图名称应该仅包含字母和下划线。\n"
  },
  {
    "path": "docs/source/guide_cn/distributed-tools.rst",
    "content": ".. _guide_cn-distributed-tools:\n\n7.3 运行分布式训练/推断所需的工具\n------------------------------------------------------\n\n:ref:`(English Version) <guide-distributed-tools>`\n\nDGL提供了两个脚本来帮助用户进行分布式训练：\n\n* *tools/copy_files.py* 用于将图分区复制到集群，\n* *tools/launch.py* 用于在机器集群中启动分布式训练任务。\n\n*copy_files.py* 将计算机(对图进行分区的计算机)中的分区数据和相关文件(例如，训练脚本)\n复制到(负责分布式训练的)机器集群上。在这些机器上，分布式训练将需要用到这些分区。该脚本包含四个参数：\n\n* ``--part_config`` 指定分区配置文件，该文件包含本地计算机中分区数据的信息。\n* ``--ip_config`` 指定集群的IP配置文件。\n* ``--workspace`` 指定训练机器中存储与分布式训练有关的所有数据的目录。\n* ``--rel_data_path`` 指定工作空间目录下存储分区数据的相对路径。\n* ``--script_folder`` 指定工作空间目录下存储用户的训练脚本的相对路径。\n\n**Note**: *copy_files.py* 会根据IP配置文件找到对应的计算机来存储图分区。因此，copy_files.py和launch.py应该使用相同的IP配置文件。\n\nDGL提供了用于启动集群中的分布式训练任务的tools/launch.py。该脚本有以下假设：\n\n* 分区数据和训练脚本都已被复制到集群或存在集群中所有计算机均可访问的全局存储空间(例如NFS)。\n* 主计算机(执行启动脚本的计算机)具有对集群内所有其他计算机的无密码ssh访问权限。\n\n**Note**: 必须在集群中的一台计算机上调用启动脚本。\n\n下面展示了在集群中启动分布式训练任务的示例。\n\n.. code:: none\n\n    python3 tools/launch.py \\\n    --workspace ~graphsage/ \\\n    --num_trainers 2 \\\n    --num_samplers 4 \\\n    --num_servers 1 \\\n    --part_config data/ogb-product.json \\\n    --ip_config ip_config.txt \\\n    \"python3 code/train_dist.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 5 --batch-size 1000 --lr 0.1 --num_workers 4\"\n\n配置文件 *ip_config.txt* 包含了集群中计算机的IP地址。*ip_config.txt* 的典型示例如下：\n\n.. code:: none\n\n    172.31.19.1\n    172.31.23.205\n    172.31.29.175\n    172.31.16.98\n\n每行是一个计算机的IP地址。IP地址后面还可以有一个端口，用来指定不同训练器之间的网络通信所使用的端口。\n如果未提供具体端口，则默认值为 ``30050``。\n\n启动脚本中指定的工作空间(--workspace)是计算机中的工作目录，里面保存了训练脚本、IP配置文件、分区配置文件以及图分区。\n文件的所有路径都应指定为工作空间的相对路径。\n\n启动脚本会在每台计算机上创建指定数量的训练任务(``--num_trainers``)。另外，\n用户需要为每个训练器指定采样器进程的数量(``--num_samplers``)。\n采样器进程的数量必须匹配 :func:`~dgl.distributed.initialize` 中指定的工作进程的数量。\n"
  },
  {
    "path": "docs/source/guide_cn/distributed.rst",
    "content": ".. _guide_cn-distributed:\n\n第7章：分布式训练\n=====================================\n\n:ref:`(English Version) <guide-distributed>`\n\nDGL采用完全分布式的方法，可将数据和计算同时分布在一组计算资源中。在本节中，\n我们默认使用一个集群的环境设置(即一组机器)。DGL会将一张图划分为多张子图，\n集群中的每台机器各自负责一张子图(分区)。为了并行化计算，DGL在集群所有机器上运行相同的训练脚本，\n并在同样的机器上运行服务器以将分区数据提供给训练器。\n\n对于训练脚本，DGL提供了分布式的API。它们与小批次训练的API相似。用户仅需对单机小批次训练的代码稍作修改就可实现分布式训练。\n以下代码给出了一个用分布式方式训练GraphSage的示例。仅有的代码修改出现在第4-7行：1)初始化DGL的分布式模块，2)创建分布式图对象，以及\n3)拆分训练集，并计算本地进程的节点。其余代码保持不变，与 :ref:`mini_cn-batch training <guide_cn-minibatch>` 类似，\n包括：创建采样器，模型定义，模型训练的循环。\n\n.. code:: python\n\n    import dgl\n    import torch as th\n\n    dgl.distributed.initialize('ip_config.txt')\n    th.distributed.init_process_group(backend='gloo')\n    g = dgl.distributed.DistGraph('graph_name', 'part_config.json')\n    pb = g.get_partition_book()\n    train_nid = dgl.distributed.node_split(g.ndata['train_mask'], pb, force_even=True)\n\n    # 创建采样器\n    sampler = NeighborSampler(g, [10,25],\n                              dgl.distributed.sample_neighbors,\n                              device)\n\n    dataloader = DistDataLoader(\n        dataset=train_nid.numpy(),\n        batch_size=batch_size,\n        collate_fn=sampler.sample_blocks,\n        shuffle=True,\n        drop_last=False)\n\n    # 定义模型和优化器\n    model = SAGE(in_feats, num_hidden, n_classes, num_layers, F.relu, dropout)\n    model = th.nn.parallel.DistributedDataParallel(model)\n    loss_fcn = nn.CrossEntropyLoss()\n    optimizer = optim.Adam(model.parameters(), lr=args.lr)\n\n    # 模型训练的循环\n    for epoch in range(args.num_epochs):\n        for step, blocks in enumerate(dataloader):\n            batch_inputs, batch_labels = load_subtensor(g, blocks[0].srcdata[dgl.NID],\n                                                        blocks[-1].dstdata[dgl.NID])\n            batch_pred = model(blocks, batch_inputs)\n            loss = loss_fcn(batch_pred, batch_labels)\n            optimizer.zero_grad()\n            loss.backward()\n            optimizer.step()\n\n在一个集群的机器上运行训练脚本时，DGL提供了一些工具，可将数据复制到集群的计算机上，并在所有机器上启动训练任务。\n\n**Note**: 当前版本的分布式训练API仅支持PyTorch后端。\n\n**Note**: 当前版本的实现仅支持具有一种节点类型和一种边类型的图。\n\nDGL实现了一些分布式组件以支持分布式训练，下图显示了这些组件及它们间的相互作用。\n\n.. figure:: https://data.dgl.ai/asset/image/distributed.png\n   :alt: Imgur\n\n具体来说，DGL的分布式训练具有三种类型的交互进程：\n*服务器*，\n*采样器* 和 *训练器*。\n\n* *服务器进程* 在存储图分区数据(这包括图结构和节点/边特征)的每台计算机上运行。\n  这些服务器一起工作以将图数据提供给训练器。请注意，一台机器可能同时运行多个服务器进程，以并行化计算和网络通信。\n* *采样器进程* 与服务器进行交互，并对节点和边采样以生成用于训练的小批次数据。\n* *训练器进程* 包含多个与服务器交互的类。它用 :class:`~dgl.distributed.DistGraph` 来获取被划分的图分区数据，\n  用 :class:`~dgl.distributed.DistEmbedding` 和\n  :class:`~dgl.distributed.DistTensor` 来获取节点/边特征/嵌入，用\n  :class:`~dgl.distributed.dist_dataloader.DistDataLoader` 与采样器进行交互以获得小批次数据。\n\n在初步了解了分布式组件后，本章的剩余部分将介绍以下分布式组件：\n\n* :ref:`guide_cn-distributed-preprocessing`\n* :ref:`guide_cn-distributed-apis`\n* :ref:`guide_cn-distributed-tools`\n\n.. toctree::\n    :maxdepth: 1\n    :hidden:\n    :glob:\n\n    distributed-preprocessing\n    distributed-apis\n    distributed-tools\n"
  },
  {
    "path": "docs/source/guide_cn/graph-basic.rst",
    "content": ".. _guide_cn-graph-basic:\n\n1.1 关于图的基本概念\n-----------------\n\n:ref:`(English Version) <guide-graph-basic>`\n\n图是用以表示实体及其关系的结构，记为 :math:`G=(V, E)` 。图由两个集合组成，一是节点的集合 :math:`V` ，一个是边的集合 :math:`E` 。\n在边集 :math:`E` 中，一条边 :math:`(u, v)` 连接一对节点 :math:`u` 和 :math:`v` ，表明两节点间存在关系。关系可以是无向的，\n如描述节点之间的对称关系；也可以是有向的，如描述非对称关系。例如，若用图对社交网络中人们的友谊关系进行建模，因为友谊是相互的，则边是无向的；\n若用图对Twitter用户的关注行为进行建模，则边是有向的。图可以是 *有向的* 或 *无向的* ，这取决于图中边的方向性。\n\n图可以是 *加权的* 或 *未加权的* 。在加权图中，每条边都与一个标量权重值相关联。例如，该权重可以表示长度或连接的强度。\n\n图可以是 *同构的* 或是 *异构的* 。在同构图中，所有节点表示同一类型的实体，所有边表示同一类型的关系。\n例如，社交网络的图由表示同一实体类型的人及其相互之间的社交关系组成。\n\n相对地，在异构图中，节点和边的类型可以是不同的。例如，编码市场的图可以有表示\"顾客\"、\"商家\"和\"商品\"的节点，\n它们通过“想购买”、“已经购买”、“是顾客”和“正在销售”的边互相连接。二分图是一类特殊的、常用的异构图，\n其中的边连接两类不同类型的节点。例如，在推荐系统中，可以使用二分图表示\"用户\"和\"物品\"之间的关系。想了解更多信息，读者可参考 :ref:`guide_cn-graph-heterogeneous`。\n\n在多重图中，同一对节点之间可以有多条（有向）边，包括自循环的边。例如，两名作者可以在不同年份共同署名文章，\n这就带来了具有不同特征的多条边。\n"
  },
  {
    "path": "docs/source/guide_cn/graph-external.rst",
    "content": ".. _guide_cn-graph-external:\n\n1.4 从外部源创建图\n---------------\n\n:ref:`(English Version)<guide-graph-external>`\n\n可以从外部来源构造一个 :class:`~dgl.DGLGraph` 对象，包括：\n\n- 从用于图和稀疏矩阵的外部Python库（NetworkX 和 SciPy）创建而来。\n- 从磁盘加载图数据。\n\n本节不涉及通过转换其他图来生成图的函数，相关概述请阅读API参考手册。\n\n从外部库创建图\n^^^^^^^^^^^\n\n以下代码片段为从SciPy稀疏矩阵和NetworkX图创建DGL图的示例。\n\n.. code::\n\n    >>> import dgl\n    >>> import torch as th\n    >>> import scipy.sparse as sp\n    >>> spmat = sp.rand(100, 100, density=0.05) # 5%非零项\n    >>> dgl.from_scipy(spmat)                   # 来自SciPy\n    Graph(num_nodes=100, num_edges=500,\n          ndata_schemes={}\n          edata_schemes={})\n\n    >>> import networkx as nx\n    >>> nx_g = nx.path_graph(5) # 一条链路0-1-2-3-4\n    >>> dgl.from_networkx(nx_g) # 来自NetworkX\n    Graph(num_nodes=5, num_edges=8,\n          ndata_schemes={}\n          edata_schemes={})\n\n注意，当使用 `nx.path_graph(5)` 进行创建时， :class:`~dgl.DGLGraph` 对象有8条边，而非4条。\n这是由于 `nx.path_graph(5)` 构建了一个无向的NetworkX图 :class:`networkx.Graph` ，而 :class:`~dgl.DGLGraph` 的边总是有向的。\n所以当将无向的NetworkX图转换为 :class:`~dgl.DGLGraph` 对象时，DGL会在内部将1条无向边转换为2条有向边。\n使用有向的NetworkX图 :class:`networkx.DiGraph` 可避免该行为。\n\n.. code::\n\n    >>> nxg = nx.DiGraph([(2, 1), (1, 2), (2, 3), (0, 0)])\n    >>> dgl.from_networkx(nxg)\n    Graph(num_nodes=4, num_edges=4,\n          ndata_schemes={}\n          edata_schemes={})\n\n.. note::\n\n    DGL在内部将SciPy矩阵和NetworkX图转换为张量来创建图。因此，这些构建方法并不适用于重视性能的场景。\n\n相关API： :func:`dgl.from_scipy`、 :func:`dgl.from_networkx`。\n\n从磁盘加载图\n^^^^^^^^^^\n\n有多种文件格式可储存图，所以这里难以枚举所有选项。本节仅给出一些常见格式的一般情况。\n\n逗号分隔值（CSV）\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n\nCSV是一种常见的格式，以表格格式储存节点、边及其特征：\n\n.. table:: nodes.csv\n\n   +-----------+\n   |age, title |\n   +===========+\n   |43, 1      |\n   +-----------+\n   |23, 3      |\n   +-----------+\n   |...        |\n   +-----------+\n\n.. table:: edges.csv\n\n   +-----------------+\n   |src, dst, weight |\n   +=================+\n   |0, 1, 0.4        |\n   +-----------------+\n   |0, 3, 0.9        |\n   +-----------------+\n   |...              |\n   +-----------------+\n\n许多知名Python库(如Pandas)可以将该类型数据加载到python对象(如 :class:`numpy.ndarray`)中，\n进而使用这些对象来构建DGLGraph对象。如果后端框架也提供了从磁盘中保存或加载张量的工具(如 :func:`torch.save`,  :func:`torch.load` )，\n可以遵循相同的原理来构建图。\n\n另见： `从成对的边 CSV 文件中加载 Karate Club Network 的教程 <https://github.com/dglai/WWW20-Hands-on-Tutorial/blob/master/basic_tasks/1_load_data.ipynb>`_。\n\nJSON/GML 格式\n\"\"\"\"\"\"\"\"\"\"\"\"\n\n如果对速度不太关注的话，读者可以使用NetworkX提供的工具来解析 `各种数据格式 <https://networkx.github.io/documentation/stable/reference/readwrite/index.html>`_，\nDGL可以间接地从这些来源创建图。\n\nDGL 二进制格式\n\"\"\"\"\"\"\"\"\"\"\"\"\n\nDGL提供了API以从磁盘中加载或向磁盘里保存二进制格式的图。除了图结构，API也能处理特征数据和图级别的标签数据。\nDGL也支持直接从S3/HDFS中加载或向S3/HDFS保存图。参考手册提供了该用法的更多细节。\n\n相关API： :func:`dgl.save_graphs`、 :func:`dgl.load_graphs`。"
  },
  {
    "path": "docs/source/guide_cn/graph-feature.rst",
    "content": ".. _guide_cn-graph-feature:\n\n1.3 节点和边的特征\n---------------\n\n:ref:`(English Version)<guide-graph-feature>`\n\n:class:`~dgl.DGLGraph` 对象的节点和边可具有多个用户定义的、可命名的特征，以储存图的节点和边的属性。\n通过 :py:attr:`~dgl.DGLGraph.ndata` 和 :py:attr:`~dgl.DGLGraph.edata` 接口可访问这些特征。\n例如，以下代码创建了2个节点特征（分别在第8、15行命名为 ``'x'`` 、 ``'y'`` ）和1个边特征（在第9行命名为 ``'x'`` ）。\n\n.. code-block:: python\n    :linenos:\n\n    >>> import dgl\n    >>> import torch as th\n    >>> g = dgl.graph(([0, 0, 1, 5], [1, 2, 2, 0])) # 6个节点，4条边\n    >>> g\n    Graph(num_nodes=6, num_edges=4,\n          ndata_schemes={}\n          edata_schemes={})\n    >>> g.ndata['x'] = th.ones(g.num_nodes(), 3)               # 长度为3的节点特征\n    >>> g.edata['x'] = th.ones(g.num_edges(), dtype=th.int32)  # 标量整型特征\n    >>> g\n    Graph(num_nodes=6, num_edges=4,\n          ndata_schemes={'x' : Scheme(shape=(3,), dtype=torch.float32)}\n          edata_schemes={'x' : Scheme(shape=(,), dtype=torch.int32)})\n    >>> # 不同名称的特征可以具有不同形状\n    >>> g.ndata['y'] = th.randn(g.num_nodes(), 5)\n    >>> g.ndata['x'][1]                  # 获取节点1的特征\n    tensor([1., 1., 1.])\n    >>> g.edata['x'][th.tensor([0, 3])]  # 获取边0和3的特征\n        tensor([1, 1], dtype=torch.int32)\n\n关于 :py:attr:`~dgl.DGLGraph.ndata` 和 :py:attr:`~dgl.DGLGraph.edata` 接口的重要说明：\n\n- 仅允许使用数值类型（如单精度浮点型、双精度浮点型和整型）的特征。这些特征可以是标量、向量或多维张量。\n- 每个节点特征具有唯一名称，每个边特征也具有唯一名称。节点和边的特征可以具有相同的名称（如上述示例代码中的 ``'x'`` ）。\n- 通过张量分配创建特征时，DGL会将特征赋给图中的每个节点和每条边。该张量的第一维必须与图中节点或边的数量一致。\n  不能将特征赋给图中节点或边的子集。\n- 相同名称的特征必须具有相同的维度和数据类型。\n- 特征张量使用\"行优先\"的原则，即每个行切片储存1个节点或1条边的特征（参考上述示例代码的第16和18行）。\n\n对于加权图，用户可以将权重储存为一个边特征，如下。\n\n.. code-block:: python\n\n    >>> # 边 0->1, 0->2, 0->3, 1->3\n    >>> edges = th.tensor([0, 0, 0, 1]), th.tensor([1, 2, 3, 3])\n    >>> weights = th.tensor([0.1, 0.6, 0.9, 0.7])  # 每条边的权重\n    >>> g = dgl.graph(edges)\n    >>> g.edata['w'] = weights  # 将其命名为 'w'\n    >>> g\n    Graph(num_nodes=4, num_edges=4,\n          ndata_schemes={}\n          edata_schemes={'w' : Scheme(shape=(,), dtype=torch.float32)})\n\n\n\n相关API： :py:attr:`~dgl.DGLGraph.ndata`、 :py:attr:`~dgl.DGLGraph.edata`。"
  },
  {
    "path": "docs/source/guide_cn/graph-gpu.rst",
    "content": ".. _guide_cn-graph-gpu:\n\n1.6 在GPU上使用DGLGraph\n----------------------\n\n:ref:`(English Version)<guide-graph-gpu>`\n\n用户可以通过在构造过程中传入两个GPU张量来创建GPU上的 :class:`~dgl.DGLGraph` 。\n另一种方法是使用 :func:`~dgl.DGLGraph.to` API将 :class:`~dgl.DGLGraph` 复制到GPU，这会将图结构和特征数据都拷贝到指定的设备。\n\n.. code::\n\n    >>> import dgl\n    >>> import torch as th\n    >>> u, v = th.tensor([0, 1, 2]), th.tensor([2, 3, 4])\n    >>> g = dgl.graph((u, v))\n    >>> g.ndata['x'] = th.randn(5, 3)   # 原始特征在CPU上\n    >>> g.device\n    device(type='cpu')\n    >>> cuda_g = g.to('cuda:0')         # 接受来自后端框架的任何设备对象\n    >>> cuda_g.device\n    device(type='cuda', index=0)\n    >>> cuda_g.ndata['x'].device        # 特征数据也拷贝到了GPU上\n    device(type='cuda', index=0)\n\n    >>> # 由GPU张量构造的图也在GPU上\n    >>> u, v = u.to('cuda:0'), v.to('cuda:0')\n    >>> g = dgl.graph((u, v))\n    >>> g.device\n    device(type='cuda', index=0)\n\n任何涉及GPU图的操作都是在GPU上运行的。因此，这要求所有张量参数都已经放在GPU上，其结果(图或张量)也将在GPU上。\n此外，GPU图只接受GPU上的特征数据。\n\n.. code::\n\n    >>> cuda_g.in_degrees()\n    tensor([0, 0, 1, 1, 1], device='cuda:0')\n    >>> cuda_g.in_edges([2, 3, 4])                          # 可以接受非张量类型的参数\n    (tensor([0, 1, 2], device='cuda:0'), tensor([2, 3, 4], device='cuda:0'))\n    >>> cuda_g.in_edges(th.tensor([2, 3, 4]).to('cuda:0'))  # 张量类型的参数必须在GPU上\n    (tensor([0, 1, 2], device='cuda:0'), tensor([2, 3, 4], device='cuda:0'))\n    >>> cuda_g.ndata['h'] = th.randn(5, 4)                  # ERROR! 特征也必须在GPU上！\n    DGLError: Cannot assign node feature \"h\" on device cpu to a graph on device\n    cuda:0. Call DGLGraph.to() to copy the graph to the same device.\n"
  },
  {
    "path": "docs/source/guide_cn/graph-graphs-nodes-edges.rst",
    "content": ".. _guide_cn-graph-graphs-nodes-edges:\n\n1.2 图、节点和边\n--------------\n\n:ref:`(English Version)<guide-graph-graphs-nodes-edges>`\n\nDGL使用一个唯一的整数来表示一个节点，称为点ID；并用对应的两个端点ID表示一条边。同时，DGL也会根据边被添加的顺序，\n给每条边分配一个唯一的整数编号，称为边ID。节点和边的ID都是从0开始构建的。在DGL的图里，所有的边都是有方向的，\n即边 :math:`(u, v)` 表示它是从节点 :math:`u` 指向节点 :math:`v` 的。\n\n对于多个节点，DGL使用一个一维的整型张量（如，PyTorch的Tensor类，TensorFlow的Tensor类或MXNet的ndarray类）来保存图的点ID，\nDGL称之为\"节点张量\"。为了指代多条边，DGL使用一个包含2个节点张量的元组 :math:`(U, V)` ，其中，用 :math:`(U[i], V[i])` 指代一条\n:math:`U[i]` 到 :math:`V[i]` 的边。\n\n创建一个 :class:`~dgl.DGLGraph` 对象的一种方法是使用 :func:`dgl.graph` 函数。它接受一个边的集合作为输入。DGL也支持从其他的数据源来创建图对象。\n读者可参考 :ref:`guide_cn-graph-external`。\n\n下面的代码段使用了 :func:`dgl.graph` 函数来构建一个 :class:`~dgl.DGLGraph` 对象，对应着下图所示的包含4个节点的图。\n其中一些代码演示了查询图结构的部分API的使用方法。\n\n.. figure:: https://data.dgl.ai/asset/image/user_guide_graphch_1.png\n    :height: 200px\n    :width: 300px\n    :align: center\n\n.. code::\n\n    >>> import dgl\n    >>> import torch as th\n\n    >>> # 边 0->1, 0->2, 0->3, 1->3\n    >>> u, v = th.tensor([0, 0, 0, 1]), th.tensor([1, 2, 3, 3])\n    >>> g = dgl.graph((u, v))\n    >>> print(g) # 图中节点的数量是DGL通过给定的图的边列表中最大的点ID推断所得出的\n    Graph(num_nodes=4, num_edges=4,\n          ndata_schemes={}\n          edata_schemes={})\n\n    >>> # 获取节点的ID\n    >>> print(g.nodes())\n    tensor([0, 1, 2, 3])\n    >>> # 获取边的对应端点\n    >>> print(g.edges())\n    (tensor([0, 0, 0, 1]), tensor([1, 2, 3, 3]))\n    >>> # 获取边的对应端点和边ID\n    >>> print(g.edges(form='all'))\n    (tensor([0, 0, 0, 1]), tensor([1, 2, 3, 3]), tensor([0, 1, 2, 3]))\n\n    >>> # 如果具有最大ID的节点没有边，在创建图的时候，用户需要明确地指明节点的数量。\n    >>> g = dgl.graph((u, v), num_nodes=8)\n\n对于无向的图，用户需要为每条边都创建两个方向的边。可以使用 :func:`dgl.to_bidirected` 函数来实现这个目的。\n如下面的代码段所示，这个函数可以把原图转换成一个包含反向边的图。\n\n.. code::\n\n    >>> bg = dgl.to_bidirected(g)\n    >>> bg.edges()\n    (tensor([0, 0, 0, 1, 1, 2, 3, 3]), tensor([1, 2, 3, 0, 3, 0, 0, 1]))\n\n.. note::\n\n    由于Tensor类内部使用C来存储，且显性定义了数据类型以及存储的设备信息，DGL推荐使用Tensor作为DGL API的输入。\n    不过大部分的DGL API也支持Python的可迭代类型(比如列表)或numpy.ndarray类型作为API的输入，方便用户快速进行开发验证。\n\nDGL支持使用 :math:`32` 位或 :math:`64` 位的整数作为节点ID和边ID。节点和边ID的数据类型必须一致。如果使用 :math:`64` 位整数，\nDGL可以处理最多 :math:`2^{63} - 1` 个节点或边。不过，如果图里的节点或者边的数量小于 :math:`2^{31} - 1` ，用户最好使用 :math:`32` 位整数。\n这样不仅能提升速度，还能减少内存的使用。DGL提供了进行数据类型转换的方法，如下例所示。\n\n.. code::\n\n    >>> edges = th.tensor([2, 5, 3]), th.tensor([3, 5, 0])  # 边：2->3, 5->5, 3->0\n    >>> g64 = dgl.graph(edges)  # DGL默认使用int64\n    >>> print(g64.idtype)\n    torch.int64\n    >>> g32 = dgl.graph(edges, idtype=th.int32)  # 使用int32构建图\n    >>> g32.idtype\n    torch.int32\n    >>> g64_2 = g32.long()  # 转换成int64\n    >>> g64_2.idtype\n    torch.int64\n    >>> g32_2 = g64.int()  # 转换成int32\n    >>> g32_2.idtype\n    torch.int32\n\n相关API：:func:`dgl.graph`、 :func:`dgl.DGLGraph.nodes`、 :func:`dgl.DGLGraph.edges`、 :func:`dgl.to_bidirected`、\n:func:`dgl.DGLGraph.int`、 :func:`dgl.DGLGraph.long` 和 :py:attr:`dgl.DGLGraph.idtype`。\n"
  },
  {
    "path": "docs/source/guide_cn/graph-heterogeneous.rst",
    "content": ".. _guide_cn-graph-heterogeneous:\n\n1.5 异构图\n---------\n\n:ref:`(English Version)<guide-graph-heterogeneous>`\n\n相比同构图，异构图里可以有不同类型的节点和边。这些不同类型的节点和边具有独立的ID空间和特征。\n例如在下图中，\"用户\"和\"游戏\"节点的ID都是从0开始的，而且两种节点具有不同的特征。\n\n.. figure:: https://data.dgl.ai/asset/image/user_guide_graphch_2.png\n\n    一个异构图示例。该图具有两种类型的节点(\"用户\"和\"游戏\")和两种类型的边(\"关注\"和\"玩\")。\n\n创建异构图\n^^^^^^^^\n\n在DGL中，一个异构图由一系列子图构成，一个子图对应一种关系。每个关系由一个字符串三元组\n定义 ``(源节点类型, 边类型, 目标节点类型)`` 。由于这里的关系定义消除了边类型的歧义，DGL称它们为规范边类型。\n\n下面的代码是一个在DGL中创建异构图的示例。\n\n.. code::\n\n    >>> import dgl\n    >>> import torch as th\n\n    >>> # 创建一个具有3种节点类型和3种边类型的异构图\n    >>> graph_data = {\n    ...    ('drug', 'interacts', 'drug'): (th.tensor([0, 1]), th.tensor([1, 2])),\n    ...    ('drug', 'interacts', 'gene'): (th.tensor([0, 1]), th.tensor([2, 3])),\n    ...    ('drug', 'treats', 'disease'): (th.tensor([1]), th.tensor([2]))\n    ... }\n    >>> g = dgl.heterograph(graph_data)\n    >>> g.ntypes\n    ['disease', 'drug', 'gene']\n    >>> g.etypes\n    ['interacts', 'interacts', 'treats']\n    >>> g.canonical_etypes\n    [('drug', 'interacts', 'drug'),\n     ('drug', 'interacts', 'gene'),\n     ('drug', 'treats', 'disease')]\n\n注意，同构图和二分图只是一种特殊的异构图，它们只包括一种关系。\n\n.. code::\n\n    >>> # 一个同构图\n    >>> dgl.heterograph({('node_type', 'edge_type', 'node_type'): (u, v)})\n    >>> # 一个二分图\n    >>> dgl.heterograph({('source_type', 'edge_type', 'destination_type'): (u, v)})\n\n与异构图相关联的 *metagraph* 就是图的模式。它指定节点集和节点之间的边的类型约束。\n*metagraph* 中的一个节点 :math:`u` 对应于相关异构图中的一个节点类型。\n*metagraph* 中的边 :math:`(u,v)` 表示在相关异构图中存在从 :math:`u` 型节点到 :math:`v` 型节点的边。\n\n.. code::\n\n    >>> g\n    Graph(num_nodes={'disease': 3, 'drug': 3, 'gene': 4},\n          num_edges={('drug', 'interacts', 'drug'): 2,\n                     ('drug', 'interacts', 'gene'): 2,\n                     ('drug', 'treats', 'disease'): 1},\n          metagraph=[('drug', 'drug', 'interacts'),\n                     ('drug', 'gene', 'interacts'),\n                     ('drug', 'disease', 'treats')])\n    >>> g.metagraph().edges()\n    OutMultiEdgeDataView([('drug', 'drug'), ('drug', 'gene'), ('drug', 'disease')])\n\n相关API: :func:`dgl.heterograph`、 :py:attr:`~dgl.DGLGraph.ntypes`、 :py:attr:`~dgl.DGLGraph.etypes`、\n:py:attr:`~dgl.DGLGraph.canonical_etypes`、 :py:attr:`~dgl.DGLGraph.metagraph`。\n\n使用多种类型\n^^^^^^^^^^\n\n当引入多种节点和边类型后，用户在调用DGLGraph API以获取特定类型的信息时，需要指定具体的节点和边类型。此外，不同类型的节点和边具有单独的ID。\n\n.. code::\n\n    >>> # 获取图中所有节点的数量\n    >>> g.num_nodes()\n    10\n    >>> # 获取drug节点的数量\n    >>> g.num_nodes('drug')\n    3\n    >>> # 不同类型的节点有单独的ID。因此，没有指定节点类型就没有明确的返回值。\n    >>> g.nodes()\n    DGLError: Node type name must be specified if there are more than one node types.\n    >>> g.nodes('drug')\n    tensor([0, 1, 2])\n\n为了设置/获取特定节点和边类型的特征，DGL提供了两种新类型的语法： `g.nodes['node_type'].data['feat_name']` 和 `g.edges['edge_type'].data['feat_name']` 。\n\n.. code::\n\n    >>> # 设置/获取\"drug\"类型的节点的\"hv\"特征\n    >>> g.nodes['drug'].data['hv'] = th.ones(3, 1)\n    >>> g.nodes['drug'].data['hv']\n    tensor([[1.],\n            [1.],\n            [1.]])\n    >>> # 设置/获取\"treats\"类型的边的\"he\"特征\n    >>> g.edges['treats'].data['he'] = th.zeros(1, 1)\n    >>> g.edges['treats'].data['he']\n    tensor([[0.]])\n\n如果图里只有一种节点或边类型，则不需要指定节点或边的类型。\n\n.. code::\n\n    >>> g = dgl.heterograph({\n    ...    ('drug', 'interacts', 'drug'): (th.tensor([0, 1]), th.tensor([1, 2])),\n    ...    ('drug', 'is similar', 'drug'): (th.tensor([0, 1]), th.tensor([2, 3]))\n    ... })\n    >>> g.nodes()\n    tensor([0, 1, 2, 3])\n    >>> # 设置/获取单一类型的节点或边特征，不必使用新的语法\n    >>> g.ndata['hv'] = th.ones(4, 1)\n\n.. note::\n\n    当边类型唯一地确定了源节点和目标节点的类型时，用户可以只使用一个字符串而不是字符串三元组来指定边类型。例如，\n    对于具有两个关系 ``('user', 'plays', 'game')`` 和  ``('user', 'likes', 'game')`` 的异构图，\n    只使用 ``'plays'`` 或 ``'like'`` 来指代这两个关系是可以的。\n\n从磁盘加载异构图\n^^^^^^^^^^^^^\n\n逗号分隔值（CSV）\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n\n一种存储异构图的常见方法是在不同的CSV文件中存储不同类型的节点和边。下面是一个例子。\n\n.. code::\n\n    # 数据文件夹\n    data/\n    |-- drug.csv        # drug节点\n    |-- gene.csv        # gene节点\n    |-- disease.csv     # disease节点\n    |-- drug-interact-drug.csv  # drug-drug相互作用边\n    |-- drug-interact-gene.csv  # drug-gene相互作用边\n    |-- drug-treat-disease.csv  # drug-disease治疗边\n\n与同构图的情况类似，用户可以使用像Pandas这样的包先将CSV文件解析为numpy数组或框架张量，再构建一个关系字典，并用它构造一个异构图。\n这种方法也适用于其他流行的文件格式，比如GML或JSON。\n\nDGL二进制格式\n\"\"\"\"\"\"\"\"\"\"\"\n\nDGL提供了 :func:`dgl.save_graphs` 和 :func:`dgl.load_graphs` 函数，分别用于以二进制格式保存异构图和加载它们。\n\n边类型子图\n^^^^^^^^\n\n用户可以通过指定要保留的关系来创建异构图的子图，相关的特征也会被拷贝。\n\n.. code::\n\n    >>> g = dgl.heterograph({\n    ...    ('drug', 'interacts', 'drug'): (th.tensor([0, 1]), th.tensor([1, 2])),\n    ...    ('drug', 'interacts', 'gene'): (th.tensor([0, 1]), th.tensor([2, 3])),\n    ...    ('drug', 'treats', 'disease'): (th.tensor([1]), th.tensor([2]))\n    ... })\n    >>> g.nodes['drug'].data['hv'] = th.ones(3, 1)\n\n    >>> # 保留关系 ('drug', 'interacts', 'drug') 和 ('drug', 'treats', 'disease') 。\n    >>> # 'drug' 和 'disease' 类型的节点也会被保留\n    >>> eg = dgl.edge_type_subgraph(g, [('drug', 'interacts', 'drug'),\n    ...                                 ('drug', 'treats', 'disease')])\n    >>> eg\n    Graph(num_nodes={'disease': 3, 'drug': 3},\n          num_edges={('drug', 'interacts', 'drug'): 2, ('drug', 'treats', 'disease'): 1},\n          metagraph=[('drug', 'drug', 'interacts'), ('drug', 'disease', 'treats')])\n    >>> # 相关的特征也会被拷贝\n    >>> eg.nodes['drug'].data['hv']\n    tensor([[1.],\n            [1.],\n            [1.]])\n\n\n将异构图转化为同构图\n^^^^^^^^^^^^^^^^\n\n异构图为管理不同类型的节点和边及其相关特征提供了一个清晰的接口。这在以下情况下尤其有用:\n\n1. 不同类型的节点和边的特征具有不同的数据类型或大小。\n2. 用户希望对不同类型的节点和边应用不同的操作。\n\n如果上述情况不适用，并且用户不希望在建模中区分节点和边的类型，则DGL允许使用 :func:`dgl.DGLGraph.to_homogeneous` API将异构图转换为同构图。\n具体行为如下:\n\n1. 用从0开始的连续整数重新标记所有类型的节点和边。\n2. 对所有的节点和边合并用户指定的特征。\n\n.. code::\n\n    >>> g = dgl.heterograph({\n    ...    ('drug', 'interacts', 'drug'): (th.tensor([0, 1]), th.tensor([1, 2])),\n    ...    ('drug', 'treats', 'disease'): (th.tensor([1]), th.tensor([2]))})\n    >>> g.nodes['drug'].data['hv'] = th.zeros(3, 1)\n    >>> g.nodes['disease'].data['hv'] = th.ones(3, 1)\n    >>> g.edges['interacts'].data['he'] = th.zeros(2, 1)\n    >>> g.edges['treats'].data['he'] = th.zeros(1, 2)\n\n    >>> # 默认情况下不进行特征合并\n    >>> hg = dgl.to_homogeneous(g)\n    >>> 'hv' in hg.ndata\n    False\n\n    >>> # 拷贝边的特征\n    >>> # 对于要拷贝的特征，DGL假定不同类型的节点或边的需要合并的特征具有相同的大小和数据类型\n    >>> hg = dgl.to_homogeneous(g, edata=['he'])\n    DGLError: Cannot concatenate column ‘he’ with shape Scheme(shape=(2,), dtype=torch.float32) and shape Scheme(shape=(1,), dtype=torch.float32)\n\n    >>> # 拷贝节点特征\n    >>> hg = dgl.to_homogeneous(g, ndata=['hv'])\n    >>> hg.ndata['hv']\n    tensor([[1.],\n            [1.],\n            [1.],\n            [0.],\n            [0.],\n            [0.]])\n\n原始的节点或边的类型和对应的ID被存储在 :py:attr:`~dgl.DGLGraph.ndata` 和 :py:attr:`~dgl.DGLGraph.edata` 中。\n\n.. code::\n\n    >>> # 异构图中节点类型的顺序\n    >>> g.ntypes\n    ['disease', 'drug']\n    >>> # 原始节点类型\n    >>> hg.ndata[dgl.NTYPE]\n    tensor([0, 0, 0, 1, 1, 1])\n    >>> # 原始的特定类型节点ID\n    >>> hg.ndata[dgl.NID]\n    tensor([0, 1, 2, 0, 1, 2])\n\n    >>> # 异构图中边类型的顺序\n    >>> g.etypes\n    ['interacts', 'treats']\n    >>> # 原始边类型\n    >>> hg.edata[dgl.ETYPE]\n    tensor([0, 0, 1])\n    >>> # 原始的特定类型边ID\n    >>> hg.edata[dgl.EID]\n    tensor([0, 1, 0])\n\n出于建模的目的，用户可能需要将一些关系合并，并对它们应用相同的操作。为了实现这一目的，可以先抽取异构图的边类型子图，然后将该子图转换为同构图。\n\n.. code::\n\n    >>> g = dgl.heterograph({\n    ...    ('drug', 'interacts', 'drug'): (th.tensor([0, 1]), th.tensor([1, 2])),\n    ...    ('drug', 'interacts', 'gene'): (th.tensor([0, 1]), th.tensor([2, 3])),\n    ...    ('drug', 'treats', 'disease'): (th.tensor([1]), th.tensor([2]))\n    ... })\n    >>> sub_g = dgl.edge_type_subgraph(g, [('drug', 'interacts', 'drug'),\n    ...                                    ('drug', 'interacts', 'gene')])\n    >>> h_sub_g = dgl.to_homogeneous(sub_g)\n    >>> h_sub_g\n    Graph(num_nodes=7, num_edges=4,\n          ...)\n"
  },
  {
    "path": "docs/source/guide_cn/graph.rst",
    "content": ".. _guide_cn-graph:\n\n第1章：图\n=============\n\n:ref:`(English Version)<guide-graph>`\n\n图表示实体(节点)和它们的关系(边)，其中节点和边可以是有类型的 (例如，``\"用户\"`` 和 ``\"物品\"`` 是两种不同类型的节点)。\nDGL通过其核心数据结构  :class:`~dgl.DGLGraph` 提供了一个以图为中心的编程抽象。 :class:`~dgl.DGLGraph` 提供了接口以处理图的结构、节点/边\n的特征，以及使用这些组件可以执行的计算。\n\n\n本章路线图\n--------------\n\n本章首先简要介绍了图的定义（见1.1节），然后介绍了一些 :class:`~dgl.DGLGraph` 相关的核心概念：\n\n* :ref:`guide_cn-graph-basic`\n* :ref:`guide_cn-graph-graphs-nodes-edges`\n* :ref:`guide_cn-graph-feature`\n* :ref:`guide_cn-graph-external`\n* :ref:`guide_cn-graph-heterogeneous`\n* :ref:`guide_cn-graph-gpu`\n\n.. toctree::\n    :maxdepth: 1\n    :hidden:\n    :glob:\n\n    graph-basic\n    graph-graphs-nodes-edges\n    graph-feature\n    graph-external\n    graph-heterogeneous\n    graph-gpu\n"
  },
  {
    "path": "docs/source/guide_cn/index.rst",
    "content": "用户指南【包含过时信息】\n===================\n\n.. toctree::\n  :maxdepth: 2\n  :titlesonly:\n\n  graph\n  message\n  nn\n  data\n  training\n  minibatch\n  distributed\n\n2020年9月，DGL社区的一群热心贡献者把DGL用户指南译成了中文，方便广大中文用户群学习和使用DGL。\n\n特此致谢下述贡献者：\n\n.. list-table::\n   :widths: 20 20 20\n   :header-rows: 1\n\n   * - 章节\n     - 个人姓名/昵称\n     - 个人链接\n   * - :ref:`guide_cn-graph`\n     - 张怀文/Huaiwen Zhang\n     - https://github.com/huaiwen\n   * - :ref:`guide_cn-graph-basic`\n     - 沈成 / mlsoar\n     - https://github.com/mlsoar\n   * - :ref:`guide_cn-graph-graphs-nodes-edges`\n     - 张建 / zhjwy9343\n     - https://github.com/zhjwy9343\n   * - :ref:`guide_cn-graph-feature`\n     - 沈成 / mlsoar\n     - https://github.com/mlsoar\n   * - :ref:`guide_cn-graph-external`\n     - 沈成 / mlsoar\n     - https://github.com/mlsoar\n   * - :ref:`guide_cn-graph-heterogeneous`\n     - 张怀文/Huaiwen Zhang\n     - https://github.com/huaiwen\n   * - :ref:`guide_cn-message-passing`,\n     - 黄崟/Brook Huang\n     - https://github.com/brookhuang16211\n   * - :ref:`guide_cn-message-passing-api`\n     - 黄崟/Brook Huang\n     - https://github.com/brookhuang16211\n   * - :ref:`guide_cn-message-passing-efficient`\n     - 黄崟/Brook Huang\n     - https://github.com/brookhuang16211\n   * - :ref:`guide_cn-message-passing-part`\n     - 陈知雨/Zhiyu Chen\n     - https://www.zhiyuchen.com\n   * - :ref:`guide_cn-message-passing-edge`\n     - 陈知雨/Zhiyu Chen\n     - https://www.zhiyuchen.com\n   * - :ref:`guide_cn-message-passing-heterograph`\n     - 陈知雨/Zhiyu Chen\n     - https://www.zhiyuchen.com\n   * - :ref:`guide_cn-nn`\n     - 陈知雨/Zhiyu Chen\n     - https://www.zhiyuchen.com\n   * - :ref:`guide_cn-nn-construction`\n     - 陈知雨/Zhiyu Chen\n     - https://www.zhiyuchen.com\n   * - :ref:`guide_cn-nn-forward`\n     - 栩栩的夏天\n     -\n   * - :ref:`guide_cn-nn-heterograph`\n     - 栩栩的夏天\n     -\n   * - :ref:`guide_cn-data-pipeline`\n     - 吴紫薇/ Maggie Wu\n     - https://github.com/hhhiddleston\n   * - :ref:`guide_cn-data-pipeline-dataset`\n     - 吴紫薇/ Maggie Wu\n     - https://github.com/hhhiddleston\n   * - :ref:`guide_cn-data-pipeline-download`\n     - 吴紫薇/ Maggie Wu\n     - https://github.com/hhhiddleston\n   * - :ref:`guide_cn-data-pipeline-process`\n     - 吴紫薇/ Maggie Wu\n     - https://github.com/hhhiddleston\n   * - :ref:`guide_cn-data-pipeline-savenload`\n     - 王建民/DrugAI\n     - https://github.com/AspirinCode\n   * - :ref:`guide_cn-data-pipeline-loadogb`\n     - 王建民/DrugAI\n     - https://github.com/AspirinCode\n   * - :ref:`guide_cn-training`\n     - 王建民/DrugAI\n     - https://github.com/AspirinCode\n   * - :ref:`guide_cn-training-node-classification`,\n     - 王建民/DrugAI\n     - https://github.com/AspirinCode\n   * - :ref:`guide_cn-training-edge-classification`\n     - 徐东辉/DonghuiXu\n     - https://github.com/rewonderful\n   * - :ref:`guide_cn-training-link-prediction`\n     - 徐东辉/DonghuiXu\n     - https://github.com/rewonderful\n   * - :ref:`guide_cn-training-graph-classification`\n     - 莫佳帅子/Molasses\n     - https://github.com/sleeplessai\n   * - :ref:`guide_cn-minibatch`\n     - 莫佳帅子/Molasses\n     - https://github.com/sleeplessai\n   * - :ref:`guide_cn-minibatch-node-classification-sampler`\n     - 孟凡荣/kevin-meng\n     - https://github.com/kevin-meng\n   * - :ref:`guide_cn-minibatch-edge-classification-sampler`\n     - 莫佳帅子/Molasses\n     - https://github.com/sleeplessai\n   * - :ref:`guide_cn-minibatch-link-classification-sampler`\n     - 孟凡荣/kevin-meng\n     - https://github.com/kevin-meng\n   * - :ref:`guide_cn-minibatch-customizing-neighborhood-sampler`\n     - 孟凡荣/kevin-meng\n     - https://github.com/kevin-meng\n   * - :ref:`guide_cn-minibatch-custom-gnn-module`\n     - 胡骏\n     - https://github.com/CrawlScript\n   * - :ref:`guide_cn-minibatch-inference`\n     - 胡骏\n     - https://github.com/CrawlScript\n   * - :ref:`guide_cn-distributed`\n     - 宋怡然/Yiran Song\n     - https://github.com/rr-Yiran\n   * - :ref:`guide_cn-distributed-preprocessing`\n     - 宋怡然/Yiran Song\n     - https://github.com/rr-Yiran\n   * - :ref:`guide_cn-distributed-apis`\n     - 李庆标/Qingbiao Li\n     - https://qingbiaoli.github.io/\n   * - :ref:`guide_cn-distributed-tools`\n     - 李庆标/Qingbiao Li\n     - https://qingbiaoli.github.io/\n"
  },
  {
    "path": "docs/source/guide_cn/message-api.rst",
    "content": ".. _guide_cn-message-passing-api:\n\n2.1 内置函数和消息传递API\n----------------------\n\n:ref:`(English Version) <guide-message-passing-api>`\n\n在DGL中，**消息函数** 接受一个参数 ``edges``，这是一个 :class:`~dgl.udf.EdgeBatch` 的实例，\n在消息传递时，它被DGL在内部生成以表示一批边。 ``edges`` 有 ``src``、 ``dst`` 和 ``data`` 共3个成员属性，\n分别用于访问源节点、目标节点和边的特征。\n\n**聚合函数** 接受一个参数 ``nodes``，这是一个 :class:`~dgl.udf.NodeBatch` 的实例，\n在消息传递时，它被DGL在内部生成以表示一批节点。 ``nodes`` 的成员属性 ``mailbox`` 可以用来访问节点收到的消息。\n一些最常见的聚合操作包括 ``sum``、``max``、``min`` 等。\n\n**更新函数** 接受一个如上所述的参数 ``nodes``。此函数对 ``聚合函数`` 的聚合结果进行操作，\n通常在消息传递的最后一步将其与节点的特征相结合，并将输出作为节点的新特征。\n\nDGL在命名空间 ``dgl.function`` 中实现了常用的消息函数和聚合函数作为 **内置函数**。\n一般来说，DGL建议 **尽可能** 使用内置函数，因为它们经过了大量优化，并且可以自动处理维度广播。\n\n如果用户的消息传递函数无法用内置函数实现，则可以实现自己的消息或聚合函数(也称为 **用户定义函数** )。\n\n内置消息函数可以是一元函数或二元函数。对于一元函数，DGL支持 ``copy`` 函数。对于二元函数，\nDGL现在支持 ``add``、 ``sub``、 ``mul``、 ``div``、 ``dot`` 函数。消息的内置函数的命名约定是 ``u`` 表示 ``源`` 节点，\n``v`` 表示 ``目标`` 节点，``e`` 表示 ``边``。这些函数的参数是字符串，指示相应节点和边的输入和输出特征字段名。\n关于内置函数的列表，请参见 :ref:`api-built-in`。例如，要对源节点的 ``hu`` 特征和目标节点的 ``hv`` 特征求和，\n然后将结果保存在边的 ``he`` 特征上，用户可以使用内置函数 ``dgl.function.u_add_v('hu', 'hv', 'he')``。\n而以下用户定义消息函数与此内置函数等价。\n\n.. code::\n\n    def message_func(edges):\n         return {'he': edges.src['hu'] + edges.dst['hv']}\n\nDGL支持内置的聚合函数 ``sum``、 ``max``、 ``min`` 和 ``mean`` 操作。\n聚合函数通常有两个参数，它们的类型都是字符串。一个用于指定 ``mailbox`` 中的字段名，一个用于指示目标节点特征的字段名，\n例如， ``dgl.function.sum('m', 'h')`` 等价于如下所示的对接收到消息求和的用户定义函数：\n\n.. code::\n\n    import torch\n    def reduce_func(nodes):\n         return {'h': torch.sum(nodes.mailbox['m'], dim=1)}\n\n关于用户定义函数的进阶用法，参见 :ref:`apiudf`。\n\n在DGL中，也可以在不涉及消息传递的情况下，通过 :meth:`~dgl.DGLGraph.apply_edges` 单独调用逐边计算。\n:meth:`~dgl.DGLGraph.apply_edges` 的参数是一个消息函数。并且在默认情况下，这个接口将更新所有的边。例如：\n\n.. code::\n\n    import dgl.function as fn\n    graph.apply_edges(fn.u_add_v('el', 'er', 'e'))\n\n对于消息传递， :meth:`~dgl.DGLGraph.update_all` 是一个高级API。它在单个API调用里合并了消息生成、\n消息聚合和节点特征更新，这为从整体上进行系统优化提供了空间。\n\n:meth:`~dgl.DGLGraph.update_all` 的参数是一个消息函数、一个聚合函数和一个更新函数。\n更新函数是一个可选择的参数，用户也可以不使用它，而是在 ``update_all`` 执行完后直接对节点特征进行操作。\n由于更新函数通常可以用纯张量操作实现，所以DGL不推荐在 ``update_all`` 中指定更新函数。例如：\n\n.. code::\n\n    def update_all_example(graph):\n        # 在graph.ndata['ft']中存储结果\n        graph.update_all(fn.u_mul_e('ft', 'a', 'm'),\n                         fn.sum('m', 'ft'))\n        # 在update_all外调用更新函数\n        final_ft = graph.ndata['ft'] * 2\n        return final_ft\n\n此调用通过将源节点特征 ``ft`` 与边特征 ``a`` 相乘生成消息 ``m``，\n然后对所有消息求和来更新节点特征 ``ft``，再将 ``ft`` 乘以2得到最终结果 ``final_ft``。\n\n调用后，中间消息 ``m`` 将被清除。上述函数的数学公式为：\n\n.. math::  {final\\_ft}_i = 2 * \\sum_{j\\in\\mathcal{N}(i)} ({ft}_j * a_{ij})\n"
  },
  {
    "path": "docs/source/guide_cn/message-efficient.rst",
    "content": ".. _guide_cn-message-passing-efficient:\n\n2.2 编写高效的消息传递代码\n----------------------\n\n:ref:`(English Version) <guide-message-passing-efficient>`\n\nDGL优化了消息传递的内存消耗和计算速度。利用这些优化的一个常见实践是通过基于内置函数的 :meth:`~dgl.DGLGraph.update_all` 来开发消息传递功能。\n\n除此之外，考虑到某些图边的数量远远大于节点的数量，DGL建议避免不必要的从点到边的内存拷贝。对于某些情况，比如 :class:`~dgl.nn.pytorch.conv.GATConv`，计算必须在边上保存消息，\n那么用户就需要调用基于内置函数的 :meth:`~dgl.DGLGraph.apply_edges`。有时边上的消息可能是高维的，这会非常消耗内存。\nDGL建议用户尽量减少边的特征维数。\n\n下面是一个如何通过对节点特征降维来减少消息维度的示例。该做法执行以下操作：拼接 ``源`` 节点和 ``目标`` 节点特征，\n然后应用一个线性层，即 :math:`W\\times (u || v)`。 ``源`` 节点和 ``目标`` 节点特征维数较高，而线性层输出维数较低。\n一个直截了当的实现方式如下：\n\n.. code::\n\n    import torch\n    import torch.nn as nn\n\n    linear = nn.Parameter(torch.FloatTensor(size=(node_feat_dim * 2, out_dim)))\n    def concat_message_function(edges):\n         return {'cat_feat': torch.cat([edges.src['feat'], edges.dst['feat']], dim=1)}\n    g.apply_edges(concat_message_function)\n    g.edata['out'] = g.edata['cat_feat'] @ linear\n\n建议的实现是将线性操作分成两部分，一个应用于 ``源`` 节点特征，另一个应用于 ``目标`` 节点特征。\n在最后一个阶段，在边上将以上两部分线性操作的结果相加，即执行 :math:`W_l\\times u + W_r \\times v`，\n因为 :math:`W \\times (u||v) = W_l \\times u + W_r \\times v`，其中 :math:`W_l` 和 :math:`W_r` 分别是矩阵\n:math:`W` 的左半部分和右半部分：\n\n.. code::\n\n    import dgl.function as fn\n\n    linear_src = nn.Parameter(torch.FloatTensor(size=(node_feat_dim, out_dim)))\n    linear_dst = nn.Parameter(torch.FloatTensor(size=(node_feat_dim, out_dim)))\n    out_src = g.ndata['feat'] @ linear_src\n    out_dst = g.ndata['feat'] @ linear_dst\n    g.srcdata.update({'out_src': out_src})\n    g.dstdata.update({'out_dst': out_dst})\n    g.apply_edges(fn.u_add_v('out_src', 'out_dst', 'out'))\n\n以上两个实现在数学上是等价的。后一种方法效率高得多，因为不需要在边上保存feat_src和feat_dst，\n从内存角度来说是高效的。另外，加法可以通过DGL的内置函数 ``u_add_v`` 进行优化，从而进一步加快计算速度并节省内存占用。\n"
  },
  {
    "path": "docs/source/guide_cn/message-heterograph.rst",
    "content": ".. _guide_cn-message-passing-heterograph:\n\n2.5 在异构图上进行消息传递\n----------------------\n\n:ref:`(English Version) <guide-message-passing-heterograph>`\n\n异构图（参考用户指南 :ref:`1.5 异构图 <guide_cn-graph-heterogeneous>` ）是包含不同类型的节点和边的图。\n不同类型的节点和边常常具有不同类型的属性。这些属性旨在刻画每一种节点和边的特征。在使用图神经网络时，根据其复杂性，\n可能需要使用不同维度的表示来对不同类型的节点和边进行建模。\n\n异构图上的消息传递可以分为两个部分：\n\n1. 对每个关系计算和聚合消息。\n2. 对每个结点聚合来自不同关系的消息。\n\n在DGL中，对异构图进行消息传递的接口是 :meth:`~dgl.DGLGraph.multi_update_all`。\n:meth:`~dgl.DGLGraph.multi_update_all` 接受一个字典。这个字典的每一个键值对里，键是一种关系，\n值是这种关系对应 :meth:`~dgl.DGLGraph.update_all` 的参数。\n:meth:`~dgl.DGLGraph.multi_update_all` 还接受一个字符串来表示跨类型整合函数，来指定整合不同关系聚合结果的方式。\n这个整合方式可以是 ``sum``、 ``min``、 ``max``、 ``mean`` 和 ``stack`` 中的一个。以下是一个例子：\n\n.. code::\n\n    import dgl.function as fn\n\n    for c_etype in G.canonical_etypes:\n        srctype, etype, dsttype = c_etype\n        Wh = self.weight[etype](feat_dict[srctype])\n        # 把它存在图中用来做消息传递\n        G.nodes[srctype].data['Wh_%s' % etype] = Wh\n        # 指定每个关系的消息传递函数：(message_func, reduce_func).\n        # 注意结果保存在同一个目标特征“h”，说明聚合是逐类进行的。\n        funcs[etype] = (fn.copy_u('Wh_%s' % etype, 'm'), fn.mean('m', 'h'))\n    # 将每个类型消息聚合的结果相加。\n    G.multi_update_all(funcs, 'sum')\n    # 返回更新过的节点特征字典\n    return {ntype : G.nodes[ntype].data['h'] for ntype in G.ntypes}\n"
  },
  {
    "path": "docs/source/guide_cn/message-part.rst",
    "content": ".. _guide_cn-message-passing-part:\n\n2.3 在图的一部分上进行消息传递\n-------------------------\n\n:ref:`(English Version) <guide-message-passing-part>`\n\n如果用户只想更新图中的部分节点，可以先通过想要囊括的节点编号创建一个子图，\n然后在子图上调用 :meth:`~dgl.DGLGraph.update_all` 方法。例如：\n\n.. code::\n\n    nid = [0, 2, 3, 6, 7, 9]\n    sg = g.subgraph(nid)\n    sg.update_all(message_func, reduce_func, apply_node_func)\n\n这是小批量训练中的常见用法。更多详细用法请参考用户指南 :ref:`guide_cn-minibatch`。"
  },
  {
    "path": "docs/source/guide_cn/message.rst",
    "content": ".. _guide_cn-message-passing:\n\n第2章：消息传递范式\n===========================\n\n:ref:`(English Version) <guide-message-passing>`\n\n消息传递是实现GNN的一种通用框架和编程范式。它从聚合与更新的角度归纳总结了多种GNN模型的实现。\n\n消息传递范式\n----------------------\n\n假设节点 :math:`v` 上的的特征为 :math:`x_v\\in\\mathbb{R}^{d_1}`，边 :math:`({u}, {v})` 上的特征为 :math:`w_{e}\\in\\mathbb{R}^{d_2}`。\n**消息传递范式** 定义了以下逐节点和边上的计算：\n\n.. math::  \\text{边上计算: } m_{e}^{(t+1)} = \\phi \\left( x_v^{(t)}, x_u^{(t)}, w_{e}^{(t)} \\right) , ({u}, {v},{e}) \\in \\mathcal{E}.\n\n.. math::  \\text{点上计算: } x_v^{(t+1)} = \\psi \\left(x_v^{(t)}, \\rho\\left(\\left\\lbrace m_{e}^{(t+1)} : ({u}, {v},{e}) \\in \\mathcal{E} \\right\\rbrace \\right) \\right).\n\n在上面的等式中， :math:`\\phi` 是定义在每条边上的消息函数，它通过将边上特征与其两端节点的特征相结合来生成消息。\n**聚合函数** :math:`\\rho` 会聚合节点接受到的消息。 **更新函数** :math:`\\psi` 会结合聚合后的消息和节点本身的特征来更新节点的特征。\n\n本章路线图\n--------------------\n\n本章首先介绍了DGL的消息传递API。然后讲解了如何高效地在点和边上使用这些API。本章的最后一节解释了如何在异构图上实现消息传递。\n\n* :ref:`guide_cn-message-passing-api`\n* :ref:`guide_cn-message-passing-efficient`\n* :ref:`guide_cn-message-passing-part`\n* :ref:`guide_cn-message-passing-heterograph`\n\n.. toctree::\n    :maxdepth: 1\n    :hidden:\n    :glob:\n\n    message-api\n    message-efficient\n    message-part\n    message-heterograph\n"
  },
  {
    "path": "docs/source/guide_cn/minibatch-custom-sampler.rst",
    "content": ".. _guide_cn-minibatch-customizing-neighborhood-sampler:\n\n6.4 定制用户自己的邻居采样器\n----------------------------------------------\n\n:ref:`(English Version) <guide-minibatch-customizing-neighborhood-sampler>`\n\n虽然DGL提供了一些邻居采样器，但有时用户还是希望编写自己的采样器。\n本节会说明如何编写用户自己的采样器并将其加入到GNN的训练框架中。\n\n回想一下在\n`How Powerful are Graph Neural Networks <https://arxiv.org/pdf/1810.00826.pdf>`__\n的论文中，消息传递的定义是：\n\n.. math::\n\n   \\begin{gathered}\n     \\boldsymbol{a}_v^{(l)} = \\rho^{(l)} \\left(\n       \\left\\lbrace\n         \\boldsymbol{h}_u^{(l-1)} : u \\in \\mathcal{N} \\left( v \\right)\n       \\right\\rbrace\n     \\right)\n   \\\\\n     \\boldsymbol{h}_v^{(l)} = \\phi^{(l)} \\left(\n       \\boldsymbol{h}_v^{(l-1)}, \\boldsymbol{a}_v^{(l)}\n     \\right)\n   \\end{gathered}\n\n其中， :math:`\\rho^{(l)}` 和 :math:`\\phi^{(l)}` 分别是可自定义的消息函数与聚合函数，\n:math:`\\mathcal{N}(v)` 为有向图 :math:`\\mathcal{G}` 上的节点 :math:`v` 的前驱节点(或无向图中的邻居)。\n\n以下图为例，假设红色节点为需要更新的目标节点：\n\n.. figure:: https://data.dgl.ai/asset/image/guide_6_4_0.png\n   :alt: Imgur\n\n\n消息传递需要聚集其邻居(绿色节点)的节点特征，如下图所示：\n\n.. figure:: https://data.dgl.ai/asset/image/guide_6_4_1.png\n   :alt: Imgur\n\n\n理解邻居采样的工作原理\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n在介绍DGL中邻居采样的用法之前，这里先解释一下邻居采样的工作原理。下文继续使用上述的例子。\n首先定义一个如上图所示的DGLGraph。\n\n.. code:: python\n\n    import torch\n    import dgl\n\n    src = torch.LongTensor(\n        [0, 0, 0, 1, 2, 2, 2, 3, 3, 4, 4, 5, 5, 6, 7, 7, 8, 9, 10,\n         1, 2, 3, 3, 3, 4, 5, 5, 6, 5, 8, 6, 8, 9, 8, 11, 11, 10, 11])\n    dst = torch.LongTensor(\n        [1, 2, 3, 3, 3, 4, 5, 5, 6, 5, 8, 6, 8, 9, 8, 11, 11, 10, 11,\n         0, 0, 0, 1, 2, 2, 2, 3, 3, 4, 4, 5, 5, 6, 7, 7, 8, 9, 10])\n    g = dgl.graph((src, dst))\n\n该例子的目标是计算单个节点(节点8)的输出。DGL将需要计算GNN输出的节点称为 *种子节点* 。\n\n找出消息传递的依赖\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n假设要使用2层GNN计算种子节点8(红色点)的输出：\n\n.. figure:: https://data.dgl.ai/asset/image/guide_6_4_2.png\n   :alt: Imgur\n\n\n其消息传递的计算公式如下：\n\n.. math::\n\n   \\begin{gathered}\n     \\boldsymbol{a}_8^{(2)} = \\rho^{(2)} \\left(\n       \\left\\lbrace\n         \\boldsymbol{h}_u^{(1)} : u \\in \\mathcal{N} \\left( 8 \\right)\n       \\right\\rbrace\n     \\right) = \\rho^{(2)} \\left(\n       \\left\\lbrace\n         \\boldsymbol{h}_4^{(1)}, \\boldsymbol{h}_5^{(1)},\n         \\boldsymbol{h}_7^{(1)}, \\boldsymbol{h}_{11}^{(1)}\n       \\right\\rbrace\n     \\right)\n   \\\\\n     \\boldsymbol{h}_8^{(2)} = \\phi^{(2)} \\left(\n       \\boldsymbol{h}_8^{(1)}, \\boldsymbol{a}_8^{(2)}\n     \\right)\n   \\end{gathered}\n\n从公式中可以看出，要计算 :math:`\\boldsymbol{h}_8^{(2)}`，需要下图中的来自节点4、5、7和11(绿色点)的消息。\n\n.. figure:: https://data.dgl.ai/asset/image/guide_6_4_3.png\n   :alt: Imgur\n\n\n上图中隐去了和计算不相关的边，仅仅保留了输出节点所需要收集消息的边。DGL称它们为红色节点8在第二个GNN层的 *边界子图*。\n\nDGL实现了多个可用于生成边界的函数。例如，\n:func:`dgl.in_subgraph()` 是一个生成子图的函数，该子图包括初始图中的所有节点和指定节点的入边。\n用户可以将其用作沿所有入边传递消息的边界。\n\n.. code:: python\n\n    frontier = dgl.in_subgraph(g, [8])\n    print(frontier.all_edges())\n\n想了解更多的相关函数，用户可以参考 :ref:`api-subgraph-extraction` 和 :ref:`api-sampling`。\n\n在DGL中，任何具有与初始图相同的节点的图都可以用作边界。这点在之后的\n:ref:`guide_cn-minibatch-customizing-neighborhood-sampler-impl`\n章节中也会提到。\n\n多层小批量消息传递的二分计算图\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n从上图中可以看到，从 :math:`\\boldsymbol{h}_\\cdot^{(1)}` 计算\n:math:`\\boldsymbol{h}_8^{(2)}` 只需要节点4, 5, 7, 8和11(绿色和红色节点)作为输入。\n原图上的其他节点是不参与计算的，因此直接在边界子图上执行消息传递有很大开销。\n因此，DGL对边界子图做了一个转换，把它的计算依赖关系变成了一个小的二分图。\nDGL称这种仅包含必要的输入节点和输出节点的二分图为一个 *块* (block)。\n下图显示了以节点8为种子节点时第二个GNN层所需的块。\n\n.. figure:: https://data.dgl.ai/asset/image/guide_6_4_4.png\n   :alt: Imgur\n\n\n请注意，输出节点也出现在输入节点中。原因是消息传递后的特征组合需要前一层的输出节点表示\n(即 :math:`\\phi^{(2)}`)。\n\nDGL提供了 :func:`dgl.to_block` 以将任何边界转换为块。其中第一个参数指定边界，\n第二个参数指定输出节点。例如，可以使用以下代码将上述边界转换为输出节点为8的块。\n\n.. code:: python\n\n    output_nodes = torch.LongTensor([8])\n    block = dgl.to_block(frontier, output_nodes)\n\n要查找给定节点类型的输入节点和输出节点的数量，可以使用\n:meth:`dgl.DGLGraph.number_of_src_nodes`  和\n:meth:`dgl.DGLGraph.number_of_dst_nodes` 方法。\n\n.. code:: python\n\n    num_input_nodes, num_output_nodes = block.number_of_src_nodes(), block.number_of_dst_nodes()\n    print(num_input_nodes, num_output_nodes)\n\n可以通过 :attr:`dgl.DGLGraph.srcdata` 和\n:attr:`dgl.DGLGraph.srcnodes` 访问该块的输入节点特征，\n并且可以通过 :attr:`dgl.DGLGraph.dstdata` 和\n:attr:`dgl.DGLGraph.dstnodes` 访问其输出节点特征。\n``srcdata``/``dstdata`` 和 ``srcnodes``/``dstnodes``\n的语法与常规图中的 :attr:`dgl.DGLGraph.ndata` 和 :attr:`dgl.DGLGraph.nodes` 相同。\n\n.. code:: python\n\n    block.srcdata['h'] = torch.randn(num_input_nodes, 5)\n    block.dstdata['h'] = torch.randn(num_output_nodes, 5)\n\n如果是从图中得到的边界，再由边界转换成块，则可以通过以下方式直接读取块的输入和输出节点的特征。\n\n.. code:: python\n\n    print(block.srcdata['x'])\n    print(block.dstdata['y'])\n\n.. raw:: html\n\n   <div class=\"alert alert-info\">\n\n::\n\n用户可以通过 ``dgl.NID`` 得到块中输入节点和输出节点的初始节点ID，可以通过 ``dgl.EID``\n得到边ID到输入边界中边的初始ID的映射。\n\n.. raw:: html\n\n   </div>\n\n**输出节点**\n\nDGL确保块的输出节点将始终出现在输入节点中。如下代码所演示的，在输入节点中，输出节点的ID位于其它节点之前。\n\n.. code:: python\n\n    input_nodes = block.srcdata[dgl.NID]\n    output_nodes = block.dstdata[dgl.NID]\n    assert torch.equal(input_nodes[:len(output_nodes)], output_nodes)\n\n因此，在用多层图神经网络时，中间某一层对应的边界需要包含该层及所有后续层计算涉及边的目标节点。例如，考虑以下边界\n\n.. figure:: https://data.dgl.ai/asset/image/guide_6_4_5.png\n   :alt: Imgur\n\n\n其中红色和绿色节点（即节点4、5、7、8和11）都是后续图神经网络层计算中某条边的目标节点。\n以下代码由于输出节点未覆盖所有这些节点，将会报错。\n\n.. code:: python\n\n    dgl.to_block(frontier2, torch.LongTensor([4, 5]))   # ERROR\n\n但是，输出节点可以比以上节点包含更多节点。下例的输出节点包含了没有入边的孤立节点。\n输入节点和输出节点将同时包含这些孤立节点。\n\n.. code:: python\n\n    # 节点3是一个孤立节点，没有任何指向它的边.\n    block3 = dgl.to_block(frontier2, torch.LongTensor([4, 5, 7, 8, 11, 3]))\n    print(block3.srcdata[dgl.NID])\n    print(block3.dstdata[dgl.NID])\n\n异构图上的采样\n^^^^^^^^^^^^^^^^^^^^\n\n块也可用于异构图。假设有如下的边界：\n\n.. code:: python\n\n    hetero_frontier = dgl.heterograph({\n        ('user', 'follow', 'user'): ([1, 3, 7], [3, 6, 8]),\n        ('user', 'play', 'game'): ([5, 5, 4], [6, 6, 2]),\n        ('game', 'played-by', 'user'): ([2], [6])\n    }, num_nodes_dict={'user': 10, 'game': 10})\n\n可以创建一个如下的块，块的输出节点为 ``User`` 节点3、6、8和 ``Game`` 节点2、6。\n\n.. code:: python\n\n    hetero_block = dgl.to_block(hetero_frontier, {'user': [3, 6, 8], 'block': [2, 6]})\n\n对于这个块，用户可以按节点类型来获取输入节点和输出节点：\n\n.. code:: python\n\n    # 输入的User和Game节点\n    print(hetero_block.srcnodes['user'].data[dgl.NID], hetero_block.srcnodes['game'].data[dgl.NID])\n    # 输出的User和Game节点\n    print(hetero_block.dstnodes['user'].data[dgl.NID], hetero_block.dstnodes['game'].data[dgl.NID])\n\n\n.. _guide_cn-minibatch-customizing-neighborhood-sampler-impl:\n\n实现一个自定义邻居采样器\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n前面章节里给出了以下用在节点分类任务的邻居采样器。\n\n.. code:: python\n\n    sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2)\n\n想实现自定义的邻居采样策略，用户可以将采样器对象替换为自定义的采样器对象。\n为此，先来看一下\n:class:`~dgl.dataloading.neighbor.MultiLayerFullNeighborSampler`\n的父类\n:class:`~dgl.dataloading.dataloader.BlockSampler`。\n\n:class:`~dgl.dataloading.dataloader.BlockSampler`\n负责使用\n:meth:`~dgl.dataloading.dataloader.BlockSampler.sample_blocks`\n方法从最后一层开始生成一个块的列表。 ``sample_blocks`` 的默认实现是向后迭代，生成边界，并将其转换为块。\n\n因此，对于邻居采样，**用户仅需要实现**\\ :meth:`~dgl.dataloading.dataloader.BlockSampler.sample_frontier`\\ **方法**。\n给定GNN层、初始图和要计算表示的节点，该方法负责为它们生成边界。\n\n同时，用户还必须将GNN的层数传递给父类。\n\n例如， :class:`~dgl.dataloading.neighbor.MultiLayerFullNeighborSampler` 的实现如下。\n\n.. code:: python\n\n    class MultiLayerFullNeighborSampler(dgl.dataloading.BlockSampler):\n        def __init__(self, n_layers):\n            super().__init__(n_layers)\n    \n        def sample_frontier(self, block_id, g, seed_nodes):\n            frontier = dgl.in_subgraph(g, seed_nodes)\n            return frontier\n\n:class:`dgl.dataloading.neighbor.MultiLayerNeighborSampler`\n是一个更复杂的邻居采样器类，它允许用户为每个节点采样部分邻居节点以汇聚信息，如下所示。\n\n.. code:: python\n\n    class MultiLayerNeighborSampler(dgl.dataloading.BlockSampler):\n        def __init__(self, fanouts):\n            super().__init__(len(fanouts))\n    \n            self.fanouts = fanouts\n    \n        def sample_frontier(self, block_id, g, seed_nodes):\n            fanout = self.fanouts[block_id]\n            if fanout is None:\n                frontier = dgl.in_subgraph(g, seed_nodes)\n            else:\n                frontier = dgl.sampling.sample_neighbors(g, seed_nodes, fanout)\n            return frontier\n\n虽然上面的函数可以生成边界，但是任何拥有与初始图相同节点的图都可用作边界。\n\n例如，如果要以某种概率将种子节点的入边随机剔除，则可以按照以下方式简单地定义采样器：\n\n.. code:: python\n\n    class MultiLayerDropoutSampler(dgl.dataloading.BlockSampler):\n        def __init__(self, p, num_layers):\n            super().__init__(num_layers)\n\n            self.p = p\n    \n        def sample_frontier(self, block_id, g, seed_nodes, *args, **kwargs):\n            # 获取种 `seed_nodes` 的所有入边\n            src, dst = dgl.in_subgraph(g, seed_nodes).all_edges()\n            # 以概率p随机选择边\n            mask = torch.zeros_like(src).bernoulli_(self.p)\n            src = src[mask]\n            dst = dst[mask]\n            # 返回一个与初始图有相同节点的边界\n            frontier = dgl.graph((src, dst), num_nodes=g.num_nodes())\n            return frontier\n    \n        def __len__(self):\n            return self.num_layers\n\n在实现自定义采样器后，用户可以创建一个数据加载器。这个数据加载器使用用户自定义的采样器，\n并且遍历种子节点生成一系列的块。\n\n.. code:: python\n\n    sampler = MultiLayerDropoutSampler(0.5, 2)\n    dataloader = dgl.dataloading.NodeDataLoader(\n        g, train_nids, sampler,\n        batch_size=1024,\n        shuffle=True,\n        drop_last=False,\n        num_workers=4)\n    \n    model = StochasticTwoLayerRGCN(in_features, hidden_features, out_features)\n    model = model.cuda()\n    opt = torch.optim.Adam(model.parameters())\n    \n    for input_nodes, blocks in dataloader:\n        blocks = [b.to(torch.device('cuda')) for b in blocks]\n        input_features = blocks[0].srcdata     # 返回一个字典\n        output_labels = blocks[-1].dstdata     # 返回一个字典\n        output_predictions = model(blocks, input_features)\n        loss = compute_loss(output_labels, output_predictions)\n        opt.zero_grad()\n        loss.backward()\n        opt.step()\n\n异构图上自定义采样器\n^^^^^^^^^^^^^^^^^^^^\n\n为异构图生成边界与为同构图生成边界没有什么不同。只要使返回的图具有与初始图相同的节点，\n就可以正常工作。例如，可以重写上面的 ``MultiLayerDropoutSampler`` 以遍历所有的边类型，\n以便它也可以在异构图上使用。\n\n.. code:: python\n\n    class MultiLayerDropoutSampler(dgl.dataloading.BlockSampler):\n        def __init__(self, p, num_layers):\n            super().__init__(num_layers)\n\n            self.p = p\n    \n        def sample_frontier(self, block_id, g, seed_nodes, *args, **kwargs):\n            # 获取 `seed_nodes` 的所有入边\n            sg = dgl.in_subgraph(g, seed_nodes)\n    \n            new_edges_masks = {}\n            # 遍历所有边的类型\n            for etype in sg.canonical_etypes:\n                edge_mask = torch.zeros(sg.num_edges(etype))\n                edge_mask.bernoulli_(self.p)\n                new_edges_masks[etype] = edge_mask.bool()\n    \n            # 返回一个与初始图有相同节点的图作为边界\n            frontier = dgl.edge_subgraph(new_edges_masks, relabel_nodes=False)\n            return frontier\n    \n        def __len__(self):\n            return self.num_layers\n            "
  },
  {
    "path": "docs/source/guide_cn/minibatch-edge.rst",
    "content": ".. _guide_cn-minibatch-edge-classification-sampler:\n\n6.2 针对边分类任务的邻居采样训练方法\n----------------------------------------------------------------------\n\n:ref:`(English Version) <guide-minibatch-edge-classification-sampler>`\n\n边分类/回归的训练与节点分类/回归的训练类似，但还是有一些明显的区别。\n\n定义邻居采样器和数据加载器\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n用户可以使用\n:ref:`和节点分类一样的邻居采样器 <guide_cn-minibatch-node-classification-sampler>`。\n\n.. code:: python\n\n    sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2)\n\n想要用DGL提供的邻居采样器做边分类，需要将其与\n:class:`~dgl.dataloading.pytorch.EdgeDataLoader` 结合使用。\n:class:`~dgl.dataloading.pytorch.EdgeDataLoader` 以小批次的形式对一组边进行迭代，\n从而产生包含边小批次的子图以及供下文中模块使用的 ``块``。\n\n例如，以下代码创建了一个PyTorch数据加载器，该PyTorch数据加载器以批的形式迭代训练边ID的数组\n``train_eids``，并将生成的块列表放到GPU上。\n\n.. code:: python\n\n    dataloader = dgl.dataloading.EdgeDataLoader(\n        g, train_eid_dict, sampler,\n        batch_size=1024,\n        shuffle=True,\n        drop_last=False,\n        num_workers=4)\n\n有关DGL的内置采样器的完整列表，用户可以参考\n:ref:`neighborhood sampler API reference <api-dataloading-neighbor-sampling>`。\n\n如果用户希望开发自己的邻居采样器，或者想要对块的概念有更详细的了解，请参考\n:ref:`guide_cn-minibatch-customizing-neighborhood-sampler`。\n\n小批次邻居采样训练时删边\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n用户在训练边分类模型时，有时希望从计算依赖中删除出现在训练数据中的边，就好像这些边根本不存在一样。\n否则，模型将 \"知道\" 两个节点之间存在边的联系，并有可能利用这点 \"作弊\" 。\n\n因此，在基于邻居采样的边分类中，用户有时会希望从采样得到的小批次图中删去部分边及其对应的反向边。\n用户可以在实例化\n:class:`~dgl.dataloading.pytorch.EdgeDataLoader`\n时设置 ``exclude='reverse_id'``，同时将边ID映射到其反向边ID。\n通常这样做会导致采样过程变慢很多，这是因为DGL要定位并删除包含在小批次中的反向边。\n\n.. code:: python\n\n    n_edges = g.num_edges()\n    dataloader = dgl.dataloading.EdgeDataLoader(\n        g, train_eid_dict, sampler,\n\n        # 下面的两个参数专门用于在邻居采样时删除小批次的一些边和它们的反向边\n        exclude='reverse_id',\n        reverse_eids=torch.cat([\n            torch.arange(n_edges // 2, n_edges), torch.arange(0, n_edges // 2)]),\n    \n        batch_size=1024,\n        shuffle=True,\n        drop_last=False,\n        num_workers=4)\n\n调整模型以适用小批次训练\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n边分类模型通常由两部分组成：\n\n-  获取边两端节点的表示。\n-  用边两端节点表示为每个类别打分。\n\n第一部分与\n:ref:`随机批次训练节点分类 <guide_cn-minibatch-node-classification-model>`\n完全相同，用户可以简单地复用它。输入仍然是DGL的数据加载器生成的块列表和输入特征。\n\n.. code:: python\n\n    class StochasticTwoLayerGCN(nn.Module):\n        def __init__(self, in_features, hidden_features, out_features):\n            super().__init__()\n            self.conv1 = dglnn.GraphConv(in_features, hidden_features)\n            self.conv2 = dglnn.GraphConv(hidden_features, out_features)\n    \n        def forward(self, blocks, x):\n            x = F.relu(self.conv1(blocks[0], x))\n            x = F.relu(self.conv2(blocks[1], x))\n            return x\n\n第二部分的输入通常是前一部分的输出，以及由小批次边导出的原始图的子图。\n子图是从相同的数据加载器产生的。用户可以调用 :meth:`dgl.DGLGraph.apply_edges` 计算边子图中边的得分。\n\n以下代码片段实现了通过合并边两端节点的特征并将其映射到全连接层来预测边的得分。\n\n.. code:: python\n\n    class ScorePredictor(nn.Module):\n        def __init__(self, num_classes, in_features):\n            super().__init__()\n            self.W = nn.Linear(2 * in_features, num_classes)\n    \n        def apply_edges(self, edges):\n            data = torch.cat([edges.src['x'], edges.dst['x']], 1)\n            return {'score': self.W(data)}\n    \n        def forward(self, edge_subgraph, x):\n            with edge_subgraph.local_scope():\n                edge_subgraph.ndata['x'] = x\n                edge_subgraph.apply_edges(self.apply_edges)\n                return edge_subgraph.edata['score']\n\n模型接受数据加载器生成的块列表、边子图以及输入节点特征进行前向传播，如下所示：\n\n.. code:: python\n\n    class Model(nn.Module):\n        def __init__(self, in_features, hidden_features, out_features, num_classes):\n            super().__init__()\n            self.gcn = StochasticTwoLayerGCN(\n                in_features, hidden_features, out_features)\n            self.predictor = ScorePredictor(num_classes, out_features)\n    \n        def forward(self, edge_subgraph, blocks, x):\n            x = self.gcn(blocks, x)\n            return self.predictor(edge_subgraph, x)\n\nDGL保证边子图中的节点与生成的块列表中最后一个块的输出节点相同。\n\n模型的训练\n~~~~~~~~~~~~~\n\n模型的训练与节点分类的随机批次训练的情况非常相似。用户可以遍历数据加载器以获得由小批次边组成的子图，\n以及计算其两端节点表示所需的块列表。\n\n.. code:: python\n\n    model = Model(in_features, hidden_features, out_features, num_classes)\n    model = model.cuda()\n    opt = torch.optim.Adam(model.parameters())\n    \n    for input_nodes, edge_subgraph, blocks in dataloader:\n        blocks = [b.to(torch.device('cuda')) for b in blocks]\n        edge_subgraph = edge_subgraph.to(torch.device('cuda'))\n        input_features = blocks[0].srcdata['features']\n        edge_labels = edge_subgraph.edata['labels']\n        edge_predictions = model(edge_subgraph, blocks, input_features)\n        loss = compute_loss(edge_labels, edge_predictions)\n        opt.zero_grad()\n        loss.backward()\n        opt.step()\n\n异构图上的模型训练\n~~~~~~~~~~~~~~~~~~~~~~~~\n\n在异构图上，计算节点表示的模型也可以用于计算边分类/回归所需的两端节点的表示。\n\n.. code:: python\n\n    class StochasticTwoLayerRGCN(nn.Module):\n        def __init__(self, in_feat, hidden_feat, out_feat, rel_names):\n            super().__init__()\n            self.conv1 = dglnn.HeteroGraphConv({\n                    rel : dglnn.GraphConv(in_feat, hidden_feat, norm='right')\n                    for rel in rel_names\n                })\n            self.conv2 = dglnn.HeteroGraphConv({\n                    rel : dglnn.GraphConv(hidden_feat, out_feat, norm='right')\n                    for rel in rel_names\n                })\n    \n        def forward(self, blocks, x):\n            x = self.conv1(blocks[0], x)\n            x = self.conv2(blocks[1], x)\n            return x\n\n在同构图和异构图上做评分预测时，代码实现的唯一不同在于调用\n:meth:`~dgl.DGLGraph.apply_edges`\n时需要在特定类型的边上进行迭代。\n\n.. code:: python\n\n    class ScorePredictor(nn.Module):\n        def __init__(self, num_classes, in_features):\n            super().__init__()\n            self.W = nn.Linear(2 * in_features, num_classes)\n    \n        def apply_edges(self, edges):\n            data = torch.cat([edges.src['x'], edges.dst['x']], 1)\n            return {'score': self.W(data)}\n    \n        def forward(self, edge_subgraph, x):\n            with edge_subgraph.local_scope():\n                edge_subgraph.ndata['x'] = x\n                for etype in edge_subgraph.canonical_etypes:\n                    edge_subgraph.apply_edges(self.apply_edges, etype=etype)\n                return edge_subgraph.edata['score']\n\n    class Model(nn.Module):\n        def __init__(self, in_features, hidden_features, out_features, num_classes,\n                     etypes):\n            super().__init__()\n            self.rgcn = StochasticTwoLayerRGCN(\n                in_features, hidden_features, out_features, etypes)\n            self.pred = ScorePredictor(num_classes, out_features)\n\n        def forward(self, edge_subgraph, blocks, x):\n            x = self.rgcn(blocks, x)\n            return self.pred(edge_subgraph, x)\n\n数据加载器的定义也与节点分类的非常相似。唯一的区别是用户需要使用\n:class:`~dgl.dataloading.pytorch.EdgeDataLoader`\n而不是\n:class:`~dgl.dataloading.pytorch.NodeDataLoader`，\n并且提供边类型和边ID张量的字典，而不是节点类型和节点ID张量的字典。\n\n.. code:: python\n\n    sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2)\n    dataloader = dgl.dataloading.EdgeDataLoader(\n        g, train_eid_dict, sampler,\n        batch_size=1024,\n        shuffle=True,\n        drop_last=False,\n        num_workers=4)\n\n如果用户希望删除异构图中的反向边，情况会有所不同。在异构图上，\n反向边通常具有与正向边本身不同的边类型，以便区分 ``向前`` 和 ``向后`` 关系。\n例如，``关注`` 和 ``被关注`` 是一对相反的关系， ``购买`` 和 ``被买下`` 也是一对相反的关系。\n\n如果一个类型中的每个边都有一个与之对应的ID相同、属于另一类型的反向边，\n则用户可以指定边类型及其反向边类型之间的映射。删除小批次中的边及其反向边的方法如下。\n\n.. code:: python\n\n    dataloader = dgl.dataloading.EdgeDataLoader(\n        g, train_eid_dict, sampler,\n    \n        # 下面的两个参数专门用于在邻居采样时删除小批次的一些边和它们的反向边\n        exclude='reverse_types',\n        reverse_etypes={'follow': 'followed by', 'followed by': 'follow',\n                        'purchase': 'purchased by', 'purchased by': 'purchase'}\n    \n        batch_size=1024,\n        shuffle=True,\n        drop_last=False,\n        num_workers=4)\n\n除了 ``compute_loss`` 的代码实现有所不同，异构图的训练循环与同构图中的训练循环几乎相同，\n计算损失函数接受节点类型和预测的两个字典。\n\n.. code:: python\n\n    model = Model(in_features, hidden_features, out_features, num_classes, etypes)\n    model = model.cuda()\n    opt = torch.optim.Adam(model.parameters())\n    \n    for input_nodes, edge_subgraph, blocks in dataloader:\n        blocks = [b.to(torch.device('cuda')) for b in blocks]\n        edge_subgraph = edge_subgraph.to(torch.device('cuda'))\n        input_features = blocks[0].srcdata['features']\n        edge_labels = edge_subgraph.edata['labels']\n        edge_predictions = model(edge_subgraph, blocks, input_features)\n        loss = compute_loss(edge_labels, edge_predictions)\n        opt.zero_grad()\n        loss.backward()\n        opt.step()\n\n`GCMC <https://github.com/dmlc/dgl/tree/master/examples/pytorch/gcmc>`__\n是一个在二分图上做边分类的代码示例。\n\n"
  },
  {
    "path": "docs/source/guide_cn/minibatch-inference.rst",
    "content": ".. _guide_cn-minibatch-inference:\n\n6.6 超大图上的精准离线推断\n------------------------------------------------------\n\n:ref:`(English Version) <guide-minibatch-inference>`\n\n子图采样和邻居采样都是为了减少用GPU训练GNN模型的内存和时间消耗。在进行推断时，\n通常更好的方法是将所有邻居进行真正的聚合，以避免采样所带来的随机性。\n然而，在GPU上进行全图前向传播通常由于显存大小的限制而不可行，而在CPU上进行则计算速度很慢。\n本节介绍了在GPU显存有限的情况下通过小批次处理和邻居采样实现全图前向传播的方法。\n\n推断算法不同于训练算法，因为需要从第一层开始对节点表示逐层计算。具体来说，对于一个指定的层，\n需要以小批次的方式计算这个GNN层所有节点的输出表示。其结果是，推断算法将包含一个外循环以迭代执行各层，\n和一个内循环以迭代处理各个节点小批次。相比之下，训练算法有一个外循环以迭代处理各个节点小批次，\n和一个内循环以迭代执行各层（包含邻居采样和消息传递）。\n\n下面的动画展示了计算的过程（注意，每层只展示前3个小批次）：\n\n.. figure:: https://data.dgl.ai/asset/image/guide_6_6_0.gif\n   :alt: Imgur\n\n\n实现离线推断\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n这里以6.1节中 :ref:`guide_cn-minibatch-node-classification-model`\n提到的两层GCN为例。实现离线推断的方法依然需要使用 ``MultiLayerFullNeighborSampler``，\n但它每次只为一层进行采样。注意，这里的离线推断被实现为GNN模块的一个方法，\n这是因为它对一层的计算依赖于消息的聚合和结合。\n\n.. code:: python\n\n    class StochasticTwoLayerGCN(nn.Module):\n        def __init__(self, in_features, hidden_features, out_features):\n            super().__init__()\n            self.hidden_features = hidden_features\n            self.out_features = out_features\n            self.conv1 = dgl.nn.GraphConv(in_features, hidden_features)\n            self.conv2 = dgl.nn.GraphConv(hidden_features, out_features)\n            self.n_layers = 2\n    \n        def forward(self, blocks, x):\n            x_dst = x[:blocks[0].number_of_dst_nodes()]\n            x = F.relu(self.conv1(blocks[0], (x, x_dst)))\n            x_dst = x[:blocks[1].number_of_dst_nodes()]\n            x = F.relu(self.conv2(blocks[1], (x, x_dst)))\n            return x\n    \n        def inference(self, g, x, batch_size, device):\n            \"\"\"        用该模块进行离线推断        \"\"\"\n            # 逐层计算表示\n            for l, layer in enumerate([self.conv1, self.conv2]):\n                y = torch.zeros(g.num_nodes(),\n                                self.hidden_features\n                                if l != self.n_layers - 1\n                                else self.out_features)\n                sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)\n                dataloader = dgl.dataloading.NodeDataLoader(\n                    g, torch.arange(g.num_nodes()), sampler,\n                    batch_size=batch_size,\n                    shuffle=True,\n                    drop_last=False)\n\n                # 在一层中，依批次对节点进行迭代\n                for input_nodes, output_nodes, blocks in dataloader:\n                    block = blocks[0]\n\n                    # 将必要输入节点的特征复制到GPU上\n                    h = x[input_nodes].to(device)\n\n                    # 计算输出，注意计算方法是一样的，但只对一层进行计算\n                    h_dst = h[:block.number_of_dst_nodes()]\n                    h = F.relu(layer(block, (h, h_dst)))\n\n                    # 将输出复制回CPU\n                    y[output_nodes] = h.cpu()\n\n                x = y\n    \n            return y\n\n注意，如果以模型选择为目的在验证集上计算评价指标，则通常不需要进行计算精确的离线推断。\n原因是这需要为每一层上的每个节点计算表示，会非常消耗资源，尤其是在包含大量未标记数据的半监督系统中。\n邻居采样在这个时候可以更好地发挥作用。\n\n对于离线推断的示例，用户可以参照\n`GraphSAGE <https://github.com/dmlc/dgl/blob/master/examples/pytorch/graphsage/train_sampling.py>`__\n和\n`RGCN <https://github.com/dmlc/dgl/blob/master/examples/pytorch/rgcn-hetero/entity_classify_mb.py>`__。\n"
  },
  {
    "path": "docs/source/guide_cn/minibatch-link.rst",
    "content": ".. _guide_cn-minibatch-link-classification-sampler:\n\n6.3 针对链接预测任务的邻居采样训练方法\n--------------------------------------------------------------------\n\n:ref:`(English Version) <guide-minibatch-link-classification-sampler>`\n\n结合负采样来定义邻居采样器和数据加载器\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n用户仍然可以使用与节点/边分类中相同的邻居采样器。\n\n.. code:: python\n\n    sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2)\n\nDGL中的\n:class:`~dgl.dataloading.pytorch.EdgeDataLoader`\n还支持生成用于链接预测的负样本。为此，用户需要定义负采样函数。例如，\n:class:`~dgl.dataloading.negative_sampler.Uniform`\n函数是基于均匀分布的采样函数，它对于每个边的源节点，采样 ``k`` 个负样本的目标节点。\n\n以下数据加载器将为每个边的源节点均匀采样5个负样本的目标节点。\n\n.. code:: python\n\n    dataloader = dgl.dataloading.EdgeDataLoader(\n        g, train_seeds, sampler,\n        negative_sampler=dgl.dataloading.negative_sampler.Uniform(5),\n        batch_size=args.batch_size,\n        shuffle=True,\n        drop_last=False,\n        pin_memory=True,\n        num_workers=args.num_workers)\n\n关于内置的负采样方法，用户可以参考 :ref:`api-dataloading-negative-sampling`。\n\n用户还可以自定义负采样函数，它应当以原图 ``g`` 和小批量的边ID数组 ``eid`` 作为入参，\n并返回源节点ID数组和目标节点ID数组。\n\n下面给出了一个自定义的负采样方法的示例，该采样方法根据与节点的度的幂成正比的概率分布对负样本目标节点进行采样。\n\n.. code:: python\n\n    class NegativeSampler(object):\n        def __init__(self, g, k):\n            # 缓存概率分布\n            self.weights = g.in_degrees().float() ** 0.75\n            self.k = k\n    \n        def __call__(self, g, eids):\n            src, _ = g.find_edges(eids)\n            src = src.repeat_interleave(self.k)\n            dst = self.weights.multinomial(len(src), replacement=True)\n            return src, dst\n    \n    dataloader = dgl.dataloading.EdgeDataLoader(\n        g, train_seeds, sampler,\n        negative_sampler=NegativeSampler(g, 5),\n        batch_size=args.batch_size,\n        shuffle=True,\n        drop_last=False,\n        pin_memory=True,\n        num_workers=args.num_workers)\n\n调整模型以进行小批次训练\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n如 :ref:`guide_cn-training-link-prediction` 中所介绍的，\n用户可以通过比较边(正样本)与不存在的边(负样本)的得分来训练链路模型。用户可以重用在边分类/回归中的节点表示模型，\n来计算边的分数。\n\n.. code:: python\n\n    class StochasticTwoLayerGCN(nn.Module):\n        def __init__(self, in_features, hidden_features, out_features):\n            super().__init__()\n            self.conv1 = dgl.nn.GraphConv(in_features, hidden_features)\n            self.conv2 = dgl.nn.GraphConv(hidden_features, out_features)\n    \n        def forward(self, blocks, x):\n            x = F.relu(self.conv1(blocks[0], x))\n            x = F.relu(self.conv2(blocks[1], x))\n            return x\n\n对于得分的预测，只需要预测每个边的标量分数而不是类别的概率分布，\n因此本示例说明了如何使用边的两个端点的向量的点积来计算分数。\n\n.. code:: python\n\n    class ScorePredictor(nn.Module):\n        def forward(self, edge_subgraph, x):\n            with edge_subgraph.local_scope():\n                edge_subgraph.ndata['x'] = x\n                edge_subgraph.apply_edges(dgl.function.u_dot_v('x', 'x', 'score'))\n                return edge_subgraph.edata['score']\n\n使用负采样方法后，DGL的数据加载器将为每个小批次生成三项：\n\n-  一个正样本图，其中包含采样得到的小批次内所有的边。\n-  一个负样本图，其中包含由负采样方法生成的所有不存在的边。\n-  邻居采样方法生成的块的列表。\n\n因此，可以如下定义链接预测模型，该模型的输入包括上述三项以及输入的特征。\n\n.. code:: python\n\n    class Model(nn.Module):\n        def __init__(self, in_features, hidden_features, out_features):\n            super().__init__()\n            self.gcn = StochasticTwoLayerGCN(\n                in_features, hidden_features, out_features)\n    \n        def forward(self, positive_graph, negative_graph, blocks, x):\n            x = self.gcn(blocks, x)\n            pos_score = self.predictor(positive_graph, x)\n            neg_score = self.predictor(negative_graph, x)\n            return pos_score, neg_score\n\n模型的训练\n~~~~~~~~~~~~~\n\n训练循环通过数据加载器去遍历数据，将得到的图和输入特征传入上述模型。\n\n.. code:: python\n\n    model = Model(in_features, hidden_features, out_features)\n    model = model.cuda()\n    opt = torch.optim.Adam(model.parameters())\n    \n    for input_nodes, positive_graph, negative_graph, blocks in dataloader:\n        blocks = [b.to(torch.device('cuda')) for b in blocks]\n        positive_graph = positive_graph.to(torch.device('cuda'))\n        negative_graph = negative_graph.to(torch.device('cuda'))\n        input_features = blocks[0].srcdata['features']\n        pos_score, neg_score = model(positive_graph, negative_graph, blocks, input_features)\n        loss = compute_loss(pos_score, neg_score)\n        opt.zero_grad()\n        loss.backward()\n        opt.step()\n\nDGL提供了在同构图上做链路预测的一个示例：\n`无监督学习GraphSAGE <https://github.com/dmlc/dgl/blob/master/examples/pytorch/graphsage/train_sampling_unsupervised.py>`__。\n\n异构图上的随机批次训练\n~~~~~~~~~~~~~~~~~~~~~~~~\n\n计算异构图上的节点表示的模型也可以用于计算边分类/回归中的边两端节点的表示。\n\n.. code:: python\n\n    class StochasticTwoLayerRGCN(nn.Module):\n        def __init__(self, in_feat, hidden_feat, out_feat, rel_names):\n            super().__init__()\n            self.conv1 = dglnn.HeteroGraphConv({\n                    rel : dglnn.GraphConv(in_feat, hidden_feat, norm='right')\n                    for rel in rel_names\n                })\n            self.conv2 = dglnn.HeteroGraphConv({\n                    rel : dglnn.GraphConv(hidden_feat, out_feat, norm='right')\n                    for rel in rel_names\n                })\n    \n        def forward(self, blocks, x):\n            x = self.conv1(blocks[0], x)\n            x = self.conv2(blocks[1], x)\n            return x\n\n对于得分的预测，同构图和异构图之间唯一的实现差异是后者需要用\n:meth:`dgl.DGLGraph.apply_edges`\n来遍历所有的边类型。\n\n.. code:: python\n\n    class ScorePredictor(nn.Module):\n        def forward(self, edge_subgraph, x):\n            with edge_subgraph.local_scope():\n                edge_subgraph.ndata['x'] = x\n                for etype in edge_subgraph.canonical_etypes:\n                    edge_subgraph.apply_edges(\n                        dgl.function.u_dot_v('x', 'x', 'score'), etype=etype)\n                return edge_subgraph.edata['score']\n\n    class Model(nn.Module):\n        def __init__(self, in_features, hidden_features, out_features, num_classes,\n                     etypes):\n            super().__init__()\n            self.rgcn = StochasticTwoLayerRGCN(\n                in_features, hidden_features, out_features, etypes)\n            self.pred = ScorePredictor()\n\n        def forward(self, positive_graph, negative_graph, blocks, x):\n            x = self.rgcn(blocks, x)\n            pos_score = self.pred(positive_graph, x)\n            neg_score = self.pred(negative_graph, x)\n            return pos_score, neg_score\n\n数据加载器的定义也与边分类/回归里的定义非常相似。唯一的区别是用户需要提供负采样方法，\n并且提供边类型和边ID张量的字典，而不是节点类型和节点ID张量的字典。\n\n.. code:: python\n\n    sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2)\n    dataloader = dgl.dataloading.EdgeDataLoader(\n        g, train_eid_dict, sampler,\n        negative_sampler=dgl.dataloading.negative_sampler.Uniform(5),\n        batch_size=1024,\n        shuffle=True,\n        drop_last=False,\n        num_workers=4)\n\n如果用户想自定义负采样函数，那么该函数应以初始图以及由边类型和边ID张量构成的字典作为输入。\n它返回以边类型为键、源节点-目标节点数组对为值的字典。示例如下所示：\n\n.. code:: python\n\n   class NegativeSampler(object):\n       def __init__(self, g, k):\n           # 缓存概率分布\n           self.weights = {\n               etype: g.in_degrees(etype=etype).float() ** 0.75\n               for _, etype, _ in g.canonical_etypes\n           }\n           self.k = k\n\n       def __call__(self, g, eids_dict):\n           result_dict = {}\n           for etype, eids in eids_dict.items():\n               src, _ = g.find_edges(eids, etype=etype)\n               src = src.repeat_interleave(self.k)\n               dst = self.weights[etype].multinomial(len(src), replacement=True)\n               result_dict[etype] = (src, dst)\n           return result_dict\n\n随后，需要向数据载入器提供边类型和对应边ID的字典，以及负采样器。示例如下所示：\n\n.. code:: python\n\n    train_eid_dict = {\n        g.edges(etype=etype, form='eid')\n        for etype in g.etypes}\n\n    dataloader = dgl.dataloading.EdgeDataLoader(\n        g, train_eid_dict, sampler,\n        negative_sampler=NegativeSampler(g, 5),\n        batch_size=1024,\n        shuffle=True,\n        drop_last=False,\n        num_workers=4)\n\n异构图上的随机批次模型训练与同构图中的训练几乎相同，不同之处在于，\n``compute_loss`` 是以边类型字典和预测结果字典作为输入。\n\n.. code:: python\n\n    model = Model(in_features, hidden_features, out_features, num_classes, etypes)\n    model = model.cuda()\n    opt = torch.optim.Adam(model.parameters())\n    \n    for input_nodes, positive_graph, negative_graph, blocks in dataloader:\n        blocks = [b.to(torch.device('cuda')) for b in blocks]\n        positive_graph = positive_graph.to(torch.device('cuda'))\n        negative_graph = negative_graph.to(torch.device('cuda'))\n        input_features = blocks[0].srcdata['features']\n        pos_score, neg_score = model(positive_graph, negative_graph, blocks, input_features)\n        loss = compute_loss(pos_score, neg_score)\n        opt.zero_grad()\n        loss.backward()\n        opt.step()\n\n\n\n"
  },
  {
    "path": "docs/source/guide_cn/minibatch-nn.rst",
    "content": ".. _guide_cn-minibatch-custom-gnn-module:\n\n6.5 为小批次训练实现定制化的GNN模块\n-------------------------------------------------------------\n\n:ref:`(English Version) <guide-minibatch-custom-gnn-module>`\n\n如果用户熟悉如何定制用于更新整个同构图或异构图的GNN模块(参见\n:ref:`guide_cn-nn`)，那么在块上计算的代码也是类似的，区别只在于节点被划分为输入节点和输出节点。\n\n以下面的自定义图卷积模块代码为例。注意，该代码并不一定是最高效的实现，\n此处只是将其作为自定义GNN模块的一个示例。\n\n.. code:: python\n\n    class CustomGraphConv(nn.Module):\n        def __init__(self, in_feats, out_feats):\n            super().__init__()\n            self.W = nn.Linear(in_feats * 2, out_feats)\n    \n        def forward(self, g, h):\n            with g.local_scope():\n                g.ndata['h'] = h\n                g.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h_neigh'))\n                return self.W(torch.cat([g.ndata['h'], g.ndata['h_neigh']], 1))\n\n如果用户已有一个用于整个图的自定义消息传递模块，并且想将其用于块，则只需要按照如下的方法重写forward函数。\n注意，以下代码在注释里保留了整图实现的语句，用户可以将用于块的语句和原先用于整图的语句进行比较。\n\n.. code:: python\n\n    class CustomGraphConv(nn.Module):\n        def __init__(self, in_feats, out_feats):\n            super().__init__()\n            self.W = nn.Linear(in_feats * 2, out_feats)\n\n        # h现在是输入和输出节点的特征张量对，而不是一个单独的特征张量\n\n        # def forward(self, g, h):\n        def forward(self, block, h):\n            # with g.local_scope():\n            with block.local_scope():\n                # g.ndata['h'] = h\n                h_src = h\n                h_dst = h[:block.number_of_dst_nodes()]\n                block.srcdata['h'] = h_src\n                block.dstdata['h'] = h_dst\n    \n                # g.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h_neigh'))\n                block.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h_neigh'))\n    \n                # return self.W(torch.cat([g.ndata['h'], g.ndata['h_neigh']], 1))\n                return self.W(torch.cat(\n                    [block.dstdata['h'], block.dstdata['h_neigh']], 1))\n\n通常，需要对用于整图的GNN模块进行如下调整以将其用于块作为输入的情况：\n\n-  切片取输入特征的前几行，得到输出节点的特征。切片行数可以通过\n   :meth:`block.number_of_dst_nodes <dgl.DGLGraph.number_of_dst_nodes>` 获得。\n-  如果原图只包含一种节点类型，对输入节点特征，将 :attr:`g.ndata <dgl.DGLGraph.ndata>` 替换为\n   :attr:`block.srcdata <dgl.DGLGraph.srcdata>`；对于输出节点特征，将\n   :attr:`g.ndata <dgl.DGLGraph.ndata>`  替换为\n   :attr:`block.dstdata <dgl.DGLGraph.dstdata>`。\n-  如果原图包含多种节点类型，对于输入节点特征，将\n   :attr:`g.nodes <dgl.DGLGraph.nodes>` 替换为\n   :attr:`block.srcnodes <dgl.DGLGraph.srcnodes>`；对于输出节点特征，将\n   :attr:`g.nodes <dgl.DGLGraph.nodes>` 替换为\n   :attr:`block.dstnodes <dgl.DGLGraph.dstnodes>`。\n-  对于输入节点数量，将 :meth:`g.num_nodes <dgl.DGLGraph.num_nodes>` 替换为\n   :meth:`block.number_of_src_nodes <dgl.DGLGraph.number_of_src_nodes>` ；\n   对于输出节点数量，将 :meth:`g.num_nodes <dgl.DGLGraph.num_nodes>` 替换为\n   :meth:`block.number_of_dst_nodes <dgl.DGLGraph.number_of_dst_nodes>` 。\n\n异构图上的模型定制\n~~~~~~~~~~~~~~~~~~~~\n\n为异构图修改GNN模块的方法是类似的。例如，以下面用于全图的GNN模块为例：\n\n.. code:: python\n\n    class CustomHeteroGraphConv(nn.Module):\n        def __init__(self, g, in_feats, out_feats):\n            super().__init__()\n            self.Ws = nn.ModuleDict()\n            for etype in g.canonical_etypes:\n                utype, _, vtype = etype\n                self.Ws[etype] = nn.Linear(in_feats[utype], out_feats[vtype])\n            for ntype in g.ntypes:\n                self.Vs[ntype] = nn.Linear(in_feats[ntype], out_feats[ntype])\n    \n        def forward(self, g, h):\n            with g.local_scope():\n                for ntype in g.ntypes:\n                    g.nodes[ntype].data['h_dst'] = self.Vs[ntype](h[ntype])\n                    g.nodes[ntype].data['h_src'] = h[ntype]\n                for etype in g.canonical_etypes:\n                    utype, _, vtype = etype\n                    g.update_all(\n                        fn.copy_u('h_src', 'm'), fn.mean('m', 'h_neigh'),\n                        etype=etype)\n                    g.nodes[vtype].data['h_dst'] = g.nodes[vtype].data['h_dst'] + \\\n                        self.Ws[etype](g.nodes[vtype].data['h_neigh'])\n                return {ntype: g.nodes[ntype].data['h_dst'] for ntype in g.ntypes}\n\n对于 ``CustomHeteroGraphConv``，原则是将 ``g.nodes`` 替换为 ``g.srcnodes`` 或\n``g.dstnodes`` (根据需要输入还是输出节点的特征来选择)。\n\n.. code:: python\n\n    class CustomHeteroGraphConv(nn.Module):\n        def __init__(self, g, in_feats, out_feats):\n            super().__init__()\n            self.Ws = nn.ModuleDict()\n            for etype in g.canonical_etypes:\n                utype, _, vtype = etype\n                self.Ws[etype] = nn.Linear(in_feats[utype], out_feats[vtype])\n            for ntype in g.ntypes:\n                self.Vs[ntype] = nn.Linear(in_feats[ntype], out_feats[ntype])\n    \n        def forward(self, g, h):\n            with g.local_scope():\n                for ntype in g.ntypes:\n                    h_src, h_dst = h[ntype]\n                    g.dstnodes[ntype].data['h_dst'] = self.Vs[ntype](h[ntype])\n                    g.srcnodes[ntype].data['h_src'] = h[ntype]\n                for etype in g.canonical_etypes:\n                    utype, _, vtype = etype\n                    g.update_all(\n                        fn.copy_u('h_src', 'm'), fn.mean('m', 'h_neigh'),\n                        etype=etype)\n                    g.dstnodes[vtype].data['h_dst'] = \\\n                        g.dstnodes[vtype].data['h_dst'] + \\\n                        self.Ws[etype](g.dstnodes[vtype].data['h_neigh'])\n                return {ntype: g.dstnodes[ntype].data['h_dst']\n                        for ntype in g.ntypes}\n\n实现能够处理同构图、二分图和块的模块\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nDGL中所有的消息传递模块(参见 :ref:`apinn`)都能够处理同构图、\n单向二分图(包含两种节点类型和一种边类型)和包含一种边类型的块。\n本质上，内置的DGL神经网络模块的输入图及特征必须满足下列情况之一：\n\n-  如果输入特征是一个张量对，则输入图必须是一个单向二分图\n-  如果输入特征是一个单独的张量且输入图是一个块，则DGL会自动将输入节点特征前一部分设为输出节点的特征。\n-  如果输入特征是一个单独的张量且输入图不是块，则输入图必须是同构图。\n\n例如，下面的代码是 :class:`dgl.nn.pytorch.SAGEConv` 的简化版(DGL同样支持它在MXNet和TensorFlow后端里的实现)。\n代码里移除了归一化，且只考虑平均聚合函数的情况。\n\n.. code:: python\n\n    import dgl.function as fn\n    class SAGEConv(nn.Module):\n        def __init__(self, in_feats, out_feats):\n            super().__init__()\n            self.W = nn.Linear(in_feats * 2, out_feats)\n    \n        def forward(self, g, h):\n            if isinstance(h, tuple):\n                h_src, h_dst = h\n            elif g.is_block:\n                h_src = h\n                h_dst = h[:g.number_of_dst_nodes()]\n            else:\n                h_src = h_dst = h\n                 \n            g.srcdata['h'] = h_src\n            g.dstdata['h'] = h_dst\n            g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h_neigh'))\n            return F.relu(\n                self.W(torch.cat([g.dstdata['h'], g.dstdata['h_neigh']], 1)))\n\n:ref:`guide_cn-nn` 提供了对 :class:`dgl.nn.pytorch.SAGEConv` 代码的详细解读，\n其适用于单向二分图、同构图和块。\n"
  },
  {
    "path": "docs/source/guide_cn/minibatch-node.rst",
    "content": ".. _guide_cn-minibatch-node-classification-sampler:\n\n6.1 针对节点分类任务的邻居采样训练方法\n-----------------------------------------------------------------------\n\n:ref:`(English Version) <guide-minibatch-node-classification-sampler>`\n\n为了随机(批次)训练模型，需要进行以下操作：\n\n- 定义邻居采样器。\n- 调整模型以进行小批次训练。\n- 修改模型训练循环部分。\n\n以下小节将逐一介绍这些步骤。\n\n定义邻居采样器和数据加载器\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nDGL提供了几个邻居采样类，这些类会生成需计算的节点在每一层计算时所需的依赖图。\n\n最简单的邻居采样器是\n:class:`~dgl.dataloading.neighbor.MultiLayerFullNeighborSampler`，它可获取节点的所有邻居。\n\n要使用DGL提供的采样器，还需要将其与\n:class:`~dgl.dataloading.pytorch.NodeDataLoader`\n结合使用，后者可以以小批次的形式对一个节点的集合进行迭代。\n\n例如，以下代码创建了一个PyTorch的 DataLoader，它分批迭代训练节点ID数组 ``train_nids``，\n并将生成的子图列表放到GPU上。\n\n.. code:: python\n\n    import dgl\n    import dgl.nn as dglnn\n    import torch\n    import torch.nn as nn\n    import torch.nn.functional as F\n    \n    sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2)\n    dataloader = dgl.dataloading.NodeDataLoader(\n        g, train_nids, sampler,\n        batch_size=1024,\n        shuffle=True,\n        drop_last=False,\n        num_workers=4)\n\n对DataLoader进行迭代，将会创建一个特定图的列表，这些图表示每层的计算依赖。在DGL中称之为 *块*。\n\n.. code:: python\n\n    input_nodes, output_nodes, blocks = next(iter(dataloader))\n    print(blocks)\n\n上面的dataloader一次迭代会生成三个输出。 ``input_nodes`` 代表计算 ``output_nodes`` 的表示所需的节点。\n``块`` 包含了每个GNN层要计算哪些节点表示作为输出，要将哪些节点表示作为输入，以及来自输入节点的表示如何传播到输出节点。\n\n完整的内置采样方法清单，用户可以参考\n:ref:`neighborhood sampler API reference <api-dataloading-neighbor-sampling>`。\n\n如果用户希望编写自己的邻居采样器，或者想要关于块的更深入的介绍，读者可以参考\n:ref:`guide_cn-minibatch-customizing-neighborhood-sampler`。\n\n.. _guide_cn-minibatch-node-classification-model:\n\n调整模型以进行小批次训练\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n如果用户的消息传递模块全使用的是DGL内置模块，则模型在进行小批次训练时只需做很小的调整。\n以多层GCN为例。如果用户模型在全图上是按以下方式实现的：\n\n.. code:: python\n\n    class TwoLayerGCN(nn.Module):\n        def __init__(self, in_features, hidden_features, out_features):\n            super().__init__()\n            self.conv1 = dglnn.GraphConv(in_features, hidden_features)\n            self.conv2 = dglnn.GraphConv(hidden_features, out_features)\n    \n        def forward(self, g, x):\n            x = F.relu(self.conv1(g, x))\n            x = F.relu(self.conv2(g, x))\n            return x\n\n然后，用户所需要做的就是用上面生成的块( ``block`` )来替换图( ``g`` )。\n\n.. code:: python\n\n    class StochasticTwoLayerGCN(nn.Module):\n        def __init__(self, in_features, hidden_features, out_features):\n            super().__init__()\n            self.conv1 = dgl.nn.GraphConv(in_features, hidden_features)\n            self.conv2 = dgl.nn.GraphConv(hidden_features, out_features)\n    \n        def forward(self, blocks, x):\n            x = F.relu(self.conv1(blocks[0], x))\n            x = F.relu(self.conv2(blocks[1], x))\n            return x\n\n上面的DGL ``GraphConv`` 模块接受的一个参数是数据加载器生成的 ``块`` 中的一个元素。\n\n用户可以查阅 :ref:`NN模块的API参考 <apinn>` 来查看DGL的内置模型模块是否支持接受 ``块`` 作为参数。\n\n如果希望使用自定义的消息传递模块，用户可以参考\n:ref:`guide_cn-minibatch-custom-gnn-module`。\n\n模型的训练\n~~~~~~~~~~~~~\n\n这里的模型的训练循环仅包含使用定制的批处理迭代器遍历数据集的内容。在每个生成块列表的迭代中：\n\n\n1. 将与输入节点相对应的节点特征加载到GPU上。节点特征可以存储在内存或外部存储中。\n   请注意，用户只需要加载输入节点的特征，而不是像整图训练那样加载所有节点的特征。\n\n   如果特征存储在 ``g.ndata`` 中，则可以通过 ``blocks[0].srcdata`` 来加载第一个块的输入节点的特征，\n   这些节点是计算节点最终表示所需的所有必需的节点。\n\n2. 将块列表和输入节点特征传入多层GNN并获取输出。\n\n3. 将与输出节点相对应的节点标签加载到GPU上。同样，节点标签可以存储在内存或外部存储器中。\n   再次提醒下，用户只需要加载输出节点的标签，而不是像整图训练那样加载所有节点的标签。\n\n   如果特征存储在 ``g.ndata`` 中，则可以通过访问 ``blocks[-1].dstdata`` 中的特征来加载标签，\n   它是最后一个块的输出节点的特征，这些节点与用户希望计算最终表示的节点相同。\n\n4. 计算损失并反向传播。\n\n.. code:: python\n\n    model = StochasticTwoLayerGCN(in_features, hidden_features, out_features)\n    model = model.cuda()\n    opt = torch.optim.Adam(model.parameters())\n    \n    for input_nodes, output_nodes, blocks in dataloader:\n        blocks = [b.to(torch.device('cuda')) for b in blocks]\n        input_features = blocks[0].srcdata['features']\n        output_labels = blocks[-1].dstdata['label']\n        output_predictions = model(blocks, input_features)\n        loss = compute_loss(output_labels, output_predictions)\n        opt.zero_grad()\n        loss.backward()\n        opt.step()\n\nDGL提供了一个端到端的随机批次训练示例\n`GraphSAGE的实现 <https://github.com/dmlc/dgl/blob/master/examples/pytorch/graphsage/node_classification.py>`__。\n\n\n异构图上模型的训练\n~~~~~~~~~~~~~~~~~~~~~~~~\n\n在异构图上训练图神经网络进行节点分类的方法也是类似的。\n\n例如，在\n:ref:`guide_cn-training-rgcn-node-classification`\n中介绍了如何在整图上训练一个2层的RGCN模型。\nRGCN小批次训练的代码与它非常相似(为简单起见，这里删除了自环、非线性和基分解)：\n\n.. code:: python\n\n    class StochasticTwoLayerRGCN(nn.Module):\n        def __init__(self, in_feat, hidden_feat, out_feat, rel_names):\n            super().__init__()\n            self.conv1 = dglnn.HeteroGraphConv({\n                    rel : dglnn.GraphConv(in_feat, hidden_feat, norm='right')\n                    for rel in rel_names\n                })\n            self.conv2 = dglnn.HeteroGraphConv({\n                    rel : dglnn.GraphConv(hidden_feat, out_feat, norm='right')\n                    for rel in rel_names\n                })\n    \n        def forward(self, blocks, x):\n            x = self.conv1(blocks[0], x)\n            x = self.conv2(blocks[1], x)\n            return x\n\nDGL提供的一些采样方法也支持异构图。例如，用户仍然可以使用\n:class:`~dgl.dataloading.neighbor.MultiLayerFullNeighborSampler` 类和\n:class:`~dgl.dataloading.pytorch.NodeDataLoader` 类进行随机批次训练。\n对于全邻居采样，唯一的区别是用户需要为训练集指定节点类型和节点ID的字典。\n\n.. code:: python\n\n    sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2)\n    dataloader = dgl.dataloading.NodeDataLoader(\n        g, train_nid_dict, sampler,\n        batch_size=1024,\n        shuffle=True,\n        drop_last=False,\n        num_workers=4)\n\n模型的训练与同构图几乎相同。不同之处在于， ``compute_loss`` 的实现会包含两个字典：节点类型和预测结果。\n\n.. code:: python\n\n    model = StochasticTwoLayerRGCN(in_features, hidden_features, out_features, etypes)\n    model = model.cuda()\n    opt = torch.optim.Adam(model.parameters())\n    \n    for input_nodes, output_nodes, blocks in dataloader:\n        blocks = [b.to(torch.device('cuda')) for b in blocks]\n        input_features = blocks[0].srcdata     # returns a dict\n        output_labels = blocks[-1].dstdata     # returns a dict\n        output_predictions = model(blocks, input_features)\n        loss = compute_loss(output_labels, output_predictions)\n        opt.zero_grad()\n        loss.backward()\n        opt.step()\n\nDGL提供了端到端随机批次训练的\n`RGCN的实现 <https://github.com/dmlc/dgl/blob/master/examples/pytorch/rgcn-hetero/entity_classify_mb.py>`__。\n"
  },
  {
    "path": "docs/source/guide_cn/minibatch.rst",
    "content": ".. _guide_cn-minibatch:\n\n第6章：在大图上的随机（批次）训练\n=======================================================\n\n:ref:`(English Version) <guide-minibatch>`\n\n如果用户有包含数百万甚至数十亿个节点或边的大图，通常无法进行\n:ref:`guide_cn-training`\n中所述的全图训练。考虑在一个有 :math:`N` 个节点的图上运行的、隐层大小为 :math:`H` 的 :math:`L` 层图卷积网络，\n存储隐层表示需要 :math:`O(NLH)` 的内存空间，当 :math:`N` 较大时，这很容易超过一块GPU的显存限制。\n\n本章介绍了一种在大图上进行随机小批次训练的方法，可以让用户不用一次性把所有节点特征拷贝到GPU上。\n\n邻居采样方法概述\n--------------------------------------------\n\n邻居节点采样的工作流程通常如下：每次梯度下降，选择一个小批次的图节点，\n其最终表示将在神经网络的第 :math:`L` 层进行计算，然后在网络的第 :math:`L-1` 层选择该批次节点的全部或部分邻居节点。\n重复这个过程，直到到达输入层。这个迭代过程会构建计算的依赖关系图，从输出开始，一直到输入，如下图所示：\n\n.. figure:: https://data.dgl.ai/asset/image/guide_6_0_0.png\n   :alt: Imgur\n\n该方法能节省在大图上训练图神经网络的开销和计算资源。\n\nDGL实现了一些邻居节点采样的方法和使用邻居节点采样训练图神经网络的管道，同时也支持让用户自定义采样策略。\n\n本章路线图\n-----------\n\n本章的前半部分介绍了不同场景下如何进行随机训练的方法。\n\n* :ref:`guide_cn-minibatch-node-classification-sampler`\n* :ref:`guide_cn-minibatch-edge-classification-sampler`\n* :ref:`guide_cn-minibatch-link-classification-sampler`\n\n本章余下的小节介绍了更多的高级主题，面向那些想要开发新的采样算法、\n想要实现与小批次训练兼容的图神经网络模块、以及想要了解如何在小批次数据上进行评估和推理模型的用户。\n\n* :ref:`guide_cn-minibatch-customizing-neighborhood-sampler`\n* :ref:`guide_cn-minibatch-custom-gnn-module`\n* :ref:`guide_cn-minibatch-inference`\n\n\n.. toctree::\n    :maxdepth: 1\n    :hidden:\n    :glob:\n\n    minibatch-node\n    minibatch-edge\n    minibatch-link\n    minibatch-custom-sampler\n    minibatch-nn\n    minibatch-inference\n"
  },
  {
    "path": "docs/source/guide_cn/nn-construction.rst",
    "content": ".. _guide_cn-nn-construction:\n\n3.1 DGL NN模块的构造函数\n-----------------------------\n\n:ref:`(English Version) <guide-nn-construction>`\n\n构造函数完成以下几个任务：\n\n1. 设置选项。\n2. 注册可学习的参数或者子模块。\n3. 初始化参数。\n\n.. code::\n\n    import torch.nn as nn\n\n    from dgl.utils import expand_as_pair\n\n    class SAGEConv(nn.Module):\n        def __init__(self,\n                     in_feats,\n                     out_feats,\n                     aggregator_type,\n                     bias=True,\n                     norm=None,\n                     activation=None):\n            super(SAGEConv, self).__init__()\n\n            self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)\n            self._out_feats = out_feats\n            self._aggre_type = aggregator_type\n            self.norm = norm\n            self.activation = activation\n\n在构造函数中，用户首先需要设置数据的维度。对于一般的PyTorch模块，维度通常包括输入的维度、输出的维度和隐层的维度。\n对于图神经网络，输入维度可被分为源节点特征维度和目标节点特征维度。\n\n除了数据维度，图神经网络的一个典型选项是聚合类型(``self._aggre_type``)。对于特定目标节点，聚合类型决定了如何聚合不同边上的信息。\n常用的聚合类型包括 ``mean``、 ``sum``、 ``max`` 和 ``min``。一些模块可能会使用更加复杂的聚合函数，比如 ``lstm``。\n\n上面代码里的 ``norm`` 是用于特征归一化的可调用函数。在SAGEConv论文里，归一化可以是L2归一化:\n:math:`h_v = h_v / \\lVert h_v \\rVert_2`。\n\n.. code::\n\n            # 聚合类型：mean、pool、lstm、gcn\n            if aggregator_type not in ['mean', 'pool', 'lstm', 'gcn']:\n                raise KeyError('Aggregator type {} not supported.'.format(aggregator_type))\n            if aggregator_type == 'pool':\n                self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)\n            if aggregator_type == 'lstm':\n                self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)\n            if aggregator_type in ['mean', 'pool', 'lstm']:\n                self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)\n            self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias)\n            self.reset_parameters()\n\n注册参数和子模块。在SAGEConv中，子模块根据聚合类型而有所不同。这些模块是纯PyTorch NN模块，例如 ``nn.Linear``、 ``nn.LSTM`` 等。\n构造函数的最后调用了 ``reset_parameters()`` 进行权重初始化。\n\n.. code::\n\n        def reset_parameters(self):\n            \"\"\"重新初始化可学习的参数\"\"\"\n            gain = nn.init.calculate_gain('relu')\n            if self._aggre_type == 'pool':\n                nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)\n            if self._aggre_type == 'lstm':\n                self.lstm.reset_parameters()\n            if self._aggre_type != 'gcn':\n                nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)\n            nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)\n"
  },
  {
    "path": "docs/source/guide_cn/nn-forward.rst",
    "content": ".. _guide_cn-nn-forward:\n\n3.2 编写DGL NN模块的forward函数\n---------------------------------\n\n:ref:`(English Version) <guide-nn-forward>`\n\n在NN模块中， ``forward()`` 函数执行了实际的消息传递和计算。与通常以张量为参数的PyTorch NN模块相比，\nDGL NN模块额外增加了1个参数 :class:`dgl.DGLGraph`。``forward()`` 函数的内容一般可以分为3项操作：\n\n-  检测输入图对象是否符合规范。\n\n-  消息传递和聚合。\n\n-  聚合后，更新特征作为输出。\n\n下文展示了SAGEConv示例中的 ``forward()`` 函数。\n\n输入图对象的规范检测\n~~~~~~~~~~~~~~~~~~~~~\n\n.. code::\n\n        def forward(self, graph, feat):\n            with graph.local_scope():\n                # 指定图类型，然后根据图类型扩展输入特征\n                feat_src, feat_dst = expand_as_pair(feat, graph)\n\n``forward()`` 函数需要处理输入的许多极端情况，这些情况可能导致计算和消息传递中的值无效。\n比如在 :class:`~dgl.nn.pytorch.conv.GraphConv` 等conv模块中，DGL会检查输入图中是否有入度为0的节点。\n当1个节点入度为0时， ``mailbox`` 将为空，并且聚合函数的输出值全为0，\n这可能会导致模型性能不佳。但是，在 :class:`~dgl.nn.pytorch.conv.SAGEConv` 模块中，被聚合的特征将会与节点的初始特征拼接起来，\n``forward()`` 函数的输出不会全为0。在这种情况下，无需进行此类检验。\n\nDGL NN模块可在不同类型的图输入中重复使用，包括：同构图、异构图（:ref:`guide_cn-graph-heterogeneous`）和子图块（:ref:`guide_cn-minibatch`）。\n\nSAGEConv的数学公式如下：\n\n.. math::\n\n\n   h_{\\mathcal{N}(dst)}^{(l+1)}  = \\mathrm{aggregate}\n           \\left(\\{h_{src}^{l}, \\forall src \\in \\mathcal{N}(dst) \\}\\right)\n\n.. math::\n\n    h_{dst}^{(l+1)} = \\sigma \\left(W \\cdot \\mathrm{concat}\n           (h_{dst}^{l}, h_{\\mathcal{N}(dst)}^{l+1}) + b \\right)\n\n.. math::\n\n    h_{dst}^{(l+1)} = \\mathrm{norm}(h_{dst}^{l+1})\n\n源节点特征 ``feat_src`` 和目标节点特征 ``feat_dst`` 需要根据图类型被指定。\n用于指定图类型并将 ``feat`` 扩展为 ``feat_src`` 和 ``feat_dst`` 的函数是 :meth:`~dgl.utils.expand_as_pair`。\n该函数的细节如下所示。\n\n.. code::\n\n    def expand_as_pair(input_, g=None):\n        if isinstance(input_, tuple):\n            # 二分图的情况\n            return input_\n        elif g is not None and g.is_block:\n            # 子图块的情况\n            if isinstance(input_, Mapping):\n                input_dst = {\n                    k: F.narrow_row(v, 0, g.number_of_dst_nodes(k))\n                    for k, v in input_.items()}\n            else:\n                input_dst = F.narrow_row(input_, 0, g.number_of_dst_nodes())\n            return input_, input_dst\n        else:\n            # 同构图的情况\n            return input_, input_\n\n对于同构图上的全图训练，源节点和目标节点相同，它们都是图中的所有节点。\n\n在异构图的情况下，图可以分为几个二分图，每种关系对应一个。关系表示为 ``(src_type, edge_type, dst_dtype)``。\n当输入特征 ``feat`` 是1个元组时，图将会被视为二分图。元组中的第1个元素为源节点特征，第2个元素为目标节点特征。\n\n在小批次训练中，计算应用于给定的一堆目标节点所采样的子图。子图在DGL中称为区块(``block``)。\n在区块创建的阶段，``dst nodes`` 位于节点列表的最前面。通过索引 ``[0:g.number_of_dst_nodes()]`` 可以找到 ``feat_dst``。\n\n确定 ``feat_src`` 和 ``feat_dst`` 之后，以上3种图类型的计算方法是相同的。\n\n消息传递和聚合\n~~~~~~~~~~~~~~~~~\n\n.. code::\n\n                import dgl.function as fn\n                import torch.nn.functional as F\n                from dgl.utils import check_eq_shape\n\n                if self._aggre_type == 'mean':\n                    graph.srcdata['h'] = feat_src\n                    graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh'))\n                    h_neigh = graph.dstdata['neigh']\n                elif self._aggre_type == 'gcn':\n                    check_eq_shape(feat)\n                    graph.srcdata['h'] = feat_src\n                    graph.dstdata['h'] = feat_dst\n                    graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh'))\n                    # 除以入度\n                    degs = graph.in_degrees().to(feat_dst)\n                    h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)\n                elif self._aggre_type == 'pool':\n                    graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))\n                    graph.update_all(fn.copy_u('h', 'm'), fn.max('m', 'neigh'))\n                    h_neigh = graph.dstdata['neigh']\n                else:\n                    raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))\n\n                # GraphSAGE中gcn聚合不需要fc_self\n                if self._aggre_type == 'gcn':\n                    rst = self.fc_neigh(h_neigh)\n                else:\n                    rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)\n\n上面的代码执行了消息传递和聚合的计算。这部分代码会因模块而异。请注意，代码中的所有消息传递均使用  :meth:`~dgl.DGLGraph.update_all` API和\nDGL内置的消息/聚合函数来实现，以充分利用 :ref:`guide_cn-message-passing-efficient` 里所介绍的性能优化。\n\n聚合后，更新特征作为输出\n~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. code::\n\n                # 激活函数\n                if self.activation is not None:\n                    rst = self.activation(rst)\n                # 归一化\n                if self.norm is not None:\n                    rst = self.norm(rst)\n                return rst\n\n``forward()`` 函数的最后一部分是在完成消息聚合后更新节点的特征。\n常见的更新操作是根据构造函数中设置的选项来应用激活函数和进行归一化。\n"
  },
  {
    "path": "docs/source/guide_cn/nn-heterograph.rst",
    "content": ".. _guide_cn-nn-heterograph:\n\n3.3 异构图上的GraphConv模块\n--------------------------------\n\n:ref:`(English Version) <guide-nn-heterograph>`\n\nDGL提供了 :class:`~dgl.nn.pytorch.HeteroGraphConv`，用于定义异构图上GNN模块。\n实现逻辑与消息传递级别的API :meth:`~dgl.DGLGraph.multi_update_all` 相同，它包括：\n\n-  每个关系上的DGL NN模块。\n-  聚合来自不同关系上的结果。\n\n其数学定义为：\n\n.. math::  h_{dst}^{(l+1)} = \\underset{r\\in\\mathcal{R}, r_{dst}=dst}{AGG} (f_r(g_r, h_{r_{src}}^l, h_{r_{dst}}^l))\n\n其中 :math:`f_r` 是对应每个关系 :math:`r` 的NN模块，:math:`AGG` 是聚合函数。\n\nHeteroGraphConv的实现逻辑\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. code::\n\n    import torch.nn as nn\n\n    class HeteroGraphConv(nn.Module):\n        def __init__(self, mods, aggregate='sum'):\n            super(HeteroGraphConv, self).__init__()\n            self.mods = nn.ModuleDict(mods)\n            if isinstance(aggregate, str):\n                # 获取聚合函数的内部函数\n                self.agg_fn = get_aggregate_fn(aggregate)\n            else:\n                self.agg_fn = aggregate\n\n异构图的卷积操作接受一个字典类型参数 ``mods``。这个字典的键为关系名，值为作用在该关系上NN模块对象。参数 ``aggregate``\n则指定了如何聚合来自不同关系的结果。\n\n.. code::\n\n    def forward(self, g, inputs, mod_args=None, mod_kwargs=None):\n        if mod_args is None:\n            mod_args = {}\n        if mod_kwargs is None:\n            mod_kwargs = {}\n        outputs = {nty : [] for nty in g.dsttypes}\n\n除了输入图和输入张量，``forward()`` 函数还使用2个额外的字典参数 ``mod_args`` 和 ``mod_kwargs``。\n这2个字典与 ``self.mods`` 具有相同的键，值则为对应NN模块的自定义参数。\n\n``forward()`` 函数的输出结果也是一个字典类型的对象。其键为 ``nty``，其值为每个目标节点类型 ``nty`` 的输出张量的列表，\n表示来自不同关系的计算结果。``HeteroGraphConv`` 会对这个列表进一步聚合，并将结果返回给用户。\n\n.. code::\n\n          if g.is_block:\n              src_inputs = inputs\n              dst_inputs = {k: v[:g.number_of_dst_nodes(k)] for k, v in inputs.items()}\n          else:\n              src_inputs = dst_inputs = inputs\n\n          for stype, etype, dtype in g.canonical_etypes:\n              rel_graph = g[stype, etype, dtype]\n              if rel_graph.num_edges() == 0:\n                  continue\n              if stype not in src_inputs or dtype not in dst_inputs:\n                  continue\n              dstdata = self.mods[etype](\n                  rel_graph,\n                  (src_inputs[stype], dst_inputs[dtype]),\n                  *mod_args.get(etype, ()),\n                  **mod_kwargs.get(etype, {}))\n              outputs[dtype].append(dstdata)\n\n输入 ``g`` 可以是异构图或来自异构图的子图区块。和普通的NN模块一样，``forward()`` 函数需要分别处理不同的输入图类型。\n\n上述代码中的for循环为处理异构图计算的主要逻辑。首先我们遍历图中所有的关系(通过调用 ``canonical_etypes``)。\n通过关系名，我们可以使用g[ ``stype, etype, dtype`` ]的语法将只包含该关系的子图( ``rel_graph`` )抽取出来。\n对于二分图，输入特征将被组织为元组 ``(src_inputs[stype], dst_inputs[dtype])``。\n接着调用用户预先注册在该关系上的NN模块，并将结果保存在outputs字典中。\n\n.. code::\n\n        rsts = {}\n        for nty, alist in outputs.items():\n            if len(alist) != 0:\n                rsts[nty] = self.agg_fn(alist, nty)\n\n最后，``HeteroGraphConv`` 会调用用户注册的 ``self.agg_fn`` 函数聚合来自多个关系的结果。\n读者可以在API文档中找到 :class:~dgl.nn.pytorch.HeteroGraphConv 的示例。"
  },
  {
    "path": "docs/source/guide_cn/nn.rst",
    "content": ".. _guide_cn-nn:\n\n第3章：构建图神经网络（GNN）模块\n===================================\n\n:ref:`(English Version) <guide-nn>`\n\nDGL NN模块是用户构建GNN模型的基本模块。根据DGL所使用的后端深度神经网络框架，\nDGL NN模块的父类取决于后端所使用的深度神经网络框架。对于PyTorch后端，\n它应该继承 `PyTorch的NN模块 <https://pytorch.org/docs/1.2.0/_modules/torch/nn/modules/module.html>`__；对于MXNet后端，它应该继承\n`MXNet Gluon的NN块 <http://mxnet.incubator.apache.org/versions/1.6/api/python/docs/api/gluon/nn/index.html>`__；\n对于TensorFlow后端，它应该继承 `Tensorflow的Keras层 <https://www.tensorflow.org/api_docs/python/tf/keras/layers>`__。\n在DGL NN模块中，构造函数中的参数注册和前向传播函数中使用的张量操作与后端框架一样。这种方式使得DGL的代码可以无缝嵌入到后端框架的代码中。\nDGL和这些深度神经网络框架的主要差异是其独有的消息传递操作。\n\nDGL已经集成了很多常用的 :ref:`apinn-pytorch-conv`、 :ref:`apinn-pytorch-dense-conv`、\n:ref:`apinn-pytorch-pooling` 和 :ref:`apinn-pytorch-util`。欢迎给DGL贡献更多的模块！\n\n本章将使用PyTorch作为后端，用 :class:`~dgl.nn.pytorch.conv.SAGEConv` 作为例子来介绍如何构建用户自己的DGL NN模块。\n\n本章路线图\n------------\n\n* :ref:`guide_cn-nn-construction`\n* :ref:`guide_cn-nn-forward`\n* :ref:`guide_cn-nn-heterograph`\n\n.. toctree::\n    :maxdepth: 1\n    :hidden:\n    :glob:\n\n    nn-construction\n    nn-forward\n    nn-heterograph\n"
  },
  {
    "path": "docs/source/guide_cn/training-edge.rst",
    "content": ".. _guide_cn-training-edge-classification:\n\n5.2 边分类/回归\n---------------------------------------------\n\n:ref:`(English Version) <guide-training-edge-classification>`\n\n有时用户希望预测图中边的属性值，这种情况下，用户需要构建一个边分类/回归的模型。\n\n以下代码生成了一个随机图用于演示边分类/回归。\n\n.. code:: python\n\n    src = np.random.randint(0, 100, 500)\n    dst = np.random.randint(0, 100, 500)\n    # 同时建立反向边\n    edge_pred_graph = dgl.graph((np.concatenate([src, dst]), np.concatenate([dst, src])))\n    # 建立点和边特征，以及边的标签\n    edge_pred_graph.ndata['feature'] = torch.randn(100, 10)\n    edge_pred_graph.edata['feature'] = torch.randn(1000, 10)\n    edge_pred_graph.edata['label'] = torch.randn(1000)\n    # 进行训练、验证和测试集划分\n    edge_pred_graph.edata['train_mask'] = torch.zeros(1000, dtype=torch.bool).bernoulli(0.6)\n\n概述\n~~~~~~~~\n\n上一节介绍了如何使用多层GNN进行节点分类。同样的方法也可以被用于计算任何节点的隐藏表示。\n并从边的两个端点的表示，通过计算得出对边属性的预测。\n\n对一条边计算预测值最常见的情况是将预测表示为一个函数，函数的输入为两个端点的表示，\n输入还可以包括边自身的特征。\n\n与节点分类在模型实现上的差别\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n如果用户使用上一节中的模型计算了节点的表示，那么用户只需要再编写一个用\n:meth:`~dgl.DGLGraph.apply_edges` 方法计算边预测的组件即可进行边分类/回归任务。\n\n例如，对于边回归任务，如果用户想为每条边计算一个分数，可按下面的代码对每一条边计算它的两端节点隐藏表示的点积来作为分数。\n\n.. code:: python\n\n    import dgl.function as fn\n    class DotProductPredictor(nn.Module):\n        def forward(self, graph, h):\n            # h是从5.1节的GNN模型中计算出的节点表示\n            with graph.local_scope():\n                graph.ndata['h'] = h\n                graph.apply_edges(fn.u_dot_v('h', 'h', 'score'))\n                return graph.edata['score']\n\n用户也可以使用MLP(多层感知机)对每条边生成一个向量表示(例如，作为一个未经过归一化的类别的分布)，\n并在下游任务中使用。\n\n.. code:: python\n\n    class MLPPredictor(nn.Module):\n        def __init__(self, in_features, out_classes):\n            super().__init__()\n            self.W = nn.Linear(in_features * 2, out_classes)\n\n        def apply_edges(self, edges):\n            h_u = edges.src['h']\n            h_v = edges.dst['h']\n            score = self.W(torch.cat([h_u, h_v], 1))\n            return {'score': score}\n\n        def forward(self, graph, h):\n            # h是从5.1节的GNN模型中计算出的节点表示\n            with graph.local_scope():\n                graph.ndata['h'] = h\n                graph.apply_edges(self.apply_edges)\n                return graph.edata['score']\n\n模型的训练\n~~~~~~~~~~~~~\n\n给定计算节点和边上表示的模型后，用户可以轻松地编写在所有边上进行预测的全图训练代码。\n\n以下代码用了 :ref:`guide_cn-message-passing` 中定义的 ``SAGE`` 作为节点表示计算模型以及前一小节中定义的\n``DotPredictor`` 作为边预测模型。\n\n.. code:: python\n\n    class Model(nn.Module):\n        def __init__(self, in_features, hidden_features, out_features):\n            super().__init__()\n            self.sage = SAGE(in_features, hidden_features, out_features)\n            self.pred = DotProductPredictor()\n        def forward(self, g, x):\n            h = self.sage(g, x)\n            return self.pred(g, h)\n\n在训练模型时可以使用布尔掩码区分训练、验证和测试数据集。该例子里省略了训练早停和模型保存部分的代码。\n\n.. code:: python\n\n    node_features = edge_pred_graph.ndata['feature']\n    edge_label = edge_pred_graph.edata['label']\n    train_mask = edge_pred_graph.edata['train_mask']\n    model = Model(10, 20, 5)\n    opt = torch.optim.Adam(model.parameters())\n    for epoch in range(10):\n        pred = model(edge_pred_graph, node_features)\n        loss = ((pred[train_mask] - edge_label[train_mask]) ** 2).mean()\n        opt.zero_grad()\n        loss.backward()\n        opt.step()\n        print(loss.item())\n\n.. _guide_cn-training-edge-classification-heterogeneous-graph:\n\n异构图上的边预测模型的训练\n~~~~~~~~~~~~~~~~~~~~~~~~~\n\n例如想在某一特定类型的边上进行分类任务，用户只需要计算所有节点类型的节点表示，\n然后同样通过调用 :meth:`~dgl.DGLGraph.apply_edges` 方法计算预测值即可。\n唯一的区别是在调用 ``apply_edges`` 时需要指定边的类型。\n\n.. code:: python\n\n    class HeteroDotProductPredictor(nn.Module):\n        def forward(self, graph, h, etype):\n            # h是从5.1节中对每种类型的边所计算的节点表示\n            with graph.local_scope():\n                graph.ndata['h'] = h   #一次性为所有节点类型的 'h'赋值\n                graph.apply_edges(fn.u_dot_v('h', 'h', 'score'), etype=etype)\n                return graph.edges[etype].data['score']\n\n同样地，用户也可以编写一个 ``HeteroMLPPredictor``。\n\n.. code:: python\n\n    class MLPPredictor(nn.Module):\n        def __init__(self, in_features, out_classes):\n            super().__init__()\n            self.W = nn.Linear(in_features * 2, out_classes)\n\n        def apply_edges(self, edges):\n            h_u = edges.src['h']\n            h_v = edges.dst['h']\n            score = self.W(torch.cat([h_u, h_v], 1))\n            return {'score': score}\n\n        def forward(self, graph, h, etype):\n            # h是从5.1节中对异构图的每种类型的边所计算的节点表示\n            with graph.local_scope():\n                graph.ndata['h'] = h   #一次性为所有节点类型的 'h'赋值\n                graph.apply_edges(self.apply_edges, etype=etype)\n                return graph.edges[etype].data['score']\n\n在某种类型的边上为每一条边预测的端到端模型的定义如下所示：\n\n.. code:: python\n\n    class Model(nn.Module):\n        def __init__(self, in_features, hidden_features, out_features, rel_names):\n            super().__init__()\n            self.sage = RGCN(in_features, hidden_features, out_features, rel_names)\n            self.pred = HeteroDotProductPredictor()\n        def forward(self, g, x, etype):\n            h = self.sage(g, x)\n            return self.pred(g, h, etype)\n\n使用模型时只需要简单地向模型提供一个包含节点类型和数据特征的字典。\n\n.. code:: python\n\n    model = Model(10, 20, 5, hetero_graph.etypes)\n    user_feats = hetero_graph.nodes['user'].data['feature']\n    item_feats = hetero_graph.nodes['item'].data['feature']\n    label = hetero_graph.edges['click'].data['label']\n    train_mask = hetero_graph.edges['click'].data['train_mask']\n    node_features = {'user': user_feats, 'item': item_feats}\n\n\n训练部分和同构图的训练基本一致。例如，如果用户想预测边类型为 ``click`` 的边的标签，只需要按下例编写代码。\n\n.. code:: python\n\n    opt = torch.optim.Adam(model.parameters())\n    for epoch in range(10):\n        pred = model(hetero_graph, node_features, 'click')\n        loss = ((pred[train_mask] - label[train_mask]) ** 2).mean()\n        opt.zero_grad()\n        loss.backward()\n        opt.step()\n        print(loss.item())\n\n\n在异构图中预测已有边的类型\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n预测图中已经存在的边属于哪个类型是一个非常常见的任务类型。例如，根据\n:ref:`本章的异构图样例数据 <guide_cn-training-heterogeneous-graph-example>`，\n用户的任务是给定一条连接 ``user`` 节点和 ``item`` 节点的边，预测它的类型是 ``click`` 还是 ``dislike``。\n这个例子是评分预测的一个简化版本，在推荐场景中很常见。\n\n边类型预测的第一步仍然是计算节点表示。可以通过类似\n:ref:`节点分类的RGCN模型 <guide_cn-training-rgcn-node-classification>`\n这一章中提到的图卷积网络获得。第二步是计算边上的预测值。\n在这里可以复用上述提到的 ``HeteroDotProductPredictor``。\n这里需要注意的是输入的图数据不能包含边的类型信息，\n因此需要将所要预测的边类型(如 ``click`` 和 ``dislike``)合并成一种边的图，\n并为每条边计算出每种边类型的可能得分。下面的例子使用一个拥有 ``user``\n和 ``item`` 两种节点类型和一种边类型的图。该边类型是通过合并所有从 ``user``\n到 ``item`` 的边类型(如 ``like`` 和 ``dislike``)得到。\n用户可以很方便地用关系切片的方式创建这个图。\n\n.. code:: python\n\n    dec_graph = hetero_graph['user', :, 'item']\n\n这个方法会返回一个异构图，它具有 ``user`` 和 ``item`` 两种节点类型，\n以及把它们之间的所有边的类型进行合并后的单一边类型。\n\n由于上面这行代码将原来的边类型存成边特征 ``dgl.ETYPE``，用户可以将它作为标签使用。\n\n.. code:: python\n\n    edge_label = dec_graph.edata[dgl.ETYPE]\n\n将上述图作为边类型预测模块的输入，用户可以按如下方式编写预测模块：\n\n.. code:: python\n\n    class HeteroMLPPredictor(nn.Module):\n        def __init__(self, in_dims, n_classes):\n            super().__init__()\n            self.W = nn.Linear(in_dims * 2, n_classes)\n\n        def apply_edges(self, edges):\n            x = torch.cat([edges.src['h'], edges.dst['h']], 1)\n            y = self.W(x)\n            return {'score': y}\n\n        def forward(self, graph, h):\n            # h是从5.1节中对异构图的每种类型的边所计算的节点表示\n            with graph.local_scope():\n                graph.ndata['h'] = h   #一次性为所有节点类型的 'h'赋值\n                graph.apply_edges(self.apply_edges)\n                return graph.edata['score']\n\n结合了节点表示模块和边类型预测模块的模型如下所示：\n\n.. code:: python\n\n    class Model(nn.Module):\n        def __init__(self, in_features, hidden_features, out_features, rel_names):\n            super().__init__()\n            self.sage = RGCN(in_features, hidden_features, out_features, rel_names)\n            self.pred = HeteroMLPPredictor(out_features, len(rel_names))\n        def forward(self, g, x, dec_graph):\n            h = self.sage(g, x)\n            return self.pred(dec_graph, h)\n\n训练部分如下所示：\n\n.. code:: python\n\n    model = Model(10, 20, 5, hetero_graph.etypes)\n    user_feats = hetero_graph.nodes['user'].data['feature']\n    item_feats = hetero_graph.nodes['item'].data['feature']\n    node_features = {'user': user_feats, 'item': item_feats}\n\n    opt = torch.optim.Adam(model.parameters())\n    for epoch in range(10):\n        logits = model(hetero_graph, node_features, dec_graph)\n        loss = F.cross_entropy(logits, edge_label)\n        opt.zero_grad()\n        loss.backward()\n        opt.step()\n        print(loss.item())\n\n读者可以进一步参考\n`Graph Convolutional Matrix\nCompletion <https://github.com/dmlc/dgl/tree/master/examples/pytorch/gcmc>`__\n这一示例来了解如何预测异构图中的边类型。\n`模型实现文件中 <https://github.com/dmlc/dgl/tree/master/examples/pytorch/gcmc>`__\n的节点表示模块称作 ``GCMCLayer``。边类型预测模块称作 ``BiDecoder``。\n虽然这两个模块都比上述的示例代码要复杂，但其基本思想和本章描述的流程是一致的。\n"
  },
  {
    "path": "docs/source/guide_cn/training-eweight.rst",
    "content": ".. _guide_cn-training-eweight:\n\n5.5 使用边权重\n----------------------------------\n\n:ref:`(English Version) <guide-training-eweight>`\n\n在一个加权图里，每条边都有一个有意义的标量权重。例如，边权重可以是连接强度或者信心指数。\n人们自然会想要在模型开发中使用它们。\n\n使用边权重的消息传递\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n大部分图神经网络在前馈计算中仅通过消息传递引入图结构信息。一个消息传递运算可以视为一个函数。\n这个函数的输入变量是一个邻接矩阵和其他输入特征。对于一个不带权重的图，邻接矩阵里的元素不是零就是一。\n值为一的元素表示一条边。对于一个加权图，非零的元素可以取任意标量值。这等价于把每条消息和对应的边权重相乘，\n即`图注意力网络 <https://arxiv.org/pdf/1710.10903.pdf>`__中的做法。\n\n在DGL里可以通过以下步骤实现这一需求：\n\n- 把边权重保存为一个边特征\n- 在消息函数里，用保存的边特征与对应边的原始消息相乘\n\n考虑以下基于DGL的消息传递示例：\n\n.. code::\n\n    import dgl.function as fn\n\n    # 假定graph.ndata['ft']存储了输入节点特征\n    graph.update_all(fn.copy_u('ft', 'm'), fn.sum('m', 'ft'))\n\n可以将其按以下方式修改以支持边权重：\n\n.. code::\n\n    import dgl.function as fn\n\n    # 将边权重保存为一个边特征。边权重是一个形状为(E, *)的张量。\n    # E是边的数量\n    graph.edata['w'] = eweight\n\n    # 假定graph.ndata['ft']存储了输入节点特征\n    graph.update_all(fn.u_mul_e('ft', 'w', 'm'), fn.sum('m', 'ft'))\n\n在NN模块中使用边权重\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n用户可以通过修改NN模块中所有的消息传递操作来给NN模块增加边权重支持。以下代码块提供了一个例子。\n\n.. code::\n    import dgl.function as fn\n    import torch.nn as nn\n\n    class GNN(nn.Module):\n        def __init__(self, in_feats, out_feats):\n            super().__init__()\n            self.linear = nn.Linear(in_feats, out_feats)\n\n        def forward(self, g, feat, edge_weight=None):\n            with g.local_scope():\n                g.ndata['ft'] = self.linear(feat)\n                if edge_weight is None:\n                    msg_func = fn.copy_u('ft', 'm')\n                else:\n                    g.edata['w'] = edge_weight\n                    msg_func = fn.u_mul_e('ft', 'w', 'm')\n                g.update_all(msg_func, fn.sum('m', 'ft'))\n                return g.ndata['ft']\n\nDGL内置的NN模块如果在forward函数中支持一个可选的:attr:`edge_weight`变量，那么它们已经支持了边权重。\n\n用户可能会需要标准化原始边权重。DGL提供了一个满足这个功能的函数\n:func:`~dgl.nn.pytorch.conv.EdgeWeightNorm`。\n"
  },
  {
    "path": "docs/source/guide_cn/training-graph.rst",
    "content": ".. _guide_cn-training-graph-classification:\n\n5.4 整图分类\n----------------------------------\n\n:ref:`(English Version) <guide-training-graph-classification>`\n\n许多场景中的图数据是由多个图组成，而不是单个的大图数据。例如不同类型的人群社区。\n通过用图刻画同一社区里人与人间的友谊，可以得到多张用于分类的图。\n在这个场景里，整图分类模型可以识别社区的类型，即根据结构和整体信息对图进行分类。\n\n概述\n~~~~~~~~\n\n整图分类与节点分类或链接预测的主要区别是：预测结果刻画了整个输入图的属性。\n与之前的任务类似，用户还是在节点或边上进行消息传递。但不同的是，整图分类任务还需要得到整个图的表示。\n\n整图分类的处理流程如下图所示：\n\n.. figure:: https://data.dgl.ai/tutorial/batch/graph_classifier.png\n   :alt: Graph Classification Process\n\n   整图分类流程\n\n从左至右，一般流程是：\n\n-  准备一个批次的图；\n-  在这个批次的图上进行消息传递以更新节点或边的特征；\n-  将一张图里的节点或边特征聚合成整张图的图表示；\n-  根据任务设计分类层。\n\n批次的图\n^^^^^^^^^^^^^^^\n\n整图分类任务通常需要在很多图上进行训练。如果用户在训练模型时一次仅使用一张图，训练效率会很低。\n借用深度学习实践中常用的小批次训练方法，用户可将多张图组成一个批次，在整个图批次上进行一次训练迭代。\n\n使用DGL，用户可将一系列的图建立成一个图批次。一个图批次可以被看作是一张大图，图中的每个连通子图对应一张原始小图。\n\n.. figure:: https://data.dgl.ai/tutorial/batch/batch.png\n   :alt: Batched Graph\n\n   批次化的图\n\n需要注意，DGL里对图进行变换的函数会去掉图上的批次信息。用户可以通过 :func:`dgl.DGLGraph.set_batch_num_nodes`\n和 :func:`dgl.DGLGraph.set_batch_num_edges` 两个函数在变换后的图上重新加入批次信息。\n\n图读出\n^^^^^^^^^^^^^\n\n数据集中的每一张图都有它独特的结构和节点与边的特征。为了完成单个图的预测，通常会聚合并汇总单个图尽可能多的信息。\n这类操作叫做“读出”。常见的聚合方法包括：对所有节点或边特征求和、取平均值、逐元素求最大值或最小值。\n\n给定一张图 :math:`g`，对它所有节点特征取平均值的聚合读出公式如下：\n\n.. math:: h_g = \\frac{1}{|\\mathcal{V}|}\\sum_{v\\in \\mathcal{V}}h_v\n\n其中，:math:`h_g` 是图 :math:`g` 的表征， :math:`\\mathcal{V}` 是图 :math:`g` 中节点的集合，\n:math:`h_v` 是节点 :math:`v` 的特征。\n\nDGL内置了常见的图读出函数，例如 :func:`dgl.readout_nodes` 就实现了上述的平均值读出计算。\n\n在得到 :math:`h_g` 后，用户可将其传给一个多层感知机(MLP)来获得分类输出。\n\n编写神经网络模型\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n模型的输入是带节点和边特征的批次化图。需要注意的是批次化图中的节点和边属性没有批次大小对应的维度。\n模型中应特别注意以下几点。\n\n批次化图上的计算\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n首先，一个批次中不同的图是完全分开的，即任意两个图之间没有边连接。\n根据这个良好的性质，所有消息传递函数(的计算)仍然具有相同的结果。\n\n其次，读出函数会分别作用在图批次中的每张图上。假设批次大小为 :math:`B`，要聚合的特征大小为 :math:`D`，\n则图读出的张量形状为 :math:`(B, D)`。\n\n.. code:: python\n\n    import dgl\n    import torch\n\n    g1 = dgl.graph(([0, 1], [1, 0]))\n    g1.ndata['h'] = torch.tensor([1., 2.])\n    g2 = dgl.graph(([0, 1], [1, 2]))\n    g2.ndata['h'] = torch.tensor([1., 2., 3.])\n    \n    dgl.readout_nodes(g1, 'h')\n    # tensor([3.])  # 1 + 2\n    \n    bg = dgl.batch([g1, g2])\n    dgl.readout_nodes(bg, 'h')\n    # tensor([3., 6.])  # [1 + 2, 1 + 2 + 3]\n\n最后，批次化图中的每个节点或边特征张量均通过将所有图上的相应特征拼接得到。\n\n.. code:: python\n\n    bg.ndata['h']\n    # tensor([1., 2., 1., 2., 3.])\n\n模型定义\n^^^^^^^^^^^^^^^^\n\n了解了上述计算规则后，用户可以定义一个非常简单的模型。\n\n.. code:: python\n\n    import dgl.nn.pytorch as dglnn\n    import torch.nn as nn\n\n    class Classifier(nn.Module):\n        def __init__(self, in_dim, hidden_dim, n_classes):\n            super(Classifier, self).__init__()\n            self.conv1 = dglnn.GraphConv(in_dim, hidden_dim)\n            self.conv2 = dglnn.GraphConv(hidden_dim, hidden_dim)\n            self.classify = nn.Linear(hidden_dim, n_classes)\n    \n        def forward(self, g, h):\n            # 应用图卷积和激活函数\n            h = F.relu(self.conv1(g, h))\n            h = F.relu(self.conv2(g, h))\n            with g.local_scope():\n                g.ndata['h'] = h\n                # 使用平均读出计算图表示\n                hg = dgl.mean_nodes(g, 'h')\n                return self.classify(hg)\n\n模型的训练\n~~~~~~~~~~~~~\n\n数据加载\n^^^^^^^^^^^^\n\n\n模型定义完成后，用户就可以开始训练模型。由于整图分类处理的是很多相对较小的图，而不是一个大图，\n因此通常可以在随机抽取的小批次图上进行高效的训练，而无需设计复杂的图采样算法。\n\n以下例子中使用了 :ref:`guide_cn-data-pipeline` 中的整图分类数据集。\n\n.. code:: python\n\n    import dgl.data\n    dataset = dgl.data.GINDataset('MUTAG', False)\n\n整图分类数据集里的每个数据点是一个图和它对应标签的元组。为提升数据加载速度，\n用户可以调用GraphDataLoader，从而以小批次遍历整个图数据集。\n\n.. code:: python\n\n    from dgl.dataloading import GraphDataLoader\n    dataloader = GraphDataLoader(\n        dataset,\n        batch_size=1024,\n        drop_last=False,\n        shuffle=True)\n\n训练过程包括遍历dataloader和更新模型参数的部分。\n\n.. code:: python\n\n    import torch.nn.functional as F\n\n    # 这仅是个例子，特征尺寸是7\n    model = Classifier(7, 20, 5)\n    opt = torch.optim.Adam(model.parameters())\n    for epoch in range(20):\n        for batched_graph, labels in dataloader:\n            feats = batched_graph.ndata['attr']\n            logits = model(batched_graph, feats)\n            loss = F.cross_entropy(logits, labels)\n            opt.zero_grad()\n            loss.backward()\n            opt.step()\n\nDGL实现了一个整图分类的样例：\n`DGL的GIN样例 <https://github.com/dmlc/dgl/tree/master/examples/pytorch/gin>`__。\n模型训练的代码请参考位于\n`main.py <https://github.com/dmlc/dgl/blob/master/examples/pytorch/gin/main.py>`__ 源文件中的 ``train`` 函数。\n模型实现位于\n`gin.py <https://github.com/dmlc/dgl/blob/master/examples/pytorch/gin/gin.py>`__ ，\n其中使用了更多的模块组件，例如使用 :class:`dgl.nn.pytorch.GINConv`\n模块作为图卷积层(DGL同样支持它在MXNet和TensorFlow后端里的实现)、批量归一化等。\n\n异构图上的整图分类模型的训练\n~~~~~~~~~~~~~~~~~~~\n\n在异构图上做整图分类和在同构图上做整图分类略有不同。用户除了需要使用异构图卷积模块，还需要在读出函数中聚合不同类别的节点。\n\n以下代码演示了如何对每种节点类型的节点表示取平均值并求和。\n\n.. code:: python\n\n    class RGCN(nn.Module):\n        def __init__(self, in_feats, hid_feats, out_feats, rel_names):\n            super().__init__()\n    \n            self.conv1 = dglnn.HeteroGraphConv({\n                rel: dglnn.GraphConv(in_feats, hid_feats)\n                for rel in rel_names}, aggregate='sum')\n            self.conv2 = dglnn.HeteroGraphConv({\n                rel: dglnn.GraphConv(hid_feats, out_feats)\n                for rel in rel_names}, aggregate='sum')\n    \n        def forward(self, graph, inputs):\n            # inputs是节点的特征\n            h = self.conv1(graph, inputs)\n            h = {k: F.relu(v) for k, v in h.items()}\n            h = self.conv2(graph, h)\n            return h\n    \n    class HeteroClassifier(nn.Module):\n        def __init__(self, in_dim, hidden_dim, n_classes, rel_names):\n            super().__init__()\n\n            self.rgcn = RGCN(in_dim, hidden_dim, hidden_dim, rel_names)\n            self.classify = nn.Linear(hidden_dim, n_classes)\n    \n        def forward(self, g):\n            h = g.ndata['feat']\n            h = self.rgcn(g, h)\n            with g.local_scope():\n                g.ndata['h'] = h\n                # 通过平均读出值来计算单图的表征\n                hg = 0\n                for ntype in g.ntypes:\n                    hg = hg + dgl.mean_nodes(g, 'h', ntype=ntype)\n                return self.classify(hg)\n\n剩余部分的训练代码和同构图代码相同。\n\n.. code:: python\n\n    # etypes是一个列表，元素是字符串类型的边类型\n    model = HeteroClassifier(10, 20, 5, etypes)\n    opt = torch.optim.Adam(model.parameters())\n    for epoch in range(20):\n        for batched_graph, labels in dataloader:\n            logits = model(batched_graph)\n            loss = F.cross_entropy(logits, labels)\n            opt.zero_grad()\n            loss.backward()\n            opt.step()\n"
  },
  {
    "path": "docs/source/guide_cn/training-link.rst",
    "content": ".. _guide_cn-training-link-prediction:\n\n5.3 链接预测\n---------------------------\n\n:ref:`(English Version) <guide-training-link-prediction>`\n\n在某些场景中，用户可能希望预测给定节点之间是否存在边，这样的任务称作 **链接预测** 任务。\n\n概述\n~~~~~~~~\n\n基于GNN的链接预测模型的基本思想是通过使用所需预测的节点对\n:math:`u`, :math:`v` 的节点表示 :math:`\\boldsymbol{h}_u^{(L)}` 和\n:math:`\\boldsymbol{h}_v^{(L)}`，计算它们之间存在链接可能性的得分 :math:`y_{u,v}`。\n其中  :math:`\\boldsymbol{h}_u^{(L)}` 和  :math:`\\boldsymbol{h}_v^{(L)}` 由多层GNN计算得出。\n\n.. math::\n\n   y_{u,v} = \\phi(\\boldsymbol{h}_u^{(L)}, \\boldsymbol{h}_v^{(L)})\n\n本节把节点 :math:`u` 和 :math:`v` 之间存在连接可能性的 *得分* 记作 :math:`y_{u,v}`。\n\n训练一个链接预测模型涉及到比对两个相连接节点之间的得分与任意一对节点之间的得分的差异。\n例如，给定一条连接 :math:`u` 和 :math:`v` 的边，一个好的模型希望 :math:`u` 和 :math:`v` 之间的得分要高于\n:math:`u` 和从一个任意的噪声分布 :math:`v′∼Pn(v)` 中所采样的节点 :math:`v′` 之间的得分。\n这样的方法称作 *负采样*。\n\n许多损失函数都可以实现上述目标，包括但不限于。\n\n-  交叉熵损失:\n   :math:`\\mathcal{L} = - \\log \\sigma (y_{u,v}) - \\sum_{v_i \\sim P_n(v), i=1,\\dots,k}\\log \\left[ 1 - \\sigma (y_{u,v_i})\\right]`\n-  贝叶斯个性化排序损失:\n   :math:`\\mathcal{L} = \\sum_{v_i \\sim P_n(v), i=1,\\dots,k} - \\log \\sigma (y_{u,v} - y_{u,v_i})`\n-  间隔损失:\n   :math:`\\mathcal{L} = \\sum_{v_i \\sim P_n(v), i=1,\\dots,k} \\max(0, M - y_{u, v} + y_{u, v_i})`,\n   其中 :math:`M` 是常数项超参数。\n\n如果用户熟悉 `implicit feedback <https://arxiv.org/ftp/arxiv/papers/1205/1205.2618.pdf>`__ 和\n`noise-contrastive estimation <http://proceedings.mlr.press/v9/gutmann10a/gutmann10a.pdf>`__ ，\n可能会发现这些工作的想法都很类似。\n\n计算 :math:`u` 和 :math:`v` 之间分数的神经网络模型与 :ref:`guide_cn-training-edge-classification`\n中所述的边回归模型相同。\n\n下面是使用点积计算边得分的例子。\n\n.. code:: python\n\n    class DotProductPredictor(nn.Module):\n        def forward(self, graph, h):\n            # h是从5.1节的GNN模型中计算出的节点表示\n            with graph.local_scope():\n                graph.ndata['h'] = h\n                graph.apply_edges(fn.u_dot_v('h', 'h', 'score'))\n                return graph.edata['score']\n\n模型的训练\n~~~~~~~~~~~~~\n\n因为上述的得分预测模型在图上进行计算，用户需要将负采样的样本表示为另外一个图，\n其中包含所有负采样的节点对作为边。\n\n下面的例子展示了将负采样的样本表示为一个图。每一条边 :math:`(u,v)` 都有 :math:`k`\n个对应的负采样样本 :math:`(u,v_i)`，其中 :math:`v_i` 是从均匀分布中采样的。\n\n.. code:: python\n\n    def construct_negative_graph(graph, k):\n        src, dst = graph.edges()\n    \n        neg_src = src.repeat_interleave(k)\n        neg_dst = torch.randint(0, graph.num_nodes(), (len(src) * k,))\n        return dgl.graph((neg_src, neg_dst), num_nodes=graph.num_nodes())\n\n预测边得分的模型和边分类/回归模型中的预测边得分模型相同。\n\n.. code:: python\n\n    class Model(nn.Module):\n        def __init__(self, in_features, hidden_features, out_features):\n            super().__init__()\n            self.sage = SAGE(in_features, hidden_features, out_features)\n            self.pred = DotProductPredictor()\n        def forward(self, g, neg_g, x):\n            h = self.sage(g, x)\n            return self.pred(g, h), self.pred(neg_g, h)\n\n训练的循环部分里会重复构建负采样图并计算损失函数值。\n\n.. code:: python\n\n    def compute_loss(pos_score, neg_score):\n        # 间隔损失\n        n_edges = pos_score.shape[0]\n        return (1 - pos_score.unsqueeze(1) + neg_score.view(n_edges, -1)).clamp(min=0).mean()\n    \n    node_features = graph.ndata['feat']\n    n_features = node_features.shape[1]\n    k = 5\n    model = Model(n_features, 100, 100)\n    opt = torch.optim.Adam(model.parameters())\n    for epoch in range(10):\n        negative_graph = construct_negative_graph(graph, k)\n        pos_score, neg_score = model(graph, negative_graph, node_features)\n        loss = compute_loss(pos_score, neg_score)\n        opt.zero_grad()\n        loss.backward()\n        opt.step()\n        print(loss.item())\n\n训练后，节点表示可以通过以下代码获取。\n\n.. code:: python\n\n    node_embeddings = model.sage(graph, node_features)\n\n(实际应用中)，有着许多使用节点嵌入的方法，例如，训练下游任务的分类器，或为相关实体推荐进行最近邻搜索或最大内积搜索。\n\n异构图上的链接预测模型的训练\n~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n异构图上的链接预测和同构图上的链接预测没有太大区别。下文是在一种边类型上进行预测，\n用户可以很容易地将其拓展为对多种边类型上进行预测。\n\n例如，为某一种边类型，用户可以重复使用\n:ref:`guide_cn-training-edge-classification-heterogeneous-graph`\n里的 ``HeteroDotProductPredictor`` 来计算节点间存在连接可能性的得分。\n\n.. code:: python\n\n    class HeteroDotProductPredictor(nn.Module):\n        def forward(self, graph, h, etype):\n            # h是从5.1节中对异构图的每种类型的边所计算的节点表示\n            with graph.local_scope():\n                graph.ndata['h'] = h\n                graph.apply_edges(fn.u_dot_v('h', 'h', 'score'), etype=etype)\n                return graph.edges[etype].data['score']\n\n要执行负采样，用户可以对要进行链接预测的边类型构造一个负采样图。\n\n.. code:: python\n\n    def construct_negative_graph(graph, k, etype):\n        utype, _, vtype = etype\n        src, dst = graph.edges(etype=etype)\n        neg_src = src.repeat_interleave(k)\n        neg_dst = torch.randint(0, graph.num_nodes(vtype), (len(src) * k,))\n        return dgl.heterograph(\n            {etype: (neg_src, neg_dst)},\n            num_nodes_dict={ntype: graph.num_nodes(ntype) for ntype in graph.ntypes})\n\n该模型与异构图上边分类的模型有些不同，因为用户需要指定在哪种边类型上进行链接预测。\n\n.. code:: python\n\n    class Model(nn.Module):\n        def __init__(self, in_features, hidden_features, out_features, rel_names):\n            super().__init__()\n            self.sage = RGCN(in_features, hidden_features, out_features, rel_names)\n            self.pred = HeteroDotProductPredictor()\n        def forward(self, g, neg_g, x, etype):\n            h = self.sage(g, x)\n            return self.pred(g, h, etype), self.pred(neg_g, h, etype)\n\n训练的循环部分和同构图时一致。\n\n.. code:: python\n\n    def compute_loss(pos_score, neg_score):\n        # 间隔损失\n        n_edges = pos_score.shape[0]\n        return (1 - pos_score.unsqueeze(1) + neg_score.view(n_edges, -1)).clamp(min=0).mean()\n    \n    k = 5\n    model = Model(10, 20, 5, hetero_graph.etypes)\n    user_feats = hetero_graph.nodes['user'].data['feature']\n    item_feats = hetero_graph.nodes['item'].data['feature']\n    node_features = {'user': user_feats, 'item': item_feats}\n    opt = torch.optim.Adam(model.parameters())\n    for epoch in range(10):\n        negative_graph = construct_negative_graph(hetero_graph, k, ('user', 'click', 'item'))\n        pos_score, neg_score = model(hetero_graph, negative_graph, node_features, ('user', 'click', 'item'))\n        loss = compute_loss(pos_score, neg_score)\n        opt.zero_grad()\n        loss.backward()\n        opt.step()\n        print(loss.item())\n\n\n\n"
  },
  {
    "path": "docs/source/guide_cn/training-node.rst",
    "content": ".. _guide_cn-training-node-classification:\n\n5.1 节点分类/回归\n--------------------------------------------------\n\n:ref:`(English Version) <guide-training-node-classification>`\n\n对于图神经网络来说，最常见和被广泛使用的任务之一就是节点分类。\n图数据中的训练、验证和测试集中的每个节点都具有从一组预定义的类别中分配的一个类别，即正确的标注。\n节点回归任务也类似，训练、验证和测试集中的每个节点都被标注了一个正确的数字。\n\n概述\n~~~~~~~~\n\n为了对节点进行分类，图神经网络执行了 :ref:`guide_cn-message-passing`\n中介绍的消息传递机制，利用节点自身的特征和其邻节点及边的特征来计算节点的隐藏表示。\n消息传递可以重复多轮，以利用更大范围的邻居信息。\n\n编写神经网络模型\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nDGL提供了一些内置的图卷积模块，可以完成一轮消息传递计算。\n本章中选择 :class:`dgl.nn.pytorch.SAGEConv` 作为演示的样例代码(针对MXNet和PyTorch后端也有对应的模块)，\n它是GraphSAGE模型中使用的图卷积模块。\n\n对于图上的深度学习模型，通常需要一个多层的图神经网络，并在这个网络中要进行多轮的信息传递。\n可以通过堆叠图卷积模块来实现这种网络架构，具体如下所示。\n\n.. code:: python\n\n    # 构建一个2层的GNN模型\n    import dgl.nn as dglnn\n    import torch.nn as nn\n    import torch.nn.functional as F\n    class SAGE(nn.Module):\n        def __init__(self, in_feats, hid_feats, out_feats):\n            super().__init__()\n            # 实例化SAGEConve，in_feats是输入特征的维度，out_feats是输出特征的维度，aggregator_type是聚合函数的类型\n            self.conv1 = dglnn.SAGEConv(\n                in_feats=in_feats, out_feats=hid_feats, aggregator_type='mean')\n            self.conv2 = dglnn.SAGEConv(\n                in_feats=hid_feats, out_feats=out_feats, aggregator_type='mean')\n      \n        def forward(self, graph, inputs):\n            # 输入是节点的特征\n            h = self.conv1(graph, inputs)\n            h = F.relu(h)\n            h = self.conv2(graph, h)\n            return h\n\n请注意，这个模型不仅可以做节点分类，还可以为其他下游任务获取隐藏节点表示，如：\n:ref:`guide_cn-training-edge-classification`、\n:ref:`guide_cn-training-link-prediction` 和\n:ref:`guide_cn-training-graph-classification`。\n\n关于DGL内置图卷积模块的完整列表，读者可以参考 :ref:`apinn`。\n\n有关DGL神经网络模块如何工作，以及如何编写一个自定义的带有消息传递的GNN模块的更多细节，请参考 :ref:`guide_cn-nn` 中的例子。\n\n模型的训练\n~~~~~~~~~~~~~\n\n全图(使用所有的节点和边的特征)上的训练只需要使用上面定义的模型进行前向传播计算，并通过在训练节点上比较预测和真实标签来计算损失，从而完成后向传播。\n\n本节使用DGL内置的数据集 :class:`dgl.data.CiteseerGraphDataset` 来展示模型的训练。\n节点特征和标签存储在其图上，训练、验证和测试的分割也以布尔掩码的形式存储在图上。这与在\n:ref:`guide_cn-data-pipeline` 中的做法类似。\n\n.. code:: python\n\n    node_features = graph.ndata['feat']\n    node_labels = graph.ndata['label']\n    train_mask = graph.ndata['train_mask']\n    valid_mask = graph.ndata['val_mask']\n    test_mask = graph.ndata['test_mask']\n    n_features = node_features.shape[1]\n    n_labels = int(node_labels.max().item() + 1)\n\n下面是通过使用准确性来评估模型的一个例子。\n\n.. code:: python\n\n    def evaluate(model, graph, features, labels, mask):\n        model.eval()\n        with torch.no_grad():\n            logits = model(graph, features)\n            logits = logits[mask]\n            labels = labels[mask]\n            _, indices = torch.max(logits, dim=1)\n            correct = torch.sum(indices == labels)\n            return correct.item() * 1.0 / len(labels)\n\n用户可以按如下方式实现模型的训练。\n\n.. code:: python\n\n    model = SAGE(in_feats=n_features, hid_feats=100, out_feats=n_labels)\n    opt = torch.optim.Adam(model.parameters())\n    \n    for epoch in range(10):\n        model.train()\n        # 使用所有节点(全图)进行前向传播计算\n        logits = model(graph, node_features)\n        # 计算损失值\n        loss = F.cross_entropy(logits[train_mask], node_labels[train_mask])\n        # 计算验证集的准确度\n        acc = evaluate(model, graph, node_features, node_labels, valid_mask)\n        # 进行反向传播计算\n        opt.zero_grad()\n        loss.backward()\n        opt.step()\n        print(loss.item())\n    \n        # 如果需要的话，保存训练好的模型。本例中省略。\n\n\n`DGL的GraphSAGE样例 <https://github.com/dmlc/dgl/blob/master/examples/pytorch/graphsage/train_full.py>`__\n提供了一个端到端的同构图节点分类的例子。用户可以在 ``GraphSAGE`` 类中看到模型实现的细节。\n这个模型具有可调节的层数、dropout概率，以及可定制的聚合函数和非线性函数。\n\n.. _guide_cn-training-rgcn-node-classification:\n\n异构图上的节点分类模型的训练\n~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n如果图是异构的，用户可能希望沿着所有边类型从邻居那里收集消息。\n用户可以使用 :class:`dgl.nn.pytorch.HeteroGraphConv`\n模块(针对MXNet和PyTorch后端也有对应的模块)在所有边类型上执行消息传递，\n并为每种边类型使用一种图卷积模块。\n\n下面的代码定义了一个异构图卷积模块。模块首先对每种边类型进行单独的图卷积计算，然后将每种边类型上的消息聚合结果再相加，\n并作为所有节点类型的最终结果。\n\n.. code:: python\n\n    # Define a Heterograph Conv model\n\n    class RGCN(nn.Module):\n        def __init__(self, in_feats, hid_feats, out_feats, rel_names):\n            super().__init__()\n            # 实例化HeteroGraphConv，in_feats是输入特征的维度，out_feats是输出特征的维度，aggregate是聚合函数的类型\n            self.conv1 = dglnn.HeteroGraphConv({\n                rel: dglnn.GraphConv(in_feats, hid_feats)\n                for rel in rel_names}, aggregate='sum')\n            self.conv2 = dglnn.HeteroGraphConv({\n                rel: dglnn.GraphConv(hid_feats, out_feats)\n                for rel in rel_names}, aggregate='sum')\n      \n        def forward(self, graph, inputs):\n            # 输入是节点的特征字典\n            h = self.conv1(graph, inputs)\n            h = {k: F.relu(v) for k, v in h.items()}\n            h = self.conv2(graph, h)\n            return h\n\n\n``dgl.nn.HeteroGraphConv`` 接收一个节点类型和节点特征张量的字典作为输入，并返回另一个节点类型和节点特征的字典。\n\n本章的 :ref:`guide_cn-training-heterogeneous-graph-example`\n中已经有了 ``user`` 和 ``item`` 的特征，用户可用如下代码获取。\n\n.. code:: python\n\n    model = RGCN(n_hetero_features, 20, n_user_classes, hetero_graph.etypes)\n    user_feats = hetero_graph.nodes['user'].data['feature']\n    item_feats = hetero_graph.nodes['item'].data['feature']\n    labels = hetero_graph.nodes['user'].data['label']\n    train_mask = hetero_graph.nodes['user'].data['train_mask']\n\n然后，用户可以简单地按如下形式进行前向传播计算：\n\n.. code:: python\n\n    node_features = {'user': user_feats, 'item': item_feats}\n    h_dict = model(hetero_graph, {'user': user_feats, 'item': item_feats})\n    h_user = h_dict['user']\n    h_item = h_dict['item']\n\n异构图上模型的训练和同构图的模型训练是一样的，只是这里使用了一个包括节点表示的字典来计算预测值。\n例如，如果只预测 ``user`` 节点的类别，用户可以从返回的字典中提取 ``user`` 的节点嵌入。\n\n.. code:: python\n\n    opt = torch.optim.Adam(model.parameters())\n    \n    for epoch in range(5):\n        model.train()\n        # 使用所有节点的特征进行前向传播计算，并提取输出的user节点嵌入\n        logits = model(hetero_graph, node_features)['user']\n        # 计算损失值\n        loss = F.cross_entropy(logits[train_mask], labels[train_mask])\n        # 计算验证集的准确度。在本例中省略。\n        # 进行反向传播计算\n        opt.zero_grad()\n        loss.backward()\n        opt.step()\n        print(loss.item())\n    \n        # 如果需要的话，保存训练好的模型。本例中省略。\n\nDGL提供了一个用于节点分类的RGCN的端到端的例子\n`RGCN <https://github.com/dmlc/dgl/blob/master/examples/pytorch/rgcn-hetero/entity_classify.py>`__\n。用户可以在 `RGCN模型实现文件\n<https://github.com/dmlc/dgl/blob/master/examples/pytorch/rgcn-hetero/model.py>`__\n中查看异构图卷积 ``RelGraphConvLayer`` 的具体定义。\n\n\n"
  },
  {
    "path": "docs/source/guide_cn/training.rst",
    "content": ".. _guide_cn-training:\n\n第5章：训练图神经网络\n=====================================================\n\n:ref:`(English Version) <guide-training>`\n\n概述\n--------\n\n本章通过使用 :ref:`guide_cn-message-passing` 中介绍的消息传递方法和 :ref:`guide_cn-nn` 中介绍的图神经网络模块，\n讲解了如何对小规模的图数据进行节点分类、边分类、链接预测和整图分类的图神经网络的训练。\n\n本章假设用户的图以及所有的节点和边特征都能存进GPU。对于无法全部载入的情况，请参考用户指南的 :ref:`guide_cn-minibatch`。\n\n后续章节的内容均假设用户已经准备好了图和节点/边的特征数据。如果用户希望使用DGL提供的数据集或其他兼容\n``DGLDataset`` 的数据(如 :ref:`guide_cn-data-pipeline` 所述)，\n可以使用类似以下代码的方法获取单个图数据集的图数据。\n\n.. code:: python\n\n    import dgl\n\n    dataset = dgl.data.CiteseerGraphDataset()\n    graph = dataset[0]\n\n注意: 本章代码使用PyTorch作为DGL的后端框架。\n\n.. _guide_cn-training-heterogeneous-graph-example:\n\n异构图训练的样例数据\n~~~~~~~~~~~~~~~~~~~~~~~~~\n\n有时用户会想在异构图上进行图神经网络的训练。本章会以下面代码所创建的一个异构图为例，来演示如何进行节点分类、边分类和链接预测的训练。\n\n这个 ``hetero_graph`` 异构图有以下这些边的类型：\n\n-  ``('user', 'follow', 'user')``\n-  ``('user', 'followed-by', 'user')``\n-  ``('user', 'click', 'item')``\n-  ``('item', 'clicked-by', 'user')``\n-  ``('user', 'dislike', 'item')``\n-  ``('item', 'disliked-by', 'user')``\n\n.. code:: python\n\n    import numpy as np\n    import torch\n\n    n_users = 1000\n    n_items = 500\n    n_follows = 3000\n    n_clicks = 5000\n    n_dislikes = 500\n    n_hetero_features = 10\n    n_user_classes = 5\n    n_max_clicks = 10\n\n    follow_src = np.random.randint(0, n_users, n_follows)\n    follow_dst = np.random.randint(0, n_users, n_follows)\n    click_src = np.random.randint(0, n_users, n_clicks)\n    click_dst = np.random.randint(0, n_items, n_clicks)\n    dislike_src = np.random.randint(0, n_users, n_dislikes)\n    dislike_dst = np.random.randint(0, n_items, n_dislikes)\n\n    hetero_graph = dgl.heterograph({\n        ('user', 'follow', 'user'): (follow_src, follow_dst),\n        ('user', 'followed-by', 'user'): (follow_dst, follow_src),\n        ('user', 'click', 'item'): (click_src, click_dst),\n        ('item', 'clicked-by', 'user'): (click_dst, click_src),\n        ('user', 'dislike', 'item'): (dislike_src, dislike_dst),\n        ('item', 'disliked-by', 'user'): (dislike_dst, dislike_src)})\n\n    hetero_graph.nodes['user'].data['feature'] = torch.randn(n_users, n_hetero_features)\n    hetero_graph.nodes['item'].data['feature'] = torch.randn(n_items, n_hetero_features)\n    hetero_graph.nodes['user'].data['label'] = torch.randint(0, n_user_classes, (n_users,))\n    hetero_graph.edges['click'].data['label'] = torch.randint(1, n_max_clicks, (n_clicks,)).float()\n    # 在user类型的节点和click类型的边上随机生成训练集的掩码\n    hetero_graph.nodes['user'].data['train_mask'] = torch.zeros(n_users, dtype=torch.bool).bernoulli(0.6)\n    hetero_graph.edges['click'].data['train_mask'] = torch.zeros(n_clicks, dtype=torch.bool).bernoulli(0.6)\n\n本章路线图\n------------\n\n本章共有四节，每节对应一种图学习任务。\n\n* :ref:`guide_cn-training-node-classification`\n* :ref:`guide_cn-training-edge-classification`\n* :ref:`guide_cn-training-link-prediction`\n* :ref:`guide_cn-training-graph-classification`\n* :ref:`guide_cn-training-graph-eweight`\n\n.. toctree::\n    :maxdepth: 1\n    :hidden:\n    :glob:\n\n    training-node\n    training-edge\n    training-link\n    training-graph\n"
  },
  {
    "path": "docs/source/guide_ko/data-dataset.rst",
    "content": ".. _guide_ko-data-pipeline-dataset:\n\n4.1 DGLDataset 클래스\n--------------------\n\n:ref:`(English Version) <guide-data-pipeline-dataset>`\n\n:class:`~dgl.data.DGLDataset` 는 :ref:`apidata` 에서 정의된 그래프 데이터셋을 프로세싱하고, 로딩하고 저장하기 위한 기본 클래스이다. 이는 그래프 데이트를 서치하는 기본 파이프라인을 구현한다. 아래 순서도는 파이프라인이 어떻게 동작하는지를 보여준다.\n\n.. figure:: https://data.dgl.ai/asset/image/userguide_data_flow.png\n    :align: center\n\n    DGLDataset 클래스에 정의된 그래프 데이터 입력 파이프라인에 대한 순서도\n\n\n원격 또는 로컬 디스크에 있는 그래프 데이터셋을 처리하기 위해서, :class:`dgl.data.DGLDataset` 를 상속해서 클래스를 정의하나. 예로, ``MyDataset`` 이라고 하자. ``MyDataset`` 템플릿은 다음과 같다.\n\n.. code:: \n\n    from dgl.data import DGLDataset\n    \n    class MyDataset(DGLDataset):\n        \"\"\" Template for customizing graph datasets in DGL.\n    \n        Parameters\n        ----------\n        url : str\n            URL to download the raw dataset\n        raw_dir : str\n            Specifying the directory that will store the \n            downloaded data or the directory that\n            already stores the input data.\n            Default: ~/.dgl/\n        save_dir : str\n            Directory to save the processed dataset.\n            Default: the value of `raw_dir`\n        force_reload : bool\n            Whether to reload the dataset. Default: False\n        verbose : bool\n            Whether to print out progress information\n        \"\"\"\n        def __init__(self, \n                     url=None, \n                     raw_dir=None, \n                     save_dir=None, \n                     force_reload=False, \n                     verbose=False):\n            super(MyDataset, self).__init__(name='dataset_name',\n                                            url=url,\n                                            raw_dir=raw_dir,\n                                            save_dir=save_dir,\n                                            force_reload=force_reload,\n                                            verbose=verbose)\n    \n        def download(self):\n            # download raw data to local disk\n            pass\n    \n        def process(self):\n            # process raw data to graphs, labels, splitting masks\n            pass\n        \n        def __getitem__(self, idx):\n            # get one example by index\n            pass\n    \n        def __len__(self):\n            # number of data examples\n            pass\n    \n        def save(self):\n            # save processed data to directory `self.save_path`\n            pass\n    \n        def load(self):\n            # load processed data from directory `self.save_path`\n            pass\n    \n        def has_cache(self):\n            # check whether there are processed data in `self.save_path`\n            pass\n\n:class:`~dgl.data.DGLDataset` 클래스에는 서브클래스에서 꼭 구현되어야 하는 함수들 ``process()`` ,\n``__getitem__(idx)`` 와 ``__len__()`` 이 있다. 또한 DGL은 저장과 로딩을 구현하는 것을 권장하는데, 그 이유는 큰 데이터셋 처리 시간을 많이 줄일 수 있고, 이를 쉽게 구현하는데 필요한 API들이 있기 때문이다. (:ref:`guide_ko-data-pipeline-savenload` 참고)\n\n:class:`~dgl.data.DGLDataset` 의 목적은 그래프 데이터 로드에 필요한 편리하고 표준적인 방법을 제공하는 것이다. 그래프, 피쳐, 레이블, 그리고 데이터셋에 대한 기본적인 정보 (클래스 개수, 레이블 개수 등)을 저장할 수 있다. 샘플링, 파티셔닝 또는 파쳐 normalization과 같은 작업은 :class:`~dgl.data.DGLDataset` 의 서브클래스 밖에서 수행된다.\n\n이 장의 나머지에서는 파이프라인에서 함수를 구현하는 best practice들을 소개한다.\n"
  },
  {
    "path": "docs/source/guide_ko/data-download.rst",
    "content": ".. _guide_ko-data-pipeline-download:\n\n4.2 Raw 데이터 다운로드하기 (optional)\n---------------------------------\n\n:ref:`(English Version) <guide-data-pipeline-download>`\n\n로컬 디스크에 데이터셋이 이미 존재한다면, ``raw_dir`` 디렉토리에 있어야 한다. 만약 데이터를 다운로드하고 특정 디렉토리에 옮기는 일을 직접 수행하지 않고 코드를 실행하고 어디서나 실행하고 싶다면, ``download()`` 구현해서 이를 자동화할 수 있다.\n\n데이터셋이 zip 파일 포멧인 경우, zip 파일 추출을 자동을 해주는 :class:`dgl.data.DGLBuiltinDataset` 클래스를 상속해서 ``MyDataset`` 클래스를 만들자. 그렇지 않은 경우 :class:`~dgl.data.QM7bDataset` 처럼 ``download()`` 함수를 직접 구현한다:\n\n.. code:: \n\n    import os\n    from dgl.data.utils import download\n    \n    def download(self):\n        # path to store the file\n        file_path = os.path.join(self.raw_dir, self.name + '.mat')\n        # download file\n        download(self.url, path=file_path)\n\n위 코드는 .mat 파일을 ``self.raw_dir`` 디렉토리에 다운로드한다. 만약 파일 포멧이 .gz, .tar, .tar.gz 또는 .tgz 이라면, :func:`~dgl.data.utils.extract_archive` 함수로 파일들을 추출하자. 다음 코드는 :class:`~dgl.data.BitcoinOTCDataset` 에서 .gz 파일을 다운로드하는 예이다:\n\n.. code:: \n\n    from dgl.data.utils import download, check_sha1\n    \n    def download(self):\n        # path to store the file\n        # make sure to use the same suffix as the original file name's\n        gz_file_path = os.path.join(self.raw_dir, self.name + '.csv.gz')\n        # download file\n        download(self.url, path=gz_file_path)\n        # check SHA-1\n        if not check_sha1(gz_file_path, self._sha1_str):\n            raise UserWarning('File {} is downloaded but the content hash does not match.'\n                              'The repo may be outdated or download may be incomplete. '\n                              'Otherwise you can create an issue for it.'.format(self.name + '.csv.gz'))\n        # extract file to directory `self.name` under `self.raw_dir`\n        self._extract_gz(gz_file_path, self.raw_path)\n\n위 코드는 ``self.raw_dir`` 디렉토리 아래의 ``self.name`` 서브 디렉토리에 파일을 추출한다. 만약 zip 파일을 다루기 위해서 :class:`dgl.data.DGLBuiltinDataset` 를 상속해서 사용했다면, 파일들은 자동으로 ``self.name`` 디렉토리로 추출될 것이다.\n\n추가적으로, 다운로드한 파일에 대한 SHA-1 값 검증을 수행해서 파일이 변경되었는지 확인하는 것도 위 예제처럼 구현할 수 있다."
  },
  {
    "path": "docs/source/guide_ko/data-loadogb.rst",
    "content": ".. _guide_ko-data-pipeline-loadogb:\n\n4.5 ``ogb`` 패키지를 사용해서 OGB 데이터셋들 로드하기\n-------------------------------------------\n\n:ref:`(English Version) <guide-data-pipeline-loadogb>`\n\n`Open Graph Benchmark (OGB) <https://ogb.stanford.edu/docs/home/>`__ 은 벤치마킹 데이터셋의 모음이다. 공식 OGB 패키지 `ogb <https://github.com/snap-stanford/ogb>`__ 는 OBG 데이터셋들을 다운로드해서 :class:`dgl.data.DGLGraph` 객체로 프로세싱하는 API들을 제공한다. 이 절은 기본적인 사용법을 설명한다.\n\n우선 obg 패키지를 pip 명령으로 설치한다.\n\n.. code:: \n\n    pip install ogb\n\n다음 코드는 *Graph Property Prediction* 테스크를 위한 데이터셋 로딩 방법을 보여준다.\n\n.. code:: \n\n    # Load Graph Property Prediction datasets in OGB\n    import dgl\n    import torch\n    from ogb.graphproppred import DglGraphPropPredDataset\n    from dgl.dataloading import GraphDataLoader\n    \n    \n    def _collate_fn(batch):\n        # batch is a list of tuple (graph, label)\n        graphs = [e[0] for e in batch]\n        g = dgl.batch(graphs)\n        labels = [e[1] for e in batch]\n        labels = torch.stack(labels, 0)\n        return g, labels\n    \n    # load dataset\n    dataset = DglGraphPropPredDataset(name='ogbg-molhiv')\n    split_idx = dataset.get_idx_split()\n    # dataloader\n    train_loader = GraphDataLoader(dataset[split_idx[\"train\"]], batch_size=32, shuffle=True, collate_fn=_collate_fn)\n    valid_loader = GraphDataLoader(dataset[split_idx[\"valid\"]], batch_size=32, shuffle=False, collate_fn=_collate_fn)\n    test_loader = GraphDataLoader(dataset[split_idx[\"test\"]], batch_size=32, shuffle=False, collate_fn=_collate_fn)\n\n*Node Property Prediction* 데이터셋을 로딩하는 것이 비슷하지만, 이런 종류의 데이터셋은 오직 한 개의 그래프 객체만 존재한다는 것이 다름을 유의하자.\n\n.. code:: \n\n    # Load Node Property Prediction datasets in OGB\n    from ogb.nodeproppred import DglNodePropPredDataset\n    \n    dataset = DglNodePropPredDataset(name='ogbn-proteins')\n    split_idx = dataset.get_idx_split()\n    \n    # there is only one graph in Node Property Prediction datasets\n    g, labels = dataset[0]\n    # get split labels\n    train_label = dataset.labels[split_idx['train']]\n    valid_label = dataset.labels[split_idx['valid']]\n    test_label = dataset.labels[split_idx['test']]\n\n*Link Property Prediction* 데이터셋 역시 데이터셋에 한개의 그래프를 갖고 있다.\n\n.. code:: \n\n    # Load Link Property Prediction datasets in OGB\n    from ogb.linkproppred import DglLinkPropPredDataset\n    \n    dataset = DglLinkPropPredDataset(name='ogbl-ppa')\n    split_edge = dataset.get_edge_split()\n    \n    graph = dataset[0]\n    print(split_edge['train'].keys())\n    print(split_edge['valid'].keys())\n    print(split_edge['test'].keys())\n"
  },
  {
    "path": "docs/source/guide_ko/data-process.rst",
    "content": ".. _guide_ko-data-pipeline-process:\n\n4.3 데이터 프로세싱\n---------------\n\n:ref:`(English Version) <guide-data-pipeline-process>`\n\n데이터 프로세싱 코드를 ``process()`` 함수에 구현할 수 있으며, 이때 처리되지 않은 데이터는 ``self.raw_dir`` 디렉토리에 있어야 한다. 그래프 머신러닝에는 일반적으로 3가지 종류의 일이 있다: 그래프 분류, 노드 분류, 그리고 링크 예측. 이 절에서는 이 일들에 관련된 데이터셋 처리 방법을 설명한다.\n\n이 절에서 그래프들, 피쳐들, 그리고 마스크들을 처리하는 표준 방법에 집중해서 알아본다. 빌트인 데이터셋을 예제로 사용할 것이고, 파일로 부터 그래프를 만드는 방법은 생략한다. 하지만, 이와 관련된 구현에 대한 링크를 제공할 것이다. 외부 소스들로 부터 그래프를 만드는 방법에 대한 완벽한 가이드는 :ref:`guide_ko-graph-external` 를 참고하자.\n\n그래프 분류 데이터셋 프로세싱\n~~~~~~~~~~~~~~~~~~~~~~\n\n그래프 분류 데이터셋은 미니-배치 학습이 사용되는 전형적인 머신러닝 테스크에서 사용되는 데이터셋과 거의 동일하다. 즉, 처리되지 않은 데이터는 :class:`dgl.DGLGraph` 객체들의 리스트와 레이블 텐서들의 리스트로 변환하면 된다. 또한, 만약 처리되지 않은 데이터가 여러 파일들로 나눠져 있을 경우에는, 데이터의 특정 부분을 로드하기 위해서 ``split``  파라메터를 더할 수 있다.\n\n:class:`~dgl.data.QM7bDataset` 를 예로 살펴보자:\n\n.. code::\n\n    from dgl.data import DGLDataset\n\n    class QM7bDataset(DGLDataset):\n        _url = 'http://deepchem.io.s3-website-us-west-1.amazonaws.com/' \\\n               'datasets/qm7b.mat'\n        _sha1_str = '4102c744bb9d6fd7b40ac67a300e49cd87e28392'\n\n        def __init__(self, raw_dir=None, force_reload=False, verbose=False):\n            super(QM7bDataset, self).__init__(name='qm7b',\n                                              url=self._url,\n                                              raw_dir=raw_dir,\n                                              force_reload=force_reload,\n                                              verbose=verbose)\n\n        def process(self):\n            mat_path = self.raw_path + '.mat'\n            # process data to a list of graphs and a list of labels\n            self.graphs, self.label = self._load_graph(mat_path)\n\n        def __getitem__(self, idx):\n            \"\"\" Get graph and label by index\n\n            Parameters\n            ----------\n            idx : int\n                Item index\n\n            Returns\n            -------\n            (dgl.DGLGraph, Tensor)\n            \"\"\"\n            return self.graphs[idx], self.label[idx]\n\n        def __len__(self):\n            \"\"\"Number of graphs in the dataset\"\"\"\n            return len(self.graphs)\n\n``process()`` 함수에서 처리되지 않은 데이터는 그래프들의 리스트와 레이블들의 리스트로 변환된다. Iteration을 위해서 ``__getitem__(idx)`` 와 ``__len__()`` 를 구현해야 한다. 위의 예제에서와 같이, DGL에서는 ``__getitem__(idx)`` 가 ``(graph, label)`` tuple을 리턴하도록 권장한다. ``self._load_graph()`` 와 ``__getitem__`` 함수의 구체적인 구현은 `QM7bDataset source\ncode <https://docs.dgl.ai/en/0.5.x/_modules/dgl/data/qm7b.html#QM7bDataset>`__ 를 확인하자.\n\n데이터셋의 유용한 정보들을 지정하기 위해서 클래스에 프로퍼티들을 추가하는 것이 가능하다. :class:`~dgl.data.QM7bDataset` 에 이 멀티 테스크 데이터셋의 예측 테스트의 총 개숫를 지정하기 위해 ``num_tasks`` 라는 프로퍼티를 추가할 수 있다.\n\n.. code::\n\n    @property\n    def num_tasks(self):\n        \"\"\"Number of labels for each graph, i.e. number of prediction tasks.\"\"\"\n        return 14\n\n구현 코드를 마친 후에, :class:`~dgl.data.QM7bDataset` 를 다음과 같이 사용한다.\n\n.. code::\n\n    import dgl\n    import torch\n\n    from dgl.dataloading import GraphDataLoader\n\n    # load data\n    dataset = QM7bDataset()\n    num_tasks = dataset.num_tasks\n\n    # create dataloaders\n    dataloader = GraphDataLoader(dataset, batch_size=1, shuffle=True)\n\n    # training\n    for epoch in range(100):\n        for g, labels in dataloader:\n            # your training code here\n            pass\n\n그래프 분류 모델 학습에 대한 전체 가이드는 :ref:`guide_ko-training-graph-classification` 를 참고하자.\n\nDGL의 빌트인 그래프 분류 데이터셋을 참고하면 그래프 분류 데이터셋의 더 많은 예들을 확인할 수 있다.\n\n* :ref:`gindataset`\n* :ref:`minigcdataset`\n* :ref:`qm7bdata`\n* :ref:`tudata`\n\n노드 분류 데이터셋 프로세싱\n~~~~~~~~~~~~~~~~~~~~\n\n그래프 분류와는 다르게 노드 분류는 일번적으로 단일 그래프에서 이뤄진다. 따라서, 데이터셋의 분할(split)은 그래프 노드에서 일어난다. DGL은 노드 마스크를 사용해서 분할을 지정하는 것을 권장한다. 이 절에서는 빌트인 데이터셋 `CitationGraphDataset <https://docs.dgl.ai/en/0.5.x/_modules/dgl/data/citation_graph.html#CitationGraphDataset>`__ 을 예로 들겠다.\n\n추가로, DGL은 노드들와 에지들이 서로 가까운 ID값들이 서로 가까운 범위에 있도록 재배열하는 것을 권장한다. 이 절차는 노드의 neighbor들에 대한 접근성을 향상시켜서, 이 후의 연산 및 그래프에 대한 분석을 빠르게 하기 위함이다. 이를 위해서 DGL은 :func:`dgl.reorder_graph` API를 제공한다. 더 자세한 내용은 다음 예제의 ``process()`` 를 참고하자.\n\n.. code::\n\n    from dgl.data import DGLBuiltinDataset\n    from dgl.data.utils import _get_dgl_url\n\n    class CitationGraphDataset(DGLBuiltinDataset):\n        _urls = {\n            'cora_v2' : 'dataset/cora_v2.zip',\n            'citeseer' : 'dataset/citeseer.zip',\n            'pubmed' : 'dataset/pubmed.zip',\n        }\n\n        def __init__(self, name, raw_dir=None, force_reload=False, verbose=True):\n            assert name.lower() in ['cora', 'citeseer', 'pubmed']\n            if name.lower() == 'cora':\n                name = 'cora_v2'\n            url = _get_dgl_url(self._urls[name])\n            super(CitationGraphDataset, self).__init__(name,\n                                                       url=url,\n                                                       raw_dir=raw_dir,\n                                                       force_reload=force_reload,\n                                                       verbose=verbose)\n\n        def process(self):\n            # Skip some processing code\n            # === data processing skipped ===\n\n            # build graph\n            g = dgl.graph(graph)\n            # splitting masks\n            g.ndata['train_mask'] = train_mask\n            g.ndata['val_mask'] = val_mask\n            g.ndata['test_mask'] = test_mask\n            # node labels\n            g.ndata['label'] = torch.tensor(labels)\n            # node features\n            g.ndata['feat'] = torch.tensor(_preprocess_features(features),\n                                           dtype=F.data_type_dict['float32'])\n            self._num_tasks = onehot_labels.shape[1]\n            self._labels = labels\n            # reorder graph to obtain better locality.\n            self._g = dgl.reorder_graph(g)\n\n        def __getitem__(self, idx):\n            assert idx == 0, \"This dataset has only one graph\"\n            return self._g\n\n        def __len__(self):\n            return 1\n\n분류 데이터셋 프로세싱 코드의 중요한 부분(마스크 분할하기)을 강조하기 위해서 ``process()`` 함수의 코드 일부는 생략해서 간략하게 만들었다.\n\n일반적으로 노드 분류 테스크에서 하나의 그래프만 사용되기 때문에, ``__getitem__(idx)`` 와 ``__len__()`` 함수 구현이 바뀐 점을 알아두자. 마스크는 PyTorch와 TensorFlow에서는 ``bool tensors`` 이고 MXNet에서는 ``float tensors`` 이다.\n\n다음 예는  ``CitationGraphDataset`` 의 서브 클래스인 :class:`dgl.data.CiteseerGraphDataset` 를 사용하는 방법이다.\n\n.. code::\n\n    # load data\n    dataset = CiteseerGraphDataset(raw_dir='')\n    graph = dataset[0]\n\n    # get split masks\n    train_mask = graph.ndata['train_mask']\n    val_mask = graph.ndata['val_mask']\n    test_mask = graph.ndata['test_mask']\n\n    # get node features\n    feats = graph.ndata['feat']\n\n    # get labels\n    labels = graph.ndata['label']\n\n노드 분류 모델에 대한 전체 가이드는 :ref:`guide_ko-training-node-classification` 를 참고하자.\n\nDGL의 빌트인 데이터셋들은 노드 분류 데이터셋의 여러 예제들을 포함하고 있다.\n\n* :ref:`citationdata`\n\n* :ref:`corafulldata`\n\n* :ref:`amazoncobuydata`\n\n* :ref:`coauthordata`\n\n* :ref:`karateclubdata`\n\n* :ref:`ppidata`\n\n* :ref:`redditdata`\n\n* :ref:`sbmdata`\n\n* :ref:`sstdata`\n\n* :ref:`rdfdata`\n\n링크 예측 데이터셋 프로세싱\n~~~~~~~~~~~~~~~~~~~~\n\n링크 예측 데이테셋을 프로세싱하는 것은 주로 데이터셋에 하나의 그래프만 있기 때문에, 노드 분류의 경우와 비슷하다.\n\n예제로 `KnowledgeGraphDataset <https://docs.dgl.ai/en/0.5.x/_modules/dgl/data/knowledge_graph.html#KnowledgeGraphDataset>`__ 빌트인 데이터셋을 사용하는데, 링크 예측 데이터셋 프로세싱의 주요 부분을 강조하기 위해서 자세한 데이터 프로세싱 코드는 생략했다.\n\n.. code::\n\n    # Example for creating Link Prediction datasets\n    class KnowledgeGraphDataset(DGLBuiltinDataset):\n        def __init__(self, name, reverse=True, raw_dir=None, force_reload=False, verbose=True):\n            self._name = name\n            self.reverse = reverse\n            url = _get_dgl_url('dataset/') + '{}.tgz'.format(name)\n            super(KnowledgeGraphDataset, self).__init__(name,\n                                                        url=url,\n                                                        raw_dir=raw_dir,\n                                                        force_reload=force_reload,\n                                                        verbose=verbose)\n\n        def process(self):\n            # Skip some processing code\n            # === data processing skipped ===\n\n            # splitting mask\n            g.edata['train_mask'] = train_mask\n            g.edata['val_mask'] = val_mask\n            g.edata['test_mask'] = test_mask\n            # edge type\n            g.edata['etype'] = etype\n            # node type\n            g.ndata['ntype'] = ntype\n            self._g = g\n\n        def __getitem__(self, idx):\n            assert idx == 0, \"This dataset has only one graph\"\n            return self._g\n\n        def __len__(self):\n            return 1\n\n\n위 코드에서 볼 수 있듯이 분할 마스크들을 그래프의 ``edata`` 필드에 추가한다. 전체 구현은  `KnowledgeGraphDataset 소스 코드 <https://docs.dgl.ai/en/0.5.x/_modules/dgl/data/knowledge_graph.html#KnowledgeGraphDataset>`__ 를 참고하자.\n\n.. code::\n\n    from dgl.data import FB15k237Dataset\n\n    # load data\n    dataset = FB15k237Dataset()\n    graph = dataset[0]\n\n    # get training mask\n    train_mask = graph.edata['train_mask']\n    train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze()\n    src, dst = graph.edges(train_idx)\n    # get edge types in training set\n    rel = graph.edata['etype'][train_idx]\n\n링크 예측 모델에 대한 전체 가이드는 :ref:`guide_ko-training-link-prediction` 에 있다.\n\nDGL의 빌트인 데이터셋들은 링크 예측 데이터셋의 여러 예제들을 포함하고 있다.\n\n* :ref:`kgdata`\n\n* :ref:`bitcoinotcdata`\n"
  },
  {
    "path": "docs/source/guide_ko/data-savenload.rst",
    "content": ".. _guide_ko-data-pipeline-savenload:\n\n4.4 데이터 저장과 로딩\n------------------\n\n:ref:`(English Version) <guide-data-pipeline-savenload>`\n\nDGL에서는 프로세싱된 데이터를 로컬 디스크에 임시로 저장하기 위해 저장 및 로딩 함수를 구현할 것을 권장한다. 이는 대부분의 경우에 데이터 프로세싱 시간을 상당히 절약할 수 있게한다. DGL은 이를 간단하게 구현하기 위한 4가지 함수를 제공한다:\n\n- :func:`dgl.save_graphs` 와 :func:`dgl.load_graphs` : DGLGraph 객체와 레이블을 로컬 디스크로 저장/로딩함\n- :func:`dgl.data.utils.save_info` 와 :func:`dgl.data.utils.load_info` : 데이터셋에 대한 유용한 정보(python의 ``dict`` 객체)를 로컬 디스크로 저장/로딩함\n\n다음 예는 그래프들의 리스트와 데이터셋 정보를 저장하는 것을 보여준다.\n\n.. code:: \n\n    import os\n    from dgl import save_graphs, load_graphs\n    from dgl.data.utils import makedirs, save_info, load_info\n    \n    def save(self):\n        # save graphs and labels\n        graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin')\n        save_graphs(graph_path, self.graphs, {'labels': self.labels})\n        # save other information in python dict\n        info_path = os.path.join(self.save_path, self.mode + '_info.pkl')\n        save_info(info_path, {'num_classes': self.num_classes})\n    \n    def load(self):\n        # load processed data from directory `self.save_path`\n        graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin')\n        self.graphs, label_dict = load_graphs(graph_path)\n        self.labels = label_dict['labels']\n        info_path = os.path.join(self.save_path, self.mode + '_info.pkl')\n        self.num_classes = load_info(info_path)['num_classes']\n    \n    def has_cache(self):\n        # check whether there are processed data in `self.save_path`\n        graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin')\n        info_path = os.path.join(self.save_path, self.mode + '_info.pkl')\n        return os.path.exists(graph_path) and os.path.exists(info_path)\n\n단, 프로세싱된 데이터를 저장하는 것이 적합하지 않은 경우도 있다. 예를 들어, 빌트인 데이터셋 중 :class:`~dgl.data.GDELTDataset` 의 경우 프로세스된 데이터가 굉장히 크기 때문에 ``__getitem__(idx)`` 에서 각 데이터 예제들을 처리하는 것이 더 효율적이다.\n"
  },
  {
    "path": "docs/source/guide_ko/data.rst",
    "content": ".. _guide_ko-data-pipeline:\n\n4장: 그래프 데이터 파이프라인\n======================\n\n:ref:`(English Version) <guide-data-pipeline>`\n\nDGL은 :ref:`apidata` 에서 일반적으로 많이 사용되는 그래프 데이터셋을 구현하고 있다. 이것들은 :class:`dgl.data.DGLDataset` 클래스에서 정의하고 있는 표준 파이프라인을 따른다. DGL은 :class:`dgl.data.DGLDataset` 의 서브클래스로 그래프 데이터 프로세싱하는 것을 강하게 권장한다. 이는 파이프라인이 그래프 데이터를 로딩하고, 처리하고, 저장하는데 대한 간단하고 깔끔한 방법을 제공하기 때문이다.\n\n로드맵\n----\n\n이 장은 커스텀 DGL-Dataset를 만드는 방법을 소개한다. 이를 위해 다음 절들에서 파이프라인이 어떻게 동작하는지 설명하고, 각 파이프라인의 컴포넌트를 구현하는 방법을 보여준다.\n\n* :ref:`guide_ko-data-pipeline-dataset`\n* :ref:`guide_ko-data-pipeline-download`\n* :ref:`guide_ko-data-pipeline-process`\n* :ref:`guide_ko-data-pipeline-savenload`\n* :ref:`guide_ko-data-pipeline-loadogb`\n\n.. toctree::\n    :maxdepth: 1\n    :hidden:\n    :glob:\n\n    data-dataset\n    data-download\n    data-process\n    data-savenload\n    data-loadogb"
  },
  {
    "path": "docs/source/guide_ko/distributed-apis.rst",
    "content": ".. _guide_ko-distributed-apis:\n\n7.2 분산 APIs\n--------------------\n\n:ref:`(English Version) <guide-distributed-apis>`\n\n이 절은 학습 스크립트에 사용할 분산 API들을 다룬다. DGL은 초기화, 분산 샘플링, 그리고 워크로드 분할(split)을 위한 세가지 분산 데이터 구조와 다양한 API들을 제공한다. 분산 학습/추론에 사용되는 세가지 분산 자료 구조는 분산 그래프를 위한 :class:`~dgl.distributed.DistGraph` , 분산 텐서를 위한 :class:`~dgl.distributed.DistTensor` , 그리고 분산 learnable 임베딩을 위한 :class:`~dgl.distributed.DistEmbedding` 이다.\n\nDGL 분산 모듈 초기화\n~~~~~~~~~~~~~~~~\n\n:func:`~dgl.distributed.initialize` 은 분산 모듈을 초기화한다. 학습 스크립트가 학습 모드로 수행되면, 이 API는 DGL 서버들간의 연결을 만들고, 샘플러 프로세스들을 생성한다; 스크립트가 서버 모드로 실행되면, 이 API는 서버 코드를 실행하고 절대로 리턴되지 않는다. 이 API는 어떤 DGL 분산 API들 보다 먼저 호출되어야 한다. PyTorch와 함께 사용될 때, :func:`~dgl.distributed.initialize` 는 ``torch.distributed.init_process_group`` 전에 호출되어야 한다. 일반적으로 초기화 API들은 다음 순서로 실행된다.\n\n.. code:: python\n\n    dgl.distributed.initialize('ip_config.txt')\n    th.distributed.init_process_group(backend='gloo')\n\nDistributed 그래프\n~~~~~~~~~~~~~~~~~\n\n:class:`~dgl.distributed.DistGraph` 는 클러스터에서 그래프 구조와 노드/에지 피쳐들을 접근하기 위한 Python 클래스이다. 각 컴퓨터는 단 하나의 파티션을 담당한다. 이 클래스는 파티션 데이터(그 파티션의 그래프 구조, 노드 데이터와 에지 데이터)를 로드하고, 클러스터의 모든 트레이너들이 접근할 수 있도록 만들어 준다. :class:`~dgl.distributed.DistGraph` 는 데이터 접근을 위한 :class:`~dgl.DGLGraph` API들의 작은 서브셋을 지원한다.\n\n**Note**: :class:`~dgl.distributed.DistGraph` 는 현재 한 개의 노드 타입과 한 개의 에지 타입만을 지원한다.\n\n분산 모드 vs. 단독(standalone) 모드\n^^^^^^^^^^^^^^^^^^\n\n:class:`~dgl.distributed.DistGraph` 는 두가지 모드로 실행된다: 분산 모드와 단독 모드. 사용자가 학습 스크립트를 Python 명령행이나 Jupyter notebook에서 실행하면, 단독 모드로 수행된다. 즉, 모든 계산이 단일 프로세스에서 수행되고, 다른 어떤 프로세스들과의 통신이 없다. 따라서, 단독 모드에서는 입력 그래프가 한 개의 파티션이다. 이 모드는 주로 개발 및 테스트를 위해서 사용된다 (즉, Jupyter notebook에서 코드를 개발하고 수행할 때). 학습 스크립트가 launch 스크립트를 사용해서 실행되면 (launch 스크립트 섹션 참조), :class:`~dgl.distributed.DistGraph` 가 분산 모드로 동작한다. Launch 툴은 자동으로 (노드/에지 피쳐 접근 및 그래프 샘플링을 하는) 서버들을 구동하고, 클러스터의 각 컴퓨터에 파티션 데이터를 자동으로 로드한다. :class:`~dgl.distributed.DistGraph` 는 클러스터의 서버들과 네트워크를 통해서 연결한다.\n\nDistGraph 생성\n^^^^^^^^^^^^^\n\n분산 모드에서는, :class:`~dgl.distributed.DistGraph` 를 생성할 때 파티션에서 사용된 그래프 이름이 필요하다. 그래프 이름은 클러스터에서 로드될 그래프를 지정한다.\n\n.. code:: python\n\n    import dgl\n    g = dgl.distributed.DistGraph('graph_name')\n\n단독 모드로 수행될 때, 로컬 머신의 그래프 데이터를 로드한다. 따라서, 사용자는 입력 그래프에 대한 모든 정보를 담고 있는 파티션 설정 파일을 제공해야 한다.\n\n.. code:: python\n\n    import dgl\n    g = dgl.distributed.DistGraph('graph_name', part_config='data/graph_name.json')\n\n**Note**: DGL의 현재 구현은 `DistGraph` 객체를 한 개만 만들 수 있다. `DistGraph` 를 없애고 새로운 것을 다시 만드는 것은 정의되어 있지 않다.\n\n그래프 구조 접근\n^^^^^^^^^^^^\n\n:class:`~dgl.distributed.DistGraph` 는 그래프 구조 접근을 위한 적은 수의 API들을 갖고 있다. 현재 대부분 API들은 노드 및 에지 수와 같은 그래프 정보를 제공한다. DistGraph의 주요 사용 케이스는 미니-배치 학습을 지원하기 위한 샘플링 API를 수행하는 것이다. (분산 그래프 샘플링은 섹션 참조)\n\n.. code:: python\n\n    print(g.num_nodes())\n\n노드/에지 데이터 접근\n^^^^^^^^^^^^^^^^\n\n:class:`~dgl.DGLGraph` 처럼 :class:`~dgl.distributed.DistGraph` 는 노드와 에지의 데이터 접근을 위해서 ``ndata`` 와 ``edata`` 를 제공한다. 차이점은 :class:`~dgl.distributed.DistGraph` 의 ``ndata`` / ``edata`` 는 사용되는 프레임워크의 텐서 대신 :class:`~dgl.distributed.DistTensor` 를 리턴한다는 것이다. 사용자는 새로운 :class:`~dgl.distributed.DistTensor` 를 :class:`~dgl.distributed.DistGraph` 노드 데이터 또는 에지 데이터로서 할당할 수 있다.\n\n.. code:: python\n\n    g.ndata['train_mask']  # <dgl.distributed.dist_graph.DistTensor at 0x7fec820937b8>\n    g.ndata['train_mask'][0]  # tensor([1], dtype=torch.uint8)\n\n분산 텐서(Distributed Tensor)\n~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n앞에서 언급했듯이, DGL은 노드/에치 피쳐들을 샤드(shard)해서, 머신들의 클러스터에 이것들을 저장한다. DGL은 클러스터에서 파티션된 노드/에지 피쳐들을 접근하기 위해서 tensor-like 인터패이스를 갖는 분산 텐서를 제공한다. 분산 세팅에서 DGL은 덴스 노드/에지 피쳐들만 지원한다.\n\n:class:`~dgl.distributed.DistTensor` 는 파티션되어 여러 머신들에 저장되어 있는 덴스 텐서들을 관리한다. 지금은 부산 텐서는 그래프의 노드 또는 에지와 연결되어 있어야만 한다. 다르게 말하자면, `DistTensor` 의 행 개수는 그래프의 노드 개수 또는 에지의 개수과 같아야만 한다. 아래 코드는 분산 텐서를 생성하고 있다. `shape` 과 `dtype` 뿐만아니라, 유일한 텐서 이름을 지정할 수 있다. 사용자가 영속적인 분산 텐서를 참고하고자 할 경우 이 이름은 유용하다 (즉, :class:`~dgl.distributed.DistTensor` 객체가 사라져도 클러스터에 존재하는 텐서).\n\n.. code:: python\n\n    tensor = dgl.distributed.DistTensor((g.num_nodes(), 10), th.float32, name='test')\n\n**Note**: :class:`~dgl.distributed.DistTensor` 생성은 동기화 수행이다. 모든 트레이너들은 생성을 실행해야하고, 모든 트레이너가 이를 호출한 경우에만 생성이 완료된다.\n\n사용자는 :class:`~dgl.distributed.DistTensor` 를 노드 데이터 또는 에지 데이터의 하나로서 :class:`~dgl.distributed.DistGraph`  객체에 추가할 수 있다.\n\n.. code:: python\n\n    g.ndata['feat'] = tensor\n\n**Note**: 노드 데이터 이름과 텐서 이름이 같을 필요는 없다. 전자는 :class:`~dgl.distributed.DistGraph` 로부터 노드 데이터를 구별하고(트레이너 프로세스에서), 후자는 DGL 서버들에서 분산 텐서를 구별하는데 사용된다. \n\n:class:`~dgl.distributed.DistTensor` 는 적은 수의 함수들을 제공한다. 이는 일반 텐서가 `shape` 또는 `dtype` 과 같은 메타데이터를 접근하는 것과 같은 API들이다. :class:`~dgl.distributed.DistTensor` 는 인덱스를 사용한 읽기와 쓰기를 지원하지만, `sum` 또는 `mean` 과 같은 연산 오퍼레이터는 지원하지 않는다.\n\n.. code:: python\n\n    data = g.ndata['feat'][[1, 2, 3]]\n    print(data)\n    g.ndata['feat'][[3, 4, 5]] = data\n\n**Note**: 현재 DGL은 한 머신이 여러 서버들을 수행할 때, 다중의 서버들이 동시에 쓰기를 동시에 수행하는 경우에 대한 보호를 지원하지 않는다. 이 경우 데이터 깨짐(data corruption)이 발생할 수 있다. 같은 행의 데이터에 동시 쓰기를 방지하는 방법 중에 하나로 한 머신에서 한 개의 서버 프로세스만 실행하는 것이다.\n\n분산 DistEmbedding\n~~~~~~~~~~~~~~~~~\n\nDGL은 노드 임베딩들을 필요로 하는 변환 모델(transductive models)을 지원하기 위해서 :class:`~dgl.distributed.DistEmbedding` 를 제공한다. 분산 임베딩을 생성하는 것은 분산 텐서를 생성하는 것과 비슷하다.\n\n.. code:: python\n\n    def initializer(shape, dtype):\n        arr = th.zeros(shape, dtype=dtype)\n        arr.uniform_(-1, 1)\n        return arr\n    emb = dgl.distributed.DistEmbedding(g.num_nodes(), 10, init_func=initializer)\n\n내부적으로는 분산 임배딩은 분산 텐서를 사용해서 만들어진다. 따라서, 분산 텐서와 비슷하게 동작한다. 예를 들어, 임베딩이 만들어지면, 그것들은 클러스터의 여러 머신들에 나눠져서(shard) 저장된다. 이는 이름을 통해서 고유하게 식별될 수 있다.\n\n**Note**: 초기화 함수가 서버 프로세스에서 호출된다. 따라서, :class:`~dgl.distributed.initialize` 전에 선언되야 한다.\n\n임배딩은 모델의 일부이기 때문에, 미니배치 학습을 위해서 이를 optimizer에 붙여줘야 한다. 현재는, DGL은 sparse Adagrad optimizer, :class:`~dgl.distributed.SparseAdagrad` 를 지원한다 (DGL은 sparse 임베딩을 위핸 더 많은 optimizer들을 추가할 예정이다). 사용자는 모델로 부터 모든 분산 임베딩을 수집하고, 이를 sparse optimizer에 전달해야 한다. 만약 모델이 노드 임베딩과 정상적인 dense 모델 파라메터들을 갖고, 사용자가 임베딩들에 sparse 업데이트를 수행하고 싶은 경우, optimizer 두 개를 만들어야 한다. 하나는 노드 임베딩을 위한 것이고, 다른 하나는 dense model 파라메터들을 위한 것이다. 다음 코드를 보자.\n\n.. code:: python\n\n    sparse_optimizer = dgl.distributed.SparseAdagrad([emb], lr=lr1)\n    optimizer = th.optim.Adam(model.parameters(), lr=lr2)\n    feats = emb(nids)\n    loss = model(feats)\n    loss.backward()\n    optimizer.step()\n    sparse_optimizer.step()\n\n**Note**: :class:`~dgl.distributed.DistEmbedding` 는 PyTorch nn 모듈이 아니다. 따라서, PyTorch nn 모듈의 파라메터들을 통해서 접근할 수 없다.\n\n분산 샘플링\n~~~~~~~~\n\nDGL은 미니-배치를 생성하기 위해 노드 및 에지 샘플링을 하는 두 수준의 API를 제공한다 (미니-배치 학습 섹션 참조). Low-level API는 노드들의 레이어가 어떻게 샘플링될지를 명시적으로 정의하는 코드를 직접 작성해야한다 (예를 들면, :func:`dgl.sampling.sample_neighbors` 사용해서). High-level API는 노드 분류 및 링크 예측(예, :class:`~dgl.dataloading.pytorch.NodeDataLoader` 와\n:class:`~dgl.dataloading.pytorch.EdgeDataLoader`) 에 사용되는 몇 가지 유명한 샘플링 알고리즘을 구현하고 있다.\n\n분산 샘플링 모듈도 같은 디자인을 따르고 있고, 두 level의 샘플링 API를 제공한다. Low-level 샘플링 API의 경우, :class:`~dgl.distributed.DistGraph` 에 대한 분산 이웃 샘플링을 위해 :func:`~dgl.distributed.sample_neighbors` 가 있다. 또한, DGL은 분산 샘플링을 위해 분산 데이터 로더, :class:`~dgl.distributed.DistDataLoader` 를 제공한다. 분산 DataLoader는 PyTorch DataLoader와 같은 인터페이스를 갖는데, 다른 점은 사용자가 데이터 로더를 생성할 때 worker 프로세스의 개수를 지정할 수 없다는 점이다. Worker 프로세스들은 :func:`dgl.distributed.initialize` 에서 만들어진다.\n\n**Note**: :class:`~dgl.distributed.DistGraph` 에 :func:`dgl.distributed.sample_neighbors` 를 실행할 때, 샘플러는 다중의 worker 프로세스를 갖는 PyTorch DataLoader에서 실행될 수 없다. 주요 이유는 PyTorch DataLoader는 매 epoch 마다 새로운 샘플링 worker 프로세스는 생성하는데, 이는 :class:`~dgl.distributed.DistGraph` 객체들을 여러번 생성하고 삭제하게하기 때문이다.\n\nLow-level API를 사용할 때, 샘플링 코드는 단일 프로세스 샘플링과 비슷하다. 유일한 차이점은 사용자가 :func:`dgl.distributed.sample_neighbors` 와 :class:`~dgl.distributed.DistDataLoader` 를 사용한다는 것이다.\n\n.. code:: python\n\n    def sample_blocks(seeds):\n        seeds = th.LongTensor(np.asarray(seeds))\n        blocks = []\n        for fanout in [10, 25]:\n            frontier = dgl.distributed.sample_neighbors(g, seeds, fanout, replace=True)\n            block = dgl.to_block(frontier, seeds)\n            seeds = block.srcdata[dgl.NID]\n            blocks.insert(0, block)\n            return blocks\n        dataloader = dgl.distributed.DistDataLoader(dataset=train_nid,\n                                                    batch_size=batch_size,\n                                                    collate_fn=sample_blocks,\n                                                    shuffle=True)\n        for batch in dataloader:\n            ...\n\n동일한 high-level 샘플링 API들(:class:`~dgl.dataloading.pytorch.NodeDataLoader` 와 :class:`~dgl.dataloading.pytorch.EdgeDataLoader` )이 :class:`~dgl.DGLGraph` 와 :class:`~dgl.distributed.DistGraph` 에 대해서 동작한다. :class:`~dgl.dataloading.pytorch.NodeDataLoader` 과 :class:`~dgl.dataloading.pytorch.EdgeDataLoader` 를 사용할 때, 분산 샘플링 코드는 싱글-프로세스 샘플링 코드와 정확하게 같다.\n\n.. code:: python\n\n    sampler = dgl.sampling.MultiLayerNeighborSampler([10, 25])\n    dataloader = dgl.sampling.DistNodeDataLoader(g, train_nid, sampler,\n                                                 batch_size=batch_size, shuffle=True)\n    for batch in dataloader:\n        ...\n\n\n워크로드 나누기(Split workloads)\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n모델을 학습하기 위해서, 사용자는 우선 데이터를 학습, 검증 그리고 테스트 셋으로 나눠야한다. 분산 학습에서는, 이 단계가 보통은 그래프를 파터션하기 위해 :func:`dgl.distributed.partition_graph` 를 호출하기 전에 일어난다. 우리는 데이터 split를 노드 데이 또는 에지 데이터로서 boolean array들에 저장하는 것을 권장한다. 노드 분류 테스크의 경우에 이 boolean array들의 길이는 그래프의 노드의 개수와 같고, 각 원소들은 노드가 학습/검증/테스트 셋에 속하는지를 지정한다. 링크 예측 테스크에도 비슷한 boolean array들을 사용해야 한다. :func:`dgl.distributed.partition_graph` 는 그래프 파티션 결과에 따라서 이 boolean array들을 나누고, 이를 그래프 파타션과 함께 저장한다.\n\n분산 학습을 수행하는 동안에 사용자는 학습 노드들/에지들을 각 트레이너에게 할당해야 한다. 비슷하게, 검증 및 테스트 셋도 같은 방법으로 나눠야만 한다. DGL은 분산학습이 수행될 때 학습, 검증, 테스트 셋을 나누는 :func:`~dgl.distributed.node_split` 와 :func:`~dgl.distributed.edge_split` 를 제공한다. 이 두 함수는 그래프 파티셔닝 전에 생성된 boolean array들을 입력으로 받고, 그것들을 나누고 나눠진 부분을 로컬 트레이너에게 리턴한다. 기본 설정으로는 모든 부분들이 같은 개수의 노드와 에지를 갖도록 해준다. 이는 각 트레이너가 같은 크기의 미니-배치들을 갖는다고 가정하는 synchronous SDG에서 중요하다.\n\n아래 예제는 학습 셋을 나누고, 노들의 서브셋을 로컬 프로세스에 리턴한다.\n\n.. code:: python\n\n    train_nids = dgl.distributed.node_split(g.ndata['train_mask'])\n\n"
  },
  {
    "path": "docs/source/guide_ko/distributed-hetero.rst",
    "content": ".. _guide_ko-distributed-hetero:\n\n7.3 분산 heterogeneous 그래프 학습하기\n---------------------------------\n\n:ref:`(English Version) <guide-distributed-hetero>`\n\nDGL v0.6.0은 heterogeneous 그래프들을 위한 분산 학습을 실험적으로 지원한다. DGL에서 heterogeneous 그래프의 노드와 에지는 그 노드 타입 및 에지 타입에서 고유한 ID를 갖는다. DGL은 노드/에지 타입과 타입별 ID의 tuple을 사용해서 노드 및 에지를 지정한다. 분산 학습에서는 노드/에지 타입과 타입별 ID의 tuple과 더불어서 노드 또는 에지는 homogeneous ID를 통해서 지정될 수 있다. Homogeneous ID는 노드 타입이나 에지 타입과 관련없이 고유하다. DGL은 같은 타입의 모든 노드들이 연속된 homogeneous ID값들을 갖도록 노드와 에지를 정렬한다.\n\n아래 그림은 homegeneous ID 할당을 보여주는 heterogeneous 그래프의 adjacency matrix이다. 여기서 그래프틑 두가지 노드 타입( `T0` 와 `T1` )을, 네가지 에지 타입(`R0` , `R1` , `R2` , `R3` )를 갖는다. 그래프는 총 400개의 노드를 갖고, 각 타입은 200개 노드를 갖는다. `T0` 의 노드들은 [0,200)의 ID를 갖고, `T1` 의 노드들은 [200, 400)의 ID 값을 갖는다. 여기서 만약 tuple을 사용해서 노드를 구분한다면, `T0` 의 노드들은 (T0, type-wise ID)로 지정될 수 있다. 여기서 type-wise ID는 [0,200)에 속한다; `T1` 의 노드들은 (T1, type-wise ID)으로 지정되고, type-wise ID는 [0, 200)에 속한다.\n\n.. figure:: https://data.dgl.ai/tutorial/hetero/heterograph_ids.png\n   :alt: Imgur\n\n7.3.1 분산 그래프 데이터 접근하기\n^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n분산 학습을 위해 :class:`~dgl.distributed.DistGraph` 은 :class:`~dgl.DGLGraph` 에서 heterogeneous 그래프 API를 지원한다. 아래 코드는 `T0` 의 노드 데이터를 type-wise 노드 ID를 사용해서 얻는 것을 보여준다. :class:`~dgl.DGLGraph` 의 데이터를 접근할 때, 사용자는 type-wise ID와 연관된 노드 타입 또는 에지 타입을 사용해야 한다.\n\n.. code:: python\n\n    import dgl\n    g = dgl.distributed.DistGraph('graph_name', part_config='data/graph_name.json')\n    feat = g.nodes['T0'].data['feat'][type_wise_ids]\n\n사용자는 특정 노드 타입 또는 에지 타입에 대한 분산 텐서 및 분산 임베딩을 생성할 수 있다. 분산 텐서들과 분산 임베딩들은 여러 머신에 나눠져서 저장된다. 만들 때는 :class:`~dgl.distributed.PartitionPolicy` 로 파티션을 어떻게 할지를 명시해야 한다. 기본 설정으로 DGL은 첫 차원 값의 크기를 기반으로 적절한 파티션 정책을 선택한다. 하지만, 다중 노드 타입 또는 에지 타입이 같은 수의 노드 또는 에지를 갖는 다면, DGL은 파티션 정책을 자동으로 결정할 수 없고, 사용자는 직접 파티션 정책을 지정해야 한다. 아래 코드는 노드 타입 `T0` 의 분산 텐서를 `T0` 를 위한 파티션 정책을 사용해서 생성하고, 이를 `T0` 의 노드 데이터로 저장한다.\n\n.. code:: python\n\n    g.nodes['T0'].data['feat1'] = dgl.distributed.DistTensor((g.num_nodes('T0'), 1), th.float32, 'feat1',\n                                                             part_policy=g.get_node_partition_policy('T0'))\n\n분산 텐서 및 분산 임베딩을 만들기 위한 파티션 정책은 heterogeneous 그래프가 그래프 서버에 로드될 때 초기화된다. 사용자는 새로운 파티션 정책을 실행 중에 생성할 수 없다. 따라서, 사용자는 노드 타입 이나 에지 타입에 대한 분산 텐서 또는 분산 임베딩 만을 만들 수 있다.\n\n7.3.2 분산 샘플링\n^^^^^^^^^^^^^^\n\nDGL v0.6은 분산 샘플링에서 homogeneous ID를 사용한다. **Note**: 이는 앞으로 릴리즈에서 바뀔 수도 있다. DGL은 homogeneous ID와 type-wise ID 간에 노드 ID와 에지 ID를 변환하는 네 개의 API를 제공한다.\n\n* :func:`~dgl.distributed.GraphPartitionBook.map_to_per_ntype` : homogeneous 노드 ID를 type-wise ID와 노드 타입 ID로 변환한다.\n* :func:`~dgl.distributed.GraphPartitionBook.map_to_per_etype` : homogeneous 에지 ID를 type-wise ID와 에지 타입 ID로 변환한다.\n* :func:`~dgl.distributed.GraphPartitionBook.map_to_homo_nid` : type-wise ID와 노드 타입을 homogeneous 노드 ID로 변환한다.\n* :func:`~dgl.distributed.GraphPartitionBook.map_to_homo_eid` : type-wise ID와 에지 타입을 homogeneous 에지 ID로 변환한다.\n\n다음 예제는 `paper` 라는 노드 타입을 갖는 heterogeneous 그래프로부터 :func:`~dgl.distributed.sample_neighbors` 를 사용해서 서브 그래프를 샘플링한다. 이는 우선 type-wise 노드 ID들을 homogeneous 노드 ID들로 변환한다. 시드 노드들로 서브 그래프를 샘플링 한 다음, homogeneous 노드 ID들과 에지 ID들을 type-wise ID들로 바꾸고, 타입 ID를 노드 데이터와 에지 데이터에 저장한다.\n\n.. code:: python\n\n        gpb = g.get_partition_book()\n        # We need to map the type-wise node IDs to homogeneous IDs.\n        cur = gpb.map_to_homo_nid(seeds, 'paper')\n        # For a heterogeneous input graph, the returned frontier is stored in\n        # the homogeneous graph format.\n        frontier = dgl.distributed.sample_neighbors(g, cur, fanout, replace=False)\n        block = dgl.to_block(frontier, cur)\n        cur = block.srcdata[dgl.NID]\n\n        block.edata[dgl.EID] = frontier.edata[dgl.EID]\n        # Map the homogeneous edge Ids to their edge type.\n        block.edata[dgl.ETYPE], block.edata[dgl.EID] = gpb.map_to_per_etype(block.edata[dgl.EID])\n        # Map the homogeneous node Ids to their node types and per-type Ids.\n        block.srcdata[dgl.NTYPE], block.srcdata[dgl.NID] = gpb.map_to_per_ntype(block.srcdata[dgl.NID])\n        block.dstdata[dgl.NTYPE], block.dstdata[dgl.NID] = gpb.map_to_per_ntype(block.dstdata[dgl.NID])\n\n노드/에지 타입 ID를 위해서, 사용자는 노드/에지 타입을 검색할 수 있다. 예를 들어, `g.ntypes[node_type_id]` . 노드/에지 타입들과 type-wise ID들을 사용해서, 사용자는 미니배치 계산을 위해서 `DistGraph` 로부터 노드/에지 데이터를 검색할 수 있다.\n"
  },
  {
    "path": "docs/source/guide_ko/distributed-preprocessing.rst",
    "content": ".. _guide_ko-distributed-preprocessing:\n\n7.1 분산 학습을 위한 전처리\n---------------------\n\n:ref:`(English Version) <guide-distributed-preprocessing>`\n\nDGL의 분산 학습을 사용하기 위해서는 그래프 데이터에 대한 전처리가 필요하다. 이 전처리는 두 단계로 구성된다: 1) 그래프를 서브 그래프들로 파티션하기, 2) 노드/에지들에 새로운 ID를 부여하기. 상대적으로 작은 그래프들의 경우, DGL이 제공하는 파티셔닝 API :func:`dgl.distributed.partition_graph` 를 사용해서 위 두 단계를 수행할 수 있다. 이 API는 한 컴퓨터에서 수행된다. 따라서, 그래프가 큰 경우, 이 API를 사용하고 싶다면 큰 컴퓨터를 사용해야 한다. 이 API과 더불어, 여기서는 큰 그래프를 컴퓨터들의 클러스터에서 파티션을 하는 솔루션을 소개한다. (7.1.1 절을 보라)\n\n:func:`dgl.distributed.partition_graph` 는 랜덤 파티션과 `Metis <http://glaros.dtc.umn.edu/gkhome/views/metis>`__ 기반의 파티셔닝을 모두 지원한다. Metis 파티셔닝의 장점은 최소의 에지 컷(edge cut)을 갖는 파티션들을 만들 수 있다는 것이다. 이는 분산 학습 및 추론에서 네트워크 통신을 줄여준다. DGL은 최신 버전의 Metis은 실제(real world)에서 거듭 제곱 법칙의 분포를 갖는 그래프에 최적화되어 있다. 파타셔닝 후, API는 학습시 쉽게 로딩될 수 있는 형태로 파티션된 결과를 만든다.\n\n기본 설정으로 파티션 API는 분산 학습/추론이 실행될 때 노드/에지를 구별하는 것을 돕기 위해서 입력 그래프의 노드와 에지에 새로운 ID를 부여한다. ID를 할당한 후, 파티션 API은 모든 노드 데이터와 에지 데이터를 섞는다. 파티션된 서브 그래프를 생선한 후, 각 서브 그래프는 ``DGLGraph`` 객체로 저장된다. 섞기전의 원본 노드/에지 ID들은 서브 그래프들의 노드/에지 데이터에 `orig_id` 필드에 저장된다. 서브 그래프의 노드 데이터 `dgl.NID` 와 에지 데이터 `dgl.EID` 는 노드/에지들이 reshuffle 후의 전체 그래프의 새로운 노드/에지 ID를 저장한다. 학습이 실행되는 동안, 사용자는 새로운 노드/에지 ID만을 사용한다.\n\n파티션된 결과는 출력 디렉토리의 여러 파일로 저장된다. 이는 한개의 JSON 파일을 포함하는데, 파일 이름은 xxx.json 형태이고, xxx는 파티션 API에 사용된 그래프 이름이다. JSON 파일은 모든 파티션 설정들을 갖는다. 먄약 파티션 API가 새로운 ID를 노드와 에지에 할당하지 않은 경우에는, 추가적으로 두 개의 Numpy 파일; `node_map.npy` 와 `edge_map.npy` 를 생성하는데, 이는 노드/에지 ID와 파티션 ID의 매핑을 저장한다. 만약 그래프에 수십억 개의 노드와 에지가 있다면, 두 파일의 Numpy array는 커질 것인다. 그 이유는 그래프의 각 노드 및 에지에 대해서 하나의 엔트리를 갖기 때문이다. 각 파티션에 대한 폴더는 DGL 포멧으로 파티션 데이터를 저장하는 세 개의 파일이 있다. `graph.dgl` 은 파티션의 그래프 구조와 노드 및 에지에 대한 메타 데이터를 저장하고 있고, `node_feats.dgl` 과 `edge_feats.dlg` 은 파티션에 속하는 노드와 에지의 모든 피쳐들을 저장하고 있다.\n\n.. code-block:: none\n\n    data_root_dir/\n        |-- xxx.json                  # partition configuration file in JSON\n        |-- node_map.npy              # partition id of each node stored in a numpy array (optional)\n        |-- edge_map.npy              # partition id of each edge stored in a numpy array (optional)\n        |-- part0/                    # data for partition 0\n            |-- node_feats.dgl        # node features stored in binary format\n            |-- edge_feats.dgl        # edge features stored in binary format\n            |-- graph.dgl             # graph structure of this partition stored in binary format\n        |-- part1/\n            |-- node_feats.dgl\n            |-- edge_feats.dgl\n            |-- graph.dgl\n\n로드 밸런싱\n~~~~~~~~\n\n그래프를 파티셔닝할 때, Metis의 기본 설정은 각 파티션의 노드 수에 대해서 균형을 맞춘다. 그 결과 주어진 테스크에 따라서 최적이지 않은 구성(suboptimal configuration)이 될 수 있다. 예를 들어, semi-supervised 노드 분류의 경우, 트레이너는 로컬 파티션의 레이블이 있는 노들의 서브셋에 대해서 계산을 수행한다. 그래프의 노드들(레이블이 있는 것과 없는 모든 노드)에 균형을 맞추는 파티셔닝은 계산적인 로드(computational node)가 불균형하게 될 수 있다. 각 파티션에 균형잡힌 워크로드를 얻기 위해서 파티션 API는 각 노드 타입에 대한 노드 수를 고려해서 파티션들에 대한 균형을 만드는 것을 지원한다. 이는 :func:`dgl.distributed.partition_graph` 에서 ``balance_ntypes`` 를 설정하는 것으로 가능하다. 사용자들은 이 기능을 활용해서, 학습 셋, 검증 셋, 그리고 테스트 셋에 다른 노드 타입들이 포함된 것을 고려하게 할 수 있다.\n\n아래 코드는 학습 셋 내에서 그리고 학습 셋 외에 두 가지 노드 타입이 있다는 것을 고려한 코드 예제이다.\n\n.. code:: python\n\n    dgl.distributed.partition_graph(g, 'graph_name', 4, '/tmp/test', balance_ntypes=g.ndata['train_mask'])\n\n노드 타입 균형을 맞추는 것에 더해서, :func:`dgl.distributed.partition_graph` 는 ``balance_edges`` 설정을 통해서 다른 노드 타입들의 노드들의 in-degree들 사이의 균형을 잡는 것을 지원한다. 이는 다른 타입의 노드들에 부속되는 에지들의 개수에 대한 균형을 만든다.\n\n**Note**: :func:`dgl.distributed.partition_graph` 에 전달되는 그래프 이름은 중요한 인자이다. 그 그래프 이름은 :class:`dgl.distributed.DistGraph` 이 분산 그래프를 지정하는데 사용된다. 그래프 이름은 알파벳 문자들과 밑줄 기호만으로 구성되어야 한다.\n\nID 매핑\n~~~~~~\n\n:func:`dgl.distributed.partition_graph` 는 파티셔닝을 하는 과정에서 노드 ID와 에지 ID를 섞고, 노드 데이터와 에지 데이터도 그에 따라서 섞어준다. 학습이 끝나면, 다운스트림 과제를 위해서 계산된 노드 임베딩들을 저장할 필요가 있다. 따라서, 저장된 노드 임베딩을 원본 ID에 따라서 다시 섞어야한다.\n\n`return_mapping=True` 인 경우, :func:`dgl.distributed.partition_graph` 는 섞인 노드/에지 ID와 그것들의 원본 ID 사이의 매핑을 리턴한다. Homogeneous 그래프의 경우, 두 벡터를 리턴한다. 첫번째 벡터는 모든 섞인 노드 ID와 그것의 원본 ID 메핑을, 두번째 벡터는 모든 섞인 에지 ID와 그것의 원본 ID 매핑이다. Heterogeneous 그래프의 경우에는 벡터들의 dictionary 두 개가 리턴된다. 첫번째 dictionary는 각 노드 타입에 대한 매핑을, 두번째 dictionary는 각 에지 타입에 대한 매핑이다.\n\n.. code:: python\n\n    node_map, edge_map = dgl.distributed.partition_graph(g, 'graph_name', 4, '/tmp/test',\n                                                         balance_ntypes=g.ndata['train_mask'],\n                                                         return_mapping=True)\n    # Let's assume that node_emb is saved from the distributed training.\n    orig_node_emb = th.zeros(node_emb.shape, dtype=node_emb.dtype)\n    orig_node_emb[node_map] = node_emb\n\n7.1.1 분산 파티셔닝\n^^^^^^^^^^^^^^^^\n\n큰 그래프를 위해서 DGL은 `ParMetis <http://glaros.dtc.umn.edu/gkhome/metis/parmetis/overview>`__ 을 사용해서 컴퓨터들의 클러스터에서 그래프를 파티셔닝한다. 이 솔루션은 사용자가 ParMETIS에 맞도록 데이터를 준비하고, ParMETIS에 의해 만들어질 파티션들을 위한 :class:`dgl.DGLGraph` 를 만들기 위해서 DGL 스크립트 `tools/convert_partition.py` 를 사용해야 한다.\n\n**Note**: `convert_partition.py` 는 `pyarrow` 패키지를 사용해서 csv 파일을 로드안다. `pyarrow` 설치하자.\n\nParMETIS 설치\n~~~~~~~~~~~~\n\nParMETIS는 METIS와 GKLib을 필요로 한다. GKLib 컴파일과 설치는 `here <https://github.com/KarypisLab/GKlib>`__ 에 있는 설명을 참고하자. METIS 컴파일과 설치는 아래 설명을 따라 GIT에서 METIRS를 클론하고 int64 지원을 활성화해서 컴파일한다.\n\n.. code-block:: none\n\n    git clone https://github.com/KarypisLab/METIS.git\n    make config shared=1 cc=gcc prefix=~/local i64=1\n    make install\n\n여기서부터는 PartMETIS를 직접 컴파일하고 설치하는 것이 필요하다. 아래 명령을 사용해서 ParMETIS의 DGL 브랜치를 클론한다.\n\n.. code-block:: none\n\n    git clone --branch dgl https://github.com/KarypisLab/ParMETIS.git\n\n그리고, ParMETIS를 컴파일하고 설치한다.\n\n.. code-block:: none\n\n    make config cc=mpicc prefix=~/local\n    make install\n\nParMETIS를 실행하기 전에, 두 환경 변수들, `PATH`와 `LD_LIBRARY_PATH`을 설정해야 한다: \n\n.. code-block:: none\n\n    export PATH=$PATH:$HOME/local/bin\n    export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/local/lib/\n\nParMETIS를 위한 입력 포멧\n~~~~~~~~~~~~~~~~~~~~~\n\nParMETIS의 입력 그래프는 다음 이름들을 사용해서 세 개의 파일들에 저장된다: `xxx_nodes.txt` , `xxx_edges.txt` 와 `xxx_stats.txt`. 여기서 `xxx` 는 그래프 이름이다.\n\n`xxx_nodes.txt` 의 각 행은 다음 형식으로 노드에 대한 정보를 담고 있다.\n\n.. code-block:: none\n\n    <node_type> <weight1> ... <orig_type_node_id> <attributes>\n\n모든 필드들은 공백 문자로 구분된다.\n\n* `<node_type>` 은 정수 값이다. Homogeneous 그래프에서는 항상 0이고, heterogenous 그래프에서는 그 값이 각 노드의 타입을 의미한다.\n* `<weight1>`, `<weight2>`, 등은 정수 값들인데, ParMETIS가 그래프 파티션들의 균형을 맞출 때 노드 가중치로 사용하는 값들이다. 사용자가 노드 가중치를 명시하지 않는 경우, ParMETIS는 각 파티션의 노드 수에 대한 균형을 고려해서 파티션을 나눈다 (좋은 학습 속도를 얻기 위해서는 그래프 파티션들의 균헝을 맞추는 것이 중요하다). 하지만, 이 기본 전략은 많은 use case들에 충분하지 않을 수 있다. 예를 들어, heterogeneous 그래프의 경우, 우리는 모든 파티션들이 각 노드 타입별로 비슷한 개수의 노드들을 갖도록 그래프에 대한 파티션을 나누고 싶다. 아래 토이 예제는 노드 가중치를 사용해서 다른 테입들의 노드 개수의 균형을 맞추것을 어떻게 하는지 보여준다.\n* `<orig_type_node_id>` 은 노드 타입에서의 노드 ID를 표현하는 정수 값이다. DGL에서 각 타입의 노드들은 0부터 시작하는 ID가 부여된다. Homogeneous 그래프에서 이 필드는 노드 ID의 값도 동일하다.\n* `<attributes>` 는 선택적인 필드들이다. 이는 임의의 값을 저장하는데 사용될 수 있으며, ParMETIS는 이 필드들을 사용하지 않는다. 잠재적으로는 homogenous 그래프들의 경우 노드 피쳐들과 에지 피쳐들을 이 필드에 저장할 수 있다.\n* 행(row) ID는 그래프의 *homogeneous* ID를 의미한다 (모든 노드에 고유한 ID가 할당된다). 같은 타입의 모든 노드들에 ID는 연속된 값으로 부여된다. 즉, 같은 타입의 노드들은 `xxx_notes.txt` 파일에 함께 저장되어야 한다.\n\n다음은 두 노드 타입을 갖는 heterogenous 그래프의 노트 파일 예이다. 노드 타입 0은 세 개의 노드를 갖고 있고, 노드 타입 1은 네 개의 노드들을 갖는다. 두 노드 가중치를 사용해서 ParMETIS느 노드 타입 0에 속한 노드 개수와 노드 타입 1에 속한 노드 개수가 대략 같도록 파티션 나눈다.\n\n.. code-block:: none\n\n    0 1 0 0\n    0 1 0 1\n    0 1 0 2\n    1 0 1 0\n    1 0 1 1\n    1 0 1 2\n    1 0 1 3\n\n비슷하게, `xxx_edges.txt` 의 각 행은 아래 형식으로 에지에 대한 정보를 저장한다.\n\n.. code-block:: none\n\n    <src_id> <dst_id> <type_edge_id> <edge_type> <attributes>\n\n모든 필드들은 공백 문자로 구분된다.\n\n* `<src_id>` 는 소스 노드의 *homogeneous* ID이다.\n* `<dst_id>` 는 목적지 노드의 *homogeneous* ID이다.\n* `<type_edge_id>` 는 에지 타입에 대한 에지 ID이다.\n* `<attributes>` 는 선택적인 필드들이다. 임의의 값을 저장하는데 사용할 수 있는데, ParMETIS는 이 필드를 사용하지 않는다.\n\n**Note**: 에지 파일에 중복된 에지나 셀프-룹을 갖는 에지가 없어야 한다.\n\n`xxx_stats.txt` 는 그래프에 대한 기본적인 통계들을 저장한다. 이 파일은 공백으로 구분되는 세 필드들로 구성된 단 한 줄만 갖는다.\n\n.. code-block:: none\n\n    <num_nodes> <num_edges> <num_node_weights>\n\n* `num_nodes` 는 노드 타입을 상관하지 않고 전체 노드 수를 저장한다.\n* `num_edges` 는 에지 타입을 상관하지 않고 전체 에지 수를 저장한다.\n* `num_node_weights` 는 노드 파일의 노드 가중치 수를 저장한다.\n\nParMETIS 실행하기 및 결과 포멧들\n~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nParMETIS는 `pm_dglpart` 명령이 실행된 머신에서 세 파일들에 저장된 그래프를 로드하고, 클러스터의 모든 머신에 데이터를 분산하고, ParMETIS를 실행해서 그래프의 파티션을 나누는 명령 `pm_dglpart` 을 포함하고 있다. 이 명령의 수행이 완료되면, 각 파타션에 대해서 세 개의 파일이 생성된다: `p<part_id>-xxx_nodes.txt`, `p<part_id>-xxx_edges.txt`, `p<part_id>-xxx_stats.txt`\n\n**Note**: ParMETIS는 파티셔닝을 수행하면서 노드들에 ID를 재할당한다. ID 재할당이 끝나면, 한 파티션의 노드들은 연속된 ID값을 갖는다; 더 나아가, 같은 타입의 노드들은 연속된 ID들을 부여 받는다.\n\n`p<part_id>-xxx_nodes.txt` 는 파티션의 노드 데이터를 저장한다. 각 행은 한 노드에 대한 다음 정보들을 담고 있다.\n\n.. code-block:: none\n\n    <node_id> <node_type> <weight1> ... <orig_type_node_id> <attributes>\n\n* `<node_id>` 는 ID 재할당 후의 *homogeneous* 노드 ID이다.\n* `<node_type>` 는 노드 타입이다.\n* `<weight1>` 는 ParMETIS가 사용하는 노드 가중치이다.\n* `<orig_type_node_id>` 는 입력 heterogeneous 그래프의 특정 노드 티입에 대한 원본 노드 ID이다.\n* `<attributes>` 는 선택적인 필드들로 입력 노드 파일에서 임의의 값을 갖는다.\n\n`p<part_id>-xxx_edges.txt` 는 파티션의 에지 데이터를 저장한다. 각 행은 한 에지에 대한 다음 정보를 담고 있다.\n\n.. code-block:: none\n\n    <src_id> <dst_id> <orig_src_id> <orig_dst_id> <orig_type_edge_id> <edge_type> <attributes>\n\n* `<src_id>` 는 ID 재할당 후의 소스 노드의 *homogeneous* ID이다.\n* `<dst_id>` 는 ID 재할당 후의 목적지 노드의 *homogeneous* ID이다.\n* `<orig_src_id>` 는 입력 그래프의 소스 노드에 대한 *homogeneous* ID이다.\n* `<orig_dst_id>` 는 입력 그래프의 목적지 노드에 대한 *homogeneous* ID이다.\n* `<orig_type_edge_id>` 는 입력 그래프의 특정 에지 타입에 대한 에지 ID이다.\n* `<edge_type>` 은 에지 타입이다.\n* `<attributes>` 는 선택적인 필드들로 입력 에지 파일에서 임의의 에지 속성 값을 갖는다.\n\n`pm_dglpart` 이 실행된 때, 세 입력 파일들(`xxx_nodes.txt`, `xxx_edges.txt`, `xxx_stats.txt`)은 `pm_dglpart` 명령이 실행된 디렉토리와 같은 곳에 있어야 한다. 다음 명령은 네 개의 ParMETIS 프로세스를 실행해서, `xxx` 라는 이름의 그래프를 8개의 파티션으로 나눈다 (각 프로세스는 2개의 파티션을 담당한다).\n\n.. code-block:: none\n\n    mpirun -np 4 pm_dglpart xxx 2\n\nParMETIS 결과들을 DGLGraph로 변환하기\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nDGL은 `convert_partition.py` 라는 스크립트를 제공한다. 이는 `tool` 디렉토리에 있는데, 파티션 파일들에 있는 데이터를 :class:`dgl.DGLGraph` 객체로 변환하고 파일들에 저장하는 역할을 한다. **Note** `convert_partition.py` 는 단일 머신에서 실행된다. 향후, 우리는 이를 확장해서 여러 머신들에 걸쳐서 데이터를 병렬로 변환하도록 만들 것이다.  **Note**: csv 파일로 저장된 데이터를 로딩하기 위해서 `pyarrow` 패키지를 설치하자.\n\n`convert_partition.py` 는 다음 인자들을 받는다:\n\n* `--input-dir INPUT_DIR` 는 ParMETIS가 생성한 파티션 파일들이 있는 디렉토리를 지정한다.\n* `--graph-name GRAPH_NAME` 는 그래프 이름을 지정한다.\n* `--schema SCHEMA` 는 입력 heterogeneous 그래프의 스키마를 명시하는 파일이다. 스키마 파일은 JSON 파일로서, 노드 타입들과 에지 타입들을 나열하고, 또한 각 노드 타입 및 에지 타입에 대한 homogeneous ID의 범위를 포함한다.\n* `--num-parts NUM_PARTS` 는 파티션의 개수를 명시한다.\n* `--num-node-weights NUM_NODE_WEIGHTS` 는 ParMETIS가 파티션들의 균형을 위해서 사용한 노드 가중치의 개수를 지정한다.\n* `[--workspace WORKSPACE]` 는 선택적인 인자로, 중간 결과들을 저장할 workspace 디렉토리를 지정한다.\n* `[--node-attr-dtype NODE_ATTR_DTYPE]` 는 선택적인 인자로, 노드 파일들의 나머지 필드인 `<attributes>` 에 저장된 노드 속성들의 데이터 타입을 명시한다.\n* `[--edge-attr-dtype EDGE_ATTR_DTYPE]` 는 선택적인 인자로, 에지 파일들의 나머지 필드인 `<attributes>` 에 저장된 에지 속성들의 데이터 타입을 명시한다.\n* `--output OUTPUT` 는 파티션 결과들이 저장될 출력 디렉토리를 지정한다.\n\n`convert_partition.py` 의 결과 파일들은 다음과 같다:\n\n.. code-block:: none\n\n    data_root_dir/\n        |-- xxx.json                  # partition configuration file in JSON\n        |-- part0/                    # data for partition 0\n            |-- node_feats.dgl        # node features stored in binary format (optional)\n            |-- edge_feats.dgl        # edge features stored in binary format (optional)\n            |-- graph.dgl             # graph structure of this partition stored in binary format\n        |-- part1/\n            |-- node_feats.dgl\n            |-- edge_feats.dgl\n            |-- graph.dgl\n\n**Note**: 노드 속성 또는 에지 속성의 데이터 타입이 명시된다면, `convert_partition.py` 는 모든 타입의 모든 노드들 및 에지들이 꼭 이 속성들을 갖는다고 가정한다. 따라서, 다른 타입의 노드들이나 에지들이 서로 다른 개수의 속성을 갖는다면, 사용자는 이를 직접 만들어야 한다.\n\n다음은 `convert_partition.py` 를 위한 OGBN-MAG의 스키마 예제이다. 이는 두 필드를 갖는다: `nid` 와 `eid`. `nid` 안에는, 모든 노드 타입들이 나열되어 있고, 각 노드 타입에 대한 homogeneous ID 범위도 포함되어 있다; `eid` 안에는, 모든 에지 타입들이 나열되어 있고, 각 에지 타입에 대한 homogeneous ID 범위도 포함되어 있다.\n\n.. code-block:: none\n\n    {\n    \"nid\": {\n        \"author\": [\n            0,\n            1134649\n        ],\n        \"field_of_study\": [\n            1134649,\n            1194614\n        ],\n        \"institution\": [\n            1194614,\n            1203354\n        ],\n        \"paper\": [\n            1203354,\n            1939743\n        ]\n    },\n    \"eid\": {\n        \"affiliated_with\": [\n            0,\n            1043998\n        ],\n        \"writes\": [\n            1043998,\n            8189658\n        ],\n        \"rev-has_topic\": [\n            8189658,\n            15694736\n        ],\n        \"rev-affiliated_with\": [\n            15694736,\n            16738734\n        ],\n        \"cites\": [\n            16738734,\n            22155005\n        ],\n        \"has_topic\": [\n            22155005,\n            29660083\n        ],\n        \"rev-cites\": [\n            29660083,\n            35076354\n        ],\n        \"rev-writes\": [\n            35076354,\n            42222014\n        ]\n    }\n    }\n\n아래 코드는 스키마 파일을 만드는 예제이다.\n\n.. code-block:: none\n\n    nid_ranges = {}\n    eid_ranges = {}\n    for ntype in hg.ntypes:\n        ntype_id = hg.get_ntype_id(ntype)\n        nid = th.nonzero(g.ndata[dgl.NTYPE] == ntype_id, as_tuple=True)[0]\n        nid_ranges[ntype] = [int(nid[0]), int(nid[-1] + 1)]\n\n    for etype in hg.etypes:\n        etype_id = hg.get_etype_id(etype)\n        eid = th.nonzero(g.edata[dgl.ETYPE] == etype_id, as_tuple=True)[0]\n        eid_ranges[etype] = [int(eid[0]), int(eid[-1] + 1)]\n    with open('mag.json', 'w') as outfile:\n        json.dump({'nid': nid_ranges, 'eid': eid_ranges}, outfile, indent=4)\n\nHeterogeneous 그래프에 대한 노드/에지 피처들 생성하기\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n`convert_partition.py` 이 만든 :class:`dgl.DGLGraph` 아웃풋은 heterogeneous 그래프 파티션들을 homogeneous 그래프로 저장한다. 노드 데이터는 `orig_id` 라는 필드를 갖는데, 이는 원본 heterogeneous 그래프의 특정 노드 타입의 노드 ID들을 저장하고, `NTYPE` 의 필드는 노드 타입을 저장한다. 추가로, 이는 `inner_node` 라는 노드 데이터를 저장하는데, 이는 그래프 파티션의 노드가 파티션이 할당되어 있는지 여부를 알려준다. 만약 어떤 노드가 파티션에 할당되었다면, `inner_node` 는 1을 갖고, 반대의 경우에는 0을 갖는다. **Note**: 그래프 파티션은 몇 개의 HALO 노드들을 포함하는데, 이는 다른 파티션에 할당된 것지만, 이 그래프 파티션의 몇 개의 에지와 연결되어 있는 것들이다. 이 정보를 사용해서, 우리는 별도로 각 노드 타입에 대한 노드 피쳐들을 구성할 수 있으며, 이들을 `<node_type>/<feature_name>` 를 키로 갖고 값은 노드 피쳐 벡터인 dictionary에 저장할 수 있다. 아래 코드는 노드 피쳐 dictionary를 구성하는 방법을 보여준다. 텐서들의 dictionary가 만들어지면, 이는 파일에 저장된다.\n\n.. code-block:: none\n\n    node_data = {}\n    for ntype in hg.ntypes:\n        local_node_idx = th.logical_and(part.ndata['inner_node'].bool(),\n                                        part.ndata[dgl.NTYPE] == hg.get_ntype_id(ntype))\n        local_nodes = part.ndata['orig_id'][local_node_idx].numpy()\n        for name in hg.nodes[ntype].data:\n            node_data[ntype + '/' + name] = hg.nodes[ntype].data[name][local_nodes]\n    dgl.data.utils.save_tensors(metadata['part-{}'.format(part_id)]['node_feats'], node_data)\n\n에지 피쳐도 비슷한 방법으로 구성할 수 있다. 차이점은 :class:`dgl.DGLGraph` 의 모든 에지들이 파티션에 포함된다는 점이다. 그래서, 구성 방법은 더 간단하다.\n\n.. code-block:: none\n\n    edge_data = {}\n    for etype in hg.etypes:\n        local_edges = subg.edata['orig_id'][subg.edata[dgl.ETYPE] == hg.get_etype_id(etype)]\n        for name in hg.edges[etype].data:\n            edge_data[etype + '/' + name] = hg.edges[etype].data[name][local_edges]\n    dgl.data.utils.save_tensors(metadata['part-{}'.format(part_id)]['edge_feats'], edge_data)\n"
  },
  {
    "path": "docs/source/guide_ko/distributed-tools.rst",
    "content": ".. _guide_ko-distributed-tools:\n\n7.4 분산 학습/추론을 런칭하기 위한 툴들\n-------------------------------\n\n:ref:`(English Version) <guide-distributed-tools>`\n\nDGL은 분산 학습을 돕는 두 스크립트들을 제공한다.\n\n* *tools/copy_files.py* : 그래프 파티션들을 하나의 그래프로 복사\n* *tools/launch.py* : 머신들의 클러스터에서 분산 학습 잡을 시작\n\n*copy_files.py* 는 (그래프가 파티션이 수행된) 한 머신의 파타션된 데이터와 관련 파일들(예, 학습 스크립트)을 (분산 학습이 수행 될) 클러스터에 복사한다. 스크립트는 한 파티션을 해당 파티션을 사용해서 분산 학습 잡이 실행될 머신에 복사한다. 스크립트는 네 개의 인자를 사용한다.\n\n* ``--part_config`` 는 로컬 머신의 파티션된 데이터에 대한 정보를 저장하는 파티션 설정 파일을 지정한다.\n* ``--ip_config`` 는 클러스터의 IP 설정 파일을 지정한다.\n* ``--workspace`` 는 분산 학습에 관련된 모든 데이터가 저장될 학습 머신의 디렉토리를 지정한다.\n* ``--rel_data_path`` 는 파티션된 데이터가 저장될 workspace 디렉토리 아래 상대 경로를 지정한다.\n* ``--script_folder`` 는 사용자의 학습 스크립트가 저장될 workspace 디렉토리 아래 상대 경로를 지정한다.\n\n**Note**: *copy_files.py* 는 IP 설정 파일을 기반으로 파티션을 저장할 머신을 찾는다. 따라서, 같은 IP 설정 파일이 *copy_files.py* 과 *launch.py* 에 사용되어야 한다.\n\nDGL은 클러스터에서 분산 학습 잡을 시작하기 위해서 *tools/launch.py* 를 제공한다. 이 스크립트는 다음을 가정한다.\n\n* 파티션된 데이터와 학습 스크립트는 클러스터 또는 클러스터의 모든 머신이 접근 가능한 클로벌 스토리지(예, NFS)로 복사된다.\n* (런치 스크립트가 실행되는) 마스터 머신은 다른 모든 머신에 패스워드 없이(passwordless) ssh 접근을 할 수 있다.\n\n**Note**: 런치 스크립트는 클러스터의 머신 중에 하나에서 실행되야 한다.\n\n다음은 클러스터에서 분산 학습 잡을 수행하는 예를 보여준다.\n\n.. code:: none\n\n    python3 tools/launch.py \\\n    --workspace ~graphsage/ \\\n    --num_trainers 2 \\\n    --num_samplers 4 \\\n    --num_servers 1 \\\n    --part_config data/ogb-product.json \\\n    --ip_config ip_config.txt \\\n    \"python3 code/train_dist.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 5 --batch-size 1000 --lr 0.1 --num_workers 4\"\n\n설정 파일 *ip_config.txt* 은 클러스터의 머신들의 IP 주소들을 저장한다. *ip_config.txt* 의 전형적인 예는 다음과 같다:\n\n.. code:: none\n\n    172.31.19.1\n    172.31.23.205\n    172.31.29.175\n    172.31.16.98\n\n각 줄은 한 머신의 IP 주소이다. 선택적으로 IP 주소 뒤에 트레이너들의 네트워크 통신에 사용될 포트 번호도 지정할 수 있다. 포트 번호가 지정되지 않은 경우 기본 값인 ``30050`` 이 사용된다.\n\n런치 스크립트에서 지정된 workspace는 머신들의 작업 디렉토리로, 학습 스크립트, IP 설정 파일, 파티션 설정 파일 그리고 그래프 파티션들이 저장되는 위치이다. 파일들의 모든 경로들은 workspace의 상대 경로로 지정되어야 한다.\n\n런치 스크립트는 한 머신에서 지정된 수의 학습 잡(``--num_trainers`` )을 생성한다. 또한, 사용자는 각 트레이너에 대한 샘플러 프로세스의 개수(``--num_samplers``)를 정해야 한다. 샘플러 프로세스의 개수는 :func:`~dgl.distributed.initialize` 에서 명시된 worker 프로세스의 개수과 같아야 한다.\n"
  },
  {
    "path": "docs/source/guide_ko/distributed.rst",
    "content": ".. _guide_ko-distributed:\n\n7장: 분산 학습\n===========\n\n:ref:`(English Version) <guide-distributed>`\n\nDGL은 데이터와 연산을 컴퓨터 리소스들의 집합들에 분산하는 완전한 분산 방식을 채택하고 있다. 이 절에서는 클러스터 설정(컴퓨터들의 그룹)을 가정하고 있다. DGL은 그래프를 서브 그래프들로 나누고, 클러스터의 각 컴퓨터는 한개의 서브 그래프 (또는 파티션)에 대해 책임을 진다. DGL은 클러스터이 모든 컴퓨터에서 동일한 학습 스크립트를 실행해서 계산을 병렬화시키고, trainer에게 파티션된 데이터를 제공하기 위해서 같은 컴퓨터에서 서버들을 실행한다.\n\n학습 스크립트를 위해서 DGL은 미니-배치 학습과 비슷한 분산 API를 제공한다. 이는 단일 컴퓨터에서 미니-배치 학습을 수행하는 코드를 아주 조금만 수정하면 되게 해준다. 아래 코드는 GraphSAGE를 분산 형태로 학습하는 예제이다. 유일한 코드 변경은 4-7 라인이다: 1) DGL의 분산 모듈 초기화하기, 2) 분산 그래프 객체 생성하기, 3) 학습 셋을 나누고 로컬 프로세스를 위해서 노드들을 계산하기. 샘플러 생성, 모델 정의, 학습 룹과 같은 나머지 코드는 :ref:`mini-batch training <guide_ko-minibatch>` 과 같다.\n\n.. code:: python\n\n    import dgl\n    import torch as th\n\n    dgl.distributed.initialize('ip_config.txt')\n    th.distributed.init_process_group(backend='gloo')\n    g = dgl.distributed.DistGraph('graph_name', 'part_config.json')\n    pb = g.get_partition_book()\n    train_nid = dgl.distributed.node_split(g.ndata['train_mask'], pb, force_even=True)\n\n\n    # Create sampler\n    sampler = NeighborSampler(g, [10,25],\n                              dgl.distributed.sample_neighbors,\n                              device)\n\n    dataloader = DistDataLoader(\n        dataset=train_nid.numpy(),\n        batch_size=batch_size,\n        collate_fn=sampler.sample_blocks,\n        shuffle=True,\n        drop_last=False)\n\n    # Define model and optimizer\n    model = SAGE(in_feats, num_hidden, n_classes, num_layers, F.relu, dropout)\n    model = th.nn.parallel.DistributedDataParallel(model)\n    loss_fcn = nn.CrossEntropyLoss()\n    optimizer = optim.Adam(model.parameters(), lr=args.lr)\n\n    # training loop\n    for epoch in range(args.num_epochs):\n        for step, blocks in enumerate(dataloader):\n            batch_inputs, batch_labels = load_subtensor(g, blocks[0].srcdata[dgl.NID],\n                                                        blocks[-1].dstdata[dgl.NID])\n            batch_pred = model(blocks, batch_inputs)\n            loss = loss_fcn(batch_pred, batch_labels)\n            optimizer.zero_grad()\n            loss.backward()\n            optimizer.step()\n\n컴퓨터들의 클러스터에서 학습 스크립트를 수행할 때, DGL은 데이터를 클러스터의 컴퓨터들에 복사하고 모든 컴퓨터에서 학습 잡을 실행하는 도구들을 제공한다.\n\n**Note**: 현재 분산 학습 API는 PyTorch 백앤드만 지원한다.\n\nDGL은 분산 학습을 지원하기 위해서 몇 가지 분산 컴포넌트를 구현하고 있다. 아래 그림은 컴포넌트들과 그것들의 인터엑션을 보여준다.\n\n.. figure:: https://data.dgl.ai/asset/image/distributed.png\n   :alt: Imgur\n\n특히, DGL의 분산 학습은 3가지 종류의 프로세스들을 갖는다: *서버*, *샘플러*, 그리고 *트레이너*\n\n* 서버 프로세스는 그래프 파티션(그래프 구조와 노드/에지 피처를 포함)을 저장하고 있는 각 컴퓨터에서 실행된다. 이 서버들은 함께 작동하면서 그래프 데이터를 트레이너에게 제공한다. 한 컴퓨터는 여러 서버 프로세스들을 동시에 수행하면서 연산과 네트워크 통신을 병렬화 한다.\n* 샘플러 프로세스들은 서버들과 상호작용을 하면서, 학습에 사용될 미니-배치를 만들기 위해서 노드와 에지를 샘플링한다.\n* 트레이너들은 서버들과 상호작용을 하기 위한 여러 클래스를 포함하고 있다. 파티션된 그래프 데이터를 접근하기 위한 :class:`~dgl.distributed.DistGraph` , 노드/에지의 피쳐/임베딩을 접근하기 위한 :class:`~dgl.distributed.DistEmbedding` 와 :class:`~dgl.distributed.DistTensor` 를 갖는다. 미니-배치를 얻기 위해서 샘플러와 상호작용을 하는 :class:`~dgl.distributed.dist_dataloader.DistDataLoader` 가 있다.\n\n분산 컴포넌드들을 염두해두고, 이 절의 나머지에서는 다음과 같은 분산 컴포넌트들을 다룬다.\n\n* :ref:`guide_ko-distributed-preprocessing`\n* :ref:`guide_ko-distributed-apis`\n* :ref:`guide_ko-distributed-hetero`\n* :ref:`guide_ko-distributed-tools`\n\n.. toctree::\n    :maxdepth: 1\n    :hidden:\n    :glob:\n\n    distributed-preprocessing\n    distributed-apis\n    distributed-hetero\n    distributed-tools\n"
  },
  {
    "path": "docs/source/guide_ko/graph-basic.rst",
    "content": ".. _guide_ko-graph-basic:\n\n1.1 그래프에 대한 몇가지 기본적인 정의 (그래프 101)\n----------------------------------------------------\n\n:ref:`(English Version)<guide-graph-basic>`\n\n그래프 :math:`G=(V, E)` 는 인티티들과 그것들의 관계를 표현하기 위한 자료 구조이다. 그래프는 노드들의 집합(또는 버틱스들):math:`V` 과 에지들의 집합(또는 아크들) :math:`E` , 두개의 집합으로 구성된다. 두 노드 :math:`u` 와 :math:`v` 의 쌍을 연결하는 에지 :math:`(u, v) \\in E` 는 이들 사이에 관계가 있음을 나타낸다. 이 관계는 노드들간의 대칭적인 관계를 표현하는 것과 같이 방향성이 없거나, 비대칭적인 관계를 표현하기 위해서 방향성을 갖을 수 있다. 예를 들어, 소셜 네트워크에서 사람들 간의 친구 관계 모델링에 그래프를 사용한다면, 친구 관계는 양방향이기 때문에 에지는 방향성이 없을 것이다. 하지만, 그래프가 트위터의 팔로우 관계를 모델링하는데 사용된다면, 에지는 방향성이 있다. 에지의 방향성에 따라서, 그래프는 *방향성(directed)* 또는 *비방향성(undirected)* 이 된다. \n\n그래프는 *가중치를 갖거나(unweight)* , *가중치를 갖지 않는다(unweighted)*. 가중치 그래프에서 각 에지는 스칼라 가중치와 연결된다. 예를 들어, 가중치는 길이 또는 연결 강도를 의미할 수 있다.\n\n그래프는 *동종(homogeneous)* 또는 *이종(heterogeneous)* 일 수 있다. 동종 그래프(homogeneous graph)에서 모든 노드들은 같은 타입의 인스턴스를 표현하고, 모든 에지들도 같은 타입의 관계를 나타낸다. 예를 들어, 소셜 네트워크는 사람들과 그들의 연결로 구성된 그래프이고, 이들은 모두 같은 타입을 갖는다.\n\n그와 반대로 이종 그래프(heterogeneous graph)에서는 노드들과 에지들이 여러 타입을 갖는다. 예들 들어, 메켓플래이스를 인코딩한 그래프는 구매자, 판매자, 그리고 상품 노드들이 구입-원함(want-to-buy), 구입했음(has-bought), ~의-고객(is-coustomer-of), 그리고 ~을-판매함(is-selling) 에지로 연결되어 있다. 이분 그래프(bipartite graph)는 이종 그래프의 특별한 형태로 흔히 사용되는 그래프 타입으로, 에지는 서로 다른 두 타입의 노드를 연결한다. 예를 들어, 추천 시스템에서 이분 그래프를 사용해서 사용자들과 아이템들의 상호관계를 표현할 수 있다. DGL에서 이종 그래프를 어떻게 사용하는지는 :ref:`guide_ko-graph-heterogeneous` 를 참고하자. \n\n다중 그래프(multigraph)는 자체 루프(self loop)를 포함한 노드들의 같은 쌍들 사이에 (방향성이 있는) 여러 에지들을 갖는 그래프이다. 예를 들어, 두 저자가 서로 다른 해에 공동 저작을 했다면, 다른 피처들을 갖는 여러 에지가 만들어진다.\n"
  },
  {
    "path": "docs/source/guide_ko/graph-external.rst",
    "content": ".. _guide_ko-graph-external:\n\n1.4 외부 소스를 사용한 그래프 생성하기\n-----------------------------------------\n\n:ref:`(English Version)<guide-graph-external>`\n\n외부 소스들로부터 :class:`~dgl.DGLGraph` 를 만드는 옵션들:\n\n- 그래프 및 회소 행렬을 위한 python 라이브러리(NetworkX 및 SciPy)로부터 변환하기\n- 디스크에서 그래프를 로딩하기\n\n이 절에서는 다른 그래프를 변환해서 그래프를 생성하는 함수들은 다루지 않겠다. 그 방법들에 대한 소개는 매뉴얼의 API를 참조하자.\n\n외부 라이브러리를 사용해서 그래프 생성하기\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n아래 코드는 SciPy 희소행렬과 NetworkX 그래프로부터 그래프를 생성하는 예제이다.\n\n.. code::\n\n    >>> import dgl\n    >>> import torch as th\n    >>> import scipy.sparse as sp\n    >>> spmat = sp.rand(100, 100, density=0.05) # 5% nonzero entries\n    >>> dgl.from_scipy(spmat)                   # from SciPy\n    Graph(num_nodes=100, num_edges=500,\n          ndata_schemes={}\n          edata_schemes={})\n\n    >>> import networkx as nx\n    >>> nx_g = nx.path_graph(5) # a chain 0-1-2-3-4\n    >>> dgl.from_networkx(nx_g) # from networkx\n    Graph(num_nodes=5, num_edges=8,\n          ndata_schemes={}\n          edata_schemes={})\n\n`nx.path_graph(5)` 로부터 만들면 생성된 :class:`~dgl.DGLGraph` 는 4개가 아니라 8개의 에지를 갖는 점을 유의하자. 이유는 `nx.path_graph(5)` 는 방향이 없는 NetworkX 그래프 :class:`networkx.Graph` 를 만드는데, :class:`~dgl.DGLGraph` 는 항상 방향이 있는 그래프이기 때문이다. 방향이 없는 NetworkX 그래프를 :class:`~dgl.DGLGraph` 로 변환하면, DGL은 내부적으로 방향이 없는 에지를 두개의 방향이 있는 에지로 변환한다. :class:`networkx.DiGraph` 를 사용하면 이런 현상을 피할 수 있다.\n\n.. code::\n\n    >>> nxg = nx.DiGraph([(2, 1), (1, 2), (2, 3), (0, 0)])\n    >>> dgl.from_networkx(nxg)\n    Graph(num_nodes=4, num_edges=4,\n          ndata_schemes={}\n          edata_schemes={})\n\n.. note::\n\n    내부적으로 DGL은 SciPy 행렬과 NetworkX 그래프를 텐서로 변환해서 그래프를 만든다. 따라서, 이 생성 방법은 성능이 중요한 곳에 사용되기 적합하지 않다.\n\n참고할 API들: :func:`dgl.from_scipy` , :func:`dgl.from_networkx` .\n\n디스크에서 그래프 로딩하기\n^^^^^^^^^^^^^^^^^^^\n\n그래프를 저장하기 위한 여러 데이터 포멧들이 있는데, 모든 옵션들을 나열하기는 불가능하다. 그래서 이 절에서는 공통적인 것들에 대한 일반적인 참조만 소개한다.\n\nComma Separated Values (CSV)\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n\n아주 일반적인 포멧으로 CSV가 사용된다. 이는 노드, 에치, 그리고 그것들의 피처들을 테이블 형태로 저장한다.\n\n.. table:: nodes.csv\n\n   +-----------+\n   |age, title |\n   +===========+\n   |43, 1      |\n   +-----------+\n   |23, 3      |\n   +-----------+\n   |...        |\n   +-----------+\n\n.. table:: edges.csv\n\n   +-----------------+\n   |src, dst, weight |\n   +=================+\n   |0, 1, 0.4        |\n   +-----------------+\n   |0, 3, 0.9        |\n   +-----------------+\n   |...              |\n   +-----------------+\n\n잘 알려진 Python 라이브러리들(예, pandas)을 사용해서 이 형태의 데이터를 python 객체(예, :class:`numpy.ndarray` )로 로딩하고, 이를 DGLGraph로 변환하는데 사용할 수 있다. 만약 백엔드 프레임워크가 디스크에서 텐서를 저장하고/읽는 기능(예, :func:`torch.save` , :func:`torch.load` )을 제공한다면, 그래프를 만드는데 이용할 수 있다.\n\n함께 참조하기: `Tutorial for loading a Karate Club Network from edge pairs CSV <https://github.com/dglai/WWW20-Hands-on-Tutorial/blob/master/basic_tasks/1_load_data.ipynb>`_.\n\nJSON/GML 포멧\n\"\"\"\"\"\"\"\"\"\"\"\"\n\n특별히 빠르지는 않지만 NetworkX는 `다양한 데이터 포멧 <https://networkx.github.io/documentation/stable/reference/readwrite/index.html>`_ 을 파싱하는 유틸리티들을 제공하는데, 이를 통해서 DGL 그래프를 만들 수 있다.\n\nDGL 바이너리 포멧\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n\nDGL은 디스크에 그래프를 바이너리 형태로 저장하고 로딩하는 API들을 제공한다. 그래프 구조와 더불어, API들은 피처 데이터와 그래프 수준의 레이블 데이터도 다룰 수 있다. DGL은 그래프를 직접 S3 또는 HDFS에 체크포인트를 할 수 있는 기능을 제공한다. 러퍼런스 메뉴얼에 자세한 내용이 있으니 참고하자.\n\n참고할 API들: :func:`dgl.save_graphs` , :func:`dgl.load_graphs`\n"
  },
  {
    "path": "docs/source/guide_ko/graph-feature.rst",
    "content": ".. _guide_ko-graph-feature:\n\n1.3 노드와 에지의 피처\n--------------------------\n\n:ref:`(English Version)<guide-graph-feature>`\n\n노드들과 에지들의 그래프별 속성을 저장하기 위해서, :class:`~dgl.DGLGraph` 의 노드들과 에지들은 이름을 갖는 사용자 정의 피쳐를 갖을 수 있다. :py:attr:`~dgl.DGLGraph.ndata` 와 :py:attr:`~dgl.DGLGraph.edata` 인터페이스를 이용해서 이 피쳐들을 접근할 수 있다. 예를 들어, 아래 코드는 두 노드에 대한 피쳐를 생성하고(라인 8과 15에서 ``'x'`` 와 ``'y'`` 이름 피처), 한개의 에지 피처(라인 9에서 ``'x'`` 이름 피처)를 생성한다.\n\n.. code-block:: python\n    :linenos:\n\n    >>> import dgl\n    >>> import torch as th\n    >>> g = dgl.graph(([0, 0, 1, 5], [1, 2, 2, 0])) # 6 nodes, 4 edges\n    >>> g\n    Graph(num_nodes=6, num_edges=4,\n          ndata_schemes={}\n          edata_schemes={})\n    >>> g.ndata['x'] = th.ones(g.num_nodes(), 3)               # node feature of length 3\n    >>> g.edata['x'] = th.ones(g.num_edges(), dtype=th.int32)  # scalar integer feature\n    >>> g\n    Graph(num_nodes=6, num_edges=4,\n          ndata_schemes={'x' : Scheme(shape=(3,), dtype=torch.float32)}\n          edata_schemes={'x' : Scheme(shape=(,), dtype=torch.int32)})\n    >>> # different names can have different shapes\n    >>> g.ndata['y'] = th.randn(g.num_nodes(), 5)\n    >>> g.ndata['x'][1]                  # get node 1's feature\n    tensor([1., 1., 1.])\n    >>> g.edata['x'][th.tensor([0, 3])]  # get features of edge 0 and 3\n        tensor([1, 1], dtype=torch.int32)\n\n:py:attr:`~dgl.DGLGraph.ndata`/:py:attr:`~dgl.DGLGraph.edata` 인터페이스의 중요한 사실들:\n\n- 숫자 타입(예, float, double, int)의 피처들만 허용된다. 피처는 스칼라, 벡터, 또는 다차원 텐서가 가능하다.\n- 각 노드 피처는 고유한 이름을 갖고, 각 에지 피쳐도 고유한 이름을 갖는다. 노드와 에지의 피쳐는 같은 이름을 갖을 수 있다. (예, 위 예의 'x')\n- 턴서 할당으로 피처가 만들어진다. 즉, 피처를 그래프의 각 노드/에지에 할당하는 것이다. 텐서의 첫번째 차원은 그래프의 노드/에지들의 개수와 같아야 한다. 그래프의 노드/에지의 일부에만 피쳐를 할당하는 것은 불가능하다.\n- 같은 이름의 피처들은 같은 차원 및 같은 타입을 갖아야 한다.\n- 피처 텐서는 행 위주(row-major)의 레이아웃을 따른다. 각 행-슬라이스는 한 노드 또는 이제의 피처를 저장한다. (아래 예제의 16줄 및 18줄을 보자)\n\n가중치 그래프인 경우, 에지 피처로 가중치를 저장할 수 있다.\n\n.. code-block:: python\n\n    >>> # edges 0->1, 0->2, 0->3, 1->3\n    >>> edges = th.tensor([0, 0, 0, 1]), th.tensor([1, 2, 3, 3])\n    >>> weights = th.tensor([0.1, 0.6, 0.9, 0.7])  # weight of each edge\n    >>> g = dgl.graph(edges)\n    >>> g.edata['w'] = weights  # give it a name 'w'\n    >>> g\n    Graph(num_nodes=4, num_edges=4,\n          ndata_schemes={}\n          edata_schemes={'w' : Scheme(shape=(,), dtype=torch.float32)})\n\n참고할 API들: :py:attr:`~dgl.DGLGraph.ndata` , :py:attr:`~dgl.DGLGraph.edata`\n"
  },
  {
    "path": "docs/source/guide_ko/graph-gpu.rst",
    "content": ".. _guide_ko-graph-gpu:\n\n1.6 GPU에서 DGLGraph 사용하기\n--------------------------\n\n:ref:`(English Version)<guide-graph-gpu>`\n\n그래프 생성시, 두 GPU 텐서를 전달해서 GPU에 위치한 :class:`~dgl.DGLGraph` 를 만들 수 있다. 다른 방법으로는 :func:`~dgl.DGLGraph.to` API를 사용해서 :class:`~dgl.DGLGraph` 를 GPU로 복사할 수 있다. 이는 그래프 구조와 피처 데이터를 함께 복사한다.\n\n.. code::\n\n    >>> import dgl\n    >>> import torch as th\n    >>> u, v = th.tensor([0, 1, 2]), th.tensor([2, 3, 4])\n    >>> g = dgl.graph((u, v))\n    >>> g.ndata['x'] = th.randn(5, 3)  # original feature is on CPU\n    >>> g.device\n    device(type='cpu')\n    >>> cuda_g = g.to('cuda:0')  # accepts any device objects from backend framework\n    >>> cuda_g.device\n    device(type='cuda', index=0)\n    >>> cuda_g.ndata['x'].device       # feature data is copied to GPU too\n    device(type='cuda', index=0)\n\n    >>> # A graph constructed from GPU tensors is also on GPU\n    >>> u, v = u.to('cuda:0'), v.to('cuda:0')\n    >>> g = dgl.graph((u, v))\n    >>> g.device\n    device(type='cuda', index=0)\n\nGPU 그래프에 대한 모든 연산은 GPU에서 수행된다. 따라서, 모든 텐서 인자들이 GPU에 이미 존재해야하며, 연산 결과(그래프 또는 텐서) 역시 GPU에 저장된다. 더 나아가, GPU 그래프는 GPU에 있는 피쳐 데이터만 받아들인다.\n\n.. code::\n\n    >>> cuda_g.in_degrees()\n    tensor([0, 0, 1, 1, 1], device='cuda:0')\n    >>> cuda_g.in_edges([2, 3, 4])   # ok for non-tensor type arguments\n    (tensor([0, 1, 2], device='cuda:0'), tensor([2, 3, 4], device='cuda:0'))\n    >>> cuda_g.in_edges(th.tensor([2, 3, 4]).to('cuda:0'))  # tensor type must be on GPU\n    (tensor([0, 1, 2], device='cuda:0'), tensor([2, 3, 4], device='cuda:0'))\n    >>> cuda_g.ndata['h'] = th.randn(5, 4)  # ERROR! feature must be on GPU too!\n    DGLError: Cannot assign node feature \"h\" on device cpu to a graph on device\n    cuda:0. Call DGLGraph.to() to copy the graph to the same device.\n"
  },
  {
    "path": "docs/source/guide_ko/graph-graphs-nodes-edges.rst",
    "content": ".. _guide_ko-graph-graphs-nodes-edges:\n\n1.2 그래프, 노드, 그리고 에지\n----------------------------\n\n:ref:`(English Version)<guide-graph-graphs-nodes-edges>`\n\nDGL은 각 노드에 고유한 번호를 부여하는데 이를 노드 ID라고 하고, 각 에지에는 연결된 노드의 ID들에 해당하는 번호 쌍으로 표현된다. DGL은 각 에지에 고유한 번호를 부여하고, 이를 **에지 ID**라고 하며, 그래프에 추가된 순서에 따라 번호가 부여된다. 노드와 에지 ID의 번호는 0부터 시작한다. DGL에서는 모든 에지는 방향을 갖고, 에지 :math:`(u,v)` 는 노드 :math:`u` 에서 노드 :math:`v` 로 이어진 방향을 나타낸다.\n\n여러 노드를 표현하기 위해서 DGL는 노드 ID로 1차원 정수 텐서를 사용한다. (PyTorch의 tensor, TensorFlow의 Tensor, 또는 MXNet의 ndarry) DGL은 이 포멧을 \"노드-텐서\"라고 부른다. DGL에서 에지들은 노드-텐서의 튜플 :math:`(U, V)` 로 표현된다. :math:`(U[i], V[i])`  는 :math:`U[i]` 에서 :math:`V[i]` 로의 에지이다. \n\n:class:`~dgl.DGLGraph` 를 만드는 방법 중의 하나는 :func:`dgl.graph` 메소드를 사용하는 것이다. 이는 에지 집합을 입력으로 받는다. 또한 DGL은 다른 데이터 소스로부터 그래프들을 생성하는 것도 지원한다. :ref:`guide_ko-graph-external` 참고하자.\n\n다음 코드는 아래와 같은 4개의 노드를 갖는 그래프를 :func:`dgl.graph` 를 사용해서 :class:`~dgl.DGLGraph` 만들고, 그래프 구조를 쿼리하는 API들을 보여준다.\n\n.. figure:: https://data.dgl.ai/asset/image/user_guide_graphch_1.png\n    :height: 200px\n    :width: 300px\n    :align: center\n\n.. code::\n\n    >>> import dgl\n    >>> import torch as th\n\n    >>> # edges 0->1, 0->2, 0->3, 1->3\n    >>> u, v = th.tensor([0, 0, 0, 1]), th.tensor([1, 2, 3, 3])\n    >>> g = dgl.graph((u, v))\n    >>> print(g) # number of nodes are inferred from the max node IDs in the given edges\n    Graph(num_nodes=4, num_edges=4,\n          ndata_schemes={}\n          edata_schemes={})\n\n    >>> # Node IDs\n    >>> print(g.nodes())\n    tensor([0, 1, 2, 3])\n    >>> # Edge end nodes\n    >>> print(g.edges())\n    (tensor([0, 0, 0, 1]), tensor([1, 2, 3, 3]))\n    >>> # Edge end nodes and edge IDs\n    >>> print(g.edges(form='all'))\n    (tensor([0, 0, 0, 1]), tensor([1, 2, 3, 3]), tensor([0, 1, 2, 3]))\n\n    >>> # If the node with the largest ID is isolated (meaning no edges),\n    >>> # then one needs to explicitly set the number of nodes\n    >>> g = dgl.graph((u, v), num_nodes=8)\n\n비방향성 그래프를 만들기 위해서는 양방향에 대한 에지들을 만들어야 한다. :func:`dgl.to_bidirected` 함수를 사용하면, 그래프를 양방향의 에지를 갖는 그래프로 변환할 수 있다.\n\n.. code::\n\n    >>> bg = dgl.to_bidirected(g)\n    >>> bg.edges()\n    (tensor([0, 0, 0, 1, 1, 2, 3, 3]), tensor([1, 2, 3, 0, 3, 0, 0, 1]))\n\n.. note::\n\n    DGL API에서는 일반적으로 텐서 타입이 사용된다. 이는 C 언어에서 효율적으로 저장되는 특징과, 명시적인 데이터 타입, 그리고 디바이스 컨택스트 정보 때문이다. 하지만, 빠른 프로토타입 개발을 지원하기 위해서, 대부분 DGL API는 파이선 iterable (예 list) 및 numpy.array를 함수 인자로 지원하고 있다.\n\nDGL은 노드 및 에지 ID를 저장하는데 :math:`32` 비트 또는 :math:`64` 비트 정수를 사용할 수 있다. 노드와 에지 ID의 데이터 타입은 같아야 한다. :math:`64` 비트를 사용하면 DGL은 노드 또는 에지를 :math:`2^{64} - 1` 개까지 다룰 수 있다. 하지만 그래프의 노드 또는 에지가 :math:`2^{31} - 1` 개 이하인 경우에는 :math:`32` 비트 정수를 사용해야한다. 이유는 속도도 빠르고 저장공간도 적게 사용하기 때문이다. DGL은 이 변환을 위한 방법들을 제공한다. 아래 예제를 보자.\n\n.. code::\n\n    >>> edges = th.tensor([2, 5, 3]), th.tensor([3, 5, 0])  # edges 2->3, 5->5, 3->0\n    >>> g64 = dgl.graph(edges)  # DGL uses int64 by default\n    >>> print(g64.idtype)\n    torch.int64\n    >>> g32 = dgl.graph(edges, idtype=th.int32)  # create a int32 graph\n    >>> g32.idtype\n    torch.int32\n    >>> g64_2 = g32.long()  # convert to int64\n    >>> g64_2.idtype\n    torch.int64\n    >>> g32_2 = g64.int()  # convert to int32\n    >>> g32_2.idtype\n    torch.int32\n\n참고할 API들: :func:`dgl.graph` , :func:`dgl.DGLGraph.nodes` , :func:`dgl.DGLGraph.edges` , :func:`dgl.to_bidirected` ,\n:func:`dgl.DGLGraph.int` , :func:`dgl.DGLGraph.long` , 그리고 :py:attr:`dgl.DGLGraph.idtype` \n\n"
  },
  {
    "path": "docs/source/guide_ko/graph-heterogeneous.rst",
    "content": ".. _guide_ko-graph-heterogeneous:\n\n1.5 이종 그래프 (Heterogeneous Graph)\n----------------------------------\n\n:ref:`(English Version)<guide-graph-heterogeneous>`\n\n이종 그래프는 다른 타입의 노드와 에지를 갖는다. 다른 타입의 노드/에지는 독립적인 ID 공간과 피처 저장소를 갖는다. 아래 그램의 예를 보면, user와 game 노드 ID는 모두 0부터 시작하고, 서로 다른 피처들을 갖고 있다.\n\n.. figure:: https://data.dgl.ai/asset/image/user_guide_graphch_2.png\n\n    두 타입의 노드(user와 game)와 두 타입의 에지(follows와 plays)를 갖는 이종 그래프 예\n\n이종 그래프 생성하기\n^^^^^^^^^^^^^^^\n\nDGL에서 이종 그래프(짧게 heterograph)는 관계당 하나의 그래프들의 시리즈로 표현된다. 각 관계는 문자열 트리플 ``(source node type, edge type, destination node type)`` 이다. 관계가 에지 타입을 명확하게 하기 때문에, DGL은 이것들을 캐노니컬(canonical) 에지 타입이라고 한다.\n\n아래 코드는 DGL에서 이종 그래프를 만드는 예제이다.\n\n.. code::\n\n    >>> import dgl\n    >>> import torch as th\n\n    >>> # Create a heterograph with 3 node types and 3 edges types.\n    >>> graph_data = {\n    ...    ('drug', 'interacts', 'drug'): (th.tensor([0, 1]), th.tensor([1, 2])),\n    ...    ('drug', 'interacts', 'gene'): (th.tensor([0, 1]), th.tensor([2, 3])),\n    ...    ('drug', 'treats', 'disease'): (th.tensor([1]), th.tensor([2]))\n    ... }\n    >>> g = dgl.heterograph(graph_data)\n    >>> g.ntypes\n    ['disease', 'drug', 'gene']\n    >>> g.etypes\n    ['interacts', 'interacts', 'treats']\n    >>> g.canonical_etypes\n    [('drug', 'interacts', 'drug'),\n     ('drug', 'interacts', 'gene'),\n     ('drug', 'treats', 'disease')]\n\n동종(homogeneous) 및 이분(bipartite) 그래프는 하나의 관계를 갖는 특별한 이종 그래프일 뿐임을 알아두자.\n\n.. code::\n\n    >>> # A homogeneous graph\n    >>> dgl.heterograph({('node_type', 'edge_type', 'node_type'): (u, v)})\n    >>> # A bipartite graph\n    >>> dgl.heterograph({('source_type', 'edge_type', 'destination_type'): (u, v)})\n\n이종 그래프와 연관된 *메타그래프(metagraph)* 는 그래프의 스키마이다. 이것은 노드들과 노드간의 에지들의 집합에 대한 타입 제약 조건을 지정한다. 메타그래프의 노드 :math:`u` 는 연관된 이종 그래프의 노드 타입에 해당한다. 메타그래프의 에지 :math:`(u,v)` 는 연관된 이종 그래프의 노드 타입 :math:`u` 와 노드 타입 :math:`v` 간에 에지가 있다는 것을 알려준다.\n\n.. code::\n\n    >>> g\n    Graph(num_nodes={'disease': 3, 'drug': 3, 'gene': 4},\n          num_edges={('drug', 'interacts', 'drug'): 2,\n                     ('drug', 'interacts', 'gene'): 2,\n                     ('drug', 'treats', 'disease'): 1},\n          metagraph=[('drug', 'drug', 'interacts'),\n                     ('drug', 'gene', 'interacts'),\n                     ('drug', 'disease', 'treats')])\n    >>> g.metagraph().edges()\n    OutMultiEdgeDataView([('drug', 'drug'), ('drug', 'gene'), ('drug', 'disease')])\n\n참고할 API들: :func:`dgl.heterograph` , :py:attr:`~dgl.DGLGraph.ntypes` , :py:attr:`~dgl.DGLGraph.etypes` , :py:attr:`~dgl.DGLGraph.canonical_etypes` , :py:attr:`~dgl.DGLGraph.metagraph`\n\n다양한 타입을 다루기\n^^^^^^^^^^^^^^^\n\n노드와 에지가 여러 타입이 사용되는 경우, 타입 관련된 정보를 위한 DGLGraph API를 호출할 때는 노드/에지의 타입을 명시해야한다. 추가로 다른 타입의 노드/에지는 별도의 ID를 갖는다.\n\n.. code::\n\n    >>> # Get the number of all nodes in the graph\n    >>> g.num_nodes()\n    10\n    >>> # Get the number of drug nodes\n    >>> g.num_nodes('drug')\n    3\n    >>> # Nodes of different types have separate IDs,\n    >>> # hence not well-defined without a type specified\n    >>> g.nodes()\n    DGLError: Node type name must be specified if there are more than one node types.\n    >>> g.nodes('drug')\n    tensor([0, 1, 2])\n\n특정 노드/에지 타입에 대한 피쳐를 설정하고 얻을 때, DGL은 두가지 새로운 형태의 문법을 제공한다 -- `g.nodes['node_type'].data['feat_name']`와 `g.edges['edge_type'].data['feat_name']`.\n\n.. code::\n\n    >>> # Set/get feature 'hv' for nodes of type 'drug'\n    >>> g.nodes['drug'].data['hv'] = th.ones(3, 1)\n    >>> g.nodes['drug'].data['hv']\n    tensor([[1.],\n            [1.],\n            [1.]])\n    >>> # Set/get feature 'he' for edge of type 'treats'\n    >>> g.edges['treats'].data['he'] = th.zeros(1, 1)\n    >>> g.edges['treats'].data['he']\n    tensor([[0.]])\n\n만약 그래프가 오직 한개의 노드/에지 타입을 갖는다면, 노드/에지 타입을 명시할 필요가 없다.\n\n.. code::\n\n    >>> g = dgl.heterograph({\n    ...    ('drug', 'interacts', 'drug'): (th.tensor([0, 1]), th.tensor([1, 2])),\n    ...    ('drug', 'is similar', 'drug'): (th.tensor([0, 1]), th.tensor([2, 3]))\n    ... })\n    >>> g.nodes()\n    tensor([0, 1, 2, 3])\n    >>> # To set/get feature with a single type, no need to use the new syntax\n    >>> g.ndata['hv'] = th.ones(4, 1)\n\n.. note::\n\n    에지 타입이 목적지와 도착지 노드의 타입을 고유하게 결정할 수 있다면, 에지 타입을 명시할 때 문자 트리플 대신 한 문자만들 사용할 수 있다. 예를 듬녀, 두 관계 ``('user', 'plays', 'game')`` and ``('user', 'likes', 'game')``를 갖는 이종 그래프가 있을 때, 두 관계를 지정하기 위해서 단지 ``'plays'`` 또는 ``'likes'`` 를 사용해도 된다.\n\n디스크에서 이종 그래프 로딩하기\n^^^^^^^^^^^^^^^^^^^^^^^\n\nComma Separated Values (CSV)\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n\n이종 그래프를 저장하는 일반적인 방법은 다른 타입의 노드와 에지를 서로 다른 CSV 파일에 저장하는 것이다. 예를들면 다음과 같다.\n\n.. code::\n\n    # data folder\n    data/\n    |-- drug.csv        # drug nodes\n    |-- gene.csv        # gene nodes\n    |-- disease.csv     # disease nodes\n    |-- drug-interact-drug.csv  # drug-drug interaction edges\n    |-- drug-interact-gene.csv  # drug-gene interaction edges\n    |-- drug-treat-disease.csv  # drug-treat-disease edges\n\n동종 그래프의 경우와 동일하게, Pandas와 같은 패키지들을 사용해서 CSV 파일들을 파싱하고, 이를 numpy 배열 또는 프레임워크의 텐서들에 저장하고, 관계 사전을 만들고, 이를 이용해서 이종 그래프를 생성할 수 있다. 이 방법은 GML/JSON과 같은 다른 유명한 포멧들에도 동일하게 적용된다.\n\nDGL 바이너리 포멧\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n\nDGL은 이종 그래프를 바이너리 포멧으로 저장하고 읽기 위한 함수 :func:`dgl.save_graphs` 와 :func:`dgl.load_graphs` 를 제공한다.\n\n에지 타입 서브그래프\n^^^^^^^^^^^^^^^\n\n보존하고 싶은 관계를 명시하고, 피처가 있을 경우는 이를 복사하면서 이종 그래프의 서브그래프를 생성할 수 있다.\n\n.. code::\n\n    >>> g = dgl.heterograph({\n    ...    ('drug', 'interacts', 'drug'): (th.tensor([0, 1]), th.tensor([1, 2])),\n    ...    ('drug', 'interacts', 'gene'): (th.tensor([0, 1]), th.tensor([2, 3])),\n    ...    ('drug', 'treats', 'disease'): (th.tensor([1]), th.tensor([2]))\n    ... })\n    >>> g.nodes['drug'].data['hv'] = th.ones(3, 1)\n\n    >>> # Retain relations ('drug', 'interacts', 'drug') and ('drug', 'treats', 'disease')\n    >>> # All nodes for 'drug' and 'disease' will be retained\n    >>> eg = dgl.edge_type_subgraph(g, [('drug', 'interacts', 'drug'),\n    ...                                 ('drug', 'treats', 'disease')])\n    >>> eg\n    Graph(num_nodes={'disease': 3, 'drug': 3},\n          num_edges={('drug', 'interacts', 'drug'): 2, ('drug', 'treats', 'disease'): 1},\n          metagraph=[('drug', 'drug', 'interacts'), ('drug', 'disease', 'treats')])\n    >>> # The associated features will be copied as well\n    >>> eg.nodes['drug'].data['hv']\n    tensor([[1.],\n            [1.],\n            [1.]])\n\n이종 그래프를 동종 그래프로 변환하기\n^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n이종 그래프는 다른 타입의 노드/에지와 그것들에 연관된 피쳐들을 관리하는데 깔끔한 인터페이스를 제공한다. 이것을 아래의 경우 특히 유용하다.\n\n1. 다른 타입의 노드/에지에 대한 피쳐가 다른 데이터 타입 또는 크기를 갖는다.\n2. 다른 타입의 노드/에지에 다른 연산을 적용하고 싶다.\n\n만약 위 조건을 만족하지 않고 모델링에서 노드/에지 타입의 구별이 필요하지 않는다면, DGL의 :func:`dgl.DGLGraph.to_homogeneous` API를 이용해서 이종 그래프를 동종 그래프로 변환할 수 있다. 이 변환은 다음 절처로 이뤄진다.\n\n1. 모든 타입의 노드/에지를 0부터 시작하는 정수로 레이블을 다시 부여한다.\n2. 사용자가 지정한 노드/에지 타입들에 걸쳐서 피쳐들을 합친다.\n\n.. code::\n\n    >>> g = dgl.heterograph({\n    ...    ('drug', 'interacts', 'drug'): (th.tensor([0, 1]), th.tensor([1, 2])),\n    ...    ('drug', 'treats', 'disease'): (th.tensor([1]), th.tensor([2]))})\n    >>> g.nodes['drug'].data['hv'] = th.zeros(3, 1)\n    >>> g.nodes['disease'].data['hv'] = th.ones(3, 1)\n    >>> g.edges['interacts'].data['he'] = th.zeros(2, 1)\n    >>> g.edges['treats'].data['he'] = th.zeros(1, 2)\n\n    >>> # By default, it does not merge any features\n    >>> hg = dgl.to_homogeneous(g)\n    >>> 'hv' in hg.ndata\n    False\n\n    >>> # Copy edge features\n    >>> # For feature copy, it expects features to have\n    >>> # the same size and dtype across node/edge types\n    >>> hg = dgl.to_homogeneous(g, edata=['he'])\n    DGLError: Cannot concatenate column ‘he’ with shape Scheme(shape=(2,), dtype=torch.float32) and shape Scheme(shape=(1,), dtype=torch.float32)\n\n    >>> # Copy node features\n    >>> hg = dgl.to_homogeneous(g, ndata=['hv'])\n    >>> hg.ndata['hv']\n    tensor([[1.],\n            [1.],\n            [1.],\n            [0.],\n            [0.],\n            [0.]])\n\n원래의 노드/에지 타입과 타입별 ID들은 :py:attr:`~dgl.DGLGraph.ndata` 와 :py:attr:`~dgl.DGLGraph.edata` 에 저장된다.\n\n.. code::\n\n    >>> # Order of node types in the heterograph\n    >>> g.ntypes\n    ['disease', 'drug']\n    >>> # Original node types\n    >>> hg.ndata[dgl.NTYPE]\n    tensor([0, 0, 0, 1, 1, 1])\n    >>> # Original type-specific node IDs\n    >>> hg.ndata[dgl.NID]\n    tensor([0, 1, 2, 0, 1, 2])\n\n    >>> # Order of edge types in the heterograph\n    >>> g.etypes\n    ['interacts', 'treats']\n    >>> # Original edge types\n    >>> hg.edata[dgl.ETYPE]\n    tensor([0, 0, 1])\n    >>> # Original type-specific edge IDs\n    >>> hg.edata[dgl.EID]\n    tensor([0, 1, 0])\n\n모델링 목적으로, 특정 관계들을 모아서 그룹으로 만들고, 그것들에 같은 연산을 적용하고 싶은 경우가 있다. 이를 위해서, 우선 이종 그래프의 에지 타입 서브그래프를 추출하고, 그리고 그 서브그래프를 동종 그래프로 변환한다.\n\n.. code::\n\n    >>> g = dgl.heterograph({\n    ...    ('drug', 'interacts', 'drug'): (th.tensor([0, 1]), th.tensor([1, 2])),\n    ...    ('drug', 'interacts', 'gene'): (th.tensor([0, 1]), th.tensor([2, 3])),\n    ...    ('drug', 'treats', 'disease'): (th.tensor([1]), th.tensor([2]))\n    ... })\n    >>> sub_g = dgl.edge_type_subgraph(g, [('drug', 'interacts', 'drug'),\n    ...                                    ('drug', 'interacts', 'gene')])\n    >>> h_sub_g = dgl.to_homogeneous(sub_g)\n    >>> h_sub_g\n    Graph(num_nodes=7, num_edges=4,\n          ...)\n"
  },
  {
    "path": "docs/source/guide_ko/graph.rst",
    "content": ".. _guide_ko-graph:\n\n1장: 그래프\n=========\n\n:ref:`(English version)<guide-graph>`\n\n그래프는 앤티티들(entity 또는 노드들)과 노드들간의 관계(에지)로 표현되며, 노드와 에지들을 타입을 갖을 수 있다. (예를 들어, ``\"user\"`` 와 ``\"item\"`` 은 서로 다른 타입의 노드들이다.) DGL은 :class:`~dgl.DGLGraph` 를 핵심 자료 구조로 갖는 그래프-중심의 프로그래밍 추상화를 제공한다. :class:`~dgl.DGLGraph` 그래프의 구조, 그 그래프의 노드 및 에지 피처들과 이 컴포넌트들을 사용해서 수행된 연산 결과를 다루는데 필요한 인터페이스를 제공한다.\n\n로드맵\n-------\n\n이 장은 1.1절의 그래프 정의에 대한 간단한 소개를 시작으로 :class:`~dgl.DGLGraph`: 의 몇가지 핵심 개념을 소개한다.\n\n* :ref:`guide_ko-graph-basic`\n* :ref:`guide_ko-graph-graphs-nodes-edges`\n* :ref:`guide_ko-graph-feature`\n* :ref:`guide_ko-graph-external`\n* :ref:`guide_ko-graph-heterogeneous`\n* :ref:`guide_ko-graph-gpu`\n\n.. toctree::\n    :maxdepth: 1\n    :hidden:\n    :glob:\n\n    graph-basic\n    graph-graphs-nodes-edges\n    graph-feature\n    graph-external\n    graph-heterogeneous\n    graph-gpu\n"
  },
  {
    "path": "docs/source/guide_ko/index.rst",
    "content": "사용자 가이드[시대에 뒤쳐진]\n=====================\n\n.. toctree::\n  :maxdepth: 2\n  :titlesonly:\n\n  graph\n  message\n  nn\n  data\n  training\n  minibatch\n  distributed\n  mixed_precision\n\n\n이 한글 버전 DGL 사용자 가이드 2021년 11월 기준의 영문 :ref:`(User Guide) <guide-index>` 을 Amazon Machine Learning Solutions Lab의 김무현 Principal Data Scientist가 번역한 것입니다. 오류 및 질문은 `muhyun@amazon.com` 으로 보내주세요."
  },
  {
    "path": "docs/source/guide_ko/message-api.rst",
    "content": ".. _guide_ko-message-passing-api:\n\n2.1 빌트인 함수 및 메시지 전달 API들\n-----------------------------\n\n:ref:`(English Version) <guide-message-passing-api>`\n\nDGL에서 **메시지 함수** 는 한개의 인자 ``edges`` 를 갖는데, 이는 :class:`~dgl.udf.EdgeBatch` 의 객체이다. 메시지 전달이 실행되는 동안 DGL은 에지 배치를 표현하기 위해서 이 객체를 내부적으로 생성한다. 이것은 3개의 맴버, ``src`` , ``dst`` , 그리고 ``data`` 를 갖고, 이는 각각 소스 노드, 목적지 노드, 그리고 에지의 피쳐를 의미한다.\n\n**축약 함수(reduce function)** 는 한개의 인자 ``nodes`` 를 갖는데, 이는 :class:`~dgl.udf.NodeBatch` 의 객체이다. 메시지 전달이 실행되는 동안 DGL은 노드 배치를 표현하기 위해서 이 객체를 내부적으로 생성한다. 이 객체는 ``mailbox`` 라는 맴버를 갖는데, 이는 배치에 속한 노드들에게 전달된 메시지들을 접근 방법을 제공한다. 가장 흔한 축약 함수로는 ``sum`` , ``max`` , ``min`` 등이 있다.\n\n**업데이트 함수** 는 위에서 언급한 ``nodes`` 를 한개의 인자로 갖는다. 이 함수는 ``축약 함수`` 의 집계 결과에 적용되는데, 보통은 마지막 스탭에서 노드의 원래 피처와 이 결과와 결합하고, 그 결과를 노드의 피처로 저장한다.\n\nDGL은 일반적으로 사용되는 메시지 전달 함수들과 축약 함수들을 ``dgl.function`` 네임스패이스에 **빌트인** 으로 구현하고 있다. 일반적으로, **가능한 경우라면 항상** DLG의 빌트인 함수를 사용하는 것을 권장하는데, 그 이유는 이 함수들은 가장 최적화된 형태로 구현되어 있고, 차원 브로드캐스팅을 자동으로 해주기 때문이다.\n\n만약 여러분의 메시지 전달 함수가 빌트인 함수로 구현이 불가능하다면, 사용자 정의 메시지/축소 함수를 직접 구현할 수 있다. 이를 **UDF** 라고 한다.\n\n빌트인 메시지 함수들은 단항(unary) 또는 이상(binary)이다. 단항의 경우 DGL은 ``copy`` 를 지원한다. 이항 함수로 DGL은 ``add`` , ``sub`` , ``mul`` , ``div`` , 그리고 ``dot`` 를 지원한다. 빌트인 메시지 함수의 이름 규칙은 다음과 같다. ``u`` 는 ``src`` 노드를, ``v`` 는 ``dst`` 노드를 그리고 ``e`` 는 ``edges`` 를 의미한다. 이 함수들에 대한 파라미터들은 관련된 노드와 에지의 입력과 출력 필드 이름을 지칭하는 문자열이다. 지원되는 빌트인 함수의 목록은 :ref:`api-built-in` 을 참고하자. 한가지 예를 들면, 소스 노드의 ``hu`` 피처와 목적지 노드의 ``hv`` 피처를 더해서 그 결과를 에지의 ``he`` 필드에 저장하는 것을 빌트인 함수 ``dgl.function.u_add_v('hu', 'hv', 'he')`` 를 사용해서 구현할 수 있다. 이와 동일한 기능을 하는 메시지 UDF는 다음과 같다.\n\n.. code::\n\n    def message_func(edges):\n         return {'he': edges.src['hu'] + edges.dst['hv']}\n\n빌트인 축약 함수는 ``sum``, ``max``, ``min`` 그리고 ``mean`` 연산을 지원한다. 보통 축약 함수는 두개의 파라메터를 갖는데, 하나는 ``mailbox`` 의 필드 이름이고, 다른 하나는 노드 피처의 필드 이름이다. 이는 모두 문자열이다. 예를 들어, ``dgl.function.sum('m', 'h')`` 는 메시지 ``m`` 을 합하는 아래 축약 UDF와 같다.\n\n.. code::\n\n    import torch\n    def reduce_func(nodes):\n         return {'h': torch.sum(nodes.mailbox['m'], dim=1)}\n\nUDF의 고급 사용법을 더 알고 싶으면 :ref:`apiudf` 를 참고하자.\n\n:meth:`~dgl.DGLGraph.apply_edges` 를 사용해서 메시지 전달 함수를 호출하지 않고 에지별 연산만 호출하는 것도 가능하다. :meth:`~dgl.DGLGraph.apply_edges` 는 파라미터로 메시지 함수를 받는데, 기본 설정으로는 모든 에지의 피쳐를 업데이트한다. 다음 예를 살펴보자.\n\n.. code::\n\n    import dgl.function as fn\n    graph.apply_edges(fn.u_add_v('el', 'er', 'e'))\n\n메시지 전달을 위한 :meth:`~dgl.DGLGraph.update_all` 는 하이레벨 API로 메시지 생성, 메시지 병합 그리고 노드 업데이트를 단일 호출로 합쳤는데, 전반적으로 최적화할 여지가 남아있다.\n\n:meth:`~dgl.DGLGraph.update_all` 의 파라메터들은 메시지 함수, 축약 함수, 그리고 업데이트 함수이다. :meth:`~dgl.DGLGraph.update_all` 를 호출할 때 업데이트 함수를 지정하지 않는 경우, 업데이트 함수는 ``update_all`` 밖에서 수행될 수 있다. DGL은 이 방법을 권장하는데, 업데이트 함수는 코드를 간결하게 만들기 위해서 보통은 순수 텐서 연산으로 구현되어 있기 때문이다. 예를 들면, 다음과 같다.\n\n.. code::\n\n    def update_all_example(graph):\n        # store the result in graph.ndata['ft']\n        graph.update_all(fn.u_mul_e('ft', 'a', 'm'),\n                         fn.sum('m', 'ft'))\n        # Call update function outside of update_all\n        final_ft = graph.ndata['ft'] * 2\n        return final_ft\n\n이 함수는 소스 노드의 피처 ``ft`` 와 에지 피처 ``a`` 를 곱해서 메시지 ``m`` 을 생성하고, 메시지 ``m`` 들을 더해서 노드 피처 ``ft`` 를 업데이트하고, 마지막으로 ``final_ft`` 결과를 구하기 위해서 ``ft`` 에 2를 곱하고 있다. 호출이 완료되면 DGL은 중간에 사용된 메시지들 ``m`` 을 제거한다. 위 함수를 수학 공식으로 표현하면 다음과 같다.\n\n.. math::  {final\\_ft}_i = 2 * \\sum_{j\\in\\mathcal{N}(i)} ({ft}_j * a_{ij})\n\nDGL의 빌트인 함수는 부동소수점 데이터 타입을 지원한다. 즉, 피쳐들은 반드시 ``half`` (``float16``), ``float``, 또는 ``double`` 텐서여야만 한다. ``float16`` 데이터 타입에 대한 지원은 기본 설정에서는 비활성화되어 있다. 그 이유는 이를 지원하기 위해서는 ``sm_53`` (Pascal, Volta, Turing, 그리고 Ampere 아키텍타)와 같은 최소한의 GPU 컴퓨팅 능력이 요구되기 때문이다. \n\n사용자는 DGL 소스 컴파일을 통해서 mixed precision training을 위해서 float16을 활성화시킬 수 있다. (자세한 내용은 :doc:`Mixed Precision Training <mixed_precision>` 튜토리얼 참고)\n"
  },
  {
    "path": "docs/source/guide_ko/message-edge.rst",
    "content": ".. _guide_ko-message-passing-edge:\n\n2.4 메시지 전달에 에지 가중치 적용하기\n-----------------------------\n\n:ref:`(English Version) <guide-message-passing-edge>`\n\n`GAT <https://arxiv.org/pdf/1710.10903.pdf>`__ 또는 일부 `GCN 변형 <https://arxiv.org/abs/2004.00445>`__ 에서 사용되는 것처럼 메시지 병합이전에 에지의 가중치를 적용하는 것은 GNN 모델링에서 흔하게 사용되는 기법이다. DGL은 이를 다음과 같은 밥벙으로 지원하고 있다.\n\n- 가중치를 에지 피쳐로 저장\n- 메시지 함수에서 에지 피쳐를 소스 노드의 피쳐와 곱하기\n\n예를 들면,\n\n.. code::\n\n    import dgl.function as fn\n\n    # Suppose eweight is a tensor of shape (E, *), where E is the number of edges.\n    graph.edata['a'] = eweight\n    graph.update_all(fn.u_mul_e('ft', 'a', 'm'),\n                     fn.sum('m', 'ft'))\n\n이 예제는 eweight를 이제 가중치고 사용하고 있다. 에지 가중치는 보통은 스칼라 값을 갖는다."
  },
  {
    "path": "docs/source/guide_ko/message-efficient.rst",
    "content": ".. _guide_ko-message-passing-efficient:\n\n2.2 효율적인 메시지 전달 코드 작성 방법\n------------------------------\n\n:ref:`(English Version) <guide-message-passing-efficient>`\n\nDGL은 메시지 전달에 대한 메모리 사용과 연산 속드를 최적화하고 있다. 이 최적화들을 활용하는 일반적으로 사용되는 방법은 직접 메시지 전달 함수를 만들어서 이를 :meth:`~dgl.DGLGraph.update_all` 호출시 빌트인 함수와 함께 파라메터로 사용하는 것이다. \n\n만약 그래프의 에지들의 수가 노드들의 수보다 훨씬 많은 경우에는 노드에서 에지로의 불필요한 메모리 복사를 피하는 것이 도움이 된다. 에지에 메시지를 저장할 필요가 있는 :class:`~dgl.nn.pytorch.conv.GATConv` 와 같은 경우에는 빌트인 함수를 사용해서 :meth:`~dgl.DGLGraph.apply_edges` 를 호출해야 한다. 때로는 에지에 저장할 메시지의 차원이 너무 커서 메모리를 많이 차지하기도 한다. DGL에서는 가능한 에지 피쳐의 차원을 낮추는 것을 권장한다.\n\n에지에 대한 연산을 노드로 분할하여 이를 달성하는 방법에 대한 예제이다. 이 방법은 다음과 같다. ``src`` 피쳐와 ``dst`` 피쳐를 연결하고, 선형 레이어 :math:`W\\times (u || v)`를 적용하는 경우를 들어보자. ``src``와 ``dst`` 피처 차원은 매우 높은 반면에 선형 레이어의 결과 차원은 낮다고 가정하자. 이 예제를 직관적으로 구현하면 다음과 같다.\n\n.. code::\n\n    import torch\n    import torch.nn as nn\n\n    linear = nn.Parameter(torch.FloatTensor(size=(node_feat_dim * 2, out_dim)))\n    def concat_message_function(edges):\n         return {'cat_feat': torch.cat([edges.src['feat'], edges.dst['feat']], dim=1)}\n    g.apply_edges(concat_message_function)\n    g.edata['out'] = g.edata['cat_feat'] @ linear\n\n제안하는 구현은 이 선형 연산을 두개로 나누는 것이다. 하나는 ``src`` 피처에 적용하고, 다른 하나는 ``dst`` 피쳐에 적용한다. 그 후, 에지에 대한 두 선형 연산의 결과를 마지막 단계에서 더한다. 즉, :math:`W_l\\times u + W_r \\times v` 를 실행하는 것이다. :math:`W` 행렬의 왼쪽 반과 오른쪽 반이 각각 :math:`W_l` 와 :math:`W_r` 일 때, :math:`W \\times (u||v) = W_l \\times u + W_r \\times v` 가 성립하기 때문에 가능하다.\n\n.. code::\n\n    import dgl.function as fn\n\n    linear_src = nn.Parameter(torch.FloatTensor(size=(node_feat_dim, out_dim)))\n    linear_dst = nn.Parameter(torch.FloatTensor(size=(node_feat_dim, out_dim)))\n    out_src = g.ndata['feat'] @ linear_src\n    out_dst = g.ndata['feat'] @ linear_dst\n    g.srcdata.update({'out_src': out_src})\n    g.dstdata.update({'out_dst': out_dst})\n    g.apply_edges(fn.u_add_v('out_src', 'out_dst', 'out'))\n\n위 두 구현은 수학적으로 동일하다. 후자가 더 효율적인데, 그 이유는 메모리 비효율적인 에지에 feat_src와 feat_dst의 저장이 필요가 없기 때문이다. 추가로, 합은 연산속도가 더 빠르고 메모리 사용량을 줄인 DGL의 빌트인 함수 ``u_add_v`` 를 사용하면 최적화될 수 있다. "
  },
  {
    "path": "docs/source/guide_ko/message-heterograph.rst",
    "content": ".. _guide_ko-message-passing-heterograph:\n\n2.5 이종 그래프에서의 메시지 전달\n--------------------------\n\n:ref:`(English Version) <guide-message-passing-heterograph>`\n\n이종 그래프 ( :ref:`guide_ko-graph-heterogeneous` ) 또는 헤테로그래프는 여러 타입의 노드와 에지를 갖는 그래프이다. 각 노드와 에지의 특징을 표현하기 위해서 다른 타입의 속성을 갖기 위해서 노드와 에지들이 다른 타입을 갖을 수 있다. 복잡한 그래프 뉴럴 네트워크들에서 어떤 노드나 에지 타입들은 다른 차원들을 갖게 모델링 되기도 한다.\n\n이종 그래프에서 메시지 전달은 두 파트로 나뉜다:\n\n1. 각 관계(relation) r에 대한, 메지시 연산과 집계(aggregation)\n2. 가 노트 타입에 대한 모든 관계의 집계 결과를 합치는 축약(reduction)\n\n이종 그래프에서 메시지 전달을 담당하는 DGL 인터페이스는 :meth:`~dgl.DGLGraph.multi_update_all` 이다. :meth:`~dgl.DGLGraph.multi_update_all` 는 :meth:`~dgl.DGLGraph.update_all` 에 대한 파라메터들을 갖는 사전(dictionary)을 인자로 받는다. 이 사전의 각 키값는 관계이고, 그에 대한 값은 크로스 타입 리듀셔(cross type reducer)에 대한 문자열이다. Reducer는 ``sum``, ``min``, ``max``, ``mean``, ``stack`` 중에 하나가 된다. 예제는 다음과 같다.\n\n.. code::\n\n    import dgl.function as fn\n\n    for c_etype in G.canonical_etypes:\n        srctype, etype, dsttype = c_etype\n        Wh = self.weight[etype](feat_dict[srctype])\n        # Save it in graph for message passing\n        G.nodes[srctype].data['Wh_%s' % etype] = Wh\n        # Specify per-relation message passing functions: (message_func, reduce_func).\n        # Note that the results are saved to the same destination feature 'h', which\n        # hints the type wise reducer for aggregation.\n        funcs[etype] = (fn.copy_u('Wh_%s' % etype, 'm'), fn.mean('m', 'h'))\n    # Trigger message passing of multiple types.\n    G.multi_update_all(funcs, 'sum')\n    # return the updated node feature dictionary\n    return {ntype : G.nodes[ntype].data['h'] for ntype in G.ntypes}\n"
  },
  {
    "path": "docs/source/guide_ko/message-part.rst",
    "content": ".. _guide_ko-message-passing-part:\n\n2.3 그래프 일부에 메지시 전달 적용하기\n------------------------------\n\n:ref:`(English Version) <guide-message-passing-part>`\n\n그래프 노드의 일부만 업데이트를 하기 원하는 경우, 업데이트를 하고 싶은 노드들의 ID를 사용해서 서브그래프를 만든 후, 그 서브그래프에 :meth:`~dgl.DGLGraph.update_all` 를 호출하는 방법으로 가능하다.\n\n.. code::\n\n    nid = [0, 2, 3, 6, 7, 9]\n    sg = g.subgraph(nid)\n    sg.update_all(message_func, reduce_func, apply_node_func)\n\n이는 미니-배치 학습에서 흔히 사용되는 방법이다. 자세한 사용법은 :ref:`guide_ko-minibatch` 참고하자."
  },
  {
    "path": "docs/source/guide_ko/message.rst",
    "content": ".. _guide_ko-message-passing:\n\n2장: 메지시 전달(Message Passing)\n=============================\n\n:ref:`(English Version) <guide-message-passing>`\n\n메지시 전달 패러다임(Message Passing Paradigm)\n-----------------------------------------\n\n:math:`x_v\\in\\mathbb{R}^{d_1}` 이 노드 :math:`v` 의 피처이고, :math:`w_{e}\\in\\mathbb{R}^{d_2}` 가 에지 :math:`({u}, {v})` 의 피처라고 하자. **메시지 전달 패러다임** 은 :math:`t+1` 단계에서 노드별(node-wise) 그리고 에지별(edge-wise)의 연산을 다음과 같이 정의한다:\n\n.. math::  \\text{에지별: } m_{e}^{(t+1)} = \\phi \\left( x_v^{(t)}, x_u^{(t)}, w_{e}^{(t)} \\right) , ({u}, {v},{e}) \\in \\mathcal{E}.\n\n.. math::  \\text{노드별: } x_v^{(t+1)} = \\psi \\left(x_v^{(t)}, \\rho\\left(\\left\\lbrace m_{e}^{(t+1)} : ({u}, {v},{e}) \\in \\mathcal{E} \\right\\rbrace \\right) \\right).\n\n위 수식에서 :math:`\\phi` 는 각 에지에 대한 **메시지 함수** 로서 에지의 부속 노드(incident node)들의 피처를 그 에지 피처와 합쳐서 메시지를 만드는 역할을 수행한다. :math:`\\psi` 는 각 노드에 대한 **업데이트 함수** 로, **축소 함수(reduce function)** :math:`\\rho` 를 사용해서 전달된 메시지들을 통합하는 방식으로 노드의 피처를 업데이트한다.\n\n로드맵\n----\n\n이 장는 DGL의 메시지 전달 API들과, 노드와 에지에 효율적으로 적용하는 방법을 소개한다. 마지막 절에서는 이종 그래프에 메시지 전달을 어떻게 구현하는지 설명한다.\n\n* :ref:`guide_ko-message-passing-api`\n* :ref:`guide_ko-message-passing-efficient`\n* :ref:`guide_ko-message-passing-part`\n* :ref:`guide_ko-message-passing-edge`\n* :ref:`guide_ko-message-passing-heterograph`\n\n.. toctree::\n    :maxdepth: 1\n    :hidden:\n    :glob:\n\n    message-api\n    message-efficient\n    message-part\n    message-edge\n    message-heterograph\n"
  },
  {
    "path": "docs/source/guide_ko/minibatch-custom-sampler.rst",
    "content": ".. _guide_ko-minibatch-customizing-neighborhood-sampler:\n\n6.4 이웃 샘플러 커스터마이징하기\n-------------------------\n\n:ref:`(English Version) <guide-minibatch-customizing-neighborhood-sampler>`\n\nDGL이 여러 이웃 샘플링 방법들을 제공하지만, 샘플링 방법을 직접 만들어야할 경우도 있다. 이 절에서는 샘플링 방법을 직접 만드는 방법과 stochastic GNN 학습 프레임워크에서 사용하는 방법을 설명한다.\n\n`그래프 뉴럴 네트워크가 얼마나 강력한가(How Powerful are Graph Neural Networks) <https://arxiv.org/pdf/1810.00826.pdf>`__ 에서 설명했듯이, 메시지 전달은 다음과 같이 정의된다.\n\n.. math::\n\n\n   \\begin{gathered}\n     \\boldsymbol{a}_v^{(l)} = \\rho^{(l)} \\left(\n       \\left\\lbrace\n         \\boldsymbol{h}_u^{(l-1)} : u \\in \\mathcal{N} \\left( v \\right)\n       \\right\\rbrace\n     \\right)\n   \\\\\n     \\boldsymbol{h}_v^{(l)} = \\phi^{(l)} \\left(\n       \\boldsymbol{h}_v^{(l-1)}, \\boldsymbol{a}_v^{(l)}\n     \\right)\n   \\end{gathered}\n\n여기서, :math:`\\rho^{(l)}` 와 :math:`\\phi^{(l)}` 는 파라메터를 갖는 함수이고, :math:`\\mathcal{N}(v)`는 그래프 :math:`\\mathcal{G}` 에 속한 노드 :math:`v` 의 선행 노드(predecessor)들 (또는 방향성 그래프의 경우 *이웃 노드들*)의 집합을 의미한다.\n\n아래 그래프의 빨간색 노드를 업데이트하는 메시지 전달을 수행하기 위해서는,\n\n.. figure:: https://data.dgl.ai/asset/image/guide_6_4_0.png\n   :alt: Imgur\n\n아래 그림의 녹색으로 표시된 이웃 노드들의 노드 피쳐들을 합쳐야한다(aggregate).\n\n.. figure:: https://data.dgl.ai/asset/image/guide_6_4_1.png\n   :alt: Imgur\n\n이웃 샘플링 직접 해보기\n~~~~~~~~~~~~~~~~~~\n\n우선 위 그림의 그래프를 DGL 그래프로 정의한다.\n\n.. code:: python\n\n    import torch\n    import dgl\n\n    src = torch.LongTensor(\n        [0, 0, 0, 1, 2, 2, 2, 3, 3, 4, 4, 5, 5, 6, 7, 7, 8, 9, 10,\n         1, 2, 3, 3, 3, 4, 5, 5, 6, 5, 8, 6, 8, 9, 8, 11, 11, 10, 11])\n    dst = torch.LongTensor(\n        [1, 2, 3, 3, 3, 4, 5, 5, 6, 5, 8, 6, 8, 9, 8, 11, 11, 10, 11,\n         0, 0, 0, 1, 2, 2, 2, 3, 3, 4, 4, 5, 5, 6, 7, 7, 8, 9, 10])\n    g = dgl.graph((src, dst))\n\n그리고 노드 한개에 대한 결과를 계산하기 위해서 멀티-레이어 메시지 전달을 어떻게 수행할지를 고려하자. \n\n메시지 전달 의존성 찾기\n^^^^^^^^^^^^^^^^^\n\n아래 그래프에서 2-레이어 GNN을 사용해서 시드 노드 8의 결과를 계산하는 것을 생각해보자.\n\n.. figure:: https://data.dgl.ai/asset/image/guide_6_4_2.png\n   :alt: Imgur\n\n공식은 다음과 같다.\n\n.. math::\n\n\n   \\begin{gathered}\n     \\boldsymbol{a}_8^{(2)} = \\rho^{(2)} \\left(\n       \\left\\lbrace\n         \\boldsymbol{h}_u^{(1)} : u \\in \\mathcal{N} \\left( 8 \\right)\n       \\right\\rbrace\n     \\right) = \\rho^{(2)} \\left(\n       \\left\\lbrace\n         \\boldsymbol{h}_4^{(1)}, \\boldsymbol{h}_5^{(1)},\n         \\boldsymbol{h}_7^{(1)}, \\boldsymbol{h}_{11}^{(1)}\n       \\right\\rbrace\n     \\right)\n   \\\\\n     \\boldsymbol{h}_8^{(2)} = \\phi^{(2)} \\left(\n       \\boldsymbol{h}_8^{(1)}, \\boldsymbol{a}_8^{(2)}\n     \\right)\n   \\end{gathered}\n\n이 공식에 따르면, :math:`\\boldsymbol{h}_8^{(2)}` 을 계산하기 위해서는 아래 그림에서와 같이 (녹색으로 표시된) 노드 4,5,7 그리고 11번에서 에지을 따라서 메시지를 수집하는 것이 필요하다.\n\n.. figure:: https://data.dgl.ai/asset/image/guide_6_4_3.png\n   :alt: Imgur\n\n이 그래프는 원본 그래프의 모든 노드들을 포함하고 있지만, 특정 출력 노드들에 메시지를 전달할 에지들만을 포함하고 있다. 이런 그래프를 빨간색 노드 8에 대한 두번째 GNN 레이어에 대한 *프론티어(frontier)* 라고 부른다.\n\n프론티어들을 생성하는데 여러 함수들이 사용된다. 예를 들어, :func:`dgl.in_subgraph()` 는 원본 그래프의 모든 노드를 포함하지만, 특정 노드의 진입 에지(incoming edge)들만 포함하는 서브 그래프를 유도하는 함수이다.\n\n.. code:: python\n\n    frontier = dgl.in_subgraph(g, [8])\n    print(frontier.all_edges())\n\n전체 구현은 :ref:`api-subgraph-extraction` 와 :ref:`api-sampling` 를 참고하자.\n\n기술적으로는 원본 그래프와 같은 노들들 집합을 잡는 어떤 그래프도 프로티어가 될 수 있다. 이는 :ref:`guide_ko-minibatch-customizing-neighborhood-sampler-impl` 에 대한 기반이다.\n\n멀티-레이어 미니배치 메시지 전달을 위한 이분 구조(Bipartite Structure)\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n하지만, :math:`\\boldsymbol{h}_\\cdot^{(1)}` 로부터 단순히 :math:`\\boldsymbol{h}_8^{(2)}` 를 계산하는 것은 프론티어에서 메시지 전달을 계산하는 방식으로 할 수 없다. 그 이유는, 여전히 프론티어가 원본 그래프의 모든 노드를 포함하고 있기 때문이다. 이 그래프의 경우, (녹색과 빨간색 노드들) 4, 5, 7, 8, 11 노드들만이 입력으로 필요하고, 출력으로는 (빨간색 노드) 노드 8번이 필요하다. 입력과 출력의 노드 개수가 다르기 때문에, 작은 이분-구조(bipartite-structured) 그래프에서 메시지 전달을 수행할 필요가 있다.\n\n아래 그림은 노드 8에 대해서 2번째 GNN 레이어의 MFG를 보여준다.\n\n.. figure:: https://data.dgl.ai/asset/image/guide_6_4_4.png\n   :alt: Imgur\n\n.. note::\n\n   Message Flow Graph에 대한 개념은 :doc:`Stochastic Training Tutorial\n   <tutorials/large/L0_neighbor_sampling_overview>` 참고하자.\n\n목적지 노드들이 소스 노드에도 등장한다는 점을 유의하자. 그 이유는 메시지 전달(예를 들어, :math:`\\phi^{(2)}` )이 수행된 후에 이전 레이어의 목적지 노드들의 representation들이 피처를 합치는데 사용되기 때문이다.\n\nDGL은 임의의 프론티어를 MFG로 변환하는 :func:`dgl.to_block` 함수를 제공한다. 이 함수의 첫번째 인자는 프론티어이고, 두번째 인자는 목적지 노드들이다. 예를 들어, 위 프론티어는 목적지 노드 8에 대한 MFG로 전환하는 코드는 다음과 같다.\n\n.. code:: python\n\n    dst_nodes = torch.LongTensor([8])\n    block = dgl.to_block(frontier, dst_nodes)\n\n:meth:`dgl.DGLGraph.number_of_src_nodes` 와\n:meth:`dgl.DGLGraph.number_of_dst_nodes` 메소스들 사용해서 특정 노트 타입의 소스 노드 및 목적지 노드의 수를 알아낼 수 있다.\n\n.. code:: python\n\n    num_src_nodes, num_dst_nodes = block.number_of_src_nodes(), block.number_of_dst_nodes()\n    print(num_src_nodes, num_dst_nodes)\n\n:attr:`dgl.DGLGraph.srcdata` 와 :attr:`dgl.DGLGraph.srcnodes` 같은 멤머를 통해서 MFG의 소스 노드 피쳐들을 접근할 수 있고, :attr:`dgl.DGLGraph.dstdata` 와 :attr:`dgl.DGLGraph.dstnodes` 를 통해서는 목적지 노드의 피쳐들을 접근할 수 있다. ``srcdata`` / ``dstdata`` 와 ``srcnodes`` / ``dstnodes`` 의 사용법은 일반 그래프에 사용하는 :attr:`dgl.DGLGraph.ndata` 와 :attr:`dgl.DGLGraph.nodes` 와 동일하다.\n\n.. code:: python\n\n    block.srcdata['h'] = torch.randn(num_src_nodes, 5)\n    block.dstdata['h'] = torch.randn(num_dst_nodes, 5)\n\n만약 MFG가 프론티어에서 만들어졌다면, 즉 프래프에서 만들어졌다면, MFG의 소스 및 목적지 노드의 피쳐는 다음과 같이 직접 읽을 수 있다.\n\n.. code:: python\n\n    print(block.srcdata['x'])\n    print(block.dstdata['y'])\n\n.. note::\n\n   MFG에서의 소스 노드와 목적지 노드의 원본의 노드 ID는 ``dgl.NID`` 피쳐에 저장되어 있고, MFG의 에지 ID들와 프론티어의 에지 ID 사이의 매핑은 ``dgl.EID`` 에 있다.\n\nDGL에서는 MFG의 목적지 노드들이 항상 소스 노드에도 있도록 하고 있다. 다음 코드에서 알수 있듯이, 목적지 노드들은 소스 노드들에서 늘 먼저 위치한다.\n\n.. code:: python\n\n    src_nodes = block.srcdata[dgl.NID]\n    dst_nodes = block.dstdata[dgl.NID]\n    assert torch.equal(src_nodes[:len(dst_nodes)], dst_nodes)\n\n그 결과, 목적지 노드들은 프론티어의 에지들의 목적지인 모든 노들들을 포함해야 한다.\n\n예를 들어, 아래 프론티어를 생각해 보자.\n\n.. figure:: https://data.dgl.ai/asset/image/guide_6_4_5.png\n   :alt: Imgur\n\n여기서 빨간 노드와 녹색 노드들 (즉, 4, 5, 7, 8 그리고 11번 노드)는 에지의 목적지가 되는 노드들이다. 이 경우, 아래 코드는 에러를 발생시키는데, 이유는 목적지 노드 목록이 이들 노드를 모두 포함하지 않기 때문이다.\n\n.. code:: python\n\n    dgl.to_block(frontier2, torch.LongTensor([4, 5]))   # ERROR\n\n하지만, 목적지 노드들은 위 보다 더 많은 노드들을 포함할 수 있다. 이 예제의 경우, 어떤 에지도 연결되지 않은 고립된 노드들(isolated node)이 있고, 이 고립 노드들은 소스 노드와 목적지 노드 모두에 포함될 수 있다.\n\n.. code:: python\n\n    # Node 3 is an isolated node that do not have any edge pointing to it.\n    block3 = dgl.to_block(frontier2, torch.LongTensor([4, 5, 7, 8, 11, 3]))\n    print(block3.srcdata[dgl.NID])\n    print(block3.dstdata[dgl.NID])\n\nHeterogeneous 그래프들\n^^^^^^^^^^^^^^^^^^^^\n\nMFG들은 heterogeneous 그래프에도 적용됩니다. 다음 프론티어를 예로 들어보자.\n\n.. code:: python\n\n    hetero_frontier = dgl.heterograph({\n        ('user', 'follow', 'user'): ([1, 3, 7], [3, 6, 8]),\n        ('user', 'play', 'game'): ([5, 5, 4], [6, 6, 2]),\n        ('game', 'played-by', 'user'): ([2], [6])\n    }, num_nodes_dict={'user': 10, 'game': 10})\n\n목적지 노드들 User #3, #4, #8 그리고 Game #2, #6을 포함한 MFG를 생성한다.\n\n.. code:: python\n\n    hetero_block = dgl.to_block(hetero_frontier, {'user': [3, 6, 8], 'game': [2, 6]})\n\n소스 노드들과 목적지 노드들의 타입별로 얻을 수 있다.\n\n.. code:: python\n\n    # source users and games\n    print(hetero_block.srcnodes['user'].data[dgl.NID], hetero_block.srcnodes['game'].data[dgl.NID])\n    # destination users and games\n    print(hetero_block.dstnodes['user'].data[dgl.NID], hetero_block.dstnodes['game'].data[dgl.NID])\n\n\n.. _guide_ko-minibatch-customizing-neighborhood-sampler-impl:\n\n커스텀 이웃 샘플러 구현하기\n~~~~~~~~~~~~~~~~~~~~\n\n아래 코드는 노드 분류를 위한 이웃 샘플링을 수행한다는 것을 떠올려 보자.\n\n.. code:: python\n\n    sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2)\n\n이웃 샘플링 전략을 직접 구현하기 위해서는 ``sampler`` 를 직접 구현한 내용으로 바꾸기만 하면 된다. 이를 살펴보기 위해서, 우선 :class:`~dgl.dataloading.neighbor.MultiLayerFullNeighborSampler` 를 상속한 클래스인 :class:`~dgl.dataloading.dataloader.BlockSampler` 를 살펴보자.\n\n:class:`~dgl.dataloading.dataloader.BlockSampler` 클래스는 :meth:`~dgl.dataloading.dataloader.BlockSampler.sample_blocks` 메소드를 통해서 마지막 레이어로부터 시작하는 MFG들의 리스트를 만들어내는 역할을 한다. ``sample_blocks`` 의 기본 구현은 프론티어들과 그것들을 MFG들로 변환하면서 backwards를 iterate한다.\n\n따라서, 이웃 샘플링을 하기 위해서 단지 :meth:`~dgl.dataloading.dataloader.BlockSampler.sample_frontier` **메소드** 를 **구현하기만 하면된다**. 어떤 레이어를 위한 프론티어를 생성할 것인지, 원본 그래프, representation들을 계산할 노드들이 주어지면, 이 메소드는 그것들을 위한 프론티어를 생성하는것을 담당한다.\n\nGNN 레이어 수를 상위 클래스에 전달해야 한다.\n\n예를 들어, :class:`~dgl.dataloading.neighbor.MultiLayerFullNeighborSampler` 구현은 다음과 같다.\n\n.. code:: python\n\n    class MultiLayerFullNeighborSampler(dgl.dataloading.BlockSampler):\n        def __init__(self, n_layers):\n            super().__init__(n_layers)\n    \n        def sample_frontier(self, block_id, g, seed_nodes):\n            frontier = dgl.in_subgraph(g, seed_nodes)\n            return frontier\n\n:class:`dgl.dataloading.neighbor.MultiLayerNeighborSampler` 는 더 복잡한 이웃 샘플러로, 각 노들에 대해서 메시지를 수집할 적은 수의 이웃 노드들을 샘플하는 기능을 하는데, 구현은 다음과 같다.\n\n.. code:: python\n\n    class MultiLayerNeighborSampler(dgl.dataloading.BlockSampler):\n        def __init__(self, fanouts):\n            super().__init__(len(fanouts))\n    \n            self.fanouts = fanouts\n    \n        def sample_frontier(self, block_id, g, seed_nodes):\n            fanout = self.fanouts[block_id]\n            if fanout is None:\n                frontier = dgl.in_subgraph(g, seed_nodes)\n            else:\n                frontier = dgl.sampling.sample_neighbors(g, seed_nodes, fanout)\n            return frontier\n\n위의 함수는 프론티어를 생성하지만, 원본 그래프와 같은 노들을 갖는 어떤 그래프도 프론티어로 사용될 수 있다.\n\n예를 들어, 주어진 확률에 따라서 시드 노드들에 연결되는 인바운드 에지를 임의로 삭제하기를 원한다면, 다음과 같이 샘플러를 정의할 수 있다.\n\n.. code:: python\n\n    class MultiLayerDropoutSampler(dgl.dataloading.BlockSampler):\n        def __init__(self, p, num_layers):\n            super().__init__(num_layers)\n    \n            self.p = p\n    \n        def sample_frontier(self, block_id, g, seed_nodes, *args, **kwargs):\n            # Get all inbound edges to `seed_nodes`\n            src, dst = dgl.in_subgraph(g, seed_nodes).all_edges()\n            # Randomly select edges with a probability of p\n            mask = torch.zeros_like(src).bernoulli_(self.p)\n            src = src[mask]\n            dst = dst[mask]\n            # Return a new graph with the same nodes as the original graph as a\n            # frontier\n            frontier = dgl.graph((src, dst), num_nodes=g.num_nodes())\n            return frontier\n    \n        def __len__(self):\n            return self.num_layers\n\n샘플러를 직접 구현한 다음에는, 그 샘플러를 사용하는 데이터 로더를 생성하고, 예전과 같이 시드 노드들을 iterate하면서 MFG들의 리스트를 만들게 한다.\n\n.. code:: python\n\n    sampler = MultiLayerDropoutSampler(0.5, 2)\n    dataloader = dgl.dataloading.NodeDataLoader(\n        g, train_nids, sampler,\n        batch_size=1024,\n        shuffle=True,\n        drop_last=False,\n        num_workers=4)\n    \n    model = StochasticTwoLayerRGCN(in_features, hidden_features, out_features)\n    model = model.cuda()\n    opt = torch.optim.Adam(model.parameters())\n    \n    for input_nodes, blocks in dataloader:\n        blocks = [b.to(torch.device('cuda')) for b in blocks]\n        input_features = blocks[0].srcdata     # returns a dict\n        output_labels = blocks[-1].dstdata     # returns a dict\n        output_predictions = model(blocks, input_features)\n        loss = compute_loss(output_labels, output_predictions)\n        opt.zero_grad()\n        loss.backward()\n        opt.step()\n\nHeterogeneous 그래프들\n^^^^^^^^^^^^^^^^^^^^\n\nHeterogeneous 그래프에 대한 프론티어를 생성하는 것은 homogeneous 그래프의 경우와 동일하다. 리턴된 그래프가 원본 그래프와 같은 노드들을 갖도록 하면, 나머지는 그대로 동작할 것이다. 예를 들어, 위 ``MultiLayerDropoutSampler`` 를 재작성해서 모든 에지 타입들을 iterate 해서, heterogeneous 그래프에도 작동하게 만들 수 있다.\n\n.. code:: python\n\n    class MultiLayerDropoutSampler(dgl.dataloading.BlockSampler):\n        def __init__(self, p, num_layers):\n            super().__init__(num_layers)\n    \n            self.p = p\n    \n        def sample_frontier(self, block_id, g, seed_nodes, *args, **kwargs):\n            # Get all inbound edges to `seed_nodes`\n            sg = dgl.in_subgraph(g, seed_nodes)\n    \n            new_edges_masks = {}\n            # Iterate over all edge types\n            for etype in sg.canonical_etypes:\n                edge_mask = torch.zeros(sg.num_edges(etype))\n                edge_mask.bernoulli_(self.p)\n                new_edges_masks[etype] = edge_mask.bool()\n    \n            # Return a new graph with the same nodes as the original graph as a\n            # frontier\n            frontier = dgl.edge_subgraph(new_edges_masks, relabel_nodes=False)\n            return frontier\n    \n        def __len__(self):\n            return self.num_layers\n"
  },
  {
    "path": "docs/source/guide_ko/minibatch-edge.rst",
    "content": ".. _guide_ko-minibatch-edge-classification-sampler:\n\n6.2 이웃 샘플링을 사용한 에지 분류 GNN 모델 학습하기\n-----------------------------------------\n\n:ref:`(English Version) <guide-minibatch-edge-classification-sampler>`\n\n에지 분류/리그레션 모델을 학습하는 것은 몇 가지 눈에 띄는 차이점이 있지만 노드 분류/리그레션과 어느정도 비슷하다.\n\n이웃 샘플러 및 데이터 로더 정의하기\n~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n:ref:`노드 분류에서 사용한 것과 같은 이웃 샘플러<guide_ko-minibatch-node-classification-sampler>` 를 사용할 수 있다.\n\n.. code:: python\n\n    sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2)\n\n에지 분류에 DGL이 제공하는 이웃 샘플러를 사용하려면, 미니-배치의 에지들의 집합을 iterate 하는 :class:`~dgl.dataloading.pytorch.EdgeDataLoader` 와 함께 사용해야한다. 이것은 아래 모듈에서 사용될 에지 미니-배치로부터 만들어질 서브 그래프와 *message flow graph* (MFG)들을 리턴한다.\n\n다음 코드 예제는 PyTorch DataLoader를 만든다. 이는 베치들에 있는 학습 에지 ID 배열 :math:`train_eids` 들을 iterate 하고, 생성된 MFG들의 리스트를 GPU로 옮겨놓는다.\n\n.. code:: python\n\n    dataloader = dgl.dataloading.EdgeDataLoader(\n        g, train_eid_dict, sampler,\n        batch_size=1024,\n        shuffle=True,\n        drop_last=False,\n        num_workers=4)\n\n.. note::\n\n   Message flow graph의 개념은 :doc:`Stochastic Training Tutorial <tutorials/large/L0_neighbor_sampling_overview>` 를 참고하자.\n\n   빌트인으로 지원되는 샘플러들에 대한 전체 목록은 :ref:`neighborhood sampler API reference <api-dataloading-neighbor-sampling>` 에 있다.\n\n   :ref:`guide_ko-minibatch-customizing-neighborhood-sampler` 에는 여러분만의 이웃 샘플러 만드는 방법과 MFG 개념에 대한 보다 상세한 설명을 담고 있다.\n\n이웃 샘플링을 위해서 원본 그래프에서 미니 배치의 에지들 제거하기\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n에지 분류 모델을 학습할 때, 때로는 computation dependency에서 학습 데이터에 있는 에지들을 존재하지 않았던 것처럼 만들기 위해 제거하는 것이 필요하다. 그렇지 않으면, 모델은 두 노드들 사이에 에지가 존재한다는 사실을 *인지* 할 것이고, 이 정보를 학습에 잠재적으로 이용할 수 있기 때문이다.\n\n따라서, 에지 분류의 경우 때로는 이웃 샘플링은 미니-배치안에 샘플된 에지들 및 undirected 그래프인 경우 샘플된 에지의 역방향 에지들도 원본 그래프에서 삭제하기도 한다. :class:`~dgl.dataloading.pytorch.EdgeDataLoader` 객체를 만들 때, ``exclude='reverse_id'`` 를 에지 ID와 그와 연관된 reverse 에지 ID들의 매핑 정보와 함께 지정할 수 있다.\n\n.. code:: python\n\n    n_edges = g.num_edges()\n    dataloader = dgl.dataloading.EdgeDataLoader(\n        g, train_eid_dict, sampler,\n    \n        # The following two arguments are specifically for excluding the minibatch\n        # edges and their reverse edges from the original graph for neighborhood\n        # sampling.\n        exclude='reverse_id',\n        reverse_eids=torch.cat([\n            torch.arange(n_edges // 2, n_edges), torch.arange(0, n_edges // 2)]),\n    \n        batch_size=1024,\n        shuffle=True,\n        drop_last=False,\n        num_workers=4)\n\n모델을 미니-배치 학습에 맞게 만들기\n~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n에지 분류 모델은 보통은 다음과 같이 두 부분으로 구성된다:\n\n- 첫번째는 부속 노드(incident node)들의 representation을 얻는 부분\n- 두번째는 부속 노드의 representation들로부터 에지 점수를 계산하는 부분\n\n첫번째 부분은 :ref:`노드 분류<guide_ko-minibatch-node-classification-model>` 와 완전히 동일하기에, 단순하게 이를 재사용할 수 있다. 입력 DGL에서 제공하는 데이터 로더가 만들어 낸 MFG들의 리스트와 입력 피쳐들이 된다.\n\n.. code:: python\n\n    class StochasticTwoLayerGCN(nn.Module):\n        def __init__(self, in_features, hidden_features, out_features):\n            super().__init__()\n            self.conv1 = dglnn.GraphConv(in_features, hidden_features)\n            self.conv2 = dglnn.GraphConv(hidden_features, out_features)\n    \n        def forward(self, blocks, x):\n            x = F.relu(self.conv1(blocks[0], x))\n            x = F.relu(self.conv2(blocks[1], x))\n            return x\n\n두번째 부분에 대한 입력은 보통은 이전 부분의 출력과 미니배치의 에지들에 의해서 유도된 원본 그래프의 서브 그래프가 된다. 서브 그래프는 같은 데이터 로더에서 리턴된다. :meth:`dgl.DGLGraph.apply_edges` 를 사용해서 에지 서브 그래프를 사용해서 에지들의 점수를 계산한다.\n\n다음 코드는 부속 노드 피처들을 연결하고, 이를 dense 레이어에 입력해서 얻은 결과로 에지들의 점수를 예측하는 예를 보여준다.\n\n.. code:: python\n\n    class ScorePredictor(nn.Module):\n        def __init__(self, num_classes, in_features):\n            super().__init__()\n            self.W = nn.Linear(2 * in_features, num_classes)\n    \n        def apply_edges(self, edges):\n            data = torch.cat([edges.src['x'], edges.dst['x']], 1)\n            return {'score': self.W(data)}\n    \n        def forward(self, edge_subgraph, x):\n            with edge_subgraph.local_scope():\n                edge_subgraph.ndata['x'] = x\n                edge_subgraph.apply_edges(self.apply_edges)\n                return edge_subgraph.edata['score']\n\n전체 모델은 아래와 같이 데이터 로더로부터 얻은 MFG들의 리스트와 에지 서브 그래프, 그리고 입력 노드 피쳐들을 사용한다.\n\n.. code:: python\n\n    class Model(nn.Module):\n        def __init__(self, in_features, hidden_features, out_features, num_classes):\n            super().__init__()\n            self.gcn = StochasticTwoLayerGCN(\n                in_features, hidden_features, out_features)\n            self.predictor = ScorePredictor(num_classes, out_features)\n    \n        def forward(self, edge_subgraph, blocks, x):\n            x = self.gcn(blocks, x)\n            return self.predictor(edge_subgraph, x)\n\nDGL에서는 에지 서브 그래프의 노드들이 MFG들의 리스트에서 마지막 MFG의 출력 노드들과 동일하도록 확인한다.\n\n학습 룹\n~~~~~\n\n학습 룹은 노드 분류의 학습 룹과 비슷하다. 데이터 로더를 iterate해서, 미니배치의 에지들에 의해서 유도된 서브 그래프와 에지들의 부속 노드(incident node)들의 representation들을 계산하기 위한 MFG들의 목록을 얻는다.\n\n.. code:: python\n\n    model = Model(in_features, hidden_features, out_features, num_classes)\n    model = model.cuda()\n    opt = torch.optim.Adam(model.parameters())\n    \n    for input_nodes, edge_subgraph, blocks in dataloader:\n        blocks = [b.to(torch.device('cuda')) for b in blocks]\n        edge_subgraph = edge_subgraph.to(torch.device('cuda'))\n        input_features = blocks[0].srcdata['features']\n        edge_labels = edge_subgraph.edata['labels']\n        edge_predictions = model(edge_subgraph, blocks, input_features)\n        loss = compute_loss(edge_labels, edge_predictions)\n        opt.zero_grad()\n        loss.backward()\n        opt.step()\n\nHeterogeneous 그래프의 경우\n~~~~~~~~~~~~~~~~~~~~~~~~\n\nHeterogeneous 그래프들의 노드 representation들을 계산하는 모델은 에지 분류/리그레션을 위한 부속 노드 representation들을 구하는데 사용될 수 있다.\n\n.. code:: python\n\n    class StochasticTwoLayerRGCN(nn.Module):\n        def __init__(self, in_feat, hidden_feat, out_feat, rel_names):\n            super().__init__()\n            self.conv1 = dglnn.HeteroGraphConv({\n                    rel : dglnn.GraphConv(in_feat, hidden_feat, norm='right')\n                    for rel in rel_names\n                })\n            self.conv2 = dglnn.HeteroGraphConv({\n                    rel : dglnn.GraphConv(hidden_feat, out_feat, norm='right')\n                    for rel in rel_names\n                })\n    \n        def forward(self, blocks, x):\n            x = self.conv1(blocks[0], x)\n            x = self.conv2(blocks[1], x)\n            return x\n\n점수를 예측하기 위한 homogeneous 그래프와 heterogeneous 그래프간의 유일한 구현상의 차이점은 :meth:`~dgl.DGLGraph.apply_edges` 를 호출할 때 에지 타입들을 사용한다는 점이다.\n\n.. code:: python\n\n    class ScorePredictor(nn.Module):\n        def __init__(self, num_classes, in_features):\n            super().__init__()\n            self.W = nn.Linear(2 * in_features, num_classes)\n    \n        def apply_edges(self, edges):\n            data = torch.cat([edges.src['x'], edges.dst['x']], 1)\n            return {'score': self.W(data)}\n    \n        def forward(self, edge_subgraph, x):\n            with edge_subgraph.local_scope():\n                edge_subgraph.ndata['x'] = x\n                for etype in edge_subgraph.canonical_etypes:\n                    edge_subgraph.apply_edges(self.apply_edges, etype=etype)\n                return edge_subgraph.edata['score']\n\n    class Model(nn.Module):\n        def __init__(self, in_features, hidden_features, out_features, num_classes,\n                     etypes):\n            super().__init__()\n            self.rgcn = StochasticTwoLayerRGCN(\n                in_features, hidden_features, out_features, etypes)\n            self.pred = ScorePredictor(num_classes, out_features)\n\n        def forward(self, edge_subgraph, blocks, x):\n            x = self.rgcn(blocks, x)\n            return self.pred(edge_subgraph, x)\n\n데이터 로더 구현도 노드 분류을 위한 것과 아주 비슷하다. 유일한 차이점은 :class:`~dgl.dataloading.pytorch.NodeDataLoader` 대신에 :class:`~dgl.dataloading.pytorch.EdgeDataLoader` 를 사용하고, 노드 타입과 노드 ID 텐서들의 사전 대신에 에지 타입과 에지 ID 텐서들의 사전을 사용한다는 것이다.\n\n.. code:: python\n\n    sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2)\n    dataloader = dgl.dataloading.EdgeDataLoader(\n        g, train_eid_dict, sampler,\n        batch_size=1024,\n        shuffle=True,\n        drop_last=False,\n        num_workers=4)\n\n만약 heterogeneous 그래프에서 역방향의 에지를 배제하고자 한다면 약간 달라진다. Heterogeneous 그래프에서 역방향 에지들은 에지와는 다른 에지 타입을 갖는 것이 보통이다. 이는 “forward”와 “backward” 관계들을 구분직기 위해서이다. (즉, ``follow`` 와 ``followed by`` 는 서로 역 관계이고, ``purchase`` 와 ``purchased by`` 는 서로 역 관계인 것 처럼)\n\n만약 어떤 타입의 에지들이 다른 타입의 같은 ID를 갖는 역방향 에지를 갖는다면, 에지 타입들과 \n그것들의 반대 타입간의 매핑을 명시할 수 있다. 미니배치에서 에지들과 그것들의 역방향 에지를 배제하는 것은\n다음과 같다.\n\n.. code:: python\n\n    dataloader = dgl.dataloading.EdgeDataLoader(\n        g, train_eid_dict, sampler,\n    \n        # The following two arguments are specifically for excluding the minibatch\n        # edges and their reverse edges from the original graph for neighborhood\n        # sampling.\n        exclude='reverse_types',\n        reverse_etypes={'follow': 'followed by', 'followed by': 'follow',\n                        'purchase': 'purchased by', 'purchased by': 'purchase'}\n    \n        batch_size=1024,\n        shuffle=True,\n        drop_last=False,\n        num_workers=4)\n\n학습 룹은 ``compute_loss`` 의 구현이 노드 타입들과 예측 값에 대한 두 사전들을 인자로 받는다는 점을 제외하면,\nhomogeneous 그래프의 학습 룹 구현과 거의 같다.\n\n.. code:: python\n\n    model = Model(in_features, hidden_features, out_features, num_classes, etypes)\n    model = model.cuda()\n    opt = torch.optim.Adam(model.parameters())\n    \n    for input_nodes, edge_subgraph, blocks in dataloader:\n        blocks = [b.to(torch.device('cuda')) for b in blocks]\n        edge_subgraph = edge_subgraph.to(torch.device('cuda'))\n        input_features = blocks[0].srcdata['features']\n        edge_labels = edge_subgraph.edata['labels']\n        edge_predictions = model(edge_subgraph, blocks, input_features)\n        loss = compute_loss(edge_labels, edge_predictions)\n        opt.zero_grad()\n        loss.backward()\n        opt.step()\n\n`GCMC <https://github.com/dmlc/dgl/tree/master/examples/pytorch/gcmc>`__ 은 이분 그래프(bipartite graph)에 대한 에지 분류 예제이다.\n\n"
  },
  {
    "path": "docs/source/guide_ko/minibatch-gpu-sampling.rst",
    "content": ".. _guide_ko-minibatch-gpu-sampling:\n\n6.7 이웃 샘플링에 GPU 사용하기\n------------------------\n\n:ref:`(English Version) <guide-minibatch-gpu-sampling>`\n\nDGL 0.7부터 GPU 기반의 이웃 샘플링을 지원하는데, 이는 CPU 기반의 이웃 샘플링에 비해서 상당한 속도 향상을 가져다 준다. 만약 다루는 그래프와 피쳐들이 GPU에 들어갈 수 있는 크기이고, 모델이 너무 많은 GPU 메모리를 차지하지 않는다면, GPU 메모리에 올려서 GPU 기반의 이웃 샘플링을 하는 것이 최선의 방법이다.\n\n예를 들어, `OGB Products <https://ogb.stanford.edu/docs/nodeprop/#ogbn-products>`__ 는 2.4M 노드들과 61M 에지들을 갖고, 각 노드는 100 차원의 피쳐를 갖는다. 노트 피쳐들을 모두 합해서 1GB 미만의 메모리를 차지하고, 그래프는 약 1GB 보다 적은 메모리를 사용한다. 그래프의 메모리 요구량은 에지의 개수에 관련이 있다. 따라서, 전체 그래프를 GPU에 로딩하는 것이 가능하다.\n\n.. note::\n\n   이 기능은 실험적인 것으로 개발이 진행 중이다. 추가 업데이트를 지켜보자.\n\nDGL 데이터 로더에서 GPU 기반의 이웃 샘플링 사용하기\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\nDGL 데이터 로더에서 GPU 기반의 이웃 샘플링은 다음 방법으로 사용할 수 있다.\n\n* 그래프를 GPU에 넣기\n* ``num_workers`` 인자를 0으로 설정하기. CUDA는 같은 context를 사용하는 멀티 프로세스를 지원하지 않기 때문이다.\n* ``device`` 인자를 GPU 디바이스로 설정하기\n\n:class:`~dgl.dataloading.pytorch.NodeDataLoader` 의 다른 모든 인자들은 다른 가이드와 튜토리얼에서 사용한 것돠 같다.\n\n.. code:: python\n\n   g = g.to('cuda:0')\n   dataloader = dgl.dataloading.NodeDataLoader(\n       g,                                # The graph must be on GPU.\n       train_nid,\n       sampler,\n       device=torch.device('cuda:0'),    # The device argument must be GPU.\n       num_workers=0,                    # Number of workers must be 0.\n       batch_size=1000,\n       drop_last=False,\n       shuffle=True)\n\nGPU 기반의 이웃 샘플링은 커스텀 이웃 샘플러가 두가지 조건을 충족하면 동작한다. (1) 커스텀 샘플러가 :class:`~dgl.dataloading.BlockSampler` 의 서브 클래스이고, (2) 샘플러가 GPU에서 완전하게 동작한다.\n\n.. note::\n\n   현재는 :class:`~dgl.dataloading.pytorch.EdgeDataLoader` 와 heterogeneous 그래프는 지원하지 않는다.\n\nGPU 기반의 이웃 샘플러를 DGL 함수와 함께 사용하기\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n다음 함수들은 GPU에서 작동을 지원한다.\n\n* :func:`dgl.sampling.sample_neighbors`\n\n  * 균일 샘플링(uniform sampling)만 지원함. non-uniform샘플링은 CPU에서만 동작함.\n\n위 함수들 이외의 GPU에서 동작하는 함수들은 :func:`dgl.to_block` 를 참고하자."
  },
  {
    "path": "docs/source/guide_ko/minibatch-inference.rst",
    "content": ".. _guide_ko-minibatch-inference:\n\n6.6 큰 그래프들에 대핸 정확한 오프라인 추론\n---------------------------------\n\n:ref:`(English Version) <guide-minibatch-inference>`\n\nGPU를 사용해서 GNN을 학습하는데 메모리와 걸리는 시간을 줄이기 위해서 서브 샘플링과 이웃 샘플링이 모두 사용된다. 추론을 수행할 때 보통은 샘플링으로 발생할 수 있는 임의성을 제거하기 위해서 전체 이웃들에 대해서 aggretate하는 것이 더 좋다. 하지만, GPU 메모리 제약이나, CPU의 느린 속도 때문에 전체 그래프에 대한 forward propagagtion을 수행하는 것은 쉽지 않다. 이 절은 미니배치와 이웃 샘플링을 통해서 제한적인 GPU를 사용한 전체 그래프 forward propagation의 방법을 소개한다.\n\n추론 알고리즘은 학습 알고리즘과는 다른데, 추론 알고리즘은 첫번째 레이어부터 시작해서 각 레이이별로 모든 노드의 representation들을 계산해야하기 때문이다. 특히, 특정 레이어의 경우에 우리는 미니배치의 모든 노드들에 대해서 이 레이어의 출력 representation을 계산해야한다. 그 결과, 추론 알고리즘은 모든 레이어들 iterate하는 outer 룹과 노들들의 미니배치를 iterate하는 inner 룹을 갖는다. 반면, 학습 알고리즘은 노드들의 미니배치를 iterate하는 outer 룹과, 이웃 샘플링과 메시지 전달을 위한 레이어들을 iterate하는 inner 룹을 갖는다.\n\n아래 애니매이션은 이 연산이 어떻게 일어나는지를 보여주고 있다 (각 레이어에 대해서 첫 3개의 미니배치만 표현되고 있음을 주의하자)\n\n.. figure:: https://data.dgl.ai/asset/image/guide_6_6_0.gif\n   :alt: Imgur\n\n오프라인 추론 구현하기\n~~~~~~~~~~~~~~~~\n\n6.1 :ref:`guide_ko-minibatch-node-classification-model` 에서 다룬 2-레이어 GCN을 생각해 보자. 오프라인 추론을 구현하는 방법은 여전히 :class:`~dgl.dataloading.neighbor.MultiLayerFullNeighborSampler` 를 사용하지만, 한번에 하나의 레이어에 대한 샘플링을 수행한다. 하나의 레이어에 대한 계산은 메시지들어 어떻게 aggregate되고 합쳐지는지에 의존하기 때문에 오프라인 추론은 GNN 모듈의 메소드로 구현된다는 점을 주목하자.\n\n.. code:: python\n\n    class StochasticTwoLayerGCN(nn.Module):\n        def __init__(self, in_features, hidden_features, out_features):\n            super().__init__()\n            self.hidden_features = hidden_features\n            self.out_features = out_features\n            self.conv1 = dgl.nn.GraphConv(in_features, hidden_features)\n            self.conv2 = dgl.nn.GraphConv(hidden_features, out_features)\n            self.n_layers = 2\n    \n        def forward(self, blocks, x):\n            x_dst = x[:blocks[0].number_of_dst_nodes()]\n            x = F.relu(self.conv1(blocks[0], (x, x_dst)))\n            x_dst = x[:blocks[1].number_of_dst_nodes()]\n            x = F.relu(self.conv2(blocks[1], (x, x_dst)))\n            return x\n    \n        def inference(self, g, x, batch_size, device):\n            \"\"\"\n            Offline inference with this module\n            \"\"\"\n            # Compute representations layer by layer\n            for l, layer in enumerate([self.conv1, self.conv2]):\n                y = torch.zeros(g.num_nodes(),\n                                self.hidden_features\n                                if l != self.n_layers - 1\n                                else self.out_features)\n                sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)\n                dataloader = dgl.dataloading.NodeDataLoader(\n                    g, torch.arange(g.num_nodes()), sampler,\n                    batch_size=batch_size,\n                    shuffle=True,\n                    drop_last=False)\n                \n                # Within a layer, iterate over nodes in batches\n                for input_nodes, output_nodes, blocks in dataloader:\n                    block = blocks[0]\n    \n                    # Copy the features of necessary input nodes to GPU\n                    h = x[input_nodes].to(device)\n                    # Compute output.  Note that this computation is the same\n                    # but only for a single layer.\n                    h_dst = h[:block.number_of_dst_nodes()]\n                    h = F.relu(layer(block, (h, h_dst)))\n                    # Copy to output back to CPU.\n                    y[output_nodes] = h.cpu()\n\n                x = y\n    \n            return y\n\n모델 선택을 위해서 검증 데이터셋에 평가 metric을 계산하는 목적으로 정확한 오프라인 추론을 계산할 필요가 없다는 점을 주목하자. 모든 레이어에 대해서 모든 노드들의 representation을 계산하는 것이 필요한데, 이것은 레이블이 없는 데이터가 많은 semi-supervised 영역에서는 아주 많은 리소스를 필요로하기 때문이다. 이웃 샘플링은 모델 선택 및 평가 목적으로는 충분하다.\n\n오프라인 추론의 예들로 `GraphSAGE <https://github.com/dmlc/dgl/blob/master/examples/pytorch/graphsage/train_sampling.py>`__ 및 \n`RGCN <https://github.com/dmlc/dgl/blob/master/examples/pytorch/rgcn-hetero/entity_classify_mb.py>`__ 를 참고하자.\n"
  },
  {
    "path": "docs/source/guide_ko/minibatch-link.rst",
    "content": ".. _guide_ko-minibatch-link-classification-sampler:\n\n6.3 이웃 샘플링을 사용한 링크 예측 GNN 모델 학습하기\n-----------------------------------------\n\n:ref:`(English Version) <guide-minibatch-link-classification-sampler>`\n\nNegative 샘플링을 사용한 이웃 샘플러 및 데이터 로더 정의하기\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n노드/에지 분류에서 사용한 이웃 샘플러를 그대로 사용하는 것이 가능하다.\n\n.. code:: python\n\n    sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2)\n\nDGL의 :class:`~dgl.dataloading.pytorch.EdgeDataLoader` 는 링크 예측를 위한 negative 샘플 생성을\n지원한다. 이를 사용하기 위해서는, negative 샘플링 함수를 제공해야한다. :class:`~dgl.dataloading.negative_sampler.Uniform` 은 uniform 샘플링을 해주는 함수이다. 에지의 각 소스 노드에 대해서,이 함수는 ``k`` 개의 negative 목적지 노드들을 샘플링한다.\n\n아래 코드는 에지의 각 소스 노드에 대해서 5개의 negative 목적지 노드를 균등하게 선택한다.\n\n.. code:: python\n\n    dataloader = dgl.dataloading.EdgeDataLoader(\n        g, train_seeds, sampler,\n        negative_sampler=dgl.dataloading.negative_sampler.Uniform(5),\n        batch_size=args.batch_size,\n        shuffle=True,\n        drop_last=False,\n        pin_memory=True,\n        num_workers=args.num_workers)\n\n빌드인 negative 샘플러들은 :ref:`api-dataloading-negative-sampling` 에서 확인하자.\n\n직접 만든 negative 샘플러 함수를 사용할 수도 있다. 이 함수는 원본 그래프 ``g`` 와, 미니배치 에지 ID 배열 ``eid`` 를 받아서\n소스 ID 배열과 목적지 ID 배열의 쌍을 리턴해야 한다.\n\n아래 코드 예제는 degree의 거듭제곱에 비례하는 확률 분포에 따라서 negative 목적지 노드들을 샘플링하는 custom negative 샘플러다.\n\n.. code:: python\n\n    class NegativeSampler(object):\n        def __init__(self, g, k):\n            # caches the probability distribution\n            self.weights = g.in_degrees().float() ** 0.75\n            self.k = k\n    \n        def __call__(self, g, eids):\n            src, _ = g.find_edges(eids)\n            src = src.repeat_interleave(self.k)\n            dst = self.weights.multinomial(len(src), replacement=True)\n            return src, dst\n    \n    dataloader = dgl.dataloading.EdgeDataLoader(\n        g, train_seeds, sampler,\n        negative_sampler=NegativeSampler(g, 5),\n        batch_size=args.batch_size,\n        shuffle=True,\n        drop_last=False,\n        pin_memory=True,\n        num_workers=args.num_workers)\n\n모델을 미니-배치 학습에 맞게 만들기\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n:ref:`guide_ko-training-link-prediction` 에서 설명한 것처럼, 링크 예측은 (positive 예제인) 에지의 점수와 존재하지 않는 에지(즉, negative 예제)의 점수를 비교하는 것을 통해서 학습될 수 있다. 에지들의 점수를 계산하기 위해서, 에지 분류/리그레션에서 사용했던 노드 representation 계산 모델을 재사용한다.\n\n.. code:: python\n\n    class StochasticTwoLayerGCN(nn.Module):\n        def __init__(self, in_features, hidden_features, out_features):\n            super().__init__()\n            self.conv1 = dgl.nn.GraphConv(in_features, hidden_features)\n            self.conv2 = dgl.nn.GraphConv(hidden_features, out_features)\n    \n        def forward(self, blocks, x):\n            x = F.relu(self.conv1(blocks[0], x))\n            x = F.relu(self.conv2(blocks[1], x))\n            return x\n\n점수 예측을 위해서 확률 분포 대신 각 에지의 scalar 점수를 예측하기만 하면되기 때문에, 이 예제는 부속 노드 representation들의 dot product로 점수를 계산하는 방법을 사용한다.\n\n.. code:: python\n\n    class ScorePredictor(nn.Module):\n        def forward(self, edge_subgraph, x):\n            with edge_subgraph.local_scope():\n                edge_subgraph.ndata['x'] = x\n                edge_subgraph.apply_edges(dgl.function.u_dot_v('x', 'x', 'score'))\n                return edge_subgraph.edata['score']\n\nNegative 샘플러가 지정되면, DGL의 데이터 로더는 미니배치 마다 다음 3가지 아이템들을 만들어낸다.\n\n- 샘플된 미니배치에 있는 모든 에지를 포함한 postive 그래프\n- Negative 샘플러가 생성한 존재하지 않는 에지 모두를 포함한 negative 그래프\n- 이웃 샘플러가 생성한 *message flow graph* (MFG)들의 리스트\n\n이제 3가지 아이템와 입력 피쳐들을 받는 링크 예측 모델을 다음과 같이 정의할 수 있다.\n\n.. code:: python\n\n    class Model(nn.Module):\n        def __init__(self, in_features, hidden_features, out_features):\n            super().__init__()\n            self.gcn = StochasticTwoLayerGCN(\n                in_features, hidden_features, out_features)\n    \n        def forward(self, positive_graph, negative_graph, blocks, x):\n            x = self.gcn(blocks, x)\n            pos_score = self.predictor(positive_graph, x)\n            neg_score = self.predictor(negative_graph, x)\n            return pos_score, neg_score\n\n학습 룹\n~~~~~\n\n학습 룹은 데이터 로더를 iterate하고, 그래프들과 입력 피쳐들을 위해서 정의한 모델에 입력하는 것일 뿐이다.\n\n.. code:: python\n\n    def compute_loss(pos_score, neg_score):\n        # an example hinge loss\n        n = pos_score.shape[0]\n        return (neg_score.view(n, -1) - pos_score.view(n, -1) + 1).clamp(min=0).mean()\n\n    model = Model(in_features, hidden_features, out_features)\n    model = model.cuda()\n    opt = torch.optim.Adam(model.parameters())\n    \n    for input_nodes, positive_graph, negative_graph, blocks in dataloader:\n        blocks = [b.to(torch.device('cuda')) for b in blocks]\n        positive_graph = positive_graph.to(torch.device('cuda'))\n        negative_graph = negative_graph.to(torch.device('cuda'))\n        input_features = blocks[0].srcdata['features']\n        pos_score, neg_score = model(positive_graph, negative_graph, blocks, input_features)\n        loss = compute_loss(pos_score, neg_score)\n        opt.zero_grad()\n        loss.backward()\n        opt.step()\n\nDGL에서는 homogeneous 그래프들에 대한 링크 예측의 예제로 `unsupervised learning GraphSAGE <https://github.com/dmlc/dgl/blob/master/examples/pytorch/graphsage/train_sampling_unsupervised.py>`__ 를 제공한다.\n\nHeterogeneous 그래프의 경우\n~~~~~~~~~~~~~~~~~~~~~~~~\n\nHeterogeneous 그래프들의 노드 representation들을 계산하는 모델은 에지 분류/리그레션을 위한 부속 노드\nrepresentation들을 구하는데 사용될 수 있다.\n\n.. code:: python\n\n    class StochasticTwoLayerRGCN(nn.Module):\n        def __init__(self, in_feat, hidden_feat, out_feat, rel_names):\n            super().__init__()\n            self.conv1 = dglnn.HeteroGraphConv({\n                    rel : dglnn.GraphConv(in_feat, hidden_feat, norm='right')\n                    for rel in rel_names\n                })\n            self.conv2 = dglnn.HeteroGraphConv({\n                    rel : dglnn.GraphConv(hidden_feat, out_feat, norm='right')\n                    for rel in rel_names\n                })\n    \n        def forward(self, blocks, x):\n            x = self.conv1(blocks[0], x)\n            x = self.conv2(blocks[1], x)\n            return x\n\n점수를 예측하기 위한 homogeneous 그래프와 heterogeneous 그래프간의 유일한 구현상의 차이점은\n:meth:`dgl.DGLGraph.apply_edges` 를 호출할 때 에지 타입들을 사용한다는 점이다.\n\n.. code:: python\n\n    class ScorePredictor(nn.Module):\n        def forward(self, edge_subgraph, x):\n            with edge_subgraph.local_scope():\n                edge_subgraph.ndata['x'] = x\n                for etype in edge_subgraph.canonical_etypes:\n                    edge_subgraph.apply_edges(\n                        dgl.function.u_dot_v('x', 'x', 'score'), etype=etype)\n                return edge_subgraph.edata['score']\n\n    class Model(nn.Module):\n        def __init__(self, in_features, hidden_features, out_features, num_classes,\n                     etypes):\n            super().__init__()\n            self.rgcn = StochasticTwoLayerRGCN(\n                in_features, hidden_features, out_features, etypes)\n            self.pred = ScorePredictor()\n\n        def forward(self, positive_graph, negative_graph, blocks, x):\n            x = self.rgcn(blocks, x)\n            pos_score = self.pred(positive_graph, x)\n            neg_score = self.pred(negative_graph, x)\n            return pos_score, neg_score\n\n데이터 로더 구현도 노드 분류을 위한 것과 아주 비슷하다. 유일한 차이점은 negative 샘플러를 사용하며, 노드 타입과 노드 ID 텐서들의 사전 대신에 에지 타입과 에지 ID 텐서들의 사전을 사용한다는 것이다.\n\n.. code:: python\n\n    sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2)\n    dataloader = dgl.dataloading.EdgeDataLoader(\n        g, train_eid_dict, sampler,\n        negative_sampler=dgl.dataloading.negative_sampler.Uniform(5),\n        batch_size=1024,\n        shuffle=True,\n        drop_last=False,\n        num_workers=4)\n\n만약 직접 만든 negative 샘플링 함수를 사용하기를 원한다면, 그 함수는 원본 그래프, 에지 타입과 에지 ID 텐서들의 dictionary를 인자로 받아야하고, 에지 타입들과 소스-목적지 배열 쌍의 dictionary를 리턴해야한다. 다음은 예제 함수이다.\n\n.. code:: python\n\n   class NegativeSampler(object):\n       def __init__(self, g, k):\n           # caches the probability distribution\n           self.weights = {\n               etype: g.in_degrees(etype=etype).float() ** 0.75\n               for etype in g.canonical_etypes}\n           self.k = k\n\n       def __call__(self, g, eids_dict):\n           result_dict = {}\n           for etype, eids in eids_dict.items():\n               src, _ = g.find_edges(eids, etype=etype)\n               src = src.repeat_interleave(self.k)\n               dst = self.weights[etype].multinomial(len(src), replacement=True)\n               result_dict[etype] = (src, dst)\n           return result_dict\n\n다음으로는 에지 타입들와 에지 ID들의 dictionary와 negative 샘플러를 데이터 로더에 전달한다. 예를 들면, 아래 코드는 heterogeneous 그래프의 모든 에지들을 iterate하는 예이다.\n\n.. code:: python\n\n    train_eid_dict = {\n        etype: g.edges(etype=etype, form='eid')\n        for etype in g.canonical_etypes}\n\n    dataloader = dgl.dataloading.EdgeDataLoader(\n        g, train_eid_dict, sampler,\n        negative_sampler=NegativeSampler(g, 5),\n        batch_size=1024,\n        shuffle=True,\n        drop_last=False,\n        num_workers=4)\n\n학습 룹은 ``compute_loss`` 의 구현이 노드 타입들과 예측 값에 대한 두 사전들을 인자로 받는다는 점을 제외하면, homogeneous 그래프의 학습 룹 구현과 거의 같다.\n\n.. code:: python\n\n    model = Model(in_features, hidden_features, out_features, num_classes, etypes)\n    model = model.cuda()\n    opt = torch.optim.Adam(model.parameters())\n    \n    for input_nodes, positive_graph, negative_graph, blocks in dataloader:\n        blocks = [b.to(torch.device('cuda')) for b in blocks]\n        positive_graph = positive_graph.to(torch.device('cuda'))\n        negative_graph = negative_graph.to(torch.device('cuda'))\n        input_features = blocks[0].srcdata['features']\n        pos_score, neg_score = model(positive_graph, negative_graph, blocks, input_features)\n        loss = compute_loss(pos_score, neg_score)\n        opt.zero_grad()\n        loss.backward()\n        opt.step()\n\n\n\n"
  },
  {
    "path": "docs/source/guide_ko/minibatch-nn.rst",
    "content": ".. _guide_ko-minibatch-custom-gnn-module:\n\n6.5 미니-배치 학습을 위한 커스텀 GNN 모듈 구현하기\n----------------------------------------\n\n:ref:`(English Version) <guide-minibatch-custom-gnn-module>`\n\nHomogeneous 그래프나 heterogeneous 그래프를 대상으로 전체 그래프를 업데이트하는 커스텀 GNN 모듈을 만드는 것에 익숙하다면, MFG에 대한 연산을 구현하는 코드도 비슷하다는 것을 알 수 있다. 차이점은 노드들이 입력 노드와 출력 노드로 나뉜다는 것 뿐이다.\n\n커스텀 graph convolution 모듈을 예로 들자. 이 코드는 단지 커스텀 GNN 모듈이 어떻게 동작하는지 보여주기 위함이지, 가장 효율적인 구현이 아님을 주의하자. \n\n.. code:: python\n\n    class CustomGraphConv(nn.Module):\n        def __init__(self, in_feats, out_feats):\n            super().__init__()\n            self.W = nn.Linear(in_feats * 2, out_feats)\n    \n        def forward(self, g, h):\n            with g.local_scope():\n                g.ndata['h'] = h\n                g.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h_neigh'))\n                return self.W(torch.cat([g.ndata['h'], g.ndata['h_neigh']], 1))\n\n전체 그래프에 대한 커스텀 메시지 전달 NN 모듈이 있고, 이를 MFG에서 작동하도록 만들고 싶다면, 다음과 같이 forward 함수를 다시 작성하는 것만이 필요하다. 전체 그래프에 대한 구현은 주석 처리를 했으니, 새로운 코드들과 비교해 보자.\n\n.. code:: python\n\n    class CustomGraphConv(nn.Module):\n        def __init__(self, in_feats, out_feats):\n            super().__init__()\n            self.W = nn.Linear(in_feats * 2, out_feats)\n    \n        # h is now a pair of feature tensors for input and output nodes, instead of\n        # a single feature tensor.\n        # def forward(self, g, h):\n        def forward(self, block, h):\n            # with g.local_scope():\n            with block.local_scope():\n                # g.ndata['h'] = h\n                h_src = h\n                h_dst = h[:block.number_of_dst_nodes()]\n                block.srcdata['h'] = h_src\n                block.dstdata['h'] = h_dst\n    \n                # g.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h_neigh'))\n                block.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h_neigh'))\n    \n                # return self.W(torch.cat([g.ndata['h'], g.ndata['h_neigh']], 1))\n                return self.W(torch.cat(\n                    [block.dstdata['h'], block.dstdata['h_neigh']], 1))\n\n일반적으로, 직접 구현한 NN 모듈이 MFG에서 동작하게 만들기 위해서는 다음과 같은 것을 해야한다.\n\n- 첫 몇 행들(row)을 잘라서 입력 피쳐들로부터 출력 노드의 피처를 얻는다. 행의 개수는 :meth:`block.number_of_dst_nodes <dgl.DGLGraph.number_of_dst_nodes>` 로 얻는다.\n- 원본 그래프가 한 하나의 노드 타입을 갖는 경우, :attr:`g.ndata <dgl.DGLGraph.ndata>` 를 입력 노드의 피쳐의 경우 :attr:`block.srcdata <dgl.DGLGraph.srcdata>` 로 또는 출력 노드의 피쳐의 경우 :attr:`block.dstdata <dgl.DGLGraph.dstdata>` 로 교체한다.\n- 원본 그래프가 여러 종류의 노드 타입을 갖는 경우, :attr:`g.nodes <dgl.DGLGraph.nodes>` 를 입력 노드의 피쳐의 경우 :attr:`block.srcnodes <dgl.DGLGraph.srcnodes>` 로 또는 출력 노드의 피처의 경우 :attr:`block.dstnodes <dgl.DGLGraph.dstnodes>` 로 교체한다.\n- :meth:`g.num_nodes <dgl.DGLGraph.num_nodes>` 를 입력 노드의 개수는 :meth:`block.number_of_src_nodes <dgl.DGLGraph.number_of_src_nodes>` 로 출력 노드의 개수는 :meth:`block.number_of_dst_nodes <dgl.DGLGraph.number_of_dst_nodes>` 로 각각 교체한다.\n\nHeterogeneous 그래프들\n~~~~~~~~~~~~~~~~~~~~\n\nHeterogeneous 그래프의 경우도 커스텀 GNN 모듈을 만드는 것은 비슷하다. 예를 들어, 전체 그래프에 적용되는 다음 모듈을 예로 들어보자.\n\n.. code:: python\n\n    class CustomHeteroGraphConv(nn.Module):\n        def __init__(self, g, in_feats, out_feats):\n            super().__init__()\n            self.Ws = nn.ModuleDict()\n            for etype in g.canonical_etypes:\n                utype, _, vtype = etype\n                self.Ws[etype] = nn.Linear(in_feats[utype], out_feats[vtype])\n            for ntype in g.ntypes:\n                self.Vs[ntype] = nn.Linear(in_feats[ntype], out_feats[ntype])\n    \n        def forward(self, g, h):\n            with g.local_scope():\n                for ntype in g.ntypes:\n                    g.nodes[ntype].data['h_dst'] = self.Vs[ntype](h[ntype])\n                    g.nodes[ntype].data['h_src'] = h[ntype]\n                for etype in g.canonical_etypes:\n                    utype, _, vtype = etype\n                    g.update_all(\n                        fn.copy_u('h_src', 'm'), fn.mean('m', 'h_neigh'),\n                        etype=etype)\n                    g.nodes[vtype].data['h_dst'] = g.nodes[vtype].data['h_dst'] + \\\n                        self.Ws[etype](g.nodes[vtype].data['h_neigh'])\n                return {ntype: g.nodes[ntype].data['h_dst'] for ntype in g.ntypes}\n\n``CustomHeteroGraphConv`` 에서의 원칙은 ``g.nodes`` 를 대상 피쳐가 입력 노드의 것인지 출력 노드의 것인지에 따라서 ``g.srcnodes`` 또는 ``g.dstnodes`` 바꾸는 것이다.\n\n.. code:: python\n\n    class CustomHeteroGraphConv(nn.Module):\n        def __init__(self, g, in_feats, out_feats):\n            super().__init__()\n            self.Ws = nn.ModuleDict()\n            for etype in g.canonical_etypes:\n                utype, _, vtype = etype\n                self.Ws[etype] = nn.Linear(in_feats[utype], out_feats[vtype])\n            for ntype in g.ntypes:\n                self.Vs[ntype] = nn.Linear(in_feats[ntype], out_feats[ntype])\n    \n        def forward(self, g, h):\n            with g.local_scope():\n                for ntype in g.ntypes:\n                    h_src, h_dst = h[ntype]\n                    g.dstnodes[ntype].data['h_dst'] = self.Vs[ntype](h[ntype])\n                    g.srcnodes[ntype].data['h_src'] = h[ntype]\n                for etype in g.canonical_etypes:\n                    utype, _, vtype = etype\n                    g.update_all(\n                        fn.copy_u('h_src', 'm'), fn.mean('m', 'h_neigh'),\n                        etype=etype)\n                    g.dstnodes[vtype].data['h_dst'] = \\\n                        g.dstnodes[vtype].data['h_dst'] + \\\n                        self.Ws[etype](g.dstnodes[vtype].data['h_neigh'])\n                return {ntype: g.dstnodes[ntype].data['h_dst']\n                        for ntype in g.ntypes}\n\nHomogeneous 그래프, 이분 그래프(bipartite graph), 그리고 MFG를 위한 모듈 작성하기\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nDGL의 모든 메시지 전달 모듈들은 homogeneous 그래프, 단방향 이분 그래프 (unidirectional bipartite graphs, 두개 노드 타입을 갖고, 하나의 에지 타입을 갖음), 그리고 하나의 에지 타입을 갖는 MFG에서 동작한다. 기본적으로 DGL 빌트인 뉴럴 네트워크 모듈의 입력 그래프와 피쳐는 아래 경우들 중에 하나를 만족해야 한다.\n\n- 입력 피쳐가 텐서들의 쌍인 경우, 입력 그래프는 단방향 이분(unidirectional bipartite) 그래프이어야 한다.\n- 입력 피쳐가 단일 텐서이고 입력 그래프가 MFG인 경우, DGL은 자동으로 출력 노드의 피쳐를 입력 노드 피처의 첫 몇개의 행으로 정의한다.\n- 입력 피쳐가 단일 텐서이고 입력 그래프가 MGF가 아닌 경우, 입력 그래프는 반드시 homogeneous여야 한다.\n\n다음 코드는 :class:`dgl.nn.pytorch.SAGEConv` 을 PyTorch로 단순하게 구현한 것이다. (MXNet이나 TensorFlow 버전도 제공함. (이 코드는 normalization이 제거되어 있고, mean aggregation만 사용한다.)\n\n.. code:: python\n\n    import dgl.function as fn\n    class SAGEConv(nn.Module):\n        def __init__(self, in_feats, out_feats):\n            super().__init__()\n            self.W = nn.Linear(in_feats * 2, out_feats)\n    \n        def forward(self, g, h):\n            if isinstance(h, tuple):\n                h_src, h_dst = h\n            elif g.is_block:\n                h_src = h\n                h_dst = h[:g.number_of_dst_nodes()]\n            else:\n                h_src = h_dst = h\n                 \n            g.srcdata['h'] = h_src\n            g.dstdata['h'] = h_dst\n            g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h_neigh'))\n            return F.relu(\n                self.W(torch.cat([g.dstdata['h'], g.dstdata['h_neigh']], 1)))\n\n:ref:`guide_ko-nn` 은 단방향 이분 그래프, homogeneous 그래프와 MFG에 적용되는 :class:`dgl.nn.pytorch.SAGEConv` 를 자세히 다루고 있다.\n\n\n"
  },
  {
    "path": "docs/source/guide_ko/minibatch-node.rst",
    "content": ".. _guide_ko-minibatch-node-classification-sampler:\n\n6.1 이웃 샘플링을 사용한 노드 분류 GNN 모델 학습하기\n-----------------------------------------\n\n:ref:`(English Version) <guide-minibatch-node-classification-sampler>`\n\nStochastic 학습이 되도록 모델을 만들기 위해서는, 다음과 같은 것이 필요하다.\n\n- 이웃 샘플러 정의하기\n- 미니 배치 학습이 되도록 모델을 변경하기\n- 학습 룹 고치기\n\n이제, 이 단계를 어떻게 구현하는 하나씩 살펴보자.\n\n이웃 샘플러 및 데이터 로더 정의하기\n~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nDGL는 계산하기를 원하는 노드들에 대해서 각 레이어에서 필요한 computation dependency들을 생성하는 몇 가지 이웃 샘플러 클래스들을 가지고 있다.\n\n가장 단순한 이웃 샘플러는 :class:`~dgl.dataloading.neighbor.MultiLayerFullNeighborSampler` 로, 노드가 그 노드의 모든 이웃들로부터 메시지를 수집하도록 해준다.\n\nDGL의 샘플러를 사용하기 위해서는 이를 미니배치에 있는 노드들의 집한은 iterate하는 :class:`~dgl.dataloading.pytorch.NodeDataLoader` 와 합쳐야한다.\n\n다음 예제 코드는 배치들의 학습 노드 ID 배열 ``train_nids`` 를 iterate하고, 생성된 MFG(Message Flow Graph)들의 목록을 GPU로 옮기는 PyTorch DataLoader를 만든다.\n\n.. code:: python\n\n    import dgl\n    import dgl.nn as dglnn\n    import torch\n    import torch.nn as nn\n    import torch.nn.functional as F\n    \n    sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2)\n    dataloader = dgl.dataloading.NodeDataLoader(\n        g, train_nids, sampler,\n        batch_size=1024,\n        shuffle=True,\n        drop_last=False,\n        num_workers=4)\n\nDataLoader를 iterate 하면서 각 레이어에 대한 computation dependency들을 대표하도록 특별하게 생성된 그래프들의 리스트를 얻을 수 있다. DGL에서 이것들을 *message flow graph* (MFG) 라고 부른다.\n\n.. code:: python\n\n    input_nodes, output_nodes, blocks = next(iter(dataloader))\n    print(blocks)\n\nIterator는 매번 세개의 아이템을 생성한다. ``input_nodes`` 는 ``output_nodes`` 의 representation을 계산하는데 필요한 노드들을 담고 있다. ``block`` 은 그것의 노드가 출력으로 계산되어야 할 각 GNN 레이어에 대해 어떤 노드 representation들이 입력으로 필요한지, 입력 노드들의 representation들이 출력 노드로 어떻게 전파되어야 하는지를 설명한다.\n\n.. note::\n\n   Message flow graph의 개념은 :doc:`Stochastic Training Tutorial <tutorials/large/L0_neighbor_sampling_overview>` 을 참고하자.\n\n   지원되는 빌드인 샘플러들의 전체 목록은 :ref:`neighborhood sampler API reference <api-dataloading-neighbor-sampling>` 에서 찾아볼 수 있다.\n\n   :ref:`guide_ko-minibatch-customizing-neighborhood-sampler` 에는 여러분만의 이웃 샘플러 만드는 방법과 MFG 개념에 대한 보다 상세한 설명을 담고 있다.\n\n\n.. _guide_ko-minibatch-node-classification-model:\n\n모델을 미니-배치 학습에 맞게 만들기\n~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n만약 DGL에서 제공하는 메시지 전달 모듈만을 사용하고 있다면, 모델을 미니-배치 학습에 맞도록 수정할 것은 적다. 멀티-레이어 GCN을 예로 들어보자. 그래프 전체에 대한 모델 구현은 아래와 같다.\n\n.. code:: python\n\n    class TwoLayerGCN(nn.Module):\n        def __init__(self, in_features, hidden_features, out_features):\n            super().__init__()\n            self.conv1 = dglnn.GraphConv(in_features, hidden_features)\n            self.conv2 = dglnn.GraphConv(hidden_features, out_features)\n    \n        def forward(self, g, x):\n            x = F.relu(self.conv1(g, x))\n            x = F.relu(self.conv2(g, x))\n            return x\n\n이 때, 변경해야할 것은 ``g`` 를 앞에서 생성된 ``block`` 로 교체하는 것이 전부이다.\n\n.. code:: python\n\n    class StochasticTwoLayerGCN(nn.Module):\n        def __init__(self, in_features, hidden_features, out_features):\n            super().__init__()\n            self.conv1 = dgl.nn.GraphConv(in_features, hidden_features)\n            self.conv2 = dgl.nn.GraphConv(hidden_features, out_features)\n    \n        def forward(self, blocks, x):\n            x = F.relu(self.conv1(blocks[0], x))\n            x = F.relu(self.conv2(blocks[1], x))\n            return x\n\n위 DGL ``GraphConv`` 모듈들은 데이터 로더가 생성한 ``block`` 의 원소를 argument로 받는다.\n\n:ref:`The API reference of each NN module <apinn>` 는 모듈이 MFG를 argument로 받을 수 있는지 없는지를 알려주고 있다.\n\n만약 여러분 자신의 메시지 전달 모듈을 사용하고 싶다면, :ref:`guide_ko-minibatch-custom-gnn-module` 를 참고하자.\n\n학습 룹\n~~~~~\n\n단순하게 학습 룹은 커스터마이징된 배치 iterator를 사용해서 데이터셋을 iterating하는 것으로 구성된다. MFG들의 리스트를 반환하는 매 iteration마다, 다음과 같은 일을 한다.\n\n1. 입력 노드들의 노드 피처들을 GPU로 로딩한다. 노드 피쳐들은 메모리나 외부 저장소에 저장되어 있을 수 있다. 그래프 전체 학습에서 모든 노드들의 피처를 로드하는 것과는 다르게, 입력 노드들의 피처만 로드하면 된다는 점을 유의하자.\n   \n\n   만약 피쳐들이 ``g.ndata`` 에 저장되어 있다면, 그 피쳐들은 ``blocks[0].srcdata`` 에 저장된 피쳐들, 즉 첫번째 MFG의 소스 노드들의 피처들을 접근해서 로드될 수 있다. 여기서 노드들은 최종 representation을 계산하는데 필요한 모든 노드들을 의미한다.\n\n2. MFG들의 리스트 및 입력 노드 피쳐들을 멀티-레이어 GNN에 입력해서 결과를 \n얻는다.\n\n3. 출력 노드에 해당하는 노드 레이블을 GPU에 로드한다. 비슷하게, 노드 레이블은 메모리나 외부 저장소에 저장되어 있을 수 있다. 역시, 그래프 전체 학습에서 모든 노드들의 레이블을 로드하는 것과는 다르게, 출력 노드들의 레이블만 로드한다는 점을 알아두자.\n   \n   피처가 ``g.ndata`` 에 저장되어 있다면, 레이블은 ``blocks[-1].dstdata`` 의 피쳐들 즉, 마지막 MFG의 목적지 노드들의 피쳐들을 접근해서 로드될 수 있다. 이것들은 최종 representation을 계산할 노드들과 같다.\n\n4. loss를 계산한 후, backpropagate를 수행한다.\n\n.. code:: python\n\n    model = StochasticTwoLayerGCN(in_features, hidden_features, out_features)\n    model = model.cuda()\n    opt = torch.optim.Adam(model.parameters())\n    \n    for input_nodes, output_nodes, blocks in dataloader:\n        blocks = [b.to(torch.device('cuda')) for b in blocks]\n        input_features = blocks[0].srcdata['features']\n        output_labels = blocks[-1].dstdata['label']\n        output_predictions = model(blocks, input_features)\n        loss = compute_loss(output_labels, output_predictions)\n        opt.zero_grad()\n        loss.backward()\n        opt.step()\n\nDGL에서는 end-to-end stochastic 학습 예제인 `GraphSAGE\nimplementation <https://github.com/dmlc/dgl/blob/master/examples/pytorch/graphsage/node_classification.py>`__ 를 제공한다.\n\nHeterogeneous 그래프의 경우\n~~~~~~~~~~~~~~~~~~~~~~~~\n\nHeterogeneous 그래프에 대한 노드 분류 그래프 뉴럴 네트워크를 학습하는 것은 간단하다.\n\n:ref:`how to train a 2-layer RGCN on full graph <guide_ko-training-rgcn-node-classification>` 를 예로 들어보자. 미니-배치 학습을 하는 RGCN 구현 코드는 이 예제와 매우 비슷하다. (간단하게 하기 위해서 self-loop, non-linearity와 기본적인 decomposition은 제거했다.)\n\n.. code:: python\n\n    class StochasticTwoLayerRGCN(nn.Module):\n        def __init__(self, in_feat, hidden_feat, out_feat, rel_names):\n            super().__init__()\n            self.conv1 = dglnn.HeteroGraphConv({\n                    rel : dglnn.GraphConv(in_feat, hidden_feat, norm='right')\n                    for rel in rel_names\n                })\n            self.conv2 = dglnn.HeteroGraphConv({\n                    rel : dglnn.GraphConv(hidden_feat, out_feat, norm='right')\n                    for rel in rel_names\n                })\n    \n        def forward(self, blocks, x):\n            x = self.conv1(blocks[0], x)\n            x = self.conv2(blocks[1], x)\n            return x\n\n또한, DGL이 제공하는 일부 샘플러들은 heterogeneous 그래프를 지원한다. 예를 들어, 제공되는 :class:`~dgl.dataloading.neighbor.MultiLayerFullNeighborSampler` 클래스 및 :class:`~dgl.dataloading.pytorch.NodeDataLoader` 클래스를 stochastic 학습에도 여전히 사용할 수 있다. 전체 이웃 샘플링에서 다른 점은 학습 셋에 노드 타입들과 노드 ID들의 사전을 명시해야한다는 것 뿐이다.\n\n.. code:: python\n\n    sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2)\n    dataloader = dgl.dataloading.NodeDataLoader(\n        g, train_nid_dict, sampler,\n        batch_size=1024,\n        shuffle=True,\n        drop_last=False,\n        num_workers=4)\n\n학습 룹은 homogeneous 그래프에 대한 학습 룹이랑 거의 유사하다. 다른 점은 ``compute_loss`` 의 구현에서 노드 타입들와 예측 결과라는 두개의 dictionary들을 인자로 받는다는 것이다.\n\n.. code:: python\n\n    model = StochasticTwoLayerRGCN(in_features, hidden_features, out_features, etypes)\n    model = model.cuda()\n    opt = torch.optim.Adam(model.parameters())\n    \n    for input_nodes, output_nodes, blocks in dataloader:\n        blocks = [b.to(torch.device('cuda')) for b in blocks]\n        input_features = blocks[0].srcdata     # returns a dict\n        output_labels = blocks[-1].dstdata     # returns a dict\n        output_predictions = model(blocks, input_features)\n        loss = compute_loss(output_labels, output_predictions)\n        opt.zero_grad()\n        loss.backward()\n        opt.step()\n\nEnd-to-end stochastic 학습 예제는 `RGCN\nimplementation <https://github.com/dmlc/dgl/blob/master/examples/pytorch/rgcn-hetero/entity_classify_mb.py>`__ 를 참고하자.\n\n\n"
  },
  {
    "path": "docs/source/guide_ko/minibatch.rst",
    "content": ".. _guide_ko-minibatch:\n\n6장: 큰 그래프에 대한 stochastic 학습\n===============================\n\n:ref:`(English Version) <guide-minibatch>`\n\n만약 수백만, 수십억개의 노드들 또는 에지들을 갖는 큰 그래프인 경우에는 :ref:`guide_ko-training` 에서 소개한 그래프 전체를 사용한 학습을 적용하기 어려울 것이다. Hidden state 크기가 :math:`H` 인 노드가 :math:`N` 개인 그래프에 :math:`L` -레이어의 graph convolutional network를 생각해보자. 중간 hidden 상태를 저장하는데 :math:`(NLH)` 메모리가 필요하고, :math:`N` 이 큰 경우 GPU 하나의 용량을 훨씬 넘을 것이다.\n\n이 절에서 모든 노드들의 피쳐를 GPU에 올려야할 필요가 없는 stochastic 미니-배치 학습을 수행하는 법을 알아본다.\n\n이웃 샘플링(Neighborhood Sampling) 방법 개요\n---------------------------------------\n\n이웃 샘플링 방법은 일반적으로 다음과 같다. 각 gradient descent 단계마다, :math:`L-1` 레이어의 최종 representation을 계산되어야 할 노드들의 미니 배치를 선택한다. 그 다음으로 :math:`L-1` 레이어에서 그것들의 이웃 전체 또는 일부를 선택한다. 이 절차는 모델의 입력에 이를 때까지 반복된다. 이 반복 프로세스는 출력시작해서 거꾸로 입력까지의 의존성 그래프(dependency graph)를 생성하며, 이를 시각화하면 다음과 같다:\n\n.. figure:: https://data.dgl.ai/asset/image/guide_6_0_0.png\n   :alt: Imgur\n\n이를 사용하면, 큰 그래프에 대한 GNN 모델을 학습하는데 필요한 워크로드 및 연산 자원을 절약할 수 있다.\n\nDGL은 이웃 샘플링을 사용한 GNN 학습을 위한 몇 가지 이웃 샘플러들과 파이프라인을 제공한다. 또한, 샘플링 전략을 커스터마이징하는 방법도 지원한다.\n\n로드맵\n----\n\n이 장은 GNN은 stochastical하게 학습하는 여러 시나리오들로 시작한다.\n\n* :ref:`guide_ko-minibatch-node-classification-sampler`\n* :ref:`guide_ko-minibatch-edge-classification-sampler`\n* :ref:`guide_ko-minibatch-link-classification-sampler`\n\n이 후 절들에서는 새로운 샘플링 알고리즘들, 미니-배치 학습과 호환되는 새로운 GNN 모듈을 만들고자 하거나, 검증과 추론이 미니-배치에서 어떻게 수행되는지 이해하고 싶은 분들을 위한 보다 고급 토픽들을 다룬다.\n\n* :ref:`guide_ko-minibatch-customizing-neighborhood-sampler`\n* :ref:`guide_ko-minibatch-custom-gnn-module`\n* :ref:`guide_ko-minibatch-inference`\n\n마지막으로 이웃 샘플링을 구현하고 사용하는데 대한 성능 팁을 알아본다.\n\n* :ref:`guide_ko-minibatch-gpu-sampling`\n\n.. toctree::\n    :maxdepth: 1\n    :hidden:\n    :glob:\n\n    minibatch-node\n    minibatch-edge\n    minibatch-link\n    minibatch-custom-sampler\n    minibatch-nn\n    minibatch-inference\n    minibatch-gpu-sampling\n"
  },
  {
    "path": "docs/source/guide_ko/mixed_precision.rst",
    "content": ".. _guide_ko-mixed_precision:\n\n8장: Mixed Precision 학습\n=======================\n\n:ref:`(English Version) <guide-mixed_precision>`\n\nDGL은 mixed precision 학습을 위해서 `PyTorch's automatic mixed precision package <https://pytorch.org/docs/stable/amp.html>`_ 와 호환된다. 따라서, 학습 시간 및 GPU 메모리 사용량을 절약할 수 있다. \n\nHalf precision을 사용한 메시지 전달\n------------------------------\n\nfp16을 지원하는 DGL은 UDF(User Defined Function)이나 빌트인 함수(예, ``dgl.function.sum``,\n``dgl.function.copy_u``)를 사용해서 ``float16`` 피쳐에 대한 메시지 전달을 허용한다.\n\n\n다음 예제는 DGL 메시지 전달 API를 half-precision 피쳐들에 사용하는 방법을 보여준다.\n\n    >>> import torch\n    >>> import dgl\n    >>> import dgl.function as fn\n    >>> g = dgl.rand_graph(30, 100).to(0)  # Create a graph on GPU w/ 30 nodes and 100 edges.\n    >>> g.ndata['h'] = torch.rand(30, 16).to(0).half()  # Create fp16 node features.\n    >>> g.edata['w'] = torch.rand(100, 1).to(0).half()  # Create fp16 edge features.\n    >>> # Use DGL's built-in functions for message passing on fp16 features.\n    >>> g.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'x'))\n    >>> g.ndata['x'][0]\n    tensor([0.3391, 0.2208, 0.7163, 0.6655, 0.7031, 0.5854, 0.9404, 0.7720, 0.6562,\n            0.4028, 0.6943, 0.5908, 0.9307, 0.5962, 0.7827, 0.5034],\n           device='cuda:0', dtype=torch.float16)\n    >>> g.apply_edges(fn.u_dot_v('h', 'x', 'hx'))\n    >>> g.edata['hx'][0]\n    tensor([5.4570], device='cuda:0', dtype=torch.float16)\n    >>> # Use UDF(User Defined Functions) for message passing on fp16 features.\n    >>> def message(edges):\n    ...     return {'m': edges.src['h'] * edges.data['w']}\n    ...\n    >>> def reduce(nodes):\n    ...     return {'y': torch.sum(nodes.mailbox['m'], 1)}\n    ...\n    >>> def dot(edges):\n    ...     return {'hy': (edges.src['h'] * edges.dst['y']).sum(-1, keepdims=True)}\n    ...\n    >>> g.update_all(message, reduce)\n    >>> g.ndata['y'][0]\n    tensor([0.3394, 0.2209, 0.7168, 0.6655, 0.7026, 0.5854, 0.9404, 0.7720, 0.6562,\n            0.4028, 0.6943, 0.5908, 0.9307, 0.5967, 0.7827, 0.5039],\n           device='cuda:0', dtype=torch.float16)\n    >>> g.apply_edges(dot)\n    >>> g.edata['hy'][0]\n    tensor([5.4609], device='cuda:0', dtype=torch.float16)\n\n\nEnd-to-End Mixed Precision 학습\n------------------------------\n\nDGL은 PyTorch의 AMP package를 사용해서 mixed precision 학습을 구현하고 있어서, 사용 방법은 `PyTorch의 것 <https://pytorch.org/docs/stable/notes/amp_examples.html>`_ 과 동일하다.\n\nGNN 모델의 forward 패스(loss 계산 포함)를 ``torch.cuda.amp.autocast()`` 로 래핑하면 PyTorch는 각 op 및 텐서에 대해서 적절한 데이터 타입을 자동으로 선택한다. Half precision 텐서는 메모리 효율적이고, half precision 텐서에 대한 대부분 연산들은 GPU tensorcore들을 활용하기 때문에 더 빠르다.\n\n``float16`` 포멧의 작은 graident들은 언더플로우(underflow) 문제를 갖는데 (0이 되버림), PyTorch는 이를 해결하기 위해서 ``GradScaler`` 모듈을 제공한다. ``GradScaler`` 는 loss 값에 factor를 곱하고, 이 scaled loss에 backward pass를 수행한다. 그리고 파라메터들을 업데이트하는 optimizer를 수행하기 전에 unscale 한다.\n\n다음은 3-레이어 GAT를 Reddit 데이터셋(1140억개의 에지를 갖는)에 학습을 하는 스크립트이다. ``use_fp16`` 가 활성화/비활성화되었을 때의 코드 차이를 살펴보자.\n\n.. code::\n\n    import torch \n    import torch.nn as nn\n    import torch.nn.functional as F\n    from torch.cuda.amp import autocast, GradScaler\n    import dgl\n    from dgl.data import RedditDataset\n    from dgl.nn import GATConv\n\n    use_fp16 = True\n\n\n    class GAT(nn.Module):\n        def __init__(self,\n                     in_feats,\n                     n_hidden,\n                     n_classes,\n                     heads):\n            super().__init__()\n            self.layers = nn.ModuleList()\n            self.layers.append(GATConv(in_feats, n_hidden, heads[0], activation=F.elu))\n            self.layers.append(GATConv(n_hidden * heads[0], n_hidden, heads[1], activation=F.elu))\n            self.layers.append(GATConv(n_hidden * heads[1], n_classes, heads[2], activation=F.elu))\n\n        def forward(self, g, h):\n            for l, layer in enumerate(self.layers):\n                h = layer(g, h)\n                if l != len(self.layers) - 1:\n                    h = h.flatten(1)\n                else:\n                    h = h.mean(1)\n            return h\n\n    # Data loading\n    data = RedditDataset()\n    device = torch.device(0)\n    g = data[0]\n    g = dgl.add_self_loop(g)\n    g = g.int().to(device)\n    train_mask = g.ndata['train_mask']\n    features = g.ndata['feat']\n    labels = g.ndata['label']\n    in_feats = features.shape[1]\n    n_hidden = 256\n    n_classes = data.num_classes\n    n_edges = g.num_edges()\n    heads = [1, 1, 1]\n    model = GAT(in_feats, n_hidden, n_classes, heads)\n    model = model.to(device)\n\n    # Create optimizer\n    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4)\n    # Create gradient scaler\n    scaler = GradScaler()\n\n    for epoch in range(100):\n        model.train()\n        optimizer.zero_grad()\n\n        # Wrap forward pass with autocast\n        with autocast(enabled=use_fp16):\n            logits = model(g, features)\n            loss = F.cross_entropy(logits[train_mask], labels[train_mask])\n        \n        if use_fp16:\n            # Backprop w/ gradient scaling\n            scaler.scale(loss).backward()\n            scaler.step(optimizer)\n            scaler.update()\n        else:\n            loss.backward()\n            optimizer.step()\n\n        print('Epoch {} | Loss {}'.format(epoch, loss.item()))\n\nNVIDIA V100 (16GB) 한개를 갖는 컴퓨터에서, 이 모델을 fp16을 사용하지 않고 학습할 때는 15.2GB GPU 메모리가 사용되는데, fp16을 활성화하면, 학습에 12.8G GPU 메모리가 사용된며, 두 경우 loss가 비슷한 값으로 수렴한다. 만약 head의 갯수를 ``[2, 2, 2]`` 로 바꾸면, fp16를 사용하지 않는 학습은 GPU OOM(out-of-memory) 이슈가 생길 것이지만, fp16를 사용한 학습은 15.7G GPU 메모리를 사용하면서 수행된다.\n\nDGL은 half-precision 지원을 계속 향상하고 있고, 연산 커널의 성능은 아직 최적은 아니다. 앞으로의 업데이트를 계속 지켜보자."
  },
  {
    "path": "docs/source/guide_ko/nn-construction.rst",
    "content": ".. _guide_ko-nn-construction:\n\n3.1 DGL NN 모듈 생성 함수\n---------------------\n\n:ref:`(English Version) <guide-nn-construction>`\n\n생성 함수는 다음 단계들을 수행한다:\n\n1. 옵션 설정\n2. 학습할 파라메터 또는 서브모듈 등록\n3. 파라메터 리셋\n\n.. code::\n\n    import torch.nn as nn\n\n    from dgl.utils import expand_as_pair\n\n    class SAGEConv(nn.Module):\n        def __init__(self,\n                     in_feats,\n                     out_feats,\n                     aggregator_type,\n                     bias=True,\n                     norm=None,\n                     activation=None):\n            super(SAGEConv, self).__init__()\n\n            self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)\n            self._out_feats = out_feats\n            self._aggre_type = aggregator_type\n            self.norm = norm\n            self.activation = activation\n\n생성 함수를 만들 때 데이터 차원을 지정해야 한다. 일반적인 PyTorch 모듈의 경우에는 차원이란 보통은 입력 차원, 출력 차원, 그리고 은닉(hidden) 치원을 의미하는데, 그래프 뉴럴 네트워크의 경우 입력 차원은 소스 노드의 차원과 목적지 노드의 차원으로 나뉜다.\n\n데이터 차원들 이외의 전형적인 그래프 뉴럴 네트워크의 옵션으로 aggregation 타입( ``self._aggre_type`` )이 있다. Aggregation 타입은 특정 목적지 노드에 대해서 관련된 여러 에지의 메시지들이 어떻게 집합되어야 하는지를 결정한다. 흔히 사용되는 aggregation 타입으로는 ``mean`` , ``sum`` , ``max`` , ``min`` 이 있으며, 어떤 모듈은 ``lstm`` 과 같이 좀더 복잡한 aggregation을 적용하기도 한다.\n\n여기서 ``norm`` 은 피처 normalization을 위해서 호출될 수 있는 함수이다. SAGEConv 페이퍼에서는 l2 normlization, :math:`h_v = h_v / \\lVert h_v \\rVert_2` 이 normalization으로 사용되고 있다.\n\n.. code::\n\n            # aggregator type: mean, pool, lstm, gcn\n            if aggregator_type not in ['mean', 'pool', 'lstm', 'gcn']:\n                raise KeyError('Aggregator type {} not supported.'.format(aggregator_type))\n            if aggregator_type == 'pool':\n                self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)\n            if aggregator_type == 'lstm':\n                self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)\n            if aggregator_type in ['mean', 'pool', 'lstm']:\n                self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)\n            self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias)\n            self.reset_parameters()\n\n다음으로는 파라메터들과 서브모듈들을 등록한다. SAGEConv의 경우에는 서브모듈은 aggregation 타입에 따라 달라진다. 그 모듈들은 ``nn.Linear`` , ``nn.LSTM`` 등과 같은 순수한 PyTorch nn 모듈이다. 생성 함수의 마지막에는 ``reset_parameters()`` 호출로 가중치들을 초기화한다.\n\n.. code::\n\n        def reset_parameters(self):\n            \"\"\"Reinitialize learnable parameters.\"\"\"\n            gain = nn.init.calculate_gain('relu')\n            if self._aggre_type == 'pool':\n                nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)\n            if self._aggre_type == 'lstm':\n                self.lstm.reset_parameters()\n            if self._aggre_type != 'gcn':\n                nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)\n            nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)\n"
  },
  {
    "path": "docs/source/guide_ko/nn-forward.rst",
    "content": ".. _guide_ko-nn-forward:\n\n3.2 DGL NN 모듈의 Forward 함수\n---------------------------\n\n:ref:`(English Versin) <guide-nn-forward>`\n\nNN 모듈에서 ``forward()`` 함수는 실제 메시지 전달과 연산을 수행한다. 일반적으로 텐서들을 파라메터로 받는 PyTorch의 NN 모듈과 비교하면, DGL NN 모듈은 :class:`dgl.DGLGraph` 를 추가 파라메터로 받는다. ``forward()`` 함수는 3단계로 수행된다.\n\n- 그래프 체크 및 그래프 타입 명세화\n- 메시지 전달\n- 피쳐 업데이트\n\n이 절에서는 SAGEConv에서 사용되는 ``forward()`` 함수를 자세하게 살펴보겠다.\n\n그래프 체크와 그래프 타입 명세화(graph type specification)\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. code::\n\n        def forward(self, graph, feat):\n            with graph.local_scope():\n                # Specify graph type then expand input feature according to graph type\n                feat_src, feat_dst = expand_as_pair(feat, graph)\n\n``forward()`` 는 계산 및 메시지 전달 과정에서 유효하지 않은 값을 만들 수 있는 여러 특별한 케이스들을 다룰 수 있어야 한다. :class:`~dgl.nn.pytorch.conv.GraphConv` 와 같은 그래프 conv 모듈에서 수행하는 가장 전형적인 점검은 입력 그래프가 in-degree가 0인 노드를 갖지 않는지 확인하는 것이다. in-degree가 0인 경우에, ``mailbox`` 에 아무것도 없게 되고, 축약 함수는 모두 0인 값을 만들어낼 것이다. 이는 잠재적인 모델 성능 문제를 일이킬 수도 있다. 하지만, :class:`~dgl.nn.pytorch.conv.SAGEConv` 모듈의 경우, aggregated representation은 원래의 노드 피쳐와 연결(concatenated)되기 때문에, ``forward()`` 의 결과는 항상 0이 아니기 때문에, 이런 체크가 필요 없다.\n\nDGL NN 모듈은 여러 종류의 그래프, 단종 그래프, 이종 그래프(:ref:`guide_ko-graph-heterogeneous`), 서브그래프 블록(:ref:`guide_ko-minibatch` ), 입력에 걸쳐서 재사용될 수 있다. \n\nSAGEConv의 수학 공식은 다음과 같다:\n\n.. math::\n\n   h_{\\mathcal{N}(dst)}^{(l+1)}  = \\mathrm{aggregate}\n           \\left(\\{h_{src}^{l}, \\forall src \\in \\mathcal{N}(dst) \\}\\right)\n\n.. math::\n\n    h_{dst}^{(l+1)} = \\sigma \\left(W \\cdot \\mathrm{concat}\n           (h_{dst}^{l}, h_{\\mathcal{N}(dst)}^{l+1}) + b \\right)\n\n.. math::\n\n    h_{dst}^{(l+1)} = \\mathrm{norm}(h_{dst}^{l+1})\n\n그래프 타입에 따라서 소스 노드 피쳐(``feat_src``)와 목적지 노드 피쳐(``feat_dst``)를 명시해야 한다. :meth:`~dgl.utils.expand_as_pair` 는 명시된 그래프 타입에 따라 ``feat`` 를 ``feat_src`` 와 ``feat_dst`` 로 확장하는 함수이다. 이 함수의 동작은 다음과 같다.\n\n.. code::\n\n    def expand_as_pair(input_, g=None):\n        if isinstance(input_, tuple):\n            # Bipartite graph case\n            return input_\n        elif g is not None and g.is_block:\n            # Subgraph block case\n            if isinstance(input_, Mapping):\n                input_dst = {\n                    k: F.narrow_row(v, 0, g.number_of_dst_nodes(k))\n                    for k, v in input_.items()}\n            else:\n                input_dst = F.narrow_row(input_, 0, g.number_of_dst_nodes())\n            return input_, input_dst\n        else:\n            # Homogeneous graph case\n            return input_, input_\n\nhomogeneous 그래프 전체를 학습시키는 경우, 소스 노드와 목적지 노드들의 타입이 같다. 이것들은 그래프의 전체 노드들이다.\n\nHeterogeneous 그래프의 경우, 그래프는 여러 이분 그래프로 나뉠 수 있다. 즉, 각 관계당 하나의 그래프로. 관계는 ``(src_type, edge_type, dst_dtype)`` 로 표현된다. 입력 피쳐 ``feat`` 가 tuple 이라고 확인되면, 이 함수는 그 그래프는 이분 그래프로 취급한다. Tuple의 첫번째 요소는 소스 노드 피처이고, 두번째는 목적지 노드의 피처이다.\n\n미니-배치 학습의 경우, 연산이 여러 목적지 노드들을 기반으로 샘플된 서브 그래프에 적용된다. DGL에서 서브 그래프는 ``block`` 이라고 한다. 블록이 생성되는 단계에서, ``dst_nodes`` 가 노드 리스트의 앞에 놓이게 된다. ``[0:g.number_of_dst_nodes()]`` 인덱스를 이용해서 ``feat_dst`` 를 찾아낼 수 있다.\n\n``feat_src`` 와 ``feat_dst`` 가 정해진 후에는, 세가지 그래프 타입들에 대한 연산은 모두 동일하다.\n\n메시지 전달과 축약\n~~~~~~~~~~~~~~\n\n.. code::\n\n                import dgl.function as fn\n                import torch.nn.functional as F\n                from dgl.utils import check_eq_shape\n\n                if self._aggre_type == 'mean':\n                    graph.srcdata['h'] = feat_src\n                    graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh'))\n                    h_neigh = graph.dstdata['neigh']\n                elif self._aggre_type == 'gcn':\n                    check_eq_shape(feat)\n                    graph.srcdata['h'] = feat_src\n                    graph.dstdata['h'] = feat_dst\n                    graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh'))\n                    # divide in_degrees\n                    degs = graph.in_degrees().to(feat_dst)\n                    h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)\n                elif self._aggre_type == 'pool':\n                    graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))\n                    graph.update_all(fn.copy_u('h', 'm'), fn.max('m', 'neigh'))\n                    h_neigh = graph.dstdata['neigh']\n                else:\n                    raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))\n\n                # GraphSAGE GCN does not require fc_self.\n                if self._aggre_type == 'gcn':\n                    rst = self.fc_neigh(h_neigh)\n                else:\n                    rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)\n\n이 코드는 실제로 메시지 전달과 축약 연산을 실행하고 있다. 이 부분의 코드는 모듈에 따라 다르게 구현된다. 이 코드의 모든 메시지 전달은 :meth:`~dgl.DGLGraph.update_all` API와 ``built-in``  메시지/축약 함수들로 구현되어 있는데, 이는 :ref:`guide_ko-message-passing-efficient` 에서 설명된 DGL의 성능 최적화를 모두 활용하기 위해서이다.\n\n출력값을 위한 축약 후 피쳐 업데이트\n~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. code::\n\n                # activation\n                if self.activation is not None:\n                    rst = self.activation(rst)\n                # normalization\n                if self.norm is not None:\n                    rst = self.norm(rst)\n                return rst\n\n``forward()`` 함수의 마지막 부분은 ``reduce function`` 다음에 피쳐를 업데이트하는 것이다. 일반적인 업데이트 연산들은 활성화 함수를 적용하고, 객체 생성 단계에서 설정된 옵션에 따라 normalization을 수행한다.\n\n"
  },
  {
    "path": "docs/source/guide_ko/nn-heterograph.rst",
    "content": ".. _guide_ko-nn-heterograph:\n\n3.3 Heterogeneous GraphConv 모듈\n-------------------------------\n\n:ref:`(English Version) <guide-nn-heterograph>`\n\n:class:`~dgl.nn.pytorch.HeteroGraphConv` 는 heterogeneous 그래프들에 DGL NN 모듈을 적용하기 위한 모듈 수준의 인캡슐레이션이다. 메시지 전달 API :meth:`~dgl.DGLGraph.multi_update_all` 와 같은 로직으로 구현되어 있고, 이는 다음을 포함한다.\n\n- :math:`r` 관계에 대한 DGL NN 모듈\n- 한 노드에 연결된 여러 관계로부터 얻은 결과를 통합하는 축약(reduction)\n\n이는 다음과 같이 공식으로 표현된다:\n\n.. math::  h_{dst}^{(l+1)} = \\underset{r\\in\\mathcal{R}, r_{dst}=dst}{AGG} (f_r(g_r, h_{r_{src}}^l, h_{r_{dst}}^l))\n\n, 여기서 :math:`f_r` 는 각 :math:`r` 관계에 대한 NN 모듈이고, :math:`AGG` 는 aggregation 함수이다.\n\nHeteroGraphConv 구현 로직:\n~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. code::\n\n    import torch.nn as nn\n\n    class HeteroGraphConv(nn.Module):\n        def __init__(self, mods, aggregate='sum'):\n            super(HeteroGraphConv, self).__init__()\n            self.mods = nn.ModuleDict(mods)\n            if isinstance(aggregate, str):\n                # An internal function to get common aggregation functions\n                self.agg_fn = get_aggregate_fn(aggregate)\n            else:\n                self.agg_fn = aggregate\n\nHeterograph convolution은 각 관계를 NN 모듈에 매핑하는 ``mods`` 사전을 인자로 받고, 한 노드에 대한 여러 관계들의 결과를 집계하는 함수를 설정한다.\n\n.. code::\n\n    def forward(self, g, inputs, mod_args=None, mod_kwargs=None):\n        if mod_args is None:\n            mod_args = {}\n        if mod_kwargs is None:\n            mod_kwargs = {}\n        outputs = {nty : [] for nty in g.dsttypes}\n\n입력 그래프와 입력 텐서들과 더불어, ``forward()`` 함수는 두가지 추가적인 파라메터들, ``mod_args`` 와 ``mod_kwargs`` 을 받는다. 이것들은 ``self.mods`` 안에서, 다른 종류의 관계에 연관된 NN 모듈을 수행할 때, 커스터마이즈된 파라메터들로써 사용된다.\n\n각 목적지 타입 ``nty`` 에 대한 결과 텐서를 저장하기 위해서 결과 사전(output dictionary)가 생성된다. 각 ``nty`` 에 대한 값은 리스트이다. 이는 ``nty`` 를 목적 타입으로 갖을 관계가 여러개가 있는 경우, 단일 노드 타입이 여러 아웃풋들을 갖을 수 있음을 의미한다. ``HeteroGraphConv`` 는 이 리스트들에 대해서 추가적인 aggregation을 수행할 것이다.\n\n.. code::\n\n          if g.is_block:\n              src_inputs = inputs\n              dst_inputs = {k: v[:g.number_of_dst_nodes(k)] for k, v in inputs.items()}\n          else:\n              src_inputs = dst_inputs = inputs\n\n          for stype, etype, dtype in g.canonical_etypes:\n              rel_graph = g[stype, etype, dtype]\n              if rel_graph.num_edges() == 0:\n                  continue\n              if stype not in src_inputs or dtype not in dst_inputs:\n                  continue\n              dstdata = self.mods[etype](\n                  rel_graph,\n                  (src_inputs[stype], dst_inputs[dtype]),\n                  *mod_args.get(etype, ()),\n                  **mod_kwargs.get(etype, {}))\n              outputs[dtype].append(dstdata)\n\n입력 그래프 ``g`` 는 heterogeneous 그래프 또는 heterogeneous 그래프의 서브그래프 블록일 수 있다. 보통의 NN 모듈처럼, ``forward()`` 함수는 다양한 입력 그래프 타입들을 별로도 다룰 수 있어야 한다.\n\n각 관계는 ``(stype, etype, dtype)`` 인 ``canonical_etype`` 으로 표현된다. ``canonical_etype`` 을 키로 사용해서, 이분 그래프(bipartite graph)인 ``rel_graph`` 를 추출할 수 있다. 이분 그래프에서 입력 피쳐는 ``(src_inputs[stype], dst_inputs[dtype])`` 로 구성된다. 각 관계에 대한 NN 모듈이 호출되고, 결과는 저장된다. \n\n.. code::\n\n        rsts = {}\n        for nty, alist in outputs.items():\n            if len(alist) != 0:\n                rsts[nty] = self.agg_fn(alist, nty)\n\n마지막으로 한 목적 노드 타입에 대해 여러 관계로 부터 얻어진 결과들은 ``self.agg_fn`` 를 통해서 집계된다. :class:`~dgl.nn.pytorch.HeteroGraphConv` 의 API DOC에서 관련 예제들이 있다.\n"
  },
  {
    "path": "docs/source/guide_ko/nn.rst",
    "content": ".. _guide_ko-nn:\n\n3장: GNN 모듈 만들기\n=================\n\n:ref:`(English Version) <guide-nn>`\n\nDGL NN 모듈은 GNN 모델을 만드는데 필요한 빌딩 블록들로 구성되어 있다. NN 모듈은 백엔드로 사용되는 DNN 프레임워크에 따라 `Pytorch’s NN Module <https://pytorch.org/docs/1.2.0/_modules/torch/nn/modules/module.html>`__ , `MXNet Gluon’s NN Block  <http://mxnet.incubator.apache.org/versions/1.6/api/python/docs/api/gluon/nn/index.html>`__ 그리고 `TensorFlow’s Keras Layer <https://www.tensorflow.org/api_docs/python/tf/keras/layers>`__ 를 상속한다. DGL NN 모듈에서, 생성 함수에서의 파라메터 등록과 forward 함수에서 텐서 연산은 백엔드 프레임워크의 것과 동일하다. 이런 방식의 구현덕에 DGL 코드는 백엔드 프레임워크 코드와 원활하게 통합될 수 있다. 주요 차이점은 DGL 고유의 메시지 전달 연산에 존재한다.\n\nDGL은 일반적으로 많이 사용되는 :ref:`apinn-pytorch-conv` , :ref:`apinn-pytorch-dense-conv` , :ref:`apinn-pytorch-pooling` 와 :ref:`apinn-pytorch-util` 를 포함하고 있고. 여러분의 기여를 환영한다.\n\n이 장에서는 PyTorch 백엔드를 사용한 :class:`~dgl.nn.pytorch.conv.SAGEConv` 를 예제로 커스텀 DGL NN 모듈을 만드는 방법을 소개한다.\n\n로드맵\n----\n\n* :ref:`guide_ko-nn-construction`\n* :ref:`guide_ko-nn-forward`\n* :ref:`guide_ko-nn-heterograph`\n\n.. toctree::\n    :maxdepth: 1\n    :hidden:\n    :glob:\n\n    nn-construction\n    nn-forward\n    nn-heterograph\n"
  },
  {
    "path": "docs/source/guide_ko/training-edge.rst",
    "content": ".. _guide_ko-training-edge-classification:\n\n5.2 에지 분류 및 리그레션(Regression)\n--------------------------------\n\n:ref:`(English Version) <guide-training-edge-classification>`\n\n때론 그래프의 에지들의 속성을 예측을 원하는 경우가 있다. 이를 위해서 *에지 분류/리그레션* 모델을 만들고자 한다.\n\n우선, 예제로 사용할 에지 예측을 위한 임의의 그래프를 만든다.\n\n.. code:: python\n\n    src = np.random.randint(0, 100, 500)\n    dst = np.random.randint(0, 100, 500)\n    # make it symmetric\n    edge_pred_graph = dgl.graph((np.concatenate([src, dst]), np.concatenate([dst, src])))\n    # synthetic node and edge features, as well as edge labels\n    edge_pred_graph.ndata['feature'] = torch.randn(100, 10)\n    edge_pred_graph.edata['feature'] = torch.randn(1000, 10)\n    edge_pred_graph.edata['label'] = torch.randn(1000)\n    # synthetic train-validation-test splits\n    edge_pred_graph.edata['train_mask'] = torch.zeros(1000, dtype=torch.bool).bernoulli(0.6)\n\n개요\n~~~~~~~~~\n\n앞 절에서 우리는 멀티 레이어 GNN을 사용해서 노드 분류하는 방법을 알아봤다. 임의의 노드에 대한 hidden representation을 계산하기 위해서 같은 기법을 적용한다. 그러면 에지들에 대한 예측은 그것들의 부속 노드들의 representation들로 부터 도출할 수 있다.\n\n에지에 대한 예측을 계산하는 가장 일반적인 방법은 그 에지의 부속 노드들의 representation들과 부수적으로 그 에지에 대한 피쳐들의 parameterized 함수로 표현하는 것이다.\n\n노드 분류 모델과 구현상의 차이점\n~~~~~~~~~~~~~~~~~~~~~~~~\n\n이전 절에서 만든 모델을 사용해서 노드 representation을 계산한다고 가정하면, :meth:`~dgl.DGLGraph.apply_edges` 메소드로 에지 예측을 계산하는 컴포넌트만 작성하면 된다.\n\n예를 들어, 에지 리그레션을 위해서 각 에지에 대한 점수를 계산하고자 한다면, 아래 코드와 같이 각 에지에 대한 부속 노드의 representation들의 dot product를 계산하면 된다.\n\n.. code:: python\n\n    import dgl.function as fn\n    class DotProductPredictor(nn.Module):\n        def forward(self, graph, h):\n            # h contains the node representations computed from the GNN defined\n            # in the node classification section (Section 5.1).\n            with graph.local_scope():\n                graph.ndata['h'] = h\n                graph.apply_edges(fn.u_dot_v('h', 'h', 'score'))\n                return graph.edata['score']\n\n또한 MLP를 사용해서 각 에지에 대한 벡터 값을 예측하는 예측하는 함수를 작성할 수도 있다. 이 벡터 값은 미래의 다운스트림 테스크들에 사용될 수 있다. 즉, 범주형 분류의 logit으로 사용.\n\n.. code:: python\n\n    class MLPPredictor(nn.Module):\n        def __init__(self, in_features, out_classes):\n            super().__init__()\n            self.W = nn.Linear(in_features * 2, out_classes)\n\n        def apply_edges(self, edges):\n            h_u = edges.src['h']\n            h_v = edges.dst['h']\n            score = self.W(torch.cat([h_u, h_v], 1))\n            return {'score': score}\n\n        def forward(self, graph, h):\n            # h contains the node representations computed from the GNN defined\n            # in the node classification section (Section 5.1).\n            with graph.local_scope():\n                graph.ndata['h'] = h\n                graph.apply_edges(self.apply_edges)\n                return graph.edata['score']\n\n학습 룹(loop)\n~~~~~~~~~~~\n\n노드 representation 계산 모델과 에지 예측 모델을 만들었다면, 모든 에지들에 대한 예측값을 계산하는 전체 그래프를 이용한 학습 룹을 작성할 수 있다.\n\n노드 representation 계산 모델로 ``SAGE`` 를, 에지 예측 모델로 ``DotPredictor`` 을 사용한다.\n\n.. code:: python\n\n    class Model(nn.Module):\n        def __init__(self, in_features, hidden_features, out_features):\n            super().__init__()\n            self.sage = SAGE(in_features, hidden_features, out_features)\n            self.pred = DotProductPredictor()\n        def forward(self, g, x):\n            h = self.sage(g, x)\n            return self.pred(g, h)\n\n이 예제에서 학습/검증/테스트 에지 셋이 에지의 이진 마스크로 구분된다고 가정한다. 또한 early stopping이나 모델 저장은 포함하지 않는다.\n\n.. code:: python\n\n    node_features = edge_pred_graph.ndata['feature']\n    edge_label = edge_pred_graph.edata['label']\n    train_mask = edge_pred_graph.edata['train_mask']\n    model = Model(10, 20, 5)\n    opt = torch.optim.Adam(model.parameters())\n    for epoch in range(10):\n        pred = model(edge_pred_graph, node_features)\n        loss = ((pred[train_mask] - edge_label[train_mask]) ** 2).mean()\n        opt.zero_grad()\n        loss.backward()\n        opt.step()\n        print(loss.item())\n\n.. _guide_ko-training-edge-classification-heterogeneous-graph:\n\nHeterogeneous 그래프\n~~~~~~~~~~~~~~~~~~\n\nHeterogeneous 그래프들에 대한 에지 분류는 homogeneous 그래프와 크게 다르지 않다. 하나의 에지 타입에 대해서 에지 분류를 수행하자 한다면, 모든 노드 티압에 대한 노드 representation을 구하고, :meth:`~dgl.DGLGraph.apply_edges` 메소드를 사용해서 에지 타입을 예측하면 된다.\n\n예를 들면, heterogeneous 그래프의 하나의 에지 타입에 대한 동작하는 ``DotProductPredictor`` 를 작성하고자 한다면, ``apply_edges`` 메소드에 해당 에지 타입을 명시하기만 하면 된다.\n\n.. code:: python\n\n    class HeteroDotProductPredictor(nn.Module):\n        def forward(self, graph, h, etype):\n            # h contains the node representations for each edge type computed from\n            # the GNN for heterogeneous graphs defined in the node classification\n            # section (Section 5.1).\n            with graph.local_scope():\n                graph.ndata['h'] = h   # assigns 'h' of all node types in one shot\n                graph.apply_edges(fn.u_dot_v('h', 'h', 'score'), etype=etype)\n                return graph.edges[etype].data['score']\n\n비슷하게 ``HeteroMLPPredictor`` 를 작성할 수 있다.\n\n.. code:: python\n\n    class HeteroMLPPredictor(nn.Module):\n        def __init__(self, in_features, out_classes):\n            super().__init__()\n            self.W = nn.Linear(in_features * 2, out_classes)\n\n        def apply_edges(self, edges):\n            h_u = edges.src['h']\n            h_v = edges.dst['h']\n            score = self.W(torch.cat([h_u, h_v], 1))\n            return {'score': score}\n\n        def forward(self, graph, h, etype):\n            # h contains the node representations for each edge type computed from\n            # the GNN for heterogeneous graphs defined in the node classification\n            # section (Section 5.1).\n            with graph.local_scope():\n                graph.ndata['h'] = h   # assigns 'h' of all node types in one shot\n                graph.apply_edges(self.apply_edges, etype=etype)\n                return graph.edges[etype].data['score']\n\n특정 타입의 에지에 대해서, 각 에지의 점수를 예측하는 end-to-end 모델을 다음과 같다:\n\n.. code:: python\n\n    class Model(nn.Module):\n        def __init__(self, in_features, hidden_features, out_features, rel_names):\n            super().__init__()\n            self.sage = RGCN(in_features, hidden_features, out_features, rel_names)\n            self.pred = HeteroDotProductPredictor()\n        def forward(self, g, x, etype):\n            h = self.sage(g, x)\n            return self.pred(g, h, etype)\n\n모델을 사용하는 방법은 노드 타입과 피쳐들에 대한 사전을 모델에 간단하게 입력하면 된다.\n\n.. code:: python\n\n    model = Model(10, 20, 5, hetero_graph.etypes)\n    user_feats = hetero_graph.nodes['user'].data['feature']\n    item_feats = hetero_graph.nodes['item'].data['feature']\n    label = hetero_graph.edges['click'].data['label']\n    train_mask = hetero_graph.edges['click'].data['train_mask']\n    node_features = {'user': user_feats, 'item': item_feats}\n\n학습 룹은 homogeneous 그래프의 것과 거의 유사하다. 예를 들어, 에지 타입 ``click`` 에 대한 에지 레이블을 예측하는 것은 다음과 같이 간단히 구현된다.\n\n.. code:: python\n\n    opt = torch.optim.Adam(model.parameters())\n    for epoch in range(10):\n        pred = model(hetero_graph, node_features, 'click')\n        loss = ((pred[train_mask] - label[train_mask]) ** 2).mean()\n        opt.zero_grad()\n        loss.backward()\n        opt.step()\n        print(loss.item())\n\n\nHeterogeneous 그래프의 에지들에 대한 에지 타입 예측하기\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n주어진 에지의 타입을 예측하는 일도 종종 하게된다.\n\n:ref:`heterogeneous 그래프 예제 <guide_ko-training-heterogeneous-graph-example>` 에서는 user와 item을 연결하는 에지가 주어졌을 때, user가 ``click`` 을 선택할지, ``dislike`` 를 선택할지를 예측하고 있다.\n\n이는 추천에서 흔히 쓰이는 평가 예측의 간략한 버전이다.\n\n노드 representation을 얻기 위해서 heterogeneous graph convolution 네트워크를 사용할 수 있다. 이를 위해서 :ref:`이전에 정의한 RGCN <guide_ko-training-rgcn-node-classification>` 를 사용하는 것도 가능하다.\n\n에지 타입을 예측하기 위해서 ``HeteroDotProductPredictor`` 의 용도를 간단히 변경해서 예측할 모든 에지 타입을 “병합“하고 모든 에지들의 각 타입에 대한 점수를 내보내는 하나의 에지 타입만 있는 다른 그래프를 취하게하면 된다.\n\n이 예제에 적용해보면, ``user`` 와 ``item`` 두 노트 타입을 갖으며 ``user`` 와 ``item`` 에 대한 ``click`` 이나 ``dislike`` 같은 모든 에지 타입을 병합하는 단일 에지 타입을 갖는 그래프가 필요하다. 다음 문장으로 간단하게 생성할 수 있다.\n\n.. code:: python\n\n    dec_graph = hetero_graph['user', :, 'item']\n\n이 함수는 ``user`` 와 ``item`` 을 노드 타입으로 갖고, 두 노드 타입을 연결하고 있는 모든 에지 타입(예, ``click`` 와 ``dislike`` )을 합친 단일 에지 타입을 갖는 heterogeneous 그래프를 리턴한다.\n\n위 코드는 원래의 에지 타입을 ``dgl.ETYPE`` 이라는 이름의 피처로 리턴하기 때문에, 이를 레이블로 사용할 수 있다.\n\n.. code:: python\n\n    edge_label = dec_graph.edata[dgl.ETYPE]\n\n에지 타입 예측 모듈의 입력으로 위 그래프를 사용해서 예측 모델을 다음과 같이 작성한다.\n\n.. code:: python\n\n    class HeteroMLPPredictor(nn.Module):\n        def __init__(self, in_dims, n_classes):\n            super().__init__()\n            self.W = nn.Linear(in_dims * 2, n_classes)\n\n        def apply_edges(self, edges):\n            x = torch.cat([edges.src['h'], edges.dst['h']], 1)\n            y = self.W(x)\n            return {'score': y}\n\n        def forward(self, graph, h):\n            # h contains the node representations for each edge type computed from\n            # the GNN for heterogeneous graphs defined in the node classification\n            # section (Section 5.1).\n            with graph.local_scope():\n                graph.ndata['h'] = h   # assigns 'h' of all node types in one shot\n                graph.apply_edges(self.apply_edges)\n                return graph.edata['score']\n\n노드 representation 모듈과 에지 타입 예측 모듈을 합친 모델은 다음과 같다.\n\n.. code:: python\n\n    class Model(nn.Module):\n        def __init__(self, in_features, hidden_features, out_features, rel_names):\n            super().__init__()\n            self.sage = RGCN(in_features, hidden_features, out_features, rel_names)\n            self.pred = HeteroMLPPredictor(out_features, len(rel_names))\n        def forward(self, g, x, dec_graph):\n            h = self.sage(g, x)\n            return self.pred(dec_graph, h)\n\n학습 룹은 아래와 같이 간단하다.\n\n.. code:: python\n\n    model = Model(10, 20, 5, hetero_graph.etypes)\n    user_feats = hetero_graph.nodes['user'].data['feature']\n    item_feats = hetero_graph.nodes['item'].data['feature']\n    node_features = {'user': user_feats, 'item': item_feats}\n\n    opt = torch.optim.Adam(model.parameters())\n    for epoch in range(10):\n        logits = model(hetero_graph, node_features, dec_graph)\n        loss = F.cross_entropy(logits, edge_label)\n        opt.zero_grad()\n        loss.backward()\n        opt.step()\n        print(loss.item())\n\nDGL은 heterogeneous 그래프의 에지들에 대한 타입을 예측하는 문제인 평가 예측 예제로 `Graph Convolutional Matrix Completion <https://github.com/dmlc/dgl/tree/master/examples/pytorch/gcmc>`__ 를 제공한다. `모델 구현 파일 <https://github.com/dmlc/dgl/tree/master/examples/pytorch/gcmc>`__ 에 있는 노드 representation 모듈은 ``GCMCLayer`` 라고 불린다. 이 둘은 여기서 설명하기에는 너무 복잡하니 자세한 설명은 생략한다.\n"
  },
  {
    "path": "docs/source/guide_ko/training-graph.rst",
    "content": ".. _guide_ko-training-graph-classification:\n\n5.4 그래프 분류\n------------\n\n:ref:`(English Version) <guide-training-graph-classification>`\n\n데이터가 커다란 하나의 그래프가 아닌 여러 그래프로 구성된 경우도 종종 있다. 예를 들면, 사람들의 커뮤니티의 여러 종류 목록 같은 것을 들 수 있다. 같은 커뮤니티에 있는 사람들의 친목 관계를 그래프로 특징을 지어본다면, 분류할 수 있는 그래프들의 리스트를 만들 수 있다. 이 상황에서 그래프 분류 모델을 이용해서 커뮤니티의 종류를 구별해볼 수 있다.\n\n개요\n~~~~~~~~~\n\n그래프 분류가 노드 분류나 링크 예측 문제와 주요 차이점은 예측 결과가 전체 입력 그래프의 특성을 나타낸다는 것이다. 이전 문제들과 똑같이 노드들이나 에지들에 대해서 메시지 전달을 수행하지만, 그래프 수준의 representation을 찾아내야한다.\n\n그래프 분류 파이프라인은 다음과 같다:\n\n.. figure:: https://data.dgl.ai/tutorial/batch/graph_classifier.png\n   :alt: Graph Classification Process\n\n   그래프 분류 프로세스\n\n\n일반적인 방법은 (왼쪽부터 오른쪽으로 진행):\n\n- 그래프들의 배치를 준비한다\n- 그래프들의 배치에 메시지 전달을 수행해서 노드/에지 피쳐를 업데이트한다\n- 노드/에지 피쳐들을 모두 합쳐서 그래프 수준의 representation들을 만든다\n- 그래프 수준의 representation들을 사용해서 그래프들을 분류한다\n\n그래프들의 배치(batch)\n^^^^^^^^^^^^^^^^^^\n\n보통의 경우 그래프 분류 문제는 많은 수의 그래프를 사용해서 학습하기 때문에, 모델을 학습할 때 그래프를 한개씩 사용하는 것은 굉장히 비효율적이다. 일반적 딥러닝에서 사용되는 미니-배치 학습의 아이디어를 발려와서, 그래프들의 배치를 만들어서 한번의 학습 이터레이션에 사용하는 것이 가능하다.\n\nDGL는 그래프들의 리스트로부터 하나의 배치 그래프(batched graph)를 생성할 수 있다. 단순하게, 이 배치 그래프는 원래의 작은 그래프들을 연결하는 컴포넌트를 가지고 있는 하나의 큰 그래프로 사용된다.\n\n.. figure:: https://data.dgl.ai/tutorial/batch/batch.png\n   :alt: Batched Graph\n\n   배치 그래프(Batched Graph)\n\n다음 코드 예제는 그래프들의 목록에 :func:`dgl.batch` 를 호출한다. 배치 그래프는 하나의 그래프이자, 그 리스트에 대한 정보를 담고 있다.\n\n.. code:: python\n\n    import dgl\n    import torch as th\n\n    g1 = dgl.graph((th.tensor([0, 1, 2]), th.tensor([1, 2, 3])))\n    g2 = dgl.graph((th.tensor([0, 0, 0, 1]), th.tensor([0, 1, 2, 0])))\n\n    bg = dgl.batch([g1, g2])\n    bg\n    # Graph(num_nodes=7, num_edges=7,\n    #       ndata_schemes={}\n    #       edata_schemes={})\n    bg.batch_size\n    # 2\n    bg.batch_num_nodes()\n    # tensor([4, 3])\n    bg.batch_num_edges()\n    # tensor([3, 4])\n    bg.edges()\n    # (tensor([0, 1, 2, 4, 4, 4, 5], tensor([1, 2, 3, 4, 5, 6, 4]))\n\n대부분의 DGL 변환 함수들은 배치 정보를 버린다는 점을 주의하자. 이 정보를 유지하기 위해서, 변환된 그래프에  :func:`dgl.DGLGraph.set_batch_num_nodes` 와 :func:`dgl.DGLGraph.set_batch_num_edges` 를 사용한다.\n\n그래프 리드아웃(readout)\n^^^^^^^^^^^^^^^^^^^^\n\n모든 그래프는 노드와 에지의 피쳐들과 더불어 유일한 구조를 지니고 있다. 하나의 예측을 만들어내기 위해서, 보통은 아마도 풍부한 정보들을 합치고 요약한다. 이런 종류의 연산을 *리드아웃(readout)* 이라고 부른다. 흔히 쓰이는 리드아웃 연산들은 모든 노드 또는 에지 피쳐들에 대한 합(summation), 평균, 최대 또는 최소들이 있다.\n\n그래프 :math:`g` 에 대해서, 평균 노드 피처 리드아웃은 아래와 같이 정의된다.\n\n.. math:: h_g = \\frac{1}{|\\mathcal{V}|}\\sum_{v\\in \\mathcal{V}}h_v\n\n여기서 :math:`h_g` 는 :math:`g` 에 대한 representation이고, :math:`\\mathcal{V}` 는 :math:`g` 의 노드들의 집합, 그리고 :math:`h_v` 는 노드 :math:`v` 의 피쳐이다.\n\nDGL은 많이 쓰이는 리드아웃 연산들을 빌드인 함수로 지원한다. 예를 들어, :func:`dgl.mean_nodes` 는 위의 리드아웃 연산을 구현하고 있다.\n\n:math:`h_g` 가 구해진 후, 이를 MLP 레이어에 전달해서 분류 결과를 얻는다.\n\n뉴럴 네트워크 모델 작성하기\n~~~~~~~~~~~~~~~~~~~~\n\n모델에 대한 입력은 노드와 에지의 피쳐들 갖는 배치 그래프이다.\n\n배치 그래프에 연산하기\n^^^^^^^^^^^^^^^^\n\n첫째로, 배치 그래프에 있는 그래프들을 완전히 분리되어 있다. 즉, 두 그래들 사이에 에지가 존재하지 않는다. 이런 멋진 성질 덕에, 모든 메시지 전달 함수는 같은 결과를 만들어낸다. (즉 그래프 간의 간섭이 없다)\n\n두번째로, 배치 그래프에 대한 리드아웃 함수는 각 그래프에 별도록 수행된다. 배치 크기가 :math:`B` 이고 협쳐진 피쳐(aggregated feature)의 차원이 :math:`D` 인 경우, 리드아웃 결과의 shape은 :math:`(B, D)` 가 된다.\n\n.. code:: python\n\n    import dgl\n    import torch\n\n    g1 = dgl.graph(([0, 1], [1, 0]))\n    g1.ndata['h'] = torch.tensor([1., 2.])\n    g2 = dgl.graph(([0, 1], [1, 2]))\n    g2.ndata['h'] = torch.tensor([1., 2., 3.])\n\n    dgl.readout_nodes(g1, 'h')\n    # tensor([3.])  # 1 + 2\n\n    bg = dgl.batch([g1, g2])\n    dgl.readout_nodes(bg, 'h')\n    # tensor([3., 6.])  # [1 + 2, 1 + 2 + 3]\n\n마지막으로, 배치 그래프의 각 노드/에치 피쳐는 모든 그래프의 노드와 에지 피쳐들을 순서대로 연결해서 얻는다.\n\n.. code:: python\n\n    bg.ndata['h']\n    # tensor([1., 2., 1., 2., 3.])\n\n모델 정의하기\n^^^^^^^^^\n\n위 연산 규칙을 염두해서, 모델을 다음과 같이 정의한다.\n\n.. code:: python\n\n    import dgl.nn.pytorch as dglnn\n    import torch.nn as nn\n\n    class Classifier(nn.Module):\n        def __init__(self, in_dim, hidden_dim, n_classes):\n            super(Classifier, self).__init__()\n            self.conv1 = dglnn.GraphConv(in_dim, hidden_dim)\n            self.conv2 = dglnn.GraphConv(hidden_dim, hidden_dim)\n            self.classify = nn.Linear(hidden_dim, n_classes)\n\n        def forward(self, g, h):\n            # Apply graph convolution and activation.\n            h = F.relu(self.conv1(g, h))\n            h = F.relu(self.conv2(g, h))\n            with g.local_scope():\n                g.ndata['h'] = h\n                # Calculate graph representation by average readout.\n                hg = dgl.mean_nodes(g, 'h')\n                return self.classify(hg)\n\n학습 룹\n~~~~~\n\n데이터 로딩\n^^^^^^^^\n\n모델이 정의되었다면, 학습을 시작할 수 있다. 그래프 분류는 커다란 그래프 한개가 아니라 상대적으로 작은 그래프를 많이 다루기 때문에, 복잡한 그래프 샘플링 알고리즘을 사용하지 않고 그래프들의 stochastic 미니-배치를 사용해서 효과적으로 학습을 수행할 수 있다.\n\n:ref:`guide_ko-data-pipeline` 에서 소개한 그래프 분류 데이터셋을 사용하자.\n\n.. code:: python\n\n    import dgl.data\n    dataset = dgl.data.GINDataset('MUTAG', False)\n\n그래프 분류 데이터셋의 각 아이템은 한개의 그래프와 그 그래프의 레이블 쌍이다. 데이터 로딩 프로세스를 빠르게 하기 위해서 GraphDataLoader의 장점을 사용해 그래프들의 데이터셋을 미니-배치 단위로 iterate한다.\n\n.. code:: python\n\n    from dgl.dataloading import GraphDataLoader\n    dataloader = GraphDataLoader(\n        dataset,\n        batch_size=1024,\n        drop_last=False,\n        shuffle=True)\n\n학습 룹은 데이터로더를 iterate하면서 모델을 업데이트하는 것일 뿐이다.\n\n.. code:: python\n\n    import torch.nn.functional as F\n\n    # Only an example, 7 is the input feature size\n    model = Classifier(7, 20, 5)\n    opt = torch.optim.Adam(model.parameters())\n    for epoch in range(20):\n        for batched_graph, labels in dataloader:\n            feats = batched_graph.ndata['attr']\n            logits = model(batched_graph, feats)\n            loss = F.cross_entropy(logits, labels)\n            opt.zero_grad()\n            loss.backward()\n            opt.step()\n\n`DGL's GIN example <https://github.com/dmlc/dgl/tree/master/examples/pytorch/gin>`__ 의 end-to-end 그래프 분류 예를 참고하자. 이 학습 룹은 `main.py <https://github.com/dmlc/dgl/blob/master/examples/pytorch/gin/main.py>`__ 의 `train` 함수안에 있다. 모델의 구현은 `gin.py <https://github.com/dmlc/dgl/blob/master/examples/pytorch/gin/gin.py>`__ 에 있고, :class:`dgl.nn.pytorch.GINConv` (MXNet 및 Tensorflow 버전도 있음)와 같은 컴포넌트들과 graph convolution layer와 배치 normalization 등이 적용되어 있다.\n\nHeterogeneous 그래프\n~~~~~~~~~~~~~~~~~~\n\nHeterogeneous 그래프들에 대한 그래프 분류는 homogeneous 그래프의 경우와는 약간 차이가 있다. Heterogeneous 그래프와 호환되는 graph convolution 모듈에 더해서, 리드아웃 함수에서 다른 종류의 노드들에 대한 aggregate를 해야한다.\n\n다음 코드는 각 노트 타입에 대해서 노드 representation을 평균을 합산하는 예제이다.\n\n.. code:: python\n\n    class RGCN(nn.Module):\n        def __init__(self, in_feats, hid_feats, out_feats, rel_names):\n            super().__init__()\n\n            self.conv1 = dglnn.HeteroGraphConv({\n                rel: dglnn.GraphConv(in_feats, hid_feats)\n                for rel in rel_names}, aggregate='sum')\n            self.conv2 = dglnn.HeteroGraphConv({\n                rel: dglnn.GraphConv(hid_feats, out_feats)\n                for rel in rel_names}, aggregate='sum')\n\n        def forward(self, graph, inputs):\n            # inputs is features of nodes\n            h = self.conv1(graph, inputs)\n            h = {k: F.relu(v) for k, v in h.items()}\n            h = self.conv2(graph, h)\n            return h\n\n    class HeteroClassifier(nn.Module):\n        def __init__(self, in_dim, hidden_dim, n_classes, rel_names):\n            super().__init__()\n\n            self.rgcn = RGCN(in_dim, hidden_dim, hidden_dim, rel_names)\n            self.classify = nn.Linear(hidden_dim, n_classes)\n\n        def forward(self, g):\n            h = g.ndata['feat']\n            h = self.rgcn(g, h)\n            with g.local_scope():\n                g.ndata['h'] = h\n                # Calculate graph representation by average readout.\n                hg = 0\n                for ntype in g.ntypes:\n                    hg = hg + dgl.mean_nodes(g, 'h', ntype=ntype)\n                return self.classify(hg)\n\n나머지 코드는 homegeneous 그래프의 경우와 다르지 않다.\n\n.. code:: python\n\n    # etypes is the list of edge types as strings.\n    model = HeteroClassifier(10, 20, 5, etypes)\n    opt = torch.optim.Adam(model.parameters())\n    for epoch in range(20):\n        for batched_graph, labels in dataloader:\n            logits = model(batched_graph)\n            loss = F.cross_entropy(logits, labels)\n            opt.zero_grad()\n            loss.backward()\n            opt.step()\n"
  },
  {
    "path": "docs/source/guide_ko/training-link.rst",
    "content": ".. _guide_ko-training-link-prediction:\n\n5.3 링크 예측\n-----------\n\n:ref:`(English Version) <guide-training-link-prediction>`\n\n어떤 두 노드들 사이에 에지가 존재하는지 아닌지를 예측하고 싶은 경우가 있고, 이를 *링크 예측 과제* 라고 한다.\n\n개요\n~~~~~~~~~\n\nGNN 기반의 링크 예측 모델은 두 노드 :math:`u` 와 :math:`v` 간의 연결 가능도(likelihood)를 :math:`\\boldsymbol{h}_u^{(L)}` 의 함수로 표현하는데, 여기서 :math:`\\boldsymbol{h}_v^{(L)}` 는 멀티-레이어 GNN을 통해서 계단된 노드 representation이다. \n\n.. math::\n\n\n   y_{u,v} = \\phi(\\boldsymbol{h}_u^{(L)}, \\boldsymbol{h}_v^{(L)})\n\n:math:`y_{u,v}` 는 노드 :math:`u` 와 :math:`v` 사이의 점수를 뜻 한다.\n\n링크 예측 모델을 학습시키는 것은 에지로 연결된 두 노드들에 대한 점수와 임의의 두 노드 쌍에 대한 점수를 비교하면서 이뤄진다. 예를 들어, 노드 :math:`u` 와 :math:`v` 사이에 에지가 존재하는 경우 노드 :math:`u` 와 :math:`v` 사이의 점수가 노드 :math:`u` 와 임의의 *노이즈* 분표 :math:`v' \\sim P_n(v)`에 따라 샘플링된 노드 :math:`v'` 간의 점수보다 높도록 하는 학습이다.\n\n위를 달성하기 위한 다양한 loss 함수가 있다. 몇 가지 예는 다음과 같다:\n\n-  Cross-entropy loss:\n   :math:`\\mathcal{L} = - \\log \\sigma (y_{u,v}) - \\sum_{v_i \\sim P_n(v), i=1,\\dots,k}\\log \\left[ 1 - \\sigma (y_{u,v_i})\\right]`\n-  BPR loss:\n   :math:`\\mathcal{L} = \\sum_{v_i \\sim P_n(v), i=1,\\dots,k} - \\log \\sigma (y_{u,v} - y_{u,v_i})`\n-  Margin loss:\n   :math:`\\mathcal{L} = \\sum_{v_i \\sim P_n(v), i=1,\\dots,k} \\max(0, M - y_{u, v} + y_{u, v_i})`, 여기서 :math:`M` 은 상수 하이퍼-파라메터이다.\n\n`implicit feedback <https://arxiv.org/ftp/arxiv/papers/1205/1205.2618.pdf>`__ 이나 `noise-contrastive estimation <http://proceedings.mlr.press/v9/gutmann10a/gutmann10a.pdf>`__ 를 알고 있다면, 이 아이디어는 친숙할 것이다.\n\n:math:`u` 와 :math:`v` 사이의 점수를 계산하는 뉴럴 네트워크 모델은 :ref:`위에서 설명한 <guide_ko-training-edge-classification>`  에지 리그레션 모델과 동일하다.\n\n다음은 dot product를 사용해서 에지들의 점수를 계산하는 예제이다.\n\n.. code:: python\n\n    class DotProductPredictor(nn.Module):\n        def forward(self, graph, h):\n            # h contains the node representations computed from the GNN defined\n            # in the node classification section (Section 5.1).\n            with graph.local_scope():\n                graph.ndata['h'] = h\n                graph.apply_edges(fn.u_dot_v('h', 'h', 'score'))\n                return graph.edata['score']\n\n학습 룹\n~~~~~\n\n점수를 예측하는 모델은 그래프들에 적용되기 때문에, 네가티브 샘들은 별도의 그래프로 표현되어야 한다. 즉, 그것은 에지들이 모두 네가티브 노드들의 쌍들로만 구성된 그래프이다.\n\n아래 코드는 네가티브 샘들로 구성된 그래프를 만드는 예제이다. 각 에지 :math:`(u,v)` 는 :math:`k` 개의 네가티브 셈플들 :math:`(u,v_i)` 을 갖는다. 여기서 :math:`v_i` 는 균등 분포에서 샘플링된다.\n\n.. code:: python\n\n    def construct_negative_graph(graph, k):\n        src, dst = graph.edges()\n    \n        neg_src = src.repeat_interleave(k)\n        neg_dst = torch.randint(0, graph.num_nodes(), (len(src) * k,))\n        return dgl.graph((neg_src, neg_dst), num_nodes=graph.num_nodes())\n\n에지 점수를 예측하는 모델은 에지 분류 또는 에지 리그레션 모델과 같다.\n\n.. code:: python\n\n    class Model(nn.Module):\n        def __init__(self, in_features, hidden_features, out_features):\n            super().__init__()\n            self.sage = SAGE(in_features, hidden_features, out_features)\n            self.pred = DotProductPredictor()\n        def forward(self, g, neg_g, x):\n            h = self.sage(g, x)\n            return self.pred(g, h), self.pred(neg_g, h)\n\n그런 다음, 학습 룹은 반복적으로 네가티브 그래프를 만들고 loss를 계산한다.\n\n.. code:: python\n\n    def compute_loss(pos_score, neg_score):\n        # Margin loss\n        n_edges = pos_score.shape[0]\n        return (1 - pos_score + neg_score.view(n_edges, -1)).clamp(min=0).mean()\n    \n    node_features = graph.ndata['feat']\n    n_features = node_features.shape[1]\n    k = 5\n    model = Model(n_features, 100, 100)\n    opt = torch.optim.Adam(model.parameters())\n    for epoch in range(10):\n        negative_graph = construct_negative_graph(graph, k)\n        pos_score, neg_score = model(graph, negative_graph, node_features)\n        loss = compute_loss(pos_score, neg_score)\n        opt.zero_grad()\n        loss.backward()\n        opt.step()\n        print(loss.item())\n\n학습이 종료되면, 노드 representation은 다음과 같이 얻을 수 있다:\n\n.. code:: python\n\n    node_embeddings = model.sage(graph, node_features)\n\n노드 임베딩을 사용하는 방법은 여러가지가 있다. 몇가지 예를 들면, 다운스트림 분류기 학습, 관련된 엔터리 추천을 위한 nearest neighbor search 또는 maximum inner product search와 같은 것이 있다.\n\nHeterogeneous 그래프들\n~~~~~~~~~~~~~~~~~~~~\n\nHeterogeneous 그래프에서의 링크 예측은 homogeneous 그래프에서의 링크 예측과 많이 다르지 않다. 다음 예제는 하나의 에지 타입에 대해서 예측을 수행한다고 가정하고 있는데, 이를 여러 에지 타입으로 확장하는 것은 쉽다.\n\n링크 예측을 위해서 :ref:`앞에서 <guide_ko-training-edge-classification-heterogeneous-graph>` 의 ``HeteroDotProductPredictor`` 를 재활용해서 한 에지 타입에 대한 에지의 점수를 계산할 수 있다.\n\n.. code:: python\n\n    class HeteroDotProductPredictor(nn.Module):\n        def forward(self, graph, h, etype):\n            # h contains the node representations for each node type computed from\n            # the GNN defined in the previous section (Section 5.1).\n            with graph.local_scope():\n                graph.ndata['h'] = h\n                graph.apply_edges(fn.u_dot_v('h', 'h', 'score'), etype=etype)\n                return graph.edges[etype].data['score']\n\n네가티브 샘플링을 수행하기 위해서, 링크 예측을 수행할 에지 타입에 대한 네가티브 그램프를 생성하면 된다.\n\n.. code:: python\n\n    def construct_negative_graph(graph, k, etype):\n        utype, _, vtype = etype\n        src, dst = graph.edges(etype=etype)\n        neg_src = src.repeat_interleave(k)\n        neg_dst = torch.randint(0, graph.num_nodes(vtype), (len(src) * k,))\n        return dgl.heterograph(\n            {etype: (neg_src, neg_dst)},\n            num_nodes_dict={ntype: graph.num_nodes(ntype) for ntype in graph.ntypes})\n\n모델을 heterogeneous 그래프들에서 에지 분류하는 모델과는 약간 다른데, 그 이유는 링크 예측을 할 때 에지 타입을 지정해야하기 때문이다.\n\n.. code:: python\n\n    class Model(nn.Module):\n        def __init__(self, in_features, hidden_features, out_features, rel_names):\n            super().__init__()\n            self.sage = RGCN(in_features, hidden_features, out_features, rel_names)\n            self.pred = HeteroDotProductPredictor()\n        def forward(self, g, neg_g, x, etype):\n            h = self.sage(g, x)\n            return self.pred(g, h, etype), self.pred(neg_g, h, etype)\n\n학습 룹은 homogeneous 그래프에 대한 학습 룹과 비슷하다.\n\n.. code:: python\n\n    def compute_loss(pos_score, neg_score):\n        # Margin loss\n        n_edges = pos_score.shape[0]\n        return (1 - pos_score + neg_score.view(n_edges, -1)).clamp(min=0).mean()\n    \n    k = 5\n    model = Model(10, 20, 5, hetero_graph.etypes)\n    user_feats = hetero_graph.nodes['user'].data['feature']\n    item_feats = hetero_graph.nodes['item'].data['feature']\n    node_features = {'user': user_feats, 'item': item_feats}\n    opt = torch.optim.Adam(model.parameters())\n    for epoch in range(10):\n        negative_graph = construct_negative_graph(hetero_graph, k, ('user', 'click', 'item'))\n        pos_score, neg_score = model(hetero_graph, negative_graph, node_features, ('user', 'click', 'item'))\n        loss = compute_loss(pos_score, neg_score)\n        opt.zero_grad()\n        loss.backward()\n        opt.step()\n        print(loss.item())\n\n\n\n"
  },
  {
    "path": "docs/source/guide_ko/training-node.rst",
    "content": ".. _guide_ko-training-node-classification:\n\n5.1 노드 분류/리그래션(Regression)\n--------------------------------------------------\n\n:ref:`(English Version) <guide-training-node-classification>`\n\n가장 유명하고 널리 적용되고 있는 그래프 뉴럴 네트워크 중에 하나가 노드 분류이다. 학습/검증/테스트 셋의 각 노드는 미리 정해진 카테로기들로 중에 하나를 ground truth 카테고리로 분류되어 있다. 노드 regression도 비슷하다. 학습/검증/테스트 셋의 각 노드에 ground truth 수가 할당되어 있다.\n\n개요\n~~~~~~\n\n노드를 분류하기 위해서 그래프 뉴럴 네트워크는 :ref:`guide_ko-message-passing` 에서 소개한 메시지 전달 방법을 수행해서 노드 자신의 피쳐 뿐만 아니라 그 노드의 이웃 노드 및 에지의 피쳐도 함께 활용한다. 메시지 전달은 여러 회 반복해서 더 큰 범위의 이웃들에 대한 정보를 활용할 수 있다.\n\n뉴럴 네트워크 모델 작성하기\n~~~~~~~~~~~~~~~~~~~~\n\nDGL은 한 차례 메시지 전달을 수행하는 몇 가지 빌트인 graph convolution 모듈을 제공한다. 여기서 우리는 GraphSAGE에서 사용되는 graph convolution 모듈인 :class:`dgl.nn.pytorch.SAGEConv` (MXNet과 TensorFlow에서도 사용 가능)를 사용한다.\n\n보통 그래프에 대한 딥러닝 모델에서는 메시지 전달이 여러 번 수행되는 멀티-레이어 그래프 뉴럴 네트워크가 필요하다. 이는 다음 코드처럼 graph convolution 모듈들을 쌓아서 구현할 수 있다.\n\n.. code:: python\n\n    # Contruct a two-layer GNN model\n    import dgl.nn as dglnn\n    import torch.nn as nn\n    import torch.nn.functional as F\n    class SAGE(nn.Module):\n        def __init__(self, in_feats, hid_feats, out_feats):\n            super().__init__()\n            self.conv1 = dglnn.SAGEConv(\n                in_feats=in_feats, out_feats=hid_feats, aggregator_type='mean')\n            self.conv2 = dglnn.SAGEConv(\n                in_feats=hid_feats, out_feats=out_feats, aggregator_type='mean')\n      \n        def forward(self, graph, inputs):\n            # inputs are features of nodes\n            h = self.conv1(graph, inputs)\n            h = F.relu(h)\n            h = self.conv2(graph, h)\n            return h\n\n위 모델은 노드 분류 뿐만 아니라, :ref:`guide_ko-training-edge-classification` , :ref:`guide_ko-training-link-prediction` , 또는 :ref:`guide_ko-training-graph-classification` 와 같은 다른 다운스트림 테스크들을 위한 히든 노드 표현을 구하기 위해서 사용될 수 있음을 알아두자.\n\n빌트인 graph convolution 모듈의 전체 목록은 :ref:`apinn` 를 참고하자.\n\nDGL 뉴럴 네트워크 모듈이 어떻게 동작하는지 그리고 메시지 전달을 활용한 커스텀 뉴럴 네트워크 모듈을 작성하는 방법은 :ref:`guide_ko-nn` 에 있는 예제들을 참고하자.\n\n학습 룹(loop)\n~~~~~~~~~~~\n\n전체 그래프를 이용한 학습은 단지 위에서 정의된 모델에 forward propagation 그리고 학습 노드들의 groud truth 레이블과 예측을 비교해서 loss를 계산하는 것으로 구성된다.\n\n이 절은 빌드인 데이터셋 :class:`dgl.data.CiteseerGraphDataset` 을 사용해서 학습 룹을 설명한다. 노드 피처 및 레이블은 각 그래프 인스턴스에 저장되어 있고, 학습-검증-테스트 분할 또한 그래프에 이진 마스크로서 저장되어 있다. 이는 :ref:`guide_ko-data-pipeline` 에서 본것과 비슷하다.\n\n.. code:: python\n\n    node_features = graph.ndata['feat']\n    node_labels = graph.ndata['label']\n    train_mask = graph.ndata['train_mask']\n    valid_mask = graph.ndata['val_mask']\n    test_mask = graph.ndata['test_mask']\n    n_features = node_features.shape[1]\n    n_labels = int(node_labels.max().item() + 1)\n\n다음은 정확도(accuracy)로 모델을 평가하는 예제 코드이다.\n\n.. code:: python\n\n    def evaluate(model, graph, features, labels, mask):\n        model.eval()\n        with torch.no_grad():\n            logits = model(graph, features)\n            logits = logits[mask]\n            labels = labels[mask]\n            _, indices = torch.max(logits, dim=1)\n            correct = torch.sum(indices == labels)\n            return correct.item() * 1.0 / len(labels)\n\n그리고, 학습 룹은 다음과 같이 작성할 수 있다.\n\n.. code:: python\n\n    model = SAGE(in_feats=n_features, hid_feats=100, out_feats=n_labels)\n    opt = torch.optim.Adam(model.parameters())\n    \n    for epoch in range(10):\n        model.train()\n        # forward propagation by using all nodes\n        logits = model(graph, node_features)\n        # compute loss\n        loss = F.cross_entropy(logits[train_mask], node_labels[train_mask])\n        # compute validation accuracy\n        acc = evaluate(model, graph, node_features, node_labels, valid_mask)\n        # backward propagation\n        opt.zero_grad()\n        loss.backward()\n        opt.step()\n        print(loss.item())\n    \n        # Save model if necessary.  Omitted in this example.\n\n`GraphSAGE <https://github.com/dmlc/dgl/blob/master/examples/pytorch/graphsage/train_full.py>`__ 는 end-to-end homogeneous 그래프 노드 분류 예제를 제공한다. 해당 모델은 ``GraphSAGE`` 클래스에 구현되어 있고, 조정가능 한 레이어 수, dropout 확률들, 그리고 커스터마이징이 가능한 aggregation 함수 및 비선형성 등의 예제가 포함되어 있다.\n\n.. _guide_ko-training-rgcn-node-classification:\n\nHeterogeneous 그래프\n~~~~~~~~~~~~~~~~~~\n\n만약 그래프가 heterogeneous(이종)이라면, 여러분은 노드의 모든 에지 타입에 대한 이웃들로부터 메시지를 수집하기를 원할 것이다. 모든 에지 종류에 대해서 각 에지 타입별로 서로 다른 graph convolution 모듈을 사용한 메시지 전달을 수행하는 것은, :class:`dgl.nn.pytorch.HeteroGraphConv` (MXNet과 Tensorflow에서도 제공함) 모듈을 사용해서 가능하다.\n\n아래 코드는 heterogeneous graph convolution을 정의하는데, 이는 각 에지 타입에 따라 별도의 graph convolution을 수행하고, 모든 노드 타입들에 대한 결과로서 각 에지 타입에 대한 메시지 aggregation 값들을 합하는 일을 수행한다.\n\n.. code:: python\n\n    # Define a Heterograph Conv model\n\n    class RGCN(nn.Module):\n        def __init__(self, in_feats, hid_feats, out_feats, rel_names):\n            super().__init__()\n            \n            self.conv1 = dglnn.HeteroGraphConv({\n                rel: dglnn.GraphConv(in_feats, hid_feats)\n                for rel in rel_names}, aggregate='sum')\n            self.conv2 = dglnn.HeteroGraphConv({\n                rel: dglnn.GraphConv(hid_feats, out_feats)\n                for rel in rel_names}, aggregate='sum')\n      \n        def forward(self, graph, inputs):\n            # inputs are features of nodes\n            h = self.conv1(graph, inputs)\n            h = {k: F.relu(v) for k, v in h.items()}\n            h = self.conv2(graph, h)\n            return h\n\n``dgl.nn.HeteroGraphConv`` 는 노드 타입들과 노드 피쳐 텐서들의 사전을 입력으로 받고, 노드 타입과 노드 피쳐의 다른 사전을 리턴한다.\n\n여기서 사용되는 데이터셋은 이미 user 및 item 피쳐를 가지고 있고, 이는 :ref:`heterogeneous graph example <guide_ko-training-heterogeneous-graph-example>` 에서 확인할 수 있다.\n\n.. code:: python\n\n    model = RGCN(n_hetero_features, 20, n_user_classes, hetero_graph.etypes)\n    user_feats = hetero_graph.nodes['user'].data['feature']\n    item_feats = hetero_graph.nodes['item'].data['feature']\n    labels = hetero_graph.nodes['user'].data['label']\n    train_mask = hetero_graph.nodes['user'].data['train_mask']\n\nForward propagation을 다음과 같이 단순하게 실행된다.\n\n.. code:: python\n\n    node_features = {'user': user_feats, 'item': item_feats}\n    h_dict = model(hetero_graph, {'user': user_feats, 'item': item_feats})\n    h_user = h_dict['user']\n    h_item = h_dict['item']\n\n학습 룹은 예측을 계산할 노드 representation들의 사전을 사용하는 것을 제외하고는 homogeneous graph의 학습 룹과 동일하다. 예를 들어, ``user`` 노드 만을 예측하고 싶다면, 단지 리턴된 사전에서 ``user`` 노드 임베딩을 추출하면 된다.\n\n.. code:: python\n\n    opt = torch.optim.Adam(model.parameters())\n    \n    for epoch in range(5):\n        model.train()\n        # forward propagation by using all nodes and extracting the user embeddings\n        logits = model(hetero_graph, node_features)['user']\n        # compute loss\n        loss = F.cross_entropy(logits[train_mask], labels[train_mask])\n        # Compute validation accuracy.  Omitted in this example.\n        # backward propagation\n        opt.zero_grad()\n        loss.backward()\n        opt.step()\n        print(loss.item())\n    \n        # Save model if necessary.  Omitted in the example.\n\nDGL은 `RGCN <https://github.com/dmlc/dgl/blob/master/examples/pytorch/rgcn-hetero/entity_classify.py>`__ 의 end-to-end 예제를 제공한다. Heterogeneous graph convolution의 정의는 `모델 구현 파일 <https://github.com/dmlc/dgl/blob/master/examples/pytorch/rgcn-hetero/model.py>`__ ``RelGraphConvLayer`` 에서 확인할 수 있다.\n\n"
  },
  {
    "path": "docs/source/guide_ko/training.rst",
    "content": ".. _guide_ko-training:\n\n5장: 그래프 뉴럴 네트워크 학습하기\n==========================\n\n:ref:`(English Version) <guide-training>`\n\n개요\n----------------\n\n이 장에서는 :ref:`guide_ko-message-passing` 에서 소개한 메시지 전달 방법과 :ref:`guide_ko-nn` 에서 소개한 뉴럴 네트워크 모듈을 사용해서 작은 그래프들에 대한 노드 분류, 에지 분류, 링크 예측, 그리고 그래프 분류를 위한 그래프 뉴럴 네트워크를 학습하는 방법에 대해서 알아본다.\n\n여기서는 그래프 및 노드 및 에지 피쳐들이 GPU 메모리에 들어갈 수 있는 크기라고 가정한다. 만약 그렇지 않다면, :ref:`guide_ko-minibatch` 를 참고하자.\n\n그리고, 그래프와 노드/에지 피쳐들은 이미 프로세싱되어 있다고 가정한다. 만약 DGL에서 제공되는 데이터셋 또는 :ref:`guide_ko-data-pipeline` 에서 소개한 ``DGLDataset`` 과 호환되는 다른 데이터셋을 사용할 계획이라면, 다음과 같이 단일-그래프 데이터셋을 위한 그래프를 얻을 수 있다.\n\n.. code:: python\n\n    import dgl\n    \n    dataset = dgl.data.CiteseerGraphDataset()\n    graph = dataset[0]\n\n주의: 이 장의 예제들은 PyTorch를 백엔드로 사용한다.\n\n.. _guide_ko-training-heterogeneous-graph-example:\n\nHeterogeneous 그래프\n~~~~~~~~~~~~~~~~~~\n\n때로는 heterogeneous 그래프를 사용할 경우도 있다. 노드 분류, 에지 분류, 그리고 링크 예측 과제들의 예제를 위해서 임의로 만든 heterogeneous 그래프를 사용하겠다.\n\n임의로 생성한 heterogeneous 그래프 ``hetero_graph`` 는 다음과 같은 에지 타입을 갖는다:\n\n-  ``('user', 'follow', 'user')``\n-  ``('user', 'followed-by', 'user')``\n-  ``('user', 'click', 'item')``\n-  ``('item', 'clicked-by', 'user')``\n-  ``('user', 'dislike', 'item')``\n-  ``('item', 'disliked-by', 'user')``\n\n.. code:: python\n\n    import numpy as np\n    import torch\n    \n    n_users = 1000\n    n_items = 500\n    n_follows = 3000\n    n_clicks = 5000\n    n_dislikes = 500\n    n_hetero_features = 10\n    n_user_classes = 5\n    n_max_clicks = 10\n    \n    follow_src = np.random.randint(0, n_users, n_follows)\n    follow_dst = np.random.randint(0, n_users, n_follows)\n    click_src = np.random.randint(0, n_users, n_clicks)\n    click_dst = np.random.randint(0, n_items, n_clicks)\n    dislike_src = np.random.randint(0, n_users, n_dislikes)\n    dislike_dst = np.random.randint(0, n_items, n_dislikes)\n    \n    hetero_graph = dgl.heterograph({\n        ('user', 'follow', 'user'): (follow_src, follow_dst),\n        ('user', 'followed-by', 'user'): (follow_dst, follow_src),\n        ('user', 'click', 'item'): (click_src, click_dst),\n        ('item', 'clicked-by', 'user'): (click_dst, click_src),\n        ('user', 'dislike', 'item'): (dislike_src, dislike_dst),\n        ('item', 'disliked-by', 'user'): (dislike_dst, dislike_src)})\n    \n    hetero_graph.nodes['user'].data['feature'] = torch.randn(n_users, n_hetero_features)\n    hetero_graph.nodes['item'].data['feature'] = torch.randn(n_items, n_hetero_features)\n    hetero_graph.nodes['user'].data['label'] = torch.randint(0, n_user_classes, (n_users,))\n    hetero_graph.edges['click'].data['label'] = torch.randint(1, n_max_clicks, (n_clicks,)).float()\n    # randomly generate training masks on user nodes and click edges\n    hetero_graph.nodes['user'].data['train_mask'] = torch.zeros(n_users, dtype=torch.bool).bernoulli(0.6)\n    hetero_graph.edges['click'].data['train_mask'] = torch.zeros(n_clicks, dtype=torch.bool).bernoulli(0.6)\n\n\n로드맵\n----\n\n이 장은 그래프 학습 테스크를 설명하기 위해서 4개의 절로 구성되어 있다.\n\n* :ref:`guide_ko-training-node-classification`\n* :ref:`guide_ko-training-edge-classification`\n* :ref:`guide_ko-training-link-prediction`\n* :ref:`guide_ko-training-graph-classification`\n\n.. toctree::\n    :maxdepth: 1\n    :hidden:\n    :glob:\n\n    training-node\n    training-edge\n    training-link\n    training-graph\n"
  },
  {
    "path": "docs/source/index.rst",
    "content": ".. DGL documentation master file, created by\n   sphinx-quickstart on Fri Oct  5 14:18:01 2018.\n   You can adapt this file completely to your liking, but it should at least\n   contain the root `toctree` directive.\n\nWelcome to Deep Graph Library Tutorials and Documentation\n=========================================================\n\n.. toctree::\n   :maxdepth: 1\n   :caption: Get Started\n   :hidden:\n   :glob:\n\n   install/index\n   tutorials/blitz/index\n\n.. toctree::\n   :maxdepth: 2\n   :caption: Advanced Materials\n   :hidden:\n   :titlesonly:\n   :glob:\n\n   stochastic_training/index\n   guide/index\n   guide_cn/index\n   guide_ko/index\n   graphtransformer/index\n   notebooks/sparse/index\n   tutorials/cpu/index\n   tutorials/multi/index\n   tutorials/dist/index\n   tutorials/models/index\n\n.. toctree::\n   :maxdepth: 2\n   :caption: API Reference\n   :hidden:\n   :glob:\n\n   api/python/dgl\n   api/python/dgl.data\n   api/python/dgl.dataloading\n   api/python/dgl.DGLGraph\n   api/python/dgl.distributed\n   api/python/dgl.function\n   api/python/dgl.geometry\n   api/python/dgl.graphbolt\n   api/python/nn-pytorch\n   api/python/nn.functional\n   api/python/dgl.ops\n   api/python/dgl.optim\n   api/python/dgl.sampling\n   api/python/dgl.sparse_v0\n   api/python/dgl.multiprocessing\n   api/python/transforms\n   api/python/udf\n\n.. toctree::\n   :maxdepth: 1\n   :caption: Notes\n   :hidden:\n   :glob:\n\n   contribute\n   developer/ffi\n   performance\n\n.. toctree::\n   :maxdepth: 1\n   :caption: Misc\n   :hidden:\n   :glob:\n\n   faq\n   env_var\n   resources\n\n\nDeep Graph Library (DGL) is a Python package built for easy implementation of\ngraph neural network model family, on top of existing DL frameworks (currently\nsupporting PyTorch, MXNet and TensorFlow). It offers a versatile control of message passing,\nspeed optimization via auto-batching and highly tuned sparse matrix kernels,\nand multi-GPU/CPU training to scale to graphs of hundreds of millions of\nnodes and edges.\n\nGetting Started\n---------------\n\nFor absolute beginners, start with the :doc:`Blitz Introduction to DGL <tutorials/blitz/index>`.\nIt covers the basic concepts of common graph machine learning tasks and a step-by-step\non building Graph Neural Networks (GNNs) to solve them.\n\nFor acquainted users who wish to learn more advanced usage,\n\n* `Learn DGL by examples <https://github.com/dmlc/dgl/tree/master/examples>`_.\n* Read the :doc:`User Guide<guide/index>` (:doc:`中文版链接<guide_cn/index>`), which explains the concepts\n  and usage of DGL in much more details.\n* Go through the tutorials for :doc:`Stochastic Training of GNNs <notebooks/stochastic_training/index>`,\n  which covers the basic steps for training GNNs on large graphs in mini-batches.\n* :doc:`Study classical papers <tutorials/models/index>` on graph machine learning alongside DGL.\n* Search for the usage of a specific API in the :doc:`API reference manual <api/python/index>`,\n  which organizes all DGL APIs by their namespace.\n\nContribution\n-------------\nDGL is free software; you can redistribute it and/or modify it under the terms\nof the Apache License 2.0. We welcome contributions.\nJoin us on `GitHub <https://github.com/dmlc/dgl>`_ and check out our\n:doc:`contribution guidelines <contribute>`.\n\nIndex\n-----\n* :ref:`genindex`\n"
  },
  {
    "path": "docs/source/install/index.rst",
    "content": "Install and Setup\n=================\n\nSystem requirements\n-------------------\nDGL works with the following operating systems:\n\n* Ubuntu 20.04+\n* CentOS 8+ (Although gcc 9 is needed)\n* RHEL 8+\n* macOS X\n* Windows 10\n\nDGL requires Python version 3.7, 3.8, 3.9, 3.10, 3.11.\n\nDGL supports multiple tensor libraries as backends, e.g., PyTorch, MXNet. For requirements on backends and how to select one, see :ref:`backends`.\n\nStarting at version 0.3, DGL is separated into CPU and CUDA builds.  The builds share the\nsame Python package name. If you install DGL with a CUDA 9 build after you install the\nCPU build, then the CPU build is overwritten.\n\nInstall from Conda or Pip\n-------------------------\n\nWe recommend installing DGL by ``conda`` or ``pip``.\nCheck out the instructions on the `Get Started page <https://www.dgl.ai/pages/start.html>`_.\n\n.. note::\n\n   For Windows users: you will need to install `Visual C++ 2015 Redistributable <https://www.microsoft.com/en-us/download/details.aspx?id=48145>`_.\n\n.. _install-from-source:\n\nInstall from source\n-------------------\nDownload the source files from GitHub.\n\n.. code:: bash\n\n   git clone --recurse-submodules https://github.com/dmlc/dgl.git\n\n(Optional) Clone the repository first, and then run the following:\n\n.. code:: bash\n\n   git submodule update --init --recursive\n\nLinux\n`````\n\nInstall the system packages for building the shared library. For Debian and Ubuntu\nusers, run:\n\n.. code:: bash\n\n   sudo apt-get update\n   sudo apt-get install -y build-essential python3-dev make cmake\n\nFor Fedora/RHEL/CentOS users, run:\n\n.. code:: bash\n\n   sudo yum install -y gcc-c++ python3-devel make cmake\n\nTo create a Conda environment for CPU development, run:\n\n.. code:: bash\n\n   bash script/create_dev_conda_env.sh -c\n\nTo create a Conda environment for GPU development, run:\n\n.. code:: bash\n\n   bash script/create_dev_conda_env.sh -g 11.7\n\n\nTo further configure the conda environment, run the following command for more details:\n\n.. code:: bash\n\n   bash script/create_dev_conda_env.sh -h\n\nTo build the shared library for CPU development, run:\n\n.. code:: bash\n\n   bash script/build_dgl.sh -c\n\nTo build the shared library for GPU development, run:\n\n.. code:: bash\n\n   bash script/build_dgl.sh -g\n\nTo further build the shared library, run the following command for more details:\n\n.. code:: bash\n\n   bash script/build_dgl.sh -h\n\nFinally, install the Python binding.\n\n.. code:: bash\n\n   cd python\n   python setup.py install\n   # Build Cython extension\n   python setup.py build_ext --inplace\n\nmacOS\n`````\n\nInstallation on macOS is similar to Linux. But macOS users need to install build tools like clang, GNU Make, and cmake first. These installation steps were tested on macOS X with clang 10.0.0, GNU Make 3.81, and cmake 3.13.1.\n\nTools like clang and GNU Make are packaged in **Command Line Tools** for macOS. To\ninstall, run the following:\n\n.. code:: bash\n\n   xcode-select --install\n\nTo install other needed packages like cmake, we recommend first installing\n**Homebrew**, which is a popular package manager for macOS. To learn more, see the `Homebrew website <https://brew.sh/>`_.\n\nAfter you install Homebrew, install cmake.\n\n.. code:: bash\n\n   brew install cmake\n\nGo to root directory of the DGL repository, build a shared library, and\ninstall the Python binding for DGL.\n\n.. code:: bash\n\n   mkdir build\n   cd build\n   cmake -DUSE_OPENMP=off -DUSE_LIBXSMM=OFF ..\n   make -j4\n   cd ../python\n   python setup.py install\n   # Build Cython extension\n   python setup.py build_ext --inplace\n\nWindows\n```````\n\nYou can build DGL with MSBuild.  With `MS Build Tools <https://go.microsoft.com/fwlink/?linkid=840931>`_\nand `CMake on Windows <https://cmake.org/download/>`_ installed, run the following\nin VS2019 x64 Native tools command prompt.\n\n* CPU only build::\n\n     MD build\n     CD build\n     cmake -DCMAKE_CXX_FLAGS=\"/DDGL_EXPORTS\" -DCMAKE_CONFIGURATION_TYPES=\"Release\" -DDMLC_FORCE_SHARED_CRT=ON .. -G \"Visual Studio 16 2019\"\n     msbuild dgl.sln /m\n     CD ..\\python\n     python setup.py install\n\n* CUDA build::\n\n     MD build\n     CD build\n     cmake -DCMAKE_CXX_FLAGS=\"/DDGL_EXPORTS\" -DCMAKE_CONFIGURATION_TYPES=\"Release\" -DDMLC_FORCE_SHARED_CRT=ON -DUSE_CUDA=ON .. -G \"Visual Studio 16 2019\"\n     msbuild dgl.sln /m\n     CD ..\\python\n     python setup.py install\n\n\n.. _backends:\n\nWorking with different backends\n-------------------------------\n\nDGL supports PyTorch, MXNet and Tensorflow backends. \nDGL will choose the backend on the following options (high priority to low priority)\n\n* Use the ``DGLBACKEND`` environment variable:\n\n   - You can use ``DGLBACKEND=[BACKEND] python gcn.py ...`` to specify the backend\n   - Or ``export DGLBACKEND=[BACKEND]`` to set the global environment variable \n\n* Modify the ``config.json`` file under \"~/.dgl\":\n\n   - You can use ``python -m dgl.backend.set_default_backend [BACKEND]`` to set the default backend\n\nCurrently BACKEND can be chosen from mxnet, pytorch, tensorflow.\n\nPyTorch backend\n```````````````\n\nExport ``DGLBACKEND`` as ``pytorch`` to specify PyTorch backend. The required PyTorch\nversion is 1.12.0 or later. See `pytorch.org <https://pytorch.org>`_ for installation instructions.\n\nMXNet backend\n`````````````\n\nExport ``DGLBACKEND`` as ``mxnet`` to specify MXNet backend. The required MXNet version is\n1.6 or later. See `mxnet.apache.org <https://mxnet.apache.org/get_started>`_ for installation\ninstructions.\n\nMXNet uses uint32 as the default data type for integer tensors, which only supports graph of\nsize smaller than 2^32. To enable large graph training, *build* MXNet with ``USE_INT64_TENSOR_SIZE=1``\nflag. See `this FAQ <https://mxnet.apache.org/api/faq/large_tensor_support>`_ for more information.\n\nMXNet 1.5 and later has an option to enable Numpy shape mode for ``NDArray`` objects, some DGL models\nneed this mode to be enabled to run correctly. However, this mode may not compatible with pretrained\nmodel parameters with this mode disabled, e.g. pretrained models from GluonCV and GluonNLP.\nBy setting ``DGL_MXNET_SET_NP_SHAPE``, users can switch this mode on or off.\n\nTensorflow backend\n``````````````````\n\nExport ``DGLBACKEND`` as ``tensorflow`` to specify Tensorflow backend. The required Tensorflow\nversion is 2.3.0 or later. See `tensorflow.org <https://www.tensorflow.org/install>`_ for installation\ninstructions. In addition, DGL will set ``TF_FORCE_GPU_ALLOW_GROWTH`` to ``true`` to prevent Tensorflow take over the whole GPU memory:\n\n"
  },
  {
    "path": "docs/source/notebooks/sparse/gcn.nblink",
    "content": "{\n    \"path\": \"../../../../notebooks/sparse/gcn.ipynb\"\n}\n"
  },
  {
    "path": "docs/source/notebooks/sparse/graph_diffusion.nblink",
    "content": "{\n    \"path\": \"../../../../notebooks/sparse/graph_diffusion.ipynb\"\n}\n"
  },
  {
    "path": "docs/source/notebooks/sparse/graph_transformer.nblink",
    "content": "{\n    \"path\": \"../../../../notebooks/sparse/graph_transformer.ipynb\"\n}\n"
  },
  {
    "path": "docs/source/notebooks/sparse/hgnn.nblink",
    "content": "{\n    \"path\": \"../../../../notebooks/sparse/hgnn.ipynb\"\n}\n"
  },
  {
    "path": "docs/source/notebooks/sparse/index.rst",
    "content": "Tutorials: dgl.sparse\n=========================\n\nThe tutorial set cover the basic usage of DGL's sparse matrix class and operators. You can begin with \"Quickstart\" and \"Building a Graph Convolutional Network Using Sparse Matrices\". The rest of the tutorials demonstrate the usage by end-to-end examples. All the tutorials are written in Jupyter Notebook and can be played on Google Colab.\n\n.. toctree::\n  :maxdepth: 3\n  :titlesonly:\n\n  quickstart.nblink\n  gcn.nblink\n  graph_diffusion.nblink\n  hgnn.nblink\n  graph_transformer.nblink\n"
  },
  {
    "path": "docs/source/notebooks/sparse/quickstart.nblink",
    "content": "{\n    \"path\": \"../../../../notebooks/sparse/quickstart.ipynb\"\n}\n"
  },
  {
    "path": "docs/source/performance.rst",
    "content": "Performance Benchmarks\n======================\n\nIntegrated Benchmarks\n---------------------\n\nDGL continuously evaluates the speed of its core APIs, kernels as well as the training speed\nof the state-of-the-art GNN models. The benchmark code is available at\n`the main repository <https://github.com/dmlc/dgl/tree/master/benchmarks>`_. They are triggered\nfor every nightly-built version and the results are published to\n`https://asv.dgl.ai/ <https://asv.dgl.ai>`_.\n\nv0.6 Benchmarks\n---------------\n\nTo understand the performance gain of DGL v0.6, we re-evaluated it on the v0.5 benchmarks\nplus some new ones for graph classification tasks against the updated baselines. The results\nare available in `a standalone repository <https://github.com/dglai/dgl-0.5-benchmark>`_.\n\nv0.5 Benchmarks\n---------------\n\nCheck out our paper `Deep Graph Library: \nA Graph-Centric, Highly-Performant Package for Graph Neural Networks <https://arxiv.org/abs/1909.01315>`_.\n\nv0.4.3 Benchmarks\n------------------\n\n**Microbenchmark on speed and memory usage**:\nWhile leaving tensor and autograd functions to backend frameworks (e.g.\nPyTorch, MXNet, and TensorFlow), DGL aggressively optimizes storage and\ncomputation with its own kernels. Here's a comparison to another popular\npackage -- PyTorch Geometric (PyG). The short story is that raw speed is\nsimilar, but DGL has much better memory management.\n\n+----------+--------------+-----------------+-------------------------+-------------------------+\n| Dataset  |    Model     |   Accuracy      |         Time            |           Memory        |\n|          |              |                 +------------+------------+------------+------------+\n|          |              |                 |  PyG       |  DGL       |  PyG       |  DGL       |\n+==========+==============+=================+============+============+============+============+\n| Cora     | GCN          | 81.31 ± 0.88    | **0.478**  | 0.666      | 1.1        | 1.1        |\n+          +--------------+-----------------+------------+------------+------------+------------+\n|          | GAT          | 83.98 ± 0.52    | 1.608      | **1.399**  | 1.2        | **1.1**    |\n+----------+--------------+-----------------+------------+------------+------------+------------+\n| CiteSeer | GCN          | 70.98 ± 0.68    | **0.490**  | 0.674      | 1.1        | 1.1        |\n+          +--------------+-----------------+------------+------------+------------+------------+\n|          | GAT          | 69.96 ± 0.53    | 1.606      | **1.399**  | 1.3        | **1.1**    |\n+----------+--------------+-----------------+------------+------------+------------+------------+\n| PubMed   | GCN          | 79.00 ± 0.41    | **0.491**  | 0.690      | 1.1        | 1.1        |\n+          +--------------+-----------------+------------+------------+------------+------------+\n|          | GAT          | 77.65 ± 0.32    | 1.946      | **1.393**  | 1.6        | **1.1**    |\n+----------+--------------+-----------------+------------+------------+------------+------------+\n| Reddit   |     GCN      | 93.46 ± 0.06    | OOM        | **28.6**   | OOM        |  **11.7**  |\n+----------+--------------+-----------------+------------+------------+------------+------------+\n| Reddit-S |     GCN      | N/A             | 29.12      | **9.44**   | 15.7       |  **3.6**   |\n+----------+--------------+-----------------+------------+------------+------------+------------+\n\nTable: Training time(in seconds) for 200 epochs and memory consumption(GB)\n\nHere is another comparison of DGL on TensorFlow backend with other TF-based GNN tools (training time in seconds for one epoch):\n\n+---------+-------+--------+----------+--------------+\n| Dateset | Model | DGL    | GraphNet | tf_geometric |\n+=========+=======+========+==========+==============+\n| Core    | GCN   | 0.0148 | 0.0152   | 0.0192       |\n+---------+-------+--------+----------+--------------+\n| Reddit  | GCN   | 0.1095 | OOM      | OOM          |\n+---------+-------+--------+----------+--------------+\n| PubMed  | GCN   | 0.0156 | 0.0553   | 0.0185       |\n+---------+-------+--------+----------+--------------+\n| PPI     | GCN   | 0.09   | 0.16     | 0.21         |\n+---------+-------+--------+----------+--------------+\n| Cora    | GAT   | 0.0442 | n/a      | 0.058        |\n+---------+-------+--------+----------+--------------+\n| PPI     | GAT   | 0.398  | n/a      | 0.752        |\n+---------+-------+--------+----------+--------------+\n\nHigh memory utilization allows DGL to push the limit of single-GPU performance, as seen in below images.\n\n.. image:: http://data.dgl.ai/asset/image/DGLvsPyG-time1.png\n\n.. image:: http://data.dgl.ai/asset/image/DGLvsPyG-time2.png\n\n**Scalability**:\nDGL has fully leveraged multiple GPUs in both one machine and clusters for\nincreasing training speed, and has better performance than alternatives, as\nseen in below images.\n\n.. image:: http://data.dgl.ai/asset/image/one-four-GPUs.png\n\n.. image:: http://data.dgl.ai/asset/image/one-four-GPUs-DGLvsGraphVite.png\n\n.. image:: http://data.dgl.ai/asset/image/one-fourMachines.png\n\n**Further reading**:\nDetailed comparison of DGL and other alternatives can be found\n[here](https://arxiv.org/abs/1909.01315).\n"
  },
  {
    "path": "docs/source/resources.rst",
    "content": "Resources\n=========\n* If you are new to deep learning, `Dive into Deep Learning <http://diveintodeeplearning.org>`__\n  is a nice book to start with.\n* `Pytorch tutorials <https://pytorch.org/tutorials/>`__\n* Thomas Kipf's `blog on Graph Convolutional Networks <https://tkipf.github.io/graph-convolutional-networks/>`__\n"
  },
  {
    "path": "docs/source/stochastic_training/index.rst",
    "content": "🆕 Stochastic Training of GNNs with GraphBolt\n=============================================\n\nGraphBolt is a data loading framework for GNN with high flexibility and\nscalability. It is built on top of DGL and PyTorch.\n\nThis tutorial introduces how to enable stochastic training of GNNs with\nGraphBolt.\n\nOverview\n^^^^^^^\n\n.. image:: ../_static/graphbolt_overview.jpg\n  :width: 700\n  :alt: Graphbolt Overview\n\nGraphBolt integrates seamlessly with the PyTorch `datapipe <https://pytorch.org/data/beta/torchdata.datapipes.iter.html>`_, relying on the unified \"MiniBatch\" data structure to connect processing stages. It streamlines data loading and preprocessing for GNN training, validation, and testing.\nBy default, GraphBolt provides a collection of built-in datasets and exceptionally efficient implementations of datapipes for common scenarios, which can be summarized as follows:\n\n1. **Item Sampler:** Randomly selects a subset (nodes, edges, graphs) from the entire training set as an initial mini-batch for downstream computation.\n\n2. **Negative Sampler:** Specially designed for link prediction tasks, it generates non-existing edges as negative examples for training.\n\n3. **Subgraph Sampler:** Generates subgraphs based on the input nodes/edges for computation.\n\n4. **Feature Fetcher:** Fetches related node/edge features from the dataset for the given input.\n\nBy exposing the entire data loading process as a pipeline, GraphBolt provides significant flexibility and customization opportunities. Users can easily substitute any stage with their own implementations. Additionally, users can benefit from the optimized scheduling strategy for datapipes, even with customized stages.\n\nIn summary, GraphBolt offers the following benefits:\n\n1. A flexible, pipelined framework for GNN data loading and preprocessing.\n\n2. Highly efficient canonical implementations.\n\n3. Efficient scheduling.\n\nScenarios\n^^^^^^^\n\n.. toctree::\n  :maxdepth: 1\n\n  neighbor_sampling_overview.nblink\n  node_classification.nblink\n  link_prediction.nblink\n  multigpu_node_classification.nblink\n  ondisk-dataset.rst\n"
  },
  {
    "path": "docs/source/stochastic_training/link_prediction.nblink",
    "content": "{\n    \"path\": \"../../../notebooks/stochastic_training/link_prediction.ipynb\"\n}\n"
  },
  {
    "path": "docs/source/stochastic_training/multigpu_node_classification.nblink",
    "content": "{\n    \"path\": \"../../../notebooks/stochastic_training/multigpu_node_classification.ipynb\"\n}\n"
  },
  {
    "path": "docs/source/stochastic_training/neighbor_sampling_overview.nblink",
    "content": "{\n    \"path\": \"../../../notebooks/stochastic_training/neighbor_sampling_overview.ipynb\"\n}\n"
  },
  {
    "path": "docs/source/stochastic_training/node_classification.nblink",
    "content": "{\n    \"path\": \"../../../notebooks/stochastic_training/node_classification.ipynb\"\n}\n"
  },
  {
    "path": "docs/source/stochastic_training/ondisk-dataset-specification.rst",
    "content": ".. _stochastic_training-ondisk-dataset-specification:\n\nYAML specification\n==================\n\nThis document describes the YAML specification of ``metadata.yaml`` file for\n``OnDiskDataset``. ``metadata.yaml`` file is used to specify the dataset\ninformation, including the graph structure, feature data and tasks.\n\n.. code:: yaml\n\n    dataset_name: <string>\n    graph:\n      nodes:\n        - type: <string>\n          num: <int>\n        - type: <string>\n          num: <int>\n      edges:\n        - type: <string>\n          format: <string>\n          path: <string>\n        - type: <string>\n          format: <string>\n          path: <string>\n    feature_data:\n      - domain: node\n        type: <string>\n        name: <string>\n        format: <string>\n        in_memory: <bool>\n        path: <string>\n      - domain: node\n        type: <string>\n        name: <string>\n        format: <string>\n        in_memory: <bool>\n        path: <string>\n      - domain: edge\n        type: <string>\n        name: <string>\n        format: <string>\n        in_memory: <bool>\n        path: <string>\n      - domain: edge\n        type: <string>\n        name: <string>\n        format: <string>\n        in_memory: <bool>\n        path: <string>\n    tasks:\n      - name: <string>\n        num_classes: <int>\n        train_set:\n          - type: <string>\n            data:\n              - name: <string>\n                format: <string>\n                in_memory: <bool>\n                path: <string>\n              - name: <string>\n                format: <string>\n                in_memory: <bool>\n                path: <string>\n        validation_set:\n          - type: <string>\n            data:\n              - name: <string>\n                format: <string>\n                in_memory: <bool>\n                path: <string>\n              - name: <string>\n                format: <string>\n                in_memory: <bool>\n                path: <string>\n        test_set:\n          - type: <string>\n            data:\n              - name: <string>\n                format: <string>\n                in_memory: <bool>\n                path: <string>\n              - name: <string>\n                format: <string>\n                in_memory: <bool>\n                path: <string>\n\n``dataset_name``\n---------------\n\nThe ``dataset_name`` field is used to specify the name of the dataset. It is\nuser-defined.\n\n``graph``\n---------\n\nThe ``graph`` field is used to specify the graph structure. It has two fields:\n``nodes`` and ``edges``.\n\n - ``nodes``: ``list``\n\n   The ``nodes`` field is used to specify the number of nodes for each node type.\n   It is a list of ``node`` objects. Each ``node`` object has two fields: ``type``\n   and ``num``.\n    - ``type``: ``string``, optional\n\n      The ``type`` field is used to specify the node type. It is ``null`` for\n      homogeneous graphs. For heterogeneous graphs, it is the node type.\n    - ``num``: ``int``\n\n      The ``num`` field is used to specify the number of nodes for the node type.\n      It is mandatory for both homogeneous graphs and heterogeneous graphs.\n\n  - ``edges``: ``list``\n\n    The ``edges`` field is used to specify the edges. It is a list of ``edge``\n    objects. Each ``edge`` object has three fields: ``type``, ``format`` and\n    ``path``.\n    - ``type``: ``string``, optional\n\n      The ``type`` field is used to specify the edge type. It is ``null`` for\n      homogeneous graphs. For heterogeneous graphs, it is the edge type.\n    - ``format``: ``string``\n\n      The ``format`` field is used to specify the format of the edge data. It\n      can be ``csv`` or ``numpy``. If it is ``csv``, no ``index`` and ``header``\n      fields are needed. If it is ``numpy``, the array requires to be in shape\n      of ``(2, num_edges)``. ``numpy`` format is recommended for large graphs.\n    - ``path``: ``string``\n\n      The ``path`` field is used to specify the path of the edge data. It is\n      relative to the directory of ``metadata.yaml`` file.\n\n\n``feature_data``\n----------------\n\nThe ``feature_data`` field is used to specify the feature data. It is a list of\n``feature`` objects. Each ``feature`` object has five canonical fields: ``domain``,\n``type``, ``name``, ``format`` and ``path``. Any other fields will be passed to\nthe ``Feature.metadata`` object.\n\n - ``domain``: ``string``\n\n   The ``domain`` field is used to specify the domain of the feature data. It can\n   be either ``node`` or ``edge``.\n - ``type``: ``string``, optional\n\n   The ``type`` field is used to specify the type of the feature data. It is\n   ``null`` for homogeneous graphs. For heterogeneous graphs, it is the node or\n   edge type.\n  - ``name``: ``string``\n\n    The ``name`` field is used to specify the name of the feature data. It is\n    user-defined.\n  - ``format``: ``string``\n\n    The ``format`` field is used to specify the format of the feature data. It can\n    be either ``numpy`` or ``torch``.\n  - ``in_memory``: ``bool``, optional\n\n    The ``in_memory`` field is used to specify whether the feature data is loaded\n    into memory. It can be either ``true`` or ``false``. Default is ``true``.\n  - ``path``: ``string``\n\n    The ``path`` field is used to specify the path of the feature data. It is\n    relative to the directory of ``metadata.yaml`` file.\n\n\n``tasks``\n---------\n\nThe ``tasks`` field is used to specify the tasks. It is a list of ``task``\nobjects. Each ``task`` object has at least three fields: ``train_set``,\n``validation_set``, ``test_set``. And you are free to add other fields\nsuch as ``num_classes`` and all these fields will be passed to the\n``Task.metadata`` object.\n\n - ``name``: ``string``, optional\n\n   The ``name`` field is used to specify the name of the task. It is user-defined.\n - ``num_classes``: ``int``, optional\n\n    The ``num_classes`` field is used to specify the number of classes of the task.\n - ``train_set``: ``list``\n\n    The ``train_set`` field is used to specify the training set. It is a list of\n    ``set`` objects. Each ``set`` object has two fields: ``type`` and ``data``.\n  - ``type``: ``string``, optional\n\n      The ``type`` field is used to specify the node/edge type of the set. It is\n      ``null`` for homogeneous graphs. For heterogeneous graphs, it is the node\n      or edge type.\n  - ``data``: ``list``\n\n      The ``data`` field is used to load ``train_set``. It is a list of ``data``\n      objects. Each ``data`` object has four fields: ``name``, ``format``,\n      ``in_memory`` and ``path``.\n\n    - ``name``: ``string``\n\n        The ``name`` field is used to specify the name of the data. It is mandatory\n        and used to specify the data fields of ``MiniBatch`` for sampling. It can\n        be either ``seeds``, ``labels`` or ``indexes``. If any other name is used,\n        it will be added into the ``MiniBatch`` data fields.\n    - ``format``: ``string``\n\n        The ``format`` field is used to specify the format of the data. It can be\n        either ``numpy`` or ``torch``.\n    - ``in_memory``: ``bool``, optional\n\n        The ``in_memory`` field is used to specify whether the data is loaded into\n        memory. It can be either ``true`` or ``false``. Default is ``true``.\n    - ``path``: ``string``\n\n        The ``path`` field is used to specify the path of the data. It is relative\n        to the directory of ``metadata.yaml`` file.\n - ``validation_set``: ``list``\n - ``test_set``: ``list``\n\n    The ``validation_set`` and ``test_set`` fields are used to specify the\n    validation set and test set respectively. They are similar to the\n    ``train_set`` field.\n\n"
  },
  {
    "path": "docs/source/stochastic_training/ondisk-dataset.rst",
    "content": ".. _stochastic_training-ondisk-dataset:\n\nComposing OnDiskDataset from raw data\n=====================================\n\nThis tutorial shows how to compose :class:`~dgl.graphbolt.OnDiskDataset` from\nraw data. A full specification of ``metadata.yaml`` is also provided.\n\n**GraphBolt** provides the ``OnDiskDataset`` class to help user organize plain\ndata of graph strucutre, feature data and tasks. ``OnDiskDataset`` is also\ndesigned to efficiently handle large graphs and features that do not fit into\nmemory by storing them on disk.\n\n.. toctree::\n    :maxdepth: 1\n    :glob:\n\n    ondisk_dataset_homograph.nblink\n    ondisk_dataset_heterograph.nblink\n    ondisk-dataset-specification.rst\n"
  },
  {
    "path": "docs/source/stochastic_training/ondisk_dataset_heterograph.nblink",
    "content": "{\n    \"path\": \"../../../notebooks/stochastic_training/ondisk_dataset_heterograph.ipynb\"\n}\n"
  },
  {
    "path": "docs/source/stochastic_training/ondisk_dataset_homograph.nblink",
    "content": "{\n    \"path\": \"../../../notebooks/stochastic_training/ondisk_dataset_homograph.ipynb\"\n}\n"
  },
  {
    "path": "examples/README.md",
    "content": "# Official DGL Examples and Modules\n\nThe folder contains example implementations of selected research papers related to Graph Neural Networks. Note that the examples may not work with incompatible DGL versions.\n* For examples working with the latest master (or the latest [nightly build](https://www.dgl.ai/pages/start.html)), check out https://github.com/dmlc/dgl/tree/master/examples.\n* For examples working with a certain release, check out `https://github.com/dmlc/dgl/tree/<release_version>/examples` (E.g., https://github.com/dmlc/dgl/tree/0.5.x/examples)\n\nTo quickly locate the examples of your interest, search for the tagged keywords or use the search tool on [dgl.ai](https://www.dgl.ai/).\n\n## 2024\n\n- <a name=\"labor\"></a> Lin et al. ARGO: An Auto-Tuning Runtime System for Scalable GNN Training on Multi-Core Processor. [Paper link](https://arxiv.org/abs/2402.03671)\n  - Example code: [PyTorch](https://github.com/dmlc/dgl/tree/master/examples/pytorch/argo)\n\n  - Tags: semi-supervised node classification\n\n## 2023\n\n- <a name=\"labor\"></a> Zheng Wang et al. From Cluster Assumption to Graph Convolution: Graph-based Semi-Supervised Learning Revisited. [Paper link](https://arxiv.org/abs/2210.13339)\n  - Example code: [PyTorch](../examples/pytorch/ogc)\n\n  - Tags: semi-supervised node classification\n\n## 2022\n- <a name=\"labor\"></a> Balin et al. Layer-Neighbor Sampling -- Defusing Neighborhood Explosion in GNNs. [Paper link](https://arxiv.org/abs/2210.13339)\n    - Example code: [PyTorch](../examples/labor/train_lightning.py)\n    - Tags: node classification, weighted graphs, sampling\n## 2021\n- <a name=\"rnaglib\"></a> Mallet et al. Learning Protein and Small Molecule binding sites in RNA molecules with 2.5D graphs. [Paper link](https://academic.oup.com/bioinformatics/article/38/5/1458/6462185?login=true)\n    - Example code: [PyTorch](https://jwgitlab.cs.mcgill.ca/cgoliver/rnaglib)\n    - Tags: semi-supervised node classification\n- <a name=\"hilander\"></a> Xing et al. Learning Hierarchical Graph Neural Networks for Image Clustering.\n    - Example code: [PyTorch](../examples/pytorch/hilander)\n    - Tags: clustering\n- <a name=\"bgnn\"></a> Ivanov et al. Boost then Convolve: Gradient Boosting Meets Graph Neural Networks. [Paper link](https://openreview.net/forum?id=ebS5NUfoMKL). \n    - Example code: [PyTorch](../examples/pytorch/bgnn)\n    - Tags: semi-supervised node classification, tabular data, GBDT\n- <a name=\"correct_and_smooth\"></a> Huang et al. Combining Label Propagation and Simple Models Out-performs Graph Neural Networks. [Paper link](https://arxiv.org/abs/2010.13993). \n    - Example code: [PyTorch](../examples/pytorch/correct_and_smooth)\n    - Tags: efficiency, node classification, label propagation\n- <a name=\"point_transformer\"></a> Zhao et al. Point Transformer. [Paper link](http://arxiv.org/abs/2012.09164).\n    - Example code: [PyTorch](../examples/pytorch/pointcloud/point_transformer)\n    - Tags: point cloud classification, point cloud part-segmentation\n- <a name=\"pct\"></a> Guo et al. PCT: Point cloud transformer. [Paper link](http://arxiv.org/abs/2012.09688).\n    - Example code: [PyTorch](../examples/pytorch/pointcloud/pct)\n    - Tags: point cloud classification, point cloud part-segmentation\n- <a name='gatv2'></a> Brody et al. How Attentive are Graph Attention Networks? [Paper link](https://arxiv.org/abs/2105.14491).\n    - Example code: [PyTorch](../examples/pytorch/gatv2)\n    - Tags: graph attention, gat, gatv2, attention\n- <a name='bgrl'></a> Thakoor et al. Large-Scale Representation Learning on Graphs via Bootstrapping. [Paper link](https://arxiv.org/abs/2102.06514).\n    - Example code: [PyTorch](../examples/pytorch/bgrl)\n    - Tags: contrastive learning for node classification.\n- <a name='directional_gsn'></a> Bouritsas et al. Improving Graph Neural Network Expressivity via Subgraph Isomorphism Counting. [Paper link](https://arxiv.org/abs/2006.09252).\n    - Example code: [PyTorch](../examples/pytorch/ogb/directional_GSN)\n    - Tags: subgraph isomorphism counting, graph classification.\n- <a name='ngnn'></a> Song et al. Network In Graph Neural Network. [Paper link](https://arxiv.org/abs/2111.11638).\n    - Example code: [PyTorch](../examples/pytorch/ogb/ngnn)\n    - Tags: model-agnostic methodology, link prediction, open graph benchmark.\n- <a name='bipointnet'></a>Qin et al. BiPointNet: Binary Neural Network for Point Clouds. [Paper link](https://openreview.net/forum?id=9QLRCVysdlO)\n    - Example code: [PyTorch](../examples/pytorch/pointcloud/bipointnet)\n    - Tags: point cloud classification, network binarization.\n\n\n## 2020\n- <a name=\"eeg-gcnn\"></a> Wagh et al. EEG-GCNN: Augmenting Electroencephalogram-based Neurological Disease Diagnosis using a Domain-guided Graph Convolutional Neural Network. [Paper link](http://proceedings.mlr.press/v136/wagh20a.html). \n    - Example code: [PyTorch](../examples/pytorch/eeg-gcnn)\n    - Tags: graph classification, eeg representation learning, brain activity, graph convolution,  neurological disease classification, large dataset, edge weights, node features, fully-connected graph, graph neural network\n- <a name=\"rect\"></a> Wang et al. Network Embedding with Completely-imbalanced Labels. [Paper link](https://ieeexplore.ieee.org/document/8979355). \n    - Example code: [PyTorch](../examples/pytorch/rect)\n    - Tags: node classification, network embedding, completely-imbalanced labels\n- <a name=\"mvgrl\"></a> Hassani and Khasahmadi. Contrastive Multi-View Representation Learning on Graphs. [Paper link](https://arxiv.org/abs/2006.05582). \n    - Example code: [PyTorch](../examples/pytorch/mvgrl)\n    - Tags: graph diffusion, self-supervised learning\n- <a name=\"grace\"></a> Zhu et al. Deep Graph Contrastive Representation Learning. [Paper link](https://arxiv.org/abs/2006.04131). \n    - Example code: [PyTorch](../examples/pytorch/grace)\n    - Tags: contrastive learning for node classification.\n- <a name=\"grand\"></a> Feng et al. Graph Random Neural Network for Semi-Supervised Learning on Graphs. [Paper link](https://arxiv.org/abs/2005.11079). \n    - Example code: [PyTorch](../examples/pytorch/grand)\n    - Tags: semi-supervised node classification, simplifying graph convolution, data augmentation\n- <a name=\"hgt\"></a> Hu et al. Heterogeneous Graph Transformer. [Paper link](https://arxiv.org/abs/2003.01332).\n    - Example code: [PyTorch](../examples/pytorch/hgt)\n    - Tags: dynamic heterogeneous graph, large-scale, node classification, link prediction\n- <a name=\"mwe\"></a> Chen. Graph Convolutional Networks for Graphs with Multi-Dimensionally Weighted Edges. [Paper link](https://cims.nyu.edu/~chenzh/files/GCN_with_edge_weights.pdf).\n    - Example code: [PyTorch on ogbn-proteins](../examples/pytorch/ogb/ogbn-proteins)\n    - Tags: node classification, weighted graphs, OGB\n- <a name=\"sign\"></a> Frasca et al. SIGN: Scalable Inception Graph Neural Networks. [Paper link](https://arxiv.org/abs/2004.11198).\n    - Example code: [PyTorch on ogbn-arxiv/products/mag](../examples/pytorch/ogb/sign), [PyTorch](../examples/pytorch/sign)\n    - Tags: node classification, OGB, large-scale, heterogeneous graph\n- <a name=\"prestrategy\"></a> Hu et al. Strategies for Pre-training Graph Neural Networks. [Paper link](https://arxiv.org/abs/1905.12265).\n    - Example code: [Molecule embedding](https://github.com/awslabs/dgl-lifesci/tree/master/examples/molecule_embeddings), [PyTorch for custom data](https://github.com/awslabs/dgl-lifesci/tree/master/examples/property_prediction/csv_data_configuration)\n    - Tags: molecules, graph classification, unsupervised learning, self-supervised learning, molecular property prediction\n- <a name=\"gnnfilm\"></a> Marc Brockschmidt. GNN-FiLM: Graph Neural Networks with Feature-wise Linear Modulation. [Paper link](https://arxiv.org/abs/1906.12192).\n    - Example code: [PyTorch](../examples/pytorch/GNN-FiLM)\n    - Tags: multi-relational graphs, hypernetworks, GNN architectures\n- <a name=\"gxn\"></a> Li, Maosen, et al. Graph Cross Networks with Vertex Infomax Pooling. [Paper link](https://arxiv.org/abs/2010.01804).\n    - Example code: [PyTorch](../examples/pytorch/gxn)\n    - Tags: pooling, graph classification\n- <a name=\"dagnn\"></a> Liu et al. Towards Deeper Graph Neural Networks. [Paper link](https://arxiv.org/abs/2007.09296).\n    - Example code: [PyTorch](../examples/pytorch/dagnn)\n    - Tags: over-smoothing, node classification\n- <a name=\"dimenet\"></a> Klicpera et al. Directional Message Passing for Molecular Graphs. [Paper link](https://arxiv.org/abs/2003.03123).\n    - Example code: [PyTorch](../examples/pytorch/dimenet)\n    - Tags: molecules, molecular property prediction, quantum chemistry\n- <a name=\"tgn\"></a> Rossi et al. Temporal Graph Networks For Deep Learning on Dynamic Graphs. [Paper link](https://arxiv.org/abs/2006.10637).\n    - Example code: [Pytorch](../examples/pytorch/tgn)\n    - Tags: temporal, node classification \n- <a name=\"compgcn\"></a> Vashishth, Shikhar, et al. Composition-based Multi-Relational Graph Convolutional Networks. [Paper link](https://arxiv.org/abs/1911.03082).\n    - Example code: [PyTorch](../examples/pytorch/compGCN)\n    - Tags: multi-relational graphs, graph neural network\n- <a name=\"deepergcn\"></a> Li et al. DeeperGCN: All You Need to Train Deeper GCNs. [Paper link](https://arxiv.org/abs/2006.07739).\n    - Example code: [PyTorch](../examples/pytorch/deepergcn)\n    - Tags: over-smoothing, deeper gnn, OGB\n\n- <a name=\"tahin\"></a> Bi, Ye, et al. A Heterogeneous Information Network based Cross DomainInsurance Recommendation System for Cold Start Users. [Paper link](https://arxiv.org/abs/2007.15293).\n    - Example code: [Pytorch](../examples/pytorch/TAHIN)\n    - Tags: cross-domain recommendation, graph neural network\n- <a name=\"magnn\"></a> Fu X, Zhang J, Meng Z, et al. MAGNN: metapath aggregated graph neural network for heterogeneous graph embedding. [Paper link](https://dl.acm.org/doi/abs/10.1145/3366423.3380297).\n    - Example code: [OpenHGNN](https://github.com/BUPT-GAMMA/OpenHGNN/tree/main/openhgnn/output/MAGNN)\n    - Tags: Heterogeneous graph, Graph neural network, Graph embedding\n- <a name=\"nshe\"></a> Zhao J, Wang X, et al. Network Schema Preserving Heterogeneous Information Network Embedding. [Paper link](https://www.ijcai.org/Proceedings/2020/0190.pdf).\n    - Example code: [OpenHGNN](https://github.com/BUPT-GAMMA/OpenHGNN/tree/main/openhgnn/output/NSHE)\n    - Tags: Heterogeneous graph, Graph neural network, Graph embedding, Network Schema\n- <a name=\"caregnn\"></a> Dou Y, Liu Z, et al. Enhancing Graph Neural Network-based Fraud Detectors against Camouflaged Fraudsters. [Paper link](https://arxiv.org/abs/2008.08692).\n    - Example code: [PyTorch](../examples/pytorch/caregnn)\n    - Tags: Multi-relational graph, Graph neural network, Fraud detection, Reinforcement learning, Node classification\n- <a name=\"seal_ogbl\"></a>  Zhang et al. Labeling Trick: A Theory of Using Graph Neural Networks for Multi-Node Representation Learning. [Paper link](https://arxiv.org/pdf/2010.16103.pdf).\n    - Example code: [PyTorch](../examples/pytorch/ogb/seal_ogbl)\n    - Tags: link prediction, labeling trick, OGB\n\n## 2019\n\n- <a name=\"infograph\"></a> Sun et al. InfoGraph: Unsupervised and Semi-supervised Graph-Level Representation Learning via Mutual Information Maximization. [Paper link](https://arxiv.org/abs/1908.01000). \n    - Example code: [PyTorch](../examples/pytorch/infograph)\n    - Tags: semi-supervised graph regression, unsupervised graph classification\n- <a name=\"arma\"></a>  Bianchi et al. Graph Neural Networks with Convolutional ARMA Filters. [Paper link](https://arxiv.org/abs/1901.01343).\n    - Example code: [PyTorch](../examples/pytorch/arma)\n    - Tags: node classification\n- <a name=\"appnp\"></a> Klicpera et al. Predict then Propagate: Graph Neural Networks meet Personalized PageRank. [Paper link](https://arxiv.org/abs/1810.05997).\n    - Example code: [PyTorch](../examples/pytorch/appnp), [MXNet](../examples/mxnet/appnp)\n    - Tags: node classification\n- <a name=\"clustergcn\"></a> Chiang et al. Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks. [Paper link](https://arxiv.org/abs/1905.07953).\n    - Example code: [PyTorch](../examples/pytorch/cluster_gcn), [PyTorch-based GraphSAGE variant on OGB](../examples/pytorch/ogb/cluster-sage), [PyTorch-based GAT variant on OGB](../examples/pytorch/ogb/cluster-gat)\n    - Tags: graph partition, node classification, large-scale, OGB, sampling\n- <a name=\"dgi\"></a> Veličković et al. Deep Graph Infomax. [Paper link](https://arxiv.org/abs/1809.10341).\n    - Example code: [PyTorch](../examples/pytorch/dgi), [TensorFlow](../examples/tensorflow/dgi)\n    - Tags: unsupervised learning, node classification\n- <a name=\"diffpool\"></a> Ying et al. Hierarchical Graph Representation Learning with Differentiable Pooling. [Paper link](https://arxiv.org/abs/1806.08804).\n    - Example code: [PyTorch](../examples/pytorch/diffpool)\n    - Tags: pooling, graph classification, graph coarsening\n- <a name=\"gatne-t\"></a> Cen et al. Representation Learning for Attributed Multiplex Heterogeneous Network. [Paper link](https://arxiv.org/abs/1905.01669v2).\n    - Example code: [PyTorch](../examples/pytorch/GATNE-T)\n    - Tags: heterogeneous graph, link prediction, large-scale\n- <a name=\"gin\"></a> Xu et al. How Powerful are Graph Neural Networks? [Paper link](https://arxiv.org/abs/1810.00826).\n    - Example code: [PyTorch on graph classification](../examples/pytorch/gin), [PyTorch on node classification](../examples/pytorch/model_zoo/citation_network), [PyTorch on ogbg-ppa](https://github.com/awslabs/dgl-lifesci/tree/master/examples/property_prediction/ogbg_ppa), [MXNet](../examples/mxnet/gin)\n    - Tags: graph classification, node classification, OGB\n- <a name=\"graphwriter\"></a> Koncel-Kedziorski et al. Text Generation from Knowledge Graphs with Graph Transformers. [Paper link](https://arxiv.org/abs/1904.02342).\n    - Example code: [PyTorch](../examples/pytorch/graphwriter)\n    - Tags: knowledge graph, text generation\n- <a name=\"han\"></a> Wang et al. Heterogeneous Graph Attention Network. [Paper link](https://arxiv.org/abs/1903.07293).\n    - Example code: [PyTorch](../examples/pytorch/han), [OpenHGNN](https://github.com/BUPT-GAMMA/OpenHGNN/tree/main/openhgnn/output/HAN)\n    - Tags: heterogeneous graph, node classification\n- <a name=\"lgnn\"></a> Chen et al. Supervised Community Detection with Line Graph Neural Networks. [Paper link](https://arxiv.org/abs/1705.08415).\n    - Example code: [PyTorch](../examples/pytorch/line_graph)\n    - Tags: line graph, community detection\n- <a name=\"sgc\"></a> Wu et al. Simplifying Graph Convolutional Networks. [Paper link](https://arxiv.org/abs/1902.07153).\n    - Example code: [PyTorch](../examples/pytorch/sgc), [MXNet](../examples/mxnet/sgc)\n    - Tags: node classification\n- <a name=\"dgcnnpoint\"></a> Wang et al. Dynamic Graph CNN for Learning on Point Clouds. [Paper link](https://arxiv.org/abs/1801.07829).\n    - Example code: [PyTorch](../examples/pytorch/pointcloud/edgeconv)\n    - Tags: point cloud classification\n- <a name=\"scenegraph\"></a> Zhang et al. Graphical Contrastive Losses for Scene Graph Parsing. [Paper link](https://arxiv.org/abs/1903.02728).\n    - Example code: [MXNet](../examples/mxnet/scenegraph)\n    - Tags: scene graph extraction\n- <a name=\"settrans\"></a> Lee et al. Set Transformer: A Framework for Attention-based Permutation-Invariant Neural Networks. [Paper link](https://arxiv.org/abs/1810.00825).\n    - Pooling module: [PyTorch encoder](https://docs.dgl.ai/api/python/nn.pytorch.html#settransformerencoder), [PyTorch decoder](https://docs.dgl.ai/api/python/nn.pytorch.html#settransformerdecoder)\n    - Tags: graph classification\n- <a name=\"wln\"></a> Coley et al. A graph-convolutional neural network model for the prediction of chemical reactivity. [Paper link](https://pubs.rsc.org/en/content/articlelanding/2019/sc/c8sc04228d#!divAbstract).\n    - Example code: [PyTorch](https://github.com/awslabs/dgl-lifesci/tree/master/examples/reaction_prediction/rexgen_direct)\n    - Tags: molecules, reaction prediction\n- <a name=\"mgcn\"></a> Lu et al. Molecular Property Prediction: A Multilevel Quantum Interactions Modeling Perspective. [Paper link](https://arxiv.org/abs/1906.11081).\n    - Example code: [PyTorch](https://github.com/awslabs/dgl-lifesci/tree/master/examples/property_prediction/alchemy)\n    - Tags: molecules, quantum chemistry\n- <a name=\"attentivefp\"></a> Xiong et al. Pushing the Boundaries of Molecular Representation for Drug Discovery with the Graph Attention Mechanism. [Paper link](https://pubs.acs.org/doi/10.1021/acs.jmedchem.9b00959).\n    - Example code: [PyTorch (with attention visualization)](https://github.com/awslabs/dgl-lifesci/tree/master/examples/property_prediction/pubchem_aromaticity), [PyTorch for custom data](https://github.com/awslabs/dgl-lifesci/tree/master/examples/property_prediction/csv_data_configuration)\n    - Tags: molecules, molecular property prediction\n- <a name=\"rotate\"></a> Sun et al. RotatE: Knowledge Graph Embedding by Relational Rotation in Complex Space. [Paper link](https://arxiv.org/pdf/1902.10197.pdf).\n    - Example code: [PyTorch](https://github.com/awslabs/dgl-ke/tree/master/examples), [PyTorch for custom data](https://aws-dglke.readthedocs.io/en/latest/commands.html)\n    - Tags: knowledge graph\n- <a name=\"mixhop\"></a> Abu-El-Haija et al. MixHop: Higher-Order Graph Convolutional Architectures via Sparsified Neighborhood Mixing. [Paper link](https://arxiv.org/abs/1905.00067).\n    - Example code: [PyTorch](../examples/pytorch/mixhop)\n    - Tags: node classification\n- <a name=\"sagpool\"></a> Lee, Junhyun, et al. Self-Attention Graph Pooling. [Paper link](https://arxiv.org/abs/1904.08082).\n    - Example code: [PyTorch](../examples/pytorch/sagpool)\n    - Tags: graph classification, pooling\n- <a name=\"hgp-sl\"></a> Zhang, Zhen, et al. Hierarchical Graph Pooling with Structure Learning. [Paper link](https://arxiv.org/abs/1911.05954).\n    - Example code: [PyTorch](../examples/pytorch/hgp_sl)\n    - Tags: graph classification, pooling\n- <a name='hardgat'></a> Gao, Hongyang, et al. Graph Representation Learning via Hard and Channel-Wise Attention Networks [Paper link](https://arxiv.org/abs/1907.04652).\n    - Example code: [PyTorch](../examples/pytorch/hardgat)\n    - Tags: node classification, graph attention\n- <a name='ngcf'></a> Wang, Xiang, et al. Neural Graph Collaborative Filtering. [Paper link](https://arxiv.org/abs/1905.08108).\n    - Example code: [PyTorch](../examples/pytorch/NGCF)\n    - Tags: Collaborative Filtering, recommender system, Graph Neural Network \n- <a name='gnnexplainer'></a> Ying, Rex, et al. GNNExplainer: Generating Explanations for Graph Neural Networks. [Paper link](https://arxiv.org/abs/1903.03894).\n    - Example code: [PyTorch](../examples/pytorch/gnn_explainer)\n    - Tags: Graph Neural Network, Explainability\n- <a name='hetgnn'></a> Zhang C, Song D, et al. Heterogeneous graph neural network. [Paper link](https://dl.acm.org/doi/abs/10.1145/3292500.3330961).\n    - Example code: [OpenHGNN](https://github.com/BUPT-GAMMA/OpenHGNN/tree/main/openhgnn/output/HetGNN)\n    - Tags:  Heterogeneous graph, Graph neural network, Graph embedding\n- <a name='gtn'></a> Yun S, Jeong M, et al. Graph transformer networks. [Paper link](https://arxiv.org/abs/1911.06455).\n    - Example code: [OpenHGNN](https://github.com/BUPT-GAMMA/OpenHGNN/tree/main/openhgnn/output/GTN)\n    - Tags:  Heterogeneous graph, Graph neural network, Graph structure\n- <a name='gas'></a> Li A, Qin Z, et al. Spam Review Detection with Graph Convolutional Networks. [Paper link](https://arxiv.org/abs/1908.10679).\n    - Example code: [PyTorch](../examples/pytorch/gas)\n    - Tags:  Fraud detection, Heterogeneous graph, Edge classification, Graph attention\n- <a name='geniepath'></a> Liu Z, et al. Geniepath: Graph neural networks with adaptive receptive paths. [Paper link](https://arxiv.org/abs/1802.00910).\n    - Example code: [PyTorch](../examples/pytorch/geniepath)\n    - Tags:  Fraud detection, Node classification, Graph attention, LSTM, Adaptive receptive fields\n- <a name='pgnn'></a> You J, et al. Position-aware graph neural networks. [Paper link](https://arxiv.org/abs/1906.04817).\n    - Example code: [PyTorch](../examples/pytorch/P-GNN)\n    - Tags:  Positional encoding, Link prediction, Link-pair prediction\n\n## 2018\n\n- <a name=\"dgmg\"></a> Li et al. Learning Deep Generative Models of Graphs. [Paper link](https://arxiv.org/abs/1803.03324).\n    - Example code: [PyTorch example for cycles](../examples/pytorch/dgmg), [PyTorch example for molecules](https://github.com/awslabs/dgl-lifesci/tree/master/examples/generative_models/dgmg)\n    - Tags: generative models, autoregressive models, molecules\n\n- <a name=\"gat\"></a> Veličković et al. Graph Attention Networks. [Paper link](https://arxiv.org/abs/1710.10903).\n    - Example code: [PyTorch](../examples/pytorch/gat), [PyTorch on ogbn-arxiv](../examples/pytorch/ogb/ogbn-arxiv), [PyTorch on ogbn-products](../examples/pytorch/ogb/ogbn-products), [TensorFlow](../examples/tensorflow/gat), [MXNet](../examples/mxnet/gat)\n    - Tags: node classification, OGB\n\n- <a name=\"jtvae\"></a> Jin et al. Junction Tree Variational Autoencoder for Molecular Graph Generation. [Paper link](https://arxiv.org/abs/1802.04364).\n    - Example code: [PyTorch](../examples/pytorch/jtnn)\n    - Tags: generative models, molecules, VAE\n\n- <a name=\"agnn\"></a> Thekumparampil et al. Attention-based Graph Neural Network for Semi-supervised Learning. [Paper link](https://arxiv.org/abs/1803.03735).\n    - Example code: [PyTorch](../examples/pytorch/model_zoo/citation_network)\n    - Tags: node classification\n    \n- <a name=\"pinsage\"></a> Ying et al. Graph Convolutional Neural Networks for Web-Scale Recommender Systems. [Paper link](https://arxiv.org/abs/1806.01973).\n    - Example code: [PyTorch](../examples/pytorch/pinsage)\n    - Tags: recommender system, large-scale, sampling\n\n- <a name=\"rrn\"></a> Berg Palm et al. Recurrent Relational Networks. [Paper link](https://arxiv.org/abs/1711.08028).\n    - Example code: [PyTorch](../examples/pytorch/rrn)\n    - Tags: sudoku solving\n\n- <a name=\"stgcn\"></a> Yu et al. Spatio-Temporal Graph Convolutional Networks: A Deep Learning Framework for Traffic Forecasting. [Paper link](https://arxiv.org/abs/1709.04875v4).\n    - Example code: [PyTorch](../examples/pytorch/stgcn_wave)\n    - Tags: spatio-temporal, traffic forecasting\n\n- <a name=\"dgcnn\"></a> Zhang et al. An End-to-End Deep Learning Architecture for Graph Classification. [Paper link](https://www.cse.wustl.edu/~ychen/public/DGCNN.pdf).\n    - Pooling module: [PyTorch](https://docs.dgl.ai/api/python/nn.pytorch.html#sortpooling), [TensorFlow](https://docs.dgl.ai/api/python/nn.tensorflow.html#sortpooling), [MXNet](https://docs.dgl.ai/api/python/nn.mxnet.html#sortpooling)\n    - Tags: graph classification\n\n- <a name=\"seal\"></a>  Zhang et al. Link Prediction Based on Graph Neural Networks. [Paper link](https://papers.nips.cc/paper/2018/file/53f0d7c537d99b3824f0f99d62ea2428-Paper.pdf).\n    - Example code: [PyTorch](../examples/pytorch/seal)\n    - Tags: link prediction, sampling\n\n- <a name=\"jknet\"></a>  Xu et al. Representation Learning on Graphs with Jumping Knowledge Networks. [Paper link](https://arxiv.org/abs/1806.03536).\n    - Example code: [PyTorch](../examples/pytorch/jknet)\n    - Tags: message passing, neighborhood\n\n- <a name=\"gaan\"></a> Zhang et al. GaAN: Gated Attention Networks for Learning on Large and Spatiotemporal Graphs. [Paper link](https://arxiv.org/abs/1803.07294).\n    - Example code: [pytorch](../examples/pytorch/dtgrnn)\n    - Tags: Static discrete temporal graph, traffic forecasting\n\n- <a name=\"hgnn\"></a> Feng et al. Hypergraph Neural Networks. [Paper link](https://arxiv.org/abs/1809.09401).\n    - Example code: [pytorch](../examples/sparse/hgnn)\n    - Tags: hypergraph\n\n## 2017\n\n- <a name=\"gcn\"></a> Kipf and Welling. Semi-Supervised Classification with Graph Convolutional Networks. [Paper link](https://arxiv.org/abs/1609.02907). \n    - Example code: [PyTorch](../examples/pytorch/gcn), [PyTorch on ogbn-arxiv](../examples/pytorch/ogb/ogbn-arxiv), [PyTorch on ogbl-ppa](https://github.com/awslabs/dgl-lifesci/tree/master/examples/link_prediction/ogbl-ppa), [PyTorch on ogbg-ppa](https://github.com/awslabs/dgl-lifesci/tree/master/examples/property_prediction/ogbg_ppa), [TensorFlow](../examples/tensorflow/gcn), [MXNet](../examples/mxnet/gcn)\n    - Tags: node classification, link prediction, graph classification, OGB\n\n- <a name=\"capsule\"></a> Sabour et al. Dynamic Routing Between Capsules. [Paper link](https://arxiv.org/abs/1710.09829).\n    - Example code: [PyTorch](../examples/pytorch/capsule)\n    - Tags: image classification\n  \n- <a name=\"gcmc\"></a> van den Berg et al. Graph Convolutional Matrix Completion. [Paper link](https://arxiv.org/abs/1706.02263).\n    - Example code: [PyTorch](../examples/pytorch/gcmc)\n    - Tags: matrix completion, recommender system, link prediction, bipartite graphs\n\n- <a name=\"graphsage\"></a> Hamilton et al. Inductive Representation Learning on Large Graphs. [Paper link](https://cs.stanford.edu/people/jure/pubs/graphsage-nips17.pdf).\n    - Example code: [PyTorch](../examples/pytorch/graphsage), [PyTorch on ogbn-products](../examples/pytorch/ogb/ogbn-products), [PyTorch on ogbn-mag](../examples/pytorch/ogb/ogbn-mag), [PyTorch on ogbl-ppa](https://github.com/awslabs/dgl-lifesci/tree/master/examples/link_prediction/ogbl-ppa), [MXNet](../examples/mxnet/graphsage)\n    - Tags: node classification, sampling, unsupervised learning, link prediction, OGB\n\n- <a name=\"metapath2vec\"></a> Dong et al. metapath2vec: Scalable Representation Learning for Heterogeneous Networks. [Paper link](https://dl.acm.org/doi/10.1145/3097983.3098036).\n    - Example code: [PyTorch](../examples/pytorch/metapath2vec)\n    - Tags: heterogeneous graph, network embedding, large-scale, node classification\n\n- <a name=\"tagcn\"></a> Du et al. Topology Adaptive Graph Convolutional Networks. [Paper link](https://arxiv.org/abs/1710.10370).\n    - Example code: [PyTorch](../examples/pytorch/tagcn), [MXNet](../examples/mxnet/tagcn)\n    - Tags: node classification\n    \n- <a name=\"pointnet\"></a> Qi et al. PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation. [Paper link](https://arxiv.org/abs/1612.00593).\n    - Example code: [PyTorch](../examples/pytorch/pointcloud/pointnet)\n    - Tags: point cloud classification, point cloud part-segmentation\n\n- <a name=\"pointnet++\"></a> Qi et al. PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space. [Paper link](https://arxiv.org/abs/1706.02413).\n    - Example code: [PyTorch](../examples/pytorch/pointcloud/pointnet)\n    - Tags: point cloud classification\n    \n- <a name=\"rgcn\"></a> Schlichtkrull. Modeling Relational Data with Graph Convolutional Networks. [Paper link](https://arxiv.org/abs/1703.06103).\n    - Example code: [PyTorch example using homogeneous DGLGraphs](../examples/pytorch/rgcn), [PyTorch](../examples/pytorch/rgcn-hetero), [TensorFlow](../examples/tensorflow/rgcn), [MXNet](../examples/mxnet/rgcn)\n    - Tags: node classification, link prediction, heterogeneous graph, sampling\n\n- <a name=\"transformer\"></a> Vaswani et al. Attention Is All You Need. [Paper link](https://arxiv.org/abs/1706.03762).\n    - Example code: [PyTorch](../examples/pytorch/transformer)\n    - Tags: machine translation\n\n- <a name=\"mpnn\"></a> Gilmer et al. Neural Message Passing for Quantum Chemistry. [Paper link](https://arxiv.org/abs/1704.01212).\n    - Example code: [PyTorch](https://github.com/awslabs/dgl-lifesci/tree/master/examples/property_prediction/alchemy), [PyTorch for custom data](https://github.com/awslabs/dgl-lifesci/tree/master/examples/property_prediction/csv_data_configuration)\n    - Tags: molecules, quantum chemistry\n\n- <a name=\"acnn\"></a> Gomes et al. Atomic Convolutional Networks for Predicting Protein-Ligand Binding Affinity. [Paper link](https://arxiv.org/abs/1703.10603).\n    - Example code: [PyTorch](https://github.com/awslabs/dgl-lifesci/tree/master/examples/binding_affinity_prediction)\n    - Tags: binding affinity prediction, molecules, proteins\n\n- <a name=\"schnet\"></a> Schütt et al. SchNet: A continuous-filter convolutional neural network for modeling quantum interactions. [Paper link](https://arxiv.org/abs/1706.08566).\n    - Example code: [PyTorch](https://github.com/awslabs/dgl-lifesci/tree/master/examples/property_prediction/alchemy)\n    - Tags: molecules, quantum chemistry\n\n- <a name=\"dcrnn\"></a> Li et al. Diffusion Convolutional Recurrent Neural Network: Data-Driven Traffic Forcasting. [Paper link](https://arxiv.org/abs/1707.01926).\n    - Example code: [Pytorch](../examples/pytorch/dtgrnn)\n    - Tags: Static discrete temporal graph, traffic forecasting\n\n## 2016\n\n- <a name=\"ggnn\"></a> Li et al. Gated Graph Sequence Neural Networks. [Paper link](https://arxiv.org/abs/1511.05493).\n    - Example code: [PyTorch](../examples/pytorch/ggnn)\n    - Tags: question answering\n- <a name=\"chebnet\"></a> Defferrard et al. Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering. [Paper link](https://arxiv.org/abs/1606.09375).\n    - Example code: [PyTorch on image classification](../examples/pytorch/model_zoo/geometric), [PyTorch on node classification](../examples/pytorch/model_zoo/citation_network)\n    - Tags: image classification, graph classification, node classification\n- <a name=\"monet\"></a> Monti et al. Geometric deep learning on graphs and manifolds using mixture model CNNs. [Paper link](https://arxiv.org/abs/1611.08402).\n    - Example code: [PyTorch on image classification](../examples/pytorch/model_zoo/geometric), [PyTorch on node classification](../examples/pytorch/monet), [MXNet on node classification](../examples/mxnet/monet)\n    - Tags: image classification, graph classification, node classification\n- <a name=\"weave\"></a> Kearnes et al. Molecular Graph Convolutions: Moving Beyond Fingerprints. [Paper link](https://arxiv.org/abs/1603.00856).\n    - Example code: [PyTorch](https://github.com/awslabs/dgl-lifesci/tree/master/examples/property_prediction/moleculenet), [PyTorch for custom data](https://github.com/awslabs/dgl-lifesci/tree/master/examples/property_prediction/csv_data_configuration)\n    - Tags: molecular property prediction\n- <a name=\"complex\"></a> Trouillon et al. Complex Embeddings for Simple Link Prediction. [Paper link](http://proceedings.mlr.press/v48/trouillon16.pdf).\n    - Example code: [PyTorch](https://github.com/awslabs/dgl-ke/tree/master/examples), [PyTorch for custom data](https://aws-dglke.readthedocs.io/en/latest/commands.html)\n    - Tags: knowledge graph\n- <a name=\"vgae\"></a> Thomas et al. Variational Graph Auto-Encoders. [Paper link](https://arxiv.org/abs/1611.07308).\n    - Example code: [PyTorch](../examples/pytorch/vgae)\n    - Tags: link prediction\n\n## 2015\n\n- <a name=\"line\"></a> Tang et al. LINE: Large-scale Information Network Embedding. [Paper link](https://arxiv.org/abs/1503.03578).\n    - Example code: [PyTorch on OGB](../examples/pytorch/ogb/line)\n    - Tags: network embedding, transductive learning, OGB, link prediction\n\n- <a name=\"treelstm\"></a> Sheng Tai et al. Improved Semantic Representations From Tree-Structured Long Short-Term Memory Networks. [Paper link](https://arxiv.org/abs/1503.00075).\n    - Example code: [PyTorch](../examples/pytorch/tree_lstm), [MXNet](../examples/mxnet/tree_lstm)\n    - Tags: sentiment classification\n    \n- <a name=\"seq2seq\"></a> Vinyals et al. Order Matters: Sequence to sequence for sets. [Paper link](https://arxiv.org/abs/1511.06391).\n    - Pooling module: [PyTorch](https://docs.dgl.ai/api/python/nn.pytorch.html#set2set), [MXNet](https://docs.dgl.ai/api/python/nn.mxnet.html#set2set)\n    - Tags: graph classification\n    \n- <a name=\"transr\"></a> Lin et al. Learning Entity and Relation Embeddings for Knowledge Graph Completion. [Paper link](https://www.aaai.org/ocs/index.php/AAAI/AAAI15/paper/viewPaper/9571).\n    - Example code: [PyTorch](https://github.com/awslabs/dgl-ke/tree/master/examples), [PyTorch for custom data](https://aws-dglke.readthedocs.io/en/latest/commands.html)\n    - Tags: knowledge graph\n\n- <a name=\"distmul\"></a> Yang et al. Embedding Entities and Relations for Learning and Inference in Knowledge Bases. [Paper link](https://arxiv.org/abs/1412.6575).\n    - Example code: [PyTorch](https://github.com/awslabs/dgl-ke/tree/master/examples), [PyTorch for custom data](https://aws-dglke.readthedocs.io/en/latest/commands.html)\n    - Tags: knowledge graph\n\n- <a name=\"nf\"></a> Duvenaud et al. Convolutional Networks on Graphs for Learning Molecular Fingerprints. [Paper link](https://arxiv.org/abs/1509.09292).\n    - Example code: [PyTorch](https://github.com/awslabs/dgl-lifesci/tree/master/examples/property_prediction/moleculenet), [PyTorch for custom data](https://github.com/awslabs/dgl-lifesci/tree/master/examples/property_prediction/csv_data_configuration)\n    - Tags: molecules, molecular property prediction\n\n## 2014\n\n- <a name=\"deepwalk\"></a> Perozzi et al. DeepWalk: Online Learning of Social Representations. [Paper link](https://arxiv.org/abs/1403.6652).\n    - Example code: [PyTorch on OGB](../examples/pytorch/ogb/deepwalk)\n    - Tags: network embedding, transductive learning, OGB, link prediction\n\n- <a name=\"hausdorff\"></a> Fischer et al. A Hausdorff Heuristic for Efficient Computation of Graph Edit Distance. [Paper link](https://link.springer.com/chapter/10.1007/978-3-662-44415-3_9).\n    - Example code: [PyTorch](../examples/pytorch/graph_matching)\n    - Tags: graph edit distance, graph matching\n\n## 2013\n\n- <a name=\"transe\"></a> Bordes et al. Translating Embeddings for Modeling Multi-relational Data. [Paper link](https://proceedings.neurips.cc/paper/2013/file/1cecc7a77928ca8133fa24680a88d2f9-Paper.pdf).\n    - Example code: [PyTorch](https://github.com/awslabs/dgl-ke/tree/master/examples), [PyTorch for custom data](https://aws-dglke.readthedocs.io/en/latest/commands.html)\n    - Tags: knowledge graph\n\n## 2011\n\n- <a name=\"bipartite\"></a> Fankhauser et al. Speeding Up Graph Edit Distance Computation through Fast Bipartite Matching. [Paper link](https://link.springer.com/chapter/10.1007/978-3-642-20844-7_11).\n    - Example code: [PyTorch](../examples/pytorch/graph_matching)\n    - Tags: graph edit distance, graph matching\n\n- <a name=\"rescal\"></a> Nickel et al. A Three-Way Model for Collective Learning on Multi-Relational Data. [Paper link](http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.383.2015&rep=rep1&type=pdf).\n    - Example code: [PyTorch](https://github.com/awslabs/dgl-ke/tree/master/examples), [PyTorch for custom data](https://aws-dglke.readthedocs.io/en/latest/commands.html)\n    - Tags: knowledge graph\n\n## 2010\n\n- <a name=\"lda\"></a> Hoffman et al. Online Learning for Latent Dirichlet Allocation. [Paper link](https://papers.nips.cc/paper/2010/file/71f6278d140af599e06ad9bf1ba03cb0-Paper.pdf).\n    - Example code: [PyTorch](../examples/pytorch/lda)\n    - Tags: sklearn, decomposition, latent Dirichlet allocation\n\n## 2009\n\n- <a name=\"astar\"></a> Riesen et al. Speeding Up Graph Edit Distance Computation with a Bipartite Heuristic. [Paper link](https://core.ac.uk/download/pdf/33054885.pdf).\n    - Example code: [PyTorch](../examples/pytorch/graph_matching)\n    - Tags: graph edit distance, graph matching\n\n## 2006\n\n- <a name=\"beam\"></a> Neuhaus et al. Fast Suboptimal Algorithms for the Computation of Graph Edit Distance. [Paper link](https://link.springer.com/chapter/10.1007/11815921_17).\n    - Example code: [PyTorch](../examples/pytorch/graph_matching)\n    - Tags: graph edit distance, graph matching\n\n## 2002\n\n- <a name=\"label_propagation\"></a> Zhu & Ghahramani. Learning from Labeled and Unlabeled Data with Label Propagation. [Paper link](https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.14.3864&rep=rep1&type=pdf).\n    - Example code: [PyTorch](../examples/pytorch/label_propagation)\n    - Tags: node classification, label propagation\n\n## 1998\n\n- <a name=\"pagerank\"></a> Page et al. The PageRank Citation Ranking: Bringing Order to the Web. [Paper link](http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.38.5427).\n    - Example code: [PyTorch](../examples/pytorch/pagerank.py)\n    - Tags: PageRank\n"
  },
  {
    "path": "examples/advanced/cugraph/graphsage.py",
    "content": "import argparse\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torchmetrics.functional as MF\nimport tqdm\nfrom dgl.data import AsNodePredDataset\nfrom dgl.dataloading import (\n    DataLoader,\n    MultiLayerFullNeighborSampler,\n    NeighborSampler,\n)\nfrom dgl.nn import CuGraphSAGEConv\nfrom ogb.nodeproppred import DglNodePropPredDataset\n\n\nclass SAGE(nn.Module):\n    def __init__(self, in_size, hid_size, out_size):\n        super().__init__()\n        self.layers = nn.ModuleList()\n        # three-layer GraphSAGE-mean\n        self.layers.append(CuGraphSAGEConv(in_size, hid_size, \"mean\"))\n        self.layers.append(CuGraphSAGEConv(hid_size, hid_size, \"mean\"))\n        self.layers.append(CuGraphSAGEConv(hid_size, out_size, \"mean\"))\n        self.dropout = nn.Dropout(0.5)\n        self.hid_size = hid_size\n        self.out_size = out_size\n\n    def forward(self, blocks, x):\n        h = x\n        for l, (layer, block) in enumerate(zip(self.layers, blocks)):\n            h = layer(block, h)\n            if l != len(self.layers) - 1:\n                h = F.relu(h)\n                h = self.dropout(h)\n        return h\n\n    def inference(self, g, device, batch_size):\n        \"\"\"Conduct layer-wise inference to get all the node embeddings.\"\"\"\n        feat = g.ndata[\"feat\"]\n        sampler = MultiLayerFullNeighborSampler(1, prefetch_node_feats=[\"feat\"])\n        dataloader = DataLoader(\n            g,\n            torch.arange(g.num_nodes()).to(g.device),\n            sampler,\n            device=device,\n            batch_size=batch_size,\n            shuffle=False,\n            drop_last=False,\n            num_workers=0,\n        )\n        buffer_device = torch.device(\"cpu\")\n        pin_memory = buffer_device != device\n\n        for l, layer in enumerate(self.layers):\n            y = torch.empty(\n                g.num_nodes(),\n                self.hid_size if l != len(self.layers) - 1 else self.out_size,\n                device=buffer_device,\n                pin_memory=pin_memory,\n            )\n            feat = feat.to(device)\n            for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):\n                x = feat[input_nodes]\n                h = layer(blocks[0], x)  # len(blocks) = 1\n                if l != len(self.layers) - 1:\n                    h = F.relu(h)\n                    h = self.dropout(h)\n                # by design, our output nodes are contiguous\n                y[output_nodes[0] : output_nodes[-1] + 1] = h.to(buffer_device)\n            feat = y\n        return y\n\n\ndef evaluate(model, graph, dataloader):\n    model.eval()\n    ys = []\n    y_hats = []\n    for it, (input_nodes, output_nodes, blocks) in enumerate(dataloader):\n        with torch.no_grad():\n            x = blocks[0].srcdata[\"feat\"]\n            ys.append(blocks[-1].dstdata[\"label\"])\n            y_hats.append(model(blocks, x))\n    num_classes = y_hats[0].shape[1]\n    return MF.accuracy(\n        torch.cat(y_hats),\n        torch.cat(ys),\n        task=\"multiclass\",\n        num_classes=num_classes,\n    )\n\n\ndef layerwise_infer(device, graph, nid, model, batch_size):\n    model.eval()\n    with torch.no_grad():\n        pred = model.inference(\n            graph, device, batch_size\n        )  # pred in buffer_device\n        pred = pred[nid]\n        label = graph.ndata[\"label\"][nid].to(pred.device)\n        num_classes = pred.shape[1]\n        return MF.accuracy(\n            pred, label, task=\"multiclass\", num_classes=num_classes\n        )\n\n\ndef train(args, device, g, dataset, model):\n    # create sampler & dataloader\n    train_idx = dataset.train_idx.to(device)\n    val_idx = dataset.val_idx.to(device)\n    sampler = NeighborSampler(\n        [10, 10, 10],  # fanout for [layer-0, layer-1, layer-2]\n        prefetch_node_feats=[\"feat\"],\n        prefetch_labels=[\"label\"],\n    )\n    use_uva = args.mode == \"mixed\"\n    train_dataloader = DataLoader(\n        g,\n        train_idx,\n        sampler,\n        device=device,\n        batch_size=1024,\n        shuffle=True,\n        drop_last=False,\n        num_workers=0,\n        use_uva=use_uva,\n    )\n\n    val_dataloader = DataLoader(\n        g,\n        val_idx,\n        sampler,\n        device=device,\n        batch_size=1024,\n        shuffle=True,\n        drop_last=False,\n        num_workers=0,\n        use_uva=use_uva,\n    )\n\n    opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4)\n\n    for epoch in range(10):\n        model.train()\n        total_loss = 0\n        for it, (input_nodes, output_nodes, blocks) in enumerate(\n            train_dataloader\n        ):\n            x = blocks[0].srcdata[\"feat\"]\n            y = blocks[-1].dstdata[\"label\"]\n            y_hat = model(blocks, x)\n            loss = F.cross_entropy(y_hat, y)\n            opt.zero_grad()\n            loss.backward()\n            opt.step()\n\n            total_loss += loss.item()\n        acc = evaluate(model, g, val_dataloader)\n        print(\n            \"Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} \".format(\n                epoch, total_loss / (it + 1), acc.item()\n            )\n        )\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--mode\",\n        default=\"mixed\",\n        choices=[\"mixed\", \"puregpu\"],\n        help=\"Training mode. 'mixed' for CPU-GPU mixed training, \"\n        \"'puregpu' for pure-GPU training.\",\n    )\n    args = parser.parse_args()\n    if not torch.cuda.is_available():\n        args.mode = \"cpu\"\n    print(f\"Training in {args.mode} mode.\")\n\n    # load and preprocess dataset\n    print(\"Loading data\")\n    dataset = AsNodePredDataset(DglNodePropPredDataset(\"ogbn-products\"))\n    g = dataset[0]\n    g = g.to(\"cuda\" if args.mode == \"puregpu\" else \"cpu\")\n    device = torch.device(\"cpu\" if args.mode == \"cpu\" else \"cuda\")\n\n    # create GraphSAGE model\n    in_size = g.ndata[\"feat\"].shape[1]\n    out_size = dataset.num_classes\n    model = SAGE(in_size, 256, out_size).to(device)\n\n    # model training\n    print(\"Training...\")\n    train(args, device, g, dataset, model)\n\n    # test the model\n    print(\"Testing...\")\n    acc = layerwise_infer(device, g, dataset.test_idx, model, batch_size=4096)\n    print(\"Test Accuracy {:.4f}\".format(acc.item()))\n"
  },
  {
    "path": "examples/advanced/cugraph/rgcn.py",
    "content": "\"\"\"\n[RGCN: Relational Graph Convolutional Networks]\n(https://arxiv.org/abs/1703.06103)\n\nThis example showcases the usage of `CuGraphRelGraphConv` via the entity\nclassification problem in the RGCN paper with mini-batch training. It offers\na 1.5~2x speed-up over `RelGraphConv` on cuda devices and only requires minimal\ncode changes from the current `entity_sample.py` example.\n\"\"\"\n\nimport argparse\n\nimport dgl\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dgl.data.rdf import AIFBDataset, AMDataset, BGSDataset, MUTAGDataset\nfrom dgl.dataloading import DataLoader, MultiLayerNeighborSampler\nfrom dgl.nn import CuGraphRelGraphConv\nfrom torchmetrics.functional import accuracy\n\n\nclass RGCN(nn.Module):\n    def __init__(self, num_nodes, h_dim, out_dim, num_rels, num_bases):\n        super().__init__()\n        self.emb = nn.Embedding(num_nodes, h_dim)\n        # two-layer RGCN\n        self.conv1 = CuGraphRelGraphConv(\n            h_dim,\n            h_dim,\n            num_rels,\n            regularizer=\"basis\",\n            num_bases=num_bases,\n            self_loop=True,\n            apply_norm=True,\n        )\n        self.conv2 = CuGraphRelGraphConv(\n            h_dim,\n            out_dim,\n            num_rels,\n            regularizer=\"basis\",\n            num_bases=num_bases,\n            self_loop=True,\n            apply_norm=True,\n        )\n\n    def forward(self, g, fanouts=[None, None]):\n        x = self.emb(g[0].srcdata[dgl.NID])\n        h = F.relu(self.conv1(g[0], x, g[0].edata[dgl.ETYPE], fanouts[0]))\n        h = self.conv2(g[1], h, g[1].edata[dgl.ETYPE], fanouts[1])\n        return h\n\n\ndef evaluate(model, labels, dataloader, inv_target):\n    model.eval()\n    eval_logits = []\n    eval_seeds = []\n    with torch.no_grad():\n        for _, output_nodes, blocks in dataloader:\n            output_nodes = inv_target[output_nodes.type(torch.int64)]\n            logits = model(blocks)\n            eval_logits.append(logits.cpu().detach())\n            eval_seeds.append(output_nodes.cpu().detach())\n    num_classes = eval_logits[0].shape[1]\n    eval_logits = torch.cat(eval_logits)\n    eval_seeds = torch.cat(eval_seeds)\n    return accuracy(\n        eval_logits.argmax(dim=1),\n        labels[eval_seeds].cpu(),\n        task=\"multiclass\",\n        num_classes=num_classes,\n    ).item()\n\n\ndef train(device, g, target_idx, labels, train_mask, model, fanouts):\n    # Define train idx, loss function and optimizer.\n    train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze()\n    loss_fcn = nn.CrossEntropyLoss()\n    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)\n    # Construct sampler and dataloader.\n    sampler = MultiLayerNeighborSampler(fanouts)\n    train_loader = DataLoader(\n        g,\n        target_idx[train_idx].type(g.idtype),\n        sampler,\n        device=device,\n        batch_size=100,\n        shuffle=True,\n    )\n    # No separate validation subset, use train index instead for validation.\n    val_loader = DataLoader(\n        g,\n        target_idx[train_idx].type(g.idtype),\n        sampler,\n        device=device,\n        batch_size=100,\n        shuffle=False,\n    )\n    for epoch in range(50):\n        model.train()\n        total_loss = 0\n        for it, (_, output_nodes, blocks) in enumerate(train_loader):\n            output_nodes = inv_target[output_nodes.type(torch.int64)]\n            logits = model(blocks, fanouts=fanouts)\n            loss = loss_fcn(logits, labels[output_nodes])\n            optimizer.zero_grad()\n            loss.backward()\n            optimizer.step()\n            total_loss += loss.item()\n        acc = evaluate(model, labels, val_loader, inv_target)\n        print(\n            f\"Epoch {epoch:05d} | Loss {total_loss / (it+1):.4f} | \"\n            f\"Val. Accuracy {acc:.4f}\"\n        )\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(\n        description=\"RGCN for entity classification with sampling\"\n    )\n    parser.add_argument(\n        \"--dataset\",\n        type=str,\n        default=\"aifb\",\n        choices=[\"aifb\", \"mutag\", \"bgs\", \"am\"],\n    )\n    args = parser.parse_args()\n    device = torch.device(\"cuda\")\n    print(f\"Training with DGL CuGraphRelGraphConv module with sampling.\")\n\n    # Load and preprocess dataset.\n    if args.dataset == \"aifb\":\n        data = AIFBDataset()\n    elif args.dataset == \"mutag\":\n        data = MUTAGDataset()\n    elif args.dataset == \"bgs\":\n        data = BGSDataset()\n    elif args.dataset == \"am\":\n        data = AMDataset()\n    else:\n        raise ValueError(f\"Unknown dataset: {args.dataset}\")\n    hg = data[0].to(device)\n    num_rels = len(hg.canonical_etypes)\n    category = data.predict_category\n\n    labels = hg.nodes[category].data.pop(\"labels\")\n    train_mask = hg.nodes[category].data.pop(\"train_mask\")\n    test_mask = hg.nodes[category].data.pop(\"test_mask\")\n\n    # Find target category and node id.\n    category_id = hg.ntypes.index(category)\n    g = dgl.to_homogeneous(hg)\n    node_ids = torch.arange(g.num_nodes()).to(device)\n    target_idx = node_ids[g.ndata[dgl.NTYPE] == category_id]\n    g.ndata[\"ntype\"] = g.ndata.pop(dgl.NTYPE)\n    g.ndata[\"type_id\"] = g.ndata.pop(dgl.NID)\n\n    # Find the mapping from global node IDs to type-specific node IDs.\n    inv_target = torch.empty((g.num_nodes(),), dtype=torch.int64).to(device)\n    inv_target[target_idx] = torch.arange(\n        0, target_idx.shape[0], dtype=inv_target.dtype\n    ).to(device)\n\n    # Create RGCN model.\n    in_size = g.num_nodes()  # featureless with one-hot encoding\n    out_size = data.num_classes\n    num_bases = 20\n    fanouts = [4, 4]\n    model = RGCN(in_size, 16, out_size, num_rels, num_bases).to(device)\n\n    train(\n        device,\n        g,\n        target_idx,\n        labels,\n        train_mask,\n        model,\n        fanouts,\n    )\n    test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze()\n    test_sampler = MultiLayerNeighborSampler([-1, -1])\n    test_loader = DataLoader(\n        g,\n        target_idx[test_idx].type(g.idtype),\n        test_sampler,\n        device=device,\n        batch_size=32,\n        shuffle=False,\n    )\n    acc = evaluate(model, labels, test_loader, inv_target)\n    print(f\"Test accuracy {acc:.4f}\")\n"
  },
  {
    "path": "examples/core/Graphormer/README.md",
    "content": "Graphormer\n==============================\n\n## Introduction\n\n* Graphormer is a Transformer model designed for graph-structured data, which encodes the structural information of a graph into the standard Transformer. Specifically, Graphormer utilizes Degree Encoding to measure the importance of nodes, Spatial Encoding and Path Encoding to measure the relation between node pairs. The former plus the node features serve as input to Graphormer, while the latter acts as bias terms in the self-attention module.\n\n* paper link: [https://arxiv.org/abs/2106.05234](https://arxiv.org/abs/2106.05234)\n\n## Requirements\n- accelerate\n- transformers\n- ogb\n\n## Dataset\n\nTask: Graph Property Prediction\n\n|   Dataset   | #Graphs | #Node Feats | #Edge Feats | Metric  |\n| :---------: | :-----: | :---------: | :---------: | :-----: |\n| ogbg-molhiv | 41,127  |      9      |      3      | ROC-AUC |\n\nHow to run\n----------\n\n```bash\naccelerate launch --multi_gpu --mixed_precision=fp16 main.py\n```\n> **_NOTE:_**  The script will automatically download weights pre-trained on PCQM4Mv2. To reproduce the same result, set the total batch size to 64.\n\n## Summary\n\n* ogbg-molhiv (pretrained on PCQM4Mv2): ~0.791\n"
  },
  {
    "path": "examples/core/Graphormer/dataset.py",
    "content": "\"\"\"\nThis file contains the MolHIVDataset class, which handles data preprocessing\n(computing required graph features, converting graphs to tensors) of the\nogbg-molhiv dataset.\n\"\"\"\nimport torch as th\nimport torch.nn.functional as F\nfrom dgl import shortest_dist\nfrom ogb.graphproppred import DglGraphPropPredDataset\nfrom torch.nn.utils.rnn import pad_sequence\n\n\nclass MolHIVDataset(th.utils.data.Dataset):\n    def __init__(self):\n        dataset = DglGraphPropPredDataset(name=\"ogbg-molhiv\")\n        split_idx = dataset.get_idx_split()\n\n        # Compute the shortest path distances and their corresponding paths\n        # of all graphs during preprocessing.\n        for g, label in dataset:\n            spd, path = shortest_dist(g, root=None, return_paths=True)\n            g.ndata[\"spd\"] = spd\n            g.ndata[\"path\"] = path\n\n        self.train, self.val, self.test = (\n            dataset[split_idx[\"train\"]],\n            dataset[split_idx[\"valid\"]],\n            dataset[split_idx[\"test\"]],\n        )\n\n    def collate(self, samples):\n        # To match Graphormer's input style, all graph features should be\n        # padded to the same size. Keep in mind that different graphs may\n        # have varying feature sizes since they have different number of\n        # nodes, so they will be aligned with the graph having the maximum\n        # number of nodes.\n        graphs, labels = map(list, zip(*samples))\n        labels = th.stack(labels)\n\n        num_graphs = len(graphs)\n        num_nodes = [g.num_nodes() for g in graphs]\n        max_num_nodes = max(num_nodes)\n\n        # Graphormer adds a virual node to the graph, which is connected to\n        # all other nodes and supposed to represent the graph embedding. So\n        # here +1 is for the virtual node.\n        attn_mask = th.zeros(num_graphs, max_num_nodes + 1, max_num_nodes + 1)\n        node_feat = []\n        in_degree, out_degree = [], []\n        path_data = []\n        # Since shortest_dist returns -1 for unreachable node pairs and padded\n        # nodes are unreachable to others, distance relevant to padded nodes\n        # use -1 padding as well.\n        dist = -th.ones(\n            (num_graphs, max_num_nodes, max_num_nodes), dtype=th.long\n        )\n\n        for i in range(num_graphs):\n            # A binary mask where invalid positions are indicated by True.\n            attn_mask[i, :, num_nodes[i] + 1 :] = 1\n\n            # +1 to distinguish padded non-existing nodes from real nodes\n            node_feat.append(graphs[i].ndata[\"feat\"] + 1)\n\n            in_degree.append(\n                th.clamp(graphs[i].in_degrees() + 1, min=0, max=512)\n            )\n            out_degree.append(\n                th.clamp(graphs[i].out_degrees() + 1, min=0, max=512)\n            )\n\n            # Path padding to make all paths to the same length \"max_len\".\n            path = graphs[i].ndata[\"path\"]\n            path_len = path.size(dim=2)\n            # shape of shortest_path: [n, n, max_len]\n            max_len = 5\n            if path_len >= max_len:\n                shortest_path = path[:, :, :max_len]\n            else:\n                p1d = (0, max_len - path_len)\n                # Use the same -1 padding as shortest_dist for\n                # invalid edge IDs.\n                shortest_path = F.pad(path, p1d, \"constant\", -1)\n            pad_num_nodes = max_num_nodes - num_nodes[i]\n            p3d = (0, 0, 0, pad_num_nodes, 0, pad_num_nodes)\n            shortest_path = F.pad(shortest_path, p3d, \"constant\", -1)\n            # +1 to distinguish padded non-existing edges from real edges\n            edata = graphs[i].edata[\"feat\"] + 1\n            # shortest_dist pads non-existing edges (at the end of shortest\n            # paths) with edge IDs -1, and th.zeros(1, edata.shape[1]) stands\n            # for all padded edge features.\n            edata = th.cat(\n                (edata, th.zeros(1, edata.shape[1]).to(edata.device)), dim=0\n            )\n            path_data.append(edata[shortest_path])\n\n            dist[i, : num_nodes[i], : num_nodes[i]] = graphs[i].ndata[\"spd\"]\n\n        # node feat padding\n        node_feat = pad_sequence(node_feat, batch_first=True)\n\n        # degree padding\n        in_degree = pad_sequence(in_degree, batch_first=True)\n        out_degree = pad_sequence(out_degree, batch_first=True)\n\n        return (\n            labels.reshape(num_graphs, -1),\n            attn_mask,\n            node_feat,\n            in_degree,\n            out_degree,\n            th.stack(path_data),\n            dist,\n        )\n"
  },
  {
    "path": "examples/core/Graphormer/main.py",
    "content": "\"\"\"\nThis script finetunes and tests a Graphormer model (pretrained on PCQM4Mv2)\nfor graph classification on ogbg-molhiv dataset.\n\nPaper: [Do Transformers Really Perform Bad for Graph Representation?]\n(https://arxiv.org/abs/2106.05234)\n\nThis flowchart describes the main functional sequence of the provided example.\nmain\n│\n└───> train_val_pipeline\n      │\n      ├───> Load and preprocess dataset\n      │\n      ├───> Download pretrained model\n      │\n      ├───> train_epoch\n      │     │\n      │     └───> Graphormer.forward\n      │\n      └───> evaluate_network\n            │\n            └───> Graphormer.inference\n\"\"\"\nimport argparse\nimport random\n\nimport torch as th\nimport torch.nn as nn\nfrom accelerate import Accelerator\nfrom dataset import MolHIVDataset\nfrom dgl.data import download\nfrom dgl.dataloading import GraphDataLoader\nfrom model import Graphormer\nfrom ogb.graphproppred import Evaluator\nfrom transformers.optimization import (\n    AdamW,\n    get_polynomial_decay_schedule_with_warmup,\n)\n\n# Instantiate an accelerator object to support distributed\n# training and inference.\naccelerator = Accelerator()\n\n\ndef train_epoch(model, optimizer, data_loader, lr_scheduler):\n    model.train()\n    epoch_loss = 0\n    list_scores = []\n    list_labels = []\n    loss_fn = nn.BCEWithLogitsLoss()\n    for (\n        batch_labels,\n        attn_mask,\n        node_feat,\n        in_degree,\n        out_degree,\n        path_data,\n        dist,\n    ) in data_loader:\n        optimizer.zero_grad()\n        device = accelerator.device\n\n        batch_scores = model(\n            node_feat.to(device),\n            in_degree.to(device),\n            out_degree.to(device),\n            path_data.to(device),\n            dist.to(device),\n            attn_mask=attn_mask,\n        )\n\n        loss = loss_fn(batch_scores, batch_labels.float())\n\n        accelerator.backward(loss)\n        optimizer.step()\n        lr_scheduler.step()\n        epoch_loss += loss.item()\n        list_scores.append(batch_scores)\n        list_labels.append(batch_labels)\n\n        # Release GPU memory.\n        del (\n            batch_labels,\n            batch_scores,\n            loss,\n            attn_mask,\n            node_feat,\n            in_degree,\n            out_degree,\n            path_data,\n            dist,\n        )\n        th.cuda.empty_cache()\n\n    epoch_loss /= len(data_loader)\n\n    evaluator = Evaluator(name=\"ogbg-molhiv\")\n    epoch_auc = evaluator.eval(\n        {\"y_pred\": th.cat(list_scores), \"y_true\": th.cat(list_labels)}\n    )[\"rocauc\"]\n\n    return epoch_loss, epoch_auc\n\n\ndef evaluate_network(model, data_loader):\n    model.eval()\n    epoch_loss = 0\n    loss_fn = nn.BCEWithLogitsLoss()\n    with th.no_grad():\n        list_scores = []\n        list_labels = []\n        for (\n            batch_labels,\n            attn_mask,\n            node_feat,\n            in_degree,\n            out_degree,\n            path_data,\n            dist,\n        ) in data_loader:\n            device = accelerator.device\n\n            batch_scores = model(\n                node_feat.to(device),\n                in_degree.to(device),\n                out_degree.to(device),\n                path_data.to(device),\n                dist.to(device),\n                attn_mask=attn_mask,\n            )\n\n            # Gather all predictions and targets.\n            all_predictions, all_targets = accelerator.gather_for_metrics(\n                (batch_scores, batch_labels)\n            )\n            loss = loss_fn(all_predictions, all_targets.float())\n\n            epoch_loss += loss.item()\n            list_scores.append(all_predictions)\n            list_labels.append(all_targets)\n\n        epoch_loss /= len(data_loader)\n\n        evaluator = Evaluator(name=\"ogbg-molhiv\")\n        epoch_auc = evaluator.eval(\n            {\"y_pred\": th.cat(list_scores), \"y_true\": th.cat(list_labels)}\n        )[\"rocauc\"]\n\n    return epoch_loss, epoch_auc\n\n\ndef train_val_pipeline(params):\n    dataset = MolHIVDataset()\n\n    accelerator.print(\n        f\"train, test, val sizes: {len(dataset.train)}, \"\n        f\"{len(dataset.test)}, {len(dataset.val)}.\"\n    )\n    accelerator.print(\"Finished loading.\")\n\n    train_loader = GraphDataLoader(\n        dataset.train,\n        batch_size=params.batch_size,\n        shuffle=True,\n        collate_fn=dataset.collate,\n        pin_memory=True,\n        num_workers=16,\n    )\n    val_loader = GraphDataLoader(\n        dataset.val,\n        batch_size=params.batch_size,\n        shuffle=False,\n        collate_fn=dataset.collate,\n        pin_memory=True,\n        num_workers=16,\n    )\n    test_loader = GraphDataLoader(\n        dataset.test,\n        batch_size=params.batch_size,\n        shuffle=False,\n        collate_fn=dataset.collate,\n        pin_memory=True,\n        num_workers=16,\n    )\n\n    # Load pre-trained model.\n    download(url=\"https://data.dgl.ai/pre_trained/graphormer_pcqm.pth\")\n    model = Graphormer()\n    state_dict = th.load(\"graphormer_pcqm.pth\")\n    model.load_state_dict(state_dict)\n\n    model.reset_output_layer_parameters()\n    num_epochs = 16\n    total_updates = 33000 * num_epochs / params.batch_size\n    # Use warmup schedule to avoid overfitting at the very beginning\n    # of training, the ratio 0.16 is the same as the paper.\n    warmup_updates = total_updates * 0.16\n\n    optimizer = AdamW(model.parameters(), lr=1e-4, eps=1e-8, weight_decay=0)\n    lr_scheduler = get_polynomial_decay_schedule_with_warmup(\n        optimizer,\n        num_warmup_steps=warmup_updates,\n        num_training_steps=total_updates,\n        lr_end=1e-9,\n        power=1.0,\n    )\n\n    epoch_train_AUCs, epoch_val_AUCs, epoch_test_AUCs = [], [], []\n\n    # Pass all objects relevant to training to the prepare() method as required\n    # by Accelerate.\n    (\n        model,\n        optimizer,\n        train_loader,\n        val_loader,\n        test_loader,\n        lr_scheduler,\n    ) = accelerator.prepare(\n        model, optimizer, train_loader, val_loader, test_loader, lr_scheduler\n    )\n\n    for epoch in range(num_epochs):\n        epoch_train_loss, epoch_train_auc = train_epoch(\n            model, optimizer, train_loader, lr_scheduler\n        )\n        epoch_val_loss, epoch_val_auc = evaluate_network(model, val_loader)\n        epoch_test_loss, epoch_test_auc = evaluate_network(model, test_loader)\n\n        epoch_train_AUCs.append(epoch_train_auc)\n        epoch_val_AUCs.append(epoch_val_auc)\n        epoch_test_AUCs.append(epoch_test_auc)\n\n        accelerator.print(\n            f\"Epoch={epoch + 1} | train_AUC={epoch_train_auc:.3f} | \"\n            f\"val_AUC={epoch_val_auc:.3f} | test_AUC={epoch_test_auc:.3f}\"\n        )\n\n    # Return test and train AUCs with best val AUC.\n    index = epoch_val_AUCs.index(max(epoch_val_AUCs))\n    val_auc = epoch_val_AUCs[index]\n    train_auc = epoch_train_AUCs[index]\n    test_auc = epoch_test_AUCs[index]\n\n    accelerator.print(\"Test ROCAUC: {:.4f}\".format(test_auc))\n    accelerator.print(\"Val ROCAUC: {:.4f}\".format(val_auc))\n    accelerator.print(\"Train ROCAUC: {:.4f}\".format(train_auc))\n    accelerator.print(\"Best epoch index: {:.4f}\".format(index))\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--seed\",\n        default=1,\n        type=int,\n        help=\"Please give a value for random seed\",\n    )\n    parser.add_argument(\n        \"--batch_size\",\n        default=16,\n        type=int,\n        help=\"Please give a value for batch_size\",\n    )\n    args = parser.parse_args()\n\n    # Set manual seed to bind the order of training data to the random seed.\n    random.seed(args.seed)\n    th.manual_seed(args.seed)\n    if th.cuda.is_available():\n        th.cuda.manual_seed(args.seed)\n\n    train_val_pipeline(args)\n"
  },
  {
    "path": "examples/core/Graphormer/model.py",
    "content": "\"\"\"\nThis file defines the Graphormer model, which utilizes DegreeEncoder,\nSpatialEncoder, PathEncoder and GraphormerLayer from DGL build-in modules.\n\"\"\"\nimport torch as th\nimport torch.nn as nn\nfrom dgl.nn import DegreeEncoder, GraphormerLayer, PathEncoder, SpatialEncoder\n\n\nclass Graphormer(nn.Module):\n    def __init__(\n        self,\n        num_classes=1,\n        edge_dim=3,\n        num_atoms=4608,\n        max_degree=512,\n        num_spatial=511,\n        multi_hop_max_dist=5,\n        num_encoder_layers=12,\n        embedding_dim=768,\n        ffn_embedding_dim=768,\n        num_attention_heads=32,\n        dropout=0.1,\n        pre_layernorm=True,\n        activation_fn=nn.GELU(),\n    ):\n        super().__init__()\n        self.dropout = nn.Dropout(p=dropout)\n        self.embedding_dim = embedding_dim\n        self.num_heads = num_attention_heads\n\n        self.atom_encoder = nn.Embedding(\n            num_atoms + 1, embedding_dim, padding_idx=0\n        )\n        self.graph_token = nn.Embedding(1, embedding_dim)\n\n        self.degree_encoder = DegreeEncoder(\n            max_degree=max_degree, embedding_dim=embedding_dim\n        )\n\n        self.path_encoder = PathEncoder(\n            max_len=multi_hop_max_dist,\n            feat_dim=edge_dim,\n            num_heads=num_attention_heads,\n        )\n\n        self.spatial_encoder = SpatialEncoder(\n            max_dist=num_spatial, num_heads=num_attention_heads\n        )\n        self.graph_token_virtual_distance = nn.Embedding(1, num_attention_heads)\n\n        self.emb_layer_norm = nn.LayerNorm(self.embedding_dim)\n\n        self.layers = nn.ModuleList([])\n        self.layers.extend(\n            [\n                GraphormerLayer(\n                    feat_size=self.embedding_dim,\n                    hidden_size=ffn_embedding_dim,\n                    num_heads=num_attention_heads,\n                    dropout=dropout,\n                    activation=activation_fn,\n                    norm_first=pre_layernorm,\n                )\n                for _ in range(num_encoder_layers)\n            ]\n        )\n\n        # map graph_rep to num_classes\n        self.lm_head_transform_weight = nn.Linear(\n            self.embedding_dim, self.embedding_dim\n        )\n        self.layer_norm = nn.LayerNorm(self.embedding_dim)\n        self.activation_fn = activation_fn\n        self.embed_out = nn.Linear(self.embedding_dim, num_classes, bias=False)\n        self.lm_output_learned_bias = nn.Parameter(th.zeros(num_classes))\n\n    def reset_output_layer_parameters(self):\n        self.lm_output_learned_bias = nn.Parameter(th.zeros(1))\n        self.embed_out.reset_parameters()\n\n    def forward(\n        self,\n        node_feat,\n        in_degree,\n        out_degree,\n        path_data,\n        dist,\n        attn_mask=None,\n    ):\n        num_graphs, max_num_nodes, _ = node_feat.shape\n        deg_emb = self.degree_encoder(th.stack((in_degree, out_degree)))\n\n        # node feature + degree encoding as input\n        node_feat = self.atom_encoder(node_feat.int()).sum(dim=-2)\n        node_feat = node_feat + deg_emb\n        graph_token_feat = self.graph_token.weight.unsqueeze(0).repeat(\n            num_graphs, 1, 1\n        )\n        x = th.cat([graph_token_feat, node_feat], dim=1)\n\n        # spatial encoding and path encoding serve as attention bias\n        attn_bias = th.zeros(\n            num_graphs,\n            max_num_nodes + 1,\n            max_num_nodes + 1,\n            self.num_heads,\n            device=dist.device,\n        )\n        path_encoding = self.path_encoder(dist, path_data)\n        spatial_encoding = self.spatial_encoder(dist)\n        attn_bias[:, 1:, 1:, :] = path_encoding + spatial_encoding\n\n        # spatial encoding of the virtual node\n        t = self.graph_token_virtual_distance.weight.reshape(\n            1, 1, self.num_heads\n        )\n        # Since the virtual node comes first, the spatial encodings between it\n        # and other nodes will fill the 1st row and 1st column (omit num_graphs\n        # and num_heads dimensions) of attn_bias matrix by broadcasting.\n        attn_bias[:, 1:, 0, :] = attn_bias[:, 1:, 0, :] + t\n        attn_bias[:, 0, :, :] = attn_bias[:, 0, :, :] + t\n\n        x = self.emb_layer_norm(x)\n\n        for layer in self.layers:\n            x = layer(\n                x,\n                attn_mask=attn_mask,\n                attn_bias=attn_bias,\n            )\n\n        graph_rep = x[:, 0, :]\n        graph_rep = self.layer_norm(\n            self.activation_fn(self.lm_head_transform_weight(graph_rep))\n        )\n        graph_rep = self.embed_out(graph_rep) + self.lm_output_learned_bias\n\n        return graph_rep\n"
  },
  {
    "path": "examples/core/gat/README.md",
    "content": "Graph Attention Networks (GAT)\n============\n\n- Paper link: [https://arxiv.org/abs/1710.10903](https://arxiv.org/abs/1710.10903)\n- Author's code repo (tensorflow implementation):\n  [https://github.com/PetarV-/GAT](https://github.com/PetarV-/GAT).\n- Popular pytorch implementation:\n  [https://github.com/Diego999/pyGAT](https://github.com/Diego999/pyGAT).\n\nHow to run\n-------\n\nRun with the following for multiclass node classification (available datasets: \"cora\", \"citeseer\", \"pubmed\")\n```bash\npython3 train.py --dataset cora\n```\n\n> **_NOTE:_**  Users may occasionally run into low accuracy issue (e.g., test accuracy < 0.8) due to overfitting. This can be resolved by adding Early Stopping or reducing maximum number of training epochs.\n\nSummary\n-------\n* cora: ~0.821\n* citeseer: ~0.710\n* pubmed: ~0.780\n"
  },
  {
    "path": "examples/core/gat/train.py",
    "content": "import argparse\nimport time\n\nimport dgl.nn as dglnn\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dgl import AddSelfLoop\nfrom dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset\n\n\nclass GAT(nn.Module):\n    def __init__(self, in_size, hid_size, out_size, heads):\n        super().__init__()\n        self.gat_layers = nn.ModuleList()\n        # two-layer GAT\n        self.gat_layers.append(\n            dglnn.GATConv(\n                in_size,\n                hid_size,\n                heads[0],\n                feat_drop=0.6,\n                attn_drop=0.6,\n                activation=F.elu,\n            )\n        )\n        self.gat_layers.append(\n            dglnn.GATConv(\n                hid_size * heads[0],\n                out_size,\n                heads[1],\n                feat_drop=0.6,\n                attn_drop=0.6,\n                activation=None,\n            )\n        )\n\n    def forward(self, g, inputs):\n        h = inputs\n        for i, layer in enumerate(self.gat_layers):\n            h = layer(g, h)\n            if i == len(self.gat_layers) - 1:  # last layer\n                h = h.mean(1)\n            else:  # other layer(s)\n                h = h.flatten(1)\n        return h\n\n\ndef evaluate(g, features, labels, mask, model):\n    model.eval()\n    with torch.no_grad():\n        logits = model(g, features)\n        logits = logits[mask]\n        labels = labels[mask]\n        _, indices = torch.max(logits, dim=1)\n        correct = torch.sum(indices == labels)\n        return correct.item() * 1.0 / len(labels)\n\n\ndef train(g, features, labels, masks, model, num_epochs):\n    # Define train/val samples, loss function and optimizer\n    train_mask = masks[0]\n    val_mask = masks[1]\n    loss_fcn = nn.CrossEntropyLoss()\n    optimizer = torch.optim.Adam(model.parameters(), lr=5e-3, weight_decay=5e-4)\n\n    for epoch in range(num_epochs):\n        t0 = time.time()\n        model.train()\n        logits = model(g, features)\n        loss = loss_fcn(logits[train_mask], labels[train_mask])\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n        acc = evaluate(g, features, labels, val_mask, model)\n        t1 = time.time()\n        print(\n            \"Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} | Time {:.4f}\".format(\n                epoch, loss.item(), acc, t1 - t0\n            )\n        )\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--dataset\",\n        type=str,\n        default=\"cora\",\n        help=\"Dataset name ('cora', 'citeseer', 'pubmed').\",\n    )\n    parser.add_argument(\n        \"--num_epochs\",\n        type=int,\n        default=200,\n        help=\"Number of epochs for train.\",\n    )\n    parser.add_argument(\n        \"--num_gpus\",\n        type=int,\n        default=0,\n        help=\"Number of GPUs used for train and evaluation.\",\n    )\n    args = parser.parse_args()\n    print(f\"Training with DGL built-in GATConv module.\")\n\n    # Load and preprocess dataset\n    transform = (\n        AddSelfLoop()\n    )  # by default, it will first remove self-loops to prevent duplication\n    if args.dataset == \"cora\":\n        data = CoraGraphDataset(transform=transform)\n    elif args.dataset == \"citeseer\":\n        data = CiteseerGraphDataset(transform=transform)\n    elif args.dataset == \"pubmed\":\n        data = PubmedGraphDataset(transform=transform)\n    else:\n        raise ValueError(\"Unknown dataset: {}\".format(args.dataset))\n    g = data[0]\n    if args.num_gpus > 0 and torch.cuda.is_available():\n        device = torch.device(\"cuda\")\n    else:\n        device = torch.device(\"cpu\")\n    g = g.int().to(device)\n    features = g.ndata[\"feat\"]\n    labels = g.ndata[\"label\"]\n    masks = g.ndata[\"train_mask\"], g.ndata[\"val_mask\"], g.ndata[\"test_mask\"]\n\n    # Create GAT model\n    in_size = features.shape[1]\n    out_size = data.num_classes\n    model = GAT(in_size, 8, out_size, heads=[8, 1]).to(device)\n\n    print(\"Training...\")\n    train(g, features, labels, masks, model, args.num_epochs)\n\n    print(\"Testing...\")\n    acc = evaluate(g, features, labels, masks[2], model)\n    print(\"Test accuracy {:.4f}\".format(acc))\n"
  },
  {
    "path": "examples/core/gated_gcn/README.md",
    "content": "Gated Graph ConvNet (GatedGCN)\n==============================\n\n* paper link: [https://arxiv.org/abs/2003.00982.pdf](https://arxiv.org/abs/2003.00982.pdf)\n\n## Dataset\n\nTask: Graph Property Prediction\n\n|   Dataset   | #Graphs | #Node Feats | #Edge Feats | Metric |\n| :---------: | :-----: | :---------: | :---------: | :-----: |\n| ogbg-molhiv | 41,127 |      9      |      3      | ROC-AUC |\n\nHow to run\n----------\n\n```bash\npython3 train.py --dataset ogbg-molhiv --num_gpus 0 --num_epochs 50\n```\n\n## Summary\n\n* ogbg-molhiv: ~0.781\n"
  },
  {
    "path": "examples/core/gated_gcn/train.py",
    "content": "\"\"\"\nGated Graph Convolutional Network module for graph classification tasks\n\"\"\"\nimport argparse\nimport time\n\nimport torch\nimport torch.nn as nn\n\nimport torch.nn.functional as F\nimport torch.optim as optim\nfrom dgl.dataloading import GraphDataLoader\nfrom dgl.nn.pytorch import GatedGCNConv\nfrom dgl.nn.pytorch.glob import AvgPooling\nfrom ogb.graphproppred import DglGraphPropPredDataset, Evaluator\nfrom ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder\n\n\nclass GatedGCN(nn.Module):\n    def __init__(\n        self,\n        hid_dim,\n        out_dim,\n        num_layers,\n        dropout=0.2,\n        batch_norm=True,\n        residual=True,\n        activation=F.relu,\n    ):\n        super(GatedGCN, self).__init__()\n\n        self.num_layers = num_layers\n        self.dropout = dropout\n\n        self.node_encoder = AtomEncoder(hid_dim)\n        self.edge_encoder = BondEncoder(hid_dim)\n\n        self.layers = nn.ModuleList()\n        for _ in range(self.num_layers):\n            layer = GatedGCNConv(\n                input_feats=hid_dim,\n                edge_feats=hid_dim,\n                output_feats=hid_dim,\n                dropout=dropout,\n                batch_norm=batch_norm,\n                residual=residual,\n                activation=activation,\n            )\n            self.layers.append(layer)\n\n        self.pooling = AvgPooling()\n        self.output = nn.Linear(hid_dim, out_dim)\n\n    def forward(self, g, node_feat, edge_feat):\n        # Encode node and edge feature.\n        hv = self.node_encoder(node_feat)\n        he = self.edge_encoder(edge_feat)\n\n        # GatedGCNConv layers.\n        for layer in self.layers:\n            hv, he = layer(g, hv, he)\n\n        # Output project.\n        h_g = self.pooling(g, hv)\n\n        return self.output(h_g)\n\n\ndef train(model, device, data_loader, opt, loss_fn):\n    model.train()\n    train_loss = []\n\n    for g, labels in data_loader:\n        g = g.to(device)\n        labels = labels.to(torch.float32).to(device)\n        logits = model(g, g.ndata[\"feat\"], g.edata[\"feat\"])\n        loss = loss_fn(logits, labels)\n        opt.zero_grad()\n        loss.backward()\n        opt.step()\n        train_loss.append(loss.item())\n\n    return sum(train_loss) / len(train_loss)\n\n\n@torch.no_grad()\ndef evaluate(model, device, data_loader, evaluator):\n    model.eval()\n    y_true, y_pred = [], []\n\n    for g, labels in data_loader:\n        g = g.to(device)\n        logits = model(g, g.ndata[\"feat\"], g.edata[\"feat\"])\n        y_true.append(labels.detach().cpu())\n        y_pred.append(logits.detach().cpu())\n\n    y_true = torch.cat(y_true, dim=0).numpy()\n    y_pred = torch.cat(y_pred, dim=0).numpy()\n\n    return evaluator.eval({\"y_true\": y_true, \"y_pred\": y_pred})[\"rocauc\"]\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--dataset\",\n        type=str,\n        default=\"ogbg-molhiv\",\n        help=\"Dataset name ('ogbg-molhiv', 'ogbg-molbace', 'ogbg-molmuv').\",\n    )\n    parser.add_argument(\n        \"--num_epochs\",\n        type=int,\n        default=200,\n        help=\"Number of epochs for train.\",\n    )\n    parser.add_argument(\n        \"--num_gpus\",\n        type=int,\n        default=0,\n        help=\"Number of GPUs used for train and evaluation.\",\n    )\n    args = parser.parse_args()\n    print(\"Training with DGL built-in GATConv module.\")\n\n    # Load ogb dataset & evaluator.\n    dataset = DglGraphPropPredDataset(name=args.dataset)\n    evaluator = Evaluator(name=args.dataset)\n\n    if args.num_gpus > 0 and torch.cuda.is_available():\n        device = torch.device(\"cuda\")\n    else:\n        device = torch.device(\"cpu\")\n\n    n_classes = dataset.num_tasks\n\n    split_idx = dataset.get_idx_split()\n    train_loader = GraphDataLoader(\n        dataset[split_idx[\"train\"]],\n        batch_size=32,\n        shuffle=True,\n    )\n    valid_loader = GraphDataLoader(dataset[split_idx[\"valid\"]], batch_size=32)\n    test_loader = GraphDataLoader(dataset[split_idx[\"test\"]], batch_size=32)\n\n    # Load model.\n    model = GatedGCN(hid_dim=256, out_dim=n_classes, num_layers=8).to(device)\n\n    print(model)\n\n    opt = optim.Adam(model.parameters(), lr=0.01)\n    loss_fn = nn.BCEWithLogitsLoss()\n\n    print(\"---------- Training ----------\")\n    for epoch in range(args.num_epochs):\n        # Kick off training.\n        t0 = time.time()\n        loss = train(model, device, train_loader, opt, loss_fn)\n        t1 = time.time()\n        # Evaluate the prediction.\n        val_acc = evaluate(model, device, valid_loader, evaluator)\n        print(\n            f\"Epoch {epoch:05d} | Loss {loss:.4f} | Accuracy {val_acc:.4f} | \"\n            f\"Time {t1 - t0:.4f}\"\n        )\n    acc = evaluate(model, device, test_loader, evaluator)\n    print(f\"Test accuracy {acc:.4f}\")\n"
  },
  {
    "path": "examples/core/graphsage/node_classification.py",
    "content": "\"\"\"\nThis script trains and tests a GraphSAGE model based on the information of \na full graph.\n\nThis flowchart describes the main functional sequence of the provided example.\nmain\n│\n├───> Load and preprocess full dataset\n│\n├───> Instantiate SAGE model\n│\n├───> train\n│     │\n│     └───> Training loop\n│           │\n│           └───> SAGE.forward\n└───> test\n      │\n      └───> Evaluate the model\n\"\"\"\nimport argparse\nimport time\n\nimport dgl.nn as dglnn\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dgl import AddSelfLoop\nfrom dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset\n\n\nclass SAGE(nn.Module):\n    def __init__(self, in_size, hidden_size, out_size):\n        super().__init__()\n        self.layers = nn.ModuleList()\n        # Two-layer GraphSAGE-gcn.\n        self.layers.append(dglnn.SAGEConv(in_size, hidden_size, \"gcn\"))\n        self.layers.append(dglnn.SAGEConv(hidden_size, out_size, \"gcn\"))\n        self.dropout = nn.Dropout(0.5)\n\n    def forward(self, graph, x):\n        hidden_x = x\n        for layer_idx, layer in enumerate(self.layers):\n            hidden_x = layer(graph, hidden_x)\n            is_last_layer = layer_idx == len(self.layers) - 1\n            if not is_last_layer:\n                hidden_x = F.relu(hidden_x)\n                hidden_x = self.dropout(hidden_x)\n        return hidden_x\n\n\ndef evaluate(g, features, labels, mask, model):\n    model.eval()\n    with torch.no_grad():\n        logits = model(g, features)\n        logits = logits[mask]\n        labels = labels[mask]\n        _, indices = torch.max(logits, dim=1)\n        correct = torch.sum(indices == labels)\n        return correct.item() * 1.0 / len(labels)\n\n\ndef train(g, features, labels, masks, model):\n    # Define train/val samples, loss function and optimizer.\n    train_mask, val_mask = masks\n    loss_fcn = nn.CrossEntropyLoss()\n    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)\n\n    # Training loop.\n    for epoch in range(200):\n        t0 = time.time()\n        model.train()\n        logits = model(g, features)\n        loss = loss_fcn(logits[train_mask], labels[train_mask])\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n        t1 = time.time()\n        acc = evaluate(g, features, labels, val_mask, model)\n        print(\n            f\"Epoch {epoch:05d} | Loss {loss.item():.4f} | Accuracy {acc:.4f} | \"\n            f\"Time {t1 - t0:.4f}\"\n        )\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"GraphSAGE\")\n    parser.add_argument(\n        \"--dataset\",\n        type=str,\n        default=\"cora\",\n        help=\"Dataset name ('cora', 'citeseer', 'pubmed')\",\n    )\n    args = parser.parse_args()\n    print(f\"Training with DGL built-in GraphSage module\")\n\n    #####################################################################\n    # (HIGHLIGHT) Node classification task is a supervise learning task\n    # in which the model try to predict the label of a certain node.\n    # In this example, graph sage algorithm is applied to this task.\n    # A good accuracy can be achieved after a few steps of training.\n    #\n    # First, the whole graph is loaded and transformed. Then the training\n    # process is performed on a model which is composed of 2 GraphSAGE-gcn\n    # layer. Finally, the performance of the model is evaluated on test set.\n    #####################################################################\n\n    # Load and preprocess dataset.\n    transform = (\n        AddSelfLoop()\n    )  # By default, it will first remove self-loops to prevent duplication.\n    if args.dataset == \"cora\":\n        data = CoraGraphDataset(transform=transform)\n    elif args.dataset == \"citeseer\":\n        data = CiteseerGraphDataset(transform=transform)\n    elif args.dataset == \"pubmed\":\n        data = PubmedGraphDataset(transform=transform)\n    else:\n        raise ValueError(f\"Unknown dataset: {args.dataset}\")\n    g = data[0]\n    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n    g = g.int().to(device)\n    features = g.ndata[\"feat\"]\n    labels = g.ndata[\"label\"]\n    masks = (g.ndata[\"train_mask\"], g.ndata[\"val_mask\"])\n\n    # Create GraphSAGE model.\n    in_size = features.shape[1]\n    out_size = data.num_classes\n    model = SAGE(in_size, 16, out_size).to(device)\n\n    # Model training.\n    print(\"Training...\")\n    train(g, features, labels, masks, model)\n\n    # Test the model.\n    print(\"Testing...\")\n    acc = evaluate(g, features, labels, g.ndata[\"test_mask\"], model)\n    print(f\"Test accuracy {acc:.4f}\")\n"
  },
  {
    "path": "examples/core/rgcn/README.md",
    "content": "# Node classification on heterogeneous graph with RGCN\n\nThis example aims to demonstrate how to run node classification task on heterogeneous graph with **DGL**. Models are not tuned to achieve the best accuracy yet.\n\n## Run on `ogbn-mag` dataset\nIn the preprocess stage, reverse edges are added and duplicate edges are removed. Feature data of `author` and `institution` node types are generated dynamically with embedding layer.\n\n### Sample on CPU and train/infer on CPU\n```\npython3 hetero_rgcn.py --dataset ogbn-mag\n```\n\n### Sample on CPU and train/infer on GPU\n```\npython3 hetero_rgcn.py --dataset ogbn-mag --num_gpus 1\n```\n\n### Resource usage and time cost\nBelow results are roughly collected from an AWS EC2 **g4dn.metal**, 384GB RAM, 96 vCPUs(Cascade Lake P-8259L), 8 NVIDIA T4 GPUs(16GB RAM). CPU RAM usage is the peak value of `used` field of `free` command which is a bit rough. Please refer to `RSS`/`USS`/`PSS` which are more accurate. GPU RAM usage is the peak value recorded by `nvidia-smi` command.\n\n| Dataset Size | CPU RAM Usage | Num of GPUs | GPU RAM Usage | Time Per Epoch(Training) |\n| ------------ | ------------- | ----------- | ------------- | ------------------------ |\n| ~1.1GB       | ~7GB          | 0           |  0GB          | ~233s                    |\n| ~1.1GB       | ~5GB          | 1           |  4.5GB        | ~73.6s                   |\n\n### Accuracies\n```\nEpoch: 01, Loss: 2.3386, Valid: 47.67%, Test: 46.96%\nEpoch: 02, Loss: 1.5563, Valid: 47.66%, Test: 47.02%\nEpoch: 03, Loss: 1.1557, Valid: 46.58%, Test: 45.42%\nTest accuracy 45.3850\n```\n\n## Run on `ogb-lsc-mag240m` dataset\nIn the preprocess stage, reverse edges are added and duplicate edges are removed. What's more, feature data are generated in advance for `author` and `institution` node types via message passing. Since such preprocessing will usually take a long time, we also offer the above files for download:\n\n* [`paper-feat.npy`](https://dgl-data.s3-accelerate.amazonaws.com/dataset/OGB-LSC/paper-feat.npy)\n* [`author-feat.npy`](https://dgl-data.s3-accelerate.amazonaws.com/dataset/OGB-LSC/author-feat.npy)\n* [`inst-feat.npy`](https://dgl-data.s3-accelerate.amazonaws.com/dataset/OGB-LSC/inst-feat.npy)\n* [`hetero-graph.dgl`](https://dgl-data.s3-accelerate.amazonaws.com/dataset/OGB-LSC/hetero-graph.dgl)\n\n### Sample on CPU and train/infer on CPU\n```\npython3 hetero_rgcn.py --dataset ogb-lsc-mag240m\n```\n\n### Sample on CPU and train/infer on GPU\n```\npython3 hetero_rgcn.py --dataset ogb-lsc-mag240m --num_gpus 1\n```\n\n### Resource usage and time cost\nBelow results are roughly collected from an AWS EC2 **g4dn.metal**, 384GB RAM, 96 vCPUs(Cascade Lake P-8259L), 8 NVIDIA T4 GPUs(16GB RAM). CPU RAM usage is the peak value of `used` field of `free` command which is a bit rough. Please refer to `RSS`/`USS`/`PSS` which are more accurate. GPU RAM usage is the peak value recorded by `nvidia-smi` command.\n\n| Dataset Size | CPU RAM Usage | Num of GPUs | GPU RAM Usage | Time Per Epoch(Training) |\n| ------------ | ------------- | ----------- | ------------- | ------------------------ |\n| ~404GB       | ~72GB         | 0           |  0GB          | ~325s                    |\n| ~404GB       | ~61GB         | 1           |  14GB         | ~178s                    |\n\n### Accuracies\n```\nEpoch: 01, Loss: 2.0798, Valid: 52.04%\nEpoch: 02, Loss: 1.8652, Valid: 54.51%\nEpoch: 03, Loss: 1.8175, Valid: 53.71%\n```\n"
  },
  {
    "path": "examples/core/rgcn/hetero_rgcn.py",
    "content": "\"\"\"\nThis script, `hetero_rgcn.py`, trains and tests a Relational Graph\nConvolutional Network (R-GCN) model for node classification on the\nOpen Graph Benchmark (OGB) dataset \"ogbn-mag\". For more details on\n\"ogbn-mag\", please refer to the OGB website:\n(https://ogb.stanford.edu/docs/linkprop/)\n\nPaper [Modeling Relational Data with Graph Convolutional Networks]\n(https://arxiv.org/abs/1703.06103).\n\nGeneration of graph embeddings is the main difference between homograph\nnode classification and heterograph node classification:\n- Homograph: Since all nodes and edges are of the same type, embeddings\n  can be generated using a unified approach. Type-specific handling is\n  typically not required.\n- Heterograph: Due to the existence of multiple types of nodes and edges,\n  specific embeddings need to be generated for each type. This allows for\n  a more nuanced capture of the complex structure and semantic information\n  within the heterograph.\n\nThis flowchart describes the main functional sequence of the provided example.\nmain\n│\n├───> prepare_data\n│     │\n│     └───> Load and preprocess dataset\n│\n├───> rel_graph_embed [HIGHLIGHT]\n│     │\n│     └───> Generate graph embeddings\n│\n├───> Instantiate RGCN model\n│     │\n│     ├───> RelGraphConvLayer (input to hidden)\n│     │\n│     └───> RelGraphConvLayer (hidden to output)\n│\n└───> train\n      │\n      │\n      └───> Training loop\n            │\n            ├───> EntityClassify.forward (RGCN model forward pass)\n            │\n            └───> test\n                  │\n                  └───> EntityClassify.evaluate\n\"\"\"\n\nimport argparse\nimport itertools\nimport sys\nimport time\n\nimport dgl\nimport dgl.nn as dglnn\nimport numpy as np\n\nimport psutil\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dgl import AddReverse, Compose, ToSimple\nfrom dgl.nn import HeteroEmbedding\nfrom ogb.lsc import MAG240MDataset, MAG240MEvaluator\nfrom ogb.nodeproppred import DglNodePropPredDataset, Evaluator\nfrom tqdm import tqdm\n\n\ndef prepare_data(args, device):\n    feats = {}\n    if args.dataset == \"ogbn-mag\":\n        dataset = DglNodePropPredDataset(name=\"ogbn-mag\", root=args.rootdir)\n\n        # - graph: dgl graph object.\n        # - label: torch tensor of shape (num_nodes, num_tasks).\n        g, labels = dataset[0]\n\n        # Flatten the labels for \"paper\" type nodes. This step reduces the\n        # dimensionality of the labels. We need to flatten the labels because\n        # the model requires a 1-dimensional label tensor.\n        labels = labels[\"paper\"].flatten().long()\n\n        # Apply transformation to the graph.\n        # - \"ToSimple()\" removes multi-edge between two nodes.\n        # - \"AddReverse()\" adds reverse edges to the graph.\n        print(\"Start to transform graph. This may take a while...\")\n        transform = Compose([ToSimple(), AddReverse()])\n        g = transform(g)\n    else:\n        dataset = MAG240MDataset(root=args.rootdir)\n        (g,), _ = dgl.load_graphs(args.graph_path)\n        g = g.formats([\"csc\"])\n        labels = torch.as_tensor(dataset.paper_label).long()\n        # As feature data is too large to fit in memory, we read it from disk.\n        feats[\"paper\"] = torch.as_tensor(\n            np.load(args.paper_feature_path, mmap_mode=\"r+\")\n        )\n        feats[\"author\"] = torch.as_tensor(\n            np.load(args.author_feature_path, mmap_mode=\"r+\")\n        )\n        feats[\"institution\"] = torch.as_tensor(\n            np.load(args.inst_feature_path, mmap_mode=\"r+\")\n        )\n    print(f\"Loaded graph: {g}\")\n\n    # Get train/valid/test index.\n    split_idx = dataset.get_idx_split()\n    if args.dataset == \"ogb-lsc-mag240m\":\n        split_idx = {\n            split_type: {\"paper\": split_idx[split_type]}\n            for split_type in split_idx\n        }\n\n    # Initialize a train sampler that samples neighbors for multi-layer graph\n    # convolution. It samples 25 and 10 neighbors for the first and second\n    # layers respectively.\n    sampler = dgl.dataloading.MultiLayerNeighborSampler([25, 10], fused=False)\n    num_workers = args.num_workers\n    train_loader = dgl.dataloading.DataLoader(\n        g,\n        split_idx[\"train\"],\n        sampler,\n        batch_size=1024,\n        shuffle=True,\n        num_workers=num_workers,\n        device=device,\n    )\n\n    return g, labels, dataset.num_classes, split_idx, train_loader, feats\n\n\ndef extract_embed(node_embed, input_nodes):\n    emb = node_embed(\n        {ntype: input_nodes[ntype] for ntype in input_nodes if ntype != \"paper\"}\n    )\n    return emb\n\n\ndef rel_graph_embed(graph, embed_size):\n    \"\"\"Initialize a heterogenous embedding layer for all node types in the\n    graph, except for the \"paper\" node type.\n\n    The function constructs a dictionary 'node_num', where the keys are node\n    types (ntype) and the values are the number of nodes for each type. This\n    dictionary is used to create a HeteroEmbedding instance.\n\n    (HIGHLIGHT)\n    A HeteroEmbedding instance holds separate embedding layers for each node\n    type, each with its own feature space of dimensionality\n    (node_num[ntype], embed_size), where 'node_num[ntype]' is the number of\n    nodes of type 'ntype' and 'embed_size' is the embedding dimension.\n\n    The \"paper\" node type is specifically excluded, possibly because these nodes\n    might already have predefined feature representations, and therefore, do not\n    require an additional embedding layer.\n\n    Parameters\n    ----------\n    graph : DGLGraph\n        The graph for which to create the heterogenous embedding layer.\n    embed_size : int\n        The size of the embedding vectors.\n\n    Returns\n    --------\n    HeteroEmbedding\n        A heterogenous embedding layer for all node types in the graph, except\n        for the \"paper\" node type.\n    \"\"\"\n    node_num = {}\n    for ntype in graph.ntypes:\n        # Skip the \"paper\" node type.\n        if ntype == \"paper\":\n            continue\n        node_num[ntype] = graph.num_nodes(ntype)\n    return HeteroEmbedding(node_num, embed_size)\n\n\nclass RelGraphConvLayer(nn.Module):\n    def __init__(\n        self,\n        in_size,\n        out_size,\n        ntypes,\n        relation_names,\n        activation=None,\n        dropout=0.0,\n    ):\n        super(RelGraphConvLayer, self).__init__()\n        self.in_size = in_size\n        self.out_size = out_size\n        self.ntypes = ntypes\n        self.relation_names = relation_names\n        self.activation = activation\n\n        ########################################################################\n        # (HIGHLIGHT) HeteroGraphConv is a graph convolution operator over\n        # heterogeneous graphs. A dictionary is passed where the key is the\n        # relation name and the value is the instance of GraphConv. norm=\"right\"\n        # is to divide the aggregated messages by each node’s in-degrees, which\n        # is equivalent to averaging the received messages. weight=False and\n        # bias=False as we will use our own weight matrices defined later.\n        ########################################################################\n        self.conv = dglnn.HeteroGraphConv(\n            {\n                rel: dglnn.GraphConv(\n                    in_size, out_size, norm=\"right\", weight=False, bias=False\n                )\n                for rel in relation_names\n            }\n        )\n\n        # Create a separate Linear layer for each relationship. Each\n        # relationship has its own weights which will be applied to the node\n        # features before performing convolution.\n        self.weight = nn.ModuleDict(\n            {\n                rel_name: nn.Linear(in_size, out_size, bias=False)\n                for rel_name in self.relation_names\n            }\n        )\n\n        # Create a separate Linear layer for each node type.\n        # loop_weights are used to update the output embedding of each target node\n        # based on its own features, thereby allowing the model to refine the node\n        # representations. Note that this does not imply the existence of self-loop\n        # edges in the graph. It is similar to residual connection.\n        self.loop_weights = nn.ModuleDict(\n            {\n                ntype: nn.Linear(in_size, out_size, bias=True)\n                for ntype in self.ntypes\n            }\n        )\n\n        self.loop_weights = nn.ModuleDict(\n            {\n                ntype: nn.Linear(in_size, out_size, bias=True)\n                for ntype in self.ntypes\n            }\n        )\n\n        self.dropout = nn.Dropout(dropout)\n        # Initialize parameters of the model.\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        for layer in self.weight.values():\n            layer.reset_parameters()\n\n        for layer in self.loop_weights.values():\n            layer.reset_parameters()\n\n    def forward(self, g, inputs):\n        \"\"\"\n        Parameters\n        ----------\n        g : DGLGraph\n            Input graph.\n        inputs : dict[str, torch.Tensor]\n            Node feature for each node type.\n\n        Returns\n        -------\n        dict[str, torch.Tensor]\n            New node features for each node type.\n        \"\"\"\n        # Create a deep copy of the graph g with features saved in local\n        # frames to prevent side effects from modifying the graph.\n        g = g.local_var()\n\n        # Create a dictionary of weights for each relationship. The weights\n        # are retrieved from the Linear layers defined earlier.\n        weight_dict = {\n            rel_name: {\"weight\": self.weight[rel_name].weight.T}\n            for rel_name in self.relation_names\n        }\n\n        # Create a dictionary of node features for the destination nodes in\n        # the graph. We slice the node features according to the number of\n        # destination nodes of each type. This is necessary because when\n        # incorporating the effect of self-loop edges, we perform computations\n        # only on the destination nodes' features. By doing so, we ensure the\n        # feature dimensions match and prevent any misuse of incorrect node\n        # features.\n        inputs_dst = {\n            k: v[: g.number_of_dst_nodes(k)] for k, v in inputs.items()\n        }\n\n        # Apply the convolution operation on the graph. mod_kwargs are\n        # additional arguments for each relation function defined in the\n        # HeteroGraphConv. In this case, it's the weights for each relation.\n        hs = self.conv(g, inputs, mod_kwargs=weight_dict)\n\n        def _apply(ntype, h):\n            # Apply the `loop_weight` to the input node features, effectively\n            # acting as a residual connection. This allows the model to refine\n            # node embeddings based on its current features.\n            h = h + self.loop_weights[ntype](inputs_dst[ntype])\n            if self.activation:\n                h = self.activation(h)\n            return self.dropout(h)\n\n        # Apply the function defined above for each node type. This will update\n        # the node features using the `loop_weights`, apply the activation\n        # function and dropout.\n        return {ntype: _apply(ntype, h) for ntype, h in hs.items()}\n\n\nclass EntityClassify(nn.Module):\n    def __init__(self, g, in_size, out_size):\n        super(EntityClassify, self).__init__()\n        self.in_size = in_size\n        self.hidden_size = 64\n        self.out_size = out_size\n\n        # Generate and sort a list of unique edge types from the input graph.\n        # eg. ['writes', 'cites']\n        self.relation_names = list(set(g.etypes))\n        self.relation_names.sort()\n        self.dropout = 0.5\n\n        self.layers = nn.ModuleList()\n\n        # First layer: transform input features to hidden features. Use ReLU\n        # as the activation function and apply dropout for regularization.\n        self.layers.append(\n            RelGraphConvLayer(\n                self.in_size,\n                self.hidden_size,\n                g.ntypes,\n                self.relation_names,\n                activation=F.relu,\n                dropout=self.dropout,\n            )\n        )\n\n        # Second layer: transform hidden features to output features. No\n        # activation function is applied at this stage.\n        self.layers.append(\n            RelGraphConvLayer(\n                self.hidden_size,\n                self.out_size,\n                g.ntypes,\n                self.relation_names,\n                activation=None,\n            )\n        )\n\n    def reset_parameters(self):\n        # Reset the parameters of each layer.\n        for layer in self.layers:\n            layer.reset_parameters()\n\n    def forward(self, h, blocks):\n        for layer, block in zip(self.layers, blocks):\n            h = layer(block, h)\n        return h\n\n\ndef extract_node_features(name, g, input_nodes, node_embed, feats, device):\n    \"\"\"Extract the node features from embedding layer or raw features.\"\"\"\n    if name == \"ogbn-mag\":\n        # Extract node embeddings for the input nodes.\n        node_features = extract_embed(node_embed, input_nodes)\n        # Add the batch's raw \"paper\" features. Corresponds to the content\n        # in the function `rel_graph_embed` comment.\n        node_features.update(\n            {\"paper\": g.ndata[\"feat\"][\"paper\"][input_nodes[\"paper\"].cpu()]}\n        )\n        node_features = {k: e.to(device) for k, e in node_features.items()}\n    else:\n        node_features = {\n            ntype: feats[ntype][input_nodes[ntype].cpu()].to(device)\n            for ntype in input_nodes\n        }\n        # Original feature data are stored in float16 while model weights are\n        # float32, so we need to convert the features to float32.\n        # [TODO] Enable mixed precision training on GPU.\n        node_features = {k: v.float() for k, v in node_features.items()}\n    return node_features\n\n\ndef train(\n    dataset,\n    g,\n    feats,\n    model,\n    node_embed,\n    optimizer,\n    train_loader,\n    split_idx,\n    labels,\n    device,\n):\n    print(\"Start training...\")\n    category = \"paper\"\n\n    # Typically, the best Validation performance is obtained after\n    # the 1st or 2nd epoch. This is why the max epoch is set to 3.\n    for epoch in range(3):\n        num_train = split_idx[\"train\"][category].shape[0]\n        t0 = time.time()\n        model.train()\n\n        total_loss = 0\n\n        for input_nodes, seeds, blocks in tqdm(\n            train_loader, desc=f\"Epoch {epoch:02d}\"\n        ):\n            # Move the input data onto the device.\n            blocks = [blk.to(device) for blk in blocks]\n            # We only predict the nodes with type \"category\".\n            seeds = seeds[category]\n            batch_size = seeds.shape[0]\n\n            # Extract the node features from embedding layer or raw features.\n            node_features = extract_node_features(\n                dataset, g, input_nodes, node_embed, feats, device\n            )\n            lbl = labels[seeds.cpu()].to(device)\n\n            # Reset gradients.\n            optimizer.zero_grad()\n            # Generate predictions.\n            logits = model(node_features, blocks)[category]\n\n            y_hat = logits.log_softmax(dim=-1)\n            loss = F.nll_loss(y_hat, lbl)\n            loss.backward()\n            optimizer.step()\n\n            total_loss += loss.item() * batch_size\n\n        t1 = time.time()\n        loss = total_loss / num_train\n\n        # Evaluate the model on the val/test set.\n        valid_acc = evaluate(\n            dataset,\n            g,\n            feats,\n            model,\n            node_embed,\n            labels,\n            device,\n            split_idx[\"valid\"],\n        )\n        test_key = \"test\" if dataset == \"ogbn-mag\" else \"test-dev\"\n        test_acc = evaluate(\n            dataset,\n            g,\n            feats,\n            model,\n            node_embed,\n            labels,\n            device,\n            split_idx[test_key],\n            save_test_submission=(dataset == \"ogb-lsc-mag240m\"),\n        )\n        print(\n            f\"Epoch: {epoch +1 :02d}, \"\n            f\"Loss: {loss:.4f}, \"\n            f\"Valid: {100 * valid_acc:.2f}%, \"\n            f\"Test: {100 * test_acc:.2f}%, \"\n            f\"Time {t1 - t0:.4f}\"\n        )\n\n\n@torch.no_grad()\ndef evaluate(\n    dataset,\n    g,\n    feats,\n    model,\n    node_embed,\n    labels,\n    device,\n    idx,\n    save_test_submission=False,\n):\n    # Switches the model to evaluation mode.\n    model.eval()\n    category = \"paper\"\n    if dataset == \"ogbn-mag\":\n        evaluator = Evaluator(name=\"ogbn-mag\")\n    else:\n        evaluator = MAG240MEvaluator()\n\n    sampler = dgl.dataloading.MultiLayerNeighborSampler([25, 10], fused=False)\n    dataloader = dgl.dataloading.DataLoader(\n        g,\n        idx,\n        sampler,\n        batch_size=4096,\n        shuffle=False,\n        num_workers=0,\n        device=device,\n    )\n\n    # To store the predictions.\n    y_hats = list()\n    y_true = list()\n\n    for input_nodes, seeds, blocks in tqdm(dataloader, desc=\"Inference\"):\n        blocks = [blk.to(device) for blk in blocks]\n        # We only predict the nodes with type \"category\".\n        node_features = extract_node_features(\n            dataset, g, input_nodes, node_embed, feats, device\n        )\n\n        # Generate predictions.\n        logits = model(node_features, blocks)[category]\n        # Apply softmax to the logits and get the prediction by selecting the\n        # argmax.\n        y_hat = logits.log_softmax(dim=-1).argmax(dim=1, keepdims=True)\n        y_hats.append(y_hat.cpu())\n        y_true.append(labels[seeds[\"paper\"].cpu()])\n\n    y_pred = torch.cat(y_hats, dim=0)\n    y_true = torch.cat(y_true, dim=0)\n    y_true = torch.unsqueeze(y_true, 1)\n\n    if dataset == \"ogb-lsc-mag240m\":\n        y_pred = y_pred.view(-1)\n        y_true = y_true.view(-1)\n\n    if save_test_submission:\n        evaluator.save_test_submission(\n            input_dict={\"y_pred\": y_pred}, dir_path=\".\", mode=\"test-dev\"\n        )\n    return evaluator.eval({\"y_true\": y_true, \"y_pred\": y_pred})[\"acc\"]\n\n\ndef main(args):\n    device = (\n        \"cuda:0\" if torch.cuda.is_available() and args.num_gpus > 0 else \"cpu\"\n    )\n\n    # Prepare the data.\n    g, labels, num_classes, split_idx, train_loader, feats = prepare_data(\n        args, device\n    )\n\n    feat_size = 128 if args.dataset == \"ogbn-mag\" else 768\n\n    # Create the embedding layer and move it to the appropriate device.\n    embed_layer = None\n    if args.dataset == \"ogbn-mag\":\n        embed_layer = rel_graph_embed(g, feat_size).to(device)\n        print(\n            \"Number of embedding parameters: \"\n            f\"{sum(p.numel() for p in embed_layer.parameters())}\"\n        )\n\n    # Initialize the entity classification model.\n    model = EntityClassify(g, feat_size, num_classes).to(device)\n\n    print(\n        \"Number of model parameters: \"\n        f\"{sum(p.numel() for p in model.parameters())}\"\n    )\n\n    try:\n        if embed_layer is not None:\n            embed_layer.reset_parameters()\n        model.reset_parameters()\n    except:\n        # Old pytorch version doesn't support reset_parameters() API.\n        ##################################################################\n        # [Why we need to reset the parameters?]\n        # If parameters are not reset, the model will start with the\n        # parameters learned from the last run, potentially resulting\n        # in biased outcomes or sub-optimal performance if the model was\n        # previously stuck in a poor local minimum.\n        ##################################################################\n        pass\n\n    # `itertools.chain()` is a function in Python's itertools module.\n    # It is used to flatten a list of iterables, making them act as\n    # one big iterable.\n    # In this context, the following code is used to create a single\n    # iterable over the parameters of both the model and the embed_layer,\n    # which is passed to the optimizer. The optimizer then updates all\n    # these parameters during the training process.\n    all_params = itertools.chain(\n        model.parameters(),\n        [] if embed_layer is None else embed_layer.parameters(),\n    )\n    optimizer = torch.optim.Adam(all_params, lr=0.01)\n\n    # `expected_max`` is the number of physical cores on your machine.\n    # The `logical` parameter, when set to False, ensures that the count\n    # returned is the number of physical cores instead of logical cores\n    # (which could be higher due to technologies like Hyper-Threading).\n    expected_max = int(psutil.cpu_count(logical=False))\n    if args.num_workers >= expected_max:\n        print(\n            \"[ERROR] You specified num_workers are larger than physical\"\n            f\"cores, please set any number less than {expected_max}\",\n            file=sys.stderr,\n        )\n    train(\n        args.dataset,\n        g,\n        feats,\n        model,\n        embed_layer,\n        optimizer,\n        train_loader,\n        split_idx,\n        labels,\n        device,\n    )\n\n    print(\"Testing...\")\n    test_key = \"test\" if args.dataset == \"ogbn-mag\" else \"test-dev\"\n    test_acc = evaluate(\n        args.dataset,\n        g,\n        feats,\n        model,\n        embed_layer,\n        labels,\n        device,\n        split_idx[test_key],\n        save_test_submission=(args.dataset == \"ogb-lsc-mag240m\"),\n    )\n    print(f\"Test accuracy {test_acc*100:.4f}\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"RGCN\")\n    parser.add_argument(\n        \"--dataset\",\n        type=str,\n        default=\"ogbn-mag\",\n        help=\"Dataset for train: ogbn-mag, ogb-lsc-mag240m\",\n    )\n    parser.add_argument(\n        \"--num_gpus\",\n        type=int,\n        default=0,\n        help=\"Number of GPUs. Use 0 for CPU training.\",\n    )\n    parser.add_argument(\n        \"--num_workers\",\n        type=int,\n        default=0,\n        help=\"Number of worker processes for data loading.\",\n    )\n    parser.add_argument(\n        \"--rootdir\",\n        type=str,\n        default=\"./dataset/\",\n        help=\"Directory to download the OGB dataset.\",\n    )\n    parser.add_argument(\n        \"--graph_path\",\n        type=str,\n        default=\"./graph.dgl\",\n        help=\"Path to the graph file.\",\n    )\n    parser.add_argument(\n        \"--paper_feature_path\",\n        type=str,\n        default=\"./paper-feat.npy\",\n        help=\"Path to the features of paper nodes.\",\n    )\n    parser.add_argument(\n        \"--author_feature_path\",\n        type=str,\n        default=\"./author-feat.npy\",\n        help=\"Path to the features of author nodes.\",\n    )\n    parser.add_argument(\n        \"--inst_feature_path\",\n        type=str,\n        default=\"./inst-feat.npy\",\n        help=\"Path to the features of institution nodes.\",\n    )\n\n    args = parser.parse_args()\n\n    main(args)\n"
  },
  {
    "path": "examples/distributed/graphsage/README.md",
    "content": "## Distributed training\n\nThis is an example of training GraphSage in a distributed fashion. Before training, please install some python libs by pip:\n\n```\npip3 install ogb\n```\n\n**Requires PyTorch 1.12.0+ to work.**\n\nTo train GraphSage, it has five steps:\n\n### Step 0: Setup a Distributed File System\n* You may skip this step if your cluster already has folder(s) synchronized across machines.\n\nTo perform distributed training, files and codes need to be accessed across multiple machines. A distributed file system would perfectly handle the job (i.e., NFS, Ceph).\n\n#### Server side setup\nHere is an example of how to setup NFS. First, install essential libs on the storage server\n\n```\nsudo apt-get install nfs-kernel-server\n```\n\nBelow we assume the user account is `ubuntu` and we create a directory of `workspace` in the home directory.\n\n```\nmkdir -p /home/ubuntu/workspace\n```\n\nWe assume that the all servers are under a subnet with ip range `192.168.0.0` to `192.168.255.255`. The exports configuration needs to be modifed to\n\n```\nsudo vim /etc/exports\n# add the following line\n/home/ubuntu/workspace  192.168.0.0/16(rw,sync,no_subtree_check)\n```\n\nThe server's internal ip can be checked  via `ifconfig` or `ip`. If the ip does not begin with `192.168`, then you may use\n\n```\n/home/ubuntu/workspace  10.0.0.0/8(rw,sync,no_subtree_check)\n/home/ubuntu/workspace  172.16.0.0/12(rw,sync,no_subtree_check)\n```\n\nThen restart NFS, the setup on server side is finished.\n\n```\nsudo systemctl restart nfs-kernel-server\n```\n\nFor configraution details, please refer to [NFS ArchWiki](https://wiki.archlinux.org/index.php/NFS).\n\n#### Client side setup\n\nTo use NFS, clients also require to install essential packages\n\n```\nsudo apt-get install nfs-common\n```\n\nYou can either mount the NFS manually\n\n```\nmkdir -p /home/ubuntu/workspace\nsudo mount -t nfs <nfs-server-ip>:/home/ubuntu/workspace /home/ubuntu/workspace\n```\n\nor edit the fstab so the folder will be mounted automatically\n\n```\n# vim /etc/fstab\n## append the following line to the file\n<nfs-server-ip>:/home/ubuntu/workspace   /home/ubuntu/workspace   nfs   defaults\t0 0\n```\n\nThen run `mount -a`.\n\nNow go to `/home/ubuntu/workspace` and clone the DGL Github repository.\n\n### Step 1: set IP configuration file.\n\nUser need to set their own IP configuration file `ip_config.txt` before training. For example, if we have four machines in current cluster, the IP configuration\ncould like this:\n\n```\n172.31.19.1\n172.31.23.205\n172.31.29.175\n172.31.16.98\n```\n\nUsers need to make sure that the master node (node-0) has right permission to ssh to all the other nodes without password authentication.\n[This link](https://linuxize.com/post/how-to-setup-passwordless-ssh-login/) provides instructions of setting passwordless SSH login.\n\n### Step 2: partition the graph.\n\nThe example provides a script to partition some builtin graphs such as Reddit and OGB product graph.\nIf we want to train GraphSage on 4 machines, we need to partition the graph into 4 parts.\n\nIn this example, we partition the ogbn-products graph into 4 parts with Metis on node-0. The partitions are balanced with respect to\nthe number of nodes, the number of edges and the number of labelled nodes.\n\n```\npython3 partition_graph.py --dataset ogbn-products --num_parts 4 --balance_train --balance_edges\n```\n\nThis script generates partitioned graphs and store them in the directory called `data`.\n\n\n### Step 3: Launch distributed jobs\n\nDGL provides a script to launch the training job in the cluster. `part_config` and `ip_config`\nspecify relative paths to the path of the workspace.\n\nThe command below launches one process per machine for both sampling and training.\n\n```\npython3 ~/workspace/dgl/tools/launch.py \\\n--workspace ~/workspace/dgl/examples/distributed/graphsage/ \\\n--num_trainers 1 \\\n--num_samplers 0 \\\n--num_servers 1 \\\n--part_config data/ogbn-products.json \\\n--ip_config ip_config.txt \\\n\"python3 node_classification.py --graph_name ogbn-products --ip_config ip_config.txt --num_epochs 30 --batch_size 1000\"\n```\n\nBy default, this code will run on CPU. If you have GPU support, you can just add a `--num_gpus` argument in user command:\n\n```\npython3 ~/workspace/dgl/tools/launch.py \\\n--workspace ~/workspace/dgl/examples/distributed/graphsage/ \\\n--num_trainers 4 \\\n--num_samplers 0 \\\n--num_servers 1 \\\n--part_config data/ogbn-products.json \\\n--ip_config ip_config.txt \\\n\"python3 node_classification.py --graph_name ogbn-products --ip_config ip_config.txt --num_epochs 30 --batch_size 1000 --num_gpus 4\"\n```\n\nUnsupervised training(train with link prediction dataloader).\n\n```\npython3 ~/workspace/dgl/tools/launch.py \\\n--workspace ~/workspace/dgl/examples/distributed/graphsage/ \\\n--num_trainers 1 \\\n--num_samplers 0 \\\n--num_servers 1 \\\n--part_config data/ogbn-products.json \\\n--ip_config ip_config.txt \\\n\"python3 node_classification_unsupervised.py --graph_name ogbn-products --ip_config ip_config.txt --num_epochs 30 --batch_size 1000 --remove_edge\"\n```\n\n### Running with GraphBolt\n\nIn order to run with `GraphBolt`, we need to partition graph into `GraphBolt` data formats.Please note that both `DGL` and `GraphBolt` partitions are saved together.\n\nIf we have already partitioned into `DGL` format, just convert them directly like below:\n\n```\n    python3 -c \"import dgl; dgl.distributed.dgl_partition_to_graphbolt('ogbn-products.json')\"\n```\n\nOr partition from scratch like this:\n\n```\npython3 partition_graph.py --dataset ogbn-products --num_parts 2 --balance_train --balance_edges --use_graphbolt\n```\n\n#### Partition sizes compared to DGL\n\nCompared to `DGL`, `GraphBolt` partitions are much smaller(reduced to **16%** and **19%** for `ogbn-products` and `ogbn-papers100M` respectively).\n\n`ogbn-products`\n\n| Data Formats |         File Name            | Part 0 | Part 1 |\n| ------------ | ---------------------------- | ------ | ------ |\n| DGL          | graph.dgl                    | 1.5GB  | 1.6GB  |\n| GraphBolt    | fused_csc_sampling_graph.pt  | 255MB  | 265MB  |\n\n`ogbn-papers100M`\n\n| Data Formats |         File Name            | Part 0 | Part 1 |\n| ------------ | ---------------------------- | ------ | ------ |\n| DGL          | graph.dgl                    | 23GB   | 22GB   |\n| GraphBolt    | fused_csc_sampling_graph.pt  | 4.4GB  | 4.1GB  |\n\nThen run example with `--use_graphbolt`.\n\n```\npython3 ~/workspace/dgl/tools/launch.py \\\n--workspace ~/workspace/dgl/examples/distributed/graphsage/ \\\n--num_trainers 4 \\\n--num_samplers 0 \\\n--num_servers 2 \\\n--part_config data/ogbn-products.json \\\n--ip_config ip_config.txt \\\n\"python3 node_classification.py --graph_name ogbn-products --ip_config ip_config.txt --num_epochs 10 --use_graphbolt\"\n```\n\n#### Performance compared to `DGL`\n\nCompared to `DGL`, `GraphBolt`'s sampler works faster(reduced to **80%** and **77%** for `ogbn-products` and `ogbn-papers100M` respectively). `Min` and `Max` are statistics of all trainers on all nodes(machines).\n\nAs for RAM usage, the shared memory(measured by **shared** field of `free` command) usage is decreased due to smaller graph partitions in `GraphBolt` though the peak memory used by processes(measured by **used** field of `free` command) does not decrease.\n\n`ogbn-products`\n\n| Data Formats | Sample Time Per Epoch (CPU) |      Test Accuracy (10 epochs)   |  shared | used (peak) |\n| ------------ | --------------------------- | -------------------------------- |  -----  | ---- |\n|     DGL      | Min: 1.2884s, Max: 1.4159s  | Min: 64.38%, Max: 70.42%         |  2.4GB  | 7.8GB|\n|   GraphBolt  | Min: 1.0589s, Max: 1.1400s  | Min: 61.68%, Max: 71.23%         |  1.1GB  | 7.8GB|\n\n\n`ogbn-papers100M`\n\n| Data Formats | Sample Time Per Epoch (CPU) |      Test Accuracy (10 epochs)   |  shared | used (peak) |\n| ------------ | --------------------------- | -------------------------------- |  -----  | ---- |\n|     DGL      | Min: 5.5570s, Max: 6.1900s  | Min: 29.12%, Max: 34.33%         |  84GB   | 43GB |\n|   GraphBolt  | Min: 4.5046s, Max: 4.7718s  | Min: 29.11%, Max: 33.49%         |  67GB   | 43GB |\n"
  },
  {
    "path": "examples/distributed/graphsage/node_classification.py",
    "content": "import argparse\nimport socket\nimport time\n\nimport dgl\nimport dgl.distributed\nimport dgl.nn.pytorch as dglnn\nimport numpy as np\nimport torch as th\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nimport tqdm\n\n\nclass DistSAGE(nn.Module):\n    \"\"\"\n    SAGE model for distributed train and evaluation.\n\n    Parameters\n    ----------\n    in_feats : int\n        Feature dimension.\n    n_hidden : int\n        Hidden layer dimension.\n    n_classes : int\n        Number of classes.\n    n_layers : int\n        Number of layers.\n    activation : callable\n        Activation function.\n    dropout : float\n        Dropout value.\n    \"\"\"\n\n    def __init__(\n        self, in_feats, n_hidden, n_classes, n_layers, activation, dropout\n    ):\n        super().__init__()\n        self.n_layers = n_layers\n        self.n_hidden = n_hidden\n        self.n_classes = n_classes\n        self.layers = nn.ModuleList()\n        self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, \"mean\"))\n        for _ in range(1, n_layers - 1):\n            self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, \"mean\"))\n        self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, \"mean\"))\n        self.dropout = nn.Dropout(dropout)\n        self.activation = activation\n\n    def forward(self, blocks, x):\n        \"\"\"\n        Forward function.\n\n        Parameters\n        ----------\n        blocks : List[DGLBlock]\n            Sampled blocks.\n        x : DistTensor\n            Feature data.\n        \"\"\"\n        h = x\n        for i, (layer, block) in enumerate(zip(self.layers, blocks)):\n            h = layer(block, h)\n            if i != len(self.layers) - 1:\n                h = self.activation(h)\n                h = self.dropout(h)\n        return h\n\n    def inference(self, g, x, batch_size, device):\n        \"\"\"\n        Distributed layer-wise inference with the GraphSAGE model on full\n        neighbors.\n\n        Parameters\n        ----------\n        g : DistGraph\n            Input Graph for inference.\n        x : DistTensor\n            Node feature data of input graph.\n\n        Returns\n        -------\n        DistTensor\n            Inference results.\n        \"\"\"\n        # Split nodes to each trainer.\n        nodes = dgl.distributed.node_split(\n            np.arange(g.num_nodes()),\n            g.get_partition_book(),\n            force_even=True,\n        )\n\n        for i, layer in enumerate(self.layers):\n            # Create DistTensor to save forward results.\n            if i == len(self.layers) - 1:\n                out_dim = self.n_classes\n                name = \"h_last\"\n            else:\n                out_dim = self.n_hidden\n                name = \"h\"\n            y = dgl.distributed.DistTensor(\n                (g.num_nodes(), out_dim),\n                th.float32,\n                name,\n                persistent=True,\n            )\n            print(f\"|V|={g.num_nodes()}, inference batch size: {batch_size}\")\n\n            # `-1` indicates all inbound edges will be inlcuded, namely, full\n            # neighbor sampling.\n            sampler = dgl.dataloading.NeighborSampler([-1])\n            dataloader = dgl.distributed.DistNodeDataLoader(\n                g,\n                nodes,\n                sampler,\n                batch_size=batch_size,\n                shuffle=False,\n                drop_last=False,\n            )\n\n            for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):\n                block = blocks[0].to(device)\n                h = x[input_nodes].to(device)\n                h_dst = h[: block.number_of_dst_nodes()]\n                h = layer(block, (h, h_dst))\n                if i != len(self.layers) - 1:\n                    h = self.activation(h)\n                    h = self.dropout(h)\n                # Copy back to CPU as DistTensor requires data reside on CPU.\n                y[output_nodes] = h.cpu()\n\n            x = y\n            # Synchronize trainers.\n            g.barrier()\n        return x\n\n\ndef compute_acc(pred, labels):\n    \"\"\"\n    Compute the accuracy of prediction given the labels.\n\n    Parameters\n    ----------\n    pred : torch.Tensor\n        Predicted labels.\n    labels : torch.Tensor\n        Ground-truth labels.\n\n    Returns\n    -------\n    float\n        Accuracy.\n    \"\"\"\n    labels = labels.long()\n    return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred)\n\n\ndef evaluate(model, g, inputs, labels, val_nid, test_nid, batch_size, device):\n    \"\"\"\n    Evaluate the model on the validation and test set.\n\n    Parameters\n    ----------\n    model : DistSAGE\n        The model to be evaluated.\n    g : DistGraph\n        The entire graph.\n    inputs : DistTensor\n        The feature data of all the nodes.\n    labels : DistTensor\n        The labels of all the nodes.\n    val_nid : torch.Tensor\n        The node IDs for validation.\n    test_nid : torch.Tensor\n        The node IDs for test.\n    batch_size : int\n        Batch size for evaluation.\n    device : torch.Device\n        The target device to evaluate on.\n\n    Returns\n    -------\n    float\n        Validation accuracy.\n    float\n        Test accuracy.\n    \"\"\"\n    model.eval()\n    with th.no_grad():\n        pred = model.inference(g, inputs, batch_size, device)\n    model.train()\n    return compute_acc(pred[val_nid], labels[val_nid]), compute_acc(\n        pred[test_nid], labels[test_nid]\n    )\n\n\ndef run(args, device, data):\n    \"\"\"\n    Train and evaluate DistSAGE.\n\n    Parameters\n    ----------\n    args : argparse.Args\n        Arguments for train and evaluate.\n    device : torch.Device\n        Target device for train and evaluate.\n    data : Packed Data\n        Packed data includes train/val/test IDs, feature dimension,\n        number of classes, graph.\n    \"\"\"\n    train_nid, val_nid, test_nid, in_feats, n_classes, g = data\n    sampler = dgl.dataloading.NeighborSampler(\n        [int(fanout) for fanout in args.fan_out.split(\",\")]\n    )\n    dataloader = dgl.distributed.DistNodeDataLoader(\n        g,\n        train_nid,\n        sampler,\n        batch_size=args.batch_size,\n        shuffle=True,\n        drop_last=False,\n    )\n    model = DistSAGE(\n        in_feats,\n        args.num_hidden,\n        n_classes,\n        args.num_layers,\n        F.relu,\n        args.dropout,\n    )\n    model = model.to(device)\n    if args.num_gpus == 0:\n        model = th.nn.parallel.DistributedDataParallel(model)\n    else:\n        model = th.nn.parallel.DistributedDataParallel(\n            model, device_ids=[device], output_device=device\n        )\n    loss_fcn = nn.CrossEntropyLoss()\n    loss_fcn = loss_fcn.to(device)\n    optimizer = optim.Adam(model.parameters(), lr=args.lr)\n\n    # Training loop.\n    iter_tput = []\n    epoch = 0\n    epoch_time = []\n    test_acc = 0.0\n    for _ in range(args.num_epochs):\n        epoch += 1\n        tic = time.time()\n        # Various time statistics.\n        sample_time = 0\n        forward_time = 0\n        backward_time = 0\n        update_time = 0\n        num_seeds = 0\n        num_inputs = 0\n        start = time.time()\n        step_time = []\n\n        with model.join():\n            for step, (input_nodes, seeds, blocks) in enumerate(dataloader):\n                tic_step = time.time()\n                sample_time += tic_step - start\n                # Slice feature and label.\n                batch_inputs = g.ndata[\"features\"][input_nodes]\n                batch_labels = g.ndata[\"labels\"][seeds].long()\n                num_seeds += len(blocks[-1].dstdata[dgl.NID])\n                num_inputs += len(blocks[0].srcdata[dgl.NID])\n                # Move to target device.\n                blocks = [block.to(device) for block in blocks]\n                batch_inputs = batch_inputs.to(device)\n                batch_labels = batch_labels.to(device)\n                # Compute loss and prediction.\n                start = time.time()\n                batch_pred = model(blocks, batch_inputs)\n                loss = loss_fcn(batch_pred, batch_labels)\n                forward_end = time.time()\n                optimizer.zero_grad()\n                loss.backward()\n                compute_end = time.time()\n                forward_time += forward_end - start\n                backward_time += compute_end - forward_end\n\n                optimizer.step()\n                update_time += time.time() - compute_end\n\n                step_t = time.time() - tic_step\n                step_time.append(step_t)\n                iter_tput.append(len(blocks[-1].dstdata[dgl.NID]) / step_t)\n                if (step + 1) % args.log_every == 0:\n                    acc = compute_acc(batch_pred, batch_labels)\n                    gpu_mem_alloc = (\n                        th.cuda.max_memory_allocated() / 1000000\n                        if th.cuda.is_available()\n                        else 0\n                    )\n                    sample_speed = np.mean(iter_tput[-args.log_every :])\n                    mean_step_time = np.mean(step_time[-args.log_every :])\n                    print(\n                        f\"Part {g.rank()} | Epoch {epoch:05d} | Step {step:05d}\"\n                        f\" | Loss {loss.item():.4f} | Train Acc {acc.item():.4f}\"\n                        f\" | Speed (samples/sec) {sample_speed:.4f}\"\n                        f\" | GPU {gpu_mem_alloc:.1f} MB | \"\n                        f\"Mean step time {mean_step_time:.3f} s\"\n                    )\n                start = time.time()\n\n        toc = time.time()\n        print(\n            f\"Part {g.rank()}, Epoch Time(s): {toc - tic:.4f}, \"\n            f\"sample+data_copy: {sample_time:.4f}, forward: {forward_time:.4f},\"\n            f\" backward: {backward_time:.4f}, update: {update_time:.4f}, \"\n            f\"#seeds: {num_seeds}, #inputs: {num_inputs}\"\n        )\n        epoch_time.append(toc - tic)\n\n        if epoch % args.eval_every == 0 or epoch == args.num_epochs:\n            start = time.time()\n            val_acc, test_acc = evaluate(\n                model.module,\n                g,\n                g.ndata[\"features\"],\n                g.ndata[\"labels\"],\n                val_nid,\n                test_nid,\n                args.batch_size_eval,\n                device,\n            )\n            print(\n                f\"Part {g.rank()}, Val Acc {val_acc:.4f}, \"\n                f\"Test Acc {test_acc:.4f}, time: {time.time() - start:.4f}\"\n            )\n\n    return np.mean(epoch_time[-int(args.num_epochs * 0.8) :]), test_acc\n\n\ndef main(args):\n    \"\"\"\n    Main function.\n    \"\"\"\n    host_name = socket.gethostname()\n    print(f\"{host_name}: Initializing DistDGL.\")\n    dgl.distributed.initialize(args.ip_config, use_graphbolt=args.use_graphbolt)\n    print(f\"{host_name}: Initializing PyTorch process group.\")\n    th.distributed.init_process_group(backend=args.backend)\n    print(f\"{host_name}: Initializing DistGraph.\")\n    g = dgl.distributed.DistGraph(args.graph_name, part_config=args.part_config)\n    print(f\"Rank of {host_name}: {g.rank()}\")\n\n    # Split train/val/test IDs for each trainer.\n    pb = g.get_partition_book()\n    if \"trainer_id\" in g.ndata:\n        train_nid = dgl.distributed.node_split(\n            g.ndata[\"train_mask\"],\n            pb,\n            force_even=True,\n            node_trainer_ids=g.ndata[\"trainer_id\"],\n        )\n        val_nid = dgl.distributed.node_split(\n            g.ndata[\"val_mask\"],\n            pb,\n            force_even=True,\n            node_trainer_ids=g.ndata[\"trainer_id\"],\n        )\n        test_nid = dgl.distributed.node_split(\n            g.ndata[\"test_mask\"],\n            pb,\n            force_even=True,\n            node_trainer_ids=g.ndata[\"trainer_id\"],\n        )\n    else:\n        train_nid = dgl.distributed.node_split(\n            g.ndata[\"train_mask\"], pb, force_even=True\n        )\n        val_nid = dgl.distributed.node_split(\n            g.ndata[\"val_mask\"], pb, force_even=True\n        )\n        test_nid = dgl.distributed.node_split(\n            g.ndata[\"test_mask\"], pb, force_even=True\n        )\n    local_nid = pb.partid2nids(pb.partid).detach().numpy()\n    num_train_local = len(np.intersect1d(train_nid.numpy(), local_nid))\n    num_val_local = len(np.intersect1d(val_nid.numpy(), local_nid))\n    num_test_local = len(np.intersect1d(test_nid.numpy(), local_nid))\n    print(\n        f\"part {g.rank()}, train: {len(train_nid)} (local: {num_train_local}), \"\n        f\"val: {len(val_nid)} (local: {num_val_local}), \"\n        f\"test: {len(test_nid)} (local: {num_test_local})\"\n    )\n    del local_nid\n    if args.num_gpus == 0:\n        device = th.device(\"cpu\")\n    else:\n        dev_id = g.rank() % args.num_gpus\n        device = th.device(\"cuda:\" + str(dev_id))\n    n_classes = args.n_classes\n    if n_classes == 0:\n        labels = g.ndata[\"labels\"][np.arange(g.num_nodes())]\n        n_classes = len(th.unique(labels[th.logical_not(th.isnan(labels))]))\n        del labels\n    print(f\"Number of classes: {n_classes}\")\n\n    # Pack data.\n    in_feats = g.ndata[\"features\"].shape[1]\n    data = train_nid, val_nid, test_nid, in_feats, n_classes, g\n\n    # Train and evaluate.\n    epoch_time, test_acc = run(args, device, data)\n    print(\n        f\"Summary of node classification(GraphSAGE): GraphName \"\n        f\"{args.graph_name} | TrainEpochTime(mean) {epoch_time:.4f} \"\n        f\"| TestAccuracy {test_acc:.4f}\"\n    )\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"Distributed GraphSAGE.\")\n    parser.add_argument(\"--graph_name\", type=str, help=\"graph name\")\n    parser.add_argument(\n        \"--ip_config\", type=str, help=\"The file for IP configuration\"\n    )\n    parser.add_argument(\n        \"--part_config\", type=str, help=\"The path to the partition config file\"\n    )\n    parser.add_argument(\n        \"--n_classes\", type=int, default=0, help=\"the number of classes\"\n    )\n    parser.add_argument(\n        \"--backend\",\n        type=str,\n        default=\"gloo\",\n        help=\"pytorch distributed backend\",\n    )\n    parser.add_argument(\n        \"--num_gpus\",\n        type=int,\n        default=0,\n        help=\"the number of GPU device. Use 0 for CPU training\",\n    )\n    parser.add_argument(\"--num_epochs\", type=int, default=20)\n    parser.add_argument(\"--num_hidden\", type=int, default=16)\n    parser.add_argument(\"--num_layers\", type=int, default=2)\n    parser.add_argument(\"--fan_out\", type=str, default=\"10,25\")\n    parser.add_argument(\"--batch_size\", type=int, default=1000)\n    parser.add_argument(\"--batch_size_eval\", type=int, default=100000)\n    parser.add_argument(\"--log_every\", type=int, default=20)\n    parser.add_argument(\"--eval_every\", type=int, default=5)\n    parser.add_argument(\"--lr\", type=float, default=0.003)\n    parser.add_argument(\"--dropout\", type=float, default=0.5)\n    parser.add_argument(\n        \"--local_rank\", type=int, help=\"get rank of the process\"\n    )\n    parser.add_argument(\n        \"--pad-data\",\n        default=False,\n        action=\"store_true\",\n        help=\"Pad train nid to the same length across machine, to ensure num \"\n        \"of batches to be the same.\",\n    )\n    parser.add_argument(\n        \"--use_graphbolt\",\n        action=\"store_true\",\n        help=\"Use GraphBolt for distributed train.\",\n    )\n    args = parser.parse_args()\n    print(f\"Arguments: {args}\")\n    main(args)\n"
  },
  {
    "path": "examples/distributed/graphsage/node_classification_unsupervised.py",
    "content": "import argparse\nimport time\nfrom contextlib import contextmanager\n\nimport dgl\nimport dgl.distributed\nimport dgl.function as fn\nimport dgl.nn.pytorch as dglnn\n\nimport numpy as np\nimport sklearn.linear_model as lm\nimport sklearn.metrics as skm\nimport torch as th\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nimport tqdm\n\n\nclass DistSAGE(nn.Module):\n    def __init__(\n        self, in_feats, n_hidden, n_classes, n_layers, activation, dropout\n    ):\n        super().__init__()\n        self.n_layers = n_layers\n        self.n_hidden = n_hidden\n        self.n_classes = n_classes\n        self.layers = nn.ModuleList()\n        self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, \"mean\"))\n        for i in range(1, n_layers - 1):\n            self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, \"mean\"))\n        self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, \"mean\"))\n        self.dropout = nn.Dropout(dropout)\n        self.activation = activation\n\n    def forward(self, blocks, x):\n        h = x\n        for i, (layer, block) in enumerate(zip(self.layers, blocks)):\n            h = layer(block, h)\n            if i != len(self.layers) - 1:\n                h = self.activation(h)\n                h = self.dropout(h)\n        return h\n\n    def inference(self, g, x, batch_size, device):\n        \"\"\"\n        Inference with the GraphSAGE model on full neighbors (i.e. without\n        neighbor sampling).\n\n        g : the entire graph.\n        x : the input of entire node set.\n\n        The inference code is written in a fashion that it could handle any\n        number of nodes and layers.\n        \"\"\"\n        # During inference with sampling, multi-layer blocks are very\n        # inefficient because lots of computations in the first few layers are\n        # repeated. Therefore, we compute the representation of all nodes layer\n        # by layer.  The nodes on each layer are of course splitted in batches.\n        # TODO: can we standardize this?\n        nodes = dgl.distributed.node_split(\n            np.arange(g.num_nodes()),\n            g.get_partition_book(),\n            force_even=True,\n        )\n        y = dgl.distributed.DistTensor(\n            (g.num_nodes(), self.n_hidden),\n            th.float32,\n            \"h\",\n            persistent=True,\n        )\n        for i, layer in enumerate(self.layers):\n            if i == len(self.layers) - 1:\n                y = dgl.distributed.DistTensor(\n                    (g.num_nodes(), self.n_classes),\n                    th.float32,\n                    \"h_last\",\n                    persistent=True,\n                )\n            # Create sampler\n            sampler = dgl.dataloading.NeighborSampler([-1])\n            # Create dataloader\n            dataloader = dgl.distributed.DistNodeDataLoader(\n                g,\n                nodes,\n                sampler,\n                batch_size=batch_size,\n                shuffle=False,\n                drop_last=False,\n            )\n\n            for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):\n                block = blocks[0].to(device)\n                h = x[input_nodes].to(device)\n                h_dst = h[: block.number_of_dst_nodes()]\n                h = layer(block, (h, h_dst))\n                if i != len(self.layers) - 1:\n                    h = self.activation(h)\n                    h = self.dropout(h)\n\n                y[output_nodes] = h.cpu()\n\n            x = y\n            g.barrier()\n        return y\n\n    @contextmanager\n    def join(self):\n        \"\"\"dummy join for standalone\"\"\"\n        yield\n\n\ndef load_subtensor(g, input_nodes, device):\n    \"\"\"\n    Copys features and labels of a set of nodes onto GPU.\n    \"\"\"\n    batch_inputs = g.ndata[\"features\"][input_nodes].to(device)\n    return batch_inputs\n\n\nclass CrossEntropyLoss(nn.Module):\n    def forward(self, block_outputs, pos_graph, neg_graph):\n        with pos_graph.local_scope():\n            pos_graph.ndata[\"h\"] = block_outputs\n            pos_graph.apply_edges(fn.u_dot_v(\"h\", \"h\", \"score\"))\n            pos_score = pos_graph.edata[\"score\"]\n        with neg_graph.local_scope():\n            neg_graph.ndata[\"h\"] = block_outputs\n            neg_graph.apply_edges(fn.u_dot_v(\"h\", \"h\", \"score\"))\n            neg_score = neg_graph.edata[\"score\"]\n\n        score = th.cat([pos_score, neg_score])\n        label = th.cat(\n            [th.ones_like(pos_score), th.zeros_like(neg_score)]\n        ).long()\n        loss = F.binary_cross_entropy_with_logits(score, label.float())\n        return loss\n\n\ndef generate_emb(model, g, inputs, batch_size, device):\n    \"\"\"\n    Generate embeddings for each node\n    g : The entire graph.\n    inputs : The features of all the nodes.\n    batch_size : Number of nodes to compute at the same time.\n    device : The GPU device to evaluate on.\n    \"\"\"\n    model.eval()\n    with th.no_grad():\n        pred = model.inference(g, inputs, batch_size, device)\n\n    return pred\n\n\ndef compute_acc(emb, labels, train_nids, val_nids, test_nids):\n    \"\"\"\n    Compute the accuracy of prediction given the labels.\n\n    We will fist train a LogisticRegression model using the trained embeddings,\n    the training set, validation set and test set is provided as the arguments.\n\n    The final result is predicted by the lr model.\n\n    emb: The pretrained embeddings\n    labels: The ground truth\n    train_nids: The training set node ids\n    val_nids: The validation set node ids\n    test_nids: The test set node ids\n    \"\"\"\n\n    emb = emb[np.arange(labels.shape[0])].cpu().numpy()\n    train_nids = train_nids.cpu().numpy()\n    val_nids = val_nids.cpu().numpy()\n    test_nids = test_nids.cpu().numpy()\n    labels = labels.cpu().numpy()\n\n    emb = (emb - emb.mean(0, keepdims=True)) / emb.std(0, keepdims=True)\n    lr = lm.LogisticRegression(multi_class=\"multinomial\", max_iter=10000)\n    lr.fit(emb[train_nids], labels[train_nids])\n\n    pred = lr.predict(emb)\n    eval_acc = skm.accuracy_score(labels[val_nids], pred[val_nids])\n    test_acc = skm.accuracy_score(labels[test_nids], pred[test_nids])\n    return eval_acc, test_acc\n\n\ndef run(args, device, data):\n    # Unpack data\n    (\n        train_eids,\n        train_nids,\n        in_feats,\n        g,\n        global_train_nid,\n        global_valid_nid,\n        global_test_nid,\n        labels,\n    ) = data\n    # Create sampler\n    neg_sampler = dgl.dataloading.negative_sampler.Uniform(args.num_negs)\n    sampler = dgl.dataloading.NeighborSampler(\n        [int(fanout) for fanout in args.fan_out.split(\",\")]\n    )\n    # Create dataloader\n    exclude = \"reverse_id\" if args.remove_edge else None\n    reverse_eids = th.arange(g.num_edges()) if args.remove_edge else None\n    dataloader = dgl.distributed.DistEdgeDataLoader(\n        g,\n        train_eids,\n        sampler,\n        negative_sampler=neg_sampler,\n        exclude=exclude,\n        reverse_eids=reverse_eids,\n        batch_size=args.batch_size,\n        shuffle=True,\n        drop_last=False,\n    )\n    # Define model and optimizer\n    model = DistSAGE(\n        in_feats,\n        args.num_hidden,\n        args.num_hidden,\n        args.num_layers,\n        F.relu,\n        args.dropout,\n    )\n    model = model.to(device)\n    if not args.standalone:\n        if args.num_gpus == -1:\n            model = th.nn.parallel.DistributedDataParallel(model)\n        else:\n            dev_id = g.rank() % args.num_gpus\n            model = th.nn.parallel.DistributedDataParallel(\n                model, device_ids=[dev_id], output_device=dev_id\n            )\n    loss_fcn = CrossEntropyLoss()\n    loss_fcn = loss_fcn.to(device)\n    optimizer = optim.Adam(model.parameters(), lr=args.lr)\n\n    # Training loop\n    epoch = 0\n    for epoch in range(args.num_epochs):\n        num_seeds = 0\n        num_inputs = 0\n\n        step_time = []\n        sample_t = []\n        feat_copy_t = []\n        forward_t = []\n        backward_t = []\n        update_t = []\n        iter_tput = []\n\n        start = time.time()\n        with model.join():\n            # Loop over the dataloader to sample the computation dependency\n            # graph as a list of blocks.\n            for step, (input_nodes, pos_graph, neg_graph, blocks) in enumerate(\n                dataloader\n            ):\n                if args.debug:\n                    # Verify exclude_edges functionality.\n                    for block in blocks:\n                        current_eids = block.edata[dgl.EID]\n                        seed_eids = pos_graph.edata[dgl.EID]\n                        if exclude is None:\n                            assert th.any(th.isin(current_eids, seed_eids))\n                        elif exclude == \"self\":\n                            assert not th.any(th.isin(current_eids, seed_eids))\n                        elif exclude == \"reverse_id\":\n                            assert not th.any(th.isin(current_eids, seed_eids))\n                        else:\n                            raise ValueError(\n                                f\"Unsupported exclude type: {exclude}\"\n                            )\n                tic_step = time.time()\n                sample_t.append(tic_step - start)\n\n                copy_t = time.time()\n                pos_graph = pos_graph.to(device)\n                neg_graph = neg_graph.to(device)\n                blocks = [block.to(device) for block in blocks]\n                batch_inputs = load_subtensor(g, input_nodes, device)\n                copy_time = time.time()\n                feat_copy_t.append(copy_time - copy_t)\n\n                # Compute loss and prediction\n                batch_pred = model(blocks, batch_inputs)\n                loss = loss_fcn(batch_pred, pos_graph, neg_graph)\n                forward_end = time.time()\n                optimizer.zero_grad()\n                loss.backward()\n                compute_end = time.time()\n                forward_t.append(forward_end - copy_time)\n                backward_t.append(compute_end - forward_end)\n\n                # Aggregate gradients in multiple nodes.\n                optimizer.step()\n                update_t.append(time.time() - compute_end)\n\n                pos_edges = pos_graph.num_edges()\n\n                step_t = time.time() - start\n                step_time.append(step_t)\n                iter_tput.append(pos_edges / step_t)\n                num_seeds += pos_edges\n                if step % args.log_every == 0:\n                    print(\n                        \"[{}] Epoch {:05d} | Step {:05d} | Loss {:.4f} | Speed \"\n                        \"(samples/sec) {:.4f} | time {:.3f}s | sample {:.3f} | \"\n                        \"copy {:.3f} | forward {:.3f} | backward {:.3f} | \"\n                        \"update {:.3f}\".format(\n                            g.rank(),\n                            epoch,\n                            step,\n                            loss.item(),\n                            np.mean(iter_tput[3:]),\n                            np.sum(step_time[-args.log_every :]),\n                            np.sum(sample_t[-args.log_every :]),\n                            np.sum(feat_copy_t[-args.log_every :]),\n                            np.sum(forward_t[-args.log_every :]),\n                            np.sum(backward_t[-args.log_every :]),\n                            np.sum(update_t[-args.log_every :]),\n                        )\n                    )\n                start = time.time()\n\n        print(\n            \"[{}]Epoch Time(s): {:.4f}, sample: {:.4f}, data copy: {:.4f}, \"\n            \"forward: {:.4f}, backward: {:.4f}, update: {:.4f}, #seeds: {}, \"\n            \"#inputs: {}\".format(\n                g.rank(),\n                np.sum(step_time),\n                np.sum(sample_t),\n                np.sum(feat_copy_t),\n                np.sum(forward_t),\n                np.sum(backward_t),\n                np.sum(update_t),\n                num_seeds,\n                num_inputs,\n            )\n        )\n        epoch += 1\n\n    # evaluate the embedding using LogisticRegression\n    pred = generate_emb(\n        model if args.standalone else model.module,\n        g,\n        g.ndata[\"features\"],\n        args.batch_size_eval,\n        device,\n    )\n    if g.rank() == 0:\n        eval_acc, test_acc = compute_acc(\n            pred, labels, global_train_nid, global_valid_nid, global_test_nid\n        )\n        print(\"eval acc {:.4f}; test acc {:.4f}\".format(eval_acc, test_acc))\n\n    # sync for eval and test\n    if not args.standalone:\n        th.distributed.barrier()\n\n    if not args.standalone:\n        g._client.barrier()\n\n        # save features into file\n        if g.rank() == 0:\n            th.save(pred, \"emb.pt\")\n    else:\n        th.save(pred, \"emb.pt\")\n\n\ndef main(args):\n    print(\"--- Distributed node classification with GraphSAGE unsuperised ---\")\n    dgl.distributed.initialize(args.ip_config)\n    if not args.standalone:\n        th.distributed.init_process_group(backend=\"gloo\")\n    g = dgl.distributed.DistGraph(args.graph_name, part_config=args.part_config)\n    print(\"rank:\", g.rank())\n    print(\"number of edges\", g.num_edges())\n\n    train_eids = dgl.distributed.edge_split(\n        th.ones((g.num_edges(),), dtype=th.bool),\n        g.get_partition_book(),\n        force_even=True,\n    )\n    train_nids = dgl.distributed.node_split(\n        th.ones((g.num_nodes(),), dtype=th.bool), g.get_partition_book()\n    )\n    global_train_nid = th.LongTensor(\n        np.nonzero(g.ndata[\"train_mask\"][np.arange(g.num_nodes())])\n    )\n    global_valid_nid = th.LongTensor(\n        np.nonzero(g.ndata[\"val_mask\"][np.arange(g.num_nodes())])\n    )\n    global_test_nid = th.LongTensor(\n        np.nonzero(g.ndata[\"test_mask\"][np.arange(g.num_nodes())])\n    )\n    labels = g.ndata[\"labels\"][np.arange(g.num_nodes())]\n    if args.num_gpus == -1:\n        device = th.device(\"cpu\")\n    else:\n        dev_id = g.rank() % args.num_gpus\n        device = th.device(\"cuda:\" + str(dev_id))\n\n    # Pack data\n    in_feats = g.ndata[\"features\"].shape[1]\n    global_train_nid = global_train_nid.squeeze()\n    global_valid_nid = global_valid_nid.squeeze()\n    global_test_nid = global_test_nid.squeeze()\n    print(\"number of train {}\".format(global_train_nid.shape[0]))\n    print(\"number of valid {}\".format(global_valid_nid.shape[0]))\n    print(\"number of test {}\".format(global_test_nid.shape[0]))\n    data = (\n        train_eids,\n        train_nids,\n        in_feats,\n        g,\n        global_train_nid,\n        global_valid_nid,\n        global_test_nid,\n        labels,\n    )\n    run(args, device, data)\n    print(\"parent ends\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"GCN\")\n    parser.add_argument(\"--graph_name\", type=str, help=\"graph name\")\n    parser.add_argument(\"--id\", type=int, help=\"the partition id\")\n    parser.add_argument(\n        \"--ip_config\", type=str, help=\"The file for IP configuration\"\n    )\n    parser.add_argument(\n        \"--part_config\", type=str, help=\"The path to the partition config file\"\n    )\n    parser.add_argument(\"--n_classes\", type=int, help=\"the number of classes\")\n    parser.add_argument(\n        \"--num_gpus\",\n        type=int,\n        default=-1,\n        help=\"the number of GPU device. Use -1 for CPU training\",\n    )\n    parser.add_argument(\"--num_epochs\", type=int, default=20)\n    parser.add_argument(\"--num_hidden\", type=int, default=16)\n    parser.add_argument(\"--num-layers\", type=int, default=2)\n    parser.add_argument(\"--fan_out\", type=str, default=\"10,25\")\n    parser.add_argument(\"--batch_size\", type=int, default=1000)\n    parser.add_argument(\"--batch_size_eval\", type=int, default=100000)\n    parser.add_argument(\"--log_every\", type=int, default=20)\n    parser.add_argument(\"--eval_every\", type=int, default=5)\n    parser.add_argument(\"--lr\", type=float, default=0.003)\n    parser.add_argument(\"--dropout\", type=float, default=0.5)\n    parser.add_argument(\n        \"--local_rank\", type=int, help=\"get rank of the process\"\n    )\n    parser.add_argument(\n        \"--standalone\", action=\"store_true\", help=\"run in the standalone mode\"\n    )\n    parser.add_argument(\"--num_negs\", type=int, default=1)\n    parser.add_argument(\n        \"--remove_edge\",\n        default=False,\n        action=\"store_true\",\n        help=\"whether to remove edges during sampling\",\n    )\n    parser.add_argument(\n        \"--debug\",\n        default=False,\n        action=\"store_true\",\n        help=\"whether to verify functionality of remove edges\",\n    )\n    args = parser.parse_args()\n    print(args)\n    main(args)\n"
  },
  {
    "path": "examples/distributed/graphsage/partition_graph.py",
    "content": "import argparse\nimport time\n\nimport dgl\nimport torch as th\nfrom dgl.data import RedditDataset\nfrom ogb.nodeproppred import DglNodePropPredDataset\n\n\ndef load_reddit(self_loop=True):\n    \"\"\"Load reddit dataset.\"\"\"\n    data = RedditDataset(self_loop=self_loop)\n    g = data[0]\n    g.ndata[\"features\"] = g.ndata.pop(\"feat\")\n    g.ndata[\"labels\"] = g.ndata.pop(\"label\")\n    return g, data.num_classes\n\n\ndef load_ogb(name, root=\"dataset\"):\n    \"\"\"Load ogbn dataset.\"\"\"\n    data = DglNodePropPredDataset(name=name, root=root)\n    splitted_idx = data.get_idx_split()\n    graph, labels = data[0]\n    labels = labels[:, 0]\n\n    graph.ndata[\"features\"] = graph.ndata.pop(\"feat\")\n    graph.ndata[\"labels\"] = labels\n    num_labels = len(th.unique(labels[th.logical_not(th.isnan(labels))]))\n\n    # Find the node IDs in the training, validation, and test set.\n    train_nid, val_nid, test_nid = (\n        splitted_idx[\"train\"],\n        splitted_idx[\"valid\"],\n        splitted_idx[\"test\"],\n    )\n    train_mask = th.zeros((graph.num_nodes(),), dtype=th.bool)\n    train_mask[train_nid] = True\n    val_mask = th.zeros((graph.num_nodes(),), dtype=th.bool)\n    val_mask[val_nid] = True\n    test_mask = th.zeros((graph.num_nodes(),), dtype=th.bool)\n    test_mask[test_nid] = True\n    graph.ndata[\"train_mask\"] = train_mask\n    graph.ndata[\"val_mask\"] = val_mask\n    graph.ndata[\"test_mask\"] = test_mask\n    return graph, num_labels\n\n\nif __name__ == \"__main__\":\n    argparser = argparse.ArgumentParser(\"Partition graph\")\n    argparser.add_argument(\n        \"--dataset\",\n        type=str,\n        default=\"reddit\",\n        help=\"datasets: reddit, ogbn-products, ogbn-papers100M\",\n    )\n    argparser.add_argument(\n        \"--num_parts\", type=int, default=4, help=\"number of partitions\"\n    )\n    argparser.add_argument(\n        \"--part_method\", type=str, default=\"metis\", help=\"the partition method\"\n    )\n    argparser.add_argument(\n        \"--balance_train\",\n        action=\"store_true\",\n        help=\"balance the training size in each partition.\",\n    )\n    argparser.add_argument(\n        \"--undirected\",\n        action=\"store_true\",\n        help=\"turn the graph into an undirected graph.\",\n    )\n    argparser.add_argument(\n        \"--balance_edges\",\n        action=\"store_true\",\n        help=\"balance the number of edges in each partition.\",\n    )\n    argparser.add_argument(\n        \"--num_trainers_per_machine\",\n        type=int,\n        default=1,\n        help=\"the number of trainers per machine. The trainer ids are stored\\\n                                in the node feature 'trainer_id'\",\n    )\n    argparser.add_argument(\n        \"--output\",\n        type=str,\n        default=\"data\",\n        help=\"Output path of partitioned graph.\",\n    )\n    argparser.add_argument(\n        \"--use_graphbolt\",\n        action=\"store_true\",\n        help=\"Use GraphBolt for distributed train.\",\n    )\n    args = argparser.parse_args()\n\n    start = time.time()\n    if args.dataset == \"reddit\":\n        g, _ = load_reddit()\n    elif args.dataset in [\"ogbn-products\", \"ogbn-papers100M\"]:\n        g, _ = load_ogb(args.dataset)\n    else:\n        raise RuntimeError(f\"Unknown dataset: {args.dataset}\")\n    print(\n        \"Load {} takes {:.3f} seconds\".format(args.dataset, time.time() - start)\n    )\n    print(\"|V|={}, |E|={}\".format(g.num_nodes(), g.num_edges()))\n    print(\n        \"train: {}, valid: {}, test: {}\".format(\n            th.sum(g.ndata[\"train_mask\"]),\n            th.sum(g.ndata[\"val_mask\"]),\n            th.sum(g.ndata[\"test_mask\"]),\n        )\n    )\n    if args.balance_train:\n        balance_ntypes = g.ndata[\"train_mask\"]\n    else:\n        balance_ntypes = None\n\n    if args.undirected:\n        sym_g = dgl.to_bidirected(g, readonly=True)\n        for key in g.ndata:\n            sym_g.ndata[key] = g.ndata[key]\n        g = sym_g\n\n    dgl.distributed.partition_graph(\n        g,\n        args.dataset,\n        args.num_parts,\n        args.output,\n        part_method=args.part_method,\n        balance_ntypes=balance_ntypes,\n        balance_edges=args.balance_edges,\n        num_trainers_per_machine=args.num_trainers_per_machine,\n        use_graphbolt=args.use_graphbolt,\n    )\n"
  },
  {
    "path": "examples/distributed/rgcn/README.md",
    "content": "## Distributed training\n\nThis is an example of training RGCN node classification in a distributed fashion. Currently, the example train RGCN graphs with input node features.\n\nBefore training, install python libs by pip:\n\n```bash\npip3 install ogb pyarrow\n```\n\nTo train RGCN, it has four steps:\n\n### Step 0: Setup a Distributed File System\n* You may skip this step if your cluster already has folder(s) synchronized across machines.\n\nTo perform distributed training, files and codes need to be accessed across multiple machines. A distributed file system would perfectly handle the job (i.e., NFS, Ceph).\n\n#### Server side setup\nHere is an example of how to setup NFS. First, install essential libs on the storage server\n```bash\nsudo apt-get install nfs-kernel-server\n```\n\nBelow we assume the user account is `ubuntu` and we create a directory of `workspace` in the home directory.\n```bash\nmkdir -p /home/ubuntu/workspace\n```\n\nWe assume that the all servers are under a subnet with ip range `192.168.0.0` to `192.168.255.255`. The exports configuration needs to be modifed to\n\n```bash\nsudo vim /etc/exports\n# add the following line\n/home/ubuntu/workspace  192.168.0.0/16(rw,sync,no_subtree_check)\n```\n\nThe server's internal ip can be checked  via `ifconfig` or `ip`. If the ip does not begin with `192.168`, then you may use\n```bash\n# for ip range 10.0.0.0 - 10.255.255.255\n/home/ubuntu/workspace  10.0.0.0/8(rw,sync,no_subtree_check)\n# for ip range 172.16.0.0 - 172.31.255.255\n/home/ubuntu/workspace  172.16.0.0/12(rw,sync,no_subtree_check)\n```\n\nThen restart NFS, the setup on server side is finished.\n\n```\nsudo systemctl restart nfs-kernel-server\n```\n\nFor configraution details, please refer to [NFS ArchWiki](https://wiki.archlinux.org/index.php/NFS).\n\n\n#### Client side setup\n\nTo use NFS, clients also require to install essential packages\n\n```\nsudo apt-get install nfs-common\n```\n\nYou can either mount the NFS manually\n\n```\nmkdir -p /home/ubuntu/workspace\nsudo mount -t nfs <nfs-server-ip>:/home/ubuntu/workspace /home/ubuntu/workspace\n```\n\nor edit the fstab so the folder will be mounted automatically\n\n```\n# vim /etc/fstab\n## append the following line to the file\n<nfs-server-ip>:/home/ubuntu/workspace   /home/ubuntu/workspace   nfs   defaults\t0 0\n```\n\nThen run `mount -a`.\n\nNow go to `/home/ubuntu/workspace` and clone the DGL Github repository.\n\n### Step 1: set IP configuration file.\n\nUser need to set their own IP configuration file `ip_config.txt` before training. For example, if we have four machines in current cluster, the IP configuration could like this:\n\n```bash\n172.31.0.1\n172.31.0.2\n```\n\nUsers need to make sure that the master node (node-0) has right permission to ssh to all the other nodes without password authentication.\n[This link](https://linuxize.com/post/how-to-setup-passwordless-ssh-login/) provides instructions of setting passwordless SSH login.\n\n### Step 2: partition the graph.\n\nThe example provides a script to partition some builtin graphs such as ogbn-mag graph.\nIf we want to train RGCN on 2 machines, we need to partition the graph into 2 parts.\n\nIn this example, we partition the ogbn-mag graph into 2 parts with Metis. The partitions are balanced with respect to the number of nodes, the number of edges and the number of labelled nodes.\n\n```bash\npython3 partition_graph.py --dataset ogbn-mag --num_parts 2 --balance_train --balance_edges\n```\n\nIf we want to train RGCN with `GraphBolt`, we need to append `--use_graphbolt` to generate partitions in `GraphBolt` format.\n\n```bash\npython3 partition_graph.py --dataset ogbn-mag --num_parts 2 --balance_train --balance_edges --use_graphbolt\n```\n\nIf we have already partitioned into `DGL` format, just convert them directly like below:\n\n```\n    python3 -c \"import dgl; dgl.distributed.dgl_partition_to_graphbolt('ogbn-products.json')\"\n```\n\n\n### Step 3: Launch distributed jobs\n\nDGL provides a script to launch the training job in the cluster. `part_config` and `ip_config`\nspecify relative paths to the path of the workspace.\n\nThe command below launches 4 training processes on each machine as we'd like to utilize 4 GPUs for training.\n\n```bash\npython3 ~/workspace/dgl/tools/launch.py \\\n--workspace ~/workspace/dgl/examples/distributed/rgcn/ \\\n--num_trainers 4 \\\n--num_servers 2 \\\n--num_samplers 0 \\\n--part_config data/ogbn-mag.json \\\n--ip_config ip_config.txt \\\n\"python3 node_classification.py --graph-name ogbn-mag --dataset ogbn-mag --fanout='25,25' --batch-size 1024  --n-hidden 64 --lr 0.01 --eval-batch-size 1024  --low-mem --dropout 0.5 --use-self-loop --n-bases 2 --n-epochs 3 --layer-norm --ip-config ip_config.txt --num_gpus 4\"\n```\n\nIf we want to train RGCN with `GraphBolt`, we need to append `--use_graphbolt`.\n\n```bash\npython3 ~/workspace/dgl/tools/launch.py \\\n--workspace ~/workspace/dgl/examples/distributed/rgcn/ \\\n--num_trainers 4 \\\n--num_servers 2 \\\n--num_samplers 0 \\\n--part_config data/ogbn-mag.json \\\n--ip_config ip_config.txt \\\n\"python3 node_classification.py --graph-name ogbn-mag --dataset ogbn-mag --fanout='25,25' --batch-size 1024  --n-hidden 64 --lr 0.01 --eval-batch-size 1024  --low-mem --dropout 0.5 --use-self-loop --n-bases 2 --n-epochs 3 --layer-norm --ip-config ip_config.txt --num_gpus 4 --use_graphbolt\"\n```\n\n**Note:** if you are using conda or other virtual environments on the remote machines, you need to replace `python3` in the command string (i.e. the last argument) with the path to the Python interpreter in that environment.\n\n\n## Comparison between `DGL` and `GraphBolt`\n\n### Partition sizes\n\nCompared to `DGL`, `GraphBolt` partitions are reduced to **19%** for `ogbn-mag`.\n\n`ogbn-mag`\n\n| Data Formats |         File Name            | Part 0 | Part 1 |\n| ------------ | ---------------------------- | ------ | ------ |\n| DGL          | graph.dgl                    | 714MB  | 716MB  |\n| GraphBolt    | fused_csc_sampling_graph.pt  | 137MB  | 136MB  |\n\n\n### Performance\n\nCompared to `DGL`, `GraphBolt`'s sampler works faster(reduced to **16%** `ogbn-mag`). `Min` and `Max` are statistics of all trainers on all nodes(machines).\n\nAs for RAM usage, the shared memory(measured by **shared** field of `free` command) usage decreases due to smaller graph partitions in `GraphBolt`. The peak memory used by processes(measured by **used** field of `free` command) decreases as well.\n\n`ogbn-mag`\n\n| Data Formats | Sample Time Per Epoch (CPU) |  Test Accuracy (3 epochs) | shared | used (peak) | CPU Util |\n| ------------ | --------------------------- | ------------------------- |  -----  | ---- | ----- |\n|     DGL      | Min: 48.2s, Max: 91.4s      |            42.76%         |  1.3GB  | 9.2GB| 10.4% |\n|   GraphBolt  | Min: 9.2s, Max: 11.9s       |            42.46%         |  742MB  | 5.9GB| 18.1% |\n\n\n## Demonstrate and profile sampling for Link Prediction task\n\n### DGL\n\n```\npython3 ~/workspace/dgl/tools/launch.py \\\n    --workspace ~/workspace/dgl/examples/distributed/rgcn/ \\\n    --num_trainers 4 \\\n    --num_servers 2 \\\n    --num_samplers 0 \\\n    --part_config ~/data/ogbn_mag_lp/ogbn-mag.json \\\n    --ip_config ~/workspace/ip_config.txt \\\n    \"python3 lp_perf.py --fanout='25,25' --batch-size 1024  --n-epochs 1 --graph-name ogbn-mag --ip-config ~/workspace/ip_config.txt --num_gpus 4 --remove_edge\"\n```\n\n### GraphBolt\n\nIn order to sample with `GraphBolt`, we need to convert partitions into `GraphBolt` formats with below command.\n\n```\npython3 -c \"import dgl;dgl.distributed.dgl_partition_to_graphbolt('/home/ubuntu/workspace/data/ogbn_mag_lp/ogbn-mag.json', store_eids=True, graph_formats='coo')\"\n```\n\nThen train with appended `--use_graphbolt`.\n\n```\npython3 ~/workspace/dgl/tools/launch.py \\\n    --workspace ~/workspace/dgl/examples/distributed/rgcn/ \\\n    --num_trainers 4 \\\n    --num_servers 2 \\\n    --num_samplers 0 \\\n    --part_config ~/data/ogbn_mag_lp/ogbn-mag.json \\\n    --ip_config ~/workspace/ip_config.txt \\\n    \"python3 lp_perf.py --fanout='25,25' --batch-size 1024  --n-epochs 1 --graph-name ogbn-mag --ip-config ~/workspace/ip_config.txt --num_gpus 4 --remove_edge --use_graphbolt\"\n```\n\n### Partition sizes\n\nCompared to `DGL`, `GraphBolt` partitions are reduced to **72%** for `ogbn-mag`.\n\n#### ogbn-mag\n\n| Data Formats |         File Name            | Part 0 | Part 1 |\n| ------------ | ---------------------------- | ------ | ------ |\n| DGL          | graph.dgl                    | 714MB  | 716MB  |\n| GraphBolt    | fused_csc_sampling_graph.pt  | 512MB  | 514MB  |\n\n### Performance Comparison\n\n#### Major used parameters\n\n1. 2 nodes(g4dn.metal), 4 trainers, 2 servers per node. Sample on main process.\n2. 2 layers.\n3. fanouts = 25, 25 for all edge types.\n4. batch_size = 1024.\n5. seed edge IDs are all edges of (\"author\", \"writes\", \"paper\"), ~7M in total.\n6. ratio of negative sampler = 3.\n7. exclude = \"reverse_types\".\n\n#### ogbn-mag\n\nCompared to `DGL`, sampling with `GraphBolt` is reduced to **15%**. As for the overhead of `exclude`, it's about **5%** in this test. This number could be higher if larger `fanout` or `batch size` is applied.\n\nThe time shown below is the mean sampling time per iteration(60 iters in total, slowest rank). Unit: seconds\n\n| Data Formats | No Exclude | Exclude |\n| ------------ | ---------- | ------- |\n| DGL          |   6.50     |   6.86  |\n| GraphBolt    |   0.95     |   1.00  |\n"
  },
  {
    "path": "examples/distributed/rgcn/lp_perf.py",
    "content": "\"\"\"\n[For internal use only]\n\nDemonstrate and profile the performance of sampling for link prediction tasks.\n\"\"\"\n\nimport argparse\nimport time\n\nimport dgl\n\nimport numpy as np\nimport torch as th\n\n\ndef run(args, g, train_eids):\n    fanouts = [int(fanout) for fanout in args.fanout.split(\",\")]\n\n    neg_sampler = dgl.dataloading.negative_sampler.Uniform(3)\n\n    prob = args.prob_or_mask\n    sampler = dgl.dataloading.MultiLayerNeighborSampler(\n        fanouts,\n        prob=prob,\n    )\n\n    exclude = None\n    reverse_etypes = None\n    if args.remove_edge:\n        exclude = \"reverse_types\"\n        # add reverse edge types mapping.\n        reverse_etypes = {\n            (\"author\", \"affiliated_with\", \"institution\"): (\n                \"institution\",\n                \"rev-affiliated_with\",\n                \"author\",\n            ),\n            (\"author\", \"writes\", \"paper\"): (\"paper\", \"rev-writes\", \"author\"),\n            (\"paper\", \"has_topic\", \"field_of_study\"): (\n                \"field_of_study\",\n                \"rev-has_topic\",\n                \"paper\",\n            ),\n            (\"paper\", \"cites\", \"paper\"): (\"paper\", \"rev-cites\", \"paper\"),\n            (\"institution\", \"rev-affiliated_with\", \"author\"): (\n                \"author\",\n                \"affiliated_with\",\n                \"institution\",\n            ),\n            (\"paper\", \"rev-writes\", \"author\"): (\"author\", \"writes\", \"paper\"),\n            (\"field_of_study\", \"rev-has_topic\", \"paper\"): (\n                \"paper\",\n                \"has_topic\",\n                \"field_of_study\",\n            ),\n            (\"paper\", \"rev-cites\", \"paper\"): (\"paper\", \"cites\", \"paper\"),\n        }\n\n    dataloader = dgl.dataloading.DistEdgeDataLoader(\n        g,\n        train_eids,\n        sampler,\n        negative_sampler=neg_sampler,\n        exclude=exclude,\n        reverse_etypes=reverse_etypes,\n        batch_size=args.batch_size,\n        shuffle=True,\n        drop_last=False,\n    )\n\n    for epoch in range(args.n_epochs):\n        sample_times = []\n        tic = time.time()\n        epoch_tic = time.time()\n        for step, sample_data in enumerate(dataloader):\n            input_nodes, pos_graph, neg_graph, blocks = sample_data\n\n            if args.debug:\n                # Verify prob/mask values.\n                for block in blocks:\n                    for c_etype in block.canonical_etypes:\n                        homo_eids = block.edges[c_etype].data[dgl.EID]\n                        assert th.all(\n                            g.edges[c_etype].data[prob][homo_eids] > 0\n                        )\n                # Verify exclude_edges functionality.\n                current_eids = blocks[-1].edata[dgl.EID]\n                seed_eids = pos_graph.edata[dgl.EID]\n                if exclude is None:\n                    assert th.any(th.isin(current_eids, seed_eids))\n                elif exclude == \"self\":\n                    assert not th.any(th.isin(current_eids, seed_eids))\n                elif exclude == \"reverse_id\":\n                    assert not th.any(th.isin(current_eids, seed_eids))\n                elif exclude == \"reverse_types\":\n                    for src_type, etype, dst_type in pos_graph.canonical_etypes:\n                        reverse_etype = reverse_etypes[\n                            (src_type, etype, dst_type)\n                        ]\n                        seed_eids = pos_graph.edges[etype].data[dgl.EID]\n                        if (src_type, etype, dst_type) in blocks[\n                            -1\n                        ].canonical_etypes:\n                            assert not th.any(\n                                th.isin(\n                                    blocks[-1].edges[etype].data[dgl.EID],\n                                    seed_eids,\n                                )\n                            )\n                        if reverse_etype in blocks[-1].canonical_etypes:\n                            assert not th.any(\n                                th.isin(\n                                    blocks[-1]\n                                    .edges[reverse_etype]\n                                    .data[dgl.EID],\n                                    seed_eids,\n                                )\n                            )\n                else:\n                    raise ValueError(f\"Unsupported exclude type: {exclude}\")\n            sample_times.append(time.time() - tic)\n            if step % 10 == 0:\n                print(\n                    f\"[{g.rank()}]Epoch {epoch} | Step {step} | Sample Time {np.mean(sample_times[10:]):.4f}\"\n                )\n            tic = time.time()\n        print(\n            f\"[{g.rank()}]Epoch {epoch} | Total time {time.time() - epoch_tic} | Sample Time {np.mean(sample_times[100:]):.4f}\"\n        )\n        g.barrier()\n\n\ndef rand_init_prob(shape, dtype):\n    prob = th.rand(shape)\n    prob[th.randperm(len(prob))[: int(len(prob) * 0.5)]] = 0.0\n    return prob\n\n\ndef rand_init_mask(shape, dtype):\n    prob = th.rand(shape)\n    prob[th.randperm(len(prob))[: int(len(prob) * 0.5)]] = 0.0\n    return (prob > 0.2).to(th.float32)\n\n\ndef main(args):\n    dgl.distributed.initialize(args.ip_config, use_graphbolt=args.use_graphbolt)\n\n    backend = \"gloo\" if args.num_gpus == -1 else \"nccl\"\n    th.distributed.init_process_group(backend=backend)\n\n    g = dgl.distributed.DistGraph(args.graph_name)\n    print(\"rank:\", g.rank())\n\n    # Assign prob/masks to edges.\n    for c_etype in g.canonical_etypes:\n        shape = (g.num_edges(etype=c_etype),)\n        g.edges[c_etype].data[\"prob\"] = dgl.distributed.DistTensor(\n            shape,\n            th.float32,\n            init_func=rand_init_prob,\n            part_policy=g.get_edge_partition_policy(c_etype),\n        )\n        g.edges[c_etype].data[\"mask\"] = dgl.distributed.DistTensor(\n            shape,\n            th.float32,\n            init_func=rand_init_mask,\n            part_policy=g.get_edge_partition_policy(c_etype),\n        )\n\n    pb = g.get_partition_book()\n    c_etype = (\"author\", \"writes\", \"paper\")\n    train_eids = dgl.distributed.edge_split(\n        th.ones((g.num_edges(etype=c_etype),), dtype=th.bool),\n        g.get_partition_book(),\n        etype=c_etype,\n        force_even=True,\n    )\n    train_eids = {c_etype: train_eids}\n    local_eids = pb.partid2eids(pb.partid, c_etype).detach().numpy()\n    print(\n        \"part {}, train: {} (local: {})\".format(\n            g.rank(),\n            len(train_eids[c_etype]),\n            len(np.intersect1d(train_eids[c_etype].numpy(), local_eids)),\n        )\n    )\n\n    run(\n        args,\n        g,\n        train_eids,\n    )\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(\n        description=\"Sampling Performance Profiling For Link Prediction Tasks\"\n    )\n    parser.add_argument(\"--graph-name\", type=str, help=\"graph name\")\n    parser.add_argument(\n        \"--ip-config\", type=str, help=\"The file for IP configuration\"\n    )\n    parser.add_argument(\n        \"--num_gpus\",\n        type=int,\n        default=-1,\n        help=\"the number of GPU device. Use -1 for CPU training\",\n    )\n    parser.add_argument(\n        \"-e\",\n        \"--n-epochs\",\n        type=int,\n        default=5,\n        help=\"number of training epochs\",\n    )\n    parser.add_argument(\n        \"--fanout\",\n        type=str,\n        default=\"4, 4\",\n        help=\"Fan-out of neighbor sampling.\",\n    )\n    parser.add_argument(\n        \"--batch-size\", type=int, default=100, help=\"Mini-batch size. \"\n    )\n    parser.add_argument(\n        \"--use_graphbolt\",\n        default=False,\n        action=\"store_true\",\n        help=\"Use GraphBolt for distributed train.\",\n    )\n    parser.add_argument(\n        \"--remove_edge\",\n        default=False,\n        action=\"store_true\",\n        help=\"whether to remove edges during sampling\",\n    )\n    parser.add_argument(\n        \"--debug\",\n        default=False,\n        action=\"store_true\",\n        help=\"whether to remove edges during sampling\",\n    )\n    parser.add_argument(\n        \"--prob_or_mask\",\n        type=str,\n        default=\"prob\",\n        help=\"whether to use prob or mask during sampling\",\n    )\n    args = parser.parse_args()\n\n    print(args)\n    main(args)\n"
  },
  {
    "path": "examples/distributed/rgcn/node_classification.py",
    "content": "\"\"\"\nModeling Relational Data with Graph Convolutional Networks\nPaper: https://arxiv.org/abs/1703.06103\nCode: https://github.com/tkipf/relational-gcn\nDifference compared to tkipf/relation-gcn\n* l2norm applied to all weights\n* remove nodes that won't be touched\n\"\"\"\n\nimport argparse\nimport gc, os\nimport itertools\nimport time\n\nimport numpy as np\n\nos.environ[\"DGLBACKEND\"] = \"pytorch\"\n\nfrom functools import partial\n\nimport dgl\nimport dgl.distributed\nimport torch as th\nimport torch.multiprocessing as mp\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nimport tqdm\nfrom dgl import DGLGraph, nn as dglnn\nfrom dgl.distributed import DistDataLoader\n\nfrom ogb.nodeproppred import DglNodePropPredDataset\nfrom torch.multiprocessing import Queue\nfrom torch.nn.parallel import DistributedDataParallel\nfrom torch.utils.data import DataLoader\n\n\nclass RelGraphConvLayer(nn.Module):\n    r\"\"\"Relational graph convolution layer.\n    Parameters\n    ----------\n    in_feat : int\n        Input feature size.\n    out_feat : int\n        Output feature size.\n    rel_names : list[str]\n        Relation names.\n    num_bases : int, optional\n        Number of bases. If is none, use number of relations. Default: None.\n    weight : bool, optional\n        True if a linear layer is applied after message passing. Default: True\n    bias : bool, optional\n        True if bias is added. Default: True\n    activation : callable, optional\n        Activation function. Default: None\n    self_loop : bool, optional\n        True to include self loop message. Default: False\n    dropout : float, optional\n        Dropout rate. Default: 0.0\n    \"\"\"\n\n    def __init__(\n        self,\n        in_feat,\n        out_feat,\n        rel_names,\n        num_bases,\n        *,\n        weight=True,\n        bias=True,\n        activation=None,\n        self_loop=False,\n        dropout=0.0\n    ):\n        super(RelGraphConvLayer, self).__init__()\n        self.in_feat = in_feat\n        self.out_feat = out_feat\n        self.rel_names = rel_names\n        self.num_bases = num_bases\n        self.bias = bias\n        self.activation = activation\n        self.self_loop = self_loop\n\n        self.conv = dglnn.HeteroGraphConv(\n            {\n                rel: dglnn.GraphConv(\n                    in_feat, out_feat, norm=\"right\", weight=False, bias=False\n                )\n                for rel in rel_names\n            }\n        )\n\n        self.use_weight = weight\n        self.use_basis = num_bases < len(self.rel_names) and weight\n        if self.use_weight:\n            if self.use_basis:\n                self.basis = dglnn.WeightBasis(\n                    (in_feat, out_feat), num_bases, len(self.rel_names)\n                )\n            else:\n                self.weight = nn.Parameter(\n                    th.Tensor(len(self.rel_names), in_feat, out_feat)\n                )\n                nn.init.xavier_uniform_(\n                    self.weight, gain=nn.init.calculate_gain(\"relu\")\n                )\n\n        # bias\n        if bias:\n            self.h_bias = nn.Parameter(th.Tensor(out_feat))\n            nn.init.zeros_(self.h_bias)\n\n        # weight for self loop\n        if self.self_loop:\n            self.loop_weight = nn.Parameter(th.Tensor(in_feat, out_feat))\n            nn.init.xavier_uniform_(\n                self.loop_weight, gain=nn.init.calculate_gain(\"relu\")\n            )\n\n        self.dropout = nn.Dropout(dropout)\n\n    def forward(self, g, inputs):\n        \"\"\"Forward computation\n        Parameters\n        ----------\n        g : DGLGraph\n            Input graph.\n        inputs : dict[str, torch.Tensor]\n            Node feature for each node type.\n        Returns\n        -------\n        dict[str, torch.Tensor]\n            New node features for each node type.\n        \"\"\"\n        g = g.local_var()\n        if self.use_weight:\n            weight = self.basis() if self.use_basis else self.weight\n            wdict = {\n                self.rel_names[i]: {\"weight\": w.squeeze(0)}\n                for i, w in enumerate(th.split(weight, 1, dim=0))\n            }\n        else:\n            wdict = {}\n\n        if g.is_block:\n            inputs_src = inputs\n            inputs_dst = {\n                k: v[: g.number_of_dst_nodes(k)] for k, v in inputs.items()\n            }\n        else:\n            inputs_src = inputs_dst = inputs\n\n        hs = self.conv(g, inputs, mod_kwargs=wdict)\n\n        def _apply(ntype, h):\n            if self.self_loop:\n                h = h + th.matmul(inputs_dst[ntype], self.loop_weight)\n            if self.bias:\n                h = h + self.h_bias\n            if self.activation:\n                h = self.activation(h)\n            return self.dropout(h)\n\n        return {ntype: _apply(ntype, h) for ntype, h in hs.items()}\n\n\nclass EntityClassify(nn.Module):\n    \"\"\"Entity classification class for RGCN\n    Parameters\n    ----------\n    device : int\n        Device to run the layer.\n    num_nodes : int\n        Number of nodes.\n    h_dim : int\n        Hidden dim size.\n    out_dim : int\n        Output dim size.\n    rel_names : list of str\n        A list of relation names.\n    num_bases : int\n        Number of bases. If is none, use number of relations.\n    num_hidden_layers : int\n        Number of hidden RelGraphConv Layer\n    dropout : float\n        Dropout\n    use_self_loop : bool\n        Use self loop if True, default False.\n    \"\"\"\n\n    def __init__(\n        self,\n        device,\n        h_dim,\n        out_dim,\n        rel_names,\n        num_bases=None,\n        num_hidden_layers=1,\n        dropout=0,\n        use_self_loop=False,\n        layer_norm=False,\n    ):\n        super(EntityClassify, self).__init__()\n        self.device = device\n        self.h_dim = h_dim\n        self.out_dim = out_dim\n        self.num_bases = None if num_bases < 0 else num_bases\n        self.num_hidden_layers = num_hidden_layers\n        self.dropout = dropout\n        self.use_self_loop = use_self_loop\n        self.layer_norm = layer_norm\n\n        self.layers = nn.ModuleList()\n        # i2h\n        self.layers.append(\n            RelGraphConvLayer(\n                self.h_dim,\n                self.h_dim,\n                rel_names,\n                self.num_bases,\n                activation=F.relu,\n                self_loop=self.use_self_loop,\n                dropout=self.dropout,\n            )\n        )\n        # h2h\n        for idx in range(self.num_hidden_layers):\n            self.layers.append(\n                RelGraphConvLayer(\n                    self.h_dim,\n                    self.h_dim,\n                    rel_names,\n                    self.num_bases,\n                    activation=F.relu,\n                    self_loop=self.use_self_loop,\n                    dropout=self.dropout,\n                )\n            )\n        # h2o\n        self.layers.append(\n            RelGraphConvLayer(\n                self.h_dim,\n                self.out_dim,\n                rel_names,\n                self.num_bases,\n                activation=None,\n                self_loop=self.use_self_loop,\n            )\n        )\n\n    def forward(self, blocks, feats, norm=None):\n        if blocks is None:\n            # full graph training\n            blocks = [self.g] * len(self.layers)\n        h = feats\n        for layer, block in zip(self.layers, blocks):\n            block = block.to(self.device)\n            h = layer(block, h)\n        return h\n\n\ndef init_emb(shape, dtype):\n    arr = th.zeros(shape, dtype=dtype)\n    nn.init.uniform_(arr, -1.0, 1.0)\n    return arr\n\n\nclass DistEmbedLayer(nn.Module):\n    r\"\"\"Embedding layer for featureless heterograph.\n    Parameters\n    ----------\n    dev_id : int\n        Device to run the layer.\n    g : DistGraph\n        training graph\n    embed_size : int\n        Output embed size\n    sparse_emb: bool\n        Whether to use sparse embedding\n        Default: False\n    dgl_sparse_emb: bool\n        Whether to use DGL sparse embedding\n        Default: False\n    embed_name : str, optional\n        Embed name\n    \"\"\"\n\n    def __init__(\n        self,\n        dev_id,\n        g,\n        embed_size,\n        sparse_emb=False,\n        dgl_sparse_emb=False,\n        feat_name=\"feat\",\n        embed_name=\"node_emb\",\n    ):\n        super(DistEmbedLayer, self).__init__()\n        self.dev_id = dev_id\n        self.embed_size = embed_size\n        self.embed_name = embed_name\n        self.feat_name = feat_name\n        self.sparse_emb = sparse_emb\n        self.g = g\n        self.ntype_id_map = {g.get_ntype_id(ntype): ntype for ntype in g.ntypes}\n\n        self.node_projs = nn.ModuleDict()\n        for ntype in g.ntypes:\n            if feat_name in g.nodes[ntype].data:\n                self.node_projs[ntype] = nn.Linear(\n                    g.nodes[ntype].data[feat_name].shape[1], embed_size\n                )\n                nn.init.xavier_uniform_(self.node_projs[ntype].weight)\n                print(\"node {} has data {}\".format(ntype, feat_name))\n        if sparse_emb:\n            if dgl_sparse_emb:\n                self.node_embeds = {}\n                for ntype in g.ntypes:\n                    # We only create embeddings for nodes without node features.\n                    if feat_name not in g.nodes[ntype].data:\n                        part_policy = g.get_node_partition_policy(ntype)\n                        self.node_embeds[ntype] = dgl.distributed.DistEmbedding(\n                            g.num_nodes(ntype),\n                            self.embed_size,\n                            embed_name + \"_\" + ntype,\n                            init_emb,\n                            part_policy,\n                        )\n            else:\n                self.node_embeds = nn.ModuleDict()\n                for ntype in g.ntypes:\n                    # We only create embeddings for nodes without node features.\n                    if feat_name not in g.nodes[ntype].data:\n                        self.node_embeds[ntype] = th.nn.Embedding(\n                            g.num_nodes(ntype),\n                            self.embed_size,\n                            sparse=self.sparse_emb,\n                        )\n                        nn.init.uniform_(\n                            self.node_embeds[ntype].weight, -1.0, 1.0\n                        )\n        else:\n            self.node_embeds = nn.ModuleDict()\n            for ntype in g.ntypes:\n                # We only create embeddings for nodes without node features.\n                if feat_name not in g.nodes[ntype].data:\n                    self.node_embeds[ntype] = th.nn.Embedding(\n                        g.num_nodes(ntype), self.embed_size\n                    )\n                    nn.init.uniform_(self.node_embeds[ntype].weight, -1.0, 1.0)\n\n    def forward(self, node_ids):\n        \"\"\"Forward computation\n        Parameters\n        ----------\n        node_ids : dict of Tensor\n            node ids to generate embedding for.\n        Returns\n        -------\n        tensor\n            embeddings as the input of the next layer\n        \"\"\"\n        embeds = {}\n        for ntype in node_ids:\n            if self.feat_name in self.g.nodes[ntype].data:\n                embeds[ntype] = self.node_projs[ntype](\n                    self.g.nodes[ntype]\n                    .data[self.feat_name][node_ids[ntype]]\n                    .to(self.dev_id)\n                )\n            else:\n                embeds[ntype] = self.node_embeds[ntype](node_ids[ntype]).to(\n                    self.dev_id\n                )\n        return embeds\n\n\ndef compute_acc(results, labels):\n    \"\"\"\n    Compute the accuracy of prediction given the labels.\n    \"\"\"\n    labels = labels.long()\n    return (results == labels).float().sum() / len(results)\n\n\ndef evaluate(\n    g,\n    model,\n    embed_layer,\n    labels,\n    eval_loader,\n    test_loader,\n    all_val_nid,\n    all_test_nid,\n):\n    model.eval()\n    embed_layer.eval()\n    eval_logits = []\n    eval_seeds = []\n\n    global_results = dgl.distributed.DistTensor(\n        labels.shape, th.long, \"results\", persistent=True\n    )\n\n    with th.no_grad():\n        th.cuda.empty_cache()\n        for sample_data in tqdm.tqdm(eval_loader):\n            input_nodes, seeds, blocks = sample_data\n            seeds = seeds[\"paper\"]\n            feats = embed_layer(input_nodes)\n            logits = model(blocks, feats)\n            assert len(logits) == 1\n            logits = logits[\"paper\"]\n            eval_logits.append(logits.cpu().detach())\n            assert np.all(seeds.numpy() < g.num_nodes(\"paper\"))\n            eval_seeds.append(seeds.cpu().detach())\n    eval_logits = th.cat(eval_logits)\n    eval_seeds = th.cat(eval_seeds)\n    global_results[eval_seeds] = eval_logits.argmax(dim=1)\n\n    test_logits = []\n    test_seeds = []\n    with th.no_grad():\n        th.cuda.empty_cache()\n        for sample_data in tqdm.tqdm(test_loader):\n            input_nodes, seeds, blocks = sample_data\n            seeds = seeds[\"paper\"]\n            feats = embed_layer(input_nodes)\n            logits = model(blocks, feats)\n            assert len(logits) == 1\n            logits = logits[\"paper\"]\n            test_logits.append(logits.cpu().detach())\n            assert np.all(seeds.numpy() < g.num_nodes(\"paper\"))\n            test_seeds.append(seeds.cpu().detach())\n    test_logits = th.cat(test_logits)\n    test_seeds = th.cat(test_seeds)\n    global_results[test_seeds] = test_logits.argmax(dim=1)\n\n    g.barrier()\n    if g.rank() == 0:\n        return compute_acc(\n            global_results[all_val_nid], labels[all_val_nid]\n        ), compute_acc(global_results[all_test_nid], labels[all_test_nid])\n    else:\n        return -1, -1\n\n\ndef run(args, device, data):\n    (\n        g,\n        num_classes,\n        train_nid,\n        val_nid,\n        test_nid,\n        labels,\n        all_val_nid,\n        all_test_nid,\n    ) = data\n\n    fanouts = [int(fanout) for fanout in args.fanout.split(\",\")]\n    val_fanouts = [int(fanout) for fanout in args.validation_fanout.split(\",\")]\n\n    sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts)\n    dataloader = dgl.distributed.DistNodeDataLoader(\n        g,\n        {\"paper\": train_nid},\n        sampler,\n        batch_size=args.batch_size,\n        shuffle=True,\n        drop_last=False,\n    )\n\n    valid_sampler = dgl.dataloading.MultiLayerNeighborSampler(val_fanouts)\n    valid_dataloader = dgl.distributed.DistNodeDataLoader(\n        g,\n        {\"paper\": val_nid},\n        valid_sampler,\n        batch_size=args.batch_size,\n        shuffle=False,\n        drop_last=False,\n    )\n\n    test_sampler = dgl.dataloading.MultiLayerNeighborSampler(val_fanouts)\n    test_dataloader = dgl.distributed.DistNodeDataLoader(\n        g,\n        {\"paper\": test_nid},\n        test_sampler,\n        batch_size=args.eval_batch_size,\n        shuffle=False,\n        drop_last=False,\n    )\n\n    embed_layer = DistEmbedLayer(\n        device,\n        g,\n        args.n_hidden,\n        sparse_emb=args.sparse_embedding,\n        dgl_sparse_emb=args.dgl_sparse,\n        feat_name=\"feat\",\n    )\n\n    model = EntityClassify(\n        device,\n        args.n_hidden,\n        num_classes,\n        g.etypes,\n        num_bases=args.n_bases,\n        num_hidden_layers=args.n_layers - 2,\n        dropout=args.dropout,\n        use_self_loop=args.use_self_loop,\n        layer_norm=args.layer_norm,\n    )\n    model = model.to(device)\n\n    if not args.standalone:\n        if args.num_gpus == -1:\n            model = DistributedDataParallel(model)\n            # If there are dense parameters in the embedding layer\n            # or we use Pytorch saprse embeddings.\n            if len(embed_layer.node_projs) > 0 or not args.dgl_sparse:\n                embed_layer = DistributedDataParallel(embed_layer)\n        else:\n            dev_id = g.rank() % args.num_gpus\n            model = DistributedDataParallel(\n                model, device_ids=[dev_id], output_device=dev_id\n            )\n            # If there are dense parameters in the embedding layer\n            # or we use Pytorch saprse embeddings.\n            if len(embed_layer.node_projs) > 0 or not args.dgl_sparse:\n                embed_layer = embed_layer.to(device)\n                embed_layer = DistributedDataParallel(\n                    embed_layer, device_ids=[dev_id], output_device=dev_id\n                )\n\n    if args.sparse_embedding:\n        if args.dgl_sparse and args.standalone:\n            emb_optimizer = dgl.distributed.optim.SparseAdam(\n                list(embed_layer.node_embeds.values()), lr=args.sparse_lr\n            )\n            print(\n                \"optimize DGL sparse embedding:\", embed_layer.node_embeds.keys()\n            )\n        elif args.dgl_sparse:\n            emb_optimizer = dgl.distributed.optim.SparseAdam(\n                list(embed_layer.module.node_embeds.values()), lr=args.sparse_lr\n            )\n            print(\n                \"optimize DGL sparse embedding:\",\n                embed_layer.module.node_embeds.keys(),\n            )\n        elif args.standalone:\n            emb_optimizer = th.optim.SparseAdam(\n                list(embed_layer.node_embeds.parameters()), lr=args.sparse_lr\n            )\n            print(\"optimize Pytorch sparse embedding:\", embed_layer.node_embeds)\n        else:\n            emb_optimizer = th.optim.SparseAdam(\n                list(embed_layer.module.node_embeds.parameters()),\n                lr=args.sparse_lr,\n            )\n            print(\n                \"optimize Pytorch sparse embedding:\",\n                embed_layer.module.node_embeds,\n            )\n\n        dense_params = list(model.parameters())\n        if args.standalone:\n            dense_params += list(embed_layer.node_projs.parameters())\n            print(\"optimize dense projection:\", embed_layer.node_projs)\n        else:\n            dense_params += list(embed_layer.module.node_projs.parameters())\n            print(\"optimize dense projection:\", embed_layer.module.node_projs)\n        optimizer = th.optim.Adam(\n            dense_params, lr=args.lr, weight_decay=args.l2norm\n        )\n    else:\n        all_params = list(model.parameters()) + list(embed_layer.parameters())\n        optimizer = th.optim.Adam(\n            all_params, lr=args.lr, weight_decay=args.l2norm\n        )\n\n    # training loop\n    print(\"start training...\")\n    for epoch in range(args.n_epochs):\n        tic = time.time()\n\n        sample_time = 0\n        copy_time = 0\n        forward_time = 0\n        backward_time = 0\n        update_time = 0\n        number_train = 0\n        number_input = 0\n\n        step_time = []\n        iter_t = []\n        sample_t = []\n        feat_copy_t = []\n        forward_t = []\n        backward_t = []\n        update_t = []\n        iter_tput = []\n\n        start = time.time()\n        # Loop over the dataloader to sample the computation dependency graph as a list of\n        # blocks.\n        step_time = []\n        for step, sample_data in enumerate(dataloader):\n            input_nodes, seeds, blocks = sample_data\n            seeds = seeds[\"paper\"]\n            number_train += seeds.shape[0]\n            number_input += np.sum(\n                [blocks[0].num_src_nodes(ntype) for ntype in blocks[0].ntypes]\n            )\n            tic_step = time.time()\n            sample_time += tic_step - start\n            sample_t.append(tic_step - start)\n\n            feats = embed_layer(input_nodes)\n            label = labels[seeds].to(device)\n            copy_time = time.time()\n            feat_copy_t.append(copy_time - tic_step)\n\n            # forward\n            logits = model(blocks, feats)\n            assert len(logits) == 1\n            logits = logits[\"paper\"]\n            loss = F.cross_entropy(logits, label)\n            forward_end = time.time()\n\n            # backward\n            optimizer.zero_grad()\n            if args.sparse_embedding:\n                emb_optimizer.zero_grad()\n            loss.backward()\n            compute_end = time.time()\n            forward_t.append(forward_end - copy_time)\n            backward_t.append(compute_end - forward_end)\n\n            # Update model parameters\n            optimizer.step()\n            if args.sparse_embedding:\n                emb_optimizer.step()\n            update_t.append(time.time() - compute_end)\n            step_t = time.time() - start\n            step_time.append(step_t)\n\n            train_acc = th.sum(logits.argmax(dim=1) == label).item() / len(\n                seeds\n            )\n\n            if step % args.log_every == 0:\n                print(\n                    \"[{}] Epoch {:05d} | Step {:05d} | Train acc {:.4f} | Loss {:.4f} | time {:.3f} s\"\n                    \"| sample {:.3f} | copy {:.3f} | forward {:.3f} | backward {:.3f} | update {:.3f}\".format(\n                        g.rank(),\n                        epoch,\n                        step,\n                        train_acc,\n                        loss.item(),\n                        np.sum(step_time[-args.log_every :]),\n                        np.sum(sample_t[-args.log_every :]),\n                        np.sum(feat_copy_t[-args.log_every :]),\n                        np.sum(forward_t[-args.log_every :]),\n                        np.sum(backward_t[-args.log_every :]),\n                        np.sum(update_t[-args.log_every :]),\n                    )\n                )\n            start = time.time()\n\n        gc.collect()\n        print(\n            \"[{}]Epoch Time(s): {:.4f}, sample: {:.4f}, data copy: {:.4f}, forward: {:.4f}, backward: {:.4f}, update: {:.4f}, #train: {}, #input: {}\".format(\n                g.rank(),\n                np.sum(step_time),\n                np.sum(sample_t),\n                np.sum(feat_copy_t),\n                np.sum(forward_t),\n                np.sum(backward_t),\n                np.sum(update_t),\n                number_train,\n                number_input,\n            )\n        )\n        epoch += 1\n\n        start = time.time()\n        g.barrier()\n        val_acc, test_acc = evaluate(\n            g,\n            model,\n            embed_layer,\n            labels,\n            valid_dataloader,\n            test_dataloader,\n            all_val_nid,\n            all_test_nid,\n        )\n        if val_acc >= 0:\n            print(\n                \"Val Acc {:.4f}, Test Acc {:.4f}, time: {:.4f}\".format(\n                    val_acc, test_acc, time.time() - start\n                )\n            )\n\n\ndef main(args):\n    dgl.distributed.initialize(args.ip_config, use_graphbolt=args.use_graphbolt)\n    if not args.standalone:\n        backend = \"gloo\" if args.num_gpus == -1 else \"nccl\"\n        if args.sparse_embedding and args.dgl_sparse:\n            # `nccl` is not fully supported in DistDGL's sparse optimizer.\n            backend = \"gloo\"\n        th.distributed.init_process_group(backend=backend)\n\n    g = dgl.distributed.DistGraph(args.graph_name, part_config=args.conf_path)\n    print(\"rank:\", g.rank())\n\n    pb = g.get_partition_book()\n    if \"trainer_id\" in g.nodes[\"paper\"].data:\n        train_nid = dgl.distributed.node_split(\n            g.nodes[\"paper\"].data[\"train_mask\"],\n            pb,\n            ntype=\"paper\",\n            force_even=True,\n            node_trainer_ids=g.nodes[\"paper\"].data[\"trainer_id\"],\n        )\n        val_nid = dgl.distributed.node_split(\n            g.nodes[\"paper\"].data[\"val_mask\"],\n            pb,\n            ntype=\"paper\",\n            force_even=True,\n            node_trainer_ids=g.nodes[\"paper\"].data[\"trainer_id\"],\n        )\n        test_nid = dgl.distributed.node_split(\n            g.nodes[\"paper\"].data[\"test_mask\"],\n            pb,\n            ntype=\"paper\",\n            force_even=True,\n            node_trainer_ids=g.nodes[\"paper\"].data[\"trainer_id\"],\n        )\n    else:\n        train_nid = dgl.distributed.node_split(\n            g.nodes[\"paper\"].data[\"train_mask\"],\n            pb,\n            ntype=\"paper\",\n            force_even=True,\n        )\n        val_nid = dgl.distributed.node_split(\n            g.nodes[\"paper\"].data[\"val_mask\"],\n            pb,\n            ntype=\"paper\",\n            force_even=True,\n        )\n        test_nid = dgl.distributed.node_split(\n            g.nodes[\"paper\"].data[\"test_mask\"],\n            pb,\n            ntype=\"paper\",\n            force_even=True,\n        )\n    local_nid = pb.partid2nids(pb.partid, \"paper\").detach().numpy()\n    print(\n        \"part {}, train: {} (local: {}), val: {} (local: {}), test: {} (local: {})\".format(\n            g.rank(),\n            len(train_nid),\n            len(np.intersect1d(train_nid.numpy(), local_nid)),\n            len(val_nid),\n            len(np.intersect1d(val_nid.numpy(), local_nid)),\n            len(test_nid),\n            len(np.intersect1d(test_nid.numpy(), local_nid)),\n        )\n    )\n    if args.num_gpus == -1:\n        device = th.device(\"cpu\")\n    else:\n        dev_id = g.rank() % args.num_gpus\n        device = th.device(\"cuda:\" + str(dev_id))\n    labels = g.nodes[\"paper\"].data[\"labels\"][np.arange(g.num_nodes(\"paper\"))]\n    all_val_nid = th.LongTensor(\n        np.nonzero(\n            g.nodes[\"paper\"].data[\"val_mask\"][np.arange(g.num_nodes(\"paper\"))]\n        )\n    ).squeeze()\n    all_test_nid = th.LongTensor(\n        np.nonzero(\n            g.nodes[\"paper\"].data[\"test_mask\"][np.arange(g.num_nodes(\"paper\"))]\n        )\n    ).squeeze()\n    n_classes = len(th.unique(labels[labels >= 0]))\n    print(\"#classes:\", n_classes)\n\n    run(\n        args,\n        device,\n        (\n            g,\n            n_classes,\n            train_nid,\n            val_nid,\n            test_nid,\n            labels,\n            all_val_nid,\n            all_test_nid,\n        ),\n    )\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"RGCN\")\n    # distributed training related\n    parser.add_argument(\"--graph-name\", type=str, help=\"graph name\")\n    parser.add_argument(\"--id\", type=int, help=\"the partition id\")\n    parser.add_argument(\n        \"--ip-config\", type=str, help=\"The file for IP configuration\"\n    )\n    parser.add_argument(\n        \"--conf-path\", type=str, help=\"The path to the partition config file\"\n    )\n\n    # rgcn related\n    parser.add_argument(\n        \"--num_gpus\",\n        type=int,\n        default=-1,\n        help=\"the number of GPU device. Use -1 for CPU training\",\n    )\n    parser.add_argument(\n        \"--dropout\", type=float, default=0, help=\"dropout probability\"\n    )\n    parser.add_argument(\n        \"--n-hidden\", type=int, default=16, help=\"number of hidden units\"\n    )\n    parser.add_argument(\"--lr\", type=float, default=1e-2, help=\"learning rate\")\n    parser.add_argument(\n        \"--sparse-lr\", type=float, default=1e-2, help=\"sparse lr rate\"\n    )\n    parser.add_argument(\n        \"--n-bases\",\n        type=int,\n        default=-1,\n        help=\"number of filter weight matrices, default: -1 [use all]\",\n    )\n    parser.add_argument(\n        \"--n-layers\", type=int, default=2, help=\"number of propagation rounds\"\n    )\n    parser.add_argument(\n        \"-e\",\n        \"--n-epochs\",\n        type=int,\n        default=50,\n        help=\"number of training epochs\",\n    )\n    parser.add_argument(\n        \"-d\", \"--dataset\", type=str, required=True, help=\"dataset to use\"\n    )\n    parser.add_argument(\"--l2norm\", type=float, default=0, help=\"l2 norm coef\")\n    parser.add_argument(\n        \"--relabel\",\n        default=False,\n        action=\"store_true\",\n        help=\"remove untouched nodes and relabel\",\n    )\n    parser.add_argument(\n        \"--fanout\",\n        type=str,\n        default=\"4, 4\",\n        help=\"Fan-out of neighbor sampling.\",\n    )\n    parser.add_argument(\n        \"--validation-fanout\",\n        type=str,\n        default=None,\n        help=\"Fan-out of neighbor sampling during validation.\",\n    )\n    parser.add_argument(\n        \"--use-self-loop\",\n        default=False,\n        action=\"store_true\",\n        help=\"include self feature as a special relation\",\n    )\n    parser.add_argument(\n        \"--batch-size\", type=int, default=100, help=\"Mini-batch size. \"\n    )\n    parser.add_argument(\n        \"--eval-batch-size\", type=int, default=128, help=\"Mini-batch size. \"\n    )\n    parser.add_argument(\"--log-every\", type=int, default=20)\n    parser.add_argument(\n        \"--low-mem\",\n        default=False,\n        action=\"store_true\",\n        help=\"Whether use low mem RelGraphCov\",\n    )\n    parser.add_argument(\n        \"--sparse-embedding\",\n        action=\"store_true\",\n        help=\"Use sparse embedding for node embeddings.\",\n    )\n    parser.add_argument(\n        \"--dgl-sparse\",\n        action=\"store_true\",\n        help=\"Whether to use DGL sparse embedding\",\n    )\n    parser.add_argument(\n        \"--layer-norm\",\n        default=False,\n        action=\"store_true\",\n        help=\"Use layer norm\",\n    )\n    parser.add_argument(\n        \"--local_rank\", type=int, help=\"get rank of the process\"\n    )\n    parser.add_argument(\n        \"--standalone\", action=\"store_true\", help=\"run in the standalone mode\"\n    )\n    parser.add_argument(\n        \"--use_graphbolt\",\n        action=\"store_true\",\n        help=\"Use GraphBolt for distributed train.\",\n    )\n    args = parser.parse_args()\n\n    # if validation_fanout is None, set it with args.fanout\n    if args.validation_fanout is None:\n        args.validation_fanout = args.fanout\n    print(args)\n    main(args)\n"
  },
  {
    "path": "examples/distributed/rgcn/partition_graph.py",
    "content": "import argparse\nimport time\n\nimport dgl\nimport numpy as np\nimport torch as th\n\nfrom ogb.nodeproppred import DglNodePropPredDataset\n\n\ndef load_ogb(dataset):\n    if dataset == \"ogbn-mag\":\n        dataset = DglNodePropPredDataset(name=dataset)\n        split_idx = dataset.get_idx_split()\n        train_idx = split_idx[\"train\"][\"paper\"]\n        val_idx = split_idx[\"valid\"][\"paper\"]\n        test_idx = split_idx[\"test\"][\"paper\"]\n        hg_orig, labels = dataset[0]\n        subgs = {}\n        for etype in hg_orig.canonical_etypes:\n            u, v = hg_orig.all_edges(etype=etype)\n            subgs[etype] = (u, v)\n            subgs[(etype[2], \"rev-\" + etype[1], etype[0])] = (v, u)\n        hg = dgl.heterograph(subgs)\n        hg.nodes[\"paper\"].data[\"feat\"] = hg_orig.nodes[\"paper\"].data[\"feat\"]\n        paper_labels = labels[\"paper\"].squeeze()\n\n        num_rels = len(hg.canonical_etypes)\n        num_of_ntype = len(hg.ntypes)\n        num_classes = dataset.num_classes\n        category = \"paper\"\n        print(\"Number of relations: {}\".format(num_rels))\n        print(\"Number of class: {}\".format(num_classes))\n        print(\"Number of train: {}\".format(len(train_idx)))\n        print(\"Number of valid: {}\".format(len(val_idx)))\n        print(\"Number of test: {}\".format(len(test_idx)))\n\n        # get target category id\n        category_id = len(hg.ntypes)\n        for i, ntype in enumerate(hg.ntypes):\n            if ntype == category:\n                category_id = i\n\n        train_mask = th.zeros((hg.num_nodes(\"paper\"),), dtype=th.bool)\n        train_mask[train_idx] = True\n        val_mask = th.zeros((hg.num_nodes(\"paper\"),), dtype=th.bool)\n        val_mask[val_idx] = True\n        test_mask = th.zeros((hg.num_nodes(\"paper\"),), dtype=th.bool)\n        test_mask[test_idx] = True\n        hg.nodes[\"paper\"].data[\"train_mask\"] = train_mask\n        hg.nodes[\"paper\"].data[\"val_mask\"] = val_mask\n        hg.nodes[\"paper\"].data[\"test_mask\"] = test_mask\n\n        hg.nodes[\"paper\"].data[\"labels\"] = paper_labels\n        return hg\n    else:\n        raise (\"Do not support other ogbn datasets.\")\n\n\nif __name__ == \"__main__\":\n    argparser = argparse.ArgumentParser(\"Partition builtin graphs\")\n    argparser.add_argument(\n        \"--dataset\", type=str, default=\"ogbn-mag\", help=\"datasets: ogbn-mag\"\n    )\n    argparser.add_argument(\n        \"--num_parts\", type=int, default=4, help=\"number of partitions\"\n    )\n    argparser.add_argument(\n        \"--part_method\", type=str, default=\"metis\", help=\"the partition method\"\n    )\n    argparser.add_argument(\n        \"--balance_train\",\n        action=\"store_true\",\n        help=\"balance the training size in each partition.\",\n    )\n    argparser.add_argument(\n        \"--undirected\",\n        action=\"store_true\",\n        help=\"turn the graph into an undirected graph.\",\n    )\n    argparser.add_argument(\n        \"--balance_edges\",\n        action=\"store_true\",\n        help=\"balance the number of edges in each partition.\",\n    )\n    argparser.add_argument(\n        \"--num_trainers_per_machine\",\n        type=int,\n        default=1,\n        help=\"the number of trainers per machine. The trainer ids are stored\\\n                                in the node feature 'trainer_id'\",\n    )\n    argparser.add_argument(\n        \"--output\",\n        type=str,\n        default=\"data\",\n        help=\"Output path of partitioned graph.\",\n    )\n    argparser.add_argument(\n        \"--use_graphbolt\",\n        action=\"store_true\",\n        help=\"Use GraphBolt for distributed train.\",\n    )\n\n    args = argparser.parse_args()\n\n    start = time.time()\n    g = load_ogb(args.dataset)\n\n    print(\n        \"load {} takes {:.3f} seconds\".format(args.dataset, time.time() - start)\n    )\n    print(\"|V|={}, |E|={}\".format(g.num_nodes(), g.num_edges()))\n    print(\n        \"train: {}, valid: {}, test: {}\".format(\n            th.sum(g.nodes[\"paper\"].data[\"train_mask\"]),\n            th.sum(g.nodes[\"paper\"].data[\"val_mask\"]),\n            th.sum(g.nodes[\"paper\"].data[\"test_mask\"]),\n        )\n    )\n\n    if args.balance_train:\n        balance_ntypes = {\"paper\": g.nodes[\"paper\"].data[\"train_mask\"]}\n    else:\n        balance_ntypes = None\n\n    dgl.distributed.partition_graph(\n        g,\n        args.dataset,\n        args.num_parts,\n        args.output,\n        part_method=args.part_method,\n        balance_ntypes=balance_ntypes,\n        balance_edges=args.balance_edges,\n        num_trainers_per_machine=args.num_trainers_per_machine,\n        use_graphbolt=args.use_graphbolt,\n    )\n"
  },
  {
    "path": "examples/graphbolt/README.md",
    "content": "## How to run the code?\n\n```bash\npython link_prediction.py\n```\n\nResults (10 epochs):\n```\nValid MRR 0.7040\nTest MRR 0.7043\n```"
  },
  {
    "path": "examples/graphbolt/disk_based_feature/README.md",
    "content": "## Overview\n\nThis project demonstrates how to use GraphBolt to train and evaluate a GraphSAGE model for node classification task on large graphs, where node features are on-disk and fetched using `DiskBasedFeature`. GraphBolt utilizes various in-house implemented caching policy algorithms such as [SIEVE](https://cachemon.github.io/SIEVE-website/), [S3-FIFO](https://s3fifo.com), LRU and [CLOCK](https://people.csail.mit.edu/saltzer/Multics/MHP-Saltzer-060508/bookcases/M00s/M0104%20074-12%29.PDF) to cache frequently required features and io_uring to fetch cache-missed features from disk. The SIEVE algorithm is the default option.\n\n# Node classification task\n\nThis example demonstrates how to run node classification task with **GraphBolt.DiskBasedFeature**. All results are collected on an AWS EC2 g5.8xlarge instance with 128GB RAM, 32 cores, an 24GB A10G GPU and a instance storage of 250K IOPS.\n\n## Run on `ogbn-papers100M` dataset\n\n|     Dataset     | Graph Size | Feature Size | Feature Dim |\n| :-------------: | :--------: | :----------: | :---------: |\n| ogbn-papers100M |   13 GB   |    53 GB    |     128     |\n\n## Results with various caching policies\n\nThis part trains a three-layer GraphSAGE model for 3 epochs on `ogbn-papers100M` dataset with 10GB CPU cache, using neighbor sampling.\n\n### Run default SIEVE policy\n\nInstruction:\n\n```\npython node_classification.py --gpu-cache-size-in-gigabytes=0 --cpu-cache-size-in-gigabytes=10 --dataset=ogbn-papers100M --epochs=3\n```\n\nResult:\n\n```\nTraining: 1178it [03:00,  6.53it/s, num_nodes=671260, gpu_cache_miss=1, cpu_cache_miss=0.0578]                                             \nEvaluating: 123it [00:16,  7.47it/s, num_nodes=624816, gpu_cache_miss=1, cpu_cache_miss=0.0569]\nEpoch 00, Loss: 1.4173, Approx. Train: 0.5787, Approx. Val: 0.6353, Time: 180.33928060531616s                                              \nTraining: 1178it [01:39, 11.79it/s, num_nodes=648380, gpu_cache_miss=1, cpu_cache_miss=0.0451]                                             \nEvaluating: 123it [00:15,  7.90it/s, num_nodes=625373, gpu_cache_miss=1, cpu_cache_miss=0.0451]\nEpoch 01, Loss: 1.1446, Approx. Train: 0.6386, Approx. Val: 0.6382, Time: 99.92613315582275s                                               \nTraining: 1178it [01:36, 12.15it/s, num_nodes=674194, gpu_cache_miss=1, cpu_cache_miss=0.0408]                                             \nEvaluating: 123it [00:15,  8.08it/s, num_nodes=628233, gpu_cache_miss=1, cpu_cache_miss=0.0409]\nEpoch 02, Loss: 1.0975, Approx. Train: 0.6507, Approx. Val: 0.6535, Time: 96.95083212852478s\n```\n\n### Performance Comparison on four caching polices\n\nBelow results demonstrate the epoch time with four different caching policies.\n\n| Policy | Epoch 1 (s) | Epoch 2 (s) | Epoch 3 (s) |\n| :-----: | :---------: | :---------: | :---------: |\n|  SIEVE  |   180.339   |   99.926   |   96.951   |\n| S3-FiFO |   181.438   |   110.054   |   108.310   |\n|   LRU   |   194.583   |   138.352   |   138.369   |\n|  CLOCK  |   188.915   |   129.372   |   129.388   |\n\n## Results with Layer-Neighbor Sampling\n\nThis part trains a three-layer GraphSAGE model for 3 epochs on `ogbn-papers100M` dataset with 10GB CPU cache, using Layer-Neighbor Sampling and default SIEVE policy.\n\n### Run default `--batch-dependency=1`\n\nInstruction:\n\n```\npython node_classification.py --gpu-cache-size-in-gigabytes=0 --cpu-cache-size-in-gigabytes=10 --dataset=ogbn-papers100M --sample-mode=sample_layer_neighbor --batch-dependency=1 --epochs=3\n```\n\nResult:\n\n```\nTraining: 1178it [02:51,  6.88it/s, num_nodes=463495, gpu_cache_miss=1, cpu_cache_miss=0.0774]                                             \nEvaluating: 123it [00:15,  7.94it/s, num_nodes=465592, gpu_cache_miss=1, cpu_cache_miss=0.0762]\nEpoch 00, Loss: 1.4173, Approx. Train: 0.5774, Approx. Val: 0.6300, Time: 171.11454963684082s                                              \nTraining: 1178it [01:34, 12.43it/s, num_nodes=474446, gpu_cache_miss=1, cpu_cache_miss=0.0604]                                             \nEvaluating: 123it [00:14,  8.45it/s, num_nodes=462042, gpu_cache_miss=1, cpu_cache_miss=0.0603]\nEpoch 01, Loss: 1.1463, Approx. Train: 0.6384, Approx. Val: 0.6395, Time: 94.7821741104126s                                                \nTraining: 1178it [01:31, 12.82it/s, num_nodes=479331, gpu_cache_miss=1, cpu_cache_miss=0.0545]                                             \nEvaluating: 123it [00:14,  8.67it/s, num_nodes=463628, gpu_cache_miss=1, cpu_cache_miss=0.0546]\nEpoch 02, Loss: 1.1000, Approx. Train: 0.6501, Approx. Val: 0.6516, Time: 91.8746063709259s\n```\n\n### Performance Comparison on different `--batch-dependency`\n\n| batch-dependency | Epoch 1 (s) | Epoch 2 (s) | Epoch 3 (s) |\n| :--------------: | :---------: | :---------: | :---------: |\n|        1        |   171.114   |   94.782   |   91.875   |\n|        64        |   144.241   |   78.749   |   75.270   |\n|       4096       |   92.494   |   56.111   |   57.647   |\n\n### Effect of `--layer-dependency`\n\nBelow results demonstrate the effect of enabling `--layer-dependency` on epoch time when setting `--batch-dependency=1`.\n\n| layer-dependency | Epoch 1 (s) | Epoch 2 (s) | Epoch 3 (s) |\n| :--------------: | :---------: | :---------: | :---------: |\n|      False      |   171.114   |   94.782   |   91.875   |\n|       True       |   159.625   |   86.209   |   83.171   |\n\n## Compared to In-mem Performance\n\nThis part trains a three-layer GraphSAGE model for 3 epochs on `ogbn-papers100M` dataset with 20GB CPU cache and 5GB GPU cache, using neighbor sampling. We compare it to the in-mem performance with 5GB GPU cache. Following result demonstrates that with sufficient cache memory, the performance of DiskBasedFeature is not bottlenecked by the cache itself and comparable with in-memory feature stores. Note that the first epoch of training initiates the cache, thus taking longer time.\n\nInstruction:\n\n```\npython node_classification.py --gpu-cache-size-in-gigabytes=5 --cpu-cache-size-in-gigabytes=20 --dataset=ogbn-papers100M --epochs=3\n```\n\nResult:\n\n|  Feature Store  | Epoch 1 (s) | Epoch 2 (s) | Epoch 3 (s) |\n| :--------------: | :---------: | :---------: | :---------: |\n| DiskBasedFeature |   143.761   |   32.018   |   31.889   |\n|    In-memory    |   28.861   |   28.330   |   28.305   |\n"
  },
  {
    "path": "examples/graphbolt/disk_based_feature/node_classification.py",
    "content": "\"\"\"\nThis example references examples/graphbolt/pyg/labor/node_classification.py\n\"\"\"\n\nimport argparse\nimport time\n\nfrom copy import deepcopy\n\nimport dgl.graphbolt as gb\nimport dgl.nn as dglnn\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom tqdm import tqdm\n\n\ndef accuracy(out, labels):\n    assert out.ndim == 2\n    assert out.size(0) == labels.size(0)\n    assert labels.ndim == 1 or (labels.ndim == 2 and labels.size(1) == 1)\n    labels = labels.flatten()\n    predictions = torch.argmax(out, 1)\n    return (labels == predictions).sum(dtype=torch.float64) / labels.size(0)\n\n\nclass SAGE(nn.Module):\n    def __init__(self, in_size, hidden_size, out_size, num_layers, dropout):\n        super().__init__()\n        self.layers = nn.ModuleList()\n        # Three-layer GraphSAGE-mean.\n        self.layers.append(dglnn.SAGEConv(in_size, hidden_size, \"mean\"))\n        for _ in range(num_layers - 2):\n            self.layers.append(dglnn.SAGEConv(hidden_size, hidden_size, \"mean\"))\n        self.layers.append(dglnn.SAGEConv(hidden_size, out_size, \"mean\"))\n        self.dropout = nn.Dropout(dropout)\n        self.hidden_size = hidden_size\n        self.out_size = out_size\n        # Set the dtype for the layers manually.\n        self.set_layer_dtype(torch.float32)\n\n    def set_layer_dtype(self, _dtype):\n        for layer in self.layers:\n            for param in layer.parameters():\n                param.data = param.data.to(_dtype)\n\n    def forward(self, blocks, x):\n        hidden_x = x\n        for layer_idx, (layer, block) in enumerate(zip(self.layers, blocks)):\n            hidden_x = layer(block, hidden_x)\n            is_last_layer = layer_idx == len(self.layers) - 1\n            if not is_last_layer:\n                hidden_x = F.relu(hidden_x)\n                hidden_x = self.dropout(hidden_x)\n        return hidden_x\n\n    def inference(self, graph, features, dataloader, storage_device):\n        \"\"\"Conduct layer-wise inference to get all the node embeddings.\"\"\"\n        pin_memory = storage_device == \"pinned\"\n        buffer_device = torch.device(\"cpu\" if pin_memory else storage_device)\n\n        for layer_idx, layer in enumerate(self.layers):\n            is_last_layer = layer_idx == len(self.layers) - 1\n\n            y = torch.empty(\n                graph.total_num_nodes,\n                self.out_size if is_last_layer else self.hidden_size,\n                dtype=torch.float32,\n                device=buffer_device,\n                pin_memory=pin_memory,\n            )\n            for data in tqdm(dataloader):\n                # len(blocks) = 1\n                hidden_x = layer(data.blocks[0], data.node_features[\"feat\"])\n                if not is_last_layer:\n                    hidden_x = F.relu(hidden_x)\n                    hidden_x = self.dropout(hidden_x)\n                # By design, our output nodes are contiguous.\n                y[data.seeds[0] : data.seeds[-1] + 1] = hidden_x.to(\n                    buffer_device\n                )\n            if not is_last_layer:\n                features.update(\"node\", None, \"feat\", y)\n\n        return y\n\n\ndef create_dataloader(\n    graph, features, itemset, batch_size, fanout, device, job\n):\n\n    # Initialize an ItemSampler to sample mini-batches from the dataset.\n    datapipe = gb.ItemSampler(\n        itemset,\n        batch_size=batch_size,\n        shuffle=(job == \"train\"),\n        drop_last=(job == \"train\"),\n    )\n    # Copy the data to the specified device.\n    if args.graph_device != \"cpu\":\n        datapipe = datapipe.copy_to(device=device)\n    # Sample neighbors for each node in the mini-batch.\n    kwargs = (\n        {\n            # Layer dependency makes it so that the sampled neighborhoods across layers\n            # become correlated, reducing the total number of sampled unique nodes in a\n            # minibatch, thus reducing the amount of feature data requested.\n            \"layer_dependency\": args.layer_dependency,\n            # Batch dependency makes it so that the sampled neighborhoods across minibatches\n            # become correlated, reducing the total number of sampled unique nodes across\n            # minibatches, thus increasing temporal locality and reducing cache miss rates.\n            \"batch_dependency\": args.batch_dependency,\n        }\n        if args.sample_mode == \"sample_layer_neighbor\"\n        else {}\n    )\n    datapipe = getattr(datapipe, args.sample_mode)(\n        graph,\n        fanout if job != \"infer\" else [-1],\n        overlap_fetch=args.overlap_graph_fetch,\n        **kwargs,\n    )\n    # Copy the data to the specified device.\n    if args.feature_device != \"cpu\":\n        datapipe = datapipe.copy_to(device=device)\n    # Fetch node features for the sampled subgraph.\n    datapipe = datapipe.fetch_feature(\n        features,\n        node_feature_keys=[\"feat\"],\n        overlap_fetch=args.overlap_feature_fetch,\n    )\n    # Copy the data to the specified device.\n    if args.feature_device == \"cpu\":\n        datapipe = datapipe.copy_to(device=device)\n    # Create and return a DataLoader to handle data loading.\n    return gb.DataLoader(datapipe, num_workers=args.num_workers)\n\n\ndef train_step(minibatch, optimizer, model, loss_fn):\n    node_features = minibatch.node_features[\"feat\"]\n    labels = minibatch.labels\n    optimizer.zero_grad()\n    out = model(minibatch.blocks, node_features)\n    loss = loss_fn(out, labels)\n    num_correct = accuracy(out, labels) * labels.size(0)\n    loss.backward()\n    optimizer.step()\n    return loss.detach(), num_correct, labels.size(0)\n\n\ndef train_helper(\n    dataloader,\n    model,\n    optimizer,\n    loss_fn,\n    gpu_cache_miss_rate_fn,\n    cpu_cache_miss_rate_fn,\n    device,\n):\n    model.train()  # Set the model to training mode\n    total_loss = torch.zeros(1, device=device)  # Accumulator for the total loss\n    # Accumulator for the total number of correct predictions\n    total_correct = torch.zeros(1, dtype=torch.float64, device=device)\n    total_samples = 0  # Accumulator for the total number of samples processed\n    num_batches = 0  # Counter for the number of mini-batches processed\n    start = time.time()\n    dataloader = tqdm(dataloader, \"Training\")\n    for step, minibatch in enumerate(dataloader):\n        loss, num_correct, num_samples = train_step(\n            minibatch, optimizer, model, loss_fn\n        )\n        total_loss += loss\n        total_correct += num_correct\n        total_samples += num_samples\n        num_batches += 1\n        if step % 25 == 0:\n            # log every 25 steps for performance.\n            dataloader.set_postfix(\n                {\n                    \"num_nodes\": minibatch.node_ids().size(0),\n                    \"gpu_cache_miss\": gpu_cache_miss_rate_fn(),\n                    \"cpu_cache_miss\": cpu_cache_miss_rate_fn(),\n                }\n            )\n    train_loss = total_loss / num_batches\n    train_acc = total_correct / total_samples\n    end = time.time()\n    return train_loss, train_acc, end - start\n\n\ndef train(\n    train_dataloader,\n    valid_dataloader,\n    model,\n    gpu_cache_miss_rate_fn,\n    cpu_cache_miss_rate_fn,\n    device,\n):\n    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)\n    loss_fn = nn.CrossEntropyLoss()\n\n    best_model = None\n    best_model_acc = 0\n    best_model_epoch = -1\n\n    for epoch in range(args.epochs):\n        train_loss, train_acc, duration = train_helper(\n            train_dataloader,\n            model,\n            optimizer,\n            loss_fn,\n            gpu_cache_miss_rate_fn,\n            cpu_cache_miss_rate_fn,\n            device,\n        )\n        val_acc = evaluate(\n            model,\n            valid_dataloader,\n            gpu_cache_miss_rate_fn,\n            cpu_cache_miss_rate_fn,\n            device,\n        )\n        if val_acc > best_model_acc:\n            best_model_acc = val_acc\n            best_model = deepcopy(model.state_dict())\n            best_model_epoch = epoch\n        print(\n            f\"Epoch {epoch:02d}, Loss: {train_loss.item():.4f}, \"\n            f\"Approx. Train: {train_acc.item():.4f}, \"\n            f\"Approx. Val: {val_acc.item():.4f}, \"\n            f\"Time: {duration}s\"\n        )\n        if best_model_epoch + args.early_stopping_patience < epoch:\n            break\n    return best_model\n\n\n@torch.no_grad()\ndef layerwise_infer(\n    args,\n    graph,\n    features,\n    itemsets,\n    all_nodes_set,\n    model,\n):\n    model.eval()\n    dataloader = create_dataloader(\n        graph=graph,\n        features=features,\n        itemset=all_nodes_set,\n        batch_size=args.batch_size,\n        fanout=[-1],\n        device=args.device,\n        job=\"infer\",\n    )\n    pred = model.inference(graph, features, dataloader, args.feature_device)\n\n    metrics = {}\n    for split_name, itemset in itemsets.items():\n        nid, labels = itemset[:]\n        acc = accuracy(\n            pred[nid.to(pred.device)],\n            labels.to(pred.device),\n        )\n        metrics[split_name] = acc.item()\n\n    return metrics\n\n\ndef evaluate_step(minibatch, model):\n    node_features = minibatch.node_features[\"feat\"]\n    labels = minibatch.labels\n    out = model(minibatch.blocks, node_features)\n    num_correct = accuracy(out, labels) * labels.size(0)\n    return num_correct, labels.size(0)\n\n\n@torch.no_grad()\ndef evaluate(\n    model,\n    dataloader,\n    gpu_cache_miss_rate_fn,\n    cpu_cache_miss_rate_fn,\n    device,\n):\n    model.eval()\n    total_correct = torch.zeros(1, dtype=torch.float64, device=device)\n    total_samples = 0\n    val_dataloader_tqdm = tqdm(dataloader, \"Evaluating\")\n    for step, minibatch in enumerate(val_dataloader_tqdm):\n        num_correct, num_samples = evaluate_step(minibatch, model)\n        total_correct += num_correct\n        total_samples += num_samples\n        if step % 25 == 0:\n            val_dataloader_tqdm.set_postfix(\n                {\n                    \"num_nodes\": minibatch.node_ids().size(0),\n                    \"gpu_cache_miss\": gpu_cache_miss_rate_fn(),\n                    \"cpu_cache_miss\": cpu_cache_miss_rate_fn(),\n                }\n            )\n\n    return total_correct / total_samples\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(\n        description=\"Which dataset are you going to use?\"\n    )\n    parser.add_argument(\n        \"--epochs\", type=int, default=9999999, help=\"Number of training epochs.\"\n    )\n    parser.add_argument(\n        \"--lr\",\n        type=float,\n        default=0.001,\n        help=\"Learning rate for optimization.\",\n    )\n    parser.add_argument(\"--num-hidden\", type=int, default=256)\n    parser.add_argument(\"--dropout\", type=float, default=0.2)\n    parser.add_argument(\n        \"--batch-size\", type=int, default=1024, help=\"Batch size for training.\"\n    )\n    parser.add_argument(\n        \"--num-workers\",\n        type=int,\n        default=0,\n        help=\"Number of workers for data loading.\",\n    )\n    parser.add_argument(\n        \"--dataset\",\n        type=str,\n        default=\"ogbn-products\",\n        choices=[\n            \"ogbn-arxiv\",\n            \"ogbn-products\",\n            \"ogbn-papers100M\",\n            \"igb-hom-tiny\",\n            \"igb-hom-small\",\n            \"igb-hom-medium\",\n            \"igb-hom-large\",\n            \"igb-hom\",\n        ],\n    )\n    parser.add_argument(\"--root\", type=str, default=\"datasets\")\n    parser.add_argument(\n        \"--fanout\",\n        type=str,\n        default=\"10,10,10\",\n        help=\"Fan-out of neighbor sampling. len(fanout) determines the number of\"\n        \" GNN layers in your model. Default: 10,10,10\",\n    )\n    parser.add_argument(\n        \"--mode\",\n        default=\"pinned-pinned-cuda\",\n        choices=[\n            \"cpu-cpu-cpu\",\n            \"cpu-cpu-cuda\",\n            \"cpu-pinned-cuda\",\n            \"pinned-pinned-cuda\",\n            \"cuda-pinned-cuda\",\n            \"cuda-cuda-cuda\",\n        ],\n        help=\"Graph storage - feature storage - Train device: 'cpu' for CPU and\"\n        \" RAM, 'pinned' for pinned memory in RAM, 'cuda' for GPU and GPU memory.\",\n    )\n    parser.add_argument(\"--layer-dependency\", action=\"store_true\")\n    parser.add_argument(\"--batch-dependency\", type=int, default=1)\n    parser.add_argument(\n        \"--cpu-feature-cache-policy\",\n        type=str,\n        default=None,\n        choices=[\"s3-fifo\", \"sieve\", \"lru\", \"clock\"],\n        help=\"The cache policy for the CPU feature cache.\",\n    )\n    parser.add_argument(\n        \"--cpu-cache-size-in-gigabytes\",\n        type=float,\n        default=0,\n        help=\"The capacity of the CPU cache in GiB.\",\n    )\n    parser.add_argument(\n        \"--gpu-cache-size-in-gigabytes\",\n        type=float,\n        default=0,\n        help=\"The capacity of the GPU cache in GiB.\",\n    )\n    parser.add_argument(\"--early-stopping-patience\", type=int, default=25)\n    parser.add_argument(\n        \"--sample-mode\",\n        default=\"sample_neighbor\",\n        choices=[\"sample_neighbor\", \"sample_layer_neighbor\"],\n        help=\"The sampling function when doing layerwise sampling.\",\n    )\n    parser.add_argument(\"--precision\", type=str, default=\"high\")\n    parser.add_argument(\"--enable-inference\", action=\"store_true\")\n    return parser.parse_args()\n\n\ndef main():\n    torch.set_float32_matmul_precision(args.precision)\n    if not torch.cuda.is_available():\n        args.mode = \"cpu-cpu-cpu\"\n    print(f\"Training in {args.mode} mode.\")\n    args.graph_device, args.feature_device, args.device = args.mode.split(\"-\")\n    args.overlap_feature_fetch = args.feature_device == \"pinned\"\n    args.overlap_graph_fetch = args.graph_device == \"pinned\"\n\n    \"\"\"\n    Load and preprocess on-disk dataset.\n    We inspect the in_memory field of the feature_data in the YAML file and modify\n    it to False. This will make sure the feature_data is loaded as DiskBasedFeature.\n    \"\"\"\n    print(\"Loading data...\")\n    disk_based_feature_keys = None\n    if args.cpu_cache_size_in_gigabytes > 0:\n        disk_based_feature_keys = [(\"node\", None, \"feat\")]\n\n    dataset = gb.BuiltinDataset(args.dataset, root=args.root)\n    if disk_based_feature_keys is None:\n        disk_based_feature_keys = set()\n    for feature in dataset.yaml_data[\"feature_data\"]:\n        feature_key = (feature[\"domain\"], feature[\"type\"], feature[\"name\"])\n        # Set the in_memory setting to False without modifying YAML file.\n        if feature_key in disk_based_feature_keys:\n            feature[\"in_memory\"] = False\n    dataset = dataset.load()\n\n    # Move the dataset to the selected storage.\n    graph = (\n        dataset.graph.pin_memory_()\n        if args.graph_device == \"pinned\"\n        else dataset.graph.to(args.graph_device)\n    )\n    features = (\n        dataset.feature.pin_memory_()\n        if args.feature_device == \"pinned\"\n        else dataset.feature.to(args.feature_device)\n    )\n\n    train_set = dataset.tasks[0].train_set\n    valid_set = dataset.tasks[0].validation_set\n    test_set = dataset.tasks[0].test_set\n    all_nodes_set = dataset.all_nodes_set\n    args.fanout = list(map(int, args.fanout.split(\",\")))\n    num_classes = dataset.tasks[0].metadata[\"num_classes\"]\n\n    \"\"\"\n    If the CPU cache size is greater than 0, we wrap the DiskBasedFeature to be\n    a CPUCachedFeature. This internally manages the CPU feature cache by the\n    specified cache replacement policy. This will reduce the amount of data\n    transferred during disk read operations for this feature.\n    \n    Note: It is advised to set the CPU cache size to be at least 4 times the number\n    of sampled nodes in a mini-batch, otherwise the feature fetcher might get into\n    a deadlock, causing a hang.\n    \"\"\"\n    if args.cpu_cache_size_in_gigabytes > 0 and isinstance(\n        features[(\"node\", None, \"feat\")], gb.DiskBasedFeature\n    ):\n        features[(\"node\", None, \"feat\")] = gb.cpu_cached_feature(\n            features[(\"node\", None, \"feat\")],\n            int(args.cpu_cache_size_in_gigabytes * 1024 * 1024 * 1024),\n            args.cpu_feature_cache_policy,\n            args.feature_device == \"pinned\",\n        )\n        cpu_cached_feature = features[(\"node\", None, \"feat\")]\n        cpu_cache_miss_rate_fn = lambda: cpu_cached_feature.miss_rate\n    else:\n        cpu_cache_miss_rate_fn = lambda: 1\n\n    \"\"\"\n    If the GPU cache size is greater than 0, we wrap the underlying feature store\n    to be a GPUCachedFeature. This will reduce the amount of data transferred during\n    host-to-device copy operations for this feature.\n    \"\"\"\n    if args.gpu_cache_size_in_gigabytes > 0 and args.feature_device != \"cuda\":\n        features[(\"node\", None, \"feat\")] = gb.gpu_cached_feature(\n            features[(\"node\", None, \"feat\")],\n            int(args.gpu_cache_size_in_gigabytes * 1024 * 1024 * 1024),\n        )\n        gpu_cached_feature = features[(\"node\", None, \"feat\")]\n        gpu_cache_miss_rate_fn = lambda: gpu_cached_feature.miss_rate\n    else:\n        gpu_cache_miss_rate_fn = lambda: 1\n\n    train_dataloader, valid_dataloader = (\n        create_dataloader(\n            graph=graph,\n            features=features,\n            itemset=itemset,\n            batch_size=args.batch_size,\n            fanout=args.fanout,\n            device=args.device,\n            job=job,\n        )\n        for itemset, job in zip([train_set, valid_set], [\"train\", \"evaluate\"])\n    )\n\n    in_channels = features.size(\"node\", None, \"feat\")[0]\n    model = SAGE(\n        in_channels,\n        args.num_hidden,\n        num_classes,\n        len(args.fanout),\n        args.dropout,\n    ).to(args.device)\n    assert len(args.fanout) == len(model.layers)\n\n    best_model = train(\n        train_dataloader,\n        valid_dataloader,\n        model,\n        gpu_cache_miss_rate_fn,\n        cpu_cache_miss_rate_fn,\n        args.device,\n    )\n    model.load_state_dict(best_model)\n\n    if args.enable_inference:\n        # Test the model.\n        print(\"Testing...\")\n        itemsets = {\"train\": train_set, \"val\": valid_set, \"test\": test_set}\n        final_acc = layerwise_infer(\n            args,\n            graph,\n            features,\n            itemsets,\n            all_nodes_set,\n            model,\n        )\n        print(\"Final accuracy values:\")\n        print(final_acc)\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    print(args)\n    main()\n"
  },
  {
    "path": "examples/graphbolt/lightning/README.md",
    "content": "# Node classification on homogeneous graph with GraphSAGE\n\n## Run on `ogbn-products` dataset\n\n### Command\n```\npython3 node_classification.py\n```\n\n### Results\n```\nValid Accuracy: 0.907\n```"
  },
  {
    "path": "examples/graphbolt/lightning/node_classification.py",
    "content": "\"\"\"\nThis flowchart describes the main functional sequence of the provided example.\nmain\n│\n├───> Instantiate DataModule\n│     │\n│     └───> Load dataset\n│     │\n│     └───> Create train and valid dataloader[HIGHLIGHT]\n│           │\n│           └───> ItemSampler (Distribute data to minibatchs)\n│           │\n│           └───> sample_neighbor or sample_layer_neighbor\n                  (Sample a subgraph for a minibatch)\n│           │\n│           └───> fetch_feature (Fetch features for the sampled subgraph)\n│\n├───> Instantiate GraphSAGE model\n│     │\n│     ├───> SAGEConvLayer (input to hidden)\n│     │\n│     └───> SAGEConvLayer (hidden to hidden)\n│     │\n│     └───> SAGEConvLayer (hidden to output)\n│     │\n│     └───> DropoutLayer\n│\n└───> Run\n      │\n      │\n      └───> Trainer[HIGHLIGHT]\n            │\n            ├───> SAGE.forward (GraphSAGE model forward pass)\n            │\n            └───> Validate\n\"\"\"\nimport argparse\n\nimport dgl.graphbolt as gb\nimport dgl.nn.pytorch as dglnn\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom pytorch_lightning import LightningDataModule, LightningModule, Trainer\nfrom pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint\nfrom torchmetrics import Accuracy\n\n\nclass SAGE(LightningModule):\n    def __init__(self, in_feats, n_hidden, n_classes):\n        super().__init__()\n        self.save_hyperparameters()\n        self.layers = nn.ModuleList()\n        self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, \"mean\"))\n        self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, \"mean\"))\n        self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, \"mean\"))\n        self.dropout = nn.Dropout(0.5)\n        self.n_hidden = n_hidden\n        self.n_classes = n_classes\n        self.train_acc = Accuracy(task=\"multiclass\", num_classes=n_classes)\n        self.val_acc = Accuracy(task=\"multiclass\", num_classes=n_classes)\n\n    def forward(self, blocks, x):\n        h = x\n        for l, (layer, block) in enumerate(zip(self.layers, blocks)):\n            h = layer(block, h)\n            if l != len(self.layers) - 1:\n                h = F.relu(h)\n                h = self.dropout(h)\n        return h\n\n    def log_node_and_edge_counts(self, blocks):\n        node_counts = [block.num_src_nodes() for block in blocks] + [\n            blocks[-1].num_dst_nodes()\n        ]\n        edge_counts = [block.num_edges() for block in blocks]\n        for i, c in enumerate(node_counts):\n            self.log(\n                f\"num_nodes/{i}\",\n                float(c),\n                prog_bar=True,\n                on_step=True,\n                on_epoch=False,\n            )\n            if i < len(edge_counts):\n                self.log(\n                    f\"num_edges/{i}\",\n                    float(edge_counts[i]),\n                    prog_bar=True,\n                    on_step=True,\n                    on_epoch=False,\n                )\n\n    def training_step(self, batch, batch_idx):\n        blocks = [block.to(\"cuda\") for block in batch.blocks]\n        x = batch.node_features[\"feat\"]\n        y = batch.labels.to(\"cuda\")\n        y_hat = self(blocks, x)\n        loss = F.cross_entropy(y_hat, y)\n        self.train_acc(torch.argmax(y_hat, 1), y)\n        self.log(\n            \"train_acc\",\n            self.train_acc,\n            prog_bar=True,\n            on_step=True,\n            on_epoch=False,\n        )\n        self.log_node_and_edge_counts(blocks)\n        return loss\n\n    def validation_step(self, batch, batch_idx):\n        blocks = [block.to(\"cuda\") for block in batch.blocks]\n        x = batch.node_features[\"feat\"]\n        y = batch.labels.to(\"cuda\")\n        y_hat = self(blocks, x)\n        self.val_acc(torch.argmax(y_hat, 1), y)\n        self.log(\n            \"val_acc\",\n            self.val_acc,\n            prog_bar=True,\n            on_step=False,\n            on_epoch=True,\n            sync_dist=True,\n        )\n        self.log_node_and_edge_counts(blocks)\n\n    def configure_optimizers(self):\n        optimizer = torch.optim.Adam(\n            self.parameters(), lr=0.001, weight_decay=5e-4\n        )\n        return optimizer\n\n\nclass DataModule(LightningDataModule):\n    def __init__(self, dataset, fanouts, batch_size, num_workers):\n        super().__init__()\n        self.fanouts = fanouts\n        self.batch_size = batch_size\n        self.num_workers = num_workers\n        self.feature_store = dataset.feature\n        self.graph = dataset.graph\n        self.train_set = dataset.tasks[0].train_set\n        self.valid_set = dataset.tasks[0].validation_set\n        self.num_classes = dataset.tasks[0].metadata[\"num_classes\"]\n\n    def create_dataloader(self, node_set, is_train):\n        datapipe = gb.ItemSampler(\n            node_set,\n            batch_size=self.batch_size,\n            shuffle=True,\n            drop_last=True,\n        )\n        sampler = (\n            datapipe.sample_layer_neighbor\n            if is_train\n            else datapipe.sample_neighbor\n        )\n        datapipe = sampler(self.graph, self.fanouts)\n        datapipe = datapipe.fetch_feature(self.feature_store, [\"feat\"])\n        dataloader = gb.DataLoader(datapipe, num_workers=self.num_workers)\n        return dataloader\n\n    ########################################################################\n    # (HIGHLIGHT) The 'train_dataloader' and 'val_dataloader' hooks are\n    # essential components of the Lightning framework, defining how data is\n    # loaded during training and validation. In this example, we utilize a\n    # specialized 'graphbolt dataloader', which are concatenated by a series\n    # of datapipes, for these purposes.\n    ########################################################################\n    def train_dataloader(self):\n        return self.create_dataloader(self.train_set, is_train=True)\n\n    def val_dataloader(self):\n        return self.create_dataloader(self.valid_set, is_train=False)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(\n        description=\"GNN baselines on ogbn-products data with GraphBolt\"\n    )\n    parser.add_argument(\n        \"--num_gpus\",\n        type=int,\n        default=1,\n        help=\"number of GPUs used for computing (default: 1)\",\n    )\n    parser.add_argument(\n        \"--batch_size\",\n        type=int,\n        default=1024,\n        help=\"input batch size for training (default: 1024)\",\n    )\n    parser.add_argument(\n        \"--epochs\",\n        type=int,\n        default=40,\n        help=\"number of epochs to train (default: 40)\",\n    )\n    parser.add_argument(\n        \"--num_workers\",\n        type=int,\n        default=0,\n        help=\"number of workers (default: 0)\",\n    )\n    args = parser.parse_args()\n\n    dataset = gb.BuiltinDataset(\"ogbn-products\").load()\n    datamodule = DataModule(\n        dataset,\n        [10, 10, 10],\n        args.batch_size,\n        args.num_workers,\n    )\n    in_size = dataset.feature.size(\"node\", None, \"feat\")[0]\n    model = SAGE(in_size, 256, datamodule.num_classes)\n\n    # Train.\n    checkpoint_callback = ModelCheckpoint(monitor=\"val_acc\", mode=\"max\")\n    early_stopping_callback = EarlyStopping(monitor=\"val_acc\", mode=\"max\")\n    ########################################################################\n    # (HIGHLIGHT) The `Trainer` is the key Class in lightning, which automates\n    # everything after defining `LightningDataModule` and\n    # `LightningDataModule`. More details can be found in\n    # https://lightning.ai/docs/pytorch/stable/common/trainer.html.\n    ########################################################################\n    trainer = Trainer(\n        accelerator=\"gpu\",\n        devices=args.num_gpus,\n        max_epochs=args.epochs,\n        callbacks=[checkpoint_callback, early_stopping_callback],\n    )\n    trainer.fit(model, datamodule=datamodule)\n"
  },
  {
    "path": "examples/graphbolt/link_prediction.py",
    "content": "\"\"\"\nThis script trains and tests a GraphSAGE model for link prediction on\nlarge graphs using graphbolt dataloader.\n\nPaper: [Inductive Representation Learning on Large Graphs]\n(https://arxiv.org/abs/1706.02216)\n\nUnlike previous dgl examples, we've utilized the newly defined dataloader\nfrom GraphBolt. This example will help you grasp how to build an end-to-end\ntraining pipeline using GraphBolt.\n\nWhile node classification predicts labels for nodes based on their\nlocal neighborhoods, link prediction assesses the likelihood of an edge\nexisting between two nodes, necessitating different sampling strategies\nthat account for pairs of nodes and their joint neighborhoods.\n\nTODO: Add the link_prediction.py example to core/graphsage.\nBefore reading this example, please familiar yourself with graphsage link\nprediction by reading the example in the\n`examples/core/graphsage/link_prediction.py`\n\nIf you want to train graphsage on a large graph in a distributed fashion, read\nthe example in the `examples/distributed/graphsage/`.\n\nThis flowchart describes the main functional sequence of the provided example.\nmain\n│\n├───> OnDiskDataset pre-processing\n│\n├───> Instantiate SAGE model\n│\n├───> train\n│     │\n│     ├───> Get graphbolt dataloader (HIGHLIGHT)\n│     │\n│     └───> Training loop\n│           │\n│           ├───> SAGE.forward\n│           │\n│           └───> Validation set evaluation\n│\n└───> Test set evaluation\n\"\"\"\nimport argparse\nimport time\nfrom functools import partial\n\nimport dgl.graphbolt as gb\nimport dgl.nn as dglnn\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport tqdm\nfrom torchmetrics.retrieval import RetrievalMRR\n\n\nclass SAGE(nn.Module):\n    def __init__(self, in_size, hidden_size):\n        super().__init__()\n        self.layers = nn.ModuleList()\n        self.layers.append(dglnn.SAGEConv(in_size, hidden_size, \"mean\"))\n        self.layers.append(dglnn.SAGEConv(hidden_size, hidden_size, \"mean\"))\n        self.layers.append(dglnn.SAGEConv(hidden_size, hidden_size, \"mean\"))\n        self.hidden_size = hidden_size\n        self.predictor = nn.Sequential(\n            nn.Linear(hidden_size, hidden_size),\n            nn.ReLU(),\n            nn.Linear(hidden_size, hidden_size),\n            nn.ReLU(),\n            nn.Linear(hidden_size, 1),\n        )\n\n    def forward(self, blocks, x):\n        hidden_x = x\n        for layer_idx, (layer, block) in enumerate(zip(self.layers, blocks)):\n            hidden_x = layer(block, hidden_x)\n            is_last_layer = layer_idx == len(self.layers) - 1\n            if not is_last_layer:\n                hidden_x = F.relu(hidden_x)\n        return hidden_x\n\n    def inference(self, graph, features, dataloader, storage_device):\n        \"\"\"Conduct layer-wise inference to get all the node embeddings.\"\"\"\n        pin_memory = storage_device == \"pinned\"\n        buffer_device = torch.device(\"cpu\" if pin_memory else storage_device)\n\n        print(\"Start node embedding inference.\")\n        for layer_idx, layer in enumerate(self.layers):\n            is_last_layer = layer_idx == len(self.layers) - 1\n\n            y = torch.empty(\n                graph.total_num_nodes,\n                self.hidden_size,\n                dtype=torch.float32,\n                device=buffer_device,\n                pin_memory=pin_memory,\n            )\n            for data in tqdm.tqdm(dataloader):\n                # len(blocks) = 1\n                hidden_x = layer(data.blocks[0], data.node_features[\"feat\"])\n                if not is_last_layer:\n                    hidden_x = F.relu(hidden_x)\n                # By design, our seed nodes are contiguous.\n                y[data.seeds[0] : data.seeds[-1] + 1] = hidden_x.to(\n                    buffer_device, non_blocking=True\n                )\n            if not is_last_layer:\n                features.update(\"node\", None, \"feat\", y)\n\n        return y\n\n\ndef create_dataloader(args, graph, features, itemset, is_train=True):\n    \"\"\"Get a GraphBolt version of a dataloader for link prediction tasks. This\n    function demonstrates how to utilize functional forms of datapipes in\n    GraphBolt. Alternatively, you can create a datapipe using its class\n    constructor. For a more detailed tutorial, please read the examples in\n    `dgl/notebooks/graphbolt/walkthrough.ipynb`.\n    \"\"\"\n\n    ############################################################################\n    # [Input]:\n    # 'itemset': The current dataset.\n    # 'args.batch_size': Specify the number of samples to be processed together,\n    # referred to as a 'mini-batch'. (The term 'mini-batch' is used here to\n    # indicate a subset of the entire dataset that is processed together. This\n    # is in contrast to processing the entire dataset, known as a 'full batch'.)\n    # 'is_train': Determining if data should be shuffled. (Shuffling is\n    # generally used only in training to improve model generalization. It's\n    # not used in validation and testing as the focus there is to evaluate\n    # performance rather than to learn from the data.)\n    # [Output]:\n    # An ItemSampler object for handling mini-batch sampling.\n    # [Role]:\n    # Initialize the ItemSampler to sample mini-batche from the dataset.\n    ############################################################################\n    datapipe = gb.ItemSampler(\n        itemset,\n        batch_size=args.train_batch_size if is_train else args.eval_batch_size,\n        shuffle=is_train,\n    )\n\n    ############################################################################\n    # [Input]:\n    # 'device': The device to copy the data to.\n    # [Output]:\n    # A CopyTo object to copy the data to the specified device. Copying here\n    # ensures that the rest of the operations run on the GPU.\n    ############################################################################\n    if args.storage_device != \"cpu\":\n        datapipe = datapipe.copy_to(device=args.device)\n\n    ############################################################################\n    # [Input]:\n    # 'args.neg_ratio': Specify the ratio of negative to positive samples.\n    # (E.g., if neg_ratio is 1, for each positive sample there will be 1\n    # negative sample.)\n    # 'graph': The overall network topology for negative sampling.\n    # [Output]:\n    # A UniformNegativeSampler object that will handle the generation of\n    # negative samples for link prediction tasks.\n    # [Role]:\n    # Initialize the UniformNegativeSampler for negative sampling in link\n    # prediction.\n    # [Note]:\n    # If 'is_train' is False, the UniformNegativeSampler will not be used.\n    # Since, in validation and testing, the itemset already contains the\n    # negative edges information.\n    ############################################################################\n    if is_train:\n        datapipe = datapipe.sample_uniform_negative(graph, args.neg_ratio)\n\n    ############################################################################\n    # [Input]:\n    # 'datapipe' is either 'ItemSampler' or 'UniformNegativeSampler' depending\n    # on whether training is needed ('is_train'),\n    # 'graph': The network topology for sampling.\n    # 'args.fanout': Number of neighbors to sample per node.\n    # [Output]:\n    # A NeighborSampler object to sample neighbors.\n    # [Role]:\n    # Initialize a neighbor sampler for sampling the neighborhoods of nodes.\n    ############################################################################\n    datapipe = datapipe.sample_neighbor(\n        graph,\n        args.fanout if is_train else [-1],\n        overlap_fetch=args.storage_device == \"pinned\",\n        asynchronous=args.storage_device != \"cpu\",\n    )\n\n    ############################################################################\n    # [Input]:\n    # 'gb.exclude_seed_edges': Function to exclude seed edges, optionally\n    # including their reverse edges, from the sampled subgraphs in the\n    # minibatch.\n    # [Output]:\n    # A MiniBatchTransformer object with excluded seed edges.\n    # [Role]:\n    # During the training phase of link prediction, negative edges are\n    # sampled. It's essential to exclude the seed edges from the process\n    # to ensure that positive samples are not inadvertently included within\n    # the negative samples.\n    ############################################################################\n    if is_train and args.exclude_edges:\n        datapipe = datapipe.exclude_seed_edges(\n            include_reverse_edges=True,\n            asynchronous=args.storage_device != \"cpu\",\n        )\n\n    ############################################################################\n    # [Input]:\n    # 'features': The node features.\n    # 'node_feature_keys': The node feature keys (list) to be fetched.\n    # [Output]:\n    # A FeatureFetcher object to fetch node features.\n    # [Role]:\n    # Initialize a feature fetcher for fetching features of the sampled\n    # subgraphs.\n    ############################################################################\n    datapipe = datapipe.fetch_feature(features, node_feature_keys=[\"feat\"])\n\n    ############################################################################\n    # [Input]:\n    # 'device': The device to copy the data to.\n    # [Output]:\n    # A CopyTo object to copy the data to the specified device.\n    ############################################################################\n    if args.storage_device == \"cpu\":\n        datapipe = datapipe.copy_to(device=args.device)\n\n    ############################################################################\n    # [Input]:\n    # 'datapipe': The datapipe object to be used for data loading.\n    # 'args.num_workers': The number of processes to be used for data loading.\n    # [Output]:\n    # A DataLoader object to handle data loading.\n    # [Role]:\n    # Initialize a multi-process dataloader to load the data in parallel.\n    ############################################################################\n    dataloader = gb.DataLoader(\n        datapipe,\n        num_workers=args.num_workers,\n    )\n\n    # Return the fully-initialized DataLoader object.\n    return dataloader\n\n\n@torch.no_grad()\ndef compute_mrr(args, model, node_emb, seeds, labels, indexes):\n    \"\"\"Compute the Mean Reciprocal Rank (MRR) for given source and destination\n    nodes.\n\n    This function computes the MRR for a set of node pairs, dividing the task\n    into batches to handle potentially large graphs.\n    \"\"\"\n\n    preds = torch.empty(seeds.shape[0], device=indexes.device)\n    mrr = RetrievalMRR()\n    seeds_src, seeds_dst = seeds.T\n    # The constant number is 1001, due to negtive ratio in the `ogbl-citation2`\n    # dataset is 1000.\n    eval_size = args.eval_batch_size * 1001\n    # Loop over node pairs in batches.\n    for start in tqdm.trange(0, seeds_src.shape[0], eval_size, desc=\"Evaluate\"):\n        end = min(start + eval_size, seeds_src.shape[0])\n\n        # Fetch embeddings for current batch of source and destination nodes.\n        h_src = node_emb[seeds_src[start:end]].to(args.device)\n        h_dst = node_emb[seeds_dst[start:end]].to(args.device)\n\n        # Compute prediction scores using the model.\n        pred = model.predictor(h_src * h_dst).squeeze()\n        preds[start:end] = pred\n    return mrr(preds, labels, indexes=indexes)\n\n\n@torch.no_grad()\ndef evaluate(args, model, graph, features, all_nodes_set, valid_set, test_set):\n    \"\"\"Evaluate the model on validation and test sets.\"\"\"\n    model.eval()\n\n    dataloader = create_dataloader(\n        args, graph, features, all_nodes_set, is_train=False\n    )\n\n    # Compute node embeddings for the entire graph.\n    node_emb = model.inference(graph, features, dataloader, args.storage_device)\n    results = []\n\n    # Loop over both validation and test sets.\n    for split in [valid_set, test_set]:\n        # Unpack the item set.\n        seeds = split._items[0].to(node_emb.device)\n        labels = split._items[1].to(node_emb.device)\n        indexes = split._items[2].to(node_emb.device)\n\n        # Compute MRR values for the current split.\n        results.append(\n            compute_mrr(args, model, node_emb, seeds, labels, indexes)\n        )\n    return results\n\n\ndef train(args, model, graph, features, train_set):\n    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)\n    dataloader = create_dataloader(args, graph, features, train_set)\n\n    for epoch in range(args.epochs):\n        model.train()\n        total_loss = 0\n        start_epoch_time = time.time()\n        for step, data in tqdm.tqdm(enumerate(dataloader)):\n            # Get node pairs with labels for loss calculation.\n            compacted_seeds = data.compacted_seeds.T\n            labels = data.labels\n\n            node_feature = data.node_features[\"feat\"]\n            blocks = data.blocks\n\n            # Get the embeddings of the input nodes.\n            y = model(blocks, node_feature)\n            logits = model.predictor(\n                y[compacted_seeds[0]] * y[compacted_seeds[1]]\n            ).squeeze()\n\n            # Compute loss.\n            loss = F.binary_cross_entropy_with_logits(logits, labels)\n            optimizer.zero_grad()\n            loss.backward()\n            optimizer.step()\n\n            total_loss += loss.item()\n            if step + 1 == args.early_stop:\n                break\n\n        end_epoch_time = time.time()\n        print(\n            f\"Epoch {epoch:05d} | \"\n            f\"Loss {(total_loss) / (step + 1):.4f} | \"\n            f\"Time {(end_epoch_time - start_epoch_time):.4f} s\"\n        )\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"OGBL-Citation2 (GraphBolt)\")\n    parser.add_argument(\"--epochs\", type=int, default=10)\n    parser.add_argument(\"--lr\", type=float, default=0.0005)\n    parser.add_argument(\"--neg-ratio\", type=int, default=1)\n    parser.add_argument(\"--train-batch-size\", type=int, default=512)\n    parser.add_argument(\"--eval-batch-size\", type=int, default=1024)\n    parser.add_argument(\"--num-workers\", type=int, default=0)\n    parser.add_argument(\n        \"--early-stop\",\n        type=int,\n        default=0,\n        help=\"0 means no early stop, otherwise stop at the input-th step\",\n    )\n    parser.add_argument(\n        \"--fanout\",\n        type=str,\n        default=\"15,10,5\",\n        help=\"Fan-out of neighbor sampling. Default: 15,10,5\",\n    )\n    parser.add_argument(\n        \"--exclude-edges\",\n        type=int,\n        default=1,\n        help=\"Whether to exclude reverse edges during sampling. Default: 1\",\n    )\n    parser.add_argument(\n        \"--mode\",\n        default=\"pinned-cuda\",\n        choices=[\"cpu-cpu\", \"cpu-cuda\", \"pinned-cuda\", \"cuda-cuda\"],\n        help=\"Dataset storage placement and Train device: 'cpu' for CPU and RAM,\"\n        \" 'pinned' for pinned memory in RAM, 'cuda' for GPU and GPU memory.\",\n    )\n    return parser.parse_args()\n\n\ndef main(args):\n    if not torch.cuda.is_available():\n        args.mode = \"cpu-cpu\"\n    print(f\"Training in {args.mode} mode.\")\n    args.storage_device, args.device = args.mode.split(\"-\")\n    args.device = torch.device(args.device)\n\n    # Load and preprocess dataset.\n    print(\"Loading data\")\n    dataset = gb.BuiltinDataset(\"ogbl-citation2\").load()\n\n    # Move the dataset to the selected storage.\n    if args.storage_device == \"pinned\":\n        graph = dataset.graph.pin_memory_()\n        features = dataset.feature.pin_memory_()\n    else:\n        graph = dataset.graph.to(args.storage_device)\n        features = dataset.feature.to(args.storage_device)\n\n    train_set = dataset.tasks[0].train_set\n    args.fanout = list(map(int, args.fanout.split(\",\")))\n\n    in_size = features.size(\"node\", None, \"feat\")[0]\n    hidden_channels = 256\n    args.device = torch.device(args.device)\n    model = SAGE(in_size, hidden_channels).to(args.device)\n\n    # Model training.\n    print(\"Training...\")\n    train(args, model, graph, features, train_set)\n\n    # Test the model.\n    print(\"Testing...\")\n    test_set = dataset.tasks[0].test_set\n    valid_set = dataset.tasks[0].validation_set\n    all_nodes_set = dataset.all_nodes_set\n    valid_mrr, test_mrr = evaluate(\n        args, model, graph, features, all_nodes_set, valid_set, test_set\n    )\n    print(\n        f\"Validation MRR {valid_mrr.item():.4f}, \"\n        f\"Test MRR {test_mrr.item():.4f}\"\n    )\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/graphbolt/node_classification.py",
    "content": "\"\"\"\nThis script trains and tests a GraphSAGE model for node classification\non large graphs using GraphBolt dataloader.\n\nPaper: [Inductive Representation Learning on Large Graphs]\n(https://arxiv.org/abs/1706.02216)\n\nUnlike previous dgl examples, we've utilized the newly defined dataloader\nfrom GraphBolt. This example will help you grasp how to build an end-to-end\ntraining pipeline using GraphBolt.\n\nBefore reading this example, please familiar yourself with graphsage node\nclassification by reading the example in the\n`examples/core/graphsage/node_classification.py`. This introduction,\n[A Blitz Introduction to Node Classification with DGL]\n(https://docs.dgl.ai/tutorials/blitz/1_introduction.html), might be helpful.\n\nIf you want to train graphsage on a large graph in a distributed fashion,\nplease read the example in the `examples/distributed/graphsage/`.\n\nThis flowchart describes the main functional sequence of the provided example:\nmain\n│\n├───> OnDiskDataset pre-processing\n│\n├───> Instantiate SAGE model\n│\n├───> train\n│     │\n│     ├───> Get graphbolt dataloader (HIGHLIGHT)\n│     │\n│     └───> Training loop\n│           │\n│           ├───> SAGE.forward\n│           │\n│           └───> Validation set evaluation\n│\n└───> All nodes set inference & Test set evaluation\n\"\"\"\nimport argparse\nimport time\n\nimport dgl.graphbolt as gb\nimport dgl.nn as dglnn\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torchmetrics.functional as MF\nfrom tqdm import tqdm\n\n\ndef create_dataloader(\n    graph, features, itemset, batch_size, fanout, device, num_workers, job\n):\n    \"\"\"\n    [HIGHLIGHT]\n    Get a GraphBolt version of a dataloader for node classification tasks.\n    This function demonstrates how to utilize functional forms of datapipes in\n    GraphBolt. For a more detailed tutorial, please read the examples in\n    `dgl/notebooks/graphbolt/walkthrough.ipynb`.\n    Alternatively, you can create a datapipe using its class constructor.\n\n    Parameters\n    ----------\n    job : one of [\"train\", \"evaluate\", \"infer\"]\n        The stage where dataloader is created, with options \"train\", \"evaluate\"\n        and \"infer\".\n    Other parameters are explicated in the comments below.\n    \"\"\"\n\n    ############################################################################\n    # [Step-1]:\n    # gb.ItemSampler()\n    # [Input]:\n    # 'itemset': The current dataset. (e.g. `train_set` or `valid_set`)\n    # 'batch_size': Specify the number of samples to be processed together,\n    # referred to as a 'mini-batch'. (The term 'mini-batch' is used here to\n    # indicate a subset of the entire dataset that is processed together. This\n    # is in contrast to processing the entire dataset, known as a 'full batch'.)\n    # 'job': Determines whether data should be shuffled. (Shuffling is\n    # generally used only in training to improve model generalization. It's\n    # not used in validation and testing as the focus there is to evaluate\n    # performance rather than to learn from the data.)\n    # [Output]:\n    # An ItemSampler object for handling mini-batch sampling.\n    # [Role]:\n    # Initialize the ItemSampler to sample mini-batche from the dataset.\n    ############################################################################\n    datapipe = gb.ItemSampler(\n        itemset, batch_size=batch_size, shuffle=(job == \"train\")\n    )\n\n    ############################################################################\n    # [Step-2]:\n    # self.copy_to()\n    # [Input]:\n    # 'device': The device to copy the data to.\n    # [Output]:\n    # A CopyTo object to copy the data to the specified device. Copying here\n    # ensures that the rest of the operations run on the GPU.\n    ############################################################################\n    if args.storage_device != \"cpu\":\n        datapipe = datapipe.copy_to(device=device)\n\n    ############################################################################\n    # [Step-3]:\n    # self.sample_neighbor()\n    # [Input]:\n    # 'graph': The network topology for sampling.\n    # '[-1] or fanout': Number of neighbors to sample per node. In\n    # training or validation, the length of `fanout` should be equal to the\n    # number of layers in the model. In inference, this parameter is set to\n    # [-1], indicating that all neighbors of a node are sampled.\n    # [Output]:\n    # A NeighborSampler object to sample neighbors.\n    # [Role]:\n    # Initialize a neighbor sampler for sampling the neighborhoods of nodes.\n    ############################################################################\n    datapipe = getattr(datapipe, args.sample_mode)(\n        graph,\n        fanout if job != \"infer\" else [-1],\n        overlap_fetch=args.storage_device == \"pinned\",\n        asynchronous=args.storage_device != \"cpu\",\n    )\n\n    ############################################################################\n    # [Step-4]:\n    # self.fetch_feature()\n    # [Input]:\n    # 'features': The node features.\n    # 'node_feature_keys': The keys of the node features to be fetched.\n    # [Output]:\n    # A FeatureFetcher object to fetch node features.\n    # [Role]:\n    # Initialize a feature fetcher for fetching features of the sampled\n    # subgraphs.\n    ############################################################################\n    datapipe = datapipe.fetch_feature(features, node_feature_keys=[\"feat\"])\n\n    ############################################################################\n    # [Step-5]:\n    # self.copy_to()\n    # [Input]:\n    # 'device': The device to copy the data to.\n    # [Output]:\n    # A CopyTo object to copy the data to the specified device.\n    ############################################################################\n    if args.storage_device == \"cpu\":\n        datapipe = datapipe.copy_to(device=device)\n\n    ############################################################################\n    # [Step-6]:\n    # gb.DataLoader()\n    # [Input]:\n    # 'datapipe': The datapipe object to be used for data loading.\n    # 'num_workers': The number of processes to be used for data loading.\n    # [Output]:\n    # A DataLoader object to handle data loading.\n    # [Role]:\n    # Initialize a multi-process dataloader to load the data in parallel.\n    ############################################################################\n    dataloader = gb.DataLoader(datapipe, num_workers=num_workers)\n\n    # Return the fully-initialized DataLoader object.\n    return dataloader\n\n\nclass SAGE(nn.Module):\n    def __init__(self, in_size, hidden_size, out_size):\n        super().__init__()\n        self.layers = nn.ModuleList()\n        # Three-layer GraphSAGE-mean.\n        self.layers.append(dglnn.SAGEConv(in_size, hidden_size, \"mean\"))\n        self.layers.append(dglnn.SAGEConv(hidden_size, hidden_size, \"mean\"))\n        self.layers.append(dglnn.SAGEConv(hidden_size, out_size, \"mean\"))\n        self.dropout = nn.Dropout(0.5)\n        self.hidden_size = hidden_size\n        self.out_size = out_size\n        # Set the dtype for the layers manually.\n        self.set_layer_dtype(torch.float32)\n\n    def set_layer_dtype(self, _dtype):\n        for layer in self.layers:\n            for param in layer.parameters():\n                param.data = param.data.to(_dtype)\n\n    def forward(self, blocks, x):\n        hidden_x = x\n        for layer_idx, (layer, block) in enumerate(zip(self.layers, blocks)):\n            hidden_x = layer(block, hidden_x)\n            is_last_layer = layer_idx == len(self.layers) - 1\n            if not is_last_layer:\n                hidden_x = F.relu(hidden_x)\n                hidden_x = self.dropout(hidden_x)\n        return hidden_x\n\n    def inference(self, graph, features, dataloader, storage_device):\n        \"\"\"Conduct layer-wise inference to get all the node embeddings.\"\"\"\n        pin_memory = storage_device == \"pinned\"\n        buffer_device = torch.device(\"cpu\" if pin_memory else storage_device)\n\n        for layer_idx, layer in enumerate(self.layers):\n            is_last_layer = layer_idx == len(self.layers) - 1\n\n            y = torch.empty(\n                graph.total_num_nodes,\n                self.out_size if is_last_layer else self.hidden_size,\n                dtype=torch.float32,\n                device=buffer_device,\n                pin_memory=pin_memory,\n            )\n            for data in tqdm(dataloader):\n                # len(blocks) = 1\n                hidden_x = layer(data.blocks[0], data.node_features[\"feat\"])\n                if not is_last_layer:\n                    hidden_x = F.relu(hidden_x)\n                    hidden_x = self.dropout(hidden_x)\n                # By design, our output nodes are contiguous.\n                y[data.seeds[0] : data.seeds[-1] + 1] = hidden_x.to(\n                    buffer_device\n                )\n            if not is_last_layer:\n                features.update(\"node\", None, \"feat\", y)\n\n        return y\n\n\n@torch.no_grad()\ndef layerwise_infer(\n    args, graph, features, test_set, all_nodes_set, model, num_classes\n):\n    model.eval()\n    dataloader = create_dataloader(\n        graph=graph,\n        features=features,\n        itemset=all_nodes_set,\n        batch_size=4 * args.batch_size,\n        fanout=[-1],\n        device=args.device,\n        num_workers=args.num_workers,\n        job=\"infer\",\n    )\n    pred = model.inference(graph, features, dataloader, args.storage_device)\n    pred = pred[test_set._items[0]]\n    label = test_set._items[1].to(pred.device)\n\n    return MF.accuracy(\n        pred,\n        label,\n        task=\"multiclass\",\n        num_classes=num_classes,\n    )\n\n\n@torch.no_grad()\ndef evaluate(args, model, graph, features, itemset, num_classes):\n    model.eval()\n    y = []\n    y_hats = []\n    dataloader = create_dataloader(\n        graph=graph,\n        features=features,\n        itemset=itemset,\n        batch_size=args.batch_size,\n        fanout=args.fanout,\n        device=args.device,\n        num_workers=args.num_workers,\n        job=\"evaluate\",\n    )\n\n    for step, data in tqdm(enumerate(dataloader), \"Evaluating\"):\n        x = data.node_features[\"feat\"]\n        y.append(data.labels)\n        y_hats.append(model(data.blocks, x))\n\n    return MF.accuracy(\n        torch.cat(y_hats),\n        torch.cat(y),\n        task=\"multiclass\",\n        num_classes=num_classes,\n    )\n\n\ndef train(args, graph, features, train_set, valid_set, num_classes, model):\n    optimizer = torch.optim.Adam(\n        model.parameters(), lr=args.lr, weight_decay=5e-4\n    )\n    dataloader = create_dataloader(\n        graph=graph,\n        features=features,\n        itemset=train_set,\n        batch_size=args.batch_size,\n        fanout=args.fanout,\n        device=args.device,\n        num_workers=args.num_workers,\n        job=\"train\",\n    )\n\n    for epoch in range(args.epochs):\n        t0 = time.time()\n        model.train()\n        total_loss = 0\n        for step, data in tqdm(enumerate(dataloader), \"Training\"):\n            # The input features from the source nodes in the first layer's\n            # computation graph.\n            x = data.node_features[\"feat\"]\n\n            # The ground truth labels from the destination nodes\n            # in the last layer's computation graph.\n            y = data.labels\n\n            y_hat = model(data.blocks, x)\n\n            # Compute loss.\n            loss = F.cross_entropy(y_hat, y)\n\n            optimizer.zero_grad()\n            loss.backward()\n            optimizer.step()\n\n            total_loss += loss.item()\n\n        t1 = time.time()\n        # Evaluate the model.\n        acc = evaluate(args, model, graph, features, valid_set, num_classes)\n        print(\n            f\"Epoch {epoch:05d} | Loss {total_loss / (step + 1):.4f} | \"\n            f\"Accuracy {acc.item():.4f} | Time {t1 - t0:.4f}\"\n        )\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(\n        description=\"A script trains and tests a GraphSAGE model \"\n        \"for node classification using GraphBolt dataloader.\"\n    )\n    parser.add_argument(\n        \"--epochs\", type=int, default=10, help=\"Number of training epochs.\"\n    )\n    parser.add_argument(\n        \"--lr\",\n        type=float,\n        default=1e-3,\n        help=\"Learning rate for optimization.\",\n    )\n    parser.add_argument(\n        \"--batch-size\", type=int, default=1024, help=\"Batch size for training.\"\n    )\n    parser.add_argument(\n        \"--num-workers\",\n        type=int,\n        default=0,\n        help=\"Number of workers for data loading.\",\n    )\n    parser.add_argument(\n        \"--fanout\",\n        type=str,\n        default=\"10,10,10\",\n        help=\"Fan-out of neighbor sampling. It is IMPORTANT to keep len(fanout)\"\n        \" identical with the number of layers in your model. Default: 10,10,10\",\n    )\n    parser.add_argument(\n        \"--dataset\",\n        type=str,\n        default=\"ogbn-products\",\n        choices=[\n            \"ogbn-arxiv\",\n            \"ogbn-products\",\n            \"ogbn-papers100M\",\n            \"igb-hom-tiny\",\n            \"igb-hom-small\",\n            \"igb-hom-medium\",\n            \"igb-hom-large\",\n            \"igb-hom\",\n        ],\n        help=\"The dataset we can use for node classification example. Currently\"\n        \" ogbn-products, ogbn-arxiv, ogbn-papers100M and\"\n        \" igb-hom-[tiny|small|medium|large] and igb-hom datasets are supported.\",\n    )\n    parser.add_argument(\n        \"--mode\",\n        default=\"pinned-cuda\",\n        choices=[\"cpu-cpu\", \"cpu-cuda\", \"pinned-cuda\", \"cuda-cuda\"],\n        help=\"Dataset storage placement and Train device: 'cpu' for CPU and RAM,\"\n        \" 'pinned' for pinned memory in RAM, 'cuda' for GPU and GPU memory.\",\n    )\n    parser.add_argument(\n        \"--sample-mode\",\n        default=\"sample_neighbor\",\n        choices=[\"sample_neighbor\", \"sample_layer_neighbor\"],\n        help=\"The sampling function when doing layerwise sampling.\",\n    )\n    return parser.parse_args()\n\n\ndef main(args):\n    if not torch.cuda.is_available():\n        args.mode = \"cpu-cpu\"\n    print(f\"Training in {args.mode} mode.\")\n    args.storage_device, args.device = args.mode.split(\"-\")\n    args.device = torch.device(args.device)\n\n    # Load and preprocess dataset.\n    print(\"Loading data...\")\n    dataset = gb.BuiltinDataset(args.dataset).load()\n\n    # Move the dataset to the selected storage.\n    if args.storage_device == \"pinned\":\n        graph = dataset.graph.pin_memory_()\n        features = dataset.feature.pin_memory_()\n    else:\n        graph = dataset.graph.to(args.storage_device)\n        features = dataset.feature.to(args.storage_device)\n\n    train_set = dataset.tasks[0].train_set\n    valid_set = dataset.tasks[0].validation_set\n    test_set = dataset.tasks[0].test_set\n    all_nodes_set = dataset.all_nodes_set\n    args.fanout = list(map(int, args.fanout.split(\",\")))\n\n    num_classes = dataset.tasks[0].metadata[\"num_classes\"]\n\n    in_size = features.size(\"node\", None, \"feat\")[0]\n    hidden_size = 256\n    out_size = num_classes\n\n    model = SAGE(in_size, hidden_size, out_size)\n    assert len(args.fanout) == len(model.layers)\n    model = model.to(args.device)\n\n    # Model training.\n    print(\"Training...\")\n    train(args, graph, features, train_set, valid_set, num_classes, model)\n\n    # Test the model.\n    print(\"Testing...\")\n    test_acc = layerwise_infer(\n        args,\n        graph,\n        features,\n        test_set,\n        all_nodes_set,\n        model,\n        num_classes,\n    )\n    print(f\"Test accuracy {test_acc.item():.4f}\")\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/graphbolt/pyg/README.md",
    "content": "##  Overview\n\nThis project demonstrates the training and evaluation of a GraphSAGE model for node classification on large graphs. The example utilizes GraphBolt for efficient data handling and PyG for the GNN training.\n\n\n# Node classification on graph\n\nThis example aims to demonstrate how to run node classification task on heterogeneous graph with **GraphBolt**. \n\n##  Model\n\nThe model is a three-layer GraphSAGE network implemented using PyTorch Geometric's SAGEConv layers.\n\n\n## Default Run on `ogbn-arxiv` dataset\n\n```\npython node_classification.py\n```\n\n\n\n\n## Accuracies\n```\nFinal performance(for ogbn-arxiv): \nAll runs:\nHighest Train: 62.26\nHighest Valid: 59.89\nFinal Train: 62.26\nFinal Test: 52.78\n```\n\n\n\n## Run on `ogbn-products` dataset\n\n### Sample on CPU and train/infer on CPU\n\n```\npython node_classification.py --dataset ogbn-products\n```\n\n## Accuracies\n```\nFinal performance(for ogbn-products): \nAll runs:\nHighest Train: 90.79\nHighest Valid: 89.86\nFinal Train: 90.79\nFinal Test: 75.24\n```\n\n\n\n\n\n"
  },
  {
    "path": "examples/graphbolt/pyg/hetero/node_classification.py",
    "content": "\"\"\"\nThis script is a PyG counterpart of ``/examples/graphbolt/rgcn/hetero_rgcn.py``.\n\"\"\"\n\nimport argparse\nimport time\n\nimport dgl.graphbolt as gb\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch_geometric.nn import SimpleConv\nfrom tqdm import tqdm\n\n\ndef accuracy(out, labels):\n    assert out.ndim == 2\n    assert out.size(0) == labels.size(0)\n    assert labels.ndim == 1 or (labels.ndim == 2 and labels.size(1) == 1)\n    labels = labels.flatten()\n    predictions = torch.argmax(out, 1)\n    return (labels == predictions).sum(dtype=torch.float64) / labels.size(0)\n\n\ndef create_dataloader(\n    graph,\n    features,\n    itemset,\n    batch_size,\n    fanout,\n    device,\n    job,\n):\n    \"\"\"Create a GraphBolt dataloader for training, validation or testing.\"\"\"\n    datapipe = gb.ItemSampler(\n        itemset,\n        batch_size=batch_size,\n        shuffle=(job == \"train\"),\n        drop_last=(job == \"train\"),\n    )\n    need_copy = True\n    # Copy the data to the specified device.\n    if args.graph_device != \"cpu\" and need_copy:\n        datapipe = datapipe.copy_to(device=device)\n        need_copy = False\n    # Sample neighbors for each node in the mini-batch.\n    datapipe = getattr(datapipe, args.sample_mode)(\n        graph,\n        fanout if job != \"infer\" else [-1],\n        overlap_fetch=args.overlap_graph_fetch,\n        num_gpu_cached_edges=args.num_gpu_cached_edges,\n        gpu_cache_threshold=args.gpu_graph_caching_threshold,\n        asynchronous=args.graph_device != \"cpu\",\n    )\n    # Copy the data to the specified device.\n    if args.feature_device != \"cpu\" and need_copy:\n        datapipe = datapipe.copy_to(device=device)\n        need_copy = False\n\n    node_feature_keys = {\"paper\": [\"feat\"], \"author\": [\"feat\"]}\n    if args.dataset == \"ogb-lsc-mag240m\":\n        node_feature_keys[\"institution\"] = [\"feat\"]\n    if \"igb-het\" in args.dataset:\n        node_feature_keys[\"institute\"] = [\"feat\"]\n        node_feature_keys[\"fos\"] = [\"feat\"]\n    # Fetch node features for the sampled subgraph.\n    datapipe = datapipe.fetch_feature(\n        features,\n        node_feature_keys,\n        overlap_fetch=args.overlap_feature_fetch,\n    )\n\n    # Copy the data to the specified device.\n    if need_copy:\n        datapipe = datapipe.copy_to(device=device)\n    # Create and return a DataLoader to handle data loading.\n    return gb.DataLoader(datapipe, num_workers=args.num_workers)\n\n\nclass RelGraphConvLayer(nn.Module):\n    def __init__(\n        self,\n        in_size,\n        out_size,\n        ntypes,\n        etypes,\n        activation,\n        dropout=0.0,\n    ):\n        super().__init__()\n        self.in_size = in_size\n        self.out_size = out_size\n        self.activation = activation\n\n        # Create a separate convolution layer for each relationship. PyG's\n        # SimpleConv does not have any weights and only performs message passing\n        # and aggregation.\n        self.convs = nn.ModuleDict(\n            {etype: SimpleConv(aggr=\"mean\") for etype in etypes}\n        )\n\n        # Create a separate Linear layer for each relationship. Each\n        # relationship has its own weights which will be applied to the node\n        # features before performing convolution.\n        self.weight = nn.ModuleDict(\n            {\n                etype: nn.Linear(in_size, out_size, bias=False)\n                for etype in etypes\n            }\n        )\n\n        # Create a separate Linear layer for each node type.\n        # loop_weights are used to update the output embedding of each target node\n        # based on its own features, thereby allowing the model to refine the node\n        # representations. Note that this does not imply the existence of self-loop\n        # edges in the graph. It is similar to residual connection.\n        self.loop_weights = nn.ModuleDict(\n            {ntype: nn.Linear(in_size, out_size, bias=True) for ntype in ntypes}\n        )\n\n        self.dropout = nn.Dropout(dropout)\n\n    def forward(self, subgraph, x):\n        # Create a dictionary of node features for the destination nodes in\n        # the graph. We slice the node features according to the number of\n        # destination nodes of each type. This is necessary because when\n        # incorporating the effect of self-loop edges, we perform computations\n        # only on the destination nodes' features. By doing so, we ensure the\n        # feature dimensions match and prevent any misuse of incorrect node\n        # features.\n        (h, h_dst), edge_index, size = subgraph.to_pyg(x)\n\n        h_out = {}\n        for etype in edge_index:\n            src_ntype, _, dst_ntype = gb.etype_str_to_tuple(etype)\n            # h_dst is unused in SimpleConv.\n            t = self.convs[etype](\n                (h[src_ntype], h_dst[dst_ntype]),\n                edge_index[etype],\n                size=size[etype],\n            )\n            t = self.weight[etype](t)\n            if dst_ntype in h_out:\n                h_out[dst_ntype] += t\n            else:\n                h_out[dst_ntype] = t\n\n        def _apply(ntype, x):\n            # Apply the `loop_weight` to the input node features, effectively\n            # acting as a residual connection. This allows the model to refine\n            # node embeddings based on its current features.\n            x = x + self.loop_weights[ntype](h_dst[ntype])\n            return self.dropout(self.activation(x))\n\n        # Apply the function defined above for each node type. This will update\n        # the node features using the `loop_weights`, apply the activation\n        # function and dropout.\n        return {ntype: _apply(ntype, h) for ntype, h in h_out.items()}\n\n\nclass EntityClassify(nn.Module):\n    def __init__(self, graph, in_size, hidden_size, out_size, n_layers):\n        super(EntityClassify, self).__init__()\n        self.layers = nn.ModuleList()\n        sizes = [in_size] + [hidden_size] * (n_layers - 1) + [out_size]\n        for i in range(n_layers):\n            self.layers.append(\n                RelGraphConvLayer(\n                    sizes[i],\n                    sizes[i + 1],\n                    graph.node_type_to_id.keys(),\n                    graph.edge_type_to_id.keys(),\n                    activation=F.relu if i != n_layers - 1 else lambda x: x,\n                    dropout=0.5,\n                )\n            )\n\n    def forward(self, subgraphs, h):\n        for layer, subgraph in zip(self.layers, subgraphs):\n            h = layer(subgraph, h)\n        return h\n\n\n@torch.compile\ndef evaluate_step(minibatch, model):\n    category = \"paper\"\n    node_features = {\n        ntype: feat.float()\n        for (ntype, name), feat in minibatch.node_features.items()\n        if name == \"feat\"\n    }\n    labels = minibatch.labels[category].long()\n    out = model(minibatch.sampled_subgraphs, node_features)[category]\n    num_correct = accuracy(out, labels) * labels.size(0)\n    return num_correct, labels.size(0)\n\n\n@torch.no_grad()\ndef evaluate(\n    model,\n    dataloader,\n    gpu_cache_miss_rate_fn,\n    cpu_cache_miss_rate_fn,\n    device,\n):\n    model.eval()\n    total_correct = torch.zeros(1, dtype=torch.float64, device=device)\n    total_samples = 0\n    dataloader = tqdm(dataloader, desc=\"Evaluating\")\n    for step, minibatch in enumerate(dataloader):\n        num_correct, num_samples = evaluate_step(minibatch, model)\n        total_correct += num_correct\n        total_samples += num_samples\n        if step % 15 == 0:\n            num_nodes = sum(id.size(0) for id in minibatch.node_ids().values())\n            dataloader.set_postfix(\n                {\n                    \"num_nodes\": num_nodes,\n                    \"gpu_cache_miss\": gpu_cache_miss_rate_fn(),\n                    \"cpu_cache_miss\": cpu_cache_miss_rate_fn(),\n                }\n            )\n\n    return total_correct / total_samples\n\n\n@torch.compile\ndef train_step(minibatch, optimizer, model, loss_fn):\n    category = \"paper\"\n    node_features = {\n        ntype: feat.float()\n        for (ntype, name), feat in minibatch.node_features.items()\n        if name == \"feat\"\n    }\n    labels = minibatch.labels[category].long()\n    optimizer.zero_grad()\n    out = model(minibatch.sampled_subgraphs, node_features)[category]\n    loss = loss_fn(out, labels)\n    # https://github.com/pytorch/pytorch/issues/133942\n    # num_correct = accuracy(out, labels) * labels.size(0)\n    num_correct = torch.zeros(1, dtype=torch.float64, device=out.device)\n    loss.backward()\n    optimizer.step()\n    return loss.detach(), num_correct, labels.size(0)\n\n\ndef train_helper(\n    dataloader,\n    model,\n    optimizer,\n    loss_fn,\n    gpu_cache_miss_rate_fn,\n    cpu_cache_miss_rate_fn,\n    device,\n):\n    model.train()\n    total_loss = torch.zeros(1, device=device)\n    total_correct = torch.zeros(1, dtype=torch.float64, device=device)\n    total_samples = 0\n    start = time.time()\n    dataloader = tqdm(dataloader, \"Training\")\n    for step, minibatch in enumerate(dataloader):\n        loss, num_correct, num_samples = train_step(\n            minibatch, optimizer, model, loss_fn\n        )\n        total_loss += loss * num_samples\n        total_correct += num_correct\n        total_samples += num_samples\n        if step % 15 == 0:\n            # log every 15 steps for performance.\n            num_nodes = sum(id.size(0) for id in minibatch.node_ids().values())\n            dataloader.set_postfix(\n                {\n                    \"num_nodes\": num_nodes,\n                    \"gpu_cache_miss\": gpu_cache_miss_rate_fn(),\n                    \"cpu_cache_miss\": cpu_cache_miss_rate_fn(),\n                }\n            )\n    loss = total_loss / total_samples\n    acc = total_correct / total_samples\n    end = time.time()\n    return loss, acc, end - start\n\n\ndef train(\n    train_dataloader,\n    valid_dataloader,\n    model,\n    gpu_cache_miss_rate_fn,\n    cpu_cache_miss_rate_fn,\n    device,\n):\n    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)\n    loss_fn = nn.CrossEntropyLoss()\n\n    for epoch in range(args.epochs):\n        train_loss, train_acc, duration = train_helper(\n            train_dataloader,\n            model,\n            optimizer,\n            loss_fn,\n            gpu_cache_miss_rate_fn,\n            cpu_cache_miss_rate_fn,\n            device,\n        )\n        val_acc = evaluate(\n            model,\n            valid_dataloader,\n            gpu_cache_miss_rate_fn,\n            cpu_cache_miss_rate_fn,\n            device,\n        )\n        print(\n            f\"Epoch: {epoch:02d}, Loss: {train_loss.item():.4f}, \"\n            f\"Approx. Train: {train_acc.item():.4f}, \"\n            f\"Approx. Val: {val_acc.item():.4f}, \"\n            f\"Time: {duration}s\"\n        )\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"GraphBolt PyG R-SAGE\")\n    parser.add_argument(\n        \"--epochs\", type=int, default=10, help=\"Number of training epochs.\"\n    )\n    parser.add_argument(\n        \"--lr\",\n        type=float,\n        default=0.001,\n        help=\"Learning rate for optimization.\",\n    )\n    parser.add_argument(\"--num-hidden\", type=int, default=1024)\n    parser.add_argument(\n        \"--batch-size\", type=int, default=1024, help=\"Batch size for training.\"\n    )\n    parser.add_argument(\"--num_workers\", type=int, default=0)\n    parser.add_argument(\n        \"--dataset\",\n        type=str,\n        default=\"ogb-lsc-mag240m\",\n        choices=[\n            \"ogb-lsc-mag240m\",\n            \"igb-het-tiny\",\n            \"igb-het-small\",\n            \"igb-het-medium\",\n        ],\n        help=\"Dataset name. Possible values: ogb-lsc-mag240m, igb-het-[tiny|small|medium].\",\n    )\n    parser.add_argument(\n        \"--fanout\",\n        type=str,\n        default=\"25,10\",\n        help=\"Fan-out of neighbor sampling. It is IMPORTANT to keep len(fanout)\"\n        \" identical with the number of layers in your model. Default: 25,10\",\n    )\n    parser.add_argument(\n        \"--mode\",\n        default=\"pinned-pinned-cuda\",\n        choices=[\n            \"cpu-cpu-cpu\",\n            \"cpu-cpu-cuda\",\n            \"cpu-pinned-cuda\",\n            \"pinned-pinned-cuda\",\n            \"cuda-pinned-cuda\",\n            \"cuda-cuda-cuda\",\n        ],\n        help=\"Graph storage - feature storage - Train device: 'cpu' for CPU and RAM,\"\n        \" 'pinned' for pinned memory in RAM, 'cuda' for GPU and GPU memory.\",\n    )\n    parser.add_argument(\n        \"--sample-mode\",\n        default=\"sample_neighbor\",\n        choices=[\"sample_neighbor\", \"sample_layer_neighbor\"],\n        help=\"The sampling function when doing layerwise sampling.\",\n    )\n    parser.add_argument(\n        \"--cpu-feature-cache-policy\",\n        type=str,\n        default=None,\n        choices=[\"s3-fifo\", \"sieve\", \"lru\", \"clock\"],\n        help=\"The cache policy for the CPU feature cache.\",\n    )\n    parser.add_argument(\n        \"--cpu-cache-size\",\n        type=float,\n        default=0,\n        help=\"The capacity of the CPU feature cache in GiB.\",\n    )\n    parser.add_argument(\n        \"--gpu-cache-size\",\n        type=float,\n        default=0,\n        help=\"The capacity of the GPU feature cache in GiB.\",\n    )\n    parser.add_argument(\n        \"--num-gpu-cached-edges\",\n        type=int,\n        default=0,\n        help=\"The number of edges to be cached from the graph on the GPU.\",\n    )\n    parser.add_argument(\n        \"--gpu-graph-caching-threshold\",\n        type=int,\n        default=1,\n        help=\"The number of accesses after which a vertex neighborhood will be cached.\",\n    )\n    parser.add_argument(\"--precision\", type=str, default=\"high\")\n    return parser.parse_args()\n\n\ndef main():\n    torch.set_float32_matmul_precision(args.precision)\n    if not torch.cuda.is_available():\n        args.mode = \"cpu-cpu-cpu\"\n    print(f\"Training in {args.mode} mode.\")\n    args.graph_device, args.feature_device, args.device = args.mode.split(\"-\")\n    args.overlap_feature_fetch = args.feature_device == \"pinned\"\n    args.overlap_graph_fetch = args.graph_device == \"pinned\"\n\n    # Load dataset.\n    dataset = gb.BuiltinDataset(args.dataset).load()\n    print(\"Dataset loaded\")\n\n    # Move the dataset to the selected storage.\n    graph = (\n        dataset.graph.pin_memory_()\n        if args.graph_device == \"pinned\"\n        else dataset.graph.to(args.graph_device)\n    )\n    features = (\n        dataset.feature.pin_memory_()\n        if args.feature_device == \"pinned\"\n        else dataset.feature.to(args.feature_device)\n    )\n\n    train_set = dataset.tasks[0].train_set\n    valid_set = dataset.tasks[0].validation_set\n    test_set = dataset.tasks[0].test_set\n    args.fanout = list(map(int, args.fanout.split(\",\")))\n\n    num_classes = dataset.tasks[0].metadata[\"num_classes\"]\n    num_etypes = len(graph.num_edges)\n\n    feats_on_disk = {\n        k: features[k]\n        for k in features.keys()\n        if k[2] == \"feat\" and isinstance(features[k], gb.DiskBasedFeature)\n    }\n\n    if args.cpu_cache_size > 0 and len(feats_on_disk) > 0:\n        cached_features = gb.cpu_cached_feature(\n            feats_on_disk,\n            int(args.cpu_cache_size * (2**30)),\n            args.cpu_feature_cache_policy,\n            args.feature_device == \"pinned\",\n        )\n        for k, cpu_cached_feature in cached_features.items():\n            features[k] = cpu_cached_feature\n            cpu_cache_miss_rate_fn = lambda: cpu_cached_feature.miss_rate\n    else:\n        cpu_cache_miss_rate_fn = lambda: 1\n\n    if args.gpu_cache_size > 0 and args.feature_device != \"cuda\":\n        feats = {k: features[k] for k in features.keys() if k[2] == \"feat\"}\n        cached_features = gb.gpu_cached_feature(\n            feats,\n            int(args.gpu_cache_size * (2**30)),\n        )\n        for k, gpu_cached_feature in cached_features.items():\n            features[k] = gpu_cached_feature\n            gpu_cache_miss_rate_fn = lambda: gpu_cached_feature.miss_rate\n    else:\n        gpu_cache_miss_rate_fn = lambda: 1\n\n    train_dataloader, valid_dataloader, test_dataloader = (\n        create_dataloader(\n            graph=graph,\n            features=features,\n            itemset=itemset,\n            batch_size=args.batch_size,\n            fanout=[\n                torch.full((num_etypes,), fanout) for fanout in args.fanout\n            ],\n            device=args.device,\n            job=job,\n        )\n        for itemset, job in zip(\n            [train_set, valid_set, test_set], [\"train\", \"evaluate\", \"evaluate\"]\n        )\n    )\n\n    feat_size = features.size(\"node\", \"paper\", \"feat\")[0]\n    hidden_channels = args.num_hidden\n\n    # Initialize the entity classification model.\n    model = EntityClassify(\n        graph, feat_size, hidden_channels, num_classes, len(args.fanout)\n    ).to(args.device)\n\n    print(\n        \"Number of model parameters: \"\n        f\"{sum(p.numel() for p in model.parameters())}\"\n    )\n\n    train(\n        train_dataloader,\n        valid_dataloader,\n        model,\n        gpu_cache_miss_rate_fn,\n        cpu_cache_miss_rate_fn,\n        args.device,\n    )\n\n    # Labels are currently unavailable for mag240M so the test acc will be 0.\n    print(\"Testing...\")\n    test_acc = evaluate(\n        model,\n        test_dataloader,\n        gpu_cache_miss_rate_fn,\n        cpu_cache_miss_rate_fn,\n        args.device,\n    )\n    print(f\"Test accuracy {test_acc.item():.4f}\")\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main()\n"
  },
  {
    "path": "examples/graphbolt/pyg/labor/README.md",
    "content": "Layer-Neighbor Sampling -- Defusing Neighborhood Explosion in GNNs\n============\n\n- Paper link: [https://papers.nips.cc/paper_files/paper/2023/hash/51f9036d5e7ae822da8f6d4adda1fb39-Abstract-Conference.html](NeurIPS 2023)\nThis is an official Labor sampling example to showcase the use of [https://docs.dgl.ai/en/latest/generated/dgl.graphbolt.LayerNeighborSampler.html](dgl.graphbolt.LayerNeighborSampler).\n\nThis sampler has 2 parameters, `layer_dependency=[False|True]` and\n`batch_dependency=k`, where k is any nonnegative integer.\n\nWe use early stopping so that the final accuracy numbers are reported with a\nfairly well converged model. Additional contributions to improve the validation\naccuracy are welcome, and hence hopefully also improving the test accuracy.\n\n### layer_dependency\n\nEnabling this parameter by the command line option `--layer-dependency` makes it so\nthat the random variates for sampling are identical across layers. This ensures\nthat the same vertex gets the same neighborhood in each layer.\n\n### batch_dependency\n\nThis method is proposed in Section 3.2 of [https://arxiv.org/pdf/2310.12403](Cooperative Minibatching in Graph Neural Networks), it is denoted as kappa in the paper. It\nmakes the random variates used across minibatches dependent, thus increasing \ntemporal locality. When used with a cache, the increase in the temporal locality\ncan be observed by monitoring the drop in the cache miss rate with higher values\nof the batch dependency parameter, speeding up embedding transfers to the GPU.\n\n### Performance\n\nUse the `--torch-compile` option for best performance. If your GPU has spare\nmemory, consider using `--mode=cuda-cuda-cuda` to move the whole dataset to the\nGPU. If not, consider using `--mode=cuda-pinned-cuda --num-gpu-cached-features=N`\nto keep the graph on the GPU and features in system RAM with `N` of the node\nfeatures cached on the GPU. If you can not even fit the graph on the GPU, then\nconsider using `--mode=pinned-pinned-cuda --num-gpu-cached-features=N`. Finally,\nyou can use `--mode=cpu-pinned=cuda --num-gpu-cached-features=N` to perform the\nsampling operation on the CPU.\n\n### Examples\n\nWe use `--num-gpu-cached-features=500000` to cache the 500k of the node\nembeddings for the `ogbn-products` dataset (default). Check the command line\narguments to see which other datasets can be run. When running with the yelp\ndataset, using `--dropout=0` gives better final validation and test accuracy.\n\nExample run with batch_dependency=1, cache miss rate is 62%:\n\n```bash\npython node_classification.py --num-gpu-cached-features=500000 --batch-dependency=1\nTraining in pinned-pinned-cuda mode.\nLoading data...\nThe dataset is already preprocessed.\nTraining: 192it [00:03, 50.95it/s, num_nodes=247243, cache_miss=0.619]\nEvaluating: 39it [00:00, 76.01it/s, num_nodes=137466, cache_miss=0.621]\nEpoch 00, Loss: 1.1161, Approx. Train: 0.7024, Approx. Val: 0.8612, Time: 3.7688188552856445s\n```\n\nExample run with batch_dependency=32, cache miss rate is 22%:\n\n```bash\npython node_classification.py --num-gpu-cached-features=500000 --batch-dependency=32\nTraining in pinned-pinned-cuda mode.\nLoading data...\nThe dataset is already preprocessed.\nTraining: 192it [00:03, 54.34it/s, num_nodes=250479, cache_miss=0.221]\nEvaluating: 39it [00:00, 84.66it/s, num_nodes=135142, cache_miss=0.226]\nEpoch 00, Loss: 1.1288, Approx. Train: 0.6993, Approx. Val: 0.8607, Time: 3.5339605808258057s\n```\n\nExample run with layer_dependency=True, # sampled nodes is 190k vs 250k without\nthis option:\n\n```bash\npython node_classification.py --num-gpu-cached-features=500000 --layer-dependency\nTraining in pinned-pinned-cuda mode.\nLoading data...\nThe dataset is already preprocessed.\nTraining: 192it [00:03, 54.03it/s, num_nodes=191259, cache_miss=0.626]\nEvaluating: 39it [00:00, 79.49it/s, num_nodes=108720, cache_miss=0.627]\nEpoch 00, Loss: 1.1495, Approx. Train: 0.6932, Approx. Val: 0.8586, Time: 3.5540308952331543s\n```\n\nExample run with the original GraphSAGE sampler (Neighbor Sampler), # sampled nodes \nis 520k, more than 2x higher than Labor sampler.\n\n```bash\npython node_classification.py --num-gpu-cached-features=500000 --sample-mode=sample_neighbor\nTraining in pinned-pinned-cuda mode.\nLoading data...\nThe dataset is already preprocessed.\nTraining: 192it [00:04, 45.60it/s, num_nodes=517522, cache_miss=0.563]\nEvaluating: 39it [00:00, 77.53it/s, num_nodes=255686, cache_miss=0.565]\nEpoch 00, Loss: 1.1152, Approx. Train: 0.7015, Approx. Val: 0.8652, Time: 4.211000919342041s\n```\n"
  },
  {
    "path": "examples/graphbolt/pyg/labor/load_dataset.py",
    "content": "import dgl.graphbolt as gb\n\n\ndef load_dgl(name):\n    from dgl.data import (\n        CiteseerGraphDataset,\n        CoraGraphDataset,\n        FlickrDataset,\n        PubmedGraphDataset,\n        RedditDataset,\n        YelpDataset,\n    )\n\n    d = {\n        \"cora\": CoraGraphDataset,\n        \"citeseer\": CiteseerGraphDataset,\n        \"pubmed\": PubmedGraphDataset,\n        \"reddit\": RedditDataset,\n        \"yelp\": YelpDataset,\n        \"flickr\": FlickrDataset,\n    }\n\n    dataset = gb.LegacyDataset(d[name]())\n    new_feature = gb.TorchBasedFeatureStore([])\n    new_feature._features = dataset.feature._features\n    dataset._feature = new_feature\n    multilabel = name in [\"yelp\"]\n    return dataset, multilabel\n\n\ndef load_dataset(dataset_name, disk_based_feature_keys=None):\n    multilabel = False\n    if dataset_name in [\n        \"reddit\",\n        \"cora\",\n        \"citeseer\",\n        \"pubmed\",\n        \"yelp\",\n        \"flickr\",\n    ]:\n        dataset, multilabel = load_dgl(dataset_name)\n    else:\n        if \"mag240M\" in dataset_name:\n            dataset_name = \"ogb-lsc-mag240m\"\n        dataset = gb.BuiltinDataset(dataset_name)\n        if disk_based_feature_keys is None:\n            disk_based_feature_keys = set()\n        for feature in dataset.yaml_data[\"feature_data\"]:\n            feature_key = (feature[\"domain\"], feature[\"type\"], feature[\"name\"])\n            # Set the in_memory setting to False without modifying YAML file.\n            if feature_key in disk_based_feature_keys:\n                feature[\"in_memory\"] = False\n        dataset = dataset.load()\n\n    return dataset, multilabel\n"
  },
  {
    "path": "examples/graphbolt/pyg/labor/node_classification.py",
    "content": "import argparse\nimport time\n\nfrom copy import deepcopy\nfrom functools import partial\n\nimport dgl.graphbolt as gb\nimport torch\n\n# For torch.compile until https://github.com/pytorch/pytorch/issues/121197 is\n# resolved.\nimport torch._inductor.codecache\n\ntorch._dynamo.config.cache_size_limit = 32\n\nimport torch.nn as nn\nimport torchmetrics.functional as MF\nfrom load_dataset import load_dataset\nfrom sage_conv import SAGEConv as CustomSAGEConv\nfrom torch_geometric.nn import SAGEConv\nfrom tqdm import tqdm\n\n\ndef accuracy(out, labels):\n    assert out.ndim == 2\n    assert out.size(0) == labels.size(0)\n    assert labels.ndim == 1 or (labels.ndim == 2 and labels.size(1) == 1)\n    labels = labels.flatten()\n    predictions = torch.argmax(out, 1)\n    return (labels == predictions).sum(dtype=torch.float64) / labels.size(0)\n\n\nclass GraphSAGE(torch.nn.Module):\n    def __init__(\n        self, in_size, hidden_size, out_size, n_layers, dropout, variant\n    ):\n        super().__init__()\n        assert variant in [\"original\", \"custom\"]\n        self.layers = torch.nn.ModuleList()\n        if variant == \"custom\":\n            sizes = [in_size] + [hidden_size] * n_layers\n            for i in range(n_layers):\n                self.layers.append(CustomSAGEConv(sizes[i], sizes[i + 1]))\n            self.linear = nn.Linear(hidden_size, out_size)\n            self.activation = nn.GELU()\n        else:\n            sizes = [in_size] + [hidden_size] * (n_layers - 1) + [out_size]\n            for i in range(n_layers):\n                self.layers.append(SAGEConv(sizes[i], sizes[i + 1]))\n            self.activation = nn.ReLU()\n        self.dropout = nn.Dropout(dropout)\n        self.hidden_size = hidden_size\n        self.out_size = out_size\n        self.variant = variant\n\n    def forward(self, subgraphs, x):\n        h = x\n        for i, (layer, subgraph) in enumerate(zip(self.layers, subgraphs)):\n            h, edge_index, size = subgraph.to_pyg(h)\n            h = layer(h, edge_index, size=size)\n            if self.variant == \"custom\":\n                h = self.activation(h)\n                h = self.dropout(h)\n            elif i != len(subgraphs) - 1:\n                h = self.activation(h)\n        return self.linear(h) if self.variant == \"custom\" else h\n\n    def inference(self, graph, features, dataloader, storage_device):\n        \"\"\"Conduct layer-wise inference to get all the node embeddings.\"\"\"\n        pin_memory = storage_device == \"pinned\"\n        buffer_device = torch.device(\"cpu\" if pin_memory else storage_device)\n\n        for layer_idx, layer in enumerate(self.layers):\n            is_last_layer = layer_idx == len(self.layers) - 1\n\n            y = torch.empty(\n                graph.total_num_nodes,\n                self.out_size if is_last_layer else self.hidden_size,\n                dtype=torch.float32,\n                device=buffer_device,\n                pin_memory=pin_memory,\n            )\n            for data in tqdm(dataloader, \"Inferencing\"):\n                # len(data.sampled_subgraphs) = 1\n                h, edge_index, size = data.sampled_subgraphs[0].to_pyg(\n                    data.node_features[\"feat\"]\n                )\n                hidden_x = layer(h, edge_index, size=size)\n                if self.variant == \"custom\":\n                    hidden_x = self.activation(hidden_x)\n                    if is_last_layer:\n                        hidden_x = self.linear(hidden_x)\n                elif not is_last_layer:\n                    hidden_x = self.activation(hidden_x)\n                # By design, our output nodes are contiguous.\n                y[data.seeds[0] : data.seeds[-1] + 1] = hidden_x.to(\n                    buffer_device\n                )\n            if not is_last_layer:\n                features.update(\"node\", None, \"feat\", y)\n\n        return y\n\n\ndef create_dataloader(\n    graph, features, itemset, batch_size, fanout, device, job\n):\n\n    # Initialize an ItemSampler to sample mini-batches from the dataset.\n    datapipe = gb.ItemSampler(\n        itemset,\n        batch_size=batch_size,\n        shuffle=(job == \"train\"),\n        drop_last=(job == \"train\"),\n    )\n    need_copy = True\n    # Copy the data to the specified device.\n    if args.graph_device != \"cpu\" and need_copy:\n        datapipe = datapipe.copy_to(device=device)\n        need_copy = False\n    # Sample neighbors for each node in the mini-batch.\n    kwargs = (\n        {\n            \"layer_dependency\": args.layer_dependency,\n            \"batch_dependency\": args.batch_dependency,\n        }\n        if args.sample_mode == \"sample_layer_neighbor\"\n        else {}\n    )\n    datapipe = getattr(datapipe, args.sample_mode)(\n        graph,\n        fanout if job != \"infer\" else [-1],\n        overlap_fetch=args.overlap_graph_fetch,\n        asynchronous=args.graph_device != \"cpu\",\n        **kwargs,\n    )\n    # Copy the data to the specified device.\n    if args.feature_device != \"cpu\" and need_copy:\n        datapipe = datapipe.copy_to(device=device)\n        need_copy = False\n    # Fetch node features for the sampled subgraph.\n    datapipe = datapipe.fetch_feature(\n        features,\n        node_feature_keys=[\"feat\"],\n        overlap_fetch=args.overlap_feature_fetch,\n    )\n    # Copy the data to the specified device.\n    if need_copy:\n        datapipe = datapipe.copy_to(device=device)\n    # Create and return a DataLoader to handle data loading.\n    return gb.DataLoader(datapipe, num_workers=args.num_workers)\n\n\n@torch.compile\ndef train_step(minibatch, optimizer, model, loss_fn, multilabel, eval_fn):\n    node_features = minibatch.node_features[\"feat\"]\n    labels = minibatch.labels\n    optimizer.zero_grad()\n    out = model(minibatch.sampled_subgraphs, node_features)\n    label_dtype = out.dtype if multilabel else None\n    loss = loss_fn(out, labels.to(label_dtype))\n    num_correct = eval_fn(out, labels) * labels.size(0)\n    loss.backward()\n    optimizer.step()\n    return loss.detach(), num_correct, labels.size(0)\n\n\ndef train_helper(\n    dataloader,\n    model,\n    optimizer,\n    loss_fn,\n    multilabel,\n    eval_fn,\n    gpu_cache_miss_rate_fn,\n    cpu_cache_miss_rate_fn,\n    device,\n):\n    model.train()  # Set the model to training mode\n    total_loss = torch.zeros(1, device=device)  # Accumulator for the total loss\n    # Accumulator for the total number of correct predictions\n    total_correct = torch.zeros(1, dtype=torch.float64, device=device)\n    total_samples = 0  # Accumulator for the total number of samples processed\n    num_batches = 0  # Counter for the number of mini-batches processed\n    start = time.time()\n    dataloader = tqdm(dataloader, \"Training\")\n    for step, minibatch in enumerate(dataloader):\n        loss, num_correct, num_samples = train_step(\n            minibatch, optimizer, model, loss_fn, multilabel, eval_fn\n        )\n        total_loss += loss\n        total_correct += num_correct\n        total_samples += num_samples\n        num_batches += 1\n        if step % 25 == 0:\n            # log every 25 steps for performance.\n            dataloader.set_postfix(\n                {\n                    \"num_nodes\": minibatch.node_ids().size(0),\n                    \"gpu_cache_miss\": gpu_cache_miss_rate_fn(),\n                    \"cpu_cache_miss\": cpu_cache_miss_rate_fn(),\n                }\n            )\n    train_loss = total_loss / num_batches\n    train_acc = total_correct / total_samples\n    end = time.time()\n    return train_loss, train_acc, end - start\n\n\ndef train(\n    train_dataloader,\n    valid_dataloader,\n    model,\n    multilabel,\n    eval_fn,\n    gpu_cache_miss_rate_fn,\n    cpu_cache_miss_rate_fn,\n    device,\n):\n    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)\n    loss_fn = nn.BCEWithLogitsLoss() if multilabel else nn.CrossEntropyLoss()\n\n    best_model = None\n    best_model_acc = 0\n    best_model_epoch = -1\n\n    for epoch in range(args.epochs):\n        train_loss, train_acc, duration = train_helper(\n            train_dataloader,\n            model,\n            optimizer,\n            loss_fn,\n            multilabel,\n            eval_fn,\n            gpu_cache_miss_rate_fn,\n            cpu_cache_miss_rate_fn,\n            device,\n        )\n        val_acc = evaluate(\n            model,\n            valid_dataloader,\n            eval_fn,\n            gpu_cache_miss_rate_fn,\n            cpu_cache_miss_rate_fn,\n            device,\n        )\n        if val_acc > best_model_acc:\n            best_model_acc = val_acc\n            best_model = deepcopy(model.state_dict())\n            best_model_epoch = epoch\n        print(\n            f\"Epoch {epoch:02d}, Loss: {train_loss.item():.4f}, \"\n            f\"Approx. Train: {train_acc.item():.4f}, \"\n            f\"Approx. Val: {val_acc.item():.4f}, \"\n            f\"Time: {duration}s\"\n        )\n        if best_model_epoch + args.early_stopping_patience < epoch:\n            break\n    return best_model\n\n\n@torch.no_grad()\ndef layerwise_infer(\n    args,\n    graph,\n    features,\n    itemsets,\n    all_nodes_set,\n    model,\n    eval_fn,\n):\n    model.eval()\n    dataloader = create_dataloader(\n        graph=graph,\n        features=features,\n        itemset=all_nodes_set,\n        batch_size=args.batch_size,\n        fanout=[-1],\n        device=args.device,\n        job=\"infer\",\n    )\n    pred = model.inference(graph, features, dataloader, args.feature_device)\n\n    metrics = {}\n    for split_name, itemset in itemsets.items():\n        nid, labels = itemset[:]\n        acc = eval_fn(\n            pred[nid.to(pred.device)],\n            labels.to(pred.device),\n        )\n        metrics[split_name] = acc.item()\n\n    return metrics\n\n\n@torch.compile\ndef evaluate_step(minibatch, model, eval_fn):\n    node_features = minibatch.node_features[\"feat\"]\n    labels = minibatch.labels\n    out = model(minibatch.sampled_subgraphs, node_features)\n    num_correct = eval_fn(out, labels) * labels.size(0)\n    return num_correct, labels.size(0)\n\n\n@torch.no_grad()\ndef evaluate(\n    model,\n    dataloader,\n    eval_fn,\n    gpu_cache_miss_rate_fn,\n    cpu_cache_miss_rate_fn,\n    device,\n):\n    model.eval()\n    total_correct = torch.zeros(1, dtype=torch.float64, device=device)\n    total_samples = 0\n    dataloader = tqdm(dataloader, \"Evaluating\")\n    for step, minibatch in enumerate(dataloader):\n        num_correct, num_samples = evaluate_step(minibatch, model, eval_fn)\n        total_correct += num_correct\n        total_samples += num_samples\n        if step % 25 == 0:\n            dataloader.set_postfix(\n                {\n                    \"num_nodes\": minibatch.node_ids().size(0),\n                    \"gpu_cache_miss\": gpu_cache_miss_rate_fn(),\n                    \"cpu_cache_miss\": cpu_cache_miss_rate_fn(),\n                }\n            )\n\n    return total_correct / total_samples\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(\n        description=\"Which dataset are you going to use?\"\n    )\n    parser.add_argument(\n        \"--epochs\", type=int, default=9999999, help=\"Number of training epochs.\"\n    )\n    parser.add_argument(\n        \"--lr\",\n        type=float,\n        default=0.001,\n        help=\"Learning rate for optimization.\",\n    )\n    parser.add_argument(\"--num-hidden\", type=int, default=256)\n    parser.add_argument(\"--dropout\", type=float, default=0.5)\n    parser.add_argument(\n        \"--batch-size\", type=int, default=1024, help=\"Batch size for training.\"\n    )\n    parser.add_argument(\n        \"--num-workers\",\n        type=int,\n        default=0,\n        help=\"Number of workers for data loading.\",\n    )\n    parser.add_argument(\n        \"--dataset\",\n        type=str,\n        default=\"ogbn-products\",\n        choices=[\n            \"ogbn-arxiv\",\n            \"ogbn-products\",\n            \"ogbn-papers100M\",\n            \"igb-hom-tiny\",\n            \"igb-hom-small\",\n            \"igb-hom-medium\",\n            \"igb-hom-large\",\n            \"igb-hom\",\n            \"reddit\",\n            \"yelp\",\n            \"flickr\",\n        ],\n    )\n    parser.add_argument(\n        \"--fanout\",\n        type=str,\n        default=\"10,10,10\",\n        help=\"Fan-out of neighbor sampling. len(fanout) determines the number of\"\n        \" GNN layers in your model. Default: 10,10,10\",\n    )\n    parser.add_argument(\n        \"--mode\",\n        default=\"pinned-pinned-cuda\",\n        choices=[\n            \"cpu-cpu-cpu\",\n            \"cpu-cpu-cuda\",\n            \"cpu-pinned-cuda\",\n            \"pinned-pinned-cuda\",\n            \"cuda-pinned-cuda\",\n            \"cuda-cuda-cuda\",\n        ],\n        help=\"Graph storage - feature storage - Train device: 'cpu' for CPU and\"\n        \" RAM, 'pinned' for pinned memory in RAM, 'cuda' for GPU and GPU memory.\",\n    )\n    parser.add_argument(\"--layer-dependency\", action=\"store_true\")\n    parser.add_argument(\"--batch-dependency\", type=int, default=1)\n    parser.add_argument(\n        \"--cpu-feature-cache-policy\",\n        type=str,\n        default=None,\n        choices=[\"s3-fifo\", \"sieve\", \"lru\", \"clock\"],\n        help=\"The cache policy for the CPU feature cache.\",\n    )\n    parser.add_argument(\n        \"--num-cpu-cached-features\",\n        type=int,\n        default=0,\n        help=\"The capacity of the CPU cache, the number of features to store.\",\n    )\n    parser.add_argument(\n        \"--num-gpu-cached-features\",\n        type=int,\n        default=0,\n        help=\"The capacity of the GPU cache, the number of features to store.\",\n    )\n    parser.add_argument(\"--early-stopping-patience\", type=int, default=25)\n    parser.add_argument(\n        \"--sample-mode\",\n        default=\"sample_layer_neighbor\",\n        choices=[\"sample_neighbor\", \"sample_layer_neighbor\"],\n        help=\"The sampling function when doing layerwise sampling.\",\n    )\n    parser.add_argument(\n        \"--sage-model-variant\",\n        default=\"custom\",\n        choices=[\"custom\", \"original\"],\n        help=\"The custom SAGE GNN model provides higher accuracy with lower\"\n        \" runtime performance.\",\n    )\n    parser.add_argument(\"--precision\", type=str, default=\"high\")\n    return parser.parse_args()\n\n\ndef main():\n    torch.set_float32_matmul_precision(args.precision)\n    if not torch.cuda.is_available():\n        args.mode = \"cpu-cpu-cpu\"\n    print(f\"Training in {args.mode} mode.\")\n    args.graph_device, args.feature_device, args.device = args.mode.split(\"-\")\n    args.overlap_feature_fetch = args.feature_device == \"pinned\"\n    args.overlap_graph_fetch = args.graph_device == \"pinned\"\n\n    # Load and preprocess dataset.\n    print(\"Loading data...\")\n    disk_based_feature_keys = None\n    if args.num_cpu_cached_features > 0:\n        disk_based_feature_keys = [(\"node\", None, \"feat\")]\n    dataset, multilabel = load_dataset(args.dataset, disk_based_feature_keys)\n\n    # Move the dataset to the selected storage.\n    graph = (\n        dataset.graph.pin_memory_()\n        if args.graph_device == \"pinned\"\n        else dataset.graph.to(args.graph_device)\n    )\n    features = (\n        dataset.feature.pin_memory_()\n        if args.feature_device == \"pinned\"\n        else dataset.feature.to(args.feature_device)\n    )\n\n    train_set = dataset.tasks[0].train_set\n    valid_set = dataset.tasks[0].validation_set\n    test_set = dataset.tasks[0].test_set\n    all_nodes_set = dataset.all_nodes_set\n    args.fanout = list(map(int, args.fanout.split(\",\")))\n\n    num_classes = dataset.tasks[0].metadata[\"num_classes\"]\n\n    feature_index_device = (\n        args.feature_device if args.feature_device != \"pinned\" else None\n    )\n    feature_num_bytes = (\n        features[(\"node\", None, \"feat\")]\n        # Read a single row to query its size in bytes.\n        .read(torch.zeros(1, device=feature_index_device).long()).nbytes\n    )\n    if args.num_cpu_cached_features > 0 and isinstance(\n        features[(\"node\", None, \"feat\")], gb.DiskBasedFeature\n    ):\n        features[(\"node\", None, \"feat\")] = gb.cpu_cached_feature(\n            features[(\"node\", None, \"feat\")],\n            args.num_cpu_cached_features * feature_num_bytes,\n            args.cpu_feature_cache_policy,\n            args.feature_device == \"pinned\",\n        )\n        cpu_cached_feature = features[(\"node\", None, \"feat\")]\n        cpu_cache_miss_rate_fn = lambda: cpu_cached_feature.miss_rate\n    else:\n        cpu_cache_miss_rate_fn = lambda: 1\n    if args.num_gpu_cached_features > 0 and args.feature_device != \"cuda\":\n        features[(\"node\", None, \"feat\")] = gb.gpu_cached_feature(\n            features[(\"node\", None, \"feat\")],\n            args.num_gpu_cached_features * feature_num_bytes,\n        )\n        gpu_cached_feature = features[(\"node\", None, \"feat\")]\n        gpu_cache_miss_rate_fn = lambda: gpu_cached_feature.miss_rate\n    else:\n        gpu_cache_miss_rate_fn = lambda: 1\n\n    train_dataloader, valid_dataloader = (\n        create_dataloader(\n            graph=graph,\n            features=features,\n            itemset=itemset,\n            batch_size=args.batch_size,\n            fanout=args.fanout,\n            device=args.device,\n            job=job,\n        )\n        for itemset, job in zip([train_set, valid_set], [\"train\", \"evaluate\"])\n    )\n\n    in_channels = features.size(\"node\", None, \"feat\")[0]\n    model = GraphSAGE(\n        in_channels,\n        args.num_hidden,\n        num_classes,\n        len(args.fanout),\n        args.dropout,\n        args.sage_model_variant,\n    ).to(args.device)\n    assert len(args.fanout) == len(model.layers)\n\n    eval_fn = (\n        partial(\n            # TODO @mfbalin: Find an implementation that does not synchronize.\n            MF.f1_score,\n            task=\"multilabel\",\n            num_labels=num_classes,\n            validate_args=False,\n        )\n        if multilabel\n        else accuracy\n    )\n\n    best_model = train(\n        train_dataloader,\n        valid_dataloader,\n        model,\n        multilabel,\n        eval_fn,\n        gpu_cache_miss_rate_fn,\n        cpu_cache_miss_rate_fn,\n        args.device,\n    )\n    model.load_state_dict(best_model)\n\n    # Test the model.\n    print(\"Testing...\")\n    itemsets = {\"train\": train_set, \"val\": valid_set, \"test\": test_set}\n    final_acc = layerwise_infer(\n        args,\n        graph,\n        features,\n        itemsets,\n        all_nodes_set,\n        model,\n        eval_fn,\n    )\n    print(\"Final accuracy values:\")\n    print(final_acc)\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main()\n"
  },
  {
    "path": "examples/graphbolt/pyg/labor/sage_conv.py",
    "content": "from typing import List, Optional, Tuple, Union\n\nimport torch.nn.functional as F\nfrom torch import Tensor\n\nfrom torch_geometric.nn.aggr import Aggregation, MultiAggregation\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.dense.linear import Linear\nfrom torch_geometric.typing import Adj, OptPairTensor, Size, SparseTensor\nfrom torch_geometric.utils import spmm\n\n\nclass SAGEConv(MessagePassing):\n    r\"\"\"A variant of the GraphSAGE operator from the `\"Inductive Representation\n    Learning on Large Graphs\" <https://arxiv.org/abs/1706.02216>`_ paper.\n\n    .. math::\n        \\mathbf{x}^{\\prime}_i = \\mathbf{W}_1 \\mathbf{x}_i + \\mathbf{W}_2 \\cdot\n        \\mathrm{mean}_{j \\in \\mathcal{N(i)}} \\mathbf{x}_j\n\n    If :obj:`project = True`, then :math:`\\mathbf{x}_j` will first get\n    projected via\n\n    .. math::\n        \\mathbf{x}_j \\leftarrow \\sigma ( \\mathbf{W}_3 \\mathbf{x}_j +\n        \\mathbf{b})\n\n    as described in Eq. (3) of the paper.\n\n    Args:\n        in_channels (int or tuple): Size of each input sample, or :obj:`-1` to\n            derive the size from the first input(s) to the forward method.\n            A tuple corresponds to the sizes of source and target\n            dimensionalities.\n        out_channels (int): Size of each output sample.\n        aggr (str or Aggregation, optional): The aggregation scheme to use.\n            Any aggregation of :obj:`torch_geometric.nn.aggr` can be used,\n            *e.g.*, :obj:`\"mean\"`, :obj:`\"max\"`, or :obj:`\"lstm\"`.\n            (default: :obj:`\"mean\"`)\n        project (bool, optional): If set to :obj:`True`, the layer will apply a\n            linear transformation followed by an activation function before\n            aggregation (as described in Eq. (3) of the paper).\n            (default: :obj:`True`)\n        bias (bool, optional): If set to :obj:`False`, the layer will not learn\n            an additive bias. (default: :obj:`True`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.MessagePassing`.\n\n    Shapes:\n        - **inputs:**\n          node features :math:`(|\\mathcal{V}|, F_{in})` or\n          :math:`((|\\mathcal{V_s}|, F_{s}), (|\\mathcal{V_t}|, F_{t}))`\n          if bipartite,\n          edge indices :math:`(2, |\\mathcal{E}|)`\n        - **outputs:** node features :math:`(|\\mathcal{V}|, F_{out})` or\n          :math:`(|\\mathcal{V_t}|, F_{out})` if bipartite\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels: Union[int, Tuple[int, int]],\n        out_channels: int,\n        aggr: Optional[Union[str, List[str], Aggregation]] = \"mean\",\n        project: bool = True,\n        bias: bool = True,\n        **kwargs,\n    ):\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.project = project\n\n        if isinstance(in_channels, int):\n            in_channels = (in_channels, in_channels)\n\n        if aggr == \"lstm\":\n            kwargs.setdefault(\"aggr_kwargs\", {})\n            kwargs[\"aggr_kwargs\"].setdefault(\"in_channels\", in_channels[0])\n            kwargs[\"aggr_kwargs\"].setdefault(\"out_channels\", in_channels[0])\n\n        super().__init__(aggr, **kwargs)\n\n        if self.project:\n            if in_channels[0] <= 0:\n                raise ValueError(\n                    f\"'{self.__class__.__name__}' does not \"\n                    f\"support lazy initialization with \"\n                    f\"`project=True`\"\n                )\n            self.lin = Linear(in_channels[0], in_channels[0], bias=True)\n\n        if isinstance(self.aggr_module, MultiAggregation):\n            aggr_out_channels = self.aggr_module.get_out_channels(\n                in_channels[0]\n            )\n        else:\n            aggr_out_channels = in_channels[0]\n\n        self.lin_l = Linear(aggr_out_channels, out_channels, bias=bias)\n        self.lin_r = Linear(in_channels[1], out_channels, bias=False)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        super().reset_parameters()\n        if self.project:\n            self.lin.reset_parameters()\n        self.lin_l.reset_parameters()\n        self.lin_r.reset_parameters()\n\n    def forward(\n        self,\n        x: Union[Tensor, OptPairTensor],\n        edge_index: Adj,\n        size: Size = None,\n    ) -> Tensor:\n\n        if isinstance(x, Tensor):\n            x = (x, x)\n\n        if self.project and hasattr(self, \"lin\"):\n            x = (F.gelu(self.lin(x[0])), x[1])\n\n        # propagate_type: (x: OptPairTensor)\n        AX = self.propagate(edge_index, x=x, size=size)\n        out = self.lin_l(AX)\n\n        x_r = x[1]\n        if x_r is not None:\n            out = out + self.lin_r(x_r)\n\n        return out\n\n    def message(self, x_j: Tensor) -> Tensor:\n        return x_j\n\n    def message_and_aggregate(self, adj_t: Adj, x: OptPairTensor) -> Tensor:\n        if isinstance(adj_t, SparseTensor):\n            adj_t = adj_t.set_value(None, layout=None)\n        return spmm(adj_t, x[0], reduce=self.aggr)\n\n    def __repr__(self) -> str:\n        return (\n            f\"{self.__class__.__name__}({self.in_channels}, \"\n            f\"{self.out_channels}, aggr={self.aggr})\"\n        )\n"
  },
  {
    "path": "examples/graphbolt/pyg/link_prediction.py",
    "content": "\"\"\"\nThis script trains and tests a GraphSAGE model for link prediction on\nlarge graphs using graphbolt dataloader. It is the PyG counterpart of the\nexample in `examples/graphbolt/link_prediction.py`.\n\nPaper: [Inductive Representation Learning on Large Graphs]\n(https://arxiv.org/abs/1706.02216)\n\nWhile node classification predicts labels for nodes based on their\nlocal neighborhoods, link prediction assesses the likelihood of an edge\nexisting between two nodes, necessitating different sampling strategies\nthat account for pairs of nodes and their joint neighborhoods.\n\nThis flowchart describes the main functional sequence of the provided example.\nmain\n│\n├───> OnDiskDataset pre-processing\n│\n├───> Instantiate SAGE model\n│\n├───> train\n│     │\n│     ├───> Get graphbolt dataloader (HIGHLIGHT)\n|     |\n|     |───> Define a PyG GNN model for link prediction (HIGHLIGHT)\n│     │\n│     └───> Training loop\n│           │\n│           ├───> SAGE.forward\n│\n└───> Validation and test set evaluation\n\"\"\"\nimport argparse\nimport time\nfrom functools import partial\n\nimport dgl.graphbolt as gb\nimport torch\n\n# For torch.compile until https://github.com/pytorch/pytorch/issues/121197 is\n# resolved.\nimport torch._inductor.codecache\n\ntorch._dynamo.config.cache_size_limit = 32\n\nimport torch.nn.functional as F\nfrom torch_geometric.nn import SAGEConv\nfrom torchmetrics.retrieval import RetrievalMRR\nfrom tqdm import tqdm, trange\n\n\nclass GraphSAGE(torch.nn.Module):\n    #####################################################################\n    # (HIGHLIGHT) Define the GraphSAGE model architecture.\n    #\n    # - This class inherits from `torch.nn.Module`.\n    # - Two convolutional layers are created using the SAGEConv class from PyG.\n    # - The forward method defines the computation performed at every call.\n    #####################################################################\n    def __init__(self, in_size, hidden_size, n_layers):\n        super(GraphSAGE, self).__init__()\n        self.layers = torch.nn.ModuleList()\n        sizes = [in_size] + [hidden_size] * n_layers\n        for i in range(n_layers):\n            self.layers.append(SAGEConv(sizes[i], sizes[i + 1]))\n        self.hidden_size = hidden_size\n        self.predictor = torch.nn.Sequential(\n            torch.nn.Linear(hidden_size, hidden_size),\n            torch.nn.ReLU(),\n            torch.nn.Linear(hidden_size, hidden_size),\n            torch.nn.ReLU(),\n            torch.nn.Linear(hidden_size, 1),\n        )\n\n    def forward(self, subgraphs, x):\n        h = x\n        for i, (layer, subgraph) in enumerate(zip(self.layers, subgraphs)):\n            #####################################################################\n            # (HIGHLIGHT) Convert given features to be consumed by a PyG layer.\n            #\n            #   PyG layers have two modes, bipartite and normal. We slice the\n            #   given features to get src and dst features to use the PyG layers\n            #   in the more efficient bipartite mode.\n            #####################################################################\n            h, edge_index, size = subgraph.to_pyg(h)\n            h = layer(h, edge_index, size=size)\n            if i != len(subgraphs) - 1:\n                h = F.relu(h)\n        return h\n\n    def inference(self, graph, features, dataloader, storage_device):\n        \"\"\"Conduct layer-wise inference to get all the node embeddings.\"\"\"\n        pin_memory = storage_device == \"pinned\"\n        buffer_device = torch.device(\"cpu\" if pin_memory else storage_device)\n\n        for layer_idx, layer in enumerate(self.layers):\n            is_last_layer = layer_idx == len(self.layers) - 1\n\n            y = torch.empty(\n                graph.total_num_nodes,\n                self.hidden_size,\n                dtype=torch.float32,\n                device=buffer_device,\n                pin_memory=pin_memory,\n            )\n            for data in tqdm(dataloader, \"Inferencing\"):\n                # len(data.sampled_subgraphs) = 1\n                h, edge_index, size = data.sampled_subgraphs[0].to_pyg(\n                    data.node_features[\"feat\"]\n                )\n                hidden_x = layer(h, edge_index, size=size)\n                if not is_last_layer:\n                    hidden_x = F.relu(hidden_x)\n                # By design, our output nodes are contiguous.\n                y[data.seeds[0] : data.seeds[-1] + 1] = hidden_x.to(\n                    buffer_device\n                )\n            if not is_last_layer:\n                features.update(\"node\", None, \"feat\", y)\n\n        return y\n\n\ndef create_dataloader(\n    graph, features, itemset, batch_size, fanout, device, job\n):\n    #####################################################################\n    # (HIGHLIGHT) Create a data loader for efficiently loading graph data.\n    #\n    # - 'ItemSampler' samples mini-batches of node IDs from the dataset.\n    # - 'CopyTo' copies the fetched data to the specified device.\n    # - 'sample_neighbor' performs neighbor sampling on the graph.\n    # - 'FeatureFetcher' fetches node features based on the sampled subgraph.\n\n    #####################################################################\n    # Create a datapipe for mini-batch sampling with a specific neighbor fanout.\n    # Here, [10, 10, 10] specifies the number of neighbors sampled for each node at each layer.\n    # We're using `sample_neighbor` for consistency with DGL's sampling API.\n    # Note: GraphBolt offers additional sampling methods, such as `sample_layer_neighbor`,\n    # which could provide further optimization and efficiency for GNN training.\n    # Users are encouraged to explore these advanced features for potentially improved performance.\n\n    # Initialize an ItemSampler to sample mini-batches from the dataset.\n    datapipe = gb.ItemSampler(\n        itemset,\n        batch_size=batch_size,\n        shuffle=(job == \"train\"),\n        drop_last=(job == \"train\"),\n    )\n    need_copy = True\n    # Copy the data to the specified device.\n    if args.graph_device != \"cpu\" and need_copy:\n        datapipe = datapipe.copy_to(device=device)\n        need_copy = False\n    # Sample negative edges.\n    if job == \"train\":\n        datapipe = datapipe.sample_uniform_negative(graph, args.neg_ratio)\n    # Sample neighbors for each node in the mini-batch.\n    datapipe = getattr(datapipe, args.sample_mode)(\n        graph,\n        fanout if job != \"infer\" else [-1],\n        overlap_fetch=args.overlap_graph_fetch,\n        asynchronous=args.graph_device != \"cpu\",\n    )\n    if job == \"train\" and args.exclude_edges:\n        datapipe = datapipe.exclude_seed_edges(\n            include_reverse_edges=True,\n            asynchronous=args.graph_device != \"cpu\",\n        )\n    # Copy the data to the specified device.\n    if args.feature_device != \"cpu\" and need_copy:\n        datapipe = datapipe.copy_to(device=device)\n        need_copy = False\n    # Fetch node features for the sampled subgraph.\n    datapipe = datapipe.fetch_feature(\n        features,\n        node_feature_keys=[\"feat\"],\n        overlap_fetch=args.overlap_feature_fetch,\n    )\n    # Copy the data to the specified device.\n    if need_copy:\n        datapipe = datapipe.copy_to(device=device)\n    # Create and return a DataLoader to handle data loading.\n    return gb.DataLoader(datapipe, num_workers=args.num_workers)\n\n\n@torch.compile\ndef predictions_step(model, h_src, h_dst):\n    return model.predictor(h_src * h_dst).squeeze()\n\n\ndef compute_predictions(model, node_emb, seeds, device):\n    \"\"\"Compute the predictions for given source and destination nodes.\n\n    This function computes the predictions for a set of node pairs, dividing the\n    task into batches to handle potentially large graphs.\n    \"\"\"\n\n    preds = torch.empty(seeds.shape[0], device=device)\n    seeds_src, seeds_dst = seeds.T\n    # The constant number is 1001, due to negtive ratio in the `ogbl-citation2`\n    # dataset is 1000.\n    eval_size = args.eval_batch_size * 1001\n    # Loop over node pairs in batches.\n    for start in trange(0, seeds_src.shape[0], eval_size, desc=\"Evaluate\"):\n        end = min(start + eval_size, seeds_src.shape[0])\n\n        # Fetch embeddings for current batch of source and destination nodes.\n        h_src = node_emb[seeds_src[start:end]].to(device, non_blocking=True)\n        h_dst = node_emb[seeds_dst[start:end]].to(device, non_blocking=True)\n\n        # Compute prediction scores using the model.\n        preds[start:end] = predictions_step(model, h_src, h_dst)\n    return preds\n\n\n@torch.no_grad()\ndef evaluate(model, graph, features, all_nodes_set, valid_set, test_set):\n    \"\"\"Evaluate the model on validation and test sets.\"\"\"\n    model.eval()\n\n    dataloader = create_dataloader(\n        graph,\n        features,\n        all_nodes_set,\n        args.eval_batch_size,\n        [-1],\n        args.device,\n        job=\"infer\",\n    )\n\n    # Compute node embeddings for the entire graph.\n    node_emb = model.inference(graph, features, dataloader, args.feature_device)\n    results = []\n\n    # Loop over both validation and test sets.\n    for split in [valid_set, test_set]:\n        # Unpack the item set.\n        seeds = split._items[0].to(node_emb.device)\n        labels = split._items[1].to(node_emb.device)\n        indexes = split._items[2].to(node_emb.device)\n\n        preds = compute_predictions(model, node_emb, seeds, indexes.device)\n        # Compute MRR values for the current split.\n        results.append(RetrievalMRR()(preds, labels, indexes))\n    return results\n\n\n@torch.compile\ndef train_step(minibatch, optimizer, model):\n    node_features = minibatch.node_features[\"feat\"]\n    compacted_seeds = minibatch.compacted_seeds.T\n    labels = minibatch.labels\n    optimizer.zero_grad()\n    y = model(minibatch.sampled_subgraphs, node_features)\n    logits = model.predictor(\n        y[compacted_seeds[0]] * y[compacted_seeds[1]]\n    ).squeeze()\n    loss = F.binary_cross_entropy_with_logits(logits, labels)\n    loss.backward()\n    optimizer.step()\n    return loss.detach(), labels.size(0)\n\n\ndef train_helper(dataloader, model, optimizer, device):\n    model.train()  # Set the model to training mode\n    total_loss = torch.zeros(1, device=device)  # Accumulator for the total loss\n    total_samples = 0  # Accumulator for the total number of samples processed\n    start = time.time()\n    for step, minibatch in tqdm(enumerate(dataloader), \"Training\"):\n        loss, num_samples = train_step(minibatch, optimizer, model)\n        total_loss += loss * num_samples\n        total_samples += num_samples\n        if step + 1 == args.early_stop:\n            break\n    train_loss = total_loss / total_samples\n    end = time.time()\n    return train_loss, end - start\n\n\ndef train(dataloader, model, device):\n    #####################################################################\n    # (HIGHLIGHT) Train the model for one epoch.\n    #\n    # - Iterates over the data loader, fetching mini-batches of graph data.\n    # - For each mini-batch, it performs a forward pass, computes loss, and\n    #   updates the model parameters.\n    # - The function returns the average loss and accuracy for the epoch.\n    #\n    # Parameters:\n    #   dataloader: DataLoader that provides mini-batches of graph data.\n    #   model: The GraphSAGE model.\n    #   device: The device (CPU/GPU) to run the training on.\n    #####################################################################\n\n    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)\n\n    for epoch in range(args.epochs):\n        train_loss, duration = train_helper(\n            dataloader, model, optimizer, device\n        )\n        print(\n            f\"Epoch {epoch:02d}, Loss: {train_loss.item():.4f}, \"\n            f\"Time: {duration}s\"\n        )\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(\n        description=\"Which dataset are you going to use?\"\n    )\n    parser.add_argument(\n        \"--epochs\", type=int, default=10, help=\"Number of training epochs.\"\n    )\n    parser.add_argument(\n        \"--lr\",\n        type=float,\n        default=0.003,\n        help=\"Learning rate for optimization.\",\n    )\n    parser.add_argument(\"--neg-ratio\", type=int, default=1)\n    parser.add_argument(\"--train-batch-size\", type=int, default=512)\n    parser.add_argument(\"--eval-batch-size\", type=int, default=1024)\n    parser.add_argument(\n        \"--batch-size\", type=int, default=1024, help=\"Batch size for training.\"\n    )\n    parser.add_argument(\n        \"--num-workers\",\n        type=int,\n        default=0,\n        help=\"Number of workers for data loading.\",\n    )\n    parser.add_argument(\n        \"--early-stop\",\n        type=int,\n        default=0,\n        help=\"0 means no early stop, otherwise stop at the input-th step\",\n    )\n    parser.add_argument(\n        \"--dataset\",\n        type=str,\n        default=\"ogbl-citation2\",\n        choices=[\"ogbl-citation2\"],\n        help=\"The dataset we can use for link prediction. Currently\"\n        \" only ogbl-citation2 dataset is supported.\",\n    )\n    parser.add_argument(\n        \"--fanout\",\n        type=str,\n        default=\"10,10,10\",\n        help=\"Fan-out of neighbor sampling. It is IMPORTANT to keep len(fanout)\"\n        \" identical with the number of layers in your model. Default: 10,10,10\",\n    )\n    parser.add_argument(\n        \"--exclude-edges\",\n        type=bool,\n        default=True,\n        help=\"Whether to exclude reverse edges during sampling. Default: True\",\n    )\n    parser.add_argument(\n        \"--mode\",\n        default=\"pinned-pinned-cuda\",\n        choices=[\n            \"cpu-cpu-cpu\",\n            \"cpu-cpu-cuda\",\n            \"cpu-pinned-cuda\",\n            \"pinned-pinned-cuda\",\n            \"cuda-pinned-cuda\",\n            \"cuda-cuda-cuda\",\n        ],\n        help=\"Graph storage - feature storage - Train device: 'cpu' for CPU and RAM,\"\n        \" 'pinned' for pinned memory in RAM, 'cuda' for GPU and GPU memory.\",\n    )\n    parser.add_argument(\n        \"--gpu-cache-size\",\n        type=int,\n        default=0,\n        help=\"The capacity of the GPU cache in bytes.\",\n    )\n    parser.add_argument(\n        \"--sample-mode\",\n        default=\"sample_neighbor\",\n        choices=[\"sample_neighbor\", \"sample_layer_neighbor\"],\n        help=\"The sampling function when doing layerwise sampling.\",\n    )\n    parser.add_argument(\"--precision\", type=str, default=\"high\")\n    return parser.parse_args()\n\n\ndef main():\n    torch.set_float32_matmul_precision(args.precision)\n    if not torch.cuda.is_available():\n        args.mode = \"cpu-cpu-cpu\"\n    print(f\"Training in {args.mode} mode.\")\n    args.graph_device, args.feature_device, args.device = args.mode.split(\"-\")\n    args.overlap_feature_fetch = args.feature_device == \"pinned\"\n    args.overlap_graph_fetch = args.graph_device == \"pinned\"\n\n    # Load and preprocess dataset.\n    print(\"Loading data...\")\n    dataset = gb.BuiltinDataset(args.dataset).load()\n\n    # Move the dataset to the selected storage.\n    graph = (\n        dataset.graph.pin_memory_()\n        if args.graph_device == \"pinned\"\n        else dataset.graph.to(args.graph_device)\n    )\n    features = (\n        dataset.feature.pin_memory_()\n        if args.feature_device == \"pinned\"\n        else dataset.feature.to(args.feature_device)\n    )\n\n    train_set = dataset.tasks[0].train_set\n    valid_set = dataset.tasks[0].validation_set\n    test_set = dataset.tasks[0].test_set\n    all_nodes_set = dataset.all_nodes_set\n    args.fanout = list(map(int, args.fanout.split(\",\")))\n\n    if args.gpu_cache_size > 0 and args.feature_device != \"cuda\":\n        features._features[(\"node\", None, \"feat\")] = gb.gpu_cached_feature(\n            features._features[(\"node\", None, \"feat\")],\n            args.gpu_cache_size,\n        )\n\n    train_dataloader = create_dataloader(\n        graph=graph,\n        features=features,\n        itemset=train_set,\n        batch_size=args.train_batch_size,\n        fanout=args.fanout,\n        device=args.device,\n        job=\"train\",\n    )\n\n    in_channels = features.size(\"node\", None, \"feat\")[0]\n    hidden_channels = 256\n    model = GraphSAGE(in_channels, hidden_channels, len(args.fanout)).to(\n        args.device\n    )\n    assert len(args.fanout) == len(model.layers)\n\n    train(train_dataloader, model, args.device)\n\n    # Test the model.\n    print(\"Testing...\")\n    valid_mrr, test_mrr = evaluate(\n        model,\n        graph,\n        features,\n        all_nodes_set,\n        valid_set,\n        test_set,\n    )\n    print(\n        f\"Validation MRR {valid_mrr.item():.4f}, Test MRR {test_mrr.item():.4f}\"\n    )\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main()\n"
  },
  {
    "path": "examples/graphbolt/pyg/multigpu/node_classification.py",
    "content": "\"\"\"\nThis script demonstrates node classification with GraphSAGE on large graphs, \nmerging GraphBolt (GB) and PyTorch Geometric (PyG). GraphBolt efficiently manages \ndata loading for large datasets, crucial for mini-batch processing. Post data \nloading, PyG's user-friendly framework takes over for training, showcasing seamless \nintegration with GraphBolt. This combination offers an efficient alternative to \ntraditional Deep Graph Library (DGL) methods, highlighting adaptability and \nscalability in handling large-scale graph data for diverse real-world applications.\n\n\n\nKey Features:\n- Implements the GraphSAGE model, a scalable GNN, for node classification on large graphs.\n- Utilizes GraphBolt, an efficient framework for large-scale graph data processing.\n- Integrates with PyTorch Geometric for building and training the GraphSAGE model.\n- The script is well-documented, providing clear explanations at each step.\n\nThis flowchart describes the main functional sequence of the provided example.\nmain: \n\nmain\n│\n├───> Load and preprocess dataset (GraphBolt)\n│     │\n│     └───> Utilize GraphBolt's BuiltinDataset for dataset handling\n│\n├───> Instantiate the SAGE model (PyTorch Geometric)\n│     │\n│     └───> Define the GraphSAGE model architecture\n│\n├───> Train the model\n│     │\n│     ├───> Mini-Batch Processing with GraphBolt\n│     │     │\n│     │     └───> Efficient handling of mini-batches using GraphBolt's utilities\n│     │\n│     └───> Training Loop\n│           │\n│           ├───> Forward and backward passes\n│           │\n│           └───> Parameters optimization\n│\n└───> Evaluate the model\n      │\n      └───> Performance assessment on validation and test datasets\n            │\n            └───> Accuracy and other relevant metrics calculation\n\n\n\"\"\"\n\nimport argparse\nimport os\nimport time\n\nimport dgl.graphbolt as gb\nimport torch\n\n# For torch.compile until https://github.com/pytorch/pytorch/issues/121197 is\n# resolved.\nimport torch._inductor.codecache\n\ntorch._dynamo.config.cache_size_limit = 32\n\nimport torch.distributed as dist\nimport torch.multiprocessing as mp\nimport torch.nn.functional as F\nfrom torch_geometric.nn import SAGEConv\nfrom tqdm import tqdm\n\n\ndef accuracy(out, labels):\n    assert out.ndim == 2\n    assert out.size(0) == labels.size(0)\n    assert labels.ndim == 1 or (labels.ndim == 2 and labels.size(1) == 1)\n    labels = labels.flatten()\n    predictions = torch.argmax(out, 1)\n    return (labels == predictions).sum(dtype=torch.float64) / labels.size(0)\n\n\nclass GraphSAGE(torch.nn.Module):\n    #####################################################################\n    # (HIGHLIGHT) Define the GraphSAGE model architecture.\n    #\n    # - This class inherits from `torch.nn.Module`.\n    # - Two convolutional layers are created using the SAGEConv class from PyG.\n    # - 'in_size', 'hidden_size', 'out_size' are the sizes of\n    #   the input, hidden, and output features, respectively.\n    # - The forward method defines the computation performed at every call.\n    #####################################################################\n    def __init__(self, in_size, hidden_size, out_size, n_layers, cooperative):\n        super(GraphSAGE, self).__init__()\n        self.layers = torch.nn.ModuleList()\n        sizes = [in_size] + [hidden_size] * (n_layers - 1) + [out_size]\n        for i in range(n_layers):\n            self.layers.append(SAGEConv(sizes[i], sizes[i + 1]))\n        self.hidden_size = hidden_size\n        self.out_size = out_size\n        self.cooperative = cooperative\n\n    def forward(self, minibatch, x):\n        subgraphs = minibatch.sampled_subgraphs\n        h = x\n        for i, (layer, subgraph) in enumerate(zip(self.layers, subgraphs)):\n            #####################################################################\n            # (HIGHLIGHT) Convert given features to be consumed by a PyG layer.\n            #\n            #   PyG layers have two modes, bipartite and normal. We slice the\n            #   given features to get src and dst features to use the PyG layers\n            #   in the more efficient bipartite mode.\n            #####################################################################\n            if i != 0 and self.cooperative:\n                h = gb.CooperativeConvFunction.apply(subgraph, h)\n            h, edge_index, size = subgraph.to_pyg(h)\n            h = layer(h, edge_index, size=size)\n            if i != len(subgraphs) - 1:\n                h = F.relu(h)\n        if self.cooperative:\n            h = gb.CooperativeConvFunction.apply(minibatch, h)\n            h = h[minibatch.compacted_seeds]\n        return h\n\n\ndef create_dataloader(\n    args, graph, features, itemset, batch_size, fanout, device, job\n):\n    #####################################################################\n    # (HIGHLIGHT) Create a data loader for efficiently loading graph data.\n    #\n    # - 'ItemSampler' samples mini-batches of node IDs from the dataset.\n    # - 'CopyTo' copies the fetched data to the specified device.\n    # - 'sample_neighbor' performs neighbor sampling on the graph.\n    # - 'FeatureFetcher' fetches node features based on the sampled subgraph.\n\n    #####################################################################\n    # Create a datapipe for mini-batch sampling with a specific neighbor fanout.\n    # Here, [10, 10, 10] specifies the number of neighbors sampled for each node at each layer.\n    # We're using `sample_neighbor` for consistency with DGL's sampling API.\n    # Note: GraphBolt offers additional sampling methods, such as `sample_layer_neighbor`,\n    # which could provide further optimization and efficiency for GNN training.\n    # Users are encouraged to explore these advanced features for potentially improved performance.\n\n    # Initialize an ItemSampler to sample mini-batches from the dataset.\n    datapipe = gb.DistributedItemSampler(\n        itemset,\n        batch_size=batch_size,\n        shuffle=(job == \"train\"),\n        drop_last=(job == \"train\"),\n        drop_uneven_inputs=True,\n    )\n    need_copy = True\n    # Copy the data to the specified device.\n    if args.graph_device != \"cpu\" and need_copy:\n        datapipe = datapipe.copy_to(device=device)\n        need_copy = False\n    # Sample neighbors for each node in the mini-batch.\n    datapipe = getattr(datapipe, args.sample_mode)(\n        graph,\n        fanout if job != \"infer\" else [-1],\n        overlap_fetch=args.overlap_graph_fetch,\n        num_gpu_cached_edges=args.num_gpu_cached_edges,\n        gpu_cache_threshold=args.gpu_graph_caching_threshold,\n        cooperative=args.cooperative,\n        asynchronous=args.graph_device != \"cpu\",\n    )\n    # Copy the data to the specified device.\n    if args.feature_device != \"cpu\" and need_copy:\n        datapipe = datapipe.copy_to(device=device)\n        need_copy = False\n    # Fetch node features for the sampled subgraph.\n    datapipe = datapipe.fetch_feature(\n        features,\n        node_feature_keys=[\"feat\"],\n        overlap_fetch=args.overlap_feature_fetch,\n        cooperative=args.cooperative,\n    )\n    # Copy the data to the specified device.\n    if need_copy:\n        datapipe = datapipe.copy_to(device=device)\n    # Create and return a DataLoader to handle data loading.\n    return gb.DataLoader(datapipe, num_workers=args.num_workers)\n\n\ndef weighted_reduce(tensor, weight, dst=0):\n    ########################################################################\n    # (HIGHLIGHT) Collect accuracy and loss values from sub-processes and\n    # obtain overall average values.\n    #\n    # `torch.distributed.reduce` is used to reduce tensors from all the\n    # sub-processes to a specified process, ReduceOp.SUM is used by default.\n    #\n    # Because the GPUs may have differing numbers of processed items, we\n    # perform a weighted mean to calculate the exact loss and accuracy.\n    ########################################################################\n    dist.reduce(tensor=tensor, dst=dst)\n    weight = torch.tensor(weight, device=tensor.device)\n    dist.reduce(tensor=weight, dst=dst)\n    return tensor / weight\n\n\n@torch.compile\ndef train_step(minibatch, optimizer, model, loss_fn):\n    node_features = minibatch.node_features[\"feat\"]\n    labels = minibatch.labels\n    optimizer.zero_grad()\n    out = model(minibatch, node_features)\n    loss = loss_fn(out, labels)\n    num_correct = accuracy(out, labels) * labels.size(0)\n    loss.backward()\n    optimizer.step()\n    return loss.detach(), num_correct, labels.size(0)\n\n\ndef train_helper(rank, dataloader, model, optimizer, loss_fn, device):\n    model.train()  # Set the model to training mode\n    total_loss = torch.zeros(1, device=device)  # Accumulator for the total loss\n    # Accumulator for the total number of correct predictions\n    total_correct = torch.zeros(1, dtype=torch.float64, device=device)\n    total_samples = 0  # Accumulator for the total number of samples processed\n    num_batches = 0  # Counter for the number of mini-batches processed\n    start = time.time()\n    for minibatch in tqdm(dataloader, \"Training\") if rank == 0 else dataloader:\n        loss, num_correct, num_samples = train_step(\n            minibatch, optimizer, model, loss_fn\n        )\n        total_loss += loss\n        total_correct += num_correct\n        total_samples += num_samples\n        num_batches += 1\n    train_loss = weighted_reduce(total_loss, num_batches)\n    train_acc = weighted_reduce(total_correct, total_samples)\n    end = time.time()\n    return train_loss, train_acc, end - start\n\n\ndef train(args, rank, train_dataloader, valid_dataloader, model, device):\n    #####################################################################\n    # (HIGHLIGHT) Train the model for one epoch.\n    #\n    # - Iterates over the data loader, fetching mini-batches of graph data.\n    # - For each mini-batch, it performs a forward pass, computes loss, and\n    #   updates the model parameters.\n    # - The function returns the average loss and accuracy for the epoch.\n    #\n    # Parameters:\n    #   model: The GraphSAGE model.\n    #   dataloader: DataLoader that provides mini-batches of graph data.\n    #   optimizer: Optimizer used for updating model parameters.\n    #   loss_fn: Loss function used for training.\n    #   device: The device (CPU/GPU) to run the training on.\n    #####################################################################\n\n    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)\n    loss_fn = torch.nn.CrossEntropyLoss()\n\n    for epoch in range(args.epochs):\n        train_loss, train_acc, duration = train_helper(\n            rank,\n            train_dataloader,\n            model,\n            optimizer,\n            loss_fn,\n            device,\n        )\n        val_acc = evaluate(rank, model, valid_dataloader, device)\n        if rank == 0:\n            print(\n                f\"Epoch {epoch:02d}, Loss: {train_loss.item():.4f}, \"\n                f\"Approx. Train: {train_acc.item():.4f}, \"\n                f\"Approx. Val: {val_acc.item():.4f}, \"\n                f\"Time: {duration}s\"\n            )\n\n\n@torch.compile\ndef evaluate_step(minibatch, model):\n    node_features = minibatch.node_features[\"feat\"]\n    labels = minibatch.labels\n    out = model(minibatch, node_features)\n    num_correct = accuracy(out, labels) * labels.size(0)\n    return num_correct, labels.size(0)\n\n\n@torch.no_grad()\ndef evaluate(rank, model, dataloader, device):\n    model.eval()\n    total_correct = torch.zeros(1, dtype=torch.float64, device=device)\n    total_samples = 0\n    for minibatch in (\n        tqdm(dataloader, \"Evaluating\") if rank == 0 else dataloader\n    ):\n        num_correct, num_samples = evaluate_step(minibatch, model)\n        total_correct += num_correct\n        total_samples += num_samples\n\n    return weighted_reduce(total_correct, total_samples)\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(\n        description=\"Which dataset are you going to use?\"\n    )\n    parser.add_argument(\n        \"--epochs\", type=int, default=10, help=\"Number of training epochs.\"\n    )\n    parser.add_argument(\n        \"--lr\",\n        type=float,\n        default=0.003,\n        help=\"Learning rate for optimization.\",\n    )\n    parser.add_argument(\n        \"--batch-size\", type=int, default=1024, help=\"Batch size for training.\"\n    )\n    parser.add_argument(\n        \"--num-workers\",\n        type=int,\n        default=0,\n        help=\"Number of workers for data loading.\",\n    )\n    parser.add_argument(\n        \"--dataset\",\n        type=str,\n        default=\"ogbn-products\",\n        choices=[\n            \"ogbn-arxiv\",\n            \"ogbn-products\",\n            \"ogbn-papers100M\",\n            \"igb-hom-tiny\",\n            \"igb-hom-small\",\n            \"igb-hom-medium\",\n            \"igb-hom-large\",\n            \"igb-hom\",\n        ],\n        help=\"The dataset we can use for node classification example. Currently\"\n        \" ogbn-products, ogbn-arxiv, ogbn-papers100M and\"\n        \" igb-hom-[tiny|small|medium|large] and igb-hom datasets are supported.\",\n    )\n    parser.add_argument(\n        \"--fanout\",\n        type=str,\n        default=\"10,10,10\",\n        help=\"Fan-out of neighbor sampling. It is IMPORTANT to keep len(fanout)\"\n        \" identical with the number of layers in your model. Default: 10,10,10\",\n    )\n    parser.add_argument(\n        \"--mode\",\n        default=\"pinned-pinned-cuda\",\n        choices=[\n            \"pinned-pinned-cuda\",\n            \"cuda-pinned-cuda\",\n            \"cuda-cuda-cuda\",\n        ],\n        help=\"Graph storage - feature storage - Train device: 'cpu' for CPU and RAM,\"\n        \" 'pinned' for pinned memory in RAM, 'cuda' for GPU and GPU memory.\",\n    )\n    parser.add_argument(\n        \"--gpu-cache-size\",\n        type=int,\n        default=0,\n        help=\"The capacity of the GPU cache in bytes.\",\n    )\n    parser.add_argument(\n        \"--sample-mode\",\n        default=\"sample_neighbor\",\n        choices=[\"sample_neighbor\", \"sample_layer_neighbor\"],\n        help=\"The sampling function when doing layerwise sampling.\",\n    )\n    parser.add_argument(\n        \"--num-gpu-cached-edges\",\n        type=int,\n        default=0,\n        help=\"The number of edges to be cached from the graph on the GPU.\",\n    )\n    parser.add_argument(\n        \"--gpu-graph-caching-threshold\",\n        type=int,\n        default=1,\n        help=\"The number of accesses after which a vertex neighborhood will be cached.\",\n    )\n    parser.add_argument(\"--precision\", type=str, default=\"medium\")\n    parser.add_argument(\n        \"--cooperative\",\n        action=\"store_true\",\n        help=\"Enables Cooperative Minibatching from arXiv:2310.12403.\",\n    )\n    return parser.parse_args()\n\n\ndef run(rank, world_size, args, dataset):\n    # Set up multiprocessing environment.\n    torch.cuda.set_device(rank)\n    dist.init_process_group(\n        init_method=\"tcp://127.0.0.1:12345\",\n        rank=rank,\n        world_size=world_size,\n    )\n\n    print(f\"Training in {args.mode} mode.\")\n    args.graph_device, args.feature_device, args.device = args.mode.split(\"-\")\n    args.overlap_feature_fetch = args.feature_device == \"pinned\"\n    args.overlap_graph_fetch = args.graph_device == \"pinned\"\n\n    # Move the dataset to the selected storage.\n    graph = (\n        dataset.graph.pin_memory_()\n        if args.graph_device == \"pinned\"\n        else dataset.graph.to(args.graph_device)\n    )\n    features = (\n        dataset.feature.pin_memory_()\n        if args.feature_device == \"pinned\"\n        else dataset.feature.to(args.feature_device)\n    )\n\n    train_set = dataset.tasks[0].train_set\n    valid_set = dataset.tasks[0].validation_set\n    args.fanout = list(map(int, args.fanout.split(\",\")))\n\n    num_classes = dataset.tasks[0].metadata[\"num_classes\"]\n\n    if args.gpu_cache_size > 0 and args.feature_device != \"cuda\":\n        features._features[(\"node\", None, \"feat\")] = gb.gpu_cached_feature(\n            features._features[(\"node\", None, \"feat\")],\n            args.gpu_cache_size,\n        )\n\n    train_dataloader, valid_dataloader = (\n        create_dataloader(\n            args,\n            graph=graph,\n            features=features,\n            itemset=itemset,\n            batch_size=args.batch_size,\n            fanout=args.fanout,\n            device=args.device,\n            job=job,\n        )\n        for itemset, job in zip([train_set, valid_set], [\"train\", \"evaluate\"])\n    )\n\n    in_channels = features.size(\"node\", None, \"feat\")[0]\n    hidden_channels = 256\n    model = GraphSAGE(\n        in_channels,\n        hidden_channels,\n        num_classes,\n        len(args.fanout),\n        args.cooperative,\n    ).to(args.device)\n    assert len(args.fanout) == len(model.layers)\n    model = torch.nn.parallel.DistributedDataParallel(model)\n\n    train(args, rank, train_dataloader, valid_dataloader, model, args.device)\n\n    dist.destroy_process_group()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    if not torch.cuda.is_available():\n        print(\"Multi-GPU training requires GPUs.\")\n        exit(0)\n\n    torch.set_float32_matmul_precision(args.precision)\n\n    # Load and preprocess dataset.\n    print(\"Loading data...\")\n    dataset = gb.BuiltinDataset(args.dataset).load()\n\n    world_size = torch.cuda.device_count()\n\n    # Thread limiting to avoid resource competition.\n    os.environ[\"OMP_NUM_THREADS\"] = str(mp.cpu_count() // 2 // world_size)\n\n    mp.set_sharing_strategy(\"file_system\")\n    mp.spawn(\n        run,\n        args=(world_size, args, dataset),\n        nprocs=world_size,\n        join=True,\n    )\n"
  },
  {
    "path": "examples/graphbolt/pyg/node_classification.py",
    "content": "\"\"\"\nThis script demonstrates node classification with GraphSAGE on large graphs, \nmerging GraphBolt (GB) and PyTorch Geometric (PyG). GraphBolt efficiently\nmanages data loading for large datasets, crucial for mini-batch processing.\nPost data loading, PyG's user-friendly framework takes over for training,\nshowcasing seamless integration with GraphBolt. This combination offers an\nefficient alternative to traditional Deep Graph Library (DGL) methods,\nhighlighting adaptability and scalability in handling large-scale graph data\nfor diverse real-world applications.\n\nKey Features:\n- Implements the GraphSAGE model, a scalable GNN, for node classification on\n  large graphs.\n- Utilizes GraphBolt, an efficient framework for large-scale graph data processing.\n- Integrates with PyTorch Geometric for building and training the GraphSAGE model.\n- The script is well-documented, providing clear explanations at each step.\n\nThis flowchart describes the main functional sequence of the provided example.\nmain: \n\nmain\n│\n├───> Load and preprocess dataset (GraphBolt)\n│     │\n│     └───> Utilize GraphBolt's BuiltinDataset for dataset handling\n│\n├───> Instantiate the SAGE model (PyTorch Geometric)\n│     │\n│     └───> Define the GraphSAGE model architecture\n│\n├───> Train the model\n│     │\n│     ├───> Mini-Batch Processing with GraphBolt\n│     │     │\n│     │     └───> Efficient handling of mini-batches using GraphBolt's utilities\n│     │\n│     └───> Training Loop\n│           │\n│           ├───> Forward and backward passes\n│           │\n│           ├───> Convert GraphBolt MiniBatch to PyG Data\n│           │\n│           └───> Parameters optimization\n│\n└───> Evaluate the model\n      │\n      └───> Performance assessment on validation and test datasets\n            │\n            └───> Accuracy and other relevant metrics calculation\n\n\n\"\"\"\n\nimport argparse\n\nimport dgl.graphbolt as gb\nimport torch\nimport torch.nn.functional as F\nimport torchmetrics.functional as MF\nfrom torch_geometric.nn import SAGEConv\nfrom tqdm import tqdm\n\n\nclass GraphSAGE(torch.nn.Module):\n    #####################################################################\n    # (HIGHLIGHT) Define the GraphSAGE model architecture.\n    #\n    # - This class inherits from `torch.nn.Module`.\n    # - Two convolutional layers are created using the SAGEConv class from PyG.\n    # - 'in_size', 'hidden_size', 'out_size' are the sizes of\n    #   the input, hidden, and output features, respectively.\n    # - The forward method defines the computation performed at every call.\n    # - It's adopted from the official PyG example which can be found at\n    # https://github.com/pyg-team/pytorch_geometric/blob/master/examples/ogbn_products_sage.py\n    #####################################################################\n    def __init__(self, in_size, hidden_size, out_size):\n        super(GraphSAGE, self).__init__()\n        self.layers = torch.nn.ModuleList()\n        self.layers.append(SAGEConv(in_size, hidden_size))\n        self.layers.append(SAGEConv(hidden_size, hidden_size))\n        self.layers.append(SAGEConv(hidden_size, out_size))\n\n    def forward(self, x, edge_index):\n        for i, layer in enumerate(self.layers):\n            x = layer(x, edge_index)\n            if i != len(self.layers) - 1:\n                x = x.relu()\n                x = F.dropout(x, p=0.5, training=self.training)\n        return x\n\n    def inference(self, dataloader, x_all, device):\n        \"\"\"Conduct layer-wise inference to get all the node embeddings.\"\"\"\n        for i, layer in tqdm(enumerate(self.layers), \"inference\"):\n            xs = []\n            for minibatch in dataloader:\n                # Call `to_pyg_data` to convert GB Minibatch to PyG Data.\n                pyg_data = minibatch.to_pyg_data()\n                n_id = pyg_data.n_id.to(\"cpu\")\n                x = x_all[n_id].to(device)\n                edge_index = pyg_data.edge_index\n                x = layer(x, edge_index)\n                x = x[: pyg_data.batch_size]\n                if i != len(self.layers) - 1:\n                    x = x.relu()\n                xs.append(x.cpu())\n            x_all = torch.cat(xs, dim=0)\n        return x_all\n\n\ndef create_dataloader(\n    dataset_set, graph, feature, batch_size, fanout, device, job\n):\n    # Initialize an ItemSampler to sample mini-batches from the dataset.\n    datapipe = gb.ItemSampler(\n        dataset_set,\n        batch_size=batch_size,\n        shuffle=(job == \"train\"),\n        drop_last=(job == \"train\"),\n    )\n    # Sample neighbors for each node in the mini-batch.\n    datapipe = datapipe.sample_neighbor(\n        graph, fanout if job != \"infer\" else [-1]\n    )\n    # Copy the data to the specified device.\n    datapipe = datapipe.copy_to(device=device)\n    # Fetch node features for the sampled subgraph.\n    datapipe = datapipe.fetch_feature(feature, node_feature_keys=[\"feat\"])\n    # Create and return a DataLoader to handle data loading.\n    dataloader = gb.DataLoader(datapipe, num_workers=0)\n\n    return dataloader\n\n\ndef train(model, dataloader, optimizer):\n    model.train()  # Set the model to training mode\n    total_loss = 0  # Accumulator for the total loss\n    total_correct = 0  # Accumulator for the total number of correct predictions\n    total_samples = 0  # Accumulator for the total number of samples processed\n    num_batches = 0  # Counter for the number of mini-batches processed\n\n    for _, minibatch in tqdm(enumerate(dataloader), \"training\"):\n        #####################################################################\n        # (HIGHLIGHT) Convert GraphBolt MiniBatch to PyG Data class.\n        #\n        # Call `MiniBatch.to_pyg_data()` and it will return a PyG Data class\n        # with necessary data and information.\n        #####################################################################\n        pyg_data = minibatch.to_pyg_data()\n\n        optimizer.zero_grad()\n        out = model(pyg_data.x, pyg_data.edge_index)[: pyg_data.y.shape[0]]\n        y = pyg_data.y\n        loss = F.cross_entropy(out, y)\n        loss.backward()\n        optimizer.step()\n\n        total_loss += float(loss)\n        total_correct += int(out.argmax(dim=-1).eq(y).sum())\n        total_samples += y.shape[0]\n        num_batches += 1\n    avg_loss = total_loss / num_batches\n    avg_accuracy = total_correct / total_samples\n    return avg_loss, avg_accuracy\n\n\n@torch.no_grad()\ndef evaluate(model, dataloader, num_classes):\n    model.eval()\n    y_hats = []\n    ys = []\n    for _, minibatch in tqdm(enumerate(dataloader), \"evaluating\"):\n        pyg_data = minibatch.to_pyg_data()\n        out = model(pyg_data.x, pyg_data.edge_index)[: pyg_data.y.shape[0]]\n        y = pyg_data.y\n        y_hats.append(out)\n        ys.append(y)\n\n    return MF.accuracy(\n        torch.cat(y_hats),\n        torch.cat(ys),\n        task=\"multiclass\",\n        num_classes=num_classes,\n    )\n\n\n@torch.no_grad()\ndef layerwise_infer(\n    model, infer_dataloader, test_set, feature, num_classes, device\n):\n    model.eval()\n    features = feature.read(\"node\", None, \"feat\")\n    pred = model.inference(infer_dataloader, features, device)\n    pred = pred[test_set._items[0]]\n    label = test_set._items[1].to(pred.device)\n\n    return MF.accuracy(\n        pred,\n        label,\n        task=\"multiclass\",\n        num_classes=num_classes,\n    )\n\n\ndef main():\n    parser = argparse.ArgumentParser(\n        description=\"Which dataset are you going to use?\"\n    )\n    parser.add_argument(\n        \"--dataset\",\n        type=str,\n        default=\"ogbn-products\",\n        help='Name of the dataset to use (e.g., \"ogbn-products\", \"ogbn-arxiv\")',\n    )\n    parser.add_argument(\n        \"--epochs\", type=int, default=10, help=\"Number of training epochs.\"\n    )\n    parser.add_argument(\n        \"--batch-size\", type=int, default=1024, help=\"Batch size for training.\"\n    )\n    args = parser.parse_args()\n\n    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n    dataset_name = args.dataset\n    dataset = gb.BuiltinDataset(dataset_name).load()\n    graph = dataset.graph\n    feature = dataset.feature.pin_memory_()\n    train_set = dataset.tasks[0].train_set\n    valid_set = dataset.tasks[0].validation_set\n    test_set = dataset.tasks[0].test_set\n    all_nodes_set = dataset.all_nodes_set\n    num_classes = dataset.tasks[0].metadata[\"num_classes\"]\n\n    train_dataloader = create_dataloader(\n        train_set,\n        graph,\n        feature,\n        args.batch_size,\n        [5, 10, 15],\n        device,\n        job=\"train\",\n    )\n    valid_dataloader = create_dataloader(\n        valid_set,\n        graph,\n        feature,\n        args.batch_size,\n        [5, 10, 15],\n        device,\n        job=\"evaluate\",\n    )\n    infer_dataloader = create_dataloader(\n        all_nodes_set,\n        graph,\n        feature,\n        4 * args.batch_size,\n        [-1],\n        device,\n        job=\"infer\",\n    )\n    in_channels = feature.size(\"node\", None, \"feat\")[0]\n    hidden_channels = 256\n    model = GraphSAGE(in_channels, hidden_channels, num_classes).to(device)\n    optimizer = torch.optim.Adam(model.parameters(), lr=0.003)\n    for epoch in range(args.epochs):\n        train_loss, train_accuracy = train(model, train_dataloader, optimizer)\n\n        valid_accuracy = evaluate(model, valid_dataloader, num_classes)\n        print(\n            f\"Epoch {epoch}, Train Loss: {train_loss:.4f}, \"\n            f\"Train Accuracy: {train_accuracy:.4f}, \"\n            f\"Valid Accuracy: {valid_accuracy:.4f}\"\n        )\n    test_accuracy = layerwise_infer(\n        model, infer_dataloader, test_set, feature, num_classes, device\n    )\n    print(f\"Test Accuracy: {test_accuracy:.4f}\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/graphbolt/pyg/node_classification_advanced.py",
    "content": "\"\"\"\nThis script demonstrates node classification with GraphSAGE on large graphs, \nmerging GraphBolt (GB) and PyTorch Geometric (PyG). GraphBolt efficiently manages \ndata loading for large datasets, crucial for mini-batch processing. Post data \nloading, PyG's user-friendly framework takes over for training, showcasing seamless \nintegration with GraphBolt. This combination offers an efficient alternative to \ntraditional Deep Graph Library (DGL) methods, highlighting adaptability and \nscalability in handling large-scale graph data for diverse real-world applications.\n\n\n\nKey Features:\n- Implements the GraphSAGE model, a scalable GNN, for node classification on large graphs.\n- Utilizes GraphBolt, an efficient framework for large-scale graph data processing.\n- Integrates with PyTorch Geometric for building and training the GraphSAGE model.\n- The script is well-documented, providing clear explanations at each step.\n\nThis flowchart describes the main functional sequence of the provided example.\nmain: \n\nmain\n│\n├───> Load and preprocess dataset (GraphBolt)\n│     │\n│     └───> Utilize GraphBolt's BuiltinDataset for dataset handling\n│\n├───> Instantiate the SAGE model (PyTorch Geometric)\n│     │\n│     └───> Define the GraphSAGE model architecture\n│\n├───> Train the model\n│     │\n│     ├───> Mini-Batch Processing with GraphBolt\n│     │     │\n│     │     └───> Efficient handling of mini-batches using GraphBolt's utilities\n│     │\n│     └───> Training Loop\n│           │\n│           ├───> Forward and backward passes\n│           │\n│           └───> Parameters optimization\n│\n└───> Evaluate the model\n      │\n      └───> Performance assessment on validation and test datasets\n            │\n            └───> Accuracy and other relevant metrics calculation\n\n\n\"\"\"\n\nimport argparse\nimport time\n\nimport dgl.graphbolt as gb\nimport torch\n\n# For torch.compile until https://github.com/pytorch/pytorch/issues/121197 is\n# resolved.\nimport torch._inductor.codecache\n\ntorch._dynamo.config.cache_size_limit = 32\n\nimport torch.nn.functional as F\nfrom torch_geometric.nn import SAGEConv\nfrom tqdm import tqdm\n\n\ndef accuracy(out, labels):\n    assert out.ndim == 2\n    assert out.size(0) == labels.size(0)\n    assert labels.ndim == 1 or (labels.ndim == 2 and labels.size(1) == 1)\n    labels = labels.flatten()\n    predictions = torch.argmax(out, 1)\n    return (labels == predictions).sum(dtype=torch.float64) / labels.size(0)\n\n\nclass GraphSAGE(torch.nn.Module):\n    #####################################################################\n    # (HIGHLIGHT) Define the GraphSAGE model architecture.\n    #\n    # - This class inherits from `torch.nn.Module`.\n    # - Two convolutional layers are created using the SAGEConv class from PyG.\n    # - 'in_size', 'hidden_size', 'out_size' are the sizes of\n    #   the input, hidden, and output features, respectively.\n    # - The forward method defines the computation performed at every call.\n    #####################################################################\n    def __init__(self, in_size, hidden_size, out_size, n_layers):\n        super(GraphSAGE, self).__init__()\n        self.layers = torch.nn.ModuleList()\n        sizes = [in_size] + [hidden_size] * (n_layers - 1) + [out_size]\n        for i in range(n_layers):\n            self.layers.append(SAGEConv(sizes[i], sizes[i + 1]))\n        self.hidden_size = hidden_size\n        self.out_size = out_size\n\n    def forward(self, subgraphs, x):\n        h = x\n        for i, (layer, subgraph) in enumerate(zip(self.layers, subgraphs)):\n            #####################################################################\n            # (HIGHLIGHT) Convert given features to be consumed by a PyG layer.\n            #\n            #   PyG layers have two modes, bipartite and normal. We slice the\n            #   given features to get src and dst features to use the PyG layers\n            #   in the more efficient bipartite mode.\n            #####################################################################\n            h, edge_index, size = subgraph.to_pyg(h)\n            h = layer(h, edge_index, size=size)\n            if i != len(subgraphs) - 1:\n                h = F.relu(h)\n        return h\n\n    def inference(self, graph, features, dataloader, storage_device):\n        \"\"\"Conduct layer-wise inference to get all the node embeddings.\"\"\"\n        pin_memory = storage_device == \"pinned\"\n        buffer_device = torch.device(\"cpu\" if pin_memory else storage_device)\n\n        for layer_idx, layer in enumerate(self.layers):\n            is_last_layer = layer_idx == len(self.layers) - 1\n\n            y = torch.empty(\n                graph.total_num_nodes,\n                self.out_size if is_last_layer else self.hidden_size,\n                dtype=torch.float32,\n                device=buffer_device,\n                pin_memory=pin_memory,\n            )\n            for data in tqdm(dataloader, \"Inferencing\"):\n                # len(data.sampled_subgraphs) = 1\n                h, edge_index, size = data.sampled_subgraphs[0].to_pyg(\n                    data.node_features[\"feat\"]\n                )\n                hidden_x = layer(h, edge_index, size=size)\n                if not is_last_layer:\n                    hidden_x = F.relu(hidden_x)\n                # By design, our output nodes are contiguous.\n                y[data.seeds[0] : data.seeds[-1] + 1] = hidden_x.to(\n                    buffer_device\n                )\n            if not is_last_layer:\n                features.update(\"node\", None, \"feat\", y)\n\n        return y\n\n\ndef create_dataloader(\n    graph, features, itemset, batch_size, fanout, device, job\n):\n    #####################################################################\n    # (HIGHLIGHT) Create a data loader for efficiently loading graph data.\n    #\n    # - 'ItemSampler' samples mini-batches of node IDs from the dataset.\n    # - 'CopyTo' copies the fetched data to the specified device.\n    # - 'sample_neighbor' performs neighbor sampling on the graph.\n    # - 'FeatureFetcher' fetches node features based on the sampled subgraph.\n\n    #####################################################################\n    # Create a datapipe for mini-batch sampling with a specific neighbor fanout.\n    # Here, [10, 10, 10] specifies the number of neighbors sampled for each node at each layer.\n    # We're using `sample_neighbor` for consistency with DGL's sampling API.\n    # Note: GraphBolt offers additional sampling methods, such as `sample_layer_neighbor`,\n    # which could provide further optimization and efficiency for GNN training.\n    # Users are encouraged to explore these advanced features for potentially improved performance.\n\n    # Initialize an ItemSampler to sample mini-batches from the dataset.\n    datapipe = gb.ItemSampler(\n        itemset,\n        batch_size=batch_size,\n        shuffle=(job == \"train\"),\n        drop_last=(job == \"train\"),\n    )\n    need_copy = True\n    # Copy the data to the specified device.\n    if args.graph_device != \"cpu\" and need_copy:\n        datapipe = datapipe.copy_to(device=device)\n        need_copy = False\n    # Sample neighbors for each node in the mini-batch.\n    datapipe = getattr(datapipe, args.sample_mode)(\n        graph,\n        fanout if job != \"infer\" else [-1],\n        overlap_fetch=args.overlap_graph_fetch,\n        num_gpu_cached_edges=args.num_gpu_cached_edges,\n        gpu_cache_threshold=args.gpu_graph_caching_threshold,\n        asynchronous=args.graph_device != \"cpu\",\n    )\n    # Copy the data to the specified device.\n    if args.feature_device != \"cpu\" and need_copy:\n        datapipe = datapipe.copy_to(device=device)\n        need_copy = False\n    # Fetch node features for the sampled subgraph.\n    datapipe = datapipe.fetch_feature(\n        features,\n        node_feature_keys=[\"feat\"],\n        overlap_fetch=args.overlap_feature_fetch,\n    )\n    # Copy the data to the specified device.\n    if need_copy:\n        datapipe = datapipe.copy_to(device=device)\n    # Create and return a DataLoader to handle data loading.\n    return gb.DataLoader(datapipe, num_workers=args.num_workers)\n\n\n@torch.compile\ndef train_step(minibatch, optimizer, model, loss_fn):\n    node_features = minibatch.node_features[\"feat\"]\n    labels = minibatch.labels\n    optimizer.zero_grad()\n    out = model(minibatch.sampled_subgraphs, node_features)\n    loss = loss_fn(out, labels)\n    num_correct = accuracy(out, labels) * labels.size(0)\n    loss.backward()\n    optimizer.step()\n    return loss.detach(), num_correct, labels.size(0)\n\n\ndef train_helper(dataloader, model, optimizer, loss_fn, device):\n    model.train()  # Set the model to training mode\n    total_loss = torch.zeros(1, device=device)  # Accumulator for the total loss\n    # Accumulator for the total number of correct predictions\n    total_correct = torch.zeros(1, dtype=torch.float64, device=device)\n    total_samples = 0  # Accumulator for the total number of samples processed\n    num_batches = 0  # Counter for the number of mini-batches processed\n    start = time.time()\n    for minibatch in tqdm(dataloader, \"Training\"):\n        loss, num_correct, num_samples = train_step(\n            minibatch, optimizer, model, loss_fn\n        )\n        total_loss += loss\n        total_correct += num_correct\n        total_samples += num_samples\n        num_batches += 1\n    train_loss = total_loss / num_batches\n    train_acc = total_correct / total_samples\n    end = time.time()\n    return train_loss, train_acc, end - start\n\n\ndef train(train_dataloader, valid_dataloader, model, device):\n    #####################################################################\n    # (HIGHLIGHT) Train the model for one epoch.\n    #\n    # - Iterates over the data loader, fetching mini-batches of graph data.\n    # - For each mini-batch, it performs a forward pass, computes loss, and\n    #   updates the model parameters.\n    # - The function returns the average loss and accuracy for the epoch.\n    #\n    # Parameters:\n    #   model: The GraphSAGE model.\n    #   dataloader: DataLoader that provides mini-batches of graph data.\n    #   optimizer: Optimizer used for updating model parameters.\n    #   loss_fn: Loss function used for training.\n    #   device: The device (CPU/GPU) to run the training on.\n    #####################################################################\n\n    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)\n    loss_fn = torch.nn.CrossEntropyLoss()\n\n    for epoch in range(args.epochs):\n        train_loss, train_acc, duration = train_helper(\n            train_dataloader, model, optimizer, loss_fn, device\n        )\n        val_acc = evaluate(model, valid_dataloader, device)\n        print(\n            f\"Epoch {epoch:02d}, Loss: {train_loss.item():.4f}, \"\n            f\"Approx. Train: {train_acc.item():.4f}, \"\n            f\"Approx. Val: {val_acc.item():.4f}, \"\n            f\"Time: {duration}s\"\n        )\n\n\n@torch.no_grad()\ndef layerwise_infer(args, graph, features, test_set, all_nodes_set, model):\n    model.eval()\n    dataloader = create_dataloader(\n        graph=graph,\n        features=features,\n        itemset=all_nodes_set,\n        batch_size=4 * args.batch_size,\n        fanout=[-1],\n        device=args.device,\n        job=\"infer\",\n    )\n    pred = model.inference(graph, features, dataloader, args.feature_device)\n    pred = pred[test_set._items[0]]\n    label = test_set._items[1].to(pred.device)\n\n    return accuracy(pred, label)\n\n\n@torch.compile\ndef evaluate_step(minibatch, model):\n    node_features = minibatch.node_features[\"feat\"]\n    labels = minibatch.labels\n    out = model(minibatch.sampled_subgraphs, node_features)\n    num_correct = accuracy(out, labels) * labels.size(0)\n    return num_correct, labels.size(0)\n\n\n@torch.no_grad()\ndef evaluate(model, dataloader, device):\n    model.eval()\n    total_correct = torch.zeros(1, dtype=torch.float64, device=device)\n    total_samples = 0\n    for minibatch in tqdm(dataloader, \"Evaluating\"):\n        num_correct, num_samples = evaluate_step(minibatch, model)\n        total_correct += num_correct\n        total_samples += num_samples\n\n    return total_correct / total_samples\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(\n        description=\"Which dataset are you going to use?\"\n    )\n    parser.add_argument(\n        \"--epochs\", type=int, default=10, help=\"Number of training epochs.\"\n    )\n    parser.add_argument(\n        \"--lr\",\n        type=float,\n        default=0.003,\n        help=\"Learning rate for optimization.\",\n    )\n    parser.add_argument(\n        \"--batch-size\", type=int, default=1024, help=\"Batch size for training.\"\n    )\n    parser.add_argument(\n        \"--num-workers\",\n        type=int,\n        default=0,\n        help=\"Number of workers for data loading.\",\n    )\n    parser.add_argument(\n        \"--dataset\",\n        type=str,\n        default=\"ogbn-products\",\n        choices=[\n            \"ogbn-arxiv\",\n            \"ogbn-products\",\n            \"ogbn-papers100M\",\n            \"igb-hom-tiny\",\n            \"igb-hom-small\",\n            \"igb-hom-medium\",\n            \"igb-hom-large\",\n            \"igb-hom\",\n        ],\n        help=\"The dataset we can use for node classification example. Currently\"\n        \" ogbn-products, ogbn-arxiv, ogbn-papers100M and\"\n        \" igb-hom-[tiny|small|medium|large] and igb-hom datasets are supported.\",\n    )\n    parser.add_argument(\n        \"--fanout\",\n        type=str,\n        default=\"10,10,10\",\n        help=\"Fan-out of neighbor sampling. It is IMPORTANT to keep len(fanout)\"\n        \" identical with the number of layers in your model. Default: 10,10,10\",\n    )\n    parser.add_argument(\n        \"--mode\",\n        default=\"pinned-pinned-cuda\",\n        choices=[\n            \"cpu-cpu-cpu\",\n            \"cpu-cpu-cuda\",\n            \"cpu-pinned-cuda\",\n            \"pinned-pinned-cuda\",\n            \"cuda-pinned-cuda\",\n            \"cuda-cuda-cuda\",\n        ],\n        help=\"Graph storage - feature storage - Train device: 'cpu' for CPU and RAM,\"\n        \" 'pinned' for pinned memory in RAM, 'cuda' for GPU and GPU memory.\",\n    )\n    parser.add_argument(\n        \"--gpu-cache-size\",\n        type=int,\n        default=0,\n        help=\"The capacity of the GPU cache in bytes.\",\n    )\n    parser.add_argument(\n        \"--sample-mode\",\n        default=\"sample_neighbor\",\n        choices=[\"sample_neighbor\", \"sample_layer_neighbor\"],\n        help=\"The sampling function when doing layerwise sampling.\",\n    )\n    parser.add_argument(\n        \"--num-gpu-cached-edges\",\n        type=int,\n        default=0,\n        help=\"The number of edges to be cached from the graph on the GPU.\",\n    )\n    parser.add_argument(\n        \"--gpu-graph-caching-threshold\",\n        type=int,\n        default=1,\n        help=\"The number of accesses after which a vertex neighborhood will be cached.\",\n    )\n    parser.add_argument(\"--precision\", type=str, default=\"high\")\n    return parser.parse_args()\n\n\ndef main():\n    torch.set_float32_matmul_precision(args.precision)\n    if not torch.cuda.is_available():\n        args.mode = \"cpu-cpu-cpu\"\n    print(f\"Training in {args.mode} mode.\")\n    args.graph_device, args.feature_device, args.device = args.mode.split(\"-\")\n    args.overlap_feature_fetch = args.feature_device == \"pinned\"\n    args.overlap_graph_fetch = args.graph_device == \"pinned\"\n\n    # Load and preprocess dataset.\n    print(\"Loading data...\")\n    dataset = gb.BuiltinDataset(args.dataset).load()\n\n    # Move the dataset to the selected storage.\n    graph = (\n        dataset.graph.pin_memory_()\n        if args.graph_device == \"pinned\"\n        else dataset.graph.to(args.graph_device)\n    )\n    features = (\n        dataset.feature.pin_memory_()\n        if args.feature_device == \"pinned\"\n        else dataset.feature.to(args.feature_device)\n    )\n\n    train_set = dataset.tasks[0].train_set\n    valid_set = dataset.tasks[0].validation_set\n    test_set = dataset.tasks[0].test_set\n    all_nodes_set = dataset.all_nodes_set\n    args.fanout = list(map(int, args.fanout.split(\",\")))\n\n    num_classes = dataset.tasks[0].metadata[\"num_classes\"]\n\n    if args.gpu_cache_size > 0 and args.feature_device != \"cuda\":\n        features._features[(\"node\", None, \"feat\")] = gb.gpu_cached_feature(\n            features._features[(\"node\", None, \"feat\")],\n            args.gpu_cache_size,\n        )\n\n    train_dataloader, valid_dataloader = (\n        create_dataloader(\n            graph=graph,\n            features=features,\n            itemset=itemset,\n            batch_size=args.batch_size,\n            fanout=args.fanout,\n            device=args.device,\n            job=job,\n        )\n        for itemset, job in zip([train_set, valid_set], [\"train\", \"evaluate\"])\n    )\n\n    in_channels = features.size(\"node\", None, \"feat\")[0]\n    hidden_channels = 256\n    model = GraphSAGE(\n        in_channels, hidden_channels, num_classes, len(args.fanout)\n    ).to(args.device)\n    assert len(args.fanout) == len(model.layers)\n\n    train(train_dataloader, valid_dataloader, model, args.device)\n\n    # Test the model.\n    print(\"Testing...\")\n    test_acc = layerwise_infer(\n        args,\n        graph,\n        features,\n        test_set,\n        all_nodes_set,\n        model,\n    )\n    print(f\"Test accuracy {test_acc.item():.4f}\")\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main()\n"
  },
  {
    "path": "examples/graphbolt/quickstart/README.md",
    "content": "# Graphbolt Quickstart Tutorial\n\nGraphbolt provides all you need to create a dataloader to train a Graph Neural Networks.\n\n## Examples\n\n - The [node_classification.py](https://github.com/dmlc/dgl/blob/master/examples/graphbolt/quickstart/node_classification.py)\n   shows how to create a Graphbolt dataloader to train a 2 layer Graph Convolutional Networks node\n   classification model.\n - The [link_prediction.py](https://github.com/dmlc/dgl/blob/master/examples/graphbolt/quickstart/link_prediction.py)\n   shows how to create a Graphbolt dataloader to train a 2 layer GraphSage link prediction model.\n"
  },
  {
    "path": "examples/graphbolt/quickstart/link_prediction.py",
    "content": "\"\"\"\nThis example shows how to create a GraphBolt dataloader to sample and train a\nlink prediction model with the Cora dataset.\n\nDisclaimer: Please note that the test edges are not excluded from the original\ngraph in the dataset, which could lead to data leakage. We are ignoring this\nissue for this example because we are focused on demonstrating usability.\n\"\"\"\n\nimport dgl.graphbolt as gb\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dgl.nn import SAGEConv\nfrom torcheval.metrics import BinaryAUROC\n\n\n############################################################################\n# (HIGHLIGHT) Create a single process dataloader with dgl graphbolt package.\n############################################################################\ndef create_dataloader(dataset, device, is_train=True):\n    # The second of two tasks in the dataset is link prediction.\n    task = dataset.tasks[1]\n    itemset = task.train_set if is_train else task.test_set\n\n    # Sample seed edges from the itemset.\n    datapipe = gb.ItemSampler(itemset, batch_size=256)\n\n    # Copy the mini-batch to the designated device for sampling and training.\n    datapipe = datapipe.copy_to(device)\n\n    if is_train:\n        # Sample negative edges for the seed edges.\n        datapipe = datapipe.sample_uniform_negative(\n            dataset.graph, negative_ratio=1\n        )\n\n        # Sample neighbors for the seed nodes.\n        datapipe = datapipe.sample_neighbor(dataset.graph, fanouts=[4, 2])\n\n        # Exclude seed edges from the subgraph.\n        datapipe = datapipe.transform(gb.exclude_seed_edges)\n\n    else:\n        # Sample neighbors for the seed nodes.\n        datapipe = datapipe.sample_neighbor(dataset.graph, fanouts=[-1, -1])\n\n    # Fetch features for sampled nodes.\n    datapipe = datapipe.fetch_feature(\n        dataset.feature, node_feature_keys=[\"feat\"]\n    )\n\n    # Initiate the dataloader for the datapipe.\n    return gb.DataLoader(datapipe)\n\n\nclass GraphSAGE(nn.Module):\n    def __init__(self, in_size, hidden_size=16):\n        super().__init__()\n        self.layers = nn.ModuleList()\n        self.layers.append(SAGEConv(in_size, hidden_size, \"mean\"))\n        self.layers.append(SAGEConv(hidden_size, hidden_size, \"mean\"))\n        self.predictor = nn.Sequential(\n            nn.Linear(hidden_size, hidden_size),\n            nn.ReLU(),\n            nn.Linear(hidden_size, 1),\n        )\n\n    def forward(self, blocks, x):\n        hidden_x = x\n        for layer_idx, (layer, block) in enumerate(zip(self.layers, blocks)):\n            hidden_x = layer(block, hidden_x)\n            is_last_layer = layer_idx == len(self.layers) - 1\n            if not is_last_layer:\n                hidden_x = F.relu(hidden_x)\n        return hidden_x\n\n\n@torch.no_grad()\ndef evaluate(model, dataset, device):\n    model.eval()\n    dataloader = create_dataloader(dataset, device, is_train=False)\n\n    logits = []\n    labels = []\n    for step, data in enumerate(dataloader):\n        # Get node pairs with labels for loss calculation.\n        compacted_seeds = data.compacted_seeds.T\n        label = data.labels\n\n        # The features of sampled nodes.\n        x = data.node_features[\"feat\"]\n\n        # Forward.\n        y = model(data.blocks, x)\n        logit = (\n            model.predictor(\n                y[compacted_seeds[0].long()] * y[compacted_seeds[1].long()]\n            )\n            .squeeze()\n            .detach()\n        )\n\n        logits.append(logit)\n        labels.append(label)\n\n    logits = torch.cat(logits, dim=0)\n    labels = torch.cat(labels, dim=0)\n\n    # Compute the AUROC score.\n    metric = BinaryAUROC()\n    metric.update(logits, labels)\n    score = metric.compute().item()\n    print(f\"AUC: {score:.3f}\")\n\n\ndef train(model, dataset, device):\n    dataloader = create_dataloader(dataset, device)\n    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)\n\n    for epoch in range(10):\n        model.train()\n        total_loss = 0\n        ########################################################################\n        # (HIGHLIGHT) Iterate over the dataloader and train the model with all\n        # mini-batches.\n        ########################################################################\n        for step, data in enumerate(dataloader):\n            # Get node pairs with labels for loss calculation.\n            compacted_seeds = data.compacted_seeds.T\n            labels = data.labels\n\n            # The features of sampled nodes.\n            x = data.node_features[\"feat\"]\n\n            # Forward.\n            y = model(data.blocks, x)\n            logits = model.predictor(\n                y[compacted_seeds[0].long()] * y[compacted_seeds[1].long()]\n            ).squeeze()\n\n            # Compute loss.\n            loss = F.binary_cross_entropy_with_logits(logits, labels.float())\n\n            # Backward.\n            optimizer.zero_grad()\n            loss.backward()\n            optimizer.step()\n\n            total_loss += loss.item()\n\n        print(f\"Epoch {epoch:03d} | Loss {total_loss / (step + 1):.3f}\")\n\n\nif __name__ == \"__main__\":\n    device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n    print(f\"Training in {device} mode.\")\n\n    # Load and preprocess dataset.\n    print(\"Loading data...\")\n    dataset = gb.BuiltinDataset(\"cora\").load()\n\n    # If a CUDA device is selected, we pin the graph and the features so that\n    # the GPU can access them.\n    if device == torch.device(\"cuda:0\"):\n        dataset.graph.pin_memory_()\n        dataset.feature.pin_memory_()\n\n    in_size = dataset.feature.size(\"node\", None, \"feat\")[0]\n    model = GraphSAGE(in_size).to(device)\n\n    # Model training.\n    print(\"Training...\")\n    train(model, dataset, device)\n\n    # Test the model.\n    print(\"Testing...\")\n    evaluate(model, dataset, device)\n"
  },
  {
    "path": "examples/graphbolt/quickstart/node_classification.py",
    "content": "\"\"\"\nThis example shows how to create a GraphBolt dataloader to sample and train a\nnode classification model with the Cora dataset.\n\"\"\"\nimport dgl.graphbolt as gb\nimport dgl.nn as dglnn\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torchmetrics.functional as MF\n\n\n############################################################################\n# (HIGHLIGHT) Create a single process dataloader with dgl graphbolt package.\n############################################################################\ndef create_dataloader(dataset, itemset, device):\n    # Sample seed nodes from the itemset.\n    datapipe = gb.ItemSampler(itemset, batch_size=16)\n\n    # Copy the mini-batch to the designated device for sampling and training.\n    datapipe = datapipe.copy_to(device)\n\n    # Sample neighbors for the seed nodes.\n    datapipe = datapipe.sample_neighbor(dataset.graph, fanouts=[4, 2])\n\n    # Fetch features for sampled nodes.\n    datapipe = datapipe.fetch_feature(\n        dataset.feature, node_feature_keys=[\"feat\"]\n    )\n\n    # Initiate the dataloader for the datapipe.\n    return gb.DataLoader(datapipe)\n\n\nclass GCN(nn.Module):\n    def __init__(self, in_size, out_size, hidden_size=16):\n        super().__init__()\n        self.layers = nn.ModuleList()\n        self.layers.append(dglnn.GraphConv(in_size, hidden_size))\n        self.layers.append(dglnn.GraphConv(hidden_size, out_size))\n\n    def forward(self, blocks, x):\n        hidden_x = x\n        for layer_idx, (layer, block) in enumerate(zip(self.layers, blocks)):\n            hidden_x = layer(block, hidden_x)\n            is_last_layer = layer_idx == len(self.layers) - 1\n            if not is_last_layer:\n                hidden_x = F.relu(hidden_x)\n        return hidden_x\n\n\n@torch.no_grad()\ndef evaluate(model, dataset, itemset, device):\n    model.eval()\n    y = []\n    y_hats = []\n    dataloader = create_dataloader(dataset, itemset, device)\n\n    for step, data in enumerate(dataloader):\n        x = data.node_features[\"feat\"]\n        y.append(data.labels)\n        y_hats.append(model(data.blocks, x))\n\n    return MF.accuracy(\n        torch.cat(y_hats),\n        torch.cat(y),\n        task=\"multiclass\",\n        num_classes=dataset.tasks[0].metadata[\"num_classes\"],\n    )\n\n\ndef train(model, dataset, device):\n    # The first of two tasks in the dataset is node classification.\n    task = dataset.tasks[0]\n    dataloader = create_dataloader(dataset, task.train_set, device)\n    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)\n\n    for epoch in range(10):\n        model.train()\n        total_loss = 0\n        ########################################################################\n        # (HIGHLIGHT) Iterate over the dataloader and train the model with all\n        # mini-batches.\n        ########################################################################\n        for step, data in enumerate(dataloader):\n            # The features of sampled nodes.\n            x = data.node_features[\"feat\"]\n\n            # The ground truth labels of the seed nodes.\n            y = data.labels\n\n            # Forward.\n            y_hat = model(data.blocks, x)\n\n            # Compute loss.\n            loss = F.cross_entropy(y_hat, y)\n\n            # Backward.\n            optimizer.zero_grad()\n            loss.backward()\n            optimizer.step()\n\n            total_loss += loss.item()\n\n        # Evaluate the model.\n        val_acc = evaluate(model, dataset, task.validation_set, device)\n        test_acc = evaluate(model, dataset, task.test_set, device)\n        print(\n            f\"Epoch {epoch:03d} | Loss {total_loss / (step + 1):.3f} | \"\n            f\"Val Acc {val_acc.item():.3f} | Test Acc {test_acc.item():.3f}\"\n        )\n\n\nif __name__ == \"__main__\":\n    device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n    print(f\"Training in {device} mode.\")\n\n    # Load and preprocess dataset.\n    print(\"Loading data...\")\n    dataset = gb.BuiltinDataset(\"cora\").load()\n\n    # If a CUDA device is selected, we pin the graph and the features so that\n    # the GPU can access them.\n    if device == torch.device(\"cuda:0\"):\n        dataset.graph.pin_memory_()\n        dataset.feature.pin_memory_()\n\n    in_size = dataset.feature.size(\"node\", None, \"feat\")[0]\n    out_size = dataset.tasks[0].metadata[\"num_classes\"]\n    model = GCN(in_size, out_size).to(device)\n\n    # Model training.\n    print(\"Training...\")\n    train(model, dataset, device)\n"
  },
  {
    "path": "examples/graphbolt/rgcn/README.md",
    "content": "# Node classification on heterogeneous graph with RGCN\n\nThis example aims to demonstrate how to run node classification task on heterogeneous graph with **GraphBolt**. Models are not tuned to achieve the best accuracy yet.\n\n## Run on `ogbn-mag` dataset\n\n### Sample on CPU and train/infer on CPU\n```\npython3 hetero_rgcn.py --dataset ogbn-mag\n```\n\n### Sample on CPU and train/infer on GPU\n```\npython3 hetero_rgcn.py --dataset ogbn-mag --num_gpus 1\n```\n\n### Resource usage and time cost\nBelow results are roughly collected from an AWS EC2 **g4dn.metal**, 384GB RAM, 96 vCPUs(Cascade Lake P-8259L), 8 NVIDIA T4 GPUs(16GB RAM). CPU RAM usage is the peak value of `used` field of `free` command which is a bit rough. Please refer to `RSS`/`USS`/`PSS` which are more accurate. GPU RAM usage is the peak value recorded by `nvidia-smi` command.\n\n| Dataset Size | CPU RAM Usage | Num of GPUs | GPU RAM Usage | Time Per Epoch(Training) |\n| ------------ | ------------- | ----------- | ------------- | ------------------------ |\n| ~1.1GB       | ~5.3GB        | 0           |  0GB          | ~230s                    |\n| ~1.1GB       | ~3GB          | 1           |  3.87GB       | ~64.6s                   |\n\n### Accuracies\n```\nEpoch: 01, Loss: 2.3434, Valid accuracy: 48.23%\nEpoch: 02, Loss: 1.5646, Valid accuracy: 48.49%\nEpoch: 03, Loss: 1.1633, Valid accuracy: 45.79%\nTest accuracy 44.6792\n```\n\n## Run on `ogb-lsc-mag240m` dataset\n\n### Sample on CPU and train/infer on CPU\n```\npython3 hetero_rgcn.py --dataset ogb-lsc-mag240m\n```\n\n### Sample on CPU and train/infer on GPU\n```\npython3 hetero_rgcn.py --dataset ogb-lsc-mag240m --num_gpus 1\n```\n\n### Resource usage and time cost\nBelow results are roughly collected from an AWS EC2 **g4dn.metal**, 384GB RAM, 96 vCPUs(Cascade Lake P-8259L), 8 NVIDIA T4 GPUs(16GB RAM). CPU RAM usage is the peak value of `used` field of `free` command which is a bit rough. Please refer to `RSS`/`USS`/`PSS` which are more accurate. GPU RAM usage is the peak value recorded by `nvidia-smi` command.\n\n> **note:**\n`buffer/cache` are highly used during train, it's about 300GB. If more RAM is available, more `buffer/cache` will be consumed as graph size is about 55GB and feature data is about 350GB.\nOne more thing, first epoch is quite slow as `buffer/cache` is not ready yet. For GPU train, first epoch takes **1030s**.\nEven in following epochs, time consumption varies.\n\n| Dataset Size | CPU RAM Usage | Num of GPUs | GPU RAM Usage | Time Per Epoch(Training) |\n| ------------ | ------------- | ----------- | ------------- | ------------------------ |\n| ~404GB       | ~67GB         | 0           |  0GB          | ~248s                    |\n| ~404GB       | ~60GB         | 1           |  15GB         | ~166s                    |\n\n### Accuracies\n```\nEpoch: 01, Loss: 2.1432, Valid accuracy: 50.21%\nEpoch: 02, Loss: 1.9267, Valid accuracy: 50.77%\nEpoch: 03, Loss: 1.8797, Valid accuracy: 53.38%\n```\n"
  },
  {
    "path": "examples/graphbolt/rgcn/hetero_rgcn.py",
    "content": "\"\"\"\nThis script is a GraphBolt counterpart of\n``/examples/core/rgcn/hetero_rgcn.py``. It demonstrates how to use GraphBolt\nto train a R-GCN model for node classification on the Open Graph Benchmark\n(OGB) dataset \"ogbn-mag\" and \"ogb-lsc-mag240m\". For more details on \"ogbn-mag\",\nplease refer to the OGB website: (https://ogb.stanford.edu/docs/linkprop/). For\nmore details on \"ogb-lsc-mag240m\", please refer to the OGB website:\n(https://ogb.stanford.edu/docs/lsc/mag240m/).\n\nPaper [Modeling Relational Data with Graph Convolutional Networks]\n(https://arxiv.org/abs/1703.06103).\n\nThis example highlights the user experience of GraphBolt while the model and\ntraining/evaluation procedures are almost identical to the original DGL\nimplementation. Please refer to original DGL implementation for more details.\n\nThis flowchart describes the main functional sequence of the provided example.\nmain\n│\n├───> load_dataset\n│     │\n│     └───> Load dataset\n│\n├───> rel_graph_embed [HIGHLIGHT]\n│     │\n│     └───> Generate graph embeddings\n│\n├───> Instantiate RGCN model\n│     │\n│     ├───> RelGraphConvLayer (input to hidden)\n│     │\n│     └───> RelGraphConvLayer (hidden to output)\n│\n└───> run\n      │\n      │\n      └───> Training loop\n            │\n            ├───> EntityClassify.forward (RGCN model forward pass)\n            │\n            └───> validate and test\n                  │\n                  └───> EntityClassify.evaluate\n\"\"\"\n\nimport argparse\nimport itertools\nimport sys\nimport time\n\nimport dgl\nimport dgl.graphbolt as gb\nimport dgl.nn as dglnn\n\nimport psutil\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dgl.nn import HeteroEmbedding\nfrom ogb.lsc import MAG240MEvaluator\nfrom ogb.nodeproppred import Evaluator\nfrom tqdm import tqdm\n\n\ndef load_dataset(dataset_name):\n    \"\"\"Load the dataset and return the graph, features, train/valid/test sets\n    and the number of classes.\n\n    Here, we use `BuiltInDataset` to load the dataset which returns graph,\n    features, train/valid/test sets and the number of classes.\n    \"\"\"\n    dataset = gb.BuiltinDataset(dataset_name).load()\n    print(f\"Loaded dataset: {dataset.tasks[0].metadata['name']}\")\n\n    graph = dataset.graph\n    features = dataset.feature\n    train_set = dataset.tasks[0].train_set\n    valid_set = dataset.tasks[0].validation_set\n    test_set = dataset.tasks[0].test_set\n    num_classes = dataset.tasks[0].metadata[\"num_classes\"]\n\n    return (\n        graph,\n        features,\n        train_set,\n        valid_set,\n        test_set,\n        num_classes,\n    )\n\n\ndef create_dataloader(\n    name,\n    graph,\n    features,\n    item_set,\n    device,\n    batch_size,\n    fanouts,\n    shuffle,\n    num_workers,\n):\n    \"\"\"Create a GraphBolt dataloader for training, validation or testing.\"\"\"\n\n    ###########################################################################\n    # Initialize the ItemSampler to sample mini-batches from the dataset.\n    # `item_set`:\n    #   The set of items to sample from. This is typically the\n    #   training, validation or test set.\n    # `batch_size`:\n    #   The number of nodes to sample in each mini-batch.\n    # `shuffle`:\n    #   Whether to shuffle the items in the dataset before sampling.\n    datapipe = gb.ItemSampler(item_set, batch_size=batch_size, shuffle=shuffle)\n\n    # Move the mini-batch to the appropriate device.\n    # `device`:\n    #   The device to move the mini-batch to.\n    datapipe = datapipe.copy_to(device)\n\n    # Sample neighbors for each seed node in the mini-batch.\n    # `graph`:\n    #   The graph(FusedCSCSamplingGraph) from which to sample neighbors.\n    # `fanouts`:\n    #   The number of neighbors to sample for each node in each layer.\n    datapipe = datapipe.sample_neighbor(\n        graph,\n        fanouts=fanouts,\n        overlap_fetch=args.overlap_graph_fetch,\n        asynchronous=args.asynchronous,\n    )\n\n    # Fetch the features for each node in the mini-batch.\n    # `features`:\n    #   The feature store from which to fetch the features.\n    # `node_feature_keys`:\n    #   The node features to fetch. This is a dictionary where the keys are\n    #   node types and the values are lists of feature names.\n    node_feature_keys = {\"paper\": [\"feat\"]}\n    if name == \"ogb-lsc-mag240m\":\n        node_feature_keys[\"author\"] = [\"feat\"]\n        node_feature_keys[\"institution\"] = [\"feat\"]\n    datapipe = datapipe.fetch_feature(features, node_feature_keys)\n\n    # Create a DataLoader from the datapipe.\n    # `num_workers`:\n    #   The number of worker processes to use for data loading.\n    return gb.DataLoader(datapipe, num_workers=num_workers)\n\n\ndef extract_embed(node_embed, input_nodes):\n    emb = node_embed(\n        {ntype: input_nodes[ntype] for ntype in input_nodes if ntype != \"paper\"}\n    )\n    return emb\n\n\ndef extract_node_features(name, block, data, node_embed, device):\n    \"\"\"Extract the node features from embedding layer or raw features.\"\"\"\n    if name == \"ogbn-mag\":\n        input_nodes = {\n            k: v.to(device) for k, v in block.srcdata[dgl.NID].items()\n        }\n        # Extract node embeddings for the input nodes.\n        node_features = extract_embed(node_embed, input_nodes)\n        # Add the batch's raw \"paper\" features. Corresponds to the content\n        # in the function `rel_graph_embed` comment.\n        node_features.update(\n            {\"paper\": data.node_features[(\"paper\", \"feat\")].to(device)}\n        )\n    else:\n        node_features = {\n            ntype: data.node_features[(ntype, \"feat\")]\n            for ntype in block.srctypes\n        }\n        # Original feature data are stored in float16 while model weights are\n        # float32, so we need to convert the features to float32.\n        node_features = {\n            k: v.to(device).float() for k, v in node_features.items()\n        }\n    return node_features\n\n\ndef rel_graph_embed(graph, embed_size):\n    \"\"\"Initialize a heterogenous embedding layer for all node types in the\n    graph, except for the \"paper\" node type.\n\n    The function constructs a dictionary 'node_num', where the keys are node\n    types (ntype) and the values are the number of nodes for each type. This\n    dictionary is used to create a HeteroEmbedding instance.\n\n    (HIGHLIGHT)\n    A HeteroEmbedding instance holds separate embedding layers for each node\n    type, each with its own feature space of dimensionality\n    (node_num[ntype], embed_size), where 'node_num[ntype]' is the number of\n    nodes of type 'ntype' and 'embed_size' is the embedding dimension.\n\n    The \"paper\" node type is specifically excluded, possibly because these nodes\n    might already have predefined feature representations, and therefore, do not\n    require an additional embedding layer.\n\n    Parameters\n    ----------\n    graph : FusedCSCSamplingGraph\n        The graph for which to create the heterogenous embedding layer.\n    embed_size : int\n        The size of the embedding vectors.\n\n    Returns\n    --------\n    HeteroEmbedding\n        A heterogenous embedding layer for all node types in the graph, except\n        for the \"paper\" node type.\n    \"\"\"\n    node_num = {}\n    node_type_to_id = graph.node_type_to_id\n    node_type_offset = graph.node_type_offset\n    for ntype, ntype_id in node_type_to_id.items():\n        # Skip the \"paper\" node type.\n        if ntype == \"paper\":\n            continue\n        node_num[ntype] = (\n            node_type_offset[ntype_id + 1] - node_type_offset[ntype_id]\n        )\n    print(f\"node_num for rel_graph_embed: {node_num}\")\n    return HeteroEmbedding(node_num, embed_size)\n\n\nclass RelGraphConvLayer(nn.Module):\n    def __init__(\n        self,\n        in_size,\n        out_size,\n        ntypes,\n        relation_names,\n        activation=None,\n        dropout=0.0,\n    ):\n        super(RelGraphConvLayer, self).__init__()\n        self.in_size = in_size\n        self.out_size = out_size\n        self.ntypes = ntypes\n        self.relation_names = relation_names\n        self.activation = activation\n\n        ########################################################################\n        # (HIGHLIGHT) HeteroGraphConv is a graph convolution operator over\n        # heterogeneous graphs. A dictionary is passed where the key is the\n        # relation name and the value is the instance of GraphConv. norm=\"right\"\n        # is to divide the aggregated messages by each node’s in-degrees, which\n        # is equivalent to averaging the received messages. weight=False and\n        # bias=False as we will use our own weight matrices defined later.\n        ########################################################################\n        self.conv = dglnn.HeteroGraphConv(\n            {\n                rel: dglnn.GraphConv(\n                    in_size, out_size, norm=\"right\", weight=False, bias=False\n                )\n                for rel in relation_names\n            }\n        )\n\n        # Create a separate Linear layer for each relationship. Each\n        # relationship has its own weights which will be applied to the node\n        # features before performing convolution.\n        self.weight = nn.ModuleDict(\n            {\n                rel_name: nn.Linear(in_size, out_size, bias=False)\n                for rel_name in self.relation_names\n            }\n        )\n\n        # Create a separate Linear layer for each node type.\n        # loop_weights are used to update the output embedding of each target node\n        # based on its own features, thereby allowing the model to refine the node\n        # representations. Note that this does not imply the existence of self-loop\n        # edges in the graph. It is similar to residual connection.\n        self.loop_weights = nn.ModuleDict(\n            {\n                ntype: nn.Linear(in_size, out_size, bias=True)\n                for ntype in self.ntypes\n            }\n        )\n\n        self.loop_weights = nn.ModuleDict(\n            {\n                ntype: nn.Linear(in_size, out_size, bias=True)\n                for ntype in self.ntypes\n            }\n        )\n\n        self.dropout = nn.Dropout(dropout)\n        # Initialize parameters of the model.\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        for layer in self.weight.values():\n            layer.reset_parameters()\n\n        for layer in self.loop_weights.values():\n            layer.reset_parameters()\n\n    def forward(self, g, inputs):\n        \"\"\"\n        Parameters\n        ----------\n        g : DGLGraph\n            Input graph.\n        inputs : dict[str, torch.Tensor]\n            Node feature for each node type.\n\n        Returns\n        -------\n        dict[str, torch.Tensor]\n            New node features for each node type.\n        \"\"\"\n        # Create a deep copy of the graph g with features saved in local\n        # frames to prevent side effects from modifying the graph.\n        g = g.local_var()\n\n        # Create a dictionary of weights for each relationship. The weights\n        # are retrieved from the Linear layers defined earlier.\n        weight_dict = {\n            rel_name: {\"weight\": self.weight[rel_name].weight.T}\n            for rel_name in self.relation_names\n        }\n\n        # Create a dictionary of node features for the destination nodes in\n        # the graph. We slice the node features according to the number of\n        # destination nodes of each type. This is necessary because when\n        # incorporating the effect of self-loop edges, we perform computations\n        # only on the destination nodes' features. By doing so, we ensure the\n        # feature dimensions match and prevent any misuse of incorrect node\n        # features.\n        inputs_dst = {\n            k: v[: g.number_of_dst_nodes(k)] for k, v in inputs.items()\n        }\n        # Apply the convolution operation on the graph. mod_kwargs are\n        # additional arguments for each relation function defined in the\n        # HeteroGraphConv. In this case, it's the weights for each relation.\n        hs = self.conv(g, inputs, mod_kwargs=weight_dict)\n\n        def _apply(ntype, h):\n            # Apply the `loop_weight` to the input node features, effectively\n            # acting as a residual connection. This allows the model to refine\n            # node embeddings based on its current features.\n            h = h + self.loop_weights[ntype](inputs_dst[ntype])\n            if self.activation:\n                h = self.activation(h)\n            return self.dropout(h)\n\n        # Apply the function defined above for each node type. This will update\n        # the node features using the `loop_weights`, apply the activation\n        # function and dropout.\n        return {ntype: _apply(ntype, h) for ntype, h in hs.items()}\n\n\nclass EntityClassify(nn.Module):\n    def __init__(self, graph, in_size, out_size):\n        super(EntityClassify, self).__init__()\n        self.in_size = in_size\n        self.hidden_size = 64\n        self.out_size = out_size\n\n        # Generate and sort a list of unique edge types from the input graph.\n        # eg. ['writes', 'cites']\n        etypes = list(graph.edge_type_to_id.keys())\n        etypes = [gb.etype_str_to_tuple(etype)[1] for etype in etypes]\n        self.relation_names = etypes\n        self.relation_names.sort()\n        self.dropout = 0.5\n        ntypes = list(graph.node_type_to_id.keys())\n        self.layers = nn.ModuleList()\n\n        # First layer: transform input features to hidden features. Use ReLU\n        # as the activation function and apply dropout for regularization.\n        self.layers.append(\n            RelGraphConvLayer(\n                self.in_size,\n                self.hidden_size,\n                ntypes,\n                self.relation_names,\n                activation=F.relu,\n                dropout=self.dropout,\n            )\n        )\n\n        # Second layer: transform hidden features to output features. No\n        # activation function is applied at this stage.\n        self.layers.append(\n            RelGraphConvLayer(\n                self.hidden_size,\n                self.out_size,\n                ntypes,\n                self.relation_names,\n                activation=None,\n            )\n        )\n\n    def reset_parameters(self):\n        # Reset the parameters of each layer.\n        for layer in self.layers:\n            layer.reset_parameters()\n\n    def forward(self, blocks, h):\n        for layer, block in zip(self.layers, blocks):\n            h = layer(block, h)\n        return h\n\n\n@torch.no_grad()\ndef evaluate(\n    name,\n    g,\n    model,\n    node_embed,\n    device,\n    item_set,\n    features,\n    num_workers,\n):\n    # Switches the model to evaluation mode.\n    model.eval()\n    category = \"paper\"\n    # An evaluator for the dataset.\n    if name == \"ogbn-mag\":\n        evaluator = Evaluator(name=name)\n    else:\n        evaluator = MAG240MEvaluator()\n\n    num_etype = len(g.num_edges)\n    data_loader = create_dataloader(\n        name,\n        g,\n        features,\n        item_set,\n        device,\n        batch_size=4096,\n        fanouts=[torch.full((num_etype,), 25), torch.full((num_etype,), 10)],\n        shuffle=False,\n        num_workers=num_workers,\n    )\n\n    # To store the predictions.\n    y_hats = list()\n    y_true = list()\n\n    for data in tqdm(data_loader, desc=\"Inference\"):\n        # Convert MiniBatch to DGL Blocks and move them to the target device.\n        blocks = [block.to(device) for block in data.blocks]\n        node_features = extract_node_features(\n            name, blocks[0], data, node_embed, device\n        )\n\n        # Generate predictions.\n        logits = model(blocks, node_features)\n\n        logits = logits[category]\n\n        # Apply softmax to the logits and get the prediction by selecting the\n        # argmax.\n        y_hat = logits.log_softmax(dim=-1).argmax(dim=1, keepdims=True)\n        y_hats.append(y_hat.cpu())\n        y_true.append(data.labels[category].long())\n\n    y_pred = torch.cat(y_hats, dim=0)\n    y_true = torch.cat(y_true, dim=0)\n    y_true = torch.unsqueeze(y_true, 1)\n\n    if name == \"ogb-lsc-mag240m\":\n        y_pred = y_pred.view(-1)\n        y_true = y_true.view(-1)\n\n    return evaluator.eval({\"y_true\": y_true, \"y_pred\": y_pred})[\"acc\"]\n\n\ndef train(\n    name,\n    g,\n    model,\n    node_embed,\n    optimizer,\n    train_set,\n    valid_set,\n    device,\n    features,\n    num_workers,\n    num_epochs,\n):\n    print(\"Start to train...\")\n    category = \"paper\"\n\n    num_etype = len(g.num_edges)\n    data_loader = create_dataloader(\n        name,\n        g,\n        features,\n        train_set,\n        device,\n        batch_size=1024,\n        fanouts=[torch.full((num_etype,), 25), torch.full((num_etype,), 10)],\n        shuffle=True,\n        num_workers=num_workers,\n    )\n\n    # Typically, the best Validation performance is obtained after\n    # the 1st or 2nd epoch. This is why the max epoch is set to 3.\n    for epoch in range(num_epochs):\n        num_train = len(train_set)\n        t0 = time.time()\n        model.train()\n\n        total_loss = 0\n\n        for data in tqdm(data_loader, desc=f\"Training~Epoch {epoch + 1:02d}\"):\n            # Convert MiniBatch to DGL Blocks and move them to the target\n            # device.\n            blocks = [block.to(device) for block in data.blocks]\n\n            # Fetch the number of seed nodes in the batch.\n            num_seeds = blocks[-1].num_dst_nodes(category)\n\n            # Extract the node features from embedding layer or raw features.\n            node_features = extract_node_features(\n                name, blocks[0], data, node_embed, device\n            )\n\n            # Reset gradients.\n            optimizer.zero_grad()\n            # Generate predictions.\n            logits = model(blocks, node_features)[category]\n\n            y_hat = logits.log_softmax(dim=-1)\n            loss = F.nll_loss(y_hat, data.labels[category].long())\n            loss.backward()\n            optimizer.step()\n\n            total_loss += loss.item() * num_seeds\n\n        t1 = time.time()\n        loss = total_loss / num_train\n\n        # Evaluate the model on the val/test set.\n\n        print(\"Evaluating the model on the validation set.\")\n        valid_acc = evaluate(\n            name, g, model, node_embed, device, valid_set, features, num_workers\n        )\n        print(\"Finish evaluating on validation set.\")\n\n        print(\n            f\"Epoch: {epoch + 1:02d}, \"\n            f\"Loss: {loss:.4f}, \"\n            f\"Valid accuracy: {100 * valid_acc:.2f}%, \"\n            f\"Time {t1 - t0:.4f}\"\n        )\n\n\ndef main(args):\n    device = torch.device(\n        \"cuda\" if args.num_gpus > 0 and torch.cuda.is_available() else \"cpu\"\n    )\n\n    # Load dataset.\n    (\n        g,\n        features,\n        train_set,\n        valid_set,\n        test_set,\n        num_classes,\n    ) = load_dataset(args.dataset)\n\n    # Move the dataset to the pinned memory to enable GPU access.\n    args.overlap_graph_fetch = False\n    args.asynchronous = False\n    if device == torch.device(\"cuda\"):\n        g = g.pin_memory_()\n        features = features.pin_memory_()\n        # Enable optimizations for sampling on the GPU.\n        args.overlap_graph_fetch = True\n        args.asynchronous = True\n\n    feat_size = features.size(\"node\", \"paper\", \"feat\")[0]\n\n    # As `ogb-lsc-mag240m` is a large dataset, features of `author` and\n    # `institution` are generated in advance and stored in the feature store.\n    # For `ogbn-mag`, we generate the features on the fly.\n    embed_layer = None\n    if args.dataset == \"ogbn-mag\":\n        # Create the embedding layer and move it to the appropriate device.\n        embed_layer = rel_graph_embed(g, feat_size).to(device)\n        print(\n            \"Number of embedding parameters: \"\n            f\"{sum(p.numel() for p in embed_layer.parameters())}\"\n        )\n\n    # Initialize the entity classification model.\n    model = EntityClassify(g, feat_size, num_classes).to(device)\n\n    print(\n        \"Number of model parameters: \"\n        f\"{sum(p.numel() for p in model.parameters())}\"\n    )\n\n    if embed_layer is not None:\n        embed_layer.reset_parameters()\n    model.reset_parameters()\n\n    # `itertools.chain()` is a function in Python's itertools module.\n    # It is used to flatten a list of iterables, making them act as\n    # one big iterable.\n    # In this context, the following code is used to create a single\n    # iterable over the parameters of both the model and the embed_layer,\n    # which is passed to the optimizer. The optimizer then updates all\n    # these parameters during the training process.\n    all_params = itertools.chain(\n        model.parameters(),\n        [] if embed_layer is None else embed_layer.parameters(),\n    )\n    optimizer = torch.optim.Adam(all_params, lr=0.01)\n\n    expected_max = int(psutil.cpu_count(logical=False))\n    if args.num_workers >= expected_max:\n        print(\n            \"[ERROR] You specified num_workers are larger than physical\"\n            f\"cores, please set any number less than {expected_max}\",\n            file=sys.stderr,\n        )\n\n    train(\n        args.dataset,\n        g,\n        model,\n        embed_layer,\n        optimizer,\n        train_set,\n        valid_set,\n        device,\n        features,\n        args.num_workers,\n        args.num_epochs,\n    )\n\n    print(\"Testing...\")\n    test_acc = evaluate(\n        args.dataset,\n        g,\n        model,\n        embed_layer,\n        device,\n        test_set,\n        features,\n        args.num_workers,\n    )\n    print(f\"Test accuracy {test_acc*100:.4f}\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"GraphBolt RGCN\")\n    parser.add_argument(\n        \"--dataset\",\n        type=str,\n        default=\"ogbn-mag\",\n        choices=[\"ogbn-mag\", \"ogb-lsc-mag240m\"],\n        help=\"Dataset name. Possible values: ogbn-mag, ogb-lsc-mag240m\",\n    )\n    parser.add_argument(\"--num_epochs\", type=int, default=3)\n    parser.add_argument(\"--num_workers\", type=int, default=0)\n    parser.add_argument(\"--num_gpus\", type=int, default=1)\n\n    args = parser.parse_args()\n\n    main(args)\n"
  },
  {
    "path": "examples/graphbolt/sparse/graphsage.py",
    "content": "\"\"\"\nThis script demonstrate how to use dgl sparse library to sample on graph and \ntrain model. It trains and tests a GraphSAGE model using the sparse sample and \ncompact operators to sample submatrix from the whole matrix.\n\nThis flowchart describes the main functional sequence of the provided example.\nmain\n│\n├───> Load and preprocess full dataset\n│\n├───> Instantiate SAGE model\n│\n├───> train\n│     │\n│     └───> Training loop\n│           │\n│           ├───> Sample submatrix\n│           │\n│           └───> SAGE.forward\n└───> test\n      │\n      ├───> Sample submatrix\n      │\n      └───> Evaluate the model\n\"\"\"\nimport argparse\nfrom functools import partial\n\nimport dgl.graphbolt as gb\n\nimport dgl.sparse as dglsp\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torchmetrics.functional as MF\nfrom dgl.graphbolt.subgraph_sampler import SubgraphSampler\nfrom torch.utils.data import functional_datapipe\nfrom tqdm import tqdm\n\n\nclass SAGEConv(nn.Module):\n    r\"\"\"GraphSAGE layer from `Inductive Representation Learning on\n    Large Graphs <https://arxiv.org/pdf/1706.02216.pdf>`__\n    \"\"\"\n\n    def __init__(\n        self,\n        in_feats,\n        out_feats,\n    ):\n        super(SAGEConv, self).__init__()\n        self._in_src_feats, self._in_dst_feats = in_feats, in_feats\n        self._out_feats = out_feats\n\n        self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=False)\n        self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=True)\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        gain = nn.init.calculate_gain(\"relu\")\n        nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)\n        nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)\n\n    def forward(self, A, feat):\n        feat_src = feat\n        feat_dst = feat[: A.shape[1]]\n\n        # Aggregator type: mean.\n        srcdata = self.fc_neigh(feat_src)\n        # Divided by degree.\n        D_hat = dglsp.diag(A.sum(0)) ** -1\n        A_div = A @ D_hat\n        # Conv neighbors.\n        dstdata = A_div.T @ srcdata\n\n        rst = self.fc_self(feat_dst) + dstdata\n        return rst\n\n\nclass SAGE(nn.Module):\n    def __init__(self, in_size, hid_size, out_size):\n        super().__init__()\n        self.layers = nn.ModuleList()\n        # Three-layer GraphSAGE-gcn.\n        self.layers.append(SAGEConv(in_size, hid_size))\n        self.layers.append(SAGEConv(hid_size, hid_size))\n        self.layers.append(SAGEConv(hid_size, out_size))\n        self.dropout = nn.Dropout(0.5)\n        self.hid_size = hid_size\n        self.out_size = out_size\n\n    def forward(self, sampled_matrices, x):\n        hidden_x = x\n        for layer_idx, (layer, sampled_matrix) in enumerate(\n            zip(self.layers, sampled_matrices)\n        ):\n            hidden_x = layer(sampled_matrix, hidden_x)\n            if layer_idx != len(self.layers) - 1:\n                hidden_x = F.relu(hidden_x)\n                hidden_x = self.dropout(hidden_x)\n        return hidden_x\n\n\n@functional_datapipe(\"sample_sparse_neighbor\")\nclass SparseNeighborSampler(SubgraphSampler):\n    def __init__(self, datapipe, matrix, fanouts):\n        super().__init__(datapipe)\n        self.matrix = matrix\n        # Convert fanouts to a list of tensors.\n        self.fanouts = []\n        for fanout in fanouts:\n            if not isinstance(fanout, torch.Tensor):\n                fanout = torch.LongTensor([int(fanout)])\n            self.fanouts.insert(0, fanout)\n\n    def sample_subgraphs(self, seeds, seeds_timestamp=None):\n        sampled_matrices = []\n        src = seeds.long()\n\n        #####################################################################\n        # (HIGHLIGHT) Using the sparse sample operator to preform random\n        # sampling on the neighboring nodes of the seeds nodes. The sparse\n        # compact operator is then employed to compact and relabel the sampled\n        # matrix, resulting in the sampled matrix and the relabel index.\n        #####################################################################\n        for fanout in self.fanouts:\n            # Sample neighbors.\n            sampled_matrix = self.matrix.sample(1, fanout, ids=src).coalesce()\n            # Compact the sampled matrix.\n            compacted_mat, row_ids = sampled_matrix.compact(0)\n            sampled_matrices.insert(0, compacted_mat)\n            src = row_ids\n\n        return src, sampled_matrices\n\n\n############################################################################\n# (HIGHLIGHT) Create a multi-process dataloader with dgl graphbolt package.\n############################################################################\ndef create_dataloader(A, fanouts, ids, features, device):\n    datapipe = gb.ItemSampler(ids, batch_size=1024)\n    # Customize graphbolt sampler by sparse.\n    datapipe = datapipe.sample_sparse_neighbor(A, fanouts)\n    # Use grapbolt to fetch features.\n    datapipe = datapipe.fetch_feature(features, node_feature_keys=[\"feat\"])\n    datapipe = datapipe.copy_to(device)\n    dataloader = gb.DataLoader(datapipe)\n    return dataloader\n\n\ndef evaluate(model, dataloader, num_classes):\n    model.eval()\n    ys = []\n    y_hats = []\n    for it, data in tqdm(enumerate(dataloader), \"Evaluating\"):\n        with torch.no_grad():\n            node_feature = data.node_features[\"feat\"].float()\n            blocks = data.sampled_subgraphs\n            y = data.labels\n            ys.append(y)\n            y_hats.append(model(blocks, node_feature))\n\n    return MF.accuracy(\n        torch.cat(y_hats),\n        torch.cat(ys),\n        task=\"multiclass\",\n        num_classes=num_classes,\n    )\n\n\ndef validate(device, dataset, model, num_classes):\n    test_set = dataset.tasks[0].test_set\n    test_dataloader = create_dataloader(\n        A, [10, 10, 10], test_set, features, device\n    )\n    acc = evaluate(model, test_dataloader, num_classes)\n    return acc\n\n\ndef train(device, A, features, dataset, num_classes, model):\n    # Create sampler & dataloader.\n    train_set = dataset.tasks[0].train_set\n    train_dataloader = create_dataloader(\n        A, [10, 10, 10], train_set, features, device\n    )\n\n    valid_set = dataset.tasks[0].validation_set\n    val_dataloader = create_dataloader(\n        A, [10, 10, 10], valid_set, features, device\n    )\n\n    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4)\n\n    for epoch in range(10):\n        model.train()\n        total_loss = 0\n        for it, data in tqdm(enumerate(train_dataloader), \"Training\"):\n            node_feature = data.node_features[\"feat\"].float()\n            blocks = data.sampled_subgraphs\n            y = data.labels\n            y_hat = model(blocks, node_feature)\n            loss = F.cross_entropy(y_hat, y)\n            optimizer.zero_grad()\n            loss.backward()\n            optimizer.step()\n            total_loss += loss.item()\n\n        acc = evaluate(model, val_dataloader, num_classes)\n        print(\n            \"Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} \".format(\n                epoch, total_loss / (it + 1), acc.item()\n            )\n        )\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"GraphSAGE\")\n    parser.add_argument(\n        \"--mode\",\n        default=\"gpu\",\n        choices=[\"cpu\", \"gpu\"],\n        help=\"Training mode. 'cpu' for CPU training, 'gpu' for GPU training.\",\n    )\n    args = parser.parse_args()\n    if not torch.cuda.is_available():\n        args.mode = \"cpu\"\n    print(f\"Training in {args.mode} mode.\")\n\n    #####################################################################\n    # (HIGHLIGHT) This example implements a graphSAGE algorithm by sparse\n    # operators, which involves sampling a subgraph from a full graph and\n    # conducting training.\n    #\n    # First, the whole graph is loaded onto the CPU or GPU and transformed\n    # to sparse matrix. To obtain the training subgraph, it samples three\n    # submatrices by seed nodes, which contains their randomly sampled\n    # 1-hop, 2-hop, and 3-hop neighbors. Then, the features of the\n    # subgraph are input to the network for training.\n    #####################################################################\n\n    # Load and preprocess dataset.\n    print(\"Loading data\")\n    device = torch.device(\"cpu\" if args.mode == \"cpu\" else \"cuda\")\n    dataset = gb.BuiltinDataset(\"ogbn-products\").load()\n    g = dataset.graph\n    features = dataset.feature\n\n    # Create GraphSAGE model.\n    in_size = features.size(\"node\", None, \"feat\")[0]\n    num_classes = dataset.tasks[0].metadata[\"num_classes\"]\n    out_size = num_classes\n    model = SAGE(in_size, 256, out_size).to(device)\n\n    # Create sparse.\n    N = g.num_nodes\n    A = dglsp.from_csc(g.csc_indptr.long(), g.indices.long(), shape=(N, N))\n\n    # Model training.\n    print(\"Training...\")\n    train(device, A, features, dataset, num_classes, model)\n\n    # Test the model.\n    print(\"Testing...\")\n    acc = validate(device, dataset, model, num_classes)\n    print(f\"Test accuracy {acc:.4f}\")\n"
  },
  {
    "path": "examples/graphbolt/temporal_link_prediction.py",
    "content": "\"\"\"\nThis script trains and tests a Heterogeneous GraphSAGE model for link\nprediction with temporal information using graphbolt dataloader.\n\nWhile node classification predicts labels for nodes based on their\nlocal neighborhoods, link prediction assesses the likelihood of an edge\nexisting between two nodes, necessitating different sampling strategies\nthat account for pairs of nodes and their joint neighborhoods.\n\nAn additional temporal attribute is provided in both graph and TVT sets,\nensuring that during sampling, only neighbors whose timestamps are earlier\nthan the seed timestamp will be sampled.\n\nThis flowchart describes the main functional sequence of the provided example.\nmain\n│\n├───> OnDiskDataset pre-processing\n│\n├───> Instantiate HeteroSAGE model\n│\n├───> train\n│     │\n│     ├───> Get graphbolt dataloader (HIGHLIGHT)\n│     │\n│     └───> Training loop\n│           │\n│           ├───> HeteroSAGE.forward\n│           │\n│           └───> Validation set evaluation\n│\n└───> Test set evaluation\n\"\"\"\nimport argparse\nimport os\nimport time\n\nimport dgl.graphbolt as gb\nimport dgl.nn as dglnn\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport tqdm\nfrom dgl.data.utils import download, extract_archive\n\n\nTIMESTAMP_FEATURE_NAME = \"__timestamp__\"\nNODE_FEATURE_KEYS = {\n    \"Product\": [\"categoryId\"],\n    \"Query\": [\"categoryId\"],\n}\n\nTARGET_TYPE = (\"Query\", \"Click\", \"Product\")\nALL_TYPES = [\n    TARGET_TYPE,\n    (\"Product\", \"reverse_Click\", \"Query\"),\n    (\"Product\", \"reverse_QueryResult\", \"Query\"),\n    (\"Query\", \"QueryResult\", \"Product\"),\n]\n\n\nclass CategoricalEncoder(nn.Module):\n    def __init__(\n        self,\n        num_categories,\n        out_size,\n    ):\n        super().__init__()\n        self.embed = nn.Embedding(num_categories, out_size)\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        nn.init.xavier_uniform_(self.embed.weight)\n\n    def forward(self, input_feat: torch.Tensor):\n        return self.embed(input_feat.view(-1))\n\n\nclass HeteroSAGE(nn.Module):\n    def __init__(self, in_size, hidden_size):\n        super().__init__()\n        self.layers = nn.ModuleList()\n        sizes = [in_size, hidden_size]\n        for size in sizes:\n            self.layers.append(\n                dglnn.HeteroGraphConv(\n                    {\n                        etype: dglnn.SAGEConv(\n                            size,\n                            hidden_size,\n                            \"mean\",\n                        )\n                        for etype in ALL_TYPES\n                    },\n                    aggregate=\"sum\",\n                )\n            )\n        self.predictor = nn.Sequential(\n            nn.Linear(hidden_size, hidden_size),\n            nn.ReLU(),\n            nn.Linear(hidden_size, hidden_size),\n            nn.ReLU(),\n            nn.Linear(hidden_size, 1),\n        )\n\n    def forward(self, blocks, X_node_dict):\n        H_node_dict = X_node_dict\n        for layer_idx, (layer, block) in enumerate(zip(self.layers, blocks)):\n            H_node_dict = layer(block, H_node_dict)\n            is_last_layer = layer_idx == len(self.layers) - 1\n            if not is_last_layer:\n                H_node_dict = {\n                    ntype: F.relu(H) for ntype, H in H_node_dict.items()\n                }\n        return H_node_dict\n\n\ndef create_dataloader(args, graph, features, itemset, is_train=True):\n    datapipe = gb.ItemSampler(\n        itemset,\n        batch_size=args.train_batch_size if is_train else args.eval_batch_size,\n        shuffle=is_train,\n    )\n\n    if args.storage_device != \"cpu\":\n        datapipe = datapipe.copy_to(device=args.device)\n\n    ############################################################################\n    # [Input]:\n    # 'datapipe' is either 'ItemSampler' or 'UniformNegativeSampler' depending\n    # on whether training is needed ('is_train'),\n    # 'graph': The network topology for sampling.\n    # 'args.fanout': Number of neighbors to sample per node.\n    # [Output]:\n    # A NeighborSampler object to sample neighbors.\n    # [Role]:\n    # Initialize a neighbor sampler for sampling the neighborhoods of nodes with\n    # considering of temporal information. Only neighbors that is earlier than\n    # the seed will be sampled.\n    ############################################################################\n    datapipe = getattr(datapipe, args.sample_mode)(\n        graph,\n        args.fanout if is_train else [-1],\n        node_timestamp_attr_name=TIMESTAMP_FEATURE_NAME,\n        edge_timestamp_attr_name=TIMESTAMP_FEATURE_NAME,\n    )\n\n    datapipe = datapipe.fetch_feature(\n        features, node_feature_keys=NODE_FEATURE_KEYS\n    )\n\n    if args.storage_device == \"cpu\":\n        datapipe = datapipe.copy_to(device=args.device)\n\n    dataloader = gb.DataLoader(\n        datapipe,\n        num_workers=args.num_workers,\n    )\n\n    # Return the fully-initialized DataLoader object.\n    return dataloader\n\n\ndef train(args, model, graph, features, train_set, encoders):\n    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)\n    dataloader = create_dataloader(args, graph, features, train_set)\n\n    for epoch in range(args.epochs):\n        model.train()\n        total_loss = 0\n        start_epoch_time = time.time()\n        for step, data in tqdm.tqdm(enumerate(dataloader)):\n            # Get node pairs with labels for loss calculation.\n            compacted_seeds = data.compacted_seeds[\n                gb.etype_tuple_to_str(TARGET_TYPE)\n            ].T\n            labels = data.labels\n\n            node_feature = {}\n            for ntype, keys in NODE_FEATURE_KEYS.items():\n                ntype, feat = ntype, keys[0]\n                node_feature[ntype] = data.node_features[\n                    (ntype, feat)\n                ].squeeze()\n\n            blocks = data.blocks\n\n            # Get the embeddings of the input nodes.\n            X_node_dict = {\n                ntype: encoders[ntype](feat)\n                for ntype, feat in node_feature.items()\n            }\n            X_node_dict = model(blocks, X_node_dict)\n            src_type, _, dst_type = TARGET_TYPE\n            logits = model.predictor(\n                X_node_dict[src_type][compacted_seeds[0]]\n                * X_node_dict[dst_type][compacted_seeds[1]]\n            ).squeeze()\n\n            # Compute loss.\n            loss = F.binary_cross_entropy_with_logits(\n                logits, labels[gb.etype_tuple_to_str(TARGET_TYPE)].float()\n            )\n            optimizer.zero_grad()\n            loss.backward()\n            optimizer.step()\n\n            total_loss += loss.item()\n            if step + 1 == args.early_stop:\n                # Early stopping requires a new dataloader to reset its state.\n                dataloader = create_dataloader(args, graph, features, train_set)\n                break\n\n        end_epoch_time = time.time()\n        print(\n            f\"Epoch {epoch:05d} | \"\n            f\"Loss {(total_loss) / (step + 1):.4f} | \"\n            f\"Time {(end_epoch_time - start_epoch_time):.4f} s\"\n        )\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"diginetica-r2ne (GraphBolt)\")\n    parser.add_argument(\"--epochs\", type=int, default=10)\n    parser.add_argument(\"--lr\", type=float, default=0.0005)\n    parser.add_argument(\"--neg-ratio\", type=int, default=1)\n    parser.add_argument(\"--train-batch-size\", type=int, default=1024)\n    parser.add_argument(\"--eval-batch-size\", type=int, default=1024)\n    parser.add_argument(\"--num-workers\", type=int, default=0)\n    parser.add_argument(\n        \"--dataset\",\n        default=\"diginetica-r2ne\",\n        choices=[\"diginetica-r2ne\"],\n        help=\"Dataset.\",\n    )\n    parser.add_argument(\n        \"--early-stop\",\n        type=int,\n        default=0,\n        help=\"0 means no early stop, otherwise stop at the input-th step\",\n    )\n    parser.add_argument(\n        \"--fanout\",\n        type=str,\n        default=\"20,20\",\n        help=\"Fan-out of neighbor sampling. Default: 20, 20\",\n    )\n    parser.add_argument(\n        \"--exclude-edges\",\n        type=int,\n        default=1,\n        help=\"Whether to exclude reverse edges during sampling. Default: 1\",\n    )\n    parser.add_argument(\n        \"--mode\",\n        default=\"cpu-cuda\",\n        choices=[\"cpu-cpu\", \"cpu-cuda\", \"cuda-cuda\"],\n        help=\"Dataset storage placement and Train device: 'cpu' for CPU and RAM,\"\n        \" 'pinned' for pinned memory in RAM, 'cuda' for GPU and GPU memory.\",\n    )\n    parser.add_argument(\n        \"--sample-mode\",\n        default=\"temporal_sample_neighbor\",\n        choices=[\"temporal_sample_neighbor\", \"temporal_sample_layer_neighbor\"],\n        help=\"The sampling function when doing layerwise sampling.\",\n    )\n    return parser.parse_args()\n\n\ndef download_datasets(name, root=\"datasets\"):\n    url = \"https://dgl-data.s3-accelerate.amazonaws.com/dataset/\"\n    dataset_dir = os.path.join(root, name)\n    if not os.path.exists(dataset_dir):\n        url += name + \".zip\"\n        os.makedirs(root, exist_ok=True)\n        zip_file_path = os.path.join(root, name + \".zip\")\n        download(url, path=zip_file_path)\n        extract_archive(zip_file_path, root, overwrite=True)\n        os.remove(zip_file_path)\n    return dataset_dir\n\n\ndef main(args):\n    if not torch.cuda.is_available():\n        args.mode = \"cpu-cpu\"\n    print(f\"Training in {args.mode} mode.\")\n    args.storage_device, args.device = args.mode.split(\"-\")\n    args.device = torch.device(args.device)\n\n    # Load and preprocess dataset.\n    print(\"Loading data\")\n    # TODO: Add the datasets to built-in.\n    dataset_path = download_datasets(args.dataset)\n    dataset = gb.OnDiskDataset(dataset_path).load()\n\n    # Move the dataset to the selected storage.\n    graph = dataset.graph.to(args.storage_device)\n    features = dataset.feature.to(args.storage_device)\n\n    train_set = dataset.tasks[0].train_set\n    # TODO: Fix the dataset so that this modification is not needed. node_pairs\n    # needs to be cast into graph.indices.dtype, which is int32.\n    train_set._itemsets[\"Query:Click:Product\"]._items = tuple(\n        item.to(graph.indices.dtype if i == 0 else None)\n        for i, item in enumerate(\n            train_set._itemsets[\"Query:Click:Product\"]._items\n        )\n    )\n\n    args.fanout = list(map(int, args.fanout.split(\",\")))\n\n    in_size = 128\n    hidden_channels = 256\n    query_size = features.metadata(\"node\", \"Query\", \"categoryId\")[\n        \"num_categories\"\n    ]\n    product_size = features.metadata(\"node\", \"Product\", \"categoryId\")[\n        \"num_categories\"\n    ]\n    args.device = torch.device(args.device)\n    model = HeteroSAGE(in_size, hidden_channels).to(args.device)\n    encoders = {\n        \"Query\": CategoricalEncoder(query_size, in_size).to(args.device),\n        \"Product\": CategoricalEncoder(product_size, in_size).to(args.device),\n    }\n\n    # Model training.\n    print(\"Training...\")\n    train(args, model, graph, features, train_set, encoders)\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/legacy/README.md",
    "content": "# New sampling examples via `dgl.graphbolt`\nConsider taking a look at our new sampling examples in the\n`../graphbolt` folder using `dgl.graphbolt`.\n\n# Sampling Examples Running\n\n## Requirements\n\n```bash\npip install torchmetrics==0.11.4\n```\n\n## How to run\n\n### Node classification\n\nRun with following (available mode: \"cpu\", \"mixed\"(default), \"gpu\")\n\n```bash\npython3 node_classification.py --mode mixed\n```\n"
  },
  {
    "path": "examples/legacy/link_prediction.py",
    "content": "\"\"\"\nThis script trains and tests a GraphSAGE model for link prediction on\nlarge graphs using efficient and tailor-made neighbor sampling.\n\nPaper: [Inductive Representation Learning on Large Graphs]\n(https://arxiv.org/abs/1706.02216)\n\nWhile node classification predicts labels for nodes based on their\nlocal neighborhoods, link prediction assesses the likelihood of an edge\nexisting between two nodes, necessitating different sampling strategies\nthat account for pairs of nodes and their joint neighborhoods.\n\nBefore reading this example, please familiar yourself with graphsage node\nclassification by reading the example in the\n`examples/core/graphsage/node_classification.py`\n\nIf you want to train graphsage on a large graph in a distributed fashion, read\nthe example in the `examples/distributed/graphsage/`.\n\nThis flowchart describes the main functional sequence of the provided example.\nmain\n│\n├───> Load and preprocess dataset\n│\n├───> Instantiate SAGE model\n│\n├───> train\n│     │\n│     ├───> NeighborSampler (HIGHLIGHT)\n│     │\n│     └───> Training loop\n│           │\n│           └───> SAGE.forward\n│\n└───> evaluate\n      │\n      └───> SAGE.inference\n            │\n            └───> MultiLayerFullNeighborSampler (HIGHLIGHT)\n\"\"\"\n\nimport argparse\nimport time\n\nimport dgl\nimport dgl.nn as dglnn\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport tqdm\nfrom dgl.dataloading import (\n    as_edge_prediction_sampler,\n    DataLoader,\n    MultiLayerFullNeighborSampler,\n    negative_sampler,\n    NeighborSampler,\n)\nfrom ogb.linkproppred import DglLinkPropPredDataset, Evaluator\n\n\ndef to_bidirected_with_reverse_mapping(g):\n    \"\"\"Convert the graph to bidirectional and return the reverse mapping.\n\n    This function transforms the input graph into its bidirectional form. It\n    then returns the newly formed bidirectional graph and the mapping that\n    represents the reverse edges. The function does not work with graphs that\n    have self-loops.\n\n    Parameters:\n    ----------\n    g : DGLGraph\n        Input graph.\n\n    Returns:\n    -------\n    DGLGraph :\n        Bidirectional graph.\n    Tensor :\n        Mapping to reverse edges.\n    \"\"\"\n    # First, add reverse edges to the graph, effectively making it\n    # bidirectional. Then, simplify the resulting graph by merging any duplicate\n    # edges. The resulting simplified graph is stored in `g_simple`, and\n    # `mapping` provides information on how edges in `g_simple` correspond to\n    # edges in the original graph.\n    g_simple, mapping = dgl.to_simple(\n        dgl.add_reverse_edges(g), return_counts=\"count\", writeback_mapping=True\n    )\n\n    # The `return_counts` option in `dgl.to_simple` returns the count of how\n    # many times each edge in the simplified graph corresponds to an edge in the\n    # original graph. This count is saved in the edge data of the returned\n    # graph with the key \"count\".\n    c = g_simple.edata[\"count\"]\n    num_edges = g.num_edges()\n\n    # `mapping_offset` is an auxiliary tensor used to understand how edges in\n    # the simplified bidirectional graph (g_simple) relate to the edges in the\n    # original graph.\n    mapping_offset = torch.zeros(\n        g_simple.num_edges() + 1, dtype=g_simple.idtype\n    )\n\n    # Calculate the cumulative sum of counts to determine boundaries for each\n    # unique edge.\n    mapping_offset[1:] = c.cumsum(0)\n\n    # Sort the mapping tensor to group the same edge indices.\n    idx = mapping.argsort()\n\n    # Using the previously computed `mapping_offset`, it extracts the first\n    # index of each group, which represents the unique edge indices from the\n    # sorted mapping.\n    idx_uniq = idx[mapping_offset[:-1]]\n\n    # If an edge index is greater than or equal to the number of edges in the\n    # original graph, it indicates that this edge is a reversed edge, and the\n    # original edge index for it is (idx_uniq - num_edges). Otherwise, its\n    # reverse edge index is (idx_uniq + num_edges).\n    reverse_idx = torch.where(\n        idx_uniq >= num_edges, idx_uniq - num_edges, idx_uniq + num_edges\n    )\n    reverse_mapping = mapping[reverse_idx]\n\n    # Sanity check to ensure valid mapping.\n    src1, dst1 = g_simple.edges()\n    src2, dst2 = g_simple.find_edges(reverse_mapping)\n    assert torch.equal(src1, dst2)\n    assert torch.equal(src2, dst1)\n    return g_simple, reverse_mapping\n\n\nclass SAGE(nn.Module):\n    def __init__(self, in_size, hidden_size):\n        super().__init__()\n        self.layers = nn.ModuleList()\n        # Three-layer GraphSAGE-mean.\n        self.layers.append(dglnn.SAGEConv(in_size, hidden_size, \"mean\"))\n        self.layers.append(dglnn.SAGEConv(hidden_size, hidden_size, \"mean\"))\n        self.layers.append(dglnn.SAGEConv(hidden_size, hidden_size, \"mean\"))\n        self.hidden_size = hidden_size\n        self.predictor = nn.Sequential(\n            nn.Linear(hidden_size, hidden_size),\n            nn.ReLU(),\n            nn.Linear(hidden_size, hidden_size),\n            nn.ReLU(),\n            nn.Linear(hidden_size, 1),\n        )\n\n    def forward(self, pair_graph, neg_pair_graph, blocks, x):\n        hidden_x = x\n        for layer_idx, (layer, block) in enumerate(zip(self.layers, blocks)):\n            hidden_x = layer(block, hidden_x)\n            is_last_layer = layer_idx == len(self.layers) - 1\n            if not is_last_layer:\n                hidden_x = F.relu(hidden_x)\n        pos_src, pos_dst = pair_graph.edges()\n        neg_src, neg_dst = neg_pair_graph.edges()\n        hidden_pos = self.predictor(hidden_x[pos_src] * hidden_x[pos_dst])\n        hidden_neg = self.predictor(hidden_x[neg_src] * hidden_x[neg_dst])\n        return hidden_pos, hidden_neg\n\n    def inference(self, g, device, batch_size):\n        \"\"\"Layer-wise inference algorithm to compute GNN node embeddings.\"\"\"\n        feat = g.ndata[\"feat\"]\n        #####################################################################\n        # (HIGHLIGHT) Creating a MultiLayerFullNeighborSampler instance.\n        # This sampler is used in the Graph Neural Networks (GNN) training\n        # process to provide neighbor sampling, which is crucial for\n        # efficient training of GNN on large graphs.\n        #\n        # The first argument '1' indicates the number of layers for\n        # the neighbor sampling. In this case, it's set to 1, meaning\n        # only the direct neighbors of each node will be included in the\n        # sampling.\n        #\n        # The 'prefetch_node_feats' parameter specifies the node features\n        # that need to be pre-fetched during sampling. In this case, the\n        # feature named 'feat' will be pre-fetched.\n        #\n        # `prefetch` in DGL initiates data fetching operations in parallel\n        # with model computations. This ensures data is ready when the\n        # computation needs it, thereby eliminating waiting times between\n        # fetching and computing steps and reducing the I/O overhead during\n        # the training process.\n        #\n        # The difference between whether to use prefetch or not is shown:\n        #\n        # Without Prefetch:\n        # Fetch1 ──> Compute1 ──> Fetch2 ──> Compute2 ──> Fetch3 ──> Compute3\n        #\n        # With Prefetch:\n        # Fetch1 ──> Fetch2 ──> Fetch3\n        #    │          │          │\n        #    └─Compute1 └─Compute2 └─Compute3\n        #####################################################################\n        sampler = MultiLayerFullNeighborSampler(1, prefetch_node_feats=[\"feat\"])\n        dataloader = DataLoader(\n            g,\n            torch.arange(g.num_nodes()).to(g.device),\n            sampler,\n            device=device,\n            batch_size=batch_size,\n            shuffle=False,\n            drop_last=False,\n            num_workers=0,\n        )\n        buffer_device = torch.device(\"cpu\")\n        # Enable pin_memory for faster CPU to GPU data transfer if the model is\n        # running on a GPU.\n        pin_memory = buffer_device != device\n        for layer_idx, layer in enumerate(self.layers):\n            is_last_layer = layer_idx == len(self.layers) - 1\n            y = torch.empty(\n                g.num_nodes(),\n                self.hidden_size,\n                device=buffer_device,\n                pin_memory=pin_memory,\n            )\n            feat = feat.to(device)\n            for input_nodes, output_nodes, blocks in tqdm.tqdm(\n                dataloader, desc=\"Inference\"\n            ):\n                x = feat[input_nodes]\n                hidden_x = layer(blocks[0], x)\n                if not is_last_layer:\n                    hidden_x = F.relu(hidden_x)\n                y[output_nodes] = hidden_x.to(buffer_device)\n            feat = y\n        return y\n\n\n@torch.no_grad()\ndef compute_mrr(\n    model, evaluator, node_emb, src, dst, neg_dst, device, batch_size=500\n):\n    \"\"\"Compute the Mean Reciprocal Rank (MRR) for given source and destination\n    nodes.\n\n    This function computes the MRR for a set of node pairs, dividing the task\n    into batches to handle potentially large graphs.\n    \"\"\"\n    rr = torch.zeros(src.shape[0])\n    # Loop over node pairs in batches.\n    for start in tqdm.trange(0, src.shape[0], batch_size, desc=\"Evaluate\"):\n        end = min(start + batch_size, src.shape[0])\n\n        # Concatenate positive and negative destination nodes.\n        all_dst = torch.cat([dst[start:end, None], neg_dst[start:end]], 1)\n\n        # Fetch embeddings for current batch of source and destination nodes.\n        h_src = node_emb[src[start:end]][:, None, :].to(device)\n        h_dst = node_emb[all_dst.view(-1)].view(*all_dst.shape, -1).to(device)\n\n        # Compute prediction scores using the model.\n        pred = model.predictor(h_src * h_dst).squeeze(-1)\n\n        # Evaluate the predictions to obtain MRR values.\n        input_dict = {\"y_pred_pos\": pred[:, 0], \"y_pred_neg\": pred[:, 1:]}\n        rr[start:end] = evaluator.eval(input_dict)[\"mrr_list\"]\n    return rr.mean()\n\n\n@torch.no_grad()\ndef evaluate(device, graph, edge_split, model, batch_size):\n    \"\"\"Evaluate the model on validation and test sets.\"\"\"\n    model.eval()\n    evaluator = Evaluator(name=\"ogbl-citation2\")\n\n    # Compute node embeddings for the entire graph.\n    node_emb = model.inference(graph, device, batch_size)\n    results = []\n\n    # Loop over both validation and test sets.\n    for split in [\"valid\", \"test\"]:\n        src = edge_split[split][\"source_node\"].to(node_emb.device)\n        dst = edge_split[split][\"target_node\"].to(node_emb.device)\n        neg_dst = edge_split[split][\"target_node_neg\"].to(node_emb.device)\n\n        # Compute MRR values for the current split.\n        results.append(\n            compute_mrr(model, evaluator, node_emb, src, dst, neg_dst, device)\n        )\n    return results\n\n\ndef train(\n    args, device, g, reverse_eids, seed_edges, model, use_uva, fused_sampling\n):\n    #####################################################################\n    # (HIGHLIGHT) Instantiate a NeighborSampler object for efficient\n    # training of Graph Neural Networks (GNNs) on large-scale graphs.\n    #\n    # The argument [15, 10, 5] sets the number of neighbors (fanout)\n    # to be sampled at each layer. Here, we have three layers, and\n    # 15/10/5 neighbors will be randomly selected for each node at each\n    # layer.\n    #\n    # The 'prefetch_node_feats' parameter specify the node features that\n    # needs to be pre-fetched during sampling. More details about\n    # `prefetch` can be found in the `SAGE.inference` function.\n    #\n    # (HIGHLIGHT) Modify the NeighborSampler for Edge Prediction\n    #\n    # This `as_edge_prediction_sampler` augments the original NeighborSampler\n    # to specifically handle edge prediction tasks, where not only the\n    # structure but also the relationships between nodes (edges) are of\n    # importance.\n    #\n    # - `exclude=\"reverse_id\"` ensures that the edges corresponding to the\n    #   reverse of the original edges are excluded during sampling, given that\n    #   reverse edges can introduce unnecessary redundancy in edge prediction.\n    #\n    # - `reverse_eids=reverse_eids` specifies the IDs of the reverse edges.\n    #   This information is vital so the sampler knows which edges to avoid.\n    #\n    # - The negative sampling strategy is specified using the\n    #   `negative_sampler`. Here, a uniform negative sampling method is\n    #   employed, where a negative sample (an edge that doesn't exist in the\n    #   original graph) is uniformly drawn from the set of all possible edges.\n    #\n    # The modified sampler is tailor-made for scenarios where the goal is\n    # not just to learn node representations, but also to predict the\n    # likelihood of an edge existing between two nodes (link prediction).\n    #####################################################################\n    sampler = NeighborSampler(\n        [15, 10, 5],\n        prefetch_node_feats=[\"feat\"],\n        fused=fused_sampling,\n    )\n    sampler = as_edge_prediction_sampler(\n        sampler,\n        exclude=\"reverse_id\" if args.exclude_edges else None,\n        reverse_eids=reverse_eids if args.exclude_edges else None,\n        negative_sampler=negative_sampler.Uniform(1),\n    )\n\n    dataloader = DataLoader(\n        g,\n        seed_edges,\n        sampler,\n        device=device,\n        batch_size=args.train_batch_size,\n        shuffle=True,\n        drop_last=False,\n        # If `g` is on gpu or `use_uva` is True, `num_workers` must be zero,\n        # otherwise it will cause error.\n        num_workers=0,\n        use_uva=use_uva,\n    )\n    opt = torch.optim.Adam(model.parameters(), lr=args.lr)\n    for epoch in range(args.epochs):\n        model.train()\n        total_loss = 0\n        start_epoch_time = time.time()\n        # A block is a graph consisting of two sets of nodes: the\n        # source nodes and destination nodes. The source and destination\n        # nodes can have multiple node types. All the edges connect from\n        # source nodes to destination nodes.\n        # For more details: https://discuss.dgl.ai/t/what-is-the-block/2932.\n        for it, (input_nodes, pair_graph, neg_pair_graph, blocks) in enumerate(\n            dataloader\n        ):\n            # The input features from the source nodes in the first layer's\n            # computation graph.\n            x = blocks[0].srcdata[\"feat\"]\n            pos_score, neg_score = model(pair_graph, neg_pair_graph, blocks, x)\n            score = torch.cat([pos_score, neg_score])\n\n            # Create true labels for positive and negative samples.\n            pos_label = torch.ones_like(pos_score)\n            neg_label = torch.zeros_like(neg_score)\n            labels = torch.cat([pos_label, neg_label])\n\n            # Compute the binary cross-entropy loss.\n            loss = F.binary_cross_entropy_with_logits(score, labels)\n            opt.zero_grad()\n            loss.backward()\n            opt.step()\n            total_loss += loss.item()\n            if (it + 1) == args.early_stop:\n                break\n        end_epoch_time = time.time()\n        print(\n            f\"Epoch {epoch:05d} | \"\n            f\"Loss {total_loss / (it + 1):.4f} | \"\n            f\"Time {(end_epoch_time - start_epoch_time):.4f} s\"\n        )\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--epochs\", type=int, default=10)\n    parser.add_argument(\n        \"--lr\",\n        type=float,\n        default=0.0005,\n        help=\"Learning rate. Default: 0.0005\",\n    )\n    parser.add_argument(\n        \"--train-batch-size\",\n        type=int,\n        default=512,\n        help=\"Batch size for training. Default: 512\",\n    )\n    parser.add_argument(\n        \"--eval-batch-size\",\n        type=int,\n        default=1024,\n        help=\"Batch size during evaluation. Default: 1024\",\n    )\n    parser.add_argument(\n        \"--early-stop\",\n        type=int,\n        default=0,\n        help=\"0 means no early stop, otherwise stop at the input-th step\",\n    )\n    parser.add_argument(\n        \"--exclude-edges\",\n        type=int,\n        default=1,\n        help=\"Whether to exclude reverse edges during sampling. Default: 1\",\n    )\n    parser.add_argument(\n        \"--compare-graphbolt\",\n        action=\"store_true\",\n        help=\"Compare with GraphBolt\",\n    )\n    parser.add_argument(\n        \"--mode\",\n        default=\"mixed\",\n        choices=[\"cpu\", \"mixed\", \"puregpu\"],\n        help=\"Training mode. 'cpu' for CPU training, 'mixed' for CPU-GPU mixed \"\n        \"training, 'puregpu' for pure-GPU training.\",\n    )\n    return parser.parse_args()\n\n\ndef main(args):\n    if not torch.cuda.is_available():\n        args.mode = \"cpu\"\n    print(f\"Training in {args.mode} mode.\")\n\n    # Load and preprocess dataset.\n    print(\"Loading data\")\n    dataset = DglLinkPropPredDataset(\"ogbl-citation2\")\n    g = dataset[0]\n    if args.compare_graphbolt:\n        fused_sampling = False\n    else:\n        fused_sampling = True\n        g = g.to(\"cuda\" if args.mode == \"puregpu\" else \"cpu\")\n\n    # Whether use Unified Virtual Addressing (UVA) for CUDA computation.\n    use_uva = args.mode == \"mixed\"\n    device = torch.device(\"cpu\" if args.mode == \"cpu\" else \"cuda\")\n\n    # Convert the graph to its bidirectional form.\n    g, reverse_eids = to_bidirected_with_reverse_mapping(g)\n    reverse_eids = reverse_eids.to(g.device)\n    seed_edges = torch.arange(g.num_edges()).to(g.device)\n    edge_split = dataset.get_edge_split()\n\n    # Create GraphSAGE model.\n    in_size = g.ndata[\"feat\"].shape[1]\n    model = SAGE(in_size, 256).to(device)\n\n    # Model training.\n    print(\"Training...\")\n    train(\n        args,\n        device,\n        g,\n        reverse_eids,\n        seed_edges,\n        model,\n        use_uva,\n        fused_sampling,\n    )\n\n    # Validate/Test the model.\n    print(\"Validation/Testing...\")\n    valid_mrr, test_mrr = evaluate(\n        device, g, edge_split, model, batch_size=args.eval_batch_size\n    )\n    print(\n        f\"Validation MRR {valid_mrr.item():.4f}, Test MRR {test_mrr.item():.4f}\"\n    )\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/legacy/node_classification.py",
    "content": "\"\"\"\nThis script trains and tests a GraphSAGE model for node classification on\nlarge graphs using efficient neighbor sampling.\n\nPaper: [Inductive Representation Learning on Large Graphs]\n(https://arxiv.org/abs/1706.02216)\n\nBefore reading this example, please familiar yourself with graphsage node\nclassification by reading the example in the\n`examples/core/graphsage/node_classification.py`\n\nIf you want to train graphsage on a large graph in a distributed fashion, read\nthe example in the `examples/distributed/graphsage/`.\n\nThis flowchart describes the main functional sequence of the provided example.\nmain\n│\n├───> Load and preprocess dataset\n│\n├───> Instantiate SAGE model\n│\n├───> train\n│     │\n│     ├───> NeighborSampler (HIGHLIGHT)\n│     │\n│     └───> Training loop\n│           │\n│           └───> SAGE.forward\n│\n└───> layerwise_infer\n      │\n      └───> SAGE.inference\n            │\n            └───> MultiLayerFullNeighborSampler (HIGHLIGHT)\n\"\"\"\n\nimport argparse\nimport time\n\nimport dgl\nimport dgl.nn as dglnn\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torchmetrics.functional as MF\nimport tqdm\nfrom dgl.data import AsNodePredDataset\nfrom dgl.dataloading import (\n    DataLoader,\n    MultiLayerFullNeighborSampler,\n    NeighborSampler,\n)\nfrom ogb.nodeproppred import DglNodePropPredDataset\n\n\nclass SAGE(nn.Module):\n    def __init__(self, in_size, hidden_size, out_size):\n        super().__init__()\n        self.layers = nn.ModuleList()\n        # Three-layer GraphSAGE-mean.\n        self.layers.append(dglnn.SAGEConv(in_size, hidden_size, \"mean\"))\n        self.layers.append(dglnn.SAGEConv(hidden_size, hidden_size, \"mean\"))\n        self.layers.append(dglnn.SAGEConv(hidden_size, out_size, \"mean\"))\n        self.dropout = nn.Dropout(0.5)\n        self.hidden_size = hidden_size\n        self.out_size = out_size\n\n    def forward(self, blocks, x):\n        hidden_x = x\n        for layer_idx, (layer, block) in enumerate(zip(self.layers, blocks)):\n            hidden_x = layer(block, hidden_x)\n            is_last_layer = layer_idx == len(self.layers) - 1\n            if not is_last_layer:\n                hidden_x = F.relu(hidden_x)\n                hidden_x = self.dropout(hidden_x)\n        return hidden_x\n\n    def inference(self, g, device, batch_size, fused_sampling: bool = True):\n        \"\"\"Conduct layer-wise inference to get all the node embeddings.\"\"\"\n        feat = g.ndata[\"feat\"]\n        #####################################################################\n        # (HIGHLIGHT) Creating a MultiLayerFullNeighborSampler instance.\n        # This sampler is used in the Graph Neural Networks (GNN) training\n        # process to provide neighbor sampling, which is crucial for\n        # efficient training of GNN on large graphs.\n        #\n        # The first argument '1' indicates the number of layers for\n        # the neighbor sampling. In this case, it's set to 1, meaning\n        # only the direct neighbors of each node will be included in the\n        # sampling.\n        #\n        # The 'prefetch_node_feats' parameter specifies the node features\n        # that need to be pre-fetched during sampling. In this case, the\n        # feature named 'feat' will be pre-fetched.\n        #\n        # `prefetch` in DGL initiates data fetching operations in parallel\n        # with model computations. This ensures data is ready when the\n        # computation needs it, thereby eliminating waiting times between\n        # fetching and computing steps and reducing the I/O overhead during\n        # the training process.\n        #\n        # The difference between whether to use prefetch or not is shown:\n        #\n        # Without Prefetch:\n        # Fetch1 ──> Compute1 ──> Fetch2 ──> Compute2 ──> Fetch3 ──> Compute3\n        #\n        # With Prefetch:\n        # Fetch1 ──> Fetch2 ──> Fetch3\n        #    │          │          │\n        #    └─Compute1 └─Compute2 └─Compute3\n        #####################################################################\n        sampler = MultiLayerFullNeighborSampler(\n            1, prefetch_node_feats=[\"feat\"], fused=fused_sampling\n        )\n\n        dataloader = DataLoader(\n            g,\n            torch.arange(g.num_nodes()).to(g.device),\n            sampler,\n            device=device,\n            batch_size=batch_size,\n            shuffle=False,\n            drop_last=False,\n            num_workers=0,\n        )\n        buffer_device = torch.device(\"cpu\")\n        # Enable pin_memory for faster CPU to GPU data transfer if the\n        # model is running on a GPU.\n        pin_memory = buffer_device != device\n\n        for layer_idx, layer in enumerate(self.layers):\n            is_last_layer = layer_idx == len(self.layers) - 1\n            y = torch.empty(\n                g.num_nodes(),\n                self.out_size if is_last_layer else self.hidden_size,\n                device=buffer_device,\n                pin_memory=pin_memory,\n            )\n            feat = feat.to(device)\n            for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):\n                x = feat[input_nodes]\n                hidden_x = layer(blocks[0], x)  # len(blocks) = 1\n                if layer_idx != len(self.layers) - 1:\n                    hidden_x = F.relu(hidden_x)\n                    hidden_x = self.dropout(hidden_x)\n                # By design, our output nodes are contiguous.\n                y[output_nodes[0] : output_nodes[-1] + 1] = hidden_x.to(\n                    buffer_device\n                )\n            feat = y\n        return y\n\n\n@torch.no_grad()\ndef evaluate(model, graph, dataloader, num_classes):\n    model.eval()\n    ys = []\n    y_hats = []\n    for it, (input_nodes, output_nodes, blocks) in enumerate(dataloader):\n        x = blocks[0].srcdata[\"feat\"]\n        ys.append(blocks[-1].dstdata[\"label\"])\n        y_hats.append(model(blocks, x))\n    return MF.accuracy(\n        torch.cat(y_hats),\n        torch.cat(ys),\n        task=\"multiclass\",\n        num_classes=num_classes,\n    )\n\n\n@torch.no_grad()\ndef layerwise_infer(\n    device, graph, nid, model, num_classes, batch_size, fused_sampling\n):\n    model.eval()\n    pred = model.inference(\n        graph, device, batch_size, fused_sampling\n    )  # pred in buffer_device.\n    pred = pred[nid]\n    label = graph.ndata[\"label\"][nid].to(pred.device)\n    return MF.accuracy(pred, label, task=\"multiclass\", num_classes=num_classes)\n\n\ndef train(device, g, dataset, model, num_classes, use_uva, fused_sampling):\n    # Create sampler & dataloader.\n    train_idx = dataset.train_idx.to(g.device if not use_uva else device)\n    val_idx = dataset.val_idx.to(g.device if not use_uva else device)\n    #####################################################################\n    # (HIGHLIGHT) Instantiate a NeighborSampler object for efficient\n    # training of Graph Neural Networks (GNNs) on large-scale graphs.\n    #\n    # The argument [10, 10, 10] sets the number of neighbors (fanout)\n    # to be sampled at each layer. Here, we have three layers, and\n    # 10 neighbors will be randomly selected for each node at each\n    # layer.\n    #\n    # The 'prefetch_node_feats' and 'prefetch_labels' parameters\n    # specify the node features and labels that need to be pre-fetched\n    # during sampling. More details about `prefetch` can be found in the\n    # `SAGE.inference` function.\n    #####################################################################\n    sampler = NeighborSampler(\n        [10, 10, 10],  # fanout for [layer-0, layer-1, layer-2]\n        prefetch_node_feats=[\"feat\"],\n        prefetch_labels=[\"label\"],\n        fused=fused_sampling,\n    )\n\n    train_dataloader = DataLoader(\n        g,\n        train_idx,\n        sampler,\n        device=device,\n        batch_size=1024,\n        shuffle=True,\n        drop_last=False,\n        # If `g` is on gpu or `use_uva` is True, `num_workers` must be zero,\n        # otherwise it will cause error.\n        num_workers=0,\n        use_uva=use_uva,\n    )\n\n    val_dataloader = DataLoader(\n        g,\n        val_idx,\n        sampler,\n        device=device,\n        batch_size=1024,\n        # No need to shuffle for validation.\n        shuffle=False,\n        drop_last=False,\n        num_workers=0,\n        use_uva=use_uva,\n    )\n\n    opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4)\n\n    for epoch in range(10):\n        t0 = time.time()\n        model.train()\n        total_loss = 0\n        # A block is a graph consisting of two sets of nodes: the\n        # source nodes and destination nodes. The source and destination\n        # nodes can have multiple node types. All the edges connect from\n        # source nodes to destination nodes.\n        # For more details: https://discuss.dgl.ai/t/what-is-the-block/2932.\n        for it, (input_nodes, output_nodes, blocks) in enumerate(\n            train_dataloader\n        ):\n            # The input features from the source nodes in the first layer's\n            # computation graph.\n            x = blocks[0].srcdata[\"feat\"]\n\n            # The ground truth labels from the destination nodes\n            # in the last layer's computation graph.\n            y = blocks[-1].dstdata[\"label\"]\n\n            y_hat = model(blocks, x)\n            loss = F.cross_entropy(y_hat, y)\n            opt.zero_grad()\n            loss.backward()\n            opt.step()\n            total_loss += loss.item()\n        t1 = time.time()\n        acc = evaluate(model, g, val_dataloader, num_classes)\n        print(\n            f\"Epoch {epoch:05d} | Loss {total_loss / (it + 1):.4f} | \"\n            f\"Accuracy {acc.item():.4f} | Time {t1 - t0:.4f}\"\n        )\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--mode\",\n        default=\"mixed\",\n        choices=[\"cpu\", \"mixed\", \"gpu\"],\n        help=\"Training mode. 'cpu' for CPU training, 'mixed' for \"\n        \"CPU-GPU mixed training, 'gpu' for pure-GPU training.\",\n    )\n    parser.add_argument(\n        \"--compare-to-graphbolt\",\n        default=\"false\",\n        choices=[\"false\", \"true\"],\n        help=\"Whether comparing to GraphBolt or not, 'false' by default.\",\n    )\n    args = parser.parse_args()\n    if not torch.cuda.is_available():\n        args.mode = \"cpu\"\n    print(f\"Training in {args.mode} mode.\")\n\n    # Load and preprocess dataset.\n    print(\"Loading data\")\n    dataset = AsNodePredDataset(DglNodePropPredDataset(\"ogbn-products\"))\n\n    g = dataset[0]\n    if args.compare_to_graphbolt == \"false\":\n        g = g.to(\"cuda\" if args.mode == \"gpu\" else \"cpu\")\n    num_classes = dataset.num_classes\n    # Whether use Unified Virtual Addressing (UVA) for CUDA computation.\n    use_uva = args.mode == \"mixed\"\n    device = torch.device(\"cpu\" if args.mode == \"cpu\" else \"cuda\")\n    fused_sampling = args.compare_to_graphbolt == \"false\"\n\n    # Create GraphSAGE model.\n    in_size = g.ndata[\"feat\"].shape[1]\n    out_size = dataset.num_classes\n    model = SAGE(in_size, 256, out_size).to(device)\n\n    # Model training.\n    print(\"Training...\")\n    train(device, g, dataset, model, num_classes, use_uva, fused_sampling)\n\n    # Test the model.\n    print(\"Testing...\")\n    acc = layerwise_infer(\n        device,\n        g,\n        dataset.test_idx,\n        model,\n        num_classes,\n        batch_size=4096,\n        fused_sampling=fused_sampling,\n    )\n    print(f\"Test accuracy {acc.item():.4f}\")\n"
  },
  {
    "path": "examples/multigpu/README.md",
    "content": "# Multiple GPU Training\n\n## Requirements\n\n```bash\npip install torchmetrics==0.11.4\n```\n\n## How to run\n\n### Node classification\n\nRun with following (available dataset: \"ogbn-products\", \"ogbn-arxiv\")\n\n```bash\npython3 node_classification_sage.py --dataset_name ogbn-products\n```\n\n#### __Results__ with default arguments\n```\n* Test Accuracy of \"ogbn-products\": ~0.7716\n* Test Accuracy of \"ogbn-arxiv\": ~0.6994\n```\n"
  },
  {
    "path": "examples/multigpu/graphbolt/README.md",
    "content": "# Multi-gpu training with GraphBolt data loader\n\n## How to run\n\n```bash\npython node_classification.py --gpu=0,1\n```"
  },
  {
    "path": "examples/multigpu/graphbolt/node_classification.py",
    "content": "\"\"\"\nThis script trains and tests a GraphSAGE model for node classification on\nmultiple GPUs using distributed data-parallel training (DDP) and GraphBolt\ndata loader. \n\nBefore reading this example, please familiar yourself with graphsage node\nclassification using GtaphBolt data loader by reading the example in the\n`examples/graphbolt/node_classification.py`.\n\nFor the usage of DDP provided by PyTorch, please read its documentation:\nhttps://pytorch.org/tutorials/beginner/dist_overview.html and\nhttps://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParal\nlel.html\n\nThis flowchart describes the main functional sequence of the provided example:\nmain\n│\n├───> OnDiskDataset pre-processing\n│\n└───> run (multiprocessing) \n      │\n      ├───> Init process group and build distributed SAGE model (HIGHLIGHT)\n      │\n      ├───> train\n      │     │\n      │     ├───> Get GraphBolt dataloader with DistributedItemSampler\n      │     │     (HIGHLIGHT)\n      │     │\n      │     └───> Training loop\n      │           │\n      │           ├───> SAGE.forward\n      │           │\n      │           ├───> Validation set evaluation\n      │           │\n      │           └───> Collect accuracy and loss from all ranks (HIGHLIGHT)\n      │\n      └───> Test set evaluation\n\"\"\"\nimport argparse\nimport os\nimport time\n\nimport dgl.graphbolt as gb\nimport dgl.nn as dglnn\nimport torch\nimport torch.distributed as dist\nimport torch.multiprocessing as mp\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torchmetrics.functional as MF\nimport tqdm\nfrom torch.distributed.algorithms.join import Join\nfrom torch.nn.parallel import DistributedDataParallel as DDP\n\n\nclass SAGE(nn.Module):\n    def __init__(self, in_size, hidden_size, out_size):\n        super().__init__()\n        self.layers = nn.ModuleList()\n        # Three-layer GraphSAGE-mean.\n        self.layers.append(dglnn.SAGEConv(in_size, hidden_size, \"mean\"))\n        self.layers.append(dglnn.SAGEConv(hidden_size, hidden_size, \"mean\"))\n        self.layers.append(dglnn.SAGEConv(hidden_size, out_size, \"mean\"))\n        self.dropout = nn.Dropout(0.5)\n        self.hidden_size = hidden_size\n        self.out_size = out_size\n        # Set the dtype for the layers manually.\n        self.set_layer_dtype(torch.float32)\n\n    def set_layer_dtype(self, dtype):\n        for layer in self.layers:\n            for param in layer.parameters():\n                param.data = param.data.to(dtype)\n\n    def forward(self, blocks, x):\n        hidden_x = x\n        for layer_idx, (layer, block) in enumerate(zip(self.layers, blocks)):\n            hidden_x = layer(block, hidden_x)\n            is_last_layer = layer_idx == len(self.layers) - 1\n            if not is_last_layer:\n                hidden_x = F.relu(hidden_x)\n                hidden_x = self.dropout(hidden_x)\n        return hidden_x\n\n\ndef create_dataloader(\n    args,\n    graph,\n    features,\n    itemset,\n    device,\n    is_train,\n):\n    ############################################################################\n    # [HIGHLIGHT]\n    # Get a GraphBolt dataloader for node classification tasks with multi-gpu\n    # distributed training. DistributedItemSampler instead of ItemSampler should\n    # be used.\n    ############################################################################\n\n    ############################################################################\n    # [Note]:\n    # gb.DistributedItemSampler()\n    # [Input]:\n    # 'item_set': The current dataset. (e.g. `train_set` or `valid_set`)\n    # 'batch_size': Specifies the number of samples to be processed together,\n    # referred to as a 'mini-batch'. (The term 'mini-batch' is used here to\n    # indicate a subset of the entire dataset that is processed together. This\n    # is in contrast to processing the entire dataset, known as a 'full batch'.)\n    # 'drop_last': Determines whether the last non-full minibatch should be\n    # dropped.\n    # 'shuffle': Determines if the items should be shuffled.\n    # 'num_replicas': Specifies the number of replicas.\n    # 'drop_uneven_inputs': Determines whether the numbers of minibatches on all\n    # ranks should be kept the same by dropping uneven minibatches.\n    # [Output]:\n    # An DistributedItemSampler object for handling mini-batch sampling on\n    # multiple replicas.\n    ############################################################################\n    datapipe = gb.DistributedItemSampler(\n        item_set=itemset,\n        batch_size=args.batch_size,\n        drop_last=is_train,\n        shuffle=is_train,\n        drop_uneven_inputs=is_train,\n    )\n    ############################################################################\n    # [Note]:\n    # datapipe.copy_to() / gb.CopyTo()\n    # [Input]:\n    # 'device': The specified device that data should be copied to.\n    # [Output]:\n    # A CopyTo object copying data in the datapipe to a specified device.\\\n    ############################################################################\n    if args.storage_device != \"cpu\":\n        datapipe = datapipe.copy_to(device)\n    datapipe = datapipe.sample_neighbor(\n        graph,\n        args.fanout,\n        overlap_fetch=args.storage_device == \"pinned\",\n        asynchronous=args.storage_device != \"cpu\",\n    )\n    datapipe = datapipe.fetch_feature(features, node_feature_keys=[\"feat\"])\n    if args.storage_device == \"cpu\":\n        datapipe = datapipe.copy_to(device)\n\n    dataloader = gb.DataLoader(datapipe, args.num_workers)\n\n    # Return the fully-initialized DataLoader object.\n    return dataloader\n\n\ndef weighted_reduce(tensor, weight, dst=0):\n    ########################################################################\n    # (HIGHLIGHT) Collect accuracy and loss values from sub-processes and\n    # obtain overall average values.\n    #\n    # `torch.distributed.reduce` is used to reduce tensors from all the\n    # sub-processes to a specified process, ReduceOp.SUM is used by default.\n    #\n    # Because the GPUs may have differing numbers of processed items, we\n    # perform a weighted mean to calculate the exact loss and accuracy.\n    ########################################################################\n    dist.reduce(tensor=tensor, dst=dst)\n    weight = torch.tensor(weight, device=tensor.device)\n    dist.reduce(tensor=weight, dst=dst)\n    return tensor / weight\n\n\n@torch.no_grad()\ndef evaluate(rank, model, dataloader, num_classes, device):\n    model.eval()\n    y = []\n    y_hats = []\n\n    for data in tqdm.tqdm(dataloader) if rank == 0 else dataloader:\n        blocks = data.blocks\n        x = data.node_features[\"feat\"]\n        y.append(data.labels)\n        y_hats.append(model.module(blocks, x))\n\n    res = MF.accuracy(\n        torch.cat(y_hats),\n        torch.cat(y),\n        task=\"multiclass\",\n        num_classes=num_classes,\n    )\n\n    return res.to(device), sum(y_i.size(0) for y_i in y)\n\n\ndef train(\n    rank,\n    args,\n    train_dataloader,\n    valid_dataloader,\n    num_classes,\n    model,\n    device,\n):\n    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)\n\n    for epoch in range(args.epochs):\n        epoch_start = time.time()\n\n        model.train()\n        total_loss = torch.tensor(0, dtype=torch.float, device=device)\n        num_train_items = 0\n        ########################################################################\n        # (HIGHLIGHT) Use Join Context Manager to solve uneven input problem.\n        #\n        # The mechanics of Distributed Data Parallel (DDP) training in PyTorch\n        # requires the number of inputs are the same for all ranks, otherwise\n        # the program may error or hang. To solve it, PyTorch provides Join\n        # Context Manager. Please refer to\n        # https://pytorch.org/tutorials/advanced/generic_join.html for detailed\n        # information.\n        #\n        # Another method is to set `drop_uneven_inputs` as True in GraphBolt's\n        # DistributedItemSampler, which will solve this problem by dropping\n        # uneven inputs.\n        ########################################################################\n        with Join([model]):\n            for data in (\n                tqdm.tqdm(train_dataloader) if rank == 0 else train_dataloader\n            ):\n                # The input features are from the source nodes in the first\n                # layer's computation graph.\n                x = data.node_features[\"feat\"]\n\n                # The ground truth labels are from the destination nodes\n                # in the last layer's computation graph.\n                y = data.labels\n\n                blocks = data.blocks\n\n                y_hat = model(blocks, x)\n\n                # Compute loss.\n                loss = F.cross_entropy(y_hat, y)\n\n                optimizer.zero_grad()\n                loss.backward()\n                optimizer.step()\n\n                total_loss += loss.detach() * y.size(0)\n                num_train_items += y.size(0)\n\n        # Evaluate the model.\n        if rank == 0:\n            print(\"Validating...\")\n        acc, num_val_items = evaluate(\n            rank,\n            model,\n            valid_dataloader,\n            num_classes,\n            device,\n        )\n\n        total_loss = weighted_reduce(total_loss, num_train_items)\n        acc = weighted_reduce(acc * num_val_items, num_val_items)\n\n        # We synchronize before measuring the epoch time.\n        torch.cuda.synchronize()\n        epoch_end = time.time()\n        if rank == 0:\n            print(\n                f\"Epoch {epoch:05d} | \"\n                f\"Average Loss {total_loss.item():.4f} | \"\n                f\"Accuracy {acc.item():.4f} | \"\n                f\"Time {epoch_end - epoch_start:.4f}\"\n            )\n\n\ndef run(rank, world_size, args, devices, dataset):\n    # Set up multiprocessing environment.\n    device = devices[rank]\n    torch.cuda.set_device(device)\n    dist.init_process_group(\n        backend=\"nccl\",  # Use NCCL backend for distributed GPU training\n        init_method=\"tcp://127.0.0.1:12345\",\n        world_size=world_size,\n        rank=rank,\n    )\n\n    # Pin the graph and features to enable GPU access.\n    if args.storage_device == \"pinned\":\n        graph = dataset.graph.pin_memory_()\n        feature = dataset.feature.pin_memory_()\n    else:\n        graph = dataset.graph.to(args.storage_device)\n        feature = dataset.feature.to(args.storage_device)\n\n    train_set = dataset.tasks[0].train_set\n    valid_set = dataset.tasks[0].validation_set\n    test_set = dataset.tasks[0].test_set\n    args.fanout = list(map(int, args.fanout.split(\",\")))\n    num_classes = dataset.tasks[0].metadata[\"num_classes\"]\n\n    in_size = feature.size(\"node\", None, \"feat\")[0]\n    hidden_size = 256\n    out_size = num_classes\n\n    if args.gpu_cache_size > 0 and args.storage_device != \"cuda\":\n        feature[(\"node\", None, \"feat\")] = gb.gpu_cached_feature(\n            feature[(\"node\", None, \"feat\")],\n            args.gpu_cache_size,\n        )\n\n    # Create GraphSAGE model. It should be copied onto a GPU as a replica.\n    model = SAGE(in_size, hidden_size, out_size).to(device)\n    model = DDP(model)\n\n    # Create data loaders.\n    train_dataloader = create_dataloader(\n        args,\n        graph,\n        feature,\n        train_set,\n        device,\n        is_train=True,\n    )\n    valid_dataloader = create_dataloader(\n        args,\n        graph,\n        feature,\n        valid_set,\n        device,\n        is_train=False,\n    )\n    test_dataloader = create_dataloader(\n        args,\n        graph,\n        feature,\n        test_set,\n        device,\n        is_train=False,\n    )\n\n    # Model training.\n    if rank == 0:\n        print(\"Training...\")\n    train(\n        rank,\n        args,\n        train_dataloader,\n        valid_dataloader,\n        num_classes,\n        model,\n        device,\n    )\n\n    # Test the model.\n    if rank == 0:\n        print(\"Testing...\")\n    test_acc, num_test_items = evaluate(\n        rank,\n        model,\n        test_dataloader,\n        num_classes,\n        device,\n    )\n    test_acc = weighted_reduce(test_acc * num_test_items, num_test_items)\n\n    if rank == 0:\n        print(f\"Test Accuracy {test_acc.item():.4f}\")\n    dist.destroy_process_group()\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(\n        description=\"A script does a multi-gpu training on a GraphSAGE model \"\n        \"for node classification using GraphBolt dataloader.\"\n    )\n    parser.add_argument(\n        \"--gpu\",\n        type=str,\n        default=\"0\",\n        help=\"GPU(s) in use. Can be a list of gpu ids for multi-gpu training,\"\n        \" e.g., 0,1,2,3.\",\n    )\n    parser.add_argument(\n        \"--epochs\", type=int, default=10, help=\"Number of training epochs.\"\n    )\n    parser.add_argument(\n        \"--lr\",\n        type=float,\n        default=0.001,\n        help=\"Learning rate for optimization.\",\n    )\n    parser.add_argument(\n        \"--batch-size\", type=int, default=1024, help=\"Batch size for training.\"\n    )\n    parser.add_argument(\n        \"--fanout\",\n        type=str,\n        default=\"10,10,10\",\n        help=\"Fan-out of neighbor sampling. It is IMPORTANT to keep len(fanout)\"\n        \" identical with the number of layers in your model. Default: 10,10,10\",\n    )\n    parser.add_argument(\n        \"--num-workers\", type=int, default=0, help=\"The number of processes.\"\n    )\n    parser.add_argument(\n        \"--gpu-cache-size\",\n        type=int,\n        default=0,\n        help=\"The capacity of the GPU cache in bytes.\",\n    )\n    parser.add_argument(\n        \"--dataset\",\n        type=str,\n        default=\"ogbn-products\",\n        choices=[\"ogbn-arxiv\", \"ogbn-products\", \"ogbn-papers100M\"],\n        help=\"The dataset we can use for node classification example. Currently\"\n        \" ogbn-products, ogbn-arxiv, ogbn-papers100M datasets are supported.\",\n    )\n    parser.add_argument(\n        \"--mode\",\n        default=\"pinned-cuda\",\n        choices=[\"cpu-cuda\", \"pinned-cuda\", \"cuda-cuda\"],\n        help=\"Dataset storage placement and Train device: 'cpu' for CPU and RAM\"\n        \", 'pinned' for pinned memory in RAM, 'cuda' for GPU and GPU memory.\",\n    )\n    return parser.parse_args()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    if not torch.cuda.is_available():\n        print(f\"Multi-gpu training needs to be in gpu mode.\")\n        exit(0)\n    args.storage_device, _ = args.mode.split(\"-\")\n\n    devices = list(map(int, args.gpu.split(\",\")))\n    world_size = len(devices)\n\n    print(f\"Training with {world_size} gpus.\")\n\n    # Load and preprocess dataset.\n    dataset = gb.BuiltinDataset(args.dataset).load()\n\n    # Thread limiting to avoid resource competition.\n    os.environ[\"OMP_NUM_THREADS\"] = str(mp.cpu_count() // 2 // world_size)\n\n    mp.set_sharing_strategy(\"file_system\")\n    mp.spawn(\n        run,\n        args=(world_size, args, devices, dataset),\n        nprocs=world_size,\n        join=True,\n    )\n"
  },
  {
    "path": "examples/multigpu/node_classification_sage.py",
    "content": "\"\"\"\nThis script trains and tests a GraphSAGE model for node classification on\nmultiple GPUs with distributed data-parallel training (DDP).\n\nBefore reading this example, please familiar yourself with graphsage node\nclassification using neighbor sampling by reading the example in the\n`examples/sampling/node_classification.py`\n\nThis flowchart describes the main functional sequence of the provided example.\nmain\n│\n├───> Load and preprocess dataset\n│\n└───> run (multiprocessing) \n      │\n      ├───> Init process group and build distributed SAGE model (HIGHLIGHT)\n      │\n      ├───> train\n      │     │\n      │     ├───> NeighborSampler\n      │     │\n      │     └───> Training loop\n      │           │\n      │           ├───> SAGE.forward\n      │           │\n      │           └───> Collect validation accuracy (HIGHLIGHT)\n      │\n      └───> layerwise_infer\n            │\n            └───> SAGE.inference\n                  │\n                  ├───> MultiLayerFullNeighborSampler\n                  │\n                  └───> Use a shared output tensor\n\"\"\"\nimport argparse\nimport os\nimport time\n\nimport dgl\nimport dgl.nn as dglnn\n\nimport torch\nimport torch.distributed as dist\nimport torch.multiprocessing as mp\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torchmetrics.functional as MF\nimport tqdm\nfrom dgl.data import AsNodePredDataset\nfrom dgl.dataloading import (\n    DataLoader,\n    MultiLayerFullNeighborSampler,\n    NeighborSampler,\n)\nfrom dgl.multiprocessing import shared_tensor\nfrom ogb.nodeproppred import DglNodePropPredDataset\nfrom torch.nn.parallel import DistributedDataParallel\n\n\nclass SAGE(nn.Module):\n    def __init__(self, in_size, hid_size, out_size):\n        super().__init__()\n        self.layers = nn.ModuleList()\n        # Three-layer GraphSAGE-mean\n        self.layers.append(dglnn.SAGEConv(in_size, hid_size, \"mean\"))\n        self.layers.append(dglnn.SAGEConv(hid_size, hid_size, \"mean\"))\n        self.layers.append(dglnn.SAGEConv(hid_size, out_size, \"mean\"))\n        self.dropout = nn.Dropout(0.5)\n        self.hid_size = hid_size\n        self.out_size = out_size\n\n    def forward(self, blocks, x):\n        h = x\n        for l, (layer, block) in enumerate(zip(self.layers, blocks)):\n            h = layer(block, h)\n            if l != len(self.layers) - 1:\n                h = F.relu(h)\n                h = self.dropout(h)\n        return h\n\n    def inference(self, g, device, batch_size, use_uva):\n        g.ndata[\"h\"] = g.ndata[\"feat\"]\n        sampler = MultiLayerFullNeighborSampler(1, prefetch_node_feats=[\"h\"])\n        for l, layer in enumerate(self.layers):\n            dataloader = DataLoader(\n                g,\n                torch.arange(g.num_nodes(), device=device),\n                sampler,\n                device=device,\n                batch_size=batch_size,\n                shuffle=False,\n                drop_last=False,\n                num_workers=0,\n                use_ddp=True,  # use DDP\n                use_uva=use_uva,\n            )\n            # In order to prevent running out of GPU memory, allocate a shared\n            # output tensor 'y' in host memory.\n            y = shared_tensor(\n                (\n                    g.num_nodes(),\n                    self.hid_size\n                    if l != len(self.layers) - 1\n                    else self.out_size,\n                )\n            )\n            for input_nodes, output_nodes, blocks in (\n                tqdm.tqdm(dataloader) if dist.get_rank() == 0 else dataloader\n            ):\n                x = blocks[0].srcdata[\"h\"]\n                h = layer(blocks[0], x)  # len(blocks) = 1\n                if l != len(self.layers) - 1:\n                    h = F.relu(h)\n                    h = self.dropout(h)\n                # Non_blocking (with pinned memory) to accelerate data transfer\n                y[output_nodes] = h.to(y.device, non_blocking=True)\n            # Use a barrier to make sure all GPUs are done writing to 'y'\n            dist.barrier()\n            g.ndata[\"h\"] = y if use_uva else y.to(device)\n\n        g.ndata.pop(\"h\")\n        return y\n\n\ndef evaluate(device, model, g, num_classes, dataloader):\n    model.eval()\n    ys = []\n    y_hats = []\n    for it, (input_nodes, output_nodes, blocks) in enumerate(dataloader):\n        with torch.no_grad():\n            blocks = [block.to(device) for block in blocks]\n            x = blocks[0].srcdata[\"feat\"]\n            ys.append(blocks[-1].dstdata[\"label\"])\n            y_hats.append(model(blocks, x))\n    return MF.accuracy(\n        torch.cat(y_hats),\n        torch.cat(ys),\n        task=\"multiclass\",\n        num_classes=num_classes,\n    )\n\n\ndef layerwise_infer(\n    proc_id, device, g, num_classes, nid, model, use_uva, batch_size=2**10\n):\n    model.eval()\n    with torch.no_grad():\n        if not use_uva:\n            g = g.to(device)\n        pred = model.module.inference(g, device, batch_size, use_uva)\n        pred = pred[nid]\n        labels = g.ndata[\"label\"][nid].to(pred.device)\n    if proc_id == 0:\n        acc = MF.accuracy(\n            pred, labels, task=\"multiclass\", num_classes=num_classes\n        )\n        print(f\"Test accuracy {acc.item():.4f}\")\n\n\ndef train(\n    proc_id,\n    nprocs,\n    device,\n    args,\n    g,\n    num_classes,\n    train_idx,\n    val_idx,\n    model,\n    use_uva,\n):\n    # Instantiate a neighbor sampler\n    if args.mode == \"benchmark\":\n        # A work-around to prevent CUDA running error. For more details, please\n        # see https://github.com/dmlc/dgl/issues/6697.\n        sampler = NeighborSampler([10, 10, 10], fused=False)\n    else:\n        sampler = NeighborSampler(\n            [10, 10, 10],\n            prefetch_node_feats=[\"feat\"],\n            prefetch_labels=[\"label\"],\n        )\n    train_dataloader = DataLoader(\n        g,\n        train_idx,\n        sampler,\n        device=device,\n        batch_size=1024,\n        shuffle=True,\n        drop_last=False,\n        num_workers=args.num_workers,\n        use_ddp=True,  # To split the set for each process\n        use_uva=use_uva,\n    )\n    val_dataloader = DataLoader(\n        g,\n        val_idx,\n        sampler,\n        device=device,\n        batch_size=1024,\n        shuffle=True,\n        drop_last=False,\n        num_workers=args.num_workers,\n        use_ddp=True,\n        use_uva=use_uva,\n    )\n    opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)\n    for epoch in range(args.num_epochs):\n        t0 = time.time()\n        model.train()\n        total_loss = 0\n        for it, (input_nodes, output_nodes, blocks) in enumerate(\n            train_dataloader\n        ):\n            x = blocks[0].srcdata[\"feat\"]\n            y = blocks[-1].dstdata[\"label\"].to(torch.int64)\n            y_hat = model(blocks, x)\n            loss = F.cross_entropy(y_hat, y)\n            opt.zero_grad()\n            loss.backward()\n            opt.step()  # Gradients are synchronized in DDP\n            total_loss += loss\n        #####################################################################\n        # (HIGHLIGHT) Collect accuracy values from sub-processes and obtain\n        # overall accuracy.\n        #\n        # `torch.distributed.reduce` is used to reduce tensors from all the\n        # sub-processes to a specified process, ReduceOp.SUM is used by default.\n        #\n        # Other multiprocess functions supported by the backend are also\n        # available. Please refer to\n        # https://pytorch.org/docs/stable/distributed.html\n        # for more information.\n        #####################################################################\n        acc = (\n            evaluate(device, model, g, num_classes, val_dataloader).to(device)\n            / nprocs\n        )\n        t1 = time.time()\n        # Reduce `acc` tensors to process 0.\n        dist.reduce(tensor=acc, dst=0)\n        if proc_id == 0:\n            print(\n                f\"Epoch {epoch:05d} | Loss {total_loss / (it + 1):.4f} | \"\n                f\"Accuracy {acc.item():.4f} | Time {t1 - t0:.4f}\"\n            )\n\n\ndef run(proc_id, nprocs, devices, g, data, args):\n    # Find corresponding device for current process.\n    device = devices[proc_id]\n    torch.cuda.set_device(device)\n    #########################################################################\n    # (HIGHLIGHT) Build a data-parallel distributed GraphSAGE model.\n    #\n    # DDP in PyTorch provides data parallelism across the devices specified\n    # by the `process_group`. Gradients are synchronized across each model\n    # replica.\n    #\n    # To prepare a training sub-process, there are four steps involved:\n    # 1. Initialize the process group\n    # 2. Unpack data for the sub-process.\n    # 3. Instantiate a GraphSAGE model on the corresponding device.\n    # 4. Parallelize the model with `DistributedDataParallel`.\n    #\n    # For the detailed usage of `DistributedDataParallel`, please refer to\n    # PyTorch documentation.\n    #########################################################################\n    dist.init_process_group(\n        backend=\"nccl\",  # Use NCCL backend for distributed GPU training\n        init_method=\"tcp://127.0.0.1:12345\",\n        world_size=nprocs,\n        rank=proc_id,\n    )\n    num_classes, train_idx, val_idx, test_idx = data\n    if args.mode != \"benchmark\":\n        train_idx = train_idx.to(device)\n        val_idx = val_idx.to(device)\n        g = g.to(device if args.mode == \"puregpu\" else \"cpu\")\n    in_size = g.ndata[\"feat\"].shape[1]\n    model = SAGE(in_size, 256, num_classes).to(device)\n    model = DistributedDataParallel(\n        model, device_ids=[device], output_device=device\n    )\n\n    # Training.\n    use_uva = args.mode == \"mixed\"\n\n    if proc_id == 0:\n        print(\"Training...\")\n    train(\n        proc_id,\n        nprocs,\n        device,\n        args,\n        g,\n        num_classes,\n        train_idx,\n        val_idx,\n        model,\n        use_uva,\n    )\n\n    # Testing.\n    if proc_id == 0:\n        print(\"Testing...\")\n    layerwise_infer(proc_id, device, g, num_classes, test_idx, model, use_uva)\n\n    # Cleanup the process group.\n    dist.destroy_process_group()\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--mode\",\n        default=\"mixed\",\n        choices=[\"mixed\", \"puregpu\", \"benchmark\"],\n        help=\"Training mode. 'mixed' for CPU-GPU mixed training, \"\n        \"'puregpu' for pure-GPU training.\",\n    )\n    parser.add_argument(\n        \"--gpu\",\n        type=str,\n        default=\"0\",\n        help=\"GPU(s) in use. Can be a list of gpu ids for multi-gpu training,\"\n        \" e.g., 0,1,2,3.\",\n    )\n    parser.add_argument(\n        \"--num_epochs\",\n        type=int,\n        default=10,\n        help=\"Number of epochs for train.\",\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=\"ogbn-products\",\n        help=\"Dataset name.\",\n    )\n    parser.add_argument(\n        \"--dataset_dir\",\n        type=str,\n        default=\"dataset\",\n        help=\"Root directory of dataset.\",\n    )\n    parser.add_argument(\n        \"--num_workers\",\n        type=int,\n        default=0,\n        help=\"Number of workers\",\n    )\n    args = parser.parse_args()\n    devices = list(map(int, args.gpu.split(\",\")))\n    nprocs = len(devices)\n    assert (\n        torch.cuda.is_available()\n    ), f\"Must have GPUs to enable multi-gpu training.\"\n    print(f\"Training in {args.mode} mode using {nprocs} GPU(s)\")\n\n    # Load and preprocess the dataset.\n    print(\"Loading data\")\n    dataset = AsNodePredDataset(\n        DglNodePropPredDataset(args.dataset_name, root=args.dataset_dir)\n    )\n    g = dataset[0]\n    # Explicitly create desired graph formats before multi-processing to avoid\n    # redundant creation in each sub-process and to save memory.\n    g.create_formats_()\n    if args.dataset_name == \"ogbn-arxiv\":\n        g = dgl.to_bidirected(g, copy_ndata=True)\n        g = dgl.add_self_loop(g)\n    # Thread limiting to avoid resource competition.\n    os.environ[\"OMP_NUM_THREADS\"] = str(mp.cpu_count() // 2 // nprocs)\n    data = (\n        dataset.num_classes,\n        dataset.train_idx,\n        dataset.val_idx,\n        dataset.test_idx,\n    )\n\n    # To use DDP with n GPUs, spawn up n processes.\n    mp.spawn(\n        run,\n        args=(nprocs, devices, g, data, args),\n        nprocs=nprocs,\n    )\n"
  },
  {
    "path": "examples/mxnet/README.md",
    "content": "# Model Examples using DGL (w/ MXNet backend)\n\nuse `DGLBACKEND=mxnet` to use MXNet as DGL's backend\n\n## Examples:\n\n```\nDGLBACKEND=mxnet python gcn_batch.py --dataset cora\nDGLBACKEND=mxnet python gat_batch.py --dataset cora\n```\n\nEach model is hosted in their own folders. Please read their README.md to see how to\nrun them.\n\nTo understand step-by-step how these models are implemented in DGL. Check out our\n[tutorials](https://docs.dgl.ai/tutorials/models/index.html)\n"
  },
  {
    "path": "examples/mxnet/appnp/README.md",
    "content": "Predict then Propagate: Graph Neural Networks meet Personalized PageRank (APPNP)\n============\n\n- Paper link: [Predict then Propagate: Graph Neural Networks meet Personalized PageRank](https://arxiv.org/abs/1810.05997)\n- Author's code repo: [https://github.com/klicperajo/ppnp](https://github.com/klicperajo/ppnp). \n\nDependencies\n------------\n- MXNET 1.5+\n- requests\n\n``bash\npip install torch requests\n``\n\nCode\n-----\nThe folder contains an implementation of APPNP (`appnp.py`).\n\nResults\n-------\n\nRun with following (available dataset: \"cora\", \"citeseer\", \"pubmed\")\n```bash\nDGLBACKEND=mxnet python3 appnp.py --dataset cora --gpu 0\n```\n\n* cora: 0.8370 (paper: 0.850)\n* citeseer: 0.713 (paper: 0.757)\n* pubmed: 0.798 (paper: 0.797)\n\nExperiments were done on dgl datasets (GCN settings) which are different from those used in the original implementation. (discrepancies are detailed in experimental section of the original paper)\n"
  },
  {
    "path": "examples/mxnet/appnp/appnp.py",
    "content": "import argparse\nimport time\n\nimport dgl\n\nimport mxnet as mx\nimport numpy as np\nfrom dgl.data import (\n    CiteseerGraphDataset,\n    CoraGraphDataset,\n    PubmedGraphDataset,\n    register_data_args,\n)\nfrom dgl.nn.mxnet.conv import APPNPConv\nfrom mxnet import gluon, nd\nfrom mxnet.gluon import nn\n\n\nclass APPNP(nn.Block):\n    def __init__(\n        self,\n        g,\n        in_feats,\n        hiddens,\n        n_classes,\n        activation,\n        feat_drop,\n        edge_drop,\n        alpha,\n        k,\n    ):\n        super(APPNP, self).__init__()\n        self.g = g\n\n        with self.name_scope():\n            self.layers = nn.Sequential()\n            # input layer\n            self.layers.add(nn.Dense(hiddens[0], in_units=in_feats))\n            # hidden layers\n            for i in range(1, len(hiddens)):\n                self.layers.add(nn.Dense(hiddens[i], in_units=hiddens[i - 1]))\n            # output layer\n            self.layers.add(nn.Dense(n_classes, in_units=hiddens[-1]))\n            self.activation = activation\n            if feat_drop:\n                self.feat_drop = nn.Dropout(feat_drop)\n            else:\n                self.feat_drop = lambda x: x\n            self.propagate = APPNPConv(k, alpha, edge_drop)\n\n    def forward(self, features):\n        # prediction step\n        h = features\n        h = self.feat_drop(h)\n        h = self.activation(self.layers[0](h))\n        for layer in self.layers[1:-1]:\n            h = self.activation(layer(h))\n        h = self.layers[-1](self.feat_drop(h))\n        # propagation step\n        h = self.propagate(self.g, h)\n        return h\n\n\ndef evaluate(model, features, labels, mask):\n    pred = model(features).argmax(axis=1)\n    accuracy = ((pred == labels) * mask).sum() / mask.sum().asscalar()\n    return accuracy.asscalar()\n\n\ndef main(args):\n    # load and preprocess dataset\n    if args.dataset == \"cora\":\n        data = CoraGraphDataset()\n    elif args.dataset == \"citeseer\":\n        data = CiteseerGraphDataset()\n    elif args.dataset == \"pubmed\":\n        data = PubmedGraphDataset()\n    else:\n        raise ValueError(\"Unknown dataset: {}\".format(args.dataset))\n\n    g = data[0]\n    if args.gpu < 0:\n        cuda = False\n        ctx = mx.cpu(0)\n    else:\n        cuda = True\n        ctx = mx.gpu(args.gpu)\n        g = g.to(ctx)\n\n    features = g.ndata[\"feat\"]\n    labels = mx.nd.array(g.ndata[\"label\"], dtype=\"float32\", ctx=ctx)\n    train_mask = g.ndata[\"train_mask\"]\n    val_mask = g.ndata[\"val_mask\"]\n    test_mask = g.ndata[\"test_mask\"]\n    in_feats = features.shape[1]\n    n_classes = data.num_classes\n    n_edges = data.graph.number_of_edges()\n    print(\n        \"\"\"----Data statistics------'\n      #Edges %d\n      #Classes %d\n      #Train samples %d\n      #Val samples %d\n      #Test samples %d\"\"\"\n        % (\n            n_edges,\n            n_classes,\n            train_mask.sum().asscalar(),\n            val_mask.sum().asscalar(),\n            test_mask.sum().asscalar(),\n        )\n    )\n\n    # add self loop\n    g = dgl.remove_self_loop(g)\n    g = dgl.add_self_loop(g)\n\n    # create APPNP model\n    model = APPNP(\n        g,\n        in_feats,\n        args.hidden_sizes,\n        n_classes,\n        nd.relu,\n        args.in_drop,\n        args.edge_drop,\n        args.alpha,\n        args.k,\n    )\n\n    model.initialize(ctx=ctx)\n    n_train_samples = train_mask.sum().asscalar()\n    loss_fcn = gluon.loss.SoftmaxCELoss()\n\n    # use optimizer\n    print(model.collect_params())\n    trainer = gluon.Trainer(\n        model.collect_params(),\n        \"adam\",\n        {\"learning_rate\": args.lr, \"wd\": args.weight_decay},\n    )\n\n    # initialize graph\n    dur = []\n    for epoch in range(args.n_epochs):\n        if epoch >= 3:\n            t0 = time.time()\n        # forward\n        with mx.autograd.record():\n            pred = model(features)\n            loss = loss_fcn(pred, labels, mx.nd.expand_dims(train_mask, 1))\n            loss = loss.sum() / n_train_samples\n\n        loss.backward()\n        trainer.step(batch_size=1)\n\n        if epoch >= 3:\n            loss.asscalar()\n            dur.append(time.time() - t0)\n            acc = evaluate(model, features, labels, val_mask)\n            print(\n                \"Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | \"\n                \"ETputs(KTEPS) {:.2f}\".format(\n                    epoch,\n                    np.mean(dur),\n                    loss.asscalar(),\n                    acc,\n                    n_edges / np.mean(dur) / 1000,\n                )\n            )\n\n    # test set accuracy\n    acc = evaluate(model, features, labels, test_mask)\n    print(\"Test accuracy {:.2%}\".format(acc))\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"APPNP\")\n    register_data_args(parser)\n    parser.add_argument(\n        \"--in-drop\", type=float, default=0.5, help=\"input feature dropout\"\n    )\n    parser.add_argument(\n        \"--edge-drop\", type=float, default=0.5, help=\"edge propagation dropout\"\n    )\n    parser.add_argument(\"--gpu\", type=int, default=-1, help=\"gpu\")\n    parser.add_argument(\"--lr\", type=float, default=1e-2, help=\"learning rate\")\n    parser.add_argument(\n        \"--n-epochs\", type=int, default=200, help=\"number of training epochs\"\n    )\n    parser.add_argument(\n        \"--hidden_sizes\",\n        type=int,\n        nargs=\"+\",\n        default=[64],\n        help=\"hidden unit sizes for appnp\",\n    )\n    parser.add_argument(\n        \"--k\", type=int, default=10, help=\"Number of propagation steps\"\n    )\n    parser.add_argument(\n        \"--alpha\", type=float, default=0.1, help=\"Teleport Probability\"\n    )\n    parser.add_argument(\n        \"--weight-decay\", type=float, default=5e-4, help=\"Weight for L2 loss\"\n    )\n    args = parser.parse_args()\n    print(args)\n\n    main(args)\n"
  },
  {
    "path": "examples/mxnet/gat/README.md",
    "content": "Graph Attention Networks (GAT)\n============\n\n- Paper link: [https://arxiv.org/abs/1710.10903](https://arxiv.org/abs/1710.10903)\n- Author's code repo:\n  [https://github.com/PetarV-/GAT](https://github.com/PetarV-/GAT).\n\nNote that the original code is implemented with Tensorflow for the paper.\n\n### Dependencies\n* MXNet nightly build\n* requests\n\n```bash\npip install mxnet --pre\npip install requests\n```\n\n\n### Usage (make sure that DGLBACKEND is changed into mxnet)\n```bash\nDGLBACKEND=mxnet python3 train.py --dataset cora --gpu 0\nDGLBACKEND=mxnet python3 train.py --dataset citeseer --gpu 0 --early-stop\nDGLBACKEND=mxnet python3 train.py --dataset pubmed --gpu 0 --early-stop\n```\n"
  },
  {
    "path": "examples/mxnet/gat/gat.py",
    "content": "\"\"\"\nGraph Attention Networks in DGL using SPMV optimization.\nReferences\n----------\nPaper: https://arxiv.org/abs/1710.10903\nAuthor's code: https://github.com/PetarV-/GAT\nPytorch implementation: https://github.com/Diego999/pyGAT\n\"\"\"\n\nimport mxnet.gluon.nn as nn\n\nfrom dgl.nn.mxnet.conv import GATConv\n\n\nclass GAT(nn.Block):\n    def __init__(\n        self,\n        g,\n        num_layers,\n        in_dim,\n        num_hidden,\n        num_classes,\n        heads,\n        activation,\n        feat_drop,\n        attn_drop,\n        alpha,\n        residual,\n    ):\n        super(GAT, self).__init__()\n        self.g = g\n        self.num_layers = num_layers\n        self.gat_layers = []\n        self.activation = activation\n        # input projection (no residual)\n        self.gat_layers.append(\n            GATConv(\n                in_dim, num_hidden, heads[0], feat_drop, attn_drop, alpha, False\n            )\n        )\n        # hidden layers\n        for l in range(1, num_layers):\n            # due to multi-head, the in_dim = num_hidden * num_heads\n            self.gat_layers.append(\n                GATConv(\n                    num_hidden * heads[l - 1],\n                    num_hidden,\n                    heads[l],\n                    feat_drop,\n                    attn_drop,\n                    alpha,\n                    residual,\n                )\n            )\n        # output projection\n        self.gat_layers.append(\n            GATConv(\n                num_hidden * heads[-2],\n                num_classes,\n                heads[-1],\n                feat_drop,\n                attn_drop,\n                alpha,\n                residual,\n            )\n        )\n        for i, layer in enumerate(self.gat_layers):\n            self.register_child(layer, \"gat_layer_{}\".format(i))\n\n    def forward(self, inputs):\n        h = inputs\n        for l in range(self.num_layers):\n            h = self.gat_layers[l](self.g, h).flatten()\n            h = self.activation(h)\n        # output projection\n        logits = self.gat_layers[-1](self.g, h).mean(1)\n        return logits\n"
  },
  {
    "path": "examples/mxnet/gat/train.py",
    "content": "\"\"\"\nGraph Attention Networks in DGL using SPMV optimization.\nMultiple heads are also batched together for faster training.\nReferences\n----------\nPaper: https://arxiv.org/abs/1710.10903\nAuthor's code: https://github.com/PetarV-/GAT\nPytorch implementation: https://github.com/Diego999/pyGAT\n\"\"\"\n\nimport argparse\nimport time\n\nimport dgl\n\nimport mxnet as mx\nimport networkx as nx\nimport numpy as np\nfrom dgl.data import (\n    CiteseerGraphDataset,\n    CoraGraphDataset,\n    PubmedGraphDataset,\n    register_data_args,\n)\nfrom gat import GAT\nfrom mxnet import gluon\nfrom utils import EarlyStopping\n\n\ndef elu(data):\n    return mx.nd.LeakyReLU(data, act_type=\"elu\")\n\n\ndef evaluate(model, features, labels, mask):\n    logits = model(features)\n    logits = logits[mask].asnumpy().squeeze()\n    val_labels = labels[mask].asnumpy().squeeze()\n    max_index = np.argmax(logits, axis=1)\n    accuracy = np.sum(np.where(max_index == val_labels, 1, 0)) / len(val_labels)\n    return accuracy\n\n\ndef main(args):\n    # load and preprocess dataset\n    if args.dataset == \"cora\":\n        data = CoraGraphDataset()\n    elif args.dataset == \"citeseer\":\n        data = CiteseerGraphDataset()\n    elif args.dataset == \"pubmed\":\n        data = PubmedGraphDataset()\n    else:\n        raise ValueError(\"Unknown dataset: {}\".format(args.dataset))\n\n    g = data[0]\n    if args.gpu < 0:\n        cuda = False\n        ctx = mx.cpu(0)\n    else:\n        cuda = True\n        ctx = mx.gpu(args.gpu)\n        g = g.to(ctx)\n\n    features = g.ndata[\"feat\"]\n    labels = mx.nd.array(g.ndata[\"label\"], dtype=\"float32\", ctx=ctx)\n    mask = g.ndata[\"train_mask\"]\n    mask = mx.nd.array(np.nonzero(mask.asnumpy())[0], ctx=ctx)\n    val_mask = g.ndata[\"val_mask\"]\n    val_mask = mx.nd.array(np.nonzero(val_mask.asnumpy())[0], ctx=ctx)\n    test_mask = g.ndata[\"test_mask\"]\n    test_mask = mx.nd.array(np.nonzero(test_mask.asnumpy())[0], ctx=ctx)\n    in_feats = features.shape[1]\n    n_classes = data.num_classes\n    n_edges = data.graph.number_of_edges()\n\n    g = dgl.remove_self_loop(g)\n    g = dgl.add_self_loop(g)\n    # create model\n    heads = ([args.num_heads] * args.num_layers) + [args.num_out_heads]\n    model = GAT(\n        g,\n        args.num_layers,\n        in_feats,\n        args.num_hidden,\n        n_classes,\n        heads,\n        elu,\n        args.in_drop,\n        args.attn_drop,\n        args.alpha,\n        args.residual,\n    )\n\n    if args.early_stop:\n        stopper = EarlyStopping(patience=100)\n    model.initialize(ctx=ctx)\n\n    # use optimizer\n    trainer = gluon.Trainer(\n        model.collect_params(), \"adam\", {\"learning_rate\": args.lr}\n    )\n\n    dur = []\n    for epoch in range(args.epochs):\n        if epoch >= 3:\n            t0 = time.time()\n        # forward\n        with mx.autograd.record():\n            logits = model(features)\n            loss = mx.nd.softmax_cross_entropy(\n                logits[mask].squeeze(), labels[mask].squeeze()\n            )\n            loss.backward()\n        trainer.step(mask.shape[0])\n\n        if epoch >= 3:\n            dur.append(time.time() - t0)\n        print(\n            \"Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f} | ETputs(KTEPS) {:.2f}\".format(\n                epoch,\n                loss.asnumpy()[0],\n                np.mean(dur),\n                n_edges / np.mean(dur) / 1000,\n            )\n        )\n        val_accuracy = evaluate(model, features, labels, val_mask)\n        print(\"Validation Accuracy {:.4f}\".format(val_accuracy))\n        if args.early_stop:\n            if stopper.step(val_accuracy, model):\n                break\n    print()\n\n    if args.early_stop:\n        model.load_parameters(\"model.param\")\n    test_accuracy = evaluate(model, features, labels, test_mask)\n    print(\"Test Accuracy {:.4f}\".format(test_accuracy))\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"GAT\")\n    register_data_args(parser)\n    parser.add_argument(\n        \"--gpu\",\n        type=int,\n        default=-1,\n        help=\"which GPU to use. Set -1 to use CPU.\",\n    )\n    parser.add_argument(\n        \"--epochs\", type=int, default=200, help=\"number of training epochs\"\n    )\n    parser.add_argument(\n        \"--num-heads\",\n        type=int,\n        default=8,\n        help=\"number of hidden attention heads\",\n    )\n    parser.add_argument(\n        \"--num-out-heads\",\n        type=int,\n        default=1,\n        help=\"number of output attention heads\",\n    )\n    parser.add_argument(\n        \"--num-layers\", type=int, default=1, help=\"number of hidden layers\"\n    )\n    parser.add_argument(\n        \"--num-hidden\", type=int, default=8, help=\"number of hidden units\"\n    )\n    parser.add_argument(\n        \"--residual\",\n        action=\"store_true\",\n        default=False,\n        help=\"use residual connection\",\n    )\n    parser.add_argument(\n        \"--in-drop\", type=float, default=0.6, help=\"input feature dropout\"\n    )\n    parser.add_argument(\n        \"--attn-drop\", type=float, default=0.6, help=\"attention dropout\"\n    )\n    parser.add_argument(\"--lr\", type=float, default=0.005, help=\"learning rate\")\n    parser.add_argument(\n        \"--weight-decay\", type=float, default=5e-4, help=\"weight decay\"\n    )\n    parser.add_argument(\n        \"--alpha\",\n        type=float,\n        default=0.2,\n        help=\"the negative slop of leaky relu\",\n    )\n    parser.add_argument(\n        \"--early-stop\",\n        action=\"store_true\",\n        default=False,\n        help=\"indicates whether to use early stop or not\",\n    )\n    args = parser.parse_args()\n    print(args)\n\n    main(args)\n"
  },
  {
    "path": "examples/mxnet/gat/utils.py",
    "content": "import numpy as np\n\n\nclass EarlyStopping:\n    def __init__(self, patience=10):\n        self.patience = patience\n        self.counter = 0\n        self.best_score = None\n        self.early_stop = False\n\n    def step(self, acc, model):\n        score = acc\n        if self.best_score is None:\n            self.best_score = score\n            self.save_checkpoint(model)\n        elif score < self.best_score:\n            self.counter += 1\n            print(\n                f\"EarlyStopping counter: {self.counter} out of {self.patience}\"\n            )\n            if self.counter >= self.patience:\n                self.early_stop = True\n        else:\n            self.best_score = score\n            self.save_checkpoint(model)\n            self.counter = 0\n        return self.early_stop\n\n    def save_checkpoint(self, model):\n        \"\"\"Saves model when validation loss decrease.\"\"\"\n        model.save_parameters(\"model.param\")\n"
  },
  {
    "path": "examples/mxnet/gcn/README.md",
    "content": "Graph Convolutional Networks (GCN)\n============\n\nPaper link: [https://arxiv.org/abs/1609.02907](https://arxiv.org/abs/1609.02907)\nAuthor's code repo: [https://github.com/tkipf/gcn](https://github.com/tkipf/gcn)\n\nDependencies\n------------\n- MXNet nightly build\n- requests\n\n``bash\npip install mxnet --pre\npip install requests\n``\n\nCodes\n-----\nThe folder contains three implementations of GCN:\n- `gcn.py` uses DGL's predefined graph convolution module.\n- `gcn_mp.py` uses user-defined message and reduce functions.\nModify `train.py` to switch between different implementations.\n\nThe provided implementation in `gcn_concat.py` is a bit different from the\noriginal paper for better performance, credit to @yifeim and @ZiyueHuang.\n\nResults\n-------\nRun with following (available dataset: \"cora\", \"citeseer\", \"pubmed\")\n```bash\nDGLBACKEND=mxnet python3 train.py --dataset cora --gpu 0 --self-loop\n```\n\n* cora: ~0.810 (paper: 0.815)\n* citeseer: ~0.702 (paper: 0.703)\n* pubmed: ~0.780 (paper: 0.790)\n\nResults (`gcn_concat.py vs. gcn.py`)\n------------------------------------\n`gcn_concat.py` uses concatenation of hidden units to account for multi-hop\n  skip-connections. We feel concatenation is superior\nbecause all neighboring information is presented without additional modeling\nassumptions.\nThese results are based on single-run training to minimize the cross-entropy\nloss. We can see clear skip connection can help train a GCN with many layers.\n\nThe experiments show that adding depth may or may not improve accuracy.\nWhile adding depth is a clear way to mimic power iterations of matrix factorizations,\ntraining multiple epochs to obtain stationary points could equivalently solve matrix\nfactorization. Given the small datasets, we can't draw such conclusions from these experiments.\n\n```\n# Final accuracy 57.70% MLP without GCN\nDGLBACKEND=mxnet python3 examples/mxnet/gcn/gcn_concat.py --dataset \"citeseer\" --n-epochs 200 --n-layers 0\n\n# Final accuracy 65.70% with 10-layer GCN with skip connection\nDGLBACKEND=mxnet python3 examples/mxnet/gcn/gcn_concat.py --dataset \"citeseer\" --n-epochs 200 --n-layers 2 --normalization 'sym' --self-loop\n\n# Final accuracy 64.70% with 10-layer GCN with skip connection\nDGLBACKEND=mxnet python3 examples/mxnet/gcn/gcn_concat.py --dataset \"citeseer\" --n-epochs 200 --n-layers 10 --normalization 'sym' --self-loop\n\n```\n\n```\n# Final accuracy 53.20% MLP without GCN\nDGLBACKEND=mxnet python3 examples/mxnet/gcn/gcn_concat.py --dataset \"cora\" --n-epochs 200 --n-layers 0\n\n# Final accuracy 72.60% with 2-layer GCN with skip connection\nDGLBACKEND=mxnet python3 examples/mxnet/gcn/gcn_concat.py --dataset \"cora\" --n-epochs 200 --n-layers 2 --normalization 'sym' --self-loop\n\n# Final accuracy 78.90% with 10-layer GCN with skip connection\nDGLBACKEND=mxnet python3 examples/mxnet/gcn/gcn_concat.py --dataset \"cora\" --n-epochs 200 --n-layers 10 --normalization 'sym' --self-loop\n\n```\n\n```\n# Final accuracy 70.30% MLP without GCN\nDGLBACKEND=mxnet python3 examples/mxnet/gcn/gcn_concat.py --dataset \"pubmed\" --n-epochs 200 --n-layers 0\n\n# Final accuracy 78.30% with 2-layer GCN with skip connection\nDGLBACKEND=mxnet python3 examples/mxnet/gcn/gcn_concat.py --dataset \"pubmed\" --n-epochs 200 --n-layers 2 --normalization 'sym' --self-loop\n\n# Final accuracy 76.30% with 10-layer GCN with skip connection\nDGLBACKEND=mxnet python3 examples/mxnet/gcn/gcn_concat.py --dataset \"pubmed\" --n-epochs 200 --n-layers 10 --normalization 'sym' --self-loop\n```\n"
  },
  {
    "path": "examples/mxnet/gcn/gcn.py",
    "content": "\"\"\"GCN using DGL nn package\n\nReferences:\n- Semi-Supervised Classification with Graph Convolutional Networks\n- Paper: https://arxiv.org/abs/1609.02907\n- Code: https://github.com/tkipf/gcn\n\"\"\"\nimport dgl\nimport mxnet as mx\nfrom dgl.nn.mxnet import GraphConv\nfrom mxnet import gluon\n\n\nclass GCN(gluon.Block):\n    def __init__(\n        self, g, in_feats, n_hidden, n_classes, n_layers, activation, dropout\n    ):\n        super(GCN, self).__init__()\n        self.g = g\n        self.layers = gluon.nn.Sequential()\n        # input layer\n        self.layers.add(GraphConv(in_feats, n_hidden, activation=activation))\n        # hidden layers\n        for i in range(n_layers - 1):\n            self.layers.add(\n                GraphConv(n_hidden, n_hidden, activation=activation)\n            )\n        # output layer\n        self.layers.add(GraphConv(n_hidden, n_classes))\n        self.dropout = gluon.nn.Dropout(rate=dropout)\n\n    def forward(self, features):\n        h = features\n        for i, layer in enumerate(self.layers):\n            if i != 0:\n                h = self.dropout(h)\n            h = layer(self.g, h)\n        return h\n"
  },
  {
    "path": "examples/mxnet/gcn/gcn_concat.py",
    "content": "\"\"\"\nSemi-Supervised Classification with Graph Convolutional Networks\nPaper: https://arxiv.org/abs/1609.02907\nCode: https://github.com/tkipf/gcn\nGCN with batch processing\n\"\"\"\nimport argparse\nimport time\n\nimport dgl\nimport dgl.function as fn\nimport mxnet as mx\nimport numpy as np\nfrom dgl.data import (\n    CiteseerGraphDataset,\n    CoraGraphDataset,\n    PubmedGraphDataset,\n    register_data_args,\n)\nfrom mxnet import gluon\n\n\nclass GCNLayer(gluon.Block):\n    def __init__(self, g, out_feats, activation, dropout):\n        super(GCNLayer, self).__init__()\n        self.g = g\n        self.dense = gluon.nn.Dense(out_feats, activation)\n        self.dropout = dropout\n\n    def forward(self, h):\n        self.g.ndata[\"h\"] = h * self.g.ndata[\"out_norm\"]\n        self.g.update_all(\n            fn.copy_u(u=\"h\", out=\"m\"), fn.sum(msg=\"m\", out=\"accum\")\n        )\n        accum = self.g.ndata.pop(\"accum\")\n        accum = self.dense(accum * self.g.ndata[\"in_norm\"])\n        if self.dropout:\n            accum = mx.nd.Dropout(accum, p=self.dropout)\n        h = self.g.ndata.pop(\"h\")\n        h = mx.nd.concat(h / self.g.ndata[\"out_norm\"], accum, dim=1)\n        return h\n\n\nclass GCN(gluon.Block):\n    def __init__(self, g, n_hidden, n_classes, n_layers, activation, dropout):\n        super(GCN, self).__init__()\n        self.inp_layer = gluon.nn.Dense(n_hidden, activation)\n        self.dropout = dropout\n        self.layers = gluon.nn.Sequential()\n        for i in range(n_layers):\n            self.layers.add(GCNLayer(g, n_hidden, activation, dropout))\n        self.out_layer = gluon.nn.Dense(n_classes)\n\n    def forward(self, features):\n        emb_inp = [features, self.inp_layer(features)]\n        if self.dropout:\n            emb_inp[-1] = mx.nd.Dropout(emb_inp[-1], p=self.dropout)\n        h = mx.nd.concat(*emb_inp, dim=1)\n        for layer in self.layers:\n            h = layer(h)\n        h = self.out_layer(h)\n        return h\n\n\ndef evaluate(model, features, labels, mask):\n    pred = model(features).argmax(axis=1)\n    accuracy = ((pred == labels) * mask).sum() / mask.sum().asscalar()\n    return accuracy.asscalar()\n\n\ndef main(args):\n    # load and preprocess dataset\n    if args.dataset == \"cora\":\n        data = CoraGraphDataset()\n    elif args.dataset == \"citeseer\":\n        data = CiteseerGraphDataset()\n    elif args.dataset == \"pubmed\":\n        data = PubmedGraphDataset()\n    else:\n        raise ValueError(\"Unknown dataset: {}\".format(args.dataset))\n\n    g = data[0]\n    if args.gpu < 0:\n        cuda = False\n        ctx = mx.cpu(0)\n    else:\n        cuda = True\n        ctx = mx.gpu(args.gpu)\n        g = g.to(ctx)\n\n    features = g.ndata[\"feat\"]\n    labels = mx.nd.array(g.ndata[\"label\"], dtype=\"float32\", ctx=ctx)\n    train_mask = g.ndata[\"train_mask\"]\n    val_mask = g.ndata[\"val_mask\"]\n    test_mask = g.ndata[\"test_mask\"]\n    in_feats = features.shape[1]\n    n_classes = data.num_classes\n    n_edges = data.graph.number_of_edges()\n    print(\n        \"\"\"----Data statistics------'\n      #Edges %d\n      #Classes %d\n      #Train samples %d\n      #Val samples %d\n      #Test samples %d\"\"\"\n        % (\n            n_edges,\n            n_classes,\n            train_mask.sum().asscalar(),\n            val_mask.sum().asscalar(),\n            test_mask.sum().asscalar(),\n        )\n    )\n\n    # add self loop\n    if args.self_loop:\n        g = dgl.remove_self_loop(g)\n        g = dgl.add_self_loop(g)\n    # normalization\n    in_degs = g.in_degrees().astype(\"float32\")\n    out_degs = g.out_degrees().astype(\"float32\")\n    in_norm = mx.nd.power(in_degs, -0.5)\n    out_norm = mx.nd.power(out_degs, -0.5)\n    if cuda:\n        in_norm = in_norm.as_in_context(ctx)\n        out_norm = out_norm.as_in_context(ctx)\n    g.ndata[\"in_norm\"] = mx.nd.expand_dims(in_norm, 1)\n    g.ndata[\"out_norm\"] = mx.nd.expand_dims(out_norm, 1)\n\n    model = GCN(\n        g,\n        args.n_hidden,\n        n_classes,\n        args.n_layers,\n        \"relu\",\n        args.dropout,\n    )\n    model.initialize(ctx=ctx)\n    n_train_samples = train_mask.sum().asscalar()\n    loss_fcn = gluon.loss.SoftmaxCELoss()\n\n    # use optimizer\n    print(model.collect_params())\n    trainer = gluon.Trainer(\n        model.collect_params(),\n        \"adam\",\n        {\"learning_rate\": args.lr, \"wd\": args.weight_decay},\n    )\n\n    # initialize graph\n    dur = []\n    for epoch in range(args.n_epochs):\n        if epoch >= 3:\n            t0 = time.time()\n        # forward\n        with mx.autograd.record():\n            pred = model(features)\n            loss = loss_fcn(pred, labels, mx.nd.expand_dims(train_mask, 1))\n            loss = loss.sum() / n_train_samples\n\n        loss.backward()\n        trainer.step(batch_size=1)\n\n        if epoch >= 3:\n            dur.append(time.time() - t0)\n            acc = evaluate(model, features, labels, val_mask)\n            print(\n                \"Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | \"\n                \"ETputs(KTEPS) {:.2f}\".format(\n                    epoch,\n                    np.mean(dur),\n                    loss.asscalar(),\n                    acc,\n                    n_edges / np.mean(dur) / 1000,\n                )\n            )\n\n    # test set accuracy\n    acc = evaluate(model, features, labels, test_mask)\n    print(\"Test accuracy {:.2%}\".format(acc))\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"GCN\")\n    register_data_args(parser)\n    parser.add_argument(\n        \"--dropout\", type=float, default=0.5, help=\"dropout probability\"\n    )\n    parser.add_argument(\"--gpu\", type=int, default=-1, help=\"gpu\")\n    parser.add_argument(\"--lr\", type=float, default=1e-2, help=\"learning rate\")\n    parser.add_argument(\n        \"--n-epochs\", type=int, default=200, help=\"number of training epochs\"\n    )\n    parser.add_argument(\n        \"--n-hidden\", type=int, default=16, help=\"number of hidden gcn units\"\n    )\n    parser.add_argument(\n        \"--n-layers\", type=int, default=1, help=\"number of hidden gcn layers\"\n    )\n    parser.add_argument(\n        \"--normalization\",\n        choices=[\"sym\", \"left\"],\n        default=None,\n        help=\"graph normalization types (default=None)\",\n    )\n    parser.add_argument(\n        \"--self-loop\",\n        action=\"store_true\",\n        help=\"graph self-loop (default=False)\",\n    )\n    parser.add_argument(\n        \"--weight-decay\", type=float, default=5e-4, help=\"Weight for L2 loss\"\n    )\n    args = parser.parse_args()\n\n    print(args)\n\n    main(args)\n"
  },
  {
    "path": "examples/mxnet/gcn/gcn_mp.py",
    "content": "\"\"\"GCN using basic message passing\n\nReferences:\n- Semi-Supervised Classification with Graph Convolutional Networks\n- Paper: https://arxiv.org/abs/1609.02907\n- Code: https://github.com/tkipf/gcn\n\"\"\"\nimport mxnet as mx\nfrom mxnet import gluon\n\n\ndef gcn_msg(edge):\n    msg = edge.src[\"h\"] * edge.src[\"norm\"]\n    return {\"m\": msg}\n\n\ndef gcn_reduce(node):\n    accum = mx.nd.sum(node.mailbox[\"m\"], 1) * node.data[\"norm\"]\n    return {\"h\": accum}\n\n\nclass NodeUpdate(gluon.Block):\n    def __init__(self, out_feats, activation=None, bias=True):\n        super(NodeUpdate, self).__init__()\n        with self.name_scope():\n            if bias:\n                self.bias = self.params.get(\n                    \"bias\", shape=(out_feats,), init=mx.init.Zero()\n                )\n            else:\n                self.bias = None\n        self.activation = activation\n\n    def forward(self, node):\n        h = node.data[\"h\"]\n        if self.bias is not None:\n            h = h + self.bias.data(h.context)\n        if self.activation:\n            h = self.activation(h)\n        return {\"h\": h}\n\n\nclass GCNLayer(gluon.Block):\n    def __init__(self, g, in_feats, out_feats, activation, dropout, bias=True):\n        super(GCNLayer, self).__init__()\n        self.g = g\n        self.dropout = dropout\n        with self.name_scope():\n            self.weight = self.params.get(\n                \"weight\", shape=(in_feats, out_feats), init=mx.init.Xavier()\n            )\n            self.node_update = NodeUpdate(out_feats, activation, bias)\n\n    def forward(self, h):\n        if self.dropout:\n            h = mx.nd.Dropout(h, p=self.dropout)\n        h = mx.nd.dot(h, self.weight.data(h.context))\n        self.g.ndata[\"h\"] = h\n        self.g.update_all(gcn_msg, gcn_reduce, self.node_update)\n        h = self.g.ndata.pop(\"h\")\n        return h\n\n\nclass GCN(gluon.Block):\n    def __init__(\n        self, g, in_feats, n_hidden, n_classes, n_layers, activation, dropout\n    ):\n        super(GCN, self).__init__()\n        self.layers = gluon.nn.Sequential()\n        # input layer\n        self.layers.add(GCNLayer(g, in_feats, n_hidden, activation, 0))\n        # hidden layers\n        for i in range(n_layers - 1):\n            self.layers.add(\n                GCNLayer(g, n_hidden, n_hidden, activation, dropout)\n            )\n        # output layer\n        self.layers.add(GCNLayer(g, n_hidden, n_classes, None, dropout))\n\n    def forward(self, features):\n        h = features\n        for layer in self.layers:\n            h = layer(h)\n        return h\n"
  },
  {
    "path": "examples/mxnet/gcn/train.py",
    "content": "\"\"\"Training GCN model on citation graphs.\"\"\"\nimport argparse\nimport time\n\nimport dgl\n\nimport mxnet as mx\nimport numpy as np\nfrom dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset\nfrom gcn import GCN\nfrom mxnet import gluon\n\n# from gcn_mp import GCN\n# from gcn_spmv import GCN\n\n\ndef evaluate(model, features, labels, mask):\n    pred = model(features).argmax(axis=1)\n    accuracy = ((pred == labels) * mask).sum() / mask.sum().asscalar()\n    return accuracy.asscalar()\n\n\ndef main(args):\n    # load and preprocess dataset\n    if args.dataset == \"cora\":\n        data = CoraGraphDataset()\n    elif args.dataset == \"citeseer\":\n        data = CiteseerGraphDataset()\n    elif args.dataset == \"pubmed\":\n        data = PubmedGraphDataset()\n    else:\n        raise ValueError(\"Unknown dataset: {}\".format(args.dataset))\n\n    g = data[0]\n    if args.gpu < 0:\n        cuda = False\n        ctx = mx.cpu(0)\n    else:\n        cuda = True\n        ctx = mx.gpu(args.gpu)\n        g = g.int().to(ctx)\n\n    features = g.ndata[\"feat\"]\n    labels = mx.nd.array(g.ndata[\"label\"], dtype=\"float32\", ctx=ctx)\n    train_mask = g.ndata[\"train_mask\"]\n    val_mask = g.ndata[\"val_mask\"]\n    test_mask = g.ndata[\"test_mask\"]\n    in_feats = features.shape[1]\n    n_classes = data.num_classes\n    n_edges = data.graph.number_of_edges()\n    print(\n        \"\"\"----Data statistics------'\n      #Edges %d\n      #Classes %d\n      #Train samples %d\n      #Val samples %d\n      #Test samples %d\"\"\"\n        % (\n            n_edges,\n            n_classes,\n            train_mask.sum().asscalar(),\n            val_mask.sum().asscalar(),\n            test_mask.sum().asscalar(),\n        )\n    )\n\n    # add self loop\n    if args.self_loop:\n        g = dgl.remove_self_loop(g)\n        g = dgl.add_self_loop(g)\n    # normalization\n    degs = g.in_degrees().astype(\"float32\")\n    norm = mx.nd.power(degs, -0.5)\n    if cuda:\n        norm = norm.as_in_context(ctx)\n    g.ndata[\"norm\"] = mx.nd.expand_dims(norm, 1)\n\n    model = GCN(\n        g,\n        in_feats,\n        args.n_hidden,\n        n_classes,\n        args.n_layers,\n        mx.nd.relu,\n        args.dropout,\n    )\n    model.initialize(ctx=ctx)\n    n_train_samples = train_mask.sum().asscalar()\n    loss_fcn = gluon.loss.SoftmaxCELoss()\n\n    # use optimizer\n    print(model.collect_params())\n    trainer = gluon.Trainer(\n        model.collect_params(),\n        \"adam\",\n        {\"learning_rate\": args.lr, \"wd\": args.weight_decay},\n    )\n\n    # initialize graph\n    dur = []\n    for epoch in range(args.n_epochs):\n        if epoch >= 3:\n            t0 = time.time()\n        # forward\n        with mx.autograd.record():\n            pred = model(features)\n            loss = loss_fcn(pred, labels, mx.nd.expand_dims(train_mask, 1))\n            loss = loss.sum() / n_train_samples\n\n        loss.backward()\n        trainer.step(batch_size=1)\n\n        if epoch >= 3:\n            loss.asscalar()\n            dur.append(time.time() - t0)\n            acc = evaluate(model, features, labels, val_mask)\n            print(\n                \"Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | \"\n                \"ETputs(KTEPS) {:.2f}\".format(\n                    epoch,\n                    np.mean(dur),\n                    loss.asscalar(),\n                    acc,\n                    n_edges / np.mean(dur) / 1000,\n                )\n            )\n\n    # test set accuracy\n    acc = evaluate(model, features, labels, test_mask)\n    print(\"Test accuracy {:.2%}\".format(acc))\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"GCN\")\n    parser.add_argument(\n        \"--dataset\",\n        type=str,\n        default=\"cora\",\n        help=\"Dataset name ('cora', 'citeseer', 'pubmed').\",\n    )\n    parser.add_argument(\n        \"--dropout\", type=float, default=0.5, help=\"dropout probability\"\n    )\n    parser.add_argument(\"--gpu\", type=int, default=-1, help=\"gpu\")\n    parser.add_argument(\"--lr\", type=float, default=3e-2, help=\"learning rate\")\n    parser.add_argument(\n        \"--n-epochs\", type=int, default=200, help=\"number of training epochs\"\n    )\n    parser.add_argument(\n        \"--n-hidden\", type=int, default=16, help=\"number of hidden gcn units\"\n    )\n    parser.add_argument(\n        \"--n-layers\", type=int, default=1, help=\"number of hidden gcn layers\"\n    )\n    parser.add_argument(\n        \"--weight-decay\", type=float, default=5e-4, help=\"Weight for L2 loss\"\n    )\n    parser.add_argument(\n        \"--self-loop\",\n        action=\"store_true\",\n        help=\"graph self-loop (default=False)\",\n    )\n    parser.set_defaults(self_loop=False)\n    args = parser.parse_args()\n\n    print(args)\n\n    main(args)\n"
  },
  {
    "path": "examples/mxnet/gin/README.md",
    "content": "Graph Isomorphism Network (GIN)\n============\n\n- Paper link: [arXiv](https://arxiv.org/abs/1810.00826) [OpenReview](https://openreview.net/forum?id=ryGs6iA5Km) \n- Author's code repo: [https://github.com/weihua916/powerful-gnns](https://github.com/weihua916/powerful-gnns).\n\nDependencies\n------------\n- MXNet 1.5+\n- sklearn\n- tqdm\n\n``bash\npip install torch sklearn tqdm\n``\n\nHow to run\n----------\n\nAn experiment on the GIN in default settings can be run with\n\n```bash\nDGLBACKEND=mxnet python main.py\n```\n\nAn experiment on the GIN in customized settings can be run with\n```bash\nDGLBACKEND=mxnet python main.py [--device 0 | --disable-cuda] --dataset COLLAB \\\n               --graph_pooling_type max --neighbor_pooling_type sum\n```\n\nResults\n-------\n\nRun with following with the double SUM pooling way:\n(tested dataset: \"MUTAG\"(default), \"COLLAB\", \"IMDBBINARY\", \"IMDBMULTI\")\n```bash\nDGLBACKEND=mxnet python main.py --dataset MUTAG --device 0  \\\n                --graph_pooling_type sum --neighbor_pooling_type sum\n```\n\n"
  },
  {
    "path": "examples/mxnet/gin/dataloader.py",
    "content": "\"\"\"\nMxNet compatible dataloader\n\"\"\"\n\nimport math\n\nimport dgl\n\nimport numpy as np\nfrom mxnet import nd\nfrom mxnet.gluon.data import DataLoader, Sampler\nfrom sklearn.model_selection import StratifiedKFold\n\n\nclass SubsetRandomSampler(Sampler):\n    def __init__(self, indices):\n        self.indices = indices\n\n    def __iter__(self):\n        return iter(\n            [self.indices[i] for i in np.random.permutation(len(self.indices))]\n        )\n\n    def __len__(self):\n        return len(self.indices)\n\n\n# default collate function\ndef collate(samples):\n    # The input `samples` is a list of pairs (graph, label).\n    graphs, labels = map(list, zip(*samples))\n    for g in graphs:\n        # deal with node feats\n        for key in g.node_attr_schemes().keys():\n            g.ndata[key] = nd.array(g.ndata[key])\n        # no edge feats\n    batched_graph = dgl.batch(graphs)\n    labels = [nd.reshape(label, (1,)) for label in labels]\n    labels = nd.concat(*labels, dim=0)\n    return batched_graph, labels\n\n\nclass GraphDataLoader:\n    def __init__(\n        self,\n        dataset,\n        batch_size,\n        collate_fn=collate,\n        seed=0,\n        shuffle=True,\n        split_name=\"fold10\",\n        fold_idx=0,\n        split_ratio=0.7,\n    ):\n        self.shuffle = shuffle\n        self.seed = seed\n\n        labels = [l for _, l in dataset]\n\n        if split_name == \"fold10\":\n            train_idx, valid_idx = self._split_fold10(\n                labels, fold_idx, seed, shuffle\n            )\n        elif split_name == \"rand\":\n            train_idx, valid_idx = self._split_rand(\n                labels, split_ratio, seed, shuffle\n            )\n        else:\n            raise NotImplementedError()\n\n        train_sampler = SubsetRandomSampler(train_idx)\n        valid_sampler = SubsetRandomSampler(valid_idx)\n\n        self.train_loader = DataLoader(\n            dataset,\n            sampler=train_sampler,\n            batch_size=batch_size,\n            batchify_fn=collate_fn,\n        )\n        self.valid_loader = DataLoader(\n            dataset,\n            sampler=valid_sampler,\n            batch_size=batch_size,\n            batchify_fn=collate_fn,\n        )\n\n    def train_valid_loader(self):\n        return self.train_loader, self.valid_loader\n\n    def _split_fold10(self, labels, fold_idx=0, seed=0, shuffle=True):\n        \"\"\"10 flod\"\"\"\n        assert 0 <= fold_idx and fold_idx < 10, print(\n            \"fold_idx must be from 0 to 9.\"\n        )\n\n        skf = StratifiedKFold(n_splits=10, shuffle=shuffle, random_state=seed)\n        idx_list = []\n        for idx in skf.split(\n            np.zeros(len(labels)), [label.asnumpy() for label in labels]\n        ):  # split(x, y)\n            idx_list.append(idx)\n        train_idx, valid_idx = idx_list[fold_idx]\n\n        print(\"train_set : test_set = %d : %d\", len(train_idx), len(valid_idx))\n\n        return train_idx, valid_idx\n\n    def _split_rand(self, labels, split_ratio=0.7, seed=0, shuffle=True):\n        num_entries = len(labels)\n        indices = list(range(num_entries))\n        np.random.seed(seed)\n        np.random.shuffle(indices)\n        split = int(math.floor(split_ratio * num_entries))\n        train_idx, valid_idx = indices[:split], indices[split:]\n\n        print(\"train_set : test_set = %d : %d\", len(train_idx), len(valid_idx))\n\n        return train_idx, valid_idx\n"
  },
  {
    "path": "examples/mxnet/gin/gin.py",
    "content": "\"\"\"\nHow Powerful are Graph Neural Networks\nhttps://arxiv.org/abs/1810.00826\nhttps://openreview.net/forum?id=ryGs6iA5Km\nAuthor's implementation: https://github.com/weihua916/powerful-gnns\n\"\"\"\n\nimport mxnet as mx\n\nfrom dgl.nn.mxnet.conv import GINConv\nfrom dgl.nn.mxnet.glob import AvgPooling, MaxPooling, SumPooling\nfrom mxnet import gluon, nd\nfrom mxnet.gluon import nn\n\n\nclass ApplyNodeFunc(nn.Block):\n    \"\"\"Update the node feature hv with MLP, BN and ReLU.\"\"\"\n\n    def __init__(self, mlp):\n        super(ApplyNodeFunc, self).__init__()\n        with self.name_scope():\n            self.mlp = mlp\n            self.bn = nn.BatchNorm(in_channels=self.mlp.output_dim)\n\n    def forward(self, h):\n        h = self.mlp(h)\n        h = self.bn(h)\n        h = nd.relu(h)\n        return h\n\n\nclass MLP(nn.Block):\n    \"\"\"MLP with linear output\"\"\"\n\n    def __init__(self, num_layers, input_dim, hidden_dim, output_dim):\n        \"\"\"MLP layers construction\n\n        Paramters\n        ---------\n        num_layers: int\n            The number of linear layers\n        input_dim: int\n            The dimensionality of input features\n        hidden_dim: int\n            The dimensionality of hidden units at ALL layers\n        output_dim: int\n            The number of classes for prediction\n        \"\"\"\n        super(MLP, self).__init__()\n        self.linear_or_not = True\n        self.num_layers = num_layers\n        self.output_dim = output_dim\n\n        with self.name_scope():\n            if num_layers < 1:\n                raise ValueError(\"number of layers should be positive!\")\n            elif num_layers == 1:\n                # Linear model\n                self.linear = nn.Dense(output_dim, in_units=input_dim)\n            else:\n                self.linear_or_not = False\n                self.linears = nn.Sequential()\n                self.batch_norms = nn.Sequential()\n\n                self.linears.add(nn.Dense(hidden_dim, in_units=input_dim))\n                for layer in range(num_layers - 2):\n                    self.linears.add(nn.Dense(hidden_dim, in_units=hidden_dim))\n                self.linears.add(nn.Dense(output_dim, in_units=hidden_dim))\n\n                for layer in range(num_layers - 1):\n                    self.batch_norms.add(nn.BatchNorm(in_channels=hidden_dim))\n\n    def forward(self, x):\n        if self.linear_or_not:\n            return self.linear(x)\n        else:\n            h = x\n            for i in range(self.num_layers - 1):\n                h = nd.relu(self.batch_norms[i](self.linears[i](h)))\n            return self.linears[-1](h)\n\n\nclass GIN(nn.Block):\n    \"\"\"GIN model\"\"\"\n\n    def __init__(\n        self,\n        num_layers,\n        num_mlp_layers,\n        input_dim,\n        hidden_dim,\n        output_dim,\n        final_dropout,\n        learn_eps,\n        graph_pooling_type,\n        neighbor_pooling_type,\n    ):\n        \"\"\"model parameters setting\n\n        Paramters\n        ---------\n        num_layers: int\n            The number of linear layers in the neural network\n        num_mlp_layers: int\n            The number of linear layers in mlps\n        input_dim: int\n            The dimensionality of input features\n        hidden_dim: int\n            The dimensionality of hidden units at ALL layers\n        output_dim: int\n            The number of classes for prediction\n        final_dropout: float\n            dropout ratio on the final linear layer\n        learn_eps: boolean\n            If True, learn epsilon to distinguish center nodes from neighbors\n            If False, aggregate neighbors and center nodes altogether.\n        neighbor_pooling_type: str\n            how to aggregate neighbors (sum, mean, or max)\n        graph_pooling_type: str\n            how to aggregate entire nodes in a graph (sum, mean or max)\n\n        \"\"\"\n        super(GIN, self).__init__()\n        self.num_layers = num_layers\n        self.learn_eps = learn_eps\n\n        with self.name_scope():\n            # List of MLPs\n            self.ginlayers = nn.Sequential()\n            self.batch_norms = nn.Sequential()\n\n            for i in range(self.num_layers - 1):\n                if i == 0:\n                    mlp = MLP(num_mlp_layers, input_dim, hidden_dim, hidden_dim)\n                else:\n                    mlp = MLP(\n                        num_mlp_layers, hidden_dim, hidden_dim, hidden_dim\n                    )\n\n                self.ginlayers.add(\n                    GINConv(\n                        ApplyNodeFunc(mlp),\n                        neighbor_pooling_type,\n                        0,\n                        self.learn_eps,\n                    )\n                )\n                self.batch_norms.add(nn.BatchNorm(in_channels=hidden_dim))\n\n            self.linears_prediction = nn.Sequential()\n\n            for i in range(num_layers):\n                if i == 0:\n                    self.linears_prediction.add(\n                        nn.Dense(output_dim, in_units=input_dim)\n                    )\n                else:\n                    self.linears_prediction.add(\n                        nn.Dense(output_dim, in_units=hidden_dim)\n                    )\n\n            self.drop = nn.Dropout(final_dropout)\n\n            if graph_pooling_type == \"sum\":\n                self.pool = SumPooling()\n            elif graph_pooling_type == \"mean\":\n                self.pool = AvgPooling()\n            elif graph_pooling_type == \"max\":\n                self.pool = MaxPooling()\n            else:\n                raise NotImplementedError\n\n    def forward(self, g, h):\n        hidden_rep = [h]\n\n        for i in range(self.num_layers - 1):\n            h = self.ginlayers[i](g, h)\n            h = self.batch_norms[i](h)\n            h = nd.relu(h)\n            hidden_rep.append(h)\n\n        score_over_layer = 0\n        # perform pooling over all nodes in each graph in every layer\n        for i, h in enumerate(hidden_rep):\n            pooled_h = self.pool(g, h)\n            score_over_layer = score_over_layer + self.drop(\n                self.linears_prediction[i](pooled_h)\n            )\n\n        return score_over_layer\n"
  },
  {
    "path": "examples/mxnet/gin/main.py",
    "content": "import sys\nfrom parser import Parser\n\nimport mxnet as mx\nimport numpy as np\nfrom dataloader import collate, GraphDataLoader\n\nfrom dgl.data.gindt import GINDataset\nfrom gin import GIN\nfrom mxnet import gluon, nd\nfrom mxnet.gluon import nn\nfrom tqdm import tqdm\n\n\ndef train(args, net, trainloader, trainer, criterion, epoch):\n    running_loss = 0\n    total_iters = len(trainloader)\n    # setup the offset to avoid the overlap with mouse cursor\n    bar = tqdm(range(total_iters), unit=\"batch\", position=2, file=sys.stdout)\n\n    for pos, (graphs, labels) in zip(bar, trainloader):\n        # batch graphs will be shipped to device in forward part of model\n        labels = labels.as_in_context(args.device)\n        feat = graphs.ndata[\"attr\"].as_in_context(args.device)\n\n        with mx.autograd.record():\n            graphs = graphs.to(args.device)\n            outputs = net(graphs, feat)\n            loss = criterion(outputs, labels)\n            loss = loss.sum() / len(labels)\n\n        running_loss += loss.asscalar()\n\n        # backprop\n        loss.backward()\n        trainer.step(batch_size=1)\n\n        # report\n        bar.set_description(\"epoch-{}\".format(epoch))\n    bar.close()\n    # the final batch will be aligned\n    running_loss = running_loss / total_iters\n\n    return running_loss\n\n\ndef eval_net(args, net, dataloader, criterion):\n    total = 0\n    total_loss = 0\n    total_correct = 0\n\n    for data in dataloader:\n        graphs, labels = data\n        labels = labels.as_in_context(args.device)\n        feat = graphs.ndata[\"attr\"].as_in_context(args.device)\n\n        total += len(labels)\n        graphs = graphs.to(args.device)\n        outputs = net(graphs, feat)\n        predicted = nd.argmax(outputs, axis=1)\n        predicted = predicted.astype(\"int64\")\n\n        total_correct += (predicted == labels).sum().asscalar()\n        loss = criterion(outputs, labels)\n        # crossentropy(reduce=True) for default\n        total_loss += loss.sum().asscalar()\n\n    loss, acc = 1.0 * total_loss / total, 1.0 * total_correct / total\n\n    return loss, acc\n\n\ndef main(args):\n    # set up seeds, args.seed supported\n    mx.random.seed(0)\n    np.random.seed(seed=0)\n\n    if args.device >= 0:\n        args.device = mx.gpu(args.device)\n    else:\n        args.device = mx.cpu()\n\n    dataset = GINDataset(args.dataset, not args.learn_eps)\n\n    trainloader, validloader = GraphDataLoader(\n        dataset,\n        batch_size=args.batch_size,\n        collate_fn=collate,\n        seed=args.seed,\n        shuffle=True,\n        split_name=\"fold10\",\n        fold_idx=args.fold_idx,\n    ).train_valid_loader()\n    # or split_name='rand', split_ratio=0.7\n\n    model = GIN(\n        args.num_layers,\n        args.num_mlp_layers,\n        dataset.dim_nfeats,\n        args.hidden_dim,\n        dataset.gclasses,\n        args.final_dropout,\n        args.learn_eps,\n        args.graph_pooling_type,\n        args.neighbor_pooling_type,\n    )\n    model.initialize(ctx=args.device)\n\n    criterion = gluon.loss.SoftmaxCELoss()\n\n    print(model.collect_params())\n    lr_scheduler = mx.lr_scheduler.FactorScheduler(50, 0.5)\n    trainer = gluon.Trainer(\n        model.collect_params(), \"adam\", {\"lr_scheduler\": lr_scheduler}\n    )\n\n    # it's not cost-effective to hanle the cursor and init 0\n    # https://stackoverflow.com/a/23121189\n    tbar = tqdm(\n        range(args.epochs), unit=\"epoch\", position=3, ncols=0, file=sys.stdout\n    )\n    vbar = tqdm(\n        range(args.epochs), unit=\"epoch\", position=4, ncols=0, file=sys.stdout\n    )\n    lrbar = tqdm(\n        range(args.epochs), unit=\"epoch\", position=5, ncols=0, file=sys.stdout\n    )\n\n    for epoch, _, _ in zip(tbar, vbar, lrbar):\n        train(args, model, trainloader, trainer, criterion, epoch)\n\n        train_loss, train_acc = eval_net(args, model, trainloader, criterion)\n        tbar.set_description(\n            \"train set - average loss: {:.4f}, accuracy: {:.0f}%\".format(\n                train_loss, 100.0 * train_acc\n            )\n        )\n\n        valid_loss, valid_acc = eval_net(args, model, validloader, criterion)\n        vbar.set_description(\n            \"valid set - average loss: {:.4f}, accuracy: {:.0f}%\".format(\n                valid_loss, 100.0 * valid_acc\n            )\n        )\n\n        if not args.filename == \"\":\n            with open(args.filename, \"a\") as f:\n                f.write(\n                    \"%s %s %s %s\"\n                    % (\n                        args.dataset,\n                        args.learn_eps,\n                        args.neighbor_pooling_type,\n                        args.graph_pooling_type,\n                    )\n                )\n                f.write(\"\\n\")\n                f.write(\n                    \"%f %f %f %f\"\n                    % (train_loss, train_acc, valid_loss, valid_acc)\n                )\n                f.write(\"\\n\")\n\n        lrbar.set_description(\n            \"Learning eps with learn_eps={}: {}\".format(\n                args.learn_eps,\n                [\n                    layer.eps.data(args.device).asscalar()\n                    for layer in model.ginlayers\n                ],\n            )\n        )\n\n    tbar.close()\n    vbar.close()\n    lrbar.close()\n\n\nif __name__ == \"__main__\":\n    args = Parser(description=\"GIN\").args\n    print(\"show all arguments configuration...\")\n    print(args)\n\n    main(args)\n"
  },
  {
    "path": "examples/mxnet/gin/parser.py",
    "content": "\"\"\"Parser for arguments\n\nPut all arguments in one file and group similar arguments\n\"\"\"\nimport argparse\n\n\nclass Parser:\n    def __init__(self, description):\n        \"\"\"\n        arguments parser\n        \"\"\"\n        self.parser = argparse.ArgumentParser(description=description)\n        self.args = None\n        self._parse()\n\n    def _parse(self):\n        # dataset\n        self.parser.add_argument(\n            \"--dataset\",\n            type=str,\n            default=\"MUTAG\",\n            help=\"name of dataset (default: MUTAG)\",\n        )\n        self.parser.add_argument(\n            \"--batch_size\",\n            type=int,\n            default=32,\n            help=\"batch size for training and validation (default: 32)\",\n        )\n        self.parser.add_argument(\n            \"--fold_idx\",\n            type=int,\n            default=0,\n            help=\"the index(<10) of fold in 10-fold validation.\",\n        )\n        self.parser.add_argument(\n            \"--filename\", type=str, default=\"\", help=\"output file\"\n        )\n\n        # device\n        self.parser.add_argument(\n            \"--disable-cuda\", action=\"store_true\", help=\"Disable CUDA\"\n        )\n        self.parser.add_argument(\n            \"--device\",\n            type=int,\n            default=0,\n            help=\"which gpu device to use (default: 0)\",\n        )\n\n        # net\n        self.parser.add_argument(\n            \"--net\", type=str, default=\"gin\", help=\"gnn net (default: gin)\"\n        )\n        self.parser.add_argument(\n            \"--num_layers\",\n            type=int,\n            default=5,\n            help=\"number of layers (default: 5)\",\n        )\n        self.parser.add_argument(\n            \"--num_mlp_layers\",\n            type=int,\n            default=2,\n            help=\"number of MLP layers(default: 2). 1 means linear model.\",\n        )\n        self.parser.add_argument(\n            \"--hidden_dim\",\n            type=int,\n            default=64,\n            help=\"number of hidden units (default: 64)\",\n        )\n\n        # graph\n        self.parser.add_argument(\n            \"--graph_pooling_type\",\n            type=str,\n            default=\"sum\",\n            choices=[\"sum\", \"mean\", \"max\"],\n            help=\"type of graph pooling: sum, mean or max\",\n        )\n        self.parser.add_argument(\n            \"--neighbor_pooling_type\",\n            type=str,\n            default=\"sum\",\n            choices=[\"sum\", \"mean\", \"max\"],\n            help=\"type of neighboring pooling: sum, mean or max\",\n        )\n        self.parser.add_argument(\n            \"--learn_eps\",\n            action=\"store_true\",\n            help=\"learn the epsilon weighting\",\n        )\n        self.parser.add_argument(\n            \"--degree_as_tag\",\n            action=\"store_true\",\n            help=\"take the degree of nodes as input feature\",\n        )\n\n        # learning\n        self.parser.add_argument(\n            \"--seed\", type=int, default=0, help=\"random seed (default: 0)\"\n        )\n        self.parser.add_argument(\n            \"--epochs\",\n            type=int,\n            default=350,\n            help=\"number of epochs to train (default: 350)\",\n        )\n        self.parser.add_argument(\n            \"--lr\",\n            type=float,\n            default=0.01,\n            help=\"learning rate (default: 0.01)\",\n        )\n        self.parser.add_argument(\n            \"--final_dropout\",\n            type=float,\n            default=0.5,\n            help=\"final layer dropout (default: 0.5)\",\n        )\n\n        # done\n        self.args = self.parser.parse_args()\n"
  },
  {
    "path": "examples/mxnet/graphsage/README.md",
    "content": "Inductive Representation Learning on Large Graphs (GraphSAGE)\n============\n\n- Paper link: [http://papers.nips.cc/paper/6703-inductive-representation-learning-on-large-graphs.pdf](http://papers.nips.cc/paper/6703-inductive-representation-learning-on-large-graphs.pdf)\n- Author's code repo: [https://github.com/williamleif/graphsage-simple](https://github.com/williamleif/graphsage-simple). Note that the original code is \nsimple reference implementation of GraphSAGE.\n\nRequirements\n------------\n- requests\n\n``bash\npip install requests\n``\n\n\nResults\n-------\n\nRun with following (available dataset: \"cora\", \"citeseer\", \"pubmed\")\n```bash\npython3 main.py --dataset cora --gpu 0\n```\n\n* cora: ~0.817\n* citeseer: ~0.699\n* pubmed: ~0.790"
  },
  {
    "path": "examples/mxnet/graphsage/main.py",
    "content": "\"\"\"\nInductive Representation Learning on Large Graphs\nPaper: http://papers.nips.cc/paper/6703-inductive-representation-learning-on-large-graphs.pdf\nCode: https://github.com/williamleif/graphsage-simple\nSimple reference implementation of GraphSAGE.\n\"\"\"\nimport argparse\nimport time\n\nimport dgl\n\nimport mxnet as mx\nimport networkx as nx\nimport numpy as np\nfrom dgl.data import (\n    CiteseerGraphDataset,\n    CoraGraphDataset,\n    PubmedGraphDataset,\n    register_data_args,\n)\nfrom dgl.nn.mxnet.conv import SAGEConv\nfrom mxnet import gluon, nd\nfrom mxnet.gluon import nn\n\n\nclass GraphSAGE(nn.Block):\n    def __init__(\n        self,\n        g,\n        in_feats,\n        n_hidden,\n        n_classes,\n        n_layers,\n        activation,\n        dropout,\n        aggregator_type,\n    ):\n        super(GraphSAGE, self).__init__()\n        self.g = g\n\n        with self.name_scope():\n            self.layers = nn.Sequential()\n            # input layer\n            self.layers.add(\n                SAGEConv(\n                    in_feats,\n                    n_hidden,\n                    aggregator_type,\n                    feat_drop=dropout,\n                    activation=activation,\n                )\n            )\n            # hidden layers\n            for i in range(n_layers - 1):\n                self.layers.add(\n                    SAGEConv(\n                        n_hidden,\n                        n_hidden,\n                        aggregator_type,\n                        feat_drop=dropout,\n                        activation=activation,\n                    )\n                )\n            # output layer\n            self.layers.add(\n                SAGEConv(\n                    n_hidden,\n                    n_classes,\n                    aggregator_type,\n                    feat_drop=dropout,\n                    activation=None,\n                )\n            )  # activation None\n\n    def forward(self, features):\n        h = features\n        for layer in self.layers:\n            h = layer(self.g, h)\n        return h\n\n\ndef evaluate(model, features, labels, mask):\n    pred = model(features).argmax(axis=1)\n    accuracy = ((pred == labels) * mask).sum() / mask.sum().asscalar()\n    return accuracy.asscalar()\n\n\ndef main(args):\n    # load and preprocess dataset\n    if args.dataset == \"cora\":\n        data = CoraGraphDataset()\n    elif args.dataset == \"citeseer\":\n        data = CiteseerGraphDataset()\n    elif args.dataset == \"pubmed\":\n        data = PubmedGraphDataset()\n    else:\n        raise ValueError(\"Unknown dataset: {}\".format(args.dataset))\n\n    g = data[0]\n    if args.gpu < 0:\n        cuda = False\n        ctx = mx.cpu(0)\n    else:\n        cuda = True\n        ctx = mx.gpu(args.gpu)\n        g = g.int().to(ctx)\n\n    features = g.ndata[\"feat\"]\n    labels = mx.nd.array(g.ndata[\"label\"], dtype=\"float32\", ctx=ctx)\n    train_mask = g.ndata[\"train_mask\"]\n    val_mask = g.ndata[\"val_mask\"]\n    test_mask = g.ndata[\"test_mask\"]\n    in_feats = features.shape[1]\n    n_classes = data.num_classes\n    n_edges = data.graph.number_of_edges()\n    print(\n        \"\"\"----Data statistics------'\n      #Edges %d\n      #Classes %d\n      #Train samples %d\n      #Val samples %d\n      #Test samples %d\"\"\"\n        % (\n            n_edges,\n            n_classes,\n            train_mask.sum().asscalar(),\n            val_mask.sum().asscalar(),\n            test_mask.sum().asscalar(),\n        )\n    )\n\n    # add self loop\n    g = dgl.remove_self_loop(g)\n    g = dgl.add_self_loop(g)\n    n_edges = g.number_of_edges()\n\n    # create GraphSAGE model\n    model = GraphSAGE(\n        g,\n        in_feats,\n        args.n_hidden,\n        n_classes,\n        args.n_layers,\n        nd.relu,\n        args.dropout,\n        args.aggregator_type,\n    )\n\n    model.initialize(ctx=ctx)\n    n_train_samples = train_mask.sum().asscalar()\n    loss_fcn = gluon.loss.SoftmaxCELoss()\n\n    print(model.collect_params())\n    trainer = gluon.Trainer(\n        model.collect_params(),\n        \"adam\",\n        {\"learning_rate\": args.lr, \"wd\": args.weight_decay},\n    )\n\n    # initialize graph\n    dur = []\n    for epoch in range(args.n_epochs):\n        if epoch >= 3:\n            t0 = time.time()\n        # forward\n        with mx.autograd.record():\n            pred = model(features)\n            loss = loss_fcn(pred, labels, mx.nd.expand_dims(train_mask, 1))\n            loss = loss.sum() / n_train_samples\n\n        loss.backward()\n        trainer.step(batch_size=1)\n\n        if epoch >= 3:\n            loss.asscalar()\n            dur.append(time.time() - t0)\n            acc = evaluate(model, features, labels, val_mask)\n            print(\n                \"Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | \"\n                \"ETputs(KTEPS) {:.2f}\".format(\n                    epoch,\n                    np.mean(dur),\n                    loss.asscalar(),\n                    acc,\n                    n_edges / np.mean(dur) / 1000,\n                )\n            )\n\n    # test set accuracy\n    acc = evaluate(model, features, labels, test_mask)\n    print(\"Test accuracy {:.2%}\".format(acc))\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"GraphSAGE\")\n    register_data_args(parser)\n    parser.add_argument(\n        \"--dropout\", type=float, default=0.5, help=\"dropout probability\"\n    )\n    parser.add_argument(\"--gpu\", type=int, default=-1, help=\"gpu\")\n    parser.add_argument(\"--lr\", type=float, default=1e-2, help=\"learning rate\")\n    parser.add_argument(\n        \"--n-epochs\", type=int, default=200, help=\"number of training epochs\"\n    )\n    parser.add_argument(\n        \"--n-hidden\", type=int, default=16, help=\"number of hidden gcn units\"\n    )\n    parser.add_argument(\n        \"--n-layers\", type=int, default=1, help=\"number of hidden gcn layers\"\n    )\n    parser.add_argument(\n        \"--weight-decay\", type=float, default=5e-4, help=\"Weight for L2 loss\"\n    )\n    parser.add_argument(\n        \"--aggregator-type\",\n        type=str,\n        default=\"gcn\",\n        help=\"Aggregator type: mean/gcn/pool/lstm\",\n    )\n    args = parser.parse_args()\n    print(args)\n\n    main(args)\n"
  },
  {
    "path": "examples/mxnet/monet/README.md",
    "content": "MoNet\n=====\n\n- paper link: [Geometric deep learning on graphs and manifolds using mixture model CNNs](https://arxiv.org/pdf/1611.08402.pdf)\n\nDependencies\n============\n\n- MXNet 1.5+\n\nResults\n=======\n\n## Citation networks\nRun with following (available dataset: \"cora\", \"citeseer\", \"pubmed\")\n```bash\npython3 citation.py --dataset cora --gpu 0\n```\n\n- Cora: ~0.814\n- Pubmed: ~0.748\n"
  },
  {
    "path": "examples/mxnet/monet/citation.py",
    "content": "import argparse\nimport time\n\nimport dgl\n\nimport mxnet as mx\nimport networkx as nx\nimport numpy as np\nfrom dgl.data import (\n    CiteseerGraphDataset,\n    CoraGraphDataset,\n    PubmedGraphDataset,\n    register_data_args,\n)\nfrom dgl.nn.mxnet.conv import GMMConv\nfrom mxnet import gluon, nd\nfrom mxnet.gluon import nn\n\n\nclass MoNet(nn.Block):\n    def __init__(\n        self,\n        g,\n        in_feats,\n        n_hidden,\n        out_feats,\n        n_layers,\n        dim,\n        n_kernels,\n        dropout,\n    ):\n        super(MoNet, self).__init__()\n        self.g = g\n        with self.name_scope():\n            self.layers = nn.Sequential()\n            self.pseudo_proj = nn.Sequential()\n\n            # Input layer\n            self.layers.add(GMMConv(in_feats, n_hidden, dim, n_kernels))\n            self.pseudo_proj.add(nn.Dense(dim, in_units=2, activation=\"tanh\"))\n\n            # Hidden layer\n            for _ in range(n_layers - 1):\n                self.layers.add(GMMConv(n_hidden, n_hidden, dim, n_kernels))\n                self.pseudo_proj.add(\n                    nn.Dense(dim, in_units=2, activation=\"tanh\")\n                )\n\n            # Output layer\n            self.layers.add(GMMConv(n_hidden, out_feats, dim, n_kernels))\n            self.pseudo_proj.add(nn.Dense(dim, in_units=2, activation=\"tanh\"))\n\n            self.dropout = nn.Dropout(dropout)\n\n    def forward(self, feat, pseudo):\n        h = feat\n        for i in range(len(self.layers)):\n            if i > 0:\n                h = self.dropout(h)\n            h = self.layers[i](self.g, h, self.pseudo_proj[i](pseudo))\n        return h\n\n\ndef evaluate(model, features, pseudo, labels, mask):\n    pred = model(features, pseudo).argmax(axis=1)\n    accuracy = ((pred == labels) * mask).sum() / mask.sum().asscalar()\n    return accuracy.asscalar()\n\n\ndef main(args):\n    # load and preprocess dataset\n    if args.dataset == \"cora\":\n        data = CoraGraphDataset()\n    elif args.dataset == \"citeseer\":\n        data = CiteseerGraphDataset()\n    elif args.dataset == \"pubmed\":\n        data = PubmedGraphDataset()\n    else:\n        raise ValueError(\"Unknown dataset: {}\".format(args.dataset))\n\n    g = data[0]\n    if args.gpu < 0:\n        cuda = False\n        ctx = mx.cpu(0)\n    else:\n        cuda = True\n        ctx = mx.gpu(args.gpu)\n        g = g.to(ctx)\n\n    features = g.ndata[\"feat\"]\n    labels = mx.nd.array(g.ndata[\"label\"], dtype=\"float32\", ctx=ctx)\n    train_mask = g.ndata[\"train_mask\"]\n    val_mask = g.ndata[\"val_mask\"]\n    test_mask = g.ndata[\"test_mask\"]\n    in_feats = features.shape[1]\n    n_classes = data.num_classes\n    n_edges = data.graph.number_of_edges()\n    print(\n        \"\"\"----Data statistics------'\n      #Edges %d\n      #Classes %d\n      #Train samples %d\n      #Val samples %d\n      #Test samples %d\"\"\"\n        % (\n            n_edges,\n            n_classes,\n            train_mask.sum().asscalar(),\n            val_mask.sum().asscalar(),\n            test_mask.sum().asscalar(),\n        )\n    )\n\n    # add self loop\n    g = dgl.remove_self_loop(g)\n    g = dgl.add_self_loop(g)\n\n    n_edges = g.number_of_edges()\n    us, vs = g.edges()\n    us = us.asnumpy()\n    vs = vs.asnumpy()\n    pseudo = []\n    for i in range(g.number_of_edges()):\n        pseudo.append(\n            [1 / np.sqrt(g.in_degrees(us[i])), 1 / np.sqrt(g.in_degrees(vs[i]))]\n        )\n    pseudo = nd.array(pseudo, ctx=ctx)\n\n    # create GraphSAGE model\n    model = MoNet(\n        g,\n        in_feats,\n        args.n_hidden,\n        n_classes,\n        args.n_layers,\n        args.pseudo_dim,\n        args.n_kernels,\n        args.dropout,\n    )\n    model.initialize(ctx=ctx)\n    n_train_samples = train_mask.sum().asscalar()\n    loss_fcn = gluon.loss.SoftmaxCELoss()\n\n    print(model.collect_params())\n    trainer = gluon.Trainer(\n        model.collect_params(),\n        \"adam\",\n        {\"learning_rate\": args.lr, \"wd\": args.weight_decay},\n    )\n\n    # initialize graph\n    dur = []\n    for epoch in range(args.n_epochs):\n        if epoch >= 3:\n            t0 = time.time()\n        # forward\n        with mx.autograd.record():\n            pred = model(features, pseudo)\n            loss = loss_fcn(pred, labels, mx.nd.expand_dims(train_mask, 1))\n            loss = loss.sum() / n_train_samples\n\n        loss.backward()\n        trainer.step(batch_size=1)\n\n        if epoch >= 3:\n            loss.asscalar()\n            dur.append(time.time() - t0)\n            acc = evaluate(model, features, pseudo, labels, val_mask)\n            print(\n                \"Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | \"\n                \"ETputs(KTEPS) {:.2f}\".format(\n                    epoch,\n                    np.mean(dur),\n                    loss.asscalar(),\n                    acc,\n                    n_edges / np.mean(dur) / 1000,\n                )\n            )\n\n    # test set accuracy\n    acc = evaluate(model, features, pseudo, labels, test_mask)\n    print(\"Test accuracy {:.2%}\".format(acc))\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"MoNet on citation network\")\n    register_data_args(parser)\n    parser.add_argument(\n        \"--dropout\", type=float, default=0.5, help=\"dropout probability\"\n    )\n    parser.add_argument(\"--gpu\", type=int, default=-1, help=\"gpu\")\n    parser.add_argument(\"--lr\", type=float, default=1e-2, help=\"learning rate\")\n    parser.add_argument(\n        \"--n-epochs\", type=int, default=200, help=\"number of training epochs\"\n    )\n    parser.add_argument(\n        \"--n-hidden\", type=int, default=16, help=\"number of hidden gcn units\"\n    )\n    parser.add_argument(\n        \"--n-layers\", type=int, default=1, help=\"number of hidden gcn layers\"\n    )\n    parser.add_argument(\n        \"--pseudo-dim\",\n        type=int,\n        default=2,\n        help=\"Pseudo coordinate dimensions in GMMConv, 2 for cora and 3 for pubmed\",\n    )\n    parser.add_argument(\n        \"--n-kernels\",\n        type=int,\n        default=3,\n        help=\"Number of kernels in GMMConv layer\",\n    )\n    parser.add_argument(\n        \"--weight-decay\", type=float, default=5e-5, help=\"Weight for L2 loss\"\n    )\n    args = parser.parse_args()\n    print(args)\n\n    main(args)\n"
  },
  {
    "path": "examples/mxnet/rgcn/README.md",
    "content": "# Relational-GCN\n\n* Paper: [https://arxiv.org/abs/1703.06103](https://arxiv.org/abs/1703.06103)\n* Author's code for entity classification: [https://github.com/tkipf/relational-gcn](https://github.com/tkipf/relational-gcn)\n* Author's code for link prediction: [https://github.com/MichSchli/RelationPrediction](https://github.com/MichSchli/RelationPrediction)\n\n### Dependencies\nTwo extra python packages are needed for this example:\n\n- MXNet nightly build\n- requests\n- rdflib\n- pandas\n\n```bash\npip install mxnet --pre\npip install requests rdflib pandas\n```\n\nExample code was tested with rdflib 4.2.2 and pandas 0.23.4\n\n### Entity Classification\nAIFB: accuracy 97.22% (5 runs, DGL), 95.83% (paper)\n```\nDGLBACKEND=mxnet python3 entity_classify.py -d aifb --testing --gpu 0\n```\n\nMUTAG: accuracy 70.59% (5 runs, DGL), 73.23% (paper)\n```\nDGLBACKEND=mxnet python3 entity_classify.py -d mutag --l2norm 5e-4 --n-bases 40 --testing --gpu 0\n```\n\nBGS: accuracy 86.21% (5 runs, DGL, n-basese=20), 83.10% (paper)\n```\nDGLBACKEND=mxnet python3 entity_classify.py -d bgs --l2norm 5e-4 --n-bases 20 --testing --gpu 0\n```\n"
  },
  {
    "path": "examples/mxnet/rgcn/entity_classify.py",
    "content": "\"\"\"\nModeling Relational Data with Graph Convolutional Networks\nPaper: https://arxiv.org/abs/1703.06103\nCode: https://github.com/tkipf/relational-gcn\n\nDifference compared to tkipf/relation-gcn\n* l2norm applied to all weights\n* remove nodes that won't be touched\n\"\"\"\n\nimport argparse\nimport time\nfrom functools import partial\n\nimport dgl\nimport mxnet as mx\nimport mxnet.ndarray as F\nimport numpy as np\nfrom dgl.data.rdf import AIFBDataset, AMDataset, BGSDataset, MUTAGDataset\nfrom dgl.nn.mxnet import RelGraphConv\n\nfrom model import BaseRGCN\nfrom mxnet import gluon\n\n\nclass EntityClassify(BaseRGCN):\n    def build_input_layer(self):\n        return RelGraphConv(\n            self.num_nodes,\n            self.h_dim,\n            self.num_rels,\n            \"basis\",\n            self.num_bases,\n            activation=F.relu,\n            self_loop=self.use_self_loop,\n            dropout=self.dropout,\n        )\n\n    def build_hidden_layer(self, idx):\n        return RelGraphConv(\n            self.h_dim,\n            self.h_dim,\n            self.num_rels,\n            \"basis\",\n            self.num_bases,\n            activation=F.relu,\n            self_loop=self.use_self_loop,\n            dropout=self.dropout,\n        )\n\n    def build_output_layer(self):\n        return RelGraphConv(\n            self.h_dim,\n            self.out_dim,\n            self.num_rels,\n            \"basis\",\n            self.num_bases,\n            activation=None,\n            self_loop=self.use_self_loop,\n        )\n\n\ndef main(args):\n    # load graph data\n    if args.dataset == \"aifb\":\n        dataset = AIFBDataset()\n    elif args.dataset == \"mutag\":\n        dataset = MUTAGDataset()\n    elif args.dataset == \"bgs\":\n        dataset = BGSDataset()\n    elif args.dataset == \"am\":\n        dataset = AMDataset()\n    else:\n        raise ValueError()\n\n    # Load from hetero-graph\n    hg = dataset[0]\n\n    num_rels = len(hg.canonical_etypes)\n    category = dataset.predict_category\n    num_classes = dataset.num_classes\n    train_mask = hg.nodes[category].data.pop(\"train_mask\")\n    test_mask = hg.nodes[category].data.pop(\"test_mask\")\n    train_idx = mx.nd.array(np.nonzero(train_mask.asnumpy())[0], dtype=\"int64\")\n    test_idx = mx.nd.array(np.nonzero(test_mask.asnumpy())[0], dtype=\"int64\")\n    labels = mx.nd.array(hg.nodes[category].data.pop(\"labels\"), dtype=\"int64\")\n\n    # split dataset into train, validate, test\n    if args.validation:\n        val_idx = train_idx[: len(train_idx) // 5]\n        train_idx = train_idx[len(train_idx) // 5 :]\n    else:\n        val_idx = train_idx\n\n    # calculate norm for each edge type and store in edge\n    for canonical_etype in hg.canonical_etypes:\n        u, v, eid = hg.all_edges(form=\"all\", etype=canonical_etype)\n        v = v.asnumpy()\n        _, inverse_index, count = np.unique(\n            v, return_inverse=True, return_counts=True\n        )\n        degrees = count[inverse_index]\n        norm = np.ones(eid.shape[0]) / degrees\n        hg.edges[canonical_etype].data[\"norm\"] = mx.nd.expand_dims(\n            mx.nd.array(norm), axis=1\n        )\n\n    # get target category id\n    category_id = len(hg.ntypes)\n    for i, ntype in enumerate(hg.ntypes):\n        if ntype == category:\n            category_id = i\n\n    g = dgl.to_homogeneous(hg, edata=[\"norm\"])\n    num_nodes = g.number_of_nodes()\n    node_ids = mx.nd.arange(num_nodes)\n    edge_norm = g.edata[\"norm\"]\n    edge_type = g.edata[dgl.ETYPE]\n\n    # find out the target node ids in g\n    node_tids = g.ndata[dgl.NTYPE]\n    loc = node_tids == category_id\n    loc = mx.nd.array(np.nonzero(loc.asnumpy())[0], dtype=\"int64\")\n    target_idx = node_ids[loc]\n\n    # since the nodes are featureless, the input feature is then the node id.\n    feats = mx.nd.arange(num_nodes, dtype=\"int32\")\n\n    # check cuda\n    use_cuda = args.gpu >= 0\n    if use_cuda:\n        ctx = mx.gpu(args.gpu)\n        feats = feats.as_in_context(ctx)\n        edge_type = edge_type.as_in_context(ctx)\n        edge_norm = edge_norm.as_in_context(ctx)\n        labels = labels.as_in_context(ctx)\n        train_idx = train_idx.as_in_context(ctx)\n        g = g.to(ctx)\n    else:\n        ctx = mx.cpu(0)\n\n    # create model\n    model = EntityClassify(\n        num_nodes,\n        args.n_hidden,\n        num_classes,\n        num_rels,\n        num_bases=args.n_bases,\n        num_hidden_layers=args.n_layers - 2,\n        dropout=args.dropout,\n        use_self_loop=args.use_self_loop,\n        gpu_id=args.gpu,\n    )\n    model.initialize(ctx=ctx)\n\n    # optimizer\n    trainer = gluon.Trainer(\n        model.collect_params(),\n        \"adam\",\n        {\"learning_rate\": args.lr, \"wd\": args.l2norm},\n    )\n    loss_fcn = gluon.loss.SoftmaxCELoss(from_logits=False)\n\n    # training loop\n    print(\"start training...\")\n    forward_time = []\n    backward_time = []\n    for epoch in range(args.n_epochs):\n        t0 = time.time()\n        with mx.autograd.record():\n            pred = model(g, feats, edge_type, edge_norm)\n            pred = pred[target_idx]\n            loss = loss_fcn(pred[train_idx], labels[train_idx])\n        t1 = time.time()\n        loss.backward()\n        trainer.step(len(train_idx))\n        t2 = time.time()\n\n        forward_time.append(t1 - t0)\n        backward_time.append(t2 - t1)\n        print(\n            \"Epoch {:05d} | Train Forward Time(s) {:.4f} | Backward Time(s) {:.4f}\".format(\n                epoch, forward_time[-1], backward_time[-1]\n            )\n        )\n\n        train_acc = (\n            F.sum(\n                mx.nd.cast(pred[train_idx].argmax(axis=1), \"int64\")\n                == labels[train_idx]\n            ).asscalar()\n            / train_idx.shape[0]\n        )\n        val_acc = F.sum(\n            mx.nd.cast(pred[val_idx].argmax(axis=1), \"int64\") == labels[val_idx]\n        ).asscalar() / len(val_idx)\n        print(\n            \"Train Accuracy: {:.4f} | Validation Accuracy: {:.4f}\".format(\n                train_acc, val_acc\n            )\n        )\n    print()\n\n    logits = model.forward(g, feats, edge_type, edge_norm)\n    logits = logits[target_idx]\n    test_acc = F.sum(\n        mx.nd.cast(logits[test_idx].argmax(axis=1), \"int64\") == labels[test_idx]\n    ).asscalar() / len(test_idx)\n    print(\"Test Accuracy: {:.4f}\".format(test_acc))\n    print()\n\n    print(\n        \"Mean forward time: {:4f}\".format(\n            np.mean(forward_time[len(forward_time) // 4 :])\n        )\n    )\n    print(\n        \"Mean backward time: {:4f}\".format(\n            np.mean(backward_time[len(backward_time) // 4 :])\n        )\n    )\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"RGCN\")\n    parser.add_argument(\n        \"--dropout\", type=float, default=0, help=\"dropout probability\"\n    )\n    parser.add_argument(\n        \"--n-hidden\", type=int, default=16, help=\"number of hidden units\"\n    )\n    parser.add_argument(\"--gpu\", type=int, default=-1, help=\"gpu\")\n    parser.add_argument(\"--lr\", type=float, default=1e-2, help=\"learning rate\")\n    parser.add_argument(\n        \"--n-bases\",\n        type=int,\n        default=-1,\n        help=\"number of filter weight matrices, default: -1 [use all]\",\n    )\n    parser.add_argument(\n        \"--n-layers\", type=int, default=2, help=\"number of propagation rounds\"\n    )\n    parser.add_argument(\n        \"-e\",\n        \"--n-epochs\",\n        type=int,\n        default=50,\n        help=\"number of training epochs\",\n    )\n    parser.add_argument(\n        \"-d\", \"--dataset\", type=str, required=True, help=\"dataset to use\"\n    )\n    parser.add_argument(\"--l2norm\", type=float, default=0, help=\"l2 norm coef\")\n    parser.add_argument(\n        \"--use-self-loop\",\n        default=False,\n        action=\"store_true\",\n        help=\"include self feature as a special relation\",\n    )\n    fp = parser.add_mutually_exclusive_group(required=False)\n    fp.add_argument(\"--validation\", dest=\"validation\", action=\"store_true\")\n    fp.add_argument(\"--testing\", dest=\"validation\", action=\"store_false\")\n    parser.set_defaults(validation=True)\n\n    args = parser.parse_args()\n    print(args)\n    args.bfs_level = args.n_layers + 1  # pruning used nodes for memory\n    main(args)\n"
  },
  {
    "path": "examples/mxnet/rgcn/model.py",
    "content": "import mxnet as mx\nfrom mxnet import gluon\n\n\nclass BaseRGCN(gluon.Block):\n    def __init__(\n        self,\n        num_nodes,\n        h_dim,\n        out_dim,\n        num_rels,\n        num_bases=-1,\n        num_hidden_layers=1,\n        dropout=0,\n        use_self_loop=False,\n        gpu_id=-1,\n    ):\n        super(BaseRGCN, self).__init__()\n        self.num_nodes = num_nodes\n        self.h_dim = h_dim\n        self.out_dim = out_dim\n        self.num_rels = num_rels\n        self.num_bases = num_bases\n        self.num_hidden_layers = num_hidden_layers\n        self.dropout = dropout\n        self.use_self_loop = use_self_loop\n        self.gpu_id = gpu_id\n\n        # create rgcn layers\n        self.build_model()\n\n    def build_model(self):\n        self.layers = gluon.nn.Sequential()\n        # i2h\n        i2h = self.build_input_layer()\n        if i2h is not None:\n            self.layers.add(i2h)\n        # h2h\n        for idx in range(self.num_hidden_layers):\n            h2h = self.build_hidden_layer(idx)\n            self.layers.add(h2h)\n        # h2o\n        h2o = self.build_output_layer()\n        if h2o is not None:\n            self.layers.add(h2o)\n\n    def build_input_layer(self):\n        return None\n\n    def build_hidden_layer(self):\n        raise NotImplementedError\n\n    def build_output_layer(self):\n        return None\n\n    def forward(self, g, h, r, norm):\n        for layer in self.layers:\n            h = layer(g, h, r, norm)\n        return h\n"
  },
  {
    "path": "examples/mxnet/scenegraph/README.md",
    "content": "# Scene Graph Extraction\n\nScene graph extraction aims at not only detect objects in the given image, but also classify the relationships between pairs of them.\n\nThis example reproduces [Graphical Contrastive Losses for Scene Graph Parsing](https://arxiv.org/abs/1903.02728), author's code can be found [here](https://github.com/NVIDIA/ContrastiveLosses4VRD).\n\n![DEMO](https://raw.githubusercontent.com/dmlc/web-data/master/dgl/examples/mxnet/scenegraph/old-couple-pred.png)\n\n## Results\n\n**VisualGenome**\n\n| Model     | Backbone  | mAP@50   | SGDET@20 | SGDET@50 | SGDET@100 | PHRCLS@20 | PHRCLS@50 |PHRCLS@100 | PREDCLS@20 | PREDCLS@50 | PREDCLS@100 |\n| :---      | :---      | :---     | :---     | :---     | :---      | :---      | :---      | :---      | :---       | :---       | :---        |\n| RelDN, L0 | ResNet101 | 29.5     | 22.65    | 30.02    | 35.04     | 32.84     | 35.60     | 36.26     | 60.58      | 65.53      | 66.51       |\n\n## Preparation\n\nThis implementation is based on GluonCV. Install GluonCV with \n\n```\npip install gluoncv --upgrade\n```\n\nThe implementation contains the following files:\n\n```\n.\n|-- data\n|   |-- dataloader.py\n|   |-- __init__.py\n|   |-- object.py\n|   |-- prepare_visualgenome.py\n|   `-- relation.py\n|-- demo_reldn.py\n|-- model\n|   |-- faster_rcnn.py\n|   |-- __init__.py\n|   `-- reldn.py\n|-- README.md\n|-- train_faster_rcnn.py\n|-- train_faster_rcnn.sh\n|-- train_freq_prior.py\n|-- train_reldn.py\n|-- train_reldn.sh\n|-- utils\n|   |-- build_graph.py\n|   |-- __init__.py\n|   |-- metric.py\n|   |-- sampling.py\n|   `-- viz.py\n|-- validate_reldn.py\n`-- validate_reldn.sh\n```\n\n- The folder `data` contains the data preparation script, and definition of datasets for object detection and scene graph extraction.\n- The folder `model` contains model definition.\n- The folder `utils` contains helper functions for training, validation, and visualization.\n- The script `train_faster_rcnn.py` trains a Faster R-CNN model on VisualGenome dataset, and `train_faster_rcnn.sh` includes preset parameters.\n- The script `train_freq_prior.py` trains the frequency counts for RelDN model training.\n- The script `train_reldn.py` trains a RelDN model, and `train_reldn.sh` includes preset parameters.\n- The script `validate_reldn.py` validate the trained Faster R-CNN and RelDN models, and `validate_reldn.sh` includes preset parameters.\n- The script `demo_reldh.py` makes use of trained parameters and extract an scene graph from an arbitrary input image.\n\nBelow are further steps on training your own models. Besides, we also provide pretrained model files for validation and demo:\n\n1. [Faster R-CNN Model for Object Detection](http://dgl-data/models/SceneGraph/faster_rcnn_resnet101_v1d_visualgenome.params)\n2. [RelDN Model](http://dgl-data/models/SceneGraph/reldn.params)\n3. [Faster R-CNN Model for Edge Feature](http://dgl-data/models/SceneGraph/detector_feature.params)\n\n## Data preparation\n\nWe provide scripts to download and prepare the VisualGenome dataset. One can run with\n\n```\npython data/prepare_visualgenome.py\n```\n\n## Object Detector\n\nFirst one need to train the object detection model on VisualGenome.\n\n```\nbash train_faster_rcnn.sh\n```\n\nIt runs for about 20 hours on a machine with 64 CPU cores and 8 V100 GPUs.\n\n## Training RelDN\n\nWith a trained Faster R-CNN model, one can start the training of RelDN model by\n\n```\nbash train_reldn.sh\n```\n\nIt runs for about 2 days with one single GPU and 8 CPU cores.\n\n## Validate RelDN\n\nAfter the training, one can evaluate the results with multiple commonly-used metrics:\n\n```\nbash validate_reldn.sh\n```\n\n## Demo\n\nWe provide a demo script of running the model with real-world pictures. Be aware that you need trained model to generate meaningful results from the demo, otherwise the script will download the pre-trained model automatically.\n"
  },
  {
    "path": "examples/mxnet/scenegraph/data/__init__.py",
    "content": "from .dataloader import *\nfrom .object import *\nfrom .relation import *\n"
  },
  {
    "path": "examples/mxnet/scenegraph/data/dataloader.py",
    "content": "\"\"\"DataLoader utils.\"\"\"\nimport dgl\nfrom gluoncv.data.batchify import Pad\nfrom mxnet import nd\n\n\ndef dgl_mp_batchify_fn(data):\n    if isinstance(data[0], tuple):\n        data = zip(*data)\n        return [dgl_mp_batchify_fn(i) for i in data]\n\n    for dt in data:\n        if dt is not None:\n            if isinstance(dt, dgl.DGLGraph):\n                return [d for d in data if isinstance(d, dgl.DGLGraph)]\n            elif isinstance(dt, nd.NDArray):\n                pad = Pad(axis=(1, 2), num_shards=1, ret_length=False)\n                data_list = [dt for dt in data if dt is not None]\n                return pad(data_list)\n"
  },
  {
    "path": "examples/mxnet/scenegraph/data/object.py",
    "content": "\"\"\"Pascal VOC object detection dataset.\"\"\"\nfrom __future__ import absolute_import, division\n\nimport json\nimport logging\nimport os\nimport pickle\nimport warnings\nfrom collections import Counter\n\nimport mxnet as mx\nimport numpy as np\nfrom gluoncv.data import COCODetection\n\n\nclass VGObject(COCODetection):\n    CLASSES = [\n        \"airplane\",\n        \"animal\",\n        \"arm\",\n        \"bag\",\n        \"banana\",\n        \"basket\",\n        \"beach\",\n        \"bear\",\n        \"bed\",\n        \"bench\",\n        \"bike\",\n        \"bird\",\n        \"board\",\n        \"boat\",\n        \"book\",\n        \"boot\",\n        \"bottle\",\n        \"bowl\",\n        \"box\",\n        \"boy\",\n        \"branch\",\n        \"building\",\n        \"bus\",\n        \"cabinet\",\n        \"cap\",\n        \"car\",\n        \"cat\",\n        \"chair\",\n        \"child\",\n        \"clock\",\n        \"coat\",\n        \"counter\",\n        \"cow\",\n        \"cup\",\n        \"curtain\",\n        \"desk\",\n        \"dog\",\n        \"door\",\n        \"drawer\",\n        \"ear\",\n        \"elephant\",\n        \"engine\",\n        \"eye\",\n        \"face\",\n        \"fence\",\n        \"finger\",\n        \"flag\",\n        \"flower\",\n        \"food\",\n        \"fork\",\n        \"fruit\",\n        \"giraffe\",\n        \"girl\",\n        \"glass\",\n        \"glove\",\n        \"guy\",\n        \"hair\",\n        \"hand\",\n        \"handle\",\n        \"hat\",\n        \"head\",\n        \"helmet\",\n        \"hill\",\n        \"horse\",\n        \"house\",\n        \"jacket\",\n        \"jean\",\n        \"kid\",\n        \"kite\",\n        \"lady\",\n        \"lamp\",\n        \"laptop\",\n        \"leaf\",\n        \"leg\",\n        \"letter\",\n        \"light\",\n        \"logo\",\n        \"man\",\n        \"men\",\n        \"motorcycle\",\n        \"mountain\",\n        \"mouth\",\n        \"neck\",\n        \"nose\",\n        \"number\",\n        \"orange\",\n        \"pant\",\n        \"paper\",\n        \"paw\",\n        \"people\",\n        \"person\",\n        \"phone\",\n        \"pillow\",\n        \"pizza\",\n        \"plane\",\n        \"plant\",\n        \"plate\",\n        \"player\",\n        \"pole\",\n        \"post\",\n        \"pot\",\n        \"racket\",\n        \"railing\",\n        \"rock\",\n        \"roof\",\n        \"room\",\n        \"screen\",\n        \"seat\",\n        \"sheep\",\n        \"shelf\",\n        \"shirt\",\n        \"shoe\",\n        \"short\",\n        \"sidewalk\",\n        \"sign\",\n        \"sink\",\n        \"skateboard\",\n        \"ski\",\n        \"skier\",\n        \"sneaker\",\n        \"snow\",\n        \"sock\",\n        \"stand\",\n        \"street\",\n        \"surfboard\",\n        \"table\",\n        \"tail\",\n        \"tie\",\n        \"tile\",\n        \"tire\",\n        \"toilet\",\n        \"towel\",\n        \"tower\",\n        \"track\",\n        \"train\",\n        \"tree\",\n        \"truck\",\n        \"trunk\",\n        \"umbrella\",\n        \"vase\",\n        \"vegetable\",\n        \"vehicle\",\n        \"wave\",\n        \"wheel\",\n        \"window\",\n        \"windshield\",\n        \"wing\",\n        \"wire\",\n        \"woman\",\n        \"zebra\",\n    ]\n\n    def __init__(self, **kwargs):\n        super(VGObject, self).__init__(**kwargs)\n\n    @property\n    def annotation_dir(self):\n        return \"\"\n\n    def _parse_image_path(self, entry):\n        dirname = \"VG_100K\"\n        filename = entry[\"file_name\"]\n        abs_path = os.path.join(self._root, dirname, filename)\n        return abs_path\n"
  },
  {
    "path": "examples/mxnet/scenegraph/data/prepare_visualgenome.py",
    "content": "\"\"\"Prepare Visual Genome datasets\"\"\"\nimport argparse\nimport json\nimport os\nimport pickle\nimport random\nimport shutil\nimport zipfile\n\nimport tqdm\nfrom gluoncv.utils import download, makedirs\n\n_TARGET_DIR = os.path.expanduser(\"~/.mxnet/datasets/visualgenome\")\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(\n        description=\"Initialize Visual Genome dataset.\",\n        epilog=\"Example: python visualgenome.py --download-dir ~/visualgenome\",\n        formatter_class=argparse.ArgumentDefaultsHelpFormatter,\n    )\n    parser.add_argument(\n        \"--download-dir\",\n        type=str,\n        default=\"~/visualgenome/\",\n        help=\"dataset directory on disk\",\n    )\n    parser.add_argument(\n        \"--no-download\",\n        action=\"store_true\",\n        help=\"disable automatic download if set\",\n    )\n    parser.add_argument(\n        \"--overwrite\",\n        action=\"store_true\",\n        help=\"overwrite downloaded files if set, in case they are corrupted\",\n    )\n    args = parser.parse_args()\n    return args\n\n\ndef download_vg(path, overwrite=False):\n    _DOWNLOAD_URLS = [\n        (\n            \"https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip\",\n            \"a055367f675dd5476220e9b93e4ca9957b024b94\",\n        ),\n        (\n            \"https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip\",\n            \"2add3aab77623549e92b7f15cda0308f50b64ecf\",\n        ),\n    ]\n    makedirs(path)\n    for url, checksum in _DOWNLOAD_URLS:\n        filename = download(\n            url, path=path, overwrite=overwrite, sha1_hash=checksum\n        )\n        # extract\n        if filename.endswith(\"zip\"):\n            with zipfile.ZipFile(filename) as zf:\n                zf.extractall(path=path)\n    # move all images into folder `VG_100K`\n    vg_100k_path = os.path.join(path, \"VG_100K\")\n    vg_100k_2_path = os.path.join(path, \"VG_100K_2\")\n    files_2 = os.listdir(vg_100k_2_path)\n    for fl in files_2:\n        shutil.move(\n            os.path.join(vg_100k_2_path, fl), os.path.join(vg_100k_path, fl)\n        )\n\n\ndef download_json(path, overwrite=False):\n    url = \"https://data.dgl.ai/dataset/vg.zip\"\n    output = \"vg.zip\"\n    download(url, path=path)\n    with zipfile.ZipFile(output) as zf:\n        zf.extractall(path=path)\n    json_path = os.path.join(path, \"vg\")\n    json_files = os.listdir(json_path)\n    for fl in json_files:\n        shutil.move(os.path.join(json_path, fl), os.path.join(path, fl))\n    os.rmdir(json_path)\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    path = os.path.expanduser(args.download_dir)\n    if not os.path.isdir(path):\n        if args.no_download:\n            raise ValueError(\n                (\n                    \"{} is not a valid directory, make sure it is present.\"\n                    ' Or you should not disable \"--no-download\" to grab it'.format(\n                        path\n                    )\n                )\n            )\n        else:\n            download_vg(path, overwrite=args.overwrite)\n            download_json(path, overwrite=args.overwrite)\n\n    # make symlink\n    makedirs(os.path.expanduser(\"~/.mxnet/datasets\"))\n    if os.path.isdir(_TARGET_DIR):\n        os.rmdir(_TARGET_DIR)\n    os.symlink(path, _TARGET_DIR)\n"
  },
  {
    "path": "examples/mxnet/scenegraph/data/relation.py",
    "content": "\"\"\"Pascal VOC object detection dataset.\"\"\"\nfrom __future__ import absolute_import, division\n\nimport json\nimport logging\nimport os\nimport pickle\nimport warnings\nfrom collections import Counter\n\nimport dgl\n\nimport mxnet as mx\nimport numpy as np\nfrom gluoncv.data.base import VisionDataset\nfrom gluoncv.data.transforms.presets.rcnn import (\n    FasterRCNNDefaultTrainTransform,\n    FasterRCNNDefaultValTransform,\n)\n\n\nclass VGRelation(VisionDataset):\n    def __init__(\n        self,\n        root=os.path.join(\"~\", \".mxnet\", \"datasets\", \"visualgenome\"),\n        split=\"train\",\n    ):\n        super(VGRelation, self).__init__(root)\n        self._root = os.path.expanduser(root)\n        self._img_path = os.path.join(self._root, \"VG_100K\", \"{}\")\n\n        if split == \"train\":\n            self._dict_path = os.path.join(\n                self._root, \"rel_annotations_train.json\"\n            )\n        elif split == \"val\":\n            self._dict_path = os.path.join(\n                self._root, \"rel_annotations_val.json\"\n            )\n        else:\n            raise NotImplementedError\n        with open(self._dict_path) as f:\n            tmp = f.read()\n            self._dict = json.loads(tmp)\n\n        self._predicates_path = os.path.join(self._root, \"predicates.json\")\n        with open(self._predicates_path, \"r\") as f:\n            tmp = f.read()\n            self.rel_classes = json.loads(tmp)\n        self.num_rel_classes = len(self.rel_classes) + 1\n\n        self._objects_path = os.path.join(self._root, \"objects.json\")\n        with open(self._objects_path, \"r\") as f:\n            tmp = f.read()\n            self.obj_classes = json.loads(tmp)\n        self.num_obj_classes = len(self.obj_classes)\n\n        if split == \"val\":\n            self.img_transform = FasterRCNNDefaultValTransform(\n                short=600, max_size=1000\n            )\n        else:\n            self.img_transform = FasterRCNNDefaultTrainTransform(\n                short=600, max_size=1000\n            )\n        self.split = split\n\n    def __len__(self):\n        return len(self._dict)\n\n    def _hash_bbox(self, object):\n        num_list = [object[\"category\"]] + object[\"bbox\"]\n        return \"_\".join([str(num) for num in num_list])\n\n    def __getitem__(self, idx):\n        img_id = list(self._dict)[idx]\n        img_path = self._img_path.format(img_id)\n        img = mx.image.imread(img_path)\n\n        item = self._dict[img_id]\n        n_edges = len(item)\n\n        # edge to node ids\n        sub_node_hash = []\n        ob_node_hash = []\n        for i, it in enumerate(item):\n            sub_node_hash.append(self._hash_bbox(it[\"subject\"]))\n            ob_node_hash.append(self._hash_bbox(it[\"object\"]))\n        node_set = sorted(list(set(sub_node_hash + ob_node_hash)))\n        n_nodes = len(node_set)\n        node_to_id = {}\n        for i, node in enumerate(node_set):\n            node_to_id[node] = i\n        sub_id = []\n        ob_id = []\n        for i in range(n_edges):\n            sub_id.append(node_to_id[sub_node_hash[i]])\n            ob_id.append(node_to_id[ob_node_hash[i]])\n\n        # node features\n        bbox = mx.nd.zeros((n_nodes, 4))\n        node_class_ids = mx.nd.zeros((n_nodes, 1))\n        node_visited = [False for i in range(n_nodes)]\n        for i, it in enumerate(item):\n            if not node_visited[sub_id[i]]:\n                ind = sub_id[i]\n                sub = it[\"subject\"]\n                node_class_ids[ind] = sub[\"category\"]\n                # y1y2x1x2 to x1y1x2y2\n                bbox[ind, 0] = sub[\"bbox\"][2]\n                bbox[ind, 1] = sub[\"bbox\"][0]\n                bbox[ind, 2] = sub[\"bbox\"][3]\n                bbox[ind, 3] = sub[\"bbox\"][1]\n\n                node_visited[ind] = True\n\n            if not node_visited[ob_id[i]]:\n                ind = ob_id[i]\n                ob = it[\"object\"]\n                node_class_ids[ind] = ob[\"category\"]\n                # y1y2x1x2 to x1y1x2y2\n                bbox[ind, 0] = ob[\"bbox\"][2]\n                bbox[ind, 1] = ob[\"bbox\"][0]\n                bbox[ind, 2] = ob[\"bbox\"][3]\n                bbox[ind, 3] = ob[\"bbox\"][1]\n\n                node_visited[ind] = True\n\n        eta = 0.1\n        node_class_vec = node_class_ids[:, 0].one_hot(\n            self.num_obj_classes,\n            on_value=1 - eta + eta / self.num_obj_classes,\n            off_value=eta / self.num_obj_classes,\n        )\n\n        # augmentation\n        if self.split == \"val\":\n            img, bbox, _ = self.img_transform(img, bbox)\n        else:\n            img, bbox = self.img_transform(img, bbox)\n\n        # build the graph\n        g = dgl.DGLGraph()\n        g.add_nodes(n_nodes)\n        adjmat = np.zeros((n_nodes, n_nodes))\n        predicate = []\n        for i, it in enumerate(item):\n            adjmat[sub_id[i], ob_id[i]] = 1\n            predicate.append(it[\"predicate\"])\n        predicate = mx.nd.array(predicate).expand_dims(1)\n        g.add_edges(sub_id, ob_id, {\"rel_class\": mx.nd.array(predicate) + 1})\n        empty_edge_list = []\n        for i in range(n_nodes):\n            for j in range(n_nodes):\n                if i != j and adjmat[i, j] == 0:\n                    empty_edge_list.append((i, j))\n        if len(empty_edge_list) > 0:\n            src, dst = tuple(zip(*empty_edge_list))\n            g.add_edges(\n                src, dst, {\"rel_class\": mx.nd.zeros((len(empty_edge_list), 1))}\n            )\n\n        # assign features\n        g.ndata[\"bbox\"] = bbox\n        g.ndata[\"node_class\"] = node_class_ids\n        g.ndata[\"node_class_vec\"] = node_class_vec\n\n        return g, img\n"
  },
  {
    "path": "examples/mxnet/scenegraph/demo_reldn.py",
    "content": "import argparse\n\nimport gluoncv as gcv\nimport mxnet as mx\nfrom data import *\nfrom gluoncv.data.transforms import presets\nfrom gluoncv.utilz import download\nfrom model import faster_rcnn_resnet101_v1d_custom, RelDN\nfrom utils import *\n\nimport dgl\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(\n        description=\"Demo of Scene Graph Extraction.\"\n    )\n    parser.add_argument(\n        \"--image\",\n        type=str,\n        default=\"\",\n        help=\"The image for scene graph extraction.\",\n    )\n    parser.add_argument(\n        \"--gpu\",\n        type=str,\n        default=\"\",\n        help=\"GPU id to use for inference, default is not using GPU.\",\n    )\n    parser.add_argument(\n        \"--pretrained-faster-rcnn-params\",\n        type=str,\n        default=\"\",\n        help=\"Path to saved Faster R-CNN model parameters.\",\n    )\n    parser.add_argument(\n        \"--reldn-params\",\n        type=str,\n        default=\"\",\n        help=\"Path to saved Faster R-CNN model parameters.\",\n    )\n    parser.add_argument(\n        \"--faster-rcnn-params\",\n        type=str,\n        default=\"\",\n        help=\"Path to saved Faster R-CNN model parameters.\",\n    )\n    parser.add_argument(\n        \"--freq-prior\",\n        type=str,\n        default=\"freq_prior.pkl\",\n        help=\"Path to saved frequency prior data.\",\n    )\n    args = parser.parse_args()\n    return args\n\n\nargs = parse_args()\nif args.gpu:\n    ctx = mx.gpu(int(args.gpu))\nelse:\n    ctx = mx.cpu()\n\nnet = RelDN(n_classes=50, prior_pkl=args.freq_prior, semantic_only=False)\nif args.reldn_params == \"\":\n    download(\"http://data.dgl.ai/models/SceneGraph/reldn.params\")\n    net.load_parameters(\"rendl.params\", ctx=ctx)\nelse:\n    net.load_parameters(args.reldn_params, ctx=ctx)\n\n# dataset and dataloader\nvg_val = VGRelation(split=\"val\")\ndetector = faster_rcnn_resnet101_v1d_custom(\n    classes=vg_val.obj_classes,\n    pretrained_base=False,\n    pretrained=False,\n    additional_output=True,\n)\nif args.pretrained_faster_rcnn_params == \"\":\n    download(\n        \"http://data.dgl.ai/models/SceneGraph/faster_rcnn_resnet101_v1d_visualgenome.params\"\n    )\n    params_path = \"faster_rcnn_resnet101_v1d_visualgenome.params\"\nelse:\n    params_path = args.pretrained_faster_rcnn_params\ndetector.load_parameters(\n    params_path, ctx=ctx, ignore_extra=True, allow_missing=True\n)\n\ndetector_feat = faster_rcnn_resnet101_v1d_custom(\n    classes=vg_val.obj_classes,\n    pretrained_base=False,\n    pretrained=False,\n    additional_output=True,\n)\ndetector_feat.load_parameters(\n    params_path, ctx=ctx, ignore_extra=True, allow_missing=True\n)\nif args.faster_rcnn_params == \"\":\n    download(\n        \"http://data.dgl.ai/models/SceneGraph/faster_rcnn_resnet101_v1d_visualgenome.params\"\n    )\n    detector_feat.features.load_parameters(\n        \"faster_rcnn_resnet101_v1d_visualgenome.params\", ctx=ctx\n    )\nelse:\n    detector_feat.features.load_parameters(args.faster_rcnn_params, ctx=ctx)\n\n# image input\nif args.image:\n    image_path = args.image\nelse:\n    gcv.utils.download(\n        \"https://raw.githubusercontent.com/dmlc/web-data/master/\"\n        + \"dgl/examples/mxnet/scenegraph/old-couple.png\",\n        \"old-couple.png\",\n    )\n    image_path = \"old-couple.png\"\nx, img = presets.rcnn.load_test(\n    args.image, short=detector.short, max_size=detector.max_size\n)\nx = x.as_in_context(ctx)\n# detector prediction\nids, scores, bboxes, feat, feat_ind, spatial_feat = detector(x)\n# build graph, extract edge features\ng = build_graph_validate_pred(\n    x,\n    ids,\n    scores,\n    bboxes,\n    feat_ind,\n    spatial_feat,\n    bbox_improvement=True,\n    scores_top_k=75,\n    overlap=False,\n)\nrel_bbox = g.edata[\"rel_bbox\"].expand_dims(0).as_in_context(ctx)\n_, _, _, spatial_feat_rel = detector_feat(x, None, None, rel_bbox)\ng.edata[\"edge_feat\"] = spatial_feat_rel[0]\n# graph prediction\ng = net(g)\n\n_, preds = extract_pred(g, joint_preds=True)\npreds = preds[preds[:, 1].argsort()[::-1]]\n\nplot_sg(img, preds, detector.classes, vg_val.rel_classes, 10)\n"
  },
  {
    "path": "examples/mxnet/scenegraph/model/__init__.py",
    "content": "from .faster_rcnn import *\nfrom .reldn import *\n"
  },
  {
    "path": "examples/mxnet/scenegraph/model/faster_rcnn.py",
    "content": "\"\"\"Faster RCNN Model.\"\"\"\nfrom __future__ import absolute_import\n\nimport os\nimport warnings\n\nimport mxnet as mx\n\nfrom gluoncv.model_zoo.faster_rcnn.rcnn_target import (\n    RCNNTargetGenerator,\n    RCNNTargetSampler,\n)\nfrom gluoncv.model_zoo.rcnn import RCNN\nfrom gluoncv.model_zoo.rpn import RPN\nfrom gluoncv.nn.feature import FPNFeatureExpander\nfrom mxnet import autograd\nfrom mxnet.gluon import nn\nfrom mxnet.gluon.contrib.nn import SyncBatchNorm\n\n__all__ = [\n    \"FasterRCNN\",\n    \"get_faster_rcnn\",\n    \"faster_rcnn_resnet50_v1b_coco\",\n    \"faster_rcnn_resnet50_v1b_custom\",\n    \"faster_rcnn_resnet101_v1d_coco\",\n    \"faster_rcnn_resnet101_v1d_custom\",\n]\n\n\nclass FasterRCNN(RCNN):\n    r\"\"\"Faster RCNN network.\n\n    Parameters\n    ----------\n    features : gluon.HybridBlock\n        Base feature extractor before feature pooling layer.\n    top_features : gluon.HybridBlock\n        Tail feature extractor after feature pooling layer.\n    classes : iterable of str\n        Names of categories, its length is ``num_class``.\n    box_features : gluon.HybridBlock, default is None\n        feature head for transforming shared ROI output (top_features) for box prediction.\n        If set to None, global average pooling will be used.\n    short : int, default is 600.\n        Input image short side size.\n    max_size : int, default is 1000.\n        Maximum size of input image long side.\n    min_stage : int, default is 4\n        Minimum stage NO. for FPN stages.\n    max_stage : int, default is 4\n        Maximum stage NO. for FPN stages.\n    train_patterns : str, default is None.\n        Matching pattern for trainable parameters.\n    nms_thresh : float, default is 0.3.\n        Non-maximum suppression threshold. You can specify < 0 or > 1 to disable NMS.\n    nms_topk : int, default is 400\n        Apply NMS to top k detection results, use -1 to disable so that every Detection\n         result is used in NMS.\n    post_nms : int, default is 100\n        Only return top `post_nms` detection results, the rest is discarded. The number is\n        based on COCO dataset which has maximum 100 objects per image. You can adjust this\n        number if expecting more objects. You can use -1 to return all detections.\n    roi_mode : str, default is align\n        ROI pooling mode. Currently support 'pool' and 'align'.\n    roi_size : tuple of int, length 2, default is (14, 14)\n        (height, width) of the ROI region.\n    strides : int/tuple of ints, default is 16\n        Feature map stride with respect to original image.\n        This is usually the ratio between original image size and feature map size.\n        For FPN, use a tuple of ints.\n    clip : float, default is None\n        Clip bounding box target to this value.\n    rpn_channel : int, default is 1024\n        Channel number used in RPN convolutional layers.\n    base_size : int\n        The width(and height) of reference anchor box.\n    scales : iterable of float, default is (8, 16, 32)\n        The areas of anchor boxes.\n        We use the following form to compute the shapes of anchors:\n\n        .. math::\n\n            width_{anchor} = size_{base} \\times scale \\times \\sqrt{ 1 / ratio}\n            height_{anchor} = size_{base} \\times scale \\times \\sqrt{ratio}\n\n    ratios : iterable of float, default is (0.5, 1, 2)\n        The aspect ratios of anchor boxes. We expect it to be a list or tuple.\n    alloc_size : tuple of int\n        Allocate size for the anchor boxes as (H, W).\n        Usually we generate enough anchors for large feature map, e.g. 128x128.\n        Later in inference we can have variable input sizes,\n        at which time we can crop corresponding anchors from this large\n        anchor map so we can skip re-generating anchors for each input.\n    rpn_train_pre_nms : int, default is 12000\n        Filter top proposals before NMS in training of RPN.\n    rpn_train_post_nms : int, default is 2000\n        Return top proposal results after NMS in training of RPN.\n        Will be set to rpn_train_pre_nms if it is larger than rpn_train_pre_nms.\n    rpn_test_pre_nms : int, default is 6000\n        Filter top proposals before NMS in testing of RPN.\n    rpn_test_post_nms : int, default is 300\n        Return top proposal results after NMS in testing of RPN.\n        Will be set to rpn_test_pre_nms if it is larger than rpn_test_pre_nms.\n    rpn_nms_thresh : float, default is 0.7\n        IOU threshold for NMS. It is used to remove overlapping proposals.\n    rpn_num_sample : int, default is 256\n        Number of samples for RPN targets.\n    rpn_pos_iou_thresh : float, default is 0.7\n        Anchor with IOU larger than ``pos_iou_thresh`` is regarded as positive samples.\n    rpn_neg_iou_thresh : float, default is 0.3\n        Anchor with IOU smaller than ``neg_iou_thresh`` is regarded as negative samples.\n        Anchors with IOU in between ``pos_iou_thresh`` and ``neg_iou_thresh`` are\n        ignored.\n    rpn_pos_ratio : float, default is 0.5\n        ``pos_ratio`` defines how many positive samples (``pos_ratio * num_sample``) is\n        to be sampled.\n    rpn_box_norm : array-like of size 4, default is (1., 1., 1., 1.)\n        Std value to be divided from encoded values.\n    rpn_min_size : int, default is 16\n        Proposals whose size is smaller than ``min_size`` will be discarded.\n    per_device_batch_size : int, default is 1\n        Batch size for each device during training.\n    num_sample : int, default is 128\n        Number of samples for RCNN targets.\n    pos_iou_thresh : float, default is 0.5\n        Proposal whose IOU larger than ``pos_iou_thresh`` is regarded as positive samples.\n    pos_ratio : float, default is 0.25\n        ``pos_ratio`` defines how many positive samples (``pos_ratio * num_sample``) is\n        to be sampled.\n    max_num_gt : int, default is 300\n        Maximum ground-truth number in whole training dataset. This is only an upper bound, not\n        necessarily very precise. However, using a very big number may impact the training speed.\n    additional_output : boolean, default is False\n        ``additional_output`` is only used for Mask R-CNN to get internal outputs.\n    force_nms : bool, default is False\n        Appy NMS to all categories, this is to avoid overlapping detection results from different\n        categories.\n\n    Attributes\n    ----------\n    classes : iterable of str\n        Names of categories, its length is ``num_class``.\n    num_class : int\n        Number of positive categories.\n    short : int\n        Input image short side size.\n    max_size : int\n        Maximum size of input image long side.\n    train_patterns : str\n        Matching pattern for trainable parameters.\n    nms_thresh : float\n        Non-maximum suppression threshold. You can specify < 0 or > 1 to disable NMS.\n    nms_topk : int\n        Apply NMS to top k detection results, use -1 to disable so that every Detection\n         result is used in NMS.\n    force_nms : bool\n        Appy NMS to all categories, this is to avoid overlapping detection results\n        from different categories.\n    post_nms : int\n        Only return top `post_nms` detection results, the rest is discarded. The number is\n        based on COCO dataset which has maximum 100 objects per image. You can adjust this\n        number if expecting more objects. You can use -1 to return all detections.\n    rpn_target_generator : gluon.Block\n        Generate training targets with cls_target, box_target, and box_mask.\n    target_generator : gluon.Block\n        Generate training targets with boxes, samples, matches, gt_label and gt_box.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        features,\n        top_features,\n        classes,\n        box_features=None,\n        short=600,\n        max_size=1000,\n        min_stage=4,\n        max_stage=4,\n        train_patterns=None,\n        nms_thresh=0.3,\n        nms_topk=400,\n        post_nms=100,\n        roi_mode=\"align\",\n        roi_size=(14, 14),\n        strides=16,\n        clip=None,\n        rpn_channel=1024,\n        base_size=16,\n        scales=(8, 16, 32),\n        ratios=(0.5, 1, 2),\n        alloc_size=(128, 128),\n        rpn_nms_thresh=0.7,\n        rpn_train_pre_nms=12000,\n        rpn_train_post_nms=2000,\n        rpn_test_pre_nms=6000,\n        rpn_test_post_nms=300,\n        rpn_min_size=16,\n        per_device_batch_size=1,\n        num_sample=128,\n        pos_iou_thresh=0.5,\n        pos_ratio=0.25,\n        max_num_gt=300,\n        additional_output=False,\n        force_nms=False,\n        **kwargs\n    ):\n        super(FasterRCNN, self).__init__(\n            features=features,\n            top_features=top_features,\n            classes=classes,\n            box_features=box_features,\n            short=short,\n            max_size=max_size,\n            train_patterns=train_patterns,\n            nms_thresh=nms_thresh,\n            nms_topk=nms_topk,\n            post_nms=post_nms,\n            roi_mode=roi_mode,\n            roi_size=roi_size,\n            strides=strides,\n            clip=clip,\n            force_nms=force_nms,\n            **kwargs\n        )\n        if rpn_train_post_nms > rpn_train_pre_nms:\n            rpn_train_post_nms = rpn_train_pre_nms\n        if rpn_test_post_nms > rpn_test_pre_nms:\n            rpn_test_post_nms = rpn_test_pre_nms\n\n        self.ashape = alloc_size[0]\n        self._min_stage = min_stage\n        self._max_stage = max_stage\n        self.num_stages = max_stage - min_stage + 1\n        if self.num_stages > 1:\n            assert len(scales) == len(strides) == self.num_stages, (\n                \"The num_stages (%d) must match number of scales (%d) and strides (%d)\"\n                % (self.num_stages, len(scales), len(strides))\n            )\n        self._batch_size = per_device_batch_size\n        self._num_sample = num_sample\n        self._rpn_test_post_nms = rpn_test_post_nms\n        self._target_generator = RCNNTargetGenerator(\n            self.num_class, int(num_sample * pos_ratio), self._batch_size\n        )\n        self._additional_output = additional_output\n        with self.name_scope():\n            self.rpn = RPN(\n                channels=rpn_channel,\n                strides=strides,\n                base_size=base_size,\n                scales=scales,\n                ratios=ratios,\n                alloc_size=alloc_size,\n                clip=clip,\n                nms_thresh=rpn_nms_thresh,\n                train_pre_nms=rpn_train_pre_nms,\n                train_post_nms=rpn_train_post_nms,\n                test_pre_nms=rpn_test_pre_nms,\n                test_post_nms=rpn_test_post_nms,\n                min_size=rpn_min_size,\n                multi_level=self.num_stages > 1,\n                per_level_nms=False,\n            )\n            self.sampler = RCNNTargetSampler(\n                num_image=self._batch_size,\n                num_proposal=rpn_train_post_nms,\n                num_sample=num_sample,\n                pos_iou_thresh=pos_iou_thresh,\n                pos_ratio=pos_ratio,\n                max_num_gt=max_num_gt,\n            )\n\n    @property\n    def target_generator(self):\n        \"\"\"Returns stored target generator\n\n        Returns\n        -------\n        mxnet.gluon.HybridBlock\n            The RCNN target generator\n\n        \"\"\"\n        return self._target_generator\n\n    def reset_class(self, classes, reuse_weights=None):\n        \"\"\"Reset class categories and class predictors.\n\n        Parameters\n        ----------\n        classes : iterable of str\n            The new categories. ['apple', 'orange'] for example.\n        reuse_weights : dict\n            A {new_integer : old_integer} or mapping dict or {new_name : old_name} mapping dict,\n            or a list of [name0, name1,...] if class names don't change.\n            This allows the new predictor to reuse the\n            previously trained weights specified.\n\n        Example\n        -------\n        >>> net = gluoncv.model_zoo.get_model('faster_rcnn_resnet50_v1b_coco', pretrained=True)\n        >>> # use direct name to name mapping to reuse weights\n        >>> net.reset_class(classes=['person'], reuse_weights={'person':'person'})\n        >>> # or use interger mapping, person is the 14th category in VOC\n        >>> net.reset_class(classes=['person'], reuse_weights={0:14})\n        >>> # you can even mix them\n        >>> net.reset_class(classes=['person'], reuse_weights={'person':14})\n        >>> # or use a list of string if class name don't change\n        >>> net.reset_class(classes=['person'], reuse_weights=['person'])\n\n        \"\"\"\n        super(FasterRCNN, self).reset_class(classes, reuse_weights)\n        self._target_generator = RCNNTargetGenerator(\n            self.num_class, self.sampler._max_pos, self._batch_size\n        )\n\n    def _pyramid_roi_feats(\n        self,\n        F,\n        features,\n        rpn_rois,\n        roi_size,\n        strides,\n        roi_mode=\"align\",\n        roi_canonical_scale=224.0,\n        eps=1e-6,\n    ):\n        \"\"\"Assign rpn_rois to specific FPN layers according to its area\n           and then perform `ROIPooling` or `ROIAlign` to generate final\n           region proposals aggregated features.\n        Parameters\n        ----------\n        features : list of mx.ndarray or mx.symbol\n            Features extracted from FPN base network\n        rpn_rois : mx.ndarray or mx.symbol\n            (N, 5) with [[batch_index, x1, y1, x2, y2], ...] like\n        roi_size : tuple\n            The size of each roi with regard to ROI-Wise operation\n            each region proposal will be roi_size spatial shape.\n        strides : tuple e.g. [4, 8, 16, 32]\n            Define the gap that ori image and feature map have\n        roi_mode : str, default is align\n            ROI pooling mode. Currently support 'pool' and 'align'.\n        roi_canonical_scale : float, default is 224.0\n            Hyperparameters for the RoI-to-FPN level mapping heuristic.\n        Returns\n        -------\n        Pooled roi features aggregated according to its roi_level\n        \"\"\"\n        max_stage = self._max_stage\n        if self._max_stage > 5:  # do not use p6 for RCNN\n            max_stage = self._max_stage - 1\n        _, x1, y1, x2, y2 = F.split(rpn_rois, axis=-1, num_outputs=5)\n        h = y2 - y1 + 1\n        w = x2 - x1 + 1\n        roi_level = F.floor(\n            4 + F.log2(F.sqrt(w * h) / roi_canonical_scale + eps)\n        )\n        roi_level = F.squeeze(F.clip(roi_level, self._min_stage, max_stage))\n        # [2,2,..,3,3,...,4,4,...,5,5,...] ``Prohibit swap order here``\n        # roi_level_sorted_args = F.argsort(roi_level, is_ascend=True)\n        # roi_level = F.sort(roi_level, is_ascend=True)\n        # rpn_rois = F.take(rpn_rois, roi_level_sorted_args, axis=0)\n        pooled_roi_feats = []\n        for i, l in enumerate(range(self._min_stage, max_stage + 1)):\n            if roi_mode == \"pool\":\n                # Pool features with all rois first, and then set invalid pooled features to zero,\n                # at last ele-wise add together to aggregate all features.\n                pooled_feature = F.ROIPooling(\n                    features[i], rpn_rois, roi_size, 1.0 / strides[i]\n                )\n                pooled_feature = F.where(\n                    roi_level == l, pooled_feature, F.zeros_like(pooled_feature)\n                )\n            elif roi_mode == \"align\":\n                if (\n                    \"box_encode\" in F.contrib.__dict__\n                    and \"box_decode\" in F.contrib.__dict__\n                ):\n                    # TODO(jerryzcn): clean this up for once mx 1.6 is released.\n                    masked_rpn_rois = F.where(\n                        roi_level == l, rpn_rois, F.ones_like(rpn_rois) * -1.0\n                    )\n                    pooled_feature = F.contrib.ROIAlign(\n                        features[i],\n                        masked_rpn_rois,\n                        roi_size,\n                        1.0 / strides[i],\n                        sample_ratio=2,\n                    )\n                else:\n                    pooled_feature = F.contrib.ROIAlign(\n                        features[i],\n                        rpn_rois,\n                        roi_size,\n                        1.0 / strides[i],\n                        sample_ratio=2,\n                    )\n                    pooled_feature = F.where(\n                        roi_level == l,\n                        pooled_feature,\n                        F.zeros_like(pooled_feature),\n                    )\n            else:\n                raise ValueError(\"Invalid roi mode: {}\".format(roi_mode))\n            pooled_roi_feats.append(pooled_feature)\n        # Ele-wise add to aggregate all pooled features\n        pooled_roi_feats = F.ElementWiseSum(*pooled_roi_feats)\n        # Sort all pooled features by asceding order\n        # [2,2,..,3,3,...,4,4,...,5,5,...]\n        # pooled_roi_feats = F.take(pooled_roi_feats, roi_level_sorted_args)\n        # pooled roi feats (B*N, C, 7, 7), N = N2 + N3 + N4 + N5 = num_roi, C=256 in ori paper\n        return pooled_roi_feats\n\n    # pylint: disable=arguments-differ\n    def hybrid_forward(self, F, x, gt_box=None, gt_label=None, m_rpn_box=None):\n        \"\"\"Forward Faster-RCNN network.\n\n        The behavior during training and inference is different.\n\n        Parameters\n        ----------\n        x : mxnet.nd.NDArray or mxnet.symbol\n            The network input tensor.\n        gt_box : type, only required during training\n            The ground-truth bbox tensor with shape (B, N, 4).\n        gt_label : type, only required during training\n            The ground-truth label tensor with shape (B, 1, 4).\n\n        Returns\n        -------\n        (ids, scores, bboxes)\n            During inference, returns final class id, confidence scores, bounding\n            boxes.\n\n        \"\"\"\n\n        def _split(x, axis, num_outputs, squeeze_axis):\n            x = F.split(\n                x, axis=axis, num_outputs=num_outputs, squeeze_axis=squeeze_axis\n            )\n            if isinstance(x, list):\n                return x\n            else:\n                return [x]\n\n        if m_rpn_box is not None:\n            manual_rpn_box = True\n        else:\n            manual_rpn_box = False\n        feat = self.features(x)\n        if not isinstance(feat, (list, tuple)):\n            feat = [feat]\n\n        # RPN proposals\n        if autograd.is_training():\n            if manual_rpn_box:\n                rpn_box = m_rpn_box\n                self.nms_thresh = 1\n            else:\n                (\n                    rpn_score,\n                    rpn_box,\n                    raw_rpn_score,\n                    raw_rpn_box,\n                    anchors,\n                ) = self.rpn(F.zeros_like(x), *feat)\n                rpn_box, samples, matches = self.sampler(\n                    rpn_box, rpn_score, gt_box\n                )\n        else:\n            if manual_rpn_box:\n                rpn_box = m_rpn_box\n                self.nms_thresh = 1\n            else:\n                _, rpn_box = self.rpn(F.zeros_like(x), *feat)\n\n        # create batchid for roi\n        if not manual_rpn_box:\n            num_roi = (\n                self._num_sample\n                if autograd.is_training()\n                else self._rpn_test_post_nms\n            )\n            batch_size = self._batch_size if autograd.is_training() else 1\n        else:\n            num_roi = m_rpn_box.shape[1]\n            batch_size = rpn_box.shape[0]\n\n        with autograd.pause():\n            roi_batchid = F.arange(0, batch_size)\n            roi_batchid = F.repeat(roi_batchid, num_roi)\n            # remove batch dim because ROIPooling require 2d input\n            rpn_roi = F.concat(\n                *[roi_batchid.reshape((-1, 1)), rpn_box.reshape((-1, 4))],\n                dim=-1\n            )\n            rpn_roi = F.stop_gradient(rpn_roi)\n\n        if self.num_stages > 1:\n            # using FPN\n            pooled_feat = self._pyramid_roi_feats(\n                F,\n                feat,\n                rpn_roi,\n                self._roi_size,\n                self._strides,\n                roi_mode=self._roi_mode,\n            )\n        else:\n            # ROI features\n            if self._roi_mode == \"pool\":\n                pooled_feat = F.ROIPooling(\n                    feat[0], rpn_roi, self._roi_size, 1.0 / self._strides\n                )\n            elif self._roi_mode == \"align\":\n                pooled_feat = F.contrib.ROIAlign(\n                    feat[0],\n                    rpn_roi,\n                    self._roi_size,\n                    1.0 / self._strides,\n                    sample_ratio=2,\n                )\n            else:\n                raise ValueError(\"Invalid roi mode: {}\".format(self._roi_mode))\n\n        # RCNN prediction\n        if self.top_features is not None:\n            top_feat = self.top_features(pooled_feat)\n        else:\n            top_feat = pooled_feat\n        if self.box_features is None:\n            box_feat = F.contrib.AdaptiveAvgPooling2D(top_feat, output_size=1)\n        else:\n            box_feat = self.box_features(top_feat)\n        cls_pred = self.class_predictor(box_feat)\n        # cls_pred (B * N, C) -> (B, N, C)\n        cls_pred = cls_pred.reshape((batch_size, num_roi, self.num_class + 1))\n        if manual_rpn_box:\n            spatial_feat = top_feat.mean(axis=1).reshape(\n                (-4, rpn_box.shape[0], rpn_box.shape[1], -3)\n            )\n            cls_ids, scores = self.cls_decoder(F.softmax(cls_pred, axis=-1))\n            cls_ids = cls_ids.transpose((0, 2, 1)).reshape((0, 0, 0, 1))\n            scores = scores.transpose((0, 2, 1)).reshape((0, 0, 0, 1))\n            cls_ids = _split(\n                cls_ids, axis=0, num_outputs=batch_size, squeeze_axis=True\n            )\n            scores = _split(\n                scores, axis=0, num_outputs=batch_size, squeeze_axis=True\n            )\n            return cls_ids, scores, rpn_box, spatial_feat\n\n        # no need to convert bounding boxes in training, just return\n        if autograd.is_training():\n            (\n                cls_targets,\n                box_targets,\n                box_masks,\n                indices,\n            ) = self._target_generator(\n                rpn_box, samples, matches, gt_label, gt_box\n            )\n            box_feat = F.reshape(box_feat.expand_dims(0), (batch_size, -1, 0))\n            box_pred = self.box_predictor(\n                F.concat(\n                    *[\n                        F.take(\n                            F.slice_axis(\n                                box_feat, axis=0, begin=i, end=i + 1\n                            ).squeeze(),\n                            F.slice_axis(\n                                indices, axis=0, begin=i, end=i + 1\n                            ).squeeze(),\n                        )\n                        for i in range(batch_size)\n                    ],\n                    dim=0\n                )\n            )\n            # box_pred (B * N, C * 4) -> (B, N, C, 4)\n            box_pred = box_pred.reshape((batch_size, -1, self.num_class, 4))\n            if self._additional_output:\n                return (\n                    cls_pred,\n                    box_pred,\n                    rpn_box,\n                    samples,\n                    matches,\n                    raw_rpn_score,\n                    raw_rpn_box,\n                    anchors,\n                    cls_targets,\n                    box_targets,\n                    box_masks,\n                    top_feat,\n                    indices,\n                )\n            return (\n                cls_pred,\n                box_pred,\n                rpn_box,\n                samples,\n                matches,\n                raw_rpn_score,\n                raw_rpn_box,\n                anchors,\n                cls_targets,\n                box_targets,\n                box_masks,\n                indices,\n            )\n\n        box_pred = self.box_predictor(box_feat)\n        # box_pred (B * N, C * 4) -> (B, N, C, 4)\n        box_pred = box_pred.reshape((batch_size, num_roi, self.num_class, 4))\n        # cls_ids (B, N, C), scores (B, N, C)\n        cls_ids, scores = self.cls_decoder(F.softmax(cls_pred, axis=-1))\n        # cls_ids, scores (B, N, C) -> (B, C, N) -> (B, C, N, 1)\n        cls_ids = cls_ids.transpose((0, 2, 1)).reshape((0, 0, 0, 1))\n        scores = scores.transpose((0, 2, 1)).reshape((0, 0, 0, 1))\n        # box_pred (B, N, C, 4) -> (B, C, N, 4)\n        box_pred = box_pred.transpose((0, 2, 1, 3))\n\n        # rpn_boxes (B, N, 4) -> B * (1, N, 4)\n        rpn_boxes = _split(\n            rpn_box, axis=0, num_outputs=batch_size, squeeze_axis=False\n        )\n        # cls_ids, scores (B, C, N, 1) -> B * (C, N, 1)\n        cls_ids = _split(\n            cls_ids, axis=0, num_outputs=batch_size, squeeze_axis=True\n        )\n        scores = _split(\n            scores, axis=0, num_outputs=batch_size, squeeze_axis=True\n        )\n        # box_preds (B, C, N, 4) -> B * (C, N, 4)\n        box_preds = _split(\n            box_pred, axis=0, num_outputs=batch_size, squeeze_axis=True\n        )\n\n        # per batch predict, nms, each class has topk outputs\n        results = []\n        # add feat index\n        if self._additional_output:\n            sizes = scores[0].shape[0:2]\n            # ind = mx.nd.array(list(range(sizes[1])))\n            ind = mx.nd.linspace(0, 999, 1000)\n            ind = mx.nd.repeat(ind, repeats=sizes[0])\n            ind = (\n                ind.reshape(sizes[1], sizes[0])\n                .transpose((1, 0))\n                .expand_dims(axis=2)\n            )\n        for rpn_box, cls_id, score, box_pred in zip(\n            rpn_boxes, cls_ids, scores, box_preds\n        ):\n            # box_pred (C, N, 4) rpn_box (1, N, 4) -> bbox (C, N, 4)\n            bbox = self.box_decoder(box_pred, rpn_box)\n            if self._additional_output:\n                # res (C, N, 7)\n                res = F.concat(*[cls_id, score, bbox, ind], dim=-1)\n            else:\n                # res (C, N, 6)\n                res = F.concat(*[cls_id, score, bbox], dim=-1)\n            if self.force_nms:\n                # res (1, C*N, 6), to allow cross-catogory suppression\n                res = res.reshape((1, -1, 0))\n            # res (C, self.nms_topk, 6)\n            res = F.contrib.box_nms(\n                res,\n                overlap_thresh=self.nms_thresh,\n                topk=self.nms_topk,\n                valid_thresh=0.001,\n                id_index=0,\n                score_index=1,\n                coord_start=2,\n                force_suppress=self.force_nms,\n            )\n            # res (C * self.nms_topk, 6)\n            res = res.reshape((-3, 0))\n            results.append(res)\n\n        # result B * (C * topk, 6) -> (B, C * topk, 6)\n        result = F.stack(*results, axis=0)\n        ids = F.slice_axis(result, axis=-1, begin=0, end=1)\n        scores = F.slice_axis(result, axis=-1, begin=1, end=2)\n        bboxes = F.slice_axis(result, axis=-1, begin=2, end=6)\n        if self._additional_output:\n            feat_ind = F.slice_axis(result, axis=-1, begin=6, end=7)\n            spatial_feat = (\n                top_feat.mean(axis=1).expand_dims(0).reshape(batch_size, 0, -1)\n            )\n            return ids, scores, bboxes, feat, feat_ind, spatial_feat\n        return ids, scores, bboxes\n\n\ndef get_faster_rcnn(\n    name,\n    dataset,\n    pretrained=False,\n    ctx=mx.cpu(),\n    root=os.path.join(\"~\", \".mxnet\", \"models\"),\n    **kwargs\n):\n    r\"\"\"Utility function to return faster rcnn networks.\n\n    Parameters\n    ----------\n    name : str\n        Model name.\n    dataset : str\n        The name of dataset.\n    pretrained : bool or str\n        Boolean value controls whether to load the default pretrained weights for model.\n        String value represents the hashtag for a certain version of pretrained weights.\n    ctx : mxnet.Context\n        Context such as mx.cpu(), mx.gpu(0).\n    root : str\n        Model weights storing path.\n\n    Returns\n    -------\n    mxnet.gluon.HybridBlock\n        The Faster-RCNN network.\n\n    \"\"\"\n    net = FasterRCNN(**kwargs)\n    if pretrained:\n        from gluoncv.model_zoo.model_store import get_model_file\n\n        full_name = \"_\".join((\"faster_rcnn\", name, dataset))\n        net.load_parameters(\n            get_model_file(full_name, tag=pretrained, root=root),\n            ctx=ctx,\n            ignore_extra=True,\n            allow_missing=True,\n        )\n    else:\n        for v in net.collect_params().values():\n            try:\n                v.reset_ctx(ctx)\n            except ValueError:\n                pass\n    return net\n\n\ndef faster_rcnn_resnet50_v1b_coco(\n    pretrained=False, pretrained_base=True, **kwargs\n):\n    r\"\"\"Faster RCNN model from the paper\n    \"Ren, S., He, K., Girshick, R., & Sun, J. (2015). Faster r-cnn: Towards\n    real-time object detection with region proposal networks\"\n\n    Parameters\n    ----------\n    pretrained : bool or str\n        Boolean value controls whether to load the default pretrained weights for model.\n        String value represents the hashtag for a certain version of pretrained weights.\n    pretrained_base : bool or str, optional, default is True\n        Load pretrained base network, the extra layers are randomized. Note that\n        if pretrained is `True`, this has no effect.\n    ctx : Context, default CPU\n        The context in which to load the pretrained weights.\n    root : str, default '~/.mxnet/models'\n        Location for keeping the model parameters.\n\n    Examples\n    --------\n    >>> model = get_faster_rcnn_resnet50_v1b_coco(pretrained=True)\n    >>> print(model)\n    \"\"\"\n    from gluoncv.data import COCODetection\n    from gluoncv.model_zoo.resnetv1b import resnet50_v1b\n\n    classes = COCODetection.CLASSES\n    pretrained_base = False if pretrained else pretrained_base\n    base_network = resnet50_v1b(\n        pretrained=pretrained_base,\n        dilated=False,\n        use_global_stats=True,\n        **kwargs\n    )\n    features = nn.HybridSequential()\n    top_features = nn.HybridSequential()\n    for layer in [\n        \"conv1\",\n        \"bn1\",\n        \"relu\",\n        \"maxpool\",\n        \"layer1\",\n        \"layer2\",\n        \"layer3\",\n    ]:\n        features.add(getattr(base_network, layer))\n    for layer in [\"layer4\"]:\n        top_features.add(getattr(base_network, layer))\n    train_patterns = \"|\".join(\n        [\".*dense\", \".*rpn\", \".*down(2|3|4)_conv\", \".*layers(2|3|4)_conv\"]\n    )\n    return get_faster_rcnn(\n        name=\"resnet50_v1b\",\n        dataset=\"coco\",\n        pretrained=pretrained,\n        features=features,\n        top_features=top_features,\n        classes=classes,\n        short=800,\n        max_size=1333,\n        train_patterns=train_patterns,\n        nms_thresh=0.7,\n        nms_topk=-1,\n        post_nms=-1,\n        roi_mode=\"align\",\n        roi_size=(14, 14),\n        strides=16,\n        clip=4.14,\n        rpn_channel=1024,\n        base_size=16,\n        scales=(2, 4, 8, 16, 32),\n        ratios=(0.5, 1, 2),\n        alloc_size=(128, 128),\n        rpn_nms_thresh=0.7,\n        rpn_train_pre_nms=12000,\n        rpn_train_post_nms=2000,\n        rpn_test_pre_nms=6000,\n        rpn_test_post_nms=1000,\n        rpn_min_size=1,\n        num_sample=128,\n        pos_iou_thresh=0.5,\n        pos_ratio=0.25,\n        max_num_gt=3000,\n        **kwargs\n    )\n\n\ndef faster_rcnn_resnet50_v1b_custom(\n    classes, transfer=None, pretrained_base=True, pretrained=False, **kwargs\n):\n    r\"\"\"Faster RCNN model with resnet50_v1b base network on custom dataset.\n\n    Parameters\n    ----------\n    classes : iterable of str\n        Names of custom foreground classes. `len(classes)` is the number of foreground classes.\n    transfer : str or None\n        If not `None`, will try to reuse pre-trained weights from faster RCNN networks trained\n        on other datasets.\n    pretrained : bool or str\n        Boolean value controls whether to load the default pretrained weights for model.\n        String value represents the hashtag for a certain version of pretrained weights.\n    pretrained_base : bool or str\n        Boolean value controls whether to load the default pretrained weights for model.\n        String value represents the hashtag for a certain version of pretrained weights.\n    ctx : Context, default CPU\n        The context in which to load the pretrained weights.\n    root : str, default '~/.mxnet/models'\n        Location for keeping the model parameters.\n\n    Returns\n    -------\n    mxnet.gluon.HybridBlock\n        Hybrid faster RCNN network.\n    \"\"\"\n    if pretrained:\n        warnings.warn(\n            \"Custom models don't provide `pretrained` weights, ignored.\"\n        )\n    if transfer is None:\n        from gluoncv.model_zoo.resnetv1b import resnet50_v1b\n\n        base_network = resnet50_v1b(\n            pretrained=pretrained_base,\n            dilated=False,\n            use_global_stats=True,\n            **kwargs\n        )\n        features = nn.HybridSequential()\n        top_features = nn.HybridSequential()\n        for layer in [\n            \"conv1\",\n            \"bn1\",\n            \"relu\",\n            \"maxpool\",\n            \"layer1\",\n            \"layer2\",\n            \"layer3\",\n        ]:\n            features.add(getattr(base_network, layer))\n        for layer in [\"layer4\"]:\n            top_features.add(getattr(base_network, layer))\n        train_patterns = \"|\".join(\n            [\".*dense\", \".*rpn\", \".*down(2|3|4)_conv\", \".*layers(2|3|4)_conv\"]\n        )\n        return get_faster_rcnn(\n            name=\"resnet50_v1b\",\n            dataset=\"custom\",\n            pretrained=pretrained,\n            features=features,\n            top_features=top_features,\n            classes=classes,\n            short=600,\n            max_size=1000,\n            train_patterns=train_patterns,\n            nms_thresh=0.7,\n            nms_topk=400,\n            post_nms=100,\n            roi_mode=\"align\",\n            roi_size=(14, 14),\n            strides=16,\n            clip=4.14,\n            rpn_channel=1024,\n            base_size=16,\n            scales=(2, 4, 8, 16, 32),\n            ratios=(0.5, 1, 2),\n            alloc_size=(128, 128),\n            rpn_nms_thresh=0.7,\n            rpn_train_pre_nms=12000,\n            rpn_train_post_nms=2000,\n            rpn_test_pre_nms=6000,\n            rpn_test_post_nms=300,\n            rpn_min_size=16,\n            num_sample=128,\n            pos_iou_thresh=0.5,\n            pos_ratio=0.25,\n            max_num_gt=3000,\n            **kwargs\n        )\n    else:\n        from gluoncv.model_zoo import get_model\n\n        net = get_model(\n            \"faster_rcnn_resnet50_v1b_\" + str(transfer),\n            pretrained=True,\n            **kwargs\n        )\n        reuse_classes = [x for x in classes if x in net.classes]\n        net.reset_class(classes, reuse_weights=reuse_classes)\n    return net\n\n\ndef faster_rcnn_resnet101_v1d_coco(\n    pretrained=False, pretrained_base=True, **kwargs\n):\n    r\"\"\"Faster RCNN model from the paper\n    \"Ren, S., He, K., Girshick, R., & Sun, J. (2015). Faster r-cnn: Towards\n    real-time object detection with region proposal networks\"\n\n    Parameters\n    ----------\n    pretrained : bool, optional, default is False\n        Load pretrained weights.\n    pretrained_base : bool or str, optional, default is True\n        Load pretrained base network, the extra layers are randomized. Note that\n        if pretrained is `True`, this has no effect.\n    ctx : Context, default CPU\n        The context in which to load the pretrained weights.\n    root : str, default '~/.mxnet/models'\n        Location for keeping the model parameters.\n\n    Examples\n    --------\n    >>> model = get_faster_rcnn_resnet101_v1d_coco(pretrained=True)\n    >>> print(model)\n    \"\"\"\n    from gluoncv.data import COCODetection\n    from gluoncv.model_zoo.resnetv1b import resnet101_v1d\n\n    classes = COCODetection.CLASSES\n    pretrained_base = False if pretrained else pretrained_base\n    base_network = resnet101_v1d(\n        pretrained=pretrained_base,\n        dilated=False,\n        use_global_stats=True,\n        **kwargs\n    )\n    features = nn.HybridSequential()\n    top_features = nn.HybridSequential()\n    for layer in [\n        \"conv1\",\n        \"bn1\",\n        \"relu\",\n        \"maxpool\",\n        \"layer1\",\n        \"layer2\",\n        \"layer3\",\n    ]:\n        features.add(getattr(base_network, layer))\n    for layer in [\"layer4\"]:\n        top_features.add(getattr(base_network, layer))\n    train_patterns = \"|\".join(\n        [\".*dense\", \".*rpn\", \".*down(2|3|4)_conv\", \".*layers(2|3|4)_conv\"]\n    )\n    return get_faster_rcnn(\n        name=\"resnet101_v1d\",\n        dataset=\"coco\",\n        pretrained=pretrained,\n        features=features,\n        top_features=top_features,\n        classes=classes,\n        short=800,\n        max_size=1333,\n        train_patterns=train_patterns,\n        nms_thresh=0.5,\n        nms_topk=-1,\n        post_nms=100,\n        roi_mode=\"align\",\n        roi_size=(14, 14),\n        strides=16,\n        clip=4.14,\n        rpn_channel=1024,\n        base_size=16,\n        scales=(2, 4, 8, 16, 32),\n        ratios=(0.5, 1, 2),\n        alloc_size=(128, 128),\n        rpn_nms_thresh=0.7,\n        rpn_train_pre_nms=12000,\n        rpn_train_post_nms=2000,\n        rpn_test_pre_nms=6000,\n        rpn_test_post_nms=1000,\n        rpn_min_size=1,\n        num_sample=128,\n        pos_iou_thresh=0.5,\n        pos_ratio=0.25,\n        max_num_gt=3000,\n        **kwargs\n    )\n\n\ndef faster_rcnn_resnet101_v1d_custom(\n    classes, transfer=None, pretrained_base=True, pretrained=False, **kwargs\n):\n    r\"\"\"Faster RCNN model with resnet101_v1d base network on custom dataset.\n\n    Parameters\n    ----------\n    classes : iterable of str\n        Names of custom foreground classes. `len(classes)` is the number of foreground classes.\n    transfer : str or None\n        If not `None`, will try to reuse pre-trained weights from faster RCNN networks trained\n        on other datasets.\n    pretrained_base : bool or str\n        Boolean value controls whether to load the default pretrained weights for model.\n        String value represents the hashtag for a certain version of pretrained weights.\n    ctx : Context, default CPU\n        The context in which to load the pretrained weights.\n    root : str, default '~/.mxnet/models'\n        Location for keeping the model parameters.\n\n    Returns\n    -------\n    mxnet.gluon.HybridBlock\n        Hybrid faster RCNN network.\n    \"\"\"\n    if pretrained:\n        warnings.warn(\n            \"Custom models don't provide `pretrained` weights, ignored.\"\n        )\n    if transfer is None:\n        from gluoncv.model_zoo.resnetv1b import resnet101_v1d\n\n        base_network = resnet101_v1d(\n            pretrained=pretrained_base,\n            dilated=False,\n            use_global_stats=True,\n            **kwargs\n        )\n        features = nn.HybridSequential()\n        top_features = nn.HybridSequential()\n        for layer in [\n            \"conv1\",\n            \"bn1\",\n            \"relu\",\n            \"maxpool\",\n            \"layer1\",\n            \"layer2\",\n            \"layer3\",\n        ]:\n            features.add(getattr(base_network, layer))\n        for layer in [\"layer4\"]:\n            top_features.add(getattr(base_network, layer))\n        train_patterns = \"|\".join(\n            [\".*dense\", \".*rpn\", \".*down(2|3|4)_conv\", \".*layers(2|3|4)_conv\"]\n        )\n        return get_faster_rcnn(\n            name=\"resnet101_v1d\",\n            dataset=\"custom\",\n            pretrained=pretrained,\n            features=features,\n            top_features=top_features,\n            classes=classes,\n            short=600,\n            max_size=1000,\n            train_patterns=train_patterns,\n            nms_thresh=0.5,\n            nms_topk=400,\n            post_nms=100,\n            roi_mode=\"align\",\n            roi_size=(14, 14),\n            strides=16,\n            clip=4.14,\n            rpn_channel=1024,\n            base_size=16,\n            scales=(2, 4, 8, 16, 32),\n            ratios=(0.5, 1, 2),\n            alloc_size=(128, 128),\n            rpn_nms_thresh=0.7,\n            rpn_train_pre_nms=12000,\n            rpn_train_post_nms=2000,\n            rpn_test_pre_nms=6000,\n            rpn_test_post_nms=300,\n            rpn_min_size=16,\n            num_sample=128,\n            pos_iou_thresh=0.5,\n            pos_ratio=0.25,\n            max_num_gt=3000,\n            **kwargs\n        )\n    else:\n        net = faster_rcnn_resnet101_v1d_coco(pretrained=True)\n        reuse_classes = [x for x in classes if x in net.classes]\n        net.reset_class(classes, reuse_weights=reuse_classes)\n    return net\n"
  },
  {
    "path": "examples/mxnet/scenegraph/model/reldn.py",
    "content": "import pickle\n\nimport dgl\n\nimport gluoncv as gcv\nimport mxnet as mx\nimport numpy as np\nfrom dgl.nn.mxnet import GraphConv\nfrom dgl.utils import toindex\nfrom mxnet import nd\nfrom mxnet.gluon import nn\n\n__all__ = [\"RelDN\"]\n\n\nclass EdgeConfMLP(nn.Block):\n    \"\"\"compute the confidence for edges\"\"\"\n\n    def __init__(self):\n        super(EdgeConfMLP, self).__init__()\n\n    def forward(self, edges):\n        score_pred = nd.log_softmax(edges.data[\"preds\"])[:, 1:].max(axis=1)\n        score_phr = (\n            score_pred\n            + edges.src[\"node_class_logit\"]\n            + edges.dst[\"node_class_logit\"]\n        )\n        return {\"score_pred\": score_pred, \"score_phr\": score_phr}\n\n\nclass EdgeBBoxExtend(nn.Block):\n    \"\"\"encode the bounding boxes\"\"\"\n\n    def __init__(self):\n        super(EdgeBBoxExtend, self).__init__()\n\n    def bbox_delta(self, bbox_a, bbox_b):\n        n = bbox_a.shape[0]\n        result = nd.zeros((n, 4), ctx=bbox_a.context)\n        result[:, 0] = bbox_a[:, 0] - bbox_b[:, 0]\n        result[:, 1] = bbox_a[:, 1] - bbox_b[:, 1]\n        result[:, 2] = nd.log(\n            (bbox_a[:, 2] - bbox_a[:, 0] + 1e-8)\n            / (bbox_b[:, 2] - bbox_b[:, 0] + 1e-8)\n        )\n        result[:, 3] = nd.log(\n            (bbox_a[:, 3] - bbox_a[:, 1] + 1e-8)\n            / (bbox_b[:, 3] - bbox_b[:, 1] + 1e-8)\n        )\n        return result\n\n    def forward(self, edges):\n        ctx = edges.src[\"pred_bbox\"].context\n        n = edges.src[\"pred_bbox\"].shape[0]\n        delta_src_obj = self.bbox_delta(\n            edges.src[\"pred_bbox\"], edges.dst[\"pred_bbox\"]\n        )\n        delta_src_rel = self.bbox_delta(\n            edges.src[\"pred_bbox\"], edges.data[\"rel_bbox\"]\n        )\n        delta_rel_obj = self.bbox_delta(\n            edges.data[\"rel_bbox\"], edges.dst[\"pred_bbox\"]\n        )\n        result = nd.zeros((n, 12), ctx=ctx)\n        result[:, 0:4] = delta_src_obj\n        result[:, 4:8] = delta_src_rel\n        result[:, 8:12] = delta_rel_obj\n        return {\"pred_bbox_additional\": result}\n\n\nclass EdgeFreqPrior(nn.Block):\n    \"\"\"make use of the pre-trained frequency prior\"\"\"\n\n    def __init__(self, prior_pkl):\n        super(EdgeFreqPrior, self).__init__()\n        with open(prior_pkl, \"rb\") as f:\n            freq_prior = pickle.load(f)\n        self.freq_prior = freq_prior\n\n    def forward(self, edges):\n        ctx = edges.src[\"node_class_pred\"].context\n        src_ind = edges.src[\"node_class_pred\"].asnumpy().astype(int)\n        dst_ind = edges.dst[\"node_class_pred\"].asnumpy().astype(int)\n        prob = self.freq_prior[src_ind, dst_ind]\n        out = nd.array(prob, ctx=ctx)\n        return {\"freq_prior\": out}\n\n\nclass EdgeSpatial(nn.Block):\n    \"\"\"spatial feature branch\"\"\"\n\n    def __init__(self, n_classes):\n        super(EdgeSpatial, self).__init__()\n        self.mlp = nn.Sequential()\n        self.mlp.add(nn.Dense(64))\n        self.mlp.add(nn.LeakyReLU(0.1))\n        self.mlp.add(nn.Dense(64))\n        self.mlp.add(nn.LeakyReLU(0.1))\n        self.mlp.add(nn.Dense(n_classes))\n\n    def forward(self, edges):\n        feat = nd.concat(\n            edges.src[\"pred_bbox\"],\n            edges.dst[\"pred_bbox\"],\n            edges.data[\"rel_bbox\"],\n            edges.data[\"pred_bbox_additional\"],\n        )\n        out = self.mlp(feat)\n        return {\"spatial\": out}\n\n\nclass EdgeVisual(nn.Block):\n    \"\"\"visual feature branch\"\"\"\n\n    def __init__(self, n_classes, vis_feat_dim=7 * 7 * 3):\n        super(EdgeVisual, self).__init__()\n        self.dim_in = vis_feat_dim\n        self.mlp_joint = nn.Sequential()\n        self.mlp_joint.add(nn.Dense(vis_feat_dim // 2))\n        self.mlp_joint.add(nn.LeakyReLU(0.1))\n        self.mlp_joint.add(nn.Dense(vis_feat_dim // 3))\n        self.mlp_joint.add(nn.LeakyReLU(0.1))\n        self.mlp_joint.add(nn.Dense(n_classes))\n\n        self.mlp_sub = nn.Dense(n_classes)\n        self.mlp_ob = nn.Dense(n_classes)\n\n    def forward(self, edges):\n        feat = nd.concat(\n            edges.src[\"node_feat\"],\n            edges.dst[\"node_feat\"],\n            edges.data[\"edge_feat\"],\n        )\n        out_joint = self.mlp_joint(feat)\n        out_sub = self.mlp_sub(edges.src[\"node_feat\"])\n        out_ob = self.mlp_ob(edges.dst[\"node_feat\"])\n        out = out_joint + out_sub + out_ob\n        return {\"visual\": out}\n\n\nclass RelDN(nn.Block):\n    \"\"\"The RelDN Model\"\"\"\n\n    def __init__(self, n_classes, prior_pkl, semantic_only=False):\n        super(RelDN, self).__init__()\n        # output layers\n        self.edge_bbox_extend = EdgeBBoxExtend()\n        # semantic through mlp encoding\n        if prior_pkl is not None:\n            self.freq_prior = EdgeFreqPrior(prior_pkl)\n\n        # with predicate class and a link class\n        self.spatial = EdgeSpatial(n_classes + 1)\n        # with visual features\n        self.visual = EdgeVisual(n_classes + 1)\n        self.edge_conf_mlp = EdgeConfMLP()\n        self.semantic_only = semantic_only\n\n    def forward(self, g):\n        if g is None or g.number_of_nodes() == 0:\n            return g\n        # predictions\n        g.apply_edges(self.freq_prior)\n        if self.semantic_only:\n            g.edata[\"preds\"] = g.edata[\"freq_prior\"]\n        else:\n            # bbox extension\n            g.apply_edges(self.edge_bbox_extend)\n            g.apply_edges(self.spatial)\n            g.apply_edges(self.visual)\n            g.edata[\"preds\"] = (\n                g.edata[\"freq_prior\"] + g.edata[\"spatial\"] + g.edata[\"visual\"]\n            )\n        # subgraph for gconv\n        g.apply_edges(self.edge_conf_mlp)\n        return g\n"
  },
  {
    "path": "examples/mxnet/scenegraph/train_faster_rcnn.py",
    "content": "\"\"\"Train Faster-RCNN end to end.\"\"\"\nimport argparse\nimport os\n\n# disable autotune\nos.environ[\"MXNET_CUDNN_AUTOTUNE_DEFAULT\"] = \"0\"\nimport logging\nimport time\n\nimport gluoncv as gcv\nimport mxnet as mx\nimport numpy as np\nfrom data import *\nfrom gluoncv import data as gdata, utils as gutils\nfrom gluoncv.data.batchify import Append, FasterRCNNTrainBatchify, Tuple\nfrom gluoncv.data.transforms.presets.rcnn import (\n    FasterRCNNDefaultTrainTransform,\n    FasterRCNNDefaultValTransform,\n)\nfrom gluoncv.model_zoo import get_model\nfrom gluoncv.utils.metrics.coco_detection import COCODetectionMetric\nfrom gluoncv.utils.metrics.rcnn import (\n    RCNNAccMetric,\n    RCNNL1LossMetric,\n    RPNAccMetric,\n    RPNL1LossMetric,\n)\nfrom gluoncv.utils.metrics.voc_detection import VOC07MApMetric\nfrom gluoncv.utils.parallel import Parallel, Parallelizable\nfrom model import (\n    faster_rcnn_resnet101_v1d_custom,\n    faster_rcnn_resnet50_v1b_custom,\n)\nfrom mxnet import autograd, gluon\nfrom mxnet.contrib import amp\n\ntry:\n    import horovod.mxnet as hvd\nexcept ImportError:\n    hvd = None\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(\n        description=\"Train Faster-RCNN networks e2e.\"\n    )\n    parser.add_argument(\n        \"--network\",\n        type=str,\n        default=\"resnet101_v1d\",\n        help=\"Base network name which serves as feature extraction base.\",\n    )\n    parser.add_argument(\n        \"--dataset\",\n        type=str,\n        default=\"visualgenome\",\n        help=\"Training dataset. Now support voc and coco.\",\n    )\n    parser.add_argument(\n        \"--num-workers\",\n        \"-j\",\n        dest=\"num_workers\",\n        type=int,\n        default=8,\n        help=\"Number of data workers, you can use larger \"\n        \"number to accelerate data loading, \"\n        \"if your CPU and GPUs are powerful.\",\n    )\n    parser.add_argument(\n        \"--batch-size\", type=int, default=8, help=\"Training mini-batch size.\"\n    )\n    parser.add_argument(\n        \"--gpus\",\n        type=str,\n        default=\"0\",\n        help=\"Training with GPUs, you can specify 1,3 for example.\",\n    )\n    parser.add_argument(\n        \"--epochs\", type=str, default=\"\", help=\"Training epochs.\"\n    )\n    parser.add_argument(\n        \"--resume\",\n        type=str,\n        default=\"\",\n        help=\"Resume from previously saved parameters if not None. \"\n        \"For example, you can resume from ./faster_rcnn_xxx_0123.params\",\n    )\n    parser.add_argument(\n        \"--start-epoch\",\n        type=int,\n        default=0,\n        help=\"Starting epoch for resuming, default is 0 for new training.\"\n        \"You can specify it to 100 for example to start from 100 epoch.\",\n    )\n    parser.add_argument(\n        \"--lr\",\n        type=str,\n        default=\"\",\n        help=\"Learning rate, default is 0.001 for voc single gpu training.\",\n    )\n    parser.add_argument(\n        \"--lr-decay\",\n        type=float,\n        default=0.1,\n        help=\"decay rate of learning rate. default is 0.1.\",\n    )\n    parser.add_argument(\n        \"--lr-decay-epoch\",\n        type=str,\n        default=\"\",\n        help=\"epochs at which learning rate decays. default is 14,20 for voc.\",\n    )\n    parser.add_argument(\n        \"--lr-warmup\",\n        type=str,\n        default=\"\",\n        help=\"warmup iterations to adjust learning rate, default is 0 for voc.\",\n    )\n    parser.add_argument(\n        \"--lr-warmup-factor\",\n        type=float,\n        default=1.0 / 3.0,\n        help=\"warmup factor of base lr.\",\n    )\n    parser.add_argument(\n        \"--momentum\",\n        type=float,\n        default=0.9,\n        help=\"SGD momentum, default is 0.9\",\n    )\n    parser.add_argument(\n        \"--wd\",\n        type=str,\n        default=\"\",\n        help=\"Weight decay, default is 5e-4 for voc\",\n    )\n    parser.add_argument(\n        \"--log-interval\",\n        type=int,\n        default=100,\n        help=\"Logging mini-batch interval. Default is 100.\",\n    )\n    parser.add_argument(\n        \"--save-prefix\", type=str, default=\"\", help=\"Saving parameter prefix\"\n    )\n    parser.add_argument(\n        \"--save-interval\",\n        type=int,\n        default=1,\n        help=\"Saving parameters epoch interval, best model will always be saved.\",\n    )\n    parser.add_argument(\n        \"--val-interval\",\n        type=int,\n        default=1,\n        help=\"Epoch interval for validation, increase the number will reduce the \"\n        \"training time if validation is slow.\",\n    )\n    parser.add_argument(\n        \"--seed\", type=int, default=233, help=\"Random seed to be fixed.\"\n    )\n    parser.add_argument(\n        \"--verbose\",\n        dest=\"verbose\",\n        action=\"store_true\",\n        help=\"Print helpful debugging info once set.\",\n    )\n    parser.add_argument(\n        \"--mixup\", action=\"store_true\", help=\"Use mixup training.\"\n    )\n    parser.add_argument(\n        \"--no-mixup-epochs\",\n        type=int,\n        default=20,\n        help=\"Disable mixup training if enabled in the last N epochs.\",\n    )\n\n    # Norm layer options\n    parser.add_argument(\n        \"--norm-layer\",\n        type=str,\n        default=None,\n        help=\"Type of normalization layer to use. \"\n        \"If set to None, backbone normalization layer will be fixed,\"\n        \" and no normalization layer will be used. \"\n        \"Currently supports 'bn', and None, default is None.\"\n        \"Note that if horovod is enabled, sync bn will not work correctly.\",\n    )\n\n    # FPN options\n    parser.add_argument(\n        \"--use-fpn\",\n        action=\"store_true\",\n        help=\"Whether to use feature pyramid network.\",\n    )\n\n    # Performance options\n    parser.add_argument(\n        \"--disable-hybridization\",\n        action=\"store_true\",\n        help=\"Whether to disable hybridize the model. \"\n        \"Memory usage and speed will decrese.\",\n    )\n    parser.add_argument(\n        \"--static-alloc\",\n        action=\"store_true\",\n        help=\"Whether to use static memory allocation. Memory usage will increase.\",\n    )\n    parser.add_argument(\n        \"--amp\",\n        action=\"store_true\",\n        help=\"Use MXNet AMP for mixed precision training.\",\n    )\n    parser.add_argument(\n        \"--horovod\",\n        action=\"store_true\",\n        help=\"Use MXNet Horovod for distributed training. Must be run with OpenMPI. \"\n        \"--gpus is ignored when using --horovod.\",\n    )\n    parser.add_argument(\n        \"--executor-threads\",\n        type=int,\n        default=1,\n        help=\"Number of threads for executor for scheduling ops. \"\n        \"More threads may incur higher GPU memory footprint, \"\n        \"but may speed up throughput. Note that when horovod is used, \"\n        \"it is set to 1.\",\n    )\n    parser.add_argument(\n        \"--kv-store\",\n        type=str,\n        default=\"nccl\",\n        help=\"KV store options. local, device, nccl, dist_sync, dist_device_sync, \"\n        \"dist_async are available.\",\n    )\n\n    args = parser.parse_args()\n\n    if args.horovod:\n        if hvd is None:\n            raise SystemExit(\n                \"Horovod not found, please check if you installed it correctly.\"\n            )\n        hvd.init()\n\n    if args.dataset == \"voc\":\n        args.epochs = int(args.epochs) if args.epochs else 20\n        args.lr_decay_epoch = (\n            args.lr_decay_epoch if args.lr_decay_epoch else \"14,20\"\n        )\n        args.lr = float(args.lr) if args.lr else 0.001\n        args.lr_warmup = args.lr_warmup if args.lr_warmup else -1\n        args.wd = float(args.wd) if args.wd else 5e-4\n    elif args.dataset == \"visualgenome\":\n        args.epochs = int(args.epochs) if args.epochs else 20\n        args.lr_decay_epoch = (\n            args.lr_decay_epoch if args.lr_decay_epoch else \"14,20\"\n        )\n        args.lr = float(args.lr) if args.lr else 0.001\n        args.lr_warmup = args.lr_warmup if args.lr_warmup else -1\n        args.wd = float(args.wd) if args.wd else 5e-4\n    elif args.dataset == \"coco\":\n        args.epochs = int(args.epochs) if args.epochs else 26\n        args.lr_decay_epoch = (\n            args.lr_decay_epoch if args.lr_decay_epoch else \"17,23\"\n        )\n        args.lr = float(args.lr) if args.lr else 0.01\n        args.lr_warmup = args.lr_warmup if args.lr_warmup else 1000\n        args.wd = float(args.wd) if args.wd else 1e-4\n    return args\n\n\ndef get_dataset(dataset, args):\n    if dataset.lower() == \"voc\":\n        train_dataset = gdata.VOCDetection(\n            splits=[(2007, \"trainval\"), (2012, \"trainval\")]\n        )\n        val_dataset = gdata.VOCDetection(splits=[(2007, \"test\")])\n        val_metric = VOC07MApMetric(\n            iou_thresh=0.5, class_names=val_dataset.classes\n        )\n    elif dataset.lower() == \"coco\":\n        train_dataset = gdata.COCODetection(\n            splits=\"instances_train2017\", use_crowd=False\n        )\n        val_dataset = gdata.COCODetection(\n            splits=\"instances_val2017\", skip_empty=False\n        )\n        val_metric = COCODetectionMetric(\n            val_dataset, args.save_prefix + \"_eval\", cleanup=True\n        )\n    elif dataset.lower() == \"visualgenome\":\n        train_dataset = VGObject(\n            root=os.path.join(\"~\", \".mxnet\", \"datasets\", \"visualgenome\"),\n            splits=\"detections_train\",\n            use_crowd=False,\n        )\n        val_dataset = VGObject(\n            root=os.path.join(\"~\", \".mxnet\", \"datasets\", \"visualgenome\"),\n            splits=\"detections_val\",\n            skip_empty=False,\n        )\n        val_metric = COCODetectionMetric(\n            val_dataset, args.save_prefix + \"_eval\", cleanup=True\n        )\n    else:\n        raise NotImplementedError(\n            \"Dataset: {} not implemented.\".format(dataset)\n        )\n    if args.mixup:\n        from gluoncv.data.mixup import detection\n\n        train_dataset = detection.MixupDetection(train_dataset)\n    return train_dataset, val_dataset, val_metric\n\n\ndef get_dataloader(\n    net,\n    train_dataset,\n    val_dataset,\n    train_transform,\n    val_transform,\n    batch_size,\n    num_shards,\n    args,\n):\n    \"\"\"Get dataloader.\"\"\"\n    train_bfn = FasterRCNNTrainBatchify(net, num_shards)\n    if hasattr(train_dataset, \"get_im_aspect_ratio\"):\n        im_aspect_ratio = train_dataset.get_im_aspect_ratio()\n    else:\n        im_aspect_ratio = [1.0] * len(train_dataset)\n    train_sampler = gcv.nn.sampler.SplitSortedBucketSampler(\n        im_aspect_ratio,\n        batch_size,\n        num_parts=hvd.size() if args.horovod else 1,\n        part_index=hvd.rank() if args.horovod else 0,\n        shuffle=True,\n    )\n    train_loader = mx.gluon.data.DataLoader(\n        train_dataset.transform(\n            train_transform(\n                net.short,\n                net.max_size,\n                net,\n                ashape=net.ashape,\n                multi_stage=args.use_fpn,\n            )\n        ),\n        batch_sampler=train_sampler,\n        batchify_fn=train_bfn,\n        num_workers=args.num_workers,\n    )\n    if val_dataset is None:\n        val_loader = None\n    else:\n        val_bfn = Tuple(*[Append() for _ in range(3)])\n        short = (\n            net.short[-1] if isinstance(net.short, (tuple, list)) else net.short\n        )\n        # validation use 1 sample per device\n        val_loader = mx.gluon.data.DataLoader(\n            val_dataset.transform(val_transform(short, net.max_size)),\n            num_shards,\n            False,\n            batchify_fn=val_bfn,\n            last_batch=\"keep\",\n            num_workers=args.num_workers,\n        )\n    return train_loader, val_loader\n\n\ndef save_params(\n    net, logger, best_map, current_map, epoch, save_interval, prefix\n):\n    current_map = float(current_map)\n    if current_map > best_map[0]:\n        logger.info(\n            \"[Epoch {}] mAP {} higher than current best {} saving to {}\".format(\n                epoch, current_map, best_map, \"{:s}_best.params\".format(prefix)\n            )\n        )\n        best_map[0] = current_map\n        net.save_parameters(\"{:s}_best.params\".format(prefix))\n        with open(prefix + \"_best_map.log\", \"a\") as f:\n            f.write(\"{:04d}:\\t{:.4f}\\n\".format(epoch, current_map))\n    if save_interval and (epoch + 1) % save_interval == 0:\n        logger.info(\n            \"[Epoch {}] Saving parameters to {}\".format(\n                epoch,\n                \"{:s}_{:04d}_{:.4f}.params\".format(prefix, epoch, current_map),\n            )\n        )\n        net.save_parameters(\n            \"{:s}_{:04d}_{:.4f}.params\".format(prefix, epoch, current_map)\n        )\n\n\ndef split_and_load(batch, ctx_list):\n    \"\"\"Split data to 1 batch each device.\"\"\"\n    new_batch = []\n    for i, data in enumerate(batch):\n        if isinstance(data, (list, tuple)):\n            new_data = [x.as_in_context(ctx) for x, ctx in zip(data, ctx_list)]\n        else:\n            new_data = [data.as_in_context(ctx_list[0])]\n        new_batch.append(new_data)\n    return new_batch\n\n\ndef validate(net, val_data, ctx, eval_metric, args):\n    \"\"\"Test on validation dataset.\"\"\"\n    clipper = gcv.nn.bbox.BBoxClipToImage()\n    eval_metric.reset()\n    if not args.disable_hybridization:\n        # input format is differnet than training, thus rehybridization is needed.\n        net.hybridize(static_alloc=args.static_alloc)\n    for i, batch in enumerate(val_data):\n        batch = split_and_load(batch, ctx_list=ctx)\n        det_bboxes = []\n        det_ids = []\n        det_scores = []\n        gt_bboxes = []\n        gt_ids = []\n        gt_difficults = []\n        for x, y, im_scale in zip(*batch):\n            # get prediction results\n            ids, scores, bboxes = net(x)\n            det_ids.append(ids)\n            det_scores.append(scores)\n            # clip to image size\n            det_bboxes.append(clipper(bboxes, x))\n            # rescale to original resolution\n            im_scale = im_scale.reshape((-1)).asscalar()\n            det_bboxes[-1] *= im_scale\n            # split ground truths\n            gt_ids.append(y.slice_axis(axis=-1, begin=4, end=5))\n            gt_bboxes.append(y.slice_axis(axis=-1, begin=0, end=4))\n            gt_bboxes[-1] *= im_scale\n            gt_difficults.append(\n                y.slice_axis(axis=-1, begin=5, end=6)\n                if y.shape[-1] > 5\n                else None\n            )\n\n        # update metric\n        for det_bbox, det_id, det_score, gt_bbox, gt_id, gt_diff in zip(\n            det_bboxes, det_ids, det_scores, gt_bboxes, gt_ids, gt_difficults\n        ):\n            eval_metric.update(\n                det_bbox, det_id, det_score, gt_bbox, gt_id, gt_diff\n            )\n    return eval_metric.get()\n\n\ndef get_lr_at_iter(alpha, lr_warmup_factor=1.0 / 3.0):\n    return lr_warmup_factor * (1 - alpha) + alpha\n\n\nclass ForwardBackwardTask(Parallelizable):\n    def __init__(\n        self,\n        net,\n        optimizer,\n        rpn_cls_loss,\n        rpn_box_loss,\n        rcnn_cls_loss,\n        rcnn_box_loss,\n        mix_ratio,\n    ):\n        super(ForwardBackwardTask, self).__init__()\n        self.net = net\n        self._optimizer = optimizer\n        self.rpn_cls_loss = rpn_cls_loss\n        self.rpn_box_loss = rpn_box_loss\n        self.rcnn_cls_loss = rcnn_cls_loss\n        self.rcnn_box_loss = rcnn_box_loss\n        self.mix_ratio = mix_ratio\n\n    def forward_backward(self, x):\n        data, label, rpn_cls_targets, rpn_box_targets, rpn_box_masks = x\n        with autograd.record():\n            gt_label = label[:, :, 4:5]\n            gt_box = label[:, :, :4]\n            (\n                cls_pred,\n                box_pred,\n                roi,\n                samples,\n                matches,\n                rpn_score,\n                rpn_box,\n                anchors,\n                cls_targets,\n                box_targets,\n                box_masks,\n                _,\n            ) = net(data, gt_box, gt_label)\n            # losses of rpn\n            rpn_score = rpn_score.squeeze(axis=-1)\n            num_rpn_pos = (rpn_cls_targets >= 0).sum()\n            rpn_loss1 = (\n                self.rpn_cls_loss(\n                    rpn_score, rpn_cls_targets, rpn_cls_targets >= 0\n                )\n                * rpn_cls_targets.size\n                / num_rpn_pos\n            )\n            rpn_loss2 = (\n                self.rpn_box_loss(rpn_box, rpn_box_targets, rpn_box_masks)\n                * rpn_box.size\n                / num_rpn_pos\n            )\n            # rpn overall loss, use sum rather than average\n            rpn_loss = rpn_loss1 + rpn_loss2\n            # losses of rcnn\n            num_rcnn_pos = (cls_targets >= 0).sum()\n            rcnn_loss1 = (\n                self.rcnn_cls_loss(\n                    cls_pred, cls_targets, cls_targets.expand_dims(-1) >= 0\n                )\n                * cls_targets.size\n                / num_rcnn_pos\n            )\n            rcnn_loss2 = (\n                self.rcnn_box_loss(box_pred, box_targets, box_masks)\n                * box_pred.size\n                / num_rcnn_pos\n            )\n            rcnn_loss = rcnn_loss1 + rcnn_loss2\n            # overall losses\n            total_loss = (\n                rpn_loss.sum() * self.mix_ratio\n                + rcnn_loss.sum() * self.mix_ratio\n            )\n\n            rpn_loss1_metric = rpn_loss1.mean() * self.mix_ratio\n            rpn_loss2_metric = rpn_loss2.mean() * self.mix_ratio\n            rcnn_loss1_metric = rcnn_loss1.mean() * self.mix_ratio\n            rcnn_loss2_metric = rcnn_loss2.mean() * self.mix_ratio\n            rpn_acc_metric = [\n                [rpn_cls_targets, rpn_cls_targets >= 0],\n                [rpn_score],\n            ]\n            rpn_l1_loss_metric = [[rpn_box_targets, rpn_box_masks], [rpn_box]]\n            rcnn_acc_metric = [[cls_targets], [cls_pred]]\n            rcnn_l1_loss_metric = [[box_targets, box_masks], [box_pred]]\n\n            if args.amp:\n                with amp.scale_loss(\n                    total_loss, self._optimizer\n                ) as scaled_losses:\n                    autograd.backward(scaled_losses)\n            else:\n                total_loss.backward()\n\n        return (\n            rpn_loss1_metric,\n            rpn_loss2_metric,\n            rcnn_loss1_metric,\n            rcnn_loss2_metric,\n            rpn_acc_metric,\n            rpn_l1_loss_metric,\n            rcnn_acc_metric,\n            rcnn_l1_loss_metric,\n        )\n\n\ndef train(net, train_data, val_data, eval_metric, batch_size, ctx, args):\n    \"\"\"Training pipeline\"\"\"\n    args.kv_store = (\n        \"device\" if (args.amp and \"nccl\" in args.kv_store) else args.kv_store\n    )\n    kv = mx.kvstore.create(args.kv_store)\n    net.collect_params().setattr(\"grad_req\", \"null\")\n    net.collect_train_params().setattr(\"grad_req\", \"write\")\n    optimizer_params = {\n        \"learning_rate\": args.lr,\n        \"wd\": args.wd,\n        \"momentum\": args.momentum,\n    }\n    if args.horovod:\n        hvd.broadcast_parameters(net.collect_params(), root_rank=0)\n        trainer = hvd.DistributedTrainer(\n            net.collect_train_params(),  # fix batchnorm, fix first stage, etc...\n            \"sgd\",\n            optimizer_params,\n        )\n    else:\n        trainer = gluon.Trainer(\n            net.collect_train_params(),  # fix batchnorm, fix first stage, etc...\n            \"sgd\",\n            optimizer_params,\n            update_on_kvstore=(False if args.amp else None),\n            kvstore=kv,\n        )\n\n    if args.amp:\n        amp.init_trainer(trainer)\n\n    # lr decay policy\n    lr_decay = float(args.lr_decay)\n    lr_steps = sorted(\n        [float(ls) for ls in args.lr_decay_epoch.split(\",\") if ls.strip()]\n    )\n    lr_warmup = float(args.lr_warmup)  # avoid int division\n\n    # TODO(zhreshold) losses?\n    rpn_cls_loss = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss(\n        from_sigmoid=False\n    )\n    rpn_box_loss = mx.gluon.loss.HuberLoss(rho=1 / 9.0)  # == smoothl1\n    rcnn_cls_loss = mx.gluon.loss.SoftmaxCrossEntropyLoss()\n    rcnn_box_loss = mx.gluon.loss.HuberLoss()  # == smoothl1\n    metrics = [\n        mx.metric.Loss(\"RPN_Conf\"),\n        mx.metric.Loss(\"RPN_SmoothL1\"),\n        mx.metric.Loss(\"RCNN_CrossEntropy\"),\n        mx.metric.Loss(\"RCNN_SmoothL1\"),\n    ]\n\n    rpn_acc_metric = RPNAccMetric()\n    rpn_bbox_metric = RPNL1LossMetric()\n    rcnn_acc_metric = RCNNAccMetric()\n    rcnn_bbox_metric = RCNNL1LossMetric()\n    metrics2 = [\n        rpn_acc_metric,\n        rpn_bbox_metric,\n        rcnn_acc_metric,\n        rcnn_bbox_metric,\n    ]\n\n    # set up logger\n    logging.basicConfig()\n    logger = logging.getLogger()\n    logger.setLevel(logging.INFO)\n    log_file_path = args.save_prefix + \"_train.log\"\n    log_dir = os.path.dirname(log_file_path)\n    if log_dir and not os.path.exists(log_dir):\n        os.makedirs(log_dir)\n    fh = logging.FileHandler(log_file_path)\n    logger.addHandler(fh)\n    logger.info(args)\n    if args.verbose:\n        logger.info(\"Trainable parameters:\")\n        logger.info(net.collect_train_params().keys())\n    logger.info(\"Start training from [Epoch {}]\".format(args.start_epoch))\n    best_map = [0]\n    for epoch in range(args.start_epoch, args.epochs):\n        mix_ratio = 1.0\n        if not args.disable_hybridization:\n            net.hybridize(static_alloc=args.static_alloc)\n        rcnn_task = ForwardBackwardTask(\n            net,\n            trainer,\n            rpn_cls_loss,\n            rpn_box_loss,\n            rcnn_cls_loss,\n            rcnn_box_loss,\n            mix_ratio=1.0,\n        )\n        executor = (\n            Parallel(args.executor_threads, rcnn_task)\n            if not args.horovod\n            else None\n        )\n        if args.mixup:\n            # TODO(zhreshold) only support evenly mixup now, target generator needs to be modified otherwise\n            train_data._dataset._data.set_mixup(np.random.uniform, 0.5, 0.5)\n            mix_ratio = 0.5\n            if epoch >= args.epochs - args.no_mixup_epochs:\n                train_data._dataset._data.set_mixup(None)\n                mix_ratio = 1.0\n        while lr_steps and epoch >= lr_steps[0]:\n            new_lr = trainer.learning_rate * lr_decay\n            lr_steps.pop(0)\n            trainer.set_learning_rate(new_lr)\n            logger.info(\n                \"[Epoch {}] Set learning rate to {}\".format(epoch, new_lr)\n            )\n        for metric in metrics:\n            metric.reset()\n        tic = time.time()\n        btic = time.time()\n        base_lr = trainer.learning_rate\n        rcnn_task.mix_ratio = mix_ratio\n        logger.info(\"Total Num of Batches: %d\" % (len(train_data)))\n        for i, batch in enumerate(train_data):\n            if epoch == 0 and i <= lr_warmup:\n                # adjust based on real percentage\n                new_lr = base_lr * get_lr_at_iter(\n                    i / lr_warmup, args.lr_warmup_factor\n                )\n                if new_lr != trainer.learning_rate:\n                    if i % args.log_interval == 0:\n                        logger.info(\n                            \"[Epoch 0 Iteration {}] Set learning rate to {}\".format(\n                                i, new_lr\n                            )\n                        )\n                    trainer.set_learning_rate(new_lr)\n            batch = split_and_load(batch, ctx_list=ctx)\n            metric_losses = [[] for _ in metrics]\n            add_losses = [[] for _ in metrics2]\n            if executor is not None:\n                for data in zip(*batch):\n                    executor.put(data)\n            for j in range(len(ctx)):\n                if executor is not None:\n                    result = executor.get()\n                else:\n                    result = rcnn_task.forward_backward(list(zip(*batch))[0])\n                if (not args.horovod) or hvd.rank() == 0:\n                    for k in range(len(metric_losses)):\n                        metric_losses[k].append(result[k])\n                    for k in range(len(add_losses)):\n                        add_losses[k].append(result[len(metric_losses) + k])\n            for metric, record in zip(metrics, metric_losses):\n                metric.update(0, record)\n            for metric, records in zip(metrics2, add_losses):\n                for pred in records:\n                    metric.update(pred[0], pred[1])\n            trainer.step(batch_size)\n\n            # update metrics\n            if (\n                (not args.horovod or hvd.rank() == 0)\n                and args.log_interval\n                and not (i + 1) % args.log_interval\n            ):\n                msg = \",\".join(\n                    [\n                        \"{}={:.3f}\".format(*metric.get())\n                        for metric in metrics + metrics2\n                    ]\n                )\n                logger.info(\n                    \"[Epoch {}][Batch {}], Speed: {:.3f} samples/sec, {}\".format(\n                        epoch,\n                        i,\n                        args.log_interval\n                        * args.batch_size\n                        / (time.time() - btic),\n                        msg,\n                    )\n                )\n                btic = time.time()\n\n        if (not args.horovod) or hvd.rank() == 0:\n            msg = \",\".join(\n                [\"{}={:.3f}\".format(*metric.get()) for metric in metrics]\n            )\n            logger.info(\n                \"[Epoch {}] Training cost: {:.3f}, {}\".format(\n                    epoch, (time.time() - tic), msg\n                )\n            )\n            if not (epoch + 1) % args.val_interval:\n                # consider reduce the frequency of validation to save time\n                if val_data is not None:\n                    map_name, mean_ap = validate(\n                        net, val_data, ctx, eval_metric, args\n                    )\n                    val_msg = \"\\n\".join(\n                        [\n                            \"{}={}\".format(k, v)\n                            for k, v in zip(map_name, mean_ap)\n                        ]\n                    )\n                    logger.info(\n                        \"[Epoch {}] Validation: \\n{}\".format(epoch, val_msg)\n                    )\n                    current_map = float(mean_ap[-1])\n                else:\n                    current_map = 0\n            else:\n                current_map = 0.0\n            save_params(\n                net,\n                logger,\n                best_map,\n                current_map,\n                epoch,\n                args.save_interval,\n                args.save_prefix,\n            )\n\n\nif __name__ == \"__main__\":\n    import sys\n\n    sys.setrecursionlimit(1100)\n    args = parse_args()\n    # fix seed for mxnet, numpy and python builtin random generator.\n    gutils.random.seed(args.seed)\n\n    if args.amp:\n        amp.init()\n\n    # training contexts\n    if args.horovod:\n        ctx = [mx.gpu(hvd.local_rank())]\n    else:\n        ctx = [mx.gpu(int(i)) for i in args.gpus.split(\",\") if i.strip()]\n        ctx = ctx if ctx else [mx.cpu()]\n\n    # network\n    kwargs = {}\n    module_list = []\n    if args.use_fpn:\n        module_list.append(\"fpn\")\n    if args.norm_layer is not None:\n        module_list.append(args.norm_layer)\n        if args.norm_layer == \"bn\":\n            kwargs[\"num_devices\"] = len(args.gpus.split(\",\"))\n\n    net_name = \"_\".join((\"faster_rcnn\", *module_list, args.network, \"custom\"))\n    args.save_prefix += net_name\n    gutils.makedirs(args.save_prefix)\n    train_dataset, val_dataset, eval_metric = get_dataset(args.dataset, args)\n    net = faster_rcnn_resnet101_v1d_custom(\n        classes=train_dataset.classes,\n        transfer=\"coco\",\n        pretrained_base=False,\n        additional_output=False,\n        per_device_batch_size=args.batch_size // len(ctx),\n        **kwargs\n    )\n    if args.resume.strip():\n        net.load_parameters(args.resume.strip())\n    else:\n        for param in net.collect_params().values():\n            if param._data is not None:\n                continue\n            param.initialize()\n    net.collect_params().reset_ctx(ctx)\n\n    # training data\n    batch_size = (\n        args.batch_size // len(ctx) if args.horovod else args.batch_size\n    )\n    train_data, val_data = get_dataloader(\n        net,\n        train_dataset,\n        val_dataset,\n        FasterRCNNDefaultTrainTransform,\n        FasterRCNNDefaultValTransform,\n        batch_size,\n        len(ctx),\n        args,\n    )\n\n    # training\n    train(net, train_data, val_data, eval_metric, batch_size, ctx, args)\n"
  },
  {
    "path": "examples/mxnet/scenegraph/train_faster_rcnn.sh",
    "content": "MXNET_CUDNN_AUTOTUNE_DEFAULT=0 CUDNN_AUTOTUNE_DEFAULT=0 MXNET_GPU_MEM_POOL_TYPE=Round MXNET_GPU_MEM_POOL_ROUND_LINEAR_CUTOFF=28 python train_faster_rcnn.py \\\n    --gpus 0,1,2,3,4,5,6,7 --dataset visualgenome -j 60 --batch-size 8 --val-interval 20 --save-prefix faster_rcnn_resnet101_v1d_visualgenome/\n"
  },
  {
    "path": "examples/mxnet/scenegraph/train_freq_prior.py",
    "content": "import argparse\nimport json\nimport os\nimport pickle\n\nimport numpy as np\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(\n        description=\"Train the Frequenct Prior For RelDN.\"\n    )\n    parser.add_argument(\n        \"--overlap\", action=\"store_true\", help=\"Only count overlap boxes.\"\n    )\n    parser.add_argument(\n        \"--json-path\",\n        type=str,\n        default=\"~/.mxnet/datasets/visualgenome\",\n        help=\"Only count overlap boxes.\",\n    )\n    args = parser.parse_args()\n    return args\n\n\nargs = parse_args()\nuse_overlap = args.overlap\nPATH_TO_DATASETS = os.path.expanduser(args.json_path)\npath_to_json = os.path.join(PATH_TO_DATASETS, \"rel_annotations_train.json\")\n\n\n# format in y1y2x1x2\ndef with_overlap(boxA, boxB):\n    xA = max(boxA[2], boxB[2])\n    xB = min(boxA[3], boxB[3])\n\n    if xB > xA:\n        yA = max(boxA[0], boxB[0])\n        yB = min(boxA[1], boxB[1])\n\n        if yB > yA:\n            return 1\n\n    return 0\n\n\ndef box_ious(boxes):\n    n = len(boxes)\n    res = np.zeros((n, n))\n    for i in range(n - 1):\n        for j in range(i + 1, n):\n            iou_val = with_overlap(boxes[i], boxes[j])\n            res[i, j] = iou_val\n            res[j, i] = iou_val\n    return res\n\n\nwith open(path_to_json, \"r\") as f:\n    tmp = f.read()\n    train_data = json.loads(tmp)\n\nfg_matrix = np.zeros((150, 150, 51), dtype=np.int64)\nbg_matrix = np.zeros((150, 150), dtype=np.int64)\n\nfor _, item in train_data.items():\n    gt_box_to_label = {}\n    for rel in item:\n        sub_bbox = rel[\"subject\"][\"bbox\"]\n        ob_bbox = rel[\"object\"][\"bbox\"]\n        sub_class = rel[\"subject\"][\"category\"]\n        ob_class = rel[\"object\"][\"category\"]\n        rel_class = rel[\"predicate\"]\n\n        sub_node = tuple(sub_bbox)\n        ob_node = tuple(ob_bbox)\n        if sub_node not in gt_box_to_label:\n            gt_box_to_label[sub_node] = sub_class\n        if ob_node not in gt_box_to_label:\n            gt_box_to_label[ob_node] = ob_class\n\n        fg_matrix[sub_class, ob_class, rel_class + 1] += 1\n\n    if use_overlap:\n        gt_boxes = [*gt_box_to_label]\n        gt_classes = np.array([*gt_box_to_label.values()])\n        iou_mat = box_ious(gt_boxes)\n        cols, rows = np.where(iou_mat)\n        if len(cols) and len(rows):\n            for col, row in zip(cols, rows):\n                bg_matrix[gt_classes[col], gt_classes[row]] += 1\n        else:\n            all_possib = np.ones_like(iou_mat, dtype=np.bool_)\n            np.fill_diagonal(all_possib, 0)\n            cols, rows = np.where(all_possib)\n            for col, row in zip(cols, rows):\n                bg_matrix[gt_classes[col], gt_classes[row]] += 1\n    else:\n        for b1, l1 in gt_box_to_label.items():\n            for b2, l2 in gt_box_to_label.items():\n                if b1 == b2:\n                    continue\n                bg_matrix[l1, l2] += 1\n\n\neps = 1e-3\nbg_matrix += 1\nfg_matrix[:, :, 0] = bg_matrix\npred_dist = np.log(fg_matrix / (fg_matrix.sum(2)[:, :, None] + eps) + eps)\n\n\nif use_overlap:\n    with open(\"freq_prior_overlap.pkl\", \"wb\") as f:\n        pickle.dump(pred_dist, f)\nelse:\n    with open(\"freq_prior.pkl\", \"wb\") as f:\n        pickle.dump(pred_dist, f)\n"
  },
  {
    "path": "examples/mxnet/scenegraph/train_reldn.py",
    "content": "import argparse\nimport logging\nimport time\n\nimport mxnet as mx\nimport numpy as np\nfrom data import *\nfrom gluoncv.data.batchify import Pad\nfrom gluoncv.utils import makedirs\nfrom model import faster_rcnn_resnet101_v1d_custom, RelDN\nfrom mxnet import gluon, nd\nfrom utils import *\n\nimport dgl\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Train RelDN Model.\")\n    parser.add_argument(\n        \"--gpus\",\n        type=str,\n        default=\"0\",\n        help=\"Training with GPUs, you can specify 1,3 for example.\",\n    )\n    parser.add_argument(\n        \"--batch-size\",\n        type=int,\n        default=8,\n        help=\"Total batch-size for training.\",\n    )\n    parser.add_argument(\n        \"--epochs\", type=int, default=9, help=\"Training epochs.\"\n    )\n    parser.add_argument(\n        \"--lr-reldn\",\n        type=float,\n        default=0.01,\n        help=\"Learning rate for RelDN module.\",\n    )\n    parser.add_argument(\n        \"--wd-reldn\",\n        type=float,\n        default=0.0001,\n        help=\"Weight decay for RelDN module.\",\n    )\n    parser.add_argument(\n        \"--lr-faster-rcnn\",\n        type=float,\n        default=0.01,\n        help=\"Learning rate for Faster R-CNN module.\",\n    )\n    parser.add_argument(\n        \"--wd-faster-rcnn\",\n        type=float,\n        default=0.0001,\n        help=\"Weight decay for RelDN module.\",\n    )\n    parser.add_argument(\n        \"--lr-decay-epochs\",\n        type=str,\n        default=\"5,8\",\n        help=\"Learning rate decay points.\",\n    )\n    parser.add_argument(\n        \"--lr-warmup-iters\",\n        type=int,\n        default=4000,\n        help=\"Learning rate warm-up iterations.\",\n    )\n    parser.add_argument(\n        \"--save-dir\",\n        type=str,\n        default=\"params_resnet101_v1d_reldn\",\n        help=\"Path to save model parameters.\",\n    )\n    parser.add_argument(\n        \"--log-dir\",\n        type=str,\n        default=\"reldn_output.log\",\n        help=\"Path to save training logs.\",\n    )\n    parser.add_argument(\n        \"--pretrained-faster-rcnn-params\",\n        type=str,\n        required=True,\n        help=\"Path to saved Faster R-CNN model parameters.\",\n    )\n    parser.add_argument(\n        \"--freq-prior\",\n        type=str,\n        default=\"freq_prior.pkl\",\n        help=\"Path to saved frequency prior data.\",\n    )\n    parser.add_argument(\n        \"--verbose-freq\",\n        type=int,\n        default=100,\n        help=\"Frequency of log printing in number of iterations.\",\n    )\n\n    args = parser.parse_args()\n    return args\n\n\nargs = parse_args()\n\nfilehandler = logging.FileHandler(args.log_dir)\nstreamhandler = logging.StreamHandler()\nlogger = logging.getLogger(\"\")\nlogger.setLevel(logging.INFO)\nlogger.addHandler(filehandler)\nlogger.addHandler(streamhandler)\n\n# Hyperparams\nctx = [mx.gpu(int(i)) for i in args.gpus.split(\",\") if i.strip()]\nif ctx:\n    num_gpus = len(ctx)\n    assert args.batch_size % num_gpus == 0\n    per_device_batch_size = int(args.batch_size / num_gpus)\nelse:\n    ctx = [mx.cpu()]\n    per_device_batch_size = args.batch_size\n\naggregate_grad = per_device_batch_size > 1\n\nnepoch = args.epochs\nN_relations = 50\nN_objects = 150\nsave_dir = args.save_dir\nmakedirs(save_dir)\nbatch_verbose_freq = args.verbose_freq\nlr_decay_epochs = [int(i) for i in args.lr_decay_epochs.split(\",\")]\n\n# Dataset and dataloader\nvg_train = VGRelation(split=\"train\")\nlogger.info(\"data loaded!\")\ntrain_data = gluon.data.DataLoader(\n    vg_train,\n    batch_size=len(ctx),\n    shuffle=True,\n    num_workers=8 * num_gpus,\n    batchify_fn=dgl_mp_batchify_fn,\n)\nn_batches = len(train_data)\n\n# Network definition\nnet = RelDN(n_classes=N_relations, prior_pkl=args.freq_prior)\nnet.spatial.initialize(mx.init.Normal(1e-4), ctx=ctx)\nnet.visual.initialize(mx.init.Normal(1e-4), ctx=ctx)\nfor k, v in net.collect_params().items():\n    v.grad_req = \"add\" if aggregate_grad else \"write\"\nnet_params = net.collect_params()\nnet_trainer = gluon.Trainer(\n    net.collect_params(),\n    \"adam\",\n    {\"learning_rate\": args.lr_reldn, \"wd\": args.wd_reldn},\n)\n\ndet_params_path = args.pretrained_faster_rcnn_params\ndetector = faster_rcnn_resnet101_v1d_custom(\n    classes=vg_train.obj_classes,\n    pretrained_base=False,\n    pretrained=False,\n    additional_output=True,\n)\ndetector.load_parameters(\n    det_params_path, ctx=ctx, ignore_extra=True, allow_missing=True\n)\nfor k, v in detector.collect_params().items():\n    v.grad_req = \"null\"\n\ndetector_feat = faster_rcnn_resnet101_v1d_custom(\n    classes=vg_train.obj_classes,\n    pretrained_base=False,\n    pretrained=False,\n    additional_output=True,\n)\ndetector_feat.load_parameters(\n    det_params_path, ctx=ctx, ignore_extra=True, allow_missing=True\n)\nfor k, v in detector_feat.collect_params().items():\n    v.grad_req = \"null\"\nfor k, v in detector_feat.features.collect_params().items():\n    v.grad_req = \"add\" if aggregate_grad else \"write\"\ndet_params = detector_feat.features.collect_params()\ndet_trainer = gluon.Trainer(\n    detector_feat.features.collect_params(),\n    \"adam\",\n    {\"learning_rate\": args.lr_faster_rcnn, \"wd\": args.wd_faster_rcnn},\n)\n\n\ndef get_data_batch(g_list, img_list, ctx_list):\n    if g_list is None or len(g_list) == 0:\n        return None, None\n    n_gpu = len(ctx_list)\n    size = len(g_list)\n    if size < n_gpu:\n        raise Exception(\"too small batch\")\n    step = size // n_gpu\n    G_list = [\n        g_list[i * step : (i + 1) * step]\n        if i < n_gpu - 1\n        else g_list[i * step : size]\n        for i in range(n_gpu)\n    ]\n    img_list = [\n        img_list[i * step : (i + 1) * step]\n        if i < n_gpu - 1\n        else img_list[i * step : size]\n        for i in range(n_gpu)\n    ]\n\n    for G_slice, ctx in zip(G_list, ctx_list):\n        for G in G_slice:\n            G.ndata[\"bbox\"] = G.ndata[\"bbox\"].as_in_context(ctx)\n            G.ndata[\"node_class\"] = G.ndata[\"node_class\"].as_in_context(ctx)\n            G.ndata[\"node_class_vec\"] = G.ndata[\"node_class_vec\"].as_in_context(\n                ctx\n            )\n            G.edata[\"rel_class\"] = G.edata[\"rel_class\"].as_in_context(ctx)\n    img_list = [img.as_in_context(ctx) for img in img_list]\n    return G_list, img_list\n\n\nL_rel = gluon.loss.SoftmaxCELoss()\n\ntrain_metric = mx.metric.Accuracy(name=\"rel_acc\")\ntrain_metric_top5 = mx.metric.TopKAccuracy(5, name=\"rel_acc_top5\")\nmetric_list = [train_metric, train_metric_top5]\n\n\ndef batch_print(\n    epoch, i, batch_verbose_freq, n_batches, btic, loss_rel_val, metric_list\n):\n    if (i + 1) % batch_verbose_freq == 0:\n        print_txt = \"Epoch[%d] Batch[%d/%d], time: %d, loss_rel=%.4f \" % (\n            epoch,\n            i,\n            n_batches,\n            int(time.time() - btic),\n            loss_rel_val / (i + 1),\n        )\n        for metric in metric_list:\n            metric_name, metric_val = metric.get()\n            print_txt += \"%s=%.4f \" % (metric_name, metric_val)\n        logger.info(print_txt)\n        btic = time.time()\n        loss_rel_val = 0\n    return btic, loss_rel_val\n\n\nfor epoch in range(nepoch):\n    loss_rel_val = 0\n    tic = time.time()\n    btic = time.time()\n    for metric in metric_list:\n        metric.reset()\n    if epoch == 0:\n        net_trainer_base_lr = net_trainer.learning_rate\n        det_trainer_base_lr = det_trainer.learning_rate\n    if epoch == 5 or epoch == 8:\n        net_trainer.set_learning_rate(net_trainer.learning_rate * 0.1)\n        det_trainer.set_learning_rate(det_trainer.learning_rate * 0.1)\n    for i, (G_list, img_list) in enumerate(train_data):\n        if epoch == 0 and i < args.lr_warmup_iters:\n            alpha = i / args.lr_warmup_iters\n            warmup_factor = 1 / 3 * (1 - alpha) + alpha\n            net_trainer.set_learning_rate(net_trainer_base_lr * warmup_factor)\n            det_trainer.set_learning_rate(det_trainer_base_lr * warmup_factor)\n        G_list, img_list = get_data_batch(G_list, img_list, ctx)\n        if G_list is None or img_list is None:\n            btic, loss_rel_val = batch_print(\n                epoch,\n                i,\n                batch_verbose_freq,\n                n_batches,\n                btic,\n                loss_rel_val,\n                metric_list,\n            )\n            continue\n\n        loss = []\n        detector_res_list = []\n        G_batch = []\n        bbox_pad = Pad(axis=(0))\n        with mx.autograd.record():\n            for G_slice, img in zip(G_list, img_list):\n                cur_ctx = img.context\n                bbox_list = [G.ndata[\"bbox\"] for G in G_slice]\n                bbox_stack = bbox_pad(bbox_list).as_in_context(cur_ctx)\n                with mx.autograd.pause():\n                    ids, scores, bbox, feat, feat_ind, spatial_feat = detector(\n                        img\n                    )\n                g_pred_batch = build_graph_train(\n                    G_slice,\n                    bbox_stack,\n                    img,\n                    ids,\n                    scores,\n                    bbox,\n                    feat_ind,\n                    spatial_feat,\n                    scores_top_k=300,\n                    overlap=False,\n                )\n                g_batch = l0_sample(g_pred_batch)\n                if g_batch is None:\n                    continue\n                rel_bbox = g_batch.edata[\"rel_bbox\"]\n                batch_id = g_batch.edata[\"batch_id\"].asnumpy()\n                n_sample_edges = g_batch.number_of_edges()\n                n_graph = len(G_slice)\n                bbox_rel_list = []\n                for j in range(n_graph):\n                    eids = np.where(batch_id == j)[0]\n                    if len(eids) > 0:\n                        bbox_rel_list.append(rel_bbox[eids])\n                bbox_rel_stack = bbox_pad(bbox_rel_list).as_in_context(cur_ctx)\n                img_size = img.shape[2:4]\n                bbox_rel_stack[:, :, 0] *= img_size[1]\n                bbox_rel_stack[:, :, 1] *= img_size[0]\n                bbox_rel_stack[:, :, 2] *= img_size[1]\n                bbox_rel_stack[:, :, 3] *= img_size[0]\n                _, _, _, spatial_feat_rel = detector_feat(\n                    img, None, None, bbox_rel_stack\n                )\n                spatial_feat_rel_list = []\n                for j in range(n_graph):\n                    eids = np.where(batch_id == j)[0]\n                    if len(eids) > 0:\n                        spatial_feat_rel_list.append(\n                            spatial_feat_rel[j, 0 : len(eids)]\n                        )\n                g_batch.edata[\"edge_feat\"] = nd.concat(\n                    *spatial_feat_rel_list, dim=0\n                )\n\n                G_batch.append(g_batch)\n\n            G_batch = [net(G) for G in G_batch]\n\n            for G_pred, img in zip(G_batch, img_list):\n                if G_pred is None or G_pred.number_of_nodes() == 0:\n                    continue\n                loss_rel = L_rel(\n                    G_pred.edata[\"preds\"],\n                    G_pred.edata[\"rel_class\"],\n                    G_pred.edata[\"sample_weights\"],\n                )\n                loss.append(loss_rel.sum())\n                loss_rel_val += loss_rel.mean().asscalar() / num_gpus\n\n        if len(loss) == 0:\n            btic, loss_rel_val = batch_print(\n                epoch,\n                i,\n                batch_verbose_freq,\n                n_batches,\n                btic,\n                loss_rel_val,\n                metric_list,\n            )\n            continue\n        for l in loss:\n            l.backward()\n        if (i + 1) % per_device_batch_size == 0 or i == n_batches - 1:\n            net_trainer.step(args.batch_size)\n            det_trainer.step(args.batch_size)\n            if aggregate_grad:\n                for k, v in net_params.items():\n                    v.zero_grad()\n                for k, v in det_params.items():\n                    v.zero_grad()\n        for G_pred, img_slice in zip(G_batch, img_list):\n            if G_pred is None or G_pred.number_of_nodes() == 0:\n                continue\n            link_ind = np.where(G_pred.edata[\"rel_class\"].asnumpy() > 0)[0]\n            if len(link_ind) == 0:\n                continue\n            train_metric.update(\n                [G_pred.edata[\"rel_class\"][link_ind]],\n                [G_pred.edata[\"preds\"][link_ind]],\n            )\n            train_metric_top5.update(\n                [G_pred.edata[\"rel_class\"][link_ind]],\n                [G_pred.edata[\"preds\"][link_ind]],\n            )\n        btic, loss_rel_val = batch_print(\n            epoch,\n            i,\n            batch_verbose_freq,\n            n_batches,\n            btic,\n            loss_rel_val,\n            metric_list,\n        )\n        if (i + 1) % batch_verbose_freq == 0:\n            net.save_parameters(\"%s/model-%d.params\" % (save_dir, epoch))\n            detector_feat.features.save_parameters(\n                \"%s/detector_feat.features-%d.params\" % (save_dir, epoch)\n            )\n    print_txt = \"Epoch[%d], time: %d, loss_rel=%.4f,\" % (\n        epoch,\n        int(time.time() - tic),\n        loss_rel_val / (i + 1),\n    )\n    for metric in metric_list:\n        metric_name, metric_val = metric.get()\n        print_txt += \"%s=%.4f \" % (metric_name, metric_val)\n    logger.info(print_txt)\n    net.save_parameters(\"%s/model-%d.params\" % (save_dir, epoch))\n    detector_feat.features.save_parameters(\n        \"%s/detector_feat.features-%d.params\" % (save_dir, epoch)\n    )\n"
  },
  {
    "path": "examples/mxnet/scenegraph/train_reldn.sh",
    "content": "MXNET_CUDNN_AUTOTUNE_DEFAULT=0 python train_reldn.py \\\n    --pretrained-faster-rcnn-params faster_rcnn_resnet101_v1d_visualgenome/faster_rcnn_resnet101_v1d_custom_best.params\n"
  },
  {
    "path": "examples/mxnet/scenegraph/utils/__init__.py",
    "content": "from .build_graph import *\nfrom .metric import *\nfrom .sampling import *\nfrom .viz import *\n"
  },
  {
    "path": "examples/mxnet/scenegraph/utils/build_graph.py",
    "content": "import dgl\nimport numpy as np\nfrom mxnet import nd\n\n\ndef bbox_improve(bbox):\n    \"\"\"bbox encoding\"\"\"\n    area = (bbox[:, 2] - bbox[:, 0]) * (bbox[:, 3] - bbox[:, 1])\n    return nd.concat(bbox, area.expand_dims(1))\n\n\ndef extract_edge_bbox(g):\n    \"\"\"bbox encoding\"\"\"\n    src, dst = g.edges(order=\"eid\")\n    n = g.number_of_edges()\n    src_bbox = g.ndata[\"pred_bbox\"][src.asnumpy()]\n    dst_bbox = g.ndata[\"pred_bbox\"][dst.asnumpy()]\n    edge_bbox = nd.zeros((n, 4), ctx=g.ndata[\"pred_bbox\"].context)\n    edge_bbox[:, 0] = nd.stack(src_bbox[:, 0], dst_bbox[:, 0]).min(axis=0)\n    edge_bbox[:, 1] = nd.stack(src_bbox[:, 1], dst_bbox[:, 1]).min(axis=0)\n    edge_bbox[:, 2] = nd.stack(src_bbox[:, 2], dst_bbox[:, 2]).max(axis=0)\n    edge_bbox[:, 3] = nd.stack(src_bbox[:, 3], dst_bbox[:, 3]).max(axis=0)\n    return edge_bbox\n\n\ndef build_graph_train(\n    g_slice,\n    gt_bbox,\n    img,\n    ids,\n    scores,\n    bbox,\n    feat_ind,\n    spatial_feat,\n    iou_thresh=0.5,\n    bbox_improvement=True,\n    scores_top_k=50,\n    overlap=False,\n):\n    \"\"\"given ground truth and predicted bboxes, assign the label to the predicted w.r.t iou_thresh\"\"\"\n    # match and re-factor the graph\n    img_size = img.shape[2:4]\n    gt_bbox[:, :, 0] /= img_size[1]\n    gt_bbox[:, :, 1] /= img_size[0]\n    gt_bbox[:, :, 2] /= img_size[1]\n    gt_bbox[:, :, 3] /= img_size[0]\n    bbox[:, :, 0] /= img_size[1]\n    bbox[:, :, 1] /= img_size[0]\n    bbox[:, :, 2] /= img_size[1]\n    bbox[:, :, 3] /= img_size[0]\n\n    n_graph = len(g_slice)\n    g_pred_batch = []\n    for gi in range(n_graph):\n        g = g_slice[gi]\n        ctx = g.ndata[\"bbox\"].context\n        inds = np.where(scores[gi, :, 0].asnumpy() > 0)[0].tolist()\n        if len(inds) == 0:\n            return None\n        if len(inds) > scores_top_k:\n            top_score_inds = (\n                scores[gi, inds, 0].asnumpy().argsort()[::-1][0:scores_top_k]\n            )\n            inds = np.array(inds)[top_score_inds].tolist()\n\n        n_nodes = len(inds)\n        roi_ind = feat_ind[gi, inds].squeeze(axis=1)\n        g_pred = dgl.DGLGraph()\n        g_pred.add_nodes(\n            n_nodes,\n            {\n                \"pred_bbox\": bbox[gi, inds],\n                \"node_feat\": spatial_feat[gi, roi_ind],\n                \"node_class_pred\": ids[gi, inds, 0],\n                \"node_class_logit\": nd.log(scores[gi, inds, 0] + 1e-7),\n            },\n        )\n\n        # iou matching\n        ious = nd.contrib.box_iou(\n            gt_bbox[gi], g_pred.ndata[\"pred_bbox\"]\n        ).asnumpy()\n        H, W = ious.shape\n        h = H\n        w = W\n        pred_to_gt_ind = np.array([-1 for i in range(W)])\n        pred_to_gt_class_match = [0 for i in range(W)]\n        pred_to_gt_class_match_id = [0 for i in range(W)]\n        while h > 0 and w > 0:\n            ind = int(ious.argmax())\n            row_ind = ind // W\n            col_ind = ind % W\n            if ious[row_ind, col_ind] < iou_thresh:\n                break\n            pred_to_gt_ind[col_ind] = row_ind\n            gt_node_class = g.ndata[\"node_class\"][row_ind]\n            pred_node_class = g_pred.ndata[\"node_class_pred\"][col_ind]\n            if gt_node_class == pred_node_class:\n                pred_to_gt_class_match[col_ind] = 1\n                pred_to_gt_class_match_id[col_ind] = row_ind\n            ious[row_ind, :] = -1\n            ious[:, col_ind] = -1\n            h -= 1\n            w -= 1\n\n        n_nodes = g_pred.number_of_nodes()\n        triplet = []\n        adjmat = np.zeros((n_nodes, n_nodes))\n\n        src, dst = g.all_edges(order=\"eid\")\n        eid_keys = np.column_stack([src.asnumpy(), dst.asnumpy()])\n        eid_dict = {}\n        for i, key in enumerate(eid_keys):\n            k = tuple(key)\n            if k not in eid_dict:\n                eid_dict[k] = [i]\n            else:\n                eid_dict[k].append(i)\n        ori_rel_class = g.edata[\"rel_class\"].asnumpy()\n        for i in range(n_nodes):\n            for j in range(n_nodes):\n                if i != j:\n                    if pred_to_gt_class_match[i] and pred_to_gt_class_match[j]:\n                        sub_gt_id = pred_to_gt_class_match_id[i]\n                        ob_gt_id = pred_to_gt_class_match_id[j]\n                        eids = eid_dict[(sub_gt_id, ob_gt_id)]\n                        rel_cls = ori_rel_class[eids]\n                        n_edges_between = len(rel_cls)\n                        for ii in range(n_edges_between):\n                            triplet.append((i, j, rel_cls[ii]))\n                        adjmat[i, j] = 1\n                    else:\n                        triplet.append((i, j, 0))\n        src, dst, rel_class = tuple(zip(*triplet))\n        rel_class = nd.array(rel_class, ctx=ctx).expand_dims(1)\n        g_pred.add_edges(src, dst, data={\"rel_class\": rel_class})\n\n        # other operations\n        n_nodes = g_pred.number_of_nodes()\n        n_edges = g_pred.number_of_edges()\n        if bbox_improvement:\n            g_pred.ndata[\"pred_bbox\"] = bbox_improve(g_pred.ndata[\"pred_bbox\"])\n        g_pred.edata[\"rel_bbox\"] = extract_edge_bbox(g_pred)\n        g_pred.edata[\"batch_id\"] = nd.zeros((n_edges, 1), ctx=ctx) + gi\n\n        # remove non-overlapping edges\n        if overlap:\n            overlap_ious = nd.contrib.box_iou(\n                g_pred.ndata[\"pred_bbox\"][:, 0:4],\n                g_pred.ndata[\"pred_bbox\"][:, 0:4],\n            ).asnumpy()\n            cols, rows = np.where(overlap_ious <= 1e-7)\n            if cols.shape[0] > 0:\n                eids = g_pred.edge_ids(cols, rows)[2].asnumpy().tolist()\n                if len(eids):\n                    g_pred.remove_edges(eids)\n                    if g_pred.number_of_edges() == 0:\n                        g_pred = None\n        g_pred_batch.append(g_pred)\n\n    if n_graph > 1:\n        return dgl.batch(g_pred_batch)\n    else:\n        return g_pred_batch[0]\n\n\ndef build_graph_validate_gt_obj(\n    img, gt_ids, bbox, spatial_feat, bbox_improvement=True, overlap=False\n):\n    \"\"\"given ground truth bbox and label, build graph for validation\"\"\"\n    n_batch = img.shape[0]\n    img_size = img.shape[2:4]\n    bbox[:, :, 0] /= img_size[1]\n    bbox[:, :, 1] /= img_size[0]\n    bbox[:, :, 2] /= img_size[1]\n    bbox[:, :, 3] /= img_size[0]\n    ctx = img.context\n\n    g_batch = []\n    for btc in range(n_batch):\n        inds = np.where(bbox[btc].sum(1).asnumpy() > 0)[0].tolist()\n        if len(inds) == 0:\n            continue\n        n_nodes = len(inds)\n        g_pred = dgl.DGLGraph()\n        g_pred.add_nodes(\n            n_nodes,\n            {\n                \"pred_bbox\": bbox[btc, inds],\n                \"node_feat\": spatial_feat[btc, inds],\n                \"node_class_pred\": gt_ids[btc, inds, 0],\n                \"node_class_logit\": nd.zeros_like(\n                    gt_ids[btc, inds, 0], ctx=ctx\n                ),\n            },\n        )\n\n        edge_list = []\n        for i in range(n_nodes - 1):\n            for j in range(i + 1, n_nodes):\n                edge_list.append((i, j))\n        src, dst = tuple(zip(*edge_list))\n        g_pred.add_edges(src, dst)\n        g_pred.add_edges(dst, src)\n\n        n_nodes = g_pred.number_of_nodes()\n        n_edges = g_pred.number_of_edges()\n        if bbox_improvement:\n            g_pred.ndata[\"pred_bbox\"] = bbox_improve(g_pred.ndata[\"pred_bbox\"])\n        g_pred.edata[\"rel_bbox\"] = extract_edge_bbox(g_pred)\n        g_pred.edata[\"batch_id\"] = nd.zeros((n_edges, 1), ctx=ctx) + btc\n\n        g_batch.append(g_pred)\n\n    if len(g_batch) == 0:\n        return None\n    if len(g_batch) > 1:\n        return dgl.batch(g_batch)\n    return g_batch[0]\n\n\ndef build_graph_validate_gt_bbox(\n    img,\n    ids,\n    scores,\n    bbox,\n    spatial_feat,\n    gt_ids=None,\n    bbox_improvement=True,\n    overlap=False,\n):\n    \"\"\"given ground truth bbox, build graph for validation\"\"\"\n    n_batch = img.shape[0]\n    img_size = img.shape[2:4]\n    bbox[:, :, 0] /= img_size[1]\n    bbox[:, :, 1] /= img_size[0]\n    bbox[:, :, 2] /= img_size[1]\n    bbox[:, :, 3] /= img_size[0]\n    ctx = img.context\n\n    g_batch = []\n    for btc in range(n_batch):\n        id_btc = scores[btc][:, :, 0].argmax(0)\n        score_btc = scores[btc][:, :, 0].max(0)\n        inds = np.where(bbox[btc].sum(1).asnumpy() > 0)[0].tolist()\n        if len(inds) == 0:\n            continue\n        n_nodes = len(inds)\n        g_pred = dgl.DGLGraph()\n        g_pred.add_nodes(\n            n_nodes,\n            {\n                \"pred_bbox\": bbox[btc, inds],\n                \"node_feat\": spatial_feat[btc, inds],\n                \"node_class_pred\": id_btc,\n                \"node_class_logit\": nd.log(score_btc + 1e-7),\n            },\n        )\n\n        edge_list = []\n        for i in range(n_nodes - 1):\n            for j in range(i + 1, n_nodes):\n                edge_list.append((i, j))\n        src, dst = tuple(zip(*edge_list))\n        g_pred.add_edges(src, dst)\n        g_pred.add_edges(dst, src)\n\n        n_nodes = g_pred.number_of_nodes()\n        n_edges = g_pred.number_of_edges()\n        if bbox_improvement:\n            g_pred.ndata[\"pred_bbox\"] = bbox_improve(g_pred.ndata[\"pred_bbox\"])\n        g_pred.edata[\"rel_bbox\"] = extract_edge_bbox(g_pred)\n        g_pred.edata[\"batch_id\"] = nd.zeros((n_edges, 1), ctx=ctx) + btc\n\n        g_batch.append(g_pred)\n\n    if len(g_batch) == 0:\n        return None\n    if len(g_batch) > 1:\n        return dgl.batch(g_batch)\n    return g_batch[0]\n\n\ndef build_graph_validate_pred(\n    img,\n    ids,\n    scores,\n    bbox,\n    feat_ind,\n    spatial_feat,\n    bbox_improvement=True,\n    scores_top_k=50,\n    overlap=False,\n):\n    \"\"\"given predicted bbox, build graph for validation\"\"\"\n    n_batch = img.shape[0]\n    img_size = img.shape[2:4]\n    bbox[:, :, 0] /= img_size[1]\n    bbox[:, :, 1] /= img_size[0]\n    bbox[:, :, 2] /= img_size[1]\n    bbox[:, :, 3] /= img_size[0]\n    ctx = img.context\n\n    g_batch = []\n    for btc in range(n_batch):\n        inds = np.where(scores[btc, :, 0].asnumpy() > 0)[0].tolist()\n        if len(inds) == 0:\n            continue\n        if len(inds) > scores_top_k:\n            top_score_inds = (\n                scores[btc, inds, 0].asnumpy().argsort()[::-1][0:scores_top_k]\n            )\n            inds = np.array(inds)[top_score_inds].tolist()\n        n_nodes = len(inds)\n        roi_ind = feat_ind[btc, inds].squeeze(axis=1)\n\n        g_pred = dgl.DGLGraph()\n        g_pred.add_nodes(\n            n_nodes,\n            {\n                \"pred_bbox\": bbox[btc, inds],\n                \"node_feat\": spatial_feat[btc, roi_ind],\n                \"node_class_pred\": ids[btc, inds, 0],\n                \"node_class_logit\": nd.log(scores[btc, inds, 0] + 1e-7),\n            },\n        )\n\n        edge_list = []\n        for i in range(n_nodes - 1):\n            for j in range(i + 1, n_nodes):\n                edge_list.append((i, j))\n        src, dst = tuple(zip(*edge_list))\n        g_pred.add_edges(src, dst)\n        g_pred.add_edges(dst, src)\n\n        n_nodes = g_pred.number_of_nodes()\n        n_edges = g_pred.number_of_edges()\n        if bbox_improvement:\n            g_pred.ndata[\"pred_bbox\"] = bbox_improve(g_pred.ndata[\"pred_bbox\"])\n        g_pred.edata[\"rel_bbox\"] = extract_edge_bbox(g_pred)\n        g_pred.edata[\"batch_id\"] = nd.zeros((n_edges, 1), ctx=ctx) + btc\n\n        g_batch.append(g_pred)\n\n    if len(g_batch) == 0:\n        return None\n    if len(g_batch) > 1:\n        return dgl.batch(g_batch)\n    return g_batch[0]\n"
  },
  {
    "path": "examples/mxnet/scenegraph/utils/metric.py",
    "content": "import logging\nimport time\nfrom operator import attrgetter, itemgetter\n\nimport dgl\n\nimport mxnet as mx\nimport numpy as np\nfrom dgl.nn.mxnet import GraphConv\nfrom dgl.utils import toindex\nfrom gluoncv.data.batchify import Pad\nfrom gluoncv.model_zoo import get_model\nfrom mxnet import gluon, nd\nfrom mxnet.gluon import nn\n\n\ndef iou(boxA, boxB):\n    # determine the (x, y)-coordinates of the intersection rectangle\n    xA = max(boxA[0], boxB[0])\n    yA = max(boxA[1], boxB[1])\n    xB = min(boxA[2], boxB[2])\n    yB = min(boxA[3], boxB[3])\n\n    interArea = max(0, xB - xA) * max(0, yB - yA)\n    if interArea < 1e-7:\n        return 0\n\n    boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])\n    boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])\n    if boxAArea + boxBArea - interArea < 1e-7:\n        return 0\n\n    iou_val = interArea / float(boxAArea + boxBArea - interArea)\n    return iou_val\n\n\ndef object_iou_thresh(gt_object, pred_object, iou_thresh=0.5):\n    obj_iou = iou(gt_object[1:5], pred_object[1:5])\n    if obj_iou >= iou_thresh:\n        return True\n    return False\n\n\ndef triplet_iou_thresh(pred_triplet, gt_triplet, iou_thresh=0.5):\n    sub_iou = iou(gt_triplet[5:9], pred_triplet[5:9])\n    if sub_iou >= iou_thresh:\n        ob_iou = iou(gt_triplet[9:13], pred_triplet[9:13])\n        if ob_iou >= iou_thresh:\n            return True\n    return False\n\n\n@mx.metric.register\n@mx.metric.alias(\"auc\")\nclass AUCMetric(mx.metric.EvalMetric):\n    def __init__(self, name=\"auc\", eps=1e-12):\n        super(AUCMetric, self).__init__(name)\n        self.eps = eps\n\n    def update(self, labels, preds):\n        mx.metric.check_label_shapes(labels, preds)\n        label_weight = labels[0].asnumpy()\n        preds = preds[0].asnumpy()\n        tmp = []\n        for i in range(preds.shape[0]):\n            tmp.append((label_weight[i], preds[i][1]))\n        tmp = sorted(tmp, key=itemgetter(1), reverse=True)\n        label_sum = label_weight.sum()\n        if label_sum == 0 or label_sum == label_weight.size:\n            return\n\n        label_one_num = np.count_nonzero(label_weight)\n        label_zero_num = len(label_weight) - label_one_num\n        total_area = label_zero_num * label_one_num\n        height = 0\n        width = 0\n        area = 0\n        for a, _ in tmp:\n            if a == 1.0:\n                height += 1.0\n            else:\n                width += 1.0\n                area += height\n\n        self.sum_metric += area / total_area\n        self.num_inst += 1\n\n\n@mx.metric.register\n@mx.metric.alias(\"predcls\")\nclass PredCls(mx.metric.EvalMetric):\n    \"\"\"Metric with ground truth object location and label\"\"\"\n\n    def __init__(self, topk=20, iou_thresh=0.99):\n        super(PredCls, self).__init__(\"predcls@%d\" % (topk))\n        self.topk = topk\n        self.iou_thresh = iou_thresh\n\n    def update(self, labels, preds):\n        if labels is None or preds is None:\n            self.num_inst += 1\n            return\n        preds = preds[preds[:, 0].argsort()[::-1]]\n        m = min(self.topk, preds.shape[0])\n        count = 0\n        gt_edge_num = labels.shape[0]\n        label_matched = [False for label in labels]\n        for i in range(m):\n            pred = preds[i]\n            for j in range(gt_edge_num):\n                if label_matched[j]:\n                    continue\n                label = labels[j]\n                if int(label[2]) == int(pred[2]) and triplet_iou_thresh(\n                    pred, label, self.iou_thresh\n                ):\n                    count += 1\n                    label_matched[j] = True\n\n        total = labels.shape[0]\n        self.sum_metric += count / total\n        self.num_inst += 1\n\n\n@mx.metric.register\n@mx.metric.alias(\"phrcls\")\nclass PhrCls(mx.metric.EvalMetric):\n    \"\"\"Metric with ground truth object location and predicted object label from detector\"\"\"\n\n    def __init__(self, topk=20, iou_thresh=0.99):\n        super(PhrCls, self).__init__(\"phrcls@%d\" % (topk))\n        self.topk = topk\n        self.iou_thresh = iou_thresh\n\n    def update(self, labels, preds):\n        if labels is None or preds is None:\n            self.num_inst += 1\n            return\n        preds = preds[preds[:, 1].argsort()[::-1]]\n        m = min(self.topk, preds.shape[0])\n        count = 0\n        gt_edge_num = labels.shape[0]\n        label_matched = [False for label in labels]\n        for i in range(m):\n            pred = preds[i]\n            for j in range(gt_edge_num):\n                if label_matched[j]:\n                    continue\n                label = labels[j]\n                if (\n                    int(label[2]) == int(pred[2])\n                    and int(label[3]) == int(pred[3])\n                    and int(label[4]) == int(pred[4])\n                    and triplet_iou_thresh(pred, label, self.iou_thresh)\n                ):\n                    count += 1\n                    label_matched[j] = True\n        total = labels.shape[0]\n        self.sum_metric += count / total\n        self.num_inst += 1\n\n\n@mx.metric.register\n@mx.metric.alias(\"sgdet\")\nclass SGDet(mx.metric.EvalMetric):\n    \"\"\"Metric with predicted object information by the detector\"\"\"\n\n    def __init__(self, topk=20, iou_thresh=0.5):\n        super(SGDet, self).__init__(\"sgdet@%d\" % (topk))\n        self.topk = topk\n        self.iou_thresh = iou_thresh\n\n    def update(self, labels, preds):\n        if labels is None or preds is None:\n            self.num_inst += 1\n            return\n        preds = preds[preds[:, 1].argsort()[::-1]]\n        m = min(self.topk, len(preds))\n        count = 0\n        gt_edge_num = labels.shape[0]\n        label_matched = [False for label in labels]\n        for i in range(m):\n            pred = preds[i]\n            for j in range(gt_edge_num):\n                if label_matched[j]:\n                    continue\n                label = labels[j]\n                if (\n                    int(label[2]) == int(pred[2])\n                    and int(label[3]) == int(pred[3])\n                    and int(label[4]) == int(pred[4])\n                    and triplet_iou_thresh(pred, label, self.iou_thresh)\n                ):\n                    count += 1\n                    label_matched[j] = True\n        total = labels.shape[0]\n        self.sum_metric += count / total\n        self.num_inst += 1\n\n\n@mx.metric.register\n@mx.metric.alias(\"sgdet+\")\nclass SGDetPlus(mx.metric.EvalMetric):\n    \"\"\"Metric proposed by `Graph R-CNN for Scene Graph Generation`\"\"\"\n\n    def __init__(self, topk=20, iou_thresh=0.5):\n        super(SGDetPlus, self).__init__(\"sgdet+@%d\" % (topk))\n        self.topk = topk\n        self.iou_thresh = iou_thresh\n\n    def update(self, labels, preds):\n        label_objects, label_triplets = labels\n        pred_objects, pred_triplets = preds\n        if label_objects is None or pred_objects is None:\n            self.num_inst += 1\n            return\n        count = 0\n        # count objects\n        object_matched = [False for obj in label_objects]\n        m = len(pred_objects)\n        gt_obj_num = label_objects.shape[0]\n        for i in range(m):\n            pred = pred_objects[i]\n            for j in range(gt_obj_num):\n                if object_matched[j]:\n                    continue\n                label = label_objects[j]\n                if int(label[0]) == int(pred[0]) and object_iou_thresh(\n                    pred, label, self.iou_thresh\n                ):\n                    count += 1\n                    object_matched[j] = True\n\n        # count predicate and triplet\n        pred_triplets = pred_triplets[pred_triplets[:, 1].argsort()[::-1]]\n        m = min(self.topk, len(pred_triplets))\n        gt_triplet_num = label_triplets.shape[0]\n        triplet_matched = [False for label in label_triplets]\n        predicate_matched = [False for label in label_triplets]\n        for i in range(m):\n            pred = pred_triplets[i]\n            for j in range(gt_triplet_num):\n                label = label_triplets[j]\n                if not predicate_matched:\n                    if int(label[2]) == int(pred[2]) and triplet_iou_thresh(\n                        pred, label, self.iou_thresh\n                    ):\n                        count += label[3]\n                        predicate_matched[j] = True\n                if not triplet_matched[j]:\n                    if (\n                        int(label[2]) == int(pred[2])\n                        and int(label[3]) == int(pred[3])\n                        and int(label[4]) == int(pred[4])\n                        and triplet_iou_thresh(pred, label, self.iou_thresh)\n                    ):\n                        count += 1\n                        triplet_matched[j] = True\n        # compute sum\n        total = labels.shape[0]\n        N = gt_obj_num + 2 * total\n        self.sum_metric += count / N\n        self.num_inst += 1\n\n\ndef extract_gt(g, img_size):\n    \"\"\"extract prediction from ground truth graph\"\"\"\n    if g is None or g.number_of_nodes() == 0:\n        return None, None\n    gt_eids = np.where(g.edata[\"rel_class\"].asnumpy() > 0)[0]\n    if len(gt_eids) == 0:\n        return None, None\n\n    gt_class = g.ndata[\"node_class\"][:, 0].asnumpy()\n    gt_bbox = g.ndata[\"bbox\"].asnumpy()\n    gt_bbox[:, 0] /= img_size[1]\n    gt_bbox[:, 1] /= img_size[0]\n    gt_bbox[:, 2] /= img_size[1]\n    gt_bbox[:, 3] /= img_size[0]\n\n    gt_objects = np.vstack([gt_class, gt_bbox.transpose(1, 0)]).transpose(1, 0)\n\n    gt_node_ids = g.find_edges(gt_eids)\n    gt_node_sub = gt_node_ids[0].asnumpy()\n    gt_node_ob = gt_node_ids[1].asnumpy()\n    gt_rel_class = g.edata[\"rel_class\"][gt_eids, 0].asnumpy() - 1\n    gt_sub_class = gt_class[gt_node_sub]\n    gt_ob_class = gt_class[gt_node_ob]\n\n    gt_sub_bbox = gt_bbox[gt_node_sub]\n    gt_ob_bbox = gt_bbox[gt_node_ob]\n\n    n = len(gt_eids)\n    gt_triplets = np.vstack(\n        [\n            np.ones(n),\n            np.ones(n),\n            gt_rel_class,\n            gt_sub_class,\n            gt_ob_class,\n            gt_sub_bbox.transpose(1, 0),\n            gt_ob_bbox.transpose(1, 0),\n        ]\n    ).transpose(1, 0)\n    return gt_objects, gt_triplets\n\n\ndef extract_pred(g, topk=100, joint_preds=False):\n    \"\"\"extract prediction from prediction graph for validation and visualization\"\"\"\n    if g is None or g.number_of_nodes() == 0:\n        return None, None\n\n    pred_class = g.ndata[\"node_class_pred\"].asnumpy()\n    pred_class_prob = g.ndata[\"node_class_logit\"].asnumpy()\n    pred_bbox = g.ndata[\"pred_bbox\"][:, 0:4].asnumpy()\n\n    pred_objects = np.vstack([pred_class, pred_bbox.transpose(1, 0)]).transpose(\n        1, 0\n    )\n\n    score_pred = g.edata[\"score_pred\"].asnumpy()\n    score_phr = g.edata[\"score_phr\"].asnumpy()\n    score_pred_topk_eids = (-score_pred).argsort()[0:topk].tolist()\n    score_phr_topk_eids = (-score_phr).argsort()[0:topk].tolist()\n    topk_eids = sorted(list(set(score_pred_topk_eids + score_phr_topk_eids)))\n\n    pred_rel_prob = g.edata[\"preds\"][topk_eids].asnumpy()\n    if joint_preds:\n        pred_rel_class = pred_rel_prob[:, 1:].argmax(axis=1)\n    else:\n        pred_rel_class = pred_rel_prob.argmax(axis=1)\n\n    pred_node_ids = g.find_edges(topk_eids)\n    pred_node_sub = pred_node_ids[0].asnumpy()\n    pred_node_ob = pred_node_ids[1].asnumpy()\n\n    pred_sub_class = pred_class[pred_node_sub]\n    pred_sub_class_prob = pred_class_prob[pred_node_sub]\n    pred_sub_bbox = pred_bbox[pred_node_sub]\n\n    pred_ob_class = pred_class[pred_node_ob]\n    pred_ob_class_prob = pred_class_prob[pred_node_ob]\n    pred_ob_bbox = pred_bbox[pred_node_ob]\n\n    pred_triplets = np.vstack(\n        [\n            score_pred[topk_eids],\n            score_phr[topk_eids],\n            pred_rel_class,\n            pred_sub_class,\n            pred_ob_class,\n            pred_sub_bbox.transpose(1, 0),\n            pred_ob_bbox.transpose(1, 0),\n        ]\n    ).transpose(1, 0)\n    return pred_objects, pred_triplets\n"
  },
  {
    "path": "examples/mxnet/scenegraph/utils/sampling.py",
    "content": "import dgl\nimport mxnet as mx\nimport numpy as np\nfrom dgl.utils import toindex\n\n\ndef l0_sample(g, positive_max=128, negative_ratio=3):\n    \"\"\"sampling positive and negative edges\"\"\"\n    if g is None:\n        return None\n    n_eids = g.number_of_edges()\n    pos_eids = np.where(g.edata[\"rel_class\"].asnumpy() > 0)[0]\n    neg_eids = np.where(g.edata[\"rel_class\"].asnumpy() == 0)[0]\n    if len(pos_eids) == 0:\n        return None\n\n    positive_num = min(len(pos_eids), positive_max)\n    negative_num = min(len(neg_eids), positive_num * negative_ratio)\n    pos_sample = np.random.choice(pos_eids, positive_num, replace=False)\n    neg_sample = np.random.choice(neg_eids, negative_num, replace=False)\n    weights = np.zeros(n_eids)\n    # np.add.at(weights, pos_sample, 1)\n    weights[pos_sample] = 1\n    weights[neg_sample] = 1\n    # g.edata['sample_weights'] = mx.nd.array(weights, ctx=g.edata['rel_class'].context)\n    # return g\n    eids = np.where(weights > 0)[0]\n    sub_g = g.edge_subgraph(toindex(eids.tolist()))\n    sub_g.copy_from_parent()\n    sub_g.edata[\"sample_weights\"] = mx.nd.array(\n        weights[eids], ctx=g.edata[\"rel_class\"].context\n    )\n    return sub_g\n"
  },
  {
    "path": "examples/mxnet/scenegraph/utils/viz.py",
    "content": "import gluoncv as gcv\nimport numpy as np\nfrom matplotlib import pyplot as plt\n\n\ndef plot_sg(img, preds, obj_classes, rel_classes, topk=1):\n    \"\"\"visualization of generated scene graph\"\"\"\n    size = img.shape[0:2]\n    box_scale = np.array([size[1], size[0], size[1], size[0]])\n    topk = min(topk, preds.shape[0])\n    ax = gcv.utils.viz.plot_image(img)\n    for i in range(topk):\n        rel = int(preds[i, 2])\n        src = int(preds[i, 3])\n        dst = int(preds[i, 4])\n        src_name = obj_classes[src]\n        dst_name = obj_classes[dst]\n        rel_name = rel_classes[rel]\n        src_bbox = preds[i, 5:9] * box_scale\n        dst_bbox = preds[i, 9:13] * box_scale\n\n        src_center = np.array(\n            [(src_bbox[0] + src_bbox[2]) / 2, (src_bbox[1] + src_bbox[3]) / 2]\n        )\n        dst_center = np.array(\n            [(dst_bbox[0] + dst_bbox[2]) / 2, (dst_bbox[1] + dst_bbox[3]) / 2]\n        )\n        rel_center = (src_center + dst_center) / 2\n\n        line_x = np.array(\n            [(src_bbox[0] + src_bbox[2]) / 2, (dst_bbox[0] + dst_bbox[2]) / 2]\n        )\n        line_y = np.array(\n            [(src_bbox[1] + src_bbox[3]) / 2, (dst_bbox[1] + dst_bbox[3]) / 2]\n        )\n\n        ax.plot(\n            line_x, line_y, linewidth=3.0, alpha=0.7, color=plt.cm.cool(rel)\n        )\n\n        ax.text(\n            src_center[0],\n            src_center[1],\n            \"{:s}\".format(src_name),\n            bbox=dict(alpha=0.5),\n            fontsize=12,\n            color=\"white\",\n        )\n        ax.text(\n            dst_center[0],\n            dst_center[1],\n            \"{:s}\".format(dst_name),\n            bbox=dict(alpha=0.5),\n            fontsize=12,\n            color=\"white\",\n        )\n        ax.text(\n            rel_center[0],\n            rel_center[1],\n            \"{:s}\".format(rel_name),\n            bbox=dict(alpha=0.5),\n            fontsize=12,\n            color=\"white\",\n        )\n    return ax\n\n\nplot_sg(img, preds, 2)\n"
  },
  {
    "path": "examples/mxnet/scenegraph/validate_reldn.py",
    "content": "import argparse\nimport logging\nimport time\n\nimport mxnet as mx\nimport numpy as np\nfrom data import *\nfrom gluoncv.data.batchify import Pad\nfrom model import faster_rcnn_resnet101_v1d_custom, RelDN\nfrom mxnet import gluon, nd\nfrom utils import *\n\nimport dgl\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(\n        description=\"Validate Pre-trained RelDN Model.\"\n    )\n    parser.add_argument(\n        \"--gpus\",\n        type=str,\n        default=\"0\",\n        help=\"Training with GPUs, you can specify 1,3 for example.\",\n    )\n    parser.add_argument(\n        \"--batch-size\",\n        type=int,\n        default=8,\n        help=\"Total batch-size for training.\",\n    )\n    parser.add_argument(\n        \"--metric\",\n        type=str,\n        default=\"sgdet\",\n        help=\"Evaluation metric, could be 'predcls', 'phrcls', 'sgdet' or 'sgdet+'.\",\n    )\n    parser.add_argument(\n        \"--pretrained-faster-rcnn-params\",\n        type=str,\n        required=True,\n        help=\"Path to saved Faster R-CNN model parameters.\",\n    )\n    parser.add_argument(\n        \"--reldn-params\",\n        type=str,\n        required=True,\n        help=\"Path to saved Faster R-CNN model parameters.\",\n    )\n    parser.add_argument(\n        \"--faster-rcnn-params\",\n        type=str,\n        required=True,\n        help=\"Path to saved Faster R-CNN model parameters.\",\n    )\n    parser.add_argument(\n        \"--log-dir\",\n        type=str,\n        default=\"reldn_output.log\",\n        help=\"Path to save training logs.\",\n    )\n    parser.add_argument(\n        \"--freq-prior\",\n        type=str,\n        default=\"freq_prior.pkl\",\n        help=\"Path to saved frequency prior data.\",\n    )\n    parser.add_argument(\n        \"--verbose-freq\",\n        type=int,\n        default=100,\n        help=\"Frequency of log printing in number of iterations.\",\n    )\n    args = parser.parse_args()\n    return args\n\n\nargs = parse_args()\n\nfilehandler = logging.FileHandler(args.log_dir)\nstreamhandler = logging.StreamHandler()\nlogger = logging.getLogger(\"\")\nlogger.setLevel(logging.INFO)\nlogger.addHandler(filehandler)\nlogger.addHandler(streamhandler)\n\n# Hyperparams\nctx = [mx.gpu(int(i)) for i in args.gpus.split(\",\") if i.strip()]\nif ctx:\n    num_gpus = len(ctx)\n    assert args.batch_size % num_gpus == 0\n    per_device_batch_size = int(args.batch_size / num_gpus)\nelse:\n    ctx = [mx.cpu()]\n    per_device_batch_size = args.batch_size\nbatch_size = args.batch_size\nN_relations = 50\nN_objects = 150\nbatch_verbose_freq = args.verbose_freq\n\nmode = args.metric\nmetric_list = []\ntopk_list = [20, 50, 100]\nif mode == \"predcls\":\n    for topk in topk_list:\n        metric_list.append(PredCls(topk=topk))\nif mode == \"phrcls\":\n    for topk in topk_list:\n        metric_list.append(PhrCls(topk=topk))\nif mode == \"sgdet\":\n    for topk in topk_list:\n        metric_list.append(SGDet(topk=topk))\nif mode == \"sgdet+\":\n    for topk in topk_list:\n        metric_list.append(SGDetPlus(topk=topk))\nfor metric in metric_list:\n    metric.reset()\n\nsemantic_only = False\nnet = RelDN(\n    n_classes=N_relations,\n    prior_pkl=args.freq_prior,\n    semantic_only=semantic_only,\n)\nnet.load_parameters(args.reldn_params, ctx=ctx)\n\n# dataset and dataloader\nvg_val = VGRelation(split=\"val\")\nlogger.info(\"data loaded!\")\nval_data = gluon.data.DataLoader(\n    vg_val,\n    batch_size=len(ctx),\n    shuffle=False,\n    num_workers=16 * num_gpus,\n    batchify_fn=dgl_mp_batchify_fn,\n)\nn_batches = len(val_data)\n\ndetector = faster_rcnn_resnet101_v1d_custom(\n    classes=vg_val.obj_classes,\n    pretrained_base=False,\n    pretrained=False,\n    additional_output=True,\n)\nparams_path = args.pretrained_faster_rcnn_params\ndetector.load_parameters(\n    params_path, ctx=ctx, ignore_extra=True, allow_missing=True\n)\n\ndetector_feat = faster_rcnn_resnet101_v1d_custom(\n    classes=vg_val.obj_classes,\n    pretrained_base=False,\n    pretrained=False,\n    additional_output=True,\n)\ndetector_feat.load_parameters(\n    params_path, ctx=ctx, ignore_extra=True, allow_missing=True\n)\n\ndetector_feat.features.load_parameters(args.faster_rcnn_params, ctx=ctx)\n\n\ndef get_data_batch(g_list, img_list, ctx_list):\n    if g_list is None or len(g_list) == 0:\n        return None, None\n    n_gpu = len(ctx_list)\n    size = len(g_list)\n    if size < n_gpu:\n        raise Exception(\"too small batch\")\n    step = size // n_gpu\n    G_list = [\n        g_list[i * step : (i + 1) * step]\n        if i < n_gpu - 1\n        else g_list[i * step : size]\n        for i in range(n_gpu)\n    ]\n    img_list = [\n        img_list[i * step : (i + 1) * step]\n        if i < n_gpu - 1\n        else img_list[i * step : size]\n        for i in range(n_gpu)\n    ]\n\n    for G_slice, ctx in zip(G_list, ctx_list):\n        for G in G_slice:\n            G.ndata[\"bbox\"] = G.ndata[\"bbox\"].as_in_context(ctx)\n            G.ndata[\"node_class\"] = G.ndata[\"node_class\"].as_in_context(ctx)\n            G.ndata[\"node_class_vec\"] = G.ndata[\"node_class_vec\"].as_in_context(\n                ctx\n            )\n            G.edata[\"rel_class\"] = G.edata[\"rel_class\"].as_in_context(ctx)\n    img_list = [img.as_in_context(ctx) for img in img_list]\n    return G_list, img_list\n\n\nfor i, (G_list, img_list) in enumerate(val_data):\n    G_list, img_list = get_data_batch(G_list, img_list, ctx)\n    if G_list is None or img_list is None:\n        if (i + 1) % batch_verbose_freq == 0:\n            print_txt = \"Batch[%d/%d] \" % (i, n_batches)\n            for metric in metric_list:\n                metric_name, metric_val = metric.get()\n                print_txt += \"%s=%.4f \" % (metric_name, metric_val)\n            logger.info(print_txt)\n        continue\n\n    detector_res_list = []\n    G_batch = []\n    bbox_pad = Pad(axis=(0))\n    # loss_cls_val = 0\n    for G_slice, img in zip(G_list, img_list):\n        cur_ctx = img.context\n        if mode == \"predcls\":\n            bbox_list = [G.ndata[\"bbox\"] for G in G_slice]\n            bbox_stack = bbox_pad(bbox_list).as_in_context(cur_ctx)\n            ids, scores, bbox, spatial_feat = detector(\n                img, None, None, bbox_stack\n            )\n\n            node_class_list = [G.ndata[\"node_class\"] for G in G_slice]\n            node_class_stack = bbox_pad(node_class_list).as_in_context(cur_ctx)\n            g_pred_batch = build_graph_validate_gt_obj(\n                img,\n                node_class_stack,\n                bbox,\n                spatial_feat,\n                bbox_improvement=True,\n                overlap=False,\n            )\n        elif mode == \"phrcls\":\n            # use ground truth bbox\n            bbox_list = [G.ndata[\"bbox\"] for G in G_slice]\n            bbox_stack = bbox_pad(bbox_list).as_in_context(cur_ctx)\n            ids, scores, bbox, spatial_feat = detector(\n                img, None, None, bbox_stack\n            )\n\n            g_pred_batch = build_graph_validate_gt_bbox(\n                img,\n                ids,\n                scores,\n                bbox,\n                spatial_feat,\n                bbox_improvement=True,\n                overlap=False,\n            )\n        else:\n            # use predicted bbox\n            ids, scores, bbox, feat, feat_ind, spatial_feat = detector(img)\n            g_pred_batch = build_graph_validate_pred(\n                img,\n                ids,\n                scores,\n                bbox,\n                feat_ind,\n                spatial_feat,\n                bbox_improvement=True,\n                scores_top_k=75,\n                overlap=False,\n            )\n        if not semantic_only:\n            rel_bbox = g_pred_batch.edata[\"rel_bbox\"]\n            batch_id = g_pred_batch.edata[\"batch_id\"].asnumpy()\n            n_sample_edges = g_pred_batch.number_of_edges()\n            # g_pred_batch.edata['edge_feat'] = mx.nd.zeros((n_sample_edges, 49), ctx=cur_ctx)\n            n_graph = len(G_slice)\n            bbox_rel_list = []\n            for j in range(n_graph):\n                eids = np.where(batch_id == j)[0]\n                if len(eids) > 0:\n                    bbox_rel_list.append(rel_bbox[eids])\n            bbox_rel_stack = bbox_pad(bbox_rel_list).as_in_context(cur_ctx)\n            _, _, _, spatial_feat_rel = detector_feat(\n                img, None, None, bbox_rel_stack\n            )\n            spatial_feat_rel_list = []\n            for j in range(n_graph):\n                eids = np.where(batch_id == j)[0]\n                if len(eids) > 0:\n                    spatial_feat_rel_list.append(\n                        spatial_feat_rel[j, 0 : len(eids)]\n                    )\n            g_pred_batch.edata[\"edge_feat\"] = nd.concat(\n                *spatial_feat_rel_list, dim=0\n            )\n\n        G_batch.append(g_pred_batch)\n\n    G_batch = [net(G) for G in G_batch]\n\n    for G_slice, G_pred, img_slice in zip(G_list, G_batch, img_list):\n        for G_gt, G_pred_one in zip(G_slice, [G_pred]):\n            if G_pred_one is None or G_pred_one.number_of_nodes() == 0:\n                continue\n            gt_objects, gt_triplet = extract_gt(G_gt, img_slice.shape[2:4])\n            pred_objects, pred_triplet = extract_pred(G_pred, joint_preds=True)\n            for metric in metric_list:\n                if (\n                    isinstance(metric, PredCls)\n                    or isinstance(metric, PhrCls)\n                    or isinstance(metric, SGDet)\n                ):\n                    metric.update(gt_triplet, pred_triplet)\n                else:\n                    metric.update(\n                        (gt_objects, gt_triplet), (pred_objects, pred_triplet)\n                    )\n    if (i + 1) % batch_verbose_freq == 0:\n        print_txt = \"Batch[%d/%d] \" % (i, n_batches)\n        for metric in metric_list:\n            metric_name, metric_val = metric.get()\n            print_txt += \"%s=%.4f \" % (metric_name, metric_val)\n        logger.info(print_txt)\n\nprint_txt = \"Batch[%d/%d] \" % (n_batches, n_batches)\nfor metric in metric_list:\n    metric_name, metric_val = metric.get()\n    print_txt += \"%s=%.4f \" % (metric_name, metric_val)\nlogger.info(print_txt)\n"
  },
  {
    "path": "examples/mxnet/scenegraph/validate_reldn.sh",
    "content": "MXNET_CUDNN_AUTOTUNE_DEFAULT=0 python validate_reldn.py \\\n    --pretrained-faster-rcnn-params faster_rcnn_resnet101_v1d_visualgenome/faster_rcnn_resnet101_v1d_custom_best.params \\\n    --reldn-params params_resnet101_v1d_reldn/model-8.params \\\n    --faster-rcnn-params params_resnet101_v1d_reldn/detector_feat.features-8.params\n"
  },
  {
    "path": "examples/mxnet/sgc/README.md",
    "content": "Simple Graph Convolution (SGC)\n============\n\n- Paper link: [Simplifying Graph Convolutional Networks](https://arxiv.org/abs/1902.07153)\n- Author's code repo: [https://github.com/Tiiiger/SGC](https://github.com/Tiiiger/SGC). \n\nDependencies\n------------\n- MXNET 1.5+\n- requests\n\n``bash\npip install torch requests\n``\n\nCodes\n-----\nThe folder contains an implementation of SGC (`sgc.py`).\n\nResults\n-------\n\nRun with following (available dataset: \"cora\", \"citeseer\", \"pubmed\")\n```bash\nDGLBACKEND=mxnet python3 sgc.py --dataset cora --gpu 0\nDGLBACKEND=mxnet python3 sgc.py --dataset citeseer --weight-decay 5e-5 --n-epochs 150 --bias --gpu 0\nDGLBACKEND=mxnet python3 sgc.py --dataset pubmed --weight-decay 5e-5 --bias --gpu 0\n```\n\nOn NVIDIA V100\n\n* cora: 0.818 (paper: 0.810)\n* citeseer: 0.725 (paper: 0.719)\n* pubmed: 0.788 (paper: 0.789)\n"
  },
  {
    "path": "examples/mxnet/sgc/sgc.py",
    "content": "\"\"\"\nThis code was modified from the GCN implementation in DGL examples.\nSimplifying Graph Convolutional Networks\nPaper: https://arxiv.org/abs/1902.07153\nCode: https://github.com/Tiiiger/SGC\nSGC implementation in DGL.\n\"\"\"\nimport argparse\nimport math\nimport time\n\nimport dgl\n\nimport mxnet as mx\nimport numpy as np\nfrom dgl.data import (\n    CiteseerGraphDataset,\n    CoraGraphDataset,\n    PubmedGraphDataset,\n    register_data_args,\n)\nfrom dgl.nn.mxnet.conv import SGConv\nfrom mxnet import gluon, nd\nfrom mxnet.gluon import nn\n\n\ndef evaluate(model, g, features, labels, mask):\n    pred = model(g, features).argmax(axis=1)\n    accuracy = ((pred == labels) * mask).sum() / mask.sum().asscalar()\n    return accuracy.asscalar()\n\n\ndef main(args):\n    # load and preprocess dataset\n    if args.dataset == \"cora\":\n        data = CoraGraphDataset()\n    elif args.dataset == \"citeseer\":\n        data = CiteseerGraphDataset()\n    elif args.dataset == \"pubmed\":\n        data = PubmedGraphDataset()\n    else:\n        raise ValueError(\"Unknown dataset: {}\".format(args.dataset))\n\n    g = data[0]\n    if args.gpu < 0:\n        cuda = False\n        ctx = mx.cpu(0)\n    else:\n        cuda = True\n        ctx = mx.gpu(args.gpu)\n        g = g.int().to(ctx)\n\n    features = g.ndata[\"feat\"]\n    labels = mx.nd.array(g.ndata[\"label\"], dtype=\"float32\", ctx=ctx)\n    train_mask = g.ndata[\"train_mask\"]\n    val_mask = g.ndata[\"val_mask\"]\n    test_mask = g.ndata[\"test_mask\"]\n    in_feats = features.shape[1]\n    n_classes = data.num_classes\n    n_edges = data.graph.number_of_edges()\n    print(\n        \"\"\"----Data statistics------'\n      #Edges %d\n      #Classes %d\n      #Train samples %d\n      #Val samples %d\n      #Test samples %d\"\"\"\n        % (\n            n_edges,\n            n_classes,\n            train_mask.sum().asscalar(),\n            val_mask.sum().asscalar(),\n            test_mask.sum().asscalar(),\n        )\n    )\n\n    # add self loop\n    g = dgl.remove_self_loop(g)\n    g = dgl.add_self_loop(g)\n\n    # create SGC model\n    model = SGConv(in_feats, n_classes, k=2, cached=True, bias=args.bias)\n\n    model.initialize(ctx=ctx)\n    n_train_samples = train_mask.sum().asscalar()\n    loss_fcn = gluon.loss.SoftmaxCELoss()\n\n    # use optimizer\n    print(model.collect_params())\n    trainer = gluon.Trainer(\n        model.collect_params(),\n        \"adam\",\n        {\"learning_rate\": args.lr, \"wd\": args.weight_decay},\n    )\n\n    # initialize graph\n    dur = []\n    for epoch in range(args.n_epochs):\n        if epoch >= 3:\n            t0 = time.time()\n        # forward\n        with mx.autograd.record():\n            pred = model(g, features)\n            loss = loss_fcn(pred, labels, mx.nd.expand_dims(train_mask, 1))\n            loss = loss.sum() / n_train_samples\n\n        loss.backward()\n        trainer.step(batch_size=1)\n\n        if epoch >= 3:\n            loss.asscalar()\n            dur.append(time.time() - t0)\n            acc = evaluate(model, g, features, labels, val_mask)\n            print(\n                \"Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | \"\n                \"ETputs(KTEPS) {:.2f}\".format(\n                    epoch,\n                    np.mean(dur),\n                    loss.asscalar(),\n                    acc,\n                    n_edges / np.mean(dur) / 1000,\n                )\n            )\n\n    # test set accuracy\n    acc = evaluate(model, g, features, labels, test_mask)\n    print(\"Test accuracy {:.2%}\".format(acc))\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"SGC\")\n    register_data_args(parser)\n    parser.add_argument(\"--gpu\", type=int, default=-1, help=\"gpu\")\n    parser.add_argument(\"--lr\", type=float, default=0.2, help=\"learning rate\")\n    parser.add_argument(\n        \"--bias\", action=\"store_true\", default=False, help=\"flag to use bias\"\n    )\n    parser.add_argument(\n        \"--n-epochs\", type=int, default=100, help=\"number of training epochs\"\n    )\n    parser.add_argument(\n        \"--weight-decay\", type=float, default=5e-6, help=\"Weight for L2 loss\"\n    )\n    args = parser.parse_args()\n    print(args)\n\n    main(args)\n"
  },
  {
    "path": "examples/mxnet/tagcn/README.md",
    "content": "Topology Adaptive Graph Convolutional networks (TAGCN)\n============\n\n- Paper link: [https://arxiv.org/abs/1710.10370](https://arxiv.org/abs/1710.10370)\n\nDependencies\n------------\n- MXNet nightly build\n- requests\n\n``bash\npip install mxnet --pre\npip install requests\n``\n\nResults\n-------\nRun with following (available dataset: \"cora\", \"citeseer\", \"pubmed\")\n```bash\nDGLBACKEND=mxnet python3 train.py --dataset cora --gpu 0 --self-loop\n```\n\n* cora: ~0.820 (paper: 0.833)\n* citeseer: ~0.702 (paper: 0.714)\n* pubmed: ～0.798 (paper: 0.811)"
  },
  {
    "path": "examples/mxnet/tagcn/tagcn.py",
    "content": "\"\"\"TAGCN using DGL nn package\n\nReferences:\n- Topology Adaptive Graph Convolutional Networks\n- Paper: https://arxiv.org/abs/1710.10370\n\"\"\"\nimport dgl\nimport mxnet as mx\nfrom dgl.nn.mxnet import TAGConv\nfrom mxnet import gluon\n\n\nclass TAGCN(gluon.Block):\n    def __init__(\n        self, g, in_feats, n_hidden, n_classes, n_layers, activation, dropout\n    ):\n        super(TAGCN, self).__init__()\n        self.g = g\n        self.layers = gluon.nn.Sequential()\n        # input layer\n        self.layers.add(TAGConv(in_feats, n_hidden, activation=activation))\n        # hidden layers\n        for i in range(n_layers - 1):\n            self.layers.add(TAGConv(n_hidden, n_hidden, activation=activation))\n        # output layer\n        self.layers.add(TAGConv(n_hidden, n_classes))  # activation=None\n        self.dropout = gluon.nn.Dropout(rate=dropout)\n\n    def forward(self, features):\n        h = features\n        for i, layer in enumerate(self.layers):\n            if i != 0:\n                h = self.dropout(h)\n            h = layer(self.g, h)\n        return h\n"
  },
  {
    "path": "examples/mxnet/tagcn/train.py",
    "content": "import argparse\nimport time\n\nimport dgl\n\nimport mxnet as mx\nimport networkx as nx\nimport numpy as np\nfrom dgl.data import (\n    CiteseerGraphDataset,\n    CoraGraphDataset,\n    PubmedGraphDataset,\n    register_data_args,\n)\nfrom mxnet import gluon\nfrom tagcn import TAGCN\n\n\ndef evaluate(model, features, labels, mask):\n    pred = model(features).argmax(axis=1)\n    accuracy = ((pred == labels) * mask).sum() / mask.sum().asscalar()\n    return accuracy.asscalar()\n\n\ndef main(args):\n    # load and preprocess dataset\n    if args.dataset == \"cora\":\n        data = CoraGraphDataset()\n    elif args.dataset == \"citeseer\":\n        data = CiteseerGraphDataset()\n    elif args.dataset == \"pubmed\":\n        data = PubmedGraphDataset()\n    else:\n        raise ValueError(\"Unknown dataset: {}\".format(args.dataset))\n\n    g = data[0]\n    if args.gpu < 0:\n        cuda = False\n        ctx = mx.cpu(0)\n    else:\n        cuda = True\n        ctx = mx.gpu(args.gpu)\n        g = g.to(ctx)\n\n    features = g.ndata[\"feat\"]\n    labels = mx.nd.array(g.ndata[\"label\"], dtype=\"float32\", ctx=ctx)\n    train_mask = g.ndata[\"train_mask\"]\n    val_mask = g.ndata[\"val_mask\"]\n    test_mask = g.ndata[\"test_mask\"]\n    in_feats = features.shape[1]\n    n_classes = data.num_classes\n    n_edges = data.graph.number_of_edges()\n    print(\n        \"\"\"----Data statistics------'\n      #Edges %d\n      #Classes %d\n      #Train samples %d\n      #Val samples %d\n      #Test samples %d\"\"\"\n        % (\n            n_edges,\n            n_classes,\n            train_mask.sum().asscalar(),\n            val_mask.sum().asscalar(),\n            test_mask.sum().asscalar(),\n        )\n    )\n\n    # add self loop\n    g = dgl.remove_self_loop(g)\n    g = dgl.add_self_loop(g)\n\n    # create TAGCN model\n    model = TAGCN(\n        g,\n        in_feats,\n        args.n_hidden,\n        n_classes,\n        args.n_layers,\n        mx.nd.relu,\n        args.dropout,\n    )\n\n    model.initialize(ctx=ctx)\n    n_train_samples = train_mask.sum().asscalar()\n    loss_fcn = gluon.loss.SoftmaxCELoss()\n\n    # use optimizer\n    print(model.collect_params())\n    trainer = gluon.Trainer(\n        model.collect_params(),\n        \"adam\",\n        {\"learning_rate\": args.lr, \"wd\": args.weight_decay},\n    )\n\n    # initialize graph\n    dur = []\n    for epoch in range(args.n_epochs):\n        if epoch >= 3:\n            t0 = time.time()\n        # forward\n        with mx.autograd.record():\n            pred = model(features)\n            loss = loss_fcn(pred, labels, mx.nd.expand_dims(train_mask, 1))\n            loss = loss.sum() / n_train_samples\n\n        loss.backward()\n        trainer.step(batch_size=1)\n\n        if epoch >= 3:\n            loss.asscalar()\n            dur.append(time.time() - t0)\n            acc = evaluate(model, features, labels, val_mask)\n            print(\n                \"Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | \"\n                \"ETputs(KTEPS) {:.2f}\".format(\n                    epoch,\n                    np.mean(dur),\n                    loss.asscalar(),\n                    acc,\n                    n_edges / np.mean(dur) / 1000,\n                )\n            )\n\n    print()\n    acc = evaluate(model, features, labels, val_mask)\n    print(\"Test accuracy {:.2%}\".format(acc))\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"TAGCN\")\n    register_data_args(parser)\n    parser.add_argument(\n        \"--dropout\", type=float, default=0.5, help=\"dropout probability\"\n    )\n    parser.add_argument(\"--gpu\", type=int, default=-1, help=\"gpu\")\n    parser.add_argument(\"--lr\", type=float, default=1e-2, help=\"learning rate\")\n    parser.add_argument(\n        \"--n-epochs\", type=int, default=200, help=\"number of training epochs\"\n    )\n    parser.add_argument(\n        \"--n-hidden\", type=int, default=16, help=\"number of hidden tagcn units\"\n    )\n    parser.add_argument(\n        \"--n-layers\", type=int, default=1, help=\"number of hidden tagcn layers\"\n    )\n    parser.add_argument(\n        \"--weight-decay\", type=float, default=5e-4, help=\"Weight for L2 loss\"\n    )\n    parser.add_argument(\n        \"--self-loop\",\n        action=\"store_true\",\n        help=\"graph self-loop (default=False)\",\n    )\n    parser.set_defaults(self_loop=False)\n    args = parser.parse_args()\n    print(args)\n\n    main(args)\n"
  },
  {
    "path": "examples/mxnet/tree_lstm/README.md",
    "content": "# Tree-LSTM\r\nThis is a re-implementation of the following paper:\r\n\r\n> [**Improved Semantic Representations From Tree-Structured Long Short-Term Memory Networks**](http://arxiv.org/abs/1503.00075)\r\n> *Kai Sheng Tai, Richard Socher, and Christopher Manning*.\r\n\r\nThe provided implementation can achieve a test accuracy of 51.72 which is comparable with the result reported in the original paper: 51.0(±0.5).\r\n\r\n## Dependencies\r\n* MXNet nightly build\r\n* requests\r\n* nltk\r\n\r\n```bash\r\npip install mxnet --pre\r\npip install requests nltk\r\n```\r\n\r\n## Data\r\nThe script will download the [SST dataset] (http://nlp.stanford.edu/sentiment/index.html) and the GloVe 840B.300d embedding automatically if `--use-glove` is specified (note: download may take a while).\r\n\r\n## Usage\r\n```\r\nDGLBACKEND=mxnet python3 train.py --gpu 0\r\n```\r\n\r\n## Speed Test\r\n\r\nSee https://docs.google.com/spreadsheets/d/1eCQrVn7g0uWriz63EbEDdes2ksMdKdlbWMyT8PSU4rc .\r\n\r\n## Note\r\nThe code can work with MXNet 1.5.1\r\n"
  },
  {
    "path": "examples/mxnet/tree_lstm/train.py",
    "content": "import argparse\nimport collections\nimport os\nimport time\nimport warnings\nimport zipfile\n\nos.environ[\"DGLBACKEND\"] = \"mxnet\"\nos.environ[\"MXNET_GPU_MEM_POOL_TYPE\"] = \"Round\"\n\nimport dgl\nimport dgl.data as data\nimport mxnet as mx\nimport numpy as np\nfrom mxnet import gluon\nfrom tree_lstm import TreeLSTM\n\nSSTBatch = collections.namedtuple(\n    \"SSTBatch\", [\"graph\", \"mask\", \"wordid\", \"label\"]\n)\n\n\ndef batcher(ctx):\n    def batcher_dev(batch):\n        batch_trees = dgl.batch(batch)\n        return SSTBatch(\n            graph=batch_trees,\n            mask=batch_trees.ndata[\"mask\"].as_in_context(ctx),\n            wordid=batch_trees.ndata[\"x\"].as_in_context(ctx),\n            label=batch_trees.ndata[\"y\"].as_in_context(ctx),\n        )\n\n    return batcher_dev\n\n\ndef prepare_glove():\n    if not (\n        os.path.exists(\"glove.840B.300d.txt\")\n        and data.utils.check_sha1(\n            \"glove.840B.300d.txt\",\n            sha1_hash=\"294b9f37fa64cce31f9ebb409c266fc379527708\",\n        )\n    ):\n        zip_path = data.utils.download(\n            \"http://nlp.stanford.edu/data/glove.840B.300d.zip\",\n            sha1_hash=\"8084fbacc2dee3b1fd1ca4cc534cbfff3519ed0d\",\n        )\n        with zipfile.ZipFile(zip_path, \"r\") as zf:\n            zf.extractall()\n        if not data.utils.check_sha1(\n            \"glove.840B.300d.txt\",\n            sha1_hash=\"294b9f37fa64cce31f9ebb409c266fc379527708\",\n        ):\n            warnings.warn(\n                \"The downloaded glove embedding file checksum mismatch. File content \"\n                \"may be corrupted.\"\n            )\n\n\ndef main(args):\n    np.random.seed(args.seed)\n    mx.random.seed(args.seed)\n\n    best_epoch = -1\n    best_dev_acc = 0\n\n    cuda = args.gpu >= 0\n    if cuda:\n        if args.gpu in mx.test_utils.list_gpus():\n            ctx = mx.gpu(args.gpu)\n        else:\n            print(\n                \"Requested GPU id {} was not found. Defaulting to CPU implementation\".format(\n                    args.gpu\n                )\n            )\n            ctx = mx.cpu()\n    else:\n        ctx = mx.cpu()\n\n    if args.use_glove:\n        prepare_glove()\n\n    trainset = data.SSTDataset()\n    train_loader = gluon.data.DataLoader(\n        dataset=trainset,\n        batch_size=args.batch_size,\n        batchify_fn=batcher(ctx),\n        shuffle=True,\n        num_workers=0,\n    )\n    devset = data.SSTDataset(mode=\"dev\")\n    dev_loader = gluon.data.DataLoader(\n        dataset=devset,\n        batch_size=100,\n        batchify_fn=batcher(ctx),\n        shuffle=True,\n        num_workers=0,\n    )\n\n    testset = data.SSTDataset(mode=\"test\")\n    test_loader = gluon.data.DataLoader(\n        dataset=testset,\n        batch_size=100,\n        batchify_fn=batcher(ctx),\n        shuffle=False,\n        num_workers=0,\n    )\n\n    model = TreeLSTM(\n        trainset.vocab_size,\n        args.x_size,\n        args.h_size,\n        trainset.num_classes,\n        args.dropout,\n        cell_type=\"childsum\" if args.child_sum else \"nary\",\n        pretrained_emb=trainset.pretrained_emb,\n        ctx=ctx,\n    )\n    print(model)\n    params_ex_emb = [\n        x\n        for x in model.collect_params().values()\n        if x.grad_req != \"null\" and x.shape[0] != trainset.vocab_size\n    ]\n    params_emb = list(model.embedding.collect_params().values())\n    for p in params_emb:\n        p.lr_mult = 0.1\n\n    model.initialize(mx.init.Xavier(magnitude=1), ctx=ctx)\n    model.hybridize()\n    trainer = gluon.Trainer(\n        model.collect_params(\"^(?!embedding).*$\"),\n        \"adagrad\",\n        {\"learning_rate\": args.lr, \"wd\": args.weight_decay},\n    )\n    trainer_emb = gluon.Trainer(\n        model.collect_params(\"^embedding.*$\"),\n        \"adagrad\",\n        {\"learning_rate\": args.lr},\n    )\n\n    dur = []\n    L = gluon.loss.SoftmaxCrossEntropyLoss(axis=1)\n    for epoch in range(args.epochs):\n        t_epoch = time.time()\n        for step, batch in enumerate(train_loader):\n            g = batch.graph\n            n = g.number_of_nodes()\n\n            # TODO begin_states function?\n            h = mx.nd.zeros((n, args.h_size), ctx=ctx)\n            c = mx.nd.zeros((n, args.h_size), ctx=ctx)\n            if step >= 3:\n                t0 = time.time()  # tik\n            with mx.autograd.record():\n                pred = model(batch, h, c)\n                loss = L(pred, batch.label)\n\n            loss.backward()\n            trainer.step(args.batch_size)\n            trainer_emb.step(args.batch_size)\n\n            if step >= 3:\n                dur.append(time.time() - t0)  # tok\n\n            if step > 0 and step % args.log_every == 0:\n                pred = pred.argmax(axis=1).astype(batch.label.dtype)\n                acc = (batch.label == pred).sum()\n                root_ids = [\n                    i\n                    for i in range(batch.graph.number_of_nodes())\n                    if batch.graph.out_degrees(i) == 0\n                ]\n                root_acc = np.sum(\n                    batch.label.asnumpy()[root_ids] == pred.asnumpy()[root_ids]\n                )\n\n                print(\n                    \"Epoch {:05d} | Step {:05d} | Loss {:.4f} | Acc {:.4f} | Root Acc {:.4f} | Time(s) {:.4f}\".format(\n                        epoch,\n                        step,\n                        loss.sum().asscalar(),\n                        1.0 * acc.asscalar() / len(batch.label),\n                        1.0 * root_acc / len(root_ids),\n                        np.mean(dur),\n                    )\n                )\n        print(\n            \"Epoch {:05d} training time {:.4f}s\".format(\n                epoch, time.time() - t_epoch\n            )\n        )\n\n        # eval on dev set\n        accs = []\n        root_accs = []\n        for step, batch in enumerate(dev_loader):\n            g = batch.graph\n            n = g.number_of_nodes()\n            h = mx.nd.zeros((n, args.h_size), ctx=ctx)\n            c = mx.nd.zeros((n, args.h_size), ctx=ctx)\n            pred = model(batch, h, c).argmax(1).astype(batch.label.dtype)\n\n            acc = (batch.label == pred).sum().asscalar()\n            accs.append([acc, len(batch.label)])\n            root_ids = [\n                i\n                for i in range(batch.graph.number_of_nodes())\n                if batch.graph.out_degrees(i) == 0\n            ]\n            root_acc = np.sum(\n                batch.label.asnumpy()[root_ids] == pred.asnumpy()[root_ids]\n            )\n            root_accs.append([root_acc, len(root_ids)])\n\n        dev_acc = (\n            1.0 * np.sum([x[0] for x in accs]) / np.sum([x[1] for x in accs])\n        )\n        dev_root_acc = (\n            1.0\n            * np.sum([x[0] for x in root_accs])\n            / np.sum([x[1] for x in root_accs])\n        )\n        print(\n            \"Epoch {:05d} | Dev Acc {:.4f} | Root Acc {:.4f}\".format(\n                epoch, dev_acc, dev_root_acc\n            )\n        )\n\n        if dev_root_acc > best_dev_acc:\n            best_dev_acc = dev_root_acc\n            best_epoch = epoch\n            model.save_parameters(\"best_{}.params\".format(args.seed))\n        else:\n            if best_epoch <= epoch - 10:\n                break\n\n        # lr decay\n        trainer.set_learning_rate(max(1e-5, trainer.learning_rate * 0.99))\n        print(trainer.learning_rate)\n        trainer_emb.set_learning_rate(\n            max(1e-5, trainer_emb.learning_rate * 0.99)\n        )\n        print(trainer_emb.learning_rate)\n\n    # test\n    model.load_parameters(\"best_{}.params\".format(args.seed))\n    accs = []\n    root_accs = []\n    for step, batch in enumerate(test_loader):\n        g = batch.graph\n        n = g.number_of_nodes()\n        h = mx.nd.zeros((n, args.h_size), ctx=ctx)\n        c = mx.nd.zeros((n, args.h_size), ctx=ctx)\n        pred = model(batch, h, c).argmax(axis=1).astype(batch.label.dtype)\n\n        acc = (batch.label == pred).sum().asscalar()\n        accs.append([acc, len(batch.label)])\n        root_ids = [\n            i\n            for i in range(batch.graph.number_of_nodes())\n            if batch.graph.out_degrees(i) == 0\n        ]\n        root_acc = np.sum(\n            batch.label.asnumpy()[root_ids] == pred.asnumpy()[root_ids]\n        )\n        root_accs.append([root_acc, len(root_ids)])\n\n    test_acc = 1.0 * np.sum([x[0] for x in accs]) / np.sum([x[1] for x in accs])\n    test_root_acc = (\n        1.0\n        * np.sum([x[0] for x in root_accs])\n        / np.sum([x[1] for x in root_accs])\n    )\n    print(\n        \"------------------------------------------------------------------------------------\"\n    )\n    print(\n        \"Epoch {:05d} | Test Acc {:.4f} | Root Acc {:.4f}\".format(\n            best_epoch, test_acc, test_root_acc\n        )\n    )\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--gpu\", type=int, default=0)\n    parser.add_argument(\"--seed\", type=int, default=41)\n    parser.add_argument(\"--batch-size\", type=int, default=256)\n    parser.add_argument(\"--child-sum\", action=\"store_true\")\n    parser.add_argument(\"--x-size\", type=int, default=300)\n    parser.add_argument(\"--h-size\", type=int, default=150)\n    parser.add_argument(\"--epochs\", type=int, default=100)\n    parser.add_argument(\"--log-every\", type=int, default=5)\n    parser.add_argument(\"--lr\", type=float, default=0.05)\n    parser.add_argument(\"--weight-decay\", type=float, default=1e-4)\n    parser.add_argument(\"--dropout\", type=float, default=0.5)\n    parser.add_argument(\"--use-glove\", action=\"store_true\")\n    args = parser.parse_args()\n    print(args)\n    main(args)\n"
  },
  {
    "path": "examples/mxnet/tree_lstm/tree_lstm.py",
    "content": "\"\"\"\nImproved Semantic Representations From Tree-Structured Long Short-Term Memory Networks\nhttps://arxiv.org/abs/1503.00075\n\"\"\"\nimport itertools\nimport time\n\nimport dgl\n\nimport mxnet as mx\nimport networkx as nx\nimport numpy as np\nfrom mxnet import gluon\n\n\nclass _TreeLSTMCellNodeFunc(gluon.HybridBlock):\n    def hybrid_forward(self, F, iou, b_iou, c):\n        iou = F.broadcast_add(iou, b_iou)\n        i, o, u = iou.split(num_outputs=3, axis=1)\n        i, o, u = i.sigmoid(), o.sigmoid(), u.tanh()\n        c = i * u + c\n        h = o * c.tanh()\n\n        return h, c\n\n\nclass _TreeLSTMCellReduceFunc(gluon.HybridBlock):\n    def __init__(self, U_iou, U_f):\n        super(_TreeLSTMCellReduceFunc, self).__init__()\n        self.U_iou = U_iou\n        self.U_f = U_f\n\n    def hybrid_forward(self, F, h, c):\n        h_cat = h.reshape((0, -1))\n        f = self.U_f(h_cat).sigmoid().reshape_like(h)\n        c = (f * c).sum(axis=1)\n        iou = self.U_iou(h_cat)\n        return iou, c\n\n\nclass _TreeLSTMCell(gluon.HybridBlock):\n    def __init__(self, h_size):\n        super(_TreeLSTMCell, self).__init__()\n        self._apply_node_func = _TreeLSTMCellNodeFunc()\n        self.b_iou = self.params.get(\n            \"bias\", shape=(1, 3 * h_size), init=\"zeros\"\n        )\n\n    def message_func(self, edges):\n        return {\"h\": edges.src[\"h\"], \"c\": edges.src[\"c\"]}\n\n    def apply_node_func(self, nodes):\n        iou = nodes.data[\"iou\"]\n        b_iou, c = self.b_iou.data(iou.context), nodes.data[\"c\"]\n        h, c = self._apply_node_func(iou, b_iou, c)\n        return {\"h\": h, \"c\": c}\n\n\nclass TreeLSTMCell(_TreeLSTMCell):\n    def __init__(self, x_size, h_size):\n        super(TreeLSTMCell, self).__init__(h_size)\n        self._reduce_func = _TreeLSTMCellReduceFunc(\n            gluon.nn.Dense(3 * h_size, use_bias=False),\n            gluon.nn.Dense(2 * h_size),\n        )\n        self.W_iou = gluon.nn.Dense(3 * h_size, use_bias=False)\n\n    def reduce_func(self, nodes):\n        h, c = nodes.mailbox[\"h\"], nodes.mailbox[\"c\"]\n        iou, c = self._reduce_func(h, c)\n        return {\"iou\": iou, \"c\": c}\n\n\nclass ChildSumTreeLSTMCell(_TreeLSTMCell):\n    def __init__(self, x_size, h_size):\n        super(ChildSumTreeLSTMCell, self).__init__()\n        self.W_iou = gluon.nn.Dense(3 * h_size, use_bias=False)\n        self.U_iou = gluon.nn.Dense(3 * h_size, use_bias=False)\n        self.U_f = gluon.nn.Dense(h_size)\n\n    def reduce_func(self, nodes):\n        h_tild = nodes.mailbox[\"h\"].sum(axis=1)\n        f = self.U_f(nodes.mailbox[\"h\"]).sigmoid()\n        c = (f * nodes.mailbox[\"c\"]).sum(axis=1)\n        return {\"iou\": self.U_iou(h_tild), \"c\": c}\n\n\nclass TreeLSTM(gluon.nn.Block):\n    def __init__(\n        self,\n        num_vocabs,\n        x_size,\n        h_size,\n        num_classes,\n        dropout,\n        cell_type=\"nary\",\n        pretrained_emb=None,\n        ctx=None,\n    ):\n        super(TreeLSTM, self).__init__()\n        self.x_size = x_size\n        self.embedding = gluon.nn.Embedding(num_vocabs, x_size)\n        if pretrained_emb is not None:\n            print(\"Using glove\")\n            self.embedding.initialize(ctx=ctx)\n            self.embedding.weight.set_data(pretrained_emb)\n        self.dropout = gluon.nn.Dropout(dropout)\n        self.linear = gluon.nn.Dense(num_classes)\n        cell = TreeLSTMCell if cell_type == \"nary\" else ChildSumTreeLSTMCell\n        self.cell = cell(x_size, h_size)\n        self.ctx = ctx\n\n    def forward(self, batch, h, c):\n        \"\"\"Compute tree-lstm prediction given a batch.\n        Parameters\n        ----------\n        batch : dgl.data.SSTBatch\n            The data batch.\n        h : Tensor\n            Initial hidden state.\n        c : Tensor\n            Initial cell state.\n        Returns\n        -------\n        logits : Tensor\n            The prediction of each node.\n        \"\"\"\n        g = batch.graph\n        g = g.to(self.ctx)\n        # feed embedding\n        embeds = self.embedding(batch.wordid * batch.mask)\n        wiou = self.cell.W_iou(self.dropout(embeds))\n        g.ndata[\"iou\"] = wiou * batch.mask.expand_dims(-1).astype(wiou.dtype)\n        g.ndata[\"h\"] = h\n        g.ndata[\"c\"] = c\n        # propagate\n        dgl.prop_nodes_topo(\n            g,\n            message_func=self.cell.message_func,\n            reduce_func=self.cell.reduce_func,\n            apply_node_func=self.cell.apply_node_func,\n        )\n        # compute logits\n        h = self.dropout(g.ndata.pop(\"h\"))\n        logits = self.linear(h)\n        return logits\n"
  },
  {
    "path": "examples/pytorch/GATNE-T/README.md",
    "content": "Representation Learning for Attributed Multiplex Heterogeneous Network (GANTE)\n============\n\n- Paper link: [https://arxiv.org/abs/1905.01669](https://arxiv.org/abs/1905.01669)\n- Author's code repo: [https://github.com/THUDM/GATNE](https://github.com/THUDM/GATNE). Note that only GATNE-T is implemented here.\n\nRequirements\n------------\n- requirements\n\n```bash\npip install -r requirements.txt\n```\n\nAlso requires PyTorch 1.7.0+.\n\nDatasets\n--------\n\nTo prepare the datasets:\n1. ```bash\n   mkdir data\n   cd data\n   ```\n2. Download datasets from the following links:\n    - example: https://s3.us-west-2.amazonaws.com/dgl-data/dataset/recsys/GATNE/example.zip\n    - amazon: https://s3.us-west-2.amazonaws.com/dgl-data/dataset/recsys/GATNE/amazon.zip\n    - youtube: https://s3.us-west-2.amazonaws.com/dgl-data/dataset/recsys/GATNE/youtube.zip\n    - twitter: https://s3.us-west-2.amazonaws.com/dgl-data/dataset/recsys/GATNE/twitter.zip\n3. Unzip the datasets\n\nTraining\n--------\n\nRun with following (available dataset: \"example\", \"youtube\", \"amazon\")\n```bash\npython src/main.py --input data/example\n```\n\nTo run on \"twitter\" dataset, use\n```bash\npython src/main.py --input data/twitter --eval-type 1 --gpu 0\n```\n\nFor a big dataset, use sparse to avoid cuda out of memory in backward\n```bash\npython src/main_sparse.py --input data/example --gpu 0\n```\n\nIf you have multiple GPUs, you can also accelerate training with [`DistributedDataParallel`](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html)\n```bash\npython src/main_sparse_multi_gpus.py --input data/example --gpu 0,1\n```\n\n**It is worth noting that DistributedDataParallel will cause more cuda memory consumption and a certain loss of preformance.**\n\n\nResults\n-------\nAll the results match the [official code](https://github.com/THUDM/GATNE/blob/master/src/main_pytorch.py) with the same hyper parameter values, including twiiter dataset (auc, pr, f1 is 76.29, 76.17, 69.34, respectively).\n\n|         | auc   | pr    | f1    |\n| ------- | ----- | ----- | ----- |\n| amazon  | 96.88 | 96.31 | 92.12 |\n| youtube | 82.29 | 80.35 | 74.63 |\n| twitter | 72.40 | 74.40 | 65.89 |\n| example | 94.65 | 94.57 | 89.99 |\n"
  },
  {
    "path": "examples/pytorch/GATNE-T/requirements.txt",
    "content": "tqdm\nnumpy\nscikit-learn\nnetworkx\ngensim\nrequests\n--pre dgl-cu101\n"
  },
  {
    "path": "examples/pytorch/GATNE-T/scripts/run_example.sh",
    "content": "python src/main.py --input data/example --gpu 0\n"
  },
  {
    "path": "examples/pytorch/GATNE-T/scripts/run_example_sparse.sh",
    "content": "python src/main_sparse.py --input data/example --gpu 0\n"
  },
  {
    "path": "examples/pytorch/GATNE-T/scripts/run_example_sparse_multi_gpus.sh",
    "content": "python src/main_sparse_multi_gpus.py --input data/example\n"
  },
  {
    "path": "examples/pytorch/GATNE-T/src/main.py",
    "content": "import math\nimport os\nimport sys\nimport time\nfrom collections import defaultdict\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom numpy import random\nfrom torch.nn.parameter import Parameter\nfrom tqdm.auto import tqdm\nfrom utils import *\n\nimport dgl\nimport dgl.function as fn\n\n\ndef get_graph(network_data, vocab):\n    \"\"\"Build graph, treat all nodes as the same type\n\n    Parameters\n    ----------\n    network_data: a dict\n        keys describing the edge types, values representing edges\n    vocab: a dict\n        mapping node IDs to node indices\n    Output\n    ------\n    DGLGraph\n        a heterogenous graph, with one node type and different edge types\n    \"\"\"\n    graphs = []\n\n    node_type = \"_N\"  # '_N' can be replaced by an arbitrary name\n    data_dict = dict()\n    num_nodes_dict = {node_type: len(vocab)}\n\n    for edge_type in network_data:\n        tmp_data = network_data[edge_type]\n        src = []\n        dst = []\n        for edge in tmp_data:\n            src.extend([vocab[edge[0]], vocab[edge[1]]])\n            dst.extend([vocab[edge[1]], vocab[edge[0]]])\n        data_dict[(node_type, edge_type, node_type)] = (src, dst)\n    graph = dgl.heterograph(data_dict, num_nodes_dict)\n\n    return graph\n\n\nclass NeighborSampler(object):\n    def __init__(self, g, num_fanouts):\n        self.g = g\n        self.num_fanouts = num_fanouts\n\n    def sample(self, pairs):\n        heads, tails, types = zip(*pairs)\n        seeds, head_invmap = torch.unique(\n            torch.LongTensor(heads), return_inverse=True\n        )\n        blocks = []\n        for fanout in reversed(self.num_fanouts):\n            sampled_graph = dgl.sampling.sample_neighbors(self.g, seeds, fanout)\n            sampled_block = dgl.to_block(sampled_graph, seeds)\n            seeds = sampled_block.srcdata[dgl.NID]\n            blocks.insert(0, sampled_block)\n        return (\n            blocks,\n            torch.LongTensor(head_invmap),\n            torch.LongTensor(tails),\n            torch.LongTensor(types),\n        )\n\n\nclass DGLGATNE(nn.Module):\n    def __init__(\n        self,\n        num_nodes,\n        embedding_size,\n        embedding_u_size,\n        edge_types,\n        edge_type_count,\n        dim_a,\n    ):\n        super(DGLGATNE, self).__init__()\n        self.num_nodes = num_nodes\n        self.embedding_size = embedding_size\n        self.embedding_u_size = embedding_u_size\n        self.edge_types = edge_types\n        self.edge_type_count = edge_type_count\n        self.dim_a = dim_a\n\n        self.node_embeddings = Parameter(\n            torch.FloatTensor(num_nodes, embedding_size)\n        )\n        self.node_type_embeddings = Parameter(\n            torch.FloatTensor(num_nodes, edge_type_count, embedding_u_size)\n        )\n        self.trans_weights = Parameter(\n            torch.FloatTensor(edge_type_count, embedding_u_size, embedding_size)\n        )\n        self.trans_weights_s1 = Parameter(\n            torch.FloatTensor(edge_type_count, embedding_u_size, dim_a)\n        )\n        self.trans_weights_s2 = Parameter(\n            torch.FloatTensor(edge_type_count, dim_a, 1)\n        )\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        self.node_embeddings.data.uniform_(-1.0, 1.0)\n        self.node_type_embeddings.data.uniform_(-1.0, 1.0)\n        self.trans_weights.data.normal_(\n            std=1.0 / math.sqrt(self.embedding_size)\n        )\n        self.trans_weights_s1.data.normal_(\n            std=1.0 / math.sqrt(self.embedding_size)\n        )\n        self.trans_weights_s2.data.normal_(\n            std=1.0 / math.sqrt(self.embedding_size)\n        )\n\n    # embs: [batch_size, embedding_size]\n    def forward(self, block):\n        input_nodes = block.srcdata[dgl.NID]\n        output_nodes = block.dstdata[dgl.NID]\n        batch_size = block.number_of_dst_nodes()\n        node_embed = self.node_embeddings\n        node_type_embed = []\n\n        with block.local_scope():\n            for i in range(self.edge_type_count):\n                edge_type = self.edge_types[i]\n                block.srcdata[edge_type] = self.node_type_embeddings[\n                    input_nodes, i\n                ]\n                block.dstdata[edge_type] = self.node_type_embeddings[\n                    output_nodes, i\n                ]\n                block.update_all(\n                    fn.copy_u(edge_type, \"m\"),\n                    fn.sum(\"m\", edge_type),\n                    etype=edge_type,\n                )\n                node_type_embed.append(block.dstdata[edge_type])\n\n            node_type_embed = torch.stack(node_type_embed, 1)\n            tmp_node_type_embed = node_type_embed.unsqueeze(2).view(\n                -1, 1, self.embedding_u_size\n            )\n            trans_w = (\n                self.trans_weights.unsqueeze(0)\n                .repeat(batch_size, 1, 1, 1)\n                .view(-1, self.embedding_u_size, self.embedding_size)\n            )\n            trans_w_s1 = (\n                self.trans_weights_s1.unsqueeze(0)\n                .repeat(batch_size, 1, 1, 1)\n                .view(-1, self.embedding_u_size, self.dim_a)\n            )\n            trans_w_s2 = (\n                self.trans_weights_s2.unsqueeze(0)\n                .repeat(batch_size, 1, 1, 1)\n                .view(-1, self.dim_a, 1)\n            )\n\n            attention = (\n                F.softmax(\n                    torch.matmul(\n                        torch.tanh(\n                            torch.matmul(tmp_node_type_embed, trans_w_s1)\n                        ),\n                        trans_w_s2,\n                    )\n                    .squeeze(2)\n                    .view(-1, self.edge_type_count),\n                    dim=1,\n                )\n                .unsqueeze(1)\n                .repeat(1, self.edge_type_count, 1)\n            )\n\n            node_type_embed = torch.matmul(attention, node_type_embed).view(\n                -1, 1, self.embedding_u_size\n            )\n            node_embed = node_embed[output_nodes].unsqueeze(1).repeat(\n                1, self.edge_type_count, 1\n            ) + torch.matmul(node_type_embed, trans_w).view(\n                -1, self.edge_type_count, self.embedding_size\n            )\n            last_node_embed = F.normalize(node_embed, dim=2)\n\n            return (\n                last_node_embed  # [batch_size, edge_type_count, embedding_size]\n            )\n\n\nclass NSLoss(nn.Module):\n    def __init__(self, num_nodes, num_sampled, embedding_size):\n        super(NSLoss, self).__init__()\n        self.num_nodes = num_nodes\n        self.num_sampled = num_sampled\n        self.embedding_size = embedding_size\n        self.weights = Parameter(torch.FloatTensor(num_nodes, embedding_size))\n        # [ (log(i+2) - log(i+1)) / log(num_nodes + 1)]\n        self.sample_weights = F.normalize(\n            torch.Tensor(\n                [\n                    (math.log(k + 2) - math.log(k + 1))\n                    / math.log(num_nodes + 1)\n                    for k in range(num_nodes)\n                ]\n            ),\n            dim=0,\n        )\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        self.weights.data.normal_(std=1.0 / math.sqrt(self.embedding_size))\n\n    def forward(self, input, embs, label):\n        n = input.shape[0]\n        log_target = torch.log(\n            torch.sigmoid(torch.sum(torch.mul(embs, self.weights[label]), 1))\n        )\n        negs = torch.multinomial(\n            self.sample_weights, self.num_sampled * n, replacement=True\n        ).view(n, self.num_sampled)\n        noise = torch.neg(self.weights[negs])\n        sum_log_sampled = torch.sum(\n            torch.log(torch.sigmoid(torch.bmm(noise, embs.unsqueeze(2)))), 1\n        ).squeeze()\n\n        loss = log_target + sum_log_sampled\n        return -loss.sum() / n\n\n\ndef train_model(network_data):\n    index2word, vocab, type_nodes = generate_vocab(network_data)\n\n    edge_types = list(network_data.keys())\n    num_nodes = len(index2word)\n    edge_type_count = len(edge_types)\n    epochs = args.epoch\n    batch_size = args.batch_size\n    embedding_size = args.dimensions\n    embedding_u_size = args.edge_dim\n    u_num = edge_type_count\n    num_sampled = args.negative_samples\n    dim_a = args.att_dim\n    att_head = 1\n    neighbor_samples = args.neighbor_samples\n    num_workers = args.workers\n\n    device = torch.device(\n        \"cuda\" if args.gpu is not None and torch.cuda.is_available() else \"cpu\"\n    )\n\n    g = get_graph(network_data, vocab)\n    all_walks = []\n    for i in range(edge_type_count):\n        nodes = torch.LongTensor(type_nodes[i] * args.num_walks)\n        traces, types = dgl.sampling.random_walk(\n            g, nodes, metapath=[edge_types[i]] * (neighbor_samples - 1)\n        )\n        all_walks.append(traces)\n\n    train_pairs = generate_pairs(all_walks, args.window_size, num_workers)\n    neighbor_sampler = NeighborSampler(g, [neighbor_samples])\n    train_dataloader = torch.utils.data.DataLoader(\n        train_pairs,\n        batch_size=batch_size,\n        collate_fn=neighbor_sampler.sample,\n        shuffle=True,\n        num_workers=num_workers,\n        pin_memory=True,\n    )\n    model = DGLGATNE(\n        num_nodes,\n        embedding_size,\n        embedding_u_size,\n        edge_types,\n        edge_type_count,\n        dim_a,\n    )\n    nsloss = NSLoss(num_nodes, num_sampled, embedding_size)\n    model.to(device)\n    nsloss.to(device)\n\n    optimizer = torch.optim.Adam(\n        [{\"params\": model.parameters()}, {\"params\": nsloss.parameters()}],\n        lr=1e-3,\n    )\n\n    best_score = 0\n    patience = 0\n    for epoch in range(epochs):\n        model.train()\n        random.shuffle(train_pairs)\n\n        data_iter = tqdm(\n            train_dataloader,\n            desc=\"epoch %d\" % (epoch),\n            total=(len(train_pairs) + (batch_size - 1)) // batch_size,\n        )\n        avg_loss = 0.0\n\n        for i, (block, head_invmap, tails, block_types) in enumerate(data_iter):\n            optimizer.zero_grad()\n            # embs: [batch_size, edge_type_count, embedding_size]\n            block_types = block_types.to(device)\n            embs = model(block[0].to(device))[head_invmap]\n            embs = embs.gather(\n                1,\n                block_types.view(-1, 1, 1).expand(\n                    embs.shape[0], 1, embs.shape[2]\n                ),\n            )[:, 0]\n            loss = nsloss(\n                block[0].dstdata[dgl.NID][head_invmap].to(device),\n                embs,\n                tails.to(device),\n            )\n            loss.backward()\n            optimizer.step()\n            avg_loss += loss.item()\n\n            post_fix = {\n                \"epoch\": epoch,\n                \"iter\": i,\n                \"avg_loss\": avg_loss / (i + 1),\n                \"loss\": loss.item(),\n            }\n            data_iter.set_postfix(post_fix)\n\n        model.eval()\n        # {'1': {}, '2': {}}\n        final_model = dict(\n            zip(edge_types, [dict() for _ in range(edge_type_count)])\n        )\n        for i in range(num_nodes):\n            train_inputs = (\n                torch.tensor([i for _ in range(edge_type_count)])\n                .unsqueeze(1)\n                .to(device)\n            )  # [i, i]\n            train_types = (\n                torch.tensor(list(range(edge_type_count)))\n                .unsqueeze(1)\n                .to(device)\n            )  # [0, 1]\n            pairs = torch.cat(\n                (train_inputs, train_inputs, train_types), dim=1\n            )  # (2, 3)\n            (\n                train_blocks,\n                train_invmap,\n                fake_tails,\n                train_types,\n            ) = neighbor_sampler.sample(pairs)\n\n            node_emb = model(train_blocks[0].to(device))[train_invmap]\n            node_emb = node_emb.gather(\n                1,\n                train_types.to(device)\n                .view(-1, 1, 1)\n                .expand(node_emb.shape[0], 1, node_emb.shape[2]),\n            )[:, 0]\n\n            for j in range(edge_type_count):\n                final_model[edge_types[j]][index2word[i]] = (\n                    node_emb[j].cpu().detach().numpy()\n                )\n\n        valid_aucs, valid_f1s, valid_prs = [], [], []\n        test_aucs, test_f1s, test_prs = [], [], []\n        for i in range(edge_type_count):\n            if args.eval_type == \"all\" or edge_types[i] in args.eval_type.split(\n                \",\"\n            ):\n                tmp_auc, tmp_f1, tmp_pr = evaluate(\n                    final_model[edge_types[i]],\n                    valid_true_data_by_edge[edge_types[i]],\n                    valid_false_data_by_edge[edge_types[i]],\n                    num_workers,\n                )\n                valid_aucs.append(tmp_auc)\n                valid_f1s.append(tmp_f1)\n                valid_prs.append(tmp_pr)\n\n                tmp_auc, tmp_f1, tmp_pr = evaluate(\n                    final_model[edge_types[i]],\n                    testing_true_data_by_edge[edge_types[i]],\n                    testing_false_data_by_edge[edge_types[i]],\n                    num_workers,\n                )\n                test_aucs.append(tmp_auc)\n                test_f1s.append(tmp_f1)\n                test_prs.append(tmp_pr)\n        print(\"valid auc:\", np.mean(valid_aucs))\n        print(\"valid pr:\", np.mean(valid_prs))\n        print(\"valid f1:\", np.mean(valid_f1s))\n\n        average_auc = np.mean(test_aucs)\n        average_f1 = np.mean(test_f1s)\n        average_pr = np.mean(test_prs)\n\n        cur_score = np.mean(valid_aucs)\n        if cur_score > best_score:\n            best_score = cur_score\n            patience = 0\n        else:\n            patience += 1\n            if patience > args.patience:\n                print(\"Early Stopping\")\n                break\n    return average_auc, average_f1, average_pr\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    file_name = args.input\n    print(args)\n\n    training_data_by_type = load_training_data(file_name + \"/train.txt\")\n    valid_true_data_by_edge, valid_false_data_by_edge = load_testing_data(\n        file_name + \"/valid.txt\"\n    )\n    testing_true_data_by_edge, testing_false_data_by_edge = load_testing_data(\n        file_name + \"/test.txt\"\n    )\n    start = time.time()\n    average_auc, average_f1, average_pr = train_model(training_data_by_type)\n    end = time.time()\n\n    print(\"Overall ROC-AUC:\", average_auc)\n    print(\"Overall PR-AUC\", average_pr)\n    print(\"Overall F1:\", average_f1)\n    print(\"Training Time\", end - start)\n"
  },
  {
    "path": "examples/pytorch/GATNE-T/src/main_sparse.py",
    "content": "import math\nimport os\nimport sys\nimport time\nfrom collections import defaultdict\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport tqdm\nfrom numpy import random\nfrom torch.nn.parameter import Parameter\nfrom utils import *\n\nimport dgl\nimport dgl.function as fn\n\n\ndef get_graph(network_data, vocab):\n    \"\"\"Build graph, treat all nodes as the same type\n\n    Parameters\n    ----------\n    network_data: a dict\n        keys describing the edge types, values representing edges\n    vocab: a dict\n        mapping node IDs to node indices\n    Output\n    ------\n    DGLGraph\n        a heterogenous graph, with one node type and different edge types\n    \"\"\"\n    graphs = []\n\n    node_type = \"_N\"  # '_N' can be replaced by an arbitrary name\n    data_dict = dict()\n    num_nodes_dict = {node_type: len(vocab)}\n\n    for edge_type in network_data:\n        tmp_data = network_data[edge_type]\n        src = []\n        dst = []\n        for edge in tmp_data:\n            src.extend([vocab[edge[0]], vocab[edge[1]]])\n            dst.extend([vocab[edge[1]], vocab[edge[0]]])\n        data_dict[(node_type, edge_type, node_type)] = (src, dst)\n    graph = dgl.heterograph(data_dict, num_nodes_dict)\n\n    return graph\n\n\nclass NeighborSampler(object):\n    def __init__(self, g, num_fanouts):\n        self.g = g\n        self.num_fanouts = num_fanouts\n\n    def sample(self, pairs):\n        pairs = np.stack(pairs)\n        heads, tails, types = pairs[:, 0], pairs[:, 1], pairs[:, 2]\n        seeds, head_invmap = torch.unique(\n            torch.LongTensor(heads), return_inverse=True\n        )\n        blocks = []\n        for fanout in reversed(self.num_fanouts):\n            sampled_graph = dgl.sampling.sample_neighbors(self.g, seeds, fanout)\n            sampled_block = dgl.to_block(sampled_graph, seeds)\n            seeds = sampled_block.srcdata[dgl.NID]\n            blocks.insert(0, sampled_block)\n        return (\n            blocks,\n            torch.LongTensor(head_invmap),\n            torch.LongTensor(tails),\n            torch.LongTensor(types),\n        )\n\n\nclass DGLGATNE(nn.Module):\n    def __init__(\n        self,\n        num_nodes,\n        embedding_size,\n        embedding_u_size,\n        edge_types,\n        edge_type_count,\n        dim_a,\n    ):\n        super(DGLGATNE, self).__init__()\n        self.num_nodes = num_nodes\n        self.embedding_size = embedding_size\n        self.embedding_u_size = embedding_u_size\n        self.edge_types = edge_types\n        self.edge_type_count = edge_type_count\n        self.dim_a = dim_a\n\n        self.node_embeddings = nn.Embedding(\n            num_nodes, embedding_size, sparse=True\n        )\n        self.node_type_embeddings = nn.Embedding(\n            num_nodes * edge_type_count, embedding_u_size, sparse=True\n        )\n        self.trans_weights = Parameter(\n            torch.FloatTensor(edge_type_count, embedding_u_size, embedding_size)\n        )\n        self.trans_weights_s1 = Parameter(\n            torch.FloatTensor(edge_type_count, embedding_u_size, dim_a)\n        )\n        self.trans_weights_s2 = Parameter(\n            torch.FloatTensor(edge_type_count, dim_a, 1)\n        )\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        self.node_embeddings.weight.data.uniform_(-1.0, 1.0)\n        self.node_type_embeddings.weight.data.uniform_(-1.0, 1.0)\n        self.trans_weights.data.normal_(\n            std=1.0 / math.sqrt(self.embedding_size)\n        )\n        self.trans_weights_s1.data.normal_(\n            std=1.0 / math.sqrt(self.embedding_size)\n        )\n        self.trans_weights_s2.data.normal_(\n            std=1.0 / math.sqrt(self.embedding_size)\n        )\n\n    # embs: [batch_size, embedding_size]\n    def forward(self, block):\n        input_nodes = block.srcdata[dgl.NID]\n        output_nodes = block.dstdata[dgl.NID]\n        batch_size = block.number_of_dst_nodes()\n        node_type_embed = []\n\n        with block.local_scope():\n            for i in range(self.edge_type_count):\n                edge_type = self.edge_types[i]\n                block.srcdata[edge_type] = self.node_type_embeddings(\n                    input_nodes * self.edge_type_count + i\n                )\n                block.dstdata[edge_type] = self.node_type_embeddings(\n                    output_nodes * self.edge_type_count + i\n                )\n                block.update_all(\n                    fn.copy_u(edge_type, \"m\"),\n                    fn.sum(\"m\", edge_type),\n                    etype=edge_type,\n                )\n                node_type_embed.append(block.dstdata[edge_type])\n\n            node_type_embed = torch.stack(node_type_embed, 1)\n            tmp_node_type_embed = node_type_embed.unsqueeze(2).view(\n                -1, 1, self.embedding_u_size\n            )\n            trans_w = (\n                self.trans_weights.unsqueeze(0)\n                .repeat(batch_size, 1, 1, 1)\n                .view(-1, self.embedding_u_size, self.embedding_size)\n            )\n            trans_w_s1 = (\n                self.trans_weights_s1.unsqueeze(0)\n                .repeat(batch_size, 1, 1, 1)\n                .view(-1, self.embedding_u_size, self.dim_a)\n            )\n            trans_w_s2 = (\n                self.trans_weights_s2.unsqueeze(0)\n                .repeat(batch_size, 1, 1, 1)\n                .view(-1, self.dim_a, 1)\n            )\n\n            attention = (\n                F.softmax(\n                    torch.matmul(\n                        torch.tanh(\n                            torch.matmul(tmp_node_type_embed, trans_w_s1)\n                        ),\n                        trans_w_s2,\n                    )\n                    .squeeze(2)\n                    .view(-1, self.edge_type_count),\n                    dim=1,\n                )\n                .unsqueeze(1)\n                .repeat(1, self.edge_type_count, 1)\n            )\n\n            node_type_embed = torch.matmul(attention, node_type_embed).view(\n                -1, 1, self.embedding_u_size\n            )\n            node_embed = self.node_embeddings(output_nodes).unsqueeze(1).repeat(\n                1, self.edge_type_count, 1\n            ) + torch.matmul(node_type_embed, trans_w).view(\n                -1, self.edge_type_count, self.embedding_size\n            )\n            last_node_embed = F.normalize(node_embed, dim=2)\n\n            return (\n                last_node_embed  # [batch_size, edge_type_count, embedding_size]\n            )\n\n\nclass NSLoss(nn.Module):\n    def __init__(self, num_nodes, num_sampled, embedding_size):\n        super(NSLoss, self).__init__()\n        self.num_nodes = num_nodes\n        self.num_sampled = num_sampled\n        self.embedding_size = embedding_size\n\n        # [ (log(i+2) - log(i+1)) / log(num_nodes + 1)]\n        self.sample_weights = F.normalize(\n            torch.Tensor(\n                [\n                    (math.log(k + 2) - math.log(k + 1))\n                    / math.log(num_nodes + 1)\n                    for k in range(num_nodes)\n                ]\n            ),\n            dim=0,\n        )\n        self.weights = nn.Embedding(num_nodes, embedding_size, sparse=True)\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        self.weights.weight.data.normal_(\n            std=1.0 / math.sqrt(self.embedding_size)\n        )\n\n    def forward(self, input, embs, label):\n        n = input.shape[0]\n        log_target = torch.log(\n            torch.sigmoid(torch.sum(torch.mul(embs, self.weights(label)), 1))\n        )\n        negs = (\n            torch.multinomial(\n                self.sample_weights, self.num_sampled * n, replacement=True\n            )\n            .view(n, self.num_sampled)\n            .to(input.device)\n        )\n        noise = torch.neg(self.weights(negs))\n        sum_log_sampled = torch.sum(\n            torch.log(torch.sigmoid(torch.bmm(noise, embs.unsqueeze(2)))), 1\n        ).squeeze()\n\n        loss = log_target + sum_log_sampled\n        return -loss.sum() / n\n\n\ndef train_model(network_data):\n    index2word, vocab, type_nodes = generate_vocab(network_data)\n\n    edge_types = list(network_data.keys())\n    num_nodes = len(index2word)\n    edge_type_count = len(edge_types)\n    epochs = args.epoch\n    batch_size = args.batch_size\n    embedding_size = args.dimensions\n    embedding_u_size = args.edge_dim\n    u_num = edge_type_count\n    num_sampled = args.negative_samples\n    dim_a = args.att_dim\n    att_head = 1\n    neighbor_samples = args.neighbor_samples\n    num_workers = args.workers\n\n    device = torch.device(\n        \"cuda\" if args.gpu is not None and torch.cuda.is_available() else \"cpu\"\n    )\n\n    g = get_graph(network_data, vocab)\n    all_walks = []\n    for i in range(edge_type_count):\n        nodes = torch.LongTensor(type_nodes[i] * args.num_walks)\n        traces, types = dgl.sampling.random_walk(\n            g, nodes, metapath=[edge_types[i]] * (neighbor_samples - 1)\n        )\n        all_walks.append(traces)\n\n    train_pairs = generate_pairs(all_walks, args.window_size, num_workers)\n    neighbor_sampler = NeighborSampler(g, [neighbor_samples])\n    train_dataloader = torch.utils.data.DataLoader(\n        train_pairs,\n        batch_size=batch_size,\n        collate_fn=neighbor_sampler.sample,\n        shuffle=True,\n        num_workers=num_workers,\n        pin_memory=True,\n    )\n\n    model = DGLGATNE(\n        num_nodes,\n        embedding_size,\n        embedding_u_size,\n        edge_types,\n        edge_type_count,\n        dim_a,\n    )\n\n    nsloss = NSLoss(num_nodes, num_sampled, embedding_size)\n\n    model.to(device)\n    nsloss.to(device)\n\n    embeddings_params = list(\n        map(id, model.node_embeddings.parameters())\n    ) + list(map(id, model.node_type_embeddings.parameters()))\n    weights_params = list(map(id, nsloss.weights.parameters()))\n\n    optimizer = torch.optim.Adam(\n        [\n            {\n                \"params\": filter(\n                    lambda p: id(p) not in embeddings_params,\n                    model.parameters(),\n                )\n            },\n            {\n                \"params\": filter(\n                    lambda p: id(p) not in weights_params,\n                    nsloss.parameters(),\n                )\n            },\n        ],\n        lr=1e-3,\n    )\n\n    sparse_optimizer = torch.optim.SparseAdam(\n        [\n            {\"params\": model.node_embeddings.parameters()},\n            {\"params\": model.node_type_embeddings.parameters()},\n            {\"params\": nsloss.weights.parameters()},\n        ],\n        lr=1e-3,\n    )\n\n    best_score = 0\n    patience = 0\n    for epoch in range(epochs):\n        model.train()\n\n        random.shuffle(train_pairs)\n\n        data_iter = tqdm.tqdm(\n            train_dataloader,\n            desc=\"epoch %d\" % (epoch),\n            total=(len(train_pairs) + (batch_size - 1)) // batch_size,\n        )\n        avg_loss = 0.0\n\n        for i, (block, head_invmap, tails, block_types) in enumerate(data_iter):\n            optimizer.zero_grad()\n            sparse_optimizer.zero_grad()\n            # embs: [batch_size, edge_type_count, embedding_size]\n            block_types = block_types.to(device)\n            embs = model(block[0].to(device))[head_invmap]\n            embs = embs.gather(\n                1,\n                block_types.view(-1, 1, 1).expand(\n                    embs.shape[0], 1, embs.shape[2]\n                ),\n            )[:, 0]\n            loss = nsloss(\n                block[0].dstdata[dgl.NID][head_invmap].to(device),\n                embs,\n                tails.to(device),\n            )\n            loss.backward()\n            optimizer.step()\n            sparse_optimizer.step()\n            avg_loss += loss.item()\n\n            post_fix = {\n                \"epoch\": epoch,\n                \"iter\": i,\n                \"avg_loss\": avg_loss / (i + 1),\n                \"loss\": loss.item(),\n            }\n            data_iter.set_postfix(post_fix)\n\n        model.eval()\n        # {'1': {}, '2': {}}\n        final_model = dict(\n            zip(edge_types, [dict() for _ in range(edge_type_count)])\n        )\n        for i in range(num_nodes):\n            train_inputs = (\n                torch.tensor([i for _ in range(edge_type_count)])\n                .unsqueeze(1)\n                .to(device)\n            )  # [i, i]\n            train_types = (\n                torch.tensor(list(range(edge_type_count)))\n                .unsqueeze(1)\n                .to(device)\n            )  # [0, 1]\n            pairs = torch.cat(\n                (train_inputs, train_inputs, train_types), dim=1\n            )  # (2, 3)\n            (\n                train_blocks,\n                train_invmap,\n                fake_tails,\n                train_types,\n            ) = neighbor_sampler.sample(pairs.cpu())\n\n            node_emb = model(train_blocks[0].to(device))[train_invmap]\n            node_emb = node_emb.gather(\n                1,\n                train_types.to(device)\n                .view(-1, 1, 1)\n                .expand(node_emb.shape[0], 1, node_emb.shape[2]),\n            )[:, 0]\n\n            for j in range(edge_type_count):\n                final_model[edge_types[j]][index2word[i]] = (\n                    node_emb[j].cpu().detach().numpy()\n                )\n\n        valid_aucs, valid_f1s, valid_prs = [], [], []\n        test_aucs, test_f1s, test_prs = [], [], []\n        for i in range(edge_type_count):\n            if args.eval_type == \"all\" or edge_types[i] in args.eval_type.split(\n                \",\"\n            ):\n                tmp_auc, tmp_f1, tmp_pr = evaluate(\n                    final_model[edge_types[i]],\n                    valid_true_data_by_edge[edge_types[i]],\n                    valid_false_data_by_edge[edge_types[i]],\n                    num_workers,\n                )\n                valid_aucs.append(tmp_auc)\n                valid_f1s.append(tmp_f1)\n                valid_prs.append(tmp_pr)\n\n                tmp_auc, tmp_f1, tmp_pr = evaluate(\n                    final_model[edge_types[i]],\n                    testing_true_data_by_edge[edge_types[i]],\n                    testing_false_data_by_edge[edge_types[i]],\n                    num_workers,\n                )\n                test_aucs.append(tmp_auc)\n                test_f1s.append(tmp_f1)\n                test_prs.append(tmp_pr)\n        print(\"valid auc:\", np.mean(valid_aucs))\n        print(\"valid pr:\", np.mean(valid_prs))\n        print(\"valid f1:\", np.mean(valid_f1s))\n\n        average_auc = np.mean(test_aucs)\n        average_f1 = np.mean(test_f1s)\n        average_pr = np.mean(test_prs)\n\n        cur_score = np.mean(valid_aucs)\n        if cur_score > best_score:\n            best_score = cur_score\n            patience = 0\n        else:\n            patience += 1\n            if patience > args.patience:\n                print(\"Early Stopping\")\n                break\n    return average_auc, average_f1, average_pr\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    file_name = args.input\n    print(args)\n\n    training_data_by_type = load_training_data(file_name + \"/train.txt\")\n    valid_true_data_by_edge, valid_false_data_by_edge = load_testing_data(\n        file_name + \"/valid.txt\"\n    )\n    testing_true_data_by_edge, testing_false_data_by_edge = load_testing_data(\n        file_name + \"/test.txt\"\n    )\n\n    start = time.time()\n    average_auc, average_f1, average_pr = train_model(training_data_by_type)\n    end = time.time()\n\n    print(\"Overall ROC-AUC:\", average_auc)\n    print(\"Overall PR-AUC\", average_pr)\n    print(\"Overall F1:\", average_f1)\n    print(\"Training Time\", end - start)\n"
  },
  {
    "path": "examples/pytorch/GATNE-T/src/main_sparse_multi_gpus.py",
    "content": "import datetime\nimport math\nimport os\nimport sys\nimport time\nfrom collections import defaultdict\n\nimport numpy as np\nimport torch\nimport torch.multiprocessing as mp\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom numpy import random\nfrom torch.nn.parallel import DistributedDataParallel\nfrom torch.nn.parameter import Parameter\nfrom tqdm.auto import tqdm\nfrom utils import *\n\nimport dgl\nimport dgl.function as fn\n\n\ndef setup_seed(seed):\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    np.random.seed(seed)\n    random.seed(seed)\n    torch.backends.cudnn.deterministic = True\n\n\ndef get_graph(network_data, vocab):\n    \"\"\"Build graph, treat all nodes as the same type\n\n    Parameters\n    ----------\n    network_data: a dict\n        keys describing the edge types, values representing edges\n    vocab: a dict\n        mapping node IDs to node indices\n    Output\n    ------\n    DGLGraph\n        a heterogenous graph, with one node type and different edge types\n    \"\"\"\n    graphs = []\n\n    node_type = \"_N\"  # '_N' can be replaced by an arbitrary name\n    data_dict = dict()\n    num_nodes_dict = {node_type: len(vocab)}\n\n    for edge_type in network_data:\n        tmp_data = network_data[edge_type]\n        src = []\n        dst = []\n        for edge in tmp_data:\n            src.extend([vocab[edge[0]], vocab[edge[1]]])\n            dst.extend([vocab[edge[1]], vocab[edge[0]]])\n        data_dict[(node_type, edge_type, node_type)] = (src, dst)\n    graph = dgl.heterograph(data_dict, num_nodes_dict)\n\n    return graph\n\n\nclass NeighborSampler(object):\n    def __init__(self, g, num_fanouts):\n        self.g = g\n        self.num_fanouts = num_fanouts\n\n    def sample(self, pairs):\n        pairs = np.stack(pairs)\n        heads, tails, types = pairs[:, 0], pairs[:, 1], pairs[:, 2]\n        seeds, head_invmap = torch.unique(\n            torch.LongTensor(heads), return_inverse=True\n        )\n        blocks = []\n        for fanout in reversed(self.num_fanouts):\n            sampled_graph = dgl.sampling.sample_neighbors(self.g, seeds, fanout)\n            sampled_block = dgl.to_block(sampled_graph, seeds)\n            seeds = sampled_block.srcdata[dgl.NID]\n            blocks.insert(0, sampled_block)\n        return (\n            blocks,\n            torch.LongTensor(head_invmap),\n            torch.LongTensor(tails),\n            torch.LongTensor(types),\n        )\n\n\nclass DGLGATNE(nn.Module):\n    def __init__(\n        self,\n        num_nodes,\n        embedding_size,\n        embedding_u_size,\n        edge_types,\n        edge_type_count,\n        dim_a,\n    ):\n        super(DGLGATNE, self).__init__()\n        self.num_nodes = num_nodes\n        self.embedding_size = embedding_size\n        self.embedding_u_size = embedding_u_size\n        self.edge_types = edge_types\n        self.edge_type_count = edge_type_count\n        self.dim_a = dim_a\n\n        self.node_embeddings = nn.Embedding(\n            num_nodes, embedding_size, sparse=True\n        )\n        self.node_type_embeddings = nn.Embedding(\n            num_nodes * edge_type_count, embedding_u_size, sparse=True\n        )\n        self.trans_weights = Parameter(\n            torch.FloatTensor(edge_type_count, embedding_u_size, embedding_size)\n        )\n        self.trans_weights_s1 = Parameter(\n            torch.FloatTensor(edge_type_count, embedding_u_size, dim_a)\n        )\n        self.trans_weights_s2 = Parameter(\n            torch.FloatTensor(edge_type_count, dim_a, 1)\n        )\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        self.node_embeddings.weight.data.uniform_(-1.0, 1.0)\n        self.node_type_embeddings.weight.data.uniform_(-1.0, 1.0)\n        self.trans_weights.data.normal_(\n            std=1.0 / math.sqrt(self.embedding_size)\n        )\n        self.trans_weights_s1.data.normal_(\n            std=1.0 / math.sqrt(self.embedding_size)\n        )\n        self.trans_weights_s2.data.normal_(\n            std=1.0 / math.sqrt(self.embedding_size)\n        )\n\n    # embs: [batch_size, embedding_size]\n    def forward(self, block):\n        input_nodes = block.srcdata[dgl.NID]\n        output_nodes = block.dstdata[dgl.NID]\n        batch_size = block.number_of_dst_nodes()\n        node_type_embed = []\n\n        with block.local_scope():\n            for i in range(self.edge_type_count):\n                edge_type = self.edge_types[i]\n                block.srcdata[edge_type] = self.node_type_embeddings(\n                    input_nodes * self.edge_type_count + i\n                )\n                block.dstdata[edge_type] = self.node_type_embeddings(\n                    output_nodes * self.edge_type_count + i\n                )\n                block.update_all(\n                    fn.copy_u(edge_type, \"m\"),\n                    fn.sum(\"m\", edge_type),\n                    etype=edge_type,\n                )\n                node_type_embed.append(block.dstdata[edge_type])\n\n            node_type_embed = torch.stack(node_type_embed, 1)\n            tmp_node_type_embed = node_type_embed.unsqueeze(2).view(\n                -1, 1, self.embedding_u_size\n            )\n            trans_w = (\n                self.trans_weights.unsqueeze(0)\n                .repeat(batch_size, 1, 1, 1)\n                .view(-1, self.embedding_u_size, self.embedding_size)\n            )\n            trans_w_s1 = (\n                self.trans_weights_s1.unsqueeze(0)\n                .repeat(batch_size, 1, 1, 1)\n                .view(-1, self.embedding_u_size, self.dim_a)\n            )\n            trans_w_s2 = (\n                self.trans_weights_s2.unsqueeze(0)\n                .repeat(batch_size, 1, 1, 1)\n                .view(-1, self.dim_a, 1)\n            )\n\n            attention = (\n                F.softmax(\n                    torch.matmul(\n                        torch.tanh(\n                            torch.matmul(tmp_node_type_embed, trans_w_s1)\n                        ),\n                        trans_w_s2,\n                    )\n                    .squeeze(2)\n                    .view(-1, self.edge_type_count),\n                    dim=1,\n                )\n                .unsqueeze(1)\n                .repeat(1, self.edge_type_count, 1)\n            )\n\n            node_type_embed = torch.matmul(attention, node_type_embed).view(\n                -1, 1, self.embedding_u_size\n            )\n            node_embed = self.node_embeddings(output_nodes).unsqueeze(1).repeat(\n                1, self.edge_type_count, 1\n            ) + torch.matmul(node_type_embed, trans_w).view(\n                -1, self.edge_type_count, self.embedding_size\n            )\n            last_node_embed = F.normalize(node_embed, dim=2)\n\n            return (\n                last_node_embed  # [batch_size, edge_type_count, embedding_size]\n            )\n\n\nclass NSLoss(nn.Module):\n    def __init__(self, num_nodes, num_sampled, embedding_size):\n        super(NSLoss, self).__init__()\n        self.num_nodes = num_nodes\n        self.num_sampled = num_sampled\n        self.embedding_size = embedding_size\n\n        # [ (log(i+2) - log(i+1)) / log(num_nodes + 1)]\n        self.sample_weights = F.normalize(\n            torch.Tensor(\n                [\n                    (math.log(k + 2) - math.log(k + 1))\n                    / math.log(num_nodes + 1)\n                    for k in range(num_nodes)\n                ]\n            ),\n            dim=0,\n        )\n        self.weights = nn.Embedding(num_nodes, embedding_size, sparse=True)\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        self.weights.weight.data.normal_(\n            std=1.0 / math.sqrt(self.embedding_size)\n        )\n\n    def forward(self, input, embs, label):\n        n = input.shape[0]\n        log_target = torch.log(\n            torch.sigmoid(torch.sum(torch.mul(embs, self.weights(label)), 1))\n        )\n        negs = (\n            torch.multinomial(\n                self.sample_weights, self.num_sampled * n, replacement=True\n            )\n            .view(n, self.num_sampled)\n            .to(input.device)\n        )\n        noise = torch.neg(self.weights(negs))\n        sum_log_sampled = torch.sum(\n            torch.log(torch.sigmoid(torch.bmm(noise, embs.unsqueeze(2)))), 1\n        ).squeeze()\n\n        loss = log_target + sum_log_sampled\n        return -loss.sum() / n\n\n\ndef run(proc_id, n_gpus, args, devices, data):\n    dev_id = devices[proc_id]\n    if n_gpus > 1:\n        dist_init_method = \"tcp://{master_ip}:{master_port}\".format(\n            master_ip=\"127.0.0.1\", master_port=\"12345\"\n        )\n        world_size = n_gpus\n        torch.distributed.init_process_group(\n            backend=\"gloo\",\n            init_method=dist_init_method,\n            world_size=world_size,\n            rank=proc_id,\n            timeout=datetime.timedelta(seconds=100),\n        )\n    torch.cuda.set_device(dev_id)\n\n    g, train_pairs, index2word, edge_types, num_nodes, edge_type_count = data\n\n    epochs = args.epoch\n    batch_size = args.batch_size\n    embedding_size = args.dimensions\n    embedding_u_size = args.edge_dim\n    u_num = edge_type_count\n    num_sampled = args.negative_samples\n    dim_a = args.att_dim\n    att_head = 1\n    neighbor_samples = args.neighbor_samples\n    num_workers = args.workers\n\n    neighbor_sampler = NeighborSampler(g, [neighbor_samples])\n    if n_gpus > 1:\n        train_sampler = torch.utils.data.distributed.DistributedSampler(\n            train_pairs,\n            num_replicas=world_size,\n            rank=proc_id,\n            shuffle=True,\n            drop_last=False,\n        )\n        train_dataloader = torch.utils.data.DataLoader(\n            train_pairs,\n            batch_size=batch_size,\n            collate_fn=neighbor_sampler.sample,\n            num_workers=num_workers,\n            sampler=train_sampler,\n            pin_memory=True,\n        )\n    else:\n        train_dataloader = torch.utils.data.DataLoader(\n            train_pairs,\n            batch_size=batch_size,\n            collate_fn=neighbor_sampler.sample,\n            num_workers=num_workers,\n            shuffle=True,\n            drop_last=False,\n            pin_memory=True,\n        )\n\n    model = DGLGATNE(\n        num_nodes,\n        embedding_size,\n        embedding_u_size,\n        edge_types,\n        edge_type_count,\n        dim_a,\n    )\n\n    nsloss = NSLoss(num_nodes, num_sampled, embedding_size)\n\n    model.to(dev_id)\n    if n_gpus > 1:\n        model = DistributedDataParallel(\n            model, device_ids=[dev_id], output_device=dev_id\n        )\n\n    nsloss.to(dev_id)\n\n    if n_gpus > 1:\n        mmodel = model.module\n    else:\n        mmodel = model\n\n    embeddings_params = list(\n        map(id, mmodel.node_embeddings.parameters())\n    ) + list(map(id, mmodel.node_type_embeddings.parameters()))\n    weights_params = list(map(id, nsloss.weights.parameters()))\n\n    optimizer = torch.optim.Adam(\n        [\n            {\n                \"params\": filter(\n                    lambda p: id(p) not in embeddings_params,\n                    model.parameters(),\n                )\n            },\n            {\n                \"params\": filter(\n                    lambda p: id(p) not in weights_params,\n                    nsloss.parameters(),\n                )\n            },\n        ],\n        lr=2e-3,\n    )\n\n    sparse_optimizer = torch.optim.SparseAdam(\n        [\n            {\"params\": mmodel.node_embeddings.parameters()},\n            {\"params\": mmodel.node_type_embeddings.parameters()},\n            {\"params\": nsloss.weights.parameters()},\n        ],\n        lr=2e-3,\n    )\n\n    if n_gpus > 1:\n        torch.distributed.barrier()\n\n    if proc_id == 0:\n        start = time.time()\n\n    for epoch in range(epochs):\n        if n_gpus > 1:\n            train_sampler.set_epoch(epoch)\n        model.train()\n\n        data_iter = train_dataloader\n        if proc_id == 0:\n            data_iter = tqdm(\n                train_dataloader,\n                desc=\"epoch %d\" % (epoch),\n                total=(len(train_pairs) + (batch_size - 1)) // batch_size,\n            )\n            avg_loss = 0.0\n\n        for i, (block, head_invmap, tails, block_types) in enumerate(data_iter):\n            optimizer.zero_grad()\n            sparse_optimizer.zero_grad()\n            # embs: [batch_size, edge_type_count, embedding_size]\n            block_types = block_types.to(dev_id)\n            embs = model(block[0].to(dev_id))[head_invmap]\n            embs = embs.gather(\n                1,\n                block_types.view(-1, 1, 1).expand(\n                    embs.shape[0], 1, embs.shape[2]\n                ),\n            )[:, 0]\n            loss = nsloss(\n                block[0].dstdata[dgl.NID][head_invmap].to(dev_id),\n                embs,\n                tails.to(dev_id),\n            )\n            loss.backward()\n            optimizer.step()\n            sparse_optimizer.step()\n\n            if proc_id == 0:\n                avg_loss += loss.item()\n\n                post_fix = {\n                    \"avg_loss\": avg_loss / (i + 1),\n                    \"loss\": loss.item(),\n                }\n                data_iter.set_postfix(post_fix)\n\n        if n_gpus > 1:\n            torch.distributed.barrier()\n\n        if proc_id == 0:\n            model.eval()\n            # {'1': {}, '2': {}}\n            final_model = dict(\n                zip(edge_types, [dict() for _ in range(edge_type_count)])\n            )\n            for i in range(num_nodes):\n                train_inputs = (\n                    torch.tensor([i for _ in range(edge_type_count)])\n                    .unsqueeze(1)\n                    .to(dev_id)\n                )  # [i, i]\n                train_types = (\n                    torch.tensor(list(range(edge_type_count)))\n                    .unsqueeze(1)\n                    .to(dev_id)\n                )  # [0, 1]\n                pairs = torch.cat(\n                    (train_inputs, train_inputs, train_types), dim=1\n                )  # (2, 3)\n                (\n                    train_blocks,\n                    train_invmap,\n                    fake_tails,\n                    train_types,\n                ) = neighbor_sampler.sample(pairs.cpu())\n\n                node_emb = model(train_blocks[0].to(dev_id))[train_invmap]\n                node_emb = node_emb.gather(\n                    1,\n                    train_types.to(dev_id)\n                    .view(-1, 1, 1)\n                    .expand(node_emb.shape[0], 1, node_emb.shape[2]),\n                )[:, 0]\n\n                for j in range(edge_type_count):\n                    final_model[edge_types[j]][index2word[i]] = (\n                        node_emb[j].cpu().detach().numpy()\n                    )\n\n            valid_aucs, valid_f1s, valid_prs = [], [], []\n            test_aucs, test_f1s, test_prs = [], [], []\n            for i in range(edge_type_count):\n                if args.eval_type == \"all\" or edge_types[\n                    i\n                ] in args.eval_type.split(\",\"):\n                    tmp_auc, tmp_f1, tmp_pr = evaluate(\n                        final_model[edge_types[i]],\n                        valid_true_data_by_edge[edge_types[i]],\n                        valid_false_data_by_edge[edge_types[i]],\n                        num_workers,\n                    )\n                    valid_aucs.append(tmp_auc)\n                    valid_f1s.append(tmp_f1)\n                    valid_prs.append(tmp_pr)\n\n                    tmp_auc, tmp_f1, tmp_pr = evaluate(\n                        final_model[edge_types[i]],\n                        testing_true_data_by_edge[edge_types[i]],\n                        testing_false_data_by_edge[edge_types[i]],\n                        num_workers,\n                    )\n                    test_aucs.append(tmp_auc)\n                    test_f1s.append(tmp_f1)\n                    test_prs.append(tmp_pr)\n            print(\"valid auc:\", np.mean(valid_aucs))\n            print(\"valid pr:\", np.mean(valid_prs))\n            print(\"valid f1:\", np.mean(valid_f1s))\n\n    if proc_id == 0:\n        end = time.time()\n        average_auc = np.mean(test_aucs)\n        average_f1 = np.mean(test_f1s)\n        average_pr = np.mean(test_prs)\n        print(\"Overall ROC-AUC:\", average_auc)\n        print(\"Overall PR-AUC\", average_pr)\n        print(\"Overall F1:\", average_f1)\n        print(\"Training Time\", end - start)\n\n\ndef train_model(network_data):\n    index2word, vocab, type_nodes = generate_vocab(network_data)\n\n    edge_types = list(network_data.keys())\n    num_nodes = len(index2word)\n    edge_type_count = len(edge_types)\n\n    devices = list(map(int, args.gpu.split(\",\")))\n    n_gpus = len(devices)\n    neighbor_samples = args.neighbor_samples\n    num_workers = args.workers\n\n    g = get_graph(network_data, vocab)\n    all_walks = []\n    for i in range(edge_type_count):\n        nodes = torch.LongTensor(type_nodes[i] * args.num_walks)\n        traces, types = dgl.sampling.random_walk(\n            g, nodes, metapath=[edge_types[i]] * (neighbor_samples - 1)\n        )\n        all_walks.append(traces)\n\n    train_pairs = generate_pairs(all_walks, args.window_size, num_workers)\n    data = g, train_pairs, index2word, edge_types, num_nodes, edge_type_count\n\n    if n_gpus == 1:\n        run(0, n_gpus, args, devices, data)\n    else:\n        mp.spawn(run, args=(n_gpus, args, devices, data), nprocs=n_gpus)\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    file_name = args.input\n    print(args)\n    setup_seed(1234)\n\n    training_data_by_type = load_training_data(file_name + \"/train.txt\")\n    valid_true_data_by_edge, valid_false_data_by_edge = load_testing_data(\n        file_name + \"/valid.txt\"\n    )\n    testing_true_data_by_edge, testing_false_data_by_edge = load_testing_data(\n        file_name + \"/test.txt\"\n    )\n\n    train_model(training_data_by_type)\n"
  },
  {
    "path": "examples/pytorch/GATNE-T/src/utils.py",
    "content": "import argparse\nimport multiprocessing\nimport time\nfrom collections import defaultdict\nfrom functools import partial, reduce, wraps\n\nimport networkx as nx\nimport numpy as np\nimport torch\nfrom gensim.models.keyedvectors import Vocab\nfrom six import iteritems\nfrom sklearn.metrics import auc, f1_score, precision_recall_curve, roc_auc_score\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser()\n\n    parser.add_argument(\n        \"--input\", type=str, default=\"data/amazon\", help=\"Input dataset path\"\n    )\n\n    parser.add_argument(\n        \"--features\", type=str, default=None, help=\"Input node features\"\n    )\n\n    parser.add_argument(\n        \"--epoch\",\n        type=int,\n        default=100,\n        help=\"Number of epoch. Default is 100.\",\n    )\n\n    parser.add_argument(\n        \"--batch-size\",\n        type=int,\n        default=64,\n        help=\"Number of batch_size. Default is 64.\",\n    )\n\n    parser.add_argument(\n        \"--eval-type\",\n        type=str,\n        default=\"all\",\n        help=\"The edge type(s) for evaluation.\",\n    )\n\n    parser.add_argument(\n        \"--schema\",\n        type=str,\n        default=None,\n        help=\"The metapath schema (e.g., U-I-U,I-U-I).\",\n    )\n\n    parser.add_argument(\n        \"--dimensions\",\n        type=int,\n        default=200,\n        help=\"Number of dimensions. Default is 200.\",\n    )\n\n    parser.add_argument(\n        \"--edge-dim\",\n        type=int,\n        default=10,\n        help=\"Number of edge embedding dimensions. Default is 10.\",\n    )\n\n    parser.add_argument(\n        \"--att-dim\",\n        type=int,\n        default=20,\n        help=\"Number of attention dimensions. Default is 20.\",\n    )\n\n    parser.add_argument(\n        \"--walk-length\",\n        type=int,\n        default=10,\n        help=\"Length of walk per source. Default is 10.\",\n    )\n\n    parser.add_argument(\n        \"--num-walks\",\n        type=int,\n        default=20,\n        help=\"Number of walks per source. Default is 20.\",\n    )\n\n    parser.add_argument(\n        \"--window-size\",\n        type=int,\n        default=5,\n        help=\"Context size for optimization. Default is 5.\",\n    )\n\n    parser.add_argument(\n        \"--negative-samples\",\n        type=int,\n        default=5,\n        help=\"Negative samples for optimization. Default is 5.\",\n    )\n\n    parser.add_argument(\n        \"--neighbor-samples\",\n        type=int,\n        default=10,\n        help=\"Neighbor samples for aggregation. Default is 10.\",\n    )\n\n    parser.add_argument(\n        \"--patience\",\n        type=int,\n        default=5,\n        help=\"Early stopping patience. Default is 5.\",\n    )\n\n    parser.add_argument(\n        \"--gpu\",\n        type=str,\n        default=None,\n        help=\"Comma separated list of GPU device IDs.\",\n    )\n\n    parser.add_argument(\n        \"--workers\",\n        type=int,\n        default=4,\n        help=\"Number of workers.\",\n    )\n\n    return parser.parse_args()\n\n\n# for each line, the data is [edge_type, node, node]\ndef load_training_data(f_name):\n    print(\"We are loading data from:\", f_name)\n    edge_data_by_type = dict()\n    all_nodes = list()\n    with open(f_name, \"r\") as f:\n        for line in f:\n            words = line[:-1].split(\" \")  # line[-1] == '\\n'\n            if words[0] not in edge_data_by_type:\n                edge_data_by_type[words[0]] = list()\n            x, y = words[1], words[2]\n            edge_data_by_type[words[0]].append((x, y))\n            all_nodes.append(x)\n            all_nodes.append(y)\n    all_nodes = list(set(all_nodes))\n    print(\"Total training nodes: \" + str(len(all_nodes)))\n    return edge_data_by_type\n\n\n# for each line, the data is [edge_type, node, node, true_or_false]\ndef load_testing_data(f_name):\n    print(\"We are loading data from:\", f_name)\n    true_edge_data_by_type = dict()\n    false_edge_data_by_type = dict()\n    all_edges = list()\n    all_nodes = list()\n    with open(f_name, \"r\") as f:\n        for line in f:\n            words = line[:-1].split(\" \")\n            x, y = words[1], words[2]\n            if int(words[3]) == 1:\n                if words[0] not in true_edge_data_by_type:\n                    true_edge_data_by_type[words[0]] = list()\n                true_edge_data_by_type[words[0]].append((x, y))\n            else:\n                if words[0] not in false_edge_data_by_type:\n                    false_edge_data_by_type[words[0]] = list()\n                false_edge_data_by_type[words[0]].append((x, y))\n            all_nodes.append(x)\n            all_nodes.append(y)\n    all_nodes = list(set(all_nodes))\n    return true_edge_data_by_type, false_edge_data_by_type\n\n\ndef load_node_type(f_name):\n    print(\"We are loading node type from:\", f_name)\n    node_type = {}\n    with open(f_name, \"r\") as f:\n        for line in f:\n            items = line.strip().split()\n            node_type[items[0]] = items[1]\n    return node_type\n\n\ndef generate_pairs_parallel(walks, skip_window=None, layer_id=None):\n    pairs = []\n    for walk in walks:\n        walk = walk.tolist()\n        for i in range(len(walk)):\n            for j in range(1, skip_window + 1):\n                if i - j >= 0:\n                    pairs.append((walk[i], walk[i - j], layer_id))\n                if i + j < len(walk):\n                    pairs.append((walk[i], walk[i + j], layer_id))\n    return pairs\n\n\ndef generate_pairs(all_walks, window_size, num_workers):\n    # for each node, choose the first neighbor and second neighbor of it to form pairs\n    # Get all worker processes\n    start_time = time.time()\n    print(\"We are generating pairs with {} cores.\".format(num_workers))\n\n    # Start all worker processes\n    pool = multiprocessing.Pool(processes=num_workers)\n    pairs = []\n    skip_window = window_size // 2\n    for layer_id, walks in enumerate(all_walks):\n        block_num = len(walks) // num_workers\n        if block_num > 0:\n            walks_list = [\n                walks[i * block_num : min((i + 1) * block_num, len(walks))]\n                for i in range(num_workers)\n            ]\n        else:\n            walks_list = [walks]\n        tmp_result = pool.map(\n            partial(\n                generate_pairs_parallel,\n                skip_window=skip_window,\n                layer_id=layer_id,\n            ),\n            walks_list,\n        )\n        pairs += reduce(lambda x, y: x + y, tmp_result)\n\n    pool.close()\n    end_time = time.time()\n    print(\"Generate pairs end, use {}s.\".format(end_time - start_time))\n    return np.array([list(pair) for pair in set(pairs)])\n\n\ndef generate_vocab(network_data):\n    nodes, index2word = [], []\n    for edge_type in network_data:\n        node1, node2 = zip(*network_data[edge_type])\n        index2word = index2word + list(node1) + list(node2)\n\n    index2word = list(set(index2word))\n    vocab = {}\n    i = 0\n    for word in index2word:\n        vocab[word] = i\n        i = i + 1\n\n    for edge_type in network_data:\n        node1, node2 = zip(*network_data[edge_type])\n        tmp_nodes = list(set(list(node1) + list(node2)))\n        tmp_nodes = [vocab[word] for word in tmp_nodes]\n        nodes.append(tmp_nodes)\n\n    return index2word, vocab, nodes\n\n\ndef get_score(local_model, edge):\n    node1, node2 = str(edge[0]), str(edge[1])\n    try:\n        vector1 = local_model[node1]\n        vector2 = local_model[node2]\n        return np.dot(vector1, vector2) / (\n            np.linalg.norm(vector1) * np.linalg.norm(vector2)\n        )\n    except Exception as e:\n        pass\n\n\ndef evaluate(model, true_edges, false_edges, num_workers):\n    true_list = list()\n    prediction_list = list()\n    true_num = 0\n\n    # Start all worker processes\n    pool = multiprocessing.Pool(processes=num_workers)\n    tmp_true_score_list = pool.map(partial(get_score, model), true_edges)\n    tmp_false_score_list = pool.map(partial(get_score, model), false_edges)\n    pool.close()\n\n    prediction_list += [\n        tmp_score for tmp_score in tmp_true_score_list if tmp_score is not None\n    ]\n    true_num = len(prediction_list)\n    true_list += [1] * true_num\n\n    prediction_list += [\n        tmp_score for tmp_score in tmp_false_score_list if tmp_score is not None\n    ]\n    true_list += [0] * (len(prediction_list) - true_num)\n\n    sorted_pred = prediction_list[:]\n    sorted_pred.sort()\n    threshold = sorted_pred[-true_num]\n\n    y_pred = np.zeros(len(prediction_list), dtype=np.int32)\n    for i in range(len(prediction_list)):\n        if prediction_list[i] >= threshold:\n            y_pred[i] = 1\n\n    y_true = np.array(true_list)\n    y_scores = np.array(prediction_list)\n    ps, rs, _ = precision_recall_curve(y_true, y_scores)\n    return (\n        roc_auc_score(y_true, y_scores),\n        f1_score(y_true, y_pred),\n        auc(rs, ps),\n    )\n"
  },
  {
    "path": "examples/pytorch/GNN-FiLM/README.md",
    "content": "# DGL Implementation of the GNN-FiLM Model\n\nThis DGL example implements the GNN model proposed in the paper [GNN-FiLM: Graph Neural Networks with Feature-wise Linear Modulation](https://arxiv.org/pdf/1906.12192.pdf). \nThe author's codes of implementation is in [here](https://github.com/Microsoft/tf-gnn-samples)\n\n\nExample implementor\n----------------------\nThis example was implemented by [Kounianhua Du](https://github.com/KounianhuaDu) during her Software Dev Engineer Intern work at the AWS Shanghai AI Lab.\n\n\nDependencies\n----------------------\n- numpy 1.19.4\n- scikit-learn 0.22.1\n- pytorch 1.4.0\n- dgl 0.5.3\n\n\nThe graph dataset used in this example \n---------------------------------------\nThe DGL's built-in PPIDataset. This is a Protein-Protein Interaction dataset for inductive node classification. The PPIDataset is a toy Protein-Protein Interaction network dataset. The dataset contains 24 graphs. The average number of nodes per graph is 2372. Each node has 50 features and 121 labels. There are 20 graphs for training, 2 for validation, and 2 for testing.\n\nNOTE: Following the paper, in addition to the dataset-provided untyped edges, a fresh \"self-loop\" edge type is added.\n\nStatistics:\n- Train examples: 20\n- Valid examples: 2\n- Test examples: 2\n- AvgNodesPerGraph: 2372\n- NumFeats: 50\n- NumLabels: 121\n\n\nHow to run example files\n--------------------------------\nIn the GNNFiLM folder, run\n\n```bash\npython main.py \n```\n\nIf want to use a GPU, run\n\n```bash\npython main.py --gpu ${your_device_id_here}\n```\n\n\nPerformance\n-------------------------\n\nNOTE: We do not perform grid search or finetune here, so there is a gap between the performance reported in the original paper and this example. Below results, mean(standard deviation), were computed over ten runs.\n\n**GNN-FiLM results on PPI task**\n| Model         | Paper (tensorflow)               | ours (dgl)                  |\n| ------------- | -------------------------------- | --------------------------- |\n| Avg. Micro-F1 | 0.992 (0.000)                    | 0.983 (0.001)               |\n"
  },
  {
    "path": "examples/pytorch/GNN-FiLM/data_loader.py",
    "content": "import collections\n\nimport dgl\nfrom dgl.data import PPIDataset\n\nfrom torch.utils.data import DataLoader, Dataset\n\n# implement the collate_fn for dgl graph data class\nPPIBatch = collections.namedtuple(\"PPIBatch\", [\"graph\", \"label\"])\n\n\ndef batcher(device):\n    def batcher_dev(batch):\n        batch_graphs = dgl.batch(batch)\n        return PPIBatch(\n            graph=batch_graphs, label=batch_graphs.ndata[\"label\"].to(device)\n        )\n\n    return batcher_dev\n\n\n# add a fresh \"self-loop\" edge type to the untyped PPI dataset and prepare train, val, test loaders\ndef load_PPI(batch_size=1, device=\"cpu\"):\n    train_set = PPIDataset(mode=\"train\")\n    valid_set = PPIDataset(mode=\"valid\")\n    test_set = PPIDataset(mode=\"test\")\n    # for each graph, add self-loops as a new relation type\n    # here we reconstruct the graph since the schema of a heterograph cannot be changed once constructed\n    for i in range(len(train_set)):\n        g = dgl.heterograph(\n            {\n                (\"_N\", \"_E\", \"_N\"): train_set[i].edges(),\n                (\"_N\", \"self\", \"_N\"): (\n                    train_set[i].nodes(),\n                    train_set[i].nodes(),\n                ),\n            }\n        )\n        g.ndata[\"label\"] = train_set[i].ndata[\"label\"]\n        g.ndata[\"feat\"] = train_set[i].ndata[\"feat\"]\n        g.ndata[\"_ID\"] = train_set[i].ndata[\"_ID\"]\n        g.edges[\"_E\"].data[\"_ID\"] = train_set[i].edata[\"_ID\"]\n        train_set.graphs[i] = g\n    for i in range(len(valid_set)):\n        g = dgl.heterograph(\n            {\n                (\"_N\", \"_E\", \"_N\"): valid_set[i].edges(),\n                (\"_N\", \"self\", \"_N\"): (\n                    valid_set[i].nodes(),\n                    valid_set[i].nodes(),\n                ),\n            }\n        )\n        g.ndata[\"label\"] = valid_set[i].ndata[\"label\"]\n        g.ndata[\"feat\"] = valid_set[i].ndata[\"feat\"]\n        g.ndata[\"_ID\"] = valid_set[i].ndata[\"_ID\"]\n        g.edges[\"_E\"].data[\"_ID\"] = valid_set[i].edata[\"_ID\"]\n        valid_set.graphs[i] = g\n    for i in range(len(test_set)):\n        g = dgl.heterograph(\n            {\n                (\"_N\", \"_E\", \"_N\"): test_set[i].edges(),\n                (\"_N\", \"self\", \"_N\"): (\n                    test_set[i].nodes(),\n                    test_set[i].nodes(),\n                ),\n            }\n        )\n        g.ndata[\"label\"] = test_set[i].ndata[\"label\"]\n        g.ndata[\"feat\"] = test_set[i].ndata[\"feat\"]\n        g.ndata[\"_ID\"] = test_set[i].ndata[\"_ID\"]\n        g.edges[\"_E\"].data[\"_ID\"] = test_set[i].edata[\"_ID\"]\n        test_set.graphs[i] = g\n\n    etypes = train_set[0].etypes\n    in_size = train_set[0].ndata[\"feat\"].shape[1]\n    out_size = train_set[0].ndata[\"label\"].shape[1]\n\n    # prepare train, valid, and test dataloaders\n    train_loader = DataLoader(\n        train_set,\n        batch_size=batch_size,\n        collate_fn=batcher(device),\n        shuffle=True,\n    )\n    valid_loader = DataLoader(\n        valid_set,\n        batch_size=batch_size,\n        collate_fn=batcher(device),\n        shuffle=True,\n    )\n    test_loader = DataLoader(\n        test_set,\n        batch_size=batch_size,\n        collate_fn=batcher(device),\n        shuffle=True,\n    )\n    return train_loader, valid_loader, test_loader, etypes, in_size, out_size\n"
  },
  {
    "path": "examples/pytorch/GNN-FiLM/main.py",
    "content": "import argparse\nimport os\n\nimport dgl\nimport dgl.function as fn\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom data_loader import load_PPI\nfrom utils import evaluate_f1_score\n\n\nclass GNNFiLMLayer(nn.Module):\n    def __init__(self, in_size, out_size, etypes, dropout=0.1):\n        super(GNNFiLMLayer, self).__init__()\n        self.in_size = in_size\n        self.out_size = out_size\n\n        # weights for different types of edges\n        self.W = nn.ModuleDict(\n            {name: nn.Linear(in_size, out_size, bias=False) for name in etypes}\n        )\n\n        # hypernets to learn the affine functions for different types of edges\n        self.film = nn.ModuleDict(\n            {\n                name: nn.Linear(in_size, 2 * out_size, bias=False)\n                for name in etypes\n            }\n        )\n\n        # layernorm before each propogation\n        self.layernorm = nn.LayerNorm(out_size)\n\n        # dropout layer\n        self.dropout = nn.Dropout(dropout)\n\n    def forward(self, g, feat_dict):\n        # the input graph is a multi-relational graph, so treated as hetero-graph.\n\n        funcs = {}  # message and reduce functions dict\n        # for each type of edges, compute messages and reduce them all\n        for srctype, etype, dsttype in g.canonical_etypes:\n            messages = self.W[etype](\n                feat_dict[srctype]\n            )  # apply W_l on src feature\n            film_weights = self.film[etype](\n                feat_dict[dsttype]\n            )  # use dst feature to compute affine function paras\n            gamma = film_weights[\n                :, : self.out_size\n            ]  # \"gamma\" for the affine function\n            beta = film_weights[\n                :, self.out_size :\n            ]  # \"beta\" for the affine function\n            messages = gamma * messages + beta  # compute messages\n            messages = F.relu_(messages)\n            g.nodes[srctype].data[etype] = messages  # store in ndata\n            funcs[etype] = (\n                fn.copy_u(etype, \"m\"),\n                fn.sum(\"m\", \"h\"),\n            )  # define message and reduce functions\n        g.multi_update_all(\n            funcs, \"sum\"\n        )  # update all, reduce by first type-wisely then across different types\n        feat_dict = {}\n        for ntype in g.ntypes:\n            feat_dict[ntype] = self.dropout(\n                self.layernorm(g.nodes[ntype].data[\"h\"])\n            )  # apply layernorm and dropout\n        return feat_dict\n\n\nclass GNNFiLM(nn.Module):\n    def __init__(\n        self, etypes, in_size, hidden_size, out_size, num_layers, dropout=0.1\n    ):\n        super(GNNFiLM, self).__init__()\n        self.film_layers = nn.ModuleList()\n        self.film_layers.append(\n            GNNFiLMLayer(in_size, hidden_size, etypes, dropout)\n        )\n        for i in range(num_layers - 1):\n            self.film_layers.append(\n                GNNFiLMLayer(hidden_size, hidden_size, etypes, dropout)\n            )\n        self.predict = nn.Linear(hidden_size, out_size, bias=True)\n\n    def forward(self, g, out_key):\n        h_dict = {\n            ntype: g.nodes[ntype].data[\"feat\"] for ntype in g.ntypes\n        }  # prepare input feature dict\n        for layer in self.film_layers:\n            h_dict = layer(g, h_dict)\n        h = self.predict(\n            h_dict[out_key]\n        )  # use the final embed to predict, out_size = num_classes\n        h = torch.sigmoid(h)\n        return h\n\n\ndef main(args):\n    # Step 1: Prepare graph data and retrieve train/validation/test dataloader ============================= #\n    if args.gpu >= 0 and torch.cuda.is_available():\n        device = \"cuda:{}\".format(args.gpu)\n    else:\n        device = \"cpu\"\n\n    if args.dataset == \"PPI\":\n        train_set, valid_set, test_set, etypes, in_size, out_size = load_PPI(\n            args.batch_size, device\n        )\n\n    # Step 2: Create model and training components=========================================================== #\n    model = GNNFiLM(\n        etypes, in_size, args.hidden_size, out_size, args.num_layers\n    ).to(device)\n    criterion = nn.BCELoss()\n    optimizer = torch.optim.Adam(\n        model.parameters(), lr=args.lr, weight_decay=args.wd\n    )\n    scheduler = torch.optim.lr_scheduler.StepLR(\n        optimizer, args.step_size, gamma=args.gamma\n    )\n\n    # Step 4: training epoches ============================================================================== #\n    lastf1 = 0\n    cnt = 0\n    best_val_f1 = 0\n    for epoch in range(args.max_epoch):\n        train_loss = []\n        train_f1 = []\n        val_loss = []\n        val_f1 = []\n        model.train()\n        for batch in train_set:\n            g = batch.graph\n            g = g.to(device)\n            logits = model.forward(g, \"_N\")\n            labels = batch.label\n            loss = criterion(logits, labels)\n            f1 = evaluate_f1_score(\n                logits.detach().cpu().numpy(), labels.detach().cpu().numpy()\n            )\n\n            optimizer.zero_grad()\n            loss.backward()\n            optimizer.step()\n            train_loss.append(loss.item())\n            train_f1.append(f1)\n\n        train_loss = np.mean(train_loss)\n        train_f1 = np.mean(train_f1)\n        scheduler.step()\n\n        model.eval()\n        with torch.no_grad():\n            for batch in valid_set:\n                g = batch.graph\n                g = g.to(device)\n                logits = model.forward(g, \"_N\")\n                labels = batch.label\n                loss = criterion(logits, labels)\n                f1 = evaluate_f1_score(\n                    logits.detach().cpu().numpy(), labels.detach().cpu().numpy()\n                )\n                val_loss.append(loss.item())\n                val_f1.append(f1)\n\n        val_loss = np.mean(val_loss)\n        val_f1 = np.mean(val_f1)\n        print(\n            \"Epoch {:d} | Train Loss {:.4f} | Train F1 {:.4f} | Val Loss {:.4f} | Val F1 {:.4f} |\".format(\n                epoch + 1, train_loss, train_f1, val_loss, val_f1\n            )\n        )\n        if val_f1 > best_val_f1:\n            best_val_f1 = val_f1\n            torch.save(\n                model.state_dict(), os.path.join(args.save_dir, args.name)\n            )\n\n        if val_f1 < lastf1:\n            cnt += 1\n            if cnt == args.early_stopping:\n                print(\"Early stop.\")\n                break\n        else:\n            cnt = 0\n            lastf1 = val_f1\n\n    model.eval()\n    test_loss = []\n    test_f1 = []\n    model.load_state_dict(\n        torch.load(os.path.join(args.save_dir, args.name), weights_only=False)\n    )\n    with torch.no_grad():\n        for batch in test_set:\n            g = batch.graph\n            g = g.to(device)\n            logits = model.forward(g, \"_N\")\n            labels = batch.label\n            loss = criterion(logits, labels)\n            f1 = evaluate_f1_score(\n                logits.detach().cpu().numpy(), labels.detach().cpu().numpy()\n            )\n            test_loss.append(loss.item())\n            test_f1.append(f1)\n    test_loss = np.mean(test_loss)\n    test_f1 = np.mean(test_f1)\n\n    print(\"Test F1: {:.4f} | Test loss: {:.4f}\".format(test_f1, test_loss))\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"GNN-FiLM\")\n    parser.add_argument(\n        \"--dataset\",\n        type=str,\n        default=\"PPI\",\n        help=\"DGL dataset for this GNN-FiLM\",\n    )\n    parser.add_argument(\n        \"--gpu\", type=int, default=-1, help=\"GPU Index. Default: -1, using CPU.\"\n    )\n    parser.add_argument(\n        \"--in_size\", type=int, default=50, help=\"Input dimensionalities\"\n    )\n    parser.add_argument(\n        \"--hidden_size\",\n        type=int,\n        default=320,\n        help=\"Hidden layer dimensionalities\",\n    )\n    parser.add_argument(\n        \"--out_size\", type=int, default=121, help=\"Output dimensionalities\"\n    )\n    parser.add_argument(\n        \"--num_layers\", type=int, default=4, help=\"Number of GNN layers\"\n    )\n    parser.add_argument(\"--batch_size\", type=int, default=5, help=\"Batch size\")\n    parser.add_argument(\n        \"--max_epoch\",\n        type=int,\n        default=1500,\n        help=\"The max number of epoches. Default: 500\",\n    )\n    parser.add_argument(\n        \"--early_stopping\",\n        type=int,\n        default=80,\n        help=\"Early stopping. Default: 50\",\n    )\n    parser.add_argument(\n        \"--lr\", type=float, default=0.001, help=\"Learning rate. Default: 3e-1\"\n    )\n    parser.add_argument(\n        \"--wd\", type=float, default=0.0009, help=\"Weight decay. Default: 3e-1\"\n    )\n    parser.add_argument(\n        \"--step-size\",\n        type=int,\n        default=40,\n        help=\"Period of learning rate decay.\",\n    )\n    parser.add_argument(\n        \"--gamma\",\n        type=float,\n        default=0.8,\n        help=\"Multiplicative factor of learning rate decay.\",\n    )\n    parser.add_argument(\n        \"--dropout\", type=float, default=0.1, help=\"Dropout rate. Default: 0.9\"\n    )\n    parser.add_argument(\n        \"--save_dir\", type=str, default=\"./out\", help=\"Path to save the model.\"\n    )\n    parser.add_argument(\n        \"--name\", type=str, default=\"GNN-FiLM\", help=\"Saved model name.\"\n    )\n\n    args = parser.parse_args()\n    print(args)\n    if not os.path.exists(args.save_dir):\n        os.mkdir(args.save_dir)\n    main(args)\n"
  },
  {
    "path": "examples/pytorch/GNN-FiLM/utils.py",
    "content": "import numpy as np\nfrom sklearn.metrics import f1_score\n\n\n# function to compute f1 score\ndef evaluate_f1_score(pred, label):\n    pred = np.round(pred, 0).astype(np.int16)\n    pred = pred.flatten()\n    label = label.flatten()\n    return f1_score(y_pred=pred, y_true=label)\n"
  },
  {
    "path": "examples/pytorch/NGCF/Data/load_amazon-book.sh",
    "content": "wget https://s3.us-west-2.amazonaws.com/dgl-data/dataset/amazon-book.zip\nunzip amazon-book.zip"
  },
  {
    "path": "examples/pytorch/NGCF/Data/load_gowalla.sh",
    "content": "wget https://s3.us-west-2.amazonaws.com/dgl-data/dataset/gowalla.zip\nunzip gowalla.zip"
  },
  {
    "path": "examples/pytorch/NGCF/NGCF/main.py",
    "content": "import os\nfrom time import time\n\nimport torch\nimport torch.optim as optim\nfrom model import NGCF\nfrom utility.batch_test import *\nfrom utility.helper import early_stopping\n\n\ndef main(args):\n    # Step 1: Prepare graph data and device ================================================================= #\n    if args.gpu >= 0 and torch.cuda.is_available():\n        device = \"cuda:{}\".format(args.gpu)\n    else:\n        device = \"cpu\"\n\n    g = data_generator.g\n    g = g.to(device)\n\n    # Step 2: Create model and training components=========================================================== #\n    model = NGCF(\n        g, args.embed_size, args.layer_size, args.mess_dropout, args.regs[0]\n    ).to(device)\n    optimizer = optim.Adam(model.parameters(), lr=args.lr)\n\n    # Step 3: training epoches ============================================================================== #\n    n_batch = data_generator.n_train // args.batch_size + 1\n    t0 = time()\n    cur_best_pre_0, stopping_step = 0, 0\n    loss_loger, pre_loger, rec_loger, ndcg_loger, hit_loger = [], [], [], [], []\n    for epoch in range(args.epoch):\n        t1 = time()\n        loss, mf_loss, emb_loss = 0.0, 0.0, 0.0\n        for idx in range(n_batch):\n            users, pos_items, neg_items = data_generator.sample()\n            u_g_embeddings, pos_i_g_embeddings, neg_i_g_embeddings = model(\n                g, \"user\", \"item\", users, pos_items, neg_items\n            )\n\n            batch_loss, batch_mf_loss, batch_emb_loss = model.create_bpr_loss(\n                u_g_embeddings, pos_i_g_embeddings, neg_i_g_embeddings\n            )\n            optimizer.zero_grad()\n            batch_loss.backward()\n            optimizer.step()\n\n            loss += batch_loss\n            mf_loss += batch_mf_loss\n            emb_loss += batch_emb_loss\n\n        if (epoch + 1) % 10 != 0:\n            if args.verbose > 0 and epoch % args.verbose == 0:\n                perf_str = \"Epoch %d [%.1fs]: train==[%.5f=%.5f + %.5f]\" % (\n                    epoch,\n                    time() - t1,\n                    loss,\n                    mf_loss,\n                    emb_loss,\n                )\n                print(perf_str)\n            continue  # end the current epoch and move to the next epoch, let the following evaluation run every 10 epoches\n\n        # evaluate the model every 10 epoches\n        t2 = time()\n        users_to_test = list(data_generator.test_set.keys())\n        ret = test(model, g, users_to_test)\n        t3 = time()\n\n        loss_loger.append(loss)\n        rec_loger.append(ret[\"recall\"])\n        pre_loger.append(ret[\"precision\"])\n        ndcg_loger.append(ret[\"ndcg\"])\n        hit_loger.append(ret[\"hit_ratio\"])\n\n        if args.verbose > 0:\n            perf_str = (\n                \"Epoch %d [%.1fs + %.1fs]: train==[%.5f=%.5f + %.5f], recall=[%.5f, %.5f], \"\n                \"precision=[%.5f, %.5f], hit=[%.5f, %.5f], ndcg=[%.5f, %.5f]\"\n                % (\n                    epoch,\n                    t2 - t1,\n                    t3 - t2,\n                    loss,\n                    mf_loss,\n                    emb_loss,\n                    ret[\"recall\"][0],\n                    ret[\"recall\"][-1],\n                    ret[\"precision\"][0],\n                    ret[\"precision\"][-1],\n                    ret[\"hit_ratio\"][0],\n                    ret[\"hit_ratio\"][-1],\n                    ret[\"ndcg\"][0],\n                    ret[\"ndcg\"][-1],\n                )\n            )\n            print(perf_str)\n\n        cur_best_pre_0, stopping_step, should_stop = early_stopping(\n            ret[\"recall\"][0],\n            cur_best_pre_0,\n            stopping_step,\n            expected_order=\"acc\",\n            flag_step=5,\n        )\n\n        # early stop\n        if should_stop == True:\n            break\n\n        if ret[\"recall\"][0] == cur_best_pre_0 and args.save_flag == 1:\n            torch.save(model.state_dict(), args.weights_path + args.model_name)\n            print(\n                \"save the weights in path: \",\n                args.weights_path + args.model_name,\n            )\n\n    recs = np.array(rec_loger)\n    pres = np.array(pre_loger)\n    ndcgs = np.array(ndcg_loger)\n    hit = np.array(hit_loger)\n\n    best_rec_0 = max(recs[:, 0])\n    idx = list(recs[:, 0]).index(best_rec_0)\n\n    final_perf = (\n        \"Best Iter=[%d]@[%.1f]\\trecall=[%s], precision=[%s], hit=[%s], ndcg=[%s]\"\n        % (\n            idx,\n            time() - t0,\n            \"\\t\".join([\"%.5f\" % r for r in recs[idx]]),\n            \"\\t\".join([\"%.5f\" % r for r in pres[idx]]),\n            \"\\t\".join([\"%.5f\" % r for r in hit[idx]]),\n            \"\\t\".join([\"%.5f\" % r for r in ndcgs[idx]]),\n        )\n    )\n    print(final_perf)\n\n\nif __name__ == \"__main__\":\n    if not os.path.exists(args.weights_path):\n        os.mkdir(args.weights_path)\n    args.mess_dropout = eval(args.mess_dropout)\n    args.layer_size = eval(args.layer_size)\n    args.regs = eval(args.regs)\n    print(args)\n    main(args)\n"
  },
  {
    "path": "examples/pytorch/NGCF/NGCF/model.py",
    "content": "import dgl.function as fn\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass NGCFLayer(nn.Module):\n    def __init__(self, in_size, out_size, norm_dict, dropout):\n        super(NGCFLayer, self).__init__()\n        self.in_size = in_size\n        self.out_size = out_size\n\n        # weights for different types of messages\n        self.W1 = nn.Linear(in_size, out_size, bias=True)\n        self.W2 = nn.Linear(in_size, out_size, bias=True)\n\n        # leaky relu\n        self.leaky_relu = nn.LeakyReLU(0.2)\n\n        # dropout layer\n        self.dropout = nn.Dropout(dropout)\n\n        # initialization\n        torch.nn.init.xavier_uniform_(self.W1.weight)\n        torch.nn.init.constant_(self.W1.bias, 0)\n        torch.nn.init.xavier_uniform_(self.W2.weight)\n        torch.nn.init.constant_(self.W2.bias, 0)\n\n        # norm\n        self.norm_dict = norm_dict\n\n    def forward(self, g, feat_dict):\n        funcs = {}  # message and reduce functions dict\n        # for each type of edges, compute messages and reduce them all\n        for srctype, etype, dsttype in g.canonical_etypes:\n            if srctype == dsttype:  # for self loops\n                messages = self.W1(feat_dict[srctype])\n                g.nodes[srctype].data[etype] = messages  # store in ndata\n                funcs[(srctype, etype, dsttype)] = (\n                    fn.copy_u(etype, \"m\"),\n                    fn.sum(\"m\", \"h\"),\n                )  # define message and reduce functions\n            else:\n                src, dst = g.edges(etype=(srctype, etype, dsttype))\n                norm = self.norm_dict[(srctype, etype, dsttype)]\n                messages = norm * (\n                    self.W1(feat_dict[srctype][src])\n                    + self.W2(feat_dict[srctype][src] * feat_dict[dsttype][dst])\n                )  # compute messages\n                g.edges[(srctype, etype, dsttype)].data[\n                    etype\n                ] = messages  # store in edata\n                funcs[(srctype, etype, dsttype)] = (\n                    fn.copy_e(etype, \"m\"),\n                    fn.sum(\"m\", \"h\"),\n                )  # define message and reduce functions\n\n        g.multi_update_all(\n            funcs, \"sum\"\n        )  # update all, reduce by first type-wisely then across different types\n        feature_dict = {}\n        for ntype in g.ntypes:\n            h = self.leaky_relu(g.nodes[ntype].data[\"h\"])  # leaky relu\n            h = self.dropout(h)  # dropout\n            h = F.normalize(h, dim=1, p=2)  # l2 normalize\n            feature_dict[ntype] = h\n        return feature_dict\n\n\nclass NGCF(nn.Module):\n    def __init__(self, g, in_size, layer_size, dropout, lmbd=1e-5):\n        super(NGCF, self).__init__()\n        self.lmbd = lmbd\n        self.norm_dict = dict()\n        for srctype, etype, dsttype in g.canonical_etypes:\n            src, dst = g.edges(etype=(srctype, etype, dsttype))\n            dst_degree = g.in_degrees(\n                dst, etype=(srctype, etype, dsttype)\n            ).float()  # obtain degrees\n            src_degree = g.out_degrees(\n                src, etype=(srctype, etype, dsttype)\n            ).float()\n            norm = torch.pow(src_degree * dst_degree, -0.5).unsqueeze(\n                1\n            )  # compute norm\n            self.norm_dict[(srctype, etype, dsttype)] = norm\n\n        self.layers = nn.ModuleList()\n        self.layers.append(\n            NGCFLayer(in_size, layer_size[0], self.norm_dict, dropout[0])\n        )\n        self.num_layers = len(layer_size)\n        for i in range(self.num_layers - 1):\n            self.layers.append(\n                NGCFLayer(\n                    layer_size[i],\n                    layer_size[i + 1],\n                    self.norm_dict,\n                    dropout[i + 1],\n                )\n            )\n        self.initializer = nn.init.xavier_uniform_\n\n        # embeddings for different types of nodes\n        self.feature_dict = nn.ParameterDict(\n            {\n                ntype: nn.Parameter(\n                    self.initializer(torch.empty(g.num_nodes(ntype), in_size))\n                )\n                for ntype in g.ntypes\n            }\n        )\n\n    def create_bpr_loss(self, users, pos_items, neg_items):\n        pos_scores = (users * pos_items).sum(1)\n        neg_scores = (users * neg_items).sum(1)\n\n        mf_loss = nn.LogSigmoid()(pos_scores - neg_scores).mean()\n        mf_loss = -1 * mf_loss\n\n        regularizer = (\n            torch.norm(users) ** 2\n            + torch.norm(pos_items) ** 2\n            + torch.norm(neg_items) ** 2\n        ) / 2\n        emb_loss = self.lmbd * regularizer / users.shape[0]\n\n        return mf_loss + emb_loss, mf_loss, emb_loss\n\n    def rating(self, u_g_embeddings, pos_i_g_embeddings):\n        return torch.matmul(u_g_embeddings, pos_i_g_embeddings.t())\n\n    def forward(self, g, user_key, item_key, users, pos_items, neg_items):\n        h_dict = {ntype: self.feature_dict[ntype] for ntype in g.ntypes}\n        # obtain features of each layer and concatenate them all\n        user_embeds = []\n        item_embeds = []\n        user_embeds.append(h_dict[user_key])\n        item_embeds.append(h_dict[item_key])\n        for layer in self.layers:\n            h_dict = layer(g, h_dict)\n            user_embeds.append(h_dict[user_key])\n            item_embeds.append(h_dict[item_key])\n        user_embd = torch.cat(user_embeds, 1)\n        item_embd = torch.cat(item_embeds, 1)\n\n        u_g_embeddings = user_embd[users, :]\n        pos_i_g_embeddings = item_embd[pos_items, :]\n        neg_i_g_embeddings = item_embd[neg_items, :]\n\n        return u_g_embeddings, pos_i_g_embeddings, neg_i_g_embeddings\n"
  },
  {
    "path": "examples/pytorch/NGCF/NGCF/utility/batch_test.py",
    "content": "# This file is based on the NGCF author's implementation\n# <https://github.com/xiangwang1223/neural_graph_collaborative_filtering/blob/master/NGCF/utility/batch_test.py>.\n# It implements the batch test.\n\nimport heapq\nimport multiprocessing\n\nimport utility.metrics as metrics\nfrom utility.load_data import *\nfrom utility.parser import parse_args\n\ncores = multiprocessing.cpu_count()\n\nargs = parse_args()\nKs = eval(args.Ks)\n\ndata_generator = Data(\n    path=args.data_path + args.dataset, batch_size=args.batch_size\n)\nUSR_NUM, ITEM_NUM = data_generator.n_users, data_generator.n_items\nN_TRAIN, N_TEST = data_generator.n_train, data_generator.n_test\nBATCH_SIZE = args.batch_size\n\n\ndef ranklist_by_heapq(user_pos_test, test_items, rating, Ks):\n    item_score = {}\n    for i in test_items:\n        item_score[i] = rating[i]\n\n    K_max = max(Ks)\n    K_max_item_score = heapq.nlargest(K_max, item_score, key=item_score.get)\n\n    r = []\n    for i in K_max_item_score:\n        if i in user_pos_test:\n            r.append(1)\n        else:\n            r.append(0)\n    auc = 0.0\n    return r, auc\n\n\ndef get_auc(item_score, user_pos_test):\n    item_score = sorted(item_score.items(), key=lambda kv: kv[1])\n    item_score.reverse()\n    item_sort = [x[0] for x in item_score]\n    posterior = [x[1] for x in item_score]\n\n    r = []\n    for i in item_sort:\n        if i in user_pos_test:\n            r.append(1)\n        else:\n            r.append(0)\n    auc = metrics.auc(ground_truth=r, prediction=posterior)\n    return auc\n\n\ndef ranklist_by_sorted(user_pos_test, test_items, rating, Ks):\n    item_score = {}\n    for i in test_items:\n        item_score[i] = rating[i]\n\n    K_max = max(Ks)\n    K_max_item_score = heapq.nlargest(K_max, item_score, key=item_score.get)\n\n    r = []\n    for i in K_max_item_score:\n        if i in user_pos_test:\n            r.append(1)\n        else:\n            r.append(0)\n    auc = get_auc(item_score, user_pos_test)\n    return r, auc\n\n\ndef get_performance(user_pos_test, r, auc, Ks):\n    precision, recall, ndcg, hit_ratio = [], [], [], []\n\n    for K in Ks:\n        precision.append(metrics.precision_at_k(r, K))\n        recall.append(metrics.recall_at_k(r, K, len(user_pos_test)))\n        ndcg.append(metrics.ndcg_at_k(r, K))\n        hit_ratio.append(metrics.hit_at_k(r, K))\n\n    return {\n        \"recall\": np.array(recall),\n        \"precision\": np.array(precision),\n        \"ndcg\": np.array(ndcg),\n        \"hit_ratio\": np.array(hit_ratio),\n        \"auc\": auc,\n    }\n\n\ndef test_one_user(x):\n    # user u's ratings for user u\n    rating = x[0]\n    # uid\n    u = x[1]\n    # user u's items in the training set\n    try:\n        training_items = data_generator.train_items[u]\n    except Exception:\n        training_items = []\n    # user u's items in the test set\n    user_pos_test = data_generator.test_set[u]\n\n    all_items = set(range(ITEM_NUM))\n\n    test_items = list(all_items - set(training_items))\n\n    if args.test_flag == \"part\":\n        r, auc = ranklist_by_heapq(user_pos_test, test_items, rating, Ks)\n    else:\n        r, auc = ranklist_by_sorted(user_pos_test, test_items, rating, Ks)\n\n    return get_performance(user_pos_test, r, auc, Ks)\n\n\ndef test(model, g, users_to_test, batch_test_flag=False):\n    result = {\n        \"precision\": np.zeros(len(Ks)),\n        \"recall\": np.zeros(len(Ks)),\n        \"ndcg\": np.zeros(len(Ks)),\n        \"hit_ratio\": np.zeros(len(Ks)),\n        \"auc\": 0.0,\n    }\n\n    pool = multiprocessing.Pool(cores)\n\n    u_batch_size = 5000\n    i_batch_size = BATCH_SIZE\n\n    test_users = users_to_test\n    n_test_users = len(test_users)\n    n_user_batchs = n_test_users // u_batch_size + 1\n\n    count = 0\n\n    for u_batch_id in range(n_user_batchs):\n        start = u_batch_id * u_batch_size\n        end = (u_batch_id + 1) * u_batch_size\n\n        user_batch = test_users[start:end]\n\n        if batch_test_flag:\n            # batch-item test\n            n_item_batchs = ITEM_NUM // i_batch_size + 1\n            rate_batch = np.zeros(shape=(len(user_batch), ITEM_NUM))\n\n            i_count = 0\n            for i_batch_id in range(n_item_batchs):\n                i_start = i_batch_id * i_batch_size\n                i_end = min((i_batch_id + 1) * i_batch_size, ITEM_NUM)\n\n                item_batch = range(i_start, i_end)\n\n                u_g_embeddings, pos_i_g_embeddings, _ = model(\n                    g, \"user\", \"item\", user_batch, item_batch, []\n                )\n                i_rate_batch = (\n                    model.rating(u_g_embeddings, pos_i_g_embeddings)\n                    .detach()\n                    .cpu()\n                )\n\n                rate_batch[:, i_start:i_end] = i_rate_batch\n                i_count += i_rate_batch.shape[1]\n\n            assert i_count == ITEM_NUM\n\n        else:\n            # all-item test\n            item_batch = range(ITEM_NUM)\n            u_g_embeddings, pos_i_g_embeddings, _ = model(\n                g, \"user\", \"item\", user_batch, item_batch, []\n            )\n            rate_batch = (\n                model.rating(u_g_embeddings, pos_i_g_embeddings).detach().cpu()\n            )\n\n        user_batch_rating_uid = zip(rate_batch.numpy(), user_batch)\n        batch_result = pool.map(test_one_user, user_batch_rating_uid)\n        count += len(batch_result)\n\n        for re in batch_result:\n            result[\"precision\"] += re[\"precision\"] / n_test_users\n            result[\"recall\"] += re[\"recall\"] / n_test_users\n            result[\"ndcg\"] += re[\"ndcg\"] / n_test_users\n            result[\"hit_ratio\"] += re[\"hit_ratio\"] / n_test_users\n            result[\"auc\"] += re[\"auc\"] / n_test_users\n\n    assert count == n_test_users\n    pool.close()\n    return result\n"
  },
  {
    "path": "examples/pytorch/NGCF/NGCF/utility/helper.py",
    "content": "# This file is copied from the NGCF author's implementation\n# <https://github.com/xiangwang1223/neural_graph_collaborative_filtering/blob/master/NGCF/utility/helper.py>.\n# It implements the helper functions.\n\"\"\"\nCreated on Aug 19, 2016\n@author: Xiang Wang (xiangwang@u.nus.edu)\n\"\"\"\n__author__ = \"xiangwang\"\nimport os\nimport re\n\n\ndef txt2list(file_src):\n    orig_file = open(file_src, \"r\")\n    lines = orig_file.readlines()\n    return lines\n\n\ndef ensureDir(dir_path):\n    d = os.path.dirname(dir_path)\n    if not os.path.exists(d):\n        os.makedirs(d)\n\n\ndef uni2str(unicode_str):\n    return str(unicode_str.encode(\"ascii\", \"ignore\")).replace(\"\\n\", \"\").strip()\n\n\ndef hasNumbers(inputString):\n    return bool(re.search(r\"\\d\", inputString))\n\n\ndef delMultiChar(inputString, chars):\n    for ch in chars:\n        inputString = inputString.replace(ch, \"\")\n    return inputString\n\n\ndef merge_two_dicts(x, y):\n    z = x.copy()  # start with x's keys and values\n    z.update(y)  # modifies z with y's keys and values & returns None\n    return z\n\n\ndef early_stopping(\n    log_value, best_value, stopping_step, expected_order=\"acc\", flag_step=100\n):\n    # early stopping strategy:\n    assert expected_order in [\"acc\", \"dec\"]\n\n    if (expected_order == \"acc\" and log_value >= best_value) or (\n        expected_order == \"dec\" and log_value <= best_value\n    ):\n        stopping_step = 0\n        best_value = log_value\n    else:\n        stopping_step += 1\n\n    if stopping_step >= flag_step:\n        print(\n            \"Early stopping is trigger at step: {} log:{}\".format(\n                flag_step, log_value\n            )\n        )\n        should_stop = True\n    else:\n        should_stop = False\n    return best_value, stopping_step, should_stop\n"
  },
  {
    "path": "examples/pytorch/NGCF/NGCF/utility/load_data.py",
    "content": "# This file is based on the NGCF author's implementation\n# <https://github.com/xiangwang1223/neural_graph_collaborative_filtering/blob/master/NGCF/utility/load_data.py>.\n# It implements the data processing and graph construction.\nimport random as rd\n\nimport dgl\n\nimport numpy as np\n\n\nclass Data(object):\n    def __init__(self, path, batch_size):\n        self.path = path\n        self.batch_size = batch_size\n\n        train_file = path + \"/train.txt\"\n        test_file = path + \"/test.txt\"\n\n        # get number of users and items\n        self.n_users, self.n_items = 0, 0\n        self.n_train, self.n_test = 0, 0\n        self.exist_users = []\n\n        user_item_src = []\n        user_item_dst = []\n\n        with open(train_file) as f:\n            for l in f.readlines():\n                if len(l) > 0:\n                    l = l.strip(\"\\n\").split(\" \")\n                    items = [int(i) for i in l[1:]]\n                    uid = int(l[0])\n                    self.exist_users.append(uid)\n                    self.n_items = max(self.n_items, max(items))\n                    self.n_users = max(self.n_users, uid)\n                    self.n_train += len(items)\n                    for i in l[1:]:\n                        user_item_src.append(uid)\n                        user_item_dst.append(int(i))\n\n        with open(test_file) as f:\n            for l in f.readlines():\n                if len(l) > 0:\n                    l = l.strip(\"\\n\")\n                    try:\n                        items = [int(i) for i in l.split(\" \")[1:]]\n                    except Exception:\n                        continue\n                    self.n_items = max(self.n_items, max(items))\n                    self.n_test += len(items)\n        self.n_items += 1\n        self.n_users += 1\n\n        self.print_statistics()\n\n        # training positive items corresponding to each user; testing positive items corresponding to each user\n        self.train_items, self.test_set = {}, {}\n        with open(train_file) as f_train:\n            with open(test_file) as f_test:\n                for l in f_train.readlines():\n                    if len(l) == 0:\n                        break\n                    l = l.strip(\"\\n\")\n                    items = [int(i) for i in l.split(\" \")]\n                    uid, train_items = items[0], items[1:]\n                    self.train_items[uid] = train_items\n\n                for l in f_test.readlines():\n                    if len(l) == 0:\n                        break\n                    l = l.strip(\"\\n\")\n                    try:\n                        items = [int(i) for i in l.split(\" \")]\n                    except Exception:\n                        continue\n\n                    uid, test_items = items[0], items[1:]\n                    self.test_set[uid] = test_items\n\n        # construct graph from the train data and add self-loops\n        user_selfs = [i for i in range(self.n_users)]\n        item_selfs = [i for i in range(self.n_items)]\n\n        data_dict = {\n            (\"user\", \"user_self\", \"user\"): (user_selfs, user_selfs),\n            (\"item\", \"item_self\", \"item\"): (item_selfs, item_selfs),\n            (\"user\", \"ui\", \"item\"): (user_item_src, user_item_dst),\n            (\"item\", \"iu\", \"user\"): (user_item_dst, user_item_src),\n        }\n        num_dict = {\"user\": self.n_users, \"item\": self.n_items}\n\n        self.g = dgl.heterograph(data_dict, num_nodes_dict=num_dict)\n\n    def sample(self):\n        if self.batch_size <= self.n_users:\n            users = rd.sample(self.exist_users, self.batch_size)\n        else:\n            users = [\n                rd.choice(self.exist_users) for _ in range(self.batch_size)\n            ]\n\n        def sample_pos_items_for_u(u, num):\n            # sample num pos items for u-th user\n            pos_items = self.train_items[u]\n            n_pos_items = len(pos_items)\n            pos_batch = []\n            while True:\n                if len(pos_batch) == num:\n                    break\n                pos_id = np.random.randint(low=0, high=n_pos_items, size=1)[0]\n                pos_i_id = pos_items[pos_id]\n\n                if pos_i_id not in pos_batch:\n                    pos_batch.append(pos_i_id)\n            return pos_batch\n\n        def sample_neg_items_for_u(u, num):\n            # sample num neg items for u-th user\n            neg_items = []\n            while True:\n                if len(neg_items) == num:\n                    break\n                neg_id = np.random.randint(low=0, high=self.n_items, size=1)[0]\n                if (\n                    neg_id not in self.train_items[u]\n                    and neg_id not in neg_items\n                ):\n                    neg_items.append(neg_id)\n            return neg_items\n\n        pos_items, neg_items = [], []\n        for u in users:\n            pos_items += sample_pos_items_for_u(u, 1)\n            neg_items += sample_neg_items_for_u(u, 1)\n\n        return users, pos_items, neg_items\n\n    def get_num_users_items(self):\n        return self.n_users, self.n_items\n\n    def print_statistics(self):\n        print(\"n_users=%d, n_items=%d\" % (self.n_users, self.n_items))\n        print(\"n_interactions=%d\" % (self.n_train + self.n_test))\n        print(\n            \"n_train=%d, n_test=%d, sparsity=%.5f\"\n            % (\n                self.n_train,\n                self.n_test,\n                (self.n_train + self.n_test) / (self.n_users * self.n_items),\n            )\n        )\n"
  },
  {
    "path": "examples/pytorch/NGCF/NGCF/utility/metrics.py",
    "content": "# This file is copied from the NGCF author's implementation\n# <https://github.com/xiangwang1223/neural_graph_collaborative_filtering/blob/master/NGCF/utility/metrics.py>.\n# It implements the metrics.\n\"\"\"\nCreated on Oct 10, 2018\nTensorflow Implementation of Neural Graph Collaborative Filtering (NGCF) model in:\nWang Xiang et al. Neural Graph Collaborative Filtering. In SIGIR 2019.\n@author: Xiang Wang (xiangwang@u.nus.edu)\n\"\"\"\n\nimport numpy as np\nfrom sklearn.metrics import roc_auc_score\n\n\ndef recall(rank, ground_truth, N):\n    return len(set(rank[:N]) & set(ground_truth)) / float(\n        len(set(ground_truth))\n    )\n\n\ndef precision_at_k(r, k):\n    \"\"\"Score is precision @ k\n    Relevance is binary (nonzero is relevant).\n    Returns:\n        Precision @ k\n    Raises:\n        ValueError: len(r) must be >= k\n    \"\"\"\n    assert k >= 1\n    r = np.asarray(r)[:k]\n    return np.mean(r)\n\n\ndef average_precision(r, cut):\n    \"\"\"Score is average precision (area under PR curve)\n    Relevance is binary (nonzero is relevant).\n    Returns:\n        Average precision\n    \"\"\"\n    r = np.asarray(r)\n    out = [precision_at_k(r, k + 1) for k in range(cut) if r[k]]\n    if not out:\n        return 0.0\n    return np.sum(out) / float(min(cut, np.sum(r)))\n\n\ndef mean_average_precision(rs):\n    \"\"\"Score is mean average precision\n    Relevance is binary (nonzero is relevant).\n    Returns:\n        Mean average precision\n    \"\"\"\n    return np.mean([average_precision(r) for r in rs])\n\n\ndef dcg_at_k(r, k, method=1):\n    \"\"\"Score is discounted cumulative gain (dcg)\n    Relevance is positive real values.  Can use binary\n    as the previous methods.\n    Returns:\n        Discounted cumulative gain\n    \"\"\"\n    r = np.asfarray(r)[:k]\n    if r.size:\n        if method == 0:\n            return r[0] + np.sum(r[1:] / np.log2(np.arange(2, r.size + 1)))\n        elif method == 1:\n            return np.sum(r / np.log2(np.arange(2, r.size + 2)))\n        else:\n            raise ValueError(\"method must be 0 or 1.\")\n    return 0.0\n\n\ndef ndcg_at_k(r, k, method=1):\n    \"\"\"Score is normalized discounted cumulative gain (ndcg)\n    Relevance is positive real values.  Can use binary\n    as the previous methods.\n    Returns:\n        Normalized discounted cumulative gain\n    \"\"\"\n    dcg_max = dcg_at_k(sorted(r, reverse=True), k, method)\n    if not dcg_max:\n        return 0.0\n    return dcg_at_k(r, k, method) / dcg_max\n\n\ndef recall_at_k(r, k, all_pos_num):\n    r = np.asfarray(r)[:k]\n    return np.sum(r) / all_pos_num\n\n\ndef hit_at_k(r, k):\n    r = np.array(r)[:k]\n    if np.sum(r) > 0:\n        return 1.0\n    else:\n        return 0.0\n\n\ndef F1(pre, rec):\n    if pre + rec > 0:\n        return (2.0 * pre * rec) / (pre + rec)\n    else:\n        return 0.0\n\n\ndef auc(ground_truth, prediction):\n    try:\n        res = roc_auc_score(y_true=ground_truth, y_score=prediction)\n    except Exception:\n        res = 0.0\n    return res\n"
  },
  {
    "path": "examples/pytorch/NGCF/NGCF/utility/parser.py",
    "content": "# This file is based on the NGCF author's implementation\n# <https://github.com/xiangwang1223/neural_graph_collaborative_filtering/blob/master/NGCF/utility/parser.py>.\n\nimport argparse\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Run NGCF.\")\n    parser.add_argument(\n        \"--weights_path\", nargs=\"?\", default=\"model/\", help=\"Store model path.\"\n    )\n    parser.add_argument(\n        \"--data_path\", nargs=\"?\", default=\"../Data/\", help=\"Input data path.\"\n    )\n    parser.add_argument(\n        \"--model_name\", type=str, default=\"NGCF.pkl\", help=\"Saved model name.\"\n    )\n\n    parser.add_argument(\n        \"--dataset\",\n        nargs=\"?\",\n        default=\"gowalla\",\n        help=\"Choose a dataset from {gowalla, yelp2018, amazon-book}\",\n    )\n    parser.add_argument(\n        \"--verbose\", type=int, default=1, help=\"Interval of evaluation.\"\n    )\n    parser.add_argument(\n        \"--epoch\", type=int, default=400, help=\"Number of epoch.\"\n    )\n\n    parser.add_argument(\n        \"--embed_size\", type=int, default=64, help=\"Embedding size.\"\n    )\n    parser.add_argument(\n        \"--layer_size\",\n        nargs=\"?\",\n        default=\"[64,64,64]\",\n        help=\"Output sizes of every layer\",\n    )\n    parser.add_argument(\n        \"--batch_size\", type=int, default=1024, help=\"Batch size.\"\n    )\n\n    parser.add_argument(\n        \"--regs\", nargs=\"?\", default=\"[1e-5]\", help=\"Regularizations.\"\n    )\n    parser.add_argument(\n        \"--lr\", type=float, default=0.0001, help=\"Learning rate.\"\n    )\n\n    parser.add_argument(\n        \"--gpu\", type=int, default=0, help=\"0 for NAIS_prod, 1 for NAIS_concat\"\n    )\n\n    parser.add_argument(\n        \"--mess_dropout\",\n        nargs=\"?\",\n        default=\"[0.1,0.1,0.1]\",\n        help=\"Keep probability w.r.t. message dropout (i.e., 1-dropout_ratio) for each deep layer. 1: no dropout.\",\n    )\n\n    parser.add_argument(\n        \"--Ks\",\n        nargs=\"?\",\n        default=\"[20, 40]\",\n        help=\"Output sizes of every layer\",\n    )\n\n    parser.add_argument(\n        \"--save_flag\",\n        type=int,\n        default=1,\n        help=\"0: Disable model saver, 1: Activate model saver\",\n    )\n\n    parser.add_argument(\n        \"--test_flag\",\n        nargs=\"?\",\n        default=\"part\",\n        help=\"Specify the test type from {part, full}, indicating whether the reference is done in mini-batch\",\n    )\n\n    parser.add_argument(\n        \"--report\",\n        type=int,\n        default=0,\n        help=\"0: Disable performance report w.r.t. sparsity levels, 1: Show performance report w.r.t. sparsity levels\",\n    )\n    return parser.parse_args()\n"
  },
  {
    "path": "examples/pytorch/NGCF/README.md",
    "content": "# DGL Implementation of the NGCF Model\n\nThis DGL example implements the GNN model proposed in the paper [Neural Graph Collaborative Filtering](https://arxiv.org/abs/1905.08108). \nThe author's codes of implementation is in [here](https://github.com/xiangwang1223/neural_graph_collaborative_filtering). A pytorch re-implementation can be found [here](https://github.com/huangtinglin/NGCF-PyTorch).\n\nExample implementor\n----------------------\nThis example was implemented by [Kounianhua Du](https://github.com/KounianhuaDu) during her Software Dev Engineer Intern work at the AWS Shanghai AI Lab.\n\n\nThe graph dataset used in this example \n---------------------------------------\nGowalla: This is the check-in dataset obtained from Gowalla, where users share their locations by checking-in. To ensure the quality of the dataset, we use the 10-core setting, i.e., retaining users and items with at least ten interactions. The dataset used can be found [here](https://github.com/xiangwang1223/neural_graph_collaborative_filtering/tree/master/Data).\n\nStatistics:\n- Users: 29858\n- Items: 40981\n- Interactions: 1027370\n- Density: 0.00084\n\n\nHow to run example files\n--------------------------------\nFirst to get the data, in the Data folder, run\n\n```bash\nsh load_gowalla.sh\n```\n\nThen, in the NGCF folder, run\n\n```bash\npython main.py --dataset gowalla --regs [1e-5] --embed_size 64 --layer_size [64,64,64] --lr 0.0001 --save_flag 1 --batch_size 1024 --epoch 400 --verbose 1 --mess_dropout [0.1,0.1,0.1] --gpu 0 \n```\n\nNOTE: Following the paper's setting, the node dropout is disabled.\n\n\nPerformance\n-------------------------\nThe following results are the results in 400 epoches.\n\n**NGCF results**\n| Model         | Paper (tensorflow)               | ours (DGL)                  |\n| ------------- | -------------------------------- | --------------------------- |\n| recall@20     | 0.1569                           | 0.1552                      |\n| ndcg@20       | 0.1327                           | 0.2707                      |\n\n"
  },
  {
    "path": "examples/pytorch/P-GNN/README.md",
    "content": "# DGL Implementations of P-GNN\n\nThis DGL example implements the GNN model proposed in the paper [Position-aware Graph Neural Networks](http://proceedings.mlr.press/v97/you19b/you19b.pdf). For the original implementation, see [here](https://github.com/JiaxuanYou/P-GNN).\n\nContributor: [RecLusIve-F](https://github.com/RecLusIve-F)\n\n## Requirements\n\nThe codebase is implemented in Python 3.8. For version requirement of packages, see below.\n\n```\ndgl 0.7.2\nnumpy 1.21.2\ntorch 1.10.1\nnetworkx 2.6.3\nscikit-learn 1.0.2\n```\n\n## Instructions for experiments\n\n### Link prediction\n\n```bash\n# Communities-T\npython main.py --task link\n\n# Communities\npython main.py --task link --inductive\n```\n\n### Link pair prediction\n\n```bash\n# Communities\npython main.py --task link_pair --inductive\n```\n\n## Performance\n\n### Link prediction (Grid-T and Communities-T refer to the transductive learning setting of Grid and Communities)\n\n|             Dataset              | Communities-T |  Communities  |\n| :------------------------------: | :-----------: | :-----------: |\n| ROC AUC ( P-GNN-E-2L in Table 1) | 0.988 ± 0.003 | 0.985 ± 0.008 |\n|    ROC AUC (DGL: P-GNN-E-2L)     | 0.984 ± 0.010 | 0.991 ± 0.004 |\n\n### Link pair prediction\n\n|             Dataset              | Communities |\n| :------------------------------: | :---------: |\n| ROC AUC ( P-GNN-E-2L in Table 1) | 1.0 ± 0.001 |\n|    ROC AUC (DGL: P-GNN-E-2L)     | 1.0 ± 0.000 |\n"
  },
  {
    "path": "examples/pytorch/P-GNN/main.py",
    "content": "import os\nimport warnings\n\nimport dgl\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nfrom model import PGNN\nfrom sklearn.metrics import roc_auc_score\nfrom utils import get_dataset, preselect_anchor\n\nwarnings.filterwarnings(\"ignore\")\n\n\ndef get_loss(p, data, out, loss_func, device, get_auc=True):\n    edge_mask = np.concatenate(\n        (\n            data[\"positive_edges_{}\".format(p)],\n            data[\"negative_edges_{}\".format(p)],\n        ),\n        axis=-1,\n    )\n\n    nodes_first = torch.index_select(\n        out, 0, torch.from_numpy(edge_mask[0, :]).long().to(out.device)\n    )\n    nodes_second = torch.index_select(\n        out, 0, torch.from_numpy(edge_mask[1, :]).long().to(out.device)\n    )\n\n    pred = torch.sum(nodes_first * nodes_second, dim=-1)\n\n    label_positive = torch.ones(\n        [\n            data[\"positive_edges_{}\".format(p)].shape[1],\n        ],\n        dtype=pred.dtype,\n    )\n    label_negative = torch.zeros(\n        [\n            data[\"negative_edges_{}\".format(p)].shape[1],\n        ],\n        dtype=pred.dtype,\n    )\n    label = torch.cat((label_positive, label_negative)).to(device)\n    loss = loss_func(pred, label)\n\n    if get_auc:\n        auc = roc_auc_score(\n            label.flatten().cpu().numpy(),\n            torch.sigmoid(pred).flatten().data.cpu().numpy(),\n        )\n        return loss, auc\n    else:\n        return loss\n\n\ndef train_model(data, model, loss_func, optimizer, device, g_data):\n    model.train()\n    out = model(g_data)\n\n    loss = get_loss(\"train\", data, out, loss_func, device, get_auc=False)\n\n    optimizer.zero_grad()\n    loss.backward()\n    optimizer.step()\n    optimizer.zero_grad()\n\n    return g_data\n\n\ndef eval_model(data, g_data, model, loss_func, device):\n    model.eval()\n    out = model(g_data)\n\n    # train loss and auc\n    tmp_loss, auc_train = get_loss(\"train\", data, out, loss_func, device)\n    loss_train = tmp_loss.cpu().data.numpy()\n\n    # val loss and auc\n    _, auc_val = get_loss(\"val\", data, out, loss_func, device)\n\n    # test loss and auc\n    _, auc_test = get_loss(\"test\", data, out, loss_func, device)\n\n    return loss_train, auc_train, auc_val, auc_test\n\n\ndef main(args):\n    # The mean and standard deviation of the experiment results\n    # are stored in the 'results' folder\n    if not os.path.isdir(\"results\"):\n        os.mkdir(\"results\")\n\n    if torch.cuda.is_available():\n        device = \"cuda:0\"\n    else:\n        device = \"cpu\"\n\n    print(\n        \"Learning Type: {}\".format(\n            [\"Transductive\", \"Inductive\"][args.inductive]\n        ),\n        \"Task: {}\".format(args.task),\n    )\n\n    results = []\n\n    for repeat in range(args.repeat_num):\n        data = get_dataset(args)\n\n        # pre-sample anchor nodes and compute shortest distance values for all epochs\n        (\n            g_list,\n            anchor_eid_list,\n            dist_max_list,\n            edge_weight_list,\n        ) = preselect_anchor(data, args)\n\n        # model\n        model = PGNN(input_dim=data[\"feature\"].shape[1]).to(device)\n\n        # loss\n        optimizer = torch.optim.Adam(\n            model.parameters(), lr=1e-2, weight_decay=5e-4\n        )\n        loss_func = nn.BCEWithLogitsLoss()\n\n        best_auc_val = -1\n        best_auc_test = -1\n\n        for epoch in range(args.epoch_num):\n            if epoch == 200:\n                for param_group in optimizer.param_groups:\n                    param_group[\"lr\"] /= 10\n\n            g = dgl.graph(g_list[epoch])\n            g.ndata[\"feat\"] = torch.FloatTensor(data[\"feature\"])\n            g.edata[\"sp_dist\"] = torch.FloatTensor(edge_weight_list[epoch])\n            g_data = {\n                \"graph\": g.to(device),\n                \"anchor_eid\": anchor_eid_list[epoch],\n                \"dists_max\": dist_max_list[epoch],\n            }\n\n            train_model(data, model, loss_func, optimizer, device, g_data)\n\n            loss_train, auc_train, auc_val, auc_test = eval_model(\n                data, g_data, model, loss_func, device\n            )\n            if auc_val > best_auc_val:\n                best_auc_val = auc_val\n                best_auc_test = auc_test\n\n            if epoch % args.epoch_log == 0:\n                print(\n                    repeat,\n                    epoch,\n                    \"Loss {:.4f}\".format(loss_train),\n                    \"Train AUC: {:.4f}\".format(auc_train),\n                    \"Val AUC: {:.4f}\".format(auc_val),\n                    \"Test AUC: {:.4f}\".format(auc_test),\n                    \"Best Val AUC: {:.4f}\".format(best_auc_val),\n                    \"Best Test AUC: {:.4f}\".format(best_auc_test),\n                )\n\n        results.append(best_auc_test)\n\n    results = np.array(results)\n    results_mean = np.mean(results).round(6)\n    results_std = np.std(results).round(6)\n    print(\"-----------------Final-------------------\")\n    print(results_mean, results_std)\n\n    with open(\n        \"results/{}_{}_{}.txt\".format(\n            [\"Transductive\", \"Inductive\"][args.inductive],\n            args.task,\n            args.k_hop_dist,\n        ),\n        \"w\",\n    ) as f:\n        f.write(\"{}, {}\\n\".format(results_mean, results_std))\n\n\nif __name__ == \"__main__\":\n    from argparse import ArgumentParser\n\n    parser = ArgumentParser()\n    parser.add_argument(\n        \"--task\", type=str, default=\"link\", choices=[\"link\", \"link_pair\"]\n    )\n    parser.add_argument(\n        \"--inductive\",\n        action=\"store_true\",\n        help=\"Inductive learning or transductive learning\",\n    )\n    parser.add_argument(\n        \"--k_hop_dist\",\n        default=-1,\n        type=int,\n        help=\"K-hop shortest path distance, -1 means exact shortest path.\",\n    )\n\n    parser.add_argument(\"--epoch_num\", type=int, default=2000)\n    parser.add_argument(\"--repeat_num\", type=int, default=10)\n    parser.add_argument(\"--epoch_log\", type=int, default=100)\n\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/pytorch/P-GNN/model.py",
    "content": "import dgl.function as fn\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass PGNN_layer(nn.Module):\n    def __init__(self, input_dim, output_dim):\n        super(PGNN_layer, self).__init__()\n        self.input_dim = input_dim\n\n        self.linear_hidden_u = nn.Linear(input_dim, output_dim)\n        self.linear_hidden_v = nn.Linear(input_dim, output_dim)\n        self.linear_out_position = nn.Linear(output_dim, 1)\n        self.act = nn.ReLU()\n\n    def forward(self, graph, feature, anchor_eid, dists_max):\n        with graph.local_scope():\n            u_feat = self.linear_hidden_u(feature)\n            v_feat = self.linear_hidden_v(feature)\n            graph.srcdata.update({\"u_feat\": u_feat})\n            graph.dstdata.update({\"v_feat\": v_feat})\n\n            graph.apply_edges(fn.u_mul_e(\"u_feat\", \"sp_dist\", \"u_message\"))\n            graph.apply_edges(fn.v_add_e(\"v_feat\", \"u_message\", \"message\"))\n\n            messages = torch.index_select(\n                graph.edata[\"message\"],\n                0,\n                torch.LongTensor(anchor_eid).to(feature.device),\n            )\n            messages = messages.reshape(\n                dists_max.shape[0], dists_max.shape[1], messages.shape[-1]\n            )\n\n            messages = self.act(messages)  # n*m*d\n\n            out_position = self.linear_out_position(messages).squeeze(\n                -1\n            )  # n*m_out\n            out_structure = torch.mean(messages, dim=1)  # n*d\n\n            return out_position, out_structure\n\n\nclass PGNN(nn.Module):\n    def __init__(self, input_dim, feature_dim=32, dropout=0.5):\n        super(PGNN, self).__init__()\n        self.dropout = nn.Dropout(dropout)\n\n        self.linear_pre = nn.Linear(input_dim, feature_dim)\n        self.conv_first = PGNN_layer(feature_dim, feature_dim)\n        self.conv_out = PGNN_layer(feature_dim, feature_dim)\n\n    def forward(self, data):\n        x = data[\"graph\"].ndata[\"feat\"]\n        graph = data[\"graph\"]\n        x = self.linear_pre(x)\n        x_position, x = self.conv_first(\n            graph, x, data[\"anchor_eid\"], data[\"dists_max\"]\n        )\n\n        x = self.dropout(x)\n        x_position, x = self.conv_out(\n            graph, x, data[\"anchor_eid\"], data[\"dists_max\"]\n        )\n        x_position = F.normalize(x_position, p=2, dim=-1)\n        return x_position\n"
  },
  {
    "path": "examples/pytorch/P-GNN/utils.py",
    "content": "import multiprocessing as mp\nimport random\nfrom multiprocessing import get_context\n\nimport networkx as nx\nimport numpy as np\nimport torch\nfrom tqdm.auto import tqdm\n\n\ndef get_communities(remove_feature):\n    community_size = 20\n\n    # Create 20 cliques (communities) of size 20,\n    # then rewire a single edge in each clique to a node in an adjacent clique\n    graph = nx.connected_caveman_graph(20, community_size)\n\n    # Randomly rewire 1% edges\n    node_list = list(graph.nodes)\n    for u, v in graph.edges():\n        if random.random() < 0.01:\n            x = random.choice(node_list)\n            if graph.has_edge(u, x):\n                continue\n            graph.remove_edge(u, v)\n            graph.add_edge(u, x)\n\n    # remove self-loops\n    graph.remove_edges_from(nx.selfloop_edges(graph))\n    edge_index = np.array(list(graph.edges))\n    # Add (i, j) for an edge (j, i)\n    edge_index = np.concatenate((edge_index, edge_index[:, ::-1]), axis=0)\n    edge_index = torch.from_numpy(edge_index).long().permute(1, 0)\n\n    n = graph.number_of_nodes()\n    label = np.zeros((n, n), dtype=int)\n    for u in node_list:\n        # the node IDs are simply consecutive integers from 0\n        for v in range(u):\n            if u // community_size == v // community_size:\n                label[u, v] = 1\n\n    if remove_feature:\n        feature = torch.ones((n, 1))\n    else:\n        rand_order = np.random.permutation(n)\n        feature = np.identity(n)[:, rand_order]\n\n    data = {\n        \"edge_index\": edge_index,\n        \"feature\": feature,\n        \"positive_edges\": np.stack(np.nonzero(label)),\n        \"num_nodes\": feature.shape[0],\n    }\n\n    return data\n\n\ndef to_single_directed(edges):\n    edges_new = np.zeros((2, edges.shape[1] // 2), dtype=int)\n    j = 0\n    for i in range(edges.shape[1]):\n        if edges[0, i] < edges[1, i]:\n            edges_new[:, j] = edges[:, i]\n            j += 1\n\n    return edges_new\n\n\n# each node at least remain in the new graph\ndef split_edges(p, edges, data, non_train_ratio=0.2):\n    e = edges.shape[1]\n    edges = edges[:, np.random.permutation(e)]\n    split1 = int((1 - non_train_ratio) * e)\n    split2 = int((1 - non_train_ratio / 2) * e)\n\n    data.update(\n        {\n            \"{}_edges_train\".format(p): edges[:, :split1],  # 80%\n            \"{}_edges_val\".format(p): edges[:, split1:split2],  # 10%\n            \"{}_edges_test\".format(p): edges[:, split2:],  # 10%\n        }\n    )\n\n\ndef to_bidirected(edges):\n    return np.concatenate((edges, edges[::-1, :]), axis=-1)\n\n\ndef get_negative_edges(positive_edges, num_nodes, num_negative_edges):\n    positive_edge_set = []\n    positive_edges = to_bidirected(positive_edges)\n    for i in range(positive_edges.shape[1]):\n        positive_edge_set.append(tuple(positive_edges[:, i]))\n    positive_edge_set = set(positive_edge_set)\n\n    negative_edges = np.zeros(\n        (2, num_negative_edges), dtype=positive_edges.dtype\n    )\n    for i in range(num_negative_edges):\n        while True:\n            mask_temp = tuple(\n                np.random.choice(num_nodes, size=(2,), replace=False)\n            )\n            if mask_temp not in positive_edge_set:\n                negative_edges[:, i] = mask_temp\n                break\n\n    return negative_edges\n\n\ndef get_pos_neg_edges(data, infer_link_positive=True):\n    if infer_link_positive:\n        data[\"positive_edges\"] = to_single_directed(data[\"edge_index\"].numpy())\n    split_edges(\"positive\", data[\"positive_edges\"], data)\n\n    # resample edge mask link negative\n    negative_edges = get_negative_edges(\n        data[\"positive_edges\"],\n        data[\"num_nodes\"],\n        num_negative_edges=data[\"positive_edges\"].shape[1],\n    )\n    split_edges(\"negative\", negative_edges, data)\n\n    return data\n\n\ndef shortest_path(graph, node_range, cutoff):\n    dists_dict = {}\n    for node in tqdm(node_range, leave=False):\n        dists_dict[node] = nx.single_source_shortest_path_length(\n            graph, node, cutoff\n        )\n    return dists_dict\n\n\ndef merge_dicts(dicts):\n    result = {}\n    for dictionary in dicts:\n        result.update(dictionary)\n    return result\n\n\ndef all_pairs_shortest_path(graph, cutoff=None, num_workers=4):\n    nodes = list(graph.nodes)\n    random.shuffle(nodes)\n    pool = mp.Pool(processes=num_workers)\n    interval_size = len(nodes) / num_workers\n    results = [\n        pool.apply_async(\n            shortest_path,\n            args=(\n                graph,\n                nodes[int(interval_size * i) : int(interval_size * (i + 1))],\n                cutoff,\n            ),\n        )\n        for i in range(num_workers)\n    ]\n    output = [p.get() for p in results]\n    dists_dict = merge_dicts(output)\n    pool.close()\n    pool.join()\n    return dists_dict\n\n\ndef precompute_dist_data(edge_index, num_nodes, approximate=0):\n    \"\"\"\n    Here dist is 1/real_dist, higher actually means closer, 0 means disconnected\n    :return:\n    \"\"\"\n    graph = nx.Graph()\n    edge_list = edge_index.transpose(1, 0).tolist()\n    graph.add_edges_from(edge_list)\n\n    n = num_nodes\n    dists_array = np.zeros((n, n))\n    dists_dict = all_pairs_shortest_path(\n        graph, cutoff=approximate if approximate > 0 else None\n    )\n    node_list = graph.nodes()\n    for node_i in node_list:\n        shortest_dist = dists_dict[node_i]\n        for node_j in node_list:\n            dist = shortest_dist.get(node_j, -1)\n            if dist != -1:\n                dists_array[node_i, node_j] = 1 / (dist + 1)\n    return dists_array\n\n\ndef get_dataset(args):\n    # Generate graph data\n    data_info = get_communities(args.inductive)\n    # Get positive and negative edges\n    data = get_pos_neg_edges(\n        data_info, infer_link_positive=True if args.task == \"link\" else False\n    )\n    # Pre-compute shortest path length\n    if args.task == \"link\":\n        dists_removed = precompute_dist_data(\n            data[\"positive_edges_train\"],\n            data[\"num_nodes\"],\n            approximate=args.k_hop_dist,\n        )\n        data[\"dists\"] = torch.from_numpy(dists_removed).float()\n        data[\"edge_index\"] = torch.from_numpy(\n            to_bidirected(data[\"positive_edges_train\"])\n        ).long()\n    else:\n        dists = precompute_dist_data(\n            data[\"edge_index\"].numpy(),\n            data[\"num_nodes\"],\n            approximate=args.k_hop_dist,\n        )\n        data[\"dists\"] = torch.from_numpy(dists).float()\n\n    return data\n\n\ndef get_anchors(n):\n    \"\"\"Get a list of NumPy arrays, each of them is an anchor node set\"\"\"\n    m = int(np.log2(n))\n    anchor_set_id = []\n    for i in range(m):\n        anchor_size = int(n / np.exp2(i + 1))\n        for _ in range(m):\n            anchor_set_id.append(\n                np.random.choice(n, size=anchor_size, replace=False)\n            )\n    return anchor_set_id\n\n\ndef get_dist_max(anchor_set_id, dist):\n    # N x K, N is number of nodes, K is the number of anchor sets\n    dist_max = torch.zeros((dist.shape[0], len(anchor_set_id)))\n    dist_argmax = torch.zeros((dist.shape[0], len(anchor_set_id))).long()\n    for i in range(len(anchor_set_id)):\n        temp_id = torch.as_tensor(anchor_set_id[i], dtype=torch.long)\n        # Get reciprocal of shortest distance to each node in the i-th anchor set\n        dist_temp = torch.index_select(dist, 1, temp_id)\n        # For each node in the graph, find its closest anchor node in the set\n        # and the reciprocal of shortest distance\n        dist_max_temp, dist_argmax_temp = torch.max(dist_temp, dim=-1)\n        dist_max[:, i] = dist_max_temp\n        dist_argmax[:, i] = torch.index_select(temp_id, 0, dist_argmax_temp)\n    return dist_max, dist_argmax\n\n\ndef get_a_graph(dists_max, dists_argmax):\n    src = []\n    dst = []\n    real_src = []\n    real_dst = []\n    edge_weight = []\n    dists_max = dists_max.numpy()\n    for i in range(dists_max.shape[0]):\n        # Get unique closest anchor nodes for node i across all anchor sets\n        tmp_dists_argmax, tmp_dists_argmax_idx = np.unique(\n            dists_argmax[i, :], True\n        )\n        src.extend([i] * tmp_dists_argmax.shape[0])\n        real_src.extend([i] * dists_argmax[i, :].shape[0])\n        real_dst.extend(list(dists_argmax[i, :].numpy()))\n        dst.extend(list(tmp_dists_argmax))\n        edge_weight.extend(dists_max[i, tmp_dists_argmax_idx].tolist())\n    eid_dict = {(u, v): i for i, (u, v) in enumerate(list(zip(dst, src)))}\n    anchor_eid = [eid_dict.get((u, v)) for u, v in zip(real_dst, real_src)]\n    g = (dst, src)\n    return g, anchor_eid, edge_weight\n\n\ndef get_graphs(data, anchor_sets):\n    graphs = []\n    anchor_eids = []\n    dists_max_list = []\n    edge_weights = []\n    for anchor_set in tqdm(anchor_sets, leave=False):\n        dists_max, dists_argmax = get_dist_max(anchor_set, data[\"dists\"])\n        g, anchor_eid, edge_weight = get_a_graph(dists_max, dists_argmax)\n        graphs.append(g)\n        anchor_eids.append(anchor_eid)\n        dists_max_list.append(dists_max)\n        edge_weights.append(edge_weight)\n\n    return graphs, anchor_eids, dists_max_list, edge_weights\n\n\ndef merge_result(outputs):\n    graphs = []\n    anchor_eids = []\n    dists_max_list = []\n    edge_weights = []\n\n    for g, anchor_eid, dists_max, edge_weight in outputs:\n        graphs.extend(g)\n        anchor_eids.extend(anchor_eid)\n        dists_max_list.extend(dists_max)\n        edge_weights.extend(edge_weight)\n\n    return graphs, anchor_eids, dists_max_list, edge_weights\n\n\ndef preselect_anchor(data, args, num_workers=4):\n    pool = get_context(\"spawn\").Pool(processes=num_workers)\n    # Pre-compute anchor sets, a collection of anchor sets per epoch\n    anchor_set_ids = [\n        get_anchors(data[\"num_nodes\"]) for _ in range(args.epoch_num)\n    ]\n    interval_size = len(anchor_set_ids) / num_workers\n    results = [\n        pool.apply_async(\n            get_graphs,\n            args=(\n                data,\n                anchor_set_ids[\n                    int(interval_size * i) : int(interval_size * (i + 1))\n                ],\n            ),\n        )\n        for i in range(num_workers)\n    ]\n\n    output = [p.get() for p in results]\n    graphs, anchor_eids, dists_max_list, edge_weights = merge_result(output)\n    pool.close()\n    pool.join()\n\n    return graphs, anchor_eids, dists_max_list, edge_weights\n"
  },
  {
    "path": "examples/pytorch/TAHIN/TAHIN.py",
    "content": "import dgl\nimport dgl.function as fn\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dgl.nn.pytorch import GATConv\n\n\n# Semantic attention in the metapath-based aggregation (the same as that in the HAN)\nclass SemanticAttention(nn.Module):\n    def __init__(self, in_size, hidden_size=128):\n        super(SemanticAttention, self).__init__()\n\n        self.project = nn.Sequential(\n            nn.Linear(in_size, hidden_size),\n            nn.Tanh(),\n            nn.Linear(hidden_size, 1, bias=False),\n        )\n\n    def forward(self, z):\n        \"\"\"\n        Shape of z: (N, M , D*K)\n        N: number of nodes\n        M: number of metapath patterns\n        D: hidden_size\n        K: number of heads\n        \"\"\"\n        w = self.project(z).mean(0)  # (M, 1)\n        beta = torch.softmax(w, dim=0)  # (M, 1)\n        beta = beta.expand((z.shape[0],) + beta.shape)  # (N, M, 1)\n\n        return (beta * z).sum(1)  # (N, D * K)\n\n\n# Metapath-based aggregation (the same as the HANLayer)\nclass HANLayer(nn.Module):\n    def __init__(\n        self, meta_path_patterns, in_size, out_size, layer_num_heads, dropout\n    ):\n        super(HANLayer, self).__init__()\n\n        # One GAT layer for each meta path based adjacency matrix\n        self.gat_layers = nn.ModuleList()\n        for i in range(len(meta_path_patterns)):\n            self.gat_layers.append(\n                GATConv(\n                    in_size,\n                    out_size,\n                    layer_num_heads,\n                    dropout,\n                    dropout,\n                    activation=F.elu,\n                    allow_zero_in_degree=True,\n                )\n            )\n        self.semantic_attention = SemanticAttention(\n            in_size=out_size * layer_num_heads\n        )\n        self.meta_path_patterns = list(\n            tuple(meta_path_pattern) for meta_path_pattern in meta_path_patterns\n        )\n\n        self._cached_graph = None\n        self._cached_coalesced_graph = {}\n\n    def forward(self, g, h):\n        semantic_embeddings = []\n        # obtain metapath reachable graph\n        if self._cached_graph is None or self._cached_graph is not g:\n            self._cached_graph = g\n            self._cached_coalesced_graph.clear()\n            for meta_path_pattern in self.meta_path_patterns:\n                self._cached_coalesced_graph[\n                    meta_path_pattern\n                ] = dgl.metapath_reachable_graph(g, meta_path_pattern)\n\n        for i, meta_path_pattern in enumerate(self.meta_path_patterns):\n            new_g = self._cached_coalesced_graph[meta_path_pattern]\n            semantic_embeddings.append(self.gat_layers[i](new_g, h).flatten(1))\n        semantic_embeddings = torch.stack(\n            semantic_embeddings, dim=1\n        )  # (N, M, D * K)\n\n        return self.semantic_attention(semantic_embeddings)  # (N, D * K)\n\n\n# Relational neighbor aggregation\nclass RelationalAGG(nn.Module):\n    def __init__(self, g, in_size, out_size, dropout=0.1):\n        super(RelationalAGG, self).__init__()\n        self.in_size = in_size\n        self.out_size = out_size\n\n        # Transform weights for different types of edges\n        self.W_T = nn.ModuleDict(\n            {\n                name: nn.Linear(in_size, out_size, bias=False)\n                for name in g.etypes\n            }\n        )\n\n        # Attention weights for different types of edges\n        self.W_A = nn.ModuleDict(\n            {name: nn.Linear(out_size, 1, bias=False) for name in g.etypes}\n        )\n\n        # layernorm\n        self.layernorm = nn.LayerNorm(out_size)\n\n        # dropout layer\n        self.dropout = nn.Dropout(dropout)\n\n    def forward(self, g, feat_dict):\n        funcs = {}\n        for srctype, etype, dsttype in g.canonical_etypes:\n            g.nodes[dsttype].data[\"h\"] = feat_dict[\n                dsttype\n            ]  # nodes' original feature\n            g.nodes[srctype].data[\"h\"] = feat_dict[srctype]\n            g.nodes[srctype].data[\"t_h\"] = self.W_T[etype](\n                feat_dict[srctype]\n            )  # src nodes' transformed feature\n\n            # compute the attention numerator (exp)\n            g.apply_edges(fn.u_mul_v(\"t_h\", \"h\", \"x\"), etype=etype)\n            g.edges[etype].data[\"x\"] = torch.exp(\n                self.W_A[etype](g.edges[etype].data[\"x\"])\n            )\n\n            # first update to compute the attention denominator (\\sum exp)\n            funcs[etype] = (fn.copy_e(\"x\", \"m\"), fn.sum(\"m\", \"att\"))\n        g.multi_update_all(funcs, \"sum\")\n\n        funcs = {}\n        for srctype, etype, dsttype in g.canonical_etypes:\n            g.apply_edges(\n                fn.e_div_v(\"x\", \"att\", \"att\"), etype=etype\n            )  # compute attention weights (numerator/denominator)\n            funcs[etype] = (\n                fn.u_mul_e(\"h\", \"att\", \"m\"),\n                fn.sum(\"m\", \"h\"),\n            )  # \\sum(h0*att) -> h1\n        # second update to obtain h1\n        g.multi_update_all(funcs, \"sum\")\n\n        # apply activation, layernorm, and dropout\n        feat_dict = {}\n        for ntype in g.ntypes:\n            feat_dict[ntype] = self.dropout(\n                self.layernorm(F.relu_(g.nodes[ntype].data[\"h\"]))\n            )  # apply activation, layernorm, and dropout\n\n        return feat_dict\n\n\nclass TAHIN(nn.Module):\n    def __init__(\n        self, g, meta_path_patterns, in_size, out_size, num_heads, dropout\n    ):\n        super(TAHIN, self).__init__()\n\n        # embeddings for different types of nodes, h0\n        self.initializer = nn.init.xavier_uniform_\n        self.feature_dict = nn.ParameterDict(\n            {\n                ntype: nn.Parameter(\n                    self.initializer(torch.empty(g.num_nodes(ntype), in_size))\n                )\n                for ntype in g.ntypes\n            }\n        )\n\n        # relational neighbor aggregation, this produces h1\n        self.RelationalAGG = RelationalAGG(g, in_size, out_size)\n\n        # metapath-based aggregation modules for user and item, this produces h2\n        self.meta_path_patterns = meta_path_patterns\n        # one HANLayer for user, one HANLayer for item\n        self.hans = nn.ModuleDict(\n            {\n                key: HANLayer(value, in_size, out_size, num_heads, dropout)\n                for key, value in self.meta_path_patterns.items()\n            }\n        )\n\n        # layers to combine h0, h1, and h2\n        # used to update node embeddings\n        self.user_layer1 = nn.Linear(\n            (num_heads + 1) * out_size, out_size, bias=True\n        )\n        self.user_layer2 = nn.Linear(2 * out_size, out_size, bias=True)\n        self.item_layer1 = nn.Linear(\n            (num_heads + 1) * out_size, out_size, bias=True\n        )\n        self.item_layer2 = nn.Linear(2 * out_size, out_size, bias=True)\n\n        # layernorm\n        self.layernorm = nn.LayerNorm(out_size)\n\n        # network to score the node pairs\n        self.pred = nn.Linear(out_size, out_size)\n        self.dropout = nn.Dropout(dropout)\n        self.fc = nn.Linear(out_size, 1)\n\n    def forward(self, g, user_key, item_key, user_idx, item_idx):\n        # relational neighbor aggregation, h1\n        h1 = self.RelationalAGG(g, self.feature_dict)\n\n        # metapath-based aggregation, h2\n        h2 = {}\n        for key in self.meta_path_patterns.keys():\n            h2[key] = self.hans[key](g, self.feature_dict[key])\n\n        # update node embeddings\n        user_emb = torch.cat((h1[user_key], h2[user_key]), 1)\n        item_emb = torch.cat((h1[item_key], h2[item_key]), 1)\n        user_emb = self.user_layer1(user_emb)\n        item_emb = self.item_layer1(item_emb)\n        user_emb = self.user_layer2(\n            torch.cat((user_emb, self.feature_dict[user_key]), 1)\n        )\n        item_emb = self.item_layer2(\n            torch.cat((item_emb, self.feature_dict[item_key]), 1)\n        )\n\n        # Relu\n        user_emb = F.relu_(user_emb)\n        item_emb = F.relu_(item_emb)\n\n        # layer norm\n        user_emb = self.layernorm(user_emb)\n        item_emb = self.layernorm(item_emb)\n\n        # obtain users/items embeddings and their interactions\n        user_feat = user_emb[user_idx]\n        item_feat = item_emb[item_idx]\n        interaction = user_feat * item_feat\n\n        # score the node pairs\n        pred = self.pred(interaction)\n        pred = self.dropout(pred)  # dropout\n        pred = self.fc(pred)\n        pred = torch.sigmoid(pred)\n\n        return pred.squeeze(1)\n"
  },
  {
    "path": "examples/pytorch/TAHIN/data_loader.py",
    "content": "import os\nimport pickle as pkl\nimport random\n\nimport dgl\n\nimport numpy as np\nimport torch\nfrom torch.utils.data import DataLoader, Dataset\n\n\n# Split data into train/eval/test\ndef split_data(hg, etype_name):\n    src, dst = hg.edges(etype=etype_name)\n    user_item_src = src.numpy().tolist()\n    user_item_dst = dst.numpy().tolist()\n\n    num_link = len(user_item_src)\n    pos_label = [1] * num_link\n    pos_data = list(zip(user_item_src, user_item_dst, pos_label))\n\n    ui_adj = np.array(hg.adj_external(etype=etype_name).to_dense())\n    full_idx = np.where(ui_adj == 0)\n\n    sample = random.sample(range(0, len(full_idx[0])), num_link)\n    neg_label = [0] * num_link\n    neg_data = list(zip(full_idx[0][sample], full_idx[1][sample], neg_label))\n\n    full_data = pos_data + neg_data\n    random.shuffle(full_data)\n\n    train_size = int(len(full_data) * 0.6)\n    eval_size = int(len(full_data) * 0.2)\n    test_size = len(full_data) - train_size - eval_size\n    train_data = full_data[:train_size]\n    eval_data = full_data[train_size : train_size + eval_size]\n    test_data = full_data[\n        train_size + eval_size : train_size + eval_size + test_size\n    ]\n    train_data = np.array(train_data)\n    eval_data = np.array(eval_data)\n    test_data = np.array(test_data)\n\n    return train_data, eval_data, test_data\n\n\ndef process_amazon(root_path):\n    # User-Item 3584 2753 50903 UIUI\n    # Item-View 2753 3857 5694 UIVI\n    # Item-Brand 2753 334 2753 UIBI\n    # Item-Category 2753 22 5508 UICI\n\n    # Construct graph from raw data.\n    # load data of amazon\n    data_path = os.path.join(root_path, \"Amazon\")\n    if not (os.path.exists(data_path)):\n        print(\n            \"Can not find amazon in {}, please download the dataset first.\".format(\n                data_path\n            )\n        )\n\n    # item_view\n    item_view_src = []\n    item_view_dst = []\n    with open(os.path.join(data_path, \"item_view.dat\")) as fin:\n        for line in fin.readlines():\n            _line = line.strip().split(\",\")\n            item, view = int(_line[0]), int(_line[1])\n            item_view_src.append(item)\n            item_view_dst.append(view)\n\n    # user_item\n    user_item_src = []\n    user_item_dst = []\n    with open(os.path.join(data_path, \"user_item.dat\")) as fin:\n        for line in fin.readlines():\n            _line = line.strip().split(\"\\t\")\n            user, item, rate = int(_line[0]), int(_line[1]), int(_line[2])\n            if rate > 3:\n                user_item_src.append(user)\n                user_item_dst.append(item)\n\n    # item_brand\n    item_brand_src = []\n    item_brand_dst = []\n    with open(os.path.join(data_path, \"item_brand.dat\")) as fin:\n        for line in fin.readlines():\n            _line = line.strip().split(\",\")\n            item, brand = int(_line[0]), int(_line[1])\n            item_brand_src.append(item)\n            item_brand_dst.append(brand)\n\n    # item_category\n    item_category_src = []\n    item_category_dst = []\n    with open(os.path.join(data_path, \"item_category.dat\")) as fin:\n        for line in fin.readlines():\n            _line = line.strip().split(\",\")\n            item, category = int(_line[0]), int(_line[1])\n            item_category_src.append(item)\n            item_category_dst.append(category)\n\n    # build graph\n    hg = dgl.heterograph(\n        {\n            (\"item\", \"iv\", \"view\"): (item_view_src, item_view_dst),\n            (\"view\", \"vi\", \"item\"): (item_view_dst, item_view_src),\n            (\"user\", \"ui\", \"item\"): (user_item_src, user_item_dst),\n            (\"item\", \"iu\", \"user\"): (user_item_dst, user_item_src),\n            (\"item\", \"ib\", \"brand\"): (item_brand_src, item_brand_dst),\n            (\"brand\", \"bi\", \"item\"): (item_brand_dst, item_brand_src),\n            (\"item\", \"ic\", \"category\"): (item_category_src, item_category_dst),\n            (\"category\", \"ci\", \"item\"): (item_category_dst, item_category_src),\n        }\n    )\n\n    print(\"Graph constructed.\")\n\n    # Split data into train/eval/test\n    train_data, eval_data, test_data = split_data(hg, \"ui\")\n\n    # delete the positive edges in eval/test data in the original graph\n    train_pos = np.nonzero(train_data[:, 2])\n    train_pos_idx = train_pos[0]\n    user_item_src_processed = train_data[train_pos_idx, 0]\n    user_item_dst_processed = train_data[train_pos_idx, 1]\n    edges_dict = {\n        (\"item\", \"iv\", \"view\"): (item_view_src, item_view_dst),\n        (\"view\", \"vi\", \"item\"): (item_view_dst, item_view_src),\n        (\"user\", \"ui\", \"item\"): (\n            user_item_src_processed,\n            user_item_dst_processed,\n        ),\n        (\"item\", \"iu\", \"user\"): (\n            user_item_dst_processed,\n            user_item_src_processed,\n        ),\n        (\"item\", \"ib\", \"brand\"): (item_brand_src, item_brand_dst),\n        (\"brand\", \"bi\", \"item\"): (item_brand_dst, item_brand_src),\n        (\"item\", \"ic\", \"category\"): (item_category_src, item_category_dst),\n        (\"category\", \"ci\", \"item\"): (item_category_dst, item_category_src),\n    }\n    nodes_dict = {\n        \"user\": hg.num_nodes(\"user\"),\n        \"item\": hg.num_nodes(\"item\"),\n        \"view\": hg.num_nodes(\"view\"),\n        \"brand\": hg.num_nodes(\"brand\"),\n        \"category\": hg.num_nodes(\"category\"),\n    }\n    hg_processed = dgl.heterograph(\n        data_dict=edges_dict, num_nodes_dict=nodes_dict\n    )\n    print(\"Graph processed.\")\n\n    # save the processed data\n    with open(os.path.join(root_path, \"amazon_hg.pkl\"), \"wb\") as file:\n        pkl.dump(hg_processed, file)\n    with open(os.path.join(root_path, \"amazon_train.pkl\"), \"wb\") as file:\n        pkl.dump(train_data, file)\n    with open(os.path.join(root_path, \"amazon_test.pkl\"), \"wb\") as file:\n        pkl.dump(test_data, file)\n    with open(os.path.join(root_path, \"amazon_eval.pkl\"), \"wb\") as file:\n        pkl.dump(eval_data, file)\n\n    return hg_processed, train_data, eval_data, test_data\n\n\ndef process_movielens(root_path):\n    # User-Movie 943 1682 100000 UMUM\n    # User-Age 943 8 943 UAUM\n    # User-Occupation 943 21 943 UOUM\n    # Movie-Genre 1682 18 2861 UMGM\n\n    data_path = os.path.join(root_path, \"Movielens\")\n    if not (os.path.exists(data_path)):\n        print(\n            \"Can not find movielens in {}, please download the dataset first.\".format(\n                data_path\n            )\n        )\n\n    # Construct graph from raw data.\n    # movie_genre\n    movie_genre_src = []\n    movie_genre_dst = []\n    with open(os.path.join(data_path, \"movie_genre.dat\")) as fin:\n        for line in fin.readlines():\n            _line = line.strip().split(\"\\t\")\n            movie, genre = int(_line[0]), int(_line[1])\n            movie_genre_src.append(movie)\n            movie_genre_dst.append(genre)\n\n    # user_movie\n    user_movie_src = []\n    user_movie_dst = []\n    with open(os.path.join(data_path, \"user_movie.dat\")) as fin:\n        for line in fin.readlines():\n            _line = line.strip().split(\"\\t\")\n            user, item, rate = int(_line[0]), int(_line[1]), int(_line[2])\n            if rate > 3:\n                user_movie_src.append(user)\n                user_movie_dst.append(item)\n\n    # user_occupation\n    user_occupation_src = []\n    user_occupation_dst = []\n    with open(os.path.join(data_path, \"user_occupation.dat\")) as fin:\n        for line in fin.readlines():\n            _line = line.strip().split(\"\\t\")\n            user, occupation = int(_line[0]), int(_line[1])\n            user_occupation_src.append(user)\n            user_occupation_dst.append(occupation)\n\n    # user_age\n    user_age_src = []\n    user_age_dst = []\n    with open(os.path.join(data_path, \"user_age.dat\")) as fin:\n        for line in fin.readlines():\n            _line = line.strip().split(\"\\t\")\n            user, age = int(_line[0]), int(_line[1])\n            user_age_src.append(user)\n            user_age_dst.append(age)\n\n    # build graph\n    hg = dgl.heterograph(\n        {\n            (\"movie\", \"mg\", \"genre\"): (movie_genre_src, movie_genre_dst),\n            (\"genre\", \"gm\", \"movie\"): (movie_genre_dst, movie_genre_src),\n            (\"user\", \"um\", \"movie\"): (user_movie_src, user_movie_dst),\n            (\"movie\", \"mu\", \"user\"): (user_movie_dst, user_movie_src),\n            (\"user\", \"uo\", \"occupation\"): (\n                user_occupation_src,\n                user_occupation_dst,\n            ),\n            (\"occupation\", \"ou\", \"user\"): (\n                user_occupation_dst,\n                user_occupation_src,\n            ),\n            (\"user\", \"ua\", \"age\"): (user_age_src, user_age_dst),\n            (\"age\", \"au\", \"user\"): (user_age_dst, user_age_src),\n        }\n    )\n\n    print(\"Graph constructed.\")\n\n    # Split data into train/eval/test\n    train_data, eval_data, test_data = split_data(hg, \"um\")\n\n    # delete the positive edges in eval/test data in the original graph\n    train_pos = np.nonzero(train_data[:, 2])\n    train_pos_idx = train_pos[0]\n    user_movie_src_processed = train_data[train_pos_idx, 0]\n    user_movie_dst_processed = train_data[train_pos_idx, 1]\n    edges_dict = {\n        (\"movie\", \"mg\", \"genre\"): (movie_genre_src, movie_genre_dst),\n        (\"genre\", \"gm\", \"movie\"): (movie_genre_dst, movie_genre_src),\n        (\"user\", \"um\", \"movie\"): (\n            user_movie_src_processed,\n            user_movie_dst_processed,\n        ),\n        (\"movie\", \"mu\", \"user\"): (\n            user_movie_dst_processed,\n            user_movie_src_processed,\n        ),\n        (\"user\", \"uo\", \"occupation\"): (\n            user_occupation_src,\n            user_occupation_dst,\n        ),\n        (\"occupation\", \"ou\", \"user\"): (\n            user_occupation_dst,\n            user_occupation_src,\n        ),\n        (\"user\", \"ua\", \"age\"): (user_age_src, user_age_dst),\n        (\"age\", \"au\", \"user\"): (user_age_dst, user_age_src),\n    }\n    nodes_dict = {\n        \"user\": hg.num_nodes(\"user\"),\n        \"movie\": hg.num_nodes(\"movie\"),\n        \"genre\": hg.num_nodes(\"genre\"),\n        \"occupation\": hg.num_nodes(\"occupation\"),\n        \"age\": hg.num_nodes(\"age\"),\n    }\n    hg_processed = dgl.heterograph(\n        data_dict=edges_dict, num_nodes_dict=nodes_dict\n    )\n    print(\"Graph processed.\")\n\n    # save the processed data\n    with open(os.path.join(root_path, \"movielens_hg.pkl\"), \"wb\") as file:\n        pkl.dump(hg_processed, file)\n    with open(os.path.join(root_path, \"movielens_train.pkl\"), \"wb\") as file:\n        pkl.dump(train_data, file)\n    with open(os.path.join(root_path, \"movielens_test.pkl\"), \"wb\") as file:\n        pkl.dump(test_data, file)\n    with open(os.path.join(root_path, \"movielens_eval.pkl\"), \"wb\") as file:\n        pkl.dump(eval_data, file)\n\n    return hg_processed, train_data, eval_data, test_data\n\n\nclass MyDataset(Dataset):\n    def __init__(self, triple):\n        self.triple = triple\n        self.len = self.triple.shape[0]\n\n    def __getitem__(self, index):\n        return (\n            self.triple[index, 0],\n            self.triple[index, 1],\n            self.triple[index, 2].float(),\n        )\n\n    def __len__(self):\n        return self.len\n\n\ndef load_data(dataset, batch_size=128, num_workers=10, root_path=\"./data\"):\n    if os.path.exists(os.path.join(root_path, dataset + \"_train.pkl\")):\n        g_file = open(os.path.join(root_path, dataset + \"_hg.pkl\"), \"rb\")\n        hg = pkl.load(g_file)\n        g_file.close()\n        train_set_file = open(\n            os.path.join(root_path, dataset + \"_train.pkl\"), \"rb\"\n        )\n        train_set = pkl.load(train_set_file)\n        train_set_file.close()\n        test_set_file = open(\n            os.path.join(root_path, dataset + \"_test.pkl\"), \"rb\"\n        )\n        test_set = pkl.load(test_set_file)\n        test_set_file.close()\n        eval_set_file = open(\n            os.path.join(root_path, dataset + \"_eval.pkl\"), \"rb\"\n        )\n        eval_set = pkl.load(eval_set_file)\n        eval_set_file.close()\n    else:\n        if dataset == \"movielens\":\n            hg, train_set, eval_set, test_set = process_movielens(root_path)\n        elif dataset == \"amazon\":\n            hg, train_set, eval_set, test_set = process_amazon(root_path)\n        else:\n            print(\"Available datasets: movielens, amazon.\")\n            raise NotImplementedError\n\n    if dataset == \"movielens\":\n        meta_paths = {\n            \"user\": [[\"um\", \"mu\"]],\n            \"movie\": [[\"mu\", \"um\"], [\"mg\", \"gm\"]],\n        }\n        user_key = \"user\"\n        item_key = \"movie\"\n    elif dataset == \"amazon\":\n        meta_paths = {\n            \"user\": [[\"ui\", \"iu\"]],\n            \"item\": [[\"iu\", \"ui\"], [\"ic\", \"ci\"], [\"ib\", \"bi\"], [\"iv\", \"vi\"]],\n        }\n        user_key = \"user\"\n        item_key = \"item\"\n    else:\n        print(\"Available datasets: movielens, amazon.\")\n        raise NotImplementedError\n\n    train_set = torch.Tensor(train_set).long()\n    eval_set = torch.Tensor(eval_set).long()\n    test_set = torch.Tensor(test_set).long()\n\n    train_set = MyDataset(train_set)\n    train_loader = DataLoader(\n        dataset=train_set,\n        batch_size=batch_size,\n        shuffle=True,\n        num_workers=num_workers,\n    )\n    eval_set = MyDataset(eval_set)\n    eval_loader = DataLoader(\n        dataset=eval_set,\n        batch_size=batch_size,\n        shuffle=True,\n        num_workers=num_workers,\n    )\n    test_set = MyDataset(test_set)\n    test_loader = DataLoader(\n        dataset=test_set,\n        batch_size=batch_size,\n        shuffle=True,\n        num_workers=num_workers,\n    )\n\n    return (\n        hg,\n        train_loader,\n        eval_loader,\n        test_loader,\n        meta_paths,\n        user_key,\n        item_key,\n    )\n"
  },
  {
    "path": "examples/pytorch/TAHIN/main.py",
    "content": "import argparse\nimport pickle as pkl\n\nimport dgl\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom data_loader import load_data\nfrom TAHIN import TAHIN\nfrom utils import (\n    evaluate_acc,\n    evaluate_auc,\n    evaluate_f1_score,\n    evaluate_logloss,\n)\n\n\ndef main(args):\n    # step 1: Check device\n    if args.gpu >= 0 and torch.cuda.is_available():\n        device = \"cuda:{}\".format(args.gpu)\n    else:\n        device = \"cpu\"\n\n    # step 2: Load data\n    (\n        g,\n        train_loader,\n        eval_loader,\n        test_loader,\n        meta_paths,\n        user_key,\n        item_key,\n    ) = load_data(args.dataset, args.batch, args.num_workers, args.path)\n    g = g.to(device)\n    print(\"Data loaded.\")\n\n    # step 3: Create model and training components\n    model = TAHIN(\n        g, meta_paths, args.in_size, args.out_size, args.num_heads, args.dropout\n    )\n    model = model.to(device)\n    criterion = nn.BCELoss()\n    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)\n    print(\"Model created.\")\n\n    # step 4: Training\n    print(\"Start training.\")\n    best_acc = 0.0\n    kill_cnt = 0\n    for epoch in range(args.epochs):\n        # Training and validation using a full graph\n        model.train()\n        train_loss = []\n        for step, batch in enumerate(train_loader):\n            user, item, label = [_.to(device) for _ in batch]\n            logits = model.forward(g, user_key, item_key, user, item)\n\n            # compute loss\n            tr_loss = criterion(logits, label)\n            train_loss.append(tr_loss)\n\n            # backward\n            optimizer.zero_grad()\n            tr_loss.backward()\n            optimizer.step()\n\n        train_loss = torch.stack(train_loss).sum().cpu().item()\n\n        model.eval()\n        with torch.no_grad():\n            validate_loss = []\n            validate_acc = []\n            for step, batch in enumerate(eval_loader):\n                user, item, label = [_.to(device) for _ in batch]\n                logits = model.forward(g, user_key, item_key, user, item)\n\n                # compute loss\n                val_loss = criterion(logits, label)\n                val_acc = evaluate_acc(\n                    logits.detach().cpu().numpy(), label.detach().cpu().numpy()\n                )\n                validate_loss.append(val_loss)\n                validate_acc.append(val_acc)\n\n            validate_loss = torch.stack(validate_loss).sum().cpu().item()\n            validate_acc = np.mean(validate_acc)\n\n            # validate\n            if validate_acc > best_acc:\n                best_acc = validate_acc\n                best_epoch = epoch\n                torch.save(model.state_dict(), \"TAHIN\" + \"_\" + args.dataset)\n                kill_cnt = 0\n                print(\"saving model...\")\n            else:\n                kill_cnt += 1\n                if kill_cnt > args.early_stop:\n                    print(\"early stop.\")\n                    print(\"best epoch:{}\".format(best_epoch))\n                    break\n\n            print(\n                \"In epoch {}, Train Loss: {:.4f}, Valid Loss: {:.5}\\n, Valid ACC: {:.5}\".format(\n                    epoch, train_loss, validate_loss, validate_acc\n                )\n            )\n\n    # test use the best model\n    model.eval()\n    with torch.no_grad():\n        model.load_state_dict(\n            torch.load(\"TAHIN\" + \"_\" + args.dataset, weights_only=False)\n        )\n        test_loss = []\n        test_acc = []\n        test_auc = []\n        test_f1 = []\n        test_logloss = []\n        for step, batch in enumerate(test_loader):\n            user, item, label = [_.to(device) for _ in batch]\n            logits = model.forward(g, user_key, item_key, user, item)\n\n            # compute loss\n            loss = criterion(logits, label)\n            acc = evaluate_acc(\n                logits.detach().cpu().numpy(), label.detach().cpu().numpy()\n            )\n            auc = evaluate_auc(\n                logits.detach().cpu().numpy(), label.detach().cpu().numpy()\n            )\n            f1 = evaluate_f1_score(\n                logits.detach().cpu().numpy(), label.detach().cpu().numpy()\n            )\n            log_loss = evaluate_logloss(\n                logits.detach().cpu().numpy(), label.detach().cpu().numpy()\n            )\n\n            test_loss.append(loss)\n            test_acc.append(acc)\n            test_auc.append(auc)\n            test_f1.append(f1)\n            test_logloss.append(log_loss)\n\n        test_loss = torch.stack(test_loss).sum().cpu().item()\n        test_acc = np.mean(test_acc)\n        test_auc = np.mean(test_auc)\n        test_f1 = np.mean(test_f1)\n        test_logloss = np.mean(test_logloss)\n        print(\n            \"Test Loss: {:.5}\\n, Test ACC: {:.5}\\n, AUC: {:.5}\\n, F1: {:.5}\\n, Logloss: {:.5}\\n\".format(\n                test_loss, test_acc, test_auc, test_f1, test_logloss\n            )\n        )\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(\n        description=\"Parser For Arguments\",\n        formatter_class=argparse.ArgumentDefaultsHelpFormatter,\n    )\n\n    parser.add_argument(\n        \"--dataset\",\n        default=\"movielens\",\n        help=\"Dataset to use, default: movielens\",\n    )\n    parser.add_argument(\n        \"--path\", default=\"./data\", help=\"Path to save the data\"\n    )\n    parser.add_argument(\"--model\", default=\"TAHIN\", help=\"Model Name\")\n\n    parser.add_argument(\"--batch\", default=128, type=int, help=\"Batch size\")\n    parser.add_argument(\n        \"--gpu\",\n        type=int,\n        default=\"0\",\n        help=\"Set GPU Ids : Eg: For CPU = -1, For Single GPU = 0\",\n    )\n    parser.add_argument(\n        \"--epochs\", type=int, default=500, help=\"Maximum number of epochs\"\n    )\n    parser.add_argument(\n        \"--wd\", type=float, default=0, help=\"L2 Regularization for Optimizer\"\n    )\n    parser.add_argument(\"--lr\", type=float, default=0.001, help=\"Learning Rate\")\n    parser.add_argument(\n        \"--num_workers\",\n        type=int,\n        default=10,\n        help=\"Number of processes to construct batches\",\n    )\n    parser.add_argument(\n        \"--early_stop\", default=15, type=int, help=\"Patience for early stop.\"\n    )\n\n    parser.add_argument(\n        \"--in_size\",\n        default=128,\n        type=int,\n        help=\"Initial dimension size for entities.\",\n    )\n    parser.add_argument(\n        \"--out_size\",\n        default=128,\n        type=int,\n        help=\"Output dimension size for entities.\",\n    )\n\n    parser.add_argument(\n        \"--num_heads\", default=1, type=int, help=\"Number of attention heads\"\n    )\n    parser.add_argument(\"--dropout\", default=0.1, type=float, help=\"Dropout.\")\n\n    args = parser.parse_args()\n\n    print(args)\n\n    main(args)\n"
  },
  {
    "path": "examples/pytorch/TAHIN/readme.md",
    "content": "# DGL Implementation of the TAHIN\n\nThis DGL example implements the TAHIN module proposed in the paper [HCDIR](https://arxiv.org/pdf/2007.15293.pdf). Since the code and dataset have not been published yet, we implement its main idea and experiment on two other datasets.\n\nExample implementor\n----------------------\nThis example was implemented by [KounianhuaDu](https://github.com/KounianhuaDu) during her software development intern time at the AWS Shanghai AI Lab.\n\nDependencies\n----------------------\n- pytorch 1.7.1\n- dgl 0.6.0\n- scikit-learn 0.22.1\n\nDatasets\n---------------------------------------\nThe datasets used can be downloaded from [here](https://github.com/librahu/HIN-Datasets-for-Recommendation-and-Network-Embedding). For the experiments, all the positive edges are fetched and the same number of negative edges are randomly sampled. The edges are then shuffled and splitted into train/validate/test at a ratio of 6:2:2. The positive edges that appear in the validation and test sets are then removed from the original graph.\n\nThe original graph statistics:\n\n**Movielens** \n\n(Source : https://grouplens.org/datasets/movielens/)\n\n| Entity         |#Entity        |\n| :-------------:|:-------------:|\n| User           | 943           |\n| Age            | 8             |\n| Occupation     | 21            |\n| Movie          | 1,682         |\n| Genre          | 18            |\n\n| Relation            |#Relation      |\n| :-------------:     |:-------------:|\n| User - Movie        | 100,000       |\n| User - User (KNN)   | 47,150        |\n| User - Age          | 943           |\n| User - Occupation   | 943           |\n| Movie - Movie (KNN) | 82,798        |\n| Movie - Genre       | 2,861         |\n\n**Amazon** \n\n(Source : http://jmcauley.ucsd.edu/data/amazon/)\n\n| Entity         |#Entity        |\n| :-------------:|:-------------:|\n| User           | 6,170         |\n| Item           | 2,753         |\n| View           | 3,857         |           \n| Category       | 22            |\n| Brand          | 334           |\n\n| Relation          |#Relation      |\n| :-------------:   |:-------------:|\n| User - Item       | 195,791       |\n| Item - View       | 5,694         |\n| Item - Category   | 5,508         | \n| Item - Brand      | 2,753         |\n\nHow to run\n--------------------------------\n\n```python\npython main.py --dataset amazon --gpu 0\n```\n\n\n```python\npython main.py --dataset movielens --gpu 0\n```\n\n\nPerformance\n-------------------------\n**Results**\n\n| Dataset |         Movielens        |          Amazon          |\n|---------| ------------------------ | ------------------------ |\n|  Metric |    HAN     /    TAHIN    |    HAN     /    TAHIN    |\n|   AUC   |   0.9297   /   0.9392    |   0.8470   /   0.8442    |\n|   ACC   |   0.8627   /   0.8683    |   0.7672   /   0.7619    |\n|    F1   |   0.8631   /   0.8707    |   0.7628   /   0.7499    |\n| Logloss |   0.3689   /   0.3266    |   0.5311   /   0.5150    |\n"
  },
  {
    "path": "examples/pytorch/TAHIN/utils.py",
    "content": "from sklearn.metrics import accuracy_score, f1_score, log_loss, roc_auc_score\n\n\ndef evaluate_auc(pred, label):\n    res = roc_auc_score(y_score=pred, y_true=label)\n    return res\n\n\ndef evaluate_acc(pred, label):\n    res = []\n    for _value in pred:\n        res.append(1 if _value >= 0.5 else 0)\n    return accuracy_score(y_pred=res, y_true=label)\n\n\ndef evaluate_f1_score(pred, label):\n    res = []\n    for _value in pred:\n        res.append(1 if _value >= 0.5 else 0)\n    return f1_score(y_pred=res, y_true=label)\n\n\ndef evaluate_logloss(pred, label):\n    res = log_loss(y_true=label, y_pred=pred, normalize=True)\n    return res\n"
  },
  {
    "path": "examples/pytorch/appnp/README.md",
    "content": "Predict then Propagate: Graph Neural Networks meet Personalized PageRank (APPNP)\n============\n\n- Paper link: [Predict then Propagate: Graph Neural Networks meet Personalized PageRank](https://arxiv.org/abs/1810.05997)\n- Author's code repo: [https://github.com/klicperajo/ppnp](https://github.com/klicperajo/ppnp). \n\nDependencies\n------------\n- PyTorch 0.4.1+\n- requests\n\n``bash\npip install torch requests\n``\n\nCode\n-----\nThe folder contains an implementation of APPNP (`appnp.py`).\n\nResults\n-------\n\nRun with following (available dataset: \"cora\", \"citeseer\", \"pubmed\")\n```bash\npython3 train.py --dataset cora --gpu 0\n```\n\n* cora: 0.8370 (paper: 0.850)\n* citeseer: 0.715 (paper: 0.757)\n* pubmed: 0.793 (paper: 0.797)\n\nExperiments were done on dgl datasets (GCN settings) which are different from those used in the original implementation. (discrepancies are detailed in experimental section of the original paper)\n"
  },
  {
    "path": "examples/pytorch/appnp/appnp.py",
    "content": "\"\"\"\nAPPNP implementation in DGL.\nReferences\n----------\nPaper: https://arxiv.org/abs/1810.05997\nAuthor's code: https://github.com/klicperajo/ppnp\n\"\"\"\nimport torch.nn as nn\n\nfrom dgl.nn.pytorch.conv import APPNPConv\n\n\nclass APPNP(nn.Module):\n    def __init__(\n        self,\n        g,\n        in_feats,\n        hiddens,\n        n_classes,\n        activation,\n        feat_drop,\n        edge_drop,\n        alpha,\n        k,\n    ):\n        super(APPNP, self).__init__()\n        self.g = g\n        self.layers = nn.ModuleList()\n        # input layer\n        self.layers.append(nn.Linear(in_feats, hiddens[0]))\n        # hidden layers\n        for i in range(1, len(hiddens)):\n            self.layers.append(nn.Linear(hiddens[i - 1], hiddens[i]))\n        # output layer\n        self.layers.append(nn.Linear(hiddens[-1], n_classes))\n        self.activation = activation\n        if feat_drop:\n            self.feat_drop = nn.Dropout(feat_drop)\n        else:\n            self.feat_drop = lambda x: x\n        self.propagate = APPNPConv(k, alpha, edge_drop)\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        for layer in self.layers:\n            layer.reset_parameters()\n\n    def forward(self, features):\n        # prediction step\n        h = features\n        h = self.feat_drop(h)\n        h = self.activation(self.layers[0](h))\n        for layer in self.layers[1:-1]:\n            h = self.activation(layer(h))\n        h = self.layers[-1](self.feat_drop(h))\n        # propagation step\n        h = self.propagate(self.g, h)\n        return h\n"
  },
  {
    "path": "examples/pytorch/appnp/train.py",
    "content": "import argparse\nimport time\n\nimport dgl\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom appnp import APPNP\nfrom dgl.data import (\n    CiteseerGraphDataset,\n    CoraGraphDataset,\n    PubmedGraphDataset,\n    register_data_args,\n)\n\n\ndef evaluate(model, features, labels, mask):\n    model.eval()\n    with torch.no_grad():\n        logits = model(features)\n        logits = logits[mask]\n        labels = labels[mask]\n        _, indices = torch.max(logits, dim=1)\n        correct = torch.sum(indices == labels)\n        return correct.item() * 1.0 / len(labels)\n\n\ndef main(args):\n    # load and preprocess dataset\n    if args.dataset == \"cora\":\n        data = CoraGraphDataset()\n    elif args.dataset == \"citeseer\":\n        data = CiteseerGraphDataset()\n    elif args.dataset == \"pubmed\":\n        data = PubmedGraphDataset()\n    else:\n        raise ValueError(\"Unknown dataset: {}\".format(args.dataset))\n\n    g = data[0]\n    if args.gpu < 0:\n        cuda = False\n    else:\n        cuda = True\n        g = g.to(args.gpu)\n\n    features = g.ndata[\"feat\"]\n    labels = g.ndata[\"label\"]\n    train_mask = g.ndata[\"train_mask\"]\n    val_mask = g.ndata[\"val_mask\"]\n    test_mask = g.ndata[\"test_mask\"]\n    in_feats = features.shape[1]\n    n_classes = data.num_classes\n    n_edges = g.num_edges()\n    print(\n        \"\"\"----Data statistics------'\n      #Edges %d\n      #Classes %d\n      #Train samples %d\n      #Val samples %d\n      #Test samples %d\"\"\"\n        % (\n            n_edges,\n            n_classes,\n            train_mask.int().sum().item(),\n            val_mask.int().sum().item(),\n            test_mask.int().sum().item(),\n        )\n    )\n\n    n_edges = g.num_edges()\n    # add self loop\n    g = dgl.remove_self_loop(g)\n    g = dgl.add_self_loop(g)\n\n    # create APPNP model\n    model = APPNP(\n        g,\n        in_feats,\n        args.hidden_sizes,\n        n_classes,\n        F.relu,\n        args.in_drop,\n        args.edge_drop,\n        args.alpha,\n        args.k,\n    )\n\n    if cuda:\n        model.cuda()\n    loss_fcn = torch.nn.CrossEntropyLoss()\n\n    # use optimizer\n    optimizer = torch.optim.Adam(\n        model.parameters(), lr=args.lr, weight_decay=args.weight_decay\n    )\n\n    # initialize graph\n    mean = 0\n    for epoch in range(args.n_epochs):\n        model.train()\n        if epoch >= 3:\n            t0 = time.time()\n        # forward\n        logits = model(features)\n        loss = loss_fcn(logits[train_mask], labels[train_mask])\n\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        if epoch >= 3:\n            mean = (mean * (epoch - 3) + (time.time() - t0)) / (epoch - 2)\n            acc = evaluate(model, features, labels, val_mask)\n            print(\n                \"Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | \"\n                \"ETputs(KTEPS) {:.2f}\".format(\n                    epoch,\n                    mean,\n                    loss.item(),\n                    acc,\n                    n_edges / mean / 1000,\n                )\n            )\n\n    print()\n    acc = evaluate(model, features, labels, test_mask)\n    print(\"Test Accuracy {:.4f}\".format(acc))\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"APPNP\")\n    register_data_args(parser)\n    parser.add_argument(\n        \"--in-drop\", type=float, default=0.5, help=\"input feature dropout\"\n    )\n    parser.add_argument(\n        \"--edge-drop\", type=float, default=0.5, help=\"edge propagation dropout\"\n    )\n    parser.add_argument(\"--gpu\", type=int, default=-1, help=\"gpu\")\n    parser.add_argument(\"--lr\", type=float, default=1e-2, help=\"learning rate\")\n    parser.add_argument(\n        \"--n-epochs\", type=int, default=200, help=\"number of training epochs\"\n    )\n    parser.add_argument(\n        \"--hidden_sizes\",\n        type=int,\n        nargs=\"+\",\n        default=[64],\n        help=\"hidden unit sizes for appnp\",\n    )\n    parser.add_argument(\n        \"--k\", type=int, default=10, help=\"Number of propagation steps\"\n    )\n    parser.add_argument(\n        \"--alpha\", type=float, default=0.1, help=\"Teleport Probability\"\n    )\n    parser.add_argument(\n        \"--weight-decay\", type=float, default=5e-4, help=\"Weight for L2 loss\"\n    )\n    args = parser.parse_args()\n    print(args)\n\n    main(args)\n"
  },
  {
    "path": "examples/pytorch/argo/README.md",
    "content": "# ARGO: An Auto-Tuning Runtime System for Scalable GNN Training on Multi-Core Processor\n\n## Overview\n\nGraph Neural Network (GNN) training suffers from low scalability on multi-core processors. \nARGO is a runtime system that offers scalable performance. \nThe figure below shows an example of GNN training on a Xeon 8380H platform with 112 cores. \nWithout ARGO, there is no performance improvement after applying more than 16 cores; we observe a similar scalability limit on a Xeon 6430L platform with 64 cores as well.\nHowever, with ARGO enabled, we are able to scale over 64 cores, allowing ARGO to speedup GNN training (in terms of epoch time) by up to 4.30x and 3.32x on a Xeon 8380H and a Xeon 6430L, respectively.\n![ARGO](https://github.com/dmlc/dgl/blob/master/examples/pytorch/argo/argo_scale.png)\n\nThis README includes how to:\n1. [Installation](#1-installation)\n2. [Run the example code](#2-running-the-example-GNN-program)\n3. [Modify your own GNN program to enable ARGO.](#3-enabling-ARGO-on-your-own-GNN-program)\n\n## 1. Installation\n\n1. ARGO utilizes the scikit-optimize library for auto-tuning. Please install scikit-optimize to run ARGO:\n\n   ```shell\n   conda install -c conda-forge \"scikit-optimize>=0.9.0\" \n   ```\n   or\n   ```shell\n   pip install scikit-optimize>=0.9\n   ```\n\n## 2. Running the example GNN program\n### Usage\n  ```shell\n  python main.py --dataset ogbn-products --sampler shadow --model sage\n  ``` \n  Important Arguments: \n  - `--dataset`: the training datasets. Available choices [ogbn-products, ogbn-papers100M, reddit, flickr, yelp]\n  - `--sampler`: the mini-batch sampling algorithm. Available choices [shadow, neighbor]\n  - `--model`: GNN model. Available choices [gcn, sage]\n  - `--layer`: number of GNN layers.\n  - `--fan_out`: number of fanout neighbors for each layer.\n  - `--hidden`: hidden feature dimension.\n  - `--batch_size`: the size of the mini-batch.\n\n\n\n## 3. Enabling ARGO on your own GNN program\n\nIn this section, we provide a step-by-step tutorial on how to enable ARGO on a DGL program. We use the ```ogb_example.py``` file in this repo as an example.  \n\n>  Note: we also provide the complete example file ```ogb_example_ARGO.py``` which followed the steps below to enable ARGO on ```ogb_example.py```.\n\n1. First, include all necessary packages on top of the file. Please place your file and ```argo.py``` in the same directory.\n\n   ```python\n   import os\n   import torch.distributed as dist\n   from torch.nn.parallel import DistributedDataParallel\n   import torch.multiprocessing as mp\n   from argo import ARGO\n   ```\n\n2. Setup PyTorch Distributed Data Parallel (DDP). \n    1. Add the initialization function on top of the training program, and wrap the ```model``` with the DDP wrapper\n     ```python\n     def train(...):\n       dist.init_process_group('gloo', rank=rank, world_size=world_size) # newly added\n       model = SAGE(...) # original code\n       model = DistributedDataParallel(model) # newly added\n       ...\n     ```\n    2. In the main program, add the following before launching the training function\n    \n     ```python\n     os.environ['MASTER_ADDR'] = '127.0.0.1'\n     os.environ['MASTER_PORT'] = '29501'\n     mp.set_start_method('fork', force=True)\n     train(args, device, data) # original code for launching the training function\n     ```\n\n3. Enable ARGO by initializing the runtime system, and wrapping the training function\n   ```python\n   runtime = ARGO(n_search = 15, epoch = args.num_epochs, batch_size = args.batch_size) #initialization\n   runtime.run(train, args=(args, device, data)) # wrap the training function\n   ```\n   >  ARGO takes three input paramters: number of searches ```n_search```, number of epochs, and the mini-batch size. Increasing ```n_search``` potentially leads to a better configuration with less epoch time; however, searching itself also causes extra overhead. We recommend setting ```n_search``` from 15 to 45 for an optimal overall performance. Details of ```n_search``` can be found in the paper.\n\n4. Modify the input of the training function, by directly adding ARGO parameters after the original inputs.\n   This is the original function:\n   ```python\n   def train(args, device, data):\n   ```\n   Add ```rank, world_size, comp_core, load_core, counter, b_size, ep``` like this:\n   ```python\n   def train(args, device, data, rank, world_size, comp_core, load_core, counter, b_size, ep):\n   ```\n\n5. Modify the ```dataloader``` function in the training function\n   ```python\n   dataloader = dgl.dataloading.DataLoader(\n           g,\n           train_nid,\n           sampler,\n           batch_size=b_size, # modified\n           shuffle=True,\n           drop_last=False,\n           num_workers=len(load_core), # modified\n           use_ddp = True) # newly added\n   ```\n\n6. Enable core-binding by adding ```enable_cpu_affinity()``` before the training for-loop, and also change the number of epochs into the variable ```ep```: \n   ```python\n   with dataloader.enable_cpu_affinity(loader_cores=load_core, compute_cores=comp_core): \n     for epoch in range(ep): # change num_epochs to ep\n   ```\n\n7. Last step! Load the model before training and save it afterward.  \n   Original Program:\n   ```python\n   with dataloader.enable_cpu_affinity(loader_cores=load_core, compute_cores=comp_core): \n     for epoch in range(ep): \n       ... # training operations\n   ```\n   Modified:\n   ```python\n   PATH = \"model.pt\"\n   if counter[0] != 0:\n     checkpoint = th.load(PATH)\n     model.load_state_dict(checkpoint['model_state_dict'])\n     optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n     epoch = checkpoint['epoch']\n     loss = checkpoint['loss']\n   \n   with dataloader.enable_cpu_affinity(loader_cores=load_core, compute_cores=comp_core): \n     for epoch in range(ep): \n       ... # training operations\n   \n   dist.barrier()\n   if rank == 0:\n     th.save({'epoch': counter[0],\n                 'model_state_dict': model.state_dict(),\n                 'optimizer_state_dict': optimizer.state_dict(),\n                 'loss': loss,\n                 }, PATH)\n   \n   ```\n8. Done! You can now run your GNN program with ARGO enabled.\n      ```shell\n      python <your_code>.py\n      ```\n\n## Citation & Acknowledgement\nThis work has been supported by the U.S. National Science Foundation (NSF) under grants CCF-1919289/SPX-2333009, CNS-2009057 and OAC-2209563, and the Semiconductor Research Corporation (SRC).\n```\n@INPROCEEDINGS{argo-ipdps24,\n  author={Yi-Chien Lin and Yuyang Chen and Sameh Gobriel and Nilesh Jain and Gopi Krishna Jhaand and Viktor Prasanna},\n  booktitle={IEEE International Parallel and Distributed Processing Symposium (IPDPS)}, \n  title={ARGO: An Auto-Tuning Runtime System for Scalable GNN Training on Multi-Core Processor}, \n  year={2024}}\n```\n"
  },
  {
    "path": "examples/pytorch/argo/argo.py",
    "content": "\"\"\"\nARGO: An Auto-Tuning Runtime System for Scalable GNN Training on Multi-Core Processor\n--------------------------------------------\nGraph Neural Network (GNN) training suffers from low scalability on multi-core CPUs. \nSpecificially, the performance often caps at 16 cores, and no improvement is observed when applying more than 16 cores.\nARGO is a runtime system that offers scalable performance by overlapping the computation and communication during GNN training.\nWith ARGO enabled, we are able to scale over 64 cores, allowing ARGO to speedup GNN training (in terms of epoch time) by up to 4.30x and 3.32x on a Xeon 8380H and a Xeon 6430L, respectively.\n--------------------------------------------\nPaper Link: https://arxiv.org/abs/2402.03671\n\"\"\"\n\nimport time\nfrom typing import Callable, List, Tuple\n\nimport dgl.multiprocessing as dmp\nimport numpy as np\nimport psutil\nfrom skopt import gp_minimize\nfrom skopt.space import Normalize\n\n\ndef transform(self, X):\n    X = np.asarray(X)\n    if self.is_int:\n        if np.any(np.round(X) > self.high):\n            raise ValueError(\n                \"All integer values should\" \"be less than %f\" % self.high\n            )\n        if np.any(np.round(X) < self.low):\n            raise ValueError(\n                \"All integer values should\" \"be greater than %f\" % self.low\n            )\n    else:\n        if np.any(X > self.high + self._eps):\n            raise ValueError(\"All values should\" \"be less than %f\" % self.high)\n        if np.any(X < self.low - self._eps):\n            raise ValueError(\n                \"All values should\" \"be greater than %f\" % self.low\n            )\n    if (self.high - self.low) == 0.0:\n        return X * 0.0\n    if self.is_int:\n        return (np.round(X).astype(int) - self.low) / (self.high - self.low)\n    else:\n        return (X - self.low) / (self.high - self.low)\n\n\ndef inverse_transform(self, X):\n    X = np.asarray(X)\n    if np.any(X > 1.0 + self._eps):\n        raise ValueError(\"All values should be less than 1.0\")\n    if np.any(X < 0.0 - self._eps):\n        raise ValueError(\"All values should be greater than 0.0\")\n    X_orig = X * (self.high - self.low) + self.low\n    if self.is_int:\n        return np.round(X_orig).astype(int)\n    return X_orig\n\n\n# This is a workaround for scikit-optimize's incompatibility with NumPy, which results in an error::\n# AttributeError: module 'numpy' has no attribute 'int'\nNormalize.transform = transform\nNormalize.inverse_transform = inverse_transform\n\n\nclass ARGO:\n    def __init__(\n        self,\n        n_search=10,\n        epoch=200,\n        batch_size=4096,\n        space=[(2, 8), (1, 4), (1, 32)],\n        random_state=1,\n    ):\n        \"\"\"\n        Initialization\n\n        Parameters\n        ----------\n        n_search: int\n            Number of configuration searches the auto-tuner will conduct\n\n        epoch: int\n            Number of epochs of GNN training\n\n        batch_size: int\n            Size of the mini-batch\n\n        space: list[Tuple(int,int)]\n            Range of the search space; [range of processes, range of samplers for each process, range of trainers for each process]\n\n        random_state: int\n            Number of random initializations before searching\n\n        \"\"\"\n        self.n_search = n_search\n        self.epoch = epoch\n        self.batch_size = batch_size\n        self.space = space\n        self.random_state = random_state\n        self.acq_func = \"EI\"\n        self.counter = [0]\n\n    def core_binder(\n        self, num_cpu_proc: int, n_samp: int, n_train: int, rank: int\n    ) -> Tuple[List[int], List[int]]:\n        \"\"\"\n        Core Binder\n\n        The Core Binder binds CPU cores to perform sampling (i.e., sampling cores) and model propagation (i.e., training cores).\n        The actual binding is done using the CPU affinity function in the data_loader.\n        The core_binder function here is used to produce the list of CPU IDs for the CPU affinity function.\n\n        Parameters\n        ----------\n        num_cpu_proc: int\n            Number of processes instantiated\n\n        n_samp: int\n            Number of sampling cores for each process\n\n        n_train: int\n            Number of training cores for each process\n\n        rank: int\n            The rank of the current process\n\n        Returns: Tuple[list[int], list[int]]\n        -------\n        load_core: list[int]\n            For a given process rank, the load_core specifies a list of CPU core IDs to be used for sampling, the length of load_core = n_samp.\n\n        comp_core: list[int]\n            For a given process rank, the comp_core specifies a list of CPU core IDs to be used for training, the length of comp_core = n_comp.\n\n        .. note:: Each process is assigned with a unique list of sampling cores and training cores, and no CPU core will appear in two lists or more.\n\n        \"\"\"\n        load_core, comp_core = [], []\n        n = psutil.cpu_count(logical=False)\n        size = num_cpu_proc\n        num_of_samplers = n_samp\n        load_core = list(\n            range(n // size * rank, n // size * rank + num_of_samplers)\n        )\n        comp_core = list(\n            range(\n                n // size * rank + num_of_samplers,\n                n // size * rank + num_of_samplers + n_train,\n            )\n        )\n        return load_core, comp_core\n\n    def auto_tuning(self, train: Callable, args) -> List[int]:\n        \"\"\"\n        Auto-tuner\n\n        The auto-tuner runs Bayesian Optimization (BO) to search for the optimal configuration (number of processes, samplers, trainers).\n        During the search, the auto-tuner explores the design space by collecting the epoch time of various configurations.\n        Specifically, the exploration is done by feeding the Multi-Process Engine with various configurations, and record the epoch time.\n        After the searching is done, the optimal configuration will be used repeatedly until the end of model training.\n\n        Parameters\n        ----------\n        train: Callable\n            The GNN training function.\n\n        args:\n            The inputs of the GNN training function.\n\n        Returns\n        -------\n        result: list[int]\n            The optimal configurations (which leads to the shortest epoch time) found by running BO.\n            - result[0]: number of processes to instantiate\n            - result[1]: number of sampling cores for each process\n            - result[2]: number of training cores for each process\n\n        \"\"\"\n        ep = 1\n        result = gp_minimize(\n            lambda x: self.mp_engine(x, train, args, ep),\n            dimensions=self.space,\n            n_calls=self.n_search,\n            random_state=self.random_state,\n            acq_func=self.acq_func,\n        )\n        return result\n\n    def mp_engine(self, x: List[int], train: Callable, args, ep: int) -> float:\n        \"\"\"\n        Multi-Process Engine (MP Engine)\n\n        The MP Engine launches multiple GNN training processes in parallel to overlap computation with communication.\n        Such an approach effectively improves the utilization of the memory bandwidth and the CPU cores.\n        The MP Engine also adjust the batch size according to the number of processes instantiated, so that the effective batch size remains the same as the original program without ARGO.\n\n        Parameters\n        ----------\n        x: list[int]\n            Optimal configurations provided by the auto-tuner.\n            - x[0]: number of processes to instantiate\n            - x[1]: number of sampling cores for each process\n            - x[2]: number of training cores for each process\n\n        train: Callable\n            The GNN training function.\n\n        args:\n            The inputs of the GNN training function.\n\n        ep: int\n            number of epochs.\n\n        Returns\n        -------\n        t: float\n            The epoch time using the current configuration `x`.\n        \"\"\"\n        n_proc = x[0]\n        n_samp = x[1]\n        n_train = x[2]\n        n_total = psutil.cpu_count(logical=False)\n\n        if n_proc * (n_samp + n_train) > n_total:  # handling corner cases\n            n_proc = 2\n            n_samp = 2\n            n_train = (n_total // n_proc) - n_samp\n\n        processes = []\n        cnt = self.counter\n        b_size = self.batch_size // n_proc  # adjust batch size\n\n        tik = time.time()\n        for i in range(n_proc):\n            load_core, comp_core = self.core_binder(n_proc, n_samp, n_train, i)\n            p = dmp.Process(\n                target=train,\n                args=(*args, i, n_proc, comp_core, load_core, cnt, b_size, ep),\n            )\n            p.start()\n            processes.append(p)\n        for p in processes:\n            p.join()\n        t = time.time() - tik\n\n        self.counter[0] = self.counter[0] + 1\n\n        return t\n\n    def run(self, train, args):\n        \"\"\"\n        The \"run\" function launches ARGO to traing GNN model\n        Step 1: run the auto-tuner to search for the optimal configuration\n        Step 2: record the optimal configuration\n        Step 3: use the optimal configuration repeatedly until the end of the model training\n\n        Parameters\n        ----------\n        train: Callable\n            The GNN training function.\n\n        args:\n            The inputs of the GNN training function.\n        \"\"\"\n        result = self.auto_tuning(train, args)  # Step 1\n        x = result.x  # Step 2\n        self.mp_engine(\n            x, train, args, ep=(self.epoch - self.n_search)\n        )  # Step 3\n"
  },
  {
    "path": "examples/pytorch/argo/main.py",
    "content": "import argparse\nimport os\n\nimport dgl\nimport dgl.nn as dglnn\nimport torch\nimport torch.distributed as dist\nimport torch.multiprocessing as mp\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom argo import ARGO\nfrom dgl.data import (\n    AsNodePredDataset,\n    FlickrDataset,\n    RedditDataset,\n    YelpDataset,\n)\nfrom dgl.dataloading import DataLoader, NeighborSampler, ShaDowKHopSampler\nfrom ogb.nodeproppred import DglNodePropPredDataset\nfrom torch.nn.parallel import DistributedDataParallel\n\n\nclass GNN(nn.Module):\n    def __init__(\n        self, in_size, hid_size, out_size, num_layers=3, model_name=\"sage\"\n    ):\n        super().__init__()\n        self.layers = nn.ModuleList()\n\n        # GraphSAGE-mean\n        if model_name.lower() == \"sage\":\n            self.layers.append(dglnn.SAGEConv(in_size, hid_size, \"mean\"))\n            for i in range(num_layers - 2):\n                self.layers.append(dglnn.SAGEConv(hid_size, hid_size, \"mean\"))\n            self.layers.append(dglnn.SAGEConv(hid_size, out_size, \"mean\"))\n        # GCN\n        elif model_name.lower() == \"gcn\":\n            kwargs = {\n                \"norm\": \"both\",\n                \"weight\": True,\n                \"bias\": True,\n                \"allow_zero_in_degree\": True,\n            }\n            self.layers.append(dglnn.GraphConv(in_size, hid_size, **kwargs))\n            for i in range(num_layers - 2):\n                self.layers.append(\n                    dglnn.GraphConv(hid_size, hid_size, **kwargs)\n                )\n            self.layers.append(dglnn.GraphConv(hid_size, out_size, **kwargs))\n        else:\n            raise NotImplementedError\n\n        self.dropout = nn.Dropout(0.5)\n        self.hid_size = hid_size\n        self.out_size = out_size\n\n    def forward(self, blocks, x):\n        h = x\n        if hasattr(blocks, \"__len__\"):\n            for l, (layer, block) in enumerate(zip(self.layers, blocks)):\n                h = layer(block, h)\n                if l != len(self.layers) - 1:\n                    h = F.relu(h)\n                    h = self.dropout(h)\n        else:\n            for l, layer in enumerate(self.layers):\n                h = layer(blocks, h)\n                if l != len(self.layers) - 1:\n                    h = F.relu(h)\n                    h = self.dropout(h)\n        return h\n\n\ndef _train(**kwargs):\n    total_loss = 0\n    loader = kwargs[\"loader\"]\n    model = kwargs[\"model\"]\n    opt = kwargs[\"opt\"]\n    load_core = kwargs[\"load_core\"]\n    comp_core = kwargs[\"comp_core\"]\n\n    device = torch.device(\"cpu\")\n    with loader.enable_cpu_affinity(\n        loader_cores=load_core, compute_cores=comp_core\n    ):\n        for it, (input_nodes, output_nodes, blocks) in enumerate(loader):\n            if hasattr(blocks, \"__len__\"):\n                x = blocks[0].srcdata[\"feat\"].to(torch.float32)\n                y = blocks[-1].dstdata[\"label\"]\n            else:\n                x = blocks.srcdata[\"feat\"].to(torch.float32)\n                y = blocks.dstdata[\"label\"]\n            if kwargs[\"device\"] == \"cpu\":  # for papers100M\n                y = y.type(torch.LongTensor)\n                y_hat = model(blocks, x)\n            else:\n                y = y.type(torch.LongTensor).to(device)\n                y_hat = model(blocks, x).to(device)\n            try:\n                loss = F.cross_entropy(\n                    y_hat[: output_nodes.shape[0]], y[: output_nodes.shape[0]]\n                )\n            except:\n                loss = F.binary_cross_entropy_with_logits(\n                    y_hat[: output_nodes.shape[0]].float(),\n                    y[: output_nodes.shape[0]].float(),\n                    reduction=\"sum\",\n                )\n            opt.zero_grad()\n            loss.backward()\n            opt.step()\n            del input_nodes, output_nodes, blocks\n            total_loss += loss.item()\n    return total_loss\n\n\ndef train(\n    args, g, data, rank, world_size, comp_core, load_core, counter, b_size, ep\n):\n\n    num_classes, train_idx = data\n    dist.init_process_group(\"gloo\", rank=rank, world_size=world_size)\n    device = torch.device(\"cpu\")\n    hidden = args.hidden\n    # create GraphSAGE model\n    in_size = g.ndata[\"feat\"].shape[1]\n    model = GNN(\n        in_size,\n        hidden,\n        num_classes,\n        num_layers=args.layer,\n        model_name=args.model,\n    ).to(device)\n    model = DistributedDataParallel(model)\n    num_of_samplers = len(load_core)\n    # create loader\n    drop_last, shuffle = True, True\n    if args.sampler.lower() == \"neighbor\":\n        sampler = NeighborSampler(\n            [int(fanout) for fanout in args.fan_out.split(\",\")],\n            prefetch_node_feats=[\"feat\"],\n            prefetch_labels=[\"label\"],\n        )\n        assert len(sampler.fanouts) == args.layer\n    elif args.sampler.lower() == \"shadow\":\n        sampler = ShaDowKHopSampler(\n            [10, 5],\n            output_device=device,\n            prefetch_node_feats=[\"feat\"],\n        )\n    else:\n        raise NotImplementedError\n\n    train_dataloader = DataLoader(\n        g,\n        train_idx.to(device),\n        sampler,\n        device=device,\n        batch_size=b_size,\n        drop_last=drop_last,\n        shuffle=shuffle,\n        num_workers=num_of_samplers,\n        use_ddp=True,\n    )\n\n    # training loop\n    opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4)\n    params = {\n        # training\n        \"loader\": train_dataloader,\n        \"model\": model,\n        \"opt\": opt,\n        # logging\n        \"rank\": rank,\n        \"train_size\": len(train_idx),\n        \"batch_size\": b_size,\n        \"device\": device,\n        \"process\": world_size,\n    }\n\n    PATH = \"model.pt\"\n    if counter[0] != 0:\n        checkpoint = torch.load(PATH, weights_only=False)\n        model.load_state_dict(checkpoint[\"model_state_dict\"])\n        opt.load_state_dict(checkpoint[\"optimizer_state_dict\"])\n        epoch = checkpoint[\"epoch\"]\n        loss = checkpoint[\"loss\"]\n\n    for epoch in range(ep):\n        params[\"epoch\"] = epoch\n        model.train()\n        params[\"load_core\"] = load_core\n        params[\"comp_core\"] = comp_core\n        loss = _train(**params)\n        if rank == 0:\n            print(\"loss:\", loss)\n\n    dist.barrier()\n    EPOCH = counter[0]\n    LOSS = loss\n    if rank == 0:\n        torch.save(\n            {\n                \"epoch\": EPOCH,\n                \"model_state_dict\": model.state_dict(),\n                \"optimizer_state_dict\": opt.state_dict(),\n                \"loss\": LOSS,\n            },\n            PATH,\n        )\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--dataset\",\n        type=str,\n        default=\"ogbn-products\",\n        choices=[\n            \"ogbn-papers100M\",\n            \"ogbn-products\",\n            \"reddit\",\n            \"yelp\",\n            \"flickr\",\n        ],\n    )\n    parser.add_argument(\"--batch_size\", type=int, default=1024 * 4)\n    parser.add_argument(\"--layer\", type=int, default=3)\n    parser.add_argument(\"--fan_out\", type=str, default=\"15,10,5\")\n    parser.add_argument(\n        \"--sampler\",\n        type=str,\n        default=\"neighbor\",\n        choices=[\"neighbor\", \"shadow\"],\n    )\n    parser.add_argument(\n        \"--model\", type=str, default=\"sage\", choices=[\"sage\", \"gcn\"]\n    )\n    parser.add_argument(\"--hidden\", type=int, default=128)\n    arguments = parser.parse_args()\n\n    os.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = \"max_split_size_mb:512\"\n\n    if arguments.dataset in [\"reddit\", \"flickr\", \"yelp\"]:\n        if arguments.dataset == \"reddit\":\n            dataset = RedditDataset()\n        elif arguments.dataset == \"flickr\":\n            dataset = FlickrDataset()\n        else:\n            dataset = YelpDataset()\n        g = dataset[0]\n        train_mask = g.ndata[\"train_mask\"]\n        idx = []\n        for i in range(len(train_mask)):\n            if train_mask[i]:\n                idx.append(i)\n        dataset.train_idx = torch.tensor(idx)\n    else:\n        dataset = AsNodePredDataset(DglNodePropPredDataset(arguments.dataset))\n        g = dataset[0]\n\n    data = (dataset.num_classes, dataset.train_idx)\n\n    in_size = g.ndata[\"feat\"].shape[1]\n    out_size = dataset.num_classes\n    hidden_size = int(arguments.hidden)\n\n    os.environ[\"MASTER_ADDR\"] = \"127.0.0.1\"\n    os.environ[\"MASTER_PORT\"] = \"29501\"\n    mp.set_start_method(\"fork\", force=True)\n    runtime = ARGO(n_search=10, epoch=20, batch_size=arguments.batch_size)\n    runtime.run(train, args=(arguments, g, data))\n"
  },
  {
    "path": "examples/pytorch/argo/ogb_example.py",
    "content": "\"\"\"\nThis is modified version of: https://github.com/dmlc/dgl/blob/master/examples/pytorch/ogb/ogbn-products/graphsage/main.py\n\"\"\"\n\nimport argparse\nimport time\n\nimport dgl\nimport dgl.nn.pytorch as dglnn\n\nimport numpy as np\nimport torch as th\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nimport tqdm\nfrom ogb.nodeproppred import DglNodePropPredDataset\n\n\nclass SAGE(nn.Module):\n    def __init__(\n        self, in_feats, n_hidden, n_classes, n_layers, activation, dropout\n    ):\n        super().__init__()\n        self.n_layers = n_layers\n        self.n_hidden = n_hidden\n        self.n_classes = n_classes\n        self.layers = nn.ModuleList()\n        self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, \"mean\"))\n        for i in range(1, n_layers - 1):\n            self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, \"mean\"))\n        self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, \"mean\"))\n        self.dropout = nn.Dropout(dropout)\n        self.activation = activation\n\n    def forward(self, blocks, x):\n        h = x\n        for l, (layer, block) in enumerate(zip(self.layers, blocks)):\n            # We need to first copy the representation of nodes on the RHS from the\n            # appropriate nodes on the LHS.\n            # Note that the shape of h is (num_nodes_LHS, D) and the shape of h_dst\n            # would be (num_nodes_RHS, D)\n            h_dst = h[: block.num_dst_nodes()]\n            # Then we compute the updated representation on the RHS.\n            # The shape of h now becomes (num_nodes_RHS, D)\n            h = layer(block, (h, h_dst))\n            if l != len(self.layers) - 1:\n                h = self.activation(h)\n                h = self.dropout(h)\n        return h\n\n    def inference(self, g, x, device):\n        \"\"\"\n        Inference with the GraphSAGE model on full neighbors (i.e. without neighbor sampling).\n        g : the entire graph.\n        x : the input of entire node set.\n        The inference code is written in a fashion that it could handle any number of nodes and\n        layers.\n        \"\"\"\n        # During inference with sampling, multi-layer blocks are very inefficient because\n        # lots of computations in the first few layers are repeated.\n        # Therefore, we compute the representation of all nodes layer by layer.  The nodes\n        # on each layer are of course splitted in batches.\n        # TODO: can we standardize this?\n        for l, layer in enumerate(self.layers):\n            y = th.zeros(\n                g.num_nodes(),\n                self.n_hidden if l != len(self.layers) - 1 else self.n_classes,\n            ).to(device)\n\n            sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)\n            dataloader = dgl.dataloading.DataLoader(\n                g,\n                th.arange(g.num_nodes()),\n                sampler,\n                batch_size=args.batch_size,\n                shuffle=True,\n                drop_last=False,\n                num_workers=args.num_workers,\n            )\n\n            for input_nodes, output_nodes, blocks in tqdm.tqdm(\n                dataloader, disable=None\n            ):\n                block = blocks[0].int().to(device)\n\n                h = x[input_nodes]\n                h_dst = h[: block.num_dst_nodes()]\n                h = layer(block, (h, h_dst))\n                if l != len(self.layers) - 1:\n                    h = self.activation(h)\n                    h = self.dropout(h)\n\n                y[output_nodes] = h\n\n            x = y\n        return y\n\n\ndef compute_acc(pred, labels):\n    \"\"\"\n    Compute the accuracy of prediction given the labels.\n    \"\"\"\n    return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred)\n\n\ndef evaluate(model, g, nfeat, labels, val_nid, test_nid, device):\n    \"\"\"\n    Evaluate the model on the validation set specified by ``val_mask``.\n    g : The entire graph.\n    inputs : The features of all the nodes.\n    labels : The labels of all the nodes.\n    val_mask : A 0-1 mask indicating which nodes do we actually compute the accuracy for.\n    device : The GPU device to evaluate on.\n    \"\"\"\n    model.eval()\n    with th.no_grad():\n        pred = model.inference(g, nfeat, device)\n    model.train()\n    return (\n        compute_acc(pred[val_nid], labels[val_nid]),\n        compute_acc(pred[test_nid], labels[test_nid]),\n        pred,\n    )\n\n\ndef load_subtensor(nfeat, labels, seeds, input_nodes):\n    \"\"\"\n    Extracts features and labels for a set of nodes.\n    \"\"\"\n    batch_inputs = nfeat[input_nodes]\n    batch_labels = labels[seeds]\n    return batch_inputs, batch_labels\n\n\n#### Entry point\ndef train(args, device, data):\n    # Unpack data\n    train_nid, val_nid, test_nid, in_feats, labels, n_classes, nfeat, g = data\n\n    # Create PyTorch DataLoader for constructing blocks\n    sampler = dgl.dataloading.MultiLayerNeighborSampler(\n        [int(fanout) for fanout in args.fan_out.split(\",\")]\n    )\n    dataloader = dgl.dataloading.DataLoader(\n        g,\n        train_nid,\n        sampler,\n        batch_size=args.batch_size,\n        shuffle=True,\n        drop_last=False,\n        num_workers=args.num_workers,\n    )\n\n    # Define model and optimizer\n    model = SAGE(\n        in_feats,\n        args.num_hidden,\n        n_classes,\n        args.num_layers,\n        F.relu,\n        args.dropout,\n    )\n    model = model.to(device)\n    loss_fcn = nn.CrossEntropyLoss()\n    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)\n\n    # Training loop\n    avg = 0\n    iter_tput = []\n    best_eval_acc = 0\n    best_test_acc = 0\n    with dataloader.enable_cpu_affinity():\n        for epoch in range(args.num_epochs):\n            tic = time.time()\n\n            # Loop over the dataloader to sample the computation dependency graph as a list of\n            # blocks.\n            for step, (input_nodes, seeds, blocks) in enumerate(dataloader):\n                tic_step = time.time()\n\n                # copy block to gpu\n                blocks = [blk.int().to(device) for blk in blocks]\n\n                # Load the input features as well as output labels\n                batch_inputs, batch_labels = load_subtensor(\n                    nfeat, labels, seeds, input_nodes\n                )\n\n                # Compute loss and prediction\n                batch_pred = model(blocks, batch_inputs)\n                loss = loss_fcn(batch_pred, batch_labels)\n                optimizer.zero_grad()\n                loss.backward()\n                optimizer.step()\n\n                iter_tput.append(len(seeds) / (time.time() - tic_step))\n                if step % args.log_every == 0 and step != 0:\n                    acc = compute_acc(batch_pred, batch_labels)\n                    print(\n                        \"Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | Speed (samples/sec) {:.4f}\".format(\n                            step,\n                            loss.item(),\n                            acc.item(),\n                            np.mean(iter_tput[3:]),\n                        )\n                    )\n\n            toc = time.time()\n            print(\"Epoch Time(s): {:.4f}\".format(toc - tic))\n            avg += toc - tic\n            if epoch % args.eval_every == 0 and epoch != 0:\n                eval_acc, test_acc, pred = evaluate(\n                    model, g, nfeat, labels, val_nid, test_nid, device\n                )\n                if args.save_pred:\n                    np.savetxt(\n                        args.save_pred + \"%02d\" % epoch,\n                        pred.argmax(1).cpu().numpy(),\n                        \"%d\",\n                    )\n                print(\"Eval Acc {:.4f}\".format(eval_acc))\n                if eval_acc > best_eval_acc:\n                    best_eval_acc = eval_acc\n                    best_test_acc = test_acc\n                print(\n                    \"Best Eval Acc {:.4f} Test Acc {:.4f}\".format(\n                        best_eval_acc, best_test_acc\n                    )\n                )\n\n    print(\"Avg epoch time: {}\".format(avg / args.num_epochs))\n    return best_test_acc\n\n\nif __name__ == \"__main__\":\n    argparser = argparse.ArgumentParser(\"multi-gpu training\")\n    argparser.add_argument(\n        \"--gpu\",\n        type=int,\n        default=0,\n        help=\"GPU device ID. Use -1 for CPU training\",\n    )\n    argparser.add_argument(\"--num-epochs\", type=int, default=20)\n    argparser.add_argument(\"--num-hidden\", type=int, default=256)\n    argparser.add_argument(\"--num-layers\", type=int, default=3)\n    argparser.add_argument(\"--fan-out\", type=str, default=\"5,10,15\")\n    argparser.add_argument(\"--batch-size\", type=int, default=1000)\n    argparser.add_argument(\"--val-batch-size\", type=int, default=10000)\n    argparser.add_argument(\"--log-every\", type=int, default=20)\n    argparser.add_argument(\"--eval-every\", type=int, default=1)\n    argparser.add_argument(\"--lr\", type=float, default=0.003)\n    argparser.add_argument(\"--dropout\", type=float, default=0.5)\n    argparser.add_argument(\n        \"--dataset\",\n        type=str,\n        default=\"ogbn-products\",\n        choices=[\"ogbn-papers100M\", \"ogbn-products\"],\n    )\n    argparser.add_argument(\n        \"--num-workers\",\n        type=int,\n        default=4,\n        help=\"Number of sampling processes. Use 0 for no extra process.\",\n    )\n    argparser.add_argument(\"--save-pred\", type=str, default=\"\")\n    argparser.add_argument(\"--wd\", type=float, default=0)\n    args = argparser.parse_args()\n\n    device = th.device(\"cpu\")\n\n    # load ogbn-products data\n    data = DglNodePropPredDataset(args.dataset)\n    splitted_idx = data.get_idx_split()\n    train_idx, val_idx, test_idx = (\n        splitted_idx[\"train\"],\n        splitted_idx[\"valid\"],\n        splitted_idx[\"test\"],\n    )\n    graph, labels = data[0]\n    nfeat = graph.ndata.pop(\"feat\").to(device)\n    labels = labels[:, 0].to(device)\n\n    in_feats = nfeat.shape[1]\n    n_classes = (labels.max() + 1).item()\n    # Create csr/coo/csc formats before launching sampling processes\n    # This avoids creating certain formats in each data loader process, which saves momory and CPU.\n    graph.create_formats_()\n    # Pack data\n    data = (\n        train_idx,\n        val_idx,\n        test_idx,\n        in_feats,\n        labels,\n        n_classes,\n        nfeat,\n        graph,\n    )\n\n    test_acc = train(args, device, data).cpu().numpy()\n    print(\"Test accuracy:\", test_acc)\n"
  },
  {
    "path": "examples/pytorch/argo/ogb_example_ARGO.py",
    "content": "\"\"\"\nThis is a modified version of: https://github.com/dmlc/dgl/blob/master/examples/pytorch/ogb/ogbn-products/graphsage/main.py\nThis example shows how to enable ARGO to automatically instantiate multi-processing and adjust CPU core assignment to achieve better performance.\n\"\"\"\n\nimport argparse\n\nimport ctypes\nimport os\nimport time\nfrom multiprocessing import RawValue\n\nimport dgl\nimport dgl.nn.pytorch as dglnn\n\nimport numpy as np\nimport torch as th\nimport torch.distributed as dist\nimport torch.multiprocessing as mp\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nimport tqdm\nfrom argo import ARGO\nfrom ogb.nodeproppred import DglNodePropPredDataset\nfrom torch.nn.parallel import DistributedDataParallel\n\navg_total = RawValue(ctypes.c_float, 0.0)\n\n\nclass SAGE(nn.Module):\n    def __init__(\n        self, in_feats, n_hidden, n_classes, n_layers, activation, dropout\n    ):\n        super().__init__()\n        self.n_layers = n_layers\n        self.n_hidden = n_hidden\n        self.n_classes = n_classes\n        self.layers = nn.ModuleList()\n        self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, \"mean\"))\n        for i in range(1, n_layers - 1):\n            self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, \"mean\"))\n        self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, \"mean\"))\n        self.dropout = nn.Dropout(dropout)\n        self.activation = activation\n\n    def forward(self, blocks, x):\n        h = x\n        for l, (layer, block) in enumerate(zip(self.layers, blocks)):\n            # We need to first copy the representation of nodes on the RHS from the\n            # appropriate nodes on the LHS.\n            # Note that the shape of h is (num_nodes_LHS, D) and the shape of h_dst\n            # would be (num_nodes_RHS, D)\n            h_dst = h[: block.num_dst_nodes()]\n            # Then we compute the updated representation on the RHS.\n            # The shape of h now becomes (num_nodes_RHS, D)\n            h = layer(block, (h, h_dst))\n            if l != len(self.layers) - 1:\n                h = self.activation(h)\n                h = self.dropout(h)\n        return h\n\n    def inference(self, g, x, device):\n        \"\"\"\n        Inference with the GraphSAGE model on full neighbors (i.e. without neighbor sampling).\n        g : the entire graph.\n        x : the input of entire node set.\n        The inference code is written in a fashion that it could handle any number of nodes and\n        layers.\n        \"\"\"\n        # During inference with sampling, multi-layer blocks are very inefficient because\n        # lots of computations in the first few layers are repeated.\n        # Therefore, we compute the representation of all nodes layer by layer.  The nodes\n        # on each layer are of course splitted in batches.\n        # TODO: can we standardize this?\n        for l, layer in enumerate(self.layers):\n            y = th.zeros(\n                g.num_nodes(),\n                self.n_hidden if l != len(self.layers) - 1 else self.n_classes,\n            ).to(device)\n\n            sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)\n            dataloader = dgl.dataloading.DataLoader(\n                g,\n                th.arange(g.num_nodes()),\n                sampler,\n                batch_size=args.batch_size,\n                shuffle=True,\n                drop_last=False,\n                num_workers=args.num_workers,\n            )\n\n            for input_nodes, output_nodes, blocks in tqdm.tqdm(\n                dataloader, disable=None\n            ):\n                block = blocks[0].int().to(device)\n\n                h = x[input_nodes]\n                h_dst = h[: block.num_dst_nodes()]\n                h = layer(block, (h, h_dst))\n                if l != len(self.layers) - 1:\n                    h = self.activation(h)\n                    h = self.dropout(h)\n\n                y[output_nodes] = h\n\n            x = y\n        return y\n\n\ndef compute_acc(pred, labels):\n    \"\"\"\n    Compute the accuracy of prediction given the labels.\n    \"\"\"\n    return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred)\n\n\ndef evaluate(model, g, nfeat, labels, val_nid, test_nid, device):\n    \"\"\"\n    Evaluate the model on the validation set specified by ``val_mask``.\n    g : The entire graph.\n    inputs : The features of all the nodes.\n    labels : The labels of all the nodes.\n    val_mask : A 0-1 mask indicating which nodes do we actually compute the accuracy for.\n    device : The GPU device to evaluate on.\n    \"\"\"\n    model.eval()\n    with th.no_grad():\n        pred = model.module.inference(g, nfeat, device)\n    model.train()\n    return (\n        compute_acc(pred[val_nid], labels[val_nid]),\n        compute_acc(pred[test_nid], labels[test_nid]),\n        pred,\n    )\n\n\ndef load_subtensor(nfeat, labels, seeds, input_nodes):\n    \"\"\"\n    Extracts features and labels for a set of nodes.\n    \"\"\"\n    batch_inputs = nfeat[input_nodes]\n    batch_labels = labels[seeds]\n    return batch_inputs, batch_labels\n\n\n#### Entry point\ndef train(\n    args,\n    device,\n    data,\n    rank,\n    world_size,\n    comp_core,\n    load_core,\n    counter,\n    b_size,\n    ep,\n):\n    dist.init_process_group(\"gloo\", rank=rank, world_size=world_size)\n    # Unpack data\n    train_nid, val_nid, test_nid, in_feats, labels, n_classes, nfeat, g = data\n\n    # Create PyTorch DataLoader for constructing blocks\n    sampler = dgl.dataloading.MultiLayerNeighborSampler(\n        [int(fanout) for fanout in args.fan_out.split(\",\")]\n    )\n    dataloader = dgl.dataloading.DataLoader(\n        g,\n        train_nid,\n        sampler,\n        batch_size=b_size,\n        shuffle=True,\n        drop_last=False,\n        num_workers=len(load_core),\n        use_ddp=True,\n    )\n\n    # Define model and optimizer\n    model = SAGE(\n        in_feats,\n        args.num_hidden,\n        n_classes,\n        args.num_layers,\n        F.relu,\n        args.dropout,\n    )\n    model = model.to(device)\n    model = DistributedDataParallel(model)\n    loss_fcn = nn.CrossEntropyLoss()\n    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)\n\n    # Training loop\n    avg = 0\n    iter_tput = []\n    best_eval_acc = 0\n    best_test_acc = 0\n    PATH = \"model.pt\"\n    if counter[0] != 0:\n        checkpoint = th.load(PATH)\n        model.load_state_dict(checkpoint[\"model_state_dict\"])\n        optimizer.load_state_dict(checkpoint[\"optimizer_state_dict\"])\n        epoch = checkpoint[\"epoch\"]\n        loss = checkpoint[\"loss\"]\n\n    with dataloader.enable_cpu_affinity(\n        loader_cores=load_core, compute_cores=comp_core\n    ):\n        for epoch in range(ep):\n            tic = time.time()\n\n            # Loop over the dataloader to sample the computation dependency graph as a list of\n            # blocks.\n            for step, (input_nodes, seeds, blocks) in enumerate(dataloader):\n                tic_step = time.time()\n\n                # copy block to gpu\n                blocks = [blk.int().to(device) for blk in blocks]\n\n                # Load the input features as well as output labels\n                batch_inputs, batch_labels = load_subtensor(\n                    nfeat, labels, seeds, input_nodes\n                )\n\n                # Compute loss and prediction\n                batch_pred = model(blocks, batch_inputs)\n                loss = loss_fcn(batch_pred, batch_labels)\n                optimizer.zero_grad()\n                loss.backward()\n                optimizer.step()\n\n                iter_tput.append(len(seeds) / (time.time() - tic_step))\n                if step % args.log_every == 0 and step != 0:\n                    acc = compute_acc(batch_pred, batch_labels)\n                    print(\n                        \"Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | Speed (samples/sec) {:.4f}\".format(\n                            step,\n                            loss.item(),\n                            acc.item(),\n                            np.mean(iter_tput[3:]),\n                        )\n                    )\n\n            toc = time.time()\n            print(\"Epoch Time(s): {:.4f}\".format(toc - tic))\n\n            if rank == 0:\n                global avg_total\n                avg_total.value += toc - tic\n                avg += toc - tic\n\n                if epoch % args.eval_every == 0 and epoch != 0:\n                    eval_acc, test_acc, pred = evaluate(\n                        model, g, nfeat, labels, val_nid, test_nid, device\n                    )\n                    if args.save_pred:\n                        np.savetxt(\n                            args.save_pred + \"%02d\" % epoch,\n                            pred.argmax(1).cpu().numpy(),\n                            \"%d\",\n                        )\n                    print(\"Eval Acc {:.4f}\".format(eval_acc))\n                    if eval_acc > best_eval_acc:\n                        best_eval_acc = eval_acc\n                        best_test_acc = test_acc\n                    print(\n                        \"Best Eval Acc {:.4f} Test Acc {:.4f}\".format(\n                            best_eval_acc, best_test_acc\n                        )\n                    )\n\n    dist.barrier()\n    if rank == 0:\n        th.save(\n            {\n                \"epoch\": counter[0],\n                \"model_state_dict\": model.state_dict(),\n                \"optimizer_state_dict\": optimizer.state_dict(),\n                \"loss\": loss,\n            },\n            PATH,\n        )\n        if args.num_epochs == counter[0] + epoch + 1:\n            print(\n                \"Avg epoch time: {}\".format(avg_total.value / args.num_epochs)\n            )\n            print(\n                \"Avg epoch time after auto-tuning: {}\".format(avg / (epoch + 1))\n            )\n\n    return best_test_acc\n\n\nif __name__ == \"__main__\":\n    argparser = argparse.ArgumentParser(\"multi-gpu training\")\n    argparser.add_argument(\n        \"--gpu\",\n        type=int,\n        default=0,\n        help=\"GPU device ID. Use -1 for CPU training\",\n    )\n    argparser.add_argument(\"--num-epochs\", type=int, default=20)\n    argparser.add_argument(\"--num-hidden\", type=int, default=256)\n    argparser.add_argument(\"--num-layers\", type=int, default=3)\n    argparser.add_argument(\"--fan-out\", type=str, default=\"5,10,15\")\n    argparser.add_argument(\"--batch-size\", type=int, default=1000)\n    argparser.add_argument(\"--val-batch-size\", type=int, default=10000)\n    argparser.add_argument(\"--log-every\", type=int, default=20)\n    argparser.add_argument(\"--eval-every\", type=int, default=1)\n    argparser.add_argument(\"--lr\", type=float, default=0.003)\n    argparser.add_argument(\"--dropout\", type=float, default=0.5)\n    argparser.add_argument(\n        \"--dataset\",\n        type=str,\n        default=\"ogbn-products\",\n        choices=[\"ogbn-papers100M\", \"ogbn-products\"],\n    )\n    argparser.add_argument(\n        \"--num-workers\",\n        type=int,\n        default=4,\n        help=\"Number of sampling processes. Use 0 for no extra process.\",\n    )\n    argparser.add_argument(\"--save-pred\", type=str, default=\"\")\n    argparser.add_argument(\"--wd\", type=float, default=0)\n    args = argparser.parse_args()\n\n    device = th.device(\"cpu\")\n\n    # load ogbn-products data\n    data = DglNodePropPredDataset(args.dataset)\n    splitted_idx = data.get_idx_split()\n    train_idx, val_idx, test_idx = (\n        splitted_idx[\"train\"],\n        splitted_idx[\"valid\"],\n        splitted_idx[\"test\"],\n    )\n    graph, labels = data[0]\n    nfeat = graph.ndata.pop(\"feat\").to(device)\n    labels = labels[:, 0].to(device)\n\n    in_feats = nfeat.shape[1]\n    n_classes = (labels.max() + 1).item()\n    # Create csr/coo/csc formats before launching sampling processes\n    # This avoids creating certain formats in each data loader process, which saves momory and CPU.\n    graph.create_formats_()\n    # Pack data\n    data = (\n        train_idx,\n        val_idx,\n        test_idx,\n        in_feats,\n        labels,\n        n_classes,\n        nfeat,\n        graph,\n    )\n    os.environ[\"MASTER_ADDR\"] = \"127.0.0.1\"\n    os.environ[\"MASTER_PORT\"] = \"29501\"\n    mp.set_start_method(\"fork\", force=True)\n    runtime = ARGO(\n        n_search=15, epoch=args.num_epochs, batch_size=args.batch_size\n    )  # initialization\n    runtime.run(train, args=(args, device, data))  # wrap the training function\n"
  },
  {
    "path": "examples/pytorch/arma/README.md",
    "content": "# DGL Implementation of ARMA\n\nThis DGL example implements the GNN model proposed in the paper [Graph Neural Networks with convolutional ARMA filters](https://arxiv.org/abs/1901.01343).\n\nContributor: [xnuohz](https://github.com/xnuohz)\n\n### Requirements\nThe codebase is implemented in Python 3.6. For version requirement of packages, see below.\n\n```\ndgl\nnumpy 1.19.5\nnetworkx 2.5\nscikit-learn 0.24.1\ntqdm 4.56.0\ntorch 1.7.0\n```\n\n### The graph datasets used in this example\n\n###### Node Classification\n\nThe DGL's built-in Cora, Pubmed, Citeseer datasets. Dataset summary:\n\n| Dataset | #Nodes | #Edges | #Feats | #Classes | #Train Nodes | #Val Nodes | #Test Nodes |\n| :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: |\n| Cora | 2,708 | 10,556 | 1,433 | 7(single label) | 140 | 500 | 1000 |\n| Citeseer | 3,327 | 9,228 | 3,703 | 6(single label) | 120 | 500 | 1000 |\n| Pubmed | 19,717 | 88,651 | 500 | 3(single label) | 60 | 500 | 1000 |\n\n### Usage\n\n###### Dataset options\n```\n--dataset          str     The graph dataset name.             Default is 'Cora'.\n```\n\n###### GPU options\n```\n--gpu              int     GPU index.                          Default is -1, using CPU.\n```\n\n###### Model options\n```\n--epochs           int     Number of training epochs.          Default is 2000.\n--early-stopping   int     Early stopping rounds.              Default is 100.\n--lr               float   Adam optimizer learning rate.       Default is 0.01.\n--lamb             float   L2 regularization coefficient.      Default is 0.0005.\n--hid-dim          int     Hidden layer dimensionalities.      Default is 16.\n--num-stacks       int     Number of K.                        Default is 2.\n--num-layers       int     Number of T.                        Default is 1.\n--dropout          float   Dropout applied at all layers.      Default is 0.75.\n```\n\n###### Examples\n\nThe following commands learn a neural network and predict on the test set.\nTrain an ARMA model which follows the original hyperparameters on different datasets.\n```bash\n# Cora:\npython citation.py --gpu 0\n\n# Citeseer:\npython citation.py --gpu 0 --dataset Citeseer --num-stacks 3\n\n# Pubmed:\npython citation.py --gpu 0 --dataset Pubmed --dropout 0.25 --num-stacks 1\n```\n\n### Performance\n\n###### Node Classification\n\n| Dataset | Cora | Citeseer | Pubmed |\n| :-: | :-: | :-: | :-: |\n| Metrics(Table 1.Node classification accuracy) | 83.4±0.6 | 72.5±0.4 | 78.9±0.3 |\n| Metrics(PyG) | 82.3±0.5 | 70.9±1.1 | 78.3±0.8 |\n| Metrics(DGL) | 80.9±0.6 | 71.6±0.8 | 75.0±4.2 |"
  },
  {
    "path": "examples/pytorch/arma/citation.py",
    "content": "\"\"\" The main file to train an ARMA model using a full graph \"\"\"\n\nimport argparse\nimport copy\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\n\nfrom dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset\nfrom model import ARMA4NC\nfrom tqdm import trange\n\n\ndef main(args):\n    # Step 1: Prepare graph data and retrieve train/validation/test index ============================= #\n    # Load from DGL dataset\n    if args.dataset == \"Cora\":\n        dataset = CoraGraphDataset()\n    elif args.dataset == \"Citeseer\":\n        dataset = CiteseerGraphDataset()\n    elif args.dataset == \"Pubmed\":\n        dataset = PubmedGraphDataset()\n    else:\n        raise ValueError(\"Dataset {} is invalid.\".format(args.dataset))\n\n    graph = dataset[0]\n\n    # check cuda\n    device = (\n        f\"cuda:{args.gpu}\"\n        if args.gpu >= 0 and torch.cuda.is_available()\n        else \"cpu\"\n    )\n\n    # retrieve the number of classes\n    n_classes = dataset.num_classes\n\n    # retrieve labels of ground truth\n    labels = graph.ndata.pop(\"label\").to(device).long()\n\n    # Extract node features\n    feats = graph.ndata.pop(\"feat\").to(device)\n    n_features = feats.shape[-1]\n\n    # retrieve masks for train/validation/test\n    train_mask = graph.ndata.pop(\"train_mask\")\n    val_mask = graph.ndata.pop(\"val_mask\")\n    test_mask = graph.ndata.pop(\"test_mask\")\n\n    train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze().to(device)\n    val_idx = torch.nonzero(val_mask, as_tuple=False).squeeze().to(device)\n    test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze().to(device)\n\n    graph = graph.to(device)\n\n    # Step 2: Create model =================================================================== #\n    model = ARMA4NC(\n        in_dim=n_features,\n        hid_dim=args.hid_dim,\n        out_dim=n_classes,\n        num_stacks=args.num_stacks,\n        num_layers=args.num_layers,\n        activation=nn.ReLU(),\n        dropout=args.dropout,\n    ).to(device)\n\n    best_model = copy.deepcopy(model)\n\n    # Step 3: Create training components ===================================================== #\n    loss_fn = nn.CrossEntropyLoss()\n    opt = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.lamb)\n\n    # Step 4: training epoches =============================================================== #\n    acc = 0\n    no_improvement = 0\n    epochs = trange(args.epochs, desc=\"Accuracy & Loss\")\n\n    for _ in epochs:\n        # Training using a full graph\n        model.train()\n\n        logits = model(graph, feats)\n\n        # compute loss\n        train_loss = loss_fn(logits[train_idx], labels[train_idx])\n        train_acc = torch.sum(\n            logits[train_idx].argmax(dim=1) == labels[train_idx]\n        ).item() / len(train_idx)\n\n        # backward\n        opt.zero_grad()\n        train_loss.backward()\n        opt.step()\n\n        # Validation using a full graph\n        model.eval()\n\n        with torch.no_grad():\n            valid_loss = loss_fn(logits[val_idx], labels[val_idx])\n            valid_acc = torch.sum(\n                logits[val_idx].argmax(dim=1) == labels[val_idx]\n            ).item() / len(val_idx)\n\n        # Print out performance\n        epochs.set_description(\n            \"Train Acc {:.4f} | Train Loss {:.4f} | Val Acc {:.4f} | Val loss {:.4f}\".format(\n                train_acc, train_loss.item(), valid_acc, valid_loss.item()\n            )\n        )\n\n        if valid_acc < acc:\n            no_improvement += 1\n            if no_improvement == args.early_stopping:\n                print(\"Early stop.\")\n                break\n        else:\n            no_improvement = 0\n            acc = valid_acc\n            best_model = copy.deepcopy(model)\n\n    best_model.eval()\n    logits = best_model(graph, feats)\n    test_acc = torch.sum(\n        logits[test_idx].argmax(dim=1) == labels[test_idx]\n    ).item() / len(test_idx)\n\n    print(\"Test Acc {:.4f}\".format(test_acc))\n    return test_acc\n\n\nif __name__ == \"__main__\":\n    \"\"\"\n    ARMA Model Hyperparameters\n    \"\"\"\n    parser = argparse.ArgumentParser(description=\"ARMA GCN\")\n\n    # data source params\n    parser.add_argument(\n        \"--dataset\", type=str, default=\"Cora\", help=\"Name of dataset.\"\n    )\n    # cuda params\n    parser.add_argument(\n        \"--gpu\", type=int, default=-1, help=\"GPU index. Default: -1, using CPU.\"\n    )\n    # training params\n    parser.add_argument(\n        \"--epochs\", type=int, default=2000, help=\"Training epochs.\"\n    )\n    parser.add_argument(\n        \"--early-stopping\",\n        type=int,\n        default=100,\n        help=\"Patient epochs to wait before early stopping.\",\n    )\n    parser.add_argument(\"--lr\", type=float, default=0.01, help=\"Learning rate.\")\n    parser.add_argument(\"--lamb\", type=float, default=5e-4, help=\"L2 reg.\")\n    # model params\n    parser.add_argument(\n        \"--hid-dim\", type=int, default=16, help=\"Hidden layer dimensionalities.\"\n    )\n    parser.add_argument(\n        \"--num-stacks\", type=int, default=2, help=\"Number of K.\"\n    )\n    parser.add_argument(\n        \"--num-layers\", type=int, default=1, help=\"Number of T.\"\n    )\n    parser.add_argument(\n        \"--dropout\",\n        type=float,\n        default=0.75,\n        help=\"Dropout applied at all layers.\",\n    )\n\n    args = parser.parse_args()\n    print(args)\n\n    acc_lists = []\n\n    for _ in range(100):\n        acc_lists.append(main(args))\n\n    mean = np.around(np.mean(acc_lists, axis=0), decimals=3)\n    std = np.around(np.std(acc_lists, axis=0), decimals=3)\n    print(\"Total acc: \", acc_lists)\n    print(\"mean\", mean)\n    print(\"std\", std)\n"
  },
  {
    "path": "examples/pytorch/arma/model.py",
    "content": "import math\n\nimport dgl.function as fn\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\ndef glorot(tensor):\n    if tensor is not None:\n        stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1)))\n        tensor.data.uniform_(-stdv, stdv)\n\n\ndef zeros(tensor):\n    if tensor is not None:\n        tensor.data.fill_(0)\n\n\nclass ARMAConv(nn.Module):\n    def __init__(\n        self,\n        in_dim,\n        out_dim,\n        num_stacks,\n        num_layers,\n        activation=None,\n        dropout=0.0,\n        bias=True,\n    ):\n        super(ARMAConv, self).__init__()\n\n        self.in_dim = in_dim\n        self.out_dim = out_dim\n        self.K = num_stacks\n        self.T = num_layers\n        self.activation = activation\n        self.dropout = nn.Dropout(p=dropout)\n\n        # init weight\n        self.w_0 = nn.ModuleDict(\n            {\n                str(k): nn.Linear(in_dim, out_dim, bias=False)\n                for k in range(self.K)\n            }\n        )\n        # deeper weight\n        self.w = nn.ModuleDict(\n            {\n                str(k): nn.Linear(out_dim, out_dim, bias=False)\n                for k in range(self.K)\n            }\n        )\n        # v\n        self.v = nn.ModuleDict(\n            {\n                str(k): nn.Linear(in_dim, out_dim, bias=False)\n                for k in range(self.K)\n            }\n        )\n        # bias\n        if bias:\n            self.bias = nn.Parameter(\n                torch.Tensor(self.K, self.T, 1, self.out_dim)\n            )\n        else:\n            self.register_parameter(\"bias\", None)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        for k in range(self.K):\n            glorot(self.w_0[str(k)].weight)\n            glorot(self.w[str(k)].weight)\n            glorot(self.v[str(k)].weight)\n        zeros(self.bias)\n\n    def forward(self, g, feats):\n        with g.local_scope():\n            init_feats = feats\n            # assume that the graphs are undirected and graph.in_degrees() is the same as graph.out_degrees()\n            degs = g.in_degrees().float().clamp(min=1)\n            norm = torch.pow(degs, -0.5).to(feats.device).unsqueeze(1)\n            output = []\n\n            for k in range(self.K):\n                feats = init_feats\n                for t in range(self.T):\n                    feats = feats * norm\n                    g.ndata[\"h\"] = feats\n                    g.update_all(fn.copy_u(\"h\", \"m\"), fn.sum(\"m\", \"h\"))\n                    feats = g.ndata.pop(\"h\")\n                    feats = feats * norm\n\n                    if t == 0:\n                        feats = self.w_0[str(k)](feats)\n                    else:\n                        feats = self.w[str(k)](feats)\n\n                    feats += self.dropout(self.v[str(k)](init_feats))\n                    feats += self.v[str(k)](self.dropout(init_feats))\n\n                    if self.bias is not None:\n                        feats += self.bias[k][t]\n\n                    if self.activation is not None:\n                        feats = self.activation(feats)\n                output.append(feats)\n\n            return torch.stack(output).mean(dim=0)\n\n\nclass ARMA4NC(nn.Module):\n    def __init__(\n        self,\n        in_dim,\n        hid_dim,\n        out_dim,\n        num_stacks,\n        num_layers,\n        activation=None,\n        dropout=0.0,\n    ):\n        super(ARMA4NC, self).__init__()\n\n        self.conv1 = ARMAConv(\n            in_dim=in_dim,\n            out_dim=hid_dim,\n            num_stacks=num_stacks,\n            num_layers=num_layers,\n            activation=activation,\n            dropout=dropout,\n        )\n\n        self.conv2 = ARMAConv(\n            in_dim=hid_dim,\n            out_dim=out_dim,\n            num_stacks=num_stacks,\n            num_layers=num_layers,\n            activation=activation,\n            dropout=dropout,\n        )\n\n        self.dropout = nn.Dropout(p=dropout)\n\n    def forward(self, g, feats):\n        feats = F.relu(self.conv1(g, feats))\n        feats = self.dropout(feats)\n        feats = self.conv2(g, feats)\n        return feats\n"
  },
  {
    "path": "examples/pytorch/bgnn/BGNN.py",
    "content": "import itertools\nimport time\nfrom collections import defaultdict as ddict\n\nimport numpy as np\nimport pandas as pd\nimport torch\nimport torch.nn.functional as F\nfrom catboost import CatBoostClassifier, CatBoostRegressor, Pool, sum_models\nfrom sklearn import preprocessing\nfrom sklearn.metrics import r2_score\nfrom tqdm import tqdm\n\n\nclass BGNNPredictor:\n    \"\"\"\n    Description\n    -----------\n    Boost GNN predictor for semi-supervised node classification or regression problems.\n    Publication: https://arxiv.org/abs/2101.08543\n\n    Parameters\n    ----------\n    gnn_model : nn.Module\n        DGL implementation of GNN model.\n    task: str, optional\n        Regression or classification task.\n    loss_fn : callable, optional\n        Function that takes torch tensors, pred and true, and returns a scalar.\n    trees_per_epoch : int, optional\n        Number of GBDT trees to build each epoch.\n    backprop_per_epoch : int, optional\n        Number of backpropagation steps to make each epoch.\n    lr : float, optional\n        Learning rate of gradient descent optimizer.\n    append_gbdt_pred : bool, optional\n        Append GBDT predictions or replace original input node features.\n    train_input_features : bool, optional\n        Train original input node features.\n    gbdt_depth : int, optional\n        Depth of each tree in GBDT model.\n    gbdt_lr : float, optional\n        Learning rate of GBDT model.\n    gbdt_alpha : int, optional\n        Weight to combine previous and new GBDT trees.\n    random_seed : int, optional\n        random seed for GNN and GBDT models.\n\n    Examples\n    ----------\n    gnn_model = GAT(10, 20, num_heads=5),\n    bgnn = BGNNPredictor(gnn_model)\n    metrics = bgnn.fit(graph, X, y, train_mask, val_mask, test_mask, cat_features)\n    \"\"\"\n\n    def __init__(\n        self,\n        gnn_model,\n        task=\"regression\",\n        loss_fn=None,\n        trees_per_epoch=10,\n        backprop_per_epoch=10,\n        lr=0.01,\n        append_gbdt_pred=True,\n        train_input_features=False,\n        gbdt_depth=6,\n        gbdt_lr=0.1,\n        gbdt_alpha=1,\n        random_seed=0,\n    ):\n        self.device = torch.device(\n            \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n        )\n\n        self.model = gnn_model.to(self.device)\n        self.task = task\n        self.loss_fn = loss_fn\n        self.trees_per_epoch = trees_per_epoch\n        self.backprop_per_epoch = backprop_per_epoch\n        self.lr = lr\n        self.append_gbdt_pred = append_gbdt_pred\n        self.train_input_features = train_input_features\n        self.gbdt_depth = gbdt_depth\n        self.gbdt_lr = gbdt_lr\n        self.gbdt_alpha = gbdt_alpha\n        self.random_seed = random_seed\n        torch.manual_seed(random_seed)\n        np.random.seed(random_seed)\n\n    def init_gbdt_model(self, num_epochs, epoch):\n        if self.task == \"regression\":\n            catboost_model_obj = CatBoostRegressor\n            catboost_loss_fn = \"RMSE\"\n        else:\n            if epoch == 0:  # we predict multiclass probs at first epoch\n                catboost_model_obj = CatBoostClassifier\n                catboost_loss_fn = \"MultiClass\"\n            else:  # we predict the gradients for each class at epochs > 0\n                catboost_model_obj = CatBoostRegressor\n                catboost_loss_fn = \"MultiRMSE\"\n\n        return catboost_model_obj(\n            iterations=num_epochs,\n            depth=self.gbdt_depth,\n            learning_rate=self.gbdt_lr,\n            loss_function=catboost_loss_fn,\n            random_seed=self.random_seed,\n            nan_mode=\"Min\",\n        )\n\n    def fit_gbdt(self, pool, trees_per_epoch, epoch):\n        gbdt_model = self.init_gbdt_model(trees_per_epoch, epoch)\n        gbdt_model.fit(pool, verbose=False)\n        return gbdt_model\n\n    def append_gbdt_model(self, new_gbdt_model, weights):\n        if self.gbdt_model is None:\n            return new_gbdt_model\n        return sum_models([self.gbdt_model, new_gbdt_model], weights=weights)\n\n    def train_gbdt(\n        self,\n        gbdt_X_train,\n        gbdt_y_train,\n        cat_features,\n        epoch,\n        gbdt_trees_per_epoch,\n        gbdt_alpha,\n    ):\n        pool = Pool(gbdt_X_train, gbdt_y_train, cat_features=cat_features)\n        epoch_gbdt_model = self.fit_gbdt(pool, gbdt_trees_per_epoch, epoch)\n        if epoch == 0 and self.task == \"classification\":\n            self.base_gbdt = epoch_gbdt_model\n        else:\n            self.gbdt_model = self.append_gbdt_model(\n                epoch_gbdt_model, weights=[1, gbdt_alpha]\n            )\n\n    def update_node_features(self, node_features, X, original_X):\n        # get predictions from gbdt model\n        if self.task == \"regression\":\n            predictions = np.expand_dims(\n                self.gbdt_model.predict(original_X), axis=1\n            )\n        else:\n            predictions = self.base_gbdt.predict_proba(original_X)\n            if self.gbdt_model is not None:\n                predictions_after_one = self.gbdt_model.predict(original_X)\n                predictions += predictions_after_one\n\n        # update node features with predictions\n        if self.append_gbdt_pred:\n            if self.train_input_features:\n                predictions = np.append(\n                    node_features.detach().cpu().data[:, : -self.out_dim],\n                    predictions,\n                    axis=1,\n                )  # replace old predictions with new predictions\n            else:\n                predictions = np.append(\n                    X, predictions, axis=1\n                )  # append original features with new predictions\n\n        predictions = torch.from_numpy(predictions).to(self.device)\n\n        node_features.data = predictions.float().data\n\n    def update_gbdt_targets(\n        self, node_features, node_features_before, train_mask\n    ):\n        return (\n            (node_features - node_features_before)\n            .detach()\n            .cpu()\n            .numpy()[train_mask, -self.out_dim :]\n        )\n\n    def init_node_features(self, X):\n        node_features = torch.empty(\n            X.shape[0], self.in_dim, requires_grad=True, device=self.device\n        )\n        if self.append_gbdt_pred:\n            node_features.data[:, : -self.out_dim] = torch.from_numpy(\n                X.to_numpy(copy=True)\n            )\n        return node_features\n\n    def init_optimizer(\n        self, node_features, optimize_node_features, learning_rate\n    ):\n        params = [self.model.parameters()]\n        if optimize_node_features:\n            params.append([node_features])\n        optimizer = torch.optim.Adam(itertools.chain(*params), lr=learning_rate)\n        return optimizer\n\n    def train_model(self, model_in, target_labels, train_mask, optimizer):\n        y = target_labels[train_mask]\n\n        self.model.train()\n        logits = self.model(*model_in).squeeze()\n        pred = logits[train_mask]\n\n        if self.loss_fn is not None:\n            loss = self.loss_fn(pred, y)\n        else:\n            if self.task == \"regression\":\n                loss = torch.sqrt(F.mse_loss(pred, y))\n            elif self.task == \"classification\":\n                loss = F.cross_entropy(pred, y.long())\n            else:\n                raise NotImplemented(\n                    \"Unknown task. Supported tasks: classification, regression.\"\n                )\n\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n        return loss\n\n    def evaluate_model(self, logits, target_labels, mask):\n        metrics = {}\n        y = target_labels[mask]\n        with torch.no_grad():\n            pred = logits[mask]\n            if self.task == \"regression\":\n                metrics[\"loss\"] = torch.sqrt(\n                    F.mse_loss(pred, y).squeeze() + 1e-8\n                )\n                metrics[\"rmsle\"] = torch.sqrt(\n                    F.mse_loss(torch.log(pred + 1), torch.log(y + 1)).squeeze()\n                    + 1e-8\n                )\n                metrics[\"mae\"] = F.l1_loss(pred, y)\n                metrics[\"r2\"] = torch.Tensor(\n                    [r2_score(y.cpu().numpy(), pred.cpu().numpy())]\n                )\n            elif self.task == \"classification\":\n                metrics[\"loss\"] = F.cross_entropy(pred, y.long())\n                metrics[\"accuracy\"] = torch.Tensor(\n                    [(y == pred.max(1)[1]).sum().item() / y.shape[0]]\n                )\n\n            return metrics\n\n    def train_and_evaluate(\n        self,\n        model_in,\n        target_labels,\n        train_mask,\n        val_mask,\n        test_mask,\n        optimizer,\n        metrics,\n        gnn_passes_per_epoch,\n    ):\n        loss = None\n\n        for _ in range(gnn_passes_per_epoch):\n            loss = self.train_model(\n                model_in, target_labels, train_mask, optimizer\n            )\n\n        self.model.eval()\n        logits = self.model(*model_in).squeeze()\n        train_results = self.evaluate_model(logits, target_labels, train_mask)\n        val_results = self.evaluate_model(logits, target_labels, val_mask)\n        test_results = self.evaluate_model(logits, target_labels, test_mask)\n        for metric_name in train_results:\n            metrics[metric_name].append(\n                (\n                    train_results[metric_name].detach().item(),\n                    val_results[metric_name].detach().item(),\n                    test_results[metric_name].detach().item(),\n                )\n            )\n        return loss\n\n    def update_early_stopping(\n        self,\n        metrics,\n        epoch,\n        best_metric,\n        best_val_epoch,\n        epochs_since_last_best_metric,\n        metric_name,\n        lower_better=False,\n    ):\n        train_metric, val_metric, test_metric = metrics[metric_name][-1]\n        if (lower_better and val_metric < best_metric[1]) or (\n            not lower_better and val_metric > best_metric[1]\n        ):\n            best_metric = metrics[metric_name][-1]\n            best_val_epoch = epoch\n            epochs_since_last_best_metric = 0\n        else:\n            epochs_since_last_best_metric += 1\n        return best_metric, best_val_epoch, epochs_since_last_best_metric\n\n    def log_epoch(\n        self,\n        pbar,\n        metrics,\n        epoch,\n        loss,\n        epoch_time,\n        logging_epochs,\n        metric_name=\"loss\",\n    ):\n        train_metric, val_metric, test_metric = metrics[metric_name][-1]\n        if epoch and epoch % logging_epochs == 0:\n            pbar.set_description(\n                \"Epoch {:05d} | Loss {:.3f} | Loss {:.3f}/{:.3f}/{:.3f} | Time {:.4f}\".format(\n                    epoch,\n                    loss,\n                    train_metric,\n                    val_metric,\n                    test_metric,\n                    epoch_time,\n                )\n            )\n\n    def fit(\n        self,\n        graph,\n        X,\n        y,\n        train_mask,\n        val_mask,\n        test_mask,\n        original_X=None,\n        cat_features=None,\n        num_epochs=100,\n        patience=10,\n        logging_epochs=1,\n        metric_name=\"loss\",\n    ):\n        \"\"\"\n\n        :param graph : dgl.DGLGraph\n            Input graph\n        :param X : pd.DataFrame\n            Input node features. Each column represents one input feature. Each row is a node.\n            Values in dataframe are numerical, after preprocessing.\n        :param y : pd.DataFrame\n            Input node targets. Each column represents one target. Each row is a node\n            (order of nodes should be the same as in X).\n        :param train_mask : list[int]\n            Node indexes (rows) that belong to train set.\n        :param val_mask : list[int]\n            Node indexes (rows) that belong to validation set.\n        :param test_mask : list[int]\n            Node indexes (rows) that belong to test set.\n        :param original_X : pd.DataFrame, optional\n            Input node features before preprocessing. Each column represents one input feature. Each row is a node.\n            Values in dataframe can be of any type, including categorical (e.g. string, bool) or\n            missing values (None). This is useful if you want to preprocess X with GBDT model.\n        :param cat_features: list[int]\n            Feature indexes (columns) which are categorical features.\n        :param num_epochs : int\n            Number of epochs to run.\n        :param patience : int\n            Number of epochs to wait until early stopping.\n        :param logging_epochs : int\n            Log every n epoch.\n        :param metric_name : str\n            Metric to use for early stopping.\n        :param normalize_features : bool\n            If to normalize original input features X (column wise).\n        :param replace_na: bool\n            If to replace missing values (None) in X.\n        :return: metrics evaluated during training\n        \"\"\"\n\n        # initialize for early stopping and metrics\n        if metric_name in [\"r2\", \"accuracy\"]:\n            best_metric = [np.cfloat(\"-inf\")] * 3  # for train/val/test\n        else:\n            best_metric = [np.cfloat(\"inf\")] * 3  # for train/val/test\n\n        best_val_epoch = 0\n        epochs_since_last_best_metric = 0\n        metrics = ddict(list)\n        if cat_features is None:\n            cat_features = []\n\n        if self.task == \"regression\":\n            self.out_dim = y.shape[1]\n        elif self.task == \"classification\":\n            self.out_dim = len(set(y.iloc[test_mask, 0]))\n        self.in_dim = (\n            self.out_dim + X.shape[1] if self.append_gbdt_pred else self.out_dim\n        )\n\n        if original_X is None:\n            original_X = X.copy()\n            cat_features = []\n\n        gbdt_X_train = original_X.iloc[train_mask]\n        gbdt_y_train = y.iloc[train_mask]\n        gbdt_alpha = self.gbdt_alpha\n        self.gbdt_model = None\n\n        node_features = self.init_node_features(X)\n        optimizer = self.init_optimizer(\n            node_features, optimize_node_features=True, learning_rate=self.lr\n        )\n\n        y = (\n            torch.from_numpy(y.to_numpy(copy=True))\n            .float()\n            .squeeze()\n            .to(self.device)\n        )\n        graph = graph.to(self.device)\n\n        pbar = tqdm(range(num_epochs))\n        for epoch in pbar:\n            start2epoch = time.time()\n\n            # gbdt part\n            self.train_gbdt(\n                gbdt_X_train,\n                gbdt_y_train,\n                cat_features,\n                epoch,\n                self.trees_per_epoch,\n                gbdt_alpha,\n            )\n\n            self.update_node_features(node_features, X, original_X)\n            node_features_before = node_features.clone()\n            model_in = (graph, node_features)\n            loss = self.train_and_evaluate(\n                model_in,\n                y,\n                train_mask,\n                val_mask,\n                test_mask,\n                optimizer,\n                metrics,\n                self.backprop_per_epoch,\n            )\n            gbdt_y_train = self.update_gbdt_targets(\n                node_features, node_features_before, train_mask\n            )\n\n            self.log_epoch(\n                pbar,\n                metrics,\n                epoch,\n                loss,\n                time.time() - start2epoch,\n                logging_epochs,\n                metric_name=metric_name,\n            )\n\n            # check early stopping\n            (\n                best_metric,\n                best_val_epoch,\n                epochs_since_last_best_metric,\n            ) = self.update_early_stopping(\n                metrics,\n                epoch,\n                best_metric,\n                best_val_epoch,\n                epochs_since_last_best_metric,\n                metric_name,\n                lower_better=(metric_name not in [\"r2\", \"accuracy\"]),\n            )\n            if patience and epochs_since_last_best_metric > patience:\n                break\n\n            if np.isclose(gbdt_y_train.sum(), 0.0):\n                print(\"Node embeddings do not change anymore. Stopping...\")\n                break\n\n        print(\n            \"Best {} at iteration {}: {:.3f}/{:.3f}/{:.3f}\".format(\n                metric_name, best_val_epoch, *best_metric\n            )\n        )\n        return metrics\n\n    def predict(self, graph, X, test_mask):\n        graph = graph.to(self.device)\n        node_features = torch.empty(X.shape[0], self.in_dim).to(self.device)\n        self.update_node_features(node_features, X, X)\n        logits = self.model(graph, node_features).squeeze()\n        if self.task == \"regression\":\n            return logits[test_mask]\n        else:\n            return logits[test_mask].max(1)[1]\n\n    def plot_interactive(\n        self,\n        metrics,\n        legend,\n        title,\n        logx=False,\n        logy=False,\n        metric_name=\"loss\",\n        start_from=0,\n    ):\n        import plotly.graph_objects as go\n\n        metric_results = metrics[metric_name]\n        xs = [list(range(len(metric_results)))] * len(metric_results[0])\n        ys = list(zip(*metric_results))\n\n        fig = go.Figure()\n        for i in range(len(ys)):\n            fig.add_trace(\n                go.Scatter(\n                    x=xs[i][start_from:],\n                    y=ys[i][start_from:],\n                    mode=\"lines+markers\",\n                    name=legend[i],\n                )\n            )\n\n        fig.update_layout(\n            title=title,\n            title_x=0.5,\n            xaxis_title=\"Epoch\",\n            yaxis_title=metric_name,\n            font=dict(\n                size=40,\n            ),\n            height=600,\n        )\n\n        if logx:\n            fig.update_layout(xaxis_type=\"log\")\n        if logy:\n            fig.update_layout(yaxis_type=\"log\")\n\n        fig.show()\n"
  },
  {
    "path": "examples/pytorch/bgnn/Readme.md",
    "content": "# Instructions to download datasets:\n\n1. Download datasets from here: https://www.dropbox.com/s/verx1evkykzli88/datasets.zip\n2. Extract zip folder in this directory\n3. Choose the dataset you wish in `run.py` file. \n\n# Details about BGNN model\n`run.py` implements a class for GNN model. You can select GAT, GCN, ChebNet, AGNN, or APPNP gnn models.\nOr you can provide your favorite GNN model. You can also pretrain your model or setup the hyperparameters you like. \n\nHyperparameters of BGNN model. \n* `append_gbdt_pred` -- this decides whether to append GBDT predictions from GNN to original input features or to replace original input features with predictions of GBDT. This can be important for performance, so try both values, True and False. \n* `trees_per_epoch` and `backprop_per_epoch`. Values in the range 5-15 usually gives good results. The more, the longer training is. \n* `lr` is learning rate for GNN. 0.01-0.1 are good values to try.\n* `gbdt_lr` is learning rate for GBDT. Should be that important. \n* `gbdt_depth` number of levels in GBDT tree. 4-8 are good values. The more, the longer it trains. "
  },
  {
    "path": "examples/pytorch/bgnn/run.py",
    "content": "import json\nimport os\n\nimport numpy as np\nimport pandas as pd\nimport torch\nimport torch.nn.functional as F\nfrom BGNN import BGNNPredictor\nfrom category_encoders import CatBoostEncoder\n\nfrom dgl.data.utils import load_graphs\nfrom dgl.nn.pytorch import (\n    AGNNConv as AGNNConvDGL,\n    APPNPConv,\n    ChebConv as ChebConvDGL,\n    GATConv as GATConvDGL,\n    GraphConv,\n)\nfrom sklearn import preprocessing\nfrom torch.nn import Dropout, ELU, Linear, ReLU, Sequential\n\n\nclass GNNModelDGL(torch.nn.Module):\n    def __init__(\n        self,\n        in_dim,\n        hidden_dim,\n        out_dim,\n        dropout=0.0,\n        name=\"gat\",\n        residual=True,\n        use_mlp=False,\n        join_with_mlp=False,\n    ):\n        super(GNNModelDGL, self).__init__()\n        self.name = name\n        self.use_mlp = use_mlp\n        self.join_with_mlp = join_with_mlp\n        self.normalize_input_columns = True\n        if name == \"gat\":\n            self.l1 = GATConvDGL(\n                in_dim,\n                hidden_dim // 8,\n                8,\n                feat_drop=dropout,\n                attn_drop=dropout,\n                residual=False,\n                activation=F.elu,\n            )\n            self.l2 = GATConvDGL(\n                hidden_dim,\n                out_dim,\n                1,\n                feat_drop=dropout,\n                attn_drop=dropout,\n                residual=residual,\n                activation=None,\n            )\n        elif name == \"gcn\":\n            self.l1 = GraphConv(in_dim, hidden_dim, activation=F.elu)\n            self.l2 = GraphConv(hidden_dim, out_dim, activation=F.elu)\n            self.drop = Dropout(p=dropout)\n        elif name == \"cheb\":\n            self.l1 = ChebConvDGL(in_dim, hidden_dim, k=3)\n            self.l2 = ChebConvDGL(hidden_dim, out_dim, k=3)\n            self.drop = Dropout(p=dropout)\n        elif name == \"agnn\":\n            self.lin1 = Sequential(\n                Dropout(p=dropout), Linear(in_dim, hidden_dim), ELU()\n            )\n            self.l1 = AGNNConvDGL(learn_beta=False)\n            self.l2 = AGNNConvDGL(learn_beta=True)\n            self.lin2 = Sequential(\n                Dropout(p=dropout), Linear(hidden_dim, out_dim), ELU()\n            )\n        elif name == \"appnp\":\n            self.lin1 = Sequential(\n                Dropout(p=dropout),\n                Linear(in_dim, hidden_dim),\n                ReLU(),\n                Dropout(p=dropout),\n                Linear(hidden_dim, out_dim),\n            )\n            self.l1 = APPNPConv(k=10, alpha=0.1, edge_drop=0.0)\n\n    def forward(self, graph, features):\n        h = features\n        if self.use_mlp:\n            if self.join_with_mlp:\n                h = torch.cat((h, self.mlp(features)), 1)\n            else:\n                h = self.mlp(features)\n        if self.name == \"gat\":\n            h = self.l1(graph, h).flatten(1)\n            logits = self.l2(graph, h).mean(1)\n        elif self.name in [\"appnp\"]:\n            h = self.lin1(h)\n            logits = self.l1(graph, h)\n        elif self.name == \"agnn\":\n            h = self.lin1(h)\n            h = self.l1(graph, h)\n            h = self.l2(graph, h)\n            logits = self.lin2(h)\n        elif self.name == \"che3b\":\n            lambda_max = dgl.laplacian_lambda_max(graph)\n            h = self.drop(h)\n            h = self.l1(graph, h, lambda_max)\n            logits = self.l2(graph, h, lambda_max)\n        elif self.name == \"gcn\":\n            h = self.drop(h)\n            h = self.l1(graph, h)\n            logits = self.l2(graph, h)\n\n        return logits\n\n\ndef read_input(input_folder):\n    X = pd.read_csv(f\"{input_folder}/X.csv\")\n    y = pd.read_csv(f\"{input_folder}/y.csv\")\n\n    categorical_columns = []\n    if os.path.exists(f\"{input_folder}/cat_features.txt\"):\n        with open(f\"{input_folder}/cat_features.txt\") as f:\n            for line in f:\n                if line.strip():\n                    categorical_columns.append(line.strip())\n\n    cat_features = None\n    if categorical_columns:\n        columns = X.columns\n        cat_features = np.where(columns.isin(categorical_columns))[0]\n\n        for col in list(columns[cat_features]):\n            X[col] = X[col].astype(str)\n\n    gs, _ = load_graphs(f\"{input_folder}/graph.dgl\")\n    graph = gs[0]\n\n    with open(f\"{input_folder}/masks.json\") as f:\n        masks = json.load(f)\n\n    return graph, X, y, cat_features, masks\n\n\ndef normalize_features(X, train_mask, val_mask, test_mask):\n    min_max_scaler = preprocessing.MinMaxScaler()\n    A = X.to_numpy(copy=True)\n    A[train_mask] = min_max_scaler.fit_transform(A[train_mask])\n    A[val_mask + test_mask] = min_max_scaler.transform(A[val_mask + test_mask])\n    return pd.DataFrame(A, columns=X.columns).astype(float)\n\n\ndef replace_na(X, train_mask):\n    if X.isna().any().any():\n        return X.fillna(X.iloc[train_mask].min() - 1)\n    return X\n\n\ndef encode_cat_features(X, y, cat_features, train_mask, val_mask, test_mask):\n    enc = CatBoostEncoder()\n    A = X.to_numpy(copy=True)\n    b = y.to_numpy(copy=True)\n    A[np.ix_(train_mask, cat_features)] = enc.fit_transform(\n        A[np.ix_(train_mask, cat_features)], b[train_mask]\n    )\n    A[np.ix_(val_mask + test_mask, cat_features)] = enc.transform(\n        A[np.ix_(val_mask + test_mask, cat_features)]\n    )\n    A = A.astype(float)\n    return pd.DataFrame(A, columns=X.columns)\n\n\nif __name__ == \"__main__\":\n    # datasets can be found here: https://www.dropbox.com/s/verx1evkykzli88/datasets.zip\n    # Read dataset\n    input_folder = \"datasets/avazu\"\n    graph, X, y, cat_features, masks = read_input(input_folder)\n    train_mask, val_mask, test_mask = (\n        masks[\"0\"][\"train\"],\n        masks[\"0\"][\"val\"],\n        masks[\"0\"][\"test\"],\n    )\n\n    encoded_X = X.copy()\n    normalizeFeatures = False\n    replaceNa = True\n\n    if len(cat_features):\n        encoded_X = encode_cat_features(\n            encoded_X, y, cat_features, train_mask, val_mask, test_mask\n        )\n    if normalizeFeatures:\n        encoded_X = normalize_features(\n            encoded_X, train_mask, val_mask, test_mask\n        )\n    if replaceNa:\n        encoded_X = replace_na(encoded_X, train_mask)\n\n    # specify parameters\n    task = \"regression\"\n    hidden_dim = 128\n    trees_per_epoch = 5  # 5-10 are good values to try\n    backprop_per_epoch = 5  # 5-10 are good values to try\n    lr = 0.1  # 0.01-0.1 are good values to try\n    append_gbdt_pred = (\n        False  # this can be important for performance (try True and False)\n    )\n    train_input_features = False\n    gbdt_depth = 6\n    gbdt_lr = 0.1\n\n    out_dim = (\n        y.shape[1] if task == \"regression\" else len(set(y.iloc[test_mask, 0]))\n    )\n    in_dim = out_dim + X.shape[1] if append_gbdt_pred else out_dim\n\n    # specify GNN model\n    gnn_model = GNNModelDGL(in_dim, hidden_dim, out_dim)\n\n    # initialize BGNN model\n    bgnn = BGNNPredictor(\n        gnn_model,\n        task=task,\n        loss_fn=None,\n        trees_per_epoch=trees_per_epoch,\n        backprop_per_epoch=backprop_per_epoch,\n        lr=lr,\n        append_gbdt_pred=append_gbdt_pred,\n        train_input_features=train_input_features,\n        gbdt_depth=gbdt_depth,\n        gbdt_lr=gbdt_lr,\n    )\n\n    # train\n    metrics = bgnn.fit(\n        graph,\n        encoded_X,\n        y,\n        train_mask,\n        val_mask,\n        test_mask,\n        original_X=X,\n        cat_features=cat_features,\n        num_epochs=100,\n        patience=10,\n        metric_name=\"loss\",\n    )\n\n    bgnn.plot_interactive(\n        metrics,\n        legend=[\"train\", \"valid\", \"test\"],\n        title=\"Avazu\",\n        metric_name=\"loss\",\n    )\n"
  },
  {
    "path": "examples/pytorch/bgrl/README.md",
    "content": "# DGL Implementation of BGRL\n\nThis DGL example implements the GNN experiment proposed in the paper [Large-Scale Representation Learning on Graphs via Bootstrapping](https://arxiv.org/abs/2102.06514). For the original implementation, see [here](https://github.com/nerdslab/bgrl).\n\nContributor: [RecLusIve-F](https://github.com/RecLusIve-F)\n\n### Requirements\n\nThe codebase is implemented in Python 3.8. For version requirement of packages, see below.\n\n```\ndgl 0.8.3\nnumpy 1.21.2\ntorch 1.10.2\nscikit-learn 1.0.2\n```\n\n### Dataset\nDataset summary:\n\n|     Dataset      |     Task     | Nodes  |  Edges  | Features |     Classes     |\n|:----------------:|:------------:|:------:|:-------:|:--------:|:---------------:|\n|      WikiCS      | Transductive | 11,701 | 216,123 |   300    |       10        |\n| Amazon Computers | Transductive | 13,752 | 245,861 |   767    |       10        |\n|  Amazon Photos   | Transductive | 7,650  | 119,081 |   745    |        8        |\n|   Coauthor CS    | Transductive | 18,333 | 81,894  |  6,805   |       15        |\n| Coauthor Physics | Transductive | 34,493 | 247,962 |  8,415   |        5        |\n|  PPI(24 graphs)  |  Inductive   | 56,944 | 818,716 |    50    | 121(multilabel) |\n\n### Usage\n\n##### Dataset options\n```\n--dataset                     str         The graph dataset name.                         Default is 'amazon_photos'.\n```\n\n##### Model options\n```\n--graph_encoder_layer         list        Convolutional layer hidden sizes.               Default is [256, 128].\n--predictor_hidden_size       int         Hidden size of predictor.                       Default is 512.\n```\n\n##### Training options\n```\n--epochs                      int         The number of training epochs.                  Default is 10000.\n--lr                          float       The learning rate.                              Default is 0.00001.\n--weight_decay                float       The weight decay.                               Default is 0.00001.\n--mm                          float       The momentum for moving average.                Default is 0.99.\n--lr_warmup_epochs            int         Warmup period for learning rate scheduling.     Default is 1000.    \n--weights_dir                 str         Where to save the weights.                      Default is '../weights'.\n```\n\n##### Augmentation options\n```\n--drop_edge_p                 float      Probability of edge dropout.                     Default is [0., 0.].\n--feat_mask_p                 float      Probability of node feature masking.             Default is [0., 0.].\n```\n\n##### Evaluation options\n```\n--eval_epochs                 int        Evaluate every eval_epochs.                      Default is 250.\n--num_eval_splits             int        Number of evaluation splits.                     Default is 20.\n--data_seed                   int        Data split seed for evaluation.                  Default is 1.\n```\n\n### Instructions for experiments\n\n##### Transductive task\n```\n# Coauthor CS\npython main.py --dataset coauthor_cs --graph_encoder_layer 512 256 --drop_edge_p 0.3 0.2 --feat_mask_p 0.3 0.4\n\n# Coauthor Physics\npython main.py --dataset coauthor_physics --graph_encoder_layer 256 128 --drop_edge_p 0.4 0.1 --feat_mask_p 0.1 0.4\n\n# WikiCS\npython main.py --dataset wiki_cs --graph_encoder_layer 512 256 --drop_edge_p 0.2 0.3 --feat_mask_p 0.2 0.1 --lr 5e-4\n\n# Amazon Photos\npython main.py --dataset amazon_photos --graph_encoder_layer 256 128 --drop_edge_p 0.4 0.1 --feat_mask_p 0.1 0.2 --lr 1e-4\n\n# Amazon Computers\npython main.py --dataset amazon_computers --graph_encoder_layer 256 128 --drop_edge_p 0.5 0.4 --feat_mask_p 0.2 0.1 --lr 5e-4\n```\n\n##### Inductive task\n```\n# PPI\npython main.py --dataset ppi --graph_encoder_layer 512 512 --drop_edge_p 0.3 0.25 --feat_mask_p 0.25 0. --lr 5e-3\n```\n\n### Performance\n\n##### Transductive Task\n|        Dataset         |    WikiCS    |  Am. Comp.   |  Am. Photos  |    Co. CS    |   Co. Phy    |\n|:----------------------:|:------------:|:------------:|:------------:|:------------:|:------------:|\n|   Accuracy Reported    | 79.98 ± 0.10 | 90.34 ± 0.19 | 93.17 ± 0.30 | 93.31 ± 0.13 | 95.73 ± 0.05 |\n| Accuracy Official Code |    79.94     |    90.62     |    93.45     |    93.42     |    95.74     |\n|      Accuracy DGL      |    80.00     |    90.64     |    93.34     |    93.76     |    95.79     |\n\n##### Inductive Task\n|        Dataset         |     PPI      |\n|:----------------------:|:------------:|\n|   Micro-F1 Reported    | 69.41 ± 0.15 |\n| Accuracy Official Code |    68.83     |\n|      Micro-F1 DGL      |    68.65     |\n\n\n##### Accuracy reported is over 20 random dataset splits and model initializations. Micro-F1 reported is over 20 random model initializations.\n\n##### Accuracy official code and Accuracy DGL is only over 1 random dataset splits and model initialization. Micro-F1 official code and Micro-F1 DGL is only over 1 random model initialization."
  },
  {
    "path": "examples/pytorch/bgrl/eval_function.py",
    "content": "import numpy as np\nimport torch\nfrom sklearn import metrics\nfrom sklearn.linear_model import LogisticRegression\nfrom sklearn.model_selection import GridSearchCV, ShuffleSplit, train_test_split\nfrom sklearn.multiclass import OneVsRestClassifier\nfrom sklearn.preprocessing import normalize, OneHotEncoder\n\n\ndef fit_logistic_regression(X, y, data_random_seed=1, repeat=1):\n    # transform targets to one-hot vector\n    one_hot_encoder = OneHotEncoder(categories=\"auto\", sparse=False)\n\n    y = one_hot_encoder.fit_transform(y.reshape(-1, 1)).astype(np.bool_)\n\n    # normalize x\n    X = normalize(X, norm=\"l2\")\n\n    # set random state, this will ensure the dataset will be split exactly the same throughout training\n    rng = np.random.RandomState(data_random_seed)\n\n    accuracies = []\n    for _ in range(repeat):\n        # different random split after each repeat\n        X_train, X_test, y_train, y_test = train_test_split(\n            X, y, test_size=0.8, random_state=rng\n        )\n\n        # grid search with one-vs-rest classifiers\n        logreg = LogisticRegression(solver=\"liblinear\")\n        c = 2.0 ** np.arange(-10, 11)\n        cv = ShuffleSplit(n_splits=5, test_size=0.5)\n        clf = GridSearchCV(\n            estimator=OneVsRestClassifier(logreg),\n            param_grid=dict(estimator__C=c),\n            n_jobs=5,\n            cv=cv,\n            verbose=0,\n        )\n        clf.fit(X_train, y_train)\n\n        y_pred = clf.predict_proba(X_test)\n        y_pred = np.argmax(y_pred, axis=1)\n        y_pred = one_hot_encoder.transform(y_pred.reshape(-1, 1)).astype(\n            np.bool_\n        )\n\n        test_acc = metrics.accuracy_score(y_test, y_pred)\n        accuracies.append(test_acc)\n    return accuracies\n\n\ndef fit_logistic_regression_preset_splits(\n    X, y, train_mask, val_mask, test_mask\n):\n    # transform targets to one-hot vector\n    one_hot_encoder = OneHotEncoder(categories=\"auto\", sparse=False)\n    y = one_hot_encoder.fit_transform(y.reshape(-1, 1)).astype(np.bool_)\n\n    # normalize x\n    X = normalize(X, norm=\"l2\")\n\n    accuracies = []\n    for split_id in range(train_mask.shape[1]):\n        # get train/val/test masks\n        tmp_train_mask, tmp_val_mask = (\n            train_mask[:, split_id],\n            val_mask[:, split_id],\n        )\n\n        # make custom cv\n        X_train, y_train = X[tmp_train_mask], y[tmp_train_mask]\n        X_val, y_val = X[tmp_val_mask], y[tmp_val_mask]\n        X_test, y_test = X[test_mask], y[test_mask]\n\n        # grid search with one-vs-rest classifiers\n        best_test_acc, best_acc = 0, 0\n        for c in 2.0 ** np.arange(-10, 11):\n            clf = OneVsRestClassifier(\n                LogisticRegression(solver=\"liblinear\", C=c)\n            )\n            clf.fit(X_train, y_train)\n\n            y_pred = clf.predict_proba(X_val)\n            y_pred = np.argmax(y_pred, axis=1)\n            y_pred = one_hot_encoder.transform(y_pred.reshape(-1, 1)).astype(\n                np.bool_\n            )\n            val_acc = metrics.accuracy_score(y_val, y_pred)\n            if val_acc > best_acc:\n                best_acc = val_acc\n                y_pred = clf.predict_proba(X_test)\n                y_pred = np.argmax(y_pred, axis=1)\n                y_pred = one_hot_encoder.transform(\n                    y_pred.reshape(-1, 1)\n                ).astype(np.bool_)\n                best_test_acc = metrics.accuracy_score(y_test, y_pred)\n\n        accuracies.append(best_test_acc)\n    return accuracies\n\n\ndef fit_ppi_linear(\n    num_classes, train_data, val_data, test_data, device, repeat=1\n):\n    r\"\"\"\n    Trains a linear layer on top of the representations. This function is specific to the PPI dataset,\n    which has multiple labels.\n    \"\"\"\n\n    def train(classifier, train_data, optimizer):\n        classifier.train()\n\n        x, label = train_data\n        x, label = x.to(device), label.to(device)\n        for step in range(100):\n            # forward\n            optimizer.zero_grad()\n            pred_logits = classifier(x)\n\n            # loss and backprop\n            loss = criterion(pred_logits, label)\n            loss.backward()\n            optimizer.step()\n\n    def test(classifier, data):\n        classifier.eval()\n        x, label = data\n        label = label.cpu().numpy().squeeze()\n\n        # feed to network and classifier\n        with torch.no_grad():\n            pred_logits = classifier(x.to(device))\n            pred_class = (pred_logits > 0).float().cpu().numpy()\n\n        return (\n            metrics.f1_score(label, pred_class, average=\"micro\")\n            if pred_class.sum() > 0\n            else 0\n        )\n\n    num_feats = train_data[0].size(1)\n    criterion = torch.nn.BCEWithLogitsLoss()\n\n    # normalization\n    mean, std = train_data[0].mean(0, keepdim=True), train_data[0].std(\n        0, unbiased=False, keepdim=True\n    )\n    train_data[0] = (train_data[0] - mean) / std\n    val_data[0] = (val_data[0] - mean) / std\n    test_data[0] = (test_data[0] - mean) / std\n\n    best_val_f1 = []\n    test_f1 = []\n    for _ in range(repeat):\n        tmp_best_val_f1 = 0\n        tmp_test_f1 = 0\n        for weight_decay in 2.0 ** np.arange(-10, 11, 2):\n            classifier = torch.nn.Linear(num_feats, num_classes).to(device)\n            optimizer = torch.optim.AdamW(\n                params=classifier.parameters(),\n                lr=0.01,\n                weight_decay=weight_decay,\n            )\n\n            train(classifier, train_data, optimizer)\n            val_f1 = test(classifier, val_data)\n            if val_f1 > tmp_best_val_f1:\n                tmp_best_val_f1 = val_f1\n                tmp_test_f1 = test(classifier, test_data)\n        best_val_f1.append(tmp_best_val_f1)\n        test_f1.append(tmp_test_f1)\n\n    return [best_val_f1], [test_f1]\n"
  },
  {
    "path": "examples/pytorch/bgrl/main.py",
    "content": "import copy\nimport os\nimport warnings\n\nimport dgl\n\nimport numpy as np\nimport torch\nfrom eval_function import (\n    fit_logistic_regression,\n    fit_logistic_regression_preset_splits,\n    fit_ppi_linear,\n)\nfrom model import (\n    BGRL,\n    compute_representations,\n    GCN,\n    GraphSAGE_GCN,\n    MLP_Predictor,\n)\nfrom torch.nn.functional import cosine_similarity\nfrom torch.optim import AdamW\nfrom tqdm import tqdm\nfrom utils import CosineDecayScheduler, get_dataset, get_graph_drop_transform\n\nwarnings.filterwarnings(\"ignore\")\n\n\ndef train(\n    step,\n    model,\n    optimizer,\n    lr_scheduler,\n    mm_scheduler,\n    transform_1,\n    transform_2,\n    data,\n    args,\n):\n    model.train()\n\n    # update learning rate\n    lr = lr_scheduler.get(step)\n    for param_group in optimizer.param_groups:\n        param_group[\"lr\"] = lr\n\n    # update momentum\n    mm = 1 - mm_scheduler.get(step)\n\n    # forward\n    optimizer.zero_grad()\n\n    x1, x2 = transform_1(data), transform_2(data)\n\n    if args.dataset != \"ppi\":\n        x1, x2 = dgl.add_self_loop(x1), dgl.add_self_loop(x2)\n\n    q1, y2 = model(x1, x2)\n    q2, y1 = model(x2, x1)\n\n    loss = (\n        2\n        - cosine_similarity(q1, y2.detach(), dim=-1).mean()\n        - cosine_similarity(q2, y1.detach(), dim=-1).mean()\n    )\n    loss.backward()\n\n    # update online network\n    optimizer.step()\n    # update target network\n    model.update_target_network(mm)\n\n    return loss.item()\n\n\ndef eval(model, dataset, device, args, train_data, val_data, test_data):\n    # make temporary copy of encoder\n    tmp_encoder = copy.deepcopy(model.online_encoder).eval()\n    val_scores = None\n\n    if args.dataset == \"ppi\":\n        train_data = compute_representations(tmp_encoder, train_data, device)\n        val_data = compute_representations(tmp_encoder, val_data, device)\n        test_data = compute_representations(tmp_encoder, test_data, device)\n        num_classes = train_data[1].shape[1]\n        val_scores, test_scores = fit_ppi_linear(\n            num_classes,\n            train_data,\n            val_data,\n            test_data,\n            device,\n            args.num_eval_splits,\n        )\n    elif args.dataset != \"wiki_cs\":\n        representations, labels = compute_representations(\n            tmp_encoder, dataset, device\n        )\n        test_scores = fit_logistic_regression(\n            representations.cpu().numpy(),\n            labels.cpu().numpy(),\n            data_random_seed=args.data_seed,\n            repeat=args.num_eval_splits,\n        )\n    else:\n        g = dataset[0]\n        train_mask = g.ndata[\"train_mask\"]\n        val_mask = g.ndata[\"val_mask\"]\n        test_mask = g.ndata[\"test_mask\"]\n        representations, labels = compute_representations(\n            tmp_encoder, dataset, device\n        )\n        test_scores = fit_logistic_regression_preset_splits(\n            representations.cpu().numpy(),\n            labels.cpu().numpy(),\n            train_mask,\n            val_mask,\n            test_mask,\n        )\n\n    return val_scores, test_scores\n\n\ndef main(args):\n    # use CUDA_VISIBLE_DEVICES to select gpu\n    device = (\n        torch.device(\"cuda\")\n        if torch.cuda.is_available()\n        else torch.device(\"cpu\")\n    )\n    print(\"Using device:\", device)\n\n    dataset, train_data, val_data, test_data = get_dataset(args.dataset)\n\n    g = dataset[0]\n    g = g.to(device)\n\n    input_size, representation_size = (\n        g.ndata[\"feat\"].size(1),\n        args.graph_encoder_layer[-1],\n    )\n\n    # prepare transforms\n    transform_1 = get_graph_drop_transform(\n        drop_edge_p=args.drop_edge_p[0], feat_mask_p=args.feat_mask_p[0]\n    )\n    transform_2 = get_graph_drop_transform(\n        drop_edge_p=args.drop_edge_p[1], feat_mask_p=args.feat_mask_p[1]\n    )\n\n    # scheduler\n    lr_scheduler = CosineDecayScheduler(\n        args.lr, args.lr_warmup_epochs, args.epochs\n    )\n    mm_scheduler = CosineDecayScheduler(1 - args.mm, 0, args.epochs)\n\n    # build networks\n    if args.dataset == \"ppi\":\n        encoder = GraphSAGE_GCN([input_size] + args.graph_encoder_layer)\n    else:\n        encoder = GCN([input_size] + args.graph_encoder_layer)\n    predictor = MLP_Predictor(\n        representation_size,\n        representation_size,\n        hidden_size=args.predictor_hidden_size,\n    )\n    model = BGRL(encoder, predictor).to(device)\n\n    # optimizer\n    optimizer = AdamW(\n        model.trainable_parameters(), lr=args.lr, weight_decay=args.weight_decay\n    )\n\n    # train\n    for epoch in tqdm(range(1, args.epochs + 1), desc=\"  - (Training)  \"):\n        train(\n            epoch - 1,\n            model,\n            optimizer,\n            lr_scheduler,\n            mm_scheduler,\n            transform_1,\n            transform_2,\n            g,\n            args,\n        )\n        if epoch % args.eval_epochs == 0:\n            val_scores, test_scores = eval(\n                model, dataset, device, args, train_data, val_data, test_data\n            )\n            if args.dataset == \"ppi\":\n                print(\n                    \"Epoch: {:04d} | Best Val F1: {:.4f} | Test F1: {:.4f}\".format(\n                        epoch, np.mean(val_scores), np.mean(test_scores)\n                    )\n                )\n            else:\n                print(\n                    \"Epoch: {:04d} | Test Accuracy: {:.4f}\".format(\n                        epoch, np.mean(test_scores)\n                    )\n                )\n\n    # save encoder weights\n    if not os.path.isdir(args.weights_dir):\n        os.mkdir(args.weights_dir)\n    torch.save(\n        {\"model\": model.online_encoder.state_dict()},\n        os.path.join(args.weights_dir, \"bgrl-{}.pt\".format(args.dataset)),\n    )\n\n\nif __name__ == \"__main__\":\n    from argparse import ArgumentParser\n\n    parser = ArgumentParser()\n\n    # Dataset options.\n    parser.add_argument(\n        \"--dataset\",\n        type=str,\n        default=\"amazon_photos\",\n        choices=[\n            \"coauthor_cs\",\n            \"coauthor_physics\",\n            \"amazon_photos\",\n            \"amazon_computers\",\n            \"wiki_cs\",\n            \"ppi\",\n        ],\n    )\n\n    # Model options.\n    parser.add_argument(\n        \"--graph_encoder_layer\", type=int, nargs=\"+\", default=[256, 128]\n    )\n    parser.add_argument(\"--predictor_hidden_size\", type=int, default=512)\n\n    # Training options.\n    parser.add_argument(\"--epochs\", type=int, default=10000)\n    parser.add_argument(\"--lr\", type=float, default=1e-5)\n    parser.add_argument(\"--weight_decay\", type=float, default=1e-5)\n    parser.add_argument(\"--mm\", type=float, default=0.99)\n    parser.add_argument(\"--lr_warmup_epochs\", type=int, default=1000)\n    parser.add_argument(\"--weights_dir\", type=str, default=\"../weights\")\n\n    # Augmentations options.\n    parser.add_argument(\n        \"--drop_edge_p\", type=float, nargs=\"+\", default=[0.0, 0.0]\n    )\n    parser.add_argument(\n        \"--feat_mask_p\", type=float, nargs=\"+\", default=[0.0, 0.0]\n    )\n\n    # Evaluation options.\n    parser.add_argument(\"--eval_epochs\", type=int, default=250)\n    parser.add_argument(\"--num_eval_splits\", type=int, default=20)\n    parser.add_argument(\"--data_seed\", type=int, default=1)\n\n    # Experiment options.\n    parser.add_argument(\"--num_experiments\", type=int, default=20)\n\n    args = parser.parse_args()\n\n    main(args)\n"
  },
  {
    "path": "examples/pytorch/bgrl/model.py",
    "content": "import copy\n\nimport dgl\n\nimport torch\nfrom dgl.nn.pytorch.conv import GraphConv, SAGEConv\nfrom torch import nn\nfrom torch.nn import BatchNorm1d, Parameter\nfrom torch.nn.init import ones_, zeros_\n\n\nclass LayerNorm(nn.Module):\n    def __init__(self, in_channels, eps=1e-5, affine=True):\n        super().__init__()\n        self.in_channels = in_channels\n        self.eps = eps\n\n        if affine:\n            self.weight = Parameter(torch.Tensor(in_channels))\n            self.bias = Parameter(torch.Tensor(in_channels))\n        else:\n            self.register_parameter(\"weight\", None)\n            self.register_parameter(\"bias\", None)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        ones_(self.weight)\n        zeros_(self.bias)\n\n    def forward(self, x, batch=None):\n        device = x.device\n        if batch is None:\n            x = x - x.mean()\n            out = x / (x.std(unbiased=False) + self.eps)\n        else:\n            batch_size = int(batch.max()) + 1\n            batch_idx = [batch == i for i in range(batch_size)]\n            norm = (\n                torch.tensor([i.sum() for i in batch_idx], dtype=x.dtype)\n                .clamp_(min=1)\n                .to(device)\n            )\n            norm = norm.mul_(x.size(-1)).view(-1, 1)\n            tmp_list = [x[i] for i in batch_idx]\n            mean = (\n                torch.concat([i.sum(0).unsqueeze(0) for i in tmp_list], dim=0)\n                .sum(dim=-1, keepdim=True)\n                .to(device)\n            )\n            mean = mean / norm\n            x = x - mean.index_select(0, batch.long())\n            var = (\n                torch.concat(\n                    [(i * i).sum(0).unsqueeze(0) for i in tmp_list], dim=0\n                )\n                .sum(dim=-1, keepdim=True)\n                .to(device)\n            )\n            var = var / norm\n            out = x / (var + self.eps).sqrt().index_select(0, batch.long())\n\n        if self.weight is not None and self.bias is not None:\n            out = out * self.weight + self.bias\n\n        return out\n\n    def __repr__(self):\n        return f\"{self.__class__.__name__}({self.in_channels})\"\n\n\nclass MLP_Predictor(nn.Module):\n    r\"\"\"MLP used for predictor. The MLP has one hidden layer.\n    Args:\n        input_size (int): Size of input features.\n        output_size (int): Size of output features.\n        hidden_size (int, optional): Size of hidden layer. (default: :obj:`4096`).\n    \"\"\"\n\n    def __init__(self, input_size, output_size, hidden_size=512):\n        super().__init__()\n\n        self.net = nn.Sequential(\n            nn.Linear(input_size, hidden_size, bias=True),\n            nn.PReLU(1),\n            nn.Linear(hidden_size, output_size, bias=True),\n        )\n        self.reset_parameters()\n\n    def forward(self, x):\n        return self.net(x)\n\n    def reset_parameters(self):\n        # kaiming_uniform\n        for m in self.modules():\n            if isinstance(m, nn.Linear):\n                m.reset_parameters()\n\n\nclass GCN(nn.Module):\n    def __init__(self, layer_sizes, batch_norm_mm=0.99):\n        super(GCN, self).__init__()\n\n        self.layers = nn.ModuleList()\n        for in_dim, out_dim in zip(layer_sizes[:-1], layer_sizes[1:]):\n            self.layers.append(GraphConv(in_dim, out_dim))\n            self.layers.append(BatchNorm1d(out_dim, momentum=batch_norm_mm))\n            self.layers.append(nn.PReLU())\n\n    def forward(self, g):\n        x = g.ndata[\"feat\"]\n        for layer in self.layers:\n            if isinstance(layer, GraphConv):\n                x = layer(g, x)\n            else:\n                x = layer(x)\n        return x\n\n    def reset_parameters(self):\n        for layer in self.layers:\n            if hasattr(layer, \"reset_parameters\"):\n                layer.reset_parameters()\n\n\nclass GraphSAGE_GCN(nn.Module):\n    def __init__(self, layer_sizes):\n        super().__init__()\n\n        input_size, hidden_size, embedding_size = layer_sizes\n\n        self.convs = nn.ModuleList(\n            [\n                SAGEConv(input_size, hidden_size, \"mean\"),\n                SAGEConv(hidden_size, hidden_size, \"mean\"),\n                SAGEConv(hidden_size, embedding_size, \"mean\"),\n            ]\n        )\n\n        self.skip_lins = nn.ModuleList(\n            [\n                nn.Linear(input_size, hidden_size, bias=False),\n                nn.Linear(input_size, hidden_size, bias=False),\n            ]\n        )\n\n        self.layer_norms = nn.ModuleList(\n            [\n                LayerNorm(hidden_size),\n                LayerNorm(hidden_size),\n                LayerNorm(embedding_size),\n            ]\n        )\n\n        self.activations = nn.ModuleList(\n            [\n                nn.PReLU(),\n                nn.PReLU(),\n                nn.PReLU(),\n            ]\n        )\n\n    def forward(self, g):\n        x = g.ndata[\"feat\"]\n        if \"batch\" in g.ndata.keys():\n            batch = g.ndata[\"batch\"]\n        else:\n            batch = None\n\n        h1 = self.convs[0](g, x)\n        h1 = self.layer_norms[0](h1, batch)\n        h1 = self.activations[0](h1)\n\n        x_skip_1 = self.skip_lins[0](x)\n        h2 = self.convs[1](g, h1 + x_skip_1)\n        h2 = self.layer_norms[1](h2, batch)\n        h2 = self.activations[1](h2)\n\n        x_skip_2 = self.skip_lins[1](x)\n        ret = self.convs[2](g, h1 + h2 + x_skip_2)\n        ret = self.layer_norms[2](ret, batch)\n        ret = self.activations[2](ret)\n        return ret\n\n    def reset_parameters(self):\n        for m in self.convs:\n            m.reset_parameters()\n        for m in self.skip_lins:\n            m.reset_parameters()\n        for m in self.activations:\n            m.weight.data.fill_(0.25)\n        for m in self.layer_norms:\n            m.reset_parameters()\n\n\nclass BGRL(nn.Module):\n    r\"\"\"BGRL architecture for Graph representation learning.\n    Args:\n        encoder (torch.nn.Module): Encoder network to be duplicated and used in both online and target networks.\n        predictor (torch.nn.Module): Predictor network used to predict the target projection from the online projection.\n    .. note::\n        `encoder` must have a `reset_parameters` method, as the weights of the target network will be initialized\n        differently from the online network.\n    \"\"\"\n\n    def __init__(self, encoder, predictor):\n        super(BGRL, self).__init__()\n        # online network\n        self.online_encoder = encoder\n        self.predictor = predictor\n\n        # target network\n        self.target_encoder = copy.deepcopy(encoder)\n\n        # reinitialize weights\n        self.target_encoder.reset_parameters()\n\n        # stop gradient\n        for param in self.target_encoder.parameters():\n            param.requires_grad = False\n\n    def trainable_parameters(self):\n        r\"\"\"Returns the parameters that will be updated via an optimizer.\"\"\"\n        return list(self.online_encoder.parameters()) + list(\n            self.predictor.parameters()\n        )\n\n    @torch.no_grad()\n    def update_target_network(self, mm):\n        r\"\"\"Performs a momentum update of the target network's weights.\n        Args:\n            mm (float): Momentum used in moving average update.\n        \"\"\"\n        for param_q, param_k in zip(\n            self.online_encoder.parameters(), self.target_encoder.parameters()\n        ):\n            param_k.data.mul_(mm).add_(param_q.data, alpha=1.0 - mm)\n\n    def forward(self, online_x, target_x):\n        # forward online network\n        online_y = self.online_encoder(online_x)\n\n        # prediction\n        online_q = self.predictor(online_y)\n\n        # forward target network\n        with torch.no_grad():\n            target_y = self.target_encoder(target_x).detach()\n        return online_q, target_y\n\n\ndef compute_representations(net, dataset, device):\n    r\"\"\"Pre-computes the representations for the entire data.\n    Returns:\n        [torch.Tensor, torch.Tensor]: Representations and labels.\n    \"\"\"\n    net.eval()\n    reps = []\n    labels = []\n\n    if len(dataset) == 1:\n        g = dataset[0]\n        g = dgl.add_self_loop(g)\n        g = g.to(device)\n        with torch.no_grad():\n            reps.append(net(g))\n            labels.append(g.ndata[\"label\"])\n    else:\n        for g in dataset:\n            # forward\n            g = g.to(device)\n            with torch.no_grad():\n                reps.append(net(g))\n                labels.append(g.ndata[\"label\"])\n\n    reps = torch.cat(reps, dim=0)\n    labels = torch.cat(labels, dim=0)\n    return [reps, labels]\n"
  },
  {
    "path": "examples/pytorch/bgrl/utils.py",
    "content": "import copy\n\nimport numpy as np\nimport torch\n\nfrom dgl.data import (\n    AmazonCoBuyComputerDataset,\n    AmazonCoBuyPhotoDataset,\n    CoauthorCSDataset,\n    CoauthorPhysicsDataset,\n    PPIDataset,\n    WikiCSDataset,\n)\nfrom dgl.dataloading import GraphDataLoader\nfrom dgl.transforms import Compose, DropEdge, FeatMask, RowFeatNormalizer\n\n\nclass CosineDecayScheduler:\n    def __init__(self, max_val, warmup_steps, total_steps):\n        self.max_val = max_val\n        self.warmup_steps = warmup_steps\n        self.total_steps = total_steps\n\n    def get(self, step):\n        if step < self.warmup_steps:\n            return self.max_val * step / self.warmup_steps\n        elif self.warmup_steps <= step <= self.total_steps:\n            return (\n                self.max_val\n                * (\n                    1\n                    + np.cos(\n                        (step - self.warmup_steps)\n                        * np.pi\n                        / (self.total_steps - self.warmup_steps)\n                    )\n                )\n                / 2\n            )\n        else:\n            raise ValueError(\n                \"Step ({}) > total number of steps ({}).\".format(\n                    step, self.total_steps\n                )\n            )\n\n\ndef get_graph_drop_transform(drop_edge_p, feat_mask_p):\n    transforms = list()\n\n    # make copy of graph\n    transforms.append(copy.deepcopy)\n\n    # drop edges\n    if drop_edge_p > 0.0:\n        transforms.append(DropEdge(drop_edge_p))\n\n    # drop features\n    if feat_mask_p > 0.0:\n        transforms.append(FeatMask(feat_mask_p, node_feat_names=[\"feat\"]))\n\n    return Compose(transforms)\n\n\ndef get_wiki_cs(transform=RowFeatNormalizer(subtract_min=True)):\n    dataset = WikiCSDataset(transform=transform)\n    g = dataset[0]\n    std, mean = torch.std_mean(g.ndata[\"feat\"], dim=0, unbiased=False)\n    g.ndata[\"feat\"] = (g.ndata[\"feat\"] - mean) / std\n\n    return [g]\n\n\ndef get_ppi():\n    train_dataset = PPIDataset(mode=\"train\")\n    val_dataset = PPIDataset(mode=\"valid\")\n    test_dataset = PPIDataset(mode=\"test\")\n    train_val_dataset = [i for i in train_dataset] + [i for i in val_dataset]\n    for idx, data in enumerate(train_val_dataset):\n        data.ndata[\"batch\"] = torch.zeros(data.num_nodes()) + idx\n        data.ndata[\"batch\"] = data.ndata[\"batch\"].long()\n\n    g = list(GraphDataLoader(train_val_dataset, batch_size=22, shuffle=True))\n\n    return g, PPIDataset(mode=\"train\"), PPIDataset(mode=\"valid\"), test_dataset\n\n\ndef get_dataset(name, transform=RowFeatNormalizer(subtract_min=True)):\n    dgl_dataset_dict = {\n        \"coauthor_cs\": CoauthorCSDataset,\n        \"coauthor_physics\": CoauthorPhysicsDataset,\n        \"amazon_computers\": AmazonCoBuyComputerDataset,\n        \"amazon_photos\": AmazonCoBuyPhotoDataset,\n        \"wiki_cs\": get_wiki_cs,\n        \"ppi\": get_ppi,\n    }\n\n    dataset_class = dgl_dataset_dict[name]\n    train_data, val_data, test_data = None, None, None\n    if name != \"ppi\":\n        dataset = dataset_class(transform=transform)\n    else:\n        dataset, train_data, val_data, test_data = dataset_class()\n\n    return dataset, train_data, val_data, test_data\n"
  },
  {
    "path": "examples/pytorch/capsule/DGLDigitCapsule.py",
    "content": "import dgl\nimport dgl.function as fn\nimport torch\nfrom DGLRoutingLayer import DGLRoutingLayer\nfrom torch import nn\nfrom torch.nn import functional as F\n\n\nclass DGLDigitCapsuleLayer(nn.Module):\n    def __init__(\n        self,\n        in_nodes_dim=8,\n        in_nodes=1152,\n        out_nodes=10,\n        out_nodes_dim=16,\n        device=\"cpu\",\n    ):\n        super(DGLDigitCapsuleLayer, self).__init__()\n        self.device = device\n        self.in_nodes_dim, self.out_nodes_dim = in_nodes_dim, out_nodes_dim\n        self.in_nodes, self.out_nodes = in_nodes, out_nodes\n        self.weight = nn.Parameter(\n            torch.randn(in_nodes, out_nodes, out_nodes_dim, in_nodes_dim)\n        )\n\n    def forward(self, x):\n        self.batch_size = x.size(0)\n        u_hat = self.compute_uhat(x)\n        routing = DGLRoutingLayer(\n            self.in_nodes,\n            self.out_nodes,\n            self.out_nodes_dim,\n            batch_size=self.batch_size,\n            device=self.device,\n        )\n        routing(u_hat, routing_num=3)\n        out_nodes_feature = routing.g.nodes[routing.out_indx].data[\"v\"]\n        # shape transformation is for further classification\n        return (\n            out_nodes_feature.transpose(0, 1)\n            .unsqueeze(1)\n            .unsqueeze(4)\n            .squeeze(1)\n        )\n\n    def compute_uhat(self, x):\n        # x is the input vextor with shape [batch_size, in_nodes_dim, in_nodes]\n        # Transpose x to [batch_size, in_nodes, in_nodes_dim]\n        x = x.transpose(1, 2)\n        # Expand x to [batch_size, in_nodes, out_nodes, in_nodes_dim, 1]\n        x = torch.stack([x] * self.out_nodes, dim=2).unsqueeze(4)\n        # Expand W from [in_nodes, out_nodes, in_nodes_dim, out_nodes_dim]\n        # to [batch_size, in_nodes, out_nodes, out_nodes_dim, in_nodes_dim]\n        W = self.weight.expand(self.batch_size, *self.weight.size())\n        # u_hat's shape is [in_nodes, out_nodes, batch_size, out_nodes_dim]\n        u_hat = torch.matmul(W, x).permute(1, 2, 0, 3, 4).squeeze().contiguous()\n        return u_hat.view(-1, self.batch_size, self.out_nodes_dim)\n"
  },
  {
    "path": "examples/pytorch/capsule/DGLRoutingLayer.py",
    "content": "import dgl\nimport torch as th\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass DGLRoutingLayer(nn.Module):\n    def __init__(self, in_nodes, out_nodes, f_size, batch_size=0, device=\"cpu\"):\n        super(DGLRoutingLayer, self).__init__()\n        self.batch_size = batch_size\n        self.g = init_graph(in_nodes, out_nodes, f_size, device=device)\n        self.in_nodes = in_nodes\n        self.out_nodes = out_nodes\n        self.in_indx = list(range(in_nodes))\n        self.out_indx = list(range(in_nodes, in_nodes + out_nodes))\n        self.device = device\n\n    def forward(self, u_hat, routing_num=1):\n        self.g.edata[\"u_hat\"] = u_hat\n        batch_size = self.batch_size\n\n        # step 2 (line 5)\n        def cap_message(edges):\n            if batch_size:\n                return {\"m\": edges.data[\"c\"].unsqueeze(1) * edges.data[\"u_hat\"]}\n            else:\n                return {\"m\": edges.data[\"c\"] * edges.data[\"u_hat\"]}\n\n        def cap_reduce(nodes):\n            return {\"s\": th.sum(nodes.mailbox[\"m\"], dim=1)}\n\n        for r in range(routing_num):\n            # step 1 (line 4): normalize over out edges\n            edges_b = self.g.edata[\"b\"].view(self.in_nodes, self.out_nodes)\n            self.g.edata[\"c\"] = F.softmax(edges_b, dim=1).view(-1, 1)\n\n            # Execute step 1 & 2\n            self.g.update_all(message_func=cap_message, reduce_func=cap_reduce)\n\n            # step 3 (line 6)\n            if self.batch_size:\n                self.g.nodes[self.out_indx].data[\"v\"] = squash(\n                    self.g.nodes[self.out_indx].data[\"s\"], dim=2\n                )\n            else:\n                self.g.nodes[self.out_indx].data[\"v\"] = squash(\n                    self.g.nodes[self.out_indx].data[\"s\"], dim=1\n                )\n            # step 4 (line 7)\n            v = th.cat(\n                [self.g.nodes[self.out_indx].data[\"v\"]] * self.in_nodes, dim=0\n            )\n            if self.batch_size:\n                self.g.edata[\"b\"] = self.g.edata[\"b\"] + (\n                    self.g.edata[\"u_hat\"] * v\n                ).mean(dim=1).sum(dim=1, keepdim=True)\n            else:\n                self.g.edata[\"b\"] = self.g.edata[\"b\"] + (\n                    self.g.edata[\"u_hat\"] * v\n                ).sum(dim=1, keepdim=True)\n\n\ndef squash(s, dim=1):\n    sq = th.sum(s**2, dim=dim, keepdim=True)\n    s_norm = th.sqrt(sq)\n    s = (sq / (1.0 + sq)) * (s / s_norm)\n    return s\n\n\ndef init_graph(in_nodes, out_nodes, f_size, device=\"cpu\"):\n    src, dst = [], []\n    in_indx = list(range(in_nodes))\n    out_indx = list(range(in_nodes, in_nodes + out_nodes))\n    # add edges use edge broadcasting\n    for u in in_indx:\n        src += [u] * len(out_indx)\n        dst += out_indx\n\n    g = dgl.graph((src, dst))  # dgl.graph once;\n    g.set_n_initializer(dgl.frame.zero_initializer)\n    g = g.to(device)\n    g.edata[\"b\"] = th.zeros(in_nodes * out_nodes, 1).to(device)\n    return g\n"
  },
  {
    "path": "examples/pytorch/capsule/README.md",
    "content": "DGL implementation of Capsule Network\n=====================================\n\nThis repo implements Hinton and his team's [Capsule Network](https://arxiv.org/abs/1710.09829).\nOnly margin loss is implemented, for simplicity to understand the DGL.\n\nDependencies\n--------------\n* PyTorch 0.4.1+\n* torchvision\n\n```bash\npip install torch torchvision\n```\n\nTraining & Evaluation\n----------------------\n```bash\n# Run with default config\npython3 main.py\n# Run with train and test batch size 128, and for 50 epochs\npython3 main.py --batch-size 128 --test-batch-size 128 --epochs 50\n```\n"
  },
  {
    "path": "examples/pytorch/capsule/main.py",
    "content": "import argparse\r\n\r\nimport torch\r\nimport torch.optim as optim\r\nfrom model import Net\r\nfrom torchvision import datasets, transforms\r\n\r\n\r\ndef train(args, model, device, train_loader, optimizer, epoch):\r\n    model.train()\r\n    for batch_idx, (data, target) in enumerate(train_loader):\r\n        data, target = data.to(device), target.to(device)\r\n        optimizer.zero_grad()\r\n        output = model(data)\r\n        loss = model.margin_loss(output, target)\r\n        loss.backward()\r\n        optimizer.step()\r\n        if batch_idx % args.log_interval == 0:\r\n            print(\r\n                \"Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}\".format(\r\n                    epoch,\r\n                    batch_idx * len(data),\r\n                    len(train_loader.dataset),\r\n                    100.0 * batch_idx / len(train_loader),\r\n                    loss.item(),\r\n                )\r\n            )\r\n\r\n\r\ndef test(args, model, device, test_loader):\r\n    model.eval()\r\n    test_loss = 0\r\n    correct = 0\r\n    with torch.no_grad():\r\n        for data, target in test_loader:\r\n            data, target = data.to(device), target.to(device)\r\n            output = model(data)\r\n            test_loss += model.margin_loss(\r\n                output, target\r\n            ).item()  # sum up batch loss\r\n            pred = (\r\n                output.norm(dim=2).squeeze().max(1, keepdim=True)[1]\r\n            )  # get the index of the max log-probability\r\n            correct += pred.eq(target.view_as(pred)).sum().item()\r\n\r\n    test_loss /= len(test_loader.dataset)\r\n    print(\r\n        \"\\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n\".format(\r\n            test_loss,\r\n            correct,\r\n            len(test_loader.dataset),\r\n            100.0 * correct / len(test_loader.dataset),\r\n        )\r\n    )\r\n\r\n\r\ndef main():\r\n    # Training settings\r\n    parser = argparse.ArgumentParser(description=\"PyTorch MNIST Example\")\r\n    parser.add_argument(\r\n        \"--batch-size\",\r\n        type=int,\r\n        default=512,\r\n        metavar=\"N\",\r\n        help=\"input batch size for training (default: 64)\",\r\n    )\r\n    parser.add_argument(\r\n        \"--test-batch-size\",\r\n        type=int,\r\n        default=512,\r\n        metavar=\"N\",\r\n        help=\"input batch size for testing (default: 1000)\",\r\n    )\r\n    parser.add_argument(\r\n        \"--epochs\",\r\n        type=int,\r\n        default=10,\r\n        metavar=\"N\",\r\n        help=\"number of epochs to train (default: 10)\",\r\n    )\r\n    parser.add_argument(\r\n        \"--lr\",\r\n        type=float,\r\n        default=0.01,\r\n        metavar=\"LR\",\r\n        help=\"learning rate (default: 0.01)\",\r\n    )\r\n    parser.add_argument(\r\n        \"--no-cuda\",\r\n        action=\"store_true\",\r\n        default=False,\r\n        help=\"disables CUDA training\",\r\n    )\r\n    parser.add_argument(\r\n        \"--seed\",\r\n        type=int,\r\n        default=1,\r\n        metavar=\"S\",\r\n        help=\"random seed (default: 1)\",\r\n    )\r\n    parser.add_argument(\r\n        \"--log-interval\",\r\n        type=int,\r\n        default=10,\r\n        metavar=\"N\",\r\n        help=\"how many batches to wait before logging training status\",\r\n    )\r\n    args = parser.parse_args()\r\n    use_cuda = not args.no_cuda and torch.cuda.is_available()\r\n\r\n    torch.manual_seed(args.seed)\r\n\r\n    device = torch.device(\"cuda\" if use_cuda else \"cpu\")\r\n\r\n    kwargs = {\"num_workers\": 1, \"pin_memory\": True} if use_cuda else {}\r\n    train_loader = torch.utils.data.DataLoader(\r\n        datasets.MNIST(\r\n            \"../data\",\r\n            train=True,\r\n            download=True,\r\n            transform=transforms.Compose(\r\n                [\r\n                    transforms.ToTensor(),\r\n                    transforms.Normalize((0.1307,), (0.3081,)),\r\n                ]\r\n            ),\r\n        ),\r\n        batch_size=args.batch_size,\r\n        shuffle=True,\r\n        **kwargs\r\n    )\r\n    test_loader = torch.utils.data.DataLoader(\r\n        datasets.MNIST(\r\n            \"../data\",\r\n            train=False,\r\n            transform=transforms.Compose(\r\n                [\r\n                    transforms.ToTensor(),\r\n                    transforms.Normalize((0.1307,), (0.3081,)),\r\n                ]\r\n            ),\r\n        ),\r\n        batch_size=args.test_batch_size,\r\n        shuffle=True,\r\n        **kwargs\r\n    )\r\n\r\n    model = Net(device=device).to(device)\r\n    optimizer = optim.Adam(model.parameters(), lr=args.lr)\r\n\r\n    for epoch in range(1, args.epochs + 1):\r\n        train(args, model, device, train_loader, optimizer, epoch)\r\n        test(args, model, device, test_loader)\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    main()\r\n"
  },
  {
    "path": "examples/pytorch/capsule/model.py",
    "content": "import torch\r\nfrom DGLDigitCapsule import DGLDigitCapsuleLayer\r\nfrom DGLRoutingLayer import squash\r\nfrom torch import nn\r\n\r\n\r\nclass Net(nn.Module):\r\n    def __init__(self, device=\"cpu\"):\r\n        super(Net, self).__init__()\r\n        self.device = device\r\n        self.conv1 = nn.Sequential(\r\n            nn.Conv2d(in_channels=1, out_channels=256, kernel_size=9, stride=1),\r\n            nn.ReLU(inplace=True),\r\n        )\r\n\r\n        self.primary = PrimaryCapsuleLayer(device=device)\r\n        self.digits = DGLDigitCapsuleLayer(device=device)\r\n\r\n    def forward(self, x):\r\n        out_conv1 = self.conv1(x)\r\n        out_primary_caps = self.primary(out_conv1)\r\n        out_digit_caps = self.digits(out_primary_caps)\r\n        return out_digit_caps\r\n\r\n    def margin_loss(self, input, target):\r\n        batch_s = target.size(0)\r\n        one_hot_vec = torch.zeros(batch_s, 10).to(self.device)\r\n        for i in range(batch_s):\r\n            one_hot_vec[i, target[i]] = 1.0\r\n        batch_size = input.size(0)\r\n        v_c = torch.sqrt((input**2).sum(dim=2, keepdim=True))\r\n        zero = torch.zeros(1).to(self.device)\r\n        m_plus = 0.9\r\n        m_minus = 0.1\r\n        loss_lambda = 0.5\r\n        max_left = torch.max(m_plus - v_c, zero).view(batch_size, -1) ** 2\r\n        max_right = torch.max(v_c - m_minus, zero).view(batch_size, -1) ** 2\r\n        t_c = one_hot_vec\r\n        l_c = t_c * max_left + loss_lambda * (1.0 - t_c) * max_right\r\n        l_c = l_c.sum(dim=1)\r\n        return l_c.mean()\r\n\r\n\r\nclass PrimaryCapsuleLayer(nn.Module):\r\n    def __init__(self, in_channel=256, num_unit=8, device=\"cpu\"):\r\n        super(PrimaryCapsuleLayer, self).__init__()\r\n        self.in_channel = in_channel\r\n        self.num_unit = num_unit\r\n        self.deivce = device\r\n        self.conv_units = nn.ModuleList(\r\n            [nn.Conv2d(self.in_channel, 32, 9, 2) for _ in range(self.num_unit)]\r\n        )\r\n\r\n    def forward(self, x):\r\n        unit = [self.conv_units[i](x) for i, l in enumerate(self.conv_units)]\r\n        unit = torch.stack(unit, dim=1)\r\n        batch_size = x.size(0)\r\n        unit = unit.view(batch_size, 8, -1)\r\n        return squash(unit, dim=2)\r\n"
  },
  {
    "path": "examples/pytorch/capsule/simple_routing.py",
    "content": "import dgl\nimport torch as th\nimport torch.nn as nn\nfrom DGLRoutingLayer import DGLRoutingLayer\nfrom torch.nn import functional as F\n\ng = dgl.DGLGraph()\ng.graph_data = {}\n\nin_nodes = 20\nout_nodes = 10\ng.graph_data[\"in_nodes\"] = in_nodes\ng.graph_data[\"out_nodes\"] = out_nodes\nall_nodes = in_nodes + out_nodes\ng.add_nodes(all_nodes)\n\n\nin_indx = list(range(in_nodes))\nout_indx = list(range(in_nodes, in_nodes + out_nodes))\ng.graph_data[\"in_indx\"] = in_indx\ng.graph_data[\"out_indx\"] = out_indx\n\n# add edges use edge broadcasting\nfor u in out_indx:\n    g.add_edges(in_indx, u)\n# init states\nf_size = 4\ng.ndata[\"v\"] = th.zeros(all_nodes, f_size)\ng.edata[\"u_hat\"] = th.randn(in_nodes * out_nodes, f_size)\ng.edata[\"b\"] = th.randn(in_nodes * out_nodes, 1)\n\nrouting_layer = DGLRoutingLayer(g)\n\nentropy_list = []\nfor i in range(15):\n    routing_layer()\n    dist_matrix = g.edata[\"c\"].view(in_nodes, out_nodes)\n    entropy = (-dist_matrix * th.log(dist_matrix)).sum(dim=0)\n    entropy_list.append(entropy.data.numpy())\n    std = dist_matrix.std(dim=0)\n"
  },
  {
    "path": "examples/pytorch/caregnn/README.md",
    "content": "# DGL Implementation of the CARE-GNN Paper\n\nThis DGL example implements the CAmouflage-REsistant GNN (CARE-GNN) model proposed in the paper [Enhancing Graph Neural Network-based Fraud Detectors against Camouflaged Fraudsters](https://arxiv.org/abs/2008.08692). The author's codes of implementation is [here](https://github.com/YingtongDou/CARE-GNN).\n\n**NOTE**: The sampling version of this model has been modified according to the feature of the DGL's NodeDataLoader. For the formula 2 in the paper, rather than using the embedding of the last layer, this version uses the embedding of the current layer in the previous epoch to measure the similarity between center nodes and their neighbors.\n\nExample implementor\n----------------------\nThis example was implemented by [Kay Liu](https://github.com/kayzliu) during his SDE intern work at the AWS Shanghai AI Lab.\n\nDependencies\n----------------------\n- Python 3.7.10\n- PyTorch 1.8.1\n- dgl 0.7.1\n- scikit-learn 0.23.2\n\nDataset\n---------------------------------------\nThe datasets used for node classification are DGL's built-in FraudDataset. The statistics are summarized as followings:\n\n**Amazon**\n\n- Nodes: 11,944\n- Edges:\n    - U-P-U: 351,216\n    - U-S-U: 7,132,958\n    - U-V-U: 2,073,474\n- Classes:\n    - Positive (fraudulent): 821\n    - Negative (benign): 7,818\n    - Unlabeled: 3,305\n- Positive-Negative ratio: 1 : 10.5\n- Node feature size: 25\n\n**YelpChi**\n\n- Nodes: 45,954\n- Edges:\n    - R-U-R: 98,630\n    - R-T-R: 1,147,232\n    - R-S-R: 6,805,486\n- Classes:\n    - Positive (spam): 6,677\n    - Negative (legitimate): 39,277\n- Positive-Negative ratio: 1 : 5.9\n- Node feature size: 32\n\nHow to run\n--------------------------------\nTo run the full graph version and use early stopping, in the care-gnn folder, run\n```\npython main.py --early-stop\n```\n\nIf want to use a GPU, run\n```\npython main.py --gpu 0\n```\n\nTo train on Yelp dataset instead of Amazon, run\n```\npython main.py --dataset yelp\n```\n\nTo run the sampling version, run\n```\npython main_sampling.py\n```\n\nPerformance\n-------------------------\nThe result reported by the paper is the best validation results within 30 epochs, and the table below reports the val and test results (same setting in the paper except for the random seed, here `seed=717`). \n\n<table>\n<thead>\n  <tr>\n    <th colspan=\"2\">Dataset</th>\n    <th>Amazon</th>\n    <th>Yelp</th>\n  </tr>\n</thead>\n<tbody>\n  <tr>\n    <td>Metric (val / test)</td>\n    <td>Max Epoch</td>\n    <td>30</td>\n    <td>30 </td>\n  </tr>\n  <tr>\n    <td rowspan=\"3\">AUC (val/test)</td>\n    <td>paper reported</td>\n    <td>0.8973 / -</td>\n    <td>0.7570 / -</td>\n  </tr>\n  <tr>\n    <td>DGL full graph</td>\n    <td>0.8849 / 0.8922</td>\n    <td>0.6856 / 0.6867</td>\n  </tr>\n  <tr>\n    <td>DGL sampling</td>\n    <td>0.9350 / 0.9331</td>\n    <td>0.7857 / 0.7890</td>\n  </tr>\n  <tr>\n    <td rowspan=\"3\">Recall (val/test)</td>\n    <td>paper reported</td>\n    <td>0.8848 / -</td>\n    <td>0.7192 / -</td>\n  </tr>\n  <tr>\n    <td>DGL full graph</td>\n    <td>0.8615 / 0.8544</td>\n    <td>0.6667/ 0.6619</td>\n  </tr>\n  <tr>\n    <td>DGL sampling</td>\n    <td>0.9130 / 0.9045</td>\n    <td>0.7537 / 0.7540</td>\n  </tr>\n</tbody>\n</table>\n\n\n\n"
  },
  {
    "path": "examples/pytorch/caregnn/main.py",
    "content": "import argparse\n\nimport dgl\n\nimport torch as th\nimport torch.optim as optim\nfrom model import CAREGNN\nfrom sklearn.metrics import recall_score, roc_auc_score\nfrom torch.nn.functional import softmax\nfrom utils import EarlyStopping\n\n\ndef main(args):\n    # Step 1: Prepare graph data and retrieve train/validation/test index ============================= #\n    # Load dataset\n    dataset = dgl.data.FraudDataset(args.dataset, train_size=0.4)\n    graph = dataset[0]\n    num_classes = dataset.num_classes\n\n    # check cuda\n    if args.gpu >= 0 and th.cuda.is_available():\n        device = \"cuda:{}\".format(args.gpu)\n    else:\n        device = \"cpu\"\n\n    # retrieve labels of ground truth\n    labels = graph.ndata[\"label\"].to(device)\n\n    # Extract node features\n    feat = graph.ndata[\"feature\"].to(device)\n\n    # retrieve masks for train/validation/test\n    train_mask = graph.ndata[\"train_mask\"]\n    val_mask = graph.ndata[\"val_mask\"]\n    test_mask = graph.ndata[\"test_mask\"]\n\n    train_idx = th.nonzero(train_mask, as_tuple=False).squeeze(1).to(device)\n    val_idx = th.nonzero(val_mask, as_tuple=False).squeeze(1).to(device)\n    test_idx = th.nonzero(test_mask, as_tuple=False).squeeze(1).to(device)\n\n    # Reinforcement learning module only for positive training nodes\n    rl_idx = th.nonzero(\n        train_mask.to(device) & labels.bool(), as_tuple=False\n    ).squeeze(1)\n\n    graph = graph.to(device)\n\n    # Step 2: Create model =================================================================== #\n    model = CAREGNN(\n        in_dim=feat.shape[-1],\n        num_classes=num_classes,\n        hid_dim=args.hid_dim,\n        num_layers=args.num_layers,\n        activation=th.tanh,\n        step_size=args.step_size,\n        edges=graph.canonical_etypes,\n    )\n\n    model = model.to(device)\n\n    # Step 3: Create training components ===================================================== #\n    _, cnt = th.unique(labels, return_counts=True)\n    loss_fn = th.nn.CrossEntropyLoss(weight=1 / cnt)\n    optimizer = optim.Adam(\n        model.parameters(), lr=args.lr, weight_decay=args.weight_decay\n    )\n    if args.early_stop:\n        stopper = EarlyStopping(patience=100)\n\n    # Step 4: training epochs =============================================================== #\n    for epoch in range(args.max_epoch):\n        # Training and validation using a full graph\n        model.train()\n        logits_gnn, logits_sim = model(graph, feat)\n\n        # compute loss\n        tr_loss = loss_fn(\n            logits_gnn[train_idx], labels[train_idx]\n        ) + args.sim_weight * loss_fn(logits_sim[train_idx], labels[train_idx])\n\n        tr_recall = recall_score(\n            labels[train_idx].cpu(),\n            logits_gnn.data[train_idx].argmax(dim=1).cpu(),\n        )\n        tr_auc = roc_auc_score(\n            labels[train_idx].cpu(),\n            softmax(logits_gnn, dim=1).data[train_idx][:, 1].cpu(),\n        )\n\n        # validation\n        val_loss = loss_fn(\n            logits_gnn[val_idx], labels[val_idx]\n        ) + args.sim_weight * loss_fn(logits_sim[val_idx], labels[val_idx])\n        val_recall = recall_score(\n            labels[val_idx].cpu(), logits_gnn.data[val_idx].argmax(dim=1).cpu()\n        )\n        val_auc = roc_auc_score(\n            labels[val_idx].cpu(),\n            softmax(logits_gnn, dim=1).data[val_idx][:, 1].cpu(),\n        )\n\n        # backward\n        optimizer.zero_grad()\n        tr_loss.backward()\n        optimizer.step()\n\n        # Print out performance\n        print(\n            \"Epoch {}, Train: Recall: {:.4f} AUC: {:.4f} Loss: {:.4f} | Val: Recall: {:.4f} AUC: {:.4f} Loss: {:.4f}\".format(\n                epoch,\n                tr_recall,\n                tr_auc,\n                tr_loss.item(),\n                val_recall,\n                val_auc,\n                val_loss.item(),\n            )\n        )\n\n        # Adjust p value with reinforcement learning module\n        model.RLModule(graph, epoch, rl_idx)\n\n        if args.early_stop:\n            if stopper.step(val_auc, model):\n                break\n\n    # Test after all epoch\n    model.eval()\n    if args.early_stop:\n        model.load_state_dict(th.load(\"es_checkpoint.pt\"))\n\n    # forward\n    logits_gnn, logits_sim = model.forward(graph, feat)\n\n    # compute loss\n    test_loss = loss_fn(\n        logits_gnn[test_idx], labels[test_idx]\n    ) + args.sim_weight * loss_fn(logits_sim[test_idx], labels[test_idx])\n    test_recall = recall_score(\n        labels[test_idx].cpu(), logits_gnn[test_idx].argmax(dim=1).cpu()\n    )\n    test_auc = roc_auc_score(\n        labels[test_idx].cpu(),\n        softmax(logits_gnn, dim=1).data[test_idx][:, 1].cpu(),\n    )\n\n    print(\n        \"Test Recall: {:.4f} AUC: {:.4f} Loss: {:.4f}\".format(\n            test_recall, test_auc, test_loss.item()\n        )\n    )\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"GCN-based Anti-Spam Model\")\n    parser.add_argument(\n        \"--dataset\",\n        type=str,\n        default=\"amazon\",\n        help=\"DGL dataset for this model (yelp, or amazon)\",\n    )\n    parser.add_argument(\n        \"--gpu\", type=int, default=-1, help=\"GPU index. Default: -1, using CPU.\"\n    )\n    parser.add_argument(\n        \"--hid_dim\", type=int, default=64, help=\"Hidden layer dimension\"\n    )\n    parser.add_argument(\n        \"--num_layers\", type=int, default=1, help=\"Number of layers\"\n    )\n    parser.add_argument(\n        \"--max_epoch\",\n        type=int,\n        default=30,\n        help=\"The max number of epochs. Default: 30\",\n    )\n    parser.add_argument(\n        \"--lr\", type=float, default=0.01, help=\"Learning rate. Default: 0.01\"\n    )\n    parser.add_argument(\n        \"--weight_decay\",\n        type=float,\n        default=0.001,\n        help=\"Weight decay. Default: 0.001\",\n    )\n    parser.add_argument(\n        \"--step_size\",\n        type=float,\n        default=0.02,\n        help=\"RL action step size (lambda 2). Default: 0.02\",\n    )\n    parser.add_argument(\n        \"--sim_weight\",\n        type=float,\n        default=2,\n        help=\"Similarity loss weight (lambda 1). Default: 2\",\n    )\n    parser.add_argument(\n        \"--early-stop\",\n        action=\"store_true\",\n        default=False,\n        help=\"indicates whether to use early stop\",\n    )\n\n    args = parser.parse_args()\n    print(args)\n    th.manual_seed(717)\n    main(args)\n"
  },
  {
    "path": "examples/pytorch/caregnn/main_sampling.py",
    "content": "import argparse\n\nimport dgl\n\nimport torch as th\nimport torch.optim as optim\nfrom model_sampling import _l1_dist, CAREGNN, CARESampler\nfrom sklearn.metrics import recall_score, roc_auc_score\nfrom torch.nn.functional import softmax\nfrom utils import EarlyStopping\n\n\ndef evaluate(model, loss_fn, dataloader, device=\"cpu\"):\n    loss = 0\n    auc = 0\n    recall = 0\n    num_blocks = 0\n    for input_nodes, output_nodes, blocks in dataloader:\n        blocks = [b.to(device) for b in blocks]\n        feature = blocks[0].srcdata[\"feature\"]\n        label = blocks[-1].dstdata[\"label\"]\n        logits_gnn, logits_sim = model(blocks, feature)\n\n        # compute loss\n        loss += (\n            loss_fn(logits_gnn, label).item()\n            + args.sim_weight * loss_fn(logits_sim, label).item()\n        )\n        recall += recall_score(\n            label.cpu(), logits_gnn.argmax(dim=1).detach().cpu()\n        )\n        auc += roc_auc_score(\n            label.cpu(), softmax(logits_gnn, dim=1)[:, 1].detach().cpu()\n        )\n        num_blocks += 1\n\n    return recall / num_blocks, auc / num_blocks, loss / num_blocks\n\n\ndef main(args):\n    # Step 1: Prepare graph data and retrieve train/validation/test index ============================= #\n    # Load dataset\n    dataset = dgl.data.FraudDataset(args.dataset, train_size=0.4)\n    graph = dataset[0]\n    num_classes = dataset.num_classes\n\n    # check cuda\n    if args.gpu >= 0 and th.cuda.is_available():\n        device = \"cuda:{}\".format(args.gpu)\n        args.num_workers = 0\n    else:\n        device = \"cpu\"\n\n    # retrieve labels of ground truth\n    labels = graph.ndata[\"label\"].to(device)\n\n    # Extract node features\n    feat = graph.ndata[\"feature\"].to(device)\n    layers_feat = feat.expand(args.num_layers, -1, -1)\n\n    # retrieve masks for train/validation/test\n    train_mask = graph.ndata[\"train_mask\"]\n    val_mask = graph.ndata[\"val_mask\"]\n    test_mask = graph.ndata[\"test_mask\"]\n\n    train_idx = th.nonzero(train_mask, as_tuple=False).squeeze(1).to(device)\n    val_idx = th.nonzero(val_mask, as_tuple=False).squeeze(1).to(device)\n    test_idx = th.nonzero(test_mask, as_tuple=False).squeeze(1).to(device)\n\n    # Reinforcement learning module only for positive training nodes\n    rl_idx = th.nonzero(\n        train_mask.to(device) & labels.bool(), as_tuple=False\n    ).squeeze(1)\n\n    graph = graph.to(device)\n\n    # Step 2: Create model =================================================================== #\n    model = CAREGNN(\n        in_dim=feat.shape[-1],\n        num_classes=num_classes,\n        hid_dim=args.hid_dim,\n        num_layers=args.num_layers,\n        activation=th.tanh,\n        step_size=args.step_size,\n        edges=graph.canonical_etypes,\n    )\n\n    model = model.to(device)\n\n    # Step 3: Create training components ===================================================== #\n    _, cnt = th.unique(labels, return_counts=True)\n    loss_fn = th.nn.CrossEntropyLoss(weight=1 / cnt)\n    optimizer = optim.Adam(\n        model.parameters(), lr=args.lr, weight_decay=args.weight_decay\n    )\n    if args.early_stop:\n        stopper = EarlyStopping(patience=100)\n\n    # Step 4: training epochs =============================================================== #\n    for epoch in range(args.max_epoch):\n        # calculate the distance of each edges and sample based on the distance\n        dists = []\n        p = []\n        for i in range(args.num_layers):\n            dist = {}\n            graph.ndata[\"nd\"] = th.tanh(model.layers[i].MLP(layers_feat[i]))\n            for etype in graph.canonical_etypes:\n                graph.apply_edges(_l1_dist, etype=etype)\n                dist[etype] = graph.edges[etype].data.pop(\"ed\").detach().cpu()\n            dists.append(dist)\n            p.append(model.layers[i].p)\n        graph.ndata.pop(\"nd\")\n        sampler = CARESampler(p, dists, args.num_layers)\n\n        # train\n        model.train()\n        tr_loss = 0\n        tr_recall = 0\n        tr_auc = 0\n        tr_blk = 0\n        train_dataloader = dgl.dataloading.DataLoader(\n            graph,\n            train_idx,\n            sampler,\n            batch_size=args.batch_size,\n            shuffle=True,\n            drop_last=False,\n            num_workers=args.num_workers,\n        )\n\n        for input_nodes, output_nodes, blocks in train_dataloader:\n            blocks = [b.to(device) for b in blocks]\n            train_feature = blocks[0].srcdata[\"feature\"]\n            train_label = blocks[-1].dstdata[\"label\"]\n            logits_gnn, logits_sim = model(blocks, train_feature)\n\n            # compute loss\n            blk_loss = loss_fn(\n                logits_gnn, train_label\n            ) + args.sim_weight * loss_fn(logits_sim, train_label)\n            tr_loss += blk_loss.item()\n            tr_recall += recall_score(\n                train_label.cpu(), logits_gnn.argmax(dim=1).detach().cpu()\n            )\n            tr_auc += roc_auc_score(\n                train_label.cpu(),\n                softmax(logits_gnn, dim=1)[:, 1].detach().cpu(),\n            )\n            tr_blk += 1\n\n            # backward\n            optimizer.zero_grad()\n            blk_loss.backward()\n            optimizer.step()\n\n        # Reinforcement learning module\n        model.RLModule(graph, epoch, rl_idx, dists)\n\n        # validation\n        model.eval()\n        val_dataloader = dgl.dataloading.DataLoader(\n            graph,\n            val_idx,\n            sampler,\n            batch_size=args.batch_size,\n            shuffle=True,\n            drop_last=False,\n            num_workers=args.num_workers,\n        )\n\n        val_recall, val_auc, val_loss = evaluate(\n            model, loss_fn, val_dataloader, device\n        )\n\n        # Print out performance\n        print(\n            \"In epoch {}, Train Recall: {:.4f} | Train AUC: {:.4f} | Train Loss: {:.4f}; \"\n            \"Valid Recall: {:.4f} | Valid AUC: {:.4f} | Valid loss: {:.4f}\".format(\n                epoch,\n                tr_recall / tr_blk,\n                tr_auc / tr_blk,\n                tr_loss / tr_blk,\n                val_recall,\n                val_auc,\n                val_loss,\n            )\n        )\n\n        if args.early_stop:\n            if stopper.step(val_auc, model):\n                break\n\n    # Test with mini batch after all epoch\n    model.eval()\n    if args.early_stop:\n        model.load_state_dict(th.load(\"es_checkpoint.pt\"))\n    test_dataloader = dgl.dataloading.DataLoader(\n        graph,\n        test_idx,\n        sampler,\n        batch_size=args.batch_size,\n        shuffle=True,\n        drop_last=False,\n        num_workers=args.num_workers,\n    )\n\n    test_recall, test_auc, test_loss = evaluate(\n        model, loss_fn, test_dataloader, device\n    )\n\n    print(\n        \"Test Recall: {:.4f} | Test AUC: {:.4f} | Test loss: {:.4f}\".format(\n            test_recall, test_auc, test_loss\n        )\n    )\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"GCN-based Anti-Spam Model\")\n    parser.add_argument(\n        \"--dataset\",\n        type=str,\n        default=\"amazon\",\n        help=\"DGL dataset for this model (yelp, or amazon)\",\n    )\n    parser.add_argument(\n        \"--gpu\", type=int, default=-1, help=\"GPU index. Default: -1, using CPU.\"\n    )\n    parser.add_argument(\n        \"--hid_dim\", type=int, default=64, help=\"Hidden layer dimension\"\n    )\n    parser.add_argument(\n        \"--num_layers\", type=int, default=1, help=\"Number of layers\"\n    )\n    parser.add_argument(\n        \"--batch_size\", type=int, default=256, help=\"Size of mini-batch\"\n    )\n    parser.add_argument(\n        \"--max_epoch\",\n        type=int,\n        default=30,\n        help=\"The max number of epochs. Default: 30\",\n    )\n    parser.add_argument(\n        \"--lr\", type=float, default=0.01, help=\"Learning rate. Default: 0.01\"\n    )\n    parser.add_argument(\n        \"--weight_decay\",\n        type=float,\n        default=0.001,\n        help=\"Weight decay. Default: 0.001\",\n    )\n    parser.add_argument(\n        \"--step_size\",\n        type=float,\n        default=0.02,\n        help=\"RL action step size (lambda 2). Default: 0.02\",\n    )\n    parser.add_argument(\n        \"--sim_weight\",\n        type=float,\n        default=2,\n        help=\"Similarity loss weight (lambda 1). Default: 0.001\",\n    )\n    parser.add_argument(\n        \"--num_workers\", type=int, default=4, help=\"Number of node dataloader\"\n    )\n    parser.add_argument(\n        \"--early-stop\",\n        action=\"store_true\",\n        default=False,\n        help=\"indicates whether to use early stop\",\n    )\n\n    args = parser.parse_args()\n    th.manual_seed(717)\n    print(args)\n    main(args)\n"
  },
  {
    "path": "examples/pytorch/caregnn/model.py",
    "content": "import dgl.function as fn\nimport numpy as np\nimport torch as th\nimport torch.nn as nn\n\n\nclass CAREConv(nn.Module):\n    \"\"\"One layer of CARE-GNN.\"\"\"\n\n    def __init__(\n        self,\n        in_dim,\n        out_dim,\n        num_classes,\n        edges,\n        activation=None,\n        step_size=0.02,\n    ):\n        super(CAREConv, self).__init__()\n\n        self.activation = activation\n        self.step_size = step_size\n        self.in_dim = in_dim\n        self.out_dim = out_dim\n        self.num_classes = num_classes\n        self.edges = edges\n        self.dist = {}\n\n        self.linear = nn.Linear(self.in_dim, self.out_dim)\n        self.MLP = nn.Linear(self.in_dim, self.num_classes)\n\n        self.p = {}\n        self.last_avg_dist = {}\n        self.f = {}\n        self.cvg = {}\n        for etype in edges:\n            self.p[etype] = 0.5\n            self.last_avg_dist[etype] = 0\n            self.f[etype] = []\n            self.cvg[etype] = False\n\n    def _calc_distance(self, edges):\n        # formula 2\n        d = th.norm(\n            th.tanh(self.MLP(edges.src[\"h\"]))\n            - th.tanh(self.MLP(edges.dst[\"h\"])),\n            1,\n            1,\n        )\n        return {\"d\": d}\n\n    def _top_p_sampling(self, g, p):\n        # this implementation is low efficient\n        # optimization requires dgl.sampling.select_top_p requested in issue #3100\n        dist = g.edata[\"d\"]\n        neigh_list = []\n        for node in g.nodes():\n            edges = g.in_edges(node, form=\"eid\")\n            num_neigh = th.ceil(g.in_degrees(node) * p).int().item()\n            neigh_dist = dist[edges]\n            if neigh_dist.shape[0] > num_neigh:\n                neigh_index = np.argpartition(\n                    neigh_dist.cpu().detach(), num_neigh\n                )[:num_neigh]\n            else:\n                neigh_index = np.arange(num_neigh)\n            neigh_list.append(edges[neigh_index])\n        return th.cat(neigh_list)\n\n    def forward(self, g, feat):\n        with g.local_scope():\n            g.ndata[\"h\"] = feat\n\n            hr = {}\n            for i, etype in enumerate(g.canonical_etypes):\n                g.apply_edges(self._calc_distance, etype=etype)\n                self.dist[etype] = g.edges[etype].data[\"d\"]\n                sampled_edges = self._top_p_sampling(g[etype], self.p[etype])\n\n                # formula 8\n                g.send_and_recv(\n                    sampled_edges,\n                    fn.copy_u(\"h\", \"m\"),\n                    fn.mean(\"m\", \"h_%s\" % etype[1]),\n                    etype=etype,\n                )\n                hr[etype] = g.ndata[\"h_%s\" % etype[1]]\n                if self.activation is not None:\n                    hr[etype] = self.activation(hr[etype])\n\n            # formula 9 using mean as inter-relation aggregator\n            p_tensor = (\n                th.Tensor(list(self.p.values())).view(-1, 1, 1).to(g.device)\n            )\n            h_homo = th.sum(th.stack(list(hr.values())) * p_tensor, dim=0)\n            h_homo += feat\n            if self.activation is not None:\n                h_homo = self.activation(h_homo)\n\n            return self.linear(h_homo)\n\n\nclass CAREGNN(nn.Module):\n    def __init__(\n        self,\n        in_dim,\n        num_classes,\n        hid_dim=64,\n        edges=None,\n        num_layers=2,\n        activation=None,\n        step_size=0.02,\n    ):\n        super(CAREGNN, self).__init__()\n        self.in_dim = in_dim\n        self.hid_dim = hid_dim\n        self.num_classes = num_classes\n        self.edges = edges\n        self.activation = activation\n        self.step_size = step_size\n        self.num_layers = num_layers\n\n        self.layers = nn.ModuleList()\n\n        if self.num_layers == 1:\n            # Single layer\n            self.layers.append(\n                CAREConv(\n                    self.in_dim,\n                    self.num_classes,\n                    self.num_classes,\n                    self.edges,\n                    activation=self.activation,\n                    step_size=self.step_size,\n                )\n            )\n\n        else:\n            # Input layer\n            self.layers.append(\n                CAREConv(\n                    self.in_dim,\n                    self.hid_dim,\n                    self.num_classes,\n                    self.edges,\n                    activation=self.activation,\n                    step_size=self.step_size,\n                )\n            )\n\n            # Hidden layers with n - 2 layers\n            for i in range(self.num_layers - 2):\n                self.layers.append(\n                    CAREConv(\n                        self.hid_dim,\n                        self.hid_dim,\n                        self.num_classes,\n                        self.edges,\n                        activation=self.activation,\n                        step_size=self.step_size,\n                    )\n                )\n\n            # Output layer\n            self.layers.append(\n                CAREConv(\n                    self.hid_dim,\n                    self.num_classes,\n                    self.num_classes,\n                    self.edges,\n                    activation=self.activation,\n                    step_size=self.step_size,\n                )\n            )\n\n    def forward(self, graph, feat):\n        # For full graph training, directly use the graph\n        # formula 4\n        sim = th.tanh(self.layers[0].MLP(feat))\n\n        # Forward of n layers of CARE-GNN\n        for layer in self.layers:\n            feat = layer(graph, feat)\n\n        return feat, sim\n\n    def RLModule(self, graph, epoch, idx):\n        for layer in self.layers:\n            for etype in self.edges:\n                if not layer.cvg[etype]:\n                    # formula 5\n                    eid = graph.in_edges(idx, form=\"eid\", etype=etype)\n                    avg_dist = th.mean(layer.dist[etype][eid])\n\n                    # formula 6\n                    if layer.last_avg_dist[etype] < avg_dist:\n                        if layer.p[etype] - self.step_size > 0:\n                            layer.p[etype] -= self.step_size\n                        layer.f[etype].append(-1)\n                    else:\n                        if layer.p[etype] + self.step_size <= 1:\n                            layer.p[etype] += self.step_size\n                        layer.f[etype].append(+1)\n                    layer.last_avg_dist[etype] = avg_dist\n\n                    # formula 7\n                    if epoch >= 9 and abs(sum(layer.f[etype][-10:])) <= 2:\n                        layer.cvg[etype] = True\n"
  },
  {
    "path": "examples/pytorch/caregnn/model_sampling.py",
    "content": "import dgl\nimport dgl.function as fn\nimport numpy as np\nimport torch as th\nimport torch.nn as nn\n\n\ndef _l1_dist(edges):\n    # formula 2\n    ed = th.norm(edges.src[\"nd\"] - edges.dst[\"nd\"], 1, 1)\n    return {\"ed\": ed}\n\n\nclass CARESampler(dgl.dataloading.BlockSampler):\n    def __init__(self, p, dists, num_layers):\n        super().__init__()\n        self.p = p\n        self.dists = dists\n        self.num_layers = num_layers\n\n    def sample_frontier(self, block_id, g, seed_nodes, *args, **kwargs):\n        with g.local_scope():\n            new_edges_masks = {}\n            for etype in g.canonical_etypes:\n                edge_mask = th.zeros(g.num_edges(etype))\n                # extract each node from dict because of single node type\n                for node in seed_nodes:\n                    edges = g.in_edges(node, form=\"eid\", etype=etype)\n                    num_neigh = (\n                        th.ceil(\n                            g.in_degrees(node, etype=etype)\n                            * self.p[block_id][etype]\n                        )\n                        .int()\n                        .item()\n                    )\n                    neigh_dist = self.dists[block_id][etype][edges]\n                    if neigh_dist.shape[0] > num_neigh:\n                        neigh_index = np.argpartition(neigh_dist, num_neigh)[\n                            :num_neigh\n                        ]\n                    else:\n                        neigh_index = np.arange(num_neigh)\n                    edge_mask[edges[neigh_index]] = 1\n                new_edges_masks[etype] = edge_mask.bool()\n\n            return dgl.edge_subgraph(g, new_edges_masks, relabel_nodes=False)\n\n    def sample_blocks(self, g, seed_nodes, exclude_eids=None):\n        output_nodes = seed_nodes\n        blocks = []\n        for block_id in reversed(range(self.num_layers)):\n            frontier = self.sample_frontier(block_id, g, seed_nodes)\n            eid = frontier.edata[dgl.EID]\n            block = dgl.to_block(frontier, seed_nodes)\n            block.edata[dgl.EID] = eid\n            seed_nodes = block.srcdata[dgl.NID]\n            blocks.insert(0, block)\n\n        return seed_nodes, output_nodes, blocks\n\n    def __len__(self):\n        return self.num_layers\n\n\nclass CAREConv(nn.Module):\n    \"\"\"One layer of CARE-GNN.\"\"\"\n\n    def __init__(\n        self,\n        in_dim,\n        out_dim,\n        num_classes,\n        edges,\n        activation=None,\n        step_size=0.02,\n    ):\n        super(CAREConv, self).__init__()\n\n        self.activation = activation\n        self.step_size = step_size\n        self.in_dim = in_dim\n        self.out_dim = out_dim\n        self.num_classes = num_classes\n        self.edges = edges\n\n        self.linear = nn.Linear(self.in_dim, self.out_dim)\n        self.MLP = nn.Linear(self.in_dim, self.num_classes)\n\n        self.p = {}\n        self.last_avg_dist = {}\n        self.f = {}\n        # indicate whether the RL converges\n        self.cvg = {}\n        for etype in edges:\n            self.p[etype] = 0.5\n            self.last_avg_dist[etype] = 0\n            self.f[etype] = []\n            self.cvg[etype] = False\n\n    def forward(self, g, feat):\n        g.srcdata[\"h\"] = feat\n\n        # formula 8\n        hr = {}\n        for etype in g.canonical_etypes:\n            g.update_all(fn.copy_u(\"h\", \"m\"), fn.mean(\"m\", \"hr\"), etype=etype)\n            hr[etype] = g.dstdata[\"hr\"]\n            if self.activation is not None:\n                hr[etype] = self.activation(hr[etype])\n\n        # formula 9 using mean as inter-relation aggregator\n        p_tensor = (\n            th.Tensor(list(self.p.values())).view(-1, 1, 1).to(feat.device)\n        )\n        h_homo = th.sum(th.stack(list(hr.values())) * p_tensor, dim=0)\n        h_homo += feat[: g.number_of_dst_nodes()]\n        if self.activation is not None:\n            h_homo = self.activation(h_homo)\n\n        return self.linear(h_homo)\n\n\nclass CAREGNN(nn.Module):\n    def __init__(\n        self,\n        in_dim,\n        num_classes,\n        hid_dim=64,\n        edges=None,\n        num_layers=2,\n        activation=None,\n        step_size=0.02,\n    ):\n        super(CAREGNN, self).__init__()\n        self.in_dim = in_dim\n        self.hid_dim = hid_dim\n        self.num_classes = num_classes\n        self.edges = edges\n        self.num_layers = num_layers\n        self.activation = activation\n        self.step_size = step_size\n\n        self.layers = nn.ModuleList()\n\n        if self.num_layers == 1:\n            # Single layer\n            self.layers.append(\n                CAREConv(\n                    self.in_dim,\n                    self.num_classes,\n                    self.num_classes,\n                    self.edges,\n                    activation=self.activation,\n                    step_size=self.step_size,\n                )\n            )\n\n        else:\n            # Input layer\n            self.layers.append(\n                CAREConv(\n                    self.in_dim,\n                    self.hid_dim,\n                    self.num_classes,\n                    self.edges,\n                    activation=self.activation,\n                    step_size=self.step_size,\n                )\n            )\n\n            # Hidden layers with n - 2 layers\n            for i in range(self.num_layers - 2):\n                self.layers.append(\n                    CAREConv(\n                        self.hid_dim,\n                        self.hid_dim,\n                        self.num_classes,\n                        self.edges,\n                        activation=self.activation,\n                        step_size=self.step_size,\n                    )\n                )\n\n            # Output layer\n            self.layers.append(\n                CAREConv(\n                    self.hid_dim,\n                    self.num_classes,\n                    self.num_classes,\n                    self.edges,\n                    activation=self.activation,\n                    step_size=self.step_size,\n                )\n            )\n\n    def forward(self, blocks, feat):\n        # formula 4\n        sim = th.tanh(self.layers[0].MLP(blocks[-1].dstdata[\"feature\"].float()))\n\n        # Forward of n layers of CARE-GNN\n        for block, layer in zip(blocks, self.layers):\n            feat = layer(block, feat)\n        return feat, sim\n\n    def RLModule(self, graph, epoch, idx, dists):\n        for i, layer in enumerate(self.layers):\n            for etype in self.edges:\n                if not layer.cvg[etype]:\n                    # formula 5\n                    eid = graph.in_edges(idx, form=\"eid\", etype=etype)\n                    avg_dist = th.mean(dists[i][etype][eid])\n\n                    # formula 6\n                    if layer.last_avg_dist[etype] < avg_dist:\n                        layer.p[etype] -= self.step_size\n                        layer.f[etype].append(-1)\n                        # avoid overflow, follow the author's implement\n                        if layer.p[etype] < 0:\n                            layer.p[etype] = 0.001\n                    else:\n                        layer.p[etype] += self.step_size\n                        layer.f[etype].append(+1)\n                        if layer.p[etype] > 1:\n                            layer.p[etype] = 0.999\n                    layer.last_avg_dist[etype] = avg_dist\n\n                    # formula 7\n                    if epoch >= 9 and abs(sum(layer.f[etype][-10:])) <= 2:\n                        layer.cvg[etype] = True\n"
  },
  {
    "path": "examples/pytorch/caregnn/utils.py",
    "content": "\"\"\"\nFrom GAT utils\n\"\"\"\nimport torch\n\n\nclass EarlyStopping:\n    def __init__(self, patience=10):\n        self.patience = patience\n        self.counter = 0\n        self.best_score = None\n        self.early_stop = False\n\n    def step(self, acc, model):\n        score = acc\n        if self.best_score is None:\n            self.best_score = score\n            self.save_checkpoint(model)\n        elif score < self.best_score:\n            self.counter += 1\n            print(\n                f\"EarlyStopping counter: {self.counter} out of {self.patience}\"\n            )\n            if self.counter >= self.patience:\n                self.early_stop = True\n        else:\n            self.best_score = score\n            self.save_checkpoint(model)\n            self.counter = 0\n        return self.early_stop\n\n    def save_checkpoint(self, model):\n        \"\"\"Saves model when validation loss decrease.\"\"\"\n        torch.save(model.state_dict(), \"es_checkpoint.pt\")\n"
  },
  {
    "path": "examples/pytorch/cluster_gcn/README.md",
    "content": "Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks\n============\n- Paper link: [Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks](https://arxiv.org/abs/1905.07953)\n- Author's code repo: [https://github.com/google-research/google-research/blob/master/cluster_gcn/](https://github.com/google-research/google-research/blob/master/cluster_gcn/).\n\nThis repo reproduce the reported speed and performance maximally on Reddit and PPI. However, the diag enhancement is not covered, as the GraphSage aggregator already achieves satisfying F1 score.\n\nDependencies\n------------\n- Python 3.7+(for string formatting features)\n- PyTorch 1.9.0+\n- scikit-learn\n- TorchMetrics 0.11.4\n\n## Run Experiments\n\n```bash\npython cluster_gcn.py\n```\n"
  },
  {
    "path": "examples/pytorch/cluster_gcn/cluster_gcn.py",
    "content": "import time\n\nimport dgl\nimport dgl.nn as dglnn\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torchmetrics.functional as MF\nfrom ogb.nodeproppred import DglNodePropPredDataset\n\n\nclass SAGE(nn.Module):\n    def __init__(self, in_feats, n_hidden, n_classes):\n        super().__init__()\n        self.layers = nn.ModuleList()\n        self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, \"mean\"))\n        self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, \"mean\"))\n        self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, \"mean\"))\n        self.dropout = nn.Dropout(0.5)\n\n    def forward(self, sg, x):\n        h = x\n        for l, layer in enumerate(self.layers):\n            h = layer(sg, h)\n            if l != len(self.layers) - 1:\n                h = F.relu(h)\n                h = self.dropout(h)\n        return h\n\n\ndataset = dgl.data.AsNodePredDataset(DglNodePropPredDataset(\"ogbn-products\"))\ngraph = dataset[\n    0\n]  # already prepares ndata['label'/'train_mask'/'val_mask'/'test_mask']\n\nmodel = SAGE(graph.ndata[\"feat\"].shape[1], 256, dataset.num_classes).cuda()\nopt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)\n\nnum_partitions = 1000\nsampler = dgl.dataloading.ClusterGCNSampler(\n    graph,\n    num_partitions,\n    prefetch_ndata=[\"feat\", \"label\", \"train_mask\", \"val_mask\", \"test_mask\"],\n)\n# DataLoader for generic dataloading with a graph, a set of indices (any indices, like\n# partition IDs here), and a graph sampler.\ndataloader = dgl.dataloading.DataLoader(\n    graph,\n    torch.arange(num_partitions).to(\"cuda\"),\n    sampler,\n    device=\"cuda\",\n    batch_size=100,\n    shuffle=True,\n    drop_last=False,\n    num_workers=0,\n    use_uva=True,\n)\n\ndurations = []\nfor epoch in range(10):\n    t0 = time.time()\n    model.train()\n    for it, sg in enumerate(dataloader):\n        x = sg.ndata[\"feat\"]\n        y = sg.ndata[\"label\"]\n        m = sg.ndata[\"train_mask\"].bool()\n        y_hat = model(sg, x)\n        loss = F.cross_entropy(y_hat[m], y[m])\n        opt.zero_grad()\n        loss.backward()\n        opt.step()\n        if it % 20 == 0:\n            acc = MF.accuracy(\n                y_hat[m],\n                y[m],\n                task=\"multiclass\",\n                num_classes=dataset.num_classes,\n            )\n            mem = torch.cuda.max_memory_allocated() / 1000000\n            print(\"Loss\", loss.item(), \"Acc\", acc.item(), \"GPU Mem\", mem, \"MB\")\n\n    tt = time.time() - t0\n    print(\"Run time for epoch# %d: %.2fs\" % (epoch, tt))\n    durations.append(tt)\n\n    model.eval()\n    with torch.no_grad():\n        val_preds, test_preds = [], []\n        val_labels, test_labels = [], []\n        for it, sg in enumerate(dataloader):\n            x = sg.ndata[\"feat\"]\n            y = sg.ndata[\"label\"]\n            m_val = sg.ndata[\"val_mask\"].bool()\n            m_test = sg.ndata[\"test_mask\"].bool()\n            y_hat = model(sg, x)\n            val_preds.append(y_hat[m_val])\n            val_labels.append(y[m_val])\n            test_preds.append(y_hat[m_test])\n            test_labels.append(y[m_test])\n        val_preds = torch.cat(val_preds, 0)\n        val_labels = torch.cat(val_labels, 0)\n        test_preds = torch.cat(test_preds, 0)\n        test_labels = torch.cat(test_labels, 0)\n        val_acc = MF.accuracy(\n            val_preds,\n            val_labels,\n            task=\"multiclass\",\n            num_classes=dataset.num_classes,\n        )\n        test_acc = MF.accuracy(\n            test_preds,\n            test_labels,\n            task=\"multiclass\",\n            num_classes=dataset.num_classes,\n        )\n        print(\"Validation acc:\", val_acc.item(), \"Test acc:\", test_acc.item())\n\nprint(\n    \"Average run time for last %d epochs: %.2fs standard deviation: %.3f\"\n    % ((epoch - 3), np.mean(durations[4:]), np.std(durations[4:]))\n)\n"
  },
  {
    "path": "examples/pytorch/compGCN/README.md",
    "content": "# DGL Implementation of the CompGCN Paper\n\nThis DGL example implements the GNN model proposed in the paper [CompositionGCN](https://arxiv.org/abs/1911.03082). \nThe author's codes of implementation is in [here](https://github.com/malllabiisc/CompGCN)\n\nExample implementor\n----------------------\nThis example was implemented by [zhjwy9343](https://github.com/zhjwy9343) and [KounianhuaDu](https://github.com/KounianhuaDu) at the AWS Shanghai AI Lab.\n\nDependencies\n----------------------\n- pytorch 1.9.0\n- dgl 0.7.1\n- numpy 1.20.3\n- ordered_set 4.0.2\n\nDataset\n---------------------------------------\nThe datasets used for link predictions are FB15k-237 constructed from Freebase and WN18RR constructed from WordNet. The statistics are summarized as followings:\n\n**FB15k-237** \n\n- Nodes: 14541\n- Relation types: 237\n- Reversed relation types: 237\n- Train: 272115\n- Valid: 17535\n- Test: 20466\n\n**WN18RR** \n\n- Nodes: 40943\n- Relation types: 11\n- Reversed relation types: 11\n- Train: 86835\n- Valid: 3034\n- Test: 3134\n\nHow to run\n--------------------------------\nFirst to get the data, one can run \n\n```python\nsh get_fb15k-237.sh\n```\n```python\nsh get_wn18rr.sh\n```\n\nThen for FB15k-237, run\n\n```python\npython main.py --score_func conve --opn ccorr --gpu 0 --data FB15k-237\n```\n\nFor WN18RR, run\n\n```python\npython main.py --score_func conve --opn ccorr --gpu 0 --data wn18rr\n```\n\n\nPerformance\n-------------------------\n**Link Prediction Results**\n\n| Dataset |        FB15k-237         |          WN18RR          |\n|---------| ------------------------ | ------------------------ |\n|  Metric |    Paper   /  ours (dgl) |    Paper   /  ours (dgl) |\n|   MRR   |    0.355   /    0.348    |    0.479   /    0.466    |\n|   MR    |     197    /     208     |    3533    /     3542    |\n| Hit@10  |    0.535   /   0.527     |    0.546   /    0.525    |\n|  Hit@3  |    0.390   /    0.380    |    0.494   /    0.476    |\n|  Hit@1  |    0.264   /    0.259    |    0.443   /    0.435    |\n\n\n\n\n"
  },
  {
    "path": "examples/pytorch/compGCN/data_loader.py",
    "content": "from collections import defaultdict as ddict\n\nimport dgl\n\nimport numpy as np\nimport torch\nfrom ordered_set import OrderedSet\nfrom torch.utils.data import DataLoader, Dataset\n\n\nclass TrainDataset(Dataset):\n    \"\"\"\n    Training Dataset class.\n    Parameters\n    ----------\n    triples: The triples used for training the model\n    num_ent: Number of entities in the knowledge graph\n    lbl_smooth: Label smoothing\n\n    Returns\n    -------\n    A training Dataset class instance used by DataLoader\n    \"\"\"\n\n    def __init__(self, triples, num_ent, lbl_smooth):\n        self.triples = triples\n        self.num_ent = num_ent\n        self.lbl_smooth = lbl_smooth\n        self.entities = np.arange(self.num_ent, dtype=np.int32)\n\n    def __len__(self):\n        return len(self.triples)\n\n    def __getitem__(self, idx):\n        ele = self.triples[idx]\n        triple, label = torch.LongTensor(ele[\"triple\"]), np.int32(ele[\"label\"])\n        trp_label = self.get_label(label)\n        # label smoothing\n        if self.lbl_smooth != 0.0:\n            trp_label = (1.0 - self.lbl_smooth) * trp_label + (\n                1.0 / self.num_ent\n            )\n\n        return triple, trp_label\n\n    @staticmethod\n    def collate_fn(data):\n        triples = []\n        labels = []\n        for triple, label in data:\n            triples.append(triple)\n            labels.append(label)\n        triple = torch.stack(triples, dim=0)\n        trp_label = torch.stack(labels, dim=0)\n        return triple, trp_label\n\n    # for edges that exist in the graph, the entry is 1.0, otherwise the entry is 0.0\n    def get_label(self, label):\n        y = np.zeros([self.num_ent], dtype=np.float32)\n        for e2 in label:\n            y[e2] = 1.0\n        return torch.FloatTensor(y)\n\n\nclass TestDataset(Dataset):\n    \"\"\"\n    Evaluation Dataset class.\n    Parameters\n    ----------\n    triples: The triples used for evaluating the model\n    num_ent: Number of entities in the knowledge graph\n\n    Returns\n    -------\n    An evaluation Dataset class instance used by DataLoader for model evaluation\n    \"\"\"\n\n    def __init__(self, triples, num_ent):\n        self.triples = triples\n        self.num_ent = num_ent\n\n    def __len__(self):\n        return len(self.triples)\n\n    def __getitem__(self, idx):\n        ele = self.triples[idx]\n        triple, label = torch.LongTensor(ele[\"triple\"]), np.int32(ele[\"label\"])\n        label = self.get_label(label)\n\n        return triple, label\n\n    @staticmethod\n    def collate_fn(data):\n        triples = []\n        labels = []\n        for triple, label in data:\n            triples.append(triple)\n            labels.append(label)\n        triple = torch.stack(triples, dim=0)\n        label = torch.stack(labels, dim=0)\n        return triple, label\n\n    # for edges that exist in the graph, the entry is 1.0, otherwise the entry is 0.0\n    def get_label(self, label):\n        y = np.zeros([self.num_ent], dtype=np.float32)\n        for e2 in label:\n            y[e2] = 1.0\n        return torch.FloatTensor(y)\n\n\nclass Data(object):\n    def __init__(self, dataset, lbl_smooth, num_workers, batch_size):\n        \"\"\"\n        Reading in raw triples and converts it into a standard format.\n        Parameters\n        ----------\n        dataset:           The name of the dataset\n        lbl_smooth:        Label smoothing\n        num_workers:       Number of workers of dataloaders\n        batch_size:        Batch size of dataloaders\n\n        Returns\n        -------\n        self.ent2id:            Entity to unique identifier mapping\n        self.rel2id:            Relation to unique identifier mapping\n        self.id2ent:            Inverse mapping of self.ent2id\n        self.id2rel:            Inverse mapping of self.rel2id\n        self.num_ent:           Number of entities in the knowledge graph\n        self.num_rel:           Number of relations in the knowledge graph\n\n        self.g:                 The dgl graph constucted from the edges in the traing set and all the entities in the knowledge graph\n        self.data['train']:     Stores the triples corresponding to training dataset\n        self.data['valid']:     Stores the triples corresponding to validation dataset\n        self.data['test']:      Stores the triples corresponding to test dataset\n        self.data_iter:\t\tThe dataloader for different data splits\n        \"\"\"\n        self.dataset = dataset\n        self.lbl_smooth = lbl_smooth\n        self.num_workers = num_workers\n        self.batch_size = batch_size\n\n        # read in raw data and get mappings\n        ent_set, rel_set = OrderedSet(), OrderedSet()\n        for split in [\"train\", \"test\", \"valid\"]:\n            for line in open(\"./{}/{}.txt\".format(self.dataset, split)):\n                sub, rel, obj = map(str.lower, line.strip().split(\"\\t\"))\n                ent_set.add(sub)\n                rel_set.add(rel)\n                ent_set.add(obj)\n\n        self.ent2id = {ent: idx for idx, ent in enumerate(ent_set)}\n        self.rel2id = {rel: idx for idx, rel in enumerate(rel_set)}\n        self.rel2id.update(\n            {\n                rel + \"_reverse\": idx + len(self.rel2id)\n                for idx, rel in enumerate(rel_set)\n            }\n        )\n\n        self.id2ent = {idx: ent for ent, idx in self.ent2id.items()}\n        self.id2rel = {idx: rel for rel, idx in self.rel2id.items()}\n\n        self.num_ent = len(self.ent2id)\n        self.num_rel = len(self.rel2id) // 2\n\n        # read in ids of subjects, relations, and objects for train/test/valid\n        self.data = ddict(list)  # stores the triples\n        sr2o = ddict(\n            set\n        )  # The key of sr20 is (subject, relation), and the items are all the successors following (subject, relation)\n        src = []\n        dst = []\n        rels = []\n        inver_src = []\n        inver_dst = []\n        inver_rels = []\n\n        for split in [\"train\", \"test\", \"valid\"]:\n            for line in open(\"./{}/{}.txt\".format(self.dataset, split)):\n                sub, rel, obj = map(str.lower, line.strip().split(\"\\t\"))\n                sub_id, rel_id, obj_id = (\n                    self.ent2id[sub],\n                    self.rel2id[rel],\n                    self.ent2id[obj],\n                )\n                self.data[split].append((sub_id, rel_id, obj_id))\n\n                if split == \"train\":\n                    sr2o[(sub_id, rel_id)].add(obj_id)\n                    sr2o[(obj_id, rel_id + self.num_rel)].add(\n                        sub_id\n                    )  # append the reversed edges\n                    src.append(sub_id)\n                    dst.append(obj_id)\n                    rels.append(rel_id)\n                    inver_src.append(obj_id)\n                    inver_dst.append(sub_id)\n                    inver_rels.append(rel_id + self.num_rel)\n\n        # construct dgl graph\n        src = src + inver_src\n        dst = dst + inver_dst\n        rels = rels + inver_rels\n        self.g = dgl.graph((src, dst), num_nodes=self.num_ent)\n        self.g.edata[\"etype\"] = torch.Tensor(rels).long()\n\n        # identify in and out edges\n        in_edges_mask = [True] * (self.g.num_edges() // 2) + [False] * (\n            self.g.num_edges() // 2\n        )\n        out_edges_mask = [False] * (self.g.num_edges() // 2) + [True] * (\n            self.g.num_edges() // 2\n        )\n        self.g.edata[\"in_edges_mask\"] = torch.Tensor(in_edges_mask)\n        self.g.edata[\"out_edges_mask\"] = torch.Tensor(out_edges_mask)\n\n        # Prepare train/valid/test data\n        self.data = dict(self.data)\n        self.sr2o = {\n            k: list(v) for k, v in sr2o.items()\n        }  # store only the train data\n\n        for split in [\"test\", \"valid\"]:\n            for sub, rel, obj in self.data[split]:\n                sr2o[(sub, rel)].add(obj)\n                sr2o[(obj, rel + self.num_rel)].add(sub)\n\n        self.sr2o_all = {\n            k: list(v) for k, v in sr2o.items()\n        }  # store all the data\n        self.triples = ddict(list)\n\n        for (sub, rel), obj in self.sr2o.items():\n            self.triples[\"train\"].append(\n                {\"triple\": (sub, rel, -1), \"label\": self.sr2o[(sub, rel)]}\n            )\n\n        for split in [\"test\", \"valid\"]:\n            for sub, rel, obj in self.data[split]:\n                rel_inv = rel + self.num_rel\n                self.triples[\"{}_{}\".format(split, \"tail\")].append(\n                    {\n                        \"triple\": (sub, rel, obj),\n                        \"label\": self.sr2o_all[(sub, rel)],\n                    }\n                )\n                self.triples[\"{}_{}\".format(split, \"head\")].append(\n                    {\n                        \"triple\": (obj, rel_inv, sub),\n                        \"label\": self.sr2o_all[(obj, rel_inv)],\n                    }\n                )\n\n        self.triples = dict(self.triples)\n\n        def get_train_data_loader(split, batch_size, shuffle=True):\n            return DataLoader(\n                TrainDataset(\n                    self.triples[split], self.num_ent, self.lbl_smooth\n                ),\n                batch_size=batch_size,\n                shuffle=shuffle,\n                num_workers=max(0, self.num_workers),\n                collate_fn=TrainDataset.collate_fn,\n            )\n\n        def get_test_data_loader(split, batch_size, shuffle=True):\n            return DataLoader(\n                TestDataset(self.triples[split], self.num_ent),\n                batch_size=batch_size,\n                shuffle=shuffle,\n                num_workers=max(0, self.num_workers),\n                collate_fn=TestDataset.collate_fn,\n            )\n\n        # train/valid/test dataloaders\n        self.data_iter = {\n            \"train\": get_train_data_loader(\"train\", self.batch_size),\n            \"valid_head\": get_test_data_loader(\"valid_head\", self.batch_size),\n            \"valid_tail\": get_test_data_loader(\"valid_tail\", self.batch_size),\n            \"test_head\": get_test_data_loader(\"test_head\", self.batch_size),\n            \"test_tail\": get_test_data_loader(\"test_tail\", self.batch_size),\n        }\n"
  },
  {
    "path": "examples/pytorch/compGCN/get_fb15k-237.sh",
    "content": "wget https://dgl-data.s3.cn-north-1.amazonaws.com.cn/dataset/FB15k-237.zip\nunzip FB15k-237.zip\n"
  },
  {
    "path": "examples/pytorch/compGCN/get_wn18rr.sh",
    "content": "wget https://dgl-data.s3.cn-north-1.amazonaws.com.cn/dataset/wn18rr.zip\nunzip wn18rr.zip\n"
  },
  {
    "path": "examples/pytorch/compGCN/main.py",
    "content": "import argparse\nfrom time import time\n\nimport numpy as np\nimport torch as th\nimport torch.optim as optim\nfrom data_loader import Data\nfrom models import CompGCN_ConvE\nfrom utils import in_out_norm\n\n\n# predict the tail for (head, rel, -1) or head for (-1, rel, tail)\ndef predict(model, graph, device, data_iter, split=\"valid\", mode=\"tail\"):\n    model.eval()\n    with th.no_grad():\n        results = {}\n        train_iter = iter(data_iter[\"{}_{}\".format(split, mode)])\n\n        for step, batch in enumerate(train_iter):\n            triple, label = batch[0].to(device), batch[1].to(device)\n            sub, rel, obj, label = (\n                triple[:, 0],\n                triple[:, 1],\n                triple[:, 2],\n                label,\n            )\n            pred = model(graph, sub, rel)\n            b_range = th.arange(pred.size()[0], device=device)\n            target_pred = pred[b_range, obj]\n            pred = th.where(label.bool(), -th.ones_like(pred) * 10000000, pred)\n            pred[b_range, obj] = target_pred\n\n            # compute metrics\n            ranks = (\n                1\n                + th.argsort(\n                    th.argsort(pred, dim=1, descending=True),\n                    dim=1,\n                    descending=False,\n                )[b_range, obj]\n            )\n            ranks = ranks.float()\n            results[\"count\"] = th.numel(ranks) + results.get(\"count\", 0.0)\n            results[\"mr\"] = th.sum(ranks).item() + results.get(\"mr\", 0.0)\n            results[\"mrr\"] = th.sum(1.0 / ranks).item() + results.get(\n                \"mrr\", 0.0\n            )\n            for k in [1, 3, 10]:\n                results[\"hits@{}\".format(k)] = th.numel(\n                    ranks[ranks <= (k)]\n                ) + results.get(\"hits@{}\".format(k), 0.0)\n\n    return results\n\n\n# evaluation function, evaluate the head and tail prediction and then combine the results\ndef evaluate(model, graph, device, data_iter, split=\"valid\"):\n    # predict for head and tail\n    left_results = predict(model, graph, device, data_iter, split, mode=\"tail\")\n    right_results = predict(model, graph, device, data_iter, split, mode=\"head\")\n    results = {}\n    count = float(left_results[\"count\"])\n\n    # combine the head and tail prediction results\n    # Metrics: MRR, MR, and Hit@k\n    results[\"left_mr\"] = round(left_results[\"mr\"] / count, 5)\n    results[\"left_mrr\"] = round(left_results[\"mrr\"] / count, 5)\n    results[\"right_mr\"] = round(right_results[\"mr\"] / count, 5)\n    results[\"right_mrr\"] = round(right_results[\"mrr\"] / count, 5)\n    results[\"mr\"] = round(\n        (left_results[\"mr\"] + right_results[\"mr\"]) / (2 * count), 5\n    )\n    results[\"mrr\"] = round(\n        (left_results[\"mrr\"] + right_results[\"mrr\"]) / (2 * count), 5\n    )\n    for k in [1, 3, 10]:\n        results[\"left_hits@{}\".format(k)] = round(\n            left_results[\"hits@{}\".format(k)] / count, 5\n        )\n        results[\"right_hits@{}\".format(k)] = round(\n            right_results[\"hits@{}\".format(k)] / count, 5\n        )\n        results[\"hits@{}\".format(k)] = round(\n            (\n                left_results[\"hits@{}\".format(k)]\n                + right_results[\"hits@{}\".format(k)]\n            )\n            / (2 * count),\n            5,\n        )\n    return results\n\n\ndef main(args):\n    # Step 1: Prepare graph data and retrieve train/validation/test index ============================= #\n    # check cuda\n    if args.gpu >= 0 and th.cuda.is_available():\n        device = \"cuda:{}\".format(args.gpu)\n    else:\n        device = \"cpu\"\n\n    # construct graph, split in/out edges and prepare train/validation/test data_loader\n    data = Data(\n        args.dataset, args.lbl_smooth, args.num_workers, args.batch_size\n    )\n    data_iter = data.data_iter  # train/validation/test data_loader\n    graph = data.g.to(device)\n    num_rel = th.max(graph.edata[\"etype\"]).item() + 1\n\n    # Compute in/out edge norms and store in edata\n    graph = in_out_norm(graph)\n\n    # Step 2: Create model =================================================================== #\n    compgcn_model = CompGCN_ConvE(\n        num_bases=args.num_bases,\n        num_rel=num_rel,\n        num_ent=graph.num_nodes(),\n        in_dim=args.init_dim,\n        layer_size=args.layer_size,\n        comp_fn=args.opn,\n        batchnorm=True,\n        dropout=args.dropout,\n        layer_dropout=args.layer_dropout,\n        num_filt=args.num_filt,\n        hid_drop=args.hid_drop,\n        feat_drop=args.feat_drop,\n        ker_sz=args.ker_sz,\n        k_w=args.k_w,\n        k_h=args.k_h,\n    )\n    compgcn_model = compgcn_model.to(device)\n\n    # Step 3: Create training components ===================================================== #\n    loss_fn = th.nn.BCELoss()\n    optimizer = optim.Adam(\n        compgcn_model.parameters(), lr=args.lr, weight_decay=args.l2\n    )\n\n    # Step 4: training epoches =============================================================== #\n    best_mrr = 0.0\n    kill_cnt = 0\n    for epoch in range(args.max_epochs):\n        # Training and validation using a full graph\n        compgcn_model.train()\n        train_loss = []\n        t0 = time()\n        for step, batch in enumerate(data_iter[\"train\"]):\n            triple, label = batch[0].to(device), batch[1].to(device)\n            sub, rel, obj, label = (\n                triple[:, 0],\n                triple[:, 1],\n                triple[:, 2],\n                label,\n            )\n            logits = compgcn_model(graph, sub, rel)\n\n            # compute loss\n            tr_loss = loss_fn(logits, label)\n            train_loss.append(tr_loss.item())\n\n            # backward\n            optimizer.zero_grad()\n            tr_loss.backward()\n            optimizer.step()\n\n        train_loss = np.sum(train_loss)\n\n        t1 = time()\n        val_results = evaluate(\n            compgcn_model, graph, device, data_iter, split=\"valid\"\n        )\n        t2 = time()\n\n        # validate\n        if val_results[\"mrr\"] > best_mrr:\n            best_mrr = val_results[\"mrr\"]\n            th.save(\n                compgcn_model.state_dict(), \"comp_link\" + \"_\" + args.dataset\n            )\n            kill_cnt = 0\n            print(\"saving model...\")\n        else:\n            kill_cnt += 1\n            if kill_cnt > 100:\n                print(\"early stop.\")\n                break\n        print(\n            \"In epoch {}, Train Loss: {:.4f}, Valid MRR: {:.5}, Train time: {}, Valid time: {}\".format(\n                epoch, train_loss, val_results[\"mrr\"], t1 - t0, t2 - t1\n            )\n        )\n\n    # test use the best model\n    compgcn_model.eval()\n    compgcn_model.load_state_dict(th.load(\"comp_link\" + \"_\" + args.dataset))\n    test_results = evaluate(\n        compgcn_model, graph, device, data_iter, split=\"test\"\n    )\n    print(\n        \"Test MRR: {:.5}\\n, MR: {:.10}\\n, H@10: {:.5}\\n, H@3: {:.5}\\n, H@1: {:.5}\\n\".format(\n            test_results[\"mrr\"],\n            test_results[\"mr\"],\n            test_results[\"hits@10\"],\n            test_results[\"hits@3\"],\n            test_results[\"hits@1\"],\n        )\n    )\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(\n        description=\"Parser For Arguments\",\n        formatter_class=argparse.ArgumentDefaultsHelpFormatter,\n    )\n\n    parser.add_argument(\n        \"--data\",\n        dest=\"dataset\",\n        default=\"FB15k-237\",\n        help=\"Dataset to use, default: FB15k-237\",\n    )\n    parser.add_argument(\n        \"--model\", dest=\"model\", default=\"compgcn\", help=\"Model Name\"\n    )\n    parser.add_argument(\n        \"--score_func\",\n        dest=\"score_func\",\n        default=\"conve\",\n        help=\"Score Function for Link prediction\",\n    )\n    parser.add_argument(\n        \"--opn\",\n        dest=\"opn\",\n        default=\"ccorr\",\n        help=\"Composition Operation to be used in CompGCN\",\n    )\n\n    parser.add_argument(\n        \"--batch\", dest=\"batch_size\", default=1024, type=int, help=\"Batch size\"\n    )\n    parser.add_argument(\n        \"--gpu\",\n        type=int,\n        default=\"0\",\n        help=\"Set GPU Ids : Eg: For CPU = -1, For Single GPU = 0\",\n    )\n    parser.add_argument(\n        \"--epoch\",\n        dest=\"max_epochs\",\n        type=int,\n        default=500,\n        help=\"Number of epochs\",\n    )\n    parser.add_argument(\n        \"--l2\", type=float, default=0.0, help=\"L2 Regularization for Optimizer\"\n    )\n    parser.add_argument(\n        \"--lr\", type=float, default=0.001, help=\"Starting Learning Rate\"\n    )\n    parser.add_argument(\n        \"--lbl_smooth\",\n        dest=\"lbl_smooth\",\n        type=float,\n        default=0.1,\n        help=\"Label Smoothing\",\n    )\n    parser.add_argument(\n        \"--num_workers\",\n        type=int,\n        default=10,\n        help=\"Number of processes to construct batches\",\n    )\n    parser.add_argument(\n        \"--seed\",\n        dest=\"seed\",\n        default=41504,\n        type=int,\n        help=\"Seed for randomization\",\n    )\n\n    parser.add_argument(\n        \"--num_bases\",\n        dest=\"num_bases\",\n        default=-1,\n        type=int,\n        help=\"Number of basis relation vectors to use\",\n    )\n    parser.add_argument(\n        \"--init_dim\",\n        dest=\"init_dim\",\n        default=100,\n        type=int,\n        help=\"Initial dimension size for entities and relations\",\n    )\n    parser.add_argument(\n        \"--layer_size\",\n        nargs=\"?\",\n        default=\"[200]\",\n        help=\"List of output size for each compGCN layer\",\n    )\n    parser.add_argument(\n        \"--gcn_drop\",\n        dest=\"dropout\",\n        default=0.1,\n        type=float,\n        help=\"Dropout to use in GCN Layer\",\n    )\n    parser.add_argument(\n        \"--layer_dropout\",\n        nargs=\"?\",\n        default=\"[0.3]\",\n        help=\"List of dropout value after each compGCN layer\",\n    )\n\n    # ConvE specific hyperparameters\n    parser.add_argument(\n        \"--hid_drop\",\n        dest=\"hid_drop\",\n        default=0.3,\n        type=float,\n        help=\"ConvE: Hidden dropout\",\n    )\n    parser.add_argument(\n        \"--feat_drop\",\n        dest=\"feat_drop\",\n        default=0.3,\n        type=float,\n        help=\"ConvE: Feature Dropout\",\n    )\n    parser.add_argument(\n        \"--k_w\", dest=\"k_w\", default=10, type=int, help=\"ConvE: k_w\"\n    )\n    parser.add_argument(\n        \"--k_h\", dest=\"k_h\", default=20, type=int, help=\"ConvE: k_h\"\n    )\n    parser.add_argument(\n        \"--num_filt\",\n        dest=\"num_filt\",\n        default=200,\n        type=int,\n        help=\"ConvE: Number of filters in convolution\",\n    )\n    parser.add_argument(\n        \"--ker_sz\",\n        dest=\"ker_sz\",\n        default=7,\n        type=int,\n        help=\"ConvE: Kernel size to use\",\n    )\n\n    args = parser.parse_args()\n\n    np.random.seed(args.seed)\n    th.manual_seed(args.seed)\n\n    print(args)\n\n    args.layer_size = eval(args.layer_size)\n    args.layer_dropout = eval(args.layer_dropout)\n\n    main(args)\n"
  },
  {
    "path": "examples/pytorch/compGCN/models.py",
    "content": "import dgl\nimport dgl.function as fn\nimport torch as th\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nfrom utils import ccorr\n\n\nclass CompGraphConv(nn.Module):\n    \"\"\"One layer of CompGCN.\"\"\"\n\n    def __init__(\n        self, in_dim, out_dim, comp_fn=\"sub\", batchnorm=True, dropout=0.1\n    ):\n        super(CompGraphConv, self).__init__()\n        self.in_dim = in_dim\n        self.out_dim = out_dim\n        self.comp_fn = comp_fn\n        self.actvation = th.tanh\n        self.batchnorm = batchnorm\n\n        # define dropout layer\n        self.dropout = nn.Dropout(dropout)\n\n        # define batch norm layer\n        if self.batchnorm:\n            self.bn = nn.BatchNorm1d(out_dim)\n\n        # define in/out/loop transform layer\n        self.W_O = nn.Linear(self.in_dim, self.out_dim)\n        self.W_I = nn.Linear(self.in_dim, self.out_dim)\n        self.W_S = nn.Linear(self.in_dim, self.out_dim)\n\n        # define relation transform layer\n        self.W_R = nn.Linear(self.in_dim, self.out_dim)\n\n        # self loop embedding\n        self.loop_rel = nn.Parameter(th.Tensor(1, self.in_dim))\n        nn.init.xavier_normal_(self.loop_rel)\n\n    def forward(self, g, n_in_feats, r_feats):\n        with g.local_scope():\n            # Assign values to source nodes. In a homogeneous graph, this is equal to\n            # assigning them to all nodes.\n            g.srcdata[\"h\"] = n_in_feats\n            # append loop_rel embedding to r_feats\n            r_feats = th.cat((r_feats, self.loop_rel), 0)\n            # Assign features to all edges with the corresponding relation embeddings\n            g.edata[\"h\"] = r_feats[g.edata[\"etype\"]] * g.edata[\"norm\"]\n\n            # Compute composition function in 4 steps\n            # Step 1: compute composition by edge in the edge direction, and store results in edges.\n            if self.comp_fn == \"sub\":\n                g.apply_edges(fn.u_sub_e(\"h\", \"h\", out=\"comp_h\"))\n            elif self.comp_fn == \"mul\":\n                g.apply_edges(fn.u_mul_e(\"h\", \"h\", out=\"comp_h\"))\n            elif self.comp_fn == \"ccorr\":\n                g.apply_edges(\n                    lambda edges: {\n                        \"comp_h\": ccorr(edges.src[\"h\"], edges.data[\"h\"])\n                    }\n                )\n            else:\n                raise Exception(\"Only supports sub, mul, and ccorr\")\n\n            # Step 2: use extracted edge direction to compute in and out edges\n            comp_h = g.edata[\"comp_h\"]\n\n            in_edges_idx = th.nonzero(\n                g.edata[\"in_edges_mask\"], as_tuple=False\n            ).squeeze()\n            out_edges_idx = th.nonzero(\n                g.edata[\"out_edges_mask\"], as_tuple=False\n            ).squeeze()\n\n            comp_h_O = self.W_O(comp_h[out_edges_idx])\n            comp_h_I = self.W_I(comp_h[in_edges_idx])\n\n            new_comp_h = th.zeros(comp_h.shape[0], self.out_dim).to(\n                comp_h.device\n            )\n            new_comp_h[out_edges_idx] = comp_h_O\n            new_comp_h[in_edges_idx] = comp_h_I\n\n            g.edata[\"new_comp_h\"] = new_comp_h\n\n            # Step 3: sum comp results to both src and dst nodes\n            g.update_all(fn.copy_e(\"new_comp_h\", \"m\"), fn.sum(\"m\", \"comp_edge\"))\n\n            # Step 4: add results of self-loop\n            if self.comp_fn == \"sub\":\n                comp_h_s = n_in_feats - r_feats[-1]\n            elif self.comp_fn == \"mul\":\n                comp_h_s = n_in_feats * r_feats[-1]\n            elif self.comp_fn == \"ccorr\":\n                comp_h_s = ccorr(n_in_feats, r_feats[-1])\n            else:\n                raise Exception(\"Only supports sub, mul, and ccorr\")\n\n            # Sum all of the comp results as output of nodes and dropout\n            n_out_feats = (\n                self.W_S(comp_h_s) + self.dropout(g.ndata[\"comp_edge\"])\n            ) * (1 / 3)\n\n            # Compute relation output\n            r_out_feats = self.W_R(r_feats)\n\n            # Batch norm\n            if self.batchnorm:\n                n_out_feats = self.bn(n_out_feats)\n\n            # Activation function\n            if self.actvation is not None:\n                n_out_feats = self.actvation(n_out_feats)\n\n        return n_out_feats, r_out_feats[:-1]\n\n\nclass CompGCN(nn.Module):\n    def __init__(\n        self,\n        num_bases,\n        num_rel,\n        num_ent,\n        in_dim=100,\n        layer_size=[200],\n        comp_fn=\"sub\",\n        batchnorm=True,\n        dropout=0.1,\n        layer_dropout=[0.3],\n    ):\n        super(CompGCN, self).__init__()\n\n        self.num_bases = num_bases\n        self.num_rel = num_rel\n        self.num_ent = num_ent\n        self.in_dim = in_dim\n        self.layer_size = layer_size\n        self.comp_fn = comp_fn\n        self.batchnorm = batchnorm\n        self.dropout = dropout\n        self.layer_dropout = layer_dropout\n        self.num_layer = len(layer_size)\n\n        # CompGCN layers\n        self.layers = nn.ModuleList()\n        self.layers.append(\n            CompGraphConv(\n                self.in_dim,\n                self.layer_size[0],\n                comp_fn=self.comp_fn,\n                batchnorm=self.batchnorm,\n                dropout=self.dropout,\n            )\n        )\n        for i in range(self.num_layer - 1):\n            self.layers.append(\n                CompGraphConv(\n                    self.layer_size[i],\n                    self.layer_size[i + 1],\n                    comp_fn=self.comp_fn,\n                    batchnorm=self.batchnorm,\n                    dropout=self.dropout,\n                )\n            )\n\n        # Initial relation embeddings\n        if self.num_bases > 0:\n            self.basis = nn.Parameter(th.Tensor(self.num_bases, self.in_dim))\n            self.weights = nn.Parameter(th.Tensor(self.num_rel, self.num_bases))\n            nn.init.xavier_normal_(self.basis)\n            nn.init.xavier_normal_(self.weights)\n        else:\n            self.rel_embds = nn.Parameter(th.Tensor(self.num_rel, self.in_dim))\n            nn.init.xavier_normal_(self.rel_embds)\n\n        # Node embeddings\n        self.n_embds = nn.Parameter(th.Tensor(self.num_ent, self.in_dim))\n        nn.init.xavier_normal_(self.n_embds)\n\n        # Dropout after compGCN layers\n        self.dropouts = nn.ModuleList()\n        for i in range(self.num_layer):\n            self.dropouts.append(nn.Dropout(self.layer_dropout[i]))\n\n    def forward(self, graph):\n        # node and relation features\n        n_feats = self.n_embds\n        if self.num_bases > 0:\n            r_embds = th.mm(self.weights, self.basis)\n            r_feats = r_embds\n        else:\n            r_feats = self.rel_embds\n\n        for layer, dropout in zip(self.layers, self.dropouts):\n            n_feats, r_feats = layer(graph, n_feats, r_feats)\n            n_feats = dropout(n_feats)\n\n        return n_feats, r_feats\n\n\n# Use convE as the score function\nclass CompGCN_ConvE(nn.Module):\n    def __init__(\n        self,\n        num_bases,\n        num_rel,\n        num_ent,\n        in_dim,\n        layer_size,\n        comp_fn=\"sub\",\n        batchnorm=True,\n        dropout=0.1,\n        layer_dropout=[0.3],\n        num_filt=200,\n        hid_drop=0.3,\n        feat_drop=0.3,\n        ker_sz=5,\n        k_w=5,\n        k_h=5,\n    ):\n        super(CompGCN_ConvE, self).__init__()\n\n        self.embed_dim = layer_size[-1]\n        self.hid_drop = hid_drop\n        self.feat_drop = feat_drop\n        self.ker_sz = ker_sz\n        self.k_w = k_w\n        self.k_h = k_h\n        self.num_filt = num_filt\n\n        # compGCN model to get sub/rel embs\n        self.compGCN_Model = CompGCN(\n            num_bases,\n            num_rel,\n            num_ent,\n            in_dim,\n            layer_size,\n            comp_fn,\n            batchnorm,\n            dropout,\n            layer_dropout,\n        )\n\n        # batchnorms to the combined (sub+rel) emb\n        self.bn0 = th.nn.BatchNorm2d(1)\n        self.bn1 = th.nn.BatchNorm2d(self.num_filt)\n        self.bn2 = th.nn.BatchNorm1d(self.embed_dim)\n\n        # dropouts and conv module to the combined (sub+rel) emb\n        self.hidden_drop = th.nn.Dropout(self.hid_drop)\n        self.feature_drop = th.nn.Dropout(self.feat_drop)\n        self.m_conv1 = th.nn.Conv2d(\n            1,\n            out_channels=self.num_filt,\n            kernel_size=(self.ker_sz, self.ker_sz),\n            stride=1,\n            padding=0,\n            bias=False,\n        )\n\n        flat_sz_h = int(2 * self.k_w) - self.ker_sz + 1\n        flat_sz_w = self.k_h - self.ker_sz + 1\n        self.flat_sz = flat_sz_h * flat_sz_w * self.num_filt\n        self.fc = th.nn.Linear(self.flat_sz, self.embed_dim)\n\n        # bias to the score\n        self.bias = nn.Parameter(th.zeros(num_ent))\n\n    # combine entity embeddings and relation embeddings\n    def concat(self, e1_embed, rel_embed):\n        e1_embed = e1_embed.view(-1, 1, self.embed_dim)\n        rel_embed = rel_embed.view(-1, 1, self.embed_dim)\n        stack_inp = th.cat([e1_embed, rel_embed], 1)\n        stack_inp = th.transpose(stack_inp, 2, 1).reshape(\n            (-1, 1, 2 * self.k_w, self.k_h)\n        )\n        return stack_inp\n\n    def forward(self, graph, sub, rel):\n        # get sub_emb and rel_emb via compGCN\n        n_feats, r_feats = self.compGCN_Model(graph)\n        sub_emb = n_feats[sub, :]\n        rel_emb = r_feats[rel, :]\n\n        # combine the sub_emb and rel_emb\n        stk_inp = self.concat(sub_emb, rel_emb)\n        # use convE to score the combined emb\n        x = self.bn0(stk_inp)\n        x = self.m_conv1(x)\n        x = self.bn1(x)\n        x = F.relu(x)\n        x = self.feature_drop(x)\n        x = x.view(-1, self.flat_sz)\n        x = self.fc(x)\n        x = self.hidden_drop(x)\n        x = self.bn2(x)\n        x = F.relu(x)\n        # compute score\n        x = th.mm(x, n_feats.transpose(1, 0))\n        # add in bias\n        x += self.bias.expand_as(x)\n        score = th.sigmoid(x)\n        return score\n"
  },
  {
    "path": "examples/pytorch/compGCN/utils.py",
    "content": "# This file is based on the CompGCN author's implementation\n# <https://github.com/malllabiisc/CompGCN/blob/master/helper.py>.\n# It implements the operation of circular convolution in the ccorr function and an additional in_out_norm function for norm computation.\n\nimport dgl\nimport torch as th\n\n\ndef com_mult(a, b):\n    r1, i1 = a[..., 0], a[..., 1]\n    r2, i2 = b[..., 0], b[..., 1]\n    return th.stack([r1 * r2 - i1 * i2, r1 * i2 + i1 * r2], dim=-1)\n\n\ndef conj(a):\n    a[..., 1] = -a[..., 1]\n    return a\n\n\ndef ccorr(a, b):\n    \"\"\"\n    Compute circular correlation of two tensors.\n    Parameters\n    ----------\n    a: Tensor, 1D or 2D\n    b: Tensor, 1D or 2D\n\n    Notes\n    -----\n    Input a and b should have the same dimensions. And this operation supports broadcasting.\n\n    Returns\n    -------\n    Tensor, having the same dimension as the input a.\n    \"\"\"\n    return th.fft.irfftn(\n        th.conj(th.fft.rfftn(a, (-1))) * th.fft.rfftn(b, (-1)), (-1)\n    )\n\n\n# identify in/out edges, compute edge norm for each and store in edata\ndef in_out_norm(graph):\n    src, dst, EID = graph.edges(form=\"all\")\n    graph.edata[\"norm\"] = th.ones(EID.shape[0]).to(graph.device)\n\n    in_edges_idx = th.nonzero(\n        graph.edata[\"in_edges_mask\"], as_tuple=False\n    ).squeeze()\n    out_edges_idx = th.nonzero(\n        graph.edata[\"out_edges_mask\"], as_tuple=False\n    ).squeeze()\n\n    for idx in [in_edges_idx, out_edges_idx]:\n        u, v = src[idx], dst[idx]\n        deg = th.zeros(graph.num_nodes()).to(graph.device)\n        n_idx, inverse_index, count = th.unique(\n            v, return_inverse=True, return_counts=True\n        )\n        deg[n_idx] = count.float()\n        deg_inv = deg.pow(-0.5)  # D^{-0.5}\n        deg_inv[deg_inv == float(\"inf\")] = 0\n        norm = deg_inv[u] * deg_inv[v]\n        graph.edata[\"norm\"][idx] = norm\n    graph.edata[\"norm\"] = graph.edata[\"norm\"].unsqueeze(1)\n\n    return graph\n"
  },
  {
    "path": "examples/pytorch/correct_and_smooth/README.md",
    "content": "# DGL Implementation of CorrectAndSmooth\n\nThis DGL example implements the GNN model proposed in the paper [Combining Label Propagation and Simple Models Out-performs Graph Neural Networks](https://arxiv.org/abs/2010.13993). For the original implementation, see [here](https://github.com/CUAI/CorrectAndSmooth).\n\nContributor: [xnuohz](https://github.com/xnuohz)\n\n### Requirements\nThe codebase is implemented in Python 3.7. For version requirement of packages, see below.\n\n```\ndgl 0.6.0.post1\ntorch 1.7.0\nogb 1.3.0\n```\n\n### Limitations\n\nSpectral and Diffusion Embeddings used by the authors for feature augmentation are not currently implemented. Without these feature augmentations only the \"Plain\" (without feature augmentations) results from the authors can be replicated.\n\n### The graph datasets used in this example\n\nOpen Graph Benchmark(OGB). Dataset summary:\n\n|    Dataset    |  #Nodes   |   #Edges   | #Node Feats |  Metric  |\n| :-----------: | :-------: | :--------: | :---------: | :------: |\n|  ogbn-arxiv   |  169,343  | 1,166,243  |     128     | Accuracy |\n| ogbn-products | 2,449,029 | 61,859,140 |     100     | Accuracy |\n\n### Usage\n\nTraining a **Base predictor** and using **Correct&Smooth** which follows the original hyperparameters on different datasets.\n\n##### ogbn-arxiv\n\n* **Plain MLP + C&S**\n\n```bash\npython main.py --dropout 0.5\npython main.py --pretrain --correction-adj DA --smoothing-adj AD --autoscale\n```\n\n* **Plain Linear + C&S**\n\n```bash\npython main.py --model linear --dropout 0.5 --epochs 1000\npython main.py --model linear --pretrain --correction-alpha 0.87 --smoothing-alpha 0.81 --correction-adj AD --autoscale\n```\n\n##### ogbn-products\n\n* **Plain Linear + C&S**\n\n```bash\npython main.py --dataset ogbn-products --model linear --dropout 0.5 --epochs 1000 --lr 0.1\npython main.py --dataset ogbn-products --model linear --pretrain --correction-alpha 1. --smoothing-alpha 0.9\n```\n\n### Performance\n\n#### ogbn-arxiv\n\n|                 | Linear | Plain Linear + C&S |\n| :-------------: | :----: |    :----------:    |\n| Results(Author) | 52.5   |       71.26        |\n|  Results(DGL)   | 52.48  |       71.26        |\n\n#### ogbn-products\n\n|                 | Plain Linear | Plain Linear + C&S |\n| :-------------: | :----: | :----------: |\n| Results(Author) | 47.67  |    82.34     |\n|  Results(DGL)   | 47.65  |    82.86     |\n\n### Speed\n\n|      ogb-arxiv       |      Time     | GPU Memory | Params  |\n| :------------------: | :-----------: | :--------: | :-----: |\n| Author, Plain Linear + C&S | 6.3 * 10 ^ -3 |   1,248M   |  5,160  |\n|   DGL, Plain Linear + C&S  | 5.6 * 10 ^ -3 |   1,252M   |  5,160  |\n"
  },
  {
    "path": "examples/pytorch/correct_and_smooth/main.py",
    "content": "import argparse\nimport copy\nimport os\n\nimport dgl\n\nimport torch\nimport torch.nn.functional as F\nimport torch.optim as optim\nfrom model import CorrectAndSmooth, MLP, MLPLinear\nfrom ogb.nodeproppred import DglNodePropPredDataset, Evaluator\n\n\ndef evaluate(y_pred, y_true, idx, evaluator):\n    return evaluator.eval({\"y_true\": y_true[idx], \"y_pred\": y_pred[idx]})[\"acc\"]\n\n\ndef main():\n    # check cuda\n    device = (\n        f\"cuda:{args.gpu}\"\n        if torch.cuda.is_available() and args.gpu >= 0\n        else \"cpu\"\n    )\n    # load data\n    dataset = DglNodePropPredDataset(name=args.dataset)\n    evaluator = Evaluator(name=args.dataset)\n\n    split_idx = dataset.get_idx_split()\n    g, labels = dataset[\n        0\n    ]  # graph: DGLGraph object, label: torch tensor of shape (num_nodes, num_tasks)\n\n    if args.dataset == \"ogbn-arxiv\":\n        g = dgl.to_bidirected(g, copy_ndata=True)\n\n        feat = g.ndata[\"feat\"]\n        feat = (feat - feat.mean(0)) / feat.std(0)\n        g.ndata[\"feat\"] = feat\n\n    g = g.to(device)\n    feats = g.ndata[\"feat\"]\n    labels = labels.to(device)\n\n    # load masks for train / validation / test\n    train_idx = split_idx[\"train\"].to(device)\n    valid_idx = split_idx[\"valid\"].to(device)\n    test_idx = split_idx[\"test\"].to(device)\n\n    n_features = feats.size()[-1]\n    n_classes = dataset.num_classes\n\n    # load model\n    if args.model == \"mlp\":\n        model = MLP(\n            n_features, args.hid_dim, n_classes, args.num_layers, args.dropout\n        )\n    elif args.model == \"linear\":\n        model = MLPLinear(n_features, n_classes)\n    else:\n        raise NotImplementedError(f\"Model {args.model} is not supported.\")\n\n    model = model.to(device)\n    print(f\"Model parameters: {sum(p.numel() for p in model.parameters())}\")\n\n    if args.pretrain:\n        print(\"---------- Before ----------\")\n        model.load_state_dict(\n            torch.load(\n                f\"base/{args.dataset}-{args.model}.pt\", weights_only=False\n            )\n        )\n        model.eval()\n\n        y_soft = model(feats).exp()\n\n        y_pred = y_soft.argmax(dim=-1, keepdim=True)\n        valid_acc = evaluate(y_pred, labels, valid_idx, evaluator)\n        test_acc = evaluate(y_pred, labels, test_idx, evaluator)\n        print(f\"Valid acc: {valid_acc:.4f} | Test acc: {test_acc:.4f}\")\n\n        print(\"---------- Correct & Smoothing ----------\")\n        cs = CorrectAndSmooth(\n            num_correction_layers=args.num_correction_layers,\n            correction_alpha=args.correction_alpha,\n            correction_adj=args.correction_adj,\n            num_smoothing_layers=args.num_smoothing_layers,\n            smoothing_alpha=args.smoothing_alpha,\n            smoothing_adj=args.smoothing_adj,\n            autoscale=args.autoscale,\n            scale=args.scale,\n        )\n\n        y_soft = cs.correct(g, y_soft, labels[train_idx], train_idx)\n        y_soft = cs.smooth(g, y_soft, labels[train_idx], train_idx)\n        y_pred = y_soft.argmax(dim=-1, keepdim=True)\n        valid_acc = evaluate(y_pred, labels, valid_idx, evaluator)\n        test_acc = evaluate(y_pred, labels, test_idx, evaluator)\n        print(f\"Valid acc: {valid_acc:.4f} | Test acc: {test_acc:.4f}\")\n    else:\n        opt = optim.Adam(model.parameters(), lr=args.lr)\n\n        best_acc = 0\n        best_model = copy.deepcopy(model)\n\n        # training\n        print(\"---------- Training ----------\")\n        for i in range(args.epochs):\n            model.train()\n            opt.zero_grad()\n\n            logits = model(feats)\n\n            train_loss = F.nll_loss(\n                logits[train_idx], labels.squeeze(1)[train_idx]\n            )\n            train_loss.backward()\n\n            opt.step()\n\n            model.eval()\n            with torch.no_grad():\n                logits = model(feats)\n\n                y_pred = logits.argmax(dim=-1, keepdim=True)\n\n                train_acc = evaluate(y_pred, labels, train_idx, evaluator)\n                valid_acc = evaluate(y_pred, labels, valid_idx, evaluator)\n\n                print(\n                    f\"Epoch {i} | Train loss: {train_loss.item():.4f} | Train acc: {train_acc:.4f} | Valid acc {valid_acc:.4f}\"\n                )\n\n                if valid_acc > best_acc:\n                    best_acc = valid_acc\n                    best_model = copy.deepcopy(model)\n\n        # testing & saving model\n        print(\"---------- Testing ----------\")\n        best_model.eval()\n\n        logits = best_model(feats)\n\n        y_pred = logits.argmax(dim=-1, keepdim=True)\n        test_acc = evaluate(y_pred, labels, test_idx, evaluator)\n        print(f\"Test acc: {test_acc:.4f}\")\n\n        if not os.path.exists(\"base\"):\n            os.makedirs(\"base\")\n\n        torch.save(\n            best_model.state_dict(), f\"base/{args.dataset}-{args.model}.pt\"\n        )\n\n\nif __name__ == \"__main__\":\n    \"\"\"\n    Correct & Smoothing Hyperparameters\n    \"\"\"\n    parser = argparse.ArgumentParser(description=\"Base predictor(C&S)\")\n\n    # Dataset\n    parser.add_argument(\"--gpu\", type=int, default=0, help=\"-1 for cpu\")\n    parser.add_argument(\n        \"--dataset\",\n        type=str,\n        default=\"ogbn-arxiv\",\n        choices=[\"ogbn-arxiv\", \"ogbn-products\"],\n    )\n    # Base predictor\n    parser.add_argument(\n        \"--model\", type=str, default=\"mlp\", choices=[\"mlp\", \"linear\"]\n    )\n    parser.add_argument(\"--num-layers\", type=int, default=3)\n    parser.add_argument(\"--hid-dim\", type=int, default=256)\n    parser.add_argument(\"--dropout\", type=float, default=0.4)\n    parser.add_argument(\"--lr\", type=float, default=0.01)\n    parser.add_argument(\"--epochs\", type=int, default=300)\n    # extra options for gat\n    parser.add_argument(\"--n-heads\", type=int, default=3)\n    parser.add_argument(\"--attn_drop\", type=float, default=0.05)\n    # C & S\n    parser.add_argument(\n        \"--pretrain\", action=\"store_true\", help=\"Whether to perform C & S\"\n    )\n    parser.add_argument(\"--num-correction-layers\", type=int, default=50)\n    parser.add_argument(\"--correction-alpha\", type=float, default=0.979)\n    parser.add_argument(\"--correction-adj\", type=str, default=\"DAD\")\n    parser.add_argument(\"--num-smoothing-layers\", type=int, default=50)\n    parser.add_argument(\"--smoothing-alpha\", type=float, default=0.756)\n    parser.add_argument(\"--smoothing-adj\", type=str, default=\"DAD\")\n    parser.add_argument(\"--autoscale\", action=\"store_true\")\n    parser.add_argument(\"--scale\", type=float, default=20.0)\n\n    args = parser.parse_args()\n    print(args)\n\n    main()\n"
  },
  {
    "path": "examples/pytorch/correct_and_smooth/model.py",
    "content": "import dgl.function as fn\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass MLPLinear(nn.Module):\n    def __init__(self, in_dim, out_dim):\n        super(MLPLinear, self).__init__()\n        self.linear = nn.Linear(in_dim, out_dim)\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        self.linear.reset_parameters()\n\n    def forward(self, x):\n        return F.log_softmax(self.linear(x), dim=-1)\n\n\nclass MLP(nn.Module):\n    def __init__(self, in_dim, hid_dim, out_dim, num_layers, dropout=0.0):\n        super(MLP, self).__init__()\n        assert num_layers >= 2\n\n        self.linears = nn.ModuleList()\n        self.bns = nn.ModuleList()\n        self.linears.append(nn.Linear(in_dim, hid_dim))\n        self.bns.append(nn.BatchNorm1d(hid_dim))\n\n        for _ in range(num_layers - 2):\n            self.linears.append(nn.Linear(hid_dim, hid_dim))\n            self.bns.append(nn.BatchNorm1d(hid_dim))\n\n        self.linears.append(nn.Linear(hid_dim, out_dim))\n        self.dropout = dropout\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        for layer in self.linears:\n            layer.reset_parameters()\n        for layer in self.bns:\n            layer.reset_parameters()\n\n    def forward(self, x):\n        for linear, bn in zip(self.linears[:-1], self.bns):\n            x = linear(x)\n            x = F.relu(x, inplace=True)\n            x = bn(x)\n            x = F.dropout(x, p=self.dropout, training=self.training)\n        x = self.linears[-1](x)\n        return F.log_softmax(x, dim=-1)\n\n\nclass LabelPropagation(nn.Module):\n    r\"\"\"\n\n    Description\n    -----------\n    Introduced in `Learning from Labeled and Unlabeled Data with Label Propagation <https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.14.3864&rep=rep1&type=pdf>`_\n\n    .. math::\n        \\mathbf{Y}^{\\prime} = \\alpha \\cdot \\mathbf{D}^{-1/2} \\mathbf{A}\n        \\mathbf{D}^{-1/2} \\mathbf{Y} + (1 - \\alpha) \\mathbf{Y},\n\n    where unlabeled data is inferred by labeled data via propagation.\n\n    Parameters\n    ----------\n        num_layers: int\n            The number of propagations.\n        alpha: float\n            The :math:`\\alpha` coefficient.\n        adj: str\n            'DAD': D^-0.5 * A * D^-0.5\n            'DA': D^-1 * A\n            'AD': A * D^-1\n    \"\"\"\n\n    def __init__(self, num_layers, alpha, adj=\"DAD\"):\n        super(LabelPropagation, self).__init__()\n\n        self.num_layers = num_layers\n        self.alpha = alpha\n        self.adj = adj\n\n    @torch.no_grad()\n    def forward(\n        self, g, labels, mask=None, post_step=lambda y: y.clamp_(0.0, 1.0)\n    ):\n        with g.local_scope():\n            if labels.dtype == torch.long:\n                labels = F.one_hot(labels.view(-1)).to(torch.float32)\n\n            y = labels\n            if mask is not None:\n                y = torch.zeros_like(labels)\n                y[mask] = labels[mask]\n\n            last = (1 - self.alpha) * y\n            degs = g.in_degrees().float().clamp(min=1)\n            norm = (\n                torch.pow(degs, -0.5 if self.adj == \"DAD\" else -1)\n                .to(labels.device)\n                .unsqueeze(1)\n            )\n\n            for _ in range(self.num_layers):\n                # Assume the graphs to be undirected\n                if self.adj in [\"DAD\", \"AD\"]:\n                    y = norm * y\n\n                g.ndata[\"h\"] = y\n                g.update_all(fn.copy_u(\"h\", \"m\"), fn.sum(\"m\", \"h\"))\n                y = self.alpha * g.ndata.pop(\"h\")\n\n                if self.adj in [\"DAD\", \"DA\"]:\n                    y = y * norm\n\n                y = post_step(last + y)\n\n            return y\n\n\nclass CorrectAndSmooth(nn.Module):\n    r\"\"\"\n\n    Description\n    -----------\n    Introduced in `Combining Label Propagation and Simple Models Out-performs Graph Neural Networks <https://arxiv.org/abs/2010.13993>`_\n\n    Parameters\n    ----------\n        num_correction_layers: int\n            The number of correct propagations.\n        correction_alpha: float\n            The coefficient of correction.\n        correction_adj: str\n            'DAD': D^-0.5 * A * D^-0.5\n            'DA': D^-1 * A\n            'AD': A * D^-1\n        num_smoothing_layers: int\n            The number of smooth propagations.\n        smoothing_alpha: float\n            The coefficient of smoothing.\n        smoothing_adj: str\n            'DAD': D^-0.5 * A * D^-0.5\n            'DA': D^-1 * A\n            'AD': A * D^-1\n        autoscale: bool, optional\n            If set to True, will automatically determine the scaling factor :math:`\\sigma`. Default is True.\n        scale: float, optional\n            The scaling factor :math:`\\sigma`, in case :obj:`autoscale = False`. Default is 1.\n    \"\"\"\n\n    def __init__(\n        self,\n        num_correction_layers,\n        correction_alpha,\n        correction_adj,\n        num_smoothing_layers,\n        smoothing_alpha,\n        smoothing_adj,\n        autoscale=True,\n        scale=1.0,\n    ):\n        super(CorrectAndSmooth, self).__init__()\n\n        self.autoscale = autoscale\n        self.scale = scale\n\n        self.prop1 = LabelPropagation(\n            num_correction_layers, correction_alpha, correction_adj\n        )\n        self.prop2 = LabelPropagation(\n            num_smoothing_layers, smoothing_alpha, smoothing_adj\n        )\n\n    def correct(self, g, y_soft, y_true, mask):\n        with g.local_scope():\n            assert abs(float(y_soft.sum()) / y_soft.size(0) - 1.0) < 1e-2\n            numel = (\n                int(mask.sum()) if mask.dtype == torch.bool else mask.size(0)\n            )\n            assert y_true.size(0) == numel\n\n            if y_true.dtype == torch.long:\n                y_true = F.one_hot(y_true.view(-1), y_soft.size(-1)).to(\n                    y_soft.dtype\n                )\n\n            error = torch.zeros_like(y_soft)\n            error[mask] = y_true - y_soft[mask]\n\n            if self.autoscale:\n                smoothed_error = self.prop1(\n                    g, error, post_step=lambda x: x.clamp_(-1.0, 1.0)\n                )\n                sigma = error[mask].abs().sum() / numel\n                scale = sigma / smoothed_error.abs().sum(dim=1, keepdim=True)\n                scale[scale.isinf() | (scale > 1000)] = 1.0\n\n                result = y_soft + scale * smoothed_error\n                result[result.isnan()] = y_soft[result.isnan()]\n                return result\n            else:\n\n                def fix_input(x):\n                    x[mask] = error[mask]\n                    return x\n\n                smoothed_error = self.prop1(g, error, post_step=fix_input)\n\n                result = y_soft + self.scale * smoothed_error\n                result[result.isnan()] = y_soft[result.isnan()]\n                return result\n\n    def smooth(self, g, y_soft, y_true, mask):\n        with g.local_scope():\n            numel = (\n                int(mask.sum()) if mask.dtype == torch.bool else mask.size(0)\n            )\n            assert y_true.size(0) == numel\n\n            if y_true.dtype == torch.long:\n                y_true = F.one_hot(y_true.view(-1), y_soft.size(-1)).to(\n                    y_soft.dtype\n                )\n\n            y_soft[mask] = y_true\n            return self.prop2(g, y_soft)\n"
  },
  {
    "path": "examples/pytorch/dagnn/README.md",
    "content": "# DAGNN\n\nThis DGL example implements the GNN model proposed in the paper [Towards Deeper Graph Neural Networks](https://arxiv.org/abs/2007.09296).\n\nPaper link: https://arxiv.org/abs/2007.09296\n\nAuthor's code: https://github.com/divelab/DeeperGNN\n\nContributor: Liu Tang ([@lt610](https://github.com/lt610))\n\n## Dependecies\n- Python 3.6.10\n- PyTorch 1.4.0\n- numpy 1.18.1\n- dgl 0.5.3\n- tqdm 4.44.1\n\n## Dataset\n\nThe DGL's built-in Cora, Pubmed and Citeseer datasets. Dataset summary:\n\n| Dataset | #Nodes | #Edges | #Feats | #Classes | #Train Nodes | #Val Nodes | #Test Nodes |\n| :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: |\n| Citeseer | 3,327 | 9,228 | 3,703 | 6 | 120 | 500 | 1000 |\n| Cora | 2,708 | 10,556 | 1,433 | 7 | 140 | 500 | 1000 |\n| Pubmed | 19,717 | 88,651 | 500 | 3 | 60 | 500 | 1000 |\n\n## Arguments\n\n###### Dataset options\n```\n--dataset          str     The graph dataset name.             Default is 'Cora'.\n```\n\n###### GPU options\n```\n--gpu              int     GPU index.                          Default is -1, using CPU.\n```\n\n###### Model options\n```\n--runs             int     Number of training runs.               Default is 1\n--epochs           int     Number of training epochs.             Default is 1500.\n--early-stopping   int     Early stopping patience rounds.        Default is 100.\n--lr               float   Adam optimizer learning rate.          Default is 0.01.\n--lamb             float   L2 regularization coefficient.         Default is 5e-3.\n--k                int     Number of propagation layers.          Default is 10.\n--hid-dim          int     Hidden layer dimensionalities.         Default is 64.\n--dropout          float   Dropout rate                           Default is 0.8\n```\n\n## Examples\n\nTrain a model which follows the original hyperparameters on different datasets.\n```bash\n# Cora:\npython main.py --dataset Cora --gpu 0 --runs 100 --lamb 0.005 --k 12\n# Citeseer:\npython main.py --dataset Citeseer --gpu 0 --runs 100 --lamb 0.02 --k 16\n# Pubmed:\npython main.py --dataset Pubmed --gpu 0 --runs 100 --lamb 0.005 --k 20\n```\n### Performance\n\n#### On Cora, Citeseer and Pubmed\n| Dataset | Cora | Citeseer | Pubmed |\n| :-: | :-: | :-: | :-: |\n| Accuracy Reported(100 runs) | 84.4 ± 0.5 | 73.3 ± 0.6 | 80.5 ± 0.5 |\n| Accuracy DGL(100 runs) | 84.3 ± 0.5 | 73.1 ± 0.9 | 80.5 ± 0.4 |\n"
  },
  {
    "path": "examples/pytorch/dagnn/main.py",
    "content": "import argparse\n\nimport dgl.function as fn\n\nimport numpy as np\nimport torch\nfrom dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset\nfrom torch import nn\nfrom torch.nn import functional as F, Parameter\nfrom tqdm import trange\nfrom utils import evaluate, generate_random_seeds, set_random_state\n\n\nclass DAGNNConv(nn.Module):\n    def __init__(self, in_dim, k):\n        super(DAGNNConv, self).__init__()\n\n        self.s = Parameter(torch.FloatTensor(in_dim, 1))\n        self.k = k\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        gain = nn.init.calculate_gain(\"sigmoid\")\n        nn.init.xavier_uniform_(self.s, gain=gain)\n\n    def forward(self, graph, feats):\n        with graph.local_scope():\n            results = [feats]\n\n            degs = graph.in_degrees().float()\n            norm = torch.pow(degs, -0.5)\n            norm = norm.to(feats.device).unsqueeze(1)\n\n            for _ in range(self.k):\n                feats = feats * norm\n                graph.ndata[\"h\"] = feats\n                graph.update_all(fn.copy_u(\"h\", \"m\"), fn.sum(\"m\", \"h\"))\n                feats = graph.ndata[\"h\"]\n                feats = feats * norm\n                results.append(feats)\n\n            H = torch.stack(results, dim=1)\n            S = F.sigmoid(torch.matmul(H, self.s))\n            S = S.permute(0, 2, 1)\n            H = torch.matmul(S, H).squeeze()\n\n            return H\n\n\nclass MLPLayer(nn.Module):\n    def __init__(self, in_dim, out_dim, bias=True, activation=None, dropout=0):\n        super(MLPLayer, self).__init__()\n\n        self.linear = nn.Linear(in_dim, out_dim, bias=bias)\n        self.activation = activation\n        self.dropout = nn.Dropout(dropout)\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        gain = 1.0\n        if self.activation is F.relu:\n            gain = nn.init.calculate_gain(\"relu\")\n        nn.init.xavier_uniform_(self.linear.weight, gain=gain)\n        if self.linear.bias is not None:\n            nn.init.zeros_(self.linear.bias)\n\n    def forward(self, feats):\n        feats = self.dropout(feats)\n        feats = self.linear(feats)\n        if self.activation:\n            feats = self.activation(feats)\n\n        return feats\n\n\nclass DAGNN(nn.Module):\n    def __init__(\n        self,\n        k,\n        in_dim,\n        hid_dim,\n        out_dim,\n        bias=True,\n        activation=F.relu,\n        dropout=0,\n    ):\n        super(DAGNN, self).__init__()\n        self.mlp = nn.ModuleList()\n        self.mlp.append(\n            MLPLayer(\n                in_dim=in_dim,\n                out_dim=hid_dim,\n                bias=bias,\n                activation=activation,\n                dropout=dropout,\n            )\n        )\n        self.mlp.append(\n            MLPLayer(\n                in_dim=hid_dim,\n                out_dim=out_dim,\n                bias=bias,\n                activation=None,\n                dropout=dropout,\n            )\n        )\n        self.dagnn = DAGNNConv(in_dim=out_dim, k=k)\n\n    def forward(self, graph, feats):\n        for layer in self.mlp:\n            feats = layer(feats)\n        feats = self.dagnn(graph, feats)\n        return feats\n\n\ndef main(args):\n    # Step 1: Prepare graph data and retrieve train/validation/test index ============================= #\n    # Load from DGL dataset\n    if args.dataset == \"Cora\":\n        dataset = CoraGraphDataset()\n    elif args.dataset == \"Citeseer\":\n        dataset = CiteseerGraphDataset()\n    elif args.dataset == \"Pubmed\":\n        dataset = PubmedGraphDataset()\n    else:\n        raise ValueError(\"Dataset {} is invalid.\".format(args.dataset))\n\n    graph = dataset[0]\n    graph = graph.add_self_loop()\n\n    # check cuda\n    if args.gpu >= 0 and torch.cuda.is_available():\n        device = \"cuda:{}\".format(args.gpu)\n    else:\n        device = \"cpu\"\n\n    # retrieve the number of classes\n    n_classes = dataset.num_classes\n\n    # retrieve labels of ground truth\n    labels = graph.ndata.pop(\"label\").to(device).long()\n\n    # Extract node features\n    feats = graph.ndata.pop(\"feat\").to(device)\n    n_features = feats.shape[-1]\n\n    # retrieve masks for train/validation/test\n    train_mask = graph.ndata.pop(\"train_mask\")\n    val_mask = graph.ndata.pop(\"val_mask\")\n    test_mask = graph.ndata.pop(\"test_mask\")\n\n    train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze().to(device)\n    val_idx = torch.nonzero(val_mask, as_tuple=False).squeeze().to(device)\n    test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze().to(device)\n\n    graph = graph.to(device)\n\n    # Step 2: Create model =================================================================== #\n    model = DAGNN(\n        k=args.k,\n        in_dim=n_features,\n        hid_dim=args.hid_dim,\n        out_dim=n_classes,\n        dropout=args.dropout,\n    )\n    model = model.to(device)\n\n    # Step 3: Create training components ===================================================== #\n    loss_fn = F.cross_entropy\n    opt = torch.optim.Adam(\n        model.parameters(), lr=args.lr, weight_decay=args.lamb\n    )\n\n    # Step 4: training epochs =============================================================== #\n    loss = float(\"inf\")\n    best_acc = 0\n    no_improvement = 0\n    epochs = trange(args.epochs, desc=\"Accuracy & Loss\")\n\n    for _ in epochs:\n        model.train()\n\n        logits = model(graph, feats)\n\n        # compute loss\n        train_loss = loss_fn(logits[train_idx], labels[train_idx])\n\n        # backward\n        opt.zero_grad()\n        train_loss.backward()\n        opt.step()\n\n        (\n            train_loss,\n            train_acc,\n            valid_loss,\n            valid_acc,\n            test_loss,\n            test_acc,\n        ) = evaluate(\n            model, graph, feats, labels, (train_idx, val_idx, test_idx)\n        )\n\n        # Print out performance\n        epochs.set_description(\n            \"Train Acc {:.4f} | Train Loss {:.4f} | Val Acc {:.4f} | Val loss {:.4f}\".format(\n                train_acc, train_loss.item(), valid_acc, valid_loss.item()\n            )\n        )\n\n        if valid_loss > loss:\n            no_improvement += 1\n            if no_improvement == args.early_stopping:\n                print(\"Early stop.\")\n                break\n        else:\n            no_improvement = 0\n            loss = valid_loss\n            best_acc = test_acc\n\n    print(\"Test Acc {:.4f}\".format(best_acc))\n    return best_acc\n\n\nif __name__ == \"__main__\":\n    \"\"\"\n    DAGNN Model Hyperparameters\n    \"\"\"\n    parser = argparse.ArgumentParser(description=\"DAGNN\")\n    # data source params\n    parser.add_argument(\n        \"--dataset\",\n        type=str,\n        default=\"Cora\",\n        choices=[\"Cora\", \"Citeseer\", \"Pubmed\"],\n        help=\"Name of dataset.\",\n    )\n    # cuda params\n    parser.add_argument(\n        \"--gpu\", type=int, default=-1, help=\"GPU index. Default: -1, using CPU.\"\n    )\n    # training params\n    parser.add_argument(\"--runs\", type=int, default=1, help=\"Training runs.\")\n    parser.add_argument(\n        \"--epochs\", type=int, default=1500, help=\"Training epochs.\"\n    )\n    parser.add_argument(\n        \"--early-stopping\",\n        type=int,\n        default=100,\n        help=\"Patient epochs to wait before early stopping.\",\n    )\n    parser.add_argument(\"--lr\", type=float, default=0.01, help=\"Learning rate.\")\n    parser.add_argument(\"--lamb\", type=float, default=0.005, help=\"L2 reg.\")\n    # model params\n    parser.add_argument(\n        \"--k\", type=int, default=12, help=\"Number of propagation layers.\"\n    )\n    parser.add_argument(\n        \"--hid-dim\", type=int, default=64, help=\"Hidden layer dimensionalities.\"\n    )\n    parser.add_argument(\"--dropout\", type=float, default=0.8, help=\"dropout\")\n    args = parser.parse_args()\n    print(args)\n\n    acc_lists = []\n    random_seeds = generate_random_seeds(seed=1222, nums=args.runs)\n\n    for run in range(args.runs):\n        set_random_state(random_seeds[run])\n        acc_lists.append(main(args))\n\n    acc_lists = np.array(acc_lists)\n\n    mean = np.around(np.mean(acc_lists, axis=0), decimals=4)\n    std = np.around(np.std(acc_lists, axis=0), decimals=4)\n\n    print(\"Total acc: \", acc_lists)\n    print(\"mean\", mean)\n    print(\"std\", std)\n"
  },
  {
    "path": "examples/pytorch/dagnn/utils.py",
    "content": "import random\n\nimport numpy as np\nimport torch\nfrom torch.nn import functional as F\n\n\ndef evaluate(model, graph, feats, labels, idxs):\n    model.eval()\n    with torch.no_grad():\n        logits = model(graph, feats)\n        results = ()\n        for idx in idxs:\n            loss = F.cross_entropy(logits[idx], labels[idx])\n            acc = torch.sum(\n                logits[idx].argmax(dim=1) == labels[idx]\n            ).item() / len(idx)\n            results += (loss, acc)\n    return results\n\n\ndef generate_random_seeds(seed, nums):\n    random.seed(seed)\n    return [random.randint(1, 999999999) for _ in range(nums)]\n\n\ndef set_random_state(seed):\n    random.seed(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n    if torch.cuda.is_available():\n        torch.cuda.manual_seed_all(seed)\n        torch.backends.cudnn.deterministic = True\n"
  },
  {
    "path": "examples/pytorch/deepergcn/README.md",
    "content": "# DGL Implementation of DeeperGCN\n\nThis DGL example implements the GNN model proposed in the paper [DeeperGCN: All You Need to Train Deeper GCNs](https://arxiv.org/abs/2006.07739). For the original implementation, see [here](https://github.com/lightaime/deep_gcns_torch).\n\nContributor: [xnuohz](https://github.com/xnuohz)\n\n### Requirements\nThe codebase is implemented in Python 3.7. For version requirement of packages, see below.\n\n```\ndgl 0.6.0.post1\ntorch 1.7.0\nogb 1.3.0\n```\n\n### The graph datasets used in this example\n\nOpen Graph Benchmark(OGB). Dataset summary:\n\n###### Graph Property Prediction\n\n|   Dataset   | #Graphs | #Node Feats | #Edge Feats | Metric  |\n| :---------: | :-----: | :---------: | :---------: | :-----: |\n| ogbg-molhiv | 41,127  |      9      |      3      | ROC-AUC |\n\n### Usage\n\nTrain a model which follows the original hyperparameters on different datasets.\n```bash\n# ogbg-molhiv\npython main.py --gpu 0 --learn-beta\n```\n\n### Performance\n\n* Table 6: Numbers associated with \"Table 6\" are the ones from table 6 in the paper.\n* Author: Numbers associated with \"Author\" are the ones we got by running the original code.\n* DGL: Numbers associated with \"DGL\" are the ones we got by running the DGL example.\n\n|     Dataset      | ogbg-molhiv |\n| :--------------: | :---------: |\n| Results(Table 6) |    0.786    |\n| Results(Author)  |    0.781    |\n|   Results(DGL)   |    0.778    |\n\n### Speed\n\n|     Dataset     | ogbg-molhiv |\n| :-------------: | :---------: |\n| Results(Author) |   11.833    |\n|  Results(DGL)   |    8.965    |\n"
  },
  {
    "path": "examples/pytorch/deepergcn/layers.py",
    "content": "import dgl.function as fn\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dgl.nn.functional import edge_softmax\nfrom modules import MessageNorm, MLP\nfrom ogb.graphproppred.mol_encoder import BondEncoder\n\n\nclass GENConv(nn.Module):\n    r\"\"\"\n\n    Description\n    -----------\n    Generalized Message Aggregator was introduced in \"DeeperGCN: All You Need to Train Deeper GCNs <https://arxiv.org/abs/2006.07739>\"\n\n    Parameters\n    ----------\n    in_dim: int\n        Input size.\n    out_dim: int\n        Output size.\n    aggregator: str\n        Type of aggregation. Default is 'softmax'.\n    beta: float\n        A continuous variable called an inverse temperature. Default is 1.0.\n    learn_beta: bool\n        Whether beta is a learnable variable or not. Default is False.\n    p: float\n        Initial power for power mean aggregation. Default is 1.0.\n    learn_p: bool\n        Whether p is a learnable variable or not. Default is False.\n    msg_norm: bool\n        Whether message normalization is used. Default is False.\n    learn_msg_scale: bool\n        Whether s is a learnable scaling factor or not in message normalization. Default is False.\n    mlp_layers: int\n        The number of MLP layers. Default is 1.\n    eps: float\n        A small positive constant in message construction function. Default is 1e-7.\n    \"\"\"\n\n    def __init__(\n        self,\n        in_dim,\n        out_dim,\n        aggregator=\"softmax\",\n        beta=1.0,\n        learn_beta=False,\n        p=1.0,\n        learn_p=False,\n        msg_norm=False,\n        learn_msg_scale=False,\n        mlp_layers=1,\n        eps=1e-7,\n    ):\n        super(GENConv, self).__init__()\n\n        self.aggr = aggregator\n        self.eps = eps\n\n        channels = [in_dim]\n        for _ in range(mlp_layers - 1):\n            channels.append(in_dim * 2)\n        channels.append(out_dim)\n\n        self.mlp = MLP(channels)\n        self.msg_norm = MessageNorm(learn_msg_scale) if msg_norm else None\n\n        self.beta = (\n            nn.Parameter(torch.Tensor([beta]), requires_grad=True)\n            if learn_beta and self.aggr == \"softmax\"\n            else beta\n        )\n        self.p = (\n            nn.Parameter(torch.Tensor([p]), requires_grad=True)\n            if learn_p\n            else p\n        )\n\n        self.edge_encoder = BondEncoder(in_dim)\n\n    def forward(self, g, node_feats, edge_feats):\n        with g.local_scope():\n            # Node and edge feature size need to match.\n            g.ndata[\"h\"] = node_feats\n            g.edata[\"h\"] = self.edge_encoder(edge_feats)\n            g.apply_edges(fn.u_add_e(\"h\", \"h\", \"m\"))\n\n            if self.aggr == \"softmax\":\n                g.edata[\"m\"] = F.relu(g.edata[\"m\"]) + self.eps\n                g.edata[\"a\"] = edge_softmax(g, g.edata[\"m\"] * self.beta)\n                g.update_all(\n                    lambda edge: {\"x\": edge.data[\"m\"] * edge.data[\"a\"]},\n                    fn.sum(\"x\", \"m\"),\n                )\n\n            elif self.aggr == \"power\":\n                minv, maxv = 1e-7, 1e1\n                torch.clamp_(g.edata[\"m\"], minv, maxv)\n                g.update_all(\n                    lambda edge: {\"x\": torch.pow(edge.data[\"m\"], self.p)},\n                    fn.mean(\"x\", \"m\"),\n                )\n                torch.clamp_(g.ndata[\"m\"], minv, maxv)\n                g.ndata[\"m\"] = torch.pow(g.ndata[\"m\"], self.p)\n\n            else:\n                raise NotImplementedError(\n                    f\"Aggregator {self.aggr} is not supported.\"\n                )\n\n            if self.msg_norm is not None:\n                g.ndata[\"m\"] = self.msg_norm(node_feats, g.ndata[\"m\"])\n\n            feats = node_feats + g.ndata[\"m\"]\n\n            return self.mlp(feats)\n"
  },
  {
    "path": "examples/pytorch/deepergcn/main.py",
    "content": "import argparse\nimport copy\nimport time\n\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom models import DeeperGCN\nfrom ogb.graphproppred import collate_dgl, DglGraphPropPredDataset, Evaluator\nfrom torch.utils.data import DataLoader\n\n\ndef train(model, device, data_loader, opt, loss_fn):\n    model.train()\n\n    train_loss = []\n    for g, labels in data_loader:\n        g = g.to(device)\n        labels = labels.to(torch.float32).to(device)\n        logits = model(g, g.edata[\"feat\"], g.ndata[\"feat\"])\n        loss = loss_fn(logits, labels)\n        train_loss.append(loss.item())\n\n        opt.zero_grad()\n        loss.backward()\n        opt.step()\n\n    return sum(train_loss) / len(train_loss)\n\n\n@torch.no_grad()\ndef test(model, device, data_loader, evaluator):\n    model.eval()\n    y_true, y_pred = [], []\n\n    for g, labels in data_loader:\n        g = g.to(device)\n        logits = model(g, g.edata[\"feat\"], g.ndata[\"feat\"])\n        y_true.append(labels.detach().cpu())\n        y_pred.append(logits.detach().cpu())\n\n    y_true = torch.cat(y_true, dim=0).numpy()\n    y_pred = torch.cat(y_pred, dim=0).numpy()\n\n    return evaluator.eval({\"y_true\": y_true, \"y_pred\": y_pred})[\"rocauc\"]\n\n\ndef main():\n    # check cuda\n    device = (\n        f\"cuda:{args.gpu}\"\n        if args.gpu >= 0 and torch.cuda.is_available()\n        else \"cpu\"\n    )\n\n    # load ogb dataset & evaluator\n    dataset = DglGraphPropPredDataset(name=\"ogbg-molhiv\")\n    evaluator = Evaluator(name=\"ogbg-molhiv\")\n\n    g, _ = dataset[0]\n    node_feat_dim = g.ndata[\"feat\"].size()[-1]\n    edge_feat_dim = g.edata[\"feat\"].size()[-1]\n    n_classes = dataset.num_tasks\n\n    split_idx = dataset.get_idx_split()\n    train_loader = DataLoader(\n        dataset[split_idx[\"train\"]],\n        batch_size=args.batch_size,\n        shuffle=True,\n        collate_fn=collate_dgl,\n    )\n    valid_loader = DataLoader(\n        dataset[split_idx[\"valid\"]],\n        batch_size=args.batch_size,\n        shuffle=False,\n        collate_fn=collate_dgl,\n    )\n    test_loader = DataLoader(\n        dataset[split_idx[\"test\"]],\n        batch_size=args.batch_size,\n        shuffle=False,\n        collate_fn=collate_dgl,\n    )\n\n    # load model\n    model = DeeperGCN(\n        node_feat_dim=node_feat_dim,\n        edge_feat_dim=edge_feat_dim,\n        hid_dim=args.hid_dim,\n        out_dim=n_classes,\n        num_layers=args.num_layers,\n        dropout=args.dropout,\n        learn_beta=args.learn_beta,\n    ).to(device)\n\n    print(model)\n\n    opt = optim.Adam(model.parameters(), lr=args.lr)\n    loss_fn = nn.BCEWithLogitsLoss()\n\n    # training & validation & testing\n    best_auc = 0\n    best_model = copy.deepcopy(model)\n    times = []\n\n    print(\"---------- Training ----------\")\n    for i in range(args.epochs):\n        t1 = time.time()\n        train_loss = train(model, device, train_loader, opt, loss_fn)\n        t2 = time.time()\n\n        if i >= 5:\n            times.append(t2 - t1)\n\n        train_auc = test(model, device, train_loader, evaluator)\n        valid_auc = test(model, device, valid_loader, evaluator)\n\n        print(\n            f\"Epoch {i} | Train Loss: {train_loss:.4f} | Train Auc: {train_auc:.4f} | Valid Auc: {valid_auc:.4f}\"\n        )\n\n        if valid_auc > best_auc:\n            best_auc = valid_auc\n            best_model = copy.deepcopy(model)\n\n    print(\"---------- Testing ----------\")\n    test_auc = test(best_model, device, test_loader, evaluator)\n    print(f\"Test Auc: {test_auc}\")\n    if len(times) > 0:\n        print(\"Times/epoch: \", sum(times) / len(times))\n\n\nif __name__ == \"__main__\":\n    \"\"\"\n    DeeperGCN Hyperparameters\n    \"\"\"\n    parser = argparse.ArgumentParser(description=\"DeeperGCN\")\n    # training\n    parser.add_argument(\n        \"--gpu\", type=int, default=-1, help=\"GPU index, -1 for CPU.\"\n    )\n    parser.add_argument(\n        \"--epochs\", type=int, default=300, help=\"Number of epochs to train.\"\n    )\n    parser.add_argument(\"--lr\", type=float, default=0.01, help=\"Learning rate.\")\n    parser.add_argument(\n        \"--dropout\", type=float, default=0.2, help=\"Dropout rate.\"\n    )\n    parser.add_argument(\n        \"--batch-size\", type=int, default=2048, help=\"Batch size.\"\n    )\n    # model\n    parser.add_argument(\n        \"--num-layers\", type=int, default=7, help=\"Number of GNN layers.\"\n    )\n    parser.add_argument(\n        \"--hid-dim\", type=int, default=256, help=\"Hidden channel size.\"\n    )\n    # learnable parameters in aggr\n    parser.add_argument(\"--learn-beta\", action=\"store_true\")\n\n    args = parser.parse_args()\n    print(args)\n\n    main()\n"
  },
  {
    "path": "examples/pytorch/deepergcn/models.py",
    "content": "import dgl.function as fn\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dgl.nn.pytorch.glob import AvgPooling\nfrom layers import GENConv\nfrom ogb.graphproppred.mol_encoder import AtomEncoder\n\n\nclass DeeperGCN(nn.Module):\n    r\"\"\"\n\n    Description\n    -----------\n    Introduced in \"DeeperGCN: All You Need to Train Deeper GCNs <https://arxiv.org/abs/2006.07739>\"\n\n    Parameters\n    ----------\n    node_feat_dim: int\n        Size of node feature.\n    edge_feat_dim: int\n        Size of edge feature.\n    hid_dim: int\n        Size of hidden representations.\n    out_dim: int\n        Size of output.\n    num_layers: int\n        Number of graph convolutional layers.\n    dropout: float\n        Dropout rate. Default is 0.\n    beta: float\n        A continuous variable called an inverse temperature. Default is 1.0.\n    learn_beta: bool\n        Whether beta is a learnable weight. Default is False.\n    aggr: str\n        Type of aggregation. Default is 'softmax'.\n    mlp_layers: int\n        Number of MLP layers in message normalization. Default is 1.\n    \"\"\"\n\n    def __init__(\n        self,\n        node_feat_dim,\n        edge_feat_dim,\n        hid_dim,\n        out_dim,\n        num_layers,\n        dropout=0.0,\n        beta=1.0,\n        learn_beta=False,\n        aggr=\"softmax\",\n        mlp_layers=1,\n    ):\n        super(DeeperGCN, self).__init__()\n\n        self.num_layers = num_layers\n        self.dropout = dropout\n        self.gcns = nn.ModuleList()\n        self.norms = nn.ModuleList()\n\n        for _ in range(self.num_layers):\n            conv = GENConv(\n                in_dim=hid_dim,\n                out_dim=hid_dim,\n                aggregator=aggr,\n                beta=beta,\n                learn_beta=learn_beta,\n                mlp_layers=mlp_layers,\n            )\n\n            self.gcns.append(conv)\n            self.norms.append(nn.BatchNorm1d(hid_dim, affine=True))\n\n        self.node_encoder = AtomEncoder(hid_dim)\n        self.pooling = AvgPooling()\n        self.output = nn.Linear(hid_dim, out_dim)\n\n    def forward(self, g, edge_feats, node_feats=None):\n        with g.local_scope():\n            hv = self.node_encoder(node_feats)\n            he = edge_feats\n\n            for layer in range(self.num_layers):\n                hv1 = self.norms[layer](hv)\n                hv1 = F.relu(hv1)\n                hv1 = F.dropout(hv1, p=self.dropout, training=self.training)\n                hv = self.gcns[layer](g, hv1, he) + hv\n\n            h_g = self.pooling(g, hv)\n\n            return self.output(h_g)\n"
  },
  {
    "path": "examples/pytorch/deepergcn/modules.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass MLP(nn.Sequential):\n    r\"\"\"\n\n    Description\n    -----------\n    From equation (5) in \"DeeperGCN: All You Need to Train Deeper GCNs <https://arxiv.org/abs/2006.07739>\"\n    \"\"\"\n\n    def __init__(self, channels, act=\"relu\", dropout=0.0, bias=True):\n        layers = []\n\n        for i in range(1, len(channels)):\n            layers.append(nn.Linear(channels[i - 1], channels[i], bias))\n            if i < len(channels) - 1:\n                layers.append(nn.BatchNorm1d(channels[i], affine=True))\n                layers.append(nn.ReLU())\n                layers.append(nn.Dropout(dropout))\n\n        super(MLP, self).__init__(*layers)\n\n\nclass MessageNorm(nn.Module):\n    r\"\"\"\n\n    Description\n    -----------\n    Message normalization was introduced in \"DeeperGCN: All You Need to Train Deeper GCNs <https://arxiv.org/abs/2006.07739>\"\n\n    Parameters\n    ----------\n    learn_scale: bool\n        Whether s is a learnable scaling factor or not. Default is False.\n    \"\"\"\n\n    def __init__(self, learn_scale=False):\n        super(MessageNorm, self).__init__()\n        self.scale = nn.Parameter(\n            torch.FloatTensor([1.0]), requires_grad=learn_scale\n        )\n\n    def forward(self, feats, msg, p=2):\n        msg = F.normalize(msg, p=2, dim=-1)\n        feats_norm = feats.norm(p=p, dim=-1, keepdim=True)\n        return msg * feats_norm * self.scale\n"
  },
  {
    "path": "examples/pytorch/deepwalk/README.md",
    "content": "# DeepWalk\n\n- Paper link: [here](https://arxiv.org/pdf/1403.6652.pdf)\n\nThe example code was moved to examples/pytorch/ogb/deepwalk.\n"
  },
  {
    "path": "examples/pytorch/dgi/README.md",
    "content": "Deep Graph Infomax (DGI)\n========================\n\n- Paper link: [https://arxiv.org/abs/1809.10341](https://arxiv.org/abs/1809.10341)\n- Author's code repo (in Pytorch):\n  [https://github.com/PetarV-/DGI](https://github.com/PetarV-/DGI)\n\nDependencies\n------------\n- PyTorch 0.4.1+\n- requests\n\n```bash\npip install torch requests\n```\n\nHow to run\n----------\n\nRun with following:\n\n```bash\npython3 train.py --dataset=cora --gpu=0 --self-loop\n```\n\n```bash\npython3 train.py --dataset=citeseer --gpu=0\n```\n\n```bash\npython3 train.py --dataset=pubmed --gpu=0\n```\n\nResults\n-------\n* cora: ~81.6 (81.2-82.1) (paper: 82.3)\n* citeseer: ~69.4 (paper: 71.8)\n* pubmed: ~76.1 (paper: 76.8)\n"
  },
  {
    "path": "examples/pytorch/dgi/dgi.py",
    "content": "\"\"\"\nDeep Graph Infomax in DGL\n\nReferences\n----------\nPapers: https://arxiv.org/abs/1809.10341\nAuthor's code: https://github.com/PetarV-/DGI\n\"\"\"\n\nimport math\n\nimport torch\nimport torch.nn as nn\nfrom gcn import GCN\n\n\nclass Encoder(nn.Module):\n    def __init__(self, g, in_feats, n_hidden, n_layers, activation, dropout):\n        super(Encoder, self).__init__()\n        self.g = g\n        self.conv = GCN(\n            g, in_feats, n_hidden, n_hidden, n_layers, activation, dropout\n        )\n\n    def forward(self, features, corrupt=False):\n        if corrupt:\n            perm = torch.randperm(self.g.num_nodes())\n            features = features[perm]\n        features = self.conv(features)\n        return features\n\n\nclass Discriminator(nn.Module):\n    def __init__(self, n_hidden):\n        super(Discriminator, self).__init__()\n        self.weight = nn.Parameter(torch.Tensor(n_hidden, n_hidden))\n        self.reset_parameters()\n\n    def uniform(self, size, tensor):\n        bound = 1.0 / math.sqrt(size)\n        if tensor is not None:\n            tensor.data.uniform_(-bound, bound)\n\n    def reset_parameters(self):\n        size = self.weight.size(0)\n        self.uniform(size, self.weight)\n\n    def forward(self, features, summary):\n        features = torch.matmul(features, torch.matmul(self.weight, summary))\n        return features\n\n\nclass DGI(nn.Module):\n    def __init__(self, g, in_feats, n_hidden, n_layers, activation, dropout):\n        super(DGI, self).__init__()\n        self.encoder = Encoder(\n            g, in_feats, n_hidden, n_layers, activation, dropout\n        )\n        self.discriminator = Discriminator(n_hidden)\n        self.loss = nn.BCEWithLogitsLoss()\n\n    def forward(self, features):\n        positive = self.encoder(features, corrupt=False)\n        negative = self.encoder(features, corrupt=True)\n        summary = torch.sigmoid(positive.mean(dim=0))\n\n        positive = self.discriminator(positive, summary)\n        negative = self.discriminator(negative, summary)\n\n        l1 = self.loss(positive, torch.ones_like(positive))\n        l2 = self.loss(negative, torch.zeros_like(negative))\n\n        return l1 + l2\n\n\nclass Classifier(nn.Module):\n    def __init__(self, n_hidden, n_classes):\n        super(Classifier, self).__init__()\n        self.fc = nn.Linear(n_hidden, n_classes)\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        self.fc.reset_parameters()\n\n    def forward(self, features):\n        features = self.fc(features)\n        return torch.log_softmax(features, dim=-1)\n"
  },
  {
    "path": "examples/pytorch/dgi/gcn.py",
    "content": "\"\"\"\nThis code was copied from the GCN implementation in DGL examples.\n\"\"\"\nimport torch\nimport torch.nn as nn\n\nfrom dgl.nn.pytorch import GraphConv\n\n\nclass GCN(nn.Module):\n    def __init__(\n        self, g, in_feats, n_hidden, n_classes, n_layers, activation, dropout\n    ):\n        super(GCN, self).__init__()\n        self.g = g\n        self.layers = nn.ModuleList()\n        # input layer\n        self.layers.append(GraphConv(in_feats, n_hidden, activation=activation))\n        # hidden layers\n        for i in range(n_layers - 1):\n            self.layers.append(\n                GraphConv(n_hidden, n_hidden, activation=activation)\n            )\n        # output layer\n        self.layers.append(GraphConv(n_hidden, n_classes))\n        self.dropout = nn.Dropout(p=dropout)\n\n    def forward(self, features):\n        h = features\n        for i, layer in enumerate(self.layers):\n            if i != 0:\n                h = self.dropout(h)\n            h = layer(self.g, h)\n        return h\n"
  },
  {
    "path": "examples/pytorch/dgi/train.py",
    "content": "import argparse, time\n\nimport dgl\nimport networkx as nx\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dgi import Classifier, DGI\nfrom dgl import DGLGraph\nfrom dgl.data import load_data, register_data_args\n\n\ndef evaluate(model, features, labels, mask):\n    model.eval()\n    with torch.no_grad():\n        logits = model(features)\n        logits = logits[mask]\n        labels = labels[mask]\n        _, indices = torch.max(logits, dim=1)\n        correct = torch.sum(indices == labels)\n        return correct.item() * 1.0 / len(labels)\n\n\ndef main(args):\n    # load and preprocess dataset\n    data = load_data(args)\n    g = data[0]\n    features = torch.FloatTensor(g.ndata[\"feat\"])\n    labels = torch.LongTensor(g.ndata[\"label\"])\n    if hasattr(torch, \"BoolTensor\"):\n        train_mask = torch.BoolTensor(g.ndata[\"train_mask\"])\n        val_mask = torch.BoolTensor(g.ndata[\"val_mask\"])\n        test_mask = torch.BoolTensor(g.ndata[\"test_mask\"])\n    else:\n        train_mask = torch.ByteTensor(g.ndata[\"train_mask\"])\n        val_mask = torch.ByteTensor(g.ndata[\"val_mask\"])\n        test_mask = torch.ByteTensor(g.ndata[\"test_mask\"])\n    in_feats = features.shape[1]\n    n_classes = data.num_classes\n    n_edges = g.num_edges()\n\n    if args.gpu < 0:\n        cuda = False\n    else:\n        cuda = True\n        torch.cuda.set_device(args.gpu)\n        features = features.cuda()\n        labels = labels.cuda()\n        train_mask = train_mask.cuda()\n        val_mask = val_mask.cuda()\n        test_mask = test_mask.cuda()\n\n    # add self loop\n    if args.self_loop:\n        g = dgl.remove_self_loop(g)\n        g = dgl.add_self_loop(g)\n    n_edges = g.num_edges()\n\n    if args.gpu >= 0:\n        g = g.to(args.gpu)\n    # create DGI model\n    dgi = DGI(\n        g,\n        in_feats,\n        args.n_hidden,\n        args.n_layers,\n        nn.PReLU(args.n_hidden),\n        args.dropout,\n    )\n\n    if cuda:\n        dgi.cuda()\n\n    dgi_optimizer = torch.optim.Adam(\n        dgi.parameters(), lr=args.dgi_lr, weight_decay=args.weight_decay\n    )\n\n    # train deep graph infomax\n    cnt_wait = 0\n    best = 1e9\n    best_t = 0\n    mean = 0\n    for epoch in range(args.n_dgi_epochs):\n        dgi.train()\n        if epoch >= 3:\n            t0 = time.time()\n\n        dgi_optimizer.zero_grad()\n        loss = dgi(features)\n        loss.backward()\n        dgi_optimizer.step()\n\n        if loss < best:\n            best = loss\n            best_t = epoch\n            cnt_wait = 0\n            torch.save(dgi.state_dict(), \"best_dgi.pkl\")\n        else:\n            cnt_wait += 1\n\n        if cnt_wait == args.patience:\n            print(\"Early stopping!\")\n            break\n\n        if epoch >= 3:\n            mean = (mean * (epoch - 3) + (time.time() - t0)) / (epoch - 2)\n\n            print(\n                \"Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | \"\n                \"ETputs(KTEPS) {:.2f}\".format(\n                    epoch, mean, loss.item(), n_edges / mean / 1000\n                )\n            )\n\n    # create classifier model\n    classifier = Classifier(args.n_hidden, n_classes)\n    if cuda:\n        classifier.cuda()\n\n    classifier_optimizer = torch.optim.Adam(\n        classifier.parameters(),\n        lr=args.classifier_lr,\n        weight_decay=args.weight_decay,\n    )\n\n    # train classifier\n    print(\"Loading {}th epoch\".format(best_t))\n    dgi.load_state_dict(torch.load(\"best_dgi.pkl\", weights_only=False))\n    embeds = dgi.encoder(features, corrupt=False)\n    embeds = embeds.detach()\n    mean = 0\n    for epoch in range(args.n_classifier_epochs):\n        classifier.train()\n        if epoch >= 3:\n            t0 = time.time()\n\n        classifier_optimizer.zero_grad()\n        preds = classifier(embeds)\n        loss = F.nll_loss(preds[train_mask], labels[train_mask])\n        loss.backward()\n        classifier_optimizer.step()\n\n        if epoch >= 3:\n            mean = (mean * (epoch - 3) + (time.time() - t0)) / (epoch - 2)\n\n            acc = evaluate(classifier, embeds, labels, val_mask)\n            print(\n                \"Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | \"\n                \"ETputs(KTEPS) {:.2f}\".format(\n                    epoch,\n                    mean,\n                    loss.item(),\n                    acc,\n                    n_edges / mean / 1000,\n                )\n            )\n\n    print()\n    acc = evaluate(classifier, embeds, labels, test_mask)\n    print(\"Test Accuracy {:.4f}\".format(acc))\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"DGI\")\n    register_data_args(parser)\n    parser.add_argument(\n        \"--dropout\", type=float, default=0.0, help=\"dropout probability\"\n    )\n    parser.add_argument(\"--gpu\", type=int, default=-1, help=\"gpu\")\n    parser.add_argument(\n        \"--dgi-lr\", type=float, default=1e-3, help=\"dgi learning rate\"\n    )\n    parser.add_argument(\n        \"--classifier-lr\",\n        type=float,\n        default=1e-2,\n        help=\"classifier learning rate\",\n    )\n    parser.add_argument(\n        \"--n-dgi-epochs\",\n        type=int,\n        default=300,\n        help=\"number of training epochs\",\n    )\n    parser.add_argument(\n        \"--n-classifier-epochs\",\n        type=int,\n        default=300,\n        help=\"number of training epochs\",\n    )\n    parser.add_argument(\n        \"--n-hidden\", type=int, default=512, help=\"number of hidden gcn units\"\n    )\n    parser.add_argument(\n        \"--n-layers\", type=int, default=1, help=\"number of hidden gcn layers\"\n    )\n    parser.add_argument(\n        \"--weight-decay\", type=float, default=0.0, help=\"Weight for L2 loss\"\n    )\n    parser.add_argument(\n        \"--patience\", type=int, default=20, help=\"early stop patience condition\"\n    )\n    parser.add_argument(\n        \"--self-loop\",\n        action=\"store_true\",\n        help=\"graph self-loop (default=False)\",\n    )\n    parser.set_defaults(self_loop=False)\n    args = parser.parse_args()\n    print(args)\n\n    main(args)\n"
  },
  {
    "path": "examples/pytorch/dgmg/README.md",
    "content": "# Learning Deep Generative Models of Graphs\n\nThis is an implementation of [Learning Deep Generative Models of Graphs](https://arxiv.org/pdf/1803.03324.pdf) by \nYujia Li, Oriol Vinyals, Chris Dyer, Razvan Pascanu, Peter Battaglia.\n\nFor molecule generation, see \n[DGL-LifeSci](https://github.com/awslabs/dgl-lifesci/tree/master/examples/generative_models/dgmg).\n\n## Dependencies\n- Python 3.5.2\n- [Pytorch 0.4.1](https://pytorch.org/)\n- [Matplotlib 2.2.2](https://matplotlib.org/)\n\n## Usage\n\n`python3 main.py`\n\n## Performance\n\n90% accuracy for cycles compared with 84% accuracy reported in the original paper.\n\n## Speed\n\nOn AWS p3.2x instance (w/ V100), one epoch takes ~526s.\n\n## Acknowledgement\n\nWe would like to thank Yujia Li for providing details on the implementation.\n"
  },
  {
    "path": "examples/pytorch/dgmg/configure.py",
    "content": "\"\"\"We intend to make our reproduction as close as possible to the original paper.\nThe configuration in the file is mostly from the description in the original paper\nand will be loaded when setting up.\"\"\"\n\n\ndef dataset_based_configure(opts):\n    if opts[\"dataset\"] == \"cycles\":\n        ds_configure = cycles_configure\n    else:\n        raise ValueError(\"Unsupported dataset: {}\".format(opts[\"dataset\"]))\n\n    opts = {**opts, **ds_configure}\n\n    return opts\n\n\nsynthetic_dataset_configure = {\n    \"node_hidden_size\": 16,\n    \"num_propagation_rounds\": 2,\n    \"optimizer\": \"Adam\",\n    \"nepochs\": 25,\n    \"ds_size\": 4000,\n    \"num_generated_samples\": 10000,\n}\n\ncycles_configure = {\n    **synthetic_dataset_configure,\n    **{\n        \"min_size\": 10,\n        \"max_size\": 20,\n        \"lr\": 5e-4,\n    },\n}\n"
  },
  {
    "path": "examples/pytorch/dgmg/cycles.py",
    "content": "import os\nimport pickle\nimport random\n\nimport matplotlib.pyplot as plt\nimport networkx as nx\nfrom torch.utils.data import Dataset\n\n\ndef get_previous(i, v_max):\n    if i == 0:\n        return v_max\n    else:\n        return i - 1\n\n\ndef get_next(i, v_max):\n    if i == v_max:\n        return 0\n    else:\n        return i + 1\n\n\ndef is_cycle(g):\n    size = g.num_nodes()\n\n    if size < 3:\n        return False\n\n    for node in range(size):\n        neighbors = g.successors(node)\n\n        if len(neighbors) != 2:\n            return False\n\n        if get_previous(node, size - 1) not in neighbors:\n            return False\n\n        if get_next(node, size - 1) not in neighbors:\n            return False\n\n    return True\n\n\ndef get_decision_sequence(size):\n    \"\"\"\n    Get the decision sequence for generating valid cycles with DGMG for teacher\n    forcing optimization.\n    \"\"\"\n    decision_sequence = []\n\n    for i in range(size):\n        decision_sequence.append(0)  # Add node\n\n        if i != 0:\n            decision_sequence.append(0)  # Add edge\n            decision_sequence.append(\n                i - 1\n            )  # Set destination to be previous node.\n\n        if i == size - 1:\n            decision_sequence.append(0)  # Add edge\n            decision_sequence.append(0)  # Set destination to be the root.\n\n        decision_sequence.append(1)  # Stop adding edge\n\n    decision_sequence.append(1)  # Stop adding node\n\n    return decision_sequence\n\n\ndef generate_dataset(v_min, v_max, n_samples, fname):\n    samples = []\n    for _ in range(n_samples):\n        size = random.randint(v_min, v_max)\n        samples.append(get_decision_sequence(size))\n\n    with open(fname, \"wb\") as f:\n        pickle.dump(samples, f)\n\n\nclass CycleDataset(Dataset):\n    def __init__(self, fname):\n        super(CycleDataset, self).__init__()\n\n        with open(fname, \"rb\") as f:\n            self.dataset = pickle.load(f)\n\n    def __len__(self):\n        return len(self.dataset)\n\n    def __getitem__(self, index):\n        return self.dataset[index]\n\n    def collate_single(self, batch):\n        assert len(batch) == 1, \"Currently we do not support batched training\"\n        return batch[0]\n\n    def collate_batch(self, batch):\n        return batch\n\n\ndef dglGraph_to_adj_list(g):\n    adj_list = {}\n    for node in range(g.num_nodes()):\n        # For undirected graph. successors and\n        # predecessors are equivalent.\n        adj_list[node] = g.successors(node).tolist()\n    return adj_list\n\n\nclass CycleModelEvaluation(object):\n    def __init__(self, v_min, v_max, dir):\n        super(CycleModelEvaluation, self).__init__()\n\n        self.v_min = v_min\n        self.v_max = v_max\n\n        self.dir = dir\n\n    def rollout_and_examine(self, model, num_samples):\n        assert not model.training, \"You need to call model.eval().\"\n\n        num_total_size = 0\n        num_valid_size = 0\n        num_cycle = 0\n        num_valid = 0\n        plot_times = 0\n        adj_lists_to_plot = []\n\n        for i in range(num_samples):\n            sampled_graph = model()\n            if isinstance(sampled_graph, list):\n                # When the model is a batched implementation, a list of\n                # DGLGraph objects is returned. Note that with model(),\n                # we generate a single graph as with the non-batched\n                # implementation. We actually support batched generation\n                # during the inference so feel free to modify the code.\n                sampled_graph = sampled_graph[0]\n\n            sampled_adj_list = dglGraph_to_adj_list(sampled_graph)\n            adj_lists_to_plot.append(sampled_adj_list)\n\n            graph_size = sampled_graph.num_nodes()\n            valid_size = self.v_min <= graph_size <= self.v_max\n            cycle = is_cycle(sampled_graph)\n\n            num_total_size += graph_size\n\n            if valid_size:\n                num_valid_size += 1\n\n            if cycle:\n                num_cycle += 1\n\n            if valid_size and cycle:\n                num_valid += 1\n\n            if len(adj_lists_to_plot) >= 4:\n                plot_times += 1\n                fig, ((ax0, ax1), (ax2, ax3)) = plt.subplots(2, 2)\n                axes = {0: ax0, 1: ax1, 2: ax2, 3: ax3}\n                for i in range(4):\n                    nx.draw_circular(\n                        nx.from_dict_of_lists(adj_lists_to_plot[i]),\n                        with_labels=True,\n                        ax=axes[i],\n                    )\n\n                plt.savefig(self.dir + \"/samples/{:d}\".format(plot_times))\n                plt.close()\n\n                adj_lists_to_plot = []\n\n        self.num_samples_examined = num_samples\n        self.average_size = num_total_size / num_samples\n        self.valid_size_ratio = num_valid_size / num_samples\n        self.cycle_ratio = num_cycle / num_samples\n        self.valid_ratio = num_valid / num_samples\n\n    def write_summary(self):\n        def _format_value(v):\n            if isinstance(v, float):\n                return \"{:.4f}\".format(v)\n            elif isinstance(v, int):\n                return \"{:d}\".format(v)\n            else:\n                return \"{}\".format(v)\n\n        statistics = {\n            \"num_samples\": self.num_samples_examined,\n            \"v_min\": self.v_min,\n            \"v_max\": self.v_max,\n            \"average_size\": self.average_size,\n            \"valid_size_ratio\": self.valid_size_ratio,\n            \"cycle_ratio\": self.cycle_ratio,\n            \"valid_ratio\": self.valid_ratio,\n        }\n\n        model_eval_path = os.path.join(self.dir, \"model_eval.txt\")\n\n        with open(model_eval_path, \"w\") as f:\n            for key, value in statistics.items():\n                msg = \"{}\\t{}\\n\".format(key, _format_value(value))\n                f.write(msg)\n\n        print(\"Saved model evaluation statistics to {}\".format(model_eval_path))\n\n\nclass CyclePrinting(object):\n    def __init__(self, num_epochs, num_batches):\n        super(CyclePrinting, self).__init__()\n\n        self.num_epochs = num_epochs\n        self.num_batches = num_batches\n        self.batch_count = 0\n\n    def update(self, epoch, metrics):\n        self.batch_count = (self.batch_count) % self.num_batches + 1\n\n        msg = \"epoch {:d}/{:d}, batch {:d}/{:d}\".format(\n            epoch, self.num_epochs, self.batch_count, self.num_batches\n        )\n        for key, value in metrics.items():\n            msg += \", {}: {:4f}\".format(key, value)\n        print(msg)\n"
  },
  {
    "path": "examples/pytorch/dgmg/main.py",
    "content": "\"\"\"\nLearning Deep Generative Models of Graphs\nPaper: https://arxiv.org/pdf/1803.03324.pdf\n\nThis implementation works with a minibatch of size 1 only for both training and inference.\n\"\"\"\nimport argparse\nimport datetime\nimport time\n\nimport torch\nfrom model import DGMG\nfrom torch.nn.utils import clip_grad_norm_\nfrom torch.optim import Adam\nfrom torch.utils.data import DataLoader\n\n\ndef main(opts):\n    t1 = time.time()\n\n    # Setup dataset and data loader\n    if opts[\"dataset\"] == \"cycles\":\n        from cycles import CycleDataset, CycleModelEvaluation, CyclePrinting\n\n        dataset = CycleDataset(fname=opts[\"path_to_dataset\"])\n        evaluator = CycleModelEvaluation(\n            v_min=opts[\"min_size\"], v_max=opts[\"max_size\"], dir=opts[\"log_dir\"]\n        )\n        printer = CyclePrinting(\n            num_epochs=opts[\"nepochs\"],\n            num_batches=opts[\"ds_size\"] // opts[\"batch_size\"],\n        )\n    else:\n        raise ValueError(\"Unsupported dataset: {}\".format(opts[\"dataset\"]))\n\n    data_loader = DataLoader(\n        dataset,\n        batch_size=1,\n        shuffle=True,\n        num_workers=0,\n        collate_fn=dataset.collate_single,\n    )\n\n    # Initialize_model\n    model = DGMG(\n        v_max=opts[\"max_size\"],\n        node_hidden_size=opts[\"node_hidden_size\"],\n        num_prop_rounds=opts[\"num_propagation_rounds\"],\n    )\n\n    # Initialize optimizer\n    if opts[\"optimizer\"] == \"Adam\":\n        optimizer = Adam(model.parameters(), lr=opts[\"lr\"])\n    else:\n        raise ValueError(\"Unsupported argument for the optimizer\")\n\n    t2 = time.time()\n\n    # Training\n    model.train()\n    for epoch in range(opts[\"nepochs\"]):\n        batch_count = 0\n        batch_loss = 0\n        batch_prob = 0\n        optimizer.zero_grad()\n\n        for i, data in enumerate(data_loader):\n            log_prob = model(actions=data)\n            prob = log_prob.detach().exp()\n\n            loss = -log_prob / opts[\"batch_size\"]\n            prob_averaged = prob / opts[\"batch_size\"]\n\n            loss.backward()\n\n            batch_loss += loss.item()\n            batch_prob += prob_averaged.item()\n            batch_count += 1\n\n            if batch_count % opts[\"batch_size\"] == 0:\n                printer.update(\n                    epoch + 1,\n                    {\"averaged_loss\": batch_loss, \"averaged_prob\": batch_prob},\n                )\n\n                if opts[\"clip_grad\"]:\n                    clip_grad_norm_(model.parameters(), opts[\"clip_bound\"])\n\n                optimizer.step()\n\n                batch_loss = 0\n                batch_prob = 0\n                optimizer.zero_grad()\n\n    t3 = time.time()\n\n    model.eval()\n    evaluator.rollout_and_examine(model, opts[\"num_generated_samples\"])\n    evaluator.write_summary()\n\n    t4 = time.time()\n\n    print(\"It took {} to setup.\".format(datetime.timedelta(seconds=t2 - t1)))\n    print(\n        \"It took {} to finish training.\".format(\n            datetime.timedelta(seconds=t3 - t2)\n        )\n    )\n    print(\n        \"It took {} to finish evaluation.\".format(\n            datetime.timedelta(seconds=t4 - t3)\n        )\n    )\n    print(\n        \"--------------------------------------------------------------------------\"\n    )\n    print(\n        \"On average, an epoch takes {}.\".format(\n            datetime.timedelta(seconds=(t3 - t2) / opts[\"nepochs\"])\n        )\n    )\n\n    del model.g\n    torch.save(model, \"./model.pth\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"DGMG\")\n\n    # configure\n    parser.add_argument(\"--seed\", type=int, default=9284, help=\"random seed\")\n\n    # dataset\n    parser.add_argument(\n        \"--dataset\", choices=[\"cycles\"], default=\"cycles\", help=\"dataset to use\"\n    )\n    parser.add_argument(\n        \"--path-to-dataset\",\n        type=str,\n        default=\"cycles.p\",\n        help=\"load the dataset if it exists, \"\n        \"generate it and save to the path otherwise\",\n    )\n\n    # log\n    parser.add_argument(\n        \"--log-dir\",\n        default=\"./results\",\n        help=\"folder to save info like experiment configuration \"\n        \"or model evaluation results\",\n    )\n\n    # optimization\n    parser.add_argument(\n        \"--batch-size\",\n        type=int,\n        default=10,\n        help=\"batch size to use for training\",\n    )\n    parser.add_argument(\n        \"--clip-grad\",\n        action=\"store_true\",\n        default=True,\n        help=\"gradient clipping is required to prevent gradient explosion\",\n    )\n    parser.add_argument(\n        \"--clip-bound\",\n        type=float,\n        default=0.25,\n        help=\"constraint of gradient norm for gradient clipping\",\n    )\n\n    args = parser.parse_args()\n    from utils import setup\n\n    opts = setup(args)\n\n    main(opts)\n"
  },
  {
    "path": "examples/pytorch/dgmg/model.py",
    "content": "from functools import partial\n\nimport dgl\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.distributions import Bernoulli, Categorical\n\n\nclass GraphEmbed(nn.Module):\n    def __init__(self, node_hidden_size):\n        super(GraphEmbed, self).__init__()\n\n        # Setting from the paper\n        self.graph_hidden_size = 2 * node_hidden_size\n\n        # Embed graphs\n        self.node_gating = nn.Sequential(\n            nn.Linear(node_hidden_size, 1), nn.Sigmoid()\n        )\n        self.node_to_graph = nn.Linear(node_hidden_size, self.graph_hidden_size)\n\n    def forward(self, g):\n        if g.num_nodes() == 0:\n            return torch.zeros(1, self.graph_hidden_size)\n        else:\n            # Node features are stored as hv in ndata.\n            hvs = g.ndata[\"hv\"]\n            return (self.node_gating(hvs) * self.node_to_graph(hvs)).sum(\n                0, keepdim=True\n            )\n\n\nclass GraphProp(nn.Module):\n    def __init__(self, num_prop_rounds, node_hidden_size):\n        super(GraphProp, self).__init__()\n\n        self.num_prop_rounds = num_prop_rounds\n\n        # Setting from the paper\n        self.node_activation_hidden_size = 2 * node_hidden_size\n\n        message_funcs = []\n        self.reduce_funcs = []\n        node_update_funcs = []\n\n        for t in range(num_prop_rounds):\n            # input being [hv, hu, xuv]\n            message_funcs.append(\n                nn.Linear(\n                    2 * node_hidden_size + 1, self.node_activation_hidden_size\n                )\n            )\n\n            self.reduce_funcs.append(partial(self.dgmg_reduce, round=t))\n            node_update_funcs.append(\n                nn.GRUCell(self.node_activation_hidden_size, node_hidden_size)\n            )\n\n        self.message_funcs = nn.ModuleList(message_funcs)\n        self.node_update_funcs = nn.ModuleList(node_update_funcs)\n\n    def dgmg_msg(self, edges):\n        \"\"\"For an edge u->v, return concat([h_u, x_uv])\"\"\"\n        return {\"m\": torch.cat([edges.src[\"hv\"], edges.data[\"he\"]], dim=1)}\n\n    def dgmg_reduce(self, nodes, round):\n        hv_old = nodes.data[\"hv\"]\n        m = nodes.mailbox[\"m\"]\n        message = torch.cat(\n            [hv_old.unsqueeze(1).expand(-1, m.size(1), -1), m], dim=2\n        )\n        node_activation = (self.message_funcs[round](message)).sum(1)\n\n        return {\"a\": node_activation}\n\n    def forward(self, g):\n        if g.num_edges() == 0:\n            return\n        else:\n            for t in range(self.num_prop_rounds):\n                g.update_all(\n                    message_func=self.dgmg_msg, reduce_func=self.reduce_funcs[t]\n                )\n                g.ndata[\"hv\"] = self.node_update_funcs[t](\n                    g.ndata[\"a\"], g.ndata[\"hv\"]\n                )\n\n\ndef bernoulli_action_log_prob(logit, action):\n    \"\"\"Calculate the log p of an action with respect to a Bernoulli\n    distribution. Use logit rather than prob for numerical stability.\"\"\"\n    if action == 0:\n        return F.logsigmoid(-logit)\n    else:\n        return F.logsigmoid(logit)\n\n\nclass AddNode(nn.Module):\n    def __init__(self, graph_embed_func, node_hidden_size):\n        super(AddNode, self).__init__()\n\n        self.graph_op = {\"embed\": graph_embed_func}\n\n        self.stop = 1\n        self.add_node = nn.Linear(graph_embed_func.graph_hidden_size, 1)\n\n        # If to add a node, initialize its hv\n        self.node_type_embed = nn.Embedding(1, node_hidden_size)\n        self.initialize_hv = nn.Linear(\n            node_hidden_size + graph_embed_func.graph_hidden_size,\n            node_hidden_size,\n        )\n\n        self.init_node_activation = torch.zeros(1, 2 * node_hidden_size)\n\n    def _initialize_node_repr(self, g, node_type, graph_embed):\n        num_nodes = g.num_nodes()\n        hv_init = self.initialize_hv(\n            torch.cat(\n                [\n                    self.node_type_embed(torch.LongTensor([node_type])),\n                    graph_embed,\n                ],\n                dim=1,\n            )\n        )\n        g.nodes[num_nodes - 1].data[\"hv\"] = hv_init\n        g.nodes[num_nodes - 1].data[\"a\"] = self.init_node_activation\n\n    def prepare_training(self):\n        self.log_prob = []\n\n    def forward(self, g, action=None):\n        graph_embed = self.graph_op[\"embed\"](g)\n\n        logit = self.add_node(graph_embed)\n        prob = torch.sigmoid(logit)\n\n        if not self.training:\n            action = Bernoulli(prob).sample().item()\n        stop = bool(action == self.stop)\n\n        if not stop:\n            g.add_nodes(1)\n            self._initialize_node_repr(g, action, graph_embed)\n\n        if self.training:\n            sample_log_prob = bernoulli_action_log_prob(logit, action)\n            self.log_prob.append(sample_log_prob)\n\n        return stop\n\n\nclass AddEdge(nn.Module):\n    def __init__(self, graph_embed_func, node_hidden_size):\n        super(AddEdge, self).__init__()\n\n        self.graph_op = {\"embed\": graph_embed_func}\n        self.add_edge = nn.Linear(\n            graph_embed_func.graph_hidden_size + node_hidden_size, 1\n        )\n\n    def prepare_training(self):\n        self.log_prob = []\n\n    def forward(self, g, action=None):\n        graph_embed = self.graph_op[\"embed\"](g)\n        src_embed = g.nodes[g.num_nodes() - 1].data[\"hv\"]\n\n        logit = self.add_edge(torch.cat([graph_embed, src_embed], dim=1))\n        prob = torch.sigmoid(logit)\n\n        if not self.training:\n            action = Bernoulli(prob).sample().item()\n        to_add_edge = bool(action == 0)\n\n        if self.training:\n            sample_log_prob = bernoulli_action_log_prob(logit, action)\n            self.log_prob.append(sample_log_prob)\n\n        return to_add_edge\n\n\nclass ChooseDestAndUpdate(nn.Module):\n    def __init__(self, graph_prop_func, node_hidden_size):\n        super(ChooseDestAndUpdate, self).__init__()\n\n        self.graph_op = {\"prop\": graph_prop_func}\n        self.choose_dest = nn.Linear(2 * node_hidden_size, 1)\n\n    def _initialize_edge_repr(self, g, src_list, dest_list):\n        # For untyped edges, we only add 1 to indicate its existence.\n        # For multiple edge types, we can use a one hot representation\n        # or an embedding module.\n        edge_repr = torch.ones(len(src_list), 1)\n        g.edges[src_list, dest_list].data[\"he\"] = edge_repr\n\n    def prepare_training(self):\n        self.log_prob = []\n\n    def forward(self, g, dest):\n        src = g.num_nodes() - 1\n        possible_dests = range(src)\n\n        src_embed_expand = g.nodes[src].data[\"hv\"].expand(src, -1)\n        possible_dests_embed = g.nodes[possible_dests].data[\"hv\"]\n\n        dests_scores = self.choose_dest(\n            torch.cat([possible_dests_embed, src_embed_expand], dim=1)\n        ).view(1, -1)\n        dests_probs = F.softmax(dests_scores, dim=1)\n\n        if not self.training:\n            dest = Categorical(dests_probs).sample().item()\n\n        if not g.has_edges_between(src, dest):\n            # For undirected graphs, we add edges for both directions\n            # so that we can perform graph propagation.\n            src_list = [src, dest]\n            dest_list = [dest, src]\n\n            g.add_edges(src_list, dest_list)\n            self._initialize_edge_repr(g, src_list, dest_list)\n\n            self.graph_op[\"prop\"](g)\n\n        if self.training:\n            if dests_probs.nelement() > 1:\n                self.log_prob.append(\n                    F.log_softmax(dests_scores, dim=1)[:, dest : dest + 1]\n                )\n\n\nclass DGMG(nn.Module):\n    def __init__(self, v_max, node_hidden_size, num_prop_rounds):\n        super(DGMG, self).__init__()\n\n        # Graph configuration\n        self.v_max = v_max\n\n        # Graph embedding module\n        self.graph_embed = GraphEmbed(node_hidden_size)\n\n        # Graph propagation module\n        self.graph_prop = GraphProp(num_prop_rounds, node_hidden_size)\n\n        # Actions\n        self.add_node_agent = AddNode(self.graph_embed, node_hidden_size)\n        self.add_edge_agent = AddEdge(self.graph_embed, node_hidden_size)\n        self.choose_dest_agent = ChooseDestAndUpdate(\n            self.graph_prop, node_hidden_size\n        )\n\n        # Weight initialization\n        self.init_weights()\n\n    def init_weights(self):\n        from utils import dgmg_message_weight_init, weights_init\n\n        self.graph_embed.apply(weights_init)\n        self.graph_prop.apply(weights_init)\n        self.add_node_agent.apply(weights_init)\n        self.add_edge_agent.apply(weights_init)\n        self.choose_dest_agent.apply(weights_init)\n\n        self.graph_prop.message_funcs.apply(dgmg_message_weight_init)\n\n    @property\n    def action_step(self):\n        old_step_count = self.step_count\n        self.step_count += 1\n\n        return old_step_count\n\n    def prepare_for_train(self):\n        self.step_count = 0\n\n        self.add_node_agent.prepare_training()\n        self.add_edge_agent.prepare_training()\n        self.choose_dest_agent.prepare_training()\n\n    def add_node_and_update(self, a=None):\n        \"\"\"Decide if to add a new node.\n        If a new node should be added, update the graph.\"\"\"\n\n        return self.add_node_agent(self.g, a)\n\n    def add_edge_or_not(self, a=None):\n        \"\"\"Decide if a new edge should be added.\"\"\"\n\n        return self.add_edge_agent(self.g, a)\n\n    def choose_dest_and_update(self, a=None):\n        \"\"\"Choose destination and connect it to the latest node.\n        Add edges for both directions and update the graph.\"\"\"\n\n        self.choose_dest_agent(self.g, a)\n\n    def get_log_prob(self):\n        return (\n            torch.cat(self.add_node_agent.log_prob).sum()\n            + torch.cat(self.add_edge_agent.log_prob).sum()\n            + torch.cat(self.choose_dest_agent.log_prob).sum()\n        )\n\n    def forward_train(self, actions):\n        self.prepare_for_train()\n\n        stop = self.add_node_and_update(a=actions[self.action_step])\n\n        while not stop:\n            to_add_edge = self.add_edge_or_not(a=actions[self.action_step])\n            while to_add_edge:\n                self.choose_dest_and_update(a=actions[self.action_step])\n                to_add_edge = self.add_edge_or_not(a=actions[self.action_step])\n            stop = self.add_node_and_update(a=actions[self.action_step])\n\n        return self.get_log_prob()\n\n    def forward_inference(self):\n        stop = self.add_node_and_update()\n        while (not stop) and (self.g.num_nodes() < self.v_max + 1):\n            num_trials = 0\n            to_add_edge = self.add_edge_or_not()\n            while to_add_edge and (num_trials < self.g.num_nodes() - 1):\n                self.choose_dest_and_update()\n                num_trials += 1\n                to_add_edge = self.add_edge_or_not()\n            stop = self.add_node_and_update()\n\n        return self.g\n\n    def forward(self, actions=None):\n        # The graph we will work on\n        self.g = dgl.DGLGraph()\n\n        # If there are some features for nodes and edges,\n        # zero tensors will be set for those of new nodes and edges.\n        self.g.set_n_initializer(dgl.frame.zero_initializer)\n        self.g.set_e_initializer(dgl.frame.zero_initializer)\n\n        if self.training:\n            return self.forward_train(actions)\n        else:\n            return self.forward_inference()\n"
  },
  {
    "path": "examples/pytorch/dgmg/utils.py",
    "content": "import datetime\nimport os\nimport random\nfrom pprint import pprint\n\nimport matplotlib.pyplot as plt\nimport torch\nimport torch.backends.cudnn as cudnn\nimport torch.nn as nn\nimport torch.nn.init as init\n\n########################################################################################################################\n#                                                    configuration                                                     #\n########################################################################################################################\n\n\ndef mkdir_p(path):\n    import errno\n\n    try:\n        os.makedirs(path)\n        print(\"Created directory {}\".format(path))\n    except OSError as exc:\n        if exc.errno == errno.EEXIST and os.path.isdir(path):\n            print(\"Directory {} already exists.\".format(path))\n        else:\n            raise\n\n\ndef date_filename(base_dir=\"./\"):\n    dt = datetime.datetime.now()\n    return os.path.join(\n        base_dir,\n        \"{}_{:02d}-{:02d}-{:02d}\".format(\n            dt.date(), dt.hour, dt.minute, dt.second\n        ),\n    )\n\n\ndef setup_log_dir(opts):\n    log_dir = \"{}\".format(date_filename(opts[\"log_dir\"]))\n    mkdir_p(log_dir)\n    return log_dir\n\n\ndef save_arg_dict(opts, filename=\"settings.txt\"):\n    def _format_value(v):\n        if isinstance(v, float):\n            return \"{:.4f}\".format(v)\n        elif isinstance(v, int):\n            return \"{:d}\".format(v)\n        else:\n            return \"{}\".format(v)\n\n    save_path = os.path.join(opts[\"log_dir\"], filename)\n    with open(save_path, \"w\") as f:\n        for key, value in opts.items():\n            f.write(\"{}\\t{}\\n\".format(key, _format_value(value)))\n    print(\"Saved settings to {}\".format(save_path))\n\n\ndef setup(args):\n    opts = args.__dict__.copy()\n\n    cudnn.benchmark = False\n    cudnn.deterministic = True\n\n    # Seed\n    if opts[\"seed\"] is None:\n        opts[\"seed\"] = random.randint(1, 10000)\n    random.seed(opts[\"seed\"])\n    torch.manual_seed(opts[\"seed\"])\n\n    # Dataset\n    from configure import dataset_based_configure\n\n    opts = dataset_based_configure(opts)\n\n    assert (\n        opts[\"path_to_dataset\"] is not None\n    ), \"Expect path to dataset to be set.\"\n    if not os.path.exists(opts[\"path_to_dataset\"]):\n        if opts[\"dataset\"] == \"cycles\":\n            from cycles import generate_dataset\n\n            generate_dataset(\n                opts[\"min_size\"],\n                opts[\"max_size\"],\n                opts[\"ds_size\"],\n                opts[\"path_to_dataset\"],\n            )\n        else:\n            raise ValueError(\"Unsupported dataset: {}\".format(opts[\"dataset\"]))\n\n    # Optimization\n    if opts[\"clip_grad\"]:\n        assert (\n            opts[\"clip_grad\"] is not None\n        ), \"Expect the gradient norm constraint to be set.\"\n\n    # Log\n    print(\"Prepare logging directory...\")\n    log_dir = setup_log_dir(opts)\n    opts[\"log_dir\"] = log_dir\n    mkdir_p(log_dir + \"/samples\")\n\n    plt.switch_backend(\"Agg\")\n\n    save_arg_dict(opts)\n    pprint(opts)\n\n    return opts\n\n\n########################################################################################################################\n#                                                         model                                                        #\n########################################################################################################################\n\n\ndef weights_init(m):\n    \"\"\"\n    Code from https://gist.github.com/jeasinema/ed9236ce743c8efaf30fa2ff732749f5\n    Usage:\n        model = Model()\n        model.apply(weight_init)\n    \"\"\"\n    if isinstance(m, nn.Linear):\n        init.xavier_normal_(m.weight.data)\n        init.normal_(m.bias.data)\n    elif isinstance(m, nn.GRUCell):\n        for param in m.parameters():\n            if len(param.shape) >= 2:\n                init.orthogonal_(param.data)\n            else:\n                init.normal_(param.data)\n\n\ndef dgmg_message_weight_init(m):\n    \"\"\"\n    This is similar as the function above where we initialize linear layers from a normal distribution with std\n    1./10 as suggested by the author. This should only be used for the message passing functions, i.e. fe's in the\n    paper.\n    \"\"\"\n\n    def _weight_init(m):\n        if isinstance(m, nn.Linear):\n            init.normal_(m.weight.data, std=1.0 / 10)\n            init.normal_(m.bias.data, std=1.0 / 10)\n        else:\n            raise ValueError(\"Expected the input to be of type nn.Linear!\")\n\n    if isinstance(m, nn.ModuleList):\n        for layer in m:\n            layer.apply(_weight_init)\n    else:\n        m.apply(_weight_init)\n"
  },
  {
    "path": "examples/pytorch/diffpool/README.md",
    "content": "Hierarchical Graph Representation Learning with Differentiable Pooling\n============\n\n\nPaper link: [https://arxiv.org/abs/1806.08804](https://arxiv.org/abs/1806.08804)\n\nAuthor's code repo: [https://github.com/RexYing/diffpool](https://github.com/RexYing/diffpool)\n\nThis folder contains a DGL implementation of the DiffPool model. The first pooling layer is computed with DGL, and following pooling layers are computed with tensorized operation since the pooled graphs are dense.\n\nDependencies\n------------\n* PyTorch 1.0+\n\nHow to run\n----------\n\n```bash\npython train.py --dataset ENZYMES --pool_ratio 0.10 --num_pool 1 --epochs 1000\npython train.py --dataset DD --pool_ratio 0.15 --num_pool 1  --batch-size 10\n```\nPerformance\n-----------\nENZYMES 63.33% (with early stopping)\nDD 79.31% (with early stopping)\n\n\n## Update (2021-03-09)\n\n**Changes:**\n\n* Fix bug in Diffpool: the wrong `assign_dim` parameter\n* Improve efficiency of DiffPool, make the model independent of batch size. Remove redundant computation.\n\n\n**Efficiency:**\n\nOn V100-SXM2 16GB\n\n|                    | Train time/epoch (original) (s) | Train time/epoch (improved) (s) |\n| ------------------ | ------------------------------: | ------------------------------: |\n| DD (batch_size=10) |                          21.302 |                      **17.282** |\n| DD (batch_size=20) |                             OOM |                      **44.682** |\n| ENZYMES            |                           1.749 |                       **1.685** |\n\n|                    | Memory usage (original) (MB) | Memory usage (improved) (MB) |\n| ------------------ | ---------------------------: | ---------------------------: |\n| DD (batch_size=10) |                     5274.620 |                 **2928.568** |\n| DD (batch_size=20) |                          OOM |                **10088.889** |\n| ENZYMES            |                       25.685 |                   **21.909** |\n\n**Accuracy**\n\nEach experiment with improved model is only conducted once, thus the result may has noise.\n\n|         |   Original |   Improved |\n| ------- | ---------: | ---------: |\n| DD      | **79.31%** |     78.33% |\n| ENZYMES |     63.33% | **68.33%** |\n"
  },
  {
    "path": "examples/pytorch/diffpool/data_utils.py",
    "content": "import numpy as np\nimport torch\n\n\ndef one_hotify(labels, pad=-1):\n    \"\"\"\n    cast label to one hot vector\n    \"\"\"\n    num_instances = len(labels)\n    if pad <= 0:\n        dim_embedding = np.max(labels) + 1  # zero-indexed assumed\n    else:\n        assert pad > 0, \"result_dim for padding one hot embedding not set!\"\n        dim_embedding = pad + 1\n    embeddings = np.zeros((num_instances, dim_embedding))\n    embeddings[np.arange(num_instances), labels] = 1\n\n    return embeddings\n\n\ndef pre_process(dataset, prog_args):\n    \"\"\"\n    diffpool specific data partition, pre-process and shuffling\n    \"\"\"\n    if prog_args.data_mode != \"default\":\n        print(\"overwrite node attributes with DiffPool's preprocess setting\")\n        if prog_args.data_mode == \"id\":\n            for g, _ in dataset:\n                id_list = np.arange(g.num_nodes())\n                g.ndata[\"feat\"] = one_hotify(id_list, pad=dataset.max_num_node)\n\n        elif prog_args.data_mode == \"deg-num\":\n            for g, _ in dataset:\n                g.ndata[\"feat\"] = np.expand_dims(g.in_degrees(), axis=1)\n\n        elif prog_args.data_mode == \"deg\":\n            for g in dataset:\n                degs = list(g.in_degrees())\n                degs_one_hot = one_hotify(degs, pad=dataset.max_degrees)\n                g.ndata[\"feat\"] = degs_one_hot\n"
  },
  {
    "path": "examples/pytorch/diffpool/model/__init__.py",
    "content": ""
  },
  {
    "path": "examples/pytorch/diffpool/model/dgl_layers/__init__.py",
    "content": "from .gnn import DiffPoolBatchedGraphLayer, GraphSage, GraphSageLayer\n"
  },
  {
    "path": "examples/pytorch/diffpool/model/dgl_layers/aggregator.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass Aggregator(nn.Module):\n    \"\"\"\n    Base Aggregator class. Adapting\n    from PR# 403\n\n    This class is not supposed to be called\n    \"\"\"\n\n    def __init__(self):\n        super(Aggregator, self).__init__()\n\n    def forward(self, node):\n        neighbour = node.mailbox[\"m\"]\n        c = self.aggre(neighbour)\n        return {\"c\": c}\n\n    def aggre(self, neighbour):\n        # N x F\n        raise NotImplementedError\n\n\nclass MeanAggregator(Aggregator):\n    \"\"\"\n    Mean Aggregator for graphsage\n    \"\"\"\n\n    def __init__(self):\n        super(MeanAggregator, self).__init__()\n\n    def aggre(self, neighbour):\n        mean_neighbour = torch.mean(neighbour, dim=1)\n        return mean_neighbour\n\n\nclass MaxPoolAggregator(Aggregator):\n    \"\"\"\n    Maxpooling aggregator for graphsage\n    \"\"\"\n\n    def __init__(self, in_feats, out_feats, activation, bias):\n        super(MaxPoolAggregator, self).__init__()\n        self.linear = nn.Linear(in_feats, out_feats, bias=bias)\n        self.activation = activation\n        # Xavier initialization of weight\n        nn.init.xavier_uniform_(\n            self.linear.weight, gain=nn.init.calculate_gain(\"relu\")\n        )\n\n    def aggre(self, neighbour):\n        neighbour = self.linear(neighbour)\n        if self.activation:\n            neighbour = self.activation(neighbour)\n        maxpool_neighbour = torch.max(neighbour, dim=1)[0]\n        return maxpool_neighbour\n\n\nclass LSTMAggregator(Aggregator):\n    \"\"\"\n    LSTM aggregator for graphsage\n    \"\"\"\n\n    def __init__(self, in_feats, hidden_feats):\n        super(LSTMAggregator, self).__init__()\n        self.lstm = nn.LSTM(in_feats, hidden_feats, batch_first=True)\n        self.hidden_dim = hidden_feats\n        self.hidden = self.init_hidden()\n\n        nn.init.xavier_uniform_(\n            self.lstm.weight, gain=nn.init.calculate_gain(\"relu\")\n        )\n\n    def init_hidden(self):\n        \"\"\"\n        Defaulted to initialite all zero\n        \"\"\"\n        return (\n            torch.zeros(1, 1, self.hidden_dim),\n            torch.zeros(1, 1, self.hidden_dim),\n        )\n\n    def aggre(self, neighbours):\n        \"\"\"\n        aggregation function\n        \"\"\"\n        # N X F\n        rand_order = torch.randperm(neighbours.size()[1])\n        neighbours = neighbours[:, rand_order, :]\n\n        (lstm_out, self.hidden) = self.lstm(\n            neighbours.view(neighbours.size()[0], neighbours.size()[1], -1)\n        )\n        return lstm_out[:, -1, :]\n\n    def forward(self, node):\n        neighbour = node.mailbox[\"m\"]\n        c = self.aggre(neighbour)\n        return {\"c\": c}\n"
  },
  {
    "path": "examples/pytorch/diffpool/model/dgl_layers/bundler.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass Bundler(nn.Module):\n    \"\"\"\n    Bundler, which will be the node_apply function in DGL paradigm\n    \"\"\"\n\n    def __init__(self, in_feats, out_feats, activation, dropout, bias=True):\n        super(Bundler, self).__init__()\n        self.dropout = nn.Dropout(p=dropout)\n        self.linear = nn.Linear(in_feats * 2, out_feats, bias)\n        self.activation = activation\n\n        nn.init.xavier_uniform_(\n            self.linear.weight, gain=nn.init.calculate_gain(\"relu\")\n        )\n\n    def concat(self, h, aggre_result):\n        bundle = torch.cat((h, aggre_result), 1)\n        bundle = self.linear(bundle)\n        return bundle\n\n    def forward(self, node):\n        h = node.data[\"h\"]\n        c = node.data[\"c\"]\n        bundle = self.concat(h, c)\n        bundle = F.normalize(bundle, p=2, dim=1)\n        if self.activation:\n            bundle = self.activation(bundle)\n        return {\"h\": bundle}\n"
  },
  {
    "path": "examples/pytorch/diffpool/model/dgl_layers/gnn.py",
    "content": "import dgl.function as fn\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom scipy.linalg import block_diag\n\nfrom model.loss import EntropyLoss\nfrom ..model_utils import masked_softmax\n\nfrom .aggregator import LSTMAggregator, MaxPoolAggregator, MeanAggregator\nfrom .bundler import Bundler\n\n\nclass GraphSageLayer(nn.Module):\n    \"\"\"\n    GraphSage layer in Inductive learning paper by hamilton\n    Here, graphsage layer is a reduced function in DGL framework\n    \"\"\"\n\n    def __init__(\n        self,\n        in_feats,\n        out_feats,\n        activation,\n        dropout,\n        aggregator_type,\n        bn=False,\n        bias=True,\n    ):\n        super(GraphSageLayer, self).__init__()\n        self.use_bn = bn\n        self.bundler = Bundler(\n            in_feats, out_feats, activation, dropout, bias=bias\n        )\n        self.dropout = nn.Dropout(p=dropout)\n\n        if aggregator_type == \"maxpool\":\n            self.aggregator = MaxPoolAggregator(\n                in_feats, in_feats, activation, bias\n            )\n        elif aggregator_type == \"lstm\":\n            self.aggregator = LSTMAggregator(in_feats, in_feats)\n        else:\n            self.aggregator = MeanAggregator()\n\n    def forward(self, g, h):\n        h = self.dropout(h)\n        g.ndata[\"h\"] = h\n        if self.use_bn and not hasattr(self, \"bn\"):\n            device = h.device\n            self.bn = nn.BatchNorm1d(h.size()[1]).to(device)\n        g.update_all(fn.copy_u(u=\"h\", out=\"m\"), self.aggregator, self.bundler)\n        if self.use_bn:\n            h = self.bn(h)\n        h = g.ndata.pop(\"h\")\n        return h\n\n\nclass GraphSage(nn.Module):\n    \"\"\"\n    Grahpsage network that concatenate several graphsage layer\n    \"\"\"\n\n    def __init__(\n        self,\n        in_feats,\n        n_hidden,\n        n_classes,\n        n_layers,\n        activation,\n        dropout,\n        aggregator_type,\n    ):\n        super(GraphSage, self).__init__()\n        self.layers = nn.ModuleList()\n\n        # input layer\n        self.layers.append(\n            GraphSageLayer(\n                in_feats, n_hidden, activation, dropout, aggregator_type\n            )\n        )\n        # hidden layers\n        for _ in range(n_layers - 1):\n            self.layers.append(\n                GraphSageLayer(\n                    n_hidden, n_hidden, activation, dropout, aggregator_type\n                )\n            )\n        # output layer\n        self.layers.append(\n            GraphSageLayer(n_hidden, n_classes, None, dropout, aggregator_type)\n        )\n\n    def forward(self, g, features):\n        h = features\n        for layer in self.layers:\n            h = layer(g, h)\n        return h\n\n\nclass DiffPoolBatchedGraphLayer(nn.Module):\n    def __init__(\n        self,\n        input_dim,\n        assign_dim,\n        output_feat_dim,\n        activation,\n        dropout,\n        aggregator_type,\n        link_pred,\n    ):\n        super(DiffPoolBatchedGraphLayer, self).__init__()\n        self.embedding_dim = input_dim\n        self.assign_dim = assign_dim\n        self.hidden_dim = output_feat_dim\n        self.link_pred = link_pred\n        self.feat_gc = GraphSageLayer(\n            input_dim, output_feat_dim, activation, dropout, aggregator_type\n        )\n        self.pool_gc = GraphSageLayer(\n            input_dim, assign_dim, activation, dropout, aggregator_type\n        )\n        self.reg_loss = nn.ModuleList([])\n        self.loss_log = {}\n        self.reg_loss.append(EntropyLoss())\n\n    def forward(self, g, h):\n        feat = self.feat_gc(\n            g, h\n        )  # size = (sum_N, F_out), sum_N is num of nodes in this batch\n        device = feat.device\n        assign_tensor = self.pool_gc(\n            g, h\n        )  # size = (sum_N, N_a), N_a is num of nodes in pooled graph.\n        assign_tensor = F.softmax(assign_tensor, dim=1)\n        assign_tensor = torch.split(assign_tensor, g.batch_num_nodes().tolist())\n        assign_tensor = torch.block_diag(\n            *assign_tensor\n        )  # size = (sum_N, batch_size * N_a)\n\n        h = torch.matmul(torch.t(assign_tensor), feat)\n        adj = g.adj_external(transpose=True, ctx=device)\n        adj_new = torch.sparse.mm(adj, assign_tensor)\n        adj_new = torch.mm(torch.t(assign_tensor), adj_new)\n\n        if self.link_pred:\n            current_lp_loss = torch.norm(\n                adj.to_dense() - torch.mm(assign_tensor, torch.t(assign_tensor))\n            ) / np.power(g.num_nodes(), 2)\n            self.loss_log[\"LinkPredLoss\"] = current_lp_loss\n\n        for loss_layer in self.reg_loss:\n            loss_name = str(type(loss_layer).__name__)\n            self.loss_log[loss_name] = loss_layer(adj, adj_new, assign_tensor)\n\n        return adj_new, h\n"
  },
  {
    "path": "examples/pytorch/diffpool/model/encoder.py",
    "content": "import time\n\nimport dgl\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom scipy.linalg import block_diag\nfrom torch.nn import init\n\nfrom .dgl_layers import DiffPoolBatchedGraphLayer, GraphSage, GraphSageLayer\nfrom .model_utils import batch2tensor\nfrom .tensorized_layers import *\n\n\nclass DiffPool(nn.Module):\n    \"\"\"\n    DiffPool Fuse\n    \"\"\"\n\n    def __init__(\n        self,\n        input_dim,\n        hidden_dim,\n        embedding_dim,\n        label_dim,\n        activation,\n        n_layers,\n        dropout,\n        n_pooling,\n        linkpred,\n        batch_size,\n        aggregator_type,\n        assign_dim,\n        pool_ratio,\n        cat=False,\n    ):\n        super(DiffPool, self).__init__()\n        self.link_pred = linkpred\n        self.concat = cat\n        self.n_pooling = n_pooling\n        self.batch_size = batch_size\n        self.link_pred_loss = []\n        self.entropy_loss = []\n\n        # list of GNN modules before the first diffpool operation\n        self.gc_before_pool = nn.ModuleList()\n        self.diffpool_layers = nn.ModuleList()\n\n        # list of list of GNN modules, each list after one diffpool operation\n        self.gc_after_pool = nn.ModuleList()\n        self.assign_dim = assign_dim\n        self.bn = True\n        self.num_aggs = 1\n\n        # constructing layers\n        # layers before diffpool\n        assert n_layers >= 3, \"n_layers too few\"\n        self.gc_before_pool.append(\n            GraphSageLayer(\n                input_dim,\n                hidden_dim,\n                activation,\n                dropout,\n                aggregator_type,\n                self.bn,\n            )\n        )\n        for _ in range(n_layers - 2):\n            self.gc_before_pool.append(\n                GraphSageLayer(\n                    hidden_dim,\n                    hidden_dim,\n                    activation,\n                    dropout,\n                    aggregator_type,\n                    self.bn,\n                )\n            )\n        self.gc_before_pool.append(\n            GraphSageLayer(\n                hidden_dim, embedding_dim, None, dropout, aggregator_type\n            )\n        )\n\n        assign_dims = []\n        assign_dims.append(self.assign_dim)\n        if self.concat:\n            # diffpool layer receive pool_emedding_dim node feature tensor\n            # and return pool_embedding_dim node embedding\n            pool_embedding_dim = hidden_dim * (n_layers - 1) + embedding_dim\n        else:\n            pool_embedding_dim = embedding_dim\n\n        self.first_diffpool_layer = DiffPoolBatchedGraphLayer(\n            pool_embedding_dim,\n            self.assign_dim,\n            hidden_dim,\n            activation,\n            dropout,\n            aggregator_type,\n            self.link_pred,\n        )\n        gc_after_per_pool = nn.ModuleList()\n\n        for _ in range(n_layers - 1):\n            gc_after_per_pool.append(BatchedGraphSAGE(hidden_dim, hidden_dim))\n        gc_after_per_pool.append(BatchedGraphSAGE(hidden_dim, embedding_dim))\n        self.gc_after_pool.append(gc_after_per_pool)\n\n        self.assign_dim = int(self.assign_dim * pool_ratio)\n        # each pooling module\n        for _ in range(n_pooling - 1):\n            self.diffpool_layers.append(\n                BatchedDiffPool(\n                    pool_embedding_dim,\n                    self.assign_dim,\n                    hidden_dim,\n                    self.link_pred,\n                )\n            )\n            gc_after_per_pool = nn.ModuleList()\n            for _ in range(n_layers - 1):\n                gc_after_per_pool.append(\n                    BatchedGraphSAGE(hidden_dim, hidden_dim)\n                )\n            gc_after_per_pool.append(\n                BatchedGraphSAGE(hidden_dim, embedding_dim)\n            )\n            self.gc_after_pool.append(gc_after_per_pool)\n            assign_dims.append(self.assign_dim)\n            self.assign_dim = int(self.assign_dim * pool_ratio)\n\n        # predicting layer\n        if self.concat:\n            self.pred_input_dim = (\n                pool_embedding_dim * self.num_aggs * (n_pooling + 1)\n            )\n        else:\n            self.pred_input_dim = embedding_dim * self.num_aggs\n        self.pred_layer = nn.Linear(self.pred_input_dim, label_dim)\n\n        # weight initialization\n        for m in self.modules():\n            if isinstance(m, nn.Linear):\n                m.weight.data = init.xavier_uniform_(\n                    m.weight.data, gain=nn.init.calculate_gain(\"relu\")\n                )\n                if m.bias is not None:\n                    m.bias.data = init.constant_(m.bias.data, 0.0)\n\n    def gcn_forward(self, g, h, gc_layers, cat=False):\n        \"\"\"\n        Return gc_layer embedding cat.\n        \"\"\"\n        block_readout = []\n        for gc_layer in gc_layers[:-1]:\n            h = gc_layer(g, h)\n            block_readout.append(h)\n        h = gc_layers[-1](g, h)\n        block_readout.append(h)\n        if cat:\n            block = torch.cat(block_readout, dim=1)  # N x F, F = F1 + F2 + ...\n        else:\n            block = h\n        return block\n\n    def gcn_forward_tensorized(self, h, adj, gc_layers, cat=False):\n        block_readout = []\n        for gc_layer in gc_layers:\n            h = gc_layer(h, adj)\n            block_readout.append(h)\n        if cat:\n            block = torch.cat(block_readout, dim=2)  # N x F, F = F1 + F2 + ...\n        else:\n            block = h\n        return block\n\n    def forward(self, g):\n        self.link_pred_loss = []\n        self.entropy_loss = []\n        h = g.ndata[\"feat\"]\n        # node feature for assignment matrix computation is the same as the\n        # original node feature\n        h_a = h\n\n        out_all = []\n\n        # we use GCN blocks to get an embedding first\n        g_embedding = self.gcn_forward(g, h, self.gc_before_pool, self.concat)\n\n        g.ndata[\"h\"] = g_embedding\n\n        readout = dgl.sum_nodes(g, \"h\")\n        out_all.append(readout)\n        if self.num_aggs == 2:\n            readout = dgl.max_nodes(g, \"h\")\n            out_all.append(readout)\n\n        adj, h = self.first_diffpool_layer(g, g_embedding)\n        node_per_pool_graph = int(adj.size()[0] / len(g.batch_num_nodes()))\n\n        h, adj = batch2tensor(adj, h, node_per_pool_graph)\n        h = self.gcn_forward_tensorized(\n            h, adj, self.gc_after_pool[0], self.concat\n        )\n        readout = torch.sum(h, dim=1)\n        out_all.append(readout)\n        if self.num_aggs == 2:\n            readout, _ = torch.max(h, dim=1)\n            out_all.append(readout)\n\n        for i, diffpool_layer in enumerate(self.diffpool_layers):\n            h, adj = diffpool_layer(h, adj)\n            h = self.gcn_forward_tensorized(\n                h, adj, self.gc_after_pool[i + 1], self.concat\n            )\n            readout = torch.sum(h, dim=1)\n            out_all.append(readout)\n            if self.num_aggs == 2:\n                readout, _ = torch.max(h, dim=1)\n                out_all.append(readout)\n        if self.concat or self.num_aggs > 1:\n            final_readout = torch.cat(out_all, dim=1)\n        else:\n            final_readout = readout\n        ypred = self.pred_layer(final_readout)\n        return ypred\n\n    def loss(self, pred, label):\n        \"\"\"\n        loss function\n        \"\"\"\n        # softmax + CE\n        criterion = nn.CrossEntropyLoss()\n        loss = criterion(pred, label)\n        for key, value in self.first_diffpool_layer.loss_log.items():\n            loss += value\n        for diffpool_layer in self.diffpool_layers:\n            for key, value in diffpool_layer.loss_log.items():\n                loss += value\n        return loss\n"
  },
  {
    "path": "examples/pytorch/diffpool/model/loss.py",
    "content": "import torch\nimport torch.nn as nn\n\n\nclass EntropyLoss(nn.Module):\n    # Return Scalar\n    def forward(self, adj, anext, s_l):\n        entropy = (\n            (torch.distributions.Categorical(probs=s_l).entropy())\n            .sum(-1)\n            .mean(-1)\n        )\n        assert not torch.isnan(entropy)\n        return entropy\n\n\nclass LinkPredLoss(nn.Module):\n    def forward(self, adj, anext, s_l):\n        link_pred_loss = (adj - s_l.matmul(s_l.transpose(-1, -2))).norm(\n            dim=(1, 2)\n        )\n        link_pred_loss = link_pred_loss / (adj.size(1) * adj.size(2))\n        return link_pred_loss.mean()\n"
  },
  {
    "path": "examples/pytorch/diffpool/model/model_utils.py",
    "content": "import torch as th\nfrom torch.autograd import Function\n\n\ndef batch2tensor(batch_adj, batch_feat, node_per_pool_graph):\n    \"\"\"\n    transform a batched graph to batched adjacency tensor and node feature tensor\n    \"\"\"\n    batch_size = int(batch_adj.size()[0] / node_per_pool_graph)\n    adj_list = []\n    feat_list = []\n    for i in range(batch_size):\n        start = i * node_per_pool_graph\n        end = (i + 1) * node_per_pool_graph\n        adj_list.append(batch_adj[start:end, start:end])\n        feat_list.append(batch_feat[start:end, :])\n    adj_list = list(map(lambda x: th.unsqueeze(x, 0), adj_list))\n    feat_list = list(map(lambda x: th.unsqueeze(x, 0), feat_list))\n    adj = th.cat(adj_list, dim=0)\n    feat = th.cat(feat_list, dim=0)\n\n    return feat, adj\n\n\ndef masked_softmax(\n    matrix, mask, dim=-1, memory_efficient=True, mask_fill_value=-1e32\n):\n    \"\"\"\n    masked_softmax for dgl batch graph\n    code snippet contributed by AllenNLP (https://github.com/allenai/allennlp)\n    \"\"\"\n    if mask is None:\n        result = th.nn.functional.softmax(matrix, dim=dim)\n    else:\n        mask = mask.float()\n        while mask.dim() < matrix.dim():\n            mask = mask.unsqueeze(1)\n        if not memory_efficient:\n            result = th.nn.functional.softmax(matrix * mask, dim=dim)\n            result = result * mask\n            result = result / (result.sum(dim=dim, keepdim=True) + 1e-13)\n        else:\n            masked_matrix = matrix.masked_fill(\n                (1 - mask).byte(), mask_fill_value\n            )\n            result = th.nn.functional.softmax(masked_matrix, dim=dim)\n    return result\n"
  },
  {
    "path": "examples/pytorch/diffpool/model/tensorized_layers/__init__.py",
    "content": "from .diffpool import BatchedDiffPool\nfrom .graphsage import BatchedGraphSAGE\n"
  },
  {
    "path": "examples/pytorch/diffpool/model/tensorized_layers/assignment.py",
    "content": "import torch\nfrom torch import nn as nn\nfrom torch.autograd import Variable\nfrom torch.nn import functional as F\n\nfrom model.tensorized_layers.graphsage import BatchedGraphSAGE\n\n\nclass DiffPoolAssignment(nn.Module):\n    def __init__(self, nfeat, nnext):\n        super().__init__()\n        self.assign_mat = BatchedGraphSAGE(nfeat, nnext, use_bn=True)\n\n    def forward(self, x, adj, log=False):\n        s_l_init = self.assign_mat(x, adj)\n        s_l = F.softmax(s_l_init, dim=-1)\n        return s_l\n"
  },
  {
    "path": "examples/pytorch/diffpool/model/tensorized_layers/diffpool.py",
    "content": "import torch\nfrom torch import nn as nn\n\nfrom model.loss import EntropyLoss, LinkPredLoss\nfrom model.tensorized_layers.assignment import DiffPoolAssignment\nfrom model.tensorized_layers.graphsage import BatchedGraphSAGE\n\n\nclass BatchedDiffPool(nn.Module):\n    def __init__(self, nfeat, nnext, nhid, link_pred=False, entropy=True):\n        super(BatchedDiffPool, self).__init__()\n        self.link_pred = link_pred\n        self.log = {}\n        self.link_pred_layer = LinkPredLoss()\n        self.embed = BatchedGraphSAGE(nfeat, nhid, use_bn=True)\n        self.assign = DiffPoolAssignment(nfeat, nnext)\n        self.reg_loss = nn.ModuleList([])\n        self.loss_log = {}\n        if link_pred:\n            self.reg_loss.append(LinkPredLoss())\n        if entropy:\n            self.reg_loss.append(EntropyLoss())\n\n    def forward(self, x, adj, log=False):\n        z_l = self.embed(x, adj)\n        s_l = self.assign(x, adj)\n        if log:\n            self.log[\"s\"] = s_l.cpu().numpy()\n        xnext = torch.matmul(s_l.transpose(-1, -2), z_l)\n        anext = (s_l.transpose(-1, -2)).matmul(adj).matmul(s_l)\n\n        for loss_layer in self.reg_loss:\n            loss_name = str(type(loss_layer).__name__)\n            self.loss_log[loss_name] = loss_layer(adj, anext, s_l)\n        if log:\n            self.log[\"a\"] = anext.cpu().numpy()\n        return xnext, anext\n"
  },
  {
    "path": "examples/pytorch/diffpool/model/tensorized_layers/graphsage.py",
    "content": "import torch\nfrom torch import nn as nn\nfrom torch.nn import functional as F\n\n\nclass BatchedGraphSAGE(nn.Module):\n    def __init__(\n        self, infeat, outfeat, use_bn=True, mean=False, add_self=False\n    ):\n        super().__init__()\n        self.add_self = add_self\n        self.use_bn = use_bn\n        self.mean = mean\n        self.W = nn.Linear(infeat, outfeat, bias=True)\n\n        nn.init.xavier_uniform_(\n            self.W.weight, gain=nn.init.calculate_gain(\"relu\")\n        )\n\n    def forward(self, x, adj):\n        num_node_per_graph = adj.size(1)\n        if self.use_bn and not hasattr(self, \"bn\"):\n            self.bn = nn.BatchNorm1d(num_node_per_graph).to(adj.device)\n\n        if self.add_self:\n            adj = adj + torch.eye(num_node_per_graph).to(adj.device)\n\n        if self.mean:\n            adj = adj / adj.sum(-1, keepdim=True)\n\n        h_k_N = torch.matmul(adj, x)\n        h_k = self.W(h_k_N)\n        h_k = F.normalize(h_k, dim=2, p=2)\n        h_k = F.relu(h_k)\n        if self.use_bn:\n            h_k = self.bn(h_k)\n        return h_k\n\n    def __repr__(self):\n        if self.use_bn:\n            return \"BN\" + super(BatchedGraphSAGE, self).__repr__()\n        else:\n            return super(BatchedGraphSAGE, self).__repr__()\n"
  },
  {
    "path": "examples/pytorch/diffpool/train.py",
    "content": "import argparse\nimport os\nimport random\nimport time\n\nimport dgl\nimport dgl.function as fn\n\nimport networkx as nx\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.utils.data\nfrom data_utils import pre_process\nfrom dgl import DGLGraph\nfrom dgl.data import tu\nfrom model.encoder import DiffPool\n\nglobal_train_time_per_epoch = []\n\n\ndef arg_parse():\n    \"\"\"\n    argument parser\n    \"\"\"\n    parser = argparse.ArgumentParser(description=\"DiffPool arguments\")\n    parser.add_argument(\"--dataset\", dest=\"dataset\", help=\"Input Dataset\")\n    parser.add_argument(\n        \"--pool_ratio\", dest=\"pool_ratio\", type=float, help=\"pooling ratio\"\n    )\n    parser.add_argument(\n        \"--num_pool\", dest=\"num_pool\", type=int, help=\"num_pooling layer\"\n    )\n    parser.add_argument(\n        \"--no_link_pred\",\n        dest=\"linkpred\",\n        action=\"store_false\",\n        help=\"switch of link prediction object\",\n    )\n    parser.add_argument(\"--cuda\", dest=\"cuda\", type=int, help=\"switch cuda\")\n    parser.add_argument(\"--lr\", dest=\"lr\", type=float, help=\"learning rate\")\n    parser.add_argument(\n        \"--clip\", dest=\"clip\", type=float, help=\"gradient clipping\"\n    )\n    parser.add_argument(\n        \"--batch-size\", dest=\"batch_size\", type=int, help=\"batch size\"\n    )\n    parser.add_argument(\"--epochs\", dest=\"epoch\", type=int, help=\"num-of-epoch\")\n    parser.add_argument(\n        \"--train-ratio\",\n        dest=\"train_ratio\",\n        type=float,\n        help=\"ratio of trainning dataset split\",\n    )\n    parser.add_argument(\n        \"--test-ratio\",\n        dest=\"test_ratio\",\n        type=float,\n        help=\"ratio of testing dataset split\",\n    )\n    parser.add_argument(\n        \"--num_workers\",\n        dest=\"n_worker\",\n        type=int,\n        help=\"number of workers when dataloading\",\n    )\n    parser.add_argument(\n        \"--gc-per-block\",\n        dest=\"gc_per_block\",\n        type=int,\n        help=\"number of graph conv layer per block\",\n    )\n    parser.add_argument(\n        \"--bn\",\n        dest=\"bn\",\n        action=\"store_const\",\n        const=True,\n        default=True,\n        help=\"switch for bn\",\n    )\n    parser.add_argument(\n        \"--dropout\", dest=\"dropout\", type=float, help=\"dropout rate\"\n    )\n    parser.add_argument(\n        \"--bias\",\n        dest=\"bias\",\n        action=\"store_const\",\n        const=True,\n        default=True,\n        help=\"switch for bias\",\n    )\n    parser.add_argument(\n        \"--save_dir\",\n        dest=\"save_dir\",\n        help=\"model saving directory: SAVE_DICT/DATASET\",\n    )\n    parser.add_argument(\n        \"--load_epoch\",\n        dest=\"load_epoch\",\n        type=int,\n        help=\"load trained model params from\\\n                         SAVE_DICT/DATASET/model-LOAD_EPOCH\",\n    )\n    parser.add_argument(\n        \"--data_mode\",\n        dest=\"data_mode\",\n        help=\"data\\\n                        preprocessing mode: default, id, degree, or one-hot\\\n                        vector of degree number\",\n        choices=[\"default\", \"id\", \"deg\", \"deg_num\"],\n    )\n\n    parser.set_defaults(\n        dataset=\"ENZYMES\",\n        pool_ratio=0.15,\n        num_pool=1,\n        cuda=1,\n        lr=1e-3,\n        clip=2.0,\n        batch_size=20,\n        epoch=4000,\n        train_ratio=0.7,\n        test_ratio=0.1,\n        n_worker=1,\n        gc_per_block=3,\n        dropout=0.0,\n        method=\"diffpool\",\n        bn=True,\n        bias=True,\n        save_dir=\"./model_param\",\n        load_epoch=-1,\n        data_mode=\"default\",\n    )\n    return parser.parse_args()\n\n\ndef prepare_data(dataset, prog_args, train=False, pre_process=None):\n    \"\"\"\n    preprocess TU dataset according to DiffPool's paper setting and load dataset into dataloader\n    \"\"\"\n    if train:\n        shuffle = True\n    else:\n        shuffle = False\n\n    if pre_process:\n        pre_process(dataset, prog_args)\n\n    # dataset.set_fold(fold)\n    return dgl.dataloading.GraphDataLoader(\n        dataset,\n        batch_size=prog_args.batch_size,\n        shuffle=shuffle,\n        num_workers=prog_args.n_worker,\n    )\n\n\ndef graph_classify_task(prog_args):\n    \"\"\"\n    perform graph classification task\n    \"\"\"\n\n    dataset = tu.LegacyTUDataset(name=prog_args.dataset)\n    train_size = int(prog_args.train_ratio * len(dataset))\n    test_size = int(prog_args.test_ratio * len(dataset))\n    val_size = int(len(dataset) - train_size - test_size)\n\n    dataset_train, dataset_val, dataset_test = torch.utils.data.random_split(\n        dataset, (train_size, val_size, test_size)\n    )\n    train_dataloader = prepare_data(\n        dataset_train, prog_args, train=True, pre_process=pre_process\n    )\n    val_dataloader = prepare_data(\n        dataset_val, prog_args, train=False, pre_process=pre_process\n    )\n    test_dataloader = prepare_data(\n        dataset_test, prog_args, train=False, pre_process=pre_process\n    )\n    input_dim, label_dim, max_num_node = dataset.statistics()\n    print(\"++++++++++STATISTICS ABOUT THE DATASET\")\n    print(\"dataset feature dimension is\", input_dim)\n    print(\"dataset label dimension is\", label_dim)\n    print(\"the max num node is\", max_num_node)\n    print(\"number of graphs is\", len(dataset))\n    # assert len(dataset) % prog_args.batch_size == 0, \"training set not divisible by batch size\"\n\n    hidden_dim = 64  # used to be 64\n    embedding_dim = 64\n\n    # calculate assignment dimension: pool_ratio * largest graph's maximum\n    # number of nodes  in the dataset\n    assign_dim = int(max_num_node * prog_args.pool_ratio)\n    print(\"++++++++++MODEL STATISTICS++++++++\")\n    print(\"model hidden dim is\", hidden_dim)\n    print(\"model embedding dim for graph instance embedding\", embedding_dim)\n    print(\"initial batched pool graph dim is\", assign_dim)\n    activation = F.relu\n\n    # initialize model\n    # 'diffpool' : diffpool\n    model = DiffPool(\n        input_dim,\n        hidden_dim,\n        embedding_dim,\n        label_dim,\n        activation,\n        prog_args.gc_per_block,\n        prog_args.dropout,\n        prog_args.num_pool,\n        prog_args.linkpred,\n        prog_args.batch_size,\n        \"meanpool\",\n        assign_dim,\n        prog_args.pool_ratio,\n    )\n\n    if prog_args.load_epoch >= 0 and prog_args.save_dir is not None:\n        model.load_state_dict(\n            torch.load(\n                prog_args.save_dir\n                + \"/\"\n                + prog_args.dataset\n                + \"/model.iter-\"\n                + str(prog_args.load_epoch),\n                weights_only=False,\n            )\n        )\n\n    print(\"model init finished\")\n    print(\"MODEL:::::::\", prog_args.method)\n    if prog_args.cuda:\n        model = model.cuda()\n\n    logger = train(\n        train_dataloader, model, prog_args, val_dataset=val_dataloader\n    )\n    result = evaluate(test_dataloader, model, prog_args, logger)\n    print(\"test  accuracy {:.2f}%\".format(result * 100))\n\n\ndef train(dataset, model, prog_args, same_feat=True, val_dataset=None):\n    \"\"\"\n    training function\n    \"\"\"\n    dir = prog_args.save_dir + \"/\" + prog_args.dataset\n    if not os.path.exists(dir):\n        os.makedirs(dir)\n    dataloader = dataset\n    optimizer = torch.optim.Adam(\n        filter(lambda p: p.requires_grad, model.parameters()), lr=0.001\n    )\n    early_stopping_logger = {\"best_epoch\": -1, \"val_acc\": -1}\n\n    if prog_args.cuda > 0:\n        torch.cuda.set_device(0)\n    for epoch in range(prog_args.epoch):\n        begin_time = time.time()\n        model.train()\n        accum_correct = 0\n        total = 0\n        print(\"\\nEPOCH ###### {} ######\".format(epoch))\n        computation_time = 0.0\n        for batch_idx, (batch_graph, graph_labels) in enumerate(dataloader):\n            for key, value in batch_graph.ndata.items():\n                batch_graph.ndata[key] = value.float()\n            graph_labels = graph_labels.long()\n            if torch.cuda.is_available():\n                batch_graph = batch_graph.to(torch.cuda.current_device())\n                graph_labels = graph_labels.cuda()\n\n            model.zero_grad()\n            compute_start = time.time()\n            ypred = model(batch_graph)\n            indi = torch.argmax(ypred, dim=1)\n            correct = torch.sum(indi == graph_labels).item()\n            accum_correct += correct\n            total += graph_labels.size()[0]\n            loss = model.loss(ypred, graph_labels)\n            loss.backward()\n            batch_compute_time = time.time() - compute_start\n            computation_time += batch_compute_time\n            nn.utils.clip_grad_norm_(model.parameters(), prog_args.clip)\n            optimizer.step()\n\n        train_accu = accum_correct / total\n        print(\n            \"train accuracy for this epoch {} is {:.2f}%\".format(\n                epoch, train_accu * 100\n            )\n        )\n        elapsed_time = time.time() - begin_time\n        print(\n            \"loss {:.4f} with epoch time {:.4f} s & computation time {:.4f} s \".format(\n                loss.item(), elapsed_time, computation_time\n            )\n        )\n        global_train_time_per_epoch.append(elapsed_time)\n        if val_dataset is not None:\n            result = evaluate(val_dataset, model, prog_args)\n            print(\"validation  accuracy {:.2f}%\".format(result * 100))\n            if (\n                result >= early_stopping_logger[\"val_acc\"]\n                and result <= train_accu\n            ):\n                early_stopping_logger.update(best_epoch=epoch, val_acc=result)\n                if prog_args.save_dir is not None:\n                    torch.save(\n                        model.state_dict(),\n                        prog_args.save_dir\n                        + \"/\"\n                        + prog_args.dataset\n                        + \"/model.iter-\"\n                        + str(early_stopping_logger[\"best_epoch\"]),\n                    )\n            print(\n                \"best epoch is EPOCH {}, val_acc is {:.2f}%\".format(\n                    early_stopping_logger[\"best_epoch\"],\n                    early_stopping_logger[\"val_acc\"] * 100,\n                )\n            )\n        torch.cuda.empty_cache()\n    return early_stopping_logger\n\n\ndef evaluate(dataloader, model, prog_args, logger=None):\n    \"\"\"\n    evaluate function\n    \"\"\"\n    if logger is not None and prog_args.save_dir is not None:\n        model.load_state_dict(\n            torch.load(\n                prog_args.save_dir\n                + \"/\"\n                + prog_args.dataset\n                + \"/model.iter-\"\n                + str(logger[\"best_epoch\"]),\n                weights_only=False,\n            )\n        )\n    model.eval()\n    correct_label = 0\n    with torch.no_grad():\n        for batch_idx, (batch_graph, graph_labels) in enumerate(dataloader):\n            for key, value in batch_graph.ndata.items():\n                batch_graph.ndata[key] = value.float()\n            graph_labels = graph_labels.long()\n            if torch.cuda.is_available():\n                batch_graph = batch_graph.to(torch.cuda.current_device())\n                graph_labels = graph_labels.cuda()\n            ypred = model(batch_graph)\n            indi = torch.argmax(ypred, dim=1)\n            correct = torch.sum(indi == graph_labels)\n            correct_label += correct.item()\n    result = correct_label / (len(dataloader) * prog_args.batch_size)\n    return result\n\n\ndef main():\n    \"\"\"\n    main\n    \"\"\"\n    prog_args = arg_parse()\n    print(prog_args)\n    graph_classify_task(prog_args)\n\n    print(\n        \"Train time per epoch: {:.4f}\".format(\n            sum(global_train_time_per_epoch) / len(global_train_time_per_epoch)\n        )\n    )\n    print(\n        \"Max memory usage: {:.4f}\".format(\n            torch.cuda.max_memory_allocated(0) / (1024 * 1024)\n        )\n    )\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/pytorch/dimenet/README.md",
    "content": "# DGL Implementation of DimeNet and DimeNet++\n\nThis DGL example implements the GNN model proposed in the paper [Directional Message Passing for Molecular Graphs](https://arxiv.org/abs/2003.03123) and [Fast and Uncertainty-Aware Directional Message Passing for Non-Equilibrium Molecules](https://arxiv.org/abs/2011.14115). For the original implementation, see [here](https://github.com/klicperajo/dimenet).\n\nContributor: [xnuohz](https://github.com/xnuohz)\n\n* This example implements both DimeNet and DimeNet++.\n* The advantages of DimeNet++ over DimeNet\n    - Fast interactions: replacing bilinear layer with a simple Hadamard priduct\n    - Embedding hierarchy: using a higher number of embeddings by reducing the embedding size in blocks via down- and up-projection layers\n    - Other improvements: using less interaction blocks\n\n### Requirements\nThe codebase is implemented in Python 3.6. For version requirement of packages, see below.\n\n```\nclick 7.1.2\ndgl 0.6.0\nlogzero 1.6.3\nnumpy 1.19.5\nruamel.yaml 0.16.12\nscikit-learn 0.24.1\nscipy 1.5.4\nsympy 1.7.1\ntorch 1.7.0\ntqdm 4.56.0\n```\n\n### The graph datasets used in this example\n\nThe DGL's built-in QM9 dataset. Dataset summary:\n\n* Number of Molecular Graphs: 130,831\n* Number of Tasks: 12\n\n### Usage\n\n**Note: DimeNet++ is recommended to use over DimeNet.**\n\n##### Examples\n\nThe following commands learn a neural network and predict on the test set.\nTraining a DimeNet model on QM9 dataset.\n```bash\npython main.py --model-cnf config/dimenet.yaml\n```\nTraining a DimeNet++ model on QM9 dataset.\n```bash\npython main.py --model-cnf config/dimenet_pp.yaml\n```\nFor faster experimentation, you should first put the author's [pretrained](https://github.com/klicperajo/dimenet/tree/master/pretrained) folder here, which contains pre-trained TensorFlow models. You can convert a TensorFlow model to a PyTorch model by using the following commands.\n```\npython convert_tf_ckpt_to_pytorch.py --model-cnf config/dimenet_pp.yaml --convert-cnf config/convert.yaml\n```\nThen you can set `flag: True` in `dimenet_pp.yaml` and run the above script, DimeNet++ will use the pretrained weights to predict on the test set.\n\n##### Configuration\n\nFor more details, please see `config/dimenet.yaml` and `config/dimenet_pp.yaml`\n\n###### Model options\n```\n// The following paramaters are only used in DimeNet++\nout_emb_size      int    Output embedding size.                                         Default is 256\nint_emb_size      int    Input embedding size.                                          Default is 64\nbasis_emb_size    int    Basis embedding size.                                          Default is 8\nextensive         bool   Readout operator for generating a graph-level representation.  Default is True \n\n// The following paramater is only used in DimeNet\nnum_bilinear      int    Third dimension of the bilinear layer tensor in DimeNet.       Default is 8\n\n// The following paramaters are used in both DimeNet and DimeNet++\nemb_size          int    Embedding size used throughout the model.                              Default is 128\nnum_blocks        int    Number of building blocks to be stacked.                               Default is 6 in DimeNet and 4 in DimeNet++   \nnum_spherical     int    Number of spherical harmonics.                                         Default is 7   \nnum_radial        int    Number of radial basis functions.                                      Default is 6   \nenvelope_exponent int    Shape of the smooth cutoff.                                            Default is 5   \ncutoff            float  Cutoff distance for interatomic interactions.                          Default is 5.0 \nnum_before_skip   int    Number of residual layers in interaction block before skip connection. Default is 1   \nnum_after_skip    int    Number of residual layers in interaction block after skip connection.  Default is 2   \nnum_dense_output  int    Number of dense layers for the output blocks.                          Default is 3   \ntargets           list   List of targets to predict.                                            Default is ['mu']\noutput_init       string Initial function name for output layer.                                Default is 'GlorotOrthogonal'\n```\n\n###### Training options\n```\nnum_train         int   Number of training samples.                     Default is 110000\nnum_valid         int   Number of validation samples.                   Default is 10000\ndata_seed         int   Random seed.                                    Default is 42\nlr                float Learning rate.                                  Default is 0.001\nweight_decay      float Weight decay.                                   Default is 0.0001\nema_decay         float EMA decay.                                      Default is 0.\nbatch_size        int   Batch size.                                     Default is 100\nepochs            int   Training epochs.                                Default is 300\nearly_stopping    int   Patient epochs to wait before early stopping.   Default is 20\nnum_workers       int   Number of subprocesses to use for data loading. Default is 18\ngpu               int   GPU index.                                      Default is 0, using CUDA:0\ninterval          int   Time intervals for model evaluation.            Default is 50\nstep_size         int   Period of learning rate decay.                  Default is 100\ngamma             float Factor of learning rate decay.                  Default is 0.3\n```\n\n### Performance\n\n- Batch size is different\n- Linear learning rate warm-up is not used\n- Exponential learning rate decay is not used\n- Exponential moving average (EMA) is not used\n- The values for tasks except mu, alpha, r2, Cv should be x 10^-3\n- The author's code didn't provide the pretrained model for gap task\n- MAE(DimeNet in Table 1) is from [here](https://arxiv.org/abs/2003.03123)\n- MAE(DimeNet++ in Table 2) is from [here](https://arxiv.org/abs/2011.14115)\n\n| Target | mu | alpha | homo | lumo | gap | r2 | zpve | U0 | U | H | G | Cv |\n| :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: |\n| MAE(DimeNet in Table 1)      | 0.0286 | 0.0469 | 27.8 | 19.7 | 34.8 | 0.331 | 1.29 | 8.02 | 7.89 | 8.11 | 8.98 | 0.0249 |\n| MAE(DimeNet++ in Table 2)    | 0.0297 | 0.0435 | 24.6 | 19.5 | 32.6 | 0.331 | 1.21 | 6.32 | 6.28 | 6.53 | 7.56 | 0.0230 |\n| MAE(DimeNet++, TF, pretrain) | 0.0297 | 0.0435 | 0.0246 | 0.0195 | -      | 0.3312 | 0.00121 | 0.0063 | 0.00628 | 0.00653 | 0.00756 | 0.0230 |\n| MAE(DimeNet++, TF, scratch)  | 0.0330 | 0.0447 | 0.0251 | 0.0227 | 0.0486 | 0.3574 | 0.00123 | 0.0065 | 0.00635 | 0.00658 | 0.00747 | 0.0224 |\n| MAE(DimeNet++, DGL)          | 0.0326 | 0.0537 | 0.0311 | 0.0255 | 0.0490 | 0.4801 | 0.0043 | 0.0141 | 0.0109 | 0.0117 | 0.0150 | 0.0254 |\n\n### Speed\n\n| Model | Original Implementation | DGL Implementation | Improvement |\n| :-: | :-: | :-: | :-: |\n| DimeNet | 2839 | 1345 | 2.1x |\n| DimeNet++ | 624 | 238 | 2.6x |\n"
  },
  {
    "path": "examples/pytorch/dimenet/config/convert.yaml",
    "content": "tf:\n  ckpt_path: 'pretrained/dimenet_pp/mu'\ntorch:\n  dump_path: 'pretrained/converted'"
  },
  {
    "path": "examples/pytorch/dimenet/config/dimenet.yaml",
    "content": "name: \"dimenet\"\n\nmodel:\n  emb_size: 128\n  num_blocks: 6\n  num_bilinear: 8\n  num_spherical: 7\n  num_radial: 6\n  envelope_exponent: 5\n  cutoff: 5.0\n  num_before_skip: 1\n  num_after_skip: 2\n  num_dense_output: 3\n  # ['mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv']\n  targets: ['U0']\n\ntrain:\n  num_train: 110000\n  num_valid: 10000\n  data_seed: 42\n  lr: 0.001\n  weight_decay: 0.0001\n  ema_decay: 0\n  batch_size: 45\n  epochs: 300\n  early_stopping: 20\n  num_workers: 18\n  gpu: 0\n  interval: 50\n  step_size: 100\n  gamma: 0.3\n\npretrain:\n  flag: False\n  path: 'pretrained/converted/'"
  },
  {
    "path": "examples/pytorch/dimenet/config/dimenet_pp.yaml",
    "content": "name: \"dimenet++\"\n\nmodel:\n  emb_size: 128\n  out_emb_size: 256\n  int_emb_size: 64\n  basis_emb_size: 8\n  num_blocks: 4\n  num_spherical: 7\n  num_radial: 6\n  envelope_exponent: 5\n  cutoff: 5.0\n  extensive: True\n  num_before_skip: 1\n  num_after_skip: 2\n  num_dense_output: 3\n  # ['mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv']\n  targets: ['mu']\n\ntrain:\n  num_train: 110000\n  num_valid: 10000\n  data_seed: 42\n  lr: 0.001\n  weight_decay: 0.0001\n  ema_decay: 0\n  batch_size: 100\n  epochs: 300\n  early_stopping: 20\n  num_workers: 18\n  gpu: 0\n  interval: 50\n  step_size: 100\n  gamma: 0.3\n\npretrain:\n  flag: False\n  path: 'pretrained/converted/'"
  },
  {
    "path": "examples/pytorch/dimenet/convert_tf_ckpt_to_pytorch.py",
    "content": "import os\nfrom pathlib import Path\n\nimport click\nimport numpy as np\nimport tensorflow as tf\nimport torch\nimport torch.nn as nn\nfrom logzero import logger\nfrom modules.dimenet_pp import DimeNetPP\nfrom modules.initializers import GlorotOrthogonal\nfrom ruamel.yaml import YAML\n\n\n@click.command()\n@click.option(\n    \"-m\",\n    \"--model-cnf\",\n    type=click.Path(exists=True),\n    help=\"Path of model config yaml.\",\n)\n@click.option(\n    \"-c\",\n    \"--convert-cnf\",\n    type=click.Path(exists=True),\n    help=\"Path of convert config yaml.\",\n)\ndef main(model_cnf, convert_cnf):\n    yaml = YAML(typ=\"safe\")\n    model_cnf = yaml.load(Path(model_cnf))\n    convert_cnf = yaml.load(Path(convert_cnf))\n    model_name, model_params, _ = (\n        model_cnf[\"name\"],\n        model_cnf[\"model\"],\n        model_cnf[\"train\"],\n    )\n    logger.info(f\"Model name: {model_name}\")\n    logger.info(f\"Model params: {model_params}\")\n\n    if model_params[\"targets\"] in [\"mu\", \"homo\", \"lumo\", \"gap\", \"zpve\"]:\n        model_params[\"output_init\"] = nn.init.zeros_\n    else:\n        # 'GlorotOrthogonal' for alpha, R2, U0, U, H, G, and Cv\n        model_params[\"output_init\"] = GlorotOrthogonal\n\n    # model initialization\n    logger.info(\"Loading Model\")\n    model = DimeNetPP(\n        emb_size=model_params[\"emb_size\"],\n        out_emb_size=model_params[\"out_emb_size\"],\n        int_emb_size=model_params[\"int_emb_size\"],\n        basis_emb_size=model_params[\"basis_emb_size\"],\n        num_blocks=model_params[\"num_blocks\"],\n        num_spherical=model_params[\"num_spherical\"],\n        num_radial=model_params[\"num_radial\"],\n        cutoff=model_params[\"cutoff\"],\n        envelope_exponent=model_params[\"envelope_exponent\"],\n        num_before_skip=model_params[\"num_before_skip\"],\n        num_after_skip=model_params[\"num_after_skip\"],\n        num_dense_output=model_params[\"num_dense_output\"],\n        num_targets=len(model_params[\"targets\"]),\n        extensive=model_params[\"extensive\"],\n        output_init=model_params[\"output_init\"],\n    )\n    logger.info(model.state_dict())\n    tf_path, torch_path = (\n        convert_cnf[\"tf\"][\"ckpt_path\"],\n        convert_cnf[\"torch\"][\"dump_path\"],\n    )\n    init_vars = tf.train.list_variables(tf_path)\n    tf_vars_dict = {}\n\n    # 147 keys\n    for name, shape in init_vars:\n        if name == \"_CHECKPOINTABLE_OBJECT_GRAPH\":\n            continue\n        array = tf.train.load_variable(tf_path, name)\n        logger.info(f\"Loading TF weight {name} with shape {shape}\")\n        tf_vars_dict[name] = array\n\n    for name, array in tf_vars_dict.items():\n        name = name.split(\"/\")[:-2]\n        pointer = model\n\n        for m_name in name:\n            if m_name == \"kernel\":\n                pointer = getattr(pointer, \"weight\")\n            elif m_name == \"int_blocks\":\n                pointer = getattr(pointer, \"interaction_blocks\")\n            elif m_name == \"embeddings\":\n                pointer = getattr(pointer, \"embedding\")\n                pointer = getattr(pointer, \"weight\")\n            else:\n                pointer = getattr(pointer, m_name)\n        if name[-1] == \"kernel\":\n            array = np.transpose(array)\n        assert array.shape == pointer.shape\n        logger.info(f\"Initialize PyTorch weight {name}\")\n        pointer.data = torch.from_numpy(array)\n\n    logger.info(f\"Save PyTorch model to {torch_path}\")\n    if not os.path.exists(torch_path):\n        os.makedirs(torch_path)\n    target = model_params[\"targets\"][0]\n    torch.save(model.state_dict(), f\"{torch_path}/{target}.pt\")\n    logger.info(model.state_dict())\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/pytorch/dimenet/main.py",
    "content": "import copy\nfrom pathlib import Path\n\nimport click\n\nimport dgl\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nfrom dgl.data.utils import Subset\nfrom logzero import logger\nfrom modules.dimenet import DimeNet\nfrom modules.dimenet_pp import DimeNetPP\nfrom modules.initializers import GlorotOrthogonal\nfrom qm9 import QM9\nfrom ruamel.yaml import YAML\nfrom sklearn.metrics import mean_absolute_error\nfrom torch.utils.data import DataLoader\n\n\ndef split_dataset(\n    dataset, num_train, num_valid, shuffle=False, random_state=None\n):\n    \"\"\"Split dataset into training, validation and test set.\n\n    Parameters\n    ----------\n    dataset\n        We assume that ``len(dataset)`` gives the number of datapoints and ``dataset[i]``\n        gives the ith datapoint.\n    num_train : int\n        Number of training datapoints.\n    num_valid : int\n        Number of validation datapoints.\n    shuffle : bool, optional\n        By default we perform a consecutive split of the dataset. If True,\n        we will first randomly shuffle the dataset.\n    random_state : None, int or array_like, optional\n        Random seed used to initialize the pseudo-random number generator.\n        This can be any integer between 0 and 2^32 - 1 inclusive, an array\n        (or other sequence) of such integers, or None (the default value).\n        If seed is None, then RandomState will try to read data from /dev/urandom\n        (or the Windows analogue) if available or seed from the clock otherwise.\n\n    Returns\n    -------\n    list of length 3\n        Subsets for training, validation and test.\n    \"\"\"\n    from itertools import accumulate\n\n    num_data = len(dataset)\n    assert num_train + num_valid < num_data\n    lengths = [num_train, num_valid, num_data - num_train - num_valid]\n    if shuffle:\n        indices = np.random.RandomState(seed=random_state).permutation(num_data)\n    else:\n        indices = np.arange(num_data)\n    return [\n        Subset(dataset, indices[offset - length : offset])\n        for offset, length in zip(accumulate(lengths), lengths)\n    ]\n\n\n@torch.no_grad()\ndef ema(ema_model, model, decay):\n    msd = model.state_dict()\n    for k, ema_v in ema_model.state_dict().items():\n        model_v = msd[k].detach()\n        ema_v.copy_(ema_v * decay + (1.0 - decay) * model_v)\n\n\ndef edge_init(edges):\n    R_src, R_dst = edges.src[\"R\"], edges.dst[\"R\"]\n    dist = torch.sqrt(F.relu(torch.sum((R_src - R_dst) ** 2, -1)))\n    # d: bond length, o: bond orientation\n    return {\"d\": dist, \"o\": R_src - R_dst}\n\n\ndef _collate_fn(batch):\n    graphs, line_graphs, labels = map(list, zip(*batch))\n    g, l_g = dgl.batch(graphs), dgl.batch(line_graphs)\n    labels = torch.tensor(labels, dtype=torch.float32)\n    return g, l_g, labels\n\n\ndef train(device, model, opt, loss_fn, train_loader):\n    model.train()\n    epoch_loss = 0\n    num_samples = 0\n\n    for g, l_g, labels in train_loader:\n        g = g.to(device)\n        l_g = l_g.to(device)\n        labels = labels.to(device)\n        logits = model(g, l_g)\n        loss = loss_fn(logits, labels.view([-1, 1]))\n        epoch_loss += loss.data.item() * len(labels)\n        num_samples += len(labels)\n        opt.zero_grad()\n        loss.backward()\n        opt.step()\n\n    return epoch_loss / num_samples\n\n\n@torch.no_grad()\ndef evaluate(device, model, valid_loader):\n    model.eval()\n    predictions_all, labels_all = [], []\n\n    for g, l_g, labels in valid_loader:\n        g = g.to(device)\n        l_g = l_g.to(device)\n        logits = model(g, l_g)\n        labels_all.extend(labels)\n        predictions_all.extend(\n            logits.view(\n                -1,\n            )\n            .cpu()\n            .numpy()\n        )\n\n    return np.array(predictions_all), np.array(labels_all)\n\n\n@click.command()\n@click.option(\n    \"-m\",\n    \"--model-cnf\",\n    type=click.Path(exists=True),\n    help=\"Path of model config yaml.\",\n)\ndef main(model_cnf):\n    yaml = YAML(typ=\"safe\")\n    model_cnf = yaml.load(Path(model_cnf))\n    model_name, model_params, train_params, pretrain_params = (\n        model_cnf[\"name\"],\n        model_cnf[\"model\"],\n        model_cnf[\"train\"],\n        model_cnf[\"pretrain\"],\n    )\n    logger.info(f\"Model name: {model_name}\")\n    logger.info(f\"Model params: {model_params}\")\n    logger.info(f\"Train params: {train_params}\")\n\n    if model_params[\"targets\"] in [\"mu\", \"homo\", \"lumo\", \"gap\", \"zpve\"]:\n        model_params[\"output_init\"] = nn.init.zeros_\n    else:\n        # 'GlorotOrthogonal' for alpha, R2, U0, U, H, G, and Cv\n        model_params[\"output_init\"] = GlorotOrthogonal\n\n    logger.info(\"Loading Data Set\")\n    dataset = QM9(label_keys=model_params[\"targets\"], edge_funcs=[edge_init])\n\n    # data split\n    train_data, valid_data, test_data = split_dataset(\n        dataset,\n        num_train=train_params[\"num_train\"],\n        num_valid=train_params[\"num_valid\"],\n        shuffle=True,\n        random_state=train_params[\"data_seed\"],\n    )\n    logger.info(f\"Size of Training Set: {len(train_data)}\")\n    logger.info(f\"Size of Validation Set: {len(valid_data)}\")\n    logger.info(f\"Size of Test Set: {len(test_data)}\")\n\n    # data loader\n    train_loader = DataLoader(\n        train_data,\n        batch_size=train_params[\"batch_size\"],\n        shuffle=True,\n        collate_fn=_collate_fn,\n        num_workers=train_params[\"num_workers\"],\n    )\n\n    valid_loader = DataLoader(\n        valid_data,\n        batch_size=train_params[\"batch_size\"],\n        shuffle=False,\n        collate_fn=_collate_fn,\n        num_workers=train_params[\"num_workers\"],\n    )\n\n    test_loader = DataLoader(\n        test_data,\n        batch_size=train_params[\"batch_size\"],\n        shuffle=False,\n        collate_fn=_collate_fn,\n        num_workers=train_params[\"num_workers\"],\n    )\n\n    # check cuda\n    gpu = train_params[\"gpu\"]\n    device = f\"cuda:{gpu}\" if gpu >= 0 and torch.cuda.is_available() else \"cpu\"\n\n    # model initialization\n    logger.info(\"Loading Model\")\n    if model_name == \"dimenet\":\n        model = DimeNet(\n            emb_size=model_params[\"emb_size\"],\n            num_blocks=model_params[\"num_blocks\"],\n            num_bilinear=model_params[\"num_bilinear\"],\n            num_spherical=model_params[\"num_spherical\"],\n            num_radial=model_params[\"num_radial\"],\n            cutoff=model_params[\"cutoff\"],\n            envelope_exponent=model_params[\"envelope_exponent\"],\n            num_before_skip=model_params[\"num_before_skip\"],\n            num_after_skip=model_params[\"num_after_skip\"],\n            num_dense_output=model_params[\"num_dense_output\"],\n            num_targets=len(model_params[\"targets\"]),\n            output_init=model_params[\"output_init\"],\n        ).to(device)\n    elif model_name == \"dimenet++\":\n        model = DimeNetPP(\n            emb_size=model_params[\"emb_size\"],\n            out_emb_size=model_params[\"out_emb_size\"],\n            int_emb_size=model_params[\"int_emb_size\"],\n            basis_emb_size=model_params[\"basis_emb_size\"],\n            num_blocks=model_params[\"num_blocks\"],\n            num_spherical=model_params[\"num_spherical\"],\n            num_radial=model_params[\"num_radial\"],\n            cutoff=model_params[\"cutoff\"],\n            envelope_exponent=model_params[\"envelope_exponent\"],\n            num_before_skip=model_params[\"num_before_skip\"],\n            num_after_skip=model_params[\"num_after_skip\"],\n            num_dense_output=model_params[\"num_dense_output\"],\n            num_targets=len(model_params[\"targets\"]),\n            extensive=model_params[\"extensive\"],\n            output_init=model_params[\"output_init\"],\n        ).to(device)\n    else:\n        raise ValueError(f\"Invalid Model Name {model_name}\")\n\n    if pretrain_params[\"flag\"]:\n        torch_path = pretrain_params[\"path\"]\n        target = model_params[\"targets\"][0]\n        model.load_state_dict(\n            torch.load(f\"{torch_path}/{target}.pt\", weights_only=False)\n        )\n\n        logger.info(\"Testing with Pretrained model\")\n        predictions, labels = evaluate(device, model, test_loader)\n        test_mae = mean_absolute_error(labels, predictions)\n        logger.info(f\"Test MAE {test_mae:.4f}\")\n\n        return\n    # define loss function and optimization\n    loss_fn = nn.L1Loss()\n    opt = optim.Adam(\n        model.parameters(),\n        lr=train_params[\"lr\"],\n        weight_decay=train_params[\"weight_decay\"],\n        amsgrad=True,\n    )\n    scheduler = optim.lr_scheduler.StepLR(\n        opt, train_params[\"step_size\"], gamma=train_params[\"gamma\"]\n    )\n\n    # model training\n    best_mae = 1e9\n    no_improvement = 0\n\n    # EMA for valid and test\n    logger.info(\"EMA Init\")\n    ema_model = copy.deepcopy(model)\n    for p in ema_model.parameters():\n        p.requires_grad_(False)\n\n    best_model = copy.deepcopy(ema_model)\n\n    logger.info(\"Training\")\n    for i in range(train_params[\"epochs\"]):\n        train_loss = train(device, model, opt, loss_fn, train_loader)\n        ema(ema_model, model, train_params[\"ema_decay\"])\n        if i % train_params[\"interval\"] == 0:\n            predictions, labels = evaluate(device, ema_model, valid_loader)\n\n            valid_mae = mean_absolute_error(labels, predictions)\n            logger.info(\n                f\"Epoch {i} | Train Loss {train_loss:.4f} | Val MAE {valid_mae:.4f}\"\n            )\n\n            if valid_mae > best_mae:\n                no_improvement += 1\n                if no_improvement == train_params[\"early_stopping\"]:\n                    logger.info(\"Early stop.\")\n                    break\n            else:\n                no_improvement = 0\n                best_mae = valid_mae\n                best_model = copy.deepcopy(ema_model)\n        else:\n            logger.info(f\"Epoch {i} | Train Loss {train_loss:.4f}\")\n\n        scheduler.step()\n\n    logger.info(\"Testing\")\n    predictions, labels = evaluate(device, best_model, test_loader)\n    test_mae = mean_absolute_error(labels, predictions)\n    logger.info(\"Test MAE {:.4f}\".format(test_mae))\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/pytorch/dimenet/modules/activations.py",
    "content": "import torch\n\n\ndef swish(x):\n    \"\"\"\n    Swish activation function,\n    from Ramachandran, Zopf, Le 2017. \"Searching for Activation Functions\"\n    \"\"\"\n    return x * torch.sigmoid(x)\n"
  },
  {
    "path": "examples/pytorch/dimenet/modules/basis_utils.py",
    "content": "import numpy as np\nimport sympy as sym\nfrom scipy import special as sp\nfrom scipy.optimize import brentq\n\n\ndef Jn(r, n):\n    \"\"\"\n    r: int or list\n    n: int or list\n    len(r) == len(n)\n    return value should be the same shape as the input data\n    ===\n    example:\n        r = n = np.array([1, 2, 3, 4])\n        res = [0.3, 0.1, 0.1, 0.1]\n    ===\n    numerical spherical bessel functions of order n\n    \"\"\"\n    return np.sqrt(np.pi / (2 * r)) * sp.jv(n + 0.5, r)  # the same shape as n\n\n\ndef Jn_zeros(n, k):\n    \"\"\"\n    n: int\n    k: int\n    res: array of shape [n, k]\n\n    Compute the first k zeros of the spherical bessel functions up to order n (excluded)\n    \"\"\"\n    zerosj = np.zeros((n, k), dtype=\"float32\")\n    zerosj[0] = np.arange(1, k + 1) * np.pi\n    points = np.arange(1, k + n) * np.pi\n    racines = np.zeros(k + n - 1, dtype=\"float32\")\n    for i in range(1, n):\n        for j in range(k + n - 1 - i):\n            foo = brentq(Jn, points[j], points[j + 1], (i,))\n            racines[j] = foo\n        points = racines\n        zerosj[i][:k] = racines[:k]\n\n    return zerosj\n\n\ndef spherical_bessel_formulas(n):\n    \"\"\"\n    n: int\n    res: array of shape [n,]\n\n    n sympy functions\n    Computes the sympy formulas for the spherical bessel functions up to order n (excluded)\n    \"\"\"\n    x = sym.symbols(\"x\")\n\n    f = [sym.sin(x) / x]\n    a = sym.sin(x) / x\n    for i in range(1, n):\n        b = sym.diff(a, x) / x\n        f += [sym.simplify(b * (-x) ** i)]\n        a = sym.simplify(b)\n    return f\n\n\ndef bessel_basis(n, k):\n    \"\"\"\n    n: int\n    k: int\n    res: [n, k]\n\n    n * k sympy functions\n    Computes the sympy formulas for the normalized and rescaled spherical bessel functions up to\n    order n (excluded) and maximum frequency k (excluded).\n    \"\"\"\n\n    zeros = Jn_zeros(n, k)\n    normalizer = []\n    for order in range(n):\n        normalizer_tmp = []\n        for i in range(k):\n            normalizer_tmp += [0.5 * Jn(zeros[order, i], order + 1) ** 2]\n        normalizer_tmp = 1 / np.array(normalizer_tmp) ** 0.5\n        normalizer += [normalizer_tmp]\n\n    f = spherical_bessel_formulas(n)\n    x = sym.symbols(\"x\")\n    bess_basis = []\n    for order in range(n):\n        bess_basis_tmp = []\n        for i in range(k):\n            bess_basis_tmp += [\n                sym.simplify(\n                    normalizer[order][i] * f[order].subs(x, zeros[order, i] * x)\n                )\n            ]\n        bess_basis += [bess_basis_tmp]\n    return bess_basis\n\n\ndef sph_harm_prefactor(l, m):\n    \"\"\"\n    l: int\n    m: int\n    res: float\n    Computes the constant pre-factor for the spherical harmonic of degree l and order m\n    input:\n    l: int, l>=0\n    m: int, -l<=m<=l\n    \"\"\"\n    return (\n        (2 * l + 1)\n        * np.math.factorial(l - abs(m))\n        / (4 * np.pi * np.math.factorial(l + abs(m)))\n    ) ** 0.5\n\n\ndef associated_legendre_polynomials(l, zero_m_only=True):\n    \"\"\"\n    l: int\n    return: l sympy functions\n    Computes sympy formulas of the associated legendre polynomials up to order l (excluded).\n    \"\"\"\n    z = sym.symbols(\"z\")\n    P_l_m = [[0] * (j + 1) for j in range(l)]\n\n    P_l_m[0][0] = 1\n\n    if l > 0:\n        P_l_m[1][0] = z\n\n        for j in range(2, l):\n            P_l_m[j][0] = sym.simplify(\n                ((2 * j - 1) * z * P_l_m[j - 1][0] - (j - 1) * P_l_m[j - 2][0])\n                / j\n            )\n\n        if not zero_m_only:\n            for i in range(1, l):\n                P_l_m[i][i] = sym.simplify((1 - 2 * i) * P_l_m[i - 1][i - 1])\n                if i + 1 < l:\n                    P_l_m[i + 1][i] = sym.simplify(\n                        (2 * i + 1) * z * P_l_m[i][i]\n                    )\n                for j in range(i + 2, l):\n                    P_l_m[j][i] = sym.simplify(\n                        (\n                            (2 * j - 1) * z * P_l_m[j - 1][i]\n                            - (i + j - 1) * P_l_m[j - 2][i]\n                        )\n                        / (j - i)\n                    )\n\n    return P_l_m\n\n\ndef real_sph_harm(l, zero_m_only=True, spherical_coordinates=True):\n    \"\"\"\n    return: a sympy function list of length l, for i-th index of the list, it is also a list of length (2 * i + 1)\n    Computes formula strings of the real part of the spherical harmonics up to order l (excluded).\n    Variables are either cartesian coordinates x,y,z on the unit sphere or spherical coordinates phi and theta.\n    \"\"\"\n    if not zero_m_only:\n        S_m = [0]\n        C_m = [1]\n        for i in range(1, l):\n            x = sym.symbols(\"x\")\n            y = sym.symbols(\"y\")\n            S_m += [x * S_m[i - 1] + y * C_m[i - 1]]\n            C_m += [x * C_m[i - 1] - y * S_m[i - 1]]\n\n    P_l_m = associated_legendre_polynomials(l, zero_m_only)\n\n    if spherical_coordinates:\n        theta = sym.symbols(\"theta\")\n        z = sym.symbols(\"z\")\n\n        for i in range(len(P_l_m)):\n            for j in range(len(P_l_m[i])):\n                if type(P_l_m[i][j]) != int:\n                    P_l_m[i][j] = P_l_m[i][j].subs(z, sym.cos(theta))\n\n        if not zero_m_only:\n            phi = sym.symbols(\"phi\")\n            for i in range(len(S_m)):\n                S_m[i] = (\n                    S_m[i]\n                    .subs(x, sym.sin(theta) * sym.cos(phi))\n                    .subs(y, sym.sin(theta) * sym.sin(phi))\n                )\n            for i in range(len(C_m)):\n                C_m[i] = (\n                    C_m[i]\n                    .subs(x, sym.sin(theta) * sym.cos(phi))\n                    .subs(y, sym.sin(theta) * sym.sin(phi))\n                )\n\n    Y_func_l_m = [[\"0\"] * (2 * j + 1) for j in range(l)]\n\n    for i in range(l):\n        Y_func_l_m[i][0] = sym.simplify(sph_harm_prefactor(i, 0) * P_l_m[i][0])\n\n    if not zero_m_only:\n        for i in range(1, l):\n            for j in range(1, i + 1):\n                Y_func_l_m[i][j] = sym.simplify(\n                    2**0.5 * sph_harm_prefactor(i, j) * C_m[j] * P_l_m[i][j]\n                )\n        for i in range(1, l):\n            for j in range(1, i + 1):\n                Y_func_l_m[i][-j] = sym.simplify(\n                    2**0.5 * sph_harm_prefactor(i, -j) * S_m[j] * P_l_m[i][j]\n                )\n\n    return Y_func_l_m\n"
  },
  {
    "path": "examples/pytorch/dimenet/modules/bessel_basis_layer.py",
    "content": "import numpy as np\nimport torch\nimport torch.nn as nn\nfrom modules.envelope import Envelope\n\n\nclass BesselBasisLayer(nn.Module):\n    def __init__(self, num_radial, cutoff, envelope_exponent=5):\n        super(BesselBasisLayer, self).__init__()\n\n        self.cutoff = cutoff\n        self.envelope = Envelope(envelope_exponent)\n        self.frequencies = nn.Parameter(torch.Tensor(num_radial))\n        self.reset_params()\n\n    def reset_params(self):\n        with torch.no_grad():\n            torch.arange(\n                1, self.frequencies.numel() + 1, out=self.frequencies\n            ).mul_(np.pi)\n        self.frequencies.requires_grad_()\n\n    def forward(self, g):\n        d_scaled = g.edata[\"d\"] / self.cutoff\n        # Necessary for proper broadcasting behaviour\n        d_scaled = torch.unsqueeze(d_scaled, -1)\n        d_cutoff = self.envelope(d_scaled)\n        g.edata[\"rbf\"] = d_cutoff * torch.sin(self.frequencies * d_scaled)\n        return g\n"
  },
  {
    "path": "examples/pytorch/dimenet/modules/dimenet.py",
    "content": "import torch\nimport torch.nn as nn\nfrom modules.activations import swish\nfrom modules.bessel_basis_layer import BesselBasisLayer\nfrom modules.embedding_block import EmbeddingBlock\nfrom modules.interaction_block import InteractionBlock\nfrom modules.output_block import OutputBlock\nfrom modules.spherical_basis_layer import SphericalBasisLayer\n\n\nclass DimeNet(nn.Module):\n    \"\"\"\n    DimeNet model.\n\n    Parameters\n    ----------\n    emb_size\n        Embedding size used throughout the model\n    num_blocks\n        Number of building blocks to be stacked\n    num_bilinear\n        Third dimension of the bilinear layer tensor\n    num_spherical\n        Number of spherical harmonics\n    num_radial\n        Number of radial basis functions\n    cutoff\n        Cutoff distance for interatomic interactions\n    envelope_exponent\n        Shape of the smooth cutoff\n    num_before_skip\n        Number of residual layers in interaction block before skip connection\n    num_after_skip\n        Number of residual layers in interaction block after skip connection\n    num_dense_output\n        Number of dense layers for the output blocks\n    num_targets\n        Number of targets to predict\n    activation\n        Activation function\n    output_init\n        Initial function in output block\n    \"\"\"\n\n    def __init__(\n        self,\n        emb_size,\n        num_blocks,\n        num_bilinear,\n        num_spherical,\n        num_radial,\n        cutoff=5.0,\n        envelope_exponent=5,\n        num_before_skip=1,\n        num_after_skip=2,\n        num_dense_output=3,\n        num_targets=12,\n        activation=swish,\n        output_init=nn.init.zeros_,\n    ):\n        super(DimeNet, self).__init__()\n\n        self.num_blocks = num_blocks\n        self.num_radial = num_radial\n\n        # cosine basis function expansion layer\n        self.rbf_layer = BesselBasisLayer(\n            num_radial=num_radial,\n            cutoff=cutoff,\n            envelope_exponent=envelope_exponent,\n        )\n\n        self.sbf_layer = SphericalBasisLayer(\n            num_spherical=num_spherical,\n            num_radial=num_radial,\n            cutoff=cutoff,\n            envelope_exponent=envelope_exponent,\n        )\n\n        # embedding block\n        self.emb_block = EmbeddingBlock(\n            emb_size=emb_size,\n            num_radial=num_radial,\n            bessel_funcs=self.sbf_layer.get_bessel_funcs(),\n            cutoff=cutoff,\n            envelope_exponent=envelope_exponent,\n            activation=activation,\n        )\n\n        # output block\n        self.output_blocks = nn.ModuleList(\n            {\n                OutputBlock(\n                    emb_size=emb_size,\n                    num_radial=num_radial,\n                    num_dense=num_dense_output,\n                    num_targets=num_targets,\n                    activation=activation,\n                    output_init=output_init,\n                )\n                for _ in range(num_blocks + 1)\n            }\n        )\n\n        # interaction block\n        self.interaction_blocks = nn.ModuleList(\n            {\n                InteractionBlock(\n                    emb_size=emb_size,\n                    num_radial=num_radial,\n                    num_spherical=num_spherical,\n                    num_bilinear=num_bilinear,\n                    num_before_skip=num_before_skip,\n                    num_after_skip=num_after_skip,\n                    activation=activation,\n                )\n                for _ in range(num_blocks)\n            }\n        )\n\n    def edge_init(self, edges):\n        # Calculate angles k -> j -> i\n        R1, R2 = edges.src[\"o\"], edges.dst[\"o\"]\n        x = torch.sum(R1 * R2, dim=-1)\n        y = torch.cross(R1, R2)\n        y = torch.norm(y, dim=-1)\n        angle = torch.atan2(y, x)\n        # Transform via angles\n        cbf = [f(angle) for f in self.sbf_layer.get_sph_funcs()]\n        cbf = torch.stack(cbf, dim=1)  # [None, 7]\n        cbf = cbf.repeat_interleave(self.num_radial, dim=1)  # [None, 42]\n        sbf = edges.src[\"rbf_env\"] * cbf  # [None, 42]\n        return {\"sbf\": sbf}\n\n    def forward(self, g, l_g):\n        # add rbf features for each edge in one batch graph, [num_radial,]\n        g = self.rbf_layer(g)\n        # Embedding block\n        g = self.emb_block(g)\n        # Output block\n        P = self.output_blocks[0](g)  # [batch_size, num_targets]\n        # Prepare sbf feature before the following blocks\n        for k, v in g.edata.items():\n            l_g.ndata[k] = v\n\n        l_g.apply_edges(self.edge_init)\n        # Interaction blocks\n        for i in range(self.num_blocks):\n            g = self.interaction_blocks[i](g, l_g)\n            P += self.output_blocks[i + 1](g)\n\n        return P\n"
  },
  {
    "path": "examples/pytorch/dimenet/modules/dimenet_pp.py",
    "content": "import torch\nimport torch.nn as nn\nfrom modules.activations import swish\nfrom modules.bessel_basis_layer import BesselBasisLayer\nfrom modules.embedding_block import EmbeddingBlock\nfrom modules.interaction_pp_block import InteractionPPBlock\nfrom modules.output_pp_block import OutputPPBlock\nfrom modules.spherical_basis_layer import SphericalBasisLayer\n\n\nclass DimeNetPP(nn.Module):\n    \"\"\"\n    DimeNet++ model.\n\n    Parameters\n    ----------\n    emb_size\n        Embedding size used for the messages\n    out_emb_size\n        Embedding size used for atoms in the output block\n    int_emb_size\n        Embedding size used for interaction triplets\n    basis_emb_size\n        Embedding size used inside the basis transformation\n    num_blocks\n        Number of building blocks to be stacked\n    num_spherical\n        Number of spherical harmonics\n    num_radial\n        Number of radial basis functions\n    cutoff\n        Cutoff distance for interatomic interactions\n    envelope_exponent\n        Shape of the smooth cutoff\n    num_before_skip\n        Number of residual layers in interaction block before skip connection\n    num_after_skip\n        Number of residual layers in interaction block after skip connection\n    num_dense_output\n        Number of dense layers for the output blocks\n    num_targets\n        Number of targets to predict\n    activation\n        Activation function\n    extensive\n        Whether the output should be extensive (proportional to the number of atoms)\n    output_init\n        Initial function in output block\n    \"\"\"\n\n    def __init__(\n        self,\n        emb_size,\n        out_emb_size,\n        int_emb_size,\n        basis_emb_size,\n        num_blocks,\n        num_spherical,\n        num_radial,\n        cutoff=5.0,\n        envelope_exponent=5,\n        num_before_skip=1,\n        num_after_skip=2,\n        num_dense_output=3,\n        num_targets=12,\n        activation=swish,\n        extensive=True,\n        output_init=nn.init.zeros_,\n    ):\n        super(DimeNetPP, self).__init__()\n\n        self.num_blocks = num_blocks\n        self.num_radial = num_radial\n\n        # cosine basis function expansion layer\n        self.rbf_layer = BesselBasisLayer(\n            num_radial=num_radial,\n            cutoff=cutoff,\n            envelope_exponent=envelope_exponent,\n        )\n\n        self.sbf_layer = SphericalBasisLayer(\n            num_spherical=num_spherical,\n            num_radial=num_radial,\n            cutoff=cutoff,\n            envelope_exponent=envelope_exponent,\n        )\n\n        # embedding block\n        self.emb_block = EmbeddingBlock(\n            emb_size=emb_size,\n            num_radial=num_radial,\n            bessel_funcs=self.sbf_layer.get_bessel_funcs(),\n            cutoff=cutoff,\n            envelope_exponent=envelope_exponent,\n            activation=activation,\n        )\n\n        # output block\n        self.output_blocks = nn.ModuleList(\n            {\n                OutputPPBlock(\n                    emb_size=emb_size,\n                    out_emb_size=out_emb_size,\n                    num_radial=num_radial,\n                    num_dense=num_dense_output,\n                    num_targets=num_targets,\n                    activation=activation,\n                    extensive=extensive,\n                    output_init=output_init,\n                )\n                for _ in range(num_blocks + 1)\n            }\n        )\n\n        # interaction block\n        self.interaction_blocks = nn.ModuleList(\n            {\n                InteractionPPBlock(\n                    emb_size=emb_size,\n                    int_emb_size=int_emb_size,\n                    basis_emb_size=basis_emb_size,\n                    num_radial=num_radial,\n                    num_spherical=num_spherical,\n                    num_before_skip=num_before_skip,\n                    num_after_skip=num_after_skip,\n                    activation=activation,\n                )\n                for _ in range(num_blocks)\n            }\n        )\n\n    def edge_init(self, edges):\n        # Calculate angles k -> j -> i\n        R1, R2 = edges.src[\"o\"], edges.dst[\"o\"]\n        x = torch.sum(R1 * R2, dim=-1)\n        y = torch.cross(R1, R2)\n        y = torch.norm(y, dim=-1)\n        angle = torch.atan2(y, x)\n        # Transform via angles\n        cbf = [f(angle) for f in self.sbf_layer.get_sph_funcs()]\n        cbf = torch.stack(cbf, dim=1)  # [None, 7]\n        cbf = cbf.repeat_interleave(self.num_radial, dim=1)  # [None, 42]\n        # Notice: it's dst, not src\n        sbf = edges.dst[\"rbf_env\"] * cbf  # [None, 42]\n        return {\"sbf\": sbf}\n\n    def forward(self, g, l_g):\n        # add rbf features for each edge in one batch graph, [num_radial,]\n        g = self.rbf_layer(g)\n        # Embedding block\n        g = self.emb_block(g)\n        # Output block\n        P = self.output_blocks[0](g)  # [batch_size, num_targets]\n        # Prepare sbf feature before the following blocks\n        for k, v in g.edata.items():\n            l_g.ndata[k] = v\n\n        l_g.apply_edges(self.edge_init)\n        # Interaction blocks\n        for i in range(self.num_blocks):\n            g = self.interaction_blocks[i](g, l_g)\n            P += self.output_blocks[i + 1](g)\n\n        return P\n"
  },
  {
    "path": "examples/pytorch/dimenet/modules/embedding_block.py",
    "content": "import numpy as np\nimport torch\nimport torch.nn as nn\nfrom modules.envelope import Envelope\nfrom modules.initializers import GlorotOrthogonal\n\n\nclass EmbeddingBlock(nn.Module):\n    def __init__(\n        self,\n        emb_size,\n        num_radial,\n        bessel_funcs,\n        cutoff,\n        envelope_exponent,\n        num_atom_types=95,\n        activation=None,\n    ):\n        super(EmbeddingBlock, self).__init__()\n\n        self.bessel_funcs = bessel_funcs\n        self.cutoff = cutoff\n        self.activation = activation\n        self.envelope = Envelope(envelope_exponent)\n        self.embedding = nn.Embedding(num_atom_types, emb_size)\n        self.dense_rbf = nn.Linear(num_radial, emb_size)\n        self.dense = nn.Linear(emb_size * 3, emb_size)\n        self.reset_params()\n\n    def reset_params(self):\n        nn.init.uniform_(self.embedding.weight, a=-np.sqrt(3), b=np.sqrt(3))\n        GlorotOrthogonal(self.dense_rbf.weight)\n        GlorotOrthogonal(self.dense.weight)\n\n    def edge_init(self, edges):\n        \"\"\"msg emb init\"\"\"\n        # m init\n        rbf = self.dense_rbf(edges.data[\"rbf\"])\n        if self.activation is not None:\n            rbf = self.activation(rbf)\n\n        m = torch.cat([edges.src[\"h\"], edges.dst[\"h\"], rbf], dim=-1)\n        m = self.dense(m)\n        if self.activation is not None:\n            m = self.activation(m)\n\n        # rbf_env init\n        d_scaled = edges.data[\"d\"] / self.cutoff\n        rbf_env = [f(d_scaled) for f in self.bessel_funcs]\n        rbf_env = torch.stack(rbf_env, dim=1)\n\n        d_cutoff = self.envelope(d_scaled)\n        rbf_env = d_cutoff[:, None] * rbf_env\n\n        return {\"m\": m, \"rbf_env\": rbf_env}\n\n    def forward(self, g):\n        g.ndata[\"h\"] = self.embedding(g.ndata[\"Z\"])\n        g.apply_edges(self.edge_init)\n        return g\n"
  },
  {
    "path": "examples/pytorch/dimenet/modules/envelope.py",
    "content": "import torch.nn as nn\n\n\nclass Envelope(nn.Module):\n    \"\"\"\n    Envelope function that ensures a smooth cutoff\n    \"\"\"\n\n    def __init__(self, exponent):\n        super(Envelope, self).__init__()\n\n        self.p = exponent + 1\n        self.a = -(self.p + 1) * (self.p + 2) / 2\n        self.b = self.p * (self.p + 2)\n        self.c = -self.p * (self.p + 1) / 2\n\n    def forward(self, x):\n        # Envelope function divided by r\n        x_p_0 = x.pow(self.p - 1)\n        x_p_1 = x_p_0 * x\n        x_p_2 = x_p_1 * x\n        env_val = 1 / x + self.a * x_p_0 + self.b * x_p_1 + self.c * x_p_2\n        return env_val\n"
  },
  {
    "path": "examples/pytorch/dimenet/modules/initializers.py",
    "content": "import torch.nn as nn\n\n\ndef GlorotOrthogonal(tensor, scale=2.0):\n    if tensor is not None:\n        nn.init.orthogonal_(tensor.data)\n        scale /= (tensor.size(-2) + tensor.size(-1)) * tensor.var()\n        tensor.data *= scale.sqrt()\n"
  },
  {
    "path": "examples/pytorch/dimenet/modules/interaction_block.py",
    "content": "import dgl.function as fn\nimport torch\nimport torch.nn as nn\nfrom modules.initializers import GlorotOrthogonal\nfrom modules.residual_layer import ResidualLayer\n\n\nclass InteractionBlock(nn.Module):\n    def __init__(\n        self,\n        emb_size,\n        num_radial,\n        num_spherical,\n        num_bilinear,\n        num_before_skip,\n        num_after_skip,\n        activation=None,\n    ):\n        super(InteractionBlock, self).__init__()\n\n        self.activation = activation\n        # Transformations of Bessel and spherical basis representations\n        self.dense_rbf = nn.Linear(num_radial, emb_size, bias=False)\n        self.dense_sbf = nn.Linear(\n            num_radial * num_spherical, num_bilinear, bias=False\n        )\n        # Dense transformations of input messages\n        self.dense_ji = nn.Linear(emb_size, emb_size)\n        self.dense_kj = nn.Linear(emb_size, emb_size)\n        # Bilinear layer\n        bilin_initializer = torch.empty(\n            (emb_size, num_bilinear, emb_size)\n        ).normal_(mean=0, std=2 / emb_size)\n        self.W_bilin = nn.Parameter(bilin_initializer)\n        # Residual layers before skip connection\n        self.layers_before_skip = nn.ModuleList(\n            [\n                ResidualLayer(emb_size, activation=activation)\n                for _ in range(num_before_skip)\n            ]\n        )\n        self.final_before_skip = nn.Linear(emb_size, emb_size)\n        # Residual layers after skip connection\n        self.layers_after_skip = nn.ModuleList(\n            [\n                ResidualLayer(emb_size, activation=activation)\n                for _ in range(num_after_skip)\n            ]\n        )\n\n        self.reset_params()\n\n    def reset_params(self):\n        GlorotOrthogonal(self.dense_rbf.weight)\n        GlorotOrthogonal(self.dense_sbf.weight)\n        GlorotOrthogonal(self.dense_ji.weight)\n        GlorotOrthogonal(self.dense_kj.weight)\n        GlorotOrthogonal(self.final_before_skip.weight)\n\n    def edge_transfer(self, edges):\n        # Transform from Bessel basis to dence vector\n        rbf = self.dense_rbf(edges.data[\"rbf\"])\n        # Initial transformation\n        x_ji = self.dense_ji(edges.data[\"m\"])\n        x_kj = self.dense_kj(edges.data[\"m\"])\n        if self.activation is not None:\n            x_ji = self.activation(x_ji)\n            x_kj = self.activation(x_kj)\n\n        # w: W * e_RBF \\bigodot \\sigma(W * m + b)\n        return {\"x_kj\": x_kj * rbf, \"x_ji\": x_ji}\n\n    def msg_func(self, edges):\n        sbf = self.dense_sbf(edges.data[\"sbf\"])\n        # Apply bilinear layer to interactions and basis function activation\n        # [None, 8] * [128, 8, 128] * [None, 128] -> [None, 128]\n        x_kj = torch.einsum(\n            \"wj,wl,ijl->wi\", sbf, edges.src[\"x_kj\"], self.W_bilin\n        )\n        return {\"x_kj\": x_kj}\n\n    def forward(self, g, l_g):\n        g.apply_edges(self.edge_transfer)\n\n        # nodes correspond to edges and edges correspond to nodes in the original graphs\n        # node: d, rbf, o, rbf_env, x_kj, x_ji\n        for k, v in g.edata.items():\n            l_g.ndata[k] = v\n\n        l_g.update_all(self.msg_func, fn.sum(\"x_kj\", \"m_update\"))\n\n        for k, v in l_g.ndata.items():\n            g.edata[k] = v\n\n        # Transformations before skip connection\n        g.edata[\"m_update\"] = g.edata[\"m_update\"] + g.edata[\"x_ji\"]\n        for layer in self.layers_before_skip:\n            g.edata[\"m_update\"] = layer(g.edata[\"m_update\"])\n        g.edata[\"m_update\"] = self.final_before_skip(g.edata[\"m_update\"])\n        if self.activation is not None:\n            g.edata[\"m_update\"] = self.activation(g.edata[\"m_update\"])\n\n        # Skip connection\n        g.edata[\"m\"] = g.edata[\"m\"] + g.edata[\"m_update\"]\n\n        # Transformations after skip connection\n        for layer in self.layers_after_skip:\n            g.edata[\"m\"] = layer(g.edata[\"m\"])\n\n        return g\n"
  },
  {
    "path": "examples/pytorch/dimenet/modules/interaction_pp_block.py",
    "content": "import dgl\nimport dgl.function as fn\nimport torch.nn as nn\nfrom modules.initializers import GlorotOrthogonal\nfrom modules.residual_layer import ResidualLayer\n\n\nclass InteractionPPBlock(nn.Module):\n    def __init__(\n        self,\n        emb_size,\n        int_emb_size,\n        basis_emb_size,\n        num_radial,\n        num_spherical,\n        num_before_skip,\n        num_after_skip,\n        activation=None,\n    ):\n        super(InteractionPPBlock, self).__init__()\n\n        self.activation = activation\n        # Transformations of Bessel and spherical basis representations\n        self.dense_rbf1 = nn.Linear(num_radial, basis_emb_size, bias=False)\n        self.dense_rbf2 = nn.Linear(basis_emb_size, emb_size, bias=False)\n        self.dense_sbf1 = nn.Linear(\n            num_radial * num_spherical, basis_emb_size, bias=False\n        )\n        self.dense_sbf2 = nn.Linear(basis_emb_size, int_emb_size, bias=False)\n        # Dense transformations of input messages\n        self.dense_ji = nn.Linear(emb_size, emb_size)\n        self.dense_kj = nn.Linear(emb_size, emb_size)\n        # Embedding projections for interaction triplets\n        self.down_projection = nn.Linear(emb_size, int_emb_size, bias=False)\n        self.up_projection = nn.Linear(int_emb_size, emb_size, bias=False)\n        # Residual layers before skip connection\n        self.layers_before_skip = nn.ModuleList(\n            [\n                ResidualLayer(emb_size, activation=activation)\n                for _ in range(num_before_skip)\n            ]\n        )\n        self.final_before_skip = nn.Linear(emb_size, emb_size)\n        # Residual layers after skip connection\n        self.layers_after_skip = nn.ModuleList(\n            [\n                ResidualLayer(emb_size, activation=activation)\n                for _ in range(num_after_skip)\n            ]\n        )\n\n        self.reset_params()\n\n    def reset_params(self):\n        GlorotOrthogonal(self.dense_rbf1.weight)\n        GlorotOrthogonal(self.dense_rbf2.weight)\n        GlorotOrthogonal(self.dense_sbf1.weight)\n        GlorotOrthogonal(self.dense_sbf2.weight)\n        GlorotOrthogonal(self.dense_ji.weight)\n        nn.init.zeros_(self.dense_ji.bias)\n        GlorotOrthogonal(self.dense_kj.weight)\n        nn.init.zeros_(self.dense_kj.bias)\n        GlorotOrthogonal(self.down_projection.weight)\n        GlorotOrthogonal(self.up_projection.weight)\n\n    def edge_transfer(self, edges):\n        # Transform from Bessel basis to dense vector\n        rbf = self.dense_rbf1(edges.data[\"rbf\"])\n        rbf = self.dense_rbf2(rbf)\n        # Initial transformation\n        x_ji = self.dense_ji(edges.data[\"m\"])\n        x_kj = self.dense_kj(edges.data[\"m\"])\n        if self.activation is not None:\n            x_ji = self.activation(x_ji)\n            x_kj = self.activation(x_kj)\n\n        x_kj = self.down_projection(x_kj * rbf)\n        if self.activation is not None:\n            x_kj = self.activation(x_kj)\n        return {\"x_kj\": x_kj, \"x_ji\": x_ji}\n\n    def msg_func(self, edges):\n        sbf = self.dense_sbf1(edges.data[\"sbf\"])\n        sbf = self.dense_sbf2(sbf)\n        x_kj = edges.src[\"x_kj\"] * sbf\n        return {\"x_kj\": x_kj}\n\n    def forward(self, g, l_g):\n        g.apply_edges(self.edge_transfer)\n\n        # nodes correspond to edges and edges correspond to nodes in the original graphs\n        # node: d, rbf, o, rbf_env, x_kj, x_ji\n        for k, v in g.edata.items():\n            l_g.ndata[k] = v\n\n        l_g_reverse = dgl.reverse(l_g, copy_edata=True)\n        l_g_reverse.update_all(self.msg_func, fn.sum(\"x_kj\", \"m_update\"))\n\n        g.edata[\"m_update\"] = self.up_projection(l_g_reverse.ndata[\"m_update\"])\n        if self.activation is not None:\n            g.edata[\"m_update\"] = self.activation(g.edata[\"m_update\"])\n        # Transformations before skip connection\n        g.edata[\"m_update\"] = g.edata[\"m_update\"] + g.edata[\"x_ji\"]\n        for layer in self.layers_before_skip:\n            g.edata[\"m_update\"] = layer(g.edata[\"m_update\"])\n        g.edata[\"m_update\"] = self.final_before_skip(g.edata[\"m_update\"])\n        if self.activation is not None:\n            g.edata[\"m_update\"] = self.activation(g.edata[\"m_update\"])\n\n        # Skip connection\n        g.edata[\"m\"] = g.edata[\"m\"] + g.edata[\"m_update\"]\n\n        # Transformations after skip connection\n        for layer in self.layers_after_skip:\n            g.edata[\"m\"] = layer(g.edata[\"m\"])\n\n        return g\n"
  },
  {
    "path": "examples/pytorch/dimenet/modules/output_block.py",
    "content": "import dgl\nimport dgl.function as fn\nimport torch.nn as nn\nfrom modules.initializers import GlorotOrthogonal\n\n\nclass OutputBlock(nn.Module):\n    def __init__(\n        self,\n        emb_size,\n        num_radial,\n        num_dense,\n        num_targets,\n        activation=None,\n        output_init=nn.init.zeros_,\n    ):\n        super(OutputBlock, self).__init__()\n\n        self.activation = activation\n        self.output_init = output_init\n        self.dense_rbf = nn.Linear(num_radial, emb_size, bias=False)\n        self.dense_layers = nn.ModuleList(\n            [nn.Linear(emb_size, emb_size) for _ in range(num_dense)]\n        )\n        self.dense_final = nn.Linear(emb_size, num_targets, bias=False)\n        self.reset_params()\n\n    def reset_params(self):\n        GlorotOrthogonal(self.dense_rbf.weight)\n        for layer in self.dense_layers:\n            GlorotOrthogonal(layer.weight)\n        self.output_init(self.dense_final.weight)\n\n    def forward(self, g):\n        with g.local_scope():\n            g.edata[\"tmp\"] = g.edata[\"m\"] * self.dense_rbf(g.edata[\"rbf\"])\n            g.update_all(fn.copy_e(\"tmp\", \"x\"), fn.sum(\"x\", \"t\"))\n\n            for layer in self.dense_layers:\n                g.ndata[\"t\"] = layer(g.ndata[\"t\"])\n                if self.activation is not None:\n                    g.ndata[\"t\"] = self.activation(g.ndata[\"t\"])\n            g.ndata[\"t\"] = self.dense_final(g.ndata[\"t\"])\n            return dgl.readout_nodes(g, \"t\")\n"
  },
  {
    "path": "examples/pytorch/dimenet/modules/output_pp_block.py",
    "content": "import dgl\nimport dgl.function as fn\nimport torch.nn as nn\nfrom modules.initializers import GlorotOrthogonal\n\n\nclass OutputPPBlock(nn.Module):\n    def __init__(\n        self,\n        emb_size,\n        out_emb_size,\n        num_radial,\n        num_dense,\n        num_targets,\n        activation=None,\n        output_init=nn.init.zeros_,\n        extensive=True,\n    ):\n        super(OutputPPBlock, self).__init__()\n\n        self.activation = activation\n        self.output_init = output_init\n        self.extensive = extensive\n        self.dense_rbf = nn.Linear(num_radial, emb_size, bias=False)\n        self.up_projection = nn.Linear(emb_size, out_emb_size, bias=False)\n        self.dense_layers = nn.ModuleList(\n            [nn.Linear(out_emb_size, out_emb_size) for _ in range(num_dense)]\n        )\n        self.dense_final = nn.Linear(out_emb_size, num_targets, bias=False)\n        self.reset_params()\n\n    def reset_params(self):\n        GlorotOrthogonal(self.dense_rbf.weight)\n        GlorotOrthogonal(self.up_projection.weight)\n        for layer in self.dense_layers:\n            GlorotOrthogonal(layer.weight)\n        self.output_init(self.dense_final.weight)\n\n    def forward(self, g):\n        with g.local_scope():\n            g.edata[\"tmp\"] = g.edata[\"m\"] * self.dense_rbf(g.edata[\"rbf\"])\n            g_reverse = dgl.reverse(g, copy_edata=True)\n            g_reverse.update_all(fn.copy_e(\"tmp\", \"x\"), fn.sum(\"x\", \"t\"))\n            g.ndata[\"t\"] = self.up_projection(g_reverse.ndata[\"t\"])\n\n            for layer in self.dense_layers:\n                g.ndata[\"t\"] = layer(g.ndata[\"t\"])\n                if self.activation is not None:\n                    g.ndata[\"t\"] = self.activation(g.ndata[\"t\"])\n            g.ndata[\"t\"] = self.dense_final(g.ndata[\"t\"])\n            return dgl.readout_nodes(\n                g, \"t\", op=\"sum\" if self.extensive else \"mean\"\n            )\n"
  },
  {
    "path": "examples/pytorch/dimenet/modules/residual_layer.py",
    "content": "import torch.nn as nn\nfrom modules.initializers import GlorotOrthogonal\n\n\nclass ResidualLayer(nn.Module):\n    def __init__(self, units, activation=None):\n        super(ResidualLayer, self).__init__()\n\n        self.activation = activation\n        self.dense_1 = nn.Linear(units, units)\n        self.dense_2 = nn.Linear(units, units)\n\n        self.reset_params()\n\n    def reset_params(self):\n        GlorotOrthogonal(self.dense_1.weight)\n        nn.init.zeros_(self.dense_1.bias)\n        GlorotOrthogonal(self.dense_2.weight)\n        nn.init.zeros_(self.dense_2.bias)\n\n    def forward(self, inputs):\n        x = self.dense_1(inputs)\n        if self.activation is not None:\n            x = self.activation(x)\n        x = self.dense_2(x)\n        if self.activation is not None:\n            x = self.activation(x)\n        return inputs + x\n"
  },
  {
    "path": "examples/pytorch/dimenet/modules/spherical_basis_layer.py",
    "content": "import sympy as sym\nimport torch\nimport torch.nn as nn\nfrom modules.basis_utils import bessel_basis, real_sph_harm\nfrom modules.envelope import Envelope\n\n\nclass SphericalBasisLayer(nn.Module):\n    def __init__(self, num_spherical, num_radial, cutoff, envelope_exponent=5):\n        super(SphericalBasisLayer, self).__init__()\n\n        assert num_radial <= 64\n        self.num_radial = num_radial\n        self.num_spherical = num_spherical\n        self.cutoff = cutoff\n        self.envelope = Envelope(envelope_exponent)\n\n        # retrieve formulas\n        self.bessel_formulas = bessel_basis(\n            num_spherical, num_radial\n        )  # x, [num_spherical, num_radial] sympy functions\n        self.sph_harm_formulas = real_sph_harm(\n            num_spherical\n        )  # theta, [num_spherical, ] sympy functions\n        self.sph_funcs = []\n        self.bessel_funcs = []\n\n        # convert to torch functions\n        x = sym.symbols(\"x\")\n        theta = sym.symbols(\"theta\")\n        modules = {\"sin\": torch.sin, \"cos\": torch.cos}\n        for i in range(num_spherical):\n            if i == 0:\n                first_sph = sym.lambdify(\n                    [theta], self.sph_harm_formulas[i][0], modules\n                )(0)\n                self.sph_funcs.append(\n                    lambda tensor: torch.zeros_like(tensor) + first_sph\n                )\n            else:\n                self.sph_funcs.append(\n                    sym.lambdify([theta], self.sph_harm_formulas[i][0], modules)\n                )\n            for j in range(num_radial):\n                self.bessel_funcs.append(\n                    sym.lambdify([x], self.bessel_formulas[i][j], modules)\n                )\n\n    def get_bessel_funcs(self):\n        return self.bessel_funcs\n\n    def get_sph_funcs(self):\n        return self.sph_funcs\n"
  },
  {
    "path": "examples/pytorch/dimenet/qm9.py",
    "content": "\"\"\"QM9 dataset for graph property prediction (regression).\"\"\"\nimport os\n\nimport dgl\n\nimport numpy as np\nimport scipy.sparse as sp\nimport torch\nfrom dgl.convert import graph as dgl_graph\nfrom dgl.data import QM9Dataset\nfrom dgl.data.utils import load_graphs, save_graphs\nfrom tqdm import trange\n\n\nclass QM9(QM9Dataset):\n    r\"\"\"QM9 dataset for graph property prediction (regression)\n\n    This dataset consists of 130,831 molecules with 12 regression targets.\n    Nodes correspond to atoms and edges correspond to bonds.\n\n    Reference:\n\n    - `\"Quantum-Machine.org\" <http://quantum-machine.org/datasets/>`_\n    - `\"Directional Message Passing for Molecular Graphs\" <https://arxiv.org/abs/2003.03123>`_\n\n    Statistics:\n\n    - Number of graphs: 130,831\n    - Number of regression targets: 12\n\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | Keys   | Property                         | Description                                                                       | Unit                                        |\n    +========+==================================+===================================================================================+=============================================+\n    | mu     | :math:`\\mu`                      | Dipole moment                                                                     | :math:`\\textrm{D}`                          |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | alpha  | :math:`\\alpha`                   | Isotropic polarizability                                                          | :math:`{a_0}^3`                             |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | homo   | :math:`\\epsilon_{\\textrm{HOMO}}` | Highest occupied molecular orbital energy                                         | :math:`\\textrm{eV}`                         |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | lumo   | :math:`\\epsilon_{\\textrm{LUMO}}` | Lowest unoccupied molecular orbital energy                                        | :math:`\\textrm{eV}`                         |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | gap    | :math:`\\Delta \\epsilon`          | Gap between :math:`\\epsilon_{\\textrm{HOMO}}` and :math:`\\epsilon_{\\textrm{LUMO}}` | :math:`\\textrm{eV}`                         |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | r2     | :math:`\\langle R^2 \\rangle`      | Electronic spatial extent                                                         | :math:`{a_0}^2`                             |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | zpve   | :math:`\\textrm{ZPVE}`            | Zero point vibrational energy                                                     | :math:`\\textrm{eV}`                         |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | U0     | :math:`U_0`                      | Internal energy at 0K                                                             | :math:`\\textrm{eV}`                         |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | U      | :math:`U`                        | Internal energy at 298.15K                                                        | :math:`\\textrm{eV}`                         |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | H      | :math:`H`                        | Enthalpy at 298.15K                                                               | :math:`\\textrm{eV}`                         |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | G      | :math:`G`                        | Free energy at 298.15K                                                            | :math:`\\textrm{eV}`                         |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | Cv     | :math:`c_{\\textrm{v}}`           | Heat capavity at 298.15K                                                          | :math:`\\frac{\\textrm{cal}}{\\textrm{mol K}}` |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n\n    Parameters\n    ----------\n    label_keys: list\n        Names of the regression property, which should be a subset of the keys in the table above.\n    edge_funcs: list\n        A list of edge-wise user-defined functions <https://docs.dgl.ai/en/0.6.x/api/python/udf.html#edge-wise-user-defined-function> for chemical bonds. Default: None\n    cutoff: float\n        Cutoff distance for interatomic interactions, i.e. two atoms are connected in the corresponding graph if the distance between them is no larger than this.\n        Default: 5.0 Angstrom\n    raw_dir : str\n        Raw file directory to download/contains the input data directory.\n        Default: ~/.dgl/\n    force_reload : bool\n        Whether to reload the dataset. Default: False\n    verbose: bool\n        Whether to print out progress information. Default: True\n\n    Attributes\n    ----------\n    num_labels : int\n        Number of labels for each graph, i.e. number of prediction tasks\n\n    Raises\n    ------\n    UserWarning\n        If the raw data is changed in the remote server by the author.\n\n    Examples\n    --------\n    >>> data = QM9Dataset(label_keys=['mu', 'gap'], cutoff=5.0)\n    >>> data.num_classes\n    2\n    >>>\n    >>> # iterate over the dataset\n    >>> for g, label in data:\n    ...     R = g.ndata['R'] # get coordinates of each atom\n    ...     Z = g.ndata['Z'] # get atomic numbers of each atom\n    ...     # your code here...\n    >>>\n    \"\"\"\n\n    def __init__(\n        self,\n        label_keys,\n        edge_funcs=None,\n        cutoff=5.0,\n        raw_dir=None,\n        force_reload=False,\n        verbose=False,\n    ):\n        self.edge_funcs = edge_funcs\n        self._keys = [\n            \"mu\",\n            \"alpha\",\n            \"homo\",\n            \"lumo\",\n            \"gap\",\n            \"r2\",\n            \"zpve\",\n            \"U0\",\n            \"U\",\n            \"H\",\n            \"G\",\n            \"Cv\",\n        ]\n\n        super(QM9, self).__init__(\n            label_keys=label_keys,\n            cutoff=cutoff,\n            raw_dir=raw_dir,\n            force_reload=force_reload,\n            verbose=verbose,\n        )\n\n    @property\n    def graph_path(self):\n        return f\"{self.save_path}/dgl_graph.bin\"\n\n    @property\n    def line_graph_path(self):\n        return f\"{self.save_path}/dgl_line_graph.bin\"\n\n    def has_cache(self):\n        \"\"\"step 1, if True, goto step 5; else goto download(step 2), then step 3\"\"\"\n        return os.path.exists(self.graph_path) and os.path.exists(\n            self.line_graph_path\n        )\n\n    def process(self):\n        \"\"\"step 3\"\"\"\n        npz_path = f\"{self.raw_dir}/qm9_eV.npz\"\n        data_dict = np.load(npz_path, allow_pickle=True)\n        # data_dict['N'] contains the number of atoms in each molecule,\n        # data_dict['R'] consists of the atomic coordinates,\n        # data_dict['Z'] consists of the atomic numbers.\n        # Atomic properties (Z and R) of all molecules are concatenated as single tensors,\n        # so you need this value to select the correct atoms for each molecule.\n        self.N = data_dict[\"N\"]\n        self.R = data_dict[\"R\"]\n        self.Z = data_dict[\"Z\"]\n        self.N_cumsum = np.concatenate([[0], np.cumsum(self.N)])\n        # graph labels\n        self.label_dict = {}\n        for k in self._keys:\n            self.label_dict[k] = torch.tensor(data_dict[k], dtype=torch.float32)\n\n        self.label = torch.stack(\n            [self.label_dict[key] for key in self.label_keys], dim=1\n        )\n        # graphs & features\n        self.graphs, self.line_graphs = self._load_graph()\n\n    def _load_graph(self):\n        num_graphs = self.label.shape[0]\n        graphs = []\n        line_graphs = []\n\n        for idx in trange(num_graphs):\n            n_atoms = self.N[idx]\n            # get all the atomic coordinates of the idx-th molecular graph\n            R = self.R[self.N_cumsum[idx] : self.N_cumsum[idx + 1]]\n            # calculate the distance between all atoms\n            dist = np.linalg.norm(R[:, None, :] - R[None, :, :], axis=-1)\n            # keep all edges that don't exceed the cutoff and delete self-loops\n            adj = sp.csr_matrix(dist <= self.cutoff) - sp.eye(\n                n_atoms, dtype=np.bool_\n            )\n            adj = adj.tocoo()\n            u, v = torch.tensor(adj.row), torch.tensor(adj.col)\n            g = dgl_graph((u, v))\n            g.ndata[\"R\"] = torch.tensor(R, dtype=torch.float32)\n            g.ndata[\"Z\"] = torch.tensor(\n                self.Z[self.N_cumsum[idx] : self.N_cumsum[idx + 1]],\n                dtype=torch.long,\n            )\n\n            # add user-defined features\n            if self.edge_funcs is not None:\n                for func in self.edge_funcs:\n                    g.apply_edges(func)\n\n            graphs.append(g)\n            l_g = dgl.line_graph(g, backtracking=False)\n            line_graphs.append(l_g)\n\n        return graphs, line_graphs\n\n    def save(self):\n        \"\"\"step 4\"\"\"\n        save_graphs(str(self.graph_path), self.graphs, self.label_dict)\n        save_graphs(str(self.line_graph_path), self.line_graphs)\n\n    def load(self):\n        \"\"\"step 5\"\"\"\n        self.graphs, label_dict = load_graphs(self.graph_path)\n        self.line_graphs, _ = load_graphs(self.line_graph_path)\n        self.label = torch.stack(\n            [label_dict[key] for key in self.label_keys], dim=1\n        )\n\n    def __getitem__(self, idx):\n        r\"\"\"Get graph and label by index\n\n        Parameters\n        ----------\n        idx : int\n            Item index\n\n        Returns\n        -------\n        dgl.DGLGraph\n            The graph contains:\n            - ``ndata['R']``: the coordinates of each atom\n            - ``ndata['Z']``: the atomic number\n        Tensor\n            Property values of molecular graphs\n        \"\"\"\n        return self.graphs[idx], self.line_graphs[idx], self.label[idx]\n"
  },
  {
    "path": "examples/pytorch/dtgrnn/README.md",
    "content": "# Discrete Temporal Dynamic Graph with recurrent structure\n## DGL Implementation of DCRNN and GaAN paper.\n\nThis DGL example implements the GNN model proposed in the paper [Diffusion Convolutional Recurrent Neural Network: Data-Driven Traffic Forecasting](https://arxiv.org/abs/1707.01926) and [GaAN:Gated Attention Networks for Learning on Large and Spatiotemporal Graphs](https://arxiv.org/pdf/1803.07294). \n\nModel implementor\n----------------------\nThis example was implemented by [Ericcsr](https://github.com/Ericcsr) during his Internship work at the AWS Shanghai AI Lab.\n\nThe graph dataset used in this example \n---------------------------------------\nMETR-LA dataset. Dataset summary:\n- NumNodes: 207\n- NumEdges: 1722\n- NumFeats: 2\n- TrainingSamples: 70%\n- ValidationSamples: 20%\n- TestSamples: 10%\n\nPEMS-BAY dataset. Dataset Summary:\n\n- NumNodes: 325\n- NumEdges: 2694\n- NumFeats: 2\n- TrainingSamples: 70%\n- ValidationSamples: 20%\n- TestSamples: 10%\n\nHow to run example files\n--------------------------------\nIn the dtdg folder, run\n\n**Please use `train.py`**\n\nTrain the DCRNN model on METR-LA Dataset\n\n```python\npython train.py --dataset LA --model dcrnn\n```\n\nIf want to use a GPU, run\n\n```python\npython train.py --gpu 0 --dataset LA --model dcrnn\n```\n\nif you want to use PEMS-BAY dataset\n\n```python\npython train.py --gpu 0 --dataset BAY --model dcrnn\n```\n\nTrain GaAN model\n\n```python\npython train.py --gpu 0 --model gaan --dataset <LA/BAY>\n```\n\n\nPerformance on METR-LA\n-------------------------\n| Models/Datasets | Test MAE |\n| :-------------- | --------:|\n| DCRNN in DGL    | 2.91 |\n| DCRNN paper     | 3.17 |\n| GaAN in DGL     | 3.20 |\n| GaAN paper      | 3.16 |\n\n\nNotice that Any Graph Convolution module can be plugged into the recurrent discrete temporal dynamic graph template to test performance; simply replace DiffConv or GaAN.\n\n"
  },
  {
    "path": "examples/pytorch/dtgrnn/dataloading.py",
    "content": "import os\nimport ssl\n\nimport dgl\n\nimport numpy as np\nimport torch\nfrom six.moves import urllib\nfrom torch.utils.data import DataLoader, Dataset\n\n\ndef download_file(dataset):\n    print(\"Start Downloading data: {}\".format(dataset))\n    url = \"https://s3.us-west-2.amazonaws.com/dgl-data/dataset/{}\".format(\n        dataset\n    )\n    print(\"Start Downloading File....\")\n    context = ssl._create_unverified_context()\n    data = urllib.request.urlopen(url, context=context)\n    with open(\"./data/{}\".format(dataset), \"wb\") as handle:\n        handle.write(data.read())\n\n\nclass SnapShotDataset(Dataset):\n    def __init__(self, path, npz_file):\n        if not os.path.exists(path + \"/\" + npz_file):\n            if not os.path.exists(path):\n                os.mkdir(path)\n            download_file(npz_file)\n        zipfile = np.load(path + \"/\" + npz_file)\n        self.x = zipfile[\"x\"]\n        self.y = zipfile[\"y\"]\n\n    def __len__(self):\n        return len(self.x)\n\n    def __getitem__(self, idx):\n        if torch.is_tensor(idx):\n            idx = idx.tolist()\n\n        return self.x[idx, ...], self.y[idx, ...]\n\n\ndef METR_LAGraphDataset():\n    if not os.path.exists(\"data/graph_la.bin\"):\n        if not os.path.exists(\"data\"):\n            os.mkdir(\"data\")\n        download_file(\"graph_la.bin\")\n    g, _ = dgl.load_graphs(\"data/graph_la.bin\")\n    return g[0]\n\n\nclass METR_LATrainDataset(SnapShotDataset):\n    def __init__(self):\n        super(METR_LATrainDataset, self).__init__(\"data\", \"metr_la_train.npz\")\n        self.mean = self.x[..., 0].mean()\n        self.std = self.x[..., 0].std()\n\n\nclass METR_LATestDataset(SnapShotDataset):\n    def __init__(self):\n        super(METR_LATestDataset, self).__init__(\"data\", \"metr_la_test.npz\")\n\n\nclass METR_LAValidDataset(SnapShotDataset):\n    def __init__(self):\n        super(METR_LAValidDataset, self).__init__(\"data\", \"metr_la_valid.npz\")\n\n\ndef PEMS_BAYGraphDataset():\n    if not os.path.exists(\"data/graph_bay.bin\"):\n        if not os.path.exists(\"data\"):\n            os.mkdir(\"data\")\n        download_file(\"graph_bay.bin\")\n    g, _ = dgl.load_graphs(\"data/graph_bay.bin\")\n    return g[0]\n\n\nclass PEMS_BAYTrainDataset(SnapShotDataset):\n    def __init__(self):\n        super(PEMS_BAYTrainDataset, self).__init__(\"data\", \"pems_bay_train.npz\")\n        self.mean = self.x[..., 0].mean()\n        self.std = self.x[..., 0].std()\n\n\nclass PEMS_BAYTestDataset(SnapShotDataset):\n    def __init__(self):\n        super(PEMS_BAYTestDataset, self).__init__(\"data\", \"pems_bay_test.npz\")\n\n\nclass PEMS_BAYValidDataset(SnapShotDataset):\n    def __init__(self):\n        super(PEMS_BAYValidDataset, self).__init__(\"data\", \"pems_bay_valid.npz\")\n"
  },
  {
    "path": "examples/pytorch/dtgrnn/dcrnn.py",
    "content": "import dgl\nimport dgl.function as fn\nimport numpy as np\nimport scipy.sparse as sparse\nimport torch\nimport torch.nn as nn\nfrom dgl.base import DGLError\n\n\nclass DiffConv(nn.Module):\n    \"\"\"DiffConv is the implementation of diffusion convolution from paper DCRNN\n    It will compute multiple diffusion matrix and perform multiple diffusion conv on it,\n    this layer can be used for traffic prediction, pedamic model.\n    Parameter\n    ==========\n    in_feats : int\n        number of input feature\n\n    out_feats : int\n        number of output feature\n\n    k : int\n        number of diffusion steps\n\n    dir : str [both/in/out]\n        direction of diffusion convolution\n        From paper default both direction\n    \"\"\"\n\n    def __init__(\n        self, in_feats, out_feats, k, in_graph_list, out_graph_list, dir=\"both\"\n    ):\n        super(DiffConv, self).__init__()\n        self.in_feats = in_feats\n        self.out_feats = out_feats\n        self.k = k\n        self.dir = dir\n        self.num_graphs = self.k - 1 if self.dir == \"both\" else 2 * self.k - 2\n        self.project_fcs = nn.ModuleList()\n        for i in range(self.num_graphs):\n            self.project_fcs.append(\n                nn.Linear(self.in_feats, self.out_feats, bias=False)\n            )\n        self.merger = nn.Parameter(torch.randn(self.num_graphs + 1))\n        self.in_graph_list = in_graph_list\n        self.out_graph_list = out_graph_list\n\n    @staticmethod\n    def attach_graph(g, k):\n        device = g.device\n        out_graph_list = []\n        in_graph_list = []\n        wadj, ind, outd = DiffConv.get_weight_matrix(g)\n        adj = sparse.coo_matrix(wadj / outd.cpu().numpy())\n        outg = dgl.from_scipy(adj, eweight_name=\"weight\").to(device)\n        outg.edata[\"weight\"] = outg.edata[\"weight\"].float().to(device)\n        out_graph_list.append(outg)\n        for i in range(k - 1):\n            out_graph_list.append(\n                DiffConv.diffuse(out_graph_list[-1], wadj, outd)\n            )\n        adj = sparse.coo_matrix(wadj.T / ind.cpu().numpy())\n        ing = dgl.from_scipy(adj, eweight_name=\"weight\").to(device)\n        ing.edata[\"weight\"] = ing.edata[\"weight\"].float().to(device)\n        in_graph_list.append(ing)\n        for i in range(k - 1):\n            in_graph_list.append(\n                DiffConv.diffuse(in_graph_list[-1], wadj.T, ind)\n            )\n        return out_graph_list, in_graph_list\n\n    @staticmethod\n    def get_weight_matrix(g):\n        adj = g.adj_external(scipy_fmt=\"coo\")\n        ind = g.in_degrees()\n        outd = g.out_degrees()\n        weight = g.edata[\"weight\"]\n        adj.data = weight.cpu().numpy()\n        return adj, ind, outd\n\n    @staticmethod\n    def diffuse(progress_g, weighted_adj, degree):\n        device = progress_g.device\n        progress_adj = progress_g.adj_external(scipy_fmt=\"coo\")\n        progress_adj.data = progress_g.edata[\"weight\"].cpu().numpy()\n        ret_adj = sparse.coo_matrix(\n            progress_adj @ (weighted_adj / degree.cpu().numpy())\n        )\n        ret_graph = dgl.from_scipy(ret_adj, eweight_name=\"weight\").to(device)\n        ret_graph.edata[\"weight\"] = ret_graph.edata[\"weight\"].float().to(device)\n        return ret_graph\n\n    def forward(self, g, x):\n        feat_list = []\n        if self.dir == \"both\":\n            graph_list = self.in_graph_list + self.out_graph_list\n        elif self.dir == \"in\":\n            graph_list = self.in_graph_list\n        elif self.dir == \"out\":\n            graph_list = self.out_graph_list\n\n        for i in range(self.num_graphs):\n            g = graph_list[i]\n            with g.local_scope():\n                g.ndata[\"n\"] = self.project_fcs[i](x)\n                g.update_all(\n                    fn.u_mul_e(\"n\", \"weight\", \"e\"), fn.sum(\"e\", \"feat\")\n                )\n                feat_list.append(g.ndata[\"feat\"])\n                # Each feat has shape [N,q_feats]\n        feat_list.append(self.project_fcs[-1](x))\n        feat_list = torch.cat(feat_list).view(\n            len(feat_list), -1, self.out_feats\n        )\n        ret = (\n            (self.merger * feat_list.permute(1, 2, 0)).permute(2, 0, 1).mean(0)\n        )\n        return ret\n"
  },
  {
    "path": "examples/pytorch/dtgrnn/gaan.py",
    "content": "import dgl\nimport dgl.function as fn\nimport dgl.nn as dglnn\nimport numpy as np\nimport torch\nimport torch.nn as nn\nfrom dgl.base import DGLError\nfrom dgl.nn.functional import edge_softmax\n\n\nclass WeightedGATConv(dglnn.GATConv):\n    \"\"\"\n    This model inherit from dgl GATConv for traffic prediction task,\n    it add edge weight when aggregating the node feature.\n    \"\"\"\n\n    def forward(self, graph, feat, get_attention=False):\n        with graph.local_scope():\n            if not self._allow_zero_in_degree:\n                if (graph.in_degrees() == 0).any():\n                    raise DGLError(\n                        \"There are 0-in-degree nodes in the graph, \"\n                        \"output for those nodes will be invalid. \"\n                        \"This is harmful for some applications, \"\n                        \"causing silent performance regression. \"\n                        \"Adding self-loop on the input graph by \"\n                        \"calling `g = dgl.add_self_loop(g)` will resolve \"\n                        \"the issue. Setting ``allow_zero_in_degree`` \"\n                        \"to be `True` when constructing this module will \"\n                        \"suppress the check and let the code run.\"\n                    )\n\n            if isinstance(feat, tuple):\n                h_src = self.feat_drop(feat[0])\n                h_dst = self.feat_drop(feat[1])\n                if not hasattr(self, \"fc_src\"):\n                    feat_src = self.fc(h_src).view(\n                        -1, self._num_heads, self._out_feats\n                    )\n                    feat_dst = self.fc(h_dst).view(\n                        -1, self._num_heads, self._out_feats\n                    )\n                else:\n                    feat_src = self.fc_src(h_src).view(\n                        -1, self._num_heads, self._out_feats\n                    )\n                    feat_dst = self.fc_dst(h_dst).view(\n                        -1, self._num_heads, self._out_feats\n                    )\n            else:\n                h_src = h_dst = self.feat_drop(feat)\n                feat_src = feat_dst = self.fc(h_src).view(\n                    -1, self._num_heads, self._out_feats\n                )\n                if graph.is_block:\n                    feat_dst = feat_src[: graph.number_of_dst_nodes()]\n            # NOTE: GAT paper uses \"first concatenation then linear projection\"\n            # to compute attention scores, while ours is \"first projection then\n            # addition\", the two approaches are mathematically equivalent:\n            # We decompose the weight vector a mentioned in the paper into\n            # [a_l || a_r], then\n            # a^T [Wh_i || Wh_j] = a_l Wh_i + a_r Wh_j\n            # Our implementation is much efficient because we do not need to\n            # save [Wh_i || Wh_j] on edges, which is not memory-efficient. Plus,\n            # addition could be optimized with DGL's built-in function u_add_v,\n            # which further speeds up computation and saves memory footprint.\n            el = (feat_src * self.attn_l).sum(dim=-1).unsqueeze(-1)\n            er = (feat_dst * self.attn_r).sum(dim=-1).unsqueeze(-1)\n            graph.srcdata.update({\"ft\": feat_src, \"el\": el})\n            graph.dstdata.update({\"er\": er})\n            # compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively.\n            graph.apply_edges(fn.u_add_v(\"el\", \"er\", \"e\"))\n            e = self.leaky_relu(graph.edata.pop(\"e\"))\n            # compute softmax\n            graph.edata[\"a\"] = self.attn_drop(edge_softmax(graph, e))\n            # compute weighted attention\n            graph.edata[\"a\"] = (\n                graph.edata[\"a\"].permute(1, 2, 0) * graph.edata[\"weight\"]\n            ).permute(2, 0, 1)\n            # message passing\n            graph.update_all(fn.u_mul_e(\"ft\", \"a\", \"m\"), fn.sum(\"m\", \"ft\"))\n            rst = graph.dstdata[\"ft\"]\n            # residual\n            if self.res_fc is not None:\n                resval = self.res_fc(h_dst).view(\n                    h_dst.shape[0], -1, self._out_feats\n                )\n                rst = rst + resval\n            # activation\n            if self.activation:\n                rst = self.activation(rst)\n\n            if get_attention:\n                return rst, graph.edata[\"a\"]\n            else:\n                return rst\n\n\nclass GatedGAT(nn.Module):\n    \"\"\"Gated Graph Attention module, it is a general purpose\n    graph attention module proposed in paper GaAN. The paper use\n    it for traffic prediction task\n    Parameter\n    ==========\n    in_feats : int\n        number of input feature\n\n    out_feats : int\n        number of output feature\n\n    map_feats : int\n        intermediate feature size for gate computation\n\n    num_heads : int\n        number of head for multihead attention\n    \"\"\"\n\n    def __init__(self, in_feats, out_feats, map_feats, num_heads):\n        super(GatedGAT, self).__init__()\n        self.in_feats = in_feats\n        self.out_feats = out_feats\n        self.map_feats = map_feats\n        self.num_heads = num_heads\n        self.gatlayer = WeightedGATConv(\n            self.in_feats, self.out_feats, self.num_heads\n        )\n        self.gate_fn = nn.Linear(\n            2 * self.in_feats + self.map_feats, self.num_heads\n        )\n        self.gate_m = nn.Linear(self.in_feats, self.map_feats)\n        self.merger_layer = nn.Linear(\n            self.in_feats + self.out_feats, self.out_feats\n        )\n\n    def forward(self, g, x):\n        with g.local_scope():\n            g.ndata[\"x\"] = x\n            g.ndata[\"z\"] = self.gate_m(x)\n            g.update_all(fn.copy_u(\"x\", \"x\"), fn.mean(\"x\", \"mean_z\"))\n            g.update_all(fn.copy_u(\"z\", \"z\"), fn.max(\"z\", \"max_z\"))\n            nft = torch.cat(\n                [g.ndata[\"x\"], g.ndata[\"max_z\"], g.ndata[\"mean_z\"]], dim=1\n            )\n            gate = self.gate_fn(nft).sigmoid()\n            attn_out = self.gatlayer(g, x)\n            node_num = g.num_nodes()\n            gated_out = (\n                (gate.view(-1) * attn_out.view(-1, self.out_feats).T).T\n            ).view(node_num, self.num_heads, self.out_feats)\n            gated_out = gated_out.mean(1)\n            merge = self.merger_layer(torch.cat([x, gated_out], dim=1))\n            return merge\n"
  },
  {
    "path": "examples/pytorch/dtgrnn/model.py",
    "content": "import dgl\nimport dgl.function as fn\nimport dgl.nn as dglnn\nimport numpy as np\nimport scipy.sparse as sparse\nimport torch\nimport torch.nn as nn\nfrom dgl.base import DGLError\nfrom dgl.nn.functional import edge_softmax\n\n\nclass GraphGRUCell(nn.Module):\n    \"\"\"Graph GRU unit which can use any message passing\n    net to replace the linear layer in the original GRU\n    Parameter\n    ==========\n    in_feats : int\n        number of input features\n\n    out_feats : int\n        number of output features\n\n    net : torch.nn.Module\n        message passing network\n    \"\"\"\n\n    def __init__(self, in_feats, out_feats, net):\n        super(GraphGRUCell, self).__init__()\n        self.in_feats = in_feats\n        self.out_feats = out_feats\n        self.dir = dir\n        # net can be any GNN model\n        self.r_net = net(in_feats + out_feats, out_feats)\n        self.u_net = net(in_feats + out_feats, out_feats)\n        self.c_net = net(in_feats + out_feats, out_feats)\n        # Manually add bias Bias\n        self.r_bias = nn.Parameter(torch.rand(out_feats))\n        self.u_bias = nn.Parameter(torch.rand(out_feats))\n        self.c_bias = nn.Parameter(torch.rand(out_feats))\n\n    def forward(self, g, x, h):\n        r = torch.sigmoid(self.r_net(g, torch.cat([x, h], dim=1)) + self.r_bias)\n        u = torch.sigmoid(self.u_net(g, torch.cat([x, h], dim=1)) + self.u_bias)\n        h_ = r * h\n        c = torch.sigmoid(\n            self.c_net(g, torch.cat([x, h_], dim=1)) + self.c_bias\n        )\n        new_h = u * h + (1 - u) * c\n        return new_h\n\n\nclass StackedEncoder(nn.Module):\n    \"\"\"One step encoder unit for hidden representation generation\n    it can stack multiple vertical layers to increase the depth.\n\n    Parameter\n    ==========\n    in_feats : int\n        number if input features\n\n    out_feats : int\n        number of output features\n\n    num_layers : int\n        vertical depth of one step encoding unit\n\n    net : torch.nn.Module\n        message passing network for graph computation\n    \"\"\"\n\n    def __init__(self, in_feats, out_feats, num_layers, net):\n        super(StackedEncoder, self).__init__()\n        self.in_feats = in_feats\n        self.out_feats = out_feats\n        self.num_layers = num_layers\n        self.net = net\n        self.layers = nn.ModuleList()\n        if self.num_layers <= 0:\n            raise DGLError(\"Layer Number must be greater than 0! \")\n        self.layers.append(\n            GraphGRUCell(self.in_feats, self.out_feats, self.net)\n        )\n        for _ in range(self.num_layers - 1):\n            self.layers.append(\n                GraphGRUCell(self.out_feats, self.out_feats, self.net)\n            )\n\n    # hidden_states should be a list which for different layer\n    def forward(self, g, x, hidden_states):\n        hiddens = []\n        for i, layer in enumerate(self.layers):\n            x = layer(g, x, hidden_states[i])\n            hiddens.append(x)\n        return x, hiddens\n\n\nclass StackedDecoder(nn.Module):\n    \"\"\"One step decoder unit for hidden representation generation\n    it can stack multiple vertical layers to increase the depth.\n\n    Parameter\n    ==========\n    in_feats : int\n        number if input features\n\n    hid_feats : int\n        number of feature before the linear output layer\n\n    out_feats : int\n        number of output features\n\n    num_layers : int\n        vertical depth of one step encoding unit\n\n    net : torch.nn.Module\n        message passing network for graph computation\n    \"\"\"\n\n    def __init__(self, in_feats, hid_feats, out_feats, num_layers, net):\n        super(StackedDecoder, self).__init__()\n        self.in_feats = in_feats\n        self.hid_feats = hid_feats\n        self.out_feats = out_feats\n        self.num_layers = num_layers\n        self.net = net\n        self.out_layer = nn.Linear(self.hid_feats, self.out_feats)\n        self.layers = nn.ModuleList()\n        if self.num_layers <= 0:\n            raise DGLError(\"Layer Number must be greater than 0!\")\n        self.layers.append(GraphGRUCell(self.in_feats, self.hid_feats, net))\n        for _ in range(self.num_layers - 1):\n            self.layers.append(\n                GraphGRUCell(self.hid_feats, self.hid_feats, net)\n            )\n\n    def forward(self, g, x, hidden_states):\n        hiddens = []\n        for i, layer in enumerate(self.layers):\n            x = layer(g, x, hidden_states[i])\n            hiddens.append(x)\n        x = self.out_layer(x)\n        return x, hiddens\n\n\nclass GraphRNN(nn.Module):\n    \"\"\"Graph Sequence to sequence prediction framework\n    Support multiple backbone GNN. Mainly used for traffic prediction.\n\n    Parameter\n    ==========\n    in_feats : int\n        number of input features\n\n    out_feats : int\n        number of prediction output features\n\n    seq_len : int\n        input and predicted sequence length\n\n    num_layers : int\n        vertical number of layers in encoder and decoder unit\n\n    net : torch.nn.Module\n        Message passing GNN as backbone\n\n    decay_steps : int\n        number of steps for the teacher forcing probability to decay\n    \"\"\"\n\n    def __init__(\n        self, in_feats, out_feats, seq_len, num_layers, net, decay_steps\n    ):\n        super(GraphRNN, self).__init__()\n        self.in_feats = in_feats\n        self.out_feats = out_feats\n        self.seq_len = seq_len\n        self.num_layers = num_layers\n        self.net = net\n        self.decay_steps = decay_steps\n\n        self.encoder = StackedEncoder(\n            self.in_feats, self.out_feats, self.num_layers, self.net\n        )\n\n        self.decoder = StackedDecoder(\n            self.in_feats,\n            self.out_feats,\n            self.in_feats,\n            self.num_layers,\n            self.net,\n        )\n\n    # Threshold For Teacher Forcing\n\n    def compute_thresh(self, batch_cnt):\n        return self.decay_steps / (\n            self.decay_steps + np.exp(batch_cnt / self.decay_steps)\n        )\n\n    def encode(self, g, inputs, device):\n        hidden_states = [\n            torch.zeros(g.num_nodes(), self.out_feats).to(device)\n            for _ in range(self.num_layers)\n        ]\n        for i in range(self.seq_len):\n            _, hidden_states = self.encoder(g, inputs[i], hidden_states)\n\n        return hidden_states\n\n    def decode(self, g, teacher_states, hidden_states, batch_cnt, device):\n        outputs = []\n        inputs = torch.zeros(g.num_nodes(), self.in_feats).to(device)\n        for i in range(self.seq_len):\n            if (\n                np.random.random() < self.compute_thresh(batch_cnt)\n                and self.training\n            ):\n                inputs, hidden_states = self.decoder(\n                    g, teacher_states[i], hidden_states\n                )\n            else:\n                inputs, hidden_states = self.decoder(g, inputs, hidden_states)\n            outputs.append(inputs)\n        outputs = torch.stack(outputs)\n        return outputs\n\n    def forward(self, g, inputs, teacher_states, batch_cnt, device):\n        hidden = self.encode(g, inputs, device)\n        outputs = self.decode(g, teacher_states, hidden, batch_cnt, device)\n        return outputs\n"
  },
  {
    "path": "examples/pytorch/dtgrnn/train.py",
    "content": "import argparse\nfrom functools import partial\n\nimport dgl\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nfrom dataloading import (\n    METR_LAGraphDataset,\n    METR_LATestDataset,\n    METR_LATrainDataset,\n    METR_LAValidDataset,\n    PEMS_BAYGraphDataset,\n    PEMS_BAYTestDataset,\n    PEMS_BAYTrainDataset,\n    PEMS_BAYValidDataset,\n)\nfrom dcrnn import DiffConv\nfrom gaan import GatedGAT\nfrom model import GraphRNN\nfrom torch.utils.data import DataLoader\nfrom utils import get_learning_rate, masked_mae_loss, NormalizationLayer\n\nbatch_cnt = [0]\n\n\ndef train(\n    model,\n    graph,\n    dataloader,\n    optimizer,\n    scheduler,\n    normalizer,\n    loss_fn,\n    device,\n    args,\n):\n    total_loss = []\n    graph = graph.to(device)\n    model.train()\n    batch_size = args.batch_size\n    for i, (x, y) in enumerate(dataloader):\n        optimizer.zero_grad()\n        # Padding: Since the diffusion graph is precmputed we need to pad the batch so that\n        # each batch have same batch size\n        if x.shape[0] != batch_size:\n            x_buff = torch.zeros(batch_size, x.shape[1], x.shape[2], x.shape[3])\n            y_buff = torch.zeros(batch_size, x.shape[1], x.shape[2], x.shape[3])\n            x_buff[: x.shape[0], :, :, :] = x\n            x_buff[x.shape[0] :, :, :, :] = x[-1].repeat(\n                batch_size - x.shape[0], 1, 1, 1\n            )\n            y_buff[: x.shape[0], :, :, :] = y\n            y_buff[x.shape[0] :, :, :, :] = y[-1].repeat(\n                batch_size - x.shape[0], 1, 1, 1\n            )\n            x = x_buff\n            y = y_buff\n        # Permute the dimension for shaping\n        x = x.permute(1, 0, 2, 3)\n        y = y.permute(1, 0, 2, 3)\n\n        x_norm = (\n            normalizer.normalize(x)\n            .reshape(x.shape[0], -1, x.shape[3])\n            .float()\n            .to(device)\n        )\n        y_norm = (\n            normalizer.normalize(y)\n            .reshape(x.shape[0], -1, x.shape[3])\n            .float()\n            .to(device)\n        )\n        y = y.reshape(y.shape[0], -1, y.shape[3]).float().to(device)\n\n        batch_graph = dgl.batch([graph] * batch_size)\n        output = model(batch_graph, x_norm, y_norm, batch_cnt[0], device)\n        # Denormalization for loss compute\n        y_pred = normalizer.denormalize(output)\n        loss = loss_fn(y_pred, y)\n        loss.backward()\n        nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)\n        optimizer.step()\n        if get_learning_rate(optimizer) > args.minimum_lr:\n            scheduler.step()\n        total_loss.append(float(loss))\n        batch_cnt[0] += 1\n        print(\"\\rBatch: \", i, end=\"\")\n    return np.mean(total_loss)\n\n\ndef eval(model, graph, dataloader, normalizer, loss_fn, device, args):\n    total_loss = []\n    graph = graph.to(device)\n    model.eval()\n    batch_size = args.batch_size\n    for i, (x, y) in enumerate(dataloader):\n        # Padding: Since the diffusion graph is precmputed we need to pad the batch so that\n        # each batch have same batch size\n        if x.shape[0] != batch_size:\n            x_buff = torch.zeros(batch_size, x.shape[1], x.shape[2], x.shape[3])\n            y_buff = torch.zeros(batch_size, x.shape[1], x.shape[2], x.shape[3])\n            x_buff[: x.shape[0], :, :, :] = x\n            x_buff[x.shape[0] :, :, :, :] = x[-1].repeat(\n                batch_size - x.shape[0], 1, 1, 1\n            )\n            y_buff[: x.shape[0], :, :, :] = y\n            y_buff[x.shape[0] :, :, :, :] = y[-1].repeat(\n                batch_size - x.shape[0], 1, 1, 1\n            )\n            x = x_buff\n            y = y_buff\n        # Permute the order of dimension\n        x = x.permute(1, 0, 2, 3)\n        y = y.permute(1, 0, 2, 3)\n\n        x_norm = (\n            normalizer.normalize(x)\n            .reshape(x.shape[0], -1, x.shape[3])\n            .float()\n            .to(device)\n        )\n        y_norm = (\n            normalizer.normalize(y)\n            .reshape(x.shape[0], -1, x.shape[3])\n            .float()\n            .to(device)\n        )\n        y = y.reshape(x.shape[0], -1, x.shape[3]).to(device)\n\n        batch_graph = dgl.batch([graph] * batch_size)\n        output = model(batch_graph, x_norm, y_norm, i, device)\n        y_pred = normalizer.denormalize(output)\n        loss = loss_fn(y_pred, y)\n        total_loss.append(float(loss))\n    return np.mean(total_loss)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    # Define the arguments\n    parser.add_argument(\n        \"--batch_size\",\n        type=int,\n        default=64,\n        help=\"Size of batch for minibatch Training\",\n    )\n    parser.add_argument(\n        \"--num_workers\",\n        type=int,\n        default=0,\n        help=\"Number of workers for parallel dataloading\",\n    )\n    parser.add_argument(\n        \"--model\",\n        type=str,\n        default=\"dcrnn\",\n        help=\"WHich model to use DCRNN vs GaAN\",\n    )\n    parser.add_argument(\n        \"--gpu\", type=int, default=-1, help=\"GPU indexm -1 for CPU training\"\n    )\n    parser.add_argument(\n        \"--diffsteps\",\n        type=int,\n        default=2,\n        help=\"Step of constructing the diffusiob matrix\",\n    )\n    parser.add_argument(\n        \"--num_heads\", type=int, default=2, help=\"Number of multiattention head\"\n    )\n    parser.add_argument(\n        \"--decay_steps\",\n        type=int,\n        default=2000,\n        help=\"Teacher forcing probability decay ratio\",\n    )\n    parser.add_argument(\n        \"--lr\", type=float, default=0.01, help=\"Initial learning rate\"\n    )\n    parser.add_argument(\n        \"--minimum_lr\",\n        type=float,\n        default=2e-6,\n        help=\"Lower bound of learning rate\",\n    )\n    parser.add_argument(\n        \"--dataset\",\n        type=str,\n        default=\"LA\",\n        help=\"dataset LA for METR_LA; BAY for PEMS_BAY\",\n    )\n    parser.add_argument(\n        \"--epochs\", type=int, default=100, help=\"Number of epoches for training\"\n    )\n    parser.add_argument(\n        \"--max_grad_norm\",\n        type=float,\n        default=5.0,\n        help=\"Maximum gradient norm for update parameters\",\n    )\n    args = parser.parse_args()\n    # Load the datasets\n    if args.dataset == \"LA\":\n        g = METR_LAGraphDataset()\n        train_data = METR_LATrainDataset()\n        test_data = METR_LATestDataset()\n        valid_data = METR_LAValidDataset()\n    elif args.dataset == \"BAY\":\n        g = PEMS_BAYGraphDataset()\n        train_data = PEMS_BAYTrainDataset()\n        test_data = PEMS_BAYTestDataset()\n        valid_data = PEMS_BAYValidDataset()\n\n    if args.gpu == -1:\n        device = torch.device(\"cpu\")\n    else:\n        device = torch.device(\"cuda:{}\".format(args.gpu))\n\n    train_loader = DataLoader(\n        train_data,\n        batch_size=args.batch_size,\n        num_workers=args.num_workers,\n        shuffle=True,\n    )\n    valid_loader = DataLoader(\n        valid_data,\n        batch_size=args.batch_size,\n        num_workers=args.num_workers,\n        shuffle=True,\n    )\n    test_loader = DataLoader(\n        test_data,\n        batch_size=args.batch_size,\n        num_workers=args.num_workers,\n        shuffle=True,\n    )\n    normalizer = NormalizationLayer(train_data.mean, train_data.std)\n\n    if args.model == \"dcrnn\":\n        batch_g = dgl.batch([g] * args.batch_size).to(device)\n        out_gs, in_gs = DiffConv.attach_graph(batch_g, args.diffsteps)\n        net = partial(\n            DiffConv,\n            k=args.diffsteps,\n            in_graph_list=in_gs,\n            out_graph_list=out_gs,\n        )\n    elif args.model == \"gaan\":\n        net = partial(GatedGAT, map_feats=64, num_heads=args.num_heads)\n\n    dcrnn = GraphRNN(\n        in_feats=2,\n        out_feats=64,\n        seq_len=12,\n        num_layers=2,\n        net=net,\n        decay_steps=args.decay_steps,\n    ).to(device)\n\n    optimizer = torch.optim.Adam(dcrnn.parameters(), lr=args.lr)\n    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)\n\n    loss_fn = masked_mae_loss\n\n    for e in range(args.epochs):\n        train_loss = train(\n            dcrnn,\n            g,\n            train_loader,\n            optimizer,\n            scheduler,\n            normalizer,\n            loss_fn,\n            device,\n            args,\n        )\n        valid_loss = eval(\n            dcrnn, g, valid_loader, normalizer, loss_fn, device, args\n        )\n        test_loss = eval(\n            dcrnn, g, test_loader, normalizer, loss_fn, device, args\n        )\n        print(\n            \"\\rEpoch: {} Train Loss: {} Valid Loss: {} Test Loss: {}\".format(\n                e, train_loss, valid_loss, test_loss\n            )\n        )\n"
  },
  {
    "path": "examples/pytorch/dtgrnn/utils.py",
    "content": "import dgl\nimport numpy as np\nimport scipy.sparse as sparse\nimport torch\nimport torch.nn as nn\n\n\nclass NormalizationLayer(nn.Module):\n    def __init__(self, mean, std):\n        self.mean = mean\n        self.std = std\n\n    # Here we shall expect mean and std be scaler\n    def normalize(self, x):\n        return (x - self.mean) / self.std\n\n    def denormalize(self, x):\n        return x * self.std + self.mean\n\n\ndef masked_mae_loss(y_pred, y_true):\n    mask = (y_true != 0).float()\n    mask /= mask.mean()\n    loss = torch.abs(y_pred - y_true)\n    loss = loss * mask\n    # trick for nans: https://discuss.pytorch.org/t/how-to-set-nan-in-tensor-to-0/3918/3\n    loss[loss != loss] = 0\n    return loss.mean()\n\n\ndef get_learning_rate(optimizer):\n    for param in optimizer.param_groups:\n        return param[\"lr\"]\n"
  },
  {
    "path": "examples/pytorch/eeg-gcnn/EEGGraphDataset.py",
    "content": "import math\nfrom itertools import product\n\nimport dgl\n\nimport numpy as np\nimport pandas as pd\nimport torch\nfrom dgl.data import DGLDataset\n\n\nclass EEGGraphDataset(DGLDataset):\n    \"\"\"Build graph, treat all nodes as the same type\n    Parameters\n    ----------\n    x: edge weights of 8-node complete graph\n        There are 1 x 64 edges\n    y: labels (diseased/healthy)\n    num_nodes: the number of nodes of the graph. In our case, it is 8.\n    indices: Patient level indices. They are used to generate edge weights.\n\n    Output\n    ------\n    a complete 8-node DGLGraph with node features and edge weights\n    \"\"\"\n\n    def __init__(self, x, y, num_nodes, indices):\n        # CAUTION - x and labels are memory-mapped, used as if they are in RAM.\n        self.x = x\n        self.labels = y\n        self.indices = indices\n        self.num_nodes = num_nodes\n\n        # NOTE: this order decides the node index, keep consistent!\n        self.ch_names = [\n            \"F7-F3\",\n            \"F8-F4\",\n            \"T7-C3\",\n            \"T8-C4\",\n            \"P7-P3\",\n            \"P8-P4\",\n            \"O1-P3\",\n            \"O2-P4\",\n        ]\n\n        # in the 10-10 system, in between the 2 10-20 electrodes in ch_names, used for calculating edge weights\n        # Note: \"01\" is for \"P03\", and \"02\" is for \"P04.\"\n        self.ref_names = [\"F5\", \"F6\", \"C5\", \"C6\", \"P5\", \"P6\", \"O1\", \"O2\"]\n\n        # edge indices source to target - 2 x E = 2 x 64\n        # fully connected undirected graph so 8*8=64 edges\n        self.node_ids = range(len(self.ch_names))\n        self.edge_index = (\n            torch.tensor(\n                [[a, b] for a, b in product(self.node_ids, self.node_ids)],\n                dtype=torch.long,\n            )\n            .t()\n            .contiguous()\n        )\n\n        # edge attributes - E x 1\n        # only the spatial distance between electrodes for now - standardize between 0 and 1\n        self.distances = self.get_sensor_distances()\n        a = np.array(self.distances)\n        self.distances = (a - np.min(a)) / (np.max(a) - np.min(a))\n        self.spec_coh_values = np.load(\"spec_coh_values.npy\", allow_pickle=True)\n\n    # sensor distances don't depend on window ID\n    def get_sensor_distances(self):\n        coords_1010 = pd.read_csv(\"standard_1010.tsv.txt\", sep=\"\\t\")\n        num_edges = self.edge_index.shape[1]\n        distances = []\n        for edge_idx in range(num_edges):\n            sensor1_idx = self.edge_index[0, edge_idx]\n            sensor2_idx = self.edge_index[1, edge_idx]\n            dist = self.get_geodesic_distance(\n                sensor1_idx, sensor2_idx, coords_1010\n            )\n            distances.append(dist)\n        assert len(distances) == num_edges\n        return distances\n\n    def get_geodesic_distance(\n        self, montage_sensor1_idx, montage_sensor2_idx, coords_1010\n    ):\n        def get_coord(ref_sensor, coord):\n            return float(\n                (coords_1010[coords_1010.label == ref_sensor][coord]).iloc[0]\n            )\n\n        # get the reference sensor in the 10-10 system for the current montage pair in 10-20 system\n        ref_sensor1 = self.ref_names[montage_sensor1_idx]\n        ref_sensor2 = self.ref_names[montage_sensor2_idx]\n\n        x1 = get_coord(ref_sensor1, \"x\")\n        y1 = get_coord(ref_sensor1, \"y\")\n        z1 = get_coord(ref_sensor1, \"z\")\n\n        x2 = get_coord(ref_sensor2, \"x\")\n        y2 = get_coord(ref_sensor2, \"y\")\n        z2 = get_coord(ref_sensor2, \"z\")\n\n        # https://math.stackexchange.com/questions/1304169/distance-between-two-points-on-a-sphere\n        r = 1  # since coords are on unit sphere\n        # rounding is for numerical stability, domain is [-1, 1]\n        dist = r * math.acos(\n            round(((x1 * x2) + (y1 * y2) + (z1 * z2)) / (r**2), 2)\n        )\n        return dist\n\n    # returns size of dataset = number of indices\n    def __len__(self):\n        return len(self.indices)\n\n    # retrieve one sample from the dataset after applying all transforms\n    def __getitem__(self, idx):\n        if torch.is_tensor(idx):\n            idx = idx.tolist()\n\n        # map input idx (ranging from 0 to __len__() inside self.indices)\n        # to an idx in the whole dataset (inside self.x)\n        # assert idx < len(self.indices)\n        idx = self.indices[idx]\n        node_features = self.x[idx]\n        node_features = torch.from_numpy(node_features.reshape(8, 6))\n\n        # spectral coherence between 2 montage channels!\n        spec_coh_values = self.spec_coh_values[idx, :]\n\n        # combine edge weights and spect coh values into one value/ one E x 1 tensor\n        edge_weights = self.distances + spec_coh_values\n        edge_weights = torch.tensor(edge_weights)  # trucated to integer\n\n        # create 8-node complete graph\n        src = [\n            [0 for i in range(self.num_nodes)] for j in range(self.num_nodes)\n        ]\n        for i in range(len(src)):\n            for j in range(len(src[i])):\n                src[i][j] = i\n        src = np.array(src).flatten()\n\n        det = [\n            [i for i in range(self.num_nodes)] for j in range(self.num_nodes)\n        ]\n        det = np.array(det).flatten()\n\n        u, v = (torch.tensor(src), torch.tensor(det))\n        g = dgl.graph((u, v))\n\n        # add node features and edge features\n        g.ndata[\"x\"] = node_features\n        g.edata[\"edge_weights\"] = edge_weights\n        return g, torch.tensor(idx), torch.tensor(self.labels[idx])\n"
  },
  {
    "path": "examples/pytorch/eeg-gcnn/README.md",
    "content": "# DGL Implementation of EEG-GCNN Paper\nThis example is a simplified version that presents how to utilize the original EEG-GCNN model proposed in the paper [EEG-GCNN](http://proceedings.mlr.press/v136/wagh20a.html), implemented with DGL library. The example removes cross validation and optimal decision boundary that are used in the original code. The performance stats are slightly different from what is present in the paper. The original code is [here](https://github.com/neerajwagh/eeg-gcnn).\n\n## All References\n- [ML4H Poster](https://drive.google.com/file/d/14nuAQKiIud3p6-c8r9WLV2tAvCyRwRev/view?usp=sharing) can be helpful for understanding data preprocessing, model, and performance of the project. \n- The recording of presentation by the author Neeraj Wagh can be found on [slideslive](https://slideslive.com/38941020/eeggcnn-augmenting-electroencephalogrambased-neurological-disease-diagnosis-using-a-domainguided-graph-convolutional-neural-network?ref=account-folder-62123-folders).\n- The slides used during the presentation can be found [here](https://drive.google.com/file/d/1dXT4QAUXKauf7CAkhrVyhR2PFUsNh4b8/view?usp=sharing).\n- Raw Data can be found with these two links: [MPI LEMON](http://fcon_1000.projects.nitrc.org/indi/retro/MPI_LEMON.html) (no registration needed), [TUH EEG Abnormal Corpus](https://www.isip.piconepress.com/projects/tuh_eeg/downloads/tuh_eeg_abnormal/) ([needs registration](https://www.isip.piconepress.com/projects/tuh_eeg/html/request_access.php))\n\n## Dependencies\n\n- Python 3.8.1\n- PyTorch 1.7.0\n- DGL 0.6.1\n- numpy 1.20.2\n- Sklearn 0.24.2\n- pandas 1.2.4\n## Dataset\n- Final Models, Pre-computed Features, Training Metadata can be downloaded through [FigShare](https://figshare.com/articles/software/EEG-GCNN_Supporting_Resources_for_Reproducibility/13251452).\n- In ```EEGGraphDataset.py```, we specify the channels and electrodes and use precomputed spectral coherence values to compute the edge weights. To use this example in your own advantage, please specify your channels and electrodes in ```__init__``` function of ```EEGGraphDataset.py```.\n- To generate spectral coherence values, please refer to [spectral_connectivity](https://mne.tools/stable/generated/mne.connectivity.spectral_connectivity.html) function in mne library. An example usage may take the following form:\n```python\n     # ....loop over all windows in dataset....\n\n        # window data is 10-second preprocessed multi-channel timeseries (shape: n_channels x n_timepoints) containing all channels in ch_names\n        window_data = np.expand_dims(window_data, axis=0)\n\n        # ch_names are listed in EEGGraphDataset.py\n        for ch_idx, ch in enumerate(ch_names):\n            # number of channels is is len(ch_names), which is 8 in our case.\n            spec_coh_values, _, _, _, _ = mne.connectivity.spectral_connectivity(data=window_data, method='coh', indices=([ch_idx]*8, range(8)), sfreq=SAMPLING_FREQ,\n                                              fmin=1.0, fmax=40.0, faverage=True, verbose=False)\n```\n## How to Run\n- First, download ```figshare_upload/master_metadata_index.csv```, ```figshare_upload/psd_features_data_X```, ```figshare_upload/labels_y```, ```figshare_upload/psd_shallow_eeg-gcnn/spec_coh_values.npy```, and ```figshare_upload/psd_shallow_eeg-gcnn/standard_1010.tsv.txt```. Put them in the repo. <br>\n- You may download these files by running:\n```python\nwget https://ndownloader.figshare.com/files/25518170\n```\n- You will need to unzip the downloaded file.\n- Then run: \n```python\npython main.py\n```\n- The default model used is ```shallow_EEGGraphConvNet.py```. To use ```deep_EEGGraphConvNet.py```, run:\n```python\npython main.py --model deep\n```\n- After the code executes, you will be able to see similar stats in performance section printed. The code will save the trained model from every epoch.\n## Performance\n\n|      DGL          | AUC         | Bal. Accuracy |\n|-------------------|-------------|---------------|\n| Shallow EEG-GCNN  | 0.832       | 0.750         |\n| Deep EEG-GCNN     | 0.830       | 0.736         |\n\nShallow_EEGGraphConvNet    |              AUC          |     Bal.Accuracy      |\n:-------------------------:|:-------------------------:|:---------------------:|\n![shallow_loss](https://user-images.githubusercontent.com/53772888/128595442-d185bd74-5c5d-4118-a6b7-b89dd307d3aa.png)  |![shallow_auc](https://user-images.githubusercontent.com/53772888/128595453-2f3b181a-bcb7-4da4-becd-7a7aa62083bc.png)|![shallow_bacc](https://user-images.githubusercontent.com/53772888/128595456-b293c888-bf8c-4f37-bd58-d01885da3832.png)\n\nDeep_EEGGraphConvNet            |  AUC | Bal.Accuracy |\n:-------------------------:|:-------------------------:|:---------------:|\n![deep_loss](https://user-images.githubusercontent.com/53772888/128595458-e4a76591-11cf-405f-9c20-2d161e49c358.png)|![deep_auc](https://user-images.githubusercontent.com/53772888/128595462-7a7bfb67-4601-4e83-8764-d7c44bf979b5.png)|![deep_bacc](https://user-images.githubusercontent.com/53772888/128595467-1a0cd37d-0152-431b-a29b-a40bafb71be5.png)\n\n### Contact\n\n- Email to John(_wei33@illinois.edu_)\n- You may also contact the authors:\n  - Neeraj: nwagh2@illinois.edu / [Website](http://neerajwagh.com/) / [Twitter](https://twitter.com/neeraj_wagh) / [Google Scholar](https://scholar.google.com/citations?hl=en&user=lCy5VsUAAAAJ)\n  - Yoga: varatha2@illinois.edu / [Website](https://sites.google.com/view/yoga-personal/home) / [Google Scholar](https://scholar.google.com/citations?user=XwL4dBgAAAAJ&hl=en)\n\n### Citation\n\nWagh, N. & Varatharajah, Y.. (2020). EEG-GCNN: Augmenting Electroencephalogram-based Neurological Disease Diagnosis using a Domain-guided Graph Convolutional Neural Network. Proceedings of the Machine Learning for Health NeurIPS Workshop, in PMLR 136:367-378 Available from http://proceedings.mlr.press/v136/wagh20a.html.\n"
  },
  {
    "path": "examples/pytorch/eeg-gcnn/deep_EEGGraphConvNet.py",
    "content": "import torch.nn as nn\nimport torch.nn.functional as function\n\nfrom dgl.nn import GraphConv, SumPooling\nfrom torch.nn import BatchNorm1d\n\n\nclass EEGGraphConvNet(nn.Module):\n    \"\"\"EEGGraph Convolution Net\n    Parameters\n    ----------\n    num_feats: the number of features per node. In our case, it is 6.\n    \"\"\"\n\n    def __init__(self, num_feats):\n        super(EEGGraphConvNet, self).__init__()\n\n        self.conv1 = GraphConv(num_feats, 16)\n        self.conv2 = GraphConv(16, 32)\n        self.conv3 = GraphConv(32, 64)\n        self.conv4 = GraphConv(64, 50)\n        self.conv4_bn = BatchNorm1d(\n            50, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n        )\n\n        self.fc_block1 = nn.Linear(50, 30)\n        self.fc_block2 = nn.Linear(30, 10)\n        self.fc_block3 = nn.Linear(10, 2)\n\n        # Xavier initializations\n        self.fc_block1.apply(lambda x: nn.init.xavier_normal_(x.weight, gain=1))\n        self.fc_block2.apply(lambda x: nn.init.xavier_normal_(x.weight, gain=1))\n        self.fc_block3.apply(lambda x: nn.init.xavier_normal_(x.weight, gain=1))\n\n        self.sumpool = SumPooling()\n\n    def forward(self, g, return_graph_embedding=False):\n        x = g.ndata[\"x\"]\n        edge_weight = g.edata[\"edge_weights\"]\n\n        x = self.conv1(g, x, edge_weight=edge_weight)\n        x = function.leaky_relu(x, negative_slope=0.01)\n        x = function.dropout(x, p=0.2, training=self.training)\n\n        x = self.conv2(g, x, edge_weight=edge_weight)\n        x = function.leaky_relu(x, negative_slope=0.01)\n        x = function.dropout(x, p=0.2, training=self.training)\n\n        x = self.conv3(g, x, edge_weight=edge_weight)\n        x = function.leaky_relu(x, negative_slope=0.01)\n        x = function.dropout(x, p=0.2, training=self.training)\n\n        x = self.conv4(g, x, edge_weight=edge_weight)\n        x = self.conv4_bn(x)\n        x = function.leaky_relu(x, negative_slope=0.01)\n        x = function.dropout(x, p=0.2, training=self.training)\n        # NOTE: this takes node-level features/\"embeddings\"\n        # and aggregates to graph-level - use for graph-level classification\n\n        out = self.sumpool(g, x)\n        if return_graph_embedding:\n            return out\n\n        out = function.leaky_relu(self.fc_block1(out), negative_slope=0.1)\n        out = function.dropout(out, p=0.2, training=self.training)\n\n        out = function.leaky_relu(self.fc_block2(out), negative_slope=0.1)\n        out = function.dropout(out, p=0.2, training=self.training)\n\n        out = self.fc_block3(out)\n        return out\n"
  },
  {
    "path": "examples/pytorch/eeg-gcnn/main.py",
    "content": "import argparse\n\nimport numpy as np\nimport pandas as pd\nimport torch\nimport torch.nn as nn\n\nfrom dgl.dataloading import GraphDataLoader\nfrom EEGGraphDataset import EEGGraphDataset\nfrom joblib import dump, load\nfrom sklearn import preprocessing\nfrom sklearn.metrics import balanced_accuracy_score, roc_auc_score\nfrom sklearn.model_selection import train_test_split\nfrom torch.utils.data import WeightedRandomSampler\n\n\ndef _load_memory_mapped_array(file_name):\n    # Due to a legacy problem related to memory alignment in joblib [1], the\n    # data provided in the example may not be byte-aligned. This can be risky\n    # when loading with mmap_mode. To fix the issue, load and re-dump the data.\n    # [1] https://joblib.readthedocs.io/en/latest/developing.html#release-1-2-0\n    dump(load(file_name), file_name)\n    return load(file_name, mmap_mode=\"r\")\n\n\nif __name__ == \"__main__\":\n    # argparse commandline args\n    parser = argparse.ArgumentParser(\n        description=\"Execute training pipeline on a given train/val subjects\"\n    )\n    parser.add_argument(\n        \"--num_feats\",\n        type=int,\n        default=6,\n        help=\"Number of features per node for the graph\",\n    )\n    parser.add_argument(\n        \"--num_nodes\", type=int, default=8, help=\"Number of nodes in the graph\"\n    )\n    parser.add_argument(\n        \"--num_workers\",\n        type=int,\n        default=4,\n        help=\"Number of epochs used to train\",\n    )\n    parser.add_argument(\n        \"--gpu_idx\",\n        type=int,\n        default=0,\n        help=\"index of GPU device that should be used for this run, defaults to 0.\",\n    )\n    parser.add_argument(\n        \"--num_epochs\",\n        type=int,\n        default=40,\n        help=\"Number of epochs used to train\",\n    )\n    parser.add_argument(\n        \"--exp_name\", type=str, default=\"default\", help=\"Name for the test.\"\n    )\n    parser.add_argument(\n        \"--batch_size\",\n        type=int,\n        default=512,\n        help=\"Batch Size. Default is 512.\",\n    )\n    parser.add_argument(\n        \"--model\",\n        type=str,\n        default=\"shallow\",\n        help=\"type shallow to use shallow_EEGGraphDataset; \"\n        \"type deep to use deep_EEGGraphDataset. Default is shallow\",\n    )\n    args = parser.parse_args()\n\n    # choose model\n    if args.model == \"shallow\":\n        from shallow_EEGGraphConvNet import EEGGraphConvNet\n\n    if args.model == \"deep\":\n        from deep_EEGGraphConvNet import EEGGraphConvNet\n\n    # set the random seed so that we can reproduce the results\n    np.random.seed(42)\n    torch.manual_seed(42)\n\n    # use GPU when available\n    _GPU_IDX = args.gpu_idx\n    _DEVICE = torch.device(\n        f\"cuda:{_GPU_IDX}\" if torch.cuda.is_available() else \"cpu\"\n    )\n    torch.cuda.set_device(_DEVICE)\n    print(f\" Using device: {_DEVICE} {torch.cuda.get_device_name(_DEVICE)}\")\n\n    # load patient level indices\n    _DATASET_INDEX = pd.read_csv(\"master_metadata_index.csv\", low_memory=False)\n    all_subjects = _DATASET_INDEX[\"patient_ID\"].astype(\"str\").unique()\n    print(f\"Subject list fetched! Total subjects are {len(all_subjects)}.\")\n\n    # retrieve inputs\n    num_nodes = args.num_nodes\n    _NUM_EPOCHS = args.num_epochs\n    _EXPERIMENT_NAME = args.exp_name\n    _BATCH_SIZE = args.batch_size\n    num_feats = args.num_feats\n    num_workers = args.num_workers\n\n    # set up input and targets from files\n    x = _load_memory_mapped_array(f\"psd_features_data_X\")\n    y = _load_memory_mapped_array(f\"labels_y\")\n\n    # normalize psd features data\n    normd_x = []\n    for i in range(len(y)):\n        arr = x[i, :]\n        arr = arr.reshape(1, -1)\n        arr2 = preprocessing.normalize(arr)\n        arr2 = arr2.reshape(48)\n        normd_x.append(arr2)\n\n    norm = np.array(normd_x)\n    x = norm.reshape(len(y), 48)\n    # map 0/1 to diseased/healthy\n    label_mapping, y = np.unique(y, return_inverse=True)\n    print(f\"Unique labels 0/1 mapping: {label_mapping}\")\n\n    # split the dataset to train and test. The ratio of test is 0.3.\n    train_and_val_subjects, heldout_subjects = train_test_split(\n        all_subjects, test_size=0.3, random_state=42\n    )\n\n    # split the dataset using patient indices\n    train_window_indices = _DATASET_INDEX.index[\n        _DATASET_INDEX[\"patient_ID\"].astype(\"str\").isin(train_and_val_subjects)\n    ].tolist()\n    heldout_test_window_indices = _DATASET_INDEX.index[\n        _DATASET_INDEX[\"patient_ID\"].astype(\"str\").isin(heldout_subjects)\n    ].tolist()\n\n    # define model, optimizer, scheduler\n    model = EEGGraphConvNet(num_feats)\n    loss_function = nn.CrossEntropyLoss()\n    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)\n    scheduler = torch.optim.lr_scheduler.MultiStepLR(\n        optimizer, milestones=[i * 10 for i in range(1, 26)], gamma=0.1\n    )\n\n    model = model.to(_DEVICE).double()\n    num_trainable_params = np.sum(\n        [\n            np.prod(p.size()) if p.requires_grad else 0\n            for p in model.parameters()\n        ]\n    )\n\n    # Dataloader========================================================================================================\n\n    # use WeightedRandomSampler to balance the training dataset\n\n    labels_unique, counts = np.unique(y, return_counts=True)\n\n    class_weights = np.array([1.0 / x for x in counts])\n    # provide weights for samples in the training set only\n    sample_weights = class_weights[y[train_window_indices]]\n    # sampler needs to come up with training set size number of samples\n    weighted_sampler = WeightedRandomSampler(\n        weights=sample_weights,\n        num_samples=len(train_window_indices),\n        replacement=True,\n    )\n\n    # train data loader\n    train_dataset = EEGGraphDataset(\n        x=x, y=y, num_nodes=num_nodes, indices=train_window_indices\n    )\n\n    train_loader = GraphDataLoader(\n        dataset=train_dataset,\n        batch_size=_BATCH_SIZE,\n        sampler=weighted_sampler,\n        num_workers=num_workers,\n        pin_memory=True,\n    )\n\n    # this loader is used without weighted sampling, to evaluate metrics on full training set after each epoch\n    train_metrics_loader = GraphDataLoader(\n        dataset=train_dataset,\n        batch_size=_BATCH_SIZE,\n        shuffle=False,\n        num_workers=num_workers,\n        pin_memory=True,\n    )\n\n    # test data loader\n    test_dataset = EEGGraphDataset(\n        x=x, y=y, num_nodes=num_nodes, indices=heldout_test_window_indices\n    )\n\n    test_loader = GraphDataLoader(\n        dataset=test_dataset,\n        batch_size=_BATCH_SIZE,\n        shuffle=False,\n        num_workers=num_workers,\n        pin_memory=True,\n    )\n\n    auroc_train_history = []\n    auroc_test_history = []\n    balACC_train_history = []\n    balACC_test_history = []\n    loss_train_history = []\n    loss_test_history = []\n\n    # training=========================================================================================================\n    for epoch in range(_NUM_EPOCHS):\n        model.train()\n        train_loss = []\n\n        for batch_idx, batch in enumerate(train_loader):\n            # send batch to GPU\n            g, dataset_idx, y = batch\n            g_batch = g.to(device=_DEVICE, non_blocking=True)\n            y_batch = y.to(device=_DEVICE, non_blocking=True)\n            optimizer.zero_grad()\n\n            # forward pass\n            outputs = model(g_batch)\n            loss = loss_function(outputs, y_batch)\n            train_loss.append(loss.item())\n\n            # backward pass\n            loss.backward()\n            optimizer.step()\n\n        # update learning rate\n        scheduler.step()\n\n        # evaluate model after each epoch for train-metric data============================================================\n        model.eval()\n        with torch.no_grad():\n            y_probs_train = torch.empty(0, 2).to(_DEVICE)\n            y_true_train, y_pred_train = [], []\n\n            for i, batch in enumerate(train_metrics_loader):\n                g, dataset_idx, y = batch\n                g_batch = g.to(device=_DEVICE, non_blocking=True)\n                y_batch = y.to(device=_DEVICE, non_blocking=True)\n\n                # forward pass\n                outputs = model(g_batch)\n\n                _, predicted = torch.max(outputs.data, 1)\n                y_pred_train += predicted.cpu().numpy().tolist()\n                # concatenate along 0th dimension\n                y_probs_train = torch.cat((y_probs_train, outputs.data), 0)\n                y_true_train += y_batch.cpu().numpy().tolist()\n\n        # returning prob distribution over target classes, take softmax over the 1st dimension\n        y_probs_train = (\n            nn.functional.softmax(y_probs_train, dim=1).cpu().numpy()\n        )\n        y_true_train = np.array(y_true_train)\n\n        # evaluate model after each epoch for validation data ==============================================================\n        y_probs_test = torch.empty(0, 2).to(_DEVICE)\n        y_true_test, minibatch_loss, y_pred_test = [], [], []\n\n        for i, batch in enumerate(test_loader):\n            g, dataset_idx, y = batch\n            g_batch = g.to(device=_DEVICE, non_blocking=True)\n            y_batch = y.to(device=_DEVICE, non_blocking=True)\n\n            # forward pass\n            outputs = model(g_batch)\n            _, predicted = torch.max(outputs.data, 1)\n            y_pred_test += predicted.cpu().numpy().tolist()\n\n            loss = loss_function(outputs, y_batch)\n            minibatch_loss.append(loss.item())\n            y_probs_test = torch.cat((y_probs_test, outputs.data), 0)\n            y_true_test += y_batch.cpu().numpy().tolist()\n\n        # returning prob distribution over target classes, take softmax over the 1st dimension\n        y_probs_test = (\n            torch.nn.functional.softmax(y_probs_test, dim=1).cpu().numpy()\n        )\n        y_true_test = np.array(y_true_test)\n\n        # record training auroc and testing auroc\n        auroc_train_history.append(\n            roc_auc_score(y_true_train, y_probs_train[:, 1])\n        )\n        auroc_test_history.append(\n            roc_auc_score(y_true_test, y_probs_test[:, 1])\n        )\n\n        # record training balanced accuracy and testing balanced accuracy\n        balACC_train_history.append(\n            balanced_accuracy_score(y_true_train, y_pred_train)\n        )\n        balACC_test_history.append(\n            balanced_accuracy_score(y_true_test, y_pred_test)\n        )\n\n        # LOSS - epoch loss is defined as mean of minibatch losses within epoch\n        loss_train_history.append(np.mean(train_loss))\n        loss_test_history.append(np.mean(minibatch_loss))\n\n        # print the metrics\n        print(\n            \"Train loss: {}, test loss: {}\".format(\n                loss_train_history[-1], loss_test_history[-1]\n            )\n        )\n        print(\n            \"Train AUC: {}, test AUC: {}\".format(\n                auroc_train_history[-1], auroc_test_history[-1]\n            )\n        )\n        print(\n            \"Train Bal.ACC: {}, test Bal.ACC: {}\".format(\n                balACC_train_history[-1], balACC_test_history[-1]\n            )\n        )\n\n        # save model from each epoch====================================================================================\n        state = {\n            \"epochs\": _NUM_EPOCHS,\n            \"experiment_name\": _EXPERIMENT_NAME,\n            \"model_description\": str(model),\n            \"state_dict\": model.state_dict(),\n            \"optimizer\": optimizer.state_dict(),\n        }\n        torch.save(state, f\"{_EXPERIMENT_NAME}_Epoch_{epoch}.ckpt\")\n"
  },
  {
    "path": "examples/pytorch/eeg-gcnn/shallow_EEGGraphConvNet.py",
    "content": "import torch.nn as nn\nimport torch.nn.functional as function\n\nfrom dgl.nn import GraphConv, SumPooling\n\n\nclass EEGGraphConvNet(nn.Module):\n    \"\"\"EEGGraph Convolution Net\n    Parameters\n    ----------\n    num_feats: the number of features per node. In our case, it is 6.\n    \"\"\"\n\n    def __init__(self, num_feats):\n        super(EEGGraphConvNet, self).__init__()\n\n        self.conv1 = GraphConv(num_feats, 32)\n        self.conv2 = GraphConv(32, 20)\n        self.conv2_bn = nn.BatchNorm1d(\n            20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n        )\n        self.fc_block1 = nn.Linear(20, 10)\n        self.fc_block2 = nn.Linear(10, 2)\n\n        # Xavier initializations\n        self.fc_block1.apply(lambda x: nn.init.xavier_normal_(x.weight, gain=1))\n        self.fc_block2.apply(lambda x: nn.init.xavier_normal_(x.weight, gain=1))\n\n    def forward(self, g, return_graph_embedding=False):\n        x = g.ndata[\"x\"]\n        edge_weight = g.edata[\"edge_weights\"]\n\n        x = function.leaky_relu(self.conv1(g, x, edge_weight=edge_weight))\n        x = function.leaky_relu(\n            self.conv2_bn(self.conv2(g, x, edge_weight=edge_weight))\n        )\n\n        # NOTE: this takes node-level features/\"embeddings\"\n        # and aggregates to graph-level - use for graph-level classification\n        sumpool = SumPooling()\n        out = sumpool(g, x)\n        if return_graph_embedding:\n            return out\n\n        out = function.dropout(out, p=0.2, training=self.training)\n        out = self.fc_block1(out)\n        out = function.leaky_relu(out)\n        out = self.fc_block2(out)\n\n        return out\n"
  },
  {
    "path": "examples/pytorch/eges/.gitignore",
    "content": "__pycache__\n\n"
  },
  {
    "path": "examples/pytorch/eges/README.md",
    "content": "# DGL & Pytorch implementation of Enhanced Graph Embedding with Side information (EGES)\nPaper link: https://arxiv.org/pdf/1803.02349.pdf\nReference code repo: (https://github.com/wangzhegeek/EGES.git)\n\n## How to run\n\n- Create a folder named `data`.\n`mkdir data`\n- Download csv data\n`wget https://raw.githubusercontent.com/Wang-Yu-Qing/dgl_data/master/eges_data/action_head.csv -P data/`\n`wget https://raw.githubusercontent.com/Wang-Yu-Qing/dgl_data/master/eges_data/jdata_product.csv -P data/`\n- Run with the following command (with default configuration)\n`python main.py`\n\n## Result\n```\nEvaluate link prediction AUC: 0.7084\n```\n"
  },
  {
    "path": "examples/pytorch/eges/main.py",
    "content": "import dgl\nimport torch as th\nimport torch.optim as optim\nimport utils\nfrom model import EGES\nfrom sampler import Sampler\nfrom sklearn import metrics\nfrom torch.utils.data import DataLoader\n\n\ndef train(args, train_g, sku_info, num_skus, num_brands, num_shops, num_cates):\n    sampler = Sampler(\n        train_g,\n        args.walk_length,\n        args.num_walks,\n        args.window_size,\n        args.num_negative,\n    )\n    # for each node in the graph, we sample pos and neg\n    # pairs for it, and feed these sampled pairs into the model.\n    # (nodes in the graph are of course batched before sampling)\n    dataloader = DataLoader(\n        th.arange(train_g.num_nodes()),\n        # this is the batch_size of input nodes\n        batch_size=args.batch_size,\n        shuffle=True,\n        collate_fn=lambda x: sampler.sample(x, sku_info),\n    )\n    model = EGES(args.dim, num_skus, num_brands, num_shops, num_cates)\n    optimizer = optim.Adam(model.parameters(), lr=args.lr)\n\n    for epoch in range(args.epochs):\n        epoch_total_loss = 0\n        for step, (srcs, dsts, labels) in enumerate(dataloader):\n            # the batch size of output pairs is unfixed\n            # TODO: shuffle the triples?\n            srcs_embeds, dsts_embeds = model(srcs, dsts)\n            loss = model.loss(srcs_embeds, dsts_embeds, labels)\n\n            optimizer.zero_grad()\n            loss.backward()\n            optimizer.step()\n\n            epoch_total_loss += loss.item()\n\n            if step % args.log_every == 0:\n                print(\n                    \"Epoch {:05d} | Step {:05d} | Step Loss {:.4f} | Epoch Avg Loss: {:.4f}\".format(\n                        epoch, step, loss.item(), epoch_total_loss / (step + 1)\n                    )\n                )\n\n        eval(model, test_g, sku_info)\n\n    return model\n\n\ndef eval(model, test_graph, sku_info):\n    preds, labels = [], []\n    for edge in test_graph:\n        src = th.tensor(sku_info[edge.src.numpy()[0]]).view(1, 4)\n        dst = th.tensor(sku_info[edge.dst.numpy()[0]]).view(1, 4)\n        # (1, dim)\n        src = model.query_node_embed(src)\n        dst = model.query_node_embed(dst)\n        # (1, dim) -> (1, dim) -> (1, )\n        logit = th.sigmoid(th.sum(src * dst))\n        preds.append(logit.detach().numpy().tolist())\n        labels.append(edge.label)\n\n    fpr, tpr, thresholds = metrics.roc_curve(labels, preds, pos_label=1)\n\n    print(\"Evaluate link prediction AUC: {:.4f}\".format(metrics.auc(fpr, tpr)))\n\n\nif __name__ == \"__main__\":\n    args = utils.init_args()\n\n    valid_sku_raw_ids = utils.get_valid_sku_set(args.item_info_data)\n\n    g, sku_encoder, sku_decoder = utils.construct_graph(\n        args.action_data, args.session_interval_sec, valid_sku_raw_ids\n    )\n\n    train_g, test_g = utils.split_train_test_graph(g)\n\n    sku_info_encoder, sku_info_decoder, sku_info = utils.encode_sku_fields(\n        args.item_info_data, sku_encoder, sku_decoder\n    )\n\n    num_skus = len(sku_encoder)\n    num_brands = len(sku_info_encoder[\"brand\"])\n    num_shops = len(sku_info_encoder[\"shop\"])\n    num_cates = len(sku_info_encoder[\"cate\"])\n\n    print(\n        \"Num skus: {}, num brands: {}, num shops: {}, num cates: {}\".format(\n            num_skus, num_brands, num_shops, num_cates\n        )\n    )\n\n    model = train(\n        args, train_g, sku_info, num_skus, num_brands, num_shops, num_cates\n    )\n"
  },
  {
    "path": "examples/pytorch/eges/model.py",
    "content": "import torch as th\n\n\nclass EGES(th.nn.Module):\n    def __init__(self, dim, num_nodes, num_brands, num_shops, num_cates):\n        super(EGES, self).__init__()\n        self.dim = dim\n        # embeddings for nodes\n        base_embeds = th.nn.Embedding(num_nodes, dim)\n        brand_embeds = th.nn.Embedding(num_brands, dim)\n        shop_embeds = th.nn.Embedding(num_shops, dim)\n        cate_embeds = th.nn.Embedding(num_cates, dim)\n        self.embeds = [base_embeds, brand_embeds, shop_embeds, cate_embeds]\n        # weights for each node's side information\n        self.side_info_weights = th.nn.Embedding(num_nodes, 4)\n\n    def forward(self, srcs, dsts):\n        # srcs: sku_id, brand_id, shop_id, cate_id\n        srcs = self.query_node_embed(srcs)\n        dsts = self.query_node_embed(dsts)\n\n        return srcs, dsts\n\n    def query_node_embed(self, nodes):\n        \"\"\"\n        @nodes: tensor of shape (batch_size, num_side_info)\n        \"\"\"\n        batch_size = nodes.shape[0]\n        # query side info weights, (batch_size, 4)\n        side_info_weights = th.exp(self.side_info_weights(nodes[:, 0]))\n        # merge all embeddings\n        side_info_weighted_embeds_sum = []\n        side_info_weights_sum = []\n        for i in range(4):\n            # weights for i-th side info, (batch_size, ) -> (batch_size, 1)\n            i_th_side_info_weights = side_info_weights[:, i].view(\n                (batch_size, 1)\n            )\n            # batch of i-th side info embedding * its weight, (batch_size, dim)\n            side_info_weighted_embeds_sum.append(\n                i_th_side_info_weights * self.embeds[i](nodes[:, i])\n            )\n            side_info_weights_sum.append(i_th_side_info_weights)\n        # stack: (batch_size, 4, dim), sum: (batch_size, dim)\n        side_info_weighted_embeds_sum = th.sum(\n            th.stack(side_info_weighted_embeds_sum, axis=1), axis=1\n        )\n        # stack: (batch_size, 4), sum: (batch_size, )\n        side_info_weights_sum = th.sum(\n            th.stack(side_info_weights_sum, axis=1), axis=1\n        )\n        # (batch_size, dim)\n        H = side_info_weighted_embeds_sum / side_info_weights_sum\n\n        return H\n\n    def loss(self, srcs, dsts, labels):\n        dots = th.sigmoid(th.sum(srcs * dsts, axis=1))\n        dots = th.clamp(dots, min=1e-7, max=1 - 1e-7)\n\n        return th.mean(\n            -(labels * th.log(dots) + (1 - labels) * th.log(1 - dots))\n        )\n"
  },
  {
    "path": "examples/pytorch/eges/sampler.py",
    "content": "import dgl\nimport numpy as np\nimport torch as th\n\n\nclass Sampler:\n    def __init__(\n        self, graph, walk_length, num_walks, window_size, num_negative\n    ):\n        self.graph = graph\n        self.walk_length = walk_length\n        self.num_walks = num_walks\n        self.window_size = window_size\n        self.num_negative = num_negative\n        self.node_weights = self.compute_node_sample_weight()\n\n    def sample(self, batch, sku_info):\n        \"\"\"\n        Given a batch of target nodes, sample postive\n        pairs and negative pairs from the graph\n        \"\"\"\n        batch = np.repeat(batch, self.num_walks)\n\n        pos_pairs = self.generate_pos_pairs(batch)\n        neg_pairs = self.generate_neg_pairs(pos_pairs)\n\n        # get sku info with id\n        srcs, dsts, labels = [], [], []\n        for pair in pos_pairs + neg_pairs:\n            src, dst, label = pair\n            src_info = sku_info[src]\n            dst_info = sku_info[dst]\n\n            srcs.append(src_info)\n            dsts.append(dst_info)\n            labels.append(label)\n\n        return th.tensor(srcs), th.tensor(dsts), th.tensor(labels)\n\n    def filter_padding(self, traces):\n        for i in range(len(traces)):\n            traces[i] = [x for x in traces[i] if x != -1]\n\n    def generate_pos_pairs(self, nodes):\n        \"\"\"\n        For seq [1, 2, 3, 4] and node NO.2,\n        the window_size=1 will generate:\n            (1, 2) and (2, 3)\n        \"\"\"\n        # random walk\n        traces, types = dgl.sampling.random_walk(\n            g=self.graph, nodes=nodes, length=self.walk_length, prob=\"weight\"\n        )\n        traces = traces.tolist()\n        self.filter_padding(traces)\n\n        # skip-gram\n        pairs = []\n        for trace in traces:\n            for i in range(len(trace)):\n                center = trace[i]\n                left = max(0, i - self.window_size)\n                right = min(len(trace), i + self.window_size + 1)\n                pairs.extend([[center, x, 1] for x in trace[left:i]])\n                pairs.extend([[center, x, 1] for x in trace[i + 1 : right]])\n\n        return pairs\n\n    def compute_node_sample_weight(self):\n        \"\"\"\n        Using node degree as sample weight\n        \"\"\"\n        return self.graph.in_degrees().float()\n\n    def generate_neg_pairs(self, pos_pairs):\n        \"\"\"\n        Sample based on node freq in traces, frequently shown\n        nodes will have larger chance to be sampled as\n        negative node.\n        \"\"\"\n        # sample `self.num_negative` neg dst node\n        # for each pos node pair's src node.\n        negs = th.multinomial(\n            self.node_weights,\n            len(pos_pairs) * self.num_negative,\n            replacement=True,\n        ).tolist()\n\n        tar = np.repeat([pair[0] for pair in pos_pairs], self.num_negative)\n        assert len(tar) == len(negs)\n        neg_pairs = [[x, y, 0] for x, y in zip(tar, negs)]\n\n        return neg_pairs\n"
  },
  {
    "path": "examples/pytorch/eges/utils.py",
    "content": "import argparse\nimport random\nfrom datetime import datetime\n\nimport dgl\n\nimport networkx as nx\nimport numpy as np\nimport torch as th\n\n\ndef init_args():\n    # TODO: change args\n    argparser = argparse.ArgumentParser()\n    argparser.add_argument(\"--session_interval_sec\", type=int, default=1800)\n    argparser.add_argument(\n        \"--action_data\", type=str, default=\"data/action_head.csv\"\n    )\n    argparser.add_argument(\n        \"--item_info_data\", type=str, default=\"data/jdata_product.csv\"\n    )\n    argparser.add_argument(\"--walk_length\", type=int, default=10)\n    argparser.add_argument(\"--num_walks\", type=int, default=5)\n    argparser.add_argument(\"--batch_size\", type=int, default=64)\n    argparser.add_argument(\"--dim\", type=int, default=16)\n    argparser.add_argument(\"--epochs\", type=int, default=30)\n    argparser.add_argument(\"--window_size\", type=int, default=2)\n    argparser.add_argument(\"--num_negative\", type=int, default=5)\n    argparser.add_argument(\"--lr\", type=float, default=0.001)\n    argparser.add_argument(\"--log_every\", type=int, default=100)\n\n    return argparser.parse_args()\n\n\ndef construct_graph(datapath, session_interval_gap_sec, valid_sku_raw_ids):\n    user_clicks, sku_encoder, sku_decoder = parse_actions(\n        datapath, valid_sku_raw_ids\n    )\n\n    # {src,dst: weight}\n    graph = {}\n    for user_id, action_list in user_clicks.items():\n        # sort by action time\n        _action_list = sorted(action_list, key=lambda x: x[1])\n\n        last_action_time = datetime.strptime(\n            _action_list[0][1], \"%Y-%m-%d %H:%M:%S\"\n        )\n        session = [_action_list[0][0]]\n        # cut sessions and add to graph\n        for sku_id, action_time in _action_list[1:]:\n            action_time = datetime.strptime(action_time, \"%Y-%m-%d %H:%M:%S\")\n            gap = action_time - last_action_time\n            if gap.seconds < session_interval_gap_sec:\n                session.append(sku_id)\n            else:\n                # here we have a new session\n                # add prev session to graph\n                add_session(session, graph)\n                # create a new session\n                session = [sku_id]\n        # add last session\n        add_session(session, graph)\n\n    g = convert_to_dgl_graph(graph)\n\n    return g, sku_encoder, sku_decoder\n\n\ndef convert_to_dgl_graph(graph):\n    # directed graph\n    g = nx.DiGraph()\n    for edge, weight in graph.items():\n        nodes = edge.split(\",\")\n        src, dst = int(nodes[0]), int(nodes[1])\n        g.add_edge(src, dst, weight=float(weight))\n\n    return dgl.from_networkx(g, edge_attrs=[\"weight\"])\n\n\ndef add_session(session, graph):\n    \"\"\"\n    For session like:\n        [sku1, sku2, sku3]\n    add 1 weight to each of the following edges:\n        sku1 -> sku2\n        sku2 -> sku3\n    If sesson length < 2, no nodes/edges will be added\n    \"\"\"\n    for i in range(len(session) - 1):\n        edge = str(session[i]) + \",\" + str(session[i + 1])\n        try:\n            graph[edge] += 1\n        except KeyError:\n            graph[edge] = 1\n\n\ndef parse_actions(datapath, valid_sku_raw_ids):\n    user_clicks = {}\n    with open(datapath, \"r\") as f:\n        f.readline()\n        # raw_id -> new_id and new_id -> raw_id\n        sku_encoder, sku_decoder = {}, []\n        sku_id = -1\n        for line in f:\n            line = line.replace(\"\\n\", \"\")\n            fields = line.split(\",\")\n            action_type = fields[-1]\n            # actually, all types in the dataset is \"1\"\n            if action_type == \"1\":\n                user_id = fields[0]\n                sku_raw_id = fields[1]\n                if sku_raw_id in valid_sku_raw_ids:\n                    action_time = fields[2]\n                    # encode sku_id\n                    sku_id = encode_id(\n                        sku_encoder, sku_decoder, sku_raw_id, sku_id\n                    )\n\n                    # add to user clicks\n                    try:\n                        user_clicks[user_id].append((sku_id, action_time))\n                    except KeyError:\n                        user_clicks[user_id] = [(sku_id, action_time)]\n\n    return user_clicks, sku_encoder, sku_decoder\n\n\ndef encode_id(encoder, decoder, raw_id, encoded_id):\n    if raw_id in encoder:\n        return encoded_id\n    else:\n        encoded_id += 1\n        encoder[raw_id] = encoded_id\n        decoder.append(raw_id)\n\n    return encoded_id\n\n\ndef get_valid_sku_set(datapath):\n    sku_ids = set()\n    with open(datapath, \"r\") as f:\n        for line in f.readlines():\n            line.replace(\"\\n\", \"\")\n            sku_raw_id = line.split(\",\")[0]\n            sku_ids.add(sku_raw_id)\n\n    return sku_ids\n\n\ndef encode_sku_fields(datapath, sku_encoder, sku_decoder):\n    # sku_id,brand,shop_id,cate,market_time\n    sku_info_encoder = {\"brand\": {}, \"shop\": {}, \"cate\": {}}\n    sku_info_decoder = {\"brand\": [], \"shop\": [], \"cate\": []}\n    sku_info = {}\n    brand_id, shop_id, cate_id = -1, -1, -1\n    with open(datapath, \"r\") as f:\n        f.readline()\n        for line in f:\n            line = line.replace(\"\\n\", \"\")\n            fields = line.split(\",\")\n            sku_raw_id = fields[0]\n\n            brand_raw_id = fields[1]\n            shop_raw_id = fields[2]\n            cate_raw_id = fields[3]\n\n            if sku_raw_id in sku_encoder:\n                sku_id = sku_encoder[sku_raw_id]\n\n                brand_id = encode_id(\n                    sku_info_encoder[\"brand\"],\n                    sku_info_decoder[\"brand\"],\n                    brand_raw_id,\n                    brand_id,\n                )\n\n                shop_id = encode_id(\n                    sku_info_encoder[\"shop\"],\n                    sku_info_decoder[\"shop\"],\n                    shop_raw_id,\n                    shop_id,\n                )\n\n                cate_id = encode_id(\n                    sku_info_encoder[\"cate\"],\n                    sku_info_decoder[\"cate\"],\n                    cate_raw_id,\n                    cate_id,\n                )\n\n                sku_info[sku_id] = [sku_id, brand_id, shop_id, cate_id]\n\n    return sku_info_encoder, sku_info_decoder, sku_info\n\n\nclass TestEdge:\n    def __init__(self, src, dst, label):\n        self.src = src\n        self.dst = dst\n        self.label = label\n\n\ndef split_train_test_graph(graph):\n    \"\"\"\n    For test true edges, 1/3 of the edges are randomly chosen\n    and removed as ground truth in the test set,\n    the remaining graph is taken as the training set.\n    \"\"\"\n    test_edges = []\n    neg_sampler = dgl.dataloading.negative_sampler.Uniform(1)\n    sampled_edge_ids = random.sample(\n        range(graph.num_edges()), int(graph.num_edges() / 3)\n    )\n    for edge_id in sampled_edge_ids:\n        src, dst = graph.find_edges(edge_id)\n        test_edges.append(TestEdge(src, dst, 1))\n\n        src, dst = neg_sampler(graph, th.tensor([edge_id]))\n        test_edges.append(TestEdge(src, dst, 0))\n\n    graph.remove_edges(sampled_edge_ids)\n    test_graph = test_edges\n\n    return graph, test_graph\n"
  },
  {
    "path": "examples/pytorch/evolveGCN/README.md",
    "content": "# Implement EvolveGCN with DGL\npaper link: [EvolveGCN](https://arxiv.org/abs/1902.10191)  \nofficial code: [IBM/EvolveGCN](https://github.com/IBM/EvolveGCN)  \nanother implement: [pyG_temporal](https://github.com/benedekrozemberczki/pytorch_geometric_temporal/blob/master/torch_geometric_temporal/nn/recurrent/evolvegcno.py)  \n\n## Dependency:\n* dgl\n* pandas\n* numpy\n\n## Run\n* donwload Elliptic dataset from [kaggle](https://kaggle.com/ellipticco/elliptic-data-set)\n* unzip the dataset into a raw directory, such as /home/Elliptic/elliptic_bitcoin_dataset/\n* make a new dir to save processed data, such as /home/Elliptic/processed/  \n* run train.py by:\n```bash\npython train.py --raw-dir /home/Elliptic/elliptic_bitcoin_dataset/ --processed-dir /home/Elliptic/processed/\n```\n\n## Result\nUsing EvolveGCN-O can match the results of Fig.3 and Fig.4 in the paper.\n(May need to run several times to get the average)\n\n\n## Attention:  \n* Currently only the Elliptic dataset is used.\n* EvolveGCN-H is not solid in Elliptic dataset, the official code is the same.   \n\nOfficial code result when use EvolveGCN-H:  \n1. set seed to 1234, finally result is :\n> TEST epoch 189: TEST measures for class 1 - precision 0.3875 - recall 0.5714 - f1 0.4618  \n2. not set seed manually, run the same code three times:\n> TEST epoch 168: TEST measures for class 1 - precision 0.3189 - recall 0.0680 - f1 0.1121  \n> TEST epoch 270: TEST measures for class 1 - precision 0.3517 - recall 0.3018 - f1 0.3249  \n> TEST epoch 455: TEST measures for class 1 - precision 0.2271 - recall 0.2995 - f1 0.2583  \n"
  },
  {
    "path": "examples/pytorch/evolveGCN/dataset.py",
    "content": "import os\n\nimport dgl\n\nimport numpy\nimport pandas\nimport torch\n\n\ndef process_raw_data(raw_dir, processed_dir):\n    r\"\"\"\n\n    Description\n    -----------\n    Preprocess Elliptic dataset like the EvolveGCN official instruction:\n    github.com/IBM/EvolveGCN/blob/master/elliptic_construction.md\n    The main purpose is to convert original idx to contiguous idx start at 0.\n    \"\"\"\n    oid_nid_path = os.path.join(processed_dir, \"oid_nid.npy\")\n    id_label_path = os.path.join(processed_dir, \"id_label.npy\")\n    id_time_features_path = os.path.join(processed_dir, \"id_time_features.npy\")\n    src_dst_time_path = os.path.join(processed_dir, \"src_dst_time.npy\")\n    if (\n        os.path.exists(oid_nid_path)\n        and os.path.exists(id_label_path)\n        and os.path.exists(id_time_features_path)\n        and os.path.exists(src_dst_time_path)\n    ):\n        print(\n            \"The preprocessed data already exists, skip the preprocess stage!\"\n        )\n        return\n    print(\"starting process raw data in {}\".format(raw_dir))\n    id_label = pandas.read_csv(\n        os.path.join(raw_dir, \"elliptic_txs_classes.csv\")\n    )\n    src_dst = pandas.read_csv(\n        os.path.join(raw_dir, \"elliptic_txs_edgelist.csv\")\n    )\n    # elliptic_txs_features.csv has no header, and it has the same order idx with elliptic_txs_classes.csv\n    id_time_features = pandas.read_csv(\n        os.path.join(raw_dir, \"elliptic_txs_features.csv\"), header=None\n    )\n\n    # get oldId_newId\n    oid_nid = id_label.loc[:, [\"txId\"]]\n    oid_nid = oid_nid.rename(columns={\"txId\": \"originalId\"})\n    oid_nid.insert(1, \"newId\", range(len(oid_nid)))\n\n    # map classes unknown,1,2 to -1,1,0 and construct id_label. type 1 means illicit.\n    id_label = pandas.concat(\n        [\n            oid_nid[\"newId\"],\n            id_label[\"class\"].map({\"unknown\": -1.0, \"1\": 1.0, \"2\": 0.0}),\n        ],\n        axis=1,\n    )\n\n    # replace originalId to newId.\n    # Attention: the timestamp in features start at 1.\n    id_time_features[0] = oid_nid[\"newId\"]\n\n    # construct originalId2newId dict\n    oid_nid_dict = oid_nid.set_index([\"originalId\"])[\"newId\"].to_dict()\n    # construct newId2timestamp dict\n    nid_time_dict = id_time_features.set_index([0])[1].to_dict()\n\n    # Map id in edgelist to newId, and add a timestamp to each edge.\n    # Attention: From the EvolveGCN official instruction, the timestamp with edgelist start at 0, rather than 1.\n    # see: github.com/IBM/EvolveGCN/blob/master/elliptic_construction.md\n    # Here we dose not follow the official instruction, which means timestamp with edgelist also start at 1.\n    # In EvolveGCN example, the edge timestamp will not be used.\n    #\n    # Note: in the dataset, src and dst node has the same timestamp, so it's easy to set edge's timestamp.\n    new_src = src_dst[\"txId1\"].map(oid_nid_dict).rename(\"newSrc\")\n    new_dst = src_dst[\"txId2\"].map(oid_nid_dict).rename(\"newDst\")\n    edge_time = new_src.map(nid_time_dict).rename(\"timestamp\")\n    src_dst_time = pandas.concat([new_src, new_dst, edge_time], axis=1)\n\n    # save oid_nid, id_label, id_time_features, src_dst_time to disk. we can convert them to numpy.\n    # oid_nid: type int.  id_label: type int.  id_time_features: type float.  src_dst_time: type int.\n    oid_nid = oid_nid.to_numpy(dtype=int)\n    id_label = id_label.to_numpy(dtype=int)\n    id_time_features = id_time_features.to_numpy(dtype=float)\n    src_dst_time = src_dst_time.to_numpy(dtype=int)\n\n    numpy.save(oid_nid_path, oid_nid)\n    numpy.save(id_label_path, id_label)\n    numpy.save(id_time_features_path, id_time_features)\n    numpy.save(src_dst_time_path, src_dst_time)\n    print(\n        \"Process Elliptic raw data done, data has saved into {}\".format(\n            processed_dir\n        )\n    )\n\n\nclass EllipticDataset:\n    def __init__(\n        self, raw_dir, processed_dir, self_loop=True, reverse_edge=True\n    ):\n        self.raw_dir = raw_dir\n        self.processd_dir = processed_dir\n        self.self_loop = self_loop\n        self.reverse_edge = reverse_edge\n\n    def process(self):\n        process_raw_data(self.raw_dir, self.processd_dir)\n        id_time_features = torch.Tensor(\n            numpy.load(os.path.join(self.processd_dir, \"id_time_features.npy\"))\n        )\n        id_label = torch.IntTensor(\n            numpy.load(os.path.join(self.processd_dir, \"id_label.npy\"))\n        )\n        src_dst_time = torch.IntTensor(\n            numpy.load(os.path.join(self.processd_dir, \"src_dst_time.npy\"))\n        )\n\n        src = src_dst_time[:, 0]\n        dst = src_dst_time[:, 1]\n        # id_label[:, 0] is used to add self loop\n        if self.self_loop:\n            if self.reverse_edge:\n                g = dgl.graph(\n                    data=(\n                        torch.cat((src, dst, id_label[:, 0])),\n                        torch.cat((dst, src, id_label[:, 0])),\n                    ),\n                    num_nodes=id_label.shape[0],\n                )\n                g.edata[\"timestamp\"] = torch.cat(\n                    (\n                        src_dst_time[:, 2],\n                        src_dst_time[:, 2],\n                        id_time_features[:, 1].int(),\n                    )\n                )\n            else:\n                g = dgl.graph(\n                    data=(\n                        torch.cat((src, id_label[:, 0])),\n                        torch.cat((dst, id_label[:, 0])),\n                    ),\n                    num_nodes=id_label.shape[0],\n                )\n                g.edata[\"timestamp\"] = torch.cat(\n                    (src_dst_time[:, 2], id_time_features[:, 1].int())\n                )\n        else:\n            if self.reverse_edge:\n                g = dgl.graph(\n                    data=(torch.cat((src, dst)), torch.cat((dst, src))),\n                    num_nodes=id_label.shape[0],\n                )\n                g.edata[\"timestamp\"] = torch.cat(\n                    (src_dst_time[:, 2], src_dst_time[:, 2])\n                )\n            else:\n                g = dgl.graph(data=(src, dst), num_nodes=id_label.shape[0])\n                g.edata[\"timestamp\"] = src_dst_time[:, 2]\n\n        time_features = id_time_features[:, 1:]\n        label = id_label[:, 1]\n        g.ndata[\"label\"] = label\n        g.ndata[\"feat\"] = time_features\n\n        # used to construct time-based sub-graph.\n        node_mask_by_time = []\n        start_time = int(torch.min(id_time_features[:, 1]))\n        end_time = int(torch.max(id_time_features[:, 1]))\n        for i in range(start_time, end_time + 1):\n            node_mask = id_time_features[:, 1] == i\n            node_mask_by_time.append(node_mask)\n\n        return g, node_mask_by_time\n\n    @property\n    def num_classes(self):\n        r\"\"\"Number of classes for each node.\"\"\"\n        return 2\n"
  },
  {
    "path": "examples/pytorch/evolveGCN/model.py",
    "content": "import torch\nimport torch.nn as nn\n\nfrom dgl.nn.pytorch import GraphConv\nfrom torch.nn import init\nfrom torch.nn.parameter import Parameter\n\n\nclass MatGRUCell(torch.nn.Module):\n    \"\"\"\n    GRU cell for matrix, similar to the official code.\n    Please refer to section 3.4 of the paper for the formula.\n    \"\"\"\n\n    def __init__(self, in_feats, out_feats):\n        super().__init__()\n        self.update = MatGRUGate(in_feats, out_feats, torch.nn.Sigmoid())\n\n        self.reset = MatGRUGate(in_feats, out_feats, torch.nn.Sigmoid())\n\n        self.htilda = MatGRUGate(in_feats, out_feats, torch.nn.Tanh())\n\n    def forward(self, prev_Q, z_topk=None):\n        if z_topk is None:\n            z_topk = prev_Q\n\n        update = self.update(z_topk, prev_Q)\n        reset = self.reset(z_topk, prev_Q)\n\n        h_cap = reset * prev_Q\n        h_cap = self.htilda(z_topk, h_cap)\n\n        new_Q = (1 - update) * prev_Q + update * h_cap\n\n        return new_Q\n\n\nclass MatGRUGate(torch.nn.Module):\n    \"\"\"\n    GRU gate for matrix, similar to the official code.\n    Please refer to section 3.4 of the paper for the formula.\n    \"\"\"\n\n    def __init__(self, rows, cols, activation):\n        super().__init__()\n        self.activation = activation\n        self.W = Parameter(torch.Tensor(rows, rows))\n        self.U = Parameter(torch.Tensor(rows, rows))\n        self.bias = Parameter(torch.Tensor(rows, cols))\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        init.xavier_uniform_(self.W)\n        init.xavier_uniform_(self.U)\n        init.zeros_(self.bias)\n\n    def forward(self, x, hidden):\n        out = self.activation(\n            self.W.matmul(x) + self.U.matmul(hidden) + self.bias\n        )\n\n        return out\n\n\nclass TopK(torch.nn.Module):\n    \"\"\"\n    Similar to the official `egcn_h.py`. We only consider the node in a timestamp based subgraph,\n    so we need to pay attention to `K` should be less than the min node numbers in all subgraph.\n    Please refer to section 3.4 of the paper for the formula.\n    \"\"\"\n\n    def __init__(self, feats, k):\n        super().__init__()\n        self.scorer = Parameter(torch.Tensor(feats, 1))\n        self.reset_parameters()\n\n        self.k = k\n\n    def reset_parameters(self):\n        init.xavier_uniform_(self.scorer)\n\n    def forward(self, node_embs):\n        scores = node_embs.matmul(self.scorer) / self.scorer.norm().clamp(\n            min=1e-6\n        )\n        vals, topk_indices = scores.view(-1).topk(self.k)\n        out = node_embs[topk_indices] * torch.tanh(\n            scores[topk_indices].view(-1, 1)\n        )\n        # we need to transpose the output\n        return out.t()\n\n\nclass EvolveGCNH(nn.Module):\n    def __init__(\n        self,\n        in_feats=166,\n        n_hidden=76,\n        num_layers=2,\n        n_classes=2,\n        classifier_hidden=510,\n    ):\n        # default parameters follow the official config\n        super(EvolveGCNH, self).__init__()\n        self.num_layers = num_layers\n        self.pooling_layers = nn.ModuleList()\n        self.recurrent_layers = nn.ModuleList()\n        self.gnn_convs = nn.ModuleList()\n        self.gcn_weights_list = nn.ParameterList()\n\n        self.pooling_layers.append(TopK(in_feats, n_hidden))\n        # similar to EvolveGCNO\n        self.recurrent_layers.append(\n            MatGRUCell(in_feats=in_feats, out_feats=n_hidden)\n        )\n        self.gcn_weights_list.append(\n            Parameter(torch.Tensor(in_feats, n_hidden))\n        )\n        self.gnn_convs.append(\n            GraphConv(\n                in_feats=in_feats,\n                out_feats=n_hidden,\n                bias=False,\n                activation=nn.RReLU(),\n                weight=False,\n            )\n        )\n        for _ in range(num_layers - 1):\n            self.pooling_layers.append(TopK(n_hidden, n_hidden))\n            self.recurrent_layers.append(\n                MatGRUCell(in_feats=n_hidden, out_feats=n_hidden)\n            )\n            self.gcn_weights_list.append(\n                Parameter(torch.Tensor(n_hidden, n_hidden))\n            )\n            self.gnn_convs.append(\n                GraphConv(\n                    in_feats=n_hidden,\n                    out_feats=n_hidden,\n                    bias=False,\n                    activation=nn.RReLU(),\n                    weight=False,\n                )\n            )\n\n        self.mlp = nn.Sequential(\n            nn.Linear(n_hidden, classifier_hidden),\n            nn.ReLU(),\n            nn.Linear(classifier_hidden, n_classes),\n        )\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        for gcn_weight in self.gcn_weights_list:\n            init.xavier_uniform_(gcn_weight)\n\n    def forward(self, g_list):\n        feature_list = []\n        for g in g_list:\n            feature_list.append(g.ndata[\"feat\"])\n        for i in range(self.num_layers):\n            W = self.gcn_weights_list[i]\n            for j, g in enumerate(g_list):\n                X_tilde = self.pooling_layers[i](feature_list[j])\n                W = self.recurrent_layers[i](W, X_tilde)\n                feature_list[j] = self.gnn_convs[i](\n                    g, feature_list[j], weight=W\n                )\n        return self.mlp(feature_list[-1])\n\n\nclass EvolveGCNO(nn.Module):\n    def __init__(\n        self,\n        in_feats=166,\n        n_hidden=256,\n        num_layers=2,\n        n_classes=2,\n        classifier_hidden=307,\n    ):\n        # default parameters follow the official config\n        super(EvolveGCNO, self).__init__()\n        self.num_layers = num_layers\n        self.recurrent_layers = nn.ModuleList()\n        self.gnn_convs = nn.ModuleList()\n        self.gcn_weights_list = nn.ParameterList()\n\n        # In the paper, EvolveGCN-O use LSTM as RNN layer. According to the official code,\n        # EvolveGCN-O use GRU as RNN layer. Here we follow the official code.\n        # See: https://github.com/IBM/EvolveGCN/blob/90869062bbc98d56935e3d92e1d9b1b4c25be593/egcn_o.py#L53\n        # PS: I try to use torch.nn.LSTM directly,\n        #     like [pyg_temporal](github.com/benedekrozemberczki/pytorch_geometric_temporal/blob/master/torch_geometric_temporal/nn/recurrent/evolvegcno.py)\n        #     but the performance is worse than use torch.nn.GRU.\n        # PPS: I think torch.nn.GRU can't match the manually implemented GRU cell in the official code,\n        #      we follow the official code here.\n        self.recurrent_layers.append(\n            MatGRUCell(in_feats=in_feats, out_feats=n_hidden)\n        )\n        self.gcn_weights_list.append(\n            Parameter(torch.Tensor(in_feats, n_hidden))\n        )\n        self.gnn_convs.append(\n            GraphConv(\n                in_feats=in_feats,\n                out_feats=n_hidden,\n                bias=False,\n                activation=nn.RReLU(),\n                weight=False,\n            )\n        )\n        for _ in range(num_layers - 1):\n            self.recurrent_layers.append(\n                MatGRUCell(in_feats=n_hidden, out_feats=n_hidden)\n            )\n            self.gcn_weights_list.append(\n                Parameter(torch.Tensor(n_hidden, n_hidden))\n            )\n            self.gnn_convs.append(\n                GraphConv(\n                    in_feats=n_hidden,\n                    out_feats=n_hidden,\n                    bias=False,\n                    activation=nn.RReLU(),\n                    weight=False,\n                )\n            )\n\n        self.mlp = nn.Sequential(\n            nn.Linear(n_hidden, classifier_hidden),\n            nn.ReLU(),\n            nn.Linear(classifier_hidden, n_classes),\n        )\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        for gcn_weight in self.gcn_weights_list:\n            init.xavier_uniform_(gcn_weight)\n\n    def forward(self, g_list):\n        feature_list = []\n        for g in g_list:\n            feature_list.append(g.ndata[\"feat\"])\n        for i in range(self.num_layers):\n            W = self.gcn_weights_list[i]\n            for j, g in enumerate(g_list):\n                # Attention: I try to use the below code to set gcn.weight(similar to pyG_temporal),\n                # but it doesn't work. It seems that the gradient function lost in this situation,\n                # more discussion see here: https://github.com/benedekrozemberczki/pytorch_geometric_temporal/issues/80\n                # ====================================================\n                # W = self.gnn_convs[i].weight[None, :, :]\n                # W, _ = self.recurrent_layers[i](W)\n                # self.gnn_convs[i].weight = nn.Parameter(W.squeeze())\n                # ====================================================\n\n                # Remove the following line of code, it will become `GCN`.\n                W = self.recurrent_layers[i](W)\n                feature_list[j] = self.gnn_convs[i](\n                    g, feature_list[j], weight=W\n                )\n        return self.mlp(feature_list[-1])\n"
  },
  {
    "path": "examples/pytorch/evolveGCN/train.py",
    "content": "import argparse\nimport time\n\nimport dgl\n\nimport torch\nimport torch.nn.functional as F\nfrom dataset import EllipticDataset\nfrom model import EvolveGCNH, EvolveGCNO\nfrom utils import Measure\n\n\ndef train(args, device):\n    elliptic_dataset = EllipticDataset(\n        raw_dir=args.raw_dir,\n        processed_dir=args.processed_dir,\n        self_loop=True,\n        reverse_edge=True,\n    )\n\n    g, node_mask_by_time = elliptic_dataset.process()\n    num_classes = elliptic_dataset.num_classes\n\n    cached_subgraph = []\n    cached_labeled_node_mask = []\n    for i in range(len(node_mask_by_time)):\n        # we add self loop edge when we construct full graph, not here\n        node_subgraph = dgl.node_subgraph(graph=g, nodes=node_mask_by_time[i])\n        cached_subgraph.append(node_subgraph.to(device))\n        valid_node_mask = node_subgraph.ndata[\"label\"] >= 0\n        cached_labeled_node_mask.append(valid_node_mask)\n\n    if args.model == \"EvolveGCN-O\":\n        model = EvolveGCNO(\n            in_feats=int(g.ndata[\"feat\"].shape[1]),\n            n_hidden=args.n_hidden,\n            num_layers=args.n_layers,\n        )\n    elif args.model == \"EvolveGCN-H\":\n        model = EvolveGCNH(\n            in_feats=int(g.ndata[\"feat\"].shape[1]), num_layers=args.n_layers\n        )\n    else:\n        return NotImplementedError(\"Unsupported model {}\".format(args.model))\n    model = model.to(device)\n    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)\n\n    # split train, valid, test(0-30,31-35,36-48)\n    # train/valid/test split follow the paper.\n    train_max_index = 30\n    valid_max_index = 35\n    test_max_index = 48\n    time_window_size = args.n_hist_steps\n    loss_class_weight = [float(w) for w in args.loss_class_weight.split(\",\")]\n    loss_class_weight = torch.Tensor(loss_class_weight).to(device)\n\n    train_measure = Measure(\n        num_classes=num_classes, target_class=args.eval_class_id\n    )\n    valid_measure = Measure(\n        num_classes=num_classes, target_class=args.eval_class_id\n    )\n    test_measure = Measure(\n        num_classes=num_classes, target_class=args.eval_class_id\n    )\n\n    test_res_f1 = 0\n    for epoch in range(args.num_epochs):\n        model.train()\n        for i in range(time_window_size, train_max_index + 1):\n            g_list = cached_subgraph[i - time_window_size : i + 1]\n            predictions = model(g_list)\n            # get predictions which has label\n            predictions = predictions[cached_labeled_node_mask[i]]\n            labels = (\n                cached_subgraph[i]\n                .ndata[\"label\"][cached_labeled_node_mask[i]]\n                .long()\n            )\n            loss = F.cross_entropy(\n                predictions, labels, weight=loss_class_weight\n            )\n            optimizer.zero_grad()\n            loss.backward()\n            optimizer.step()\n\n            train_measure.append_measures(predictions, labels)\n\n        # get each epoch measures during training.\n        cl_precision, cl_recall, cl_f1 = train_measure.get_total_measure()\n        train_measure.update_best_f1(cl_f1, epoch)\n        # reset measures for next epoch\n        train_measure.reset_info()\n\n        print(\n            \"Train Epoch {} | class {} | precision:{:.4f} | recall: {:.4f} | f1: {:.4f}\".format(\n                epoch, args.eval_class_id, cl_precision, cl_recall, cl_f1\n            )\n        )\n\n        # eval\n        model.eval()\n        for i in range(train_max_index + 1, valid_max_index + 1):\n            g_list = cached_subgraph[i - time_window_size : i + 1]\n            predictions = model(g_list)\n            # get node predictions which has label\n            predictions = predictions[cached_labeled_node_mask[i]]\n            labels = (\n                cached_subgraph[i]\n                .ndata[\"label\"][cached_labeled_node_mask[i]]\n                .long()\n            )\n\n            valid_measure.append_measures(predictions, labels)\n\n        # get each epoch measure during eval.\n        cl_precision, cl_recall, cl_f1 = valid_measure.get_total_measure()\n        valid_measure.update_best_f1(cl_f1, epoch)\n        # reset measures for next epoch\n        valid_measure.reset_info()\n\n        print(\n            \"Eval Epoch {} | class {} | precision:{:.4f} | recall: {:.4f} | f1: {:.4f}\".format(\n                epoch, args.eval_class_id, cl_precision, cl_recall, cl_f1\n            )\n        )\n\n        # early stop\n        if epoch - valid_measure.target_best_f1_epoch >= args.patience:\n            print(\n                \"Best eval Epoch {}, Cur Epoch {}\".format(\n                    valid_measure.target_best_f1_epoch, epoch\n                )\n            )\n            break\n        # if cur valid f1 score is best, do test\n        if epoch == valid_measure.target_best_f1_epoch:\n            print(\n                \"###################Epoch {} Test###################\".format(\n                    epoch\n                )\n            )\n            for i in range(valid_max_index + 1, test_max_index + 1):\n                g_list = cached_subgraph[i - time_window_size : i + 1]\n                predictions = model(g_list)\n                # get predictions which has label\n                predictions = predictions[cached_labeled_node_mask[i]]\n                labels = (\n                    cached_subgraph[i]\n                    .ndata[\"label\"][cached_labeled_node_mask[i]]\n                    .long()\n                )\n\n                test_measure.append_measures(predictions, labels)\n\n            # we get each subgraph measure when testing to match fig 4 in EvolveGCN paper.\n            (\n                cl_precisions,\n                cl_recalls,\n                cl_f1s,\n            ) = test_measure.get_each_timestamp_measure()\n            for index, (sub_p, sub_r, sub_f1) in enumerate(\n                zip(cl_precisions, cl_recalls, cl_f1s)\n            ):\n                print(\n                    \"  Test | Time {} | precision:{:.4f} | recall: {:.4f} | f1: {:.4f}\".format(\n                        valid_max_index + index + 2, sub_p, sub_r, sub_f1\n                    )\n                )\n\n            # get each epoch measure during test.\n            cl_precision, cl_recall, cl_f1 = test_measure.get_total_measure()\n            test_measure.update_best_f1(cl_f1, epoch)\n            # reset measures for next test\n            test_measure.reset_info()\n\n            test_res_f1 = cl_f1\n\n            print(\n                \"  Test | Epoch {} | class {} | precision:{:.4f} | recall: {:.4f} | f1: {:.4f}\".format(\n                    epoch, args.eval_class_id, cl_precision, cl_recall, cl_f1\n                )\n            )\n\n    print(\n        \"Best test f1 is {}, in Epoch {}\".format(\n            test_measure.target_best_f1, test_measure.target_best_f1_epoch\n        )\n    )\n    if test_measure.target_best_f1_epoch != valid_measure.target_best_f1_epoch:\n        print(\n            \"The Epoch get best Valid measure not get the best Test measure, \"\n            \"please checkout the test result in Epoch {}, which f1 is {}\".format(\n                valid_measure.target_best_f1_epoch, test_res_f1\n            )\n        )\n\n\nif __name__ == \"__main__\":\n    argparser = argparse.ArgumentParser(\"EvolveGCN\")\n    argparser.add_argument(\n        \"--model\",\n        type=str,\n        default=\"EvolveGCN-O\",\n        help=\"We can choose EvolveGCN-O or EvolveGCN-H,\"\n        \"but the EvolveGCN-H performance on Elliptic dataset is not good.\",\n    )\n    argparser.add_argument(\n        \"--raw-dir\",\n        type=str,\n        default=\"/home/Elliptic/elliptic_bitcoin_dataset/\",\n        help=\"Dir after unzip downloaded dataset, which contains 3 csv files.\",\n    )\n    argparser.add_argument(\n        \"--processed-dir\",\n        type=str,\n        default=\"/home/Elliptic/processed/\",\n        help=\"Dir to store processed raw data.\",\n    )\n    argparser.add_argument(\n        \"--gpu\",\n        type=int,\n        default=0,\n        help=\"GPU device ID. Use -1 for CPU training.\",\n    )\n    argparser.add_argument(\"--num-epochs\", type=int, default=1000)\n    argparser.add_argument(\"--n-hidden\", type=int, default=256)\n    argparser.add_argument(\"--n-layers\", type=int, default=2)\n    argparser.add_argument(\n        \"--n-hist-steps\",\n        type=int,\n        default=5,\n        help=\"If it is set to 5, it means in the first batch,\"\n        \"we use historical data of 0-4 to predict the data of time 5.\",\n    )\n    argparser.add_argument(\"--lr\", type=float, default=0.001)\n    argparser.add_argument(\n        \"--loss-class-weight\",\n        type=str,\n        default=\"0.35,0.65\",\n        help=\"Weight for loss function. Follow the official code,\"\n        \"we need to change it to 0.25, 0.75 when use EvolveGCN-H\",\n    )\n    argparser.add_argument(\n        \"--eval-class-id\",\n        type=int,\n        default=1,\n        help=\"Class type to eval. On Elliptic, type 1(illicit) is the main interest.\",\n    )\n    argparser.add_argument(\n        \"--patience\", type=int, default=100, help=\"Patience for early stopping.\"\n    )\n\n    args = argparser.parse_args()\n\n    if args.gpu >= 0:\n        device = torch.device(\"cuda:%d\" % args.gpu)\n    else:\n        device = torch.device(\"cpu\")\n\n    start_time = time.perf_counter()\n    train(args, device)\n    print(\"train time is: {}\".format(time.perf_counter() - start_time))\n"
  },
  {
    "path": "examples/pytorch/evolveGCN/utils.py",
    "content": "def calculate_measure(tp, fn, fp):\n    # avoid nan\n    if tp == 0:\n        return 0, 0, 0\n\n    p = tp * 1.0 / (tp + fp)\n    r = tp * 1.0 / (tp + fn)\n    if (p + r) > 0:\n        f1 = 2.0 * (p * r) / (p + r)\n    else:\n        f1 = 0\n    return p, r, f1\n\n\nclass Measure(object):\n    def __init__(self, num_classes, target_class):\n        \"\"\"\n\n        Args:\n            num_classes: number of classes.\n            target_class: target class we focus on, used to print info and do early stopping.\n        \"\"\"\n        self.num_classes = num_classes\n        self.target_class = target_class\n        self.true_positives = {}\n        self.false_positives = {}\n        self.false_negatives = {}\n        self.target_best_f1 = 0.0\n        self.target_best_f1_epoch = 0\n        self.reset_info()\n\n    def reset_info(self):\n        \"\"\"\n        reset info after each epoch.\n        \"\"\"\n        self.true_positives = {\n            cur_class: [] for cur_class in range(self.num_classes)\n        }\n        self.false_positives = {\n            cur_class: [] for cur_class in range(self.num_classes)\n        }\n        self.false_negatives = {\n            cur_class: [] for cur_class in range(self.num_classes)\n        }\n\n    def append_measures(self, predictions, labels):\n        predicted_classes = predictions.argmax(dim=1)\n        for cl in range(self.num_classes):\n            cl_indices = labels == cl\n            pos = predicted_classes == cl\n            hits = predicted_classes[cl_indices] == labels[cl_indices]\n\n            tp = hits.sum()\n            fn = hits.size(0) - tp\n            fp = pos.sum() - tp\n\n            self.true_positives[cl].append(tp.cpu())\n            self.false_negatives[cl].append(fn.cpu())\n            self.false_positives[cl].append(fp.cpu())\n\n    def get_each_timestamp_measure(self):\n        precisions = []\n        recalls = []\n        f1s = []\n        for i in range(len(self.true_positives[self.target_class])):\n            tp = self.true_positives[self.target_class][i]\n            fn = self.false_negatives[self.target_class][i]\n            fp = self.false_positives[self.target_class][i]\n\n            p, r, f1 = calculate_measure(tp, fn, fp)\n            precisions.append(p)\n            recalls.append(r)\n            f1s.append(f1)\n        return precisions, recalls, f1s\n\n    def get_total_measure(self):\n        tp = sum(self.true_positives[self.target_class])\n        fn = sum(self.false_negatives[self.target_class])\n        fp = sum(self.false_positives[self.target_class])\n\n        p, r, f1 = calculate_measure(tp, fn, fp)\n        return p, r, f1\n\n    def update_best_f1(self, cur_f1, cur_epoch):\n        if cur_f1 > self.target_best_f1:\n            self.target_best_f1 = cur_f1\n            self.target_best_f1_epoch = cur_epoch\n"
  },
  {
    "path": "examples/pytorch/gas/README.md",
    "content": "# DGL Implementation of the GAS Paper\n\nThis DGL example implements the Heterogeneous GCN part of the model proposed in the paper [Spam Review Detection with Graph Convolutional Networks](https://arxiv.org/abs/1908.10679).\n\nExample implementor\n----------------------\nThis example was implemented by [Kay Liu](https://github.com/kayzliu) during his SDE intern work at the AWS Shanghai AI Lab.\n\nDependencies\n----------------------\n- Python 3.7.10\n- PyTorch 1.8.1\n- dgl 0.7.0\n- scikit-learn 0.23.2\n\nDataset\n---------------------------------------\nThe datasets used for edge classification are variants of DGL's built-in [fake news datasets](https://github.com/dmlc/dgl/blob/master/python/dgl/data/fakenews.py). The converting process from tree-structured graph to bipartite graph is shown in the figure. \n\n![variant](variant.png)\n\n**NOTE**: Same as the original fake news dataset, this variant is for academic use only as well, and commercial use is prohibited. The statistics are summarized as followings:\n\n**Politifact**\n\n- Nodes:\n    - user (u): 276,277\n    - news (v): 581\n- Edges:\n    - forward: 399,016\n    - backward: 399,016\n- Number of Classes: 2\n- Node feature size: 300\n- Edge feature size: 300\n\n**Gossicop** \n\n- Nodes:\n    - user (u): 565,660\n    - news (v): 10,333\n- Edges:\n    - forward: 1,254,469\n    - backward: 1,254,469\n- Number of Classes: 2\n- Node feature size: 300\n- Edge feature size: 300\n\nHow to run\n--------------------------------\nIn the gas folder, run\n```\npython main.py\n```\n\nIf want to use a GPU, run\n```\npython main.py --gpu 0\n```\n\nIf the mini-batch training is required to run on a GPU, run\n```\npython main_sampling.py --gpu 0\n```\n\nPerformance\n-------------------------\n|Dataset               | Xianyu Graph (paper reported) | Fake News Politifact | Fake News Gossipcop |\n| -------------------- | ----------------- | -------------------- | ------------------- |\n| F1                   | 0.8143            | 0.9994               | 0.9942              |\n| AUC                  | 0.9860            | 1.0000               | 0.9991              |\n| Recall@90% precision | 0.6702            | 0.9999               | 0.9976              |"
  },
  {
    "path": "examples/pytorch/gas/dataloader.py",
    "content": "import os\n\nimport dgl\n\nimport numpy as np\nimport scipy.io as sio\nimport torch as th\nfrom dgl.data import DGLBuiltinDataset\nfrom dgl.data.utils import _get_dgl_url, load_graphs, save_graphs\n\n\nclass GASDataset(DGLBuiltinDataset):\n    file_urls = {\"pol\": \"dataset/GASPOL.zip\", \"gos\": \"dataset/GASGOS.zip\"}\n\n    def __init__(\n        self, name, raw_dir=None, random_seed=717, train_size=0.7, val_size=0.1\n    ):\n        assert name in [\"gos\", \"pol\"], \"Only supports 'gos' or 'pol'.\"\n        self.seed = random_seed\n        self.train_size = train_size\n        self.val_size = val_size\n        url = _get_dgl_url(self.file_urls[name])\n        super(GASDataset, self).__init__(name=name, url=url, raw_dir=raw_dir)\n\n    def process(self):\n        \"\"\"process raw data to graph, labels and masks\"\"\"\n        data = sio.loadmat(\n            os.path.join(self.raw_path, f\"{self.name}_retweet_graph.mat\")\n        )\n\n        adj = data[\"graph\"].tocoo()\n        num_edges = len(adj.row)\n        row, col = adj.row[: int(num_edges / 2)], adj.col[: int(num_edges / 2)]\n\n        graph = dgl.graph(\n            (np.concatenate((row, col)), np.concatenate((col, row)))\n        )\n        news_labels = data[\"label\"].squeeze()\n        num_news = len(news_labels)\n\n        node_feature = np.load(\n            os.path.join(self.raw_path, f\"{self.name}_node_feature.npy\")\n        )\n        edge_feature = np.load(\n            os.path.join(self.raw_path, f\"{self.name}_edge_feature.npy\")\n        )[: int(num_edges / 2)]\n\n        graph.ndata[\"feat\"] = th.tensor(node_feature)\n        graph.edata[\"feat\"] = th.tensor(np.tile(edge_feature, (2, 1)))\n        pos_news = news_labels.nonzero()[0]\n\n        edge_labels = th.zeros(num_edges)\n        edge_labels[graph.in_edges(pos_news, form=\"eid\")] = 1\n        edge_labels[graph.out_edges(pos_news, form=\"eid\")] = 1\n        graph.edata[\"label\"] = edge_labels\n\n        ntypes = th.ones(graph.num_nodes(), dtype=int)\n        etypes = th.ones(graph.num_edges(), dtype=int)\n\n        ntypes[graph.nodes() < num_news] = 0\n        etypes[: int(num_edges / 2)] = 0\n\n        graph.ndata[\"_TYPE\"] = ntypes\n        graph.edata[\"_TYPE\"] = etypes\n\n        hg = dgl.to_heterogeneous(graph, [\"v\", \"u\"], [\"forward\", \"backward\"])\n        self._random_split(hg, self.seed, self.train_size, self.val_size)\n\n        self.graph = hg\n\n    @property\n    def graph_path(self):\n        return os.path.join(self.save_path, self.name + \"_dgl_graph.bin\")\n\n    def save(self):\n        \"\"\"save the graph list and the labels\"\"\"\n        save_graphs(str(self.graph_path), self.graph)\n\n    def has_cache(self):\n        \"\"\"check whether there are processed data in `self.save_path`\"\"\"\n        return os.path.exists(self.graph_path)\n\n    def load(self):\n        \"\"\"load processed data from directory `self.save_path`\"\"\"\n        graph, _ = load_graphs(str(self.graph_path))\n        self.graph = graph[0]\n\n    @property\n    def num_classes(self):\n        \"\"\"Number of classes for each graph, i.e. number of prediction tasks.\"\"\"\n        return 2\n\n    def __getitem__(self, idx):\n        r\"\"\"Get graph object\n        Parameters\n        ----------\n        idx : int\n            Item index\n        Returns\n        -------\n        :class:`dgl.DGLGraph`\n        \"\"\"\n        assert idx == 0, \"This dataset has only one graph\"\n        return self.graph\n\n    def __len__(self):\n        r\"\"\"Number of data examples\n        Return\n        -------\n        int\n        \"\"\"\n        return len(self.graph)\n\n    def _random_split(self, graph, seed=717, train_size=0.7, val_size=0.1):\n        \"\"\"split the dataset into training set, validation set and testing set\"\"\"\n\n        assert 0 <= train_size + val_size <= 1, (\n            \"The sum of valid training set size and validation set size \"\n            \"must between 0 and 1 (inclusive).\"\n        )\n\n        num_edges = graph.num_edges(etype=\"forward\")\n        index = np.arange(num_edges)\n\n        index = np.random.RandomState(seed).permutation(index)\n        train_idx = index[: int(train_size * num_edges)]\n        val_idx = index[num_edges - int(val_size * num_edges) :]\n        test_idx = index[\n            int(train_size * num_edges) : num_edges - int(val_size * num_edges)\n        ]\n        train_mask = np.zeros(num_edges, dtype=np.bool_)\n        val_mask = np.zeros(num_edges, dtype=np.bool_)\n        test_mask = np.zeros(num_edges, dtype=np.bool_)\n        train_mask[train_idx] = True\n        val_mask[val_idx] = True\n        test_mask[test_idx] = True\n        graph.edges[\"forward\"].data[\"train_mask\"] = th.tensor(train_mask)\n        graph.edges[\"forward\"].data[\"val_mask\"] = th.tensor(val_mask)\n        graph.edges[\"forward\"].data[\"test_mask\"] = th.tensor(test_mask)\n        graph.edges[\"backward\"].data[\"train_mask\"] = th.tensor(train_mask)\n        graph.edges[\"backward\"].data[\"val_mask\"] = th.tensor(val_mask)\n        graph.edges[\"backward\"].data[\"test_mask\"] = th.tensor(test_mask)\n"
  },
  {
    "path": "examples/pytorch/gas/main.py",
    "content": "import argparse\n\nimport torch as th\nimport torch.nn.functional as F\nimport torch.optim as optim\nfrom dataloader import GASDataset\nfrom model import GAS\nfrom sklearn.metrics import f1_score, precision_recall_curve, roc_auc_score\n\n\ndef main(args):\n    # Step 1: Prepare graph data and retrieve train/validation/test index ============================= #\n    # Load dataset\n    dataset = GASDataset(args.dataset)\n    graph = dataset[0]\n\n    # check cuda\n    if args.gpu >= 0 and th.cuda.is_available():\n        device = \"cuda:{}\".format(args.gpu)\n    else:\n        device = \"cpu\"\n\n    # binary classification\n    num_classes = dataset.num_classes\n\n    # retrieve labels of ground truth\n    labels = graph.edges[\"forward\"].data[\"label\"].to(device).long()\n\n    # Extract node features\n    e_feat = graph.edges[\"forward\"].data[\"feat\"].to(device)\n    u_feat = graph.nodes[\"u\"].data[\"feat\"].to(device)\n    v_feat = graph.nodes[\"v\"].data[\"feat\"].to(device)\n\n    # retrieve masks for train/validation/test\n    train_mask = graph.edges[\"forward\"].data[\"train_mask\"]\n    val_mask = graph.edges[\"forward\"].data[\"val_mask\"]\n    test_mask = graph.edges[\"forward\"].data[\"test_mask\"]\n\n    train_idx = th.nonzero(train_mask, as_tuple=False).squeeze(1).to(device)\n    val_idx = th.nonzero(val_mask, as_tuple=False).squeeze(1).to(device)\n    test_idx = th.nonzero(test_mask, as_tuple=False).squeeze(1).to(device)\n\n    graph = graph.to(device)\n\n    # Step 2: Create model =================================================================== #\n    model = GAS(\n        e_in_dim=e_feat.shape[-1],\n        u_in_dim=u_feat.shape[-1],\n        v_in_dim=v_feat.shape[-1],\n        e_hid_dim=args.e_hid_dim,\n        u_hid_dim=args.u_hid_dim,\n        v_hid_dim=args.v_hid_dim,\n        out_dim=num_classes,\n        num_layers=args.num_layers,\n        dropout=args.dropout,\n        activation=F.relu,\n    )\n\n    model = model.to(device)\n\n    # Step 3: Create training components ===================================================== #\n    loss_fn = th.nn.CrossEntropyLoss()\n    optimizer = optim.Adam(\n        model.parameters(), lr=args.lr, weight_decay=args.weight_decay\n    )\n\n    # Step 4: training epochs =============================================================== #\n    for epoch in range(args.max_epoch):\n        # Training and validation using a full graph\n        model.train()\n        logits = model(graph, e_feat, u_feat, v_feat)\n\n        # compute loss\n        tr_loss = loss_fn(logits[train_idx], labels[train_idx])\n        tr_f1 = f1_score(\n            labels[train_idx].cpu(), logits[train_idx].argmax(dim=1).cpu()\n        )\n        tr_auc = roc_auc_score(\n            labels[train_idx].cpu(), logits[train_idx][:, 1].detach().cpu()\n        )\n        tr_pre, tr_re, _ = precision_recall_curve(\n            labels[train_idx].cpu(), logits[train_idx][:, 1].detach().cpu()\n        )\n        tr_rap = tr_re[tr_pre > args.precision].max()\n\n        # validation\n        valid_loss = loss_fn(logits[val_idx], labels[val_idx])\n        valid_f1 = f1_score(\n            labels[val_idx].cpu(), logits[val_idx].argmax(dim=1).cpu()\n        )\n        valid_auc = roc_auc_score(\n            labels[val_idx].cpu(), logits[val_idx][:, 1].detach().cpu()\n        )\n        valid_pre, valid_re, _ = precision_recall_curve(\n            labels[val_idx].cpu(), logits[val_idx][:, 1].detach().cpu()\n        )\n        valid_rap = valid_re[valid_pre > args.precision].max()\n\n        # backward\n        optimizer.zero_grad()\n        tr_loss.backward()\n        optimizer.step()\n\n        # Print out performance\n        print(\n            \"In epoch {}, Train R@P: {:.4f} | Train F1: {:.4f} | Train AUC: {:.4f} | Train Loss: {:.4f}; \"\n            \"Valid R@P: {:.4f} | Valid F1: {:.4f} | Valid AUC: {:.4f} | Valid loss: {:.4f}\".format(\n                epoch,\n                tr_rap,\n                tr_f1,\n                tr_auc,\n                tr_loss.item(),\n                valid_rap,\n                valid_f1,\n                valid_auc,\n                valid_loss.item(),\n            )\n        )\n\n    # Test after all epoch\n    model.eval()\n\n    # forward\n    logits = model(graph, e_feat, u_feat, v_feat)\n\n    # compute loss\n    test_loss = loss_fn(logits[test_idx], labels[test_idx])\n    test_f1 = f1_score(\n        labels[test_idx].cpu(), logits[test_idx].argmax(dim=1).cpu()\n    )\n    test_auc = roc_auc_score(\n        labels[test_idx].cpu(), logits[test_idx][:, 1].detach().cpu()\n    )\n    test_pre, test_re, _ = precision_recall_curve(\n        labels[test_idx].cpu(), logits[test_idx][:, 1].detach().cpu()\n    )\n    test_rap = test_re[test_pre > args.precision].max()\n\n    print(\n        \"Test R@P: {:.4f} | Test F1: {:.4f} | Test AUC: {:.4f} | Test loss: {:.4f}\".format(\n            test_rap, test_f1, test_auc, test_loss.item()\n        )\n    )\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"GCN-based Anti-Spam Model\")\n    parser.add_argument(\n        \"--dataset\", type=str, default=\"pol\", help=\"'pol', or 'gos'\"\n    )\n    parser.add_argument(\n        \"--gpu\", type=int, default=-1, help=\"GPU Index. Default: -1, using CPU.\"\n    )\n    parser.add_argument(\n        \"--e_hid_dim\",\n        type=int,\n        default=128,\n        help=\"Hidden layer dimension for edges\",\n    )\n    parser.add_argument(\n        \"--u_hid_dim\",\n        type=int,\n        default=128,\n        help=\"Hidden layer dimension for source nodes\",\n    )\n    parser.add_argument(\n        \"--v_hid_dim\",\n        type=int,\n        default=128,\n        help=\"Hidden layer dimension for destination nodes\",\n    )\n    parser.add_argument(\n        \"--num_layers\", type=int, default=2, help=\"Number of GCN layers\"\n    )\n    parser.add_argument(\n        \"--max_epoch\",\n        type=int,\n        default=100,\n        help=\"The max number of epochs. Default: 100\",\n    )\n    parser.add_argument(\n        \"--lr\", type=float, default=0.001, help=\"Learning rate. Default: 1e-3\"\n    )\n    parser.add_argument(\n        \"--dropout\", type=float, default=0.0, help=\"Dropout rate. Default: 0.0\"\n    )\n    parser.add_argument(\n        \"--weight_decay\",\n        type=float,\n        default=5e-4,\n        help=\"Weight Decay. Default: 0.0005\",\n    )\n    parser.add_argument(\n        \"--precision\",\n        type=float,\n        default=0.9,\n        help=\"The value p in recall@p precision. Default: 0.9\",\n    )\n\n    args = parser.parse_args()\n    print(args)\n    main(args)\n"
  },
  {
    "path": "examples/pytorch/gas/main_sampling.py",
    "content": "import argparse\n\nimport dgl\nimport torch as th\nimport torch.nn.functional as F\nimport torch.optim as optim\nfrom dataloader import GASDataset\nfrom model_sampling import GAS\nfrom sklearn.metrics import f1_score, precision_recall_curve, roc_auc_score\n\n\ndef evaluate(model, loss_fn, dataloader, device=\"cpu\"):\n    loss = 0\n    f1 = 0\n    auc = 0\n    rap = 0\n    num_blocks = 0\n    for input_nodes, edge_subgraph, blocks in dataloader:\n        blocks = [b.to(device) for b in blocks]\n        edge_subgraph = edge_subgraph.to(device)\n        u_feat = blocks[0].srcdata[\"feat\"][\"u\"]\n        v_feat = blocks[0].srcdata[\"feat\"][\"v\"]\n        f_feat = blocks[0].edges[\"forward\"].data[\"feat\"]\n        b_feat = blocks[0].edges[\"backward\"].data[\"feat\"]\n        labels = edge_subgraph.edges[\"forward\"].data[\"label\"].long()\n        logits = model(edge_subgraph, blocks, f_feat, b_feat, u_feat, v_feat)\n\n        loss += loss_fn(logits, labels).item()\n        f1 += f1_score(labels.cpu(), logits.argmax(dim=1).cpu())\n        auc += roc_auc_score(labels.cpu(), logits[:, 1].detach().cpu())\n        pre, re, _ = precision_recall_curve(\n            labels.cpu(), logits[:, 1].detach().cpu()\n        )\n        rap += re[pre > args.precision].max()\n        num_blocks += 1\n\n    return (\n        rap / num_blocks,\n        f1 / num_blocks,\n        auc / num_blocks,\n        loss / num_blocks,\n    )\n\n\ndef main(args):\n    # Step 1: Prepare graph data and retrieve train/validation/test index ============================= #\n    # Load dataset\n    dataset = GASDataset(args.dataset)\n    graph = dataset[0]\n\n    # generate mini-batch only for forward edges\n    sampler = dgl.dataloading.MultiLayerNeighborSampler([10, 10])\n    tr_eid_dict = {}\n    val_eid_dict = {}\n    test_eid_dict = {}\n    tr_eid_dict[\"forward\"] = (\n        graph.edges[\"forward\"].data[\"train_mask\"].nonzero().squeeze()\n    )\n    val_eid_dict[\"forward\"] = (\n        graph.edges[\"forward\"].data[\"val_mask\"].nonzero().squeeze()\n    )\n    test_eid_dict[\"forward\"] = (\n        graph.edges[\"forward\"].data[\"test_mask\"].nonzero().squeeze()\n    )\n\n    sampler = dgl.dataloading.as_edge_prediction_sampler(sampler)\n    tr_loader = dgl.dataloading.DataLoader(\n        graph,\n        tr_eid_dict,\n        sampler,\n        batch_size=args.batch_size,\n        shuffle=True,\n        drop_last=False,\n        num_workers=args.num_workers,\n    )\n    val_loader = dgl.dataloading.DataLoader(\n        graph,\n        val_eid_dict,\n        sampler,\n        batch_size=args.batch_size,\n        shuffle=True,\n        drop_last=False,\n        num_workers=args.num_workers,\n    )\n    test_loader = dgl.dataloading.DataLoader(\n        graph,\n        test_eid_dict,\n        sampler,\n        batch_size=args.batch_size,\n        shuffle=True,\n        drop_last=False,\n        num_workers=args.num_workers,\n    )\n\n    # check cuda\n    if args.gpu >= 0 and th.cuda.is_available():\n        device = \"cuda:{}\".format(args.gpu)\n    else:\n        device = \"cpu\"\n\n    # binary classification\n    num_classes = dataset.num_classes\n\n    # Extract node features\n    e_feats = graph.edges[\"forward\"].data[\"feat\"].shape[-1]\n    u_feats = graph.nodes[\"u\"].data[\"feat\"].shape[-1]\n    v_feats = graph.nodes[\"v\"].data[\"feat\"].shape[-1]\n\n    # Step 2: Create model =================================================================== #\n    model = GAS(\n        e_in_dim=e_feats,\n        u_in_dim=u_feats,\n        v_in_dim=v_feats,\n        e_hid_dim=args.e_hid_dim,\n        u_hid_dim=args.u_hid_dim,\n        v_hid_dim=args.v_hid_dim,\n        out_dim=num_classes,\n        num_layers=args.num_layers,\n        dropout=args.dropout,\n        activation=F.relu,\n    )\n\n    model = model.to(device)\n\n    # Step 3: Create training components ===================================================== #\n    loss_fn = th.nn.CrossEntropyLoss()\n    optimizer = optim.Adam(\n        model.parameters(), lr=args.lr, weight_decay=args.weight_decay\n    )\n\n    # Step 4: training epochs =============================================================== #\n    for epoch in range(args.max_epoch):\n        model.train()\n        tr_loss = 0\n        tr_f1 = 0\n        tr_auc = 0\n        tr_rap = 0\n        tr_blocks = 0\n        for input_nodes, edge_subgraph, blocks in tr_loader:\n            blocks = [b.to(device) for b in blocks]\n            edge_subgraph = edge_subgraph.to(device)\n            u_feat = blocks[0].srcdata[\"feat\"][\"u\"]\n            v_feat = blocks[0].srcdata[\"feat\"][\"v\"]\n            f_feat = blocks[0].edges[\"forward\"].data[\"feat\"]\n            b_feat = blocks[0].edges[\"backward\"].data[\"feat\"]\n            labels = edge_subgraph.edges[\"forward\"].data[\"label\"].long()\n            logits = model(\n                edge_subgraph, blocks, f_feat, b_feat, u_feat, v_feat\n            )\n\n            # compute loss\n            batch_loss = loss_fn(logits, labels)\n            tr_loss += batch_loss.item()\n            tr_f1 += f1_score(labels.cpu(), logits.argmax(dim=1).cpu())\n            tr_auc += roc_auc_score(labels.cpu(), logits[:, 1].detach().cpu())\n            tr_pre, tr_re, _ = precision_recall_curve(\n                labels.cpu(), logits[:, 1].detach().cpu()\n            )\n            tr_rap += tr_re[tr_pre > args.precision].max()\n            tr_blocks += 1\n\n            # backward\n            optimizer.zero_grad()\n            batch_loss.backward()\n            optimizer.step()\n\n        # validation\n        model.eval()\n        val_rap, val_f1, val_auc, val_loss = evaluate(\n            model, loss_fn, val_loader, device\n        )\n\n        # Print out performance\n        print(\n            \"In epoch {}, Train R@P: {:.4f} | Train F1: {:.4f} | Train AUC: {:.4f} | Train Loss: {:.4f}; \"\n            \"Valid R@P: {:.4f} | Valid F1: {:.4f} | Valid AUC: {:.4f} | Valid loss: {:.4f}\".format(\n                epoch,\n                tr_rap / tr_blocks,\n                tr_f1 / tr_blocks,\n                tr_auc / tr_blocks,\n                tr_loss / tr_blocks,\n                val_rap,\n                val_f1,\n                val_auc,\n                val_loss,\n            )\n        )\n\n    # Test with mini batch after all epoch\n    model.eval()\n    test_rap, test_f1, test_auc, test_loss = evaluate(\n        model, loss_fn, test_loader, device\n    )\n    print(\n        \"Test R@P: {:.4f} | Test F1: {:.4f} | Test AUC: {:.4f} | Test loss: {:.4f}\".format(\n            test_rap, test_f1, test_auc, test_loss\n        )\n    )\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"GCN-based Anti-Spam Model\")\n    parser.add_argument(\n        \"--dataset\", type=str, default=\"pol\", help=\"'pol', or 'gos'\"\n    )\n    parser.add_argument(\n        \"--gpu\", type=int, default=-1, help=\"GPU Index. Default: -1, using CPU.\"\n    )\n    parser.add_argument(\n        \"--e_hid_dim\",\n        type=int,\n        default=128,\n        help=\"Hidden layer dimension for edges\",\n    )\n    parser.add_argument(\n        \"--u_hid_dim\",\n        type=int,\n        default=128,\n        help=\"Hidden layer dimension for source nodes\",\n    )\n    parser.add_argument(\n        \"--v_hid_dim\",\n        type=int,\n        default=128,\n        help=\"Hidden layer dimension for destination nodes\",\n    )\n    parser.add_argument(\n        \"--num_layers\", type=int, default=2, help=\"Number of GCN layers\"\n    )\n    parser.add_argument(\n        \"--max_epoch\",\n        type=int,\n        default=100,\n        help=\"The max number of epochs. Default: 100\",\n    )\n    parser.add_argument(\n        \"--lr\", type=float, default=0.001, help=\"Learning rate. Default: 1e-3\"\n    )\n    parser.add_argument(\n        \"--dropout\", type=float, default=0.0, help=\"Dropout rate. Default: 0.0\"\n    )\n    parser.add_argument(\n        \"--batch_size\",\n        type=int,\n        default=64,\n        help=\"Size of mini-batches. Default: 64\",\n    )\n    parser.add_argument(\n        \"--num_workers\", type=int, default=4, help=\"Number of node dataloader\"\n    )\n    parser.add_argument(\n        \"--weight_decay\",\n        type=float,\n        default=5e-4,\n        help=\"Weight Decay. Default: 0.0005\",\n    )\n    parser.add_argument(\n        \"--precision\",\n        type=float,\n        default=0.9,\n        help=\"The value p in recall@p precision. Default: 0.9\",\n    )\n\n    args = parser.parse_args()\n    print(args)\n    main(args)\n"
  },
  {
    "path": "examples/pytorch/gas/model.py",
    "content": "import dgl.function as fn\nimport torch as th\nimport torch.nn as nn\nfrom dgl.nn.functional import edge_softmax\n\n\nclass MLP(nn.Module):\n    def __init__(self, in_dim, out_dim):\n        super().__init__()\n        self.W = nn.Linear(in_dim, out_dim)\n\n    def apply_edges(self, edges):\n        h_e = edges.data[\"h\"]\n        h_u = edges.src[\"h\"]\n        h_v = edges.dst[\"h\"]\n        score = self.W(th.cat([h_e, h_u, h_v], -1))\n        return {\"score\": score}\n\n    def forward(self, g, e_feat, u_feat, v_feat):\n        with g.local_scope():\n            g.edges[\"forward\"].data[\"h\"] = e_feat\n            g.nodes[\"u\"].data[\"h\"] = u_feat\n            g.nodes[\"v\"].data[\"h\"] = v_feat\n            g.apply_edges(self.apply_edges, etype=\"forward\")\n            return g.edges[\"forward\"].data[\"score\"]\n\n\nclass GASConv(nn.Module):\n    \"\"\"One layer of GAS.\"\"\"\n\n    def __init__(\n        self,\n        e_in_dim,\n        u_in_dim,\n        v_in_dim,\n        e_out_dim,\n        u_out_dim,\n        v_out_dim,\n        activation=None,\n        dropout=0,\n    ):\n        super(GASConv, self).__init__()\n\n        self.activation = activation\n        self.dropout = nn.Dropout(dropout)\n\n        self.e_linear = nn.Linear(e_in_dim, e_out_dim)\n        self.u_linear = nn.Linear(u_in_dim, e_out_dim)\n        self.v_linear = nn.Linear(v_in_dim, e_out_dim)\n\n        self.W_ATTN_u = nn.Linear(u_in_dim, v_in_dim + e_in_dim)\n        self.W_ATTN_v = nn.Linear(v_in_dim, u_in_dim + e_in_dim)\n\n        # the proportion of h_u and h_Nu are specified as 1/2 in formula 8\n        nu_dim = int(u_out_dim / 2)\n        nv_dim = int(v_out_dim / 2)\n\n        self.W_u = nn.Linear(v_in_dim + e_in_dim, nu_dim)\n        self.W_v = nn.Linear(u_in_dim + e_in_dim, nv_dim)\n\n        self.Vu = nn.Linear(u_in_dim, u_out_dim - nu_dim)\n        self.Vv = nn.Linear(v_in_dim, v_out_dim - nv_dim)\n\n    def forward(self, g, e_feat, u_feat, v_feat):\n        with g.local_scope():\n            g.nodes[\"u\"].data[\"h\"] = u_feat\n            g.nodes[\"v\"].data[\"h\"] = v_feat\n            g.edges[\"forward\"].data[\"h\"] = e_feat\n            g.edges[\"backward\"].data[\"h\"] = e_feat\n\n            # formula 3 and 4 (optimized implementation to save memory)\n            g.nodes[\"u\"].data.update({\"he_u\": self.u_linear(u_feat)})\n            g.nodes[\"v\"].data.update({\"he_v\": self.v_linear(v_feat)})\n            g.edges[\"forward\"].data.update({\"he_e\": self.e_linear(e_feat)})\n            g.apply_edges(\n                lambda edges: {\n                    \"he\": edges.data[\"he_e\"]\n                    + edges.src[\"he_u\"]\n                    + edges.dst[\"he_v\"]\n                },\n                etype=\"forward\",\n            )\n            he = g.edges[\"forward\"].data[\"he\"]\n            if self.activation is not None:\n                he = self.activation(he)\n\n            # formula 6\n            g.apply_edges(\n                lambda edges: {\n                    \"h_ve\": th.cat([edges.src[\"h\"], edges.data[\"h\"]], -1)\n                },\n                etype=\"backward\",\n            )\n            g.apply_edges(\n                lambda edges: {\n                    \"h_ue\": th.cat([edges.src[\"h\"], edges.data[\"h\"]], -1)\n                },\n                etype=\"forward\",\n            )\n\n            # formula 7, self-attention\n            g.nodes[\"u\"].data[\"h_att_u\"] = self.W_ATTN_u(u_feat)\n            g.nodes[\"v\"].data[\"h_att_v\"] = self.W_ATTN_v(v_feat)\n\n            # Step 1: dot product\n            g.apply_edges(\n                fn.e_dot_v(\"h_ve\", \"h_att_u\", \"edotv\"), etype=\"backward\"\n            )\n            g.apply_edges(\n                fn.e_dot_v(\"h_ue\", \"h_att_v\", \"edotv\"), etype=\"forward\"\n            )\n\n            # Step 2. softmax\n            g.edges[\"backward\"].data[\"sfm\"] = edge_softmax(\n                g[\"backward\"], g.edges[\"backward\"].data[\"edotv\"]\n            )\n            g.edges[\"forward\"].data[\"sfm\"] = edge_softmax(\n                g[\"forward\"], g.edges[\"forward\"].data[\"edotv\"]\n            )\n\n            # Step 3. Broadcast softmax value to each edge, and then attention is done\n            g.apply_edges(\n                lambda edges: {\"attn\": edges.data[\"h_ve\"] * edges.data[\"sfm\"]},\n                etype=\"backward\",\n            )\n            g.apply_edges(\n                lambda edges: {\"attn\": edges.data[\"h_ue\"] * edges.data[\"sfm\"]},\n                etype=\"forward\",\n            )\n\n            # Step 4. Aggregate attention to dst,user nodes, so formula 7 is done\n            g.update_all(\n                fn.copy_e(\"attn\", \"m\"), fn.sum(\"m\", \"agg_u\"), etype=\"backward\"\n            )\n            g.update_all(\n                fn.copy_e(\"attn\", \"m\"), fn.sum(\"m\", \"agg_v\"), etype=\"forward\"\n            )\n\n            # formula 5\n            h_nu = self.W_u(g.nodes[\"u\"].data[\"agg_u\"])\n            h_nv = self.W_v(g.nodes[\"v\"].data[\"agg_v\"])\n            if self.activation is not None:\n                h_nu = self.activation(h_nu)\n                h_nv = self.activation(h_nv)\n\n            # Dropout\n            he = self.dropout(he)\n            h_nu = self.dropout(h_nu)\n            h_nv = self.dropout(h_nv)\n\n            # formula 8\n            hu = th.cat([self.Vu(u_feat), h_nu], -1)\n            hv = th.cat([self.Vv(v_feat), h_nv], -1)\n\n            return he, hu, hv\n\n\nclass GAS(nn.Module):\n    def __init__(\n        self,\n        e_in_dim,\n        u_in_dim,\n        v_in_dim,\n        e_hid_dim,\n        u_hid_dim,\n        v_hid_dim,\n        out_dim,\n        num_layers=2,\n        dropout=0.0,\n        activation=None,\n    ):\n        super(GAS, self).__init__()\n        self.e_in_dim = e_in_dim\n        self.u_in_dim = u_in_dim\n        self.v_in_dim = v_in_dim\n        self.e_hid_dim = e_hid_dim\n        self.u_hid_dim = u_hid_dim\n        self.v_hid_dim = v_hid_dim\n        self.out_dim = out_dim\n        self.num_layer = num_layers\n        self.dropout = dropout\n        self.activation = activation\n        self.predictor = MLP(e_hid_dim + u_hid_dim + v_hid_dim, out_dim)\n        self.layers = nn.ModuleList()\n\n        # Input layer\n        self.layers.append(\n            GASConv(\n                self.e_in_dim,\n                self.u_in_dim,\n                self.v_in_dim,\n                self.e_hid_dim,\n                self.u_hid_dim,\n                self.v_hid_dim,\n                activation=self.activation,\n                dropout=self.dropout,\n            )\n        )\n\n        # Hidden layers with n - 1 CompGraphConv layers\n        for i in range(self.num_layer - 1):\n            self.layers.append(\n                GASConv(\n                    self.e_hid_dim,\n                    self.u_hid_dim,\n                    self.v_hid_dim,\n                    self.e_hid_dim,\n                    self.u_hid_dim,\n                    self.v_hid_dim,\n                    activation=self.activation,\n                    dropout=self.dropout,\n                )\n            )\n\n    def forward(self, graph, e_feat, u_feat, v_feat):\n        # For full graph training, directly use the graph\n        # Forward of n layers of GAS\n        for layer in self.layers:\n            e_feat, u_feat, v_feat = layer(graph, e_feat, u_feat, v_feat)\n\n        # return the result of final prediction layer\n        return self.predictor(graph, e_feat, u_feat, v_feat)\n"
  },
  {
    "path": "examples/pytorch/gas/model_sampling.py",
    "content": "import dgl.function as fn\nimport torch as th\nimport torch.nn as nn\nfrom dgl.nn.functional import edge_softmax\n\n\nclass MLP(nn.Module):\n    def __init__(self, in_dim, out_dim):\n        super().__init__()\n        self.W = nn.Linear(in_dim, out_dim)\n\n    def apply_edges(self, edges):\n        h_e = edges.data[\"h\"]\n        h_u = edges.src[\"h\"]\n        h_v = edges.dst[\"h\"]\n        score = self.W(th.cat([h_e, h_u, h_v], -1))\n        return {\"score\": score}\n\n    def forward(self, g, e_feat, u_feat, v_feat):\n        with g.local_scope():\n            g.edges[\"forward\"].data[\"h\"] = e_feat\n            g.nodes[\"u\"].data[\"h\"] = u_feat\n            g.nodes[\"v\"].data[\"h\"] = v_feat\n            g.apply_edges(self.apply_edges, etype=\"forward\")\n            return g.edges[\"forward\"].data[\"score\"]\n\n\nclass GASConv(nn.Module):\n    \"\"\"One layer of GAS.\"\"\"\n\n    def __init__(\n        self,\n        e_in_dim,\n        u_in_dim,\n        v_in_dim,\n        e_out_dim,\n        u_out_dim,\n        v_out_dim,\n        activation=None,\n        dropout=0,\n    ):\n        super(GASConv, self).__init__()\n\n        self.activation = activation\n        self.dropout = nn.Dropout(dropout)\n\n        self.e_linear = nn.Linear(e_in_dim, e_out_dim)\n        self.u_linear = nn.Linear(u_in_dim, e_out_dim)\n        self.v_linear = nn.Linear(v_in_dim, e_out_dim)\n\n        self.W_ATTN_u = nn.Linear(u_in_dim, v_in_dim + e_in_dim)\n        self.W_ATTN_v = nn.Linear(v_in_dim, u_in_dim + e_in_dim)\n\n        # the proportion of h_u and h_Nu are specified as 1/2 in formula 8\n        nu_dim = int(u_out_dim / 2)\n        nv_dim = int(v_out_dim / 2)\n\n        self.W_u = nn.Linear(v_in_dim + e_in_dim, nu_dim)\n        self.W_v = nn.Linear(u_in_dim + e_in_dim, nv_dim)\n\n        self.Vu = nn.Linear(u_in_dim, u_out_dim - nu_dim)\n        self.Vv = nn.Linear(v_in_dim, v_out_dim - nv_dim)\n\n    def forward(self, g, f_feat, b_feat, u_feat, v_feat):\n        g.srcnodes[\"u\"].data[\"h\"] = u_feat\n        g.srcnodes[\"v\"].data[\"h\"] = v_feat\n        g.dstnodes[\"u\"].data[\"h\"] = u_feat[: g.number_of_dst_nodes(ntype=\"u\")]\n        g.dstnodes[\"v\"].data[\"h\"] = v_feat[: g.number_of_dst_nodes(ntype=\"v\")]\n        g.edges[\"forward\"].data[\"h\"] = f_feat\n        g.edges[\"backward\"].data[\"h\"] = b_feat\n\n        # formula 3 and 4 (optimized implementation to save memory)\n        g.srcnodes[\"u\"].data.update(\n            {\"he_u\": self.u_linear(g.srcnodes[\"u\"].data[\"h\"])}\n        )\n        g.srcnodes[\"v\"].data.update(\n            {\"he_v\": self.v_linear(g.srcnodes[\"v\"].data[\"h\"])}\n        )\n        g.dstnodes[\"u\"].data.update(\n            {\"he_u\": self.u_linear(g.dstnodes[\"u\"].data[\"h\"])}\n        )\n        g.dstnodes[\"v\"].data.update(\n            {\"he_v\": self.v_linear(g.dstnodes[\"v\"].data[\"h\"])}\n        )\n        g.edges[\"forward\"].data.update({\"he_e\": self.e_linear(f_feat)})\n        g.edges[\"backward\"].data.update({\"he_e\": self.e_linear(b_feat)})\n        g.apply_edges(\n            lambda edges: {\n                \"he\": edges.data[\"he_e\"] + edges.dst[\"he_u\"] + edges.src[\"he_v\"]\n            },\n            etype=\"backward\",\n        )\n        g.apply_edges(\n            lambda edges: {\n                \"he\": edges.data[\"he_e\"] + edges.src[\"he_u\"] + edges.dst[\"he_v\"]\n            },\n            etype=\"forward\",\n        )\n        hf = g.edges[\"forward\"].data[\"he\"]\n        hb = g.edges[\"backward\"].data[\"he\"]\n        if self.activation is not None:\n            hf = self.activation(hf)\n            hb = self.activation(hb)\n\n        # formula 6\n        g.apply_edges(\n            lambda edges: {\n                \"h_ve\": th.cat([edges.src[\"h\"], edges.data[\"h\"]], -1)\n            },\n            etype=\"backward\",\n        )\n        g.apply_edges(\n            lambda edges: {\n                \"h_ue\": th.cat([edges.src[\"h\"], edges.data[\"h\"]], -1)\n            },\n            etype=\"forward\",\n        )\n\n        # formula 7, self-attention\n        g.srcnodes[\"u\"].data[\"h_att_u\"] = self.W_ATTN_u(\n            g.srcnodes[\"u\"].data[\"h\"]\n        )\n        g.srcnodes[\"v\"].data[\"h_att_v\"] = self.W_ATTN_v(\n            g.srcnodes[\"v\"].data[\"h\"]\n        )\n        g.dstnodes[\"u\"].data[\"h_att_u\"] = self.W_ATTN_u(\n            g.dstnodes[\"u\"].data[\"h\"]\n        )\n        g.dstnodes[\"v\"].data[\"h_att_v\"] = self.W_ATTN_v(\n            g.dstnodes[\"v\"].data[\"h\"]\n        )\n\n        # Step 1: dot product\n        g.apply_edges(fn.e_dot_v(\"h_ve\", \"h_att_u\", \"edotv\"), etype=\"backward\")\n        g.apply_edges(fn.e_dot_v(\"h_ue\", \"h_att_v\", \"edotv\"), etype=\"forward\")\n\n        # Step 2. softmax\n        g.edges[\"backward\"].data[\"sfm\"] = edge_softmax(\n            g[\"backward\"], g.edges[\"backward\"].data[\"edotv\"]\n        )\n        g.edges[\"forward\"].data[\"sfm\"] = edge_softmax(\n            g[\"forward\"], g.edges[\"forward\"].data[\"edotv\"]\n        )\n\n        # Step 3. Broadcast softmax value to each edge, and then attention is done\n        g.apply_edges(\n            lambda edges: {\"attn\": edges.data[\"h_ve\"] * edges.data[\"sfm\"]},\n            etype=\"backward\",\n        )\n        g.apply_edges(\n            lambda edges: {\"attn\": edges.data[\"h_ue\"] * edges.data[\"sfm\"]},\n            etype=\"forward\",\n        )\n\n        # Step 4. Aggregate attention to dst,user nodes, so formula 7 is done\n        g.update_all(\n            fn.copy_e(\"attn\", \"m\"), fn.sum(\"m\", \"agg_u\"), etype=\"backward\"\n        )\n        g.update_all(\n            fn.copy_e(\"attn\", \"m\"), fn.sum(\"m\", \"agg_v\"), etype=\"forward\"\n        )\n\n        # formula 5\n        h_nu = self.W_u(g.dstnodes[\"u\"].data[\"agg_u\"])\n        h_nv = self.W_v(g.dstnodes[\"v\"].data[\"agg_v\"])\n        if self.activation is not None:\n            h_nu = self.activation(h_nu)\n            h_nv = self.activation(h_nv)\n\n        # Dropout\n        hf = self.dropout(hf)\n        hb = self.dropout(hb)\n        h_nu = self.dropout(h_nu)\n        h_nv = self.dropout(h_nv)\n\n        # formula 8\n        hu = th.cat([self.Vu(g.dstnodes[\"u\"].data[\"h\"]), h_nu], -1)\n        hv = th.cat([self.Vv(g.dstnodes[\"v\"].data[\"h\"]), h_nv], -1)\n\n        return hf, hb, hu, hv\n\n\nclass GAS(nn.Module):\n    def __init__(\n        self,\n        e_in_dim,\n        u_in_dim,\n        v_in_dim,\n        e_hid_dim,\n        u_hid_dim,\n        v_hid_dim,\n        out_dim,\n        num_layers=2,\n        dropout=0.0,\n        activation=None,\n    ):\n        super(GAS, self).__init__()\n        self.e_in_dim = e_in_dim\n        self.u_in_dim = u_in_dim\n        self.v_in_dim = v_in_dim\n        self.e_hid_dim = e_hid_dim\n        self.u_hid_dim = u_hid_dim\n        self.v_hid_dim = v_hid_dim\n        self.out_dim = out_dim\n        self.num_layer = num_layers\n        self.dropout = dropout\n        self.activation = activation\n        self.predictor = MLP(e_hid_dim + u_hid_dim + v_hid_dim, out_dim)\n        self.layers = nn.ModuleList()\n\n        # Input layer\n        self.layers.append(\n            GASConv(\n                self.e_in_dim,\n                self.u_in_dim,\n                self.v_in_dim,\n                self.e_hid_dim,\n                self.u_hid_dim,\n                self.v_hid_dim,\n                activation=self.activation,\n                dropout=self.dropout,\n            )\n        )\n\n        # Hidden layers with n - 1 CompGraphConv layers\n        for i in range(self.num_layer - 1):\n            self.layers.append(\n                GASConv(\n                    self.e_hid_dim,\n                    self.u_hid_dim,\n                    self.v_hid_dim,\n                    self.e_hid_dim,\n                    self.u_hid_dim,\n                    self.v_hid_dim,\n                    activation=self.activation,\n                    dropout=self.dropout,\n                )\n            )\n\n    def forward(self, subgraph, blocks, f_feat, b_feat, u_feat, v_feat):\n        # Forward of n layers of GAS\n        for layer, block in zip(self.layers, blocks):\n            f_feat, b_feat, u_feat, v_feat = layer(\n                block,\n                f_feat[: block.num_edges(etype=\"forward\")],\n                b_feat[: block.num_edges(etype=\"backward\")],\n                u_feat,\n                v_feat,\n            )\n\n        # return the result of final prediction layer\n        return self.predictor(\n            subgraph,\n            f_feat[: subgraph.num_edges(etype=\"forward\")],\n            u_feat,\n            v_feat,\n        )\n"
  },
  {
    "path": "examples/pytorch/gat/README.md",
    "content": "Graph Attention Networks (GAT)\n============\n\n- Paper link: [https://arxiv.org/abs/1710.10903](https://arxiv.org/abs/1710.10903)\n- Author's code repo (tensorflow implementation):\n  [https://github.com/PetarV-/GAT](https://github.com/PetarV-/GAT).\n- Popular pytorch implementation:\n  [https://github.com/Diego999/pyGAT](https://github.com/Diego999/pyGAT).\n\nHow to run\n-------\n\n> **_NOTE:_**  `train.py` is deprecated and please check the new version in `//examples/core/gat/train.py`.\n\nRun with the following for multiclass node classification (available datasets: \"cora\", \"citeseer\", \"pubmed\")\n```bash\npython3 train.py --dataset cora\n```\n\nRun with the following for multilabel classification with PPI dataset\n```bash\npython3 train_ppi.py\n```\n\n> **_NOTE:_**  Users may occasionally run into low accuracy issue (e.g., test accuracy < 0.8) due to overfitting. This can be resolved by adding Early Stopping or reducing maximum number of training epochs.\n\nSummary\n-------\n* cora: ~0.821\n* citeseer: ~0.710\n* pubmed: ~0.780\n* ppi: ~0.9744\n"
  },
  {
    "path": "examples/pytorch/gat/train.py",
    "content": "import argparse\n\nimport dgl\nimport dgl.nn as dglnn\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dgl import AddSelfLoop\nfrom dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset\n\n\nclass GAT(nn.Module):\n    def __init__(self, in_size, hid_size, out_size, heads):\n        super().__init__()\n        self.gat_layers = nn.ModuleList()\n        # two-layer GAT\n        self.gat_layers.append(\n            dglnn.GATConv(\n                in_size,\n                hid_size,\n                heads[0],\n                feat_drop=0.6,\n                attn_drop=0.6,\n                activation=F.elu,\n            )\n        )\n        self.gat_layers.append(\n            dglnn.GATConv(\n                hid_size * heads[0],\n                out_size,\n                heads[1],\n                feat_drop=0.6,\n                attn_drop=0.6,\n                activation=None,\n            )\n        )\n\n    def forward(self, g, inputs):\n        h = inputs\n        for i, layer in enumerate(self.gat_layers):\n            h = layer(g, h)\n            if i == 1:  # last layer\n                h = h.mean(1)\n            else:  # other layer(s)\n                h = h.flatten(1)\n        return h\n\n\ndef evaluate(g, features, labels, mask, model):\n    model.eval()\n    with torch.no_grad():\n        logits = model(g, features)\n        logits = logits[mask]\n        labels = labels[mask]\n        _, indices = torch.max(logits, dim=1)\n        correct = torch.sum(indices == labels)\n        return correct.item() * 1.0 / len(labels)\n\n\ndef train(g, features, labels, masks, model):\n    # define train/val samples, loss function and optimizer\n    train_mask = masks[0]\n    val_mask = masks[1]\n    loss_fcn = nn.CrossEntropyLoss()\n    optimizer = torch.optim.Adam(model.parameters(), lr=5e-3, weight_decay=5e-4)\n\n    # training loop\n    for epoch in range(200):\n        model.train()\n        logits = model(g, features)\n        loss = loss_fcn(logits[train_mask], labels[train_mask])\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n        acc = evaluate(g, features, labels, val_mask, model)\n        print(\n            \"Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} \".format(\n                epoch, loss.item(), acc\n            )\n        )\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--dataset\",\n        type=str,\n        default=\"cora\",\n        help=\"Dataset name ('cora', 'citeseer', 'pubmed').\",\n    )\n    parser.add_argument(\n        \"--dt\",\n        type=str,\n        default=\"float\",\n        help=\"data type(float, bfloat16)\",\n    )\n    args = parser.parse_args()\n    print(f\"Training with DGL built-in GATConv module.\")\n\n    # load and preprocess dataset\n    transform = (\n        AddSelfLoop()\n    )  # by default, it will first remove self-loops to prevent duplication\n    if args.dataset == \"cora\":\n        data = CoraGraphDataset(transform=transform)\n    elif args.dataset == \"citeseer\":\n        data = CiteseerGraphDataset(transform=transform)\n    elif args.dataset == \"pubmed\":\n        data = PubmedGraphDataset(transform=transform)\n    else:\n        raise ValueError(\"Unknown dataset: {}\".format(args.dataset))\n    g = data[0]\n    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n    g = g.int().to(device)\n    features = g.ndata[\"feat\"]\n    labels = g.ndata[\"label\"]\n    masks = g.ndata[\"train_mask\"], g.ndata[\"val_mask\"], g.ndata[\"test_mask\"]\n\n    # create GAT model\n    in_size = features.shape[1]\n    out_size = data.num_classes\n    model = GAT(in_size, 8, out_size, heads=[8, 1]).to(device)\n\n    # convert model and graph to bfloat16 if needed\n    if args.dt == \"bfloat16\":\n        g = dgl.to_bfloat16(g)\n        features = features.to(dtype=torch.bfloat16)\n        model = model.to(dtype=torch.bfloat16)\n\n    # model training\n    print(\"Training...\")\n    train(g, features, labels, masks, model)\n\n    # test the model\n    print(\"Testing...\")\n    acc = evaluate(g, features, labels, masks[2], model)\n    print(\"Test accuracy {:.4f}\".format(acc))\n"
  },
  {
    "path": "examples/pytorch/gat/train_ppi.py",
    "content": "import dgl.nn as dglnn\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dgl.data.ppi import PPIDataset\nfrom dgl.dataloading import GraphDataLoader\nfrom sklearn.metrics import f1_score\n\n\nclass GAT(nn.Module):\n    def __init__(self, in_size, hid_size, out_size, heads):\n        super().__init__()\n        self.gat_layers = nn.ModuleList()\n        # three-layer GAT\n        self.gat_layers.append(\n            dglnn.GATConv(in_size, hid_size, heads[0], activation=F.elu)\n        )\n        self.gat_layers.append(\n            dglnn.GATConv(\n                hid_size * heads[0],\n                hid_size,\n                heads[1],\n                residual=True,\n                activation=F.elu,\n            )\n        )\n        self.gat_layers.append(\n            dglnn.GATConv(\n                hid_size * heads[1],\n                out_size,\n                heads[2],\n                residual=True,\n                activation=None,\n            )\n        )\n\n    def forward(self, g, inputs):\n        h = inputs\n        for i, layer in enumerate(self.gat_layers):\n            h = layer(g, h)\n            if i == 2:  # last layer\n                h = h.mean(1)\n            else:  # other layer(s)\n                h = h.flatten(1)\n        return h\n\n\ndef evaluate(g, features, labels, model):\n    model.eval()\n    with torch.no_grad():\n        output = model(g, features)\n        pred = np.where(output.data.cpu().numpy() >= 0, 1, 0)\n        score = f1_score(labels.data.cpu().numpy(), pred, average=\"micro\")\n        return score\n\n\ndef evaluate_in_batches(dataloader, device, model):\n    total_score = 0\n    for batch_id, batched_graph in enumerate(dataloader):\n        batched_graph = batched_graph.to(device)\n        features = batched_graph.ndata[\"feat\"]\n        labels = batched_graph.ndata[\"label\"]\n        score = evaluate(batched_graph, features, labels, model)\n        total_score += score\n    return total_score / (batch_id + 1)  # return average score\n\n\ndef train(train_dataloader, val_dataloader, device, model):\n    # define loss function and optimizer\n    loss_fcn = nn.BCEWithLogitsLoss()\n    optimizer = torch.optim.Adam(model.parameters(), lr=5e-3, weight_decay=0)\n\n    # training loop\n    for epoch in range(400):\n        model.train()\n        logits = []\n        total_loss = 0\n        # mini-batch loop\n        for batch_id, batched_graph in enumerate(train_dataloader):\n            batched_graph = batched_graph.to(device)\n            features = batched_graph.ndata[\"feat\"].float()\n            labels = batched_graph.ndata[\"label\"].float()\n            logits = model(batched_graph, features)\n            loss = loss_fcn(logits, labels)\n            optimizer.zero_grad()\n            loss.backward()\n            optimizer.step()\n            total_loss += loss.item()\n        print(\n            \"Epoch {:05d} | Loss {:.4f} |\".format(\n                epoch, total_loss / (batch_id + 1)\n            )\n        )\n\n        if (epoch + 1) % 5 == 0:\n            avg_score = evaluate_in_batches(\n                val_dataloader, device, model\n            )  # evaluate F1-score instead of loss\n            print(\n                \"                            Acc. (F1-score) {:.4f} \".format(\n                    avg_score\n                )\n            )\n\n\nif __name__ == \"__main__\":\n    print(f\"Training PPI Dataset with DGL built-in GATConv module.\")\n    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n    # load and preprocess datasets\n    train_dataset = PPIDataset(mode=\"train\")\n    val_dataset = PPIDataset(mode=\"valid\")\n    test_dataset = PPIDataset(mode=\"test\")\n    features = train_dataset[0].ndata[\"feat\"]\n\n    # create GAT model\n    in_size = features.shape[1]\n    out_size = train_dataset.num_classes\n    model = GAT(in_size, 256, out_size, heads=[4, 4, 6]).to(device)\n\n    # model training\n    print(\"Training...\")\n    train_dataloader = GraphDataLoader(train_dataset, batch_size=2)\n    val_dataloader = GraphDataLoader(val_dataset, batch_size=2)\n    train(train_dataloader, val_dataloader, device, model)\n\n    # test the model\n    print(\"Testing...\")\n    test_dataloader = GraphDataLoader(test_dataset, batch_size=2)\n    avg_score = evaluate_in_batches(test_dataloader, device, model)\n    print(\"Test Accuracy (F1-score) {:.4f}\".format(avg_score))\n"
  },
  {
    "path": "examples/pytorch/gatv2/README.md",
    "content": "Graph Attention Networks v2 (GATv2)\n============\n\n- Paper link: [How Attentive are Graph Attention Networks?](https://arxiv.org/pdf/2105.14491.pdf)\n- Author's code repo: [https://github.com/tech-srl/how_attentive_are_gats](https://github.com/tech-srl/how_attentive_are_gats).\n- Annotated implemetnation: [https://nn.labml.ai/graphs/gatv2/index.html]\n\nDependencies\n------------\n- torch\n- requests\n- scikit-learn\n\nHow to run\n----------\n\nRun with following:\n\n```bash\npython3 train.py --dataset=cora\n```\n\n```bash\npython3 train.py --dataset=citeseer\n```\n\n```bash\npython3 train.py --dataset=pubmed\n```\n\nResults\n-------\n\n| Dataset  | Test Accuracy |\n| -------- | ------------- |\n| Cora     |  82.10        |\n| Citeseer |  70.00        |\n| Pubmed   |  77.2         |\n\n* All the accuracy numbers are obtained after 200 epochs.\n"
  },
  {
    "path": "examples/pytorch/gatv2/gatv2.py",
    "content": "\"\"\"\nGraph Attention Networks in DGL using SPMV optimization.\nReferences\n----------\nPaper: https://arxiv.org/pdf/2105.14491.pdf\nAuthor's code: https://github.com/tech-srl/how_attentive_are_gats\n\"\"\"\n\nimport torch\nimport torch.nn as nn\n\nfrom dgl.nn import GATv2Conv\n\n\nclass GATv2(nn.Module):\n    def __init__(\n        self,\n        num_layers,\n        in_dim,\n        num_hidden,\n        num_classes,\n        heads,\n        activation,\n        feat_drop,\n        attn_drop,\n        negative_slope,\n        residual,\n    ):\n        super(GATv2, self).__init__()\n        self.num_layers = num_layers\n        self.gatv2_layers = nn.ModuleList()\n        self.activation = activation\n        # input projection (no residual)\n        self.gatv2_layers.append(\n            GATv2Conv(\n                in_dim,\n                num_hidden,\n                heads[0],\n                feat_drop,\n                attn_drop,\n                negative_slope,\n                False,\n                self.activation,\n                bias=False,\n                share_weights=True,\n            )\n        )\n        # hidden layers\n        for l in range(1, num_layers):\n            # due to multi-head, the in_dim = num_hidden * num_heads\n            self.gatv2_layers.append(\n                GATv2Conv(\n                    num_hidden * heads[l - 1],\n                    num_hidden,\n                    heads[l],\n                    feat_drop,\n                    attn_drop,\n                    negative_slope,\n                    residual,\n                    self.activation,\n                    bias=False,\n                    share_weights=True,\n                )\n            )\n        # output projection\n        self.gatv2_layers.append(\n            GATv2Conv(\n                num_hidden * heads[-2],\n                num_classes,\n                heads[-1],\n                feat_drop,\n                attn_drop,\n                negative_slope,\n                residual,\n                None,\n                bias=False,\n                share_weights=True,\n            )\n        )\n\n    def forward(self, g, inputs):\n        h = inputs\n        for l in range(self.num_layers):\n            h = self.gatv2_layers[l](g, h).flatten(1)\n        # output projection\n        logits = self.gatv2_layers[-1](g, h).mean(1)\n        return logits\n"
  },
  {
    "path": "examples/pytorch/gatv2/train.py",
    "content": "\"\"\"\nGraph Attention Networks v2 (GATv2) in DGL using SPMV optimization.\nMultiple heads are also batched together for faster training.\n\"\"\"\n\nimport argparse\nimport time\n\nimport dgl\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom dgl.data import (\n    CiteseerGraphDataset,\n    CoraGraphDataset,\n    PubmedGraphDataset,\n    register_data_args,\n)\nfrom gatv2 import GATv2\n\n\nclass EarlyStopping:\n    def __init__(self, patience=10):\n        self.patience = patience\n        self.counter = 0\n        self.best_score = None\n        self.early_stop = False\n\n    def step(self, acc, model):\n        score = acc\n        if self.best_score is None:\n            self.best_score = score\n            self.save_checkpoint(model)\n        elif score < self.best_score:\n            self.counter += 1\n            print(\n                f\"EarlyStopping counter: {self.counter} out of {self.patience}\"\n            )\n            if self.counter >= self.patience:\n                self.early_stop = True\n        else:\n            self.best_score = score\n            self.save_checkpoint(model)\n            self.counter = 0\n        return self.early_stop\n\n    def save_checkpoint(self, model):\n        \"\"\"Saves model when validation loss decrease.\"\"\"\n        torch.save(model.state_dict(), \"es_checkpoint.pt\")\n\n\ndef accuracy(logits, labels):\n    _, indices = torch.max(logits, dim=1)\n    correct = torch.sum(indices == labels)\n    return correct.item() * 1.0 / len(labels)\n\n\ndef evaluate(g, model, features, labels, mask):\n    model.eval()\n    with torch.no_grad():\n        logits = model(g, features)\n        logits = logits[mask]\n        labels = labels[mask]\n        return accuracy(logits, labels)\n\n\ndef main(args):\n    # load and preprocess dataset\n    if args.dataset == \"cora\":\n        data = CoraGraphDataset()\n    elif args.dataset == \"citeseer\":\n        data = CiteseerGraphDataset()\n    elif args.dataset == \"pubmed\":\n        data = PubmedGraphDataset()\n    else:\n        raise ValueError(\"Unknown dataset: {}\".format(args.dataset))\n\n    g = data[0]\n    if args.gpu < 0:\n        cuda = False\n    else:\n        cuda = True\n        g = g.int().to(args.gpu)\n\n    features = g.ndata[\"feat\"]\n    labels = g.ndata[\"label\"]\n    train_mask = g.ndata[\"train_mask\"]\n    val_mask = g.ndata[\"val_mask\"]\n    test_mask = g.ndata[\"test_mask\"]\n    num_feats = features.shape[1]\n    n_classes = data.num_classes\n    n_edges = g.num_edges()\n    print(\n        \"\"\"----Data statistics------'\n      #Edges %d\n      #Classes %d\n      #Train samples %d\n      #Val samples %d\n      #Test samples %d\"\"\"\n        % (\n            n_edges,\n            n_classes,\n            train_mask.int().sum().item(),\n            val_mask.int().sum().item(),\n            test_mask.int().sum().item(),\n        )\n    )\n\n    # add self loop\n    g = dgl.remove_self_loop(g)\n    g = dgl.add_self_loop(g)\n    n_edges = g.num_edges()\n    # create model\n    heads = ([args.num_heads] * args.num_layers) + [args.num_out_heads]\n    model = GATv2(\n        args.num_layers,\n        num_feats,\n        args.num_hidden,\n        n_classes,\n        heads,\n        F.elu,\n        args.in_drop,\n        args.attn_drop,\n        args.negative_slope,\n        args.residual,\n    )\n    print(model)\n    if args.early_stop:\n        stopper = EarlyStopping(patience=100)\n    if cuda:\n        model.cuda()\n    loss_fcn = torch.nn.CrossEntropyLoss()\n\n    # use optimizer\n    optimizer = torch.optim.Adam(\n        model.parameters(), lr=args.lr, weight_decay=args.weight_decay\n    )\n\n    # initialize graph\n    mean = 0\n    for epoch in range(args.epochs):\n        model.train()\n        if epoch >= 3:\n            t0 = time.time()\n        # forward\n        logits = model(g, features)\n        loss = loss_fcn(logits[train_mask], labels[train_mask])\n\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        if epoch >= 3:\n            mean = (mean * (epoch - 3) + (time.time() - t0)) / (epoch - 2)\n\n            train_acc = accuracy(logits[train_mask], labels[train_mask])\n\n            if args.fastmode:\n                val_acc = accuracy(logits[val_mask], labels[val_mask])\n            else:\n                val_acc = evaluate(g, model, features, labels, val_mask)\n                if args.early_stop:\n                    if stopper.step(val_acc, model):\n                        break\n\n            print(\n                \"Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | TrainAcc {:.4f} |\"\n                \" ValAcc {:.4f} | ETputs(KTEPS) {:.2f}\".format(\n                    epoch,\n                    mean,\n                    loss.item(),\n                    train_acc,\n                    val_acc,\n                    n_edges / mean / 1000,\n                )\n            )\n\n    print()\n    if args.early_stop:\n        model.load_state_dict(\n            torch.load(\"es_checkpoint.pt\", weights_only=False)\n        )\n    acc = evaluate(g, model, features, labels, test_mask)\n    print(\"Test Accuracy {:.4f}\".format(acc))\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"GAT\")\n    register_data_args(parser)\n    parser.add_argument(\n        \"--gpu\",\n        type=int,\n        default=-1,\n        help=\"which GPU to use. Set -1 to use CPU.\",\n    )\n    parser.add_argument(\n        \"--epochs\", type=int, default=200, help=\"number of training epochs\"\n    )\n    parser.add_argument(\n        \"--num-heads\",\n        type=int,\n        default=8,\n        help=\"number of hidden attention heads\",\n    )\n    parser.add_argument(\n        \"--num-out-heads\",\n        type=int,\n        default=1,\n        help=\"number of output attention heads\",\n    )\n    parser.add_argument(\n        \"--num-layers\", type=int, default=1, help=\"number of hidden layers\"\n    )\n    parser.add_argument(\n        \"--num-hidden\", type=int, default=8, help=\"number of hidden units\"\n    )\n    parser.add_argument(\n        \"--residual\",\n        action=\"store_true\",\n        default=False,\n        help=\"use residual connection\",\n    )\n    parser.add_argument(\n        \"--in-drop\", type=float, default=0.7, help=\"input feature dropout\"\n    )\n    parser.add_argument(\n        \"--attn-drop\", type=float, default=0.7, help=\"attention dropout\"\n    )\n    parser.add_argument(\"--lr\", type=float, default=0.005, help=\"learning rate\")\n    parser.add_argument(\n        \"--weight-decay\", type=float, default=5e-4, help=\"weight decay\"\n    )\n    parser.add_argument(\n        \"--negative-slope\",\n        type=float,\n        default=0.2,\n        help=\"the negative slope of leaky relu\",\n    )\n    parser.add_argument(\n        \"--early-stop\",\n        action=\"store_true\",\n        default=False,\n        help=\"indicates whether to use early stop or not\",\n    )\n    parser.add_argument(\n        \"--fastmode\",\n        action=\"store_true\",\n        default=False,\n        help=\"skip re-evaluate the validation set\",\n    )\n    args = parser.parse_args()\n    print(args)\n\n    main(args)\n"
  },
  {
    "path": "examples/pytorch/gcmc/README.md",
    "content": "# Graph Convolutional Matrix Completion\n\nPaper link: [https://arxiv.org/abs/1706.02263](https://arxiv.org/abs/1706.02263)\nAuthor's code: [https://github.com/riannevdberg/gc-mc](https://github.com/riannevdberg/gc-mc)\n\nThe implementation does not handle side-channel features and mini-epoching and thus achieves\nslightly worse performance when using node features.\n\nCredit: Jiani Zhang ([@jennyzhang0215](https://github.com/jennyzhang0215))\n\n## Dependencies\n* PyTorch 1.2+\n* pandas\n* torchtext 0.9+ (if using user and item contents as node features)\n* spacy (if using user and item contents as node features)\n    - You will also need to run `python -m spacy download en_core_web_sm`\n\n## Data\n\nSupported datasets: ml-100k, ml-1m, ml-10m\n\n## How to run\n### Train with full-graph\nml-100k, no feature\n```bash\npython3 train.py --data_name=ml-100k --use_one_hot_fea --gcn_agg_accum=stack\n```\nResults: RMSE=0.9088 (0.910 reported)\n\nml-100k, with feature\n```bash\npython3 train.py --data_name=ml-100k --gcn_agg_accum=stack\n```\nResults: RMSE=0.9448 (0.905 reported)\n\nml-1m, no feature\n```bash\npython3 train.py --data_name=ml-1m --gcn_agg_accum=sum --use_one_hot_fea\n```\nResults: RMSE=0.8377 (0.832 reported)\n\nml-10m, no feature\n```bash\npython3 train.py --data_name=ml-10m --gcn_agg_accum=stack --gcn_dropout=0.3 \\\n                                 --train_lr=0.001 --train_min_lr=0.0001 --train_max_iter=15000 \\\n                                 --use_one_hot_fea --gen_r_num_basis_func=4\n```\nResults: RMSE=0.7800 (0.777 reported)\nTestbed: EC2 p3.2xlarge instance(Amazon Linux 2)\n\n### Train with minibatch on a single GPU\nml-100k, no feature\n```bash\npython3 train_sampling.py --data_name=ml-100k \\\n                          --use_one_hot_fea \\\n                          --gcn_agg_accum=stack \\\n                          --gpu 0\n\n```\nml-100k, no feature with mix_cpu_gpu run, for mix_cpu_gpu run with no feature, the W_r is stored in CPU by default other than in GPU.\n```bash\npython3 train_sampling.py --data_name=ml-100k \\\n                          --use_one_hot_fea \\\n                          --gcn_agg_accum=stack \\\n                          --mix_cpu_gpu \\\n                          --gpu 0 \n```\nResults: RMSE=0.9380\n\nml-100k, with feature\n```bash\npython3 train_sampling.py --data_name=ml-100k \\\n                          --gcn_agg_accum=stack \\\n                          --train_max_epoch 90 \\\n                          --gpu 0\n```\nResults: RMSE=0.9574\n\nml-1m, no feature\n```bash\npython3 train_sampling.py --data_name=ml-1m \\\n                          --gcn_agg_accum=sum \\\n                          --use_one_hot_fea \\\n                          --train_max_epoch 160 \\\n                          --gpu 0\n```\nml-1m, no feature with mix_cpu_gpu run\n```bash\npython3 train_sampling.py --data_name=ml-1m \\\n                          --gcn_agg_accum=sum \\\n                          --use_one_hot_fea \\\n                          --train_max_epoch 60 \\\n                          --mix_cpu_gpu \\\n                          --gpu 0\n```\nResults: RMSE=0.8632\n\nml-10m, no feature\n```bash\npython3 train_sampling.py --data_name=ml-10m \\\n                          --gcn_agg_accum=stack \\\n                          --gcn_dropout=0.3 \\\n                          --train_lr=0.001 \\\n                          --train_min_lr=0.0001 \\\n                          --train_max_epoch=60 \\\n                          --use_one_hot_fea \\\n                          --gen_r_num_basis_func=4 \\\n                          --gpu 0\n```\nml-10m, no feature with mix_cpu_gpu run\n```bash\npython3 train_sampling.py --data_name=ml-10m \\\n                          --gcn_agg_accum=stack \\\n                          --gcn_dropout=0.3 \\\n                          --train_lr=0.001 \\\n                          --train_min_lr=0.0001 \\\n                          --train_max_epoch=60 \\\n                          --use_one_hot_fea \\\n                          --gen_r_num_basis_func=4 \\\n                          --mix_cpu_gpu \\\n                          --gpu 0\n```\nResults: RMSE=0.8050\nTestbed: EC2 p3.2xlarge instance\n\n### Train with minibatch on multi-GPU\nml-100k, no feature\n```bash\npython train_sampling.py --data_name=ml-100k \\\n                         --gcn_agg_accum=stack \\\n                         --train_max_epoch 30 \\\n                         --train_lr 0.02 \\\n                         --use_one_hot_fea \\\n                         --gpu 0,1,2,3,4,5,6,7\n```\nml-100k, no feature with mix_cpu_gpu run\n```bash\npython train_sampling.py --data_name=ml-100k \\\n                         --gcn_agg_accum=stack \\\n                         --train_max_epoch 30 \\\n                         --train_lr 0.02 \\\n                         --use_one_hot_fea \\\n                         --mix_cpu_gpu \\\n                         --gpu 0,1,2,3,4,5,6,7\n```\nResult: RMSE=0.9397\n\nml-100k, with feature\n```bash\npython train_sampling.py --data_name=ml-100k \\\n                         --gcn_agg_accum=stack \\\n                         --train_max_epoch 30 \\\n                         --gpu 0,1,2,3,4,5,6,7\n```\nResult: RMSE=0.9655\n\nml-1m, no feature\n```bash\npython train_sampling.py --data_name=ml-1m \\\n                         --gcn_agg_accum=sum \\\n                         --train_max_epoch 40 \\\n                         --use_one_hot_fea \\\n                         --gpu 0,1,2,3,4,5,6,7\n```\nml-1m, no feature with mix_cpu_gpu run\n```bash\npython train_sampling.py --data_name=ml-1m \\\n                         --gcn_agg_accum=sum \\\n                         --train_max_epoch 40 \\\n                         --use_one_hot_fea \\\n                         --mix_cpu_gpu \\\n                         --gpu 0,1,2,3,4,5,6,7\n```\nResults: RMSE=0.8621\n\nml-10m, no feature\n```bash\npython train_sampling.py --data_name=ml-10m \\\n                         --gcn_agg_accum=stack \\\n                         --gcn_dropout=0.3 \\\n                         --train_lr=0.001 \\\n                         --train_min_lr=0.0001 \\\n                         --train_max_epoch=30 \\\n                         --use_one_hot_fea \\\n                         --gen_r_num_basis_func=4 \\\n                         --gpu 0,1,2,3,4,5,6,7\n```\nml-10m, no feature with mix_cpu_gpu run\n```bash\npython train_sampling.py --data_name=ml-10m \\\n                         --gcn_agg_accum=stack \\\n                         --gcn_dropout=0.3 \\\n                         --train_lr=0.001 \\\n                         --train_min_lr=0.0001 \\\n                         --train_max_epoch=30 \\\n                         --use_one_hot_fea \\\n                         --gen_r_num_basis_func=4 \\\n                         --mix_cpu_gpu \\\n                         --gpu 0,1,2,3,4,5,6,7\n```\nResults: RMSE=0.8084\nTestbed: EC2 p3.16xlarge instance\n\n### Train with minibatch on CPU\nml-100k, no feature\n```bash\npython3 train_sampling.py --data_name=ml-100k \\\n                          --use_one_hot_fea \\\n                          --gcn_agg_accum=stack \\\n                          --gpu -1\n```\nTestbed: EC2 r5.xlarge instance\n"
  },
  {
    "path": "examples/pytorch/gcmc/data.py",
    "content": "\"\"\"MovieLens dataset\"\"\"\nimport os\nimport re\n\nimport dgl\nimport numpy as np\nimport pandas as pd\nimport scipy.sparse as sp\nimport torch as th\nfrom dgl.data.utils import download, extract_archive, get_download_dir\nfrom utils import to_etype_name\n\n_urls = {\n    \"ml-100k\": \"http://files.grouplens.org/datasets/movielens/ml-100k.zip\",\n    \"ml-1m\": \"http://files.grouplens.org/datasets/movielens/ml-1m.zip\",\n    \"ml-10m\": \"http://files.grouplens.org/datasets/movielens/ml-10m.zip\",\n}\n\nREAD_DATASET_PATH = get_download_dir()\nGENRES_ML_100K = [\n    \"unknown\",\n    \"Action\",\n    \"Adventure\",\n    \"Animation\",\n    \"Children\",\n    \"Comedy\",\n    \"Crime\",\n    \"Documentary\",\n    \"Drama\",\n    \"Fantasy\",\n    \"Film-Noir\",\n    \"Horror\",\n    \"Musical\",\n    \"Mystery\",\n    \"Romance\",\n    \"Sci-Fi\",\n    \"Thriller\",\n    \"War\",\n    \"Western\",\n]\nGENRES_ML_1M = GENRES_ML_100K[1:]\nGENRES_ML_10M = GENRES_ML_100K + [\"IMAX\"]\n\n\nclass MovieLens(object):\n    \"\"\"MovieLens dataset used by GCMC model\n\n    TODO(minjie): make this dataset more general\n\n    The dataset stores MovieLens ratings in two types of graphs. The encoder graph\n    contains rating value information in the form of edge types. The decoder graph\n    stores plain user-movie pairs in the form of a bipartite graph with no rating\n    information. All graphs have two types of nodes: \"user\" and \"movie\".\n\n    The training, validation and test set can be summarized as follows:\n\n    training_enc_graph : training user-movie pairs + rating info\n    training_dec_graph : training user-movie pairs\n    valid_enc_graph : training user-movie pairs + rating info\n    valid_dec_graph : validation user-movie pairs\n    test_enc_graph : training user-movie pairs + validation user-movie pairs + rating info\n    test_dec_graph : test user-movie pairs\n\n    Attributes\n    ----------\n    train_enc_graph : dgl.DGLGraph\n        Encoder graph for training.\n    train_dec_graph : dgl.DGLGraph\n        Decoder graph for training.\n    train_labels : torch.Tensor\n        The categorical label of each user-movie pair\n    train_truths : torch.Tensor\n        The actual rating values of each user-movie pair\n    valid_enc_graph : dgl.DGLGraph\n        Encoder graph for validation.\n    valid_dec_graph : dgl.DGLGraph\n        Decoder graph for validation.\n    valid_labels : torch.Tensor\n        The categorical label of each user-movie pair\n    valid_truths : torch.Tensor\n        The actual rating values of each user-movie pair\n    test_enc_graph : dgl.DGLGraph\n        Encoder graph for test.\n    test_dec_graph : dgl.DGLGraph\n        Decoder graph for test.\n    test_labels : torch.Tensor\n        The categorical label of each user-movie pair\n    test_truths : torch.Tensor\n        The actual rating values of each user-movie pair\n    user_feature : torch.Tensor\n        User feature tensor. If None, representing an identity matrix.\n    movie_feature : torch.Tensor\n        Movie feature tensor. If None, representing an identity matrix.\n    possible_rating_values : np.ndarray\n        Available rating values in the dataset\n\n    Parameters\n    ----------\n    name : str\n        Dataset name. Could be \"ml-100k\", \"ml-1m\", \"ml-10m\"\n    device : torch.device\n        Device context\n    mix_cpu_gpu : boo, optional\n        If true, the ``user_feature`` attribute is stored in CPU\n    use_one_hot_fea : bool, optional\n        If true, the ``user_feature`` attribute is None, representing an one-hot identity\n        matrix. (Default: False)\n    symm : bool, optional\n        If true, the use symmetric normalize constant. Otherwise, use left normalize\n        constant. (Default: True)\n    test_ratio : float, optional\n        Ratio of test data\n    valid_ratio : float, optional\n        Ratio of validation data\n\n    \"\"\"\n\n    def __init__(\n        self,\n        name,\n        device,\n        mix_cpu_gpu=False,\n        use_one_hot_fea=False,\n        symm=True,\n        test_ratio=0.1,\n        valid_ratio=0.1,\n    ):\n        self._name = name\n        self._device = device\n        self._symm = symm\n        self._test_ratio = test_ratio\n        self._valid_ratio = valid_ratio\n        # download and extract\n        download_dir = get_download_dir()\n        zip_file_path = \"{}/{}.zip\".format(download_dir, name)\n        download(_urls[name], path=zip_file_path)\n        extract_archive(zip_file_path, \"{}/{}\".format(download_dir, name))\n        if name == \"ml-10m\":\n            root_folder = \"ml-10M100K\"\n        else:\n            root_folder = name\n        self._dir = os.path.join(download_dir, name, root_folder)\n        print(\"Starting processing {} ...\".format(self._name))\n        self._load_raw_user_info()\n        self._load_raw_movie_info()\n        print(\"......\")\n        if self._name == \"ml-100k\":\n            self.all_train_rating_info = self._load_raw_rates(\n                os.path.join(self._dir, \"u1.base\"), \"\\t\"\n            )\n            self.test_rating_info = self._load_raw_rates(\n                os.path.join(self._dir, \"u1.test\"), \"\\t\"\n            )\n            self.all_rating_info = pd.concat(\n                [self.all_train_rating_info, self.test_rating_info]\n            )\n        elif self._name == \"ml-1m\" or self._name == \"ml-10m\":\n            self.all_rating_info = self._load_raw_rates(\n                os.path.join(self._dir, \"ratings.dat\"), \"::\"\n            )\n            num_test = int(\n                np.ceil(self.all_rating_info.shape[0] * self._test_ratio)\n            )\n            shuffled_idx = np.random.permutation(self.all_rating_info.shape[0])\n            self.test_rating_info = self.all_rating_info.iloc[\n                shuffled_idx[:num_test]\n            ]\n            self.all_train_rating_info = self.all_rating_info.iloc[\n                shuffled_idx[num_test:]\n            ]\n        else:\n            raise NotImplementedError\n        print(\"......\")\n        num_valid = int(\n            np.ceil(self.all_train_rating_info.shape[0] * self._valid_ratio)\n        )\n        shuffled_idx = np.random.permutation(\n            self.all_train_rating_info.shape[0]\n        )\n        self.valid_rating_info = self.all_train_rating_info.iloc[\n            shuffled_idx[:num_valid]\n        ]\n        self.train_rating_info = self.all_train_rating_info.iloc[\n            shuffled_idx[num_valid:]\n        ]\n        self.possible_rating_values = np.unique(\n            self.train_rating_info[\"rating\"].values\n        )\n\n        print(\"All rating pairs : {}\".format(self.all_rating_info.shape[0]))\n        print(\n            \"\\tAll train rating pairs : {}\".format(\n                self.all_train_rating_info.shape[0]\n            )\n        )\n        print(\n            \"\\t\\tTrain rating pairs : {}\".format(\n                self.train_rating_info.shape[0]\n            )\n        )\n        print(\n            \"\\t\\tValid rating pairs : {}\".format(\n                self.valid_rating_info.shape[0]\n            )\n        )\n        print(\n            \"\\tTest rating pairs  : {}\".format(self.test_rating_info.shape[0])\n        )\n\n        self.user_info = self._drop_unseen_nodes(\n            orign_info=self.user_info,\n            cmp_col_name=\"id\",\n            reserved_ids_set=set(self.all_rating_info[\"user_id\"].values),\n            label=\"user\",\n        )\n        self.movie_info = self._drop_unseen_nodes(\n            orign_info=self.movie_info,\n            cmp_col_name=\"id\",\n            reserved_ids_set=set(self.all_rating_info[\"movie_id\"].values),\n            label=\"movie\",\n        )\n\n        # Map user/movie to the global id\n        self.global_user_id_map = {\n            ele: i for i, ele in enumerate(self.user_info[\"id\"])\n        }\n        self.global_movie_id_map = {\n            ele: i for i, ele in enumerate(self.movie_info[\"id\"])\n        }\n        print(\n            \"Total user number = {}, movie number = {}\".format(\n                len(self.global_user_id_map), len(self.global_movie_id_map)\n            )\n        )\n        self._num_user = len(self.global_user_id_map)\n        self._num_movie = len(self.global_movie_id_map)\n\n        ### Generate features\n        if use_one_hot_fea:\n            self.user_feature = None\n            self.movie_feature = None\n        else:\n            # if mix_cpu_gpu, we put features in CPU\n            if mix_cpu_gpu:\n                self.user_feature = th.FloatTensor(self._process_user_fea())\n                self.movie_feature = th.FloatTensor(self._process_movie_fea())\n            else:\n                self.user_feature = th.FloatTensor(self._process_user_fea()).to(\n                    self._device\n                )\n                self.movie_feature = th.FloatTensor(\n                    self._process_movie_fea()\n                ).to(self._device)\n        if self.user_feature is None:\n            self.user_feature_shape = (self.num_user, self.num_user)\n            self.movie_feature_shape = (self.num_movie, self.num_movie)\n        else:\n            self.user_feature_shape = self.user_feature.shape\n            self.movie_feature_shape = self.movie_feature.shape\n        info_line = \"Feature dim: \"\n        info_line += \"\\nuser: {}\".format(self.user_feature_shape)\n        info_line += \"\\nmovie: {}\".format(self.movie_feature_shape)\n        print(info_line)\n\n        (\n            all_train_rating_pairs,\n            all_train_rating_values,\n        ) = self._generate_pair_value(self.all_train_rating_info)\n        train_rating_pairs, train_rating_values = self._generate_pair_value(\n            self.train_rating_info\n        )\n        valid_rating_pairs, valid_rating_values = self._generate_pair_value(\n            self.valid_rating_info\n        )\n        test_rating_pairs, test_rating_values = self._generate_pair_value(\n            self.test_rating_info\n        )\n\n        def _make_labels(ratings):\n            labels = th.LongTensor(\n                np.searchsorted(self.possible_rating_values, ratings)\n            ).to(device)\n            return labels\n\n        self.train_enc_graph = self._generate_enc_graph(\n            train_rating_pairs, train_rating_values, add_support=True\n        )\n        self.train_dec_graph = self._generate_dec_graph(train_rating_pairs)\n        self.train_labels = _make_labels(train_rating_values)\n        self.train_truths = th.FloatTensor(train_rating_values).to(device)\n\n        self.valid_enc_graph = self.train_enc_graph\n        self.valid_dec_graph = self._generate_dec_graph(valid_rating_pairs)\n        self.valid_labels = _make_labels(valid_rating_values)\n        self.valid_truths = th.FloatTensor(valid_rating_values).to(device)\n\n        self.test_enc_graph = self._generate_enc_graph(\n            all_train_rating_pairs, all_train_rating_values, add_support=True\n        )\n        self.test_dec_graph = self._generate_dec_graph(test_rating_pairs)\n        self.test_labels = _make_labels(test_rating_values)\n        self.test_truths = th.FloatTensor(test_rating_values).to(device)\n\n        def _npairs(graph):\n            rst = 0\n            for r in self.possible_rating_values:\n                r = to_etype_name(r)\n                rst += graph.num_edges(str(r))\n            return rst\n\n        print(\n            \"Train enc graph: \\t#user:{}\\t#movie:{}\\t#pairs:{}\".format(\n                self.train_enc_graph.num_nodes(\"user\"),\n                self.train_enc_graph.num_nodes(\"movie\"),\n                _npairs(self.train_enc_graph),\n            )\n        )\n        print(\n            \"Train dec graph: \\t#user:{}\\t#movie:{}\\t#pairs:{}\".format(\n                self.train_dec_graph.num_nodes(\"user\"),\n                self.train_dec_graph.num_nodes(\"movie\"),\n                self.train_dec_graph.num_edges(),\n            )\n        )\n        print(\n            \"Valid enc graph: \\t#user:{}\\t#movie:{}\\t#pairs:{}\".format(\n                self.valid_enc_graph.num_nodes(\"user\"),\n                self.valid_enc_graph.num_nodes(\"movie\"),\n                _npairs(self.valid_enc_graph),\n            )\n        )\n        print(\n            \"Valid dec graph: \\t#user:{}\\t#movie:{}\\t#pairs:{}\".format(\n                self.valid_dec_graph.num_nodes(\"user\"),\n                self.valid_dec_graph.num_nodes(\"movie\"),\n                self.valid_dec_graph.num_edges(),\n            )\n        )\n        print(\n            \"Test enc graph: \\t#user:{}\\t#movie:{}\\t#pairs:{}\".format(\n                self.test_enc_graph.num_nodes(\"user\"),\n                self.test_enc_graph.num_nodes(\"movie\"),\n                _npairs(self.test_enc_graph),\n            )\n        )\n        print(\n            \"Test dec graph: \\t#user:{}\\t#movie:{}\\t#pairs:{}\".format(\n                self.test_dec_graph.num_nodes(\"user\"),\n                self.test_dec_graph.num_nodes(\"movie\"),\n                self.test_dec_graph.num_edges(),\n            )\n        )\n\n    def _generate_pair_value(self, rating_info):\n        rating_pairs = (\n            np.array(\n                [\n                    self.global_user_id_map[ele]\n                    for ele in rating_info[\"user_id\"]\n                ],\n                dtype=np.int64,\n            ),\n            np.array(\n                [\n                    self.global_movie_id_map[ele]\n                    for ele in rating_info[\"movie_id\"]\n                ],\n                dtype=np.int64,\n            ),\n        )\n        rating_values = rating_info[\"rating\"].values.astype(np.float32)\n        return rating_pairs, rating_values\n\n    def _generate_enc_graph(\n        self, rating_pairs, rating_values, add_support=False\n    ):\n        user_movie_R = np.zeros(\n            (self._num_user, self._num_movie), dtype=np.float32\n        )\n        user_movie_R[rating_pairs] = rating_values\n\n        data_dict = dict()\n        num_nodes_dict = {\"user\": self._num_user, \"movie\": self._num_movie}\n        rating_row, rating_col = rating_pairs\n        for rating in self.possible_rating_values:\n            ridx = np.where(rating_values == rating)\n            rrow = rating_row[ridx]\n            rcol = rating_col[ridx]\n            rating = to_etype_name(rating)\n            data_dict.update(\n                {\n                    (\"user\", str(rating), \"movie\"): (rrow, rcol),\n                    (\"movie\", \"rev-%s\" % str(rating), \"user\"): (rcol, rrow),\n                }\n            )\n        graph = dgl.heterograph(data_dict, num_nodes_dict=num_nodes_dict)\n\n        # sanity check\n        assert (\n            len(rating_pairs[0])\n            == sum([graph.num_edges(et) for et in graph.etypes]) // 2\n        )\n\n        if add_support:\n\n            def _calc_norm(x):\n                x = x.numpy().astype(\"float32\")\n                x[x == 0.0] = np.inf\n                x = th.FloatTensor(1.0 / np.sqrt(x))\n                return x.unsqueeze(1)\n\n            user_ci = []\n            user_cj = []\n            movie_ci = []\n            movie_cj = []\n            for r in self.possible_rating_values:\n                r = to_etype_name(r)\n                user_ci.append(graph[\"rev-%s\" % r].in_degrees())\n                movie_ci.append(graph[r].in_degrees())\n                if self._symm:\n                    user_cj.append(graph[r].out_degrees())\n                    movie_cj.append(graph[\"rev-%s\" % r].out_degrees())\n                else:\n                    user_cj.append(th.zeros((self.num_user,)))\n                    movie_cj.append(th.zeros((self.num_movie,)))\n            user_ci = _calc_norm(sum(user_ci))\n            movie_ci = _calc_norm(sum(movie_ci))\n            if self._symm:\n                user_cj = _calc_norm(sum(user_cj))\n                movie_cj = _calc_norm(sum(movie_cj))\n            else:\n                user_cj = th.ones(\n                    self.num_user,\n                )\n                movie_cj = th.ones(\n                    self.num_movie,\n                )\n            graph.nodes[\"user\"].data.update({\"ci\": user_ci, \"cj\": user_cj})\n            graph.nodes[\"movie\"].data.update({\"ci\": movie_ci, \"cj\": movie_cj})\n\n        return graph\n\n    def _generate_dec_graph(self, rating_pairs):\n        ones = np.ones_like(rating_pairs[0])\n        user_movie_ratings_coo = sp.coo_matrix(\n            (ones, rating_pairs),\n            shape=(self.num_user, self.num_movie),\n            dtype=np.float32,\n        )\n        g = dgl.bipartite_from_scipy(\n            user_movie_ratings_coo, utype=\"_U\", etype=\"_E\", vtype=\"_V\"\n        )\n        return dgl.heterograph(\n            {(\"user\", \"rate\", \"movie\"): g.edges()},\n            num_nodes_dict={\"user\": self.num_user, \"movie\": self.num_movie},\n        )\n\n    @property\n    def num_links(self):\n        return self.possible_rating_values.size\n\n    @property\n    def num_user(self):\n        return self._num_user\n\n    @property\n    def num_movie(self):\n        return self._num_movie\n\n    def _drop_unseen_nodes(\n        self, orign_info, cmp_col_name, reserved_ids_set, label\n    ):\n        # print(\"  -----------------\")\n        # print(\"{}: {}(reserved) v.s. {}(from info)\".format(label, len(reserved_ids_set),\n        #                                                      len(set(orign_info[cmp_col_name].values))))\n        if reserved_ids_set != set(orign_info[cmp_col_name].values):\n            pd_rating_ids = pd.DataFrame(\n                list(reserved_ids_set), columns=[\"id_graph\"]\n            )\n            # print(\"\\torign_info: ({}, {})\".format(orign_info.shape[0], orign_info.shape[1]))\n            data_info = orign_info.merge(\n                pd_rating_ids,\n                left_on=cmp_col_name,\n                right_on=\"id_graph\",\n                how=\"outer\",\n            )\n            data_info = data_info.dropna(subset=[cmp_col_name, \"id_graph\"])\n            data_info = data_info.drop(columns=[\"id_graph\"])\n            data_info = data_info.reset_index(drop=True)\n            # print(\"\\tAfter dropping, data shape: ({}, {})\".format(data_info.shape[0], data_info.shape[1]))\n            return data_info\n        else:\n            orign_info = orign_info.reset_index(drop=True)\n            return orign_info\n\n    def _load_raw_rates(self, file_path, sep):\n        \"\"\"In MovieLens, the rates have the following format\n\n        ml-100k\n        user id \\t movie id \\t rating \\t timestamp\n\n        ml-1m/10m\n        UserID::MovieID::Rating::Timestamp\n\n        timestamp is unix timestamp and can be converted by pd.to_datetime(X, unit='s')\n\n        Parameters\n        ----------\n        file_path : str\n\n        Returns\n        -------\n        rating_info : pd.DataFrame\n        \"\"\"\n        rating_info = pd.read_csv(\n            file_path,\n            sep=sep,\n            header=None,\n            names=[\"user_id\", \"movie_id\", \"rating\", \"timestamp\"],\n            dtype={\n                \"user_id\": np.int32,\n                \"movie_id\": np.int32,\n                \"ratings\": np.float32,\n                \"timestamp\": np.int64,\n            },\n            engine=\"python\",\n        )\n        return rating_info\n\n    def _load_raw_user_info(self):\n        \"\"\"In MovieLens, the user attributes file have the following formats:\n\n        ml-100k:\n        user id | age | gender | occupation | zip code\n\n        ml-1m:\n        UserID::Gender::Age::Occupation::Zip-code\n\n        For ml-10m, there is no user information. We read the user id from the rating file.\n\n        Parameters\n        ----------\n        name : str\n\n        Returns\n        -------\n        user_info : pd.DataFrame\n        \"\"\"\n        if self._name == \"ml-100k\":\n            self.user_info = pd.read_csv(\n                os.path.join(self._dir, \"u.user\"),\n                sep=\"|\",\n                header=None,\n                names=[\"id\", \"age\", \"gender\", \"occupation\", \"zip_code\"],\n                engine=\"python\",\n            )\n        elif self._name == \"ml-1m\":\n            self.user_info = pd.read_csv(\n                os.path.join(self._dir, \"users.dat\"),\n                sep=\"::\",\n                header=None,\n                names=[\"id\", \"gender\", \"age\", \"occupation\", \"zip_code\"],\n                engine=\"python\",\n            )\n        elif self._name == \"ml-10m\":\n            rating_info = pd.read_csv(\n                os.path.join(self._dir, \"ratings.dat\"),\n                sep=\"::\",\n                header=None,\n                names=[\"user_id\", \"movie_id\", \"rating\", \"timestamp\"],\n                dtype={\n                    \"user_id\": np.int32,\n                    \"movie_id\": np.int32,\n                    \"ratings\": np.float32,\n                    \"timestamp\": np.int64,\n                },\n                engine=\"python\",\n            )\n            self.user_info = pd.DataFrame(\n                np.unique(rating_info[\"user_id\"].values.astype(np.int32)),\n                columns=[\"id\"],\n            )\n        else:\n            raise NotImplementedError\n\n    def _process_user_fea(self):\n        \"\"\"\n\n        Parameters\n        ----------\n        user_info : pd.DataFrame\n        name : str\n        For ml-100k and ml-1m, the column name is ['id', 'gender', 'age', 'occupation', 'zip_code'].\n            We take the age, gender, and the one-hot encoding of the occupation as the user features.\n        For ml-10m, there is no user feature and we set the feature to be a single zero.\n\n        Returns\n        -------\n        user_features : np.ndarray\n\n        \"\"\"\n        if self._name == \"ml-100k\" or self._name == \"ml-1m\":\n            ages = self.user_info[\"age\"].values.astype(np.float32)\n            gender = (self.user_info[\"gender\"] == \"F\").values.astype(np.float32)\n            all_occupations = set(self.user_info[\"occupation\"])\n            occupation_map = {ele: i for i, ele in enumerate(all_occupations)}\n            occupation_one_hot = np.zeros(\n                shape=(self.user_info.shape[0], len(all_occupations)),\n                dtype=np.float32,\n            )\n            occupation_one_hot[\n                np.arange(self.user_info.shape[0]),\n                np.array(\n                    [\n                        occupation_map[ele]\n                        for ele in self.user_info[\"occupation\"]\n                    ]\n                ),\n            ] = 1\n            user_features = np.concatenate(\n                [\n                    ages.reshape((self.user_info.shape[0], 1)) / 50.0,\n                    gender.reshape((self.user_info.shape[0], 1)),\n                    occupation_one_hot,\n                ],\n                axis=1,\n            )\n        elif self._name == \"ml-10m\":\n            user_features = np.zeros(\n                shape=(self.user_info.shape[0], 1), dtype=np.float32\n            )\n        else:\n            raise NotImplementedError\n        return user_features\n\n    def _load_raw_movie_info(self):\n        \"\"\"In MovieLens, the movie attributes may have the following formats:\n\n        In ml_100k:\n\n        movie id | movie title | release date | video release date | IMDb URL | [genres]\n\n        In ml_1m, ml_10m:\n\n        MovieID::Title (Release Year)::Genres\n\n        Also, Genres are separated by |, e.g., Adventure|Animation|Children|Comedy|Fantasy\n\n        Parameters\n        ----------\n        name : str\n\n        Returns\n        -------\n        movie_info : pd.DataFrame\n            For ml-100k, the column name is ['id', 'title', 'release_date', 'video_release_date', 'url'] + [GENRES (19)]]\n            For ml-1m and ml-10m, the column name is ['id', 'title'] + [GENRES (18/20)]]\n        \"\"\"\n        if self._name == \"ml-100k\":\n            GENRES = GENRES_ML_100K\n        elif self._name == \"ml-1m\":\n            GENRES = GENRES_ML_1M\n        elif self._name == \"ml-10m\":\n            GENRES = GENRES_ML_10M\n        else:\n            raise NotImplementedError\n\n        if self._name == \"ml-100k\":\n            file_path = os.path.join(self._dir, \"u.item\")\n            self.movie_info = pd.read_csv(\n                file_path,\n                sep=\"|\",\n                header=None,\n                names=[\n                    \"id\",\n                    \"title\",\n                    \"release_date\",\n                    \"video_release_date\",\n                    \"url\",\n                ]\n                + GENRES,\n                encoding=\"iso-8859-1\",\n            )\n        elif self._name == \"ml-1m\" or self._name == \"ml-10m\":\n            file_path = os.path.join(self._dir, \"movies.dat\")\n            movie_info = pd.read_csv(\n                file_path,\n                sep=\"::\",\n                header=None,\n                names=[\"id\", \"title\", \"genres\"],\n                encoding=\"iso-8859-1\",\n                engine=\"python\",\n            )\n            genre_map = {ele: i for i, ele in enumerate(GENRES)}\n            genre_map[\"Children's\"] = genre_map[\"Children\"]\n            genre_map[\"Childrens\"] = genre_map[\"Children\"]\n            movie_genres = np.zeros(\n                shape=(movie_info.shape[0], len(GENRES)), dtype=np.float32\n            )\n            for i, genres in enumerate(movie_info[\"genres\"]):\n                for ele in genres.split(\"|\"):\n                    if ele in genre_map:\n                        movie_genres[i, genre_map[ele]] = 1.0\n                    else:\n                        print(\n                            \"genres not found, filled with unknown: {}\".format(\n                                genres\n                            )\n                        )\n                        movie_genres[i, genre_map[\"unknown\"]] = 1.0\n            for idx, genre_name in enumerate(GENRES):\n                assert idx == genre_map[genre_name]\n                movie_info[genre_name] = movie_genres[:, idx]\n            self.movie_info = movie_info.drop(columns=[\"genres\"])\n        else:\n            raise NotImplementedError\n\n    def _process_movie_fea(self):\n        \"\"\"\n\n        Parameters\n        ----------\n        movie_info : pd.DataFrame\n        name :  str\n\n        Returns\n        -------\n        movie_features : np.ndarray\n            Generate movie features by concatenating embedding and the year\n\n        \"\"\"\n        import torchtext\n        from torchtext.data.utils import get_tokenizer\n\n        if self._name == \"ml-100k\":\n            GENRES = GENRES_ML_100K\n        elif self._name == \"ml-1m\":\n            GENRES = GENRES_ML_1M\n        elif self._name == \"ml-10m\":\n            GENRES = GENRES_ML_10M\n        else:\n            raise NotImplementedError\n\n        # Old torchtext-legacy API commented below\n        # TEXT = torchtext.legacy.data.Field(tokenize='spacy', tokenizer_language='en_core_web_sm')\n        tokenizer = get_tokenizer(\n            \"spacy\", language=\"en_core_web_sm\"\n        )  # new API (torchtext 0.9+)\n        embedding = torchtext.vocab.GloVe(name=\"840B\", dim=300)\n\n        title_embedding = np.zeros(\n            shape=(self.movie_info.shape[0], 300), dtype=np.float32\n        )\n        release_years = np.zeros(\n            shape=(self.movie_info.shape[0], 1), dtype=np.float32\n        )\n        p = re.compile(r\"(.+)\\s*\\((\\d+)\\)\")\n        for i, title in enumerate(self.movie_info[\"title\"]):\n            match_res = p.match(title)\n            if match_res is None:\n                print(\n                    \"{} cannot be matched, index={}, name={}\".format(\n                        title, i, self._name\n                    )\n                )\n                title_context, year = title, 1950\n            else:\n                title_context, year = match_res.groups()\n            # We use average of glove\n            # Upgraded torchtext API:  TEXT.tokenize(title_context) --> tokenizer(title_context)\n            title_embedding[i, :] = (\n                embedding.get_vecs_by_tokens(tokenizer(title_context))\n                .numpy()\n                .mean(axis=0)\n            )\n            release_years[i] = float(year)\n        movie_features = np.concatenate(\n            (\n                title_embedding,\n                (release_years - 1950.0) / 100.0,\n                self.movie_info[GENRES],\n            ),\n            axis=1,\n        )\n        return movie_features\n\n\nif __name__ == \"__main__\":\n    MovieLens(\"ml-100k\", device=th.device(\"cpu\"), symm=True)\n"
  },
  {
    "path": "examples/pytorch/gcmc/model.py",
    "content": "\"\"\"NN modules\"\"\"\nimport dgl.function as fn\nimport dgl.nn.pytorch as dglnn\nimport torch as th\nimport torch.nn as nn\nfrom torch.nn import init\n\nfrom utils import get_activation, to_etype_name\n\n\nclass GCMCGraphConv(nn.Module):\n    \"\"\"Graph convolution module used in the GCMC model.\n\n    Parameters\n    ----------\n    in_feats : int\n        Input feature size.\n    out_feats : int\n        Output feature size.\n    weight : bool, optional\n        If True, apply a linear layer. Otherwise, aggregating the messages\n        without a weight matrix or with an shared weight provided by caller.\n    device: str, optional\n        Which device to put data in. Useful in mix_cpu_gpu training and\n        multi-gpu training\n    \"\"\"\n\n    def __init__(\n        self, in_feats, out_feats, weight=True, device=None, dropout_rate=0.0\n    ):\n        super(GCMCGraphConv, self).__init__()\n        self._in_feats = in_feats\n        self._out_feats = out_feats\n        self.device = device\n        self.dropout = nn.Dropout(dropout_rate)\n\n        if weight:\n            self.weight = nn.Parameter(th.Tensor(in_feats, out_feats))\n        else:\n            self.register_parameter(\"weight\", None)\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        \"\"\"Reinitialize learnable parameters.\"\"\"\n        if self.weight is not None:\n            init.xavier_uniform_(self.weight)\n\n    def forward(self, graph, feat, weight=None):\n        \"\"\"Compute graph convolution.\n\n        Normalizer constant :math:`c_{ij}` is stored as two node data \"ci\"\n        and \"cj\".\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The graph.\n        feat : torch.Tensor\n            The input feature\n        weight : torch.Tensor, optional\n            Optional external weight tensor.\n        dropout : torch.nn.Dropout, optional\n            Optional external dropout layer.\n\n        Returns\n        -------\n        torch.Tensor\n            The output feature\n        \"\"\"\n        with graph.local_scope():\n            if isinstance(feat, tuple):\n                feat, _ = feat  # dst feature not used\n            cj = graph.srcdata[\"cj\"]\n            ci = graph.dstdata[\"ci\"]\n            if self.device is not None:\n                cj = cj.to(self.device)\n                ci = ci.to(self.device)\n            if weight is not None:\n                if self.weight is not None:\n                    raise DGLError(\n                        \"External weight is provided while at the same time the\"\n                        \" module has defined its own weight parameter. Please\"\n                        \" create the module with flag weight=False.\"\n                    )\n            else:\n                weight = self.weight\n\n            if weight is not None:\n                feat = dot_or_identity(feat, weight, self.device)\n\n            feat = feat * self.dropout(cj)\n            graph.srcdata[\"h\"] = feat\n            graph.update_all(\n                fn.copy_u(u=\"h\", out=\"m\"), fn.sum(msg=\"m\", out=\"h\")\n            )\n            rst = graph.dstdata[\"h\"]\n            rst = rst * ci\n\n        return rst\n\n\nclass GCMCLayer(nn.Module):\n    r\"\"\"GCMC layer\n\n    .. math::\n        z_j^{(l+1)} = \\sigma_{agg}\\left[\\mathrm{agg}\\left(\n        \\sum_{j\\in\\mathcal{N}_1}\\frac{1}{c_{ij}}W_1h_j, \\ldots,\n        \\sum_{j\\in\\mathcal{N}_R}\\frac{1}{c_{ij}}W_Rh_j\n        \\right)\\right]\n\n    After that, apply an extra output projection:\n\n    .. math::\n        h_j^{(l+1)} = \\sigma_{out}W_oz_j^{(l+1)}\n\n    The equation is applied to both user nodes and movie nodes and the parameters\n    are not shared unless ``share_user_item_param`` is true.\n\n    Parameters\n    ----------\n    rating_vals : list of int or float\n        Possible rating values.\n    user_in_units : int\n        Size of user input feature\n    movie_in_units : int\n        Size of movie input feature\n    msg_units : int\n        Size of message :math:`W_rh_j`\n    out_units : int\n        Size of of final output user and movie features\n    dropout_rate : float, optional\n        Dropout rate (Default: 0.0)\n    agg : str, optional\n        Function to aggregate messages of different ratings.\n        Could be any of the supported cross type reducers:\n        \"sum\", \"max\", \"min\", \"mean\", \"stack\".\n        (Default: \"stack\")\n    agg_act : callable, str, optional\n        Activation function :math:`sigma_{agg}`. (Default: None)\n    out_act : callable, str, optional\n        Activation function :math:`sigma_{agg}`. (Default: None)\n    share_user_item_param : bool, optional\n        If true, user node and movie node share the same set of parameters.\n        Require ``user_in_units`` and ``move_in_units`` to be the same.\n        (Default: False)\n    device: str, optional\n        Which device to put data in. Useful in mix_cpu_gpu training and\n        multi-gpu training\n    \"\"\"\n\n    def __init__(\n        self,\n        rating_vals,\n        user_in_units,\n        movie_in_units,\n        msg_units,\n        out_units,\n        dropout_rate=0.0,\n        agg=\"stack\",  # or 'sum'\n        agg_act=None,\n        out_act=None,\n        share_user_item_param=False,\n        device=None,\n    ):\n        super(GCMCLayer, self).__init__()\n        self.rating_vals = rating_vals\n        self.agg = agg\n        self.share_user_item_param = share_user_item_param\n        self.ufc = nn.Linear(msg_units, out_units)\n        if share_user_item_param:\n            self.ifc = self.ufc\n        else:\n            self.ifc = nn.Linear(msg_units, out_units)\n        if agg == \"stack\":\n            # divide the original msg unit size by number of ratings to keep\n            # the dimensionality\n            assert msg_units % len(rating_vals) == 0\n            msg_units = msg_units // len(rating_vals)\n        self.dropout = nn.Dropout(dropout_rate)\n        self.W_r = nn.ParameterDict()\n        subConv = {}\n        for rating in rating_vals:\n            # PyTorch parameter name can't contain \".\"\n            rating = to_etype_name(rating)\n            rev_rating = \"rev-%s\" % rating\n            if share_user_item_param and user_in_units == movie_in_units:\n                self.W_r[rating] = nn.Parameter(\n                    th.randn(user_in_units, msg_units)\n                )\n                self.W_r[\"rev-%s\" % rating] = self.W_r[rating]\n                subConv[rating] = GCMCGraphConv(\n                    user_in_units,\n                    msg_units,\n                    weight=False,\n                    device=device,\n                    dropout_rate=dropout_rate,\n                )\n                subConv[rev_rating] = GCMCGraphConv(\n                    user_in_units,\n                    msg_units,\n                    weight=False,\n                    device=device,\n                    dropout_rate=dropout_rate,\n                )\n            else:\n                self.W_r = None\n                subConv[rating] = GCMCGraphConv(\n                    user_in_units,\n                    msg_units,\n                    weight=True,\n                    device=device,\n                    dropout_rate=dropout_rate,\n                )\n                subConv[rev_rating] = GCMCGraphConv(\n                    movie_in_units,\n                    msg_units,\n                    weight=True,\n                    device=device,\n                    dropout_rate=dropout_rate,\n                )\n        self.conv = dglnn.HeteroGraphConv(subConv, aggregate=agg)\n        self.agg_act = get_activation(agg_act)\n        self.out_act = get_activation(out_act)\n        self.device = device\n        self.reset_parameters()\n\n    def partial_to(self, device):\n        \"\"\"Put parameters into device except W_r\n\n        Parameters\n        ----------\n        device : torch device\n            Which device the parameters are put in.\n        \"\"\"\n        assert device == self.device\n        if device is not None:\n            self.ufc.cuda(device)\n            if self.share_user_item_param is False:\n                self.ifc.cuda(device)\n            self.dropout.cuda(device)\n\n    def reset_parameters(self):\n        for p in self.parameters():\n            if p.dim() > 1:\n                nn.init.xavier_uniform_(p)\n\n    def forward(self, graph, ufeat=None, ifeat=None):\n        \"\"\"Forward function\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            User-movie rating graph. It should contain two node types: \"user\"\n            and \"movie\" and many edge types each for one rating value.\n        ufeat : torch.Tensor, optional\n            User features. If None, using an identity matrix.\n        ifeat : torch.Tensor, optional\n            Movie features. If None, using an identity matrix.\n\n        Returns\n        -------\n        new_ufeat : torch.Tensor\n            New user features\n        new_ifeat : torch.Tensor\n            New movie features\n        \"\"\"\n        in_feats = {\"user\": ufeat, \"movie\": ifeat}\n        mod_args = {}\n        for i, rating in enumerate(self.rating_vals):\n            rating = to_etype_name(rating)\n            rev_rating = \"rev-%s\" % rating\n            mod_args[rating] = (\n                self.W_r[rating] if self.W_r is not None else None,\n            )\n            mod_args[rev_rating] = (\n                self.W_r[rev_rating] if self.W_r is not None else None,\n            )\n        out_feats = self.conv(graph, in_feats, mod_args=mod_args)\n        ufeat = out_feats[\"user\"]\n        ifeat = out_feats[\"movie\"]\n        ufeat = ufeat.view(ufeat.shape[0], -1)\n        ifeat = ifeat.view(ifeat.shape[0], -1)\n\n        # fc and non-linear\n        ufeat = self.agg_act(ufeat)\n        ifeat = self.agg_act(ifeat)\n        ufeat = self.dropout(ufeat)\n        ifeat = self.dropout(ifeat)\n        ufeat = self.ufc(ufeat)\n        ifeat = self.ifc(ifeat)\n        return self.out_act(ufeat), self.out_act(ifeat)\n\n\nclass BiDecoder(nn.Module):\n    r\"\"\"Bi-linear decoder.\n\n    Given a bipartite graph G, for each edge (i, j) ~ G, compute the likelihood\n    of it being class r by:\n\n    .. math::\n        p(M_{ij}=r) = \\text{softmax}(u_i^TQ_rv_j)\n\n    The trainable parameter :math:`Q_r` is further decomposed to a linear\n    combination of basis weight matrices :math:`P_s`:\n\n    .. math::\n        Q_r = \\sum_{s=1}^{b} a_{rs}P_s\n\n    Parameters\n    ----------\n    in_units : int\n        Size of input user and movie features\n    num_classes : int\n        Number of classes.\n    num_basis : int, optional\n        Number of basis. (Default: 2)\n    dropout_rate : float, optional\n        Dropout raite (Default: 0.0)\n    \"\"\"\n\n    def __init__(self, in_units, num_classes, num_basis=2, dropout_rate=0.0):\n        super(BiDecoder, self).__init__()\n        self._num_basis = num_basis\n        self.dropout = nn.Dropout(dropout_rate)\n        self.Ps = nn.ParameterList(\n            nn.Parameter(th.randn(in_units, in_units)) for _ in range(num_basis)\n        )\n        self.combine_basis = nn.Linear(self._num_basis, num_classes, bias=False)\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        for p in self.parameters():\n            if p.dim() > 1:\n                nn.init.xavier_uniform_(p)\n\n    def forward(self, graph, ufeat, ifeat):\n        \"\"\"Forward function.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            \"Flattened\" user-movie graph with only one edge type.\n        ufeat : th.Tensor\n            User embeddings. Shape: (|V_u|, D)\n        ifeat : th.Tensor\n            Movie embeddings. Shape: (|V_m|, D)\n\n        Returns\n        -------\n        th.Tensor\n            Predicting scores for each user-movie edge.\n        \"\"\"\n        with graph.local_scope():\n            ufeat = self.dropout(ufeat)\n            ifeat = self.dropout(ifeat)\n            graph.nodes[\"movie\"].data[\"h\"] = ifeat\n            basis_out = []\n            for i in range(self._num_basis):\n                graph.nodes[\"user\"].data[\"h\"] = ufeat @ self.Ps[i]\n                graph.apply_edges(fn.u_dot_v(\"h\", \"h\", \"sr\"))\n                basis_out.append(graph.edata[\"sr\"])\n            out = th.cat(basis_out, dim=1)\n            out = self.combine_basis(out)\n        return out\n\n\nclass DenseBiDecoder(nn.Module):\n    r\"\"\"Dense bi-linear decoder.\n\n    Dense implementation of the bi-linear decoder used in GCMC. Suitable when\n    the graph can be efficiently represented by a pair of arrays (one for source\n    nodes; one for destination nodes).\n\n    Parameters\n    ----------\n    in_units : int\n        Size of input user and movie features\n    num_classes : int\n        Number of classes.\n    num_basis : int, optional\n        Number of basis. (Default: 2)\n    dropout_rate : float, optional\n        Dropout raite (Default: 0.0)\n    \"\"\"\n\n    def __init__(self, in_units, num_classes, num_basis=2, dropout_rate=0.0):\n        super().__init__()\n        self._num_basis = num_basis\n        self.dropout = nn.Dropout(dropout_rate)\n        self.P = nn.Parameter(th.randn(num_basis, in_units, in_units))\n        self.combine_basis = nn.Linear(self._num_basis, num_classes, bias=False)\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        for p in self.parameters():\n            if p.dim() > 1:\n                nn.init.xavier_uniform_(p)\n\n    def forward(self, ufeat, ifeat):\n        \"\"\"Forward function.\n\n        Compute logits for each pair ``(ufeat[i], ifeat[i])``.\n\n        Parameters\n        ----------\n        ufeat : th.Tensor\n            User embeddings. Shape: (B, D)\n        ifeat : th.Tensor\n            Movie embeddings. Shape: (B, D)\n\n        Returns\n        -------\n        th.Tensor\n            Predicting scores for each user-movie edge. Shape: (B, num_classes)\n        \"\"\"\n        ufeat = self.dropout(ufeat)\n        ifeat = self.dropout(ifeat)\n        out = th.einsum(\"ai,bij,aj->ab\", ufeat, self.P, ifeat)\n        out = self.combine_basis(out)\n        return out\n\n\ndef dot_or_identity(A, B, device=None):\n    # if A is None, treat as identity matrix\n    if A is None:\n        return B\n    elif len(A.shape) == 1:\n        if device is None:\n            return B[A]\n        else:\n            return B[A].to(device)\n    else:\n        return A @ B\n"
  },
  {
    "path": "examples/pytorch/gcmc/train.py",
    "content": "\"\"\"Training GCMC model on the MovieLens data set.\n\nThe script loads the full graph to the training device.\n\"\"\"\nimport argparse\nimport logging\nimport os\nimport random\nimport string\nimport time\n\nimport numpy as np\nimport torch as th\nimport torch.nn as nn\nfrom data import MovieLens\nfrom model import BiDecoder, GCMCLayer\nfrom utils import (\n    get_activation,\n    get_optimizer,\n    MetricLogger,\n    torch_net_info,\n    torch_total_param_num,\n)\n\n\nclass Net(nn.Module):\n    def __init__(self, args):\n        super(Net, self).__init__()\n        self._act = get_activation(args.model_activation)\n        self.encoder = GCMCLayer(\n            args.rating_vals,\n            args.src_in_units,\n            args.dst_in_units,\n            args.gcn_agg_units,\n            args.gcn_out_units,\n            args.gcn_dropout,\n            args.gcn_agg_accum,\n            agg_act=self._act,\n            share_user_item_param=args.share_param,\n            device=args.device,\n        )\n        self.decoder = BiDecoder(\n            in_units=args.gcn_out_units,\n            num_classes=len(args.rating_vals),\n            num_basis=args.gen_r_num_basis_func,\n        )\n\n    def forward(self, enc_graph, dec_graph, ufeat, ifeat):\n        user_out, movie_out = self.encoder(enc_graph, ufeat, ifeat)\n        pred_ratings = self.decoder(dec_graph, user_out, movie_out)\n        return pred_ratings\n\n\ndef evaluate(args, net, dataset, segment=\"valid\"):\n    possible_rating_values = dataset.possible_rating_values\n    nd_possible_rating_values = th.FloatTensor(possible_rating_values).to(\n        args.device\n    )\n\n    if segment == \"valid\":\n        rating_values = dataset.valid_truths\n        enc_graph = dataset.valid_enc_graph\n        dec_graph = dataset.valid_dec_graph\n    elif segment == \"test\":\n        rating_values = dataset.test_truths\n        enc_graph = dataset.test_enc_graph\n        dec_graph = dataset.test_dec_graph\n    else:\n        raise NotImplementedError\n\n    # Evaluate RMSE\n    net.eval()\n    with th.no_grad():\n        pred_ratings = net(\n            enc_graph, dec_graph, dataset.user_feature, dataset.movie_feature\n        )\n    real_pred_ratings = (\n        th.softmax(pred_ratings, dim=1) * nd_possible_rating_values.view(1, -1)\n    ).sum(dim=1)\n    rmse = ((real_pred_ratings - rating_values) ** 2.0).mean().item()\n    rmse = np.sqrt(rmse)\n    return rmse\n\n\ndef train(args):\n    print(args)\n    dataset = MovieLens(\n        args.data_name,\n        args.device,\n        use_one_hot_fea=args.use_one_hot_fea,\n        symm=args.gcn_agg_norm_symm,\n        test_ratio=args.data_test_ratio,\n        valid_ratio=args.data_valid_ratio,\n    )\n    print(\"Loading data finished ...\\n\")\n\n    args.src_in_units = dataset.user_feature_shape[1]\n    args.dst_in_units = dataset.movie_feature_shape[1]\n    args.rating_vals = dataset.possible_rating_values\n\n    ### build the net\n    net = Net(args=args)\n    net = net.to(args.device)\n    nd_possible_rating_values = th.FloatTensor(\n        dataset.possible_rating_values\n    ).to(args.device)\n    rating_loss_net = nn.CrossEntropyLoss()\n    learning_rate = args.train_lr\n    optimizer = get_optimizer(args.train_optimizer)(\n        net.parameters(), lr=learning_rate\n    )\n    print(\"Loading network finished ...\\n\")\n\n    ### perpare training data\n    train_gt_labels = dataset.train_labels\n    train_gt_ratings = dataset.train_truths\n\n    ### prepare the logger\n    train_loss_logger = MetricLogger(\n        [\"iter\", \"loss\", \"rmse\"],\n        [\"%d\", \"%.4f\", \"%.4f\"],\n        os.path.join(args.save_dir, \"train_loss%d.csv\" % args.save_id),\n    )\n    valid_loss_logger = MetricLogger(\n        [\"iter\", \"rmse\"],\n        [\"%d\", \"%.4f\"],\n        os.path.join(args.save_dir, \"valid_loss%d.csv\" % args.save_id),\n    )\n    test_loss_logger = MetricLogger(\n        [\"iter\", \"rmse\"],\n        [\"%d\", \"%.4f\"],\n        os.path.join(args.save_dir, \"test_loss%d.csv\" % args.save_id),\n    )\n\n    ### declare the loss information\n    best_valid_rmse = np.inf\n    no_better_valid = 0\n    best_iter = -1\n    count_rmse = 0\n    count_num = 0\n    count_loss = 0\n\n    dataset.train_enc_graph = dataset.train_enc_graph.int().to(args.device)\n    dataset.train_dec_graph = dataset.train_dec_graph.int().to(args.device)\n    dataset.valid_enc_graph = dataset.train_enc_graph\n    dataset.valid_dec_graph = dataset.valid_dec_graph.int().to(args.device)\n    dataset.test_enc_graph = dataset.test_enc_graph.int().to(args.device)\n    dataset.test_dec_graph = dataset.test_dec_graph.int().to(args.device)\n\n    print(\"Start training ...\")\n    dur = []\n    for iter_idx in range(1, args.train_max_iter):\n        if iter_idx > 3:\n            t0 = time.time()\n        net.train()\n        pred_ratings = net(\n            dataset.train_enc_graph,\n            dataset.train_dec_graph,\n            dataset.user_feature,\n            dataset.movie_feature,\n        )\n        loss = rating_loss_net(pred_ratings, train_gt_labels).mean()\n        count_loss += loss.item()\n        optimizer.zero_grad()\n        loss.backward()\n        nn.utils.clip_grad_norm_(net.parameters(), args.train_grad_clip)\n        optimizer.step()\n\n        if iter_idx > 3:\n            dur.append(time.time() - t0)\n\n        if iter_idx == 1:\n            print(\"Total #Param of net: %d\" % (torch_total_param_num(net)))\n            print(\n                torch_net_info(\n                    net,\n                    save_path=os.path.join(\n                        args.save_dir, \"net%d.txt\" % args.save_id\n                    ),\n                )\n            )\n\n        real_pred_ratings = (\n            th.softmax(pred_ratings, dim=1)\n            * nd_possible_rating_values.view(1, -1)\n        ).sum(dim=1)\n        rmse = ((real_pred_ratings - train_gt_ratings) ** 2).sum()\n        count_rmse += rmse.item()\n        count_num += pred_ratings.shape[0]\n\n        if iter_idx % args.train_log_interval == 0:\n            train_loss_logger.log(\n                iter=iter_idx,\n                loss=count_loss / (iter_idx + 1),\n                rmse=count_rmse / count_num,\n            )\n            logging_str = \"Iter={:4d}, loss={:.4f}, rmse={:.4f}\".format(\n                iter_idx,\n                count_loss / iter_idx,\n                count_rmse / count_num,\n            )\n            if iter_idx > 3:\n                logging_str += \", time={:.4f}\".format(np.average(dur))\n\n            count_rmse = 0\n            count_num = 0\n\n        if iter_idx % args.train_valid_interval == 0:\n            valid_rmse = evaluate(\n                args=args, net=net, dataset=dataset, segment=\"valid\"\n            )\n            valid_loss_logger.log(iter=iter_idx, rmse=valid_rmse)\n            logging_str += \",\\tVal RMSE={:.4f}\".format(valid_rmse)\n\n            if valid_rmse < best_valid_rmse:\n                best_valid_rmse = valid_rmse\n                no_better_valid = 0\n                best_iter = iter_idx\n                test_rmse = evaluate(\n                    args=args, net=net, dataset=dataset, segment=\"test\"\n                )\n                best_test_rmse = test_rmse\n                test_loss_logger.log(iter=iter_idx, rmse=test_rmse)\n                logging_str += \", Test RMSE={:.4f}\".format(test_rmse)\n            else:\n                no_better_valid += 1\n                if (\n                    no_better_valid > args.train_early_stopping_patience\n                    and learning_rate <= args.train_min_lr\n                ):\n                    logging.info(\n                        \"Early stopping threshold reached. Stop training.\"\n                    )\n                    break\n                if no_better_valid > args.train_decay_patience:\n                    new_lr = max(\n                        learning_rate * args.train_lr_decay_factor,\n                        args.train_min_lr,\n                    )\n                    if new_lr < learning_rate:\n                        learning_rate = new_lr\n                        logging.info(\"\\tChange the LR to %g\" % new_lr)\n                        for p in optimizer.param_groups:\n                            p[\"lr\"] = learning_rate\n                        no_better_valid = 0\n        if iter_idx % args.train_log_interval == 0:\n            print(logging_str)\n    print(\n        \"Best Iter Idx={}, Best Valid RMSE={:.4f}, Best Test RMSE={:.4f}\".format(\n            best_iter, best_valid_rmse, best_test_rmse\n        )\n    )\n    train_loss_logger.close()\n    valid_loss_logger.close()\n    test_loss_logger.close()\n\n\ndef config():\n    parser = argparse.ArgumentParser(description=\"GCMC\")\n    parser.add_argument(\"--seed\", default=123, type=int)\n    parser.add_argument(\n        \"--device\",\n        default=\"0\",\n        type=int,\n        help=\"Running device. E.g `--device 0`, if using cpu, set `--device -1`\",\n    )\n    parser.add_argument(\"--save_dir\", type=str, help=\"The saving directory\")\n    parser.add_argument(\"--save_id\", type=int, help=\"The saving log id\")\n    parser.add_argument(\"--silent\", action=\"store_true\")\n    parser.add_argument(\n        \"--data_name\",\n        default=\"ml-1m\",\n        type=str,\n        help=\"The dataset name: ml-100k, ml-1m, ml-10m\",\n    )\n    parser.add_argument(\n        \"--data_test_ratio\", type=float, default=0.1\n    )  ## for ml-100k the test ration is 0.2\n    parser.add_argument(\"--data_valid_ratio\", type=float, default=0.1)\n    parser.add_argument(\"--use_one_hot_fea\", action=\"store_true\", default=False)\n    parser.add_argument(\"--model_activation\", type=str, default=\"leaky\")\n    parser.add_argument(\"--gcn_dropout\", type=float, default=0.7)\n    parser.add_argument(\"--gcn_agg_norm_symm\", type=bool, default=True)\n    parser.add_argument(\"--gcn_agg_units\", type=int, default=500)\n    parser.add_argument(\"--gcn_agg_accum\", type=str, default=\"sum\")\n    parser.add_argument(\"--gcn_out_units\", type=int, default=75)\n    parser.add_argument(\"--gen_r_num_basis_func\", type=int, default=2)\n    parser.add_argument(\"--train_max_iter\", type=int, default=2000)\n    parser.add_argument(\"--train_log_interval\", type=int, default=1)\n    parser.add_argument(\"--train_valid_interval\", type=int, default=1)\n    parser.add_argument(\"--train_optimizer\", type=str, default=\"adam\")\n    parser.add_argument(\"--train_grad_clip\", type=float, default=1.0)\n    parser.add_argument(\"--train_lr\", type=float, default=0.01)\n    parser.add_argument(\"--train_min_lr\", type=float, default=0.001)\n    parser.add_argument(\"--train_lr_decay_factor\", type=float, default=0.5)\n    parser.add_argument(\"--train_decay_patience\", type=int, default=50)\n    parser.add_argument(\n        \"--train_early_stopping_patience\", type=int, default=100\n    )\n    parser.add_argument(\"--share_param\", default=False, action=\"store_true\")\n\n    args = parser.parse_args()\n    args.device = (\n        th.device(args.device) if args.device >= 0 else th.device(\"cpu\")\n    )\n\n    ### configure save_fir to save all the info\n    if args.save_dir is None:\n        args.save_dir = (\n            args.data_name\n            + \"_\"\n            + \"\".join(\n                random.choices(string.ascii_uppercase + string.digits, k=2)\n            )\n        )\n    if args.save_id is None:\n        args.save_id = np.random.randint(20)\n    args.save_dir = os.path.join(\"log\", args.save_dir)\n    if not os.path.isdir(args.save_dir):\n        os.makedirs(args.save_dir)\n\n    return args\n\n\nif __name__ == \"__main__\":\n    args = config()\n    np.random.seed(args.seed)\n    th.manual_seed(args.seed)\n    if th.cuda.is_available():\n        th.cuda.manual_seed_all(args.seed)\n    train(args)\n"
  },
  {
    "path": "examples/pytorch/gcmc/train_sampling.py",
    "content": "\"\"\"Training GCMC model on the MovieLens data set by mini-batch sampling.\n\nThe script loads the full graph in CPU and samples subgraphs for computing\ngradients on the training device. The script also supports multi-GPU for\nfurther acceleration.\n\"\"\"\nimport argparse\nimport logging\nimport os, time\nimport random\nimport string\nimport traceback\n\nimport dgl\nimport numpy as np\nimport torch as th\nimport torch.multiprocessing as mp\nimport torch.nn as nn\nimport tqdm\nfrom data import MovieLens\nfrom model import BiDecoder, DenseBiDecoder, GCMCLayer\nfrom torch.nn.parallel import DistributedDataParallel\nfrom torch.utils.data import DataLoader\nfrom utils import (\n    get_activation,\n    get_optimizer,\n    MetricLogger,\n    to_etype_name,\n    torch_net_info,\n    torch_total_param_num,\n)\n\n\nclass Net(nn.Module):\n    def __init__(self, args, dev_id):\n        super(Net, self).__init__()\n        self._act = get_activation(args.model_activation)\n        self.encoder = GCMCLayer(\n            args.rating_vals,\n            args.src_in_units,\n            args.dst_in_units,\n            args.gcn_agg_units,\n            args.gcn_out_units,\n            args.gcn_dropout,\n            args.gcn_agg_accum,\n            agg_act=self._act,\n            share_user_item_param=args.share_param,\n            device=dev_id,\n        )\n        if args.mix_cpu_gpu and args.use_one_hot_fea:\n            # if use_one_hot_fea, user and movie feature is None\n            # W can be extremely large, with mix_cpu_gpu W should be stored in CPU\n            self.encoder.partial_to(dev_id)\n        else:\n            self.encoder.to(dev_id)\n\n        self.decoder = BiDecoder(\n            in_units=args.gcn_out_units,\n            num_classes=len(args.rating_vals),\n            num_basis=args.gen_r_num_basis_func,\n        )\n        self.decoder.to(dev_id)\n\n    def forward(\n        self, compact_g, frontier, ufeat, ifeat, possible_rating_values\n    ):\n        user_out, movie_out = self.encoder(frontier, ufeat, ifeat)\n        pred_ratings = self.decoder(compact_g, user_out, movie_out)\n        return pred_ratings\n\n\ndef load_subtensor(input_nodes, pair_graph, blocks, dataset, parent_graph):\n    output_nodes = pair_graph.ndata[dgl.NID]\n    head_feat = (\n        input_nodes[\"user\"]\n        if dataset.user_feature is None\n        else dataset.user_feature[input_nodes[\"user\"]]\n    )\n    tail_feat = (\n        input_nodes[\"movie\"]\n        if dataset.movie_feature is None\n        else dataset.movie_feature[input_nodes[\"movie\"]]\n    )\n\n    for block in blocks:\n        block.dstnodes[\"user\"].data[\"ci\"] = parent_graph.nodes[\"user\"].data[\n            \"ci\"\n        ][block.dstnodes[\"user\"].data[dgl.NID]]\n        block.srcnodes[\"user\"].data[\"cj\"] = parent_graph.nodes[\"user\"].data[\n            \"cj\"\n        ][block.srcnodes[\"user\"].data[dgl.NID]]\n        block.dstnodes[\"movie\"].data[\"ci\"] = parent_graph.nodes[\"movie\"].data[\n            \"ci\"\n        ][block.dstnodes[\"movie\"].data[dgl.NID]]\n        block.srcnodes[\"movie\"].data[\"cj\"] = parent_graph.nodes[\"movie\"].data[\n            \"cj\"\n        ][block.srcnodes[\"movie\"].data[dgl.NID]]\n\n    return head_feat, tail_feat, blocks\n\n\ndef flatten_etypes(pair_graph, dataset, segment):\n    n_users = pair_graph.num_nodes(\"user\")\n    n_movies = pair_graph.num_nodes(\"movie\")\n    src = []\n    dst = []\n    labels = []\n    ratings = []\n\n    for rating in dataset.possible_rating_values:\n        src_etype, dst_etype = pair_graph.edges(\n            order=\"eid\", etype=to_etype_name(rating)\n        )\n        src.append(src_etype)\n        dst.append(dst_etype)\n        label = np.searchsorted(dataset.possible_rating_values, rating)\n        ratings.append(th.LongTensor(np.full_like(src_etype, rating)))\n        labels.append(th.LongTensor(np.full_like(src_etype, label)))\n    src = th.cat(src)\n    dst = th.cat(dst)\n    ratings = th.cat(ratings)\n    labels = th.cat(labels)\n\n    flattened_pair_graph = dgl.heterograph(\n        {(\"user\", \"rate\", \"movie\"): (src, dst)},\n        num_nodes_dict={\"user\": n_users, \"movie\": n_movies},\n    )\n    flattened_pair_graph.edata[\"rating\"] = ratings\n    flattened_pair_graph.edata[\"label\"] = labels\n\n    return flattened_pair_graph\n\n\ndef evaluate(args, dev_id, net, dataset, dataloader, segment=\"valid\"):\n    possible_rating_values = dataset.possible_rating_values\n    nd_possible_rating_values = th.FloatTensor(possible_rating_values).to(\n        dev_id\n    )\n\n    real_pred_ratings = []\n    true_rel_ratings = []\n    for input_nodes, pair_graph, blocks in dataloader:\n        head_feat, tail_feat, blocks = load_subtensor(\n            input_nodes,\n            pair_graph,\n            blocks,\n            dataset,\n            dataset.valid_enc_graph\n            if segment == \"valid\"\n            else dataset.test_enc_graph,\n        )\n        frontier = blocks[0]\n        true_relation_ratings = (\n            dataset.valid_truths[pair_graph.edata[dgl.EID]]\n            if segment == \"valid\"\n            else dataset.test_truths[pair_graph.edata[dgl.EID]]\n        )\n\n        frontier = frontier.to(dev_id)\n        head_feat = head_feat.to(dev_id)\n        tail_feat = tail_feat.to(dev_id)\n        pair_graph = pair_graph.to(dev_id)\n        with th.no_grad():\n            pred_ratings = net(\n                pair_graph,\n                frontier,\n                head_feat,\n                tail_feat,\n                possible_rating_values,\n            )\n        batch_pred_ratings = (\n            th.softmax(pred_ratings, dim=1)\n            * nd_possible_rating_values.view(1, -1)\n        ).sum(dim=1)\n        real_pred_ratings.append(batch_pred_ratings)\n        true_rel_ratings.append(true_relation_ratings)\n\n    real_pred_ratings = th.cat(real_pred_ratings, dim=0)\n    true_rel_ratings = th.cat(true_rel_ratings, dim=0).to(dev_id)\n    rmse = ((real_pred_ratings - true_rel_ratings) ** 2.0).mean().item()\n    rmse = np.sqrt(rmse)\n    return rmse\n\n\ndef config():\n    parser = argparse.ArgumentParser(description=\"GCMC\")\n    parser.add_argument(\"--seed\", default=123, type=int)\n    parser.add_argument(\"--gpu\", type=str, default=\"0\")\n    parser.add_argument(\"--save_dir\", type=str, help=\"The saving directory\")\n    parser.add_argument(\"--save_id\", type=int, help=\"The saving log id\")\n    parser.add_argument(\"--silent\", action=\"store_true\")\n    parser.add_argument(\n        \"--data_name\",\n        default=\"ml-1m\",\n        type=str,\n        help=\"The dataset name: ml-100k, ml-1m, ml-10m\",\n    )\n    parser.add_argument(\n        \"--data_test_ratio\", type=float, default=0.1\n    )  ## for ml-100k the test ration is 0.2\n    parser.add_argument(\"--data_valid_ratio\", type=float, default=0.1)\n    parser.add_argument(\"--use_one_hot_fea\", action=\"store_true\", default=False)\n    parser.add_argument(\"--model_activation\", type=str, default=\"leaky\")\n    parser.add_argument(\"--gcn_dropout\", type=float, default=0.7)\n    parser.add_argument(\"--gcn_agg_norm_symm\", type=bool, default=True)\n    parser.add_argument(\"--gcn_agg_units\", type=int, default=500)\n    parser.add_argument(\"--gcn_agg_accum\", type=str, default=\"sum\")\n    parser.add_argument(\"--gcn_out_units\", type=int, default=75)\n    parser.add_argument(\"--gen_r_num_basis_func\", type=int, default=2)\n    parser.add_argument(\"--train_max_epoch\", type=int, default=1000)\n    parser.add_argument(\"--train_log_interval\", type=int, default=1)\n    parser.add_argument(\"--train_valid_interval\", type=int, default=1)\n    parser.add_argument(\"--train_optimizer\", type=str, default=\"adam\")\n    parser.add_argument(\"--train_grad_clip\", type=float, default=1.0)\n    parser.add_argument(\"--train_lr\", type=float, default=0.01)\n    parser.add_argument(\"--train_min_lr\", type=float, default=0.0001)\n    parser.add_argument(\"--train_lr_decay_factor\", type=float, default=0.5)\n    parser.add_argument(\"--train_decay_patience\", type=int, default=25)\n    parser.add_argument(\"--train_early_stopping_patience\", type=int, default=50)\n    parser.add_argument(\"--share_param\", default=False, action=\"store_true\")\n    parser.add_argument(\"--mix_cpu_gpu\", default=False, action=\"store_true\")\n    parser.add_argument(\"--minibatch_size\", type=int, default=20000)\n    parser.add_argument(\"--num_workers_per_gpu\", type=int, default=8)\n\n    args = parser.parse_args()\n    ### configure save_fir to save all the info\n    if args.save_dir is None:\n        args.save_dir = (\n            args.data_name\n            + \"_\"\n            + \"\".join(\n                random.choices(string.ascii_uppercase + string.digits, k=2)\n            )\n        )\n    if args.save_id is None:\n        args.save_id = np.random.randint(20)\n    args.save_dir = os.path.join(\"log\", args.save_dir)\n    if not os.path.isdir(args.save_dir):\n        os.makedirs(args.save_dir)\n\n    return args\n\n\ndef run(proc_id, n_gpus, args, devices, dataset):\n    dev_id = devices[proc_id]\n    if n_gpus > 1:\n        dist_init_method = \"tcp://{master_ip}:{master_port}\".format(\n            master_ip=\"127.0.0.1\", master_port=\"12345\"\n        )\n        world_size = n_gpus\n        th.distributed.init_process_group(\n            backend=\"nccl\",\n            init_method=dist_init_method,\n            world_size=world_size,\n            rank=dev_id,\n        )\n    if n_gpus > 0:\n        th.cuda.set_device(dev_id)\n\n    train_labels = dataset.train_labels\n    train_truths = dataset.train_truths\n    num_edges = train_truths.shape[0]\n\n    reverse_types = {\n        to_etype_name(k): \"rev-\" + to_etype_name(k)\n        for k in dataset.possible_rating_values\n    }\n    reverse_types.update({v: k for k, v in reverse_types.items()})\n    sampler = dgl.dataloading.MultiLayerNeighborSampler(\n        [None], return_eids=True\n    )\n    sampler = dgl.dataloading.as_edge_prediction_sampler(sampler)\n    dataloader = dgl.dataloading.DataLoader(\n        dataset.train_enc_graph,\n        {\n            to_etype_name(k): th.arange(\n                dataset.train_enc_graph.num_edges(etype=to_etype_name(k))\n            )\n            for k in dataset.possible_rating_values\n        },\n        sampler,\n        use_ddp=n_gpus > 1,\n        batch_size=args.minibatch_size,\n        shuffle=True,\n        drop_last=False,\n    )\n\n    if proc_id == 0:\n        valid_dataloader = dgl.dataloading.DataLoader(\n            dataset.valid_dec_graph,\n            th.arange(dataset.valid_dec_graph.num_edges()),\n            sampler,\n            g_sampling=dataset.valid_enc_graph,\n            batch_size=args.minibatch_size,\n            shuffle=False,\n            drop_last=False,\n        )\n        test_dataloader = dgl.dataloading.DataLoader(\n            dataset.test_dec_graph,\n            th.arange(dataset.test_dec_graph.num_edges()),\n            sampler,\n            g_sampling=dataset.test_enc_graph,\n            batch_size=args.minibatch_size,\n            shuffle=False,\n            drop_last=False,\n        )\n\n    nd_possible_rating_values = th.FloatTensor(dataset.possible_rating_values)\n    nd_possible_rating_values = nd_possible_rating_values.to(dev_id)\n\n    net = Net(args=args, dev_id=dev_id)\n    net = net.to(dev_id)\n    if n_gpus > 1:\n        net = DistributedDataParallel(\n            net, device_ids=[dev_id], output_device=dev_id\n        )\n    rating_loss_net = nn.CrossEntropyLoss()\n    learning_rate = args.train_lr\n    optimizer = get_optimizer(args.train_optimizer)(\n        net.parameters(), lr=learning_rate\n    )\n    print(\"Loading network finished ...\\n\")\n\n    ### declare the loss information\n    best_valid_rmse = np.inf\n    no_better_valid = 0\n    best_epoch = -1\n    count_rmse = 0\n    count_num = 0\n    count_loss = 0\n    print(\"Start training ...\")\n    dur = []\n    iter_idx = 1\n\n    for epoch in range(1, args.train_max_epoch):\n        if n_gpus > 1:\n            dataloader.set_epoch(epoch)\n        if epoch > 1:\n            t0 = time.time()\n        net.train()\n        with tqdm.tqdm(dataloader) as tq:\n            for step, (input_nodes, pair_graph, blocks) in enumerate(tq):\n                head_feat, tail_feat, blocks = load_subtensor(\n                    input_nodes,\n                    pair_graph,\n                    blocks,\n                    dataset,\n                    dataset.train_enc_graph,\n                )\n                frontier = blocks[0]\n                compact_g = flatten_etypes(pair_graph, dataset, \"train\").to(\n                    dev_id\n                )\n                true_relation_labels = compact_g.edata[\"label\"]\n                true_relation_ratings = compact_g.edata[\"rating\"]\n\n                head_feat = head_feat.to(dev_id)\n                tail_feat = tail_feat.to(dev_id)\n                frontier = frontier.to(dev_id)\n\n                pred_ratings = net(\n                    compact_g,\n                    frontier,\n                    head_feat,\n                    tail_feat,\n                    dataset.possible_rating_values,\n                )\n                loss = rating_loss_net(\n                    pred_ratings, true_relation_labels.to(dev_id)\n                ).mean()\n                count_loss += loss.item()\n                optimizer.zero_grad()\n                loss.backward()\n                nn.utils.clip_grad_norm_(net.parameters(), args.train_grad_clip)\n                optimizer.step()\n\n                if proc_id == 0 and iter_idx == 1:\n                    print(\n                        \"Total #Param of net: %d\" % (torch_total_param_num(net))\n                    )\n\n                real_pred_ratings = (\n                    th.softmax(pred_ratings, dim=1)\n                    * nd_possible_rating_values.view(1, -1)\n                ).sum(dim=1)\n                rmse = (\n                    (real_pred_ratings - true_relation_ratings.to(dev_id)) ** 2\n                ).sum()\n                count_rmse += rmse.item()\n                count_num += pred_ratings.shape[0]\n\n                tq.set_postfix(\n                    {\n                        \"loss\": \"{:.4f}\".format(count_loss / iter_idx),\n                        \"rmse\": \"{:.4f}\".format(count_rmse / count_num),\n                    },\n                    refresh=False,\n                )\n\n                iter_idx += 1\n\n        if epoch > 1:\n            epoch_time = time.time() - t0\n            print(\"Epoch {} time {}\".format(epoch, epoch_time))\n\n        if epoch % args.train_valid_interval == 0:\n            if n_gpus > 1:\n                th.distributed.barrier()\n            if proc_id == 0:\n                valid_rmse = evaluate(\n                    args=args,\n                    dev_id=dev_id,\n                    net=net,\n                    dataset=dataset,\n                    dataloader=valid_dataloader,\n                    segment=\"valid\",\n                )\n                logging_str = \"Val RMSE={:.4f}\".format(valid_rmse)\n\n                if valid_rmse < best_valid_rmse:\n                    best_valid_rmse = valid_rmse\n                    no_better_valid = 0\n                    best_epoch = epoch\n                    test_rmse = evaluate(\n                        args=args,\n                        dev_id=dev_id,\n                        net=net,\n                        dataset=dataset,\n                        dataloader=test_dataloader,\n                        segment=\"test\",\n                    )\n                    best_test_rmse = test_rmse\n                    logging_str += \", Test RMSE={:.4f}\".format(test_rmse)\n                else:\n                    no_better_valid += 1\n                    if (\n                        no_better_valid > args.train_early_stopping_patience\n                        and learning_rate <= args.train_min_lr\n                    ):\n                        logging.info(\n                            \"Early stopping threshold reached. Stop training.\"\n                        )\n                        break\n                    if no_better_valid > args.train_decay_patience:\n                        new_lr = max(\n                            learning_rate * args.train_lr_decay_factor,\n                            args.train_min_lr,\n                        )\n                        if new_lr < learning_rate:\n                            logging.info(\"\\tChange the LR to %g\" % new_lr)\n                            learning_rate = new_lr\n                            for p in optimizer.param_groups:\n                                p[\"lr\"] = learning_rate\n                            no_better_valid = 0\n                            print(\"Change the LR to %g\" % new_lr)\n            # sync on evalution\n            if n_gpus > 1:\n                th.distributed.barrier()\n\n        if proc_id == 0:\n            print(logging_str)\n    if proc_id == 0:\n        print(\n            \"Best epoch Idx={}, Best Valid RMSE={:.4f}, Best Test RMSE={:.4f}\".format(\n                best_epoch, best_valid_rmse, best_test_rmse\n            )\n        )\n\n\nif __name__ == \"__main__\":\n    args = config()\n\n    devices = list(map(int, args.gpu.split(\",\")))\n    n_gpus = len(devices)\n\n    # For GCMC based on sampling, we require node has its own features.\n    # Otherwise (node_id is the feature), the model can not scale\n    dataset = MovieLens(\n        args.data_name,\n        \"cpu\",\n        mix_cpu_gpu=args.mix_cpu_gpu,\n        use_one_hot_fea=args.use_one_hot_fea,\n        symm=args.gcn_agg_norm_symm,\n        test_ratio=args.data_test_ratio,\n        valid_ratio=args.data_valid_ratio,\n    )\n    print(\"Loading data finished ...\\n\")\n\n    args.src_in_units = dataset.user_feature_shape[1]\n    args.dst_in_units = dataset.movie_feature_shape[1]\n    args.rating_vals = dataset.possible_rating_values\n\n    # cpu\n    if devices[0] == -1:\n        run(0, 0, args, [\"cpu\"], dataset)\n    # gpu\n    elif n_gpus == 1:\n        run(0, n_gpus, args, devices, dataset)\n    # multi gpu\n    else:\n        # Create csr/coo/csc formats before launching training processes with multi-gpu.\n        # This avoids creating certain formats in each sub-process, which saves momory and CPU.\n        dataset.train_enc_graph.create_formats_()\n        dataset.train_dec_graph.create_formats_()\n        mp.spawn(run, args=(n_gpus, args, devices, dataset), nprocs=n_gpus)\n"
  },
  {
    "path": "examples/pytorch/gcmc/utils.py",
    "content": "import csv\nimport re\nfrom collections import OrderedDict\n\nimport numpy as np\nimport torch as th\nimport torch.nn as nn\nimport torch.optim as optim\n\n\nclass MetricLogger(object):\n    def __init__(self, attr_names, parse_formats, save_path):\n        self._attr_format_dict = OrderedDict(zip(attr_names, parse_formats))\n        self._file = open(save_path, \"w\")\n        self._csv = csv.writer(self._file)\n        self._csv.writerow(attr_names)\n        self._file.flush()\n\n    def log(self, **kwargs):\n        self._csv.writerow(\n            [\n                parse_format % kwargs[attr_name]\n                for attr_name, parse_format in self._attr_format_dict.items()\n            ]\n        )\n        self._file.flush()\n\n    def close(self):\n        self._file.close()\n\n\ndef torch_total_param_num(net):\n    return sum([np.prod(p.shape) for p in net.parameters()])\n\n\ndef torch_net_info(net, save_path=None):\n    info_str = (\n        \"Total Param Number: {}\\n\".format(torch_total_param_num(net))\n        + \"Params:\\n\"\n    )\n    for k, v in net.named_parameters():\n        info_str += \"\\t{}: {}, {}\\n\".format(k, v.shape, np.prod(v.shape))\n    info_str += str(net)\n    if save_path is not None:\n        with open(save_path, \"w\") as f:\n            f.write(info_str)\n    return info_str\n\n\ndef get_activation(act):\n    \"\"\"Get the activation based on the act string\n\n    Parameters\n    ----------\n    act: str or callable function\n\n    Returns\n    -------\n    ret: callable function\n    \"\"\"\n    if act is None:\n        return lambda x: x\n    if isinstance(act, str):\n        if act == \"leaky\":\n            return nn.LeakyReLU(0.1)\n        elif act == \"relu\":\n            return nn.ReLU()\n        elif act == \"tanh\":\n            return nn.Tanh()\n        elif act == \"sigmoid\":\n            return nn.Sigmoid()\n        elif act == \"softsign\":\n            return nn.Softsign()\n        else:\n            raise NotImplementedError\n    else:\n        return act\n\n\ndef get_optimizer(opt):\n    if opt == \"sgd\":\n        return optim.SGD\n    elif opt == \"adam\":\n        return optim.Adam\n    else:\n        raise NotImplementedError\n\n\ndef to_etype_name(rating):\n    return str(rating).replace(\".\", \"_\")\n"
  },
  {
    "path": "examples/pytorch/gcn/README.md",
    "content": "Graph Convolutional Networks (GCN)\n============\n\n- Paper link: [https://arxiv.org/abs/1609.02907](https://arxiv.org/abs/1609.02907)\n- Author's code repo: [https://github.com/tkipf/gcn](https://github.com/tkipf/gcn).\n\nHow to run\n-------\n\n### DGL built-in GraphConv module\n\nRun with the following (available dataset: \"cora\", \"citeseer\", \"pubmed\")\n```bash\npython3 train.py --dataset cora\n```\n\nSummary\n-------\n* cora: ~0.810 (paper: 0.815)\n* citeseer: ~0.707 (paper: 0.703)\n* pubmed: ~0.792 (paper: 0.790)\n\n"
  },
  {
    "path": "examples/pytorch/gcn/train.py",
    "content": "import argparse\n\nimport dgl\nimport dgl.nn as dglnn\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dgl import AddSelfLoop\nfrom dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset\n\n\nclass GCN(nn.Module):\n    def __init__(self, in_size, hid_size, out_size):\n        super().__init__()\n        self.layers = nn.ModuleList()\n        # two-layer GCN\n        self.layers.append(\n            dglnn.GraphConv(in_size, hid_size, activation=F.relu)\n        )\n        self.layers.append(dglnn.GraphConv(hid_size, out_size))\n        self.dropout = nn.Dropout(0.5)\n\n    def forward(self, g, features):\n        h = features\n        for i, layer in enumerate(self.layers):\n            if i != 0:\n                h = self.dropout(h)\n            h = layer(g, h)\n        return h\n\n\ndef evaluate(g, features, labels, mask, model):\n    model.eval()\n    with torch.no_grad():\n        logits = model(g, features)\n        logits = logits[mask]\n        labels = labels[mask]\n        _, indices = torch.max(logits, dim=1)\n        correct = torch.sum(indices == labels)\n        return correct.item() * 1.0 / len(labels)\n\n\ndef train(g, features, labels, masks, model):\n    # define train/val samples, loss function and optimizer\n    train_mask = masks[0]\n    val_mask = masks[1]\n    loss_fcn = nn.CrossEntropyLoss()\n    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)\n\n    # training loop\n    for epoch in range(200):\n        model.train()\n        logits = model(g, features)\n        loss = loss_fcn(logits[train_mask], labels[train_mask])\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n        acc = evaluate(g, features, labels, val_mask, model)\n        print(\n            \"Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} \".format(\n                epoch, loss.item(), acc\n            )\n        )\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--dataset\",\n        type=str,\n        default=\"cora\",\n        help=\"Dataset name ('cora', 'citeseer', 'pubmed').\",\n    )\n    parser.add_argument(\n        \"--dt\",\n        type=str,\n        default=\"float\",\n        help=\"data type(float, bfloat16)\",\n    )\n    args = parser.parse_args()\n    print(f\"Training with DGL built-in GraphConv module.\")\n\n    # load and preprocess dataset\n    transform = (\n        AddSelfLoop()\n    )  # by default, it will first remove self-loops to prevent duplication\n    if args.dataset == \"cora\":\n        data = CoraGraphDataset(transform=transform)\n    elif args.dataset == \"citeseer\":\n        data = CiteseerGraphDataset(transform=transform)\n    elif args.dataset == \"pubmed\":\n        data = PubmedGraphDataset(transform=transform)\n    else:\n        raise ValueError(\"Unknown dataset: {}\".format(args.dataset))\n    g = data[0]\n    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n    g = g.int().to(device)\n    features = g.ndata[\"feat\"]\n    labels = g.ndata[\"label\"]\n    masks = g.ndata[\"train_mask\"], g.ndata[\"val_mask\"], g.ndata[\"test_mask\"]\n\n    # create GCN model\n    in_size = features.shape[1]\n    out_size = data.num_classes\n    model = GCN(in_size, 16, out_size).to(device)\n\n    # convert model and graph to bfloat16 if needed\n    if args.dt == \"bfloat16\":\n        g = dgl.to_bfloat16(g)\n        features = features.to(dtype=torch.bfloat16)\n        model = model.to(dtype=torch.bfloat16)\n\n    # model training\n    print(\"Training...\")\n    train(g, features, labels, masks, model)\n\n    # test the model\n    print(\"Testing...\")\n    acc = evaluate(g, features, labels, masks[2], model)\n    print(\"Test accuracy {:.4f}\".format(acc))\n"
  },
  {
    "path": "examples/pytorch/geniepath/README.md",
    "content": "# DGL Implementation of the GeniePath Paper\n\nThis DGL example implements the GNN model proposed in the paper [GeniePath: Graph Neural Networks with Adaptive Receptive Paths](https://arxiv.org/abs/1802.00910).\n\nExample implementor\n----------------------\nThis example was implemented by [Kay Liu](https://github.com/kayzliu) during his SDE intern work at the AWS Shanghai AI Lab.\n\nDependencies\n----------------------\n- Python 3.7.10\n- PyTorch 1.8.1\n- dgl 0.7.0\n- scikit-learn 0.23.2\n\nDataset\n---------------------------------------\nThe datasets used for node classification are [Pubmed citation network dataset](https://docs.dgl.ai/api/python/dgl.data.html#dgl.data.PubmedGraphDataset) (tranductive) and [Protein-Protein Interaction dataset](https://docs.dgl.ai/api/python/dgl.data.html#dgl.data.PPIDataset) (inductive).\n\nHow to run\n--------------------------------\nIf want to train on Pubmed (transductive), run\n```\npython pubmed.py\n```\n\nIf want to use a GPU, run\n```\npython pubmed.py --gpu 0\n```\n\nIf want to train GeniePath-Lazy, run\n```\npython pubmed.py --lazy True\n```\n\nIf want to train on PPI (inductive), run\n```\npython ppi.py\n```\n\nPerformance\n-------------------------\nDataset: Pubmed (ACC)\n|Method | GeniePath|\n| ------ | ----------- |\n| Paper  | 78.5%       |\n| DGL    | 73.0%       |\n\nDataset: PPI (micro-F1)\n|Method | GeniePath| GeniePath-lazy| GeniePath-lazy-residual|\n| ------ | ----------- | ------------- | ------------------ |\n| Paper  | 0.9520      | 0.9790        | 0.9850        |\n| DGL    | 0.9729      | 0.9802        | 0.9798        |\n"
  },
  {
    "path": "examples/pytorch/geniepath/model.py",
    "content": "import torch as th\nimport torch.nn as nn\n\nfrom dgl.nn import GATConv\nfrom torch.nn import LSTM\n\n\nclass GeniePathConv(nn.Module):\n    def __init__(self, in_dim, hid_dim, out_dim, num_heads=1, residual=False):\n        super(GeniePathConv, self).__init__()\n        self.breadth_func = GATConv(\n            in_dim, hid_dim, num_heads=num_heads, residual=residual\n        )\n        self.depth_func = LSTM(hid_dim, out_dim)\n\n    def forward(self, graph, x, h, c):\n        x = self.breadth_func(graph, x)\n        x = th.tanh(x)\n        x = th.mean(x, dim=1)\n        x, (h, c) = self.depth_func(x.unsqueeze(0), (h, c))\n        x = x[0]\n        return x, (h, c)\n\n\nclass GeniePath(nn.Module):\n    def __init__(\n        self,\n        in_dim,\n        out_dim,\n        hid_dim=16,\n        num_layers=2,\n        num_heads=1,\n        residual=False,\n    ):\n        super(GeniePath, self).__init__()\n        self.hid_dim = hid_dim\n        self.linear1 = nn.Linear(in_dim, hid_dim)\n        self.linear2 = nn.Linear(hid_dim, out_dim)\n        self.layers = nn.ModuleList()\n        for i in range(num_layers):\n            self.layers.append(\n                GeniePathConv(\n                    hid_dim,\n                    hid_dim,\n                    hid_dim,\n                    num_heads=num_heads,\n                    residual=residual,\n                )\n            )\n\n    def forward(self, graph, x):\n        h = th.zeros(1, x.shape[0], self.hid_dim).to(x.device)\n        c = th.zeros(1, x.shape[0], self.hid_dim).to(x.device)\n\n        x = self.linear1(x)\n        for layer in self.layers:\n            x, (h, c) = layer(graph, x, h, c)\n        x = self.linear2(x)\n\n        return x\n\n\nclass GeniePathLazy(nn.Module):\n    def __init__(\n        self,\n        in_dim,\n        out_dim,\n        hid_dim=16,\n        num_layers=2,\n        num_heads=1,\n        residual=False,\n    ):\n        super(GeniePathLazy, self).__init__()\n        self.hid_dim = hid_dim\n        self.linear1 = nn.Linear(in_dim, hid_dim)\n        self.linear2 = th.nn.Linear(hid_dim, out_dim)\n        self.breaths = nn.ModuleList()\n        self.depths = nn.ModuleList()\n        for i in range(num_layers):\n            self.breaths.append(\n                GATConv(\n                    hid_dim, hid_dim, num_heads=num_heads, residual=residual\n                )\n            )\n            self.depths.append(LSTM(hid_dim * 2, hid_dim))\n\n    def forward(self, graph, x):\n        h = th.zeros(1, x.shape[0], self.hid_dim).to(x.device)\n        c = th.zeros(1, x.shape[0], self.hid_dim).to(x.device)\n\n        x = self.linear1(x)\n        h_tmps = []\n        for layer in self.breaths:\n            h_tmps.append(th.mean(th.tanh(layer(graph, x)), dim=1))\n        x = x.unsqueeze(0)\n        for h_tmp, layer in zip(h_tmps, self.depths):\n            in_cat = th.cat((h_tmp.unsqueeze(0), x), -1)\n            x, (h, c) = layer(in_cat, (h, c))\n        x = self.linear2(x[0])\n\n        return x\n"
  },
  {
    "path": "examples/pytorch/geniepath/ppi.py",
    "content": "import argparse\n\nimport numpy as np\nimport torch as th\nimport torch.optim as optim\n\nfrom dgl.data import PPIDataset\nfrom dgl.dataloading import GraphDataLoader\nfrom model import GeniePath, GeniePathLazy\nfrom sklearn.metrics import f1_score\n\n\ndef evaluate(model, loss_fn, dataloader, device=\"cpu\"):\n    loss = 0\n    f1 = 0\n    num_blocks = 0\n    for subgraph in dataloader:\n        subgraph = subgraph.to(device)\n        label = subgraph.ndata[\"label\"].to(device)\n        feat = subgraph.ndata[\"feat\"]\n        logits = model(subgraph, feat)\n\n        # compute loss\n        loss += loss_fn(logits, label).item()\n        predict = np.where(logits.data.cpu().numpy() >= 0.0, 1, 0)\n        f1 += f1_score(label.cpu(), predict, average=\"micro\")\n        num_blocks += 1\n\n    return f1 / num_blocks, loss / num_blocks\n\n\ndef main(args):\n    # Step 1: Prepare graph data and retrieve train/validation/test index ============================= #\n    # Load dataset\n    train_dataset = PPIDataset(mode=\"train\")\n    valid_dataset = PPIDataset(mode=\"valid\")\n    test_dataset = PPIDataset(mode=\"test\")\n    train_dataloader = GraphDataLoader(\n        train_dataset, batch_size=args.batch_size\n    )\n    valid_dataloader = GraphDataLoader(\n        valid_dataset, batch_size=args.batch_size\n    )\n    test_dataloader = GraphDataLoader(test_dataset, batch_size=args.batch_size)\n\n    # check cuda\n    if args.gpu >= 0 and th.cuda.is_available():\n        device = \"cuda:{}\".format(args.gpu)\n    else:\n        device = \"cpu\"\n\n    num_classes = train_dataset.num_classes\n\n    # Extract node features\n    graph = train_dataset[0]\n    feat = graph.ndata[\"feat\"]\n\n    # Step 2: Create model =================================================================== #\n    if args.lazy:\n        model = GeniePathLazy(\n            in_dim=feat.shape[-1],\n            out_dim=num_classes,\n            hid_dim=args.hid_dim,\n            num_layers=args.num_layers,\n            num_heads=args.num_heads,\n            residual=args.residual,\n        )\n    else:\n        model = GeniePath(\n            in_dim=feat.shape[-1],\n            out_dim=num_classes,\n            hid_dim=args.hid_dim,\n            num_layers=args.num_layers,\n            num_heads=args.num_heads,\n            residual=args.residual,\n        )\n\n    model = model.to(device)\n\n    # Step 3: Create training components ===================================================== #\n    loss_fn = th.nn.BCEWithLogitsLoss()\n    optimizer = optim.Adam(model.parameters(), lr=args.lr)\n\n    # Step 4: training epochs =============================================================== #\n    for epoch in range(args.max_epoch):\n        model.train()\n        tr_loss = 0\n        tr_f1 = 0\n        num_blocks = 0\n        for subgraph in train_dataloader:\n            subgraph = subgraph.to(device)\n            label = subgraph.ndata[\"label\"]\n            feat = subgraph.ndata[\"feat\"]\n            logits = model(subgraph, feat)\n\n            # compute loss\n            batch_loss = loss_fn(logits, label)\n            tr_loss += batch_loss.item()\n            tr_predict = np.where(logits.data.cpu().numpy() >= 0.0, 1, 0)\n            tr_f1 += f1_score(label.cpu(), tr_predict, average=\"micro\")\n            num_blocks += 1\n\n            # backward\n            optimizer.zero_grad()\n            batch_loss.backward()\n            optimizer.step()\n\n        # validation\n        model.eval()\n        val_f1, val_loss = evaluate(model, loss_fn, valid_dataloader, device)\n\n        print(\n            \"In epoch {}, Train F1: {:.4f} | Train Loss: {:.4f}; Valid F1: {:.4f} | Valid loss: {:.4f}\".format(\n                epoch,\n                tr_f1 / num_blocks,\n                tr_loss / num_blocks,\n                val_f1,\n                val_loss,\n            )\n        )\n\n    # Test after all epoch\n    model.eval()\n    test_f1, test_loss = evaluate(model, loss_fn, test_dataloader, device)\n\n    print(\"Test F1: {:.4f} | Test loss: {:.4f}\".format(test_f1, test_loss))\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"GeniePath\")\n    parser.add_argument(\n        \"--gpu\", type=int, default=-1, help=\"GPU Index. Default: -1, using CPU.\"\n    )\n    parser.add_argument(\n        \"--hid_dim\", type=int, default=256, help=\"Hidden layer dimension\"\n    )\n    parser.add_argument(\n        \"--num_layers\", type=int, default=3, help=\"Number of GeniePath layers\"\n    )\n    parser.add_argument(\n        \"--max_epoch\",\n        type=int,\n        default=1000,\n        help=\"The max number of epochs. Default: 1000\",\n    )\n    parser.add_argument(\n        \"--lr\",\n        type=float,\n        default=0.0004,\n        help=\"Learning rate. Default: 0.0004\",\n    )\n    parser.add_argument(\n        \"--num_heads\",\n        type=int,\n        default=1,\n        help=\"Number of head in breadth function. Default: 1\",\n    )\n    parser.add_argument(\n        \"--residual\", type=bool, default=False, help=\"Residual in GAT or not\"\n    )\n    parser.add_argument(\n        \"--batch_size\",\n        type=int,\n        default=2,\n        help=\"Batch size of graph dataloader\",\n    )\n    parser.add_argument(\n        \"--lazy\", type=bool, default=False, help=\"Variant GeniePath-Lazy\"\n    )\n\n    args = parser.parse_args()\n    print(args)\n    th.manual_seed(16)\n    main(args)\n"
  },
  {
    "path": "examples/pytorch/geniepath/pubmed.py",
    "content": "import argparse\n\nimport torch as th\nimport torch.optim as optim\n\nfrom dgl.data import PubmedGraphDataset\nfrom model import GeniePath, GeniePathLazy\nfrom sklearn.metrics import accuracy_score\n\n\ndef main(args):\n    # Step 1: Prepare graph data and retrieve train/validation/test index ============================= #\n    # Load dataset\n    dataset = PubmedGraphDataset()\n    graph = dataset[0]\n\n    # check cuda\n    if args.gpu >= 0 and th.cuda.is_available():\n        device = \"cuda:{}\".format(args.gpu)\n    else:\n        device = \"cpu\"\n\n    num_classes = dataset.num_classes\n\n    # retrieve label of ground truth\n    label = graph.ndata[\"label\"].to(device)\n\n    # Extract node features\n    feat = graph.ndata[\"feat\"].to(device)\n\n    # retrieve masks for train/validation/test\n    train_mask = graph.ndata[\"train_mask\"]\n    val_mask = graph.ndata[\"val_mask\"]\n    test_mask = graph.ndata[\"test_mask\"]\n\n    train_idx = th.nonzero(train_mask, as_tuple=False).squeeze(1).to(device)\n    val_idx = th.nonzero(val_mask, as_tuple=False).squeeze(1).to(device)\n    test_idx = th.nonzero(test_mask, as_tuple=False).squeeze(1).to(device)\n\n    graph = graph.to(device)\n\n    # Step 2: Create model =================================================================== #\n    if args.lazy:\n        model = GeniePathLazy(\n            in_dim=feat.shape[-1],\n            out_dim=num_classes,\n            hid_dim=args.hid_dim,\n            num_layers=args.num_layers,\n            num_heads=args.num_heads,\n            residual=args.residual,\n        )\n    else:\n        model = GeniePath(\n            in_dim=feat.shape[-1],\n            out_dim=num_classes,\n            hid_dim=args.hid_dim,\n            num_layers=args.num_layers,\n            num_heads=args.num_heads,\n            residual=args.residual,\n        )\n\n    model = model.to(device)\n\n    # Step 3: Create training components ===================================================== #\n    loss_fn = th.nn.CrossEntropyLoss()\n    optimizer = optim.Adam(model.parameters(), lr=args.lr)\n\n    # Step 4: training epochs =============================================================== #\n    for epoch in range(args.max_epoch):\n        # Training and validation\n        model.train()\n        logits = model(graph, feat)\n\n        # compute loss\n        tr_loss = loss_fn(logits[train_idx], label[train_idx])\n        tr_acc = accuracy_score(\n            label[train_idx].cpu(), logits[train_idx].argmax(dim=1).cpu()\n        )\n\n        # validation\n        valid_loss = loss_fn(logits[val_idx], label[val_idx])\n        valid_acc = accuracy_score(\n            label[val_idx].cpu(), logits[val_idx].argmax(dim=1).cpu()\n        )\n\n        # backward\n        optimizer.zero_grad()\n        tr_loss.backward()\n        optimizer.step()\n\n        # Print out performance\n        print(\n            \"In epoch {}, Train ACC: {:.4f} | Train Loss: {:.4f}; Valid ACC: {:.4f} | Valid loss: {:.4f}\".format(\n                epoch, tr_acc, tr_loss.item(), valid_acc, valid_loss.item()\n            )\n        )\n\n    # Test after all epoch\n    model.eval()\n\n    # forward\n    logits = model(graph, feat)\n\n    # compute loss\n    test_loss = loss_fn(logits[test_idx], label[test_idx])\n    test_acc = accuracy_score(\n        label[test_idx].cpu(), logits[test_idx].argmax(dim=1).cpu()\n    )\n\n    print(\n        \"Test ACC: {:.4f} | Test loss: {:.4f}\".format(\n            test_acc, test_loss.item()\n        )\n    )\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"GeniePath\")\n    parser.add_argument(\n        \"--gpu\", type=int, default=-1, help=\"GPU Index. Default: -1, using CPU.\"\n    )\n    parser.add_argument(\n        \"--hid_dim\", type=int, default=16, help=\"Hidden layer dimension\"\n    )\n    parser.add_argument(\n        \"--num_layers\", type=int, default=2, help=\"Number of GeniePath layers\"\n    )\n    parser.add_argument(\n        \"--max_epoch\",\n        type=int,\n        default=300,\n        help=\"The max number of epochs. Default: 300\",\n    )\n    parser.add_argument(\n        \"--lr\",\n        type=float,\n        default=0.0004,\n        help=\"Learning rate. Default: 0.0004\",\n    )\n    parser.add_argument(\n        \"--num_heads\",\n        type=int,\n        default=1,\n        help=\"Number of head in breadth function. Default: 1\",\n    )\n    parser.add_argument(\n        \"--residual\", type=bool, default=False, help=\"Residual in GAT or not\"\n    )\n    parser.add_argument(\n        \"--lazy\", type=bool, default=False, help=\"Variant GeniePath-Lazy\"\n    )\n\n    args = parser.parse_args()\n    th.manual_seed(16)\n    print(args)\n    main(args)\n"
  },
  {
    "path": "examples/pytorch/ggnn/README.md",
    "content": "# Gated Graph Neural Network (GGNN)\n\n- Paper link: https://arxiv.org/pdf/1511.05493.pdf\n\n## Dependencies\n- PyTorch 1.0+\n\n- DGL 0.3.1+\n\n## GGNN implemented in dgl\n\nIn dgl, GGNN is implemented as module `GatedGraphConv`, it can be imported as follows:\n\n```python\nfrom dgl.nn.pytorch import GatedGraphConv\n```\n\n## Solving bAbI tasks\n\nIn this example, we use GGNN to solve some of the [bAbI](https://github.com/facebook/bAbI-tasks) \ntasks solved in the paper.\n\n#### Overview of bAbI tasks\n\nbAbI is a set of question answering tasks that require a system to do multi-step reasoning.\nDatasets of bAbI tasks are generated by templates, which can be natural language or symbolic\nform. In this example, we follow the paper to generate the datasets using symbolic form.\nThere are 20 tasks in bAbI, in this example, we follow the paper to do task 4, 15, 16, 18 and 19.\n\n#### Task 4: Two argument relations: subject vs. object\n\nAn example of task 4 is as follows\n```\n1 C e A\n2 A e B\n3 eval A w\tC\n```\n\nA, B, C are nodes; e, w are edges, there are totally four kinds of edges: `n, s, w, e`, which can \nbe viewed as north, south, west, east.\n\nThe first two lines are conditions, and the third line are the question and answer. \nSo the explanation of the example is:\n```\n1 Go east from C, we can reach A\n2 Go east from A, we can reach B\n3 Question: where can we reach if we go west from A? Answer: C\n```\n\nIf we represent the conditions using a graph, we can view this task as a `Node Selection` task.\nFor different edges in questions, we view them as different question types, we train\n separate models for each question type. The module for solving node selection tasks is\n implemented in `ggnn_ns.py`.\n \nFor four question types `n, s, w, e`, we assign a question id for them ranging from 0 to 3. \nFor each question id, run the following commands for training and testing:\n\n```bash\npython train_ns.py --task_id=4 --question_id=0 --train_num=50 --epochs=10\npython train_ns.py --task_id=4 --question_id=1 --train_num=50 --epochs=10\npython train_ns.py --task_id=4 --question_id=2 --train_num=50 --epochs=10\npython train_ns.py --task_id=4 --question_id=3 --train_num=50 --epochs=10\n```\n\nThe training file name `train_ns` means training node selection. `train_num` means the number of \ntraining examples used.\n\n#### Task 15: Basic deduction\n\nTask 15 is similar to task 4, it's also a Node Selection task. An example is shown below:\n\n```\n1 I has_fear C\n2 H is C\n3 G is I\n4 A is B\n5 E has_fear C\n6 C has_fear I\n7 B has_fear C\n8 F is E\n9 eval H has_fear\tI\n```\n\nThere are two types of edges in this task: `is, has_fear`. There is only one question type in\nthis task: `has_fear`, we assign question id `1` for it.\n\nRun the following command for training and testing:\n```bash\npython train_ns.py --task_id=15 --question_id=1 --train_num=50 --epochs=15 --lr=1e-2\n```\n\n#### Task 16: Basic induction\n\nTask 16 is similar to task 15. An example of task 16 is shown below\n\n```\n1 J has_color F\n2 K has_color I\n3 A has_color I\n4 G is D\n5 J is C\n6 H has_color I\n7 H is D\n8 A is D\n9 K is D\n10 eval G has_color\tI\n```\n\nThere are two types of edges in this task: `is, has_color`. There is only one question type in\nthis task: `has_color`, we assign question id `1` for it.\n\nRun the following command for training and testing:\n\n```bash\npython train_ns.py --task_id=16 --question_id=1 --train_num=50 --epochs=20 --lr=1e-2\n```\n\n#### Task 18: Reasoning about size\n\nTask 18 is a `Graph Classification` task, an example is shown below:\n\n```\n1 G > B\n2 G > D\n3 E > F\n4 E > A\n5 B > A\n6 E > B\n7 eval G < A\tfalse\n```\n\nLine 1 to line 6 give some conditions for comparision of the size of entities, line 7 is the\nquestion, asking whether `G < A` is `true` or `false`. So the input is a graph, the output is a\nbinary classification result. We view it as a `Graph Classification` task.\n\nFollowing the paper, we use GGNN to encode the graph, followed by a `GlobalAttentionPooling` \nlayer to pool the graph into a hidden vector, which is used to classify the graph.\n\nThe module for solving graph classification tasks is implemented in `ggnn_gc.py`.\n\nThere are two types of edges in this task: `>, <`, and so are the question types. We assign \nquestion ids `0, 1` to them.\n\nRun the following commands for training and testing:\n```bash\npython train_gc.py --task_id=18 --question_id=0 --train_num=50 --batch_size=10 --lr=1e-3 --epochs=20\npython train_gc.py --task_id=18 --question_id=1 --train_num=50 --batch_size=10 --lr=1e-3 --epochs=20\n```\n\n#### Task 19: Path finding\n\nAn example of task 19 is as follows:\n```\n1 D n A\n2 D s E\n3 G w D\n4 E s B\n5 eval path G A\tw,n\n```\n\nSimilar to task 4, there are four types of edges: `n, s, w, e`, which can \nbe viewed as north, south, west, east. The conditions are the same as task 4, the question in \nline 5 means `Question: find a path from G to A. Answer: first go west, then go north`. The \noutput is a sequence of edges. So there is no question type in this task.\n\nThe paper uses *Gated Graph Sequence Neural Networks (GGS-NNs)* to solve this kind of problems.\nIn this example, we implemented GGS-NNs in `ggsnn.py`, run the following command for training \nand testing:\n```bash\npython train_path_finding.py --train_num=250 --epochs=200\n```\n\n#### Results\n\nFollowing the paper, we use 10 different test sets for evaluation. The result is the mean and\nstandard deviation of the evaluation performance across the 10 datasets. Numbers in the parentheses\nare the number of training data used.\n\n|  Task ID  |    Reported <br> Accuracy   |      DGL <br> Accuracy       |\n|:---------:|-----------------------------|------------------------------|\n|  4        | 100.0 ± 0.00 (50)           | 100.0 ± 0.00 (50)|\n|  15       | 100.0 ± 0.00 (50)           | 100.0 ± 0.00 (50)|\n|  16       | 100.0 ± 0.00 (50)           | 100.0 ± 0.00 (50)|\n|  18       | 100.0 ± 0.00 (50)           | 100.0 ± 0.00 (50)|\n|  19       | 99.0 ± 1.1 (250)            | 97.8 ± 0.02 (50) |"
  },
  {
    "path": "examples/pytorch/ggnn/data_utils.py",
    "content": "\"\"\"\nData utils for processing bAbI datasets\n\"\"\"\n\nimport os\nimport string\n\nimport dgl\n\nimport torch\nfrom dgl.data.utils import (\n    _get_dgl_url,\n    download,\n    extract_archive,\n    get_download_dir,\n)\nfrom torch.utils.data import DataLoader\n\n\ndef get_babi_dataloaders(batch_size, train_size=50, task_id=4, q_type=0):\n    _download_babi_data()\n\n    node_dict = dict(\n        zip(list(string.ascii_uppercase), range(len(string.ascii_uppercase)))\n    )\n\n    if task_id == 4:\n        edge_dict = {\"n\": 0, \"s\": 1, \"w\": 2, \"e\": 3}\n        reverse_edge = {}\n        return _ns_dataloader(\n            train_size,\n            q_type,\n            batch_size,\n            node_dict,\n            edge_dict,\n            reverse_edge,\n            \"04\",\n        )\n    elif task_id == 15:\n        edge_dict = {\"is\": 0, \"has_fear\": 1}\n        reverse_edge = {}\n        return _ns_dataloader(\n            train_size,\n            q_type,\n            batch_size,\n            node_dict,\n            edge_dict,\n            reverse_edge,\n            \"15\",\n        )\n    elif task_id == 16:\n        edge_dict = {\"is\": 0, \"has_color\": 1}\n        reverse_edge = {0: 0}\n        return _ns_dataloader(\n            train_size,\n            q_type,\n            batch_size,\n            node_dict,\n            edge_dict,\n            reverse_edge,\n            \"16\",\n        )\n    elif task_id == 18:\n        edge_dict = {\">\": 0, \"<\": 1}\n        label_dict = {\"false\": 0, \"true\": 1}\n        reverse_edge = {0: 1, 1: 0}\n        return _gc_dataloader(\n            train_size,\n            q_type,\n            batch_size,\n            node_dict,\n            edge_dict,\n            label_dict,\n            reverse_edge,\n            \"18\",\n        )\n    elif task_id == 19:\n        edge_dict = {\"n\": 0, \"s\": 1, \"w\": 2, \"e\": 3, \"<end>\": 4}\n        reverse_edge = {0: 1, 1: 0, 2: 3, 3: 2}\n        max_seq_length = 2\n        return _path_finding_dataloader(\n            train_size,\n            batch_size,\n            node_dict,\n            edge_dict,\n            reverse_edge,\n            \"19\",\n            max_seq_length,\n        )\n\n\ndef _ns_dataloader(\n    train_size, q_type, batch_size, node_dict, edge_dict, reverse_edge, path\n):\n    def _collate_fn(batch):\n        graphs = []\n        labels = []\n        for d in batch:\n            edges = d[\"edges\"]\n\n            node_ids = []\n            for s, e, t in edges:\n                if s not in node_ids:\n                    node_ids.append(s)\n                if t not in node_ids:\n                    node_ids.append(t)\n            g = dgl.graph([])\n            g.add_nodes(len(node_ids))\n            g.ndata[\"node_id\"] = torch.tensor(node_ids, dtype=torch.long)\n\n            nid2idx = dict(zip(node_ids, list(range(len(node_ids)))))\n\n            # convert label to node index\n            label = d[\"eval\"][2]\n            label_idx = nid2idx[label]\n            labels.append(label_idx)\n\n            edge_types = []\n            for s, e, t in edges:\n                g.add_edges(nid2idx[s], nid2idx[t])\n                edge_types.append(e)\n                if e in reverse_edge:\n                    g.add_edges(nid2idx[t], nid2idx[s])\n                    edge_types.append(reverse_edge[e])\n            g.edata[\"type\"] = torch.tensor(edge_types, dtype=torch.long)\n            annotation = torch.zeros(len(node_ids), dtype=torch.long)\n            annotation[nid2idx[d[\"eval\"][0]]] = 1\n            g.ndata[\"annotation\"] = annotation.unsqueeze(-1)\n            graphs.append(g)\n        batch_graph = dgl.batch(graphs)\n        labels = torch.tensor(labels, dtype=torch.long)\n        return batch_graph, labels\n\n    def _get_dataloader(data, shuffle):\n        return DataLoader(\n            dataset=data,\n            batch_size=batch_size,\n            shuffle=shuffle,\n            collate_fn=_collate_fn,\n        )\n\n    train_set, dev_set, test_sets = _convert_ns_dataset(\n        train_size, node_dict, edge_dict, path, q_type\n    )\n    train_dataloader = _get_dataloader(train_set, True)\n    dev_dataloader = _get_dataloader(dev_set, False)\n    test_dataloaders = []\n    for d in test_sets:\n        dl = _get_dataloader(d, False)\n        test_dataloaders.append(dl)\n\n    return train_dataloader, dev_dataloader, test_dataloaders\n\n\ndef _convert_ns_dataset(train_size, node_dict, edge_dict, path, q_type):\n    total_num = 11000\n\n    def convert(file):\n        dataset = []\n        d = dict()\n        with open(file, \"r\") as f:\n            for i, line in enumerate(f.readlines()):\n                line = line.strip().split()\n                if line[0] == \"1\" and len(d) > 0:\n                    d = dict()\n                if line[1] == \"eval\":\n                    # (src, edge, label)\n                    d[\"eval\"] = (\n                        node_dict[line[2]],\n                        edge_dict[line[3]],\n                        node_dict[line[4]],\n                    )\n                    if d[\"eval\"][1] == q_type:\n                        dataset.append(d)\n                        if len(dataset) >= total_num:\n                            break\n                else:\n                    if \"edges\" not in d:\n                        d[\"edges\"] = []\n                    d[\"edges\"].append(\n                        (\n                            node_dict[line[1]],\n                            edge_dict[line[2]],\n                            node_dict[line[3]],\n                        )\n                    )\n        return dataset\n\n    download_dir = get_download_dir()\n    filename = os.path.join(download_dir, \"babi_data\", path, \"data.txt\")\n    data = convert(filename)\n\n    assert len(data) == total_num\n\n    train_set = data[:train_size]\n    dev_set = data[950:1000]\n    test_sets = []\n    for i in range(10):\n        test = data[1000 * (i + 1) : 1000 * (i + 2)]\n        test_sets.append(test)\n\n    return train_set, dev_set, test_sets\n\n\ndef _gc_dataloader(\n    train_size,\n    q_type,\n    batch_size,\n    node_dict,\n    edge_dict,\n    label_dict,\n    reverse_edge,\n    path,\n):\n    def _collate_fn(batch):\n        graphs = []\n        labels = []\n        for d in batch:\n            edges = d[\"edges\"]\n\n            node_ids = []\n            for s, e, t in edges:\n                if s not in node_ids:\n                    node_ids.append(s)\n                if t not in node_ids:\n                    node_ids.append(t)\n            g = dgl.graph([])\n            g.add_nodes(len(node_ids))\n            g.ndata[\"node_id\"] = torch.tensor(node_ids, dtype=torch.long)\n\n            nid2idx = dict(zip(node_ids, list(range(len(node_ids)))))\n\n            labels.append(d[\"eval\"][-1])\n\n            edge_types = []\n            for s, e, t in edges:\n                g.add_edges(nid2idx[s], nid2idx[t])\n                edge_types.append(e)\n                if e in reverse_edge:\n                    g.add_edges(nid2idx[t], nid2idx[s])\n                    edge_types.append(reverse_edge[e])\n            g.edata[\"type\"] = torch.tensor(edge_types, dtype=torch.long)\n            annotation = torch.zeros([len(node_ids), 2], dtype=torch.long)\n            annotation[nid2idx[d[\"eval\"][0]]][0] = 1\n            annotation[nid2idx[d[\"eval\"][2]]][1] = 1\n            g.ndata[\"annotation\"] = annotation\n            graphs.append(g)\n        batch_graph = dgl.batch(graphs)\n        labels = torch.tensor(labels, dtype=torch.long)\n        return batch_graph, labels\n\n    def _get_dataloader(data, shuffle):\n        return DataLoader(\n            dataset=data,\n            batch_size=batch_size,\n            shuffle=shuffle,\n            collate_fn=_collate_fn,\n        )\n\n    train_set, dev_set, test_sets = _convert_gc_dataset(\n        train_size, node_dict, edge_dict, label_dict, path, q_type\n    )\n    train_dataloader = _get_dataloader(train_set, True)\n    dev_dataloader = _get_dataloader(dev_set, False)\n    test_dataloaders = []\n    for d in test_sets:\n        dl = _get_dataloader(d, False)\n        test_dataloaders.append(dl)\n\n    return train_dataloader, dev_dataloader, test_dataloaders\n\n\ndef _convert_gc_dataset(\n    train_size, node_dict, edge_dict, label_dict, path, q_type\n):\n    total_num = 11000\n\n    def convert(file):\n        dataset = []\n        d = dict()\n        with open(file, \"r\") as f:\n            for i, line in enumerate(f.readlines()):\n                line = line.strip().split()\n                if line[0] == \"1\" and len(d) > 0:\n                    d = dict()\n                if line[1] == \"eval\":\n                    # (src, edge, label)\n                    if \"eval\" not in d:\n                        d[\"eval\"] = (\n                            node_dict[line[2]],\n                            edge_dict[line[3]],\n                            node_dict[line[4]],\n                            label_dict[line[5]],\n                        )\n                        if d[\"eval\"][1] == q_type:\n                            dataset.append(d)\n                            if len(dataset) >= total_num:\n                                break\n                else:\n                    if \"edges\" not in d:\n                        d[\"edges\"] = []\n                    d[\"edges\"].append(\n                        (\n                            node_dict[line[1]],\n                            edge_dict[line[2]],\n                            node_dict[line[3]],\n                        )\n                    )\n        return dataset\n\n    download_dir = get_download_dir()\n    filename = os.path.join(download_dir, \"babi_data\", path, \"data.txt\")\n    data = convert(filename)\n\n    assert len(data) == total_num\n\n    train_set = data[:train_size]\n    dev_set = data[950:1000]\n    test_sets = []\n    for i in range(10):\n        test = data[1000 * (i + 1) : 1000 * (i + 2)]\n        test_sets.append(test)\n\n    return train_set, dev_set, test_sets\n\n\ndef _path_finding_dataloader(\n    train_size,\n    batch_size,\n    node_dict,\n    edge_dict,\n    reverse_edge,\n    path,\n    max_seq_length,\n):\n    def _collate_fn(batch):\n        graphs = []\n        ground_truths = []\n        seq_lengths = []\n        for d in batch:\n            edges = d[\"edges\"]\n\n            node_ids = []\n            for s, e, t in edges:\n                if s not in node_ids:\n                    node_ids.append(s)\n                if t not in node_ids:\n                    node_ids.append(t)\n            g = dgl.graph([])\n            g.add_nodes(len(node_ids))\n            g.ndata[\"node_id\"] = torch.tensor(node_ids, dtype=torch.long)\n\n            nid2idx = dict(zip(node_ids, list(range(len(node_ids)))))\n\n            truth = d[\"seq_out\"] + [edge_dict[\"<end>\"]] * (\n                max_seq_length - len(d[\"seq_out\"])\n            )\n            seq_len = len(d[\"seq_out\"])\n            ground_truths.append(truth)\n            seq_lengths.append(seq_len)\n\n            edge_types = []\n            for s, e, t in edges:\n                g.add_edges(nid2idx[s], nid2idx[t])\n                edge_types.append(e)\n                if e in reverse_edge:\n                    g.add_edges(nid2idx[t], nid2idx[s])\n                    edge_types.append(reverse_edge[e])\n            g.edata[\"type\"] = torch.tensor(edge_types, dtype=torch.long)\n            annotation = torch.zeros([len(node_ids), 2], dtype=torch.long)\n            annotation[nid2idx[d[\"eval\"][0]]][0] = 1\n            annotation[nid2idx[d[\"eval\"][1]]][1] = 1\n            g.ndata[\"annotation\"] = annotation\n            graphs.append(g)\n        batch_graph = dgl.batch(graphs)\n        ground_truths = torch.tensor(ground_truths, dtype=torch.long)\n        seq_lengths = torch.tensor(seq_lengths, dtype=torch.long)\n        return batch_graph, ground_truths, seq_lengths\n\n    def _get_dataloader(data, shuffle):\n        return DataLoader(\n            dataset=data,\n            batch_size=batch_size,\n            shuffle=shuffle,\n            collate_fn=_collate_fn,\n        )\n\n    train_set, dev_set, test_sets = _convert_path_finding(\n        train_size, node_dict, edge_dict, path\n    )\n    train_dataloader = _get_dataloader(train_set, True)\n    dev_dataloader = _get_dataloader(dev_set, False)\n    test_dataloaders = []\n    for d in test_sets:\n        dl = _get_dataloader(d, False)\n        test_dataloaders.append(dl)\n\n    return train_dataloader, dev_dataloader, test_dataloaders\n\n\ndef _convert_path_finding(train_size, node_dict, edge_dict, path):\n    total_num = 11000\n\n    def convert(file):\n        dataset = []\n        d = dict()\n        with open(file, \"r\") as f:\n            for line in f.readlines():\n                line = line.strip().split()\n                if line[0] == \"1\" and len(d) > 0:\n                    d = dict()\n                if line[1] == \"eval\":\n                    # (src, edge, label)\n                    d[\"eval\"] = (node_dict[line[3]], node_dict[line[4]])\n                    d[\"seq_out\"] = []\n                    seq_out = line[5].split(\",\")\n                    for e in seq_out:\n                        d[\"seq_out\"].append(edge_dict[e])\n                    dataset.append(d)\n                    if len(dataset) >= total_num:\n                        break\n                else:\n                    if \"edges\" not in d:\n                        d[\"edges\"] = []\n                    d[\"edges\"].append(\n                        (\n                            node_dict[line[1]],\n                            edge_dict[line[2]],\n                            node_dict[line[3]],\n                        )\n                    )\n        return dataset\n\n    download_dir = get_download_dir()\n    filename = os.path.join(download_dir, \"babi_data\", path, \"data.txt\")\n    data = convert(filename)\n\n    assert len(data) == total_num\n\n    train_set = data[:train_size]\n    dev_set = data[950:1000]\n    test_sets = []\n    for i in range(10):\n        test = data[1000 * (i + 1) : 1000 * (i + 2)]\n        test_sets.append(test)\n\n    return train_set, dev_set, test_sets\n\n\ndef _download_babi_data():\n    download_dir = get_download_dir()\n    zip_file_path = os.path.join(download_dir, \"babi_data.zip\")\n\n    data_url = _get_dgl_url(\"models/ggnn_babi_data.zip\")\n    download(data_url, path=zip_file_path)\n\n    extract_dir = os.path.join(download_dir, \"babi_data\")\n    if not os.path.exists(extract_dir):\n        extract_archive(zip_file_path, extract_dir)\n"
  },
  {
    "path": "examples/pytorch/ggnn/ggnn_gc.py",
    "content": "\"\"\"\nGated Graph Neural Network module for graph classification tasks\n\"\"\"\nimport torch\n\nfrom dgl.nn.pytorch import GatedGraphConv, GlobalAttentionPooling\nfrom torch import nn\n\n\nclass GraphClsGGNN(nn.Module):\n    def __init__(self, annotation_size, out_feats, n_steps, n_etypes, num_cls):\n        super(GraphClsGGNN, self).__init__()\n\n        self.annotation_size = annotation_size\n        self.out_feats = out_feats\n\n        self.ggnn = GatedGraphConv(\n            in_feats=out_feats,\n            out_feats=out_feats,\n            n_steps=n_steps,\n            n_etypes=n_etypes,\n        )\n\n        pooling_gate_nn = nn.Linear(annotation_size + out_feats, 1)\n        self.pooling = GlobalAttentionPooling(pooling_gate_nn)\n        self.output_layer = nn.Linear(annotation_size + out_feats, num_cls)\n\n        self.loss_fn = nn.CrossEntropyLoss()\n\n    def forward(self, graph, labels=None):\n        etypes = graph.edata.pop(\"type\")\n        annotation = graph.ndata.pop(\"annotation\").float()\n\n        assert annotation.size()[-1] == self.annotation_size\n\n        node_num = graph.num_nodes()\n\n        zero_pad = torch.zeros(\n            [node_num, self.out_feats - self.annotation_size],\n            dtype=torch.float,\n            device=annotation.device,\n        )\n\n        h1 = torch.cat([annotation, zero_pad], -1)\n        out = self.ggnn(graph, h1, etypes)\n\n        out = torch.cat([out, annotation], -1)\n\n        out = self.pooling(graph, out)\n\n        logits = self.output_layer(out)\n        preds = torch.argmax(logits, -1)\n\n        if labels is not None:\n            loss = self.loss_fn(logits, labels)\n            return loss, preds\n        return preds\n"
  },
  {
    "path": "examples/pytorch/ggnn/ggnn_ns.py",
    "content": "\"\"\"\nGated Graph Neural Network module for node selection tasks\n\"\"\"\nimport dgl\nimport torch\nfrom dgl.nn.pytorch import GatedGraphConv\nfrom torch import nn\n\n\nclass NodeSelectionGGNN(nn.Module):\n    def __init__(self, annotation_size, out_feats, n_steps, n_etypes):\n        super(NodeSelectionGGNN, self).__init__()\n\n        self.annotation_size = annotation_size\n        self.out_feats = out_feats\n\n        self.ggnn = GatedGraphConv(\n            in_feats=out_feats,\n            out_feats=out_feats,\n            n_steps=n_steps,\n            n_etypes=n_etypes,\n        )\n\n        self.output_layer = nn.Linear(annotation_size + out_feats, 1)\n        self.loss_fn = nn.CrossEntropyLoss()\n\n    def forward(self, graph, labels=None):\n        etypes = graph.edata.pop(\"type\")\n        annotation = graph.ndata.pop(\"annotation\").float()\n\n        assert annotation.size()[-1] == self.annotation_size\n\n        node_num = graph.num_nodes()\n\n        zero_pad = torch.zeros(\n            [node_num, self.out_feats - self.annotation_size],\n            dtype=torch.float,\n            device=annotation.device,\n        )\n\n        h1 = torch.cat([annotation, zero_pad], -1)\n        out = self.ggnn(graph, h1, etypes)\n\n        all_logits = self.output_layer(\n            torch.cat([out, annotation], -1)\n        ).squeeze(-1)\n        graph.ndata[\"logits\"] = all_logits\n\n        batch_g = dgl.unbatch(graph)\n\n        preds = []\n        if labels is not None:\n            loss = 0.0\n        for i, g in enumerate(batch_g):\n            logits = g.ndata[\"logits\"]\n            preds.append(torch.argmax(logits))\n            if labels is not None:\n                logits = logits.unsqueeze(0)\n                y = labels[i].unsqueeze(0)\n                loss += self.loss_fn(logits, y)\n\n        if labels is not None:\n            loss /= float(len(batch_g))\n            return loss, preds\n        return preds\n"
  },
  {
    "path": "examples/pytorch/ggnn/ggsnn.py",
    "content": "\"\"\"\nGated Graph Sequence Neural Network for sequence outputs\n\"\"\"\n\nimport torch\nimport torch.nn.functional as F\n\nfrom dgl.nn.pytorch import GatedGraphConv, GlobalAttentionPooling\nfrom torch import nn\n\n\nclass GGSNN(nn.Module):\n    def __init__(\n        self,\n        annotation_size,\n        out_feats,\n        n_steps,\n        n_etypes,\n        max_seq_length,\n        num_cls,\n    ):\n        super(GGSNN, self).__init__()\n\n        self.annotation_size = annotation_size\n        self.out_feats = out_feats\n        self.max_seq_length = max_seq_length\n\n        self.ggnn = GatedGraphConv(\n            in_feats=out_feats,\n            out_feats=out_feats,\n            n_steps=n_steps,\n            n_etypes=n_etypes,\n        )\n\n        self.annotation_out_layer = nn.Linear(\n            annotation_size + out_feats, annotation_size\n        )\n\n        pooling_gate_nn = nn.Linear(annotation_size + out_feats, 1)\n        self.pooling = GlobalAttentionPooling(pooling_gate_nn)\n\n        self.output_layer = nn.Linear(annotation_size + out_feats, num_cls)\n        self.loss_fn = nn.CrossEntropyLoss(reduction=\"none\")\n\n    def forward(self, graph, seq_lengths, ground_truth=None):\n        etypes = graph.edata.pop(\"type\")\n        annotation = graph.ndata.pop(\"annotation\").float()\n\n        assert annotation.size()[-1] == self.annotation_size\n\n        node_num = graph.num_nodes()\n\n        all_logits = []\n        for _ in range(self.max_seq_length):\n            zero_pad = torch.zeros(\n                [node_num, self.out_feats - self.annotation_size],\n                dtype=torch.float,\n                device=annotation.device,\n            )\n\n            h1 = torch.cat([annotation.detach(), zero_pad], -1)\n            out = self.ggnn(graph, h1, etypes)\n            out = torch.cat([out, annotation], -1)\n            logits = self.pooling(graph, out)\n            logits = self.output_layer(logits)\n            all_logits.append(logits)\n\n            annotation = self.annotation_out_layer(out)\n            annotation = F.softmax(annotation, -1)\n\n        all_logits = torch.stack(all_logits, 1)\n        preds = torch.argmax(all_logits, -1)\n        if ground_truth is not None:\n            loss = sequence_loss(all_logits, ground_truth, seq_lengths)\n            return loss, preds\n        return preds\n\n\ndef sequence_loss(logits, ground_truth, seq_length=None):\n    def sequence_mask(length):\n        max_length = logits.size(1)\n        batch_size = logits.size(0)\n        range_tensor = torch.arange(\n            0, max_length, dtype=seq_length.dtype, device=seq_length.device\n        )\n        range_tensor = torch.stack([range_tensor] * batch_size, 0)\n\n        expanded_length = torch.stack([length] * max_length, -1)\n        mask = (range_tensor < expanded_length).float()\n        return mask\n\n    loss = nn.CrossEntropyLoss(reduction=\"none\")(\n        logits.permute((0, 2, 1)), ground_truth\n    )\n\n    if seq_length is None:\n        loss = loss.mean()\n    else:\n        mask = sequence_mask(seq_length)\n        loss = (loss * mask).sum(-1) / seq_length.float()\n        loss = loss.mean()\n    return loss\n"
  },
  {
    "path": "examples/pytorch/ggnn/train_gc.py",
    "content": "\"\"\"\nTraining and testing for graph classification tasks in bAbI\n\"\"\"\n\nimport argparse\n\nimport numpy as np\nimport torch\nfrom data_utils import get_babi_dataloaders\nfrom ggnn_gc import GraphClsGGNN\nfrom torch.optim import Adam\n\n\ndef main(args):\n    out_feats = {18: 3}\n    n_etypes = {18: 2}\n\n    train_dataloader, dev_dataloader, test_dataloaders = get_babi_dataloaders(\n        batch_size=args.batch_size,\n        train_size=args.train_num,\n        task_id=args.task_id,\n        q_type=args.question_id,\n    )\n\n    model = GraphClsGGNN(\n        annotation_size=2,\n        out_feats=out_feats[args.task_id],\n        n_steps=5,\n        n_etypes=n_etypes[args.task_id],\n        num_cls=2,\n    )\n    opt = Adam(model.parameters(), lr=args.lr)\n\n    print(f\"Task {args.task_id}, question_id {args.question_id}\")\n\n    print(f\"Training set size: {len(train_dataloader.dataset)}\")\n    print(f\"Dev set size: {len(dev_dataloader.dataset)}\")\n\n    # training and dev stage\n    for epoch in range(args.epochs):\n        model.train()\n        for i, batch in enumerate(train_dataloader):\n            g, labels = batch\n            loss, _ = model(g, labels)\n            opt.zero_grad()\n            loss.backward()\n            opt.step()\n            if epoch % 20 == 0:\n                print(f\"Epoch {epoch}, batch {i} loss: {loss.data}\")\n\n        if epoch % 20 != 0:\n            continue\n        dev_preds = []\n        dev_labels = []\n        model.eval()\n        for g, labels in dev_dataloader:\n            with torch.no_grad():\n                preds = model(g)\n                preds = preds.data.numpy().tolist()\n                labels = labels.data.numpy().tolist()\n                dev_preds += preds\n                dev_labels += labels\n        acc = np.equal(dev_labels, dev_preds).astype(float).tolist()\n        acc = sum(acc) / len(acc)\n        print(f\"Epoch {epoch}, Dev acc {acc}\")\n\n    # test stage\n    for i, dataloader in enumerate(test_dataloaders):\n        print(f\"Test set {i} size: {len(dataloader.dataset)}\")\n\n    test_acc_list = []\n    for dataloader in test_dataloaders:\n        test_preds = []\n        test_labels = []\n        model.eval()\n        for g, labels in dataloader:\n            with torch.no_grad():\n                preds = model(g)\n                preds = preds.data.numpy().tolist()\n                labels = labels.data.numpy().tolist()\n                test_preds += preds\n                test_labels += labels\n        acc = np.equal(test_labels, test_preds).astype(float).tolist()\n        acc = sum(acc) / len(acc)\n        test_acc_list.append(acc)\n\n    test_acc_mean = np.mean(test_acc_list)\n    test_acc_std = np.std(test_acc_list)\n\n    print(\n        f\"Mean of accuracy in 10 test datasets: {test_acc_mean}, std: {test_acc_std}\"\n    )\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(\n        description=\"Gated Graph Neural Networks for graph classification tasks in bAbI\"\n    )\n    parser.add_argument(\n        \"--task_id\", type=int, default=18, help=\"task id from 1 to 20\"\n    )\n    parser.add_argument(\n        \"--question_id\", type=int, default=0, help=\"question id for each task\"\n    )\n    parser.add_argument(\n        \"--train_num\", type=int, default=950, help=\"Number of training examples\"\n    )\n    parser.add_argument(\"--batch_size\", type=int, default=50, help=\"batch size\")\n    parser.add_argument(\"--lr\", type=float, default=1e-3, help=\"learning rate\")\n    parser.add_argument(\n        \"--epochs\", type=int, default=200, help=\"number of training epochs\"\n    )\n\n    args = parser.parse_args()\n\n    main(args)\n"
  },
  {
    "path": "examples/pytorch/ggnn/train_ns.py",
    "content": "\"\"\"\nTraining and testing for node selection tasks in bAbI\n\"\"\"\n\nimport argparse\nimport time\n\nimport numpy as np\nimport torch\nfrom data_utils import get_babi_dataloaders\nfrom ggnn_ns import NodeSelectionGGNN\nfrom torch.optim import Adam\n\n\ndef main(args):\n    out_feats = {4: 4, 15: 5, 16: 6}\n    n_etypes = {4: 4, 15: 2, 16: 2}\n\n    train_dataloader, dev_dataloader, test_dataloaders = get_babi_dataloaders(\n        batch_size=args.batch_size,\n        train_size=args.train_num,\n        task_id=args.task_id,\n        q_type=args.question_id,\n    )\n\n    model = NodeSelectionGGNN(\n        annotation_size=1,\n        out_feats=out_feats[args.task_id],\n        n_steps=5,\n        n_etypes=n_etypes[args.task_id],\n    )\n    opt = Adam(model.parameters(), lr=args.lr)\n\n    print(f\"Task {args.task_id}, question_id {args.question_id}\")\n\n    print(f\"Training set size: {len(train_dataloader.dataset)}\")\n    print(f\"Dev set size: {len(dev_dataloader.dataset)}\")\n\n    # training and dev stage\n    for epoch in range(args.epochs):\n        model.train()\n        for i, batch in enumerate(train_dataloader):\n            g, labels = batch\n            loss, _ = model(g, labels)\n            opt.zero_grad()\n            loss.backward()\n            opt.step()\n            print(f\"Epoch {epoch}, batch {i} loss: {loss.data}\")\n\n        dev_preds = []\n        dev_labels = []\n        model.eval()\n        for g, labels in dev_dataloader:\n            with torch.no_grad():\n                preds = model(g)\n                preds = (\n                    torch.tensor(preds, dtype=torch.long).data.numpy().tolist()\n                )\n                labels = labels.data.numpy().tolist()\n                dev_preds += preds\n                dev_labels += labels\n        acc = np.equal(dev_labels, dev_preds).astype(float).tolist()\n        acc = sum(acc) / len(acc)\n        print(f\"Epoch {epoch}, Dev acc {acc}\")\n\n    # test stage\n    for i, dataloader in enumerate(test_dataloaders):\n        print(f\"Test set {i} size: {len(dataloader.dataset)}\")\n\n    test_acc_list = []\n    for dataloader in test_dataloaders:\n        test_preds = []\n        test_labels = []\n        model.eval()\n        for g, labels in dataloader:\n            with torch.no_grad():\n                preds = model(g)\n                preds = (\n                    torch.tensor(preds, dtype=torch.long).data.numpy().tolist()\n                )\n                labels = labels.data.numpy().tolist()\n                test_preds += preds\n                test_labels += labels\n        acc = np.equal(test_labels, test_preds).astype(float).tolist()\n        acc = sum(acc) / len(acc)\n        test_acc_list.append(acc)\n\n    test_acc_mean = np.mean(test_acc_list)\n    test_acc_std = np.std(test_acc_list)\n\n    print(\n        f\"Mean of accuracy in 10 test datasets: {test_acc_mean}, std: {test_acc_std}\"\n    )\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(\n        description=\"Gated Graph Neural Networks for node selection tasks in bAbI\"\n    )\n    parser.add_argument(\n        \"--task_id\", type=int, default=16, help=\"task id from 1 to 20\"\n    )\n    parser.add_argument(\n        \"--question_id\", type=int, default=1, help=\"question id for each task\"\n    )\n    parser.add_argument(\n        \"--train_num\", type=int, default=50, help=\"Number of training examples\"\n    )\n    parser.add_argument(\"--batch_size\", type=int, default=10, help=\"batch size\")\n    parser.add_argument(\"--lr\", type=float, default=1e-3, help=\"learning rate\")\n    parser.add_argument(\n        \"--epochs\", type=int, default=100, help=\"number of training epochs\"\n    )\n\n    args = parser.parse_args()\n\n    main(args)\n"
  },
  {
    "path": "examples/pytorch/ggnn/train_path_finding.py",
    "content": "\"\"\"\nTraining and testing for sequence output tasks in bAbI.\nHere we take task 19 'Path Finding' as an example\n\"\"\"\n\nimport argparse\n\nimport numpy as np\nimport torch\nfrom data_utils import get_babi_dataloaders\nfrom ggsnn import GGSNN\nfrom torch.optim import Adam\n\n\ndef main(args):\n    out_feats = {19: 6}\n    n_etypes = {19: 4}\n\n    train_dataloader, dev_dataloader, test_dataloaders = get_babi_dataloaders(\n        batch_size=args.batch_size,\n        train_size=args.train_num,\n        task_id=args.task_id,\n        q_type=-1,\n    )\n\n    model = GGSNN(\n        annotation_size=2,\n        out_feats=out_feats[args.task_id],\n        n_steps=5,\n        n_etypes=n_etypes[args.task_id],\n        max_seq_length=2,\n        num_cls=5,\n    )\n    opt = Adam(model.parameters(), lr=args.lr)\n\n    print(f\"Task {args.task_id}\")\n\n    print(f\"Training set size: {len(train_dataloader.dataset)}\")\n    print(f\"Dev set size: {len(dev_dataloader.dataset)}\")\n\n    # training and dev stage\n    for epoch in range(args.epochs):\n        model.train()\n        for i, batch in enumerate(train_dataloader):\n            g, ground_truths, seq_lengths = batch\n            loss, _ = model(g, seq_lengths, ground_truths)\n            opt.zero_grad()\n            loss.backward()\n            opt.step()\n            if epoch % 20 == 0:\n                print(f\"Epoch {epoch}, batch {i} loss: {loss.data}\")\n\n        if epoch % 20 != 0:\n            continue\n        dev_res = []\n        model.eval()\n        for g, ground_truths, seq_lengths in dev_dataloader:\n            with torch.no_grad():\n                preds = model(g, seq_lengths)\n                preds = preds.data.numpy().tolist()\n                ground_truths = ground_truths.data.numpy().tolist()\n                for i, p in enumerate(preds):\n                    if p == ground_truths[i]:\n                        dev_res.append(1.0)\n                    else:\n                        dev_res.append(0.0)\n        acc = sum(dev_res) / len(dev_res)\n        print(f\"Epoch {epoch}, Dev acc {acc}\")\n\n    # test stage\n    for i, dataloader in enumerate(test_dataloaders):\n        print(f\"Test set {i} size: {len(dataloader.dataset)}\")\n\n    test_acc_list = []\n    for dataloader in test_dataloaders:\n        test_res = []\n        model.eval()\n        for g, ground_truths, seq_lengths in dataloader:\n            with torch.no_grad():\n                preds = model(g, seq_lengths)\n                preds = preds.data.numpy().tolist()\n                ground_truths = ground_truths.data.numpy().tolist()\n                for i, p in enumerate(preds):\n                    if p == ground_truths[i]:\n                        test_res.append(1.0)\n                    else:\n                        test_res.append(0.0)\n        acc = sum(test_res) / len(test_res)\n        test_acc_list.append(acc)\n\n    test_acc_mean = np.mean(test_acc_list)\n    test_acc_std = np.std(test_acc_list)\n\n    print(\n        f\"Mean of accuracy in 10 test datasets: {test_acc_mean}, std: {test_acc_std}\"\n    )\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(\n        description=\"Gated Graph Sequence Neural Networks for sequential output tasks in \"\n        \"bAbI\"\n    )\n    parser.add_argument(\n        \"--task_id\", type=int, default=19, help=\"task id from 1 to 20\"\n    )\n    parser.add_argument(\n        \"--train_num\", type=int, default=250, help=\"Number of training examples\"\n    )\n    parser.add_argument(\"--batch_size\", type=int, default=10, help=\"batch size\")\n    parser.add_argument(\"--lr\", type=float, default=1e-3, help=\"learning rate\")\n    parser.add_argument(\n        \"--epochs\", type=int, default=200, help=\"number of training epochs\"\n    )\n\n    args = parser.parse_args()\n\n    main(args)\n"
  },
  {
    "path": "examples/pytorch/gin/README.md",
    "content": "Graph Isomorphism Network (GIN)\n============\n\n- Paper link: [arXiv](https://arxiv.org/abs/1810.00826) [OpenReview](https://openreview.net/forum?id=ryGs6iA5Km) \n- Author's code repo: [https://github.com/weihua916/powerful-gnns](https://github.com/weihua916/powerful-gnns).\n\nDependencies\n------------\n- scikit-learn\n\nInstall as follows:\n```bash\npip install scikit-learn\n```\n\nHow to run\n-------\n\nRun with the following for bioinformatics graph classification (available datasets: MUTAG (default), PTC, NCI1, and PROTEINS)\n```bash\npython3 train.py --dataset MUTAG\n```\n\n> **_NOTE:_**  Users may observe results fluctuate due to the randomness with relatively small dataset.  In consistence with the original [paper](https://arxiv.org/abs/1810.00826), five social network datasets, 'COLLAB', 'IMDBBINARY' 'IMDBMULTI' 'REDDITBINARY' and 'REDDITMULTI5K', are also available as the input. Users are encouraged to update the script slightly for social network applications, for example, replacing sum readout on bioinformatics datasets with mean readout on social network datasets and using one-hot encodings of node degrees by setting \"degree_as_nlabel=True\" in GINDataset.\n\nSummary (10-fold cross-validation)\n-------\n| Dataset       | Result\n| ------------- | -------\n| MUTAG         | ~89.4\n| PTC           | ~68.5\n| NCI1          | ~82.9\n| PROTEINS      | ~74.1\n"
  },
  {
    "path": "examples/pytorch/gin/train.py",
    "content": "import argparse\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\n\nfrom dgl.data import GINDataset\nfrom dgl.dataloading import GraphDataLoader\nfrom dgl.nn.pytorch.conv import GINConv\nfrom dgl.nn.pytorch.glob import SumPooling\nfrom sklearn.model_selection import StratifiedKFold\nfrom torch.utils.data.sampler import SubsetRandomSampler\n\n\nclass MLP(nn.Module):\n    \"\"\"Construct two-layer MLP-type aggreator for GIN model\"\"\"\n\n    def __init__(self, input_dim, hidden_dim, output_dim):\n        super().__init__()\n        self.linears = nn.ModuleList()\n        # two-layer MLP\n        self.linears.append(nn.Linear(input_dim, hidden_dim, bias=False))\n        self.linears.append(nn.Linear(hidden_dim, output_dim, bias=False))\n        self.batch_norm = nn.BatchNorm1d((hidden_dim))\n\n    def forward(self, x):\n        h = x\n        h = F.relu(self.batch_norm(self.linears[0](h)))\n        return self.linears[1](h)\n\n\nclass GIN(nn.Module):\n    def __init__(self, input_dim, hidden_dim, output_dim):\n        super().__init__()\n        self.ginlayers = nn.ModuleList()\n        self.batch_norms = nn.ModuleList()\n        num_layers = 5\n        # five-layer GCN with two-layer MLP aggregator and sum-neighbor-pooling scheme\n        for layer in range(num_layers - 1):  # excluding the input layer\n            if layer == 0:\n                mlp = MLP(input_dim, hidden_dim, hidden_dim)\n            else:\n                mlp = MLP(hidden_dim, hidden_dim, hidden_dim)\n            self.ginlayers.append(\n                GINConv(mlp, learn_eps=False)\n            )  # set to True if learning epsilon\n            self.batch_norms.append(nn.BatchNorm1d(hidden_dim))\n        # linear functions for graph sum poolings of output of each layer\n        self.linear_prediction = nn.ModuleList()\n        for layer in range(num_layers):\n            if layer == 0:\n                self.linear_prediction.append(nn.Linear(input_dim, output_dim))\n            else:\n                self.linear_prediction.append(nn.Linear(hidden_dim, output_dim))\n        self.drop = nn.Dropout(0.5)\n        self.pool = (\n            SumPooling()\n        )  # change to mean readout (AvgPooling) on social network datasets\n\n    def forward(self, g, h):\n        # list of hidden representation at each layer (including the input layer)\n        hidden_rep = [h]\n        for i, layer in enumerate(self.ginlayers):\n            h = layer(g, h)\n            h = self.batch_norms[i](h)\n            h = F.relu(h)\n            hidden_rep.append(h)\n        score_over_layer = 0\n        # perform graph sum pooling over all nodes in each layer\n        for i, h in enumerate(hidden_rep):\n            pooled_h = self.pool(g, h)\n            score_over_layer += self.drop(self.linear_prediction[i](pooled_h))\n        return score_over_layer\n\n\ndef split_fold10(labels, fold_idx=0):\n    skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=0)\n    idx_list = []\n    for idx in skf.split(np.zeros(len(labels)), labels):\n        idx_list.append(idx)\n    train_idx, valid_idx = idx_list[fold_idx]\n    return train_idx, valid_idx\n\n\ndef evaluate(dataloader, device, model):\n    model.eval()\n    total = 0\n    total_correct = 0\n    for batched_graph, labels in dataloader:\n        batched_graph = batched_graph.to(device)\n        labels = labels.to(device)\n        feat = batched_graph.ndata.pop(\"attr\")\n        total += len(labels)\n        logits = model(batched_graph, feat)\n        _, predicted = torch.max(logits, 1)\n        total_correct += (predicted == labels).sum().item()\n    acc = 1.0 * total_correct / total\n    return acc\n\n\ndef train(train_loader, val_loader, device, model):\n    # loss function, optimizer and scheduler\n    loss_fcn = nn.CrossEntropyLoss()\n    optimizer = optim.Adam(model.parameters(), lr=0.01)\n    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)\n\n    # training loop\n    for epoch in range(350):\n        model.train()\n        total_loss = 0\n        for batch, (batched_graph, labels) in enumerate(train_loader):\n            batched_graph = batched_graph.to(device)\n            labels = labels.to(device)\n            feat = batched_graph.ndata.pop(\"attr\")\n            logits = model(batched_graph, feat)\n            loss = loss_fcn(logits, labels)\n            optimizer.zero_grad()\n            loss.backward()\n            optimizer.step()\n            total_loss += loss.item()\n        scheduler.step()\n        train_acc = evaluate(train_loader, device, model)\n        valid_acc = evaluate(val_loader, device, model)\n        print(\n            \"Epoch {:05d} | Loss {:.4f} | Train Acc. {:.4f} | Validation Acc. {:.4f} \".format(\n                epoch, total_loss / (batch + 1), train_acc, valid_acc\n            )\n        )\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--dataset\",\n        type=str,\n        default=\"MUTAG\",\n        choices=[\"MUTAG\", \"PTC\", \"NCI1\", \"PROTEINS\"],\n        help=\"name of dataset (default: MUTAG)\",\n    )\n    args = parser.parse_args()\n    print(f\"Training with DGL built-in GINConv module with a fixed epsilon = 0\")\n    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n    # load and split dataset\n    dataset = GINDataset(\n        args.dataset, self_loop=True, degree_as_nlabel=False\n    )  # add self_loop and disable one-hot encoding for input features\n    labels = [l for _, l in dataset]\n    train_idx, val_idx = split_fold10(labels)\n\n    # create dataloader\n    train_loader = GraphDataLoader(\n        dataset,\n        sampler=SubsetRandomSampler(train_idx),\n        batch_size=128,\n        pin_memory=torch.cuda.is_available(),\n    )\n    val_loader = GraphDataLoader(\n        dataset,\n        sampler=SubsetRandomSampler(val_idx),\n        batch_size=128,\n        pin_memory=torch.cuda.is_available(),\n    )\n\n    # create GIN model\n    in_size = dataset.dim_nfeats\n    out_size = dataset.gclasses\n    model = GIN(in_size, 16, out_size).to(device)\n\n    # model training/validating\n    print(\"Training...\")\n    train(train_loader, val_loader, device, model)\n"
  },
  {
    "path": "examples/pytorch/gnn_explainer/README.md",
    "content": "# DGL Implementation of GNNExplainer\n\nThis is a DGL example for [GNNExplainer: Generating Explanations for Graph Neural Networks](https://arxiv.org/abs/1903.03894). For the authors' original implementation,\nsee [here](https://github.com/RexYing/gnn-model-explainer).\n\nContributors:\n- [Jian Zhang](https://github.com/zhjwy9343)\n- [Kounianhua Du](https://github.com/KounianhuaDu)\n- [Yanjun Zhao](https://github.com/zyj-111)\n\nDatasets\n----------------------\n\nFour built-in synthetic datasets are used in this example.\n\n- [BA-SHAPES](https://docs.dgl.ai/generated/dgl.data.BAShapeDataset.html#dgl.data.BAShapeDataset)\n- [BA-COMMUNITY](https://docs.dgl.ai/generated/dgl.data.BACommunityDataset.html#dgl.data.BACommunityDataset)\n- [TREE-CYCLE](https://docs.dgl.ai/generated/dgl.data.TreeCycleDataset.html#dgl.data.TreeCycleDataset)\n- [TREE-GRID](https://docs.dgl.ai/generated/dgl.data.TreeGridDataset.html#dgl.data.TreeGridDataset)\n\nUsage\n----------------------\n\n**First**, train a GNN model on a dataset.\n\n```bash\npython train_main.py  --dataset $DATASET\n```\n\nValid options for `$DATASET`: `BAShape`, `BACommunity`, `TreeCycle`, `TreeGrid`\n\nThe trained model weights will be saved to `model_{dataset}.pth`\n\n**Second**, install [GNNLens2](https://github.com/dmlc/GNNLens2) with\n\n```bash\npip install -U flask-cors\npip install Flask==2.0.3\npip install gnnlens\n```\n\n**Third**, explain the trained model with the same dataset\n\n```bash\npython explain_main.py --dataset $DATASET\n```\n\n**Finally**, launch `GNNLens2` to visualize the explanations\n\n```bash\ngnnlens --logdir gnn_subgraph\n```\n\nBy entering `localhost:7777` in your web browser address bar, you can see the GNNLens2 interface. `7777` is the default port GNNLens2 uses. You can specify an alternative one by adding `--port xxxx` after the command line and change the address in the web browser accordingly.\n\nA sample visualization is available below. For more details of using `GNNLens2`, check its [tutorials](https://github.com/dmlc/GNNLens2#tutorials).\n\n<p align=\"center\">\n  <img src=\"https://data.dgl.ai/asset/image/explain_BAShape.png\"  width=\"600\">\n  <br>\n  <b>Figure</b>: Explanation for node 41 of BAShape\n</p>\n"
  },
  {
    "path": "examples/pytorch/gnn_explainer/explain_main.py",
    "content": "import argparse\nimport os\n\nimport dgl\n\nimport torch as th\nfrom dgl import load_graphs\nfrom dgl.data import (\n    BACommunityDataset,\n    BAShapeDataset,\n    TreeCycleDataset,\n    TreeGridDataset,\n)\nfrom dgl.nn import GNNExplainer\nfrom gnnlens import Writer\nfrom models import Model\n\n\ndef main(args):\n    if args.dataset == \"BAShape\":\n        dataset = BAShapeDataset(seed=0)\n    elif args.dataset == \"BACommunity\":\n        dataset = BACommunityDataset(seed=0)\n    elif args.dataset == \"TreeCycle\":\n        dataset = TreeCycleDataset(seed=0)\n    elif args.dataset == \"TreeGrid\":\n        dataset = TreeGridDataset(seed=0)\n\n    graph = dataset[0]\n    labels = graph.ndata[\"label\"]\n    feats = graph.ndata[\"feat\"]\n    num_classes = dataset.num_classes\n\n    # load an existing model\n    model_path = os.path.join(\"./\", f\"model_{args.dataset}.pth\")\n    model_stat_dict = th.load(model_path)\n    model = Model(feats.shape[-1], num_classes)\n    model.load_state_dict(model_stat_dict)\n\n    # Choose the first node of the class 1 for explaining prediction\n    target_class = 1\n    for n_idx, n_label in enumerate(labels):\n        if n_label == target_class:\n            break\n\n    explainer = GNNExplainer(model, num_hops=3)\n    new_center, sub_graph, feat_mask, edge_mask = explainer.explain_node(\n        n_idx, graph, feats\n    )\n\n    # gnnlens2\n    # Specify the path to create a new directory for dumping data files.\n    writer = Writer(\"gnn_subgraph\")\n    writer.add_graph(\n        name=args.dataset,\n        graph=graph,\n        nlabels=labels,\n        num_nlabel_types=num_classes,\n    )\n    writer.add_subgraph(\n        graph_name=args.dataset,\n        subgraph_name=\"GNNExplainer\",\n        node_id=n_idx,\n        subgraph_nids=sub_graph.ndata[dgl.NID],\n        subgraph_eids=sub_graph.edata[dgl.EID],\n        subgraph_eweights=edge_mask,\n    )\n\n    # Finish dumping.\n    writer.close()\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"Demo of GNN explainer in DGL\")\n    parser.add_argument(\n        \"--dataset\",\n        type=str,\n        default=\"BAShape\",\n        choices=[\"BAShape\", \"BACommunity\", \"TreeCycle\", \"TreeGrid\"],\n    )\n    args = parser.parse_args()\n    print(args)\n\n    main(args)\n"
  },
  {
    "path": "examples/pytorch/gnn_explainer/gnn_subgraph/1/graph.json",
    "content": "{\"graph_obj\": {\"name\": \"BAShape\", \"srcs\": [41, 45, 59, 64, 67, 70, 74, 80, 83, 92, 99, 102, 105, 108, 112, 115, 118, 121, 124, 127, 131, 134, 137, 140, 143, 146, 149, 152, 157, 160, 163, 167, 170, 173, 177, 180, 183, 187, 190, 197, 202, 205, 43, 0, 206, 47, 209, 1, 210, 213, 216, 219, 225, 228, 231, 233, 236, 239, 242, 61, 248, 2, 249, 252, 255, 66, 3, 256, 69, 4, 72, 257, 5, 245, 76, 259, 6, 260, 95, 263, 266, 82, 7, 267, 85, 8, 268, 271, 273, 276, 279, 282, 285, 94, 9, 286, 97, 78, 287, 290, 101, 10, 104, 291, 11, 107, 293, 12, 110, 294, 13, 295, 114, 298, 14, 299, 117, 15, 120, 300, 16, 302, 123, 17, 303, 126, 18, 129, 304, 19, 305, 308, 133, 20, 310, 136, 21, 311, 139, 22, 142, 312, 214, 23, 145, 313, 24, 148, 314, 25, 151, 315, 26, 154, 316, 27, 317, 155, 320, 159, 28, 321, 162, 29, 165, 322, 30, 324, 327, 169, 31, 328, 172, 32, 332, 175, 33, 333, 179, 334, 34, 335, 182, 35, 336, 185, 36, 200, 338, 189, 37, 339, 192, 38, 194, 340, 193, 342, 345, 199, 39, 201, 346, 347, 204, 40, 42, 43, 208, 349, 44, 47, 46, 352, 212, 48, 214, 354, 49, 218, 356, 50, 221, 357, 51, 176, 359, 227, 364, 52, 230, 365, 53, 301, 369, 54, 371, 235, 55, 372, 238, 56, 373, 241, 57, 244, 374, 58, 247, 383, 73, 60, 61, 251, 387, 62, 254, 393, 63, 66, 65, 69, 68, 71, 72, 329, 73, 76, 75, 407, 261, 306, 77, 265, 409, 78, 79, 81, 82, 84, 85, 413, 270, 86, 323, 416, 87, 419, 275, 348, 88, 420, 278, 89, 281, 421, 90, 284, 422, 91, 94, 93, 341, 97, 96, 289, 427, 98, 101, 100, 103, 104, 433, 106, 107, 109, 110, 455, 297, 111, 114, 113, 117, 116, 119, 120, 123, 122, 125, 126, 129, 128, 473, 307, 130, 133, 132, 483, 135, 136, 139, 138, 142, 141, 144, 145, 148, 147, 151, 150, 154, 153, 512, 155, 156, 513, 158, 159, 161, 162, 165, 164, 534, 326, 166, 168, 169, 172, 171, 331, 544, 258, 174, 175, 358, 223, 178, 179, 182, 181, 184, 185, 186, 189, 188, 191, 192, 194, 341, 195, 558, 344, 562, 196, 198, 199, 337, 201, 204, 203, 566, 569, 517, 390, 370, 205, 207, 208, 209, 437, 519, 573, 491, 566, 379, 438, 553, 575, 425, 212, 211, 556, 463, 385, 577, 458, 215, 214, 567, 580, 478, 516, 508, 218, 217, 220, 221, 222, 223, 601, 361, 224, 569, 478, 602, 555, 535, 576, 553, 378, 605, 545, 227, 226, 230, 229, 563, 521, 609, 571, 605, 613, 591, 389, 603, 516, 491, 517, 613, 602, 540, 232, 301, 574, 445, 486, 614, 554, 234, 235, 238, 237, 241, 240, 244, 243, 503, 619, 555, 606, 493, 554, 555, 500, 614, 619, 464, 619, 591, 578, 595, 559, 364, 436, 620, 478, 397, 557, 545, 432, 619, 474, 620, 591, 605, 538, 582, 620, 582, 620, 525, 567, 516, 621, 410, 576, 508, 619, 246, 247, 620, 602, 567, 626, 570, 613, 504, 626, 570, 459, 417, 248, 572, 553, 626, 591, 250, 251, 627, 628, 521, 570, 567, 619, 621, 591, 554, 629, 629, 362, 603, 536, 535, 569, 619, 592, 629, 491, 457, 551, 539, 629, 563, 253, 254, 630, 567, 255, 581, 577, 545, 444, 423, 605, 570, 630, 429, 521, 605, 370, 630, 603, 436, 626, 256, 504, 631, 468, 257, 353, 631, 537, 491, 487, 631, 603, 448, 591, 560, 258, 593, 631, 573, 592, 535, 370, 570, 569, 631, 631, 627, 567, 577, 481, 493, 632, 620, 578, 573, 259, 518, 633, 410, 519, 598, 632, 588, 566, 633, 537, 461, 486, 621, 634, 517, 524, 262, 261, 581, 635, 630, 619, 478, 265, 264, 587, 583, 266, 636, 604, 522, 475, 619, 636, 606, 499, 410, 553, 637, 619, 267, 463, 517, 270, 269, 581, 632, 572, 613, 639, 639, 575, 503, 635, 465, 272, 323, 378, 630, 580, 635, 619, 517, 397, 615, 640, 520, 275, 274, 277, 278, 281, 280, 284, 283, 580, 389, 619, 516, 285, 626, 645, 591, 503, 619, 574, 632, 432, 579, 591, 619, 519, 647, 404, 423, 286, 474, 289, 288, 492, 519, 649, 570, 632, 290, 291, 573, 575, 592, 397, 384, 398, 629, 602, 619, 620, 613, 423, 436, 579, 575, 644, 603, 569, 566, 563, 654, 435, 292, 592, 655, 630, 566, 570, 549, 554, 436, 605, 655, 617, 508, 655, 626, 463, 535, 388, 446, 655, 630, 613, 388, 620, 503, 655, 384, 593, 655, 539, 635, 632, 655, 621, 616, 555, 519, 655, 629, 508, 567, 656, 591, 553, 631, 566, 317, 656, 571, 621, 619, 570, 517, 656, 627, 378, 655, 630, 637, 436, 498, 656, 293, 518, 554, 486, 491, 656, 555, 573, 294, 553, 656, 630, 537, 596, 410, 656, 386, 631, 468, 656, 493, 460, 581, 656, 559, 563, 390, 471, 556, 362, 656, 456, 620, 466, 635, 493, 656, 296, 297, 657, 630, 566, 620, 613, 298, 299, 619, 545, 464, 657, 384, 520, 656, 574, 619, 657, 580, 657, 620, 655, 636, 300, 456, 657, 602, 375, 437, 519, 301, 614, 538, 555, 657, 573, 657, 641, 498, 499, 656, 658, 478, 619, 592, 554, 302, 656, 630, 619, 555, 605, 303, 619, 545, 658, 631, 457, 516, 658, 638, 569, 468, 570, 593, 486, 658, 659, 619, 498, 634, 554, 631, 304, 659, 468, 629, 535, 579, 659, 629, 573, 591, 605, 583, 660, 628, 478, 606, 496, 582, 660, 629, 307, 306, 655, 570, 621, 631, 619, 594, 580, 631, 661, 622, 308, 468, 516, 599, 661, 634, 594, 591, 655, 661, 626, 657, 631, 620, 659, 575, 399, 398, 486, 662, 457, 660, 662, 567, 650, 446, 496, 640, 584, 478, 631, 640, 596, 547, 602, 557, 485, 665, 309, 657, 666, 488, 310, 478, 569, 655, 436, 666, 567, 632, 311, 535, 555, 657, 666, 445, 570, 619, 629, 545, 666, 632, 444, 666, 555, 572, 603, 492, 487, 667, 545, 667, 655, 545, 554, 621, 667, 569, 666, 633, 570, 603, 491, 667, 556, 604, 579, 627, 570, 667, 643, 668, 630, 655, 520, 389, 668, 468, 637, 312, 667, 584, 656, 566, 655, 592, 313, 591, 353, 619, 314, 634, 669, 629, 581, 315, 474, 553, 669, 659, 669, 615, 367, 488, 620, 447, 669, 655, 492, 458, 538, 536, 634, 631, 591, 569, 464, 658, 619, 553, 657, 597, 492, 367, 670, 445, 607, 670, 655, 397, 474, 671, 492, 593, 316, 486, 658, 629, 566, 605, 631, 479, 603, 437, 672, 661, 545, 672, 630, 646, 556, 468, 669, 499, 673, 457, 564, 318, 674, 515, 319, 569, 630, 445, 631, 675, 604, 675, 655, 666, 619, 491, 320, 675, 655, 576, 620, 666, 571, 602, 572, 675, 656, 321, 655, 569, 675, 567, 659, 571, 478, 675, 655, 602, 535, 569, 571, 675, 556, 675, 627, 322, 671, 574, 559, 538, 619, 423, 675, 553, 446, 657, 592, 660, 675, 546, 630, 481, 571, 675, 323, 667, 645, 630, 675, 593, 537, 582, 669, 626, 675, 593, 671, 675, 670, 423, 666, 566, 367, 605, 675, 570, 380, 477, 578, 675, 406, 620, 511, 675, 472, 572, 655, 554, 546, 675, 326, 325, 545, 676, 569, 675, 655, 676, 619, 555, 668, 675, 629, 621, 676, 554, 675, 327, 676, 620, 516, 444, 666, 626, 570, 536, 629, 676, 328, 675, 676, 554, 503, 655, 676, 675, 597, 463, 592, 492, 676, 627, 594, 620, 610, 676, 488, 370, 646, 330, 331, 677, 675, 332, 333, 456, 676, 555, 677, 411, 602, 628, 594, 354, 535, 522, 570, 334, 677, 676, 677, 566, 572, 655, 536, 630, 677, 520, 375, 566, 335, 457, 667, 657, 677, 634, 518, 430, 486, 677, 639, 337, 336, 569, 676, 656, 602, 621, 338, 633, 620, 619, 592, 631, 655, 629, 675, 632, 339, 602, 487, 678, 604, 577, 655, 340, 634, 638, 655, 678, 536, 678, 571, 659, 602, 629, 463, 602, 668, 586, 655, 668, 486, 620, 555, 518, 571, 629, 626, 668, 679, 344, 343, 537, 538, 629, 657, 518, 345, 680, 655, 603, 663, 595, 572, 634, 478, 346, 633, 655, 619, 633, 632, 655, 656, 517, 633, 657, 655, 675, 629, 347, 475, 458, 444, 535, 655, 545, 632, 675, 656, 683, 569, 621, 630, 571, 656, 569, 633, 619, 683, 349, 555, 683, 638, 569, 627, 592, 619, 656, 683, 570, 516, 683, 626, 620, 570, 535, 676, 553, 630, 683, 604, 566, 675, 352, 683, 569, 683, 621, 567, 389, 571, 668, 354, 569, 683, 517, 577, 621, 478, 538, 683, 539, 536, 683, 629, 572, 539, 632, 655, 683, 644, 636, 604, 628, 683, 655, 678, 675, 683, 631, 571, 630, 356, 629, 683, 621, 631, 553, 621, 683, 569, 573, 629, 459, 683, 498, 655, 614, 504, 683, 423, 676, 666, 385, 683, 456, 382, 572, 675, 444, 683, 591, 658, 408, 675, 487, 683, 668, 619, 629, 545, 655, 621, 675, 619, 655, 656, 676, 684, 602, 357, 605, 656, 545, 463, 656, 684, 629, 491, 684, 566, 516, 456, 436, 657, 658, 358, 684, 670, 630, 629, 626, 684, 642, 604, 572, 684, 675, 508, 629, 582, 684, 642, 569, 637, 626, 604, 684, 619, 576, 360, 361, 629, 675, 545, 656, 676, 676, 571, 620, 602, 656, 555, 630, 633, 629, 655, 675, 685, 630, 444, 629, 384, 444, 685, 572, 478, 685, 602, 508, 568, 619, 364, 655, 685, 657, 626, 574, 365, 685, 657, 555, 615, 594, 629, 379, 685, 425, 537, 397, 685, 643, 535, 592, 521, 619, 637, 431, 685, 570, 655, 492, 675, 369, 436, 371, 436, 686, 554, 445, 662, 616, 604, 686, 630, 372, 498, 636, 545, 373, 574, 676, 686, 655, 631, 619, 374, 686, 487, 518, 686, 554, 568, 683, 655, 667, 629, 675, 569, 631, 657, 632, 553, 633, 535, 619, 545, 592, 569, 383, 459, 571, 575, 545, 675, 659, 655, 684, 687, 683, 554, 546, 655, 687, 667, 632, 553, 595, 687, 458, 536, 619, 602, 535, 629, 387, 554, 591, 585, 571, 632, 630, 619, 655, 554, 572, 393, 619, 655, 675, 535, 689, 689, 569, 667, 675, 602, 655, 675, 619, 676, 569, 676, 407, 566, 655, 689, 633, 518, 487, 577, 409, 689, 628, 689, 655, 602, 571, 537, 595, 656, 689, 486, 602, 676, 655, 413, 634, 545, 689, 574, 657, 444, 689, 604, 416, 689, 683, 498, 419, 667, 417, 420, 487, 689, 553, 603, 436, 689, 655, 629, 656, 685, 421, 422, 583, 689, 602, 516, 658, 684, 689, 566, 633, 591, 689, 423, 592, 619, 656, 486, 425, 657, 554, 689, 630, 567, 536, 655, 689, 632, 627, 689, 427, 566, 655, 689, 675, 508, 632, 469, 655, 553, 619, 689, 570, 629, 545, 689, 429, 578, 430, 431, 621, 689, 553, 553, 432, 689, 579, 629, 435, 434, 545, 632, 656, 667, 629, 455, 571, 675, 566, 629, 619, 676, 553, 690, 478, 464, 573, 619, 633, 553, 683, 602, 473, 676, 555, 690, 620, 474, 603, 571, 569, 690, 574, 666, 569, 690, 478, 565, 620, 690, 666, 667, 482, 497, 620, 481, 690, 484, 485, 675, 545, 629, 619, 632, 655, 604, 571, 576, 676, 498, 578, 675, 691, 603, 503, 504, 691, 655, 689, 689, 677, 691, 603, 557, 691, 602, 641, 508, 538, 512, 574, 553, 632, 591, 691, 514, 515, 655, 629, 545, 534, 656, 667, 544, 629, 655, 569, 619, 675, 585, 545, 678, 592, 613, 676, 633, 555, 553, 554, 692, 562, 579, 559, 675, 560, 692, 567, 666, 570, 563, 675, 619, 602, 565, 655, 692, 567, 619, 568, 566, 656, 569, 655, 675, 676, 591, 655, 601, 656, 668, 592, 690, 602, 603, 604, 619, 689, 619, 675, 613, 685, 619, 620, 621, 693, 622, 685, 693, 626, 627, 628, 629, 633, 654, 656, 632, 655, 655, 676, 665, 659, 657, 656, 684, 683, 666, 667, 668, 674, 675, 677, 678, 681, 676, 684, 689, 686, 683, 685, 691, 689, 692, 693, 697, 690, 696, 698, 699, 695, 697, 699], \"dsts\": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 41, 42, 44, 45, 45, 46, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 59, 60, 62, 63, 64, 64, 65, 67, 67, 68, 70, 70, 71, 73, 74, 74, 75, 77, 78, 79, 80, 80, 81, 83, 83, 84, 86, 87, 88, 89, 90, 91, 92, 92, 93, 95, 95, 96, 98, 99, 99, 100, 102, 102, 103, 105, 105, 106, 108, 108, 109, 111, 112, 112, 113, 115, 115, 116, 118, 118, 119, 121, 121, 122, 124, 124, 125, 127, 127, 128, 130, 131, 131, 132, 134, 134, 135, 137, 137, 138, 140, 140, 140, 141, 143, 143, 144, 146, 146, 147, 149, 149, 150, 152, 152, 153, 156, 156, 157, 157, 158, 160, 160, 161, 163, 163, 164, 166, 167, 167, 168, 170, 170, 171, 173, 173, 174, 176, 177, 177, 178, 180, 180, 181, 183, 183, 184, 186, 187, 187, 188, 190, 190, 191, 193, 193, 195, 196, 197, 197, 198, 200, 200, 202, 202, 203, 205, 205, 206, 206, 207, 209, 209, 210, 210, 211, 213, 213, 215, 216, 216, 217, 219, 219, 220, 222, 224, 225, 225, 226, 228, 228, 229, 231, 231, 232, 233, 233, 234, 236, 236, 237, 239, 239, 240, 242, 242, 243, 245, 245, 246, 248, 248, 249, 249, 250, 252, 252, 253, 255, 255, 256, 256, 257, 257, 258, 258, 259, 259, 260, 260, 261, 262, 263, 263, 264, 264, 266, 266, 267, 267, 268, 268, 269, 271, 271, 272, 273, 273, 273, 274, 276, 276, 277, 279, 279, 280, 282, 282, 283, 285, 285, 285, 286, 286, 287, 287, 288, 290, 290, 291, 291, 292, 293, 293, 294, 294, 295, 295, 296, 298, 298, 299, 299, 300, 300, 302, 302, 303, 303, 304, 304, 305, 305, 306, 308, 308, 309, 310, 310, 311, 311, 312, 312, 313, 313, 314, 314, 315, 315, 316, 316, 317, 317, 318, 319, 320, 320, 321, 321, 322, 322, 324, 324, 325, 327, 327, 328, 328, 329, 329, 330, 332, 332, 333, 333, 334, 334, 335, 335, 336, 336, 337, 338, 338, 339, 339, 340, 340, 341, 341, 342, 342, 343, 345, 345, 346, 346, 347, 347, 348, 348, 348, 348, 348, 348, 349, 349, 350, 350, 350, 350, 350, 350, 351, 351, 351, 351, 351, 352, 352, 353, 353, 353, 353, 353, 354, 354, 355, 355, 355, 355, 355, 356, 356, 357, 357, 358, 358, 359, 359, 360, 362, 362, 362, 362, 362, 363, 363, 363, 363, 363, 364, 364, 365, 365, 366, 366, 366, 366, 366, 367, 367, 367, 367, 367, 368, 368, 368, 368, 368, 369, 369, 370, 370, 370, 370, 370, 371, 371, 372, 372, 373, 373, 374, 374, 375, 375, 375, 375, 375, 375, 376, 376, 376, 376, 376, 377, 377, 377, 377, 377, 378, 378, 378, 378, 378, 378, 379, 379, 379, 379, 379, 380, 380, 380, 380, 380, 381, 381, 381, 381, 381, 382, 382, 382, 382, 382, 383, 383, 384, 384, 384, 384, 384, 385, 385, 385, 385, 385, 386, 386, 386, 386, 386, 386, 387, 387, 388, 388, 388, 388, 388, 389, 389, 389, 389, 389, 390, 390, 390, 390, 390, 391, 391, 391, 391, 391, 392, 392, 392, 392, 392, 393, 393, 394, 394, 394, 394, 394, 394, 395, 395, 395, 395, 395, 396, 396, 396, 396, 396, 397, 397, 397, 397, 397, 397, 398, 398, 398, 398, 398, 398, 399, 399, 399, 399, 399, 400, 400, 400, 400, 400, 400, 401, 401, 401, 401, 401, 402, 402, 402, 402, 402, 403, 403, 403, 403, 403, 404, 404, 404, 404, 404, 404, 404, 405, 405, 405, 405, 405, 406, 406, 406, 406, 406, 407, 407, 408, 408, 408, 408, 408, 409, 409, 410, 410, 410, 410, 410, 410, 411, 411, 411, 411, 411, 412, 412, 412, 412, 412, 412, 413, 413, 413, 414, 414, 414, 414, 414, 415, 415, 415, 415, 415, 416, 416, 417, 417, 417, 417, 417, 418, 418, 418, 418, 418, 419, 419, 420, 420, 421, 421, 422, 422, 423, 423, 423, 423, 423, 423, 424, 424, 424, 424, 424, 425, 425, 425, 425, 425, 426, 426, 426, 426, 426, 426, 427, 427, 428, 428, 428, 428, 428, 428, 429, 429, 429, 429, 429, 429, 430, 430, 430, 430, 430, 431, 431, 431, 431, 431, 432, 432, 432, 432, 432, 433, 433, 434, 436, 436, 436, 436, 436, 437, 437, 437, 437, 437, 438, 438, 438, 438, 438, 439, 439, 439, 439, 439, 440, 440, 440, 440, 440, 441, 441, 441, 441, 441, 442, 442, 442, 442, 442, 443, 443, 443, 443, 443, 444, 444, 444, 444, 444, 445, 445, 445, 445, 445, 445, 446, 446, 446, 446, 446, 447, 447, 447, 447, 447, 447, 448, 448, 448, 448, 448, 449, 449, 449, 449, 449, 449, 450, 450, 450, 450, 450, 451, 451, 451, 451, 451, 452, 452, 452, 452, 452, 453, 453, 453, 453, 453, 454, 454, 454, 454, 454, 455, 455, 456, 456, 456, 456, 456, 456, 457, 457, 457, 457, 457, 457, 458, 458, 458, 458, 458, 459, 459, 459, 459, 459, 460, 460, 460, 460, 460, 460, 461, 461, 461, 461, 461, 461, 462, 462, 462, 462, 462, 463, 463, 463, 463, 463, 464, 464, 464, 464, 464, 464, 465, 465, 465, 465, 465, 465, 466, 466, 466, 466, 466, 467, 467, 467, 467, 467, 468, 468, 468, 468, 468, 469, 469, 469, 469, 469, 469, 470, 470, 470, 470, 470, 471, 471, 471, 471, 471, 472, 472, 472, 472, 472, 473, 473, 474, 474, 474, 474, 474, 475, 475, 475, 475, 475, 476, 476, 476, 476, 476, 476, 477, 477, 477, 477, 477, 478, 478, 478, 478, 478, 479, 479, 479, 479, 479, 480, 480, 480, 480, 480, 481, 481, 481, 481, 481, 482, 482, 482, 482, 482, 483, 483, 484, 486, 486, 486, 486, 486, 486, 487, 487, 487, 487, 487, 487, 488, 488, 488, 488, 488, 489, 489, 489, 489, 489, 490, 490, 490, 490, 490, 491, 491, 491, 491, 491, 492, 492, 492, 492, 492, 493, 493, 493, 493, 493, 494, 494, 494, 494, 494, 495, 495, 495, 495, 495, 496, 496, 496, 496, 496, 497, 497, 497, 497, 497, 497, 498, 498, 498, 498, 498, 498, 499, 499, 499, 499, 499, 499, 500, 500, 500, 500, 500, 500, 501, 501, 501, 501, 501, 502, 502, 502, 502, 502, 503, 503, 503, 503, 503, 504, 504, 504, 504, 504, 505, 505, 505, 505, 505, 506, 506, 506, 506, 506, 507, 507, 507, 507, 507, 507, 508, 508, 508, 508, 508, 509, 509, 509, 509, 509, 510, 510, 510, 510, 510, 511, 511, 511, 511, 511, 511, 512, 513, 513, 514, 516, 516, 516, 516, 516, 517, 517, 517, 517, 517, 518, 518, 518, 518, 518, 518, 519, 519, 519, 519, 519, 520, 520, 520, 520, 520, 520, 521, 521, 521, 521, 521, 522, 522, 522, 522, 522, 523, 523, 523, 523, 523, 523, 523, 524, 524, 524, 524, 524, 525, 525, 525, 525, 525, 526, 526, 526, 526, 526, 526, 527, 527, 527, 527, 527, 528, 528, 528, 528, 528, 529, 529, 529, 529, 529, 530, 530, 530, 530, 530, 531, 531, 531, 531, 531, 532, 532, 532, 532, 532, 533, 533, 533, 533, 533, 534, 534, 535, 535, 535, 535, 535, 536, 536, 536, 536, 536, 537, 537, 537, 537, 537, 537, 538, 538, 538, 538, 538, 539, 539, 539, 539, 539, 539, 540, 540, 540, 540, 540, 541, 541, 541, 541, 541, 542, 542, 542, 542, 542, 543, 543, 543, 543, 543, 544, 544, 546, 546, 546, 546, 546, 546, 546, 547, 547, 547, 547, 547, 547, 548, 548, 548, 548, 548, 548, 549, 549, 549, 549, 549, 550, 550, 550, 550, 550, 550, 551, 551, 551, 551, 551, 552, 552, 552, 552, 552, 553, 553, 553, 553, 553, 553, 553, 554, 554, 554, 554, 554, 554, 555, 555, 555, 555, 555, 555, 556, 556, 556, 556, 556, 557, 557, 557, 557, 557, 557, 558, 558, 558, 558, 558, 559, 559, 559, 559, 559, 560, 560, 560, 560, 560, 561, 561, 561, 561, 561, 562, 562, 563, 563, 563, 563, 563, 563, 564, 564, 564, 564, 564, 565, 565, 565, 565, 565, 565, 566, 566, 566, 566, 566, 567, 567, 567, 567, 567, 568, 568, 568, 568, 568, 568, 569, 569, 569, 569, 569, 570, 570, 570, 570, 570, 571, 571, 571, 571, 571, 571, 572, 572, 572, 572, 572, 573, 573, 573, 573, 573, 574, 574, 574, 574, 574, 575, 575, 575, 575, 575, 576, 576, 576, 576, 576, 576, 577, 577, 577, 577, 577, 578, 578, 578, 578, 578, 578, 579, 579, 579, 579, 579, 580, 580, 580, 580, 580, 581, 581, 581, 581, 581, 582, 582, 582, 582, 582, 583, 583, 583, 583, 583, 584, 584, 584, 584, 584, 584, 585, 585, 585, 585, 585, 586, 586, 586, 586, 586, 587, 587, 587, 587, 587, 588, 588, 588, 588, 588, 589, 589, 589, 589, 589, 590, 590, 590, 590, 590, 591, 591, 591, 591, 591, 592, 592, 592, 592, 592, 593, 593, 593, 593, 593, 593, 594, 594, 594, 594, 594, 595, 595, 595, 595, 595, 596, 596, 596, 596, 596, 596, 597, 597, 597, 597, 597, 598, 598, 598, 598, 598, 599, 599, 599, 599, 599, 600, 600, 600, 600, 600, 601, 601, 602, 602, 602, 602, 602, 603, 603, 603, 603, 603, 604, 604, 604, 604, 604, 605, 605, 605, 605, 605, 606, 606, 606, 606, 606, 607, 607, 607, 607, 607, 607, 608, 608, 608, 608, 608, 608, 609, 609, 609, 609, 609, 610, 610, 610, 610, 610, 611, 611, 611, 611, 611, 612, 612, 612, 612, 612, 613, 613, 613, 613, 613, 613, 614, 614, 614, 614, 614, 614, 615, 615, 615, 615, 615, 615, 616, 616, 616, 616, 616, 616, 617, 617, 617, 617, 617, 617, 618, 618, 618, 618, 618, 619, 619, 619, 619, 619, 620, 620, 620, 620, 620, 621, 621, 621, 621, 621, 621, 622, 622, 622, 622, 622, 623, 623, 623, 623, 623, 624, 624, 624, 624, 624, 625, 625, 625, 625, 625, 626, 626, 626, 626, 626, 627, 627, 627, 627, 627, 627, 628, 628, 628, 628, 628, 629, 630, 630, 630, 630, 630, 631, 631, 631, 631, 631, 633, 633, 633, 633, 633, 634, 634, 634, 634, 634, 634, 635, 635, 635, 635, 635, 635, 636, 636, 636, 636, 636, 637, 637, 637, 637, 637, 638, 638, 638, 638, 638, 638, 639, 639, 639, 639, 639, 639, 640, 640, 640, 640, 640, 640, 641, 641, 641, 641, 641, 641, 642, 642, 642, 642, 642, 642, 643, 643, 643, 643, 643, 643, 644, 644, 644, 644, 644, 645, 645, 645, 645, 645, 646, 646, 646, 646, 646, 647, 647, 647, 647, 647, 648, 648, 648, 648, 648, 648, 649, 649, 649, 649, 649, 650, 650, 650, 650, 650, 651, 651, 651, 651, 651, 652, 652, 652, 652, 652, 653, 653, 653, 653, 653, 654, 654, 655, 655, 655, 655, 655, 656, 657, 657, 657, 657, 657, 658, 658, 658, 658, 658, 659, 659, 659, 659, 659, 660, 660, 660, 660, 660, 660, 661, 661, 661, 661, 661, 662, 662, 662, 662, 662, 663, 663, 663, 663, 663, 664, 664, 664, 664, 664, 665, 665, 666, 666, 666, 666, 666, 668, 668, 668, 668, 668, 669, 669, 669, 669, 669, 670, 670, 670, 670, 670, 671, 671, 671, 671, 671, 672, 672, 672, 672, 672, 673, 673, 673, 673, 673, 673, 674, 674, 675, 675, 675, 675, 675, 675, 676, 676, 676, 676, 676, 676, 677, 677, 677, 677, 677, 678, 678, 678, 678, 678, 679, 679, 679, 679, 679, 679, 680, 680, 680, 680, 680, 681, 681, 681, 681, 681, 682, 682, 682, 682, 682, 683, 683, 683, 683, 683, 684, 684, 684, 684, 684, 684, 685, 685, 685, 685, 685, 686, 686, 686, 686, 686, 687, 687, 687, 687, 687, 688, 688, 688, 688, 688, 689, 689, 689, 689, 689, 689, 690, 690, 690, 690, 690, 690, 691, 691, 691, 691, 691, 691, 692, 692, 692, 692, 692, 693, 693, 693, 693, 693, 694, 694, 694, 694, 694, 694, 695, 696, 697, 697, 698, 698], \"num_nodes\": 700, \"nlabels\": [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 3, 2, 1, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 3, 2, 2, 1, 2, 3, 1, 2, 3, 1, 2, 3, 2, 1, 2, 3, 2, 2, 2, 1, 2, 3, 1, 2, 3, 2, 2, 2, 2, 2, 2, 1, 2, 3, 1, 2, 3, 2, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 2, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 2, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 3, 1, 1, 2, 3, 1, 2, 3, 1, 2, 3, 2, 1, 2, 3, 1, 2, 3, 1, 2, 3, 2, 1, 2, 3, 1, 2, 3, 1, 2, 3, 2, 1, 2, 3, 1, 2, 3, 1, 3, 2, 2, 1, 2, 3, 1, 3, 1, 2, 3, 1, 1, 2, 3, 1, 1, 2, 3, 1, 3, 2, 1, 2, 3, 1, 2, 3, 2, 3, 2, 1, 2, 3, 1, 2, 3, 1, 2, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 1, 2, 3, 1, 2, 3, 1, 1, 1, 2, 1, 1, 3, 2, 1, 2, 3, 1, 1, 1, 2, 3, 1, 2, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 1, 1, 2, 3, 1, 1, 2, 1, 1, 1, 2, 3, 1, 1, 1, 3, 1, 1, 1, 1, 2, 3, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 3, 1, 2, 3, 1, 1, 1, 2, 3, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2, 1, 2, 3, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 2, 3, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 1, 1, 3], \"num_nlabel_types\": 4, \"eweights\": {}}, \"success\": true}"
  },
  {
    "path": "examples/pytorch/gnn_explainer/gnn_subgraph/1/model_list.json",
    "content": "{\"models\": [], \"success\": true}"
  },
  {
    "path": "examples/pytorch/gnn_explainer/gnn_subgraph/1/subgraph_1.json",
    "content": "{\"name\": \"GNNExplainer\", \"success\": true, \"node_subgraphs\": {\"0\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"1\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"2\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"3\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"4\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"5\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"6\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"7\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"8\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"9\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"10\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"11\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"12\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"13\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"14\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"15\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"16\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"17\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"18\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"19\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"20\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"21\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"22\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"23\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"24\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"25\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"26\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"27\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"28\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"29\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"30\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"31\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"32\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"33\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"34\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"35\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"36\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"37\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"38\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"39\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"40\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"41\": {\"nodes\": [41, 42, 43, 205], \"nweight\": [1.0, 1.0, 1.0, 1.0], \"eids\": [41, 42, 206, 207], \"eweight\": [0.7634373307228088, 0.5627130270004272, 0.5729275941848755, 0.3973456621170044]}, \"42\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"43\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"44\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"45\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"46\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"47\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"48\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"49\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"50\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"51\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"52\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"53\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"54\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"55\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"56\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"57\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"58\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"59\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"60\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"61\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"62\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"63\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"64\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"65\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"66\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"67\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"68\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"69\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"70\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"71\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"72\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"73\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"74\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"75\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"76\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"77\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"78\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"79\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"80\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"81\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"82\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"83\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"84\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"85\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"86\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"87\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"88\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"89\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"90\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"91\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"92\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"93\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"94\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"95\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"96\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"97\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"98\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"99\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"100\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"101\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"102\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"103\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"104\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"105\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"106\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"107\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"108\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"109\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"110\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"111\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"112\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"113\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"114\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"115\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"116\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"117\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"118\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"119\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"120\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"121\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"122\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"123\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"124\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"125\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"126\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"127\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"128\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"129\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"130\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"131\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"132\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"133\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"134\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"135\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"136\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"137\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"138\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"139\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"140\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"141\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"142\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"143\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"144\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"145\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"146\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"147\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"148\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"149\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"150\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"151\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"152\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"153\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"154\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"155\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"156\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"157\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"158\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"159\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"160\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"161\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"162\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"163\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"164\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"165\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"166\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"167\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"168\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"169\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"170\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"171\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"172\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"173\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"174\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"175\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"176\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"177\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"178\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"179\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"180\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"181\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"182\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"183\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"184\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"185\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"186\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"187\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"188\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"189\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"190\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"191\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"192\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"193\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"194\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"195\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"196\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"197\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"198\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"199\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"200\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"201\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"202\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"203\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"204\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"205\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"206\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"207\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"208\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"209\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"210\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"211\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"212\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"213\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"214\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"215\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"216\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"217\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"218\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"219\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"220\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"221\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"222\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"223\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"224\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"225\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"226\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"227\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"228\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"229\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"230\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"231\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"232\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"233\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"234\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"235\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"236\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"237\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"238\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"239\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"240\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"241\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"242\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"243\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"244\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"245\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"246\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"247\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"248\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"249\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"250\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"251\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"252\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"253\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"254\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"255\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"256\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"257\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"258\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"259\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"260\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"261\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"262\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"263\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"264\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"265\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"266\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"267\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"268\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"269\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"270\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"271\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"272\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"273\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"274\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"275\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"276\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"277\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"278\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"279\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"280\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"281\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"282\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"283\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"284\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"285\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"286\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"287\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"288\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"289\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"290\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"291\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"292\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"293\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"294\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"295\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"296\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"297\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"298\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"299\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"300\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"301\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"302\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"303\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"304\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"305\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"306\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"307\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"308\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"309\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"310\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"311\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"312\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"313\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"314\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"315\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"316\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"317\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"318\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"319\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"320\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"321\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"322\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"323\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"324\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"325\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"326\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"327\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"328\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"329\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"330\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"331\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"332\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"333\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"334\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"335\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"336\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"337\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"338\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"339\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"340\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"341\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"342\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"343\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"344\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"345\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"346\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"347\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"348\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"349\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"350\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"351\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"352\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"353\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"354\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"355\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"356\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"357\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"358\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"359\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"360\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"361\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"362\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"363\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"364\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"365\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"366\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"367\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"368\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"369\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"370\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"371\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"372\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"373\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"374\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"375\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"376\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"377\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"378\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"379\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"380\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"381\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"382\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"383\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"384\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"385\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"386\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"387\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"388\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"389\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"390\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"391\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"392\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"393\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"394\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"395\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"396\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"397\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"398\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"399\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"400\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"401\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"402\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"403\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"404\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"405\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"406\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"407\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"408\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"409\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"410\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"411\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"412\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"413\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"414\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"415\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"416\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"417\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"418\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"419\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"420\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"421\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"422\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"423\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"424\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"425\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"426\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"427\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"428\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"429\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"430\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"431\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"432\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"433\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"434\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"435\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"436\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"437\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"438\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"439\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"440\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"441\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"442\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"443\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"444\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"445\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"446\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"447\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"448\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"449\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"450\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"451\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"452\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"453\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"454\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"455\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"456\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"457\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"458\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"459\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"460\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"461\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"462\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"463\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"464\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"465\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"466\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"467\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"468\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"469\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"470\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"471\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"472\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"473\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"474\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"475\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"476\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"477\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"478\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"479\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"480\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"481\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"482\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"483\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"484\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"485\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"486\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"487\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"488\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"489\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"490\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"491\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"492\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"493\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"494\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"495\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"496\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"497\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"498\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"499\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"500\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"501\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"502\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"503\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"504\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"505\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"506\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"507\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"508\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"509\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"510\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"511\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"512\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"513\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"514\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"515\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"516\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"517\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"518\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"519\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"520\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"521\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"522\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"523\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"524\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"525\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"526\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"527\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"528\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"529\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"530\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"531\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"532\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"533\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"534\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"535\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"536\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"537\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"538\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"539\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"540\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"541\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"542\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"543\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"544\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"545\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"546\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"547\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"548\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"549\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"550\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"551\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"552\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"553\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"554\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"555\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"556\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"557\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"558\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"559\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"560\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"561\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"562\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"563\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"564\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"565\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"566\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"567\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"568\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"569\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"570\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"571\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"572\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"573\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"574\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"575\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"576\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"577\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"578\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"579\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"580\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"581\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"582\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"583\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"584\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"585\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"586\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"587\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"588\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"589\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"590\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"591\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"592\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"593\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"594\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"595\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"596\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"597\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"598\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"599\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"600\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"601\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"602\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"603\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"604\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"605\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"606\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"607\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"608\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"609\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"610\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"611\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"612\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"613\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"614\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"615\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"616\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"617\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"618\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"619\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"620\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"621\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"622\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"623\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"624\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"625\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"626\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"627\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"628\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"629\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"630\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"631\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"632\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"633\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"634\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"635\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"636\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"637\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"638\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"639\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"640\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"641\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"642\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"643\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"644\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"645\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"646\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"647\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"648\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"649\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"650\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"651\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"652\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"653\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"654\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"655\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"656\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"657\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"658\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"659\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"660\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"661\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"662\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"663\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"664\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"665\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"666\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"667\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"668\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"669\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"670\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"671\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"672\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"673\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"674\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"675\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"676\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"677\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"678\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"679\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"680\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"681\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"682\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"683\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"684\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"685\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"686\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"687\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"688\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"689\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"690\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"691\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"692\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"693\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"694\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"695\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"696\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"697\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"698\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}, \"699\": {\"nodes\": [], \"nweight\": [], \"eids\": [], \"eweight\": []}}}"
  },
  {
    "path": "examples/pytorch/gnn_explainer/gnn_subgraph/1/subgraph_list.json",
    "content": "{\"subgraphs\": [{\"id\": 1, \"name\": \"GNNExplainer\"}], \"success\": true}"
  },
  {
    "path": "examples/pytorch/gnn_explainer/gnn_subgraph/dataset_list.json",
    "content": "{\"datasets\": [{\"id\": 1, \"name\": \"BAShape\"}], \"success\": true}"
  },
  {
    "path": "examples/pytorch/gnn_explainer/models.py",
    "content": "import dgl.function as fn\nimport torch as th\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass Layer(nn.Module):\n    def __init__(self, in_dim, out_dim):\n        super().__init__()\n        self.layer = nn.Linear(in_dim * 2, out_dim, bias=True)\n\n    def forward(self, graph, feat, eweight=None):\n        with graph.local_scope():\n            graph.ndata[\"h\"] = feat\n\n            if eweight is None:\n                graph.update_all(fn.copy_u(\"h\", \"m\"), fn.mean(\"m\", \"h\"))\n            else:\n                graph.edata[\"ew\"] = eweight\n                graph.update_all(fn.u_mul_e(\"h\", \"ew\", \"m\"), fn.mean(\"m\", \"h\"))\n\n            h = self.layer(th.cat([graph.ndata[\"h\"], feat], dim=-1))\n\n            return h\n\n\nclass Model(nn.Module):\n    def __init__(self, in_dim, out_dim, hid_dim=40):\n        super().__init__()\n        self.in_layer = Layer(in_dim, hid_dim)\n        self.hid_layer = Layer(hid_dim, hid_dim)\n        self.out_layer = Layer(hid_dim, out_dim)\n\n    def forward(self, graph, feat, eweight=None):\n        h = self.in_layer(graph, feat.float(), eweight)\n        h = F.relu(h)\n        h = self.hid_layer(graph, h, eweight)\n        h = F.relu(h)\n        h = self.out_layer(graph, h, eweight)\n        return h\n"
  },
  {
    "path": "examples/pytorch/gnn_explainer/train_main.py",
    "content": "import argparse\nimport os\n\nimport torch as th\nimport torch.nn as nn\n\nfrom dgl import save_graphs\n\nfrom dgl.data import (\n    BACommunityDataset,\n    BAShapeDataset,\n    TreeCycleDataset,\n    TreeGridDataset,\n)\nfrom models import Model\n\n\ndef main(args):\n    if args.dataset == \"BAShape\":\n        dataset = BAShapeDataset(seed=0)\n    elif args.dataset == \"BACommunity\":\n        dataset = BACommunityDataset(seed=0)\n    elif args.dataset == \"TreeCycle\":\n        dataset = TreeCycleDataset(seed=0)\n    elif args.dataset == \"TreeGrid\":\n        dataset = TreeGridDataset(seed=0)\n\n    graph = dataset[0]\n    labels = graph.ndata[\"label\"]\n    n_feats = graph.ndata[\"feat\"]\n    num_classes = dataset.num_classes\n\n    model = Model(n_feats.shape[-1], num_classes)\n    loss_fn = nn.CrossEntropyLoss()\n    optim = th.optim.Adam(model.parameters(), lr=0.001)\n\n    for epoch in range(500):\n        model.train()\n        # For demo purpose, we train the model on all datapoints\n        # In practice, you should train only on the training datapoints\n        logits = model(graph, n_feats)\n        loss = loss_fn(logits, labels)\n        acc = th.sum(logits.argmax(dim=1) == labels).item() / len(labels)\n\n        optim.zero_grad()\n        loss.backward()\n        optim.step()\n\n        print(f\"In Epoch: {epoch}; Acc: {acc}; Loss: {loss.item()}\")\n\n    model_stat_dict = model.state_dict()\n    model_path = os.path.join(\"./\", f\"model_{args.dataset}.pth\")\n    th.save(model_stat_dict, model_path)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"Dummy model training\")\n    parser.add_argument(\n        \"--dataset\",\n        type=str,\n        default=\"BAShape\",\n        choices=[\"BAShape\", \"BACommunity\", \"TreeCycle\", \"TreeGrid\"],\n    )\n    args = parser.parse_args()\n    print(args)\n\n    main(args)\n"
  },
  {
    "path": "examples/pytorch/grace/README.md",
    "content": "# DGL Implementation of GRACE\nThis DGL example implements the model proposed in the paper [Deep Graph Contrastive Representation Learning](https://arxiv.org/abs/2006.04131).\n\nAuthor's code: https://github.com/CRIPAC-DIG/GRACE\n\n## Example Implementor\n\nThis example was implemented by [Hengrui Zhang](https://github.com/hengruizhang98) when he was an applied scientist intern at AWS Shanghai AI Lab.\n\n## Dependencies\n\n- Python 3.7\n- PyTorch 1.7.1\n- dgl 0.6.0\n- scikit-learn 0.22.1\n\n## Datasets\n\n##### Unsupervised Node Classification Datasets:\n\n'Cora', 'Citeseer' and 'Pubmed'\n\n| Dataset  | # Nodes | # Edges | # Classes |\n| -------- | ------- | ------- | --------- |\n| Cora     | 2,708   | 10,556  | 7         |\n| Citeseer | 3,327   | 9,228   | 6         |\n| Pubmed   | 19,717  | 88,651  | 3         |\n\n\n## Arguments\n\n```\n--dataname         str     The graph dataset name.                Default is 'cora'.\n--gpu              int     GPU index.                             Default is 0.\n--split            int     Dataset spliting method.               Default is 'random'.\n--epochs           int     Number of training periods.            Default is 500.\n--lr               float   Learning rate.                         Default is 0.001.\n--wd               float   Weight decay.                          Default is 1e-5.\n--temp             float   Temperature.                           Default is 1.0.\n--act_fn           str     Activation function.                   Default is relu.\n--hid_dim          int     Hidden dimension.                      Default is 256.\n--out_dim          int     Output dimension.                      Default is 256.\n--num_layers       int     Number of GNN layers.                  Default is 2.\n--der1             float   Drop edge ratio 1.                     Default is 0.2. \n--der2             float   Drop edge ratio 2.                     Default is 0.2. \n--dfr1             float   Drop feature ratio 1.                  Default is 0.2. \n--dfr2             float   Drop feature ratio 2.                  Default is 0.2. \n```\n\n## How to run examples\n\nIn the paper(as well as authors' repo), the training set and testing set are split randomly with 1:9 ratio. In order to fairly compare it with other methods with the public split (20 training nodes each class), in this repo we also provide its results using the public split (with fine-tuned hyper-parameters). To run the examples, follow the following instructions.\n\n```python\n# Cora with random split\npython main.py --dataname cora --epochs 200 --lr 5e-4 --wd 1e-5 --hid_dim 128 --out_dim 128 --act_fn relu --der1 0.2 --der2 0.4 --dfr1 0.3 --dfr2 0.4 --temp 0.4\n\n# Cora with public split\npython main.py --dataname cora --split public --epochs 400 --lr 5e-4 --wd 1e-5 --hid_dim 256 --out_dim 256 --act_fn relu --der1 0.3 --der2 0.4 --dfr1 0.3 --dfr2 0.4 --temp 0.4\n\n# Citeseer with random split\npython main.py --dataname citeseer --epochs 200 --lr 1e-3 --wd 1e-5 --hid_dim 256 --out_dim 256 --act_fn prelu --der1 0.2 --der2 0.0 --dfr1 0.3 --dfr2 0.2 --temp 0.9\n\n# Citeseer with public split\npython main.py --dataname citeseer --split public --epochs 100 --lr 1e-3 --wd 1e-5 --hid_dim 512 --out_dim 512 --act_fn prelu --der1 0.3 --der2 0.3 --dfr1 0.3 --dfr2 0.3 --temp 0.4\n\n# Pubmed with random split\npython main.py --dataname pubmed --epochs 1500 --lr 1e-3 --wd 1e-5 --hid_dim 256 --out_dim 256 --act_fn relu --der1 0.4 --der2 0.1 --dfr1 0.0 --dfr2 0.2 --temp 0.7\n\n# Pubmed with public split\npython main.py --dataname pubmed --split public --epochs 1500 --lr 1e-3 --wd 1e-5 --hid_dim 256 --out_dim 256 --act_fn relu --der1 0.4 --der2 0.1 --dfr1 0.0 --dfr2 0.2 --temp 0.7\n```\n\n## \tPerformance\n\nFor random split, we use the hyper-parameters as stated in the paper. For public split,  we find the given hyper-parameters lead to poor performance, so we select the hyperparameters via a small grid search.\n\nRandom split (Train/Test = 1:9)\n\n|      Dataset      | Cora | Citeseer | Pubmed |\n| :---------------: | :--: | :------: | :----: |\n| Accuracy Reported | 83.3 |   72.1   |  86.7  |\n|   Author's Code   | 83.1 |   71.0   |  86.3  |\n|        DGL        | 83.4 |   71.4   |  86.1  |\n\nPublic split\n\n|    Dataset    | Cora | Citeseer | Pubmed |\n| :-----------: | :--: | :------: | :----: |\n| Author's Code | 81.9 |   71.2   |  80.6  |\n|      DGL      | 82.2 |   71.4   |  80.2  |\n\n"
  },
  {
    "path": "examples/pytorch/grace/aug.py",
    "content": "# Data augmentation on graphs via edge dropping and feature masking\n\nimport dgl\nimport numpy as np\nimport torch as th\n\n\ndef aug(graph, x, feat_drop_rate, edge_mask_rate):\n    n_node = graph.num_nodes()\n\n    edge_mask = mask_edge(graph, edge_mask_rate)\n    feat = drop_feature(x, feat_drop_rate)\n\n    src = graph.edges()[0]\n    dst = graph.edges()[1]\n\n    nsrc = src[edge_mask]\n    ndst = dst[edge_mask]\n\n    ng = dgl.graph((nsrc, ndst), num_nodes=n_node)\n    ng = ng.add_self_loop()\n\n    return ng, feat\n\n\ndef drop_feature(x, drop_prob):\n    drop_mask = (\n        th.empty((x.size(1),), dtype=th.float32, device=x.device).uniform_(0, 1)\n        < drop_prob\n    )\n    x = x.clone()\n    x[:, drop_mask] = 0\n\n    return x\n\n\ndef mask_edge(graph, mask_prob):\n    E = graph.num_edges()\n\n    mask_rates = th.FloatTensor(np.ones(E) * mask_prob)\n    masks = th.bernoulli(1 - mask_rates)\n    mask_idx = masks.nonzero().squeeze(1)\n    return mask_idx\n"
  },
  {
    "path": "examples/pytorch/grace/dataset.py",
    "content": "from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset\n\n\ndef load(name):\n    if name == \"cora\":\n        dataset = CoraGraphDataset()\n    elif name == \"citeseer\":\n        dataset = CiteseerGraphDataset()\n    elif name == \"pubmed\":\n        dataset = PubmedGraphDataset()\n\n    graph = dataset[0]\n\n    train_mask = graph.ndata.pop(\"train_mask\")\n    test_mask = graph.ndata.pop(\"test_mask\")\n\n    feat = graph.ndata.pop(\"feat\")\n    labels = graph.ndata.pop(\"label\")\n\n    return graph, feat, labels, train_mask, test_mask\n"
  },
  {
    "path": "examples/pytorch/grace/eval.py",
    "content": "\"\"\"\nCode adapted from https://github.com/CRIPAC-DIG/GRACE\nLinear evaluation on learned node embeddings\n\"\"\"\n\nimport functools\n\nimport numpy as np\nfrom sklearn.linear_model import LogisticRegression\nfrom sklearn.metrics import f1_score\nfrom sklearn.model_selection import GridSearchCV, train_test_split\nfrom sklearn.multiclass import OneVsRestClassifier\nfrom sklearn.preprocessing import normalize, OneHotEncoder\n\n\ndef repeat(n_times):\n    def decorator(f):\n        @functools.wraps(f)\n        def wrapper(*args, **kwargs):\n            results = [f(*args, **kwargs) for _ in range(n_times)]\n            statistics = {}\n            for key in results[0].keys():\n                values = [r[key] for r in results]\n                statistics[key] = {\n                    \"mean\": np.mean(values),\n                    \"std\": np.std(values),\n                }\n            print_statistics(statistics, f.__name__)\n            return statistics\n\n        return wrapper\n\n    return decorator\n\n\ndef prob_to_one_hot(y_pred):\n    ret = np.zeros(y_pred.shape, np.bool_)\n    indices = np.argmax(y_pred, axis=1)\n    for i in range(y_pred.shape[0]):\n        ret[i][indices[i]] = True\n    return ret\n\n\ndef print_statistics(statistics, function_name):\n    print(f\"(E) | {function_name}:\", end=\" \")\n    for i, key in enumerate(statistics.keys()):\n        mean = statistics[key][\"mean\"]\n        std = statistics[key][\"std\"]\n        print(f\"{key}={mean:.4f}+-{std:.4f}\", end=\"\")\n        if i != len(statistics.keys()) - 1:\n            print(\",\", end=\" \")\n        else:\n            print()\n\n\n@repeat(3)\ndef label_classification(\n    embeddings, y, train_mask, test_mask, split=\"random\", ratio=0.1\n):\n    X = embeddings.detach().cpu().numpy()\n    Y = y.detach().cpu().numpy()\n    Y = Y.reshape(-1, 1)\n    onehot_encoder = OneHotEncoder(categories=\"auto\").fit(Y)\n    Y = onehot_encoder.transform(Y).toarray().astype(np.bool_)\n\n    X = normalize(X, norm=\"l2\")\n\n    if split == \"random\":\n        X_train, X_test, y_train, y_test = train_test_split(\n            X, Y, test_size=1 - ratio\n        )\n    elif split == \"public\":\n        X_train = X[train_mask]\n        X_test = X[test_mask]\n        y_train = Y[train_mask]\n        y_test = Y[test_mask]\n\n    logreg = LogisticRegression(solver=\"liblinear\")\n    c = 2.0 ** np.arange(-10, 10)\n\n    clf = GridSearchCV(\n        estimator=OneVsRestClassifier(logreg),\n        param_grid=dict(estimator__C=c),\n        n_jobs=8,\n        cv=5,\n        verbose=0,\n    )\n    clf.fit(X_train, y_train)\n\n    y_pred = clf.predict_proba(X_test)\n    y_pred = prob_to_one_hot(y_pred)\n\n    micro = f1_score(y_test, y_pred, average=\"micro\")\n    macro = f1_score(y_test, y_pred, average=\"macro\")\n\n    return {\"F1Mi\": micro, \"F1Ma\": macro}\n"
  },
  {
    "path": "examples/pytorch/grace/main.py",
    "content": "import argparse\nimport warnings\n\nimport numpy as np\nimport torch as th\nimport torch.nn as nn\nfrom aug import aug\nfrom dataset import load\nfrom eval import label_classification\nfrom model import Grace\n\nwarnings.filterwarnings(\"ignore\")\n\n\ndef count_parameters(model):\n    return sum(\n        [np.prod(p.size()) for p in model.parameters() if p.requires_grad]\n    )\n\n\nparser = argparse.ArgumentParser()\nparser.add_argument(\"--dataname\", type=str, default=\"cora\")\nparser.add_argument(\"--gpu\", type=int, default=0)\nparser.add_argument(\"--split\", type=str, default=\"random\")\n\nparser.add_argument(\n    \"--epochs\", type=int, default=500, help=\"Number of training periods.\"\n)\nparser.add_argument(\"--lr\", type=float, default=0.001, help=\"Learning rate.\")\nparser.add_argument(\"--wd\", type=float, default=1e-5, help=\"Weight decay.\")\nparser.add_argument(\"--temp\", type=float, default=1.0, help=\"Temperature.\")\n\nparser.add_argument(\"--act_fn\", type=str, default=\"relu\")\n\nparser.add_argument(\n    \"--hid_dim\", type=int, default=256, help=\"Hidden layer dim.\"\n)\nparser.add_argument(\n    \"--out_dim\", type=int, default=256, help=\"Output layer dim.\"\n)\n\nparser.add_argument(\n    \"--num_layers\", type=int, default=2, help=\"Number of GNN layers.\"\n)\nparser.add_argument(\n    \"--der1\",\n    type=float,\n    default=0.2,\n    help=\"Drop edge ratio of the 1st augmentation.\",\n)\nparser.add_argument(\n    \"--der2\",\n    type=float,\n    default=0.2,\n    help=\"Drop edge ratio of the 2nd augmentation.\",\n)\nparser.add_argument(\n    \"--dfr1\",\n    type=float,\n    default=0.2,\n    help=\"Drop feature ratio of the 1st augmentation.\",\n)\nparser.add_argument(\n    \"--dfr2\",\n    type=float,\n    default=0.2,\n    help=\"Drop feature ratio of the 2nd augmentation.\",\n)\n\nargs = parser.parse_args()\n\nif args.gpu != -1 and th.cuda.is_available():\n    args.device = \"cuda:{}\".format(args.gpu)\nelse:\n    args.device = \"cpu\"\n\nif __name__ == \"__main__\":\n    # Step 1: Load hyperparameters =================================================================== #\n    lr = args.lr\n    hid_dim = args.hid_dim\n    out_dim = args.out_dim\n\n    num_layers = args.num_layers\n    act_fn = ({\"relu\": nn.ReLU(), \"prelu\": nn.PReLU()})[args.act_fn]\n\n    drop_edge_rate_1 = args.der1\n    drop_edge_rate_2 = args.der2\n    drop_feature_rate_1 = args.dfr1\n    drop_feature_rate_2 = args.dfr2\n\n    temp = args.temp\n    epochs = args.epochs\n    wd = args.wd\n\n    # Step 2: Prepare data =================================================================== #\n    graph, feat, labels, train_mask, test_mask = load(args.dataname)\n    in_dim = feat.shape[1]\n\n    # Step 3: Create model =================================================================== #\n    model = Grace(in_dim, hid_dim, out_dim, num_layers, act_fn, temp)\n    model = model.to(args.device)\n    print(f\"# params: {count_parameters(model)}\")\n\n    optimizer = th.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)\n\n    # Step 4: Training =======================================================================\n    for epoch in range(epochs):\n        model.train()\n        optimizer.zero_grad()\n        graph1, feat1 = aug(graph, feat, drop_feature_rate_1, drop_edge_rate_1)\n        graph2, feat2 = aug(graph, feat, drop_feature_rate_2, drop_edge_rate_2)\n\n        graph1 = graph1.to(args.device)\n        graph2 = graph2.to(args.device)\n\n        feat1 = feat1.to(args.device)\n        feat2 = feat2.to(args.device)\n\n        loss = model(graph1, graph2, feat1, feat2)\n        loss.backward()\n        optimizer.step()\n\n        print(f\"Epoch={epoch:03d}, loss={loss.item():.4f}\")\n\n    # Step 5: Linear evaluation ============================================================== #\n    print(\"=== Final ===\")\n\n    graph = graph.add_self_loop()\n    graph = graph.to(args.device)\n    feat = feat.to(args.device)\n    embeds = model.get_embedding(graph, feat)\n\n    \"\"\"Evaluation Embeddings  \"\"\"\n    label_classification(\n        embeds, labels, train_mask, test_mask, split=args.split\n    )\n"
  },
  {
    "path": "examples/pytorch/grace/model.py",
    "content": "import torch as th\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom dgl.nn import GraphConv\n\n\n# Multi-layer Graph Convolutional Networks\nclass GCN(nn.Module):\n    def __init__(self, in_dim, out_dim, act_fn, num_layers=2):\n        super(GCN, self).__init__()\n\n        assert num_layers >= 2\n        self.num_layers = num_layers\n        self.convs = nn.ModuleList()\n\n        self.convs.append(GraphConv(in_dim, out_dim * 2))\n        for _ in range(self.num_layers - 2):\n            self.convs.append(GraphConv(out_dim * 2, out_dim * 2))\n\n        self.convs.append(GraphConv(out_dim * 2, out_dim))\n        self.act_fn = act_fn\n\n    def forward(self, graph, feat):\n        for i in range(self.num_layers):\n            feat = self.act_fn(self.convs[i](graph, feat))\n\n        return feat\n\n\n# Multi-layer(2-layer) Perceptron\nclass MLP(nn.Module):\n    def __init__(self, in_dim, out_dim):\n        super(MLP, self).__init__()\n        self.fc1 = nn.Linear(in_dim, out_dim)\n        self.fc2 = nn.Linear(out_dim, in_dim)\n\n    def forward(self, x):\n        z = F.elu(self.fc1(x))\n        return self.fc2(z)\n\n\nclass Grace(nn.Module):\n    r\"\"\"\n        GRACE model\n    Parameters\n    -----------\n    in_dim: int\n        Input feature size.\n    hid_dim: int\n        Hidden feature size.\n    out_dim: int\n        Output feature size.\n    num_layers: int\n        Number of the GNN encoder layers.\n    act_fn: nn.Module\n        Activation function.\n    temp: float\n        Temperature constant.\n    \"\"\"\n\n    def __init__(self, in_dim, hid_dim, out_dim, num_layers, act_fn, temp):\n        super(Grace, self).__init__()\n        self.encoder = GCN(in_dim, hid_dim, act_fn, num_layers)\n        self.temp = temp\n        self.proj = MLP(hid_dim, out_dim)\n\n    def sim(self, z1, z2):\n        # normalize embeddings across feature dimension\n        z1 = F.normalize(z1)\n        z2 = F.normalize(z2)\n\n        s = th.mm(z1, z2.t())\n        return s\n\n    def get_loss(self, z1, z2):\n        # calculate SimCLR loss\n        f = lambda x: th.exp(x / self.temp)\n\n        refl_sim = f(self.sim(z1, z1))  # intra-view pairs\n        between_sim = f(self.sim(z1, z2))  # inter-view pairs\n\n        # between_sim.diag(): positive pairs\n        x1 = refl_sim.sum(1) + between_sim.sum(1) - refl_sim.diag()\n        loss = -th.log(between_sim.diag() / x1)\n\n        return loss\n\n    def get_embedding(self, graph, feat):\n        # get embeddings from the model for evaluation\n        h = self.encoder(graph, feat)\n\n        return h.detach()\n\n    def forward(self, graph1, graph2, feat1, feat2):\n        # encoding\n        h1 = self.encoder(graph1, feat1)\n        h2 = self.encoder(graph2, feat2)\n\n        # projection\n        z1 = self.proj(h1)\n        z2 = self.proj(h2)\n\n        # get loss\n        l1 = self.get_loss(z1, z2)\n        l2 = self.get_loss(z2, z1)\n\n        ret = (l1 + l2) * 0.5\n\n        return ret.mean()\n"
  },
  {
    "path": "examples/pytorch/grand/README.md",
    "content": "# Graph Random Neural Network(GRAND)\n\nThis DGL example implements the GNN model proposed in the paper [Graph Random Neural Network for Semi-Supervised Learning on Graphs]( https://arxiv.org/abs/2005.11079).\n\nAuthor's code: https://github.com/THUDM/GRAND\n\n## Example Implementor\n\nThis example was implemented by [Hengrui Zhang](https://github.com/hengruizhang98) when he was an applied scientist intern at AWS Shanghai AI Lab.\n\n## Dependencies\n- Python 3.7\n- PyTorch 1.7.1\n- dgl 0.5.3\n\n## Dataset\n\nThe DGL's built-in Cora, Pubmed and Citeseer datasets. Dataset summary:\n\n| Dataset | #Nodes | #Edges | #Feats | #Classes | #Train Nodes | #Val Nodes | #Test Nodes |\n| :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: |\n| Citeseer | 3,327 | 9,228 | 3,703 | 6 | 120 | 500 | 1000 |\n| Cora | 2,708 | 10,556 | 1,433 | 7 | 140 | 500 | 1000 |\n| Pubmed | 19,717 | 88,651 | 500 | 3 | 60 | 500 | 1000 |\n\n## Arguments\n\n###### Dataset options\n```\n--dataname          str     The graph dataset name.             Default is 'cora'.\n```\n\n###### GPU options\n```\n--gpu              int     GPU index.                          Default is -1, using CPU.\n```\n\n###### Model options\n```\n--epochs           int     Number of training epochs.             Default is 2000.\n--early_stopping   int     Early stopping patience rounds.        Default is 200.\n--lr               float   Adam optimizer learning rate.          Default is 0.01.\n--weight_decay     float   L2 regularization coefficient.         Default is 5e-4.\n--dropnode_rate    float   Dropnode rate (1 - keep probability).  Default is 0.5.\n--input_droprate   float   Dropout rate of input layer.           Default is 0.5.\n--hidden_droprate  float   Dropout rate of hidden layer.          Default is 0.5.\n--hid_dim          int     Hidden layer dimensionalities.         Default is 32.\n--order            int     Propagation step.                      Default is 8.\n--sample           int     Sampling times of dropnode.            Default is 4.\n--tem              float   Sharpening temperaturer.               Default is 0.5.\n--lam              float   Coefficient of Consistency reg         Default is 1.0.\n--use_bn           bool    Using batch normalization.             Default is False\n```\n\n## Examples\n\nTrain a model which follows the original hyperparameters on different datasets.\n```bash\n# Cora:\npython main.py --dataname cora --gpu 0 --lam 1.0 --tem 0.5 --order 8 --sample 4 --input_droprate 0.5 --hidden_droprate 0.5 --dropnode_rate 0.5 --hid_dim 32 --early_stopping 100 --lr 1e-2  --epochs 2000\n# Citeseer:\npython main.py --dataname citeseer --gpu 0 --lam 0.7 --tem 0.3 --order 2 --sample 2 --input_droprate 0.0 --hidden_droprate 0.2 --dropnode_rate 0.5 --hid_dim 32 --early_stopping 100 --lr 1e-2  --epochs 2000\n# Pubmed:\npython main.py --dataname pubmed --gpu 0 --lam 1.0 --tem 0.2 --order 5 --sample 4 --input_droprate 0.6 --hidden_droprate 0.8 --dropnode_rate 0.5 --hid_dim 32 --early_stopping 200 --lr 0.2 --epochs 2000 --use_bn\n```\n\n### Performance\n\nThe hyperparameter setting in our implementation is identical to that reported in the paper.\n\n| Dataset | Cora | Citeseer | Pubmed |\n| :-: | :-: | :-: | :-: |\n| Accuracy Reported(100 runs) | **85.4(±0.4)** | **75.4(±0.4)** | 82.7(±0.6) |\n| Accuracy DGL(20 runs) | 85.33(±0.41) | 75.36(±0.36) | **82.90(±0.66)** |\n\n\n\n"
  },
  {
    "path": "examples/pytorch/grand/main.py",
    "content": "import argparse\nimport warnings\n\nimport dgl\n\nimport numpy as np\nimport torch as th\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nfrom dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset\nfrom model import GRAND\n\nwarnings.filterwarnings(\"ignore\")\n\n\ndef argument():\n    parser = argparse.ArgumentParser(description=\"GRAND\")\n\n    # data source params\n    parser.add_argument(\n        \"--dataname\", type=str, default=\"cora\", help=\"Name of dataset.\"\n    )\n    # cuda params\n    parser.add_argument(\n        \"--gpu\", type=int, default=-1, help=\"GPU index. Default: -1, using CPU.\"\n    )\n    # training params\n    parser.add_argument(\n        \"--epochs\", type=int, default=200, help=\"Training epochs.\"\n    )\n    parser.add_argument(\n        \"--early_stopping\",\n        type=int,\n        default=200,\n        help=\"Patient epochs to wait before early stopping.\",\n    )\n    parser.add_argument(\"--lr\", type=float, default=0.01, help=\"Learning rate.\")\n    parser.add_argument(\n        \"--weight_decay\", type=float, default=5e-4, help=\"L2 reg.\"\n    )\n    # model params\n    parser.add_argument(\n        \"--hid_dim\", type=int, default=32, help=\"Hidden layer dimensionalities.\"\n    )\n    parser.add_argument(\n        \"--dropnode_rate\",\n        type=float,\n        default=0.5,\n        help=\"Dropnode rate (1 - keep probability).\",\n    )\n    parser.add_argument(\n        \"--input_droprate\",\n        type=float,\n        default=0.0,\n        help=\"dropout rate of input layer\",\n    )\n    parser.add_argument(\n        \"--hidden_droprate\",\n        type=float,\n        default=0.0,\n        help=\"dropout rate of hidden layer\",\n    )\n    parser.add_argument(\"--order\", type=int, default=8, help=\"Propagation step\")\n    parser.add_argument(\n        \"--sample\", type=int, default=4, help=\"Sampling times of dropnode\"\n    )\n    parser.add_argument(\n        \"--tem\", type=float, default=0.5, help=\"Sharpening temperature\"\n    )\n    parser.add_argument(\n        \"--lam\",\n        type=float,\n        default=1.0,\n        help=\"Coefficient of consistency regularization\",\n    )\n    parser.add_argument(\n        \"--use_bn\",\n        action=\"store_true\",\n        default=False,\n        help=\"Using Batch Normalization\",\n    )\n\n    args = parser.parse_args()\n\n    # check cuda\n    if args.gpu != -1 and th.cuda.is_available():\n        args.device = \"cuda:{}\".format(args.gpu)\n    else:\n        args.device = \"cpu\"\n\n    return args\n\n\ndef consis_loss(logps, temp, lam):\n    ps = [th.exp(p) for p in logps]\n    ps = th.stack(ps, dim=2)\n\n    avg_p = th.mean(ps, dim=2)\n    sharp_p = (\n        th.pow(avg_p, 1.0 / temp)\n        / th.sum(th.pow(avg_p, 1.0 / temp), dim=1, keepdim=True)\n    ).detach()\n\n    sharp_p = sharp_p.unsqueeze(2)\n    loss = th.mean(th.sum(th.pow(ps - sharp_p, 2), dim=1, keepdim=True))\n\n    loss = lam * loss\n    return loss\n\n\nif __name__ == \"__main__\":\n    # Step 1: Prepare graph data and retrieve train/validation/test index ============================= #\n    # Load from DGL dataset\n    args = argument()\n    print(args)\n\n    if args.dataname == \"cora\":\n        dataset = CoraGraphDataset()\n    elif args.dataname == \"citeseer\":\n        dataset = CiteseerGraphDataset()\n    elif args.dataname == \"pubmed\":\n        dataset = PubmedGraphDataset()\n\n    graph = dataset[0]\n\n    graph = dgl.add_self_loop(graph)\n    device = args.device\n\n    # retrieve the number of classes\n    n_classes = dataset.num_classes\n\n    # retrieve labels of ground truth\n    labels = graph.ndata.pop(\"label\").to(device).long()\n\n    # Extract node features\n    feats = graph.ndata.pop(\"feat\").to(device)\n    n_features = feats.shape[-1]\n\n    # retrieve masks for train/validation/test\n    train_mask = graph.ndata.pop(\"train_mask\")\n    val_mask = graph.ndata.pop(\"val_mask\")\n    test_mask = graph.ndata.pop(\"test_mask\")\n\n    train_idx = th.nonzero(train_mask, as_tuple=False).squeeze().to(device)\n    val_idx = th.nonzero(val_mask, as_tuple=False).squeeze().to(device)\n    test_idx = th.nonzero(test_mask, as_tuple=False).squeeze().to(device)\n\n    # Step 2: Create model =================================================================== #\n    model = GRAND(\n        n_features,\n        args.hid_dim,\n        n_classes,\n        args.sample,\n        args.order,\n        args.dropnode_rate,\n        args.input_droprate,\n        args.hidden_droprate,\n        args.use_bn,\n    )\n\n    model = model.to(args.device)\n    graph = graph.to(args.device)\n\n    # Step 3: Create training components ===================================================== #\n    loss_fn = nn.NLLLoss()\n    opt = optim.Adam(\n        model.parameters(), lr=args.lr, weight_decay=args.weight_decay\n    )\n\n    loss_best = np.inf\n    acc_best = 0\n\n    # Step 4: training epoches =============================================================== #\n    for epoch in range(args.epochs):\n        \"\"\"Training\"\"\"\n        model.train()\n\n        loss_sup = 0\n        logits = model(graph, feats, True)\n\n        # calculate supervised loss\n        for k in range(args.sample):\n            loss_sup += F.nll_loss(logits[k][train_idx], labels[train_idx])\n\n        loss_sup = loss_sup / args.sample\n\n        # calculate consistency loss\n        loss_consis = consis_loss(logits, args.tem, args.lam)\n\n        loss_train = loss_sup + loss_consis\n        acc_train = th.sum(\n            logits[0][train_idx].argmax(dim=1) == labels[train_idx]\n        ).item() / len(train_idx)\n\n        # backward\n        opt.zero_grad()\n        loss_train.backward()\n        opt.step()\n\n        \"\"\" Validating \"\"\"\n        model.eval()\n        with th.no_grad():\n            val_logits = model(graph, feats, False)\n\n            loss_val = F.nll_loss(val_logits[val_idx], labels[val_idx])\n            acc_val = th.sum(\n                val_logits[val_idx].argmax(dim=1) == labels[val_idx]\n            ).item() / len(val_idx)\n\n            # Print out performance\n            print(\n                \"In epoch {}, Train Acc: {:.4f} | Train Loss: {:.4f} ,Val Acc: {:.4f} | Val Loss: {:.4f}\".format(\n                    epoch,\n                    acc_train,\n                    loss_train.item(),\n                    acc_val,\n                    loss_val.item(),\n                )\n            )\n\n            # set early stopping counter\n            if loss_val < loss_best or acc_val > acc_best:\n                if loss_val < loss_best:\n                    best_epoch = epoch\n                    th.save(model.state_dict(), args.dataname + \".pkl\")\n                no_improvement = 0\n                loss_best = min(loss_val, loss_best)\n                acc_best = max(acc_val, acc_best)\n            else:\n                no_improvement += 1\n                if no_improvement == args.early_stopping:\n                    print(\"Early stopping.\")\n                    break\n\n    print(\"Optimization Finished!\")\n\n    print(\"Loading {}th epoch\".format(best_epoch))\n    model.load_state_dict(th.load(args.dataname + \".pkl\"))\n\n    \"\"\" Testing \"\"\"\n    model.eval()\n\n    test_logits = model(graph, feats, False)\n    test_acc = th.sum(\n        test_logits[test_idx].argmax(dim=1) == labels[test_idx]\n    ).item() / len(test_idx)\n\n    print(\"Test Acc: {:.4f}\".format(test_acc))\n"
  },
  {
    "path": "examples/pytorch/grand/model.py",
    "content": "import dgl.function as fn\nimport numpy as np\nimport torch as th\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\ndef drop_node(feats, drop_rate, training):\n    n = feats.shape[0]\n    drop_rates = th.FloatTensor(np.ones(n) * drop_rate)\n\n    if training:\n        masks = th.bernoulli(1.0 - drop_rates).unsqueeze(1)\n        feats = masks.to(feats.device) * feats\n\n    else:\n        feats = feats * (1.0 - drop_rate)\n\n    return feats\n\n\nclass MLP(nn.Module):\n    def __init__(\n        self, nfeat, nhid, nclass, input_droprate, hidden_droprate, use_bn=False\n    ):\n        super(MLP, self).__init__()\n\n        self.layer1 = nn.Linear(nfeat, nhid, bias=True)\n        self.layer2 = nn.Linear(nhid, nclass, bias=True)\n\n        self.input_dropout = nn.Dropout(input_droprate)\n        self.hidden_dropout = nn.Dropout(hidden_droprate)\n        self.bn1 = nn.BatchNorm1d(nfeat)\n        self.bn2 = nn.BatchNorm1d(nhid)\n        self.use_bn = use_bn\n\n    def reset_parameters(self):\n        self.layer1.reset_parameters()\n        self.layer2.reset_parameters()\n\n    def forward(self, x):\n        if self.use_bn:\n            x = self.bn1(x)\n        x = self.input_dropout(x)\n        x = F.relu(self.layer1(x))\n\n        if self.use_bn:\n            x = self.bn2(x)\n        x = self.hidden_dropout(x)\n        x = self.layer2(x)\n\n        return x\n\n\ndef GRANDConv(graph, feats, order):\n    \"\"\"\n    Parameters\n    -----------\n    graph: dgl.Graph\n        The input graph\n    feats: Tensor (n_nodes * feat_dim)\n        Node features\n    order: int\n        Propagation Steps\n    \"\"\"\n    with graph.local_scope():\n        \"\"\"Calculate Symmetric normalized adjacency matrix   \\hat{A}\"\"\"\n        degs = graph.in_degrees().float().clamp(min=1)\n        norm = th.pow(degs, -0.5).to(feats.device).unsqueeze(1)\n\n        graph.ndata[\"norm\"] = norm\n        graph.apply_edges(fn.u_mul_v(\"norm\", \"norm\", \"weight\"))\n\n        \"\"\" Graph Conv \"\"\"\n        x = feats\n        y = 0 + feats\n\n        for i in range(order):\n            graph.ndata[\"h\"] = x\n            graph.update_all(fn.u_mul_e(\"h\", \"weight\", \"m\"), fn.sum(\"m\", \"h\"))\n            x = graph.ndata.pop(\"h\")\n            y.add_(x)\n\n    return y / (order + 1)\n\n\nclass GRAND(nn.Module):\n    r\"\"\"\n\n    Parameters\n    -----------\n    in_dim: int\n        Input feature size. i.e, the number of dimensions of: math: `H^{(i)}`.\n    hid_dim: int\n        Hidden feature size.\n    n_class: int\n        Number of classes.\n    S: int\n        Number of Augmentation samples\n    K: int\n        Number of Propagation Steps\n    node_dropout: float\n        Dropout rate on node features.\n    input_dropout: float\n        Dropout rate of the input layer of a MLP\n    hidden_dropout: float\n        Dropout rate of the hidden layer of a MLPx\n    batchnorm: bool, optional\n        If True, use batch normalization.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        in_dim,\n        hid_dim,\n        n_class,\n        S=1,\n        K=3,\n        node_dropout=0.0,\n        input_droprate=0.0,\n        hidden_droprate=0.0,\n        batchnorm=False,\n    ):\n        super(GRAND, self).__init__()\n        self.in_dim = in_dim\n        self.hid_dim = hid_dim\n        self.S = S\n        self.K = K\n        self.n_class = n_class\n\n        self.mlp = MLP(\n            in_dim, hid_dim, n_class, input_droprate, hidden_droprate, batchnorm\n        )\n\n        self.dropout = node_dropout\n        self.node_dropout = nn.Dropout(node_dropout)\n\n    def forward(self, graph, feats, training=True):\n        X = feats\n        S = self.S\n\n        if training:  # Training Mode\n            output_list = []\n            for s in range(S):\n                drop_feat = drop_node(X, self.dropout, True)  # Drop node\n                feat = GRANDConv(graph, drop_feat, self.K)  # Graph Convolution\n                output_list.append(\n                    th.log_softmax(self.mlp(feat), dim=-1)\n                )  # Prediction\n\n            return output_list\n        else:  # Inference Mode\n            drop_feat = drop_node(X, self.dropout, False)\n            X = GRANDConv(graph, drop_feat, self.K)\n\n            return th.log_softmax(self.mlp(X), dim=-1)\n"
  },
  {
    "path": "examples/pytorch/graph_matching/README.md",
    "content": "# Graph Matching Routines\n\nImplementation of various algorithms to compute the Graph Edit Distance (GED) between two DGLGraphs G1 and G2. The graph edit distance between two graphs is a generalization of the string edit distance between strings. The following four algorithms are implemented:\n\n - astar: Calculates exact GED using A* graph traversal algorithm, the heuristic used is the one proposed in (Riesen and Bunke, 2009) [1].\n - beam: Calculates approximate GED using A* graph traversal algorithm, with a threshold on the size of the open list. [2]\n - bipartite: Calculates approximate GED using linear assignment on the nodes, with Jonker-Volgerand (JV) algorithm. [3]\n - hausdorff: Approximation of graph edit distance based on Hausdorff matching [4].\n\n### Dependencies\n  - lapjv (https://github.com/src-d/lapjv): We use the lapjv implementation to solve assignment problem, because of its scalability. Another option is to use the hungarian algorithm provided by scipy (scipy.optimize.linear_sum_assignment).\n\n### Usage\n\nExamples of usage are provided in examples.py. The function signature and an example is also given below:\n\n```sh\ngraph_edit_distance(G1, G2, node_substitution_cost=None, edge_substitution_cost=None, G1_node_deletion_cost=None, G2_node_insertion_cost=None, G1_edge_deletion_cost=None, G2_edge_insertion_cost=None, algorithm='bipartite', max_beam_size=100)\n\"\"\"\nParameters\n----------\nG1, G2: DGLGraphs\n\nnode_substitution_cost, edge_substitution_cost : 2D numpy arrays\nnode_substitution_cost[i,j] is the cost of substitution node i of G1 with node j of G2, similar definition for edge_substitution_cost. If None, default cost of 0 is used.\n\nG1_node_deletion_cost, G1_edge_deletion_cost : 1D numpy arrays\nG1_node_deletion_cost[i] is the cost of deletion of node i of G1, similar definition for G1_edge_deletion_cost. If None, default cost of 1 is used.\n    \nG2_node_insertion_cost, G2_edge_insertion_cost : 1D numpy arrays\nG2_node_insertion_cost[i] is the cost of insertion of node i of G2, similar definition for G2_edge_insertion_cost. If None, default cost of 1 is used.\n\nalgorithm : string\nAlgorithm to use to calculate the edit distance. Can be either 'astar', 'beam', 'bipartite' or 'hausdorff'.\n\nmax_beam_size : int\nMaximum number of nodes in the open list, in case the algorithm is 'beam'.\n    \nReturns\n-------\nA tuple of three objects: (edit_distance, node_mapping, edge_mapping)\nedit distance is the calculated edit distance (float).\nnode_mapping is a tuple of size two, containing the node assignments of the two graphs respectively. eg., node_mapping[0][i] is the node mapping of node i of graph G1 (None means that the node is deleted). Similar definition for the edge_mapping.\nFor 'hausdorff', node_mapping and edge_mapping are returned as None, as this approximation does not return a unique edit path.\n\nExamples\n--------\n>>> src1 = [0, 1, 2, 3, 4, 5];\n>>> dst1 = [1, 2, 3, 4, 5, 6];\n>>> src2 = [0, 1, 3, 4, 5];\n>>> dst2 = [1, 2, 4, 5, 6];\n\n>>> G1 = dgl.DGLGraph((src1, dst1))\n>>> G2 = dgl.DGLGraph((src2, dst2))\n>>> distance, node_mapping, edge_mapping = graph_edit_distance(G1, G1, algorithm='astar')\n>>> print(distance)\n0.0\n>>> distance, node_mapping, edge_mapping = graph_edit_distance(G1, G2, algorithm='astar')\n>>> print(distance)\n1.0\n```\n### References\n    [1] Riesen, Kaspar, Stefan Fankhauser, and Horst Bunke. \"Speeding Up Graph Edit Distance Computation with a Bipartite Heuristic.\" MLG. 2007.\n    [2] Neuhaus, Michel, Kaspar Riesen, and Horst Bunke. \"Fast suboptimal algorithms for the computation of graph edit distance.\" Joint IAPR International Workshops on Statistical Techniques in Pattern Recognition (SPR) and Structural and Syntactic Pattern Recognition (SSPR). 2006.\n    [3] Fankhauser, Stefan, Kaspar Riesen, and Horst Bunke. \"Speeding up graph edit distance computation through fast bipartite matching.\" International Workshop on Graph-Based Representations in Pattern Recognition. 2011.\n    [4] Fischer, Andreas, et al. \"A hausdorff heuristic for efficient computation of graph edit distance.\" Joint IAPR International Workshops on Statistical Techniques in Pattern Recognition (SPR) and Structural and Syntactic Pattern Recognition (SSPR). 2014.\n\n\n\n"
  },
  {
    "path": "examples/pytorch/graph_matching/examples.py",
    "content": "import dgl\nimport numpy as np\nfrom ged import graph_edit_distance\n\nsrc1 = [0, 1, 2, 3, 4, 5]\ndst1 = [1, 2, 3, 4, 5, 6]\n\nsrc2 = [0, 1, 3, 4, 5]\ndst2 = [1, 2, 4, 5, 6]\n\n\nG1 = dgl.DGLGraph((src1, dst1))\nG2 = dgl.DGLGraph((src2, dst2))\n\n\n# Exact edit distance with astar search\ndistance, node_mapping, edge_mapping = graph_edit_distance(\n    G1, G1, algorithm=\"astar\"\n)\nprint(distance)  # 0.0\ndistance, node_mapping, edge_mapping = graph_edit_distance(\n    G1, G2, algorithm=\"astar\"\n)\nprint(distance)  # 1.0\n\n# With user-input cost matrices\nnode_substitution_cost = np.empty((G1.num_nodes(), G2.num_nodes()))\nG1_node_deletion_cost = np.empty(G1.num_nodes())\nG2_node_insertion_cost = np.empty(G2.num_nodes())\n\nedge_substitution_cost = np.empty((G1.num_edges(), G2.num_edges()))\nG1_edge_deletion_cost = np.empty(G1.num_edges())\nG2_edge_insertion_cost = np.empty(G2.num_edges())\n\n# Node substitution cost of 0 when node-ids are same, else 1\nnode_substitution_cost.fill(1.0)\nfor i in range(G1.num_nodes()):\n    for j in range(G2.num_nodes()):\n        node_substitution_cost[i, j] = 0.0\n\n# Node insertion/deletion cost of 1\nG1_node_deletion_cost.fill(1.0)\nG2_node_insertion_cost.fill(1.0)\n\n# Edge substitution cost of 0\nedge_substitution_cost.fill(0.0)\n\n# Edge insertion/deletion cost of 0.5\nG1_edge_deletion_cost.fill(0.5)\nG2_edge_insertion_cost.fill(0.5)\n\ndistance, node_mapping, edge_mapping = graph_edit_distance(\n    G1,\n    G2,\n    node_substitution_cost,\n    edge_substitution_cost,\n    G1_node_deletion_cost,\n    G2_node_insertion_cost,\n    G1_edge_deletion_cost,\n    G2_edge_insertion_cost,\n    algorithm=\"astar\",\n)\n\nprint(distance)  # 0.5\n\n\n# Approximate edit distance with beam search, it is more than or equal to the exact edit distance\ndistance, node_mapping, edge_mapping = graph_edit_distance(\n    G1, G2, algorithm=\"beam\", max_beam_size=2\n)\nprint(distance)  # 3.0\n\n# Approximate edit distance with bipartite heuristic, it is more than or equal to the exact edit distance\ndistance, node_mapping, edge_mapping = graph_edit_distance(\n    G1, G2, algorithm=\"bipartite\"\n)\nprint(\n    distance\n)  # 9.0, can be different as multiple solutions possible for the intermediate LAP used in this approximation\n\n\n# Approximate edit distance with hausdorff heuristic, it is less than or equal to the exact edit distance\ndistance, node_mapping, edge_mapping = graph_edit_distance(\n    G1, G2, algorithm=\"hausdorff\"\n)\nprint(distance)  # 0.0\n"
  },
  {
    "path": "examples/pytorch/graph_matching/ged.py",
    "content": "from copy import deepcopy\nfrom heapq import heapify, heappop, heappush, nsmallest\n\nimport dgl\nimport numpy as np\n\n# We use lapjv implementation (https://github.com/src-d/lapjv) to solve assignment problem, because of its scalability\n# Also see https://github.com/berhane/LAP-solvers for benchmarking of LAP solvers\nfrom lapjv import lapjv\n\nEPSILON = 0.0000001\n\n\ndef validate_cost_functions(\n    G1,\n    G2,\n    node_substitution_cost=None,\n    edge_substitution_cost=None,\n    G1_node_deletion_cost=None,\n    G1_edge_deletion_cost=None,\n    G2_node_insertion_cost=None,\n    G2_edge_insertion_cost=None,\n):\n    \"\"\"Validates cost functions (substitution, insertion, deletion) and initializes them with default=0 for substitution\n    and default=1 for insertion/deletion\n    if the provided ones are None.\n\n\n    Parameters : see graph_edit_distance\n\n    \"\"\"\n    num_G1_nodes = G1.num_nodes()\n    num_G2_nodes = G2.num_nodes()\n\n    num_G1_edges = G1.num_edges()\n    num_G2_edges = G2.num_edges()\n\n    # if any cost matrix is None, initialize it with default costs\n    if node_substitution_cost is None:\n        node_substitution_cost = np.zeros(\n            (num_G1_nodes, num_G2_nodes), dtype=float\n        )\n    else:\n        assert node_substitution_cost.shape == (num_G1_nodes, num_G2_nodes)\n\n    if edge_substitution_cost is None:\n        edge_substitution_cost = np.zeros(\n            (num_G1_edges, num_G2_edges), dtype=float\n        )\n    else:\n        assert edge_substitution_cost.shape == (num_G1_edges, num_G2_edges)\n\n    if G1_node_deletion_cost is None:\n        G1_node_deletion_cost = np.ones(num_G1_nodes, dtype=float)\n    else:\n        assert G1_node_deletion_cost.shape[0] == num_G1_nodes\n\n    if G1_edge_deletion_cost is None:\n        G1_edge_deletion_cost = np.ones(num_G1_edges, dtype=float)\n    else:\n        assert G1_edge_deletion_cost.shape[0] == num_G1_edges\n\n    if G2_node_insertion_cost is None:\n        G2_node_insertion_cost = np.ones(num_G2_nodes, dtype=float)\n    else:\n        assert G2_node_insertion_cost.shape[0] == num_G2_nodes\n\n    if G2_edge_insertion_cost is None:\n        G2_edge_insertion_cost = np.ones(num_G2_edges, dtype=float)\n    else:\n        assert G2_edge_insertion_cost.shape[0] == num_G2_edges\n\n    return (\n        node_substitution_cost,\n        edge_substitution_cost,\n        G1_node_deletion_cost,\n        G1_edge_deletion_cost,\n        G2_node_insertion_cost,\n        G2_edge_insertion_cost,\n    )\n\n\ndef construct_cost_functions(\n    G1,\n    G2,\n    node_substitution_cost,\n    edge_substitution_cost,\n    G1_node_deletion_cost,\n    G1_edge_deletion_cost,\n    G2_node_insertion_cost,\n    G2_edge_insertion_cost,\n):\n    \"\"\"Constructs cost matrices for LAP solution\n\n\n    Parameters : see graph_edit_distance\n\n    \"\"\"\n    num_G1_nodes = G1.num_nodes()\n    num_G2_nodes = G2.num_nodes()\n\n    num_G1_edges = G1.num_edges()\n    num_G2_edges = G2.num_edges()\n\n    # cost matrix of node mappings\n    cost_upper_bound = (\n        node_substitution_cost.sum()\n        + G1_node_deletion_cost.sum()\n        + G2_node_insertion_cost.sum()\n        + 1\n    )\n    C_node = np.zeros(\n        (num_G1_nodes + num_G2_nodes, num_G1_nodes + num_G2_nodes), dtype=float\n    )\n\n    C_node[0:num_G1_nodes, 0:num_G2_nodes] = node_substitution_cost\n    C_node[\n        0:num_G1_nodes, num_G2_nodes : num_G2_nodes + num_G1_nodes\n    ] = np.array(\n        [\n            G1_node_deletion_cost[i] if i == j else cost_upper_bound\n            for i in range(num_G1_nodes)\n            for j in range(num_G1_nodes)\n        ]\n    ).reshape(\n        num_G1_nodes, num_G1_nodes\n    )\n    C_node[\n        num_G1_nodes : num_G1_nodes + num_G2_nodes, 0:num_G2_nodes\n    ] = np.array(\n        [\n            G2_node_insertion_cost[i] if i == j else cost_upper_bound\n            for i in range(num_G2_nodes)\n            for j in range(num_G2_nodes)\n        ]\n    ).reshape(\n        num_G2_nodes, num_G2_nodes\n    )\n\n    # cost matrix of edge mappings\n    cost_upper_bound = (\n        edge_substitution_cost.sum()\n        + G1_edge_deletion_cost.sum()\n        + G2_edge_insertion_cost.sum()\n        + 1\n    )\n    C_edge = np.zeros(\n        (num_G1_edges + num_G2_edges, num_G1_edges + num_G2_edges), dtype=float\n    )\n\n    C_edge[0:num_G1_edges, 0:num_G2_edges] = edge_substitution_cost\n    C_edge[\n        0:num_G1_edges, num_G2_edges : num_G2_edges + num_G1_edges\n    ] = np.array(\n        [\n            G1_edge_deletion_cost[i] if i == j else cost_upper_bound\n            for i in range(num_G1_edges)\n            for j in range(num_G1_edges)\n        ]\n    ).reshape(\n        num_G1_edges, num_G1_edges\n    )\n    C_edge[\n        num_G1_edges : num_G1_edges + num_G2_edges, 0:num_G2_edges\n    ] = np.array(\n        [\n            G2_edge_insertion_cost[i] if i == j else cost_upper_bound\n            for i in range(num_G2_edges)\n            for j in range(num_G2_edges)\n        ]\n    ).reshape(\n        num_G2_edges, num_G2_edges\n    )\n    return C_node, C_edge\n\n\ndef get_edges_to_match(G, node_id, matched_nodes):\n    # Find the edges in G with one end-point as node_id and other in matched_nodes or node_id\n    incident_edges = np.array([], dtype=int)\n    index = np.array([], dtype=int)\n    direction = np.array([], dtype=int)\n    if G.has_edge_between(node_id, node_id):\n        self_edge_ids = G.edge_ids(node_id, node_id, return_array=True).numpy()\n        incident_edges = np.concatenate((incident_edges, self_edge_ids))\n        index = np.concatenate((index, [-1] * len(self_edge_ids)))\n        direction = np.concatenate((direction, [0] * len(self_edge_ids)))\n    # Find predecessors\n    src, _, eid = G.in_edges([node_id], \"all\")\n    eid = eid.numpy()\n    src = src.numpy()\n    filtered_indices = [\n        (i, matched_nodes.index(src[i]))\n        for i in range(len(src))\n        if src[i] in matched_nodes\n    ]\n    matched_index = np.array([_[1] for _ in filtered_indices], dtype=int)\n    eid_index = np.array([_[0] for _ in filtered_indices], dtype=int)\n    index = np.concatenate((index, matched_index))\n    incident_edges = np.concatenate((incident_edges, eid[eid_index]))\n    direction = np.concatenate(\n        (direction, np.array([-1] * len(filtered_indices), dtype=int))\n    )\n    # Find successors\n    _, dst, eid = G.out_edges([node_id], \"all\")\n    eid = eid.numpy()\n    dst = dst.numpy()\n    filtered_indices = [\n        (i, matched_nodes.index(dst[i]))\n        for i in range(len(dst))\n        if dst[i] in matched_nodes\n    ]\n    matched_index = np.array([_[1] for _ in filtered_indices], dtype=int)\n    eid_index = np.array([_[0] for _ in filtered_indices], dtype=int)\n    index = np.concatenate((index, matched_index))\n    incident_edges = np.concatenate((incident_edges, eid[eid_index]))\n    direction = np.concatenate(\n        (direction, np.array([1] * len(filtered_indices), dtype=int))\n    )\n    return incident_edges, index, direction\n\n\ndef subset_cost_matrix(cost_matrix, row_ids, col_ids, num_rows, num_cols):\n    # Extract thr subset of cost matrix corresponding to rows/cols in arrays row_ids/col_ids\n    # Note that the shape of cost_matrix is (num_rows+num_cols) * (num_rows+num_cols)\n    extended_row_ids = np.concatenate(\n        (row_ids, np.array([k + num_rows for k in col_ids]))\n    )\n    extended_col_ids = np.concatenate(\n        (col_ids, np.array([k + num_cols for k in row_ids]))\n    )\n    return cost_matrix[extended_row_ids, :][:, extended_col_ids]\n\n\nclass search_tree_node:\n    def __init__(\n        self,\n        G1,\n        G2,\n        parent_matched_cost,\n        parent_matched_nodes,\n        parent_matched_edges,\n        node_G1,\n        node_G2,\n        parent_unprocessed_nodes_G1,\n        parent_unprocessed_nodes_G2,\n        parent_unprocessed_edges_G1,\n        parent_unprocessed_edges_G2,\n        cost_matrix_nodes,\n        cost_matrix_edges,\n    ):\n        self.matched_cost = parent_matched_cost\n        self.future_approximate_cost = 0.0\n        self.matched_nodes = deepcopy(parent_matched_nodes)\n        self.matched_nodes[0].append(node_G1)\n        self.matched_nodes[1].append(node_G2)\n        self.matched_edges = deepcopy(parent_matched_edges)\n        self.unprocessed_nodes_G1 = [\n            _ for _ in parent_unprocessed_nodes_G1 if _ != node_G1\n        ]\n        self.unprocessed_nodes_G2 = [\n            _ for _ in parent_unprocessed_nodes_G2 if _ != node_G2\n        ]\n\n        # Add the cost of matching nodes at this tree-node to the matched cost\n        if (\n            node_G1 is not None and node_G2 is not None\n        ):  # Substitute node_G1 with node_G2\n            self.matched_cost += cost_matrix_nodes[node_G1, node_G2]\n        elif node_G1 is not None:  # Delete node_G1\n            self.matched_cost += cost_matrix_nodes[\n                node_G1, node_G1 + G2.num_nodes()\n            ]\n        elif node_G2 is not None:  # Insert node_G2\n            self.matched_cost += cost_matrix_nodes[\n                node_G2 + G1.num_nodes(), node_G2\n            ]\n\n        # Add the cost of matching edges at this tree-node to the matched cost\n        incident_edges_G1 = []\n        if (\n            node_G1 is not None\n        ):  # Find the edges with one end-point as node_G1 and other in matched nodes or node_G1\n            incident_edges_G1, index_G1, direction_G1 = get_edges_to_match(\n                G1, node_G1, parent_matched_nodes[0]\n            )\n\n        incident_edges_G2 = np.array([])\n        if (\n            node_G2 is not None\n        ):  # Find the edges with one end-point as node_G2 and other in matched nodes or node_G2\n            incident_edges_G2, index_G2, direction_G2 = get_edges_to_match(\n                G2, node_G2, parent_matched_nodes[1]\n            )\n\n        if (\n            len(incident_edges_G1) > 0 and len(incident_edges_G2) > 0\n        ):  # Consider substituting\n            matched_edges_cost_matrix = subset_cost_matrix(\n                cost_matrix_edges,\n                incident_edges_G1,\n                incident_edges_G2,\n                G1.num_edges(),\n                G2.num_edges(),\n            )\n            max_sum = matched_edges_cost_matrix.sum()\n            # take care of impossible assignments by assigning maximum cost\n            for i in range(len(incident_edges_G1)):\n                for j in range(len(incident_edges_G2)):\n                    # both edges need to have same direction and the other end nodes are matched\n                    if (\n                        direction_G1[i] == direction_G2[j]\n                        and index_G1[i] == index_G2[j]\n                    ):\n                        continue\n                    else:\n                        matched_edges_cost_matrix[i, j] = max_sum\n            # Match the edges as per the LAP solution\n            row_ind, col_ind, _ = lapjv(matched_edges_cost_matrix)\n            lap_cost = 0.00\n            for i in range(len(row_ind)):\n                lap_cost += matched_edges_cost_matrix[i, row_ind[i]]\n\n            # Update matched edges\n            for i in range(len(row_ind)):\n                if i < len(incident_edges_G1):\n                    self.matched_edges[0].append(incident_edges_G1[i])\n                    if row_ind[i] < len(incident_edges_G2):\n                        self.matched_edges[1].append(\n                            incident_edges_G2[row_ind[i]]\n                        )\n                    else:\n                        self.matched_edges[1].append(None)\n                elif row_ind[i] < len(incident_edges_G2):\n                    self.matched_edges[0].append(None)\n                    self.matched_edges[1].append(incident_edges_G2[row_ind[i]])\n            self.matched_cost += lap_cost\n\n        elif len(incident_edges_G1) > 0:  # only deletion possible\n            edge_deletion_cost = 0.0\n            for edge in incident_edges_G1:\n                edge_deletion_cost += cost_matrix_edges[\n                    edge, G2.num_edges() + edge\n                ]\n            # Update matched edges\n            for edge in incident_edges_G1:\n                self.matched_edges[0].append(edge)\n                self.matched_edges[1].append(None)\n\n                # Update matched edges\n\n            self.matched_cost += edge_deletion_cost\n\n        elif len(incident_edges_G2) > 0:  # only insertion possible\n            edge_insertion_cost = 0.0\n            for edge in incident_edges_G2:\n                edge_insertion_cost += cost_matrix_edges[\n                    G1.num_edges() + edge, edge\n                ]\n            # Update matched edges\n            for edge in incident_edges_G2:\n                self.matched_edges[0].append(None)\n                self.matched_edges[1].append(edge)\n\n            self.matched_cost += edge_insertion_cost\n\n        # Add the cost of matching of unprocessed nodes to the future approximate cost\n        if (\n            len(self.unprocessed_nodes_G1) > 0\n            and len(self.unprocessed_nodes_G2) > 0\n        ):  # Consider substituting\n            unmatched_nodes_cost_matrix = subset_cost_matrix(\n                cost_matrix_nodes,\n                self.unprocessed_nodes_G1,\n                self.unprocessed_nodes_G2,\n                G1.num_nodes(),\n                G2.num_nodes(),\n            )\n            # Match the edges as per the LAP solution\n            row_ind, col_ind, _ = lapjv(unmatched_nodes_cost_matrix)\n            lap_cost = 0.00\n            for i in range(len(row_ind)):\n                lap_cost += unmatched_nodes_cost_matrix[i, row_ind[i]]\n\n            self.future_approximate_cost += lap_cost\n\n        elif len(self.unprocessed_nodes_G1) > 0:  # only deletion possible\n            node_deletion_cost = 0.0\n            for node in self.unprocessed_nodes_G1:\n                node_deletion_cost += cost_matrix_nodes[\n                    node, G2.num_nodes() + node\n                ]\n\n            self.future_approximate_cost += node_deletion_cost\n\n        elif len(self.unprocessed_nodes_G2) > 0:  # only insertion possible\n            node_insertion_cost = 0.0\n            for node in self.unprocessed_nodes_G2:\n                node_insertion_cost += cost_matrix_nodes[\n                    G1.num_nodes() + node, node\n                ]\n\n            self.future_approximate_cost += node_insertion_cost\n\n        # Add the cost of LAP matching of unprocessed edges to the future approximate cost\n        self.unprocessed_edges_G1 = [\n            _ for _ in parent_unprocessed_edges_G1 if _ not in incident_edges_G1\n        ]\n        self.unprocessed_edges_G2 = [\n            _ for _ in parent_unprocessed_edges_G2 if _ not in incident_edges_G2\n        ]\n        if (\n            len(self.unprocessed_edges_G1) > 0\n            and len(self.unprocessed_edges_G2) > 0\n        ):  # Consider substituting\n            unmatched_edges_cost_matrix = subset_cost_matrix(\n                cost_matrix_edges,\n                self.unprocessed_edges_G1,\n                self.unprocessed_edges_G2,\n                G1.num_edges(),\n                G2.num_edges(),\n            )\n            # Match the edges as per the LAP solution\n            row_ind, col_ind, _ = lapjv(unmatched_edges_cost_matrix)\n            lap_cost = 0.00\n            for i in range(len(row_ind)):\n                lap_cost += unmatched_edges_cost_matrix[i, row_ind[i]]\n\n            self.future_approximate_cost += lap_cost\n\n        elif len(self.unprocessed_edges_G1) > 0:  # only deletion possible\n            edge_deletion_cost = 0.0\n            for edge in self.unprocessed_edges_G1:\n                edge_deletion_cost += cost_matrix_edges[\n                    edge, G2.num_edges() + edge\n                ]\n\n            self.future_approximate_cost += edge_deletion_cost\n\n        elif len(self.unprocessed_edges_G2) > 0:  # only insertion possible\n            edge_insertion_cost = 0.0\n            for edge in self.unprocessed_edges_G2:\n                edge_insertion_cost += cost_matrix_edges[\n                    G1.num_edges() + edge, edge\n                ]\n\n            self.future_approximate_cost += edge_insertion_cost\n\n    # For heap insertion order\n    def __lt__(self, other):\n        if (\n            abs(\n                (self.matched_cost + self.future_approximate_cost)\n                - (other.matched_cost + other.future_approximate_cost)\n            )\n            > EPSILON\n        ):\n            return (self.matched_cost + self.future_approximate_cost) < (\n                other.matched_cost + other.future_approximate_cost\n            )\n        elif abs(self.matched_cost - other.matched_cost) > EPSILON:\n            return other.matched_cost < self.matched_cost\n            # matched cost is closer to reality\n        else:\n            return (\n                len(self.unprocessed_nodes_G1)\n                + len(self.unprocessed_nodes_G2)\n                + len(self.unprocessed_edges_G1)\n                + len(self.unprocessed_edges_G2)\n            ) < (\n                len(other.unprocessed_nodes_G1)\n                + len(other.unprocessed_nodes_G2)\n                + len(other.unprocessed_edges_G1)\n                + len(other.unprocessed_edges_G2)\n            )\n\n\ndef edit_cost_from_node_matching(\n    G1, G2, cost_matrix_nodes, cost_matrix_edges, node_matching\n):\n    matched_cost = 0.0\n    matched_nodes = ([], [])\n    matched_edges = ([], [])\n    # Add the cost of matching nodes\n    for i in range(G1.num_nodes()):\n        matched_cost += cost_matrix_nodes[i, node_matching[i]]\n        matched_nodes[0].append(i)\n        if node_matching[i] < G2.num_nodes():\n            matched_nodes[1].append(node_matching[i])\n        else:\n            matched_nodes[1].append(None)\n    for i in range(G1.num_nodes(), len(node_matching)):\n        matched_cost += cost_matrix_nodes[i, node_matching[i]]\n        if node_matching[i] < G2.num_nodes():\n            matched_nodes[0].append(None)\n            matched_nodes[1].append(node_matching[i])\n\n    for i in range(len(matched_nodes[0])):\n        # Add the cost of matching edges\n        incident_edges_G1 = []\n        if (\n            matched_nodes[0][i] is not None\n        ):  # Find the edges with one end-point as node_G1 and other in matched nodes or node_G1\n            incident_edges_G1, index_G1, direction_G1 = get_edges_to_match(\n                G1, matched_nodes[0][i], matched_nodes[0][:i]\n            )\n\n        incident_edges_G2 = np.array([])\n        if (\n            matched_nodes[1][i] is not None\n        ):  # Find the edges with one end-point as node_G2 and other in matched nodes or node_G2\n            incident_edges_G2, index_G2, direction_G2 = get_edges_to_match(\n                G2, matched_nodes[1][i], matched_nodes[1][:i]\n            )\n\n        if (\n            len(incident_edges_G1) > 0 and len(incident_edges_G2) > 0\n        ):  # Consider substituting\n            matched_edges_cost_matrix = subset_cost_matrix(\n                cost_matrix_edges,\n                incident_edges_G1,\n                incident_edges_G2,\n                G1.num_edges(),\n                G2.num_edges(),\n            )\n            max_sum = matched_edges_cost_matrix.sum()\n            # take care of impossible assignments by assigning maximum cost\n            for i in range(len(incident_edges_G1)):\n                for j in range(len(incident_edges_G2)):\n                    # both edges need to have same direction and the other end nodes are matched\n                    if (\n                        direction_G1[i] == direction_G2[j]\n                        and index_G1[i] == index_G2[j]\n                    ):\n                        continue\n                    else:\n                        matched_edges_cost_matrix[i, j] = max_sum\n            # Match the edges as per the LAP solution\n            row_ind, col_ind, _ = lapjv(matched_edges_cost_matrix)\n            lap_cost = 0.00\n            for i in range(len(row_ind)):\n                lap_cost += matched_edges_cost_matrix[i, row_ind[i]]\n\n            # Update matched edges\n            for i in range(len(row_ind)):\n                if i < len(incident_edges_G1):\n                    matched_edges[0].append(incident_edges_G1[i])\n                    if row_ind[i] < len(incident_edges_G2):\n                        matched_edges[1].append(incident_edges_G2[row_ind[i]])\n                    else:\n                        matched_edges[1].append(None)\n                elif row_ind[i] < len(incident_edges_G2):\n                    matched_edges[0].append(None)\n                    matched_edges[1].append(incident_edges_G2[row_ind[i]])\n            matched_cost += lap_cost\n\n        elif len(incident_edges_G1) > 0:  # only deletion possible\n            edge_deletion_cost = 0.0\n            for edge in incident_edges_G1:\n                edge_deletion_cost += cost_matrix_edges[\n                    edge, G2.num_edges() + edge\n                ]\n            # Update matched edges\n            for edge in incident_edges_G1:\n                matched_edges[0].append(edge)\n                matched_edges[1].append(None)\n\n                # Update matched edges\n\n            matched_cost += edge_deletion_cost\n\n        elif len(incident_edges_G2) > 0:  # only insertion possible\n            edge_insertion_cost = 0.0\n            for edge in incident_edges_G2:\n                edge_insertion_cost += cost_matrix_edges[\n                    G1.num_edges() + edge, edge\n                ]\n            # Update matched edges\n            for edge in incident_edges_G2:\n                matched_edges[0].append(None)\n                matched_edges[1].append(edge)\n\n            matched_cost += edge_insertion_cost\n\n    return (matched_cost, matched_nodes, matched_edges)\n\n\ndef contextual_cost_matrix_construction(\n    G1,\n    G2,\n    node_substitution_cost,\n    edge_substitution_cost,\n    G1_node_deletion_cost,\n    G1_edge_deletion_cost,\n    G2_node_insertion_cost,\n    G2_edge_insertion_cost,\n):\n    # Calculates approximate GED using linear assignment on the nodes with bipartite algorithm\n    # cost matrix of node mappings\n\n    num_G1_nodes = G1.num_nodes()\n    num_G2_nodes = G2.num_nodes()\n\n    num_G1_edges = G1.num_edges()\n    num_G2_edges = G2.num_edges()\n\n    cost_upper_bound = 2 * (\n        node_substitution_cost.sum()\n        + G1_node_deletion_cost.sum()\n        + G2_node_insertion_cost.sum()\n        + 1\n    )\n    cost_matrix = np.zeros(\n        (num_G1_nodes + num_G2_nodes, num_G1_nodes + num_G2_nodes), dtype=float\n    )\n\n    cost_matrix[0:num_G1_nodes, 0:num_G2_nodes] = node_substitution_cost\n    cost_matrix[\n        0:num_G1_nodes, num_G2_nodes : num_G2_nodes + num_G1_nodes\n    ] = np.array(\n        [\n            G1_node_deletion_cost[i] if i == j else cost_upper_bound\n            for i in range(num_G1_nodes)\n            for j in range(num_G1_nodes)\n        ]\n    ).reshape(\n        num_G1_nodes, num_G1_nodes\n    )\n    cost_matrix[\n        num_G1_nodes : num_G1_nodes + num_G2_nodes, 0:num_G2_nodes\n    ] = np.array(\n        [\n            G2_node_insertion_cost[i] if i == j else cost_upper_bound\n            for i in range(num_G2_nodes)\n            for j in range(num_G2_nodes)\n        ]\n    ).reshape(\n        num_G2_nodes, num_G2_nodes\n    )\n\n    self_edge_list_G1 = [np.array([], dtype=int)] * num_G1_nodes\n    self_edge_list_G2 = [np.array([], dtype=int)] * num_G2_nodes\n    incoming_edges_G1 = [np.array([], dtype=int)] * num_G1_nodes\n    incoming_edges_G2 = [np.array([], dtype=int)] * num_G2_nodes\n    outgoing_edges_G1 = [np.array([], dtype=int)] * num_G1_nodes\n    outgoing_edges_G2 = [np.array([], dtype=int)] * num_G2_nodes\n\n    for i in range(num_G1_nodes):\n        if G1.has_edge_between(i, i):\n            self_edge_list_G1[i] = sorted(\n                G1.edge_ids(i, i, return_array=True).numpy()\n            )\n        incoming_edges_G1[i] = G1.in_edges([i], \"eid\").numpy()\n        incoming_edges_G1[i] = np.setdiff1d(\n            incoming_edges_G1[i], self_edge_list_G1[i]\n        )\n        outgoing_edges_G1[i] = G1.out_edges([i], \"eid\").numpy()\n        outgoing_edges_G1[i] = np.setdiff1d(\n            outgoing_edges_G1[i], self_edge_list_G1[i]\n        )\n    for i in range(num_G2_nodes):\n        if G2.has_edge_between(i, i):\n            self_edge_list_G2[i] = sorted(\n                G2.edge_ids(i, i, return_array=True).numpy()\n            )\n        incoming_edges_G2[i] = G2.in_edges([i], \"eid\").numpy()\n        incoming_edges_G2[i] = np.setdiff1d(\n            incoming_edges_G2[i], self_edge_list_G2[i]\n        )\n        outgoing_edges_G2[i] = G2.out_edges([i], \"eid\").numpy()\n        outgoing_edges_G2[i] = np.setdiff1d(\n            outgoing_edges_G2[i], self_edge_list_G2[i]\n        )\n\n    selected_deletion_G1 = [\n        G1_edge_deletion_cost[\n            np.concatenate(\n                (\n                    self_edge_list_G1[i],\n                    incoming_edges_G1[i],\n                    outgoing_edges_G1[i],\n                )\n            )\n        ]\n        for i in range(G1.num_nodes())\n    ]\n    selected_insertion_G2 = [\n        G2_edge_insertion_cost[\n            np.concatenate(\n                (\n                    self_edge_list_G2[i],\n                    incoming_edges_G2[i],\n                    outgoing_edges_G2[i],\n                )\n            )\n        ]\n        for i in range(G2.num_nodes())\n    ]\n\n    # Add the cost of edge edition which are dependent of a node (see this as the cost associated with a substructure)\n    for i in range(num_G1_nodes):\n        for j in range(num_G2_nodes):\n            m = (\n                len(self_edge_list_G1[i])\n                + len(incoming_edges_G1[i])\n                + len(outgoing_edges_G1[i])\n            )\n            n = (\n                len(self_edge_list_G2[j])\n                + len(incoming_edges_G2[j])\n                + len(outgoing_edges_G2[j])\n            )\n\n            matrix_dim = m + n\n\n            if matrix_dim == 0:\n                continue\n            temp_edge_cost_matrix = np.empty((matrix_dim, matrix_dim))\n            temp_edge_cost_matrix.fill(cost_upper_bound)\n\n            temp_edge_cost_matrix[\n                : len(self_edge_list_G1[i]), : len(self_edge_list_G2[j])\n            ] = edge_substitution_cost[self_edge_list_G1[i], :][\n                :, self_edge_list_G2[j]\n            ]\n            temp_edge_cost_matrix[\n                len(self_edge_list_G1[i]) : len(self_edge_list_G1[i])\n                + len(incoming_edges_G1[i]),\n                len(self_edge_list_G2[j]) : len(self_edge_list_G2[j])\n                + len(incoming_edges_G2[j]),\n            ] = edge_substitution_cost[incoming_edges_G1[i], :][\n                :, incoming_edges_G2[j]\n            ]\n            temp_edge_cost_matrix[\n                len(self_edge_list_G1[i]) + len(incoming_edges_G1[i]) : m,\n                len(self_edge_list_G2[j]) + len(incoming_edges_G2[j]) : n,\n            ] = edge_substitution_cost[outgoing_edges_G1[i], :][\n                :, outgoing_edges_G2[j]\n            ]\n\n            np.fill_diagonal(\n                temp_edge_cost_matrix[:m, n:], selected_deletion_G1[i]\n            )\n            np.fill_diagonal(\n                temp_edge_cost_matrix[m:, :n], selected_insertion_G2[j]\n            )\n\n            temp_edge_cost_matrix[m:, n:].fill(0)\n            row_ind, col_ind, _ = lapjv(temp_edge_cost_matrix)\n            lap_cost = 0.00\n            for k in range(len(row_ind)):\n                lap_cost += temp_edge_cost_matrix[k, row_ind[k]]\n\n            cost_matrix[i, j] += lap_cost\n\n    for i in range(num_G1_nodes):\n        cost_matrix[i, num_G2_nodes + i] += selected_deletion_G1[i].sum()\n\n    for i in range(num_G2_nodes):\n        cost_matrix[num_G1_nodes + i, i] += selected_insertion_G2[i].sum()\n\n    return cost_matrix\n\n\ndef hausdorff_matching(\n    G1,\n    G2,\n    node_substitution_cost,\n    edge_substitution_cost,\n    G1_node_deletion_cost,\n    G1_edge_deletion_cost,\n    G2_node_insertion_cost,\n    G2_edge_insertion_cost,\n):\n    # Calculates approximate GED using hausdorff_matching\n    # cost matrix of node mappings\n\n    num_G1_nodes = G1.num_nodes()\n    num_G2_nodes = G2.num_nodes()\n\n    num_G1_edges = G1.num_edges()\n    num_G2_edges = G2.num_edges()\n\n    self_edge_list_G1 = [np.array([], dtype=int)] * num_G1_nodes\n    self_edge_list_G2 = [np.array([], dtype=int)] * num_G2_nodes\n    incoming_edges_G1 = [np.array([], dtype=int)] * num_G1_nodes\n    incoming_edges_G2 = [np.array([], dtype=int)] * num_G2_nodes\n    outgoing_edges_G1 = [np.array([], dtype=int)] * num_G1_nodes\n    outgoing_edges_G2 = [np.array([], dtype=int)] * num_G2_nodes\n\n    for i in range(num_G1_nodes):\n        if G1.has_edge_between(i, i):\n            self_edge_list_G1[i] = sorted(\n                G1.edge_ids(i, i, return_array=True).numpy()\n            )\n        incoming_edges_G1[i] = G1.in_edges([i], \"eid\").numpy()\n        incoming_edges_G1[i] = np.setdiff1d(\n            incoming_edges_G1[i], self_edge_list_G1[i]\n        )\n        outgoing_edges_G1[i] = G1.out_edges([i], \"eid\").numpy()\n        outgoing_edges_G1[i] = np.setdiff1d(\n            outgoing_edges_G1[i], self_edge_list_G1[i]\n        )\n    for i in range(num_G2_nodes):\n        if G2.has_edge_between(i, i):\n            self_edge_list_G2[i] = sorted(\n                G2.edge_ids(i, i, return_array=True).numpy()\n            )\n        incoming_edges_G2[i] = G2.in_edges([i], \"eid\").numpy()\n        incoming_edges_G2[i] = np.setdiff1d(\n            incoming_edges_G2[i], self_edge_list_G2[i]\n        )\n        outgoing_edges_G2[i] = G2.out_edges([i], \"eid\").numpy()\n        outgoing_edges_G2[i] = np.setdiff1d(\n            outgoing_edges_G2[i], self_edge_list_G2[i]\n        )\n\n    selected_deletion_self_G1 = [\n        G1_edge_deletion_cost[self_edge_list_G1[i]]\n        for i in range(G1.num_nodes())\n    ]\n    selected_insertion_self_G2 = [\n        G2_edge_insertion_cost[self_edge_list_G2[i]]\n        for i in range(G2.num_nodes())\n    ]\n\n    selected_deletion_incoming_G1 = [\n        G1_edge_deletion_cost[incoming_edges_G1[i]]\n        for i in range(G1.num_nodes())\n    ]\n    selected_insertion_incoming_G2 = [\n        G2_edge_insertion_cost[incoming_edges_G2[i]]\n        for i in range(G2.num_nodes())\n    ]\n\n    selected_deletion_outgoing_G1 = [\n        G1_edge_deletion_cost[outgoing_edges_G1[i]]\n        for i in range(G1.num_nodes())\n    ]\n    selected_insertion_outgoing_G2 = [\n        G2_edge_insertion_cost[outgoing_edges_G2[i]]\n        for i in range(G2.num_nodes())\n    ]\n\n    selected_deletion_G1 = [\n        G1_edge_deletion_cost[\n            np.concatenate(\n                (\n                    self_edge_list_G1[i],\n                    incoming_edges_G1[i],\n                    outgoing_edges_G1[i],\n                )\n            )\n        ]\n        for i in range(G1.num_nodes())\n    ]\n    selected_insertion_G2 = [\n        G2_edge_insertion_cost[\n            np.concatenate(\n                (\n                    self_edge_list_G2[i],\n                    incoming_edges_G2[i],\n                    outgoing_edges_G2[i],\n                )\n            )\n        ]\n        for i in range(G2.num_nodes())\n    ]\n\n    cost_G1 = np.array(\n        [\n            (G1_node_deletion_cost[i] + selected_deletion_G1[i].sum() / 2)\n            for i in range(num_G1_nodes)\n        ]\n    )\n    cost_G2 = np.array(\n        [\n            (G2_node_insertion_cost[i] + selected_insertion_G2[i].sum() / 2)\n            for i in range(num_G2_nodes)\n        ]\n    )\n\n    for i in range(num_G1_nodes):\n        for j in range(num_G2_nodes):\n            c1_self = deepcopy(selected_deletion_self_G1[i])\n            c2_self = deepcopy(selected_insertion_self_G2[j])\n            c1_incoming = deepcopy(selected_deletion_incoming_G1[i])\n            c2_incoming = deepcopy(selected_insertion_incoming_G2[j])\n            c1_outgoing = deepcopy(selected_deletion_outgoing_G1[i])\n            c2_outgoing = deepcopy(selected_insertion_outgoing_G2[j])\n\n            for k, a in enumerate(self_edge_list_G1[i]):\n                for l, b in enumerate(self_edge_list_G2[j]):\n                    c1_self[k] = min(\n                        c1_self[k], edge_substitution_cost[a, b] / 2\n                    )\n                    c2_self[l] = min(\n                        c2_self[l], edge_substitution_cost[a, b] / 2\n                    )\n\n            for k, a in enumerate(incoming_edges_G1[i]):\n                for l, b in enumerate(incoming_edges_G2[j]):\n                    c1_incoming[k] = min(\n                        c1_incoming[k], edge_substitution_cost[a, b] / 2\n                    )\n                    c2_incoming[l] = min(\n                        c2_incoming[l], edge_substitution_cost[a, b] / 2\n                    )\n\n            for k, a in enumerate(outgoing_edges_G1[i]):\n                for l, b in enumerate(outgoing_edges_G2[j]):\n                    c1_outgoing[k] = min(\n                        c1_outgoing[k], edge_substitution_cost[a, b] / 2\n                    )\n                    c2_outgoing[l] = min(\n                        c2_outgoing[l], edge_substitution_cost[a, b] / 2\n                    )\n\n            edge_hausdorff_lower_bound = 0.0\n\n            if len(selected_deletion_G1[i]) > len(selected_insertion_G2[j]):\n                idx = np.argpartition(\n                    selected_deletion_G1[i],\n                    (\n                        len(selected_deletion_G1[i])\n                        - len(selected_insertion_G2[j])\n                    ),\n                )\n                edge_hausdorff_lower_bound = selected_deletion_G1[i][\n                    idx[\n                        : (\n                            len(selected_deletion_G1[i])\n                            - len(selected_insertion_G2[j])\n                        )\n                    ]\n                ].sum()\n            elif len(selected_deletion_G1[i]) < len(selected_insertion_G2[j]):\n                idx = np.argpartition(\n                    selected_insertion_G2[j],\n                    (\n                        len(selected_insertion_G2[j])\n                        - len(selected_deletion_G1[i])\n                    ),\n                )\n                edge_hausdorff_lower_bound = selected_insertion_G2[j][\n                    idx[\n                        : (\n                            len(selected_insertion_G2[j])\n                            - len(selected_deletion_G1[i])\n                        )\n                    ]\n                ].sum()\n\n            sc_cost = 0.5 * (\n                node_substitution_cost[i, j]\n                + 0.5\n                * max(\n                    c1_self.sum()\n                    + c2_self.sum()\n                    + c1_incoming.sum()\n                    + c2_incoming.sum()\n                    + c1_outgoing.sum()\n                    + c2_outgoing.sum(),\n                    edge_hausdorff_lower_bound,\n                )\n            )\n\n            if cost_G1[i] > sc_cost:\n                cost_G1[i] = sc_cost\n            if cost_G2[j] > sc_cost:\n                cost_G2[j] = sc_cost\n\n    graph_hausdorff_lower_bound = 0.0\n    if num_G1_nodes > num_G2_nodes:\n        idx = np.argpartition(\n            G1_node_deletion_cost, (num_G1_nodes - num_G2_nodes)\n        )\n        graph_hausdorff_lower_bound = G1_node_deletion_cost[\n            idx[: (num_G1_nodes - num_G2_nodes)]\n        ].sum()\n    elif num_G1_nodes < num_G2_nodes:\n        idx = np.argpartition(\n            G2_node_insertion_cost, (num_G2_nodes - num_G1_nodes)\n        )\n        graph_hausdorff_lower_bound = G2_node_insertion_cost[\n            idx[: (num_G2_nodes - num_G1_nodes)]\n        ].sum()\n\n    graph_hausdorff_cost = max(\n        graph_hausdorff_lower_bound, cost_G1.sum() + cost_G2.sum()\n    )\n    return graph_hausdorff_cost\n\n\ndef a_star_search(G1, G2, cost_matrix_nodes, cost_matrix_edges, max_beam_size):\n    # A-star traversal\n    open_list = []\n    # Create first nodes in the A-star search tree, matching node 0 of G1 with all possibilities (each node of G2, and deletion)\n    matched_cost = 0.0\n    matched_nodes = ([], [])\n    # No nodes matched in the beginning\n    matched_edges = ([], [])\n    # No edges matched in the beginning\n    unprocessed_nodes_G1 = [\n        i for i in range(G1.num_nodes())\n    ]  # No nodes matched in the beginning\n    unprocessed_nodes_G2 = [\n        i for i in range(G2.num_nodes())\n    ]  # No nodes matched in the beginning\n    unprocessed_edges_G1 = [\n        i for i in range(G1.num_edges())\n    ]  # No edges matched in the beginning\n    unprocessed_edges_G2 = [\n        i for i in range(G2.num_edges())\n    ]  # No edges matched in the beginning\n\n    for i in range(len(unprocessed_nodes_G2)):\n        tree_node = search_tree_node(\n            G1,\n            G2,\n            matched_cost,\n            matched_nodes,\n            matched_edges,\n            unprocessed_nodes_G1[0],\n            unprocessed_nodes_G2[i],\n            unprocessed_nodes_G1,\n            unprocessed_nodes_G2,\n            unprocessed_edges_G1,\n            unprocessed_edges_G2,\n            cost_matrix_nodes,\n            cost_matrix_edges,\n        )\n        # Insert into open-list, implemented as a heap\n\n        heappush(open_list, tree_node)\n\n    # Consider node deletion\n    tree_node = search_tree_node(\n        G1,\n        G2,\n        matched_cost,\n        matched_nodes,\n        matched_edges,\n        unprocessed_nodes_G1[0],\n        None,\n        unprocessed_nodes_G1,\n        unprocessed_nodes_G2,\n        unprocessed_edges_G1,\n        unprocessed_edges_G2,\n        cost_matrix_nodes,\n        cost_matrix_edges,\n    )\n    # Insert into open-list, implemented as a heap\n    heappush(open_list, tree_node)\n\n    while len(open_list) > 0:\n        # TODO: Create a node that processes multi node insertion deletion in one search node,\n        # as opposed in multiple search nodes here\n        parent_tree_node = heappop(open_list)\n        matched_cost = parent_tree_node.matched_cost\n        matched_nodes = parent_tree_node.matched_nodes\n        matched_edges = parent_tree_node.matched_edges\n        unprocessed_nodes_G1 = parent_tree_node.unprocessed_nodes_G1\n        unprocessed_nodes_G2 = parent_tree_node.unprocessed_nodes_G2\n        unprocessed_edges_G1 = parent_tree_node.unprocessed_edges_G1\n        unprocessed_edges_G2 = parent_tree_node.unprocessed_edges_G2\n\n        if len(unprocessed_nodes_G1) == 0 and len(unprocessed_nodes_G2) == 0:\n            return (matched_cost, matched_nodes, matched_edges)\n        elif len(unprocessed_nodes_G1) > 0:\n            for i in range(len(unprocessed_nodes_G2)):\n                tree_node = search_tree_node(\n                    G1,\n                    G2,\n                    matched_cost,\n                    matched_nodes,\n                    matched_edges,\n                    unprocessed_nodes_G1[0],\n                    unprocessed_nodes_G2[i],\n                    unprocessed_nodes_G1,\n                    unprocessed_nodes_G2,\n                    unprocessed_edges_G1,\n                    unprocessed_edges_G2,\n                    cost_matrix_nodes,\n                    cost_matrix_edges,\n                )\n                # Insert into open-list, implemented as a heap\n                heappush(open_list, tree_node)\n\n            # Consider node deletion\n            tree_node = search_tree_node(\n                G1,\n                G2,\n                matched_cost,\n                matched_nodes,\n                matched_edges,\n                unprocessed_nodes_G1[0],\n                None,\n                unprocessed_nodes_G1,\n                unprocessed_nodes_G2,\n                unprocessed_edges_G1,\n                unprocessed_edges_G2,\n                cost_matrix_nodes,\n                cost_matrix_edges,\n            )\n            # Insert into open-list, implemented as a heap\n            heappush(open_list, tree_node)\n\n        elif len(unprocessed_nodes_G2) > 0:\n            for i in range(len(unprocessed_nodes_G2)):\n                tree_node = search_tree_node(\n                    G1,\n                    G2,\n                    matched_cost,\n                    matched_nodes,\n                    matched_edges,\n                    None,\n                    unprocessed_nodes_G2[i],\n                    unprocessed_nodes_G1,\n                    unprocessed_nodes_G2,\n                    unprocessed_edges_G1,\n                    unprocessed_edges_G2,\n                    cost_matrix_nodes,\n                    cost_matrix_edges,\n                )\n                # Insert into open-list, implemented as a heap\n                heappush(open_list, tree_node)\n\n        # Retain the top-k elements in open-list iff algorithm is beam\n        if max_beam_size > 0 and len(open_list) > max_beam_size:\n            open_list = nsmallest(max_beam_size, open_list)\n            heapify(open_list)\n\n    return None\n\n\ndef get_sorted_mapping(mapping_tuple, len1, len2):\n    # Get sorted mapping of nodes/edges\n    result_0 = [None] * len1\n    result_1 = [None] * len2\n    for i in range(len(mapping_tuple[0])):\n        if mapping_tuple[0][i] is not None and mapping_tuple[1][i] is not None:\n            result_0[mapping_tuple[0][i]] = mapping_tuple[1][i]\n            result_1[mapping_tuple[1][i]] = mapping_tuple[0][i]\n    return (result_0, result_1)\n\n\ndef graph_edit_distance(\n    G1,\n    G2,\n    node_substitution_cost=None,\n    edge_substitution_cost=None,\n    G1_node_deletion_cost=None,\n    G2_node_insertion_cost=None,\n    G1_edge_deletion_cost=None,\n    G2_edge_insertion_cost=None,\n    algorithm=\"bipartite\",\n    max_beam_size=100,\n):\n    \"\"\"Returns GED (graph edit distance) between DGLGraphs G1 and G2.\n\n\n    Parameters\n    ----------\n    G1, G2: DGLGraphs\n\n    node_substitution_cost, edge_substitution_cost : 2D numpy arrays\n        node_substitution_cost[i,j] is the cost of substitution node i of G1 with node j of G2,\n        similar definition for edge_substitution_cost. If None, default cost of 0 is used.\n\n    G1_node_deletion_cost, G1_edge_deletion_cost : 1D numpy arrays\n        G1_node_deletion_cost[i] is the cost of deletion of node i of G1,\n        similar definition for G1_edge_deletion_cost. If None, default cost of 1 is used.\n\n    G2_node_insertion_cost, G2_edge_insertion_cost : 1D numpy arrays\n        G2_node_insertion_cost[i] is the cost of insertion of node i of G2,\n        similar definition for G2_edge_insertion_cost. If None, default cost of 1 is used.\n\n    algorithm : string\n        Algorithm to use to calculate the edit distance.\n        For now, 4 algorithms are supported\n        i) astar: Calculates exact GED using A* graph traversal algorithm,\n        the heuristic used is the one proposed in (Riesen and Bunke, 2009) [1].\n        ii) beam: Calculates approximate GED using A* graph traversal algorithm,\n        with a maximum number of nodes in the open list. [2]\n        iii) bipartite (default): Calculates approximate GED using linear assignment on the nodes,\n        with jv (Jonker-Volgerand) algorithm. [3]\n        iv) hausdorff: Approximation of graph edit distance based on Hausdorff matching [4].\n\n    max_beam_size : int\n        Maximum number of nodes in the open list, in case the algorithm is 'beam'.\n\n\n    Returns\n    -------\n    A tuple of three objects: (edit_distance, node_mapping, edge_mapping)\n    edit distance is the calculated edit distance (float)\n    node_mapping is a tuple of size two, containing the node assignments of the two graphs respectively\n    eg., node_mapping[0][i] is the node mapping of node i of graph G1 (None means that the node is deleted)\n    Similar definition for the edge_mapping\n\n    For 'hausdorff', node_mapping and edge_mapping are returned as None, as this approximation does not return a unique edit path\n\n    Examples\n    --------\n    >>> src1 = [0, 1, 2, 3, 4, 5];\n    >>> dst1 = [1, 2, 3, 4, 5, 6];\n    >>> src2 = [0, 1, 3, 4, 5];\n    >>> dst2 = [1, 2, 4, 5, 6];\n\n    >>> G1 = dgl.DGLGraph((src1, dst1))\n    >>> G2 = dgl.DGLGraph((src2, dst2))\n    >>> distance, node_mapping, edge_mapping = graph_edit_distance(G1, G1, algorithm='astar')\n    >>> print(distance)\n    0.0\n    >>> distance, node_mapping, edge_mapping = graph_edit_distance(G1, G2, algorithm='astar')\n    >>> print(distance)\n    1.0\n\n    References\n    ----------\n    [1] Riesen, Kaspar, Stefan Fankhauser, and Horst Bunke.\n    \"Speeding Up Graph Edit Distance Computation with a Bipartite Heuristic.\"\n    MLG. 2007.\n    [2] Neuhaus, Michel, Kaspar Riesen, and Horst Bunke.\n    \"Fast suboptimal algorithms for the computation of graph edit distance.\"\n    Joint IAPR International Workshops on Statistical Techniques in Pattern Recognition (SPR)\n    and Structural and Syntactic Pattern Recognition (SSPR). 2006.\n    [3] Fankhauser, Stefan, Kaspar Riesen, and Horst Bunke.\n    \"Speeding up graph edit distance computation through fast bipartite matching.\"\n    International Workshop on Graph-Based Representations in Pattern Recognition. 2011.\n    [4] Fischer, Andreas, et al. \"A hausdorff heuristic for efficient computation of graph edit distance.\"\n    Joint IAPR International Workshops on Statistical Techniques in Pattern Recognition (SPR)\n    and Structural and Syntactic Pattern Recognition (SSPR). 2014.\n\n    \"\"\"\n    # Handle corner cases\n    if G1 is None and G2 is None:\n        return (0.0, ([], []), ([], []))\n    elif G1 is None:\n        edit_cost = 0.0\n\n    # Validate\n    if algorithm != \"beam\":\n        max_beam_size = -1\n    (\n        node_substitution_cost,\n        edge_substitution_cost,\n        G1_node_deletion_cost,\n        G1_edge_deletion_cost,\n        G2_node_insertion_cost,\n        G2_edge_insertion_cost,\n    ) = validate_cost_functions(\n        G1,\n        G2,\n        node_substitution_cost,\n        edge_substitution_cost,\n        G1_node_deletion_cost,\n        G1_edge_deletion_cost,\n        G2_node_insertion_cost,\n        G2_edge_insertion_cost,\n    )\n\n    # cost matrices for LAP solution\n    cost_matrix_nodes, cost_matrix_edges = construct_cost_functions(\n        G1,\n        G2,\n        node_substitution_cost,\n        edge_substitution_cost,\n        G1_node_deletion_cost,\n        G1_edge_deletion_cost,\n        G2_node_insertion_cost,\n        G2_edge_insertion_cost,\n    )\n\n    if algorithm == \"astar\" or algorithm == \"beam\":\n        (matched_cost, matched_nodes, matched_edges) = a_star_search(\n            G1, G2, cost_matrix_nodes, cost_matrix_edges, max_beam_size\n        )\n        return (\n            matched_cost,\n            get_sorted_mapping(matched_nodes, G1.num_nodes(), G2.num_nodes()),\n            get_sorted_mapping(matched_edges, G1.num_edges(), G2.num_edges()),\n        )\n\n    elif algorithm == \"hausdorff\":\n        hausdorff_cost = hausdorff_matching(\n            G1,\n            G2,\n            node_substitution_cost,\n            edge_substitution_cost,\n            G1_node_deletion_cost,\n            G1_edge_deletion_cost,\n            G2_node_insertion_cost,\n            G2_edge_insertion_cost,\n        )\n\n        return (hausdorff_cost, None, None)\n\n    else:\n        cost_matrix = contextual_cost_matrix_construction(\n            G1,\n            G2,\n            node_substitution_cost,\n            edge_substitution_cost,\n            G1_node_deletion_cost,\n            G1_edge_deletion_cost,\n            G2_node_insertion_cost,\n            G2_edge_insertion_cost,\n        )\n        # Match the nodes as per the LAP solution\n        row_ind, col_ind, _ = lapjv(cost_matrix)\n\n        (\n            matched_cost,\n            matched_nodes,\n            matched_edges,\n        ) = edit_cost_from_node_matching(\n            G1, G2, cost_matrix_nodes, cost_matrix_edges, row_ind\n        )\n\n        return (\n            matched_cost,\n            get_sorted_mapping(matched_nodes, G1.num_nodes(), G2.num_nodes()),\n            get_sorted_mapping(matched_edges, G1.num_edges(), G2.num_edges()),\n        )\n"
  },
  {
    "path": "examples/pytorch/graphsage/README.md",
    "content": "Inductive Representation Learning on Large Graphs (GraphSAGE)\n============\n\n- Paper link: [http://papers.nips.cc/paper/6703-inductive-representation-learning-on-large-graphs.pdf](http://papers.nips.cc/paper/6703-inductive-representation-learning-on-large-graphs.pdf)\n- Author's code repo: [https://github.com/williamleif/graphsage-simple](https://github.com/williamleif/graphsage-simple)\n\nFor advanced usages, including training with multi-gpu/multi-node, and PyTorch Lightning, etc., more examples can be found in [advanced](https://github.com/dmlc/dgl/tree/master/examples/pytorch/graphsage/advanced) and [dist](https://github.com/dmlc/dgl/tree/master/examples/pytorch/graphsage/dist) directory.\n\nRequirements\n------------\n\n```bash\npip install requests torchmetrics==0.11.4 ogb\n```\n\nHow to run\n-------\n\n### Full graph training\n\nRun with following (available dataset: \"cora\", \"citeseer\", \"pubmed\")\n```bash\npython3 train_full.py --dataset cora --gpu 0    # full graph\n```\n\nResults:\n```\n* cora: ~0.8330\n* citeseer: ~0.7110\n* pubmed: ~0.7830\n```\n\n### Minibatch training for node classification\n\nTrain w/ mini-batch sampling in mixed mode (CPU+GPU) for node classification on \"ogbn-products\"\n\n```bash\npython3 node_classification.py\n```\n\nResults:\n```\nTest Accuracy: 0.7632\n```\n\n### PyTorch Lightning for node classification\n\nTrain w/ mini-batch sampling for node classification with PyTorch Lightning on OGB-products. It requires PyTorch Lightning 2.0.1. It works with both single GPU and multiple GPUs:\n\n```bash\npython3 lightning/node_classification.py\n```\n\n### Minibatch training for link prediction\n\nTrain w/ mini-batch sampling for link prediction on OGB-citation2:\n\n```bash\npython3 link_pred.py\n```\n\nResults (10 epochs):\n```\nTest MRR: 0.7386\n```\n"
  },
  {
    "path": "examples/pytorch/graphsage/advanced/README.md",
    "content": "More Examples for Training GraphSAGE\n============================\n\n### Training with PyTorch Lightning\n\nWe provide minibatch training scripts with PyTorch Lightning in `train_lightning_unsupervised.py`.\n\nRequires `pytorch_lightning` and `torchmetrics`.\n\n```bash\npython3 train_lightning_unsupervised.py\n```\n"
  },
  {
    "path": "examples/pytorch/graphsage/advanced/model.py",
    "content": "import dgl\nimport dgl.nn as dglnn\nimport sklearn.linear_model as lm\nimport sklearn.metrics as skm\nimport torch as th\nimport torch.functional as F\nimport torch.nn as nn\nimport tqdm\n\n\nclass SAGE(nn.Module):\n    def __init__(\n        self, in_feats, n_hidden, n_classes, n_layers, activation, dropout\n    ):\n        super().__init__()\n        self.init(in_feats, n_hidden, n_classes, n_layers, activation, dropout)\n\n    def init(\n        self, in_feats, n_hidden, n_classes, n_layers, activation, dropout\n    ):\n        self.n_layers = n_layers\n        self.n_hidden = n_hidden\n        self.n_classes = n_classes\n        self.layers = nn.ModuleList()\n        if n_layers > 1:\n            self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, \"mean\"))\n            for i in range(1, n_layers - 1):\n                self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, \"mean\"))\n            self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, \"mean\"))\n        else:\n            self.layers.append(dglnn.SAGEConv(in_feats, n_classes, \"mean\"))\n        self.dropout = nn.Dropout(dropout)\n        self.activation = activation\n\n    def forward(self, blocks, x):\n        h = x\n        for l, (layer, block) in enumerate(zip(self.layers, blocks)):\n            h = layer(block, h)\n            if l != len(self.layers) - 1:\n                h = self.activation(h)\n                h = self.dropout(h)\n        return h\n\n    def inference(self, g, x, device, batch_size, num_workers):\n        \"\"\"\n        Inference with the GraphSAGE model on full neighbors (i.e. without neighbor sampling).\n        g : the entire graph.\n        x : the input of entire node set.\n\n        The inference code is written in a fashion that it could handle any number of nodes and\n        layers.\n        \"\"\"\n        # During inference with sampling, multi-layer blocks are very inefficient because\n        # lots of computations in the first few layers are repeated.\n        # Therefore, we compute the representation of all nodes layer by layer.  The nodes\n        # on each layer are of course splitted in batches.\n        # TODO: can we standardize this?\n        for l, layer in enumerate(self.layers):\n            y = th.zeros(\n                g.num_nodes(),\n                self.n_hidden if l != len(self.layers) - 1 else self.n_classes,\n            )\n\n            sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)\n            dataloader = dgl.dataloading.DataLoader(\n                g,\n                th.arange(g.num_nodes()).to(g.device),\n                sampler,\n                device=device if num_workers == 0 else None,\n                batch_size=batch_size,\n                shuffle=False,\n                drop_last=False,\n                num_workers=num_workers,\n            )\n\n            for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):\n                block = blocks[0]\n\n                block = block.int().to(device)\n                h = x[input_nodes].to(device)\n                h = layer(block, h)\n                if l != len(self.layers) - 1:\n                    h = self.activation(h)\n                    h = self.dropout(h)\n\n                y[output_nodes] = h.cpu()\n\n            x = y\n        return y\n\n\ndef compute_acc_unsupervised(emb, labels, train_nids, val_nids, test_nids):\n    \"\"\"\n    Compute the accuracy of prediction given the labels.\n    \"\"\"\n    emb = emb.cpu().numpy()\n    labels = labels.cpu().numpy()\n    train_nids = train_nids.cpu().numpy()\n    train_labels = labels[train_nids]\n    val_nids = val_nids.cpu().numpy()\n    val_labels = labels[val_nids]\n    test_nids = test_nids.cpu().numpy()\n    test_labels = labels[test_nids]\n\n    emb = (emb - emb.mean(0, keepdims=True)) / emb.std(0, keepdims=True)\n\n    lr = lm.LogisticRegression(multi_class=\"multinomial\", max_iter=10000)\n    lr.fit(emb[train_nids], train_labels)\n\n    pred = lr.predict(emb)\n    f1_micro_eval = skm.f1_score(val_labels, pred[val_nids], average=\"micro\")\n    f1_micro_test = skm.f1_score(test_labels, pred[test_nids], average=\"micro\")\n    return f1_micro_eval, f1_micro_test\n"
  },
  {
    "path": "examples/pytorch/graphsage/advanced/negative_sampler.py",
    "content": "import dgl\nimport torch as th\n\n\nclass NegativeSampler(object):\n    def __init__(self, g, k, neg_share=False, device=None):\n        if device is None:\n            device = g.device\n        self.weights = g.in_degrees().float().to(device) ** 0.75\n        self.k = k\n        self.neg_share = neg_share\n\n    def __call__(self, g, eids):\n        src, _ = g.find_edges(eids)\n        n = len(src)\n        if self.neg_share and n % self.k == 0:\n            dst = self.weights.multinomial(n, replacement=True)\n            dst = dst.view(-1, 1, self.k).expand(-1, self.k, -1).flatten()\n        else:\n            dst = self.weights.multinomial(n * self.k, replacement=True)\n        src = src.repeat_interleave(self.k)\n        return src, dst\n"
  },
  {
    "path": "examples/pytorch/graphsage/advanced/train_lightning_unsupervised.py",
    "content": "import argparse\nimport glob\nimport os\nimport sys\nimport time\n\nimport dgl\nimport dgl.function as fn\nimport dgl.nn.pytorch as dglnn\nimport numpy as np\nimport torch as th\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nimport tqdm\nfrom model import compute_acc_unsupervised as compute_acc, SAGE\n\nfrom negative_sampler import NegativeSampler\nfrom pytorch_lightning import LightningDataModule, LightningModule, Trainer\n\nfrom pytorch_lightning.callbacks import Callback, ModelCheckpoint\n\nsys.path.append(os.path.join(os.path.dirname(__file__), \"..\"))\nfrom load_graph import inductive_split, load_ogb, load_reddit\n\n\nclass CrossEntropyLoss(nn.Module):\n    def forward(self, block_outputs, pos_graph, neg_graph):\n        with pos_graph.local_scope():\n            pos_graph.ndata[\"h\"] = block_outputs\n            pos_graph.apply_edges(fn.u_dot_v(\"h\", \"h\", \"score\"))\n            pos_score = pos_graph.edata[\"score\"]\n        with neg_graph.local_scope():\n            neg_graph.ndata[\"h\"] = block_outputs\n            neg_graph.apply_edges(fn.u_dot_v(\"h\", \"h\", \"score\"))\n            neg_score = neg_graph.edata[\"score\"]\n\n        score = th.cat([pos_score, neg_score])\n        label = th.cat(\n            [th.ones_like(pos_score), th.zeros_like(neg_score)]\n        ).long()\n        loss = F.binary_cross_entropy_with_logits(score, label.float())\n        return loss\n\n\nclass SAGELightning(LightningModule):\n    def __init__(\n        self, in_feats, n_hidden, n_classes, n_layers, activation, dropout, lr\n    ):\n        super().__init__()\n        self.save_hyperparameters()\n        self.module = SAGE(\n            in_feats, n_hidden, n_classes, n_layers, activation, dropout\n        )\n        self.lr = lr\n        self.loss_fcn = CrossEntropyLoss()\n\n    def training_step(self, batch, batch_idx):\n        input_nodes, pos_graph, neg_graph, mfgs = batch\n        mfgs = [mfg.int().to(device) for mfg in mfgs]\n        pos_graph = pos_graph.to(device)\n        neg_graph = neg_graph.to(device)\n        batch_inputs = mfgs[0].srcdata[\"features\"]\n        batch_labels = mfgs[-1].dstdata[\"labels\"]\n        batch_pred = self.module(mfgs, batch_inputs)\n        loss = self.loss_fcn(batch_pred, pos_graph, neg_graph)\n        self.log(\n            \"train_loss\", loss, prog_bar=True, on_step=False, on_epoch=True\n        )\n        return loss\n\n    def validation_step(self, batch, batch_idx):\n        input_nodes, output_nodes, mfgs = batch\n        mfgs = [mfg.int().to(device) for mfg in mfgs]\n        batch_inputs = mfgs[0].srcdata[\"features\"]\n        batch_labels = mfgs[-1].dstdata[\"labels\"]\n        batch_pred = self.module(mfgs, batch_inputs)\n        return batch_pred\n\n    def configure_optimizers(self):\n        optimizer = th.optim.Adam(self.parameters(), lr=self.lr)\n        return optimizer\n\n\nclass DataModule(LightningDataModule):\n    def __init__(\n        self,\n        dataset_name,\n        data_cpu=False,\n        fan_out=[10, 25],\n        device=th.device(\"cpu\"),\n        batch_size=1000,\n        num_workers=4,\n    ):\n        super().__init__()\n        if dataset_name == \"reddit\":\n            g, n_classes = load_reddit()\n            n_edges = g.num_edges()\n            reverse_eids = th.cat(\n                [th.arange(n_edges // 2, n_edges), th.arange(0, n_edges // 2)]\n            )\n        elif dataset_name == \"ogbn-products\":\n            g, n_classes = load_ogb(\"ogbn-products\")\n            n_edges = g.num_edges()\n            # The reverse edge of edge 0 in OGB products dataset is 1.\n            # The reverse edge of edge 2 is 3.  So on so forth.\n            reverse_eids = th.arange(n_edges) ^ 1\n        else:\n            raise ValueError(\"unknown dataset\")\n\n        train_nid = th.nonzero(g.ndata[\"train_mask\"], as_tuple=True)[0]\n        val_nid = th.nonzero(g.ndata[\"val_mask\"], as_tuple=True)[0]\n        test_nid = th.nonzero(\n            ~(g.ndata[\"train_mask\"] | g.ndata[\"val_mask\"]), as_tuple=True\n        )[0]\n\n        sampler = dgl.dataloading.MultiLayerNeighborSampler(\n            [int(_) for _ in fan_out]\n        )\n\n        dataloader_device = th.device(\"cpu\")\n        if not data_cpu:\n            train_nid = train_nid.to(device)\n            val_nid = val_nid.to(device)\n            test_nid = test_nid.to(device)\n            g = g.formats([\"csc\"])\n            g = g.to(device)\n            dataloader_device = device\n\n        self.g = g\n        self.train_nid, self.val_nid, self.test_nid = (\n            train_nid,\n            val_nid,\n            test_nid,\n        )\n        self.sampler = sampler\n        self.device = dataloader_device\n        self.batch_size = batch_size\n        self.num_workers = num_workers\n        self.in_feats = g.ndata[\"features\"].shape[1]\n        self.n_classes = n_classes\n        self.reverse_eids = reverse_eids\n\n    def train_dataloader(self):\n        sampler = dgl.dataloading.as_edge_prediction_sampler(\n            self.sampler,\n            exclude=\"reverse_id\",\n            reverse_eids=self.reverse_eids,\n            negative_sampler=NegativeSampler(\n                self.g, args.num_negs, args.neg_share\n            ),\n        )\n        return dgl.dataloading.DataLoader(\n            self.g,\n            np.arange(self.g.num_edges()),\n            sampler,\n            device=self.device,\n            batch_size=self.batch_size,\n            shuffle=True,\n            drop_last=False,\n            num_workers=self.num_workers,\n        )\n\n    def val_dataloader(self):\n        # Note that the validation data loader is a DataLoader\n        # as we want to evaluate all the node embeddings.\n        return dgl.dataloading.DataLoader(\n            self.g,\n            np.arange(self.g.num_nodes()),\n            self.sampler,\n            device=self.device,\n            batch_size=self.batch_size,\n            shuffle=False,\n            drop_last=False,\n            num_workers=self.num_workers,\n        )\n\n\nclass UnsupervisedClassification(Callback):\n    def on_validation_epoch_start(self, trainer, pl_module):\n        self.val_outputs = []\n\n    def on_validation_batch_end(\n        self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx\n    ):\n        self.val_outputs.append(outputs)\n\n    def on_validation_epoch_end(self, trainer, pl_module):\n        node_emb = th.cat(self.val_outputs, 0)\n        g = trainer.datamodule.g\n        labels = g.ndata[\"labels\"]\n        f1_micro, f1_macro = compute_acc(\n            node_emb,\n            labels,\n            trainer.datamodule.train_nid,\n            trainer.datamodule.val_nid,\n            trainer.datamodule.test_nid,\n        )\n        pl_module.log(\"val_f1_micro\", f1_micro)\n\n\nif __name__ == \"__main__\":\n    argparser = argparse.ArgumentParser(\"multi-gpu training\")\n    argparser.add_argument(\"--gpu\", type=int, default=0)\n    argparser.add_argument(\"--dataset\", type=str, default=\"reddit\")\n    argparser.add_argument(\"--num-epochs\", type=int, default=20)\n    argparser.add_argument(\"--num-hidden\", type=int, default=16)\n    argparser.add_argument(\"--num-layers\", type=int, default=2)\n    argparser.add_argument(\"--num-negs\", type=int, default=1)\n    argparser.add_argument(\n        \"--neg-share\",\n        default=False,\n        action=\"store_true\",\n        help=\"sharing neg nodes for positive nodes\",\n    )\n    argparser.add_argument(\"--fan-out\", type=str, default=\"10,25\")\n    argparser.add_argument(\"--batch-size\", type=int, default=10000)\n    argparser.add_argument(\"--log-every\", type=int, default=20)\n    argparser.add_argument(\"--eval-every\", type=int, default=1000)\n    argparser.add_argument(\"--lr\", type=float, default=0.003)\n    argparser.add_argument(\"--dropout\", type=float, default=0.5)\n    argparser.add_argument(\n        \"--num-workers\",\n        type=int,\n        default=0,\n        help=\"Number of sampling processes. Use 0 for no extra process.\",\n    )\n    args = argparser.parse_args()\n\n    if args.gpu >= 0:\n        device = th.device(\"cuda:%d\" % args.gpu)\n    else:\n        device = th.device(\"cpu\")\n\n    datamodule = DataModule(\n        args.dataset,\n        True,\n        [int(_) for _ in args.fan_out.split(\",\")],\n        device,\n        args.batch_size,\n        args.num_workers,\n    )\n    model = SAGELightning(\n        datamodule.in_feats,\n        args.num_hidden,\n        datamodule.n_classes,\n        args.num_layers,\n        F.relu,\n        args.dropout,\n        args.lr,\n    )\n\n    # Train\n    unsupervised_callback = UnsupervisedClassification()\n    checkpoint_callback = ModelCheckpoint(monitor=\"val_f1_micro\", save_top_k=1)\n    trainer = Trainer(\n        gpus=[args.gpu] if args.gpu != -1 else None,\n        max_epochs=args.num_epochs,\n        val_check_interval=1000,\n        callbacks=[checkpoint_callback, unsupervised_callback],\n        num_sanity_val_steps=0,\n    )\n    trainer.fit(model, datamodule=datamodule)\n"
  },
  {
    "path": "examples/pytorch/graphsage/lightning/node_classification.py",
    "content": "import glob\nimport os\n\nimport dgl\nimport dgl.nn.pytorch as dglnn\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nimport torchmetrics.functional as MF\nimport tqdm\nfrom ogb.nodeproppred import DglNodePropPredDataset\nfrom pytorch_lightning import LightningDataModule, LightningModule, Trainer\nfrom pytorch_lightning.callbacks import ModelCheckpoint\nfrom torchmetrics import Accuracy\n\n\nclass SAGE(LightningModule):\n    def __init__(self, in_feats, n_hidden, n_classes):\n        super().__init__()\n        self.save_hyperparameters()\n        self.layers = nn.ModuleList()\n        self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, \"mean\"))\n        self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, \"mean\"))\n        self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, \"mean\"))\n        self.dropout = nn.Dropout(0.5)\n        self.n_hidden = n_hidden\n        self.n_classes = n_classes\n        self.train_acc = Accuracy(task=\"multiclass\", num_classes=n_classes)\n        self.val_acc = Accuracy(task=\"multiclass\", num_classes=n_classes)\n\n    def forward(self, blocks, x):\n        h = x\n        for l, (layer, block) in enumerate(zip(self.layers, blocks)):\n            h = layer(block, h)\n            if l != len(self.layers) - 1:\n                h = F.relu(h)\n                h = self.dropout(h)\n        return h\n\n    def inference(self, g, device, batch_size, num_workers, buffer_device=None):\n        # The difference between this inference function and the one in the official\n        # example is that the intermediate results can also benefit from prefetching.\n        g.ndata[\"h\"] = g.ndata[\"feat\"]\n        sampler = dgl.dataloading.MultiLayerFullNeighborSampler(\n            1, prefetch_node_feats=[\"h\"]\n        )\n        dataloader = dgl.dataloading.DataLoader(\n            g,\n            torch.arange(g.num_nodes()).to(g.device),\n            sampler,\n            device=device,\n            batch_size=batch_size,\n            shuffle=False,\n            drop_last=False,\n            num_workers=num_workers,\n            persistent_workers=(num_workers > 0),\n        )\n        if buffer_device is None:\n            buffer_device = device\n\n        for l, layer in enumerate(self.layers):\n            y = torch.zeros(\n                g.num_nodes(),\n                self.n_hidden if l != len(self.layers) - 1 else self.n_classes,\n                device=buffer_device,\n            )\n            for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):\n                x = blocks[0].srcdata[\"h\"]\n                h = layer(blocks[0], x)\n                if l != len(self.layers) - 1:\n                    h = F.relu(h)\n                    h = self.dropout(h)\n                y[output_nodes] = h.to(buffer_device)\n            g.ndata[\"h\"] = y\n        return y\n\n    def training_step(self, batch, batch_idx):\n        input_nodes, output_nodes, blocks = batch\n        x = blocks[0].srcdata[\"feat\"]\n        y = blocks[-1].dstdata[\"label\"]\n        y_hat = self(blocks, x)\n        loss = F.cross_entropy(y_hat, y)\n        self.train_acc(torch.argmax(y_hat, 1), y)\n        self.log(\n            \"train_acc\",\n            self.train_acc,\n            prog_bar=True,\n            on_step=True,\n            on_epoch=False,\n        )\n        return loss\n\n    def validation_step(self, batch, batch_idx):\n        input_nodes, output_nodes, blocks = batch\n        x = blocks[0].srcdata[\"feat\"]\n        y = blocks[-1].dstdata[\"label\"]\n        y_hat = self(blocks, x)\n        self.val_acc(torch.argmax(y_hat, 1), y)\n        self.log(\n            \"val_acc\",\n            self.val_acc,\n            prog_bar=True,\n            on_step=True,\n            on_epoch=True,\n            sync_dist=True,\n        )\n\n    def configure_optimizers(self):\n        optimizer = torch.optim.Adam(\n            self.parameters(), lr=0.001, weight_decay=5e-4\n        )\n        return optimizer\n\n\nclass DataModule(LightningDataModule):\n    def __init__(\n        self, graph, train_idx, val_idx, fanouts, batch_size, n_classes\n    ):\n        super().__init__()\n\n        sampler = dgl.dataloading.NeighborSampler(\n            fanouts, prefetch_node_feats=[\"feat\"], prefetch_labels=[\"label\"]\n        )\n\n        self.g = graph\n        self.train_idx, self.val_idx = train_idx, val_idx\n        self.sampler = sampler\n        self.batch_size = batch_size\n        self.in_feats = graph.ndata[\"feat\"].shape[1]\n        self.n_classes = n_classes\n\n    def train_dataloader(self):\n        return dgl.dataloading.DataLoader(\n            self.g,\n            self.train_idx.to(\"cuda\"),\n            self.sampler,\n            device=\"cuda\",\n            batch_size=self.batch_size,\n            shuffle=True,\n            drop_last=False,\n            # For CPU sampling, set num_workers to nonzero and use_uva=False\n            # Set use_ddp to False for single GPU.\n            num_workers=0,\n            use_uva=True,\n            use_ddp=True,\n        )\n\n    def val_dataloader(self):\n        return dgl.dataloading.DataLoader(\n            self.g,\n            self.val_idx.to(\"cuda\"),\n            self.sampler,\n            device=\"cuda\",\n            batch_size=self.batch_size,\n            shuffle=True,\n            drop_last=False,\n            num_workers=0,\n            use_uva=True,\n        )\n\n\nif __name__ == \"__main__\":\n    dataset = DglNodePropPredDataset(\"ogbn-products\")\n    graph, labels = dataset[0]\n    graph.ndata[\"label\"] = labels.squeeze()\n    graph.create_formats_()\n    split_idx = dataset.get_idx_split()\n    train_idx, val_idx, test_idx = (\n        split_idx[\"train\"],\n        split_idx[\"valid\"],\n        split_idx[\"test\"],\n    )\n    datamodule = DataModule(\n        graph, train_idx, val_idx, [15, 10, 5], 1024, dataset.num_classes\n    )\n    model = SAGE(datamodule.in_feats, 256, datamodule.n_classes)\n\n    # Train\n    checkpoint_callback = ModelCheckpoint(monitor=\"val_acc\", save_top_k=1)\n    # Use this for single GPU\n    # trainer = Trainer(accelerator=\"gpu\", devices=[0], max_epochs=10,\n    #                   callbacks=[checkpoint_callback])\n    trainer = Trainer(\n        accelerator=\"gpu\",\n        devices=[0, 1, 2, 3],\n        max_epochs=10,\n        callbacks=[checkpoint_callback],\n        strategy=\"ddp_spawn\",\n    )\n    trainer.fit(model, datamodule=datamodule)\n\n    # Test\n    dirs = glob.glob(\"./lightning_logs/*\")\n    version = max([int(os.path.split(x)[-1].split(\"_\")[-1]) for x in dirs])\n    logdir = \"./lightning_logs/version_%d\" % version\n    print(\"Evaluating model in\", logdir)\n    ckpt = glob.glob(os.path.join(logdir, \"checkpoints\", \"*\"))[0]\n\n    model = SAGE.load_from_checkpoint(\n        checkpoint_path=ckpt, hparams_file=os.path.join(logdir, \"hparams.yaml\")\n    ).to(\"cuda\")\n    with torch.no_grad():\n        pred = model.inference(graph, \"cuda\", 4096, 12, graph.device)\n        pred = pred[test_idx]\n        label = graph.ndata[\"label\"][test_idx]\n        acc = MF.accuracy(\n            pred, label, task=\"multiclass\", num_classes=datamodule.n_classes\n        )\n    print(\"Test accuracy:\", acc)\n"
  },
  {
    "path": "examples/pytorch/graphsage/link_pred.py",
    "content": "import argparse\n\nimport dgl\nimport dgl.nn as dglnn\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport tqdm\nfrom dgl.dataloading import (\n    as_edge_prediction_sampler,\n    DataLoader,\n    MultiLayerFullNeighborSampler,\n    negative_sampler,\n    NeighborSampler,\n)\nfrom ogb.linkproppred import DglLinkPropPredDataset, Evaluator\n\n\ndef to_bidirected_with_reverse_mapping(g):\n    \"\"\"Makes a graph bidirectional, and returns a mapping array ``mapping`` where ``mapping[i]``\n    is the reverse edge of edge ID ``i``. Does not work with graphs that have self-loops.\n    \"\"\"\n    g_simple, mapping = dgl.to_simple(\n        dgl.add_reverse_edges(g), return_counts=\"count\", writeback_mapping=True\n    )\n    c = g_simple.edata[\"count\"]\n    num_edges = g.num_edges()\n    mapping_offset = torch.zeros(\n        g_simple.num_edges() + 1, dtype=g_simple.idtype\n    )\n    mapping_offset[1:] = c.cumsum(0)\n    idx = mapping.argsort()\n    idx_uniq = idx[mapping_offset[:-1]]\n    reverse_idx = torch.where(\n        idx_uniq >= num_edges, idx_uniq - num_edges, idx_uniq + num_edges\n    )\n    reverse_mapping = mapping[reverse_idx]\n    # sanity check\n    src1, dst1 = g_simple.edges()\n    src2, dst2 = g_simple.find_edges(reverse_mapping)\n    assert torch.equal(src1, dst2)\n    assert torch.equal(src2, dst1)\n    return g_simple, reverse_mapping\n\n\nclass SAGE(nn.Module):\n    def __init__(self, in_size, hid_size):\n        super().__init__()\n        self.layers = nn.ModuleList()\n        # three-layer GraphSAGE-mean\n        self.layers.append(dglnn.SAGEConv(in_size, hid_size, \"mean\"))\n        self.layers.append(dglnn.SAGEConv(hid_size, hid_size, \"mean\"))\n        self.layers.append(dglnn.SAGEConv(hid_size, hid_size, \"mean\"))\n        self.hid_size = hid_size\n        self.predictor = nn.Sequential(\n            nn.Linear(hid_size, hid_size),\n            nn.ReLU(),\n            nn.Linear(hid_size, hid_size),\n            nn.ReLU(),\n            nn.Linear(hid_size, 1),\n        )\n\n    def forward(self, pair_graph, neg_pair_graph, blocks, x):\n        h = x\n        for l, (layer, block) in enumerate(zip(self.layers, blocks)):\n            h = layer(block, h)\n            if l != len(self.layers) - 1:\n                h = F.relu(h)\n        pos_src, pos_dst = pair_graph.edges()\n        neg_src, neg_dst = neg_pair_graph.edges()\n        h_pos = self.predictor(h[pos_src] * h[pos_dst])\n        h_neg = self.predictor(h[neg_src] * h[neg_dst])\n        return h_pos, h_neg\n\n    def inference(self, g, device, batch_size):\n        \"\"\"Layer-wise inference algorithm to compute GNN node embeddings.\"\"\"\n        feat = g.ndata[\"feat\"]\n        sampler = MultiLayerFullNeighborSampler(1, prefetch_node_feats=[\"feat\"])\n        dataloader = DataLoader(\n            g,\n            torch.arange(g.num_nodes()).to(g.device),\n            sampler,\n            device=device,\n            batch_size=batch_size,\n            shuffle=False,\n            drop_last=False,\n            num_workers=0,\n        )\n        buffer_device = torch.device(\"cpu\")\n        pin_memory = buffer_device != device\n        for l, layer in enumerate(self.layers):\n            y = torch.empty(\n                g.num_nodes(),\n                self.hid_size,\n                device=buffer_device,\n                pin_memory=pin_memory,\n            )\n            feat = feat.to(device)\n            for input_nodes, output_nodes, blocks in tqdm.tqdm(\n                dataloader, desc=\"Inference\"\n            ):\n                x = feat[input_nodes]\n                h = layer(blocks[0], x)\n                if l != len(self.layers) - 1:\n                    h = F.relu(h)\n                y[output_nodes] = h.to(buffer_device)\n            feat = y\n        return y\n\n\ndef compute_mrr(\n    model, evaluator, node_emb, src, dst, neg_dst, device, batch_size=500\n):\n    \"\"\"Compute Mean Reciprocal Rank (MRR) in batches.\"\"\"\n    rr = torch.zeros(src.shape[0])\n    for start in tqdm.trange(0, src.shape[0], batch_size, desc=\"Evaluate\"):\n        end = min(start + batch_size, src.shape[0])\n        all_dst = torch.cat([dst[start:end, None], neg_dst[start:end]], 1)\n        h_src = node_emb[src[start:end]][:, None, :].to(device)\n        h_dst = node_emb[all_dst.view(-1)].view(*all_dst.shape, -1).to(device)\n        pred = model.predictor(h_src * h_dst).squeeze(-1)\n        input_dict = {\"y_pred_pos\": pred[:, 0], \"y_pred_neg\": pred[:, 1:]}\n        rr[start:end] = evaluator.eval(input_dict)[\"mrr_list\"]\n    return rr.mean()\n\n\ndef evaluate(device, graph, edge_split, model, batch_size):\n    model.eval()\n    evaluator = Evaluator(name=\"ogbl-citation2\")\n    with torch.no_grad():\n        node_emb = model.inference(graph, device, batch_size)\n        results = []\n        for split in [\"valid\", \"test\"]:\n            src = edge_split[split][\"source_node\"].to(node_emb.device)\n            dst = edge_split[split][\"target_node\"].to(node_emb.device)\n            neg_dst = edge_split[split][\"target_node_neg\"].to(node_emb.device)\n            results.append(\n                compute_mrr(\n                    model, evaluator, node_emb, src, dst, neg_dst, device\n                )\n            )\n    return results\n\n\ndef train(args, device, g, reverse_eids, seed_edges, model):\n    # create sampler & dataloader\n    sampler = NeighborSampler([15, 10, 5], prefetch_node_feats=[\"feat\"])\n    sampler = as_edge_prediction_sampler(\n        sampler,\n        exclude=\"reverse_id\",\n        reverse_eids=reverse_eids,\n        negative_sampler=negative_sampler.Uniform(1),\n    )\n    use_uva = args.mode == \"mixed\"\n    dataloader = DataLoader(\n        g,\n        seed_edges,\n        sampler,\n        device=device,\n        batch_size=512,\n        shuffle=True,\n        drop_last=False,\n        num_workers=0,\n        use_uva=use_uva,\n    )\n    opt = torch.optim.Adam(model.parameters(), lr=0.0005)\n    for epoch in range(10):\n        model.train()\n        total_loss = 0\n        for it, (input_nodes, pair_graph, neg_pair_graph, blocks) in enumerate(\n            dataloader\n        ):\n            x = blocks[0].srcdata[\"feat\"]\n            pos_score, neg_score = model(pair_graph, neg_pair_graph, blocks, x)\n            score = torch.cat([pos_score, neg_score])\n            pos_label = torch.ones_like(pos_score)\n            neg_label = torch.zeros_like(neg_score)\n            labels = torch.cat([pos_label, neg_label])\n            loss = F.binary_cross_entropy_with_logits(score, labels)\n            opt.zero_grad()\n            loss.backward()\n            opt.step()\n            total_loss += loss.item()\n            if (it + 1) == 1000:\n                break\n        print(\"Epoch {:05d} | Loss {:.4f}\".format(epoch, total_loss / (it + 1)))\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--mode\",\n        default=\"mixed\",\n        choices=[\"cpu\", \"mixed\", \"puregpu\"],\n        help=\"Training mode. 'cpu' for CPU training, 'mixed' for CPU-GPU mixed training, \"\n        \"'puregpu' for pure-GPU training.\",\n    )\n    args = parser.parse_args()\n    if not torch.cuda.is_available():\n        args.mode = \"cpu\"\n    print(f\"Training in {args.mode} mode.\")\n\n    # load and preprocess dataset\n    print(\"Loading data\")\n    dataset = DglLinkPropPredDataset(\"ogbl-citation2\")\n    g = dataset[0]\n    g = g.to(\"cuda\" if args.mode == \"puregpu\" else \"cpu\")\n    device = torch.device(\"cpu\" if args.mode == \"cpu\" else \"cuda\")\n    g, reverse_eids = to_bidirected_with_reverse_mapping(g)\n    reverse_eids = reverse_eids.to(device)\n    seed_edges = torch.arange(g.num_edges()).to(device)\n    edge_split = dataset.get_edge_split()\n\n    # create GraphSAGE model\n    in_size = g.ndata[\"feat\"].shape[1]\n    model = SAGE(in_size, 256).to(device)\n\n    # model training\n    print(\"Training...\")\n    train(args, device, g, reverse_eids, seed_edges, model)\n\n    # validate/test the model\n    print(\"Validation/Testing...\")\n    valid_mrr, test_mrr = evaluate(\n        device, g, edge_split, model, batch_size=1000\n    )\n    print(\n        \"Validation MRR {:.4f}, Test MRR {:.4f}\".format(\n            valid_mrr.item(), test_mrr.item()\n        )\n    )\n"
  },
  {
    "path": "examples/pytorch/graphsage/load_graph.py",
    "content": "import dgl\nimport torch as th\n\n\ndef load_reddit(self_loop=True):\n    from dgl.data import RedditDataset\n\n    # load reddit data\n    data = RedditDataset(self_loop=self_loop)\n    g = data[0]\n    g.ndata[\"features\"] = g.ndata.pop(\"feat\")\n    g.ndata[\"labels\"] = g.ndata.pop(\"label\")\n    return g, data.num_classes\n\n\ndef load_ogb(name, root=\"dataset\"):\n    from ogb.nodeproppred import DglNodePropPredDataset\n\n    print(\"load\", name)\n    data = DglNodePropPredDataset(name=name, root=root)\n    print(\"finish loading\", name)\n    splitted_idx = data.get_idx_split()\n    graph, labels = data[0]\n    labels = labels[:, 0]\n\n    graph.ndata[\"features\"] = graph.ndata.pop(\"feat\")\n    graph.ndata[\"labels\"] = labels\n    in_feats = graph.ndata[\"features\"].shape[1]\n    num_labels = len(th.unique(labels[th.logical_not(th.isnan(labels))]))\n\n    # Find the node IDs in the training, validation, and test set.\n    train_nid, val_nid, test_nid = (\n        splitted_idx[\"train\"],\n        splitted_idx[\"valid\"],\n        splitted_idx[\"test\"],\n    )\n    train_mask = th.zeros((graph.num_nodes(),), dtype=th.bool)\n    train_mask[train_nid] = True\n    val_mask = th.zeros((graph.num_nodes(),), dtype=th.bool)\n    val_mask[val_nid] = True\n    test_mask = th.zeros((graph.num_nodes(),), dtype=th.bool)\n    test_mask[test_nid] = True\n    graph.ndata[\"train_mask\"] = train_mask\n    graph.ndata[\"val_mask\"] = val_mask\n    graph.ndata[\"test_mask\"] = test_mask\n    print(\"finish constructing\", name)\n    return graph, num_labels\n\n\ndef inductive_split(g):\n    \"\"\"Split the graph into training graph, validation graph, and test graph by training\n    and validation masks.  Suitable for inductive models.\"\"\"\n    train_g = g.subgraph(g.ndata[\"train_mask\"])\n    val_g = g.subgraph(g.ndata[\"train_mask\"] | g.ndata[\"val_mask\"])\n    test_g = g\n    return train_g, val_g, test_g\n"
  },
  {
    "path": "examples/pytorch/graphsage/node_classification.py",
    "content": "import argparse\n\nimport dgl\nimport dgl.nn as dglnn\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torchmetrics.functional as MF\nimport tqdm\nfrom dgl.data import AsNodePredDataset\nfrom dgl.dataloading import (\n    DataLoader,\n    MultiLayerFullNeighborSampler,\n    NeighborSampler,\n)\nfrom ogb.nodeproppred import DglNodePropPredDataset\n\n\nclass SAGE(nn.Module):\n    def __init__(self, in_size, hid_size, out_size):\n        super().__init__()\n        self.layers = nn.ModuleList()\n        # three-layer GraphSAGE-mean\n        self.layers.append(dglnn.SAGEConv(in_size, hid_size, \"mean\"))\n        self.layers.append(dglnn.SAGEConv(hid_size, hid_size, \"mean\"))\n        self.layers.append(dglnn.SAGEConv(hid_size, out_size, \"mean\"))\n        self.dropout = nn.Dropout(0.5)\n        self.hid_size = hid_size\n        self.out_size = out_size\n\n    def forward(self, blocks, x):\n        h = x\n        for l, (layer, block) in enumerate(zip(self.layers, blocks)):\n            h = layer(block, h)\n            if l != len(self.layers) - 1:\n                h = F.relu(h)\n                h = self.dropout(h)\n        return h\n\n    def inference(self, g, device, batch_size):\n        \"\"\"Conduct layer-wise inference to get all the node embeddings.\"\"\"\n        feat = g.ndata[\"feat\"]\n        sampler = MultiLayerFullNeighborSampler(1, prefetch_node_feats=[\"feat\"])\n        dataloader = DataLoader(\n            g,\n            torch.arange(g.num_nodes()).to(g.device),\n            sampler,\n            device=device,\n            batch_size=batch_size,\n            shuffle=False,\n            drop_last=False,\n            num_workers=0,\n        )\n        buffer_device = torch.device(\"cpu\")\n        pin_memory = buffer_device != device\n\n        for l, layer in enumerate(self.layers):\n            y = torch.empty(\n                g.num_nodes(),\n                self.hid_size if l != len(self.layers) - 1 else self.out_size,\n                dtype=feat.dtype,\n                device=buffer_device,\n                pin_memory=pin_memory,\n            )\n            feat = feat.to(device)\n            for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):\n                x = feat[input_nodes]\n                h = layer(blocks[0], x)  # len(blocks) = 1\n                if l != len(self.layers) - 1:\n                    h = F.relu(h)\n                    h = self.dropout(h)\n                # by design, our output nodes are contiguous\n                y[output_nodes[0] : output_nodes[-1] + 1] = h.to(buffer_device)\n            feat = y\n        return y\n\n\ndef evaluate(model, graph, dataloader, num_classes):\n    model.eval()\n    ys = []\n    y_hats = []\n    for it, (input_nodes, output_nodes, blocks) in enumerate(dataloader):\n        with torch.no_grad():\n            x = blocks[0].srcdata[\"feat\"]\n            ys.append(blocks[-1].dstdata[\"label\"])\n            y_hats.append(model(blocks, x))\n    return MF.accuracy(\n        torch.cat(y_hats),\n        torch.cat(ys),\n        task=\"multiclass\",\n        num_classes=num_classes,\n    )\n\n\ndef layerwise_infer(device, graph, nid, model, num_classes, batch_size):\n    model.eval()\n    with torch.no_grad():\n        pred = model.inference(\n            graph, device, batch_size\n        )  # pred in buffer_device\n        pred = pred[nid]\n        label = graph.ndata[\"label\"][nid].to(pred.device)\n        return MF.accuracy(\n            pred, label, task=\"multiclass\", num_classes=num_classes\n        )\n\n\ndef train(args, device, g, dataset, model, num_classes):\n    # create sampler & dataloader\n    train_idx = dataset.train_idx.to(device)\n    val_idx = dataset.val_idx.to(device)\n    sampler = NeighborSampler(\n        [10, 10, 10],  # fanout for [layer-0, layer-1, layer-2]\n        prefetch_node_feats=[\"feat\"],\n        prefetch_labels=[\"label\"],\n    )\n    use_uva = args.mode == \"mixed\"\n    train_dataloader = DataLoader(\n        g,\n        train_idx,\n        sampler,\n        device=device,\n        batch_size=1024,\n        shuffle=True,\n        drop_last=False,\n        num_workers=0,\n        use_uva=use_uva,\n    )\n\n    val_dataloader = DataLoader(\n        g,\n        val_idx,\n        sampler,\n        device=device,\n        batch_size=1024,\n        shuffle=True,\n        drop_last=False,\n        num_workers=0,\n        use_uva=use_uva,\n    )\n\n    opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4)\n\n    for epoch in range(10):\n        model.train()\n        total_loss = 0\n        for it, (input_nodes, output_nodes, blocks) in enumerate(\n            train_dataloader\n        ):\n            x = blocks[0].srcdata[\"feat\"]\n            y = blocks[-1].dstdata[\"label\"]\n            y_hat = model(blocks, x)\n            loss = F.cross_entropy(y_hat, y)\n            opt.zero_grad()\n            loss.backward()\n            opt.step()\n            total_loss += loss.item()\n        acc = evaluate(model, g, val_dataloader, num_classes)\n        print(\n            \"Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} \".format(\n                epoch, total_loss / (it + 1), acc.item()\n            )\n        )\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--mode\",\n        default=\"mixed\",\n        choices=[\"cpu\", \"mixed\", \"puregpu\"],\n        help=\"Training mode. 'cpu' for CPU training, 'mixed' for CPU-GPU mixed training, \"\n        \"'puregpu' for pure-GPU training.\",\n    )\n    parser.add_argument(\n        \"--dt\",\n        type=str,\n        default=\"float\",\n        help=\"data type(float, bfloat16)\",\n    )\n    args = parser.parse_args()\n    if not torch.cuda.is_available():\n        args.mode = \"cpu\"\n    print(f\"Training in {args.mode} mode.\")\n\n    # load and preprocess dataset\n    print(\"Loading data\")\n    dataset = AsNodePredDataset(DglNodePropPredDataset(\"ogbn-products\"))\n    g = dataset[0]\n    g = g.to(\"cuda\" if args.mode == \"puregpu\" else \"cpu\")\n    num_classes = dataset.num_classes\n    device = torch.device(\"cpu\" if args.mode == \"cpu\" else \"cuda\")\n\n    # create GraphSAGE model\n    in_size = g.ndata[\"feat\"].shape[1]\n    out_size = dataset.num_classes\n    model = SAGE(in_size, 256, out_size).to(device)\n\n    # convert model and graph to bfloat16 if needed\n    if args.dt == \"bfloat16\":\n        g = dgl.to_bfloat16(g)\n        model = model.to(dtype=torch.bfloat16)\n\n    # model training\n    print(\"Training...\")\n    train(args, device, g, dataset, model, num_classes)\n\n    # test the model\n    print(\"Testing...\")\n    acc = layerwise_infer(\n        device, g, dataset.test_idx, model, num_classes, batch_size=4096\n    )\n    print(\"Test Accuracy {:.4f}\".format(acc.item()))\n"
  },
  {
    "path": "examples/pytorch/graphsage/train_full.py",
    "content": "import argparse\n\nimport dgl\nimport dgl.nn as dglnn\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dgl import AddSelfLoop\nfrom dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset\n\n\nclass SAGE(nn.Module):\n    def __init__(self, in_size, hid_size, out_size):\n        super().__init__()\n        self.layers = nn.ModuleList()\n        # two-layer GraphSAGE-mean\n        self.layers.append(dglnn.SAGEConv(in_size, hid_size, \"gcn\"))\n        self.layers.append(dglnn.SAGEConv(hid_size, out_size, \"gcn\"))\n        self.dropout = nn.Dropout(0.5)\n\n    def forward(self, graph, x):\n        h = self.dropout(x)\n        for l, layer in enumerate(self.layers):\n            h = layer(graph, h)\n            if l != len(self.layers) - 1:\n                h = F.relu(h)\n                h = self.dropout(h)\n        return h\n\n\ndef evaluate(g, features, labels, mask, model):\n    model.eval()\n    with torch.no_grad():\n        logits = model(g, features)\n        logits = logits[mask]\n        labels = labels[mask]\n        _, indices = torch.max(logits, dim=1)\n        correct = torch.sum(indices == labels)\n        return correct.item() * 1.0 / len(labels)\n\n\ndef train(g, features, labels, masks, model):\n    # define train/val samples, loss function and optimizer\n    train_mask, val_mask = masks\n    loss_fcn = nn.CrossEntropyLoss()\n    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)\n\n    # training loop\n    for epoch in range(200):\n        model.train()\n        logits = model(g, features)\n        loss = loss_fcn(logits[train_mask], labels[train_mask])\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n        acc = evaluate(g, features, labels, val_mask, model)\n        print(\n            \"Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} \".format(\n                epoch, loss.item(), acc\n            )\n        )\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"GraphSAGE\")\n    parser.add_argument(\n        \"--dataset\",\n        type=str,\n        default=\"cora\",\n        help=\"Dataset name ('cora', 'citeseer', 'pubmed')\",\n    )\n    parser.add_argument(\n        \"--dt\",\n        type=str,\n        default=\"float\",\n        help=\"data type(float, bfloat16)\",\n    )\n    args = parser.parse_args()\n    print(f\"Training with DGL built-in GraphSage module\")\n\n    # load and preprocess dataset\n    transform = (\n        AddSelfLoop()\n    )  # by default, it will first remove self-loops to prevent duplication\n    if args.dataset == \"cora\":\n        data = CoraGraphDataset(transform=transform)\n    elif args.dataset == \"citeseer\":\n        data = CiteseerGraphDataset(transform=transform)\n    elif args.dataset == \"pubmed\":\n        data = PubmedGraphDataset(transform=transform)\n    else:\n        raise ValueError(\"Unknown dataset: {}\".format(args.dataset))\n    g = data[0]\n    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n    g = g.int().to(device)\n    features = g.ndata[\"feat\"]\n    labels = g.ndata[\"label\"]\n    masks = g.ndata[\"train_mask\"], g.ndata[\"val_mask\"]\n\n    # create GraphSAGE model\n    in_size = features.shape[1]\n    out_size = data.num_classes\n    model = SAGE(in_size, 16, out_size).to(device)\n\n    # convert model and graph to bfloat16 if needed\n    if args.dt == \"bfloat16\":\n        g = dgl.to_bfloat16(g)\n        features = features.to(dtype=torch.bfloat16)\n        model = model.to(dtype=torch.bfloat16)\n\n    # model training\n    print(\"Training...\")\n    train(g, features, labels, masks, model)\n\n    # test the model\n    print(\"Testing...\")\n    acc = evaluate(g, features, labels, g.ndata[\"test_mask\"], model)\n    print(\"Test accuracy {:.4f}\".format(acc))\n"
  },
  {
    "path": "examples/pytorch/graphsaint/README.md",
    "content": "# GraphSAINT\n\nThis DGL example implements the paper: GraphSAINT: Graph Sampling Based Inductive Learning Method.\n\nPaper link: https://arxiv.org/abs/1907.04931\n\nAuthor's code: https://github.com/GraphSAINT/GraphSAINT\n\nContributor: Jiahang Li ([@ljh1064126026](https://github.com/ljh1064126026))  Tang Liu ([@lt610](https://github.com/lt610))\n\nFor built-in GraphSAINT subgraph samplers with online sampling, use `dgl.dataloading.SAINTSampler`.\n\n## Dependencies\n\n- Python 3.7.10\n- PyTorch 1.8.1\n- NumPy 1.19.2\n- Scikit-learn 0.23.2\n- DGL 0.7.1\n\n## Dataset\n\nAll datasets used are provided by Author's [code](https://github.com/GraphSAINT/GraphSAINT). They are available in [Google Drive](https://drive.google.com/drive/folders/1zycmmDES39zVlbVCYs88JTJ1Wm5FbfLz) (alternatively, [Baidu Wangpan (code: f1ao)](https://pan.baidu.com/s/1SOb0SiSAXavwAcNqkttwcg#list/path=%2F)). Dataset summary(\"m\" stands for multi-label binary classification, and \"s\" for single-label.):\n| Dataset | Nodes | Edges | Degree | Feature | Classes |\n| :-: | :-: | :-: | :-: | :-: | :-: |\n| PPI | 14,755 | 225,270 | 15 | 50 | 121(m) |\n| Flickr | 89,250 | 899,756 | 10 | 500 | 7(s) |\n| Reddit | 232,965 | 11,606,919 | 50 | 602 | 41(s) |\n| Yelp | 716,847 | 6,977,410 | 10 | 300 | 100 (m) |\n| Amazon | 1,598,960 | 132,169,734 | 83 | 200 | 107 (m) |\n\nNote that the PPI dataset here is different from DGL's built-in variant.\n\n## Config\n\n- The config file is `config.py`, which contains best configs for experiments below.\n- Please refer to `sampler.py` to see explanations of some key parameters.\n\n### Parameters\n\n| **aggr**                                                     | **arch**                                                     | **dataset**                                                  | **dropout**                                                  |\n| ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |\n| define how to aggregate embeddings of each node and its neighbors' embeddings ,which can be 'concat', 'mean'. The neighbors' embeddings are generated based on GCN | e.g. '1-1-0', means there're three layers, the first and the second layer employ message passing on the graph, then aggregate the embeddings of each node  and its neighbors. The last layer only updates each node's embedding. The message passing  mechanism comes from GCN | the name of dataset, which can be 'ppi', 'flickr', 'reddit', 'yelp', 'amazon' | the dropout of model used in train_sampling.py               |\n| **edge_budget**                                              | **gpu**                                                      | **length**                                                   | **log_dir**                                                  |\n| the expected number of edges in each subgraph, which is specified in the paper | -1 means cpu, otherwise 'cuda:gpu', e.g. if gpu=0, use 'cuda:0' | the length of each random walk                               | the directory storing logs                                   |\n| **lr**                                                       | **n_epochs**                                                 | **n_hidden**                                                 | **no_batch_norm**                                            |\n| learning rate                                                | training epochs                                              | hidden dimension                                             | True if do NOT employ batch normalization in each layer      |\n| **node_budget**                                              | **num_subg**                                                 | **num_roots**                                                | **sampler**                                                  |\n| the expected number of nodes in each subgraph, which is specified in the paper | the expected number of pre_sampled subgraphs                 | the number of roots to generate random walks                 | specify which sampler to use, which can be 'node', 'edge', 'rw', corresponding to node, edge, random walk sampler |\n| **use_val**                                                  | **val_every**                                                | **num_workers_sampler**                                      | **num_subg_sampler**                                            |\n| True if use best model to test, which is stored by earlystop mechanism | validate per 'val_every' epochs                              | number of workers (processes) specified for internal dataloader in SAINTSampler, which is to pre-sample subgraphs | the maximal number of pre-sampled subgraphs                  |\n| **batch_size_sampler**                                          | **num_workers**                                              |                                                              |                                                              |\n| batch size of internal dataloader in SAINTSampler            | number of workers (processes) specified for external dataloader in train_sampling.py, which is to sample subgraphs in training phase |                                                              |                                                              |\n\n\n\n\n## Minibatch training\n\nRun with following:\n```bash\npython train_sampling.py --task $task $online\n# online sampling: e.g. python train_sampling.py --task ppi_n --online\n# offline sampling: e.g. python train_sampling.py --task flickr_e\n```\n\n- `$task` includes `ppi_n, ppi_e, ppi_rw, flickr_n, flickr_e, flickr_rw, reddit_n, reddit_e, reddit_rw, yelp_n, yelp_e, yelp_rw, amazon_n, amazon_e, amazon_rw`. For example, `ppi_n` represents running experiments on dataset `ppi` with `node sampler`\n- If `$online` is `--online`,  we sample subgraphs on-the-fly in the training phase, while discarding pre-sampled subgraphs. If `$online` is empty, we utilize pre-sampled subgraphs in the training phase.\n\n## Experiments\n\n* Paper: results from the paper\n* Running: results from experiments with the authors' code\n* DGL: results from experiments with the DGL example. The experiment config comes from `config.py`. You can modify parameters in the `config.py` to see different performance of different setup.\n\n> Note that we implement offline sampling and online sampling in training phase. Offline sampling means all subgraphs utilized in training phase come from pre-sampled subgraphs. Online sampling means we discard all pre-sampled subgraphs and re-sample new subgraphs in training phase.\n\n> Note that the sampling method in the pre-sampling phase must be offline sampling.\n\n### F1-micro\n\n#### Random node sampler\n\n| Method | PPI | Flickr | Reddit | Yelp | Amazon |\n| --- | --- | --- | --- | --- | --- |\n| Paper | 0.960±0.001 | 0.507±0.001 | 0.962±0.001 | 0.641±0.000 | 0.782±0.004 |\n| Running | 0.9628 | 0.5077 | 0.9622 | 0.6393 | 0.7695 |\n| DGL_offline | 0.9715      | 0.5024 | 0.9645 | 0.6457 | 0.8051 |\n| DGL_online | 0.9730 | 0.5071 | 0.9645 | 0.6444 | 0.8014 |\n\n#### Random edge sampler\n\n| Method      | PPI         | Flickr      | Reddit | Yelp | Amazon |\n| --- | --- | --- | --- | --- | --- |\n| Paper | 0.981±0.007 | 0.510±0.002 | 0.966±0.001 | 0.653±0.003 | 0.807±0.001 |\n| Running | 0.9810 | 0.5066 | 0.9656 | 0.6531 | 0.8071 |\n| DGL_offline | 0.9817      | 0.5077 | 0.9655 | 0.6530 | 0.8034 |\n| DGL_online | 0.9815 | 0.5041 | 0.9653 | 0.6516 | 0.7756 |\n\n#### Random walk sampler\n| Method      | PPI         | Flickr      | Reddit      | Yelp        | Amazon      |\n| --- | --- | --- | --- | --- | --- |\n| Paper | 0.981±0.004 | 0.511±0.001 | 0.966±0.001 | 0.653±0.003 | 0.815±0.001 |\n| Running | 0.9812 | 0.5104 | 0.9648      | 0.6527      | 0.8131      |\n| DGL_offline | 0.9833      | 0.5027 | 0.9582      | 0.6514      | 0.8178   |\n| DGL_online | 0.9820 | 0.5110 | 0.9572      | 0.6508      | 0.8157   |\n\n### Sampling time\n\n- Here sampling time includes consumed time of pre-sampling subgraphs and calculating normalization coefficients in the beginning.\n\n#### Random node sampler\n\n| Method      | PPI  | Flickr | Reddit | Yelp | Amazon |\n| --- | --- | --- | --- | --- | --- |\n| Running | 1.46 | 3.49 | 19 | 59.01 | 978.62 |\n| DGL | 2.51 | 1.12 | 27.32 | 60.15 | 929.24 |\n\n#### Random edge sampler\n\n| Method      | PPI  | Flickr | Reddit | Yelp | Amazon |\n| --- | --- | --- | --- | --- | --- |\n| Running | 1.4 | 3.18 | 13.88 | 39.02 |  |\n| DGL | 3.04 | 1.87 | 52.01 | 48.38 |  |\n\n#### Random walk sampler\n\n| Method      | PPI  | Flickr | Reddit | Yelp | Amazon |\n| --- | --- | --- | --- | --- | --- |\n| Running | 1.7 | 3.82 | 16.97 | 43.25 | 355.68 |\n| DGL | 3.05 | 2.13 | 11.01 | 22.23 | 151.84 |\n\n## Test std of sampling and normalization time\n\n- We've run experiments 10 times repeatedly to test average and standard deviation of sampling and normalization time. Here we just test time without training model to the end. Moreover, for efficient testing, the hardware and config employed here are not the same as the experiments above, so the sampling time might be a bit different from that above. But we keep the environment consistent in all experiments below.\n\n> The config here which is different with that in the section above is only `num_workers_sampler`, `batch_size_sampler` and `num_workers`, which are only correlated to the sampling speed. Other parameters are kept consistent across two sections thus the model's performance is not affected.\n\n> The value is (average, std).\n\n### Random node sampler\n\n| Method                    | PPI             | Flickr       | Reddit        | Yelp          | Amazon          |\n| ------------------------- | --------------- | ------------ | ------------- | ------------- | --------------- |\n| DGL_Sampling(std)         | 2.618, 0.004    | 3.017, 0.507 | 35.356, 2.363 | 69.913, 6.3   | 888.025, 16.004 |\n| DGL_Normalization(std)    | Small to ignore | 0.008, 0.004 | 0.26, 0.047   | 0.189, 0.0288 | 2.443, 0.124    |\n|                           |                 |              |               |               |                 |\n| author_Sampling(std)      | 0.788, 0.661    | 0.728, 0.367 | 8.931, 3.155  | 27.818, 1.384 | 295.597, 4.928  |\n| author_Normalization(std) | 0.665, 0.565    | 4.981, 2.952 | 17.231, 7.116 | 47.449, 2.794 | 279.241, 17.615 |\n\n### Random edge sampler\n\n| Method                    | PPI             | Flickr       | Reddit        | Yelp          | Amazon |\n| ------------------------- | --------------- | ------------ | ------------- | ------------- | ------ |\n| DGL_Sampling(std)         | 3.554, 0.292    | 4.722, 0.245 | 47.09, 2.76   | 75.219, 6.442 |        |\n| DGL_Normalization(std)    | Small to ignore | 0.005, 0.007 | 0.235, 0.026  | 0.193, 0.021  |        |\n|                           |                 |              |               |               |        |\n| author_Sampling(std)      | 0.802, 0.667    | 0.761, 0.387 | 6.058, 2.166  | 13.914, 1.864 |        |\n| author_Normalization(std) | 0.667, 0.570    | 5.180, 3.006 | 15.803, 5.867 | 44.278, 5.853 |        |\n\n### Random walk sampler\n\n| Method                    | PPI             | Flickr       | Reddit        | Yelp          | Amazon          |\n| ------------------------- | --------------- | ------------ | ------------- | ------------- | --------------- |\n| DGL_Sampling(std)         | 3.304, 0.08     | 5.487, 1.294 | 37.041, 2.083 | 39.951, 3.094 | 179.613, 18.881 |\n| DGL_Normalization(std)    | Small to ignore | 0.001, 0.003 | 0.235, 0.026  | 0.185, 0.018  | 3.769, 0.326    |\n|                           |                 |              |               |               |                 |\n| author_Sampling(std)      | 0.924, 0.773    | 1.405, 0.718 | 8.608, 3.093  | 19.113, 1.700 | 217.184, 1.546  |\n| author_Normalization(std) | 0.701, 0.596    | 5.025, 2.954 | 18.198, 7.223 | 45.874, 8.020 | 128.272, 3.170  |\n\n"
  },
  {
    "path": "examples/pytorch/graphsaint/config.py",
    "content": "CONFIG = {\n    \"ppi_n\": {\n        \"aggr\": \"concat\",\n        \"arch\": \"1-0-1-0\",\n        \"dataset\": \"ppi\",\n        \"dropout\": 0,\n        \"edge_budget\": 4000,\n        \"length\": 2,\n        \"log_dir\": \"none\",\n        \"lr\": 0.01,\n        \"n_epochs\": 50,\n        \"n_hidden\": 512,\n        \"no_batch_norm\": False,\n        \"node_budget\": 6000,\n        \"num_subg\": 50,\n        \"num_roots\": 3000,\n        \"sampler\": \"node\",\n        \"use_val\": True,\n        \"val_every\": 1,\n        \"num_workers_sampler\": 0,\n        \"num_subg_sampler\": 10000,\n        \"batch_size_sampler\": 200,\n        \"num_workers\": 8,\n        \"full\": True,\n    },\n    \"ppi_e\": {\n        \"aggr\": \"concat\",\n        \"arch\": \"1-0-1-0\",\n        \"dataset\": \"ppi\",\n        \"dropout\": 0.1,\n        \"edge_budget\": 4000,\n        \"length\": 2,\n        \"log_dir\": \"none\",\n        \"lr\": 0.01,\n        \"n_epochs\": 50,\n        \"n_hidden\": 512,\n        \"no_batch_norm\": False,\n        \"node_budget\": 6000,\n        \"num_subg\": 50,\n        \"num_roots\": 3000,\n        \"sampler\": \"edge\",\n        \"use_val\": True,\n        \"val_every\": 1,\n        \"num_workers_sampler\": 0,\n        \"num_subg_sampler\": 10000,\n        \"batch_size_sampler\": 200,\n        \"num_workers\": 8,\n        \"full\": True,\n    },\n    \"ppi_rw\": {\n        \"aggr\": \"concat\",\n        \"arch\": \"1-0-1-0\",\n        \"dataset\": \"ppi\",\n        \"dropout\": 0.1,\n        \"edge_budget\": 4000,\n        \"length\": 2,\n        \"log_dir\": \"none\",\n        \"lr\": 0.01,\n        \"n_epochs\": 50,\n        \"n_hidden\": 512,\n        \"no_batch_norm\": False,\n        \"node_budget\": 6000,\n        \"num_subg\": 50,\n        \"num_roots\": 3000,\n        \"sampler\": \"rw\",\n        \"use_val\": True,\n        \"val_every\": 1,\n        \"num_workers_sampler\": 0,\n        \"num_subg_sampler\": 10000,\n        \"batch_size_sampler\": 200,\n        \"num_workers\": 8,\n        \"full\": True,\n    },\n    \"flickr_n\": {\n        \"aggr\": \"concat\",\n        \"arch\": \"1-1-0\",\n        \"dataset\": \"flickr\",\n        \"dropout\": 0.2,\n        \"edge_budget\": 6000,\n        \"length\": 2,\n        \"log_dir\": \"none\",\n        \"lr\": 0.01,\n        \"n_epochs\": 50,\n        \"n_hidden\": 256,\n        \"no_batch_norm\": False,\n        \"node_budget\": 8000,\n        \"num_subg\": 25,\n        \"num_roots\": 6000,\n        \"sampler\": \"node\",\n        \"use_val\": True,\n        \"val_every\": 1,\n        \"num_workers_sampler\": 0,\n        \"num_subg_sampler\": 10000,\n        \"batch_size_sampler\": 200,\n        \"num_workers\": 8,\n        \"full\": False,\n    },\n    \"flickr_e\": {\n        \"aggr\": \"concat\",\n        \"arch\": \"1-1-0\",\n        \"dataset\": \"flickr\",\n        \"dropout\": 0.2,\n        \"edge_budget\": 6000,\n        \"length\": 2,\n        \"log_dir\": \"none\",\n        \"lr\": 0.01,\n        \"n_epochs\": 50,\n        \"n_hidden\": 256,\n        \"no_batch_norm\": False,\n        \"node_budget\": 8000,\n        \"num_subg\": 25,\n        \"num_roots\": 6000,\n        \"sampler\": \"edge\",\n        \"use_val\": True,\n        \"val_every\": 1,\n        \"num_workers_sampler\": 0,\n        \"num_subg_sampler\": 10000,\n        \"batch_size_sampler\": 200,\n        \"num_workers\": 8,\n        \"full\": False,\n    },\n    \"flickr_rw\": {\n        \"aggr\": \"concat\",\n        \"arch\": \"1-1-0\",\n        \"dataset\": \"flickr\",\n        \"dropout\": 0.2,\n        \"edge_budget\": 6000,\n        \"length\": 2,\n        \"log_dir\": \"none\",\n        \"lr\": 0.01,\n        \"n_epochs\": 50,\n        \"n_hidden\": 256,\n        \"no_batch_norm\": False,\n        \"node_budget\": 8000,\n        \"num_subg\": 25,\n        \"num_roots\": 6000,\n        \"sampler\": \"rw\",\n        \"use_val\": True,\n        \"val_every\": 1,\n        \"num_workers_sampler\": 0,\n        \"num_subg_sampler\": 10000,\n        \"batch_size_sampler\": 200,\n        \"num_workers\": 8,\n        \"full\": False,\n    },\n    \"reddit_n\": {\n        \"aggr\": \"concat\",\n        \"arch\": \"1-0-1-0\",\n        \"dataset\": \"reddit\",\n        \"dropout\": 0.1,\n        \"edge_budget\": 4000,\n        \"length\": 2,\n        \"log_dir\": \"none\",\n        \"lr\": 0.01,\n        \"n_epochs\": 20,\n        \"n_hidden\": 128,\n        \"no_batch_norm\": False,\n        \"node_budget\": 8000,\n        \"num_subg\": 50,\n        \"num_roots\": 3000,\n        \"sampler\": \"node\",\n        \"use_val\": True,\n        \"val_every\": 1,\n        \"num_workers_sampler\": 8,\n        \"num_subg_sampler\": 10000,\n        \"batch_size_sampler\": 200,\n        \"num_workers\": 8,\n        \"full\": True,\n    },\n    \"reddit_e\": {\n        \"aggr\": \"concat\",\n        \"arch\": \"1-0-1-0\",\n        \"dataset\": \"reddit\",\n        \"dropout\": 0.1,\n        \"edge_budget\": 6000,\n        \"length\": 2,\n        \"log_dir\": \"none\",\n        \"lr\": 0.01,\n        \"n_epochs\": 20,\n        \"n_hidden\": 128,\n        \"no_batch_norm\": False,\n        \"node_budget\": 8000,\n        \"num_subg\": 50,\n        \"num_roots\": 3000,\n        \"sampler\": \"edge\",\n        \"use_val\": True,\n        \"val_every\": 1,\n        \"num_workers_sampler\": 8,\n        \"num_subg_sampler\": 10000,\n        \"batch_size_sampler\": 200,\n        \"num_workers\": 8,\n        \"full\": True,\n    },\n    \"reddit_rw\": {\n        \"aggr\": \"concat\",\n        \"arch\": \"1-0-1-0\",\n        \"dataset\": \"reddit\",\n        \"dropout\": 0.1,\n        \"edge_budget\": 6000,\n        \"length\": 4,\n        \"log_dir\": \"none\",\n        \"lr\": 0.01,\n        \"n_epochs\": 10,\n        \"n_hidden\": 128,\n        \"no_batch_norm\": False,\n        \"node_budget\": 8000,\n        \"num_subg\": 50,\n        \"num_roots\": 200,\n        \"sampler\": \"rw\",\n        \"use_val\": True,\n        \"val_every\": 1,\n        \"num_workers_sampler\": 8,\n        \"num_subg_sampler\": 10000,\n        \"batch_size_sampler\": 200,\n        \"num_workers\": 8,\n        \"full\": True,\n    },\n    \"yelp_n\": {\n        \"aggr\": \"concat\",\n        \"arch\": \"1-1-0\",\n        \"dataset\": \"yelp\",\n        \"dropout\": 0.1,\n        \"edge_budget\": 6000,\n        \"length\": 4,\n        \"log_dir\": \"none\",\n        \"lr\": 0.01,\n        \"n_epochs\": 10,\n        \"n_hidden\": 512,\n        \"no_batch_norm\": False,\n        \"node_budget\": 5000,\n        \"num_subg\": 50,\n        \"num_roots\": 200,\n        \"sampler\": \"node\",\n        \"use_val\": True,\n        \"val_every\": 1,\n        \"num_workers_sampler\": 8,\n        \"num_subg_sampler\": 10000,\n        \"batch_size_sampler\": 200,\n        \"num_workers\": 8,\n        \"full\": True,\n    },\n    \"yelp_e\": {\n        \"aggr\": \"concat\",\n        \"arch\": \"1-1-0\",\n        \"dataset\": \"yelp\",\n        \"dropout\": 0.1,\n        \"edge_budget\": 2500,\n        \"length\": 4,\n        \"log_dir\": \"none\",\n        \"lr\": 0.01,\n        \"n_epochs\": 10,\n        \"n_hidden\": 512,\n        \"no_batch_norm\": False,\n        \"node_budget\": 5000,\n        \"num_subg\": 50,\n        \"num_roots\": 200,\n        \"sampler\": \"edge\",\n        \"use_val\": True,\n        \"val_every\": 1,\n        \"num_workers_sampler\": 8,\n        \"num_subg_sampler\": 10000,\n        \"batch_size_sampler\": 200,\n        \"num_workers\": 8,\n        \"full\": True,\n    },\n    \"yelp_rw\": {\n        \"aggr\": \"concat\",\n        \"arch\": \"1-1-0\",\n        \"dataset\": \"yelp\",\n        \"dropout\": 0.1,\n        \"edge_budget\": 2500,\n        \"length\": 2,\n        \"log_dir\": \"none\",\n        \"lr\": 0.01,\n        \"n_epochs\": 10,\n        \"n_hidden\": 512,\n        \"no_batch_norm\": False,\n        \"node_budget\": 5000,\n        \"num_subg\": 50,\n        \"num_roots\": 1250,\n        \"sampler\": \"rw\",\n        \"use_val\": True,\n        \"val_every\": 1,\n        \"num_workers_sampler\": 8,\n        \"num_subg_sampler\": 10000,\n        \"batch_size_sampler\": 200,\n        \"num_workers\": 8,\n        \"full\": True,\n    },\n    \"amazon_n\": {\n        \"aggr\": \"concat\",\n        \"arch\": \"1-1-0\",\n        \"dataset\": \"amazon\",\n        \"dropout\": 0.1,\n        \"edge_budget\": 2500,\n        \"length\": 4,\n        \"log_dir\": \"none\",\n        \"lr\": 0.01,\n        \"n_epochs\": 5,\n        \"n_hidden\": 512,\n        \"no_batch_norm\": False,\n        \"node_budget\": 4500,\n        \"num_subg\": 50,\n        \"num_roots\": 200,\n        \"sampler\": \"node\",\n        \"use_val\": True,\n        \"val_every\": 1,\n        \"num_workers_sampler\": 4,\n        \"num_subg_sampler\": 10000,\n        \"batch_size_sampler\": 200,\n        \"num_workers\": 8,\n        \"full\": True,\n    },\n    \"amazon_e\": {\n        \"aggr\": \"concat\",\n        \"arch\": \"1-1-0\",\n        \"dataset\": \"amazon\",\n        \"dropout\": 0.1,\n        \"edge_budget\": 2000,\n        \"gpu\": 0,\n        \"length\": 4,\n        \"log_dir\": \"none\",\n        \"lr\": 0.01,\n        \"n_epochs\": 10,\n        \"n_hidden\": 512,\n        \"no_batch_norm\": False,\n        \"node_budget\": 5000,\n        \"num_subg\": 50,\n        \"num_roots\": 200,\n        \"sampler\": \"edge\",\n        \"use_val\": True,\n        \"val_every\": 1,\n        \"num_workers_sampler\": 20,\n        \"num_subg_sampler\": 5000,\n        \"batch_size_sampler\": 50,\n        \"num_workers\": 26,\n        \"full\": True,\n    },\n    \"amazon_rw\": {\n        \"aggr\": \"concat\",\n        \"arch\": \"1-1-0\",\n        \"dataset\": \"amazon\",\n        \"dropout\": 0.1,\n        \"edge_budget\": 2500,\n        \"gpu\": 0,\n        \"length\": 2,\n        \"log_dir\": \"none\",\n        \"lr\": 0.01,\n        \"n_epochs\": 5,\n        \"n_hidden\": 512,\n        \"no_batch_norm\": False,\n        \"node_budget\": 5000,\n        \"num_subg\": 50,\n        \"num_roots\": 1500,\n        \"sampler\": \"rw\",\n        \"use_val\": True,\n        \"val_every\": 1,\n        \"num_workers_sampler\": 4,\n        \"num_subg_sampler\": 10000,\n        \"batch_size_sampler\": 200,\n        \"num_workers\": 8,\n        \"full\": True,\n    },\n}\n"
  },
  {
    "path": "examples/pytorch/graphsaint/modules.py",
    "content": "import dgl.function as fn\nimport torch as th\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass GCNLayer(nn.Module):\n    def __init__(\n        self,\n        in_dim,\n        out_dim,\n        order=1,\n        act=None,\n        dropout=0,\n        batch_norm=False,\n        aggr=\"concat\",\n    ):\n        super(GCNLayer, self).__init__()\n        self.lins = nn.ModuleList()\n        self.bias = nn.ParameterList()\n        for _ in range(order + 1):\n            self.lins.append(nn.Linear(in_dim, out_dim, bias=False))\n            self.bias.append(nn.Parameter(th.zeros(out_dim)))\n\n        self.order = order\n        self.act = act\n        self.dropout = nn.Dropout(dropout)\n\n        self.batch_norm = batch_norm\n        if batch_norm:\n            self.offset, self.scale = nn.ParameterList(), nn.ParameterList()\n            for _ in range(order + 1):\n                self.offset.append(nn.Parameter(th.zeros(out_dim)))\n                self.scale.append(nn.Parameter(th.ones(out_dim)))\n\n        self.aggr = aggr\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        for lin in self.lins:\n            nn.init.xavier_normal_(lin.weight)\n\n    def feat_trans(\n        self, features, idx\n    ):  # linear transformation + activation + batch normalization\n        h = self.lins[idx](features) + self.bias[idx]\n\n        if self.act is not None:\n            h = self.act(h)\n\n        if self.batch_norm:\n            mean = h.mean(dim=1).view(h.shape[0], 1)\n            var = h.var(dim=1, unbiased=False).view(h.shape[0], 1) + 1e-9\n            h = (h - mean) * self.scale[idx] * th.rsqrt(var) + self.offset[idx]\n\n        return h\n\n    def forward(self, graph, features):\n        g = graph.local_var()\n        h_in = self.dropout(features)\n        h_hop = [h_in]\n\n        D_norm = (\n            g.ndata[\"train_D_norm\"]\n            if \"train_D_norm\" in g.ndata\n            else g.ndata[\"full_D_norm\"]\n        )\n        for _ in range(self.order):  # forward propagation\n            g.ndata[\"h\"] = h_hop[-1]\n            if \"w\" not in g.edata:\n                g.edata[\"w\"] = th.ones((g.num_edges(),)).to(features.device)\n            g.update_all(fn.u_mul_e(\"h\", \"w\", \"m\"), fn.sum(\"m\", \"h\"))\n            h = g.ndata.pop(\"h\")\n            h = h * D_norm\n            h_hop.append(h)\n\n        h_part = [self.feat_trans(ft, idx) for idx, ft in enumerate(h_hop)]\n        if self.aggr == \"mean\":\n            h_out = h_part[0]\n            for i in range(len(h_part) - 1):\n                h_out = h_out + h_part[i + 1]\n        elif self.aggr == \"concat\":\n            h_out = th.cat(h_part, 1)\n        else:\n            raise NotImplementedError\n\n        return h_out\n\n\nclass GCNNet(nn.Module):\n    def __init__(\n        self,\n        in_dim,\n        hid_dim,\n        out_dim,\n        arch=\"1-1-0\",\n        act=F.relu,\n        dropout=0,\n        batch_norm=False,\n        aggr=\"concat\",\n    ):\n        super(GCNNet, self).__init__()\n        self.gcn = nn.ModuleList()\n\n        orders = list(map(int, arch.split(\"-\")))\n        self.gcn.append(\n            GCNLayer(\n                in_dim=in_dim,\n                out_dim=hid_dim,\n                order=orders[0],\n                act=act,\n                dropout=dropout,\n                batch_norm=batch_norm,\n                aggr=aggr,\n            )\n        )\n        pre_out = ((aggr == \"concat\") * orders[0] + 1) * hid_dim\n\n        for i in range(1, len(orders) - 1):\n            self.gcn.append(\n                GCNLayer(\n                    in_dim=pre_out,\n                    out_dim=hid_dim,\n                    order=orders[i],\n                    act=act,\n                    dropout=dropout,\n                    batch_norm=batch_norm,\n                    aggr=aggr,\n                )\n            )\n            pre_out = ((aggr == \"concat\") * orders[i] + 1) * hid_dim\n\n        self.gcn.append(\n            GCNLayer(\n                in_dim=pre_out,\n                out_dim=hid_dim,\n                order=orders[-1],\n                act=act,\n                dropout=dropout,\n                batch_norm=batch_norm,\n                aggr=aggr,\n            )\n        )\n        pre_out = ((aggr == \"concat\") * orders[-1] + 1) * hid_dim\n\n        self.out_layer = GCNLayer(\n            in_dim=pre_out,\n            out_dim=out_dim,\n            order=0,\n            act=None,\n            dropout=dropout,\n            batch_norm=False,\n            aggr=aggr,\n        )\n\n    def forward(self, graph):\n        h = graph.ndata[\"feat\"]\n\n        for layer in self.gcn:\n            h = layer(graph, h)\n\n        h = F.normalize(h, p=2, dim=1)\n        h = self.out_layer(graph, h)\n\n        return h\n"
  },
  {
    "path": "examples/pytorch/graphsaint/sampler.py",
    "content": "import math\nimport os\nimport random\nimport time\n\nimport dgl\nimport dgl.function as fn\n\nimport numpy as np\nimport scipy\nimport torch as th\nfrom dgl.sampling import pack_traces, random_walk\nfrom torch.utils.data import DataLoader\n\n\n# The base class of sampler\nclass SAINTSampler:\n    \"\"\"\n    Description\n    -----------\n    SAINTSampler implements the sampler described in GraphSAINT. This sampler implements offline sampling in\n    pre-sampling phase as well as fully offline sampling, fully online sampling in training phase.\n    Users can conveniently set param 'online' of the sampler to choose different modes.\n\n    Parameters\n    ----------\n    node_budget : int\n        the expected number of nodes in each subgraph, which is specifically explained in the paper. Actually this\n        param specifies the times of sampling nodes from the original graph with replacement. The meaning of edge_budget\n        is similar to the node_budget.\n    dn : str\n        name of dataset.\n    g : DGLGraph\n        the full graph.\n    train_nid : list\n        ids of training nodes.\n    num_workers_sampler : int\n        number of processes to sample subgraphs in pre-sampling procedure using torch.dataloader.\n    num_subg_sampler : int, optional\n        the max number of subgraphs sampled in pre-sampling phase for computing normalization coefficients in the beginning.\n        Actually this param is used as ``__len__`` of sampler in pre-sampling phase.\n        Please make sure that num_subg_sampler is greater than batch_size_sampler so that we can sample enough subgraphs.\n        Defaults: 10000\n    batch_size_sampler : int, optional\n        the number of subgraphs sampled by each process concurrently in pre-sampling phase.\n        Defaults: 200\n    online : bool, optional\n        If `True`, we employ online sampling in training phase. Otherwise employing offline sampling.\n        Defaults: True\n    num_subg : int, optional\n        the expected number of sampled subgraphs in pre-sampling phase.\n        It is actually the 'N' in the original paper. Note that this param is different from the num_subg_sampler.\n        This param is just used to control the number of pre-sampled subgraphs.\n        Defaults: 50\n    full : bool, optional\n        True if the number of subgraphs used in the training phase equals to that of pre-sampled subgraphs, or\n        ``math.ceil(self.train_g.num_nodes() / self.node_budget)``. This formula takes the result of A divided by B as\n        the number of subgraphs used in the training phase, where A is the number of training nodes in the original\n        graph, B is the expected number of nodes in each pre-sampled subgraph. Please refer to the paper to check the\n        details.\n        Defaults: True\n\n    Notes\n    -----\n    For parallelism of pre-sampling, we utilize `torch.DataLoader` to concurrently speed up sampling.\n    The `num_subg_sampler` is the return value of `__len__` in pre-sampling phase. Moreover, the param `batch_size_sampler`\n    determines the batch_size of `torch.DataLoader` in internal pre-sampling part. But note that if we wanna pass the\n    SAINTSampler to `torch.DataLoader` for concurrently sampling subgraphs in training phase, we need to specify\n    `batch_size` of `DataLoader`, that is, `batch_size_sampler` is not related to how sampler works in training procedure.\n    \"\"\"\n\n    def __init__(\n        self,\n        node_budget,\n        dn,\n        g,\n        train_nid,\n        num_workers_sampler,\n        num_subg_sampler=10000,\n        batch_size_sampler=200,\n        online=True,\n        num_subg=50,\n        full=True,\n    ):\n        self.g = g.cpu()\n        self.node_budget = node_budget\n        self.train_g: dgl.graph = g.subgraph(train_nid)\n        self.dn, self.num_subg = dn, num_subg\n        self.node_counter = th.zeros((self.train_g.num_nodes(),))\n        self.edge_counter = th.zeros((self.train_g.num_edges(),))\n        self.prob = None\n        self.num_subg_sampler = num_subg_sampler\n        self.batch_size_sampler = batch_size_sampler\n        self.num_workers_sampler = num_workers_sampler\n        self.train = False\n        self.online = online\n        self.full = full\n\n        assert (\n            self.num_subg_sampler >= self.batch_size_sampler\n        ), \"num_subg_sampler should be greater than batch_size_sampler\"\n        graph_fn, norm_fn = self.__generate_fn__()\n\n        if os.path.exists(graph_fn):\n            self.subgraphs = np.load(graph_fn, allow_pickle=True)\n            aggr_norm, loss_norm = np.load(norm_fn, allow_pickle=True)\n        else:\n            os.makedirs(\"./subgraphs/\", exist_ok=True)\n\n            self.subgraphs = []\n            self.N, sampled_nodes = 0, 0\n            # N: the number of pre-sampled subgraphs\n\n            # Employ parallelism to speed up the sampling procedure\n            loader = DataLoader(\n                self,\n                batch_size=self.batch_size_sampler,\n                shuffle=True,\n                num_workers=self.num_workers_sampler,\n                collate_fn=self.__collate_fn__,\n                drop_last=False,\n            )\n\n            t = time.perf_counter()\n            for num_nodes, subgraphs_nids, subgraphs_eids in loader:\n                self.subgraphs.extend(subgraphs_nids)\n                sampled_nodes += num_nodes\n\n                _subgraphs, _node_counts = np.unique(\n                    np.concatenate(subgraphs_nids), return_counts=True\n                )\n                sampled_nodes_idx = th.from_numpy(_subgraphs)\n                _node_counts = th.from_numpy(_node_counts)\n                self.node_counter[sampled_nodes_idx] += _node_counts\n\n                _subgraphs_eids, _edge_counts = np.unique(\n                    np.concatenate(subgraphs_eids), return_counts=True\n                )\n                sampled_edges_idx = th.from_numpy(_subgraphs_eids)\n                _edge_counts = th.from_numpy(_edge_counts)\n                self.edge_counter[sampled_edges_idx] += _edge_counts\n\n                self.N += len(subgraphs_nids)  # number of subgraphs\n                if sampled_nodes > self.train_g.num_nodes() * num_subg:\n                    break\n\n            print(f\"Sampling time: [{time.perf_counter() - t:.2f}s]\")\n            np.save(graph_fn, self.subgraphs)\n\n            t = time.perf_counter()\n            aggr_norm, loss_norm = self.__compute_norm__()\n            print(f\"Normalization time: [{time.perf_counter() - t:.2f}s]\")\n            np.save(norm_fn, (aggr_norm, loss_norm))\n\n        self.train_g.ndata[\"l_n\"] = th.Tensor(loss_norm)\n        self.train_g.edata[\"w\"] = th.Tensor(aggr_norm)\n        self.__compute_degree_norm()  # basically normalizing adjacent matrix\n\n        random.shuffle(self.subgraphs)\n        self.__clear__()\n        print(\"The number of subgraphs is: \", len(self.subgraphs))\n\n        self.train = True\n\n    def __len__(self):\n        if self.train is False:\n            return self.num_subg_sampler\n        else:\n            if self.full:\n                return len(self.subgraphs)\n            else:\n                return math.ceil(self.train_g.num_nodes() / self.node_budget)\n\n    def __getitem__(self, idx):\n        # Only when sampling subgraphs in training procedure and need to utilize sampled subgraphs and we still\n        # have sampled subgraphs we can fetch a subgraph from sampled subgraphs\n        if self.train:\n            if self.online:\n                subgraph = self.__sample__()\n                return dgl.node_subgraph(self.train_g, subgraph)\n            else:\n                return dgl.node_subgraph(self.train_g, self.subgraphs[idx])\n        else:\n            subgraph_nids = self.__sample__()\n            num_nodes = len(subgraph_nids)\n            subgraph_eids = dgl.node_subgraph(\n                self.train_g, subgraph_nids\n            ).edata[dgl.EID]\n            return num_nodes, subgraph_nids, subgraph_eids\n\n    def __collate_fn__(self, batch):\n        if (\n            self.train\n        ):  # sample only one graph each epoch, batch_size in training phase in 1\n            return batch[0]\n        else:\n            sum_num_nodes = 0\n            subgraphs_nids_list = []\n            subgraphs_eids_list = []\n            for num_nodes, subgraph_nids, subgraph_eids in batch:\n                sum_num_nodes += num_nodes\n                subgraphs_nids_list.append(subgraph_nids)\n                subgraphs_eids_list.append(subgraph_eids)\n            return sum_num_nodes, subgraphs_nids_list, subgraphs_eids_list\n\n    def __clear__(self):\n        self.prob = None\n        self.node_counter = None\n        self.edge_counter = None\n        self.g = None\n\n    def __generate_fn__(self):\n        raise NotImplementedError\n\n    def __compute_norm__(self):\n        self.node_counter[self.node_counter == 0] = 1\n        self.edge_counter[self.edge_counter == 0] = 1\n\n        loss_norm = self.N / self.node_counter / self.train_g.num_nodes()\n\n        self.train_g.ndata[\"n_c\"] = self.node_counter\n        self.train_g.edata[\"e_c\"] = self.edge_counter\n        self.train_g.apply_edges(fn.v_div_e(\"n_c\", \"e_c\", \"a_n\"))\n        aggr_norm = self.train_g.edata.pop(\"a_n\")\n\n        self.train_g.ndata.pop(\"n_c\")\n        self.train_g.edata.pop(\"e_c\")\n\n        return aggr_norm.numpy(), loss_norm.numpy()\n\n    def __compute_degree_norm(self):\n        self.train_g.ndata[\n            \"train_D_norm\"\n        ] = 1.0 / self.train_g.in_degrees().float().clamp(min=1).unsqueeze(1)\n        self.g.ndata[\"full_D_norm\"] = 1.0 / self.g.in_degrees().float().clamp(\n            min=1\n        ).unsqueeze(1)\n\n    def __sample__(self):\n        raise NotImplementedError\n\n\nclass SAINTNodeSampler(SAINTSampler):\n    \"\"\"\n    Description\n    -----------\n    GraphSAINT with node sampler.\n\n    Parameters\n    ----------\n    node_budget : int\n        the expected number of nodes in each subgraph, which is specifically explained in the paper.\n    \"\"\"\n\n    def __init__(self, node_budget, **kwargs):\n        self.node_budget = node_budget\n        super(SAINTNodeSampler, self).__init__(\n            node_budget=node_budget, **kwargs\n        )\n\n    def __generate_fn__(self):\n        graph_fn = os.path.join(\n            \"./subgraphs/{}_Node_{}_{}.npy\".format(\n                self.dn, self.node_budget, self.num_subg\n            )\n        )\n        norm_fn = os.path.join(\n            \"./subgraphs/{}_Node_{}_{}_norm.npy\".format(\n                self.dn, self.node_budget, self.num_subg\n            )\n        )\n        return graph_fn, norm_fn\n\n    def __sample__(self):\n        if self.prob is None:\n            self.prob = self.train_g.in_degrees().float().clamp(min=1)\n\n        sampled_nodes = th.multinomial(\n            self.prob, num_samples=self.node_budget, replacement=True\n        ).unique()\n        return sampled_nodes.numpy()\n\n\nclass SAINTEdgeSampler(SAINTSampler):\n    \"\"\"\n    Description\n    -----------\n    GraphSAINT with edge sampler.\n\n    Parameters\n    ----------\n    edge_budget : int\n        the expected number of edges in each subgraph, which is specifically explained in the paper.\n    \"\"\"\n\n    def __init__(self, edge_budget, **kwargs):\n        self.edge_budget = edge_budget\n        self.rng = np.random.default_rng()\n\n        super(SAINTEdgeSampler, self).__init__(\n            node_budget=edge_budget * 2, **kwargs\n        )\n\n    def __generate_fn__(self):\n        graph_fn = os.path.join(\n            \"./subgraphs/{}_Edge_{}_{}.npy\".format(\n                self.dn, self.edge_budget, self.num_subg\n            )\n        )\n        norm_fn = os.path.join(\n            \"./subgraphs/{}_Edge_{}_{}_norm.npy\".format(\n                self.dn, self.edge_budget, self.num_subg\n            )\n        )\n        return graph_fn, norm_fn\n\n    # TODO: only sample half edges, then add another half edges\n    # TODO: use numpy to implement cython sampling method\n    def __sample__(self):\n        if self.prob is None:\n            src, dst = self.train_g.edges()\n            src_degrees, dst_degrees = self.train_g.in_degrees(\n                src\n            ).float().clamp(min=1), self.train_g.in_degrees(dst).float().clamp(\n                min=1\n            )\n            prob_mat = 1.0 / src_degrees + 1.0 / dst_degrees\n            prob_mat = scipy.sparse.csr_matrix(\n                (prob_mat.numpy(), (src.numpy(), dst.numpy()))\n            )\n            # The edge probability here only contains that of edges in upper triangle adjacency matrix\n            # Because we assume the graph is undirected, that is, the adjacency matrix is symmetric. We only need\n            # to consider half of edges in the graph.\n            self.prob = th.tensor(scipy.sparse.triu(prob_mat).data)\n            self.prob /= self.prob.sum()\n            self.adj_nodes = np.stack(prob_mat.nonzero(), axis=1)\n\n        sampled_edges = np.unique(\n            dgl.random.choice(\n                len(self.prob),\n                size=self.edge_budget,\n                prob=self.prob,\n                replace=False,\n            )\n        )\n        sampled_nodes = np.unique(\n            self.adj_nodes[sampled_edges].flatten()\n        ).astype(\"long\")\n        return sampled_nodes\n\n\nclass SAINTRandomWalkSampler(SAINTSampler):\n    \"\"\"\n    Description\n    -----------\n    GraphSAINT with random walk sampler\n\n    Parameters\n    ----------\n    num_roots : int\n        the number of roots to generate random walks.\n    length : int\n        the length of each random walk.\n\n    \"\"\"\n\n    def __init__(self, num_roots, length, **kwargs):\n        self.num_roots, self.length = num_roots, length\n        super(SAINTRandomWalkSampler, self).__init__(\n            node_budget=num_roots * length, **kwargs\n        )\n\n    def __generate_fn__(self):\n        graph_fn = os.path.join(\n            \"./subgraphs/{}_RW_{}_{}_{}.npy\".format(\n                self.dn, self.num_roots, self.length, self.num_subg\n            )\n        )\n        norm_fn = os.path.join(\n            \"./subgraphs/{}_RW_{}_{}_{}_norm.npy\".format(\n                self.dn, self.num_roots, self.length, self.num_subg\n            )\n        )\n        return graph_fn, norm_fn\n\n    def __sample__(self):\n        sampled_roots = th.randint(\n            0, self.train_g.num_nodes(), (self.num_roots,)\n        )\n        traces, types = random_walk(\n            self.train_g, nodes=sampled_roots, length=self.length\n        )\n        sampled_nodes, _, _, _ = pack_traces(traces, types)\n        sampled_nodes = sampled_nodes.unique()\n        return sampled_nodes.numpy()\n"
  },
  {
    "path": "examples/pytorch/graphsaint/train_sampling.py",
    "content": "import argparse\nimport os\nimport time\nimport warnings\n\nimport torch\nimport torch.nn.functional as F\nfrom config import CONFIG\nfrom modules import GCNNet\nfrom sampler import SAINTEdgeSampler, SAINTNodeSampler, SAINTRandomWalkSampler\nfrom torch.utils.data import DataLoader\nfrom utils import calc_f1, evaluate, load_data, Logger, save_log_dir\n\n\ndef main(args, task):\n    warnings.filterwarnings(\"ignore\")\n    multilabel_data = {\"ppi\", \"yelp\", \"amazon\"}\n    multilabel = args.dataset in multilabel_data\n\n    # This flag is excluded for too large dataset, like amazon, the graph of which is too large to be directly\n    # shifted to one gpu. So we need to\n    # 1. put the whole graph on cpu, and put the subgraphs on gpu in training phase\n    # 2. put the model on gpu in training phase, and put the model on cpu in validation/testing phase\n    # We need to judge cpu_flag and cuda (below) simultaneously when shift model between cpu and gpu\n    if args.dataset in [\"amazon\"]:\n        cpu_flag = True\n    else:\n        cpu_flag = False\n\n    # load and preprocess dataset\n    data = load_data(args, multilabel)\n    g = data.g\n    train_mask = g.ndata[\"train_mask\"]\n    val_mask = g.ndata[\"val_mask\"]\n    test_mask = g.ndata[\"test_mask\"]\n    labels = g.ndata[\"label\"]\n\n    train_nid = data.train_nid\n\n    in_feats = g.ndata[\"feat\"].shape[1]\n    n_classes = data.num_classes\n    n_nodes = g.num_nodes()\n    n_edges = g.num_edges()\n\n    n_train_samples = train_mask.int().sum().item()\n    n_val_samples = val_mask.int().sum().item()\n    n_test_samples = test_mask.int().sum().item()\n\n    print(\n        \"\"\"----Data statistics------'\n    #Nodes %d\n    #Edges %d\n    #Classes/Labels (multi binary labels) %d\n    #Train samples %d\n    #Val samples %d\n    #Test samples %d\"\"\"\n        % (\n            n_nodes,\n            n_edges,\n            n_classes,\n            n_train_samples,\n            n_val_samples,\n            n_test_samples,\n        )\n    )\n    # load sampler\n\n    kwargs = {\n        \"dn\": args.dataset,\n        \"g\": g,\n        \"train_nid\": train_nid,\n        \"num_workers_sampler\": args.num_workers_sampler,\n        \"num_subg_sampler\": args.num_subg_sampler,\n        \"batch_size_sampler\": args.batch_size_sampler,\n        \"online\": args.online,\n        \"num_subg\": args.num_subg,\n        \"full\": args.full,\n    }\n\n    if args.sampler == \"node\":\n        saint_sampler = SAINTNodeSampler(args.node_budget, **kwargs)\n    elif args.sampler == \"edge\":\n        saint_sampler = SAINTEdgeSampler(args.edge_budget, **kwargs)\n    elif args.sampler == \"rw\":\n        saint_sampler = SAINTRandomWalkSampler(\n            args.num_roots, args.length, **kwargs\n        )\n    else:\n        raise NotImplementedError\n    loader = DataLoader(\n        saint_sampler,\n        collate_fn=saint_sampler.__collate_fn__,\n        batch_size=1,\n        shuffle=True,\n        num_workers=args.num_workers,\n        drop_last=False,\n    )\n    # set device for dataset tensors\n    if args.gpu < 0:\n        cuda = False\n    else:\n        cuda = True\n        torch.cuda.set_device(args.gpu)\n        val_mask = val_mask.cuda()\n        test_mask = test_mask.cuda()\n        if not cpu_flag:\n            g = g.to(\"cuda:{}\".format(args.gpu))\n\n    print(\"labels shape:\", g.ndata[\"label\"].shape)\n    print(\"features shape:\", g.ndata[\"feat\"].shape)\n\n    model = GCNNet(\n        in_dim=in_feats,\n        hid_dim=args.n_hidden,\n        out_dim=n_classes,\n        arch=args.arch,\n        dropout=args.dropout,\n        batch_norm=not args.no_batch_norm,\n        aggr=args.aggr,\n    )\n\n    if cuda:\n        model.cuda()\n\n    # logger and so on\n    log_dir = save_log_dir(args)\n    logger = Logger(os.path.join(log_dir, \"loggings\"))\n    logger.write(args)\n\n    # use optimizer\n    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)\n\n    # set train_nids to cuda tensor\n    if cuda:\n        train_nid = torch.from_numpy(train_nid).cuda()\n        print(\n            \"GPU memory allocated before training(MB)\",\n            torch.cuda.memory_allocated(device=train_nid.device) / 1024 / 1024,\n        )\n    start_time = time.time()\n    best_f1 = -1\n\n    for epoch in range(args.n_epochs):\n        for j, subg in enumerate(loader):\n            if cuda:\n                subg = subg.to(torch.cuda.current_device())\n            model.train()\n            # forward\n            pred = model(subg)\n            batch_labels = subg.ndata[\"label\"]\n\n            if multilabel:\n                loss = F.binary_cross_entropy_with_logits(\n                    pred,\n                    batch_labels,\n                    reduction=\"sum\",\n                    weight=subg.ndata[\"l_n\"].unsqueeze(1),\n                )\n            else:\n                loss = F.cross_entropy(pred, batch_labels, reduction=\"none\")\n                loss = (subg.ndata[\"l_n\"] * loss).sum()\n\n            optimizer.zero_grad()\n            loss.backward()\n            torch.nn.utils.clip_grad_norm(model.parameters(), 5)\n            optimizer.step()\n\n            if j == len(loader) - 1:\n                model.eval()\n                with torch.no_grad():\n                    train_f1_mic, train_f1_mac = calc_f1(\n                        batch_labels.cpu().numpy(),\n                        pred.cpu().numpy(),\n                        multilabel,\n                    )\n                    print(\n                        f\"epoch:{epoch + 1}/{args.n_epochs}, Iteration {j + 1}/\"\n                        f\"{len(loader)}:training loss\",\n                        loss.item(),\n                    )\n                    print(\n                        \"Train F1-mic {:.4f}, Train F1-mac {:.4f}\".format(\n                            train_f1_mic, train_f1_mac\n                        )\n                    )\n        # evaluate\n        model.eval()\n        if epoch % args.val_every == 0:\n            if (\n                cpu_flag and cuda\n            ):  # Only when we have shifted model to gpu and we need to shift it back on cpu\n                model = model.to(\"cpu\")\n            val_f1_mic, val_f1_mac = evaluate(\n                model, g, labels, val_mask, multilabel\n            )\n            print(\n                \"Val F1-mic {:.4f}, Val F1-mac {:.4f}\".format(\n                    val_f1_mic, val_f1_mac\n                )\n            )\n            if val_f1_mic > best_f1:\n                best_f1 = val_f1_mic\n                print(\"new best val f1:\", best_f1)\n                torch.save(\n                    model.state_dict(),\n                    os.path.join(log_dir, \"best_model_{}.pkl\".format(task)),\n                )\n            if cpu_flag and cuda:\n                model.cuda()\n\n    end_time = time.time()\n    print(f\"training using time {end_time - start_time}\")\n\n    # test\n    if args.use_val:\n        model.load_state_dict(\n            torch.load(\n                os.path.join(log_dir, \"best_model_{}.pkl\".format(task)),\n                weights_only=False,\n            )\n        )\n    if cpu_flag and cuda:\n        model = model.to(\"cpu\")\n    test_f1_mic, test_f1_mac = evaluate(model, g, labels, test_mask, multilabel)\n    print(\n        \"Test F1-mic {:.4f}, Test F1-mac {:.4f}\".format(\n            test_f1_mic, test_f1_mac\n        )\n    )\n\n\nif __name__ == \"__main__\":\n    warnings.filterwarnings(\"ignore\")\n\n    parser = argparse.ArgumentParser(description=\"GraphSAINT\")\n    parser.add_argument(\n        \"--task\", type=str, default=\"ppi_n\", help=\"type of tasks\"\n    )\n    parser.add_argument(\n        \"--online\",\n        dest=\"online\",\n        action=\"store_true\",\n        help=\"sampling method in training phase\",\n    )\n    parser.add_argument(\"--gpu\", type=int, default=0, help=\"the gpu index\")\n    task = parser.parse_args().task\n    args = argparse.Namespace(**CONFIG[task])\n    args.online = parser.parse_args().online\n    args.gpu = parser.parse_args().gpu\n    print(args)\n\n    main(args, task=task)\n"
  },
  {
    "path": "examples/pytorch/graphsaint/utils.py",
    "content": "import json\nimport os\nfrom functools import namedtuple\n\nimport dgl\n\nimport numpy as np\nimport scipy.sparse\nimport torch\nfrom sklearn.metrics import f1_score\nfrom sklearn.preprocessing import StandardScaler\n\n\nclass Logger(object):\n    \"\"\"A custom logger to log stdout to a logging file.\"\"\"\n\n    def __init__(self, path):\n        \"\"\"Initialize the logger.\n\n        Parameters\n        ---------\n        path : str\n            The file path to be stored in.\n        \"\"\"\n        self.path = path\n\n    def write(self, s):\n        with open(self.path, \"a\") as f:\n            f.write(str(s))\n        print(s)\n        return\n\n\ndef save_log_dir(args):\n    log_dir = \"./log/{}/{}\".format(args.dataset, args.log_dir)\n    os.makedirs(log_dir, exist_ok=True)\n    return log_dir\n\n\ndef calc_f1(y_true, y_pred, multilabel):\n    if multilabel:\n        y_pred[y_pred > 0] = 1\n        y_pred[y_pred <= 0] = 0\n    else:\n        y_pred = np.argmax(y_pred, axis=1)\n    return f1_score(y_true, y_pred, average=\"micro\"), f1_score(\n        y_true, y_pred, average=\"macro\"\n    )\n\n\ndef evaluate(model, g, labels, mask, multilabel=False):\n    model.eval()\n    with torch.no_grad():\n        logits = model(g)\n        logits = logits[mask]\n        labels = labels[mask]\n        f1_mic, f1_mac = calc_f1(\n            labels.cpu().numpy(), logits.cpu().numpy(), multilabel\n        )\n        return f1_mic, f1_mac\n\n\n# load data of GraphSAINT and convert them to the format of dgl\ndef load_data(args, multilabel):\n    if not os.path.exists(\"graphsaintdata\") and not os.path.exists(\"data\"):\n        raise ValueError(\"The directory graphsaintdata does not exist!\")\n    elif os.path.exists(\"graphsaintdata\") and not os.path.exists(\"data\"):\n        os.rename(\"graphsaintdata\", \"data\")\n    prefix = \"data/{}\".format(args.dataset)\n    DataType = namedtuple(\"Dataset\", [\"num_classes\", \"train_nid\", \"g\"])\n\n    adj_full = scipy.sparse.load_npz(\"./{}/adj_full.npz\".format(prefix)).astype(\n        np.bool_\n    )\n    g = dgl.from_scipy(adj_full)\n    num_nodes = g.num_nodes()\n\n    adj_train = scipy.sparse.load_npz(\n        \"./{}/adj_train.npz\".format(prefix)\n    ).astype(np.bool_)\n    train_nid = np.array(list(set(adj_train.nonzero()[0])))\n\n    role = json.load(open(\"./{}/role.json\".format(prefix)))\n    mask = np.zeros((num_nodes,), dtype=bool)\n    train_mask = mask.copy()\n    train_mask[role[\"tr\"]] = True\n    val_mask = mask.copy()\n    val_mask[role[\"va\"]] = True\n    test_mask = mask.copy()\n    test_mask[role[\"te\"]] = True\n\n    feats = np.load(\"./{}/feats.npy\".format(prefix))\n    scaler = StandardScaler()\n    scaler.fit(feats[train_nid])\n    feats = scaler.transform(feats)\n\n    class_map = json.load(open(\"./{}/class_map.json\".format(prefix)))\n    class_map = {int(k): v for k, v in class_map.items()}\n    if multilabel:\n        # Multi-label binary classification\n        num_classes = len(list(class_map.values())[0])\n        class_arr = np.zeros((num_nodes, num_classes))\n        for k, v in class_map.items():\n            class_arr[k] = v\n    else:\n        num_classes = max(class_map.values()) - min(class_map.values()) + 1\n        class_arr = np.zeros((num_nodes,))\n        for k, v in class_map.items():\n            class_arr[k] = v\n\n    g.ndata[\"feat\"] = torch.tensor(feats, dtype=torch.float)\n    g.ndata[\"label\"] = torch.tensor(\n        class_arr, dtype=torch.float if multilabel else torch.long\n    )\n    g.ndata[\"train_mask\"] = torch.tensor(train_mask, dtype=torch.bool)\n    g.ndata[\"val_mask\"] = torch.tensor(val_mask, dtype=torch.bool)\n    g.ndata[\"test_mask\"] = torch.tensor(test_mask, dtype=torch.bool)\n\n    data = DataType(g=g, num_classes=num_classes, train_nid=train_nid)\n    return data\n"
  },
  {
    "path": "examples/pytorch/graphsim/README.md",
    "content": "# GraphParticleSim\n## DGL Implementation of Interaction-Network paper.\n\nThis DGL example implements the GNN model proposed in the paper [Interaction Network](https://arxiv.org/abs/1612.00222.pdf). \n\nGraphParticleSim implementor\n----------------------\nThis example was implemented by [Ericcsr](https://github.com/Ericcsr) during his Internship work at the AWS Shanghai AI Lab.\n\nThe graph dataset used in this example \n---------------------------------------\nThis Example uses Datasets Generate By Physics N-Body Simulator adapted from [This Repo](https://github.com/jsikyoon/Interaction-networks_tensorflow)\n\nn_body:\n    - n Particles/Nodes\n    - Complete Bidirectional Graph\n    - 10 trajectories should be generated\n    - 1000 steps of simulation per trajectory\n\nDependency\n--------------------------------\n- ffmpeg 4.3.8\n- opencv-python 4.2.0\n\nHow to run example files\n--------------------------------\nIn the graphsim folder, run\n**Please first run `n_body_sim.py` to generate some data**\n\nUsing Ground Truth Velocity From Simulator Directly.\n\n```python\npython n_body_sim.py\n```\n\nGenerate Longer trajectory or more trajectories.\n\n```python\npython n_body_sim.py --num_traj <num_traj> --steps <num_steps>\n```\n\n**Please use `train.py`**\n\n\n```python\npython train.py --num_workers 15\n```\n\nTraining with GPU\n```python\npython train.py --gpu 0 --num_workers 15\n```\n\nTraining with visualization: for valid visualization, it might take full 40000 epoch of training\n```python\npython train.py --gpu 0 --num_workers 15 --visualize\n```\n\nOne Step Loss Performance, Loss of test data after 40000 training epochs.\n-------------------------\n| Models/Dataset | 6 Body |\n| :-------------- | -----: |\n| Interaction Network in DGL | 80(10) |\n| Interaction Network in Tensorflow | 60 |\n\n-------------------------\nNotice that The datasets are generated directly from simulator to prevent using Tensorflow to handle the original dataset. The training is very unstable, the even if the minimum loss is achieved from time to time, there are chances that loss will suddenly increase,in both auther's model and our model. Since the original model hasn't been released, the implementation of this model refers to Tensorflow version implemented in: https://github.com/jsikyoon/Interaction-networks_tensorflow which had consulted the first author for some implementation details.\n\n"
  },
  {
    "path": "examples/pytorch/graphsim/dataloader.py",
    "content": "import copy\nimport os\n\nimport dgl\n\nimport networkx as nx\nimport numpy as np\nimport torch\nfrom torch.utils.data import DataLoader, Dataset\n\n\ndef build_dense_graph(n_particles):\n    g = nx.complete_graph(n_particles)\n    return dgl.from_networkx(g)\n\n\nclass MultiBodyDataset(Dataset):\n    def __init__(self, path):\n        self.path = path\n        self.zipfile = np.load(self.path)\n        self.node_state = self.zipfile[\"data\"]\n        self.node_label = self.zipfile[\"label\"]\n        self.n_particles = self.zipfile[\"n_particles\"]\n\n    def __len__(self):\n        return self.node_state.shape[0]\n\n    def __getitem__(self, idx):\n        if torch.is_tensor(idx):\n            idx = idx.tolist()\n\n        node_state = self.node_state[idx, :, :]\n        node_label = self.node_label[idx, :, :]\n        return (node_state, node_label)\n\n\nclass MultiBodyTrainDataset(MultiBodyDataset):\n    def __init__(self, data_path=\"./data/\"):\n        super(MultiBodyTrainDataset, self).__init__(\n            data_path + \"n_body_train.npz\"\n        )\n        self.stat_median = self.zipfile[\"median\"]\n        self.stat_max = self.zipfile[\"max\"]\n        self.stat_min = self.zipfile[\"min\"]\n\n\nclass MultiBodyValidDataset(MultiBodyDataset):\n    def __init__(self, data_path=\"./data/\"):\n        super(MultiBodyValidDataset, self).__init__(\n            data_path + \"n_body_valid.npz\"\n        )\n\n\nclass MultiBodyTestDataset(MultiBodyDataset):\n    def __init__(self, data_path=\"./data/\"):\n        super(MultiBodyTestDataset, self).__init__(\n            data_path + \"n_body_test.npz\"\n        )\n        self.test_traj = self.zipfile[\"test_traj\"]\n        self.first_frame = torch.from_numpy(self.zipfile[\"first_frame\"])\n\n\n# Construct fully connected graph\n\n\nclass MultiBodyGraphCollator:\n    def __init__(self, n_particles):\n        self.n_particles = n_particles\n        self.graph = dgl.from_networkx(nx.complete_graph(self.n_particles))\n\n    def __call__(self, batch):\n        graph_list = []\n        data_list = []\n        label_list = []\n        for frame in batch:\n            graph_list.append(copy.deepcopy(self.graph))\n            data_list.append(torch.from_numpy(frame[0]))\n            label_list.append(torch.from_numpy(frame[1]))\n\n        graph_batch = dgl.batch(graph_list)\n        data_batch = torch.vstack(data_list)\n        label_batch = torch.vstack(label_list)\n        return graph_batch, data_batch, label_batch\n"
  },
  {
    "path": "examples/pytorch/graphsim/models.py",
    "content": "import copy\nfrom functools import partial\n\nimport dgl\nimport dgl.function as fn\nimport dgl.nn as dglnn\n\nimport torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\n\n\nclass MLP(nn.Module):\n    def __init__(self, in_feats, out_feats, num_layers=2, hidden=128):\n        super(MLP, self).__init__()\n        self.layers = nn.ModuleList()\n        layer = nn.Linear(hidden, out_feats)\n        nn.init.normal_(layer.weight, std=0.1)\n        nn.init.zeros_(layer.bias)\n        self.layers.append(nn.Linear(in_feats, hidden))\n        if num_layers > 2:\n            for i in range(1, num_layers - 1):\n                layer = nn.Linear(hidden, hidden)\n                nn.init.normal_(layer.weight, std=0.1)\n                nn.init.zeros_(layer.bias)\n                self.layers.append(layer)\n        layer = nn.Linear(hidden, out_feats)\n        nn.init.normal_(layer.weight, std=0.1)\n        nn.init.zeros_(layer.bias)\n        self.layers.append(layer)\n\n    def forward(self, x):\n        for l in range(len(self.layers) - 1):\n            x = self.layers[l](x)\n            x = F.relu(x)\n        x = self.layers[-1](x)\n        return x\n\n\nclass PrepareLayer(nn.Module):\n    \"\"\"\n    Generate edge feature for the model input preparation:\n    as well as do the normalization work.\n    Parameters\n    ==========\n    node_feats : int\n        Number of node features\n\n    stat : dict\n        dictionary which represent the statistics needed for normalization\n    \"\"\"\n\n    def __init__(self, node_feats, stat):\n        super(PrepareLayer, self).__init__()\n        self.node_feats = node_feats\n        # stat {'median':median,'max':max,'min':min}\n        self.stat = stat\n\n    def normalize_input(self, node_feature):\n        return (node_feature - self.stat[\"median\"]) * (\n            2 / (self.stat[\"max\"] - self.stat[\"min\"])\n        )\n\n    def forward(self, g, node_feature):\n        with g.local_scope():\n            node_feature = self.normalize_input(node_feature)\n            g.ndata[\"feat\"] = node_feature  # Only dynamic feature\n            g.apply_edges(fn.u_sub_v(\"feat\", \"feat\", \"e\"))\n            edge_feature = g.edata[\"e\"]\n            return node_feature, edge_feature\n\n\nclass InteractionNet(nn.Module):\n    \"\"\"\n    Simple Interaction Network\n    One Layer interaction network for stellar multi-body problem simulation,\n    it has the ability to simulate number of body motion no more than 12\n    Parameters\n    ==========\n    node_feats : int\n        Number of node features\n\n    stat : dict\n        Statistcics for Denormalization\n    \"\"\"\n\n    def __init__(self, node_feats, stat):\n        super(InteractionNet, self).__init__()\n        self.node_feats = node_feats\n        self.stat = stat\n        edge_fn = partial(MLP, num_layers=5, hidden=150)\n        node_fn = partial(MLP, num_layers=2, hidden=100)\n\n        self.in_layer = InteractionLayer(\n            node_feats - 3,  # Use velocity only\n            node_feats,\n            out_node_feats=2,\n            out_edge_feats=50,\n            edge_fn=edge_fn,\n            node_fn=node_fn,\n            mode=\"n_n\",\n        )\n\n    # Denormalize Velocity only\n    def denormalize_output(self, out):\n        return (\n            out * (self.stat[\"max\"][3:5] - self.stat[\"min\"][3:5]) / 2\n            + self.stat[\"median\"][3:5]\n        )\n\n    def forward(self, g, n_feat, e_feat, global_feats, relation_feats):\n        with g.local_scope():\n            out_n, out_e = self.in_layer(\n                g, n_feat, e_feat, global_feats, relation_feats\n            )\n            out_n = self.denormalize_output(out_n)\n            return out_n, out_e\n\n\nclass InteractionLayer(nn.Module):\n    \"\"\"\n    Implementation of single layer of interaction network\n    Parameters\n    ==========\n    in_node_feats : int\n        Number of node features\n\n    in_edge_feats : int\n        Number of edge features\n\n    out_node_feats : int\n        Number of node feature after one interaction\n\n    out_edge_feats : int\n        Number of edge features after one interaction\n\n    global_feats : int\n        Number of global features used as input\n\n    relate_feats : int\n        Feature related to the relation between object themselves\n\n    edge_fn : torch.nn.Module\n        Function to update edge feature in message generation\n\n    node_fn : torch.nn.Module\n        Function to update node feature in message aggregation\n\n    mode : str\n        Type of message should the edge carry\n        nne : [src_feat,dst_feat,edge_feat] node feature concat edge feature.\n        n_n : [src_feat-edge_feat] node feature subtract from each other.\n    \"\"\"\n\n    def __init__(\n        self,\n        in_node_feats,\n        in_edge_feats,\n        out_node_feats,\n        out_edge_feats,\n        global_feats=1,\n        relate_feats=1,\n        edge_fn=nn.Linear,\n        node_fn=nn.Linear,\n        mode=\"nne\",\n    ):  # 'n_n'\n        super(InteractionLayer, self).__init__()\n        self.in_node_feats = in_node_feats\n        self.in_edge_feats = in_edge_feats\n        self.out_edge_feats = out_edge_feats\n        self.out_node_feats = out_node_feats\n        self.mode = mode\n        # MLP for message passing\n        input_shape = (\n            2 * self.in_node_feats + self.in_edge_feats\n            if mode == \"nne\"\n            else self.in_edge_feats + relate_feats\n        )\n        self.edge_fn = edge_fn(\n            input_shape, self.out_edge_feats\n        )  # 50 in IN paper\n\n        self.node_fn = node_fn(\n            self.in_node_feats + self.out_edge_feats + global_feats,\n            self.out_node_feats,\n        )\n\n    # Should be done by apply edge\n    def update_edge_fn(self, edges):\n        x = torch.cat(\n            [edges.src[\"feat\"], edges.dst[\"feat\"], edges.data[\"feat\"]], dim=1\n        )\n        ret = F.relu(self.edge_fn(x)) if self.mode == \"nne\" else self.edge_fn(x)\n        return {\"e\": ret}\n\n    # Assume agg comes from build in reduce\n    def update_node_fn(self, nodes):\n        x = torch.cat([nodes.data[\"feat\"], nodes.data[\"agg\"]], dim=1)\n        ret = F.relu(self.node_fn(x)) if self.mode == \"nne\" else self.node_fn(x)\n        return {\"n\": ret}\n\n    def forward(self, g, node_feats, edge_feats, global_feats, relation_feats):\n        # print(node_feats.shape,global_feats.shape)\n        g.ndata[\"feat\"] = torch.cat([node_feats, global_feats], dim=1)\n        g.edata[\"feat\"] = torch.cat([edge_feats, relation_feats], dim=1)\n        if self.mode == \"nne\":\n            g.apply_edges(self.update_edge_fn)\n        else:\n            g.edata[\"e\"] = self.edge_fn(g.edata[\"feat\"])\n\n        g.update_all(\n            fn.copy_e(\"e\", \"msg\"), fn.sum(\"msg\", \"agg\"), self.update_node_fn\n        )\n        return g.ndata[\"n\"], g.edata[\"e\"]\n"
  },
  {
    "path": "examples/pytorch/graphsim/n_body_sim.py",
    "content": "from __future__ import absolute_import, division, print_function\n\nimport argparse\nimport os\nfrom math import cos, pi, radians, sin\n\nimport numpy as np\n\n\"\"\"\nThis adapted from comes from https://github.com/jsikyoon/Interaction-networks_tensorflow\nwhich generates multi-body dynamic simulation data for Interaction network\n\"\"\"\n\n# 5 features on the state [mass,x,y,x_vel,y_vel]\nfea_num = 5\n# G stand for Gravity constant 10**5 can help numerical stability\nG = 10**5\n# time step\ndiff_t = 0.001\n\n\ndef init(total_state, n_body, fea_num, orbit):\n    data = np.zeros((total_state, n_body, fea_num), dtype=float)\n    if orbit:\n        data[0][0][0] = 100\n        data[0][0][1:5] = 0.0\n        # The position are initialized randomly.\n        for i in range(1, n_body):\n            data[0][i][0] = np.random.rand() * 8.98 + 0.02\n            distance = np.random.rand() * 90.0 + 10.0\n            theta = np.random.rand() * 360\n            theta_rad = pi / 2 - radians(theta)\n            data[0][i][1] = distance * cos(theta_rad)\n            data[0][i][2] = distance * sin(theta_rad)\n            data[0][i][3] = (\n                -1\n                * data[0][i][2]\n                / norm(data[0][i][1:3])\n                * (G * data[0][0][0] / norm(data[0][i][1:3]) ** 2)\n                * distance\n                / 1000\n            )\n            data[0][i][4] = (\n                data[0][i][1]\n                / norm(data[0][i][1:3])\n                * (G * data[0][0][0] / norm(data[0][i][1:3]) ** 2)\n                * distance\n                / 1000\n            )\n    else:\n        for i in range(n_body):\n            data[0][i][0] = np.random.rand() * 8.98 + 0.02\n            distance = np.random.rand() * 90.0 + 10.0\n            theta = np.random.rand() * 360\n            theta_rad = pi / 2 - radians(theta)\n            data[0][i][1] = distance * cos(theta_rad)\n            data[0][i][2] = distance * sin(theta_rad)\n            data[0][i][3] = np.random.rand() * 6.0 - 3.0\n            data[0][i][4] = np.random.rand() * 6.0 - 3.0\n    return data\n\n\ndef norm(x):\n    return np.sqrt(np.sum(x**2))\n\n\ndef get_f(reciever, sender):\n    diff = sender[1:3] - reciever[1:3]\n    distance = norm(diff)\n    if distance < 1:\n        distance = 1\n    return G * reciever[0] * sender[0] / (distance**3) * diff\n\n\n# Compute stat according to the paper for normalization\ndef compute_stats(train_curr):\n    data = np.vstack(train_curr).reshape(-1, fea_num)\n    stat_median = np.median(data, axis=0)\n    stat_max = np.quantile(data, 0.95, axis=0)\n    stat_min = np.quantile(data, 0.05, axis=0)\n    return stat_median, stat_max, stat_min\n\n\ndef calc(cur_state, n_body):\n    next_state = np.zeros((n_body, fea_num), dtype=float)\n    f_mat = np.zeros((n_body, n_body, 2), dtype=float)\n    f_sum = np.zeros((n_body, 2), dtype=float)\n    acc = np.zeros((n_body, 2), dtype=float)\n    for i in range(n_body):\n        for j in range(i + 1, n_body):\n            if j != i:\n                f = get_f(cur_state[i][:3], cur_state[j][:3])\n                f_mat[i, j] += f\n                f_mat[j, i] -= f\n        f_sum[i] = np.sum(f_mat[i], axis=0)\n        acc[i] = f_sum[i] / cur_state[i][0]\n        next_state[i][0] = cur_state[i][0]\n        next_state[i][3:5] = cur_state[i][3:5] + acc[i] * diff_t\n        next_state[i][1:3] = cur_state[i][1:3] + next_state[i][3:5] * diff_t\n    return next_state\n\n\n# The state is [mass,pos_x,pos_y,vel_x,vel_y]* n_body\ndef gen(n_body, num_steps, orbit):\n    # initialization on just first state\n    data = init(num_steps, n_body, fea_num, orbit)\n    for i in range(1, num_steps):\n        data[i] = calc(data[i - 1], n_body)\n    return data\n\n\nif __name__ == \"__main__\":\n    argparser = argparse.ArgumentParser()\n    argparser.add_argument(\"--num_bodies\", type=int, default=6)\n    argparser.add_argument(\"--num_traj\", type=int, default=10)\n    argparser.add_argument(\"--steps\", type=int, default=1000)\n    argparser.add_argument(\"--data_path\", type=str, default=\"data\")\n\n    args = argparser.parse_args()\n    if not os.path.exists(args.data_path):\n        os.mkdir(args.data_path)\n\n    # Generate data\n    data_curr = []\n    data_next = []\n\n    for i in range(args.num_traj):\n        raw_traj = gen(args.num_bodies, args.steps, True)\n        data_curr.append(raw_traj[:-1])\n        data_next.append(raw_traj[1:])\n        print(\"Train Traj: \", i)\n\n    # Compute normalization statistic from data\n    stat_median, stat_max, stat_min = compute_stats(data_curr)\n    data = np.vstack(data_curr)\n    label = np.vstack(data_next)[:, :, 3:5]\n    shuffle_idx = np.arange(data.shape[0])\n    np.random.shuffle(shuffle_idx)\n    train_split = int(0.9 * data.shape[0])\n    valid_split = train_split + 300\n    data = data[shuffle_idx]\n    label = label[shuffle_idx]\n\n    train_data = data[:train_split]\n    train_label = label[:train_split]\n\n    valid_data = data[train_split:valid_split]\n    valid_label = label[train_split:valid_split]\n\n    test_data = data[valid_split:]\n    test_label = label[valid_split:]\n\n    np.savez(\n        args.data_path + \"/n_body_train.npz\",\n        data=train_data,\n        label=train_label,\n        n_particles=args.num_bodies,\n        median=stat_median,\n        max=stat_max,\n        min=stat_min,\n    )\n\n    np.savez(\n        args.data_path + \"/n_body_valid.npz\",\n        data=valid_data,\n        label=valid_label,\n        n_particles=args.num_bodies,\n    )\n\n    test_traj = gen(args.num_bodies, args.steps, True)\n\n    np.savez(\n        args.data_path + \"/n_body_test.npz\",\n        data=test_data,\n        label=test_label,\n        n_particles=args.num_bodies,\n        first_frame=test_traj[0],\n        test_traj=test_traj,\n    )\n"
  },
  {
    "path": "examples/pytorch/graphsim/train.py",
    "content": "import argparse\nimport time\nimport traceback\n\nimport dgl\n\nimport networkx as nx\nimport numpy as np\nimport torch\nfrom dataloader import (\n    MultiBodyGraphCollator,\n    MultiBodyTestDataset,\n    MultiBodyTrainDataset,\n    MultiBodyValidDataset,\n)\nfrom models import InteractionNet, MLP, PrepareLayer\nfrom torch.utils.data import DataLoader\nfrom utils import make_video\n\n\ndef train(\n    optimizer, loss_fn, reg_fn, model, prep, dataloader, lambda_reg, device\n):\n    total_loss = 0\n    model.train()\n    for i, (graph_batch, data_batch, label_batch) in enumerate(dataloader):\n        graph_batch = graph_batch.to(device)\n        data_batch = data_batch.to(device)\n        label_batch = label_batch.to(device)\n        optimizer.zero_grad()\n        node_feat, edge_feat = prep(graph_batch, data_batch)\n        dummy_relation = torch.zeros(edge_feat.shape[0], 1).float().to(device)\n        dummy_global = torch.zeros(node_feat.shape[0], 1).float().to(device)\n        v_pred, out_e = model(\n            graph_batch,\n            node_feat[:, 3:5].float(),\n            edge_feat.float(),\n            dummy_global,\n            dummy_relation,\n        )\n        loss = loss_fn(v_pred, label_batch)\n        total_loss += float(loss)\n        zero_target = torch.zeros_like(out_e)\n        loss = loss + lambda_reg * reg_fn(out_e, zero_target)\n        reg_loss = 0\n        for param in model.parameters():\n            reg_loss = reg_loss + lambda_reg * reg_fn(\n                param, torch.zeros_like(param).float().to(device)\n            )\n        loss = loss + reg_loss\n        loss.backward()\n        optimizer.step()\n    return total_loss / (i + 1)\n\n\n# One step evaluation\n\n\ndef eval(loss_fn, model, prep, dataloader, device):\n    total_loss = 0\n    model.eval()\n    for i, (graph_batch, data_batch, label_batch) in enumerate(dataloader):\n        graph_batch = graph_batch.to(device)\n        data_batch = data_batch.to(device)\n        label_batch = label_batch.to(device)\n        node_feat, edge_feat = prep(graph_batch, data_batch)\n        dummy_relation = torch.zeros(edge_feat.shape[0], 1).float().to(device)\n        dummy_global = torch.zeros(node_feat.shape[0], 1).float().to(device)\n        v_pred, _ = model(\n            graph_batch,\n            node_feat[:, 3:5].float(),\n            edge_feat.float(),\n            dummy_global,\n            dummy_relation,\n        )\n        loss = loss_fn(v_pred, label_batch)\n        total_loss += float(loss)\n    return total_loss / (i + 1)\n\n\n# Rollout Evaluation based in initial state\n# Need to integrate\n\n\ndef eval_rollout(model, prep, initial_frame, n_object, device):\n    current_frame = initial_frame.to(device)\n    base_graph = nx.complete_graph(n_object)\n    graph = dgl.from_networkx(base_graph).to(device)\n    pos_buffer = []\n    model.eval()\n    for step in range(100):\n        node_feats, edge_feats = prep(graph, current_frame)\n        dummy_relation = torch.zeros(edge_feats.shape[0], 1).float().to(device)\n        dummy_global = torch.zeros(node_feats.shape[0], 1).float().to(device)\n        v_pred, _ = model(\n            graph,\n            node_feats[:, 3:5].float(),\n            edge_feats.float(),\n            dummy_global,\n            dummy_relation,\n        )\n        current_frame[:, [1, 2]] += v_pred * 0.001\n        current_frame[:, 3:5] = v_pred\n        pos_buffer.append(current_frame[:, [1, 2]].cpu().numpy())\n    pos_buffer = np.vstack(pos_buffer).reshape(100, n_object, -1)\n    make_video(pos_buffer, \"video_model.mp4\")\n\n\nif __name__ == \"__main__\":\n    argparser = argparse.ArgumentParser()\n    argparser.add_argument(\n        \"--lr\", type=float, default=0.001, help=\"learning rate\"\n    )\n    argparser.add_argument(\n        \"--epochs\", type=int, default=40000, help=\"Number of epochs in training\"\n    )\n    argparser.add_argument(\n        \"--lambda_reg\", type=float, default=0.001, help=\"regularization weight\"\n    )\n    argparser.add_argument(\n        \"--gpu\", type=int, default=-1, help=\"gpu device code, -1 means cpu\"\n    )\n    argparser.add_argument(\n        \"--batch_size\", type=int, default=100, help=\"size of each mini batch\"\n    )\n    argparser.add_argument(\n        \"--num_workers\",\n        type=int,\n        default=0,\n        help=\"number of workers for dataloading\",\n    )\n    argparser.add_argument(\n        \"--visualize\",\n        action=\"store_true\",\n        default=False,\n        help=\"Whether enable trajectory rollout mode for visualization\",\n    )\n    args = argparser.parse_args()\n\n    # Select Device to be CPU or GPU\n    if args.gpu != -1:\n        device = torch.device(\"cuda:{}\".format(args.gpu))\n    else:\n        device = torch.device(\"cpu\")\n\n    train_data = MultiBodyTrainDataset()\n    valid_data = MultiBodyValidDataset()\n    test_data = MultiBodyTestDataset()\n    collator = MultiBodyGraphCollator(train_data.n_particles)\n\n    train_dataloader = DataLoader(\n        train_data,\n        args.batch_size,\n        True,\n        collate_fn=collator,\n        num_workers=args.num_workers,\n    )\n    valid_dataloader = DataLoader(\n        valid_data,\n        args.batch_size,\n        True,\n        collate_fn=collator,\n        num_workers=args.num_workers,\n    )\n    test_full_dataloader = DataLoader(\n        test_data,\n        args.batch_size,\n        True,\n        collate_fn=collator,\n        num_workers=args.num_workers,\n    )\n\n    node_feats = 5\n    stat = {\n        \"median\": torch.from_numpy(train_data.stat_median).to(device),\n        \"max\": torch.from_numpy(train_data.stat_max).to(device),\n        \"min\": torch.from_numpy(train_data.stat_min).to(device),\n    }\n    print(\n        \"Weight: \",\n        train_data.stat_median[0],\n        train_data.stat_max[0],\n        train_data.stat_min[0],\n    )\n    print(\n        \"Position: \",\n        train_data.stat_median[[1, 2]],\n        train_data.stat_max[[1, 2]],\n        train_data.stat_min[[1, 2]],\n    )\n    print(\n        \"Velocity: \",\n        train_data.stat_median[[3, 4]],\n        train_data.stat_max[[3, 4]],\n        train_data.stat_min[[3, 4]],\n    )\n\n    prepare_layer = PrepareLayer(node_feats, stat).to(device)\n    interaction_net = InteractionNet(node_feats, stat).to(device)\n    print(interaction_net)\n    optimizer = torch.optim.Adam(interaction_net.parameters(), lr=args.lr)\n    state_dict = interaction_net.state_dict()\n\n    loss_fn = torch.nn.MSELoss()\n    reg_fn = torch.nn.MSELoss(reduction=\"sum\")\n    try:\n        for e in range(args.epochs):\n            last_t = time.time()\n            loss = train(\n                optimizer,\n                loss_fn,\n                reg_fn,\n                interaction_net,\n                prepare_layer,\n                train_dataloader,\n                args.lambda_reg,\n                device,\n            )\n            print(\"Epoch time: \", time.time() - last_t)\n            if e % 1 == 0:\n                valid_loss = eval(\n                    loss_fn,\n                    interaction_net,\n                    prepare_layer,\n                    valid_dataloader,\n                    device,\n                )\n                test_full_loss = eval(\n                    loss_fn,\n                    interaction_net,\n                    prepare_layer,\n                    test_full_dataloader,\n                    device,\n                )\n                print(\n                    \"Epoch: {}.Loss: Valid: {} Full: {}\".format(\n                        e, valid_loss, test_full_loss\n                    )\n                )\n    except:\n        traceback.print_exc()\n    finally:\n        if args.visualize:\n            eval_rollout(\n                interaction_net,\n                prepare_layer,\n                test_data.first_frame,\n                test_data.n_particles,\n                device,\n            )\n            make_video(test_data.test_traj[:100, :, [1, 2]], \"video_truth.mp4\")\n"
  },
  {
    "path": "examples/pytorch/graphsim/utils.py",
    "content": "import os\n\nimport cv2 as cv\nimport matplotlib\nimport matplotlib.animation as manimation\nimport matplotlib.pyplot as plt\nimport numpy as np\n\nmatplotlib.use(\"agg\")\n\n# Make video can be used to visualize test data\n\n\ndef make_video(xy, filename):\n    os.system(\"rm -rf pics/*\")\n    FFMpegWriter = manimation.writers[\"ffmpeg\"]\n    metadata = dict(\n        title=\"Movie Test\", artist=\"Matplotlib\", comment=\"Movie support!\"\n    )\n    writer = FFMpegWriter(fps=15, metadata=metadata)\n    fig = plt.figure()\n    plt.xlim(-200, 200)\n    plt.ylim(-200, 200)\n    fig_num = len(xy)\n    color = [\"ro\", \"bo\", \"go\", \"ko\", \"yo\", \"mo\", \"co\"]\n    with writer.saving(fig, filename, len(xy)):\n        for i in range(len(xy)):\n            for j in range(len(xy[0])):\n                plt.plot(xy[i, j, 1], xy[i, j, 0], color[j % len(color)])\n            writer.grab_frame()\n"
  },
  {
    "path": "examples/pytorch/graphwriter/README.md",
    "content": "# GraphWriter-DGL\nIn this example we implement the GraphWriter, [Text Generation from Knowledge Graphs with Graph Transformers](https://arxiv.org/abs/1904.02342) in DGL. And the [author's code](https://github.com/rikdz/GraphWriter). \n\n## Dependencies\n- PyTorch >= 1.2  \n- tqdm   \n- pycoco (only for testing)  \n- multi-bleu.perl and other scripts from mosesdecoder (only for testing)\n\n## Usage\n```\n  # download data\n  sh prepare_data.sh \n  # training\n  sh run.sh\n  # testing\n  sh test.sh\n```\n\n## Result on AGENDA\n| |BLEU|METEOR| training time per epoch|\n|-|-|-|-|\n|Author's implementation|14.3+-1.01| 18.8+-0.28| 1970s|\n|DGL implementation|14.31+-0.34|19.74+-0.69| 1080s|\n\nWe use the author's code for the speed test, and our testbed is V100 GPU.\n\n| |BLEU| detok BLEU| METEOR | \n|-|-|-|-|\n|greedy, two layers| 13.97 +- 0.40| 13.78 +- 0.46| 18.76 +- 0.36|\n|beam 4, length penalty 1.0, two layers| 14.66 +- 0.65| 14.53 +- 0.52| 19.50 +- 0.49|\n|beam 4, length penalty 0.0, two layers| 14.33 +- 0.39| 14.09 +- 0.39| 18.63 +- 0.52|\n|greedy, six layers| 14.17 +- 0.46| 14.01 +- 0.51| 19.18 +- 0.49|\n|beam 4, length penalty 1.0, six layers| 14.31 +- 0.34| 14.35 +- 0.36| 19.74 +- 0.69|\n|beam 4, length penalty 0.0, six layers| 14.40 +- 0.85| 14.15 +- 0.84| 18.86 +- 0.78|\n\nWe repeat the experiment five times. \n\n### Examples\n\nWe also provide the output of our implementation on test set together with the reference text.\n- [GraphWriter's output](https://data.dgl.ai/models/graphwriter/tmp_pred.txt)\n- [Reference text](https://data.dgl.ai/models/graphwriter/tmp_gold.txt)\n\n"
  },
  {
    "path": "examples/pytorch/graphwriter/graphwriter.py",
    "content": "import torch\nfrom modules import BiLSTM, GraphTrans, MSA\nfrom torch import nn\nfrom utlis import *\n\nimport dgl\n\n\nclass GraphWriter(nn.Module):\n    def __init__(self, args):\n        super(GraphWriter, self).__init__()\n        self.args = args\n        if args.title:\n            self.title_emb = nn.Embedding(\n                len(args.title_vocab), args.nhid, padding_idx=0\n            )\n            self.title_enc = BiLSTM(args, enc_type=\"title\")\n            self.title_attn = MSA(args)\n        self.ent_emb = nn.Embedding(\n            len(args.ent_text_vocab), args.nhid, padding_idx=0\n        )\n        self.tar_emb = nn.Embedding(\n            len(args.text_vocab), args.nhid, padding_idx=0\n        )\n        if args.title:\n            nn.init.xavier_normal_(self.title_emb.weight)\n        nn.init.xavier_normal_(self.ent_emb.weight)\n        self.rel_emb = nn.Embedding(\n            len(args.rel_vocab), args.nhid, padding_idx=0\n        )\n        nn.init.xavier_normal_(self.rel_emb.weight)\n        self.decode_lstm = nn.LSTMCell(args.dec_ninp, args.nhid)\n        self.ent_enc = BiLSTM(args, enc_type=\"entity\")\n        self.graph_enc = GraphTrans(args)\n        self.ent_attn = MSA(args)\n        self.copy_attn = MSA(args, mode=\"copy\")\n        self.copy_fc = nn.Linear(args.dec_ninp, 1)\n        self.pred_v_fc = nn.Linear(args.dec_ninp, len(args.text_vocab))\n\n    def enc_forward(\n        self, batch, ent_mask, ent_text_mask, ent_len, rel_mask, title_mask\n    ):\n        title_enc = None\n        if self.args.title:\n            title_enc = self.title_enc(\n                self.title_emb(batch[\"title\"]), title_mask\n            )\n        ent_enc = self.ent_enc(\n            self.ent_emb(batch[\"ent_text\"]),\n            ent_text_mask,\n            ent_len=batch[\"ent_len\"],\n        )\n        rel_emb = self.rel_emb(batch[\"rel\"])\n        g_ent, g_root = self.graph_enc(\n            ent_enc, ent_mask, ent_len, rel_emb, rel_mask, batch[\"graph\"]\n        )\n        return g_ent, g_root, title_enc, ent_enc\n\n    def forward(self, batch, beam_size=-1):\n        ent_mask = len2mask(batch[\"ent_len\"], self.args.device)\n        ent_text_mask = batch[\"ent_text\"] == 0\n        rel_mask = batch[\"rel\"] == 0  # 0 means the <PAD>\n        title_mask = batch[\"title\"] == 0\n        g_ent, g_root, title_enc, ent_enc = self.enc_forward(\n            batch,\n            ent_mask,\n            ent_text_mask,\n            batch[\"ent_len\"],\n            rel_mask,\n            title_mask,\n        )\n\n        _h, _c = g_root, g_root.clone().detach()\n        ctx = _h + self.ent_attn(_h, g_ent, mask=ent_mask)\n        if self.args.title:\n            attn = _h + self.title_attn(_h, title_enc, mask=title_mask)\n            ctx = torch.cat([ctx, attn], 1)\n        if beam_size < 1:\n            # training\n            outs = []\n            tar_inp = self.tar_emb(batch[\"text\"].transpose(0, 1))\n            for t, xt in enumerate(tar_inp):\n                _xt = torch.cat([ctx, xt], 1)\n                _h, _c = self.decode_lstm(_xt, (_h, _c))\n                ctx = _h + self.ent_attn(_h, g_ent, mask=ent_mask)\n                if self.args.title:\n                    attn = _h + self.title_attn(_h, title_enc, mask=title_mask)\n                    ctx = torch.cat([ctx, attn], 1)\n                outs.append(torch.cat([_h, ctx], 1))\n            outs = torch.stack(outs, 1)\n            copy_gate = torch.sigmoid(self.copy_fc(outs))\n            EPSI = 1e-6\n            # copy\n            pred_v = torch.log(copy_gate + EPSI) + torch.log_softmax(\n                self.pred_v_fc(outs), -1\n            )\n            pred_c = torch.log((1.0 - copy_gate) + EPSI) + torch.log_softmax(\n                self.copy_attn(outs, ent_enc, mask=ent_mask), -1\n            )\n            pred = torch.cat([pred_v, pred_c], -1)\n            return pred\n        else:\n            if beam_size == 1:\n                # greedy\n                device = g_ent.device\n                B = g_ent.shape[0]\n                ent_type = batch[\"ent_type\"].view(B, -1)\n                seq = (\n                    torch.ones(\n                        B,\n                    )\n                    .long()\n                    .to(device)\n                    * self.args.text_vocab(\"<BOS>\")\n                ).unsqueeze(1)\n                for t in range(self.args.beam_max_len):\n                    _inp = replace_ent(\n                        seq[:, -1], ent_type, len(self.args.text_vocab)\n                    )\n                    xt = self.tar_emb(_inp)\n                    _xt = torch.cat([ctx, xt], 1)\n                    _h, _c = self.decode_lstm(_xt, (_h, _c))\n                    ctx = _h + self.ent_attn(_h, g_ent, mask=ent_mask)\n                    if self.args.title:\n                        attn = _h + self.title_attn(\n                            _h, title_enc, mask=title_mask\n                        )\n                        ctx = torch.cat([ctx, attn], 1)\n                    _y = torch.cat([_h, ctx], 1)\n                    copy_gate = torch.sigmoid(self.copy_fc(_y))\n                    pred_v = torch.log(copy_gate) + torch.log_softmax(\n                        self.pred_v_fc(_y), -1\n                    )\n                    pred_c = torch.log((1.0 - copy_gate)) + torch.log_softmax(\n                        self.copy_attn(\n                            _y.unsqueeze(1), ent_enc, mask=ent_mask\n                        ).squeeze(1),\n                        -1,\n                    )\n                    pred = torch.cat([pred_v, pred_c], -1).view(B, -1)\n                    for ban_item in [\"<BOS>\", \"<PAD>\", \"<UNK>\"]:\n                        pred[:, self.args.text_vocab(ban_item)] = -1e8\n                    _, word = pred.max(-1)\n                    seq = torch.cat([seq, word.unsqueeze(1)], 1)\n                return seq\n            else:\n                # beam search\n                device = g_ent.device\n                B = g_ent.shape[0]\n                BSZ = B * beam_size\n                _h = _h.view(B, 1, -1).repeat(1, beam_size, 1).view(BSZ, -1)\n                _c = _c.view(B, 1, -1).repeat(1, beam_size, 1).view(BSZ, -1)\n                ent_mask = (\n                    ent_mask.view(B, 1, -1)\n                    .repeat(1, beam_size, 1)\n                    .view(BSZ, -1)\n                )\n                if self.args.title:\n                    title_mask = (\n                        title_mask.view(B, 1, -1)\n                        .repeat(1, beam_size, 1)\n                        .view(BSZ, -1)\n                    )\n                    title_enc = (\n                        title_enc.view(B, 1, title_enc.size(1), -1)\n                        .repeat(1, beam_size, 1, 1)\n                        .view(BSZ, title_enc.size(1), -1)\n                    )\n                ctx = ctx.view(B, 1, -1).repeat(1, beam_size, 1).view(BSZ, -1)\n                ent_type = (\n                    batch[\"ent_type\"]\n                    .view(B, 1, -1)\n                    .repeat(1, beam_size, 1)\n                    .view(BSZ, -1)\n                )\n                g_ent = (\n                    g_ent.view(B, 1, g_ent.size(1), -1)\n                    .repeat(1, beam_size, 1, 1)\n                    .view(BSZ, g_ent.size(1), -1)\n                )\n                ent_enc = (\n                    ent_enc.view(B, 1, ent_enc.size(1), -1)\n                    .repeat(1, beam_size, 1, 1)\n                    .view(BSZ, ent_enc.size(1), -1)\n                )\n\n                beam_best = torch.zeros(B).to(device) - 1e9\n                beam_best_seq = [None] * B\n                beam_seq = (\n                    torch.ones(B, beam_size).long().to(device)\n                    * self.args.text_vocab(\"<BOS>\")\n                ).unsqueeze(-1)\n                beam_score = torch.zeros(B, beam_size).to(device)\n                done_flag = torch.zeros(B, beam_size)\n                for t in range(self.args.beam_max_len):\n                    _inp = replace_ent(\n                        beam_seq[:, :, -1].view(-1),\n                        ent_type,\n                        len(self.args.text_vocab),\n                    )\n                    xt = self.tar_emb(_inp)\n                    _xt = torch.cat([ctx, xt], 1)\n                    _h, _c = self.decode_lstm(_xt, (_h, _c))\n                    ctx = _h + self.ent_attn(_h, g_ent, mask=ent_mask)\n                    if self.args.title:\n                        attn = _h + self.title_attn(\n                            _h, title_enc, mask=title_mask\n                        )\n                        ctx = torch.cat([ctx, attn], 1)\n                    _y = torch.cat([_h, ctx], 1)\n                    copy_gate = torch.sigmoid(self.copy_fc(_y))\n                    pred_v = torch.log(copy_gate) + torch.log_softmax(\n                        self.pred_v_fc(_y), -1\n                    )\n                    pred_c = torch.log((1.0 - copy_gate)) + torch.log_softmax(\n                        self.copy_attn(\n                            _y.unsqueeze(1), ent_enc, mask=ent_mask\n                        ).squeeze(1),\n                        -1,\n                    )\n                    pred = torch.cat([pred_v, pred_c], -1).view(\n                        B, beam_size, -1\n                    )\n                    for ban_item in [\"<BOS>\", \"<PAD>\", \"<UNK>\"]:\n                        pred[:, :, self.args.text_vocab(ban_item)] = -1e8\n                    if t == self.args.beam_max_len - 1:  # force ending\n                        tt = pred[:, :, self.args.text_vocab(\"<EOS>\")]\n                        pred = pred * 0 - 1e8\n                        pred[:, :, self.args.text_vocab(\"<EOS>\")] = tt\n                    cum_score = beam_score.view(B, beam_size, 1) + pred\n                    score, word = cum_score.topk(\n                        dim=-1, k=beam_size\n                    )  # B, beam_size, beam_size\n                    score, word = score.view(B, -1), word.view(B, -1)\n                    eos_idx = self.args.text_vocab(\"<EOS>\")\n                    if beam_seq.size(2) == 1:\n                        new_idx = torch.arange(beam_size).to(word)\n                        new_idx = new_idx[None, :].repeat(B, 1)\n                    else:\n                        _, new_idx = score.topk(dim=-1, k=beam_size)\n                    new_src, new_score, new_word, new_done = [], [], [], []\n                    LP = beam_seq.size(2) ** self.args.lp\n                    for i in range(B):\n                        for j in range(beam_size):\n                            tmp_score = score[i][new_idx[i][j]]\n                            tmp_word = word[i][new_idx[i][j]]\n                            src_idx = new_idx[i][j] // beam_size\n                            new_src.append(src_idx)\n                            if tmp_word == eos_idx:\n                                new_score.append(-1e8)\n                            else:\n                                new_score.append(tmp_score)\n                            new_word.append(tmp_word)\n\n                            if (\n                                tmp_word == eos_idx\n                                and done_flag[i][src_idx] == 0\n                                and tmp_score / LP > beam_best[i]\n                            ):\n                                beam_best[i] = tmp_score / LP\n                                beam_best_seq[i] = beam_seq[i][src_idx]\n                            if tmp_word == eos_idx:\n                                new_done.append(1)\n                            else:\n                                new_done.append(done_flag[i][src_idx])\n                    new_score = (\n                        torch.Tensor(new_score)\n                        .view(B, beam_size)\n                        .to(beam_score)\n                    )\n                    new_word = (\n                        torch.Tensor(new_word).view(B, beam_size).to(beam_seq)\n                    )\n                    new_src = (\n                        torch.LongTensor(new_src).view(B, beam_size).to(device)\n                    )\n                    new_done = (\n                        torch.Tensor(new_done).view(B, beam_size).to(done_flag)\n                    )\n                    beam_score = new_score\n                    done_flag = new_done\n                    beam_seq = beam_seq.view(B, beam_size, -1)[\n                        torch.arange(B)[:, None].to(device), new_src\n                    ]\n                    beam_seq = torch.cat([beam_seq, new_word.unsqueeze(2)], 2)\n                    _h = _h.view(B, beam_size, -1)[\n                        torch.arange(B)[:, None].to(device), new_src\n                    ].view(BSZ, -1)\n                    _c = _c.view(B, beam_size, -1)[\n                        torch.arange(B)[:, None].to(device), new_src\n                    ].view(BSZ, -1)\n                    ctx = ctx.view(B, beam_size, -1)[\n                        torch.arange(B)[:, None].to(device), new_src\n                    ].view(BSZ, -1)\n\n                return beam_best_seq\n"
  },
  {
    "path": "examples/pytorch/graphwriter/modules.py",
    "content": "import math\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\nfrom torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence\nfrom utlis import *\n\nimport dgl.function as fn\nfrom dgl.nn.functional import edge_softmax\n\n\nclass MSA(nn.Module):\n    # multi-head self-attention, three modes\n    # the first is the copy, determining which entity should be copied.\n    # the second is the normal attention with two sequence inputs\n    # the third is the attention but with one token and a sequence. (gather, attentive pooling)\n\n    def __init__(self, args, mode=\"normal\"):\n        super(MSA, self).__init__()\n        if mode == \"copy\":\n            nhead, head_dim = 1, args.nhid\n            qninp, kninp = args.dec_ninp, args.nhid\n        if mode == \"normal\":\n            nhead, head_dim = args.nhead, args.head_dim\n            qninp, kninp = args.nhid, args.nhid\n        self.attn_drop = nn.Dropout(0.1)\n        self.WQ = nn.Linear(\n            qninp, nhead * head_dim, bias=True if mode == \"copy\" else False\n        )\n        if mode != \"copy\":\n            self.WK = nn.Linear(kninp, nhead * head_dim, bias=False)\n            self.WV = nn.Linear(kninp, nhead * head_dim, bias=False)\n        self.args, self.nhead, self.head_dim, self.mode = (\n            args,\n            nhead,\n            head_dim,\n            mode,\n        )\n\n    def forward(self, inp1, inp2, mask=None):\n        B, L2, H = inp2.shape\n        NH, HD = self.nhead, self.head_dim\n        if self.mode == \"copy\":\n            q, k, v = self.WQ(inp1), inp2, inp2\n        else:\n            q, k, v = self.WQ(inp1), self.WK(inp2), self.WV(inp2)\n        L1 = 1 if inp1.ndim == 2 else inp1.shape[1]\n        if self.mode != \"copy\":\n            q = q / math.sqrt(H)\n        q = q.view(B, L1, NH, HD).permute(0, 2, 1, 3)\n        k = k.view(B, L2, NH, HD).permute(0, 2, 3, 1)\n        v = v.view(B, L2, NH, HD).permute(0, 2, 1, 3)\n        pre_attn = torch.matmul(q, k)\n        if mask is not None:\n            pre_attn = pre_attn.masked_fill(mask[:, None, None, :], -1e8)\n        if self.mode == \"copy\":\n            return pre_attn.squeeze(1)\n        else:\n            alpha = self.attn_drop(torch.softmax(pre_attn, -1))\n            attn = (\n                torch.matmul(alpha, v)\n                .permute(0, 2, 1, 3)\n                .contiguous()\n                .view(B, L1, NH * HD)\n            )\n            ret = attn\n            if inp1.ndim == 2:\n                return ret.squeeze(1)\n            else:\n                return ret\n\n\nclass BiLSTM(nn.Module):\n    # for entity encoding or the title encoding\n    def __init__(self, args, enc_type=\"title\"):\n        super(BiLSTM, self).__init__()\n        self.enc_type = enc_type\n        self.drop = nn.Dropout(args.emb_drop)\n        self.bilstm = nn.LSTM(\n            args.nhid,\n            args.nhid // 2,\n            bidirectional=True,\n            num_layers=args.enc_lstm_layers,\n            batch_first=True,\n        )\n\n    def forward(self, inp, mask, ent_len=None):\n        inp = self.drop(inp)\n        lens = (mask == 0).sum(-1).long().tolist()\n        pad_seq = pack_padded_sequence(\n            inp, lens, batch_first=True, enforce_sorted=False\n        )\n        y, (_h, _c) = self.bilstm(pad_seq)\n        if self.enc_type == \"title\":\n            y = pad_packed_sequence(y, batch_first=True)[0]\n            return y\n        if self.enc_type == \"entity\":\n            _h = _h.transpose(0, 1).contiguous()\n            _h = _h[:, -2:].view(\n                _h.size(0), -1\n            )  # two directions of the top-layer\n            ret = pad(_h.split(ent_len), out_type=\"tensor\")\n            return ret\n\n\nclass GAT(nn.Module):\n    # a graph attention network with dot-product attention\n    def __init__(\n        self,\n        in_feats,\n        out_feats,\n        num_heads,\n        ffn_drop=0.0,\n        attn_drop=0.0,\n        trans=True,\n    ):\n        super(GAT, self).__init__()\n        self._num_heads = num_heads\n        self._in_feats = in_feats\n        self._out_feats = out_feats\n        self.q_proj = nn.Linear(in_feats, num_heads * out_feats, bias=False)\n        self.k_proj = nn.Linear(in_feats, num_heads * out_feats, bias=False)\n        self.v_proj = nn.Linear(in_feats, num_heads * out_feats, bias=False)\n        self.attn_drop = nn.Dropout(0.1)\n        self.ln1 = nn.LayerNorm(in_feats)\n        self.ln2 = nn.LayerNorm(in_feats)\n        if trans:\n            self.FFN = nn.Sequential(\n                nn.Linear(in_feats, 4 * in_feats),\n                nn.PReLU(4 * in_feats),\n                nn.Linear(4 * in_feats, in_feats),\n                nn.Dropout(0.1),\n            )\n            # a strange FFN, see the author's code\n        self._trans = trans\n\n    def forward(self, graph, feat):\n        graph = graph.local_var()\n        feat_c = feat.clone().detach().requires_grad_(False)\n        q, k, v = self.q_proj(feat), self.k_proj(feat_c), self.v_proj(feat_c)\n        q = q.view(-1, self._num_heads, self._out_feats)\n        k = k.view(-1, self._num_heads, self._out_feats)\n        v = v.view(-1, self._num_heads, self._out_feats)\n        graph.ndata.update(\n            {\"ft\": v, \"el\": k, \"er\": q}\n        )  # k,q instead of q,k, the edge_softmax is applied on incoming edges\n        # compute edge attention\n        graph.apply_edges(fn.u_dot_v(\"el\", \"er\", \"e\"))\n        e = graph.edata.pop(\"e\") / math.sqrt(self._out_feats * self._num_heads)\n        graph.edata[\"a\"] = edge_softmax(graph, e)\n        # message passing\n        graph.update_all(fn.u_mul_e(\"ft\", \"a\", \"m\"), fn.sum(\"m\", \"ft2\"))\n        rst = graph.ndata[\"ft2\"]\n        # residual\n        rst = rst.view(feat.shape) + feat\n        if self._trans:\n            rst = self.ln1(rst)\n            rst = self.ln1(rst + self.FFN(rst))\n            # use the same layer norm, see the author's code\n        return rst\n\n\nclass GraphTrans(nn.Module):\n    def __init__(self, args):\n        super().__init__()\n        self.args = args\n        if args.graph_enc == \"gat\":\n            # we only support gtrans, don't use this one\n            self.gat = nn.ModuleList(\n                [\n                    GAT(\n                        args.nhid,\n                        args.nhid // 4,\n                        4,\n                        attn_drop=args.attn_drop,\n                        trans=False,\n                    )\n                    for _ in range(args.prop)\n                ]\n            )  # untested\n        else:\n            self.gat = nn.ModuleList(\n                [\n                    GAT(\n                        args.nhid,\n                        args.nhid // 4,\n                        4,\n                        attn_drop=args.attn_drop,\n                        ffn_drop=args.drop,\n                        trans=True,\n                    )\n                    for _ in range(args.prop)\n                ]\n            )\n        self.prop = args.prop\n\n    def forward(self, ent, ent_mask, ent_len, rel, rel_mask, graphs):\n        device = ent.device\n        graphs = graphs.to(device)\n        ent_mask = ent_mask == 0  # reverse mask\n        rel_mask = rel_mask == 0\n        init_h = []\n        for i in range(graphs.batch_size):\n            init_h.append(ent[i][ent_mask[i]])\n            init_h.append(rel[i][rel_mask[i]])\n        init_h = torch.cat(init_h, 0)\n        feats = init_h\n        for i in range(self.prop):\n            feats = self.gat[i](graphs, feats)\n        g_root = feats.index_select(\n            0,\n            graphs.filter_nodes(\n                lambda x: x.data[\"type\"] == NODE_TYPE[\"root\"]\n            ).to(device),\n        )\n        g_ent = pad(\n            feats.index_select(\n                0,\n                graphs.filter_nodes(\n                    lambda x: x.data[\"type\"] == NODE_TYPE[\"entity\"]\n                ).to(device),\n            ).split(ent_len),\n            out_type=\"tensor\",\n        )\n        return g_ent, g_root\n"
  },
  {
    "path": "examples/pytorch/graphwriter/opts.py",
    "content": "import argparse\n\nimport torch\n\n\ndef fill_config(args):\n    # dirty work\n    args.device = torch.device(args.gpu)\n    args.dec_ninp = args.nhid * 3 if args.title else args.nhid * 2\n    args.fnames = [args.train_file, args.valid_file, args.test_file]\n    return args\n\n\ndef vocab_config(\n    args, ent_vocab, rel_vocab, text_vocab, ent_text_vocab, title_vocab\n):\n    # dirty work\n    args.ent_vocab = ent_vocab\n    args.rel_vocab = rel_vocab\n    args.text_vocab = text_vocab\n    args.ent_text_vocab = ent_text_vocab\n    args.title_vocab = title_vocab\n    return args\n\n\ndef get_args():\n    args = argparse.ArgumentParser(description=\"Graph Writer in DGL\")\n    args.add_argument(\"--nhid\", default=500, type=int, help=\"hidden size\")\n    args.add_argument(\"--nhead\", default=4, type=int, help=\"number of heads\")\n    args.add_argument(\"--head_dim\", default=125, type=int, help=\"head dim\")\n    args.add_argument(\n        \"--weight_decay\", default=0.0, type=float, help=\"weight decay\"\n    )\n    args.add_argument(\n        \"--prop\", default=6, type=int, help=\"number of layers of gnn\"\n    )\n    args.add_argument(\"--title\", action=\"store_true\", help=\"use title input\")\n    args.add_argument(\"--test\", action=\"store_true\", help=\"inference mode\")\n    args.add_argument(\"--batch_size\", default=32, type=int, help=\"batch_size\")\n    args.add_argument(\n        \"--beam_size\", default=4, type=int, help=\"beam size, 1 for greedy\"\n    )\n    args.add_argument(\"--epoch\", default=20, type=int, help=\"training epoch\")\n    args.add_argument(\n        \"--beam_max_len\",\n        default=200,\n        type=int,\n        help=\"max length of the generated text\",\n    )\n    args.add_argument(\n        \"--enc_lstm_layers\",\n        default=2,\n        type=int,\n        help=\"number of layers of lstm\",\n    )\n    args.add_argument(\"--lr\", default=1e-1, type=float, help=\"learning rate\")\n    # args.add_argument('--lr_decay', default=1e-8, type=float, help='')\n    args.add_argument(\"--clip\", default=1, type=float, help=\"gradient clip\")\n    args.add_argument(\n        \"--emb_drop\", default=0.0, type=float, help=\"embedding dropout\"\n    )\n    args.add_argument(\n        \"--attn_drop\", default=0.1, type=float, help=\"attention dropout\"\n    )\n    args.add_argument(\"--drop\", default=0.1, type=float, help=\"dropout\")\n    args.add_argument(\"--lp\", default=1.0, type=float, help=\"length penalty\")\n    args.add_argument(\n        \"--graph_enc\",\n        default=\"gtrans\",\n        type=str,\n        help=\"gnn mode, we only support the graph transformer now\",\n    )\n    args.add_argument(\n        \"--train_file\",\n        default=\"data/unprocessed.train.json\",\n        type=str,\n        help=\"training file\",\n    )\n    args.add_argument(\n        \"--valid_file\",\n        default=\"data/unprocessed.val.json\",\n        type=str,\n        help=\"validation file\",\n    )\n    args.add_argument(\n        \"--test_file\",\n        default=\"data/unprocessed.test.json\",\n        type=str,\n        help=\"test file\",\n    )\n    args.add_argument(\n        \"--save_dataset\",\n        default=\"data.pickle\",\n        type=str,\n        help=\"save path of dataset\",\n    )\n    args.add_argument(\n        \"--save_model\",\n        default=\"saved_model.pt\",\n        type=str,\n        help=\"save path of model\",\n    )\n\n    args.add_argument(\"--gpu\", default=0, type=int, help=\"gpu mode\")\n    args = args.parse_args()\n    args = fill_config(args)\n    return args\n"
  },
  {
    "path": "examples/pytorch/graphwriter/prepare_data.sh",
    "content": "wget https://data.dgl.ai/dataset/AGENDA.tar.gz\nmkdir data\ntar -C data/ -xvzf AGENDA.tar.gz\n"
  },
  {
    "path": "examples/pytorch/graphwriter/run.sh",
    "content": "nohup env CUDA_VISIBLE_DEVICES=0 python -u train.py --prop 6 --save_model tmp_model.pt --title > train_1.log 2>&1 &\n#nohup env CUDA_VISIBLE_DEVICES=2 python -u train.py --prop 6 --save_model tmp_model1.pt --title > train_2.log 2>&1 &\n#nohup env CUDA_VISIBLE_DEVICES=3 python -u train.py --prop 6 --save_model tmp_model2.pt --title > train_3.log 2>&1 &\n#nohup env CUDA_VISIBLE_DEVICES=4 python -u train.py --prop 6 --save_model tmp_model3.pt --title > train_4.log 2>&1 &\n#nohup env CUDA_VISIBLE_DEVICES=5 python -u train.py --prop 2 --save_model tmp_model4.pt --title > train_5.log 2>&1 &\n#nohup env CUDA_VISIBLE_DEVICES=6 python -u train.py --prop 2 --save_model tmp_model5.pt --title > train_6.log 2>&1 &\n"
  },
  {
    "path": "examples/pytorch/graphwriter/test.sh",
    "content": "env CUDA_VISIBLE_DEVICES=0 python -u train.py --save_model tmp_model.ptbest --test  --title --lp 1.0 --beam_size 1\nif [ ! detokenizer.perl ]; then \n    wget https://raw.githubusercontent.com/moses-smt/mosesdecoder/8c5eaa1a122236bbf927bde4ec610906fea599e6/scripts/tokenizer/detokenizer.perl\nfi\nif [ ! multi-bleu.perl ]; then\n    wget https://raw.githubusercontent.com/moses-smt/mosesdecoder/8c5eaa1a122236bbf927bde4ec610906fea599e6/scripts/generic/multi-bleu.perl\nfi\nperl detokenizer.perl -l en < tmp_gold.txt > tmp_gold.txt.a\nperl detokenizer.perl -l en < tmp_pred.txt > tmp_pred.txt.a\nperl multi-bleu.perl tmp_gold.txt < tmp_pred.txt\nperl multi-bleu-detok.perl tmp_gold.txt.a < tmp_pred.txt.a\n"
  },
  {
    "path": "examples/pytorch/graphwriter/train.py",
    "content": "import os\nimport sys\nimport time\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom graphwriter import *\nfrom opts import *\nfrom tqdm import tqdm\nfrom utlis import *\n\nsys.path.append(\"./pycocoevalcap\")\nfrom pycocoevalcap.bleu.bleu import Bleu\nfrom pycocoevalcap.meteor.meteor import Meteor\nfrom pycocoevalcap.rouge.rouge import Rouge\n\n\ndef train_one_epoch(model, dataloader, optimizer, args, epoch):\n    model.train()\n    tloss = 0.0\n    tcnt = 0.0\n    st_time = time.time()\n    with tqdm(dataloader, desc=\"Train Ep \" + str(epoch), mininterval=60) as tq:\n        for batch in tq:\n            pred = model(batch)\n            nll_loss = F.nll_loss(\n                pred.view(-1, pred.shape[-1]),\n                batch[\"tgt_text\"].view(-1),\n                ignore_index=0,\n            )\n            loss = nll_loss\n            optimizer.zero_grad()\n            loss.backward()\n            nn.utils.clip_grad_norm_(model.parameters(), args.clip)\n            optimizer.step()\n            loss = loss.item()\n            if loss != loss:\n                raise ValueError(\"NaN appear\")\n            tloss += loss * len(batch[\"tgt_text\"])\n            tcnt += len(batch[\"tgt_text\"])\n            tq.set_postfix({\"loss\": tloss / tcnt}, refresh=False)\n    print(\n        \"Train Ep \",\n        str(epoch),\n        \"AVG Loss \",\n        tloss / tcnt,\n        \"Steps \",\n        tcnt,\n        \"Time \",\n        time.time() - st_time,\n        \"GPU\",\n        torch.cuda.max_memory_cached() / 1024.0 / 1024.0 / 1024.0,\n    )\n    torch.save(model, args.save_model + str(epoch % 100))\n\n\nval_loss = 2**31\n\n\ndef eval_it(model, dataloader, args, epoch):\n    global val_loss\n    model.eval()\n    tloss = 0.0\n    tcnt = 0.0\n    st_time = time.time()\n    with tqdm(dataloader, desc=\"Eval Ep \" + str(epoch), mininterval=60) as tq:\n        for batch in tq:\n            with torch.no_grad():\n                pred = model(batch)\n                nll_loss = F.nll_loss(\n                    pred.view(-1, pred.shape[-1]),\n                    batch[\"tgt_text\"].view(-1),\n                    ignore_index=0,\n                )\n            loss = nll_loss\n            loss = loss.item()\n            tloss += loss * len(batch[\"tgt_text\"])\n            tcnt += len(batch[\"tgt_text\"])\n            tq.set_postfix({\"loss\": tloss / tcnt}, refresh=False)\n    print(\n        \"Eval Ep \",\n        str(epoch),\n        \"AVG Loss \",\n        tloss / tcnt,\n        \"Steps \",\n        tcnt,\n        \"Time \",\n        time.time() - st_time,\n    )\n    if tloss / tcnt < val_loss:\n        print(\"Saving best model \", \"Ep \", epoch, \" loss \", tloss / tcnt)\n        torch.save(model, args.save_model + \"best\")\n        val_loss = tloss / tcnt\n\n\ndef test(model, dataloader, args):\n    scorer = Bleu(4)\n    m_scorer = Meteor()\n    r_scorer = Rouge()\n    hyp = []\n    ref = []\n    model.eval()\n    gold_file = open(\"tmp_gold.txt\", \"w\")\n    pred_file = open(\"tmp_pred.txt\", \"w\")\n    with tqdm(dataloader, desc=\"Test \", mininterval=1) as tq:\n        for batch in tq:\n            with torch.no_grad():\n                seq = model(batch, beam_size=args.beam_size)\n            r = write_txt(batch, batch[\"tgt_text\"], gold_file, args)\n            h = write_txt(batch, seq, pred_file, args)\n            hyp.extend(h)\n            ref.extend(r)\n    hyp = dict(zip(range(len(hyp)), hyp))\n    ref = dict(zip(range(len(ref)), ref))\n    print(hyp[0], ref[0])\n    print(\"BLEU INP\", len(hyp), len(ref))\n    print(\"BLEU\", scorer.compute_score(ref, hyp)[0])\n    print(\"METEOR\", m_scorer.compute_score(ref, hyp)[0])\n    print(\"ROUGE_L\", r_scorer.compute_score(ref, hyp)[0])\n    gold_file.close()\n    pred_file.close()\n\n\ndef main(args):\n    if os.path.exists(args.save_dataset):\n        train_dataset, valid_dataset, test_dataset = pickle.load(\n            open(args.save_dataset, \"rb\")\n        )\n    else:\n        train_dataset, valid_dataset, test_dataset = get_datasets(\n            args.fnames, device=args.device, save=args.save_dataset\n        )\n    args = vocab_config(\n        args,\n        train_dataset.ent_vocab,\n        train_dataset.rel_vocab,\n        train_dataset.text_vocab,\n        train_dataset.ent_text_vocab,\n        train_dataset.title_vocab,\n    )\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        batch_sampler=BucketSampler(train_dataset, batch_size=args.batch_size),\n        collate_fn=train_dataset.batch_fn,\n    )\n    valid_dataloader = torch.utils.data.DataLoader(\n        valid_dataset,\n        batch_size=args.batch_size,\n        shuffle=False,\n        collate_fn=train_dataset.batch_fn,\n    )\n    test_dataloader = torch.utils.data.DataLoader(\n        test_dataset,\n        batch_size=args.batch_size,\n        shuffle=False,\n        collate_fn=train_dataset.batch_fn,\n    )\n\n    model = GraphWriter(args)\n    model.to(args.device)\n    if args.test:\n        model = torch.load(args.save_model, weights_only=False)\n        model.args = args\n        print(model)\n        test(model, test_dataloader, args)\n    else:\n        optimizer = torch.optim.SGD(\n            model.parameters(),\n            lr=args.lr,\n            weight_decay=args.weight_decay,\n            momentum=0.9,\n        )\n        print(model)\n        for epoch in range(args.epoch):\n            train_one_epoch(model, train_dataloader, optimizer, args, epoch)\n            eval_it(model, valid_dataloader, args, epoch)\n\n\nif __name__ == \"__main__\":\n    args = get_args()\n    main(args)\n"
  },
  {
    "path": "examples/pytorch/graphwriter/utlis.py",
    "content": "import json\nimport pickle\nimport random\n\nimport dgl\n\nimport numpy as np\nimport torch\n\nNODE_TYPE = {\"entity\": 0, \"root\": 1, \"relation\": 2}\n\n\ndef write_txt(batch, seqs, w_file, args):\n    # converting the prediction to real text.\n    ret = []\n    for b, seq in enumerate(seqs):\n        txt = []\n        for token in seq:\n            # copy the entity\n            if token >= len(args.text_vocab):\n                ent_text = batch[\"raw_ent_text\"][b][\n                    token - len(args.text_vocab)\n                ]\n                ent_text = filter(lambda x: x != \"<PAD>\", ent_text)\n                txt.extend(ent_text)\n            else:\n                if int(token) not in [\n                    args.text_vocab(x) for x in [\"<PAD>\", \"<BOS>\", \"<EOS>\"]\n                ]:\n                    txt.append(args.text_vocab(int(token)))\n            if int(token) == args.text_vocab(\"<EOS>\"):\n                break\n        w_file.write(\" \".join([str(x) for x in txt]) + \"\\n\")\n        ret.append([\" \".join([str(x) for x in txt])])\n    return ret\n\n\ndef replace_ent(x, ent, V):\n    # replace the entity\n    mask = x >= V\n    if mask.sum() == 0:\n        return x\n    nz = mask.nonzero()\n    fill_ent = ent[nz, x[mask] - V]\n    x = x.masked_scatter(mask, fill_ent)\n    return x\n\n\ndef len2mask(lens, device):\n    max_len = max(lens)\n    mask = (\n        torch.arange(max_len, device=device)\n        .unsqueeze(0)\n        .expand(len(lens), max_len)\n    )\n    mask = mask >= torch.LongTensor(lens).to(mask).unsqueeze(1)\n    return mask\n\n\ndef pad(var_len_list, out_type=\"list\", flatten=False):\n    if flatten:\n        lens = [len(x) for x in var_len_list]\n        var_len_list = sum(var_len_list, [])\n    max_len = max([len(x) for x in var_len_list])\n    if out_type == \"list\":\n        if flatten:\n            return [\n                x + [\"<PAD>\"] * (max_len - len(x)) for x in var_len_list\n            ], lens\n        else:\n            return [x + [\"<PAD>\"] * (max_len - len(x)) for x in var_len_list]\n    if out_type == \"tensor\":\n        if flatten:\n            return (\n                torch.stack(\n                    [\n                        torch.cat(\n                            [\n                                x,\n                                torch.zeros(\n                                    [max_len - len(x)] + list(x.shape[1:])\n                                ).type_as(x),\n                            ],\n                            0,\n                        )\n                        for x in var_len_list\n                    ],\n                    0,\n                ),\n                lens,\n            )\n        else:\n            return torch.stack(\n                [\n                    torch.cat(\n                        [\n                            x,\n                            torch.zeros(\n                                [max_len - len(x)] + list(x.shape[1:])\n                            ).type_as(x),\n                        ],\n                        0,\n                    )\n                    for x in var_len_list\n                ],\n                0,\n            )\n\n\nclass Vocab(object):\n    def __init__(\n        self,\n        max_vocab=2**31,\n        min_freq=-1,\n        sp=[\"<PAD>\", \"<BOS>\", \"<EOS>\", \"<UNK>\"],\n    ):\n        self.i2s = []\n        self.s2i = {}\n        self.wf = {}\n        self.max_vocab, self.min_freq, self.sp = max_vocab, min_freq, sp\n\n    def __len__(self):\n        return len(self.i2s)\n\n    def __str__(self):\n        return \"Total \" + str(len(self.i2s)) + str(self.i2s[:10])\n\n    def update(self, token):\n        if isinstance(token, list):\n            for t in token:\n                self.update(t)\n        else:\n            self.wf[token] = self.wf.get(token, 0) + 1\n\n    def build(self):\n        self.i2s.extend(self.sp)\n        sort_kv = sorted(self.wf.items(), key=lambda x: x[1], reverse=True)\n        for k, v in sort_kv:\n            if (\n                len(self.i2s) < self.max_vocab\n                and v >= self.min_freq\n                and k not in self.sp\n            ):\n                self.i2s.append(k)\n        self.s2i.update(list(zip(self.i2s, range(len(self.i2s)))))\n\n    def __call__(self, x):\n        if isinstance(x, int):\n            return self.i2s[x]\n        else:\n            return self.s2i.get(x, self.s2i[\"<UNK>\"])\n\n    def save(self, fname):\n        pass\n\n    def load(self, fname):\n        pass\n\n\ndef at_least(x):\n    # handling the illegal data\n    if len(x) == 0:\n        return [\"<UNK>\"]\n    else:\n        return x\n\n\nclass Example(object):\n    def __init__(self, title, ent_text, ent_type, rel, text):\n        # one object corresponds to a data sample\n        self.raw_title = title.split()\n        self.raw_ent_text = [at_least(x.split()) for x in ent_text]\n        assert min([len(x) for x in self.raw_ent_text]) > 0, str(\n            self.raw_ent_text\n        )\n        self.raw_ent_type = ent_type.split()  # <method> .. <>\n        self.raw_rel = []\n        for r in rel:\n            rel_list = r.split()\n            for i in range(len(rel_list)):\n                if (\n                    i > 0\n                    and i < len(rel_list) - 1\n                    and rel_list[i - 1] == \"--\"\n                    and rel_list[i] != rel_list[i].lower()\n                    and rel_list[i + 1] == \"--\"\n                ):\n                    self.raw_rel.append(\n                        [\n                            rel_list[: i - 1],\n                            rel_list[i - 1] + rel_list[i] + rel_list[i + 1],\n                            rel_list[i + 2 :],\n                        ]\n                    )\n                    break\n        self.raw_text = text.split()\n        self.graph = self.build_graph()\n\n    def __str__(self):\n        return \"\\n\".join(\n            [str(k) + \":\\t\" + str(v) for k, v in self.__dict__.items()]\n        )\n\n    def __len__(self):\n        return len(self.raw_text)\n\n    @staticmethod\n    def from_json(json_data):\n        return Example(\n            json_data[\"title\"],\n            json_data[\"entities\"],\n            json_data[\"types\"],\n            json_data[\"relations\"],\n            json_data[\"abstract\"],\n        )\n\n    def build_graph(self):\n        graph = dgl.DGLGraph()\n        ent_len = len(self.raw_ent_text)\n        rel_len = len(\n            self.raw_rel\n        )  # treat the repeated relation as different nodes, refer to the author's code\n\n        graph.add_nodes(\n            ent_len, {\"type\": torch.ones(ent_len) * NODE_TYPE[\"entity\"]}\n        )\n        graph.add_nodes(1, {\"type\": torch.ones(1) * NODE_TYPE[\"root\"]})\n        graph.add_nodes(\n            rel_len * 2,\n            {\"type\": torch.ones(rel_len * 2) * NODE_TYPE[\"relation\"]},\n        )\n        graph.add_edges(ent_len, torch.arange(ent_len))\n        graph.add_edges(torch.arange(ent_len), ent_len)\n        graph.add_edges(\n            torch.arange(ent_len + 1 + rel_len * 2),\n            torch.arange(ent_len + 1 + rel_len * 2),\n        )\n        adj_edges = []\n        for i, r in enumerate(self.raw_rel):\n            assert len(r) == 3, str(r)\n            st, rt, ed = r\n            st_ent, ed_ent = self.raw_ent_text.index(\n                st\n            ), self.raw_ent_text.index(ed)\n            # according to the edge_softmax operator, we need to reverse the graph\n            adj_edges.append([ent_len + 1 + 2 * i, st_ent])\n            adj_edges.append([ed_ent, ent_len + 1 + 2 * i])\n            adj_edges.append([ent_len + 1 + 2 * i + 1, ed_ent])\n            adj_edges.append([st_ent, ent_len + 1 + 2 * i + 1])\n\n        if len(adj_edges) > 0:\n            graph.add_edges(*list(map(list, zip(*adj_edges))))\n        return graph\n\n    def get_tensor(\n        self, ent_vocab, rel_vocab, text_vocab, ent_text_vocab, title_vocab\n    ):\n        if hasattr(self, \"_cached_tensor\"):\n            return self._cached_tensor\n        else:\n            title_data = [\"<BOS>\"] + self.raw_title + [\"<EOS>\"]\n            title = [title_vocab(x) for x in title_data]\n            ent_text = [\n                [ent_text_vocab(y) for y in x] for x in self.raw_ent_text\n            ]\n            ent_type = [\n                text_vocab(x) for x in self.raw_ent_type\n            ]  # for inference\n            rel_data = [\"--root--\"] + sum(\n                [[x[1], x[1] + \"_INV\"] for x in self.raw_rel], []\n            )\n            rel = [rel_vocab(x) for x in rel_data]\n\n            text_data = [\"<BOS>\"] + self.raw_text + [\"<EOS>\"]\n            text = [text_vocab(x) for x in text_data]\n            tgt_text = []\n            # the input text and decoding target are different since the consideration of the copy mechanism.\n            for i, str1 in enumerate(text_data):\n                if str1[0] == \"<\" and str1[-1] == \">\" and \"_\" in str1:\n                    a, b = str1[1:-1].split(\"_\")\n                    text[i] = text_vocab(\"<\" + a + \">\")\n                    tgt_text.append(len(text_vocab) + int(b))\n                else:\n                    tgt_text.append(text[i])\n            self._cached_tensor = {\n                \"title\": torch.LongTensor(title),\n                \"ent_text\": [torch.LongTensor(x) for x in ent_text],\n                \"ent_type\": torch.LongTensor(ent_type),\n                \"rel\": torch.LongTensor(rel),\n                \"text\": torch.LongTensor(text[:-1]),\n                \"tgt_text\": torch.LongTensor(tgt_text[1:]),\n                \"graph\": self.graph,\n                \"raw_ent_text\": self.raw_ent_text,\n            }\n            return self._cached_tensor\n\n    def update_vocab(\n        self, ent_vocab, rel_vocab, text_vocab, ent_text_vocab, title_vocab\n    ):\n        ent_vocab.update(self.raw_ent_type)\n        ent_text_vocab.update(self.raw_ent_text)\n        title_vocab.update(self.raw_title)\n        rel_vocab.update(\n            [\"--root--\"]\n            + [x[1] for x in self.raw_rel]\n            + [x[1] + \"_INV\" for x in self.raw_rel]\n        )\n        text_vocab.update(self.raw_ent_type)\n        text_vocab.update(self.raw_text)\n\n\nclass BucketSampler(torch.utils.data.Sampler):\n    def __init__(self, data_source, batch_size=32, bucket=3):\n        self.data_source = data_source\n        self.bucket = bucket\n        self.batch_size = batch_size\n\n    def __iter__(self):\n        # the magic number comes from the author's code\n        perm = torch.randperm(len(self.data_source))\n        lens = torch.Tensor([len(x) for x in self.data_source])\n        lens = lens[perm]\n        t1 = []\n        t2 = []\n        t3 = []\n        for i, l in enumerate(lens):\n            if l < 100:\n                t1.append(perm[i])\n            elif l > 100 and l < 220:\n                t2.append(perm[i])\n            else:\n                t3.append(perm[i])\n        datas = [t1, t2, t3]\n        random.shuffle(datas)\n        idxs = sum(datas, [])\n        batch = []\n\n        lens = torch.Tensor([len(x) for x in self.data_source])\n        for idx in idxs:\n            batch.append(idx)\n            mlen = max([0] + [lens[x] for x in batch])\n            if (\n                (mlen < 100 and len(batch) == 32)\n                or (mlen > 100 and mlen < 220 and len(batch) >= 24)\n                or (mlen > 220 and len(batch) >= 8)\n                or len(batch) == 32\n            ):\n                yield batch\n                batch = []\n        if len(batch) > 0:\n            yield batch\n\n    def __len__(self):\n        return (len(self.data_source) + self.batch_size - 1) // self.batch_size\n\n\nclass GWdataset(torch.utils.data.Dataset):\n    def __init__(\n        self,\n        exs,\n        ent_vocab=None,\n        rel_vocab=None,\n        text_vocab=None,\n        ent_text_vocab=None,\n        title_vocab=None,\n        device=None,\n    ):\n        super(GWdataset, self).__init__()\n        self.exs = exs\n        (\n            self.ent_vocab,\n            self.rel_vocab,\n            self.text_vocab,\n            self.ent_text_vocab,\n            self.title_vocab,\n            self.device,\n        ) = (\n            ent_vocab,\n            rel_vocab,\n            text_vocab,\n            ent_text_vocab,\n            title_vocab,\n            device,\n        )\n\n    def __iter__(self):\n        return iter(self.exs)\n\n    def __getitem__(self, index):\n        return self.exs[index]\n\n    def __len__(self):\n        return len(self.exs)\n\n    def batch_fn(self, batch_ex):\n        (\n            batch_title,\n            batch_ent_text,\n            batch_ent_type,\n            batch_rel,\n            batch_text,\n            batch_tgt_text,\n            batch_graph,\n        ) = ([], [], [], [], [], [], [])\n        batch_raw_ent_text = []\n        for ex in batch_ex:\n            ex_data = ex.get_tensor(\n                self.ent_vocab,\n                self.rel_vocab,\n                self.text_vocab,\n                self.ent_text_vocab,\n                self.title_vocab,\n            )\n            batch_title.append(ex_data[\"title\"])\n            batch_ent_text.append(ex_data[\"ent_text\"])\n            batch_ent_type.append(ex_data[\"ent_type\"])\n            batch_rel.append(ex_data[\"rel\"])\n            batch_text.append(ex_data[\"text\"])\n            batch_tgt_text.append(ex_data[\"tgt_text\"])\n            batch_graph.append(ex_data[\"graph\"])\n            batch_raw_ent_text.append(ex_data[\"raw_ent_text\"])\n        batch_title = pad(batch_title, out_type=\"tensor\")\n        batch_ent_text, ent_len = pad(\n            batch_ent_text, out_type=\"tensor\", flatten=True\n        )\n        batch_ent_type = pad(batch_ent_type, out_type=\"tensor\")\n        batch_rel = pad(batch_rel, out_type=\"tensor\")\n        batch_text = pad(batch_text, out_type=\"tensor\")\n        batch_tgt_text = pad(batch_tgt_text, out_type=\"tensor\")\n        batch_graph = dgl.batch(batch_graph)\n        batch_graph.to(self.device)\n        return {\n            \"title\": batch_title.to(self.device),\n            \"ent_text\": batch_ent_text.to(self.device),\n            \"ent_len\": ent_len,\n            \"ent_type\": batch_ent_type.to(self.device),\n            \"rel\": batch_rel.to(self.device),\n            \"text\": batch_text.to(self.device),\n            \"tgt_text\": batch_tgt_text.to(self.device),\n            \"graph\": batch_graph,\n            \"raw_ent_text\": batch_raw_ent_text,\n        }\n\n\ndef get_datasets(\n    fnames,\n    min_freq=-1,\n    sep=\";\",\n    joint_vocab=True,\n    device=None,\n    save=\"tmp.pickle\",\n):\n    # min_freq : not support now since it's very sensitive to the final results, but you can set it via passing min_freq to the Vocab class.\n    # sep : not support now\n    # joint_vocab : not support now\n    ent_vocab = Vocab(sp=[\"<PAD>\", \"<UNK>\"])\n    title_vocab = Vocab(min_freq=5)\n    rel_vocab = Vocab(sp=[\"<PAD>\", \"<UNK>\"])\n    text_vocab = Vocab(min_freq=5)\n    ent_text_vocab = Vocab(sp=[\"<PAD>\", \"<UNK>\"])\n    datasets = []\n    for fname in fnames:\n        exs = []\n        json_datas = json.loads(open(fname).read())\n        for json_data in json_datas:\n            # construct one data example\n            ex = Example.from_json(json_data)\n            if fname == fnames[0]:  # only training set\n                ex.update_vocab(\n                    ent_vocab,\n                    rel_vocab,\n                    text_vocab,\n                    ent_text_vocab,\n                    title_vocab,\n                )\n            exs.append(ex)\n        datasets.append(exs)\n    ent_vocab.build()\n    rel_vocab.build()\n    text_vocab.build()\n    ent_text_vocab.build()\n    title_vocab.build()\n    datasets = [\n        GWdataset(\n            exs,\n            ent_vocab,\n            rel_vocab,\n            text_vocab,\n            ent_text_vocab,\n            title_vocab,\n            device,\n        )\n        for exs in datasets\n    ]\n    with open(save, \"wb\") as f:\n        pickle.dump(datasets, f)\n    return datasets\n\n\nif __name__ == \"__main__\":\n    ds = get_datasets(\n        [\n            \"data/unprocessed.val.json\",\n            \"data/unprocessed.val.json\",\n            \"data/unprocessed.test.json\",\n        ]\n    )\n    print(ds[0].exs[0])\n    print(\n        ds[0]\n        .exs[0]\n        .get_tensor(\n            ds[0].ent_vocab,\n            ds[0].rel_vocab,\n            ds[0].text_vocab,\n            ds[0].ent_text_vocab,\n            ds[0].title_vocab,\n        )\n    )\n"
  },
  {
    "path": "examples/pytorch/gxn/README.md",
    "content": "# DGL Implementation of Graph Cross Networks with Vertex Infomax Pooling (NeurIPS 2020)\n\nThis DGL example implements the GNN model proposed in the paper [Graph Cross Networks with Vertex Infomax Pooling](https://arxiv.org/pdf/2010.01804.pdf). \nThe author's codes of implementation is in [here](https://github.com/limaosen0/GXN)\n\n\nThe graph dataset used in this example \n---------------------------------------\nThe DGL's built-in LegacyTUDataset. This is a serial of graph kernel datasets for graph classification. We use 'DD', 'PROTEINS', 'ENZYMES', 'IMDB-BINARY', 'IMDB-MULTI' and 'COLLAB' in this GXN implementation. All these datasets are randomly splited to train and test set with ratio 0.9 and 0.1 (which is similar to the setting in the author's implementation).\n\nNOTE: Follow the setting of the author's implementation, for 'DD' and 'PROTEINS', we use one-hot node label as input node features. For ENZYMES', 'IMDB-BINARY', 'IMDB-MULTI' and 'COLLAB', we use the concatenation of one-hot node label (if available) and one-hot node degree as input node features.\n\n|                  | DD     | PROTEINS | ENZYMES | IMDB-BINARY  | IMDB-MULTI | COLLAB   |\n| ---------------- | ------ | -------- | ------- | ------------ | ---------- | -------- |\n| NumGraphs        | 1178   | 1113     | 600     | 1000         | 1500       | 5000     |\n| AvgNodesPerGraph | 284.32 | 39.06    | 32.63   | 19.77        | 13.00      | 74.49    |\n| AvgEdgesPerGraph | 715.66 | 72.82    | 62.14   | 96.53        | 65.94      | 2457.78  |\n| NumFeats         | 89     | 1        | 18      | -            | -          | -        |\n| NumClasses       | 2      | 2        | 6       | 2            | 3          | 2        |\n\n\nHow to run example files\n--------------------------------\nIf you want to reproduce the author's result, at the root directory of this example (gxn), run\n\n```bash\nbash scripts/run_gxn.sh ${dataset_name} ${device_id} ${num_trials} ${print_trainlog_every}\n```\n\nIf you want to perform a early-stop version experiment, at the root directory of this example, run\n\n```bash\nbash scripts/run_gxn_early_stop.sh ${dataset_name} ${device_id} ${num_trials} ${print_trainlog_every}\n```\n\nwhere\n- dataset_name: Dataset name used in this experiment. Could be DD', 'PROTEINS', 'ENZYMES', 'IMDB-BINARY', 'IMDB-MULTI' and 'COLLAB'.\n- device_id: ID of computation device. -1 for pure CPU computation. For example if you only have single GPU, set this value to be 0.\n- num_trials: How many times does the experiment conducted.\n- print_training_log_every: Print training log every ? epochs. -1 for silent training.\n\n\nNOTE: If your have problem when using 'IMDB-BINARY', 'IMDB-MULTI' and 'COLLAB', it could be caused by a bug in `LegacyTUDataset`/`TUDataset` in DGL (see [here](https://github.com/dmlc/dgl/pull/2543)). If your DGL version is less than or equal to 0.5.3 and you encounter problems like \"undefined variable\" (`LegacyTUDataset`) or \"the argument `force_reload=False` does not work\" (`TUDataset`), try:\n- use `TUDataset` with `force_reload=True`\n- delete dataset files \n- change `degree_as_feature(dataset)` and `node_label_as_feature(dataset, mode=mode)` to `degree_as_feature(dataset, save=False)` and `node_label_as_feature(dataset, mode=mode, save=False)` in `main.py`.\n\nPerformance\n-------------------------\n\n**Accuracy**\n\n**NOTE**: Different from our implementation, the author uses fixed dataset split. Thus there may be difference between our result and the author's result. **To compare our implementation with the author's, we follow the setting in the author's implementation that performs model-selection on testset**. We also try early-stop with patience equals to 1/5 of the total number of epochs for some datasets. The result of `Author's Code` in the table below are obtained using first-ford data as the test dataset.\n\n|                   | DD           | PROTEINS    | ENZYMES     | IMDB-BINARY | IMDB-MULTI | COLLAB     |\n| ------------------| ------------ | ----------- | ----------- | ----------- | ---------- | ---------- |\n| Reported in Paper | 82.68(4.1 )  | 79.91(4.1)  | 57.50(6.1)  | 78.60(2.3)  | 55.20(2.5) | 78.82(1.4) |\n| Author's Code     | 82.05        | 72.07       | 58.33       | 77.00       | 56.00      | 80.40      |\n| DGL               | 82.97(3.0)   | 78.21(2.0)  | 57.50(5.5)  | 78.70(4.0)  | 52.26(2.0) | 80.58(2.4) |\n| DGL(early-stop)   | 78.66(4.3)   | 73.12(3.1)  | 39.83(7.4)  | 68.60(6.7)  | 45.40(9.4) | 76.18(1.9) |\n\n\n**Speed**\n\nDevice: \n- CPU: Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz\n- GPU: Tesla V100-SXM2 16GB\n\nIn seconds\n\n|               | DD    | PROTEINS | ENZYMES | IMDB-BINARY | IMDB-MULTI | COLLAB(batch_size=64) | COLLAB(batch_size=20) |\n| ------------- | ----- | -------- | ------- | ----------- | ---------- | --------------------- | --------------------- |\n| Author's Code | 25.32 | 2.93     | 1.53    | 2.42        | 3.58       | 96.69                 | 19.78                 |\n| DGL           | 2.64  | 1.86     | 1.03    | 1.79        | 2.45       | 23.52                 | 32.29                 |\n"
  },
  {
    "path": "examples/pytorch/gxn/data_preprocess.py",
    "content": "import json\nimport logging\nimport os\nimport sys\n\nimport numpy as np\nimport torch\n\nfrom dgl.data import LegacyTUDataset\n\n\ndef _load_check_mark(path: str):\n    if os.path.exists(path):\n        with open(path, \"r\") as f:\n            return json.load(f)\n    else:\n        return {}\n\n\ndef _save_check_mark(path: str, marks: dict):\n    with open(path, \"w\") as f:\n        json.dump(marks, f)\n\n\ndef node_label_as_feature(dataset: LegacyTUDataset, mode=\"concat\", save=True):\n    \"\"\"\n    Description\n    -----------\n    Add node labels to graph node features dict\n\n    Parameters\n    ----------\n    dataset : LegacyTUDataset\n        The dataset object\n    concat : str, optional\n        How to add node label to the graph. Valid options are \"add\",\n        \"replace\" and \"concat\".\n        - \"add\": Directly add node_label to graph node feature dict.\n        - \"concat\": Concatenate \"feat\" and \"node_label\"\n        - \"replace\": Use \"node_label\" as \"feat\"\n        Default: :obj:`\"concat\"`\n    save : bool, optional\n        Save the result dataset.\n        Default: :obj:`True`\n    \"\"\"\n    # check if node label is not available\n    if (\n        not os.path.exists(dataset._file_path(\"node_labels\"))\n        or len(dataset) == 0\n    ):\n        logging.warning(\"No Node Label Data\")\n        return dataset\n\n    # check if has cached value\n    check_mark_name = \"node_label_as_feature\"\n    check_mark_path = os.path.join(\n        dataset.save_path, \"info_{}_{}.json\".format(dataset.name, dataset.hash)\n    )\n    check_mark = _load_check_mark(check_mark_path)\n    if (\n        check_mark_name in check_mark\n        and check_mark[check_mark_name]\n        and not dataset._force_reload\n    ):\n        logging.warning(\"Using cached value in node_label_as_feature\")\n        return dataset\n    logging.warning(\n        \"Adding node labels into node features..., mode={}\".format(mode)\n    )\n\n    # check if graph has \"feat\"\n    if \"feat\" not in dataset[0][0].ndata:\n        logging.warning(\"Dataset has no node feature 'feat'\")\n        if mode.lower() == \"concat\":\n            mode = \"replace\"\n\n    # first read node labels\n    DS_node_labels = dataset._idx_from_zero(\n        np.loadtxt(dataset._file_path(\"node_labels\"), dtype=int)\n    )\n    one_hot_node_labels = dataset._to_onehot(DS_node_labels)\n\n    # read graph idx\n    DS_indicator = dataset._idx_from_zero(\n        np.genfromtxt(dataset._file_path(\"graph_indicator\"), dtype=int)\n    )\n    node_idx_list = []\n    for idx in range(np.max(DS_indicator) + 1):\n        node_idx = np.where(DS_indicator == idx)\n        node_idx_list.append(node_idx[0])\n\n    # add to node feature dict\n    for idx, g in zip(node_idx_list, dataset.graph_lists):\n        node_labels_tensor = torch.tensor(one_hot_node_labels[idx, :])\n        if mode.lower() == \"concat\":\n            g.ndata[\"feat\"] = torch.cat(\n                (g.ndata[\"feat\"], node_labels_tensor), dim=1\n            )\n        elif mode.lower() == \"add\":\n            g.ndata[\"node_label\"] = node_labels_tensor\n        else:  # replace\n            g.ndata[\"feat\"] = node_labels_tensor\n\n    if save:\n        check_mark[check_mark_name] = True\n        _save_check_mark(check_mark_path, check_mark)\n        dataset.save()\n    return dataset\n\n\ndef degree_as_feature(dataset: LegacyTUDataset, save=True):\n    \"\"\"\n    Description\n    -----------\n    Use node degree (in one-hot format) as node feature\n\n    Parameters\n    ----------\n    dataset : LegacyTUDataset\n        The dataset object\n\n    save : bool, optional\n        Save the result dataset.\n        Default: :obj:`True`\n    \"\"\"\n    # first check if already have such feature\n    check_mark_name = \"degree_as_feat\"\n    feat_name = \"feat\"\n    check_mark_path = os.path.join(\n        dataset.save_path, \"info_{}_{}.json\".format(dataset.name, dataset.hash)\n    )\n    check_mark = _load_check_mark(check_mark_path)\n\n    if (\n        check_mark_name in check_mark\n        and check_mark[check_mark_name]\n        and not dataset._force_reload\n    ):\n        logging.warning(\"Using cached value in 'degree_as_feature'\")\n        return dataset\n\n    logging.warning(\"Adding node degree into node features...\")\n    min_degree = sys.maxsize\n    max_degree = 0\n    for i in range(len(dataset)):\n        degrees = dataset.graph_lists[i].in_degrees()\n        min_degree = min(min_degree, degrees.min().item())\n        max_degree = max(max_degree, degrees.max().item())\n\n    vec_len = max_degree - min_degree + 1\n    for i in range(len(dataset)):\n        num_nodes = dataset.graph_lists[i].num_nodes()\n        node_feat = torch.zeros((num_nodes, vec_len))\n        degrees = dataset.graph_lists[i].in_degrees()\n        node_feat[torch.arange(num_nodes), degrees - min_degree] = 1.0\n        dataset.graph_lists[i].ndata[feat_name] = node_feat\n\n    if save:\n        check_mark[check_mark_name] = True\n        dataset.save()\n        _save_check_mark(check_mark_path, check_mark)\n    return dataset\n"
  },
  {
    "path": "examples/pytorch/gxn/layers.py",
    "content": "from typing import Optional\n\nimport dgl\n\nimport torch\nimport torch.nn\nfrom dgl import DGLGraph\nfrom dgl.nn import GraphConv\nfrom torch import Tensor\n\n\nclass GraphConvWithDropout(GraphConv):\n    \"\"\"\n    A GraphConv followed by a Dropout.\n    \"\"\"\n\n    def __init__(\n        self,\n        in_feats,\n        out_feats,\n        dropout=0.3,\n        norm=\"both\",\n        weight=True,\n        bias=True,\n        activation=None,\n        allow_zero_in_degree=False,\n    ):\n        super(GraphConvWithDropout, self).__init__(\n            in_feats,\n            out_feats,\n            norm,\n            weight,\n            bias,\n            activation,\n            allow_zero_in_degree,\n        )\n        self.dropout = torch.nn.Dropout(p=dropout)\n\n    def call(self, graph, feat, weight=None):\n        feat = self.dropout(feat)\n        return super(GraphConvWithDropout, self).call(graph, feat, weight)\n\n\nclass Discriminator(torch.nn.Module):\n    \"\"\"\n    Description\n    -----------\n    A discriminator used to let the network to discrimate\n    between positive (neighborhood of center node) and\n    negative (any neighborhood in graph) samplings.\n\n    Parameters\n    ----------\n    feat_dim : int\n        The number of channels of node features.\n    \"\"\"\n\n    def __init__(self, feat_dim: int):\n        super(Discriminator, self).__init__()\n        self.affine = torch.nn.Bilinear(feat_dim, feat_dim, 1)\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        torch.nn.init.xavier_uniform_(self.affine.weight)\n        torch.nn.init.zeros_(self.affine.bias)\n\n    def forward(\n        self,\n        h_x: Tensor,\n        h_pos: Tensor,\n        h_neg: Tensor,\n        bias_pos: Optional[Tensor] = None,\n        bias_neg: Optional[Tensor] = None,\n    ):\n        \"\"\"\n        Parameters\n        ----------\n        h_x : torch.Tensor\n            Node features, shape: :obj:`(num_nodes, feat_dim)`\n        h_pos : torch.Tensor\n            The node features of positive samples\n            It has the same shape as :obj:`h_x`\n        h_neg : torch.Tensor\n            The node features of negative samples\n            It has the same shape as :obj:`h_x`\n        bias_pos : torch.Tensor\n            Bias parameter vector for positive scores\n            shape: :obj:`(num_nodes)`\n        bias_neg : torch.Tensor\n            Bias parameter vector for negative scores\n            shape: :obj:`(num_nodes)`\n\n        Returns\n        -------\n        (torch.Tensor, torch.Tensor)\n            The output scores with shape (2 * num_nodes,), (num_nodes,)\n        \"\"\"\n        score_pos = self.affine(h_pos, h_x).squeeze()\n        score_neg = self.affine(h_neg, h_x).squeeze()\n        if bias_pos is not None:\n            score_pos = score_pos + bias_pos\n        if bias_neg is not None:\n            score_neg = score_neg + bias_neg\n\n        logits = torch.cat((score_pos, score_neg), 0)\n\n        return logits, score_pos\n\n\nclass DenseLayer(torch.nn.Module):\n    \"\"\"\n    Description\n    -----------\n    Dense layer with a linear layer and an activation function\n    \"\"\"\n\n    def __init__(\n        self, in_dim: int, out_dim: int, act: str = \"prelu\", bias=True\n    ):\n        super(DenseLayer, self).__init__()\n        self.lin = torch.nn.Linear(in_dim, out_dim, bias=bias)\n        self.act_type = act.lower()\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        torch.nn.init.xavier_uniform_(self.lin.weight)\n        if self.lin.bias is not None:\n            torch.nn.init.zeros_(self.lin.bias)\n        if self.act_type == \"prelu\":\n            self.act = torch.nn.PReLU()\n        else:\n            self.act = torch.relu\n\n    def forward(self, x):\n        x = self.lin(x)\n        return self.act(x)\n\n\nclass IndexSelect(torch.nn.Module):\n    \"\"\"\n    Description\n    -----------\n    The index selection layer used by VIPool\n\n    Parameters\n    ----------\n    pool_ratio : float\n        The pooling ratio (for keeping nodes). For example,\n        if `pool_ratio=0.8`, 80\\% nodes will be preserved.\n    hidden_dim : int\n        The number of channels in node features.\n    act : str, optional\n        The activation function type.\n        Default: :obj:`'prelu'`\n    dist : int, optional\n        DO NOT USE THIS PARAMETER\n    \"\"\"\n\n    def __init__(\n        self,\n        pool_ratio: float,\n        hidden_dim: int,\n        act: str = \"prelu\",\n        dist: int = 1,\n    ):\n        super(IndexSelect, self).__init__()\n        self.pool_ratio = pool_ratio\n        self.dist = dist\n        self.dense = DenseLayer(hidden_dim, hidden_dim, act)\n        self.discriminator = Discriminator(hidden_dim)\n        self.gcn = GraphConvWithDropout(hidden_dim, hidden_dim)\n\n    def forward(\n        self,\n        graph: DGLGraph,\n        h_pos: Tensor,\n        h_neg: Tensor,\n        bias_pos: Optional[Tensor] = None,\n        bias_neg: Optional[Tensor] = None,\n    ):\n        \"\"\"\n        Description\n        -----------\n        Perform index selection\n\n        Parameters\n        ----------\n        graph : dgl.DGLGraph\n            Input graph.\n        h_pos : torch.Tensor\n            The node features of positive samples\n            It has the same shape as :obj:`h_x`\n        h_neg : torch.Tensor\n            The node features of negative samples\n            It has the same shape as :obj:`h_x`\n        bias_pos : torch.Tensor\n            Bias parameter vector for positive scores\n            shape: :obj:`(num_nodes)`\n        bias_neg : torch.Tensor\n            Bias parameter vector for negative scores\n            shape: :obj:`(num_nodes)`\n        \"\"\"\n        # compute scores\n        h_pos = self.dense(h_pos)\n        h_neg = self.dense(h_neg)\n        embed = self.gcn(graph, h_pos)\n        h_center = torch.sigmoid(embed)\n\n        logit, logit_pos = self.discriminator(\n            h_center, h_pos, h_neg, bias_pos, bias_neg\n        )\n        scores = torch.sigmoid(logit_pos)\n\n        # sort scores\n        scores, idx = torch.sort(scores, descending=True)\n\n        # select top-k\n        num_nodes = graph.num_nodes()\n        num_select_nodes = int(self.pool_ratio * num_nodes)\n        size_list = [num_select_nodes, num_nodes - num_select_nodes]\n        select_scores, _ = torch.split(scores, size_list, dim=0)\n        select_idx, non_select_idx = torch.split(idx, size_list, dim=0)\n\n        return logit, select_scores, select_idx, non_select_idx, embed\n\n\nclass GraphPool(torch.nn.Module):\n    \"\"\"\n    Description\n    -----------\n    The pooling module for graph\n\n    Parameters\n    ----------\n    hidden_dim : int\n        The number of channels of node features.\n    use_gcn : bool, optional\n        Whether use gcn in down sampling process.\n        default: :obj:`False`\n    \"\"\"\n\n    def __init__(self, hidden_dim: int, use_gcn=False):\n        super(GraphPool, self).__init__()\n        self.use_gcn = use_gcn\n        self.down_sample_gcn = (\n            GraphConvWithDropout(hidden_dim, hidden_dim) if use_gcn else None\n        )\n\n    def forward(\n        self,\n        graph: DGLGraph,\n        feat: Tensor,\n        select_idx: Tensor,\n        non_select_idx: Optional[Tensor] = None,\n        scores: Optional[Tensor] = None,\n        pool_graph=False,\n    ):\n        \"\"\"\n        Description\n        -----------\n        Perform graph pooling.\n\n        Parameters\n        ----------\n        graph : dgl.DGLGraph\n            The input graph\n        feat : torch.Tensor\n            The input node feature\n        select_idx : torch.Tensor\n            The index in fine graph of node from\n            coarse graph, this is obtained from\n            previous graph pooling layers.\n        non_select_idx : torch.Tensor, optional\n            The index that not included in output graph.\n            default: :obj:`None`\n        scores : torch.Tensor, optional\n            Scores for nodes used for pooling and scaling.\n            default: :obj:`None`\n        pool_graph : bool, optional\n            Whether perform graph pooling on graph topology.\n            default: :obj:`False`\n        \"\"\"\n        if self.use_gcn:\n            feat = self.down_sample_gcn(graph, feat)\n\n        feat = feat[select_idx]\n        if scores is not None:\n            feat = feat * scores.unsqueeze(-1)\n\n        if pool_graph:\n            num_node_batch = graph.batch_num_nodes()\n            graph = dgl.node_subgraph(graph, select_idx)\n            graph.set_batch_num_nodes(num_node_batch)\n            return feat, graph\n        else:\n            return feat\n\n\nclass GraphUnpool(torch.nn.Module):\n    \"\"\"\n    Description\n    -----------\n    The unpooling module for graph\n\n    Parameters\n    ----------\n    hidden_dim : int\n        The number of channels of node features.\n    \"\"\"\n\n    def __init__(self, hidden_dim: int):\n        super(GraphUnpool, self).__init__()\n        self.up_sample_gcn = GraphConvWithDropout(hidden_dim, hidden_dim)\n\n    def forward(self, graph: DGLGraph, feat: Tensor, select_idx: Tensor):\n        \"\"\"\n        Description\n        -----------\n        Perform graph unpooling\n\n        Parameters\n        ----------\n        graph : dgl.DGLGraph\n            The input graph\n        feat : torch.Tensor\n            The input node feature\n        select_idx : torch.Tensor\n            The index in fine graph of node from\n            coarse graph, this is obtained from\n            previous graph pooling layers.\n        \"\"\"\n        fine_feat = torch.zeros(\n            (graph.num_nodes(), feat.size(-1)), device=feat.device\n        )\n        fine_feat[select_idx] = feat\n        fine_feat = self.up_sample_gcn(graph, fine_feat)\n        return fine_feat\n"
  },
  {
    "path": "examples/pytorch/gxn/main.py",
    "content": "import json\nimport os\nfrom datetime import datetime\nfrom time import time\n\nimport dgl\n\nimport torch\nimport torch.nn.functional as F\nfrom data_preprocess import degree_as_feature, node_label_as_feature\nfrom dgl.data import LegacyTUDataset\nfrom dgl.dataloading import GraphDataLoader\nfrom networks import GraphClassifier\nfrom torch import Tensor\nfrom torch.utils.data import random_split\nfrom utils import get_stats, parse_args\n\n\ndef compute_loss(\n    cls_logits: Tensor,\n    labels: Tensor,\n    logits_s1: Tensor,\n    logits_s2: Tensor,\n    epoch: int,\n    total_epochs: int,\n    device: torch.device,\n):\n    # classification loss\n    classify_loss = F.nll_loss(cls_logits, labels.to(device))\n\n    # loss for vertex infomax pooling\n    scale1, scale2 = logits_s1.size(0) // 2, logits_s2.size(0) // 2\n    s1_label_t, s1_label_f = torch.ones(scale1), torch.zeros(scale1)\n    s2_label_t, s2_label_f = torch.ones(scale2), torch.zeros(scale2)\n    s1_label = torch.cat((s1_label_t, s1_label_f), dim=0).to(device)\n    s2_label = torch.cat((s2_label_t, s2_label_f), dim=0).to(device)\n\n    pool_loss_s1 = F.binary_cross_entropy_with_logits(logits_s1, s1_label)\n    pool_loss_s2 = F.binary_cross_entropy_with_logits(logits_s2, s2_label)\n    pool_loss = (pool_loss_s1 + pool_loss_s2) / 2\n\n    loss = classify_loss + (2 - epoch / total_epochs) * pool_loss\n\n    return loss\n\n\ndef train(\n    model: torch.nn.Module,\n    optimizer,\n    trainloader,\n    device,\n    curr_epoch,\n    total_epochs,\n):\n    model.train()\n\n    total_loss = 0.0\n    num_batches = len(trainloader)\n\n    for batch in trainloader:\n        optimizer.zero_grad()\n        batch_graphs, batch_labels = batch\n        batch_graphs = batch_graphs.to(device)\n        batch_labels = batch_labels.long().to(device)\n        out, l1, l2 = model(batch_graphs, batch_graphs.ndata[\"feat\"])\n        loss = compute_loss(\n            out, batch_labels, l1, l2, curr_epoch, total_epochs, device\n        )\n        loss.backward()\n        optimizer.step()\n\n        total_loss += loss.item()\n\n    return total_loss / num_batches\n\n\n@torch.no_grad()\ndef test(model: torch.nn.Module, loader, device):\n    model.eval()\n\n    correct = 0.0\n    num_graphs = 0\n\n    for batch in loader:\n        batch_graphs, batch_labels = batch\n        num_graphs += batch_labels.size(0)\n        batch_graphs = batch_graphs.to(device)\n        batch_labels = batch_labels.long().to(device)\n        out, _, _ = model(batch_graphs, batch_graphs.ndata[\"feat\"])\n        pred = out.argmax(dim=1)\n        correct += pred.eq(batch_labels).sum().item()\n\n    return correct / num_graphs\n\n\ndef main(args):\n    # Step 1: Prepare graph data and retrieve train/validation/test index ============================= #\n    dataset = LegacyTUDataset(args.dataset, raw_dir=args.dataset_path)\n\n    # add self loop. We add self loop for each graph here since the function \"add_self_loop\" does not\n    # support batch graph.\n    for i in range(len(dataset)):\n        dataset.graph_lists[i] = dgl.remove_self_loop(dataset.graph_lists[i])\n        dataset.graph_lists[i] = dgl.add_self_loop(dataset.graph_lists[i])\n\n    # preprocess: use node degree/label as node feature\n    if args.degree_as_feature:\n        dataset = degree_as_feature(dataset)\n        mode = \"concat\"\n    else:\n        mode = \"replace\"\n    dataset = node_label_as_feature(dataset, mode=mode)\n\n    num_training = int(len(dataset) * 0.9)\n    num_test = len(dataset) - num_training\n    train_set, test_set = random_split(dataset, [num_training, num_test])\n\n    train_loader = GraphDataLoader(\n        train_set, batch_size=args.batch_size, shuffle=True, num_workers=1\n    )\n    test_loader = GraphDataLoader(\n        test_set, batch_size=args.batch_size, num_workers=1\n    )\n\n    device = torch.device(args.device)\n\n    # Step 2: Create model =================================================================== #\n    num_feature, num_classes, _ = dataset.statistics()\n    args.in_dim = int(num_feature)\n    args.out_dim = int(num_classes)\n    args.edge_feat_dim = 0  # No edge feature in datasets that we use.\n\n    model = GraphClassifier(args).to(device)\n\n    # Step 3: Create training components ===================================================== #\n    optimizer = torch.optim.Adam(\n        model.parameters(),\n        lr=args.lr,\n        amsgrad=True,\n        weight_decay=args.weight_decay,\n    )\n\n    # Step 4: training epoches =============================================================== #\n    best_test_acc = 0.0\n    best_epoch = -1\n    train_times = []\n    for e in range(args.epochs):\n        s_time = time()\n        train_loss = train(\n            model, optimizer, train_loader, device, e, args.epochs\n        )\n        train_times.append(time() - s_time)\n        test_acc = test(model, test_loader, device)\n        if test_acc > best_test_acc:\n            best_test_acc = test_acc\n            best_epoch = e + 1\n\n        if (e + 1) % args.print_every == 0:\n            log_format = (\n                \"Epoch {}: loss={:.4f}, test_acc={:.4f}, best_test_acc={:.4f}\"\n            )\n            print(log_format.format(e + 1, train_loss, test_acc, best_test_acc))\n    print(\n        \"Best Epoch {}, final test acc {:.4f}\".format(best_epoch, best_test_acc)\n    )\n    return best_test_acc, sum(train_times) / len(train_times)\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    res = []\n    train_times = []\n    for i in range(args.num_trials):\n        print(\"Trial {}/{}\".format(i + 1, args.num_trials))\n        acc, train_time = main(args)\n        # acc, train_time = 0, 0\n        res.append(acc)\n        train_times.append(train_time)\n\n    mean, err_bd = get_stats(res, conf_interval=False)\n    print(\"mean acc: {:.4f}, error bound: {:.4f}\".format(mean, err_bd))\n\n    out_dict = {\n        \"hyper-parameters\": vars(args),\n        \"result_date\": str(datetime.now()),\n        \"result\": \"{:.4f}(+-{:.4f})\".format(mean, err_bd),\n        \"train_time\": \"{:.4f}\".format(sum(train_times) / len(train_times)),\n        \"details\": res,\n    }\n\n    with open(\n        os.path.join(args.output_path, \"{}.log\".format(args.dataset)), \"w\"\n    ) as f:\n        json.dump(out_dict, f, sort_keys=True, indent=4)\n"
  },
  {
    "path": "examples/pytorch/gxn/main_early_stop.py",
    "content": "import json\nimport os\nfrom datetime import datetime\nfrom time import time\n\nimport dgl\n\nimport torch\nimport torch.nn.functional as F\nfrom data_preprocess import degree_as_feature, node_label_as_feature\nfrom dgl.data import LegacyTUDataset\nfrom dgl.dataloading import GraphDataLoader\nfrom networks import GraphClassifier\nfrom torch import Tensor\nfrom torch.utils.data import random_split\nfrom utils import get_stats, parse_args\n\n\ndef compute_loss(\n    cls_logits: Tensor,\n    labels: Tensor,\n    logits_s1: Tensor,\n    logits_s2: Tensor,\n    epoch: int,\n    total_epochs: int,\n    device: torch.device,\n):\n    # classification loss\n    classify_loss = F.nll_loss(cls_logits, labels.to(device))\n\n    # loss for vertex infomax pooling\n    scale1, scale2 = logits_s1.size(0) // 2, logits_s2.size(0) // 2\n    s1_label_t, s1_label_f = torch.ones(scale1), torch.zeros(scale1)\n    s2_label_t, s2_label_f = torch.ones(scale2), torch.zeros(scale2)\n    s1_label = torch.cat((s1_label_t, s1_label_f), dim=0).to(device)\n    s2_label = torch.cat((s2_label_t, s2_label_f), dim=0).to(device)\n\n    pool_loss_s1 = F.binary_cross_entropy_with_logits(logits_s1, s1_label)\n    pool_loss_s2 = F.binary_cross_entropy_with_logits(logits_s2, s2_label)\n    pool_loss = (pool_loss_s1 + pool_loss_s2) / 2\n\n    loss = classify_loss + (2 - epoch / total_epochs) * pool_loss\n\n    return loss\n\n\ndef train(\n    model: torch.nn.Module,\n    optimizer,\n    trainloader,\n    device,\n    curr_epoch,\n    total_epochs,\n):\n    model.train()\n\n    total_loss = 0.0\n    num_batches = len(trainloader)\n\n    for batch in trainloader:\n        optimizer.zero_grad()\n        batch_graphs, batch_labels = batch\n        batch_graphs = batch_graphs.to(device)\n        batch_labels = batch_labels.long().to(device)\n        out, l1, l2 = model(batch_graphs, batch_graphs.ndata[\"feat\"])\n        loss = compute_loss(\n            out, batch_labels, l1, l2, curr_epoch, total_epochs, device\n        )\n        loss.backward()\n        optimizer.step()\n\n        total_loss += loss.item()\n\n    return total_loss / num_batches\n\n\n@torch.no_grad()\ndef test(model: torch.nn.Module, loader, device):\n    model.eval()\n\n    correct = 0.0\n    num_graphs = 0\n\n    for batch in loader:\n        batch_graphs, batch_labels = batch\n        num_graphs += batch_labels.size(0)\n        batch_graphs = batch_graphs.to(device)\n        batch_labels = batch_labels.long().to(device)\n        out, _, _ = model(batch_graphs, batch_graphs.ndata[\"feat\"])\n        pred = out.argmax(dim=1)\n        correct += pred.eq(batch_labels).sum().item()\n\n    return correct / num_graphs\n\n\n@torch.no_grad()\ndef validate(model: torch.nn.Module, loader, device, curr_epoch, total_epochs):\n    model.eval()\n\n    tt_loss = 0.0\n    correct = 0.0\n    num_graphs = 0\n    num_batchs = len(loader)\n\n    for batch in loader:\n        batch_graphs, batch_labels = batch\n        num_graphs += batch_labels.size(0)\n        batch_graphs = batch_graphs.to(device)\n        batch_labels = batch_labels.long().to(device)\n        out, l1, l2 = model(batch_graphs, batch_graphs.ndata[\"feat\"])\n        tt_loss += compute_loss(\n            out, batch_labels, l1, l2, curr_epoch, total_epochs, device\n        ).item()\n        pred = out.argmax(dim=1)\n        correct += pred.eq(batch_labels).sum().item()\n\n    return correct / num_graphs, tt_loss / num_batchs\n\n\ndef main(args):\n    # Step 1: Prepare graph data and retrieve train/validation/test index ============================= #\n    dataset = LegacyTUDataset(args.dataset, raw_dir=args.dataset_path)\n\n    # add self loop. We add self loop for each graph here since the function \"add_self_loop\" does not\n    # support batch graph.\n    for i in range(len(dataset)):\n        dataset.graph_lists[i] = dgl.remove_self_loop(dataset.graph_lists[i])\n        dataset.graph_lists[i] = dgl.add_self_loop(dataset.graph_lists[i])\n\n    # use degree as node feature\n    if args.degree_as_feature:\n        dataset = degree_as_feature(dataset)\n        mode = \"concat\"\n    else:\n        mode = \"replace\"\n    dataset = node_label_as_feature(dataset, mode=mode)\n\n    num_training = int(len(dataset) * 0.8)\n    num_val = int(len(dataset) * 0.1)\n    num_test = len(dataset) - num_training - num_val\n    train_set, val_set, test_set = random_split(\n        dataset, [num_training, num_val, num_test]\n    )\n\n    train_loader = GraphDataLoader(\n        train_set, batch_size=args.batch_size, shuffle=True, num_workers=1\n    )\n    val_loader = GraphDataLoader(\n        val_set, batch_size=args.batch_size, num_workers=1\n    )\n    test_loader = GraphDataLoader(\n        test_set, batch_size=args.batch_size, num_workers=1\n    )\n\n    device = torch.device(args.device)\n\n    # Step 2: Create model =================================================================== #\n    num_feature, num_classes, _ = dataset.statistics()\n    args.in_dim = int(num_feature)\n    args.out_dim = int(num_classes)\n    args.edge_feat_dim = 0  # No edge feature in datasets that we use.\n\n    model = GraphClassifier(args).to(device)\n\n    # Step 3: Create training components ===================================================== #\n    optimizer = torch.optim.Adam(\n        model.parameters(),\n        lr=args.lr,\n        amsgrad=True,\n        weight_decay=args.weight_decay,\n    )\n\n    # Step 4: training epoches =============================================================== #\n    best_test_acc = 0.0\n    best_epoch = -1\n    train_times = []\n\n    bad_count = 0\n    best_val_loss = float(\"inf\")\n    for e in range(args.epochs):\n        s_time = time()\n        train_loss = train(\n            model, optimizer, train_loader, device, e, args.epochs\n        )\n        train_times.append(time() - s_time)\n        _, val_loss = validate(model, val_loader, device, e, args.epochs)\n        test_acc = test(model, test_loader, device)\n\n        if best_val_loss > val_loss:\n            best_val_loss = val_loss\n            best_epoch = e\n            bad_count = 0\n            best_test_acc = test_acc\n        else:\n            bad_count += 1\n\n        if bad_count > args.patience:\n            break\n\n        if (e + 1) % args.print_every == 0:\n            log_format = (\n                \"Epoch {}: loss={:.4f}, test_acc={:.4f}, best_test_acc={:.4f}\"\n            )\n            print(log_format.format(e + 1, train_loss, test_acc, best_test_acc))\n    print(\n        \"Best Epoch {}, final test acc {:.4f}\".format(best_epoch, best_test_acc)\n    )\n    return best_test_acc, sum(train_times) / len(train_times)\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    res = []\n    train_times = []\n    for i in range(args.num_trials):\n        print(\"Trial {}/{}\".format(i + 1, args.num_trials))\n        acc, train_time = main(args)\n        # acc, train_time = 0, 0\n        res.append(acc)\n        train_times.append(train_time)\n\n    mean, err_bd = get_stats(res, conf_interval=False)\n    print(\"mean acc: {:.4f}, error bound: {:.4f}\".format(mean, err_bd))\n\n    out_dict = {\n        \"hyper-parameters\": vars(args),\n        \"result_date\": str(datetime.now()),\n        \"result\": \"{:.4f}(+-{:.4f})\".format(mean, err_bd),\n        \"train_time\": \"{:.4f}\".format(sum(train_times) / len(train_times)),\n        \"details\": res,\n    }\n\n    with open(\n        os.path.join(args.output_path, \"{}.log\".format(args.dataset)), \"w\"\n    ) as f:\n        json.dump(out_dict, f, sort_keys=True, indent=4)\n"
  },
  {
    "path": "examples/pytorch/gxn/networks.py",
    "content": "from typing import List, Tuple, Union\n\nfrom layers import *\nimport dgl.function as fn\nimport torch\nimport torch.nn\nimport torch.nn.functional as F\nfrom dgl.nn.pytorch.glob import SortPooling\n\n\nclass GraphCrossModule(torch.nn.Module):\n    \"\"\"\n    Description\n    -----------\n    The Graph Cross Module used by Graph Cross Networks.\n    This module only contains graph cross layers.\n\n    Parameters\n    ----------\n    pool_ratios : Union[float, List[float]]\n        The pooling ratios (for keeping nodes) for each layer.\n        For example, if `pool_ratio=0.8`, 80\\% nodes will be preserved.\n        If a single float number is given, all pooling layers will have the\n        same pooling ratio.\n    in_dim : int\n        The number of input node feature channels.\n    out_dim : int\n        The number of output node feature channels.\n    hidden_dim : int\n        The number of hidden node feature channels.\n    cross_weight : float, optional\n        The weight parameter used in graph cross layers\n        Default: :obj:`1.0`\n    fuse_weight : float, optional\n        The weight parameter used at the end of GXN for channel fusion.\n        Default: :obj:`1.0`\n    \"\"\"\n\n    def __init__(\n        self,\n        pool_ratios: Union[float, List[float]],\n        in_dim: int,\n        out_dim: int,\n        hidden_dim: int,\n        cross_weight: float = 1.0,\n        fuse_weight: float = 1.0,\n        dist: int = 1,\n        num_cross_layers: int = 2,\n    ):\n        super(GraphCrossModule, self).__init__()\n        if isinstance(pool_ratios, float):\n            pool_ratios = (pool_ratios, pool_ratios)\n        self.cross_weight = cross_weight\n        self.fuse_weight = fuse_weight\n        self.num_cross_layers = num_cross_layers\n\n        # build network\n        self.start_gcn_scale1 = GraphConvWithDropout(in_dim, hidden_dim)\n        self.start_gcn_scale2 = GraphConvWithDropout(hidden_dim, hidden_dim)\n        self.end_gcn = GraphConvWithDropout(2 * hidden_dim, out_dim)\n\n        self.index_select_scale1 = IndexSelect(\n            pool_ratios[0], hidden_dim, act=\"prelu\", dist=dist\n        )\n        self.index_select_scale2 = IndexSelect(\n            pool_ratios[1], hidden_dim, act=\"prelu\", dist=dist\n        )\n        self.start_pool_s12 = GraphPool(hidden_dim)\n        self.start_pool_s23 = GraphPool(hidden_dim)\n        self.end_unpool_s21 = GraphUnpool(hidden_dim)\n        self.end_unpool_s32 = GraphUnpool(hidden_dim)\n\n        self.s1_l1_gcn = GraphConvWithDropout(hidden_dim, hidden_dim)\n        self.s1_l2_gcn = GraphConvWithDropout(hidden_dim, hidden_dim)\n        self.s1_l3_gcn = GraphConvWithDropout(hidden_dim, hidden_dim)\n\n        self.s2_l1_gcn = GraphConvWithDropout(hidden_dim, hidden_dim)\n        self.s2_l2_gcn = GraphConvWithDropout(hidden_dim, hidden_dim)\n        self.s2_l3_gcn = GraphConvWithDropout(hidden_dim, hidden_dim)\n\n        self.s3_l1_gcn = GraphConvWithDropout(hidden_dim, hidden_dim)\n        self.s3_l2_gcn = GraphConvWithDropout(hidden_dim, hidden_dim)\n        self.s3_l3_gcn = GraphConvWithDropout(hidden_dim, hidden_dim)\n\n        if num_cross_layers >= 1:\n            self.pool_s12_1 = GraphPool(hidden_dim, use_gcn=True)\n            self.unpool_s21_1 = GraphUnpool(hidden_dim)\n            self.pool_s23_1 = GraphPool(hidden_dim, use_gcn=True)\n            self.unpool_s32_1 = GraphUnpool(hidden_dim)\n        if num_cross_layers >= 2:\n            self.pool_s12_2 = GraphPool(hidden_dim, use_gcn=True)\n            self.unpool_s21_2 = GraphUnpool(hidden_dim)\n            self.pool_s23_2 = GraphPool(hidden_dim, use_gcn=True)\n            self.unpool_s32_2 = GraphUnpool(hidden_dim)\n\n    def forward(self, graph, feat):\n        # start of scale-1\n        graph_scale1 = graph\n        feat_scale1 = self.start_gcn_scale1(graph_scale1, feat)\n        feat_origin = feat_scale1\n        feat_scale1_neg = feat_scale1[\n            torch.randperm(feat_scale1.size(0))\n        ]  # negative samples\n        (\n            logit_s1,\n            scores_s1,\n            select_idx_s1,\n            non_select_idx_s1,\n            feat_down_s1,\n        ) = self.index_select_scale1(graph_scale1, feat_scale1, feat_scale1_neg)\n        feat_scale2, graph_scale2 = self.start_pool_s12(\n            graph_scale1,\n            feat_scale1,\n            select_idx_s1,\n            non_select_idx_s1,\n            scores_s1,\n            pool_graph=True,\n        )\n\n        # start of scale-2\n        feat_scale2 = self.start_gcn_scale2(graph_scale2, feat_scale2)\n        feat_scale2_neg = feat_scale2[\n            torch.randperm(feat_scale2.size(0))\n        ]  # negative samples\n        (\n            logit_s2,\n            scores_s2,\n            select_idx_s2,\n            non_select_idx_s2,\n            feat_down_s2,\n        ) = self.index_select_scale2(graph_scale2, feat_scale2, feat_scale2_neg)\n        feat_scale3, graph_scale3 = self.start_pool_s23(\n            graph_scale2,\n            feat_scale2,\n            select_idx_s2,\n            non_select_idx_s2,\n            scores_s2,\n            pool_graph=True,\n        )\n\n        # layer-1\n        res_s1_0, res_s2_0, res_s3_0 = feat_scale1, feat_scale2, feat_scale3\n\n        feat_scale1 = F.relu(self.s1_l1_gcn(graph_scale1, feat_scale1))\n        feat_scale2 = F.relu(self.s2_l1_gcn(graph_scale2, feat_scale2))\n        feat_scale3 = F.relu(self.s3_l1_gcn(graph_scale3, feat_scale3))\n\n        if self.num_cross_layers >= 1:\n            feat_s12_fu = self.pool_s12_1(\n                graph_scale1,\n                feat_scale1,\n                select_idx_s1,\n                non_select_idx_s1,\n                scores_s1,\n            )\n            feat_s21_fu = self.unpool_s21_1(\n                graph_scale1, feat_scale2, select_idx_s1\n            )\n            feat_s23_fu = self.pool_s23_1(\n                graph_scale2,\n                feat_scale2,\n                select_idx_s2,\n                non_select_idx_s2,\n                scores_s2,\n            )\n            feat_s32_fu = self.unpool_s32_1(\n                graph_scale2, feat_scale3, select_idx_s2\n            )\n\n            feat_scale1 = (\n                feat_scale1 + self.cross_weight * feat_s21_fu + res_s1_0\n            )\n            feat_scale2 = (\n                feat_scale2\n                + self.cross_weight * (feat_s12_fu + feat_s32_fu) / 2\n                + res_s2_0\n            )\n            feat_scale3 = (\n                feat_scale3 + self.cross_weight * feat_s23_fu + res_s3_0\n            )\n\n        # layer-2\n        feat_scale1 = F.relu(self.s1_l2_gcn(graph_scale1, feat_scale1))\n        feat_scale2 = F.relu(self.s2_l2_gcn(graph_scale2, feat_scale2))\n        feat_scale3 = F.relu(self.s3_l2_gcn(graph_scale3, feat_scale3))\n\n        if self.num_cross_layers >= 2:\n            feat_s12_fu = self.pool_s12_2(\n                graph_scale1,\n                feat_scale1,\n                select_idx_s1,\n                non_select_idx_s1,\n                scores_s1,\n            )\n            feat_s21_fu = self.unpool_s21_2(\n                graph_scale1, feat_scale2, select_idx_s1\n            )\n            feat_s23_fu = self.pool_s23_2(\n                graph_scale2,\n                feat_scale2,\n                select_idx_s2,\n                non_select_idx_s2,\n                scores_s2,\n            )\n            feat_s32_fu = self.unpool_s32_2(\n                graph_scale2, feat_scale3, select_idx_s2\n            )\n\n            cross_weight = self.cross_weight * 0.05\n            feat_scale1 = feat_scale1 + cross_weight * feat_s21_fu\n            feat_scale2 = (\n                feat_scale2 + cross_weight * (feat_s12_fu + feat_s32_fu) / 2\n            )\n            feat_scale3 = feat_scale3 + cross_weight * feat_s23_fu\n\n        # layer-3\n        feat_scale1 = F.relu(self.s1_l3_gcn(graph_scale1, feat_scale1))\n        feat_scale2 = F.relu(self.s2_l3_gcn(graph_scale2, feat_scale2))\n        feat_scale3 = F.relu(self.s3_l3_gcn(graph_scale3, feat_scale3))\n\n        # final layers\n        feat_s3_out = (\n            self.end_unpool_s32(graph_scale2, feat_scale3, select_idx_s2)\n            + feat_down_s2\n        )\n        feat_s2_out = self.end_unpool_s21(\n            graph_scale1, feat_scale2 + feat_s3_out, select_idx_s1\n        )\n        feat_agg = (\n            feat_scale1\n            + self.fuse_weight * feat_s2_out\n            + self.fuse_weight * feat_down_s1\n        )\n        feat_agg = torch.cat((feat_agg, feat_origin), dim=1)\n        feat_agg = self.end_gcn(graph_scale1, feat_agg)\n\n        return feat_agg, logit_s1, logit_s2\n\n\nclass GraphCrossNet(torch.nn.Module):\n    \"\"\"\n    Description\n    -----------\n    The Graph Cross Network.\n\n    Parameters\n    ----------\n    in_dim : int\n        The number of input node feature channels.\n    out_dim : int\n        The number of output node feature channels.\n    edge_feat_dim : int, optional\n        The number of input edge feature channels. Edge feature\n        will be passed to a Linear layer and concatenated to\n        input node features. Default: :obj:`0`\n    hidden_dim : int, optional\n        The number of hidden node feature channels.\n        Default: :obj:`96`\n    pool_ratios : Union[float, List[float]], optional\n        The pooling ratios (for keeping nodes) for each layer.\n        For example, if `pool_ratio=0.8`, 80\\% nodes will be preserved.\n        If a single float number is given, all pooling layers will have the\n        same pooling ratio.\n        Default: :obj:`[0.9, 0.7]`\n    readout_nodes : int, optional\n        Number of nodes perserved in the final sort pool operation.\n        Default: :obj:`30`\n    conv1d_dims : List[int], optional\n        The number of kernels of Conv1d operations.\n        Default: :obj:`[16, 32]`\n    conv1d_kws : List[int], optional\n        The kernel size of Conv1d.\n        Default: :obj:`[5]`\n    cross_weight : float, optional\n        The weight parameter used in graph cross layers\n        Default: :obj:`1.0`\n    fuse_weight : float, optional\n        The weight parameter used at the end of GXN for channel fusion.\n        Default: :obj:`1.0`\n    \"\"\"\n\n    def __init__(\n        self,\n        in_dim: int,\n        out_dim: int,\n        edge_feat_dim: int = 0,\n        hidden_dim: int = 96,\n        pool_ratios: Union[List[float], float] = [0.9, 0.7],\n        readout_nodes: int = 30,\n        conv1d_dims: List[int] = [16, 32],\n        conv1d_kws: List[int] = [5],\n        cross_weight: float = 1.0,\n        fuse_weight: float = 1.0,\n        dist: int = 1,\n    ):\n        super(GraphCrossNet, self).__init__()\n        self.in_dim = in_dim\n        self.out_dim = out_dim\n        self.hidden_dim = hidden_dim\n        self.edge_feat_dim = edge_feat_dim\n        self.readout_nodes = readout_nodes\n        conv1d_kws = [hidden_dim] + conv1d_kws\n\n        if edge_feat_dim > 0:\n            self.in_dim += hidden_dim\n            self.e2l_lin = torch.nn.Linear(edge_feat_dim, hidden_dim)\n        else:\n            self.e2l_lin = None\n\n        self.gxn = GraphCrossModule(\n            pool_ratios,\n            in_dim=self.in_dim,\n            out_dim=hidden_dim,\n            hidden_dim=hidden_dim // 2,\n            cross_weight=cross_weight,\n            fuse_weight=fuse_weight,\n            dist=dist,\n        )\n        self.sortpool = SortPooling(readout_nodes)\n\n        # final updates\n        self.final_conv1 = torch.nn.Conv1d(\n            1, conv1d_dims[0], kernel_size=conv1d_kws[0], stride=conv1d_kws[0]\n        )\n        self.final_maxpool = torch.nn.MaxPool1d(2, 2)\n        self.final_conv2 = torch.nn.Conv1d(\n            conv1d_dims[0], conv1d_dims[1], kernel_size=conv1d_kws[1], stride=1\n        )\n        self.final_dense_dim = int((readout_nodes - 2) / 2 + 1)\n        self.final_dense_dim = (\n            self.final_dense_dim - conv1d_kws[1] + 1\n        ) * conv1d_dims[1]\n\n        if self.out_dim > 0:\n            self.out_lin = torch.nn.Linear(self.final_dense_dim, out_dim)\n\n        self.init_weights()\n\n    def init_weights(self):\n        if self.e2l_lin is not None:\n            torch.nn.init.xavier_normal_(self.e2l_lin.weight)\n        torch.nn.init.xavier_normal_(self.final_conv1.weight)\n        torch.nn.init.xavier_normal_(self.final_conv2.weight)\n        if self.out_dim > 0:\n            torch.nn.init.xavier_normal_(self.out_lin.weight)\n\n    def forward(\n        self,\n        graph: DGLGraph,\n        node_feat: Tensor,\n        edge_feat: Optional[Tensor] = None,\n    ):\n        num_batch = graph.batch_size\n        if edge_feat is not None:\n            edge_feat = self.e2l_lin(edge_feat)\n            with graph.local_scope():\n                graph.edata[\"he\"] = edge_feat\n                graph.update_all(fn.copy_e(\"he\", \"m\"), fn.sum(\"m\", \"hn\"))\n                edge2node_feat = graph.ndata.pop(\"hn\")\n                node_feat = torch.cat((node_feat, edge2node_feat), dim=1)\n\n        node_feat, logits1, logits2 = self.gxn(graph, node_feat)\n        batch_sortpool_feats = self.sortpool(graph, node_feat)\n\n        # final updates\n        to_conv1d = batch_sortpool_feats.unsqueeze(1)\n        conv1d_result = F.relu(self.final_conv1(to_conv1d))\n        conv1d_result = self.final_maxpool(conv1d_result)\n        conv1d_result = F.relu(self.final_conv2(conv1d_result))\n\n        to_dense = conv1d_result.view(num_batch, -1)\n        if self.out_dim > 0:\n            out = F.relu(self.out_lin(to_dense))\n        else:\n            out = to_dense\n\n        return out, logits1, logits2\n\n\nclass GraphClassifier(torch.nn.Module):\n    \"\"\"\n    Description\n    -----------\n    Graph Classifier for graph classification.\n    GXN + MLP\n    \"\"\"\n\n    def __init__(self, args):\n        super(GraphClassifier, self).__init__()\n        self.gxn = GraphCrossNet(\n            in_dim=args.in_dim,\n            out_dim=args.embed_dim,\n            edge_feat_dim=args.edge_feat_dim,\n            hidden_dim=args.hidden_dim,\n            pool_ratios=args.pool_ratios,\n            readout_nodes=args.readout_nodes,\n            conv1d_dims=args.conv1d_dims,\n            conv1d_kws=args.conv1d_kws,\n            cross_weight=args.cross_weight,\n            fuse_weight=args.fuse_weight,\n        )\n        self.lin1 = torch.nn.Linear(args.embed_dim, args.final_dense_hidden_dim)\n        self.lin2 = torch.nn.Linear(args.final_dense_hidden_dim, args.out_dim)\n        self.dropout = args.dropout\n\n    def forward(\n        self,\n        graph: DGLGraph,\n        node_feat: Tensor,\n        edge_feat: Optional[Tensor] = None,\n    ):\n        embed, logits1, logits2 = self.gxn(graph, node_feat, edge_feat)\n        logits = F.relu(self.lin1(embed))\n        if self.dropout > 0:\n            logits = F.dropout(logits, p=self.dropout, training=self.training)\n        logits = self.lin2(logits)\n        return F.log_softmax(logits, dim=1), logits1, logits2\n"
  },
  {
    "path": "examples/pytorch/gxn/scripts/run_gxn.sh",
    "content": "#!/bin/bash\n\n# input arguments\nDATA=\"${1-DD}\"  # ENZYMES, DD, PROTEINS, COLLAB, IMDB-BINARY, IMDB-MULTI\ndevice=${2-0}\nnum_trials=${3-10}\nprint_every=${4-10}\n\n\n# general settings\nhidden_gxn=96\nk1=0.8\nk2=0.7\nsortpooling_k=30\nhidden_final=128\nbatch_size=64\ndropout=0.5\ncross_weight=1.0\nfuse_weight=0.9\nweight_decay=1e-3\n\n# dataset-specific settings\ncase ${DATA} in\nIMDB-BINARY)\n  num_epochs=200\n  learning_rate=0.001\n  sortpooling_k=31\n  k1=0.8\n  k2=0.5\n  ;;\nIMDB-MULTI)\n  num_epochs=200\n  learning_rate=0.001\n  sortpooling_k=22\n  k1=0.8\n  k2=0.7\n  ;;\nCOLLAB)\n  num_epochs=100\n  learning_rate=0.001\n  sortpooling_k=130\n  k1=0.9\n  k2=0.5\n  ;;\nDD)\n  num_epochs=100\n  learning_rate=0.0005\n  sortpooling_k=291\n  k1=0.8\n  k2=0.6\n  ;;\nPROTEINS)\n  num_epochs=100\n  learning_rate=0.001\n  sortpooling_k=32\n  k1=0.8\n  k2=0.7\n  ;;\nENZYMES)\n  num_epochs=500\n  learning_rate=0.0001\n  sortpooling_k=42\n  k1=0.7\n  k2=0.5\n  ;;\n*)\n  num_epochs=500\n  learning_rate=0.00001\n  ;;\nesac\n\n\npython main.py \\\n      --dataset $DATA \\\n      --lr $learning_rate \\\n      --epochs $num_epochs \\\n      --hidden_dim $hidden_gxn \\\n      --final_dense_hidden_dim $hidden_final \\\n      --readout_nodes $sortpooling_k \\\n      --pool_ratios $k1 $k2 \\\n      --batch_size $batch_size \\\n      --device $device \\\n      --dropout $dropout \\\n      --cross_weight $cross_weight\\\n      --fuse_weight $fuse_weight\\\n      --weight_decay $weight_decay\\\n      --num_trials $num_trials\\\n      --print_every $print_every\\\n"
  },
  {
    "path": "examples/pytorch/gxn/scripts/run_gxn_early_stop.sh",
    "content": "#!/bin/bash\n\n# input arguments\nDATA=\"${1-DD}\"  # ENZYMES, DD, PROTEINS, COLLAB, IMDB-BINARY, IMDB-MULTI\ndevice=${2-0}\nnum_trials=${3-10}\nprint_every=${4-10}\n\n\n# general settings\nhidden_gxn=96\nk1=0.8\nk2=0.7\nsortpooling_k=30\nhidden_final=128\nbatch_size=64\ndropout=0.5\ncross_weight=1.0\nfuse_weight=0.9\nweight_decay=1e-3\n\n# dataset-specific settings\ncase ${DATA} in\nIMDB-BINARY)\n  num_epochs=200\n  patience=40\n  learning_rate=0.001\n  sortpooling_k=31\n  k1=0.8\n  k2=0.5\n  ;;\nIMDB-MULTI)\n  num_epochs=200\n  patience=40\n  learning_rate=0.001\n  sortpooling_k=22\n  k1=0.8\n  k2=0.7\n  ;;\nCOLLAB)\n  num_epochs=100\n  patience=20\n  learning_rate=0.001\n  sortpooling_k=130\n  k1=0.9\n  k2=0.5\n  ;;\nDD)\n  num_epochs=100\n  patience=20\n  learning_rate=0.0005\n  sortpooling_k=291\n  k1=0.8\n  k2=0.6\n  ;;\nPROTEINS)\n  num_epochs=100\n  patience=20\n  learning_rate=0.001\n  sortpooling_k=32\n  k1=0.8\n  k2=0.7\n  ;;\nENZYMES)\n  num_epochs=500\n  patience=100\n  learning_rate=0.0001\n  sortpooling_k=42\n  k1=0.7\n  k2=0.5\n  ;;\n*)\n  num_epochs=500\n  patience=100\n  learning_rate=0.00001\n  ;;\nesac\n\n\npython main_early_stop.py \\\n      --dataset $DATA \\\n      --lr $learning_rate \\\n      --epochs $num_epochs \\\n      --hidden_dim $hidden_gxn \\\n      --final_dense_hidden_dim $hidden_final \\\n      --readout_nodes $sortpooling_k \\\n      --pool_ratios $k1 $k2 \\\n      --batch_size $batch_size \\\n      --device $device \\\n      --dropout $dropout \\\n      --cross_weight $cross_weight\\\n      --fuse_weight $fuse_weight\\\n      --weight_decay $weight_decay\\\n      --num_trials $num_trials\\\n      --print_every $print_every\\\n      --patience $patience\\\n"
  },
  {
    "path": "examples/pytorch/gxn/utils.py",
    "content": "import argparse\nimport logging\nimport math\nimport os\nimport random\n\nimport numpy as np\nimport torch\nimport torch.cuda\nfrom scipy.stats import t\n\n\ndef get_stats(\n    array, conf_interval=False, name=None, stdout=False, logout=False\n):\n    \"\"\"Compute mean and standard deviation from an numerical array\n\n    Args:\n        array (array like obj): The numerical array, this array can be\n            convert to :obj:`torch.Tensor`.\n        conf_interval (bool, optional): If True, compute the confidence interval bound (95%)\n            instead of the std value. (default: :obj:`False`)\n        name (str, optional): The name of this numerical array, for log usage.\n            (default: :obj:`None`)\n        stdout (bool, optional): Whether to output result to the terminal.\n            (default: :obj:`False`)\n        logout (bool, optional): Whether to output result via logging module.\n            (default: :obj:`False`)\n    \"\"\"\n    eps = 1e-9\n    array = torch.Tensor(array)\n    std, mean = torch.std_mean(array)\n    std = std.item()\n    mean = mean.item()\n    center = mean\n\n    if conf_interval:\n        n = array.size(0)\n        se = std / (math.sqrt(n) + eps)\n        t_value = t.ppf(0.975, df=n - 1)\n        err_bound = t_value * se\n    else:\n        err_bound = std\n\n    # log and print\n    if name is None:\n        name = \"array {}\".format(id(array))\n    log = \"{}: {:.4f}(+-{:.4f})\".format(name, center, err_bound)\n    if stdout:\n        print(log)\n    if logout:\n        logging.info(log)\n\n    return center, err_bound\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(\"Graph Cross Network\")\n    parser.add_argument(\n        \"--pool_ratios\",\n        nargs=\"+\",\n        type=float,\n        help=\"The pooling ratios used in graph cross layers\",\n    )\n    parser.add_argument(\n        \"--hidden_dim\",\n        type=int,\n        default=96,\n        help=\"The number of hidden channels in GXN\",\n    )\n    parser.add_argument(\n        \"--cross_weight\",\n        type=float,\n        default=1.0,\n        help=\"Weight parameter used in graph cross layer\",\n    )\n    parser.add_argument(\n        \"--fuse_weight\",\n        type=float,\n        default=1.0,\n        help=\"Weight parameter for feature fusion\",\n    )\n    parser.add_argument(\n        \"--num_cross_layers\",\n        type=int,\n        default=2,\n        help=\"The number of graph corss layers\",\n    )\n    parser.add_argument(\n        \"--readout_nodes\",\n        type=int,\n        default=30,\n        help=\"Number of nodes for each graph after final graph pooling\",\n    )\n    parser.add_argument(\n        \"--conv1d_dims\",\n        nargs=\"+\",\n        type=int,\n        help=\"Number of channels in conv operations in the end of graph cross net\",\n    )\n    parser.add_argument(\n        \"--conv1d_kws\",\n        nargs=\"+\",\n        type=int,\n        help=\"Kernel sizes of conv1d operations\",\n    )\n    parser.add_argument(\n        \"--dropout\", type=float, default=0.0, help=\"Dropout rate\"\n    )\n    parser.add_argument(\n        \"--embed_dim\",\n        type=int,\n        default=1024,\n        help=\"Number of channels of graph embedding\",\n    )\n    parser.add_argument(\n        \"--final_dense_hidden_dim\",\n        type=int,\n        default=128,\n        help=\"The number of hidden channels in final dense layers\",\n    )\n\n    parser.add_argument(\"--batch_size\", type=int, default=64, help=\"Batch size\")\n    parser.add_argument(\"--lr\", type=float, default=1e-4, help=\"Learning rate\")\n    parser.add_argument(\n        \"--weight_decay\", type=float, default=0.0, help=\"Weight decay rate\"\n    )\n    parser.add_argument(\n        \"--epochs\", type=int, default=1000, help=\"Number of training epochs\"\n    )\n    parser.add_argument(\n        \"--patience\", type=int, default=20, help=\"Patience for early stopping\"\n    )\n    parser.add_argument(\n        \"--num_trials\", type=int, default=1, help=\"Number of trials\"\n    )\n\n    parser.add_argument(\n        \"--device\",\n        type=int,\n        default=0,\n        help=\"Computation device id, -1 for cpu\",\n    )\n    parser.add_argument(\n        \"--dataset\", type=str, default=\"DD\", help=\"Dataset used for training\"\n    )\n    parser.add_argument(\n        \"--seed\", type=int, default=-1, help=\"Random seed, -1 for unset\"\n    )\n    parser.add_argument(\n        \"--print_every\",\n        type=int,\n        default=10,\n        help=\"Print train log every ? epochs, -1 for silence training\",\n    )\n    parser.add_argument(\n        \"--dataset_path\",\n        type=str,\n        default=\"./datasets\",\n        help=\"Path holding your dataset\",\n    )\n    parser.add_argument(\n        \"--output_path\",\n        type=str,\n        default=\"./output\",\n        help=\"Path holding your result files\",\n    )\n\n    args = parser.parse_args()\n\n    # default value for list hyper-parameters\n    if not args.pool_ratios or len(args.pool_ratios) < 2:\n        args.pool_ratios = [0.8, 0.7]\n        logging.warning(\n            \"No valid pool_ratios is given, \"\n            \"using default value '{}'\".format(args.pool_ratios)\n        )\n    if not args.conv1d_dims or len(args.conv1d_dims) < 2:\n        args.conv1d_dims = [16, 32]\n        logging.warning(\n            \"No valid conv1d_dims is give, \"\n            \"using default value {}\".format(args.conv1d_dims)\n        )\n    if not args.conv1d_kws or len(args.conv1d_kws) < 1:\n        args.conv1d_kws = [5]\n        logging.warning(\n            \"No valid conv1d_kws is given, \"\n            \"using default value '{}'\".format(args.conv1d_kws)\n        )\n\n    # device\n    args.device = \"cpu\" if args.device < 0 else \"cuda:{}\".format(args.device)\n    if not torch.cuda.is_available():\n        logging.warning(\"GPU is not available, using CPU for training\")\n        args.device = \"cpu\"\n    else:\n        logging.warning(\"Device: {}\".format(args.device))\n\n    # random seed\n    if args.seed >= 0:\n        torch.manual_seed(args.seed)\n        random.seed(args.seed)\n        np.random.seed(args.seed)\n        if args.device != \"cpu\":\n            torch.cuda.manual_seed(args.seed)\n            torch.backends.cudnn.deterministic = True\n            torch.backends.cudnn.benchmark = False\n\n    # print every\n    if args.print_every < 0:\n        args.print_every = args.epochs + 1\n\n    # path\n    paths = [args.output_path, args.dataset_path]\n    for p in paths:\n        if not os.path.exists(p):\n            os.makedirs(p)\n\n    # datasets ad-hoc\n    if args.dataset in [\"COLLAB\", \"IMDB-BINARY\", \"IMDB-MULTI\", \"ENZYMES\"]:\n        args.degree_as_feature = True\n    else:\n        args.degree_as_feature = False\n\n    return args\n"
  },
  {
    "path": "examples/pytorch/han/README.md",
    "content": "# Heterogeneous Graph Attention Network (HAN) with DGL\n\nThis is an attempt to implement HAN with DGL's latest APIs for heterogeneous graphs.\nThe authors' implementation can be found [here](https://github.com/Jhy1993/HAN).\n\n## Usage\n\n`python main.py` for reproducing HAN's work on their dataset.\n\n`python main.py --hetero` for reproducing HAN's work on DGL's own dataset from\n[here](https://github.com/Jhy1993/HAN/tree/master/data/acm).  The dataset is noisy\nbecause there are same author occurring multiple times as different nodes.\n\nFor sampling-based training, `python train_sampling.py`\n\n## Performance\n\nReference performance numbers for the ACM dataset:\n\n|                     | micro f1 score | macro f1 score |\n| ------------------- | -------------- | -------------- |\n| Paper               | 89.22          | 89.40          |\n| DGL                 | 88.99          | 89.02          |\n| Softmax regression (own dataset) | 89.66  | 89.62     |\n| DGL (own dataset)   | 91.51          | 91.66          |\n\nWe ran a softmax regression to check the easiness of our own dataset.  HAN did show some improvements.\n"
  },
  {
    "path": "examples/pytorch/han/main.py",
    "content": "import torch\nfrom sklearn.metrics import f1_score\nfrom utils import EarlyStopping, load_data\n\n\ndef score(logits, labels):\n    _, indices = torch.max(logits, dim=1)\n    prediction = indices.long().cpu().numpy()\n    labels = labels.cpu().numpy()\n\n    accuracy = (prediction == labels).sum() / len(prediction)\n    micro_f1 = f1_score(labels, prediction, average=\"micro\")\n    macro_f1 = f1_score(labels, prediction, average=\"macro\")\n\n    return accuracy, micro_f1, macro_f1\n\n\ndef evaluate(model, g, features, labels, mask, loss_func):\n    model.eval()\n    with torch.no_grad():\n        logits = model(g, features)\n    loss = loss_func(logits[mask], labels[mask])\n    accuracy, micro_f1, macro_f1 = score(logits[mask], labels[mask])\n\n    return loss, accuracy, micro_f1, macro_f1\n\n\ndef main(args):\n    # If args['hetero'] is True, g would be a heterogeneous graph.\n    # Otherwise, it will be a list of homogeneous graphs.\n    (\n        g,\n        features,\n        labels,\n        num_classes,\n        train_idx,\n        val_idx,\n        test_idx,\n        train_mask,\n        val_mask,\n        test_mask,\n    ) = load_data(args[\"dataset\"])\n\n    if hasattr(torch, \"BoolTensor\"):\n        train_mask = train_mask.bool()\n        val_mask = val_mask.bool()\n        test_mask = test_mask.bool()\n\n    features = features.to(args[\"device\"])\n    labels = labels.to(args[\"device\"])\n    train_mask = train_mask.to(args[\"device\"])\n    val_mask = val_mask.to(args[\"device\"])\n    test_mask = test_mask.to(args[\"device\"])\n\n    if args[\"hetero\"]:\n        from model_hetero import HAN\n\n        model = HAN(\n            meta_paths=[[\"pa\", \"ap\"], [\"pf\", \"fp\"]],\n            in_size=features.shape[1],\n            hidden_size=args[\"hidden_units\"],\n            out_size=num_classes,\n            num_heads=args[\"num_heads\"],\n            dropout=args[\"dropout\"],\n        ).to(args[\"device\"])\n        g = g.to(args[\"device\"])\n    else:\n        from model import HAN\n\n        model = HAN(\n            num_meta_paths=len(g),\n            in_size=features.shape[1],\n            hidden_size=args[\"hidden_units\"],\n            out_size=num_classes,\n            num_heads=args[\"num_heads\"],\n            dropout=args[\"dropout\"],\n        ).to(args[\"device\"])\n        g = [graph.to(args[\"device\"]) for graph in g]\n\n    stopper = EarlyStopping(patience=args[\"patience\"])\n    loss_fcn = torch.nn.CrossEntropyLoss()\n    optimizer = torch.optim.Adam(\n        model.parameters(), lr=args[\"lr\"], weight_decay=args[\"weight_decay\"]\n    )\n\n    for epoch in range(args[\"num_epochs\"]):\n        model.train()\n        logits = model(g, features)\n        loss = loss_fcn(logits[train_mask], labels[train_mask])\n\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        train_acc, train_micro_f1, train_macro_f1 = score(\n            logits[train_mask], labels[train_mask]\n        )\n        val_loss, val_acc, val_micro_f1, val_macro_f1 = evaluate(\n            model, g, features, labels, val_mask, loss_fcn\n        )\n        early_stop = stopper.step(val_loss.data.item(), val_acc, model)\n\n        print(\n            \"Epoch {:d} | Train Loss {:.4f} | Train Micro f1 {:.4f} | Train Macro f1 {:.4f} | \"\n            \"Val Loss {:.4f} | Val Micro f1 {:.4f} | Val Macro f1 {:.4f}\".format(\n                epoch + 1,\n                loss.item(),\n                train_micro_f1,\n                train_macro_f1,\n                val_loss.item(),\n                val_micro_f1,\n                val_macro_f1,\n            )\n        )\n\n        if early_stop:\n            break\n\n    stopper.load_checkpoint(model)\n    test_loss, test_acc, test_micro_f1, test_macro_f1 = evaluate(\n        model, g, features, labels, test_mask, loss_fcn\n    )\n    print(\n        \"Test loss {:.4f} | Test Micro f1 {:.4f} | Test Macro f1 {:.4f}\".format(\n            test_loss.item(), test_micro_f1, test_macro_f1\n        )\n    )\n\n\nif __name__ == \"__main__\":\n    import argparse\n\n    from utils import setup\n\n    parser = argparse.ArgumentParser(\"HAN\")\n    parser.add_argument(\"-s\", \"--seed\", type=int, default=1, help=\"Random seed\")\n    parser.add_argument(\n        \"-ld\",\n        \"--log-dir\",\n        type=str,\n        default=\"results\",\n        help=\"Dir for saving training results\",\n    )\n    parser.add_argument(\n        \"--hetero\",\n        action=\"store_true\",\n        help=\"Use metapath coalescing with DGL's own dataset\",\n    )\n    args = parser.parse_args().__dict__\n\n    args = setup(args)\n\n    main(args)\n"
  },
  {
    "path": "examples/pytorch/han/model.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom dgl.nn.pytorch import GATConv\n\n\nclass SemanticAttention(nn.Module):\n    def __init__(self, in_size, hidden_size=128):\n        super(SemanticAttention, self).__init__()\n\n        self.project = nn.Sequential(\n            nn.Linear(in_size, hidden_size),\n            nn.Tanh(),\n            nn.Linear(hidden_size, 1, bias=False),\n        )\n\n    def forward(self, z):\n        w = self.project(z).mean(0)  # (M, 1)\n        beta = torch.softmax(w, dim=0)  # (M, 1)\n        beta = beta.expand((z.shape[0],) + beta.shape)  # (N, M, 1)\n\n        return (beta * z).sum(1)  # (N, D * K)\n\n\nclass HANLayer(nn.Module):\n    \"\"\"\n    HAN layer.\n\n    Arguments\n    ---------\n    num_meta_paths : number of homogeneous graphs generated from the metapaths.\n    in_size : input feature dimension\n    out_size : output feature dimension\n    layer_num_heads : number of attention heads\n    dropout : Dropout probability\n\n    Inputs\n    ------\n    g : list[DGLGraph]\n        List of graphs\n    h : tensor\n        Input features\n\n    Outputs\n    -------\n    tensor\n        The output feature\n    \"\"\"\n\n    def __init__(\n        self, num_meta_paths, in_size, out_size, layer_num_heads, dropout\n    ):\n        super(HANLayer, self).__init__()\n\n        # One GAT layer for each meta path based adjacency matrix\n        self.gat_layers = nn.ModuleList()\n        for i in range(num_meta_paths):\n            self.gat_layers.append(\n                GATConv(\n                    in_size,\n                    out_size,\n                    layer_num_heads,\n                    dropout,\n                    dropout,\n                    activation=F.elu,\n                )\n            )\n        self.semantic_attention = SemanticAttention(\n            in_size=out_size * layer_num_heads\n        )\n        self.num_meta_paths = num_meta_paths\n\n    def forward(self, gs, h):\n        semantic_embeddings = []\n\n        for i, g in enumerate(gs):\n            semantic_embeddings.append(self.gat_layers[i](g, h).flatten(1))\n        semantic_embeddings = torch.stack(\n            semantic_embeddings, dim=1\n        )  # (N, M, D * K)\n\n        return self.semantic_attention(semantic_embeddings)  # (N, D * K)\n\n\nclass HAN(nn.Module):\n    def __init__(\n        self, num_meta_paths, in_size, hidden_size, out_size, num_heads, dropout\n    ):\n        super(HAN, self).__init__()\n\n        self.layers = nn.ModuleList()\n        self.layers.append(\n            HANLayer(\n                num_meta_paths, in_size, hidden_size, num_heads[0], dropout\n            )\n        )\n        for l in range(1, len(num_heads)):\n            self.layers.append(\n                HANLayer(\n                    num_meta_paths,\n                    hidden_size * num_heads[l - 1],\n                    hidden_size,\n                    num_heads[l],\n                    dropout,\n                )\n            )\n        self.predict = nn.Linear(hidden_size * num_heads[-1], out_size)\n\n    def forward(self, g, h):\n        for gnn in self.layers:\n            h = gnn(g, h)\n\n        return self.predict(h)\n"
  },
  {
    "path": "examples/pytorch/han/model_hetero.py",
    "content": "\"\"\"This model shows an example of using dgl.metapath_reachable_graph on the original heterogeneous\ngraph.\n\nBecause the original HAN implementation only gives the preprocessed homogeneous graph, this model\ncould not reproduce the result in HAN as they did not provide the preprocessing code, and we\nconstructed another dataset from ACM with a different set of papers, connections, features and\nlabels.\n\"\"\"\n\nimport dgl\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dgl.nn.pytorch import GATConv\n\n\nclass SemanticAttention(nn.Module):\n    def __init__(self, in_size, hidden_size=128):\n        super(SemanticAttention, self).__init__()\n\n        self.project = nn.Sequential(\n            nn.Linear(in_size, hidden_size),\n            nn.Tanh(),\n            nn.Linear(hidden_size, 1, bias=False),\n        )\n\n    def forward(self, z):\n        w = self.project(z).mean(0)  # (M, 1)\n        beta = torch.softmax(w, dim=0)  # (M, 1)\n        beta = beta.expand((z.shape[0],) + beta.shape)  # (N, M, 1)\n\n        return (beta * z).sum(1)  # (N, D * K)\n\n\nclass HANLayer(nn.Module):\n    \"\"\"\n    HAN layer.\n\n    Arguments\n    ---------\n    meta_paths : list of metapaths, each as a list of edge types\n    in_size : input feature dimension\n    out_size : output feature dimension\n    layer_num_heads : number of attention heads\n    dropout : Dropout probability\n\n    Inputs\n    ------\n    g : DGLGraph\n        The heterogeneous graph\n    h : tensor\n        Input features\n\n    Outputs\n    -------\n    tensor\n        The output feature\n    \"\"\"\n\n    def __init__(self, meta_paths, in_size, out_size, layer_num_heads, dropout):\n        super(HANLayer, self).__init__()\n\n        # One GAT layer for each meta path based adjacency matrix\n        self.gat_layers = nn.ModuleList()\n        for i in range(len(meta_paths)):\n            self.gat_layers.append(\n                GATConv(\n                    in_size,\n                    out_size,\n                    layer_num_heads,\n                    dropout,\n                    dropout,\n                    activation=F.elu,\n                    allow_zero_in_degree=True,\n                )\n            )\n        self.semantic_attention = SemanticAttention(\n            in_size=out_size * layer_num_heads\n        )\n        self.meta_paths = list(tuple(meta_path) for meta_path in meta_paths)\n\n        self._cached_graph = None\n        self._cached_coalesced_graph = {}\n\n    def forward(self, g, h):\n        semantic_embeddings = []\n\n        if self._cached_graph is None or self._cached_graph is not g:\n            self._cached_graph = g\n            self._cached_coalesced_graph.clear()\n            for meta_path in self.meta_paths:\n                self._cached_coalesced_graph[\n                    meta_path\n                ] = dgl.metapath_reachable_graph(g, meta_path)\n\n        for i, meta_path in enumerate(self.meta_paths):\n            new_g = self._cached_coalesced_graph[meta_path]\n            semantic_embeddings.append(self.gat_layers[i](new_g, h).flatten(1))\n        semantic_embeddings = torch.stack(\n            semantic_embeddings, dim=1\n        )  # (N, M, D * K)\n\n        return self.semantic_attention(semantic_embeddings)  # (N, D * K)\n\n\nclass HAN(nn.Module):\n    def __init__(\n        self, meta_paths, in_size, hidden_size, out_size, num_heads, dropout\n    ):\n        super(HAN, self).__init__()\n\n        self.layers = nn.ModuleList()\n        self.layers.append(\n            HANLayer(meta_paths, in_size, hidden_size, num_heads[0], dropout)\n        )\n        for l in range(1, len(num_heads)):\n            self.layers.append(\n                HANLayer(\n                    meta_paths,\n                    hidden_size * num_heads[l - 1],\n                    hidden_size,\n                    num_heads[l],\n                    dropout,\n                )\n            )\n        self.predict = nn.Linear(hidden_size * num_heads[-1], out_size)\n\n    def forward(self, g, h):\n        for gnn in self.layers:\n            h = gnn(g, h)\n\n        return self.predict(h)\n"
  },
  {
    "path": "examples/pytorch/han/train_sampling.py",
    "content": "# -*- coding: utf-8 -*-\n\"\"\"\nHAN mini-batch training by RandomWalkSampler.\nnote: This demo use RandomWalkSampler to sample neighbors, it's hard to get all neighbors when valid or test,\nso we sampled twice as many neighbors during val/test than training.\n\"\"\"\nimport argparse\n\nimport dgl\n\nimport numpy\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dgl.nn.pytorch import GATConv\nfrom dgl.sampling import RandomWalkNeighborSampler\nfrom model_hetero import SemanticAttention\nfrom sklearn.metrics import f1_score\nfrom torch.utils.data import DataLoader\nfrom utils import EarlyStopping, set_random_seed\n\n\nclass HANLayer(torch.nn.Module):\n    \"\"\"\n    HAN layer.\n\n    Arguments\n    ---------\n    num_metapath : number of metapath based sub-graph\n    in_size : input feature dimension\n    out_size : output feature dimension\n    layer_num_heads : number of attention heads\n    dropout : Dropout probability\n\n    Inputs\n    ------\n    g : DGLGraph\n        The heterogeneous graph\n    h : tensor\n        Input features\n\n    Outputs\n    -------\n    tensor\n        The output feature\n    \"\"\"\n\n    def __init__(\n        self, num_metapath, in_size, out_size, layer_num_heads, dropout\n    ):\n        super(HANLayer, self).__init__()\n\n        # One GAT layer for each meta path based adjacency matrix\n        self.gat_layers = nn.ModuleList()\n        for i in range(num_metapath):\n            self.gat_layers.append(\n                GATConv(\n                    in_size,\n                    out_size,\n                    layer_num_heads,\n                    dropout,\n                    dropout,\n                    activation=F.elu,\n                    allow_zero_in_degree=True,\n                )\n            )\n        self.semantic_attention = SemanticAttention(\n            in_size=out_size * layer_num_heads\n        )\n        self.num_metapath = num_metapath\n\n    def forward(self, block_list, h_list):\n        semantic_embeddings = []\n\n        for i, block in enumerate(block_list):\n            semantic_embeddings.append(\n                self.gat_layers[i](block, h_list[i]).flatten(1)\n            )\n        semantic_embeddings = torch.stack(\n            semantic_embeddings, dim=1\n        )  # (N, M, D * K)\n\n        return self.semantic_attention(semantic_embeddings)  # (N, D * K)\n\n\nclass HAN(nn.Module):\n    def __init__(\n        self, num_metapath, in_size, hidden_size, out_size, num_heads, dropout\n    ):\n        super(HAN, self).__init__()\n\n        self.layers = nn.ModuleList()\n        self.layers.append(\n            HANLayer(num_metapath, in_size, hidden_size, num_heads[0], dropout)\n        )\n        for l in range(1, len(num_heads)):\n            self.layers.append(\n                HANLayer(\n                    num_metapath,\n                    hidden_size * num_heads[l - 1],\n                    hidden_size,\n                    num_heads[l],\n                    dropout,\n                )\n            )\n        self.predict = nn.Linear(hidden_size * num_heads[-1], out_size)\n\n    def forward(self, g, h):\n        for gnn in self.layers:\n            h = gnn(g, h)\n\n        return self.predict(h)\n\n\nclass HANSampler(object):\n    def __init__(self, g, metapath_list, num_neighbors):\n        self.sampler_list = []\n        for metapath in metapath_list:\n            # note: random walk may get same route(same edge), which will be removed in the sampled graph.\n            # So the sampled graph's edges may be less than num_random_walks(num_neighbors).\n            self.sampler_list.append(\n                RandomWalkNeighborSampler(\n                    G=g,\n                    num_traversals=1,\n                    termination_prob=0,\n                    num_random_walks=num_neighbors,\n                    num_neighbors=num_neighbors,\n                    metapath=metapath,\n                )\n            )\n\n    def sample_blocks(self, seeds):\n        block_list = []\n        for sampler in self.sampler_list:\n            frontier = sampler(seeds)\n            # add self loop\n            frontier = dgl.remove_self_loop(frontier)\n            frontier.add_edges(torch.tensor(seeds), torch.tensor(seeds))\n            block = dgl.to_block(frontier, seeds)\n            block_list.append(block)\n\n        return seeds, block_list\n\n\ndef score(logits, labels):\n    _, indices = torch.max(logits, dim=1)\n    prediction = indices.long().cpu().numpy()\n    labels = labels.cpu().numpy()\n\n    accuracy = (prediction == labels).sum() / len(prediction)\n    micro_f1 = f1_score(labels, prediction, average=\"micro\")\n    macro_f1 = f1_score(labels, prediction, average=\"macro\")\n\n    return accuracy, micro_f1, macro_f1\n\n\ndef evaluate(\n    model,\n    g,\n    metapath_list,\n    num_neighbors,\n    features,\n    labels,\n    val_nid,\n    loss_fcn,\n    batch_size,\n):\n    model.eval()\n\n    han_valid_sampler = HANSampler(\n        g, metapath_list, num_neighbors=num_neighbors * 2\n    )\n    dataloader = DataLoader(\n        dataset=val_nid,\n        batch_size=batch_size,\n        collate_fn=han_valid_sampler.sample_blocks,\n        shuffle=False,\n        drop_last=False,\n        num_workers=4,\n    )\n    correct = total = 0\n    prediction_list = []\n    labels_list = []\n    with torch.no_grad():\n        for step, (seeds, blocks) in enumerate(dataloader):\n            h_list = load_subtensors(blocks, features)\n            blocks = [block.to(args[\"device\"]) for block in blocks]\n            hs = [h.to(args[\"device\"]) for h in h_list]\n\n            logits = model(blocks, hs)\n            loss = loss_fcn(\n                logits, labels[numpy.asarray(seeds)].to(args[\"device\"])\n            )\n            # get each predict label\n            _, indices = torch.max(logits, dim=1)\n            prediction = indices.long().cpu().numpy()\n            labels_batch = labels[numpy.asarray(seeds)].cpu().numpy()\n\n            prediction_list.append(prediction)\n            labels_list.append(labels_batch)\n\n            correct += (prediction == labels_batch).sum()\n            total += prediction.shape[0]\n\n    total_prediction = numpy.concatenate(prediction_list)\n    total_labels = numpy.concatenate(labels_list)\n    micro_f1 = f1_score(total_labels, total_prediction, average=\"micro\")\n    macro_f1 = f1_score(total_labels, total_prediction, average=\"macro\")\n    accuracy = correct / total\n\n    return loss, accuracy, micro_f1, macro_f1\n\n\ndef load_subtensors(blocks, features):\n    h_list = []\n    for block in blocks:\n        input_nodes = block.srcdata[dgl.NID]\n        h_list.append(features[input_nodes])\n    return h_list\n\n\ndef main(args):\n    # acm data\n    if args[\"dataset\"] == \"ACMRaw\":\n        from utils import load_data\n\n        (\n            g,\n            features,\n            labels,\n            n_classes,\n            train_nid,\n            val_nid,\n            test_nid,\n            train_mask,\n            val_mask,\n            test_mask,\n        ) = load_data(\"ACMRaw\")\n        metapath_list = [[\"pa\", \"ap\"], [\"pf\", \"fp\"]]\n    else:\n        raise NotImplementedError(\n            \"Unsupported dataset {}\".format(args[\"dataset\"])\n        )\n\n    # Is it need to set different neighbors numbers for different meta-path based graph?\n    num_neighbors = args[\"num_neighbors\"]\n    han_sampler = HANSampler(g, metapath_list, num_neighbors)\n    # Create PyTorch DataLoader for constructing blocks\n    dataloader = DataLoader(\n        dataset=train_nid,\n        batch_size=args[\"batch_size\"],\n        collate_fn=han_sampler.sample_blocks,\n        shuffle=True,\n        drop_last=False,\n        num_workers=4,\n    )\n\n    model = HAN(\n        num_metapath=len(metapath_list),\n        in_size=features.shape[1],\n        hidden_size=args[\"hidden_units\"],\n        out_size=n_classes,\n        num_heads=args[\"num_heads\"],\n        dropout=args[\"dropout\"],\n    ).to(args[\"device\"])\n\n    total_params = sum(p.numel() for p in model.parameters())\n    print(\"total_params: {:d}\".format(total_params))\n    total_trainable_params = sum(\n        p.numel() for p in model.parameters() if p.requires_grad\n    )\n    print(\"total trainable params: {:d}\".format(total_trainable_params))\n\n    stopper = EarlyStopping(patience=args[\"patience\"])\n    loss_fn = torch.nn.CrossEntropyLoss()\n    optimizer = torch.optim.Adam(\n        model.parameters(), lr=args[\"lr\"], weight_decay=args[\"weight_decay\"]\n    )\n\n    for epoch in range(args[\"num_epochs\"]):\n        model.train()\n        for step, (seeds, blocks) in enumerate(dataloader):\n            h_list = load_subtensors(blocks, features)\n            blocks = [block.to(args[\"device\"]) for block in blocks]\n            hs = [h.to(args[\"device\"]) for h in h_list]\n\n            logits = model(blocks, hs)\n            loss = loss_fn(\n                logits, labels[numpy.asarray(seeds)].to(args[\"device\"])\n            )\n\n            optimizer.zero_grad()\n            loss.backward()\n            optimizer.step()\n\n            # print info in each batch\n            train_acc, train_micro_f1, train_macro_f1 = score(\n                logits, labels[numpy.asarray(seeds)]\n            )\n            print(\n                \"Epoch {:d} | loss: {:.4f} | train_acc: {:.4f} | train_micro_f1: {:.4f} | train_macro_f1: {:.4f}\".format(\n                    epoch + 1, loss, train_acc, train_micro_f1, train_macro_f1\n                )\n            )\n        val_loss, val_acc, val_micro_f1, val_macro_f1 = evaluate(\n            model,\n            g,\n            metapath_list,\n            num_neighbors,\n            features,\n            labels,\n            val_nid,\n            loss_fn,\n            args[\"batch_size\"],\n        )\n        early_stop = stopper.step(val_loss.data.item(), val_acc, model)\n\n        print(\n            \"Epoch {:d} | Val loss {:.4f} | Val Accuracy {:.4f} | Val Micro f1 {:.4f} | Val Macro f1 {:.4f}\".format(\n                epoch + 1, val_loss.item(), val_acc, val_micro_f1, val_macro_f1\n            )\n        )\n\n        if early_stop:\n            break\n\n    stopper.load_checkpoint(model)\n    test_loss, test_acc, test_micro_f1, test_macro_f1 = evaluate(\n        model,\n        g,\n        metapath_list,\n        num_neighbors,\n        features,\n        labels,\n        test_nid,\n        loss_fn,\n        args[\"batch_size\"],\n    )\n    print(\n        \"Test loss {:.4f} | Test Accuracy {:.4f} | Test Micro f1 {:.4f} | Test Macro f1 {:.4f}\".format(\n            test_loss.item(), test_acc, test_micro_f1, test_macro_f1\n        )\n    )\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(\"mini-batch HAN\")\n    parser.add_argument(\"-s\", \"--seed\", type=int, default=1, help=\"Random seed\")\n    parser.add_argument(\"--batch_size\", type=int, default=32)\n    parser.add_argument(\"--num_neighbors\", type=int, default=20)\n    parser.add_argument(\"--lr\", type=float, default=0.001)\n    parser.add_argument(\"--num_heads\", type=list, default=[8])\n    parser.add_argument(\"--hidden_units\", type=int, default=8)\n    parser.add_argument(\"--dropout\", type=float, default=0.6)\n    parser.add_argument(\"--weight_decay\", type=float, default=0.001)\n    parser.add_argument(\"--num_epochs\", type=int, default=100)\n    parser.add_argument(\"--patience\", type=int, default=10)\n    parser.add_argument(\"--dataset\", type=str, default=\"ACMRaw\")\n    parser.add_argument(\"--device\", type=str, default=\"cuda:0\")\n\n    args = parser.parse_args().__dict__\n    # set_random_seed(args['seed'])\n\n    main(args)\n"
  },
  {
    "path": "examples/pytorch/han/utils.py",
    "content": "import datetime\nimport errno\nimport os\nimport pickle\nimport random\nfrom pprint import pprint\n\nimport dgl\n\nimport numpy as np\nimport torch\nfrom dgl.data.utils import _get_dgl_url, download, get_download_dir\nfrom scipy import io as sio, sparse\n\n\ndef set_random_seed(seed=0):\n    \"\"\"Set random seed.\n    Parameters\n    ----------\n    seed : int\n        Random seed to use\n    \"\"\"\n    random.seed(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n    if torch.cuda.is_available():\n        torch.cuda.manual_seed(seed)\n\n\ndef mkdir_p(path, log=True):\n    \"\"\"Create a directory for the specified path.\n    Parameters\n    ----------\n    path : str\n        Path name\n    log : bool\n        Whether to print result for directory creation\n    \"\"\"\n    try:\n        os.makedirs(path)\n        if log:\n            print(\"Created directory {}\".format(path))\n    except OSError as exc:\n        if exc.errno == errno.EEXIST and os.path.isdir(path) and log:\n            print(\"Directory {} already exists.\".format(path))\n        else:\n            raise\n\n\ndef get_date_postfix():\n    \"\"\"Get a date based postfix for directory name.\n    Returns\n    -------\n    post_fix : str\n    \"\"\"\n    dt = datetime.datetime.now()\n    post_fix = \"{}_{:02d}-{:02d}-{:02d}\".format(\n        dt.date(), dt.hour, dt.minute, dt.second\n    )\n\n    return post_fix\n\n\ndef setup_log_dir(args, sampling=False):\n    \"\"\"Name and create directory for logging.\n    Parameters\n    ----------\n    args : dict\n        Configuration\n    Returns\n    -------\n    log_dir : str\n        Path for logging directory\n    sampling : bool\n        Whether we are using sampling based training\n    \"\"\"\n    date_postfix = get_date_postfix()\n    log_dir = os.path.join(\n        args[\"log_dir\"], \"{}_{}\".format(args[\"dataset\"], date_postfix)\n    )\n\n    if sampling:\n        log_dir = log_dir + \"_sampling\"\n\n    mkdir_p(log_dir)\n    return log_dir\n\n\n# The configuration below is from the paper.\ndefault_configure = {\n    \"lr\": 0.005,  # Learning rate\n    \"num_heads\": [8],  # Number of attention heads for node-level attention\n    \"hidden_units\": 8,\n    \"dropout\": 0.6,\n    \"weight_decay\": 0.001,\n    \"num_epochs\": 200,\n    \"patience\": 100,\n}\n\nsampling_configure = {\"batch_size\": 20}\n\n\ndef setup(args):\n    args.update(default_configure)\n    set_random_seed(args[\"seed\"])\n    args[\"dataset\"] = \"ACMRaw\" if args[\"hetero\"] else \"ACM\"\n    args[\"device\"] = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n    args[\"log_dir\"] = setup_log_dir(args)\n    return args\n\n\ndef setup_for_sampling(args):\n    args.update(default_configure)\n    args.update(sampling_configure)\n    set_random_seed()\n    args[\"device\"] = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n    args[\"log_dir\"] = setup_log_dir(args, sampling=True)\n    return args\n\n\ndef get_binary_mask(total_size, indices):\n    mask = torch.zeros(total_size)\n    mask[indices] = 1\n    return mask.byte()\n\n\ndef load_acm(remove_self_loop):\n    url = \"dataset/ACM3025.pkl\"\n    data_path = get_download_dir() + \"/ACM3025.pkl\"\n    download(_get_dgl_url(url), path=data_path)\n\n    with open(data_path, \"rb\") as f:\n        data = pickle.load(f)\n\n    labels, features = (\n        torch.from_numpy(data[\"label\"].todense()).long(),\n        torch.from_numpy(data[\"feature\"].todense()).float(),\n    )\n    num_classes = labels.shape[1]\n    labels = labels.nonzero()[:, 1]\n\n    if remove_self_loop:\n        num_nodes = data[\"label\"].shape[0]\n        data[\"PAP\"] = sparse.csr_matrix(data[\"PAP\"] - np.eye(num_nodes))\n        data[\"PLP\"] = sparse.csr_matrix(data[\"PLP\"] - np.eye(num_nodes))\n\n    # Adjacency matrices for meta path based neighbors\n    # (Mufei): I verified both of them are binary adjacency matrices with self loops\n    author_g = dgl.from_scipy(data[\"PAP\"])\n    subject_g = dgl.from_scipy(data[\"PLP\"])\n    gs = [author_g, subject_g]\n\n    train_idx = torch.from_numpy(data[\"train_idx\"]).long().squeeze(0)\n    val_idx = torch.from_numpy(data[\"val_idx\"]).long().squeeze(0)\n    test_idx = torch.from_numpy(data[\"test_idx\"]).long().squeeze(0)\n\n    num_nodes = author_g.num_nodes()\n    train_mask = get_binary_mask(num_nodes, train_idx)\n    val_mask = get_binary_mask(num_nodes, val_idx)\n    test_mask = get_binary_mask(num_nodes, test_idx)\n\n    print(\"dataset loaded\")\n    pprint(\n        {\n            \"dataset\": \"ACM\",\n            \"train\": train_mask.sum().item() / num_nodes,\n            \"val\": val_mask.sum().item() / num_nodes,\n            \"test\": test_mask.sum().item() / num_nodes,\n        }\n    )\n\n    return (\n        gs,\n        features,\n        labels,\n        num_classes,\n        train_idx,\n        val_idx,\n        test_idx,\n        train_mask,\n        val_mask,\n        test_mask,\n    )\n\n\ndef load_acm_raw(remove_self_loop):\n    assert not remove_self_loop\n    url = \"dataset/ACM.mat\"\n    data_path = get_download_dir() + \"/ACM.mat\"\n    download(_get_dgl_url(url), path=data_path)\n\n    data = sio.loadmat(data_path)\n    p_vs_l = data[\"PvsL\"]  # paper-field?\n    p_vs_a = data[\"PvsA\"]  # paper-author\n    p_vs_t = data[\"PvsT\"]  # paper-term, bag of words\n    p_vs_c = data[\"PvsC\"]  # paper-conference, labels come from that\n\n    # We assign\n    # (1) KDD papers as class 0 (data mining),\n    # (2) SIGMOD and VLDB papers as class 1 (database),\n    # (3) SIGCOMM and MOBICOMM papers as class 2 (communication)\n    conf_ids = [0, 1, 9, 10, 13]\n    label_ids = [0, 1, 2, 2, 1]\n\n    p_vs_c_filter = p_vs_c[:, conf_ids]\n    p_selected = (p_vs_c_filter.sum(1) != 0).A1.nonzero()[0]\n    p_vs_l = p_vs_l[p_selected]\n    p_vs_a = p_vs_a[p_selected]\n    p_vs_t = p_vs_t[p_selected]\n    p_vs_c = p_vs_c[p_selected]\n\n    hg = dgl.heterograph(\n        {\n            (\"paper\", \"pa\", \"author\"): p_vs_a.nonzero(),\n            (\"author\", \"ap\", \"paper\"): p_vs_a.transpose().nonzero(),\n            (\"paper\", \"pf\", \"field\"): p_vs_l.nonzero(),\n            (\"field\", \"fp\", \"paper\"): p_vs_l.transpose().nonzero(),\n        }\n    )\n\n    features = torch.FloatTensor(p_vs_t.toarray())\n\n    pc_p, pc_c = p_vs_c.nonzero()\n    labels = np.zeros(len(p_selected), dtype=np.int64)\n    for conf_id, label_id in zip(conf_ids, label_ids):\n        labels[pc_p[pc_c == conf_id]] = label_id\n    labels = torch.LongTensor(labels)\n\n    num_classes = 3\n\n    float_mask = np.zeros(len(pc_p))\n    for conf_id in conf_ids:\n        pc_c_mask = pc_c == conf_id\n        float_mask[pc_c_mask] = np.random.permutation(\n            np.linspace(0, 1, pc_c_mask.sum())\n        )\n    train_idx = np.where(float_mask <= 0.2)[0]\n    val_idx = np.where((float_mask > 0.2) & (float_mask <= 0.3))[0]\n    test_idx = np.where(float_mask > 0.3)[0]\n\n    num_nodes = hg.num_nodes(\"paper\")\n    train_mask = get_binary_mask(num_nodes, train_idx)\n    val_mask = get_binary_mask(num_nodes, val_idx)\n    test_mask = get_binary_mask(num_nodes, test_idx)\n\n    return (\n        hg,\n        features,\n        labels,\n        num_classes,\n        train_idx,\n        val_idx,\n        test_idx,\n        train_mask,\n        val_mask,\n        test_mask,\n    )\n\n\ndef load_data(dataset, remove_self_loop=False):\n    if dataset == \"ACM\":\n        return load_acm(remove_self_loop)\n    elif dataset == \"ACMRaw\":\n        return load_acm_raw(remove_self_loop)\n    else:\n        return NotImplementedError(\"Unsupported dataset {}\".format(dataset))\n\n\nclass EarlyStopping(object):\n    def __init__(self, patience=10):\n        dt = datetime.datetime.now()\n        self.filename = \"early_stop_{}_{:02d}-{:02d}-{:02d}.pth\".format(\n            dt.date(), dt.hour, dt.minute, dt.second\n        )\n        self.patience = patience\n        self.counter = 0\n        self.best_acc = None\n        self.best_loss = None\n        self.early_stop = False\n\n    def step(self, loss, acc, model):\n        if self.best_loss is None:\n            self.best_acc = acc\n            self.best_loss = loss\n            self.save_checkpoint(model)\n        elif (loss > self.best_loss) and (acc < self.best_acc):\n            self.counter += 1\n            print(\n                f\"EarlyStopping counter: {self.counter} out of {self.patience}\"\n            )\n            if self.counter >= self.patience:\n                self.early_stop = True\n        else:\n            if (loss <= self.best_loss) and (acc >= self.best_acc):\n                self.save_checkpoint(model)\n            self.best_loss = np.min((loss, self.best_loss))\n            self.best_acc = np.max((acc, self.best_acc))\n            self.counter = 0\n        return self.early_stop\n\n    def save_checkpoint(self, model):\n        \"\"\"Saves model when validation loss decreases.\"\"\"\n        torch.save(model.state_dict(), self.filename)\n\n    def load_checkpoint(self, model):\n        \"\"\"Load the latest checkpoint.\"\"\"\n        model.load_state_dict(torch.load(self.filename, weights_only=False))\n"
  },
  {
    "path": "examples/pytorch/hardgat/README.md",
    "content": "# HardGAT\n## DGL Implementation of h/cGAO paper.\n\nThis DGL example implements the GNN model proposed in the paper [HardGraphAttention](https://arxiv.org/abs/1907.04652.pdf). \n\nHardGANet implementor\n----------------------\nThis example was implemented by [Ericcsr](https://github.com/Ericcsr) during his Internship work at the AWS Shanghai AI Lab.\n\nThe graph dataset used in this example \n---------------------------------------\nThe DGL's built-in CoraGraphDataset. Dataset summary:\n- NumNodes: 2708\n- NumEdges: 10556\n- NumFeats: 1433\n- NumClasses: 7\n- NumTrainingSamples: 140\n- NumValidationSamples: 500\n- NumTestSamples: 1000\n\nThe DGL's build-in CiteseerGraphDataset. Dataset Summary:\n\n- NumNodes: 3327\n- NumEdges: 9228\n- NumFeats: 3703\n- NumClasses: 6\n- NumTrainingSamples: 120\n- NumValidationSamples: 500\n- NumTestSamples: 1000\n\nThe DGL's build-in PubmedGraphDataset. Dataset Summary:\n\n- NumNodes: 19717\n- NumEdges: 88651\n- NumFeats: 500\n- NumClasses: 3\n- NumTrainingSamples: 60\n- NumValidationSamples: 500\n- NumTestSamples: 1000\n\nHow to run example files\n--------------------------------\nIn the hgao folder, run\n\n**Please use `train.py`**\n\n\n```python\npython train.py --dataset=cora\n```\n\nIf want to use a GPU, run\n\n```python\npython train.py --gpu 0 --dataset=citeseer\n```\n\nIf you want to use more Graph Hard Attention Modules\n\n```python\npython train.py --num-layers <your number> --dataset=pubmed\n```\n\nIf you want to change the hard attention threshold k\n\n```python\npython train.py --k <your number> --dataset=cora\n```\n\nIf you want to test with vanillia GAT\n\n```python\npython train.py --model <gat/hgat> --dataset=cora\n```\n\n\n\nPerformance\n-------------------------\n| Models/Datasets | Cora | Citeseer | Pubmed |\n| :-------------- | :--: | :------: | -----: |\n| GAT in DGL | 81.5% | 70.1% | 77.7% |\n| HardGAT | 81.8% | 70.2% |78.0%|\n\nNotice that HardGAT Simply replace GATConv with hGAO mentioned in paper.\n\n"
  },
  {
    "path": "examples/pytorch/hardgat/hgao.py",
    "content": "\"\"\"\nGraph Representation Learning via Hard Attention Networks in DGL using Adam optimization.\nReferences\n----------\nPaper: https://arxiv.org/abs/1907.04652\n\"\"\"\n\nfrom functools import partial\n\nimport dgl\nimport dgl.function as fn\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dgl.base import DGLError\nfrom dgl.nn.pytorch import edge_softmax\nfrom dgl.nn.pytorch.utils import Identity\nfrom dgl.sampling import select_topk\n\n\nclass HardGAO(nn.Module):\n    def __init__(\n        self,\n        in_feats,\n        out_feats,\n        num_heads=8,\n        feat_drop=0.0,\n        attn_drop=0.0,\n        negative_slope=0.2,\n        residual=True,\n        activation=F.elu,\n        k=8,\n    ):\n        super(HardGAO, self).__init__()\n        self.num_heads = num_heads\n        self.in_feats = in_feats\n        self.out_feats = out_feats\n        self.k = k\n        self.residual = residual\n        # Initialize Parameters for Additive Attention\n        self.fc = nn.Linear(\n            self.in_feats, self.out_feats * self.num_heads, bias=False\n        )\n        self.attn_l = nn.Parameter(\n            torch.FloatTensor(size=(1, self.num_heads, self.out_feats))\n        )\n        self.attn_r = nn.Parameter(\n            torch.FloatTensor(size=(1, self.num_heads, self.out_feats))\n        )\n        # Initialize Parameters for Hard Projection\n        self.p = nn.Parameter(torch.FloatTensor(size=(1, in_feats)))\n        # Initialize Dropouts\n        self.feat_drop = nn.Dropout(feat_drop)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.leaky_relu = nn.LeakyReLU(negative_slope)\n        if self.residual:\n            if self.in_feats == self.out_feats:\n                self.residual_module = Identity()\n            else:\n                self.residual_module = nn.Linear(\n                    self.in_feats, self.out_feats * num_heads, bias=False\n                )\n\n        self.reset_parameters()\n        self.activation = activation\n\n    def reset_parameters(self):\n        gain = nn.init.calculate_gain(\"relu\")\n        nn.init.xavier_normal_(self.fc.weight, gain=gain)\n        nn.init.xavier_normal_(self.p, gain=gain)\n        nn.init.xavier_normal_(self.attn_l, gain=gain)\n        nn.init.xavier_normal_(self.attn_r, gain=gain)\n        if self.residual:\n            nn.init.xavier_normal_(self.residual_module.weight, gain=gain)\n\n    def forward(self, graph, feat, get_attention=False):\n        # Check in degree and generate error\n        if (graph.in_degrees() == 0).any():\n            raise DGLError(\n                \"There are 0-in-degree nodes in the graph, \"\n                \"output for those nodes will be invalid. \"\n                \"This is harmful for some applications, \"\n                \"causing silent performance regression. \"\n                \"Adding self-loop on the input graph by \"\n                \"calling `g = dgl.add_self_loop(g)` will resolve \"\n                \"the issue. Setting ``allow_zero_in_degree`` \"\n                \"to be `True` when constructing this module will \"\n                \"suppress the check and let the code run.\"\n            )\n        # projection process to get importance vector y\n        graph.ndata[\"y\"] = torch.abs(\n            torch.matmul(self.p, feat.T).view(-1)\n        ) / torch.norm(self.p, p=2)\n        # Use edge message passing function to get the weight from src node\n        graph.apply_edges(fn.copy_u(\"y\", \"y\"))\n        # Select Top k neighbors\n        subgraph = select_topk(graph.cpu(), self.k, \"y\").to(graph.device)\n        # Sigmoid as information threshold\n        subgraph.ndata[\"y\"] = torch.sigmoid(subgraph.ndata[\"y\"])\n        # Using vector matrix elementwise mul for acceleration\n        feat = subgraph.ndata[\"y\"].view(-1, 1) * feat\n        feat = self.feat_drop(feat)\n        h = self.fc(feat).view(-1, self.num_heads, self.out_feats)\n        el = (h * self.attn_l).sum(dim=-1).unsqueeze(-1)\n        er = (h * self.attn_r).sum(dim=-1).unsqueeze(-1)\n        # Assign the value on the subgraph\n        subgraph.srcdata.update({\"ft\": h, \"el\": el})\n        subgraph.dstdata.update({\"er\": er})\n        # compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively.\n        subgraph.apply_edges(fn.u_add_v(\"el\", \"er\", \"e\"))\n        e = self.leaky_relu(subgraph.edata.pop(\"e\"))\n        # compute softmax\n        subgraph.edata[\"a\"] = self.attn_drop(edge_softmax(subgraph, e))\n        # message passing\n        subgraph.update_all(fn.u_mul_e(\"ft\", \"a\", \"m\"), fn.sum(\"m\", \"ft\"))\n        rst = subgraph.dstdata[\"ft\"]\n        # activation\n        if self.activation:\n            rst = self.activation(rst)\n        # Residual\n        if self.residual:\n            rst = rst + self.residual_module(feat).view(\n                feat.shape[0], -1, self.out_feats\n            )\n\n        if get_attention:\n            return rst, subgraph.edata[\"a\"]\n        else:\n            return rst\n\n\nclass HardGAT(nn.Module):\n    def __init__(\n        self,\n        g,\n        num_layers,\n        in_dim,\n        num_hidden,\n        num_classes,\n        heads,\n        activation,\n        feat_drop,\n        attn_drop,\n        negative_slope,\n        residual,\n        k,\n    ):\n        super(HardGAT, self).__init__()\n        self.g = g\n        self.num_layers = num_layers\n        self.gat_layers = nn.ModuleList()\n        self.activation = activation\n        gat_layer = partial(HardGAO, k=k)\n        muls = heads\n        # input projection (no residual)\n        self.gat_layers.append(\n            gat_layer(\n                in_dim,\n                num_hidden,\n                heads[0],\n                feat_drop,\n                attn_drop,\n                negative_slope,\n                False,\n                self.activation,\n            )\n        )\n        # hidden layers\n        for l in range(1, num_layers):\n            # due to multi-head, the in_dim = num_hidden * num_heads\n            self.gat_layers.append(\n                gat_layer(\n                    num_hidden * muls[l - 1],\n                    num_hidden,\n                    heads[l],\n                    feat_drop,\n                    attn_drop,\n                    negative_slope,\n                    residual,\n                    self.activation,\n                )\n            )\n        # output projection\n        self.gat_layers.append(\n            gat_layer(\n                num_hidden * muls[-2],\n                num_classes,\n                heads[-1],\n                feat_drop,\n                attn_drop,\n                negative_slope,\n                False,\n                None,\n            )\n        )\n\n    def forward(self, inputs):\n        h = inputs\n        for l in range(self.num_layers):\n            h = self.gat_layers[l](self.g, h).flatten(1)\n        logits = self.gat_layers[-1](self.g, h).mean(1)\n        return logits\n"
  },
  {
    "path": "examples/pytorch/hardgat/train.py",
    "content": "\"\"\"\nGraph Representation Learning via Hard Attention Networks in DGL using Adam optimization.\nReferences\n----------\nPaper: https://arxiv.org/abs/1907.04652\n\"\"\"\n\nimport argparse\nimport time\n\nimport dgl\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom dgl.data import (\n    CiteseerGraphDataset,\n    CoraGraphDataset,\n    PubmedGraphDataset,\n    register_data_args,\n)\nfrom hgao import HardGAT\nfrom utils import EarlyStopping\n\n\ndef accuracy(logits, labels):\n    _, indices = torch.max(logits, dim=1)\n    correct = torch.sum(indices == labels)\n    return correct.item() * 1.0 / len(labels)\n\n\ndef evaluate(model, features, labels, mask):\n    model.eval()\n    with torch.no_grad():\n        logits = model(features)\n        logits = logits[mask]\n        labels = labels[mask]\n        return accuracy(logits, labels)\n\n\ndef main(args):\n    # load and preprocess dataset\n    if args.dataset == \"cora\":\n        data = CoraGraphDataset()\n    elif args.dataset == \"citeseer\":\n        data = CiteseerGraphDataset()\n    elif args.dataset == \"pubmed\":\n        data = PubmedGraphDataset()\n    else:\n        raise ValueError(\"Unknown dataset: {}\".format(args.dataset))\n\n    if args.num_layers <= 0:\n        raise ValueError(\"num layer must be positive int\")\n    g = data[0]\n    if args.gpu < 0:\n        cuda = False\n    else:\n        cuda = True\n        g = g.to(args.gpu)\n\n    features = g.ndata[\"feat\"]\n    labels = g.ndata[\"label\"]\n    train_mask = g.ndata[\"train_mask\"]\n    val_mask = g.ndata[\"val_mask\"]\n    test_mask = g.ndata[\"test_mask\"]\n    num_feats = features.shape[1]\n    n_classes = data.num_classes\n    n_edges = g.num_edges()\n    print(\n        \"\"\"----Data statistics------'\n      #Edges %d\n      #Classes %d\n      #Train samples %d\n      #Val samples %d\n      #Test samples %d\"\"\"\n        % (\n            n_edges,\n            n_classes,\n            train_mask.int().sum().item(),\n            val_mask.int().sum().item(),\n            test_mask.int().sum().item(),\n        )\n    )\n\n    # add self loop\n    g = dgl.remove_self_loop(g)\n    g = dgl.add_self_loop(g)\n    n_edges = g.num_edges()\n    # create model\n    heads = ([args.num_heads] * args.num_layers) + [args.num_out_heads]\n    model = HardGAT(\n        g,\n        args.num_layers,\n        num_feats,\n        args.num_hidden,\n        n_classes,\n        heads,\n        F.elu,\n        args.in_drop,\n        args.attn_drop,\n        args.negative_slope,\n        args.residual,\n        args.k,\n    )\n    print(model)\n    if args.early_stop:\n        stopper = EarlyStopping(patience=100)\n    if cuda:\n        model.cuda()\n    loss_fcn = torch.nn.CrossEntropyLoss()\n\n    # use optimizer\n    optimizer = torch.optim.Adam(\n        model.parameters(), lr=args.lr, weight_decay=args.weight_decay\n    )\n\n    # initialize graph\n    mean = 0\n    for epoch in range(args.epochs):\n        model.train()\n        if epoch >= 3:\n            t0 = time.time()\n        # forward\n        logits = model(features)\n        loss = loss_fcn(logits[train_mask], labels[train_mask])\n\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        if epoch >= 3:\n            mean = (mean * (epoch - 3) + (time.time() - t0)) / (epoch - 2)\n            train_acc = accuracy(logits[train_mask], labels[train_mask])\n\n            if args.fastmode:\n                val_acc = accuracy(logits[val_mask], labels[val_mask])\n            else:\n                val_acc = evaluate(model, features, labels, val_mask)\n                if args.early_stop:\n                    if stopper.step(val_acc, model):\n                        break\n\n            print(\n                \"Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | TrainAcc {:.4f} |\"\n                \" ValAcc {:.4f} | ETputs(KTEPS) {:.2f}\".format(\n                    epoch,\n                    mean,\n                    loss.item(),\n                    train_acc,\n                    val_acc,\n                    n_edges / mean / 1000,\n                )\n            )\n\n    print()\n    if args.early_stop:\n        model.load_state_dict(\n            torch.load(\"es_checkpoint.pt\", weights_only=False)\n        )\n    acc = evaluate(model, features, labels, test_mask)\n    print(\"Test Accuracy {:.4f}\".format(acc))\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"GAT\")\n    register_data_args(parser)\n    parser.add_argument(\n        \"--gpu\",\n        type=int,\n        default=-1,\n        help=\"which GPU to use. Set -1 to use CPU.\",\n    )\n    parser.add_argument(\n        \"--epochs\", type=int, default=200, help=\"number of training epochs\"\n    )\n    parser.add_argument(\n        \"--num-heads\",\n        type=int,\n        default=8,\n        help=\"number of hidden attention heads\",\n    )\n    parser.add_argument(\n        \"--num-out-heads\",\n        type=int,\n        default=1,\n        help=\"number of output attention heads\",\n    )\n    parser.add_argument(\n        \"--num-layers\", type=int, default=1, help=\"number of hidden layers\"\n    )\n    parser.add_argument(\n        \"--num-hidden\", type=int, default=8, help=\"number of hidden units\"\n    )\n    parser.add_argument(\n        \"--residual\",\n        action=\"store_true\",\n        default=False,\n        help=\"use residual connection\",\n    )\n    parser.add_argument(\n        \"--in-drop\", type=float, default=0.6, help=\"input feature dropout\"\n    )\n    parser.add_argument(\n        \"--attn-drop\", type=float, default=0.6, help=\"attention dropout\"\n    )\n    parser.add_argument(\"--lr\", type=float, default=0.01, help=\"learning rate\")\n    parser.add_argument(\n        \"--weight-decay\", type=float, default=5e-4, help=\"weight decay\"\n    )\n    parser.add_argument(\n        \"--negative-slope\",\n        type=float,\n        default=0.2,\n        help=\"the negative slope of leaky relu\",\n    )\n    parser.add_argument(\n        \"--early-stop\",\n        action=\"store_true\",\n        default=False,\n        help=\"indicates whether to use early stop or not\",\n    )\n    parser.add_argument(\n        \"--fastmode\",\n        action=\"store_true\",\n        default=False,\n        help=\"skip re-evaluate the validation set\",\n    )\n    parser.add_argument(\n        \"--k\",\n        type=int,\n        default=8,\n        help=\"top k neighor for attention calculation\",\n    )\n    args = parser.parse_args()\n    print(args)\n\n    main(args)\n"
  },
  {
    "path": "examples/pytorch/hardgat/utils.py",
    "content": "\"\"\"\nGraph Representation Learning via Hard Attention Networks in DGL using Adam optimization.\nReferences\n----------\nPaper: https://arxiv.org/abs/1907.04652\n\"\"\"\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\n\n\nclass EarlyStopping:\n    def __init__(self, patience=10):\n        self.patience = patience\n        self.counter = 0\n        self.best_score = None\n        self.early_stop = False\n\n    def step(self, acc, model):\n        score = acc\n        if self.best_score is None:\n            self.best_score = score\n            self.save_checkpoint(model)\n        elif score < self.best_score:\n            self.counter += 1\n            print(\n                f\"EarlyStopping counter: {self.counter} out of {self.patience}\"\n            )\n            if self.counter >= self.patience:\n                self.early_stop = True\n        else:\n            self.best_score = score\n            self.save_checkpoint(model)\n            self.counter = 0\n        return self.early_stop\n\n    def save_checkpoint(self, model):\n        \"\"\"Saves model when validation loss decrease.\"\"\"\n        torch.save(model.state_dict(), \"es_checkpoint.pt\")\n"
  },
  {
    "path": "examples/pytorch/hgp_sl/README.md",
    "content": "# DGL Implementation of the HGP-SL Paper\n\nThis DGL example implements the GNN model proposed in the paper [Hierarchical Graph Pooling with Structure Learning](https://arxiv.org/pdf/1911.05954.pdf). \nThe author's codes of implementation is in [here](https://github.com/cszhangzhen/HGP-SL)\n\n\nExample implementor\n----------------------\nThis example was implemented by [Tianqi Zhang](https://github.com/lygztq) during his Applied Scientist Intern work at the AWS Shanghai AI Lab.\n\n\nThe graph dataset used in this example \n---------------------------------------\nThe DGL's built-in [LegacyTUDataset](https://docs.dgl.ai/api/python/dgl.data.html?highlight=tudataset#dgl.data.LegacyTUDataset). This is a serial of graph kernel datasets for graph classification. We use 'DD', 'PROTEINS', 'NCI1', 'NCI109', 'Mutagenicity' and 'ENZYMES' in this HGP-SL implementation. All these datasets are randomly splited to train, validation and test set with ratio 0.8, 0.1 and 0.1.\n\nNOTE: Since there is no data attributes in some of these datasets, we use node_id (in one-hot vector whose length is the max number of nodes across all graphs) as the node feature. Also note that the node_id in some datasets is not unique (e.g. a graph may has two nodes with the same id).\n\n|                  | DD     | PROTEINS | NCI1  | NCI109 | Mutagenicity | ENZYMES |\n| ---------------- | ------ | -------- | ----- | ------ | ------------ | ------- |\n| NumGraphs        | 1178   | 1113     | 4110  | 4127   | 4337         | 600     |\n| AvgNodesPerGraph | 284.32 | 39.06    | 29.87 | 29.68  | 30.32        | 32.63   |\n| AvgEdgesPerGraph | 715.66 | 72.82    | 32.30 | 32.13  | 30.77        | 62.14   |\n| NumFeats         | 89     | 1        | 37    | 38     | 14           | 18      |\n| NumClasses       | 2      | 2        | 2     | 2      | 2            | 6       |\n\n\nHow to run example files\n--------------------------------\nIn the HGP-SL-DGL folder, run\n\n```bash\npython main.py --dataset ${your_dataset_name_here} [hyper-parameters]\n```\n\nIf want to use a GPU, run\n\n```bash\npython main.py --device ${your_device_id_here} --dataset ${your_dataset_name_here} [hyper-parameters]\n```\n\nFor example, to perform experiments on DD dataset on GPU, run:\n\n```bash\npython main.py --device 0 --dataset DD --lr 0.0001 --batch_size 64 --pool_ratio 0.3 --dropout 0.5 --conv_layers 2\n```\n\nNOTE: Be careful when modifying `batch_size` and `pool_ratio` for large dataset like DD. Too large batch size or pooling ratio may cause out-of-memory and other severe errors.\n\nYou can find the detailed hyper-parameter settings below (in the Performance section).\n\nPerformance\n-------------------------\n\n**Hyper-parameters**\n\nThis part is directly from [author's implementation](https://github.com/cszhangzhen/HGP-SL)\n\n| Datasets      | lr        | weight_decay   | batch_size      | pool_ratio     | dropout  | net_layers |\n| ------------- | --------- | -------------- | --------------- | -------------- | -------- | ---------- |\n| PROTEINS      | 0.001     | 0.001          | 512             | 0.5            | 0.0      | 3          | \n| Mutagenicity  | 0.001     | 0.001          | 512             | 0.8            | 0.0      | 3          |\n| NCI109        | 0.001     | 0.001          | 512             | 0.8            | 0.0      | 3          |\n| NCI1          | 0.001     | 0.001          | 512             | 0.8            | 0.0      | 3          |\n| DD            | 0.0001    | 0.001          | 64              | 0.3            | 0.5      | 2          |\n| ENZYMES       | 0.001     | 0.001          | 128             | 0.8            | 0.0      | 2          |\n\n\n**Accuracy**\n\n**NOTE**: We find that there is a gap between accuracy obtained via author's code and the one reported in the [paper]((https://arxiv.org/pdf/1911.05954.pdf)). An issue has been proposed in the author's repo (see [here](https://github.com/cszhangzhen/HGP-SL/issues/8)).\n\n|                            | Mutagenicity | NCI109      | NCI1        | DD          |\n| -------------------------- | ------------ | ----------- | ----------- | ----------- |\n| Reported in Paper          | 82.15(0.58)  | 80.67(1.16) | 78.45(0.77) | 80.96(1.26) |\n| Author's Code (full graph) | 78.44(2.10)  | 74.44(2.05) | 77.37(2.09) | OOM         |\n| Author's Code (sample)     | 79.68(1.68)  | 73.86(1.72) | 76.29(2.14) | 75.46(3.86) |\n| DGL (full graph)           | 79.52(2.21)  | 74.86(1.99) | 74.62(2.22) | OOM         |\n| DGL (sample)               | 79.15(1.62)  | 75.39(1.86) | 73.77(2.04) | 76.47(2.14) |\n\n\n**Speed**\n\nDevice: Tesla V100-SXM2 16GB\n\nIn seconds\n\n|                               | DD(batchsize=64), large graph | Mutagenicity(batchsize=512), small graph |\n| ----------------------------- | ----------------------------- | ---------------------------------------- |\n| Author's code (sample)        | 9.96                          | 12.91                                    |\n| Author's code (full graph)    | OOM                           | 13.03                                    |\n| DGL (sample)                  | 9.50                          | 3.59                                     |\n| DGL (full graph)              | OOM                           | 3.56                                     |\n"
  },
  {
    "path": "examples/pytorch/hgp_sl/functions.py",
    "content": "\"\"\"\nAn original implementation of sparsemax (Martins & Astudillo, 2016) is available at\nhttps://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/modules/sparse_activations.py.\nSee `From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label Classification, ICML 2016`\nfor detailed description.\n\nHere we implement a graph-edge version of sparsemax where we perform sparsemax for all edges\nwith the same node as end-node in graphs.\n\"\"\"\nimport dgl\nimport torch\nfrom dgl._sparse_ops import _gsddmm, _gspmm\nfrom dgl.backend import astype\nfrom dgl.base import ALL, is_all\nfrom dgl.heterograph_index import HeteroGraphIndex\nfrom torch import Tensor\nfrom torch.autograd import Function\n\n\ndef _neighbor_sort(\n    scores: Tensor,\n    end_n_ids: Tensor,\n    in_degrees: Tensor,\n    cum_in_degrees: Tensor,\n):\n    \"\"\"Sort edge scores for each node\"\"\"\n    num_nodes, max_in_degree = in_degrees.size(0), int(in_degrees.max().item())\n\n    # Compute the index for dense score matrix with size (N x D_{max})\n    # Note that the end_n_ids here is the end_node tensor in dgl graph,\n    # which is not grouped by its node id (i.e. in this form: 0,0,1,1,1,...,N,N).\n    # Thus here we first sort the end_node tensor to make it easier to compute\n    # indexs in dense edge score matrix. Since we will need the original order\n    # for following gspmm and gsddmm operations, we also keep the reverse mapping\n    # (the reverse_perm) here.\n    end_n_ids, perm = torch.sort(end_n_ids)\n    scores = scores[perm]\n    _, reverse_perm = torch.sort(perm)\n\n    index = torch.arange(\n        end_n_ids.size(0), dtype=torch.long, device=scores.device\n    )\n    index = (index - cum_in_degrees[end_n_ids]) + (end_n_ids * max_in_degree)\n    index = index.long()\n\n    dense_scores = scores.new_full(\n        (num_nodes * max_in_degree,), torch.finfo(scores.dtype).min\n    )\n    dense_scores[index] = scores\n    dense_scores = dense_scores.view(num_nodes, max_in_degree)\n\n    sorted_dense_scores, dense_reverse_perm = dense_scores.sort(\n        dim=-1, descending=True\n    )\n    _, dense_reverse_perm = torch.sort(dense_reverse_perm, dim=-1)\n    dense_reverse_perm = dense_reverse_perm + cum_in_degrees.view(-1, 1)\n    dense_reverse_perm = dense_reverse_perm.view(-1)\n    cumsum_sorted_dense_scores = sorted_dense_scores.cumsum(dim=-1).view(-1)\n    sorted_dense_scores = sorted_dense_scores.view(-1)\n    arange_vec = torch.arange(\n        1, max_in_degree + 1, dtype=torch.long, device=end_n_ids.device\n    )\n    arange_vec = torch.repeat_interleave(\n        arange_vec.view(1, -1), num_nodes, dim=0\n    ).view(-1)\n\n    valid_mask = sorted_dense_scores != torch.finfo(scores.dtype).min\n    sorted_scores = sorted_dense_scores[valid_mask]\n    cumsum_sorted_scores = cumsum_sorted_dense_scores[valid_mask]\n    arange_vec = arange_vec[valid_mask]\n    dense_reverse_perm = dense_reverse_perm[valid_mask].long()\n\n    return (\n        sorted_scores,\n        cumsum_sorted_scores,\n        arange_vec,\n        reverse_perm,\n        dense_reverse_perm,\n    )\n\n\ndef _threshold_and_support_graph(\n    gidx: HeteroGraphIndex, scores: Tensor, end_n_ids: Tensor\n):\n    \"\"\"Find the threshold for each node and its edges\"\"\"\n    in_degrees = _gspmm(gidx, \"copy_rhs\", \"sum\", None, torch.ones_like(scores))[\n        0\n    ]\n    cum_in_degrees = torch.cat(\n        [in_degrees.new_zeros(1), in_degrees.cumsum(dim=0)[:-1]], dim=0\n    )\n\n    # perform sort on edges for each node\n    (\n        sorted_scores,\n        cumsum_scores,\n        rhos,\n        reverse_perm,\n        dense_reverse_perm,\n    ) = _neighbor_sort(scores, end_n_ids, in_degrees, cum_in_degrees)\n    cumsum_scores = cumsum_scores - 1.0\n    support = rhos * sorted_scores > cumsum_scores\n    support = support[dense_reverse_perm]  # from sorted order to unsorted order\n    support = support[reverse_perm]  # from src-dst order to eid order\n\n    support_size = _gspmm(gidx, \"copy_rhs\", \"sum\", None, support.float())[0]\n    support_size = support_size.long()\n    idx = support_size + cum_in_degrees - 1\n\n    # mask invalid index, for example, if batch is not start from 0 or not continuous, it may result in negative index\n    mask = idx < 0\n    idx[mask] = 0\n    tau = cumsum_scores.gather(0, idx.long())\n    tau /= support_size.to(scores.dtype)\n\n    return tau, support_size\n\n\nclass EdgeSparsemaxFunction(Function):\n    r\"\"\"\n    Description\n    -----------\n    Pytorch Auto-Grad Function for edge sparsemax.\n\n    We define this auto-grad function here since\n    sparsemax involves sort and select, which are\n    not derivative.\n    \"\"\"\n\n    @staticmethod\n    def forward(\n        ctx,\n        gidx: HeteroGraphIndex,\n        scores: Tensor,\n        eids: Tensor,\n        end_n_ids: Tensor,\n        norm_by: str,\n    ):\n        if not is_all(eids):\n            gidx = gidx.edge_subgraph([eids], True).graph\n        if norm_by == \"src\":\n            gidx = gidx.reverse()\n\n        # use feat - max(feat) for numerical stability.\n        scores = scores.float()\n        scores_max = _gspmm(gidx, \"copy_rhs\", \"max\", None, scores)[0]\n        scores = _gsddmm(gidx, \"sub\", scores, scores_max, \"e\", \"v\")\n\n        # find threshold for each node and perform ReLU(u-t(u)) operation.\n        tau, supp_size = _threshold_and_support_graph(gidx, scores, end_n_ids)\n        out = torch.clamp(_gsddmm(gidx, \"sub\", scores, tau, \"e\", \"v\"), min=0)\n        ctx.backward_cache = gidx\n        ctx.save_for_backward(supp_size, out)\n        torch.cuda.empty_cache()\n        return out\n\n    @staticmethod\n    def backward(ctx, grad_out):\n        gidx = ctx.backward_cache\n        supp_size, out = ctx.saved_tensors\n        grad_in = grad_out.clone()\n\n        # grad for ReLU\n        grad_in[out == 0] = 0\n\n        # dL/dv_i = dL/do_i - 1/k \\sum_{j=1}^k dL/do_j\n        v_hat = _gspmm(gidx, \"copy_rhs\", \"sum\", None, grad_in)[\n            0\n        ] / supp_size.to(out.dtype)\n        grad_in_modify = _gsddmm(gidx, \"sub\", grad_in, v_hat, \"e\", \"v\")\n        grad_in = torch.where(out != 0, grad_in_modify, grad_in)\n        del gidx\n        torch.cuda.empty_cache()\n\n        return None, grad_in, None, None, None\n\n\ndef edge_sparsemax(graph: dgl.DGLGraph, logits, eids=ALL, norm_by=\"dst\"):\n    r\"\"\"\n    Description\n    -----------\n    Compute edge sparsemax. For a node :math:`i`, edge sparsemax is an operation that computes\n\n    .. math::\n      a_{ij} = \\text{ReLU}(z_{ij} - \\tau(\\z_{i,:}))\n\n    where :math:`z_{ij}` is a signal of edge :math:`j\\rightarrow i`, also\n    called logits in the context of sparsemax. :math:`\\tau` is a function\n    that can be found at the `From Softmax to Sparsemax <https://arxiv.org/pdf/1602.02068.pdf>`\n    paper.\n\n    NOTE: currently only homogeneous graphs are supported.\n\n    Parameters\n    ----------\n    graph : DGLGraph\n        The graph to perform edge sparsemax on.\n    logits : torch.Tensor\n        The input edge feature.\n    eids : torch.Tensor or ALL, optional\n        A tensor of edge index on which to apply edge sparsemax. If ALL, apply edge\n        sparsemax on all edges in the graph. Default: ALL.\n    norm_by : str, could be 'src' or 'dst'\n        Normalized by source nodes of destination nodes. Default: `dst`.\n\n    Returns\n    -------\n    Tensor\n        Sparsemax value.\n    \"\"\"\n    # we get edge index tensors here since it is\n    # hard to get edge index with HeteroGraphIndex\n    # object without other information like edge_type.\n    row, col = graph.all_edges(order=\"eid\")\n    assert norm_by in [\"dst\", \"src\"]\n    end_n_ids = col if norm_by == \"dst\" else row\n    if not is_all(eids):\n        eids = astype(eids, graph.idtype)\n        end_n_ids = end_n_ids[eids]\n    return EdgeSparsemaxFunction.apply(\n        graph._graph, logits, eids, end_n_ids, norm_by\n    )\n\n\nclass EdgeSparsemax(torch.nn.Module):\n    r\"\"\"\n    Description\n    -----------\n    Compute edge sparsemax. For a node :math:`i`, edge sparsemax is an operation that computes\n\n    .. math::\n      a_{ij} = \\text{ReLU}(z_{ij} - \\tau(\\z_{i,:}))\n\n    where :math:`z_{ij}` is a signal of edge :math:`j\\rightarrow i`, also\n    called logits in the context of sparsemax. :math:`\\tau` is a function\n    that can be found at the `From Softmax to Sparsemax <https://arxiv.org/pdf/1602.02068.pdf>`\n    paper.\n\n    Parameters\n    ----------\n    graph : DGLGraph\n        The graph to perform edge sparsemax on.\n    logits : torch.Tensor\n        The input edge feature.\n    eids : torch.Tensor or ALL, optional\n        A tensor of edge index on which to apply edge sparsemax. If ALL, apply edge\n        sparsemax on all edges in the graph. Default: ALL.\n    norm_by : str, could be 'src' or 'dst'\n        Normalized by source nodes of destination nodes. Default: `dst`.\n\n    NOTE: currently only homogeneous graphs are supported.\n\n    Returns\n    -------\n    Tensor\n        Sparsemax value.\n    \"\"\"\n\n    def __init__(self):\n        super(EdgeSparsemax, self).__init__()\n\n    def forward(self, graph, logits, eids=ALL, norm_by=\"dst\"):\n        return edge_sparsemax(graph, logits, eids, norm_by)\n"
  },
  {
    "path": "examples/pytorch/hgp_sl/layers.py",
    "content": "import dgl\nimport dgl.function as fn\nimport scipy.sparse\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dgl import DGLGraph\nfrom dgl.nn import AvgPooling, GraphConv, MaxPooling\nfrom dgl.ops import edge_softmax\n\nfrom functions import edge_sparsemax\nfrom torch import Tensor\nfrom torch.nn import Parameter\nfrom utils import get_batch_id, topk\n\n\nclass WeightedGraphConv(GraphConv):\n    r\"\"\"\n    Description\n    -----------\n    GraphConv with edge weights on homogeneous graphs.\n    If edge weights are not given, directly call GraphConv instead.\n\n    Parameters\n    ----------\n    graph : DGLGraph\n        The graph to perform this operation.\n    n_feat : torch.Tensor\n        The node features\n    e_feat : torch.Tensor, optional\n        The edge features. Default: :obj:`None`\n    \"\"\"\n\n    def forward(self, graph: DGLGraph, n_feat, e_feat=None):\n        if e_feat is None:\n            return super(WeightedGraphConv, self).forward(graph, n_feat)\n\n        with graph.local_scope():\n            if self.weight is not None:\n                n_feat = torch.matmul(n_feat, self.weight)\n            src_norm = torch.pow(graph.out_degrees().float().clamp(min=1), -0.5)\n            src_norm = src_norm.view(-1, 1)\n            dst_norm = torch.pow(graph.in_degrees().float().clamp(min=1), -0.5)\n            dst_norm = dst_norm.view(-1, 1)\n            n_feat = n_feat * src_norm\n            graph.ndata[\"h\"] = n_feat\n            graph.edata[\"e\"] = e_feat\n            graph.update_all(fn.u_mul_e(\"h\", \"e\", \"m\"), fn.sum(\"m\", \"h\"))\n            n_feat = graph.ndata.pop(\"h\")\n            n_feat = n_feat * dst_norm\n            if self.bias is not None:\n                n_feat = n_feat + self.bias\n            if self._activation is not None:\n                n_feat = self._activation(n_feat)\n            return n_feat\n\n\nclass NodeInfoScoreLayer(nn.Module):\n    r\"\"\"\n    Description\n    -----------\n    Compute a score for each node for sort-pooling. The score of each node\n    is computed via the absolute difference of its first-order random walk\n    result and its features.\n\n    Arguments\n    ---------\n    sym_norm : bool, optional\n        If true, use symmetric norm for adjacency.\n        Default: :obj:`True`\n\n    Parameters\n    ----------\n    graph : DGLGraph\n        The graph to perform this operation.\n    feat : torch.Tensor\n        The node features\n    e_feat : torch.Tensor, optional\n        The edge features. Default: :obj:`None`\n\n    Returns\n    -------\n    Tensor\n        Score for each node.\n    \"\"\"\n\n    def __init__(self, sym_norm: bool = True):\n        super(NodeInfoScoreLayer, self).__init__()\n        self.sym_norm = sym_norm\n\n    def forward(self, graph: dgl.DGLGraph, feat: Tensor, e_feat: Tensor):\n        with graph.local_scope():\n            if self.sym_norm:\n                src_norm = torch.pow(\n                    graph.out_degrees().float().clamp(min=1), -0.5\n                )\n                src_norm = src_norm.view(-1, 1).to(feat.device)\n                dst_norm = torch.pow(\n                    graph.in_degrees().float().clamp(min=1), -0.5\n                )\n                dst_norm = dst_norm.view(-1, 1).to(feat.device)\n\n                src_feat = feat * src_norm\n\n                graph.ndata[\"h\"] = src_feat\n                graph.edata[\"e\"] = e_feat\n                graph = dgl.remove_self_loop(graph)\n                graph.update_all(fn.u_mul_e(\"h\", \"e\", \"m\"), fn.sum(\"m\", \"h\"))\n\n                dst_feat = graph.ndata.pop(\"h\") * dst_norm\n                feat = feat - dst_feat\n            else:\n                dst_norm = 1.0 / graph.in_degrees().float().clamp(min=1)\n                dst_norm = dst_norm.view(-1, 1)\n\n                graph.ndata[\"h\"] = feat\n                graph.edata[\"e\"] = e_feat\n                graph = dgl.remove_self_loop(graph)\n                graph.update_all(fn.u_mul_e(\"h\", \"e\", \"m\"), fn.sum(\"m\", \"h\"))\n\n                feat = feat - dst_norm * graph.ndata.pop(\"h\")\n\n            score = torch.sum(torch.abs(feat), dim=1)\n            return score\n\n\nclass HGPSLPool(nn.Module):\n    r\"\"\"\n\n    Description\n    -----------\n    The HGP-SL pooling layer from\n    `Hierarchical Graph Pooling with Structure Learning <https://arxiv.org/pdf/1911.05954.pdf>`\n\n    Parameters\n    ----------\n    in_feat : int\n        The number of input node feature's channels\n    ratio : float, optional\n        Pooling ratio. Default: 0.8\n    sample : bool, optional\n        Whether use k-hop union graph to increase efficiency.\n        Currently we only support full graph. Default: :obj:`False`\n    sym_score_norm : bool, optional\n        Use symmetric norm for adjacency or not. Default: :obj:`True`\n    sparse : bool, optional\n        Use edge sparsemax instead of edge softmax. Default: :obj:`True`\n    sl : bool, optional\n        Use structure learining module or not. Default: :obj:`True`\n    lamb : float, optional\n        The lambda parameter as weight of raw adjacency as described in the\n        HGP-SL paper. Default: 1.0\n    negative_slop : float, optional\n        Negative slop for leaky_relu. Default: 0.2\n\n    Returns\n    -------\n    DGLGraph\n        The pooled graph.\n    torch.Tensor\n        Node features\n    torch.Tensor\n        Edge features\n    torch.Tensor\n        Permutation index\n    \"\"\"\n\n    def __init__(\n        self,\n        in_feat: int,\n        ratio=0.8,\n        sample=True,\n        sym_score_norm=True,\n        sparse=True,\n        sl=True,\n        lamb=1.0,\n        negative_slop=0.2,\n        k_hop=3,\n    ):\n        super(HGPSLPool, self).__init__()\n        self.in_feat = in_feat\n        self.ratio = ratio\n        self.sample = sample\n        self.sparse = sparse\n        self.sl = sl\n        self.lamb = lamb\n        self.negative_slop = negative_slop\n        self.k_hop = k_hop\n\n        self.att = Parameter(torch.Tensor(1, self.in_feat * 2))\n        self.calc_info_score = NodeInfoScoreLayer(sym_norm=sym_score_norm)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        nn.init.xavier_normal_(self.att.data)\n\n    def forward(self, graph: DGLGraph, feat: Tensor, e_feat=None):\n        # top-k pool first\n        if e_feat is None:\n            e_feat = torch.ones(\n                (graph.num_edges(),), dtype=feat.dtype, device=feat.device\n            )\n        batch_num_nodes = graph.batch_num_nodes()\n        x_score = self.calc_info_score(graph, feat, e_feat)\n        perm, next_batch_num_nodes = topk(\n            x_score, self.ratio, get_batch_id(batch_num_nodes), batch_num_nodes\n        )\n        feat = feat[perm]\n        pool_graph = None\n        if not self.sample or not self.sl:\n            # pool graph\n            graph.edata[\"e\"] = e_feat\n            pool_graph = dgl.node_subgraph(graph, perm)\n            e_feat = pool_graph.edata.pop(\"e\")\n            pool_graph.set_batch_num_nodes(next_batch_num_nodes)\n\n        # no structure learning layer, directly return.\n        if not self.sl:\n            return pool_graph, feat, e_feat, perm\n\n        # Structure Learning\n        if self.sample:\n            # A fast mode for large graphs.\n            # In large graphs, learning the possible edge weights between each\n            # pair of nodes is time consuming. To accelerate this process,\n            # we sample it's K-Hop neighbors for each node and then learn the\n            # edge weights between them.\n\n            # first build multi-hop graph\n            row, col = graph.all_edges()\n            num_nodes = graph.num_nodes()\n\n            scipy_adj = scipy.sparse.coo_matrix(\n                (\n                    e_feat.detach().cpu(),\n                    (row.detach().cpu(), col.detach().cpu()),\n                ),\n                shape=(num_nodes, num_nodes),\n            )\n            for _ in range(self.k_hop):\n                two_hop = scipy_adj**2\n                two_hop = two_hop * (1e-5 / two_hop.max())\n                scipy_adj = two_hop + scipy_adj\n            row, col = scipy_adj.nonzero()\n            row = torch.tensor(row, dtype=torch.long, device=graph.device)\n            col = torch.tensor(col, dtype=torch.long, device=graph.device)\n            e_feat = torch.tensor(\n                scipy_adj.data, dtype=torch.float, device=feat.device\n            )\n\n            # perform pooling on multi-hop graph\n            mask = perm.new_full((num_nodes,), -1)\n            i = torch.arange(perm.size(0), dtype=torch.long, device=perm.device)\n            mask[perm] = i\n            row, col = mask[row], mask[col]\n            mask = (row >= 0) & (col >= 0)\n            row, col = row[mask], col[mask]\n            e_feat = e_feat[mask]\n\n            # add remaining self loops\n            mask = row != col\n            num_nodes = perm.size(0)  # num nodes after pool\n            loop_index = torch.arange(\n                0, num_nodes, dtype=row.dtype, device=row.device\n            )\n            inv_mask = ~mask\n            loop_weight = torch.full(\n                (num_nodes,), 0, dtype=e_feat.dtype, device=e_feat.device\n            )\n            remaining_e_feat = e_feat[inv_mask]\n            if remaining_e_feat.numel() > 0:\n                loop_weight[row[inv_mask]] = remaining_e_feat\n            e_feat = torch.cat([e_feat[mask], loop_weight], dim=0)\n            row, col = row[mask], col[mask]\n            row = torch.cat([row, loop_index], dim=0)\n            col = torch.cat([col, loop_index], dim=0)\n\n            # attention scores\n            weights = (torch.cat([feat[row], feat[col]], dim=1) * self.att).sum(\n                dim=-1\n            )\n            weights = (\n                F.leaky_relu(weights, self.negative_slop) + e_feat * self.lamb\n            )\n\n            # sl and normalization\n            sl_graph = dgl.graph((row, col))\n            if self.sparse:\n                weights = edge_sparsemax(sl_graph, weights)\n            else:\n                weights = edge_softmax(sl_graph, weights)\n\n            # get final graph\n            mask = torch.abs(weights) > 0\n            row, col, weights = row[mask], col[mask], weights[mask]\n            pool_graph = dgl.graph((row, col))\n            pool_graph.set_batch_num_nodes(next_batch_num_nodes)\n            e_feat = weights\n\n        else:\n            # Learning the possible edge weights between each pair of\n            # nodes in the pooled subgraph, relative slower.\n\n            # construct complete graphs for all graph in the batch\n            # use dense to build, then transform to sparse.\n            # maybe there's more efficient way?\n            batch_num_nodes = next_batch_num_nodes\n            block_begin_idx = torch.cat(\n                [\n                    batch_num_nodes.new_zeros(1),\n                    batch_num_nodes.cumsum(dim=0)[:-1],\n                ],\n                dim=0,\n            )\n            block_end_idx = batch_num_nodes.cumsum(dim=0)\n            dense_adj = torch.zeros(\n                (pool_graph.num_nodes(), pool_graph.num_nodes()),\n                dtype=torch.float,\n                device=feat.device,\n            )\n            for idx_b, idx_e in zip(block_begin_idx, block_end_idx):\n                dense_adj[idx_b:idx_e, idx_b:idx_e] = 1.0\n            row, col = torch.nonzero(dense_adj).t().contiguous()\n\n            # compute weights for node-pairs\n            weights = (torch.cat([feat[row], feat[col]], dim=1) * self.att).sum(\n                dim=-1\n            )\n            weights = F.leaky_relu(weights, self.negative_slop)\n            dense_adj[row, col] = weights\n\n            # add pooled graph structure to weight matrix\n            pool_row, pool_col = pool_graph.all_edges()\n            dense_adj[pool_row, pool_col] += self.lamb * e_feat\n            weights = dense_adj[row, col]\n            del dense_adj\n            torch.cuda.empty_cache()\n\n            # edge softmax/sparsemax\n            complete_graph = dgl.graph((row, col))\n            if self.sparse:\n                weights = edge_sparsemax(complete_graph, weights)\n            else:\n                weights = edge_softmax(complete_graph, weights)\n\n            # get new e_feat and graph structure, clean up.\n            mask = torch.abs(weights) > 1e-9\n            row, col, weights = row[mask], col[mask], weights[mask]\n            e_feat = weights\n            pool_graph = dgl.graph((row, col))\n            pool_graph.set_batch_num_nodes(next_batch_num_nodes)\n\n        return pool_graph, feat, e_feat, perm\n\n\nclass ConvPoolReadout(torch.nn.Module):\n    \"\"\"A helper class. (GraphConv -> Pooling -> Readout)\"\"\"\n\n    def __init__(\n        self,\n        in_feat: int,\n        out_feat: int,\n        pool_ratio=0.8,\n        sample: bool = False,\n        sparse: bool = True,\n        sl: bool = True,\n        lamb: float = 1.0,\n        pool: bool = True,\n    ):\n        super(ConvPoolReadout, self).__init__()\n        self.use_pool = pool\n        self.conv = WeightedGraphConv(in_feat, out_feat)\n        if pool:\n            self.pool = HGPSLPool(\n                out_feat,\n                ratio=pool_ratio,\n                sparse=sparse,\n                sample=sample,\n                sl=sl,\n                lamb=lamb,\n            )\n        else:\n            self.pool = None\n        self.avgpool = AvgPooling()\n        self.maxpool = MaxPooling()\n\n    def forward(self, graph, feature, e_feat=None):\n        out = F.relu(self.conv(graph, feature, e_feat))\n        if self.use_pool:\n            graph, out, e_feat, _ = self.pool(graph, out, e_feat)\n        readout = torch.cat(\n            [self.avgpool(graph, out), self.maxpool(graph, out)], dim=-1\n        )\n        return graph, out, e_feat, readout\n"
  },
  {
    "path": "examples/pytorch/hgp_sl/main.py",
    "content": "import argparse\nimport json\nimport logging\nimport os\nfrom time import time\n\nimport dgl\n\nimport torch\nimport torch.nn\nimport torch.nn.functional as F\nfrom dgl.data import LegacyTUDataset\nfrom dgl.dataloading import GraphDataLoader\nfrom networks import HGPSLModel\nfrom torch.utils.data import random_split\nfrom utils import get_stats\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"HGP-SL-DGL\")\n    parser.add_argument(\n        \"--dataset\",\n        type=str,\n        default=\"DD\",\n        choices=[\"DD\", \"PROTEINS\", \"NCI1\", \"NCI109\", \"Mutagenicity\", \"ENZYMES\"],\n        help=\"DD/PROTEINS/NCI1/NCI109/Mutagenicity/ENZYMES\",\n    )\n    parser.add_argument(\n        \"--batch_size\", type=int, default=512, help=\"batch size\"\n    )\n    parser.add_argument(\n        \"--sample\", type=str, default=\"true\", help=\"use sample method\"\n    )\n    parser.add_argument(\"--lr\", type=float, default=1e-3, help=\"learning rate\")\n    parser.add_argument(\n        \"--weight_decay\", type=float, default=1e-3, help=\"weight decay\"\n    )\n    parser.add_argument(\n        \"--pool_ratio\", type=float, default=0.5, help=\"pooling ratio\"\n    )\n    parser.add_argument(\"--hid_dim\", type=int, default=128, help=\"hidden size\")\n    parser.add_argument(\n        \"--conv_layers\", type=int, default=3, help=\"number of conv layers\"\n    )\n    parser.add_argument(\n        \"--dropout\", type=float, default=0.0, help=\"dropout ratio\"\n    )\n    parser.add_argument(\n        \"--lamb\", type=float, default=1.0, help=\"trade-off parameter\"\n    )\n    parser.add_argument(\n        \"--epochs\", type=int, default=1000, help=\"max number of training epochs\"\n    )\n    parser.add_argument(\n        \"--patience\", type=int, default=100, help=\"patience for early stopping\"\n    )\n    parser.add_argument(\n        \"--device\", type=int, default=-1, help=\"device id, -1 for cpu\"\n    )\n    parser.add_argument(\n        \"--dataset_path\", type=str, default=\"./dataset\", help=\"path to dataset\"\n    )\n    parser.add_argument(\n        \"--print_every\",\n        type=int,\n        default=10,\n        help=\"print trainlog every k epochs, -1 for silent training\",\n    )\n    parser.add_argument(\n        \"--num_trials\", type=int, default=1, help=\"number of trials\"\n    )\n    parser.add_argument(\"--output_path\", type=str, default=\"./output\")\n\n    args = parser.parse_args()\n\n    # device\n    args.device = \"cpu\" if args.device == -1 else \"cuda:{}\".format(args.device)\n    if not torch.cuda.is_available():\n        logging.warning(\"CUDA is not available, use CPU for training.\")\n        args.device = \"cpu\"\n\n    # print every\n    if args.print_every == -1:\n        args.print_every = args.epochs + 1\n\n    # bool args\n    if args.sample.lower() == \"true\":\n        args.sample = True\n    else:\n        args.sample = False\n\n    # paths\n    if not os.path.exists(args.dataset_path):\n        os.makedirs(args.dataset_path)\n    if not os.path.exists(args.output_path):\n        os.makedirs(args.output_path)\n    name = (\n        \"Data={}_Hidden={}_Pool={}_WeightDecay={}_Lr={}_Sample={}.log\".format(\n            args.dataset,\n            args.hid_dim,\n            args.pool_ratio,\n            args.weight_decay,\n            args.lr,\n            args.sample,\n        )\n    )\n    args.output_path = os.path.join(args.output_path, name)\n\n    return args\n\n\ndef train(model: torch.nn.Module, optimizer, trainloader, device):\n    model.train()\n    total_loss = 0.0\n    num_batches = len(trainloader)\n    for batch in trainloader:\n        optimizer.zero_grad()\n        batch_graphs, batch_labels = batch\n        batch_graphs = batch_graphs.to(device)\n        batch_labels = batch_labels.long().to(device)\n        out = model(batch_graphs, batch_graphs.ndata[\"feat\"])\n        loss = F.nll_loss(out, batch_labels)\n        loss.backward()\n        optimizer.step()\n\n        total_loss += loss.item()\n\n    return total_loss / num_batches\n\n\n@torch.no_grad()\ndef test(model: torch.nn.Module, loader, device):\n    model.eval()\n    correct = 0.0\n    loss = 0.0\n    num_graphs = 0\n    for batch in loader:\n        batch_graphs, batch_labels = batch\n        num_graphs += batch_labels.size(0)\n        batch_graphs = batch_graphs.to(device)\n        batch_labels = batch_labels.long().to(device)\n        out = model(batch_graphs, batch_graphs.ndata[\"feat\"])\n        pred = out.argmax(dim=1)\n        loss += F.nll_loss(out, batch_labels, reduction=\"sum\").item()\n        correct += pred.eq(batch_labels).sum().item()\n    return correct / num_graphs, loss / num_graphs\n\n\ndef main(args):\n    # Step 1: Prepare graph data and retrieve train/validation/test index ============================= #\n    dataset = LegacyTUDataset(args.dataset, raw_dir=args.dataset_path)\n\n    # add self loop. We add self loop for each graph here since the function \"add_self_loop\" does not\n    # support batch graph.\n    for i in range(len(dataset)):\n        dataset.graph_lists[i] = dgl.add_self_loop(dataset.graph_lists[i])\n\n    num_training = int(len(dataset) * 0.8)\n    num_val = int(len(dataset) * 0.1)\n    num_test = len(dataset) - num_val - num_training\n    train_set, val_set, test_set = random_split(\n        dataset, [num_training, num_val, num_test]\n    )\n\n    train_loader = GraphDataLoader(\n        train_set, batch_size=args.batch_size, shuffle=True, num_workers=6\n    )\n    val_loader = GraphDataLoader(\n        val_set, batch_size=args.batch_size, num_workers=2\n    )\n    test_loader = GraphDataLoader(\n        test_set, batch_size=args.batch_size, num_workers=2\n    )\n\n    device = torch.device(args.device)\n\n    # Step 2: Create model =================================================================== #\n    num_feature, num_classes, _ = dataset.statistics()\n\n    model = HGPSLModel(\n        in_feat=num_feature,\n        out_feat=num_classes,\n        hid_feat=args.hid_dim,\n        conv_layers=args.conv_layers,\n        dropout=args.dropout,\n        pool_ratio=args.pool_ratio,\n        lamb=args.lamb,\n        sample=args.sample,\n    ).to(device)\n    args.num_feature = int(num_feature)\n    args.num_classes = int(num_classes)\n\n    # Step 3: Create training components ===================================================== #\n    optimizer = torch.optim.Adam(\n        model.parameters(), lr=args.lr, weight_decay=args.weight_decay\n    )\n\n    # Step 4: training epoches =============================================================== #\n    bad_cound = 0\n    best_val_loss = float(\"inf\")\n    final_test_acc = 0.0\n    best_epoch = 0\n    train_times = []\n    for e in range(args.epochs):\n        s_time = time()\n        train_loss = train(model, optimizer, train_loader, device)\n        train_times.append(time() - s_time)\n        val_acc, val_loss = test(model, val_loader, device)\n        test_acc, _ = test(model, test_loader, device)\n        if best_val_loss > val_loss:\n            best_val_loss = val_loss\n            final_test_acc = test_acc\n            bad_cound = 0\n            best_epoch = e + 1\n        else:\n            bad_cound += 1\n        if bad_cound >= args.patience:\n            break\n\n        if (e + 1) % args.print_every == 0:\n            log_format = (\n                \"Epoch {}: loss={:.4f}, val_acc={:.4f}, final_test_acc={:.4f}\"\n            )\n            print(log_format.format(e + 1, train_loss, val_acc, final_test_acc))\n    print(\n        \"Best Epoch {}, final test acc {:.4f}\".format(\n            best_epoch, final_test_acc\n        )\n    )\n    return final_test_acc, sum(train_times) / len(train_times)\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    res = []\n    train_times = []\n    for i in range(args.num_trials):\n        print(\"Trial {}/{}\".format(i + 1, args.num_trials))\n        acc, train_time = main(args)\n        res.append(acc)\n        train_times.append(train_time)\n\n    mean, err_bd = get_stats(res, conf_interval=False)\n    print(\"mean acc: {:.4f}, error bound: {:.4f}\".format(mean, err_bd))\n\n    out_dict = {\n        \"hyper-parameters\": vars(args),\n        \"result\": \"{:.4f}(+-{:.4f})\".format(mean, err_bd),\n        \"train_time\": \"{:.4f}\".format(sum(train_times) / len(train_times)),\n    }\n\n    with open(args.output_path, \"w\") as f:\n        json.dump(out_dict, f, sort_keys=True, indent=4)\n"
  },
  {
    "path": "examples/pytorch/hgp_sl/networks.py",
    "content": "import torch\nimport torch.nn\nimport torch.nn.functional as F\n\nfrom dgl.nn import AvgPooling, MaxPooling\nfrom layers import ConvPoolReadout\n\n\nclass HGPSLModel(torch.nn.Module):\n    r\"\"\"\n\n    Description\n    -----------\n    The graph classification model using HGP-SL pooling.\n\n    Parameters\n    ----------\n    in_feat : int\n        The number of input node feature's channels.\n    out_feat : int\n        The number of output node feature's channels.\n    hid_feat : int\n        The number of hidden state's channels.\n    dropout : float, optional\n        The dropout rate. Default: 0\n    pool_ratio : float, optional\n        The pooling ratio for each pooling layer. Default: 0.5\n    conv_layers : int, optional\n        The number of graph convolution and pooling layers. Default: 3\n    sample : bool, optional\n        Whether use k-hop union graph to increase efficiency.\n        Currently we only support full graph. Default: :obj:`False`\n    sparse : bool, optional\n        Use edge sparsemax instead of edge softmax. Default: :obj:`True`\n    sl : bool, optional\n        Use structure learining module or not. Default: :obj:`True`\n    lamb : float, optional\n        The lambda parameter as weight of raw adjacency as described in the\n        HGP-SL paper. Default: 1.0\n    \"\"\"\n\n    def __init__(\n        self,\n        in_feat: int,\n        out_feat: int,\n        hid_feat: int,\n        dropout: float = 0.0,\n        pool_ratio: float = 0.5,\n        conv_layers: int = 3,\n        sample: bool = False,\n        sparse: bool = True,\n        sl: bool = True,\n        lamb: float = 1.0,\n    ):\n        super(HGPSLModel, self).__init__()\n        self.in_feat = in_feat\n        self.out_feat = out_feat\n        self.hid_feat = hid_feat\n        self.dropout = dropout\n        self.num_layers = conv_layers\n        self.pool_ratio = pool_ratio\n\n        convpools = []\n        for i in range(conv_layers):\n            c_in = in_feat if i == 0 else hid_feat\n            c_out = hid_feat\n            use_pool = i != conv_layers - 1\n            convpools.append(\n                ConvPoolReadout(\n                    c_in,\n                    c_out,\n                    pool_ratio=pool_ratio,\n                    sample=sample,\n                    sparse=sparse,\n                    sl=sl,\n                    lamb=lamb,\n                    pool=use_pool,\n                )\n            )\n        self.convpool_layers = torch.nn.ModuleList(convpools)\n\n        self.lin1 = torch.nn.Linear(hid_feat * 2, hid_feat)\n        self.lin2 = torch.nn.Linear(hid_feat, hid_feat // 2)\n        self.lin3 = torch.nn.Linear(hid_feat // 2, self.out_feat)\n\n    def forward(self, graph, n_feat):\n        final_readout = None\n        e_feat = None\n\n        for i in range(self.num_layers):\n            graph, n_feat, e_feat, readout = self.convpool_layers[i](\n                graph, n_feat, e_feat\n            )\n            if final_readout is None:\n                final_readout = readout\n            else:\n                final_readout = final_readout + readout\n\n        n_feat = F.relu(self.lin1(final_readout))\n        n_feat = F.dropout(n_feat, p=self.dropout, training=self.training)\n        n_feat = F.relu(self.lin2(n_feat))\n        n_feat = F.dropout(n_feat, p=self.dropout, training=self.training)\n        n_feat = self.lin3(n_feat)\n\n        return F.log_softmax(n_feat, dim=-1)\n"
  },
  {
    "path": "examples/pytorch/hgp_sl/utils.py",
    "content": "import logging\nimport math\n\nimport torch\nfrom scipy.stats import t\n\n\ndef get_stats(\n    array, conf_interval=False, name=None, stdout=False, logout=False\n):\n    \"\"\"Compute mean and standard deviation from an numerical array\n\n    Args:\n        array (array like obj): The numerical array, this array can be\n            convert to :obj:`torch.Tensor`.\n        conf_interval (bool, optional): If True, compute the confidence interval bound (95%)\n            instead of the std value. (default: :obj:`False`)\n        name (str, optional): The name of this numerical array, for log usage.\n            (default: :obj:`None`)\n        stdout (bool, optional): Whether to output result to the terminal.\n            (default: :obj:`False`)\n        logout (bool, optional): Whether to output result via logging module.\n            (default: :obj:`False`)\n    \"\"\"\n    eps = 1e-9\n    array = torch.Tensor(array)\n    std, mean = torch.std_mean(array)\n    std = std.item()\n    mean = mean.item()\n    center = mean\n\n    if conf_interval:\n        n = array.size(0)\n        se = std / (math.sqrt(n) + eps)\n        t_value = t.ppf(0.975, df=n - 1)\n        err_bound = t_value * se\n    else:\n        err_bound = std\n\n    # log and print\n    if name is None:\n        name = \"array {}\".format(id(array))\n    log = \"{}: {:.4f}(+-{:.4f})\".format(name, center, err_bound)\n    if stdout:\n        print(log)\n    if logout:\n        logging.info(log)\n\n    return center, err_bound\n\n\ndef get_batch_id(num_nodes: torch.Tensor):\n    \"\"\"Convert the num_nodes array obtained from batch graph to batch_id array\n    for each node.\n\n    Args:\n        num_nodes (torch.Tensor): The tensor whose element is the number of nodes\n            in each graph in the batch graph.\n    \"\"\"\n    batch_size = num_nodes.size(0)\n    batch_ids = []\n    for i in range(batch_size):\n        item = torch.full(\n            (num_nodes[i],), i, dtype=torch.long, device=num_nodes.device\n        )\n        batch_ids.append(item)\n    return torch.cat(batch_ids)\n\n\ndef topk(\n    x: torch.Tensor,\n    ratio: float,\n    batch_id: torch.Tensor,\n    num_nodes: torch.Tensor,\n):\n    \"\"\"The top-k pooling method. Given a graph batch, this method will pool out some\n    nodes from input node feature tensor for each graph according to the given ratio.\n\n    Args:\n        x (torch.Tensor): The input node feature batch-tensor to be pooled.\n        ratio (float): the pool ratio. For example if :obj:`ratio=0.5` then half of the input\n            tensor will be pooled out.\n        batch_id (torch.Tensor): The batch_id of each element in the input tensor.\n        num_nodes (torch.Tensor): The number of nodes of each graph in batch.\n\n    Returns:\n        perm (torch.Tensor): The index in batch to be kept.\n        k (torch.Tensor): The remaining number of nodes for each graph.\n    \"\"\"\n    batch_size, max_num_nodes = num_nodes.size(0), num_nodes.max().item()\n\n    cum_num_nodes = torch.cat(\n        [num_nodes.new_zeros(1), num_nodes.cumsum(dim=0)[:-1]], dim=0\n    )\n\n    index = torch.arange(batch_id.size(0), dtype=torch.long, device=x.device)\n    index = (index - cum_num_nodes[batch_id]) + (batch_id * max_num_nodes)\n\n    dense_x = x.new_full(\n        (batch_size * max_num_nodes,), torch.finfo(x.dtype).min\n    )\n    dense_x[index] = x\n    dense_x = dense_x.view(batch_size, max_num_nodes)\n\n    _, perm = dense_x.sort(dim=-1, descending=True)\n    perm = perm + cum_num_nodes.view(-1, 1)\n    perm = perm.view(-1)\n\n    k = (ratio * num_nodes.to(torch.float)).ceil().to(torch.long)\n    mask = [\n        torch.arange(k[i], dtype=torch.long, device=x.device)\n        + i * max_num_nodes\n        for i in range(batch_size)\n    ]\n\n    mask = torch.cat(mask, dim=0)\n    perm = perm[mask]\n\n    return perm, k\n"
  },
  {
    "path": "examples/pytorch/hgt/README.md",
    "content": "# Heterogeneous Graph Transformer (HGT)\n\n[Alternative PyTorch-Geometric implementation](https://github.com/acbull/pyHGT)\n\n[“**Heterogeneous Graph Transformer**”](https://arxiv.org/abs/2003.01332) is a graph neural network architecture that can deal with large-scale heterogeneous and dynamic graphs.\n\n\nThis toy experiment is based on DGL's official [tutorial](https://docs.dgl.ai/en/0.4.x/generated/dgl.heterograph.html). As the ACM datasets doesn't have input feature, we simply randomly assign features for each node. Such process can be simply replaced by any prepared features.\n\n\nThe reference performance against R-GCN and MLP running 5 times:\n\n\n| Model        | Test Accuracy    | # Parameter  |\n| ---------    | ---------------  | -------------|\n| 2-layer HGT  | 0.465 ± 0.007   |  2,176,324   |\n| 2-layer RGCN | 0.392 ± 0.013    |  416,340   |\n| MLP          | 0.132 ± 0.003    |  200,974     | \n"
  },
  {
    "path": "examples/pytorch/hgt/model.py",
    "content": "import math\n\nimport dgl\nimport dgl.function as fn\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dgl.nn.functional import edge_softmax\n\n\nclass HGTLayer(nn.Module):\n    def __init__(\n        self,\n        in_dim,\n        out_dim,\n        node_dict,\n        edge_dict,\n        n_heads,\n        dropout=0.2,\n        use_norm=False,\n    ):\n        super(HGTLayer, self).__init__()\n\n        self.in_dim = in_dim\n        self.out_dim = out_dim\n        self.node_dict = node_dict\n        self.edge_dict = edge_dict\n        self.num_types = len(node_dict)\n        self.num_relations = len(edge_dict)\n        self.total_rel = self.num_types * self.num_relations * self.num_types\n        self.n_heads = n_heads\n        self.d_k = out_dim // n_heads\n        self.sqrt_dk = math.sqrt(self.d_k)\n        self.att = None\n\n        self.k_linears = nn.ModuleList()\n        self.q_linears = nn.ModuleList()\n        self.v_linears = nn.ModuleList()\n        self.a_linears = nn.ModuleList()\n        self.norms = nn.ModuleList()\n        self.use_norm = use_norm\n\n        for t in range(self.num_types):\n            self.k_linears.append(nn.Linear(in_dim, out_dim))\n            self.q_linears.append(nn.Linear(in_dim, out_dim))\n            self.v_linears.append(nn.Linear(in_dim, out_dim))\n            self.a_linears.append(nn.Linear(out_dim, out_dim))\n            if use_norm:\n                self.norms.append(nn.LayerNorm(out_dim))\n\n        self.relation_pri = nn.Parameter(\n            torch.ones(self.num_relations, self.n_heads)\n        )\n        self.relation_att = nn.Parameter(\n            torch.Tensor(self.num_relations, n_heads, self.d_k, self.d_k)\n        )\n        self.relation_msg = nn.Parameter(\n            torch.Tensor(self.num_relations, n_heads, self.d_k, self.d_k)\n        )\n        self.skip = nn.Parameter(torch.ones(self.num_types))\n        self.drop = nn.Dropout(dropout)\n\n        nn.init.xavier_uniform_(self.relation_att)\n        nn.init.xavier_uniform_(self.relation_msg)\n\n    def forward(self, G, h):\n        with G.local_scope():\n            node_dict, edge_dict = self.node_dict, self.edge_dict\n            for srctype, etype, dsttype in G.canonical_etypes:\n                sub_graph = G[srctype, etype, dsttype]\n\n                k_linear = self.k_linears[node_dict[srctype]]\n                v_linear = self.v_linears[node_dict[srctype]]\n                q_linear = self.q_linears[node_dict[dsttype]]\n\n                k = k_linear(h[srctype]).view(-1, self.n_heads, self.d_k)\n                v = v_linear(h[srctype]).view(-1, self.n_heads, self.d_k)\n                q = q_linear(h[dsttype]).view(-1, self.n_heads, self.d_k)\n\n                e_id = self.edge_dict[etype]\n\n                relation_att = self.relation_att[e_id]\n                relation_pri = self.relation_pri[e_id]\n                relation_msg = self.relation_msg[e_id]\n\n                k = torch.einsum(\"bij,ijk->bik\", k, relation_att)\n                v = torch.einsum(\"bij,ijk->bik\", v, relation_msg)\n\n                sub_graph.srcdata[\"k\"] = k\n                sub_graph.dstdata[\"q\"] = q\n                sub_graph.srcdata[\"v_%d\" % e_id] = v\n\n                sub_graph.apply_edges(fn.v_dot_u(\"q\", \"k\", \"t\"))\n                attn_score = (\n                    sub_graph.edata.pop(\"t\").sum(-1)\n                    * relation_pri\n                    / self.sqrt_dk\n                )\n                attn_score = edge_softmax(sub_graph, attn_score, norm_by=\"dst\")\n\n                sub_graph.edata[\"t\"] = attn_score.unsqueeze(-1)\n\n            G.multi_update_all(\n                {\n                    etype: (\n                        fn.u_mul_e(\"v_%d\" % e_id, \"t\", \"m\"),\n                        fn.sum(\"m\", \"t\"),\n                    )\n                    for etype, e_id in edge_dict.items()\n                },\n                cross_reducer=\"mean\",\n            )\n\n            new_h = {}\n            for ntype in G.ntypes:\n                \"\"\"\n                Step 3: Target-specific Aggregation\n                x = norm( W[node_type] * gelu( Agg(x) ) + x )\n                \"\"\"\n                n_id = node_dict[ntype]\n                alpha = torch.sigmoid(self.skip[n_id])\n                t = G.nodes[ntype].data[\"t\"].view(-1, self.out_dim)\n                trans_out = self.drop(self.a_linears[n_id](t))\n                trans_out = trans_out * alpha + h[ntype] * (1 - alpha)\n                if self.use_norm:\n                    new_h[ntype] = self.norms[n_id](trans_out)\n                else:\n                    new_h[ntype] = trans_out\n            return new_h\n\n\nclass HGT(nn.Module):\n    def __init__(\n        self,\n        G,\n        node_dict,\n        edge_dict,\n        n_inp,\n        n_hid,\n        n_out,\n        n_layers,\n        n_heads,\n        use_norm=True,\n    ):\n        super(HGT, self).__init__()\n        self.node_dict = node_dict\n        self.edge_dict = edge_dict\n        self.gcs = nn.ModuleList()\n        self.n_inp = n_inp\n        self.n_hid = n_hid\n        self.n_out = n_out\n        self.n_layers = n_layers\n        self.adapt_ws = nn.ModuleList()\n        for t in range(len(node_dict)):\n            self.adapt_ws.append(nn.Linear(n_inp, n_hid))\n        for _ in range(n_layers):\n            self.gcs.append(\n                HGTLayer(\n                    n_hid,\n                    n_hid,\n                    node_dict,\n                    edge_dict,\n                    n_heads,\n                    use_norm=use_norm,\n                )\n            )\n        self.out = nn.Linear(n_hid, n_out)\n\n    def forward(self, G, out_key):\n        h = {}\n        for ntype in G.ntypes:\n            n_id = self.node_dict[ntype]\n            h[ntype] = F.gelu(self.adapt_ws[n_id](G.nodes[ntype].data[\"inp\"]))\n        for i in range(self.n_layers):\n            h = self.gcs[i](G, h)\n        return self.out(h[out_key])\n\n\nclass HeteroRGCNLayer(nn.Module):\n    def __init__(self, in_size, out_size, etypes):\n        super(HeteroRGCNLayer, self).__init__()\n        # W_r for each relation\n        self.weight = nn.ModuleDict(\n            {name: nn.Linear(in_size, out_size) for name in etypes}\n        )\n\n    def forward(self, G, feat_dict):\n        # The input is a dictionary of node features for each type\n        funcs = {}\n        for srctype, etype, dsttype in G.canonical_etypes:\n            # Compute W_r * h\n            Wh = self.weight[etype](feat_dict[srctype])\n            # Save it in graph for message passing\n            G.nodes[srctype].data[\"Wh_%s\" % etype] = Wh\n            # Specify per-relation message passing functions: (message_func, reduce_func).\n            # Note that the results are saved to the same destination feature 'h', which\n            # hints the type wise reducer for aggregation.\n            funcs[etype] = (fn.copy_u(\"Wh_%s\" % etype, \"m\"), fn.mean(\"m\", \"h\"))\n        # Trigger message passing of multiple types.\n        # The first argument is the message passing functions for each relation.\n        # The second one is the type wise reducer, could be \"sum\", \"max\",\n        # \"min\", \"mean\", \"stack\"\n        G.multi_update_all(funcs, \"sum\")\n        # return the updated node feature dictionary\n        return {ntype: G.nodes[ntype].data[\"h\"] for ntype in G.ntypes}\n\n\nclass HeteroRGCN(nn.Module):\n    def __init__(self, G, in_size, hidden_size, out_size):\n        super(HeteroRGCN, self).__init__()\n        # create layers\n        self.layer1 = HeteroRGCNLayer(in_size, hidden_size, G.etypes)\n        self.layer2 = HeteroRGCNLayer(hidden_size, out_size, G.etypes)\n\n    def forward(self, G, out_key):\n        input_dict = {ntype: G.nodes[ntype].data[\"inp\"] for ntype in G.ntypes}\n        h_dict = self.layer1(G, input_dict)\n        h_dict = {k: F.leaky_relu(h) for k, h in h_dict.items()}\n        h_dict = self.layer2(G, h_dict)\n        # get appropriate logits\n        return h_dict[out_key]\n"
  },
  {
    "path": "examples/pytorch/hgt/train_acm.py",
    "content": "#!/usr/bin/env python\n# coding: utf-8\n\n# In[1]:\n\n\nimport argparse\nimport math\nimport urllib.request\n\nimport numpy as np\nimport scipy.io\nfrom model import *\n\nimport dgl\n\ntorch.manual_seed(0)\ndata_url = \"https://data.dgl.ai/dataset/ACM.mat\"\ndata_file_path = \"/tmp/ACM.mat\"\n\nurllib.request.urlretrieve(data_url, data_file_path)\ndata = scipy.io.loadmat(data_file_path)\n\n\nparser = argparse.ArgumentParser(\n    description=\"Training GNN on ogbn-products benchmark\"\n)\n\n\nparser.add_argument(\"--n_epoch\", type=int, default=200)\nparser.add_argument(\"--n_hid\", type=int, default=256)\nparser.add_argument(\"--n_inp\", type=int, default=256)\nparser.add_argument(\"--clip\", type=int, default=1.0)\nparser.add_argument(\"--max_lr\", type=float, default=1e-3)\n\nargs = parser.parse_args()\n\n\ndef get_n_params(model):\n    pp = 0\n    for p in list(model.parameters()):\n        nn = 1\n        for s in list(p.size()):\n            nn = nn * s\n        pp += nn\n    return pp\n\n\ndef train(model, G):\n    best_val_acc = torch.tensor(0)\n    best_test_acc = torch.tensor(0)\n    for epoch in np.arange(args.n_epoch) + 1:\n        model.train()\n        logits = model(G, \"paper\")\n        # The loss is computed only for labeled nodes.\n        loss = F.cross_entropy(logits[train_idx], labels[train_idx].to(device))\n        optimizer.zero_grad()\n        loss.backward()\n        torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)\n        optimizer.step()\n        scheduler.step()\n        if epoch % 5 == 0:\n            model.eval()\n            logits = model(G, \"paper\")\n            pred = logits.argmax(1).cpu()\n            train_acc = (pred[train_idx] == labels[train_idx]).float().mean()\n            val_acc = (pred[val_idx] == labels[val_idx]).float().mean()\n            test_acc = (pred[test_idx] == labels[test_idx]).float().mean()\n            if best_val_acc < val_acc:\n                best_val_acc = val_acc\n                best_test_acc = test_acc\n            print(\n                \"Epoch: %d LR: %.5f Loss %.4f, Train Acc %.4f, Val Acc %.4f (Best %.4f), Test Acc %.4f (Best %.4f)\"\n                % (\n                    epoch,\n                    optimizer.param_groups[0][\"lr\"],\n                    loss.item(),\n                    train_acc.item(),\n                    val_acc.item(),\n                    best_val_acc.item(),\n                    test_acc.item(),\n                    best_test_acc.item(),\n                )\n            )\n\n\ndevice = torch.device(\"cuda:0\")\n\nG = dgl.heterograph(\n    {\n        (\"paper\", \"written-by\", \"author\"): data[\"PvsA\"].nonzero(),\n        (\"author\", \"writing\", \"paper\"): data[\"PvsA\"].transpose().nonzero(),\n        (\"paper\", \"citing\", \"paper\"): data[\"PvsP\"].nonzero(),\n        (\"paper\", \"cited\", \"paper\"): data[\"PvsP\"].transpose().nonzero(),\n        (\"paper\", \"is-about\", \"subject\"): data[\"PvsL\"].nonzero(),\n        (\"subject\", \"has\", \"paper\"): data[\"PvsL\"].transpose().nonzero(),\n    }\n)\nprint(G)\n\npvc = data[\"PvsC\"].tocsr()\np_selected = pvc.tocoo()\n# generate labels\nlabels = pvc.indices\nlabels = torch.tensor(labels).long()\n\n# generate train/val/test split\npid = p_selected.row\nshuffle = np.random.permutation(pid)\ntrain_idx = torch.tensor(shuffle[0:800]).long()\nval_idx = torch.tensor(shuffle[800:900]).long()\ntest_idx = torch.tensor(shuffle[900:]).long()\n\nnode_dict = {}\nedge_dict = {}\nfor ntype in G.ntypes:\n    node_dict[ntype] = len(node_dict)\nfor etype in G.etypes:\n    edge_dict[etype] = len(edge_dict)\n    G.edges[etype].data[\"id\"] = (\n        torch.ones(G.num_edges(etype), dtype=torch.long) * edge_dict[etype]\n    )\n\n#     Random initialize input feature\nfor ntype in G.ntypes:\n    emb = nn.Parameter(\n        torch.Tensor(G.num_nodes(ntype), 256), requires_grad=False\n    )\n    nn.init.xavier_uniform_(emb)\n    G.nodes[ntype].data[\"inp\"] = emb\n\nG = G.to(device)\n\nmodel = HGT(\n    G,\n    node_dict,\n    edge_dict,\n    n_inp=args.n_inp,\n    n_hid=args.n_hid,\n    n_out=labels.max().item() + 1,\n    n_layers=2,\n    n_heads=4,\n    use_norm=True,\n).to(device)\noptimizer = torch.optim.AdamW(model.parameters())\nscheduler = torch.optim.lr_scheduler.OneCycleLR(\n    optimizer, total_steps=args.n_epoch, max_lr=args.max_lr\n)\nprint(\"Training HGT with #param: %d\" % (get_n_params(model)))\ntrain(model, G)\n\n\nmodel = HeteroRGCN(\n    G,\n    in_size=args.n_inp,\n    hidden_size=args.n_hid,\n    out_size=labels.max().item() + 1,\n).to(device)\noptimizer = torch.optim.AdamW(model.parameters())\nscheduler = torch.optim.lr_scheduler.OneCycleLR(\n    optimizer, total_steps=args.n_epoch, max_lr=args.max_lr\n)\nprint(\"Training RGCN with #param: %d\" % (get_n_params(model)))\ntrain(model, G)\n\n\nmodel = HGT(\n    G,\n    node_dict,\n    edge_dict,\n    n_inp=args.n_inp,\n    n_hid=args.n_hid,\n    n_out=labels.max().item() + 1,\n    n_layers=0,\n    n_heads=4,\n).to(device)\noptimizer = torch.optim.AdamW(model.parameters())\nscheduler = torch.optim.lr_scheduler.OneCycleLR(\n    optimizer, total_steps=args.n_epoch, max_lr=args.max_lr\n)\nprint(\"Training MLP with #param: %d\" % (get_n_params(model)))\ntrain(model, G)\n"
  },
  {
    "path": "examples/pytorch/hilander/PSS/README.md",
    "content": "# PSS\n\nCode for the ECCV '22 submission \"PSS: Progressive Sample Selection for Open-World Visual Representation Learning\".\n\n## Dependencies\n\nWe use python 3.7. The CUDA version needs to be 10.2. Besides DGL==0.6.1, we depend on several packages. To install dependencies using conda:\n\n```commandline\nconda create -n pss python=3.7 # create env\nconda activate pss # activate env\n\nconda install pytorch==1.7.0 torchvision==0.8.0 cudatoolkit=10.2 -c pytorch # install pytorch 1.7 version\nconda install -y cudatoolkit=10.2 faiss-gpu=1.6.5 -c pytorch # install faiss gpu version matching cuda 10.2\npip install dgl-cu102 # install dgl for cuda 10.2\npip install tqdm # install tqdm\npip install matplotlib # install matplotlib\npip install pandas # install pandas\npip install pretrainedmodels # install pretrainedmodels\npip install tensorboardX # install tensorboardX\npip install seaborn # install seaborn\npip install scikit-learn\ncd ..\ngit clone https://github.com/yjxiong/clustering-benchmark.git # install clustering-benchmark for evaluation\ncd clustering-benchmark\npython setup.py install\ncd ../PSS\n```\n\n## Data\n\nWe use the iNaturalist 2018 dataset. \n- download link: https://www.kaggle.com/c/inaturalist-2018/data;\n- annotations are in `Smooth_AP/data/Inaturalist`;\n- annotation txt files for different data splits are in [S3 link]|[[Google Drive](https://drive.google.com/drive/folders/1xrWogJGef4Ex5OGjiImgA06bAnk2MDrK?usp=sharing)]|[[Baidu Netdisk](https://pan.baidu.com/s/14S0Fns29a4o7kFDlNyyPjA?pwd=uwsg)] (password:uwsg).\n\nDownload `train_val2018.tar.gz` and the data split txt files to `data/Inaturalist/` folder. Extract the `tar.gz` files.\nThe data folder has the following structure:\n```bash\nPSS\n|- data\n|  |- Inaturalist\n|    |- train2018.json.tar.gz\n|    |- train_val2018.tar.gz\n|    |- val2018.json.tar.gz\n|    |- train_val2018\n|    |  |- Actinopterygii\n|    |  |- ...\n|    |- lin_train_set1.txt\n|    |- train_set1.txt\n|    |- uin_train_set1.txt\n|    |- uout_train_set1.txt\n|    |- in_train_set1.txt\n|    |- Inaturalist_test_set1.txt\n|-...\n```\n\n## Training\nRun `bash train.sh` to train the model.\n\n## Test\nRun `bash test.sh` to evaluate on the test set."
  },
  {
    "path": "examples/pytorch/hilander/PSS/Smooth_AP/README.md",
    "content": "# Smooth_AP\n\nReferenced from the ECCV '20 paper [\"Smooth-AP: Smoothing the Path Towards Large-Scale Image Retrieval\"](https://www.robots.ox.ac.uk/~vgg/research/smooth-ap/), reference code is from https://github.com/Andrew-Brown1/Smooth_AP.\n\n\n![teaser](https://github.com/Andrew-Brown1/Smooth_AP/blob/master/ims/teaser.png)\n"
  },
  {
    "path": "examples/pytorch/hilander/PSS/Smooth_AP/src/auxiliaries.py",
    "content": "# repo originally forked from https://github.com/Confusezius/Deep-Metric-Learning-Baselines\n\n################## LIBRARIES ##############################\nimport warnings\n\nwarnings.filterwarnings(\"ignore\")\n\nimport csv\nimport datetime\nimport os\nimport pickle as pkl\n\nimport faiss\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport torch\nfrom PIL import Image\nfrom sklearn import metrics\nfrom torch import nn\nfrom tqdm import tqdm\n\n\"\"\"=============================================================================================================\"\"\"\n\n\n################### TensorBoard Settings ###################\ndef args2exp_name(args):\n    exp_name = f\"{args.dataset}_{args.loss}_{args.lr}_bs{args.bs}_spc{args.samples_per_class}_embed{args.embed_dim}_arch{args.arch}_decay{args.decay}_fclr{args.fc_lr_mul}_anneal{args.sigmoid_temperature}\"\n    return exp_name\n\n\n################# ACQUIRE NUMBER OF WEIGHTS #################\ndef gimme_params(model):\n    \"\"\"\n    Provide number of trainable parameters (i.e. those requiring gradient computation) for input network.\n\n    Args:\n        model: PyTorch Network\n    Returns:\n        int, number of parameters.\n    \"\"\"\n    model_parameters = filter(lambda p: p.requires_grad, model.parameters())\n    params = sum([np.prod(p.size()) for p in model_parameters])\n    return params\n\n\n################# SAVE TRAINING PARAMETERS IN NICE STRING #################\ndef gimme_save_string(opt):\n    \"\"\"\n    Taking the set of parameters and convert it to easy-to-read string, which can be stored later.\n\n    Args:\n        opt: argparse.Namespace, contains all training-specific parameters.\n    Returns:\n        string, returns string summary of parameters.\n    \"\"\"\n    varx = vars(opt)\n    base_str = \"\"\n    for key in varx:\n        base_str += str(key)\n        if isinstance(varx[key], dict):\n            for sub_key, sub_item in varx[key].items():\n                base_str += \"\\n\\t\" + str(sub_key) + \": \" + str(sub_item)\n        else:\n            base_str += \"\\n\\t\" + str(varx[key])\n        base_str += \"\\n\\n\"\n    return base_str\n\n\ndef f1_score(\n    model_generated_cluster_labels,\n    target_labels,\n    feature_coll,\n    computed_centroids,\n):\n    \"\"\"\n    NOTE: MOSTLY ADAPTED FROM https://github.com/wzzheng/HDML on Hardness-Aware Deep Metric Learning.\n\n    Args:\n        model_generated_cluster_labels: np.ndarray [n_samples x 1], Cluster labels computed on top of data embeddings.\n        target_labels:                  np.ndarray [n_samples x 1], ground truth labels for each data sample.\n        feature_coll:                   np.ndarray [n_samples x embed_dim], total data embedding made by network.\n        computed_centroids:             np.ndarray [num_cluster=num_classes x embed_dim], cluster coordinates\n    Returns:\n        float, F1-score\n    \"\"\"\n    from scipy.special import comb\n\n    d = np.zeros(len(feature_coll))\n    for i in range(len(feature_coll)):\n        d[i] = np.linalg.norm(\n            feature_coll[i, :]\n            - computed_centroids[model_generated_cluster_labels[i], :]\n        )\n\n    labels_pred = np.zeros(len(feature_coll))\n    for i in np.unique(model_generated_cluster_labels):\n        index = np.where(model_generated_cluster_labels == i)[0]\n        ind = np.argmin(d[index])\n        cid = index[ind]\n        labels_pred[index] = cid\n\n    N = len(target_labels)\n\n    # Cluster n_labels\n    avail_labels = np.unique(target_labels)\n    n_labels = len(avail_labels)\n\n    # Count the number of objects in each cluster\n    count_cluster = np.zeros(n_labels)\n    for i in range(n_labels):\n        count_cluster[i] = len(np.where(target_labels == avail_labels[i])[0])\n\n    # Build a mapping from item_id to item index\n    keys = np.unique(labels_pred)\n    num_item = len(keys)\n    values = range(num_item)\n    item_map = dict()\n    for i in range(len(keys)):\n        item_map.update([(keys[i], values[i])])\n\n    # Count the number of objects of each item\n    count_item = np.zeros(num_item)\n    for i in range(N):\n        index = item_map[labels_pred[i]]\n        count_item[index] = count_item[index] + 1\n\n    # Compute True Positive (TP) plus False Positive (FP) count\n    tp_fp = 0\n    for k in range(n_labels):\n        if count_cluster[k] > 1:\n            tp_fp = tp_fp + comb(count_cluster[k], 2)\n\n    # Compute True Positive (TP) count\n    tp = 0\n    for k in range(n_labels):\n        member = np.where(target_labels == avail_labels[k])[0]\n        member_ids = labels_pred[member]\n\n        count = np.zeros(num_item)\n        for j in range(len(member)):\n            index = item_map[member_ids[j]]\n            count[index] = count[index] + 1\n\n        for i in range(num_item):\n            if count[i] > 1:\n                tp = tp + comb(count[i], 2)\n\n    # Compute  False Positive (FP) count\n    fp = tp_fp - tp\n\n    # Compute False Negative (FN) count\n    count = 0\n    for j in range(num_item):\n        if count_item[j] > 1:\n            count = count + comb(count_item[j], 2)\n    fn = count - tp\n\n    # compute F measure\n    beta = 1\n    P = tp / (tp + fp)\n    R = tp / (tp + fn)\n    F1 = (beta * beta + 1) * P * R / (beta * beta * P + R)\n\n    return F1\n\n\n\"\"\"=============================================================================================================\"\"\"\n\n\ndef eval_metrics_one_dataset(model, test_dataloader, device, k_vals, opt):\n    \"\"\"\n    Compute evaluation metrics on test-dataset, e.g. NMI, F1 and Recall @ k.\n\n    Args:\n        model:              PyTorch network, network to compute evaluation metrics for.\n        test_dataloader:    PyTorch Dataloader, dataloader for test dataset, should have no shuffling and correct processing.\n        device:             torch.device, Device to run inference on.\n        k_vals:             list of int, Recall values to compute\n        opt:                argparse.Namespace, contains all training-specific parameters.\n    Returns:\n        F1 score (float), NMI score (float), recall_at_k (list of float), data embedding (np.ndarray)\n    \"\"\"\n    torch.cuda.empty_cache()\n\n    _ = model.eval()\n    n_classes = len(test_dataloader.dataset.avail_classes)\n    with torch.no_grad():\n        ### For all test images, extract features\n        target_labels, feature_coll = [], []\n        final_iter = tqdm(\n            test_dataloader, desc=\"Computing Evaluation Metrics...\"\n        )\n        image_paths = [x[0] for x in test_dataloader.dataset.image_list]\n        for idx, inp in enumerate(final_iter):\n            input_img, target = inp[-1], inp[0]\n            target_labels.extend(target.numpy().tolist())\n            out = model(input_img.to(device), feature=True)\n            feature_coll.extend(out.cpu().detach().numpy().tolist())\n        # pdb.set_trace()\n        target_labels = np.hstack(target_labels).reshape(-1, 1)\n        feature_coll = np.vstack(feature_coll).astype(\"float32\")\n\n        torch.cuda.empty_cache()\n\n        ### Set Faiss CPU Cluster index\n        cpu_cluster_index = faiss.IndexFlatL2(feature_coll.shape[-1])\n        kmeans = faiss.Clustering(feature_coll.shape[-1], n_classes)\n        kmeans.niter = 20\n        kmeans.min_points_per_centroid = 1\n        kmeans.max_points_per_centroid = 1000000000\n\n        ### Train Kmeans\n        kmeans.train(feature_coll, cpu_cluster_index)\n        computed_centroids = faiss.vector_float_to_array(\n            kmeans.centroids\n        ).reshape(n_classes, feature_coll.shape[-1])\n\n        ### Assign feature points to clusters\n        faiss_search_index = faiss.IndexFlatL2(computed_centroids.shape[-1])\n        faiss_search_index.add(computed_centroids)\n        _, model_generated_cluster_labels = faiss_search_index.search(\n            feature_coll, 1\n        )\n\n        ### Compute NMI\n        NMI = metrics.cluster.normalized_mutual_info_score(\n            model_generated_cluster_labels.reshape(-1),\n            target_labels.reshape(-1),\n        )\n\n        ### Recover max(k_vals) nehbours to use for recall computation\n        faiss_search_index = faiss.IndexFlatL2(feature_coll.shape[-1])\n        faiss_search_index.add(feature_coll)\n        _, k_closest_points = faiss_search_index.search(\n            feature_coll, int(np.max(k_vals) + 1)\n        )\n        k_closest_classes = target_labels.reshape(-1)[k_closest_points[:, 1:]]\n        print(\"computing recalls\")\n        ### Compute Recall\n        recall_all_k = []\n        for k in k_vals:\n            recall_at_k = np.sum(\n                [\n                    1\n                    for target, recalled_predictions in zip(\n                        target_labels, k_closest_classes\n                    )\n                    if target in recalled_predictions[:k]\n                ]\n            ) / len(target_labels)\n            recall_all_k.append(recall_at_k)\n        print(\"finished recalls\")\n        print(\"computing F1\")\n        ### Compute F1 Score\n        F1 = 0\n        # F1 = f1_score(model_generated_cluster_labels, target_labels, feature_coll, computed_centroids)\n        print(\"finished computing f1\")\n\n    return F1, NMI, recall_all_k, feature_coll\n\n\ndef eval_metrics_query_and_gallery_dataset(\n    model, query_dataloader, gallery_dataloader, device, k_vals, opt\n):\n    \"\"\"\n    Compute evaluation metrics on test-dataset, e.g. NMI, F1 and Recall @ k.\n\n    Args:\n        model:               PyTorch network, network to compute evaluation metrics for.\n        query_dataloader:    PyTorch Dataloader, dataloader for query dataset, for which nearest neighbours in the gallery dataset are retrieved.\n        gallery_dataloader:  PyTorch Dataloader, dataloader for gallery dataset, provides target samples which are to be retrieved in correspondance to the query dataset.\n        device:              torch.device, Device to run inference on.\n        k_vals:              list of int, Recall values to compute\n        opt:                 argparse.Namespace, contains all training-specific parameters.\n    Returns:\n        F1 score (float), NMI score (float), recall_at_ks (list of float), query data embedding (np.ndarray), gallery data embedding (np.ndarray)\n    \"\"\"\n    torch.cuda.empty_cache()\n\n    _ = model.eval()\n    n_classes = len(query_dataloader.dataset.avail_classes)\n\n    with torch.no_grad():\n        ### For all query test images, extract features\n        query_target_labels, query_feature_coll = [], []\n        query_image_paths = [x[0] for x in query_dataloader.dataset.image_list]\n        query_iter = tqdm(query_dataloader, desc=\"Extraction Query Features\")\n        for idx, inp in enumerate(query_iter):\n            input_img, target = inp[-1], inp[0]\n            query_target_labels.extend(target.numpy().tolist())\n            out = model(input_img.to(device), feature=True)\n            query_feature_coll.extend(out.cpu().detach().numpy().tolist())\n\n        ### For all gallery test images, extract features\n        gallery_target_labels, gallery_feature_coll = [], []\n        gallery_image_paths = [\n            x[0] for x in gallery_dataloader.dataset.image_list\n        ]\n        gallery_iter = tqdm(\n            gallery_dataloader, desc=\"Extraction Gallery Features\"\n        )\n        for idx, inp in enumerate(gallery_iter):\n            input_img, target = inp[-1], inp[0]\n            gallery_target_labels.extend(target.numpy().tolist())\n            out = model(input_img.to(device), feature=True)\n            gallery_feature_coll.extend(out.cpu().detach().numpy().tolist())\n\n        query_target_labels, query_feature_coll = np.hstack(\n            query_target_labels\n        ).reshape(-1, 1), np.vstack(query_feature_coll).astype(\"float32\")\n        gallery_target_labels, gallery_feature_coll = np.hstack(\n            gallery_target_labels\n        ).reshape(-1, 1), np.vstack(gallery_feature_coll).astype(\"float32\")\n\n        torch.cuda.empty_cache()\n\n        ### Set CPU Cluster index\n        stackset = np.concatenate(\n            [query_feature_coll, gallery_feature_coll], axis=0\n        )\n        stacklabels = np.concatenate(\n            [query_target_labels, gallery_target_labels], axis=0\n        )\n        cpu_cluster_index = faiss.IndexFlatL2(stackset.shape[-1])\n        kmeans = faiss.Clustering(stackset.shape[-1], n_classes)\n        kmeans.niter = 20\n        kmeans.min_points_per_centroid = 1\n        kmeans.max_points_per_centroid = 1000000000\n\n        ### Train Kmeans\n        kmeans.train(stackset, cpu_cluster_index)\n        computed_centroids = faiss.vector_float_to_array(\n            kmeans.centroids\n        ).reshape(n_classes, stackset.shape[-1])\n\n        ### Assign feature points to clusters\n        faiss_search_index = faiss.IndexFlatL2(computed_centroids.shape[-1])\n        faiss_search_index.add(computed_centroids)\n        _, model_generated_cluster_labels = faiss_search_index.search(\n            stackset, 1\n        )\n\n        ### Compute NMI\n        NMI = metrics.cluster.normalized_mutual_info_score(\n            model_generated_cluster_labels.reshape(-1), stacklabels.reshape(-1)\n        )\n\n        ### Recover max(k_vals) nearest neighbours to use for recall computation\n        faiss_search_index = faiss.IndexFlatL2(gallery_feature_coll.shape[-1])\n        faiss_search_index.add(gallery_feature_coll)\n        _, k_closest_points = faiss_search_index.search(\n            query_feature_coll, int(np.max(k_vals))\n        )\n        k_closest_classes = gallery_target_labels.reshape(-1)[k_closest_points]\n\n        ### Compute Recall\n        recall_all_k = []\n        for k in k_vals:\n            recall_at_k = np.sum(\n                [\n                    1\n                    for target, recalled_predictions in zip(\n                        query_target_labels, k_closest_classes\n                    )\n                    if target in recalled_predictions[:k]\n                ]\n            ) / len(query_target_labels)\n            recall_all_k.append(recall_at_k)\n        recall_str = \", \".join(\n            \"@{0}: {1:.4f}\".format(k, rec)\n            for k, rec in zip(k_vals, recall_all_k)\n        )\n\n        ### Compute F1 score\n        F1 = f1_score(\n            model_generated_cluster_labels,\n            stacklabels,\n            stackset,\n            computed_centroids,\n        )\n\n    return F1, NMI, recall_all_k, query_feature_coll, gallery_feature_coll\n\n\n\"\"\"=============================================================================================================\"\"\"\n\n\n####### RECOVER CLOSEST EXAMPLE IMAGES #######\ndef recover_closest_one_dataset(\n    feature_matrix_all, image_paths, save_path, n_image_samples=10, n_closest=3\n):\n    \"\"\"\n    Provide sample recoveries.\n\n    Args:\n        feature_matrix_all: np.ndarray [n_samples x embed_dim], full data embedding of test samples.\n        image_paths:        list [n_samples], list of datapaths corresponding to <feature_matrix_all>\n        save_path:          str, where to store sample image.\n        n_image_samples:    Number of sample recoveries.\n        n_closest:          Number of closest recoveries to show.\n    Returns:\n        Nothing!\n    \"\"\"\n    image_paths = np.array([x[0] for x in image_paths])\n    sample_idxs = np.random.choice(\n        np.arange(len(feature_matrix_all)), n_image_samples\n    )\n\n    faiss_search_index = faiss.IndexFlatL2(feature_matrix_all.shape[-1])\n    faiss_search_index.add(feature_matrix_all)\n    _, closest_feature_idxs = faiss_search_index.search(\n        feature_matrix_all, n_closest + 1\n    )\n\n    sample_paths = image_paths[closest_feature_idxs][sample_idxs]\n\n    f, axes = plt.subplots(n_image_samples, n_closest + 1)\n    for i, (ax, plot_path) in enumerate(\n        zip(axes.reshape(-1), sample_paths.reshape(-1))\n    ):\n        ax.imshow(np.array(Image.open(plot_path)))\n        ax.set_xticks([])\n        ax.set_yticks([])\n        if i % (n_closest + 1):\n            ax.axvline(x=0, color=\"g\", linewidth=13)\n        else:\n            ax.axvline(x=0, color=\"r\", linewidth=13)\n    f.set_size_inches(10, 20)\n    f.tight_layout()\n    f.savefig(save_path)\n    plt.close()\n\n\n####### RECOVER CLOSEST EXAMPLE IMAGES #######\ndef recover_closest_inshop(\n    query_feature_matrix_all,\n    gallery_feature_matrix_all,\n    query_image_paths,\n    gallery_image_paths,\n    save_path,\n    n_image_samples=10,\n    n_closest=3,\n):\n    \"\"\"\n    Provide sample recoveries.\n\n    Args:\n        query_feature_matrix_all:   np.ndarray [n_query_samples x embed_dim], full data embedding of query samples.\n        gallery_feature_matrix_all: np.ndarray [n_gallery_samples x embed_dim], full data embedding of gallery samples.\n        query_image_paths:          list [n_samples], list of datapaths corresponding to <query_feature_matrix_all>\n        gallery_image_paths:        list [n_samples], list of datapaths corresponding to <gallery_feature_matrix_all>\n        save_path:          str, where to store sample image.\n        n_image_samples:    Number of sample recoveries.\n        n_closest:          Number of closest recoveries to show.\n    Returns:\n        Nothing!\n    \"\"\"\n    query_image_paths, gallery_image_paths = np.array(\n        query_image_paths\n    ), np.array(gallery_image_paths)\n    sample_idxs = np.random.choice(\n        np.arange(len(query_feature_matrix_all)), n_image_samples\n    )\n\n    faiss_search_index = faiss.IndexFlatL2(gallery_feature_matrix_all.shape[-1])\n    faiss_search_index.add(gallery_feature_matrix_all)\n    _, closest_feature_idxs = faiss_search_index.search(\n        query_feature_matrix_all, n_closest\n    )\n\n    image_paths = gallery_image_paths[closest_feature_idxs]\n    image_paths = np.concatenate(\n        [query_image_paths.reshape(-1, 1), image_paths], axis=-1\n    )\n\n    sample_paths = image_paths[closest_feature_idxs][sample_idxs]\n\n    f, axes = plt.subplots(n_image_samples, n_closest + 1)\n    for i, (ax, plot_path) in enumerate(\n        zip(axes.reshape(-1), sample_paths.reshape(-1))\n    ):\n        ax.imshow(np.array(Image.open(plot_path)))\n        ax.set_xticks([])\n        ax.set_yticks([])\n        if i % (n_closest + 1):\n            ax.axvline(x=0, color=\"g\", linewidth=13)\n        else:\n            ax.axvline(x=0, color=\"r\", linewidth=13)\n    f.set_size_inches(10, 20)\n    f.tight_layout()\n    f.savefig(save_path)\n    plt.close()\n\n\n\"\"\"=============================================================================================================\"\"\"\n\n\n################## SET NETWORK TRAINING CHECKPOINT #####################\ndef set_checkpoint(model, opt, progress_saver, savepath):\n    \"\"\"\n    Store relevant parameters (model and progress saver, as well as parameter-namespace).\n    Can be easily extend for other stuff.\n\n    Args:\n        model:          PyTorch network, network whose parameters are to be saved.\n        opt:            argparse.Namespace, includes all training-specific parameters\n        progress_saver: subclass of LOGGER-class, contains a running memory of all training metrics.\n        savepath:       str, where to save checkpoint.\n    Returns:\n        Nothing!\n    \"\"\"\n    torch.save(\n        {\n            \"state_dict\": model.state_dict(),\n            \"opt\": opt,\n            \"progress\": progress_saver,\n        },\n        savepath,\n    )\n\n\n\"\"\"=============================================================================================================\"\"\"\n\n\n################## WRITE TO CSV FILE #####################\nclass CSV_Writer:\n    \"\"\"\n    Class to append newly compute training metrics to a csv file\n    for data logging.\n    Is used together with the LOGGER class.\n    \"\"\"\n\n    def __init__(self, save_path, columns):\n        \"\"\"\n        Args:\n            save_path: str, where to store the csv file\n            columns:   list of str, name of csv columns under which the resp. metrics are stored.\n        Returns:\n            Nothing!\n        \"\"\"\n        self.save_path = save_path\n        self.columns = columns\n\n        with open(self.save_path, \"a\") as csv_file:\n            writer = csv.writer(csv_file, delimiter=\",\")\n            writer.writerow(self.columns)\n\n    def log(self, inputs):\n        \"\"\"\n        log one set of entries to the csv.\n\n        Args:\n            inputs: [list of int/str/float], values to append to the csv. Has to be of the same length as self.columns.\n        Returns:\n            Nothing!\n        \"\"\"\n        with open(self.save_path, \"a\") as csv_file:\n            writer = csv.writer(csv_file, delimiter=\",\")\n            writer.writerow(inputs)\n\n\n################## GENERATE LOGGING FOLDER/FILES #######################\ndef set_logging(opt):\n    \"\"\"\n    Generate the folder in which everything is saved.\n    If opt.savename is given, folder will take on said name.\n    If not, a name based on the start time is provided.\n    If the folder already exists, it will by iterated until it can be created without\n    deleting existing data.\n    The current opt.save_path will be extended to account for the new save_folder name.\n\n    Args:\n        opt: argparse.Namespace, contains all training-specific parameters.\n    Returns:\n        Nothing!\n    \"\"\"\n    checkfolder = opt.save_path + \"/\" + str(opt.iter)\n\n    # Create start-time-based name if opt.savename is not give.\n    if opt.savename == \"\":\n        date = datetime.datetime.now()\n        checkfolder = opt.save_path + \"/\" + str(opt.iter)\n\n    # If folder already exists, iterate over it until is doesn't.\n    # counter     = 1\n    # while os.path.exists(checkfolder):\n    #     checkfolder = opt.save_path+'/'+opt.savename+'_'+str(counter)\n    #     counter += 1\n\n    # Create Folder\n    if not os.path.exists(checkfolder):\n        os.makedirs(checkfolder)\n    opt.save_path = checkfolder\n\n    # Store training parameters as text and pickle in said folder.\n    with open(opt.save_path + \"/Parameter_Info.txt\", \"w\") as f:\n        f.write(gimme_save_string(opt))\n    pkl.dump(opt, open(opt.save_path + \"/hypa.pkl\", \"wb\"))\n\n\nimport pdb\n\n\nclass LOGGER:\n    \"\"\"\n    This class provides a collection of logging properties that are useful for training.\n    These include setting the save folder, in which progression of training/testing metrics is visualized,\n    csv log-files are stored, sample recoveries are plotted and an internal data saver.\n    \"\"\"\n\n    def __init__(self, opt, metrics_to_log, name=\"Basic\", start_new=True):\n        \"\"\"\n        Args:\n            opt:               argparse.Namespace, contains all training-specific parameters.\n            metrics_to_log:    dict, dictionary which shows in what structure the data should be saved.\n                               is given as the output of aux.metrics_to_examine. Example:\n                               {'train': ['Epochs', 'Time', 'Train Loss', 'Time'],\n                                'val': ['Epochs','Time','NMI','F1', 'Recall @ 1','Recall @ 2','Recall @ 4','Recall @ 8']}\n            name:              Name of this logger. Will be used to distinguish logged files from other LOGGER instances.\n            start_new:         If set to true, a new save folder will be created initially.\n        Returns:\n            Nothing!\n        \"\"\"\n        self.prop = opt\n        self.metrics_to_log = metrics_to_log\n\n        ### Make Logging Directories\n        if start_new:\n            set_logging(opt)\n\n        ### Set Progress Saver Dict\n        self.progress_saver = self.provide_progress_saver(metrics_to_log)\n\n        ### Set CSV Writters\n        self.csv_loggers = {\n            mode: CSV_Writer(\n                opt.save_path + \"/log_\" + mode + \"_\" + name + \".csv\", lognames\n            )\n            for mode, lognames in metrics_to_log.items()\n        }\n\n    def provide_progress_saver(self, metrics_to_log):\n        \"\"\"\n        Provide Progress Saver dictionary.\n\n        Args:\n            metrics_to_log: see __init__(). Describes the structure of Progress_Saver.\n        \"\"\"\n        Progress_Saver = {\n            key: {sub_key: [] for sub_key in metrics_to_log[key]}\n            for key in metrics_to_log.keys()\n        }\n        return Progress_Saver\n\n    def log(self, main_keys, metric_keys, values):\n        \"\"\"\n        Actually log new values in csv and Progress Saver dict internally.\n        Args:\n            main_keys:      Main key in which data will be stored. Normally is either 'train' for training metrics or 'val' for validation metrics.\n            metric_keys:    Needs to follow the list length of self.progress_saver[main_key(s)]. List of metric keys that are extended with new values.\n            values:         Needs to be a list of the same structure as metric_keys. Actual values that are appended.\n        \"\"\"\n        if not isinstance(main_keys, list):\n            main_keys = [main_keys]\n        if not isinstance(metric_keys, list):\n            metric_keys = [metric_keys]\n        if not isinstance(values, list):\n            values = [values]\n\n        # Log data to progress saver dict.\n        for main_key in main_keys:\n            for value, metric_key in zip(values, metric_keys):\n                self.progress_saver[main_key][metric_key].append(value)\n\n        # Append data to csv.\n        self.csv_loggers[main_key].log(values)\n\n    def update_info_plot(self):\n        \"\"\"\n        Create a new updated version of training/metric progression plot.\n\n        Args:\n            None\n        Returns:\n            Nothing!\n        \"\"\"\n        t_epochs = self.progress_saver[\"val\"][\"Epochs\"]\n        t_loss_list = [self.progress_saver[\"train\"][\"Train Loss\"]]\n        t_legend_handles = [\"Train Loss\"]\n\n        v_epochs = self.progress_saver[\"val\"][\"Epochs\"]\n        # Because Vehicle-ID normally uses three different test sets, a distinction has to be made.\n        if self.prop.dataset != \"vehicle_id\":\n            title = \" | \".join(\n                key + \": {0:3.3f}\".format(np.max(item))\n                for key, item in self.progress_saver[\"val\"].items()\n                if key not in [\"Time\", \"Epochs\"]\n            )\n            self.info_plot.title = title\n            v_metric_list = [\n                self.progress_saver[\"val\"][key]\n                for key in self.progress_saver[\"val\"].keys()\n                if key not in [\"Time\", \"Epochs\"]\n            ]\n            v_legend_handles = [\n                key\n                for key in self.progress_saver[\"val\"].keys()\n                if key not in [\"Time\", \"Epochs\"]\n            ]\n\n            self.info_plot.make_plot(\n                t_epochs,\n                v_epochs,\n                t_loss_list,\n                v_metric_list,\n                t_legend_handles,\n                v_legend_handles,\n            )\n        else:\n            # Iterate over all test sets.\n            for i in range(3):\n                title = \" | \".join(\n                    key + \": {0:3.3f}\".format(np.max(item))\n                    for key, item in self.progress_saver[\"val\"].items()\n                    if key not in [\"Time\", \"Epochs\"]\n                    and \"Set {}\".format(i) in key\n                )\n                self.info_plot[\"Set {}\".format(i)].title = title\n                v_metric_list = [\n                    self.progress_saver[\"val\"][key]\n                    for key in self.progress_saver[\"val\"].keys()\n                    if key not in [\"Time\", \"Epochs\"]\n                    and \"Set {}\".format(i) in key\n                ]\n                v_legend_handles = [\n                    key\n                    for key in self.progress_saver[\"val\"].keys()\n                    if key not in [\"Time\", \"Epochs\"]\n                    and \"Set {}\".format(i) in key\n                ]\n                self.info_plot[\"Set {}\".format(i)].make_plot(\n                    t_epochs,\n                    v_epochs,\n                    t_loss_list,\n                    v_metric_list,\n                    t_legend_handles,\n                    v_legend_handles,\n                    appendix=\"set_{}\".format(i),\n                )\n\n\ndef metrics_to_examine(dataset, k_vals):\n    \"\"\"\n    Please only use either of the following keys:\n    -> Epochs, Time, Train Loss for training\n    -> Epochs, Time, NMI, F1 & Recall @ k for validation\n\n    Args:\n        dataset: str, dataset for which a storing structure for LOGGER.progress_saver is to be made.\n        k_vals:  list of int, Recall @ k - values.\n    Returns:\n        metric_dict: Dictionary representing the storing structure for LOGGER.progress_saver. See LOGGER.__init__() for an example.\n    \"\"\"\n    metric_dict = {\"train\": [\"Epochs\", \"Time\", \"Train Loss\"]}\n\n    if dataset == \"vehicle_id\":\n        metric_dict[\"val\"] = [\"Epochs\", \"Time\"]\n        # Vehicle_ID uses three test sets\n        for i in range(3):\n            metric_dict[\"val\"] += [\n                \"Set {} NMI\".format(i),\n                \"Set {} F1\".format(i),\n            ]\n            for k in k_vals:\n                metric_dict[\"val\"] += [\"Set {} Recall @ {}\".format(i, k)]\n    else:\n        metric_dict[\"val\"] = [\"Epochs\", \"Time\", \"NMI\", \"F1\"]\n        metric_dict[\"val\"] += [\"Recall @ {}\".format(k) for k in k_vals]\n\n    return metric_dict\n\n\ndef bool_flag(s):\n    \"\"\"\n    Parse boolean arguments from the command line.\n    \"\"\"\n    FALSY_STRINGS = {\"off\", \"false\", \"0\"}\n    TRUTHY_STRINGS = {\"on\", \"true\", \"1\"}\n    if s.lower() in FALSY_STRINGS:\n        return False\n    elif s.lower() in TRUTHY_STRINGS:\n        return True\n    else:\n        raise argparse.ArgumentTypeError(\"invalid value for a boolean flag\")\n\n\ndef vis(model, test_dataloader, device, split, opt):\n    linsize = opt.linsize\n    torch.cuda.empty_cache()\n    if opt.dataset == \"Inaturalist\":\n        if opt.iter > 0:\n            with open(opt.cluster_path, \"rb\") as clusterf:\n                (\n                    path2idx,\n                    global_features,\n                    global_pred_labels,\n                    gt_labels,\n                    masks,\n                ) = pkl.load(clusterf)\n                gt_labels = gt_labels + len(np.unique(global_pred_labels))\n                idx2path = {v: k for k, v in path2idx.items()}\n        else:\n            with open(os.path.join(opt.source_path, \"train_set1.txt\")) as f:\n                filelines = f.readlines()\n                paths = [x.strip() for x in filelines]\n                Lin_paths = paths[:linsize]\n                masks = np.zeros(len(paths))\n                masks[: len(Lin_paths)] = 0\n                masks[len(Lin_paths) :] = 2\n\n    _ = model.eval()\n    path2ids = {}\n\n    with torch.no_grad():\n        ### For all test images, extract features\n        target_labels, feature_coll = [], []\n        final_iter = tqdm(\n            test_dataloader, desc=\"Computing Evaluation Metrics...\"\n        )\n        image_paths = [x[0] for x in test_dataloader.dataset.image_list]\n        for i in range(len(image_paths)):\n            path2ids[image_paths[i]] = i\n        for idx, inp in enumerate(final_iter):\n            input_img, target = inp[-1], inp[0]\n            target_labels.extend(target.numpy().tolist())\n            out = model(input_img.to(device), feature=True)\n            feature_coll.extend(out.cpu().detach().numpy().tolist())\n        # pdb.set_trace()\n        target_labels = np.hstack(target_labels).reshape(-1)\n        feature_coll = np.vstack(feature_coll).astype(\"float32\")\n\n    if (opt.dataset == \"Inaturalist\") and \"all_train\" in split:\n        if opt.iter > 0:\n            predicted_features = np.zeros_like(feature_coll)\n            path2ids_new = {}\n            target_labels_new = np.zeros_like(target_labels)\n            for i in range(len(idx2path.keys())):\n                path = idx2path[i]\n                idxx = path2ids[path]\n                path2ids_new[path] = i\n                predicted_features[i] = feature_coll[idxx]\n                target_labels_new[i] = target_labels[idxx]\n\n            path2ids = path2ids_new\n            feature_coll = predicted_features\n            target_labels = target_labels_new\n            gtlabels = target_labels\n            lastuselected = np.where(masks == 1)\n            masks[lastuselected] = 0\n            print(len(np.where(masks == 0)[0]))\n        else:\n            predicted_features = np.zeros_like(feature_coll)\n            path2ids_new = {}\n            target_labels_new = np.zeros_like(target_labels)\n            for i in range(len(paths)):\n                path = paths[i]\n                idxx = path2ids[opt.source_path + \"/\" + path]\n                path2ids_new[opt.source_path + \"/\" + path] = i\n                predicted_features[i] = feature_coll[idxx]\n                target_labels_new[i] = target_labels[idxx]\n\n            path2ids = path2ids_new\n            feature_coll = predicted_features\n            target_labels = target_labels_new\n            gtlabels = target_labels\n\n    if \"all_train\" not in split:\n        print(\"all_train not in split.\")\n        gtlabels = target_labels\n\n    output_feature_path = os.path.join(\n        opt.source_path, split + \"_inat_features.pkl\"\n    )\n    print(\"Dump features into {}.\".format(output_feature_path))\n    with open(output_feature_path, \"wb\") as f:\n        pkl.dump([path2ids, feature_coll, target_labels, gtlabels, masks], f)\n\n    print(target_labels.max())\n    print(\"target_labels:\", target_labels.shape)\n    print(\"feature_coll:\", feature_coll.shape)\n"
  },
  {
    "path": "examples/pytorch/hilander/PSS/Smooth_AP/src/datasets.py",
    "content": "# repo originally forked from https://github.com/Confusezius/Deep-Metric-Learning-Baselines\n\n################# LIBRARIES ###############################\nimport pickle\nimport warnings\n\nfrom numpy.core.arrayprint import IntegerFormat\n\nwarnings.filterwarnings(\"ignore\")\n\nimport copy\nimport os\nimport random\n\nimport numpy as np\nimport pandas as pd\nimport torch\nfrom PIL import Image\nfrom torch.utils.data import Dataset\nfrom torchvision import transforms\n\n\"\"\"============================================================================\"\"\"\n\n\n################ FUNCTION TO RETURN ALL DATALOADERS NECESSARY ####################\ndef give_dataloaders(dataset, trainset, testset, opt, cluster_path=\"\"):\n    \"\"\"\n    Args:\n        dataset: string, name of dataset for which the dataloaders should be returned.\n        opt:     argparse.Namespace, contains all training-specific parameters.\n    Returns:\n        dataloaders: dict of dataloaders for training, testing and evaluation on training.\n    \"\"\"\n    # Dataset selection\n    if opt.dataset == \"Inaturalist\":\n        if opt.finetune:\n            datasets = give_inat_datasets_finetune_1head(\n                testset, cluster_path, opt\n            )\n        else:\n            if opt.get_features:\n                datasets = give_inaturalist_datasets_for_features(opt)\n            else:\n                datasets = give_inaturalist_datasets(opt)\n    else:\n        raise Exception(\"No Dataset >{}< available!\".format(dataset))\n\n    # Move datasets to dataloaders.\n    dataloaders = {}\n\n    for key, dataset in datasets.items():\n        if (\n            isinstance(dataset, TrainDatasetsmoothap)\n            or isinstance(dataset, TrainDatasetsmoothap1Head)\n        ) and key in [\"training\", \"clustering\"]:\n            dataloaders[key] = torch.utils.data.DataLoader(\n                dataset,\n                batch_size=opt.bs,\n                num_workers=opt.kernels,\n                sampler=torch.utils.data.SequentialSampler(dataset),\n                pin_memory=True,\n                drop_last=True,\n            )\n        else:\n            is_val = dataset.is_validation\n            if key == \"training\" or key == \"clustering\":\n                dataloaders[key] = torch.utils.data.DataLoader(\n                    dataset,\n                    batch_size=opt.bs,\n                    num_workers=opt.kernels,\n                    shuffle=not is_val,\n                    pin_memory=True,\n                    drop_last=not is_val,\n                )\n            else:\n                dataloaders[key] = torch.utils.data.DataLoader(\n                    dataset,\n                    batch_size=opt.bs,\n                    num_workers=6,\n                    shuffle=not is_val,\n                    pin_memory=True,\n                    drop_last=not is_val,\n                )\n    return dataloaders\n\n\ndef give_inaturalist_datasets(opt):\n    \"\"\"\n    This function generates a training, testing and evaluation dataloader for Metric Learning on the Inaturalist 2018 dataset.\n    For Metric Learning, training and test sets are provided by given json files. Will define a train and test split\n    So no random shuffling of classes.\n\n    Args:\n        opt: argparse.Namespace, contains all traininig-specific parameters.\n    Returns:\n        dict of PyTorch datasets for training, testing and evaluation.\n    \"\"\"\n    # Load text-files containing classes and imagepaths.\n    # Generate image_dicts of shape {class_idx:[list of paths to images belong to this class] ...}\n    train_image_dict, val_image_dict, test_image_dict = {}, {}, {}\n    with open(os.path.join(opt.source_path, opt.trainset)) as f:\n        FileLines = f.readlines()\n        FileLines = [x.strip() for x in FileLines]\n\n        for entry in FileLines:\n            info = entry.split(\"/\")\n            if \"/\".join([info[-3], info[-2]]) not in train_image_dict:\n                train_image_dict[\"/\".join([info[-3], info[-2]])] = []\n            train_image_dict[\"/\".join([info[-3], info[-2]])].append(\n                os.path.join(opt.source_path, entry)\n            )\n\n    with open(os.path.join(opt.source_path, opt.testset)) as f:\n        FileLines = f.readlines()\n        FileLines = [x.strip() for x in FileLines]\n\n        for entry in FileLines:\n            info = entry.split(\"/\")\n            if \"/\".join([info[-3], info[-2]]) not in val_image_dict:\n                val_image_dict[\"/\".join([info[-3], info[-2]])] = []\n            val_image_dict[\"/\".join([info[-3], info[-2]])].append(\n                os.path.join(opt.source_path, entry)\n            )\n\n    with open(os.path.join(opt.source_path, opt.testset)) as f:\n        FileLines = f.readlines()\n        FileLines = [x.strip() for x in FileLines]\n\n        for entry in FileLines:\n            info = entry.split(\"/\")\n            if \"/\".join([info[-3], info[-2]]) not in test_image_dict:\n                test_image_dict[\"/\".join([info[-3], info[-2]])] = []\n            test_image_dict[\"/\".join([info[-3], info[-2]])].append(\n                os.path.join(opt.source_path, entry)\n            )\n\n    new_train_dict = {}\n    class_ind_ind = 0\n    for cate in train_image_dict:\n        new_train_dict[\"te/%d\" % class_ind_ind] = train_image_dict[cate]\n        class_ind_ind += 1\n    train_image_dict = new_train_dict\n\n    train_dataset = TrainDatasetsmoothap(train_image_dict, opt)\n\n    val_dataset = BaseTripletDataset(val_image_dict, opt, is_validation=True)\n    eval_dataset = BaseTripletDataset(test_image_dict, opt, is_validation=True)\n\n    # train_dataset.conversion       = conversion\n    # val_dataset.conversion         = conversion\n    # eval_dataset.conversion        = conversion\n\n    return {\n        \"training\": train_dataset,\n        \"testing\": val_dataset,\n        \"evaluation\": eval_dataset,\n    }\n\n\ndef give_inaturalist_datasets_for_features(opt):\n    \"\"\"\n    This function generates a training, testing and evaluation dataloader for Metric Learning on the Inaturalist 2018 dataset.\n    For Metric Learning, training and test sets are provided by given json files. Will define a train and test split\n    So no random shuffling of classes.\n\n    Args:\n        opt: argparse.Namespace, contains all traininig-specific parameters.\n    Returns:\n        dict of PyTorch datasets for training, testing and evaluation.\n    \"\"\"\n    # Load text-files containing classes and imagepaths.\n    # Generate image_dicts of shape {class_idx:[list of paths to images belong to this class] ...}\n    train_image_dict, test_image_dict, eval_image_dict = {}, {}, {}\n\n    if opt.iter > 0:\n        with open(os.path.join(opt.cluster_path), \"rb\") as clusterf:\n            (\n                path2idx,\n                global_features,\n                global_pred_labels,\n                gt_labels,\n                masks,\n            ) = pickle.load(clusterf)\n            gt_labels = gt_labels + len(np.unique(global_pred_labels))\n\n            for path, idx in path2idx.items():\n                if global_pred_labels[idx] == -1:\n                    if \"te/%d\" % gt_labels[idx] not in test_image_dict:\n                        test_image_dict[\"te/%d\" % gt_labels[idx]] = []\n                    test_image_dict[\"te/%d\" % gt_labels[idx]].append(path)\n                else:\n                    if (\n                        \"te/%d\" % global_pred_labels[idx]\n                        not in train_image_dict\n                    ):\n                        train_image_dict[\"te/%d\" % global_pred_labels[idx]] = []\n                    train_image_dict[\"te/%d\" % global_pred_labels[idx]].append(\n                        path\n                    )\n                    if \"te/%d\" % global_pred_labels[idx] not in test_image_dict:\n                        test_image_dict[\"te/%d\" % global_pred_labels[idx]] = []\n                    test_image_dict[\"te/%d\" % global_pred_labels[idx]].append(\n                        path\n                    )\n    else:\n        with open(os.path.join(opt.source_path, opt.trainset)) as f:\n            FileLines = f.readlines()\n            FileLines = [x.strip() for x in FileLines]\n\n            for entry in FileLines:\n                info = entry.split(\"/\")\n                if \"/\".join([info[-3], info[-2]]) not in train_image_dict:\n                    train_image_dict[\"/\".join([info[-3], info[-2]])] = []\n                train_image_dict[\"/\".join([info[-3], info[-2]])].append(\n                    os.path.join(opt.source_path, entry)\n                )\n\n        with open(os.path.join(opt.source_path, opt.all_trainset)) as f:\n            FileLines = f.readlines()\n            FileLines = [x.strip() for x in FileLines]\n            for entry in FileLines:\n                info = entry.split(\"/\")\n                if \"/\".join([info[-3], info[-2]]) not in test_image_dict:\n                    test_image_dict[\"/\".join([info[-3], info[-2]])] = []\n                test_image_dict[\"/\".join([info[-3], info[-2]])].append(\n                    os.path.join(opt.source_path, entry)\n                )\n\n    with open(os.path.join(opt.source_path, opt.testset)) as f:\n        FileLines = f.readlines()\n        FileLines = [x.strip() for x in FileLines]\n\n        for entry in FileLines:\n            info = entry.split(\"/\")\n            if \"/\".join([info[-3], info[-2]]) not in eval_image_dict:\n                eval_image_dict[\"/\".join([info[-3], info[-2]])] = []\n            eval_image_dict[\"/\".join([info[-3], info[-2]])].append(\n                os.path.join(opt.source_path, entry)\n            )\n\n    new_train_dict = {}\n    class_ind_ind = 0\n    for cate in train_image_dict:\n        new_train_dict[\"te/%d\" % class_ind_ind] = train_image_dict[cate]\n        class_ind_ind += 1\n    train_image_dict = new_train_dict\n\n    new_test_dict = {}\n    class_ind_ind = 0\n    for cate in test_image_dict:\n        new_test_dict[\"te/%d\" % class_ind_ind] = test_image_dict[cate]\n        class_ind_ind += 1\n    test_image_dict = new_test_dict\n\n    new_eval_dict = {}\n    class_ind_ind = 0\n    for cate in eval_image_dict:\n        new_eval_dict[\"te/%d\" % class_ind_ind] = eval_image_dict[cate]\n        class_ind_ind += 1\n    eval_image_dict = new_eval_dict\n\n    train_dataset = BaseTripletDataset(\n        train_image_dict, opt, is_validation=True\n    )\n    test_dataset = BaseTripletDataset(test_image_dict, opt, is_validation=True)\n    eval_dataset = BaseTripletDataset(eval_image_dict, opt, is_validation=True)\n\n    # train_dataset.conversion       = conversion\n    # val_dataset.conversion         = conversion\n    # eval_dataset.conversion        = conversion\n\n    return {\n        \"training\": train_dataset,\n        \"testing\": test_dataset,\n        \"eval\": eval_dataset,\n    }\n\n\ndef give_inat_datasets_finetune_1head(testset, cluster_label_path, opt):\n    \"\"\"\n    This function generates a training, testing and evaluation dataloader for Metric Learning on the Inaturalist 2018 dataset.\n    For Metric Learning, training and test sets are provided by given json files. Will define a train and test split\n    So no random shuffling of classes.\n\n    Args:\n        opt: argparse.Namespace, contains all traininig-specific parameters.\n    Returns:\n        dict of PyTorch datasets for training, testing and evaluation.\n    \"\"\"\n    # Load cluster labels from hilander results.\n    import pickle\n\n    train_image_dict, val_image_dict, cluster_image_dict = {}, {}, {}\n    with open(cluster_label_path, \"rb\") as clusterf:\n        (\n            path2idx,\n            global_features,\n            global_pred_labels,\n            gt_labels,\n            masks,\n        ) = pickle.load(clusterf)\n\n        for path, idx in path2idx.items():\n            if global_pred_labels[idx] == -1:\n                continue\n            else:\n                if \"te/%d\" % global_pred_labels[idx] not in train_image_dict:\n                    train_image_dict[\"te/%d\" % global_pred_labels[idx]] = []\n                train_image_dict[\"te/%d\" % global_pred_labels[idx]].append(path)\n\n    with open(os.path.join(opt.source_path, testset)) as f:\n        FileLines = f.readlines()\n        FileLines = [x.strip() for x in FileLines]\n\n        for entry in FileLines:\n            info = entry.split(\"/\")\n            if \"/\".join([info[-3], info[-2]]) not in val_image_dict:\n                val_image_dict[\"/\".join([info[-3], info[-2]])] = []\n            val_image_dict[\"/\".join([info[-3], info[-2]])].append(\n                os.path.join(opt.source_path, entry)\n            )\n\n    train_dataset = TrainDatasetsmoothap(train_image_dict, opt)\n\n    val_dataset = BaseTripletDataset(val_image_dict, opt, is_validation=True)\n\n    # train_dataset.conversion       = conversion\n    # val_dataset.conversion         = conversion\n    # eval_dataset.conversion        = conversion\n\n    return {\n        \"training\": train_dataset,\n        \"testing\": val_dataset,\n        \"evaluation\": val_dataset,\n    }\n\n\n################## BASIC PYTORCH DATASET USED FOR ALL DATASETS ##################################\nclass BaseTripletDataset(Dataset):\n    \"\"\"\n    Dataset class to provide (augmented) correctly prepared training samples corresponding to standard DML literature.\n    This includes normalizing to ImageNet-standards, and Random & Resized cropping of shapes 224 for ResNet50 and 227 for\n    GoogLeNet during Training. During validation, only resizing to 256 or center cropping to 224/227 is performed.\n    \"\"\"\n\n    def __init__(\n        self, image_dict, opt, samples_per_class=8, is_validation=False\n    ):\n        \"\"\"\n        Dataset Init-Function.\n\n        Args:\n            image_dict:         dict, Dictionary of shape {class_idx:[list of paths to images belong to this class] ...} providing all the training paths and classes.\n            opt:                argparse.Namespace, contains all training-specific parameters.\n            samples_per_class:  Number of samples to draw from one class before moving to the next when filling the batch.\n            is_validation:      If is true, dataset properties for validation/testing are used instead of ones for training.\n        Returns:\n            Nothing!\n        \"\"\"\n        # Define length of dataset\n        self.n_files = np.sum(\n            [len(image_dict[key]) for key in image_dict.keys()]\n        )\n\n        self.is_validation = is_validation\n\n        self.pars = opt\n        self.image_dict = image_dict\n\n        self.avail_classes = sorted(list(self.image_dict.keys()))\n\n        # Convert image dictionary from classname:content to class_idx:content, because the initial indices are not necessarily from 0 - <n_classes>.\n        self.image_dict = {\n            i: self.image_dict[key] for i, key in enumerate(self.avail_classes)\n        }\n        self.avail_classes = sorted(list(self.image_dict.keys()))\n\n        # Init. properties that are used when filling up batches.\n        if not self.is_validation:\n            self.samples_per_class = samples_per_class\n            # Select current class to sample images from up to <samples_per_class>\n            self.current_class = np.random.randint(len(self.avail_classes))\n            self.classes_visited = [self.current_class, self.current_class]\n            self.n_samples_drawn = 0\n\n        # Data augmentation/processing methods.\n        normalize = transforms.Normalize(\n            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]\n        )\n        transf_list = []\n        if not self.is_validation:\n            transf_list.extend(\n                [\n                    transforms.RandomResizedCrop(size=224)\n                    if opt.arch == \"resnet50\"\n                    else transforms.RandomResizedCrop(size=227),\n                    transforms.RandomHorizontalFlip(0.5),\n                ]\n            )\n        else:\n            transf_list.extend(\n                [\n                    transforms.Resize(256),\n                    transforms.CenterCrop(224)\n                    if opt.arch == \"resnet50\"\n                    else transforms.CenterCrop(227),\n                ]\n            )\n\n        transf_list.extend([transforms.ToTensor(), normalize])\n        self.transform = transforms.Compose(transf_list)\n\n        # Convert Image-Dict to list of (image_path, image_class). Allows for easier direct sampling.\n        self.image_list = [\n            [(x, key) for x in self.image_dict[key]]\n            for key in self.image_dict.keys()\n        ]\n        self.image_list = [x for y in self.image_list for x in y]\n\n        # Flag that denotes if dataset is called for the first time.\n        self.is_init = True\n\n    def ensure_3dim(self, img):\n        \"\"\"\n        Function that ensures that the input img is three-dimensional.\n\n        Args:\n            img: PIL.Image, image which is to be checked for three-dimensionality (i.e. if some images are black-and-white in an otherwise coloured dataset).\n        Returns:\n            Checked PIL.Image img.\n        \"\"\"\n        if len(img.size) == 2:\n            img = img.convert(\"RGB\")\n        return img\n\n    def __getitem__(self, idx):\n        \"\"\"\n        Args:\n            idx: Sample idx for training sample\n        Returns:\n            tuple of form (sample_class, torch.Tensor() of input image)\n        \"\"\"\n        if self.pars.loss == \"smoothap\" or self.pars.loss == \"smoothap_element\":\n            if self.is_init:\n                # self.current_class = self.avail_classes[idx%len(self.avail_classes)]\n                self.is_init = False\n\n            if not self.is_validation:\n                if self.samples_per_class == 1:\n                    return self.image_list[idx][-1], self.transform(\n                        self.ensure_3dim(Image.open(self.image_list[idx][0]))\n                    )\n\n                if self.n_samples_drawn == self.samples_per_class:\n                    # Once enough samples per class have been drawn, we choose another class to draw samples from.\n                    # Note that we ensure with self.classes_visited that no class is chosen if it had been chosen\n                    # previously or one before that.\n                    counter = copy.deepcopy(self.avail_classes)\n                    for prev_class in self.classes_visited:\n                        if prev_class in counter:\n                            counter.remove(prev_class)\n\n                    self.current_class = counter[idx % len(counter)]\n                    # self.classes_visited = self.classes_visited[1:]+[self.current_class]\n                    # EDIT -> there can be no class repeats\n                    self.classes_visited = self.classes_visited + [\n                        self.current_class\n                    ]\n                    self.n_samples_drawn = 0\n\n                class_sample_idx = idx % len(\n                    self.image_dict[self.current_class]\n                )\n                self.n_samples_drawn += 1\n\n                out_img = self.transform(\n                    self.ensure_3dim(\n                        Image.open(\n                            self.image_dict[self.current_class][\n                                class_sample_idx\n                            ]\n                        )\n                    )\n                )\n                return self.current_class, out_img\n            else:\n                return self.image_list[idx][-1], self.transform(\n                    self.ensure_3dim(Image.open(self.image_list[idx][0]))\n                )\n        else:\n            if self.is_init:\n                self.current_class = self.avail_classes[\n                    idx % len(self.avail_classes)\n                ]\n                self.is_init = False\n            if not self.is_validation:\n                if self.samples_per_class == 1:\n                    return self.image_list[idx][-1], self.transform(\n                        self.ensure_3dim(Image.open(self.image_list[idx][0]))\n                    )\n\n                if self.n_samples_drawn == self.samples_per_class:\n                    # Once enough samples per class have been drawn, we choose another class to draw samples from.\n                    # Note that we ensure with self.classes_visited that no class is chosen if it had been chosen\n                    # previously or one before that.\n                    counter = copy.deepcopy(self.avail_classes)\n                    for prev_class in self.classes_visited:\n                        if prev_class in counter:\n                            counter.remove(prev_class)\n\n                    self.current_class = counter[idx % len(counter)]\n                    self.classes_visited = self.classes_visited[1:] + [\n                        self.current_class\n                    ]\n                    self.n_samples_drawn = 0\n\n                class_sample_idx = idx % len(\n                    self.image_dict[self.current_class]\n                )\n                self.n_samples_drawn += 1\n\n                out_img = self.transform(\n                    self.ensure_3dim(\n                        Image.open(\n                            self.image_dict[self.current_class][\n                                class_sample_idx\n                            ]\n                        )\n                    )\n                )\n                return self.current_class, out_img\n            else:\n                return self.image_list[idx][-1], self.transform(\n                    self.ensure_3dim(Image.open(self.image_list[idx][0]))\n                )\n\n    def __len__(self):\n        return self.n_files\n\n\nflatten = lambda l: [item for sublist in l for item in sublist]\n\n######################## dataset for SmoothAP regular training ##################################\n\n\nclass TrainDatasetsmoothap(Dataset):\n    \"\"\"\n    This dataset class allows mini-batch formation pre-epoch, for greater speed\n\n    \"\"\"\n\n    def __init__(self, image_dict, opt):\n        \"\"\"\n        Args:\n            image_dict: two-level dict, `super_dict[super_class_id][class_id]` gives the list of\n                        image paths having the same super-label and class label\n        \"\"\"\n        self.image_dict = image_dict\n        self.dataset_name = opt.dataset\n        self.batch_size = opt.bs\n        self.samples_per_class = opt.samples_per_class\n        for sub in self.image_dict:\n            newsub = []\n            for instance in self.image_dict[sub]:\n                newsub.append((sub, instance))\n            self.image_dict[sub] = newsub\n\n        # checks\n        # provide avail_classes\n        self.avail_classes = [*self.image_dict]\n        # Data augmentation/processing methods.\n        normalize = transforms.Normalize(\n            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]\n        )\n        transf_list = []\n\n        transf_list.extend(\n            [\n                transforms.RandomResizedCrop(size=224)\n                if opt.arch in [\"resnet50\", \"resnet50_mcn\"]\n                else transforms.RandomResizedCrop(size=227),\n                transforms.RandomHorizontalFlip(0.5),\n            ]\n        )\n        transf_list.extend([transforms.ToTensor(), normalize])\n        self.transform = transforms.Compose(transf_list)\n\n        self.reshuffle()\n\n    def ensure_3dim(self, img):\n        if len(img.size) == 2:\n            img = img.convert(\"RGB\")\n        return img\n\n    def reshuffle(self):\n        image_dict = copy.deepcopy(self.image_dict)\n        print(\"shuffling data\")\n        for sub in image_dict:\n            random.shuffle(image_dict[sub])\n\n        classes = [*image_dict]\n        random.shuffle(classes)\n        total_batches = []\n        batch = []\n        finished = 0\n        while finished == 0:\n            for sub_class in classes:\n                if (len(image_dict[sub_class]) >= self.samples_per_class) and (\n                    len(batch) < self.batch_size / self.samples_per_class\n                ):\n                    batch.append(\n                        image_dict[sub_class][: self.samples_per_class]\n                    )\n                    image_dict[sub_class] = image_dict[sub_class][\n                        self.samples_per_class :\n                    ]\n\n            if len(batch) == self.batch_size / self.samples_per_class:\n                total_batches.append(batch)\n                batch = []\n            else:\n                finished = 1\n\n        random.shuffle(total_batches)\n        self.dataset = flatten(flatten(total_batches))\n\n    def __getitem__(self, idx):\n        # we use SequentialSampler together with SuperLabelTrainDataset,\n        # so idx==0 indicates the start of a new epoch\n        batch_item = self.dataset[idx]\n\n        if self.dataset_name == \"Inaturalist\":\n            cls = int(batch_item[0].split(\"/\")[1])\n\n        else:\n            cls = batch_item[0]\n        img = Image.open(batch_item[1])\n        return cls, self.transform(self.ensure_3dim(img))\n\n    def __len__(self):\n        return len(self.dataset)\n\n\nclass TrainDatasetsmoothap1Head(Dataset):\n    \"\"\"\n    This dataset class allows mini-batch formation pre-epoch, for greater speed\n\n    \"\"\"\n\n    def __init__(self, image_dict_L, image_dict_U, opt):\n        \"\"\"\n        Args:\n            image_dict: two-level dict, `super_dict[super_class_id][class_id]` gives the list of\n                        image paths having the same super-label and class label\n        \"\"\"\n        self.image_dict_L = image_dict_L\n        self.image_dict_U = image_dict_U\n        self.dataset_name = opt.dataset\n        self.batch_size = opt.bs\n        self.samples_per_class = opt.samples_per_class\n        for sub_L in self.image_dict_L:\n            newsub_L = []\n            for instance in self.image_dict_L[sub_L]:\n                newsub_L.append((sub_L, instance))\n            self.image_dict_L[sub_L] = newsub_L\n\n        for sub_U in self.image_dict_U:\n            newsub_U = []\n            for instance in self.image_dict_U[sub_U]:\n                newsub_U.append((sub_U, instance))\n            self.image_dict_U[sub_U] = newsub_U\n\n        # checks\n        # provide avail_classes\n        self.avail_classes = [*self.image_dict_L] + [*self.image_dict_U]\n        # Data augmentation/processing methods.\n        normalize = transforms.Normalize(\n            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]\n        )\n        transf_list = []\n\n        transf_list.extend(\n            [\n                transforms.RandomResizedCrop(size=224)\n                if opt.arch in [\"resnet50\", \"resnet50_mcn\"]\n                else transforms.RandomResizedCrop(size=227),\n                transforms.RandomHorizontalFlip(0.5),\n            ]\n        )\n        transf_list.extend([transforms.ToTensor(), normalize])\n        self.transform = transforms.Compose(transf_list)\n\n        self.reshuffle()\n\n    def sample_same_size(self):\n        image_dict = copy.deepcopy(self.image_dict_L)\n\n        L_size = 0\n        for sub_L in self.image_dict_L:\n            L_size += len(self.image_dict_L[sub_L])\n\n        U_size = 0\n        classes_U = [*self.image_dict_U]\n        # while U_size < len(list(self.image_dict_U)) and U_size < L_size:\n        while len(classes_U) != 0:\n            sub_U = random.choice(classes_U)\n            classes_U.remove(sub_U)\n            sub_U_size = len(self.image_dict_U[sub_U])\n            if sub_U in [*image_dict]:\n                image_dict[sub_U].extend(self.image_dict_U[sub_U])\n            else:\n                image_dict[sub_U] = self.image_dict_U[sub_U]\n            U_size += sub_U_size\n        return image_dict\n\n    def ensure_3dim(self, img):\n        if len(img.size) == 2:\n            img = img.convert(\"RGB\")\n        return img\n\n    def reshuffle(self):\n        image_dict = self.sample_same_size()\n        print(\"shuffling data\")\n        for sub in image_dict:\n            random.shuffle(image_dict[sub])\n\n        classes = [*image_dict]\n        random.shuffle(classes)\n        total_batches = []\n        batch = []\n        finished = 0\n        while finished == 0:\n            for sub_class in classes:\n                if (len(image_dict[sub_class]) >= self.samples_per_class) and (\n                    len(batch) < self.batch_size / self.samples_per_class\n                ):\n                    batch.append(\n                        image_dict[sub_class][: self.samples_per_class]\n                    )\n                    image_dict[sub_class] = image_dict[sub_class][\n                        self.samples_per_class :\n                    ]\n\n            if len(batch) == self.batch_size / self.samples_per_class:\n                total_batches.append(batch)\n                batch = []\n            else:\n                finished = 1\n\n        random.shuffle(total_batches)\n        self.dataset = flatten(flatten(total_batches))\n\n    def __getitem__(self, idx):\n        # we use SequentialSampler together with SuperLabelTrainDataset,\n        # so idx==0 indicates the start of a new epoch\n        batch_item = self.dataset[idx]\n\n        if self.dataset_name == \"Inaturalist\":\n            cls = int(batch_item[0].split(\"/\")[1])\n        else:\n            cls = batch_item[0]\n        img = Image.open(str(batch_item[1]))\n        return cls, self.transform(self.ensure_3dim(img))\n\n    def __len__(self):\n        return len(self.dataset)\n"
  },
  {
    "path": "examples/pytorch/hilander/PSS/Smooth_AP/src/evaluate.py",
    "content": "# repo originally forked from https://github.com/Confusezius/Deep-Metric-Learning-Baselines\n\n\n##################################### LIBRARIES ###########################################\nimport warnings\n\nwarnings.filterwarnings(\"ignore\")\n\nimport csv\nimport pickle as pkl\nimport time\n\nimport auxiliaries as aux\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport torch\nimport torch.multiprocessing\nimport torch.nn as nn\nfrom scipy.spatial import distance\nfrom sklearn.preprocessing import normalize\nfrom tqdm import tqdm\n\ntorch.multiprocessing.set_sharing_strategy(\"file_system\")\n\n\n\"\"\"==================================================================================================================\"\"\"\n\"\"\"==================================================================================================================\"\"\"\n\"\"\"=========================================================\"\"\"\n\n\ndef evaluate(dataset, LOG, **kwargs):\n    \"\"\"\n    Given a dataset name, applies the correct evaluation function.\n\n    Args:\n        dataset: str, name of dataset.\n        LOG:     aux.LOGGER instance, main logging class.\n        **kwargs: Input Argument Dict, depends on dataset.\n    Returns:\n        (optional) Computed metrics. Are normally written directly to LOG and printed.\n    \"\"\"\n    if dataset in [\"Inaturalist\", \"semi_fungi\"]:\n        ret = evaluate_one_dataset(LOG, **kwargs)\n    elif dataset in [\"vehicle_id\"]:\n        ret = evaluate_multiple_datasets(LOG, **kwargs)\n    else:\n        raise Exception(\"No implementation for dataset {} available!\")\n\n    return ret\n\n\n\"\"\"=========================================================\"\"\"\n\n\nclass DistanceMeasure:\n    \"\"\"\n    Container class to run and log the change of distance ratios\n    between intra-class distances and inter-class distances.\n    \"\"\"\n\n    def __init__(self, checkdata, opt, name=\"Train\", update_epochs=1):\n        \"\"\"\n        Args:\n            checkdata: PyTorch DataLoader, data to check distance progression.\n            opt:       argparse.Namespace, contains all training-specific parameters.\n            name:      str, Name of instance. Important for savenames.\n            update_epochs:  int, Only compute distance ratios every said epoch.\n        Returns:\n            Nothing!\n        \"\"\"\n        self.update_epochs = update_epochs\n        self.pars = opt\n        self.save_path = opt.save_path\n\n        self.name = name\n        self.csv_file = opt.save_path + \"/distance_measures_{}.csv\".format(\n            self.name\n        )\n        with open(self.csv_file, \"a\") as csv_file:\n            writer = csv.writer(csv_file, delimiter=\",\")\n            writer.writerow([\"Rel. Intra/Inter Distance\"])\n\n        self.checkdata = checkdata\n\n        self.mean_class_dists = []\n        self.epochs = []\n\n    def measure(self, model, epoch):\n        \"\"\"\n        Compute distance ratios of intra- and interclass distance.\n\n        Args:\n            model: PyTorch Network, network that produces the resp. embeddings.\n            epoch: Current epoch.\n        Returns:\n            Nothing!\n        \"\"\"\n        if epoch % self.update_epochs:\n            return\n\n        self.epochs.append(epoch)\n\n        torch.cuda.empty_cache()\n\n        _ = model.eval()\n\n        # Compute Embeddings\n        with torch.no_grad():\n            feature_coll, target_coll = [], []\n            data_iter = tqdm(\n                self.checkdata, desc=\"Estimating Data Distances...\"\n            )\n            for idx, data in enumerate(data_iter):\n                input_img, target = data[1], data[0]\n                features = model(input_img.to(self.pars.device))\n                feature_coll.extend(features.cpu().detach().numpy().tolist())\n                target_coll.extend(target.numpy().tolist())\n\n        feature_coll = np.vstack(feature_coll).astype(\"float32\")\n        target_coll = np.hstack(target_coll).reshape(-1)\n        avail_labels = np.unique(target_coll)\n\n        # Compute indixes of embeddings for each class.\n        class_positions = []\n        for lab in avail_labels:\n            class_positions.append(np.where(target_coll == lab)[0])\n\n        # Compute average intra-class distance and center of mass.\n        com_class, dists_class = [], []\n        for class_pos in class_positions:\n            dists = distance.cdist(\n                feature_coll[class_pos], feature_coll[class_pos], \"cosine\"\n            )\n            dists = np.sum(dists) / (len(dists) ** 2 - len(dists))\n            # dists = np.linalg.norm(np.std(feature_coll_aux[class_pos],axis=0).reshape(1,-1)).reshape(-1)\n            com = normalize(\n                np.mean(feature_coll[class_pos], axis=0).reshape(1, -1)\n            ).reshape(-1)\n            dists_class.append(dists)\n            com_class.append(com)\n\n        # Compute mean inter-class distances by the class-coms.\n        mean_inter_dist = distance.cdist(\n            np.array(com_class), np.array(com_class), \"cosine\"\n        )\n        mean_inter_dist = np.sum(mean_inter_dist) / (\n            len(mean_inter_dist) ** 2 - len(mean_inter_dist)\n        )\n\n        # Compute distance ratio\n        mean_class_dist = np.mean(np.array(dists_class) / mean_inter_dist)\n        self.mean_class_dists.append(mean_class_dist)\n\n        self.update(mean_class_dist)\n\n    def update(self, mean_class_dist):\n        \"\"\"\n        Update Loggers.\n\n        Args:\n            mean_class_dist: float, Distance Ratio\n        Returns:\n            Nothing!\n        \"\"\"\n        self.update_csv(mean_class_dist)\n        self.update_plot()\n\n    def update_csv(self, mean_class_dist):\n        \"\"\"\n        Update CSV.\n\n        Args:\n            mean_class_dist: float, Distance Ratio\n        Returns:\n            Nothing!\n        \"\"\"\n        with open(self.csv_file, \"a\") as csv_file:\n            writer = csv.writer(csv_file, delimiter=\",\")\n            writer.writerow([mean_class_dist])\n\n    def update_plot(self):\n        \"\"\"\n        Update progression plot.\n\n        Args:\n            None.\n        Returns:\n            Nothing!\n        \"\"\"\n        plt.style.use(\"ggplot\")\n        f, ax = plt.subplots(1)\n        ax.set_title(\"Mean Intra- over Interclassdistances\")\n        ax.plot(self.epochs, self.mean_class_dists, label=\"Class\")\n        f.legend()\n        f.set_size_inches(15, 8)\n        f.savefig(\n            self.save_path + \"/distance_measures_{}.svg\".format(self.name)\n        )\n\n\nclass GradientMeasure:\n    \"\"\"\n    Container for gradient measure functionalities.\n    Measure the gradients coming from the embedding layer to the final conv. layer\n    to examine learning signal.\n    \"\"\"\n\n    def __init__(self, opt, name=\"class-it\"):\n        \"\"\"\n        Args:\n            opt:   argparse.Namespace, contains all training-specific parameters.\n            name:  Name of class instance. Important for the savename.\n        Returns:\n            Nothing!\n        \"\"\"\n        self.pars = opt\n        self.name = name\n        self.saver = {\n            \"grad_normal_mean\": [],\n            \"grad_normal_std\": [],\n            \"grad_abs_mean\": [],\n            \"grad_abs_std\": [],\n        }\n\n    def include(self, params):\n        \"\"\"\n        Include the gradients for a set of parameters, normally the final embedding layer.\n\n        Args:\n            params: PyTorch Network layer after .backward() was called.\n        Returns:\n            Nothing!\n        \"\"\"\n        gradients = [params.weight.grad.detach().cpu().numpy()]\n\n        for grad in gradients:\n            ### Shape: 128 x 2048\n            self.saver[\"grad_normal_mean\"].append(np.mean(grad, axis=0))\n            self.saver[\"grad_normal_std\"].append(np.std(grad, axis=0))\n            self.saver[\"grad_abs_mean\"].append(np.mean(np.abs(grad), axis=0))\n            self.saver[\"grad_abs_std\"].append(np.std(np.abs(grad), axis=0))\n\n    def dump(self, epoch):\n        \"\"\"\n        Append all gradients to a pickle file.\n\n        Args:\n            epoch: Current epoch\n        Returns:\n            Nothing!\n        \"\"\"\n        with open(\n            self.pars.save_path + \"/grad_dict_{}.pkl\".format(self.name), \"ab\"\n        ) as f:\n            pkl.dump([self.saver], f)\n        self.saver = {\n            \"grad_normal_mean\": [],\n            \"grad_normal_std\": [],\n            \"grad_abs_mean\": [],\n            \"grad_abs_std\": [],\n        }\n\n\n\"\"\"=========================================================\"\"\"\n\n\ndef evaluate_one_dataset(\n    LOG, dataloader, model, opt, save=True, give_return=True, epoch=0\n):\n    \"\"\"\n    Compute evaluation metrics, update LOGGER and print results.\n\n    Args:\n        LOG:         aux.LOGGER-instance. Main Logging Functionality.\n        dataloader:  PyTorch Dataloader, Testdata to be evaluated.\n        model:       PyTorch Network, Network to evaluate.\n        opt:         argparse.Namespace, contains all training-specific parameters.\n        save:        bool, if True, Checkpoints are saved when testing metrics (specifically Recall @ 1) improve.\n        give_return: bool, if True, return computed metrics.\n        epoch:       int, current epoch, required for logger.\n    Returns:\n        (optional) Computed metrics. Are normally written directly to LOG and printed.\n    \"\"\"\n    start = time.time()\n    image_paths = np.array(dataloader.dataset.image_list)\n\n    with torch.no_grad():\n        # Compute Metrics\n        (\n            F1,\n            NMI,\n            recall_at_ks,\n            feature_matrix_all,\n        ) = aux.eval_metrics_one_dataset(\n            model, dataloader, device=opt.device, k_vals=opt.k_vals, opt=opt\n        )\n        # Make printable summary string.\n\n        result_str = \", \".join(\n            \"@{0}: {1:.4f}\".format(k, rec)\n            for k, rec in zip(opt.k_vals, recall_at_ks)\n        )\n        result_str = \"Epoch (Test) {0}: NMI [{1:.4f}] | F1 [{2:.4f}] | Recall [{3}]\".format(\n            epoch, NMI, F1, result_str\n        )\n\n        if LOG is not None:\n            if save:\n                if not len(\n                    LOG.progress_saver[\"val\"][\"Recall @ 1\"]\n                ) or recall_at_ks[0] > np.max(\n                    LOG.progress_saver[\"val\"][\"Recall @ 1\"]\n                ):\n                    # Save Checkpoint\n                    print(\n                        \"Set checkpoint at {}.\".format(\n                            LOG.prop.save_path\n                            + \"/checkpoint_{}.pth.tar\".format(opt.iter)\n                        )\n                    )\n                    aux.set_checkpoint(\n                        model,\n                        opt,\n                        LOG.progress_saver,\n                        LOG.prop.save_path\n                        + \"/checkpoint_{}.pth.tar\".format(opt.iter),\n                    )\n                    # aux.recover_closest_one_dataset(feature_matrix_all, image_paths, LOG.prop.save_path+'/sample_recoveries.png')\n            # Update logs.\n            LOG.log(\n                \"val\",\n                LOG.metrics_to_log[\"val\"],\n                [epoch, np.round(time.time() - start), NMI, F1] + recall_at_ks,\n            )\n\n    print(result_str)\n    if give_return:\n        return recall_at_ks, NMI, F1\n    else:\n        None\n\n\n\"\"\"=========================================================\"\"\"\n\n\ndef evaluate_query_and_gallery_dataset(\n    LOG,\n    query_dataloader,\n    gallery_dataloader,\n    model,\n    opt,\n    save=True,\n    give_return=True,\n    epoch=0,\n):\n    \"\"\"\n    Compute evaluation metrics, update LOGGER and print results, specifically for In-Shop Clothes.\n\n    Args:\n         LOG:         aux.LOGGER-instance. Main Logging Functionality.\n        query_dataloader:    PyTorch Dataloader, Query-testdata to be evaluated.\n        gallery_dataloader:  PyTorch Dataloader, Gallery-testdata to be evaluated.\n        model:       PyTorch Network, Network to evaluate.\n        opt:         argparse.Namespace, contains all training-specific parameters.\n        save:        bool, if True, Checkpoints are saved when testing metrics (specifically Recall @ 1) improve.\n        give_return: bool, if True, return computed metrics.\n        epoch:       int, current epoch, required for logger.\n     Returns:\n        (optional) Computed metrics. Are normally written directly to LOG and printed.\n    \"\"\"\n    start = time.time()\n    query_image_paths = np.array(\n        [x[0] for x in query_dataloader.dataset.image_list]\n    )\n    gallery_image_paths = np.array(\n        [x[0] for x in gallery_dataloader.dataset.image_list]\n    )\n\n    with torch.no_grad():\n        # Compute Metri cs.\n        (\n            F1,\n            NMI,\n            recall_at_ks,\n            query_feature_matrix_all,\n            gallery_feature_matrix_all,\n        ) = aux.eval_metrics_query_and_gallery_dataset(\n            model,\n            query_dataloader,\n            gallery_dataloader,\n            device=opt.device,\n            k_vals=opt.k_vals,\n            opt=opt,\n        )\n        # Generate printable summary string.\n        result_str = \", \".join(\n            \"@{0}: {1:.4f}\".format(k, rec)\n            for k, rec in zip(opt.k_vals, recall_at_ks)\n        )\n        result_str = \"Epoch (Test) {0}: NMI [{1:.4f}] | F1 [{2:.4f}] | Recall [{3}]\".format(\n            epoch, NMI, F1, result_str\n        )\n\n        if LOG is not None:\n            if save:\n                if not len(\n                    LOG.progress_saver[\"val\"][\"Recall @ 1\"]\n                ) or recall_at_ks[0] > np.max(\n                    LOG.progress_saver[\"val\"][\"Recall @ 1\"]\n                ):\n                    # Save Checkpoint\n                    aux.set_checkpoint(\n                        model,\n                        opt,\n                        LOG.progress_saver,\n                        LOG.prop.save_path + \"/checkpoint.pth.tar\",\n                    )\n                    aux.recover_closest_inshop(\n                        query_feature_matrix_all,\n                        gallery_feature_matrix_all,\n                        query_image_paths,\n                        gallery_image_paths,\n                        LOG.prop.save_path + \"/sample_recoveries.png\",\n                    )\n            # Update logs.\n            LOG.log(\n                \"val\",\n                LOG.metrics_to_log[\"val\"],\n                [epoch, np.round(time.time() - start), NMI, F1] + recall_at_ks,\n            )\n\n    print(result_str)\n    if give_return:\n        return recall_at_ks, NMI, F1\n    else:\n        None\n\n\n\"\"\"=========================================================\"\"\"\n\n\ndef evaluate_multiple_datasets(\n    LOG, dataloaders, model, opt, save=True, give_return=True, epoch=0\n):\n    \"\"\"\n    Compute evaluation metrics, update LOGGER and print results, specifically for Multi-test datasets s.a. PKU Vehicle ID.\n\n    Args:\n        LOG:         aux.LOGGER-instance. Main Logging Functionality.\n        dataloaders: List of PyTorch Dataloaders, test-dataloaders to evaluate.\n        model:       PyTorch Network, Network to evaluate.\n        opt:         argparse.Namespace, contains all training-specific parameters.\n        sa ve:        bool, if True, Checkpoints are saved when testing metrics (specifically Recall @ 1) improve.\n        give_return: bool, i f True, return computed metrics.\n        epoch:       int, current epoch, required for logger.\n    Returns :\n        (optional) Computed metrics. Are normally written directly to LOG and printed.\n    \"\"\"\n    start = time.time()\n\n    csv_data = [epoch]\n\n    with torch.no_grad():\n        for i, dataloader in enumerate(dataloaders):\n            print(\"Working on Set {}/{}\".format(i + 1, len(dataloaders)))\n            image_paths = np.array(dataloader.dataset.image_list)\n            # Compute Metrics for specific testset.\n            (\n                F1,\n                NMI,\n                recall_at_ks,\n                feature_matrix_all,\n            ) = aux.eval_metrics_one_dataset(\n                model, dataloader, device=opt.device, k_vals=opt.k_vals, opt=opt\n            )\n            # Generate printable summary string.\n            result_str = \", \".join(\n                \"@{0}: {1:.4f}\".format(k, rec)\n                for k, rec in zip(opt.k_vals, recall_at_ks)\n            )\n            result_str = \"SET {0}: Epoch (Test) {1}: NMI [{2:.4f}] | F1 {3:.4f}| Recall [{4}]\".format(\n                i + 1, epoch, NMI, F1, result_str\n            )\n\n            if LOG is not None:\n                if save:\n                    if not len(\n                        LOG.progress_saver[\"val\"][\"Set {} Recall @ 1\".format(i)]\n                    ) or recall_at_ks[0] > np.max(\n                        LOG.progress_saver[\"val\"][\"Set {} Recall @ 1\".format(i)]\n                    ):\n                        # Save Checkpoint for specific test set.\n                        aux.set_checkpoint(\n                            model,\n                            opt,\n                            LOG.progress_saver,\n                            LOG.prop.save_path\n                            + \"/checkpoint_set{}.pth.tar\".format(i + 1),\n                        )\n                        aux.recover_closest_one_dataset(\n                            feature_matrix_all,\n                            image_paths,\n                            LOG.prop.save_path\n                            + \"/sample_recoveries_set{}.png\".format(i + 1),\n                        )\n\n                csv_data += [NMI, F1] + recall_at_ks\n            print(result_str)\n\n    csv_data.insert(0, np.round(time.time() - start))\n    # Update logs.\n    LOG.log(\"val\", LOG.metrics_to_log[\"val\"], csv_data)\n\n    # if give_return:\n    return csv_data[2:]\n    # else:\n    #    None\n"
  },
  {
    "path": "examples/pytorch/hilander/PSS/Smooth_AP/src/evaluate_model.py",
    "content": "import argparse\nimport os\n\nimport auxiliaries as aux\nimport datasets as data\nimport evaluate as eval\nimport netlib as netlib\nimport torch\n\nif __name__ == \"__main__\":\n    ################## INPUT ARGUMENTS ###################\n    parser = argparse.ArgumentParser()\n    ####### Main Parameter: Dataset to use for Training\n    parser.add_argument(\n        \"--dataset\",\n        default=\"vehicle_id\",\n        type=str,\n        help=\"Dataset to use.\",\n        choices=[\"Inaturalist\", \"vehicle_id\"],\n    )\n    parser.add_argument(\n        \"--source_path\",\n        default=\"/scratch/shared/beegfs/abrown/datasets\",\n        type=str,\n        help=\"Path to training data.\",\n    )\n    parser.add_argument(\n        \"--save_path\",\n        default=os.getcwd() + \"/Training_Results\",\n        type=str,\n        help=\"Where to save everything.\",\n    )\n    parser.add_argument(\n        \"--savename\",\n        default=\"\",\n        type=str,\n        help=\"Save folder name if any special information is to be included.\",\n    )\n\n    ### General Training Parameters\n    parser.add_argument(\n        \"--kernels\",\n        default=8,\n        type=int,\n        help=\"Number of workers for pytorch dataloader.\",\n    )\n    parser.add_argument(\n        \"--bs\", default=112, type=int, help=\"Mini-Batchsize to use.\"\n    )\n    parser.add_argument(\n        \"--samples_per_class\",\n        default=4,\n        type=int,\n        help=\"Number of samples in one class drawn before choosing the next class. Set to >1 for losses other than ProxyNCA.\",\n    )\n    parser.add_argument(\"--loss\", default=\"smoothap\", type=str)\n\n    ##### Evaluation Settings\n    parser.add_argument(\n        \"--k_vals\",\n        nargs=\"+\",\n        default=[1, 2, 4, 8],\n        type=int,\n        help=\"Recall @ Values.\",\n    )\n    ##### Network parameters\n    parser.add_argument(\n        \"--embed_dim\",\n        default=512,\n        type=int,\n        help=\"Embedding dimensionality of the network. Note: in literature, dim=128 is used for ResNet50 and dim=512 for GoogLeNet.\",\n    )\n    parser.add_argument(\n        \"--arch\",\n        default=\"resnet50\",\n        type=str,\n        help=\"Network backend choice: resnet50, googlenet, BNinception\",\n    )\n    parser.add_argument(\n        \"--gpu\", default=0, type=int, help=\"GPU-id for GPU to use.\"\n    )\n    parser.add_argument(\n        \"--resume\",\n        default=\"\",\n        type=str,\n        help=\"path to where weights to be evaluated are saved.\",\n    )\n    parser.add_argument(\n        \"--not_pretrained\",\n        action=\"store_true\",\n        help=\"If added, the network will be trained WITHOUT ImageNet-pretrained weights.\",\n    )\n\n    parser.add_argument(\"--trainset\", default=\"lin_train_set1.txt\", type=str)\n    parser.add_argument(\n        \"--testset\", default=\"Inaturalist_test_set1.txt\", type=str\n    )\n    parser.add_argument(\"--cluster_path\", default=\"\", type=str)\n    parser.add_argument(\"--finetune\", default=\"false\", type=str)\n    parser.add_argument(\"--class_num\", default=948, type=int)\n    parser.add_argument(\"--get_features\", default=\"false\", type=str)\n    parser.add_argument(\n        \"--patch_size\", default=16, type=int, help=\"vit patch size\"\n    )\n    parser.add_argument(\n        \"--pretrained_weights\",\n        default=\"\",\n        type=str,\n        help=\"pretrained weight path\",\n    )\n    parser.add_argument(\n        \"--use_bn_in_head\",\n        default=False,\n        type=aux.bool_flag,\n        help=\"Whether to use batch normalizations in projection head (Default: False)\",\n    )\n    parser.add_argument(\n        \"--checkpoint_key\",\n        default=\"teacher\",\n        type=str,\n        help='Key to use in the checkpoint (example: \"teacher\")',\n    )\n    parser.add_argument(\n        \"--drop_path_rate\",\n        default=0.1,\n        type=float,\n        help=\"stochastic depth rate\",\n    )\n    parser.add_argument(\n        \"--norm_last_layer\",\n        default=True,\n        type=aux.bool_flag,\n        help=\"\"\"Whether or not to weight normalize the last layer of the DINO head.\n        Not normalizing leads to better performance but can make the training unstable.\n        In our experiments, we typically set this paramater to False with vit_small and True with vit_base.\"\"\",\n    )\n    parser.add_argument(\n        \"--linsize\", default=29011, type=int, help=\"Lin data size.\"\n    )\n    parser.add_argument(\n        \"--uinsize\", default=18403, type=int, help=\"Uin data size.\"\n    )\n    opt = parser.parse_args()\n\n    \"\"\"============================================================================\"\"\"\n    opt.source_path += \"/\" + opt.dataset\n\n    if opt.dataset == \"Inaturalist\":\n        opt.n_epochs = 90\n        opt.tau = [40, 70]\n        opt.k_vals = [1, 4, 16, 32]\n\n    if opt.dataset == \"vehicle_id\":\n        opt.k_vals = [1, 5]\n\n    if opt.finetune == \"true\":\n        opt.finetune = True\n    elif opt.finetune == \"false\":\n        opt.finetune = False\n\n    if opt.get_features == \"true\":\n        opt.get_features = True\n    elif opt.get_features == \"false\":\n        opt.get_features = False\n\n    metrics_to_log = aux.metrics_to_examine(opt.dataset, opt.k_vals)\n    LOG = aux.LOGGER(opt, metrics_to_log, name=\"Base\", start_new=True)\n\n    \"\"\"============================================================================\"\"\"\n    ##################### NETWORK SETUP ##################\n\n    opt.device = torch.device(\"cuda\")\n    model = netlib.networkselect(opt)\n\n    # Push to Device\n    _ = model.to(opt.device)\n\n    \"\"\"============================================================================\"\"\"\n    #################### DATALOADER SETUPS ##################\n    # Returns a dictionary containing 'training', 'testing', and 'evaluation' dataloaders.\n    # The 'testing'-dataloader corresponds to the validation set, and the 'evaluation'-dataloader\n    # Is simply using the training set, however running under the same rules as 'testing' dataloader,\n    # i.e. no shuffling and no random cropping.\n    dataloaders = data.give_dataloaders(\n        opt.dataset, opt.trainset, opt.testset, opt\n    )\n    # Because the number of supervised classes is dataset dependent, we store them after\n    # initializing the dataloader\n    opt.num_classes = len(dataloaders[\"training\"].dataset.avail_classes)\n\n    if opt.dataset == \"Inaturalist\":\n        eval_params = {\n            \"dataloader\": dataloaders[\"testing\"],\n            \"model\": model,\n            \"opt\": opt,\n            \"epoch\": 0,\n        }\n\n    elif opt.dataset == \"vehicle_id\":\n        eval_params = {\n            \"dataloaders\": [\n                dataloaders[\"testing_set1\"],\n                dataloaders[\"testing_set2\"],\n                dataloaders[\"testing_set3\"],\n            ],\n            \"model\": model,\n            \"opt\": opt,\n            \"epoch\": 0,\n        }\n\n    \"\"\"============================================================================\"\"\"\n    ####################evaluation ##################\n\n    results = eval.evaluate(opt.dataset, LOG, save=True, **eval_params)\n"
  },
  {
    "path": "examples/pytorch/hilander/PSS/Smooth_AP/src/finetune_1head.py",
    "content": "# repo originally forked from https://github.com/Confusezius/Deep-Metric-Learning-Baselines\n\n\"\"\"to do:\n\nclean all of the files - particularly the main.py and also the losses and dataset files and the file for doing the dataloading\n\n-- fast loading etc\n\nneed to change all of the copyrights at the top of all of the files\n\n\"\"\"\n\n#################### LIBRARIES ########################\nimport warnings\n\nwarnings.filterwarnings(\"ignore\")\n\nimport argparse\nimport datetime\nimport os\nimport random\n\nimport matplotlib\nimport numpy as np\n\nos.chdir(os.path.dirname(os.path.realpath(__file__)))\nfrom pathlib import Path\n\nmatplotlib.use(\"agg\")\nimport auxiliaries as aux\nimport datasets as data\nimport evaluate as eval\nimport losses as losses\nimport netlib as netlib\nimport torch.multiprocessing\nfrom tensorboardX import SummaryWriter\nfrom tqdm import tqdm\n\ntorch.multiprocessing.set_sharing_strategy(\"file_system\")\n\nimport time\n\nstart = time.time()\n\n################### INPUT ARGUMENTS ###################\nparser = argparse.ArgumentParser()\n####### Main Parameter: Dataset to use for Training\nparser.add_argument(\n    \"--dataset\",\n    default=\"Inaturalist\",\n    type=str,\n    help=\"Dataset to use.\",\n    choices=[\"Inaturalist\", \"semi_fungi\"],\n)\n### General Training Parameters\nparser.add_argument(\n    \"--lr\",\n    default=0.00001,\n    type=float,\n    help=\"Learning Rate for network parameters.\",\n)\nparser.add_argument(\n    \"--fc_lr_mul\",\n    default=5,\n    type=float,\n    help=\"OPTIONAL: Multiply the embedding layer learning rate by this value. If set to 0, the embedding layer shares the same learning rate.\",\n)\nparser.add_argument(\n    \"--n_epochs\", default=400, type=int, help=\"Number of training epochs.\"\n)\nparser.add_argument(\n    \"--kernels\",\n    default=8,\n    type=int,\n    help=\"Number of workers for pytorch dataloader.\",\n)\nparser.add_argument(\n    \"--bs\", default=112, type=int, help=\"Mini-Batchsize to use.\"\n)\nparser.add_argument(\n    \"--samples_per_class\",\n    default=4,\n    type=int,\n    help=\"Number of samples in one class drawn before choosing the next class\",\n)\nparser.add_argument(\n    \"--seed\", default=1, type=int, help=\"Random seed for reproducibility.\"\n)\nparser.add_argument(\n    \"--scheduler\",\n    default=\"step\",\n    type=str,\n    help=\"Type of learning rate scheduling. Currently: step & exp.\",\n)\nparser.add_argument(\n    \"--gamma\",\n    default=0.3,\n    type=float,\n    help=\"Learning rate reduction after tau epochs.\",\n)\nparser.add_argument(\n    \"--decay\", default=0.001, type=float, help=\"Weight decay for optimizer.\"\n)\nparser.add_argument(\n    \"--tau\",\n    default=[200, 300],\n    nargs=\"+\",\n    type=int,\n    help=\"Stepsize(s) before reducing learning rate.\",\n)\nparser.add_argument(\n    \"--infrequent_eval\",\n    default=0,\n    type=int,\n    help=\"only compute evaluation metrics every 10 epochs\",\n)\nparser.add_argument(\"--opt\", default=\"adam\", help=\"adam or sgd\")\n##### Loss-specific Settings\nparser.add_argument(\"--loss\", default=\"smoothap\", type=str)\nparser.add_argument(\n    \"--sigmoid_temperature\",\n    default=0.01,\n    type=float,\n    help=\"SmoothAP: the temperature of the sigmoid used in SmoothAP loss\",\n)\n##### Evaluation Settings\nparser.add_argument(\n    \"--k_vals\",\n    nargs=\"+\",\n    default=[1, 2, 4, 8],\n    type=int,\n    help=\"Recall @ Values.\",\n)\nparser.add_argument(\n    \"--resume\",\n    default=\"\",\n    type=str,\n    help=\"path to checkpoint to load weights from (if empty then ImageNet pre-trained weights are loaded\",\n)\n##### Network parameters\nparser.add_argument(\n    \"--embed_dim\",\n    default=512,\n    type=int,\n    help=\"Embedding dimensionality of the network\",\n)\nparser.add_argument(\n    \"--arch\",\n    default=\"resnet50\",\n    type=str,\n    help=\"Network backend choice: resnet50, googlenet, BNinception\",\n)\nparser.add_argument(\n    \"--grad_measure\",\n    action=\"store_true\",\n    help=\"If added, gradients passed from embedding layer to the last conv-layer are stored in each iteration.\",\n)\nparser.add_argument(\n    \"--dist_measure\",\n    action=\"store_true\",\n    help=\"If added, the ratio between intra- and interclass distances is stored after each epoch.\",\n)\nparser.add_argument(\n    \"--not_pretrained\",\n    action=\"store_true\",\n    help=\"If added, the network will be trained WITHOUT ImageNet-pretrained weights.\",\n)\n##### Setup Parameters\nparser.add_argument(\"--gpu\", default=0, type=int, help=\"GPU-id for GPU to use.\")\nparser.add_argument(\n    \"--savename\",\n    default=\"\",\n    type=str,\n    help=\"Save folder name if any special information is to be included.\",\n)\n### Paths to datasets and storage folder\nparser.add_argument(\n    \"--source_path\",\n    default=\"/scratch/shared/beegfs/abrown/datasets\",\n    type=str,\n    help=\"Path to data\",\n)\nparser.add_argument(\n    \"--save_path\",\n    default=os.getcwd() + \"/Training_Results\",\n    type=str,\n    help=\"Where to save the checkpoints\",\n)\n### additional parameters\nparser.add_argument(\"--trainset\", default=\"lin_train_set1.txt\", type=str)\nparser.add_argument(\"--testset\", default=\"Inaturalist_test_set1.txt\", type=str)\nparser.add_argument(\"--cluster_path\", default=\"\", type=str)\nparser.add_argument(\"--finetune\", default=\"true\", type=str)\nparser.add_argument(\"--class_num\", default=948, type=int)\nparser.add_argument(\n    \"--pretrained_weights\", default=\"\", type=str, help=\"pretrained weight path\"\n)\nparser.add_argument(\n    \"--use_bn_in_head\",\n    default=False,\n    type=aux.bool_flag,\n    help=\"Whether to use batch normalizations in projection head (Default: False)\",\n)\nparser.add_argument(\n    \"--checkpoint_key\",\n    default=\"teacher\",\n    type=str,\n    help='Key to use in the checkpoint (example: \"teacher\")',\n)\nparser.add_argument(\n    \"--drop_path_rate\", default=0.1, type=float, help=\"stochastic depth rate\"\n)\nparser.add_argument(\"--iter\", default=1, type=int)\n\nopt = parser.parse_args()\n\"\"\"============================================================================\"\"\"\nopt.source_path += \"/\" + opt.dataset\nopt.save_path += \"/\" + opt.dataset + \"_\" + str(opt.embed_dim)\n\nif opt.dataset == \"Inaturalist\":\n    # opt.n_epochs = 90\n    opt.tau = [40, 70]\n    opt.k_vals = [1, 4, 16, 32]\n\nif opt.dataset == \"semi_fungi\":\n    opt.tau = [40, 70]\n    opt.k_vals = [1, 4, 16, 32]\n\nif opt.finetune == \"true\":\n    opt.finetune = True\nelif opt.finetune == \"false\":\n    opt.finetune = False\n\n\"\"\"===========================================================================\"\"\"\n################### TensorBoard Settings ##################\ntimestamp = datetime.datetime.now().strftime(r\"%Y-%m-%d_%H-%M-%S\")\nexp_name = aux.args2exp_name(opt)\nopt.save_name = f\"weights_{exp_name}\" + \"/\" + timestamp\nrandom.seed(opt.seed)\nnp.random.seed(opt.seed)\ntorch.manual_seed(opt.seed)\ntorch.cuda.manual_seed(opt.seed)\ntorch.cuda.manual_seed_all(opt.seed)\ntensorboard_path = Path(f\"logs/logs_{exp_name}\") / timestamp\n\ntensorboard_path.parent.mkdir(exist_ok=True, parents=True)\nglobal writer\nwriter = SummaryWriter(tensorboard_path)\n\"\"\"============================================================================\"\"\"\n################### GPU SETTINGS ###########################\nos.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\n# os.environ[\"CUDA_VISIBLE_DEVICES\"]= str(opt.gpu)\nprint(\"using #GPUs:\", torch.cuda.device_count())\n\n\"\"\"============================================================================\"\"\"\n#################### DATALOADER SETUPS ##################\n# Returns a dictionary containing 'training', 'testing', and 'evaluation' dataloaders.\n# The 'testing'-dataloader corresponds to the validation set, and the 'evaluation'-dataloader\n# Is simply using the training set, however running under the same rules as 'testing' dataloader,\n# i.e. no shuffling and no random cropping.\ndataloaders = data.give_dataloaders(\n    opt.dataset, opt.trainset, opt.testset, opt, cluster_path=opt.cluster_path\n)\n# Because the number of supervised classes is dataset dependent, we store them after\n# initializing the dataloader\nopt.num_classes = len(dataloaders[\"training\"].dataset.avail_classes)\nprint(\"num_classes:\", opt.num_classes)\nprint(\"train dataset size:\", len(dataloaders[\"training\"]))\n\n\"\"\"============================================================================\"\"\"\n##################### NETWORK SETUP ##################\n\nopt.device = torch.device(\"cuda\")\nmodel = netlib.networkselect(opt)\n\n# Push to Device\nif torch.cuda.device_count() > 1:\n    model = torch.nn.DataParallel(model)\n_ = model.to(opt.device)\n# Place trainable parameter in list of parameters to train:\n\nif \"fc_lr_mul\" in vars(opt).keys() and opt.fc_lr_mul != 0:\n    all_but_fc_params = list(\n        filter(lambda x: \"last_linear\" not in x[0], model.named_parameters())\n    )\n\n    for ind, param in enumerate(all_but_fc_params):\n        all_but_fc_params[ind] = param[1]\n\n    if torch.cuda.device_count() > 1:\n        fc_params = model.module.model.last_linear.parameters()\n    else:\n        fc_params = model.model.last_linear.parameters()\n\n    to_optim = [\n        {\"params\": all_but_fc_params, \"lr\": opt.lr, \"weight_decay\": opt.decay},\n        {\n            \"params\": fc_params,\n            \"lr\": opt.lr * opt.fc_lr_mul,\n            \"weight_decay\": opt.decay,\n        },\n    ]\nelse:\n    to_optim = [\n        {\"params\": model.parameters(), \"lr\": opt.lr, \"weight_decay\": opt.decay}\n    ]\n\"\"\"============================================================================\"\"\"\n#################### CREATE LOGGING FILES ###############\n# Each dataset usually has a set of standard metrics to log. aux.metrics_to_examine()\n# returns a dict which lists metrics to log for training ('train') and validation/testing ('val')\n\nmetrics_to_log = aux.metrics_to_examine(opt.dataset, opt.k_vals)\n# example output: {'train': ['Epochs', 'Time', 'Train Loss', 'Time'],\n#                  'val': ['Epochs','Time','NMI','F1', 'Recall @ 1','Recall @ 2','Recall @ 4','Recall @ 8']}\n\n# Using the provided metrics of interest, we generate a LOGGER instance.\n# Note that 'start_new' denotes that a new folder should be made in which everything will be stored.\n# This includes network weights as well.\nLOG = aux.LOGGER(opt, metrics_to_log, name=\"Base\", start_new=True)\n# If graphviz is installed on the system, a computational graph of the underlying\n# network will be made as well.\n\n\"\"\"============================================================================\"\"\"\n#################### LOSS SETUP ####################\n# Depending on opt.loss and opt.sampling, the respective criterion is returned,\n# and if the loss has trainable parameters, to_optim is appended.\ncriterion, to_optim = losses.loss_select(opt.loss, opt, to_optim)\n_ = criterion.to(opt.device)\n\n\"\"\"============================================================================\"\"\"\n##################### OPTIONAL EVALUATIONS #####################\n# Store the averaged gradients returned from the embedding to the last conv. layer.\nif opt.grad_measure:\n    grad_measure = eval.GradientMeasure(opt, name=\"baseline\")\n# Store the relative distances between average intra- and inter-class distance.\nif opt.dist_measure:\n    # Add a distance measure for training distance ratios\n    distance_measure = eval.DistanceMeasure(\n        dataloaders[\"evaluation\"], opt, name=\"Train\", update_epochs=1\n    )\n    # #If uncommented: Do the same for the test set\n    # distance_measure_test = eval.DistanceMeasure(dataloaders['testing'], opt, name='Train', update_epochs=1)\n\n\"\"\"============================================================================\"\"\"\n#################### OPTIM SETUP ####################\n# As optimizer, Adam with standard parameters is used.\nif opt.opt == \"adam\":\n    optimizer = torch.optim.Adam(to_optim)\nelif opt.opt == \"sgd\":\n    optimizer = torch.optim.SGD(to_optim)\nelse:\n    raise Exception(\"unknown optimiser\")\n# for the SOA measures in the paper - need to use SGD and 0.05 learning rate\n# optimizer    = torch.optim.Adam(to_optim)\n# optimizer    = torch.optim.SGD(to_optim)\nif opt.scheduler == \"exp\":\n    scheduler = torch.optim.lr_scheduler.ExponentialLR(\n        optimizer, gamma=opt.gamma\n    )\nelif opt.scheduler == \"step\":\n    scheduler = torch.optim.lr_scheduler.MultiStepLR(\n        optimizer, milestones=opt.tau, gamma=opt.gamma\n    )\nelif opt.scheduler == \"none\":\n    print(\"No scheduling used!\")\nelse:\n    raise Exception(\"No scheduling option for input: {}\".format(opt.scheduler))\n\n\ndef same_model(model1, model2):\n    for p1, p2 in zip(model1.parameters(), model2.parameters()):\n        if p1.data.ne(p2.data).sum() > 0:\n            return False\n    return True\n\n\n\"\"\"============================================================================\"\"\"\n\n\n#################### TRAINER FUNCTION ############################\ndef train_one_epoch_finetune(\n    train_dataloader, model, optimizer, criterion, opt, epoch\n):\n    \"\"\"\n    This function is called every epoch to perform training of the network over one full\n    (randomized) iteration of the dataset.\n\n    Args:\n        train_dataloader: torch.utils.data.DataLoader, returns (augmented) training data.\n        model:            Network to train.\n        optimizer:        Optimizer to use for training.\n        criterion:        criterion to use during training.\n        opt:              argparse.Namespace, Contains all relevant parameters.\n        epoch:            int, Current epoch.\n\n    Returns:\n        Nothing!\n    \"\"\"\n\n    loss_collect = []\n\n    start = time.time()\n    data_iterator = tqdm(\n        train_dataloader, desc=\"Epoch {} Training gt labels...\".format(epoch)\n    )\n    for i, (class_labels, input) in enumerate(data_iterator):\n        # Compute embeddings for input batch\n        features = model(input.to(opt.device))\n\n        # Compute loss.\n        if opt.loss != \"smoothap\":\n            loss = criterion(features, class_labels)\n        else:\n            loss = criterion(features)\n\n        # Ensure gradients are set to zero at beginning\n        optimizer.zero_grad()\n        # Compute gradient\n        loss.backward()\n\n        train_dataloader.dataset.classes_visited = []\n\n        if opt.grad_measure:\n            # If desired, save computed gradients.\n            grad_measure.include(model.model.last_linear)\n\n        # Update weights using comp. gradients.\n        optimizer.step()\n\n        # Store loss per iteration.\n        loss_collect.append(loss.item())\n        if i == len(train_dataloader) - 1:\n            data_iterator.set_description(\n                \"Epoch (Train) {0}: Mean Loss [{1:.4f}]\".format(\n                    epoch, np.mean(loss_collect)\n                )\n            )\n\n    # Save metrics\n    LOG.log(\n        \"train\",\n        LOG.metrics_to_log[\"train\"],\n        [epoch, np.round(time.time() - start, 4), np.mean(loss_collect)],\n    )\n    writer.add_scalar(\"global/training_loss\", np.mean(loss_collect), epoch)\n    if opt.grad_measure:\n        # Dump stored gradients to Pickle-File.\n        grad_measure.dump(epoch)\n\n\n\"\"\"============================================================================\"\"\"\n\"\"\"========================== MAIN TRAINING PART ==============================\"\"\"\n\"\"\"============================================================================\"\"\"\n################### SCRIPT MAIN ##########################\nprint(\"\\n-----\\n\")\n# Each dataset requires slightly different dataloaders.\n\nif opt.dataset == \"Inaturalist\" or \"semi_fungi\":\n    eval_params = {\n        \"dataloader\": dataloaders[\"testing\"],\n        \"model\": model,\n        \"opt\": opt,\n        \"epoch\": 0,\n    }\n\n# Compute Evaluation metrics, print them and store in LOG.\nprint(\"epochs -> \" + str(opt.n_epochs))\nimport time\n\nfor epoch in range(opt.n_epochs):\n    ### Print current learning rates for all parameters\n    if opt.scheduler != \"none\":\n        print(\n            \"Running with learning rates {}...\".format(\n                \" | \".join(\"{}\".format(x) for x in scheduler.get_lr())\n            )\n        )\n\n    ### Train one epoch\n    _ = model.train()\n\n    train_one_epoch_finetune(\n        dataloaders[\"training\"], model, optimizer, criterion, opt, epoch\n    )\n\n    dataloaders[\"training\"].dataset.reshuffle()\n    ### Evaluate\n    _ = model.eval()\n    # Each dataset requires slightly different dataloaders.\n    if opt.dataset == \"Inaturalist\":\n        eval_params = {\n            \"dataloader\": dataloaders[\"testing\"],\n            \"model\": model,\n            \"opt\": opt,\n            \"epoch\": epoch,\n        }\n    elif opt.dataset == \"semi_fungi\":\n        eval_params = {\n            \"dataloader\": dataloaders[\"testing\"],\n            \"model\": model,\n            \"opt\": opt,\n            \"epoch\": epoch,\n        }\n\n    # Compute Evaluation metrics, print them and store in LOG.\n    if opt.infrequent_eval == 1:\n        epoch_freq = 10\n    else:\n        epoch_freq = 1\n\n    if epoch % epoch_freq == 0:\n        results = eval.evaluate(opt.dataset, LOG, save=True, **eval_params)\n        writer.add_scalar(\"global/recall1\", results[0][0], epoch + 1)\n        writer.add_scalar(\"global/recall2\", results[0][1], epoch + 1)\n        writer.add_scalar(\"global/recall3\", results[0][2], epoch + 1)\n        writer.add_scalar(\"global/recall4\", results[0][3], epoch + 1)\n        writer.add_scalar(\"global/NMI\", results[1], epoch + 1)\n        writer.add_scalar(\"global/F1\", results[2], epoch + 1)\n\n    # Update the Metric Plot and save it.\n    # LOG.update_info_plot()\n    # (optional) compute ratio of intra- to interdistances.\n    if opt.dist_measure:\n        distance_measure.measure(model, epoch)\n        # distance_measure_test.measure(model, epoch)\n\n    ### Learning Rate Scheduling Step\n    if opt.scheduler != \"none\":\n        scheduler.step()\n\n    print(\"\\n-----\\n\")\n\nprint(\"Time:\", time.time() - start)\n"
  },
  {
    "path": "examples/pytorch/hilander/PSS/Smooth_AP/src/get_features.py",
    "content": "# repo originally forked from https://github.com/Confusezius/Deep-Metric-Learning-Baselines\n\n\"\"\"to do:\n\nclean all of the files - particularly the main.py and also the losses and dataset files and the file for doing the dataloading\n\n-- fast loading etc\n\nneed to change all of the copyrights at the top of all of the files\n\n\"\"\"\n\n#################### LIBRARIES ########################\nimport warnings\n\nwarnings.filterwarnings(\"ignore\")\n\nimport argparse\nimport datetime\nimport os\nimport random\n\nimport matplotlib\nimport numpy as np\n\nos.chdir(os.path.dirname(os.path.realpath(__file__)))\nmatplotlib.use(\"agg\")\n\nimport auxiliaries as aux\nimport datasets as data\nimport evaluate as eval\nimport losses as losses\nimport netlib as netlib\nimport torch.multiprocessing\n\ntorch.multiprocessing.set_sharing_strategy(\"file_system\")\n\n################### INPUT ARGUMENTS ###################\nparser = argparse.ArgumentParser()\n####### Main Parameter: Dataset to use for Training\nparser.add_argument(\n    \"--dataset\",\n    default=\"Inaturalist\",\n    type=str,\n    help=\"Dataset to use.\",\n    choices=[\"Inaturalist\", \"semi_fungi\"],\n)\n### General Training Parameters\nparser.add_argument(\n    \"--lr\",\n    default=0.00001,\n    type=float,\n    help=\"Learning Rate for network parameters.\",\n)\nparser.add_argument(\n    \"--fc_lr_mul\",\n    default=5,\n    type=float,\n    help=\"OPTIONAL: Multiply the embedding layer learning rate by this value. If set to 0, the embedding layer shares the same learning rate.\",\n)\nparser.add_argument(\n    \"--n_epochs\", default=400, type=int, help=\"Number of training epochs.\"\n)\nparser.add_argument(\n    \"--kernels\",\n    default=8,\n    type=int,\n    help=\"Number of workers for pytorch dataloader.\",\n)\nparser.add_argument(\n    \"--bs\", default=112, type=int, help=\"Mini-Batchsize to use.\"\n)\nparser.add_argument(\n    \"--samples_per_class\",\n    default=4,\n    type=int,\n    help=\"Number of samples in one class drawn before choosing the next class\",\n)\nparser.add_argument(\n    \"--seed\", default=1, type=int, help=\"Random seed for reproducibility.\"\n)\nparser.add_argument(\n    \"--scheduler\",\n    default=\"step\",\n    type=str,\n    help=\"Type of learning rate scheduling. Currently: step & exp.\",\n)\nparser.add_argument(\n    \"--gamma\",\n    default=0.3,\n    type=float,\n    help=\"Learning rate reduction after tau epochs.\",\n)\nparser.add_argument(\n    \"--decay\", default=0.0004, type=float, help=\"Weight decay for optimizer.\"\n)\nparser.add_argument(\n    \"--tau\",\n    default=[200, 300],\n    nargs=\"+\",\n    type=int,\n    help=\"Stepsize(s) before reducing learning rate.\",\n)\nparser.add_argument(\n    \"--infrequent_eval\",\n    default=0,\n    type=int,\n    help=\"only compute evaluation metrics every 10 epochs\",\n)\nparser.add_argument(\"--opt\", default=\"adam\", help=\"adam or sgd\")\n##### Loss-specific Settings\nparser.add_argument(\"--loss\", default=\"smoothap\", type=str)\nparser.add_argument(\n    \"--sigmoid_temperature\",\n    default=0.01,\n    type=float,\n    help=\"SmoothAP: the temperature of the sigmoid used in SmoothAP loss\",\n)\n##### Evaluation Settings\nparser.add_argument(\n    \"--k_vals\",\n    nargs=\"+\",\n    default=[1, 2, 4, 8],\n    type=int,\n    help=\"Recall @ Values.\",\n)\nparser.add_argument(\n    \"--resume\",\n    default=\"\",\n    type=str,\n    help=\"path to checkpoint to load weights from (if empty then ImageNet pre-trained weights are loaded\",\n)\n##### Network parameters\nparser.add_argument(\n    \"--embed_dim\",\n    default=512,\n    type=int,\n    help=\"Embedding dimensionality of the network\",\n)\nparser.add_argument(\n    \"--arch\",\n    default=\"resnet50\",\n    type=str,\n    help=\"Network backend choice: resnet50, googlenet, BNinception\",\n)\nparser.add_argument(\n    \"--grad_measure\",\n    action=\"store_true\",\n    help=\"If added, gradients passed from embedding layer to the last conv-layer are stored in each iteration.\",\n)\nparser.add_argument(\n    \"--dist_measure\",\n    action=\"store_true\",\n    help=\"If added, the ratio between intra- and interclass distances is stored after each epoch.\",\n)\nparser.add_argument(\n    \"--not_pretrained\",\n    action=\"store_true\",\n    help=\"If added, the network will be trained WITHOUT ImageNet-pretrained weights.\",\n)\n##### Setup Parameters\nparser.add_argument(\"--gpu\", default=0, type=int, help=\"GPU-id for GPU to use.\")\nparser.add_argument(\n    \"--savename\",\n    default=\"\",\n    type=str,\n    help=\"Save folder name if any special information is to be included.\",\n)\n### Paths to datasets and storage folder\nparser.add_argument(\n    \"--source_path\",\n    default=\"/scratch/shared/beegfs/abrown/datasets\",\n    type=str,\n    help=\"Path to data\",\n)\nparser.add_argument(\n    \"--save_path\",\n    default=os.getcwd() + \"/Training_Results\",\n    type=str,\n    help=\"Where to save the checkpoints\",\n)\n### adational\nparser.add_argument(\"--trainset\", default=\"lin_train_set1.txt\", type=str)\nparser.add_argument(\"--all_trainset\", default=\"train_set1.txt\", type=str)\nparser.add_argument(\"--testset\", default=\"test_set1.txt\", type=str)\nparser.add_argument(\"--finetune\", default=\"true\", type=str)\nparser.add_argument(\"--cluster_path\", default=\"\", type=str)\nparser.add_argument(\"--get_features\", default=\"false\", type=str)\nparser.add_argument(\"--class_num\", default=948, type=int)\nparser.add_argument(\"--iter\", default=0, type=int)\nparser.add_argument(\n    \"--pretrained_weights\", default=\"\", type=str, help=\"pretrained weight path\"\n)\nparser.add_argument(\n    \"--use_bn_in_head\",\n    default=False,\n    type=aux.bool_flag,\n    help=\"Whether to use batch normalizations in projection head (Default: False)\",\n)\nparser.add_argument(\n    \"--checkpoint_key\",\n    default=\"teacher\",\n    type=str,\n    help='Key to use in the checkpoint (example: \"teacher\")',\n)\nparser.add_argument(\n    \"--drop_path_rate\", default=0.1, type=float, help=\"stochastic depth rate\"\n)\nparser.add_argument(\"--linsize\", default=29011, type=int, help=\"Lin data size.\")\nparser.add_argument(\"--uinsize\", default=18403, type=int, help=\"Uin data size.\")\nopt = parser.parse_args()\n\"\"\"============================================================================\"\"\"\nopt.source_path += \"/\" + opt.dataset\nopt.save_path += \"/\" + opt.dataset + \"_\" + str(opt.embed_dim)\n\nif opt.dataset == \"Inaturalist\":\n    opt.n_epochs = 90\n    opt.tau = [40, 70]\n    opt.k_vals = [1, 4, 16, 32]\n\nif opt.dataset == \"semi_fungi\":\n    opt.tau = [40, 70]\n    opt.k_vals = [1, 4, 16, 32]\n\nif opt.get_features == \"true\":\n    opt.get_features = True\nif opt.get_features == \"false\":\n    opt.get_features = False\n\nif opt.finetune == \"true\":\n    opt.finetune = True\nelif opt.finetune == \"false\":\n    opt.finetune = False\n\n\"\"\"===========================================================================\"\"\"\n################### TensorBoard Settings ##################\ntimestamp = datetime.datetime.now().strftime(r\"%Y-%m-%d_%H-%M-%S\")\nexp_name = aux.args2exp_name(opt)\nopt.save_name = f\"weights_{exp_name}\" + \"/\" + timestamp\nrandom.seed(opt.seed)\nnp.random.seed(opt.seed)\ntorch.manual_seed(opt.seed)\ntorch.cuda.manual_seed(opt.seed)\ntorch.cuda.manual_seed_all(opt.seed)\n\n\"\"\"============================================================================\"\"\"\n################### GPU SETTINGS ###########################\nos.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\n# os.environ[\"CUDA_VISIBLE_DEVICES\"]= str(opt.gpu)\nprint(\"using #GPUs:\", torch.cuda.device_count())\n\n\"\"\"============================================================================\"\"\"\n##################### NETWORK SETUP ##################\n\nopt.device = torch.device(\"cuda\")\nmodel = netlib.networkselect(opt)\n\n# Push to Device\nif torch.cuda.device_count() > 1:\n    model = torch.nn.DataParallel(model)\n_ = model.to(opt.device)\n\n# Place trainable parameter in list of parameters to train:\n\nif \"fc_lr_mul\" in vars(opt).keys() and opt.fc_lr_mul != 0:\n    all_but_fc_params = list(\n        filter(lambda x: \"last_linear\" not in x[0], model.named_parameters())\n    )\n\n    for ind, param in enumerate(all_but_fc_params):\n        all_but_fc_params[ind] = param[1]\n\n    if torch.cuda.device_count() > 1:\n        fc_params = model.module.model.last_linear.parameters()\n    else:\n        fc_params = model.model.last_linear.parameters()\n\n    to_optim = [\n        {\"params\": all_but_fc_params, \"lr\": opt.lr, \"weight_decay\": opt.decay},\n        {\n            \"params\": fc_params,\n            \"lr\": opt.lr * opt.fc_lr_mul,\n            \"weight_decay\": opt.decay,\n        },\n    ]\nelse:\n    to_optim = [\n        {\"params\": model.parameters(), \"lr\": opt.lr, \"weight_decay\": opt.decay}\n    ]\n\"\"\"============================================================================\"\"\"\n#################### DATALOADER SETUPS ##################\n# Returns a dictionary containing 'training', 'testing', and 'evaluation' dataloaders.\n# The 'testing'-dataloader corresponds to the validation set, and the 'evaluation'-dataloader\n# Is simply using the training set, however running under the same rules as 'testing' dataloader,\n# i.e. no shuffling and no random cropping.\ndataloaders = data.give_dataloaders(opt.dataset, opt.trainset, opt.testset, opt)\n# Because the number of supervised classes is dataset dependent, we store them after\n# initializing the dataloader\nopt.num_classes = len(dataloaders[\"training\"].dataset.avail_classes)\n\n\"\"\"============================================================================\"\"\"\n#################### CREATE LOGGING FILES ###############\n# Each dataset usually has a set of standard metrics to log. aux.metrics_to_examine()\n# returns a dict which lists metrics to log for training ('train') and validation/testing ('val')\n\nmetrics_to_log = aux.metrics_to_examine(opt.dataset, opt.k_vals)\n# example output: {'train': ['Epochs', 'Time', 'Train Loss', 'Time'],\n#                  'val': ['Epochs','Time','NMI','F1', 'Recall @ 1','Recall @ 2','Recall @ 4','Recall @ 8']}\n\n# Using the provided metrics of interest, we generate a LOGGER instance.\n# Note that 'start_new' denotes that a new folder should be made in which everything will be stored.\n# This includes network weights as well.\n# If graphviz is installed on the system, a computational graph of the underlying\n# network will be made as well.\n\n\"\"\"============================================================================\"\"\"\n#################### LOSS SETUP ####################\n# Depending on opt.loss and opt.sampling, the respective criterion is returned,\n# and if the loss has trainable parameters, to_optim is appended.\nLOG = aux.LOGGER(opt, metrics_to_log, name=\"Base\", start_new=True)\ncriterion, to_optim = losses.loss_select(opt.loss, opt, to_optim)\n_ = criterion.to(opt.device)\n\n\"\"\"============================================================================\"\"\"\n##################### OPTIONAL EVALUATIONS #####################\n# Store the averaged gradients returned from the embedding to the last conv. layer.\nif opt.grad_measure:\n    grad_measure = eval.GradientMeasure(opt, name=\"baseline\")\n# Store the relative distances between average intra- and inter-class distance.\nif opt.dist_measure:\n    # Add a distance measure for training distance ratios\n    distance_measure = eval.DistanceMeasure(\n        dataloaders[\"evaluation\"], opt, name=\"Train\", update_epochs=1\n    )\n    # #If uncommented: Do the same for the test set\n    # distance_measure_test = eval.DistanceMeasure(dataloaders['testing'], opt, name='Train', update_epochs=1)\n\n\"\"\"============================================================================\"\"\"\n#################### OPTIM SETUP ####################\n# As optimizer, Adam with standard parameters is used.\nif opt.opt == \"adam\":\n    optimizer = torch.optim.Adam(to_optim)\nelif opt.opt == \"sgd\":\n    optimizer = torch.optim.SGD(to_optim)\nelse:\n    raise Exception(\"unknown optimiser\")\n# for the SOA measures in the paper - need to use SGD and 0.05 learning rate\n# optimizer    = torch.optim.Adam(to_optim)\n# optimizer    = torch.optim.SGD(to_optim)\nif opt.scheduler == \"exp\":\n    scheduler = torch.optim.lr_scheduler.ExponentialLR(\n        optimizer, gamma=opt.gamma\n    )\nelif opt.scheduler == \"step\":\n    scheduler = torch.optim.lr_scheduler.MultiStepLR(\n        optimizer, milestones=opt.tau, gamma=opt.gamma\n    )\nelif opt.scheduler == \"none\":\n    print(\"No scheduling used!\")\nelse:\n    raise Exception(\"No scheduling option for input: {}\".format(opt.scheduler))\n\n\ndef same_model(model1, model2):\n    for p1, p2 in zip(model1.parameters(), model2.parameters()):\n        if p1.data.ne(p2.data).sum() > 0:\n            return False\n    return True\n\n\n\"\"\"============================================================================\"\"\"\n\"\"\"================================ TESTING ===================================\"\"\"\n\"\"\"============================================================================\"\"\"\n################### SCRIPT MAIN ##########################\nprint(\"\\n-----\\n\")\n# Compute Evaluation metrics, print them and store in LOG.\n\n_ = model.eval()\naux.vis(\n    model,\n    dataloaders[\"training\"],\n    opt.device,\n    split=\"T_train_iter\" + str(opt.iter) + \"_\" + str(opt.loss),\n    opt=opt,\n)\naux.vis(\n    model,\n    dataloaders[\"testing\"],\n    opt.device,\n    split=\"all_train_iter\" + str(opt.iter) + \"_\" + str(opt.loss),\n    opt=opt,\n)\naux.vis(\n    model,\n    dataloaders[\"eval\"],\n    opt.device,\n    split=\"test_iter\" + str(opt.iter) + \"_\" + str(opt.loss),\n    opt=opt,\n)\n# Update the Metric Plot and save it.\nprint(\"\\n-----\\n\")\n"
  },
  {
    "path": "examples/pytorch/hilander/PSS/Smooth_AP/src/losses.py",
    "content": "# repo originally forked from https://github.com/Confusezius/Deep-Metric-Learning-Baselines\n\n###################### LIBRARIES #################################################\nimport warnings\n\nwarnings.filterwarnings(\"ignore\")\n\nimport faiss\nimport numpy as np\nimport torch\nfrom scipy import sparse\n\n\"\"\"=================================================================================================\"\"\"\n\n\n############ LOSS SELECTION FUNCTION #####################\ndef loss_select(loss, opt, to_optim):\n    \"\"\"\n    Selection function which returns the respective criterion while appending to list of trainable parameters if required.\n\n    Args:\n        loss:     str, name of loss function to return.\n        opt:      argparse.Namespace, contains all training-specific parameters.\n        to_optim: list of trainable parameters. Is extend if loss function contains those as well.\n    Returns:\n        criterion (torch.nn.Module inherited), to_optim (optionally appended)\n    \"\"\"\n    if loss == \"smoothap\":\n        loss_params = {\n            \"anneal\": opt.sigmoid_temperature,\n            \"batch_size\": opt.bs,\n            \"num_id\": int(opt.bs / opt.samples_per_class),\n            \"feat_dims\": opt.embed_dim,\n        }\n        criterion = SmoothAP(**loss_params)\n    else:\n        raise Exception(\"Loss {} not available!\".format(loss))\n\n    return criterion, to_optim\n\n\n\"\"\"==============================================Smooth-AP========================================\"\"\"\n\n\ndef sigmoid(tensor, temp=1.0):\n    \"\"\"temperature controlled sigmoid\n    takes as input a torch tensor (tensor) and passes it through a sigmoid, controlled by temperature: temp\n    \"\"\"\n    exponent = -tensor / temp\n    # clamp the input tensor for stability\n    exponent = torch.clamp(exponent, min=-50, max=50)\n    y = 1.0 / (1.0 + torch.exp(exponent))\n    return y\n\n\ndef compute_aff(x):\n    \"\"\"computes the affinity matrix between an input vector and itself\"\"\"\n    return torch.mm(x, x.t())\n\n\nclass BinarizedF(torch.autograd.Function):\n    def forward(self, inp):\n        self.save_for_backward(inp)\n        a = torch.ones_like(inp)\n        b = torch.zeros_like(inp)\n        output = torch.where(inp > 0, a, b)\n        return output\n\n    def backward(self, output_grad):\n        (inp,) = self.saved_tensors\n        input_abs = torch.abs(inp)\n        ones = torch.ones_like(inp)\n        zeros = torch.zeros_like(inp)\n        input_grad = torch.where(input_abs > 0, ones, zeros)\n        return input_grad\n\n\nclass BinarizedModule(torch.nn.Module):\n    def __init__(self):\n        super(BinarizedModule, self).__init__()\n        self.BF = BinarizedF()\n\n    def forward(self, inp):\n        output = self.BF(inp)\n        return output\n\n\nclass SmoothAP(torch.nn.Module):\n    \"\"\"PyTorch implementation of the Smooth-AP loss.\n    implementation of the Smooth-AP loss. Takes as input the mini-batch of CNN-produced feature embeddings and returns\n    the value of the Smooth-AP loss. The mini-batch must be formed of a defined number of classes. Each class must\n    have the same number of instances represented in the mini-batch and must be ordered sequentially by class.\n    e.g. the labels for a mini-batch with batch size 9, and 3 represented classes (A,B,C) must look like:\n        labels = ( A, A, A, B, B, B, C, C, C)\n    (the order of the classes however does not matter)\n    For each instance in the mini-batch, the loss computes the Smooth-AP when it is used as the query and the rest of the\n    mini-batch is used as the retrieval set. The positive set is formed of the other instances in the batch from the\n    same class. The loss returns the average Smooth-AP across all instances in the mini-batch.\n    Args:\n        anneal : float\n            the temperature of the sigmoid that is used to smooth the ranking function. A low value of the temperature\n            results in a steep sigmoid, that tightly approximates the heaviside step function in the ranking function.\n        batch_size : int\n            the batch size being used during training.\n        num_id : int\n            the number of different classes that are represented in the batch.\n        feat_dims : int\n            the dimension of the input feature embeddings\n    Shape:\n        - Input (preds): (batch_size, feat_dims) (must be a cuda torch float tensor)\n        - Output: scalar\n    Examples::\n        >>> loss = SmoothAP(0.01, 60, 6, 256)\n        >>> input = torch.randn(60, 256, requires_grad=True).cuda()\n        >>> output = loss(input)\n        >>> output.backward()\n    \"\"\"\n\n    def __init__(self, anneal, batch_size, num_id, feat_dims):\n        \"\"\"\n        Parameters\n        ----------\n        anneal : float\n            the temperature of the sigmoid that is used to smooth the ranking function\n        batch_size : int\n            the batch size being used\n        num_id : int\n            the number of different classes that are represented in the batch\n        feat_dims : int\n            the dimension of the input feature embeddings\n        \"\"\"\n        super(SmoothAP, self).__init__()\n\n        assert batch_size % num_id == 0\n\n        self.anneal = anneal\n        self.batch_size = batch_size\n        self.num_id = num_id\n        self.feat_dims = feat_dims\n\n    def forward(self, preds):\n        \"\"\"Forward pass for all input predictions: preds - (batch_size x feat_dims)\"\"\"\n\n        # ------ differentiable ranking of all retrieval set ------\n        # compute the mask which ignores the relevance score of the query to itself\n        mask = 1.0 - torch.eye(self.batch_size)\n        mask = mask.unsqueeze(dim=0).repeat(self.batch_size, 1, 1)\n        # compute the relevance scores via cosine similarity of the CNN-produced embedding vectors\n        sim_all = compute_aff(preds)\n        sim_all_repeat = sim_all.unsqueeze(dim=1).repeat(1, self.batch_size, 1)\n        # compute the difference matrix\n        sim_diff = sim_all_repeat - sim_all_repeat.permute(0, 2, 1)\n        # pass through the sigmoid\n        sim_sg = sigmoid(sim_diff, temp=self.anneal) * mask.cuda()\n        # compute the rankings\n        sim_all_rk = torch.sum(sim_sg, dim=-1) + 1\n\n        # ------ differentiable ranking of only positive set in retrieval set ------\n        # compute the mask which only gives non-zero weights to the positive set\n        xs = preds.view(\n            self.num_id, int(self.batch_size / self.num_id), self.feat_dims\n        )\n        pos_mask = 1.0 - torch.eye(int(self.batch_size / self.num_id))\n        pos_mask = (\n            pos_mask.unsqueeze(dim=0)\n            .unsqueeze(dim=0)\n            .repeat(self.num_id, int(self.batch_size / self.num_id), 1, 1)\n        )\n        # compute the relevance scores\n        sim_pos = torch.bmm(xs, xs.permute(0, 2, 1))\n        sim_pos_repeat = sim_pos.unsqueeze(dim=2).repeat(\n            1, 1, int(self.batch_size / self.num_id), 1\n        )\n        # compute the difference matrix\n        sim_pos_diff = sim_pos_repeat - sim_pos_repeat.permute(0, 1, 3, 2)\n        # pass through the sigmoid\n        sim_pos_sg = sigmoid(sim_pos_diff, temp=self.anneal) * pos_mask.cuda()\n        # compute the rankings of the positive set\n        sim_pos_rk = torch.sum(sim_pos_sg, dim=-1) + 1\n\n        # sum the values of the Smooth-AP for all instances in the mini-batch\n        ap = torch.zeros(1).cuda()\n        group = int(self.batch_size / self.num_id)\n        for ind in range(self.num_id):\n            pos_divide = torch.sum(\n                sim_pos_rk[ind]\n                / (\n                    sim_all_rk[\n                        (ind * group) : ((ind + 1) * group),\n                        (ind * group) : ((ind + 1) * group),\n                    ]\n                )\n            )\n            ap = ap + ((pos_divide / group) / self.batch_size)\n        return 1 - ap\n"
  },
  {
    "path": "examples/pytorch/hilander/PSS/Smooth_AP/src/main.py",
    "content": "# repo originally forked from https://github.com/Confusezius/Deep-Metric-Learning-Baselines\n\n\"\"\"to do:\n\nclean all of the files - particularly the main.py and also the losses and dataset files and the file for doing the dataloading\n\n-- fast loading etc\n\nneed to change all of the copyrights at the top of all of the files\n\n\"\"\"\n\n#################### LIBRARIES ########################\nimport warnings\n\nwarnings.filterwarnings(\"ignore\")\n\nimport argparse\nimport datetime\nimport os\nimport random\n\nimport matplotlib\nimport numpy as np\n\nos.chdir(os.path.dirname(os.path.realpath(__file__)))\nfrom pathlib import Path\n\nmatplotlib.use(\"agg\")\nimport auxiliaries as aux\nimport datasets as data\nimport evaluate as eval\nimport losses as losses\nimport netlib as netlib\nimport torch.multiprocessing\nfrom tensorboardX import SummaryWriter\nfrom tqdm import tqdm\n\ntorch.multiprocessing.set_sharing_strategy(\"file_system\")\n\n################### INPUT ARGUMENTS ###################\nparser = argparse.ArgumentParser()\n####### Main Parameter: Dataset to use for Training\nparser.add_argument(\n    \"--dataset\",\n    default=\"vehicle_id\",\n    type=str,\n    help=\"Dataset to use.\",\n    choices=[\"SoftInaturalist\", \"Inaturalist\", \"vehicle_id\", \"semi_fungi\"],\n)\n### General Training Parameters\nparser.add_argument(\n    \"--lr\",\n    default=0.00001,\n    type=float,\n    help=\"Learning Rate for network parameters.\",\n)\nparser.add_argument(\n    \"--fc_lr_mul\",\n    default=5,\n    type=float,\n    help=\"OPTIONAL: Multiply the embedding layer learning rate by this value. If set to 0, the embedding layer shares the same learning rate.\",\n)\nparser.add_argument(\n    \"--n_epochs\", default=400, type=int, help=\"Number of training epochs.\"\n)\nparser.add_argument(\n    \"--kernels\",\n    default=8,\n    type=int,\n    help=\"Number of workers for pytorch dataloader.\",\n)\nparser.add_argument(\n    \"--bs\", default=112, type=int, help=\"Mini-Batchsize to use.\"\n)\nparser.add_argument(\n    \"--samples_per_class\",\n    default=4,\n    type=int,\n    help=\"Number of samples in one class drawn before choosing the next class\",\n)\nparser.add_argument(\n    \"--seed\", default=1, type=int, help=\"Random seed for reproducibility.\"\n)\nparser.add_argument(\n    \"--scheduler\",\n    default=\"step\",\n    type=str,\n    help=\"Type of learning rate scheduling. Currently: step & exp.\",\n)\nparser.add_argument(\n    \"--gamma\",\n    default=0.3,\n    type=float,\n    help=\"Learning rate reduction after tau epochs.\",\n)\nparser.add_argument(\n    \"--decay\", default=0.0004, type=float, help=\"Weight decay for optimizer.\"\n)\nparser.add_argument(\n    \"--tau\",\n    default=[200, 300],\n    nargs=\"+\",\n    type=int,\n    help=\"Stepsize(s) before reducing learning rate.\",\n)\nparser.add_argument(\n    \"--infrequent_eval\",\n    default=0,\n    type=int,\n    help=\"only compute evaluation metrics every 10 epochs\",\n)\nparser.add_argument(\"--opt\", default=\"adam\", help=\"adam or sgd\")\n##### Loss-specific Settings\nparser.add_argument(\"--loss\", default=\"smoothap\", type=str)\nparser.add_argument(\n    \"--sigmoid_temperature\",\n    default=0.01,\n    type=float,\n    help=\"SmoothAP: the temperature of the sigmoid used in SmoothAP loss\",\n)\n##### Evaluation Settings\nparser.add_argument(\n    \"--k_vals\",\n    nargs=\"+\",\n    default=[1, 2, 4, 8],\n    type=int,\n    help=\"Recall @ Values.\",\n)\nparser.add_argument(\n    \"--resume\",\n    default=\"\",\n    type=str,\n    help=\"path to checkpoint to load weights from (if empty then ImageNet pre-trained weights are loaded\",\n)\n##### Network parameters\nparser.add_argument(\n    \"--embed_dim\",\n    default=512,\n    type=int,\n    help=\"Embedding dimensionality of the network\",\n)\nparser.add_argument(\n    \"--arch\",\n    default=\"resnet50\",\n    type=str,\n    help=\"Network backend choice: resnet50\",\n)\nparser.add_argument(\n    \"--pretrained_weights\", default=\"\", type=str, help=\"pretrained weight path\"\n)\nparser.add_argument(\n    \"--use_bn_in_head\",\n    default=False,\n    type=aux.bool_flag,\n    help=\"Whether to use batch normalizations in projection head (Default: False)\",\n)\nparser.add_argument(\n    \"--checkpoint_key\",\n    default=\"teacher\",\n    type=str,\n    help='Key to use in the checkpoint (example: \"teacher\")',\n)\nparser.add_argument(\n    \"--drop_path_rate\", default=0.1, type=float, help=\"stochastic depth rate\"\n)\nparser.add_argument(\n    \"--grad_measure\",\n    action=\"store_true\",\n    help=\"If added, gradients passed from embedding layer to the last conv-layer are stored in each iteration.\",\n)\nparser.add_argument(\n    \"--dist_measure\",\n    action=\"store_true\",\n    help=\"If added, the ratio between intra- and interclass distances is stored after each epoch.\",\n)\nparser.add_argument(\n    \"--not_pretrained\",\n    action=\"store_true\",\n    help=\"If added, the network will be trained WITHOUT ImageNet-pretrained weights.\",\n)\n##### Setup Parameters\nparser.add_argument(\"--gpu\", default=0, type=int, help=\"GPU-id for GPU to use.\")\nparser.add_argument(\n    \"--savename\",\n    default=\"\",\n    type=str,\n    help=\"Save folder name if any special information is to be included.\",\n)\n### Paths to datasets and storage folder\nparser.add_argument(\n    \"--source_path\",\n    default=\"/scratch/shared/beegfs/abrown/datasets\",\n    type=str,\n    help=\"Path to data\",\n)\nparser.add_argument(\n    \"--save_path\",\n    default=os.getcwd() + \"/Training_Results\",\n    type=str,\n    help=\"Where to save the checkpoints\",\n)\n### additional parameters\nparser.add_argument(\"--trainset\", default=\"lin_train_set1.txt\", type=str)\nparser.add_argument(\"--testset\", default=\"Inaturalist_test_set1.txt\", type=str)\nparser.add_argument(\"--cluster_path\", default=\"\", type=str)\nparser.add_argument(\"--finetune\", default=\"false\", type=str)\nparser.add_argument(\"--class_num\", default=948, type=int)\nparser.add_argument(\"--get_features\", default=\"false\", type=str)\nparser.add_argument(\"--linsize\", default=29011, type=int, help=\"Lin data size.\")\nparser.add_argument(\"--uinsize\", default=18403, type=int, help=\"Uin data size.\")\nparser.add_argument(\"--iter\", default=0, type=int)\n\nopt = parser.parse_args()\n\"\"\"============================================================================\"\"\"\nif opt.dataset == \"SoftInaturalist\":\n    opt.source_path += \"/Inaturalist\"\n    opt.save_path += \"/Inaturalist\" + \"_\" + str(opt.embed_dim)\nelse:\n    opt.source_path += \"/\" + opt.dataset\n    opt.save_path += \"/\" + opt.dataset + \"_\" + str(opt.embed_dim)\n\nif opt.dataset == \"Inaturalist\":\n    # opt.n_epochs = 90\n    opt.tau = [40, 70]\n    opt.k_vals = [1, 4, 16, 32]\n\nif opt.dataset == \"SoftInaturalist\":\n    # opt.n_epochs = 90\n    opt.tau = [40, 70]\n    opt.k_vals = [1, 4, 16, 32]\n\nif opt.dataset == \"vehicle_id\":\n    opt.k_vals = [1, 5]\n\nif opt.dataset == \"semi_fungi\":\n    opt.tau = [40, 70]\n    opt.k_vals = [1, 4, 16, 32]\n\nif opt.finetune == \"true\":\n    opt.finetune = True\nelif opt.finetune == \"false\":\n    opt.finetune = False\n\nif opt.get_features == \"true\":\n    opt.get_features = True\nelif opt.get_features == \"false\":\n    opt.get_features = False\n\n\"\"\"===========================================================================\"\"\"\n################### TensorBoard Settings ##################\ntimestamp = datetime.datetime.now().strftime(r\"%Y-%m-%d_%H-%M-%S\")\nexp_name = aux.args2exp_name(opt)\nopt.save_name = f\"weights_{exp_name}\" + \"/\" + timestamp\nrandom.seed(opt.seed)\nnp.random.seed(opt.seed)\ntorch.manual_seed(opt.seed)\ntorch.cuda.manual_seed(opt.seed)\ntorch.cuda.manual_seed_all(opt.seed)\ntensorboard_path = Path(f\"logs/logs_{exp_name}\") / timestamp\n\ntensorboard_path.parent.mkdir(exist_ok=True, parents=True)\nglobal writer\nwriter = SummaryWriter(tensorboard_path)\n\"\"\"============================================================================\"\"\"\n################### GPU SETTINGS ###########################\nos.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\n# os.environ[\"CUDA_VISIBLE_DEVICES\"]= str(opt.gpu)\nprint(\"using #GPUs:\", torch.cuda.device_count())\n\n\"\"\"============================================================================\"\"\"\n##################### NETWORK SETUP ##################\n\nopt.device = torch.device(\"cuda\")\nmodel = netlib.networkselect(opt)\n\n# Push to Device\nif torch.cuda.device_count() > 1:\n    model = torch.nn.DataParallel(model)\n_ = model.to(opt.device)\n# Place trainable parameter in list of parameters to train:\n\nif \"fc_lr_mul\" in vars(opt).keys() and opt.fc_lr_mul != 0:\n    all_but_fc_params = list(\n        filter(lambda x: \"last_linear\" not in x[0], model.named_parameters())\n    )\n\n    for ind, param in enumerate(all_but_fc_params):\n        all_but_fc_params[ind] = param[1]\n\n    if torch.cuda.device_count() > 1:\n        fc_params = model.module.model.last_linear.parameters()\n    else:\n        fc_params = model.model.last_linear.parameters()\n\n    to_optim = [\n        {\"params\": all_but_fc_params, \"lr\": opt.lr, \"weight_decay\": opt.decay},\n        {\n            \"params\": fc_params,\n            \"lr\": opt.lr * opt.fc_lr_mul,\n            \"weight_decay\": opt.decay,\n        },\n    ]\nelse:\n    to_optim = [\n        {\"params\": model.parameters(), \"lr\": opt.lr, \"weight_decay\": opt.decay}\n    ]\n\"\"\"============================================================================\"\"\"\n#################### DATALOADER SETUPS ##################\n# Returns a dictionary containing 'training', 'testing', and 'evaluation' dataloaders.\n# The 'testing'-dataloader corresponds to the validation set, and the 'evaluation'-dataloader\n# Is simply using the training set, however running under the same rules as 'testing' dataloader,\n# i.e. no shuffling and no random cropping.\ndataloaders = data.give_dataloaders(opt.dataset, opt.trainset, opt.testset, opt)\n# Because the number of supervised classes is dataset dependent, we store them after\n# initializing the dataloader\nopt.num_classes = len(dataloaders[\"training\"].dataset.avail_classes)\n\n\"\"\"============================================================================\"\"\"\n#################### CREATE LOGGING FILES ###############\n# Each dataset usually has a set of standard metrics to log. aux.metrics_to_examine()\n# returns a dict which lists metrics to log for training ('train') and validation/testing ('val')\n\nmetrics_to_log = aux.metrics_to_examine(opt.dataset, opt.k_vals)\n# example output: {'train': ['Epochs', 'Time', 'Train Loss', 'Time'],\n#                  'val': ['Epochs','Time','NMI','F1', 'Recall @ 1','Recall @ 2','Recall @ 4','Recall @ 8']}\n\n# Using the provided metrics of interest, we generate a LOGGER instance.\n# Note that 'start_new' denotes that a new folder should be made in which everything will be stored.\n# This includes network weights as well.\nLOG = aux.LOGGER(opt, metrics_to_log, name=\"Base\", start_new=True)\n# If graphviz is installed on the system, a computational graph of the underlying\n# network will be made as well.\n\n\"\"\"============================================================================\"\"\"\n#################### LOSS SETUP ####################\n# Depending on opt.loss and opt.sampling, the respective criterion is returned,\n# and if the loss has trainable parameters, to_optim is appended.\ncriterion, to_optim = losses.loss_select(opt.loss, opt, to_optim)\n_ = criterion.to(opt.device)\n\n\"\"\"============================================================================\"\"\"\n##################### OPTIONAL EVALUATIONS #####################\n# Store the averaged gradients returned from the embedding to the last conv. layer.\nif opt.grad_measure:\n    grad_measure = eval.GradientMeasure(opt, name=\"baseline\")\n# Store the relative distances between average intra- and inter-class distance.\nif opt.dist_measure:\n    # Add a distance measure for training distance ratios\n    distance_measure = eval.DistanceMeasure(\n        dataloaders[\"evaluation\"], opt, name=\"Train\", update_epochs=1\n    )\n    # #If uncommented: Do the same for the test set\n    # distance_measure_test = eval.DistanceMeasure(dataloaders['testing'], opt, name='Train', update_epochs=1)\n\n\"\"\"============================================================================\"\"\"\n#################### OPTIM SETUP ####################\n# As optimizer, Adam with standard parameters is used.\nif opt.opt == \"adam\":\n    optimizer = torch.optim.Adam(to_optim)\nelif opt.opt == \"sgd\":\n    optimizer = torch.optim.SGD(to_optim)\nelse:\n    raise Exception(\"unknown optimiser\")\n# for the SOA measures in the paper - need to use SGD and 0.05 learning rate\n# optimizer    = torch.optim.Adam(to_optim)\n# optimizer    = torch.optim.SGD(to_optim)\nif opt.scheduler == \"exp\":\n    scheduler = torch.optim.lr_scheduler.ExponentialLR(\n        optimizer, gamma=opt.gamma\n    )\nelif opt.scheduler == \"step\":\n    scheduler = torch.optim.lr_scheduler.MultiStepLR(\n        optimizer, milestones=opt.tau, gamma=opt.gamma\n    )\nelif opt.scheduler == \"none\":\n    print(\"No scheduling used!\")\nelse:\n    raise Exception(\"No scheduling option for input: {}\".format(opt.scheduler))\n\n\ndef same_model(model1, model2):\n    for p1, p2 in zip(model1.parameters(), model2.parameters()):\n        if p1.data.ne(p2.data).sum() > 0:\n            return False\n    return True\n\n\n\"\"\"============================================================================\"\"\"\n\n\n#################### TRAINER FUNCTION ############################\ndef train_one_epoch(train_dataloader, model, optimizer, criterion, opt, epoch):\n    \"\"\"\n    This function is called every epoch to perform training of the network over one full\n    (randomized) iteration of the dataset.\n\n    Args:\n        train_dataloader: torch.utils.data.DataLoader, returns (augmented) training data.\n        model:            Network to train.\n        optimizer:        Optimizer to use for training.\n        criterion:        criterion to use during training.\n        opt:              argparse.Namespace, Contains all relevant parameters.\n        epoch:            int, Current epoch.\n\n    Returns:\n        Nothing!\n    \"\"\"\n\n    loss_collect = []\n\n    start = time.time()\n    data_iterator = tqdm(\n        train_dataloader, desc=\"Epoch {} Training...\".format(epoch)\n    )\n\n    for i, (class_labels, input) in enumerate(data_iterator):\n        # Compute embeddings for input batch\n        features = model(input.to(opt.device))\n\n        # Compute loss.\n        if opt.loss != \"smoothap\":\n            loss = criterion(features, class_labels)\n        else:\n            loss = criterion(features)\n\n        # Ensure gradients are set to zero at beginning\n        optimizer.zero_grad()\n        # Compute gradient\n        loss.backward()\n\n        train_dataloader.dataset.classes_visited = []\n\n        if opt.grad_measure:\n            # If desired, save computed gradients.\n            grad_measure.include(model.model.last_linear)\n\n        # Update weights using comp. gradients.\n        optimizer.step()\n\n        # Store loss per iteration.\n        loss_collect.append(loss.item())\n        if i == len(train_dataloader) - 1:\n            data_iterator.set_description(\n                \"Epoch (Train) {0}: Mean Loss [{1:.4f}]\".format(\n                    epoch, np.mean(loss_collect)\n                )\n            )\n\n    # Save metrics\n    LOG.log(\n        \"train\",\n        LOG.metrics_to_log[\"train\"],\n        [epoch, np.round(time.time() - start, 4), np.mean(loss_collect)],\n    )\n    writer.add_scalar(\"global/training_loss\", np.mean(loss_collect), epoch)\n    if opt.grad_measure:\n        # Dump stored gradients to Pickle-File.\n        grad_measure.dump(epoch)\n\n\n\"\"\"============================================================================\"\"\"\n\"\"\"========================== MAIN TRAINING PART ==============================\"\"\"\n\"\"\"============================================================================\"\"\"\n################### SCRIPT MAIN ##########################\nprint(\"\\n-----\\n\")\n# Each dataset requires slightly different dataloaders.\n\nif opt.dataset == \"SoftInaturalist\" or \"Inaturalist\" or \"semi_fungi\":\n    eval_params = {\n        \"dataloader\": dataloaders[\"testing\"],\n        \"model\": model,\n        \"opt\": opt,\n        \"epoch\": 0,\n    }\n\nelif opt.dataset == \"vehicle_id\":\n    eval_params = {\n        \"dataloaders\": [\n            dataloaders[\"testing_set1\"],\n            dataloaders[\"testing_set2\"],\n            dataloaders[\"testing_set3\"],\n        ],\n        \"model\": model,\n        \"opt\": opt,\n        \"epoch\": 0,\n    }\n# Compute Evaluation metrics, print them and store in LOG.\nprint(\"epochs -> \" + str(opt.n_epochs))\nimport time\n\nfor epoch in range(opt.n_epochs):\n    ### Print current learning rates for all parameters\n    if opt.scheduler != \"none\":\n        print(\n            \"Running with learning rates {}...\".format(\n                \" | \".join(\"{}\".format(x) for x in scheduler.get_lr())\n            )\n        )\n\n    ### Train one epoch\n    _ = model.train()\n\n    train_one_epoch(\n        dataloaders[\"training\"], model, optimizer, criterion, opt, epoch\n    )\n\n    dataloaders[\"training\"].dataset.reshuffle()\n    ### Evaluate\n    _ = model.eval()\n    # Each dataset requires slightly different dataloaders.\n    if opt.dataset == \"Inaturalist\":\n        eval_params = {\n            \"dataloader\": dataloaders[\"evaluation\"],\n            \"model\": model,\n            \"opt\": opt,\n            \"epoch\": epoch,\n        }\n    elif opt.dataset == \"vehicle_id\":\n        eval_params = {\n            \"dataloaders\": [\n                dataloaders[\"testing_set1\"],\n                dataloaders[\"testing_set2\"],\n                dataloaders[\"testing_set3\"],\n            ],\n            \"model\": model,\n            \"opt\": opt,\n            \"epoch\": epoch,\n        }\n    elif opt.dataset == \"semi_fungi\":\n        eval_params = {\n            \"dataloader\": dataloaders[\"testing\"],\n            \"model\": model,\n            \"opt\": opt,\n            \"epoch\": epoch,\n        }\n\n    # Compute Evaluation metrics, print them and store in LOG.\n    if opt.infrequent_eval == 1:\n        epoch_freq = 5\n    else:\n        epoch_freq = 1\n\n    if not opt.dataset == \"vehicle_id\":\n        if epoch % epoch_freq == 0:\n            results = eval.evaluate(opt.dataset, LOG, save=True, **eval_params)\n            writer.add_scalar(\"global/recall1\", results[0][0], epoch + 1)\n            writer.add_scalar(\"global/recall2\", results[0][1], epoch + 1)\n            writer.add_scalar(\"global/recall3\", results[0][2], epoch + 1)\n            writer.add_scalar(\"global/recall4\", results[0][3], epoch + 1)\n            writer.add_scalar(\"global/NMI\", results[1], epoch + 1)\n            writer.add_scalar(\"global/F1\", results[2], epoch + 1)\n\n    else:\n        results = eval.evaluate(opt.dataset, LOG, save=True, **eval_params)\n        writer.add_scalar(\"global/recall1\", results[2], epoch + 1)\n        writer.add_scalar(\n            \"global/recall2\", results[3], epoch + 1\n        )  # writer.add_scalar('global/recall3',results[0][2],0)\n        writer.add_scalar(\"global/recall3\", results[6], epoch + 1)\n        writer.add_scalar(\"global/recall4\", results[7], epoch + 1)\n        writer.add_scalar(\"global/recall5\", results[10], epoch + 1)\n        writer.add_scalar(\"global/recall6\", results[11], epoch + 1)\n    # Update the Metric Plot and save it.\n    # LOG.update_info_plot()\n    # (optional) compute ratio of intra- to interdistances.\n    if opt.dist_measure:\n        distance_measure.measure(model, epoch)\n        # distance_measure_test.measure(model, epoch)\n\n    ### Learning Rate Scheduling Step\n    if opt.scheduler != \"none\":\n        scheduler.step()\n\n    print(\"\\n-----\\n\")\n"
  },
  {
    "path": "examples/pytorch/hilander/PSS/Smooth_AP/src/netlib.py",
    "content": "# repo originally forked from https://github.com/Confusezius/Deep-Metric-Learning-Baselines\n\n############################ LIBRARIES ######################################\nimport os\nfrom collections import OrderedDict\n\nimport auxiliaries as aux\nimport pretrainedmodels as ptm\nimport torch\nimport torch.nn as nn\n\n\"\"\"=============================================================\"\"\"\n\n\ndef initialize_weights(model):\n    \"\"\"\n    Function to initialize network weights.\n    NOTE: NOT USED IN MAIN SCRIPT.\n\n    Args:\n        model: PyTorch Network\n    Returns:\n        Nothing!\n    \"\"\"\n    for idx, module in enumerate(model.modules()):\n        if isinstance(module, nn.Conv2d):\n            nn.init.kaiming_normal_(\n                module.weight, mode=\"fan_out\", nonlinearity=\"relu\"\n            )\n        elif isinstance(module, nn.BatchNorm2d):\n            nn.init.constant_(module.weight, 1)\n            nn.init.constant_(module.bias, 0)\n        elif isinstance(module, nn.Linear):\n            module.weight.data.normal_(0, 0.01)\n            module.bias.data.zero_()\n\n\n\"\"\"==================================================================================================================================\"\"\"\n\n\n### ATTRIBUTE CHANGE HELPER\ndef rename_attr(model, attr, name):\n    \"\"\"\n    Rename attribute in a class. Simply helper function.\n\n    Args:\n        model:  General Class for which attributes should be renamed.\n        attr:   str, Name of target attribute.\n        name:   str, New attribute name.\n    \"\"\"\n    setattr(model, name, getattr(model, attr))\n    delattr(model, attr)\n\n\n\"\"\"==================================================================================================================================\"\"\"\n\n\n### NETWORK SELECTION FUNCTION\ndef networkselect(opt):\n    \"\"\"\n    Selection function for available networks.\n\n    Args:\n        opt: argparse.Namespace, contains all training-specific training parameters.\n    Returns:\n        Network of choice\n    \"\"\"\n    if opt.arch == \"resnet50\":\n        network = ResNet50(opt)\n    else:\n        raise Exception(\"Network {} not available!\".format(opt.arch))\n\n    if opt.resume:\n        weights = torch.load(\n            os.path.join(opt.save_path, opt.resume), weights_only=False\n        )\n        weights_state_dict = weights[\"state_dict\"]\n\n        if torch.cuda.device_count() > 1:\n            encoder_state_dict = OrderedDict()\n            for k, v in weights_state_dict.items():\n                k = k.replace(\"module.\", \"\")\n                encoder_state_dict[k] = v\n\n            network.load_state_dict(encoder_state_dict)\n        else:\n            network.load_state_dict(weights_state_dict)\n\n    # print(\"=================== network =======================\")\n    # for parameter in network.parameters():\n    #     parameter.requires_grad = False\n    # for parameter in network.layer_blocks[-1].parameters():\n    #     parameter.requires_grad = True\n\n    return network\n\n\n\"\"\"=============================================================\"\"\"\n\n\nclass ResNet50(nn.Module):\n    \"\"\"\n    Container for ResNet50 s.t. it can be used for metric learning.\n    The Network has been broken down to allow for higher modularity, if one wishes\n    to target specific layers/blocks directly.\n    \"\"\"\n\n    def __init__(self, opt, list_style=False, no_norm=False):\n        super(ResNet50, self).__init__()\n\n        self.pars = opt\n\n        if not opt.not_pretrained:\n            print(\"Getting pretrained weights...\")\n            self.model = ptm.__dict__[\"resnet50\"](\n                num_classes=1000, pretrained=\"imagenet\"\n            )\n            print(\"Done.\")\n        else:\n            print(\"Not utilizing pretrained weights!\")\n            self.model = ptm.__dict__[\"resnet50\"](\n                num_classes=1000, pretrained=None\n            )\n        for module in filter(\n            lambda m: type(m) == nn.BatchNorm2d, self.model.modules()\n        ):\n            module.eval()\n            module.train = lambda _: None\n\n        if opt.embed_dim != 2048:\n            self.model.last_linear = torch.nn.Linear(\n                self.model.last_linear.in_features, opt.embed_dim\n            )\n\n        self.layer_blocks = nn.ModuleList(\n            [\n                self.model.layer1,\n                self.model.layer2,\n                self.model.layer3,\n                self.model.layer4,\n            ]\n        )\n        self.loss = opt.loss\n        self.feature = True\n\n    def forward(self, x, feature=False, is_init_cluster_generation=False):\n        x = self.model.maxpool(\n            self.model.relu(self.model.bn1(self.model.conv1(x)))\n        )\n\n        for layerblock in self.layer_blocks:\n            x = layerblock(x)\n\n        x = self.model.avgpool(x)\n        x = x.view(x.size(0), -1)\n\n        if self.pars.embed_dim != 2048:\n            mod_x = self.model.last_linear(x)\n        else:\n            mod_x = x\n\n        feat = torch.nn.functional.normalize(mod_x, dim=-1)\n\n        if feature or self.loss == \"smoothap\":\n            return feat\n        else:\n            pred = self.linear(feat)\n            return pred\n"
  },
  {
    "path": "examples/pytorch/hilander/PSS/__init__.py",
    "content": ""
  },
  {
    "path": "examples/pytorch/hilander/PSS/test.sh",
    "content": "python Smooth_AP/src/evaluate_model.py \\\n--dataset Inaturalist \\\n--bs 384 \\\n--source_path ~/code/Smooth_AP/data/ --embed_dim 128 \\\n--resume $CHECKPOINT_PATH \\\n--class_num 948 --loss smoothap \\\n--trainset lin_train_set1.txt \\\n--testset Inaturalist_test_set1.txt \\\n--linsize 29011 --uinsize 18403"
  },
  {
    "path": "examples/pytorch/hilander/PSS/test_subg_inat.py",
    "content": "import argparse, os, pickle, time\nimport random\nimport sys\n\nsys.path.append(\"..\")\n\nimport shutil\n\nimport dgl\nimport numpy as np\nimport seaborn\nimport torch\nimport torch.optim as optim\nfrom dataset import LanderDataset\n\nfrom matplotlib import pyplot as plt\n\nfrom models import LANDER\nfrom utils import build_next_level, decode, evaluation, stop_iterating\nfrom utils.deduce import get_edge_dist\n\nSTATISTIC = False\n\n###########\n# ArgParser\nparser = argparse.ArgumentParser()\n\n# Dataset\nparser.add_argument(\"--data_path\", type=str, required=True)\nparser.add_argument(\"--model_filename\", type=str, default=\"lander.pth\")\nparser.add_argument(\"--faiss_gpu\", action=\"store_true\")\nparser.add_argument(\"--num_workers\", type=int, default=0)\nparser.add_argument(\"--output_filename\", type=str, default=\"data/features.pkl\")\n\n# HyperParam\nparser.add_argument(\"--knn_k\", type=int, default=10)\nparser.add_argument(\"--levels\", type=int, default=1)\nparser.add_argument(\"--tau\", type=float, default=0.5)\nparser.add_argument(\"--threshold\", type=str, default=\"prob\")\nparser.add_argument(\"--metrics\", type=str, default=\"pairwise,bcubed,nmi\")\nparser.add_argument(\"--early_stop\", action=\"store_true\")\n\n# Model\nparser.add_argument(\"--hidden\", type=int, default=512)\nparser.add_argument(\"--num_conv\", type=int, default=4)\nparser.add_argument(\"--dropout\", type=float, default=0.0)\nparser.add_argument(\"--gat\", action=\"store_true\")\nparser.add_argument(\"--gat_k\", type=int, default=1)\nparser.add_argument(\"--balance\", action=\"store_true\")\nparser.add_argument(\"--use_cluster_feat\", action=\"store_true\")\nparser.add_argument(\"--use_focal_loss\", action=\"store_true\")\nparser.add_argument(\"--use_gt\", action=\"store_true\")\n\n# Subgraph\nparser.add_argument(\"--batch_size\", type=int, default=4096)\nparser.add_argument(\"--mode\", type=str, default=\"1head\")\nparser.add_argument(\"--midpoint\", type=str, default=\"false\")\nparser.add_argument(\"--linsize\", type=int, default=29011)\nparser.add_argument(\"--uinsize\", type=int, default=18403)\nparser.add_argument(\"--inclasses\", type=int, default=948)\nparser.add_argument(\"--thresh\", type=float, default=1.0)\n\nparser.add_argument(\"--draw\", type=str, default=\"false\")\nparser.add_argument(\n    \"--density_distance_pkl\", type=str, default=\"density_distance.pkl\"\n)\nparser.add_argument(\n    \"--density_lindistance_jpg\", type=str, default=\"density_lindistance.jpg\"\n)\n\nargs = parser.parse_args()\nprint(args)\nMODE = args.mode\nlinsize = args.linsize\nuinsize = args.uinsize\ninclasses = args.inclasses\n\nif args.draw == \"false\":\n    args.draw = False\nelif args.draw == \"true\":\n    args.draw = True\n\n###########################\n# Environment Configuration\nif torch.cuda.is_available():\n    device = torch.device(\"cuda\")\nelse:\n    device = torch.device(\"cpu\")\n\n##################\n# Data Preparation\nwith open(args.data_path, \"rb\") as f:\n    loaded_data = pickle.load(f)\n    path2idx, features, pred_labels, labels, masks = loaded_data\n\nidx2path = {v: k for k, v in path2idx.items()}\ngtlabels = labels\n\norifeatures = features\norilabels = gtlabels\n\nif MODE == \"selectbydensity\":\n    lastusim = np.where(masks == 1)\n    masks[lastusim] = 2\n    selectedidx = np.where(masks != 0)\n    features = features[selectedidx]\n    labels = gtlabels[selectedidx]\n    selectmasks = masks[selectedidx]\n    print(\"filtered features:\", len(features))\n    print(\"mask0:\", len(np.where(masks == 0)[0]))\n    print(\"mask1:\", len(np.where(masks == 1)[0]))\n    print(\"mask2:\", len(np.where(masks == 2)[0]))\nelif MODE == \"recluster\":\n    selectedidx = np.where(masks == 1)\n    features = features[selectedidx]\n    labels = gtlabels[selectedidx]\n    labelspred = pred_labels[selectedidx]\n    selectmasks = masks[selectedidx]\n    gtlabels = gtlabels[selectedidx]\n    print(\"filtered features:\", len(features))\nelse:\n    selectedidx = np.where(masks != 0)\n    features = features[selectedidx]\n    labels = gtlabels[selectedidx]\n    labelspred = pred_labels[selectedidx]\n    selectmasks = masks[selectedidx]\n    gtlabels = gtlabels[selectedidx]\n    print(\"filtered features:\", len(features))\n\nglobal_features = features.copy()  # global features\ndataset = LanderDataset(\n    features=features, labels=labels, k=args.knn_k, levels=1, faiss_gpu=False\n)\ng = dataset.gs[0]\ng.ndata[\"pred_den\"] = torch.zeros((g.num_nodes()))\ng.edata[\"prob_conn\"] = torch.zeros((g.num_edges(), 2))\nglobal_labels = labels.copy()\nids = np.arange(g.num_nodes())\nglobal_edges = ([], [])\nglobal_peaks = np.array([], dtype=np.long)\nglobal_edges_len = len(global_edges[0])\nglobal_num_nodes = g.num_nodes()\n\nglobal_densities = g.ndata[\"density\"][:linsize]\nglobal_densities = np.sort(global_densities)\nxs = np.arange(len(global_densities))\n\nfanouts = [args.knn_k - 1 for i in range(args.num_conv + 1)]\nsampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts)\n# fix the number of edges\ntest_loader = dgl.dataloading.DataLoader(\n    g,\n    torch.arange(g.num_nodes()),\n    sampler,\n    batch_size=args.batch_size,\n    shuffle=False,\n    drop_last=False,\n    num_workers=args.num_workers,\n)\n\n##################\n# Model Definition\nif not args.use_gt:\n    feature_dim = g.ndata[\"features\"].shape[1]\n    model = LANDER(\n        feature_dim=feature_dim,\n        nhid=args.hidden,\n        num_conv=args.num_conv,\n        dropout=args.dropout,\n        use_GAT=args.gat,\n        K=args.gat_k,\n        balance=args.balance,\n        use_cluster_feat=args.use_cluster_feat,\n        use_focal_loss=args.use_focal_loss,\n    )\n    model.load_state_dict(torch.load(args.model_filename, weights_only=False))\n    model = model.to(device)\n    model.eval()\n\n# number of edges added is the indicator for early stopping\nnum_edges_add_last_level = np.Inf\n##################################\n# Predict connectivity and density\nfor level in range(args.levels):\n    print(\"level:\", level)\n    if not args.use_gt:\n        total_batches = len(test_loader)\n        for batch, minibatch in enumerate(test_loader):\n            input_nodes, sub_g, bipartites = minibatch\n            sub_g = sub_g.to(device)\n            bipartites = [b.to(device) for b in bipartites]\n            with torch.no_grad():\n                output_bipartite = model(bipartites)\n            global_nid = output_bipartite.dstdata[dgl.NID]\n            global_eid = output_bipartite.edata[\"global_eid\"]\n            g.ndata[\"pred_den\"][global_nid] = output_bipartite.dstdata[\n                \"pred_den\"\n            ].to(\"cpu\")\n            g.edata[\"prob_conn\"][global_eid] = output_bipartite.edata[\n                \"prob_conn\"\n            ].to(\"cpu\")\n            torch.cuda.empty_cache()\n            if (batch + 1) % 10 == 0:\n                print(\"Batch %d / %d for inference\" % (batch, total_batches))\n\n    (\n        new_pred_labels,\n        peaks,\n        global_edges,\n        global_pred_labels,\n        global_peaks,\n    ) = decode(\n        g,\n        args.tau,\n        args.threshold,\n        args.use_gt,\n        ids,\n        global_edges,\n        global_num_nodes,\n        global_peaks,\n    )\n    if level == 0:\n        global_pred_densities = g.ndata[\"pred_den\"]\n        global_densities = g.ndata[\"density\"]\n        g.edata[\"prob_conn\"] = torch.zeros((g.num_edges(), 2))\n\n    ids = ids[peaks]\n    new_global_edges_len = len(global_edges[0])\n    num_edges_add_this_level = new_global_edges_len - global_edges_len\n    if stop_iterating(\n        level,\n        args.levels,\n        args.early_stop,\n        num_edges_add_this_level,\n        num_edges_add_last_level,\n        args.knn_k,\n    ):\n        break\n    global_edges_len = new_global_edges_len\n    num_edges_add_last_level = num_edges_add_this_level\n\n    # build new dataset\n    features, labels, cluster_features = build_next_level(\n        features,\n        labels,\n        peaks,\n        global_features,\n        global_pred_labels,\n        global_peaks,\n    )\n    # After the first level, the number of nodes reduce a lot. Using cpu faiss is faster.\n    dataset = LanderDataset(\n        features=features,\n        labels=labels,\n        k=args.knn_k,\n        levels=1,\n        faiss_gpu=False,\n        cluster_features=cluster_features,\n    )\n    g = dataset.gs[0]\n    g.ndata[\"pred_den\"] = torch.zeros((g.num_nodes()))\n    g.edata[\"prob_conn\"] = torch.zeros((g.num_edges(), 2))\n    test_loader = dgl.dataloading.DataLoader(\n        g,\n        torch.arange(g.num_nodes()),\n        sampler,\n        batch_size=args.batch_size,\n        shuffle=False,\n        drop_last=False,\n        num_workers=args.num_workers,\n    )\n\nif MODE == \"selectbydensity\":\n    thresh = args.thresh\n    global_pred_densities = np.array(global_pred_densities).astype(float)\n    global_densities = np.array(global_densities).astype(float)\n    distance = np.abs(global_pred_densities - global_densities)\n    print(\"densities shape\", global_pred_densities.shape)\n    print(global_pred_densities.max(), global_pred_densities.min())\n\n    selectidx = np.where(global_pred_densities > thresh)[0]\n    selected_pred_densities = global_pred_densities[selectidx]\n    selected_densities = global_densities[selectidx]\n    selected_distance = np.abs(selected_pred_densities - selected_densities)\n    print(np.mean(selected_distance))\n    print(\"number of selected samples:\", len(selectidx))\n\n    notselectidx = np.where(global_pred_densities <= thresh)\n    print(\"not selected:\", len(notselectidx[0]))\n    global_pred_labels[notselectidx] = -1\n\n    global_pred_labels_new = np.zeros_like(orilabels)\n    global_pred_labels_new[:] = -1\n    Tidx = np.where(masks != 2)\n    print(\"T:\", len(Tidx[0]))\n\n    l_in_gt = orilabels[Tidx]\n    l_in_features = orifeatures[Tidx]\n    l_in_gt_new = np.zeros_like(l_in_gt)\n    l_in_unique = np.unique(l_in_gt)\n    for i in range(len(l_in_unique)):\n        l_in = l_in_unique[i]\n        l_in_idx = np.where(l_in_gt == l_in)\n        l_in_gt_new[l_in_idx] = i\n    print(\"len(l_in_unique)\", len(l_in_unique))\n\n    if args.draw:\n        prototypes = np.zeros((len(l_in_unique), features.shape[1]))\n        for i in range(len(l_in_unique)):\n            idx = np.where(l_in_gt_new == i)\n            prototypes[i] = np.mean(l_in_features[idx], axis=0)\n\n        similarity_matrix = torch.mm(\n            torch.from_numpy(global_features.astype(np.float32)),\n            torch.from_numpy(prototypes.astype(np.float32)).t(),\n        )\n        similarity_matrix = (1 - similarity_matrix) / 2\n        minvalues, selected_pred_labels = torch.min(similarity_matrix, 1)\n        # far-close ratio\n        closeidx = np.where(minvalues < 0.15)\n        faridx = np.where(minvalues >= 0.15)\n        print(\"far:\", len(faridx[0]))\n        print(\"close:\", len(closeidx[0]))\n\n        cutidx = np.where(global_pred_densities >= 0.5)\n        draw_minvalues = minvalues[cutidx]\n        draw_densities = global_pred_densities[cutidx]\n        with open(args.density_distance_pkl, \"wb\") as f:\n            pickle.dump((global_pred_densities, minvalues), f)\n        print(\"dumped.\")\n        plt.clf()\n        fig, ax = plt.subplots()\n        import random\n\n        if len(draw_densities) > 10000:\n            samples_idx = random.sample(range(len(draw_minvalues)), 10000)\n            ax.plot(\n                draw_densities[random],\n                draw_minvalues[random],\n                color=\"tab:blue\",\n                marker=\"*\",\n                linestyle=\"None\",\n                markersize=1,\n            )\n        else:\n            ax.plot(\n                draw_densities[random],\n                draw_minvalues[random],\n                color=\"tab:blue\",\n                marker=\"*\",\n                linestyle=\"None\",\n                markersize=1,\n            )\n        plt.savefig(args.density_lindistance_jpg)\n\n    global_pred_labels_new[Tidx] = l_in_gt_new\n    global_pred_labels[selectidx] = global_pred_labels[selectidx] + len(\n        l_in_unique\n    )\n    global_pred_labels_new[selectedidx] = global_pred_labels\n\n    global_pred_labels = global_pred_labels_new\n    linunique = np.unique(global_pred_labels[Tidx])\n    uunique = np.unique(global_pred_labels[selectedidx])\n    allnique = np.unique(global_pred_labels)\n    print(\"labels\")\n    print(len(linunique), len(uunique), len(allnique))\n\n    global_masks = np.zeros_like(masks)\n    global_masks[:] = 1\n    global_masks[np.array(selectedidx[0])[notselectidx]] = 2\n    Tidx = np.where(masks != 2)\n    global_masks[Tidx] = 0\n    print(\"mask0\", len(np.where(global_masks == 0)[0]))\n    print(\"mask1\", len(np.where(global_masks == 1)[0]))\n    print(\"mask2\", len(np.where(global_masks == 2)[0]))\n    print(\"all\", len(masks), len(orilabels), len(orifeatures))\n\n    global_gt_labels = orilabels\n\nif MODE == \"recluster\":\n    global_pred_labels_new = np.zeros_like(orilabels)\n    global_pred_labels_new[:] = -1\n    Tidx = np.where(masks == 0)\n    print(\"T:\", len(Tidx[0]))\n\n    l_in_gt = orilabels[Tidx]\n    l_in_features = orifeatures[Tidx]\n    l_in_gt_new = np.zeros_like(l_in_gt)\n    l_in_unique = np.unique(l_in_gt)\n    for i in range(len(l_in_unique)):\n        l_in = l_in_unique[i]\n        l_in_idx = np.where(l_in_gt == l_in)\n        l_in_gt_new[l_in_idx] = i\n    print(\"len(l_in_unique)\", len(l_in_unique))\n\n    global_pred_labels_new[Tidx] = l_in_gt_new\n    print(len(global_pred_labels))\n    print(len(selectedidx[0]))\n    global_pred_labels_new[selectedidx[0]] = global_pred_labels + len(\n        l_in_unique\n    )\n    global_pred_labels = global_pred_labels_new\n    global_masks = masks\n    print(\"mask0\", len(np.where(global_masks == 0)[0]))\n    print(\"mask1\", len(np.where(global_masks == 1)[0]))\n    print(\"mask2\", len(np.where(global_masks == 2)[0]))\n    print(\"all\", len(masks), len(orilabels), len(orifeatures))\n    global_gt_labels = orilabels\n\nif MODE == \"donothing\":\n    global_masks = masks\n    pass\n\nprint(\"##################### L_in ########################\")\nprint(linsize)\nif len(global_pred_labels) >= linsize:\n    evaluation(\n        global_pred_labels[:linsize], global_gt_labels[:linsize], args.metrics\n    )\nelse:\n    print(\"No samples in L_in!\")\nprint(\"##################### U_in ########################\")\nuinidx = np.where(global_pred_labels[linsize : linsize + uinsize] != -1)[0]\nuinidx = uinidx + linsize\nprint(len(uinidx))\nif len(uinidx):\n    evaluation(\n        global_pred_labels[uinidx], global_gt_labels[uinidx], args.metrics\n    )\nelse:\n    print(\"No samples in U_in!\")\nprint(\"##################### U_out ########################\")\nuoutidx = np.where(global_pred_labels[linsize + uinsize :] != -1)[0]\nuoutidx = uoutidx + linsize + uinsize\nprint(len(uoutidx))\nif len(uoutidx):\n    evaluation(\n        global_pred_labels[uoutidx], global_gt_labels[uoutidx], args.metrics\n    )\nelse:\n    print(\"No samples in U_out!\")\nprint(\"##################### U ########################\")\nuidx = np.where(global_pred_labels[linsize:] != -1)[0]\nuidx = uidx + linsize\nprint(len(uidx))\nif len(uidx):\n    evaluation(global_pred_labels[uidx], global_gt_labels[uidx], args.metrics)\nelse:\n    print(\"No samples in U!\")\nprint(\"##################### L+U ########################\")\nluidx = np.where(global_pred_labels != -1)[0]\nprint(len(luidx))\nevaluation(global_pred_labels[luidx], global_gt_labels[luidx], args.metrics)\nprint(\"##################### new selected samples ########################\")\nsidx = np.where(global_masks == 1)[0]\nprint(len(sidx))\nif len(sidx) != 0:\n    evaluation(global_pred_labels[sidx], global_gt_labels[sidx], args.metrics)\nprint(\"##################### not selected samples ########################\")\nnsidx = np.where(global_masks == 2)[0]\nprint(len(nsidx))\nif len(nsidx) != 0:\n    evaluation(global_pred_labels[nsidx], global_gt_labels[nsidx], args.metrics)\n\nwith open(args.output_filename, \"wb\") as f:\n    print(orifeatures.shape)\n    print(global_pred_labels.shape)\n    print(global_gt_labels.shape)\n    print(global_masks.shape)\n    pickle.dump(\n        [\n            path2idx,\n            orifeatures,\n            global_pred_labels,\n            global_gt_labels,\n            global_masks,\n        ],\n        f,\n    )\n"
  },
  {
    "path": "examples/pytorch/hilander/PSS/train.sh",
    "content": "#!/bin/bash\n\nmkdir hilander_checkpoint\n\n####################### ITER 0 #######################\n# iter 0 (supervised baseline) - train Smooth-AP\nCUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7\npython Smooth_AP/src/main.py \\\n--dataset Inaturalist --lr 1e-5 --fc_lr_mul 1 \\\n--n_epochs 400 --bs 384 \\\n--source_path \"../../data/\" --embed_dim 128 \\\n--class_num 948 --loss smoothap --infrequent_eval 1 \\\n--trainset lin_train_set1.txt --testset Inaturalist_test_set1.txt\n\n# iter 0 (supervised baseline) - get feature\npython Smooth_AP/src/get_features.py \\\n--dataset Inaturalist --lr 1e-5 --fc_lr_mul 1 \\\n--n_epochs 400 --bs 384 \\\n--source_path \"../../data/\" --embed_dim 128 \\\n--resume \"0/checkpoint_0.pth.tar\" \\\n--finetune false --get_features true --iter 0 \\\n--class_num 948 --loss smoothap \\\n--trainset lin_train_set1.txt \\\n--all_trainset train_set1.txt \\\n--testset Inaturalist_test_set1.txt \\\n--linsize 29011\n\n# iter 0 (supervised baseline) - train hi-lander\npython train_subg_inat.py \\\n--data_path \"/home/ubuntu/code/dgl/examples/pytorch/hilander/PSS/data/Inaturalist/T_train_iter0_smoothap_inat_features.pkl\" \\\n--model_filename '/home/ubuntu/code/dgl/examples/pytorch/hilander/PSS/hilander_checkpoint/inat_l_smoothap_iter0.pth' \\\n--knn_k 10,5,3 --levels 2,3,4 \\\n--hidden 512 --epochs 1000 --lr 0.01 \\\n--batch_size 4096 --num_conv 1 --gat --balance\n\n# iter 0 (supervised baseline) - get pseudo labels\npython test_subg_inat.py \\\n--data_path '/home/ubuntu/code/dgl/examples/pytorch/hilander/PSS/data/Inaturalist/all_train_iter0_smoothap_inat_features.pkl' \\\n--model_filename '/home/ubuntu/code/dgl/examples/pytorch/hilander/PSS/hilander_checkpoint/inat_l_smoothap_iter0.pth'  --knn_k 10 \\\n--tau 0.9 --level 10 --threshold prob \\\n--hidden 512 --num_conv 1 --gat --batch_size 4096 --early_stop \\\n--mode selectbydensity --thresh 0.8 \\\n--linsize 29011 --uinsize 18403 --inclasses 948 \\\n--output_filename 'data/inat_hilander_l_smoothap_train_selectbydensity_iter0.pkl'\n\n\nfor i in {1..4} ; do\n  last_iter=`expr $i - 1`\n  echo ${last_iter}\n  # iter i - train Smooth-AP\n  python Smooth_AP/src/finetune_1head.py \\\n  --dataset Inaturalist --lr 1e-5 --fc_lr_mul 1 \\\n  --n_epochs 400 --bs 384 --class_num 1024 \\\n  --source_path \"../../data/\" --embed_dim 128 \\\n  --trainset lin_train_set1.txt --testset Inaturalist_test_set1.txt \\\n  --cluster_path \"../../data/inat_hilander_l_smoothap_train_selectbydensity_iter${last_iter}.pkl\" \\\n  --finetune true --loss smoothap --infrequent_eval 1 --iter ${i}\n\n  # iter i - get feature\n  python Smooth_AP/src/get_features.py \\\n  --dataset Inaturalist --lr 1e-5 --fc_lr_mul 1 \\\n  --n_epochs 400 --bs 384 \\\n  --source_path \"../../data/\" --embed_dim 128 \\\n  --resume \"${i}/checkpoint_${i}.pth.tar\" \\\n  --finetune false --get_features true --iter ${i} \\\n  --class_num 948 --loss smoothap \\\n  --trainset lin_train_set1.txt \\\n  --all_trainset train_set1.txt \\\n  --testset Inaturalist_test_set1.txt \\\n  --linsize 29011 --uinsize 18403 \\\n  --cluster_path \"../../data/inat_hilander_l_smoothap_train_selectbydensity_iter${last_iter}.pkl\"\n\n  # iter i - train hi-lander\n  python train_subg_inat.py \\\n  --data_path \"/home/ubuntu/code/dgl/examples/pytorch/hilander/PSS/data/Inaturalist/T_train_iter${i}_smoothap_inat_features.pkl\" \\\n  --model_filename \"/home/ubuntu/code/dgl/examples/pytorch/hilander/PSS/hilander_checkpoint/inat_l_smoothap_iter${i}.pth\" \\\n  --knn_k 10,5,3 --levels 2,3,4 \\\n  --hidden 512 --epochs 1000 --lr 0.01 \\\n  --batch_size 4096 --num_conv 1 --gat --balance\n\n  # iter i - get pseudo labels\n  python test_subg_inat.py \\\n  --data_path \"/home/ubuntu/code/dgl/examples/pytorch/hilander/PSS/data/Inaturalist/all_train_iter${i}_smoothap_inat_features.pkl\" \\\n  --model_filename \"/home/ubuntu/code/dgl/examples/pytorch/hilander/PSS/hilander_checkpoint/inat_l_smoothap_iter${i}.pth\"  --knn_k 10 \\\n  --tau 0.9 --level 10 --threshold prob \\\n  --hidden 512 --num_conv 1 --gat --batch_size 4096 --early_stop \\\n  --mode selectbydensity --thresh 0.8 \\\n  --linsize 29011 --uinsize 18403 --inclasses 948 \\\n  --output_filename \"data/inat_hilander_l_smoothap_train_selectbydensity_iter${i}.pkl\"\ndone\n"
  },
  {
    "path": "examples/pytorch/hilander/PSS/train_subg_inat.py",
    "content": "import argparse, os, pickle, time\nimport random\n\nimport sys\n\nimport dgl\n\nimport numpy as np\nimport torch\nimport torch.optim as optim\n\nsys.path.append(\"..\")\nfrom dataset import LanderDataset\nfrom models import LANDER\n\n###########\n# ArgParser\nparser = argparse.ArgumentParser()\n\n# Dataset\nparser.add_argument(\"--data_path\", type=str, required=True)\nparser.add_argument(\"--levels\", type=str, default=\"1\")\nparser.add_argument(\"--faiss_gpu\", action=\"store_true\")\nparser.add_argument(\"--model_filename\", type=str, default=\"lander.pth\")\n\n# KNN\nparser.add_argument(\"--knn_k\", type=str, default=\"10\")\nparser.add_argument(\"--num_workers\", type=int, default=0)\n\n# Model\nparser.add_argument(\"--hidden\", type=int, default=512)\nparser.add_argument(\"--num_conv\", type=int, default=1)\nparser.add_argument(\"--dropout\", type=float, default=0.0)\nparser.add_argument(\"--gat\", action=\"store_true\")\nparser.add_argument(\"--gat_k\", type=int, default=1)\nparser.add_argument(\"--balance\", action=\"store_true\")\nparser.add_argument(\"--use_cluster_feat\", action=\"store_true\")\nparser.add_argument(\"--use_focal_loss\", action=\"store_true\")\n\n# Training\nparser.add_argument(\"--epochs\", type=int, default=100)\nparser.add_argument(\"--batch_size\", type=int, default=1024)\nparser.add_argument(\"--lr\", type=float, default=0.1)\nparser.add_argument(\"--momentum\", type=float, default=0.9)\nparser.add_argument(\"--weight_decay\", type=float, default=1e-5)\n\nargs = parser.parse_args()\nprint(args)\n\n###########################\n# Environment Configuration\nif torch.cuda.is_available():\n    device = torch.device(\"cuda\")\nelse:\n    device = torch.device(\"cpu\")\n\n\ndef setup_seed(seed):\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    np.random.seed(seed)\n    random.seed(seed)\n    torch.backends.cudnn.deterministic = True\n\n\n# setup_seed(20)\n\n##################\n# Data Preparation\nwith open(args.data_path, \"rb\") as f:\n    path2idx, features, labels, _, masks = pickle.load(f)\n    # lidx = np.where(masks==0)\n    # features = features[lidx]\n    # labels = labels[lidx]\n    print(\"features.shape:\", features.shape)\n    print(\"labels.shape:\", labels.shape)\n\n\nk_list = [int(k) for k in args.knn_k.split(\",\")]\nlvl_list = [int(l) for l in args.levels.split(\",\")]\ngs = []\nnbrs = []\nks = []\ndatasets = []\nfor k, l in zip(k_list, lvl_list):\n    print(\"k:\", k)\n    print(\"levels:\", l)\n    dataset = LanderDataset(\n        features=features,\n        labels=labels,\n        k=k,\n        levels=l,\n        faiss_gpu=args.faiss_gpu,\n    )\n    gs += [g for g in dataset.gs]\n    ks += [k for g in dataset.gs]\n    nbrs += [nbr for nbr in dataset.nbrs]\n    datasets.append(dataset)\n\n# with open(\"./dataset.pkl\", 'rb') as f:\n#     datasets = pickle.load(f)\n# for i in range(len(datasets)):\n#     dataset = datasets[i]\n#     k = k_list[i]\n#     gs += [g for g in dataset.gs]\n#     ks += [k for g in dataset.gs]\n#     nbrs += [nbr for nbr in dataset.nbrs]\n\n\nwith open(\"./dataset.pkl\", \"wb\") as f:\n    pickle.dump(datasets, f)\n\nprint(\"Dataset Prepared.\")\n\n\ndef set_train_sampler_loader(g, k):\n    fanouts = [k - 1 for i in range(args.num_conv + 1)]\n    sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts)\n    # fix the number of edges\n    train_dataloader = dgl.dataloading.DataLoader(\n        g,\n        torch.arange(g.num_nodes()),\n        sampler,\n        batch_size=args.batch_size,\n        shuffle=True,\n        drop_last=False,\n        num_workers=args.num_workers,\n    )\n    return train_dataloader\n\n\ntrain_loaders = []\nfor gidx, g in enumerate(gs):\n    train_dataloader = set_train_sampler_loader(gs[gidx], ks[gidx])\n    train_loaders.append(train_dataloader)\n\n##################\n# Model Definition\nfeature_dim = gs[0].ndata[\"features\"].shape[1]\nprint(\"feature dimension:\", feature_dim)\nmodel = LANDER(\n    feature_dim=feature_dim,\n    nhid=args.hidden,\n    num_conv=args.num_conv,\n    dropout=args.dropout,\n    use_GAT=args.gat,\n    K=args.gat_k,\n    balance=args.balance,\n    use_cluster_feat=args.use_cluster_feat,\n    use_focal_loss=args.use_focal_loss,\n)\nmodel = model.to(device)\nmodel.train()\n\n#################\n# Hyperparameters\nopt = optim.SGD(\n    model.parameters(),\n    lr=args.lr,\n    momentum=args.momentum,\n    weight_decay=args.weight_decay,\n)\n\n# keep num_batch_per_loader the same for every sub_dataloader\nnum_batch_per_loader = len(train_loaders[0])\ntrain_loaders = [iter(train_loader) for train_loader in train_loaders]\nnum_loaders = len(train_loaders)\nscheduler = optim.lr_scheduler.CosineAnnealingLR(\n    opt, T_max=args.epochs * num_batch_per_loader * num_loaders, eta_min=1e-5\n)\n\nprint(\"Start Training.\")\n\n###############\n# Training Loop\nfor epoch in range(args.epochs):\n    loss_den_val_total = []\n    loss_conn_val_total = []\n    loss_val_total = []\n    for batch in range(num_batch_per_loader):\n        for loader_id in range(num_loaders):\n            try:\n                minibatch = next(train_loaders[loader_id])\n            except:\n                train_loaders[loader_id] = iter(\n                    set_train_sampler_loader(gs[loader_id], ks[loader_id])\n                )\n                minibatch = next(train_loaders[loader_id])\n            input_nodes, sub_g, bipartites = minibatch\n            sub_g = sub_g.to(device)\n            bipartites = [b.to(device) for b in bipartites]\n            # get the feature for the input_nodes\n            opt.zero_grad()\n            output_bipartite = model(bipartites)\n            loss, loss_den_val, loss_conn_val = model.compute_loss(\n                output_bipartite\n            )\n            loss_den_val_total.append(loss_den_val)\n            loss_conn_val_total.append(loss_conn_val)\n            loss_val_total.append(loss.item())\n            loss.backward()\n            opt.step()\n            if (batch + 1) % 10 == 0:\n                print(\n                    \"epoch: %d, batch: %d / %d, loader_id : %d / %d, loss: %.6f, loss_den: %.6f, loss_conn: %.6f\"\n                    % (\n                        epoch,\n                        batch,\n                        num_batch_per_loader,\n                        loader_id,\n                        num_loaders,\n                        loss.item(),\n                        loss_den_val,\n                        loss_conn_val,\n                    )\n                )\n            scheduler.step()\n    print(\n        \"epoch: %d, loss: %.6f, loss_den: %.6f, loss_conn: %.6f\"\n        % (\n            epoch,\n            np.array(loss_val_total).mean(),\n            np.array(loss_den_val_total).mean(),\n            np.array(loss_conn_val_total).mean(),\n        )\n    )\n    torch.save(model.state_dict(), args.model_filename)\n\ntorch.save(model.state_dict(), args.model_filename)\n"
  },
  {
    "path": "examples/pytorch/hilander/README.md",
    "content": "Learning Hierarchical Graph Neural Networks for Image Clustering\n================================================================\n\nThis folder contains the official code for [Learning Hierarchical Graph Neural Networks for Image Clustering](https://arxiv.org/abs/2107.01319).\n\n## Setup\n\nWe use python 3.7. The CUDA version needs to be 10.2. Besides DGL (>=0.8), we depend on several packages. To install dependencies using conda:\n```bash\nconda create -n Hilander # create env\nconda activate Hilander # activate env\nconda install pytorch==1.7.0 torchvision==0.8.0 cudatoolkit=10.2 -c pytorch # install pytorch 1.7 version\nconda install -y cudatoolkit=10.2 faiss-gpu=1.6.5 -c pytorch # install faiss gpu version matching cuda 10.2\npip install dgl-cu102 dglgo -f https://data.dgl.ai/wheels/repo.html # install the latest dgl for cuda 10.2\npip install tqdm # install tqdm\ngit clone https://github.com/yjxiong/clustering-benchmark.git # install clustering-benchmark for evaluation\ncd clustering-benchmark\npython setup.py install\ncd ../\n```\n\n## Data\n\nThe datasets used for training and test are hosted by several services.\n\n[AWS S3](https://dgl-data.s3.us-west-2.amazonaws.com/dataset/hilander/data.tar.gz) | [Google Drive](https://drive.google.com/file/d/1KLa3uu9ndaCc7YjnSVRLHpcJVMSz868v/view?usp=sharing) | [BaiduPan](https://pan.baidu.com/s/11iRcp84esfkkvdcw3kmPAw) (pwd: wbmh)\n\nAfter download, unpack the pickled files into `data/`.\n\n## Training\n\nWe provide training scripts for different datasets.\n\nFor training on DeepGlint, one can run\n\n```bash\nbash scripts/train_deepglint.sh\n```\nDeepglint is a large-scale dataset, we randomly select 10% of the classes to construct a subset to train.\n\nFor training on full iNatualist dataset, one can run\n\n```bash\nbash scripts/train_inat.sh\n```\n\nFor training on re-sampled iNatualist dataset, one can run\n\n```bash\nbash scripts/train_inat_resampled_1_in_6_per_class.sh\n```\nWe sample a subset of the full iNat2018-Train to attain a drastically different train-time cluster size distribution as iNat2018-Test, which is named as inat_resampled_1_in_6_per_class.\n\n## Inference\n\nIn the paper, we have two experiment settings: Clustering with Seen Test Data Distribution and Clustering with Unseen Test Data Distribution.\n\nFor Clustering with Seen Test Data Distribution, one can run\n\n```bash\nbash scripts/test_deepglint_imbd_sampled_as_deepglint.sh\n\nbash scripts/test_inat.sh\n```\n\n**Clustering with Seen Test Data Distribution Performance**\n|                    |              IMDB-Test-SameDist |                   iNat2018-Test |\n| ------------------ | ------------------------------: | ------------------------------: |\n|                 Fp |                           0.779 |                           0.330 |\n|                 Fb |                           0.819 |                           0.350 |\n|                NMI |                           0.949 |                           0.774 |\n* The results might fluctuate a little due to the randomness introduced by gpu knn building using faiss-gpu.\n\n\nFor Clustering with Unseen Test Data Distribution, one can run\n\n```bash\nbash scripts/test_deepglint_hannah.sh\n\nbash scripts/test_deepglint_imdb.sh\n\nbash scripts/test_inat_train_on_resampled_1_in_6_per_class.sh\n```\n\n**Clustering with Unseen Test Data Distribution Performance**\n|                    |                          Hannah |                            IMDB |                   iNat2018-Test |\n| ------------------ | ------------------------------: | ------------------------------: | ------------------------------: |\n|                 Fp |                           0.741 |                           0.717 |                           0.294 |\n|                 Fb |                           0.706 |                           0.810 |                           0.352 |\n|                NMI |                           0.810 |                           0.953 |                           0.764 |\n* The results might fluctuate a little due to the randomness introduced by gpu knn building using faiss-gpu.\n\n"
  },
  {
    "path": "examples/pytorch/hilander/__init__.py",
    "content": ""
  },
  {
    "path": "examples/pytorch/hilander/checkpoint/.gitkeep",
    "content": ""
  },
  {
    "path": "examples/pytorch/hilander/data/.gitkeep",
    "content": ""
  },
  {
    "path": "examples/pytorch/hilander/models/__init__.py",
    "content": "from .graphconv import GraphConv\nfrom .lander import LANDER\n"
  },
  {
    "path": "examples/pytorch/hilander/models/focal_loss.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.autograd import Variable\n\n# Below code are based on\n# https://zhuanlan.zhihu.com/p/28527749\n\n\nclass FocalLoss(nn.Module):\n    r\"\"\"\n    This criterion is a implemenation of Focal Loss, which is proposed in\n    Focal Loss for Dense Object Detection.\n\n        Loss(x, class) = - \\alpha (1-softmax(x)[class])^gamma \\log(softmax(x)[class])\n\n    The losses are averaged across observations for each minibatch.\n\n    Args:\n        alpha(1D Tensor, Variable) : the scalar factor for this criterion\n        gamma(float, double) : gamma > 0; reduces the relative loss for well-classiﬁed examples (p > .5),\n                               putting more focus on hard, misclassiﬁed examples\n        size_average(bool): By default, the losses are averaged over observations for each minibatch.\n                            However, if the field size_average is set to False, the losses are\n                            instead summed for each minibatch.\n\n\n    \"\"\"\n\n    def __init__(self, class_num, alpha=None, gamma=2, size_average=True):\n        super(FocalLoss, self).__init__()\n        if alpha is None:\n            self.alpha = Variable(torch.ones(class_num, 1))\n        else:\n            if isinstance(alpha, Variable):\n                self.alpha = alpha\n            else:\n                self.alpha = Variable(alpha)\n        self.gamma = gamma\n        self.class_num = class_num\n        self.size_average = size_average\n\n    def forward(self, inputs, targets):\n        N = inputs.size(0)\n        C = inputs.size(1)\n        P = F.softmax(inputs)\n\n        class_mask = inputs.data.new(N, C).fill_(0)\n        class_mask = Variable(class_mask)\n        ids = targets.view(-1, 1)\n        class_mask.scatter_(1, ids.data, 1.0)\n\n        if inputs.is_cuda and not self.alpha.is_cuda:\n            self.alpha = self.alpha.cuda()\n        alpha = self.alpha[ids.data.view(-1)]\n\n        probs = (P * class_mask).sum(1).view(-1, 1)\n\n        log_p = probs.log()\n\n        batch_loss = -alpha * (torch.pow((1 - probs), self.gamma)) * log_p\n\n        if self.size_average:\n            loss = batch_loss.mean()\n        else:\n            loss = batch_loss.sum()\n        return loss\n"
  },
  {
    "path": "examples/pytorch/hilander/models/graphconv.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n\nimport dgl.function as fn\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dgl.nn.pytorch import GATConv\nfrom torch.nn import init\n\n\nclass GraphConvLayer(nn.Module):\n    def __init__(self, in_feats, out_feats, bias=True):\n        super(GraphConvLayer, self).__init__()\n        self.mlp = nn.Linear(in_feats * 2, out_feats, bias=bias)\n\n    def forward(self, bipartite, feat):\n        if isinstance(feat, tuple):\n            srcfeat, dstfeat = feat\n        else:\n            srcfeat = feat\n            dstfeat = feat[: bipartite.num_dst_nodes()]\n        graph = bipartite.local_var()\n\n        graph.srcdata[\"h\"] = srcfeat\n        graph.update_all(\n            fn.u_mul_e(\"h\", \"affine\", \"m\"), fn.sum(msg=\"m\", out=\"h\")\n        )\n\n        gcn_feat = torch.cat([dstfeat, graph.dstdata[\"h\"]], dim=-1)\n        out = self.mlp(gcn_feat)\n        return out\n\n\nclass GraphConv(nn.Module):\n    def __init__(self, in_dim, out_dim, dropout=0, use_GAT=False, K=1):\n        super(GraphConv, self).__init__()\n        self.in_dim = in_dim\n        self.out_dim = out_dim\n\n        if use_GAT:\n            self.gcn_layer = GATConv(\n                in_dim, out_dim, K, allow_zero_in_degree=True\n            )\n            self.bias = nn.Parameter(torch.Tensor(K, out_dim))\n            init.constant_(self.bias, 0)\n        else:\n            self.gcn_layer = GraphConvLayer(in_dim, out_dim, bias=True)\n\n        self.dropout = dropout\n        self.use_GAT = use_GAT\n\n    def forward(self, bipartite, features):\n        out = self.gcn_layer(bipartite, features)\n\n        if self.use_GAT:\n            out = torch.mean(out + self.bias, dim=1)\n\n        out = out.reshape(out.shape[0], -1)\n        out = F.relu(out)\n        if self.dropout > 0:\n            out = F.dropout(out, self.dropout, training=self.training)\n\n        return out\n"
  },
  {
    "path": "examples/pytorch/hilander/models/lander.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\nimport dgl\nimport dgl.function as fn\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom .focal_loss import FocalLoss\nfrom .graphconv import GraphConv\n\n\nclass LANDER(nn.Module):\n    def __init__(\n        self,\n        feature_dim,\n        nhid,\n        num_conv=4,\n        dropout=0,\n        use_GAT=True,\n        K=1,\n        balance=False,\n        use_cluster_feat=True,\n        use_focal_loss=True,\n        **kwargs\n    ):\n        super(LANDER, self).__init__()\n        nhid_half = int(nhid / 2)\n        self.use_cluster_feat = use_cluster_feat\n        self.use_focal_loss = use_focal_loss\n\n        if self.use_cluster_feat:\n            self.feature_dim = feature_dim * 2\n        else:\n            self.feature_dim = feature_dim\n\n        input_dim = (feature_dim, nhid, nhid, nhid_half)\n        output_dim = (nhid, nhid, nhid_half, nhid_half)\n        self.conv = nn.ModuleList()\n        self.conv.append(GraphConv(self.feature_dim, nhid, dropout, use_GAT, K))\n        for i in range(1, num_conv):\n            self.conv.append(\n                GraphConv(input_dim[i], output_dim[i], dropout, use_GAT, K)\n            )\n\n        self.src_mlp = nn.Linear(output_dim[num_conv - 1], nhid_half)\n        self.dst_mlp = nn.Linear(output_dim[num_conv - 1], nhid_half)\n\n        self.classifier_conn = nn.Sequential(\n            nn.PReLU(nhid_half),\n            nn.Linear(nhid_half, nhid_half),\n            nn.PReLU(nhid_half),\n            nn.Linear(nhid_half, 2),\n        )\n\n        if self.use_focal_loss:\n            self.loss_conn = FocalLoss(2)\n        else:\n            self.loss_conn = nn.CrossEntropyLoss()\n        self.loss_den = nn.MSELoss()\n\n        self.balance = balance\n\n    def pred_conn(self, edges):\n        src_feat = self.src_mlp(edges.src[\"conv_features\"])\n        dst_feat = self.dst_mlp(edges.dst[\"conv_features\"])\n        pred_conn = self.classifier_conn(src_feat + dst_feat)\n        return {\"pred_conn\": pred_conn}\n\n    def pred_den_msg(self, edges):\n        prob = edges.data[\"prob_conn\"]\n        res = edges.data[\"raw_affine\"] * (prob[:, 1] - prob[:, 0])\n        return {\"pred_den_msg\": res}\n\n    def forward(self, bipartites):\n        if isinstance(bipartites, dgl.DGLGraph):\n            bipartites = [bipartites] * len(self.conv)\n            if self.use_cluster_feat:\n                neighbor_x = torch.cat(\n                    [\n                        bipartites[0].ndata[\"features\"],\n                        bipartites[0].ndata[\"cluster_features\"],\n                    ],\n                    axis=1,\n                )\n            else:\n                neighbor_x = bipartites[0].ndata[\"features\"]\n\n            for i in range(len(self.conv)):\n                neighbor_x = self.conv[i](bipartites[i], neighbor_x)\n\n            output_bipartite = bipartites[-1]\n            output_bipartite.ndata[\"conv_features\"] = neighbor_x\n        else:\n            if self.use_cluster_feat:\n                neighbor_x_src = torch.cat(\n                    [\n                        bipartites[0].srcdata[\"features\"],\n                        bipartites[0].srcdata[\"cluster_features\"],\n                    ],\n                    axis=1,\n                )\n                center_x_src = torch.cat(\n                    [\n                        bipartites[1].srcdata[\"features\"],\n                        bipartites[1].srcdata[\"cluster_features\"],\n                    ],\n                    axis=1,\n                )\n            else:\n                neighbor_x_src = bipartites[0].srcdata[\"features\"]\n                center_x_src = bipartites[1].srcdata[\"features\"]\n\n            for i in range(len(self.conv)):\n                neighbor_x_dst = neighbor_x_src[: bipartites[i].num_dst_nodes()]\n                neighbor_x_src = self.conv[i](\n                    bipartites[i], (neighbor_x_src, neighbor_x_dst)\n                )\n                center_x_dst = center_x_src[: bipartites[i + 1].num_dst_nodes()]\n                center_x_src = self.conv[i](\n                    bipartites[i + 1], (center_x_src, center_x_dst)\n                )\n\n            output_bipartite = bipartites[-1]\n            output_bipartite.srcdata[\"conv_features\"] = neighbor_x_src\n            output_bipartite.dstdata[\"conv_features\"] = center_x_src\n\n        output_bipartite.apply_edges(self.pred_conn)\n        output_bipartite.edata[\"prob_conn\"] = F.softmax(\n            output_bipartite.edata[\"pred_conn\"], dim=1\n        )\n        output_bipartite.update_all(\n            self.pred_den_msg, fn.mean(\"pred_den_msg\", \"pred_den\")\n        )\n        return output_bipartite\n\n    def compute_loss(self, bipartite):\n        pred_den = bipartite.dstdata[\"pred_den\"]\n        loss_den = self.loss_den(pred_den, bipartite.dstdata[\"density\"])\n\n        labels_conn = bipartite.edata[\"labels_conn\"]\n        mask_conn = bipartite.edata[\"mask_conn\"]\n\n        if self.balance:\n            labels_conn = bipartite.edata[\"labels_conn\"]\n            neg_check = torch.logical_and(\n                bipartite.edata[\"labels_conn\"] == 0, mask_conn\n            )\n            num_neg = torch.sum(neg_check).item()\n            neg_indices = torch.where(neg_check)[0]\n            pos_check = torch.logical_and(\n                bipartite.edata[\"labels_conn\"] == 1, mask_conn\n            )\n            num_pos = torch.sum(pos_check).item()\n            pos_indices = torch.where(pos_check)[0]\n            if num_pos > num_neg:\n                mask_conn[\n                    pos_indices[\n                        np.random.choice(\n                            num_pos, num_pos - num_neg, replace=False\n                        )\n                    ]\n                ] = 0\n            elif num_pos < num_neg:\n                mask_conn[\n                    neg_indices[\n                        np.random.choice(\n                            num_neg, num_neg - num_pos, replace=False\n                        )\n                    ]\n                ] = 0\n\n        # In subgraph training, it may happen that all edges are masked in a batch\n        if mask_conn.sum() > 0:\n            loss_conn = self.loss_conn(\n                bipartite.edata[\"pred_conn\"][mask_conn], labels_conn[mask_conn]\n            )\n            loss = loss_den + loss_conn\n            loss_den_val = loss_den.item()\n            loss_conn_val = loss_conn.item()\n        else:\n            loss = loss_den\n            loss_den_val = loss_den.item()\n            loss_conn_val = 0\n\n        return loss, loss_den_val, loss_conn_val\n"
  },
  {
    "path": "examples/pytorch/hilander/scripts/test_deepglint_hannah.sh",
    "content": "python test_subg.py --data_path data/subcenter_arcface_deepglint_hannah_features.pkl --model_filename checkpoint/deepglint_sampler.pth --knn_k 10 --tau 0.8 --level 10 --threshold prob --faiss_gpu --hidden 512 --num_conv 1 --batch_size 4096 --early_stop --use_cluster_feat\n"
  },
  {
    "path": "examples/pytorch/hilander/scripts/test_deepglint_imdb.sh",
    "content": "python test_subg.py --data_path data/subcenter_arcface_deepglint_imdb_features.pkl --model_filename checkpoint/deepglint_sampler.pth --knn_k 10 --tau 0.8 --level 10 --threshold prob --faiss_gpu --hidden 512 --num_conv 1 --batch_size 4096 --early_stop --use_cluster_feat\n"
  },
  {
    "path": "examples/pytorch/hilander/scripts/test_deepglint_imdb_sampled_as_deepglint.sh",
    "content": "python test_subg.py --data_path data/subcenter_arcface_deepglint_imdb_features_sampled_as_deepglint_1_in_10.pkl --model_filename checkpoint/deepglint_sampler.pth --knn_k 10 --tau 0.8 --level 10 --threshold prob --faiss_gpu --hidden 512 --num_conv 1 --batch_size 4096 --early_stop --use_cluster_feat\n"
  },
  {
    "path": "examples/pytorch/hilander/scripts/test_inat.sh",
    "content": "python test_subg.py --data_path data/inat2018_test.pkl --model_filename checkpoint/inat.ckpt --knn_k 10 --tau 0.1 --level 10 --threshold prob --faiss_gpu --hidden 512 --num_conv 1 --gat --batch_size 4096 --early_stop\n"
  },
  {
    "path": "examples/pytorch/hilander/scripts/test_inat_train_on_resampled_1_in_6_per_class.sh",
    "content": "python test_subg.py --data_path data/inat2018_test.pkl --model_filename checkpoint/inat_resampled_1_in_6_per_class.ckpt --knn_k 10 --tau 0.1 --level 10 --threshold prob --faiss_gpu --hidden 512 --num_conv 1 --gat --batch_size 4096 --early_stop\n"
  },
  {
    "path": "examples/pytorch/hilander/scripts/train_deepglint.sh",
    "content": "python train_subg.py --data_path data/subcenter_arcface_deepglint_train_1_in_10_recreated.pkl --model_filename checkpoint/deepglint_sampler.pth --knn_k 10,5,3 --levels 2,3,4 --faiss_gpu --hidden 512 --epochs 250 --lr 0.01 --batch_size 4096 --num_conv 1 --balance --use_cluster_feat\n"
  },
  {
    "path": "examples/pytorch/hilander/scripts/train_inat.sh",
    "content": "python train_subg.py --data_path data/inat2018_train_dedup_inter_intra.pkl --model_filename  checkpoint/inat.ckpt --knn_k 10,5,3 --levels 2,3,4 --faiss_gpu --hidden 512 --epochs 250 --lr 0.01 --batch_size 4096 --num_conv 1 --gat --balance\n"
  },
  {
    "path": "examples/pytorch/hilander/scripts/train_inat_resampled_1_in_6_per_class.sh",
    "content": "python train_subg.py --data_path data/inat2018_train_dedup_inter_intra_1_in_6_per_class.pkl --model_filename  checkpoint/inat_resampled_1_in_6_per_class.ckpt --knn_k 10,5,3 --levels 2,3,4 --faiss_gpu --hidden 512 --epochs 250 --lr 0.01 --batch_size 4096 --num_conv 1 --gat --balance\n"
  },
  {
    "path": "examples/pytorch/hilander/test.py",
    "content": "import argparse\nimport os\nimport pickle\nimport time\n\nimport dgl\n\nimport numpy as np\nimport torch\nimport torch.optim as optim\nfrom dataset import LanderDataset\nfrom models import LANDER\nfrom utils import build_next_level, decode, evaluation, stop_iterating\n\n###########\n# ArgParser\nparser = argparse.ArgumentParser()\n\n# Dataset\nparser.add_argument(\"--data_path\", type=str, required=True)\nparser.add_argument(\"--model_filename\", type=str, default=\"lander.pth\")\nparser.add_argument(\"--faiss_gpu\", action=\"store_true\")\nparser.add_argument(\"--early_stop\", action=\"store_true\")\n\n# HyperParam\nparser.add_argument(\"--knn_k\", type=int, default=10)\nparser.add_argument(\"--levels\", type=int, default=1)\nparser.add_argument(\"--tau\", type=float, default=0.5)\nparser.add_argument(\"--threshold\", type=str, default=\"prob\")\nparser.add_argument(\"--metrics\", type=str, default=\"pairwise,bcubed,nmi\")\n\n# Model\nparser.add_argument(\"--hidden\", type=int, default=512)\nparser.add_argument(\"--num_conv\", type=int, default=4)\nparser.add_argument(\"--dropout\", type=float, default=0.0)\nparser.add_argument(\"--gat\", action=\"store_true\")\nparser.add_argument(\"--gat_k\", type=int, default=1)\nparser.add_argument(\"--balance\", action=\"store_true\")\nparser.add_argument(\"--use_cluster_feat\", action=\"store_true\")\nparser.add_argument(\"--use_focal_loss\", action=\"store_true\")\nparser.add_argument(\"--use_gt\", action=\"store_true\")\n\nargs = parser.parse_args()\n\n###########################\n# Environment Configuration\nif torch.cuda.is_available():\n    device = torch.device(\"cuda\")\nelse:\n    device = torch.device(\"cpu\")\n\n##################\n# Data Preparation\nwith open(args.data_path, \"rb\") as f:\n    features, labels = pickle.load(f)\nglobal_features = features.copy()\ndataset = LanderDataset(\n    features=features,\n    labels=labels,\n    k=args.knn_k,\n    levels=1,\n    faiss_gpu=args.faiss_gpu,\n)\ng = dataset.gs[0].to(device)\nglobal_labels = labels.copy()\nids = np.arange(g.num_nodes())\nglobal_edges = ([], [])\nglobal_edges_len = len(global_edges[0])\nglobal_num_nodes = g.num_nodes()\n\n##################\n# Model Definition\nif not args.use_gt:\n    feature_dim = g.ndata[\"features\"].shape[1]\n    model = LANDER(\n        feature_dim=feature_dim,\n        nhid=args.hidden,\n        num_conv=args.num_conv,\n        dropout=args.dropout,\n        use_GAT=args.gat,\n        K=args.gat_k,\n        balance=args.balance,\n        use_cluster_feat=args.use_cluster_feat,\n        use_focal_loss=args.use_focal_loss,\n    )\n    model.load_state_dict(torch.load(args.model_filename, weights_only=False))\n    model = model.to(device)\n    model.eval()\n\n# number of edges added is the indicator for early stopping\nnum_edges_add_last_level = np.Inf\n##################################\n# Predict connectivity and density\nfor level in range(args.levels):\n    if not args.use_gt:\n        with torch.no_grad():\n            g = model(g)\n    (\n        new_pred_labels,\n        peaks,\n        global_edges,\n        global_pred_labels,\n        global_peaks,\n    ) = decode(\n        g,\n        args.tau,\n        args.threshold,\n        args.use_gt,\n        ids,\n        global_edges,\n        global_num_nodes,\n    )\n    ids = ids[peaks]\n    new_global_edges_len = len(global_edges[0])\n    num_edges_add_this_level = new_global_edges_len - global_edges_len\n    if stop_iterating(\n        level,\n        args.levels,\n        args.early_stop,\n        num_edges_add_this_level,\n        num_edges_add_last_level,\n        args.knn_k,\n    ):\n        break\n    global_edges_len = new_global_edges_len\n    num_edges_add_last_level = num_edges_add_this_level\n\n    # build new dataset\n    features, labels, cluster_features = build_next_level(\n        features,\n        labels,\n        peaks,\n        global_features,\n        global_pred_labels,\n        global_peaks,\n    )\n    # After the first level, the number of nodes reduce a lot. Using cpu faiss is faster.\n    dataset = LanderDataset(\n        features=features,\n        labels=labels,\n        k=args.knn_k,\n        levels=1,\n        faiss_gpu=False,\n        cluster_features=cluster_features,\n    )\n    if len(dataset.gs) == 0:\n        break\n    g = dataset.gs[0].to(device)\nevaluation(global_pred_labels, global_labels, args.metrics)\n"
  },
  {
    "path": "examples/pytorch/hilander/test_subg.py",
    "content": "import argparse\nimport os\nimport pickle\nimport time\n\nimport dgl\n\nimport numpy as np\nimport torch\nimport torch.optim as optim\nfrom dataset import LanderDataset\nfrom models import LANDER\nfrom utils import build_next_level, decode, evaluation, stop_iterating\n\n###########\n# ArgParser\nparser = argparse.ArgumentParser()\n\n# Dataset\nparser.add_argument(\"--data_path\", type=str, required=True)\nparser.add_argument(\"--model_filename\", type=str, default=\"lander.pth\")\nparser.add_argument(\"--faiss_gpu\", action=\"store_true\")\nparser.add_argument(\"--num_workers\", type=int, default=0)\n\n# HyperParam\nparser.add_argument(\"--knn_k\", type=int, default=10)\nparser.add_argument(\"--levels\", type=int, default=1)\nparser.add_argument(\"--tau\", type=float, default=0.5)\nparser.add_argument(\"--threshold\", type=str, default=\"prob\")\nparser.add_argument(\"--metrics\", type=str, default=\"pairwise,bcubed,nmi\")\nparser.add_argument(\"--early_stop\", action=\"store_true\")\n\n# Model\nparser.add_argument(\"--hidden\", type=int, default=512)\nparser.add_argument(\"--num_conv\", type=int, default=4)\nparser.add_argument(\"--dropout\", type=float, default=0.0)\nparser.add_argument(\"--gat\", action=\"store_true\")\nparser.add_argument(\"--gat_k\", type=int, default=1)\nparser.add_argument(\"--balance\", action=\"store_true\")\nparser.add_argument(\"--use_cluster_feat\", action=\"store_true\")\nparser.add_argument(\"--use_focal_loss\", action=\"store_true\")\nparser.add_argument(\"--use_gt\", action=\"store_true\")\n\n# Subgraph\nparser.add_argument(\"--batch_size\", type=int, default=4096)\n\nargs = parser.parse_args()\nprint(args)\n\n###########################\n# Environment Configuration\nif torch.cuda.is_available():\n    device = torch.device(\"cuda\")\nelse:\n    device = torch.device(\"cpu\")\n\n##################\n# Data Preparation\nwith open(args.data_path, \"rb\") as f:\n    features, labels = pickle.load(f)\nglobal_features = features.copy()\ndataset = LanderDataset(\n    features=features,\n    labels=labels,\n    k=args.knn_k,\n    levels=1,\n    faiss_gpu=args.faiss_gpu,\n)\ng = dataset.gs[0]\ng.ndata[\"pred_den\"] = torch.zeros((g.num_nodes()))\ng.edata[\"prob_conn\"] = torch.zeros((g.num_edges(), 2))\nglobal_labels = labels.copy()\nids = np.arange(g.num_nodes())\nglobal_edges = ([], [])\nglobal_peaks = np.array([], dtype=np.long)\nglobal_edges_len = len(global_edges[0])\nglobal_num_nodes = g.num_nodes()\n\nfanouts = [args.knn_k - 1 for i in range(args.num_conv + 1)]\nsampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts)\n# fix the number of edges\ntest_loader = dgl.dataloading.DataLoader(\n    g,\n    torch.arange(g.num_nodes()),\n    sampler,\n    batch_size=args.batch_size,\n    shuffle=False,\n    drop_last=False,\n    num_workers=args.num_workers,\n)\n\n##################\n# Model Definition\nif not args.use_gt:\n    feature_dim = g.ndata[\"features\"].shape[1]\n    model = LANDER(\n        feature_dim=feature_dim,\n        nhid=args.hidden,\n        num_conv=args.num_conv,\n        dropout=args.dropout,\n        use_GAT=args.gat,\n        K=args.gat_k,\n        balance=args.balance,\n        use_cluster_feat=args.use_cluster_feat,\n        use_focal_loss=args.use_focal_loss,\n    )\n    model.load_state_dict(torch.load(args.model_filename, weights_only=False))\n    model = model.to(device)\n    model.eval()\n\n# number of edges added is the indicator for early stopping\nnum_edges_add_last_level = np.Inf\n##################################\n# Predict connectivity and density\nfor level in range(args.levels):\n    if not args.use_gt:\n        total_batches = len(test_loader)\n        for batch, minibatch in enumerate(test_loader):\n            input_nodes, sub_g, bipartites = minibatch\n            sub_g = sub_g.to(device)\n            bipartites = [b.to(device) for b in bipartites]\n            with torch.no_grad():\n                output_bipartite = model(bipartites)\n            global_nid = output_bipartite.dstdata[dgl.NID]\n            global_eid = output_bipartite.edata[\"global_eid\"]\n            g.ndata[\"pred_den\"][global_nid] = output_bipartite.dstdata[\n                \"pred_den\"\n            ].to(\"cpu\")\n            g.edata[\"prob_conn\"][global_eid] = output_bipartite.edata[\n                \"prob_conn\"\n            ].to(\"cpu\")\n            torch.cuda.empty_cache()\n            if (batch + 1) % 10 == 0:\n                print(\"Batch %d / %d for inference\" % (batch, total_batches))\n\n    (\n        new_pred_labels,\n        peaks,\n        global_edges,\n        global_pred_labels,\n        global_peaks,\n    ) = decode(\n        g,\n        args.tau,\n        args.threshold,\n        args.use_gt,\n        ids,\n        global_edges,\n        global_num_nodes,\n        global_peaks,\n    )\n    ids = ids[peaks]\n    new_global_edges_len = len(global_edges[0])\n    num_edges_add_this_level = new_global_edges_len - global_edges_len\n    if stop_iterating(\n        level,\n        args.levels,\n        args.early_stop,\n        num_edges_add_this_level,\n        num_edges_add_last_level,\n        args.knn_k,\n    ):\n        break\n    global_edges_len = new_global_edges_len\n    num_edges_add_last_level = num_edges_add_this_level\n\n    # build new dataset\n    features, labels, cluster_features = build_next_level(\n        features,\n        labels,\n        peaks,\n        global_features,\n        global_pred_labels,\n        global_peaks,\n    )\n    # After the first level, the number of nodes reduce a lot. Using cpu faiss is faster.\n    dataset = LanderDataset(\n        features=features,\n        labels=labels,\n        k=args.knn_k,\n        levels=1,\n        faiss_gpu=False,\n        cluster_features=cluster_features,\n    )\n    g = dataset.gs[0]\n    g.ndata[\"pred_den\"] = torch.zeros((g.num_nodes()))\n    g.edata[\"prob_conn\"] = torch.zeros((g.num_edges(), 2))\n    test_loader = dgl.dataloading.DataLoader(\n        g,\n        torch.arange(g.num_nodes()),\n        sampler,\n        batch_size=args.batch_size,\n        shuffle=False,\n        drop_last=False,\n        num_workers=args.num_workers,\n    )\nevaluation(global_pred_labels, global_labels, args.metrics)\n"
  },
  {
    "path": "examples/pytorch/hilander/train.py",
    "content": "import argparse\nimport os\nimport pickle\nimport time\n\nimport dgl\n\nimport numpy as np\nimport torch\nimport torch.optim as optim\nfrom dataset import LanderDataset\nfrom models import LANDER\n\n###########\n# ArgParser\nparser = argparse.ArgumentParser()\n\n# Dataset\nparser.add_argument(\"--data_path\", type=str, required=True)\nparser.add_argument(\"--test_data_path\", type=str, required=True)\nparser.add_argument(\"--levels\", type=str, default=\"1\")\nparser.add_argument(\"--faiss_gpu\", action=\"store_true\")\nparser.add_argument(\"--model_filename\", type=str, default=\"lander.pth\")\n\n# KNN\nparser.add_argument(\"--knn_k\", type=str, default=\"10\")\n\n# Model\nparser.add_argument(\"--hidden\", type=int, default=512)\nparser.add_argument(\"--num_conv\", type=int, default=4)\nparser.add_argument(\"--dropout\", type=float, default=0.0)\nparser.add_argument(\"--gat\", action=\"store_true\")\nparser.add_argument(\"--gat_k\", type=int, default=1)\nparser.add_argument(\"--balance\", action=\"store_true\")\nparser.add_argument(\"--use_cluster_feat\", action=\"store_true\")\nparser.add_argument(\"--use_focal_loss\", action=\"store_true\")\n\n# Training\nparser.add_argument(\"--epochs\", type=int, default=100)\nparser.add_argument(\"--lr\", type=float, default=0.1)\nparser.add_argument(\"--momentum\", type=float, default=0.9)\nparser.add_argument(\"--weight_decay\", type=float, default=1e-5)\n\nargs = parser.parse_args()\n\n###########################\n# Environment Configuration\nif torch.cuda.is_available():\n    device = torch.device(\"cuda\")\nelse:\n    device = torch.device(\"cpu\")\n\n\n##################\n# Data Preparation\ndef prepare_dataset_graphs(data_path, k_list, lvl_list):\n    with open(data_path, \"rb\") as f:\n        features, labels = pickle.load(f)\n    gs = []\n    for k, l in zip(k_list, lvl_list):\n        dataset = LanderDataset(\n            features=features,\n            labels=labels,\n            k=k,\n            levels=l,\n            faiss_gpu=args.faiss_gpu,\n        )\n        gs += [g.to(device) for g in dataset.gs]\n    return gs\n\n\nk_list = [int(k) for k in args.knn_k.split(\",\")]\nlvl_list = [int(l) for l in args.levels.split(\",\")]\ngs = prepare_dataset_graphs(args.data_path, k_list, lvl_list)\ntest_gs = prepare_dataset_graphs(args.test_data_path, k_list, lvl_list)\n\n##################\n# Model Definition\nfeature_dim = gs[0].ndata[\"features\"].shape[1]\nmodel = LANDER(\n    feature_dim=feature_dim,\n    nhid=args.hidden,\n    num_conv=args.num_conv,\n    dropout=args.dropout,\n    use_GAT=args.gat,\n    K=args.gat_k,\n    balance=args.balance,\n    use_cluster_feat=args.use_cluster_feat,\n    use_focal_loss=args.use_focal_loss,\n)\nmodel = model.to(device)\nmodel.train()\nbest_model = None\nbest_loss = np.Inf\n\n#################\n# Hyperparameters\nopt = optim.SGD(\n    model.parameters(),\n    lr=args.lr,\n    momentum=args.momentum,\n    weight_decay=args.weight_decay,\n)\nscheduler = optim.lr_scheduler.CosineAnnealingLR(\n    opt, T_max=args.epochs, eta_min=1e-5\n)\n\n###############\n# Training Loop\nfor epoch in range(args.epochs):\n    all_loss_den_val = 0\n    all_loss_conn_val = 0\n    for g in gs:\n        opt.zero_grad()\n        g = model(g)\n        loss, loss_den_val, loss_conn_val = model.compute_loss(g)\n        all_loss_den_val += loss_den_val\n        all_loss_conn_val += loss_conn_val\n        loss.backward()\n        opt.step()\n    scheduler.step()\n    print(\n        \"Training, epoch: %d, loss_den: %.6f, loss_conn: %.6f\"\n        % (epoch, all_loss_den_val, all_loss_conn_val)\n    )\n    # Report test\n    all_test_loss_den_val = 0\n    all_test_loss_conn_val = 0\n    with torch.no_grad():\n        for g in test_gs:\n            g = model(g)\n            loss, loss_den_val, loss_conn_val = model.compute_loss(g)\n            all_test_loss_den_val += loss_den_val\n            all_test_loss_conn_val += loss_conn_val\n    print(\n        \"Testing, epoch: %d, loss_den: %.6f, loss_conn: %.6f\"\n        % (epoch, all_test_loss_den_val, all_test_loss_conn_val)\n    )\n    if all_test_loss_conn_val + all_test_loss_den_val < best_loss:\n        best_loss = all_test_loss_conn_val + all_test_loss_den_val\n        print(\"New best epoch\", epoch)\n        torch.save(model.state_dict(), args.model_filename + \"_best\")\n    torch.save(model.state_dict(), args.model_filename)\n\ntorch.save(model.state_dict(), args.model_filename)\n"
  },
  {
    "path": "examples/pytorch/hilander/train_subg.py",
    "content": "import argparse\nimport os\nimport pickle\nimport time\n\nimport dgl\n\nimport numpy as np\nimport torch\nimport torch.optim as optim\nfrom dataset import LanderDataset\nfrom models import LANDER\n\n###########\n# ArgParser\nparser = argparse.ArgumentParser()\n\n# Dataset\nparser.add_argument(\"--data_path\", type=str, required=True)\nparser.add_argument(\"--levels\", type=str, default=\"1\")\nparser.add_argument(\"--faiss_gpu\", action=\"store_true\")\nparser.add_argument(\"--model_filename\", type=str, default=\"lander.pth\")\n\n# KNN\nparser.add_argument(\"--knn_k\", type=str, default=\"10\")\nparser.add_argument(\"--num_workers\", type=int, default=0)\n\n# Model\nparser.add_argument(\"--hidden\", type=int, default=512)\nparser.add_argument(\"--num_conv\", type=int, default=1)\nparser.add_argument(\"--dropout\", type=float, default=0.0)\nparser.add_argument(\"--gat\", action=\"store_true\")\nparser.add_argument(\"--gat_k\", type=int, default=1)\nparser.add_argument(\"--balance\", action=\"store_true\")\nparser.add_argument(\"--use_cluster_feat\", action=\"store_true\")\nparser.add_argument(\"--use_focal_loss\", action=\"store_true\")\n\n# Training\nparser.add_argument(\"--epochs\", type=int, default=100)\nparser.add_argument(\"--batch_size\", type=int, default=1024)\nparser.add_argument(\"--lr\", type=float, default=0.1)\nparser.add_argument(\"--momentum\", type=float, default=0.9)\nparser.add_argument(\"--weight_decay\", type=float, default=1e-5)\n\nargs = parser.parse_args()\nprint(args)\n\n###########################\n# Environment Configuration\nif torch.cuda.is_available():\n    device = torch.device(\"cuda\")\nelse:\n    device = torch.device(\"cpu\")\n\n##################\n# Data Preparation\nwith open(args.data_path, \"rb\") as f:\n    features, labels = pickle.load(f)\n\nk_list = [int(k) for k in args.knn_k.split(\",\")]\nlvl_list = [int(l) for l in args.levels.split(\",\")]\ngs = []\nnbrs = []\nks = []\nfor k, l in zip(k_list, lvl_list):\n    dataset = LanderDataset(\n        features=features,\n        labels=labels,\n        k=k,\n        levels=l,\n        faiss_gpu=args.faiss_gpu,\n    )\n    gs += [g for g in dataset.gs]\n    ks += [k for g in dataset.gs]\n    nbrs += [nbr for nbr in dataset.nbrs]\n\nprint(\"Dataset Prepared.\")\n\n\ndef set_train_sampler_loader(g, k):\n    fanouts = [k - 1 for i in range(args.num_conv + 1)]\n    sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts)\n    # fix the number of edges\n    train_dataloader = dgl.dataloading.DataLoader(\n        g,\n        torch.arange(g.num_nodes()),\n        sampler,\n        batch_size=args.batch_size,\n        shuffle=True,\n        drop_last=False,\n        num_workers=args.num_workers,\n    )\n    return train_dataloader\n\n\ntrain_loaders = []\nfor gidx, g in enumerate(gs):\n    train_dataloader = set_train_sampler_loader(gs[gidx], ks[gidx])\n    train_loaders.append(train_dataloader)\n\n##################\n# Model Definition\nfeature_dim = gs[0].ndata[\"features\"].shape[1]\nmodel = LANDER(\n    feature_dim=feature_dim,\n    nhid=args.hidden,\n    num_conv=args.num_conv,\n    dropout=args.dropout,\n    use_GAT=args.gat,\n    K=args.gat_k,\n    balance=args.balance,\n    use_cluster_feat=args.use_cluster_feat,\n    use_focal_loss=args.use_focal_loss,\n)\nmodel = model.to(device)\nmodel.train()\n\n#################\n# Hyperparameters\nopt = optim.SGD(\n    model.parameters(),\n    lr=args.lr,\n    momentum=args.momentum,\n    weight_decay=args.weight_decay,\n)\n\n# keep num_batch_per_loader the same for every sub_dataloader\nnum_batch_per_loader = len(train_loaders[0])\ntrain_loaders = [iter(train_loader) for train_loader in train_loaders]\nnum_loaders = len(train_loaders)\nscheduler = optim.lr_scheduler.CosineAnnealingLR(\n    opt, T_max=args.epochs * num_batch_per_loader * num_loaders, eta_min=1e-5\n)\n\nprint(\"Start Training.\")\n\n###############\n# Training Loop\nfor epoch in range(args.epochs):\n    loss_den_val_total = []\n    loss_conn_val_total = []\n    loss_val_total = []\n    for batch in range(num_batch_per_loader):\n        for loader_id in range(num_loaders):\n            try:\n                minibatch = next(train_loaders[loader_id])\n            except:\n                train_loaders[loader_id] = iter(\n                    set_train_sampler_loader(gs[loader_id], ks[loader_id])\n                )\n                minibatch = next(train_loaders[loader_id])\n            input_nodes, sub_g, bipartites = minibatch\n            sub_g = sub_g.to(device)\n            bipartites = [b.to(device) for b in bipartites]\n            # get the feature for the input_nodes\n            opt.zero_grad()\n            output_bipartite = model(bipartites)\n            loss, loss_den_val, loss_conn_val = model.compute_loss(\n                output_bipartite\n            )\n            loss_den_val_total.append(loss_den_val)\n            loss_conn_val_total.append(loss_conn_val)\n            loss_val_total.append(loss.item())\n            loss.backward()\n            opt.step()\n            if (batch + 1) % 10 == 0:\n                print(\n                    \"epoch: %d, batch: %d / %d, loader_id : %d / %d, loss: %.6f, loss_den: %.6f, loss_conn: %.6f\"\n                    % (\n                        epoch,\n                        batch,\n                        num_batch_per_loader,\n                        loader_id,\n                        num_loaders,\n                        loss.item(),\n                        loss_den_val,\n                        loss_conn_val,\n                    )\n                )\n            scheduler.step()\n    print(\n        \"epoch: %d, loss: %.6f, loss_den: %.6f, loss_conn: %.6f\"\n        % (\n            epoch,\n            np.array(loss_val_total).mean(),\n            np.array(loss_den_val_total).mean(),\n            np.array(loss_conn_val_total).mean(),\n        )\n    )\n    torch.save(model.state_dict(), args.model_filename)\n\ntorch.save(model.state_dict(), args.model_filename)\n"
  },
  {
    "path": "examples/pytorch/hilander/utils/__init__.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n\nfrom .adjacency import *\nfrom .deduce import *\nfrom .density import *\nfrom .evaluate import *\nfrom .faiss_gpu import faiss_search_approx_knn\nfrom .faiss_search import faiss_search_knn\nfrom .knn import *\nfrom .metrics import *\nfrom .misc import *\n"
  },
  {
    "path": "examples/pytorch/hilander/utils/adjacency.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n\"\"\"\nThis file re-uses implementation from https://github.com/yl-1993/learn-to-cluster\n\"\"\"\n\nimport numpy as np\nimport scipy.sparse as sp\nfrom scipy.sparse import coo_matrix\n\n\ndef row_normalize(mx):\n    \"\"\"Row-normalize sparse matrix\"\"\"\n    rowsum = np.array(mx.sum(1))\n    # if rowsum <= 0, keep its previous value\n    rowsum[rowsum <= 0] = 1\n    r_inv = np.power(rowsum, -1).flatten()\n    r_inv[np.isinf(r_inv)] = 0.0\n    r_mat_inv = sp.diags(r_inv)\n    mx = r_mat_inv.dot(mx)\n    return mx, r_inv\n\n\ndef sparse_mx_to_indices_values(sparse_mx):\n    sparse_mx = sparse_mx.tocoo().astype(np.float32)\n    indices = np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)\n    values = sparse_mx.data\n    shape = np.array(sparse_mx.shape)\n    return indices, values, shape\n"
  },
  {
    "path": "examples/pytorch/hilander/utils/deduce.py",
    "content": "\"\"\"\nThis file re-uses implementation from https://github.com/yl-1993/learn-to-cluster\n\"\"\"\nimport dgl\nimport numpy as np\nimport torch\nfrom sklearn import mixture\n\nfrom .density import density_to_peaks, density_to_peaks_vectorize\n\n__all__ = [\n    \"peaks_to_labels\",\n    \"edge_to_connected_graph\",\n    \"decode\",\n    \"build_next_level\",\n]\n\n\ndef _find_parent(parent, u):\n    idx = []\n    # parent is a fixed point\n    while u != parent[u]:\n        idx.append(u)\n        u = parent[u]\n    for i in idx:\n        parent[i] = u\n    return u\n\n\ndef edge_to_connected_graph(edges, num):\n    parent = list(range(num))\n    for u, v in edges:\n        p_u = _find_parent(parent, u)\n        p_v = _find_parent(parent, v)\n        parent[p_u] = p_v\n\n    for i in range(num):\n        parent[i] = _find_parent(parent, i)\n    remap = {}\n    uf = np.unique(np.array(parent))\n    for i, f in enumerate(uf):\n        remap[f] = i\n    cluster_id = np.array([remap[f] for f in parent])\n    return cluster_id\n\n\ndef peaks_to_edges(peaks, dist2peak, tau):\n    edges = []\n    for src in peaks:\n        dsts = peaks[src]\n        dists = dist2peak[src]\n        for dst, dist in zip(dsts, dists):\n            if src == dst or dist >= 1 - tau:\n                continue\n            edges.append([src, dst])\n    return edges\n\n\ndef peaks_to_labels(peaks, dist2peak, tau, inst_num):\n    edges = peaks_to_edges(peaks, dist2peak, tau)\n    pred_labels = edge_to_connected_graph(edges, inst_num)\n    return pred_labels, edges\n\n\ndef get_dists(g, nbrs, use_gt):\n    k = nbrs.shape[1]\n    src_id = nbrs[:, 1:].reshape(-1)\n    dst_id = nbrs[:, 0].repeat(k - 1)\n    eids = g.edge_ids(src_id, dst_id)\n    if use_gt:\n        new_dists = (\n            (1 - g.edata[\"labels_edge\"][eids]).reshape(-1, k - 1).float()\n        )\n    else:\n        new_dists = g.edata[\"prob_conn\"][eids, 0].reshape(-1, k - 1)\n    ind = torch.argsort(new_dists, 1)\n    offset = torch.LongTensor(\n        (nbrs[:, 0] * (k - 1)).repeat(k - 1).reshape(-1, k - 1)\n    ).to(g.device)\n    ind = ind + offset\n    nbrs = torch.LongTensor(nbrs).to(g.device)\n    new_nbrs = torch.take(nbrs[:, 1:], ind)\n    new_dists = torch.cat(\n        [torch.zeros((new_dists.shape[0], 1)).to(g.device), new_dists], dim=1\n    )\n    new_nbrs = torch.cat(\n        [torch.arange(new_nbrs.shape[0]).view(-1, 1).to(g.device), new_nbrs],\n        dim=1,\n    )\n    return new_nbrs.cpu().detach().numpy(), new_dists.cpu().detach().numpy()\n\n\ndef get_edge_dist(g, threshold):\n    if threshold == \"prob\":\n        return g.edata[\"prob_conn\"][:, 0]\n    return 1 - g.edata[\"raw_affine\"]\n\n\ndef tree_generation(ng):\n    ng.ndata[\"keep_eid\"] = torch.zeros(ng.num_nodes()).long() - 1\n\n    def message_func(edges):\n        return {\"mval\": edges.data[\"edge_dist\"], \"meid\": edges.data[dgl.EID]}\n\n    def reduce_func(nodes):\n        ind = torch.min(nodes.mailbox[\"mval\"], dim=1)[1]\n        keep_eid = nodes.mailbox[\"meid\"].gather(1, ind.view(-1, 1))\n        return {\"keep_eid\": keep_eid[:, 0]}\n\n    node_order = dgl.traversal.topological_nodes_generator(ng)\n    ng.prop_nodes(node_order, message_func, reduce_func)\n    eids = ng.ndata[\"keep_eid\"]\n    eids = eids[eids > -1]\n    edges = ng.find_edges(eids)\n    treeg = dgl.graph(edges, num_nodes=ng.num_nodes())\n    return treeg\n\n\ndef peak_propogation(treeg):\n    treeg.ndata[\"pred_labels\"] = torch.zeros(treeg.num_nodes()).long() - 1\n    peaks = torch.where(treeg.in_degrees() == 0)[0].cpu().numpy()\n    treeg.ndata[\"pred_labels\"][peaks] = torch.arange(peaks.shape[0])\n\n    def message_func(edges):\n        return {\"mlb\": edges.src[\"pred_labels\"]}\n\n    def reduce_func(nodes):\n        return {\"pred_labels\": nodes.mailbox[\"mlb\"][:, 0]}\n\n    node_order = dgl.traversal.topological_nodes_generator(treeg)\n    treeg.prop_nodes(node_order, message_func, reduce_func)\n    pred_labels = treeg.ndata[\"pred_labels\"].cpu().numpy()\n    return peaks, pred_labels\n\n\ndef decode(\n    g,\n    tau,\n    threshold,\n    use_gt,\n    ids=None,\n    global_edges=None,\n    global_num_nodes=None,\n    global_peaks=None,\n):\n    # Edge filtering with tau and density\n    den_key = \"density\" if use_gt else \"pred_den\"\n    g = g.local_var()\n    g.edata[\"edge_dist\"] = get_edge_dist(g, threshold)\n    g.apply_edges(\n        lambda edges: {\n            \"keep\": (edges.src[den_key] > edges.dst[den_key]).long()\n            * (edges.data[\"edge_dist\"] < 1 - tau).long()\n        }\n    )\n    eids = torch.where(g.edata[\"keep\"] == 0)[0]\n    ng = dgl.remove_edges(g, eids)\n\n    # Tree generation\n    ng.edata[dgl.EID] = torch.arange(ng.num_edges())\n    treeg = tree_generation(ng)\n    # Label propogation\n    peaks, pred_labels = peak_propogation(treeg)\n\n    if ids is None:\n        return pred_labels, peaks\n\n    # Merge with previous layers\n    src, dst = treeg.edges()\n    new_global_edges = (\n        global_edges[0] + ids[src.numpy()].tolist(),\n        global_edges[1] + ids[dst.numpy()].tolist(),\n    )\n    global_treeg = dgl.graph(new_global_edges, num_nodes=global_num_nodes)\n    global_peaks, global_pred_labels = peak_propogation(global_treeg)\n    return (\n        pred_labels,\n        peaks,\n        new_global_edges,\n        global_pred_labels,\n        global_peaks,\n    )\n\n\ndef build_next_level(\n    features, labels, peaks, global_features, global_pred_labels, global_peaks\n):\n    global_peak_to_label = global_pred_labels[global_peaks]\n    global_label_to_peak = np.zeros_like(global_peak_to_label)\n    for i, pl in enumerate(global_peak_to_label):\n        global_label_to_peak[pl] = i\n    cluster_ind = np.split(\n        np.argsort(global_pred_labels),\n        np.unique(np.sort(global_pred_labels), return_index=True)[1][1:],\n    )\n    cluster_features = np.zeros((len(peaks), global_features.shape[1]))\n    for pi in range(len(peaks)):\n        cluster_features[global_label_to_peak[pi], :] = np.mean(\n            global_features[cluster_ind[pi], :], axis=0\n        )\n    features = features[peaks]\n    labels = labels[peaks]\n    return features, labels, cluster_features\n"
  },
  {
    "path": "examples/pytorch/hilander/utils/density.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n\"\"\"\nThis file re-uses implementation from https://github.com/yl-1993/learn-to-cluster\n\"\"\"\n\nfrom itertools import groupby\n\nimport numpy as np\nimport torch\nfrom tqdm import tqdm\n\n__all__ = [\n    \"density_estimation\",\n    \"density_to_peaks\",\n    \"density_to_peaks_vectorize\",\n]\n\n\ndef density_estimation(dists, nbrs, labels, **kwargs):\n    \"\"\"use supervised density defined on neigborhood\"\"\"\n    num, k_knn = dists.shape\n    conf = np.ones((num,), dtype=np.float32)\n    ind_array = labels[nbrs] == np.expand_dims(labels, 1).repeat(k_knn, 1)\n    pos = ((1 - dists[:, 1:]) * ind_array[:, 1:]).sum(1)\n    neg = ((1 - dists[:, 1:]) * (1 - ind_array[:, 1:])).sum(1)\n    conf = (pos - neg) * conf\n    conf /= k_knn - 1\n    return conf\n\n\ndef density_to_peaks_vectorize(dists, nbrs, density, max_conn=1, name=\"\"):\n    # just calculate 1 connectivity\n    assert dists.shape[0] == density.shape[0]\n    assert dists.shape == nbrs.shape\n\n    num, k = dists.shape\n\n    if name == \"gcn_feat\":\n        include_mask = nbrs != np.arange(0, num).reshape(-1, 1)\n        secondary_mask = (\n            np.sum(include_mask, axis=1) == k\n        )  # TODO: the condition == k should not happen as distance to the node self should be smallest, check for numerical stability; TODO: make top M instead of only supporting top 1\n        include_mask[secondary_mask, -1] = False\n        nbrs_exclude_self = nbrs[include_mask].reshape(-1, k - 1)  # (V, 79)\n        dists_exclude_self = dists[include_mask].reshape(-1, k - 1)  # (V, 79)\n    else:\n        include_mask = nbrs != np.arange(0, num).reshape(-1, 1)\n        nbrs_exclude_self = nbrs[include_mask].reshape(-1, k - 1)  # (V, 79)\n        dists_exclude_self = dists[include_mask].reshape(-1, k - 1)  # (V, 79)\n\n    compare_map = density[nbrs_exclude_self] > density.reshape(-1, 1)\n    peak_index = np.argmax(np.where(compare_map, 1, 0), axis=1)  # (V,)\n    compare_map_sum = np.sum(compare_map.cpu().data.numpy(), axis=1)  # (V,)\n\n    dist2peak = {\n        i: []\n        if compare_map_sum[i] == 0\n        else [dists_exclude_self[i, peak_index[i]]]\n        for i in range(num)\n    }\n    peaks = {\n        i: []\n        if compare_map_sum[i] == 0\n        else [nbrs_exclude_self[i, peak_index[i]]]\n        for i in range(num)\n    }\n\n    return dist2peak, peaks\n\n\ndef density_to_peaks(dists, nbrs, density, max_conn=1, sort=\"dist\"):\n    # Note that dists has been sorted in ascending order\n    assert dists.shape[0] == density.shape[0]\n    assert dists.shape == nbrs.shape\n\n    num, _ = dists.shape\n    dist2peak = {i: [] for i in range(num)}\n    peaks = {i: [] for i in range(num)}\n\n    for i, nbr in tqdm(enumerate(nbrs)):\n        nbr_conf = density[nbr]\n        for j, c in enumerate(nbr_conf):\n            nbr_idx = nbr[j]\n            if i == nbr_idx or c <= density[i]:\n                continue\n            dist2peak[i].append(dists[i, j])\n            peaks[i].append(nbr_idx)\n            if len(dist2peak[i]) >= max_conn:\n                break\n\n    return dist2peak, peaks\n"
  },
  {
    "path": "examples/pytorch/hilander/utils/evaluate.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n\nimport argparse\nimport inspect\n\nimport numpy as np\nfrom clustering_benchmark import ClusteringBenchmark\nfrom utils import metrics, TextColors, Timer\n\n\ndef _read_meta(fn):\n    labels = list()\n    lb_set = set()\n    with open(fn) as f:\n        for lb in f.readlines():\n            lb = int(lb.strip())\n            labels.append(lb)\n            lb_set.add(lb)\n    return np.array(labels), lb_set\n\n\ndef evaluate(gt_labels, pred_labels, metric=\"pairwise\"):\n    if isinstance(gt_labels, str) and isinstance(pred_labels, str):\n        print(\"[gt_labels] {}\".format(gt_labels))\n        print(\"[pred_labels] {}\".format(pred_labels))\n        gt_labels, gt_lb_set = _read_meta(gt_labels)\n        pred_labels, pred_lb_set = _read_meta(pred_labels)\n\n        print(\n            \"#inst: gt({}) vs pred({})\".format(len(gt_labels), len(pred_labels))\n        )\n        print(\n            \"#cls: gt({}) vs pred({})\".format(len(gt_lb_set), len(pred_lb_set))\n        )\n\n    metric_func = metrics.__dict__[metric]\n\n    with Timer(\n        \"evaluate with {}{}{}\".format(TextColors.FATAL, metric, TextColors.ENDC)\n    ):\n        result = metric_func(gt_labels, pred_labels)\n    if isinstance(result, float):\n        print(\n            \"{}{}: {:.4f}{}\".format(\n                TextColors.OKGREEN, metric, result, TextColors.ENDC\n            )\n        )\n    else:\n        ave_pre, ave_rec, fscore = result\n        print(\n            \"{}ave_pre: {:.4f}, ave_rec: {:.4f}, fscore: {:.4f}{}\".format(\n                TextColors.OKGREEN, ave_pre, ave_rec, fscore, TextColors.ENDC\n            )\n        )\n\n\ndef evaluation(pred_labels, labels, metrics):\n    print(\"==> evaluation\")\n    # pred_labels = g.ndata['pred_labels'].cpu().numpy()\n    max_cluster = np.max(pred_labels)\n    # gt_labels_all = g.ndata['labels'].cpu().numpy()\n    gt_labels_all = labels\n    pred_labels_all = pred_labels\n    metric_list = metrics.split(\",\")\n    for metric in metric_list:\n        evaluate(gt_labels_all, pred_labels_all, metric)\n    # H and C-scores\n    gt_dict = {}\n    pred_dict = {}\n    for i in range(len(gt_labels_all)):\n        gt_dict[str(i)] = gt_labels_all[i]\n        pred_dict[str(i)] = pred_labels_all[i]\n    bm = ClusteringBenchmark(gt_dict)\n    scores = bm.evaluate_vmeasure(pred_dict)\n    fmi_scores = bm.evaluate_fowlkes_mallows_score(pred_dict)\n    print(scores)\n"
  },
  {
    "path": "examples/pytorch/hilander/utils/faiss_gpu.py",
    "content": "\"\"\"\nThis file re-uses implementation from https://github.com/yl-1993/learn-to-cluster\n\"\"\"\nimport gc\nimport os\n\nimport faiss\nimport numpy as np\nfrom tqdm import tqdm\n\n__all__ = [\"faiss_search_approx_knn\"]\n\n\nclass faiss_index_wrapper:\n    def __init__(\n        self,\n        target,\n        nprobe=128,\n        index_factory_str=None,\n        verbose=False,\n        mode=\"proxy\",\n        using_gpu=True,\n    ):\n        self._res_list = []\n\n        num_gpu = faiss.get_num_gpus()\n        print(\"[faiss gpu] #GPU: {}\".format(num_gpu))\n\n        size, dim = target.shape\n        assert size > 0, \"size: {}\".format(size)\n        index_factory_str = (\n            \"IVF{},PQ{}\".format(min(8192, 16 * round(np.sqrt(size))), 32)\n            if index_factory_str is None\n            else index_factory_str\n        )\n        cpu_index = faiss.index_factory(dim, index_factory_str)\n        cpu_index.nprobe = nprobe\n\n        if mode == \"proxy\":\n            co = faiss.GpuClonerOptions()\n            co.useFloat16 = True\n            co.usePrecomputed = False\n\n            index = faiss.IndexProxy()\n            for i in range(num_gpu):\n                res = faiss.StandardGpuResources()\n                self._res_list.append(res)\n                sub_index = (\n                    faiss.index_cpu_to_gpu(res, i, cpu_index, co)\n                    if using_gpu\n                    else cpu_index\n                )\n                index.addIndex(sub_index)\n        elif mode == \"shard\":\n            co = faiss.GpuMultipleClonerOptions()\n            co.useFloat16 = True\n            co.usePrecomputed = False\n            co.shard = True\n            index = faiss.index_cpu_to_all_gpus(cpu_index, co, ngpu=num_gpu)\n        else:\n            raise KeyError(\"Unknown index mode\")\n\n        index = faiss.IndexIDMap(index)\n        index.verbose = verbose\n\n        # get nlist to decide how many samples used for training\n        nlist = int(\n            float(\n                [\n                    item\n                    for item in index_factory_str.split(\",\")\n                    if \"IVF\" in item\n                ][0].replace(\"IVF\", \"\")\n            )\n        )\n\n        # training\n        if not index.is_trained:\n            indexes_sample_for_train = np.random.randint(0, size, nlist * 256)\n            index.train(target[indexes_sample_for_train])\n\n        # add with ids\n        target_ids = np.arange(0, size)\n        index.add_with_ids(target, target_ids)\n        self.index = index\n\n    def search(self, *args, **kargs):\n        return self.index.search(*args, **kargs)\n\n    def __del__(self):\n        self.index.reset()\n        del self.index\n        for res in self._res_list:\n            del res\n\n\ndef batch_search(index, query, k, bs, verbose=False):\n    n = len(query)\n    dists = np.zeros((n, k), dtype=np.float32)\n    nbrs = np.zeros((n, k), dtype=np.int64)\n\n    for sid in tqdm(\n        range(0, n, bs), desc=\"faiss searching...\", disable=not verbose\n    ):\n        eid = min(n, sid + bs)\n        dists[sid:eid], nbrs[sid:eid] = index.search(query[sid:eid], k)\n    return dists, nbrs\n\n\ndef faiss_search_approx_knn(\n    query,\n    target,\n    k,\n    nprobe=128,\n    bs=int(1e6),\n    index_factory_str=None,\n    verbose=False,\n):\n    index = faiss_index_wrapper(\n        target,\n        nprobe=nprobe,\n        index_factory_str=index_factory_str,\n        verbose=verbose,\n    )\n    dists, nbrs = batch_search(index, query, k=k, bs=bs, verbose=verbose)\n\n    del index\n    gc.collect()\n    return dists, nbrs\n"
  },
  {
    "path": "examples/pytorch/hilander/utils/faiss_search.py",
    "content": "\"\"\"\nThis file re-uses implementation from https://github.com/yl-1993/learn-to-cluster\n\"\"\"\nimport gc\n\nfrom tqdm import tqdm\n\nfrom .faiss_gpu import faiss_search_approx_knn\n\n__all__ = [\"faiss_search_knn\"]\n\n\ndef precise_dist(feat, nbrs, num_process=4, sort=True, verbose=False):\n    import torch\n\n    feat_share = torch.from_numpy(feat).share_memory_()\n    nbrs_share = torch.from_numpy(nbrs).share_memory_()\n    dist_share = torch.zeros_like(nbrs_share).float().share_memory_()\n\n    precise_dist_share_mem(\n        feat_share,\n        nbrs_share,\n        dist_share,\n        num_process=num_process,\n        sort=sort,\n        verbose=verbose,\n    )\n\n    del feat_share\n    gc.collect()\n    return dist_share.numpy(), nbrs_share.numpy()\n\n\ndef precise_dist_share_mem(\n    feat,\n    nbrs,\n    dist,\n    num_process=16,\n    sort=True,\n    process_unit=4000,\n    verbose=False,\n):\n    from torch import multiprocessing as mp\n\n    num, _ = feat.shape\n    num_per_proc = int(num / num_process) + 1\n\n    for pi in range(num_process):\n        sid = pi * num_per_proc\n        eid = min(sid + num_per_proc, num)\n\n        kwargs = {\n            \"feat\": feat,\n            \"nbrs\": nbrs,\n            \"dist\": dist,\n            \"sid\": sid,\n            \"eid\": eid,\n            \"sort\": sort,\n            \"process_unit\": process_unit,\n            \"verbose\": verbose,\n        }\n        bmm(**kwargs)\n\n\ndef bmm(\n    feat, nbrs, dist, sid, eid, sort=True, process_unit=4000, verbose=False\n):\n    import torch\n\n    _, cols = dist.shape\n    batch_sim = torch.zeros((eid - sid, cols), dtype=torch.float32)\n    for s in tqdm(\n        range(sid, eid, process_unit), desc=\"bmm\", disable=not verbose\n    ):\n        e = min(eid, s + process_unit)\n        query = feat[s:e].unsqueeze(1)\n        gallery = feat[nbrs[s:e]].permute(0, 2, 1)\n        batch_sim[s - sid : e - sid] = torch.clamp(\n            torch.bmm(query, gallery).view(-1, cols), 0.0, 1.0\n        )\n\n    if sort:\n        sort_unit = int(1e6)\n        batch_nbr = nbrs[sid:eid]\n        for s in range(0, batch_sim.shape[0], sort_unit):\n            e = min(s + sort_unit, eid)\n            batch_sim[s:e], indices = torch.sort(\n                batch_sim[s:e], descending=True\n            )\n            batch_nbr[s:e] = torch.gather(batch_nbr[s:e], 1, indices)\n        nbrs[sid:eid] = batch_nbr\n    dist[sid:eid] = 1.0 - batch_sim\n\n\ndef faiss_search_knn(\n    feat,\n    k,\n    nprobe=128,\n    num_process=4,\n    is_precise=True,\n    sort=True,\n    verbose=False,\n):\n    dists, nbrs = faiss_search_approx_knn(\n        query=feat, target=feat, k=k, nprobe=nprobe, verbose=verbose\n    )\n\n    if is_precise:\n        print(\"compute precise dist among k={} nearest neighbors\".format(k))\n        dists, nbrs = precise_dist(\n            feat, nbrs, num_process=num_process, sort=sort, verbose=verbose\n        )\n\n    return dists, nbrs\n"
  },
  {
    "path": "examples/pytorch/hilander/utils/knn.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n\"\"\"\nThis file re-uses implementation from https://github.com/yl-1993/learn-to-cluster\n\"\"\"\n\nimport math\nimport multiprocessing as mp\nimport os\n\nimport numpy as np\nfrom tqdm import tqdm\nfrom utils import Timer\n\nfrom .faiss_search import faiss_search_knn\n\n__all__ = [\n    \"knn_faiss\",\n    \"knn_faiss_gpu\",\n    \"fast_knns2spmat\",\n    \"build_knns\",\n    \"knns2ordered_nbrs\",\n]\n\n\ndef knns2ordered_nbrs(knns, sort=True):\n    if isinstance(knns, list):\n        knns = np.array(knns)\n    nbrs = knns[:, 0, :].astype(np.int32)\n    dists = knns[:, 1, :]\n    if sort:\n        # sort dists from low to high\n        nb_idx = np.argsort(dists, axis=1)\n        idxs = np.arange(nb_idx.shape[0]).reshape(-1, 1)\n        dists = dists[idxs, nb_idx]\n        nbrs = nbrs[idxs, nb_idx]\n    return dists, nbrs\n\n\ndef fast_knns2spmat(knns, k, th_sim=0, use_sim=True, fill_value=None):\n    # convert knns to symmetric sparse matrix\n    from scipy.sparse import csr_matrix\n\n    eps = 1e-5\n    n = len(knns)\n    if isinstance(knns, list):\n        knns = np.array(knns)\n    if len(knns.shape) == 2:\n        # knns saved by hnsw has different shape\n        n = len(knns)\n        ndarr = np.ones([n, 2, k])\n        ndarr[:, 0, :] = -1  # assign unknown dist to 1 and nbr to -1\n        for i, (nbr, dist) in enumerate(knns):\n            size = len(nbr)\n            assert size == len(dist)\n            ndarr[i, 0, :size] = nbr[:size]\n            ndarr[i, 1, :size] = dist[:size]\n        knns = ndarr\n    nbrs = knns[:, 0, :]\n    dists = knns[:, 1, :]\n    assert (\n        -eps <= dists.min() <= dists.max() <= 1 + eps\n    ), \"min: {}, max: {}\".format(dists.min(), dists.max())\n    if use_sim:\n        sims = 1.0 - dists\n    else:\n        sims = dists\n    if fill_value is not None:\n        print(\"[fast_knns2spmat] edge fill value:\", fill_value)\n        sims.fill(fill_value)\n    row, col = np.where(sims >= th_sim)\n    # remove the self-loop\n    idxs = np.where(row != nbrs[row, col])\n    row = row[idxs]\n    col = col[idxs]\n    data = sims[row, col]\n    col = nbrs[row, col]  # convert to absolute column\n    assert len(row) == len(col) == len(data)\n    spmat = csr_matrix((data, (row, col)), shape=(n, n))\n    return spmat\n\n\ndef build_knns(feats, k, knn_method, dump=True):\n    with Timer(\"build index\"):\n        if knn_method == \"faiss\":\n            index = knn_faiss(feats, k, omp_num_threads=None)\n        elif knn_method == \"faiss_gpu\":\n            index = knn_faiss_gpu(feats, k)\n        else:\n            raise KeyError(\n                \"Only support faiss and faiss_gpu currently ({}).\".format(\n                    knn_method\n                )\n            )\n        knns = index.get_knns()\n    return knns\n\n\nclass knn:\n    def __init__(self, feats, k, index_path=\"\", verbose=True):\n        pass\n\n    def filter_by_th(self, i):\n        th_nbrs = []\n        th_dists = []\n        nbrs, dists = self.knns[i]\n        for n, dist in zip(nbrs, dists):\n            if 1 - dist < self.th:\n                continue\n            th_nbrs.append(n)\n            th_dists.append(dist)\n        th_nbrs = np.array(th_nbrs)\n        th_dists = np.array(th_dists)\n        return (th_nbrs, th_dists)\n\n    def get_knns(self, th=None):\n        if th is None or th <= 0.0:\n            return self.knns\n        # TODO: optimize the filtering process by numpy\n        # nproc = mp.cpu_count()\n        nproc = 1\n        with Timer(\n            \"filter edges by th {} (CPU={})\".format(th, nproc), self.verbose\n        ):\n            self.th = th\n            self.th_knns = []\n            tot = len(self.knns)\n            if nproc > 1:\n                pool = mp.Pool(nproc)\n                th_knns = list(\n                    tqdm(pool.imap(self.filter_by_th, range(tot)), total=tot)\n                )\n                pool.close()\n            else:\n                th_knns = [self.filter_by_th(i) for i in range(tot)]\n            return th_knns\n\n\nclass knn_faiss(knn):\n    def __init__(\n        self,\n        feats,\n        k,\n        nprobe=128,\n        omp_num_threads=None,\n        rebuild_index=True,\n        verbose=True,\n        **kwargs\n    ):\n        import faiss\n\n        if omp_num_threads is not None:\n            faiss.omp_set_num_threads(omp_num_threads)\n        self.verbose = verbose\n        with Timer(\"[faiss] build index\", verbose):\n            feats = feats.astype(\"float32\")\n            size, dim = feats.shape\n            index = faiss.IndexFlatIP(dim)\n            index.add(feats)\n        with Timer(\"[faiss] query topk {}\".format(k), verbose):\n            sims, nbrs = index.search(feats, k=k)\n            self.knns = [\n                (\n                    np.array(nbr, dtype=np.int32),\n                    1 - np.array(sim, dtype=np.float32),\n                )\n                for nbr, sim in zip(nbrs, sims)\n            ]\n\n\nclass knn_faiss_gpu(knn):\n    def __init__(\n        self,\n        feats,\n        k,\n        nprobe=128,\n        num_process=4,\n        is_precise=True,\n        sort=True,\n        verbose=True,\n        **kwargs\n    ):\n        with Timer(\"[faiss_gpu] query topk {}\".format(k), verbose):\n            dists, nbrs = faiss_search_knn(\n                feats,\n                k=k,\n                nprobe=nprobe,\n                num_process=num_process,\n                is_precise=is_precise,\n                sort=sort,\n                verbose=verbose,\n            )\n\n            self.knns = [\n                (\n                    np.array(nbr, dtype=np.int32),\n                    np.array(dist, dtype=np.float32),\n                )\n                for nbr, dist in zip(nbrs, dists)\n            ]\n"
  },
  {
    "path": "examples/pytorch/hilander/utils/metrics.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n\"\"\"\nThis file re-uses implementation from https://github.com/yl-1993/learn-to-cluster\n\"\"\"\n\nfrom __future__ import division\n\nimport numpy as np\nfrom sklearn.metrics import precision_score, recall_score\nfrom sklearn.metrics.cluster import (\n    contingency_matrix,\n    normalized_mutual_info_score,\n)\n\n__all__ = [\"pairwise\", \"bcubed\", \"nmi\", \"precision\", \"recall\", \"accuracy\"]\n\n\ndef _check(gt_labels, pred_labels):\n    if gt_labels.ndim != 1:\n        raise ValueError(\n            \"gt_labels must be 1D: shape is %r\" % (gt_labels.shape,)\n        )\n    if pred_labels.ndim != 1:\n        raise ValueError(\n            \"pred_labels must be 1D: shape is %r\" % (pred_labels.shape,)\n        )\n    if gt_labels.shape != pred_labels.shape:\n        raise ValueError(\n            \"gt_labels and pred_labels must have same size, got %d and %d\"\n            % (gt_labels.shape[0], pred_labels.shape[0])\n        )\n    return gt_labels, pred_labels\n\n\ndef _get_lb2idxs(labels):\n    lb2idxs = {}\n    for idx, lb in enumerate(labels):\n        if lb not in lb2idxs:\n            lb2idxs[lb] = []\n        lb2idxs[lb].append(idx)\n    return lb2idxs\n\n\ndef _compute_fscore(pre, rec):\n    return 2.0 * pre * rec / (pre + rec)\n\n\ndef fowlkes_mallows_score(gt_labels, pred_labels, sparse=True):\n    \"\"\"The original function is from `sklearn.metrics.fowlkes_mallows_score`.\n    We output the pairwise precision, pairwise recall and F-measure,\n    instead of calculating the geometry mean of precision and recall.\n    \"\"\"\n    (n_samples,) = gt_labels.shape\n\n    c = contingency_matrix(gt_labels, pred_labels, sparse=sparse)\n    tk = np.dot(c.data, c.data) - n_samples\n    pk = np.sum(np.asarray(c.sum(axis=0)).ravel() ** 2) - n_samples\n    qk = np.sum(np.asarray(c.sum(axis=1)).ravel() ** 2) - n_samples\n\n    avg_pre = tk / pk\n    avg_rec = tk / qk\n    fscore = _compute_fscore(avg_pre, avg_rec)\n\n    return avg_pre, avg_rec, fscore\n\n\ndef pairwise(gt_labels, pred_labels, sparse=True):\n    _check(gt_labels, pred_labels)\n    return fowlkes_mallows_score(gt_labels, pred_labels, sparse)\n\n\ndef bcubed(gt_labels, pred_labels):\n    _check(gt_labels, pred_labels)\n\n    gt_lb2idxs = _get_lb2idxs(gt_labels)\n    pred_lb2idxs = _get_lb2idxs(pred_labels)\n\n    num_lbs = len(gt_lb2idxs)\n    pre = np.zeros(num_lbs)\n    rec = np.zeros(num_lbs)\n    gt_num = np.zeros(num_lbs)\n\n    for i, gt_idxs in enumerate(gt_lb2idxs.values()):\n        all_pred_lbs = np.unique(pred_labels[gt_idxs])\n        gt_num[i] = len(gt_idxs)\n        for pred_lb in all_pred_lbs:\n            pred_idxs = pred_lb2idxs[pred_lb]\n            n = 1.0 * np.intersect1d(gt_idxs, pred_idxs).size\n            pre[i] += n**2 / len(pred_idxs)\n            rec[i] += n**2 / gt_num[i]\n\n    gt_num = gt_num.sum()\n    avg_pre = pre.sum() / gt_num\n    avg_rec = rec.sum() / gt_num\n    fscore = _compute_fscore(avg_pre, avg_rec)\n\n    return avg_pre, avg_rec, fscore\n\n\ndef nmi(gt_labels, pred_labels):\n    return normalized_mutual_info_score(pred_labels, gt_labels)\n\n\ndef precision(gt_labels, pred_labels):\n    return precision_score(gt_labels, pred_labels)\n\n\ndef recall(gt_labels, pred_labels):\n    return recall_score(gt_labels, pred_labels)\n\n\ndef accuracy(gt_labels, pred_labels):\n    return np.mean(gt_labels == pred_labels)\n"
  },
  {
    "path": "examples/pytorch/hilander/utils/misc.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n\"\"\"\nThis file re-uses implementation from https://github.com/yl-1993/learn-to-cluster\n\"\"\"\n\nimport json\nimport os\nimport pickle\nimport random\nimport time\n\nimport numpy as np\n\n\nclass TextColors:\n    HEADER = \"\\033[35m\"\n    OKBLUE = \"\\033[34m\"\n    OKGREEN = \"\\033[32m\"\n    WARNING = \"\\033[33m\"\n    FATAL = \"\\033[31m\"\n    ENDC = \"\\033[0m\"\n    BOLD = \"\\033[1m\"\n    UNDERLINE = \"\\033[4m\"\n\n\nclass Timer:\n    def __init__(self, name=\"task\", verbose=True):\n        self.name = name\n        self.verbose = verbose\n\n    def __enter__(self):\n        self.start = time.time()\n        return self\n\n    def __exit__(self, exc_type, exc_val, exc_tb):\n        if self.verbose:\n            print(\n                \"[Time] {} consumes {:.4f} s\".format(\n                    self.name, time.time() - self.start\n                )\n            )\n        return exc_type is None\n\n\ndef set_random_seed(seed, cuda=False):\n    import torch\n\n    random.seed(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n    if cuda:\n        torch.cuda.manual_seed_all(seed)\n\n\ndef l2norm(vec):\n    vec /= np.linalg.norm(vec, axis=1).reshape(-1, 1)\n    return vec\n\n\ndef is_l2norm(features, size):\n    rand_i = random.choice(range(size))\n    norm_ = np.dot(features[rand_i, :], features[rand_i, :])\n    return abs(norm_ - 1) < 1e-6\n\n\ndef is_spmat_eq(a, b):\n    return (a != b).nnz == 0\n\n\ndef aggregate(features, adj, times):\n    dtype = features.dtype\n    for i in range(times):\n        features = adj * features\n    return features.astype(dtype)\n\n\ndef mkdir_if_no_exists(path, subdirs=[\"\"], is_folder=False):\n    if path == \"\":\n        return\n    for sd in subdirs:\n        if sd != \"\" or is_folder:\n            d = os.path.dirname(os.path.join(path, sd))\n        else:\n            d = os.path.dirname(path)\n        if not os.path.exists(d):\n            os.makedirs(d)\n\n\ndef stop_iterating(\n    current_l,\n    total_l,\n    early_stop,\n    num_edges_add_this_level,\n    num_edges_add_last_level,\n    knn_k,\n):\n    # Stopping rule 1: run all levels\n    if current_l == total_l - 1:\n        return True\n    # Stopping rule 2: no new edges\n    if num_edges_add_this_level == 0:\n        return True\n    # Stopping rule 3: early stopping, two levels start to produce similar numbers of edges\n    if (\n        early_stop\n        and float(num_edges_add_last_level) / num_edges_add_this_level\n        < knn_k - 1\n    ):\n        return True\n    return False\n"
  },
  {
    "path": "examples/pytorch/infograph/README.md",
    "content": "# DGL Implementation of InfoGraph\nThis DGL example implements the model proposed in the paper [InfoGraph: Unsupervised and Semi-supervised Graph-Level Representation Learning via Mutual Information Maximization](https://arxiv.org/abs/1908.01000).\n\nAuthor's code: https://github.com/fanyun-sun/InfoGraph\n\n## Example Implementor\n\nThis example was implemented by [Hengrui Zhang](https://github.com/hengruizhang98) when he was an applied scientist intern at AWS Shanghai AI Lab.\n\n## Dependencies\n\n- Python 3.7\n- PyTorch 1.7.1\n- dgl 0.6.0\n\n## Datasets\n\n##### Unsupervised Graph Classification Dataset:\n\n 'MUTAG', 'PTC', 'IMDBBINARY'(IMDB-B), 'IMDBMULTI'(IMDB-M), 'REDDITBINARY'(RDT-B), 'REDDITMULTI5K'(RDT-M5K) of dgl.data.GINDataset.\n\n| Dataset         | MUTAG | PTC   | RDT-B  | RDT-M5K | IMDB-B | IMDB-M |\n| --------------- | ----- | ----- | ------ | ------- | ------ | ------ |\n| # Graphs        | 188   | 344   | 2000   | 4999    | 1000   | 1500   |\n| # Classes       | 2     | 2     | 2      | 5       | 2      | 3      |\n| Avg. Graph Size | 17.93 | 14.29 | 429.63 | 508.52  | 19.77  | 13.00  |\n\n**Semi-supervised Graph Regression Dataset:**\n\nQM9 dataset for graph property prediction (regression)\n\n| Dataset | # Graphs | # Regression Tasks |\n| ------- | -------- | ------------------ |\n| QM9     | 130,831  | 12                 |\n\nThe 12 tasks are:\n\n| Keys  | Description                                |\n| ----- | :----------------------------------------- |\n| mu    | Dipole moment                              |\n| alpha | Isotropic polarizability                   |\n| homo  | Highest occupied molecular orbital energ   |\n| lumo  | Lowest unoccupied molecular orbital energy |\n| gap   | Gap between 'homo' and 'lumo'              |\n| r2    | Electronic spatial extent                  |\n| zpve  | Zero point vibrational energy              |\n| U0    | Internal energy at 0K                      |\n| U     | Internal energy at 298.15K                 |\n| H     | Enthalpy at 298.15K                        |\n| G     | Free energy at 298.15K                     |\n| Cv    | Heat capavity at 298.15K                   |\n\n## Arguments\n\n##### \tUnsupervised Graph Classification:\n\n###### Dataset options\n\n```\n--dataname        str      The graph dataset name.                Default is 'MUTAG'.\n```\n\n###### GPU options\n\n```\n--gpu              int     GPU index.                             Default is -1, using CPU.\n```\n\n###### Training options\n\n```\n--epochs           int     Number of training periods.            Default is 20.\n--batch_size       int     Size of a training batch.              Default is 128.\n--lr               float   Adam optimizer learning rate.          Default is 0.01.\n--log_interval     int     Interval bettwen two evaluations.\t  Default is 1.\n```\n\n###### Model options\n\n```\n--n_layers         int     Number of GIN layers.                  Default is 3.\n--hid_dim          int     Dimension of hidden layers.            Default is 32.\n```\n\n##### \tSemi-supervised Graph Regression:\n\n###### Dataset options\n\n```\n --target          str     The regression Task.                   Default is 'mu'.\n --train_num       int     Number of supervised examples.         Default is 5000.\n```\n\n###### GPU options\n\n```\n--gpu              int     GPU index.                             Default is -1, using CPU.\n```\n\n###### Training options\n\n```\n--epochs           int     Number of training periods.            Default is 200.\n--batch_size       int     Size of a training batch.              Default is 20.\n--val_batch_size   int     Size of a validation batch.            Default is 100.\n--lr               float   Adam optimizer learning rate.          Default is 0.001.\n```\n\n###### Model options\n\n```\n--hid_dim          int     Dimension of hidden layers.            Default is 64.\n--reg              int     Regularization weight.                 Default is 0.001.\n```\n\n## How to run examples\n\nTraining and testing unsupervised model on MUTAG.\n\n (As graphs in these datasets are quite small and sparse, moving graphs from cpu to gpu would take a longer time than training, we recommend using **cpu** for these datasets).\n\n```bash\n# MUTAG:\npython unsupervised.py --dataname MUTAG --n_layers 4 --hid_dim 32\n```\n\nReplace 'MUTAG' with dataname in ['MUTAG', 'PTC', 'IMDBBINARY', 'IMDBMULTI', 'REDDITBINARY', 'REDDITMULTI5K'] if you'd like to try other datasets.\n\nTraining and testing semi-supervised model on QM9 for graph property 'mu' with gpu.\n\n```bash\n# QM9:\npython semisupervised.py --gpu 0 --target mu\n```\n\nReplace 'mu' with other target names above.\n\n## \tPerformance\n\nThe hyperparameter setting in our implementation is identical to that reported in the paper.\n\n##### Unsupervised Graph Classification:\n\n|      Dataset      | MUTAG |  PTC  | RDT-B | RDT-M5K | IMDB-B | IMDB-M |\n| :---------------: | :---: | :---: | :---: | ------- | ------ | ------ |\n| Accuracy Reported | 89.01 | 61.65 | 82.50 | 53.46   | 73.03  | 49.69  |\n|        DGL        | 89.88 | 63.54 | 88.50 | 56.27   | 72.70  | 50.13  |\n\n* REDDIT-M dataset would take a quite long time to load and evaluate. \n\n##### Semisupervised Graph Regression on QM9:\n\nHere we only provide the results of 'mu', 'alpha', 'homo'.\n\n\n\n|      Target       |   mu   | alpha  |  homo  |\n| :---------------: | :----: | :----: | :----: |\n|   MAE Reported    | 0.3169 | 0.5444 | 0.0060 |\n| The authors' code | 0.2411 | 0.5192 | 0.1560 |\n|        DGL        | 0.2355 | 0.5483 | 0.1581 |\n\n* The source of QM9 Dataset has changed so there's a gap between the MAE reported in the paper and that we reprodcued.\n* See this [issue](https://github.com/fanyun-sun/InfoGraph/issues/8) for authors' response. \n"
  },
  {
    "path": "examples/pytorch/infograph/evaluate_embedding.py",
    "content": "\"\"\" Evaluate unsupervised embedding using a variety of basic classifiers. \"\"\"\n\"\"\" Credit: https://github.com/fanyun-sun/InfoGraph \"\"\"\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nfrom sklearn import preprocessing\nfrom sklearn.metrics import accuracy_score\nfrom sklearn.model_selection import GridSearchCV, StratifiedKFold\nfrom sklearn.svm import SVC\n\n\nclass LogReg(nn.Module):\n    def __init__(self, ft_in, nb_classes):\n        super(LogReg, self).__init__()\n        self.fc = nn.Linear(ft_in, nb_classes)\n\n    def weights_init(self, m):\n        if isinstance(m, nn.Linear):\n            torch.nn.init.xavier_uniform_(m.weight.data)\n            if m.bias is not None:\n                m.bias.data.fill_(0.0)\n\n    def forward(self, seq):\n        ret = self.fc(seq)\n        return ret\n\n\ndef logistic_classify(x, y, device=\"cpu\"):\n    nb_classes = np.unique(y).shape[0]\n    xent = nn.CrossEntropyLoss()\n    hid_units = x.shape[1]\n\n    accs = []\n    kf = StratifiedKFold(n_splits=10, shuffle=True, random_state=None)\n    for train_index, test_index in kf.split(x, y):\n        train_embs, test_embs = x[train_index], x[test_index]\n        train_lbls, test_lbls = y[train_index], y[test_index]\n\n        train_embs, train_lbls = torch.from_numpy(train_embs).to(\n            device\n        ), torch.from_numpy(train_lbls).to(device)\n        test_embs, test_lbls = torch.from_numpy(test_embs).to(\n            device\n        ), torch.from_numpy(test_lbls).to(device)\n\n        log = LogReg(hid_units, nb_classes)\n        log = log.to(device)\n\n        opt = torch.optim.Adam(log.parameters(), lr=0.01, weight_decay=0.0)\n\n        for it in range(100):\n            log.train()\n            opt.zero_grad()\n\n            logits = log(train_embs)\n            loss = xent(logits, train_lbls)\n\n            loss.backward()\n            opt.step()\n\n        logits = log(test_embs)\n        preds = torch.argmax(logits, dim=1)\n        acc = torch.sum(preds == test_lbls).float() / test_lbls.shape[0]\n        accs.append(acc.item())\n    return np.mean(accs)\n\n\ndef svc_classify(x, y, search):\n    kf = StratifiedKFold(n_splits=10, shuffle=True, random_state=None)\n    accuracies = []\n    for train_index, test_index in kf.split(x, y):\n        x_train, x_test = x[train_index], x[test_index]\n        y_train, y_test = y[train_index], y[test_index]\n\n        if search:\n            params = {\"C\": [0.001, 0.01, 0.1, 1, 10, 100, 1000]}\n            classifier = GridSearchCV(\n                SVC(), params, cv=5, scoring=\"accuracy\", verbose=0\n            )\n        else:\n            classifier = SVC(C=10)\n        classifier.fit(x_train, y_train)\n        accuracies.append(accuracy_score(y_test, classifier.predict(x_test)))\n    return np.mean(accuracies)\n\n\ndef evaluate_embedding(embeddings, labels, search=True, device=\"cpu\"):\n    labels = preprocessing.LabelEncoder().fit_transform(labels)\n    x, y = np.array(embeddings), np.array(labels)\n\n    logreg_accuracy = logistic_classify(x, y, device)\n    print(\"LogReg\", logreg_accuracy)\n    svc_accuracy = svc_classify(x, y, search)\n    print(\"svc\", svc_accuracy)\n\n    return logreg_accuracy, svc_accuracy\n"
  },
  {
    "path": "examples/pytorch/infograph/model.py",
    "content": "import torch as th\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom dgl.nn import GINConv, NNConv, Set2Set\nfrom dgl.nn.pytorch.glob import SumPooling\nfrom torch.nn import BatchNorm1d, GRU, Linear, ModuleList, ReLU, Sequential\nfrom utils import global_global_loss_, local_global_loss_\n\n\"\"\" Feedforward neural network\"\"\"\n\n\nclass FeedforwardNetwork(nn.Module):\n\n    \"\"\"\n    3-layer feed-forward neural networks with jumping connections\n    Parameters\n    -----------\n    in_dim: int\n        Input feature size.\n    hid_dim: int\n        Hidden feature size.\n\n    Functions\n    -----------\n    forward(feat):\n        feat: Tensor\n            [N * D], input features\n    \"\"\"\n\n    def __init__(self, in_dim, hid_dim):\n        super(FeedforwardNetwork, self).__init__()\n\n        self.block = Sequential(\n            Linear(in_dim, hid_dim),\n            ReLU(),\n            Linear(hid_dim, hid_dim),\n            ReLU(),\n            Linear(hid_dim, hid_dim),\n            ReLU(),\n        )\n\n        self.jump_con = Linear(in_dim, hid_dim)\n\n    def forward(self, feat):\n        block_out = self.block(feat)\n        jump_out = self.jump_con(feat)\n\n        out = block_out + jump_out\n\n        return out\n\n\n\"\"\" Unsupervised Setting \"\"\"\n\n\nclass GINEncoder(nn.Module):\n    \"\"\"\n    Encoder based on dgl.nn.GINConv &  dgl.nn.SumPooling\n    Parameters\n    -----------\n    in_dim: int\n        Input feature size.\n    hid_dim: int\n        Hidden feature size.\n    n_layer:\n        Number of GIN layers.\n\n    Functions\n    -----------\n    forward(graph, feat):\n        graph: DGLGraph\n        feat: Tensor\n            [N * D], node features\n    \"\"\"\n\n    def __init__(self, in_dim, hid_dim, n_layer):\n        super(GINEncoder, self).__init__()\n\n        self.n_layer = n_layer\n\n        self.convs = ModuleList()\n        self.bns = ModuleList()\n\n        for i in range(n_layer):\n            if i == 0:\n                n_in = in_dim\n            else:\n                n_in = hid_dim\n            n_out = hid_dim\n            block = Sequential(\n                Linear(n_in, n_out), ReLU(), Linear(hid_dim, hid_dim)\n            )\n\n            conv = GINConv(apply_func=block, aggregator_type=\"sum\")\n            bn = BatchNorm1d(hid_dim)\n\n            self.convs.append(conv)\n            self.bns.append(bn)\n\n        # sum pooling\n        self.pool = SumPooling()\n\n    def forward(self, graph, feat):\n        xs = []\n        x = feat\n        for i in range(self.n_layer):\n            x = F.relu(self.convs[i](graph, x))\n            x = self.bns[i](x)\n            xs.append(x)\n\n        local_emb = th.cat(xs, 1)  # patch-level embedding\n        global_emb = self.pool(graph, local_emb)  # graph-level embedding\n\n        return global_emb, local_emb\n\n\nclass InfoGraph(nn.Module):\n    r\"\"\"\n        InfoGraph model for unsupervised setting\n\n    Parameters\n    -----------\n    in_dim: int\n        Input feature size.\n    hid_dim: int\n        Hidden feature size.\n    n_layer: int\n        Number of the GNN encoder layers.\n\n    Functions\n    -----------\n    forward(graph):\n        graph: DGLGraph\n\n    \"\"\"\n\n    def __init__(self, in_dim, hid_dim, n_layer):\n        super(InfoGraph, self).__init__()\n\n        self.in_dim = in_dim\n        self.hid_dim = hid_dim\n\n        self.n_layer = n_layer\n        embedding_dim = hid_dim * n_layer\n\n        self.encoder = GINEncoder(in_dim, hid_dim, n_layer)\n\n        self.local_d = FeedforwardNetwork(\n            embedding_dim, embedding_dim\n        )  # local discriminator (node-level)\n        self.global_d = FeedforwardNetwork(\n            embedding_dim, embedding_dim\n        )  # global discriminator (graph-level)\n\n    def get_embedding(self, graph, feat):\n        # get_embedding function for evaluation the learned embeddings\n\n        with th.no_grad():\n            global_emb, _ = self.encoder(graph, feat)\n\n        return global_emb\n\n    def forward(self, graph, feat, graph_id):\n        global_emb, local_emb = self.encoder(graph, feat)\n\n        global_h = self.global_d(global_emb)  # global hidden representation\n        local_h = self.local_d(local_emb)  # local hidden representation\n\n        loss = local_global_loss_(local_h, global_h, graph_id)\n\n        return loss\n\n\n\"\"\" Semisupervised Setting \"\"\"\n\n\nclass NNConvEncoder(nn.Module):\n\n    \"\"\"\n    Encoder based on dgl.nn.NNConv & GRU & dgl.nn.set2set pooling\n    Parameters\n    -----------\n    in_dim: int\n        Input feature size.\n    hid_dim: int\n        Hidden feature size.\n\n    Functions\n    -----------\n    forward(graph, nfeat, efeat):\n        graph: DGLGraph\n        nfeat: Tensor\n            [N * D1], node features\n        efeat: Tensor\n            [E * D2], edge features\n    \"\"\"\n\n    def __init__(self, in_dim, hid_dim):\n        super(NNConvEncoder, self).__init__()\n\n        self.lin0 = Linear(in_dim, hid_dim)\n\n        # mlp for edge convolution in NNConv\n        block = Sequential(\n            Linear(5, 128), ReLU(), Linear(128, hid_dim * hid_dim)\n        )\n\n        self.conv = NNConv(\n            hid_dim,\n            hid_dim,\n            edge_func=block,\n            aggregator_type=\"mean\",\n            residual=False,\n        )\n        self.gru = GRU(hid_dim, hid_dim)\n\n        # set2set pooling\n        self.set2set = Set2Set(hid_dim, n_iters=3, n_layers=1)\n\n    def forward(self, graph, nfeat, efeat):\n        out = F.relu(self.lin0(nfeat))\n        h = out.unsqueeze(0)\n\n        feat_map = []\n\n        # Convolution layer number is 3\n        for i in range(3):\n            m = F.relu(self.conv(graph, out, efeat))\n            out, h = self.gru(m.unsqueeze(0), h)\n            out = out.squeeze(0)\n            feat_map.append(out)\n\n        out = self.set2set(graph, out)\n\n        # out: global embedding, feat_map[-1]: local embedding\n        return out, feat_map[-1]\n\n\nclass InfoGraphS(nn.Module):\n\n    \"\"\"\n    InfoGraph* model for semi-supervised setting\n    Parameters\n    -----------\n    in_dim: int\n        Input feature size.\n    hid_dim: int\n        Hidden feature size.\n\n    Functions\n    -----------\n    forward(graph):\n        graph: DGLGraph\n\n    unsupforward(graph):\n        graph: DGLGraph\n\n    \"\"\"\n\n    def __init__(self, in_dim, hid_dim):\n        super(InfoGraphS, self).__init__()\n\n        self.sup_encoder = NNConvEncoder(in_dim, hid_dim)\n        self.unsup_encoder = NNConvEncoder(in_dim, hid_dim)\n\n        self.fc1 = Linear(2 * hid_dim, hid_dim)\n        self.fc2 = Linear(hid_dim, 1)\n\n        # unsupervised local discriminator and global discriminator for local-global infomax\n        self.unsup_local_d = FeedforwardNetwork(hid_dim, hid_dim)\n        self.unsup_global_d = FeedforwardNetwork(2 * hid_dim, hid_dim)\n\n        # supervised global discriminator and unsupervised global discriminator for global-global infomax\n        self.sup_d = FeedforwardNetwork(2 * hid_dim, hid_dim)\n        self.unsup_d = FeedforwardNetwork(2 * hid_dim, hid_dim)\n\n    def forward(self, graph, nfeat, efeat):\n        sup_global_emb, sup_local_emb = self.sup_encoder(graph, nfeat, efeat)\n\n        sup_global_pred = self.fc2(F.relu(self.fc1(sup_global_emb)))\n        sup_global_pred = sup_global_pred.view(-1)\n\n        return sup_global_pred\n\n    def unsup_forward(self, graph, nfeat, efeat, graph_id):\n        sup_global_emb, sup_local_emb = self.sup_encoder(graph, nfeat, efeat)\n        unsup_global_emb, unsup_local_emb = self.unsup_encoder(\n            graph, nfeat, efeat\n        )\n\n        g_enc = self.unsup_global_d(unsup_global_emb)\n        l_enc = self.unsup_local_d(unsup_local_emb)\n\n        sup_g_enc = self.sup_d(sup_global_emb)\n        unsup_g_enc = self.unsup_d(unsup_global_emb)\n\n        # Calculate loss\n        unsup_loss = local_global_loss_(l_enc, g_enc, graph_id)\n        con_loss = global_global_loss_(sup_g_enc, unsup_g_enc)\n\n        return unsup_loss, con_loss\n"
  },
  {
    "path": "examples/pytorch/infograph/semisupervised.py",
    "content": "import argparse\n\nimport dgl\n\nimport numpy as np\nimport torch as th\nimport torch.nn.functional as F\nfrom dgl.data import QM9EdgeDataset\nfrom dgl.data.utils import Subset\nfrom dgl.dataloading import GraphDataLoader\nfrom model import InfoGraphS\n\n\ndef argument():\n    parser = argparse.ArgumentParser(description=\"InfoGraphS\")\n\n    # data source params\n    parser.add_argument(\n        \"--target\", type=str, default=\"mu\", help=\"Choose regression task\"\n    )\n    parser.add_argument(\n        \"--train_num\", type=int, default=5000, help=\"Size of training set\"\n    )\n\n    # training params\n    parser.add_argument(\n        \"--gpu\", type=int, default=-1, help=\"GPU index, default:-1, using CPU.\"\n    )\n    parser.add_argument(\n        \"--epochs\", type=int, default=200, help=\"Training epochs.\"\n    )\n    parser.add_argument(\n        \"--batch_size\", type=int, default=20, help=\"Training batch size.\"\n    )\n    parser.add_argument(\n        \"--val_batch_size\", type=int, default=100, help=\"Validation batch size.\"\n    )\n\n    parser.add_argument(\n        \"--lr\", type=float, default=0.001, help=\"Learning rate.\"\n    )\n    parser.add_argument(\"--wd\", type=float, default=0, help=\"Weight decay.\")\n\n    # model params\n    parser.add_argument(\n        \"--hid_dim\", type=int, default=64, help=\"Hidden layer dimensionality\"\n    )\n    parser.add_argument(\n        \"--reg\", type=float, default=0.001, help=\"Regularization coefficient\"\n    )\n\n    args = parser.parse_args()\n\n    # check cuda\n    if args.gpu != -1 and th.cuda.is_available():\n        args.device = \"cuda:{}\".format(args.gpu)\n    else:\n        args.device = \"cpu\"\n\n    return args\n\n\nclass DenseQM9EdgeDataset(QM9EdgeDataset):\n    def __getitem__(self, idx):\n        r\"\"\"Get graph and label by index\n\n        Parameters\n        ----------\n        idx : int\n            Item index\n\n        Returns\n        -------\n        dgl.DGLGraph\n           The graph contains:\n\n           - ``ndata['pos']``: the coordinates of each atom\n           - ``ndata['attr']``: the features of each atom\n           - ``edata['edge_attr']``: the features of each bond\n\n        Tensor\n            Property values of molecular graphs\n        \"\"\"\n\n        pos = self.node_pos[self.n_cumsum[idx] : self.n_cumsum[idx + 1]]\n        src = self.src[self.ne_cumsum[idx] : self.ne_cumsum[idx + 1]]\n        dst = self.dst[self.ne_cumsum[idx] : self.ne_cumsum[idx + 1]]\n\n        g = dgl.graph((src, dst))\n\n        g.ndata[\"pos\"] = th.tensor(pos).float()\n        g.ndata[\"attr\"] = th.tensor(\n            self.node_attr[self.n_cumsum[idx] : self.n_cumsum[idx + 1]]\n        ).float()\n        g.edata[\"edge_attr\"] = th.tensor(\n            self.edge_attr[self.ne_cumsum[idx] : self.ne_cumsum[idx + 1]]\n        ).float()\n\n        label = th.tensor(self.targets[idx][self.label_keys]).float()\n\n        n_nodes = g.num_nodes()\n        row = th.arange(n_nodes)\n        col = th.arange(n_nodes)\n\n        row = row.view(-1, 1).repeat(1, n_nodes).view(-1)\n        col = col.repeat(n_nodes)\n\n        src = g.edges()[0]\n        dst = g.edges()[1]\n\n        idx = src * n_nodes + dst\n        size = list(g.edata[\"edge_attr\"].size())\n        size[0] = n_nodes * n_nodes\n        edge_attr = g.edata[\"edge_attr\"].new_zeros(size)\n\n        edge_attr[idx] = g.edata[\"edge_attr\"]\n\n        pos = g.ndata[\"pos\"]\n        dist = th.norm(pos[col] - pos[row], p=2, dim=-1).view(-1, 1)\n\n        new_edge_attr = th.cat([edge_attr, dist.type_as(edge_attr)], dim=-1)\n\n        graph = dgl.graph((row, col))\n        graph.ndata[\"attr\"] = g.ndata[\"attr\"]\n        graph.edata[\"edge_attr\"] = new_edge_attr\n        graph = graph.remove_self_loop()\n\n        return graph, label\n\n\ndef collate(samples):\n    \"\"\"collate function for building graph dataloader\"\"\"\n\n    # generate batched graphs and labels\n    graphs, targets = map(list, zip(*samples))\n    batched_graph = dgl.batch(graphs)\n    batched_targets = th.Tensor(targets)\n\n    n_graphs = len(graphs)\n    graph_id = th.arange(n_graphs)\n    graph_id = dgl.broadcast_nodes(batched_graph, graph_id)\n\n    batched_graph.ndata[\"graph_id\"] = graph_id\n\n    return batched_graph, batched_targets\n\n\ndef evaluate(model, loader, num, device):\n    error = 0\n    for graphs, targets in loader:\n        graphs = graphs.to(device)\n\n        nfeat, efeat = graphs.ndata[\"attr\"], graphs.edata[\"edge_attr\"]\n        targets = targets.to(device)\n        error += (model(graphs, nfeat, efeat) - targets).abs().sum().item()\n\n    error = error / num\n\n    return error\n\n\nif __name__ == \"__main__\":\n    # Step 1: Prepare graph data   ===================================== #\n    args = argument()\n    label_keys = [args.target]\n    print(args)\n\n    dataset = DenseQM9EdgeDataset(label_keys=label_keys)\n\n    # Train/Val/Test Splitting\n    N = dataset.targets.shape[0]\n    all_idx = np.arange(N)\n    np.random.shuffle(all_idx)\n\n    val_num = 10000\n    test_num = 10000\n\n    val_idx = all_idx[:val_num]\n    test_idx = all_idx[val_num : val_num + test_num]\n    train_idx = all_idx[\n        val_num + test_num : val_num + test_num + args.train_num\n    ]\n\n    train_data = Subset(dataset, train_idx)\n    val_data = Subset(dataset, val_idx)\n    test_data = Subset(dataset, test_idx)\n\n    unsup_idx = all_idx[val_num + test_num :]\n    unsup_data = Subset(dataset, unsup_idx)\n\n    # generate supervised training dataloader and unsupervised training dataloader\n    train_loader = GraphDataLoader(\n        train_data,\n        batch_size=args.batch_size,\n        collate_fn=collate,\n        drop_last=False,\n        shuffle=True,\n    )\n\n    unsup_loader = GraphDataLoader(\n        unsup_data,\n        batch_size=args.batch_size,\n        collate_fn=collate,\n        drop_last=False,\n        shuffle=True,\n    )\n\n    # generate validation & testing dataloader\n    val_loader = GraphDataLoader(\n        val_data,\n        batch_size=args.val_batch_size,\n        collate_fn=collate,\n        drop_last=False,\n        shuffle=True,\n    )\n\n    test_loader = GraphDataLoader(\n        test_data,\n        batch_size=args.val_batch_size,\n        collate_fn=collate,\n        drop_last=False,\n        shuffle=True,\n    )\n\n    print(\"======== target = {} ========\".format(args.target))\n\n    in_dim = dataset[0][0].ndata[\"attr\"].shape[1]\n\n    # Step 2: Create model =================================================================== #\n    model = InfoGraphS(in_dim, args.hid_dim)\n    model = model.to(args.device)\n\n    # Step 3: Create training components ===================================================== #\n    optimizer = th.optim.Adam(\n        model.parameters(), lr=args.lr, weight_decay=args.wd\n    )\n    scheduler = th.optim.lr_scheduler.ReduceLROnPlateau(\n        optimizer, mode=\"min\", factor=0.7, patience=5, min_lr=0.000001\n    )\n\n    # Step 4: training epochs =============================================================== #\n    best_val_error = float(\"inf\")\n    test_error = float(\"inf\")\n\n    for epoch in range(args.epochs):\n        \"\"\"Training\"\"\"\n        model.train()\n        lr = scheduler.optimizer.param_groups[0][\"lr\"]\n\n        iteration = 0\n        sup_loss_all = 0\n        unsup_loss_all = 0\n        consis_loss_all = 0\n\n        for sup_data, unsup_data in zip(train_loader, unsup_loader):\n            sup_graph, sup_target = sup_data\n            unsup_graph, _ = unsup_data\n\n            sup_graph = sup_graph.to(args.device)\n            unsup_graph = unsup_graph.to(args.device)\n\n            sup_nfeat, sup_efeat = (\n                sup_graph.ndata[\"attr\"],\n                sup_graph.edata[\"edge_attr\"],\n            )\n            unsup_nfeat, unsup_efeat, unsup_graph_id = (\n                unsup_graph.ndata[\"attr\"],\n                unsup_graph.edata[\"edge_attr\"],\n                unsup_graph.ndata[\"graph_id\"],\n            )\n\n            sup_target = sup_target\n            sup_target = sup_target.to(args.device)\n\n            optimizer.zero_grad()\n\n            sup_loss = F.mse_loss(\n                model(sup_graph, sup_nfeat, sup_efeat), sup_target\n            )\n            unsup_loss, consis_loss = model.unsup_forward(\n                unsup_graph, unsup_nfeat, unsup_efeat, unsup_graph_id\n            )\n\n            loss = sup_loss + unsup_loss + args.reg * consis_loss\n\n            loss.backward()\n\n            sup_loss_all += sup_loss.item()\n            unsup_loss_all += unsup_loss.item()\n            consis_loss_all += consis_loss.item()\n\n            optimizer.step()\n\n        print(\n            \"Epoch: {}, Sup_Loss: {:4f}, Unsup_loss: {:.4f}, Consis_loss: {:.4f}\".format(\n                epoch, sup_loss_all, unsup_loss_all, consis_loss_all\n            )\n        )\n\n        model.eval()\n\n        val_error = evaluate(model, val_loader, val_num, args.device)\n        scheduler.step(val_error)\n\n        if val_error < best_val_error:\n            best_val_error = val_error\n            test_error = evaluate(model, test_loader, test_num, args.device)\n\n        print(\n            \"Epoch: {}, LR: {}, val_error: {:.4f}, best_test_error: {:.4f}\".format(\n                epoch, lr, val_error, test_error\n            )\n        )\n"
  },
  {
    "path": "examples/pytorch/infograph/unsupervised.py",
    "content": "import argparse\n\nimport dgl\n\nimport torch as th\nfrom dgl.data import GINDataset\nfrom dgl.dataloading import GraphDataLoader\nfrom evaluate_embedding import evaluate_embedding\nfrom model import InfoGraph\n\n\ndef argument():\n    parser = argparse.ArgumentParser(description=\"InfoGraph\")\n    # data source params\n    parser.add_argument(\n        \"--dataname\", type=str, default=\"MUTAG\", help=\"Name of dataset.\"\n    )\n\n    # training params\n    parser.add_argument(\n        \"--gpu\", type=int, default=-1, help=\"GPU index, default:-1, using CPU.\"\n    )\n    parser.add_argument(\n        \"--epochs\", type=int, default=20, help=\"Training epochs.\"\n    )\n    parser.add_argument(\n        \"--batch_size\", type=int, default=128, help=\"Training batch size.\"\n    )\n    parser.add_argument(\"--lr\", type=float, default=0.01, help=\"Learning rate.\")\n    parser.add_argument(\n        \"--log_interval\",\n        type=int,\n        default=1,\n        help=\"Interval between two evaluations.\",\n    )\n\n    # model params\n    parser.add_argument(\n        \"--n_layers\",\n        type=int,\n        default=3,\n        help=\"Number of graph convolution layers before each pooling.\",\n    )\n    parser.add_argument(\n        \"--hid_dim\", type=int, default=32, help=\"Hidden layer dimensionalities.\"\n    )\n\n    args = parser.parse_args()\n\n    # check cuda\n    if args.gpu != -1 and th.cuda.is_available():\n        args.device = \"cuda:{}\".format(args.gpu)\n    else:\n        args.device = \"cpu\"\n\n    return args\n\n\ndef collate(samples):\n    \"\"\"collate function for building graph dataloader\"\"\"\n\n    graphs, labels = map(list, zip(*samples))\n\n    # generate batched graphs and labels\n    batched_graph = dgl.batch(graphs)\n    batched_labels = th.tensor(labels)\n\n    n_graphs = len(graphs)\n    graph_id = th.arange(n_graphs)\n    graph_id = dgl.broadcast_nodes(batched_graph, graph_id)\n\n    batched_graph.ndata[\"graph_id\"] = graph_id\n\n    return batched_graph, batched_labels\n\n\nif __name__ == \"__main__\":\n    # Step 1: Prepare graph data   ===================================== #\n    args = argument()\n    print(args)\n\n    # load dataset from dgl.data.GINDataset\n    dataset = GINDataset(args.dataname, self_loop=False)\n\n    # get graphs and labels\n    graphs, labels = map(list, zip(*dataset))\n\n    # generate a full-graph with all examples for evaluation\n    wholegraph = dgl.batch(graphs)\n    wholegraph.ndata[\"attr\"] = wholegraph.ndata[\"attr\"].to(th.float32)\n\n    # create dataloader for batch training\n    dataloader = GraphDataLoader(\n        dataset,\n        batch_size=args.batch_size,\n        collate_fn=collate,\n        drop_last=False,\n        shuffle=True,\n    )\n\n    in_dim = wholegraph.ndata[\"attr\"].shape[1]\n\n    # Step 2: Create model =================================================================== #\n    model = InfoGraph(in_dim, args.hid_dim, args.n_layers)\n    model = model.to(args.device)\n\n    # Step 3: Create training components ===================================================== #\n    optimizer = th.optim.Adam(model.parameters(), lr=args.lr)\n\n    print(\"===== Before training ======\")\n\n    wholegraph = wholegraph.to(args.device)\n    wholefeat = wholegraph.ndata[\"attr\"]\n\n    emb = model.get_embedding(wholegraph, wholefeat).cpu()\n    res = evaluate_embedding(emb, labels, args.device)\n\n    \"\"\" Evaluate the initialized embeddings \"\"\"\n    \"\"\" using logistic regression and SVM(non-linear) \"\"\"\n    print(\"logreg {:4f}, svc {:4f}\".format(res[0], res[1]))\n\n    best_logreg = 0\n    best_logreg_epoch = 0\n    best_svc = 0\n    best_svc_epoch = 0\n\n    # Step 4: training epochs =============================================================== #\n    for epoch in range(args.epochs):\n        loss_all = 0\n        model.train()\n\n        for graph, label in dataloader:\n            graph = graph.to(args.device)\n            feat = graph.ndata[\"attr\"]\n            graph_id = graph.ndata[\"graph_id\"]\n\n            n_graph = label.shape[0]\n\n            optimizer.zero_grad()\n            loss = model(graph, feat, graph_id)\n            loss.backward()\n            optimizer.step()\n            loss_all += loss.item()\n\n        print(\"Epoch {}, Loss {:.4f}\".format(epoch, loss_all))\n\n        if epoch % args.log_interval == 0:\n            # evaluate embeddings\n            model.eval()\n            emb = model.get_embedding(wholegraph, wholefeat).cpu()\n            res = evaluate_embedding(emb, labels, args.device)\n\n            if res[0] > best_logreg:\n                best_logreg = res[0]\n                best_logreg_epoch = epoch\n\n            if res[1] > best_svc:\n                best_svc = res[1]\n                best_svc_epoch = epoch\n\n            print(\n                \"best logreg {:4f}, epoch {} | best svc: {:4f}, epoch {}\".format(\n                    best_logreg, best_logreg_epoch, best_svc, best_svc_epoch\n                )\n            )\n\n    print(\"Training End\")\n    print(\"best logreg {:4f} ,best svc {:4f}\".format(best_logreg, best_svc))\n"
  },
  {
    "path": "examples/pytorch/infograph/utils.py",
    "content": "\"\"\" Credit: https://github.com/fanyun-sun/InfoGraph \"\"\"\n\nimport math\n\nimport torch as th\nimport torch.nn.functional as F\n\n\ndef get_positive_expectation(p_samples, average=True):\n    \"\"\"Computes the positive part of a JS Divergence.\n    Args:\n        p_samples: Positive samples.\n        average: Average the result over samples.\n    Returns:\n        th.Tensor\n    \"\"\"\n    log_2 = math.log(2.0)\n    Ep = log_2 - F.softplus(-p_samples)\n\n    if average:\n        return Ep.mean()\n    else:\n        return Ep\n\n\ndef get_negative_expectation(q_samples, average=True):\n    \"\"\"Computes the negative part of a JS Divergence.\n    Args:\n        q_samples: Negative samples.\n        average: Average the result over samples.\n    Returns:\n        th.Tensor\n    \"\"\"\n    log_2 = math.log(2.0)\n    Eq = F.softplus(-q_samples) + q_samples - log_2\n\n    if average:\n        return Eq.mean()\n    else:\n        return Eq\n\n\ndef local_global_loss_(l_enc, g_enc, graph_id):\n    num_graphs = g_enc.shape[0]\n    num_nodes = l_enc.shape[0]\n\n    device = g_enc.device\n\n    pos_mask = th.zeros((num_nodes, num_graphs)).to(device)\n    neg_mask = th.ones((num_nodes, num_graphs)).to(device)\n\n    for nodeidx, graphidx in enumerate(graph_id):\n        pos_mask[nodeidx][graphidx] = 1.0\n        neg_mask[nodeidx][graphidx] = 0.0\n\n    res = th.mm(l_enc, g_enc.t())\n\n    E_pos = get_positive_expectation(res * pos_mask, average=False).sum()\n    E_pos = E_pos / num_nodes\n    E_neg = get_negative_expectation(res * neg_mask, average=False).sum()\n    E_neg = E_neg / (num_nodes * (num_graphs - 1))\n\n    return E_neg - E_pos\n\n\ndef global_global_loss_(sup_enc, unsup_enc):\n    num_graphs = sup_enc.shape[0]\n    device = sup_enc.device\n\n    pos_mask = th.eye(num_graphs).to(device)\n    neg_mask = 1 - pos_mask\n\n    res = th.mm(sup_enc, unsup_enc.t())\n\n    E_pos = get_positive_expectation(res * pos_mask, average=False)\n    E_pos = (E_pos * pos_mask).sum() / pos_mask.sum()\n    E_neg = get_negative_expectation(res * neg_mask, average=False)\n    E_neg = (E_neg * neg_mask).sum() / neg_mask.sum()\n\n    return E_neg - E_pos\n"
  },
  {
    "path": "examples/pytorch/jknet/README.md",
    "content": "# DGL Implementation of JKNet\n\nThis DGL example implements the GNN model proposed in the paper [Representation Learning on Graphs with Jumping Knowledge Networks](https://arxiv.org/abs/1806.03536).\n\nContributor: [xnuohz](https://github.com/xnuohz)\n\n### Requirements\nThe codebase is implemented in Python 3.6. For version requirement of packages, see below.\n\n```\ndgl 0.6.0\nscikit-learn 0.24.1\ntqdm 4.56.0\ntorch 1.7.1\n```\n\n### The graph datasets used in this example\n\n###### Node Classification\n\nThe DGL's built-in Cora, Citeseer datasets. Dataset summary:\n\n| Dataset | #Nodes | #Edges | #Feats | #Classes | #Train Nodes | #Val Nodes | #Test Nodes |\n| :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: |\n| Cora | 2,708 | 10,556 | 1,433 | 7(single label) | 60% | 20% | 20% |\n| Citeseer | 3,327 | 9,228 | 3,703 | 6(single label) | 60% | 20% | 20% |\n\n### Usage\n\n###### Dataset options\n```\n--dataset          str     The graph dataset name.             Default is 'Cora'.\n```\n\n###### GPU options\n```\n--gpu              int     GPU index.                          Default is -1, using CPU.\n```\n\n###### Model options\n```\n--run              int     Number of running times.                    Default is 10.\n--epochs           int     Number of training epochs.                  Default is 500.\n--lr               float   Adam optimizer learning rate.               Default is 0.01.\n--lamb             float   L2 regularization coefficient.              Default is 0.0005.\n--hid-dim          int     Hidden layer dimensionalities.              Default is 32.\n--num-layers       int     Number of T.                                Default is 5.\n--mode             str     Type of aggregation ['cat', 'max', 'lstm']. Default is 'cat'.\n--dropout          float   Dropout applied at all layers.              Default is 0.5.\n```\n\n###### Examples\n\nThe following commands learn a neural network and predict on the test set.\nTrain a JKNet which follows the original hyperparameters on different datasets.\n```bash\n# Cora:\npython main.py --gpu 0 --mode max --num-layers 6\npython main.py --gpu 0 --mode cat --num-layers 6\npython main.py --gpu 0 --mode lstm --num-layers 1\n\n# Citeseer:\npython main.py --gpu 0 --dataset Citeseer --mode max --num-layers 1\npython main.py --gpu 0 --dataset Citeseer --mode cat --num-layers 1\npython main.py --gpu 0 --dataset Citeseer --mode lstm --num-layers 2\n```\n\n### Performance\n\n**As the author does not release the code, we don't have the access to the data splits they used.**\n\n###### Node Classification\n\n* Cora\n\n|  | JK-Maxpool | JK-Concat | JK-LSTM |\n| :-: | :-: | :-: | :-: |\n| Metrics(Table 2) | 89.6±0.5 | 89.1±1.1 | 85.8±1.0 |\n| Metrics(DGL) | 86.1±1.5 | 85.1±1.6 | 84.2±1.6 |\n\n* Citeseer\n\n|  | JK-Maxpool | JK-Concat | JK-LSTM |\n| :-: | :-: | :-: | :-: |\n| Metrics(Table 2) | 77.7±0.5 | 78.3±0.8 | 74.7±0.9 |\n| Metrics(DGL) | 70.9±1.9 | 73.0±1.5 | 69.0±1.7 |"
  },
  {
    "path": "examples/pytorch/jknet/main.py",
    "content": "\"\"\" The main file to train a JKNet model using a full graph \"\"\"\n\nimport argparse\nimport copy\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\n\nfrom dgl.data import CiteseerGraphDataset, CoraGraphDataset\nfrom model import JKNet\nfrom sklearn.model_selection import train_test_split\nfrom tqdm import trange\n\n\ndef main(args):\n    # Step 1: Prepare graph data and retrieve train/validation/test index ============================= #\n    # Load from DGL dataset\n    if args.dataset == \"Cora\":\n        dataset = CoraGraphDataset()\n    elif args.dataset == \"Citeseer\":\n        dataset = CiteseerGraphDataset()\n    else:\n        raise ValueError(\"Dataset {} is invalid.\".format(args.dataset))\n\n    graph = dataset[0]\n\n    # check cuda\n    device = (\n        f\"cuda:{args.gpu}\"\n        if args.gpu >= 0 and torch.cuda.is_available()\n        else \"cpu\"\n    )\n\n    # retrieve the number of classes\n    n_classes = dataset.num_classes\n\n    # retrieve labels of ground truth\n    labels = graph.ndata.pop(\"label\").to(device).long()\n\n    # Extract node features\n    feats = graph.ndata.pop(\"feat\").to(device)\n    n_features = feats.shape[-1]\n\n    # create masks for train / validation / test\n    # train : val : test = 6 : 2 : 2\n    n_nodes = graph.num_nodes()\n    idx = torch.arange(n_nodes).to(device)\n    train_idx, test_idx = train_test_split(idx, test_size=0.2)\n    train_idx, val_idx = train_test_split(train_idx, test_size=0.25)\n\n    graph = graph.to(device)\n\n    # Step 2: Create model =================================================================== #\n    model = JKNet(\n        in_dim=n_features,\n        hid_dim=args.hid_dim,\n        out_dim=n_classes,\n        num_layers=args.num_layers,\n        mode=args.mode,\n        dropout=args.dropout,\n    ).to(device)\n\n    best_model = copy.deepcopy(model)\n\n    # Step 3: Create training components ===================================================== #\n    loss_fn = nn.CrossEntropyLoss()\n    opt = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.lamb)\n\n    # Step 4: training epochs =============================================================== #\n    acc = 0\n    epochs = trange(args.epochs, desc=\"Accuracy & Loss\")\n\n    for _ in epochs:\n        # Training using a full graph\n        model.train()\n\n        logits = model(graph, feats)\n\n        # compute loss\n        train_loss = loss_fn(logits[train_idx], labels[train_idx])\n        train_acc = torch.sum(\n            logits[train_idx].argmax(dim=1) == labels[train_idx]\n        ).item() / len(train_idx)\n\n        # backward\n        opt.zero_grad()\n        train_loss.backward()\n        opt.step()\n\n        # Validation using a full graph\n        model.eval()\n\n        with torch.no_grad():\n            valid_loss = loss_fn(logits[val_idx], labels[val_idx])\n            valid_acc = torch.sum(\n                logits[val_idx].argmax(dim=1) == labels[val_idx]\n            ).item() / len(val_idx)\n\n        # Print out performance\n        epochs.set_description(\n            \"Train Acc {:.4f} | Train Loss {:.4f} | Val Acc {:.4f} | Val loss {:.4f}\".format(\n                train_acc, train_loss.item(), valid_acc, valid_loss.item()\n            )\n        )\n\n        if valid_acc > acc:\n            acc = valid_acc\n            best_model = copy.deepcopy(model)\n\n    best_model.eval()\n    logits = best_model(graph, feats)\n    test_acc = torch.sum(\n        logits[test_idx].argmax(dim=1) == labels[test_idx]\n    ).item() / len(test_idx)\n\n    print(\"Test Acc {:.4f}\".format(test_acc))\n    return test_acc\n\n\nif __name__ == \"__main__\":\n    \"\"\"\n    JKNet Hyperparameters\n    \"\"\"\n    parser = argparse.ArgumentParser(description=\"JKNet\")\n\n    # data source params\n    parser.add_argument(\n        \"--dataset\", type=str, default=\"Cora\", help=\"Name of dataset.\"\n    )\n    # cuda params\n    parser.add_argument(\n        \"--gpu\", type=int, default=-1, help=\"GPU index. Default: -1, using CPU.\"\n    )\n    # training params\n    parser.add_argument(\"--run\", type=int, default=10, help=\"Running times.\")\n    parser.add_argument(\n        \"--epochs\", type=int, default=500, help=\"Training epochs.\"\n    )\n    parser.add_argument(\n        \"--lr\", type=float, default=0.005, help=\"Learning rate.\"\n    )\n    parser.add_argument(\"--lamb\", type=float, default=0.0005, help=\"L2 reg.\")\n    # model params\n    parser.add_argument(\n        \"--hid-dim\", type=int, default=32, help=\"Hidden layer dimensionalities.\"\n    )\n    parser.add_argument(\n        \"--num-layers\", type=int, default=5, help=\"Number of GCN layers.\"\n    )\n    parser.add_argument(\n        \"--mode\",\n        type=str,\n        default=\"cat\",\n        help=\"Type of aggregation.\",\n        choices=[\"cat\", \"max\", \"lstm\"],\n    )\n    parser.add_argument(\n        \"--dropout\",\n        type=float,\n        default=0.5,\n        help=\"Dropout applied at all layers.\",\n    )\n\n    args = parser.parse_args()\n    print(args)\n\n    acc_lists = []\n\n    for _ in range(args.run):\n        acc_lists.append(main(args))\n\n    mean = np.around(np.mean(acc_lists, axis=0), decimals=3)\n    std = np.around(np.std(acc_lists, axis=0), decimals=3)\n    print(\"total acc: \", acc_lists)\n    print(\"mean\", mean)\n    print(\"std\", std)\n"
  },
  {
    "path": "examples/pytorch/jknet/model.py",
    "content": "import dgl.function as fn\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dgl.nn import GraphConv, JumpingKnowledge\n\n\nclass JKNet(nn.Module):\n    def __init__(\n        self, in_dim, hid_dim, out_dim, num_layers=1, mode=\"cat\", dropout=0.0\n    ):\n        super(JKNet, self).__init__()\n\n        self.mode = mode\n        self.dropout = nn.Dropout(dropout)\n        self.layers = nn.ModuleList()\n        self.layers.append(GraphConv(in_dim, hid_dim, activation=F.relu))\n        for _ in range(num_layers):\n            self.layers.append(GraphConv(hid_dim, hid_dim, activation=F.relu))\n\n        if self.mode == \"lstm\":\n            self.jump = JumpingKnowledge(mode, hid_dim, num_layers)\n        else:\n            self.jump = JumpingKnowledge(mode)\n\n        if self.mode == \"cat\":\n            hid_dim = hid_dim * (num_layers + 1)\n\n        self.output = nn.Linear(hid_dim, out_dim)\n        self.reset_params()\n\n    def reset_params(self):\n        self.output.reset_parameters()\n        for layers in self.layers:\n            layers.reset_parameters()\n        self.jump.reset_parameters()\n\n    def forward(self, g, feats):\n        feat_lst = []\n        for layer in self.layers:\n            feats = self.dropout(layer(g, feats))\n            feat_lst.append(feats)\n\n        if self.mode == \"lstm\":\n            self.jump.lstm.flatten_parameters()\n\n        g.ndata[\"h\"] = self.jump(feat_lst)\n        g.update_all(fn.copy_u(\"h\", \"m\"), fn.sum(\"m\", \"h\"))\n\n        return self.output(g.ndata[\"h\"])\n"
  },
  {
    "path": "examples/pytorch/jtnn/README.md",
    "content": "Junction Tree VAE - example for training\n==========================================\n\nThis is a direct modification from https://github.com/wengong-jin/icml18-jtnn\n\nDependencies\n--------------\n* PyTorch 0.4.1+\n* RDKit=2018.09.3.0\n* requests\n\nHow to run\n-----------\n\nTo run the model, use\n```\npython3 vaetrain_dgl.py\n```\nThe script will automatically download the data, which is the same as the one in the\noriginal repository.\n\nTo disable CUDA, run with `NOCUDA` variable set:\n```\nNOCUDA=1 python3 vaetrain_dgl.py\n```\n\nTo decode for new molecules, run\n```\npython3 vaetrain_dgl.py -T\n```\n\nCurrently, decoding involves encoding a training example, sampling from the posterior\ndistribution, and decoding a molecule from that.\n"
  },
  {
    "path": "examples/pytorch/jtnn/jtnn/__init__.py",
    "content": "from .chemutils import decode_stereo\nfrom .datautils import JTNNCollator, JTNNDataset\nfrom .jtnn_vae import DGLJTNNVAE\nfrom .mol_tree import Vocab\nfrom .mpn import DGLMPN\nfrom .nnutils import cuda\n"
  },
  {
    "path": "examples/pytorch/jtnn/jtnn/chemutils.py",
    "content": "from collections import defaultdict\n\nimport rdkit.Chem as Chem\nfrom rdkit.Chem.EnumerateStereoisomers import EnumerateStereoisomers\nfrom scipy.sparse import csr_matrix\nfrom scipy.sparse.csgraph import minimum_spanning_tree\n\nMST_MAX_WEIGHT = 100\nMAX_NCAND = 2000\n\n\ndef set_atommap(mol, num=0):\n    for atom in mol.GetAtoms():\n        atom.SetAtomMapNum(num)\n\n\ndef get_mol(smiles):\n    mol = Chem.MolFromSmiles(smiles)\n    if mol is None:\n        return None\n    Chem.Kekulize(mol)\n    return mol\n\n\ndef get_smiles(mol):\n    return Chem.MolToSmiles(mol, kekuleSmiles=True)\n\n\ndef decode_stereo(smiles2D):\n    mol = Chem.MolFromSmiles(smiles2D)\n    dec_isomers = list(EnumerateStereoisomers(mol))\n\n    dec_isomers = [\n        Chem.MolFromSmiles(Chem.MolToSmiles(mol, isomericSmiles=True))\n        for mol in dec_isomers\n    ]\n    smiles3D = [\n        Chem.MolToSmiles(mol, isomericSmiles=True) for mol in dec_isomers\n    ]\n\n    chiralN = [\n        atom.GetIdx()\n        for atom in dec_isomers[0].GetAtoms()\n        if int(atom.GetChiralTag()) > 0 and atom.GetSymbol() == \"N\"\n    ]\n    if len(chiralN) > 0:\n        for mol in dec_isomers:\n            for idx in chiralN:\n                mol.GetAtomWithIdx(idx).SetChiralTag(\n                    Chem.rdchem.ChiralType.CHI_UNSPECIFIED\n                )\n            smiles3D.append(Chem.MolToSmiles(mol, isomericSmiles=True))\n\n    return smiles3D\n\n\ndef sanitize(mol):\n    try:\n        smiles = get_smiles(mol)\n        mol = get_mol(smiles)\n    except Exception as e:\n        return None\n    return mol\n\n\ndef copy_atom(atom):\n    new_atom = Chem.Atom(atom.GetSymbol())\n    new_atom.SetFormalCharge(atom.GetFormalCharge())\n    new_atom.SetAtomMapNum(atom.GetAtomMapNum())\n    return new_atom\n\n\ndef copy_edit_mol(mol):\n    new_mol = Chem.RWMol(Chem.MolFromSmiles(\"\"))\n    for atom in mol.GetAtoms():\n        new_atom = copy_atom(atom)\n        new_mol.AddAtom(new_atom)\n    for bond in mol.GetBonds():\n        a1 = bond.GetBeginAtom().GetIdx()\n        a2 = bond.GetEndAtom().GetIdx()\n        bt = bond.GetBondType()\n        new_mol.AddBond(a1, a2, bt)\n    return new_mol\n\n\ndef get_clique_mol(mol, atoms):\n    smiles = Chem.MolFragmentToSmiles(mol, atoms, kekuleSmiles=True)\n    new_mol = Chem.MolFromSmiles(smiles, sanitize=False)\n    new_mol = copy_edit_mol(new_mol).GetMol()\n    new_mol = sanitize(new_mol)  # We assume this is not None\n    return new_mol\n\n\ndef tree_decomp(mol):\n    n_atoms = mol.GetNumAtoms()\n    if n_atoms == 1:\n        return [[0]], []\n\n    cliques = []\n    for bond in mol.GetBonds():\n        a1 = bond.GetBeginAtom().GetIdx()\n        a2 = bond.GetEndAtom().GetIdx()\n        if not bond.IsInRing():\n            cliques.append([a1, a2])\n\n    ssr = [list(x) for x in Chem.GetSymmSSSR(mol)]\n    cliques.extend(ssr)\n\n    nei_list = [[] for i in range(n_atoms)]\n    for i in range(len(cliques)):\n        for atom in cliques[i]:\n            nei_list[atom].append(i)\n\n    # Merge Rings with intersection > 2 atoms\n    for i in range(len(cliques)):\n        if len(cliques[i]) <= 2:\n            continue\n        for atom in cliques[i]:\n            for j in nei_list[atom]:\n                if i >= j or len(cliques[j]) <= 2:\n                    continue\n                inter = set(cliques[i]) & set(cliques[j])\n                if len(inter) > 2:\n                    cliques[i].extend(cliques[j])\n                    cliques[i] = list(set(cliques[i]))\n                    cliques[j] = []\n\n    cliques = [c for c in cliques if len(c) > 0]\n    nei_list = [[] for i in range(n_atoms)]\n    for i in range(len(cliques)):\n        for atom in cliques[i]:\n            nei_list[atom].append(i)\n\n    # Build edges and add singleton cliques\n    edges = defaultdict(int)\n    for atom in range(n_atoms):\n        if len(nei_list[atom]) <= 1:\n            continue\n        cnei = nei_list[atom]\n        bonds = [c for c in cnei if len(cliques[c]) == 2]\n        rings = [c for c in cnei if len(cliques[c]) > 4]\n        # In general, if len(cnei) >= 3, a singleton should be added, but 1 bond + 2 ring is currently not dealt with.\n        if len(bonds) > 2 or (len(bonds) == 2 and len(cnei) > 2):\n            cliques.append([atom])\n            c2 = len(cliques) - 1\n            for c1 in cnei:\n                edges[(c1, c2)] = 1\n        elif len(rings) > 2:  # Multiple (n>2) complex rings\n            cliques.append([atom])\n            c2 = len(cliques) - 1\n            for c1 in cnei:\n                edges[(c1, c2)] = MST_MAX_WEIGHT - 1\n        else:\n            for i in range(len(cnei)):\n                for j in range(i + 1, len(cnei)):\n                    c1, c2 = cnei[i], cnei[j]\n                    inter = set(cliques[c1]) & set(cliques[c2])\n                    if edges[(c1, c2)] < len(inter):\n                        edges[(c1, c2)] = len(\n                            inter\n                        )  # cnei[i] < cnei[j] by construction\n\n    edges = [u + (MST_MAX_WEIGHT - v,) for u, v in edges.items()]\n    if len(edges) == 0:\n        return cliques, edges\n\n    # Compute Maximum Spanning Tree\n    row, col, data = list(zip(*edges))\n    n_clique = len(cliques)\n    clique_graph = csr_matrix((data, (row, col)), shape=(n_clique, n_clique))\n    junc_tree = minimum_spanning_tree(clique_graph)\n    row, col = junc_tree.nonzero()\n    edges = [(row[i], col[i]) for i in range(len(row))]\n    return (cliques, edges)\n\n\ndef atom_equal(a1, a2):\n    return (\n        a1.GetSymbol() == a2.GetSymbol()\n        and a1.GetFormalCharge() == a2.GetFormalCharge()\n    )\n\n\n# Bond type not considered because all aromatic (so SINGLE matches DOUBLE)\ndef ring_bond_equal(b1, b2, reverse=False):\n    b1 = (b1.GetBeginAtom(), b1.GetEndAtom())\n    if reverse:\n        b2 = (b2.GetEndAtom(), b2.GetBeginAtom())\n    else:\n        b2 = (b2.GetBeginAtom(), b2.GetEndAtom())\n    return atom_equal(b1[0], b2[0]) and atom_equal(b1[1], b2[1])\n\n\ndef attach_mols_nx(ctr_mol, neighbors, prev_nodes, nei_amap):\n    prev_nids = [node[\"nid\"] for node in prev_nodes]\n    for nei_node in prev_nodes + neighbors:\n        nei_id, nei_mol = nei_node[\"nid\"], nei_node[\"mol\"]\n        amap = nei_amap[nei_id]\n        for atom in nei_mol.GetAtoms():\n            if atom.GetIdx() not in amap:\n                new_atom = copy_atom(atom)\n                amap[atom.GetIdx()] = ctr_mol.AddAtom(new_atom)\n\n        if nei_mol.GetNumBonds() == 0:\n            nei_atom = nei_mol.GetAtomWithIdx(0)\n            ctr_atom = ctr_mol.GetAtomWithIdx(amap[0])\n            ctr_atom.SetAtomMapNum(nei_atom.GetAtomMapNum())\n        else:\n            for bond in nei_mol.GetBonds():\n                a1 = amap[bond.GetBeginAtom().GetIdx()]\n                a2 = amap[bond.GetEndAtom().GetIdx()]\n                if ctr_mol.GetBondBetweenAtoms(a1, a2) is None:\n                    ctr_mol.AddBond(a1, a2, bond.GetBondType())\n                elif nei_id in prev_nids:  # father node overrides\n                    ctr_mol.RemoveBond(a1, a2)\n                    ctr_mol.AddBond(a1, a2, bond.GetBondType())\n    return ctr_mol\n\n\ndef local_attach_nx(ctr_mol, neighbors, prev_nodes, amap_list):\n    ctr_mol = copy_edit_mol(ctr_mol)\n    nei_amap = {nei[\"nid\"]: {} for nei in prev_nodes + neighbors}\n\n    for nei_id, ctr_atom, nei_atom in amap_list:\n        nei_amap[nei_id][nei_atom] = ctr_atom\n\n    ctr_mol = attach_mols_nx(ctr_mol, neighbors, prev_nodes, nei_amap)\n    return ctr_mol.GetMol()\n\n\n# This version records idx mapping between ctr_mol and nei_mol\ndef enum_attach_nx(ctr_mol, nei_node, amap, singletons):\n    nei_mol, nei_idx = nei_node[\"mol\"], nei_node[\"nid\"]\n    att_confs = []\n    black_list = [\n        atom_idx for nei_id, atom_idx, _ in amap if nei_id in singletons\n    ]\n    ctr_atoms = [\n        atom for atom in ctr_mol.GetAtoms() if atom.GetIdx() not in black_list\n    ]\n    ctr_bonds = [bond for bond in ctr_mol.GetBonds()]\n\n    if nei_mol.GetNumBonds() == 0:  # neighbor singleton\n        nei_atom = nei_mol.GetAtomWithIdx(0)\n        used_list = [atom_idx for _, atom_idx, _ in amap]\n        for atom in ctr_atoms:\n            if atom_equal(atom, nei_atom) and atom.GetIdx() not in used_list:\n                new_amap = amap + [(nei_idx, atom.GetIdx(), 0)]\n                att_confs.append(new_amap)\n\n    elif nei_mol.GetNumBonds() == 1:  # neighbor is a bond\n        bond = nei_mol.GetBondWithIdx(0)\n        bond_val = int(bond.GetBondTypeAsDouble())\n        b1, b2 = bond.GetBeginAtom(), bond.GetEndAtom()\n\n        for atom in ctr_atoms:\n            # Optimize if atom is carbon (other atoms may change valence)\n            if atom.GetAtomicNum() == 6 and atom.GetTotalNumHs() < bond_val:\n                continue\n            if atom_equal(atom, b1):\n                new_amap = amap + [(nei_idx, atom.GetIdx(), b1.GetIdx())]\n                att_confs.append(new_amap)\n            elif atom_equal(atom, b2):\n                new_amap = amap + [(nei_idx, atom.GetIdx(), b2.GetIdx())]\n                att_confs.append(new_amap)\n    else:\n        # intersection is an atom\n        for a1 in ctr_atoms:\n            for a2 in nei_mol.GetAtoms():\n                if atom_equal(a1, a2):\n                    # Optimize if atom is carbon (other atoms may change valence)\n                    if (\n                        a1.GetAtomicNum() == 6\n                        and a1.GetTotalNumHs() + a2.GetTotalNumHs() < 4\n                    ):\n                        continue\n                    new_amap = amap + [(nei_idx, a1.GetIdx(), a2.GetIdx())]\n                    att_confs.append(new_amap)\n\n        # intersection is an bond\n        if ctr_mol.GetNumBonds() > 1:\n            for b1 in ctr_bonds:\n                for b2 in nei_mol.GetBonds():\n                    if ring_bond_equal(b1, b2):\n                        new_amap = amap + [\n                            (\n                                nei_idx,\n                                b1.GetBeginAtom().GetIdx(),\n                                b2.GetBeginAtom().GetIdx(),\n                            ),\n                            (\n                                nei_idx,\n                                b1.GetEndAtom().GetIdx(),\n                                b2.GetEndAtom().GetIdx(),\n                            ),\n                        ]\n                        att_confs.append(new_amap)\n\n                    if ring_bond_equal(b1, b2, reverse=True):\n                        new_amap = amap + [\n                            (\n                                nei_idx,\n                                b1.GetBeginAtom().GetIdx(),\n                                b2.GetEndAtom().GetIdx(),\n                            ),\n                            (\n                                nei_idx,\n                                b1.GetEndAtom().GetIdx(),\n                                b2.GetBeginAtom().GetIdx(),\n                            ),\n                        ]\n                        att_confs.append(new_amap)\n    return att_confs\n\n\n# Try rings first: Speed-Up\ndef enum_assemble_nx(node, neighbors, prev_nodes=[], prev_amap=[]):\n    all_attach_confs = []\n    singletons = [\n        nei_node[\"nid\"]\n        for nei_node in neighbors + prev_nodes\n        if nei_node[\"mol\"].GetNumAtoms() == 1\n    ]\n\n    def search(cur_amap, depth):\n        if len(all_attach_confs) > MAX_NCAND:\n            return\n        if depth == len(neighbors):\n            all_attach_confs.append(cur_amap)\n            return\n\n        nei_node = neighbors[depth]\n        cand_amap = enum_attach_nx(node[\"mol\"], nei_node, cur_amap, singletons)\n        cand_smiles = set()\n        candidates = []\n        for amap in cand_amap:\n            cand_mol = local_attach_nx(\n                node[\"mol\"], neighbors[: depth + 1], prev_nodes, amap\n            )\n            cand_mol = sanitize(cand_mol)\n            if cand_mol is None:\n                continue\n            smiles = get_smiles(cand_mol)\n            if smiles in cand_smiles:\n                continue\n            cand_smiles.add(smiles)\n            candidates.append(amap)\n\n        if len(candidates) == 0:\n            return []\n\n        for new_amap in candidates:\n            search(new_amap, depth + 1)\n\n    search(prev_amap, 0)\n    cand_smiles = set()\n    candidates = []\n    for amap in all_attach_confs:\n        cand_mol = local_attach_nx(node[\"mol\"], neighbors, prev_nodes, amap)\n        cand_mol = Chem.MolFromSmiles(Chem.MolToSmiles(cand_mol))\n        smiles = Chem.MolToSmiles(cand_mol)\n        if smiles in cand_smiles:\n            continue\n        cand_smiles.add(smiles)\n        Chem.Kekulize(cand_mol)\n        candidates.append((smiles, cand_mol, amap))\n\n    return candidates\n\n\n# Only used for debugging purpose\ndef dfs_assemble_nx(\n    graph, cur_mol, global_amap, fa_amap, cur_node_id, fa_node_id\n):\n    cur_node = graph.nodes_dict[cur_node_id]\n    fa_node = graph.nodes_dict[fa_node_id] if fa_node_id is not None else None\n\n    fa_nid = fa_node[\"nid\"] if fa_node is not None else -1\n    prev_nodes = [fa_node] if fa_node is not None else []\n\n    children_id = [\n        nei\n        for nei in graph[cur_node_id]\n        if graph.nodes_dict[nei][\"nid\"] != fa_nid\n    ]\n    children = [graph.nodes_dict[nei] for nei in children_id]\n    neighbors = [nei for nei in children if nei[\"mol\"].GetNumAtoms() > 1]\n    neighbors = sorted(\n        neighbors, key=lambda x: x[\"mol\"].GetNumAtoms(), reverse=True\n    )\n    singletons = [nei for nei in children if nei[\"mol\"].GetNumAtoms() == 1]\n    neighbors = singletons + neighbors\n\n    cur_amap = [\n        (fa_nid, a2, a1) for nid, a1, a2 in fa_amap if nid == cur_node[\"nid\"]\n    ]\n    cands = enum_assemble_nx(\n        graph.nodes_dict[cur_node_id], neighbors, prev_nodes, cur_amap\n    )\n    if len(cands) == 0:\n        return\n\n    cand_smiles, _, cand_amap = zip(*cands)\n    label_idx = cand_smiles.index(cur_node[\"label\"])\n    label_amap = cand_amap[label_idx]\n\n    for nei_id, ctr_atom, nei_atom in label_amap:\n        if nei_id == fa_nid:\n            continue\n        global_amap[nei_id][nei_atom] = global_amap[cur_node[\"nid\"]][ctr_atom]\n\n    cur_mol = attach_mols_nx(\n        cur_mol, children, [], global_amap\n    )  # father is already attached\n    for nei_node_id, nei_node in zip(children_id, children):\n        if not nei_node[\"is_leaf\"]:\n            dfs_assemble_nx(\n                graph,\n                cur_mol,\n                global_amap,\n                label_amap,\n                nei_node_id,\n                cur_node_id,\n            )\n"
  },
  {
    "path": "examples/pytorch/jtnn/jtnn/datautils.py",
    "content": "import dgl\nimport torch\nfrom dgl.data.utils import (\n    _get_dgl_url,\n    download,\n    extract_archive,\n    get_download_dir,\n)\nfrom torch.utils.data import Dataset\n\nfrom .jtmpn import (\n    ATOM_FDIM as ATOM_FDIM_DEC,\n    BOND_FDIM as BOND_FDIM_DEC,\n    mol2dgl_single as mol2dgl_dec,\n)\nfrom .mol_tree import Vocab\nfrom .mol_tree_nx import DGLMolTree\nfrom .mpn import mol2dgl_single as mol2dgl_enc\n\n\ndef _unpack_field(examples, field):\n    return [e[field] for e in examples]\n\n\ndef _set_node_id(mol_tree, vocab):\n    wid = []\n    for i, node in enumerate(mol_tree.nodes_dict):\n        mol_tree.nodes_dict[node][\"idx\"] = i\n        wid.append(vocab.get_index(mol_tree.nodes_dict[node][\"smiles\"]))\n\n    return wid\n\n\nclass JTNNDataset(Dataset):\n    def __init__(self, data, vocab, training=True):\n        self.dir = get_download_dir()\n        self.zip_file_path = \"{}/jtnn.zip\".format(self.dir)\n\n        download(_get_dgl_url(\"dgllife/jtnn.zip\"), path=self.zip_file_path)\n        extract_archive(self.zip_file_path, \"{}/jtnn\".format(self.dir))\n        print(\"Loading data...\")\n        data_file = \"{}/jtnn/{}.txt\".format(self.dir, data)\n        with open(data_file) as f:\n            self.data = [line.strip(\"\\r\\n \").split()[0] for line in f]\n        self.vocab_file = \"{}/jtnn/{}.txt\".format(self.dir, vocab)\n        print(\"Loading finished.\")\n        print(\"\\tNum samples:\", len(self.data))\n        print(\"\\tVocab file:\", self.vocab_file)\n        self.training = training\n        self.vocab = Vocab([x.strip(\"\\r\\n \") for x in open(self.vocab_file)])\n\n    def __len__(self):\n        return len(self.data)\n\n    def __getitem__(self, idx):\n        smiles = self.data[idx]\n        mol_tree = DGLMolTree(smiles)\n        mol_tree.recover()\n        mol_tree.assemble()\n\n        wid = _set_node_id(mol_tree, self.vocab)\n\n        # prebuild the molecule graph\n        mol_graph, atom_x_enc, bond_x_enc = mol2dgl_enc(mol_tree.smiles)\n\n        result = {\n            \"mol_tree\": mol_tree,\n            \"mol_graph\": mol_graph,\n            \"atom_x_enc\": atom_x_enc,\n            \"bond_x_enc\": bond_x_enc,\n            \"wid\": wid,\n        }\n\n        if not self.training:\n            return result\n\n        # prebuild the candidate graph list\n        cands = []\n        for node_id, node in mol_tree.nodes_dict.items():\n            # fill in ground truth\n            if node[\"label\"] not in node[\"cands\"]:\n                node[\"cands\"].append(node[\"label\"])\n                node[\"cand_mols\"].append(node[\"label_mol\"])\n\n            if node[\"is_leaf\"] or len(node[\"cands\"]) == 1:\n                continue\n            cands.extend(\n                [(cand, mol_tree, node_id) for cand in node[\"cand_mols\"]]\n            )\n        if len(cands) > 0:\n            (\n                cand_graphs,\n                atom_x_dec,\n                bond_x_dec,\n                tree_mess_src_e,\n                tree_mess_tgt_e,\n                tree_mess_tgt_n,\n            ) = mol2dgl_dec(cands)\n        else:\n            cand_graphs = []\n            atom_x_dec = torch.zeros(0, ATOM_FDIM_DEC)\n            bond_x_dec = torch.zeros(0, BOND_FDIM_DEC)\n            tree_mess_src_e = torch.zeros(0, 2).long()\n            tree_mess_tgt_e = torch.zeros(0, 2).long()\n            tree_mess_tgt_n = torch.zeros(0).long()\n\n        # prebuild the stereoisomers\n        cands = mol_tree.stereo_cands\n        if len(cands) > 1:\n            if mol_tree.smiles3D not in cands:\n                cands.append(mol_tree.smiles3D)\n\n            stereo_graphs = [mol2dgl_enc(c) for c in cands]\n            stereo_cand_graphs, stereo_atom_x_enc, stereo_bond_x_enc = zip(\n                *stereo_graphs\n            )\n            stereo_atom_x_enc = torch.cat(stereo_atom_x_enc)\n            stereo_bond_x_enc = torch.cat(stereo_bond_x_enc)\n            stereo_cand_label = [(cands.index(mol_tree.smiles3D), len(cands))]\n        else:\n            stereo_cand_graphs = []\n            stereo_atom_x_enc = torch.zeros(0, atom_x_enc.shape[1])\n            stereo_bond_x_enc = torch.zeros(0, bond_x_enc.shape[1])\n            stereo_cand_label = []\n\n        result.update(\n            {\n                \"cand_graphs\": cand_graphs,\n                \"atom_x_dec\": atom_x_dec,\n                \"bond_x_dec\": bond_x_dec,\n                \"tree_mess_src_e\": tree_mess_src_e,\n                \"tree_mess_tgt_e\": tree_mess_tgt_e,\n                \"tree_mess_tgt_n\": tree_mess_tgt_n,\n                \"stereo_cand_graphs\": stereo_cand_graphs,\n                \"stereo_atom_x_enc\": stereo_atom_x_enc,\n                \"stereo_bond_x_enc\": stereo_bond_x_enc,\n                \"stereo_cand_label\": stereo_cand_label,\n            }\n        )\n\n        return result\n\n\nclass JTNNCollator(object):\n    def __init__(self, vocab, training):\n        self.vocab = vocab\n        self.training = training\n\n    @staticmethod\n    def _batch_and_set(graphs, atom_x, bond_x, flatten):\n        if flatten:\n            graphs = [g for f in graphs for g in f]\n        graph_batch = dgl.batch(graphs)\n        graph_batch.ndata[\"x\"] = atom_x\n        graph_batch.edata.update(\n            {\n                \"x\": bond_x,\n                \"src_x\": atom_x.new(bond_x.shape[0], atom_x.shape[1]).zero_(),\n            }\n        )\n        return graph_batch\n\n    def __call__(self, examples):\n        # get list of trees\n        mol_trees = _unpack_field(examples, \"mol_tree\")\n        wid = _unpack_field(examples, \"wid\")\n        for _wid, mol_tree in zip(wid, mol_trees):\n            mol_tree.graph.ndata[\"wid\"] = torch.LongTensor(_wid)\n\n        # TODO: either support pickling or get around ctypes pointers using scipy\n        # batch molecule graphs\n        mol_graphs = _unpack_field(examples, \"mol_graph\")\n        atom_x = torch.cat(_unpack_field(examples, \"atom_x_enc\"))\n        bond_x = torch.cat(_unpack_field(examples, \"bond_x_enc\"))\n        mol_graph_batch = self._batch_and_set(mol_graphs, atom_x, bond_x, False)\n\n        result = {\n            \"mol_trees\": mol_trees,\n            \"mol_graph_batch\": mol_graph_batch,\n        }\n\n        if not self.training:\n            return result\n\n        # batch candidate graphs\n        cand_graphs = _unpack_field(examples, \"cand_graphs\")\n        cand_batch_idx = []\n        atom_x = torch.cat(_unpack_field(examples, \"atom_x_dec\"))\n        bond_x = torch.cat(_unpack_field(examples, \"bond_x_dec\"))\n        tree_mess_src_e = _unpack_field(examples, \"tree_mess_src_e\")\n        tree_mess_tgt_e = _unpack_field(examples, \"tree_mess_tgt_e\")\n        tree_mess_tgt_n = _unpack_field(examples, \"tree_mess_tgt_n\")\n\n        n_graph_nodes = 0\n        n_tree_nodes = 0\n        for i in range(len(cand_graphs)):\n            tree_mess_tgt_e[i] += n_graph_nodes\n            tree_mess_src_e[i] += n_tree_nodes\n            tree_mess_tgt_n[i] += n_graph_nodes\n            n_graph_nodes += sum(g.num_nodes() for g in cand_graphs[i])\n            n_tree_nodes += mol_trees[i].graph.num_nodes()\n            cand_batch_idx.extend([i] * len(cand_graphs[i]))\n        tree_mess_tgt_e = torch.cat(tree_mess_tgt_e)\n        tree_mess_src_e = torch.cat(tree_mess_src_e)\n        tree_mess_tgt_n = torch.cat(tree_mess_tgt_n)\n\n        cand_graph_batch = self._batch_and_set(\n            cand_graphs, atom_x, bond_x, True\n        )\n\n        # batch stereoisomers\n        stereo_cand_graphs = _unpack_field(examples, \"stereo_cand_graphs\")\n        atom_x = torch.cat(_unpack_field(examples, \"stereo_atom_x_enc\"))\n        bond_x = torch.cat(_unpack_field(examples, \"stereo_bond_x_enc\"))\n        stereo_cand_batch_idx = []\n        for i in range(len(stereo_cand_graphs)):\n            stereo_cand_batch_idx.extend([i] * len(stereo_cand_graphs[i]))\n\n        if len(stereo_cand_batch_idx) > 0:\n            stereo_cand_labels = [\n                (label, length)\n                for ex in _unpack_field(examples, \"stereo_cand_label\")\n                for label, length in ex\n            ]\n            stereo_cand_labels, stereo_cand_lengths = zip(*stereo_cand_labels)\n            stereo_cand_graph_batch = self._batch_and_set(\n                stereo_cand_graphs, atom_x, bond_x, True\n            )\n        else:\n            stereo_cand_labels = []\n            stereo_cand_lengths = []\n            stereo_cand_graph_batch = None\n            stereo_cand_batch_idx = []\n\n        result.update(\n            {\n                \"cand_graph_batch\": cand_graph_batch,\n                \"cand_batch_idx\": cand_batch_idx,\n                \"tree_mess_tgt_e\": tree_mess_tgt_e,\n                \"tree_mess_src_e\": tree_mess_src_e,\n                \"tree_mess_tgt_n\": tree_mess_tgt_n,\n                \"stereo_cand_graph_batch\": stereo_cand_graph_batch,\n                \"stereo_cand_batch_idx\": stereo_cand_batch_idx,\n                \"stereo_cand_labels\": stereo_cand_labels,\n                \"stereo_cand_lengths\": stereo_cand_lengths,\n            }\n        )\n\n        return result\n"
  },
  {
    "path": "examples/pytorch/jtnn/jtnn/jtmpn.py",
    "content": "import os\n\nimport dgl\nimport dgl.function as DGLF\nimport rdkit.Chem as Chem\nimport torch\nimport torch.nn as nn\nfrom dgl import line_graph, mean_nodes\n\nfrom .nnutils import cuda\n\nELEM_LIST = [\n    \"C\",\n    \"N\",\n    \"O\",\n    \"S\",\n    \"F\",\n    \"Si\",\n    \"P\",\n    \"Cl\",\n    \"Br\",\n    \"Mg\",\n    \"Na\",\n    \"Ca\",\n    \"Fe\",\n    \"Al\",\n    \"I\",\n    \"B\",\n    \"K\",\n    \"Se\",\n    \"Zn\",\n    \"H\",\n    \"Cu\",\n    \"Mn\",\n    \"unknown\",\n]\n\nATOM_FDIM = len(ELEM_LIST) + 6 + 5 + 1\nBOND_FDIM = 5\nMAX_NB = 10\n\nPAPER = os.getenv(\"PAPER\", False)\n\n\ndef onek_encoding_unk(x, allowable_set):\n    if x not in allowable_set:\n        x = allowable_set[-1]\n    return [x == s for s in allowable_set]\n\n\n# Note that during graph decoding they don't predict stereochemistry-related\n# characteristics (i.e. Chiral Atoms, E-Z, Cis-Trans).  Instead, they decode\n# the 2-D graph first, then enumerate all possible 3-D forms and find the\n# one with highest score.\ndef atom_features(atom):\n    return torch.Tensor(\n        onek_encoding_unk(atom.GetSymbol(), ELEM_LIST)\n        + onek_encoding_unk(atom.GetDegree(), [0, 1, 2, 3, 4, 5])\n        + onek_encoding_unk(atom.GetFormalCharge(), [-1, -2, 1, 2, 0])\n        + [atom.GetIsAromatic()]\n    )\n\n\ndef bond_features(bond):\n    bt = bond.GetBondType()\n    return torch.Tensor(\n        [\n            bt == Chem.rdchem.BondType.SINGLE,\n            bt == Chem.rdchem.BondType.DOUBLE,\n            bt == Chem.rdchem.BondType.TRIPLE,\n            bt == Chem.rdchem.BondType.AROMATIC,\n            bond.IsInRing(),\n        ]\n    )\n\n\ndef mol2dgl_single(cand_batch):\n    cand_graphs = []\n    tree_mess_source_edges = []  # map these edges from trees to...\n    tree_mess_target_edges = []  # these edges on candidate graphs\n    tree_mess_target_nodes = []\n    n_nodes = 0\n    n_edges = 0\n    atom_x = []\n    bond_x = []\n\n    for mol, mol_tree, ctr_node_id in cand_batch:\n        n_atoms = mol.GetNumAtoms()\n        n_bonds = mol.GetNumBonds()\n\n        ctr_node = mol_tree.nodes_dict[ctr_node_id]\n        ctr_bid = ctr_node[\"idx\"]\n        mol_tree_graph = getattr(mol_tree, \"graph\", mol_tree)\n\n        for i, atom in enumerate(mol.GetAtoms()):\n            assert i == atom.GetIdx()\n            atom_x.append(atom_features(atom))\n\n        bond_src = []\n        bond_dst = []\n        for i, bond in enumerate(mol.GetBonds()):\n            a1 = bond.GetBeginAtom()\n            a2 = bond.GetEndAtom()\n            begin_idx = a1.GetIdx()\n            end_idx = a2.GetIdx()\n            features = bond_features(bond)\n\n            bond_src.append(begin_idx)\n            bond_dst.append(end_idx)\n            bond_x.append(features)\n            bond_src.append(end_idx)\n            bond_dst.append(begin_idx)\n            bond_x.append(features)\n\n            x_nid, y_nid = a1.GetAtomMapNum(), a2.GetAtomMapNum()\n            # Tree node ID in the batch\n            x_bid = mol_tree.nodes_dict[x_nid - 1][\"idx\"] if x_nid > 0 else -1\n            y_bid = mol_tree.nodes_dict[y_nid - 1][\"idx\"] if y_nid > 0 else -1\n            if x_bid >= 0 and y_bid >= 0 and x_bid != y_bid:\n                if mol_tree_graph.has_edges_between(x_bid, y_bid):\n                    tree_mess_target_edges.append(\n                        (begin_idx + n_nodes, end_idx + n_nodes)\n                    )\n                    tree_mess_source_edges.append((x_bid, y_bid))\n                    tree_mess_target_nodes.append(end_idx + n_nodes)\n                if mol_tree_graph.has_edges_between(y_bid, x_bid):\n                    tree_mess_target_edges.append(\n                        (end_idx + n_nodes, begin_idx + n_nodes)\n                    )\n                    tree_mess_source_edges.append((y_bid, x_bid))\n                    tree_mess_target_nodes.append(begin_idx + n_nodes)\n\n        n_nodes += n_atoms\n        g = dgl.graph((bond_src, bond_dst), num_nodes=n_atoms)\n        cand_graphs.append(g)\n\n    return (\n        cand_graphs,\n        torch.stack(atom_x),\n        torch.stack(bond_x) if len(bond_x) > 0 else torch.zeros(0),\n        torch.LongTensor(tree_mess_source_edges),\n        torch.LongTensor(tree_mess_target_edges),\n        torch.LongTensor(tree_mess_target_nodes),\n    )\n\n\nclass LoopyBPUpdate(nn.Module):\n    def __init__(self, hidden_size):\n        super(LoopyBPUpdate, self).__init__()\n        self.hidden_size = hidden_size\n\n        self.W_h = nn.Linear(hidden_size, hidden_size, bias=False)\n\n    def forward(self, node):\n        msg_input = node.data[\"msg_input\"]\n        msg_delta = self.W_h(node.data[\"accum_msg\"] + node.data[\"alpha\"])\n        msg = torch.relu(msg_input + msg_delta)\n        return {\"msg\": msg}\n\n\nif PAPER:\n    mpn_gather_msg = [\n        DGLF.copy_e(edge=\"msg\", out=\"msg\"),\n        DGLF.copy_e(edge=\"alpha\", out=\"alpha\"),\n    ]\nelse:\n    mpn_gather_msg = DGLF.copy_e(edge=\"msg\", out=\"msg\")\n\n\nif PAPER:\n    mpn_gather_reduce = [\n        DGLF.sum(msg=\"msg\", out=\"m\"),\n        DGLF.sum(msg=\"alpha\", out=\"accum_alpha\"),\n    ]\nelse:\n    mpn_gather_reduce = DGLF.sum(msg=\"msg\", out=\"m\")\n\n\nclass GatherUpdate(nn.Module):\n    def __init__(self, hidden_size):\n        super(GatherUpdate, self).__init__()\n        self.hidden_size = hidden_size\n\n        self.W_o = nn.Linear(ATOM_FDIM + hidden_size, hidden_size)\n\n    def forward(self, node):\n        if PAPER:\n            # m = node['m']\n            m = node.data[\"m\"] + node.data[\"accum_alpha\"]\n        else:\n            m = node.data[\"m\"] + node.data[\"alpha\"]\n        return {\n            \"h\": torch.relu(self.W_o(torch.cat([node.data[\"x\"], m], 1))),\n        }\n\n\nclass DGLJTMPN(nn.Module):\n    def __init__(self, hidden_size, depth):\n        nn.Module.__init__(self)\n\n        self.depth = depth\n\n        self.W_i = nn.Linear(ATOM_FDIM + BOND_FDIM, hidden_size, bias=False)\n\n        self.loopy_bp_updater = LoopyBPUpdate(hidden_size)\n        self.gather_updater = GatherUpdate(hidden_size)\n        self.hidden_size = hidden_size\n\n        self.n_samples_total = 0\n        self.n_nodes_total = 0\n        self.n_edges_total = 0\n        self.n_passes = 0\n\n    def forward(self, cand_batch, mol_tree_batch):\n        (\n            cand_graphs,\n            tree_mess_src_edges,\n            tree_mess_tgt_edges,\n            tree_mess_tgt_nodes,\n        ) = cand_batch\n\n        n_samples = len(cand_graphs)\n\n        cand_line_graph = line_graph(\n            cand_graphs, backtracking=False, shared=True\n        )\n\n        n_nodes = cand_graphs.num_nodes()\n        n_edges = cand_graphs.num_edges()\n\n        cand_graphs = self.run(\n            cand_graphs,\n            cand_line_graph,\n            tree_mess_src_edges,\n            tree_mess_tgt_edges,\n            tree_mess_tgt_nodes,\n            mol_tree_batch,\n        )\n\n        g_repr = mean_nodes(cand_graphs, \"h\")\n\n        self.n_samples_total += n_samples\n        self.n_nodes_total += n_nodes\n        self.n_edges_total += n_edges\n        self.n_passes += 1\n\n        return g_repr\n\n    def run(\n        self,\n        cand_graphs,\n        cand_line_graph,\n        tree_mess_src_edges,\n        tree_mess_tgt_edges,\n        tree_mess_tgt_nodes,\n        mol_tree_batch,\n    ):\n        n_nodes = cand_graphs.num_nodes()\n\n        cand_graphs.apply_edges(\n            func=lambda edges: {\"src_x\": edges.src[\"x\"]},\n        )\n        cand_line_graph.ndata.update(cand_graphs.edata)\n\n        bond_features = cand_line_graph.ndata[\"x\"]\n        source_features = cand_line_graph.ndata[\"src_x\"]\n        features = torch.cat([source_features, bond_features], 1)\n        msg_input = self.W_i(features)\n        cand_line_graph.ndata.update(\n            {\n                \"msg_input\": msg_input,\n                \"msg\": torch.relu(msg_input),\n                \"accum_msg\": torch.zeros_like(msg_input),\n            }\n        )\n        zero_node_state = bond_features.new(n_nodes, self.hidden_size).zero_()\n        cand_graphs.ndata.update(\n            {\n                \"m\": zero_node_state.clone(),\n                \"h\": zero_node_state.clone(),\n            }\n        )\n\n        cand_graphs.edata[\"alpha\"] = cuda(\n            torch.zeros(cand_graphs.num_edges(), self.hidden_size)\n        )\n        cand_graphs.ndata[\"alpha\"] = zero_node_state\n        if tree_mess_src_edges.shape[0] > 0:\n            if PAPER:\n                src_u, src_v = tree_mess_src_edges.unbind(1)\n                tgt_u, tgt_v = tree_mess_tgt_edges.unbind(1)\n                src_u = src_u.to(mol_tree_batch.device)\n                src_v = src_v.to(mol_tree_batch.device)\n                eid = mol_tree_batch.edge_ids(src_u, src_v)\n                alpha = mol_tree_batch.edata[\"m\"][eid]\n                cand_graphs.edges[tgt_u, tgt_v].data[\"alpha\"] = alpha\n            else:\n                src_u, src_v = tree_mess_src_edges.unbind(1)\n                src_u = src_u.to(mol_tree_batch.device)\n                src_v = src_v.to(mol_tree_batch.device)\n                eid = mol_tree_batch.edge_ids(src_u, src_v)\n                alpha = mol_tree_batch.edata[\"m\"][eid]\n                node_idx = tree_mess_tgt_nodes.to(\n                    device=zero_node_state.device\n                )[:, None].expand_as(alpha)\n                node_alpha = zero_node_state.clone().scatter_add(\n                    0, node_idx, alpha\n                )\n                cand_graphs.ndata[\"alpha\"] = node_alpha\n                cand_graphs.apply_edges(\n                    func=lambda edges: {\"alpha\": edges.src[\"alpha\"]},\n                )\n\n        cand_line_graph.ndata.update(cand_graphs.edata)\n        for i in range(self.depth - 1):\n            cand_line_graph.update_all(\n                DGLF.copy_u(\"msg\", \"msg\"), DGLF.sum(\"msg\", \"accum_msg\")\n            )\n            cand_line_graph.apply_nodes(self.loopy_bp_updater)\n\n        cand_graphs.edata.update(cand_line_graph.ndata)\n\n        cand_graphs.update_all(DGLF.copy_e(\"msg\", \"msg\"), DGLF.sum(\"msg\", \"m\"))\n        if PAPER:\n            cand_graphs.update_all(\n                DGLF.copy_e(\"alpha\", \"alpha\"), DGLF.sum(\"alpha\", \"accum_alpha\")\n            )\n        cand_graphs.apply_nodes(self.gather_updater)\n\n        return cand_graphs\n"
  },
  {
    "path": "examples/pytorch/jtnn/jtnn/jtnn_dec.py",
    "content": "import dgl.function as DGLF\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dgl import batch, dfs_labeled_edges_generator, line_graph\n\nfrom .chemutils import enum_assemble_nx, get_mol\nfrom .mol_tree_nx import DGLMolTree\nfrom .nnutils import cuda, GRUUpdate, tocpu\n\nMAX_NB = 8\nMAX_DECODE_LEN = 100\n\n\ndef dfs_order(forest, roots):\n    forest = tocpu(forest)\n    edges = dfs_labeled_edges_generator(forest, roots, has_reverse_edge=True)\n    for e, l in zip(*edges):\n        # I exploited the fact that the reverse edge ID equal to 1 xor forward\n        # edge ID for molecule trees.  Normally, I should locate reverse edges\n        # using find_edges().\n        yield e ^ l, l\n\n\ndec_tree_node_msg = DGLF.copy_e(edge=\"m\", out=\"m\")\ndec_tree_node_reduce = DGLF.sum(msg=\"m\", out=\"h\")\n\n\ndef dec_tree_node_update(nodes):\n    return {\"new\": nodes.data[\"new\"].clone().zero_()}\n\n\ndef have_slots(fa_slots, ch_slots):\n    if len(fa_slots) > 2 and len(ch_slots) > 2:\n        return True\n    matches = []\n    for i, s1 in enumerate(fa_slots):\n        a1, c1, h1 = s1\n        for j, s2 in enumerate(ch_slots):\n            a2, c2, h2 = s2\n            if a1 == a2 and c1 == c2 and (a1 != \"C\" or h1 + h2 >= 4):\n                matches.append((i, j))\n\n    if len(matches) == 0:\n        return False\n\n    fa_match, ch_match = list(zip(*matches))\n    if (\n        len(set(fa_match)) == 1 and 1 < len(fa_slots) <= 2\n    ):  # never remove atom from ring\n        fa_slots.pop(fa_match[0])\n    if (\n        len(set(ch_match)) == 1 and 1 < len(ch_slots) <= 2\n    ):  # never remove atom from ring\n        ch_slots.pop(ch_match[0])\n\n    return True\n\n\ndef can_assemble(mol_tree, u, v_node_dict):\n    u_node_dict = mol_tree.nodes_dict[u]\n    u_neighbors = mol_tree.graph.successors(u)\n    u_neighbors_node_dict = [\n        mol_tree.nodes_dict[_u]\n        for _u in u_neighbors\n        if _u in mol_tree.nodes_dict\n    ]\n    neis = u_neighbors_node_dict + [v_node_dict]\n    for i, nei in enumerate(neis):\n        nei[\"nid\"] = i\n\n    neighbors = [nei for nei in neis if nei[\"mol\"].GetNumAtoms() > 1]\n    neighbors = sorted(\n        neighbors, key=lambda x: x[\"mol\"].GetNumAtoms(), reverse=True\n    )\n    singletons = [nei for nei in neis if nei[\"mol\"].GetNumAtoms() == 1]\n    neighbors = singletons + neighbors\n    cands = enum_assemble_nx(u_node_dict, neighbors)\n    return len(cands) > 0\n\n\ndef create_node_dict(smiles, clique=[]):\n    return dict(\n        smiles=smiles,\n        mol=get_mol(smiles),\n        clique=clique,\n    )\n\n\nclass DGLJTNNDecoder(nn.Module):\n    def __init__(self, vocab, hidden_size, latent_size, embedding=None):\n        nn.Module.__init__(self)\n\n        self.hidden_size = hidden_size\n        self.vocab_size = vocab.size()\n        self.vocab = vocab\n\n        if embedding is None:\n            self.embedding = nn.Embedding(self.vocab_size, hidden_size)\n        else:\n            self.embedding = embedding\n\n        self.dec_tree_edge_update = GRUUpdate(hidden_size)\n\n        self.W = nn.Linear(latent_size + hidden_size, hidden_size)\n        self.U = nn.Linear(latent_size + 2 * hidden_size, hidden_size)\n        self.W_o = nn.Linear(hidden_size, self.vocab_size)\n        self.U_s = nn.Linear(hidden_size, 1)\n\n    def forward(self, mol_trees, tree_vec):\n        \"\"\"\n        The training procedure which computes the prediction loss given the\n        ground truth tree\n        \"\"\"\n        mol_tree_batch = batch(mol_trees)\n        mol_tree_batch_lg = line_graph(\n            mol_tree_batch, backtracking=False, shared=True\n        )\n        n_trees = len(mol_trees)\n\n        return self.run(mol_tree_batch, mol_tree_batch_lg, n_trees, tree_vec)\n\n    def run(self, mol_tree_batch, mol_tree_batch_lg, n_trees, tree_vec):\n        node_offset = np.cumsum(\n            np.insert(mol_tree_batch.batch_num_nodes().cpu().numpy(), 0, 0)\n        )\n        root_ids = node_offset[:-1]\n        n_nodes = mol_tree_batch.num_nodes()\n        n_edges = mol_tree_batch.num_edges()\n\n        mol_tree_batch.ndata.update(\n            {\n                \"x\": self.embedding(mol_tree_batch.ndata[\"wid\"]),\n                \"h\": cuda(torch.zeros(n_nodes, self.hidden_size)),\n                \"new\": cuda(\n                    torch.ones(n_nodes).bool()\n                ),  # whether it's newly generated node\n            }\n        )\n\n        mol_tree_batch.edata.update(\n            {\n                \"s\": cuda(torch.zeros(n_edges, self.hidden_size)),\n                \"m\": cuda(torch.zeros(n_edges, self.hidden_size)),\n                \"r\": cuda(torch.zeros(n_edges, self.hidden_size)),\n                \"z\": cuda(torch.zeros(n_edges, self.hidden_size)),\n                \"src_x\": cuda(torch.zeros(n_edges, self.hidden_size)),\n                \"dst_x\": cuda(torch.zeros(n_edges, self.hidden_size)),\n                \"rm\": cuda(torch.zeros(n_edges, self.hidden_size)),\n                \"accum_rm\": cuda(torch.zeros(n_edges, self.hidden_size)),\n            }\n        )\n\n        mol_tree_batch.apply_edges(\n            func=lambda edges: {\n                \"src_x\": edges.src[\"x\"],\n                \"dst_x\": edges.dst[\"x\"],\n            },\n        )\n\n        # input tensors for stop prediction (p) and label prediction (q)\n        p_inputs = []\n        p_targets = []\n        q_inputs = []\n        q_targets = []\n\n        # Predict root\n        mol_tree_batch.pull(root_ids, DGLF.copy_e(\"m\", \"m\"), DGLF.sum(\"m\", \"h\"))\n        mol_tree_batch.apply_nodes(dec_tree_node_update, v=root_ids)\n        # Extract hidden states and store them for stop/label prediction\n        h = mol_tree_batch.nodes[root_ids].data[\"h\"]\n        x = mol_tree_batch.nodes[root_ids].data[\"x\"]\n        p_inputs.append(torch.cat([x, h, tree_vec], 1))\n        # If the out degree is 0 we don't generate any edges at all\n        root_out_degrees = mol_tree_batch.out_degrees(root_ids)\n        q_inputs.append(torch.cat([h, tree_vec], 1))\n        q_targets.append(mol_tree_batch.nodes[root_ids].data[\"wid\"])\n\n        # Traverse the tree and predict on children\n        for eid, p in dfs_order(mol_tree_batch, root_ids):\n            eid = eid.to(mol_tree_batch.device)\n            p = p.to(mol_tree_batch.device)\n            u, v = mol_tree_batch.find_edges(eid)\n\n            p_target_list = torch.zeros_like(root_out_degrees)\n            p_target_list[root_out_degrees > 0] = 1 - p\n            p_target_list = p_target_list[root_out_degrees >= 0]\n            p_targets.append(torch.tensor(p_target_list))\n\n            root_out_degrees -= (root_out_degrees == 0).long()\n            root_out_degrees -= torch.tensor(\n                np.isin(root_ids, v.cpu().numpy())\n            ).to(root_out_degrees)\n\n            mol_tree_batch_lg.ndata.update(mol_tree_batch.edata)\n            mol_tree_batch_lg.pull(\n                eid, DGLF.copy_u(\"m\", \"m\"), DGLF.sum(\"m\", \"s\")\n            )\n            mol_tree_batch_lg.pull(\n                eid, DGLF.copy_u(\"rm\", \"rm\"), DGLF.sum(\"rm\", \"accum_rm\")\n            )\n            mol_tree_batch_lg.apply_nodes(self.dec_tree_edge_update, v=eid)\n            mol_tree_batch.edata.update(mol_tree_batch_lg.ndata)\n\n            is_new = mol_tree_batch.nodes[v].data[\"new\"]\n            mol_tree_batch.pull(v, DGLF.copy_e(\"m\", \"m\"), DGLF.sum(\"m\", \"h\"))\n            mol_tree_batch.apply_nodes(dec_tree_node_update, v=v)\n\n            # Extract\n            n_repr = mol_tree_batch.nodes[v].data\n            h = n_repr[\"h\"]\n            x = n_repr[\"x\"]\n            tree_vec_set = tree_vec[root_out_degrees >= 0]\n            wid = n_repr[\"wid\"]\n            p_inputs.append(torch.cat([x, h, tree_vec_set], 1))\n            # Only newly generated nodes are needed for label prediction\n            # NOTE: The following works since the uncomputed messages are zeros.\n\n            q_input = torch.cat([h, tree_vec_set], 1)[is_new]\n            q_target = wid[is_new]\n            if q_input.shape[0] > 0:\n                q_inputs.append(q_input)\n                q_targets.append(q_target)\n        p_targets.append(\n            torch.zeros(\n                (root_out_degrees == 0).sum(),\n                device=root_out_degrees.device,\n                dtype=torch.int64,\n            )\n        )\n\n        # Batch compute the stop/label prediction losses\n        p_inputs = torch.cat(p_inputs, 0)\n        p_targets = cuda(torch.cat(p_targets, 0))\n        q_inputs = torch.cat(q_inputs, 0)\n        q_targets = torch.cat(q_targets, 0)\n\n        q = self.W_o(torch.relu(self.W(q_inputs)))\n        p = self.U_s(torch.relu(self.U(p_inputs)))[:, 0]\n\n        p_loss = (\n            F.binary_cross_entropy_with_logits(\n                p, p_targets.float(), size_average=False\n            )\n            / n_trees\n        )\n        q_loss = F.cross_entropy(q, q_targets, size_average=False) / n_trees\n        p_acc = ((p > 0).long() == p_targets).sum().float() / p_targets.shape[0]\n        q_acc = (q.max(1)[1] == q_targets).float().sum() / q_targets.shape[0]\n\n        self.q_inputs = q_inputs\n        self.q_targets = q_targets\n        self.q = q\n        self.p_inputs = p_inputs\n        self.p_targets = p_targets\n        self.p = p\n\n        return q_loss, p_loss, q_acc, p_acc\n\n    def decode(self, mol_vec):\n        assert mol_vec.shape[0] == 1\n\n        mol_tree = DGLMolTree(None)\n        mol_tree.graph = mol_tree.graph.to(mol_vec.device)\n        mol_tree_graph = mol_tree.graph\n\n        init_hidden = cuda(torch.zeros(1, self.hidden_size))\n\n        root_hidden = torch.cat([init_hidden, mol_vec], 1)\n        root_hidden = F.relu(self.W(root_hidden))\n        root_score = self.W_o(root_hidden)\n        _, root_wid = torch.max(root_score, 1)\n        root_wid = root_wid.view(1)\n\n        mol_tree_graph.add_nodes(1)  # root\n        mol_tree_graph.ndata[\"wid\"] = root_wid\n        mol_tree_graph.ndata[\"x\"] = self.embedding(root_wid)\n        mol_tree_graph.ndata[\"h\"] = init_hidden\n        mol_tree_graph.ndata[\"fail\"] = cuda(torch.tensor([0]))\n        mol_tree.nodes_dict[0] = root_node_dict = create_node_dict(\n            self.vocab.get_smiles(root_wid)\n        )\n\n        stack, trace = [], []\n        stack.append((0, self.vocab.get_slots(root_wid)))\n\n        all_nodes = {0: root_node_dict}\n        h = {}\n        first = True\n        new_node_id = 0\n        new_edge_id = 0\n\n        for step in range(MAX_DECODE_LEN):\n            u, u_slots = stack[-1]\n            x = mol_tree_graph.ndata[\"x\"][u : u + 1]\n            h = mol_tree_graph.ndata[\"h\"][u : u + 1]\n\n            # Predict stop\n            p_input = torch.cat([x, h, mol_vec], 1)\n            p_score = torch.sigmoid(self.U_s(torch.relu(self.U(p_input))))\n            backtrack = p_score.item() < 0.5\n\n            if not backtrack:\n                # Predict next clique.  Note that the prediction may fail due\n                # to lack of assemblable components\n                mol_tree_graph.add_nodes(1)\n                new_node_id += 1\n                v = new_node_id\n                mol_tree_graph.add_edges(u, v)\n                uv = new_edge_id\n                new_edge_id += 1\n\n                if first:\n                    mol_tree_graph.edata.update(\n                        {\n                            \"s\": cuda(torch.zeros(1, self.hidden_size)),\n                            \"m\": cuda(torch.zeros(1, self.hidden_size)),\n                            \"r\": cuda(torch.zeros(1, self.hidden_size)),\n                            \"z\": cuda(torch.zeros(1, self.hidden_size)),\n                            \"src_x\": cuda(torch.zeros(1, self.hidden_size)),\n                            \"dst_x\": cuda(torch.zeros(1, self.hidden_size)),\n                            \"rm\": cuda(torch.zeros(1, self.hidden_size)),\n                            \"accum_rm\": cuda(torch.zeros(1, self.hidden_size)),\n                        }\n                    )\n                    first = False\n\n                mol_tree_graph.edata[\"src_x\"][uv] = mol_tree_graph.ndata[\"x\"][u]\n                # keeping dst_x 0 is fine as h on new edge doesn't depend on that.\n\n                # DGL doesn't dynamically maintain a line graph.\n                mol_tree_graph_lg = line_graph(\n                    mol_tree_graph, backtracking=False, shared=True\n                )\n\n                mol_tree_graph_lg.pull(\n                    uv, DGLF.copy_u(\"m\", \"m\"), DGLF.sum(\"m\", \"s\")\n                )\n                mol_tree_graph_lg.pull(\n                    uv, DGLF.copy_u(\"rm\", \"rm\"), DGLF.sum(\"rm\", \"accum_rm\")\n                )\n                mol_tree_graph_lg.apply_nodes(\n                    self.dec_tree_edge_update.update_zm, v=uv\n                )\n                mol_tree_graph.edata.update(mol_tree_graph_lg.ndata)\n                mol_tree_graph.pull(\n                    v, DGLF.copy_e(\"m\", \"m\"), DGLF.sum(\"m\", \"h\")\n                )\n\n                h_v = mol_tree_graph.ndata[\"h\"][v : v + 1]\n                q_input = torch.cat([h_v, mol_vec], 1)\n                q_score = torch.softmax(\n                    self.W_o(torch.relu(self.W(q_input))), -1\n                )\n                _, sort_wid = torch.sort(q_score, 1, descending=True)\n                sort_wid = sort_wid.squeeze()\n\n                next_wid = None\n                for wid in sort_wid.tolist()[:5]:\n                    slots = self.vocab.get_slots(wid)\n                    cand_node_dict = create_node_dict(\n                        self.vocab.get_smiles(wid)\n                    )\n                    if have_slots(u_slots, slots) and can_assemble(\n                        mol_tree, u, cand_node_dict\n                    ):\n                        next_wid = wid\n                        next_slots = slots\n                        next_node_dict = cand_node_dict\n                        break\n\n                if next_wid is None:\n                    # Failed adding an actual children; v is a spurious node\n                    # and we mark it.\n                    mol_tree_graph.ndata[\"fail\"][v] = cuda(torch.tensor([1]))\n                    backtrack = True\n                else:\n                    next_wid = cuda(torch.tensor([next_wid]))\n                    mol_tree_graph.ndata[\"wid\"][v] = next_wid\n                    mol_tree_graph.ndata[\"x\"][v] = self.embedding(next_wid)\n                    mol_tree.nodes_dict[v] = next_node_dict\n                    all_nodes[v] = next_node_dict\n                    stack.append((v, next_slots))\n                    mol_tree_graph.add_edges(v, u)\n                    vu = new_edge_id\n                    new_edge_id += 1\n                    mol_tree_graph.edata[\"dst_x\"][uv] = mol_tree_graph.ndata[\n                        \"x\"\n                    ][v]\n                    mol_tree_graph.edata[\"src_x\"][vu] = mol_tree_graph.ndata[\n                        \"x\"\n                    ][v]\n                    mol_tree_graph.edata[\"dst_x\"][vu] = mol_tree_graph.ndata[\n                        \"x\"\n                    ][u]\n\n                    # DGL doesn't dynamically maintain a line graph.\n                    mol_tree_graph_lg = line_graph(\n                        mol_tree_graph, backtracking=False, shared=True\n                    )\n                    mol_tree_graph_lg.apply_nodes(\n                        self.dec_tree_edge_update.update_r, uv\n                    )\n                    mol_tree_graph.edata.update(mol_tree_graph_lg.ndata)\n\n            if backtrack:\n                if len(stack) == 1:\n                    break  # At root, terminate\n\n                pu, _ = stack[-2]\n                u_pu = mol_tree_graph.edge_ids(u, pu)\n\n                mol_tree_graph_lg.pull(\n                    u_pu, DGLF.copy_u(\"m\", \"m\"), DGLF.sum(\"m\", \"s\")\n                )\n                mol_tree_graph_lg.pull(\n                    u_pu, DGLF.copy_u(\"rm\", \"rm\"), DGLF.sum(\"rm\", \"accum_rm\")\n                )\n                mol_tree_graph_lg.apply_nodes(self.dec_tree_edge_update, v=u_pu)\n                mol_tree_graph.edata.update(mol_tree_graph_lg.ndata)\n                mol_tree_graph.pull(\n                    pu, DGLF.copy_e(\"m\", \"m\"), DGLF.sum(\"m\", \"h\")\n                )\n                stack.pop()\n\n        effective_nodes = mol_tree_graph.filter_nodes(\n            lambda nodes: nodes.data[\"fail\"] != 1\n        )\n        effective_nodes, _ = torch.sort(effective_nodes)\n        return mol_tree, all_nodes, effective_nodes\n"
  },
  {
    "path": "examples/pytorch/jtnn/jtnn/jtnn_enc.py",
    "content": "import dgl.function as DGLF\nimport numpy as np\nimport torch\nimport torch.nn as nn\nfrom dgl import batch, bfs_edges_generator, line_graph\n\nfrom .nnutils import cuda, GRUUpdate, tocpu\n\nMAX_NB = 8\n\n\ndef level_order(forest, roots):\n    forest = tocpu(forest)\n    edges = bfs_edges_generator(forest, roots)\n    if len(edges) == 0:\n        # no edges in the tree; do not perform loopy BP\n        return\n    _, leaves = forest.find_edges(edges[-1])\n    edges_back = bfs_edges_generator(forest, roots, reverse=True)\n    yield from reversed(edges_back)\n    yield from edges\n\n\nclass EncoderGatherUpdate(nn.Module):\n    def __init__(self, hidden_size):\n        nn.Module.__init__(self)\n        self.hidden_size = hidden_size\n\n        self.W = nn.Linear(2 * hidden_size, hidden_size)\n\n    def forward(self, nodes):\n        x = nodes.data[\"x\"]\n        m = nodes.data[\"m\"]\n        return {\n            \"h\": torch.relu(self.W(torch.cat([x, m], 1))),\n        }\n\n\nclass DGLJTNNEncoder(nn.Module):\n    def __init__(self, vocab, hidden_size, embedding=None):\n        nn.Module.__init__(self)\n        self.hidden_size = hidden_size\n        self.vocab_size = vocab.size()\n        self.vocab = vocab\n\n        if embedding is None:\n            self.embedding = nn.Embedding(self.vocab_size, hidden_size)\n        else:\n            self.embedding = embedding\n\n        self.enc_tree_update = GRUUpdate(hidden_size)\n        self.enc_tree_gather_update = EncoderGatherUpdate(hidden_size)\n\n    def forward(self, mol_trees):\n        mol_tree_batch = batch(mol_trees)\n\n        # Build line graph to prepare for belief propagation\n        mol_tree_batch_lg = line_graph(\n            mol_tree_batch, backtracking=False, shared=True\n        )\n\n        return self.run(mol_tree_batch, mol_tree_batch_lg)\n\n    def run(self, mol_tree_batch, mol_tree_batch_lg):\n        # Since tree roots are designated to 0.  In the batched graph we can\n        # simply find the corresponding node ID by looking at node_offset\n        node_offset = np.cumsum(\n            np.insert(mol_tree_batch.batch_num_nodes().cpu().numpy(), 0, 0)\n        )\n        root_ids = node_offset[:-1]\n        n_nodes = mol_tree_batch.num_nodes()\n        n_edges = mol_tree_batch.num_edges()\n\n        # Assign structure embeddings to tree nodes\n        mol_tree_batch.ndata.update(\n            {\n                \"x\": self.embedding(mol_tree_batch.ndata[\"wid\"]),\n                \"m\": cuda(torch.zeros(n_nodes, self.hidden_size)),\n                \"h\": cuda(torch.zeros(n_nodes, self.hidden_size)),\n            }\n        )\n\n        # Initialize the intermediate variables according to Eq (4)-(8).\n        # Also initialize the src_x and dst_x fields.\n        # TODO: context?\n        mol_tree_batch.edata.update(\n            {\n                \"s\": cuda(torch.zeros(n_edges, self.hidden_size)),\n                \"m\": cuda(torch.zeros(n_edges, self.hidden_size)),\n                \"r\": cuda(torch.zeros(n_edges, self.hidden_size)),\n                \"z\": cuda(torch.zeros(n_edges, self.hidden_size)),\n                \"src_x\": cuda(torch.zeros(n_edges, self.hidden_size)),\n                \"dst_x\": cuda(torch.zeros(n_edges, self.hidden_size)),\n                \"rm\": cuda(torch.zeros(n_edges, self.hidden_size)),\n                \"accum_rm\": cuda(torch.zeros(n_edges, self.hidden_size)),\n            }\n        )\n\n        # Send the source/destination node features to edges\n        mol_tree_batch.apply_edges(\n            func=lambda edges: {\n                \"src_x\": edges.src[\"x\"],\n                \"dst_x\": edges.dst[\"x\"],\n            },\n        )\n\n        # Message passing\n        # I exploited the fact that the reduce function is a sum of incoming\n        # messages, and the uncomputed messages are zero vectors.  Essentially,\n        # we can always compute s_ij as the sum of incoming m_ij, no matter\n        # if m_ij is actually computed or not.\n        mol_tree_batch_lg.ndata.update(mol_tree_batch.edata)\n        for eid in level_order(mol_tree_batch, root_ids):\n            eid = eid.to(mol_tree_batch_lg.device)\n            mol_tree_batch_lg.pull(\n                eid, DGLF.copy_u(\"m\", \"m\"), DGLF.sum(\"m\", \"s\")\n            )\n            mol_tree_batch_lg.pull(\n                eid, DGLF.copy_u(\"rm\", \"rm\"), DGLF.sum(\"rm\", \"accum_rm\")\n            )\n            mol_tree_batch_lg.apply_nodes(self.enc_tree_update, v=eid)\n\n        # Readout\n        mol_tree_batch.edata.update(mol_tree_batch_lg.ndata)\n        mol_tree_batch.update_all(DGLF.copy_e(\"m\", \"m\"), DGLF.sum(\"m\", \"m\"))\n        mol_tree_batch.apply_nodes(self.enc_tree_gather_update)\n\n        root_vecs = mol_tree_batch.nodes[root_ids].data[\"h\"]\n\n        return mol_tree_batch, root_vecs\n"
  },
  {
    "path": "examples/pytorch/jtnn/jtnn/jtnn_vae.py",
    "content": "import copy\n\nimport rdkit.Chem as Chem\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom dgl import batch, unbatch\n\nfrom .chemutils import (\n    attach_mols_nx,\n    copy_edit_mol,\n    decode_stereo,\n    enum_assemble_nx,\n    set_atommap,\n)\nfrom .jtmpn import DGLJTMPN, mol2dgl_single as mol2dgl_dec\nfrom .jtnn_dec import DGLJTNNDecoder\nfrom .jtnn_enc import DGLJTNNEncoder\nfrom .mpn import DGLMPN, mol2dgl_single as mol2dgl_enc\nfrom .nnutils import cuda\n\n\nclass DGLJTNNVAE(nn.Module):\n    def __init__(self, vocab, hidden_size, latent_size, depth):\n        super(DGLJTNNVAE, self).__init__()\n        self.vocab = vocab\n        self.hidden_size = hidden_size\n        self.latent_size = latent_size\n        self.depth = depth\n\n        self.embedding = nn.Embedding(vocab.size(), hidden_size)\n        self.mpn = DGLMPN(hidden_size, depth)\n        self.jtnn = DGLJTNNEncoder(vocab, hidden_size, self.embedding)\n        self.decoder = DGLJTNNDecoder(\n            vocab, hidden_size, latent_size // 2, self.embedding\n        )\n        self.jtmpn = DGLJTMPN(hidden_size, depth)\n\n        self.T_mean = nn.Linear(hidden_size, latent_size // 2)\n        self.T_var = nn.Linear(hidden_size, latent_size // 2)\n        self.G_mean = nn.Linear(hidden_size, latent_size // 2)\n        self.G_var = nn.Linear(hidden_size, latent_size // 2)\n\n        self.n_nodes_total = 0\n        self.n_passes = 0\n        self.n_edges_total = 0\n        self.n_tree_nodes_total = 0\n\n    @staticmethod\n    def move_to_cuda(mol_batch):\n        for i in range(len(mol_batch[\"mol_trees\"])):\n            mol_batch[\"mol_trees\"][i].graph = cuda(\n                mol_batch[\"mol_trees\"][i].graph\n            )\n\n        mol_batch[\"mol_graph_batch\"] = cuda(mol_batch[\"mol_graph_batch\"])\n        if \"cand_graph_batch\" in mol_batch:\n            mol_batch[\"cand_graph_batch\"] = cuda(mol_batch[\"cand_graph_batch\"])\n        if mol_batch.get(\"stereo_cand_graph_batch\") is not None:\n            mol_batch[\"stereo_cand_graph_batch\"] = cuda(\n                mol_batch[\"stereo_cand_graph_batch\"]\n            )\n\n    def encode(self, mol_batch):\n        mol_graphs = mol_batch[\"mol_graph_batch\"]\n        mol_vec = self.mpn(mol_graphs)\n\n        mol_tree_batch, tree_vec = self.jtnn(\n            [t.graph for t in mol_batch[\"mol_trees\"]]\n        )\n\n        self.n_nodes_total += mol_graphs.num_nodes()\n        self.n_edges_total += mol_graphs.num_edges()\n        self.n_tree_nodes_total += sum(\n            t.graph.num_nodes() for t in mol_batch[\"mol_trees\"]\n        )\n        self.n_passes += 1\n\n        return mol_tree_batch, tree_vec, mol_vec\n\n    def sample(self, tree_vec, mol_vec, e1=None, e2=None):\n        tree_mean = self.T_mean(tree_vec)\n        tree_log_var = -torch.abs(self.T_var(tree_vec))\n        mol_mean = self.G_mean(mol_vec)\n        mol_log_var = -torch.abs(self.G_var(mol_vec))\n\n        epsilon = cuda(torch.randn(*tree_mean.shape)) if e1 is None else e1\n        tree_vec = tree_mean + torch.exp(tree_log_var / 2) * epsilon\n        epsilon = cuda(torch.randn(*mol_mean.shape)) if e2 is None else e2\n        mol_vec = mol_mean + torch.exp(mol_log_var / 2) * epsilon\n\n        z_mean = torch.cat([tree_mean, mol_mean], 1)\n        z_log_var = torch.cat([tree_log_var, mol_log_var], 1)\n\n        return tree_vec, mol_vec, z_mean, z_log_var\n\n    def forward(self, mol_batch, beta=0, e1=None, e2=None):\n        self.move_to_cuda(mol_batch)\n\n        mol_trees = mol_batch[\"mol_trees\"]\n        batch_size = len(mol_trees)\n\n        mol_tree_batch, tree_vec, mol_vec = self.encode(mol_batch)\n\n        tree_vec, mol_vec, z_mean, z_log_var = self.sample(\n            tree_vec, mol_vec, e1, e2\n        )\n        kl_loss = (\n            -0.5\n            * torch.sum(\n                1.0 + z_log_var - z_mean * z_mean - torch.exp(z_log_var)\n            )\n            / batch_size\n        )\n\n        word_loss, topo_loss, word_acc, topo_acc = self.decoder(\n            [t.graph for t in mol_trees], tree_vec\n        )\n        assm_loss, assm_acc = self.assm(mol_batch, mol_tree_batch, mol_vec)\n        stereo_loss, stereo_acc = self.stereo(mol_batch, mol_vec)\n\n        loss = (\n            word_loss + topo_loss + assm_loss + 2 * stereo_loss + beta * kl_loss\n        )\n\n        return loss, kl_loss, word_acc, topo_acc, assm_acc, stereo_acc\n\n    def assm(self, mol_batch, mol_tree_batch, mol_vec):\n        cands = [\n            mol_batch[\"cand_graph_batch\"],\n            cuda(mol_batch[\"tree_mess_src_e\"]),\n            cuda(mol_batch[\"tree_mess_tgt_e\"]),\n            cuda(mol_batch[\"tree_mess_tgt_n\"]),\n        ]\n        cand_vec = self.jtmpn(cands, mol_tree_batch)\n        cand_vec = self.G_mean(cand_vec)\n\n        batch_idx = cuda(torch.LongTensor(mol_batch[\"cand_batch_idx\"]))\n        mol_vec = mol_vec[batch_idx]\n\n        mol_vec = mol_vec.view(-1, 1, self.latent_size // 2)\n        cand_vec = cand_vec.view(-1, self.latent_size // 2, 1)\n        scores = (mol_vec @ cand_vec)[:, 0, 0]\n\n        cnt, tot, acc = 0, 0, 0\n        all_loss = []\n        for i, mol_tree in enumerate(mol_batch[\"mol_trees\"]):\n            comp_nodes = [\n                node_id\n                for node_id, node in mol_tree.nodes_dict.items()\n                if len(node[\"cands\"]) > 1 and not node[\"is_leaf\"]\n            ]\n            cnt += len(comp_nodes)\n            # segmented accuracy and cross entropy\n            for node_id in comp_nodes:\n                node = mol_tree.nodes_dict[node_id]\n                label = node[\"cands\"].index(node[\"label\"])\n                ncand = len(node[\"cands\"])\n                cur_score = scores[tot : tot + ncand]\n                tot += ncand\n\n                if cur_score[label].item() >= cur_score.max().item():\n                    acc += 1\n\n                label = cuda(torch.LongTensor([label]))\n                all_loss.append(\n                    F.cross_entropy(\n                        cur_score.view(1, -1), label, size_average=False\n                    )\n                )\n\n        all_loss = sum(all_loss) / len(mol_batch[\"mol_trees\"])\n        return all_loss, acc / cnt\n\n    def stereo(self, mol_batch, mol_vec):\n        stereo_cands = mol_batch[\"stereo_cand_graph_batch\"]\n        batch_idx = mol_batch[\"stereo_cand_batch_idx\"]\n        labels = mol_batch[\"stereo_cand_labels\"]\n        lengths = mol_batch[\"stereo_cand_lengths\"]\n\n        if len(labels) == 0:\n            # Only one stereoisomer exists; do nothing\n            return cuda(torch.tensor(0.0)), 1.0\n\n        batch_idx = cuda(torch.LongTensor(batch_idx))\n        stereo_cands = self.mpn(stereo_cands)\n        stereo_cands = self.G_mean(stereo_cands)\n        stereo_labels = mol_vec[batch_idx]\n        scores = F.cosine_similarity(stereo_cands, stereo_labels)\n\n        st, acc = 0, 0\n        all_loss = []\n        for label, le in zip(labels, lengths):\n            cur_scores = scores[st : st + le]\n            if cur_scores.data[label].item() >= cur_scores.max().item():\n                acc += 1\n            label = cuda(torch.LongTensor([label]))\n            all_loss.append(\n                F.cross_entropy(\n                    cur_scores.view(1, -1), label, size_average=False\n                )\n            )\n            st += le\n\n        all_loss = sum(all_loss) / len(labels)\n        return all_loss, acc / len(labels)\n\n    def decode(self, tree_vec, mol_vec):\n        mol_tree, nodes_dict, effective_nodes = self.decoder.decode(tree_vec)\n        effective_nodes_list = effective_nodes.tolist()\n        nodes_dict = [nodes_dict[v] for v in effective_nodes_list]\n\n        for i, (node_id, node) in enumerate(\n            zip(effective_nodes_list, nodes_dict)\n        ):\n            node[\"idx\"] = i\n            node[\"nid\"] = i + 1\n            node[\"is_leaf\"] = True\n            if mol_tree.graph.in_degrees(node_id) > 1:\n                node[\"is_leaf\"] = False\n                set_atommap(node[\"mol\"], node[\"nid\"])\n\n        mol_tree_sg = mol_tree.graph.subgraph(\n            effective_nodes.to(tree_vec.device)\n        )\n        mol_tree_msg, _ = self.jtnn([mol_tree_sg])\n        mol_tree_msg = unbatch(mol_tree_msg)[0]\n        mol_tree_msg.nodes_dict = nodes_dict\n\n        cur_mol = copy_edit_mol(nodes_dict[0][\"mol\"])\n        global_amap = [{}] + [{} for node in nodes_dict]\n        global_amap[1] = {\n            atom.GetIdx(): atom.GetIdx() for atom in cur_mol.GetAtoms()\n        }\n\n        cur_mol = self.dfs_assemble(\n            mol_tree_msg, mol_vec, cur_mol, global_amap, [], 0, None\n        )\n        if cur_mol is None:\n            return None\n\n        cur_mol = cur_mol.GetMol()\n        set_atommap(cur_mol)\n        cur_mol = Chem.MolFromSmiles(Chem.MolToSmiles(cur_mol))\n        if cur_mol is None:\n            return None\n\n        smiles2D = Chem.MolToSmiles(cur_mol)\n        stereo_cands = decode_stereo(smiles2D)\n        if len(stereo_cands) == 1:\n            return stereo_cands[0]\n        stereo_graphs = [mol2dgl_enc(c) for c in stereo_cands]\n        stereo_cand_graphs, atom_x, bond_x = zip(*stereo_graphs)\n        stereo_cand_graphs = cuda(batch(stereo_cand_graphs))\n        atom_x = cuda(torch.cat(atom_x))\n        bond_x = cuda(torch.cat(bond_x))\n        stereo_cand_graphs.ndata[\"x\"] = atom_x\n        stereo_cand_graphs.edata[\"x\"] = bond_x\n        stereo_cand_graphs.edata[\"src_x\"] = atom_x.new(\n            bond_x.shape[0], atom_x.shape[1]\n        ).zero_()\n        stereo_vecs = self.mpn(stereo_cand_graphs)\n        stereo_vecs = self.G_mean(stereo_vecs)\n        scores = F.cosine_similarity(stereo_vecs, mol_vec)\n        _, max_id = scores.max(0)\n        return stereo_cands[max_id.item()]\n\n    def dfs_assemble(\n        self,\n        mol_tree_msg,\n        mol_vec,\n        cur_mol,\n        global_amap,\n        fa_amap,\n        cur_node_id,\n        fa_node_id,\n    ):\n        nodes_dict = mol_tree_msg.nodes_dict\n        fa_node = nodes_dict[fa_node_id] if fa_node_id is not None else None\n        cur_node = nodes_dict[cur_node_id]\n\n        fa_nid = fa_node[\"nid\"] if fa_node is not None else -1\n        prev_nodes = [fa_node] if fa_node is not None else []\n\n        children_node_id = [\n            v\n            for v in mol_tree_msg.successors(cur_node_id).tolist()\n            if nodes_dict[v][\"nid\"] != fa_nid\n        ]\n        children = [nodes_dict[v] for v in children_node_id]\n        neighbors = [nei for nei in children if nei[\"mol\"].GetNumAtoms() > 1]\n        neighbors = sorted(\n            neighbors, key=lambda x: x[\"mol\"].GetNumAtoms(), reverse=True\n        )\n        singletons = [nei for nei in children if nei[\"mol\"].GetNumAtoms() == 1]\n        neighbors = singletons + neighbors\n\n        cur_amap = [\n            (fa_nid, a2, a1)\n            for nid, a1, a2 in fa_amap\n            if nid == cur_node[\"nid\"]\n        ]\n        cands = enum_assemble_nx(cur_node, neighbors, prev_nodes, cur_amap)\n        if len(cands) == 0:\n            return None\n        cand_smiles, cand_mols, cand_amap = list(zip(*cands))\n\n        cands = [(candmol, mol_tree_msg, cur_node_id) for candmol in cand_mols]\n        (\n            cand_graphs,\n            atom_x,\n            bond_x,\n            tree_mess_src_edges,\n            tree_mess_tgt_edges,\n            tree_mess_tgt_nodes,\n        ) = mol2dgl_dec(cands)\n        cand_graphs = batch([g.to(mol_vec.device) for g in cand_graphs])\n        atom_x = cuda(atom_x)\n        bond_x = cuda(bond_x)\n        cand_graphs.ndata[\"x\"] = atom_x\n        cand_graphs.edata[\"x\"] = bond_x\n        cand_graphs.edata[\"src_x\"] = atom_x.new(\n            bond_x.shape[0], atom_x.shape[1]\n        ).zero_()\n\n        cand_vecs = self.jtmpn(\n            (\n                cand_graphs,\n                tree_mess_src_edges,\n                tree_mess_tgt_edges,\n                tree_mess_tgt_nodes,\n            ),\n            mol_tree_msg,\n        )\n        cand_vecs = self.G_mean(cand_vecs)\n        mol_vec = mol_vec.squeeze()\n        scores = cand_vecs @ mol_vec\n\n        _, cand_idx = torch.sort(scores, descending=True)\n\n        backup_mol = Chem.RWMol(cur_mol)\n        for i in range(len(cand_idx)):\n            cur_mol = Chem.RWMol(backup_mol)\n            pred_amap = cand_amap[cand_idx[i].item()]\n            new_global_amap = copy.deepcopy(global_amap)\n\n            for nei_id, ctr_atom, nei_atom in pred_amap:\n                if nei_id == fa_nid:\n                    continue\n                new_global_amap[nei_id][nei_atom] = new_global_amap[\n                    cur_node[\"nid\"]\n                ][ctr_atom]\n\n            cur_mol = attach_mols_nx(cur_mol, children, [], new_global_amap)\n            new_mol = cur_mol.GetMol()\n            new_mol = Chem.MolFromSmiles(Chem.MolToSmiles(new_mol))\n\n            if new_mol is None:\n                continue\n\n            result = True\n            for nei_node_id, nei_node in zip(children_node_id, children):\n                if nei_node[\"is_leaf\"]:\n                    continue\n                cur_mol = self.dfs_assemble(\n                    mol_tree_msg,\n                    mol_vec,\n                    cur_mol,\n                    new_global_amap,\n                    pred_amap,\n                    nei_node_id,\n                    cur_node_id,\n                )\n                if cur_mol is None:\n                    result = False\n                    break\n\n            if result:\n                return cur_mol\n\n        return None\n"
  },
  {
    "path": "examples/pytorch/jtnn/jtnn/line_profiler_integration.py",
    "content": "\"\"\"\nline_profiler integration\n\"\"\"\nimport os\n\nif os.getenv(\"PROFILE\", 0):\n    import atexit\n\n    import line_profiler\n\n    profile = line_profiler.LineProfiler()\n\n    profile_output = os.getenv(\"PROFILE_OUTPUT\", None)\n    if profile_output:\n        from functools import partial\n\n        atexit.register(partial(profile.dump_stats, profile_output))\n    else:\n        atexit.register(profile.print_stats)\nelse:\n\n    def profile(f):\n        return f\n"
  },
  {
    "path": "examples/pytorch/jtnn/jtnn/mol_tree.py",
    "content": "import copy\n\nimport rdkit.Chem as Chem\n\n\ndef get_slots(smiles):\n    mol = Chem.MolFromSmiles(smiles)\n    return [\n        (atom.GetSymbol(), atom.GetFormalCharge(), atom.GetTotalNumHs())\n        for atom in mol.GetAtoms()\n    ]\n\n\nclass Vocab(object):\n    def __init__(self, smiles_list):\n        self.vocab = smiles_list\n        self.vmap = {x: i for i, x in enumerate(self.vocab)}\n        self.slots = [get_slots(smiles) for smiles in self.vocab]\n\n    def get_index(self, smiles):\n        return self.vmap[smiles]\n\n    def get_smiles(self, idx):\n        return self.vocab[idx]\n\n    def get_slots(self, idx):\n        return copy.deepcopy(self.slots[idx])\n\n    def size(self):\n        return len(self.vocab)\n"
  },
  {
    "path": "examples/pytorch/jtnn/jtnn/mol_tree_nx.py",
    "content": "import dgl\nimport numpy as np\nimport rdkit.Chem as Chem\n\nfrom .chemutils import (\n    decode_stereo,\n    enum_assemble_nx,\n    get_clique_mol,\n    get_mol,\n    get_smiles,\n    set_atommap,\n    tree_decomp,\n)\n\n\nclass DGLMolTree(object):\n    def __init__(self, smiles):\n        self.nodes_dict = {}\n\n        if smiles is None:\n            self.graph = dgl.graph(([], []))\n            return\n\n        self.smiles = smiles\n        self.mol = get_mol(smiles)\n\n        # Stereo Generation\n        mol = Chem.MolFromSmiles(smiles)\n        self.smiles3D = Chem.MolToSmiles(mol, isomericSmiles=True)\n        self.smiles2D = Chem.MolToSmiles(mol)\n        self.stereo_cands = decode_stereo(self.smiles2D)\n\n        # cliques: a list of list of atom indices\n        cliques, edges = tree_decomp(self.mol)\n        root = 0\n        for i, c in enumerate(cliques):\n            cmol = get_clique_mol(self.mol, c)\n            csmiles = get_smiles(cmol)\n            self.nodes_dict[i] = dict(\n                smiles=csmiles,\n                mol=get_mol(csmiles),\n                clique=c,\n            )\n            if min(c) == 0:\n                root = i\n\n        # The clique with atom ID 0 becomes root\n        if root > 0:\n            for attr in self.nodes_dict[0]:\n                self.nodes_dict[0][attr], self.nodes_dict[root][attr] = (\n                    self.nodes_dict[root][attr],\n                    self.nodes_dict[0][attr],\n                )\n\n        src = np.zeros((len(edges) * 2,), dtype=\"int\")\n        dst = np.zeros((len(edges) * 2,), dtype=\"int\")\n        for i, (_x, _y) in enumerate(edges):\n            x = 0 if _x == root else root if _x == 0 else _x\n            y = 0 if _y == root else root if _y == 0 else _y\n            src[2 * i] = x\n            dst[2 * i] = y\n            src[2 * i + 1] = y\n            dst[2 * i + 1] = x\n        self.graph = dgl.graph((src, dst), num_nodes=len(cliques))\n\n        for i in self.nodes_dict:\n            self.nodes_dict[i][\"nid\"] = i + 1\n            if self.graph.out_degrees(i) > 1:  # Leaf node mol is not marked\n                set_atommap(\n                    self.nodes_dict[i][\"mol\"], self.nodes_dict[i][\"nid\"]\n                )\n            self.nodes_dict[i][\"is_leaf\"] = self.graph.out_degrees(i) == 1\n\n    def treesize(self):\n        return self.graph.num_nodes()\n\n    def _recover_node(self, i, original_mol):\n        node = self.nodes_dict[i]\n\n        clique = []\n        clique.extend(node[\"clique\"])\n        if not node[\"is_leaf\"]:\n            for cidx in node[\"clique\"]:\n                original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(node[\"nid\"])\n\n        for j in self.graph.successors(i).numpy():\n            nei_node = self.nodes_dict[j]\n            clique.extend(nei_node[\"clique\"])\n            if nei_node[\"is_leaf\"]:  # Leaf node, no need to mark\n                continue\n            for cidx in nei_node[\"clique\"]:\n                # allow singleton node override the atom mapping\n                if cidx not in node[\"clique\"] or len(nei_node[\"clique\"]) == 1:\n                    atom = original_mol.GetAtomWithIdx(cidx)\n                    atom.SetAtomMapNum(nei_node[\"nid\"])\n\n        clique = list(set(clique))\n        label_mol = get_clique_mol(original_mol, clique)\n        node[\"label\"] = Chem.MolToSmiles(\n            Chem.MolFromSmiles(get_smiles(label_mol))\n        )\n        node[\"label_mol\"] = get_mol(node[\"label\"])\n\n        for cidx in clique:\n            original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(0)\n\n        return node[\"label\"]\n\n    def _assemble_node(self, i):\n        neighbors = [\n            self.nodes_dict[j]\n            for j in self.graph.successors(i).numpy()\n            if self.nodes_dict[j][\"mol\"].GetNumAtoms() > 1\n        ]\n        neighbors = sorted(\n            neighbors, key=lambda x: x[\"mol\"].GetNumAtoms(), reverse=True\n        )\n        singletons = [\n            self.nodes_dict[j]\n            for j in self.graph.successors(i).numpy()\n            if self.nodes_dict[j][\"mol\"].GetNumAtoms() == 1\n        ]\n        neighbors = singletons + neighbors\n\n        cands = enum_assemble_nx(self.nodes_dict[i], neighbors)\n\n        if len(cands) > 0:\n            (\n                self.nodes_dict[i][\"cands\"],\n                self.nodes_dict[i][\"cand_mols\"],\n                _,\n            ) = list(zip(*cands))\n            self.nodes_dict[i][\"cands\"] = list(self.nodes_dict[i][\"cands\"])\n            self.nodes_dict[i][\"cand_mols\"] = list(\n                self.nodes_dict[i][\"cand_mols\"]\n            )\n        else:\n            self.nodes_dict[i][\"cands\"] = []\n            self.nodes_dict[i][\"cand_mols\"] = []\n\n    def recover(self):\n        for i in self.nodes_dict:\n            self._recover_node(i, self.mol)\n\n    def assemble(self):\n        for i in self.nodes_dict:\n            self._assemble_node(i)\n"
  },
  {
    "path": "examples/pytorch/jtnn/jtnn/mpn.py",
    "content": "import dgl\nimport dgl.function as DGLF\nimport rdkit.Chem as Chem\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dgl import line_graph, mean_nodes\n\nfrom .chemutils import get_mol\n\nELEM_LIST = [\n    \"C\",\n    \"N\",\n    \"O\",\n    \"S\",\n    \"F\",\n    \"Si\",\n    \"P\",\n    \"Cl\",\n    \"Br\",\n    \"Mg\",\n    \"Na\",\n    \"Ca\",\n    \"Fe\",\n    \"Al\",\n    \"I\",\n    \"B\",\n    \"K\",\n    \"Se\",\n    \"Zn\",\n    \"H\",\n    \"Cu\",\n    \"Mn\",\n    \"unknown\",\n]\n\nATOM_FDIM = len(ELEM_LIST) + 6 + 5 + 4 + 1\nBOND_FDIM = 5 + 6\nMAX_NB = 6\n\n\ndef onek_encoding_unk(x, allowable_set):\n    if x not in allowable_set:\n        x = allowable_set[-1]\n    return [x == s for s in allowable_set]\n\n\ndef atom_features(atom):\n    return torch.Tensor(\n        onek_encoding_unk(atom.GetSymbol(), ELEM_LIST)\n        + onek_encoding_unk(atom.GetDegree(), [0, 1, 2, 3, 4, 5])\n        + onek_encoding_unk(atom.GetFormalCharge(), [-1, -2, 1, 2, 0])\n        + onek_encoding_unk(int(atom.GetChiralTag()), [0, 1, 2, 3])\n        + [atom.GetIsAromatic()]\n    )\n\n\ndef bond_features(bond):\n    bt = bond.GetBondType()\n    stereo = int(bond.GetStereo())\n    fbond = [\n        bt == Chem.rdchem.BondType.SINGLE,\n        bt == Chem.rdchem.BondType.DOUBLE,\n        bt == Chem.rdchem.BondType.TRIPLE,\n        bt == Chem.rdchem.BondType.AROMATIC,\n        bond.IsInRing(),\n    ]\n    fstereo = onek_encoding_unk(stereo, [0, 1, 2, 3, 4, 5])\n    return torch.Tensor(fbond + fstereo)\n\n\ndef mol2dgl_single(smiles):\n    n_edges = 0\n\n    atom_x = []\n    bond_x = []\n\n    mol = get_mol(smiles)\n    n_atoms = mol.GetNumAtoms()\n    n_bonds = mol.GetNumBonds()\n    for i, atom in enumerate(mol.GetAtoms()):\n        assert i == atom.GetIdx()\n        atom_x.append(atom_features(atom))\n\n    bond_src = []\n    bond_dst = []\n    for i, bond in enumerate(mol.GetBonds()):\n        begin_idx = bond.GetBeginAtom().GetIdx()\n        end_idx = bond.GetEndAtom().GetIdx()\n        features = bond_features(bond)\n        bond_src.append(begin_idx)\n        bond_dst.append(end_idx)\n        bond_x.append(features)\n        # set up the reverse direction\n        bond_src.append(end_idx)\n        bond_dst.append(begin_idx)\n        bond_x.append(features)\n    graph = dgl.graph((bond_src, bond_dst), num_nodes=n_atoms)\n    n_edges += n_bonds\n    return (\n        graph,\n        torch.stack(atom_x),\n        torch.stack(bond_x) if len(bond_x) > 0 else torch.zeros(0),\n    )\n\n\nclass LoopyBPUpdate(nn.Module):\n    def __init__(self, hidden_size):\n        super(LoopyBPUpdate, self).__init__()\n        self.hidden_size = hidden_size\n\n        self.W_h = nn.Linear(hidden_size, hidden_size, bias=False)\n\n    def forward(self, nodes):\n        msg_input = nodes.data[\"msg_input\"]\n        msg_delta = self.W_h(nodes.data[\"accum_msg\"])\n        msg = F.relu(msg_input + msg_delta)\n        return {\"msg\": msg}\n\n\nclass GatherUpdate(nn.Module):\n    def __init__(self, hidden_size):\n        super(GatherUpdate, self).__init__()\n        self.hidden_size = hidden_size\n\n        self.W_o = nn.Linear(ATOM_FDIM + hidden_size, hidden_size)\n\n    def forward(self, nodes):\n        m = nodes.data[\"m\"]\n        return {\n            \"h\": F.relu(self.W_o(torch.cat([nodes.data[\"x\"], m], 1))),\n        }\n\n\nclass DGLMPN(nn.Module):\n    def __init__(self, hidden_size, depth):\n        super(DGLMPN, self).__init__()\n\n        self.depth = depth\n\n        self.W_i = nn.Linear(ATOM_FDIM + BOND_FDIM, hidden_size, bias=False)\n\n        self.loopy_bp_updater = LoopyBPUpdate(hidden_size)\n        self.gather_updater = GatherUpdate(hidden_size)\n        self.hidden_size = hidden_size\n\n        self.n_samples_total = 0\n        self.n_nodes_total = 0\n        self.n_edges_total = 0\n        self.n_passes = 0\n\n    def forward(self, mol_graph):\n        n_samples = mol_graph.batch_size\n\n        mol_line_graph = line_graph(mol_graph, backtracking=False, shared=True)\n\n        n_nodes = mol_graph.num_nodes()\n        n_edges = mol_graph.num_edges()\n\n        mol_graph = self.run(mol_graph, mol_line_graph)\n\n        # TODO: replace with unbatch or readout\n        g_repr = mean_nodes(mol_graph, \"h\")\n\n        self.n_samples_total += n_samples\n        self.n_nodes_total += n_nodes\n        self.n_edges_total += n_edges\n        self.n_passes += 1\n\n        return g_repr\n\n    def run(self, mol_graph, mol_line_graph):\n        n_nodes = mol_graph.num_nodes()\n\n        mol_graph.apply_edges(\n            func=lambda edges: {\"src_x\": edges.src[\"x\"]},\n        )\n        mol_line_graph.ndata.update(mol_graph.edata)\n\n        e_repr = mol_line_graph.ndata\n        bond_features = e_repr[\"x\"]\n        source_features = e_repr[\"src_x\"]\n\n        features = torch.cat([source_features, bond_features], 1)\n        msg_input = self.W_i(features)\n        mol_line_graph.ndata.update(\n            {\n                \"msg_input\": msg_input,\n                \"msg\": F.relu(msg_input),\n                \"accum_msg\": torch.zeros_like(msg_input),\n            }\n        )\n        mol_graph.ndata.update(\n            {\n                \"m\": bond_features.new(n_nodes, self.hidden_size).zero_(),\n                \"h\": bond_features.new(n_nodes, self.hidden_size).zero_(),\n            }\n        )\n\n        for i in range(self.depth - 1):\n            mol_line_graph.update_all(\n                DGLF.copy_u(\"msg\", \"msg\"), DGLF.sum(\"msg\", \"accum_msg\")\n            )\n            mol_line_graph.apply_nodes(self.loopy_bp_updater)\n\n        mol_graph.edata.update(mol_line_graph.ndata)\n        mol_graph.update_all(DGLF.copy_e(\"msg\", \"msg\"), DGLF.sum(\"msg\", \"m\"))\n        mol_graph.apply_nodes(self.gather_updater)\n\n        return mol_graph\n"
  },
  {
    "path": "examples/pytorch/jtnn/jtnn/nnutils.py",
    "content": "import os\n\nimport dgl\n\nimport torch\nimport torch.nn as nn\n\n\ndef cuda(x):\n    if torch.cuda.is_available() and not os.getenv(\"NOCUDA\", None):\n        return x.to(torch.device(\"cuda\"))  # works for both DGLGraph and tensor\n    else:\n        return x\n\n\nclass GRUUpdate(nn.Module):\n    def __init__(self, hidden_size):\n        nn.Module.__init__(self)\n        self.hidden_size = hidden_size\n\n        self.W_z = nn.Linear(2 * hidden_size, hidden_size)\n        self.W_r = nn.Linear(hidden_size, hidden_size, bias=False)\n        self.U_r = nn.Linear(hidden_size, hidden_size)\n        self.W_h = nn.Linear(2 * hidden_size, hidden_size)\n\n    def update_zm(self, node):\n        src_x = node.data[\"src_x\"]\n        s = node.data[\"s\"]\n        rm = node.data[\"accum_rm\"]\n        z = torch.sigmoid(self.W_z(torch.cat([src_x, s], 1)))\n        m = torch.tanh(self.W_h(torch.cat([src_x, rm], 1)))\n        m = (1 - z) * s + z * m\n        return {\"m\": m, \"z\": z}\n\n    def update_r(self, node, zm=None):\n        dst_x = node.data[\"dst_x\"]\n        m = node.data[\"m\"] if zm is None else zm[\"m\"]\n        r_1 = self.W_r(dst_x)\n        r_2 = self.U_r(m)\n        r = torch.sigmoid(r_1 + r_2)\n        return {\"r\": r, \"rm\": r * m}\n\n    def forward(self, node):\n        dic = self.update_zm(node)\n        dic.update(self.update_r(node, zm=dic))\n        return dic\n\n\ndef tocpu(g):\n    src, dst = g.edges()\n    src = src.cpu()\n    dst = dst.cpu()\n    return dgl.graph((src, dst), num_nodes=g.num_nodes())\n"
  },
  {
    "path": "examples/pytorch/jtnn/vaetrain_dgl.py",
    "content": "import math\nimport random\nimport sys\nfrom collections import deque\nfrom optparse import OptionParser\n\nimport rdkit\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nimport torch.optim.lr_scheduler as lr_scheduler\nimport tqdm\nfrom jtnn import *\nfrom torch.utils.data import DataLoader\n\ntorch.multiprocessing.set_sharing_strategy(\"file_system\")\n\n\ndef worker_init_fn(id_):\n    lg = rdkit.RDLogger.logger()\n    lg.setLevel(rdkit.RDLogger.CRITICAL)\n\n\nworker_init_fn(None)\n\nparser = OptionParser()\nparser.add_option(\n    \"-t\", \"--train\", dest=\"train\", default=\"train\", help=\"Training file name\"\n)\nparser.add_option(\n    \"-v\", \"--vocab\", dest=\"vocab\", default=\"vocab\", help=\"Vocab file name\"\n)\nparser.add_option(\"-s\", \"--save_dir\", dest=\"save_path\")\nparser.add_option(\"-m\", \"--model\", dest=\"model_path\", default=None)\nparser.add_option(\"-b\", \"--batch\", dest=\"batch_size\", default=40)\nparser.add_option(\"-w\", \"--hidden\", dest=\"hidden_size\", default=200)\nparser.add_option(\"-l\", \"--latent\", dest=\"latent_size\", default=56)\nparser.add_option(\"-d\", \"--depth\", dest=\"depth\", default=3)\nparser.add_option(\"-z\", \"--beta\", dest=\"beta\", default=1.0)\nparser.add_option(\"-q\", \"--lr\", dest=\"lr\", default=1e-3)\nparser.add_option(\"-T\", \"--test\", dest=\"test\", action=\"store_true\")\nopts, args = parser.parse_args()\n\ndataset = JTNNDataset(data=opts.train, vocab=opts.vocab, training=True)\nvocab = dataset.vocab\n\nbatch_size = int(opts.batch_size)\nhidden_size = int(opts.hidden_size)\nlatent_size = int(opts.latent_size)\ndepth = int(opts.depth)\nbeta = float(opts.beta)\nlr = float(opts.lr)\n\nmodel = DGLJTNNVAE(vocab, hidden_size, latent_size, depth)\n\nif opts.model_path is not None:\n    model.load_state_dict(torch.load(opts.model_path, weights_only=False))\nelse:\n    for param in model.parameters():\n        if param.dim() == 1:\n            nn.init.constant(param, 0)\n        else:\n            nn.init.xavier_normal(param)\n\nmodel = cuda(model)\nprint(\n    \"Model #Params: %dK\"\n    % (sum([x.nelement() for x in model.parameters()]) / 1000,)\n)\n\noptimizer = optim.Adam(model.parameters(), lr=lr)\nscheduler = lr_scheduler.ExponentialLR(optimizer, 0.9)\nscheduler.step()\n\nMAX_EPOCH = 100\nPRINT_ITER = 20\n\n\ndef train():\n    dataset.training = True\n    dataloader = DataLoader(\n        dataset,\n        batch_size=batch_size,\n        shuffle=True,\n        num_workers=4,\n        collate_fn=JTNNCollator(vocab, True),\n        drop_last=True,\n        worker_init_fn=worker_init_fn,\n    )\n\n    for epoch in range(MAX_EPOCH):\n        word_acc, topo_acc, assm_acc, steo_acc = 0, 0, 0, 0\n\n        for it, batch in enumerate(tqdm.tqdm(dataloader)):\n            model.zero_grad()\n            try:\n                loss, kl_div, wacc, tacc, sacc, dacc = model(batch, beta)\n            except:\n                print([t.smiles for t in batch[\"mol_trees\"]])\n                raise\n            loss.backward()\n            optimizer.step()\n\n            word_acc += wacc\n            topo_acc += tacc\n            assm_acc += sacc\n            steo_acc += dacc\n\n            if (it + 1) % PRINT_ITER == 0:\n                word_acc = word_acc / PRINT_ITER * 100\n                topo_acc = topo_acc / PRINT_ITER * 100\n                assm_acc = assm_acc / PRINT_ITER * 100\n                steo_acc = steo_acc / PRINT_ITER * 100\n\n                print(\n                    \"KL: %.1f, Word: %.2f, Topo: %.2f, Assm: %.2f, Steo: %.2f, Loss: %.6f\"\n                    % (\n                        kl_div,\n                        word_acc,\n                        topo_acc,\n                        assm_acc,\n                        steo_acc,\n                        loss.item(),\n                    )\n                )\n                word_acc, topo_acc, assm_acc, steo_acc = 0, 0, 0, 0\n                sys.stdout.flush()\n\n            if (it + 1) % 1500 == 0:  # Fast annealing\n                scheduler.step()\n                print(\"learning rate: %.6f\" % scheduler.get_lr()[0])\n                torch.save(\n                    model.state_dict(),\n                    opts.save_path + \"/model.iter-%d-%d\" % (epoch, it + 1),\n                )\n\n        scheduler.step()\n        print(\"learning rate: %.6f\" % scheduler.get_lr()[0])\n        torch.save(\n            model.state_dict(), opts.save_path + \"/model.iter-\" + str(epoch)\n        )\n\n\ndef test():\n    dataset.training = False\n    dataloader = DataLoader(\n        dataset,\n        batch_size=1,\n        shuffle=False,\n        num_workers=0,\n        collate_fn=JTNNCollator(vocab, False),\n        drop_last=True,\n        worker_init_fn=worker_init_fn,\n    )\n\n    # Just an example of molecule decoding; in reality you may want to sample\n    # tree and molecule vectors.\n    for it, batch in enumerate(dataloader):\n        gt_smiles = batch[\"mol_trees\"][0].smiles\n        print(gt_smiles)\n        model.move_to_cuda(batch)\n        _, tree_vec, mol_vec = model.encode(batch)\n        tree_vec, mol_vec, _, _ = model.sample(tree_vec, mol_vec)\n        smiles = model.decode(tree_vec, mol_vec)\n        print(smiles)\n\n\nif __name__ == \"__main__\":\n    if opts.test:\n        test()\n    else:\n        train()\n\n    print(\"# passes:\", model.n_passes)\n    print(\"Total # nodes processed:\", model.n_nodes_total)\n    print(\"Total # edges processed:\", model.n_edges_total)\n    print(\"Total # tree nodes processed:\", model.n_tree_nodes_total)\n    print(\"Graph decoder: # passes:\", model.jtmpn.n_passes)\n    print(\n        \"Graph decoder: Total # candidates processed:\",\n        model.jtmpn.n_samples_total,\n    )\n    print(\"Graph decoder: Total # nodes processed:\", model.jtmpn.n_nodes_total)\n    print(\"Graph decoder: Total # edges processed:\", model.jtmpn.n_edges_total)\n    print(\"Graph encoder: # passes:\", model.mpn.n_passes)\n    print(\n        \"Graph encoder: Total # candidates processed:\",\n        model.mpn.n_samples_total,\n    )\n    print(\"Graph encoder: Total # nodes processed:\", model.mpn.n_nodes_total)\n    print(\"Graph encoder: Total # edges processed:\", model.mpn.n_edges_total)\n"
  },
  {
    "path": "examples/pytorch/label_propagation/README.md",
    "content": "# DGL Implementation of Label Propagation\n\nThis DGL example implements the method proposed in the paper [Learning from Labeled and Unlabeled Data with Label Propagation](https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.14.3864&rep=rep1&type=pdf).\n\nContributor: [xnuohz](https://github.com/xnuohz)\n\n### Requirements\nThe codebase is implemented in Python 3.7. For version requirement of packages, see below.\n\n```\ndgl 0.6.0.post1\ntorch 1.7.0\n```\n\n### The graph datasets used in this example\n\nThe DGL's built-in Cora, Pubmed and Citeseer datasets. Dataset summary:\n\n| Dataset  | #Nodes | #Edges | #Feats | #Classes | #Train Nodes | #Val Nodes | #Test Nodes |\n| :------: | :----: | :----: | :----: | :------: | :----------: | :--------: | :---------: |\n| Citeseer | 3,327  | 9,228  | 3,703  |    6     |     120      |    500     |    1000     |\n|   Cora   | 2,708  | 10,556 | 1,433  |    7     |     140      |    500     |    1000     |\n|  Pubmed  | 19,717 | 88,651 |  500   |    3     |      60      |    500     |    1000     |\n\n### Usage\n\n```bash\n# Cora\npython main.py\n\n# Citeseer\npython main.py --dataset Citeseer --num-layers 100 --alpha 0.99\n\n# Pubmed\npython main.py --dataset Pubmed --num-layers 60 --alpha 1\n```\n\n### Performance\n\n|   Dataset    | Cora  | Citeseer | Pubmed |\n| :----------: | :---: | :------: | :----: |\n| Results(DGL) | 69.20 | 51.30 | 71.40 |\n"
  },
  {
    "path": "examples/pytorch/label_propagation/main.py",
    "content": "import argparse\n\nimport dgl\n\nimport torch\nfrom dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset\nfrom dgl.nn import LabelPropagation\n\n\ndef main():\n    # check cuda\n    device = (\n        f\"cuda:{args.gpu}\"\n        if torch.cuda.is_available() and args.gpu >= 0\n        else \"cpu\"\n    )\n\n    # load data\n    if args.dataset == \"Cora\":\n        dataset = CoraGraphDataset()\n    elif args.dataset == \"Citeseer\":\n        dataset = CiteseerGraphDataset()\n    elif args.dataset == \"Pubmed\":\n        dataset = PubmedGraphDataset()\n    else:\n        raise ValueError(\"Dataset {} is invalid.\".format(args.dataset))\n\n    g = dataset[0]\n    g = dgl.add_self_loop(g)\n\n    labels = g.ndata.pop(\"label\").to(device).long()\n\n    # load masks for train / test, valid is not used.\n    train_mask = g.ndata.pop(\"train_mask\")\n    test_mask = g.ndata.pop(\"test_mask\")\n\n    train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze().to(device)\n    test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze().to(device)\n\n    g = g.to(device)\n\n    # label propagation\n    lp = LabelPropagation(args.num_layers, args.alpha)\n    logits = lp(g, labels, mask=train_idx)\n\n    test_acc = torch.sum(\n        logits[test_idx].argmax(dim=1) == labels[test_idx]\n    ).item() / len(test_idx)\n    print(\"Test Acc {:.4f}\".format(test_acc))\n\n\nif __name__ == \"__main__\":\n    \"\"\"\n    Label Propagation Hyperparameters\n    \"\"\"\n    parser = argparse.ArgumentParser(description=\"LP\")\n    parser.add_argument(\"--gpu\", type=int, default=0)\n    parser.add_argument(\"--dataset\", type=str, default=\"Cora\")\n    parser.add_argument(\"--num-layers\", type=int, default=10)\n    parser.add_argument(\"--alpha\", type=float, default=0.5)\n\n    args = parser.parse_args()\n    print(args)\n\n    main()\n"
  },
  {
    "path": "examples/pytorch/labor/README.md",
    "content": "Layer-Neighbor Sampling -- Defusing Neighborhood Explosion in GNNs\n============\n\n- Paper link: [https://arxiv.org/abs/2210.13339](https://arxiv.org/abs/2210.13339)\nThis is the official Labor sampling example to reproduce the results in the original\npaper with the GraphSAGE GNN model. The model can be changed to any other model where\nNeighborSampler can be used.\n\nA more modern and performant version is provided in the\n`examples/graphbolt/pyg/labor` folder.\n\nRequirements\n------------\n\n```bash\npip install requests lightning==2.0.6 ogb\n```\n\nHow to run\n-------\n\n### Minibatch training for node classification\n\nTrain w/ mini-batch sampling on the GPU for node classification on \"ogbn-products\"\n\n```bash\npython3 train_lightning.py --dataset=ogbn-products\n```\n\nResults:\n```\nTest Accuracy: 0.797\n```\n\nAny integer passed as the `--importance-sampling=i` argument runs the corresponding\nLABOR-i variant. `--importance-sampling=-1` runs the LABOR-* variant.\n\n`--vertex-limit` argument is used if a vertex sampling budget is needed. It adjusts\nthe batch size at the end of every epoch so that the average number of sampled vertices\nconverges to the provided vertex limit. Can be used to replicate the vertex sampling\nbudget experiments in the Labor paper.\n\nDuring training runs, statistics about number of sampled vertices, edges,\ncache miss rates will be reported. One can use tensorboard to look at their plots\nduring/after training:\n\n```bash\ntensorboard --logdir tb_logs\n```\n\n## Utilize a GPU feature cache for UVA training\n\n```bash\npython3 train_lightning.py --dataset=ogbn-products --use-uva --cache-size=500000\n```\n\n## Reduce GPU feature cache miss rate for UVA training\n\n```bash\npython3 train_lightning.py --dataset=ogbn-products --use-uva --cache-size=500000 --batch-dependency=64\n```\n\n## Force all layers to share the same neighborhood for shared vertices\n\n```bash\npython3 train_lightning.py --dataset=ogbn-products --layer-dependency\n```"
  },
  {
    "path": "examples/pytorch/labor/ladies_sampler.py",
    "content": "# referenced the following implementation: https://github.com/BarclayII/dgl/blob/ladies/examples/pytorch/ladies/ladies2.py\n\nimport dgl\nimport dgl.function as fn\nimport torch\n\n\ndef find_indices_in(a, b):\n    b_sorted, indices = torch.sort(b)\n    sorted_indices = torch.searchsorted(b_sorted, a)\n    sorted_indices[sorted_indices >= indices.shape[0]] = 0\n    return indices[sorted_indices]\n\n\ndef union(*arrays):\n    return torch.unique(torch.cat(arrays))\n\n\ndef normalized_edata(g, weight=None):\n    with g.local_scope():\n        if weight is None:\n            weight = \"W\"\n            g.edata[weight] = torch.ones(g.number_of_edges(), device=g.device)\n        g.update_all(fn.copy_e(weight, weight), fn.sum(weight, \"v\"))\n        g.apply_edges(lambda edges: {\"w\": 1 / edges.dst[\"v\"]})\n        return g.edata[\"w\"]\n\n\nclass LadiesSampler(dgl.dataloading.BlockSampler):\n    def __init__(\n        self,\n        nodes_per_layer,\n        importance_sampling=True,\n        weight=\"w\",\n        out_weight=\"edge_weights\",\n        replace=False,\n    ):\n        super().__init__()\n        self.nodes_per_layer = nodes_per_layer\n        self.importance_sampling = importance_sampling\n        self.edge_weight = weight\n        self.output_weight = out_weight\n        self.replace = replace\n\n    def compute_prob(self, g, seed_nodes, weight, num):\n        \"\"\"\n        g : the whole graph\n        seed_nodes : the output nodes for the current layer\n        weight : the weight of the edges\n        return : the unnormalized probability of the candidate nodes, as well as the subgraph\n                 containing all the edges from the candidate nodes to the output nodes.\n        \"\"\"\n        insg = dgl.in_subgraph(g, seed_nodes)\n        insg = dgl.compact_graphs(insg, seed_nodes)\n        if self.importance_sampling:\n            out_frontier = dgl.reverse(insg, copy_edata=True)\n            weight = weight[out_frontier.edata[dgl.EID].long()]\n            prob = dgl.ops.copy_e_sum(out_frontier, weight**2)\n            # prob = torch.sqrt(prob)\n        else:\n            prob = torch.ones(insg.num_nodes())\n            prob[insg.out_degrees() == 0] = 0\n        return prob, insg\n\n    def select_neighbors(self, prob, num):\n        \"\"\"\n        seed_nodes : output nodes\n        cand_nodes : candidate nodes.  Must contain all output nodes in @seed_nodes\n        prob : unnormalized probability of each candidate node\n        num : number of neighbors to sample\n        return : the set of input nodes in terms of their indices in @cand_nodes, and also the indices of\n                 seed nodes in the selected nodes.\n        \"\"\"\n        # The returned nodes should be a union of seed_nodes plus @num nodes from cand_nodes.\n        # Because compute_prob returns a compacted subgraph and a list of probabilities,\n        # we need to find the corresponding local IDs of the resulting union in the subgraph\n        # so that we can compute the edge weights of the block.\n        # This is why we need a find_indices_in() function.\n        neighbor_nodes_idx = torch.multinomial(\n            prob, min(num, prob.shape[0]), replacement=self.replace\n        )\n        return neighbor_nodes_idx\n\n    def generate_block(self, insg, neighbor_nodes_idx, seed_nodes, P_sg, W_sg):\n        \"\"\"\n        insg : the subgraph yielded by compute_prob()\n        neighbor_nodes_idx : the sampled nodes from the subgraph @insg, yielded by select_neighbors()\n        seed_nodes_local_idx : the indices of seed nodes in the selected neighbor nodes, also yielded\n                               by select_neighbors()\n        P_sg : unnormalized probability of each node being sampled, yielded by compute_prob()\n        W_sg : edge weights of @insg\n        return : the block.\n        \"\"\"\n        seed_nodes_idx = find_indices_in(seed_nodes, insg.ndata[dgl.NID])\n        u_nodes = union(neighbor_nodes_idx, seed_nodes_idx)\n        sg = insg.subgraph(u_nodes.type(insg.idtype))\n        u, v = sg.edges()\n        lu = sg.ndata[dgl.NID][u.long()]\n        s = find_indices_in(lu, neighbor_nodes_idx)\n        eg = dgl.edge_subgraph(\n            sg, lu == neighbor_nodes_idx[s], relabel_nodes=False\n        )\n        eg.ndata[dgl.NID] = sg.ndata[dgl.NID][: eg.num_nodes()]\n        eg.edata[dgl.EID] = sg.edata[dgl.EID][eg.edata[dgl.EID].long()]\n        sg = eg\n        nids = insg.ndata[dgl.NID][sg.ndata[dgl.NID].long()]\n        P = P_sg[u_nodes.long()]\n        W = W_sg[sg.edata[dgl.EID].long()]\n        W_tilde = dgl.ops.e_div_u(sg, W, P)\n        W_tilde_sum = dgl.ops.copy_e_sum(sg, W_tilde)\n        d = sg.in_degrees()\n        W_tilde = dgl.ops.e_mul_v(sg, W_tilde, d / W_tilde_sum)\n\n        block = dgl.to_block(sg, seed_nodes_idx.type(sg.idtype))\n        block.edata[self.output_weight] = W_tilde\n        # correct node ID mapping\n        block.srcdata[dgl.NID] = nids[block.srcdata[dgl.NID].long()]\n        block.dstdata[dgl.NID] = nids[block.dstdata[dgl.NID].long()]\n\n        sg_eids = insg.edata[dgl.EID][sg.edata[dgl.EID].long()]\n        block.edata[dgl.EID] = sg_eids[block.edata[dgl.EID].long()]\n        return block\n\n    def sample_blocks(self, g, seed_nodes, exclude_eids=None):\n        output_nodes = seed_nodes\n        blocks = []\n        for block_id in reversed(range(len(self.nodes_per_layer))):\n            num_nodes_to_sample = self.nodes_per_layer[block_id]\n            W = g.edata[self.edge_weight]\n            prob, insg = self.compute_prob(\n                g, seed_nodes, W, num_nodes_to_sample\n            )\n            neighbor_nodes_idx = self.select_neighbors(\n                prob, num_nodes_to_sample\n            )\n            block = self.generate_block(\n                insg,\n                neighbor_nodes_idx.type(g.idtype),\n                seed_nodes.type(g.idtype),\n                prob,\n                W[insg.edata[dgl.EID].long()],\n            )\n            seed_nodes = block.srcdata[dgl.NID]\n            blocks.insert(0, block)\n        return seed_nodes, output_nodes, blocks\n\n\nclass PoissonLadiesSampler(LadiesSampler):\n    def __init__(\n        self,\n        nodes_per_layer,\n        importance_sampling=True,\n        weight=\"w\",\n        out_weight=\"edge_weights\",\n        skip=False,\n    ):\n        super().__init__(\n            nodes_per_layer, importance_sampling, weight, out_weight\n        )\n        self.eps = 0.9999\n        self.skip = skip\n\n    def compute_prob(self, g, seed_nodes, weight, num):\n        \"\"\"\n        g : the whole graph\n        seed_nodes : the output nodes for the current layer\n        weight : the weight of the edges\n        return : the unnormalized probability of the candidate nodes, as well as the subgraph\n                 containing all the edges from the candidate nodes to the output nodes.\n        \"\"\"\n        prob, insg = super().compute_prob(g, seed_nodes, weight, num)\n\n        one = torch.ones_like(prob)\n        if prob.shape[0] <= num:\n            return one, insg\n\n        c = 1.0\n        for i in range(50):\n            S = torch.sum(torch.minimum(prob * c, one).to(torch.float64)).item()\n            if min(S, num) / max(S, num) >= self.eps:\n                break\n            else:\n                c *= num / S\n\n        if self.skip:\n            skip_nodes = find_indices_in(seed_nodes, insg.ndata[dgl.NID])\n            prob[skip_nodes] = float(\"inf\")\n\n        return torch.minimum(prob * c, one), insg\n\n    def select_neighbors(self, prob, num):\n        \"\"\"\n        seed_nodes : output nodes\n        cand_nodes : candidate nodes.  Must contain all output nodes in @seed_nodes\n        prob : unnormalized probability of each candidate node\n        num : number of neighbors to sample\n        return : the set of input nodes in terms of their indices in @cand_nodes, and also the indices of\n                 seed nodes in the selected nodes.\n        \"\"\"\n        # The returned nodes should be a union of seed_nodes plus @num nodes from cand_nodes.\n        # Because compute_prob returns a compacted subgraph and a list of probabilities,\n        # we need to find the corresponding local IDs of the resulting union in the subgraph\n        # so that we can compute the edge weights of the block.\n        # This is why we need a find_indices_in() function.\n        neighbor_nodes_idx = torch.arange(prob.shape[0], device=prob.device)[\n            torch.bernoulli(prob) == 1\n        ]\n        return neighbor_nodes_idx\n"
  },
  {
    "path": "examples/pytorch/labor/load_graph.py",
    "content": "import dgl\nimport torch as th\n\n\ndef load_data(data):\n    g = data[0]\n    g.ndata[\"features\"] = g.ndata.pop(\"feat\")\n    g.ndata[\"labels\"] = g.ndata.pop(\"label\")\n    return g, data.num_classes\n\n\ndef load_dgl(name):\n    from dgl.data import (\n        CiteseerGraphDataset,\n        CoraGraphDataset,\n        FlickrDataset,\n        PubmedGraphDataset,\n        RedditDataset,\n        YelpDataset,\n    )\n\n    d = {\n        \"cora\": CoraGraphDataset,\n        \"citeseer\": CiteseerGraphDataset,\n        \"pubmed\": PubmedGraphDataset,\n        \"reddit\": RedditDataset,\n        \"yelp\": YelpDataset,\n        \"flickr\": FlickrDataset,\n    }\n\n    return load_data(d[name]())\n\n\ndef load_reddit(self_loop=True):\n    from dgl.data import RedditDataset\n\n    # load reddit data\n    data = RedditDataset(self_loop=self_loop)\n    return load_data(data)\n\n\ndef load_mag240m(root=\"dataset\"):\n    from os.path import join\n\n    import numpy as np\n    from ogb.lsc import MAG240MDataset\n\n    dataset = MAG240MDataset(root=root)\n\n    print(\"Loading graph\")\n    (g,), _ = dgl.load_graphs(join(root, \"mag240m_kddcup2021/graph.dgl\"))\n\n    print(\"Loading features\")\n    paper_offset = dataset.num_authors + dataset.num_institutions\n    num_nodes = paper_offset + dataset.num_papers\n    num_features = dataset.num_paper_features\n    feats = th.from_numpy(\n        np.memmap(\n            join(root, \"mag240m_kddcup2021/full.npy\"),\n            mode=\"r\",\n            dtype=\"float16\",\n            shape=(num_nodes, num_features),\n        )\n    ).float()\n    g.ndata[\"features\"] = feats\n    train_nid = th.LongTensor(dataset.get_idx_split(\"train\")) + paper_offset\n    val_nid = th.LongTensor(dataset.get_idx_split(\"valid\")) + paper_offset\n    test_nid = th.LongTensor(dataset.get_idx_split(\"test-dev\")) + paper_offset\n    train_mask = th.zeros((g.number_of_nodes(),), dtype=th.bool)\n    train_mask[train_nid] = True\n    val_mask = th.zeros((g.number_of_nodes(),), dtype=th.bool)\n    val_mask[val_nid] = True\n    test_mask = th.zeros((g.number_of_nodes(),), dtype=th.bool)\n    test_mask[test_nid] = True\n    g.ndata[\"train_mask\"] = train_mask\n    g.ndata[\"val_mask\"] = val_mask\n    g.ndata[\"test_mask\"] = test_mask\n    labels = th.tensor(dataset.paper_label)\n    num_labels = len(th.unique(labels[th.logical_not(th.isnan(labels))]))\n    g.ndata[\"labels\"] = -th.ones(g.number_of_nodes(), dtype=th.int64)\n    g.ndata[\"labels\"][train_nid] = labels[train_nid - paper_offset].long()\n    g.ndata[\"labels\"][val_nid] = labels[val_nid - paper_offset].long()\n    return g, num_labels\n\n\ndef load_ogb(name, root=\"dataset\"):\n    if name == \"ogbn-mag240M\":\n        return load_mag240m(root)\n\n    from ogb.nodeproppred import DglNodePropPredDataset\n\n    print(\"load\", name)\n    data = DglNodePropPredDataset(name=name, root=root)\n    print(\"finish loading\", name)\n    splitted_idx = data.get_idx_split()\n    graph, labels = data[0]\n    labels = labels[:, 0]\n\n    graph.ndata[\"features\"] = graph.ndata.pop(\"feat\")\n    num_labels = len(th.unique(labels[th.logical_not(th.isnan(labels))]))\n    graph.ndata[\"labels\"] = labels.type(th.LongTensor)\n    in_feats = graph.ndata[\"features\"].shape[1]\n\n    # Find the node IDs in the training, validation, and test set.\n    train_nid, val_nid, test_nid = (\n        splitted_idx[\"train\"],\n        splitted_idx[\"valid\"],\n        splitted_idx[\"test\"],\n    )\n    train_mask = th.zeros((graph.number_of_nodes(),), dtype=th.bool)\n    train_mask[train_nid] = True\n    val_mask = th.zeros((graph.number_of_nodes(),), dtype=th.bool)\n    val_mask[val_nid] = True\n    test_mask = th.zeros((graph.number_of_nodes(),), dtype=th.bool)\n    test_mask[test_nid] = True\n    graph.ndata[\"train_mask\"] = train_mask\n    graph.ndata[\"val_mask\"] = val_mask\n    graph.ndata[\"test_mask\"] = test_mask\n    print(\"finish constructing\", name)\n    return graph, num_labels\n\n\ndef load_dataset(dataset_name):\n    multilabel = False\n    if dataset_name in [\n        \"reddit\",\n        \"cora\",\n        \"citeseer\",\n        \"pubmed\",\n        \"yelp\",\n        \"flickr\",\n    ]:\n        g, n_classes = load_dgl(dataset_name)\n        multilabel = dataset_name in [\"yelp\"]\n        if multilabel:\n            g.ndata[\"labels\"] = g.ndata[\"labels\"].to(dtype=th.float32)\n    elif dataset_name in [\n        \"ogbn-products\",\n        \"ogbn-arxiv\",\n        \"ogbn-papers100M\",\n        \"ogbn-mag240M\",\n    ]:\n        g, n_classes = load_ogb(dataset_name)\n    else:\n        raise ValueError(\"unknown dataset\")\n\n    return g, n_classes, multilabel\n"
  },
  {
    "path": "examples/pytorch/labor/model.py",
    "content": "import dgl\nimport dgl.nn as dglnn\nimport sklearn.linear_model as lm\nimport sklearn.metrics as skm\nimport torch as th\nimport torch.functional as F\nimport torch.nn as nn\nimport tqdm\n\n\nclass SAGE(nn.Module):\n    def __init__(\n        self, in_feats, n_hidden, n_classes, n_layers, activation, dropout\n    ):\n        super().__init__()\n        self.init(in_feats, n_hidden, n_classes, n_layers, activation, dropout)\n\n    def init(\n        self, in_feats, n_hidden, n_classes, n_layers, activation, dropout\n    ):\n        self.n_layers = n_layers\n        self.n_hidden = n_hidden\n        self.n_classes = n_classes\n        self.layers = nn.ModuleList()\n        if n_layers > 1:\n            self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, \"mean\"))\n            for i in range(1, n_layers - 1):\n                self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, \"mean\"))\n            self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, \"mean\"))\n        else:\n            self.layers.append(dglnn.SAGEConv(in_feats, n_classes, \"mean\"))\n        self.dropout = nn.Dropout(dropout)\n        self.activation = activation\n\n    def forward(self, blocks, x):\n        h = x\n        for l, (layer, block) in enumerate(zip(self.layers, blocks)):\n            h = layer(\n                block,\n                h,\n                edge_weight=block.edata[\"edge_weights\"]\n                if \"edge_weights\" in block.edata\n                else None,\n            )\n            if l != len(self.layers) - 1:\n                h = self.activation(h)\n                h = self.dropout(h)\n        return h\n\n    def inference(self, g, device, batch_size, use_uva, num_workers):\n        # The difference between this inference function and the one in the official\n        # example is that the intermediate results can also benefit from prefetching.\n        g.ndata[\"h\"] = g.ndata[\"features\"]\n        sampler = dgl.dataloading.MultiLayerFullNeighborSampler(\n            1, prefetch_node_feats=[\"h\"]\n        )\n        pin_memory = g.device != device and use_uva\n        dataloader = dgl.dataloading.DataLoader(\n            g,\n            th.arange(g.num_nodes(), dtype=g.idtype, device=g.device),\n            sampler,\n            device=device,\n            batch_size=batch_size,\n            shuffle=False,\n            drop_last=False,\n            use_uva=use_uva,\n            num_workers=num_workers,\n            persistent_workers=(num_workers > 0),\n        )\n\n        self.eval()\n\n        for l, layer in enumerate(self.layers):\n            y = th.empty(\n                g.num_nodes(),\n                self.n_hidden if l != len(self.layers) - 1 else self.n_classes,\n                dtype=g.ndata[\"h\"].dtype,\n                device=g.device,\n                pin_memory=pin_memory,\n            )\n            for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):\n                x = blocks[0].srcdata[\"h\"]\n                h = layer(blocks[0], x)\n                if l < len(self.layers) - 1:\n                    h = self.activation(h)\n                    h = self.dropout(h)\n                # by design, our output nodes are contiguous\n                y[output_nodes[0].item() : output_nodes[-1].item() + 1] = h.to(\n                    y.device\n                )\n            g.ndata[\"h\"] = y\n        return y\n"
  },
  {
    "path": "examples/pytorch/labor/train_lightning.py",
    "content": "# /*!\n#  *   Copyright (c) 2022, NVIDIA Corporation\n#  *   Copyright (c) 2022, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)\n#  *   All rights reserved.\n#  *\n#  *   Licensed under the Apache License, Version 2.0 (the \"License\");\n#  *   you may not use this file except in compliance with the License.\n#  *   You may obtain a copy of the License at\n#  *\n#  *       http://www.apache.org/licenses/LICENSE-2.0\n#  *\n#  *   Unless required by applicable law or agreed to in writing, software\n#  *   distributed under the License is distributed on an \"AS IS\" BASIS,\n#  *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  *   See the License for the specific language governing permissions and\n#  *   limitations under the License.\n#  *\n#  * @file train_lightning.py\n#  * @brief labor sampling example\n#  */\n\nimport argparse\nimport glob\nimport math\nimport os\nimport time\n\nimport dgl\nimport torch as th\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom ladies_sampler import LadiesSampler, normalized_edata, PoissonLadiesSampler\n\nfrom load_graph import load_dataset\nfrom model import SAGE\nfrom pytorch_lightning import LightningDataModule, LightningModule, Trainer\nfrom pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint\nfrom pytorch_lightning.loggers import TensorBoardLogger\n\nfrom torchmetrics.classification import MulticlassF1Score, MultilabelF1Score\n\n\nclass SAGELightning(LightningModule):\n    def __init__(\n        self,\n        in_feats,\n        n_hidden,\n        n_classes,\n        n_layers,\n        activation,\n        dropout,\n        lr,\n        multilabel,\n    ):\n        super().__init__()\n        self.save_hyperparameters()\n        self.module = SAGE(\n            in_feats, n_hidden, n_classes, n_layers, activation, dropout\n        )\n        self.lr = lr\n        self.f1score_class = lambda: (\n            MulticlassF1Score if not multilabel else MultilabelF1Score\n        )(n_classes, average=\"micro\")\n        self.train_acc = self.f1score_class()\n        self.val_acc = self.f1score_class()\n        self.num_steps = 0\n        self.cum_sampled_nodes = [0 for _ in range(n_layers + 1)]\n        self.cum_sampled_edges = [0 for _ in range(n_layers)]\n        self.w = 0.99\n        self.loss_fn = (\n            nn.CrossEntropyLoss() if not multilabel else nn.BCEWithLogitsLoss()\n        )\n        self.pt = 0\n\n    def num_sampled_nodes(self, i):\n        return (\n            self.cum_sampled_nodes[i] / self.num_steps\n            if self.w >= 1\n            else self.cum_sampled_nodes[i]\n            * (1 - self.w)\n            / (1 - self.w**self.num_steps)\n        )\n\n    def num_sampled_edges(self, i):\n        return (\n            self.cum_sampled_edges[i] / self.num_steps\n            if self.w >= 1\n            else self.cum_sampled_edges[i]\n            * (1 - self.w)\n            / (1 - self.w**self.num_steps)\n        )\n\n    def training_step(self, batch, batch_idx):\n        input_nodes, output_nodes, mfgs = batch\n        mfgs = [mfg.int().to(device) for mfg in mfgs]\n        self.num_steps += 1\n        for i, mfg in enumerate(mfgs):\n            self.cum_sampled_nodes[i] = (\n                self.cum_sampled_nodes[i] * self.w + mfg.num_src_nodes()\n            )\n            self.cum_sampled_edges[i] = (\n                self.cum_sampled_edges[i] * self.w + mfg.num_edges()\n            )\n            self.log(\n                \"num_nodes/{}\".format(i),\n                self.num_sampled_nodes(i),\n                prog_bar=True,\n                on_step=True,\n                on_epoch=False,\n            )\n            self.log(\n                \"num_edges/{}\".format(i),\n                self.num_sampled_edges(i),\n                prog_bar=True,\n                on_step=True,\n                on_epoch=False,\n            )\n        # for batch size monitoring\n        i = len(mfgs)\n        self.cum_sampled_nodes[i] = (\n            self.cum_sampled_nodes[i] * self.w + mfgs[-1].num_dst_nodes()\n        )\n        self.log(\n            \"num_nodes/{}\".format(i),\n            self.num_sampled_nodes(i),\n            prog_bar=True,\n            on_step=True,\n            on_epoch=False,\n        )\n\n        batch_inputs = mfgs[0].srcdata[\"features\"]\n        batch_labels = mfgs[-1].dstdata[\"labels\"]\n        self.st = time.time()\n        batch_pred = self.module(mfgs, batch_inputs)\n        loss = self.loss_fn(batch_pred, batch_labels)\n        self.train_acc(batch_pred, batch_labels.int())\n        self.log(\n            \"train_acc\",\n            self.train_acc,\n            prog_bar=True,\n            on_step=True,\n            on_epoch=True,\n            batch_size=batch_labels.shape[0],\n        )\n        self.log(\n            \"train_loss\",\n            loss,\n            on_step=True,\n            on_epoch=True,\n            batch_size=batch_labels.shape[0],\n        )\n        t = time.time()\n        self.log(\n            \"iter_time\",\n            t - self.pt,\n            prog_bar=True,\n            on_step=True,\n            on_epoch=False,\n        )\n        self.pt = t\n        return loss\n\n    def on_train_batch_end(self, outputs, batch, batch_idx):\n        self.log(\n            \"forward_backward_time\",\n            time.time() - self.st,\n            prog_bar=True,\n            on_step=True,\n            on_epoch=False,\n        )\n\n    def validation_step(self, batch, batch_idx, dataloader_idx=0):\n        input_nodes, output_nodes, mfgs = batch\n        mfgs = [mfg.int().to(device) for mfg in mfgs]\n        batch_inputs = mfgs[0].srcdata[\"features\"]\n        batch_labels = mfgs[-1].dstdata[\"labels\"]\n        batch_pred = self.module(mfgs, batch_inputs)\n        loss = self.loss_fn(batch_pred, batch_labels)\n        self.val_acc(batch_pred, batch_labels.int())\n        self.log(\n            \"val_acc\",\n            self.val_acc,\n            prog_bar=True,\n            on_step=False,\n            on_epoch=True,\n            sync_dist=True,\n            batch_size=batch_labels.shape[0],\n        )\n        self.log(\n            \"val_loss\",\n            loss,\n            on_step=False,\n            on_epoch=True,\n            sync_dist=True,\n            batch_size=batch_labels.shape[0],\n        )\n\n    def configure_optimizers(self):\n        optimizer = th.optim.Adam(self.parameters(), lr=self.lr)\n        return optimizer\n\n\nclass DataModule(LightningDataModule):\n    def __init__(\n        self,\n        dataset_name,\n        undirected,\n        data_cpu=False,\n        use_uva=False,\n        fan_out=[10, 25],\n        lad_out=[11000, 5000],\n        device=th.device(\"cpu\"),\n        batch_size=1000,\n        num_workers=4,\n        sampler=\"labor\",\n        importance_sampling=0,\n        layer_dependency=False,\n        batch_dependency=1,\n        cache_size=0,\n    ):\n        super().__init__()\n\n        g, n_classes, multilabel = load_dataset(dataset_name)\n        if undirected:\n            src, dst = g.all_edges()\n            g.add_edges(dst, src)\n        cast_to_int = max(g.num_nodes(), g.num_edges()) <= 2e9\n        if cast_to_int:\n            g = g.int()\n\n        train_nid = th.nonzero(g.ndata[\"train_mask\"], as_tuple=True)[0]\n        val_nid = th.nonzero(g.ndata[\"val_mask\"], as_tuple=True)[0]\n        test_nid = th.nonzero(g.ndata[\"test_mask\"], as_tuple=True)[0]\n\n        fanouts = [int(_) for _ in fan_out]\n        ladouts = [int(_) for _ in lad_out]\n        if sampler == \"neighbor\":\n            sampler = dgl.dataloading.NeighborSampler(\n                fanouts,\n                prefetch_node_feats=[\"features\"],\n                prefetch_edge_feats=[\"etype\"] if \"etype\" in g.edata else [],\n                prefetch_labels=[\"labels\"],\n            )\n        elif \"ladies\" in sampler:\n            g.edata[\"w\"] = normalized_edata(g)\n            sampler = (\n                PoissonLadiesSampler if \"poisson\" in sampler else LadiesSampler\n            )(ladouts)\n        else:\n            sampler = dgl.dataloading.LaborSampler(\n                fanouts,\n                importance_sampling=importance_sampling,\n                layer_dependency=layer_dependency,\n                batch_dependency=batch_dependency,\n                prefetch_node_feats=[\"features\"],\n                prefetch_edge_feats=[\"etype\"] if \"etype\" in g.edata else [],\n                prefetch_labels=[\"labels\"],\n            )\n\n        dataloader_device = th.device(\"cpu\")\n        g = g.formats([\"csc\"])\n        if use_uva or not data_cpu:\n            train_nid = train_nid.to(device)\n            val_nid = val_nid.to(device)\n            test_nid = test_nid.to(device)\n            if not data_cpu and not use_uva:\n                g = g.to(device)\n            dataloader_device = device\n\n        self.g = g\n        self.train_nid = train_nid.to(g.idtype)\n        self.val_nid = val_nid.to(g.idtype)\n        self.test_nid = test_nid.to(g.idtype)\n        self.sampler = sampler\n        self.device = dataloader_device\n        self.use_uva = use_uva\n        self.batch_size = batch_size\n        self.num_workers = num_workers\n        self.in_feats = g.ndata[\"features\"].shape[1]\n        self.n_classes = n_classes\n        self.multilabel = multilabel\n        self.gpu_cache_arg = {\"node\": {\"features\": cache_size}}\n\n    def train_dataloader(self):\n        return dgl.dataloading.DataLoader(\n            self.g,\n            self.train_nid,\n            self.sampler,\n            device=self.device,\n            use_uva=self.use_uva,\n            batch_size=self.batch_size,\n            shuffle=True,\n            drop_last=True,\n            num_workers=self.num_workers,\n            gpu_cache=self.gpu_cache_arg,\n        )\n\n    def val_dataloader(self):\n        return dgl.dataloading.DataLoader(\n            self.g,\n            self.val_nid,\n            self.sampler,\n            device=self.device,\n            use_uva=self.use_uva,\n            batch_size=self.batch_size,\n            shuffle=False,\n            drop_last=False,\n            num_workers=self.num_workers,\n            gpu_cache=self.gpu_cache_arg,\n        )\n\n\nclass BatchSizeCallback(Callback):\n    def __init__(self, limit, factor=3):\n        super().__init__()\n        self.limit = limit\n        self.factor = factor\n        self.clear()\n\n    def clear(self):\n        self.n = 0\n        self.m = 0\n        self.s = 0\n\n    def push(self, x):\n        self.n += 1\n        m = self.m\n        self.m += (x - m) / self.n\n        self.s += (x - m) * (x - self.m)\n\n    @property\n    def var(self):\n        return self.s / (self.n - 1)\n\n    @property\n    def std(self):\n        return math.sqrt(self.var)\n\n    def on_train_batch_start(self, trainer, datamodule, batch, batch_idx):\n        input_nodes, output_nodes, mfgs = batch\n        features = mfgs[0].srcdata[\"features\"]\n        if hasattr(features, \"__cache_miss__\"):\n            trainer.strategy.model.log(\n                \"cache_miss\",\n                features.__cache_miss__,\n                prog_bar=True,\n                on_step=True,\n                on_epoch=False,\n            )\n\n    def on_train_batch_end(\n        self, trainer, datamodule, outputs, batch, batch_idx\n    ):\n        input_nodes, output_nodes, mfgs = batch\n        self.push(mfgs[0].num_src_nodes())\n\n    def on_train_epoch_end(self, trainer, datamodule):\n        if (\n            self.limit > 0\n            and self.n >= 2\n            and abs(self.limit - self.m) * self.n >= self.std * self.factor\n        ):\n            trainer.datamodule.batch_size = int(\n                trainer.datamodule.batch_size * self.limit / self.m\n            )\n            loop = trainer._active_loop\n            assert loop is not None\n            loop._combined_loader = None\n            loop.setup_data()\n            self.clear()\n\n\nif __name__ == \"__main__\":\n    argparser = argparse.ArgumentParser()\n    argparser.add_argument(\n        \"--gpu\",\n        type=int,\n        default=0 if th.cuda.is_available() else -1,\n        help=\"GPU device ID. Use -1 for CPU training\",\n    )\n    argparser.add_argument(\"--dataset\", type=str, default=\"reddit\")\n    argparser.add_argument(\"--num-epochs\", type=int, default=-1)\n    argparser.add_argument(\"--num-steps\", type=int, default=-1)\n    argparser.add_argument(\"--min-steps\", type=int, default=0)\n    argparser.add_argument(\"--num-hidden\", type=int, default=256)\n    argparser.add_argument(\"--num-layers\", type=int, default=3)\n    argparser.add_argument(\"--fan-out\", type=str, default=\"10,10,10\")\n    argparser.add_argument(\"--lad-out\", type=str, default=\"16000,11000,5000\")\n    argparser.add_argument(\"--batch-size\", type=int, default=1024)\n    argparser.add_argument(\"--lr\", type=float, default=0.001)\n    argparser.add_argument(\"--dropout\", type=float, default=0.5)\n    argparser.add_argument(\n        \"--num-workers\",\n        type=int,\n        default=0,\n        help=\"Number of sampling processes. Use 0 for no extra process.\",\n    )\n    argparser.add_argument(\n        \"--data-cpu\",\n        action=\"store_true\",\n        help=\"By default the script puts the node features and labels \"\n        \"on GPU when using it to save time for data copy. This may \"\n        \"be undesired if they cannot fit in GPU memory at once. \"\n        \"This flag disables that.\",\n    )\n    argparser.add_argument(\n        \"--sampler\",\n        type=str,\n        default=\"labor\",\n        choices=[\"neighbor\", \"labor\", \"ladies\", \"poisson-ladies\"],\n    )\n    argparser.add_argument(\"--importance-sampling\", type=int, default=0)\n    argparser.add_argument(\"--layer-dependency\", action=\"store_true\")\n    argparser.add_argument(\"--batch-dependency\", type=int, default=1)\n    argparser.add_argument(\"--logdir\", type=str, default=\"tb_logs\")\n    argparser.add_argument(\"--vertex-limit\", type=int, default=-1)\n    argparser.add_argument(\"--use-uva\", action=\"store_true\")\n    argparser.add_argument(\"--cache-size\", type=int, default=0)\n    argparser.add_argument(\"--undirected\", action=\"store_true\")\n    argparser.add_argument(\"--val-acc-target\", type=float, default=1)\n    argparser.add_argument(\"--early-stopping-patience\", type=int, default=10)\n    argparser.add_argument(\"--disable-checkpoint\", action=\"store_true\")\n    argparser.add_argument(\"--precision\", type=str, default=\"highest\")\n    args = argparser.parse_args()\n\n    if args.precision != \"highest\":\n        th.set_float32_matmul_precision(args.precision)\n\n    if args.gpu >= 0:\n        device = th.device(\"cuda:%d\" % args.gpu)\n    else:\n        device = th.device(\"cpu\")\n\n    datamodule = DataModule(\n        args.dataset,\n        args.undirected,\n        args.data_cpu,\n        args.use_uva,\n        [int(_) for _ in args.fan_out.split(\",\")],\n        [int(_) for _ in args.lad_out.split(\",\")],\n        device,\n        args.batch_size,\n        args.num_workers,\n        args.sampler,\n        args.importance_sampling,\n        args.layer_dependency,\n        args.batch_dependency,\n        args.cache_size,\n    )\n    model = SAGELightning(\n        datamodule.in_feats,\n        args.num_hidden,\n        datamodule.n_classes,\n        args.num_layers,\n        F.relu,\n        args.dropout,\n        args.lr,\n        datamodule.multilabel,\n    )\n\n    # Train\n    callbacks = []\n    if not args.disable_checkpoint:\n        callbacks.append(\n            ModelCheckpoint(monitor=\"val_acc\", save_top_k=1, mode=\"max\")\n        )\n    callbacks.append(BatchSizeCallback(args.vertex_limit))\n    callbacks.append(\n        EarlyStopping(\n            monitor=\"val_acc\",\n            stopping_threshold=args.val_acc_target,\n            mode=\"max\",\n            patience=args.early_stopping_patience,\n        )\n    )\n    subdir = \"{}_{}_{}_{}_{}\".format(\n        args.dataset,\n        args.sampler,\n        args.importance_sampling,\n        args.layer_dependency,\n        args.batch_dependency,\n    )\n    logger = TensorBoardLogger(args.logdir, name=subdir)\n    trainer = Trainer(\n        accelerator=\"gpu\" if args.gpu != -1 else \"cpu\",\n        devices=[args.gpu] if args.gpu != -1 else \"auto\",\n        max_epochs=args.num_epochs,\n        max_steps=args.num_steps,\n        min_steps=args.min_steps,\n        callbacks=callbacks,\n        logger=logger,\n    )\n    trainer.fit(model, datamodule=datamodule)\n\n    # Test\n    if not args.disable_checkpoint:\n        logdir = os.path.join(args.logdir, subdir)\n        dirs = glob.glob(\"./{}/*\".format(logdir))\n        version = max([int(os.path.split(x)[-1].split(\"_\")[-1]) for x in dirs])\n        logdir = \"./{}/version_{}\".format(logdir, version)\n        print(\"Evaluating model in\", logdir)\n        ckpt = glob.glob(os.path.join(logdir, \"checkpoints\", \"*\"))[0]\n\n        model = SAGELightning.load_from_checkpoint(\n            checkpoint_path=ckpt,\n            hparams_file=os.path.join(logdir, \"hparams.yaml\"),\n        ).to(device)\n    with th.no_grad():\n        graph = datamodule.g\n        pred = model.module.inference(\n            graph,\n            f\"cuda:{args.gpu}\" if args.gpu != -1 else \"cpu\",\n            4096,\n            args.use_uva,\n            args.num_workers,\n        )\n        for nid, split_name in zip(\n            [datamodule.train_nid, datamodule.val_nid, datamodule.test_nid],\n            [\"Train\", \"Validation\", \"Test\"],\n        ):\n            nid = nid.to(pred.device).long()\n            pred_nid = pred[nid]\n            label = graph.ndata[\"labels\"][nid]\n            f1score = model.f1score_class().to(pred.device)\n            acc = f1score(pred_nid, label)\n            print(f\"{split_name} accuracy: {acc.item()}\")\n"
  },
  {
    "path": "examples/pytorch/lda/README.md",
    "content": "Latent Dirichlet Allocation\n===\nLDA is a classical algorithm for probabilistic graphical models. It assumes \nhierarchical Bayes models with discrete variables on sparse doc/word graphs.\nThis example shows how it can be done on DGL,\nwhere the corpus is represented as a bipartite multi-graph G.\nThere is no back-propagation, because gradient descent is typically considered\ninefficient on probability simplex.\nOn the provided small-scale example on 20 news groups dataset, our DGL-LDA model runs\n50% faster on GPU than sklearn model without joblib parallel.\nFor larger graphs, thanks to subgraph sampling and low-memory implementation, we may fit 100 million unique words with 256 topic dimensions on a large multi-gpu machine.\n(The runtime memory is often less than 2x of parameter storage.)\n\n\nKey equations\n---\n\n<!-- https://editor.codecogs.com/ -->\n\nLet k be the topic index variable with one-hot encoded vector representation z. The rest of the variables are:\n\n|             | z_d\\~p(θ_d) | w_k\\~p(β_k) | z_dw\\~q(ϕ_dw) |\n|-------------|-------------|-------------|---------------|\n| Prior       | Dir(α)      | Dir(η)      |     (n/a)     |\n| Posterior   | Dir(γ_d)    | Dir(λ_k)    |     (n/a)     |\n\nWe overload w with bold-symbol-w, which represents the entire observed document-world multi-graph. The difference is better shown in the original paper.\n\n**Multinomial PCA**\n\nMultinomial PCA is a \"latent allocation\" model without the \"Dirichlet\".\nIts data likelihood sums over the latent topic-index variable k,\n<img src=\"https://latex.codecogs.com/svg.image?\\inline&space;p(w_{di}|\\theta_d,\\beta)=\\sum_k\\theta_{dk}\\beta_{kw}\"/>,\nwhere θ_d and β_k are shared within the same document and topic, respectively.\n\nIf we perform gradient descent, we may need additional steps to project the parameters to the probability simplices:\n<img src=\"https://latex.codecogs.com/svg.image?\\inline&space;\\sum_k\\theta_{dk}=1\"/>\nand\n<img src=\"https://latex.codecogs.com/svg.image?\\inline&space;\\sum_w\\beta_{kw}=1\"/>.\nInstead, a more efficient solution is to borrow ideas from evidence lower-bound (ELBO) decomposition:\n\n<!--\n\\log p(w) \\geq \\mathcal{L}(w,\\phi)\n\\stackrel{def}{=}\n\\mathbb{E}_q [\\log p(w,z;\\theta,\\beta) - \\log q(z;\\phi)]\n\\\\=\n\\mathbb{E}_q [\\log p(w|z;\\beta) + \\log p(z;\\theta) - \\log q(z;\\phi)]\n\\\\=\n\\sum_{dwk}n_{dw}\\phi_{dwk} [\\log\\beta_{kw} + \\log \\theta_{dk} - \\log \\phi_{dwk}]\n-->\n\n<img src=\"https://latex.codecogs.com/svg.image?\\log&space;p(w)&space;\\geq&space;\\mathcal{L}(w,\\phi)\\stackrel{def}{=}\\mathbb{E}_q&space;[\\log&space;p(w,z;\\theta,\\beta)&space;-&space;\\log&space;q(z;\\phi)]\\\\=\\mathbb{E}_q&space;[\\log&space;p(w|z;\\beta)&space;&plus;&space;\\log&space;p(z;\\theta)&space;-&space;\\log&space;q(z;\\phi)]\\\\=\\sum_{dwk}n_{dw}\\phi_{dwk}&space;[\\log\\beta_{kw}&space;&plus;&space;\\log&space;\\theta_{dk}&space;-&space;\\log&space;\\phi_{dwk}]\"/>\n\nThe solutions for\n<img src=\"https://latex.codecogs.com/svg.image?\\inline&space;\\theta_{dk}\\propto\\sum_wn_{dw}\\phi_{dwk}\"/>\nand\n<img src=\"https://latex.codecogs.com/svg.image?\\inline&space;\\beta_{kw}\\propto\\sum_dn_{dw}\\phi_{dwk}\"/>\nfollow from the maximization of cross-entropy loss.\nThe solution for\n<img src=\"https://latex.codecogs.com/svg.image?\\inline&space;\\phi_{dwk}\\propto&space;\\theta_{dk}\\beta_{kw}\"/>\nfollows from Kullback-Leibler divergence.\nAfter normalizing to\n<img src=\"https://latex.codecogs.com/svg.image?\\inline&space;\\sum_k\\phi_{dwk}=1\"/>,\nthe difference\n<img src=\"https://latex.codecogs.com/svg.image?\\inline&space;\\ell_{dw}=\\log\\beta_{kw}+\\log\\theta_{dk}-\\log\\phi_{dwk}\"/>\nbecomes constant in k,\nwhich is connected to the likelihood for the observed document-word pairs.\n\nNote that after learning, the document vector θ_d considers the correlation between all words in d and similarly the topic distribution vector β_k considers the correlations in all observed documents.\n\n**Variational Bayes**\n\nA Bayesian model adds Dirichlet priors to θ_d and β_z, which leads to a similar ELBO if we assume independence\n<img src=\"https://latex.codecogs.com/svg.image?\\inline&space;q(z,\\theta,\\beta;\\phi,\\gamma,\\lambda)=q(z;\\phi)q(\\theta;\\gamma)q(\\beta;\\lambda)\"/>,\ni.e.:\n\n<!--\n\\log p(w;\\alpha,\\eta) \\geq \\mathcal{L}(w,\\phi,\\gamma,\\lambda)\n\\stackrel{def}{=}\n\\mathbb{E}_q [\\log p(w,z,\\theta,\\beta;\\alpha,\\eta) - \\log q(z,\\theta,\\beta;\\phi,\\gamma,\\lambda)]\n\\\\=\n\\mathbb{E}_q \\left[\n\\log p(w|z,\\beta) + \\log p(z|\\theta) - \\log q(z;\\phi)\n+\\log p(\\theta;\\alpha) - \\log q(\\theta;\\gamma)\n+\\log p(\\beta;\\eta) - \\log q(\\beta;\\lambda)\n\\right]\n\\\\=\n\\sum_{dwk}n_{dw}\\phi_{dwk} (\\mathbb{E}_{\\lambda_k}[\\log\\beta_{kw}] + \\mathbb{E}_{\\gamma_d}[\\log \\theta_{dk}] - \\log \\phi_{dwk})\n\\\\+\\sum_{d}\\left[\n(\\alpha-\\gamma_d)^\\top\\mathbb{E}_{\\gamma_d}[\\log\\theta_d]\n-(\\log B(\\alpha 1_K) - \\log B(\\gamma_d))\n\\right]\n\\\\+\\sum_{k}\\left[\n(\\eta-\\lambda_k)^\\top\\mathbb{E}_{\\lambda_k}[\\log\\beta_k]\n-(\\log B(\\eta 1_W) - \\log B(\\lambda_k))\n\\right]\n -->\n\n<img src=\"https://latex.codecogs.com/svg.image?\\log&space;p(w;\\alpha,\\eta)&space;\\geq&space;\\mathcal{L}(w,\\phi,\\gamma,\\lambda)\\stackrel{def}{=}\\mathbb{E}_q&space;[\\log&space;p(w,z,\\theta,\\beta;\\alpha,\\eta)&space;-&space;\\log&space;q(z,\\theta,\\beta;\\phi,\\gamma,\\lambda)]\\\\=\\mathbb{E}_q&space;\\left[\\log&space;p(w|z,\\beta)&space;&plus;&space;\\log&space;p(z|\\theta)&space;-&space;\\log&space;q(z;\\phi)&plus;\\log&space;p(\\theta;\\alpha)&space;-&space;\\log&space;q(\\theta;\\gamma)&plus;\\log&space;p(\\beta;\\eta)&space;-&space;\\log&space;q(\\beta;\\lambda)\\right]\\\\=\\sum_{dwk}n_{dw}\\phi_{dwk}&space;(\\mathbb{E}_{\\lambda_k}[\\log\\beta_{kw}]&space;&plus;&space;\\mathbb{E}_{\\gamma_d}[\\log&space;\\theta_{dk}]&space;-&space;\\log&space;\\phi_{dwk})\\\\&plus;\\sum_{d}\\left[(\\alpha-\\gamma_d)^\\top\\mathbb{E}_{\\gamma_d}[\\log\\theta_d]-(\\log&space;B(\\alpha&space;1_K)&space;-&space;\\log&space;B(\\gamma_d))\\right]\\\\&plus;\\sum_{k}\\left[(\\eta-\\lambda_k)^\\top\\mathbb{E}_{\\lambda_k}[\\log\\beta_k]-(\\log&space;B(\\eta&space;1_W)&space;-&space;\\log&space;B(\\lambda_k))\\right]\"/>\n\n\n**Solutions**\n\nThe solutions to VB subsumes the solutions to multinomial PCA when n goes to infinity.\nThe solution for ϕ is\n<img src=\"https://latex.codecogs.com/svg.image?\\inline&space;\\log\\phi_{dwk}=\\mathbb{E}_{\\gamma_d}[\\log\\theta_{dk}]+\\mathbb{E}_{\\lambda_k}[\\log\\beta_{kw}]-\\ell_{dw}\"/>,\nwhere the additional expectation can be expressed via digamma functions\nand\n<img src=\"https://latex.codecogs.com/svg.image?\\inline&space;\\ell_{dw}=\\log\\sum_k\\exp(\\mathbb{E}_{\\gamma_d}[\\log\\theta_{dk}]+\\mathbb{E}_{\\lambda_k}[\\log\\beta_{kw}])\"/>\nis the log-partition function.\nThe solutions for\n<img src=\"https://latex.codecogs.com/svg.image?\\inline&space;\\gamma_{dk}=\\alpha+\\sum_wn_{dw}\\phi_{dwk}\"/>\nand\n<img src=\"https://latex.codecogs.com/svg.image?\\inline&space;\\lambda_{kw}=\\eta+\\sum_dn_{dw}\\phi_{dwk}\"/>\ncome from direct gradient calculation.\nAfter substituting the optimal solutions, we compute the marginal likelihood by adding the three terms, which are all connected to (the negative of) Kullback-Leibler divergence.\n\nDGL usage\n---\n\nThe corpus is represented as a bipartite multi-graph G.\nWe use DGL to propagate information through the edges and aggregate the distributions at doc/word nodes.\nFor scalability, the phi variables are transient and updated during message passing.\nThe gamma / lambda variables are updated after the nodes receive all edge messages.\nFollowing the conventions in [1], the gamma update is called E-step and the lambda update is called M-step.\nThe lambda variable is further recorded by the trainer.\nA separate function is used to produce perplexity, which is based on the ELBO objective function divided by the total numbers of word/doc occurrences.\n\nExample\n---\n`%run example_20newsgroups.py`\n\n * Approximately matches scikit-learn training perplexity after 10 rounds of training.\n * Exactly matches scikit-learn training perplexity if word_z is set to lda.components_.T\n * There is a difference in how we compute testing perplexity. We weigh the beta contributions by the training word counts, whereas sklearn weighs them by test word counts.\n * The DGL-LDA model runs 50% faster on GPU devices compared with sklearn without joblib parallel.\n\nAdvanced configurations\n---\n * Set `0<rho<1` for online learning with partial_fit.\n * Set `mult[\"doc\"]=100` or `mult[\"word\"]=100` or some large value to disable the corresponding Bayesian priors.\n\nReferences\n---\n\n1. Matthew Hoffman, Francis Bach, David Blei. Online Learning for Latent\nDirichlet Allocation. Advances in Neural Information Processing Systems 23\n(NIPS 2010).\n2. Reactive LDA Library blogpost by Yingjie Miao for a similar Gibbs model\n"
  },
  {
    "path": "examples/pytorch/lda/example_20newsgroups.py",
    "content": "# Copyright 2021 Yifei Ma\n# Modified from scikit-learn example \"plot_topics_extraction_with_nmf_lda.py\"\n# with the following original authors with BSD 3-Clause:\n# * Olivier Grisel <olivier.grisel@ensta.org>\n# * Lars Buitinck\n# * Chyi-Kwei Yau <chyikwei.yau@gmail.com>\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport warnings\nfrom time import time\n\nimport dgl\n\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport scipy.sparse as ss\nimport torch\nfrom dgl import function as fn\nfrom lda_model import LatentDirichletAllocation as LDAModel\nfrom sklearn.datasets import fetch_20newsgroups\nfrom sklearn.decomposition import LatentDirichletAllocation, NMF\nfrom sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer\n\nn_samples = 2000\nn_features = 1000\nn_components = 10\nn_top_words = 20\ndevice = \"cuda\"\n\n\ndef plot_top_words(model, feature_names, n_top_words, title):\n    fig, axes = plt.subplots(2, 5, figsize=(30, 15), sharex=True)\n    axes = axes.flatten()\n    for topic_idx, topic in enumerate(model.components_):\n        top_features_ind = topic.argsort()[: -n_top_words - 1 : -1]\n        top_features = [feature_names[i] for i in top_features_ind]\n        weights = topic[top_features_ind]\n\n        ax = axes[topic_idx]\n        ax.barh(top_features, weights, height=0.7)\n        ax.set_title(f\"Topic {topic_idx +1}\", fontdict={\"fontsize\": 30})\n        ax.invert_yaxis()\n        ax.tick_params(axis=\"both\", which=\"major\", labelsize=20)\n        for i in \"top right left\".split():\n            ax.spines[i].set_visible(False)\n        fig.suptitle(title, fontsize=40)\n\n    plt.subplots_adjust(top=0.90, bottom=0.05, wspace=0.90, hspace=0.3)\n    plt.show()\n\n\n# Load the 20 newsgroups dataset and vectorize it. We use a few heuristics\n# to filter out useless terms early on: the posts are stripped of headers,\n# footers and quoted replies, and common English words, words occurring in\n# only one document or in at least 95% of the documents are removed.\n\nprint(\"Loading dataset...\")\nt0 = time()\ndata, _ = fetch_20newsgroups(\n    shuffle=True,\n    random_state=1,\n    remove=(\"headers\", \"footers\", \"quotes\"),\n    return_X_y=True,\n)\ndata_samples = data[:n_samples]\ndata_test = data[n_samples : 2 * n_samples]\nprint(\"done in %0.3fs.\" % (time() - t0))\n\n# Use tf (raw term count) features for LDA.\nprint(\"Extracting tf features for LDA...\")\ntf_vectorizer = CountVectorizer(\n    max_df=0.95, min_df=2, max_features=n_features, stop_words=\"english\"\n)\nt0 = time()\ntf_vectorizer.fit(data)\ntf = tf_vectorizer.transform(data_samples)\ntt = tf_vectorizer.transform(data_test)\n\ntf_feature_names = tf_vectorizer.get_feature_names()\ntf_uv = [\n    (u, v)\n    for u, v, e in zip(tf.tocoo().row, tf.tocoo().col, tf.tocoo().data)\n    for _ in range(e)\n]\ntt_uv = [\n    (u, v)\n    for u, v, e in zip(tt.tocoo().row, tt.tocoo().col, tt.tocoo().data)\n    for _ in range(e)\n]\nprint(\"done in %0.3fs.\" % (time() - t0))\nprint()\n\nprint(\"Preparing dgl graphs...\")\nt0 = time()\nG = dgl.heterograph({(\"doc\", \"topic\", \"word\"): tf_uv}, device=device)\nGt = dgl.heterograph({(\"doc\", \"topic\", \"word\"): tt_uv}, device=device)\nprint(\"done in %0.3fs.\" % (time() - t0))\nprint()\n\nprint(\"Training dgl-lda model...\")\nt0 = time()\nmodel = LDAModel(G.num_nodes(\"word\"), n_components)\nmodel.fit(G)\nprint(\"done in %0.3fs.\" % (time() - t0))\nprint()\n\nprint(f\"dgl-lda training perplexity {model.perplexity(G):.3f}\")\nprint(f\"dgl-lda testing perplexity {model.perplexity(Gt):.3f}\")\n\nword_nphi = np.vstack([nphi.tolist() for nphi in model.word_data.nphi])\nplot_top_words(\n    type(\"dummy\", (object,), {\"components_\": word_nphi}),\n    tf_feature_names,\n    n_top_words,\n    \"Topics in LDA model\",\n)\n\nprint(\"Training scikit-learn model...\")\n\nprint(\n    \"\\n\" * 2,\n    \"Fitting LDA models with tf features, \"\n    \"n_samples=%d and n_features=%d...\" % (n_samples, n_features),\n)\nlda = LatentDirichletAllocation(\n    n_components=n_components,\n    max_iter=5,\n    learning_method=\"online\",\n    learning_offset=50.0,\n    random_state=0,\n    verbose=1,\n)\nt0 = time()\nlda.fit(tf)\nprint(\"done in %0.3fs.\" % (time() - t0))\nprint()\n\nprint(f\"scikit-learn training perplexity {lda.perplexity(tf):.3f}\")\nprint(f\"scikit-learn testing perplexity {lda.perplexity(tt):.3f}\")\n"
  },
  {
    "path": "examples/pytorch/lda/lda_model.py",
    "content": "# Copyright 2021 Yifei Ma\n# with references from \"sklearn.decomposition.LatentDirichletAllocation\"\n# with the following original authors:\n# * Chyi-Kwei Yau (the said scikit-learn implementation)\n# * Matthew D. Hoffman (original onlineldavb implementation)\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport collections\nimport functools\nimport io\nimport os\nimport warnings\n\nimport dgl\n\nimport numpy as np\nimport scipy as sp\nimport torch\n\ntry:\n    from functools import cached_property\nexcept ImportError:\n    try:\n        from backports.cached_property import cached_property\n    except ImportError:\n        warnings.warn(\"cached_property not found - using property instead\")\n        cached_property = property\n\n\nclass EdgeData:\n    def __init__(self, src_data, dst_data):\n        self.src_data = src_data\n        self.dst_data = dst_data\n\n    @property\n    def loglike(self):\n        return (self.src_data[\"Elog\"] + self.dst_data[\"Elog\"]).logsumexp(1)\n\n    @property\n    def phi(self):\n        return (\n            self.src_data[\"Elog\"]\n            + self.dst_data[\"Elog\"]\n            - self.loglike.unsqueeze(1)\n        ).exp()\n\n    @property\n    def expectation(self):\n        return (\n            self.src_data[\"expectation\"] * self.dst_data[\"expectation\"]\n        ).sum(1)\n\n\nclass _Dirichlet:\n    def __init__(self, prior, nphi, _chunksize=int(1e6)):\n        self.prior = prior\n        self.nphi = nphi\n        self.device = nphi.device\n        self._sum_by_parts = lambda map_fn: functools.reduce(\n            torch.add,\n            [\n                map_fn(slice(i, min(i + _chunksize, nphi.shape[1]))).sum(1)\n                for i in list(range(0, nphi.shape[1], _chunksize))\n            ],\n        )\n\n    def _posterior(self, _ID=slice(None)):\n        return self.prior + self.nphi[:, _ID]\n\n    @cached_property\n    def posterior_sum(self):\n        return self.nphi.sum(1) + self.prior * self.nphi.shape[1]\n\n    def _Elog(self, _ID=slice(None)):\n        return torch.digamma(self._posterior(_ID)) - torch.digamma(\n            self.posterior_sum.unsqueeze(1)\n        )\n\n    @cached_property\n    def loglike(self):\n        neg_evid = -self._sum_by_parts(\n            lambda s: (self.nphi[:, s] * self._Elog(s))\n        )\n\n        prior = torch.as_tensor(self.prior).to(self.nphi)\n        K = self.nphi.shape[1]\n        log_B_prior = torch.lgamma(prior) * K - torch.lgamma(prior * K)\n\n        log_B_posterior = self._sum_by_parts(\n            lambda s: torch.lgamma(self._posterior(s))\n        ) - torch.lgamma(self.posterior_sum)\n\n        return neg_evid - log_B_prior + log_B_posterior\n\n    @cached_property\n    def n(self):\n        return self.nphi.sum(1)\n\n    @cached_property\n    def cdf(self):\n        cdf = self._posterior()\n        torch.cumsum(cdf, 1, out=cdf)\n        cdf /= cdf[:, -1:].clone()\n        return cdf\n\n    def _expectation(self, _ID=slice(None)):\n        expectation = self._posterior(_ID)\n        expectation /= self.posterior_sum.unsqueeze(1)\n        return expectation\n\n    @cached_property\n    def Bayesian_gap(self):\n        return 1.0 - self._sum_by_parts(lambda s: self._Elog(s).exp())\n\n    _cached_properties = [\n        \"posterior_sum\",\n        \"loglike\",\n        \"n\",\n        \"cdf\",\n        \"Bayesian_gap\",\n    ]\n\n    def clear_cache(self):\n        for name in self._cached_properties:\n            try:\n                delattr(self, name)\n            except AttributeError:\n                pass\n\n    def update(self, new, _ID=slice(None), rho=1):\n        \"\"\"inplace: old * (1-rho) + new * rho\"\"\"\n        self.clear_cache()\n        mean_change = (self.nphi[:, _ID] - new).abs().mean().tolist()\n\n        self.nphi *= 1 - rho\n        self.nphi[:, _ID] += new * rho\n        return mean_change\n\n\nclass DocData(_Dirichlet):\n    \"\"\"nphi (n_docs by n_topics)\"\"\"\n\n    def prepare_graph(self, G, key=\"Elog\"):\n        G.nodes[\"doc\"].data[key] = getattr(self, \"_\" + key)().to(G.device)\n\n    def update_from(self, G, mult):\n        new = G.nodes[\"doc\"].data[\"nphi\"] * mult\n        return self.update(new.to(self.device))\n\n\nclass _Distributed(collections.UserList):\n    \"\"\"split on dim=0 and store on multiple devices\"\"\"\n\n    def __init__(self, prior, nphi):\n        self.prior = prior\n        self.nphi = nphi\n        super().__init__([_Dirichlet(self.prior, nphi) for nphi in self.nphi])\n\n    def split_device(self, other, dim=0):\n        split_sections = [x.shape[0] for x in self.nphi]\n        out = torch.split(other, split_sections, dim)\n        return [y.to(x.device) for x, y in zip(self.nphi, out)]\n\n\nclass WordData(_Distributed):\n    \"\"\"distributed nphi (n_topics by n_words), transpose to/from graph nodes data\"\"\"\n\n    def prepare_graph(self, G, key=\"Elog\"):\n        if \"_ID\" in G.nodes[\"word\"].data:\n            _ID = G.nodes[\"word\"].data[\"_ID\"]\n        else:\n            _ID = slice(None)\n\n        out = [getattr(part, \"_\" + key)(_ID).to(G.device) for part in self]\n        G.nodes[\"word\"].data[key] = torch.cat(out).T\n\n    def update_from(self, G, mult, rho):\n        nphi = G.nodes[\"word\"].data[\"nphi\"].T * mult\n\n        if \"_ID\" in G.nodes[\"word\"].data:\n            _ID = G.nodes[\"word\"].data[\"_ID\"]\n        else:\n            _ID = slice(None)\n\n        mean_change = [\n            x.update(y, _ID, rho) for x, y in zip(self, self.split_device(nphi))\n        ]\n        return np.mean(mean_change)\n\n\nclass Gamma(collections.namedtuple(\"Gamma\", \"concentration, rate\")):\n    \"\"\"articulate the difference between torch gamma and numpy gamma\"\"\"\n\n    @property\n    def shape(self):\n        return self.concentration\n\n    @property\n    def scale(self):\n        return 1 / self.rate\n\n    def sample(self, shape, device):\n        return torch.distributions.gamma.Gamma(\n            torch.as_tensor(self.concentration, device=device),\n            torch.as_tensor(self.rate, device=device),\n        ).sample(shape)\n\n\nclass LatentDirichletAllocation:\n    \"\"\"LDA model that works with a HeteroGraph with doc->word meta paths.\n    The model alters the attributes of G arbitrarily.\n    This is inspired by [1] and its corresponding scikit-learn implementation.\n\n    Inputs\n    ---\n    * G: a template graph or an integer showing n_words\n    * n_components: latent feature dimension; automatically set priors if missing.\n    * prior: parameters in the Dirichlet prior; default to 1/n_components and 1/n_words\n    * rho: new_nphi = (1-rho)*old_nphi + rho*nphi; default to 1 for full gradients.\n    * mult: multiplier for nphi-update; a large value effectively disables prior.\n    * init: sklearn initializers (100.0, 100.0); the sample points concentrate around 1.0\n    * device_list: accelerate word_data updates.\n\n    Notes\n    ---\n    Some differences between this and sklearn.decomposition.LatentDirichletAllocation:\n    * default word perplexity is normalized by training set instead of testing set.\n\n    References\n    ---\n    [1] Matthew Hoffman, Francis Bach, David Blei. Online Learning for Latent\n    Dirichlet Allocation. Advances in Neural Information Processing Systems 23\n    (NIPS 2010).\n    [2] Reactive LDA Library blogpost by Yingjie Miao for a similar Gibbs model\n    \"\"\"\n\n    def __init__(\n        self,\n        n_words,\n        n_components,\n        prior=None,\n        rho=1,\n        mult={\"doc\": 1, \"word\": 1},\n        init={\"doc\": (100.0, 100.0), \"word\": (100.0, 100.0)},\n        device_list=[\"cpu\"],\n        verbose=True,\n    ):\n        self.n_words = n_words\n        self.n_components = n_components\n\n        if prior is None:\n            prior = {\"doc\": 1.0 / n_components, \"word\": 1.0 / n_components}\n        self.prior = prior\n\n        self.rho = rho\n        self.mult = mult\n        self.init = init\n\n        assert not isinstance(device_list, str), \"plz wrap devices in a list\"\n        self.device_list = device_list[:n_components]  # avoid edge cases\n        self.verbose = verbose\n\n        self._init_word_data()\n\n    def _init_word_data(self):\n        split_sections = np.diff(\n            np.linspace(0, self.n_components, len(self.device_list) + 1).astype(\n                int\n            )\n        )\n        word_nphi = [\n            Gamma(*self.init[\"word\"]).sample((s, self.n_words), device)\n            for s, device in zip(split_sections, self.device_list)\n        ]\n        self.word_data = WordData(self.prior[\"word\"], word_nphi)\n\n    def _init_doc_data(self, n_docs, device):\n        doc_nphi = Gamma(*self.init[\"doc\"]).sample(\n            (n_docs, self.n_components), device\n        )\n        return DocData(self.prior[\"doc\"], doc_nphi)\n\n    def save(self, f):\n        for w in self.word_data:\n            w.clear_cache()\n        torch.save(\n            {\n                \"prior\": self.prior,\n                \"rho\": self.rho,\n                \"mult\": self.mult,\n                \"init\": self.init,\n                \"word_data\": [part.nphi for part in self.word_data],\n            },\n            f,\n        )\n\n    def _prepare_graph(self, G, doc_data, key=\"Elog\"):\n        doc_data.prepare_graph(G, key)\n        self.word_data.prepare_graph(G, key)\n\n    def _e_step(self, G, doc_data=None, mean_change_tol=1e-3, max_iters=100):\n        \"\"\"_e_step implements doc data sampling until convergence or max_iters\"\"\"\n        if doc_data is None:\n            doc_data = self._init_doc_data(G.num_nodes(\"doc\"), G.device)\n\n        G_rev = G.reverse()  # word -> doc\n        self.word_data.prepare_graph(G_rev)\n\n        for i in range(max_iters):\n            doc_data.prepare_graph(G_rev)\n            G_rev.update_all(\n                lambda edges: {\"phi\": EdgeData(edges.src, edges.dst).phi},\n                dgl.function.sum(\"phi\", \"nphi\"),\n            )\n            mean_change = doc_data.update_from(G_rev, self.mult[\"doc\"])\n            if mean_change < mean_change_tol:\n                break\n\n        if self.verbose:\n            print(\n                f\"e-step num_iters={i+1} with mean_change={mean_change:.4f}, \"\n                f\"perplexity={self.perplexity(G, doc_data):.4f}\"\n            )\n\n        return doc_data\n\n    transform = _e_step\n\n    def predict(self, doc_data):\n        pred_scores = [\n            # d_exp @ w._expectation()\n            (lambda x: x @ w.nphi + x.sum(1, keepdims=True) * w.prior)(\n                d_exp / w.posterior_sum.unsqueeze(0)\n            )\n            for (d_exp, w) in zip(\n                self.word_data.split_device(doc_data._expectation(), dim=1),\n                self.word_data,\n            )\n        ]\n        x = torch.zeros_like(pred_scores[0], device=doc_data.device)\n        for p in pred_scores:\n            x += p.to(x.device)\n        return x\n\n    def sample(self, doc_data, num_samples):\n        \"\"\"draw independent words and return the marginal probabilities,\n        i.e., the expectations in Dirichlet distributions.\n        \"\"\"\n\n        def fn(cdf):\n            u = torch.rand(cdf.shape[0], num_samples, device=cdf.device)\n            return torch.searchsorted(cdf, u).to(doc_data.device)\n\n        topic_ids = fn(doc_data.cdf)\n        word_ids = torch.cat([fn(part.cdf) for part in self.word_data])\n        ids = torch.gather(\n            word_ids, 0, topic_ids\n        )  # pick components by topic_ids\n\n        # compute expectation scores on sampled ids\n        src_ids = (\n            torch.arange(ids.shape[0], dtype=ids.dtype, device=ids.device)\n            .reshape((-1, 1))\n            .expand(ids.shape)\n        )\n        unique_ids, inverse_ids = torch.unique(\n            ids, sorted=False, return_inverse=True\n        )\n\n        G = dgl.heterograph(\n            {(\"doc\", \"\", \"word\"): (src_ids.ravel(), inverse_ids.ravel())}\n        )\n        G.nodes[\"word\"].data[\"_ID\"] = unique_ids\n        self._prepare_graph(G, doc_data, \"expectation\")\n        G.apply_edges(\n            lambda e: {\"expectation\": EdgeData(e.src, e.dst).expectation}\n        )\n        expectation = G.edata.pop(\"expectation\").reshape(ids.shape)\n\n        return ids, expectation\n\n    def _m_step(self, G, doc_data):\n        \"\"\"_m_step implements word data sampling and stores word_z stats.\n        mean_change is in the sense of full graph with rho=1.\n        \"\"\"\n        G = G.clone()\n        self._prepare_graph(G, doc_data)\n        G.update_all(\n            lambda edges: {\"phi\": EdgeData(edges.src, edges.dst).phi},\n            dgl.function.sum(\"phi\", \"nphi\"),\n        )\n        self._last_mean_change = self.word_data.update_from(\n            G, self.mult[\"word\"], self.rho\n        )\n\n        if self.verbose:\n            print(f\"m-step mean_change={self._last_mean_change:.4f}, \", end=\"\")\n            Bayesian_gap = np.mean(\n                [part.Bayesian_gap.mean().tolist() for part in self.word_data]\n            )\n            print(f\"Bayesian_gap={Bayesian_gap:.4f}\")\n\n    def partial_fit(self, G):\n        doc_data = self._e_step(G)\n        self._m_step(G, doc_data)\n        return self\n\n    def fit(self, G, mean_change_tol=1e-3, max_epochs=10):\n        for i in range(max_epochs):\n            if self.verbose:\n                print(f\"epoch {i+1}, \", end=\"\")\n            self.partial_fit(G)\n\n            if self._last_mean_change < mean_change_tol:\n                break\n        return self\n\n    def perplexity(self, G, doc_data=None):\n        \"\"\"ppl = exp{-sum[log(p(w1,...,wn|d))] / n}\n        Follows Eq (15) in Hoffman et al., 2010.\n        \"\"\"\n        if doc_data is None:\n            doc_data = self._e_step(G)\n\n        # compute E[log p(docs | theta, beta)]\n        G = G.clone()\n        self._prepare_graph(G, doc_data)\n        G.apply_edges(\n            lambda edges: {\"loglike\": EdgeData(edges.src, edges.dst).loglike}\n        )\n        edge_elbo = (G.edata[\"loglike\"].sum() / G.num_edges()).tolist()\n        if self.verbose:\n            print(f\"neg_elbo phi: {-edge_elbo:.3f}\", end=\" \")\n\n        # compute E[log p(theta | alpha) - log q(theta | gamma)]\n        doc_elbo = (doc_data.loglike.sum() / doc_data.n.sum()).tolist()\n        if self.verbose:\n            print(f\"theta: {-doc_elbo:.3f}\", end=\" \")\n\n        # compute E[log p(beta | eta) - log q(beta | lambda)]\n        # The denominator n for extrapolation perplexity is undefined.\n        # We use the train set, whereas sklearn uses the test set.\n        word_elbo = sum(\n            [part.loglike.sum().tolist() for part in self.word_data]\n        ) / sum([part.n.sum().tolist() for part in self.word_data])\n        if self.verbose:\n            print(f\"beta: {-word_elbo:.3f}\")\n\n        ppl = np.exp(-edge_elbo - doc_elbo - word_elbo)\n        if G.num_edges() > 0 and np.isnan(ppl):\n            warnings.warn(\"numerical issue in perplexity\")\n        return ppl\n\n\ndef doc_subgraph(G, doc_ids):\n    sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)\n    _, _, (block,) = sampler.sample(\n        G.reverse(), {\"doc\": torch.as_tensor(doc_ids)}\n    )\n    B = dgl.DGLGraph(\n        block._graph, [\"_\", \"word\", \"doc\", \"_\"], block.etypes\n    ).reverse()\n    B.nodes[\"word\"].data[\"_ID\"] = block.nodes[\"word\"].data[\"_ID\"]\n    return B\n\n\nif __name__ == \"__main__\":\n    print(\"Testing LatentDirichletAllocation ...\")\n    G = dgl.heterograph(\n        {(\"doc\", \"\", \"word\"): [(0, 0), (1, 3)]}, {\"doc\": 2, \"word\": 5}\n    )\n    model = LatentDirichletAllocation(n_words=5, n_components=10, verbose=False)\n    model.fit(G)\n    model.transform(G)\n    model.predict(model.transform(G))\n    if hasattr(torch, \"searchsorted\"):\n        model.sample(model.transform(G), 3)\n    model.perplexity(G)\n\n    for doc_id in range(2):\n        B = doc_subgraph(G, [doc_id])\n        model.partial_fit(B)\n\n    with io.BytesIO() as f:\n        model.save(f)\n        f.seek(0)\n        print(torch.load(f, weights_only=False))\n\n    print(\"Testing LatentDirichletAllocation passed!\")\n"
  },
  {
    "path": "examples/pytorch/line_graph/README.md",
    "content": "Community Detection with Graph Neural Networks (CDGNN)\n============\n\nPaper link: [https://openreview.net/pdf?id=H1g0Z3A9Fm](https://openreview.net/pdf?id=H1g0Z3A9Fm)\n\nAuthor's code repo: [https://github.com/zhengdao-chen/GNN4CD](https://github.com/zhengdao-chen/GNN4CD)\n\nThis folder contains a DGL implementation of the CDGNN model.\n\nDependencies\n--------------\n* PyTorch 0.4.1+\n* requests\n\n```bash\npip install torch requests\n```\n\nHow to run\n----------\n\nAn experiment on the Stochastic Block Model in default settings can be run with\n\n```bash\npython3 train.py\n```\n\nAn experiment on the Stochastic Block Model in customized settings can be run with\n```bash\npython3 train.py --batch-size BATCH_SIZE --gpu GPU --n-communities N_COMMUNITIES \\\n                --n-features N_FEATURES --n-graphs N_GRAPH --n-iterations N_ITERATIONS \\\n                --n-layers N_LAYER --n-nodes N_NODE --model-path MODEL_PATH --radius RADIUS\n```\n"
  },
  {
    "path": "examples/pytorch/line_graph/gnn.py",
    "content": "import copy\nimport itertools\n\nimport dgl\nimport dgl.function as fn\nimport networkx as nx\nimport numpy as np\nimport torch as th\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass GNNModule(nn.Module):\n    def __init__(self, in_feats, out_feats, radius):\n        super().__init__()\n        self.out_feats = out_feats\n        self.radius = radius\n\n        new_linear = lambda: nn.Linear(in_feats, out_feats)\n        new_linear_list = lambda: nn.ModuleList(\n            [new_linear() for i in range(radius)]\n        )\n\n        self.theta_x, self.theta_deg, self.theta_y = (\n            new_linear(),\n            new_linear(),\n            new_linear(),\n        )\n        self.theta_list = new_linear_list()\n\n        self.gamma_y, self.gamma_deg, self.gamma_x = (\n            new_linear(),\n            new_linear(),\n            new_linear(),\n        )\n        self.gamma_list = new_linear_list()\n\n        self.bn_x = nn.BatchNorm1d(out_feats)\n        self.bn_y = nn.BatchNorm1d(out_feats)\n\n    def aggregate(self, g, z):\n        z_list = []\n        g.ndata[\"z\"] = z\n        g.update_all(fn.copy_u(u=\"z\", out=\"m\"), fn.sum(msg=\"m\", out=\"z\"))\n        z_list.append(g.ndata[\"z\"])\n        for i in range(self.radius - 1):\n            for j in range(2**i):\n                g.update_all(\n                    fn.copy_u(u=\"z\", out=\"m\"), fn.sum(msg=\"m\", out=\"z\")\n                )\n            z_list.append(g.ndata[\"z\"])\n        return z_list\n\n    def forward(self, g, lg, x, y, deg_g, deg_lg, pm_pd):\n        pmpd_x = F.embedding(pm_pd, x)\n\n        sum_x = sum(\n            theta(z) for theta, z in zip(self.theta_list, self.aggregate(g, x))\n        )\n\n        g.edata[\"y\"] = y\n        g.update_all(fn.copy_e(e=\"y\", out=\"m\"), fn.sum(\"m\", \"pmpd_y\"))\n        pmpd_y = g.ndata.pop(\"pmpd_y\")\n\n        x = (\n            self.theta_x(x)\n            + self.theta_deg(deg_g * x)\n            + sum_x\n            + self.theta_y(pmpd_y)\n        )\n        n = self.out_feats // 2\n        x = th.cat([x[:, :n], F.relu(x[:, n:])], 1)\n        x = self.bn_x(x)\n\n        sum_y = sum(\n            gamma(z) for gamma, z in zip(self.gamma_list, self.aggregate(lg, y))\n        )\n\n        y = (\n            self.gamma_y(y)\n            + self.gamma_deg(deg_lg * y)\n            + sum_y\n            + self.gamma_x(pmpd_x)\n        )\n        y = th.cat([y[:, :n], F.relu(y[:, n:])], 1)\n        y = self.bn_y(y)\n\n        return x, y\n\n\nclass GNN(nn.Module):\n    def __init__(self, feats, radius, n_classes):\n        super(GNN, self).__init__()\n        self.linear = nn.Linear(feats[-1], n_classes)\n        self.module_list = nn.ModuleList(\n            [GNNModule(m, n, radius) for m, n in zip(feats[:-1], feats[1:])]\n        )\n\n    def forward(self, g, lg, deg_g, deg_lg, pm_pd):\n        x, y = deg_g, deg_lg\n        for module in self.module_list:\n            x, y = module(g, lg, x, y, deg_g, deg_lg, pm_pd)\n        return self.linear(x)\n"
  },
  {
    "path": "examples/pytorch/line_graph/train.py",
    "content": "\"\"\"\nSupervised Community Detection with Hierarchical Graph Neural Networks\nhttps://arxiv.org/abs/1705.08415\n\nAuthor's implementation: https://github.com/joanbruna/GNN_community\n\"\"\"\n\nfrom __future__ import division\n\nimport argparse\nimport time\nfrom itertools import permutations\n\nimport gnn\nimport numpy as np\nimport torch as th\nimport torch.nn.functional as F\nimport torch.optim as optim\n\nfrom dgl.data import SBMMixtureDataset\nfrom torch.utils.data import DataLoader\n\nparser = argparse.ArgumentParser()\nparser.add_argument(\"--batch-size\", type=int, help=\"Batch size\", default=1)\nparser.add_argument(\"--gpu\", type=int, help=\"GPU index\", default=-1)\nparser.add_argument(\"--lr\", type=float, help=\"Learning rate\", default=0.001)\nparser.add_argument(\n    \"--n-communities\", type=int, help=\"Number of communities\", default=2\n)\nparser.add_argument(\n    \"--n-epochs\", type=int, help=\"Number of epochs\", default=100\n)\nparser.add_argument(\n    \"--n-features\", type=int, help=\"Number of features\", default=16\n)\nparser.add_argument(\"--n-graphs\", type=int, help=\"Number of graphs\", default=10)\nparser.add_argument(\"--n-layers\", type=int, help=\"Number of layers\", default=30)\nparser.add_argument(\n    \"--n-nodes\", type=int, help=\"Number of nodes\", default=10000\n)\nparser.add_argument(\"--optim\", type=str, help=\"Optimizer\", default=\"Adam\")\nparser.add_argument(\"--radius\", type=int, help=\"Radius\", default=3)\nparser.add_argument(\"--verbose\", action=\"store_true\")\nargs = parser.parse_args()\n\ndev = th.device(\"cpu\") if args.gpu < 0 else th.device(\"cuda:%d\" % args.gpu)\nK = args.n_communities\n\ntraining_dataset = SBMMixtureDataset(args.n_graphs, args.n_nodes, K)\ntraining_loader = DataLoader(\n    training_dataset,\n    args.batch_size,\n    collate_fn=training_dataset.collate_fn,\n    drop_last=True,\n)\n\nones = th.ones(args.n_nodes // K)\ny_list = [\n    th.cat([x * ones for x in p]).long().to(dev) for p in permutations(range(K))\n]\n\nfeats = [1] + [args.n_features] * args.n_layers + [K]\nmodel = gnn.GNN(feats, args.radius, K).to(dev)\noptimizer = getattr(optim, args.optim)(model.parameters(), lr=args.lr)\n\n\ndef compute_overlap(z_list):\n    ybar_list = [th.max(z, 1)[1] for z in z_list]\n    overlap_list = []\n    for y_bar in ybar_list:\n        accuracy = max(th.sum(y_bar == y).item() for y in y_list) / args.n_nodes\n        overlap = (accuracy - 1 / K) / (1 - 1 / K)\n        overlap_list.append(overlap)\n    return sum(overlap_list) / len(overlap_list)\n\n\ndef from_np(f, *args):\n    def wrap(*args):\n        new = [\n            th.from_numpy(x) if isinstance(x, np.ndarray) else x for x in args\n        ]\n        return f(*new)\n\n    return wrap\n\n\n@from_np\ndef step(i, j, g, lg, deg_g, deg_lg, pm_pd):\n    \"\"\"One step of training.\"\"\"\n    g = g.to(dev)\n    lg = lg.to(dev)\n    deg_g = deg_g.to(dev).unsqueeze(1)\n    deg_lg = deg_lg.to(dev).unsqueeze(1)\n    pm_pd = pm_pd.to(dev)\n    t0 = time.time()\n    z = model(g, lg, deg_g, deg_lg, pm_pd)\n    t_forward = time.time() - t0\n\n    z_list = th.chunk(z, args.batch_size, 0)\n    loss = (\n        sum(min(F.cross_entropy(z, y) for y in y_list) for z in z_list)\n        / args.batch_size\n    )\n    overlap = compute_overlap(z_list)\n\n    optimizer.zero_grad()\n    t0 = time.time()\n    loss.backward()\n    t_backward = time.time() - t0\n    optimizer.step()\n\n    return loss, overlap, t_forward, t_backward\n\n\n@from_np\ndef inference(g, lg, deg_g, deg_lg, pm_pd):\n    g = g.to(dev)\n    lg = lg.to(dev)\n    deg_g = deg_g.to(dev).unsqueeze(1)\n    deg_lg = deg_lg.to(dev).unsqueeze(1)\n    pm_pd = pm_pd.to(dev)\n\n    z = model(g, lg, deg_g, deg_lg, pm_pd)\n\n    return z\n\n\ndef test():\n    p_list = [6, 5.5, 5, 4.5, 1.5, 1, 0.5, 0]\n    q_list = [0, 0.5, 1, 1.5, 4.5, 5, 5.5, 6]\n    N = 1\n    overlap_list = []\n    for p, q in zip(p_list, q_list):\n        dataset = SBMMixtureDataset(N, args.n_nodes, K, pq=[[p, q]] * N)\n        loader = DataLoader(dataset, N, collate_fn=dataset.collate_fn)\n        g, lg, deg_g, deg_lg, pm_pd = next(iter(loader))\n        z = inference(g, lg, deg_g, deg_lg, pm_pd)\n        overlap_list.append(compute_overlap(th.chunk(z, N, 0)))\n    return overlap_list\n\n\nn_iterations = args.n_graphs // args.batch_size\nfor i in range(args.n_epochs):\n    total_loss, total_overlap, s_forward, s_backward = 0, 0, 0, 0\n    for j, [g, lg, deg_g, deg_lg, pm_pd] in enumerate(training_loader):\n        loss, overlap, t_forward, t_backward = step(\n            i, j, g, lg, deg_g, deg_lg, pm_pd\n        )\n\n        total_loss += loss\n        total_overlap += overlap\n        s_forward += t_forward\n        s_backward += t_backward\n\n        epoch = \"0\" * (len(str(args.n_epochs)) - len(str(i)))\n        iteration = \"0\" * (len(str(n_iterations)) - len(str(j)))\n        if args.verbose:\n            print(\n                \"[epoch %s%d iteration %s%d]loss %.3f | overlap %.3f\"\n                % (epoch, i, iteration, j, loss, overlap)\n            )\n\n    epoch = \"0\" * (len(str(args.n_epochs)) - len(str(i)))\n    loss = total_loss / (j + 1)\n    overlap = total_overlap / (j + 1)\n    t_forward = s_forward / (j + 1)\n    t_backward = s_backward / (j + 1)\n    print(\n        \"[epoch %s%d]loss %.3f | overlap %.3f | forward time %.3fs | backward time %.3fs\"\n        % (epoch, i, loss, overlap, t_forward, t_backward)\n    )\n\n    overlap_list = test()\n    overlap_str = \" - \".join([\"%.3f\" % overlap for overlap in overlap_list])\n    print(\"[epoch %s%d]overlap: %s\" % (epoch, i, overlap_str))\n"
  },
  {
    "path": "examples/pytorch/metapath2vec/README.md",
    "content": "Metapath2vec\n============\n\n- Paper link: [metapath2vec: Scalable Representation Learning for Heterogeneous Networks](https://ericdongyx.github.io/papers/KDD17-dong-chawla-swami-metapath2vec.pdf)\n- Author's code repo: [https://ericdongyx.github.io/metapath2vec/m2v.html](https://ericdongyx.github.io/metapath2vec/m2v.html). \n\nDependencies\n------------\n- PyTorch 1.0.1+\n\nHow to run the code\n-----\nRun with either of the following procedures:\n\n* Running with default AMiner dataset:\n  1. Directly run the following command:\n\n     ```bash\n     python metapath2vec.py --aminer --path \"where/you/want/to/download\" --output_file \"your_model_output_path\"\n     ```\n* Running with another AMiner-like dataset\n  1. Prepare the data in the same format as the ones of AMiner and DBIS in Section B of [Author's code repo](https://ericdongyx.github.io/metapath2vec/m2v.html).\n  2. Run `sampler.py` on your graph dataset with, for instance,\n\n     ```bash\n     python sampler.py net_dbis\n     ```\n  3. Run the following command:\n\n     ```bash\n     python metapath2vec.py --path net_dbis/output_path.txt --output_file \"your_model_output_path\"\n     ```\n\nTips: Change num_workers based on your GPU instances; Running 3 or 4 epochs is actually enough. \n\nTricks included in the implementation:\n-------\n1, Sub-sampling;\n\n2, Negative Sampling without repeatedly calling numpy random choices;\n\nPerformance and Explanations:\n-------\nVenue Classification Results for Metapath2vec:\n\n| Metric | 5% | 10% | 20% | 30% | 40% | 50% | 60% | 70% | 80% | 90% |\n| ------ | -- | --- | --- | --- | --- | --- | --- | --- | --- | --- |\n| Macro-F1 | 0.3033 | 0.5247 | 0.8033 | 0.8971 | 0.9406 | 0.9532 | 0.9529 | 0.9701 | 0.9683 | 0.9670 |\n| Micro-F1 | 0.4173 | 0.5975 | 0.8327 | 0.9011 | 0.9400 | 0.9522 | 0.9537 | 0.9725 | 0.9815 | 0.9857 |\n\nAuthor Classfication Results for Metapath2vec:\n\n| Metric | 5% | 10% | 20% | 30% | 40% | 50% | 60% | 70% | 80% | 90% |\n| ------ | -- | --- | --- | --- | --- | --- | --- | --- | --- | --- |\n| Macro-F1 | 0.9216 | 0.9262 | 0.9292 | 0.9303 | 0.9309 | 0.9314 | 0.9315 | 0.9316 | 0.9319 | 0.9320 |\n| Micro-F1 | 0.9279 | 0.9319 | 0.9346 | 0.9356 | 0.9361 | 0.9365 | 0.9365 | 0.9365 | 0.9367 | 0.9369 |\n\nNote that: \n\nTesting files are available in \"label 2\" file;\n\nThe above are results listed in the paper, in real experiments, exact numbers might be slightly different:\n\n1, For venue node classification results, when the size of the training dataset is small (e.g. 5%), the variance of the performance is large since the number of available labeled venues is small. \n\n2, For author node classification results, the performance is stable since the number of available labeled authors is huge, so even 5% training data would be sufficient.\n\n3, In the test.py, you could change experiment times you want, especially it is very slow to test author classification so you could only do 1 or 2 times.\n"
  },
  {
    "path": "examples/pytorch/metapath2vec/download.py",
    "content": "import os\n\nimport torch as th\nimport torch.nn as nn\nimport tqdm\n\n\nclass PBar(object):\n    def __enter__(self):\n        self.t = None\n        return self\n\n    def __call__(self, blockno, readsize, totalsize):\n        if self.t is None:\n            self.t = tqdm.tqdm(total=totalsize)\n        self.t.update(readsize)\n\n    def __exit__(self, exc_type, exc_value, traceback):\n        self.t.close()\n\n\nclass AminerDataset(object):\n    \"\"\"\n    Download Aminer Dataset from Amazon S3 bucket.\n    \"\"\"\n\n    def __init__(self, path):\n        self.url = \"https://data.dgl.ai/dataset/aminer.zip\"\n\n        if not os.path.exists(os.path.join(path, \"aminer.txt\")):\n            print(\"File not found. Downloading from\", self.url)\n            self._download_and_extract(path, \"aminer.zip\")\n        self.fn = os.path.join(path, \"aminer.txt\")\n\n    def _download_and_extract(self, path, filename):\n        import shutil, zipfile, zlib\n        import urllib.request\n\n        from tqdm import tqdm\n\n        fn = os.path.join(path, filename)\n        with PBar() as pb:\n            urllib.request.urlretrieve(self.url, fn, pb)\n        print(\"Download finished. Unzipping the file...\")\n\n        with zipfile.ZipFile(fn) as zf:\n            zf.extractall(path)\n        print(\"Unzip finished.\")\n\n\nclass CustomDataset(object):\n    \"\"\"\n    Custom dataset generated by sampler.py (e.g. NetDBIS)\n    \"\"\"\n\n    def __init__(self, path):\n        self.fn = path\n"
  },
  {
    "path": "examples/pytorch/metapath2vec/metapath2vec.py",
    "content": "import argparse\n\nimport torch\nimport torch.optim as optim\nfrom download import AminerDataset, CustomDataset\nfrom model import SkipGramModel\n\nfrom reading_data import DataReader, Metapath2vecDataset\nfrom torch.utils.data import DataLoader\n\nfrom tqdm import tqdm\n\n\nclass Metapath2VecTrainer:\n    def __init__(self, args):\n        if args.aminer:\n            dataset = AminerDataset(args.path)\n        else:\n            dataset = CustomDataset(args.path)\n        self.data = DataReader(dataset, args.min_count, args.care_type)\n        dataset = Metapath2vecDataset(self.data, args.window_size)\n        self.dataloader = DataLoader(\n            dataset,\n            batch_size=args.batch_size,\n            shuffle=True,\n            num_workers=args.num_workers,\n            collate_fn=dataset.collate,\n        )\n\n        self.output_file_name = args.output_file\n        self.emb_size = len(self.data.word2id)\n        self.emb_dimension = args.dim\n        self.batch_size = args.batch_size\n        self.iterations = args.iterations\n        self.initial_lr = args.initial_lr\n        self.skip_gram_model = SkipGramModel(self.emb_size, self.emb_dimension)\n\n        self.use_cuda = torch.cuda.is_available()\n        self.device = torch.device(\"cuda\" if self.use_cuda else \"cpu\")\n        if self.use_cuda:\n            self.skip_gram_model.cuda()\n\n    def train(self):\n        optimizer = optim.SparseAdam(\n            list(self.skip_gram_model.parameters()), lr=self.initial_lr\n        )\n        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\n            optimizer, len(self.dataloader)\n        )\n\n        for iteration in range(self.iterations):\n            print(\"\\n\\n\\nIteration: \" + str(iteration + 1))\n            running_loss = 0.0\n            for i, sample_batched in enumerate(tqdm(self.dataloader)):\n                if len(sample_batched[0]) > 1:\n                    pos_u = sample_batched[0].to(self.device)\n                    pos_v = sample_batched[1].to(self.device)\n                    neg_v = sample_batched[2].to(self.device)\n\n                    scheduler.step()\n                    optimizer.zero_grad()\n                    loss = self.skip_gram_model.forward(pos_u, pos_v, neg_v)\n                    loss.backward()\n                    optimizer.step()\n\n                    running_loss = running_loss * 0.9 + loss.item() * 0.1\n                    if i > 0 and i % 500 == 0:\n                        print(\" Loss: \" + str(running_loss))\n\n        self.skip_gram_model.save_embedding(\n            self.data.id2word, self.output_file_name\n        )\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"Metapath2vec\")\n    # parser.add_argument('--input_file', type=str, help=\"input_file\")\n    parser.add_argument(\n        \"--aminer\", action=\"store_true\", help=\"Use AMiner dataset\"\n    )\n    parser.add_argument(\"--path\", type=str, help=\"input_path\")\n    parser.add_argument(\"--output_file\", type=str, help=\"output_file\")\n    parser.add_argument(\n        \"--dim\", default=128, type=int, help=\"embedding dimensions\"\n    )\n    parser.add_argument(\n        \"--window_size\", default=7, type=int, help=\"context window size\"\n    )\n    parser.add_argument(\"--iterations\", default=5, type=int, help=\"iterations\")\n    parser.add_argument(\"--batch_size\", default=50, type=int, help=\"batch size\")\n    parser.add_argument(\n        \"--care_type\",\n        default=0,\n        type=int,\n        help=\"if 1, heterogeneous negative sampling, else normal negative sampling\",\n    )\n    parser.add_argument(\n        \"--initial_lr\", default=0.025, type=float, help=\"learning rate\"\n    )\n    parser.add_argument(\"--min_count\", default=5, type=int, help=\"min count\")\n    parser.add_argument(\n        \"--num_workers\", default=16, type=int, help=\"number of workers\"\n    )\n    args = parser.parse_args()\n    m2v = Metapath2VecTrainer(args)\n    m2v.train()\n"
  },
  {
    "path": "examples/pytorch/metapath2vec/model.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.nn import init\n\n\"\"\"\n    u_embedding: Embedding for center word.\n    v_embedding: Embedding for neighbor words.\n\"\"\"\n\n\nclass SkipGramModel(nn.Module):\n    def __init__(self, emb_size, emb_dimension):\n        super(SkipGramModel, self).__init__()\n        self.emb_size = emb_size\n        self.emb_dimension = emb_dimension\n        self.u_embeddings = nn.Embedding(emb_size, emb_dimension, sparse=True)\n        self.v_embeddings = nn.Embedding(emb_size, emb_dimension, sparse=True)\n\n        initrange = 1.0 / self.emb_dimension\n        init.uniform_(self.u_embeddings.weight.data, -initrange, initrange)\n        init.constant_(self.v_embeddings.weight.data, 0)\n\n    def forward(self, pos_u, pos_v, neg_v):\n        emb_u = self.u_embeddings(pos_u)\n        emb_v = self.v_embeddings(pos_v)\n        emb_neg_v = self.v_embeddings(neg_v)\n\n        score = torch.sum(torch.mul(emb_u, emb_v), dim=1)\n        score = torch.clamp(score, max=10, min=-10)\n        score = -F.logsigmoid(score)\n\n        neg_score = torch.bmm(emb_neg_v, emb_u.unsqueeze(2)).squeeze()\n        neg_score = torch.clamp(neg_score, max=10, min=-10)\n        neg_score = -torch.sum(F.logsigmoid(-neg_score), dim=1)\n\n        return torch.mean(score + neg_score)\n\n    def save_embedding(self, id2word, file_name):\n        embedding = self.u_embeddings.weight.cpu().data.numpy()\n        with open(file_name, \"w\") as f:\n            f.write(\"%d %d\\n\" % (len(id2word), self.emb_dimension))\n            for wid, w in id2word.items():\n                e = \" \".join(map(lambda x: str(x), embedding[wid]))\n                f.write(\"%s %s\\n\" % (w, e))\n"
  },
  {
    "path": "examples/pytorch/metapath2vec/reading_data.py",
    "content": "import numpy as np\nimport torch\nfrom download import AminerDataset\nfrom torch.utils.data import Dataset\n\nnp.random.seed(12345)\n\n\nclass DataReader:\n    NEGATIVE_TABLE_SIZE = 1e8\n\n    def __init__(self, dataset, min_count, care_type):\n        self.negatives = []\n        self.discards = []\n        self.negpos = 0\n        self.care_type = care_type\n        self.word2id = dict()\n        self.id2word = dict()\n        self.sentences_count = 0\n        self.token_count = 0\n        self.word_frequency = dict()\n        self.inputFileName = dataset.fn\n        self.read_words(min_count)\n        self.initTableNegatives()\n        self.initTableDiscards()\n\n    def read_words(self, min_count):\n        word_frequency = dict()\n        for line in open(self.inputFileName, encoding=\"ISO-8859-1\"):\n            line = line.split()\n            if len(line) > 1:\n                self.sentences_count += 1\n                for word in line:\n                    if len(word) > 0:\n                        self.token_count += 1\n                        word_frequency[word] = word_frequency.get(word, 0) + 1\n\n                        if self.token_count % 1000000 == 0:\n                            print(\n                                \"Read \"\n                                + str(int(self.token_count / 1000000))\n                                + \"M words.\"\n                            )\n\n        wid = 0\n        for w, c in word_frequency.items():\n            if c < min_count:\n                continue\n            self.word2id[w] = wid\n            self.id2word[wid] = w\n            self.word_frequency[wid] = c\n            wid += 1\n\n        self.word_count = len(self.word2id)\n        print(\"Total embeddings: \" + str(len(self.word2id)))\n\n    def initTableDiscards(self):\n        # get a frequency table for sub-sampling. Note that the frequency is adjusted by\n        # sub-sampling tricks.\n        t = 0.0001\n        f = np.array(list(self.word_frequency.values())) / self.token_count\n        self.discards = np.sqrt(t / f) + (t / f)\n\n    def initTableNegatives(self):\n        # get a table for negative sampling, if word with index 2 appears twice, then 2 will be listed\n        # in the table twice.\n        pow_frequency = np.array(list(self.word_frequency.values())) ** 0.75\n        words_pow = sum(pow_frequency)\n        ratio = pow_frequency / words_pow\n        count = np.round(ratio * DataReader.NEGATIVE_TABLE_SIZE)\n        for wid, c in enumerate(count):\n            self.negatives += [wid] * int(c)\n        self.negatives = np.array(self.negatives)\n        np.random.shuffle(self.negatives)\n        self.sampling_prob = ratio\n\n    def getNegatives(self, target, size):  # TODO check equality with target\n        if self.care_type == 0:\n            response = self.negatives[self.negpos : self.negpos + size]\n            self.negpos = (self.negpos + size) % len(self.negatives)\n            if len(response) != size:\n                return np.concatenate(\n                    (response, self.negatives[0 : self.negpos])\n                )\n        return response\n\n\n# -----------------------------------------------------------------------------------------------------------------\n\n\nclass Metapath2vecDataset(Dataset):\n    def __init__(self, data, window_size):\n        # read in data, window_size and input filename\n        self.data = data\n        self.window_size = window_size\n        self.input_file = open(data.inputFileName, encoding=\"ISO-8859-1\")\n\n    def __len__(self):\n        # return the number of walks\n        return self.data.sentences_count\n\n    def __getitem__(self, idx):\n        # return the list of pairs (center, context, 5 negatives)\n        while True:\n            line = self.input_file.readline()\n            if not line:\n                self.input_file.seek(0, 0)\n                line = self.input_file.readline()\n\n            if len(line) > 1:\n                words = line.split()\n\n                if len(words) > 1:\n                    word_ids = [\n                        self.data.word2id[w]\n                        for w in words\n                        if w in self.data.word2id\n                        and np.random.rand()\n                        < self.data.discards[self.data.word2id[w]]\n                    ]\n\n                    pair_catch = []\n                    for i, u in enumerate(word_ids):\n                        for j, v in enumerate(\n                            word_ids[\n                                max(i - self.window_size, 0) : i\n                                + self.window_size\n                            ]\n                        ):\n                            assert u < self.data.word_count\n                            assert v < self.data.word_count\n                            if i == j:\n                                continue\n                            pair_catch.append(\n                                (u, v, self.data.getNegatives(v, 5))\n                            )\n                    return pair_catch\n\n    @staticmethod\n    def collate(batches):\n        all_u = [u for batch in batches for u, _, _ in batch if len(batch) > 0]\n        all_v = [v for batch in batches for _, v, _ in batch if len(batch) > 0]\n        all_neg_v = [\n            neg_v\n            for batch in batches\n            for _, _, neg_v in batch\n            if len(batch) > 0\n        ]\n\n        return (\n            torch.LongTensor(all_u),\n            torch.LongTensor(all_v),\n            torch.LongTensor(all_neg_v),\n        )\n"
  },
  {
    "path": "examples/pytorch/metapath2vec/sampler.py",
    "content": "import os\nimport random\nimport sys\nimport time\n\nimport dgl\nimport numpy as np\nimport tqdm\n\nnum_walks_per_node = 1000\nwalk_length = 100\npath = sys.argv[1]\n\n\ndef construct_graph():\n    paper_ids = []\n    paper_names = []\n    author_ids = []\n    author_names = []\n    conf_ids = []\n    conf_names = []\n    f_3 = open(os.path.join(path, \"id_author.txt\"), encoding=\"ISO-8859-1\")\n    f_4 = open(os.path.join(path, \"id_conf.txt\"), encoding=\"ISO-8859-1\")\n    f_5 = open(os.path.join(path, \"paper.txt\"), encoding=\"ISO-8859-1\")\n    while True:\n        z = f_3.readline()\n        if not z:\n            break\n        z = z.strip().split()\n        identity = int(z[0])\n        author_ids.append(identity)\n        author_names.append(z[1])\n    while True:\n        w = f_4.readline()\n        if not w:\n            break\n        w = w.strip().split()\n        identity = int(w[0])\n        conf_ids.append(identity)\n        conf_names.append(w[1])\n    while True:\n        v = f_5.readline()\n        if not v:\n            break\n        v = v.strip().split()\n        identity = int(v[0])\n        paper_name = \"p\" + \"\".join(v[1:])\n        paper_ids.append(identity)\n        paper_names.append(paper_name)\n    f_3.close()\n    f_4.close()\n    f_5.close()\n\n    author_ids_invmap = {x: i for i, x in enumerate(author_ids)}\n    conf_ids_invmap = {x: i for i, x in enumerate(conf_ids)}\n    paper_ids_invmap = {x: i for i, x in enumerate(paper_ids)}\n\n    paper_author_src = []\n    paper_author_dst = []\n    paper_conf_src = []\n    paper_conf_dst = []\n    f_1 = open(os.path.join(path, \"paper_author.txt\"), \"r\")\n    f_2 = open(os.path.join(path, \"paper_conf.txt\"), \"r\")\n    for x in f_1:\n        x = x.split(\"\\t\")\n        x[0] = int(x[0])\n        x[1] = int(x[1].strip(\"\\n\"))\n        paper_author_src.append(paper_ids_invmap[x[0]])\n        paper_author_dst.append(author_ids_invmap[x[1]])\n    for y in f_2:\n        y = y.split(\"\\t\")\n        y[0] = int(y[0])\n        y[1] = int(y[1].strip(\"\\n\"))\n        paper_conf_src.append(paper_ids_invmap[y[0]])\n        paper_conf_dst.append(conf_ids_invmap[y[1]])\n    f_1.close()\n    f_2.close()\n\n    hg = dgl.heterograph(\n        {\n            (\"paper\", \"pa\", \"author\"): (paper_author_src, paper_author_dst),\n            (\"author\", \"ap\", \"paper\"): (paper_author_dst, paper_author_src),\n            (\"paper\", \"pc\", \"conf\"): (paper_conf_src, paper_conf_dst),\n            (\"conf\", \"cp\", \"paper\"): (paper_conf_dst, paper_conf_src),\n        }\n    )\n    return hg, author_names, conf_names, paper_names\n\n\n# \"conference - paper - Author - paper - conference\" metapath sampling\ndef generate_metapath():\n    output_path = open(os.path.join(path, \"output_path.txt\"), \"w\")\n    count = 0\n\n    hg, author_names, conf_names, paper_names = construct_graph()\n\n    for conf_idx in tqdm.trange(hg.num_nodes(\"conf\")):\n        traces, _ = dgl.sampling.random_walk(\n            hg,\n            [conf_idx] * num_walks_per_node,\n            metapath=[\"cp\", \"pa\", \"ap\", \"pc\"] * walk_length,\n        )\n        for tr in traces:\n            outline = \" \".join(\n                (conf_names if i % 4 == 0 else author_names)[tr[i]]\n                for i in range(0, len(tr), 2)\n            )  # skip paper\n            print(outline, file=output_path)\n    output_path.close()\n\n\nif __name__ == \"__main__\":\n    generate_metapath()\n"
  },
  {
    "path": "examples/pytorch/metapath2vec/test.py",
    "content": "import numpy as np\nfrom sklearn.linear_model import LogisticRegression\nfrom sklearn.metrics import f1_score\n\n\nif __name__ == \"__main__\":\n    venue_count = 133\n    author_count = 246678\n    experiment_times = 1\n    percent = 0.05\n    file = open(\".../output_file_path/...\")\n    file_1 = open(\".../label 2/googlescholar.8area.venue.label.txt\")\n    file_2 = open(\".../label 2/googlescholar.8area.author.label.txt\")\n    check_venue = {}\n    check_author = {}\n    for line in file_1:\n        venue_label = line.strip().split(\" \")\n        check_venue[venue_label[0]] = int(venue_label[1])\n    for line in file_2:\n        author_label = line.strip().split(\" \")\n        check_author[author_label[0]] = int(author_label[1])\n    venue_embed_dict = {}\n    author_embed_dict = {}\n    # collect embeddings separately in dictionary form\n    file.readline()\n    print(\"read line by line\")\n    for line in file:\n        embed = line.strip().split(\" \")\n        if embed[0] in check_venue:\n            venue_embed_dict[embed[0]] = []\n            for i in range(1, len(embed), 1):\n                venue_embed_dict[embed[0]].append(float(embed[i]))\n        if embed[0] in check_author:\n            author_embed_dict[embed[0]] = []\n            for j in range(1, len(embed), 1):\n                author_embed_dict[embed[0]].append(float(embed[j]))\n    # get venue embeddings\n    print(\"reading finished\")\n    venues = list(venue_embed_dict.keys())\n    authors = list(author_embed_dict.keys())\n    macro_average_venue = 0\n    micro_average_venue = 0\n    macro_average_author = 0\n    micro_average_author = 0\n    for time in range(experiment_times):\n        print(\"one more time\")\n        np.random.shuffle(venues)\n        np.random.shuffle(authors)\n        venue_embedding = np.array([])\n        author_embedding = np.array([])\n        print(\"collecting venue embeddings\")\n        for venue in venues:\n            temp = np.array(venue_embed_dict[venue])\n            if len(venue_embedding) == 0:\n                venue_embedding = temp\n            else:\n                venue_embedding = np.vstack((venue_embedding, temp))\n        print(\"collecting author embeddings\")\n        count = 0\n        for author in authors:\n            count += 1\n            # print(\"one more author \" + str(count))\n            temp_1 = np.array(author_embed_dict[author])\n            if len(author_embedding) == 0:\n                author_embedding = temp_1\n            else:\n                author_embedding = np.vstack((author_embedding, temp_1))\n        # split data into training and testing\n        print(\"splitting\")\n        venue_split = int(venue_count * percent)\n        venue_training = venue_embedding[:venue_split, :]\n        venue_testing = venue_embedding[venue_split:, :]\n        author_split = int(author_count * percent)\n        author_training = author_embedding[:author_split, :]\n        author_testing = author_embedding[author_split:, :]\n        # split label into training and testing\n        venue_label = []\n        venue_true = []\n        author_label = []\n        author_true = []\n        for i in range(len(venues)):\n            if i < venue_split:\n                venue_label.append(check_venue[venues[i]])\n            else:\n                venue_true.append(check_venue[venues[i]])\n        venue_label = np.array(venue_label)\n        venue_true = np.array(venue_true)\n        for j in range(len(authors)):\n            if j < author_split:\n                author_label.append(check_author[authors[j]])\n            else:\n                author_true.append(check_author[authors[j]])\n        author_label = np.array(author_label)\n        author_true = np.array(author_true)\n        file.close()\n        print(\"beging predicting\")\n        clf_venue = LogisticRegression(\n            random_state=0, solver=\"lbfgs\", multi_class=\"multinomial\"\n        ).fit(venue_training, venue_label)\n        y_pred_venue = clf_venue.predict(venue_testing)\n        clf_author = LogisticRegression(\n            random_state=0, solver=\"lbfgs\", multi_class=\"multinomial\"\n        ).fit(author_training, author_label)\n        y_pred_author = clf_author.predict(author_testing)\n        macro_average_venue += f1_score(\n            venue_true, y_pred_venue, average=\"macro\"\n        )\n        micro_average_venue += f1_score(\n            venue_true, y_pred_venue, average=\"micro\"\n        )\n        macro_average_author += f1_score(\n            author_true, y_pred_author, average=\"macro\"\n        )\n        micro_average_author += f1_score(\n            author_true, y_pred_author, average=\"micro\"\n        )\n    print(macro_average_venue / float(experiment_times))\n    print(micro_average_venue / float(experiment_times))\n    print(macro_average_author / float(experiment_times))\n    print(micro_average_author / float(experiment_times))\n"
  },
  {
    "path": "examples/pytorch/mixhop/README.md",
    "content": "# DGL Implementations of MixHop\n\nThis DGL example implements the GNN model proposed in the paper [MixHop: Higher-Order Graph Convolution Architectures via Sparsified Neighborhood Mixing](https://arxiv.org/abs/1905.00067). For the original implementation, see [here](https://github.com/samihaija/mixhop).\n\nContributor: [xnuohz](https://github.com/xnuohz)\n\n### Requirements\nThe codebase is implemented in Python 3.6. For version requirement of packages, see below.\n\n```\ndgl 0.5.2\nnumpy 1.19.4\npandas 1.1.4\ntqdm 4.53.0\ntorch 1.7.0\n```\n\n### The graph datasets used in this example\n\nThe DGL's built-in Cora, Pubmed and Citeseer datasets. Dataset summary:\n\n| Dataset | #Nodes | #Edges | #Feats | #Classes | #Train Nodes | #Val Nodes | #Test Nodes |\n| :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: |\n| Citeseer | 3,327 | 9,228 | 3,703 | 6 | 120 | 500 | 1000 |\n| Cora | 2,708 | 10,556 | 1,433 | 7 | 140 | 500 | 1000 |\n| Pubmed | 19,717 | 88,651 | 500 | 3 | 60 | 500 | 1000 |\n\n### Usage\n\n###### Dataset options\n```\n--dataset          str     The graph dataset name.             Default is 'Cora'.\n```\n\n###### GPU options\n```\n--gpu              int     GPU index.                          Default is -1, using CPU.\n```\n\n###### Model options\n```\n--epochs           int     Number of training epochs.          Default is 2000.\n--early-stopping   int     Early stopping rounds.              Default is 200.\n--lr               float   Adam optimizer learning rate.       Default is 0.5.\n--lamb             float   L2 regularization coefficient.      Default is 0.0005.\n--step-size        int     Period of learning rate decay.      Default is 40.\n--gamma            float   Factor of learning rate decay.      Default is 0.01.\n--hid-dim          int     Hidden layer dimensionalities.      Default is 60.\n--num-layers       int     Number of GNN layers.               Default is 4.\n--input-dropout    float   Dropout applied at input layer.     Default is 0.7.\n--layer-dropout    float   Dropout applied at hidden layers.   Default is 0.9.\n--p                list    List of powers of adjacency matrix. Default is [0, 1, 2].\n```\n\n###### Examples\n\nThe following commands learn a neural network and predict on the test set.\nTraining a MixHop model on the default dataset.\n```bash\npython main.py\n```\nTrain a model for 200 epochs and perform an early stop if the validation accuracy stops getting improved for 10 epochs.\n```bash\npython main.py --epochs 200 --early-stopping 10\n```\nTrain a model with a different learning rate and regularization coefficient.\n```bash\npython main.py --lr 0.001 --lamb 0.1\n```\nTrain a model with different model hyperparameters.\n```bash\npython main.py --num-layers 6 --p 2 4 6\n```\nTrain a model which follows the original hyperparameters on different datasets.\n```bash\n# Cora:\npython main.py --gpu 0 --dataset Cora --lr 1 --input-dropout 0.6 --lamb 5e-3 --hid-dim 100 --num-layers 3\n\n# Citeseer:\npython main.py --gpu 0 --dataset Citeseer --lr 0.25 --input-dropout 0.5 --lamb 5e-3 --hid-dim 60 --num-layers 3\n\n# Pubmed:\npython main.py --gpu 0 --dataset Pubmed --lr 0.5 --input-dropout 0.7 --lamb 5e-3 --hid-dim 60 --num-layers 3\n```\n\n### Performance\n\n| Dataset | Cora | Pubmed | Citeseer |\n| :-: | :-: | :-: | :-: |\n| Accuracy(MixHop: default architecture in Table 1) | 0.818 | 0.800 | 0.714 |\n| Accuracy(official code) | 0.610(0.156) | 0.746(0.065) | 0.700(0.017) |\n| Accuracy(DGL) | 0.801(0.005) | 0.780(0.005) | 0.692(0.005) |"
  },
  {
    "path": "examples/pytorch/mixhop/main.py",
    "content": "\"\"\" The main file to train a MixHop model using a full graph \"\"\"\n\nimport argparse\nimport copy\nimport random\n\nimport dgl\nimport dgl.function as fn\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset\nfrom tqdm import trange\n\n\nclass MixHopConv(nn.Module):\n    r\"\"\"\n\n    Description\n    -----------\n    MixHop Graph Convolutional layer from paper `MixHop: Higher-Order Graph Convolutional Architecturesvia Sparsified Neighborhood Mixing\n     <https://arxiv.org/pdf/1905.00067.pdf>`__.\n\n    .. math::\n        H^{(i+1)} =\\underset{j \\in P}{\\Bigg\\Vert} \\sigma\\left(\\widehat{A}^j H^{(i)} W_j^{(i)}\\right),\n\n    where :math:`\\widehat{A}` denotes the symmetrically normalized adjacencymatrix with self-connections,\n    :math:`D_{ii} = \\sum_{j=0} \\widehat{A}_{ij}` its diagonal degree matrix,\n    :math:`W_j^{(i)}` denotes the trainable weight matrix of different MixHop layers.\n\n    Parameters\n    ----------\n    in_dim : int\n        Input feature size. i.e, the number of dimensions of :math:`H^{(i)}`.\n    out_dim : int\n        Output feature size for each power.\n    p: list\n        List of powers of adjacency matrix. Defaults: ``[0, 1, 2]``.\n    dropout: float, optional\n        Dropout rate on node features. Defaults: ``0``.\n    activation: callable activation function/layer or None, optional\n        If not None, applies an activation function to the updated node features.\n        Default: ``None``.\n    batchnorm: bool, optional\n        If True, use batch normalization. Defaults: ``False``.\n    \"\"\"\n\n    def __init__(\n        self,\n        in_dim,\n        out_dim,\n        p=[0, 1, 2],\n        dropout=0,\n        activation=None,\n        batchnorm=False,\n    ):\n        super(MixHopConv, self).__init__()\n        self.in_dim = in_dim\n        self.out_dim = out_dim\n        self.p = p\n        self.activation = activation\n        self.batchnorm = batchnorm\n\n        # define dropout layer\n        self.dropout = nn.Dropout(dropout)\n\n        # define batch norm layer\n        if self.batchnorm:\n            self.bn = nn.BatchNorm1d(out_dim * len(p))\n\n        # define weight dict for each power j\n        self.weights = nn.ModuleDict(\n            {str(j): nn.Linear(in_dim, out_dim, bias=False) for j in p}\n        )\n\n    def forward(self, graph, feats):\n        with graph.local_scope():\n            # assume that the graphs are undirected and graph.in_degrees() is the same as graph.out_degrees()\n            degs = graph.in_degrees().float().clamp(min=1)\n            norm = torch.pow(degs, -0.5).to(feats.device).unsqueeze(1)\n            max_j = max(self.p) + 1\n            outputs = []\n            for j in range(max_j):\n                if j in self.p:\n                    output = self.weights[str(j)](feats)\n                    outputs.append(output)\n\n                feats = feats * norm\n                graph.ndata[\"h\"] = feats\n                graph.update_all(fn.copy_u(\"h\", \"m\"), fn.sum(\"m\", \"h\"))\n                feats = graph.ndata.pop(\"h\")\n                feats = feats * norm\n\n            final = torch.cat(outputs, dim=1)\n\n            if self.batchnorm:\n                final = self.bn(final)\n\n            if self.activation is not None:\n                final = self.activation(final)\n\n            final = self.dropout(final)\n\n            return final\n\n\nclass MixHop(nn.Module):\n    def __init__(\n        self,\n        in_dim,\n        hid_dim,\n        out_dim,\n        num_layers=2,\n        p=[0, 1, 2],\n        input_dropout=0.0,\n        layer_dropout=0.0,\n        activation=None,\n        batchnorm=False,\n    ):\n        super(MixHop, self).__init__()\n        self.in_dim = in_dim\n        self.hid_dim = hid_dim\n        self.out_dim = out_dim\n        self.num_layers = num_layers\n        self.p = p\n        self.input_dropout = input_dropout\n        self.layer_dropout = layer_dropout\n        self.activation = activation\n        self.batchnorm = batchnorm\n\n        self.layers = nn.ModuleList()\n        self.dropout = nn.Dropout(self.input_dropout)\n\n        # Input layer\n        self.layers.append(\n            MixHopConv(\n                self.in_dim,\n                self.hid_dim,\n                p=self.p,\n                dropout=self.input_dropout,\n                activation=self.activation,\n                batchnorm=self.batchnorm,\n            )\n        )\n\n        # Hidden layers with n - 1 MixHopConv layers\n        for i in range(self.num_layers - 2):\n            self.layers.append(\n                MixHopConv(\n                    self.hid_dim * len(args.p),\n                    self.hid_dim,\n                    p=self.p,\n                    dropout=self.layer_dropout,\n                    activation=self.activation,\n                    batchnorm=self.batchnorm,\n                )\n            )\n\n        self.fc_layers = nn.Linear(\n            self.hid_dim * len(args.p), self.out_dim, bias=False\n        )\n\n    def forward(self, graph, feats):\n        feats = self.dropout(feats)\n        for layer in self.layers:\n            feats = layer(graph, feats)\n\n        feats = self.fc_layers(feats)\n\n        return feats\n\n\ndef main(args):\n    # Step 1: Prepare graph data and retrieve train/validation/test index ============================= #\n    # Load from DGL dataset\n    if args.dataset == \"Cora\":\n        dataset = CoraGraphDataset()\n    elif args.dataset == \"Citeseer\":\n        dataset = CiteseerGraphDataset()\n    elif args.dataset == \"Pubmed\":\n        dataset = PubmedGraphDataset()\n    else:\n        raise ValueError(\"Dataset {} is invalid.\".format(args.dataset))\n\n    graph = dataset[0]\n    graph = dgl.add_self_loop(graph)\n\n    # check cuda\n    if args.gpu >= 0 and torch.cuda.is_available():\n        device = \"cuda:{}\".format(args.gpu)\n    else:\n        device = \"cpu\"\n\n    # retrieve the number of classes\n    n_classes = dataset.num_classes\n\n    # retrieve labels of ground truth\n    labels = graph.ndata.pop(\"label\").to(device).long()\n\n    # Extract node features\n    feats = graph.ndata.pop(\"feat\").to(device)\n    n_features = feats.shape[-1]\n\n    # retrieve masks for train/validation/test\n    train_mask = graph.ndata.pop(\"train_mask\")\n    val_mask = graph.ndata.pop(\"val_mask\")\n    test_mask = graph.ndata.pop(\"test_mask\")\n\n    train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze().to(device)\n    val_idx = torch.nonzero(val_mask, as_tuple=False).squeeze().to(device)\n    test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze().to(device)\n\n    graph = graph.to(device)\n\n    # Step 2: Create model =================================================================== #\n    model = MixHop(\n        in_dim=n_features,\n        hid_dim=args.hid_dim,\n        out_dim=n_classes,\n        num_layers=args.num_layers,\n        p=args.p,\n        input_dropout=args.input_dropout,\n        layer_dropout=args.layer_dropout,\n        activation=torch.tanh,\n        batchnorm=True,\n    )\n\n    model = model.to(device)\n    best_model = copy.deepcopy(model)\n\n    # Step 3: Create training components ===================================================== #\n    loss_fn = nn.CrossEntropyLoss()\n    opt = optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.lamb)\n    scheduler = optim.lr_scheduler.StepLR(opt, args.step_size, gamma=args.gamma)\n\n    # Step 4: training epoches =============================================================== #\n    acc = 0\n    no_improvement = 0\n    epochs = trange(args.epochs, desc=\"Accuracy & Loss\")\n\n    for _ in epochs:\n        # Training using a full graph\n        model.train()\n\n        logits = model(graph, feats)\n\n        # compute loss\n        train_loss = loss_fn(logits[train_idx], labels[train_idx])\n        train_acc = torch.sum(\n            logits[train_idx].argmax(dim=1) == labels[train_idx]\n        ).item() / len(train_idx)\n\n        # backward\n        opt.zero_grad()\n        train_loss.backward()\n        opt.step()\n\n        # Validation using a full graph\n        model.eval()\n\n        with torch.no_grad():\n            valid_loss = loss_fn(logits[val_idx], labels[val_idx])\n            valid_acc = torch.sum(\n                logits[val_idx].argmax(dim=1) == labels[val_idx]\n            ).item() / len(val_idx)\n\n        # Print out performance\n        epochs.set_description(\n            \"Train Acc {:.4f} | Train Loss {:.4f} | Val Acc {:.4f} | Val loss {:.4f}\".format(\n                train_acc, train_loss.item(), valid_acc, valid_loss.item()\n            )\n        )\n\n        if valid_acc < acc:\n            no_improvement += 1\n            if no_improvement == args.early_stopping:\n                print(\"Early stop.\")\n                break\n        else:\n            no_improvement = 0\n            acc = valid_acc\n            best_model = copy.deepcopy(model)\n\n        scheduler.step()\n\n    best_model.eval()\n    logits = best_model(graph, feats)\n    test_acc = torch.sum(\n        logits[test_idx].argmax(dim=1) == labels[test_idx]\n    ).item() / len(test_idx)\n\n    print(\"Test Acc {:.4f}\".format(test_acc))\n    return test_acc\n\n\nif __name__ == \"__main__\":\n    \"\"\"\n    MixHop Model Hyperparameters\n    \"\"\"\n    parser = argparse.ArgumentParser(description=\"MixHop GCN\")\n\n    # data source params\n    parser.add_argument(\n        \"--dataset\", type=str, default=\"Cora\", help=\"Name of dataset.\"\n    )\n    # cuda params\n    parser.add_argument(\n        \"--gpu\", type=int, default=-1, help=\"GPU index. Default: -1, using CPU.\"\n    )\n    # training params\n    parser.add_argument(\n        \"--epochs\", type=int, default=2000, help=\"Training epochs.\"\n    )\n    parser.add_argument(\n        \"--early-stopping\",\n        type=int,\n        default=200,\n        help=\"Patient epochs to wait before early stopping.\",\n    )\n    parser.add_argument(\"--lr\", type=float, default=0.5, help=\"Learning rate.\")\n    parser.add_argument(\"--lamb\", type=float, default=5e-4, help=\"L2 reg.\")\n    parser.add_argument(\n        \"--step-size\",\n        type=int,\n        default=40,\n        help=\"Period of learning rate decay.\",\n    )\n    parser.add_argument(\n        \"--gamma\",\n        type=float,\n        default=0.01,\n        help=\"Multiplicative factor of learning rate decay.\",\n    )\n    # model params\n    parser.add_argument(\n        \"--hid-dim\", type=int, default=60, help=\"Hidden layer dimensionalities.\"\n    )\n    parser.add_argument(\n        \"--num-layers\", type=int, default=4, help=\"Number of GNN layers.\"\n    )\n    parser.add_argument(\n        \"--input-dropout\",\n        type=float,\n        default=0.7,\n        help=\"Dropout applied at input layer.\",\n    )\n    parser.add_argument(\n        \"--layer-dropout\",\n        type=float,\n        default=0.9,\n        help=\"Dropout applied at hidden layers.\",\n    )\n    parser.add_argument(\n        \"--p\", nargs=\"+\", type=int, help=\"List of powers of adjacency matrix.\"\n    )\n\n    parser.set_defaults(p=[0, 1, 2])\n\n    args = parser.parse_args()\n    print(args)\n\n    acc_lists = []\n\n    for _ in range(100):\n        acc_lists.append(main(args))\n\n    acc_lists.sort()\n    acc_lists_top = np.array(acc_lists[50:])\n\n    mean = np.around(np.mean(acc_lists_top, axis=0), decimals=3)\n    std = np.around(np.std(acc_lists_top, axis=0), decimals=3)\n    print(\"Total acc: \", acc_lists)\n    print(\"Top 50 acc:\", acc_lists_top)\n    print(\"mean\", mean)\n    print(\"std\", std)\n"
  },
  {
    "path": "examples/pytorch/model_zoo/README.md",
    "content": "Model Zoo\n==========\n\nHere are examples of using the model zoo.\n"
  },
  {
    "path": "examples/pytorch/model_zoo/citation_network/README.md",
    "content": "# Node Classification on Citation Networks\n\nThis example shows how to use modules defined in `dgl.nn.pytorch.conv` to do node classification on\ncitation network datasets.\n\n## Datasets\n\n- Cora\n- Citeseer\n- Pubmed\n\n## Models\n\n- GCN: [Semi-Supervised Classification with Graph Convolutional Networks](https://arxiv.org/pdf/1609.02907)\n- GAT: [Graph Attention Networks](https://arxiv.org/abs/1710.10903)\n- GraphSAGE [Inductive Representation Learning on Large Graphs](https://cs.stanford.edu/people/jure/pubs/graphsage-nips17.pdf)\n- APPNP: [Predict then Propagate: Graph Neural Networks meet Personalized PageRank](https://arxiv.org/pdf/1810.05997)\n- GIN: [How Powerful are Graph Neural Networks?](https://arxiv.org/abs/1810.00826)\n- TAGCN: [Topology Adaptive Graph Convolutional Networks](https://arxiv.org/abs/1710.10370)\n- SGC: [Simplifying Graph Convolutional Networks](https://arxiv.org/abs/1902.07153)\n- AGNN: [Attention-based Graph Neural Network for Semi-supervised Learning](https://arxiv.org/pdf/1803.03735.pdf)\n- ChebNet: [Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering](https://arxiv.org/abs/1606.09375)\n\n## Usage\n\n```\npython run.py [--gpu GPU] --model MODEL_NAME --dataset DATASET_NAME [--self-loop]\n```\n\nThe hyperparameters might not be the optimal, you could specify them manually in `conf.py`.\n"
  },
  {
    "path": "examples/pytorch/model_zoo/citation_network/conf.py",
    "content": "import torch as th\nimport torch.nn.functional as F\n\nGCN_CONFIG = {\n    \"extra_args\": [16, 1, F.relu, 0.5],\n    \"lr\": 1e-2,\n    \"weight_decay\": 5e-4,\n}\n\nGAT_CONFIG = {\n    \"extra_args\": [8, 1, [8] * 1 + [1], F.elu, 0.6, 0.6, 0.2, False],\n    \"lr\": 0.005,\n    \"weight_decay\": 5e-4,\n}\n\nGRAPHSAGE_CONFIG = {\n    \"extra_args\": [16, 1, F.relu, 0.5, \"gcn\"],\n    \"lr\": 1e-2,\n    \"weight_decay\": 5e-4,\n}\n\nAPPNP_CONFIG = {\n    \"extra_args\": [64, 1, F.relu, 0.5, 0.5, 0.1, 10],\n    \"lr\": 1e-2,\n    \"weight_decay\": 5e-4,\n}\n\nTAGCN_CONFIG = {\n    \"extra_args\": [16, 1, F.relu, 0.5],\n    \"lr\": 1e-2,\n    \"weight_decay\": 5e-4,\n}\n\nAGNN_CONFIG = {\n    \"extra_args\": [32, 2, 1.0, True, 0.5],\n    \"lr\": 1e-2,\n    \"weight_decay\": 5e-4,\n}\n\nSGC_CONFIG = {\n    \"extra_args\": [None, 2, False],\n    \"lr\": 0.2,\n    \"weight_decay\": 5e-6,\n}\n\nGIN_CONFIG = {\n    \"extra_args\": [16, 1, 0, True],\n    \"lr\": 1e-2,\n    \"weight_decay\": 5e-6,\n}\n\nCHEBNET_CONFIG = {\n    \"extra_args\": [32, 1, 2, True],\n    \"lr\": 1e-2,\n    \"weight_decay\": 5e-4,\n}\n"
  },
  {
    "path": "examples/pytorch/model_zoo/citation_network/models.py",
    "content": "import torch\nimport torch.nn as nn\n\nfrom dgl.nn.pytorch import (\n    AGNNConv,\n    APPNPConv,\n    ChebConv,\n    GATConv,\n    GINConv,\n    GraphConv,\n    SAGEConv,\n    SGConv,\n    TAGConv,\n)\n\n\nclass GCN(nn.Module):\n    def __init__(\n        self, g, in_feats, n_classes, n_hidden, n_layers, activation, dropout\n    ):\n        super(GCN, self).__init__()\n        self.g = g\n        self.layers = nn.ModuleList()\n        # input layer\n        self.layers.append(GraphConv(in_feats, n_hidden, activation=activation))\n        # hidden layers\n        for i in range(n_layers - 1):\n            self.layers.append(\n                GraphConv(n_hidden, n_hidden, activation=activation)\n            )\n        # output layer\n        self.layers.append(GraphConv(n_hidden, n_classes))\n        self.dropout = nn.Dropout(p=dropout)\n\n    def forward(self, features):\n        h = features\n        for i, layer in enumerate(self.layers):\n            if i != 0:\n                h = self.dropout(h)\n            h = layer(self.g, h)\n        return h\n\n\nclass GAT(nn.Module):\n    def __init__(\n        self,\n        g,\n        in_dim,\n        num_classes,\n        num_hidden,\n        num_layers,\n        heads,\n        activation,\n        feat_drop,\n        attn_drop,\n        negative_slope,\n        residual,\n    ):\n        super(GAT, self).__init__()\n        self.g = g\n        self.num_layers = num_layers\n        self.gat_layers = nn.ModuleList()\n        self.activation = activation\n        # input projection (no residual)\n        self.gat_layers.append(\n            GATConv(\n                in_dim,\n                num_hidden,\n                heads[0],\n                feat_drop,\n                attn_drop,\n                negative_slope,\n                False,\n                self.activation,\n            )\n        )\n        # hidden layers\n        for l in range(1, num_layers):\n            # due to multi-head, the in_dim = num_hidden * num_heads\n            self.gat_layers.append(\n                GATConv(\n                    num_hidden * heads[l - 1],\n                    num_hidden,\n                    heads[l],\n                    feat_drop,\n                    attn_drop,\n                    negative_slope,\n                    residual,\n                    self.activation,\n                )\n            )\n        # output projection\n        self.gat_layers.append(\n            GATConv(\n                num_hidden * heads[-2],\n                num_classes,\n                heads[-1],\n                feat_drop,\n                attn_drop,\n                negative_slope,\n                residual,\n                None,\n            )\n        )\n\n    def forward(self, inputs):\n        h = inputs\n        for l in range(self.num_layers):\n            h = self.gat_layers[l](self.g, h).flatten(1)\n        # output projection\n        logits = self.gat_layers[-1](self.g, h).mean(1)\n        return logits\n\n\nclass GraphSAGE(nn.Module):\n    def __init__(\n        self,\n        g,\n        in_feats,\n        n_classes,\n        n_hidden,\n        n_layers,\n        activation,\n        dropout,\n        aggregator_type,\n    ):\n        super(GraphSAGE, self).__init__()\n        self.layers = nn.ModuleList()\n        self.g = g\n\n        # input layer\n        self.layers.append(\n            SAGEConv(\n                in_feats,\n                n_hidden,\n                aggregator_type,\n                feat_drop=dropout,\n                activation=activation,\n            )\n        )\n        # hidden layers\n        for i in range(n_layers - 1):\n            self.layers.append(\n                SAGEConv(\n                    n_hidden,\n                    n_hidden,\n                    aggregator_type,\n                    feat_drop=dropout,\n                    activation=activation,\n                )\n            )\n        # output layer\n        self.layers.append(\n            SAGEConv(\n                n_hidden,\n                n_classes,\n                aggregator_type,\n                feat_drop=dropout,\n                activation=None,\n            )\n        )  # activation None\n\n    def forward(self, features):\n        h = features\n        for layer in self.layers:\n            h = layer(self.g, h)\n        return h\n\n\nclass APPNP(nn.Module):\n    def __init__(\n        self,\n        g,\n        in_feats,\n        n_classes,\n        n_hidden,\n        n_layers,\n        activation,\n        feat_drop,\n        edge_drop,\n        alpha,\n        k,\n    ):\n        super(APPNP, self).__init__()\n        self.g = g\n        self.layers = nn.ModuleList()\n        # input layer\n        self.layers.append(nn.Linear(in_feats, n_hidden))\n        # hidden layers\n        for i in range(1, n_layers):\n            self.layers.append(nn.Linear(n_hidden, n_hidden))\n        # output layer\n        self.layers.append(nn.Linear(n_hidden, n_classes))\n        self.activation = activation\n        if feat_drop:\n            self.feat_drop = nn.Dropout(feat_drop)\n        else:\n            self.feat_drop = lambda x: x\n        self.propagate = APPNPConv(k, alpha, edge_drop)\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        for layer in self.layers:\n            layer.reset_parameters()\n\n    def forward(self, features):\n        # prediction step\n        h = features\n        h = self.feat_drop(h)\n        h = self.activation(self.layers[0](h))\n        for layer in self.layers[1:-1]:\n            h = self.activation(layer(h))\n        h = self.layers[-1](self.feat_drop(h))\n        # propagation step\n        h = self.propagate(self.g, h)\n        return h\n\n\nclass TAGCN(nn.Module):\n    def __init__(\n        self, g, in_feats, n_classes, n_hidden, n_layers, activation, dropout\n    ):\n        super(TAGCN, self).__init__()\n        self.g = g\n        self.layers = nn.ModuleList()\n        # input layer\n        self.layers.append(TAGConv(in_feats, n_hidden, activation=activation))\n        # hidden layers\n        for i in range(n_layers - 1):\n            self.layers.append(\n                TAGConv(n_hidden, n_hidden, activation=activation)\n            )\n        # output layer\n        self.layers.append(TAGConv(n_hidden, n_classes))  # activation=None\n        self.dropout = nn.Dropout(p=dropout)\n\n    def forward(self, features):\n        h = features\n        for i, layer in enumerate(self.layers):\n            if i != 0:\n                h = self.dropout(h)\n            h = layer(self.g, h)\n        return h\n\n\nclass AGNN(nn.Module):\n    def __init__(\n        self,\n        g,\n        in_feats,\n        n_classes,\n        n_hidden,\n        n_layers,\n        init_beta,\n        learn_beta,\n        dropout,\n    ):\n        super(AGNN, self).__init__()\n        self.g = g\n        self.layers = nn.ModuleList(\n            [AGNNConv(init_beta, learn_beta) for _ in range(n_layers)]\n        )\n        self.proj = nn.Sequential(\n            nn.Dropout(dropout), nn.Linear(in_feats, n_hidden), nn.ReLU()\n        )\n        self.cls = nn.Sequential(\n            nn.Dropout(dropout), nn.Linear(n_hidden, n_classes)\n        )\n\n    def forward(self, features):\n        h = self.proj(features)\n        for layer in self.layers:\n            h = layer(self.g, h)\n        return self.cls(h)\n\n\nclass SGC(nn.Module):\n    def __init__(self, g, in_feats, n_classes, n_hidden, k, bias):\n        super(SGC, self).__init__()\n        self.g = g\n        self.net = SGConv(in_feats, n_classes, k=k, cached=True, bias=bias)\n\n    def forward(self, features):\n        return self.net(self.g, features)\n\n\nclass GIN(nn.Module):\n    def __init__(\n        self, g, in_feats, n_classes, n_hidden, n_layers, init_eps, learn_eps\n    ):\n        super(GIN, self).__init__()\n        self.g = g\n        self.layers = nn.ModuleList()\n        self.layers.append(\n            GINConv(\n                nn.Sequential(\n                    nn.Dropout(0.6),\n                    nn.Linear(in_feats, n_hidden),\n                    nn.ReLU(),\n                ),\n                \"mean\",\n                init_eps,\n                learn_eps,\n            )\n        )\n        for i in range(n_layers - 1):\n            self.layers.append(\n                GINConv(\n                    nn.Sequential(\n                        nn.Dropout(0.6),\n                        nn.Linear(n_hidden, n_hidden),\n                        nn.ReLU(),\n                    ),\n                    \"mean\",\n                    init_eps,\n                    learn_eps,\n                )\n            )\n        self.layers.append(\n            GINConv(\n                nn.Sequential(\n                    nn.Dropout(0.6),\n                    nn.Linear(n_hidden, n_classes),\n                ),\n                \"mean\",\n                init_eps,\n                learn_eps,\n            )\n        )\n\n    def forward(self, features):\n        h = features\n        for layer in self.layers:\n            h = layer(self.g, h)\n        return h\n\n\nclass ChebNet(nn.Module):\n    def __init__(self, g, in_feats, n_classes, n_hidden, n_layers, k, bias):\n        super(ChebNet, self).__init__()\n        self.g = g\n        self.layers = nn.ModuleList()\n        self.layers.append(ChebConv(in_feats, n_hidden, k, bias=bias))\n        for _ in range(n_layers - 1):\n            self.layers.append(ChebConv(n_hidden, n_hidden, k, bias=bias))\n\n        self.layers.append(ChebConv(n_hidden, n_classes, k, bias=bias))\n\n    def forward(self, features):\n        h = features\n        for layer in self.layers:\n            h = layer(self.g, h, [2])\n        return h\n"
  },
  {
    "path": "examples/pytorch/model_zoo/citation_network/run.py",
    "content": "import argparse\nimport time\n\nimport networkx as nx\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom conf import *\nfrom models import *\n\nimport dgl\nfrom dgl.data import load_data, register_data_args\n\n\ndef get_model_and_config(name):\n    name = name.lower()\n    if name == \"gcn\":\n        return GCN, GCN_CONFIG\n    elif name == \"gat\":\n        return GAT, GAT_CONFIG\n    elif name == \"graphsage\":\n        return GraphSAGE, GRAPHSAGE_CONFIG\n    elif name == \"appnp\":\n        return APPNP, APPNP_CONFIG\n    elif name == \"tagcn\":\n        return TAGCN, TAGCN_CONFIG\n    elif name == \"agnn\":\n        return AGNN, AGNN_CONFIG\n    elif name == \"sgc\":\n        return SGC, SGC_CONFIG\n    elif name == \"gin\":\n        return GIN, GIN_CONFIG\n    elif name == \"chebnet\":\n        return ChebNet, CHEBNET_CONFIG\n\n\ndef evaluate(model, features, labels, mask):\n    model.eval()\n    with torch.no_grad():\n        logits = model(features)\n        logits = logits[mask]\n        labels = labels[mask]\n        _, indices = torch.max(logits, dim=1)\n        correct = torch.sum(indices == labels)\n        return correct.item() * 1.0 / len(labels)\n\n\ndef main(args):\n    # load and preprocess dataset\n    data = load_data(args)\n    g = data[0]\n    if args.gpu < 0:\n        cuda = False\n    else:\n        cuda = True\n        g = g.to(args.gpu)\n    features = g.ndata[\"feat\"]\n    labels = g.ndata[\"label\"]\n    train_mask = g.ndata[\"train_mask\"]\n    val_mask = g.ndata[\"val_mask\"]\n    test_mask = g.ndata[\"test_mask\"]\n    in_feats = features.shape[1]\n    n_classes = data.num_classes\n    n_edges = g.num_edges()\n    print(\n        \"\"\"----Data statistics------'\n      #Edges %d\n      #Classes %d\n      #Train samples %d\n      #Val samples %d\n      #Test samples %d\"\"\"\n        % (\n            n_edges,\n            n_classes,\n            train_mask.int().sum().item(),\n            val_mask.int().sum().item(),\n            test_mask.int().sum().item(),\n        )\n    )\n\n    # graph preprocess and calculate normalization factor\n    # add self loop\n    if args.self_loop:\n        g = g.remove_self_loop().add_self_loop()\n    n_edges = g.num_edges()\n\n    # normalization\n    degs = g.in_degrees().float()\n    norm = torch.pow(degs, -0.5)\n    norm[torch.isinf(norm)] = 0\n    g.ndata[\"norm\"] = norm.unsqueeze(1)\n\n    # create GCN model\n    GNN, config = get_model_and_config(args.model)\n    model = GNN(g, in_feats, n_classes, *config[\"extra_args\"])\n\n    if cuda:\n        model = model.cuda()\n\n    print(model)\n\n    loss_fcn = torch.nn.CrossEntropyLoss()\n\n    # use optimizer\n    optimizer = torch.optim.Adam(\n        model.parameters(), lr=config[\"lr\"], weight_decay=config[\"weight_decay\"]\n    )\n\n    # initialize graph\n    mean = 0\n    for epoch in range(200):\n        model.train()\n        if epoch >= 3:\n            t0 = time.time()\n        # forward\n        logits = model(features)\n        loss = loss_fcn(logits[train_mask], labels[train_mask])\n\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        if epoch >= 3:\n            mean = (mean * (epoch - 3) + (time.time() - t0)) / (epoch - 2)\n            acc = evaluate(model, features, labels, val_mask)\n            print(\n                \"Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | \"\n                \"ETputs(KTEPS) {:.2f}\".format(\n                    epoch,\n                    mean,\n                    loss.item(),\n                    acc,\n                    n_edges / mean / 1000,\n                )\n            )\n\n    print()\n    acc = evaluate(model, features, labels, test_mask)\n    print(\"Test Accuracy {:.4f}\".format(acc))\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(\n        description=\"Node classification on citation networks.\"\n    )\n    register_data_args(parser)\n    parser.add_argument(\n        \"--model\",\n        type=str,\n        default=\"gcn\",\n        help=\"model to use, available models are gcn, gat, graphsage, gin,\"\n        \"appnp, tagcn, sgc, agnn\",\n    )\n    parser.add_argument(\"--gpu\", type=int, default=-1, help=\"gpu\")\n    parser.add_argument(\n        \"--self-loop\",\n        action=\"store_true\",\n        help=\"graph self-loop (default=False)\",\n    )\n    args = parser.parse_args()\n    print(args)\n    main(args)\n"
  },
  {
    "path": "examples/pytorch/model_zoo/geometric/.gitignore",
    "content": "MNIST/\n"
  },
  {
    "path": "examples/pytorch/model_zoo/geometric/README.md",
    "content": "Geometric Deep Learning models\n=========\n\nThis example shows how to use geometric deep learning models defined in `dgl.nn.pytorch.conv` for\ngraph classification.\n\nCurrently we support following models:\n- [ChebNet](https://arxiv.org/pdf/1606.09375.pdf)\n- [MoNet](https://arxiv.org/pdf/1611.08402.pdf)\n\n## Image Classification on MNIST\n\nBy transforming images to graphs, graph classifcation algorithms could\nbe applied to image classification problems.\n\n### Usage\n```bash\npython mnist.py --model cheb --gpu 0\npython mnist.py --model monet --gpu 0\n```\n\n### Acknowledgement\nWe thank [Xavier Bresson](https://github.com/xbresson) for providing \ncode for graph coarsening algorithm and grid graph building in  \n[CE7454_2019 Labs](https://github.com/xbresson/CE7454_2019/tree/master/codes/labs_lecture14/lab01_ChebGCNs).\n"
  },
  {
    "path": "examples/pytorch/model_zoo/geometric/coarsening.py",
    "content": "# author: xbresson\n# code link: https://github.com/xbresson/CE7454_2019/blob/master/codes/labs_lecture14/lab01_ChebGCNs/lib/coarsening.py\n\nimport numpy as np\nimport scipy.sparse\nimport sklearn.metrics\n\n\ndef laplacian(W, normalized=True):\n    \"\"\"Return graph Laplacian\"\"\"\n\n    # Degree matrix.\n    d = W.sum(axis=0)\n\n    # Laplacian matrix.\n    if not normalized:\n        D = scipy.sparse.diags(d.A.squeeze(), 0)\n        L = D - W\n    else:\n        d += np.spacing(np.array(0, W.dtype))\n        d = 1 / np.sqrt(d)\n        D = scipy.sparse.diags(d.A.squeeze(), 0)\n        I = scipy.sparse.identity(d.size, dtype=W.dtype)\n        L = I - D * W * D\n\n    assert np.abs(L - L.T).mean() < 1e-9\n    assert type(L) is scipy.sparse.csr.csr_matrix\n    return L\n\n\ndef rescale_L(L, lmax=2):\n    \"\"\"Rescale Laplacian eigenvalues to [-1,1]\"\"\"\n    M, M = L.shape\n    I = scipy.sparse.identity(M, format=\"csr\", dtype=L.dtype)\n    L /= lmax * 2\n    L -= I\n    return L\n\n\ndef lmax_L(L):\n    \"\"\"Compute largest Laplacian eigenvalue\"\"\"\n    return scipy.sparse.linalg.eigsh(\n        L, k=1, which=\"LM\", return_eigenvectors=False\n    )[0]\n\n\n# graph coarsening with Heavy Edge Matching\ndef coarsen(A, levels):\n    graphs, parents = HEM(A, levels)\n    perms = compute_perm(parents)\n\n    laplacians = []\n    for i, A in enumerate(graphs):\n        M, M = A.shape\n\n        if i < levels:\n            A = perm_adjacency(A, perms[i])\n\n        A = A.tocsr()\n        A.eliminate_zeros()\n        Mnew, Mnew = A.shape\n        print(\n            \"Layer {0}: M_{0} = |V| = {1} nodes ({2} added), |E| = {3} edges\".format(\n                i, Mnew, Mnew - M, A.nnz // 2\n            )\n        )\n\n        L = laplacian(A, normalized=True)\n        laplacians.append(L)\n\n    return laplacians, perms[0] if len(perms) > 0 else None\n\n\ndef HEM(W, levels, rid=None):\n    \"\"\"\n    Coarsen a graph multiple times using the Heavy Edge Matching (HEM).\n    Input\n    W: symmetric sparse weight (adjacency) matrix\n    levels: the number of coarsened graphs\n    Output\n    graph[0]: original graph of size N_1\n    graph[2]: coarser graph of size N_2 < N_1\n    graph[levels]: coarsest graph of Size N_levels < ... < N_2 < N_1\n    parents[i] is a vector of size N_i with entries ranging from 1 to N_{i+1}\n        which indicate the parents in the coarser graph[i+1]\n    nd_sz{i} is a vector of size N_i that contains the size of the supernode in the graph{i}\n    Note\n    if \"graph\" is a list of length k, then \"parents\" will be a list of length k-1\n    \"\"\"\n\n    N, N = W.shape\n\n    if rid is None:\n        rid = np.random.permutation(range(N))\n\n    ss = np.array(W.sum(axis=0)).squeeze()\n    rid = np.argsort(ss)\n\n    parents = []\n    degree = W.sum(axis=0) - W.diagonal()\n    graphs = []\n    graphs.append(W)\n\n    print(\"Heavy Edge Matching coarsening with Xavier version\")\n\n    for _ in range(levels):\n        # CHOOSE THE WEIGHTS FOR THE PAIRING\n        # weights = ones(N,1)       # metis weights\n        weights = degree  # graclus weights\n        # weights = supernode_size  # other possibility\n        weights = np.array(weights).squeeze()\n\n        # PAIR THE VERTICES AND CONSTRUCT THE ROOT VECTOR\n        idx_row, idx_col, val = scipy.sparse.find(W)\n        cc = idx_row\n        rr = idx_col\n        vv = val\n\n        # TO BE SPEEDUP\n        if not (list(cc) == list(np.sort(cc))):\n            tmp = cc\n            cc = rr\n            rr = tmp\n\n        cluster_id = HEM_one_level(cc, rr, vv, rid, weights)  # cc is ordered\n        parents.append(cluster_id)\n\n        # COMPUTE THE EDGES WEIGHTS FOR THE NEW GRAPH\n        nrr = cluster_id[rr]\n        ncc = cluster_id[cc]\n        nvv = vv\n        Nnew = cluster_id.max() + 1\n        # CSR is more appropriate: row,val pairs appear multiple times\n        W = scipy.sparse.csr_matrix((nvv, (nrr, ncc)), shape=(Nnew, Nnew))\n        W.eliminate_zeros()\n\n        # Add new graph to the list of all coarsened graphs\n        graphs.append(W)\n        N, N = W.shape\n\n        # COMPUTE THE DEGREE (OMIT OR NOT SELF LOOPS)\n        degree = W.sum(axis=0)\n        # degree = W.sum(axis=0) - W.diagonal()\n\n        # CHOOSE THE ORDER IN WHICH VERTICES WILL BE VISTED AT THE NEXT PASS\n        # [~, rid]=sort(ss);     # arthur strategy\n        # [~, rid]=sort(supernode_size);    #  thomas strategy\n        # rid=randperm(N);                  #  metis/graclus strategy\n        ss = np.array(W.sum(axis=0)).squeeze()\n        rid = np.argsort(ss)\n\n    return graphs, parents\n\n\n# Coarsen a graph given by rr,cc,vv.  rr is assumed to be ordered\ndef HEM_one_level(rr, cc, vv, rid, weights):\n    nnz = rr.shape[0]\n    N = rr[nnz - 1] + 1\n\n    marked = np.zeros(N, np.bool_)\n    rowstart = np.zeros(N, np.int32)\n    rowlength = np.zeros(N, np.int32)\n    cluster_id = np.zeros(N, np.int32)\n\n    oldval = rr[0]\n    count = 0\n    clustercount = 0\n\n    for ii in range(nnz):\n        rowlength[count] = rowlength[count] + 1\n        if rr[ii] > oldval:\n            oldval = rr[ii]\n            rowstart[count + 1] = ii\n            count = count + 1\n\n    for ii in range(N):\n        tid = rid[ii]\n        if not marked[tid]:\n            wmax = 0.0\n            rs = rowstart[tid]\n            marked[tid] = True\n            bestneighbor = -1\n            for jj in range(rowlength[tid]):\n                nid = cc[rs + jj]\n                if marked[nid]:\n                    tval = 0.0\n                else:\n                    # First approach\n                    if 2 == 1:\n                        tval = vv[rs + jj] * (\n                            1.0 / weights[tid] + 1.0 / weights[nid]\n                        )\n\n                    # Second approach\n                    if 1 == 1:\n                        Wij = vv[rs + jj]\n                        Wii = vv[rowstart[tid]]\n                        Wjj = vv[rowstart[nid]]\n                        di = weights[tid]\n                        dj = weights[nid]\n                        tval = (2.0 * Wij + Wii + Wjj) * 1.0 / (di + dj + 1e-9)\n\n                if tval > wmax:\n                    wmax = tval\n                    bestneighbor = nid\n\n            cluster_id[tid] = clustercount\n\n            if bestneighbor > -1:\n                cluster_id[bestneighbor] = clustercount\n                marked[bestneighbor] = True\n\n            clustercount += 1\n\n    return cluster_id\n\n\ndef compute_perm(parents):\n    \"\"\"\n    Return a list of indices to reorder the adjacency and data matrices so\n    that the union of two neighbors from layer to layer forms a binary tree.\n    \"\"\"\n\n    # Order of last layer is random (chosen by the clustering algorithm).\n    indices = []\n    if len(parents) > 0:\n        M_last = max(parents[-1]) + 1\n        indices.append(list(range(M_last)))\n\n    for parent in parents[::-1]:\n        # Fake nodes go after real ones.\n        pool_singeltons = len(parent)\n\n        indices_layer = []\n        for i in indices[-1]:\n            indices_node = list(np.where(parent == i)[0])\n            assert 0 <= len(indices_node) <= 2\n\n            # Add a node to go with a singelton.\n            if len(indices_node) == 1:\n                indices_node.append(pool_singeltons)\n                pool_singeltons += 1\n\n            # Add two nodes as children of a singelton in the parent.\n            elif len(indices_node) == 0:\n                indices_node.append(pool_singeltons + 0)\n                indices_node.append(pool_singeltons + 1)\n                pool_singeltons += 2\n\n            indices_layer.extend(indices_node)\n        indices.append(indices_layer)\n\n    # Sanity checks.\n    for i, indices_layer in enumerate(indices):\n        M = M_last * 2**i\n        # Reduction by 2 at each layer (binary tree).\n        assert len(indices[0] == M)\n        # The new ordering does not omit an indice.\n        assert sorted(indices_layer) == list(range(M))\n\n    return indices[::-1]\n\n\nassert compute_perm(\n    [np.array([4, 1, 1, 2, 2, 3, 0, 0, 3]), np.array([2, 1, 0, 1, 0])]\n) == [[3, 4, 0, 9, 1, 2, 5, 8, 6, 7, 10, 11], [2, 4, 1, 3, 0, 5], [0, 1, 2]]\n\n\ndef perm_adjacency(A, indices):\n    \"\"\"\n    Permute adjacency matrix, i.e. exchange node ids,\n    so that binary unions form the clustering tree.\n    \"\"\"\n    if indices is None:\n        return A\n\n    M, M = A.shape\n    Mnew = len(indices)\n    A = A.tocoo()\n\n    # Add Mnew - M isolated vertices.\n    rows = scipy.sparse.coo_matrix((Mnew - M, M), dtype=np.float32)\n    cols = scipy.sparse.coo_matrix((Mnew, Mnew - M), dtype=np.float32)\n    A = scipy.sparse.vstack([A, rows])\n    A = scipy.sparse.hstack([A, cols])\n\n    # Permute the rows and the columns.\n    perm = np.argsort(indices)\n    A.row = np.array(perm)[A.row]\n    A.col = np.array(perm)[A.col]\n\n    assert np.abs(A - A.T).mean() < 1e-8  # 1e-9\n    assert type(A) is scipy.sparse.coo.coo_matrix\n    return A\n\n\ndef perm_data(x, indices):\n    \"\"\"\n    Permute data matrix, i.e. exchange node ids,\n    so that binary unions form the clustering tree.\n    \"\"\"\n    if indices is None:\n        return x\n\n    N, M = x.shape\n    Mnew = len(indices)\n    assert Mnew >= M\n    xnew = np.empty((N, Mnew))\n    for i, j in enumerate(indices):\n        # Existing vertex, i.e. real data.\n        if j < M:\n            xnew[:, i] = x[:, j]\n        # Fake vertex because of singeltons.\n        # They will stay 0 so that max pooling chooses the singelton.\n        # Or -infty ?\n        else:\n            xnew[:, i] = np.zeros(N)\n    return xnew\n"
  },
  {
    "path": "examples/pytorch/model_zoo/geometric/coordinate.py",
    "content": "import torch as th\n\n\"\"\"Compute x,y coordinate for nodes in the graph\"\"\"\neps = 1e-8\n\n\ndef get_coordinates(graphs, grid_side, coarsening_levels, perm):\n    rst = []\n    for l in range(coarsening_levels + 1):\n        xs, ys = [], []\n        for i in range(graphs[l].num_nodes()):\n            cnt = eps\n            x_accum = 0\n            y_accum = 0\n            for j in range(i * 2**l, (i + 1) * 2**l):\n                if perm[j] < grid_side**2:\n                    x_accum += perm[j] // grid_side\n                    y_accum += perm[j] % grid_side\n                    cnt += 1\n            xs.append(x_accum / cnt)\n            ys.append(y_accum / cnt)\n        rst.append(\n            th.cat([th.tensor(xs).view(-1, 1), th.tensor(ys).view(-1, 1)], -1)\n        )\n    return rst\n\n\n\"\"\"Cartesian coordinate to polar coordinate\"\"\"\n\n\ndef z2polar(edges):\n    z = edges.dst[\"xy\"] - edges.src[\"xy\"]\n    rho = th.norm(z, dim=-1, p=2)\n    x, y = z.unbind(dim=-1)\n    phi = th.atan2(y, x)\n    return {\"u\": th.cat([rho.unsqueeze(-1), phi.unsqueeze(-1)], -1)}\n"
  },
  {
    "path": "examples/pytorch/model_zoo/geometric/grid_graph.py",
    "content": "# author: xbresson\n# code link: https://github.com/xbresson/CE7454_2019/blob/master/codes/labs_lecture14/lab01_ChebGCNs/lib/grid_graph.py\n\nimport numpy as np\nimport scipy.sparse  # scipy.spatial.distance\nimport scipy.sparse.linalg\nimport sklearn\nimport sklearn.metrics\n\n\ndef grid_graph(grid_side, number_edges, metric):\n    \"\"\"Generate graph of a grid\"\"\"\n    z = grid(grid_side)\n    dist, idx = distance_sklearn_metrics(z, k=number_edges, metric=metric)\n    A = adjacency(dist, idx)\n    print(\"nb edges: \", A.nnz)\n    return A\n\n\ndef grid(m, dtype=np.float32):\n    \"\"\"Return coordinates of grid points\"\"\"\n    M = m**2\n    x = np.linspace(0, 1, m, dtype=dtype)\n    y = np.linspace(0, 1, m, dtype=dtype)\n    xx, yy = np.meshgrid(x, y)\n    z = np.empty((M, 2), dtype)\n    z[:, 0] = xx.reshape(M)\n    z[:, 1] = yy.reshape(M)\n    return z\n\n\ndef distance_sklearn_metrics(z, k=4, metric=\"euclidean\"):\n    \"\"\"Compute pairwise distances\"\"\"\n    # d = sklearn.metrics.pairwise.pairwise_distances(z, metric=metric, n_jobs=-2)\n    d = sklearn.metrics.pairwise.pairwise_distances(z, metric=metric, n_jobs=1)\n    # k-NN\n    idx = np.argsort(d)[:, 1 : k + 1]\n    d.sort()\n    d = d[:, 1 : k + 1]\n    return d, idx\n\n\ndef adjacency(dist, idx):\n    \"\"\"Return adjacency matrix of a kNN graph\"\"\"\n    M, k = dist.shape\n    assert M, k == idx.shape\n    assert dist.min() >= 0\n    assert dist.max() <= 1\n\n    # Pairwise distances\n    sigma2 = np.mean(dist[:, -1]) ** 2\n    dist = np.exp(-(dist**2) / sigma2)\n\n    # Weight matrix\n    I = np.arange(0, M).repeat(k)\n    J = idx.reshape(M * k)\n    V = dist.reshape(M * k)\n    W = scipy.sparse.coo_matrix((V, (I, J)), shape=(M, M))\n\n    # No self-connections\n    W.setdiag(0)\n\n    # Undirected graph\n    bigger = W.T > W\n    W = W - W.multiply(bigger) + W.T.multiply(bigger)\n\n    assert W.nnz % 2 == 0\n    assert np.abs(W - W.T).mean() < 1e-10\n    assert type(W) is scipy.sparse.csr.csr_matrix\n    return W\n"
  },
  {
    "path": "examples/pytorch/model_zoo/geometric/mnist.py",
    "content": "import argparse\nimport time\n\nimport dgl\n\nimport networkx as nx\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom coarsening import coarsen\nfrom coordinate import get_coordinates, z2polar\nfrom dgl.data import load_data, register_data_args\nfrom dgl.nn.pytorch.conv import ChebConv, GMMConv\nfrom dgl.nn.pytorch.glob import MaxPooling\nfrom grid_graph import grid_graph\nfrom torch.utils.data import DataLoader\nfrom torchvision import datasets, transforms\n\nargparser = argparse.ArgumentParser(\"MNIST\")\nargparser.add_argument(\n    \"--gpu\", type=int, default=-1, help=\"gpu id, use cpu if set to -1\"\n)\nargparser.add_argument(\n    \"--model\", type=str, default=\"chebnet\", help=\"model to use, chebnet/monet\"\n)\nargparser.add_argument(\"--batch-size\", type=int, default=100, help=\"batch size\")\nargs = argparser.parse_args()\n\ngrid_side = 28\nnumber_edges = 8\nmetric = \"euclidean\"\n\nA = grid_graph(28, 8, metric)\n\ncoarsening_levels = 4\nL, perm = coarsen(A, coarsening_levels)\ng_arr = [dgl.from_scipy(csr) for csr in L]\n\ncoordinate_arr = get_coordinates(g_arr, grid_side, coarsening_levels, perm)\nstr_to_torch_dtype = {\n    \"float16\": torch.half,\n    \"float32\": torch.float32,\n    \"float64\": torch.float64,\n}\ncoordinate_arr = [\n    coord.to(dtype=str_to_torch_dtype[str(A.dtype)]) for coord in coordinate_arr\n]\nfor g, coordinate_arr in zip(g_arr, coordinate_arr):\n    g.ndata[\"xy\"] = coordinate_arr\n    g.apply_edges(z2polar)\n\n\ndef batcher(batch):\n    g_batch = [[] for _ in range(coarsening_levels + 1)]\n    x_batch = []\n    y_batch = []\n    for x, y in batch:\n        x = torch.cat([x.view(-1), x.new_zeros(len(perm) - 28**2)], 0)\n        x = x[perm]\n        x_batch.append(x)\n        y_batch.append(y)\n        for i in range(coarsening_levels + 1):\n            g_batch[i].append(g_arr[i])\n\n    x_batch = torch.cat(x_batch).unsqueeze(-1)\n    y_batch = torch.LongTensor(y_batch)\n    g_batch = [dgl.batch(g) for g in g_batch]\n    return g_batch, x_batch, y_batch\n\n\ntrainset = datasets.MNIST(\n    root=\".\", train=True, download=True, transform=transforms.ToTensor()\n)\ntestset = datasets.MNIST(\n    root=\".\", train=False, download=True, transform=transforms.ToTensor()\n)\n\ntrain_loader = DataLoader(\n    trainset,\n    batch_size=args.batch_size,\n    shuffle=True,\n    collate_fn=batcher,\n    num_workers=6,\n)\ntest_loader = DataLoader(\n    testset,\n    batch_size=args.batch_size,\n    shuffle=False,\n    collate_fn=batcher,\n    num_workers=6,\n)\n\n\nclass MoNet(nn.Module):\n    def __init__(self, n_kernels, in_feats, hiddens, out_feats):\n        super(MoNet, self).__init__()\n        self.pool = nn.MaxPool1d(2)\n        self.layers = nn.ModuleList()\n        self.readout = MaxPooling()\n\n        # Input layer\n        self.layers.append(GMMConv(in_feats, hiddens[0], 2, n_kernels))\n\n        # Hidden layer\n        for i in range(1, len(hiddens)):\n            self.layers.append(\n                GMMConv(hiddens[i - 1], hiddens[i], 2, n_kernels)\n            )\n\n        self.cls = nn.Sequential(\n            nn.Linear(hiddens[-1], out_feats), nn.LogSoftmax(dim=1)\n        )\n\n    def forward(self, g_arr, feat):\n        for g, layer in zip(g_arr, self.layers):\n            u = g.edata[\"u\"]\n            feat = (\n                self.pool(layer(g, feat, u).transpose(-1, -2).unsqueeze(0))\n                .squeeze(0)\n                .transpose(-1, -2)\n            )\n        return self.cls(self.readout(g_arr[-1], feat))\n\n\nclass ChebNet(nn.Module):\n    def __init__(self, k, in_feats, hiddens, out_feats):\n        super(ChebNet, self).__init__()\n        self.pool = nn.MaxPool1d(2)\n        self.layers = nn.ModuleList()\n        self.readout = MaxPooling()\n\n        # Input layer\n        self.layers.append(ChebConv(in_feats, hiddens[0], k))\n\n        for i in range(1, len(hiddens)):\n            self.layers.append(ChebConv(hiddens[i - 1], hiddens[i], k))\n\n        self.cls = nn.Sequential(\n            nn.Linear(hiddens[-1], out_feats), nn.LogSoftmax(dim=1)\n        )\n\n    def forward(self, g_arr, feat):\n        for g, layer in zip(g_arr, self.layers):\n            feat = (\n                self.pool(\n                    layer(g, feat, [2] * g.batch_size)\n                    .transpose(-1, -2)\n                    .unsqueeze(0)\n                )\n                .squeeze(0)\n                .transpose(-1, -2)\n            )\n        return self.cls(self.readout(g_arr[-1], feat))\n\n\nif args.gpu == -1:\n    device = torch.device(\"cpu\")\nelse:\n    device = torch.device(args.gpu)\n\nif args.model == \"chebnet\":\n    model = ChebNet(2, 1, [32, 64, 128, 256], 10)\nelse:\n    model = MoNet(10, 1, [32, 64, 128, 256], 10)\n\nmodel = model.to(device)\n\noptimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\nlog_interval = 50\n\nfor epoch in range(10):\n    print(\"epoch {} starts\".format(epoch))\n    model.train()\n    hit, tot = 0, 0\n    loss_accum = 0\n    for i, (g, x, y) in enumerate(train_loader):\n        x = x.to(device)\n        y = y.to(device)\n        g = [g_i.to(device) for g_i in g]\n        out = model(g, x)\n        hit += (out.max(-1)[1] == y).sum().item()\n        tot += len(y)\n        loss = F.nll_loss(out, y)\n        loss_accum += loss.item()\n\n        if (i + 1) % log_interval == 0:\n            print(\n                \"loss: {}, acc: {}\".format(loss_accum / log_interval, hit / tot)\n            )\n            hit, tot = 0, 0\n            loss_accum = 0\n\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n    model.eval()\n    hit, tot = 0, 0\n    for g, x, y in test_loader:\n        x = x.to(device)\n        y = y.to(device)\n        g = [g_i.to(device) for g_i in g]\n        out = model(g, x)\n        hit += (out.max(-1)[1] == y).sum().item()\n        tot += len(y)\n\n    print(\"test acc: \", hit / tot)\n"
  },
  {
    "path": "examples/pytorch/monet/README.md",
    "content": "MoNet\n=====\n\n- paper link: [Geometric deep learning on graphs and manifolds using mixture model CNNs](https://arxiv.org/pdf/1611.08402.pdf)\n\nDependencies\n============\n\n- pytorch 1.1+\n\nResults\n=======\n\n## Citation networks\nRun with following (available dataset: \"cora\", \"citeseer\", \"pubmed\")\n```bash\npython3 citation.py --dataset cora --gpu 0\n```\n\n- Cora: ~0.816\n- Pubmed: ~0.763\n\n## Image classification:\n- please refer to [model_zoo/geometric](../model_zoo/geometric)."
  },
  {
    "path": "examples/pytorch/monet/citation.py",
    "content": "import argparse\nimport time\n\nimport networkx as nx\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom dgl import DGLGraph\nfrom dgl.data import load_data, register_data_args\nfrom dgl.nn.pytorch.conv import GMMConv\n\n\nclass MoNet(nn.Module):\n    def __init__(\n        self,\n        g,\n        in_feats,\n        n_hidden,\n        out_feats,\n        n_layers,\n        dim,\n        n_kernels,\n        dropout,\n    ):\n        super(MoNet, self).__init__()\n        self.g = g\n        self.layers = nn.ModuleList()\n        self.pseudo_proj = nn.ModuleList()\n\n        # Input layer\n        self.layers.append(GMMConv(in_feats, n_hidden, dim, n_kernels))\n        self.pseudo_proj.append(nn.Sequential(nn.Linear(2, dim), nn.Tanh()))\n\n        # Hidden layer\n        for _ in range(n_layers - 1):\n            self.layers.append(GMMConv(n_hidden, n_hidden, dim, n_kernels))\n            self.pseudo_proj.append(nn.Sequential(nn.Linear(2, dim), nn.Tanh()))\n\n        # Output layer\n        self.layers.append(GMMConv(n_hidden, out_feats, dim, n_kernels))\n        self.pseudo_proj.append(nn.Sequential(nn.Linear(2, dim), nn.Tanh()))\n        self.dropout = nn.Dropout(dropout)\n\n    def forward(self, feat, pseudo):\n        h = feat\n        for i in range(len(self.layers)):\n            if i != 0:\n                h = self.dropout(h)\n            h = self.layers[i](self.g, h, self.pseudo_proj[i](pseudo))\n        return h\n\n\ndef evaluate(model, features, pseudo, labels, mask):\n    model.eval()\n    with torch.no_grad():\n        logits = model(features, pseudo)\n        logits = logits[mask]\n        labels = labels[mask]\n        _, indices = torch.max(logits, dim=1)\n        correct = torch.sum(indices == labels)\n        return correct.item() * 1.0 / len(labels)\n\n\ndef main(args):\n    # load and preprocess dataset\n    data = load_data(args)\n    g = data[0]\n    if args.gpu < 0:\n        cuda = False\n    else:\n        cuda = True\n        g = g.to(args.gpu)\n    features = g.ndata[\"feat\"]\n    labels = g.ndata[\"label\"]\n    train_mask = g.ndata[\"train_mask\"]\n    val_mask = g.ndata[\"val_mask\"]\n    test_mask = g.ndata[\"test_mask\"]\n    in_feats = features.shape[1]\n    n_classes = data.num_classes\n    n_edges = g.num_edges()\n    print(\n        \"\"\"----Data statistics------'\n      #Edges %d\n      #Classes %d\n      #Train samples %d\n      #Val samples %d\n      #Test samples %d\"\"\"\n        % (\n            n_edges,\n            n_classes,\n            train_mask.sum().item(),\n            val_mask.sum().item(),\n            test_mask.sum().item(),\n        )\n    )\n\n    # graph preprocess and calculate normalization factor\n    g = g.remove_self_loop().add_self_loop()\n    n_edges = g.num_edges()\n    us, vs = g.edges(order=\"eid\")\n    udeg, vdeg = 1 / torch.sqrt(g.in_degrees(us).float()), 1 / torch.sqrt(\n        g.in_degrees(vs).float()\n    )\n    pseudo = torch.cat([udeg.unsqueeze(1), vdeg.unsqueeze(1)], dim=1)\n\n    # create GraphSAGE model\n    model = MoNet(\n        g,\n        in_feats,\n        args.n_hidden,\n        n_classes,\n        args.n_layers,\n        args.pseudo_dim,\n        args.n_kernels,\n        args.dropout,\n    )\n\n    if cuda:\n        model.cuda()\n    loss_fcn = torch.nn.CrossEntropyLoss()\n\n    # use optimizer\n    optimizer = torch.optim.Adam(\n        model.parameters(), lr=args.lr, weight_decay=args.weight_decay\n    )\n\n    # initialize graph\n    mean = 0\n    for epoch in range(args.n_epochs):\n        model.train()\n        if epoch >= 3:\n            t0 = time.time()\n        # forward\n        logits = model(features, pseudo)\n        loss = loss_fcn(logits[train_mask], labels[train_mask])\n\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        if epoch >= 3:\n            mean = (mean * (epoch - 3) + (time.time() - t0)) / (epoch - 2)\n            acc = evaluate(model, features, pseudo, labels, val_mask)\n            print(\n                \"Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | \"\n                \"ETputs(KTEPS) {:.2f}\".format(\n                    epoch,\n                    mean,\n                    loss.item(),\n                    acc,\n                    n_edges / mean / 1000,\n                )\n            )\n\n    print()\n    acc = evaluate(model, features, pseudo, labels, test_mask)\n    print(\"Test Accuracy {:.4f}\".format(acc))\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"MoNet on citation network\")\n    register_data_args(parser)\n    parser.add_argument(\n        \"--dropout\", type=float, default=0.5, help=\"dropout probability\"\n    )\n    parser.add_argument(\"--gpu\", type=int, default=-1, help=\"gpu\")\n    parser.add_argument(\"--lr\", type=float, default=1e-2, help=\"learning rate\")\n    parser.add_argument(\n        \"--n-epochs\", type=int, default=200, help=\"number of training epochs\"\n    )\n    parser.add_argument(\n        \"--n-hidden\", type=int, default=16, help=\"number of hidden gcn units\"\n    )\n    parser.add_argument(\n        \"--n-layers\", type=int, default=1, help=\"number of hidden gcn layers\"\n    )\n    parser.add_argument(\n        \"--pseudo-dim\",\n        type=int,\n        default=2,\n        help=\"Pseudo coordinate dimensions in GMMConv, 2 for cora and 3 for pubmed\",\n    )\n    parser.add_argument(\n        \"--n-kernels\",\n        type=int,\n        default=3,\n        help=\"Number of kernels in GMMConv layer\",\n    )\n    parser.add_argument(\n        \"--weight-decay\", type=float, default=5e-4, help=\"Weight for L2 loss\"\n    )\n    args = parser.parse_args()\n    print(args)\n\n    main(args)\n"
  },
  {
    "path": "examples/pytorch/multigpu/README.md",
    "content": "Multiple GPU Training\n============\n\nRequirements\n------------\n\n```bash\npip install torchmetrics==0.11.4\n```\n\nHow to run\n-------\n\n### Graph property prediction\n\n\nRun with following (available dataset: \"ogbg-molhiv\", \"ogbg-molpcba\")\n```bash\npython3 multi_gpu_graph_prediction.py --dataset ogbg-molhiv\n```\n\n#### __Results__\n```\n* ogbg-molhiv: ~0.7965\n* ogbg-molpcba: ~0.2239\n```\n\n#### __Scalability__\nWe test scalability of the code with dataset \"ogbg-molhiv\" in a machine of type <a href=\"https://aws.amazon.com/blogs/aws/now-available-ec2-instances-g4-with-nvidia-t4-tensor-core-gpus/\">Amazon EC2 g4dn.metal</a>\n, which has **8 Nvidia T4 Tensor Core GPUs**.\n\n\n|GPU number |Speed Up |Batch size |Test accuracy |Average epoch Time|\n| --- | ----------- | ----------- | -----------|-----------|\n| 1 | x | 32 | 0.7765| 45.0s|\n| 2 | 3.7x |64 | 0.7761|12.1s|\n| 4 | 5.9x| 128 |  0.7854|7.6s|\n| 8 | 9.5x| 256 |  0.7751|4.7s|\n\n\n### Node classification\n\n\nRun with following on dataset \"ogbn-products\"\n\n```bash\npython3 multi_gpu_node_classification.py\n```\n\n#### __Results__\n```\nTest Accuracy: ~0.7632\n```\n\n### Link prediction\n\n\nRun with following (available dataset: \"ogbn-products\", \"reddit\")\n\n```bash\npython3 multi_gpu_link_prediction.py --dataset ogbn-products\n```\n\n#### __Results__\n```\nEval F1-score: ~0.7999  Test F1-score: ~0.6383\n```\n\nNotably,\n\n* The loss function is defined by predicting whether an edge exists between two nodes or not.\n* When computing the score of `(u, v)`, the connections between node `u` and `v` are removed from neighbor sampling.\n* The performance of the learned embeddings are measured by training a softmax regression with scikit-learn.\n"
  },
  {
    "path": "examples/pytorch/multigpu/multi_gpu_graph_prediction.py",
    "content": "import argparse\n\nimport dgl\nimport dgl.nn as dglnn\n\nimport torch\nimport torch.distributed as dist\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nfrom dgl.data import AsGraphPredDataset\nfrom dgl.dataloading import GraphDataLoader\nfrom ogb.graphproppred import DglGraphPropPredDataset, Evaluator\nfrom ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder\nfrom tqdm import tqdm\n\n\nclass MLP(nn.Module):\n    def __init__(self, in_feats):\n        super().__init__()\n        self.mlp = nn.Sequential(\n            nn.Linear(in_feats, 2 * in_feats),\n            nn.BatchNorm1d(2 * in_feats),\n            nn.ReLU(),\n            nn.Linear(2 * in_feats, in_feats),\n            nn.BatchNorm1d(in_feats),\n        )\n\n    def forward(self, h):\n        return self.mlp(h)\n\n\nclass GIN(nn.Module):\n    def __init__(self, n_hidden, n_output, n_layers=5):\n        super().__init__()\n        self.node_encoder = AtomEncoder(n_hidden)\n        self.edge_encoders = nn.ModuleList(\n            [BondEncoder(n_hidden) for _ in range(n_layers)]\n        )\n\n        self.pool = dglnn.AvgPooling()\n        self.dropout = nn.Dropout(0.5)\n        self.layers = nn.ModuleList()\n        for _ in range(n_layers):\n            self.layers.append(dglnn.GINEConv(MLP(n_hidden), learn_eps=True))\n        self.predictor = nn.Linear(n_hidden, n_output)\n\n        # add virtual node\n        self.virtual_emb = nn.Embedding(1, n_hidden)\n        nn.init.constant_(self.virtual_emb.weight.data, 0)\n        self.virtual_layers = nn.ModuleList()\n        for _ in range(n_layers - 1):\n            self.virtual_layers.append(MLP(n_hidden))\n        self.virtual_pool = dglnn.SumPooling()\n\n    def forward(self, g, x, x_e):\n        v_emb = self.virtual_emb.weight.expand(g.batch_size, -1)\n        hn = self.node_encoder(x)\n        for i in range(len(self.layers)):\n            v_hn = dgl.broadcast_nodes(g, v_emb)\n            hn = hn + v_hn\n            he = self.edge_encoders[i](x_e)\n            hn = self.layers[i](g, hn, he)\n            hn = F.relu(hn)\n            hn = self.dropout(hn)\n            if i != len(self.layers) - 1:\n                v_emb_tmp = self.virtual_pool(g, hn) + v_emb\n                v_emb = self.virtual_layers[i](v_emb_tmp)\n                v_emb = self.dropout(F.relu(v_emb))\n        hn = self.pool(g, hn)\n        return self.predictor(hn)\n\n\n@torch.no_grad()\ndef evaluate(dataloader, device, model, evaluator):\n    model.eval()\n    y_true = []\n    y_pred = []\n    for batched_graph, labels in tqdm(dataloader):\n        batched_graph, labels = batched_graph.to(device), labels.to(device)\n        node_feat, edge_feat = (\n            batched_graph.ndata[\"feat\"],\n            batched_graph.edata[\"feat\"],\n        )\n        y_hat = model(batched_graph, node_feat, edge_feat)\n        y_true.append(labels.view(y_hat.shape).detach().cpu())\n        y_pred.append(y_hat.detach().cpu())\n    y_true = torch.cat(y_true, dim=0).numpy()\n    y_pred = torch.cat(y_pred, dim=0).numpy()\n    input_dict = {\"y_true\": y_true, \"y_pred\": y_pred}\n    return evaluator.eval(input_dict)\n\n\ndef train(rank, world_size, dataset_name, root):\n    dist.init_process_group(\n        \"nccl\", \"tcp://127.0.0.1:12347\", world_size=world_size, rank=rank\n    )\n    torch.cuda.set_device(rank)\n\n    dataset = AsGraphPredDataset(DglGraphPropPredDataset(dataset_name, root))\n    evaluator = Evaluator(dataset_name)\n\n    model = GIN(300, dataset.num_tasks).to(rank)\n    model = nn.parallel.DistributedDataParallel(model, device_ids=[rank])\n    optimizer = optim.Adam(model.parameters(), lr=0.001)\n    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)\n\n    train_dataloader = GraphDataLoader(\n        dataset[dataset.train_idx], batch_size=256, use_ddp=True, shuffle=True\n    )\n    valid_dataloader = GraphDataLoader(dataset[dataset.val_idx], batch_size=256)\n    test_dataloader = GraphDataLoader(dataset[dataset.test_idx], batch_size=256)\n\n    for epoch in range(50):\n        model.train()\n        train_dataloader.set_epoch(epoch)\n        for batched_graph, labels in train_dataloader:\n            batched_graph, labels = batched_graph.to(rank), labels.to(rank)\n            node_feat, edge_feat = (\n                batched_graph.ndata[\"feat\"],\n                batched_graph.edata[\"feat\"],\n            )\n            logits = model(batched_graph, node_feat, edge_feat)\n            optimizer.zero_grad()\n            is_labeled = labels == labels\n            loss = F.binary_cross_entropy_with_logits(\n                logits.float()[is_labeled], labels.float()[is_labeled]\n            )\n            loss.backward()\n            optimizer.step()\n        scheduler.step()\n\n        if rank == 0:\n            val_metric = evaluate(\n                valid_dataloader, rank, model.module, evaluator\n            )[evaluator.eval_metric]\n            test_metric = evaluate(\n                test_dataloader, rank, model.module, evaluator\n            )[evaluator.eval_metric]\n\n            print(\n                f\"Epoch: {epoch:03d}, Loss: {loss:.4f}, \"\n                f\"Val: {val_metric:.4f}, Test: {test_metric:.4f}\"\n            )\n\n    dist.destroy_process_group()\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--dataset\",\n        type=str,\n        default=\"ogbg-molhiv\",\n        choices=[\"ogbg-molhiv\", \"ogbg-molpcba\"],\n        help=\"name of dataset (default: ogbg-molhiv)\",\n    )\n    dataset_name = parser.parse_args().dataset\n    root = \"./data/OGB\"\n    DglGraphPropPredDataset(dataset_name, root)\n\n    world_size = torch.cuda.device_count()\n    print(\"Let's use\", world_size, \"GPUs!\")\n    args = (world_size, dataset_name, root)\n    import torch.multiprocessing as mp\n\n    mp.spawn(train, args=args, nprocs=world_size, join=True)\n"
  },
  {
    "path": "examples/pytorch/multigpu/multi_gpu_link_prediction.py",
    "content": "import argparse\nimport os\nimport time\n\nimport dgl.function as fn\n\nimport dgl.nn as dglnn\nimport numpy as np\nimport sklearn.linear_model as lm\nimport sklearn.metrics as skm\nimport torch\nimport torch.distributed as dist\nimport torch.multiprocessing as mp\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport tqdm\nfrom dgl.data import AsNodePredDataset, RedditDataset\nfrom dgl.dataloading import (\n    as_edge_prediction_sampler,\n    DataLoader,\n    MultiLayerFullNeighborSampler,\n    NeighborSampler,\n)\nfrom dgl.multiprocessing import shared_tensor\nfrom ogb.nodeproppred import DglNodePropPredDataset\nfrom torch.nn.parallel import DistributedDataParallel\n\n\nclass SAGE(nn.Module):\n    def __init__(self, in_size, hid_size, out_size):\n        super().__init__()\n        self.layers = nn.ModuleList()\n        # two-layer GraphSAGE-mean\n        self.layers.append(dglnn.SAGEConv(in_size, hid_size, \"mean\"))\n        self.layers.append(dglnn.SAGEConv(hid_size, out_size, \"mean\"))\n        self.dropout = nn.Dropout(0.5)\n        self.hid_size = hid_size\n        self.out_size = out_size\n\n    def forward(self, blocks, x):\n        h = x\n        for l, (layer, block) in enumerate(zip(self.layers, blocks)):\n            h = layer(block, h)\n            if l != len(self.layers) - 1:\n                h = F.relu(h)\n                h = self.dropout(h)\n        return h\n\n    def inference(self, g, device, batch_size, use_uva):\n        g.ndata[\"h\"] = g.ndata[\"feat\"]\n        sampler = MultiLayerFullNeighborSampler(1, prefetch_node_feats=[\"h\"])\n        for l, layer in enumerate(self.layers):\n            dataloader = DataLoader(\n                g,\n                torch.arange(g.num_nodes(), device=device),\n                sampler,\n                device=device,\n                batch_size=batch_size,\n                shuffle=False,\n                drop_last=False,\n                num_workers=0,\n                use_ddp=True,\n                use_uva=use_uva,\n            )\n            # in order to prevent running out of GPU memory, allocate a\n            # shared output tensor 'y' in host memory\n            y = shared_tensor(\n                (\n                    g.num_nodes(),\n                    self.hid_size\n                    if l != len(self.layers) - 1\n                    else self.out_size,\n                )\n            )\n            for input_nodes, output_nodes, blocks in (\n                tqdm.tqdm(dataloader) if dist.get_rank() == 0 else dataloader\n            ):\n                x = blocks[0].srcdata[\"h\"]\n                h = layer(blocks[0], x)  # len(blocks) = 1\n                if l != len(self.layers) - 1:\n                    h = F.relu(h)\n                    h = self.dropout(h)\n                # non_blocking (with pinned memory) to accelerate data transfer\n                y[output_nodes] = h.to(y.device, non_blocking=True)\n            # make sure all GPUs are done writing to 'y'\n            dist.barrier()\n            g.ndata[\"h\"] = y if use_uva else y.to(device)\n\n        g.ndata.pop(\"h\")\n        return y\n\n\nclass NegativeSampler(object):\n    def __init__(self, g, k, neg_share=False, device=None):\n        if device is None:\n            device = g.device\n        self.weights = g.in_degrees().float().to(device) ** 0.75\n        self.k = k\n        self.neg_share = neg_share\n\n    def __call__(self, g, eids):\n        src, _ = g.find_edges(eids)\n        n = len(src)\n        if self.neg_share and n % self.k == 0:\n            dst = self.weights.multinomial(n, replacement=True)\n            dst = dst.view(-1, 1, self.k).expand(-1, self.k, -1).flatten()\n        else:\n            dst = self.weights.multinomial(n * self.k, replacement=True)\n        src = src.repeat_interleave(self.k)\n        return src, dst\n\n\nclass CrossEntropyLoss(nn.Module):\n    def forward(self, block_outputs, pos_graph, neg_graph):\n        with pos_graph.local_scope():\n            pos_graph.ndata[\"h\"] = block_outputs\n            pos_graph.apply_edges(fn.u_dot_v(\"h\", \"h\", \"score\"))\n            pos_score = pos_graph.edata[\"score\"]\n        with neg_graph.local_scope():\n            neg_graph.ndata[\"h\"] = block_outputs\n            neg_graph.apply_edges(fn.u_dot_v(\"h\", \"h\", \"score\"))\n            neg_score = neg_graph.edata[\"score\"]\n\n        score = torch.cat([pos_score, neg_score])\n        label = torch.cat(\n            [torch.ones_like(pos_score), torch.zeros_like(neg_score)]\n        ).long()\n        loss = F.binary_cross_entropy_with_logits(score, label.float())\n        return loss\n\n\ndef compute_acc_unsupervised(emb, labels, train_nids, val_nids, test_nids):\n    \"\"\"\n    Compute the accuracy of prediction given the labels.\n    \"\"\"\n    emb = emb.cpu().numpy()\n    labels = labels.cpu().numpy()\n    train_nids = train_nids.cpu().numpy()\n    train_labels = labels[train_nids]\n    val_nids = val_nids.cpu().numpy()\n    val_labels = labels[val_nids]\n    test_nids = test_nids.cpu().numpy()\n    test_labels = labels[test_nids]\n    emb = (emb - emb.mean(0, keepdims=True)) / emb.std(0, keepdims=True)\n    lr = lm.LogisticRegression(multi_class=\"multinomial\", max_iter=10000)\n    lr.fit(emb[train_nids], train_labels)\n    pred = lr.predict(emb)\n    f1_micro_eval = skm.f1_score(val_labels, pred[val_nids], average=\"micro\")\n    f1_micro_test = skm.f1_score(test_labels, pred[test_nids], average=\"micro\")\n    return f1_micro_eval, f1_micro_test\n\n\ndef evaluate(proc_id, model, g, device, use_uva):\n    model.eval()\n    batch_size = 10000\n    with torch.no_grad():\n        pred = model.module.inference(g, device, batch_size, use_uva)\n    return pred\n\n\ndef train(\n    proc_id, nprocs, device, g, train_idx, val_idx, test_idx, model, use_uva\n):\n    # Create PyTorch DataLoader for constructing blocks\n    n_edges = g.num_edges()\n    train_seeds = torch.arange(n_edges).to(device)\n    labels = g.ndata[\"label\"].to(\"cpu\")\n\n    sampler = NeighborSampler([10, 25], prefetch_node_feats=[\"feat\"])\n    sampler = as_edge_prediction_sampler(\n        sampler,\n        exclude=\"reverse_id\",\n        # For each edge with ID e in Reddit dataset, the reverse edge is e ± |E|/2.\n        reverse_eids=torch.cat(\n            [torch.arange(n_edges // 2, n_edges), torch.arange(0, n_edges // 2)]\n        ).to(train_seeds),\n        # num_negs = 1, neg_share = False\n        negative_sampler=NegativeSampler(\n            g, 1, False, device if use_uva else None\n        ),\n    )\n    train_dataloader = DataLoader(\n        g,\n        train_seeds,\n        sampler,\n        device=device,\n        batch_size=10000,\n        shuffle=True,\n        drop_last=False,\n        num_workers=0,\n        use_ddp=True,\n        use_uva=use_uva,\n    )\n    opt = torch.optim.Adam(model.parameters(), lr=0.003)\n    loss_fcn = CrossEntropyLoss()\n    iter_pos = []\n    iter_neg = []\n    for epoch in range(10):\n        tic = time.time()\n        model.train()\n        for step, (input_nodes, pos_graph, neg_graph, blocks) in enumerate(\n            train_dataloader\n        ):\n            x = blocks[0].srcdata[\"feat\"]\n            y_hat = model(blocks, x)\n            loss = loss_fcn(y_hat, pos_graph, neg_graph)\n            opt.zero_grad()\n            loss.backward()\n            opt.step()\n\n            if step % 20 == 0 and proc_id == 0:  # log every 20 steps\n                # gpu memory reserved by PyTorch\n                gpu_mem_alloc = (\n                    torch.cuda.max_memory_allocated() / 1000000\n                    if torch.cuda.is_available()\n                    else 0\n                )\n                print(\n                    f\"Epoch {epoch:05d} | Step {step:05d} | Loss {loss.item():.4f} | GPU {gpu_mem_alloc:.1f} MB\"\n                )\n\n        t = time.time() - tic\n        if proc_id == 0:\n            print(f\"Epoch Time(s): {t:.4f}\")\n        if (epoch + 1) % 5 == 0:  # eval every 5 epochs\n            pred = evaluate(proc_id, model, g, device, use_uva)  # in parallel\n            if proc_id == 0:\n                # only master proc does the accuracy computation\n                eval_acc, test_acc = compute_acc_unsupervised(\n                    pred, labels, train_idx, val_idx, test_idx\n                )\n                print(\n                    f\"Epoch {epoch:05d} | Eval F1-score {eval_acc:.4f} | Test F1-Score {test_acc:.4f}\"\n                )\n\n\ndef run(proc_id, nprocs, devices, g, data, mode):\n    # find corresponding device for my rank\n    device = devices[proc_id]\n    torch.cuda.set_device(device)\n    # initialize process group and unpack data for sub-processes\n    dist.init_process_group(\n        backend=\"nccl\",\n        init_method=\"tcp://127.0.0.1:12345\",\n        world_size=nprocs,\n        rank=proc_id,\n    )\n    out_size, train_idx, val_idx, test_idx = data\n    g = g.to(device if mode == \"puregpu\" else \"cpu\")\n    # create GraphSAGE model (distributed)\n    in_size = g.ndata[\"feat\"].shape[1]\n    model = SAGE(in_size, 16, 16).to(device)\n    model = DistributedDataParallel(\n        model, device_ids=[device], output_device=device\n    )\n    # training + testing\n    use_uva = mode == \"mixed\"\n    train(\n        proc_id, nprocs, device, g, train_idx, val_idx, test_idx, model, use_uva\n    )\n    # cleanup process group\n    dist.destroy_process_group()\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--dataset\",\n        type=str,\n        default=\"ogbn-products\",\n        choices=[\"ogbn-products\", \"reddit\"],\n        help=\"name of dataset (default: ogbn-products)\",\n    )\n    parser.add_argument(\n        \"--mode\",\n        default=\"mixed\",\n        choices=[\"mixed\", \"puregpu\"],\n        help=\"Training mode. 'mixed' for CPU-GPU mixed training, \"\n        \"'puregpu' for pure-GPU training.\",\n    )\n    parser.add_argument(\n        \"--gpu\",\n        type=str,\n        default=\"0\",\n        help=\"GPU(s) in use. Can be a list of gpu ids for multi-gpu training,\"\n        \" e.g., 0,1,2,3.\",\n    )\n    args = parser.parse_args()\n    devices = list(map(int, args.gpu.split(\",\")))\n    nprocs = len(devices)\n    assert (\n        torch.cuda.is_available()\n    ), f\"Must have GPUs to enable multi-gpu training.\"\n    print(f\"Training in {args.mode} mode using {nprocs} GPU(s)\")\n\n    # load and preprocess dataset\n    print(\"Loading data\")\n    if args.dataset == \"ogbn-products\":\n        # can it be AsLinkPredDataset?\n        dataset = AsNodePredDataset(DglNodePropPredDataset(\"ogbn-products\"))\n    elif args.dataset == \"reddit\":\n        dataset = AsNodePredDataset(RedditDataset(self_loop=False))\n\n    g = dataset[0]\n    # avoid creating certain graph formats in each sub-process to save momory\n    g.create_formats_()\n    # thread limiting to avoid resource competition\n    os.environ[\"OMP_NUM_THREADS\"] = str(mp.cpu_count() // 2 // nprocs)\n    data = (\n        dataset.num_classes,\n        dataset.train_idx,\n        dataset.val_idx,\n        dataset.test_idx,\n    )\n\n    mp.spawn(run, args=(nprocs, devices, g, data, args.mode), nprocs=nprocs)\n"
  },
  {
    "path": "examples/pytorch/multigpu/multi_gpu_node_classification.py",
    "content": "import argparse\nimport os\n\nimport dgl.nn as dglnn\n\nimport torch\nimport torch.distributed as dist\nimport torch.multiprocessing as mp\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torchmetrics.functional as MF\nimport tqdm\nfrom dgl.data import AsNodePredDataset\nfrom dgl.dataloading import (\n    DataLoader,\n    MultiLayerFullNeighborSampler,\n    NeighborSampler,\n)\nfrom dgl.multiprocessing import shared_tensor\nfrom ogb.nodeproppred import DglNodePropPredDataset\nfrom torch.nn.parallel import DistributedDataParallel\n\n\nclass SAGE(nn.Module):\n    def __init__(self, in_size, hid_size, out_size):\n        super().__init__()\n        self.layers = nn.ModuleList()\n        # three-layer GraphSAGE-mean\n        self.layers.append(dglnn.SAGEConv(in_size, hid_size, \"mean\"))\n        self.layers.append(dglnn.SAGEConv(hid_size, hid_size, \"mean\"))\n        self.layers.append(dglnn.SAGEConv(hid_size, out_size, \"mean\"))\n        self.dropout = nn.Dropout(0.5)\n        self.hid_size = hid_size\n        self.out_size = out_size\n\n    def forward(self, blocks, x):\n        h = x\n        for l, (layer, block) in enumerate(zip(self.layers, blocks)):\n            h = layer(block, h)\n            if l != len(self.layers) - 1:\n                h = F.relu(h)\n                h = self.dropout(h)\n        return h\n\n    def inference(self, g, device, batch_size, use_uva):\n        g.ndata[\"h\"] = g.ndata[\"feat\"]\n        sampler = MultiLayerFullNeighborSampler(1, prefetch_node_feats=[\"h\"])\n        for l, layer in enumerate(self.layers):\n            dataloader = DataLoader(\n                g,\n                torch.arange(g.num_nodes(), device=device),\n                sampler,\n                device=device,\n                batch_size=batch_size,\n                shuffle=False,\n                drop_last=False,\n                num_workers=0,\n                use_ddp=True,\n                use_uva=use_uva,\n            )\n            # in order to prevent running out of GPU memory, allocate a\n            # shared output tensor 'y' in host memory\n            y = shared_tensor(\n                (\n                    g.num_nodes(),\n                    self.hid_size\n                    if l != len(self.layers) - 1\n                    else self.out_size,\n                )\n            )\n            for input_nodes, output_nodes, blocks in (\n                tqdm.tqdm(dataloader) if dist.get_rank() == 0 else dataloader\n            ):\n                x = blocks[0].srcdata[\"h\"]\n                h = layer(blocks[0], x)  # len(blocks) = 1\n                if l != len(self.layers) - 1:\n                    h = F.relu(h)\n                    h = self.dropout(h)\n                # non_blocking (with pinned memory) to accelerate data transfer\n                y[output_nodes] = h.to(y.device, non_blocking=True)\n            # make sure all GPUs are done writing to 'y'\n            dist.barrier()\n            g.ndata[\"h\"] = y if use_uva else y.to(device)\n\n        g.ndata.pop(\"h\")\n        return y\n\n\ndef evaluate(model, g, num_classes, dataloader):\n    model.eval()\n    ys = []\n    y_hats = []\n    for it, (input_nodes, output_nodes, blocks) in enumerate(dataloader):\n        with torch.no_grad():\n            x = blocks[0].srcdata[\"feat\"]\n            ys.append(blocks[-1].dstdata[\"label\"])\n            y_hats.append(model(blocks, x))\n    return MF.accuracy(\n        torch.cat(y_hats),\n        torch.cat(ys),\n        task=\"multiclass\",\n        num_classes=num_classes,\n    )\n\n\ndef layerwise_infer(\n    proc_id, device, g, num_classes, nid, model, use_uva, batch_size=2**16\n):\n    model.eval()\n    with torch.no_grad():\n        pred = model.module.inference(g, device, batch_size, use_uva)\n        pred = pred[nid]\n        labels = g.ndata[\"label\"][nid].to(pred.device)\n    if proc_id == 0:\n        acc = MF.accuracy(\n            pred, labels, task=\"multiclass\", num_classes=num_classes\n        )\n        print(\"Test Accuracy {:.4f}\".format(acc.item()))\n\n\ndef train(\n    proc_id, nprocs, device, g, num_classes, train_idx, val_idx, model, use_uva\n):\n    sampler = NeighborSampler(\n        [10, 10, 10], prefetch_node_feats=[\"feat\"], prefetch_labels=[\"label\"]\n    )\n    train_dataloader = DataLoader(\n        g,\n        train_idx,\n        sampler,\n        device=device,\n        batch_size=1024,\n        shuffle=True,\n        drop_last=False,\n        num_workers=0,\n        use_ddp=True,\n        use_uva=use_uva,\n    )\n    val_dataloader = DataLoader(\n        g,\n        val_idx,\n        sampler,\n        device=device,\n        batch_size=1024,\n        shuffle=True,\n        drop_last=False,\n        num_workers=0,\n        use_ddp=True,\n        use_uva=use_uva,\n    )\n    opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)\n    for epoch in range(10):\n        model.train()\n        total_loss = 0\n        for it, (input_nodes, output_nodes, blocks) in enumerate(\n            train_dataloader\n        ):\n            x = blocks[0].srcdata[\"feat\"]\n            y = blocks[-1].dstdata[\"label\"]\n            y_hat = model(blocks, x)\n            loss = F.cross_entropy(y_hat, y)\n            opt.zero_grad()\n            loss.backward()\n            opt.step()\n            total_loss += loss\n        acc = (\n            evaluate(model, g, num_classes, val_dataloader).to(device) / nprocs\n        )\n        dist.reduce(acc, 0)\n        if proc_id == 0:\n            print(\n                \"Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} \".format(\n                    epoch, total_loss / (it + 1), acc.item()\n                )\n            )\n\n\ndef run(proc_id, nprocs, devices, g, data, mode):\n    # find corresponding device for my rank\n    device = devices[proc_id]\n    torch.cuda.set_device(device)\n    # initialize process group and unpack data for sub-processes\n    dist.init_process_group(\n        backend=\"nccl\",\n        init_method=\"tcp://127.0.0.1:12345\",\n        world_size=nprocs,\n        rank=proc_id,\n    )\n    num_classes, train_idx, val_idx, test_idx = data\n    train_idx = train_idx.to(device)\n    val_idx = val_idx.to(device)\n    g = g.to(device if mode == \"puregpu\" else \"cpu\")\n    # create GraphSAGE model (distributed)\n    in_size = g.ndata[\"feat\"].shape[1]\n    model = SAGE(in_size, 256, num_classes).to(device)\n    model = DistributedDataParallel(\n        model, device_ids=[device], output_device=device\n    )\n    # training + testing\n    use_uva = mode == \"mixed\"\n    train(\n        proc_id,\n        nprocs,\n        device,\n        g,\n        num_classes,\n        train_idx,\n        val_idx,\n        model,\n        use_uva,\n    )\n    layerwise_infer(proc_id, device, g, num_classes, test_idx, model, use_uva)\n    # cleanup process group\n    dist.destroy_process_group()\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--mode\",\n        default=\"mixed\",\n        choices=[\"mixed\", \"puregpu\"],\n        help=\"Training mode. 'mixed' for CPU-GPU mixed training, \"\n        \"'puregpu' for pure-GPU training.\",\n    )\n    parser.add_argument(\n        \"--gpu\",\n        type=str,\n        default=\"0\",\n        help=\"GPU(s) in use. Can be a list of gpu ids for multi-gpu training,\"\n        \" e.g., 0,1,2,3.\",\n    )\n    args = parser.parse_args()\n    devices = list(map(int, args.gpu.split(\",\")))\n    nprocs = len(devices)\n    assert (\n        torch.cuda.is_available()\n    ), f\"Must have GPUs to enable multi-gpu training.\"\n    print(f\"Training in {args.mode} mode using {nprocs} GPU(s)\")\n\n    # load and preprocess dataset\n    print(\"Loading data\")\n    dataset = AsNodePredDataset(DglNodePropPredDataset(\"ogbn-products\"))\n    g = dataset[0]\n    # avoid creating certain graph formats in each sub-process to save momory\n    g.create_formats_()\n    # thread limiting to avoid resource competition\n    os.environ[\"OMP_NUM_THREADS\"] = str(mp.cpu_count() // 2 // nprocs)\n    data = (\n        dataset.num_classes,\n        dataset.train_idx,\n        dataset.val_idx,\n        dataset.test_idx,\n    )\n\n    mp.spawn(run, args=(nprocs, devices, g, data, args.mode), nprocs=nprocs)\n"
  },
  {
    "path": "examples/pytorch/mvgrl/README.md",
    "content": "# DGL Implementation of MVGRL\nThis DGL example implements the model proposed in the paper [Contrastive Multi-View Representation Learning on Graphs](https://arxiv.org/abs/2006.05582).\n\nAuthor's code: https://github.com/kavehhassani/mvgrl\n\n## Example Implementor\n\nThis example was implemented by [Hengrui Zhang](https://github.com/hengruizhang98) when he was an applied scientist intern at AWS Shanghai AI Lab.\n\n## Dependencies\n\n- Python 3.7\n- PyTorch 1.7.1\n- dgl 0.6.0\n- networkx\n- scipy\n\n## Datasets\n\n##### Unsupervised Graph Classification Datasets:\n\n 'MUTAG', 'PTC_MR', 'REDDIT-BINARY', 'IMDB-BINARY', 'IMDB-MULTI'.\n\n| Dataset         | MUTAG | PTC_MR | RDT-B  | IMDB-B | IMDB-M |\n| --------------- | ----- | ------ | ------ | ------ | ------ |\n| # Graphs        | 188   | 344    | 2000   | 1000   | 1500   |\n| # Classes       | 2     | 2      | 2      | 2      | 3      |\n| Avg. Graph Size | 17.93 | 14.29  | 429.63 | 19.77  | 13.00  |\n* RDT-B, IMDB-B, IMDB-M are short for REDDIT-BINARY, IMDB-BINARY and IMDB-MULTI respectively.\n\n##### Unsupervised Node Classification Datasets:\n\n'Cora', 'Citeseer' and 'Pubmed'\n\n| Dataset  | # Nodes | # Edges | # Classes |\n| -------- | ------- | ------- | --------- |\n| Cora     | 2,708   | 10,556  | 7         |\n| Citeseer | 3,327   | 9,228   | 6         |\n| Pubmed   | 19,717  | 88,651  | 3         |\n\n\n## Arguments\n\n##### \tGraph Classification:\n\n```\n--dataname         str     The graph dataset name.                Default is 'MUTAG'.\n--gpu              int     GPU index.                             Default is -1, using cpu.\n--epochs           int     Number of training periods.            Default is 200.\n--patience         int     Early stopping steps.                  Default is 20.\n--lr               float   Learning rate.                         Default is 0.001.\n--wd               float   Weight decay.                          Default is 0.0.\n--batch_size       int     Size of a training batch.              Default is 64.\n--n_layers         int     Number of GNN layers.                  Default is 4.\n--hid_dim          int     Embedding dimension.                   Default is 32.\n```\n\n##### \tNode Classification:\n\n```\n--dataname         str     The graph dataset name.                Default is 'cora'.\n--gpu              int     GPU index.                             Default is -1, using cpu.\n--epochs           int     Number of training periods.            Default is 500.\n--patience         int     Early stopping steps.                  Default is 20.\n--lr1              float   Learning rate of main model.           Default is 0.001.\n--lr2              float   Learning rate of linear classifer.     Default is 0.01.\n--wd1              float   Weight decay of main model.            Default is 0.0.\n--wd2              float   Weight decay of linear classifier.     Default is 0.0.\n--epsilon          float   Edge mask threshold.                   Default is 0.01.\n--hid_dim          int     Embedding dimension.                   Default is 512.\n--sample_size      int     Subgraph size.                         Default is 2000.\n```\n\n## How to run examples\n\n###### Graph Classification\n\n```python\n# Enter the 'graph' directory\ncd graph\n\n# MUTAG:\npython main.py --dataname MUTAG --epochs 20\n\n# PTC_MR:\npython main.py --dataname PTC_MR --epochs 32 --hid_dim 128\n\n# REDDIT-BINARY\npython main.py --dataname REDDIT-BINARY --epochs 20 --hid_dim 128\n\n# IMDB-BINARY\npython main.py --dataname IMDB-BINARY --epochs 20 --hid_dim 512 --n_layers 2\n\n# IMDB-MULTI\npython main.py --dataname IMDB-MULTI --epochs 20 --hid_dim 512 --n_layers 2\n```\n###### Node Classification\n\nFor semi-supervised node classification on 'Cora', 'Citeseer' and 'Pubmed', we provide two implementations:\n\n1. full-graph training, see 'main.py', where we contrast the local and global representations of the whole graph.\n2. subgraph training, see 'main_sample.py', where we contrast the local and global representations of a sampled subgraph with fixed number of nodes.\n\nFor larger graphs(e.g. Pubmed), it would be hard to calculate the graph diffusion matrix(i.e., PPR matrix), so we try to approximate it with [APPNP](https://arxiv.org/abs/1810.05997), see function 'process_dataset_appnp'  in 'node/dataset.py' for details.\n\n```python\n# Enter the 'node' directory\ncd node\n\n# Cora with full graph\npython main.py --dataname cora --gpu 0\n\n# Cora with sampled subgraphs\npython main_sample.py --dataname cora --gpu 0\n\n# Citeseer with full graph\npython main.py --dataname citeseer --wd1 0.001 --wd2 0.01 --epochs 200 --gpu 0\n\n# Citeseer with sampled subgraphs\npython main_sample.py --dataname citeseer --wd2 0.01 --gpu 0\n\n# Pubmed with sampled subgraphs\npython main_sample.py --dataname pubmed --sample_size 4000 --epochs 400 --patience 999 --gpu 0\n```\n\n## \tPerformance\n\nWe use the same  hyper-parameter settings as stated in the original paper.\n\n##### Graph classification:\n\n|      Dataset      | MUTAG | PTC-MR | REDDIT-B | IMDB-B | IMDB-M |\n| :---------------: | :---: | :----: | :------: | :----: | :----: |\n| Accuracy Reported | 89.7  |  62.5  |   84.5   |  74.2  |  51.2  |\n|        DGL        | 89.4  |  62.2  |   85.0   |  73.8  |  51.1  |\n\n* The datasets that the authors used are slightly different from standard TUDataset (see dgl.data.GINDataset) in the nodes' features(e.g. The node features of 'MUTAG' dataset are of dimensionality 11 rather than 7\")\n\n##### Node classification:\n\n|      Dataset      | Cora | Citeseer | Pubmed |\n| :---------------: | :--: | :------: | :----: |\n| Accuracy Reported | 86.8 |   73.3   |  80.1  |\n|    DGL-sample     | 83.2 |   72.6   |  79.8  |\n|     DGL-full      | 83.5 |   73.7   |  OOM   |\n\n* We fail to reproduce the reported accuracy on 'Cora', even with the authors' code.\n* The accuracy reported by the original paper is based on fixed-sized subgraph-training.\n"
  },
  {
    "path": "examples/pytorch/mvgrl/graph/dataset.py",
    "content": "\"\"\" Code adapted from https://github.com/kavehhassani/mvgrl \"\"\"\nimport os\nimport re\nfrom collections import Counter\n\nimport dgl\n\nimport networkx as nx\nimport numpy as np\nimport torch as th\nfrom dgl.data import DGLDataset\nfrom scipy.linalg import fractional_matrix_power, inv\n\n\"\"\" Compute Personalized Page Ranking\"\"\"\n\n\ndef compute_ppr(graph: nx.Graph, alpha=0.2, self_loop=True):\n    a = nx.convert_matrix.to_numpy_array(graph)\n    if self_loop:\n        a = a + np.eye(a.shape[0])  # A^ = A + I_n\n    d = np.diag(np.sum(a, 1))  # D^ = Sigma A^_ii\n    dinv = fractional_matrix_power(d, -0.5)  # D^(-1/2)\n    at = np.matmul(np.matmul(dinv, a), dinv)  # A~ = D^(-1/2) x A^ x D^(-1/2)\n    return alpha * inv(\n        (np.eye(a.shape[0]) - (1 - alpha) * at)\n    )  # a(I_n-(1-a)A~)^-1\n\n\ndef download(dataset, datadir):\n    os.makedirs(datadir)\n    url = \"https://ls11-www.cs.tu-dortmund.de/people/morris/graphkerneldatasets/{0}.zip\".format(\n        dataset\n    )\n    zipfile = os.path.basename(url)\n    os.system(\"wget {0}; unzip {1}\".format(url, zipfile))\n    os.system(\"mv {0}/* {1}\".format(dataset, datadir))\n    os.system(\"rm -r {0}\".format(dataset))\n    os.system(\"rm {0}\".format(zipfile))\n\n\ndef process(dataset):\n    src = os.path.join(os.path.dirname(__file__), \"data\")\n    prefix = os.path.join(src, dataset, dataset)\n\n    # assign each node to the corresponding graph\n    graph_node_dict = {}\n    with open(\"{0}_graph_indicator.txt\".format(prefix), \"r\") as f:\n        for idx, line in enumerate(f):\n            graph_node_dict[idx + 1] = int(line.strip(\"\\n\"))\n\n    node_labels = []\n    if os.path.exists(\"{0}_node_labels.txt\".format(prefix)):\n        with open(\"{0}_node_labels.txt\".format(prefix), \"r\") as f:\n            for line in f:\n                node_labels += [int(line.strip(\"\\n\")) - 1]\n            num_unique_node_labels = max(node_labels) + 1\n    else:\n        print(\"No node labels\")\n\n    node_attrs = []\n    if os.path.exists(\"{0}_node_attributes.txt\".format(prefix)):\n        with open(\"{0}_node_attributes.txt\".format(prefix), \"r\") as f:\n            for line in f:\n                node_attrs.append(\n                    np.array(\n                        [\n                            float(attr)\n                            for attr in re.split(\"[,\\s]+\", line.strip(\"\\s\\n\"))\n                            if attr\n                        ],\n                        dtype=float,\n                    )\n                )\n    else:\n        print(\"No node attributes\")\n\n    graph_labels = []\n    unique_labels = set()\n    with open(\"{0}_graph_labels.txt\".format(prefix), \"r\") as f:\n        for line in f:\n            val = int(line.strip(\"\\n\"))\n            if val not in unique_labels:\n                unique_labels.add(val)\n            graph_labels.append(val)\n    label_idx_dict = {val: idx for idx, val in enumerate(unique_labels)}\n    graph_labels = np.array([label_idx_dict[l] for l in graph_labels])\n\n    adj_list = {idx: [] for idx in range(1, len(graph_labels) + 1)}\n    index_graph = {idx: [] for idx in range(1, len(graph_labels) + 1)}\n    with open(\"{0}_A.txt\".format(prefix), \"r\") as f:\n        for line in f:\n            u, v = tuple(map(int, line.strip(\"\\n\").split(\",\")))\n            adj_list[graph_node_dict[u]].append((u, v))\n            index_graph[graph_node_dict[u]] += [u, v]\n\n    for k in index_graph.keys():\n        index_graph[k] = [u - 1 for u in set(index_graph[k])]\n\n    graphs, pprs = [], []\n    for idx in range(1, 1 + len(adj_list)):\n        graph = nx.from_edgelist(adj_list[idx])\n\n        graph.graph[\"label\"] = graph_labels[idx - 1]\n        for u in graph.nodes():\n            if len(node_labels) > 0:\n                node_label_one_hot = [0] * num_unique_node_labels\n                node_label = node_labels[u - 1]\n                node_label_one_hot[node_label] = 1\n                graph.nodes[u][\"label\"] = node_label_one_hot\n            if len(node_attrs) > 0:\n                graph.nodes[u][\"feat\"] = node_attrs[u - 1]\n        if len(node_attrs) > 0:\n            graph.graph[\"feat_dim\"] = node_attrs[0].shape[0]\n\n        # relabeling\n        mapping = {}\n        for node_idx, node in enumerate(graph.nodes()):\n            mapping[node] = node_idx\n\n        graphs.append(nx.relabel_nodes(graph, mapping))\n        pprs.append(compute_ppr(graph, alpha=0.2))\n\n    if \"feat_dim\" in graphs[0].graph:\n        pass\n    else:\n        max_deg = max([max(dict(graph.degree).values()) for graph in graphs])\n        for graph in graphs:\n            for u in graph.nodes(data=True):\n                f = np.zeros(max_deg + 1)\n                f[graph.degree[u[0]]] = 1.0\n                if \"label\" in u[1]:\n                    f = np.concatenate(\n                        (np.array(u[1][\"label\"], dtype=float), f)\n                    )\n                graph.nodes[u[0]][\"feat\"] = f\n    return graphs, pprs\n\n\ndef load(dataset):\n    basedir = os.path.dirname(os.path.abspath(__file__))\n    datadir = os.path.join(basedir, \"data\", dataset)\n\n    if not os.path.exists(datadir):\n        download(dataset, datadir)\n        graphs, diff = process(dataset)\n        feat, adj, labels = [], [], []\n\n        for idx, graph in enumerate(graphs):\n            adj.append(nx.to_numpy_array(graph))\n            labels.append(graph.graph[\"label\"])\n            feat.append(\n                np.array(list(nx.get_node_attributes(graph, \"feat\").values()))\n            )\n\n        adj, diff, feat, labels = (\n            np.array(adj),\n            np.array(diff),\n            np.array(feat),\n            np.array(labels),\n        )\n\n        np.save(f\"{datadir}/adj.npy\", adj)\n        np.save(f\"{datadir}/diff.npy\", diff)\n        np.save(f\"{datadir}/feat.npy\", feat)\n        np.save(f\"{datadir}/labels.npy\", labels)\n    else:\n        adj = np.load(f\"{datadir}/adj.npy\", allow_pickle=True)\n        diff = np.load(f\"{datadir}/diff.npy\", allow_pickle=True)\n        feat = np.load(f\"{datadir}/feat.npy\", allow_pickle=True)\n        labels = np.load(f\"{datadir}/labels.npy\", allow_pickle=True)\n\n    n_graphs = adj.shape[0]\n\n    graphs = []\n    diff_graphs = []\n    lbls = []\n\n    for i in range(n_graphs):\n        a = adj[i]\n        edge_indexes = a.nonzero()\n\n        graph = dgl.graph(edge_indexes)\n        graph = graph.add_self_loop()\n        graph.ndata[\"feat\"] = th.tensor(feat[i]).float()\n\n        diff_adj = diff[i]\n        diff_indexes = diff_adj.nonzero()\n        diff_weight = th.tensor(diff_adj[diff_indexes]).float()\n\n        diff_graph = dgl.graph(diff_indexes)\n        diff_graph.edata[\"edge_weight\"] = diff_weight\n        label = labels[i]\n        graphs.append(graph)\n        diff_graphs.append(diff_graph)\n        lbls.append(label)\n\n    labels = th.tensor(lbls)\n\n    dataset = TUDataset(graphs, diff_graphs, labels)\n    return dataset\n\n\nclass TUDataset(DGLDataset):\n    def __init__(self, graphs, diff_graphs, labels):\n        super(TUDataset, self).__init__(name=\"tu\")\n        self.graphs = graphs\n        self.diff_graphs = diff_graphs\n        self.labels = labels\n\n    def process(self):\n        return\n\n    def __len__(self):\n        return len(self.graphs)\n\n    def __getitem__(self, idx):\n        return self.graphs[idx], self.diff_graphs[idx], self.labels[idx]\n"
  },
  {
    "path": "examples/pytorch/mvgrl/graph/main.py",
    "content": "import argparse\nimport warnings\n\nimport dgl\n\nimport torch as th\nfrom dataset import load\nfrom dgl.dataloading import GraphDataLoader\n\nwarnings.filterwarnings(\"ignore\")\n\nfrom model import MVGRL\nfrom utils import linearsvc\n\nparser = argparse.ArgumentParser(description=\"mvgrl\")\n\nparser.add_argument(\n    \"--dataname\", type=str, default=\"MUTAG\", help=\"Name of dataset.\"\n)\nparser.add_argument(\n    \"--gpu\", type=int, default=-1, help=\"GPU index. Default: -1, using cpu.\"\n)\nparser.add_argument(\n    \"--epochs\", type=int, default=200, help=\" Number of training periods.\"\n)\nparser.add_argument(\n    \"--patience\", type=int, default=20, help=\"Early stopping steps.\"\n)\nparser.add_argument(\n    \"--lr\", type=float, default=0.001, help=\"Learning rate of mvgrl.\"\n)\nparser.add_argument(\n    \"--wd\", type=float, default=0.0, help=\"Weight decay of mvgrl.\"\n)\nparser.add_argument(\"--batch_size\", type=int, default=64, help=\"Batch size.\")\nparser.add_argument(\n    \"--n_layers\", type=int, default=4, help=\"Number of GNN layers.\"\n)\nparser.add_argument(\"--hid_dim\", type=int, default=32, help=\"Hidden layer dim.\")\n\nargs = parser.parse_args()\n\n# check cuda\nif args.gpu != -1 and th.cuda.is_available():\n    args.device = \"cuda:{}\".format(args.gpu)\nelse:\n    args.device = \"cpu\"\n\n\ndef collate(samples):\n    \"\"\"collate function for building the graph dataloader\"\"\"\n    graphs, diff_graphs, labels = map(list, zip(*samples))\n\n    # generate batched graphs and labels\n    batched_graph = dgl.batch(graphs)\n    batched_labels = th.tensor(labels)\n    batched_diff_graph = dgl.batch(diff_graphs)\n\n    n_graphs = len(graphs)\n    graph_id = th.arange(n_graphs)\n    graph_id = dgl.broadcast_nodes(batched_graph, graph_id)\n\n    batched_graph.ndata[\"graph_id\"] = graph_id\n\n    return batched_graph, batched_diff_graph, batched_labels\n\n\nif __name__ == \"__main__\":\n    # Step 1: Prepare data =================================================================== #\n    dataset = load(args.dataname)\n\n    graphs, diff_graphs, labels = map(list, zip(*dataset))\n    print(\"Number of graphs:\", len(graphs))\n    # generate a full-graph with all examples for evaluation\n\n    wholegraph = dgl.batch(graphs)\n    whole_dg = dgl.batch(diff_graphs)\n\n    # create dataloader for batch training\n    dataloader = GraphDataLoader(\n        dataset,\n        batch_size=args.batch_size,\n        collate_fn=collate,\n        drop_last=False,\n        shuffle=True,\n    )\n\n    in_dim = wholegraph.ndata[\"feat\"].shape[1]\n\n    # Step 2: Create model =================================================================== #\n    model = MVGRL(in_dim, args.hid_dim, args.n_layers)\n    model = model.to(args.device)\n\n    # Step 3: Create training components ===================================================== #\n    optimizer = th.optim.Adam(model.parameters(), lr=args.lr)\n\n    print(\"===== Before training ======\")\n\n    wholegraph = wholegraph.to(args.device)\n    whole_dg = whole_dg.to(args.device)\n    wholefeat = wholegraph.ndata.pop(\"feat\")\n    whole_weight = whole_dg.edata.pop(\"edge_weight\")\n\n    embs = model.get_embedding(wholegraph, whole_dg, wholefeat, whole_weight)\n    lbls = th.LongTensor(labels)\n    acc_mean, acc_std = linearsvc(embs, lbls)\n    print(\"accuracy_mean, {:.4f}\".format(acc_mean))\n\n    best = float(\"inf\")\n    cnt_wait = 0\n    # Step 4: Training epochs =============================================================== #\n    for epoch in range(args.epochs):\n        loss_all = 0\n        model.train()\n\n        for graph, diff_graph, label in dataloader:\n            graph = graph.to(args.device)\n            diff_graph = diff_graph.to(args.device)\n\n            feat = graph.ndata[\"feat\"]\n            graph_id = graph.ndata[\"graph_id\"]\n            edge_weight = diff_graph.edata[\"edge_weight\"]\n            n_graph = label.shape[0]\n\n            optimizer.zero_grad()\n            loss = model(graph, diff_graph, feat, edge_weight, graph_id)\n            loss_all += loss.item()\n            loss.backward()\n            optimizer.step()\n\n        print(\"Epoch {}, Loss {:.4f}\".format(epoch, loss_all))\n\n        if loss_all < best:\n            best = loss_all\n            best_t = epoch\n            cnt_wait = 0\n            th.save(model.state_dict(), f\"{args.dataname}.pkl\")\n        else:\n            cnt_wait += 1\n\n        if cnt_wait == args.patience:\n            print(\"Early stopping\")\n            break\n\n    print(\"Training End\")\n\n    # Step 5:  Linear evaluation ========================================================== #\n    model.load_state_dict(th.load(f\"{args.dataname}.pkl\"))\n    embs = model.get_embedding(wholegraph, whole_dg, wholefeat, whole_weight)\n\n    acc_mean, acc_std = linearsvc(embs, lbls)\n    print(\"accuracy_mean, {:.4f}\".format(acc_mean))\n"
  },
  {
    "path": "examples/pytorch/mvgrl/graph/model.py",
    "content": "import torch as th\nimport torch.nn as nn\n\nfrom dgl.nn.pytorch import GraphConv\nfrom dgl.nn.pytorch.glob import SumPooling\nfrom utils import local_global_loss_\n\n\nclass MLP(nn.Module):\n    def __init__(self, in_dim, out_dim):\n        super(MLP, self).__init__()\n        self.fcs = nn.Sequential(\n            nn.Linear(in_dim, out_dim),\n            nn.PReLU(),\n            nn.Linear(out_dim, out_dim),\n            nn.PReLU(),\n            nn.Linear(out_dim, out_dim),\n            nn.PReLU(),\n        )\n        self.linear_shortcut = nn.Linear(in_dim, out_dim)\n\n    def forward(self, x):\n        return self.fcs(x) + self.linear_shortcut(x)\n\n\nclass GCN(nn.Module):\n    def __init__(self, in_dim, out_dim, num_layers, norm):\n        super(GCN, self).__init__()\n\n        self.num_layers = num_layers\n        self.layers = nn.ModuleList()\n\n        self.layers.append(\n            GraphConv(\n                in_dim, out_dim, bias=False, norm=norm, activation=nn.PReLU()\n            )\n        )\n        self.pooling = SumPooling()\n\n        for _ in range(num_layers - 1):\n            self.layers.append(\n                GraphConv(\n                    out_dim,\n                    out_dim,\n                    bias=False,\n                    norm=norm,\n                    activation=nn.PReLU(),\n                )\n            )\n\n    def forward(self, graph, feat, edge_weight=None):\n        h = self.layers[0](graph, feat, edge_weight=edge_weight)\n        hg = self.pooling(graph, h)\n\n        for idx in range(self.num_layers - 1):\n            h = self.layers[idx + 1](graph, h, edge_weight=edge_weight)\n            hg = th.cat((hg, self.pooling(graph, h)), -1)\n\n        return h, hg\n\n\nclass MVGRL(nn.Module):\n    r\"\"\"\n        mvgrl model\n    Parameters\n    -----------\n    in_dim: int\n        Input feature size.\n    out_dim: int\n        Output feature size.\n    num_layers: int\n        Number of the GNN encoder layers.\n    Functions\n    -----------\n    forward(graph1, graph2, feat, edge_weight):\n        graph1: DGLGraph\n            The original graph\n        graph2: DGLGraph\n            The diffusion graph\n        feat: tensor\n            Node features\n        edge_weight: tensor\n            Edge weight of the diffusion graph\n    \"\"\"\n\n    def __init__(self, in_dim, out_dim, num_layers):\n        super(MVGRL, self).__init__()\n        self.local_mlp = MLP(out_dim, out_dim)\n        self.global_mlp = MLP(num_layers * out_dim, out_dim)\n        self.encoder1 = GCN(in_dim, out_dim, num_layers, norm=\"both\")\n        self.encoder2 = GCN(in_dim, out_dim, num_layers, norm=\"none\")\n\n    def get_embedding(self, graph1, graph2, feat, edge_weight):\n        local_v1, global_v1 = self.encoder1(graph1, feat)\n        local_v2, global_v2 = self.encoder2(\n            graph2, feat, edge_weight=edge_weight\n        )\n\n        global_v1 = self.global_mlp(global_v1)\n        global_v2 = self.global_mlp(global_v2)\n\n        return (global_v1 + global_v2).detach()\n\n    def forward(self, graph1, graph2, feat, edge_weight, graph_id):\n        # calculate node embeddings and graph embeddings\n        local_v1, global_v1 = self.encoder1(graph1, feat)\n        local_v2, global_v2 = self.encoder2(\n            graph2, feat, edge_weight=edge_weight\n        )\n\n        local_v1 = self.local_mlp(local_v1)\n        local_v2 = self.local_mlp(local_v2)\n\n        global_v1 = self.global_mlp(global_v1)\n        global_v2 = self.global_mlp(global_v2)\n\n        # calculate loss\n        loss1 = local_global_loss_(local_v1, global_v2, graph_id)\n        loss2 = local_global_loss_(local_v2, global_v1, graph_id)\n\n        loss = loss1 + loss2\n\n        return loss\n"
  },
  {
    "path": "examples/pytorch/mvgrl/graph/utils.py",
    "content": "\"\"\" Code adapted from https://github.com/fanyun-sun/InfoGraph \"\"\"\nimport math\n\nimport numpy as np\nimport torch as th\nimport torch.nn.functional as F\nfrom sklearn.metrics import accuracy_score\nfrom sklearn.model_selection import GridSearchCV, StratifiedKFold\nfrom sklearn.svm import LinearSVC\n\n\ndef linearsvc(embeds, labels):\n    x = embeds.cpu().numpy()\n    y = labels.cpu().numpy()\n    params = {\"C\": [0.001, 0.01, 0.1, 1, 10, 100, 1000]}\n    kf = StratifiedKFold(n_splits=10, shuffle=True, random_state=None)\n    accuracies = []\n    for train_index, test_index in kf.split(x, y):\n        x_train, x_test = x[train_index], x[test_index]\n        y_train, y_test = y[train_index], y[test_index]\n        classifier = GridSearchCV(\n            LinearSVC(), params, cv=5, scoring=\"accuracy\", verbose=0\n        )\n        classifier.fit(x_train, y_train)\n        accuracies.append(accuracy_score(y_test, classifier.predict(x_test)))\n    return np.mean(accuracies), np.std(accuracies)\n\n\ndef get_positive_expectation(p_samples, average=True):\n    \"\"\"Computes the positive part of a JS Divergence.\n    Args:\n        p_samples: Positive samples.\n        average: Average the result over samples.\n    Returns:\n        th.Tensor\n    \"\"\"\n    log_2 = math.log(2.0)\n    Ep = log_2 - F.softplus(-p_samples)\n\n    if average:\n        return Ep.mean()\n    else:\n        return Ep\n\n\ndef get_negative_expectation(q_samples, average=True):\n    \"\"\"Computes the negative part of a JS Divergence.\n    Args:\n        q_samples: Negative samples.\n        average: Average the result over samples.\n    Returns:\n        th.Tensor\n    \"\"\"\n    log_2 = math.log(2.0)\n    Eq = F.softplus(-q_samples) + q_samples - log_2\n\n    if average:\n        return Eq.mean()\n    else:\n        return Eq\n\n\ndef local_global_loss_(l_enc, g_enc, graph_id):\n    num_graphs = g_enc.shape[0]\n    num_nodes = l_enc.shape[0]\n\n    device = g_enc.device\n\n    pos_mask = th.zeros((num_nodes, num_graphs)).to(device)\n    neg_mask = th.ones((num_nodes, num_graphs)).to(device)\n\n    for nodeidx, graphidx in enumerate(graph_id):\n        pos_mask[nodeidx][graphidx] = 1.0\n        neg_mask[nodeidx][graphidx] = 0.0\n\n    res = th.mm(l_enc, g_enc.t())\n\n    E_pos = get_positive_expectation(res * pos_mask, average=False).sum()\n    E_pos = E_pos / num_nodes\n    E_neg = get_negative_expectation(res * neg_mask, average=False).sum()\n    E_neg = E_neg / (num_nodes * (num_graphs - 1))\n\n    return E_neg - E_pos\n"
  },
  {
    "path": "examples/pytorch/mvgrl/node/dataset.py",
    "content": "\"\"\" Code adapted from https://github.com/kavehhassani/mvgrl \"\"\"\nimport dgl\nimport networkx as nx\nimport numpy as np\nimport scipy.sparse as sp\nimport torch as th\nfrom dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset\nfrom dgl.nn import APPNPConv\nfrom scipy.linalg import fractional_matrix_power, inv\nfrom sklearn.preprocessing import MinMaxScaler\n\n\ndef preprocess_features(features):\n    \"\"\"Row-normalize feature matrix and convert to tuple representation\"\"\"\n    rowsum = np.array(features.sum(1))\n    r_inv = np.power(rowsum, -1).flatten()\n    r_inv[np.isinf(r_inv)] = 0.0\n    r_mat_inv = sp.diags(r_inv)\n    features = r_mat_inv.dot(features)\n    if isinstance(features, np.ndarray):\n        return features\n    else:\n        return features.todense(), sparse_to_tuple(features)\n\n\ndef sparse_to_tuple(sparse_mx):\n    \"\"\"Convert sparse matrix to tuple representation.\"\"\"\n\n    def to_tuple(mx):\n        if not sp.isspmatrix_coo(mx):\n            mx = mx.tocoo()\n        coords = np.vstack((mx.row, mx.col)).transpose()\n        values = mx.data\n        shape = mx.shape\n        return coords, values, shape\n\n    if isinstance(sparse_mx, list):\n        for i in range(len(sparse_mx)):\n            sparse_mx[i] = to_tuple(sparse_mx[i])\n    else:\n        sparse_mx = to_tuple(sparse_mx)\n\n    return sparse_mx\n\n\ndef compute_ppr(graph: nx.Graph, alpha=0.2, self_loop=True):\n    a = nx.convert_matrix.to_numpy_array(graph)\n    if self_loop:\n        a = a + np.eye(a.shape[0])  # A^ = A + I_n\n    d = np.diag(np.sum(a, 1))  # D^ = Sigma A^_ii\n    dinv = fractional_matrix_power(d, -0.5)  # D^(-1/2)\n    at = np.matmul(np.matmul(dinv, a), dinv)  # A~ = D^(-1/2) x A^ x D^(-1/2)\n    return alpha * inv(\n        (np.eye(a.shape[0]) - (1 - alpha) * at)\n    )  # a(I_n-(1-a)A~)^-1\n\n\ndef process_dataset(name, epsilon):\n    if name == \"cora\":\n        dataset = CoraGraphDataset()\n    elif name == \"citeseer\":\n        dataset = CiteseerGraphDataset()\n\n    graph = dataset[0]\n    feat = graph.ndata.pop(\"feat\")\n    label = graph.ndata.pop(\"label\")\n\n    train_mask = graph.ndata.pop(\"train_mask\")\n    val_mask = graph.ndata.pop(\"val_mask\")\n    test_mask = graph.ndata.pop(\"test_mask\")\n\n    train_idx = th.nonzero(train_mask, as_tuple=False).squeeze()\n    val_idx = th.nonzero(val_mask, as_tuple=False).squeeze()\n    test_idx = th.nonzero(test_mask, as_tuple=False).squeeze()\n\n    nx_g = dgl.to_networkx(graph)\n\n    print(\"computing ppr\")\n    diff_adj = compute_ppr(nx_g, 0.2)\n    print(\"computing end\")\n\n    if name == \"citeseer\":\n        print(\"additional processing\")\n        feat = th.tensor(preprocess_features(feat.numpy())).float()\n        diff_adj[diff_adj < epsilon] = 0\n        scaler = MinMaxScaler()\n        scaler.fit(diff_adj)\n        diff_adj = scaler.transform(diff_adj)\n\n    diff_edges = np.nonzero(diff_adj)\n    diff_weight = diff_adj[diff_edges]\n    diff_graph = dgl.graph(diff_edges)\n\n    graph = graph.add_self_loop()\n\n    return (\n        graph,\n        diff_graph,\n        feat,\n        label,\n        train_idx,\n        val_idx,\n        test_idx,\n        diff_weight,\n    )\n\n\ndef process_dataset_appnp(epsilon):\n    k = 20\n    alpha = 0.2\n    dataset = PubmedGraphDataset()\n    graph = dataset[0]\n    feat = graph.ndata.pop(\"feat\")\n    label = graph.ndata.pop(\"label\")\n\n    train_mask = graph.ndata.pop(\"train_mask\")\n    val_mask = graph.ndata.pop(\"val_mask\")\n    test_mask = graph.ndata.pop(\"test_mask\")\n\n    train_idx = th.nonzero(train_mask, as_tuple=False).squeeze()\n    val_idx = th.nonzero(val_mask, as_tuple=False).squeeze()\n    test_idx = th.nonzero(test_mask, as_tuple=False).squeeze()\n\n    appnp = APPNPConv(k, alpha)\n    id = th.eye(graph.num_nodes()).float()\n    diff_adj = appnp(graph.add_self_loop(), id).numpy()\n\n    diff_adj[diff_adj < epsilon] = 0\n    scaler = MinMaxScaler()\n    scaler.fit(diff_adj)\n    diff_adj = scaler.transform(diff_adj)\n    diff_edges = np.nonzero(diff_adj)\n    diff_weight = diff_adj[diff_edges]\n    diff_graph = dgl.graph(diff_edges)\n\n    return (\n        graph,\n        diff_graph,\n        feat,\n        label,\n        train_idx,\n        val_idx,\n        test_idx,\n        diff_weight,\n    )\n"
  },
  {
    "path": "examples/pytorch/mvgrl/node/main.py",
    "content": "import argparse\nimport warnings\n\nimport numpy as np\nimport torch as th\nimport torch.nn as nn\n\nwarnings.filterwarnings(\"ignore\")\n\nfrom dataset import process_dataset\nfrom model import LogReg, MVGRL\n\nparser = argparse.ArgumentParser(description=\"mvgrl\")\n\nparser.add_argument(\n    \"--dataname\", type=str, default=\"cora\", help=\"Name of dataset.\"\n)\nparser.add_argument(\n    \"--gpu\", type=int, default=0, help=\"GPU index. Default: -1, using cpu.\"\n)\nparser.add_argument(\"--epochs\", type=int, default=500, help=\"Training epochs.\")\nparser.add_argument(\n    \"--patience\",\n    type=int,\n    default=20,\n    help=\"Patient epochs to wait before early stopping.\",\n)\nparser.add_argument(\n    \"--lr1\", type=float, default=0.001, help=\"Learning rate of mvgrl.\"\n)\nparser.add_argument(\n    \"--lr2\", type=float, default=0.01, help=\"Learning rate of linear evaluator.\"\n)\nparser.add_argument(\n    \"--wd1\", type=float, default=0.0, help=\"Weight decay of mvgrl.\"\n)\nparser.add_argument(\n    \"--wd2\", type=float, default=0.0, help=\"Weight decay of linear evaluator.\"\n)\nparser.add_argument(\n    \"--epsilon\",\n    type=float,\n    default=0.01,\n    help=\"Edge mask threshold of diffusion graph.\",\n)\nparser.add_argument(\n    \"--hid_dim\", type=int, default=512, help=\"Hidden layer dim.\"\n)\n\nargs = parser.parse_args()\n\n# check cuda\nif args.gpu != -1 and th.cuda.is_available():\n    args.device = \"cuda:{}\".format(args.gpu)\nelse:\n    args.device = \"cpu\"\n\nif __name__ == \"__main__\":\n    print(args)\n\n    # Step 1: Prepare data =================================================================== #\n    (\n        graph,\n        diff_graph,\n        feat,\n        label,\n        train_idx,\n        val_idx,\n        test_idx,\n        edge_weight,\n    ) = process_dataset(args.dataname, args.epsilon)\n    n_feat = feat.shape[1]\n    n_classes = np.unique(label).shape[0]\n\n    graph = graph.to(args.device)\n    diff_graph = diff_graph.to(args.device)\n    feat = feat.to(args.device)\n    edge_weight = th.tensor(edge_weight).float().to(args.device)\n\n    train_idx = train_idx.to(args.device)\n    val_idx = val_idx.to(args.device)\n    test_idx = test_idx.to(args.device)\n\n    n_node = graph.num_nodes()\n    lbl1 = th.ones(n_node * 2)\n    lbl2 = th.zeros(n_node * 2)\n    lbl = th.cat((lbl1, lbl2))\n\n    # Step 2: Create model =================================================================== #\n    model = MVGRL(n_feat, args.hid_dim)\n    model = model.to(args.device)\n\n    lbl = lbl.to(args.device)\n\n    # Step 3: Create training components ===================================================== #\n    optimizer = th.optim.Adam(\n        model.parameters(), lr=args.lr1, weight_decay=args.wd1\n    )\n    loss_fn = nn.BCEWithLogitsLoss()\n\n    # Step 4: Training epochs ================================================================ #\n    best = float(\"inf\")\n    cnt_wait = 0\n    for epoch in range(args.epochs):\n        model.train()\n        optimizer.zero_grad()\n\n        shuf_idx = np.random.permutation(n_node)\n        shuf_feat = feat[shuf_idx, :]\n        shuf_feat = shuf_feat.to(args.device)\n\n        out = model(graph, diff_graph, feat, shuf_feat, edge_weight)\n        loss = loss_fn(out, lbl)\n\n        loss.backward()\n        optimizer.step()\n\n        print(\"Epoch: {0}, Loss: {1:0.4f}\".format(epoch, loss.item()))\n\n        if loss < best:\n            best = loss\n            cnt_wait = 0\n            th.save(model.state_dict(), \"model.pkl\")\n        else:\n            cnt_wait += 1\n\n        if cnt_wait == args.patience:\n            print(\"Early stopping\")\n            break\n\n    model.load_state_dict(th.load(\"model.pkl\"))\n    embeds = model.get_embedding(graph, diff_graph, feat, edge_weight)\n\n    train_embs = embeds[train_idx]\n    test_embs = embeds[test_idx]\n\n    label = label.to(args.device)\n    train_labels = label[train_idx]\n    test_labels = label[test_idx]\n    accs = []\n\n    # Step 5:  Linear evaluation ========================================================== #\n    for _ in range(5):\n        model = LogReg(args.hid_dim, n_classes)\n        opt = th.optim.Adam(\n            model.parameters(), lr=args.lr2, weight_decay=args.wd2\n        )\n\n        model = model.to(args.device)\n        loss_fn = nn.CrossEntropyLoss()\n        for epoch in range(300):\n            model.train()\n            opt.zero_grad()\n            logits = model(train_embs)\n            loss = loss_fn(logits, train_labels)\n            loss.backward()\n            opt.step()\n\n        model.eval()\n        logits = model(test_embs)\n        preds = th.argmax(logits, dim=1)\n        acc = th.sum(preds == test_labels).float() / test_labels.shape[0]\n        accs.append(acc * 100)\n\n    accs = th.stack(accs)\n    print(accs.mean().item(), accs.std().item())\n"
  },
  {
    "path": "examples/pytorch/mvgrl/node/main_sample.py",
    "content": "import argparse\nimport random\nimport warnings\n\nimport dgl\n\nimport numpy as np\nimport torch as th\nimport torch.nn as nn\n\nwarnings.filterwarnings(\"ignore\")\n\nfrom dataset import process_dataset, process_dataset_appnp\nfrom model import LogReg, MVGRL\n\nparser = argparse.ArgumentParser(description=\"mvgrl\")\n\nparser.add_argument(\n    \"--dataname\", type=str, default=\"cora\", help=\"Name of dataset.\"\n)\nparser.add_argument(\n    \"--gpu\", type=int, default=-1, help=\"GPU index. Default: -1, using cpu.\"\n)\nparser.add_argument(\"--epochs\", type=int, default=500, help=\"Training epochs.\")\nparser.add_argument(\n    \"--patience\",\n    type=int,\n    default=20,\n    help=\"Patient epochs to wait before early stopping.\",\n)\nparser.add_argument(\n    \"--lr1\", type=float, default=0.001, help=\"Learning rate of mvgrl.\"\n)\nparser.add_argument(\n    \"--lr2\", type=float, default=0.01, help=\"Learning rate of linear evaluator.\"\n)\nparser.add_argument(\n    \"--wd1\", type=float, default=0.0, help=\"Weight decay of mvgrl.\"\n)\nparser.add_argument(\n    \"--wd2\", type=float, default=0.0, help=\"Weight decay of linear evaluator.\"\n)\nparser.add_argument(\n    \"--epsilon\",\n    type=float,\n    default=0.01,\n    help=\"Edge mask threshold of diffusion graph.\",\n)\nparser.add_argument(\n    \"--hid_dim\", type=int, default=512, help=\"Hidden layer dim.\"\n)\nparser.add_argument(\n    \"--sample_size\", type=int, default=2000, help=\"Subgraph size.\"\n)\n\nargs = parser.parse_args()\n\n# check cuda\nif args.gpu != -1 and th.cuda.is_available():\n    args.device = \"cuda:{}\".format(args.gpu)\nelse:\n    args.device = \"cpu\"\n\nif __name__ == \"__main__\":\n    print(args)\n\n    # Step 1: Prepare data =================================================================== #\n    if args.dataname == \"pubmed\":\n        (\n            graph,\n            diff_graph,\n            feat,\n            label,\n            train_idx,\n            val_idx,\n            test_idx,\n            edge_weight,\n        ) = process_dataset_appnp(args.epsilon)\n    else:\n        (\n            graph,\n            diff_graph,\n            feat,\n            label,\n            train_idx,\n            val_idx,\n            test_idx,\n            edge_weight,\n        ) = process_dataset(args.dataname, args.epsilon)\n    edge_weight = th.tensor(edge_weight).float()\n    graph.ndata[\"feat\"] = feat\n    diff_graph.edata[\"edge_weight\"] = edge_weight\n\n    n_feat = feat.shape[1]\n    n_classes = np.unique(label).shape[0]\n    edge_weight = th.tensor(edge_weight).float()\n\n    train_idx = train_idx.to(args.device)\n    val_idx = val_idx.to(args.device)\n    test_idx = test_idx.to(args.device)\n\n    n_node = graph.num_nodes()\n\n    sample_size = args.sample_size\n\n    lbl1 = th.ones(sample_size * 2)\n    lbl2 = th.zeros(sample_size * 2)\n    lbl = th.cat((lbl1, lbl2))\n    lbl = lbl.to(args.device)\n\n    # Step 2: Create model =================================================================== #\n    model = MVGRL(n_feat, args.hid_dim)\n    model = model.to(args.device)\n\n    # Step 3: Create training components ===================================================== #\n    optimizer = th.optim.Adam(\n        model.parameters(), lr=args.lr1, weight_decay=args.wd1\n    )\n    loss_fn = nn.BCEWithLogitsLoss()\n\n    node_list = list(range(n_node))\n\n    # Step 4: Training epochs ================================================================ #\n    best = float(\"inf\")\n    cnt_wait = 0\n    for epoch in range(args.epochs):\n        model.train()\n        optimizer.zero_grad()\n\n        sample_idx = random.sample(node_list, sample_size)\n\n        g = dgl.node_subgraph(graph, sample_idx)\n        dg = dgl.node_subgraph(diff_graph, sample_idx)\n\n        f = g.ndata.pop(\"feat\")\n        ew = dg.edata.pop(\"edge_weight\")\n\n        shuf_idx = np.random.permutation(sample_size)\n        sf = f[shuf_idx, :]\n\n        g = g.to(args.device)\n        dg = dg.to(args.device)\n        f = f.to(args.device)\n        ew = ew.to(args.device)\n\n        sf = sf.to(args.device)\n\n        out = model(g, dg, f, sf, ew)\n        loss = loss_fn(out, lbl)\n\n        loss.backward()\n        optimizer.step()\n\n        print(\"Epoch: {0}, Loss: {1:0.4f}\".format(epoch, loss.item()))\n\n        if loss < best:\n            best = loss\n            cnt_wait = 0\n            th.save(model.state_dict(), \"model.pkl\")\n        else:\n            cnt_wait += 1\n\n        if cnt_wait == args.patience:\n            print(\"Early stopping\")\n            break\n\n    model.load_state_dict(th.load(\"model.pkl\"))\n\n    graph = graph.to(args.device)\n    diff_graph = diff_graph.to(args.device)\n    feat = feat.to(args.device)\n    edge_weight = edge_weight.to(args.device)\n    embeds = model.get_embedding(graph, diff_graph, feat, edge_weight)\n\n    train_embs = embeds[train_idx]\n    test_embs = embeds[test_idx]\n\n    label = label.to(args.device)\n    train_labels = label[train_idx]\n    test_labels = label[test_idx]\n    accs = []\n\n    # Step 5:  Linear evaluation ========================================================== #\n    for _ in range(5):\n        model = LogReg(args.hid_dim, n_classes)\n        opt = th.optim.Adam(\n            model.parameters(), lr=args.lr2, weight_decay=args.wd2\n        )\n\n        model = model.to(args.device)\n        loss_fn = nn.CrossEntropyLoss()\n        for epoch in range(300):\n            model.train()\n            opt.zero_grad()\n            logits = model(train_embs)\n            loss = loss_fn(logits, train_labels)\n            loss.backward()\n            opt.step()\n\n        model.eval()\n        logits = model(test_embs)\n        preds = th.argmax(logits, dim=1)\n        acc = th.sum(preds == test_labels).float() / test_labels.shape[0]\n        accs.append(acc * 100)\n\n    accs = th.stack(accs)\n    print(accs.mean().item(), accs.std().item())\n"
  },
  {
    "path": "examples/pytorch/mvgrl/node/model.py",
    "content": "import torch as th\nimport torch.nn as nn\n\nfrom dgl.nn.pytorch import GraphConv\nfrom dgl.nn.pytorch.glob import AvgPooling\n\n\nclass LogReg(nn.Module):\n    def __init__(self, hid_dim, n_classes):\n        super(LogReg, self).__init__()\n\n        self.fc = nn.Linear(hid_dim, n_classes)\n\n    def forward(self, x):\n        ret = self.fc(x)\n        return ret\n\n\nclass Discriminator(nn.Module):\n    def __init__(self, dim):\n        super(Discriminator, self).__init__()\n        self.fn = nn.Bilinear(dim, dim, 1)\n\n    def forward(self, h1, h2, h3, h4, c1, c2):\n        c_x1 = c1.expand_as(h1).contiguous()\n        c_x2 = c2.expand_as(h2).contiguous()\n\n        # positive\n        sc_1 = self.fn(h2, c_x1).squeeze(1)\n        sc_2 = self.fn(h1, c_x2).squeeze(1)\n\n        # negative\n        sc_3 = self.fn(h4, c_x1).squeeze(1)\n        sc_4 = self.fn(h3, c_x2).squeeze(1)\n\n        logits = th.cat((sc_1, sc_2, sc_3, sc_4))\n\n        return logits\n\n\nclass MVGRL(nn.Module):\n    def __init__(self, in_dim, out_dim):\n        super(MVGRL, self).__init__()\n\n        self.encoder1 = GraphConv(\n            in_dim, out_dim, norm=\"both\", bias=True, activation=nn.PReLU()\n        )\n        self.encoder2 = GraphConv(\n            in_dim, out_dim, norm=\"none\", bias=True, activation=nn.PReLU()\n        )\n        self.pooling = AvgPooling()\n\n        self.disc = Discriminator(out_dim)\n        self.act_fn = nn.Sigmoid()\n\n    def get_embedding(self, graph, diff_graph, feat, edge_weight):\n        h1 = self.encoder1(graph, feat)\n        h2 = self.encoder2(diff_graph, feat, edge_weight=edge_weight)\n\n        return (h1 + h2).detach()\n\n    def forward(self, graph, diff_graph, feat, shuf_feat, edge_weight):\n        h1 = self.encoder1(graph, feat)\n        h2 = self.encoder2(diff_graph, feat, edge_weight=edge_weight)\n\n        h3 = self.encoder1(graph, shuf_feat)\n        h4 = self.encoder2(diff_graph, shuf_feat, edge_weight=edge_weight)\n\n        c1 = self.act_fn(self.pooling(graph, h1))\n        c2 = self.act_fn(self.pooling(graph, h2))\n\n        out = self.disc(h1, h2, h3, h4, c1, c2)\n\n        return out\n"
  },
  {
    "path": "examples/pytorch/node2vec/README.md",
    "content": "# DGL Implementation of the Node2vec\nThis DGL example implements the graph embedding model proposed in the paper \n[node2vec: Scalable Feature Learning for Networks](https://arxiv.org/abs/1607.00653) \n\nThe author's codes of implementation is in [Node2vec](https://github.com/aditya-grover/node2vec) \n\n\nExample implementor\n----------------------\nThis example was implemented by [Smile](https://github.com/Smilexuhc) during his intern work at the AWS Shanghai AI Lab.\n\nThe graph dataset used in this example \n---------------------------------------\n\ncora\n - NumNodes: 2708\n - NumEdges: 10556\n\nogbn-products\n - NumNodes: 2449029\n - NumEdges: 61859140\n\n \nDependencies\n--------------------------------\n\n- python 3.6+\n- Pytorch 1.5.0+\n- ogb  \n\n\n How to run example files\n--------------------------------\nTo train a node2vec model:\n```shell script\npython main.py --task=\"train\"\n```\n\nTo time node2vec random walks:\n```shell script\npython main.py --task=\"time\" --runs=10\n```\n\nPerformance\n-------------------------\n\n**Setting:** `walk_length=50, p=0.25, q=4.0`\n| Dataset  |     DGL     |     PyG     |\n| -------- | :---------: | :---------: |\n| cora     | 0.0092s | 0.0179s |\n| products | 66.22s  | 77.65s  |\nNote that the number in table are the average results of multiple trials.  \nFor cora, we run 50 trials.  For ogbn-products, we run 10 trials.\n"
  },
  {
    "path": "examples/pytorch/node2vec/main.py",
    "content": "import time\n\nfrom dgl.sampling import node2vec_random_walk\n\nfrom model import Node2vecModel\nfrom utils import load_graph, parse_arguments\n\n\ndef time_randomwalk(graph, args):\n    \"\"\"\n    Test cost time of random walk\n    \"\"\"\n\n    start_time = time.time()\n\n    # default setting for testing\n    params = {\"p\": 0.25, \"q\": 4, \"walk_length\": 50}\n\n    for i in range(args.runs):\n        node2vec_random_walk(graph, graph.nodes(), **params)\n    end_time = time.time()\n    cost_time_avg = (end_time - start_time) / args.runs\n    print(\n        \"Run dataset {} {} trials, mean run time: {:.3f}s\".format(\n            args.dataset, args.runs, cost_time_avg\n        )\n    )\n\n\ndef train_node2vec(graph, eval_set, args):\n    \"\"\"\n    Train node2vec model\n    \"\"\"\n    trainer = Node2vecModel(\n        graph,\n        embedding_dim=args.embedding_dim,\n        walk_length=args.walk_length,\n        p=args.p,\n        q=args.q,\n        num_walks=args.num_walks,\n        eval_set=eval_set,\n        eval_steps=1,\n        device=args.device,\n    )\n\n    trainer.train(\n        epochs=args.epochs, batch_size=args.batch_size, learning_rate=0.01\n    )\n\n\nif __name__ == \"__main__\":\n    args = parse_arguments()\n    graph, eval_set = load_graph(args.dataset)\n\n    if args.task == \"train\":\n        print(\"Perform training node2vec model\")\n        train_node2vec(graph, eval_set, args)\n    elif args.task == \"time\":\n        print(\"Timing random walks\")\n        time_randomwalk(graph, args)\n    else:\n        raise ValueError(\"Task type error!\")\n"
  },
  {
    "path": "examples/pytorch/node2vec/model.py",
    "content": "import torch\nimport torch.nn as nn\n\nfrom dgl.sampling import node2vec_random_walk\nfrom sklearn.linear_model import LogisticRegression\nfrom torch.utils.data import DataLoader\n\n\nclass Node2vec(nn.Module):\n    \"\"\"Node2vec model from paper node2vec: Scalable Feature Learning for Networks <https://arxiv.org/abs/1607.00653>\n    Attributes\n    ----------\n    g: DGLGraph\n        The graph.\n    embedding_dim: int\n        Dimension of node embedding.\n    walk_length: int\n        Length of each trace.\n    p: float\n        Likelihood of immediately revisiting a node in the walk.  Same notation as in the paper.\n    q: float\n        Control parameter to interpolate between breadth-first strategy and depth-first strategy.\n        Same notation as in the paper.\n    num_walks: int\n        Number of random walks for each node. Default: 10.\n    window_size: int\n        Maximum distance between the center node and predicted node. Default: 5.\n    num_negatives: int\n        The number of negative samples for each positive sample.  Default: 5.\n    use_sparse: bool\n        If set to True, use PyTorch's sparse embedding and optimizer. Default: ``True``.\n    weight_name : str, optional\n        The name of the edge feature tensor on the graph storing the (unnormalized)\n        probabilities associated with each edge for choosing the next node.\n\n        The feature tensor must be non-negative and the sum of the probabilities\n        must be positive for the outbound edges of all nodes (although they don't have\n        to sum up to one).  The result will be undefined otherwise.\n\n        If omitted, DGL assumes that the neighbors are picked uniformly.\n    \"\"\"\n\n    def __init__(\n        self,\n        g,\n        embedding_dim,\n        walk_length,\n        p,\n        q,\n        num_walks=10,\n        window_size=5,\n        num_negatives=5,\n        use_sparse=True,\n        weight_name=None,\n    ):\n        super(Node2vec, self).__init__()\n\n        assert walk_length >= window_size\n\n        self.g = g\n        self.embedding_dim = embedding_dim\n        self.walk_length = walk_length\n        self.p = p\n        self.q = q\n        self.num_walks = num_walks\n        self.window_size = window_size\n        self.num_negatives = num_negatives\n        self.N = self.g.num_nodes()\n        if weight_name is not None:\n            self.prob = weight_name\n        else:\n            self.prob = None\n\n        self.embedding = nn.Embedding(self.N, embedding_dim, sparse=use_sparse)\n\n    def reset_parameters(self):\n        self.embedding.reset_parameters()\n\n    def sample(self, batch):\n        \"\"\"\n        Generate positive and negative samples.\n        Positive samples are generated from random walk\n        Negative samples are generated from random sampling\n        \"\"\"\n        if not isinstance(batch, torch.Tensor):\n            batch = torch.tensor(batch)\n\n        batch = batch.repeat(self.num_walks)\n        # positive\n        pos_traces = node2vec_random_walk(\n            self.g, batch, self.p, self.q, self.walk_length, self.prob\n        )\n        pos_traces = pos_traces.unfold(1, self.window_size, 1)  # rolling window\n        pos_traces = pos_traces.contiguous().view(-1, self.window_size)\n\n        # negative\n        neg_batch = batch.repeat(self.num_negatives)\n        neg_traces = torch.randint(\n            self.N, (neg_batch.size(0), self.walk_length)\n        )\n        neg_traces = torch.cat([neg_batch.view(-1, 1), neg_traces], dim=-1)\n        neg_traces = neg_traces.unfold(1, self.window_size, 1)  # rolling window\n        neg_traces = neg_traces.contiguous().view(-1, self.window_size)\n\n        return pos_traces, neg_traces\n\n    def forward(self, nodes=None):\n        \"\"\"\n        Returns the embeddings of the input nodes\n        Parameters\n        ----------\n        nodes: Tensor, optional\n            Input nodes, if set `None`, will return all the node embedding.\n\n        Returns\n        -------\n        Tensor\n            Node embedding\n\n        \"\"\"\n        emb = self.embedding.weight\n        if nodes is None:\n            return emb\n        else:\n            return emb[nodes]\n\n    def loss(self, pos_trace, neg_trace):\n        \"\"\"\n        Computes the loss given positive and negative random walks.\n        Parameters\n        ----------\n        pos_trace: Tensor\n            positive random walk trace\n        neg_trace: Tensor\n            negative random walk trace\n\n        \"\"\"\n        e = 1e-15\n\n        # Positive\n        pos_start, pos_rest = (\n            pos_trace[:, 0],\n            pos_trace[:, 1:].contiguous(),\n        )  # start node and following trace\n        w_start = self.embedding(pos_start).unsqueeze(dim=1)\n        w_rest = self.embedding(pos_rest)\n        pos_out = (w_start * w_rest).sum(dim=-1).view(-1)\n\n        # Negative\n        neg_start, neg_rest = neg_trace[:, 0], neg_trace[:, 1:].contiguous()\n\n        w_start = self.embedding(neg_start).unsqueeze(dim=1)\n        w_rest = self.embedding(neg_rest)\n        neg_out = (w_start * w_rest).sum(dim=-1).view(-1)\n\n        # compute loss\n        pos_loss = -torch.log(torch.sigmoid(pos_out) + e).mean()\n        neg_loss = -torch.log(1 - torch.sigmoid(neg_out) + e).mean()\n\n        return pos_loss + neg_loss\n\n    def loader(self, batch_size):\n        \"\"\"\n\n        Parameters\n        ----------\n        batch_size: int\n            batch size\n\n        Returns\n        -------\n        DataLoader\n            Node2vec training data loader\n\n        \"\"\"\n        return DataLoader(\n            torch.arange(self.N),\n            batch_size=batch_size,\n            shuffle=True,\n            collate_fn=self.sample,\n        )\n\n    @torch.no_grad()\n    def evaluate(self, x_train, y_train, x_val, y_val):\n        \"\"\"\n        Evaluate the quality of embedding vector via a downstream classification task with logistic regression.\n        \"\"\"\n        x_train = self.forward(x_train)\n        x_val = self.forward(x_val)\n\n        x_train, y_train = x_train.cpu().numpy(), y_train.cpu().numpy()\n        x_val, y_val = x_val.cpu().numpy(), y_val.cpu().numpy()\n        lr = LogisticRegression(\n            solver=\"lbfgs\", multi_class=\"auto\", max_iter=150\n        ).fit(x_train, y_train)\n\n        return lr.score(x_val, y_val)\n\n\nclass Node2vecModel(object):\n    \"\"\"\n    Wrapper of the ``Node2Vec`` class with a ``train`` method.\n    Attributes\n    ----------\n    g: DGLGraph\n        The graph.\n    embedding_dim: int\n        Dimension of node embedding.\n    walk_length: int\n        Length of each trace.\n    p: float\n        Likelihood of immediately revisiting a node in the walk.\n    q: float\n        Control parameter to interpolate between breadth-first strategy and depth-first strategy.\n    num_walks: int\n        Number of random walks for each node. Default: 10.\n    window_size: int\n        Maximum distance between the center node and predicted node. Default: 5.\n    num_negatives: int\n        The number of negative samples for each positive sample.  Default: 5.\n    use_sparse: bool\n        If set to True, uses PyTorch's sparse embedding and optimizer. Default: ``True``.\n    weight_name : str, optional\n        The name of the edge feature tensor on the graph storing the (unnormalized)\n        probabilities associated with each edge for choosing the next node.\n\n        The feature tensor must be non-negative and the sum of the probabilities\n        must be positive for the outbound edges of all nodes (although they don't have\n        to sum up to one).  The result will be undefined otherwise.\n\n        If omitted, DGL assumes that the neighbors are picked uniformly. Default: ``None``.\n    eval_set: list of tuples (Tensor, Tensor)\n        [(nodes_train,y_train),(nodes_val,y_val)]\n        If omitted, model will not be evaluated. Default: ``None``.\n    eval_steps: int\n        Interval steps of evaluation.\n        if set <= 0, model will not be evaluated. Default: ``None``.\n    device: str\n        device, default 'cpu'.\n    \"\"\"\n\n    def __init__(\n        self,\n        g,\n        embedding_dim,\n        walk_length,\n        p=1.0,\n        q=1.0,\n        num_walks=1,\n        window_size=5,\n        num_negatives=5,\n        use_sparse=True,\n        weight_name=None,\n        eval_set=None,\n        eval_steps=-1,\n        device=\"cpu\",\n    ):\n        self.model = Node2vec(\n            g,\n            embedding_dim,\n            walk_length,\n            p,\n            q,\n            num_walks,\n            window_size,\n            num_negatives,\n            use_sparse,\n            weight_name,\n        )\n        self.g = g\n        self.use_sparse = use_sparse\n        self.eval_steps = eval_steps\n        self.eval_set = eval_set\n\n        if device == \"cpu\":\n            self.device = device\n        else:\n            self.device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n    def _train_step(self, model, loader, optimizer, device):\n        model.train()\n        total_loss = 0\n        for pos_traces, neg_traces in loader:\n            pos_traces, neg_traces = pos_traces.to(device), neg_traces.to(\n                device\n            )\n            optimizer.zero_grad()\n            loss = model.loss(pos_traces, neg_traces)\n            loss.backward()\n            optimizer.step()\n            total_loss += loss.item()\n        return total_loss / len(loader)\n\n    @torch.no_grad()\n    def _evaluate_step(self):\n        nodes_train, y_train = self.eval_set[0]\n        nodes_val, y_val = self.eval_set[1]\n\n        acc = self.model.evaluate(nodes_train, y_train, nodes_val, y_val)\n        return acc\n\n    def train(self, epochs, batch_size, learning_rate=0.01):\n        \"\"\"\n\n        Parameters\n        ----------\n        epochs: int\n            num of train epoch\n        batch_size: int\n            batch size\n        learning_rate: float\n            learning rate. Default 0.01.\n\n        \"\"\"\n\n        self.model = self.model.to(self.device)\n        loader = self.model.loader(batch_size)\n        if self.use_sparse:\n            optimizer = torch.optim.SparseAdam(\n                list(self.model.parameters()), lr=learning_rate\n            )\n        else:\n            optimizer = torch.optim.Adam(\n                self.model.parameters(), lr=learning_rate\n            )\n        for i in range(epochs):\n            loss = self._train_step(self.model, loader, optimizer, self.device)\n            if self.eval_steps > 0:\n                if epochs % self.eval_steps == 0:\n                    acc = self._evaluate_step()\n                    print(\n                        \"Epoch: {}, Train Loss: {:.4f}, Val Acc: {:.4f}\".format(\n                            i, loss, acc\n                        )\n                    )\n\n    def embedding(self, nodes=None):\n        \"\"\"\n        Returns the embeddings of the input nodes\n        Parameters\n        ----------\n        nodes: Tensor, optional\n            Input nodes, if set `None`, will return all the node embedding.\n\n        Returns\n        -------\n        Tensor\n            Node embedding.\n        \"\"\"\n\n        return self.model(nodes)\n"
  },
  {
    "path": "examples/pytorch/node2vec/utils.py",
    "content": "import argparse\n\nfrom ogb.linkproppred import *\nfrom ogb.nodeproppred import *\n\nfrom dgl.data import CitationGraphDataset\n\n\ndef load_graph(name):\n    cite_graphs = [\"cora\", \"citeseer\", \"pubmed\"]\n\n    if name in cite_graphs:\n        dataset = CitationGraphDataset(name)\n        graph = dataset[0]\n\n        nodes = graph.nodes()\n        y = graph.ndata[\"label\"]\n        train_mask = graph.ndata[\"train_mask\"]\n        val_mask = graph.ndata[\"test_mask\"]\n\n        nodes_train, y_train = nodes[train_mask], y[train_mask]\n        nodes_val, y_val = nodes[val_mask], y[val_mask]\n        eval_set = [(nodes_train, y_train), (nodes_val, y_val)]\n\n    elif name.startswith(\"ogbn\"):\n        dataset = DglNodePropPredDataset(name)\n        graph, y = dataset[0]\n        split_nodes = dataset.get_idx_split()\n        nodes = graph.nodes()\n\n        train_idx = split_nodes[\"train\"]\n        val_idx = split_nodes[\"valid\"]\n\n        nodes_train, y_train = nodes[train_idx], y[train_idx]\n        nodes_val, y_val = nodes[val_idx], y[val_idx]\n        eval_set = [(nodes_train, y_train), (nodes_val, y_val)]\n\n    else:\n        raise ValueError(\"Dataset name error!\")\n\n    return graph, eval_set\n\n\ndef parse_arguments():\n    \"\"\"\n    Parse arguments\n    \"\"\"\n    parser = argparse.ArgumentParser(description=\"Node2vec\")\n    parser.add_argument(\"--dataset\", type=str, default=\"cora\")\n    # 'train' for training node2vec model, 'time' for testing speed of random walk\n    parser.add_argument(\"--task\", type=str, default=\"train\")\n    parser.add_argument(\"--runs\", type=int, default=10)\n    parser.add_argument(\"--device\", type=str, default=\"cpu\")\n    parser.add_argument(\"--embedding_dim\", type=int, default=128)\n    parser.add_argument(\"--walk_length\", type=int, default=50)\n    parser.add_argument(\"--p\", type=float, default=0.25)\n    parser.add_argument(\"--q\", type=float, default=4.0)\n    parser.add_argument(\"--num_walks\", type=int, default=10)\n    parser.add_argument(\"--epochs\", type=int, default=100)\n    parser.add_argument(\"--batch_size\", type=int, default=128)\n\n    args = parser.parse_args()\n\n    return args\n"
  },
  {
    "path": "examples/pytorch/ogb/README.md",
    "content": "# OGB Submissions\n\nThis directory lists the submissions made from DGL Team to the OGB Leaderboard.\n\nCurrently it contains:\n\n* OGBN-Products\n  * GraphSAGE with Neighbor Sampling\n  * SIGN\n* OGBN-Proteins\n  * MWE-GCN and MWE-DGCN ([GCN models for graphs with multi-dimensionally weighted edges](https://cims.nyu.edu/~chenzh/files/GCN_with_edge_weights.pdf))\n* OGBN-Arxiv\n  * SIGN\n* OGBN-Mag\n  * SIGN\n"
  },
  {
    "path": "examples/pytorch/ogb/cluster-gat/README.md",
    "content": "# ClusterGAT \nParams: 1540848\n\n## OGB Products\nRun `main.py` and you should directly see the result.\n\nValid over 10 runs: 0.8985 ± 0.00224\nAccuracy over 10 runs: 0.79232 ± 0.007786 \n"
  },
  {
    "path": "examples/pytorch/ogb/cluster-gat/main.py",
    "content": "import argparse\nimport time\nfrom functools import partial\n\nimport dgl\nimport dgl.nn.pytorch as dglnn\n\nimport numpy as np\nimport torch as th\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nimport tqdm\nfrom ogb.nodeproppred import DglNodePropPredDataset\nfrom sampler import ClusterIter, subgraph_collate_fn\nfrom torch.utils.data import DataLoader\n\n\nclass GAT(nn.Module):\n    def __init__(\n        self,\n        in_feats,\n        num_heads,\n        n_hidden,\n        n_classes,\n        n_layers,\n        activation,\n        dropout=0.0,\n    ):\n        super().__init__()\n        self.n_layers = n_layers\n        self.n_hidden = n_hidden\n        self.n_classes = n_classes\n        self.layers = nn.ModuleList()\n        self.num_heads = num_heads\n        self.layers.append(\n            dglnn.GATConv(\n                in_feats,\n                n_hidden,\n                num_heads=num_heads,\n                feat_drop=dropout,\n                attn_drop=dropout,\n                activation=activation,\n                negative_slope=0.2,\n            )\n        )\n        for i in range(1, n_layers - 1):\n            self.layers.append(\n                dglnn.GATConv(\n                    n_hidden * num_heads,\n                    n_hidden,\n                    num_heads=num_heads,\n                    feat_drop=dropout,\n                    attn_drop=dropout,\n                    activation=activation,\n                    negative_slope=0.2,\n                )\n            )\n        self.layers.append(\n            dglnn.GATConv(\n                n_hidden * num_heads,\n                n_classes,\n                num_heads=num_heads,\n                feat_drop=dropout,\n                attn_drop=dropout,\n                activation=None,\n                negative_slope=0.2,\n            )\n        )\n\n    def forward(self, g, x):\n        h = x\n        for l, conv in enumerate(self.layers):\n            h = conv(g, h)\n            if l < len(self.layers) - 1:\n                h = h.flatten(1)\n        h = h.mean(1)\n        return h.log_softmax(dim=-1)\n\n    def inference(self, g, x, batch_size, device):\n        \"\"\"\n        Inference with the GAT model on full neighbors (i.e. without neighbor sampling).\n        g : the entire graph.\n        x : the input of entire node set.\n        The inference code is written in a fashion that it could handle any number of nodes and\n        layers.\n        \"\"\"\n        num_heads = self.num_heads\n        for l, layer in enumerate(self.layers):\n            if l < self.n_layers - 1:\n                y = th.zeros(\n                    g.num_nodes(),\n                    self.n_hidden * num_heads\n                    if l != len(self.layers) - 1\n                    else self.n_classes,\n                )\n            else:\n                y = th.zeros(\n                    g.num_nodes(),\n                    self.n_hidden\n                    if l != len(self.layers) - 1\n                    else self.n_classes,\n                )\n            sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)\n            dataloader = dgl.dataloading.DataLoader(\n                g,\n                th.arange(g.num_nodes()),\n                sampler,\n                batch_size=batch_size,\n                shuffle=False,\n                drop_last=False,\n                num_workers=args.num_workers,\n            )\n\n            with dataloader.enable_cpu_affinity():\n                for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):\n                    block = blocks[0].int().to(device)\n                    h = x[input_nodes].to(device)\n                    if l < self.n_layers - 1:\n                        h = layer(block, h).flatten(1)\n                    else:\n                        h = layer(block, h)\n                        h = h.mean(1)\n                        h = h.log_softmax(dim=-1)\n\n                    y[output_nodes] = h.cpu()\n            x = y\n        return y\n\n\ndef compute_acc(pred, labels):\n    \"\"\"\n    Compute the accuracy of prediction given the labels.\n    \"\"\"\n    return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred)\n\n\ndef evaluate(model, g, nfeat, labels, val_nid, test_nid, batch_size, device):\n    \"\"\"\n    Evaluate the model on the validation set specified by ``val_mask``.\n    g : The entire graph.\n    inputs : The features of all the nodes.\n    labels : The labels of all the nodes.\n    val_mask : A 0-1 mask indicating which nodes do we actually compute the accuracy for.\n    batch_size : Number of nodes to compute at the same time.\n    device : The GPU device to evaluate on.\n    \"\"\"\n    model.eval()\n    with th.no_grad():\n        pred = model.inference(g, nfeat, batch_size, device)\n    model.train()\n    labels_cpu = labels.to(th.device(\"cpu\"))\n    return (\n        compute_acc(pred[val_nid], labels_cpu[val_nid]),\n        compute_acc(pred[test_nid], labels_cpu[test_nid]),\n        pred,\n    )\n\n\ndef model_param_summary(model):\n    \"\"\"Count the model parameters\"\"\"\n    cnt = sum(p.numel() for p in model.parameters() if p.requires_grad)\n    print(\"Total Params {}\".format(cnt))\n\n\n#### Entry point\ndef run(args, device, data, nfeat):\n    # Unpack data\n    (\n        train_nid,\n        val_nid,\n        test_nid,\n        in_feats,\n        labels,\n        n_classes,\n        g,\n        cluster_iterator,\n    ) = data\n    labels = labels.to(device)\n\n    # Define model and optimizer\n    model = GAT(\n        in_feats,\n        args.num_heads,\n        args.num_hidden,\n        n_classes,\n        args.num_layers,\n        F.relu,\n        args.dropout,\n    )\n    model_param_summary(model)\n    model = model.to(device)\n    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)\n\n    # Training loop\n    avg = 0\n    best_eval_acc = 0\n    best_test_acc = 0\n    for epoch in range(args.num_epochs):\n        iter_load = 0\n        iter_far = 0\n        iter_back = 0\n        tic = time.time()\n\n        # Loop over the dataloader to sample the computation dependency graph as a list of\n        # blocks.\n        tic_start = time.time()\n        for step, cluster in enumerate(cluster_iterator):\n            mask = cluster.ndata.pop(\"train_mask\")\n            if mask.sum() == 0:\n                continue\n            cluster.edata.pop(dgl.EID)\n            cluster = cluster.int().to(device)\n            input_nodes = cluster.ndata[dgl.NID]\n            batch_inputs = nfeat[input_nodes]\n            batch_labels = labels[input_nodes]\n            tic_step = time.time()\n\n            # Compute loss and prediction\n            batch_pred = model(cluster, batch_inputs)\n            batch_pred = batch_pred[mask]\n            batch_labels = batch_labels[mask]\n            loss = nn.functional.nll_loss(batch_pred, batch_labels)\n            optimizer.zero_grad()\n            tic_far = time.time()\n            loss.backward()\n            optimizer.step()\n            tic_back = time.time()\n            iter_load += tic_step - tic_start\n            iter_far += tic_far - tic_step\n            iter_back += tic_back - tic_far\n\n            if step % args.log_every == 0:\n                acc = compute_acc(batch_pred, batch_labels)\n                gpu_mem_alloc = (\n                    th.cuda.max_memory_allocated() / 1000000\n                    if th.cuda.is_available()\n                    else 0\n                )\n                print(\n                    \"Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | GPU {:.1f} MB\".format(\n                        epoch, step, loss.item(), acc.item(), gpu_mem_alloc\n                    )\n                )\n                tic_start = time.time()\n\n        toc = time.time()\n        print(\n            \"Epoch Time(s): {:.4f} Load {:.4f} Forward {:.4f} Backward {:.4f}\".format(\n                toc - tic, iter_load, iter_far, iter_back\n            )\n        )\n        if epoch >= 5:\n            avg += toc - tic\n\n        if epoch % args.eval_every == 0 and epoch != 0:\n            eval_acc, test_acc, pred = evaluate(\n                model,\n                g,\n                nfeat,\n                labels,\n                val_nid,\n                test_nid,\n                args.val_batch_size,\n                device,\n            )\n            model = model.to(device)\n            if args.save_pred:\n                np.savetxt(\n                    args.save_pred + \"%02d\" % epoch,\n                    pred.argmax(1).cpu().numpy(),\n                    \"%d\",\n                )\n            print(\"Eval Acc {:.4f}\".format(eval_acc))\n            if eval_acc > best_eval_acc:\n                best_eval_acc = eval_acc\n                best_test_acc = test_acc\n            print(\n                \"Best Eval Acc {:.4f} Test Acc {:.4f}\".format(\n                    best_eval_acc, best_test_acc\n                )\n            )\n\n    if epoch >= 5:\n        print(\"Avg epoch time: {}\".format(avg / (epoch - 4)))\n    return best_test_acc.to(th.device(\"cpu\"))\n\n\nif __name__ == \"__main__\":\n    argparser = argparse.ArgumentParser(\"multi-gpu training\")\n    argparser.add_argument(\n        \"--gpu\",\n        type=int,\n        default=0,\n        help=\"GPU device ID. Use -1 for CPU training\",\n    )\n    argparser.add_argument(\"--num_epochs\", type=int, default=20)\n    argparser.add_argument(\"--num_hidden\", type=int, default=128)\n    argparser.add_argument(\"--num_layers\", type=int, default=3)\n    argparser.add_argument(\"--num_heads\", type=int, default=8)\n    argparser.add_argument(\"--batch_size\", type=int, default=32)\n    argparser.add_argument(\"--val_batch_size\", type=int, default=2000)\n    argparser.add_argument(\"--log_every\", type=int, default=20)\n    argparser.add_argument(\"--eval_every\", type=int, default=1)\n    argparser.add_argument(\"--lr\", type=float, default=0.001)\n    argparser.add_argument(\"--dropout\", type=float, default=0.5)\n    argparser.add_argument(\"--save_pred\", type=str, default=\"\")\n    argparser.add_argument(\"--wd\", type=float, default=0)\n    argparser.add_argument(\"--num_partitions\", type=int, default=15000)\n    argparser.add_argument(\"--num_workers\", type=int, default=4)\n    argparser.add_argument(\n        \"--data_cpu\",\n        action=\"store_true\",\n        help=\"By default the script puts all node features and labels \"\n        \"on GPU when using it to save time for data copy. This may \"\n        \"be undesired if they cannot fit in GPU memory at once. \"\n        \"This flag disables that.\",\n    )\n    args = argparser.parse_args()\n\n    if args.gpu >= 0:\n        device = th.device(\"cuda:%d\" % args.gpu)\n    else:\n        device = th.device(\"cpu\")\n\n    # load ogbn-products data\n    data = DglNodePropPredDataset(name=\"ogbn-products\")\n    splitted_idx = data.get_idx_split()\n    train_idx, val_idx, test_idx = (\n        splitted_idx[\"train\"],\n        splitted_idx[\"valid\"],\n        splitted_idx[\"test\"],\n    )\n    graph, labels = data[0]\n    labels = labels[:, 0]\n    print(\"Total edges before adding self-loop {}\".format(graph.num_edges()))\n    graph = dgl.remove_self_loop(graph)\n    graph = dgl.add_self_loop(graph)\n    print(\"Total edges after adding self-loop {}\".format(graph.num_edges()))\n    num_nodes = train_idx.shape[0] + val_idx.shape[0] + test_idx.shape[0]\n    assert num_nodes == graph.num_nodes()\n    mask = th.zeros(num_nodes, dtype=th.bool)\n    mask[train_idx] = True\n    graph.ndata[\"train_mask\"] = mask\n\n    graph.in_degrees(0)\n    graph.out_degrees(0)\n    graph.find_edges(0)\n\n    cluster_iter_data = ClusterIter(\n        \"ogbn-products\", graph, args.num_partitions, args.batch_size\n    )\n    cluster_iterator = DataLoader(\n        cluster_iter_data,\n        batch_size=args.batch_size,\n        shuffle=True,\n        pin_memory=True,\n        num_workers=args.num_workers,\n        collate_fn=partial(subgraph_collate_fn, graph),\n    )\n\n    in_feats = graph.ndata[\"feat\"].shape[1]\n    n_classes = (labels.max() + 1).item()\n    # Pack data\n    data = (\n        train_idx,\n        val_idx,\n        test_idx,\n        in_feats,\n        labels,\n        n_classes,\n        graph,\n        cluster_iterator,\n    )\n\n    # Run 10 times\n    test_accs = []\n    nfeat = graph.ndata.pop(\"feat\").to(device)\n    for i in range(10):\n        test_accs.append(run(args, device, data, nfeat))\n\n    print(\"Average test accuracy:\", np.mean(test_accs), \"±\", np.std(test_accs))\n"
  },
  {
    "path": "examples/pytorch/ogb/cluster-gat/partition_utils.py",
    "content": "from time import time\n\nimport dgl\n\nimport numpy as np\nfrom dgl import backend as F\nfrom dgl.transforms import metis_partition\n\n\ndef get_partition_list(g, psize):\n    p_gs = metis_partition(g, psize)\n    graphs = []\n    for k, val in p_gs.items():\n        nids = val.ndata[dgl.NID]\n        nids = F.asnumpy(nids)\n        graphs.append(nids)\n    return graphs\n"
  },
  {
    "path": "examples/pytorch/ogb/cluster-gat/sampler.py",
    "content": "import os\n\nimport torch\nfrom partition_utils import *\n\n\nclass ClusterIter(object):\n    \"\"\"The partition sampler given a DGLGraph and partition number.\n    The metis is used as the graph partition backend.\n    \"\"\"\n\n    def __init__(self, dn, g, psize, batch_size):\n        \"\"\"Initialize the sampler.\n\n        Paramters\n        ---------\n        dn : str\n            The dataset name.\n        g  : DGLGraph\n            The full graph of dataset\n        psize: int\n            The partition number\n        batch_size: int\n            The number of partitions in one batch\n        \"\"\"\n        self.psize = psize\n        self.batch_size = batch_size\n        # cache the partitions of known datasets&partition number\n        if dn:\n            fn = os.path.join(\"./datasets/\", dn + \"_{}.npy\".format(psize))\n            if os.path.exists(fn):\n                self.par_li = np.load(fn, allow_pickle=True)\n            else:\n                os.makedirs(\"./datasets/\", exist_ok=True)\n                self.par_li = get_partition_list(g, psize)\n                self.par_li = np.array(self.par_li, dtype=object)\n                np.save(fn, self.par_li)\n        else:\n            self.par_li = get_partition_list(g, psize)\n        par_list = []\n        for p in self.par_li:\n            par = torch.Tensor(p)\n            par_list.append(par)\n        self.par_list = par_list\n\n    def __len__(self):\n        return self.psize\n\n    def __getitem__(self, idx):\n        return self.par_li[idx]\n\n\ndef subgraph_collate_fn(g, batch):\n    nids = np.concatenate(batch).reshape(-1).astype(np.int64)\n    g1 = g.subgraph(nids)\n    g1 = dgl.remove_self_loop(g1)\n    g1 = dgl.add_self_loop(g1)\n    return g1\n"
  },
  {
    "path": "examples/pytorch/ogb/cluster-sage/README.md",
    "content": "# Cluster-SAGE on OGB Dataset\n\nRequires DGL 0.4.3post2 or later versions.\nWe use builtin metis to do the graph partition.\n\n## OGB-Product\n\nRun `main.py` and you should directly see the result.\n\nAccuracy over 10 runs: 0.7830701 ± 0.0035093208\n"
  },
  {
    "path": "examples/pytorch/ogb/cluster-sage/main.py",
    "content": "import argparse\nimport time\nimport traceback\nfrom functools import partial\n\nimport dgl\nimport dgl.function as fn\nimport dgl.nn.pytorch as dglnn\n\nimport numpy as np\nimport torch as th\nimport torch.multiprocessing as mp\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nimport tqdm\nfrom dgl.data import RedditDataset\nfrom ogb.nodeproppred import DglNodePropPredDataset\nfrom sampler import ClusterIter, subgraph_collate_fn\nfrom torch.utils.data import DataLoader\n\n#### Neighbor sampler\n\n\nclass SAGE(nn.Module):\n    def __init__(\n        self, in_feats, n_hidden, n_classes, n_layers, activation, dropout\n    ):\n        super().__init__()\n        self.n_layers = n_layers\n        self.n_hidden = n_hidden\n        self.n_classes = n_classes\n        self.layers = nn.ModuleList()\n        self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, \"mean\"))\n        for i in range(1, n_layers - 1):\n            self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, \"mean\"))\n        self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, \"mean\"))\n        self.dropout = nn.Dropout(dropout)\n        self.activation = activation\n\n    def forward(self, g, x):\n        h = x\n        for l, conv in enumerate(self.layers):\n            h = conv(g, h)\n            if l != len(self.layers) - 1:\n                h = self.activation(h)\n                h = self.dropout(h)\n        return h\n\n    def inference(self, g, x, batch_size, device):\n        \"\"\"\n        Inference with the GraphSAGE model on full neighbors (i.e. without neighbor sampling).\n        g : the entire graph.\n        x : the input of entire node set.\n        The inference code is written in a fashion that it could handle any number of nodes and\n        layers.\n        \"\"\"\n        # During inference with sampling, multi-layer blocks are very inefficient because\n        # lots of computations in the first few layers are repeated.\n        # Therefore, we compute the representation of all nodes layer by layer.  The nodes\n        # on each layer are of course splitted in batches.\n        # TODO: can we standardize this?\n        h = x\n        for l, conv in enumerate(self.layers):\n            h = conv(g, h)\n            if l != len(self.layers) - 1:\n                h = self.activation(h)\n\n        return h\n\n\ndef compute_acc(pred, labels):\n    \"\"\"\n    Compute the accuracy of prediction given the labels.\n    \"\"\"\n    return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred)\n\n\ndef evaluate(model, g, labels, val_nid, test_nid, batch_size, device):\n    \"\"\"\n    Evaluate the model on the validation set specified by ``val_mask``.\n    g : The entire graph.\n    inputs : The features of all the nodes.\n    labels : The labels of all the nodes.\n    val_mask : A 0-1 mask indicating which nodes do we actually compute the accuracy for.\n    batch_size : Number of nodes to compute at the same time.\n    device : The GPU device to evaluate on.\n    \"\"\"\n    model.eval()\n    with th.no_grad():\n        inputs = g.ndata[\"feat\"]\n        model = model.cpu()\n        pred = model.inference(g, inputs, batch_size, device)\n    model.train()\n    return (\n        compute_acc(pred[val_nid], labels[val_nid]),\n        compute_acc(pred[test_nid], labels[test_nid]),\n        pred,\n    )\n\n\ndef load_subtensor(g, labels, seeds, input_nodes, device):\n    \"\"\"\n    Copys features and labels of a set of nodes onto GPU.\n    \"\"\"\n    batch_inputs = g.ndata[\"feat\"][input_nodes].to(device)\n    batch_labels = labels[seeds].to(device)\n    return batch_inputs, batch_labels\n\n\n#### Entry point\ndef run(args, device, data):\n    # Unpack data\n    (\n        train_nid,\n        val_nid,\n        test_nid,\n        in_feats,\n        labels,\n        n_classes,\n        g,\n        cluster_iterator,\n    ) = data\n\n    # Define model and optimizer\n    model = SAGE(\n        in_feats,\n        args.num_hidden,\n        n_classes,\n        args.num_layers,\n        F.relu,\n        args.dropout,\n    )\n    model = model.to(device)\n    loss_fcn = nn.CrossEntropyLoss()\n    loss_fcn = loss_fcn.to(device)\n    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)\n\n    # Training loop\n    avg = 0\n    iter_tput = []\n    best_eval_acc = 0\n    best_test_acc = 0\n    for epoch in range(args.num_epochs):\n        iter_load = 0\n        iter_far = 0\n        iter_back = 0\n        iter_tl = 0\n        tic = time.time()\n\n        # Loop over the dataloader to sample the computation dependency graph as a list of\n        # blocks.\n        tic_start = time.time()\n        for step, cluster in enumerate(cluster_iterator):\n            cluster = cluster.int().to(device)\n            mask = cluster.ndata[\"train_mask\"].to(device)\n            if mask.sum() == 0:\n                continue\n            feat = cluster.ndata[\"feat\"].to(device)\n            batch_labels = cluster.ndata[\"labels\"].to(device)\n            tic_step = time.time()\n\n            batch_pred = model(cluster, feat)\n            batch_pred = batch_pred[mask]\n            batch_labels = batch_labels[mask]\n            loss = loss_fcn(batch_pred, batch_labels)\n            optimizer.zero_grad()\n            tic_far = time.time()\n            loss.backward()\n            optimizer.step()\n            tic_back = time.time()\n            iter_load += tic_step - tic_start\n            iter_far += tic_far - tic_step\n            iter_back += tic_back - tic_far\n\n            tic_start = time.time()\n            if step % args.log_every == 0:\n                acc = compute_acc(batch_pred, batch_labels)\n                gpu_mem_alloc = (\n                    th.cuda.max_memory_allocated() / 1000000\n                    if th.cuda.is_available()\n                    else 0\n                )\n                print(\n                    \"Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | GPU {:.1f} MB\".format(\n                        epoch, step, loss.item(), acc.item(), gpu_mem_alloc\n                    )\n                )\n\n        toc = time.time()\n        print(\n            \"Epoch Time(s): {:.4f} Load {:.4f} Forward {:.4f} Backward {:.4f}\".format(\n                toc - tic, iter_load, iter_far, iter_back\n            )\n        )\n        if epoch >= 5:\n            avg += toc - tic\n\n        if epoch % args.eval_every == 0 and epoch != 0:\n            eval_acc, test_acc, pred = evaluate(\n                model, g, labels, val_nid, test_nid, args.val_batch_size, device\n            )\n            model = model.to(device)\n            if args.save_pred:\n                np.savetxt(\n                    args.save_pred + \"%02d\" % epoch,\n                    pred.argmax(1).cpu().numpy(),\n                    \"%d\",\n                )\n            print(\"Eval Acc {:.4f}\".format(eval_acc))\n            if eval_acc > best_eval_acc:\n                best_eval_acc = eval_acc\n                best_test_acc = test_acc\n            print(\n                \"Best Eval Acc {:.4f} Test Acc {:.4f}\".format(\n                    best_eval_acc, best_test_acc\n                )\n            )\n    print(\"Avg epoch time: {}\".format(avg / (epoch - 4)))\n    return best_test_acc\n\n\nif __name__ == \"__main__\":\n    argparser = argparse.ArgumentParser(\"multi-gpu training\")\n    argparser.add_argument(\n        \"--gpu\",\n        type=int,\n        default=0,\n        help=\"GPU device ID. Use -1 for CPU training\",\n    )\n    argparser.add_argument(\"--num-epochs\", type=int, default=30)\n    argparser.add_argument(\"--num-hidden\", type=int, default=256)\n    argparser.add_argument(\"--num-layers\", type=int, default=3)\n    argparser.add_argument(\"--batch-size\", type=int, default=32)\n    argparser.add_argument(\"--val-batch-size\", type=int, default=10000)\n    argparser.add_argument(\"--log-every\", type=int, default=20)\n    argparser.add_argument(\"--eval-every\", type=int, default=1)\n    argparser.add_argument(\"--lr\", type=float, default=0.001)\n    argparser.add_argument(\"--dropout\", type=float, default=0.5)\n    argparser.add_argument(\"--save-pred\", type=str, default=\"\")\n    argparser.add_argument(\"--wd\", type=float, default=0)\n    argparser.add_argument(\"--num_partitions\", type=int, default=15000)\n    args = argparser.parse_args()\n\n    if args.gpu >= 0:\n        device = th.device(\"cuda:%d\" % args.gpu)\n    else:\n        device = th.device(\"cpu\")\n\n    # load ogbn-products data\n    data = DglNodePropPredDataset(name=\"ogbn-products\")\n    splitted_idx = data.get_idx_split()\n    train_idx, val_idx, test_idx = (\n        splitted_idx[\"train\"],\n        splitted_idx[\"valid\"],\n        splitted_idx[\"test\"],\n    )\n    graph, labels = data[0]\n    labels = labels[:, 0]\n    num_nodes = train_idx.shape[0] + val_idx.shape[0] + test_idx.shape[0]\n    assert num_nodes == graph.num_nodes()\n    graph.ndata[\"labels\"] = labels\n    mask = th.zeros(num_nodes, dtype=th.bool)\n    mask[train_idx] = True\n    graph.ndata[\"train_mask\"] = mask\n    mask = th.zeros(num_nodes, dtype=th.bool)\n    mask[val_idx] = True\n    graph.ndata[\"valid_mask\"] = mask\n    mask = th.zeros(num_nodes, dtype=th.bool)\n    mask[test_idx] = True\n    graph.ndata[\"test_mask\"] = mask\n\n    graph.in_degrees(0)\n    graph.out_degrees(0)\n    graph.find_edges(0)\n\n    cluster_iter_data = ClusterIter(\n        \"ogbn-products\",\n        graph,\n        args.num_partitions,\n        args.batch_size,\n        th.cat([train_idx, val_idx, test_idx]),\n    )\n    idx = th.arange(args.num_partitions // args.batch_size)\n    cluster_iterator = DataLoader(\n        cluster_iter_data,\n        batch_size=32,\n        shuffle=True,\n        pin_memory=True,\n        num_workers=4,\n        collate_fn=partial(subgraph_collate_fn, graph),\n    )\n\n    in_feats = graph.ndata[\"feat\"].shape[1]\n    print(in_feats)\n    n_classes = (labels.max() + 1).item()\n    # Pack data\n    data = (\n        train_idx,\n        val_idx,\n        test_idx,\n        in_feats,\n        labels,\n        n_classes,\n        graph,\n        cluster_iterator,\n    )\n\n    # Run 10 times\n    test_accs = []\n    for i in range(10):\n        test_accs.append(run(args, device, data))\n        print(\n            \"Average test accuracy:\", np.mean(test_accs), \"±\", np.std(test_accs)\n        )\n"
  },
  {
    "path": "examples/pytorch/ogb/cluster-sage/partition_utils.py",
    "content": "from time import time\n\nimport dgl\n\nimport numpy as np\nfrom dgl import backend as F\nfrom dgl.transforms import metis_partition\n\n\ndef get_partition_list(g, psize):\n    p_gs = metis_partition(g, psize)\n    graphs = []\n    for k, val in p_gs.items():\n        nids = val.ndata[dgl.NID]\n        nids = F.asnumpy(nids)\n        graphs.append(nids)\n    return graphs\n"
  },
  {
    "path": "examples/pytorch/ogb/cluster-sage/sampler.py",
    "content": "import os\nimport random\nimport time\n\nimport torch\nfrom partition_utils import *\n\nimport dgl.function as fn\n\n\nclass ClusterIter(object):\n    \"\"\"The partition sampler given a DGLGraph and partition number.\n    The metis is used as the graph partition backend.\n    \"\"\"\n\n    def __init__(self, dn, g, psize, batch_size, seed_nid):\n        \"\"\"Initialize the sampler.\n\n        Paramters\n        ---------\n        dn : str\n            The dataset name.\n        g  : DGLGraph\n            The full graph of dataset\n        psize: int\n            The partition number\n        batch_size: int\n            The number of partitions in one batch\n        seed_nid: np.ndarray\n            The training nodes ids, used to extract the training graph\n        \"\"\"\n        self.psize = psize\n        self.batch_size = batch_size\n        # cache the partitions of known datasets&partition number\n        if dn:\n            fn = os.path.join(\"./datasets/\", dn + \"_{}.npy\".format(psize))\n            if os.path.exists(fn):\n                self.par_li = np.load(fn, allow_pickle=True)\n            else:\n                os.makedirs(\"./datasets/\", exist_ok=True)\n                self.par_li = get_partition_list(g, psize)\n                self.par_li = np.array(self.par_li, dtype=object)\n                np.save(fn, self.par_li)\n        else:\n            self.par_li = get_partition_list(g, psize)\n        par_list = []\n        for p in self.par_li:\n            par = torch.Tensor(p)\n            par_list.append(par)\n        self.par_list = par_list\n\n    # use one side normalization\n    def get_norm(self, g):\n        norm = 1.0 / g.in_degrees().float().unsqueeze(1)\n        norm[torch.isinf(norm)] = 0\n        norm = norm.to(self.g.ndata[\"feat\"].device)\n        return norm\n\n    def __len__(self):\n        return self.psize\n\n    def __getitem__(self, idx):\n        return self.par_li[idx]\n\n\ndef subgraph_collate_fn(g, batch):\n    nids = np.concatenate(batch).reshape(-1).astype(np.int64)\n    g1 = g.subgraph(nids)\n    return g1\n"
  },
  {
    "path": "examples/pytorch/ogb/deepwalk/README.md",
    "content": "# DeepWalk Example\n\n- Paper link: [here](https://arxiv.org/pdf/1403.6652.pdf)\n- Other implementation: [gensim](https://github.com/phanein/deepwalk), [deepwalk-c](https://github.com/xgfs/deepwalk-c)\n\nThe implementation includes multi-processing training with CPU and mixed training with CPU and multi-GPU.\n\n## Dependencies\n- PyTorch 1.5.0+\n\n## Tested version\n- PyTorch 1.5.0\n- DGL 0.5.0\n\n## Input data\nCurrently, we support two builtin dataset: youtube and blog. Use --data\\_file youtube to select youtube dataset and --data\\_file blog to select blog dataset.\nThe data is avaliable at  https://data.dgl.ai/dataset/DeepWalk/youtube.zip and https://data.dgl.ai/dataset/DeepWalk/blog.zip\nThe youtube.zip includes both youtube-net.txt, youtube-vocab.txt and youtube-label.txt; The blog.zip includes both blog-net.txt, blog-vocab.txt and blog-label.txt. \n\nFor other datasets please pass the full path to the trainer through --data\\_file and the format of a network file should follow:\n```\n1(node id) 2(node id)\n1 3\n1 4\n2 4\n...\n```\n\n### How to run the code\nTo run the code:\n```\npython3 deepwalk.py --data_file youtube --output_emb_file emb.txt --mix --lr 0.2 --gpus 0 1 2 3 --batch_size 100 --negative 5\n```\n\n### How to save the embedding\nBy default the trained embedding is saved under --output\\_embe\\_file FILE\\_NAME as a numpy object.\nTo save the trained embedding in raw format(txt format), please use --save\\_in\\_txt argument.\n\n### Evaluation\n\nTo evalutate embedding on multi-label classification, please refer to [here](https://github.com/ShawXh/Evaluate-Embedding)\n\nYouTube (1M nodes).\n\n| Implementation | Macro-F1 (%) <br> 1% &emsp;&emsp; 3% &emsp;&emsp; 5% &emsp;&emsp; 7% &emsp;&emsp; 9% | Micro-F1 (%) <br> 1% &emsp;&emsp; 3% &emsp;&emsp; 5% &emsp;&emsp; 7% &emsp;&emsp; 9% |\n|----|----|----|\n| gensim.word2vec(hs) | 28.73 &emsp; 32.51 &emsp; 33.67 &emsp; 34.28 &emsp; 34.79 | 35.73 &emsp; 38.34 &emsp; 39.37 &emsp; 40.08 &emsp; 40.77 | \n| gensim.word2vec(ns) | 28.18 &emsp; 32.25 &emsp; 33.56 &emsp; 34.60 &emsp; 35.22 | 35.35 &emsp; 37.69 &emsp; 38.08 &emsp; 40.24 &emsp; 41.09 | \n|        ours         | 24.58 &emsp; 31.23 &emsp; 33.97 &emsp; 35.41 &emsp; 36.48 | 38.93 &emsp; 43.17 &emsp; 44.73 &emsp; 45.42 &emsp; 45.92 | \n\nThe comparison between running time is shown as below, where the numbers in the brackets denote time used on random-walk.\n\n| Implementation | gensim.word2vec(hs) | gensim.word2vec(ns) | Ours |\n|----|----|----|----|\n| Time (s) |     27119.6(1759.8)    |    10580.3(1704.3)    | 428.89 |\n\nParameters.\n- walk_length = 80, number_walks = 10, window_size = 5\n- Ours: 4GPU (Tesla V100), lr = 0.2, batchs_size = 128, neg_weight = 5, negative = 1, num_thread = 4\n- Others: workers = 8, negative = 5\n\nSpeeding-up with mixed CPU & multi-GPU. The used parameters are the same as above.\n|  #GPUs   |   1   |   2   |   4   |\n|----------|-------|-------|-------|\n| Time (s) |1419.64| 952.04|428.89 |\n\n## OGB Dataset\n### How to load ogb data\nYou can run the code directly with:\n```\npython3 deepwalk --ogbl_name xxx --load_from_ogbl\n```\nHowever, ogb.linkproppred might not be compatible with mixed training with multi-gpu. If you want to do mixed training, please use no more than 1 gpu by the command above.\n\n### Evaluation\nFor evaluatation we follow the code mlp.py provided by ogb [here](https://github.com/snap-stanford/ogb/blob/master/examples/linkproppred/collab/mlp.py).\n\n### Used config\nogbl-collab\n```\npython3 deepwalk.py --ogbl_name ogbl-collab --load_from_ogbl --save_in_pt --output_emb_file collab-embedding.pt --num_walks 50 --window_size 2 --walk_length 40 --lr 0.1 --negative 1 --neg_weight 1 --lap_norm 0.01 --mix --gpus 0 --num_threads 4 --print_interval 2000 --print_loss --batch_size 128 --use_context_weight\ncd ./ogb/blob/master/examples/linkproppred/collab/\ncp embedding_pt_file_path ./\npython3 mlp.py --device 0 --runs 10 --use_node_embedding\n```\n\nogbl-ddi\n```\npython3 deepwalk.py --ogbl_name ogbl-ddi --load_from_ogbl --save_in_pt --output_emb_file ddi-embedding.pt --num_walks 50 --window_size 2 --walk_length 80 --lr 0.1 --negative 1 --neg_weight 1 --lap_norm 0.05 --only_gpu --gpus 0 --num_threads 4 --print_interval 2000 --print_loss --batch_size 16 --use_context_weight\ncd ./ogb/blob/master/examples/linkproppred/ddi/\ncp embedding_pt_file_path ./\npython3 mlp.py --device 0 --runs 10 --epochs 100\n```\n\nogbl-ppa\n```\npython3 deepwalk.py --ogbl_name ogbl-ppa --load_from_ogbl --save_in_pt --output_emb_file ppa-embedding.pt --negative 1 --neg_weight 1 --batch_size 64 --print_interval 2000 --print_loss --window_size 1 --num_walks 30 --walk_length 80 --lr 0.1 --lap_norm 0.02 --mix --gpus 0 --num_threads 4\ncp embedding_pt_file_path ./\npython3 mlp.py --device 2 --runs 10\n```\n\nogbl-citation\n```\npython3 deepwalk.py --ogbl_name ogbl-citation --load_from_ogbl --save_in_pt --output_emb_file embedding.pt --window_size 2 --num_walks 10 --negative 1 --neg_weight 1 --walk_length 80 --batch_size 128 --print_loss --print_interval 1000 --mix --gpus 0 --use_context_weight --num_threads 4 --lap_norm 0.01 --lr 0.1\ncp embedding_pt_file_path ./\npython3 mlp.py --device 2 --runs 10 --use_node_embedding\n```\n\n### OGBL Results\nogbl-collab\n<br>#params: 61258346(model) + 131841(mlp) = 61390187\n<br>Hits@10\n<br>&emsp;Highest Train: 74.83 ± 4.79\n<br>&emsp;Highest Valid: 40.03 ± 2.98\n<br>&emsp;&emsp;Final Train: 74.51 ± 4.92\n<br>&emsp;&emsp;Final Test: 31.13 ± 2.47\n<br>Hits@50\n<br>&emsp;Highest Train: 98.83 ± 0.15\n<br>&emsp;Highest Valid: 60.61 ± 0.32\n<br>&emsp;&emsp;Final Train: 98.74 ± 0.17\n<br>&emsp;&emsp;Final Test: 50.37 ± 0.34\n<br>Hits@100\n<br>&emsp;Highest Train: 99.86 ± 0.04\n<br>&emsp;Highest Valid: 66.64 ± 0.32\n<br>&emsp;&emsp;Final Train: 99.84 ± 0.06\n<br>&emsp;&emsp;Final Test: 56.88 ± 0.37\n\n<br>obgl-ddi\n<br>#params: 1444840(model) + 99073(mlp) = 1543913\n<br>Hits@10\n<br>&emsp;Highest Train: 33.91 ± 2.01\n<br>&emsp;Highest Valid: 30.96 ± 1.89\n<br>&emsp;&emsp;Final Train: 33.90 ± 2.00\n<br>&emsp;&emsp;Final Test: 15.16 ± 4.28\n<br>Hits@20\n<br>&emsp;Highest Train: 44.64 ± 1.71\n<br>&emsp;Highest Valid: 41.32 ± 1.69\n<br>&emsp;&emsp;Final Train: 44.62 ± 1.69\n<br>&emsp;&emsp;Final Test: 26.42 ± 6.10\n<br>Hits@30\n<br>&emsp;Highest Train: 51.01 ± 1.72\n<br>&emsp;Highest Valid: 47.64 ± 1.71\n<br>&emsp;&emsp;Final Train: 50.99 ± 1.72\n<br>&emsp;&emsp;Final Test: 33.56 ± 3.95\n\n<br>ogbl-ppa\n<br>#params: 150024820(model) + 113921(mlp) = 150138741\n<br>Hits@10\n<br>&emsp;Highest Train: 4.78 ± 0.73\n<br>&emsp;Highest Valid: 4.30 ± 0.68\n<br>&emsp;&emsp;Final Train: 4.77 ± 0.73\n<br>&emsp;&emsp;Final Test: 2.67 ± 0.42\n<br>Hits@50\n<br>&emsp;Highest Train: 18.82 ± 1.07\n<br>&emsp;Highest Valid: 17.26 ± 1.01\n<br>&emsp;&emsp;Final Train: 18.82 ± 1.07\n<br>&emsp;&emsp;Final Test: 17.34 ± 2.09\n<br>Hits@100\n<br>&emsp;Highest Train: 31.29 ± 2.11\n<br>&emsp;Highest Valid: 28.97 ± 1.92\n<br>&emsp;&emsp;Final Train: 31.28 ± 2.12\n<br>&emsp;&emsp;Final Test: 28.88 ± 1.53\n\n<br>ogbl-citation\n<br>#params: 757811178(model) + 131841(mlp) = 757943019\n<br>MRR\n<br>&emsp;Highest Train: 0.9381 ± 0.0003\n<br>&emsp;Highest Valid: 0.8469 ± 0.0003\n<br>&emsp;&emsp;Final Train: 0.9377 ± 0.0004\n<br>&emsp;&emsp;Final Test: 0.8479 ± 0.0003\n\n### Notes\n#### Multi-GPU issues\nFor efficiency, the results of ogbl-collab, ogbl-ppa, ogbl-ddi are run with multi-GPU. Since ogb is somehow incompatible with our multi-GPU implementation, we need to do some preprocessing. The command is:\n```\npython3 load_dataset.py --name dataset_name\n```\nIt will output a data file to the local. For example, if `dataset_name` is `ogbl-collab`, then a file `ogbl-collab-net.txt` will be generated. Then we run \n```\npython3 deepwalk.py --data_file data_file_path\n```\nwhere the other parameters are the same with used configs without using `--load_from_ogbl` and `--ogbl_name`.\n\n#### Others\nThe performance on ogbl-ddi and ogbl-ppa can be not that stable."
  },
  {
    "path": "examples/pytorch/ogb/deepwalk/deepwalk.py",
    "content": "import argparse\nimport os\nimport random\nimport time\n\nimport dgl\n\nimport numpy as np\nimport torch\nimport torch.multiprocessing as mp\nfrom model import SkipGramModel\nfrom reading_data import DeepwalkDataset\nfrom torch.utils.data import DataLoader\nfrom utils import shuffle_walks, sum_up_params\n\n\nclass DeepwalkTrainer:\n    def __init__(self, args):\n        \"\"\"Initializing the trainer with the input arguments\"\"\"\n        self.args = args\n        self.dataset = DeepwalkDataset(\n            net_file=args.data_file,\n            map_file=args.map_file,\n            walk_length=args.walk_length,\n            window_size=args.window_size,\n            num_walks=args.num_walks,\n            batch_size=args.batch_size,\n            negative=args.negative,\n            gpus=args.gpus,\n            fast_neg=args.fast_neg,\n            ogbl_name=args.ogbl_name,\n            load_from_ogbl=args.load_from_ogbl,\n        )\n        self.emb_size = self.dataset.G.num_nodes()\n        self.emb_model = None\n\n    def init_device_emb(self):\n        \"\"\"set the device before training\n        will be called once in fast_train_mp / fast_train\n        \"\"\"\n        choices = sum([self.args.only_gpu, self.args.only_cpu, self.args.mix])\n        assert (\n            choices == 1\n        ), \"Must choose only *one* training mode in [only_cpu, only_gpu, mix]\"\n\n        # initializing embedding on CPU\n        self.emb_model = SkipGramModel(\n            emb_size=self.emb_size,\n            emb_dimension=self.args.dim,\n            walk_length=self.args.walk_length,\n            window_size=self.args.window_size,\n            batch_size=self.args.batch_size,\n            only_cpu=self.args.only_cpu,\n            only_gpu=self.args.only_gpu,\n            mix=self.args.mix,\n            neg_weight=self.args.neg_weight,\n            negative=self.args.negative,\n            lr=self.args.lr,\n            lap_norm=self.args.lap_norm,\n            fast_neg=self.args.fast_neg,\n            record_loss=self.args.print_loss,\n            norm=self.args.norm,\n            use_context_weight=self.args.use_context_weight,\n            async_update=self.args.async_update,\n            num_threads=self.args.num_threads,\n        )\n\n        torch.set_num_threads(self.args.num_threads)\n        if self.args.only_gpu:\n            print(\"Run in 1 GPU\")\n            assert self.args.gpus[0] >= 0\n            self.emb_model.all_to_device(self.args.gpus[0])\n        elif self.args.mix:\n            print(\"Mix CPU with %d GPU\" % len(self.args.gpus))\n            if len(self.args.gpus) == 1:\n                assert (\n                    self.args.gpus[0] >= 0\n                ), \"mix CPU with GPU should have available GPU\"\n                self.emb_model.set_device(self.args.gpus[0])\n        else:\n            print(\"Run in CPU process\")\n            self.args.gpus = [torch.device(\"cpu\")]\n\n    def train(self):\n        \"\"\"train the embedding\"\"\"\n        if len(self.args.gpus) > 1:\n            self.fast_train_mp()\n        else:\n            self.fast_train()\n\n    def fast_train_mp(self):\n        \"\"\"multi-cpu-core or mix cpu & multi-gpu\"\"\"\n        self.init_device_emb()\n        self.emb_model.share_memory()\n\n        if self.args.count_params:\n            sum_up_params(self.emb_model)\n\n        start_all = time.time()\n        ps = []\n\n        for i in range(len(self.args.gpus)):\n            p = mp.Process(\n                target=self.fast_train_sp, args=(i, self.args.gpus[i])\n            )\n            ps.append(p)\n            p.start()\n\n        for p in ps:\n            p.join()\n\n        print(\"Used time: %.2fs\" % (time.time() - start_all))\n        if self.args.save_in_txt:\n            self.emb_model.save_embedding_txt(\n                self.dataset, self.args.output_emb_file\n            )\n        elif self.args.save_in_pt:\n            self.emb_model.save_embedding_pt(\n                self.dataset, self.args.output_emb_file\n            )\n        else:\n            self.emb_model.save_embedding(\n                self.dataset, self.args.output_emb_file\n            )\n\n    def fast_train_sp(self, rank, gpu_id):\n        \"\"\"a subprocess for fast_train_mp\"\"\"\n        if self.args.mix:\n            self.emb_model.set_device(gpu_id)\n\n        torch.set_num_threads(self.args.num_threads)\n        if self.args.async_update:\n            self.emb_model.create_async_update()\n\n        sampler = self.dataset.create_sampler(rank)\n\n        dataloader = DataLoader(\n            dataset=sampler.seeds,\n            batch_size=self.args.batch_size,\n            collate_fn=sampler.sample,\n            shuffle=False,\n            drop_last=False,\n            num_workers=self.args.num_sampler_threads,\n        )\n        num_batches = len(dataloader)\n        print(\n            \"num batchs: %d in process [%d] GPU [%d]\"\n            % (num_batches, rank, gpu_id)\n        )\n        # number of positive node pairs in a sequence\n        num_pos = int(\n            2 * self.args.walk_length * self.args.window_size\n            - self.args.window_size * (self.args.window_size + 1)\n        )\n\n        start = time.time()\n        with torch.no_grad():\n            for i, walks in enumerate(dataloader):\n                if self.args.fast_neg:\n                    self.emb_model.fast_learn(walks)\n                else:\n                    # do negative sampling\n                    bs = len(walks)\n                    neg_nodes = torch.LongTensor(\n                        np.random.choice(\n                            self.dataset.neg_table,\n                            bs * num_pos * self.args.negative,\n                            replace=True,\n                        )\n                    )\n                    self.emb_model.fast_learn(walks, neg_nodes=neg_nodes)\n\n                if i > 0 and i % self.args.print_interval == 0:\n                    if self.args.print_loss:\n                        print(\n                            \"GPU-[%d] batch %d time: %.2fs loss: %.4f\"\n                            % (\n                                gpu_id,\n                                i,\n                                time.time() - start,\n                                -sum(self.emb_model.loss)\n                                / self.args.print_interval,\n                            )\n                        )\n                        self.emb_model.loss = []\n                    else:\n                        print(\n                            \"GPU-[%d] batch %d time: %.2fs\"\n                            % (gpu_id, i, time.time() - start)\n                        )\n                    start = time.time()\n\n            if self.args.async_update:\n                self.emb_model.finish_async_update()\n\n    def fast_train(self):\n        \"\"\"fast train with dataloader with only gpu / only cpu\"\"\"\n        # the number of postive node pairs of a node sequence\n        num_pos = (\n            2 * self.args.walk_length * self.args.window_size\n            - self.args.window_size * (self.args.window_size + 1)\n        )\n        num_pos = int(num_pos)\n\n        self.init_device_emb()\n\n        if self.args.async_update:\n            self.emb_model.share_memory()\n            self.emb_model.create_async_update()\n\n        if self.args.count_params:\n            sum_up_params(self.emb_model)\n\n        sampler = self.dataset.create_sampler(0)\n\n        dataloader = DataLoader(\n            dataset=sampler.seeds,\n            batch_size=self.args.batch_size,\n            collate_fn=sampler.sample,\n            shuffle=False,\n            drop_last=False,\n            num_workers=self.args.num_sampler_threads,\n        )\n\n        num_batches = len(dataloader)\n        print(\"num batchs: %d\\n\" % num_batches)\n\n        start_all = time.time()\n        start = time.time()\n        with torch.no_grad():\n            max_i = num_batches\n            for i, walks in enumerate(dataloader):\n                if self.args.fast_neg:\n                    self.emb_model.fast_learn(walks)\n                else:\n                    # do negative sampling\n                    bs = len(walks)\n                    neg_nodes = torch.LongTensor(\n                        np.random.choice(\n                            self.dataset.neg_table,\n                            bs * num_pos * self.args.negative,\n                            replace=True,\n                        )\n                    )\n                    self.emb_model.fast_learn(walks, neg_nodes=neg_nodes)\n\n                if i > 0 and i % self.args.print_interval == 0:\n                    if self.args.print_loss:\n                        print(\n                            \"Batch %d training time: %.2fs loss: %.4f\"\n                            % (\n                                i,\n                                time.time() - start,\n                                -sum(self.emb_model.loss)\n                                / self.args.print_interval,\n                            )\n                        )\n                        self.emb_model.loss = []\n                    else:\n                        print(\n                            \"Batch %d, training time: %.2fs\"\n                            % (i, time.time() - start)\n                        )\n                    start = time.time()\n\n            if self.args.async_update:\n                self.emb_model.finish_async_update()\n\n        print(\"Training used time: %.2fs\" % (time.time() - start_all))\n        if self.args.save_in_txt:\n            self.emb_model.save_embedding_txt(\n                self.dataset, self.args.output_emb_file\n            )\n        elif self.args.save_in_pt:\n            self.emb_model.save_embedding_pt(\n                self.dataset, self.args.output_emb_file\n            )\n        else:\n            self.emb_model.save_embedding(\n                self.dataset, self.args.output_emb_file\n            )\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"DeepWalk\")\n    # input files\n    ## personal datasets\n    parser.add_argument(\n        \"--data_file\",\n        type=str,\n        help=\"path of the txt network file, builtin dataset include youtube-net and blog-net\",\n    )\n    ## ogbl datasets\n    parser.add_argument(\n        \"--ogbl_name\", type=str, help=\"name of ogbl dataset, e.g. ogbl-ddi\"\n    )\n    parser.add_argument(\n        \"--load_from_ogbl\",\n        default=False,\n        action=\"store_true\",\n        help=\"whether load dataset from ogbl\",\n    )\n\n    # output files\n    parser.add_argument(\n        \"--save_in_txt\",\n        default=False,\n        action=\"store_true\",\n        help=\"Whether save dat in txt format or npy\",\n    )\n    parser.add_argument(\n        \"--save_in_pt\",\n        default=False,\n        action=\"store_true\",\n        help=\"Whether save dat in pt format or npy\",\n    )\n    parser.add_argument(\n        \"--output_emb_file\",\n        type=str,\n        default=\"emb.npy\",\n        help=\"path of the output npy embedding file\",\n    )\n    parser.add_argument(\n        \"--map_file\",\n        type=str,\n        default=\"nodeid_to_index.pickle\",\n        help=\"path of the mapping dict that maps node ids to embedding index\",\n    )\n    parser.add_argument(\n        \"--norm\",\n        default=False,\n        action=\"store_true\",\n        help=\"whether to do normalization over node embedding after training\",\n    )\n\n    # model parameters\n    parser.add_argument(\n        \"--dim\", default=128, type=int, help=\"embedding dimensions\"\n    )\n    parser.add_argument(\n        \"--window_size\", default=5, type=int, help=\"context window size\"\n    )\n    parser.add_argument(\n        \"--use_context_weight\",\n        default=False,\n        action=\"store_true\",\n        help=\"whether to add weights over nodes in the context window\",\n    )\n    parser.add_argument(\n        \"--num_walks\",\n        default=10,\n        type=int,\n        help=\"number of walks for each node\",\n    )\n    parser.add_argument(\n        \"--negative\",\n        default=1,\n        type=int,\n        help=\"negative samples for each positve node pair\",\n    )\n    parser.add_argument(\n        \"--batch_size\",\n        default=128,\n        type=int,\n        help=\"number of node sequences in each batch\",\n    )\n    parser.add_argument(\n        \"--walk_length\",\n        default=80,\n        type=int,\n        help=\"number of nodes in a sequence\",\n    )\n    parser.add_argument(\n        \"--neg_weight\", default=1.0, type=float, help=\"negative weight\"\n    )\n    parser.add_argument(\n        \"--lap_norm\",\n        default=0.01,\n        type=float,\n        help=\"weight of laplacian normalization, recommend to set as 0.1 / windoe_size\",\n    )\n\n    # training parameters\n    parser.add_argument(\n        \"--print_interval\",\n        default=100,\n        type=int,\n        help=\"number of batches between printing\",\n    )\n    parser.add_argument(\n        \"--print_loss\",\n        default=False,\n        action=\"store_true\",\n        help=\"whether print loss during training\",\n    )\n    parser.add_argument(\"--lr\", default=0.2, type=float, help=\"learning rate\")\n\n    # optimization settings\n    parser.add_argument(\n        \"--mix\",\n        default=False,\n        action=\"store_true\",\n        help=\"mixed training with CPU and GPU\",\n    )\n    parser.add_argument(\n        \"--gpus\",\n        type=int,\n        default=[-1],\n        nargs=\"+\",\n        help=\"a list of active gpu ids, e.g. 0, used with --mix\",\n    )\n    parser.add_argument(\n        \"--only_cpu\",\n        default=False,\n        action=\"store_true\",\n        help=\"training with CPU\",\n    )\n    parser.add_argument(\n        \"--only_gpu\",\n        default=False,\n        action=\"store_true\",\n        help=\"training with GPU\",\n    )\n    parser.add_argument(\n        \"--async_update\",\n        default=False,\n        action=\"store_true\",\n        help=\"mixed training asynchronously, not recommended\",\n    )\n\n    parser.add_argument(\n        \"--true_neg\",\n        default=False,\n        action=\"store_true\",\n        help=\"If not specified, this program will use \"\n        \"a faster negative sampling method, \"\n        \"but the samples might be false negative \"\n        \"with a small probability. If specified, \"\n        \"this program will generate a true negative sample table,\"\n        \"and select from it when doing negative samling\",\n    )\n    parser.add_argument(\n        \"--num_threads\",\n        default=8,\n        type=int,\n        help=\"number of threads used for each CPU-core/GPU\",\n    )\n    parser.add_argument(\n        \"--num_sampler_threads\",\n        default=2,\n        type=int,\n        help=\"number of threads used for sampling\",\n    )\n\n    parser.add_argument(\n        \"--count_params\",\n        default=False,\n        action=\"store_true\",\n        help=\"count the params, exit once counting over\",\n    )\n\n    args = parser.parse_args()\n    args.fast_neg = not args.true_neg\n    if args.async_update:\n        assert args.mix, \"--async_update only with --mix\"\n\n    start_time = time.time()\n    trainer = DeepwalkTrainer(args)\n    trainer.train()\n    print(\"Total used time: %.2f\" % (time.time() - start_time))\n"
  },
  {
    "path": "examples/pytorch/ogb/deepwalk/load_dataset.py",
    "content": "\"\"\" load dataset from ogb \"\"\"\n\nimport argparse\nimport time\n\nfrom ogb.linkproppred import DglLinkPropPredDataset\n\n\ndef load_from_ogbl_with_name(name):\n    choices = [\"ogbl-collab\", \"ogbl-ddi\", \"ogbl-ppa\", \"ogbl-citation\"]\n    assert name in choices, \"name must be selected from \" + str(choices)\n    dataset = DglLinkPropPredDataset(name)\n    return dataset[0]\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--name\",\n        type=str,\n        choices=[\"ogbl-collab\", \"ogbl-ddi\", \"ogbl-ppa\", \"ogbl-citation\"],\n        default=\"ogbl-collab\",\n        help=\"name of datasets by ogb\",\n    )\n    args = parser.parse_args()\n\n    print(\"loading graph... it might take some time\")\n    name = args.name\n    g = load_from_ogbl_with_name(name=name)\n\n    try:\n        w = g.edata[\"edge_weight\"]\n        weighted = True\n    except:\n        weighted = False\n\n    edge_num = g.edges()[0].shape[0]\n    src = list(g.edges()[0])\n    tgt = list(g.edges()[1])\n    if weighted:\n        weight = list(g.edata[\"edge_weight\"])\n\n    print(\"writing...\")\n    start_time = time.time()\n    with open(name + \"-net.txt\", \"w\") as f:\n        for i in range(edge_num):\n            if weighted:\n                f.write(\n                    str(src[i].item())\n                    + \" \"\n                    + str(tgt[i].item())\n                    + \" \"\n                    + str(weight[i].item())\n                    + \"\\n\"\n                )\n            else:\n                f.write(\n                    str(src[i].item()) + \" \" + str(tgt[i].item()) + \" \" + \"1\\n\"\n                )\n    print(\"writing used time: %d s\" % int(time.time() - start_time))\n"
  },
  {
    "path": "examples/pytorch/ogb/deepwalk/model.py",
    "content": "import random\n\nimport numpy as np\nimport torch\nimport torch.multiprocessing as mp\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.multiprocessing import Queue\nfrom torch.nn import init\n\n\ndef init_emb2pos_index(walk_length, window_size, batch_size):\n    \"\"\"select embedding of positive nodes from a batch of node embeddings\n\n    Return\n    ------\n    index_emb_posu torch.LongTensor : the indices of u_embeddings\n    index_emb_posv torch.LongTensor : the indices of v_embeddings\n\n    Usage\n    -----\n    # emb_u.shape: [batch_size * walk_length, dim]\n    batch_emb2posu = torch.index_select(emb_u, 0, index_emb_posu)\n    \"\"\"\n    idx_list_u = []\n    idx_list_v = []\n    for b in range(batch_size):\n        for i in range(walk_length):\n            for j in range(i - window_size, i):\n                if j >= 0:\n                    idx_list_u.append(j + b * walk_length)\n                    idx_list_v.append(i + b * walk_length)\n            for j in range(i + 1, i + 1 + window_size):\n                if j < walk_length:\n                    idx_list_u.append(j + b * walk_length)\n                    idx_list_v.append(i + b * walk_length)\n\n    # [num_pos * batch_size]\n    index_emb_posu = torch.LongTensor(idx_list_u)\n    index_emb_posv = torch.LongTensor(idx_list_v)\n\n    return index_emb_posu, index_emb_posv\n\n\ndef init_emb2neg_index(walk_length, window_size, negative, batch_size):\n    \"\"\"select embedding of negative nodes from a batch of node embeddings\n    for fast negative sampling\n\n    Return\n    ------\n    index_emb_negu torch.LongTensor : the indices of u_embeddings\n    index_emb_negv torch.LongTensor : the indices of v_embeddings\n\n    Usage\n    -----\n    # emb_u.shape: [batch_size * walk_length, dim]\n    batch_emb2negu = torch.index_select(emb_u, 0, index_emb_negu)\n    \"\"\"\n    idx_list_u = []\n    for b in range(batch_size):\n        for i in range(walk_length):\n            for j in range(i - window_size, i):\n                if j >= 0:\n                    idx_list_u += [i + b * walk_length] * negative\n            for j in range(i + 1, i + 1 + window_size):\n                if j < walk_length:\n                    idx_list_u += [i + b * walk_length] * negative\n\n    idx_list_v = (\n        list(range(batch_size * walk_length)) * negative * window_size * 2\n    )\n    random.shuffle(idx_list_v)\n    idx_list_v = idx_list_v[: len(idx_list_u)]\n\n    # [bs * walk_length * negative]\n    index_emb_negu = torch.LongTensor(idx_list_u)\n    index_emb_negv = torch.LongTensor(idx_list_v)\n\n    return index_emb_negu, index_emb_negv\n\n\ndef init_weight(walk_length, window_size, batch_size):\n    \"\"\"init context weight\"\"\"\n    weight = []\n    for b in range(batch_size):\n        for i in range(walk_length):\n            for j in range(i - window_size, i):\n                if j >= 0:\n                    weight.append(1.0 - float(i - j - 1) / float(window_size))\n            for j in range(i + 1, i + 1 + window_size):\n                if j < walk_length:\n                    weight.append(1.0 - float(j - i - 1) / float(window_size))\n\n    # [num_pos * batch_size]\n    return torch.Tensor(weight).unsqueeze(1)\n\n\ndef init_empty_grad(emb_dimension, walk_length, batch_size):\n    \"\"\"initialize gradient matrix\"\"\"\n    grad_u = torch.zeros((batch_size * walk_length, emb_dimension))\n    grad_v = torch.zeros((batch_size * walk_length, emb_dimension))\n\n    return grad_u, grad_v\n\n\ndef adam(grad, state_sum, nodes, lr, device, only_gpu):\n    \"\"\"calculate gradients according to adam\"\"\"\n    grad_sum = (grad * grad).mean(1)\n    if not only_gpu:\n        grad_sum = grad_sum.cpu()\n    state_sum.index_add_(0, nodes, grad_sum)  # cpu\n    std = state_sum[nodes].to(device)  # gpu\n    std_values = std.sqrt_().add_(1e-10).unsqueeze(1)\n    grad = lr * grad / std_values  # gpu\n\n    return grad\n\n\ndef async_update(num_threads, model, queue):\n    \"\"\"asynchronous embedding update\"\"\"\n    torch.set_num_threads(num_threads)\n    while True:\n        (grad_u, grad_v, grad_v_neg, nodes, neg_nodes) = queue.get()\n        if grad_u is None:\n            return\n        with torch.no_grad():\n            model.u_embeddings.weight.data.index_add_(0, nodes.view(-1), grad_u)\n            model.v_embeddings.weight.data.index_add_(0, nodes.view(-1), grad_v)\n            if neg_nodes is not None:\n                model.v_embeddings.weight.data.index_add_(\n                    0, neg_nodes.view(-1), grad_v_neg\n                )\n\n\nclass SkipGramModel(nn.Module):\n    \"\"\"Negative sampling based skip-gram\"\"\"\n\n    def __init__(\n        self,\n        emb_size,\n        emb_dimension,\n        walk_length,\n        window_size,\n        batch_size,\n        only_cpu,\n        only_gpu,\n        mix,\n        neg_weight,\n        negative,\n        lr,\n        lap_norm,\n        fast_neg,\n        record_loss,\n        norm,\n        use_context_weight,\n        async_update,\n        num_threads,\n    ):\n        \"\"\"initialize embedding on CPU\n\n        Paremeters\n        ----------\n        emb_size int : number of nodes\n        emb_dimension int : embedding dimension\n        walk_length int : number of nodes in a sequence\n        window_size int : context window size\n        batch_size int : number of node sequences in each batch\n        only_cpu bool : training with CPU\n        only_gpu bool : training with GPU\n        mix bool : mixed training with CPU and GPU\n        negative int : negative samples for each positve node pair\n        neg_weight float : negative weight\n        lr float : initial learning rate\n        lap_norm float : weight of laplacian normalization\n        fast_neg bool : do negative sampling inside a batch\n        record_loss bool : print the loss during training\n        norm bool : do normalizatin on the embedding after training\n        use_context_weight : give different weights to the nodes in a context window\n        async_update : asynchronous training\n        \"\"\"\n        super(SkipGramModel, self).__init__()\n        self.emb_size = emb_size\n        self.emb_dimension = emb_dimension\n        self.walk_length = walk_length\n        self.window_size = window_size\n        self.batch_size = batch_size\n        self.only_cpu = only_cpu\n        self.only_gpu = only_gpu\n        self.mixed_train = mix\n        self.neg_weight = neg_weight\n        self.negative = negative\n        self.lr = lr\n        self.lap_norm = lap_norm\n        self.fast_neg = fast_neg\n        self.record_loss = record_loss\n        self.norm = norm\n        self.use_context_weight = use_context_weight\n        self.async_update = async_update\n        self.num_threads = num_threads\n\n        # initialize the device as cpu\n        self.device = torch.device(\"cpu\")\n\n        # content embedding\n        self.u_embeddings = nn.Embedding(\n            self.emb_size, self.emb_dimension, sparse=True\n        )\n        # context embedding\n        self.v_embeddings = nn.Embedding(\n            self.emb_size, self.emb_dimension, sparse=True\n        )\n        # initialze embedding\n        initrange = 1.0 / self.emb_dimension\n        init.uniform_(self.u_embeddings.weight.data, -initrange, initrange)\n        init.constant_(self.v_embeddings.weight.data, 0)\n\n        # lookup_table is used for fast sigmoid computing\n        self.lookup_table = torch.sigmoid(torch.arange(-6.01, 6.01, 0.01))\n        self.lookup_table[0] = 0.0\n        self.lookup_table[-1] = 1.0\n        if self.record_loss:\n            self.logsigmoid_table = torch.log(\n                torch.sigmoid(torch.arange(-6.01, 6.01, 0.01))\n            )\n            self.loss = []\n\n        # indexes to select positive/negative node pairs from batch_walks\n        self.index_emb_posu, self.index_emb_posv = init_emb2pos_index(\n            self.walk_length, self.window_size, self.batch_size\n        )\n        self.index_emb_negu, self.index_emb_negv = init_emb2neg_index(\n            self.walk_length, self.window_size, self.negative, self.batch_size\n        )\n\n        if self.use_context_weight:\n            self.context_weight = init_weight(\n                self.walk_length, self.window_size, self.batch_size\n            )\n\n        # adam\n        self.state_sum_u = torch.zeros(self.emb_size)\n        self.state_sum_v = torch.zeros(self.emb_size)\n\n        # gradients of nodes in batch_walks\n        self.grad_u, self.grad_v = init_empty_grad(\n            self.emb_dimension, self.walk_length, self.batch_size\n        )\n\n    def create_async_update(self):\n        \"\"\"Set up the async update subprocess.\"\"\"\n        self.async_q = Queue(1)\n        self.async_p = mp.Process(\n            target=async_update, args=(self.num_threads, self, self.async_q)\n        )\n        self.async_p.start()\n\n    def finish_async_update(self):\n        \"\"\"Notify the async update subprocess to quit.\"\"\"\n        self.async_q.put((None, None, None, None, None))\n        self.async_p.join()\n\n    def share_memory(self):\n        \"\"\"share the parameters across subprocesses\"\"\"\n        self.u_embeddings.weight.share_memory_()\n        self.v_embeddings.weight.share_memory_()\n        self.state_sum_u.share_memory_()\n        self.state_sum_v.share_memory_()\n\n    def set_device(self, gpu_id):\n        \"\"\"set gpu device\"\"\"\n        self.device = torch.device(\"cuda:%d\" % gpu_id)\n        print(\"The device is\", self.device)\n        self.lookup_table = self.lookup_table.to(self.device)\n        if self.record_loss:\n            self.logsigmoid_table = self.logsigmoid_table.to(self.device)\n        self.index_emb_posu = self.index_emb_posu.to(self.device)\n        self.index_emb_posv = self.index_emb_posv.to(self.device)\n        self.index_emb_negu = self.index_emb_negu.to(self.device)\n        self.index_emb_negv = self.index_emb_negv.to(self.device)\n        self.grad_u = self.grad_u.to(self.device)\n        self.grad_v = self.grad_v.to(self.device)\n        if self.use_context_weight:\n            self.context_weight = self.context_weight.to(self.device)\n\n    def all_to_device(self, gpu_id):\n        \"\"\"move all of the parameters to a single GPU\"\"\"\n        self.device = torch.device(\"cuda:%d\" % gpu_id)\n        self.set_device(gpu_id)\n        self.u_embeddings = self.u_embeddings.cuda(gpu_id)\n        self.v_embeddings = self.v_embeddings.cuda(gpu_id)\n        self.state_sum_u = self.state_sum_u.to(self.device)\n        self.state_sum_v = self.state_sum_v.to(self.device)\n\n    def fast_sigmoid(self, score):\n        \"\"\"do fast sigmoid by looking up in a pre-defined table\"\"\"\n        idx = torch.floor((score + 6.01) / 0.01).long()\n        return self.lookup_table[idx]\n\n    def fast_logsigmoid(self, score):\n        \"\"\"do fast logsigmoid by looking up in a pre-defined table\"\"\"\n        idx = torch.floor((score + 6.01) / 0.01).long()\n        return self.logsigmoid_table[idx]\n\n    def fast_learn(self, batch_walks, neg_nodes=None):\n        \"\"\"Learn a batch of random walks in a fast way. It has the following features:\n            1. It calculating the gradients directly without the forward operation.\n            2. It does sigmoid by a looking up table.\n\n        Specifically, for each positive/negative node pair (i,j), the updating procedure is as following:\n            score = self.fast_sigmoid(u_embedding[i].dot(v_embedding[j]))\n            # label = 1 for positive samples; label = 0 for negative samples.\n            u_embedding[i] += (label - score) * v_embedding[j]\n            v_embedding[i] += (label - score) * u_embedding[j]\n\n        Parameters\n        ----------\n        batch_walks list : a list of node sequnces\n        lr float : current learning rate\n        neg_nodes torch.LongTensor : a long tensor of sampled true negative nodes. If neg_nodes is None,\n            then do negative sampling randomly from the nodes in batch_walks as an alternative.\n\n        Usage example\n        -------------\n        batch_walks = [torch.LongTensor([1,2,3,4]),\n                       torch.LongTensor([2,3,4,2])])\n        lr = 0.01\n        neg_nodes = None\n        \"\"\"\n        lr = self.lr\n\n        # [batch_size, walk_length]\n        if isinstance(batch_walks, list):\n            nodes = torch.stack(batch_walks)\n        elif isinstance(batch_walks, torch.LongTensor):\n            nodes = batch_walks\n        if self.only_gpu:\n            nodes = nodes.to(self.device)\n            if neg_nodes is not None:\n                neg_nodes = neg_nodes.to(self.device)\n        emb_u = (\n            self.u_embeddings(nodes)\n            .view(-1, self.emb_dimension)\n            .to(self.device)\n        )\n        emb_v = (\n            self.v_embeddings(nodes)\n            .view(-1, self.emb_dimension)\n            .to(self.device)\n        )\n\n        ## Postive\n        bs = len(batch_walks)\n        if bs < self.batch_size:\n            index_emb_posu, index_emb_posv = init_emb2pos_index(\n                self.walk_length, self.window_size, bs\n            )\n            index_emb_posu = index_emb_posu.to(self.device)\n            index_emb_posv = index_emb_posv.to(self.device)\n        else:\n            index_emb_posu = self.index_emb_posu\n            index_emb_posv = self.index_emb_posv\n\n        # num_pos: the number of positive node pairs generated by a single walk sequence\n        # [batch_size * num_pos, dim]\n        emb_pos_u = torch.index_select(emb_u, 0, index_emb_posu)\n        emb_pos_v = torch.index_select(emb_v, 0, index_emb_posv)\n\n        pos_score = torch.sum(torch.mul(emb_pos_u, emb_pos_v), dim=1)\n        pos_score = torch.clamp(pos_score, max=6, min=-6)\n        # [batch_size * num_pos, 1]\n        score = (1 - self.fast_sigmoid(pos_score)).unsqueeze(1)\n        if self.record_loss:\n            self.loss.append(torch.mean(self.fast_logsigmoid(pos_score)).item())\n\n        # [batch_size * num_pos, dim]\n        if self.lap_norm > 0:\n            grad_u_pos = score * emb_pos_v + self.lap_norm * (\n                emb_pos_v - emb_pos_u\n            )\n            grad_v_pos = score * emb_pos_u + self.lap_norm * (\n                emb_pos_u - emb_pos_v\n            )\n        else:\n            grad_u_pos = score * emb_pos_v\n            grad_v_pos = score * emb_pos_u\n\n        if self.use_context_weight:\n            if bs < self.batch_size:\n                context_weight = init_weight(\n                    self.walk_length, self.window_size, bs\n                ).to(self.device)\n            else:\n                context_weight = self.context_weight\n            grad_u_pos *= context_weight\n            grad_v_pos *= context_weight\n\n        # [batch_size * walk_length, dim]\n        if bs < self.batch_size:\n            grad_u, grad_v = init_empty_grad(\n                self.emb_dimension, self.walk_length, bs\n            )\n            grad_u = grad_u.to(self.device)\n            grad_v = grad_v.to(self.device)\n        else:\n            self.grad_u = self.grad_u.to(self.device)\n            self.grad_u.zero_()\n            self.grad_v = self.grad_v.to(self.device)\n            self.grad_v.zero_()\n            grad_u = self.grad_u\n            grad_v = self.grad_v\n        grad_u.index_add_(0, index_emb_posu, grad_u_pos)\n        grad_v.index_add_(0, index_emb_posv, grad_v_pos)\n\n        ## Negative\n        if bs < self.batch_size:\n            index_emb_negu, index_emb_negv = init_emb2neg_index(\n                self.walk_length, self.window_size, self.negative, bs\n            )\n            index_emb_negu = index_emb_negu.to(self.device)\n            index_emb_negv = index_emb_negv.to(self.device)\n        else:\n            index_emb_negu = self.index_emb_negu\n            index_emb_negv = self.index_emb_negv\n        emb_neg_u = torch.index_select(emb_u, 0, index_emb_negu)\n\n        if neg_nodes is None:\n            emb_neg_v = torch.index_select(emb_v, 0, index_emb_negv)\n        else:\n            emb_neg_v = self.v_embeddings.weight[neg_nodes].to(self.device)\n\n        # [batch_size * walk_length * negative, dim]\n        neg_score = torch.sum(torch.mul(emb_neg_u, emb_neg_v), dim=1)\n        neg_score = torch.clamp(neg_score, max=6, min=-6)\n        # [batch_size * walk_length * negative, 1]\n        score = -self.fast_sigmoid(neg_score).unsqueeze(1)\n        if self.record_loss:\n            self.loss.append(\n                self.negative\n                * self.neg_weight\n                * torch.mean(self.fast_logsigmoid(-neg_score)).item()\n            )\n\n        grad_u_neg = self.neg_weight * score * emb_neg_v\n        grad_v_neg = self.neg_weight * score * emb_neg_u\n\n        grad_u.index_add_(0, index_emb_negu, grad_u_neg)\n        if neg_nodes is None:\n            grad_v.index_add_(0, index_emb_negv, grad_v_neg)\n\n        ## Update\n        nodes = nodes.view(-1)\n\n        # use adam optimizer\n        grad_u = adam(\n            grad_u, self.state_sum_u, nodes, lr, self.device, self.only_gpu\n        )\n        grad_v = adam(\n            grad_v, self.state_sum_v, nodes, lr, self.device, self.only_gpu\n        )\n        if neg_nodes is not None:\n            grad_v_neg = adam(\n                grad_v_neg,\n                self.state_sum_v,\n                neg_nodes,\n                lr,\n                self.device,\n                self.only_gpu,\n            )\n\n        if self.mixed_train:\n            grad_u = grad_u.cpu()\n            grad_v = grad_v.cpu()\n            if neg_nodes is not None:\n                grad_v_neg = grad_v_neg.cpu()\n            else:\n                grad_v_neg = None\n\n            if self.async_update:\n                grad_u.share_memory_()\n                grad_v.share_memory_()\n                nodes.share_memory_()\n                if neg_nodes is not None:\n                    neg_nodes.share_memory_()\n                    grad_v_neg.share_memory_()\n                self.async_q.put((grad_u, grad_v, grad_v_neg, nodes, neg_nodes))\n\n        if not self.async_update:\n            self.u_embeddings.weight.data.index_add_(0, nodes.view(-1), grad_u)\n            self.v_embeddings.weight.data.index_add_(0, nodes.view(-1), grad_v)\n            if neg_nodes is not None:\n                self.v_embeddings.weight.data.index_add_(\n                    0, neg_nodes.view(-1), grad_v_neg\n                )\n        return\n\n    def forward(self, pos_u, pos_v, neg_v):\n        \"\"\"Do forward and backward. It is designed for future use.\"\"\"\n        emb_u = self.u_embeddings(pos_u)\n        emb_v = self.v_embeddings(pos_v)\n        emb_neg_v = self.v_embeddings(neg_v)\n\n        score = torch.sum(torch.mul(emb_u, emb_v), dim=1)\n        score = torch.clamp(score, max=6, min=-6)\n        score = -F.logsigmoid(score)\n\n        neg_score = torch.bmm(emb_neg_v, emb_u.unsqueeze(2)).squeeze()\n        neg_score = torch.clamp(neg_score, max=6, min=-6)\n        neg_score = -torch.sum(F.logsigmoid(-neg_score), dim=1)\n\n        # return torch.mean(score + neg_score)\n        return torch.sum(score), torch.sum(neg_score)\n\n    def save_embedding(self, dataset, file_name):\n        \"\"\"Write embedding to local file. Only used when node ids are numbers.\n\n        Parameter\n        ---------\n        dataset DeepwalkDataset : the dataset\n        file_name str : the file name\n        \"\"\"\n        embedding = self.u_embeddings.weight.cpu().data.numpy()\n        if self.norm:\n            embedding /= np.sqrt(np.sum(embedding * embedding, 1)).reshape(\n                -1, 1\n            )\n        np.save(file_name, embedding)\n\n    def save_embedding_pt(self, dataset, file_name):\n        \"\"\"For ogb leaderboard.\"\"\"\n        try:\n            max_node_id = max(dataset.node2id.keys())\n            if max_node_id + 1 != self.emb_size:\n                print(\"WARNING: The node ids are not serial.\")\n\n            embedding = torch.zeros(max_node_id + 1, self.emb_dimension)\n            index = torch.LongTensor(\n                list(\n                    map(\n                        lambda id: dataset.id2node[id],\n                        list(range(self.emb_size)),\n                    )\n                )\n            )\n            embedding.index_add_(0, index, self.u_embeddings.weight.cpu().data)\n\n            if self.norm:\n                embedding /= torch.sqrt(\n                    torch.sum(embedding.mul(embedding), 1) + 1e-6\n                ).unsqueeze(1)\n            torch.save(embedding, file_name)\n        except:\n            self.save_embedding_pt_dgl_graph(dataset, file_name)\n\n    def save_embedding_pt_dgl_graph(self, dataset, file_name):\n        \"\"\"For ogb leaderboard\"\"\"\n        embedding = torch.zeros_like(self.u_embeddings.weight.cpu().data)\n        valid_seeds = torch.LongTensor(dataset.valid_seeds)\n        valid_embedding = self.u_embeddings.weight.cpu().data.index_select(\n            0, valid_seeds\n        )\n        embedding.index_add_(0, valid_seeds, valid_embedding)\n\n        if self.norm:\n            embedding /= torch.sqrt(\n                torch.sum(embedding.mul(embedding), 1) + 1e-6\n            ).unsqueeze(1)\n\n        torch.save(embedding, file_name)\n\n    def save_embedding_txt(self, dataset, file_name):\n        \"\"\"Write embedding to local file. For future use.\n\n        Parameter\n        ---------\n        dataset DeepwalkDataset : the dataset\n        file_name str : the file name\n        \"\"\"\n        embedding = self.u_embeddings.weight.cpu().data.numpy()\n        if self.norm:\n            embedding /= np.sqrt(np.sum(embedding * embedding, 1)).reshape(\n                -1, 1\n            )\n        with open(file_name, \"w\") as f:\n            f.write(\"%d %d\\n\" % (self.emb_size, self.emb_dimension))\n            for wid in range(self.emb_size):\n                e = \" \".join(map(lambda x: str(x), embedding[wid]))\n                f.write(\"%s %s\\n\" % (str(dataset.id2node[wid]), e))\n"
  },
  {
    "path": "examples/pytorch/ogb/deepwalk/reading_data.py",
    "content": "import os\nimport pickle\nimport random\nimport time\n\nimport dgl\n\nimport numpy as np\nimport scipy.sparse as sp\nimport torch\nfrom dgl.data.utils import (\n    _get_dgl_url,\n    download,\n    extract_archive,\n    get_download_dir,\n)\nfrom torch.utils.data import DataLoader\nfrom utils import shuffle_walks\n\n\ndef ReadTxtNet(file_path=\"\", undirected=True):\n    \"\"\"Read the txt network file.\n    Notations: The network is unweighted.\n\n    Parameters\n    ----------\n    file_path str : path of network file\n    undirected bool : whether the edges are undirected\n\n    Return\n    ------\n    net dict : a dict recording the connections in the graph\n    node2id dict : a dict mapping the nodes to their embedding indices\n    id2node dict : a dict mapping nodes embedding indices to the nodes\n    \"\"\"\n    if file_path == \"youtube\" or file_path == \"blog\":\n        name = file_path\n        dir = get_download_dir()\n        zip_file_path = \"{}/{}.zip\".format(dir, name)\n        download(\n            _get_dgl_url(\n                os.path.join(\"dataset/DeepWalk/\", \"{}.zip\".format(file_path))\n            ),\n            path=zip_file_path,\n        )\n        extract_archive(zip_file_path, \"{}/{}\".format(dir, name))\n        file_path = \"{}/{}/{}-net.txt\".format(dir, name, name)\n\n    node2id = {}\n    id2node = {}\n    cid = 0\n\n    src = []\n    dst = []\n    weight = []\n    net = {}\n    with open(file_path, \"r\") as f:\n        for line in f.readlines():\n            tup = list(map(int, line.strip().split(\" \")))\n            assert len(tup) in [\n                2,\n                3,\n            ], \"The format of network file is unrecognizable.\"\n            if len(tup) == 3:\n                n1, n2, w = tup\n            elif len(tup) == 2:\n                n1, n2 = tup\n                w = 1\n            if n1 not in node2id:\n                node2id[n1] = cid\n                id2node[cid] = n1\n                cid += 1\n            if n2 not in node2id:\n                node2id[n2] = cid\n                id2node[cid] = n2\n                cid += 1\n\n            n1 = node2id[n1]\n            n2 = node2id[n2]\n            if n1 not in net:\n                net[n1] = {n2: w}\n                src.append(n1)\n                dst.append(n2)\n                weight.append(w)\n            elif n2 not in net[n1]:\n                net[n1][n2] = w\n                src.append(n1)\n                dst.append(n2)\n                weight.append(w)\n\n            if undirected:\n                if n2 not in net:\n                    net[n2] = {n1: w}\n                    src.append(n2)\n                    dst.append(n1)\n                    weight.append(w)\n                elif n1 not in net[n2]:\n                    net[n2][n1] = w\n                    src.append(n2)\n                    dst.append(n1)\n                    weight.append(w)\n\n    print(\"node num: %d\" % len(net))\n    print(\"edge num: %d\" % len(src))\n    assert max(net.keys()) == len(net) - 1, \"error reading net, quit\"\n\n    sm = sp.coo_matrix((np.array(weight), (src, dst)), dtype=np.float32)\n\n    return net, node2id, id2node, sm\n\n\ndef net2graph(net_sm):\n    \"\"\"Transform the network to DGL graph\n\n    Return\n    ------\n    G DGLGraph : graph by DGL\n    \"\"\"\n    start = time.time()\n    G = dgl.from_scipy(net_sm)\n    end = time.time()\n    t = end - start\n    print(\"Building DGLGraph in %.2fs\" % t)\n    return G\n\n\ndef make_undirected(G):\n    G.add_edges(G.edges()[1], G.edges()[0])\n    return G\n\n\ndef find_connected_nodes(G):\n    nodes = G.out_degrees().nonzero().squeeze(-1)\n    return nodes\n\n\nclass DeepwalkDataset:\n    def __init__(\n        self,\n        net_file,\n        map_file,\n        walk_length,\n        window_size,\n        num_walks,\n        batch_size,\n        negative=5,\n        gpus=[0],\n        fast_neg=True,\n        ogbl_name=\"\",\n        load_from_ogbl=False,\n    ):\n        \"\"\"This class has the following functions:\n        1. Transform the txt network file into DGL graph;\n        2. Generate random walk sequences for the trainer;\n        3. Provide the negative table if the user hopes to sample negative\n        nodes according to nodes' degrees;\n\n        Parameter\n        ---------\n        net_file str : path of the txt network file\n        walk_length int : number of nodes in a sequence\n        window_size int : context window size\n        num_walks int : number of walks for each node\n        batch_size int : number of node sequences in each batch\n        negative int : negative samples for each positve node pair\n        fast_neg bool : whether do negative sampling inside a batch\n        \"\"\"\n        self.walk_length = walk_length\n        self.window_size = window_size\n        self.num_walks = num_walks\n        self.batch_size = batch_size\n        self.negative = negative\n        self.num_procs = len(gpus)\n        self.fast_neg = fast_neg\n\n        if load_from_ogbl:\n            assert (\n                len(gpus) == 1\n            ), \"ogb.linkproppred is not compatible with multi-gpu training (CUDA error).\"\n            from load_dataset import load_from_ogbl_with_name\n\n            self.G = load_from_ogbl_with_name(ogbl_name)\n            self.G = make_undirected(self.G)\n        else:\n            self.net, self.node2id, self.id2node, self.sm = ReadTxtNet(net_file)\n            self.save_mapping(map_file)\n            self.G = net2graph(self.sm)\n\n        self.num_nodes = self.G.num_nodes()\n\n        # random walk seeds\n        start = time.time()\n        self.valid_seeds = find_connected_nodes(self.G)\n        if len(self.valid_seeds) != self.num_nodes:\n            print(\n                \"WARNING: The node ids are not serial. Some nodes are invalid.\"\n            )\n\n        seeds = torch.cat([torch.LongTensor(self.valid_seeds)] * num_walks)\n        self.seeds = torch.split(\n            shuffle_walks(seeds),\n            int(\n                np.ceil(len(self.valid_seeds) * self.num_walks / self.num_procs)\n            ),\n            0,\n        )\n        end = time.time()\n        t = end - start\n        print(\"%d seeds in %.2fs\" % (len(seeds), t))\n\n        # negative table for true negative sampling\n        if not fast_neg:\n            node_degree = self.G.out_degrees(self.valid_seeds).numpy()\n            node_degree = np.power(node_degree, 0.75)\n            node_degree /= np.sum(node_degree)\n            node_degree = np.array(node_degree * 1e8, dtype=int)\n            self.neg_table = []\n\n            for idx, node in enumerate(self.valid_seeds):\n                self.neg_table += [node] * node_degree[idx]\n            self.neg_table_size = len(self.neg_table)\n            self.neg_table = np.array(self.neg_table, dtype=int)\n            del node_degree\n\n    def create_sampler(self, i):\n        \"\"\"create random walk sampler\"\"\"\n        return DeepwalkSampler(self.G, self.seeds[i], self.walk_length)\n\n    def save_mapping(self, map_file):\n        \"\"\"save the mapping dict that maps node IDs to embedding indices\"\"\"\n        with open(map_file, \"wb\") as f:\n            pickle.dump(self.node2id, f)\n\n\nclass DeepwalkSampler(object):\n    def __init__(self, G, seeds, walk_length):\n        \"\"\"random walk sampler\n\n        Parameter\n        ---------\n        G dgl.Graph : the input graph\n        seeds torch.LongTensor : starting nodes\n        walk_length int : walk length\n        \"\"\"\n        self.G = G\n        self.seeds = seeds\n        self.walk_length = walk_length\n\n    def sample(self, seeds):\n        walks = dgl.sampling.random_walk(\n            self.G, seeds, length=self.walk_length - 1\n        )[0]\n        return walks\n"
  },
  {
    "path": "examples/pytorch/ogb/deepwalk/utils.py",
    "content": "import torch\n\n\ndef shuffle_walks(walks):\n    seeds = torch.randperm(walks.size()[0])\n    return walks[seeds]\n\n\ndef sum_up_params(model):\n    \"\"\"Count the model parameters\"\"\"\n    n = []\n    n.append(model.u_embeddings.weight.cpu().data.numel() * 2)\n    n.append(model.lookup_table.cpu().numel())\n    n.append(model.index_emb_posu.cpu().numel() * 2)\n    n.append(model.grad_u.cpu().numel() * 2)\n\n    try:\n        n.append(model.index_emb_negu.cpu().numel() * 2)\n    except:\n        pass\n    try:\n        n.append(model.state_sum_u.cpu().numel() * 2)\n    except:\n        pass\n    try:\n        n.append(model.grad_avg.cpu().numel())\n    except:\n        pass\n    try:\n        n.append(model.context_weight.cpu().numel())\n    except:\n        pass\n\n    print(\"#params \" + str(sum(n)))\n    exit()\n"
  },
  {
    "path": "examples/pytorch/ogb/directional_GSN/README.md",
    "content": "# directional_GSN\n\n## Introduction\n\nThis is an example of implementing [directional_GSN](https://arxiv.org/abs/2006.09252) for graph classification in DGL.\n\ndirectional_GSN is a combination of Graph Substructure Networks ([GSN](https://arxiv.org/abs/2006.09252)) with Directional Graph Networks ([DGN](https://arxiv.org/pdf/2010.02863.pdf)), where we defined a vector field based on substructure encoding instead of Laplacian eigenvectors.\n\nThe script in this folder experiments directional_GSN on ogbg-molpcba dataset.\n\n## Installation requirements\n```\nconda create --name gsn python=3.7\nconda activate gsn\nconda install pytorch==1.11.0 cudatoolkit=10.2 -c pytorch\npip install tqdm\npip install networkx\nconda install -c conda-forge graph-tool\npip install ogb\npip install dgl-cu102 -f https://data.dgl.ai/wheels/repo.html\n```\n\n## Experiments\n\nWe fix the random seed to 41, and train the model on a single Tesla T4 GPU with 16GB memory.\n\n### ogbg-molpcba\n\n#### performance\n\n|                  | train_AP | valid_AP | test_AP | #parameters |\n| ---------------- | ---------| -------- | ------- | ----------- |\n| directional_GSN  | 0.4301   | 0.2598   | 0.2438  | 5142713     |\n\n\n#### Reproduction of performance\n\n```{.bash}\npython preprocessing.py\npython main.py --seed 41 --epochs 450 --hidden_dim 420 --out_dim 420 --dropout 0.2\n```\n\n## References\n\n```{.tex}\n@article{bouritsas2020improving,\n  title={Improving graph neural network expressivity via subgraph isomorphism counting},\n  author={Bouritsas, Giorgos and \n          Frasca, Fabrizio and \n          Zafeiriou, Stefanos and \n          Bronstein, Michael M},\n  journal={arXiv preprint arXiv:2006.09252},\n  year={2020}\n}\n```"
  },
  {
    "path": "examples/pytorch/ogb/directional_GSN/main.py",
    "content": "import argparse\nimport random\n\nimport dgl\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nfrom dgl.dataloading import GraphDataLoader\nfrom ogb.graphproppred import Evaluator\nfrom ogb.graphproppred.mol_encoder import AtomEncoder\nfrom preprocessing import prepare_dataset\nfrom torch.utils.data import Dataset\nfrom tqdm import tqdm\n\n\ndef aggregate_mean(h, vector_field, h_in):\n    return torch.mean(h, dim=1)\n\n\ndef aggregate_max(h, vector_field, h_in):\n    return torch.max(h, dim=1)[0]\n\n\ndef aggregate_sum(h, vector_field, h_in):\n    return torch.sum(h, dim=1)\n\n\ndef aggregate_dir_dx(h, vector_field, h_in, eig_idx=1):\n    eig_w = (\n        (vector_field[:, :, eig_idx])\n        / (\n            torch.sum(\n                torch.abs(vector_field[:, :, eig_idx]), keepdim=True, dim=1\n            )\n            + 1e-8\n        )\n    ).unsqueeze(-1)\n    h_mod = torch.mul(h, eig_w)\n    return torch.abs(torch.sum(h_mod, dim=1) - torch.sum(eig_w, dim=1) * h_in)\n\n\nclass FCLayer(nn.Module):\n    def __init__(self, in_size, out_size):\n        super(FCLayer, self).__init__()\n\n        self.in_size = in_size\n        self.out_size = out_size\n        self.linear = nn.Linear(in_size, out_size, bias=True)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        nn.init.xavier_uniform_(self.linear.weight, 1 / self.in_size)\n        self.linear.bias.data.zero_()\n\n    def forward(self, x):\n        h = self.linear(x)\n        return h\n\n\nclass MLP(nn.Module):\n    def __init__(self, in_size, out_size):\n        super(MLP, self).__init__()\n\n        self.in_size = in_size\n        self.out_size = out_size\n        self.fc = FCLayer(in_size, out_size)\n\n    def forward(self, x):\n        x = self.fc(x)\n        return x\n\n\nclass DGNLayer(nn.Module):\n    def __init__(self, in_dim, out_dim, dropout, aggregators):\n        super().__init__()\n\n        self.dropout = dropout\n\n        self.aggregators = aggregators\n\n        self.batchnorm_h = nn.BatchNorm1d(out_dim)\n        self.pretrans = MLP(in_size=2 * in_dim, out_size=in_dim)\n        self.posttrans = MLP(\n            in_size=(len(aggregators) * 1 + 1) * in_dim, out_size=out_dim\n        )\n\n    def pretrans_edges(self, edges):\n        z2 = torch.cat([edges.src[\"h\"], edges.dst[\"h\"]], dim=1)\n        vector_field = edges.data[\"eig\"]\n        return {\"e\": self.pretrans(z2), \"vector_field\": vector_field}\n\n    def message_func(self, edges):\n        return {\n            \"e\": edges.data[\"e\"],\n            \"vector_field\": edges.data[\"vector_field\"],\n        }\n\n    def reduce_func(self, nodes):\n        h_in = nodes.data[\"h\"]\n        h = nodes.mailbox[\"e\"]\n\n        vector_field = nodes.mailbox[\"vector_field\"]\n\n        h = torch.cat(\n            [\n                aggregate(h, vector_field, h_in)\n                for aggregate in self.aggregators\n            ],\n            dim=1,\n        )\n\n        return {\"h\": h}\n\n    def forward(self, g, h, snorm_n):\n        g.ndata[\"h\"] = h\n\n        # pretransformation\n        g.apply_edges(self.pretrans_edges)\n\n        # aggregation\n        g.update_all(self.message_func, self.reduce_func)\n        h = torch.cat([h, g.ndata[\"h\"]], dim=1)\n\n        # posttransformation\n        h = self.posttrans(h)\n\n        # graph and batch normalization\n        h = h * snorm_n\n        h = self.batchnorm_h(h)\n        h = F.relu(h)\n\n        h = F.dropout(h, self.dropout, training=self.training)\n\n        return h\n\n\nclass MLPReadout(nn.Module):\n    def __init__(self, input_dim, output_dim, L=2):  # L=nb_hidden_layers\n        super().__init__()\n        list_FC_layers = [\n            nn.Linear(input_dim // 2**l, input_dim // 2 ** (l + 1), bias=True)\n            for l in range(L)\n        ]\n        list_FC_layers.append(\n            nn.Linear(input_dim // 2**L, output_dim, bias=True)\n        )\n        self.FC_layers = nn.ModuleList(list_FC_layers)\n        self.L = L\n\n    def forward(self, x):\n        y = x\n        for l in range(self.L):\n            y = self.FC_layers[l](y)\n            y = F.relu(y)\n        y = self.FC_layers[self.L](y)\n        return y\n\n\nclass DGNNet(nn.Module):\n    def __init__(self, hidden_dim=420, out_dim=420, dropout=0.2, n_layers=4):\n        super().__init__()\n\n        self.embedding_h = AtomEncoder(emb_dim=hidden_dim)\n        self.aggregators = [\n            aggregate_mean,\n            aggregate_sum,\n            aggregate_max,\n            aggregate_dir_dx,\n        ]\n\n        self.layers = nn.ModuleList(\n            [\n                DGNLayer(\n                    in_dim=hidden_dim,\n                    out_dim=hidden_dim,\n                    dropout=dropout,\n                    aggregators=self.aggregators,\n                )\n                for _ in range(n_layers - 1)\n            ]\n        )\n        self.layers.append(\n            DGNLayer(\n                in_dim=hidden_dim,\n                out_dim=out_dim,\n                dropout=dropout,\n                aggregators=self.aggregators,\n            )\n        )\n\n        # 128 out dim since ogbg-molpcba has 128 tasks\n        self.MLP_layer = MLPReadout(out_dim, 128)\n\n    def forward(self, g, h, snorm_n):\n        h = self.embedding_h(h)\n\n        for i, conv in enumerate(self.layers):\n            h_t = conv(g, h, snorm_n)\n            h = h_t\n\n        g.ndata[\"h\"] = h\n\n        hg = dgl.mean_nodes(g, \"h\")\n\n        return self.MLP_layer(hg)\n\n    def loss(self, scores, labels):\n        is_labeled = labels == labels\n        loss = nn.BCEWithLogitsLoss()(\n            scores[is_labeled], labels[is_labeled].float()\n        )\n        return loss\n\n\ndef train_epoch(model, optimizer, device, data_loader):\n    model.train()\n    epoch_loss = 0\n    epoch_train_AP = 0\n    list_scores = []\n    list_labels = []\n    for iter, (batch_graphs, batch_labels, batch_snorm_n) in enumerate(\n        data_loader\n    ):\n        batch_graphs = batch_graphs.to(device)\n        batch_x = batch_graphs.ndata[\"feat\"]  # num x feat\n        batch_snorm_n = batch_snorm_n.to(device)\n        batch_labels = batch_labels.to(device)\n        optimizer.zero_grad()\n\n        batch_scores = model(batch_graphs, batch_x, batch_snorm_n)\n\n        loss = model.loss(batch_scores, batch_labels)\n        loss.backward()\n        optimizer.step()\n        epoch_loss += loss.item()\n        list_scores.append(batch_scores)\n        list_labels.append(batch_labels)\n\n    epoch_loss /= iter + 1\n\n    evaluator = Evaluator(name=\"ogbg-molpcba\")\n    epoch_train_AP = evaluator.eval(\n        {\"y_pred\": torch.cat(list_scores), \"y_true\": torch.cat(list_labels)}\n    )[\"ap\"]\n\n    return epoch_loss, epoch_train_AP\n\n\ndef evaluate_network(model, device, data_loader):\n    model.eval()\n    epoch_test_loss = 0\n    epoch_test_AP = 0\n    with torch.no_grad():\n        list_scores = []\n        list_labels = []\n        for iter, (batch_graphs, batch_labels, batch_snorm_n) in enumerate(\n            data_loader\n        ):\n            batch_graphs = batch_graphs.to(device)\n            batch_x = batch_graphs.ndata[\"feat\"]\n            batch_snorm_n = batch_snorm_n.to(device)\n            batch_labels = batch_labels.to(device)\n\n            batch_scores = model(batch_graphs, batch_x, batch_snorm_n)\n\n            loss = model.loss(batch_scores, batch_labels)\n            epoch_test_loss += loss.item()\n            list_scores.append(batch_scores)\n            list_labels.append(batch_labels)\n\n        epoch_test_loss /= iter + 1\n\n        evaluator = Evaluator(name=\"ogbg-molpcba\")\n        epoch_test_AP = evaluator.eval(\n            {\"y_pred\": torch.cat(list_scores), \"y_true\": torch.cat(list_labels)}\n        )[\"ap\"]\n\n    return epoch_test_loss, epoch_test_AP\n\n\ndef train(dataset, params):\n    trainset, valset, testset = dataset.train, dataset.val, dataset.test\n    device = params.device\n\n    print(\"Training Graphs: \", len(trainset))\n    print(\"Validation Graphs: \", len(valset))\n    print(\"Test Graphs: \", len(testset))\n\n    model = DGNNet()\n    model = model.to(device)\n\n    # view model parameters\n    total_param = 0\n    print(\"MODEL DETAILS:\\n\")\n    for param in model.parameters():\n        total_param += np.prod(list(param.data.size()))\n    print(\"DGN Total parameters:\", total_param)\n\n    optimizer = optim.Adam(model.parameters(), lr=0.0008, weight_decay=1e-5)\n    scheduler = optim.lr_scheduler.ReduceLROnPlateau(\n        optimizer, mode=\"min\", factor=0.8, patience=8\n    )\n\n    epoch_train_losses, epoch_val_losses = [], []\n    epoch_train_APs, epoch_val_APs, epoch_test_APs = [], [], []\n\n    train_loader = GraphDataLoader(\n        trainset,\n        batch_size=params.batch_size,\n        shuffle=True,\n        collate_fn=dataset.collate,\n        pin_memory=True,\n    )\n    val_loader = GraphDataLoader(\n        valset,\n        batch_size=params.batch_size,\n        shuffle=False,\n        collate_fn=dataset.collate,\n        pin_memory=True,\n    )\n    test_loader = GraphDataLoader(\n        testset,\n        batch_size=params.batch_size,\n        shuffle=False,\n        collate_fn=dataset.collate,\n        pin_memory=True,\n    )\n\n    with tqdm(range(450), unit=\"epoch\") as t:\n        for epoch in t:\n            t.set_description(\"Epoch %d\" % epoch)\n\n            epoch_train_loss, epoch_train_ap = train_epoch(\n                model, optimizer, device, train_loader\n            )\n            epoch_val_loss, epoch_val_ap = evaluate_network(\n                model, device, val_loader\n            )\n\n            epoch_train_losses.append(epoch_train_loss)\n            epoch_val_losses.append(epoch_val_loss)\n            epoch_train_APs.append(epoch_train_ap.item())\n            epoch_val_APs.append(epoch_val_ap.item())\n\n            _, epoch_test_ap = evaluate_network(model, device, test_loader)\n\n            epoch_test_APs.append(epoch_test_ap.item())\n\n            t.set_postfix(\n                train_loss=epoch_train_loss,\n                train_AP=epoch_train_ap.item(),\n                val_AP=epoch_val_ap.item(),\n                refresh=False,\n            )\n\n            scheduler.step(-epoch_val_ap.item())\n\n            if optimizer.param_groups[0][\"lr\"] < 1e-5:\n                print(\"\\n!! LR EQUAL TO MIN LR SET.\")\n                break\n\n            print(\"\")\n\n    best_val_epoch = np.argmax(np.array(epoch_val_APs))\n    best_train_epoch = np.argmax(np.array(epoch_train_APs))\n    best_val_ap = epoch_val_APs[best_val_epoch]\n    best_val_test_ap = epoch_test_APs[best_val_epoch]\n    best_val_train_ap = epoch_train_APs[best_val_epoch]\n    best_train_ap = epoch_train_APs[best_train_epoch]\n\n    print(\"Best Train AP: {:.4f}\".format(best_train_ap))\n    print(\"Best Val AP: {:.4f}\".format(best_val_ap))\n    print(\"Test AP of Best Val: {:.4f}\".format(best_val_test_ap))\n    print(\"Train AP of Best Val: {:.4f}\".format(best_val_train_ap))\n\n\nclass Subset(object):\n    def __init__(self, dataset, labels, indices):\n        dataset = [dataset[idx] for idx in indices]\n        labels = [labels[idx] for idx in indices]\n        self.dataset, self.labels = [], []\n        for i, g in enumerate(dataset):\n            if g.num_nodes() > 5:\n                self.dataset.append(g)\n                self.labels.append(labels[i])\n        self.len = len(self.dataset)\n\n    def __getitem__(self, item):\n        return self.dataset[item], self.labels[item]\n\n    def __len__(self):\n        return self.len\n\n\nclass PCBADataset(Dataset):\n    def __init__(self, name):\n        print(\"[I] Loading dataset %s...\" % (name))\n        self.name = name\n\n        self.dataset, self.split_idx = prepare_dataset(name)\n        print(\"One hot encoding substructure counts... \", end=\"\")\n        self.d_id = [1] * self.dataset[0].edata[\"subgraph_counts\"].shape[1]\n\n        for g in self.dataset:\n            g.edata[\"eig\"] = g.edata[\"subgraph_counts\"].float()\n\n        self.train = Subset(\n            self.dataset, self.split_idx[\"label\"], self.split_idx[\"train\"]\n        )\n        self.val = Subset(\n            self.dataset, self.split_idx[\"label\"], self.split_idx[\"valid\"]\n        )\n        self.test = Subset(\n            self.dataset, self.split_idx[\"label\"], self.split_idx[\"test\"]\n        )\n\n        print(\n            \"train, test, val sizes :\",\n            len(self.train),\n            len(self.test),\n            len(self.val),\n        )\n        print(\"[I] Finished loading.\")\n\n    # form a mini batch from a given list of samples = [(graph, label) pairs]\n    def collate(self, samples):\n        # The input samples is a list of pairs (graph, label).\n        graphs, labels = map(list, zip(*samples))\n        labels = torch.stack(labels)\n\n        tab_sizes_n = [g.num_nodes() for g in graphs]\n        tab_snorm_n = [\n            torch.FloatTensor(size, 1).fill_(1.0 / size) for size in tab_sizes_n\n        ]\n        snorm_n = torch.cat(tab_snorm_n).sqrt()\n        batched_graph = dgl.batch(graphs)\n\n        return batched_graph, labels, snorm_n\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--gpu_id\", default=0, type=int, help=\"Please give a value for gpu id\"\n    )\n    parser.add_argument(\n        \"--seed\", default=41, type=int, help=\"Please give a value for seed\"\n    )\n    parser.add_argument(\n        \"--batch_size\",\n        default=2048,\n        type=int,\n        help=\"Please give a value for batch_size\",\n    )\n    args = parser.parse_args()\n\n    # device\n    args.device = torch.device(\n        \"cuda:{}\".format(args.gpu_id) if torch.cuda.is_available() else \"cpu\"\n    )\n\n    # setting seeds\n    random.seed(args.seed)\n    np.random.seed(args.seed)\n    torch.manual_seed(args.seed)\n    if torch.cuda.is_available():\n        torch.cuda.manual_seed(args.seed)\n\n    dataset = PCBADataset(\"ogbg-molpcba\")\n    train(dataset, args)\n"
  },
  {
    "path": "examples/pytorch/ogb/directional_GSN/preprocessing.py",
    "content": "import os\n\nimport graph_tool as gt\nimport graph_tool.topology as gt_topology\nimport networkx as nx\nimport numpy as np\nimport torch\n\nfrom dgl.data.utils import load_graphs, save_graphs\nfrom ogb.graphproppred import DglGraphPropPredDataset\nfrom tqdm import tqdm\n\n\ndef to_undirected(edge_index):\n    row, col = edge_index.transpose(1, 0)\n    row, col = torch.cat([row, col], dim=0), torch.cat([col, row], dim=0)\n    edge_index = torch.stack([row, col], dim=0)\n\n    return edge_index.transpose(1, 0).tolist()\n\n\ndef induced_edge_automorphism_orbits(edge_list):\n    ##### node automorphism orbits #####\n    graph = gt.Graph(directed=False)\n    graph.add_edge_list(edge_list)\n    gt.stats.remove_self_loops(graph)\n    gt.stats.remove_parallel_edges(graph)\n\n    # compute the node automorphism group\n    aut_group = gt_topology.subgraph_isomorphism(\n        graph, graph, induced=False, subgraph=True, generator=False\n    )\n\n    orbit_membership = {}\n    for v in graph.get_vertices():\n        orbit_membership[v] = v\n\n    # whenever two nodes can be mapped via some automorphism, they are assigned the same orbit\n    for aut in aut_group:\n        for original, node in enumerate(aut):\n            role = min(original, orbit_membership[node])\n            orbit_membership[node] = role\n\n    orbit_membership_list = [[], []]\n    for node, om_curr in orbit_membership.items():\n        orbit_membership_list[0].append(node)\n        orbit_membership_list[1].append(om_curr)\n\n    # make orbit list contiguous (i.e. 0,1,2,...O)\n    _, contiguous_orbit_membership = np.unique(\n        orbit_membership_list[1], return_inverse=True\n    )\n\n    orbit_membership = {\n        node: contiguous_orbit_membership[i]\n        for i, node in enumerate(orbit_membership_list[0])\n    }\n\n    aut_count = len(aut_group)\n\n    ##### induced edge automorphism orbits (according to the node automorphism group) #####\n    edge_orbit_partition = dict()\n    edge_orbit_membership = dict()\n    edge_orbits2inds = dict()\n    ind = 0\n\n    edge_list = to_undirected(torch.tensor(graph.get_edges()))\n\n    # infer edge automorphisms from the node automorphisms\n    for i, edge in enumerate(edge_list):\n        edge_orbit = frozenset(\n            [orbit_membership[edge[0]], orbit_membership[edge[1]]]\n        )\n        if edge_orbit not in edge_orbits2inds:\n            edge_orbits2inds[edge_orbit] = ind\n            ind_edge_orbit = ind\n            ind += 1\n        else:\n            ind_edge_orbit = edge_orbits2inds[edge_orbit]\n\n        if ind_edge_orbit not in edge_orbit_partition:\n            edge_orbit_partition[ind_edge_orbit] = [tuple(edge)]\n        else:\n            edge_orbit_partition[ind_edge_orbit] += [tuple(edge)]\n\n        edge_orbit_membership[i] = ind_edge_orbit\n\n    print(\n        \"Edge orbit partition of given substructure: {}\".format(\n            edge_orbit_partition\n        )\n    )\n    print(\"Number of edge orbits: {}\".format(len(edge_orbit_partition)))\n    print(\"Graph (node) automorphism count: {}\".format(aut_count))\n\n    return graph, edge_orbit_partition, edge_orbit_membership, aut_count\n\n\ndef subgraph_isomorphism_edge_counts(edge_index, subgraph_dict):\n    ##### edge structural identifiers #####\n\n    edge_index = edge_index.transpose(1, 0).cpu().numpy()\n    edge_dict = {}\n    for i, edge in enumerate(edge_index):\n        edge_dict[tuple(edge)] = i\n\n    subgraph_edges = to_undirected(\n        torch.tensor(subgraph_dict[\"subgraph\"].get_edges().tolist())\n    )\n\n    G_gt = gt.Graph(directed=False)\n    G_gt.add_edge_list(list(edge_index))\n    gt.stats.remove_self_loops(G_gt)\n    gt.stats.remove_parallel_edges(G_gt)\n\n    # compute all subgraph isomorphisms\n    sub_iso = gt_topology.subgraph_isomorphism(\n        subgraph_dict[\"subgraph\"],\n        G_gt,\n        induced=True,\n        subgraph=True,\n        generator=True,\n    )\n\n    counts = np.zeros(\n        (edge_index.shape[0], len(subgraph_dict[\"orbit_partition\"]))\n    )\n\n    for sub_iso_curr in sub_iso:\n        mapping = sub_iso_curr.get_array()\n        for i, edge in enumerate(subgraph_edges):\n            # for every edge in the graph H, find the edge in the subgraph G_S to which it is mapped\n            # (by finding where its endpoints are matched).\n            # Then, increase the count of the matched edge w.r.t. the corresponding orbit\n            # Repeat for the reverse edge (the one with the opposite direction)\n\n            edge_orbit = subgraph_dict[\"orbit_membership\"][i]\n            mapped_edge = tuple([mapping[edge[0]], mapping[edge[1]]])\n            counts[edge_dict[mapped_edge], edge_orbit] += 1\n\n    counts = counts / subgraph_dict[\"aut_count\"]\n\n    counts = torch.tensor(counts)\n\n    return counts\n\n\ndef prepare_dataset(name):\n    # maximum size of cycle graph\n    k = 8\n\n    path = os.path.join(\"./\", \"dataset\", name)\n    data_folder = os.path.join(path, \"processed\")\n    os.makedirs(data_folder, exist_ok=True)\n\n    data_file = os.path.join(\n        data_folder, \"cycle_graph_induced_{}.bin\".format(k)\n    )\n\n    # try to load\n    if os.path.exists(data_file):  # load\n        print(\"Loading dataset from {}\".format(data_file))\n        g_list, split_idx = load_graphs(data_file)\n    else:  # generate\n        g_list, split_idx = generate_dataset(path, name)\n        print(\"Saving dataset to {}\".format(data_file))\n        save_graphs(data_file, g_list, split_idx)\n\n    return g_list, split_idx\n\n\ndef generate_dataset(path, name):\n    ### compute the orbits of each substructure in the list, as well as the node automorphism count\n    subgraph_dicts = []\n\n    edge_lists = []\n    for k in range(3, 8 + 1):\n        graphs_nx = nx.cycle_graph(k)\n        edge_lists.append(list(graphs_nx.edges))\n\n    for edge_list in edge_lists:\n        (\n            subgraph,\n            orbit_partition,\n            orbit_membership,\n            aut_count,\n        ) = induced_edge_automorphism_orbits(edge_list=edge_list)\n        subgraph_dicts.append(\n            {\n                \"subgraph\": subgraph,\n                \"orbit_partition\": orbit_partition,\n                \"orbit_membership\": orbit_membership,\n                \"aut_count\": aut_count,\n            }\n        )\n\n    ### load and preprocess dataset\n    dataset = DglGraphPropPredDataset(name=name, root=path)\n    split_idx = dataset.get_idx_split()\n\n    # computation of subgraph isomorphisms & creation of data structure\n    graphs_dgl = list()\n    split_idx[\"label\"] = []\n    for i, datapoint in tqdm(enumerate(dataset)):\n        g, label = datapoint\n        g = _prepare(g, subgraph_dicts)\n        graphs_dgl.append(g)\n        split_idx[\"label\"].append(label)\n\n    split_idx[\"label\"] = torch.stack(split_idx[\"label\"])\n\n    return graphs_dgl, split_idx\n\n\ndef _prepare(g, subgraph_dicts):\n    edge_index = torch.stack(g.edges())\n\n    identifiers = None\n    for subgraph_dict in subgraph_dicts:\n        counts = subgraph_isomorphism_edge_counts(edge_index, subgraph_dict)\n        identifiers = (\n            counts\n            if identifiers is None\n            else torch.cat((identifiers, counts), 1)\n        )\n\n    g.edata[\"subgraph_counts\"] = identifiers.long()\n\n    return g\n\n\nif __name__ == \"__main__\":\n    prepare_dataset(\"ogbg-molpcba\")\n"
  },
  {
    "path": "examples/pytorch/ogb/line/README.md",
    "content": "# LINE Example\n- Paper link: [here](https://arxiv.org/pdf/1503.03578)\n- Official implementation: [here](https://github.com/tangjianpku/LINE)\n\nThis implementation includes both LINE-1st and LINE-2nd. The detailed usage is shown in the arguments in line.py.\n\n## How to load ogb data\nTo load ogb dataset, you need to run the following command, which will output a network file, ogbn-products-net.txt:\n```\npython3 load_dataset.py --name ogbn-proteins\n```\nOr you can run the code directly with:\n```\npython3 line.py --ogbn_name xxx --load_from_ogbn\n```\nHowever, ogb.nodeproppred might not be compatible with mixed training with multi-gpu. If you want to do mixed training, please use no more than 1 gpu by the command above. We leave the commands to run with multi-gpu at the end.\n\n## Evaluation\nFor evaluatation we follow the code mlp.py provided by ogb [here](https://github.com/snap-stanford/ogb/blob/master/examples/nodeproppred/).\n\n## Used config\nogbn-arxiv\n```\npython3 line.py --save_in_pt --dim 128 --lap_norm 0.1 --mix --gpus 0 --batch_size 1024 --output_emb_file arxiv-embedding.pt --num_samples 1000 --print_interval 1000 --negative 5 --fast_neg --load_from_ogbn --ogbn_name ogbn-arxiv\ncd ./ogb/blob/master/examples/nodeproppred/arxiv\ncp embedding_pt_file_path ./\npython3 mlp.py --device 0 --use_node_embedding\n```\n\nogbn-proteins\n```\npython3 line.py --save_in_pt --dim 128 --lap_norm 0.01 --mix --gpus 1 --batch_size 1024 --output_emb_file protein-embedding.pt --num_samples 600 --print_interval 1000 --negative 1 --fast_neg --load_from_ogbn --ogbn_name ogbn-proteins --print_loss\ncd ./ogb/blob/master/examples/nodeproppred/proteins\ncp embedding_pt_file_path ./\npython3 mlp.py --device 0 --use_node_embedding\n```\n\nogbl-products\n```\npython3 line.py --save_in_pt --dim 128 --lap_norm 0.01 --mix --gpus 0 --batch_size 4096 --output_emb_file products-embedding.pt --num_samples 3000 --print_interval 1000 --negative 1 --fast_neg --load_from_ogbn --ogbn_name ogbn-products --print_loss\ncd ./ogb/blob/master/examples/nodeproppred/products\ncp embedding_pt_file_path ./\npython3 mlp.py --device 0 --use_node_embedding\n```\n\n## Results\nogbn-arxiv\n<br>#params: 33023343(model) + 142888(mlp) = 33166231\n<br>Highest Train: 82.94 ± 0.11\n<br>Highest Valid: 71.76 ± 0.08\n<br>Final Train: 80.74 ± 1.30\n<br>Final Test: 70.47 ± 0.19\n\n<br>obgn-proteins\n<br>#params: 25853524(model) + 129648(mlp) = 25983172\n<br>Highest Train: 93.11 ± 0.04\n<br>Highest Valid: 70.50 ± 1.29\n<br>Final Train: 77.66 ± 10.27\n<br>Final Test: 62.07 ± 1.25\n\n<br>ogbn-products\n<br>#params: 477570049(model) + 136495(mlp) = 477706544\n<br>Highest Train: 98.01 ± 0.32\n<br>Highest Valid: 89.57 ± 0.09\n<br>Final Train: 94.96 ± 0.43\n<br>Final Test: 72.52 ± 0.29\n\n## Notes\nTo utlize multi-GPU training, we need to load datasets as a local file before training by the following command:\n```\npython3 load_dataset.py --name dataset_name\n```\nwhere `dataset_name` can be `ogbn-arxiv`, `ogbn-proteins`, and `ogbn-products`. After that, a local file `$dataset_name$-graph.bin` will be generated. Then run:\n```\npython3 line.py --data_file $dataset_name$-graph.bin\n```\nwhere the other parameters are the same with used configs without using `--load_from_ogbn` and `--ogbn_name`."
  },
  {
    "path": "examples/pytorch/ogb/line/line.py",
    "content": "import argparse\nimport os\nimport random\nimport time\n\nimport dgl\n\nimport numpy as np\nimport torch\nimport torch.multiprocessing as mp\nfrom model import SkipGramModel\nfrom reading_data import LineDataset\nfrom torch.utils.data import DataLoader\nfrom utils import check_args, sum_up_params\n\n\nclass LineTrainer:\n    def __init__(self, args):\n        \"\"\"Initializing the trainer with the input arguments\"\"\"\n        self.args = args\n        self.dataset = LineDataset(\n            net_file=args.data_file,\n            batch_size=args.batch_size,\n            negative=args.negative,\n            gpus=args.gpus,\n            fast_neg=args.fast_neg,\n            ogbl_name=args.ogbl_name,\n            load_from_ogbl=args.load_from_ogbl,\n            ogbn_name=args.ogbn_name,\n            load_from_ogbn=args.load_from_ogbn,\n            num_samples=args.num_samples * 1000000,\n        )\n        self.emb_size = self.dataset.G.num_nodes()\n        self.emb_model = None\n\n    def init_device_emb(self):\n        \"\"\"set the device before training\n        will be called once in fast_train_mp / fast_train\n        \"\"\"\n        choices = sum([self.args.only_gpu, self.args.only_cpu, self.args.mix])\n        assert (\n            choices == 1\n        ), \"Must choose only *one* training mode in [only_cpu, only_gpu, mix]\"\n\n        # initializing embedding on CPU\n        self.emb_model = SkipGramModel(\n            emb_size=self.emb_size,\n            emb_dimension=self.args.dim,\n            batch_size=self.args.batch_size,\n            only_cpu=self.args.only_cpu,\n            only_gpu=self.args.only_gpu,\n            only_fst=self.args.only_fst,\n            only_snd=self.args.only_snd,\n            mix=self.args.mix,\n            neg_weight=self.args.neg_weight,\n            negative=self.args.negative,\n            lr=self.args.lr,\n            lap_norm=self.args.lap_norm,\n            fast_neg=self.args.fast_neg,\n            record_loss=self.args.print_loss,\n            async_update=self.args.async_update,\n            num_threads=self.args.num_threads,\n        )\n\n        torch.set_num_threads(self.args.num_threads)\n        if self.args.only_gpu:\n            print(\"Run in 1 GPU\")\n            assert self.args.gpus[0] >= 0\n            self.emb_model.all_to_device(self.args.gpus[0])\n        elif self.args.mix:\n            print(\"Mix CPU with %d GPU\" % len(self.args.gpus))\n            if len(self.args.gpus) == 1:\n                assert (\n                    self.args.gpus[0] >= 0\n                ), \"mix CPU with GPU should have avaliable GPU\"\n                self.emb_model.set_device(self.args.gpus[0])\n        else:\n            print(\"Run in CPU process\")\n\n    def train(self):\n        \"\"\"train the embedding\"\"\"\n        if len(self.args.gpus) > 1:\n            self.fast_train_mp()\n        else:\n            self.fast_train()\n\n    def fast_train_mp(self):\n        \"\"\"multi-cpu-core or mix cpu & multi-gpu\"\"\"\n        self.init_device_emb()\n        self.emb_model.share_memory()\n\n        sum_up_params(self.emb_model)\n\n        start_all = time.time()\n        ps = []\n\n        for i in range(len(self.args.gpus)):\n            p = mp.Process(\n                target=self.fast_train_sp, args=(i, self.args.gpus[i])\n            )\n            ps.append(p)\n            p.start()\n\n        for p in ps:\n            p.join()\n\n        print(\"Used time: %.2fs\" % (time.time() - start_all))\n        if self.args.save_in_pt:\n            self.emb_model.save_embedding_pt(\n                self.dataset, self.args.output_emb_file\n            )\n        else:\n            self.emb_model.save_embedding(\n                self.dataset, self.args.output_emb_file\n            )\n\n    def fast_train_sp(self, rank, gpu_id):\n        \"\"\"a subprocess for fast_train_mp\"\"\"\n        if self.args.mix:\n            self.emb_model.set_device(gpu_id)\n\n        torch.set_num_threads(self.args.num_threads)\n        if self.args.async_update:\n            self.emb_model.create_async_update()\n\n        sampler = self.dataset.create_sampler(rank)\n\n        dataloader = DataLoader(\n            dataset=sampler.seeds,\n            batch_size=self.args.batch_size,\n            collate_fn=sampler.sample,\n            shuffle=False,\n            drop_last=False,\n            num_workers=self.args.num_sampler_threads,\n        )\n        num_batches = len(dataloader)\n        print(\n            \"num batchs: %d in process [%d] GPU [%d]\"\n            % (num_batches, rank, gpu_id)\n        )\n\n        start = time.time()\n        with torch.no_grad():\n            for i, edges in enumerate(dataloader):\n                if self.args.fast_neg:\n                    self.emb_model.fast_learn(edges)\n                else:\n                    # do negative sampling\n                    bs = edges.size()[0]\n                    neg_nodes = torch.LongTensor(\n                        np.random.choice(\n                            self.dataset.neg_table,\n                            bs * self.args.negative,\n                            replace=True,\n                        )\n                    )\n                    self.emb_model.fast_learn(edges, neg_nodes=neg_nodes)\n\n                if i > 0 and i % self.args.print_interval == 0:\n                    if self.args.print_loss:\n                        if self.args.only_fst:\n                            print(\n                                \"GPU-[%d] batch %d time: %.2fs fst-loss: %.4f\"\n                                % (\n                                    gpu_id,\n                                    i,\n                                    time.time() - start,\n                                    -sum(self.emb_model.loss_fst)\n                                    / self.args.print_interval,\n                                )\n                            )\n                        elif self.args.only_snd:\n                            print(\n                                \"GPU-[%d] batch %d time: %.2fs snd-loss: %.4f\"\n                                % (\n                                    gpu_id,\n                                    i,\n                                    time.time() - start,\n                                    -sum(self.emb_model.loss_snd)\n                                    / self.args.print_interval,\n                                )\n                            )\n                        else:\n                            print(\n                                \"GPU-[%d] batch %d time: %.2fs fst-loss: %.4f snd-loss: %.4f\"\n                                % (\n                                    gpu_id,\n                                    i,\n                                    time.time() - start,\n                                    -sum(self.emb_model.loss_fst)\n                                    / self.args.print_interval,\n                                    -sum(self.emb_model.loss_snd)\n                                    / self.args.print_interval,\n                                )\n                            )\n                        self.emb_model.loss_fst = []\n                        self.emb_model.loss_snd = []\n                    else:\n                        print(\n                            \"GPU-[%d] batch %d time: %.2fs\"\n                            % (gpu_id, i, time.time() - start)\n                        )\n                    start = time.time()\n\n            if self.args.async_update:\n                self.emb_model.finish_async_update()\n\n    def fast_train(self):\n        \"\"\"fast train with dataloader with only gpu / only cpu\"\"\"\n        self.init_device_emb()\n\n        if self.args.async_update:\n            self.emb_model.share_memory()\n            self.emb_model.create_async_update()\n\n        sum_up_params(self.emb_model)\n\n        sampler = self.dataset.create_sampler(0)\n\n        dataloader = DataLoader(\n            dataset=sampler.seeds,\n            batch_size=self.args.batch_size,\n            collate_fn=sampler.sample,\n            shuffle=False,\n            drop_last=False,\n            num_workers=self.args.num_sampler_threads,\n        )\n\n        num_batches = len(dataloader)\n        print(\"num batchs: %d\\n\" % num_batches)\n\n        start_all = time.time()\n        start = time.time()\n        with torch.no_grad():\n            for i, edges in enumerate(dataloader):\n                if self.args.fast_neg:\n                    self.emb_model.fast_learn(edges)\n                else:\n                    # do negative sampling\n                    bs = edges.size()[0]\n                    neg_nodes = torch.LongTensor(\n                        np.random.choice(\n                            self.dataset.neg_table,\n                            bs * self.args.negative,\n                            replace=True,\n                        )\n                    )\n                    self.emb_model.fast_learn(edges, neg_nodes=neg_nodes)\n\n                if i > 0 and i % self.args.print_interval == 0:\n                    if self.args.print_loss:\n                        if self.args.only_fst:\n                            print(\n                                \"Batch %d time: %.2fs fst-loss: %.4f\"\n                                % (\n                                    i,\n                                    time.time() - start,\n                                    -sum(self.emb_model.loss_fst)\n                                    / self.args.print_interval,\n                                )\n                            )\n                        elif self.args.only_snd:\n                            print(\n                                \"Batch %d time: %.2fs snd-loss: %.4f\"\n                                % (\n                                    i,\n                                    time.time() - start,\n                                    -sum(self.emb_model.loss_snd)\n                                    / self.args.print_interval,\n                                )\n                            )\n                        else:\n                            print(\n                                \"Batch %d time: %.2fs fst-loss: %.4f snd-loss: %.4f\"\n                                % (\n                                    i,\n                                    time.time() - start,\n                                    -sum(self.emb_model.loss_fst)\n                                    / self.args.print_interval,\n                                    -sum(self.emb_model.loss_snd)\n                                    / self.args.print_interval,\n                                )\n                            )\n                        self.emb_model.loss_fst = []\n                        self.emb_model.loss_snd = []\n                    else:\n                        print(\n                            \"Batch %d, training time: %.2fs\"\n                            % (i, time.time() - start)\n                        )\n                    start = time.time()\n\n            if self.args.async_update:\n                self.emb_model.finish_async_update()\n\n        print(\"Training used time: %.2fs\" % (time.time() - start_all))\n        if self.args.save_in_pt:\n            self.emb_model.save_embedding_pt(\n                self.dataset, self.args.output_emb_file\n            )\n        else:\n            self.emb_model.save_embedding(\n                self.dataset, self.args.output_emb_file\n            )\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"Implementation of LINE.\")\n    # input files\n    ## personal datasets\n    parser.add_argument(\"--data_file\", type=str, help=\"path of dgl graphs\")\n    ## ogbl datasets\n    parser.add_argument(\n        \"--ogbl_name\", type=str, help=\"name of ogbl dataset, e.g. ogbl-ddi\"\n    )\n    parser.add_argument(\n        \"--load_from_ogbl\",\n        default=False,\n        action=\"store_true\",\n        help=\"whether load dataset from ogbl\",\n    )\n    parser.add_argument(\n        \"--ogbn_name\", type=str, help=\"name of ogbn dataset, e.g. ogbn-proteins\"\n    )\n    parser.add_argument(\n        \"--load_from_ogbn\",\n        default=False,\n        action=\"store_true\",\n        help=\"whether load dataset from ogbn\",\n    )\n\n    # output files\n    parser.add_argument(\n        \"--save_in_pt\",\n        default=False,\n        action=\"store_true\",\n        help=\"Whether save dat in pt format or npy\",\n    )\n    parser.add_argument(\n        \"--output_emb_file\",\n        type=str,\n        default=\"emb.npy\",\n        help=\"path of the output npy embedding file\",\n    )\n\n    # model parameters\n    parser.add_argument(\n        \"--dim\", default=128, type=int, help=\"embedding dimensions\"\n    )\n    parser.add_argument(\n        \"--num_samples\",\n        default=1,\n        type=int,\n        help=\"number of samples during training (million)\",\n    )\n    parser.add_argument(\n        \"--negative\",\n        default=1,\n        type=int,\n        help=\"negative samples for each positve node pair\",\n    )\n    parser.add_argument(\n        \"--batch_size\",\n        default=128,\n        type=int,\n        help=\"number of edges in each batch\",\n    )\n    parser.add_argument(\n        \"--neg_weight\", default=1.0, type=float, help=\"negative weight\"\n    )\n    parser.add_argument(\n        \"--lap_norm\",\n        default=0.01,\n        type=float,\n        help=\"weight of laplacian normalization\",\n    )\n\n    # training parameters\n    parser.add_argument(\n        \"--only_fst\",\n        default=False,\n        action=\"store_true\",\n        help=\"only do first-order proximity embedding\",\n    )\n    parser.add_argument(\n        \"--only_snd\",\n        default=False,\n        action=\"store_true\",\n        help=\"only do second-order proximity embedding\",\n    )\n    parser.add_argument(\n        \"--print_interval\",\n        default=100,\n        type=int,\n        help=\"number of batches between printing\",\n    )\n    parser.add_argument(\n        \"--print_loss\",\n        default=False,\n        action=\"store_true\",\n        help=\"whether print loss during training\",\n    )\n    parser.add_argument(\"--lr\", default=0.2, type=float, help=\"learning rate\")\n\n    # optimization settings\n    parser.add_argument(\n        \"--mix\",\n        default=False,\n        action=\"store_true\",\n        help=\"mixed training with CPU and GPU\",\n    )\n    parser.add_argument(\n        \"--gpus\",\n        type=int,\n        default=[-1],\n        nargs=\"+\",\n        help=\"a list of active gpu ids, e.g. 0, used with --mix\",\n    )\n    parser.add_argument(\n        \"--only_cpu\",\n        default=False,\n        action=\"store_true\",\n        help=\"training with CPU\",\n    )\n    parser.add_argument(\n        \"--only_gpu\",\n        default=False,\n        action=\"store_true\",\n        help=\"training with a single GPU (all of the parameters are moved on the GPU)\",\n    )\n    parser.add_argument(\n        \"--async_update\",\n        default=False,\n        action=\"store_true\",\n        help=\"mixed training asynchronously, recommend not to use this\",\n    )\n\n    parser.add_argument(\n        \"--fast_neg\",\n        default=False,\n        action=\"store_true\",\n        help=\"do negative sampling inside a batch\",\n    )\n    parser.add_argument(\n        \"--num_threads\",\n        default=2,\n        type=int,\n        help=\"number of threads used for each CPU-core/GPU\",\n    )\n    parser.add_argument(\n        \"--num_sampler_threads\",\n        default=2,\n        type=int,\n        help=\"number of threads used for sampling\",\n    )\n\n    args = parser.parse_args()\n\n    if args.async_update:\n        assert args.mix, \"--async_update only with --mix\"\n\n    start_time = time.time()\n    trainer = LineTrainer(args)\n    trainer.train()\n    print(\"Total used time: %.2f\" % (time.time() - start_time))\n"
  },
  {
    "path": "examples/pytorch/ogb/line/load_dataset.py",
    "content": "\"\"\" load dataset from ogb \"\"\"\n\nimport argparse\n\nimport dgl\n\nfrom ogb.linkproppred import DglLinkPropPredDataset\nfrom ogb.nodeproppred import DglNodePropPredDataset\n\n\ndef load_from_ogbl_with_name(name):\n    choices = [\"ogbl-collab\", \"ogbl-ddi\", \"ogbl-ppa\", \"ogbl-citation\"]\n    assert name in choices, \"name must be selected from \" + str(choices)\n    dataset = DglLinkPropPredDataset(name)\n    return dataset[0]\n\n\ndef load_from_ogbn_with_name(name):\n    choices = [\n        \"ogbn-products\",\n        \"ogbn-proteins\",\n        \"ogbn-arxiv\",\n        \"ogbn-papers100M\",\n    ]\n    assert name in choices, \"name must be selected from \" + str(choices)\n    dataset, label = DglNodePropPredDataset(name)[0]\n    return dataset\n\n\nif __name__ == \"__main__\":\n    \"\"\"load datasets as net.txt format\"\"\"\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--name\",\n        type=str,\n        choices=[\n            \"ogbl-collab\",\n            \"ogbl-ddi\",\n            \"ogbl-ppa\",\n            \"ogbl-citation\",\n            \"ogbn-products\",\n            \"ogbn-proteins\",\n            \"ogbn-arxiv\",\n            \"ogbn-papers100M\",\n        ],\n        default=\"ogbl-collab\",\n        help=\"name of datasets by ogb\",\n    )\n    args = parser.parse_args()\n\n    name = args.name\n    if name.startswith(\"ogbl\"):\n        g = load_from_ogbl_with_name(name=name)\n    else:\n        g = load_from_ogbn_with_name(name=name)\n\n    dgl.save_graphs(name + \"-graph.bin\", g)\n"
  },
  {
    "path": "examples/pytorch/ogb/line/model.py",
    "content": "import random\n\nimport numpy as np\nimport torch\nimport torch.multiprocessing as mp\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.multiprocessing import Queue\nfrom torch.nn import init\n\n\ndef init_emb2neg_index(negative, batch_size):\n    \"\"\"select embedding of negative nodes from a batch of node embeddings\n    for fast negative sampling\n\n    Return\n    ------\n    index_emb_negu torch.LongTensor : the indices of u_embeddings\n    index_emb_negv torch.LongTensor : the indices of v_embeddings\n\n    Usage\n    -----\n    # emb_u.shape: [batch_size, dim]\n    batch_emb2negu = torch.index_select(emb_u, 0, index_emb_negu)\n    \"\"\"\n    idx_list_u = list(range(batch_size)) * negative\n    idx_list_v = list(range(batch_size)) * negative\n    random.shuffle(idx_list_v)\n\n    index_emb_negu = torch.LongTensor(idx_list_u)\n    index_emb_negv = torch.LongTensor(idx_list_v)\n\n    return index_emb_negu, index_emb_negv\n\n\ndef adam(grad, state_sum, nodes, lr, device, only_gpu):\n    \"\"\"calculate gradients according to adam\"\"\"\n    grad_sum = (grad * grad).mean(1)\n    if not only_gpu:\n        grad_sum = grad_sum.cpu()\n    state_sum.index_add_(0, nodes, grad_sum)  # cpu\n    std = state_sum[nodes].to(device)  # gpu\n    std_values = std.sqrt_().add_(1e-10).unsqueeze(1)\n    grad = lr * grad / std_values  # gpu\n\n    return grad\n\n\ndef async_update(num_threads, model, queue):\n    \"\"\"Asynchronous embedding update for entity embeddings.\"\"\"\n    torch.set_num_threads(num_threads)\n    print(\"async start\")\n    while True:\n        (grad_u, grad_v, grad_v_neg, nodes, neg_nodes, first_flag) = queue.get()\n        if grad_u is None:\n            return\n        with torch.no_grad():\n            if first_flag:\n                model.fst_u_embeddings.weight.data.index_add_(\n                    0, nodes[:, 0], grad_u\n                )\n                model.fst_u_embeddings.weight.data.index_add_(\n                    0, nodes[:, 1], grad_v\n                )\n                if neg_nodes is not None:\n                    model.fst_u_embeddings.weight.data.index_add_(\n                        0, neg_nodes, grad_v_neg\n                    )\n            else:\n                model.snd_u_embeddings.weight.data.index_add_(\n                    0, nodes[:, 0], grad_u\n                )\n                model.snd_v_embeddings.weight.data.index_add_(\n                    0, nodes[:, 1], grad_v\n                )\n                if neg_nodes is not None:\n                    model.snd_v_embeddings.weight.data.index_add_(\n                        0, neg_nodes, grad_v_neg\n                    )\n\n\nclass SkipGramModel(nn.Module):\n    \"\"\"Negative sampling based skip-gram\"\"\"\n\n    def __init__(\n        self,\n        emb_size,\n        emb_dimension,\n        batch_size,\n        only_cpu,\n        only_gpu,\n        only_fst,\n        only_snd,\n        mix,\n        neg_weight,\n        negative,\n        lr,\n        lap_norm,\n        fast_neg,\n        record_loss,\n        async_update,\n        num_threads,\n    ):\n        \"\"\"initialize embedding on CPU\n\n        Paremeters\n        ----------\n        emb_size int : number of nodes\n        emb_dimension int : embedding dimension\n        batch_size int : number of node sequences in each batch\n        only_cpu bool : training with CPU\n        only_gpu bool : training with GPU\n        only_fst bool : only embedding for first-order proximity\n        only_snd bool : only embedding for second-order proximity\n        mix bool : mixed training with CPU and GPU\n        negative int : negative samples for each positve node pair\n        neg_weight float : negative weight\n        lr float : initial learning rate\n        lap_norm float : weight of laplacian normalization\n        fast_neg bool : do negative sampling inside a batch\n        record_loss bool : print the loss during training\n        use_context_weight : give different weights to the nodes in a context window\n        async_update : asynchronous training\n        \"\"\"\n        super(SkipGramModel, self).__init__()\n        self.emb_size = emb_size\n        self.batch_size = batch_size\n        self.only_cpu = only_cpu\n        self.only_gpu = only_gpu\n        if only_fst:\n            self.fst = True\n            self.snd = False\n            self.emb_dimension = emb_dimension\n        elif only_snd:\n            self.fst = False\n            self.snd = True\n            self.emb_dimension = emb_dimension\n        else:\n            self.fst = True\n            self.snd = True\n            self.emb_dimension = int(emb_dimension / 2)\n        self.mixed_train = mix\n        self.neg_weight = neg_weight\n        self.negative = negative\n        self.lr = lr\n        self.lap_norm = lap_norm\n        self.fast_neg = fast_neg\n        self.record_loss = record_loss\n        self.async_update = async_update\n        self.num_threads = num_threads\n\n        # initialize the device as cpu\n        self.device = torch.device(\"cpu\")\n\n        # embedding\n        initrange = 1.0 / self.emb_dimension\n        if self.fst:\n            self.fst_u_embeddings = nn.Embedding(\n                self.emb_size, self.emb_dimension, sparse=True\n            )\n            init.uniform_(\n                self.fst_u_embeddings.weight.data, -initrange, initrange\n            )\n        if self.snd:\n            self.snd_u_embeddings = nn.Embedding(\n                self.emb_size, self.emb_dimension, sparse=True\n            )\n            init.uniform_(\n                self.snd_u_embeddings.weight.data, -initrange, initrange\n            )\n            self.snd_v_embeddings = nn.Embedding(\n                self.emb_size, self.emb_dimension, sparse=True\n            )\n            init.constant_(self.snd_v_embeddings.weight.data, 0)\n\n        # lookup_table is used for fast sigmoid computing\n        self.lookup_table = torch.sigmoid(torch.arange(-6.01, 6.01, 0.01))\n        self.lookup_table[0] = 0.0\n        self.lookup_table[-1] = 1.0\n        if self.record_loss:\n            self.logsigmoid_table = torch.log(\n                torch.sigmoid(torch.arange(-6.01, 6.01, 0.01))\n            )\n            self.loss_fst = []\n            self.loss_snd = []\n\n        # indexes to select positive/negative node pairs from batch_walks\n        self.index_emb_negu, self.index_emb_negv = init_emb2neg_index(\n            self.negative, self.batch_size\n        )\n\n        # adam\n        if self.fst:\n            self.fst_state_sum_u = torch.zeros(self.emb_size)\n        if self.snd:\n            self.snd_state_sum_u = torch.zeros(self.emb_size)\n            self.snd_state_sum_v = torch.zeros(self.emb_size)\n\n    def create_async_update(self):\n        \"\"\"Set up the async update subprocess.\"\"\"\n        self.async_q = Queue(1)\n        self.async_p = mp.Process(\n            target=async_update, args=(self.num_threads, self, self.async_q)\n        )\n        self.async_p.start()\n\n    def finish_async_update(self):\n        \"\"\"Notify the async update subprocess to quit.\"\"\"\n        self.async_q.put((None, None, None, None, None))\n        self.async_p.join()\n\n    def share_memory(self):\n        \"\"\"share the parameters across subprocesses\"\"\"\n        if self.fst:\n            self.fst_u_embeddings.weight.share_memory_()\n            self.fst_state_sum_u.share_memory_()\n        if self.snd:\n            self.snd_u_embeddings.weight.share_memory_()\n            self.snd_v_embeddings.weight.share_memory_()\n            self.snd_state_sum_u.share_memory_()\n            self.snd_state_sum_v.share_memory_()\n\n    def set_device(self, gpu_id):\n        \"\"\"set gpu device\"\"\"\n        self.device = torch.device(\"cuda:%d\" % gpu_id)\n        print(\"The device is\", self.device)\n        self.lookup_table = self.lookup_table.to(self.device)\n        if self.record_loss:\n            self.logsigmoid_table = self.logsigmoid_table.to(self.device)\n        self.index_emb_negu = self.index_emb_negu.to(self.device)\n        self.index_emb_negv = self.index_emb_negv.to(self.device)\n\n    def all_to_device(self, gpu_id):\n        \"\"\"move all of the parameters to a single GPU\"\"\"\n        self.device = torch.device(\"cuda:%d\" % gpu_id)\n        self.set_device(gpu_id)\n        if self.fst:\n            self.fst_u_embeddings = self.fst_u_embeddings.cuda(gpu_id)\n            self.fst_state_sum_u = self.fst_state_sum_u.to(self.device)\n        if self.snd:\n            self.snd_u_embeddings = self.snd_u_embeddings.cuda(gpu_id)\n            self.snd_v_embeddings = self.snd_v_embeddings.cuda(gpu_id)\n            self.snd_state_sum_u = self.snd_state_sum_u.to(self.device)\n            self.snd_state_sum_v = self.snd_state_sum_v.to(self.device)\n\n    def fast_sigmoid(self, score):\n        \"\"\"do fast sigmoid by looking up in a pre-defined table\"\"\"\n        idx = torch.floor((score + 6.01) / 0.01).long()\n        return self.lookup_table[idx]\n\n    def fast_logsigmoid(self, score):\n        \"\"\"do fast logsigmoid by looking up in a pre-defined table\"\"\"\n        idx = torch.floor((score + 6.01) / 0.01).long()\n        return self.logsigmoid_table[idx]\n\n    def fast_pos_bp(self, emb_pos_u, emb_pos_v, first_flag):\n        \"\"\"get grad for positve samples\"\"\"\n        pos_score = torch.sum(torch.mul(emb_pos_u, emb_pos_v), dim=1)\n        pos_score = torch.clamp(pos_score, max=6, min=-6)\n        # [batch_size, 1]\n        score = (1 - self.fast_sigmoid(pos_score)).unsqueeze(1)\n        if self.record_loss:\n            if first_flag:\n                self.loss_fst.append(\n                    torch.mean(self.fast_logsigmoid(pos_score)).item()\n                )\n            else:\n                self.loss_snd.append(\n                    torch.mean(self.fast_logsigmoid(pos_score)).item()\n                )\n\n        # [batch_size, dim]\n        if self.lap_norm > 0:\n            grad_u_pos = score * emb_pos_v + self.lap_norm * (\n                emb_pos_v - emb_pos_u\n            )\n            grad_v_pos = score * emb_pos_u + self.lap_norm * (\n                emb_pos_u - emb_pos_v\n            )\n        else:\n            grad_u_pos = score * emb_pos_v\n            grad_v_pos = score * emb_pos_u\n\n        return grad_u_pos, grad_v_pos\n\n    def fast_neg_bp(self, emb_neg_u, emb_neg_v, first_flag):\n        \"\"\"get grad for negative samples\"\"\"\n        neg_score = torch.sum(torch.mul(emb_neg_u, emb_neg_v), dim=1)\n        neg_score = torch.clamp(neg_score, max=6, min=-6)\n        # [batch_size * negative, 1]\n        score = -self.fast_sigmoid(neg_score).unsqueeze(1)\n        if self.record_loss:\n            if first_flag:\n                self.loss_fst.append(\n                    self.negative\n                    * self.neg_weight\n                    * torch.mean(self.fast_logsigmoid(-neg_score)).item()\n                )\n            else:\n                self.loss_snd.append(\n                    self.negative\n                    * self.neg_weight\n                    * torch.mean(self.fast_logsigmoid(-neg_score)).item()\n                )\n\n        grad_u_neg = self.neg_weight * score * emb_neg_v\n        grad_v_neg = self.neg_weight * score * emb_neg_u\n\n        return grad_u_neg, grad_v_neg\n\n    def fast_learn(self, batch_edges, neg_nodes=None):\n        \"\"\"Learn a batch of edges in a fast way. It has the following features:\n            1. It calculating the gradients directly without the forward operation.\n            2. It does sigmoid by a looking up table.\n\n        Specifically, for each positive/negative node pair (i,j), the updating procedure is as following:\n            score = self.fast_sigmoid(u_embedding[i].dot(v_embedding[j]))\n            # label = 1 for positive samples; label = 0 for negative samples.\n            u_embedding[i] += (label - score) * v_embedding[j]\n            v_embedding[i] += (label - score) * u_embedding[j]\n\n        Parameters\n        ----------\n        batch_edges list : a list of node sequnces\n        neg_nodes torch.LongTensor : a long tensor of sampled true negative nodes. If neg_nodes is None,\n            then do negative sampling randomly from the nodes in batch_walks as an alternative.\n\n        Usage example\n        -------------\n        batch_walks = torch.LongTensor([[1,2], [3,4], [5,6]])\n        neg_nodes = None\n        \"\"\"\n        lr = self.lr\n\n        # [batch_size, 2]\n        nodes = batch_edges\n        if self.only_gpu:\n            nodes = nodes.to(self.device)\n            if neg_nodes is not None:\n                neg_nodes = neg_nodes.to(self.device)\n        bs = len(nodes)\n\n        if self.fst:\n            emb_u = (\n                self.fst_u_embeddings(nodes[:, 0])\n                .view(-1, self.emb_dimension)\n                .to(self.device)\n            )\n            emb_v = (\n                self.fst_u_embeddings(nodes[:, 1])\n                .view(-1, self.emb_dimension)\n                .to(self.device)\n            )\n\n            ## Postive\n            emb_pos_u, emb_pos_v = emb_u, emb_v\n            grad_u_pos, grad_v_pos = self.fast_pos_bp(\n                emb_pos_u, emb_pos_v, True\n            )\n\n            ## Negative\n            emb_neg_u = emb_pos_u.repeat((self.negative, 1))\n\n            if bs < self.batch_size:\n                index_emb_negu, index_emb_negv = init_emb2neg_index(\n                    self.negative, bs\n                )\n                index_emb_negu = index_emb_negu.to(self.device)\n                index_emb_negv = index_emb_negv.to(self.device)\n            else:\n                index_emb_negu = self.index_emb_negu\n                index_emb_negv = self.index_emb_negv\n\n            if neg_nodes is None:\n                emb_neg_v = torch.index_select(emb_v, 0, index_emb_negv)\n            else:\n                emb_neg_v = self.fst_u_embeddings.weight[neg_nodes].to(\n                    self.device\n                )\n\n            grad_u_neg, grad_v_neg = self.fast_neg_bp(\n                emb_neg_u, emb_neg_v, True\n            )\n\n            ## Update\n            grad_u_pos.index_add_(0, index_emb_negu, grad_u_neg)\n            grad_u = grad_u_pos\n            if neg_nodes is None:\n                grad_v_pos.index_add_(0, index_emb_negv, grad_v_neg)\n                grad_v = grad_v_pos\n            else:\n                grad_v = grad_v_pos\n\n            # use adam optimizer\n            grad_u = adam(\n                grad_u,\n                self.fst_state_sum_u,\n                nodes[:, 0],\n                lr,\n                self.device,\n                self.only_gpu,\n            )\n            grad_v = adam(\n                grad_v,\n                self.fst_state_sum_u,\n                nodes[:, 1],\n                lr,\n                self.device,\n                self.only_gpu,\n            )\n            if neg_nodes is not None:\n                grad_v_neg = adam(\n                    grad_v_neg,\n                    self.fst_state_sum_u,\n                    neg_nodes,\n                    lr,\n                    self.device,\n                    self.only_gpu,\n                )\n\n            if self.mixed_train:\n                grad_u = grad_u.cpu()\n                grad_v = grad_v.cpu()\n                if neg_nodes is not None:\n                    grad_v_neg = grad_v_neg.cpu()\n                else:\n                    grad_v_neg = None\n\n                if self.async_update:\n                    grad_u.share_memory_()\n                    grad_v.share_memory_()\n                    nodes.share_memory_()\n                    if neg_nodes is not None:\n                        neg_nodes.share_memory_()\n                        grad_v_neg.share_memory_()\n                    self.async_q.put(\n                        (grad_u, grad_v, grad_v_neg, nodes, neg_nodes, True)\n                    )\n\n            if not self.async_update:\n                self.fst_u_embeddings.weight.data.index_add_(\n                    0, nodes[:, 0], grad_u\n                )\n                self.fst_u_embeddings.weight.data.index_add_(\n                    0, nodes[:, 1], grad_v\n                )\n                if neg_nodes is not None:\n                    self.fst_u_embeddings.weight.data.index_add_(\n                        0, neg_nodes, grad_v_neg\n                    )\n\n        if self.snd:\n            emb_u = (\n                self.snd_u_embeddings(nodes[:, 0])\n                .view(-1, self.emb_dimension)\n                .to(self.device)\n            )\n            emb_v = (\n                self.snd_v_embeddings(nodes[:, 1])\n                .view(-1, self.emb_dimension)\n                .to(self.device)\n            )\n\n            ## Postive\n            emb_pos_u, emb_pos_v = emb_u, emb_v\n            grad_u_pos, grad_v_pos = self.fast_pos_bp(\n                emb_pos_u, emb_pos_v, False\n            )\n\n            ## Negative\n            emb_neg_u = emb_pos_u.repeat((self.negative, 1))\n\n            if bs < self.batch_size:\n                index_emb_negu, index_emb_negv = init_emb2neg_index(\n                    self.negative, bs\n                )\n                index_emb_negu = index_emb_negu.to(self.device)\n                index_emb_negv = index_emb_negv.to(self.device)\n            else:\n                index_emb_negu = self.index_emb_negu\n                index_emb_negv = self.index_emb_negv\n\n            if neg_nodes is None:\n                emb_neg_v = torch.index_select(emb_v, 0, index_emb_negv)\n            else:\n                emb_neg_v = self.snd_v_embeddings.weight[neg_nodes].to(\n                    self.device\n                )\n\n            grad_u_neg, grad_v_neg = self.fast_neg_bp(\n                emb_neg_u, emb_neg_v, False\n            )\n\n            ## Update\n            grad_u_pos.index_add_(0, index_emb_negu, grad_u_neg)\n            grad_u = grad_u_pos\n            if neg_nodes is None:\n                grad_v_pos.index_add_(0, index_emb_negv, grad_v_neg)\n                grad_v = grad_v_pos\n            else:\n                grad_v = grad_v_pos\n\n            # use adam optimizer\n            grad_u = adam(\n                grad_u,\n                self.snd_state_sum_u,\n                nodes[:, 0],\n                lr,\n                self.device,\n                self.only_gpu,\n            )\n            grad_v = adam(\n                grad_v,\n                self.snd_state_sum_v,\n                nodes[:, 1],\n                lr,\n                self.device,\n                self.only_gpu,\n            )\n            if neg_nodes is not None:\n                grad_v_neg = adam(\n                    grad_v_neg,\n                    self.snd_state_sum_v,\n                    neg_nodes,\n                    lr,\n                    self.device,\n                    self.only_gpu,\n                )\n\n            if self.mixed_train:\n                grad_u = grad_u.cpu()\n                grad_v = grad_v.cpu()\n                if neg_nodes is not None:\n                    grad_v_neg = grad_v_neg.cpu()\n                else:\n                    grad_v_neg = None\n\n                if self.async_update:\n                    grad_u.share_memory_()\n                    grad_v.share_memory_()\n                    nodes.share_memory_()\n                    if neg_nodes is not None:\n                        neg_nodes.share_memory_()\n                        grad_v_neg.share_memory_()\n                    self.async_q.put(\n                        (grad_u, grad_v, grad_v_neg, nodes, neg_nodes, False)\n                    )\n\n            if not self.async_update:\n                self.snd_u_embeddings.weight.data.index_add_(\n                    0, nodes[:, 0], grad_u\n                )\n                self.snd_v_embeddings.weight.data.index_add_(\n                    0, nodes[:, 1], grad_v\n                )\n                if neg_nodes is not None:\n                    self.snd_v_embeddings.weight.data.index_add_(\n                        0, neg_nodes, grad_v_neg\n                    )\n\n        return\n\n    def get_embedding(self):\n        if self.fst:\n            embedding_fst = self.fst_u_embeddings.weight.cpu().data.numpy()\n            embedding_fst /= np.sqrt(\n                np.sum(embedding_fst * embedding_fst, 1)\n            ).reshape(-1, 1)\n        if self.snd:\n            embedding_snd = self.snd_u_embeddings.weight.cpu().data.numpy()\n            embedding_snd /= np.sqrt(\n                np.sum(embedding_snd * embedding_snd, 1)\n            ).reshape(-1, 1)\n        if self.fst and self.snd:\n            embedding = np.concatenate((embedding_fst, embedding_snd), 1)\n            embedding /= np.sqrt(np.sum(embedding * embedding, 1)).reshape(\n                -1, 1\n            )\n        elif self.fst and not self.snd:\n            embedding = embedding_fst\n        elif self.snd and not self.fst:\n            embedding = embedding_snd\n        else:\n            pass\n\n        return embedding\n\n    def save_embedding(self, dataset, file_name):\n        \"\"\"Write embedding to local file. Only used when node ids are numbers.\n\n        Parameter\n        ---------\n        dataset DeepwalkDataset : the dataset\n        file_name str : the file name\n        \"\"\"\n        embedding = self.get_embedding()\n        np.save(file_name, embedding)\n\n    def save_embedding_pt(self, dataset, file_name):\n        \"\"\"For ogb leaderboard.\"\"\"\n        embedding = torch.Tensor(self.get_embedding()).cpu()\n        embedding_empty = torch.zeros_like(embedding.data)\n        valid_nodes = torch.LongTensor(dataset.valid_nodes)\n        valid_embedding = embedding.data.index_select(0, valid_nodes)\n        embedding_empty.index_add_(0, valid_nodes, valid_embedding)\n\n        torch.save(embedding_empty, file_name)\n"
  },
  {
    "path": "examples/pytorch/ogb/line/reading_data.py",
    "content": "import os\nimport pickle\nimport random\nimport time\n\nimport dgl\n\nimport numpy as np\nimport scipy.sparse as sp\nimport torch\nfrom dgl.data.utils import (\n    _get_dgl_url,\n    download,\n    extract_archive,\n    get_download_dir,\n)\nfrom torch.utils.data import DataLoader\n\n\ndef ReadTxtNet(file_path=\"\", undirected=True):\n    \"\"\"Read the txt network file.\n    Notations: The network is unweighted.\n\n    Parameters\n    ----------\n    file_path str : path of network file\n    undirected bool : whether the edges are undirected\n\n    Return\n    ------\n    net dict : a dict recording the connections in the graph\n    node2id dict : a dict mapping the nodes to their embedding indices\n    id2node dict : a dict mapping nodes embedding indices to the nodes\n    \"\"\"\n    if file_path == \"youtube\" or file_path == \"blog\":\n        name = file_path\n        dir = get_download_dir()\n        zip_file_path = \"{}/{}.zip\".format(dir, name)\n        download(\n            _get_dgl_url(\n                os.path.join(\"dataset/DeepWalk/\", \"{}.zip\".format(file_path))\n            ),\n            path=zip_file_path,\n        )\n        extract_archive(zip_file_path, \"{}/{}\".format(dir, name))\n        file_path = \"{}/{}/{}-net.txt\".format(dir, name, name)\n\n    node2id = {}\n    id2node = {}\n    cid = 0\n\n    src = []\n    dst = []\n    weight = []\n    net = {}\n    with open(file_path, \"r\") as f:\n        for line in f.readlines():\n            tup = list(map(int, line.strip().split(\" \")))\n            assert len(tup) in [\n                2,\n                3,\n            ], \"The format of network file is unrecognizable.\"\n            if len(tup) == 3:\n                n1, n2, w = tup\n            elif len(tup) == 2:\n                n1, n2 = tup\n                w = 1\n            if n1 not in node2id:\n                node2id[n1] = cid\n                id2node[cid] = n1\n                cid += 1\n            if n2 not in node2id:\n                node2id[n2] = cid\n                id2node[cid] = n2\n                cid += 1\n\n            n1 = node2id[n1]\n            n2 = node2id[n2]\n            if n1 not in net:\n                net[n1] = {n2: w}\n                src.append(n1)\n                dst.append(n2)\n                weight.append(w)\n            elif n2 not in net[n1]:\n                net[n1][n2] = w\n                src.append(n1)\n                dst.append(n2)\n                weight.append(w)\n\n            if undirected:\n                if n2 not in net:\n                    net[n2] = {n1: w}\n                    src.append(n2)\n                    dst.append(n1)\n                    weight.append(w)\n                elif n1 not in net[n2]:\n                    net[n2][n1] = w\n                    src.append(n2)\n                    dst.append(n1)\n                    weight.append(w)\n\n    print(\"node num: %d\" % len(net))\n    print(\"edge num: %d\" % len(src))\n    assert max(net.keys()) == len(net) - 1, \"error reading net, quit\"\n\n    sm = sp.coo_matrix((np.array(weight), (src, dst)), dtype=np.float32)\n\n    return net, node2id, id2node, sm\n\n\ndef net2graph(net_sm):\n    \"\"\"Transform the network to DGL graph\n\n    Return\n    ------\n    G DGLGraph : graph by DGL\n    \"\"\"\n    start = time.time()\n    G = dgl.DGLGraph(net_sm)\n    end = time.time()\n    t = end - start\n    print(\"Building DGLGraph in %.2fs\" % t)\n    return G\n\n\ndef make_undirected(G):\n    G.add_edges(G.edges()[1], G.edges()[0])\n    return G\n\n\ndef find_connected_nodes(G):\n    nodes = torch.nonzero(G.out_degrees(), as_tuple=False).squeeze(-1)\n    return nodes\n\n\nclass LineDataset:\n    def __init__(\n        self,\n        net_file,\n        batch_size,\n        num_samples,\n        negative=5,\n        gpus=[0],\n        fast_neg=True,\n        ogbl_name=\"\",\n        load_from_ogbl=False,\n        ogbn_name=\"\",\n        load_from_ogbn=False,\n    ):\n        \"\"\"This class has the following functions:\n        1. Transform the txt network file into DGL graph;\n        2. Generate random walk sequences for the trainer;\n        3. Provide the negative table if the user hopes to sample negative\n        nodes according to nodes' degrees;\n\n        Parameter\n        ---------\n        net_file str : path of the dgl network file\n        walk_length int : number of nodes in a sequence\n        window_size int : context window size\n        num_walks int : number of walks for each node\n        batch_size int : number of node sequences in each batch\n        negative int : negative samples for each positve node pair\n        fast_neg bool : whether do negative sampling inside a batch\n        \"\"\"\n        self.batch_size = batch_size\n        self.negative = negative\n        self.num_samples = num_samples\n        self.num_procs = len(gpus)\n        self.fast_neg = fast_neg\n\n        if load_from_ogbl:\n            assert (\n                len(gpus) == 1\n            ), \"ogb.linkproppred is not compatible with multi-gpu training.\"\n            from load_dataset import load_from_ogbl_with_name\n\n            self.G = load_from_ogbl_with_name(ogbl_name)\n        elif load_from_ogbn:\n            assert (\n                len(gpus) == 1\n            ), \"ogb.linkproppred is not compatible with multi-gpu training.\"\n            from load_dataset import load_from_ogbn_with_name\n\n            self.G = load_from_ogbn_with_name(ogbn_name)\n        else:\n            self.G = dgl.load_graphs(net_file)[0][0]\n        self.G = make_undirected(self.G)\n        print(\"Finish reading graph\")\n\n        self.num_nodes = self.G.num_nodes()\n\n        start = time.time()\n        seeds = np.random.choice(\n            np.arange(self.G.num_edges()), self.num_samples, replace=True\n        )  # edge index\n        self.seeds = torch.split(\n            torch.LongTensor(seeds),\n            int(np.ceil(self.num_samples / self.num_procs)),\n            0,\n        )\n        end = time.time()\n        t = end - start\n        print(\"generate %d samples in %.2fs\" % (len(seeds), t))\n\n        # negative table for true negative sampling\n        self.valid_nodes = find_connected_nodes(self.G)\n        if not fast_neg:\n            node_degree = self.G.out_degrees(self.valid_nodes).numpy()\n            node_degree = np.power(node_degree, 0.75)\n            node_degree /= np.sum(node_degree)\n            node_degree = np.array(node_degree * 1e8, dtype=int)\n            self.neg_table = []\n\n            for idx, node in enumerate(self.valid_nodes):\n                self.neg_table += [node] * node_degree[idx]\n            self.neg_table_size = len(self.neg_table)\n            self.neg_table = np.array(self.neg_table, dtype=int)\n            del node_degree\n\n    def create_sampler(self, i):\n        \"\"\"create random walk sampler\"\"\"\n        return EdgeSampler(self.G, self.seeds[i])\n\n    def save_mapping(self, map_file):\n        with open(map_file, \"wb\") as f:\n            pickle.dump(self.node2id, f)\n\n\nclass EdgeSampler(object):\n    def __init__(self, G, seeds):\n        self.G = G\n        self.seeds = seeds\n        self.edges = torch.cat(\n            (self.G.edges()[0].unsqueeze(0), self.G.edges()[1].unsqueeze(0)), 0\n        ).t()\n\n    def sample(self, seeds):\n        \"\"\"seeds torch.LongTensor : a batch of indices of edges\"\"\"\n        return self.edges[torch.LongTensor(seeds)]\n"
  },
  {
    "path": "examples/pytorch/ogb/line/utils.py",
    "content": "import torch\n\n\ndef check_args(args):\n    flag = sum([args.only_1st, args.only_2nd])\n    assert (\n        flag <= 1\n    ), \"no more than one selection from --only_1st and --only_2nd\"\n    if flag == 0:\n        assert args.dim % 2 == 0, \"embedding dimension must be an even number\"\n    if args.async_update:\n        assert args.mix, \"please use --async_update with --mix\"\n\n\ndef sum_up_params(model):\n    \"\"\"Count the model parameters\"\"\"\n    n = []\n    if model.fst:\n        p = model.fst_u_embeddings.weight.cpu().data.numel()\n        n.append(p)\n        p = model.fst_state_sum_u.cpu().data.numel()\n        n.append(p)\n    if model.snd:\n        p = model.snd_u_embeddings.weight.cpu().data.numel() * 2\n        n.append(p)\n        p = model.snd_state_sum_u.cpu().data.numel() * 2\n        n.append(p)\n    n.append(model.lookup_table.cpu().numel())\n    try:\n        n.append(model.index_emb_negu.cpu().numel() * 2)\n    except:\n        pass\n    print(\"#params \" + str(sum(n)))\n"
  },
  {
    "path": "examples/pytorch/ogb/ngnn/README.md",
    "content": "# NGNN + GraphSage/GCN\n\n## Introduction\n\nThis is an example of implementing [NGNN](https://arxiv.org/abs/2111.11638) for link prediction in DGL.\n\nWe use a model-agnostic methodology, namely Network In Graph Neural Network (NGNN), which allows arbitrary GNN models to increase their model capacity.\n\nThe script in this folder experiments full-batch GCN/GraphSage (with/without NGNN) on the datasets: ogbl-ddi, ogbl-collab and ogbl-ppa.\n\n## Installation requirements\n```\nogb>=1.3.3\ntorch>=1.11.0\ndgl>=0.8\n```\n\n## Experiments\n\nWe do not fix random seeds at all, and take over 10 runs for all models. All models are trained on a single V100 GPU with 16GB memory.\n\n### ogbl-ddi\n\n#### performance\n\n<table>\n   <tr>\n      <th></th>\n      <th colspan=3 style=\"text-align: center;\">test set</th>\n      <th colspan=3 style=\"text-align: center;\">validation set</th>\n      <th>#parameters</th>\n   </tr>\n   <tr>\n      <td></td>\n      <td>Hits@20</td>\n      <td>Hits@50</td>\n      <td>Hits@100</td>\n      <td>Hits@20</td>\n      <td>Hits@50</td>\n      <td>Hits@100</td>\n      <td></td>\n   </tr>\n   <tr>\n      <td>GCN+NGNN(paper)</td>\n      <td>48.22% ± 7.00%</td>\n      <td>82.56% ± 4.03%</td>\n      <td>89.48% ± 1.68%</td>\n      <td>65.95% ± 1.16%</td>\n      <td>70.24% ± 0.50%</td>\n      <td>72.54% ± 0.62%</td>\n      <td rowspan=2>1,487,361</td>\n   </tr>\n   <tr>\n      <td>GCN+NGNN(ours; 50runs)</td>\n      <td><b>54.83% ± 15.81%</b></td>\n      <td><b>93.15% ± 2.59%</b></td>\n      <td><b>97.05% ± 0.56%</b></td>\n      <td>71.21% ± 0.38%</td>\n      <td>73.55% ± 0.25%</td>\n      <td>76.24% ± 1.33%</td>\n   </tr>\n   <tr>\n      <td>GraphSage+NGNN(paper)</td>\n      <td>60.75% ± 4.94%</td>\n      <td>84.58% ± 1.89%</td>\n      <td>92.58% ± 0.88%</td>\n      <td>68.05% ± 0.68%</td>\n      <td>71.14% ± 0.33%</td>\n      <td>72.77% ± 0.09%</td>\n      <td rowspan=2>1,618,433</td>\n   </tr>\n   <tr>\n      <td>GraphSage+NGNN(ours; 50runs)</td>\n      <td>57.70% ± 15.23%</td>\n      <td><b>96.18% ± 0.94%</b></td>\n      <td><b>98.58% ± 0.17%</b></td>\n      <td>73.23% ± 0.40%</td>\n      <td>87.20% ± 5.29%</td>\n      <td>98.71% ± 0.22%</td>\n   </tr>\n</table>\n\nA 3-layer MLP is used as LinkPredictor here, while a 2-layer one is used by the NGNN paper. This is the main reason for the better performance.\n\n#### Reproduction of performance\n\n- GCN + NGNN\n```{.bash}\npython main.py --dataset ogbl-ddi --device 0 --ngnn_type input --epochs 800 --dropout 0.5 --num_layers 2 --lr 0.0025 --batch_size 16384 --runs 50\n```\n\n- GraphSage + NGNN\n```{.bash}\npython main.py --dataset ogbl-ddi --device 1 --ngnn_type input --use_sage --epochs 600 --dropout 0.25 --num_layers 2 --lr 0.0012 --batch_size 32768 --runs 50\n```\n\n### ogbl-collab\n\n#### Performance\n\n<table>\n   <tr>\n      <th></th>\n      <th colspan=3 style=\"text-align: center;\">test set</th>\n      <th colspan=3 style=\"text-align: center;\">validation set</th>\n      <th>#parameters</th>\n   </tr>\n   <tr>\n      <td></td>\n      <td>Hits@10</td>\n      <td>Hits@50</td>\n      <td>Hits@100</td>\n      <td>Hits@10</td>\n      <td>Hits@50</td>\n      <td>Hits@100</td>\n      <td></td>\n   </tr>\n   <tr>\n      <td>GCN+NGNN(paper)</td>\n      <td>36.69% ± 0.82%</td>\n      <td>51.83% ± 0.50%</td>\n      <td>57.41% ± 0.22%</td>\n      <td>44.97% ± 0.97%</td>\n      <td>60.84% ± 0.63%</td>\n      <td>66.09% ± 0.30%</td>\n      <td rowspan=2>428,033</td>\n   </tr>\n   <tr>\n      <td>GCN+NGNN(ours)</td>\n      <td><b>39.29% ± 1.21%</b></td>\n      <td><b>53.48% ± 0.40%</b></td>\n      <td>58.34% ± 0.45%</td>\n      <td>48.28% ± 1.39%</td>\n      <td>62.73% ± 0.40%</td>\n      <td>67.13% ± 0.39%</td>\n   </tr>\n   <tr>\n      <td>GraphSage+NGNN(paper)</td>\n      <td>36.83% ± 2.56%</td>\n      <td>52.62% ± 1.04%</td>\n      <td>57.96% ± 0.56%</td>\n      <td>45.62% ± 2.56%</td>\n      <td>61.34% ± 1.05%</td>\n      <td>66.26% ± 0.44%</td>\n      <td rowspan=2>591,873</td>\n   </tr>\n   <tr>\n      <td>GraphSage+NGNN(ours)</td>\n      <td><b>40.30% ± 1.03%</b></td>\n      <td>53.59% ± 0.56%</td>\n      <td>58.75% ± 0.57%</td>\n      <td>49.85% ± 1.07%</td>\n      <td>62.81% ± 0.46%</td>\n      <td>67.33% ± 0.38%</td>\n   </tr>\n</table>\n\n#### Reproduction of performance\n\n- GCN + NGNN\n```{.bash}\npython main.py --dataset ogbl-collab --device 2 --ngnn_type hidden --epochs 600 --dropout 0.2 --num_layers 3 --lr 0.001 --batch_size 32768 --runs 10\n```\n\n- GraphSage + NGNN\n```{.bash}\npython main.py --dataset ogbl-collab --device 3 --ngnn_type input --use_sage --epochs 800 --dropout 0.2 --num_layers 3 --lr 0.0005 --batch_size 32768 --runs 10\n```\n\n### ogbl-ppa\n\n#### Performance\n\n<table>\n   <tr>\n      <th></th>\n      <th colspan=3 style=\"text-align: center;\">test set</th>\n      <th colspan=3 style=\"text-align: center;\">validation set</th>\n      <th>#parameters</th>\n   </tr>\n   <tr>\n      <td></td>\n      <td>Hits@10</td>\n      <td>Hits@50</td>\n      <td>Hits@100</td>\n      <td>Hits@10</td>\n      <td>Hits@50</td>\n      <td>Hits@100</td>\n      <td></td>\n   </tr>\n   <tr>\n      <td>GCN+NGNN(paper)</td>\n      <td>5.64% ± 0.93%</td>\n      <td>18.44% ± 1.88%</td>\n      <td>26.78% ± 0.9%</td>\n      <td>8.14% ± 0.71%</td>\n      <td>19.69% ± 0.94%</td>\n      <td>27.86% ± 0.81%</td>\n      <td rowspan=1>673,281</td>\n   </tr>\n   <tr>\n      <td>GCN+NGNN(ours)</td>\n      <td><b>13.07% ± 3.24%</b></td>\n      <td><b>28.55% ± 1.62%</b></td>\n      <td><b>36.83% ± 0.99%</b></td>\n      <td>16.36% ± 1.89%</td>\n      <td>30.56% ± 0.72%</td>\n      <td>38.34% ± 0.82%</td>\n      <td>410,113</td>\n   </tr>\n   <tr>\n      <td>GraphSage+NGNN(paper)</td>\n      <td>3.52% ± 1.24%</td>\n      <td>15.55% ± 1.92%</td>\n      <td>24.45% ± 2.34%</td>\n      <td>5.59% ± 0.93%</td>\n      <td>17.21% ± 0.69%</td>\n      <td>25.42% ± 0.50%</td>\n      <td rowspan=1>819,201</td>\n   </tr>\n   <tr>\n      <td>GraphSage+NGNN(ours)</td>\n      <td><b>11.73% ± 2.42%</b></td>\n      <td><b>29.88% ± 1.84%</b></td>\n      <td><b>40.05% ± 1.38%</b></td>\n      <td>14.73% ± 2.36%</td>\n      <td>31.59% ± 1.72%</td>\n      <td>40.58% ± 1.23%</td>\n      <td>556,033</td>\n   </tr>\n</table>\n\nThe main difference between this implementation and NGNN paper is the position of NGNN (all -> input).\n\n#### Reproduction of performance\n\n- GCN + NGNN\n```{.bash}\npython main.py --dataset ogbl-ppa --device 4 --ngnn_type input --epochs 80 --dropout 0.2 --num_layers 3 --lr 0.001 --batch_size 49152 --runs 10\n```\n\n- GraphSage + NGNN\n```{.bash}\npython main.py --dataset ogbl-ppa --device 5 --ngnn_type input --use_sage --epochs 80 --dropout 0.2 --num_layers 3 --lr 0.001 --batch_size 49152 --runs 10\n```\n\n## References\n\n```{.tex}\n@article{DBLP:journals/corr/abs-2111-11638,\n  author    = {Xiang Song and\n               Runjie Ma and\n               Jiahang Li and\n               Muhan Zhang and\n               David Paul Wipf},\n  title     = {Network In Graph Neural Network},\n  journal   = {CoRR},\n  volume    = {abs/2111.11638},\n  year      = {2021},\n  url       = {https://arxiv.org/abs/2111.11638},\n  eprinttype = {arXiv},\n  eprint    = {2111.11638},\n  timestamp = {Fri, 26 Nov 2021 13:48:43 +0100},\n  biburl    = {https://dblp.org/rec/journals/corr/abs-2111-11638.bib},\n  bibsource = {dblp computer science bibliography, https://dblp.org}\n}\n```\n"
  },
  {
    "path": "examples/pytorch/ogb/ngnn/main.py",
    "content": "import argparse\nimport math\n\nimport dgl\n\nimport torch\nimport torch.nn.functional as F\nfrom dgl.dataloading.negative_sampler import GlobalUniform\nfrom dgl.nn.pytorch import GraphConv, SAGEConv\nfrom ogb.linkproppred import DglLinkPropPredDataset, Evaluator\nfrom torch.nn import Linear\nfrom torch.utils.data import DataLoader\n\n\nclass Logger(object):\n    def __init__(self, runs, info=None):\n        self.info = info\n        self.results = [[] for _ in range(runs)]\n\n    def add_result(self, run, result):\n        assert len(result) == 3\n        assert run >= 0 and run < len(self.results)\n        self.results[run].append(result)\n\n    def print_statistics(self, run=None):\n        if run is not None:\n            result = 100 * torch.tensor(self.results[run])\n            argmax = result[:, 1].argmax().item()\n            print(f\"Run {run + 1:02d}:\")\n            print(f\"Highest Train: {result[:, 0].max():.2f}\")\n            print(f\"Highest Valid: {result[:, 1].max():.2f}\")\n            print(f\"  Final Train: {result[argmax, 0]:.2f}\")\n            print(f\"   Final Test: {result[argmax, 2]:.2f}\")\n        else:\n            result = 100 * torch.tensor(self.results)\n\n            best_results = []\n            for r in result:\n                train1 = r[:, 0].max().item()\n                valid = r[:, 1].max().item()\n                train2 = r[r[:, 1].argmax(), 0].item()\n                test = r[r[:, 1].argmax(), 2].item()\n                best_results.append((train1, valid, train2, test))\n\n            best_result = torch.tensor(best_results)\n\n            print(f\"All runs:\")\n            r = best_result[:, 0]\n            print(f\"Highest Train: {r.mean():.2f} ± {r.std():.2f}\")\n            r = best_result[:, 1]\n            print(f\"Highest Valid: {r.mean():.2f} ± {r.std():.2f}\")\n            r = best_result[:, 2]\n            print(f\"  Final Train: {r.mean():.2f} ± {r.std():.2f}\")\n            r = best_result[:, 3]\n            print(f\"   Final Test: {r.mean():.2f} ± {r.std():.2f}\")\n\n\nclass NGNN_GCNConv(torch.nn.Module):\n    def __init__(\n        self, in_channels, hidden_channels, out_channels, num_nonl_layers\n    ):\n        super(NGNN_GCNConv, self).__init__()\n        self.num_nonl_layers = (\n            num_nonl_layers  # number of nonlinear layers in each conv layer\n        )\n        self.conv = GraphConv(in_channels, hidden_channels)\n        self.fc = Linear(hidden_channels, hidden_channels)\n        self.fc2 = Linear(hidden_channels, out_channels)\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        self.conv.reset_parameters()\n        gain = torch.nn.init.calculate_gain(\"relu\")\n        torch.nn.init.xavier_uniform_(self.fc.weight, gain=gain)\n        torch.nn.init.xavier_uniform_(self.fc2.weight, gain=gain)\n        for bias in [self.fc.bias, self.fc2.bias]:\n            stdv = 1.0 / math.sqrt(bias.size(0))\n            bias.data.uniform_(-stdv, stdv)\n\n    def forward(self, g, x):\n        x = self.conv(g, x)\n\n        if self.num_nonl_layers == 2:\n            x = F.relu(x)\n            x = self.fc(x)\n\n        x = F.relu(x)\n        x = self.fc2(x)\n        return x\n\n\nclass GCN(torch.nn.Module):\n    def __init__(\n        self,\n        in_channels,\n        hidden_channels,\n        out_channels,\n        num_layers,\n        dropout,\n        ngnn_type,\n        dataset,\n    ):\n        super(GCN, self).__init__()\n\n        self.dataset = dataset\n        self.convs = torch.nn.ModuleList()\n\n        num_nonl_layers = (\n            1 if num_layers <= 2 else 2\n        )  # number of nonlinear layers in each conv layer\n        if ngnn_type == \"input\":\n            self.convs.append(\n                NGNN_GCNConv(\n                    in_channels,\n                    hidden_channels,\n                    hidden_channels,\n                    num_nonl_layers,\n                )\n            )\n            for _ in range(num_layers - 2):\n                self.convs.append(GraphConv(hidden_channels, hidden_channels))\n        elif ngnn_type == \"hidden\":\n            self.convs.append(GraphConv(in_channels, hidden_channels))\n            for _ in range(num_layers - 2):\n                self.convs.append(\n                    NGNN_GCNConv(\n                        hidden_channels,\n                        hidden_channels,\n                        hidden_channels,\n                        num_nonl_layers,\n                    )\n                )\n\n        self.convs.append(GraphConv(hidden_channels, out_channels))\n\n        self.dropout = dropout\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        for conv in self.convs:\n            conv.reset_parameters()\n\n    def forward(self, g, x):\n        for conv in self.convs[:-1]:\n            x = conv(g, x)\n            x = F.relu(x)\n            x = F.dropout(x, p=self.dropout, training=self.training)\n        x = self.convs[-1](g, x)\n        return x\n\n\nclass NGNN_SAGEConv(torch.nn.Module):\n    def __init__(\n        self,\n        in_channels,\n        hidden_channels,\n        out_channels,\n        num_nonl_layers,\n        *,\n        reduce,\n    ):\n        super(NGNN_SAGEConv, self).__init__()\n        self.num_nonl_layers = (\n            num_nonl_layers  # number of nonlinear layers in each conv layer\n        )\n        self.conv = SAGEConv(in_channels, hidden_channels, reduce)\n        self.fc = Linear(hidden_channels, hidden_channels)\n        self.fc2 = Linear(hidden_channels, out_channels)\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        self.conv.reset_parameters()\n        gain = torch.nn.init.calculate_gain(\"relu\")\n        torch.nn.init.xavier_uniform_(self.fc.weight, gain=gain)\n        torch.nn.init.xavier_uniform_(self.fc2.weight, gain=gain)\n        for bias in [self.fc.bias, self.fc2.bias]:\n            stdv = 1.0 / math.sqrt(bias.size(0))\n            bias.data.uniform_(-stdv, stdv)\n\n    def forward(self, g, x):\n        x = self.conv(g, x)\n\n        if self.num_nonl_layers == 2:\n            x = F.relu(x)\n            x = self.fc(x)\n\n        x = F.relu(x)\n        x = self.fc2(x)\n        return x\n\n\nclass SAGE(torch.nn.Module):\n    def __init__(\n        self,\n        in_channels,\n        hidden_channels,\n        out_channels,\n        num_layers,\n        dropout,\n        ngnn_type,\n        dataset,\n        reduce=\"mean\",\n    ):\n        super(SAGE, self).__init__()\n\n        self.dataset = dataset\n        self.convs = torch.nn.ModuleList()\n\n        num_nonl_layers = (\n            1 if num_layers <= 2 else 2\n        )  # number of nonlinear layers in each conv layer\n        if ngnn_type == \"input\":\n            self.convs.append(\n                NGNN_SAGEConv(\n                    in_channels,\n                    hidden_channels,\n                    hidden_channels,\n                    num_nonl_layers,\n                    reduce=reduce,\n                )\n            )\n            for _ in range(num_layers - 2):\n                self.convs.append(\n                    SAGEConv(hidden_channels, hidden_channels, reduce)\n                )\n        elif ngnn_type == \"hidden\":\n            self.convs.append(SAGEConv(in_channels, hidden_channels, reduce))\n            for _ in range(num_layers - 2):\n                self.convs.append(\n                    NGNN_SAGEConv(\n                        hidden_channels,\n                        hidden_channels,\n                        hidden_channels,\n                        num_nonl_layers,\n                        reduce=reduce,\n                    )\n                )\n\n        self.convs.append(SAGEConv(hidden_channels, out_channels, reduce))\n\n        self.dropout = dropout\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        for conv in self.convs:\n            conv.reset_parameters()\n\n    def forward(self, g, x):\n        for conv in self.convs[:-1]:\n            x = conv(g, x)\n            x = F.relu(x)\n            x = F.dropout(x, p=self.dropout, training=self.training)\n        x = self.convs[-1](g, x)\n        return x\n\n\nclass LinkPredictor(torch.nn.Module):\n    def __init__(\n        self, in_channels, hidden_channels, out_channels, num_layers, dropout\n    ):\n        super(LinkPredictor, self).__init__()\n\n        self.lins = torch.nn.ModuleList()\n        self.lins.append(Linear(in_channels, hidden_channels))\n        for _ in range(num_layers - 2):\n            self.lins.append(Linear(hidden_channels, hidden_channels))\n        self.lins.append(Linear(hidden_channels, out_channels))\n\n        self.dropout = dropout\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        for lin in self.lins:\n            lin.reset_parameters()\n\n    def forward(self, x_i, x_j):\n        x = x_i * x_j\n        for lin in self.lins[:-1]:\n            x = lin(x)\n            x = F.relu(x)\n            x = F.dropout(x, p=self.dropout, training=self.training)\n        x = self.lins[-1](x)\n        return torch.sigmoid(x)\n\n\ndef train(model, predictor, g, x, split_edge, optimizer, batch_size):\n    model.train()\n    predictor.train()\n\n    pos_train_edge = split_edge[\"train\"][\"edge\"].to(x.device)\n    neg_sampler = GlobalUniform(1)\n    total_loss = total_examples = 0\n    for perm in DataLoader(\n        range(pos_train_edge.size(0)), batch_size, shuffle=True\n    ):\n        optimizer.zero_grad()\n\n        h = model(g, x)\n\n        edge = pos_train_edge[perm].t()\n\n        pos_out = predictor(h[edge[0]], h[edge[1]])\n        pos_loss = -torch.log(pos_out + 1e-15).mean()\n\n        edge = neg_sampler(g, edge[0])\n\n        neg_out = predictor(h[edge[0]], h[edge[1]])\n        neg_loss = -torch.log(1 - neg_out + 1e-15).mean()\n\n        loss = pos_loss + neg_loss\n        loss.backward()\n\n        if model.dataset == \"ogbl-ddi\":\n            torch.nn.utils.clip_grad_norm_(x, 1.0)\n        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n        torch.nn.utils.clip_grad_norm_(predictor.parameters(), 1.0)\n\n        optimizer.step()\n\n        num_examples = pos_out.size(0)\n        total_loss += loss.item() * num_examples\n        total_examples += num_examples\n\n    return total_loss / total_examples\n\n\n@torch.no_grad()\ndef test(model, predictor, g, x, split_edge, evaluator, batch_size):\n    model.eval()\n    predictor.eval()\n\n    h = model(g, x)\n\n    pos_train_edge = split_edge[\"eval_train\"][\"edge\"].to(h.device)\n    pos_valid_edge = split_edge[\"valid\"][\"edge\"].to(h.device)\n    neg_valid_edge = split_edge[\"valid\"][\"edge_neg\"].to(h.device)\n    pos_test_edge = split_edge[\"test\"][\"edge\"].to(h.device)\n    neg_test_edge = split_edge[\"test\"][\"edge_neg\"].to(h.device)\n\n    def get_pred(test_edges, h):\n        preds = []\n        for perm in DataLoader(range(test_edges.size(0)), batch_size):\n            edge = test_edges[perm].t()\n            preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()]\n        pred = torch.cat(preds, dim=0)\n        return pred\n\n    pos_train_pred = get_pred(pos_train_edge, h)\n    pos_valid_pred = get_pred(pos_valid_edge, h)\n    neg_valid_pred = get_pred(neg_valid_edge, h)\n    pos_test_pred = get_pred(pos_test_edge, h)\n    neg_test_pred = get_pred(neg_test_edge, h)\n\n    results = {}\n    for K in [20, 50, 100]:\n        evaluator.K = K\n        train_hits = evaluator.eval(\n            {\n                \"y_pred_pos\": pos_train_pred,\n                \"y_pred_neg\": neg_valid_pred,\n            }\n        )[f\"hits@{K}\"]\n        valid_hits = evaluator.eval(\n            {\n                \"y_pred_pos\": pos_valid_pred,\n                \"y_pred_neg\": neg_valid_pred,\n            }\n        )[f\"hits@{K}\"]\n        test_hits = evaluator.eval(\n            {\n                \"y_pred_pos\": pos_test_pred,\n                \"y_pred_neg\": neg_test_pred,\n            }\n        )[f\"hits@{K}\"]\n\n        results[f\"Hits@{K}\"] = (train_hits, valid_hits, test_hits)\n\n    return results\n\n\ndef main():\n    parser = argparse.ArgumentParser(\n        description=\"OGBL(Full Batch GCN/GraphSage + NGNN)\"\n    )\n\n    # dataset setting\n    parser.add_argument(\n        \"--dataset\",\n        type=str,\n        default=\"ogbl-ddi\",\n        choices=[\"ogbl-ddi\", \"ogbl-collab\", \"ogbl-ppa\"],\n    )\n\n    # device setting\n    parser.add_argument(\n        \"--device\",\n        type=int,\n        default=0,\n        help=\"GPU device ID. Use -1 for CPU training.\",\n    )\n\n    # model structure settings\n    parser.add_argument(\n        \"--use_sage\",\n        action=\"store_true\",\n        help=\"If not set, use GCN by default.\",\n    )\n    parser.add_argument(\n        \"--ngnn_type\",\n        type=str,\n        default=\"input\",\n        choices=[\"input\", \"hidden\"],\n        help=\"You can set this value from 'input' or 'hidden' to apply NGNN to different GNN layers.\",\n    )\n    parser.add_argument(\n        \"--num_layers\", type=int, default=3, help=\"number of GNN layers\"\n    )\n    parser.add_argument(\"--hidden_channels\", type=int, default=256)\n    parser.add_argument(\"--dropout\", type=float, default=0.0)\n    parser.add_argument(\"--batch_size\", type=int, default=64 * 1024)\n    parser.add_argument(\"--lr\", type=float, default=0.001)\n    parser.add_argument(\"--epochs\", type=int, default=400)\n\n    # training settings\n    parser.add_argument(\"--eval_steps\", type=int, default=1)\n    parser.add_argument(\"--runs\", type=int, default=10)\n    args = parser.parse_args()\n    print(args)\n\n    device = (\n        f\"cuda:{args.device}\"\n        if args.device != -1 and torch.cuda.is_available()\n        else \"cpu\"\n    )\n    device = torch.device(device)\n\n    dataset = DglLinkPropPredDataset(name=args.dataset)\n    g = dataset[0]\n    split_edge = dataset.get_edge_split()\n\n    # We randomly pick some training samples that we want to evaluate on:\n    idx = torch.randperm(split_edge[\"train\"][\"edge\"].size(0))\n    idx = idx[: split_edge[\"valid\"][\"edge\"].size(0)]\n    split_edge[\"eval_train\"] = {\"edge\": split_edge[\"train\"][\"edge\"][idx]}\n\n    if dataset.name == \"ogbl-ppa\":\n        g.ndata[\"feat\"] = g.ndata[\"feat\"].to(torch.float)\n\n    if dataset.name == \"ogbl-ddi\":\n        emb = torch.nn.Embedding(g.num_nodes(), args.hidden_channels).to(device)\n        in_channels = args.hidden_channels\n    else:  # ogbl-collab, ogbl-ppa\n        in_channels = g.ndata[\"feat\"].size(-1)\n\n    # select model\n    if args.use_sage:\n        model = SAGE(\n            in_channels,\n            args.hidden_channels,\n            args.hidden_channels,\n            args.num_layers,\n            args.dropout,\n            args.ngnn_type,\n            dataset.name,\n        )\n    else:  # GCN\n        g = dgl.add_self_loop(g)\n        model = GCN(\n            in_channels,\n            args.hidden_channels,\n            args.hidden_channels,\n            args.num_layers,\n            args.dropout,\n            args.ngnn_type,\n            dataset.name,\n        )\n\n    predictor = LinkPredictor(\n        args.hidden_channels, args.hidden_channels, 1, 3, args.dropout\n    )\n\n    g, model, predictor = map(lambda x: x.to(device), (g, model, predictor))\n\n    evaluator = Evaluator(name=dataset.name)\n    loggers = {\n        \"Hits@20\": Logger(args.runs, args),\n        \"Hits@50\": Logger(args.runs, args),\n        \"Hits@100\": Logger(args.runs, args),\n    }\n\n    for run in range(args.runs):\n        model.reset_parameters()\n        predictor.reset_parameters()\n        if dataset.name == \"ogbl-ddi\":\n            torch.nn.init.xavier_uniform_(emb.weight)\n            g.ndata[\"feat\"] = emb.weight\n        optimizer = torch.optim.Adam(\n            list(model.parameters())\n            + list(predictor.parameters())\n            + (list(emb.parameters()) if dataset.name == \"ogbl-ddi\" else []),\n            lr=args.lr,\n        )\n        for epoch in range(1, 1 + args.epochs):\n            loss = train(\n                model,\n                predictor,\n                g,\n                g.ndata[\"feat\"],\n                split_edge,\n                optimizer,\n                args.batch_size,\n            )\n\n            if epoch % args.eval_steps == 0:\n                results = test(\n                    model,\n                    predictor,\n                    g,\n                    g.ndata[\"feat\"],\n                    split_edge,\n                    evaluator,\n                    args.batch_size,\n                )\n                for key, result in results.items():\n                    loggers[key].add_result(run, result)\n                    train_hits, valid_hits, test_hits = result\n                    print(key)\n                    print(\n                        f\"Run: {run + 1:02d}, \"\n                        f\"Epoch: {epoch:02d}, \"\n                        f\"Loss: {loss:.4f}, \"\n                        f\"Train: {100 * train_hits:.2f}%, \"\n                        f\"Valid: {100 * valid_hits:.2f}%, \"\n                        f\"Test: {100 * test_hits:.2f}%\"\n                    )\n                print(\"---\")\n\n        for key in loggers.keys():\n            print(key)\n            loggers[key].print_statistics(run)\n\n    for key in loggers.keys():\n        print(key)\n        loggers[key].print_statistics()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/pytorch/ogb/ngnn_seal/README.md",
    "content": "# NGNN + SEAL\n\n## Introduction\n\nThis is a submission of implementing [NGNN](https://arxiv.org/abs/2111.11638) + [SEAL](https://arxiv.org/pdf/2010.16103.pdf) to OGB link prediction leaderboards. Some code is migrated from [https://github.com/facebookresearch/SEAL_OGB](https://github.com/facebookresearch/SEAL_OGB).\n\n## Installation Requirements\n```\nogb>=1.3.4\ntorch>=1.12.0\ndgl>=0.8\nscipy, numpy, tqdm...\n```\n\n## Experiments\n\nWe do not fix random seeds at all, and take over 10 runs for all models. All models are trained on a single T4 GPU with 16GB memory and 96 vCPUs.\n\n### ogbl-ppa\n\n#### performance\n\n|              | Test Hits@100 | Validation Hits@100 | #Parameters |\n|:------------:|:-------------------:|:-----------------:|:------------:|\n| SEAL | 48.80% ± 3.16% | 51.25% ± 2.52% | 709,122 |\n| SEAL + NGNN | 59.71% ± 2.45% | 59.95% ± 2.05% | 735,426 |\n\n#### Reproduction of performance\n\n```{.bash}\npython main.py --dataset ogbl-ppa --ngnn_type input --hidden_channels 48 --epochs 50 --lr 0.00015 --batch_size 128 --num_workers 48  --train_percent 5 --val_percent 8 --eval_hits_K 10 --use_feature --dynamic_train --dynamic_val --dynamic_test --runs 10\n```\n\nAs training is very costly, we select the best model by evaluation on a subset of the validation edges and using a lower K for Hits@K. Then we do experiments on the full validation and test sets with the best model selected, and get the required metrics.  \n\n### ogbl-citation2\n\n#### performance\n\n|              | Test MRR | Validation MRR | #Parameters |\n|:------------:|:-------------------:|:-----------------:|:------------:|\n| SEAL | 0.8767 ± 0.0032 | 0.8757 ± 0.0031 | 260,802 |\n| SEAL + NGNN | 0.8891 ± 0.0022 | 0.8879 ± 0.0022 | 1,134,402 |\n\n#### Reproduction of performance\n\n```{.bash}\npython main.py --dataset ogbl-citation2 --ngnn_type all --hidden_channels 256 --epochs 15 --lr 2e-05 --batch_size 64 --num_workers 24  --train_percent 8 --val_percent 4 --num_ngnn_layers 2 --use_feature --use_edge_weight --dynamic_train --dynamic_val --dynamic_test --runs 10\n```\n\nFor all datasets, if you specify `--dynamic_train`, the enclosing subgraphs of the training links will be extracted on the fly instead of preprocessing and saving to disk. Similarly for `--dynamic_val` and `--dynamic_test`. You can increase `--num_workers` to accelerate the dynamic subgraph extraction process.  \nYou can also specify the `val_percent` and `eval_hits_K` arguments in the above command to adjust the proportion of the validation dataset to use and the K to use for Hits@K.\n\n## Reference\n\n    @article{DBLP:journals/corr/abs-2111-11638,\n      author    = {Xiang Song and\n                   Runjie Ma and\n                   Jiahang Li and\n                   Muhan Zhang and\n                   David Paul Wipf},\n      title     = {Network In Graph Neural Network},\n      journal   = {CoRR},\n      volume    = {abs/2111.11638},\n      year      = {2021},\n      url       = {https://arxiv.org/abs/2111.11638},\n      eprinttype = {arXiv},\n      eprint    = {2111.11638},\n      timestamp = {Fri, 26 Nov 2021 13:48:43 +0100},\n      biburl    = {https://dblp.org/rec/journals/corr/abs-2111-11638.bib},\n      bibsource = {dblp computer science bibliography, https://dblp.org}\n    }\n    \n    @article{zhang2021labeling,\n        title={Labeling Trick: A Theory of Using Graph Neural Networks for Multi-Node Representation Learning},\n        author={Zhang, Muhan and Li, Pan and Xia, Yinglong and Wang, Kai and Jin, Long},\n        journal={Advances in Neural Information Processing Systems},\n        volume={34},\n        year={2021}\n        }\n    \n    @inproceedings{zhang2018link,\n      title={Link prediction based on graph neural networks},\n      author={Zhang, Muhan and Chen, Yixin},\n      booktitle={Advances in Neural Information Processing Systems},\n      pages={5165--5175},\n      year={2018}\n    }"
  },
  {
    "path": "examples/pytorch/ogb/ngnn_seal/main.py",
    "content": "import argparse\nimport datetime\nimport os\nimport sys\nimport time\n\nimport dgl\nimport torch\nfrom dgl.data.utils import load_graphs, save_graphs\nfrom dgl.dataloading import GraphDataLoader\nfrom ogb.linkproppred import DglLinkPropPredDataset, Evaluator\nfrom torch.nn import BCEWithLogitsLoss\nfrom torch.utils.data import Dataset\nfrom tqdm import tqdm\n\nfrom models import *\nfrom utils import *\n\n\nclass SEALOGBLDataset(Dataset):\n    def __init__(\n        self,\n        root,\n        graph,\n        split_edge,\n        percent=100,\n        split=\"train\",\n        ratio_per_hop=1.0,\n        directed=False,\n        dynamic=True,\n    ) -> None:\n        super().__init__()\n        self.root = root\n        self.graph = graph\n        self.split = split\n        self.split_edge = split_edge\n        self.percent = percent\n        self.ratio_per_hop = ratio_per_hop\n        self.directed = directed\n        self.dynamic = dynamic\n\n        if \"weights\" in self.graph.edata:\n            self.edge_weights = self.graph.edata[\"weights\"]\n        else:\n            self.edge_weights = None\n        if \"feat\" in self.graph.ndata:\n            self.node_features = self.graph.ndata[\"feat\"]\n        else:\n            self.node_features = None\n\n        pos_edge, neg_edge = get_pos_neg_edges(\n            self.split, self.split_edge, self.graph, self.percent\n        )\n        self.links = torch.cat([pos_edge, neg_edge], 0)  # [Np + Nn, 2]\n        self.labels = np.array([1] * len(pos_edge) + [0] * len(neg_edge))\n\n        if not self.dynamic:\n            self.g_list, tensor_dict = self.load_cached()\n            self.labels = tensor_dict[\"y\"]\n\n    def __len__(self):\n        return len(self.labels)\n\n    def __getitem__(self, idx):\n        if not self.dynamic:\n            g, y = self.g_list[idx], self.labels[idx]\n            x = None if \"x\" not in g.ndata else g.ndata[\"x\"]\n            w = None if \"w\" not in g.edata else g.eata[\"w\"]\n            return g, g.ndata[\"z\"], x, w, y\n\n        src, dst = self.links[idx][0].item(), self.links[idx][1].item()\n        y = self.labels[idx]\n        subg = k_hop_subgraph(\n            src, dst, 1, self.graph, self.ratio_per_hop, self.directed\n        )\n\n        # Remove the link between src and dst.\n        direct_links = [[], []]\n        for s, t in [(0, 1), (1, 0)]:\n            if subg.has_edges_between(s, t):\n                direct_links[0].append(s)\n                direct_links[1].append(t)\n        if len(direct_links[0]):\n            subg.remove_edges(subg.edge_ids(*direct_links))\n\n        NIDs, EIDs = subg.ndata[dgl.NID], subg.edata[dgl.EID]\n\n        z = drnl_node_labeling(subg.adj_external(scipy_fmt=\"csr\"), 0, 1)\n        edge_weights = (\n            self.edge_weights[EIDs] if self.edge_weights is not None else None\n        )\n        x = self.node_features[NIDs] if self.node_features is not None else None\n\n        subg_aug = subg.add_self_loop()\n        if edge_weights is not None:\n            edge_weights = torch.cat(\n                [\n                    edge_weights,\n                    torch.ones(subg_aug.num_edges() - subg.num_edges()),\n                ]\n            )\n        return subg_aug, z, x, edge_weights, y\n\n    @property\n    def cached_name(self):\n        return f\"SEAL_{self.split}_{self.percent}%.pt\"\n\n    def process(self):\n        g_list, labels = [], []\n        self.dynamic = True\n        for i in tqdm(range(len(self))):\n            g, z, x, weights, y = self[i]\n            g.ndata[\"z\"] = z\n            if x is not None:\n                g.ndata[\"x\"] = x\n            if weights is not None:\n                g.edata[\"w\"] = weights\n            g_list.append(g)\n            labels.append(y)\n        self.dynamic = False\n        return g_list, {\"y\": torch.tensor(labels)}\n\n    def load_cached(self):\n        path = os.path.join(self.root, self.cached_name)\n        if os.path.exists(path):\n            return load_graphs(path)\n\n        if not os.path.exists(self.root):\n            os.makedirs(self.root)\n\n        g_list, labels = self.process()\n        save_graphs(path, g_list, labels)\n        return g_list, labels\n\n\ndef ogbl_collate_fn(batch):\n    gs, zs, xs, ws, ys = zip(*batch)\n    batched_g = dgl.batch(gs)\n    z = torch.cat(zs, dim=0)\n    if xs[0] is not None:\n        x = torch.cat(xs, dim=0)\n    else:\n        x = None\n    if ws[0] is not None:\n        edge_weights = torch.cat(ws, dim=0)\n    else:\n        edge_weights = None\n    y = torch.tensor(ys)\n\n    return batched_g, z, x, edge_weights, y\n\n\ndef train():\n    model.train()\n    loss_fnt = BCEWithLogitsLoss()\n    total_loss = 0\n    pbar = tqdm(train_loader, ncols=70)\n    for batch in pbar:\n        g, z, x, edge_weights, y = [\n            item.to(device) if item is not None else None for item in batch\n        ]\n        optimizer.zero_grad()\n        logits = model(g, z, x, edge_weight=edge_weights)\n        loss = loss_fnt(logits.view(-1), y.to(torch.float))\n        loss.backward()\n        optimizer.step()\n        total_loss += loss.item() * g.batch_size\n\n    return total_loss / len(train_dataset)\n\n\n@torch.no_grad()\ndef test(dataloader, hits_K=[\"hits@100\"]):\n    model.eval()\n\n    if isinstance(hits_K, (int, str)):\n        hits_K = [hits_K]\n    y_pred, y_true = [], []\n    for batch in tqdm(dataloader, ncols=70):\n        g, z, x, edge_weights, y = [\n            item.to(device) if item is not None else None for item in batch\n        ]\n        logits = model(g, z, x, edge_weight=edge_weights)\n        y_pred.append(logits.view(-1).cpu())\n        y_true.append(y.view(-1).cpu().to(torch.float))\n    y_pred, y_true = torch.cat(y_pred), torch.cat(y_true)\n    pos_y_pred = y_pred[y_true == 1]\n    neg_y_pred = y_pred[y_true == 0]\n\n    if dataset.eval_metric.startswith(\"hits@\"):\n        results = evaluate_hits(pos_y_pred, neg_y_pred, hits_K)\n    elif dataset.eval_metric == \"mrr\":\n        results = evaluate_mrr(pos_y_pred, neg_y_pred)\n    elif dataset.eval_metric == \"rocauc\":\n        results = evaluate_rocauc(pos_y_pred, neg_y_pred)\n\n    return results\n\n\ndef evaluate_hits(y_pred_pos, y_pred_neg, hits_K):\n    results = {}\n    hits_K = map(\n        lambda x: (int(x.split(\"@\")[1]) if isinstance(x, str) else x), hits_K\n    )\n    for K in hits_K:\n        evaluator.K = K\n        hits = evaluator.eval(\n            {\n                \"y_pred_pos\": y_pred_pos,\n                \"y_pred_neg\": y_pred_neg,\n            }\n        )[f\"hits@{K}\"]\n\n        results[f\"hits@{K}\"] = hits\n\n    return results\n\n\ndef evaluate_mrr(y_pred_pos, y_pred_neg):\n    y_pred_neg = y_pred_neg.view(y_pred_pos.shape[0], -1)\n    results = {}\n    mrr = (\n        evaluator.eval(\n            {\n                \"y_pred_pos\": y_pred_pos,\n                \"y_pred_neg\": y_pred_neg,\n            }\n        )[\"mrr_list\"]\n        .mean()\n        .item()\n    )\n\n    results[\"mrr\"] = mrr\n\n    return results\n\n\ndef evaluate_rocauc(y_pred_pos, y_pred_neg):\n    results = {}\n    rocauc = evaluator.eval(\n        {\n            \"y_pred_pos\": y_pred_pos,\n            \"y_pred_neg\": y_pred_neg,\n        }\n    )[\"rocauc\"]\n\n    results[\"rocauc\"] = rocauc\n\n    return results\n\n\ndef print_log(*x, sep=\"\\n\", end=\"\\n\", mode=\"a\"):\n    print(*x, sep=sep, end=end)\n    with open(log_file, mode=mode) as f:\n        print(*x, sep=sep, end=end, file=f)\n\n\nif __name__ == \"__main__\":\n    # Data settings\n    parser = argparse.ArgumentParser(description=\"OGBL (SEAL)\")\n    parser.add_argument(\"--dataset\", type=str, default=\"ogbl-vessel\")\n    # GNN settings\n    parser.add_argument(\n        \"--max_z\",\n        type=int,\n        default=1000,\n        help=\"max number of labels as embeddings to look up\",\n    )\n    parser.add_argument(\"--sortpool_k\", type=float, default=0.6)\n    parser.add_argument(\"--num_layers\", type=int, default=3)\n    parser.add_argument(\"--hidden_channels\", type=int, default=32)\n    parser.add_argument(\"--batch_size\", type=int, default=32)\n    parser.add_argument(\n        \"--ngnn_type\",\n        type=str,\n        default=\"none\",\n        choices=[\"none\", \"input\", \"hidden\", \"output\", \"all\"],\n        help=\"You can set this value from 'none', 'input', 'hidden' or 'all' \"\n        \"to apply NGNN to different GNN layers.\",\n    )\n    parser.add_argument(\n        \"--num_ngnn_layers\", type=int, default=1, choices=[1, 2]\n    )\n    # Subgraph extraction settings\n    parser.add_argument(\"--ratio_per_hop\", type=float, default=1.0)\n    parser.add_argument(\n        \"--use_feature\",\n        action=\"store_true\",\n        help=\"whether to use raw node features as GNN input\",\n    )\n    parser.add_argument(\n        \"--use_edge_weight\",\n        action=\"store_true\",\n        help=\"whether to consider edge weight in GNN\",\n    )\n    # Training settings\n    parser.add_argument(\n        \"--device\",\n        type=int,\n        default=0,\n        help=\"GPU device ID. Use -1 for CPU training.\",\n    )\n    parser.add_argument(\"--lr\", type=float, default=0.001)\n    parser.add_argument(\"--epochs\", type=int, default=5)\n    parser.add_argument(\"--dropout\", type=float, default=0.0)\n    parser.add_argument(\"--runs\", type=int, default=10)\n    parser.add_argument(\"--train_percent\", type=float, default=1)\n    parser.add_argument(\"--val_percent\", type=float, default=1)\n    parser.add_argument(\"--final_val_percent\", type=float, default=100)\n    parser.add_argument(\"--test_percent\", type=float, default=100)\n    parser.add_argument(\"--no_test\", action=\"store_true\")\n    parser.add_argument(\n        \"--dynamic_train\",\n        action=\"store_true\",\n        help=\"dynamically extract enclosing subgraphs on the fly\",\n    )\n    parser.add_argument(\"--dynamic_val\", action=\"store_true\")\n    parser.add_argument(\"--dynamic_test\", action=\"store_true\")\n    parser.add_argument(\n        \"--num_workers\",\n        type=int,\n        default=24,\n        help=\"number of workers for dynamic dataloaders; \"\n        \"using a larger value for dynamic dataloading is recommended\",\n    )\n    # Testing settings\n    parser.add_argument(\n        \"--use_valedges_as_input\",\n        action=\"store_true\",\n        help=\"available for ogbl-collab\",\n    )\n    parser.add_argument(\"--eval_steps\", type=int, default=1)\n    parser.add_argument(\n        \"--eval_hits_K\",\n        type=int,\n        nargs=\"*\",\n        default=[10],\n        help=\"hits@K for each eval step; \"\n        \"only available for datasets with hits@xx as the eval metric\",\n    )\n    parser.add_argument(\n        \"--test_topk\",\n        type=int,\n        default=1,\n        help=\"select best k models for full validation/test each run.\",\n    )\n    args = parser.parse_args()\n\n    data_appendix = \"_rph{}\".format(\"\".join(str(args.ratio_per_hop).split(\".\")))\n    if args.use_valedges_as_input:\n        data_appendix += \"_uvai\"\n\n    args.res_dir = os.path.join(\n        f'results{\"_NoTest\" if args.no_test else \"\"}',\n        f'{args.dataset.split(\"-\")[1]}-{args.ngnn_type}+{time.strftime(\"%m%d%H%M%S\")}',\n    )\n    print(f\"Results will be saved in {args.res_dir}\")\n    if not os.path.exists(args.res_dir):\n        os.makedirs(args.res_dir)\n    log_file = os.path.join(args.res_dir, \"log.txt\")\n    # Save command line input.\n    cmd_input = \"python \" + \" \".join(sys.argv) + \"\\n\"\n    with open(os.path.join(args.res_dir, \"cmd_input.txt\"), \"a\") as f:\n        f.write(cmd_input)\n    print(f\"Command line input is saved.\")\n    print_log(f\"{cmd_input}\")\n\n    dataset = DglLinkPropPredDataset(name=args.dataset)\n    split_edge = dataset.get_edge_split()\n    graph = dataset[0]\n\n    # Re-format the data of ogbl-citation2.\n    if args.dataset == \"ogbl-citation2\":\n        for k in [\"train\", \"valid\", \"test\"]:\n            src = split_edge[k][\"source_node\"]\n            tgt = split_edge[k][\"target_node\"]\n            split_edge[k][\"edge\"] = torch.stack([src, tgt], dim=1)\n            if k != \"train\":\n                tgt_neg = split_edge[k][\"target_node_neg\"]\n                split_edge[k][\"edge_neg\"] = torch.stack(\n                    [src[:, None].repeat(1, tgt_neg.size(1)), tgt_neg], dim=-1\n                )  # [Ns, Nt, 2]\n\n    # Reconstruct the graph for ogbl-collab data\n    # for validation edge augmentation and coalesce.\n    if args.dataset == \"ogbl-collab\":\n        # Float edata for to_simple transformation.\n        graph.edata.pop(\"year\")\n        graph.edata[\"weight\"] = graph.edata[\"weight\"].to(torch.float)\n        if args.use_valedges_as_input:\n            val_edges = split_edge[\"valid\"][\"edge\"]\n            row, col = val_edges.t()\n            val_weights = torch.ones(size=(val_edges.size(0), 1))\n            graph.add_edges(\n                torch.cat([row, col]),\n                torch.cat([col, row]),\n                {\"weight\": val_weights},\n            )\n        graph = graph.to_simple(copy_edata=True, aggregator=\"sum\")\n\n    if args.dataset == \"ogbl-vessel\":\n        graph.ndata[\"feat\"][:, 0] = torch.nn.functional.normalize(\n            graph.ndata[\"feat\"][:, 0], dim=0\n        )\n        graph.ndata[\"feat\"][:, 1] = torch.nn.functional.normalize(\n            graph.ndata[\"feat\"][:, 1], dim=0\n        )\n        graph.ndata[\"feat\"][:, 2] = torch.nn.functional.normalize(\n            graph.ndata[\"feat\"][:, 2], dim=0\n        )\n        graph.ndata[\"feat\"] = graph.ndata[\"feat\"].to(torch.float)\n\n    if not args.use_edge_weight and \"weight\" in graph.edata:\n        del graph.edata[\"weight\"]\n    if not args.use_feature and \"feat\" in graph.ndata:\n        del graph.ndata[\"feat\"]\n\n    directed = args.dataset.startswith(\"ogbl-citation\")\n\n    evaluator = Evaluator(name=args.dataset)\n    if dataset.eval_metric.startswith(\"hits@\"):\n        loggers = {\n            f\"hits@{k}\": Logger(args.runs, args) for k in args.eval_hits_K\n        }\n    elif dataset.eval_metric == \"mrr\":\n        loggers = {\n            \"mrr\": Logger(args.runs, args),\n        }\n    elif dataset.eval_metric == \"rocauc\":\n        loggers = {\n            \"rocauc\": Logger(args.runs, args),\n        }\n\n    device = (\n        f\"cuda:{args.device}\"\n        if args.device != -1 and torch.cuda.is_available()\n        else \"cpu\"\n    )\n    device = torch.device(device)\n    path = f\"{dataset.root}_seal{data_appendix}\"\n\n    if not (args.dynamic_train or args.dynamic_val or args.dynamic_test):\n        args.num_workers = 0\n\n    train_dataset, val_dataset, final_val_dataset, test_dataset = [\n        SEALOGBLDataset(\n            path,\n            graph,\n            split_edge,\n            percent=percent,\n            split=split,\n            ratio_per_hop=args.ratio_per_hop,\n            directed=directed,\n            dynamic=dynamic,\n        )\n        for percent, split, dynamic in zip(\n            [\n                args.train_percent,\n                args.val_percent,\n                args.final_val_percent,\n                args.test_percent,\n            ],\n            [\"train\", \"valid\", \"valid\", \"test\"],\n            [\n                args.dynamic_train,\n                args.dynamic_val,\n                args.dynamic_test,\n                args.dynamic_test,\n            ],\n        )\n    ]\n\n    train_loader = GraphDataLoader(\n        train_dataset,\n        batch_size=args.batch_size,\n        shuffle=True,\n        collate_fn=ogbl_collate_fn,\n        num_workers=args.num_workers,\n    )\n    val_loader = GraphDataLoader(\n        val_dataset,\n        batch_size=args.batch_size,\n        shuffle=False,\n        collate_fn=ogbl_collate_fn,\n        num_workers=args.num_workers,\n    )\n    final_val_loader = GraphDataLoader(\n        final_val_dataset,\n        batch_size=args.batch_size,\n        shuffle=False,\n        collate_fn=ogbl_collate_fn,\n        num_workers=args.num_workers,\n    )\n    test_loader = GraphDataLoader(\n        test_dataset,\n        batch_size=args.batch_size,\n        shuffle=False,\n        collate_fn=ogbl_collate_fn,\n        num_workers=args.num_workers,\n    )\n\n    if 0 < args.sortpool_k <= 1:  # Transform percentile to number.\n        if args.dataset.startswith(\"ogbl-citation\"):\n            # For this dataset, subgraphs extracted around positive edges are\n            # rather larger than negative edges. Thus we sample from 1000\n            # positive and 1000 negative edges to estimate the k (number of\n            # nodes to hold for each graph) used in SortPooling.\n            # You can certainly set k manually, instead of estimating from\n            # a percentage of sampled subgraphs.\n            _sampled_indices = list(range(1000)) + list(\n                range(len(train_dataset) - 1000, len(train_dataset))\n            )\n        else:\n            _sampled_indices = list(range(1000))\n        _num_nodes = sorted(\n            [train_dataset[i][0].num_nodes() for i in _sampled_indices]\n        )\n        _k = _num_nodes[int(math.ceil(args.sortpool_k * len(_num_nodes))) - 1]\n        model_k = max(10, _k)\n    else:\n        raise argparse.ArgumentTypeError(\"sortpool_k must be in range (0, 1].\")\n\n    print_log(f\"training starts: {datetime.datetime.now()}\")\n\n    for run in range(args.runs):\n        stime = datetime.datetime.now()\n        print_log(f\"\\n++++++\\n\\nstart run [{run+1}], {stime}\")\n\n        model = DGCNN(\n            args.hidden_channels,\n            args.num_layers,\n            args.max_z,\n            model_k,\n            feature_dim=graph.ndata[\"feat\"].size(1)\n            if (args.use_feature and \"feat\" in graph.ndata)\n            else 0,\n            dropout=args.dropout,\n            ngnn_type=args.ngnn_type,\n            num_ngnn_layers=args.num_ngnn_layers,\n        ).to(device)\n        parameters = list(model.parameters())\n        optimizer = torch.optim.Adam(params=parameters, lr=args.lr)\n        total_params = sum(p.numel() for param in parameters for p in param)\n        print_log(\n            f\"Total number of parameters is {total_params}\",\n            f\"SortPooling k is set to {model.k}\",\n        )\n\n        start_epoch = 1\n        # Training starts.\n        for epoch in range(start_epoch, start_epoch + args.epochs):\n            epo_stime = datetime.datetime.now()\n            loss = train()\n            epo_train_etime = datetime.datetime.now()\n            print_log(\n                f\"[epoch: {epoch}]\",\n                f\"   <Train> starts: {epo_stime}, \"\n                f\"ends: {epo_train_etime}, \"\n                f\"spent time:{epo_train_etime - epo_stime}\",\n            )\n            if epoch % args.eval_steps == 0:\n                epo_eval_stime = datetime.datetime.now()\n                results = test(val_loader, loggers.keys())\n                epo_eval_etime = datetime.datetime.now()\n                print_log(\n                    f\"   <Validation> starts: {epo_eval_stime}, \"\n                    f\"ends: {epo_eval_etime}, \"\n                    f\"spent time:{epo_eval_etime - epo_eval_stime}\"\n                )\n                for key, valid_res in results.items():\n                    loggers[key].add_result(run, valid_res)\n                    to_print = (\n                        f\"Run: {run + 1:02d}, \"\n                        f\"Epoch: {epoch:02d}, \"\n                        f\"Loss: {loss:.4f}, \"\n                        f\"Valid ({args.val_percent}%) [{key}]: {valid_res:.4f}\"\n                    )\n                    print_log(key, to_print)\n\n                model_name = os.path.join(\n                    args.res_dir, f\"run{run+1}_model_checkpoint{epoch}.pth\"\n                )\n                optimizer_name = os.path.join(\n                    args.res_dir, f\"run{run+1}_optimizer_checkpoint{epoch}.pth\"\n                )\n                torch.save(model.state_dict(), model_name)\n                torch.save(optimizer.state_dict(), optimizer_name)\n\n        print_log()\n        tested = dict()\n        for eval_metric in loggers.keys():\n            # Select models according to the eval_metric of the dataset.\n            res = torch.tensor(loggers[eval_metric].results[\"valid\"][run])\n            if args.no_test:\n                epoch = torch.argmax(res).item() + 1\n                val_res = loggers[eval_metric].results[\"valid\"][run][epoch - 1]\n                loggers[eval_metric].add_result(run, (epoch, val_res), \"test\")\n                print_log(\n                    f\"No Test; Best Valid:\",\n                    f\"   Run: {run + 1:02d}, \"\n                    f\"Epoch: {epoch:02d}, \"\n                    f\"Valid ({args.val_percent}%) [{eval_metric}]: {val_res:.4f}\",\n                )\n                continue\n\n            idx_to_test = (\n                torch.topk(res, args.test_topk, largest=True).indices + 1\n            ).tolist()  # indices of top k valid results\n            print_log(\n                f\"Eval Metric: {eval_metric}\",\n                f\"Run: {run + 1:02d}, \"\n                f\"Top {args.test_topk} Eval Points: {idx_to_test}\",\n            )\n            for _idx, epoch in enumerate(idx_to_test):\n                print_log(\n                    f\"Test Point[{_idx+1}]: \"\n                    f\"Epoch {epoch:02d}, \"\n                    f\"Test Metric: {dataset.eval_metric}\"\n                )\n                if epoch not in tested:\n                    model_name = os.path.join(\n                        args.res_dir, f\"run{run+1}_model_checkpoint{epoch}.pth\"\n                    )\n                    optimizer_name = os.path.join(\n                        args.res_dir,\n                        f\"run{run+1}_optimizer_checkpoint{epoch}.pth\",\n                    )\n                    model.load_state_dict(\n                        torch.load(model_name, weights_only=False)\n                    )\n                    optimizer.load_state_dict(\n                        torch.load(optimizer_name, weights_only=False)\n                    )\n                    tested[epoch] = (\n                        test(final_val_loader, dataset.eval_metric)[\n                            dataset.eval_metric\n                        ],\n                        test(test_loader, dataset.eval_metric)[\n                            dataset.eval_metric\n                        ],\n                    )\n\n                val_res, test_res = tested[epoch]\n                loggers[eval_metric].add_result(\n                    run, (epoch, val_res, test_res), \"test\"\n                )\n                print_log(\n                    f\"   Run: {run + 1:02d}, \"\n                    f\"Epoch: {epoch:02d}, \"\n                    f\"Valid ({args.val_percent}%) [{eval_metric}]: \"\n                    f\"{loggers[eval_metric].results['valid'][run][epoch-1]:.4f}, \"\n                    f\"Valid (final) [{dataset.eval_metric}]: {val_res:.4f}, \"\n                    f\"Test [{dataset.eval_metric}]: {test_res:.4f}\"\n                )\n\n        etime = datetime.datetime.now()\n        print_log(\n            f\"end run [{run}], {etime}\",\n            f\"spent time:{etime-stime}\",\n        )\n\n    for key in loggers.keys():\n        print(f\"\\n{key}\")\n        loggers[key].print_statistics()\n        with open(log_file, \"a\") as f:\n            print(f\"\\n{key}\", file=f)\n            loggers[key].print_statistics(f=f)\n    print(f\"Total number of parameters is {total_params}\")\n    print(f\"Results are saved in {args.res_dir}\")\n"
  },
  {
    "path": "examples/pytorch/ogb/ngnn_seal/models.py",
    "content": "import math\n\nimport torch\nimport torch.nn.functional as F\nfrom dgl.nn import GraphConv, SortPooling\nfrom torch.nn import Conv1d, Embedding, Linear, MaxPool1d, ModuleList\n\n\nclass NGNN_GCNConv(torch.nn.Module):\n    def __init__(\n        self, input_channels, hidden_channels, output_channels, num_layers\n    ):\n        super(NGNN_GCNConv, self).__init__()\n        self.conv = GraphConv(input_channels, hidden_channels)\n        self.fc = Linear(hidden_channels, hidden_channels)\n        self.fc2 = Linear(hidden_channels, output_channels)\n        self.num_layers = num_layers\n\n    def reset_parameters(self):\n        self.conv.reset_parameters()\n        gain = torch.nn.init.calculate_gain(\"relu\")\n        torch.nn.init.xavier_uniform_(self.fc.weight, gain=gain)\n        torch.nn.init.xavier_uniform_(self.fc2.weight, gain=gain)\n        for bias in [self.fc.bias, self.fc2.bias]:\n            stdv = 1.0 / math.sqrt(bias.size(0))\n            bias.data.uniform_(-stdv, stdv)\n\n    def forward(self, g, x, edge_weight=None):\n        x = self.conv(g, x, edge_weight)\n        if self.num_layers == 2:\n            x = F.relu(x)\n            x = self.fc(x)\n        x = F.relu(x)\n        x = self.fc2(x)\n        return x\n\n\n# An end-to-end deep learning architecture for graph classification, AAAI-18.\nclass DGCNN(torch.nn.Module):\n    def __init__(\n        self,\n        hidden_channels,\n        num_layers,\n        max_z,\n        k,\n        feature_dim=0,\n        GNN=GraphConv,\n        NGNN=NGNN_GCNConv,\n        dropout=0.0,\n        ngnn_type=\"all\",\n        num_ngnn_layers=1,\n    ):\n        super(DGCNN, self).__init__()\n\n        self.feature_dim = feature_dim\n        self.dropout = dropout\n\n        self.k = k\n        self.sort_pool = SortPooling(k=self.k)\n\n        self.max_z = max_z\n        self.z_embedding = Embedding(self.max_z, hidden_channels)\n\n        self.convs = ModuleList()\n        initial_channels = hidden_channels + self.feature_dim\n\n        self.num_ngnn_layers = num_ngnn_layers\n        if ngnn_type in [\"input\", \"all\"]:\n            self.convs.append(\n                NGNN(\n                    initial_channels,\n                    hidden_channels,\n                    hidden_channels,\n                    self.num_ngnn_layers,\n                )\n            )\n        else:\n            self.convs.append(GNN(initial_channels, hidden_channels))\n\n        if ngnn_type in [\"hidden\", \"all\"]:\n            for _ in range(0, num_layers - 1):\n                self.convs.append(\n                    NGNN(\n                        hidden_channels,\n                        hidden_channels,\n                        hidden_channels,\n                        self.num_ngnn_layers,\n                    )\n                )\n        else:\n            for _ in range(0, num_layers - 1):\n                self.convs.append(GNN(hidden_channels, hidden_channels))\n\n        if ngnn_type in [\"output\", \"all\"]:\n            self.convs.append(\n                NGNN(hidden_channels, hidden_channels, 1, self.num_ngnn_layers)\n            )\n        else:\n            self.convs.append(GNN(hidden_channels, 1))\n\n        conv1d_channels = [16, 32]\n        total_latent_dim = hidden_channels * num_layers + 1\n        conv1d_kws = [total_latent_dim, 5]\n        self.conv1 = Conv1d(1, conv1d_channels[0], conv1d_kws[0], conv1d_kws[0])\n        self.maxpool1d = MaxPool1d(2, 2)\n        self.conv2 = Conv1d(\n            conv1d_channels[0], conv1d_channels[1], conv1d_kws[1], 1\n        )\n        dense_dim = int((self.k - 2) / 2 + 1)\n        dense_dim = (dense_dim - conv1d_kws[1] + 1) * conv1d_channels[1]\n        self.lin1 = Linear(dense_dim, 128)\n        self.lin2 = Linear(128, 1)\n\n    def forward(self, g, z, x=None, edge_weight=None):\n        z_emb = self.z_embedding(z)\n        if z_emb.ndim == 3:  # in case z has multiple integer labels\n            z_emb = z_emb.sum(dim=1)\n        if x is not None:\n            x = torch.cat([z_emb, x.to(torch.float)], 1)\n        else:\n            x = z_emb\n        xs = [x]\n\n        for conv in self.convs:\n            xs += [\n                F.dropout(\n                    torch.tanh(conv(g, xs[-1], edge_weight=edge_weight)),\n                    p=self.dropout,\n                    training=self.training,\n                )\n            ]\n        x = torch.cat(xs[1:], dim=-1)\n\n        # global pooling\n        x = self.sort_pool(g, x)\n        x = x.unsqueeze(1)  # [num_graphs, 1, k * hidden]\n        x = F.relu(self.conv1(x))\n        x = self.maxpool1d(x)\n        x = F.relu(self.conv2(x))\n        x = x.view(x.size(0), -1)  # [num_graphs, dense_dim]\n\n        # MLP.\n        x = F.relu(self.lin1(x))\n        x = F.dropout(x, p=0.5, training=self.training)\n        x = self.lin2(x)\n        return x\n"
  },
  {
    "path": "examples/pytorch/ogb/ngnn_seal/utils.py",
    "content": "import random\nimport sys\n\nimport numpy as np\nimport torch\nfrom dgl.sampling import global_uniform_negative_sampling\nfrom scipy.sparse.csgraph import shortest_path\n\n\ndef k_hop_subgraph(src, dst, num_hops, g, sample_ratio=1.0, directed=False):\n    # Extract the k-hop enclosing subgraph around link (src, dst) from g\n    nodes = [src, dst]\n    visited = set([src, dst])\n    fringe = set([src, dst])\n    for _ in range(num_hops):\n        if not directed:\n            _, fringe = g.out_edges(list(fringe))\n            fringe = fringe.tolist()\n        else:\n            _, out_neighbors = g.out_edges(list(fringe))\n            in_neighbors, _ = g.in_edges(list(fringe))\n            fringe = in_neighbors.tolist() + out_neighbors.tolist()\n        fringe = set(fringe) - visited\n        visited = visited.union(fringe)\n\n        if sample_ratio < 1.0:\n            fringe = random.sample(fringe, int(sample_ratio * len(fringe)))\n        if len(fringe) == 0:\n            break\n\n        nodes = nodes + list(fringe)\n\n    subg = g.subgraph(nodes, store_ids=True)\n\n    return subg\n\n\ndef drnl_node_labeling(adj, src, dst):\n    # Double Radius Node Labeling (DRNL).\n    src, dst = (dst, src) if src > dst else (src, dst)\n\n    idx = list(range(src)) + list(range(src + 1, adj.shape[0]))\n    adj_wo_src = adj[idx, :][:, idx]\n\n    idx = list(range(dst)) + list(range(dst + 1, adj.shape[0]))\n    adj_wo_dst = adj[idx, :][:, idx]\n\n    dist2src = shortest_path(\n        adj_wo_dst, directed=False, unweighted=True, indices=src\n    )\n    dist2src = np.insert(dist2src, dst, 0, axis=0)\n    dist2src = torch.from_numpy(dist2src)\n\n    dist2dst = shortest_path(\n        adj_wo_src, directed=False, unweighted=True, indices=dst - 1\n    )\n    dist2dst = np.insert(dist2dst, src, 0, axis=0)\n    dist2dst = torch.from_numpy(dist2dst)\n\n    dist = dist2src + dist2dst\n    dist_over_2, dist_mod_2 = (\n        torch.div(dist, 2, rounding_mode=\"floor\"),\n        dist % 2,\n    )\n\n    z = 1 + torch.min(dist2src, dist2dst)\n    z += dist_over_2 * (dist_over_2 + dist_mod_2 - 1)\n    z[src] = 1.0\n    z[dst] = 1.0\n    # shortest path may include inf values\n    z[torch.isnan(z)] = 0.0\n\n    return z.to(torch.long)\n\n\ndef get_pos_neg_edges(split, split_edge, g, percent=100):\n    pos_edge = split_edge[split][\"edge\"]\n    if split == \"train\":\n        neg_edge = torch.stack(\n            global_uniform_negative_sampling(\n                g, num_samples=pos_edge.size(0), exclude_self_loops=True\n            ),\n            dim=1,\n        )\n    else:\n        neg_edge = split_edge[split][\"edge_neg\"]\n\n    # sampling according to the percent param\n    np.random.seed(123)\n    # pos sampling\n    num_pos = pos_edge.size(0)\n    perm = np.random.permutation(num_pos)\n    perm = perm[: int(percent / 100 * num_pos)]\n    pos_edge = pos_edge[perm]\n    # neg sampling\n    if neg_edge.dim() > 2:  # [Np, Nn, 2]\n        neg_edge = neg_edge[perm].view(-1, 2)\n    else:\n        np.random.seed(123)\n        num_neg = neg_edge.size(0)\n        perm = np.random.permutation(num_neg)\n        perm = perm[: int(percent / 100 * num_neg)]\n        neg_edge = neg_edge[perm]\n\n    return pos_edge, neg_edge  # ([2, Np], [2, Nn]) -> ([Np, 2], [Nn, 2])\n\n\nclass Logger(object):\n    def __init__(self, runs, info=None):\n        self.info = info\n        self.results = {\n            \"valid\": [[] for _ in range(runs)],\n            \"test\": [[] for _ in range(runs)],\n        }\n\n    def add_result(self, run, result, split=\"valid\"):\n        assert run >= 0 and run < len(self.results[\"valid\"])\n        assert split in [\"valid\", \"test\"]\n        self.results[split][run].append(result)\n\n    def print_statistics(self, run=None, f=sys.stdout):\n        if run is not None:\n            result = torch.tensor(self.results[\"valid\"][run])\n            print(f\"Run {run + 1:02d}:\", file=f)\n            print(f\"Highest Valid: {result.max():.4f}\", file=f)\n            print(f\"Highest Eval Point: {result.argmax().item()+1}\", file=f)\n            if not self.info.no_test:\n                print(\n                    f'   Final Test Point[1]: {self.results[\"test\"][run][0][0]}',\n                    f'   Final Valid: {self.results[\"test\"][run][0][1]}',\n                    f'   Final Test: {self.results[\"test\"][run][0][2]}',\n                    sep=\"\\n\",\n                    file=f,\n                )\n        else:\n            best_result = torch.tensor(\n                [test_res[0] for test_res in self.results[\"test\"]]\n            )\n\n            print(f\"All runs:\", file=f)\n            r = best_result[:, 1]\n            print(f\"Highest Valid: {r.mean():.4f} ± {r.std():.4f}\", file=f)\n            if not self.info.no_test:\n                r = best_result[:, 2]\n                print(f\"   Final Test: {r.mean():.4f} ± {r.std():.4f}\", file=f)\n"
  },
  {
    "path": "examples/pytorch/ogb/ogbn-arxiv/README.md",
    "content": "# DGL examples for ogbn-arxiv\n\nDGL implementation of GCN and GAT for [ogbn-arxiv](https://ogb.stanford.edu/docs/nodeprop/). Using some of the techniques from *Bag of Tricks for Node Classification with Graph Neural Networks* ([https://arxiv.org/abs/2103.13355](https://arxiv.org/abs/2103.13355)).\n\nRequires DGL 0.5 or later versions.\n\n### GCN\n\nFor the best score, run `gcn.py` with `--use-linear` and `--use-labels` enabled and you should directly see the result.\n\n```bash\npython3 gcn.py --use-linear --use-labels\n```\n\n### GAT\n\nFor the score of `GAT(norm. adj.)+labels`, run the following command and you should directly see the result.\n\n```bash\npython3 gat.py --use-norm --use-labels --no-attn-dst --edge-drop=0.1 --input-drop=0.1\n```\n\nFor the score of `GAT(norm. adj.)+label reuse`, run the following command and you should directly see the result.\n\n```bash\npython3 gat.py --use-norm --use-labels --n-label-iters=1 --no-attn-dst --edge-drop=0.3 --input-drop=0.25\n```\n\nFor the score of `GAT(norm. adj.)+label reuse+C&S`, run the following command and you should directly see the result.\n\n```bash\npython3 gat.py --use-norm --use-labels --n-label-iters=1 --no-attn-dst --edge-drop=0.3 --input-drop=0.25 --save-pred\npython3 correct_and_smooth.py --use-norm\n```\n\n## Usage & Options\n\n### GCN\n\n```\nusage: GCN on OGBN-Arxiv [-h] [--cpu] [--gpu GPU] [--n-runs N_RUNS] [--n-epochs N_EPOCHS] [--use-labels] [--use-linear] [--lr LR] [--n-layers N_LAYERS] [--n-hidden N_HIDDEN]\n                         [--dropout DROPOUT] [--wd WD] [--log-every LOG_EVERY] [--plot-curves]\n\noptional arguments:\n  -h, --help            show this help message and exit\n  --cpu                 CPU mode. This option overrides --gpu. (default: False)\n  --gpu GPU             GPU device ID. (default: 0)\n  --n-runs N_RUNS       running times (default: 10)\n  --n-epochs N_EPOCHS   number of epochs (default: 1000)\n  --use-labels          Use labels in the training set as input features. (default: False)\n  --use-linear          Use linear layer. (default: False)\n  --lr LR               learning rate (default: 0.005)\n  --n-layers N_LAYERS   number of layers (default: 3)\n  --n-hidden N_HIDDEN   number of hidden units (default: 256)\n  --dropout DROPOUT     dropout rate (default: 0.75)\n  --wd WD               weight decay (default: 0)\n  --log-every LOG_EVERY\n                        log every LOG_EVERY epochs (default: 20)\n  --plot-curves         plot learning curves (default: False)\n```\n\n### GAT\n\n```\nusage: GAT on OGBN-Arxiv [-h] [--cpu] [--gpu GPU] [--n-runs N_RUNS] [--n-epochs N_EPOCHS] [--use-labels] [--n-label-iters N_LABEL_ITERS] [--no-attn-dst]\n                         [--use-norm] [--lr LR] [--n-layers N_LAYERS] [--n-heads N_HEADS] [--n-hidden N_HIDDEN] [--dropout DROPOUT] [--input-drop INPUT_DROP]\n                         [--attn-drop ATTN_DROP] [--edge-drop EDGE_DROP] [--wd WD] [--log-every LOG_EVERY] [--plot-curves]\n\noptional arguments:\n  -h, --help            show this help message and exit\n  --cpu                 CPU mode. This option overrides --gpu. (default: False)\n  --gpu GPU             GPU device ID. (default: 0)\n  --n-runs N_RUNS       running times (default: 10)\n  --n-epochs N_EPOCHS   number of epochs (default: 2000)\n  --use-labels          Use labels in the training set as input features. (default: False)\n  --n-label-iters N_LABEL_ITERS\n                        number of label iterations (default: 0)\n  --no-attn-dst         Don't use attn_dst. (default: False)\n  --use-norm            Use symmetrically normalized adjacency matrix. (default: False)\n  --lr LR               learning rate (default: 0.002)\n  --n-layers N_LAYERS   number of layers (default: 3)\n  --n-heads N_HEADS     number of heads (default: 3)\n  --n-hidden N_HIDDEN   number of hidden units (default: 250)\n  --dropout DROPOUT     dropout rate (default: 0.75)\n  --input-drop INPUT_DROP\n                        input drop rate (default: 0.1)\n  --attn-drop ATTN_DROP\n                        attention dropout rate (default: 0.0)\n  --edge-drop EDGE_DROP\n                        edge drop rate (default: 0.0)\n  --wd WD               weight decay (default: 0)\n  --log-every LOG_EVERY\n                        log every LOG_EVERY epochs (default: 20)\n  --plot-curves         plot learning curves (default: False)\n```\n\n## Results\n\nHere are the results over at least 10 runs.\n\n|             Method              | Validation Accuracy |  Test Accuracy  | #Parameters |\n|:-------------------------------:|:-------------------:|:---------------:|:-----------:|\n|               GCN               |   0.7361 ± 0.0009   | 0.7246 ± 0.0021 |   109,608   |\n|           GCN+linear            |   0.7397 ± 0.0010   | 0.7270 ± 0.0016 |   218,152   |\n|           GCN+labels            |   0.7399 ± 0.0008   | 0.7259 ± 0.0006 |   119,848   |\n|        GCN+linear+labels        |   0.7442 ± 0.0012   | 0.7306 ± 0.0024 |   238,632   |\n|     GAT(norm. adj.)+labels      |   0.7508 ± 0.0009   | 0.7366 ± 0.0011 |  1,441,580  |\n|   GAT(norm. adj.)+label reuse   |   0.7516 ± 0.0008   | 0.7391 ± 0.0012 |  1,441,580  |\n| GAT(norm. adj.)+label reuse+C&S |   0.7519 ± 0.0008   | 0.7395 ± 0.0012 |  1,441,580  |\n"
  },
  {
    "path": "examples/pytorch/ogb/ogbn-arxiv/correct_and_smooth.py",
    "content": "import argparse\nimport glob\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\n\nfrom dgl import function as fn\nfrom ogb.nodeproppred import DglNodePropPredDataset, Evaluator\n\ndevice = None\n\ndataset = \"ogbn-arxiv\"\nn_node_feats, n_classes = 0, 0\n\n\ndef load_data(dataset):\n    global n_node_feats, n_classes\n\n    data = DglNodePropPredDataset(name=dataset)\n    evaluator = Evaluator(name=dataset)\n\n    splitted_idx = data.get_idx_split()\n    train_idx, val_idx, test_idx = (\n        splitted_idx[\"train\"],\n        splitted_idx[\"valid\"],\n        splitted_idx[\"test\"],\n    )\n    graph, labels = data[0]\n\n    n_node_feats = graph.ndata[\"feat\"].shape[1]\n    n_classes = (labels.max() + 1).item()\n\n    return graph, labels, train_idx, val_idx, test_idx, evaluator\n\n\ndef preprocess(graph):\n    global n_node_feats\n\n    # add reverse edges\n    srcs, dsts = graph.all_edges()\n    graph.add_edges(dsts, srcs)\n\n    # add self-loop\n    print(f\"Total edges before adding self-loop {graph.num_edges()}\")\n    graph = graph.remove_self_loop().add_self_loop()\n    print(f\"Total edges after adding self-loop {graph.num_edges()}\")\n\n    graph.create_formats_()\n\n    return graph\n\n\ndef general_outcome_correlation(\n    graph, y0, n_prop=50, alpha=0.8, use_norm=False, post_step=None\n):\n    with graph.local_scope():\n        y = y0\n        for _ in range(n_prop):\n            if use_norm:\n                degs = graph.in_degrees().float().clamp(min=1)\n                norm = torch.pow(degs, -0.5)\n                shp = norm.shape + (1,) * (y.dim() - 1)\n                norm = torch.reshape(norm, shp)\n                y = y * norm\n\n            graph.srcdata.update({\"y\": y})\n            graph.update_all(fn.copy_u(\"y\", \"m\"), fn.mean(\"m\", \"y\"))\n            y = graph.dstdata[\"y\"]\n\n            if use_norm:\n                degs = graph.in_degrees().float().clamp(min=1)\n                norm = torch.pow(degs, 0.5)\n                shp = norm.shape + (1,) * (y.dim() - 1)\n                norm = torch.reshape(norm, shp)\n                y = y * norm\n\n            y = alpha * y + (1 - alpha) * y0\n\n            if post_step is not None:\n                y = post_step(y)\n\n        return y\n\n\ndef evaluate(labels, pred, train_idx, val_idx, test_idx, evaluator):\n    return (\n        evaluator(pred[train_idx], labels[train_idx]),\n        evaluator(pred[val_idx], labels[val_idx]),\n        evaluator(pred[test_idx], labels[test_idx]),\n    )\n\n\ndef run(args, graph, labels, pred, train_idx, val_idx, test_idx, evaluator):\n    evaluator_wrapper = lambda pred, labels: evaluator.eval(\n        {\"y_pred\": pred.argmax(dim=-1, keepdim=True), \"y_true\": labels}\n    )[\"acc\"]\n\n    y = pred.clone()\n    y[train_idx] = F.one_hot(labels[train_idx], n_classes).float().squeeze(1)\n    # dy = torch.zeros(graph.num_nodes(), n_classes, device=device)\n    # dy[train_idx] = F.one_hot(labels[train_idx], n_classes).float().squeeze(1) - pred[train_idx]\n\n    _train_acc, val_acc, test_acc = evaluate(\n        labels, y, train_idx, val_idx, test_idx, evaluator_wrapper\n    )\n\n    # print(\"train acc:\", _train_acc)\n    print(\"original val acc:\", val_acc)\n    print(\"original test acc:\", test_acc)\n\n    # NOTE: Only \"smooth\" is performed here.\n    # smoothed_dy = general_outcome_correlation(\n    #     graph, dy, alpha=args.alpha1, use_norm=args.use_norm, post_step=lambda x: x.clamp(-1, 1)\n    # )\n\n    # y[train_idx] = F.one_hot(labels[train_idx], n_classes).float().squeeze(1)\n    # smoothed_dy = smoothed_dy\n    # y = y + args.alpha2 * smoothed_dy  # .clamp(0, 1)\n\n    smoothed_y = general_outcome_correlation(\n        graph,\n        y,\n        alpha=args.alpha,\n        use_norm=args.use_norm,\n        post_step=lambda x: x.clamp(0, 1),\n    )\n\n    _train_acc, val_acc, test_acc = evaluate(\n        labels, smoothed_y, train_idx, val_idx, test_idx, evaluator_wrapper\n    )\n\n    # print(\"train acc:\", _train_acc)\n    print(\"val acc:\", val_acc)\n    print(\"test acc:\", test_acc)\n\n    return val_acc, test_acc\n\n\ndef main():\n    global device\n\n    argparser = argparse.ArgumentParser(description=\"implementation of C&S)\")\n    argparser.add_argument(\n        \"--cpu\",\n        action=\"store_true\",\n        help=\"CPU mode. This option overrides --gpu.\",\n    )\n    argparser.add_argument(\"--gpu\", type=int, default=0, help=\"GPU device ID.\")\n    argparser.add_argument(\n        \"--use-norm\",\n        action=\"store_true\",\n        help=\"Use symmetrically normalized adjacency matrix.\",\n    )\n    argparser.add_argument(\"--alpha\", type=float, default=0.6, help=\"alpha\")\n    argparser.add_argument(\n        \"--pred-files\",\n        type=str,\n        default=\"./output/*.pt\",\n        help=\"address of prediction files\",\n    )\n    args = argparser.parse_args()\n\n    if args.cpu:\n        device = torch.device(\"cpu\")\n    else:\n        device = torch.device(f\"cuda:{args.gpu}\")\n\n    # load data & preprocess\n    graph, labels, train_idx, val_idx, test_idx, evaluator = load_data(dataset)\n    graph = preprocess(graph)\n\n    graph, labels, train_idx, val_idx, test_idx = map(\n        lambda x: x.to(device), (graph, labels, train_idx, val_idx, test_idx)\n    )\n\n    # run\n    val_accs, test_accs = [], []\n\n    for pred_file in glob.iglob(args.pred_files):\n        print(\"load:\", pred_file)\n        pred = torch.load(pred_file, weights_only=False)\n        val_acc, test_acc = run(\n            args, graph, labels, pred, train_idx, val_idx, test_idx, evaluator\n        )\n        val_accs.append(val_acc)\n        test_accs.append(test_acc)\n\n    print(args)\n    print(f\"Runned {len(val_accs)} times\")\n    print(\"Val Accs:\", val_accs)\n    print(\"Test Accs:\", test_accs)\n    print(f\"Average val accuracy: {np.mean(val_accs)} ± {np.std(val_accs)}\")\n    print(f\"Average test accuracy: {np.mean(test_accs)} ± {np.std(test_accs)}\")\n\n\nif __name__ == \"__main__\":\n    main()\n\n# Namespace(alpha=0.6, cpu=False, gpu=0, pred_files='./output/*.pt', use_norm=True)\n# Runned 20 times\n# Val Accs: [0.7523742407463337, 0.750729890264774, 0.7524077989194268, 0.7527098224772644, 0.752508473438706, 0.7509983556495184, 0.751904426323031, 0.7514010537266351, 0.7524077989194268, 0.753716567670056, 0.7523071244001477, 0.7518373099768448, 0.7528440551696366, 0.7509983556495184, 0.7521057753615893, 0.7520386590154032, 0.7500251686298198, 0.7513674955535421, 0.7509312393033323, 0.7518037518037518]\n# Test Accs: [0.7392753533732486, 0.7381437359833755, 0.7412093903668497, 0.7402629467316832, 0.7386169578009588, 0.7380408616752052, 0.7397280003291978, 0.7401189227002448, 0.7424233072032591, 0.7397280003291978, 0.7378351130588647, 0.7400160483920746, 0.740921342303973, 0.7385758080776906, 0.7411682406435817, 0.7389667304487377, 0.7396457008826616, 0.7384935086311545, 0.7396251260210275, 0.7379997119519371]\n# Average val accuracy: 0.751870868149938 ± 0.0008415008835817228\n# Average test accuracy: 0.7395397403452462 ± 0.0012162384423867229\n"
  },
  {
    "path": "examples/pytorch/ogb/ogbn-arxiv/gat.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n\nimport argparse\nimport math\nimport os\nimport random\nimport time\n\nimport dgl\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torch.optim as optim\nfrom matplotlib import pyplot as plt\nfrom matplotlib.ticker import AutoMinorLocator, MultipleLocator\nfrom models import GAT\nfrom ogb.nodeproppred import DglNodePropPredDataset, Evaluator\n\nepsilon = 1 - math.log(2)\n\ndevice = None\n\ndataset = \"ogbn-arxiv\"\nn_node_feats, n_classes = 0, 0\n\n\ndef seed(seed=0):\n    random.seed(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    torch.backends.cudnn.deterministic = True\n    torch.backends.cudnn.benchmark = False\n    dgl.random.seed(seed)\n\n\ndef load_data(dataset):\n    global n_node_feats, n_classes\n\n    data = DglNodePropPredDataset(name=dataset)\n    evaluator = Evaluator(name=dataset)\n\n    splitted_idx = data.get_idx_split()\n    train_idx, val_idx, test_idx = (\n        splitted_idx[\"train\"],\n        splitted_idx[\"valid\"],\n        splitted_idx[\"test\"],\n    )\n    graph, labels = data[0]\n\n    n_node_feats = graph.ndata[\"feat\"].shape[1]\n    n_classes = (labels.max() + 1).item()\n\n    return graph, labels, train_idx, val_idx, test_idx, evaluator\n\n\ndef preprocess(graph):\n    global n_node_feats\n\n    # make bidirected\n    feat = graph.ndata[\"feat\"]\n    graph = dgl.to_bidirected(graph)\n    graph.ndata[\"feat\"] = feat\n\n    # add self-loop\n    print(f\"Total edges before adding self-loop {graph.num_edges()}\")\n    graph = graph.remove_self_loop().add_self_loop()\n    print(f\"Total edges after adding self-loop {graph.num_edges()}\")\n\n    graph.create_formats_()\n\n    return graph\n\n\ndef gen_model(args):\n    if args.use_labels:\n        n_node_feats_ = n_node_feats + n_classes\n    else:\n        n_node_feats_ = n_node_feats\n\n    model = GAT(\n        n_node_feats_,\n        n_classes,\n        n_hidden=args.n_hidden,\n        n_layers=args.n_layers,\n        n_heads=args.n_heads,\n        activation=F.relu,\n        dropout=args.dropout,\n        input_drop=args.input_drop,\n        attn_drop=args.attn_drop,\n        edge_drop=args.edge_drop,\n        use_attn_dst=not args.no_attn_dst,\n        use_symmetric_norm=args.use_norm,\n    )\n\n    return model\n\n\ndef custom_loss_function(x, labels):\n    y = F.cross_entropy(x, labels[:, 0], reduction=\"none\")\n    y = torch.log(epsilon + y) - math.log(epsilon)\n    return torch.mean(y)\n\n\ndef add_labels(feat, labels, idx):\n    onehot = torch.zeros([feat.shape[0], n_classes], device=device)\n    onehot[idx, labels[idx, 0]] = 1\n    return torch.cat([feat, onehot], dim=-1)\n\n\ndef adjust_learning_rate(optimizer, lr, epoch):\n    if epoch <= 50:\n        for param_group in optimizer.param_groups:\n            param_group[\"lr\"] = lr * epoch / 50\n\n\ndef train(\n    args,\n    model,\n    graph,\n    labels,\n    train_idx,\n    val_idx,\n    test_idx,\n    optimizer,\n    evaluator,\n):\n    model.train()\n\n    feat = graph.ndata[\"feat\"]\n\n    if args.use_labels:\n        mask = torch.rand(train_idx.shape) < args.mask_rate\n\n        train_labels_idx = train_idx[mask]\n        train_pred_idx = train_idx[~mask]\n\n        feat = add_labels(feat, labels, train_labels_idx)\n    else:\n        mask = torch.rand(train_idx.shape) < args.mask_rate\n\n        train_pred_idx = train_idx[mask]\n\n    optimizer.zero_grad()\n    pred = model(graph, feat)\n\n    if args.n_label_iters > 0:\n        unlabel_idx = torch.cat([train_pred_idx, val_idx, test_idx])\n        for _ in range(args.n_label_iters):\n            pred = pred.detach()\n            torch.cuda.empty_cache()\n            feat[unlabel_idx, -n_classes:] = F.softmax(\n                pred[unlabel_idx], dim=-1\n            )\n            pred = model(graph, feat)\n\n    loss = custom_loss_function(pred[train_pred_idx], labels[train_pred_idx])\n    loss.backward()\n    optimizer.step()\n\n    return evaluator(pred[train_idx], labels[train_idx]), loss.item()\n\n\n@torch.no_grad()\ndef evaluate(\n    args, model, graph, labels, train_idx, val_idx, test_idx, evaluator\n):\n    model.eval()\n\n    feat = graph.ndata[\"feat\"]\n\n    if args.use_labels:\n        feat = add_labels(feat, labels, train_idx)\n\n    pred = model(graph, feat)\n\n    if args.n_label_iters > 0:\n        unlabel_idx = torch.cat([val_idx, test_idx])\n        for _ in range(args.n_label_iters):\n            feat[unlabel_idx, -n_classes:] = F.softmax(\n                pred[unlabel_idx], dim=-1\n            )\n            pred = model(graph, feat)\n\n    train_loss = custom_loss_function(pred[train_idx], labels[train_idx])\n    val_loss = custom_loss_function(pred[val_idx], labels[val_idx])\n    test_loss = custom_loss_function(pred[test_idx], labels[test_idx])\n\n    return (\n        evaluator(pred[train_idx], labels[train_idx]),\n        evaluator(pred[val_idx], labels[val_idx]),\n        evaluator(pred[test_idx], labels[test_idx]),\n        train_loss,\n        val_loss,\n        test_loss,\n        pred,\n    )\n\n\ndef run(\n    args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running\n):\n    evaluator_wrapper = lambda pred, labels: evaluator.eval(\n        {\"y_pred\": pred.argmax(dim=-1, keepdim=True), \"y_true\": labels}\n    )[\"acc\"]\n\n    # define model and optimizer\n    model = gen_model(args).to(device)\n    optimizer = optim.RMSprop(\n        model.parameters(), lr=args.lr, weight_decay=args.wd\n    )\n\n    # training loop\n    total_time = 0\n    best_val_acc, final_test_acc, best_val_loss = 0, 0, float(\"inf\")\n    final_pred = None\n\n    accs, train_accs, val_accs, test_accs = [], [], [], []\n    losses, train_losses, val_losses, test_losses = [], [], [], []\n\n    for epoch in range(1, args.n_epochs + 1):\n        tic = time.time()\n\n        adjust_learning_rate(optimizer, args.lr, epoch)\n\n        acc, loss = train(\n            args,\n            model,\n            graph,\n            labels,\n            train_idx,\n            val_idx,\n            test_idx,\n            optimizer,\n            evaluator_wrapper,\n        )\n\n        (\n            train_acc,\n            val_acc,\n            test_acc,\n            train_loss,\n            val_loss,\n            test_loss,\n            pred,\n        ) = evaluate(\n            args,\n            model,\n            graph,\n            labels,\n            train_idx,\n            val_idx,\n            test_idx,\n            evaluator_wrapper,\n        )\n\n        toc = time.time()\n        total_time += toc - tic\n\n        if val_loss < best_val_loss:\n            best_val_loss = val_loss\n            best_val_acc = val_acc\n            final_test_acc = test_acc\n            final_pred = pred\n\n        if epoch == args.n_epochs or epoch % args.log_every == 0:\n            print(\n                f\"Run: {n_running}/{args.n_runs}, Epoch: {epoch}/{args.n_epochs}, Average epoch time: {total_time / epoch:.2f}\\n\"\n                f\"Loss: {loss:.4f}, Acc: {acc:.4f}\\n\"\n                f\"Train/Val/Test loss: {train_loss:.4f}/{val_loss:.4f}/{test_loss:.4f}\\n\"\n                f\"Train/Val/Test/Best val/Final test acc: {train_acc:.4f}/{val_acc:.4f}/{test_acc:.4f}/{best_val_acc:.4f}/{final_test_acc:.4f}\"\n            )\n\n        for l, e in zip(\n            [\n                accs,\n                train_accs,\n                val_accs,\n                test_accs,\n                losses,\n                train_losses,\n                val_losses,\n                test_losses,\n            ],\n            [\n                acc,\n                train_acc,\n                val_acc,\n                test_acc,\n                loss,\n                train_loss,\n                val_loss,\n                test_loss,\n            ],\n        ):\n            l.append(e)\n\n    print(\"*\" * 50)\n    print(f\"Best val acc: {best_val_acc}, Final test acc: {final_test_acc}\")\n    print(\"*\" * 50)\n\n    # plot learning curves\n    if args.plot_curves:\n        fig = plt.figure(figsize=(24, 24))\n        ax = fig.gca()\n        ax.set_xticks(np.arange(0, args.n_epochs, 100))\n        ax.set_yticks(np.linspace(0, 1.0, 101))\n        ax.tick_params(labeltop=True, labelright=True)\n        for y, label in zip(\n            [accs, train_accs, val_accs, test_accs],\n            [\"acc\", \"train acc\", \"val acc\", \"test acc\"],\n        ):\n            plt.plot(range(args.n_epochs), y, label=label, linewidth=1)\n        ax.xaxis.set_major_locator(MultipleLocator(100))\n        ax.xaxis.set_minor_locator(AutoMinorLocator(1))\n        ax.yaxis.set_major_locator(MultipleLocator(0.01))\n        ax.yaxis.set_minor_locator(AutoMinorLocator(2))\n        plt.grid(which=\"major\", color=\"red\", linestyle=\"dotted\")\n        plt.grid(which=\"minor\", color=\"orange\", linestyle=\"dotted\")\n        plt.legend()\n        plt.tight_layout()\n        plt.savefig(f\"gat_acc_{n_running}.png\")\n\n        fig = plt.figure(figsize=(24, 24))\n        ax = fig.gca()\n        ax.set_xticks(np.arange(0, args.n_epochs, 100))\n        ax.tick_params(labeltop=True, labelright=True)\n        for y, label in zip(\n            [losses, train_losses, val_losses, test_losses],\n            [\"loss\", \"train loss\", \"val loss\", \"test loss\"],\n        ):\n            plt.plot(range(args.n_epochs), y, label=label, linewidth=1)\n        ax.xaxis.set_major_locator(MultipleLocator(100))\n        ax.xaxis.set_minor_locator(AutoMinorLocator(1))\n        ax.yaxis.set_major_locator(MultipleLocator(0.1))\n        ax.yaxis.set_minor_locator(AutoMinorLocator(5))\n        plt.grid(which=\"major\", color=\"red\", linestyle=\"dotted\")\n        plt.grid(which=\"minor\", color=\"orange\", linestyle=\"dotted\")\n        plt.legend()\n        plt.tight_layout()\n        plt.savefig(f\"gat_loss_{n_running}.png\")\n\n    if args.save_pred:\n        os.makedirs(\"./output\", exist_ok=True)\n        torch.save(F.softmax(final_pred, dim=1), f\"./output/{n_running}.pt\")\n\n    return best_val_acc, final_test_acc\n\n\ndef count_parameters(args):\n    model = gen_model(args)\n    return sum([p.numel() for p in model.parameters() if p.requires_grad])\n\n\ndef main():\n    global device, n_node_feats, n_classes, epsilon\n\n    argparser = argparse.ArgumentParser(\n        \"GAT implementation on ogbn-arxiv\",\n        formatter_class=argparse.ArgumentDefaultsHelpFormatter,\n    )\n    argparser.add_argument(\n        \"--cpu\",\n        action=\"store_true\",\n        help=\"CPU mode. This option overrides --gpu.\",\n    )\n    argparser.add_argument(\"--gpu\", type=int, default=0, help=\"GPU device ID.\")\n    argparser.add_argument(\"--seed\", type=int, default=0, help=\"seed\")\n    argparser.add_argument(\n        \"--n-runs\", type=int, default=10, help=\"running times\"\n    )\n    argparser.add_argument(\n        \"--n-epochs\", type=int, default=2000, help=\"number of epochs\"\n    )\n    argparser.add_argument(\n        \"--use-labels\",\n        action=\"store_true\",\n        help=\"Use labels in the training set as input features.\",\n    )\n    argparser.add_argument(\n        \"--n-label-iters\",\n        type=int,\n        default=0,\n        help=\"number of label iterations\",\n    )\n    argparser.add_argument(\n        \"--mask-rate\", type=float, default=0.5, help=\"mask rate\"\n    )\n    argparser.add_argument(\n        \"--no-attn-dst\", action=\"store_true\", help=\"Don't use attn_dst.\"\n    )\n    argparser.add_argument(\n        \"--use-norm\",\n        action=\"store_true\",\n        help=\"Use symmetrically normalized adjacency matrix.\",\n    )\n    argparser.add_argument(\n        \"--lr\", type=float, default=0.002, help=\"learning rate\"\n    )\n    argparser.add_argument(\n        \"--n-layers\", type=int, default=3, help=\"number of layers\"\n    )\n    argparser.add_argument(\n        \"--n-heads\", type=int, default=3, help=\"number of heads\"\n    )\n    argparser.add_argument(\n        \"--n-hidden\", type=int, default=250, help=\"number of hidden units\"\n    )\n    argparser.add_argument(\n        \"--dropout\", type=float, default=0.75, help=\"dropout rate\"\n    )\n    argparser.add_argument(\n        \"--input-drop\", type=float, default=0.1, help=\"input drop rate\"\n    )\n    argparser.add_argument(\n        \"--attn-drop\", type=float, default=0.0, help=\"attention drop rate\"\n    )\n    argparser.add_argument(\n        \"--edge-drop\", type=float, default=0.0, help=\"edge drop rate\"\n    )\n    argparser.add_argument(\"--wd\", type=float, default=0, help=\"weight decay\")\n    argparser.add_argument(\n        \"--log-every\", type=int, default=20, help=\"log every LOG_EVERY epochs\"\n    )\n    argparser.add_argument(\n        \"--plot-curves\", action=\"store_true\", help=\"plot learning curves\"\n    )\n    argparser.add_argument(\n        \"--save-pred\", action=\"store_true\", help=\"save final predictions\"\n    )\n    args = argparser.parse_args()\n\n    if not args.use_labels and args.n_label_iters > 0:\n        raise ValueError(\n            \"'--use-labels' must be enabled when n_label_iters > 0\"\n        )\n\n    if args.cpu:\n        device = torch.device(\"cpu\")\n    else:\n        device = torch.device(f\"cuda:{args.gpu}\")\n\n    # load data & preprocess\n    graph, labels, train_idx, val_idx, test_idx, evaluator = load_data(dataset)\n    graph = preprocess(graph)\n\n    graph, labels, train_idx, val_idx, test_idx = map(\n        lambda x: x.to(device), (graph, labels, train_idx, val_idx, test_idx)\n    )\n\n    # run\n    val_accs, test_accs = [], []\n\n    for i in range(args.n_runs):\n        seed(args.seed + i)\n        val_acc, test_acc = run(\n            args, graph, labels, train_idx, val_idx, test_idx, evaluator, i + 1\n        )\n        val_accs.append(val_acc)\n        test_accs.append(test_acc)\n\n    print(args)\n    print(f\"Runned {args.n_runs} times\")\n    print(\"Val Accs:\", val_accs)\n    print(\"Test Accs:\", test_accs)\n    print(f\"Average val accuracy: {np.mean(val_accs)} ± {np.std(val_accs)}\")\n    print(f\"Average test accuracy: {np.mean(test_accs)} ± {np.std(test_accs)}\")\n    print(f\"Number of params: {count_parameters(args)}\")\n\n\nif __name__ == \"__main__\":\n    main()\n\n\n# Namespace(attn_drop=0.0, cpu=False, dropout=0.75, edge_drop=0.1, gpu=0, input_drop=0.1, log_every=20, lr=0.002, n_epochs=2000, n_heads=3, n_hidden=250, n_label_iters=0, n_layers=3, n_runs=10, no_attn_dst=True, plot_curves=True, use_labels=True, use_norm=True, wd=0)\n# Runned 10 times\n# Val Accs: [0.7492868888217725, 0.7524413570925199, 0.7505620993993087, 0.7500251686298198, 0.7501929594952851, 0.7513003792073559, 0.7516695191113796, 0.7505285412262156, 0.7504949830531226, 0.7515017282459143]\n# Test Accs: [0.7366829208073575, 0.7384112091846182, 0.7368886694236981, 0.7345019854741477, 0.7373001666563792, 0.7362508487130424, 0.7352221056313396, 0.736477172191017, 0.7380614365368393, 0.7362919984363105]\n# Average val accuracy: 0.7508003624282694 ± 0.0008760483047616948\n# Average test accuracy: 0.736608851305475 ± 0.0011192876013651112\n# Number of params: 1441580\n\n# Namespace(attn_drop=0.0, cpu=False, dropout=0.75, edge_drop=0.3, gpu=0, input_drop=0.25, log_every=20, lr=0.002, n_epochs=2000, n_heads=3, n_hidden=250, n_label_iters=1, n_layers=3, n_runs=10, no_attn_dst=True, plot_curves=True, use_labels=True, use_norm=True, wd=0)\n# Runned 20 times\n# Val Accs: [0.7529782878620088, 0.7521393335346823, 0.7521728917077755, 0.7504949830531226, 0.7518037518037518, 0.7518373099768448, 0.7516359609382866, 0.7511325883418907, 0.7509312393033323, 0.7515017282459143, 0.7511325883418907, 0.7514346118997282, 0.7509312393033323, 0.7521393335346823, 0.7528776133427296, 0.7522735662270545, 0.7504949830531226, 0.7522735662270545, 0.7511661465149837, 0.7501258431490989]\n# Test Accs: [0.7390901796185421, 0.7398720243606361, 0.7394605271279551, 0.7384523589078863, 0.7388638561405675, 0.7397280003291978, 0.7414151389831903, 0.7376499393041582, 0.7399748986688065, 0.7400366232537087, 0.7392547785116145, 0.7388844310022015, 0.7374853404110857, 0.7384317840462523, 0.7418677859391396, 0.737937987367035, 0.7381643108450096, 0.7399543238071724, 0.7377322387506944, 0.7385758080776906]\n# Average val accuracy: 0.7515738783180644 ± 0.0007617982474634186\n# Average test accuracy: 0.7391416167726272 ± 0.0011522198067958794\n# Number of params: 1441580\n"
  },
  {
    "path": "examples/pytorch/ogb/ogbn-arxiv/gcn.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n\nimport argparse\nimport math\nimport time\n\nimport numpy as np\nimport torch as th\nimport torch.nn.functional as F\nimport torch.optim as optim\nfrom matplotlib import pyplot as plt\nfrom matplotlib.ticker import AutoMinorLocator, MultipleLocator\nfrom models import GCN\nfrom ogb.nodeproppred import DglNodePropPredDataset, Evaluator\n\ndevice = None\nin_feats, n_classes = None, None\nepsilon = 1 - math.log(2)\n\n\ndef gen_model(args):\n    if args.use_labels:\n        model = GCN(\n            in_feats + n_classes,\n            args.n_hidden,\n            n_classes,\n            args.n_layers,\n            F.relu,\n            args.dropout,\n            args.use_linear,\n        )\n    else:\n        model = GCN(\n            in_feats,\n            args.n_hidden,\n            n_classes,\n            args.n_layers,\n            F.relu,\n            args.dropout,\n            args.use_linear,\n        )\n    return model\n\n\ndef cross_entropy(x, labels):\n    y = F.cross_entropy(x, labels[:, 0], reduction=\"none\")\n    y = th.log(epsilon + y) - math.log(epsilon)\n    return th.mean(y)\n\n\ndef compute_acc(pred, labels, evaluator):\n    return evaluator.eval(\n        {\"y_pred\": pred.argmax(dim=-1, keepdim=True), \"y_true\": labels}\n    )[\"acc\"]\n\n\ndef add_labels(feat, labels, idx):\n    onehot = th.zeros([feat.shape[0], n_classes]).to(device)\n    onehot[idx, labels[idx, 0]] = 1\n    return th.cat([feat, onehot], dim=-1)\n\n\ndef adjust_learning_rate(optimizer, lr, epoch):\n    if epoch <= 50:\n        for param_group in optimizer.param_groups:\n            param_group[\"lr\"] = lr * epoch / 50\n\n\ndef train(model, graph, labels, train_idx, optimizer, use_labels):\n    model.train()\n\n    feat = graph.ndata[\"feat\"]\n\n    if use_labels:\n        mask_rate = 0.5\n        mask = th.rand(train_idx.shape) < mask_rate\n\n        train_labels_idx = train_idx[mask]\n        train_pred_idx = train_idx[~mask]\n\n        feat = add_labels(feat, labels, train_labels_idx)\n    else:\n        mask_rate = 0.5\n        mask = th.rand(train_idx.shape) < mask_rate\n\n        train_pred_idx = train_idx[mask]\n\n    optimizer.zero_grad()\n    pred = model(graph, feat)\n    loss = cross_entropy(pred[train_pred_idx], labels[train_pred_idx])\n    loss.backward()\n    optimizer.step()\n\n    return loss, pred\n\n\n@th.no_grad()\ndef evaluate(\n    model, graph, labels, train_idx, val_idx, test_idx, use_labels, evaluator\n):\n    model.eval()\n\n    feat = graph.ndata[\"feat\"]\n\n    if use_labels:\n        feat = add_labels(feat, labels, train_idx)\n\n    pred = model(graph, feat)\n    train_loss = cross_entropy(pred[train_idx], labels[train_idx])\n    val_loss = cross_entropy(pred[val_idx], labels[val_idx])\n    test_loss = cross_entropy(pred[test_idx], labels[test_idx])\n\n    return (\n        compute_acc(pred[train_idx], labels[train_idx], evaluator),\n        compute_acc(pred[val_idx], labels[val_idx], evaluator),\n        compute_acc(pred[test_idx], labels[test_idx], evaluator),\n        train_loss,\n        val_loss,\n        test_loss,\n    )\n\n\ndef run(\n    args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running\n):\n    # define model and optimizer\n    model = gen_model(args)\n    model = model.to(device)\n\n    optimizer = optim.AdamW(\n        model.parameters(), lr=args.lr, weight_decay=args.wd\n    )\n    lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(\n        optimizer,\n        mode=\"min\",\n        factor=0.5,\n        patience=100,\n        min_lr=1e-3,\n    )\n\n    # training loop\n    total_time = 0\n    best_val_acc, final_test_acc, best_val_loss = 0, 0, float(\"inf\")\n\n    accs, train_accs, val_accs, test_accs = [], [], [], []\n    losses, train_losses, val_losses, test_losses = [], [], [], []\n\n    for epoch in range(1, args.n_epochs + 1):\n        tic = time.time()\n\n        adjust_learning_rate(optimizer, args.lr, epoch)\n\n        loss, pred = train(\n            model, graph, labels, train_idx, optimizer, args.use_labels\n        )\n        acc = compute_acc(pred[train_idx], labels[train_idx], evaluator)\n\n        (\n            train_acc,\n            val_acc,\n            test_acc,\n            train_loss,\n            val_loss,\n            test_loss,\n        ) = evaluate(\n            model,\n            graph,\n            labels,\n            train_idx,\n            val_idx,\n            test_idx,\n            args.use_labels,\n            evaluator,\n        )\n\n        lr_scheduler.step(loss)\n\n        toc = time.time()\n        total_time += toc - tic\n\n        if val_loss < best_val_loss:\n            best_val_loss = val_loss\n            best_val_acc = val_acc\n            final_test_acc = test_acc\n\n        if epoch % args.log_every == 0:\n            print(\n                f\"Run: {n_running}/{args.n_runs}, Epoch: {epoch}/{args.n_epochs}, Average epoch time: {total_time / epoch:.2f}\\n\"\n                f\"Loss: {loss.item():.4f}, Acc: {acc:.4f}\\n\"\n                f\"Train/Val/Test loss: {train_loss:.4f}/{val_loss:.4f}/{test_loss:.4f}\\n\"\n                f\"Train/Val/Test/Best val/Final test acc: {train_acc:.4f}/{val_acc:.4f}/{test_acc:.4f}/{best_val_acc:.4f}/{final_test_acc:.4f}\"\n            )\n\n        for l, e in zip(\n            [\n                accs,\n                train_accs,\n                val_accs,\n                test_accs,\n                losses,\n                train_losses,\n                val_losses,\n                test_losses,\n            ],\n            [\n                acc,\n                train_acc,\n                val_acc,\n                test_acc,\n                loss,\n                train_loss,\n                val_loss,\n                test_loss,\n            ],\n        ):\n            l.append(e)\n\n    print(\"*\" * 50)\n    print(f\"Best val acc: {best_val_acc}, Final test acc: {final_test_acc}\")\n    print(\"*\" * 50)\n\n    if args.plot_curves:\n        fig = plt.figure(figsize=(24, 24))\n        ax = fig.gca()\n        ax.set_xticks(np.arange(0, args.n_epochs, 100))\n        ax.set_yticks(np.linspace(0, 1.0, 101))\n        ax.tick_params(labeltop=True, labelright=True)\n        for y, label in zip(\n            [accs, train_accs, val_accs, test_accs],\n            [\"acc\", \"train acc\", \"val acc\", \"test acc\"],\n        ):\n            plt.plot(range(args.n_epochs), y, label=label)\n        ax.xaxis.set_major_locator(MultipleLocator(100))\n        ax.xaxis.set_minor_locator(AutoMinorLocator(1))\n        ax.yaxis.set_major_locator(MultipleLocator(0.01))\n        ax.yaxis.set_minor_locator(AutoMinorLocator(2))\n        plt.grid(which=\"major\", color=\"red\", linestyle=\"dotted\")\n        plt.grid(which=\"minor\", color=\"orange\", linestyle=\"dotted\")\n        plt.legend()\n        plt.tight_layout()\n        plt.savefig(f\"gcn_acc_{n_running}.png\")\n\n        fig = plt.figure(figsize=(24, 24))\n        ax = fig.gca()\n        ax.set_xticks(np.arange(0, args.n_epochs, 100))\n        ax.tick_params(labeltop=True, labelright=True)\n        for y, label in zip(\n            [losses, train_losses, val_losses, test_losses],\n            [\"loss\", \"train loss\", \"val loss\", \"test loss\"],\n        ):\n            plt.plot(range(args.n_epochs), y, label=label)\n        ax.xaxis.set_major_locator(MultipleLocator(100))\n        ax.xaxis.set_minor_locator(AutoMinorLocator(1))\n        ax.yaxis.set_major_locator(MultipleLocator(0.1))\n        ax.yaxis.set_minor_locator(AutoMinorLocator(5))\n        plt.grid(which=\"major\", color=\"red\", linestyle=\"dotted\")\n        plt.grid(which=\"minor\", color=\"orange\", linestyle=\"dotted\")\n        plt.legend()\n        plt.tight_layout()\n        plt.savefig(f\"gcn_loss_{n_running}.png\")\n\n    return best_val_acc, final_test_acc\n\n\ndef count_parameters(args):\n    model = gen_model(args)\n    return sum(\n        [np.prod(p.size()) for p in model.parameters() if p.requires_grad]\n    )\n\n\ndef main():\n    global device, in_feats, n_classes\n\n    argparser = argparse.ArgumentParser(\n        \"GCN on OGBN-Arxiv\",\n        formatter_class=argparse.ArgumentDefaultsHelpFormatter,\n    )\n    argparser.add_argument(\n        \"--cpu\",\n        action=\"store_true\",\n        help=\"CPU mode. This option overrides --gpu.\",\n    )\n    argparser.add_argument(\"--gpu\", type=int, default=0, help=\"GPU device ID.\")\n    argparser.add_argument(\n        \"--n-runs\", type=int, default=10, help=\"running times\"\n    )\n    argparser.add_argument(\n        \"--n-epochs\", type=int, default=1000, help=\"number of epochs\"\n    )\n    argparser.add_argument(\n        \"--use-labels\",\n        action=\"store_true\",\n        help=\"Use labels in the training set as input features.\",\n    )\n    argparser.add_argument(\n        \"--use-linear\", action=\"store_true\", help=\"Use linear layer.\"\n    )\n    argparser.add_argument(\n        \"--lr\", type=float, default=0.005, help=\"learning rate\"\n    )\n    argparser.add_argument(\n        \"--n-layers\", type=int, default=3, help=\"number of layers\"\n    )\n    argparser.add_argument(\n        \"--n-hidden\", type=int, default=256, help=\"number of hidden units\"\n    )\n    argparser.add_argument(\n        \"--dropout\", type=float, default=0.5, help=\"dropout rate\"\n    )\n    argparser.add_argument(\"--wd\", type=float, default=0, help=\"weight decay\")\n    argparser.add_argument(\n        \"--log-every\", type=int, default=20, help=\"log every LOG_EVERY epochs\"\n    )\n    argparser.add_argument(\n        \"--plot-curves\", action=\"store_true\", help=\"plot learning curves\"\n    )\n    args = argparser.parse_args()\n\n    if args.cpu:\n        device = th.device(\"cpu\")\n    else:\n        device = th.device(\"cuda:%d\" % args.gpu)\n\n    # load data\n    data = DglNodePropPredDataset(name=\"ogbn-arxiv\")\n    evaluator = Evaluator(name=\"ogbn-arxiv\")\n\n    splitted_idx = data.get_idx_split()\n    train_idx, val_idx, test_idx = (\n        splitted_idx[\"train\"],\n        splitted_idx[\"valid\"],\n        splitted_idx[\"test\"],\n    )\n    graph, labels = data[0]\n\n    # add reverse edges\n    srcs, dsts = graph.all_edges()\n    graph.add_edges(dsts, srcs)\n\n    # add self-loop\n    print(f\"Total edges before adding self-loop {graph.num_edges()}\")\n    graph = graph.remove_self_loop().add_self_loop()\n    print(f\"Total edges after adding self-loop {graph.num_edges()}\")\n\n    in_feats = graph.ndata[\"feat\"].shape[1]\n    n_classes = (labels.max() + 1).item()\n    graph.create_formats_()\n\n    train_idx = train_idx.to(device)\n    val_idx = val_idx.to(device)\n    test_idx = test_idx.to(device)\n    labels = labels.to(device)\n    graph = graph.to(device)\n\n    # run\n    val_accs = []\n    test_accs = []\n\n    for i in range(args.n_runs):\n        val_acc, test_acc = run(\n            args, graph, labels, train_idx, val_idx, test_idx, evaluator, i\n        )\n        val_accs.append(val_acc)\n        test_accs.append(test_acc)\n\n    print(f\"Runned {args.n_runs} times\")\n    print(\"Val Accs:\", val_accs)\n    print(\"Test Accs:\", test_accs)\n    print(f\"Average val accuracy: {np.mean(val_accs)} ± {np.std(val_accs)}\")\n    print(f\"Average test accuracy: {np.mean(test_accs)} ± {np.std(test_accs)}\")\n    print(f\"Number of params: {count_parameters(args)}\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/pytorch/ogb/ogbn-arxiv/models.py",
    "content": "import dgl.nn.pytorch as dglnn\nimport torch\nimport torch.nn as nn\nfrom dgl import function as fn\nfrom dgl.ops import edge_softmax\nfrom dgl.utils import expand_as_pair\n\n\nclass ElementWiseLinear(nn.Module):\n    def __init__(self, size, weight=True, bias=True, inplace=False):\n        super().__init__()\n        if weight:\n            self.weight = nn.Parameter(torch.Tensor(size))\n        else:\n            self.weight = None\n        if bias:\n            self.bias = nn.Parameter(torch.Tensor(size))\n        else:\n            self.bias = None\n        self.inplace = inplace\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        if self.weight is not None:\n            nn.init.ones_(self.weight)\n        if self.bias is not None:\n            nn.init.zeros_(self.bias)\n\n    def forward(self, x):\n        if self.inplace:\n            if self.weight is not None:\n                x.mul_(self.weight)\n            if self.bias is not None:\n                x.add_(self.bias)\n        else:\n            if self.weight is not None:\n                x = x * self.weight\n            if self.bias is not None:\n                x = x + self.bias\n        return x\n\n\nclass GCN(nn.Module):\n    def __init__(\n        self,\n        in_feats,\n        n_hidden,\n        n_classes,\n        n_layers,\n        activation,\n        dropout,\n        use_linear,\n    ):\n        super().__init__()\n        self.n_layers = n_layers\n        self.n_hidden = n_hidden\n        self.n_classes = n_classes\n        self.use_linear = use_linear\n\n        self.convs = nn.ModuleList()\n        if use_linear:\n            self.linear = nn.ModuleList()\n        self.norms = nn.ModuleList()\n\n        for i in range(n_layers):\n            in_hidden = n_hidden if i > 0 else in_feats\n            out_hidden = n_hidden if i < n_layers - 1 else n_classes\n            bias = i == n_layers - 1\n\n            self.convs.append(\n                dglnn.GraphConv(in_hidden, out_hidden, \"both\", bias=bias)\n            )\n            if use_linear:\n                self.linear.append(nn.Linear(in_hidden, out_hidden, bias=False))\n            if i < n_layers - 1:\n                self.norms.append(nn.BatchNorm1d(out_hidden))\n\n        self.input_drop = nn.Dropout(min(0.1, dropout))\n        self.dropout = nn.Dropout(dropout)\n        self.activation = activation\n\n    def forward(self, graph, feat):\n        h = feat\n        h = self.input_drop(h)\n\n        for i in range(self.n_layers):\n            conv = self.convs[i](graph, h)\n\n            if self.use_linear:\n                linear = self.linear[i](h)\n                h = conv + linear\n            else:\n                h = conv\n\n            if i < self.n_layers - 1:\n                h = self.norms[i](h)\n                h = self.activation(h)\n                h = self.dropout(h)\n\n        return h\n\n\nclass GATConv(nn.Module):\n    def __init__(\n        self,\n        in_feats,\n        out_feats,\n        num_heads=1,\n        feat_drop=0.0,\n        attn_drop=0.0,\n        edge_drop=0.0,\n        negative_slope=0.2,\n        use_attn_dst=True,\n        residual=False,\n        activation=None,\n        allow_zero_in_degree=False,\n        use_symmetric_norm=False,\n    ):\n        super(GATConv, self).__init__()\n        self._num_heads = num_heads\n        self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)\n        self._out_feats = out_feats\n        self._allow_zero_in_degree = allow_zero_in_degree\n        self._use_symmetric_norm = use_symmetric_norm\n        if isinstance(in_feats, tuple):\n            self.fc_src = nn.Linear(\n                self._in_src_feats, out_feats * num_heads, bias=False\n            )\n            self.fc_dst = nn.Linear(\n                self._in_dst_feats, out_feats * num_heads, bias=False\n            )\n        else:\n            self.fc = nn.Linear(\n                self._in_src_feats, out_feats * num_heads, bias=False\n            )\n        self.attn_l = nn.Parameter(\n            torch.FloatTensor(size=(1, num_heads, out_feats))\n        )\n        if use_attn_dst:\n            self.attn_r = nn.Parameter(\n                torch.FloatTensor(size=(1, num_heads, out_feats))\n            )\n        else:\n            self.register_buffer(\"attn_r\", None)\n        self.feat_drop = nn.Dropout(feat_drop)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.edge_drop = edge_drop\n        self.leaky_relu = nn.LeakyReLU(negative_slope)\n        if residual:\n            self.res_fc = nn.Linear(\n                self._in_dst_feats, num_heads * out_feats, bias=False\n            )\n        else:\n            self.register_buffer(\"res_fc\", None)\n        self.reset_parameters()\n        self._activation = activation\n\n    def reset_parameters(self):\n        gain = nn.init.calculate_gain(\"relu\")\n        if hasattr(self, \"fc\"):\n            nn.init.xavier_normal_(self.fc.weight, gain=gain)\n        else:\n            nn.init.xavier_normal_(self.fc_src.weight, gain=gain)\n            nn.init.xavier_normal_(self.fc_dst.weight, gain=gain)\n        nn.init.xavier_normal_(self.attn_l, gain=gain)\n        if isinstance(self.attn_r, nn.Parameter):\n            nn.init.xavier_normal_(self.attn_r, gain=gain)\n        if isinstance(self.res_fc, nn.Linear):\n            nn.init.xavier_normal_(self.res_fc.weight, gain=gain)\n\n    def set_allow_zero_in_degree(self, set_value):\n        self._allow_zero_in_degree = set_value\n\n    def forward(self, graph, feat):\n        with graph.local_scope():\n            if not self._allow_zero_in_degree:\n                if (graph.in_degrees() == 0).any():\n                    assert False\n\n            if isinstance(feat, tuple):\n                h_src = self.feat_drop(feat[0])\n                h_dst = self.feat_drop(feat[1])\n                if not hasattr(self, \"fc_src\"):\n                    self.fc_src, self.fc_dst = self.fc, self.fc\n                feat_src, feat_dst = h_src, h_dst\n                feat_src = self.fc_src(h_src).view(\n                    -1, self._num_heads, self._out_feats\n                )\n                feat_dst = self.fc_dst(h_dst).view(\n                    -1, self._num_heads, self._out_feats\n                )\n            else:\n                h_src = self.feat_drop(feat)\n                feat_src = h_src\n                feat_src = self.fc(h_src).view(\n                    -1, self._num_heads, self._out_feats\n                )\n                if graph.is_block:\n                    h_dst = h_src[: graph.number_of_dst_nodes()]\n                    feat_dst = feat_src[: graph.number_of_dst_nodes()]\n                else:\n                    h_dst = h_src\n                    feat_dst = feat_src\n\n            if self._use_symmetric_norm:\n                degs = graph.out_degrees().float().clamp(min=1)\n                norm = torch.pow(degs, -0.5)\n                shp = norm.shape + (1,) * (feat_src.dim() - 1)\n                norm = torch.reshape(norm, shp)\n                feat_src = feat_src * norm\n\n            # NOTE: GAT paper uses \"first concatenation then linear projection\"\n            # to compute attention scores, while ours is \"first projection then\n            # addition\", the two approaches are mathematically equivalent:\n            # We decompose the weight vector a mentioned in the paper into\n            # [a_l || a_r], then\n            # a^T [Wh_i || Wh_j] = a_l Wh_i + a_r Wh_j\n            # Our implementation is much efficient because we do not need to\n            # save [Wh_i || Wh_j] on edges, which is not memory-efficient. Plus,\n            # addition could be optimized with DGL's built-in function u_add_v,\n            # which further speeds up computation and saves memory footprint.\n            el = (feat_src * self.attn_l).sum(dim=-1).unsqueeze(-1)\n            graph.srcdata.update({\"ft\": feat_src, \"el\": el})\n            # compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively.\n            if self.attn_r is not None:\n                er = (feat_dst * self.attn_r).sum(dim=-1).unsqueeze(-1)\n                graph.dstdata.update({\"er\": er})\n                graph.apply_edges(fn.u_add_v(\"el\", \"er\", \"e\"))\n            else:\n                graph.apply_edges(fn.copy_u(\"el\", \"e\"))\n            e = self.leaky_relu(graph.edata.pop(\"e\"))\n\n            if self.training and self.edge_drop > 0:\n                perm = torch.randperm(graph.num_edges(), device=e.device)\n                bound = int(graph.num_edges() * self.edge_drop)\n                eids = perm[bound:]\n                graph.edata[\"a\"] = torch.zeros_like(e)\n                graph.edata[\"a\"][eids] = self.attn_drop(\n                    edge_softmax(graph, e[eids], eids=eids)\n                )\n            else:\n                graph.edata[\"a\"] = self.attn_drop(edge_softmax(graph, e))\n\n            # message passing\n            graph.update_all(fn.u_mul_e(\"ft\", \"a\", \"m\"), fn.sum(\"m\", \"ft\"))\n            rst = graph.dstdata[\"ft\"]\n\n            if self._use_symmetric_norm:\n                degs = graph.in_degrees().float().clamp(min=1)\n                norm = torch.pow(degs, 0.5)\n                shp = norm.shape + (1,) * (feat_dst.dim() - 1)\n                norm = torch.reshape(norm, shp)\n                rst = rst * norm\n\n            # residual\n            if self.res_fc is not None:\n                resval = self.res_fc(h_dst).view(\n                    h_dst.shape[0], -1, self._out_feats\n                )\n                rst = rst + resval\n\n            # activation\n            if self._activation is not None:\n                rst = self._activation(rst)\n\n            return rst\n\n\nclass GAT(nn.Module):\n    def __init__(\n        self,\n        in_feats,\n        n_classes,\n        n_hidden,\n        n_layers,\n        n_heads,\n        activation,\n        dropout=0.0,\n        input_drop=0.0,\n        attn_drop=0.0,\n        edge_drop=0.0,\n        use_attn_dst=True,\n        use_symmetric_norm=False,\n    ):\n        super().__init__()\n        self.in_feats = in_feats\n        self.n_hidden = n_hidden\n        self.n_classes = n_classes\n        self.n_layers = n_layers\n        self.num_heads = n_heads\n\n        self.convs = nn.ModuleList()\n        self.norms = nn.ModuleList()\n\n        for i in range(n_layers):\n            in_hidden = n_heads * n_hidden if i > 0 else in_feats\n            out_hidden = n_hidden if i < n_layers - 1 else n_classes\n            num_heads = n_heads if i < n_layers - 1 else 1\n            out_channels = n_heads\n\n            self.convs.append(\n                GATConv(\n                    in_hidden,\n                    out_hidden,\n                    num_heads=num_heads,\n                    attn_drop=attn_drop,\n                    edge_drop=edge_drop,\n                    use_attn_dst=use_attn_dst,\n                    use_symmetric_norm=use_symmetric_norm,\n                    residual=True,\n                )\n            )\n\n            if i < n_layers - 1:\n                self.norms.append(nn.BatchNorm1d(out_channels * out_hidden))\n\n        self.bias_last = ElementWiseLinear(\n            n_classes, weight=False, bias=True, inplace=True\n        )\n\n        self.input_drop = nn.Dropout(input_drop)\n        self.dropout = nn.Dropout(dropout)\n        self.activation = activation\n\n    def forward(self, graph, feat):\n        h = feat\n        h = self.input_drop(h)\n\n        for i in range(self.n_layers):\n            conv = self.convs[i](graph, h)\n\n            h = conv\n\n            if i < self.n_layers - 1:\n                h = h.flatten(1)\n                h = self.norms[i](h)\n                h = self.activation(h, inplace=True)\n                h = self.dropout(h)\n\n        h = h.mean(1)\n        h = self.bias_last(h)\n\n        return h\n"
  },
  {
    "path": "examples/pytorch/ogb/ogbn-mag/README.md",
    "content": "## Running\nThe task can be run with default parameters as follows:  `python hetero_rgcn.py`\n\nThe following options can be specified via command line arguments:\n```\noptional arguments:\n  -h, --help            show this help message and exit\n  --runs RUNS\n```\n\n### Performance\nRunning the task with default parameters should yield performance similar to below:\n\n```\nFinal performance:\nAll runs:\nHighest Train: 84.67 ± 0.37\nHighest Valid: 48.75 ± 0.39\n  Final Train: 71.08 ± 7.09\n   Final Test: 47.81 ± 0.37\n```\n\nThis is a result of 10 experiments where each experiment is run for 3 epochs.  In the table above, \"Highest\" corresponds to the maximum value over the 3 epochs and \"Final\" corresponds to the value obtained when evaluating with the model parameters _as they were when the Validation accuracy was its maximum_.  For example, if the best Valid Accuracy was achieved at the end of epoch 2, then \"Final Train\" and \"Final Test\" are the Train and Test accuracies after epoch 2.  The values reported in the table are the average and standard deviations of these metrics from 10 runs.\n\nTypically, the best Validation performance is obtained after the 1st or 2nd epoch, after which it begins to overfit.  This is why \"Highest Train\" (typically occuring at the end of the 3rd epoch), is significantly higher than \"Final Train\" (corresponding to epoch of maximal Validation performance).\n\n## Background\nThe purpose of this example is to faithfully recreate the ogbn-mag NeighborSampling (R-GCN aggr) [PyG implementation](https://github.com/snap-stanford/ogb/blob/master/examples/nodeproppred/mag/sampler.py) using DGL's HeteroGraph API.  This effort is a result of a deep-dive in [#3511](https://github.com/dmlc/dgl/issues/3511), which uncovered a number of differences between a simple R-GCN minibatch DGL implementation (e.g. like [this one](https://github.com/dmlc/dgl/blob/master/examples/pytorch/rgcn-hetero/entity_classify_mb.py)) and one specific to the OGB MAG dataset.\n\nSome examples of such differences:\n- Instead of reversing `(paper, cites, paper)` into a new relation like `(paper, rev-cites, paper)`, the PyG implementation instead just made these into undirected edges ([code](https://github.com/snap-stanford/ogb/blob/master/examples/nodeproppred/mag/sampler.py#L54))\n- In the PyG implementation there's a separate \"self\" linear projection matrix for each _node-type_ ([code](https://github.com/snap-stanford/ogb/blob/master/examples/nodeproppred/mag/sampler.py#L106)).  This is different from the R-GCN [paper](https://arxiv.org/abs/1703.06103), which has a single \"self\" linear projection matrix for each R-GCN layer, not a different one for each node-type.\n\n### Neighborhood sampling differences\nAlthough the model architectures, hyperparameter values and initialization methods are identical between the implementation here and the PyG one as of this writing, there is still a significant difference in the way neighbors are sampled, which results in the DGL implementation achieving significantly faster overfitting to the training dataset and slightly improved performance on the Test dataset.\n\nIn DGL, sampling on heterogeneous graphs with a `fanout = N` parameter means there are N samples _per incoming relation type_.  In the PyG implementation, the heterogeneous graph is represented as a homogeneous graph and there are N samples total, regardless of relation type.  This effectively means that given the same `fanout` value, there are R times as many neighbors sampled for DGL than PyG, where R is the number of edge-types that are directed inward to a node.  Since there are significantly more nodes involved in the computation, there are likewise more nodes receiving gradient updates and therefore more significant overfitting given the same number of epochs.\n\nAn effort was made to mitigate this increase by reducing the fanout from `[25, 20]` to `[6, 5]`, which gives roughly the same number of neighbors between PyG and DGL and similar final training performance.  However, the DGL implementation has significantly worse Test performance in this case.  This is likely due to the fact that sampling e.g., 5 nodes from 4 different edge types is not the same as sampling 20 nodes by ignoring edge type unless the edge types are uniformly distributed.\n\n### Input features\nThe `paper` nodes have 128-dimensional features that are derived from word embeddings of the words found in the title and abstract of the papers.  Following the PyG implementation, all node types except `paper` receive 128-dimensional learnable embeddings as node features.  This results in 154,029,312 learnable parameters for just the node features.\n\n```\nParameterDict(\n    (author): Parameter containing: [torch.FloatTensor of size 1134649x128]\n    (field_of_study): Parameter containing: [torch.FloatTensor of size 59965x128]\n    (institution): Parameter containing: [torch.FloatTensor of size 8740x128]\n)\n```\n\n### Model architecture\nThe input features are passed to a modified version of the R-GCN architecture.  As in the R-GCN paper, each _edge-type_ has its own linear projection matrix (the \"weight\" ModuleDict below).  Different from the original paper, however, each _node-type_ has its own \"self\" linear projection matrix (the \"loop_weights\" ModuleDict below).  There are 7 edge-types:  4 natural edge-types (\"cites\", \"affiliated_with\", \"has_topic\" and \"writes\") and 3 manufactured reverse edge-types (\"rev-affiliated_with\", \"rev-has_topic\", \"rev-writes\").  As mentioned above, note that there is _not_ a reverse edge type like \"rev-cites\", and instead the reverse edges are given the same type of \"cites\".  This exception was presumably made because the source and destinate nodes are of type \"paper\".  Whereas the 7 \"relation\" linear layers do not have a bias, the 4 \"self\" linear layers do.\n\nWith two of these layers, a hidden dimension size of 64 and 349 output classes, we end up with 337,460 R-GCN model parameters.\n"
  },
  {
    "path": "examples/pytorch/ogb/ogbn-mag/hetero_rgcn.py",
    "content": "import argparse\nimport itertools\nimport sys\n\nimport dgl\nimport dgl.nn as dglnn\n\nimport psutil\n\nimport torch as th\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dgl import AddReverse, Compose, ToSimple\nfrom dgl.nn import HeteroEmbedding\nfrom ogb.nodeproppred import DglNodePropPredDataset, Evaluator\nfrom tqdm import tqdm\n\nv_t = dgl.__version__\n\n\ndef prepare_data(args, device):\n    dataset = DglNodePropPredDataset(name=\"ogbn-mag\")\n    split_idx = dataset.get_idx_split()\n    # graph: dgl graph object, label: torch tensor of shape (num_nodes, num_tasks)\n    g, labels = dataset[0]\n    labels = labels[\"paper\"].flatten()\n\n    transform = Compose([ToSimple(), AddReverse()])\n    g = transform(g)\n\n    print(\"Loaded graph: {}\".format(g))\n\n    logger = Logger(args.runs)\n\n    # train sampler\n    sampler = dgl.dataloading.MultiLayerNeighborSampler([25, 20])\n    num_workers = args.num_workers\n    train_loader = dgl.dataloading.DataLoader(\n        g,\n        split_idx[\"train\"],\n        sampler,\n        batch_size=1024,\n        shuffle=True,\n        num_workers=num_workers,\n        device=device,\n    )\n\n    return g, labels, dataset.num_classes, split_idx, logger, train_loader\n\n\ndef extract_embed(node_embed, input_nodes):\n    emb = node_embed(\n        {ntype: input_nodes[ntype] for ntype in input_nodes if ntype != \"paper\"}\n    )\n    return emb\n\n\ndef rel_graph_embed(graph, embed_size):\n    node_num = {}\n    for ntype in graph.ntypes:\n        if ntype == \"paper\":\n            continue\n        node_num[ntype] = graph.num_nodes(ntype)\n    embeds = HeteroEmbedding(node_num, embed_size)\n    return embeds\n\n\nclass RelGraphConvLayer(nn.Module):\n    def __init__(\n        self, in_feat, out_feat, ntypes, rel_names, activation=None, dropout=0.0\n    ):\n        super(RelGraphConvLayer, self).__init__()\n        self.in_feat = in_feat\n        self.out_feat = out_feat\n        self.ntypes = ntypes\n        self.rel_names = rel_names\n        self.activation = activation\n\n        self.conv = dglnn.HeteroGraphConv(\n            {\n                rel: dglnn.GraphConv(\n                    in_feat, out_feat, norm=\"right\", weight=False, bias=False\n                )\n                for rel in rel_names\n            }\n        )\n\n        self.weight = nn.ModuleDict(\n            {\n                rel_name: nn.Linear(in_feat, out_feat, bias=False)\n                for rel_name in self.rel_names\n            }\n        )\n\n        # weight for self loop\n        self.loop_weights = nn.ModuleDict(\n            {\n                ntype: nn.Linear(in_feat, out_feat, bias=True)\n                for ntype in self.ntypes\n            }\n        )\n\n        self.dropout = nn.Dropout(dropout)\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        for layer in self.weight.values():\n            layer.reset_parameters()\n\n        for layer in self.loop_weights.values():\n            layer.reset_parameters()\n\n    def forward(self, g, inputs):\n        \"\"\"\n        Parameters\n        ----------\n        g : DGLGraph\n            Input graph.\n        inputs : dict[str, torch.Tensor]\n            Node feature for each node type.\n\n        Returns\n        -------\n        dict[str, torch.Tensor]\n            New node features for each node type.\n        \"\"\"\n        g = g.local_var()\n        wdict = {\n            rel_name: {\"weight\": self.weight[rel_name].weight.T}\n            for rel_name in self.rel_names\n        }\n\n        inputs_dst = {\n            k: v[: g.number_of_dst_nodes(k)] for k, v in inputs.items()\n        }\n\n        hs = self.conv(g, inputs, mod_kwargs=wdict)\n\n        def _apply(ntype, h):\n            h = h + self.loop_weights[ntype](inputs_dst[ntype])\n            if self.activation:\n                h = self.activation(h)\n            return self.dropout(h)\n\n        return {ntype: _apply(ntype, h) for ntype, h in hs.items()}\n\n\nclass EntityClassify(nn.Module):\n    def __init__(self, g, in_dim, out_dim):\n        super(EntityClassify, self).__init__()\n        self.in_dim = in_dim\n        self.h_dim = 64\n        self.out_dim = out_dim\n        self.rel_names = list(set(g.etypes))\n        self.rel_names.sort()\n        self.dropout = 0.5\n\n        self.layers = nn.ModuleList()\n        # i2h\n        self.layers.append(\n            RelGraphConvLayer(\n                self.in_dim,\n                self.h_dim,\n                g.ntypes,\n                self.rel_names,\n                activation=F.relu,\n                dropout=self.dropout,\n            )\n        )\n\n        # h2o\n        self.layers.append(\n            RelGraphConvLayer(\n                self.h_dim,\n                self.out_dim,\n                g.ntypes,\n                self.rel_names,\n                activation=None,\n            )\n        )\n\n    def reset_parameters(self):\n        for layer in self.layers:\n            layer.reset_parameters()\n\n    def forward(self, h, blocks):\n        for layer, block in zip(self.layers, blocks):\n            h = layer(block, h)\n        return h\n\n\nclass Logger(object):\n    r\"\"\"\n    This class was taken directly from the PyG implementation and can be found\n    here: https://github.com/snap-stanford/ogb/blob/master/examples/nodeproppred/mag/logger.py\n\n    This was done to ensure that performance was measured in precisely the same way\n    \"\"\"\n\n    def __init__(self, runs):\n        self.results = [[] for _ in range(runs)]\n\n    def add_result(self, run, result):\n        assert len(result) == 3\n        assert run >= 0 and run < len(self.results)\n        self.results[run].append(result)\n\n    def print_statistics(self, run=None):\n        if run is not None:\n            result = 100 * th.tensor(self.results[run])\n            argmax = result[:, 1].argmax().item()\n            print(f\"Run {run + 1:02d}:\")\n            print(f\"Highest Train: {result[:, 0].max():.2f}\")\n            print(f\"Highest Valid: {result[:, 1].max():.2f}\")\n            print(f\"  Final Train: {result[argmax, 0]:.2f}\")\n            print(f\"   Final Test: {result[argmax, 2]:.2f}\")\n        else:\n            result = 100 * th.tensor(self.results)\n\n            best_results = []\n            for r in result:\n                train1 = r[:, 0].max().item()\n                valid = r[:, 1].max().item()\n                train2 = r[r[:, 1].argmax(), 0].item()\n                test = r[r[:, 1].argmax(), 2].item()\n                best_results.append((train1, valid, train2, test))\n\n            best_result = th.tensor(best_results)\n\n            print(f\"All runs:\")\n            r = best_result[:, 0]\n            print(f\"Highest Train: {r.mean():.2f} ± {r.std():.2f}\")\n            r = best_result[:, 1]\n            print(f\"Highest Valid: {r.mean():.2f} ± {r.std():.2f}\")\n            r = best_result[:, 2]\n            print(f\"  Final Train: {r.mean():.2f} ± {r.std():.2f}\")\n            r = best_result[:, 3]\n            print(f\"   Final Test: {r.mean():.2f} ± {r.std():.2f}\")\n\n\ndef train(\n    g,\n    model,\n    node_embed,\n    optimizer,\n    train_loader,\n    split_idx,\n    labels,\n    logger,\n    device,\n    run,\n):\n    print(\"start training...\")\n    category = \"paper\"\n\n    for epoch in range(3):\n        num_train = split_idx[\"train\"][category].shape[0]\n        pbar = tqdm(total=num_train)\n        pbar.set_description(f\"Epoch {epoch:02d}\")\n        model.train()\n\n        total_loss = 0\n\n        for input_nodes, seeds, blocks in train_loader:\n            blocks = [blk.to(device) for blk in blocks]\n            seeds = seeds[\n                category\n            ]  # we only predict the nodes with type \"category\"\n            batch_size = seeds.shape[0]\n            input_nodes_indexes = input_nodes[\"paper\"].to(g.device)\n            seeds = seeds.to(labels.device)\n\n            emb = extract_embed(node_embed, input_nodes)\n            # Add the batch's raw \"paper\" features\n            emb.update({\"paper\": g.ndata[\"feat\"][\"paper\"][input_nodes_indexes]})\n\n            emb = {k: e.to(device) for k, e in emb.items()}\n            lbl = labels[seeds].to(device)\n\n            optimizer.zero_grad()\n            logits = model(emb, blocks)[category]\n\n            y_hat = logits.log_softmax(dim=-1)\n            loss = F.nll_loss(y_hat, lbl)\n            loss.backward()\n            optimizer.step()\n\n            total_loss += loss.item() * batch_size\n            pbar.update(batch_size)\n\n        pbar.close()\n        loss = total_loss / num_train\n\n        result = test(g, model, node_embed, labels, device, split_idx)\n        logger.add_result(run, result)\n        train_acc, valid_acc, test_acc = result\n        print(\n            f\"Run: {run + 1:02d}, \"\n            f\"Epoch: {epoch +1 :02d}, \"\n            f\"Loss: {loss:.4f}, \"\n            f\"Train: {100 * train_acc:.2f}%, \"\n            f\"Valid: {100 * valid_acc:.2f}%, \"\n            f\"Test: {100 * test_acc:.2f}%\"\n        )\n\n    return logger\n\n\n@th.no_grad()\ndef test(g, model, node_embed, y_true, device, split_idx):\n    model.eval()\n    category = \"paper\"\n    evaluator = Evaluator(name=\"ogbn-mag\")\n\n    # 2 GNN layers\n    sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2)\n    loader = dgl.dataloading.DataLoader(\n        g,\n        {\"paper\": th.arange(g.num_nodes(\"paper\"))},\n        sampler,\n        batch_size=16384,\n        shuffle=False,\n        num_workers=0,\n        device=device,\n    )\n\n    pbar = tqdm(total=y_true.size(0))\n    pbar.set_description(f\"Inference\")\n\n    y_hats = list()\n\n    for input_nodes, seeds, blocks in loader:\n        blocks = [blk.to(device) for blk in blocks]\n        seeds = seeds[\n            category\n        ]  # we only predict the nodes with type \"category\"\n        batch_size = seeds.shape[0]\n        input_nodes_indexes = input_nodes[\"paper\"].to(g.device)\n\n        emb = extract_embed(node_embed, input_nodes)\n        # Get the batch's raw \"paper\" features\n        emb.update({\"paper\": g.ndata[\"feat\"][\"paper\"][input_nodes_indexes]})\n        emb = {k: e.to(device) for k, e in emb.items()}\n\n        logits = model(emb, blocks)[category]\n        y_hat = logits.log_softmax(dim=-1).argmax(dim=1, keepdims=True)\n        y_hats.append(y_hat.cpu())\n\n        pbar.update(batch_size)\n\n    pbar.close()\n\n    y_pred = th.cat(y_hats, dim=0)\n    y_true = th.unsqueeze(y_true, 1)\n\n    train_acc = evaluator.eval(\n        {\n            \"y_true\": y_true[split_idx[\"train\"][\"paper\"]],\n            \"y_pred\": y_pred[split_idx[\"train\"][\"paper\"]],\n        }\n    )[\"acc\"]\n    valid_acc = evaluator.eval(\n        {\n            \"y_true\": y_true[split_idx[\"valid\"][\"paper\"]],\n            \"y_pred\": y_pred[split_idx[\"valid\"][\"paper\"]],\n        }\n    )[\"acc\"]\n    test_acc = evaluator.eval(\n        {\n            \"y_true\": y_true[split_idx[\"test\"][\"paper\"]],\n            \"y_pred\": y_pred[split_idx[\"test\"][\"paper\"]],\n        }\n    )[\"acc\"]\n\n    return train_acc, valid_acc, test_acc\n\n\ndef is_support_affinity(v_t):\n    # dgl supports enable_cpu_affinity since 0.9.1\n    return v_t >= \"0.9.1\"\n\n\ndef main(args):\n    device = f\"cuda:0\" if th.cuda.is_available() else \"cpu\"\n\n    g, labels, num_classes, split_idx, logger, train_loader = prepare_data(\n        args, device\n    )\n\n    embed_layer = rel_graph_embed(g, 128).to(device)\n    model = EntityClassify(g, 128, num_classes).to(device)\n\n    print(\n        f\"Number of embedding parameters: {sum(p.numel() for p in embed_layer.parameters())}\"\n    )\n    print(\n        f\"Number of model parameters: {sum(p.numel() for p in model.parameters())}\"\n    )\n\n    for run in range(args.runs):\n        try:\n            embed_layer.reset_parameters()\n            model.reset_parameters()\n        except:\n            # old pytorch version doesn't support reset_parameters() API\n            pass\n\n        # optimizer\n        all_params = itertools.chain(\n            model.parameters(), embed_layer.parameters()\n        )\n        optimizer = th.optim.Adam(all_params, lr=0.01)\n\n        if (\n            args.num_workers != 0\n            and device == \"cpu\"\n            and is_support_affinity(v_t)\n        ):\n            expected_max = int(psutil.cpu_count(logical=False))\n            if args.num_workers >= expected_max:\n                print(\n                    f\"[ERROR] You specified num_workers are larger than physical cores, please set any number less than {expected_max}\",\n                    file=sys.stderr,\n                )\n            with train_loader.enable_cpu_affinity():\n                logger = train(\n                    g,\n                    model,\n                    embed_layer,\n                    optimizer,\n                    train_loader,\n                    split_idx,\n                    labels,\n                    logger,\n                    device,\n                    run,\n                )\n        else:\n            logger = train(\n                g,\n                model,\n                embed_layer,\n                optimizer,\n                train_loader,\n                split_idx,\n                labels,\n                logger,\n                device,\n                run,\n            )\n        logger.print_statistics(run)\n\n    print(\"Final performance: \")\n    logger.print_statistics()\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"RGCN\")\n    parser.add_argument(\"--runs\", type=int, default=10)\n    parser.add_argument(\"--num_workers\", type=int, default=0)\n\n    args = parser.parse_args()\n\n    main(args)\n"
  },
  {
    "path": "examples/pytorch/ogb/ogbn-products/gat/README.md",
    "content": "# DGL examples for ogbn-products\n\n## Sample-based GAT\n\nRequires DGL 0.4.3post2 or later versions.\n\nRun `main.py` and you should directly see the result.\n\nAccuracy over 5 runs: 0.7863197 ± 0.00072568655\n\n## GAT (another implementation)\n\nRequires DGL 0.5 or later versions.\n\nFor the score of `GAT`, run the following command and you should directly see the result.\n\n```bash\npython3 gat.py\n```\n\nOr, if you want to speed up during training time, run with `--estimation-mode` enabled.\nThis option will do a complete evaluation when the training is over.\n\n```bash\npython3 gat.py --estimation-mode\n```\n\n## Results\n\nHere are the results over 10 runs.\n\n|    Method     | Validation Accuracy |  Test Accuracy  | #Parameters |\n|:-------------:|:-------------------:|:---------------:|:-----------:|\n| GAT (main.py) |         N/A         | 0.7863 ± 0.0007 |     N/A     |\n| GAT (gat.py)  |   0.9327 ± 0.0003   | 0.8126 ± 0.0018 |  1,065,127  |\n"
  },
  {
    "path": "examples/pytorch/ogb/ogbn-products/gat/gat.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n\nimport argparse\nimport math\nimport random\nimport time\nfrom collections import OrderedDict\n\nimport dgl\nimport dgl.function as fn\n\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torch.optim as optim\nfrom dgl.dataloading import (\n    DataLoader,\n    MultiLayerFullNeighborSampler,\n    MultiLayerNeighborSampler,\n)\nfrom matplotlib.ticker import AutoMinorLocator, MultipleLocator\nfrom models import GAT\nfrom ogb.nodeproppred import DglNodePropPredDataset, Evaluator\nfrom torch import nn\nfrom tqdm import tqdm\n\nepsilon = 1 - math.log(2)\n\ndevice = None\ndataset = \"ogbn-products\"\nn_node_feats, n_edge_feats, n_classes = 0, 0, 0\n\n\ndef seed(seed=0):\n    random.seed(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    torch.backends.cudnn.deterministic = True\n    torch.backends.cudnn.benchmark = False\n    dgl.random.seed(seed)\n\n\ndef load_data(dataset):\n    data = DglNodePropPredDataset(name=dataset)\n    evaluator = Evaluator(name=dataset)\n\n    splitted_idx = data.get_idx_split()\n    train_idx, val_idx, test_idx = (\n        splitted_idx[\"train\"],\n        splitted_idx[\"valid\"],\n        splitted_idx[\"test\"],\n    )\n    graph, labels = data[0]\n    graph.ndata[\"labels\"] = labels\n\n    return graph, labels, train_idx, val_idx, test_idx, evaluator\n\n\ndef preprocess(graph, labels, train_idx):\n    global n_node_feats, n_classes\n    n_node_feats = graph.ndata[\"feat\"].shape[1]\n    n_classes = (labels.max() + 1).item()\n\n    # graph = graph.remove_self_loop().add_self_loop()\n    n_node_feats = graph.ndata[\"feat\"].shape[-1]\n\n    graph.ndata[\"train_labels_onehot\"] = torch.zeros(\n        graph.num_nodes(), n_classes\n    )\n    graph.ndata[\"train_labels_onehot\"][train_idx, labels[train_idx, 0]] = 1\n\n    graph.ndata[\"is_train\"] = torch.zeros(graph.num_nodes(), dtype=torch.bool)\n    graph.ndata[\"is_train\"][train_idx] = 1\n\n    graph.create_formats_()\n\n    return graph, labels\n\n\ndef gen_model(args):\n    if args.use_labels:\n        n_node_feats_ = n_node_feats + n_classes\n    else:\n        n_node_feats_ = n_node_feats\n\n    model = GAT(\n        n_node_feats_,\n        n_edge_feats,\n        n_classes,\n        n_layers=args.n_layers,\n        n_heads=args.n_heads,\n        n_hidden=args.n_hidden,\n        edge_emb=0,\n        activation=F.relu,\n        dropout=args.dropout,\n        input_drop=args.input_drop,\n        attn_drop=args.attn_dropout,\n        edge_drop=args.edge_drop,\n        use_attn_dst=not args.no_attn_dst,\n        allow_zero_in_degree=True,\n        residual=False,\n    )\n\n    return model\n\n\ndef custom_loss_function(x, labels):\n    y = F.cross_entropy(x, labels[:, 0], reduction=\"none\")\n    y = torch.log(epsilon + y) - math.log(epsilon)\n    return torch.mean(y)\n\n\ndef add_soft_labels(graph, soft_labels):\n    feat = graph.srcdata[\"feat\"]\n    graph.srcdata[\"feat\"] = torch.cat([feat, soft_labels], dim=-1)\n\n\ndef update_hard_labels(graph, idx=None):\n    if idx is None:\n        idx = torch.arange(graph.srcdata[\"is_train\"].shape[0])[\n            graph.srcdata[\"is_train\"]\n        ]\n\n    graph.srcdata[\"feat\"][idx, -n_classes:] = graph.srcdata[\n        \"train_labels_onehot\"\n    ][idx]\n\n\ndef train(\n    args, model, dataloader, labels, train_idx, criterion, optimizer, evaluator\n):\n    model.train()\n\n    loss_sum, total = 0, 0\n\n    preds = torch.zeros(labels.shape[0], n_classes)\n\n    for it in range(args.n_label_iters + 1):\n        preds_old = preds.clone()\n        for input_nodes, output_nodes, subgraphs in dataloader:\n            subgraphs = [b.to(device) for b in subgraphs]\n            new_train_idx = torch.arange(len(output_nodes))\n\n            if args.use_labels:\n                mask = torch.rand(new_train_idx.shape) < args.mask_rate\n\n                train_labels_idx = torch.cat(\n                    [\n                        new_train_idx[~mask],\n                        torch.arange(len(output_nodes), len(input_nodes)),\n                    ]\n                )\n                train_pred_idx = new_train_idx[mask]\n\n                add_soft_labels(\n                    subgraphs[0],\n                    F.softmax(preds_old[input_nodes].to(device), dim=-1),\n                )\n                update_hard_labels(subgraphs[0], train_labels_idx)\n            else:\n                train_pred_idx = new_train_idx\n\n            pred = model(subgraphs)\n\n            preds[output_nodes] = pred.cpu().detach()\n\n            # NOTE: This is not a complete implementation of label reuse, since it is too expensive\n            # to predict the nodes in validation and test set during training time.\n            if it == args.n_label_iters:\n                loss = criterion(\n                    pred[train_pred_idx],\n                    subgraphs[-1].dstdata[\"labels\"][train_pred_idx],\n                )\n                optimizer.zero_grad()\n                loss.backward()\n                optimizer.step()\n\n                count = len(train_pred_idx)\n                loss_sum += loss.item() * count\n                total += count\n\n            torch.cuda.empty_cache()\n\n    return (\n        evaluator(preds[train_idx], labels[train_idx]),\n        loss_sum / total,\n    )\n\n\n@torch.no_grad()\ndef evaluate(\n    args,\n    model,\n    dataloader,\n    labels,\n    train_idx,\n    val_idx,\n    test_idx,\n    criterion,\n    evaluator,\n):\n    model.eval()\n\n    # Due to the limitation of memory capacity, we calculate the average of logits 'eval_times' times.\n    eval_times = 1\n\n    preds_avg = torch.zeros(labels.shape[0], n_classes)\n    for _ in range(eval_times):\n        preds = torch.zeros(labels.shape[0], n_classes)\n\n        for _it in range(args.n_label_iters + 1):\n            preds_old = preds.clone()\n            for input_nodes, output_nodes, subgraphs in dataloader:\n                subgraphs = [b.to(device) for b in subgraphs]\n\n                if args.use_labels:\n                    add_soft_labels(\n                        subgraphs[0],\n                        F.softmax(preds_old[input_nodes].to(device), dim=-1),\n                    )\n                    update_hard_labels(subgraphs[0])\n\n                pred = model(subgraphs, inference=True)\n                preds[output_nodes] = pred.cpu()\n\n                torch.cuda.empty_cache()\n\n        preds_avg += preds\n\n    preds_avg = preds_avg.to(device)\n    preds_avg /= eval_times\n\n    train_loss = criterion(preds_avg[train_idx], labels[train_idx]).item()\n    val_loss = criterion(preds_avg[val_idx], labels[val_idx]).item()\n    test_loss = criterion(preds_avg[test_idx], labels[test_idx]).item()\n\n    return (\n        evaluator(preds_avg[train_idx], labels[train_idx]),\n        evaluator(preds_avg[val_idx], labels[val_idx]),\n        evaluator(preds_avg[test_idx], labels[test_idx]),\n        train_loss,\n        val_loss,\n        test_loss,\n    )\n\n\ndef run(\n    args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running\n):\n    evaluator_wrapper = lambda pred, labels: evaluator.eval(\n        {\"y_pred\": pred.argmax(dim=-1, keepdim=True), \"y_true\": labels}\n    )[\"acc\"]\n    criterion = custom_loss_function\n\n    n_train_samples = train_idx.shape[0]\n    train_batch_size = (n_train_samples + 29) // 30\n    train_sampler = MultiLayerNeighborSampler(\n        [10 for _ in range(args.n_layers)]\n    )\n    train_dataloader = DataLoader(\n        graph.cpu(),\n        train_idx.cpu(),\n        train_sampler,\n        batch_size=train_batch_size,\n        shuffle=True,\n        num_workers=4,\n    )\n\n    eval_batch_size = 32768\n    eval_sampler = MultiLayerNeighborSampler([15 for _ in range(args.n_layers)])\n\n    if args.estimation_mode:\n        test_idx_during_training = test_idx[\n            torch.arange(start=0, end=len(test_idx), step=45)\n        ]\n    else:\n        test_idx_during_training = test_idx\n\n    eval_idx = torch.cat(\n        [train_idx.cpu(), val_idx.cpu(), test_idx_during_training.cpu()]\n    )\n    eval_dataloader = DataLoader(\n        graph.cpu(),\n        eval_idx,\n        eval_sampler,\n        batch_size=eval_batch_size,\n        shuffle=False,\n        num_workers=4,\n    )\n\n    model = gen_model(args).to(device)\n\n    optimizer = optim.AdamW(\n        model.parameters(), lr=args.lr, weight_decay=args.wd\n    )\n    lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(\n        optimizer,\n        mode=\"max\",\n        factor=0.7,\n        patience=20,\n        min_lr=1e-4,\n    )\n\n    best_model_state_dict = None\n\n    total_time = 0\n    val_score, best_val_score, final_test_score = 0, 0, 0\n\n    scores, train_scores, val_scores, test_scores = [], [], [], []\n    losses, train_losses, val_losses, test_losses = [], [], [], []\n\n    for epoch in range(1, args.n_epochs + 1):\n        tic = time.time()\n\n        score, loss = train(\n            args,\n            model,\n            train_dataloader,\n            labels,\n            train_idx,\n            criterion,\n            optimizer,\n            evaluator_wrapper,\n        )\n\n        toc = time.time()\n        total_time += toc - tic\n\n        if (\n            epoch == args.n_epochs\n            or epoch % args.eval_every == 0\n            or epoch % args.log_every == 0\n        ):\n            (\n                train_score,\n                val_score,\n                test_score,\n                train_loss,\n                val_loss,\n                test_loss,\n            ) = evaluate(\n                args,\n                model,\n                eval_dataloader,\n                labels,\n                train_idx,\n                val_idx,\n                test_idx_during_training,\n                criterion,\n                evaluator_wrapper,\n            )\n\n            if val_score > best_val_score:\n                best_val_score = val_score\n                final_test_score = test_score\n                if args.estimation_mode:\n                    best_model_state_dict = {\n                        k: v.to(\"cpu\") for k, v in model.state_dict().items()\n                    }\n\n            if epoch == args.n_epochs or epoch % args.log_every == 0:\n                print(\n                    f\"Run: {n_running}/{args.n_runs}, Epoch: {epoch}/{args.n_epochs}, Average epoch time: {total_time / epoch:.2f}\\n\"\n                    f\"Loss: {loss:.4f}, Score: {score:.4f}\\n\"\n                    f\"Train/Val/Test loss: {train_loss:.4f}/{val_loss:.4f}/{test_loss:.4f}\\n\"\n                    f\"Train/Val/Test/Best val/Final test score: {train_score:.4f}/{val_score:.4f}/{test_score:.4f}/{best_val_score:.4f}/{final_test_score:.4f}\"\n                )\n\n            for l, e in zip(\n                [\n                    scores,\n                    train_scores,\n                    val_scores,\n                    test_scores,\n                    losses,\n                    train_losses,\n                    val_losses,\n                    test_losses,\n                ],\n                [\n                    score,\n                    train_score,\n                    val_score,\n                    test_score,\n                    loss,\n                    train_loss,\n                    val_loss,\n                    test_loss,\n                ],\n            ):\n                l.append(e)\n\n        lr_scheduler.step(val_score)\n\n    if args.estimation_mode:\n        model.load_state_dict(best_model_state_dict)\n        eval_dataloader = DataLoader(\n            graph.cpu(),\n            test_idx.cpu(),\n            eval_sampler,\n            batch_size=eval_batch_size,\n            shuffle=False,\n            num_workers=4,\n        )\n        final_test_score = evaluate(\n            args,\n            model,\n            eval_dataloader,\n            labels,\n            train_idx,\n            val_idx,\n            test_idx,\n            criterion,\n            evaluator_wrapper,\n        )[2]\n\n    print(\"*\" * 50)\n    print(\n        f\"Best val score: {best_val_score}, Final test score: {final_test_score}\"\n    )\n    print(\"*\" * 50)\n\n    if args.plot_curves:\n        fig = plt.figure(figsize=(24, 24))\n        ax = fig.gca()\n        ax.set_xticks(np.arange(0, args.n_epochs, 100))\n        ax.set_yticks(np.linspace(0, 1.0, 101))\n        ax.tick_params(labeltop=True, labelright=True)\n        for y, label in zip(\n            [train_scores, val_scores, test_scores],\n            [\"train score\", \"val score\", \"test score\"],\n        ):\n            plt.plot(\n                range(1, args.n_epochs + 1, args.log_every),\n                y,\n                label=label,\n                linewidth=1,\n            )\n        ax.xaxis.set_major_locator(MultipleLocator(10))\n        ax.xaxis.set_minor_locator(AutoMinorLocator(1))\n        ax.yaxis.set_major_locator(MultipleLocator(0.01))\n        ax.yaxis.set_minor_locator(AutoMinorLocator(2))\n        plt.grid(which=\"major\", color=\"red\", linestyle=\"dotted\")\n        plt.grid(which=\"minor\", color=\"orange\", linestyle=\"dotted\")\n        plt.legend()\n        plt.tight_layout()\n        plt.savefig(f\"gat_score_{n_running}.png\")\n\n        fig = plt.figure(figsize=(24, 24))\n        ax = fig.gca()\n        ax.set_xticks(np.arange(0, args.n_epochs, 100))\n        ax.tick_params(labeltop=True, labelright=True)\n        for y, label in zip(\n            [losses, train_losses, val_losses, test_losses],\n            [\"loss\", \"train loss\", \"val loss\", \"test loss\"],\n        ):\n            plt.plot(\n                range(1, args.n_epochs + 1, args.log_every),\n                y,\n                label=label,\n                linewidth=1,\n            )\n        ax.xaxis.set_major_locator(MultipleLocator(10))\n        ax.xaxis.set_minor_locator(AutoMinorLocator(1))\n        ax.yaxis.set_major_locator(MultipleLocator(0.1))\n        ax.yaxis.set_minor_locator(AutoMinorLocator(5))\n        plt.grid(which=\"major\", color=\"red\", linestyle=\"dotted\")\n        plt.grid(which=\"minor\", color=\"orange\", linestyle=\"dotted\")\n        plt.legend()\n        plt.tight_layout()\n        plt.savefig(f\"gat_loss_{n_running}.png\")\n\n    return best_val_score, final_test_score\n\n\ndef count_parameters(args):\n    model = gen_model(args)\n    return sum(\n        [np.prod(p.size()) for p in model.parameters() if p.requires_grad]\n    )\n\n\ndef main():\n    global device\n\n    argparser = argparse.ArgumentParser(\n        \"GAT implementation on ogbn-products\",\n        formatter_class=argparse.ArgumentDefaultsHelpFormatter,\n    )\n    argparser.add_argument(\n        \"--cpu\",\n        action=\"store_true\",\n        help=\"CPU mode. This option overrides '--gpu'.\",\n    )\n    argparser.add_argument(\"--gpu\", type=int, default=0, help=\"GPU device ID\")\n    argparser.add_argument(\"--seed\", type=int, default=0, help=\"seed\")\n    argparser.add_argument(\n        \"--n-runs\", type=int, default=10, help=\"running times\"\n    )\n    argparser.add_argument(\n        \"--n-epochs\", type=int, default=250, help=\"number of epochs\"\n    )\n    argparser.add_argument(\n        \"--use-labels\",\n        action=\"store_true\",\n        help=\"Use labels in the training set as input features.\",\n    )\n    argparser.add_argument(\n        \"--n-label-iters\",\n        type=int,\n        default=0,\n        help=\"number of label iterations\",\n    )\n    argparser.add_argument(\n        \"--no-attn-dst\", action=\"store_true\", help=\"Don't use attn_dst.\"\n    )\n    argparser.add_argument(\n        \"--mask-rate\", type=float, default=0.5, help=\"mask rate\"\n    )\n    argparser.add_argument(\n        \"--n-heads\", type=int, default=4, help=\"number of heads\"\n    )\n    argparser.add_argument(\n        \"--lr\", type=float, default=0.01, help=\"learning rate\"\n    )\n    argparser.add_argument(\n        \"--n-layers\", type=int, default=3, help=\"number of layers\"\n    )\n    argparser.add_argument(\n        \"--n-hidden\", type=int, default=120, help=\"number of hidden units\"\n    )\n    argparser.add_argument(\n        \"--dropout\", type=float, default=0.5, help=\"dropout rate\"\n    )\n    argparser.add_argument(\n        \"--input-drop\", type=float, default=0.1, help=\"input drop rate\"\n    )\n    argparser.add_argument(\n        \"--attn-dropout\", type=float, default=0.0, help=\"attention drop rate\"\n    )\n    argparser.add_argument(\n        \"--edge-drop\", type=float, default=0.1, help=\"edge drop rate\"\n    )\n    argparser.add_argument(\"--wd\", type=float, default=0, help=\"weight decay\")\n    argparser.add_argument(\n        \"--eval-every\", type=int, default=2, help=\"log every EVAL_EVERY epochs\"\n    )\n    argparser.add_argument(\n        \"--estimation-mode\",\n        action=\"store_true\",\n        help=\"Estimate the score of test set for speed during training.\",\n    )\n    argparser.add_argument(\n        \"--log-every\", type=int, default=2, help=\"log every LOG_EVERY epochs\"\n    )\n    argparser.add_argument(\n        \"--plot-curves\", action=\"store_true\", help=\"plot learning curves\"\n    )\n    args = argparser.parse_args()\n\n    if args.cpu:\n        device = torch.device(\"cpu\")\n    else:\n        device = torch.device(\"cuda:%d\" % args.gpu)\n\n    # load data & preprocess\n    graph, labels, train_idx, val_idx, test_idx, evaluator = load_data(dataset)\n    graph, labels = preprocess(graph, labels, train_idx)\n\n    labels, train_idx, val_idx, test_idx = map(\n        lambda x: x.to(device), (labels, train_idx, val_idx, test_idx)\n    )\n\n    # run\n    val_scores, test_scores = [], []\n\n    for i in range(1, args.n_runs + 1):\n        seed(args.seed + i)\n        val_score, test_score = run(\n            args, graph, labels, train_idx, val_idx, test_idx, evaluator, i\n        )\n        val_scores.append(val_score)\n        test_scores.append(test_score)\n\n    print(args)\n    print(f\"Runned {args.n_runs} times\")\n    print(\"Val scores:\", val_scores)\n    print(\"Test scores:\", test_scores)\n    print(f\"Average val score: {np.mean(val_scores)} ± {np.std(val_scores)}\")\n    print(f\"Average test score: {np.mean(test_scores)} ± {np.std(test_scores)}\")\n    print(f\"Number of params: {count_parameters(args)}\")\n\n    if args.estimation_mode:\n        print(\n            \"WARNING: Estimation mode is enabled. The final test score is accurate, but not accurate during training time.\"\n        )\n\n\nif __name__ == \"__main__\":\n    main()\n\n# Namespace(attn_dropout=0.0, cpu=False, dropout=0.5, edge_drop=0.1, estimation_mode=True, eval_every=2, gpu=1, input_drop=0.1, log_every=2, lr=0.01, mask_rate=0.5, n_epochs=250, n_heads=4, n_hidden=120, n_label_iters=0, n_layers=3, n_runs=10, no_attn_dst=False, plot_curves=True, seed=0, use_labels=False, wd=0)\n# Runned 10 times\n# Val scores: [0.9326348447473489, 0.9330163008926073, 0.9327619967957684, 0.932355110240826, 0.9330163008926073, 0.9327365663860845, 0.9329145792538718, 0.9322788190117742, 0.9321516669633548, 0.9329908704829235]\n# Test scores: [0.8147550191112792, 0.8115680737936217, 0.8128332725586069, 0.8134062268564646, 0.8118784993477448, 0.8145462613150566, 0.8151228304665284, 0.8115274066904614, 0.8108545920615103, 0.8094583548530088]\n# Average val score: 0.9326857055667167 ± 0.00030580001557474636\n# Average test score: 0.8125950537054282 ± 0.001765025824381352\n# Number of params: 1065127\n\n# Namespace(attn_dropout=0.0, cpu=False, dropout=0.5, edge_drop=0.1, estimation_mode=True, eval_every=2, gpu=0, input_drop=0.1, log_every=2, lr=0.01, mask_rate=0.5, n_epochs=250, n_heads=4, n_hidden=120, n_label_iters=0, n_layers=3, n_runs=5, no_attn_dst=True, plot_curves=True, seed=0, use_labels=False, wd=0)\n# Runned 10 times\n# Val scores: [0.9332451745797625, 0.9330417313022913, 0.9328128576151362, 0.9323296798311421, 0.9324568318795616, 0.9327874272054523, 0.9327619967957684, 0.9328128576151362, 0.9322025277827226, 0.9329400096635557]\n# Test scores: [0.8103399272781824, 0.8115870517750965, 0.8107294277551171, 0.8115771109276573, 0.8130244079434601, 0.8094628734200265, 0.8105681149125815, 0.809217063374258, 0.8108085026779287, 0.8151549122923549]\n# Average val score: 0.932739109427053 ± 0.0003061065079170266\n# Average test score: 0.8112469392356664 ± 0.0016644261188834386\n# Number of params: 1060887\n"
  },
  {
    "path": "examples/pytorch/ogb/ogbn-products/gat/main.py",
    "content": "import argparse\nimport time\n\nimport dgl\nimport dgl.nn.pytorch as dglnn\n\nimport numpy as np\nimport torch as th\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nimport tqdm\nfrom ogb.nodeproppred import DglNodePropPredDataset\n\n\nclass GAT(nn.Module):\n    def __init__(\n        self, in_feats, n_hidden, n_classes, n_layers, num_heads, activation\n    ):\n        super().__init__()\n        self.n_layers = n_layers\n        self.n_hidden = n_hidden\n        self.n_classes = n_classes\n        self.layers = nn.ModuleList()\n        self.layers.append(\n            dglnn.GATConv(\n                (in_feats, in_feats),\n                n_hidden,\n                num_heads=num_heads,\n                activation=activation,\n            )\n        )\n        for i in range(1, n_layers - 1):\n            self.layers.append(\n                dglnn.GATConv(\n                    (n_hidden * num_heads, n_hidden * num_heads),\n                    n_hidden,\n                    num_heads=num_heads,\n                    activation=activation,\n                )\n            )\n        self.layers.append(\n            dglnn.GATConv(\n                (n_hidden * num_heads, n_hidden * num_heads),\n                n_classes,\n                num_heads=num_heads,\n                activation=None,\n            )\n        )\n\n    def forward(self, blocks, x):\n        h = x\n        for l, (layer, block) in enumerate(zip(self.layers, blocks)):\n            # We need to first copy the representation of nodes on the RHS from the\n            # appropriate nodes on the LHS.\n            # Note that the shape of h is (num_nodes_LHS, D) and the shape of h_dst\n            # would be (num_nodes_RHS, D)\n            h_dst = h[: block.num_dst_nodes()]\n            # Then we compute the updated representation on the RHS.\n            # The shape of h now becomes (num_nodes_RHS, D)\n            if l < self.n_layers - 1:\n                h = layer(block, (h, h_dst)).flatten(1)\n            else:\n                h = layer(block, (h, h_dst))\n        h = h.mean(1)\n        return h.log_softmax(dim=-1)\n\n    def inference(self, g, x, num_heads, device):\n        \"\"\"\n        Inference with the GAT model on full neighbors (i.e. without neighbor sampling).\n        g : the entire graph.\n        x : the input of entire node set.\n        The inference code is written in a fashion that it could handle any number of nodes and\n        layers.\n        \"\"\"\n        # During inference with sampling, multi-layer blocks are very inefficient because\n        # lots of computations in the first few layers are repeated.\n        # Therefore, we compute the representation of all nodes layer by layer.  The nodes\n        # on each layer are of course splitted in batches.\n        # TODO: can we standardize this?\n        for l, layer in enumerate(self.layers):\n            if l < self.n_layers - 1:\n                y = th.zeros(\n                    g.num_nodes(),\n                    self.n_hidden * num_heads\n                    if l != len(self.layers) - 1\n                    else self.n_classes,\n                )\n            else:\n                y = th.zeros(\n                    g.num_nodes(),\n                    self.n_hidden\n                    if l != len(self.layers) - 1\n                    else self.n_classes,\n                )\n\n            sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)\n            dataloader = dgl.dataloading.DataLoader(\n                g,\n                th.arange(g.num_nodes()),\n                sampler,\n                batch_size=args.batch_size,\n                shuffle=True,\n                drop_last=False,\n                num_workers=args.num_workers,\n            )\n\n            for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):\n                block = blocks[0].int().to(device)\n\n                h = x[input_nodes].to(device)\n                h_dst = h[: block.num_dst_nodes()]\n                if l < self.n_layers - 1:\n                    h = layer(block, (h, h_dst)).flatten(1)\n                else:\n                    h = layer(block, (h, h_dst))\n                    h = h.mean(1)\n                    h = h.log_softmax(dim=-1)\n\n                y[output_nodes] = h.cpu()\n\n            x = y\n        return y.to(device)\n\n\ndef compute_acc(pred, labels):\n    \"\"\"\n    Compute the accuracy of prediction given the labels.\n    \"\"\"\n    return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred)\n\n\ndef evaluate(model, g, nfeat, labels, val_nid, test_nid, num_heads, device):\n    \"\"\"\n    Evaluate the model on the validation set specified by ``val_mask``.\n    g : The entire graph.\n    inputs : The features of all the nodes.\n    labels : The labels of all the nodes.\n    val_mask : A 0-1 mask indicating which nodes do we actually compute the accuracy for.\n    batch_size : Number of nodes to compute at the same time.\n    device : The GPU device to evaluate on.\n    \"\"\"\n    model.eval()\n    with th.no_grad():\n        pred = model.inference(g, nfeat, num_heads, device)\n    model.train()\n    return (\n        compute_acc(pred[val_nid], labels[val_nid]),\n        compute_acc(pred[test_nid], labels[test_nid]),\n        pred,\n    )\n\n\ndef load_subtensor(nfeat, labels, seeds, input_nodes):\n    \"\"\"\n    Extracts features and labels for a set of nodes.\n    \"\"\"\n    batch_inputs = nfeat[input_nodes]\n    batch_labels = labels[seeds]\n    return batch_inputs, batch_labels\n\n\n#### Entry point\ndef run(args, device, data):\n    # Unpack data\n    (\n        train_nid,\n        val_nid,\n        test_nid,\n        in_feats,\n        labels,\n        n_classes,\n        nfeat,\n        g,\n        num_heads,\n    ) = data\n\n    # Create PyTorch DataLoader for constructing blocks\n    sampler = dgl.dataloading.MultiLayerNeighborSampler(\n        [int(fanout) for fanout in args.fan_out.split(\",\")]\n    )\n    dataloader = dgl.dataloading.DataLoader(\n        g,\n        train_nid,\n        sampler,\n        batch_size=args.batch_size,\n        shuffle=True,\n        drop_last=False,\n        num_workers=args.num_workers,\n    )\n\n    # Define model and optimizer\n    model = GAT(\n        in_feats, args.num_hidden, n_classes, args.num_layers, num_heads, F.relu\n    )\n    model = model.to(device)\n    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)\n\n    # Training loop\n    avg = 0\n    iter_tput = []\n    best_eval_acc = 0\n    best_test_acc = 0\n    for epoch in range(args.num_epochs):\n        tic = time.time()\n\n        # Loop over the dataloader to sample the computation dependency graph as a list of\n        # blocks.\n        for step, (input_nodes, seeds, blocks) in enumerate(dataloader):\n            tic_step = time.time()\n\n            # copy block to gpu\n            blocks = [blk.to(device) for blk in blocks]\n\n            # Load the input features as well as output labels\n            batch_inputs, batch_labels = load_subtensor(\n                nfeat, labels, seeds, input_nodes\n            )\n\n            # Compute loss and prediction\n            batch_pred = model(blocks, batch_inputs)\n            loss = F.nll_loss(batch_pred, batch_labels)\n            optimizer.zero_grad()\n            loss.backward()\n            optimizer.step()\n\n            iter_tput.append(len(seeds) / (time.time() - tic_step))\n            if step % args.log_every == 0:\n                acc = compute_acc(batch_pred, batch_labels)\n                gpu_mem_alloc = (\n                    th.cuda.max_memory_allocated() / 1000000\n                    if th.cuda.is_available()\n                    else 0\n                )\n                print(\n                    \"Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | Speed (samples/sec) {:.4f} | GPU {:.1f} MB\".format(\n                        epoch,\n                        step,\n                        loss.item(),\n                        acc.item(),\n                        np.mean(iter_tput[3:]),\n                        gpu_mem_alloc,\n                    )\n                )\n\n        toc = time.time()\n        print(\"Epoch Time(s): {:.4f}\".format(toc - tic))\n        if epoch >= 5:\n            avg += toc - tic\n        if epoch % args.eval_every == 0 and epoch != 0:\n            eval_acc, test_acc, pred = evaluate(\n                model, g, nfeat, labels, val_nid, test_nid, num_heads, device\n            )\n            if args.save_pred:\n                np.savetxt(\n                    args.save_pred + \"%02d\" % epoch,\n                    pred.argmax(1).cpu().numpy(),\n                    \"%d\",\n                )\n            print(\"Eval Acc {:.4f}\".format(eval_acc))\n            if eval_acc > best_eval_acc:\n                best_eval_acc = eval_acc\n                best_test_acc = test_acc\n            print(\n                \"Best Eval Acc {:.4f} Test Acc {:.4f}\".format(\n                    best_eval_acc, best_test_acc\n                )\n            )\n\n    print(\"Avg epoch time: {}\".format(avg / (epoch - 4)))\n    return best_test_acc\n\n\nif __name__ == \"__main__\":\n    argparser = argparse.ArgumentParser(\"multi-gpu training\")\n    argparser.add_argument(\n        \"--gpu\",\n        type=int,\n        default=0,\n        help=\"GPU device ID. Use -1 for CPU training\",\n    )\n    argparser.add_argument(\"--num-epochs\", type=int, default=100)\n    argparser.add_argument(\"--num-hidden\", type=int, default=128)\n    argparser.add_argument(\"--num-layers\", type=int, default=3)\n    argparser.add_argument(\"--fan-out\", type=str, default=\"10,10,10\")\n    argparser.add_argument(\"--batch-size\", type=int, default=512)\n    argparser.add_argument(\"--val-batch-size\", type=int, default=512)\n    argparser.add_argument(\"--log-every\", type=int, default=20)\n    argparser.add_argument(\"--eval-every\", type=int, default=1)\n    argparser.add_argument(\"--lr\", type=float, default=0.001)\n    argparser.add_argument(\n        \"--num-workers\",\n        type=int,\n        default=8,\n        help=\"Number of sampling processes. Use 0 for no extra process.\",\n    )\n    argparser.add_argument(\"--save-pred\", type=str, default=\"\")\n    argparser.add_argument(\"--head\", type=int, default=4)\n    argparser.add_argument(\"--wd\", type=float, default=0)\n    args = argparser.parse_args()\n\n    if args.gpu >= 0:\n        device = th.device(\"cuda:%d\" % args.gpu)\n    else:\n        device = th.device(\"cpu\")\n\n    # load data\n    data = DglNodePropPredDataset(name=\"ogbn-products\")\n    splitted_idx = data.get_idx_split()\n    train_idx, val_idx, test_idx = (\n        splitted_idx[\"train\"],\n        splitted_idx[\"valid\"],\n        splitted_idx[\"test\"],\n    )\n    graph, labels = data[0]\n    nfeat = graph.ndata.pop(\"feat\").to(device)\n    labels = labels[:, 0].to(device)\n\n    print(\"Total edges before adding self-loop {}\".format(graph.num_edges()))\n    graph = graph.remove_self_loop().add_self_loop()\n    print(\"Total edges after adding self-loop {}\".format(graph.num_edges()))\n\n    in_feats = nfeat.shape[1]\n    n_classes = (labels.max() + 1).item()\n\n    # Create csr/coo/csc formats before launching sampling processes\n    # This avoids creating certain formats in each data loader process, which saves momory and CPU.\n    graph.create_formats_()\n    # Pack data\n    data = (\n        train_idx,\n        val_idx,\n        test_idx,\n        in_feats,\n        labels,\n        n_classes,\n        nfeat,\n        graph,\n        args.head,\n    )\n\n    # Run 10 times\n    test_accs = []\n    for i in range(10):\n        test_accs.append(run(args, device, data).cpu().numpy())\n        print(\n            \"Average test accuracy:\", np.mean(test_accs), \"±\", np.std(test_accs)\n        )\n"
  },
  {
    "path": "examples/pytorch/ogb/ogbn-products/gat/models.py",
    "content": "# update time: 2020.11.02 17:33\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom dgl import function as fn\nfrom dgl.ops import edge_softmax\nfrom dgl.utils import expand_as_pair\n\n\nclass GATConv(nn.Module):\n    def __init__(\n        self,\n        node_feats,\n        edge_feats,\n        out_feats,\n        n_heads=1,\n        attn_drop=0.0,\n        edge_drop=0.0,\n        negative_slope=0.2,\n        residual=True,\n        activation=None,\n        use_attn_dst=True,\n        allow_zero_in_degree=True,\n        use_symmetric_norm=False,\n    ):\n        super(GATConv, self).__init__()\n        self._n_heads = n_heads\n        self._in_src_feats, self._in_dst_feats = expand_as_pair(node_feats)\n        self._out_feats = out_feats\n        self._allow_zero_in_degree = allow_zero_in_degree\n        self._use_symmetric_norm = use_symmetric_norm\n\n        # feat fc\n        self.src_fc = nn.Linear(\n            self._in_src_feats, out_feats * n_heads, bias=False\n        )\n        if residual:\n            self.dst_fc = nn.Linear(self._in_src_feats, out_feats * n_heads)\n            self.bias = None\n        else:\n            self.dst_fc = None\n            self.bias = nn.Parameter(out_feats * n_heads)\n\n        # attn fc\n        self.attn_src_fc = nn.Linear(self._in_src_feats, n_heads, bias=False)\n        if use_attn_dst:\n            self.attn_dst_fc = nn.Linear(\n                self._in_src_feats, n_heads, bias=False\n            )\n        else:\n            self.attn_dst_fc = None\n        if edge_feats > 0:\n            self.attn_edge_fc = nn.Linear(edge_feats, n_heads, bias=False)\n        else:\n            self.attn_edge_fc = None\n\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.edge_drop = edge_drop\n        self.leaky_relu = nn.LeakyReLU(negative_slope, inplace=True)\n        self.activation = activation\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        gain = nn.init.calculate_gain(\"relu\")\n        nn.init.xavier_normal_(self.src_fc.weight, gain=gain)\n        if self.dst_fc is not None:\n            nn.init.xavier_normal_(self.dst_fc.weight, gain=gain)\n\n        nn.init.xavier_normal_(self.attn_src_fc.weight, gain=gain)\n        if self.attn_dst_fc is not None:\n            nn.init.xavier_normal_(self.attn_dst_fc.weight, gain=gain)\n        if self.attn_edge_fc is not None:\n            nn.init.xavier_normal_(self.attn_edge_fc.weight, gain=gain)\n\n        if self.bias is not None:\n            nn.init.zeros_(self.bias)\n\n    def set_allow_zero_in_degree(self, set_value):\n        self._allow_zero_in_degree = set_value\n\n    def forward(self, graph, feat_src, feat_edge=None):\n        with graph.local_scope():\n            if not self._allow_zero_in_degree:\n                if (graph.in_degrees() == 0).any():\n                    assert False\n\n            if graph.is_block:\n                feat_dst = feat_src[: graph.number_of_dst_nodes()]\n            else:\n                feat_dst = feat_src\n\n            if self._use_symmetric_norm:\n                degs = graph.out_degrees().float().clamp(min=1)\n                norm = torch.pow(degs, -0.5)\n                shp = norm.shape + (1,) * (feat_src.dim() - 1)\n                norm = torch.reshape(norm, shp)\n                feat_src = feat_src * norm\n\n            feat_src_fc = self.src_fc(feat_src).view(\n                -1, self._n_heads, self._out_feats\n            )\n            feat_dst_fc = self.dst_fc(feat_dst).view(\n                -1, self._n_heads, self._out_feats\n            )\n            attn_src = self.attn_src_fc(feat_src).view(-1, self._n_heads, 1)\n\n            # NOTE: GAT paper uses \"first concatenation then linear projection\"\n            # to compute attention scores, while ours is \"first projection then\n            # addition\", the two approaches are mathematically equivalent:\n            # We decompose the weight vector a mentioned in the paper into\n            # [a_l || a_r], then\n            # a^T [Wh_i || Wh_j] = a_l Wh_i + a_r Wh_j\n            # Our implementation is much efficient because we do not need to\n            # save [Wh_i || Wh_j] on edges, which is not memory-efficient. Plus,\n            # addition could be optimized with DGL's built-in function u_add_v,\n            # which further speeds up computation and saves memory footprint.\n            graph.srcdata.update(\n                {\"feat_src_fc\": feat_src_fc, \"attn_src\": attn_src}\n            )\n\n            if self.attn_dst_fc is not None:\n                attn_dst = self.attn_dst_fc(feat_dst).view(-1, self._n_heads, 1)\n                graph.dstdata.update({\"attn_dst\": attn_dst})\n                graph.apply_edges(\n                    fn.u_add_v(\"attn_src\", \"attn_dst\", \"attn_node\")\n                )\n            else:\n                graph.apply_edges(fn.copy_u(\"attn_src\", \"attn_node\"))\n\n            e = graph.edata[\"attn_node\"]\n            if feat_edge is not None:\n                attn_edge = self.attn_edge_fc(feat_edge).view(\n                    -1, self._n_heads, 1\n                )\n                graph.edata.update({\"attn_edge\": attn_edge})\n                e += graph.edata[\"attn_edge\"]\n            e = self.leaky_relu(e)\n\n            if self.training and self.edge_drop > 0:\n                perm = torch.randperm(graph.num_edges(), device=e.device)\n                bound = int(graph.num_edges() * self.edge_drop)\n                eids = perm[bound:]\n                graph.edata[\"a\"] = torch.zeros_like(e)\n                graph.edata[\"a\"][eids] = self.attn_drop(\n                    edge_softmax(graph, e[eids], eids=eids)\n                )\n            else:\n                graph.edata[\"a\"] = self.attn_drop(edge_softmax(graph, e))\n\n            # message passing\n            graph.update_all(\n                fn.u_mul_e(\"feat_src_fc\", \"a\", \"m\"), fn.sum(\"m\", \"feat_src_fc\")\n            )\n            rst = graph.dstdata[\"feat_src_fc\"]\n\n            if self._use_symmetric_norm:\n                degs = graph.in_degrees().float().clamp(min=1)\n                norm = torch.pow(degs, 0.5)\n                shp = norm.shape + (1,) * (feat_dst.dim())\n                norm = torch.reshape(norm, shp)\n                rst = rst * norm\n\n            # residual\n            if self.dst_fc is not None:\n                rst += feat_dst_fc\n            else:\n                rst += self.bias\n\n            # activation\n            if self.activation is not None:\n                rst = self.activation(rst, inplace=True)\n\n            return rst\n\n\nclass GAT(nn.Module):\n    def __init__(\n        self,\n        node_feats,\n        edge_feats,\n        n_classes,\n        n_layers,\n        n_heads,\n        n_hidden,\n        edge_emb,\n        activation,\n        dropout,\n        input_drop,\n        attn_drop,\n        edge_drop,\n        use_attn_dst=True,\n        allow_zero_in_degree=False,\n        residual=False,\n    ):\n        super().__init__()\n        self.n_layers = n_layers\n        self.n_heads = n_heads\n        self.n_hidden = n_hidden\n        self.n_classes = n_classes\n\n        self.convs = nn.ModuleList()\n        self.norms = nn.ModuleList()\n\n        self.node_encoder = nn.Linear(node_feats, n_hidden)\n        if edge_emb > 0:\n            self.edge_encoder = nn.ModuleList()\n        else:\n            self.edge_encoder = None\n\n        for i in range(n_layers):\n            in_hidden = n_heads * n_hidden if i > 0 else node_feats\n            out_hidden = n_hidden\n\n            if self.edge_encoder is not None:\n                self.edge_encoder.append(nn.Linear(edge_feats, edge_emb))\n            self.convs.append(\n                GATConv(\n                    in_hidden,\n                    edge_emb,\n                    out_hidden,\n                    n_heads=n_heads,\n                    attn_drop=attn_drop,\n                    edge_drop=edge_drop,\n                    use_attn_dst=use_attn_dst,\n                    allow_zero_in_degree=allow_zero_in_degree,\n                )\n            )\n            self.norms.append(nn.BatchNorm1d(n_heads * out_hidden))\n\n        self.pred_linear = nn.Linear(n_heads * n_hidden, n_classes)\n\n        self.input_drop = nn.Dropout(input_drop)\n        self.dropout = nn.Dropout(dropout)\n        self.activation = activation\n        self.residual = residual\n\n    def forward(self, g, inference=False):\n        if not isinstance(g, list):\n            subgraphs = [g] * self.n_layers\n        else:\n            subgraphs = g\n\n        h = subgraphs[0].srcdata[\"feat\"]\n        h = self.input_drop(h)\n\n        h_last = None\n\n        for i in range(self.n_layers):\n            if self.edge_encoder is not None:\n                efeat = subgraphs[i].edata[\"feat\"]\n                efeat_emb = self.edge_encoder[i](efeat)\n                efeat_emb = F.relu(efeat_emb, inplace=True)\n            else:\n                efeat_emb = None\n\n            h = self.convs[i](subgraphs[i], h, efeat_emb).flatten(1, -1)\n\n            if self.residual and h_last is not None:\n                h += h_last[: h.shape[0], :]\n\n            h_last = h\n\n            h = self.norms[i](h)\n            h = self.activation(h, inplace=True)\n            h = self.dropout(h)\n\n            if inference:\n                torch.cuda.empty_cache()\n\n        h = self.pred_linear(h)\n\n        return h\n\n\nclass MLP(nn.Module):\n    def __init__(\n        self,\n        in_feats,\n        n_classes,\n        n_layers,\n        n_hidden,\n        activation,\n        dropout=0.0,\n        input_drop=0.0,\n        residual=False,\n    ):\n        super().__init__()\n        self.n_layers = n_layers\n        self.n_hidden = n_hidden\n        self.n_classes = n_classes\n\n        self.linears = nn.ModuleList()\n        self.norms = nn.ModuleList()\n\n        for i in range(n_layers):\n            in_hidden = n_hidden if i > 0 else in_feats\n            out_hidden = n_hidden if i < n_layers - 1 else n_classes\n\n            self.linears.append(nn.Linear(in_hidden, out_hidden))\n\n            if i < n_layers - 1:\n                self.norms.append(nn.BatchNorm1d(out_hidden))\n\n        self.activation = activation\n        self.input_drop = nn.Dropout(input_drop)\n        self.dropout = nn.Dropout(dropout)\n        self.residual = residual\n\n    def forward(self, h):\n        h = self.input_drop(h)\n\n        h_last = None\n\n        for i in range(self.n_layers):\n            h = self.linears[i](h)\n\n            if self.residual and 0 < i < self.n_layers - 1:\n                h += h_last\n\n            h_last = h\n\n            if i < self.n_layers - 1:\n                h = self.norms[i](h)\n                h = self.activation(h, inplace=True)\n                h = self.dropout(h)\n\n        return h\n"
  },
  {
    "path": "examples/pytorch/ogb/ogbn-products/graphsage/README.md",
    "content": "# GraphSAGE on OGB Products\n\nRequires DGL 0.4.3post2 or later versions.\n\nRun `main.py` and you should directly see the result.\n\nAccuracy over 10 runs: 0.7828772 ± 0.001568163\n"
  },
  {
    "path": "examples/pytorch/ogb/ogbn-products/graphsage/main.py",
    "content": "import argparse\nimport time\n\nimport dgl\nimport dgl.nn.pytorch as dglnn\n\nimport numpy as np\nimport torch as th\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nimport tqdm\nfrom ogb.nodeproppred import DglNodePropPredDataset\n\n\nclass SAGE(nn.Module):\n    def __init__(\n        self, in_feats, n_hidden, n_classes, n_layers, activation, dropout\n    ):\n        super().__init__()\n        self.n_layers = n_layers\n        self.n_hidden = n_hidden\n        self.n_classes = n_classes\n        self.layers = nn.ModuleList()\n        self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, \"mean\"))\n        for i in range(1, n_layers - 1):\n            self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, \"mean\"))\n        self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, \"mean\"))\n        self.dropout = nn.Dropout(dropout)\n        self.activation = activation\n\n    def forward(self, blocks, x):\n        h = x\n        for l, (layer, block) in enumerate(zip(self.layers, blocks)):\n            # We need to first copy the representation of nodes on the RHS from the\n            # appropriate nodes on the LHS.\n            # Note that the shape of h is (num_nodes_LHS, D) and the shape of h_dst\n            # would be (num_nodes_RHS, D)\n            h_dst = h[: block.num_dst_nodes()]\n            # Then we compute the updated representation on the RHS.\n            # The shape of h now becomes (num_nodes_RHS, D)\n            h = layer(block, (h, h_dst))\n            if l != len(self.layers) - 1:\n                h = self.activation(h)\n                h = self.dropout(h)\n        return h\n\n    def inference(self, g, x, device):\n        \"\"\"\n        Inference with the GraphSAGE model on full neighbors (i.e. without neighbor sampling).\n        g : the entire graph.\n        x : the input of entire node set.\n        The inference code is written in a fashion that it could handle any number of nodes and\n        layers.\n        \"\"\"\n        # During inference with sampling, multi-layer blocks are very inefficient because\n        # lots of computations in the first few layers are repeated.\n        # Therefore, we compute the representation of all nodes layer by layer.  The nodes\n        # on each layer are of course splitted in batches.\n        # TODO: can we standardize this?\n        for l, layer in enumerate(self.layers):\n            y = th.zeros(\n                g.num_nodes(),\n                self.n_hidden if l != len(self.layers) - 1 else self.n_classes,\n            ).to(device)\n\n            sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)\n            dataloader = dgl.dataloading.DataLoader(\n                g,\n                th.arange(g.num_nodes()),\n                sampler,\n                batch_size=args.batch_size,\n                shuffle=True,\n                drop_last=False,\n                num_workers=args.num_workers,\n            )\n\n            for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):\n                block = blocks[0].int().to(device)\n\n                h = x[input_nodes]\n                h_dst = h[: block.num_dst_nodes()]\n                h = layer(block, (h, h_dst))\n                if l != len(self.layers) - 1:\n                    h = self.activation(h)\n                    h = self.dropout(h)\n\n                y[output_nodes] = h\n\n            x = y\n        return y\n\n\ndef compute_acc(pred, labels):\n    \"\"\"\n    Compute the accuracy of prediction given the labels.\n    \"\"\"\n    return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred)\n\n\ndef evaluate(model, g, nfeat, labels, val_nid, test_nid, device):\n    \"\"\"\n    Evaluate the model on the validation set specified by ``val_mask``.\n    g : The entire graph.\n    inputs : The features of all the nodes.\n    labels : The labels of all the nodes.\n    val_mask : A 0-1 mask indicating which nodes do we actually compute the accuracy for.\n    device : The GPU device to evaluate on.\n    \"\"\"\n    model.eval()\n    with th.no_grad():\n        pred = model.inference(g, nfeat, device)\n    model.train()\n    return (\n        compute_acc(pred[val_nid], labels[val_nid]),\n        compute_acc(pred[test_nid], labels[test_nid]),\n        pred,\n    )\n\n\ndef load_subtensor(nfeat, labels, seeds, input_nodes):\n    \"\"\"\n    Extracts features and labels for a set of nodes.\n    \"\"\"\n    batch_inputs = nfeat[input_nodes]\n    batch_labels = labels[seeds]\n    return batch_inputs, batch_labels\n\n\n#### Entry point\ndef run(args, device, data):\n    # Unpack data\n    train_nid, val_nid, test_nid, in_feats, labels, n_classes, nfeat, g = data\n\n    # Create PyTorch DataLoader for constructing blocks\n    sampler = dgl.dataloading.MultiLayerNeighborSampler(\n        [int(fanout) for fanout in args.fan_out.split(\",\")]\n    )\n    dataloader = dgl.dataloading.DataLoader(\n        g,\n        train_nid,\n        sampler,\n        batch_size=args.batch_size,\n        shuffle=True,\n        drop_last=False,\n        num_workers=args.num_workers,\n    )\n\n    # Define model and optimizer\n    model = SAGE(\n        in_feats,\n        args.num_hidden,\n        n_classes,\n        args.num_layers,\n        F.relu,\n        args.dropout,\n    )\n    model = model.to(device)\n    loss_fcn = nn.CrossEntropyLoss()\n    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)\n\n    # Training loop\n    avg = 0\n    iter_tput = []\n    best_eval_acc = 0\n    best_test_acc = 0\n    for epoch in range(args.num_epochs):\n        tic = time.time()\n\n        # Loop over the dataloader to sample the computation dependency graph as a list of\n        # blocks.\n        for step, (input_nodes, seeds, blocks) in enumerate(dataloader):\n            tic_step = time.time()\n\n            # copy block to gpu\n            blocks = [blk.int().to(device) for blk in blocks]\n\n            # Load the input features as well as output labels\n            batch_inputs, batch_labels = load_subtensor(\n                nfeat, labels, seeds, input_nodes\n            )\n\n            # Compute loss and prediction\n            batch_pred = model(blocks, batch_inputs)\n            loss = loss_fcn(batch_pred, batch_labels)\n            optimizer.zero_grad()\n            loss.backward()\n            optimizer.step()\n\n            iter_tput.append(len(seeds) / (time.time() - tic_step))\n            if step % args.log_every == 0:\n                acc = compute_acc(batch_pred, batch_labels)\n                gpu_mem_alloc = (\n                    th.cuda.max_memory_allocated() / 1000000\n                    if th.cuda.is_available()\n                    else 0\n                )\n                print(\n                    \"Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | Speed (samples/sec) {:.4f} | GPU {:.1f} MB\".format(\n                        epoch,\n                        step,\n                        loss.item(),\n                        acc.item(),\n                        np.mean(iter_tput[3:]),\n                        gpu_mem_alloc,\n                    )\n                )\n\n        toc = time.time()\n        print(\"Epoch Time(s): {:.4f}\".format(toc - tic))\n        if epoch >= 5:\n            avg += toc - tic\n        if epoch % args.eval_every == 0 and epoch != 0:\n            eval_acc, test_acc, pred = evaluate(\n                model, g, nfeat, labels, val_nid, test_nid, device\n            )\n            if args.save_pred:\n                np.savetxt(\n                    args.save_pred + \"%02d\" % epoch,\n                    pred.argmax(1).cpu().numpy(),\n                    \"%d\",\n                )\n            print(\"Eval Acc {:.4f}\".format(eval_acc))\n            if eval_acc > best_eval_acc:\n                best_eval_acc = eval_acc\n                best_test_acc = test_acc\n            print(\n                \"Best Eval Acc {:.4f} Test Acc {:.4f}\".format(\n                    best_eval_acc, best_test_acc\n                )\n            )\n\n    print(\"Avg epoch time: {}\".format(avg / (epoch - 4)))\n    return best_test_acc\n\n\nif __name__ == \"__main__\":\n    argparser = argparse.ArgumentParser(\"multi-gpu training\")\n    argparser.add_argument(\n        \"--gpu\",\n        type=int,\n        default=0,\n        help=\"GPU device ID. Use -1 for CPU training\",\n    )\n    argparser.add_argument(\"--num-epochs\", type=int, default=20)\n    argparser.add_argument(\"--num-hidden\", type=int, default=256)\n    argparser.add_argument(\"--num-layers\", type=int, default=3)\n    argparser.add_argument(\"--fan-out\", type=str, default=\"5,10,15\")\n    argparser.add_argument(\"--batch-size\", type=int, default=1000)\n    argparser.add_argument(\"--val-batch-size\", type=int, default=10000)\n    argparser.add_argument(\"--log-every\", type=int, default=20)\n    argparser.add_argument(\"--eval-every\", type=int, default=1)\n    argparser.add_argument(\"--lr\", type=float, default=0.003)\n    argparser.add_argument(\"--dropout\", type=float, default=0.5)\n    argparser.add_argument(\n        \"--num-workers\",\n        type=int,\n        default=4,\n        help=\"Number of sampling processes. Use 0 for no extra process.\",\n    )\n    argparser.add_argument(\"--save-pred\", type=str, default=\"\")\n    argparser.add_argument(\"--wd\", type=float, default=0)\n    args = argparser.parse_args()\n\n    if args.gpu >= 0:\n        device = th.device(\"cuda:%d\" % args.gpu)\n    else:\n        device = th.device(\"cpu\")\n\n    # load ogbn-products data\n    data = DglNodePropPredDataset(name=\"ogbn-products\")\n    splitted_idx = data.get_idx_split()\n    train_idx, val_idx, test_idx = (\n        splitted_idx[\"train\"],\n        splitted_idx[\"valid\"],\n        splitted_idx[\"test\"],\n    )\n    graph, labels = data[0]\n    nfeat = graph.ndata.pop(\"feat\").to(device)\n    labels = labels[:, 0].to(device)\n\n    in_feats = nfeat.shape[1]\n    n_classes = (labels.max() + 1).item()\n    # Create csr/coo/csc formats before launching sampling processes\n    # This avoids creating certain formats in each data loader process, which saves momory and CPU.\n    graph.create_formats_()\n    # Pack data\n    data = (\n        train_idx,\n        val_idx,\n        test_idx,\n        in_feats,\n        labels,\n        n_classes,\n        nfeat,\n        graph,\n    )\n\n    # Run 10 times\n    test_accs = []\n    for i in range(10):\n        test_accs.append(run(args, device, data).cpu().numpy())\n        print(\n            \"Average test accuracy:\", np.mean(test_accs), \"±\", np.std(test_accs)\n        )\n"
  },
  {
    "path": "examples/pytorch/ogb/ogbn-products/mlp/README.md",
    "content": "# DGL examples for ogbn-products\n\nRequires DGL 0.5 or later versions.\n\nFor the score of `MLP`, run the following command and you should directly see the result.\n\n```bash\npython3 mlp.py --eval-last\n```\n\n## Results\n\nHere are the results over 10 runs.\n\n| Method | Validation Accuracy |  Test Accuracy  | #Parameters |\n|:------:|:-------------------:|:---------------:|:-----------:|\n|  MLP   |   0.7841 ± 0.0014   | 0.6320 ± 0.0013 |   535,727   |\n"
  },
  {
    "path": "examples/pytorch/ogb/ogbn-products/mlp/mlp.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n\nimport argparse\nimport math\nimport random\nimport time\nfrom collections import OrderedDict\n\nimport dgl.function as fn\n\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torch.optim as optim\nfrom dgl.dataloading import (\n    DataLoader,\n    MultiLayerFullNeighborSampler,\n    MultiLayerNeighborSampler,\n)\nfrom matplotlib.ticker import AutoMinorLocator, MultipleLocator\nfrom models import MLP\nfrom ogb.nodeproppred import DglNodePropPredDataset, Evaluator\nfrom torch import nn\nfrom tqdm import tqdm\n\nepsilon = 1 - math.log(2)\n\ndevice = None\ndataset = \"ogbn-products\"\nn_node_feats, n_edge_feats, n_classes = 0, 0, 0\n\n\ndef seed(seed=0):\n    random.seed(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    torch.backends.cudnn.deterministic = True\n    torch.backends.cudnn.benchmark = False\n\n\ndef load_data(dataset):\n    data = DglNodePropPredDataset(name=dataset)\n    evaluator = Evaluator(name=dataset)\n\n    splitted_idx = data.get_idx_split()\n    train_idx, val_idx, test_idx = (\n        splitted_idx[\"train\"],\n        splitted_idx[\"valid\"],\n        splitted_idx[\"test\"],\n    )\n    graph, labels = data[0]\n    graph.ndata[\"labels\"] = labels\n\n    return graph, labels, train_idx, val_idx, test_idx, evaluator\n\n\ndef preprocess(graph, labels):\n    global n_node_feats, n_classes\n    n_node_feats = graph.ndata[\"feat\"].shape[1]\n    n_classes = (labels.max() + 1).item()\n\n    # graph = graph.remove_self_loop().add_self_loop()\n    n_node_feats = graph.ndata[\"feat\"].shape[-1]\n\n    return graph, labels\n\n\ndef gen_model(args):\n    model = MLP(\n        n_node_feats,\n        n_classes,\n        n_layers=args.n_layers,\n        n_hidden=args.n_hidden,\n        activation=F.relu,\n        dropout=args.dropout,\n        input_drop=args.input_drop,\n        residual=False,\n    )\n\n    return model\n\n\ndef custom_loss_function(x, labels):\n    y = F.cross_entropy(x, labels[:, 0], reduction=\"none\")\n    y = torch.log(epsilon + y) - math.log(epsilon)\n    return torch.mean(y)\n\n\ndef train(\n    args, model, dataloader, labels, train_idx, criterion, optimizer, evaluator\n):\n    model.train()\n\n    loss_sum, total = 0, 0\n\n    preds = torch.zeros(labels.shape[0], n_classes)\n\n    with dataloader.enable_cpu_affinity():\n        for _input_nodes, output_nodes, subgraphs in dataloader:\n            subgraphs = [b.to(device) for b in subgraphs]\n            new_train_idx = list(range(len(output_nodes)))\n\n            pred = model(subgraphs[0].srcdata[\"feat\"])\n            preds[output_nodes] = pred.cpu().detach()\n\n            loss = criterion(\n                pred[new_train_idx], labels[output_nodes][new_train_idx]\n            )\n            optimizer.zero_grad()\n            loss.backward()\n            optimizer.step()\n\n            count = len(new_train_idx)\n            loss_sum += loss.item() * count\n            total += count\n\n    preds = preds.to(train_idx.device)\n    return (\n        loss_sum / total,\n        evaluator(preds[train_idx], labels[train_idx]),\n    )\n\n\n@torch.no_grad()\ndef evaluate(\n    args,\n    model,\n    dataloader,\n    labels,\n    train_idx,\n    val_idx,\n    test_idx,\n    criterion,\n    evaluator,\n):\n    model.eval()\n\n    preds = torch.zeros(labels.shape[0], n_classes, device=device)\n\n    eval_times = 1  # Due to the limitation of memory capacity, we calculate the average of logits 'eval_times' times.\n\n    for _ in range(eval_times):\n        with dataloader.enable_cpu_affinity():\n            for _input_nodes, output_nodes, subgraphs in dataloader:\n                subgraphs = [b.to(device) for b in subgraphs]\n\n                pred = model(subgraphs[0].srcdata[\"feat\"])\n                preds[output_nodes] = pred\n\n    preds /= eval_times\n\n    train_loss = criterion(preds[train_idx], labels[train_idx]).item()\n    val_loss = criterion(preds[val_idx], labels[val_idx]).item()\n    test_loss = criterion(preds[test_idx], labels[test_idx]).item()\n\n    return (\n        evaluator(preds[train_idx], labels[train_idx]),\n        evaluator(preds[val_idx], labels[val_idx]),\n        evaluator(preds[test_idx], labels[test_idx]),\n        train_loss,\n        val_loss,\n        test_loss,\n    )\n\n\ndef run(\n    args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running\n):\n    evaluator_wrapper = lambda pred, labels: evaluator.eval(\n        {\"y_pred\": pred.argmax(dim=-1, keepdim=True), \"y_true\": labels}\n    )[\"acc\"]\n    criterion = custom_loss_function\n\n    train_batch_size = 4096\n    train_sampler = MultiLayerNeighborSampler(\n        [0 for _ in range(args.n_layers)]\n    )  # no not sample neighbors\n    train_dataloader = DataLoader(\n        graph.cpu(),\n        train_idx.cpu(),\n        train_sampler,\n        batch_size=train_batch_size,\n        shuffle=True,\n        num_workers=4,\n    )\n\n    eval_batch_size = 4096\n    eval_sampler = MultiLayerNeighborSampler(\n        [0 for _ in range(args.n_layers)]\n    )  # no not sample neighbors\n    if args.eval_last:\n        eval_idx = torch.cat([train_idx.cpu(), val_idx.cpu()])\n    else:\n        eval_idx = torch.cat([train_idx.cpu(), val_idx.cpu(), test_idx.cpu()])\n    eval_dataloader = DataLoader(\n        graph.cpu(),\n        eval_idx,\n        eval_sampler,\n        batch_size=eval_batch_size,\n        shuffle=False,\n        num_workers=4,\n    )\n\n    model = gen_model(args).to(device)\n\n    optimizer = optim.AdamW(\n        model.parameters(), lr=args.lr, weight_decay=args.wd\n    )\n    lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(\n        optimizer,\n        mode=\"max\",\n        factor=0.7,\n        patience=20,\n        min_lr=1e-4,\n    )\n\n    best_model_state_dict = None\n\n    total_time = 0\n    val_score, best_val_score, final_test_score = 0, 0, 0\n\n    scores, train_scores, val_scores, test_scores = [], [], [], []\n    losses, train_losses, val_losses, test_losses = [], [], [], []\n\n    for epoch in range(1, args.n_epochs + 1):\n        tic = time.time()\n        loss, score = train(\n            args,\n            model,\n            train_dataloader,\n            labels,\n            train_idx,\n            criterion,\n            optimizer,\n            evaluator_wrapper,\n        )\n\n        toc = time.time()\n        total_time += toc - tic\n\n        if epoch % args.eval_every == 0 or epoch % args.log_every == 0:\n            (\n                train_score,\n                val_score,\n                test_score,\n                train_loss,\n                val_loss,\n                test_loss,\n            ) = evaluate(\n                args,\n                model,\n                eval_dataloader,\n                labels,\n                train_idx,\n                val_idx,\n                test_idx,\n                criterion,\n                evaluator_wrapper,\n            )\n\n            if val_score > best_val_score:\n                best_val_score = val_score\n                final_test_score = test_score\n                if args.eval_last:\n                    best_model_state_dict = {\n                        k: v.to(\"cpu\") for k, v in model.state_dict().items()\n                    }\n                    best_model_state_dict = OrderedDict(best_model_state_dict)\n\n            if epoch % args.log_every == 0:\n                print(\n                    f\"Run: {n_running}/{args.n_runs}, Epoch: {epoch}/{args.n_epochs}, Average epoch time: {total_time / epoch}\"\n                )\n                print(\n                    f\"Loss: {loss:.4f}, Score: {score:.4f}\\n\"\n                    f\"Train/Val/Test loss: {train_loss:.4f}/{val_loss:.4f}/{test_loss:.4f}\\n\"\n                    f\"Train/Val/Test/Best val/Final test score: {train_score:.4f}/{val_score:.4f}/{test_score:.4f}/{best_val_score:.4f}/{final_test_score:.4f}\"\n                )\n\n            for l, e in zip(\n                [\n                    scores,\n                    train_scores,\n                    val_scores,\n                    test_scores,\n                    losses,\n                    train_losses,\n                    val_losses,\n                    test_losses,\n                ],\n                [\n                    score,\n                    train_score,\n                    val_score,\n                    test_score,\n                    loss,\n                    train_loss,\n                    val_loss,\n                    test_loss,\n                ],\n            ):\n                l.append(e)\n\n        lr_scheduler.step(val_score)\n\n    if args.eval_last:\n        model.load_state_dict(best_model_state_dict)\n        eval_dataloader = DataLoader(\n            graph.cpu(),\n            test_idx.cpu(),\n            eval_sampler,\n            batch_size=eval_batch_size,\n            shuffle=False,\n            num_workers=4,\n        )\n        final_test_score = evaluate(\n            args,\n            model,\n            eval_dataloader,\n            labels,\n            train_idx,\n            val_idx,\n            test_idx,\n            criterion,\n            evaluator_wrapper,\n        )[2]\n\n    print(\"*\" * 50)\n    print(\n        f\"Average epoch time: {total_time / args.n_epochs}, Test score: {final_test_score}\"\n    )\n\n    if args.plot_curves:\n        fig = plt.figure(figsize=(24, 24))\n        ax = fig.gca()\n        ax.set_xticks(np.arange(0, args.n_epochs, 100))\n        ax.set_yticks(np.linspace(0, 1.0, 101))\n        ax.tick_params(labeltop=True, labelright=True)\n        for y, label in zip(\n            [train_scores, val_scores, test_scores],\n            [\"train score\", \"val score\", \"test score\"],\n        ):\n            plt.plot(\n                range(1, args.n_epochs + 1, args.log_every),\n                y,\n                label=label,\n                linewidth=1,\n            )\n        ax.xaxis.set_major_locator(MultipleLocator(20))\n        ax.xaxis.set_minor_locator(AutoMinorLocator(1))\n        ax.yaxis.set_major_locator(MultipleLocator(0.01))\n        ax.yaxis.set_minor_locator(AutoMinorLocator(2))\n        plt.grid(which=\"major\", color=\"red\", linestyle=\"dotted\")\n        plt.grid(which=\"minor\", color=\"orange\", linestyle=\"dotted\")\n        plt.legend()\n        plt.tight_layout()\n        plt.savefig(f\"gat_score_{n_running}.png\")\n\n        fig = plt.figure(figsize=(24, 24))\n        ax = fig.gca()\n        ax.set_xticks(np.arange(0, args.n_epochs, 100))\n        ax.tick_params(labeltop=True, labelright=True)\n        for y, label in zip(\n            [losses, train_losses, val_losses, test_losses],\n            [\"loss\", \"train loss\", \"val loss\", \"test loss\"],\n        ):\n            plt.plot(\n                range(1, args.n_epochs + 1, args.log_every),\n                y,\n                label=label,\n                linewidth=1,\n            )\n        ax.xaxis.set_major_locator(MultipleLocator(20))\n        ax.xaxis.set_minor_locator(AutoMinorLocator(1))\n        ax.yaxis.set_major_locator(MultipleLocator(0.1))\n        ax.yaxis.set_minor_locator(AutoMinorLocator(5))\n        plt.grid(which=\"major\", color=\"red\", linestyle=\"dotted\")\n        plt.grid(which=\"minor\", color=\"orange\", linestyle=\"dotted\")\n        plt.legend()\n        plt.tight_layout()\n        plt.savefig(f\"gat_loss_{n_running}.png\")\n\n    return best_val_score, final_test_score\n\n\ndef count_parameters(args):\n    model = gen_model(args)\n    return sum(\n        [np.prod(p.size()) for p in model.parameters() if p.requires_grad]\n    )\n\n\ndef main():\n    global device\n\n    argparser = argparse.ArgumentParser(\n        \"GAT on OGBN-Proteins\",\n        formatter_class=argparse.ArgumentDefaultsHelpFormatter,\n    )\n    argparser.add_argument(\n        \"--cpu\",\n        action=\"store_true\",\n        help=\"CPU mode. This option overrides '--gpu'.\",\n    )\n    argparser.add_argument(\"--gpu\", type=int, default=0, help=\"GPU device ID.\")\n    argparser.add_argument(\"--seed\", type=int, help=\"seed\", default=0)\n    argparser.add_argument(\"--n-runs\", type=int, default=10)\n    argparser.add_argument(\"--n-epochs\", type=int, default=500)\n    argparser.add_argument(\"--lr\", type=float, default=0.01)\n    argparser.add_argument(\"--n-layers\", type=int, default=4)\n    argparser.add_argument(\"--n-hidden\", type=int, default=480)\n    argparser.add_argument(\"--dropout\", type=float, default=0.2)\n    argparser.add_argument(\"--input-drop\", type=float, default=0)\n    argparser.add_argument(\"--wd\", type=float, default=0)\n    argparser.add_argument(\n        \"--estimation-mode\",\n        action=\"store_true\",\n        help=\"Estimate the score of test set for speed.\",\n    )\n    argparser.add_argument(\n        \"--eval-last\",\n        action=\"store_true\",\n        help=\"Evaluate the score of test set at last.\",\n    )\n    argparser.add_argument(\"--eval-every\", type=int, default=1)\n    argparser.add_argument(\"--log-every\", type=int, default=1)\n    argparser.add_argument(\"--plot-curves\", action=\"store_true\")\n    args = argparser.parse_args()\n\n    if args.cpu:\n        device = torch.device(\"cpu\")\n    else:\n        device = torch.device(\"cuda:%d\" % args.gpu)\n\n    if args.estimation_mode:\n        print(\n            \"WARNING: Estimation mode is enabled. The test score is not accurate.\"\n        )\n\n    seed(args.seed)\n\n    graph, labels, train_idx, val_idx, test_idx, evaluator = load_data(dataset)\n    graph, labels = preprocess(graph, labels)\n    graph.create_formats_()\n\n    # graph = graph.to(device)\n    labels = labels.to(device)\n    train_idx = train_idx.to(device)\n    val_idx = val_idx.to(device)\n    test_idx = test_idx.to(device)\n    if args.estimation_mode:\n        test_idx = test_idx[torch.arange(start=0, end=len(test_idx), step=50)]\n\n    val_scores, test_scores = [], []\n\n    for i in range(1, args.n_runs + 1):\n        val_score, test_score = run(\n            args, graph, labels, train_idx, val_idx, test_idx, evaluator, i\n        )\n        val_scores.append(val_score)\n        test_scores.append(test_score)\n\n    print(args)\n    print(f\"Runned {args.n_runs} times\")\n    print(\"Val scores:\", val_scores)\n    print(\"Test scores:\", test_scores)\n    print(f\"Average val score: {np.mean(val_scores)} ± {np.std(val_scores)}\")\n    print(f\"Average test score: {np.mean(test_scores)} ± {np.std(test_scores)}\")\n    print(f\"Number of params: {count_parameters(args)}\")\n\n    if args.estimation_mode:\n        print(\n            \"WARNING: Estimation mode is enabled. The test score is not accurate.\"\n        )\n\n\nif __name__ == \"__main__\":\n    main()\n\n# Namespace(cpu=False, dropout=0.2, estimation_mode=False, eval_every=1, eval_last=True, gpu=2, input_drop=0, log_every=1, lr=0.01, n_epochs=500, n_hidden=480, n_layers=4, n_runs=10, plot_curves=True, seed=0, wd=0)\n# Runned 10 times\n# Val scores: [0.7846298603870508, 0.7811713246700405, 0.7828751621188618, 0.7839941001449533, 0.7843501258805279, 0.7841466826030568, 0.7846298603870508, 0.7865880019327112, 0.7832057574447524, 0.7851384685807289]\n# Test scores: [0.6318660190656417, 0.6304137516261193, 0.6329961126767946, 0.6312885462007662, 0.6340624944929965, 0.6301507710256831, 0.6314534738969161, 0.6334637843631373, 0.6312465235275007, 0.6329857199726536]\n# Average val score: 0.7840729344149735 ± 0.0013702460721628086\n# Average test score: 0.6319927196848208 ± 0.001252448369121226\n# Number of params: 535727\n"
  },
  {
    "path": "examples/pytorch/ogb/ogbn-products/mlp/models.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass MLP(nn.Module):\n    def __init__(\n        self,\n        in_feats,\n        n_classes,\n        n_layers,\n        n_hidden,\n        activation,\n        dropout=0.0,\n        input_drop=0.0,\n        residual=False,\n    ):\n        super().__init__()\n        self.n_layers = n_layers\n        self.n_hidden = n_hidden\n        self.n_classes = n_classes\n\n        self.linears = nn.ModuleList()\n        self.norms = nn.ModuleList()\n\n        for i in range(n_layers):\n            in_hidden = n_hidden if i > 0 else in_feats\n            out_hidden = n_hidden if i < n_layers - 1 else n_classes\n\n            self.linears.append(nn.Linear(in_hidden, out_hidden))\n\n            if i < n_layers - 1:\n                self.norms.append(nn.BatchNorm1d(out_hidden))\n\n        self.activation = activation\n        self.input_drop = nn.Dropout(input_drop)\n        self.dropout = nn.Dropout(dropout)\n        self.residual = residual\n\n    def forward(self, h):\n        h = self.input_drop(h)\n\n        h_last = None\n\n        for i in range(self.n_layers):\n            h = self.linears[i](h)\n\n            if self.residual and 0 < i < self.n_layers - 1:\n                h += h_last\n\n            h_last = h\n\n            if i < self.n_layers - 1:\n                h = self.norms[i](h)\n                h = self.activation(h, inplace=True)\n                h = self.dropout(h)\n\n        return h\n"
  },
  {
    "path": "examples/pytorch/ogb/ogbn-proteins/README.md",
    "content": "# DGL for ogbn-proteins\n\n## GAT\n\nDGL implementation of GAT for [ogbn-proteins](https://ogb.stanford.edu/docs/nodeprop/). Using some of the techniques from *Bag of Tricks for Node Classification with Graph Neural Networks* ([https://arxiv.org/abs/2103.13355](https://arxiv.org/abs/2103.13355)).\n\nRequires DGL 0.5 or later versions.\n\n### Usage\n\nFor the best score, run `gat.py` and you should directly see the result.\n\n```bash\npython3 gat.py\n```\n\nFor the score of `GAT+labels`, run `gat.py` with `--use-labels` enabled and you should directly see the result.\n\n```bash\npython3 gat.py --use-labels\n```\n\n### Results\n\nHere are the results over 10 runs.\n\n|   Method   | Validation ROC-AUC |  Test ROC-AUC   | #Parameters |\n|:----------:|:------------------:|:---------------:|:-----------:|\n|    GAT     |  0.9276 ± 0.0007   | 0.8747 ± 0.0016 |  2,475,232  |\n| GAT+labels |  0.9280 ± 0.0008   | 0.8765 ± 0.0008 |  2,484,192  |\n\n## MWE-GCN and MWE-DGCN\n\n### Models\n[MWE-GCN and MWE-DGCN](https://cims.nyu.edu/~chenzh/files/GCN_with_edge_weights.pdf) are GCN models designed for graphs whose edges contain multi-dimensional edge weights that indicate the strengths of the relations represented by the edges.\n\n### Dependencies\n- DGL 0.5.2\n- PyTorch 1.4.0\n- OGB 1.2.0\n- Tensorboard 2.1.1\n\n### Usage\n\nTo use MWE-GCN:\n```python\npython main_proteins_full_dgl.py --model MWE-GCN\n```\n\nTo use MWE-DGCN:\n```python\npython main_proteins_full_dgl.py --model MWE-DGCN\n```\n\nAdditional optional arguments include 'rand_seed' (the random seed), 'cuda' (the cuda device number, if available), 'postfix' (a string appended to the saved-model file)\n\n"
  },
  {
    "path": "examples/pytorch/ogb/ogbn-proteins/configure.py",
    "content": "\"\"\"Best hyperparameters found.\"\"\"\nimport torch\n\nMWE_GCN_proteins = {\n    \"num_ew_channels\": 8,\n    \"num_epochs\": 2000,\n    \"in_feats\": 1,\n    \"hidden_feats\": 10,\n    \"out_feats\": 112,\n    \"n_layers\": 3,\n    \"lr\": 2e-2,\n    \"weight_decay\": 0,\n    \"patience\": 1000,\n    \"dropout\": 0.2,\n    \"aggr_mode\": \"sum\",  ## 'sum' or 'concat' for the aggregation across channels\n    \"ewnorm\": \"both\",\n}\n\nMWE_DGCN_proteins = {\n    \"num_ew_channels\": 8,\n    \"num_epochs\": 2000,\n    \"in_feats\": 1,\n    \"hidden_feats\": 10,\n    \"out_feats\": 112,\n    \"n_layers\": 2,\n    \"lr\": 1e-2,\n    \"weight_decay\": 0,\n    \"patience\": 300,\n    \"dropout\": 0.5,\n    \"aggr_mode\": \"sum\",\n    \"residual\": True,\n    \"ewnorm\": \"none\",\n}\n\n\ndef get_exp_configure(args):\n    if args[\"model\"] == \"MWE-GCN\":\n        return MWE_GCN_proteins\n    elif args[\"model\"] == \"MWE-DGCN\":\n        return MWE_DGCN_proteins\n"
  },
  {
    "path": "examples/pytorch/ogb/ogbn-proteins/gat.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n\nimport argparse\nimport os\nimport random\nimport sys\nimport time\n\nimport dgl\nimport dgl.function as fn\n\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torch.optim as optim\nfrom dgl.dataloading import (\n    DataLoader,\n    MultiLayerFullNeighborSampler,\n    MultiLayerNeighborSampler,\n)\nfrom matplotlib.ticker import AutoMinorLocator, MultipleLocator\nfrom models import GAT\nfrom ogb.nodeproppred import DglNodePropPredDataset, Evaluator\nfrom torch import nn\n\ndevice = None\ndataset = \"ogbn-proteins\"\nn_node_feats, n_edge_feats, n_classes = 0, 8, 112\n\n\ndef seed(seed=0):\n    random.seed(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    torch.backends.cudnn.deterministic = True\n    torch.backends.cudnn.benchmark = False\n    dgl.random.seed(seed)\n\n\ndef load_data(dataset):\n    data = DglNodePropPredDataset(name=dataset)\n    evaluator = Evaluator(name=dataset)\n\n    splitted_idx = data.get_idx_split()\n    train_idx, val_idx, test_idx = (\n        splitted_idx[\"train\"],\n        splitted_idx[\"valid\"],\n        splitted_idx[\"test\"],\n    )\n    graph, labels = data[0]\n    graph.ndata[\"labels\"] = labels\n\n    return graph, labels, train_idx, val_idx, test_idx, evaluator\n\n\ndef preprocess(graph, labels, train_idx):\n    global n_node_feats\n\n    # The sum of the weights of adjacent edges is used as node features.\n    graph.update_all(\n        fn.copy_e(\"feat\", \"feat_copy\"), fn.sum(\"feat_copy\", \"feat\")\n    )\n    n_node_feats = graph.ndata[\"feat\"].shape[-1]\n\n    # Only the labels in the training set are used as features, while others are filled with zeros.\n    graph.ndata[\"train_labels_onehot\"] = torch.zeros(\n        graph.num_nodes(), n_classes\n    )\n    graph.ndata[\"train_labels_onehot\"][train_idx, labels[train_idx, 0]] = 1\n    graph.ndata[\"deg\"] = graph.out_degrees().float().clamp(min=1)\n\n    graph.create_formats_()\n\n    return graph, labels\n\n\ndef gen_model(args):\n    if args.use_labels:\n        n_node_feats_ = n_node_feats + n_classes\n    else:\n        n_node_feats_ = n_node_feats\n\n    model = GAT(\n        n_node_feats_,\n        n_edge_feats,\n        n_classes,\n        n_layers=args.n_layers,\n        n_heads=args.n_heads,\n        n_hidden=args.n_hidden,\n        edge_emb=16,\n        activation=F.relu,\n        dropout=args.dropout,\n        input_drop=args.input_drop,\n        attn_drop=args.attn_drop,\n        edge_drop=args.edge_drop,\n        use_attn_dst=not args.no_attn_dst,\n    )\n\n    return model\n\n\ndef add_labels(graph, idx):\n    feat = graph.srcdata[\"feat\"]\n    train_labels_onehot = torch.zeros([feat.shape[0], n_classes], device=device)\n    train_labels_onehot[idx] = graph.srcdata[\"train_labels_onehot\"][idx]\n    graph.srcdata[\"feat\"] = torch.cat([feat, train_labels_onehot], dim=-1)\n\n\ndef train(\n    args,\n    model,\n    dataloader,\n    _labels,\n    _train_idx,\n    criterion,\n    optimizer,\n    _evaluator,\n):\n    model.train()\n\n    loss_sum, total = 0, 0\n\n    for input_nodes, output_nodes, subgraphs in dataloader:\n        subgraphs = [b.to(device) for b in subgraphs]\n        new_train_idx = torch.arange(len(output_nodes), device=device)\n\n        if args.use_labels:\n            train_labels_idx = torch.arange(\n                len(output_nodes), len(input_nodes), device=device\n            )\n            train_pred_idx = new_train_idx\n\n            add_labels(subgraphs[0], train_labels_idx)\n        else:\n            train_pred_idx = new_train_idx\n\n        pred = model(subgraphs)\n        loss = criterion(\n            pred[train_pred_idx],\n            subgraphs[-1].dstdata[\"labels\"][train_pred_idx].float(),\n        )\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        count = len(train_pred_idx)\n        loss_sum += loss.item() * count\n        total += count\n\n        # torch.cuda.empty_cache()\n\n    return loss_sum / total\n\n\n@torch.no_grad()\ndef evaluate(\n    args,\n    model,\n    dataloader,\n    labels,\n    train_idx,\n    val_idx,\n    test_idx,\n    criterion,\n    evaluator,\n):\n    model.eval()\n\n    preds = torch.zeros(labels.shape).to(device)\n\n    # Due to the memory capacity constraints, we use sampling for inference and calculate the average of the predictions 'eval_times' times.\n    eval_times = 1\n\n    for _ in range(eval_times):\n        for input_nodes, output_nodes, subgraphs in dataloader:\n            subgraphs = [b.to(device) for b in subgraphs]\n            new_train_idx = list(range(len(input_nodes)))\n\n            if args.use_labels:\n                add_labels(subgraphs[0], new_train_idx)\n\n            pred = model(subgraphs)\n            preds[output_nodes] += pred\n\n            # torch.cuda.empty_cache()\n\n    preds /= eval_times\n\n    train_loss = criterion(preds[train_idx], labels[train_idx].float()).item()\n    val_loss = criterion(preds[val_idx], labels[val_idx].float()).item()\n    test_loss = criterion(preds[test_idx], labels[test_idx].float()).item()\n\n    return (\n        evaluator(preds[train_idx], labels[train_idx]),\n        evaluator(preds[val_idx], labels[val_idx]),\n        evaluator(preds[test_idx], labels[test_idx]),\n        train_loss,\n        val_loss,\n        test_loss,\n        preds,\n    )\n\n\ndef run(\n    args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running\n):\n    evaluator_wrapper = lambda pred, labels: evaluator.eval(\n        {\"y_pred\": pred, \"y_true\": labels}\n    )[\"rocauc\"]\n\n    train_batch_size = (len(train_idx) + 9) // 10\n    # batch_size = len(train_idx)\n    train_sampler = MultiLayerNeighborSampler(\n        [32 for _ in range(args.n_layers)]\n    )\n    # sampler = MultiLayerFullNeighborSampler(args.n_layers)\n    train_dataloader = DataLoader(\n        graph.cpu(),\n        train_idx.cpu(),\n        train_sampler,\n        batch_size=train_batch_size,\n        num_workers=10,\n    )\n\n    eval_sampler = MultiLayerNeighborSampler(\n        [100 for _ in range(args.n_layers)]\n    )\n    # sampler = MultiLayerFullNeighborSampler(args.n_layers)\n    eval_dataloader = DataLoader(\n        graph.cpu(),\n        torch.cat([train_idx.cpu(), val_idx.cpu(), test_idx.cpu()]),\n        eval_sampler,\n        batch_size=65536,\n        num_workers=10,\n    )\n\n    criterion = nn.BCEWithLogitsLoss()\n\n    model = gen_model(args).to(device)\n\n    optimizer = optim.AdamW(\n        model.parameters(), lr=args.lr, weight_decay=args.wd\n    )\n    lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(\n        optimizer, mode=\"max\", factor=0.75, patience=50\n    )\n\n    total_time = 0\n    val_score, best_val_score, final_test_score = 0, 0, 0\n\n    train_scores, val_scores, test_scores = [], [], []\n    losses, train_losses, val_losses, test_losses = [], [], [], []\n    final_pred = None\n\n    for epoch in range(1, args.n_epochs + 1):\n        tic = time.time()\n\n        loss = train(\n            args,\n            model,\n            train_dataloader,\n            labels,\n            train_idx,\n            criterion,\n            optimizer,\n            evaluator_wrapper,\n        )\n\n        toc = time.time()\n        total_time += toc - tic\n\n        if (\n            epoch == args.n_epochs\n            or epoch % args.eval_every == 0\n            or epoch % args.log_every == 0\n        ):\n            (\n                train_score,\n                val_score,\n                test_score,\n                train_loss,\n                val_loss,\n                test_loss,\n                pred,\n            ) = evaluate(\n                args,\n                model,\n                eval_dataloader,\n                labels,\n                train_idx,\n                val_idx,\n                test_idx,\n                criterion,\n                evaluator_wrapper,\n            )\n\n            if val_score > best_val_score:\n                best_val_score = val_score\n                final_test_score = test_score\n                final_pred = pred\n\n            if epoch % args.log_every == 0:\n                print(\n                    f\"Run: {n_running}/{args.n_runs}, Epoch: {epoch}/{args.n_epochs}, Average epoch time: {total_time / epoch:.2f}s\"\n                )\n                print(\n                    f\"Loss: {loss:.4f}\\n\"\n                    f\"Train/Val/Test loss: {train_loss:.4f}/{val_loss:.4f}/{test_loss:.4f}\\n\"\n                    f\"Train/Val/Test/Best val/Final test score: {train_score:.4f}/{val_score:.4f}/{test_score:.4f}/{best_val_score:.4f}/{final_test_score:.4f}\"\n                )\n\n            for l, e in zip(\n                [\n                    train_scores,\n                    val_scores,\n                    test_scores,\n                    losses,\n                    train_losses,\n                    val_losses,\n                    test_losses,\n                ],\n                [\n                    train_score,\n                    val_score,\n                    test_score,\n                    loss,\n                    train_loss,\n                    val_loss,\n                    test_loss,\n                ],\n            ):\n                l.append(e)\n\n        lr_scheduler.step(val_score)\n\n    print(\"*\" * 50)\n    print(\n        f\"Best val score: {best_val_score}, Final test score: {final_test_score}\"\n    )\n    print(\"*\" * 50)\n\n    if args.plot:\n        fig = plt.figure(figsize=(24, 24))\n        ax = fig.gca()\n        ax.set_xticks(np.arange(0, args.n_epochs, 100))\n        ax.set_yticks(np.linspace(0, 1.0, 101))\n        ax.tick_params(labeltop=True, labelright=True)\n        for y, label in zip(\n            [train_scores, val_scores, test_scores],\n            [\"train score\", \"val score\", \"test score\"],\n        ):\n            plt.plot(\n                range(1, args.n_epochs + 1, args.log_every),\n                y,\n                label=label,\n                linewidth=1,\n            )\n        ax.xaxis.set_major_locator(MultipleLocator(100))\n        ax.xaxis.set_minor_locator(AutoMinorLocator(1))\n        ax.yaxis.set_major_locator(MultipleLocator(0.01))\n        ax.yaxis.set_minor_locator(AutoMinorLocator(2))\n        plt.grid(which=\"major\", color=\"red\", linestyle=\"dotted\")\n        plt.grid(which=\"minor\", color=\"orange\", linestyle=\"dotted\")\n        plt.legend()\n        plt.tight_layout()\n        plt.savefig(f\"gat_score_{n_running}.png\")\n\n        fig = plt.figure(figsize=(24, 24))\n        ax = fig.gca()\n        ax.set_xticks(np.arange(0, args.n_epochs, 100))\n        ax.tick_params(labeltop=True, labelright=True)\n        for y, label in zip(\n            [losses, train_losses, val_losses, test_losses],\n            [\"loss\", \"train loss\", \"val loss\", \"test loss\"],\n        ):\n            plt.plot(\n                range(1, args.n_epochs + 1, args.log_every),\n                y,\n                label=label,\n                linewidth=1,\n            )\n        ax.xaxis.set_major_locator(MultipleLocator(100))\n        ax.xaxis.set_minor_locator(AutoMinorLocator(1))\n        ax.yaxis.set_major_locator(MultipleLocator(0.1))\n        ax.yaxis.set_minor_locator(AutoMinorLocator(5))\n        plt.grid(which=\"major\", color=\"red\", linestyle=\"dotted\")\n        plt.grid(which=\"minor\", color=\"orange\", linestyle=\"dotted\")\n        plt.legend()\n        plt.tight_layout()\n        plt.savefig(f\"gat_loss_{n_running}.png\")\n\n    if args.save_pred:\n        os.makedirs(\"./output\", exist_ok=True)\n        torch.save(F.softmax(final_pred, dim=1), f\"./output/{n_running}.pt\")\n\n    return best_val_score, final_test_score\n\n\ndef count_parameters(args):\n    model = gen_model(args)\n    return sum(\n        [np.prod(p.size()) for p in model.parameters() if p.requires_grad]\n    )\n\n\ndef main():\n    global device\n\n    argparser = argparse.ArgumentParser(\n        \"GAT implementation on ogbn-proteins\",\n        formatter_class=argparse.ArgumentDefaultsHelpFormatter,\n    )\n    argparser.add_argument(\n        \"--cpu\",\n        action=\"store_true\",\n        help=\"CPU mode. This option overrides '--gpu'.\",\n    )\n    argparser.add_argument(\"--gpu\", type=int, default=0, help=\"GPU device ID\")\n    argparser.add_argument(\"--seed\", type=int, default=0, help=\"random seed\")\n    argparser.add_argument(\n        \"--n-runs\", type=int, default=10, help=\"running times\"\n    )\n    argparser.add_argument(\n        \"--n-epochs\", type=int, default=1200, help=\"number of epochs\"\n    )\n    argparser.add_argument(\n        \"--use-labels\",\n        action=\"store_true\",\n        help=\"Use labels in the training set as input features.\",\n    )\n    argparser.add_argument(\n        \"--no-attn-dst\", action=\"store_true\", help=\"Don't use attn_dst.\"\n    )\n    argparser.add_argument(\n        \"--n-heads\", type=int, default=6, help=\"number of heads\"\n    )\n    argparser.add_argument(\n        \"--lr\", type=float, default=0.01, help=\"learning rate\"\n    )\n    argparser.add_argument(\n        \"--n-layers\", type=int, default=6, help=\"number of layers\"\n    )\n    argparser.add_argument(\n        \"--n-hidden\", type=int, default=80, help=\"number of hidden units\"\n    )\n    argparser.add_argument(\n        \"--dropout\", type=float, default=0.25, help=\"dropout rate\"\n    )\n    argparser.add_argument(\n        \"--input-drop\", type=float, default=0.1, help=\"input drop rate\"\n    )\n    argparser.add_argument(\n        \"--attn-drop\", type=float, default=0.0, help=\"attention dropout rate\"\n    )\n    argparser.add_argument(\n        \"--edge-drop\", type=float, default=0.1, help=\"edge drop rate\"\n    )\n    argparser.add_argument(\"--wd\", type=float, default=0, help=\"weight decay\")\n    argparser.add_argument(\n        \"--eval-every\",\n        type=int,\n        default=5,\n        help=\"evaluate every EVAL_EVERY epochs\",\n    )\n    argparser.add_argument(\n        \"--log-every\", type=int, default=5, help=\"log every LOG_EVERY epochs\"\n    )\n    argparser.add_argument(\n        \"--plot\", action=\"store_true\", help=\"plot learning curves\"\n    )\n    argparser.add_argument(\n        \"--save-pred\", action=\"store_true\", help=\"save final predictions\"\n    )\n    args = argparser.parse_args()\n\n    if args.cpu:\n        device = torch.device(\"cpu\")\n    else:\n        device = torch.device(f\"cuda:{args.gpu}\")\n\n    # load data & preprocess\n    print(\"Loading data\")\n    graph, labels, train_idx, val_idx, test_idx, evaluator = load_data(dataset)\n    print(\"Preprocessing\")\n    graph, labels = preprocess(graph, labels, train_idx)\n\n    labels, train_idx, val_idx, test_idx = map(\n        lambda x: x.to(device), (labels, train_idx, val_idx, test_idx)\n    )\n\n    # run\n    val_scores, test_scores = [], []\n\n    for i in range(args.n_runs):\n        print(\"Running\", i)\n        seed(args.seed + i)\n        val_score, test_score = run(\n            args, graph, labels, train_idx, val_idx, test_idx, evaluator, i + 1\n        )\n        val_scores.append(val_score)\n        test_scores.append(test_score)\n\n    print(\" \".join(sys.argv))\n    print(args)\n    print(f\"Runned {args.n_runs} times\")\n    print(\"Val scores:\", val_scores)\n    print(\"Test scores:\", test_scores)\n    print(f\"Average val score: {np.mean(val_scores)} ± {np.std(val_scores)}\")\n    print(f\"Average test score: {np.mean(test_scores)} ± {np.std(test_scores)}\")\n    print(f\"Number of params: {count_parameters(args)}\")\n\n\nif __name__ == \"__main__\":\n    main()\n\n# Namespace(attn_drop=0.0, cpu=False, dropout=0.25, edge_drop=0.1, eval_every=5, gpu=6, input_drop=0.1, log_every=5, lr=0.01, n_epochs=1200, n_heads=6, n_hidden=80, n_layers=6, n_runs=10, no_attn_dst=False, plot=True, save_pred=False, seed=0, use_labels=False, wd=0)\n# Runned 10 times\n# Val scores: [0.927741031859485, 0.9272113161947824, 0.9271363901359605, 0.9275579074100136, 0.9264291968462317, 0.9275278541203443, 0.9286381790529751, 0.9288245051991526, 0.9269289529175155, 0.9278177920224489]\n# Test scores: [0.8754403567694566, 0.8749781870941457, 0.8735933245353141, 0.8759835445000637, 0.8745950242855286, 0.8742530369108132, 0.8784892022402326, 0.873345314887444, 0.8724393129004984, 0.874077975765639]\n# Average val score: 0.927581312575891 ± 0.0006953509986591492\n# Average test score: 0.8747195279889135 ± 0.001593598488797452\n# Number of params: 2475232\n\n# Namespace(attn_drop=0.0, cpu=False, dropout=0.25, edge_drop=0.1, eval_every=5, gpu=7, input_drop=0.1, log_every=5, lr=0.01, n_epochs=1200, n_heads=6, n_hidden=80, n_layers=6, n_runs=10, no_attn_dst=False, plot=True, save_pred=False, seed=0, use_labels=True, wd=0)\n# Runned 10 times\n# Val scores: [0.9293776332568928, 0.9281066322254939, 0.9286775378440911, 0.9270252685136046, 0.9267937838323375, 0.9277731792338011, 0.9285615428437761, 0.9270819730221879, 0.9276822010553241, 0.9287115722177839]\n# Test scores: [0.8761623033485811, 0.8773002619440896, 0.8756680817047869, 0.8751873860287073, 0.875781797307807, 0.8764533839446703, 0.8771202308989311, 0.8765888651476396, 0.8773581283481205, 0.8777751912293709]\n# Average val score: 0.9279791324045293 ± 0.0008115348697502517\n# Average test score: 0.8765395629902706 ± 0.0008016806017700173\n# Number of params: 2484192\n"
  },
  {
    "path": "examples/pytorch/ogb/ogbn-proteins/main_proteins_full_dgl.py",
    "content": "import os\nimport time\n\nimport dgl.function as fn\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom ogb.nodeproppred import Evaluator\nfrom ogb.nodeproppred.dataset_dgl import DglNodePropPredDataset\nfrom torch.optim import Adam\nfrom torch.optim.lr_scheduler import ReduceLROnPlateau\nfrom torch.utils.tensorboard import SummaryWriter\nfrom utils import load_model, set_random_seed\n\n\ndef normalize_edge_weights(graph, device, num_ew_channels):\n    degs = graph.in_degrees().float()\n    degs = torch.clamp(degs, min=1)\n    norm = torch.pow(degs, 0.5)\n    norm = norm.to(args[\"device\"])\n    graph.ndata[\"norm\"] = norm.unsqueeze(1)\n    graph.apply_edges(fn.e_div_u(\"feat\", \"norm\", \"feat\"))\n    graph.apply_edges(fn.e_div_v(\"feat\", \"norm\", \"feat\"))\n    for channel in range(num_ew_channels):\n        graph.edata[\"feat_\" + str(channel)] = graph.edata[\"feat\"][\n            :, channel : channel + 1\n        ]\n\n\ndef run_a_train_epoch(graph, node_idx, model, criterion, optimizer, evaluator):\n    model.train()\n    logits = model(graph)[node_idx]\n    labels = graph.ndata[\"labels\"][node_idx]\n    loss = criterion(logits, labels)\n    optimizer.zero_grad()\n    loss.backward()\n    optimizer.step()\n\n    loss = loss.data.item()\n    labels = labels.cpu().numpy()\n    preds = logits.cpu().detach().numpy()\n\n    return loss, evaluator.eval({\"y_true\": labels, \"y_pred\": preds})[\"rocauc\"]\n\n\ndef run_an_eval_epoch(graph, splitted_idx, model, evaluator):\n    model.eval()\n    with torch.no_grad():\n        logits = model(graph)\n    labels = graph.ndata[\"labels\"].cpu().numpy()\n    preds = logits.cpu().detach().numpy()\n\n    train_score = evaluator.eval(\n        {\n            \"y_true\": labels[splitted_idx[\"train\"]],\n            \"y_pred\": preds[splitted_idx[\"train\"]],\n        }\n    )\n    val_score = evaluator.eval(\n        {\n            \"y_true\": labels[splitted_idx[\"valid\"]],\n            \"y_pred\": preds[splitted_idx[\"valid\"]],\n        }\n    )\n    test_score = evaluator.eval(\n        {\n            \"y_true\": labels[splitted_idx[\"test\"]],\n            \"y_pred\": preds[splitted_idx[\"test\"]],\n        }\n    )\n\n    return train_score[\"rocauc\"], val_score[\"rocauc\"], test_score[\"rocauc\"]\n\n\ndef main(args):\n    print(args)\n    if args[\"rand_seed\"] > -1:\n        set_random_seed(args[\"rand_seed\"])\n\n    dataset = DglNodePropPredDataset(name=args[\"dataset\"])\n    print(dataset.meta_info)\n    splitted_idx = dataset.get_idx_split()\n    graph = dataset.graph[0]\n    graph.ndata[\"labels\"] = dataset.labels.float().to(args[\"device\"])\n    graph.edata[\"feat\"] = graph.edata[\"feat\"].float().to(args[\"device\"])\n\n    if args[\"ewnorm\"] == \"both\":\n        print(\"Symmetric normalization of edge weights by degree\")\n        normalize_edge_weights(graph, args[\"device\"], args[\"num_ew_channels\"])\n    elif args[\"ewnorm\"] == \"none\":\n        print(\"Not normalizing edge weights\")\n        for channel in range(args[\"num_ew_channels\"]):\n            graph.edata[\"feat_\" + str(channel)] = graph.edata[\"feat\"][\n                :, channel : channel + 1\n            ]\n\n    model = load_model(args).to(args[\"device\"])\n    optimizer = Adam(\n        model.parameters(), lr=args[\"lr\"], weight_decay=args[\"weight_decay\"]\n    )\n    min_lr = 1e-3\n    scheduler = ReduceLROnPlateau(\n        optimizer, \"max\", factor=0.7, patience=100, verbose=True, min_lr=min_lr\n    )\n    print(\"scheduler min_lr\", min_lr)\n\n    criterion = nn.BCEWithLogitsLoss()\n    evaluator = Evaluator(args[\"dataset\"])\n\n    print(\"model\", args[\"model\"])\n    print(\"n_layers\", args[\"n_layers\"])\n    print(\"hidden dim\", args[\"hidden_feats\"])\n    print(\"lr\", args[\"lr\"])\n\n    dur = []\n    best_val_score = 0.0\n    num_patient_epochs = 0\n    model_folder = \"./saved_models/\"\n    model_path = (\n        model_folder + str(args[\"exp_name\"]) + \"_\" + str(args[\"postfix\"])\n    )\n\n    if not os.path.exists(model_folder):\n        os.makedirs(model_folder)\n\n    for epoch in range(1, args[\"num_epochs\"] + 1):\n        if epoch >= 3:\n            t0 = time.time()\n\n        loss, train_score = run_a_train_epoch(\n            graph, splitted_idx[\"train\"], model, criterion, optimizer, evaluator\n        )\n\n        if epoch >= 3:\n            dur.append(time.time() - t0)\n            avg_time = np.mean(dur)\n        else:\n            avg_time = None\n\n        train_score, val_score, test_score = run_an_eval_epoch(\n            graph, splitted_idx, model, evaluator\n        )\n\n        scheduler.step(val_score)\n\n        # Early stop\n        if val_score > best_val_score:\n            torch.save(model.state_dict(), model_path)\n            best_val_score = val_score\n            num_patient_epochs = 0\n        else:\n            num_patient_epochs += 1\n\n        print(\n            \"Epoch {:d}, loss {:.4f}, train score {:.4f}, \"\n            \"val score {:.4f}, avg time {}, num patient epochs {:d}\".format(\n                epoch,\n                loss,\n                train_score,\n                val_score,\n                avg_time,\n                num_patient_epochs,\n            )\n        )\n\n        if num_patient_epochs == args[\"patience\"]:\n            break\n\n    model.load_state_dict(torch.load(model_path, weights_only=False))\n    train_score, val_score, test_score = run_an_eval_epoch(\n        graph, splitted_idx, model, evaluator\n    )\n    print(\"Train score {:.4f}\".format(train_score))\n    print(\"Valid score {:.4f}\".format(val_score))\n    print(\"Test score {:.4f}\".format(test_score))\n\n    with open(\"results.txt\", \"w\") as f:\n        f.write(\"loss {:.4f}\\n\".format(loss))\n        f.write(\"Best validation rocauc {:.4f}\\n\".format(best_val_score))\n        f.write(\"Test rocauc {:.4f}\\n\".format(test_score))\n\n    print(args)\n\n\nif __name__ == \"__main__\":\n    import argparse\n\n    from configure import get_exp_configure\n\n    parser = argparse.ArgumentParser(\n        description=\"OGB node property prediction with DGL using full graph training\"\n    )\n    parser.add_argument(\n        \"-m\",\n        \"--model\",\n        type=str,\n        choices=[\"MWE-GCN\", \"MWE-DGCN\"],\n        default=\"MWE-DGCN\",\n        help=\"Model to use\",\n    )\n    parser.add_argument(\"-c\", \"--cuda\", type=str, default=\"none\")\n    parser.add_argument(\n        \"--postfix\",\n        type=str,\n        default=\"\",\n        help=\"a string appended to the file name of the saved model\",\n    )\n    parser.add_argument(\n        \"--rand_seed\",\n        type=int,\n        default=-1,\n        help=\"random seed for torch and numpy\",\n    )\n    parser.add_argument(\"--residual\", action=\"store_true\")\n    parser.add_argument(\n        \"--ewnorm\", type=str, default=\"none\", choices=[\"none\", \"both\"]\n    )\n    args = parser.parse_args().__dict__\n\n    # Get experiment configuration\n    args[\"dataset\"] = \"ogbn-proteins\"\n    args[\"exp_name\"] = \"_\".join([args[\"model\"], args[\"dataset\"]])\n    args.update(get_exp_configure(args))\n\n    if not (args[\"cuda\"] == \"none\"):\n        args[\"device\"] = torch.device(\"cuda: \" + str(args[\"cuda\"]))\n    else:\n        args[\"device\"] = torch.device(\"cpu\")\n\n    main(args)\n"
  },
  {
    "path": "examples/pytorch/ogb/ogbn-proteins/models.py",
    "content": "import math\nfrom functools import partial\n\nimport dgl.function as fn\nimport dgl.nn.pytorch as dglnn\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dgl import function as fn\nfrom dgl._ffi.base import DGLError\nfrom dgl.base import ALL\nfrom dgl.nn.pytorch.utils import Identity\nfrom dgl.ops import edge_softmax\nfrom dgl.utils import expand_as_pair\nfrom torch.nn import init\nfrom torch.utils.checkpoint import checkpoint\n\n\nclass MWEConv(nn.Module):\n    def __init__(\n        self,\n        in_feats,\n        out_feats,\n        activation,\n        bias=True,\n        num_channels=8,\n        aggr_mode=\"sum\",\n    ):\n        super(MWEConv, self).__init__()\n        self.num_channels = num_channels\n        self._in_feats = in_feats\n        self._out_feats = out_feats\n        self.weight = nn.Parameter(\n            torch.Tensor(in_feats, out_feats, num_channels)\n        )\n\n        if bias:\n            self.bias = nn.Parameter(torch.Tensor(out_feats, num_channels))\n        else:\n            self.bias = None\n        self.reset_parameters()\n        self.activation = activation\n\n        if aggr_mode == \"concat\":\n            self.aggr_mode = \"concat\"\n            self.final = nn.Linear(out_feats * self.num_channels, out_feats)\n        elif aggr_mode == \"sum\":\n            self.aggr_mode = \"sum\"\n            self.final = nn.Linear(out_feats, out_feats)\n\n    def reset_parameters(self):\n        stdv = 1.0 / math.sqrt(self.weight.size(1))\n        self.weight.data.uniform_(-stdv, stdv)\n        if self.bias is not None:\n            stdv = 1.0 / math.sqrt(self.bias.size(0))\n            self.bias.data.uniform_(-stdv, stdv)\n\n    def forward(self, g, node_state_prev):\n        node_state = node_state_prev\n\n        # if self.dropout:\n        #     node_states = self.dropout(node_state)\n\n        g = g.local_var()\n\n        new_node_states = []\n\n        ## perform weighted convolution for every channel of edge weight\n        for c in range(self.num_channels):\n            node_state_c = node_state\n            if self._out_feats < self._in_feats:\n                g.ndata[\"feat_\" + str(c)] = torch.mm(\n                    node_state_c, self.weight[:, :, c]\n                )\n            else:\n                g.ndata[\"feat_\" + str(c)] = node_state_c\n            g.update_all(\n                fn.u_mul_e(\"feat_\" + str(c), \"feat_\" + str(c), \"m\"),\n                fn.sum(\"m\", \"feat_\" + str(c) + \"_new\"),\n            )\n            node_state_c = g.ndata.pop(\"feat_\" + str(c) + \"_new\")\n            if self._out_feats >= self._in_feats:\n                node_state_c = torch.mm(node_state_c, self.weight[:, :, c])\n            if self.bias is not None:\n                node_state_c = node_state_c + self.bias[:, c]\n            node_state_c = self.activation(node_state_c)\n            new_node_states.append(node_state_c)\n        if self.aggr_mode == \"sum\":\n            node_states = torch.stack(new_node_states, dim=1).sum(1)\n        elif self.aggr_mode == \"concat\":\n            node_states = torch.cat(new_node_states, dim=1)\n\n        node_states = self.final(node_states)\n\n        return node_states\n\n\nclass MWE_GCN(nn.Module):\n    def __init__(\n        self,\n        n_input,\n        n_hidden,\n        n_output,\n        n_layers,\n        activation,\n        dropout,\n        aggr_mode=\"sum\",\n        device=\"cpu\",\n    ):\n        super(MWE_GCN, self).__init__()\n        self.dropout = dropout\n        self.activation = activation\n        self.layers = nn.ModuleList()\n\n        self.layers.append(\n            MWEConv(\n                n_input, n_hidden, activation=activation, aggr_mode=aggr_mode\n            )\n        )\n        for i in range(n_layers - 1):\n            self.layers.append(\n                MWEConv(\n                    n_hidden,\n                    n_hidden,\n                    activation=activation,\n                    aggr_mode=aggr_mode,\n                )\n            )\n\n        self.pred_out = nn.Linear(n_hidden, n_output)\n        self.device = device\n\n    def forward(self, g, node_state=None):\n        node_state = torch.ones(g.num_nodes(), 1).float().to(self.device)\n\n        for layer in self.layers:\n            node_state = F.dropout(\n                node_state, p=self.dropout, training=self.training\n            )\n            node_state = layer(g, node_state)\n            node_state = self.activation(node_state)\n\n        out = self.pred_out(node_state)\n        return out\n\n\nclass MWE_DGCN(nn.Module):\n    def __init__(\n        self,\n        n_input,\n        n_hidden,\n        n_output,\n        n_layers,\n        activation,\n        dropout,\n        residual=False,\n        aggr_mode=\"sum\",\n        device=\"cpu\",\n    ):\n        super(MWE_DGCN, self).__init__()\n        self.n_layers = n_layers\n        self.activation = activation\n        self.dropout = dropout\n        self.residual = residual\n\n        self.layers = nn.ModuleList()\n        self.layer_norms = nn.ModuleList()\n\n        self.layers.append(\n            MWEConv(\n                n_input, n_hidden, activation=activation, aggr_mode=aggr_mode\n            )\n        )\n\n        for i in range(n_layers - 1):\n            self.layers.append(\n                MWEConv(\n                    n_hidden,\n                    n_hidden,\n                    activation=activation,\n                    aggr_mode=aggr_mode,\n                )\n            )\n\n        for i in range(n_layers):\n            self.layer_norms.append(\n                nn.LayerNorm(n_hidden, elementwise_affine=True)\n            )\n\n        self.pred_out = nn.Linear(n_hidden, n_output)\n        self.device = device\n\n    def forward(self, g, node_state=None):\n        node_state = torch.ones(g.num_nodes(), 1).float().to(self.device)\n\n        node_state = self.layers[0](g, node_state)\n\n        for layer in range(1, self.n_layers):\n            node_state_new = self.layer_norms[layer - 1](node_state)\n            node_state_new = self.activation(node_state_new)\n            node_state_new = F.dropout(\n                node_state_new, p=self.dropout, training=self.training\n            )\n\n            if self.residual == \"true\":\n                node_state = node_state + self.layers[layer](g, node_state_new)\n            else:\n                node_state = self.layers[layer](g, node_state_new)\n\n        node_state = self.layer_norms[self.n_layers - 1](node_state)\n        node_state = self.activation(node_state)\n        node_state = F.dropout(\n            node_state, p=self.dropout, training=self.training\n        )\n\n        out = self.pred_out(node_state)\n\n        return out\n\n\nclass GATConv(nn.Module):\n    def __init__(\n        self,\n        node_feats,\n        edge_feats,\n        out_feats,\n        n_heads=1,\n        attn_drop=0.0,\n        edge_drop=0.0,\n        negative_slope=0.2,\n        residual=True,\n        activation=None,\n        use_attn_dst=True,\n        allow_zero_in_degree=True,\n        use_symmetric_norm=False,\n    ):\n        super(GATConv, self).__init__()\n        self._n_heads = n_heads\n        self._in_src_feats, self._in_dst_feats = expand_as_pair(node_feats)\n        self._out_feats = out_feats\n        self._allow_zero_in_degree = allow_zero_in_degree\n        self._use_symmetric_norm = use_symmetric_norm\n\n        # feat fc\n        self.src_fc = nn.Linear(\n            self._in_src_feats, out_feats * n_heads, bias=False\n        )\n        if residual:\n            self.dst_fc = nn.Linear(self._in_src_feats, out_feats * n_heads)\n            self.bias = None\n        else:\n            self.dst_fc = None\n            self.bias = nn.Parameter(out_feats * n_heads)\n\n        # attn fc\n        self.attn_src_fc = nn.Linear(self._in_src_feats, n_heads, bias=False)\n        if use_attn_dst:\n            self.attn_dst_fc = nn.Linear(\n                self._in_src_feats, n_heads, bias=False\n            )\n        else:\n            self.attn_dst_fc = None\n        if edge_feats > 0:\n            self.attn_edge_fc = nn.Linear(edge_feats, n_heads, bias=False)\n        else:\n            self.attn_edge_fc = None\n\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.edge_drop = edge_drop\n        self.leaky_relu = nn.LeakyReLU(negative_slope, inplace=True)\n        self.activation = activation\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        gain = nn.init.calculate_gain(\"relu\")\n        nn.init.xavier_normal_(self.src_fc.weight, gain=gain)\n        if self.dst_fc is not None:\n            nn.init.xavier_normal_(self.dst_fc.weight, gain=gain)\n\n        nn.init.xavier_normal_(self.attn_src_fc.weight, gain=gain)\n        if self.attn_dst_fc is not None:\n            nn.init.xavier_normal_(self.attn_dst_fc.weight, gain=gain)\n        if self.attn_edge_fc is not None:\n            nn.init.xavier_normal_(self.attn_edge_fc.weight, gain=gain)\n\n        if self.bias is not None:\n            nn.init.zeros_(self.bias)\n\n    def set_allow_zero_in_degree(self, set_value):\n        self._allow_zero_in_degree = set_value\n\n    def forward(self, graph, feat_src, feat_edge=None):\n        with graph.local_scope():\n            if not self._allow_zero_in_degree:\n                if (graph.in_degrees() == 0).any():\n                    assert False\n\n            if graph.is_block:\n                feat_dst = feat_src[: graph.number_of_dst_nodes()]\n            else:\n                feat_dst = feat_src\n\n            if self._use_symmetric_norm:\n                degs = graph.srcdata[\"deg\"]\n                # degs = graph.out_degrees().float().clamp(min=1)\n                norm = torch.pow(degs, -0.5)\n                shp = norm.shape + (1,) * (feat_src.dim() - 1)\n                norm = torch.reshape(norm, shp)\n                feat_src = feat_src * norm\n\n            feat_src_fc = self.src_fc(feat_src).view(\n                -1, self._n_heads, self._out_feats\n            )\n            feat_dst_fc = self.dst_fc(feat_dst).view(\n                -1, self._n_heads, self._out_feats\n            )\n            attn_src = self.attn_src_fc(feat_src).view(-1, self._n_heads, 1)\n\n            # NOTE: GAT paper uses \"first concatenation then linear projection\"\n            # to compute attention scores, while ours is \"first projection then\n            # addition\", the two approaches are mathematically equivalent:\n            # We decompose the weight vector a mentioned in the paper into\n            # [a_l || a_r], then\n            # a^T [Wh_i || Wh_j] = a_l Wh_i + a_r Wh_j\n            # Our implementation is much efficient because we do not need to\n            # save [Wh_i || Wh_j] on edges, which is not memory-efficient. Plus,\n            # addition could be optimized with DGL's built-in function u_add_v,\n            # which further speeds up computation and saves memory footprint.\n            graph.srcdata.update(\n                {\"feat_src_fc\": feat_src_fc, \"attn_src\": attn_src}\n            )\n\n            if self.attn_dst_fc is not None:\n                attn_dst = self.attn_dst_fc(feat_dst).view(-1, self._n_heads, 1)\n                graph.dstdata.update({\"attn_dst\": attn_dst})\n                graph.apply_edges(\n                    fn.u_add_v(\"attn_src\", \"attn_dst\", \"attn_node\")\n                )\n            else:\n                graph.apply_edges(fn.copy_u(\"attn_src\", \"attn_node\"))\n\n            e = graph.edata[\"attn_node\"]\n            if feat_edge is not None:\n                attn_edge = self.attn_edge_fc(feat_edge).view(\n                    -1, self._n_heads, 1\n                )\n                graph.edata.update({\"attn_edge\": attn_edge})\n                e += graph.edata[\"attn_edge\"]\n            e = self.leaky_relu(e)\n\n            if self.training and self.edge_drop > 0:\n                perm = torch.randperm(graph.num_edges(), device=e.device)\n                bound = int(graph.num_edges() * self.edge_drop)\n                eids = perm[bound:]\n                graph.edata[\"a\"] = torch.zeros_like(e)\n                graph.edata[\"a\"][eids] = self.attn_drop(\n                    edge_softmax(graph, e[eids], eids=eids)\n                )\n            else:\n                graph.edata[\"a\"] = self.attn_drop(edge_softmax(graph, e))\n\n            # message passing\n            graph.update_all(\n                fn.u_mul_e(\"feat_src_fc\", \"a\", \"m\"), fn.sum(\"m\", \"feat_src_fc\")\n            )\n\n            rst = graph.dstdata[\"feat_src_fc\"]\n\n            if self._use_symmetric_norm:\n                degs = graph.dstdata[\"deg\"]\n                # degs = graph.in_degrees().float().clamp(min=1)\n                norm = torch.pow(degs, 0.5)\n                shp = norm.shape + (1,) * (feat_dst.dim())\n                norm = torch.reshape(norm, shp)\n                rst = rst * norm\n\n            # residual\n            if self.dst_fc is not None:\n                rst += feat_dst_fc\n            else:\n                rst += self.bias\n\n            # activation\n            if self.activation is not None:\n                rst = self.activation(rst, inplace=True)\n\n            return rst\n\n\nclass GAT(nn.Module):\n    def __init__(\n        self,\n        node_feats,\n        edge_feats,\n        n_classes,\n        n_layers,\n        n_heads,\n        n_hidden,\n        edge_emb,\n        activation,\n        dropout,\n        input_drop,\n        attn_drop,\n        edge_drop,\n        use_attn_dst=True,\n        allow_zero_in_degree=False,\n    ):\n        super().__init__()\n        self.n_layers = n_layers\n        self.n_heads = n_heads\n        self.n_hidden = n_hidden\n        self.n_classes = n_classes\n\n        self.convs = nn.ModuleList()\n        self.norms = nn.ModuleList()\n\n        self.node_encoder = nn.Linear(node_feats, n_hidden)\n        if edge_emb > 0:\n            self.edge_encoder = nn.ModuleList()\n\n        for i in range(n_layers):\n            in_hidden = n_heads * n_hidden if i > 0 else n_hidden\n            out_hidden = n_hidden\n            # bias = i == n_layers - 1\n\n            if edge_emb > 0:\n                self.edge_encoder.append(nn.Linear(edge_feats, edge_emb))\n            self.convs.append(\n                GATConv(\n                    in_hidden,\n                    edge_emb,\n                    out_hidden,\n                    n_heads=n_heads,\n                    attn_drop=attn_drop,\n                    edge_drop=edge_drop,\n                    use_attn_dst=use_attn_dst,\n                    allow_zero_in_degree=allow_zero_in_degree,\n                    use_symmetric_norm=False,\n                )\n            )\n            self.norms.append(nn.BatchNorm1d(n_heads * out_hidden))\n\n        self.pred_linear = nn.Linear(n_heads * n_hidden, n_classes)\n\n        self.input_drop = nn.Dropout(input_drop)\n        self.dropout = nn.Dropout(dropout)\n        self.activation = activation\n\n    def forward(self, g):\n        if not isinstance(g, list):\n            subgraphs = [g] * self.n_layers\n        else:\n            subgraphs = g\n\n        h = subgraphs[0].srcdata[\"feat\"]\n        h = self.node_encoder(h)\n        h = F.relu(h, inplace=True)\n        h = self.input_drop(h)\n\n        h_last = None\n\n        for i in range(self.n_layers):\n            if self.edge_encoder is not None:\n                efeat = subgraphs[i].edata[\"feat\"]\n                efeat_emb = self.edge_encoder[i](efeat)\n                efeat_emb = F.relu(efeat_emb, inplace=True)\n            else:\n                efeat_emb = None\n\n            h = self.convs[i](subgraphs[i], h, efeat_emb).flatten(1, -1)\n\n            if h_last is not None:\n                h += h_last[: h.shape[0], :]\n\n            h_last = h\n\n            h = self.norms[i](h)\n            h = self.activation(h, inplace=True)\n            h = self.dropout(h)\n\n        h = self.pred_linear(h)\n\n        return h\n"
  },
  {
    "path": "examples/pytorch/ogb/ogbn-proteins/utils.py",
    "content": "import random\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom models import MWE_DGCN, MWE_GCN\n\n\ndef set_random_seed(seed):\n    random.seed(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n    if torch.cuda.is_available():\n        torch.cuda.manual_seed(seed)\n    print(\"random seed set to be \" + str(seed))\n\n\ndef load_model(args):\n    if args[\"model\"] == \"MWE-GCN\":\n        model = MWE_GCN(\n            n_input=args[\"in_feats\"],\n            n_hidden=args[\"hidden_feats\"],\n            n_output=args[\"out_feats\"],\n            n_layers=args[\"n_layers\"],\n            activation=torch.nn.Tanh(),\n            dropout=args[\"dropout\"],\n            aggr_mode=args[\"aggr_mode\"],\n            device=args[\"device\"],\n        )\n    elif args[\"model\"] == \"MWE-DGCN\":\n        model = MWE_DGCN(\n            n_input=args[\"in_feats\"],\n            n_hidden=args[\"hidden_feats\"],\n            n_output=args[\"out_feats\"],\n            n_layers=args[\"n_layers\"],\n            activation=torch.nn.ReLU(),\n            dropout=args[\"dropout\"],\n            aggr_mode=args[\"aggr_mode\"],\n            residual=args[\"residual\"],\n            device=args[\"device\"],\n        )\n    else:\n        raise ValueError(\"Unexpected model {}\".format(args[\"model\"]))\n\n    return model\n\n\nclass Logger(object):\n    def __init__(self, runs, info=None):\n        self.info = info\n        self.results = [[] for _ in range(runs)]\n\n    def add_result(self, run, result):\n        assert len(result) == 3\n        assert run >= 0 and run < len(self.results)\n        self.results[run].append(result)\n\n    def print_statistics(self, run=None):\n        if run is not None:\n            result = 100 * torch.tensor(self.results[run])\n            argmax = result[:, 1].argmax().item()\n            print(f\"Run {run + 1:02d}:\")\n            print(f\"Highest Train: {result[:, 0].max():.2f}\")\n            print(f\"Highest Valid: {result[:, 1].max():.2f}\")\n            print(f\"  Final Train: {result[argmax, 0]:.2f}\")\n            print(f\"   Final Test: {result[argmax, 2]:.2f}\")\n        else:\n            result = 100 * torch.tensor(self.results)\n\n            best_results = []\n            for r in result:\n                train1 = r[:, 0].max().item()\n                valid = r[:, 1].max().item()\n                train2 = r[r[:, 1].argmax(), 0].item()\n                test = r[r[:, 1].argmax(), 2].item()\n                best_results.append((train1, valid, train2, test))\n\n            best_result = torch.tensor(best_results)\n\n            print(f\"All runs:\")\n            r = best_result[:, 0]\n            print(f\"Highest Train: {r.mean():.2f} ± {r.std():.2f}\")\n            r = best_result[:, 1]\n            print(f\"Highest Valid: {r.mean():.2f} ± {r.std():.2f}\")\n            r = best_result[:, 2]\n            print(f\"  Final Train: {r.mean():.2f} ± {r.std():.2f}\")\n            r = best_result[:, 3]\n            print(f\"   Final Test: {r.mean():.2f} ± {r.std():.2f}\")\n"
  },
  {
    "path": "examples/pytorch/ogb/seal_ogbl/README.md",
    "content": "# SEAL Implementation for OGBL in DGL\n\nIntroduction\n------------\nThis is an example of implementing [SEAL](https://arxiv.org/pdf/2010.16103.pdf) for link prediction in DGL. Some parts are migrated from [https://github.com/facebookresearch/SEAL_OGB](https://github.com/facebookresearch/SEAL_OGB).\n\nRequirements\n------------\n[PyTorch](https://pytorch.org/), [DGL](https://www.dgl.ai/), [OGB](https://ogb.stanford.edu/docs/home/), and other python libraries: numpy, scipy, tqdm, scikit-learn, etc.\n\nUsages\n------\nRun the following command for results on each benchmark\n```bash\n# ogbl-ppa\npython main.py \\\n    --dataset ogbl-ppa \\\n    --use_feature \\\n    --use_edge_weight \\\n    --eval_steps 5 \\\n    --epochs 20 \\\n    --train_percent 5 \n\n# ogbl-collab\npython main.py \\\n    --dataset ogbl-collab \\\n    --train_percent 15 \\\n    --hidden_channels 256 \\\n    --use_valedges_as_input\n\n# ogbl-ddi\npython main.py \\\n    --dataset ogbl-ddi \\\n    --ratio_per_hop 0.2 \\\n    --use_edge_weight \\\n    --eval_steps 1 \\\n    --epochs 10 \\\n    --train_percent 5\n\n# ogbl-citation2\npython main.py \\\n    --dataset ogbl-citation2 \\\n    --use_feature \\\n    --use_edge_weight \\\n    --eval_steps 1 \\\n    --epochs 10 \\\n    --train_percent 2 \\\n    --val_percent 1 \\\n    --test_percent 1\n```\n\nResults\n-------\n\n|              | ogbl-ppa (Hits@100) | ogbl-collab (Hits@50) | ogbl-ddi (Hits@20) | ogbl-citation2 (MRRd) |\n|--------------|---------------------|-----------------------|--------------------|---------------------|\n| Paper Test Results |  48.80%&plusmn;3.16% |    64.74%&plusmn;0.43% | 30.56%&plusmn;3.86%* |   87.67%&plusmn;0.32r% |\n| Our Test Results |  49.48%&plusmn;2.52% |    64.23%&plusmn;0.57% | 27.93%&plusmn;4.19% |   86.29%&plusmn;0.47% |\n\n\\* Note that the relatively large gap on ogbl-ddi may come from the high variance of results on this dataset. We get 28.77%&plusmn;3.43% by only changing the sampling seed.\n\nReference\n---------\n\n    @article{zhang2021labeling,\n        title={Labeling Trick: A Theory of Using Graph Neural Networks for Multi-Node Representation Learning},\n        author={Zhang, Muhan and Li, Pan and Xia, Yinglong and Wang, Kai and Jin, Long},\n        journal={Advances in Neural Information Processing Systems},\n        volume={34},\n        year={2021}\n        }\n\n    @inproceedings{zhang2018link,\n      title={Link prediction based on graph neural networks},\n      author={Zhang, Muhan and Chen, Yixin},\n      booktitle={Advances in Neural Information Processing Systems},\n      pages={5165--5175},\n      year={2018}\n    }\n"
  },
  {
    "path": "examples/pytorch/ogb/seal_ogbl/main.py",
    "content": "import argparse\nimport math\nimport os\nimport random\nimport sys\nimport time\n\nimport dgl\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom dgl.dataloading import DataLoader, Sampler\nfrom dgl.nn import GraphConv, SortPooling\nfrom dgl.sampling import global_uniform_negative_sampling\nfrom ogb.linkproppred import DglLinkPropPredDataset, Evaluator\nfrom scipy.sparse.csgraph import shortest_path\nfrom torch.nn import (\n    BCEWithLogitsLoss,\n    Conv1d,\n    Embedding,\n    Linear,\n    MaxPool1d,\n    ModuleList,\n)\nfrom tqdm import tqdm\n\n\nclass Logger(object):\n    def __init__(self, runs, info=None):\n        self.info = info\n        self.results = [[] for _ in range(runs)]\n\n    def add_result(self, run, result):\n        # result is in the format of (val_score, test_score)\n        assert len(result) == 2\n        assert run >= 0 and run < len(self.results)\n        self.results[run].append(result)\n\n    def print_statistics(self, run=None, f=sys.stdout):\n        if run is not None:\n            result = 100 * torch.tensor(self.results[run])\n            argmax = result[:, 0].argmax().item()\n            print(f\"Run {run + 1:02d}:\", file=f)\n            print(f\"Highest Valid: {result[:, 0].max():.2f}\", file=f)\n            print(f\"Highest Eval Point: {argmax + 1}\", file=f)\n            print(f\"   Final Test: {result[argmax, 1]:.2f}\", file=f)\n        else:\n            result = 100 * torch.tensor(self.results)\n\n            best_results = []\n            for r in result:\n                valid = r[:, 0].max().item()\n                test = r[r[:, 0].argmax(), 1].item()\n                best_results.append((valid, test))\n\n            best_result = torch.tensor(best_results)\n\n            print(f\"All runs:\", file=f)\n            r = best_result[:, 0]\n            print(f\"Highest Valid: {r.mean():.2f} ± {r.std():.2f}\", file=f)\n            r = best_result[:, 1]\n            print(f\"   Final Test: {r.mean():.2f} ± {r.std():.2f}\", file=f)\n\n\nclass SealSampler(Sampler):\n    def __init__(\n        self,\n        g,\n        num_hops=1,\n        sample_ratio=1.0,\n        directed=False,\n        prefetch_node_feats=None,\n        prefetch_edge_feats=None,\n    ):\n        super().__init__()\n        self.g = g\n        self.num_hops = num_hops\n        self.sample_ratio = sample_ratio\n        self.directed = directed\n        self.prefetch_node_feats = prefetch_node_feats\n        self.prefetch_edge_feats = prefetch_edge_feats\n\n    def _double_radius_node_labeling(self, adj):\n        N = adj.shape[0]\n        adj_wo_src = adj[range(1, N), :][:, range(1, N)]\n        idx = list(range(1)) + list(range(2, N))\n        adj_wo_dst = adj[idx, :][:, idx]\n\n        dist2src = shortest_path(\n            adj_wo_dst, directed=False, unweighted=True, indices=0\n        )\n        dist2src = np.insert(dist2src, 1, 0, axis=0)\n        dist2src = torch.from_numpy(dist2src)\n\n        dist2dst = shortest_path(\n            adj_wo_src, directed=False, unweighted=True, indices=0\n        )\n        dist2dst = np.insert(dist2dst, 0, 0, axis=0)\n        dist2dst = torch.from_numpy(dist2dst)\n\n        dist = dist2src + dist2dst\n        dist_over_2, dist_mod_2 = (\n            torch.div(dist, 2, rounding_mode=\"floor\"),\n            dist % 2,\n        )\n\n        z = 1 + torch.min(dist2src, dist2dst)\n        z += dist_over_2 * (dist_over_2 + dist_mod_2 - 1)\n        z[0:2] = 1.0\n        # shortest path may include inf values\n        z[torch.isnan(z)] = 0.0\n\n        return z.to(torch.long)\n\n    def sample(self, aug_g, seed_edges):\n        g = self.g\n        subgraphs = []\n        # construct k-hop enclosing graph for each link\n        for eid in seed_edges:\n            src, dst = map(int, aug_g.find_edges(eid))\n            # construct the enclosing graph\n            visited, nodes, fringe = [np.unique([src, dst]) for _ in range(3)]\n            for _ in range(self.num_hops):\n                if not self.directed:\n                    _, fringe = g.out_edges(fringe)\n                else:\n                    _, out_neighbors = g.out_edges(fringe)\n                    in_neighbors, _ = g.in_edges(fringe)\n                    fringe = np.union1d(in_neighbors, out_neighbors)\n                fringe = np.setdiff1d(fringe, visited)\n                visited = np.union1d(visited, fringe)\n                if self.sample_ratio < 1.0:\n                    fringe = np.random.choice(\n                        fringe,\n                        int(self.sample_ratio * len(fringe)),\n                        replace=False,\n                    )\n                if len(fringe) == 0:\n                    break\n                nodes = np.union1d(nodes, fringe)\n            subg = g.subgraph(nodes, store_ids=True)\n\n            # remove edges to predict\n            edges_to_remove = [\n                subg.edge_ids(s, t)\n                for s, t in [(0, 1), (1, 0)]\n                if subg.has_edges_between(s, t)\n            ]\n            subg.remove_edges(edges_to_remove)\n            # add double radius node labeling\n            subg.ndata[\"z\"] = self._double_radius_node_labeling(\n                subg.adj_external(scipy_fmt=\"csr\")\n            )\n            subg_aug = subg.add_self_loop()\n            if \"weight\" in subg.edata:\n                subg_aug.edata[\"weight\"][subg.num_edges() :] = torch.ones(\n                    subg_aug.num_edges() - subg.num_edges()\n                )\n            subgraphs.append(subg_aug)\n\n        subgraphs = dgl.batch(subgraphs)\n        dgl.set_src_lazy_features(subg_aug, self.prefetch_node_feats)\n        dgl.set_edge_lazy_features(subg_aug, self.prefetch_edge_feats)\n\n        return subgraphs, aug_g.edata[\"y\"][seed_edges]\n\n\n# An end-to-end deep learning architecture for graph classification, AAAI-18.\nclass DGCNN(torch.nn.Module):\n    def __init__(\n        self, hidden_channels, num_layers, k, GNN=GraphConv, feature_dim=0\n    ):\n        super(DGCNN, self).__init__()\n        self.feature_dim = feature_dim\n        self.k = k\n        self.sort_pool = SortPooling(k=k)\n\n        self.max_z = 1000\n        self.z_embedding = Embedding(self.max_z, hidden_channels)\n\n        self.convs = ModuleList()\n        initial_channels = hidden_channels + self.feature_dim\n\n        self.convs.append(GNN(initial_channels, hidden_channels))\n        for _ in range(0, num_layers - 1):\n            self.convs.append(GNN(hidden_channels, hidden_channels))\n        self.convs.append(GNN(hidden_channels, 1))\n\n        conv1d_channels = [16, 32]\n        total_latent_dim = hidden_channels * num_layers + 1\n        conv1d_kws = [total_latent_dim, 5]\n        self.conv1 = Conv1d(1, conv1d_channels[0], conv1d_kws[0], conv1d_kws[0])\n        self.maxpool1d = MaxPool1d(2, 2)\n        self.conv2 = Conv1d(\n            conv1d_channels[0], conv1d_channels[1], conv1d_kws[1], 1\n        )\n        dense_dim = int((self.k - 2) / 2 + 1)\n        dense_dim = (dense_dim - conv1d_kws[1] + 1) * conv1d_channels[1]\n        self.lin1 = Linear(dense_dim, 128)\n        self.lin2 = Linear(128, 1)\n\n    def forward(self, g, z, x=None, edge_weight=None):\n        z_emb = self.z_embedding(z)\n        if z_emb.ndim == 3:  # in case z has multiple integer labels\n            z_emb = z_emb.sum(dim=1)\n        if x is not None:\n            x = torch.cat([z_emb, x.to(torch.float)], 1)\n        else:\n            x = z_emb\n        xs = [x]\n\n        for conv in self.convs:\n            xs += [torch.tanh(conv(g, xs[-1], edge_weight=edge_weight))]\n        x = torch.cat(xs[1:], dim=-1)\n\n        # global pooling\n        x = self.sort_pool(g, x)\n        x = x.unsqueeze(1)  # [num_graphs, 1, k * hidden]\n        x = F.relu(self.conv1(x))\n        x = self.maxpool1d(x)\n        x = F.relu(self.conv2(x))\n        x = x.view(x.size(0), -1)  # [num_graphs, dense_dim]\n\n        # MLP.\n        x = F.relu(self.lin1(x))\n        x = F.dropout(x, p=0.5, training=self.training)\n        x = self.lin2(x)\n        return x\n\n\ndef get_pos_neg_edges(split, split_edge, g, percent=100):\n    pos_edge = split_edge[split][\"edge\"]\n    if split == \"train\":\n        neg_edge = torch.stack(\n            global_uniform_negative_sampling(\n                g, num_samples=pos_edge.size(0), exclude_self_loops=True\n            ),\n            dim=1,\n        )\n    else:\n        neg_edge = split_edge[split][\"edge_neg\"]\n\n    # sampling according to the percent param\n    np.random.seed(123)\n    # pos sampling\n    num_pos = pos_edge.size(0)\n    perm = np.random.permutation(num_pos)\n    perm = perm[: int(percent / 100 * num_pos)]\n    pos_edge = pos_edge[perm]\n    # neg sampling\n    if neg_edge.dim() > 2:  # [Np, Nn, 2]\n        neg_edge = neg_edge[perm].view(-1, 2)\n    else:\n        np.random.seed(123)\n        num_neg = neg_edge.size(0)\n        perm = np.random.permutation(num_neg)\n        perm = perm[: int(percent / 100 * num_neg)]\n        neg_edge = neg_edge[perm]\n\n    return pos_edge, neg_edge  # ([2, Np], [2, Nn]) -> ([Np, 2], [Nn, 2])\n\n\ndef train():\n    model.train()\n    loss_fnt = BCEWithLogitsLoss()\n    total_loss = 0\n    total = 0\n    pbar = tqdm(train_loader, ncols=70)\n    for gs, y in pbar:\n        optimizer.zero_grad()\n        logits = model(\n            gs,\n            gs.ndata[\"z\"],\n            gs.ndata.get(\"feat\", None),\n            edge_weight=gs.edata.get(\"weight\", None),\n        )\n        loss = loss_fnt(logits.view(-1), y.to(torch.float))\n        loss.backward()\n        optimizer.step()\n        total_loss += loss.item() * gs.batch_size\n        total += gs.batch_size\n\n    return total_loss / total\n\n\n@torch.no_grad()\ndef test():\n    model.eval()\n\n    y_pred, y_true = [], []\n    for gs, y in tqdm(val_loader, ncols=70):\n        logits = model(\n            gs,\n            gs.ndata[\"z\"],\n            gs.ndata.get(\"feat\", None),\n            edge_weight=gs.edata.get(\"weight\", None),\n        )\n        y_pred.append(logits.view(-1).cpu())\n        y_true.append(y.view(-1).cpu().to(torch.float))\n    val_pred, val_true = torch.cat(y_pred), torch.cat(y_true)\n    pos_val_pred = val_pred[val_true == 1]\n    neg_val_pred = val_pred[val_true == 0]\n\n    y_pred, y_true = [], []\n    for gs, y in tqdm(test_loader, ncols=70):\n        logits = model(\n            gs,\n            gs.ndata[\"z\"],\n            gs.ndata.get(\"feat\", None),\n            edge_weight=gs.edata.get(\"weight\", None),\n        )\n        y_pred.append(logits.view(-1).cpu())\n        y_true.append(y.view(-1).cpu().to(torch.float))\n    test_pred, test_true = torch.cat(y_pred), torch.cat(y_true)\n    pos_test_pred = test_pred[test_true == 1]\n    neg_test_pred = test_pred[test_true == 0]\n\n    if args.eval_metric == \"hits\":\n        results = evaluate_hits(\n            pos_val_pred, neg_val_pred, pos_test_pred, neg_test_pred\n        )\n    elif args.eval_metric == \"mrr\":\n        results = evaluate_mrr(\n            pos_val_pred, neg_val_pred, pos_test_pred, neg_test_pred\n        )\n\n    return results\n\n\ndef evaluate_hits(pos_val_pred, neg_val_pred, pos_test_pred, neg_test_pred):\n    results = {}\n    for K in [20, 50, 100]:\n        evaluator.K = K\n        valid_hits = evaluator.eval(\n            {\n                \"y_pred_pos\": pos_val_pred,\n                \"y_pred_neg\": neg_val_pred,\n            }\n        )[f\"hits@{K}\"]\n        test_hits = evaluator.eval(\n            {\n                \"y_pred_pos\": pos_test_pred,\n                \"y_pred_neg\": neg_test_pred,\n            }\n        )[f\"hits@{K}\"]\n\n        results[f\"Hits@{K}\"] = (valid_hits, test_hits)\n\n    return results\n\n\ndef evaluate_mrr(pos_val_pred, neg_val_pred, pos_test_pred, neg_test_pred):\n    print(\n        pos_val_pred.size(),\n        neg_val_pred.size(),\n        pos_test_pred.size(),\n        neg_test_pred.size(),\n    )\n    neg_val_pred = neg_val_pred.view(pos_val_pred.shape[0], -1)\n    neg_test_pred = neg_test_pred.view(pos_test_pred.shape[0], -1)\n    results = {}\n    valid_mrr = (\n        evaluator.eval(\n            {\n                \"y_pred_pos\": pos_val_pred,\n                \"y_pred_neg\": neg_val_pred,\n            }\n        )[\"mrr_list\"]\n        .mean()\n        .item()\n    )\n\n    test_mrr = (\n        evaluator.eval(\n            {\n                \"y_pred_pos\": pos_test_pred,\n                \"y_pred_neg\": neg_test_pred,\n            }\n        )[\"mrr_list\"]\n        .mean()\n        .item()\n    )\n\n    results[\"MRR\"] = (valid_mrr, test_mrr)\n\n    return results\n\n\nif __name__ == \"__main__\":\n    # Data settings\n    parser = argparse.ArgumentParser(description=\"OGBL (SEAL)\")\n    parser.add_argument(\"--dataset\", type=str, default=\"ogbl-collab\")\n    # GNN settings\n    parser.add_argument(\"--sortpool_k\", type=float, default=0.6)\n    parser.add_argument(\"--num_layers\", type=int, default=3)\n    parser.add_argument(\"--hidden_channels\", type=int, default=32)\n    parser.add_argument(\"--batch_size\", type=int, default=32)\n    # Subgraph extraction settings\n    parser.add_argument(\"--ratio_per_hop\", type=float, default=1.0)\n    parser.add_argument(\n        \"--use_feature\",\n        action=\"store_true\",\n        help=\"whether to use raw node features as GNN input\",\n    )\n    parser.add_argument(\n        \"--use_edge_weight\",\n        action=\"store_true\",\n        help=\"whether to consider edge weight in GNN\",\n    )\n    # Training settings\n    parser.add_argument(\"--lr\", type=float, default=0.0001)\n    parser.add_argument(\"--epochs\", type=int, default=50)\n    parser.add_argument(\"--runs\", type=int, default=10)\n    parser.add_argument(\"--train_percent\", type=float, default=100)\n    parser.add_argument(\"--val_percent\", type=float, default=100)\n    parser.add_argument(\"--test_percent\", type=float, default=100)\n    parser.add_argument(\n        \"--num_workers\",\n        type=int,\n        default=8,\n        help=\"number of workers for dynamic dataloaders\",\n    )\n    # Testing settings\n    parser.add_argument(\"--use_valedges_as_input\", action=\"store_true\")\n    parser.add_argument(\"--eval_steps\", type=int, default=1)\n    args = parser.parse_args()\n\n    data_appendix = \"_rph{}\".format(\"\".join(str(args.ratio_per_hop).split(\".\")))\n    if args.use_valedges_as_input:\n        data_appendix += \"_uvai\"\n\n    args.res_dir = os.path.join(\n        \"results/{}_{}\".format(args.dataset, time.strftime(\"%Y%m%d%H%M%S\"))\n    )\n    print(\"Results will be saved in \" + args.res_dir)\n    if not os.path.exists(args.res_dir):\n        os.makedirs(args.res_dir)\n    log_file = os.path.join(args.res_dir, \"log.txt\")\n    # Save command line input.\n    cmd_input = \"python \" + \" \".join(sys.argv) + \"\\n\"\n    with open(os.path.join(args.res_dir, \"cmd_input.txt\"), \"a\") as f:\n        f.write(cmd_input)\n    print(\"Command line input: \" + cmd_input + \" is saved.\")\n    with open(log_file, \"a\") as f:\n        f.write(\"\\n\" + cmd_input)\n\n    dataset = DglLinkPropPredDataset(name=args.dataset)\n    split_edge = dataset.get_edge_split()\n    graph = dataset[0]\n\n    # re-format the data of citation2\n    if args.dataset == \"ogbl-citation2\":\n        for k in [\"train\", \"valid\", \"test\"]:\n            src = split_edge[k][\"source_node\"]\n            tgt = split_edge[k][\"target_node\"]\n            split_edge[k][\"edge\"] = torch.stack([src, tgt], dim=1)\n            if k != \"train\":\n                tgt_neg = split_edge[k][\"target_node_neg\"]\n                split_edge[k][\"edge_neg\"] = torch.stack(\n                    [src[:, None].repeat(1, tgt_neg.size(1)), tgt_neg], dim=-1\n                )  # [Ns, Nt, 2]\n\n    # reconstruct the graph for ogbl-collab data for validation edge augmentation and coalesce\n    if args.dataset == \"ogbl-collab\":\n        graph.edata.pop(\"year\")\n        # float edata for to_simple transform\n        graph.edata[\"weight\"] = graph.edata[\"weight\"].to(torch.float)\n        if args.use_valedges_as_input:\n            val_edges = split_edge[\"valid\"][\"edge\"]\n            row, col = val_edges.t()\n            val_weights = torch.ones(size=(val_edges.size(0), 1))\n            graph.add_edges(\n                torch.cat([row, col]),\n                torch.cat([col, row]),\n                {\"weight\": val_weights},\n            )\n        graph = graph.to_simple(copy_edata=True, aggregator=\"sum\")\n\n    if not args.use_edge_weight and \"weight\" in graph.edata:\n        graph.edata.pop(\"weight\")\n    if not args.use_feature and \"feat\" in graph.ndata:\n        graph.ndata.pop(\"feat\")\n\n    if args.dataset.startswith(\"ogbl-citation\"):\n        args.eval_metric = \"mrr\"\n        directed = True\n    else:\n        args.eval_metric = \"hits\"\n        directed = False\n\n    evaluator = Evaluator(name=args.dataset)\n    if args.eval_metric == \"hits\":\n        loggers = {\n            \"Hits@20\": Logger(args.runs, args),\n            \"Hits@50\": Logger(args.runs, args),\n            \"Hits@100\": Logger(args.runs, args),\n        }\n    elif args.eval_metric == \"mrr\":\n        loggers = {\n            \"MRR\": Logger(args.runs, args),\n        }\n\n    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n    path = dataset.root + \"_seal{}\".format(data_appendix)\n\n    loaders = []\n    prefetch_node_feats = [\"feat\"] if \"feat\" in graph.ndata else None\n    prefetch_edge_feats = [\"weight\"] if \"weight\" in graph.edata else None\n\n    train_edge, train_edge_neg = get_pos_neg_edges(\n        \"train\", split_edge, graph, args.train_percent\n    )\n    val_edge, val_edge_neg = get_pos_neg_edges(\n        \"valid\", split_edge, graph, args.val_percent\n    )\n    test_edge, test_edge_neg = get_pos_neg_edges(\n        \"test\", split_edge, graph, args.test_percent\n    )\n    # create an augmented graph for sampling\n    aug_g = dgl.graph(graph.edges())\n    aug_g.edata[\"y\"] = torch.ones(aug_g.num_edges())\n    aug_edges = torch.cat(\n        [val_edge, test_edge, train_edge_neg, val_edge_neg, test_edge_neg]\n    )\n    aug_labels = torch.cat(\n        [\n            torch.ones(len(val_edge) + len(test_edge)),\n            torch.zeros(\n                len(train_edge_neg) + len(val_edge_neg) + len(test_edge_neg)\n            ),\n        ]\n    )\n    aug_g.add_edges(aug_edges[:, 0], aug_edges[:, 1], {\"y\": aug_labels})\n    # eids for sampling\n    split_len = [graph.num_edges()] + list(\n        map(\n            len,\n            [val_edge, test_edge, train_edge_neg, val_edge_neg, test_edge_neg],\n        )\n    )\n    train_eids = torch.cat(\n        [\n            graph.edge_ids(train_edge[:, 0], train_edge[:, 1]),\n            torch.arange(sum(split_len[:3]), sum(split_len[:4])),\n        ]\n    )\n    val_eids = torch.cat(\n        [\n            torch.arange(sum(split_len[:1]), sum(split_len[:2])),\n            torch.arange(sum(split_len[:4]), sum(split_len[:5])),\n        ]\n    )\n    test_eids = torch.cat(\n        [\n            torch.arange(sum(split_len[:2]), sum(split_len[:3])),\n            torch.arange(sum(split_len[:5]), sum(split_len[:6])),\n        ]\n    )\n    sampler = SealSampler(\n        graph,\n        1,\n        args.ratio_per_hop,\n        directed,\n        prefetch_node_feats,\n        prefetch_edge_feats,\n    )\n    # force to be dynamic for consistent dataloading\n    for split, shuffle, eids in zip(\n        [\"train\", \"valid\", \"test\"],\n        [True, False, False],\n        [train_eids, val_eids, test_eids],\n    ):\n        data_loader = DataLoader(\n            aug_g,\n            eids,\n            sampler,\n            shuffle=shuffle,\n            device=device,\n            batch_size=args.batch_size,\n            num_workers=args.num_workers,\n        )\n        loaders.append(data_loader)\n    train_loader, val_loader, test_loader = loaders\n\n    # convert sortpool_k from percentile to number.\n    num_nodes = []\n    for subgs, _ in train_loader:\n        subgs = dgl.unbatch(subgs)\n        if len(num_nodes) > 1000:\n            break\n        for subg in subgs:\n            num_nodes.append(subg.num_nodes())\n    num_nodes = sorted(num_nodes)\n    k = num_nodes[int(math.ceil(args.sortpool_k * len(num_nodes))) - 1]\n    k = max(k, 10)\n\n    for run in range(args.runs):\n        model = DGCNN(\n            args.hidden_channels,\n            args.num_layers,\n            k,\n            feature_dim=graph.ndata[\"feat\"].size(1) if args.use_feature else 0,\n        ).to(device)\n        parameters = list(model.parameters())\n        optimizer = torch.optim.Adam(params=parameters, lr=args.lr)\n        total_params = sum(p.numel() for param in parameters for p in param)\n        print(f\"Total number of parameters is {total_params}\")\n        print(f\"SortPooling k is set to {k}\")\n        with open(log_file, \"a\") as f:\n            print(f\"Total number of parameters is {total_params}\", file=f)\n            print(f\"SortPooling k is set to {k}\", file=f)\n\n        start_epoch = 1\n        # Training starts\n        for epoch in range(start_epoch, start_epoch + args.epochs):\n            loss = train()\n\n            if epoch % args.eval_steps == 0:\n                results = test()\n                for key, result in results.items():\n                    loggers[key].add_result(run, result)\n\n                model_name = os.path.join(\n                    args.res_dir,\n                    \"run{}_model_checkpoint{}.pth\".format(run + 1, epoch),\n                )\n                optimizer_name = os.path.join(\n                    args.res_dir,\n                    \"run{}_optimizer_checkpoint{}.pth\".format(run + 1, epoch),\n                )\n                torch.save(model.state_dict(), model_name)\n                torch.save(optimizer.state_dict(), optimizer_name)\n\n                for key, result in results.items():\n                    valid_res, test_res = result\n                    to_print = (\n                        f\"Run: {run + 1:02d}, Epoch: {epoch:02d}, \"\n                        + f\"Loss: {loss:.4f}, Valid: {100 * valid_res:.2f}%, \"\n                        + f\"Test: {100 * test_res:.2f}%\"\n                    )\n                    print(key)\n                    print(to_print)\n                    with open(log_file, \"a\") as f:\n                        print(key, file=f)\n                        print(to_print, file=f)\n\n        for key in loggers.keys():\n            print(key)\n            loggers[key].print_statistics(run)\n            with open(log_file, \"a\") as f:\n                print(key, file=f)\n                loggers[key].print_statistics(run, f=f)\n\n    for key in loggers.keys():\n        print(key)\n        loggers[key].print_statistics()\n        with open(log_file, \"a\") as f:\n            print(key, file=f)\n            loggers[key].print_statistics(f=f)\n    print(f\"Total number of parameters is {total_params}\")\n    print(f\"Results are saved in {args.res_dir}\")\n"
  },
  {
    "path": "examples/pytorch/ogb/sign/.gitignore",
    "content": "dataset\n"
  },
  {
    "path": "examples/pytorch/ogb/sign/README.md",
    "content": "SIGN: Scalable Inception Graph Neural Network\n==========================\nPaper: [https://arxiv.org/abs/2004.11198](https://arxiv.org/abs/2004.11198)\n\n\nDependencies\n------------\n- pytorch 1.5\n- dgl 0.5 nightly build\n    - `pip install --pre dgl`\n- ogb 1.2.3\n\n\nHow to run\n-------------\n### ogbn-products\n```python\npython3 sign.py --dataset ogbn-products --eval-ev 10 --R 5 --input-d 0.3 --num-h 512 \\\n    --dr 0.4 --lr 0.001 --batch-size 50000 --num-runs 10\n```\n\n### ogbn-arxiv\n```python\npython3 sign.py --dataset ogbn-arxiv --eval-ev 10 --R 5 --input-d 0.1 --num-h 512 \\\n    --dr 0.5 --lr 0.001 --eval-b 100000 --num-runs 10\n```\n\n### ogbn-mag\nogbn-mag is a heterogeneous graph and the task is to predict publishing venue\nof papers. Since SIGN model is designed for homogeneous graph, we simply ignore\nheterogeneous information (i.e. node and edge types) and treat the graph as a\nhomogeneous one. For node types that don't have input feature, we featurize them\nwith the average of their neighbors' features.\n\n```python\npython3 sign.py --dataset ogbn-mag --eval-ev 10 --R 5 --input-d 0 --num-h 512 \\\n    --dr 0.5 --lr 0.001 --batch-size 50000 --num-runs 10\n```\n\n\nResults\n----------\nTable below shows the average and standard deviation (over 10 times) of\naccuracy. Experiments were performed on Tesla T4 (15GB) GPU on Oct 29.\n\n| Dataset         | Test Accuracy   | Validation Accuracy   | # Params    |\n| :-------------: | :-------------: | :-------------------: | :---------: |\n| ogbn-products   | 0.8052±0.0016   | 0.9299±0.0004         | 3,483,703   |\n| ogbn-arxiv      | 0.7195±0.0011   | 0.7323±0.0006         | 3,566,128   |\n| ogbn-mag        | 0.4046±0.0012   | 0.4068±0.0010         | 3,724,645   |\n"
  },
  {
    "path": "examples/pytorch/ogb/sign/dataset.py",
    "content": "import dgl\nimport dgl.function as fn\nimport numpy as np\nimport torch\nfrom ogb.nodeproppred import DglNodePropPredDataset, Evaluator\n\n\ndef get_ogb_evaluator(dataset):\n    \"\"\"\n    Get evaluator from Open Graph Benchmark based on dataset\n    \"\"\"\n    evaluator = Evaluator(name=dataset)\n    return lambda preds, labels: evaluator.eval(\n        {\n            \"y_true\": labels.view(-1, 1),\n            \"y_pred\": preds.view(-1, 1),\n        }\n    )[\"acc\"]\n\n\ndef convert_mag_to_homograph(g, device):\n    \"\"\"\n    Featurize node types that don't have input features (i.e. author,\n    institution, field_of_study) by averaging their neighbor features.\n    Then convert the graph to a undirected homogeneous graph.\n    \"\"\"\n    src_writes, dst_writes = g.all_edges(etype=\"writes\")\n    src_topic, dst_topic = g.all_edges(etype=\"has_topic\")\n    src_aff, dst_aff = g.all_edges(etype=\"affiliated_with\")\n    new_g = dgl.heterograph(\n        {\n            (\"paper\", \"written\", \"author\"): (dst_writes, src_writes),\n            (\"paper\", \"has_topic\", \"field\"): (src_topic, dst_topic),\n            (\"author\", \"aff\", \"inst\"): (src_aff, dst_aff),\n        }\n    )\n    new_g = new_g.to(device)\n    new_g.nodes[\"paper\"].data[\"feat\"] = g.nodes[\"paper\"].data[\"feat\"]\n    new_g[\"written\"].update_all(fn.copy_u(\"feat\", \"m\"), fn.mean(\"m\", \"feat\"))\n    new_g[\"has_topic\"].update_all(fn.copy_u(\"feat\", \"m\"), fn.mean(\"m\", \"feat\"))\n    new_g[\"aff\"].update_all(fn.copy_u(\"feat\", \"m\"), fn.mean(\"m\", \"feat\"))\n    g.nodes[\"author\"].data[\"feat\"] = new_g.nodes[\"author\"].data[\"feat\"]\n    g.nodes[\"institution\"].data[\"feat\"] = new_g.nodes[\"inst\"].data[\"feat\"]\n    g.nodes[\"field_of_study\"].data[\"feat\"] = new_g.nodes[\"field\"].data[\"feat\"]\n\n    # Convert to homogeneous graph\n    # Get DGL type id for paper type\n    target_type_id = g.get_ntype_id(\"paper\")\n    g = dgl.to_homogeneous(g, ndata=[\"feat\"])\n    g = dgl.add_reverse_edges(g, copy_ndata=True)\n    # Mask for paper nodes\n    g.ndata[\"target_mask\"] = g.ndata[dgl.NTYPE] == target_type_id\n    return g\n\n\ndef load_dataset(name, device):\n    \"\"\"\n    Load dataset and move graph and features to device\n    \"\"\"\n    if name not in [\"ogbn-products\", \"ogbn-arxiv\", \"ogbn-mag\"]:\n        raise RuntimeError(\"Dataset {} is not supported\".format(name))\n    dataset = DglNodePropPredDataset(name=name)\n    splitted_idx = dataset.get_idx_split()\n    train_nid = splitted_idx[\"train\"]\n    val_nid = splitted_idx[\"valid\"]\n    test_nid = splitted_idx[\"test\"]\n    g, labels = dataset[0]\n    g = g.to(device)\n    if name == \"ogbn-arxiv\":\n        g = dgl.add_reverse_edges(g, copy_ndata=True)\n        g = dgl.add_self_loop(g)\n        g.ndata[\"feat\"] = g.ndata[\"feat\"].float()\n    elif name == \"ogbn-mag\":\n        # MAG is a heterogeneous graph. The task is to make prediction for\n        # paper nodes\n        labels = labels[\"paper\"]\n        train_nid = train_nid[\"paper\"]\n        val_nid = val_nid[\"paper\"]\n        test_nid = test_nid[\"paper\"]\n        g = convert_mag_to_homograph(g, device)\n    else:\n        g.ndata[\"feat\"] = g.ndata[\"feat\"].float()\n    n_classes = dataset.num_classes\n    labels = labels.squeeze()\n    evaluator = get_ogb_evaluator(name)\n\n    print(\n        f\"# Nodes: {g.num_nodes()}\\n\"\n        f\"# Edges: {g.num_edges()}\\n\"\n        f\"# Train: {len(train_nid)}\\n\"\n        f\"# Val: {len(val_nid)}\\n\"\n        f\"# Test: {len(test_nid)}\\n\"\n        f\"# Classes: {n_classes}\"\n    )\n\n    return g, labels, n_classes, train_nid, val_nid, test_nid, evaluator\n"
  },
  {
    "path": "examples/pytorch/ogb/sign/sign.py",
    "content": "import argparse\nimport time\n\nimport dgl\nimport dgl.function as fn\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nfrom dataset import load_dataset\n\n\nclass FeedForwardNet(nn.Module):\n    def __init__(self, in_feats, hidden, out_feats, n_layers, dropout):\n        super(FeedForwardNet, self).__init__()\n        self.layers = nn.ModuleList()\n        self.n_layers = n_layers\n        if n_layers == 1:\n            self.layers.append(nn.Linear(in_feats, out_feats))\n        else:\n            self.layers.append(nn.Linear(in_feats, hidden))\n            for i in range(n_layers - 2):\n                self.layers.append(nn.Linear(hidden, hidden))\n            self.layers.append(nn.Linear(hidden, out_feats))\n        if self.n_layers > 1:\n            self.prelu = nn.PReLU()\n            self.dropout = nn.Dropout(dropout)\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        gain = nn.init.calculate_gain(\"relu\")\n        for layer in self.layers:\n            nn.init.xavier_uniform_(layer.weight, gain=gain)\n            nn.init.zeros_(layer.bias)\n\n    def forward(self, x):\n        for layer_id, layer in enumerate(self.layers):\n            x = layer(x)\n            if layer_id < self.n_layers - 1:\n                x = self.dropout(self.prelu(x))\n        return x\n\n\nclass SIGN(nn.Module):\n    def __init__(\n        self,\n        in_feats,\n        hidden,\n        out_feats,\n        num_hops,\n        n_layers,\n        dropout,\n        input_drop,\n    ):\n        super(SIGN, self).__init__()\n        self.dropout = nn.Dropout(dropout)\n        self.prelu = nn.PReLU()\n        self.inception_ffs = nn.ModuleList()\n        self.input_drop = nn.Dropout(input_drop)\n        for hop in range(num_hops):\n            self.inception_ffs.append(\n                FeedForwardNet(in_feats, hidden, hidden, n_layers, dropout)\n            )\n        self.project = FeedForwardNet(\n            num_hops * hidden, hidden, out_feats, n_layers, dropout\n        )\n\n    def forward(self, feats):\n        feats = [self.input_drop(feat) for feat in feats]\n        hidden = []\n        for feat, ff in zip(feats, self.inception_ffs):\n            hidden.append(ff(feat))\n        out = self.project(self.dropout(self.prelu(torch.cat(hidden, dim=-1))))\n        return out\n\n    def reset_parameters(self):\n        for ff in self.inception_ffs:\n            ff.reset_parameters()\n        self.project.reset_parameters()\n\n\ndef get_n_params(model):\n    pp = 0\n    for p in list(model.parameters()):\n        nn = 1\n        for s in list(p.size()):\n            nn = nn * s\n        pp += nn\n    return pp\n\n\ndef neighbor_average_features(g, args):\n    \"\"\"\n    Compute multi-hop neighbor-averaged node features\n    \"\"\"\n    print(\"Compute neighbor-averaged feats\")\n    g.ndata[\"feat_0\"] = g.ndata[\"feat\"]\n    for hop in range(1, args.R + 1):\n        g.update_all(\n            fn.copy_u(f\"feat_{hop-1}\", \"msg\"), fn.mean(\"msg\", f\"feat_{hop}\")\n        )\n    res = []\n    for hop in range(args.R + 1):\n        res.append(g.ndata.pop(f\"feat_{hop}\"))\n\n    if args.dataset == \"ogbn-mag\":\n        # For MAG dataset, only return features for target node types (i.e.\n        # paper nodes)\n        target_mask = g.ndata[\"target_mask\"]\n        target_ids = g.ndata[dgl.NID][target_mask]\n        num_target = target_mask.sum().item()\n        new_res = []\n        for x in res:\n            feat = torch.zeros(\n                (num_target,) + x.shape[1:], dtype=x.dtype, device=x.device\n            )\n            feat[target_ids] = x[target_mask]\n            new_res.append(feat)\n        res = new_res\n    return res\n\n\ndef prepare_data(device, args):\n    \"\"\"\n    Load dataset and compute neighbor-averaged node features used by SIGN model\n    \"\"\"\n    data = load_dataset(args.dataset, device)\n    g, labels, n_classes, train_nid, val_nid, test_nid, evaluator = data\n    in_feats = g.ndata[\"feat\"].shape[1]\n    feats = neighbor_average_features(g, args)\n    labels = labels.to(device)\n    # move to device\n    train_nid = train_nid.to(device)\n    val_nid = val_nid.to(device)\n    test_nid = test_nid.to(device)\n    return (\n        feats,\n        labels,\n        in_feats,\n        n_classes,\n        train_nid,\n        val_nid,\n        test_nid,\n        evaluator,\n    )\n\n\ndef train(model, feats, labels, loss_fcn, optimizer, train_loader):\n    model.train()\n    device = labels.device\n    for batch in train_loader:\n        batch_feats = [x[batch].to(device) for x in feats]\n        loss = loss_fcn(model(batch_feats), labels[batch])\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n\ndef test(\n    model, feats, labels, test_loader, evaluator, train_nid, val_nid, test_nid\n):\n    model.eval()\n    device = labels.device\n    preds = []\n    for batch in test_loader:\n        batch_feats = [feat[batch].to(device) for feat in feats]\n        preds.append(torch.argmax(model(batch_feats), dim=-1))\n    # Concat mini-batch prediction results along node dimension\n    preds = torch.cat(preds, dim=0)\n    train_res = evaluator(preds[train_nid], labels[train_nid])\n    val_res = evaluator(preds[val_nid], labels[val_nid])\n    test_res = evaluator(preds[test_nid], labels[test_nid])\n    return train_res, val_res, test_res\n\n\ndef run(args, data, device):\n    (\n        feats,\n        labels,\n        in_size,\n        num_classes,\n        train_nid,\n        val_nid,\n        test_nid,\n        evaluator,\n    ) = data\n    train_loader = torch.utils.data.DataLoader(\n        train_nid, batch_size=args.batch_size, shuffle=True, drop_last=False\n    )\n    test_loader = torch.utils.data.DataLoader(\n        torch.arange(labels.shape[0]),\n        batch_size=args.eval_batch_size,\n        shuffle=False,\n        drop_last=False,\n    )\n\n    # Initialize model and optimizer for each run\n    num_hops = args.R + 1\n    model = SIGN(\n        in_size,\n        args.num_hidden,\n        num_classes,\n        num_hops,\n        args.ff_layer,\n        args.dropout,\n        args.input_dropout,\n    )\n    model = model.to(device)\n    print(\"# Params:\", get_n_params(model))\n\n    loss_fcn = nn.CrossEntropyLoss()\n    optimizer = torch.optim.Adam(\n        model.parameters(), lr=args.lr, weight_decay=args.weight_decay\n    )\n\n    # Start training\n    best_epoch = 0\n    best_val = 0\n    best_test = 0\n    for epoch in range(1, args.num_epochs + 1):\n        start = time.time()\n        train(model, feats, labels, loss_fcn, optimizer, train_loader)\n\n        if epoch % args.eval_every == 0:\n            with torch.no_grad():\n                acc = test(\n                    model,\n                    feats,\n                    labels,\n                    test_loader,\n                    evaluator,\n                    train_nid,\n                    val_nid,\n                    test_nid,\n                )\n            end = time.time()\n            log = \"Epoch {}, Time(s): {:.4f}, \".format(epoch, end - start)\n            log += \"Acc: Train {:.4f}, Val {:.4f}, Test {:.4f}\".format(*acc)\n            print(log)\n            if acc[1] > best_val:\n                best_epoch = epoch\n                best_val = acc[1]\n                best_test = acc[2]\n\n    print(\n        \"Best Epoch {}, Val {:.4f}, Test {:.4f}\".format(\n            best_epoch, best_val, best_test\n        )\n    )\n    return best_val, best_test\n\n\ndef main(args):\n    if args.gpu < 0:\n        device = \"cpu\"\n    else:\n        device = \"cuda:{}\".format(args.gpu)\n\n    with torch.no_grad():\n        data = prepare_data(device, args)\n    val_accs = []\n    test_accs = []\n    for i in range(args.num_runs):\n        print(f\"Run {i} start training\")\n        best_val, best_test = run(args, data, device)\n        val_accs.append(best_val)\n        test_accs.append(best_test)\n\n    print(\n        f\"Average val accuracy: {np.mean(val_accs):.4f}, \"\n        f\"std: {np.std(val_accs):.4f}\"\n    )\n    print(\n        f\"Average test accuracy: {np.mean(test_accs):.4f}, \"\n        f\"std: {np.std(test_accs):.4f}\"\n    )\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"SIGN\")\n    parser.add_argument(\"--num-epochs\", type=int, default=1000)\n    parser.add_argument(\"--num-hidden\", type=int, default=512)\n    parser.add_argument(\"--R\", type=int, default=5, help=\"number of hops\")\n    parser.add_argument(\"--lr\", type=float, default=0.001)\n    parser.add_argument(\"--dataset\", type=str, default=\"ogbn-mag\")\n    parser.add_argument(\n        \"--dropout\", type=float, default=0.5, help=\"dropout on activation\"\n    )\n    parser.add_argument(\"--gpu\", type=int, default=0)\n    parser.add_argument(\"--weight-decay\", type=float, default=0)\n    parser.add_argument(\"--eval-every\", type=int, default=10)\n    parser.add_argument(\"--batch-size\", type=int, default=50000)\n    parser.add_argument(\n        \"--eval-batch-size\",\n        type=int,\n        default=100000,\n        help=\"evaluation batch size\",\n    )\n    parser.add_argument(\n        \"--ff-layer\", type=int, default=2, help=\"number of feed-forward layers\"\n    )\n    parser.add_argument(\n        \"--input-dropout\",\n        type=float,\n        default=0,\n        help=\"dropout on input features\",\n    )\n    parser.add_argument(\n        \"--num-runs\",\n        type=int,\n        default=10,\n        help=\"number of times to repeat the experiment\",\n    )\n    args = parser.parse_args()\n\n    print(args)\n    main(args)\n"
  },
  {
    "path": "examples/pytorch/ogb_lsc/MAG240M/README.md",
    "content": "# Baseline Code for MAG240M\n\nThe code is ported from the R-GAT examples [here](https://github.com/snap-stanford/ogb/tree/master/examples/lsc/mag240m). Please refer to the [OGB-LSC paper](https://arxiv.org/abs/2103.09430) for the detailed setting.\n\n## Installation Requirements\n\n```\nogb>=1.3.0\ntorch>=1.7.0\n```\n\n## Running Preprocessing Script\n\n```\npython preprocess.py \\\n    --rootdir . \\\n    --author-output-path ./author.npy \\\n    --inst-output-path ./inst.npy \\\n    --graph-output-path ./graph.dgl \\\n    --graph-as-homogeneous \\\n    --full-output-path ./full.npy\n```\n\nThis will give you the following files:\n\n* `author.npy`: The author features, preprocessed by averaging the neighboring paper features.\n* `inst.npy`: The institution features, preprocessed by averaging the neighboring author features.\n* `graph.dgl`: The *homogenized* DGL graph stored in CSC format, which is friendly for neighbor sampling.\n  Edge types are stored on the edges as an `int8` feature.  Nodes are in the order of author, institution,\n  and paper.\n* `full.npy`: The concatenated author, institution, and paper features.\n\nSince that will usually take a long time, we also offer the above files for download:\n\n* [`author.npy`](https://dgl-data.s3-accelerate.amazonaws.com/dataset/OGB-LSC/author.npy)\n* [`inst.npy`](https://dgl-data.s3-accelerate.amazonaws.com/dataset/OGB-LSC/inst.npy)\n* [`graph.dgl`](https://dgl-data.s3-accelerate.amazonaws.com/dataset/OGB-LSC/graph.dgl)\n* [`full.npy`](https://dgl-data.s3-accelerate.amazonaws.com/dataset/OGB-LSC/full.npy)\n\nIn addition, we offer\n\n* [`full_feat.npy`](https://dgl-data.s3-accelerate.amazonaws.com/dataset/OGB-LSC/full_feat.npy): The preprocessed full feature matrix\n  for running OGB's own baseline. Note that the features are concatenated in the order of paper, author, and\n  institution, unlike the one in our baseline code.  It is also preprocessed in float32 arithmetics instead\n  of float16 arithmetics.\n\n## Running Training Script\n\n```\npython train.py \\\n    --rootdir . \\\n    --graph-preprocess-path ./graph.dgl \\\n    --full-preprocess-path ./full.npy\n```\n\nThe validation accuracy is 0.701.  We do not have ground truth test labels so we do not report\ntest accuracy.\n\n## Hardware configurations\n\nWe successfully run 8 experiments in parallel on an AWS p4d.24x large instance with the preprocessed feature\nmatrices stored on an NVMe SSD to enable fast disk read.  Each experiment requires less than 128GB CPU\nmemory and less than 12GB GPU memory to run.  Every epoch takes around 6 minutes 30 seconds to train and\n1 minutes 40 seconds to validate.\n\nIf your hard drive is slow, it is best to load all the features into memory for a reasonable training speed.\nThe CPU memory consumption will go up to as large as 512GB though.\n"
  },
  {
    "path": "examples/pytorch/ogb_lsc/MAG240M/preprocess.py",
    "content": "import argparse\nimport os\n\nimport dgl\nimport dgl.function as fn\n\nimport numpy as np\nimport ogb\nimport torch\nimport tqdm\nfrom ogb.lsc import MAG240MDataset\n\nparser = argparse.ArgumentParser()\nparser.add_argument(\n    \"--rootdir\",\n    type=str,\n    default=\".\",\n    help=\"Directory to download the OGB dataset.\",\n)\nparser.add_argument(\n    \"--author-output-path\", type=str, help=\"Path to store the author features.\"\n)\nparser.add_argument(\n    \"--inst-output-path\",\n    type=str,\n    help=\"Path to store the institution features.\",\n)\nparser.add_argument(\n    \"--graph-output-path\", type=str, help=\"Path to store the graph.\"\n)\nparser.add_argument(\n    \"--graph-format\",\n    type=str,\n    default=\"csc\",\n    help=\"Graph format (coo, csr or csc).\",\n)\nparser.add_argument(\n    \"--graph-as-homogeneous\",\n    action=\"store_true\",\n    help=\"Store the graph as DGL homogeneous graph.\",\n)\nparser.add_argument(\n    \"--full-output-path\",\n    type=str,\n    help=\"Path to store features of all nodes.  Effective only when graph is homogeneous.\",\n)\nargs = parser.parse_args()\n\nprint(\"Building graph\")\ndataset = MAG240MDataset(root=args.rootdir)\nei_writes = dataset.edge_index(\"author\", \"writes\", \"paper\")\nei_cites = dataset.edge_index(\"paper\", \"paper\")\nei_affiliated = dataset.edge_index(\"author\", \"institution\")\n\n# We sort the nodes starting with the papers, then the authors, then the institutions.\nauthor_offset = 0\ninst_offset = author_offset + dataset.num_authors\npaper_offset = inst_offset + dataset.num_institutions\n\ng = dgl.heterograph(\n    {\n        (\"author\", \"write\", \"paper\"): (ei_writes[0], ei_writes[1]),\n        (\"paper\", \"write-by\", \"author\"): (ei_writes[1], ei_writes[0]),\n        (\"author\", \"affiliate-with\", \"institution\"): (\n            ei_affiliated[0],\n            ei_affiliated[1],\n        ),\n        (\"institution\", \"affiliate\", \"author\"): (\n            ei_affiliated[1],\n            ei_affiliated[0],\n        ),\n        (\"paper\", \"cite\", \"paper\"): (\n            np.concatenate([ei_cites[0], ei_cites[1]]),\n            np.concatenate([ei_cites[1], ei_cites[0]]),\n        ),\n    }\n)\n\npaper_feat = dataset.paper_feat\nauthor_feat = np.memmap(\n    args.author_output_path,\n    mode=\"w+\",\n    dtype=\"float16\",\n    shape=(dataset.num_authors, dataset.num_paper_features),\n)\ninst_feat = np.memmap(\n    args.inst_output_path,\n    mode=\"w+\",\n    dtype=\"float16\",\n    shape=(dataset.num_institutions, dataset.num_paper_features),\n)\n\n# Iteratively process author features along the feature dimension.\nBLOCK_COLS = 16\nwith tqdm.trange(0, dataset.num_paper_features, BLOCK_COLS) as tq:\n    for start in tq:\n        tq.set_postfix_str(\"Reading paper features...\")\n        g.nodes[\"paper\"].data[\"x\"] = torch.FloatTensor(\n            paper_feat[:, start : start + BLOCK_COLS].astype(\"float32\")\n        )\n        # Compute author features...\n        tq.set_postfix_str(\"Computing author features...\")\n        g.update_all(fn.copy_u(\"x\", \"m\"), fn.mean(\"m\", \"x\"), etype=\"write-by\")\n        # Then institution features...\n        tq.set_postfix_str(\"Computing institution features...\")\n        g.update_all(\n            fn.copy_u(\"x\", \"m\"), fn.mean(\"m\", \"x\"), etype=\"affiliate-with\"\n        )\n        tq.set_postfix_str(\"Writing author features...\")\n        author_feat[:, start : start + BLOCK_COLS] = (\n            g.nodes[\"author\"].data[\"x\"].numpy().astype(\"float16\")\n        )\n        tq.set_postfix_str(\"Writing institution features...\")\n        inst_feat[:, start : start + BLOCK_COLS] = (\n            g.nodes[\"institution\"].data[\"x\"].numpy().astype(\"float16\")\n        )\n        del g.nodes[\"paper\"].data[\"x\"]\n        del g.nodes[\"author\"].data[\"x\"]\n        del g.nodes[\"institution\"].data[\"x\"]\nauthor_feat.flush()\ninst_feat.flush()\n\n# Convert to homogeneous if needed.  (The RGAT baseline needs homogeneous graph)\nif args.graph_as_homogeneous:\n    # Process graph\n    g = dgl.to_homogeneous(g)\n    # DGL ensures that nodes with the same type are put together with the order preserved.\n    # DGL also ensures that the node types are sorted in ascending order.\n    assert torch.equal(\n        g.ndata[dgl.NTYPE],\n        torch.cat(\n            [\n                torch.full((dataset.num_authors,), 0),\n                torch.full((dataset.num_institutions,), 1),\n                torch.full((dataset.num_papers,), 2),\n            ]\n        ),\n    )\n    assert torch.equal(\n        g.ndata[dgl.NID],\n        torch.cat(\n            [\n                torch.arange(dataset.num_authors),\n                torch.arange(dataset.num_institutions),\n                torch.arange(dataset.num_papers),\n            ]\n        ),\n    )\n    g.edata[\"etype\"] = g.edata[dgl.ETYPE].byte()\n    del g.edata[dgl.ETYPE]\n    del g.ndata[dgl.NTYPE]\n    del g.ndata[dgl.NID]\n\n    # Process feature\n    full_feat = np.memmap(\n        args.full_output_path,\n        mode=\"w+\",\n        dtype=\"float16\",\n        shape=(\n            dataset.num_authors + dataset.num_institutions + dataset.num_papers,\n            dataset.num_paper_features,\n        ),\n    )\n    BLOCK_ROWS = 100000\n    for start in tqdm.trange(0, dataset.num_authors, BLOCK_ROWS):\n        end = min(dataset.num_authors, start + BLOCK_ROWS)\n        full_feat[author_offset + start : author_offset + end] = author_feat[\n            start:end\n        ]\n    for start in tqdm.trange(0, dataset.num_institutions, BLOCK_ROWS):\n        end = min(dataset.num_institutions, start + BLOCK_ROWS)\n        full_feat[inst_offset + start : inst_offset + end] = inst_feat[\n            start:end\n        ]\n    for start in tqdm.trange(0, dataset.num_papers, BLOCK_ROWS):\n        end = min(dataset.num_papers, start + BLOCK_ROWS)\n        full_feat[paper_offset + start : paper_offset + end] = paper_feat[\n            start:end\n        ]\n\n# Convert the graph to the given format and save.  (The RGAT baseline needs CSC graph)\ng = g.formats(args.graph_format)\ndgl.save_graphs(args.graph_output_path, g)\n"
  },
  {
    "path": "examples/pytorch/ogb_lsc/MAG240M/train.py",
    "content": "#!/usr/bin/env python\n# coding: utf-8\n\nimport argparse\nimport time\n\nimport dgl\nimport dgl.function as fn\nimport dgl.nn as dglnn\n\nimport numpy as np\nimport ogb\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport tqdm\nfrom ogb.lsc import MAG240MDataset, MAG240MEvaluator\n\n\nclass RGAT(nn.Module):\n    def __init__(\n        self,\n        in_channels,\n        out_channels,\n        hidden_channels,\n        num_etypes,\n        num_layers,\n        num_heads,\n        dropout,\n        pred_ntype,\n    ):\n        super().__init__()\n        self.convs = nn.ModuleList()\n        self.norms = nn.ModuleList()\n        self.skips = nn.ModuleList()\n\n        self.convs.append(\n            nn.ModuleList(\n                [\n                    dglnn.GATConv(\n                        in_channels,\n                        hidden_channels // num_heads,\n                        num_heads,\n                        allow_zero_in_degree=True,\n                    )\n                    for _ in range(num_etypes)\n                ]\n            )\n        )\n        self.norms.append(nn.BatchNorm1d(hidden_channels))\n        self.skips.append(nn.Linear(in_channels, hidden_channels))\n        for _ in range(num_layers - 1):\n            self.convs.append(\n                nn.ModuleList(\n                    [\n                        dglnn.GATConv(\n                            hidden_channels,\n                            hidden_channels // num_heads,\n                            num_heads,\n                            allow_zero_in_degree=True,\n                        )\n                        for _ in range(num_etypes)\n                    ]\n                )\n            )\n            self.norms.append(nn.BatchNorm1d(hidden_channels))\n            self.skips.append(nn.Linear(hidden_channels, hidden_channels))\n\n        self.mlp = nn.Sequential(\n            nn.Linear(hidden_channels, hidden_channels),\n            nn.BatchNorm1d(hidden_channels),\n            nn.ReLU(),\n            nn.Dropout(dropout),\n            nn.Linear(hidden_channels, out_channels),\n        )\n        self.dropout = nn.Dropout(dropout)\n\n        self.hidden_channels = hidden_channels\n        self.pred_ntype = pred_ntype\n        self.num_etypes = num_etypes\n\n    def forward(self, mfgs, x):\n        for i in range(len(mfgs)):\n            mfg = mfgs[i]\n            x_dst = x[: mfg.num_dst_nodes()]\n            n_src = mfg.num_src_nodes()\n            n_dst = mfg.num_dst_nodes()\n            mfg = dgl.block_to_graph(mfg)\n            x_skip = self.skips[i](x_dst)\n            for j in range(self.num_etypes):\n                subg = mfg.edge_subgraph(\n                    mfg.edata[\"etype\"] == j, relabel_nodes=False\n                )\n                x_skip += self.convs[i][j](subg, (x, x_dst)).view(\n                    -1, self.hidden_channels\n                )\n            x = self.norms[i](x_skip)\n            x = F.elu(x)\n            x = self.dropout(x)\n        return self.mlp(x)\n\n\nclass ExternalNodeCollator(dgl.dataloading.NodeCollator):\n    def __init__(self, g, idx, sampler, offset, feats, label):\n        super().__init__(g, idx, sampler)\n        self.offset = offset\n        self.feats = feats\n        self.label = label\n\n    def collate(self, items):\n        input_nodes, output_nodes, mfgs = super().collate(items)\n        # Copy input features\n        mfgs[0].srcdata[\"x\"] = torch.FloatTensor(self.feats[input_nodes])\n        mfgs[-1].dstdata[\"y\"] = torch.LongTensor(\n            self.label[output_nodes - self.offset]\n        )\n        return input_nodes, output_nodes, mfgs\n\n\ndef train(args, dataset, g, feats, paper_offset):\n    print(\"Loading masks and labels\")\n    train_idx = torch.LongTensor(dataset.get_idx_split(\"train\")) + paper_offset\n    valid_idx = torch.LongTensor(dataset.get_idx_split(\"valid\")) + paper_offset\n    label = dataset.paper_label\n\n    print(\"Initializing dataloader...\")\n    sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 25])\n    train_collator = ExternalNodeCollator(\n        g, train_idx, sampler, paper_offset, feats, label\n    )\n    valid_collator = ExternalNodeCollator(\n        g, valid_idx, sampler, paper_offset, feats, label\n    )\n    train_dataloader = torch.utils.data.DataLoader(\n        train_collator.dataset,\n        batch_size=1024,\n        shuffle=True,\n        drop_last=False,\n        collate_fn=train_collator.collate,\n        num_workers=4,\n    )\n    valid_dataloader = torch.utils.data.DataLoader(\n        valid_collator.dataset,\n        batch_size=1024,\n        shuffle=True,\n        drop_last=False,\n        collate_fn=valid_collator.collate,\n        num_workers=2,\n    )\n\n    print(\"Initializing model...\")\n    model = RGAT(\n        dataset.num_paper_features,\n        dataset.num_classes,\n        1024,\n        5,\n        2,\n        4,\n        0.5,\n        \"paper\",\n    ).cuda()\n    opt = torch.optim.Adam(model.parameters(), lr=0.001)\n    sched = torch.optim.lr_scheduler.StepLR(opt, step_size=25, gamma=0.25)\n\n    best_acc = 0\n\n    for _ in range(args.epochs):\n        model.train()\n        with tqdm.tqdm(train_dataloader) as tq:\n            for i, (input_nodes, output_nodes, mfgs) in enumerate(tq):\n                mfgs = [g.to(\"cuda\") for g in mfgs]\n                x = mfgs[0].srcdata[\"x\"]\n                y = mfgs[-1].dstdata[\"y\"]\n                y_hat = model(mfgs, x)\n                loss = F.cross_entropy(y_hat, y)\n                opt.zero_grad()\n                loss.backward()\n                opt.step()\n                acc = (y_hat.argmax(1) == y).float().mean()\n                tq.set_postfix(\n                    {\"loss\": \"%.4f\" % loss.item(), \"acc\": \"%.4f\" % acc.item()},\n                    refresh=False,\n                )\n\n        model.eval()\n        correct = total = 0\n        for i, (input_nodes, output_nodes, mfgs) in enumerate(\n            tqdm.tqdm(valid_dataloader)\n        ):\n            with torch.no_grad():\n                mfgs = [g.to(\"cuda\") for g in mfgs]\n                x = mfgs[0].srcdata[\"x\"]\n                y = mfgs[-1].dstdata[\"y\"]\n                y_hat = model(mfgs, x)\n                correct += (y_hat.argmax(1) == y).sum().item()\n                total += y_hat.shape[0]\n        acc = correct / total\n        print(\"Validation accuracy:\", acc)\n\n        sched.step()\n\n        if best_acc < acc:\n            best_acc = acc\n            print(\"Updating best model...\")\n            torch.save(model.state_dict(), args.model_path)\n\n\ndef test(args, dataset, g, feats, paper_offset):\n    print(\"Loading masks and labels...\")\n    valid_idx = torch.LongTensor(dataset.get_idx_split(\"valid\")) + paper_offset\n    test_idx = torch.LongTensor(dataset.get_idx_split(\"test\")) + paper_offset\n    label = dataset.paper_label\n\n    print(\"Initializing data loader...\")\n    sampler = dgl.dataloading.MultiLayerNeighborSampler([160, 160])\n    valid_collator = ExternalNodeCollator(\n        g, valid_idx, sampler, paper_offset, feats, label\n    )\n    valid_dataloader = torch.utils.data.DataLoader(\n        valid_collator.dataset,\n        batch_size=16,\n        shuffle=False,\n        drop_last=False,\n        collate_fn=valid_collator.collate,\n        num_workers=2,\n    )\n    test_collator = ExternalNodeCollator(\n        g, test_idx, sampler, paper_offset, feats, label\n    )\n    test_dataloader = torch.utils.data.DataLoader(\n        test_collator.dataset,\n        batch_size=16,\n        shuffle=False,\n        drop_last=False,\n        collate_fn=test_collator.collate,\n        num_workers=4,\n    )\n\n    print(\"Loading model...\")\n    model = RGAT(\n        dataset.num_paper_features,\n        dataset.num_classes,\n        1024,\n        5,\n        2,\n        4,\n        0.5,\n        \"paper\",\n    ).cuda()\n    model.load_state_dict(torch.load(args.model_path, weights_only=False))\n\n    model.eval()\n    correct = total = 0\n    for i, (input_nodes, output_nodes, mfgs) in enumerate(\n        tqdm.tqdm(valid_dataloader)\n    ):\n        with torch.no_grad():\n            mfgs = [g.to(\"cuda\") for g in mfgs]\n            x = mfgs[0].srcdata[\"x\"]\n            y = mfgs[-1].dstdata[\"y\"]\n            y_hat = model(mfgs, x)\n            correct += (y_hat.argmax(1) == y).sum().item()\n            total += y_hat.shape[0]\n    acc = correct / total\n    print(\"Validation accuracy:\", acc)\n    evaluator = MAG240MEvaluator()\n    y_preds = []\n    for i, (input_nodes, output_nodes, mfgs) in enumerate(\n        tqdm.tqdm(test_dataloader)\n    ):\n        with torch.no_grad():\n            mfgs = [g.to(\"cuda\") for g in mfgs]\n            x = mfgs[0].srcdata[\"x\"]\n            y = mfgs[-1].dstdata[\"y\"]\n            y_hat = model(mfgs, x)\n            y_preds.append(y_hat.argmax(1).cpu())\n    evaluator.save_test_submission(\n        {\"y_pred\": torch.cat(y_preds)}, args.submission_path\n    )\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--rootdir\",\n        type=str,\n        default=\".\",\n        help=\"Directory to download the OGB dataset.\",\n    )\n    parser.add_argument(\n        \"--graph-path\",\n        type=str,\n        default=\"./graph.dgl\",\n        help=\"Path to the graph.\",\n    )\n    parser.add_argument(\n        \"--full-feature-path\",\n        type=str,\n        default=\"./full.npy\",\n        help=\"Path to the features of all nodes.\",\n    )\n    parser.add_argument(\n        \"--epochs\", type=int, default=100, help=\"Number of epochs.\"\n    )\n    parser.add_argument(\n        \"--model-path\",\n        type=str,\n        default=\"./model.pt\",\n        help=\"Path to store the best model.\",\n    )\n    parser.add_argument(\n        \"--submission-path\",\n        type=str,\n        default=\"./results\",\n        help=\"Submission directory.\",\n    )\n    args = parser.parse_args()\n\n    dataset = MAG240MDataset(root=args.rootdir)\n\n    print(\"Loading graph\")\n    (g,), _ = dgl.load_graphs(args.graph_path)\n    g = g.formats([\"csc\"])\n\n    print(\"Loading features\")\n    paper_offset = dataset.num_authors + dataset.num_institutions\n    num_nodes = paper_offset + dataset.num_papers\n    num_features = dataset.num_paper_features\n    feats = np.memmap(\n        args.full_feature_path,\n        mode=\"r\",\n        dtype=\"float16\",\n        shape=(num_nodes, num_features),\n    )\n\n    if args.epochs != 0:\n        train(args, dataset, g, feats, paper_offset)\n    test(args, dataset, g, feats, paper_offset)\n"
  },
  {
    "path": "examples/pytorch/ogb_lsc/MAG240M/train_multi_gpus.py",
    "content": "#!/usr/bin/env python\n# coding: utf-8\nimport argparse\nimport math\nimport sys\nfrom collections import OrderedDict\n\nimport dgl\nimport dgl.nn as dglnn\n\nimport numpy as np\nimport torch\nimport torch.multiprocessing as mp\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport tqdm\nfrom ogb.lsc import MAG240MDataset, MAG240MEvaluator\nfrom torch.nn.parallel import DistributedDataParallel\n\n\nclass RGAT(nn.Module):\n    def __init__(\n        self,\n        in_channels,\n        out_channels,\n        hidden_channels,\n        num_etypes,\n        num_layers,\n        num_heads,\n        dropout,\n        pred_ntype,\n    ):\n        super().__init__()\n        self.convs = nn.ModuleList()\n        self.norms = nn.ModuleList()\n        self.skips = nn.ModuleList()\n\n        self.convs.append(\n            nn.ModuleList(\n                [\n                    dglnn.GATConv(\n                        in_channels,\n                        hidden_channels // num_heads,\n                        num_heads,\n                        allow_zero_in_degree=True,\n                    )\n                    for _ in range(num_etypes)\n                ]\n            )\n        )\n        self.norms.append(nn.BatchNorm1d(hidden_channels))\n        self.skips.append(nn.Linear(in_channels, hidden_channels))\n        for _ in range(num_layers - 1):\n            self.convs.append(\n                nn.ModuleList(\n                    [\n                        dglnn.GATConv(\n                            hidden_channels,\n                            hidden_channels // num_heads,\n                            num_heads,\n                            allow_zero_in_degree=True,\n                        )\n                        for _ in range(num_etypes)\n                    ]\n                )\n            )\n            self.norms.append(nn.BatchNorm1d(hidden_channels))\n            self.skips.append(nn.Linear(hidden_channels, hidden_channels))\n\n        self.mlp = nn.Sequential(\n            nn.Linear(hidden_channels, hidden_channels),\n            nn.BatchNorm1d(hidden_channels),\n            nn.ReLU(),\n            nn.Dropout(dropout),\n            nn.Linear(hidden_channels, out_channels),\n        )\n        self.dropout = nn.Dropout(dropout)\n\n        self.hidden_channels = hidden_channels\n        self.pred_ntype = pred_ntype\n        self.num_etypes = num_etypes\n\n    def forward(self, mfgs, x):\n        for i in range(len(mfgs)):\n            mfg = mfgs[i]\n            x_dst = x[: mfg.num_dst_nodes()]\n            n_src = mfg.num_src_nodes()\n            n_dst = mfg.num_dst_nodes()\n            mfg = dgl.block_to_graph(mfg)\n            x_skip = self.skips[i](x_dst)\n            for j in range(self.num_etypes):\n                subg = mfg.edge_subgraph(\n                    mfg.edata[\"etype\"] == j, relabel_nodes=False\n                )\n                x_skip += self.convs[i][j](subg, (x, x_dst)).view(\n                    -1, self.hidden_channels\n                )\n            x = self.norms[i](x_skip)\n            x = F.elu(x)\n            x = self.dropout(x)\n        return self.mlp(x)\n\n\nclass ExternalNodeCollator(dgl.dataloading.NodeCollator):\n    def __init__(self, g, idx, sampler, offset, feats, label):\n        super().__init__(g, idx, sampler)\n        self.offset = offset\n        self.feats = feats\n        self.label = label\n\n    def collate(self, items):\n        input_nodes, output_nodes, mfgs = super().collate(items)\n        # Copy input features\n        mfgs[0].srcdata[\"x\"] = torch.FloatTensor(self.feats[input_nodes])\n        mfgs[-1].dstdata[\"y\"] = torch.LongTensor(\n            self.label[output_nodes - self.offset]\n        )\n        return input_nodes, output_nodes, mfgs\n\n\ndef train(proc_id, n_gpus, args, dataset, g, feats, paper_offset):\n    dev_id = devices[proc_id]\n    if n_gpus > 1:\n        dist_init_method = \"tcp://{master_ip}:{master_port}\".format(\n            master_ip=\"127.0.0.1\", master_port=\"12346\"\n        )\n        world_size = n_gpus\n        torch.distributed.init_process_group(\n            backend=\"nccl\",\n            init_method=dist_init_method,\n            world_size=world_size,\n            rank=proc_id,\n        )\n\n    torch.cuda.set_device(dev_id)\n\n    print(\"Loading masks and labels\")\n    train_idx = torch.LongTensor(dataset.get_idx_split(\"train\")) + paper_offset\n    valid_idx = torch.LongTensor(dataset.get_idx_split(\"valid\")) + paper_offset\n    label = dataset.paper_label\n\n    print(\"Initializing dataloader...\")\n    sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 25])\n\n    train_collator = ExternalNodeCollator(\n        g, train_idx, sampler, paper_offset, feats, label\n    )\n    valid_collator = ExternalNodeCollator(\n        g, valid_idx, sampler, paper_offset, feats, label\n    )\n    # Necessary according to https://yangkky.github.io/2019/07/08/distributed-pytorch-tutorial.html\n    train_sampler = torch.utils.data.distributed.DistributedSampler(\n        train_collator.dataset,\n        num_replicas=world_size,\n        rank=proc_id,\n        shuffle=True,\n        drop_last=False,\n    )\n    valid_sampler = torch.utils.data.distributed.DistributedSampler(\n        valid_collator.dataset,\n        num_replicas=world_size,\n        rank=proc_id,\n        shuffle=True,\n        drop_last=False,\n    )\n\n    train_dataloader = torch.utils.data.DataLoader(\n        train_collator.dataset,\n        batch_size=1024,\n        collate_fn=train_collator.collate,\n        num_workers=4,\n        sampler=train_sampler,\n    )\n\n    valid_dataloader = torch.utils.data.DataLoader(\n        valid_collator.dataset,\n        batch_size=1024,\n        collate_fn=valid_collator.collate,\n        num_workers=2,\n        sampler=valid_sampler,\n    )\n\n    print(\"Initializing model...\")\n    model = RGAT(\n        dataset.num_paper_features,\n        dataset.num_classes,\n        1024,\n        5,\n        2,\n        4,\n        0.5,\n        \"paper\",\n    ).to(dev_id)\n\n    # convert BN to SyncBatchNorm. see https://pytorch.org/docs/stable/generated/torch.nn.SyncBatchNorm.html\n    model = nn.SyncBatchNorm.convert_sync_batchnorm(model)\n\n    model = DistributedDataParallel(\n        model, device_ids=[dev_id], output_device=dev_id\n    )\n    opt = torch.optim.Adam(model.parameters(), lr=0.001)\n    sched = torch.optim.lr_scheduler.StepLR(opt, step_size=25, gamma=0.25)\n\n    best_acc = 0\n\n    for i in range(args.epochs):\n        # make shuffling work properly across multiple epochs.\n        # see https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler\n        train_sampler.set_epoch(i)\n        model.train()\n        with tqdm.tqdm(train_dataloader) as tq:\n            for i, (input_nodes, output_nodes, mfgs) in enumerate(tq):\n                mfgs = [g.to(dev_id) for g in mfgs]\n                x = mfgs[0].srcdata[\"x\"]\n                y = mfgs[-1].dstdata[\"y\"]\n                y_hat = model(mfgs, x)\n                loss = F.cross_entropy(y_hat, y)\n                opt.zero_grad()\n                loss.backward()\n                opt.step()\n                acc = (y_hat.argmax(1) == y).float().mean()\n                tq.set_postfix(\n                    {\"loss\": \"%.4f\" % loss.item(), \"acc\": \"%.4f\" % acc.item()},\n                    refresh=False,\n                )\n\n        # eval in each process\n        model.eval()\n        correct = torch.LongTensor([0]).to(dev_id)\n        total = torch.LongTensor([0]).to(dev_id)\n        for i, (input_nodes, output_nodes, mfgs) in enumerate(\n            tqdm.tqdm(valid_dataloader)\n        ):\n            with torch.no_grad():\n                mfgs = [g.to(dev_id) for g in mfgs]\n                x = mfgs[0].srcdata[\"x\"]\n                y = mfgs[-1].dstdata[\"y\"]\n                y_hat = model(mfgs, x)\n                correct += (y_hat.argmax(1) == y).sum().item()\n                total += y_hat.shape[0]\n\n        # `reduce` data into process 0\n        torch.distributed.reduce(\n            correct, dst=0, op=torch.distributed.ReduceOp.SUM\n        )\n        torch.distributed.reduce(\n            total, dst=0, op=torch.distributed.ReduceOp.SUM\n        )\n        acc = (correct / total).item()\n\n        sched.step()\n\n        # process 0 print accuracy and save model\n        if proc_id == 0:\n            print(\"Validation accuracy:\", acc)\n\n            if best_acc < acc:\n                best_acc = acc\n                print(\"Updating best model...\")\n                torch.save(model.state_dict(), args.model_path)\n\n\ndef test(args, dataset, g, feats, paper_offset):\n    print(\"Loading masks and labels...\")\n    valid_idx = torch.LongTensor(dataset.get_idx_split(\"valid\")) + paper_offset\n    test_idx = torch.LongTensor(dataset.get_idx_split(\"test\")) + paper_offset\n    label = dataset.paper_label\n\n    print(\"Initializing data loader...\")\n    sampler = dgl.dataloading.MultiLayerNeighborSampler([160, 160])\n    valid_collator = ExternalNodeCollator(\n        g, valid_idx, sampler, paper_offset, feats, label\n    )\n    valid_dataloader = torch.utils.data.DataLoader(\n        valid_collator.dataset,\n        batch_size=16,\n        shuffle=False,\n        drop_last=False,\n        collate_fn=valid_collator.collate,\n        num_workers=2,\n    )\n    test_collator = ExternalNodeCollator(\n        g, test_idx, sampler, paper_offset, feats, label\n    )\n    test_dataloader = torch.utils.data.DataLoader(\n        test_collator.dataset,\n        batch_size=16,\n        shuffle=False,\n        drop_last=False,\n        collate_fn=test_collator.collate,\n        num_workers=4,\n    )\n\n    print(\"Loading model...\")\n    model = RGAT(\n        dataset.num_paper_features,\n        dataset.num_classes,\n        1024,\n        5,\n        2,\n        4,\n        0.5,\n        \"paper\",\n    ).cuda()\n\n    # load ddp's model parameters, we need to remove the name of 'module.'\n    state_dict = torch.load(args.model_path, weights_only=False)\n    new_state_dict = OrderedDict()\n    for k, v in state_dict.items():\n        name = k[7:]\n        new_state_dict[name] = v\n    model.load_state_dict(new_state_dict)\n\n    model.eval()\n    correct = total = 0\n    for i, (input_nodes, output_nodes, mfgs) in enumerate(\n        tqdm.tqdm(valid_dataloader)\n    ):\n        with torch.no_grad():\n            mfgs = [g.to(\"cuda\") for g in mfgs]\n            x = mfgs[0].srcdata[\"x\"]\n            y = mfgs[-1].dstdata[\"y\"]\n            y_hat = model(mfgs, x)\n            correct += (y_hat.argmax(1) == y).sum().item()\n            total += y_hat.shape[0]\n    acc = correct / total\n    print(\"Validation accuracy:\", acc)\n    evaluator = MAG240MEvaluator()\n    y_preds = []\n    for i, (input_nodes, output_nodes, mfgs) in enumerate(\n        tqdm.tqdm(test_dataloader)\n    ):\n        with torch.no_grad():\n            mfgs = [g.to(\"cuda\") for g in mfgs]\n            x = mfgs[0].srcdata[\"x\"]\n            y = mfgs[-1].dstdata[\"y\"]\n            y_hat = model(mfgs, x)\n            y_preds.append(y_hat.argmax(1).cpu())\n    evaluator.save_test_submission(\n        {\"y_pred\": torch.cat(y_preds)}, args.submission_path\n    )\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--rootdir\",\n        type=str,\n        default=\".\",\n        help=\"Directory to download the OGB dataset.\",\n    )\n    parser.add_argument(\n        \"--graph-path\",\n        type=str,\n        default=\"./graph.dgl\",\n        help=\"Path to the graph.\",\n    )\n    parser.add_argument(\n        \"--full-feature-path\",\n        type=str,\n        default=\"./full.npy\",\n        help=\"Path to the features of all nodes.\",\n    )\n    parser.add_argument(\n        \"--epochs\", type=int, default=100, help=\"Number of epochs.\"\n    )\n    parser.add_argument(\n        \"--model-path\",\n        type=str,\n        default=\"./model_ddp.pt\",\n        help=\"Path to store the best model.\",\n    )\n    parser.add_argument(\n        \"--submission-path\",\n        type=str,\n        default=\"./results_ddp\",\n        help=\"Submission directory.\",\n    )\n    parser.add_argument(\"--gpus\", type=str, default=\"0,1,2\")\n    args = parser.parse_args()\n\n    devices = list(map(int, args.gpus.split(\",\")))\n    n_gpus = len(devices)\n\n    if n_gpus <= 1:\n        print(\"make sure the number of gpus greater than 1!\")\n        sys.exit()\n\n    dataset = MAG240MDataset(root=args.rootdir)\n\n    print(\"Loading graph\")\n    (g,), _ = dgl.load_graphs(args.graph_path)\n    g = g.formats([\"csc\"])\n\n    print(\"Loading features\")\n    paper_offset = dataset.num_authors + dataset.num_institutions\n    num_nodes = paper_offset + dataset.num_papers\n    num_features = dataset.num_paper_features\n    feats = np.memmap(\n        args.full_feature_path,\n        mode=\"r\",\n        dtype=\"float16\",\n        shape=(num_nodes, num_features),\n    )\n\n    mp.spawn(\n        train,\n        args=(n_gpus, args, dataset, g, feats, paper_offset),\n        nprocs=n_gpus,\n    )\n\n    test(args, dataset, g, feats, paper_offset)\n"
  },
  {
    "path": "examples/pytorch/ogb_lsc/PCQM4M/README.md",
    "content": "# Baseline Code for PCQM4M-LSC\n\nThe code is ported from the official examples [here](https://github.com/snap-stanford/ogb/tree/master/examples/lsc/pcqm4m). Please refer to the [OGB-LSC paper](https://arxiv.org/abs/2103.09430) for the detailed setting.\n\n## Installation Requirements\n\n```\nogb>=1.3.0\nrdkit>=2019.03.1\ntorch>=1.7.0\n```\n\nWe recommend installing RDKit with `conda install -c rdkit rdkit==2019.03.1`.\n\n## Commandline Arguments\n\n- `LOG_DIR`: Tensorboard log directory.\n- `CHECKPOINT_DIR`: Directory to save the best validation checkpoint. The checkpoint file will be saved at `${CHECKPOINT_DIR}/checkpoint.pt`.\n- `TEST_DIR`: Directory path to save the test submission. The test file will be saved at `${TEST_DIR}/y_pred_pcqm4m.npz`.\n\n## Baseline Models\n\n### GIN [1]\n\n```\npython main.py --gnn gin --log_dir $LOG_DIR --checkpoint_dir $CHECKPOINT_DIR --save_test_dir $TEST_DIR\n```\n\n### GIN-virtual [1,3]\n\n```\npython main.py --gnn gin-virtual --log_dir $LOG_DIR --checkpoint_dir $CHECKPOINT_DIR --save_test_dir $TEST_DIR\n```\n\n### GCN [2]\n\n```\npython main.py --gnn gcn --log_dir $LOG_DIR --checkpoint_dir $CHECKPOINT_DIR --save_test_dir $TEST_DIR\n```\n\n### GCN-virtual [2,3]\n\n```\npython main.py --gnn gcn-virtual --log_dir $LOG_DIR --checkpoint_dir $CHECKPOINT_DIR --save_test_dir $TEST_DIR\n```\n\n## Measuring the Test Inference Time\n\nThe code below takes **the raw SMILES strings as input**, uses the saved checkpoint, and performs inference over for all the 377,423 test molecules.\n\n```\npython test_inference.py --gnn $GNN --checkpoint_dir $CHECKPOINT_DIR --save_test_dir $TEST_DIR\n```\n\nFor your model, **the total inference time needs to be less than 12 hours on a single GPU and a CPU**. Ideally, you \nshould use the CPU/GPU spec of the organizers, which consists of a single GeForce RTX 2080 GPU and an Intel(R) Xeon(R) \nGold 6148 CPU @ 2.40GHz. However, the organizers also allow the use of other GPU/CPU specs, as long as the specs are \nclearly reported in the final submission.\n\n## Performance\n\n| Model       | Original Valid MAE | DGL Valid MAE | #Parameters | \n| ----------- | ------------------ | ------------- | ----------- | \n| GIN         | 0.1536             | 0.1536        | 3.8M        | \n| GIN-virtual | 0.1396             | 0.1407        | 6.7M        |\n| GCN         | 0.1684             | 0.1683        | 2.0M        |\n| GCN-virtual | 0.1510             | 0.1557        | 4.9M        |\n\n## References\n\n[1] Xu, K., Hu, W., Leskovec, J., & Jegelka, S. (2019). How powerful are graph neural networks?. ICLR 2019\n\n[2] Kipf, T. N., & Welling, M. (2017). Semi-supervised classification with graph convolutional networks. ICLR 2017\n\n[3] Gilmer, J., Schoenholz, S. S., Riley, P. F., Vinyals, O., & Dahl, G. E. Neural message passing for quantum chemistry. ICML 2017.\n"
  },
  {
    "path": "examples/pytorch/ogb_lsc/PCQM4M/conv.py",
    "content": "import dgl\nimport dgl.function as fn\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dgl.nn.pytorch import SumPooling\nfrom ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder\n\n\n### GIN convolution along the graph structure\nclass GINConv(nn.Module):\n    def __init__(self, emb_dim):\n        \"\"\"\n        emb_dim (int): node embedding dimensionality\n        \"\"\"\n\n        super(GINConv, self).__init__()\n\n        self.mlp = nn.Sequential(\n            nn.Linear(emb_dim, emb_dim),\n            nn.BatchNorm1d(emb_dim),\n            nn.ReLU(),\n            nn.Linear(emb_dim, emb_dim),\n        )\n        self.eps = nn.Parameter(torch.Tensor([0]))\n\n        self.bond_encoder = BondEncoder(emb_dim=emb_dim)\n\n    def forward(self, g, x, edge_attr):\n        with g.local_scope():\n            edge_embedding = self.bond_encoder(edge_attr)\n            g.ndata[\"x\"] = x\n            g.apply_edges(fn.copy_u(\"x\", \"m\"))\n            g.edata[\"m\"] = F.relu(g.edata[\"m\"] + edge_embedding)\n            g.update_all(fn.copy_e(\"m\", \"m\"), fn.sum(\"m\", \"new_x\"))\n            out = self.mlp((1 + self.eps) * x + g.ndata[\"new_x\"])\n\n            return out\n\n\n### GCN convolution along the graph structure\nclass GCNConv(nn.Module):\n    def __init__(self, emb_dim):\n        \"\"\"\n        emb_dim (int): node embedding dimensionality\n        \"\"\"\n\n        super(GCNConv, self).__init__()\n\n        self.linear = nn.Linear(emb_dim, emb_dim)\n        self.root_emb = nn.Embedding(1, emb_dim)\n        self.bond_encoder = BondEncoder(emb_dim=emb_dim)\n\n    def forward(self, g, x, edge_attr):\n        with g.local_scope():\n            x = self.linear(x)\n            edge_embedding = self.bond_encoder(edge_attr)\n\n            # Molecular graphs are undirected\n            # g.out_degrees() is the same as g.in_degrees()\n            degs = (g.out_degrees().float() + 1).to(x.device)\n            norm = torch.pow(degs, -0.5).unsqueeze(-1)  # (N, 1)\n            g.ndata[\"norm\"] = norm\n            g.apply_edges(fn.u_mul_v(\"norm\", \"norm\", \"norm\"))\n\n            g.ndata[\"x\"] = x\n            g.apply_edges(fn.copy_u(\"x\", \"m\"))\n            g.edata[\"m\"] = g.edata[\"norm\"] * F.relu(\n                g.edata[\"m\"] + edge_embedding\n            )\n            g.update_all(fn.copy_e(\"m\", \"m\"), fn.sum(\"m\", \"new_x\"))\n            out = g.ndata[\"new_x\"] + F.relu(\n                x + self.root_emb.weight\n            ) * 1.0 / degs.view(-1, 1)\n\n            return out\n\n\n### GNN to generate node embedding\nclass GNN_node(nn.Module):\n    \"\"\"\n    Output:\n        node representations\n    \"\"\"\n\n    def __init__(\n        self,\n        num_layers,\n        emb_dim,\n        drop_ratio=0.5,\n        JK=\"last\",\n        residual=False,\n        gnn_type=\"gin\",\n    ):\n        \"\"\"\n        num_layers (int): number of GNN message passing layers\n        emb_dim (int): node embedding dimensionality\n        \"\"\"\n\n        super(GNN_node, self).__init__()\n        self.num_layers = num_layers\n        self.drop_ratio = drop_ratio\n        self.JK = JK\n        ### add residual connection or not\n        self.residual = residual\n\n        if self.num_layers < 2:\n            raise ValueError(\"Number of GNN layers must be greater than 1.\")\n\n        self.atom_encoder = AtomEncoder(emb_dim)\n\n        ###List of GNNs\n        self.convs = nn.ModuleList()\n        self.batch_norms = nn.ModuleList()\n\n        for layer in range(num_layers):\n            if gnn_type == \"gin\":\n                self.convs.append(GINConv(emb_dim))\n            elif gnn_type == \"gcn\":\n                self.convs.append(GCNConv(emb_dim))\n            else:\n                ValueError(\"Undefined GNN type called {}\".format(gnn_type))\n\n            self.batch_norms.append(nn.BatchNorm1d(emb_dim))\n\n    def forward(self, g, x, edge_attr):\n        ### computing input node embedding\n        h_list = [self.atom_encoder(x)]\n        for layer in range(self.num_layers):\n            h = self.convs[layer](g, h_list[layer], edge_attr)\n            h = self.batch_norms[layer](h)\n\n            if layer == self.num_layers - 1:\n                # remove relu for the last layer\n                h = F.dropout(h, self.drop_ratio, training=self.training)\n            else:\n                h = F.dropout(\n                    F.relu(h), self.drop_ratio, training=self.training\n                )\n\n            if self.residual:\n                h += h_list[layer]\n\n            h_list.append(h)\n\n        ### Different implementations of Jk-concat\n        if self.JK == \"last\":\n            node_representation = h_list[-1]\n        elif self.JK == \"sum\":\n            node_representation = 0\n            for layer in range(self.num_layers):\n                node_representation += h_list[layer]\n\n        return node_representation\n\n\n### Virtual GNN to generate node embedding\nclass GNN_node_Virtualnode(nn.Module):\n    \"\"\"\n    Output:\n        node representations\n    \"\"\"\n\n    def __init__(\n        self,\n        num_layers,\n        emb_dim,\n        drop_ratio=0.5,\n        JK=\"last\",\n        residual=False,\n        gnn_type=\"gin\",\n    ):\n        \"\"\"\n        num_layers (int): number of GNN message passing layers\n        emb_dim (int): node embedding dimensionality\n        \"\"\"\n\n        super(GNN_node_Virtualnode, self).__init__()\n        self.num_layers = num_layers\n        self.drop_ratio = drop_ratio\n        self.JK = JK\n        ### add residual connection or not\n        self.residual = residual\n\n        if self.num_layers < 2:\n            raise ValueError(\"Number of GNN layers must be greater than 1.\")\n\n        self.atom_encoder = AtomEncoder(emb_dim)\n\n        ### set the initial virtual node embedding to 0.\n        self.virtualnode_embedding = nn.Embedding(1, emb_dim)\n        nn.init.constant_(self.virtualnode_embedding.weight.data, 0)\n\n        ### List of GNNs\n        self.convs = nn.ModuleList()\n        ### batch norms applied to node embeddings\n        self.batch_norms = nn.ModuleList()\n\n        ### List of MLPs to transform virtual node at every layer\n        self.mlp_virtualnode_list = nn.ModuleList()\n\n        for layer in range(num_layers):\n            if gnn_type == \"gin\":\n                self.convs.append(GINConv(emb_dim))\n            elif gnn_type == \"gcn\":\n                self.convs.append(GCNConv(emb_dim))\n            else:\n                ValueError(\"Undefined GNN type called {}\".format(gnn_type))\n\n            self.batch_norms.append(nn.BatchNorm1d(emb_dim))\n\n        for layer in range(num_layers - 1):\n            self.mlp_virtualnode_list.append(\n                nn.Sequential(\n                    nn.Linear(emb_dim, emb_dim),\n                    nn.BatchNorm1d(emb_dim),\n                    nn.ReLU(),\n                    nn.Linear(emb_dim, emb_dim),\n                    nn.BatchNorm1d(emb_dim),\n                    nn.ReLU(),\n                )\n            )\n        self.pool = SumPooling()\n\n    def forward(self, g, x, edge_attr):\n        ### virtual node embeddings for graphs\n        virtualnode_embedding = self.virtualnode_embedding(\n            torch.zeros(g.batch_size).to(x.dtype).to(x.device)\n        )\n\n        h_list = [self.atom_encoder(x)]\n        batch_id = dgl.broadcast_nodes(\n            g, torch.arange(g.batch_size).to(x.device)\n        )\n        for layer in range(self.num_layers):\n            ### add message from virtual nodes to graph nodes\n            h_list[layer] = h_list[layer] + virtualnode_embedding[batch_id]\n\n            ### Message passing among graph nodes\n            h = self.convs[layer](g, h_list[layer], edge_attr)\n            h = self.batch_norms[layer](h)\n            if layer == self.num_layers - 1:\n                # remove relu for the last layer\n                h = F.dropout(h, self.drop_ratio, training=self.training)\n            else:\n                h = F.dropout(\n                    F.relu(h), self.drop_ratio, training=self.training\n                )\n\n            if self.residual:\n                h = h + h_list[layer]\n\n            h_list.append(h)\n\n            ### update the virtual nodes\n            if layer < self.num_layers - 1:\n                ### add message from graph nodes to virtual nodes\n                virtualnode_embedding_temp = (\n                    self.pool(g, h_list[layer]) + virtualnode_embedding\n                )\n                ### transform virtual nodes using MLP\n                virtualnode_embedding_temp = self.mlp_virtualnode_list[layer](\n                    virtualnode_embedding_temp\n                )\n\n                if self.residual:\n                    virtualnode_embedding = virtualnode_embedding + F.dropout(\n                        virtualnode_embedding_temp,\n                        self.drop_ratio,\n                        training=self.training,\n                    )\n                else:\n                    virtualnode_embedding = F.dropout(\n                        virtualnode_embedding_temp,\n                        self.drop_ratio,\n                        training=self.training,\n                    )\n\n        ### Different implementations of Jk-concat\n        if self.JK == \"last\":\n            node_representation = h_list[-1]\n        elif self.JK == \"sum\":\n            node_representation = 0\n            for layer in range(self.num_layers):\n                node_representation += h_list[layer]\n\n        return node_representation\n"
  },
  {
    "path": "examples/pytorch/ogb_lsc/PCQM4M/gnn.py",
    "content": "import torch\nimport torch.nn as nn\nfrom conv import GNN_node, GNN_node_Virtualnode\n\nfrom dgl.nn.pytorch import (\n    AvgPooling,\n    GlobalAttentionPooling,\n    MaxPooling,\n    Set2Set,\n    SumPooling,\n)\n\n\nclass GNN(nn.Module):\n    def __init__(\n        self,\n        num_tasks=1,\n        num_layers=5,\n        emb_dim=300,\n        gnn_type=\"gin\",\n        virtual_node=True,\n        residual=False,\n        drop_ratio=0,\n        JK=\"last\",\n        graph_pooling=\"sum\",\n    ):\n        \"\"\"\n        num_tasks (int): number of labels to be predicted\n        virtual_node (bool): whether to add virtual node or not\n        \"\"\"\n        super(GNN, self).__init__()\n\n        self.num_layers = num_layers\n        self.drop_ratio = drop_ratio\n        self.JK = JK\n        self.emb_dim = emb_dim\n        self.num_tasks = num_tasks\n        self.graph_pooling = graph_pooling\n\n        if self.num_layers < 2:\n            raise ValueError(\"Number of GNN layers must be greater than 1.\")\n\n        ### GNN to generate node embeddings\n        if virtual_node:\n            self.gnn_node = GNN_node_Virtualnode(\n                num_layers,\n                emb_dim,\n                JK=JK,\n                drop_ratio=drop_ratio,\n                residual=residual,\n                gnn_type=gnn_type,\n            )\n        else:\n            self.gnn_node = GNN_node(\n                num_layers,\n                emb_dim,\n                JK=JK,\n                drop_ratio=drop_ratio,\n                residual=residual,\n                gnn_type=gnn_type,\n            )\n\n        ### Pooling function to generate whole-graph embeddings\n        if self.graph_pooling == \"sum\":\n            self.pool = SumPooling()\n        elif self.graph_pooling == \"mean\":\n            self.pool = AvgPooling()\n        elif self.graph_pooling == \"max\":\n            self.pool = MaxPooling\n        elif self.graph_pooling == \"attention\":\n            self.pool = GlobalAttentionPooling(\n                gate_nn=nn.Sequential(\n                    nn.Linear(emb_dim, 2 * emb_dim),\n                    nn.BatchNorm1d(2 * emb_dim),\n                    nn.ReLU(),\n                    nn.Linear(2 * emb_dim, 1),\n                )\n            )\n\n        elif self.graph_pooling == \"set2set\":\n            self.pool = Set2Set(emb_dim, n_iters=2, n_layers=2)\n        else:\n            raise ValueError(\"Invalid graph pooling type.\")\n\n        if graph_pooling == \"set2set\":\n            self.graph_pred_linear = nn.Linear(2 * self.emb_dim, self.num_tasks)\n        else:\n            self.graph_pred_linear = nn.Linear(self.emb_dim, self.num_tasks)\n\n    def forward(self, g, x, edge_attr):\n        h_node = self.gnn_node(g, x, edge_attr)\n\n        h_graph = self.pool(g, h_node)\n        output = self.graph_pred_linear(h_graph)\n\n        if self.training:\n            return output\n        else:\n            return torch.clamp(output, min=0, max=50)\n"
  },
  {
    "path": "examples/pytorch/ogb_lsc/PCQM4M/main.py",
    "content": "import argparse\nimport os\nimport random\n\nimport dgl\n\nimport numpy as np\nimport torch\nimport torch.optim as optim\nfrom gnn import GNN\nfrom ogb.lsc import DglPCQM4MDataset, PCQM4MEvaluator\nfrom torch.optim.lr_scheduler import StepLR\nfrom torch.utils.data import DataLoader\nfrom torch.utils.tensorboard import SummaryWriter\nfrom tqdm import tqdm\n\nreg_criterion = torch.nn.L1Loss()\n\n\ndef collate_dgl(samples):\n    graphs, labels = map(list, zip(*samples))\n    batched_graph = dgl.batch(graphs)\n    labels = torch.stack(labels)\n\n    return batched_graph, labels\n\n\ndef train(model, device, loader, optimizer):\n    model.train()\n    loss_accum = 0\n\n    for step, (bg, labels) in enumerate(tqdm(loader, desc=\"Iteration\")):\n        bg = bg.to(device)\n        x = bg.ndata.pop(\"feat\")\n        edge_attr = bg.edata.pop(\"feat\")\n        labels = labels.to(device)\n\n        pred = model(bg, x, edge_attr).view(\n            -1,\n        )\n        optimizer.zero_grad()\n        loss = reg_criterion(pred, labels)\n        loss.backward()\n        optimizer.step()\n\n        loss_accum += loss.detach().cpu().item()\n\n    return loss_accum / (step + 1)\n\n\ndef eval(model, device, loader, evaluator):\n    model.eval()\n    y_true = []\n    y_pred = []\n\n    for step, (bg, labels) in enumerate(tqdm(loader, desc=\"Iteration\")):\n        bg = bg.to(device)\n        x = bg.ndata.pop(\"feat\")\n        edge_attr = bg.edata.pop(\"feat\")\n        labels = labels.to(device)\n\n        with torch.no_grad():\n            pred = model(bg, x, edge_attr).view(\n                -1,\n            )\n\n        y_true.append(labels.view(pred.shape).detach().cpu())\n        y_pred.append(pred.detach().cpu())\n\n    y_true = torch.cat(y_true, dim=0)\n    y_pred = torch.cat(y_pred, dim=0)\n\n    input_dict = {\"y_true\": y_true, \"y_pred\": y_pred}\n\n    return evaluator.eval(input_dict)[\"mae\"]\n\n\ndef test(model, device, loader):\n    model.eval()\n    y_pred = []\n\n    for step, (bg, _) in enumerate(tqdm(loader, desc=\"Iteration\")):\n        bg = bg.to(device)\n        x = bg.ndata.pop(\"feat\")\n        edge_attr = bg.edata.pop(\"feat\")\n\n        with torch.no_grad():\n            pred = model(bg, x, edge_attr).view(\n                -1,\n            )\n\n        y_pred.append(pred.detach().cpu())\n\n    y_pred = torch.cat(y_pred, dim=0)\n\n    return y_pred\n\n\ndef main():\n    # Training settings\n    parser = argparse.ArgumentParser(\n        description=\"GNN baselines on pcqm4m with DGL\"\n    )\n    parser.add_argument(\n        \"--seed\", type=int, default=42, help=\"random seed to use (default: 42)\"\n    )\n    parser.add_argument(\n        \"--device\",\n        type=int,\n        default=0,\n        help=\"which gpu to use if any (default: 0)\",\n    )\n    parser.add_argument(\n        \"--gnn\",\n        type=str,\n        default=\"gin-virtual\",\n        help=\"GNN to use, which can be from \"\n        \"[gin, gin-virtual, gcn, gcn-virtual] (default: gin-virtual)\",\n    )\n    parser.add_argument(\n        \"--graph_pooling\",\n        type=str,\n        default=\"sum\",\n        help=\"graph pooling strategy mean or sum (default: sum)\",\n    )\n    parser.add_argument(\n        \"--drop_ratio\", type=float, default=0, help=\"dropout ratio (default: 0)\"\n    )\n    parser.add_argument(\n        \"--num_layers\",\n        type=int,\n        default=5,\n        help=\"number of GNN message passing layers (default: 5)\",\n    )\n    parser.add_argument(\n        \"--emb_dim\",\n        type=int,\n        default=600,\n        help=\"dimensionality of hidden units in GNNs (default: 600)\",\n    )\n    parser.add_argument(\n        \"--train_subset\",\n        action=\"store_true\",\n        help=\"use 10% of the training set for training\",\n    )\n    parser.add_argument(\n        \"--batch_size\",\n        type=int,\n        default=256,\n        help=\"input batch size for training (default: 256)\",\n    )\n    parser.add_argument(\n        \"--epochs\",\n        type=int,\n        default=100,\n        help=\"number of epochs to train (default: 100)\",\n    )\n    parser.add_argument(\n        \"--num_workers\",\n        type=int,\n        default=0,\n        help=\"number of workers (default: 0)\",\n    )\n    parser.add_argument(\n        \"--log_dir\",\n        type=str,\n        default=\"\",\n        help=\"tensorboard log directory. If not specified, \"\n        \"tensorboard will not be used.\",\n    )\n    parser.add_argument(\n        \"--checkpoint_dir\",\n        type=str,\n        default=\"\",\n        help=\"directory to save checkpoint\",\n    )\n    parser.add_argument(\n        \"--save_test_dir\",\n        type=str,\n        default=\"\",\n        help=\"directory to save test submission file\",\n    )\n    args = parser.parse_args()\n\n    print(args)\n\n    np.random.seed(args.seed)\n    torch.manual_seed(args.seed)\n    random.seed(args.seed)\n\n    if torch.cuda.is_available():\n        torch.cuda.manual_seed(args.seed)\n        device = torch.device(\"cuda:\" + str(args.device))\n    else:\n        device = torch.device(\"cpu\")\n\n    ### automatic dataloading and splitting\n    dataset = DglPCQM4MDataset(root=\"dataset/\")\n\n    # split_idx['train'], split_idx['valid'], split_idx['test']\n    # separately gives a 1D int64 tensor\n    split_idx = dataset.get_idx_split()\n\n    ### automatic evaluator.\n    evaluator = PCQM4MEvaluator()\n\n    if args.train_subset:\n        subset_ratio = 0.1\n        subset_idx = torch.randperm(len(split_idx[\"train\"]))[\n            : int(subset_ratio * len(split_idx[\"train\"]))\n        ]\n        train_loader = DataLoader(\n            dataset[split_idx[\"train\"][subset_idx]],\n            batch_size=args.batch_size,\n            shuffle=True,\n            num_workers=args.num_workers,\n            collate_fn=collate_dgl,\n        )\n    else:\n        train_loader = DataLoader(\n            dataset[split_idx[\"train\"]],\n            batch_size=args.batch_size,\n            shuffle=True,\n            num_workers=args.num_workers,\n            collate_fn=collate_dgl,\n        )\n\n    valid_loader = DataLoader(\n        dataset[split_idx[\"valid\"]],\n        batch_size=args.batch_size,\n        shuffle=False,\n        num_workers=args.num_workers,\n        collate_fn=collate_dgl,\n    )\n\n    if args.save_test_dir != \"\":\n        test_loader = DataLoader(\n            dataset[split_idx[\"test\"]],\n            batch_size=args.batch_size,\n            shuffle=False,\n            num_workers=args.num_workers,\n            collate_fn=collate_dgl,\n        )\n\n    if args.checkpoint_dir != \"\":\n        os.makedirs(args.checkpoint_dir, exist_ok=True)\n\n    shared_params = {\n        \"num_layers\": args.num_layers,\n        \"emb_dim\": args.emb_dim,\n        \"drop_ratio\": args.drop_ratio,\n        \"graph_pooling\": args.graph_pooling,\n    }\n\n    if args.gnn == \"gin\":\n        model = GNN(gnn_type=\"gin\", virtual_node=False, **shared_params).to(\n            device\n        )\n    elif args.gnn == \"gin-virtual\":\n        model = GNN(gnn_type=\"gin\", virtual_node=True, **shared_params).to(\n            device\n        )\n    elif args.gnn == \"gcn\":\n        model = GNN(gnn_type=\"gcn\", virtual_node=False, **shared_params).to(\n            device\n        )\n    elif args.gnn == \"gcn-virtual\":\n        model = GNN(gnn_type=\"gcn\", virtual_node=True, **shared_params).to(\n            device\n        )\n    else:\n        raise ValueError(\"Invalid GNN type\")\n\n    num_params = sum(p.numel() for p in model.parameters())\n    print(f\"#Params: {num_params}\")\n\n    optimizer = optim.Adam(model.parameters(), lr=0.001)\n\n    if args.log_dir != \"\":\n        writer = SummaryWriter(log_dir=args.log_dir)\n\n    best_valid_mae = 1000\n\n    if args.train_subset:\n        scheduler = StepLR(optimizer, step_size=300, gamma=0.25)\n        args.epochs = 1000\n    else:\n        scheduler = StepLR(optimizer, step_size=30, gamma=0.25)\n\n    for epoch in range(1, args.epochs + 1):\n        print(\"=====Epoch {}\".format(epoch))\n        print(\"Training...\")\n        train_mae = train(model, device, train_loader, optimizer)\n\n        print(\"Evaluating...\")\n        valid_mae = eval(model, device, valid_loader, evaluator)\n\n        print({\"Train\": train_mae, \"Validation\": valid_mae})\n\n        if args.log_dir != \"\":\n            writer.add_scalar(\"valid/mae\", valid_mae, epoch)\n            writer.add_scalar(\"train/mae\", train_mae, epoch)\n\n        if valid_mae < best_valid_mae:\n            best_valid_mae = valid_mae\n            if args.checkpoint_dir != \"\":\n                print(\"Saving checkpoint...\")\n                checkpoint = {\n                    \"epoch\": epoch,\n                    \"model_state_dict\": model.state_dict(),\n                    \"optimizer_state_dict\": optimizer.state_dict(),\n                    \"scheduler_state_dict\": scheduler.state_dict(),\n                    \"best_val_mae\": best_valid_mae,\n                    \"num_params\": num_params,\n                }\n                torch.save(\n                    checkpoint,\n                    os.path.join(args.checkpoint_dir, \"checkpoint.pt\"),\n                )\n\n            if args.save_test_dir != \"\":\n                print(\"Predicting on test data...\")\n                y_pred = test(model, device, test_loader)\n                print(\"Saving test submission file...\")\n                evaluator.save_test_submission(\n                    {\"y_pred\": y_pred}, args.save_test_dir\n                )\n\n        scheduler.step()\n\n        print(f\"Best validation MAE so far: {best_valid_mae}\")\n\n    if args.log_dir != \"\":\n        writer.close()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/pytorch/ogb_lsc/PCQM4M/test_inference.py",
    "content": "import argparse\nimport os\nimport random\n\nimport dgl\n\nimport numpy as np\nimport torch\nfrom gnn import GNN\nfrom ogb.lsc import PCQM4MDataset, PCQM4MEvaluator\nfrom ogb.utils import smiles2graph\nfrom torch.utils.data import DataLoader\nfrom tqdm import tqdm\n\n\ndef collate_dgl(graphs):\n    batched_graph = dgl.batch(graphs)\n\n    return batched_graph\n\n\ndef test(model, device, loader):\n    model.eval()\n    y_pred = []\n\n    for step, bg in enumerate(tqdm(loader, desc=\"Iteration\")):\n        bg = bg.to(device)\n        x = bg.ndata.pop(\"feat\")\n        edge_attr = bg.edata.pop(\"feat\")\n\n        with torch.no_grad():\n            pred = model(bg, x, edge_attr).view(\n                -1,\n            )\n\n        y_pred.append(pred.detach().cpu())\n\n    y_pred = torch.cat(y_pred, dim=0)\n\n    return y_pred\n\n\nclass OnTheFlyPCQMDataset(object):\n    def __init__(self, smiles_list, smiles2graph=smiles2graph):\n        super(OnTheFlyPCQMDataset, self).__init__()\n        self.smiles_list = smiles_list\n        self.smiles2graph = smiles2graph\n\n    def __getitem__(self, idx):\n        \"\"\"Get datapoint with index\"\"\"\n        smiles, _ = self.smiles_list[idx]\n        graph = self.smiles2graph(smiles)\n\n        dgl_graph = dgl.graph(\n            (graph[\"edge_index\"][0], graph[\"edge_index\"][1]),\n            num_nodes=graph[\"num_nodes\"],\n        )\n        dgl_graph.edata[\"feat\"] = torch.from_numpy(graph[\"edge_feat\"]).to(\n            torch.int64\n        )\n        dgl_graph.ndata[\"feat\"] = torch.from_numpy(graph[\"node_feat\"]).to(\n            torch.int64\n        )\n\n        return dgl_graph\n\n    def __len__(self):\n        \"\"\"Length of the dataset\n        Returns\n        -------\n        int\n            Length of Dataset\n        \"\"\"\n        return len(self.smiles_list)\n\n\ndef main():\n    # Training settings\n    parser = argparse.ArgumentParser(\n        description=\"GNN baselines on pcqm4m with DGL\"\n    )\n    parser.add_argument(\n        \"--seed\", type=int, default=42, help=\"random seed to use (default: 42)\"\n    )\n    parser.add_argument(\n        \"--device\",\n        type=int,\n        default=0,\n        help=\"which gpu to use if any (default: 0)\",\n    )\n    parser.add_argument(\n        \"--gnn\",\n        type=str,\n        default=\"gin-virtual\",\n        help=\"GNN to use, which can be from \"\n        \"[gin, gin-virtual, gcn, gcn-virtual] (default: gin-virtual)\",\n    )\n    parser.add_argument(\n        \"--graph_pooling\",\n        type=str,\n        default=\"sum\",\n        help=\"graph pooling strategy mean or sum (default: sum)\",\n    )\n    parser.add_argument(\n        \"--drop_ratio\", type=float, default=0, help=\"dropout ratio (default: 0)\"\n    )\n    parser.add_argument(\n        \"--num_layers\",\n        type=int,\n        default=5,\n        help=\"number of GNN message passing layers (default: 5)\",\n    )\n    parser.add_argument(\n        \"--emb_dim\",\n        type=int,\n        default=600,\n        help=\"dimensionality of hidden units in GNNs (default: 600)\",\n    )\n    parser.add_argument(\n        \"--batch_size\",\n        type=int,\n        default=256,\n        help=\"input batch size for training (default: 256)\",\n    )\n    parser.add_argument(\n        \"--num_workers\",\n        type=int,\n        default=0,\n        help=\"number of workers (default: 0)\",\n    )\n    parser.add_argument(\n        \"--checkpoint_dir\",\n        type=str,\n        default=\"\",\n        help=\"directory to save checkpoint\",\n    )\n    parser.add_argument(\n        \"--save_test_dir\",\n        type=str,\n        default=\"\",\n        help=\"directory to save test submission file\",\n    )\n    args = parser.parse_args()\n\n    print(args)\n\n    np.random.seed(args.seed)\n    torch.manual_seed(args.seed)\n    random.seed(args.seed)\n\n    if torch.cuda.is_available():\n        torch.cuda.manual_seed(args.seed)\n        device = torch.device(\"cuda:\" + str(args.device))\n    else:\n        device = torch.device(\"cpu\")\n\n    ### automatic data loading and splitting\n    ### Read in the raw SMILES strings\n    smiles_dataset = PCQM4MDataset(root=\"dataset/\", only_smiles=True)\n    split_idx = smiles_dataset.get_idx_split()\n\n    test_smiles_dataset = [smiles_dataset[i] for i in split_idx[\"test\"]]\n    onthefly_dataset = OnTheFlyPCQMDataset(test_smiles_dataset)\n    test_loader = DataLoader(\n        onthefly_dataset,\n        batch_size=args.batch_size,\n        shuffle=False,\n        num_workers=args.num_workers,\n        collate_fn=collate_dgl,\n    )\n\n    ### automatic evaluator.\n    evaluator = PCQM4MEvaluator()\n\n    shared_params = {\n        \"num_layers\": args.num_layers,\n        \"emb_dim\": args.emb_dim,\n        \"drop_ratio\": args.drop_ratio,\n        \"graph_pooling\": args.graph_pooling,\n    }\n\n    if args.gnn == \"gin\":\n        model = GNN(gnn_type=\"gin\", virtual_node=False, **shared_params).to(\n            device\n        )\n    elif args.gnn == \"gin-virtual\":\n        model = GNN(gnn_type=\"gin\", virtual_node=True, **shared_params).to(\n            device\n        )\n    elif args.gnn == \"gcn\":\n        model = GNN(gnn_type=\"gcn\", virtual_node=False, **shared_params).to(\n            device\n        )\n    elif args.gnn == \"gcn-virtual\":\n        model = GNN(gnn_type=\"gcn\", virtual_node=True, **shared_params).to(\n            device\n        )\n    else:\n        raise ValueError(\"Invalid GNN type\")\n\n    num_params = sum(p.numel() for p in model.parameters())\n    print(f\"#Params: {num_params}\")\n\n    checkpoint_path = os.path.join(args.checkpoint_dir, \"checkpoint.pt\")\n    if not os.path.exists(checkpoint_path):\n        raise RuntimeError(f\"Checkpoint file not found at {checkpoint_path}\")\n\n    ## reading in checkpoint\n    checkpoint = torch.load(checkpoint_path, weights_only=False)\n    model.load_state_dict(checkpoint[\"model_state_dict\"])\n\n    print(\"Predicting on test data...\")\n    y_pred = test(model, device, test_loader)\n    print(\"Saving test submission file...\")\n    evaluator.save_test_submission({\"y_pred\": y_pred}, args.save_test_dir)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/pytorch/ogb_lsc/README.md",
    "content": "# Baselines for OGB Large-Scale Challenge (LSC) at KDD Cup 2021\n\n**Please upgrade your OGB to 1.3.1 to enable faster downloads**:\n\n- [Node Classification with MAG240M](https://dgl-data.s3-accelerate.amazonaws.com/dataset/OGB-LSC/mag240m_kddcup2021.zip)\n- [Link Prediction with WikiKG90M](https://dgl-data.s3-accelerate.amazonaws.com/dataset/OGB-LSC/wikikg90m_kddcup2021.zip)\n- [Graph Classification with PCQM4M](https://dgl-data.s3-accelerate.amazonaws.com/dataset/OGB-LSC/pcqm4m_kddcup2021.zip)\n\n\n# checksum md5sum of the files\nmag240m_kddcup2021.zip     : ```bd61c9446f557fbe4430d9a7ce108b34```\n\nwikikg90m_kddcup2021.zip   : ```73d4f5dde29d78669330b4db4c12fc9c``` \n\npcqm4m_kddcup2021.zip.     : ```5144ebaa7c67d24da1a2acbe41f57f6a``` \n"
  },
  {
    "path": "examples/pytorch/ogc/README.md",
    "content": "# Optimized Graph Convolution (OGC)\n\nThis DGL example implements the OGC method from the paper: [From Cluster Assumption to Graph Convolution: Graph-based Semi-Supervised Learning Revisited](https://arxiv.org/abs/2309.13599).\nWith only one trainable layer, OGC is a very simple but powerful graph convolution method.\n\n\n## Example Implementor\n\nThis example was implemented by [Sinuo Xu](https://github.com/SinuoXu) when she was an undergraduate at SJTU.\n\n\n## Dependencies\n\nPython     3.11.5\nPyTorch    2.0.1 \nDGL       1.1.2 \nscikit-learn 1.3.1\n\n\n## Dataset\n\nThe DGL's built-in Cora, Pubmed and Citeseer datasets, as follows:\n\n| Dataset | #Nodes | #Edges | #Feats | #Classes | #Train Nodes | #Val Nodes | #Test Nodes |\n| :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: |\n| Citeseer | 3,327 | 9,228 | 3,703 | 6 | 120 | 500 | 1000 |\n| Cora | 2,708 | 10,556 | 1,433 | 7 | 140 | 500 | 1000 |\n| Pubmed | 19,717 | 88,651 | 500 | 3 | 60 | 500 | 1000 |\n\n\n## Usage\n\n```bash\npython main.py --dataset cora\npython main.py --dataset citeseer\npython main.py --dataset pubmed\n```\n\n## Performance\n\n| Dataset | Cora | Citeseer | Pubmed |\n| :-: | :-: | :-: | :-: |\n| OGC (DGL) | **86.9(±0.2)** | **77.4(±0.1)** | **83.6(±0.1)** |\n| OGC (Reported) | **86.9(±0.0)** | **77.4(±0.0)** | 83.4(±0.0) |\n"
  },
  {
    "path": "examples/pytorch/ogc/ogc.py",
    "content": "import dgl.sparse as dglsp\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom utils import LinearNeuralNetwork\n\n\nclass OGC(nn.Module):\n    def __init__(self, graph):\n        super(OGC, self).__init__()\n        self.linear_clf = LinearNeuralNetwork(\n            nfeat=graph.ndata[\"feat\"].shape[1],\n            nclass=graph.ndata[\"label\"].max().item() + 1,\n            bias=False,\n        )\n\n        self.label = graph.ndata[\"label\"]\n        self.label_one_hot = F.one_hot(graph.ndata[\"label\"]).float()\n        # LIM trick, else use both train and val set to construct this matrix.\n        self.label_idx_mat = dglsp.diag(graph.ndata[\"train_mask\"]).float()\n\n        self.test_mask = graph.ndata[\"test_mask\"]\n        self.tv_mask = graph.ndata[\"train_mask\"] + graph.ndata[\"val_mask\"]\n\n    def forward(self, x):\n        return self.linear_clf(x)\n\n    def update_embeds(self, embeds, lazy_adj, args):\n        \"\"\"Update classifier's weight by training a linear supervised model.\"\"\"\n        pred_label = self(embeds).data\n        clf_weight = self.linear_clf.W.weight.data\n\n        # Update the smoothness loss via LGC.\n        embeds = dglsp.spmm(lazy_adj, embeds)\n\n        # Update the supervised loss via SEB.\n        deriv_sup = 2 * dglsp.matmul(\n            dglsp.spmm(self.label_idx_mat, -self.label_one_hot + pred_label),\n            clf_weight,\n        )\n        embeds = embeds - args.lr_sup * deriv_sup\n\n        args.lr_sup = args.lr_sup * args.decline\n        return embeds\n"
  },
  {
    "path": "examples/pytorch/ogc/train.py",
    "content": "import argparse\nimport time\n\nimport dgl.sparse as dglsp\n\nimport torch.nn.functional as F\nimport torch.optim as optim\nfrom dgl import AddSelfLoop\nfrom dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset\n\nfrom ogc import OGC\nfrom utils import model_test, symmetric_normalize_adjacency\n\n\ndef train(model, embeds, lazy_adj, args):\n    patience = 0\n    _, _, last_acc, last_output = model_test(model, embeds)\n\n    tv_mask = model.tv_mask\n    optimizer = optim.SGD(model.parameters(), lr=args.lr_clf)\n\n    for i in range(64):\n        model.train()\n        output = model(embeds)\n        loss_tv = F.mse_loss(\n            output[tv_mask], model.label_one_hot[tv_mask], reduction=\"sum\"\n        )\n        optimizer.zero_grad()\n        loss_tv.backward()\n        optimizer.step()\n\n        # Updating node embeds by LGC and SEB jointly.\n        embeds = model.update_embeds(embeds, lazy_adj, args)\n\n        loss_tv, acc_tv, acc_test, pred = model_test(model, embeds)\n        print(\n            \"epoch {} loss_tv {:.4f} acc_tv {:.4f} acc_test {:.4f}\".format(\n                i + 1, loss_tv, acc_tv, acc_test\n            )\n        )\n\n        sim_rate = float(int((pred == last_output).sum()) / int(pred.shape[0]))\n        if sim_rate > args.max_sim_rate:\n            patience += 1\n            if patience > args.max_patience:\n                break\n        last_acc = acc_test\n        last_output = pred\n    return last_acc\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--dataset\",\n        type=str,\n        default=\"citeseer\",\n        choices=[\"cora\", \"citeseer\", \"pubmed\"],\n        help=\"dataset to use\",\n    )\n    parser.add_argument(\n        \"--decline\", type=float, default=0.9, help=\"decline rate\"\n    )\n    parser.add_argument(\n        \"--lr_sup\",\n        type=float,\n        default=0.001,\n        help=\"learning rate for supervised loss\",\n    )\n    parser.add_argument(\n        \"--lr_clf\",\n        type=float,\n        default=0.5,\n        help=\"learning rate for the used linear classifier\",\n    )\n    parser.add_argument(\n        \"--beta\",\n        type=float,\n        default=0.1,\n        help=\"moving probability that a node moves to its neighbors\",\n    )\n    parser.add_argument(\n        \"--max_sim_rate\",\n        type=float,\n        default=0.995,\n        help=\"max label prediction similarity between iterations\",\n    )\n    parser.add_argument(\n        \"--max_patience\",\n        type=int,\n        default=2,\n        help=\"tolerance for consecutively similar test predictions\",\n    )\n    parser.add_argument(\n        \"--device\",\n        type=str,\n        default=\"cpu\",\n        choices=[\"cpu\", \"cuda\"],\n        help=\"device to use\",\n    )\n    args, _ = parser.parse_known_args()\n\n    # Load and preprocess dataset.\n    transform = AddSelfLoop()\n    if args.dataset == \"cora\":\n        data = CoraGraphDataset(transform=transform)\n    elif args.dataset == \"citeseer\":\n        data = CiteseerGraphDataset(transform=transform)\n    elif args.dataset == \"pubmed\":\n        data = PubmedGraphDataset(transform=transform)\n    else:\n        raise ValueError(\"Unknown dataset: {}\".format(args.dataset))\n    graph = data[0].to(args.device)\n    features = graph.ndata[\"feat\"]\n    adj = symmetric_normalize_adjacency(graph)\n    I_N = dglsp.identity((features.shape[0], features.shape[0]))\n    # Lazy random walk (also known as lazy graph convolution).\n    lazy_adj = dglsp.add((1 - args.beta) * I_N, args.beta * adj).to(args.device)\n\n    model = OGC(graph).to(args.device)\n    start_time = time.time()\n    res = train(model, features, lazy_adj, args)\n    time_tot = time.time() - start_time\n\n    print(f\"Test Acc:{res:.4f}\")\n    print(f\"Total Time:{time_tot:.4f}\")\n"
  },
  {
    "path": "examples/pytorch/ogc/utils.py",
    "content": "import dgl.sparse as dglsp\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass LinearNeuralNetwork(nn.Module):\n    def __init__(self, nfeat, nclass, bias=True):\n        super(LinearNeuralNetwork, self).__init__()\n        self.W = nn.Linear(nfeat, nclass, bias=bias)\n\n    def forward(self, x):\n        return self.W(x)\n\n\ndef symmetric_normalize_adjacency(graph):\n    \"\"\"Symmetric normalize graph adjacency matrix.\"\"\"\n    indices = torch.stack(graph.edges())\n    n = graph.num_nodes()\n    adj = dglsp.spmatrix(indices, shape=(n, n))\n    deg_invsqrt = dglsp.diag(adj.sum(0)) ** -0.5\n    return deg_invsqrt @ adj @ deg_invsqrt\n\n\ndef model_test(model, embeds):\n    model.eval()\n    with torch.no_grad():\n        output = model(embeds)\n        pred = output.argmax(dim=-1)\n        test_mask, tv_mask = model.test_mask, model.tv_mask\n        loss_tv = F.mse_loss(output[tv_mask], model.label_one_hot[tv_mask])\n    accs = []\n    for mask in [tv_mask, test_mask]:\n        accs.append(float((pred[mask] == model.label[mask]).sum() / mask.sum()))\n    return loss_tv.item(), accs[0], accs[1], pred\n"
  },
  {
    "path": "examples/pytorch/pagerank.py",
    "content": "import dgl\nimport dgl.function as fn\nimport networkx as nx\nimport torch\n\nN = 100\nnetwork = nx.erdos_renyi_graph(N, 0.05)\ng = dgl.from_networkx(network)\n\nDAMP = 0.85\nK = 10\n\n\ndef compute_pagerank(g):\n    g.ndata[\"pv\"] = torch.ones(N) / N\n    degrees = g.out_degrees(g.nodes()).type(torch.float32)\n    for k in range(K):\n        g.ndata[\"pv\"] = g.ndata[\"pv\"] / degrees\n        g.update_all(\n            message_func=fn.copy_u(u=\"pv\", out=\"m\"),\n            reduce_func=fn.sum(msg=\"m\", out=\"pv\"),\n        )\n        g.ndata[\"pv\"] = (1 - DAMP) / N + DAMP * g.ndata[\"pv\"]\n    return g.ndata[\"pv\"]\n\n\npv = compute_pagerank(g)\nprint(pv)\n"
  },
  {
    "path": "examples/pytorch/pinsage/README.md",
    "content": "# PinSAGE example\n\n## Requirements\n\n- dask\n- pandas\n- torchtext>=0.9.0\n\n## Prepare datasets\n\n### MovieLens 1M\n\n1. Download and extract the MovieLens-1M dataset from http://files.grouplens.org/datasets/movielens/ml-1m.zip\n   into the current directory.\n2. Run `python process_movielens1m.py ./ml-1m ./data_processed`.\n   Replace `ml-1m` with the directory you put the `.dat` files, and replace `data_processed` with\n   any path you wish to put the output files.\n\n### Nowplaying-rs\n\n1. Download and extract the Nowplaying-rs dataset from https://zenodo.org/record/3248543/files/nowplayingrs.zip?download=1\n   into the current directory.\n2. Run `python process_nowplaying_rs.py ./nowplaying_rs_dataset ./data_processed`\n\n## Run model\n\n### Nearest-neighbor recommendation\n\nThis model returns items that are K nearest neighbors of the latest item the user has\ninteracted.  The distance between two items are measured by Euclidean distance of\nitem embeddings, which are learned as outputs of PinSAGE.\n\n```\npython model.py data_processed --num-epochs 300 --num-workers 2 --device cuda:0 --hidden-dims 64\n```\n\nThe implementation here also assigns a learnable vector to each item.  If your hidden\nstate size is so large that the learnable vectors cannot fit into GPU, use this script\nfor sparse embedding update (written with `torch.optim.SparseAdam`) instead:\n\n\n```\npython model_sparse.py data_processed --num-epochs 300 --num-workers 2 --device cuda:0 --hidden-dims 1024\n```\n\nNote that since the embedding update is done on CPU, it will be significantly slower than doing\neverything on GPU.\n\nThe HITS@10 is 0.01241, compared to 0.01220 with SLIM with the same dimensionality.\\\n\n## Difference from the paper\n\nThe implementation here is different from what being described in the paper:\n\n1. The paper described a supervised setting where the authors have a ground truth set of which items are\n   relevant.  However, in traditional recommender system datasets we don't have such labels other than\n   which items are interacted by which users (as well as the user/item's own features).  Therefore, I\n   adapted PinSAGE to an unsupervised setting where I predict whether two items are cointeracted by the\n   same user.\n2. PinSAGE paper explicitly stated that the items do not learnable embeddings of nodes, but directly\n   express the embeddings as a function of node features.  While this is reasonable for rich datasets like\n   Pinterest's where images and texts are rich enough to distinguish the items from each other, it is\n   unfortunately not the case for traditional recommender system datasets like MovieLens or Nowplaying-RS\n   where we only have a bunch of categorical or numeric variables.  I found adding a learnable embedding\n   for each item still helpful for those datasets.\n3. The PinSAGE paper directly pass the GNN output to an MLP and make the result the final item\n   representation.  Here, I'm adding the GNN output with the node's own learnable embedding as\n   the final item representation instead.\n"
  },
  {
    "path": "examples/pytorch/pinsage/builder.py",
    "content": "\"\"\"Graph builder from pandas dataframes\"\"\"\nfrom collections import namedtuple\n\nimport dgl\n\nfrom pandas.api.types import (\n    is_categorical,\n    is_categorical_dtype,\n    is_numeric_dtype,\n)\n\n__all__ = [\"PandasGraphBuilder\"]\n\n\ndef _series_to_tensor(series):\n    if is_categorical(series):\n        return torch.LongTensor(series.cat.codes.values.astype(\"int64\"))\n    else:  # numeric\n        return torch.FloatTensor(series.values)\n\n\nclass PandasGraphBuilder(object):\n    \"\"\"Creates a heterogeneous graph from multiple pandas dataframes.\n\n    Examples\n    --------\n    Let's say we have the following three pandas dataframes:\n\n    User table ``users``:\n\n    ===========  ===========  =======\n    ``user_id``  ``country``  ``age``\n    ===========  ===========  =======\n    XYZZY        U.S.         25\n    FOO          China        24\n    BAR          China        23\n    ===========  ===========  =======\n\n    Game table ``games``:\n\n    ===========  =========  ==============  ==================\n    ``game_id``  ``title``  ``is_sandbox``  ``is_multiplayer``\n    ===========  =========  ==============  ==================\n    1            Minecraft  True            True\n    2            Tetris 99  False           True\n    ===========  =========  ==============  ==================\n\n    Play relationship table ``plays``:\n\n    ===========  ===========  =========\n    ``user_id``  ``game_id``  ``hours``\n    ===========  ===========  =========\n    XYZZY        1            24\n    FOO          1            20\n    FOO          2            16\n    BAR          2            28\n    ===========  ===========  =========\n\n    One could then create a bidirectional bipartite graph as follows:\n    >>> builder = PandasGraphBuilder()\n    >>> builder.add_entities(users, 'user_id', 'user')\n    >>> builder.add_entities(games, 'game_id', 'game')\n    >>> builder.add_binary_relations(plays, 'user_id', 'game_id', 'plays')\n    >>> builder.add_binary_relations(plays, 'game_id', 'user_id', 'played-by')\n    >>> g = builder.build()\n    >>> g.num_nodes('user')\n    3\n    >>> g.num_edges('plays')\n    4\n    \"\"\"\n\n    def __init__(self):\n        self.entity_tables = {}\n        self.relation_tables = {}\n\n        self.entity_pk_to_name = (\n            {}\n        )  # mapping from primary key name to entity name\n        self.entity_pk = {}  # mapping from entity name to primary key\n        self.entity_key_map = (\n            {}\n        )  # mapping from entity names to primary key values\n        self.num_nodes_per_type = {}\n        self.edges_per_relation = {}\n        self.relation_name_to_etype = {}\n        self.relation_src_key = {}  # mapping from relation name to source key\n        self.relation_dst_key = (\n            {}\n        )  # mapping from relation name to destination key\n\n    def add_entities(self, entity_table, primary_key, name):\n        entities = entity_table[primary_key].astype(\"category\")\n        if not (entities.value_counts() == 1).all():\n            raise ValueError(\n                \"Different entity with the same primary key detected.\"\n            )\n        # preserve the category order in the original entity table\n        entities = entities.cat.reorder_categories(\n            entity_table[primary_key].values\n        )\n\n        self.entity_pk_to_name[primary_key] = name\n        self.entity_pk[name] = primary_key\n        self.num_nodes_per_type[name] = entity_table.shape[0]\n        self.entity_key_map[name] = entities\n        self.entity_tables[name] = entity_table\n\n    def add_binary_relations(\n        self, relation_table, source_key, destination_key, name\n    ):\n        src = relation_table[source_key].astype(\"category\")\n        src = src.cat.set_categories(\n            self.entity_key_map[\n                self.entity_pk_to_name[source_key]\n            ].cat.categories\n        )\n        dst = relation_table[destination_key].astype(\"category\")\n        dst = dst.cat.set_categories(\n            self.entity_key_map[\n                self.entity_pk_to_name[destination_key]\n            ].cat.categories\n        )\n        if src.isnull().any():\n            raise ValueError(\n                \"Some source entities in relation %s do not exist in entity %s.\"\n                % (name, source_key)\n            )\n        if dst.isnull().any():\n            raise ValueError(\n                \"Some destination entities in relation %s do not exist in entity %s.\"\n                % (name, destination_key)\n            )\n\n        srctype = self.entity_pk_to_name[source_key]\n        dsttype = self.entity_pk_to_name[destination_key]\n        etype = (srctype, name, dsttype)\n        self.relation_name_to_etype[name] = etype\n        self.edges_per_relation[etype] = (\n            src.cat.codes.values.astype(\"int64\"),\n            dst.cat.codes.values.astype(\"int64\"),\n        )\n        self.relation_tables[name] = relation_table\n        self.relation_src_key[name] = source_key\n        self.relation_dst_key[name] = destination_key\n\n    def build(self):\n        # Create heterograph\n        graph = dgl.heterograph(\n            self.edges_per_relation, self.num_nodes_per_type\n        )\n        return graph\n"
  },
  {
    "path": "examples/pytorch/pinsage/data_utils.py",
    "content": "import dask.dataframe as dd\n\nimport dgl\nimport numpy as np\nimport scipy.sparse as ssp\nimport torch\nimport tqdm\n\n\n# This is the train-test split method most of the recommender system papers running on MovieLens\n# takes.  It essentially follows the intuition of \"training on the past and predict the future\".\n# One can also change the threshold to make validation and test set take larger proportions.\ndef train_test_split_by_time(df, timestamp, user):\n    df[\"train_mask\"] = np.ones((len(df),), dtype=np.bool_)\n    df[\"val_mask\"] = np.zeros((len(df),), dtype=np.bool_)\n    df[\"test_mask\"] = np.zeros((len(df),), dtype=np.bool_)\n    df = dd.from_pandas(df, npartitions=10)\n\n    def train_test_split(df):\n        df = df.sort_values([timestamp])\n        if df.shape[0] > 1:\n            df.iloc[-1, -3] = False\n            df.iloc[-1, -1] = True\n        if df.shape[0] > 2:\n            df.iloc[-2, -3] = False\n            df.iloc[-2, -2] = True\n        return df\n\n    meta_df = {\n        \"user_id\": np.int64,\n        \"movie_id\": np.int64,\n        \"rating\": np.int64,\n        \"timestamp\": np.int64,\n        \"user_id\": np.int64,\n        \"train_mask\": bool,\n        \"val_mask\": bool,\n        \"test_mask\": bool,\n    }\n\n    df = (\n        df.groupby(user, group_keys=False)\n        .apply(train_test_split, meta=meta_df)\n        .compute(scheduler=\"processes\")\n        .sort_index()\n    )\n    print(df[df[user] == df[user].unique()[0]].sort_values(timestamp))\n    return (\n        df[\"train_mask\"].to_numpy().nonzero()[0],\n        df[\"val_mask\"].to_numpy().nonzero()[0],\n        df[\"test_mask\"].to_numpy().nonzero()[0],\n    )\n\n\ndef build_train_graph(g, train_indices, utype, itype, etype, etype_rev):\n    train_g = g.edge_subgraph(\n        {etype: train_indices, etype_rev: train_indices}, relabel_nodes=False\n    )\n\n    # copy features\n    for ntype in g.ntypes:\n        for col, data in g.nodes[ntype].data.items():\n            train_g.nodes[ntype].data[col] = data\n    for etype in g.etypes:\n        for col, data in g.edges[etype].data.items():\n            train_g.edges[etype].data[col] = data[\n                train_g.edges[etype].data[dgl.EID]\n            ]\n\n    return train_g\n\n\ndef build_val_test_matrix(g, val_indices, test_indices, utype, itype, etype):\n    n_users = g.num_nodes(utype)\n    n_items = g.num_nodes(itype)\n    val_src, val_dst = g.find_edges(val_indices, etype=etype)\n    test_src, test_dst = g.find_edges(test_indices, etype=etype)\n    val_src = val_src.numpy()\n    val_dst = val_dst.numpy()\n    test_src = test_src.numpy()\n    test_dst = test_dst.numpy()\n    val_matrix = ssp.coo_matrix(\n        (np.ones_like(val_src), (val_src, val_dst)), (n_users, n_items)\n    )\n    test_matrix = ssp.coo_matrix(\n        (np.ones_like(test_src), (test_src, test_dst)), (n_users, n_items)\n    )\n\n    return val_matrix, test_matrix\n\n\ndef linear_normalize(values):\n    return (values - values.min(0, keepdims=True)) / (\n        values.max(0, keepdims=True) - values.min(0, keepdims=True)\n    )\n"
  },
  {
    "path": "examples/pytorch/pinsage/evaluation.py",
    "content": "import argparse\nimport pickle\n\nimport dgl\n\nimport numpy as np\nimport torch\n\n\ndef prec(recommendations, ground_truth):\n    n_users, n_items = ground_truth.shape\n    K = recommendations.shape[1]\n    user_idx = np.repeat(np.arange(n_users), K)\n    item_idx = recommendations.flatten()\n    relevance = ground_truth[user_idx, item_idx].reshape((n_users, K))\n    hit = relevance.any(axis=1).mean()\n    return hit\n\n\nclass LatestNNRecommender(object):\n    def __init__(\n        self, user_ntype, item_ntype, user_to_item_etype, timestamp, batch_size\n    ):\n        self.user_ntype = user_ntype\n        self.item_ntype = item_ntype\n        self.user_to_item_etype = user_to_item_etype\n        self.batch_size = batch_size\n        self.timestamp = timestamp\n\n    def recommend(self, full_graph, K, h_user, h_item):\n        \"\"\"\n        Return a (n_user, K) matrix of recommended items for each user\n        \"\"\"\n        graph_slice = full_graph.edge_type_subgraph([self.user_to_item_etype])\n        n_users = full_graph.num_nodes(self.user_ntype)\n        latest_interactions = dgl.sampling.select_topk(\n            graph_slice, 1, self.timestamp, edge_dir=\"out\"\n        )\n        user, latest_items = latest_interactions.all_edges(\n            form=\"uv\", order=\"srcdst\"\n        )\n        # each user should have at least one \"latest\" interaction\n        assert torch.equal(user, torch.arange(n_users))\n\n        recommended_batches = []\n        user_batches = torch.arange(n_users).split(self.batch_size)\n        for user_batch in user_batches:\n            latest_item_batch = latest_items[user_batch].to(\n                device=h_item.device\n            )\n            dist = h_item[latest_item_batch] @ h_item.t()\n            # exclude items that are already interacted\n            for i, u in enumerate(user_batch.tolist()):\n                interacted_items = full_graph.successors(\n                    u, etype=self.user_to_item_etype\n                )\n                dist[i, interacted_items] = -np.inf\n            recommended_batches.append(dist.topk(K, 1)[1])\n\n        recommendations = torch.cat(recommended_batches, 0)\n        return recommendations\n\n\ndef evaluate_nn(dataset, h_item, k, batch_size):\n    g = dataset[\"train-graph\"]\n    val_matrix = dataset[\"val-matrix\"].tocsr()\n    test_matrix = dataset[\"test-matrix\"].tocsr()\n    item_texts = dataset[\"item-texts\"]\n    user_ntype = dataset[\"user-type\"]\n    item_ntype = dataset[\"item-type\"]\n    user_to_item_etype = dataset[\"user-to-item-type\"]\n    timestamp = dataset[\"timestamp-edge-column\"]\n\n    rec_engine = LatestNNRecommender(\n        user_ntype, item_ntype, user_to_item_etype, timestamp, batch_size\n    )\n\n    recommendations = rec_engine.recommend(g, k, None, h_item).cpu().numpy()\n    return prec(recommendations, val_matrix)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"dataset_path\", type=str)\n    parser.add_argument(\"item_embedding_path\", type=str)\n    parser.add_argument(\"-k\", type=int, default=10)\n    parser.add_argument(\"--batch-size\", type=int, default=32)\n    args = parser.parse_args()\n\n    with open(args.dataset_path, \"rb\") as f:\n        dataset = pickle.load(f)\n    with open(args.item_embedding_path, \"rb\") as f:\n        emb = torch.FloatTensor(pickle.load(f))\n    print(evaluate_nn(dataset, emb, args.k, args.batch_size))\n"
  },
  {
    "path": "examples/pytorch/pinsage/layers.py",
    "content": "import dgl\nimport dgl.function as fn\nimport dgl.nn.pytorch as dglnn\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\ndef disable_grad(module):\n    for param in module.parameters():\n        param.requires_grad = False\n\n\ndef _init_input_modules(g, ntype, textset, hidden_dims):\n    # We initialize the linear projections of each input feature ``x`` as\n    # follows:\n    # * If ``x`` is a scalar integral feature, we assume that ``x`` is a categorical\n    #   feature, and assume the range of ``x`` is 0..max(x).\n    # * If ``x`` is a float one-dimensional feature, we assume that ``x`` is a\n    #   numeric vector.\n    # * If ``x`` is a field of a textset, we process it as bag of words.\n    module_dict = nn.ModuleDict()\n\n    for column, data in g.nodes[ntype].data.items():\n        if column == dgl.NID:\n            continue\n        if data.dtype == torch.float32:\n            assert data.ndim == 2\n            m = nn.Linear(data.shape[1], hidden_dims)\n            nn.init.xavier_uniform_(m.weight)\n            nn.init.constant_(m.bias, 0)\n            module_dict[column] = m\n        elif data.dtype == torch.int64:\n            assert data.ndim == 1\n            m = nn.Embedding(data.max() + 2, hidden_dims, padding_idx=-1)\n            nn.init.xavier_uniform_(m.weight)\n            module_dict[column] = m\n\n    if textset is not None:\n        for column, field in textset.items():\n            textlist, vocab, pad_var, batch_first = field\n            module_dict[column] = BagOfWords(vocab, hidden_dims)\n\n    return module_dict\n\n\nclass BagOfWords(nn.Module):\n    def __init__(self, vocab, hidden_dims):\n        super().__init__()\n\n        self.emb = nn.Embedding(\n            len(vocab.get_itos()),\n            hidden_dims,\n            padding_idx=vocab.get_stoi()[\"<pad>\"],\n        )\n        nn.init.xavier_uniform_(self.emb.weight)\n\n    def forward(self, x, length):\n        return self.emb(x).sum(1) / length.unsqueeze(1).float()\n\n\nclass LinearProjector(nn.Module):\n    \"\"\"\n    Projects each input feature of the graph linearly and sums them up\n    \"\"\"\n\n    def __init__(self, full_graph, ntype, textset, hidden_dims):\n        super().__init__()\n\n        self.ntype = ntype\n        self.inputs = _init_input_modules(\n            full_graph, ntype, textset, hidden_dims\n        )\n\n    def forward(self, ndata):\n        projections = []\n        for feature, data in ndata.items():\n            if feature == dgl.NID or feature.endswith(\"__len\"):\n                # This is an additional feature indicating the length of the ``feature``\n                # column; we shouldn't process this.\n                continue\n\n            module = self.inputs[feature]\n            if isinstance(module, BagOfWords):\n                # Textual feature; find the length and pass it to the textual module.\n                length = ndata[feature + \"__len\"]\n                result = module(data, length)\n            else:\n                result = module(data)\n            projections.append(result)\n\n        return torch.stack(projections, 1).sum(1)\n\n\nclass WeightedSAGEConv(nn.Module):\n    def __init__(self, input_dims, hidden_dims, output_dims, act=F.relu):\n        super().__init__()\n\n        self.act = act\n        self.Q = nn.Linear(input_dims, hidden_dims)\n        self.W = nn.Linear(input_dims + hidden_dims, output_dims)\n        self.reset_parameters()\n        self.dropout = nn.Dropout(0.5)\n\n    def reset_parameters(self):\n        gain = nn.init.calculate_gain(\"relu\")\n        nn.init.xavier_uniform_(self.Q.weight, gain=gain)\n        nn.init.xavier_uniform_(self.W.weight, gain=gain)\n        nn.init.constant_(self.Q.bias, 0)\n        nn.init.constant_(self.W.bias, 0)\n\n    def forward(self, g, h, weights):\n        \"\"\"\n        g : graph\n        h : node features\n        weights : scalar edge weights\n        \"\"\"\n        h_src, h_dst = h\n        with g.local_scope():\n            g.srcdata[\"n\"] = self.act(self.Q(self.dropout(h_src)))\n            g.edata[\"w\"] = weights.float()\n            g.update_all(fn.u_mul_e(\"n\", \"w\", \"m\"), fn.sum(\"m\", \"n\"))\n            g.update_all(fn.copy_e(\"w\", \"m\"), fn.sum(\"m\", \"ws\"))\n            n = g.dstdata[\"n\"]\n            ws = g.dstdata[\"ws\"].unsqueeze(1).clamp(min=1)\n            z = self.act(self.W(self.dropout(torch.cat([n / ws, h_dst], 1))))\n            z_norm = z.norm(2, 1, keepdim=True)\n            z_norm = torch.where(\n                z_norm == 0, torch.tensor(1.0).to(z_norm), z_norm\n            )\n            z = z / z_norm\n            return z\n\n\nclass SAGENet(nn.Module):\n    def __init__(self, hidden_dims, n_layers):\n        \"\"\"\n        g : DGLGraph\n            The user-item interaction graph.\n            This is only for finding the range of categorical variables.\n        item_textsets : torchtext.data.Dataset\n            The textual features of each item node.\n        \"\"\"\n        super().__init__()\n\n        self.convs = nn.ModuleList()\n        for _ in range(n_layers):\n            self.convs.append(\n                WeightedSAGEConv(hidden_dims, hidden_dims, hidden_dims)\n            )\n\n    def forward(self, blocks, h):\n        for layer, block in zip(self.convs, blocks):\n            h_dst = h[: block.num_nodes(\"DST/\" + block.ntypes[0])]\n            h = layer(block, (h, h_dst), block.edata[\"weights\"])\n        return h\n\n\nclass ItemToItemScorer(nn.Module):\n    def __init__(self, full_graph, ntype):\n        super().__init__()\n\n        n_nodes = full_graph.num_nodes(ntype)\n        self.bias = nn.Parameter(torch.zeros(n_nodes, 1))\n\n    def _add_bias(self, edges):\n        bias_src = self.bias[edges.src[dgl.NID]]\n        bias_dst = self.bias[edges.dst[dgl.NID]]\n        return {\"s\": edges.data[\"s\"] + bias_src + bias_dst}\n\n    def forward(self, item_item_graph, h):\n        \"\"\"\n        item_item_graph : graph consists of edges connecting the pairs\n        h : hidden state of every node\n        \"\"\"\n        with item_item_graph.local_scope():\n            item_item_graph.ndata[\"h\"] = h\n            item_item_graph.apply_edges(fn.u_dot_v(\"h\", \"h\", \"s\"))\n            item_item_graph.apply_edges(self._add_bias)\n            pair_score = item_item_graph.edata[\"s\"]\n        return pair_score\n"
  },
  {
    "path": "examples/pytorch/pinsage/model.py",
    "content": "import argparse\nimport os\nimport pickle\n\nimport dgl\n\nimport evaluation\nimport layers\nimport numpy as np\nimport sampler as sampler_module\nimport torch\nimport torch.nn as nn\nimport torchtext\nimport tqdm\nfrom torch.utils.data import DataLoader\nfrom torchtext.data.utils import get_tokenizer\nfrom torchtext.vocab import build_vocab_from_iterator\n\n\nclass PinSAGEModel(nn.Module):\n    def __init__(self, full_graph, ntype, textsets, hidden_dims, n_layers):\n        super().__init__()\n\n        self.proj = layers.LinearProjector(\n            full_graph, ntype, textsets, hidden_dims\n        )\n        self.sage = layers.SAGENet(hidden_dims, n_layers)\n        self.scorer = layers.ItemToItemScorer(full_graph, ntype)\n\n    def forward(self, pos_graph, neg_graph, blocks):\n        h_item = self.get_repr(blocks)\n        pos_score = self.scorer(pos_graph, h_item)\n        neg_score = self.scorer(neg_graph, h_item)\n        return (neg_score - pos_score + 1).clamp(min=0)\n\n    def get_repr(self, blocks):\n        h_item = self.proj(blocks[0].srcdata)\n        h_item_dst = self.proj(blocks[-1].dstdata)\n        return h_item_dst + self.sage(blocks, h_item)\n\n\ndef train(dataset, args):\n    g = dataset[\"train-graph\"]\n    val_matrix = dataset[\"val-matrix\"].tocsr()\n    test_matrix = dataset[\"test-matrix\"].tocsr()\n    item_texts = dataset[\"item-texts\"]\n    user_ntype = dataset[\"user-type\"]\n    item_ntype = dataset[\"item-type\"]\n    user_to_item_etype = dataset[\"user-to-item-type\"]\n    timestamp = dataset[\"timestamp-edge-column\"]\n\n    device = torch.device(args.device)\n\n    # Assign user and movie IDs and use them as features (to learn an individual trainable\n    # embedding for each entity)\n    g.nodes[user_ntype].data[\"id\"] = torch.arange(g.num_nodes(user_ntype))\n    g.nodes[item_ntype].data[\"id\"] = torch.arange(g.num_nodes(item_ntype))\n\n    # Prepare torchtext dataset and Vocabulary\n    textset = {}\n    tokenizer = get_tokenizer(None)\n\n    textlist = []\n    batch_first = True\n\n    for i in range(g.num_nodes(item_ntype)):\n        for key in item_texts.keys():\n            l = tokenizer(item_texts[key][i].lower())\n            textlist.append(l)\n    for key, field in item_texts.items():\n        vocab2 = build_vocab_from_iterator(\n            textlist, specials=[\"<unk>\", \"<pad>\"]\n        )\n        textset[key] = (\n            textlist,\n            vocab2,\n            vocab2.get_stoi()[\"<pad>\"],\n            batch_first,\n        )\n\n    # Sampler\n    batch_sampler = sampler_module.ItemToItemBatchSampler(\n        g, user_ntype, item_ntype, args.batch_size\n    )\n    neighbor_sampler = sampler_module.NeighborSampler(\n        g,\n        user_ntype,\n        item_ntype,\n        args.random_walk_length,\n        args.random_walk_restart_prob,\n        args.num_random_walks,\n        args.num_neighbors,\n        args.num_layers,\n    )\n    collator = sampler_module.PinSAGECollator(\n        neighbor_sampler, g, item_ntype, textset\n    )\n    dataloader = DataLoader(\n        batch_sampler,\n        collate_fn=collator.collate_train,\n        num_workers=args.num_workers,\n    )\n    dataloader_test = DataLoader(\n        torch.arange(g.num_nodes(item_ntype)),\n        batch_size=args.batch_size,\n        collate_fn=collator.collate_test,\n        num_workers=args.num_workers,\n    )\n    dataloader_it = iter(dataloader)\n\n    # Model\n    model = PinSAGEModel(\n        g, item_ntype, textset, args.hidden_dims, args.num_layers\n    ).to(device)\n    # Optimizer\n    opt = torch.optim.Adam(model.parameters(), lr=args.lr)\n\n    # For each batch of head-tail-negative triplets...\n    for epoch_id in range(args.num_epochs):\n        model.train()\n        for batch_id in tqdm.trange(args.batches_per_epoch):\n            pos_graph, neg_graph, blocks = next(dataloader_it)\n            # Copy to GPU\n            for i in range(len(blocks)):\n                blocks[i] = blocks[i].to(device)\n            pos_graph = pos_graph.to(device)\n            neg_graph = neg_graph.to(device)\n\n            loss = model(pos_graph, neg_graph, blocks).mean()\n            opt.zero_grad()\n            loss.backward()\n            opt.step()\n\n        # Evaluate\n        model.eval()\n        with torch.no_grad():\n            item_batches = torch.arange(g.num_nodes(item_ntype)).split(\n                args.batch_size\n            )\n            h_item_batches = []\n            for blocks in dataloader_test:\n                for i in range(len(blocks)):\n                    blocks[i] = blocks[i].to(device)\n\n                h_item_batches.append(model.get_repr(blocks))\n            h_item = torch.cat(h_item_batches, 0)\n\n            print(\n                evaluation.evaluate_nn(dataset, h_item, args.k, args.batch_size)\n            )\n\n\nif __name__ == \"__main__\":\n    # Arguments\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"dataset_path\", type=str)\n    parser.add_argument(\"--random-walk-length\", type=int, default=2)\n    parser.add_argument(\"--random-walk-restart-prob\", type=float, default=0.5)\n    parser.add_argument(\"--num-random-walks\", type=int, default=10)\n    parser.add_argument(\"--num-neighbors\", type=int, default=3)\n    parser.add_argument(\"--num-layers\", type=int, default=2)\n    parser.add_argument(\"--hidden-dims\", type=int, default=16)\n    parser.add_argument(\"--batch-size\", type=int, default=32)\n    parser.add_argument(\n        \"--device\", type=str, default=\"cpu\"\n    )  # can also be \"cuda:0\"\n    parser.add_argument(\"--num-epochs\", type=int, default=1)\n    parser.add_argument(\"--batches-per-epoch\", type=int, default=20000)\n    parser.add_argument(\"--num-workers\", type=int, default=0)\n    parser.add_argument(\"--lr\", type=float, default=3e-5)\n    parser.add_argument(\"-k\", type=int, default=10)\n    args = parser.parse_args()\n\n    # Load dataset\n    data_info_path = os.path.join(args.dataset_path, \"data.pkl\")\n    with open(data_info_path, \"rb\") as f:\n        dataset = pickle.load(f)\n    train_g_path = os.path.join(args.dataset_path, \"train_g.bin\")\n    g_list, _ = dgl.load_graphs(train_g_path)\n    dataset[\"train-graph\"] = g_list[0]\n    train(dataset, args)\n"
  },
  {
    "path": "examples/pytorch/pinsage/model_sparse.py",
    "content": "import argparse\nimport os\nimport pickle\n\nimport dgl\n\nimport evaluation\nimport layers\nimport numpy as np\nimport sampler as sampler_module\nimport torch\nimport torch.nn as nn\nimport torchtext\nimport tqdm\nfrom torch.utils.data import DataLoader\nfrom torchtext.data.utils import get_tokenizer\nfrom torchtext.vocab import build_vocab_from_iterator\n\n\nclass PinSAGEModel(nn.Module):\n    def __init__(self, full_graph, ntype, textsets, hidden_dims, n_layers):\n        super().__init__()\n\n        self.proj = layers.LinearProjector(\n            full_graph, ntype, textsets, hidden_dims\n        )\n        self.sage = layers.SAGENet(hidden_dims, n_layers)\n        self.scorer = layers.ItemToItemScorer(full_graph, ntype)\n\n    def forward(self, pos_graph, neg_graph, blocks, item_emb):\n        h_item = self.get_repr(blocks, item_emb)\n        pos_score = self.scorer(pos_graph, h_item)\n        neg_score = self.scorer(neg_graph, h_item)\n        return (neg_score - pos_score + 1).clamp(min=0)\n\n    def get_repr(self, blocks, item_emb):\n        # project features\n        h_item = self.proj(blocks[0].srcdata)\n        h_item_dst = self.proj(blocks[-1].dstdata)\n\n        # add to the item embedding itself\n        h_item = h_item + item_emb(blocks[0].srcdata[dgl.NID].cpu()).to(h_item)\n        h_item_dst = h_item_dst + item_emb(\n            blocks[-1].dstdata[dgl.NID].cpu()\n        ).to(h_item_dst)\n\n        return h_item_dst + self.sage(blocks, h_item)\n\n\ndef train(dataset, args):\n    g = dataset[\"train-graph\"]\n    val_matrix = dataset[\"val-matrix\"].tocsr()\n    test_matrix = dataset[\"test-matrix\"].tocsr()\n    item_texts = dataset[\"item-texts\"]\n    user_ntype = dataset[\"user-type\"]\n    item_ntype = dataset[\"item-type\"]\n    user_to_item_etype = dataset[\"user-to-item-type\"]\n    timestamp = dataset[\"timestamp-edge-column\"]\n\n    device = torch.device(args.device)\n\n    # Prepare torchtext dataset and vocabulary\n    textset = {}\n    tokenizer = get_tokenizer(None)\n\n    textlist = []\n    batch_first = True\n\n    for i in range(g.num_nodes(item_ntype)):\n        for key in item_texts.keys():\n            l = tokenizer(item_texts[key][i].lower())\n            textlist.append(l)\n    for key, field in item_texts.items():\n        vocab2 = build_vocab_from_iterator(\n            textlist, specials=[\"<unk>\", \"<pad>\"]\n        )\n        textset[key] = (\n            textlist,\n            vocab2,\n            vocab2.get_stoi()[\"<pad>\"],\n            batch_first,\n        )\n\n    # Sampler\n    batch_sampler = sampler_module.ItemToItemBatchSampler(\n        g, user_ntype, item_ntype, args.batch_size\n    )\n    neighbor_sampler = sampler_module.NeighborSampler(\n        g,\n        user_ntype,\n        item_ntype,\n        args.random_walk_length,\n        args.random_walk_restart_prob,\n        args.num_random_walks,\n        args.num_neighbors,\n        args.num_layers,\n    )\n    collator = sampler_module.PinSAGECollator(\n        neighbor_sampler, g, item_ntype, textset\n    )\n    dataloader = DataLoader(\n        batch_sampler,\n        collate_fn=collator.collate_train,\n        num_workers=args.num_workers,\n    )\n    dataloader_test = DataLoader(\n        torch.arange(g.num_nodes(item_ntype)),\n        batch_size=args.batch_size,\n        collate_fn=collator.collate_test,\n        num_workers=args.num_workers,\n    )\n    dataloader_it = iter(dataloader)\n\n    # Model\n    model = PinSAGEModel(\n        g, item_ntype, textset, args.hidden_dims, args.num_layers\n    ).to(device)\n    item_emb = nn.Embedding(\n        g.num_nodes(item_ntype), args.hidden_dims, sparse=True\n    )\n    # Optimizer\n    opt = torch.optim.Adam(model.parameters(), lr=args.lr)\n    opt_emb = torch.optim.SparseAdam(item_emb.parameters(), lr=args.lr)\n\n    # For each batch of head-tail-negative triplets...\n    for epoch_id in range(args.num_epochs):\n        model.train()\n        for batch_id in tqdm.trange(args.batches_per_epoch):\n            pos_graph, neg_graph, blocks = next(dataloader_it)\n            # Copy to GPU\n            for i in range(len(blocks)):\n                blocks[i] = blocks[i].to(device)\n            pos_graph = pos_graph.to(device)\n            neg_graph = neg_graph.to(device)\n\n            loss = model(pos_graph, neg_graph, blocks, item_emb).mean()\n            opt.zero_grad()\n            opt_emb.zero_grad()\n            loss.backward()\n            opt.step()\n            opt_emb.step()\n\n        # Evaluate\n        model.eval()\n        with torch.no_grad():\n            item_batches = torch.arange(g.num_nodes(item_ntype)).split(\n                args.batch_size\n            )\n            h_item_batches = []\n            for blocks in tqdm.tqdm(dataloader_test):\n                for i in range(len(blocks)):\n                    blocks[i] = blocks[i].to(device)\n\n                h_item_batches.append(model.get_repr(blocks, item_emb))\n            h_item = torch.cat(h_item_batches, 0)\n\n            print(\n                evaluation.evaluate_nn(dataset, h_item, args.k, args.batch_size)\n            )\n\n\nif __name__ == \"__main__\":\n    # Arguments\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"dataset_path\", type=str)\n    parser.add_argument(\"--random-walk-length\", type=int, default=2)\n    parser.add_argument(\"--random-walk-restart-prob\", type=float, default=0.5)\n    parser.add_argument(\"--num-random-walks\", type=int, default=10)\n    parser.add_argument(\"--num-neighbors\", type=int, default=3)\n    parser.add_argument(\"--num-layers\", type=int, default=2)\n    parser.add_argument(\"--hidden-dims\", type=int, default=16)\n    parser.add_argument(\"--batch-size\", type=int, default=32)\n    parser.add_argument(\n        \"--device\", type=str, default=\"cpu\"\n    )  # can also be \"cuda:0\"\n    parser.add_argument(\"--num-epochs\", type=int, default=1)\n    parser.add_argument(\"--batches-per-epoch\", type=int, default=20000)\n    parser.add_argument(\"--num-workers\", type=int, default=0)\n    parser.add_argument(\"--lr\", type=float, default=3e-5)\n    parser.add_argument(\"-k\", type=int, default=10)\n    args = parser.parse_args()\n\n    # Load dataset\n    data_info_path = os.path.join(args.dataset_path, \"data.pkl\")\n    with open(data_info_path, \"rb\") as f:\n        dataset = pickle.load(f)\n    train_g_path = os.path.join(args.dataset_path, \"train_g.bin\")\n    g_list, _ = dgl.load_graphs(train_g_path)\n    dataset[\"train-graph\"] = g_list[0]\n    train(dataset, args)\n"
  },
  {
    "path": "examples/pytorch/pinsage/process_movielens1m.py",
    "content": "\"\"\"\nScript that reads from raw MovieLens-1M data and dumps into a pickle\nfile the following:\n\n* A heterogeneous graph with categorical features.\n* A list with all the movie titles.  The movie titles correspond to\n  the movie nodes in the heterogeneous graph.\n\nThis script exemplifies how to prepare tabular data with textual\nfeatures.  Since DGL graphs do not store variable-length features, we\ninstead put variable-length features into a more suitable container\n(e.g. torchtext to handle list of texts)\n\"\"\"\n\nimport argparse\nimport os\nimport pickle\nimport re\n\nimport numpy as np\nimport pandas as pd\nimport scipy.sparse as ssp\nimport torch\nimport torchtext\nfrom builder import PandasGraphBuilder\nfrom data_utils import *\n\nimport dgl\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"directory\", type=str)\n    parser.add_argument(\"out_directory\", type=str)\n    args = parser.parse_args()\n    directory = args.directory\n    out_directory = args.out_directory\n    os.makedirs(out_directory, exist_ok=True)\n\n    ## Build heterogeneous graph\n\n    # Load data\n    users = []\n    with open(os.path.join(directory, \"users.dat\"), encoding=\"latin1\") as f:\n        for l in f:\n            id_, gender, age, occupation, zip_ = l.strip().split(\"::\")\n            users.append(\n                {\n                    \"user_id\": int(id_),\n                    \"gender\": gender,\n                    \"age\": age,\n                    \"occupation\": occupation,\n                    \"zip\": zip_,\n                }\n            )\n    users = pd.DataFrame(users).astype(\"category\")\n\n    movies = []\n    with open(os.path.join(directory, \"movies.dat\"), encoding=\"latin1\") as f:\n        for l in f:\n            id_, title, genres = l.strip().split(\"::\")\n            genres_set = set(genres.split(\"|\"))\n\n            # extract year\n            assert re.match(r\".*\\([0-9]{4}\\)$\", title)\n            year = title[-5:-1]\n            title = title[:-6].strip()\n\n            data = {\"movie_id\": int(id_), \"title\": title, \"year\": year}\n            for g in genres_set:\n                data[g] = True\n            movies.append(data)\n    movies = pd.DataFrame(movies).astype({\"year\": \"category\"})\n\n    ratings = []\n    with open(os.path.join(directory, \"ratings.dat\"), encoding=\"latin1\") as f:\n        for l in f:\n            user_id, movie_id, rating, timestamp = [\n                int(_) for _ in l.split(\"::\")\n            ]\n            ratings.append(\n                {\n                    \"user_id\": user_id,\n                    \"movie_id\": movie_id,\n                    \"rating\": rating,\n                    \"timestamp\": timestamp,\n                }\n            )\n    ratings = pd.DataFrame(ratings)\n\n    # Filter the users and items that never appear in the rating table.\n    distinct_users_in_ratings = ratings[\"user_id\"].unique()\n    distinct_movies_in_ratings = ratings[\"movie_id\"].unique()\n    users = users[users[\"user_id\"].isin(distinct_users_in_ratings)]\n    movies = movies[movies[\"movie_id\"].isin(distinct_movies_in_ratings)]\n\n    # Group the movie features into genres (a vector), year (a category), title (a string)\n    genre_columns = movies.columns.drop([\"movie_id\", \"title\", \"year\"])\n    movies[genre_columns] = movies[genre_columns].fillna(False).astype(\"bool\")\n    movies_categorical = movies.drop(\"title\", axis=1)\n\n    # Build graph\n    graph_builder = PandasGraphBuilder()\n    graph_builder.add_entities(users, \"user_id\", \"user\")\n    graph_builder.add_entities(movies_categorical, \"movie_id\", \"movie\")\n    graph_builder.add_binary_relations(\n        ratings, \"user_id\", \"movie_id\", \"watched\"\n    )\n    graph_builder.add_binary_relations(\n        ratings, \"movie_id\", \"user_id\", \"watched-by\"\n    )\n\n    g = graph_builder.build()\n\n    # Assign features.\n    # Note that variable-sized features such as texts or images are handled elsewhere.\n    for data_type in [\"gender\", \"age\", \"occupation\", \"zip\"]:\n        g.nodes[\"user\"].data[data_type] = torch.LongTensor(\n            np.array(users[data_type].cat.codes.values)\n        )\n\n    g.nodes[\"movie\"].data[\"year\"] = torch.LongTensor(\n        np.array(movies[\"year\"].cat.codes.values)\n    )\n    g.nodes[\"movie\"].data[\"genre\"] = torch.FloatTensor(\n        np.array(movies[genre_columns].values)\n    )\n\n    for edge_type in [\"watched\", \"watched-by\"]:\n        for data_type in [\"rating\", \"timestamp\"]:\n            g.edges[edge_type].data[data_type] = torch.LongTensor(\n                np.array(ratings[data_type].values)\n            )\n\n    # Train-validation-test split\n    # This is a little bit tricky as we want to select the last interaction for test, and the\n    # second-to-last interaction for validation.\n    train_indices, val_indices, test_indices = train_test_split_by_time(\n        ratings, \"timestamp\", \"user_id\"\n    )\n\n    # Build the graph with training interactions only.\n    train_g = build_train_graph(\n        g, train_indices, \"user\", \"movie\", \"watched\", \"watched-by\"\n    )\n    assert train_g.out_degrees(etype=\"watched\").min() > 0\n\n    # Build the user-item sparse matrix for validation and test set.\n    val_matrix, test_matrix = build_val_test_matrix(\n        g, val_indices, test_indices, \"user\", \"movie\", \"watched\"\n    )\n\n    ## Build title set\n\n    movie_textual_dataset = {\"title\": movies[\"title\"].values}\n\n    # The model should build their own vocabulary and process the texts.  Here is one example\n    # of using torchtext to pad and numericalize a batch of strings.\n    #     field = torchtext.data.Field(include_lengths=True, lower=True, batch_first=True)\n    #     examples = [torchtext.data.Example.fromlist([t], [('title', title_field)]) for t in texts]\n    #     titleset = torchtext.data.Dataset(examples, [('title', title_field)])\n    #     field.build_vocab(titleset.title, vectors='fasttext.simple.300d')\n    #     token_ids, lengths = field.process([examples[0].title, examples[1].title])\n\n    ## Dump the graph and the datasets\n\n    dgl.save_graphs(os.path.join(out_directory, \"train_g.bin\"), train_g)\n\n    dataset = {\n        \"val-matrix\": val_matrix,\n        \"test-matrix\": test_matrix,\n        \"item-texts\": movie_textual_dataset,\n        \"item-images\": None,\n        \"user-type\": \"user\",\n        \"item-type\": \"movie\",\n        \"user-to-item-type\": \"watched\",\n        \"item-to-user-type\": \"watched-by\",\n        \"timestamp-edge-column\": \"timestamp\",\n    }\n\n    with open(os.path.join(out_directory, \"data.pkl\"), \"wb\") as f:\n        pickle.dump(dataset, f)\n"
  },
  {
    "path": "examples/pytorch/pinsage/process_nowplaying_rs.py",
    "content": "\"\"\"\nScript that reads from raw Nowplaying-RS data and dumps into a pickle\nfile a heterogeneous graph with categorical and numeric features.\n\"\"\"\n\nimport argparse\nimport os\nimport pickle\n\nimport pandas as pd\nimport scipy.sparse as ssp\nfrom builder import PandasGraphBuilder\nfrom data_utils import *\n\nimport dgl\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"directory\", type=str)\n    parser.add_argument(\"out_directory\", type=str)\n    args = parser.parse_args()\n    directory = args.directory\n    out_directory = args.out_directory\n    os.makedirs(out_directory, exist_ok=True)\n\n    data = pd.read_csv(os.path.join(directory, \"context_content_features.csv\"))\n    track_feature_cols = list(data.columns[1:13])\n    data = data[\n        [\"user_id\", \"track_id\", \"created_at\"] + track_feature_cols\n    ].dropna()\n\n    users = data[[\"user_id\"]].drop_duplicates()\n    tracks = data[[\"track_id\"] + track_feature_cols].drop_duplicates()\n    assert tracks[\"track_id\"].value_counts().max() == 1\n    tracks = tracks.astype(\n        {\"mode\": \"int64\", \"key\": \"int64\", \"artist_id\": \"category\"}\n    )\n    events = data[[\"user_id\", \"track_id\", \"created_at\"]]\n    events[\"created_at\"] = (\n        events[\"created_at\"].values.astype(\"datetime64[s]\").astype(\"int64\")\n    )\n\n    graph_builder = PandasGraphBuilder()\n    graph_builder.add_entities(users, \"user_id\", \"user\")\n    graph_builder.add_entities(tracks, \"track_id\", \"track\")\n    graph_builder.add_binary_relations(\n        events, \"user_id\", \"track_id\", \"listened\"\n    )\n    graph_builder.add_binary_relations(\n        events, \"track_id\", \"user_id\", \"listened-by\"\n    )\n\n    g = graph_builder.build()\n\n    float_cols = []\n    for col in tracks.columns:\n        if col == \"track_id\":\n            continue\n        elif col == \"artist_id\":\n            g.nodes[\"track\"].data[col] = torch.LongTensor(\n                tracks[col].cat.codes.values\n            )\n        elif tracks.dtypes[col] == \"float64\":\n            float_cols.append(col)\n        else:\n            g.nodes[\"track\"].data[col] = torch.LongTensor(tracks[col].values)\n    g.nodes[\"track\"].data[\"song_features\"] = torch.FloatTensor(\n        linear_normalize(tracks[float_cols].values)\n    )\n    g.edges[\"listened\"].data[\"created_at\"] = torch.LongTensor(\n        events[\"created_at\"].values\n    )\n    g.edges[\"listened-by\"].data[\"created_at\"] = torch.LongTensor(\n        events[\"created_at\"].values\n    )\n\n    n_edges = g.num_edges(\"listened\")\n    train_indices, val_indices, test_indices = train_test_split_by_time(\n        events, \"created_at\", \"user_id\"\n    )\n    train_g = build_train_graph(\n        g, train_indices, \"user\", \"track\", \"listened\", \"listened-by\"\n    )\n    assert train_g.out_degrees(etype=\"listened\").min() > 0\n    val_matrix, test_matrix = build_val_test_matrix(\n        g, val_indices, test_indices, \"user\", \"track\", \"listened\"\n    )\n\n    dgl.save_graphs(os.path.join(out_directory, \"train_g.bin\"), train_g)\n\n    dataset = {\n        \"val-matrix\": val_matrix,\n        \"test-matrix\": test_matrix,\n        \"item-texts\": {},\n        \"item-images\": None,\n        \"user-type\": \"user\",\n        \"item-type\": \"track\",\n        \"user-to-item-type\": \"listened\",\n        \"item-to-user-type\": \"listened-by\",\n        \"timestamp-edge-column\": \"created_at\",\n    }\n\n    with open(os.path.join(out_directory, \"data.pkl\"), \"wb\") as f:\n        pickle.dump(dataset, f)\n"
  },
  {
    "path": "examples/pytorch/pinsage/sampler.py",
    "content": "import dgl\nimport numpy as np\nimport torch\nfrom torch.utils.data import DataLoader, IterableDataset\nfrom torchtext.data.functional import numericalize_tokens_from_iterator\n\n\ndef padding(array, yy, val):\n    \"\"\"\n    :param array: torch tensor array\n    :param yy: desired width\n    :param val: padded value\n    :return: padded array\n    \"\"\"\n    w = array.shape[0]\n    b = 0\n    bb = yy - b - w\n\n    return torch.nn.functional.pad(\n        array, pad=(b, bb), mode=\"constant\", value=val\n    )\n\n\ndef compact_and_copy(frontier, seeds):\n    block = dgl.to_block(frontier, seeds)\n    for col, data in frontier.edata.items():\n        if col == dgl.EID:\n            continue\n        block.edata[col] = data[block.edata[dgl.EID]]\n    return block\n\n\nclass ItemToItemBatchSampler(IterableDataset):\n    def __init__(self, g, user_type, item_type, batch_size):\n        self.g = g\n        self.user_type = user_type\n        self.item_type = item_type\n        self.user_to_item_etype = list(g.metagraph()[user_type][item_type])[0]\n        self.item_to_user_etype = list(g.metagraph()[item_type][user_type])[0]\n        self.batch_size = batch_size\n\n    def __iter__(self):\n        while True:\n            heads = torch.randint(\n                0, self.g.num_nodes(self.item_type), (self.batch_size,)\n            )\n            tails = dgl.sampling.random_walk(\n                self.g,\n                heads,\n                metapath=[self.item_to_user_etype, self.user_to_item_etype],\n            )[0][:, 2]\n            neg_tails = torch.randint(\n                0, self.g.num_nodes(self.item_type), (self.batch_size,)\n            )\n\n            mask = tails != -1\n            yield heads[mask], tails[mask], neg_tails[mask]\n\n\nclass NeighborSampler(object):\n    def __init__(\n        self,\n        g,\n        user_type,\n        item_type,\n        random_walk_length,\n        random_walk_restart_prob,\n        num_random_walks,\n        num_neighbors,\n        num_layers,\n    ):\n        self.g = g\n        self.user_type = user_type\n        self.item_type = item_type\n        self.user_to_item_etype = list(g.metagraph()[user_type][item_type])[0]\n        self.item_to_user_etype = list(g.metagraph()[item_type][user_type])[0]\n        self.samplers = [\n            dgl.sampling.PinSAGESampler(\n                g,\n                item_type,\n                user_type,\n                random_walk_length,\n                random_walk_restart_prob,\n                num_random_walks,\n                num_neighbors,\n            )\n            for _ in range(num_layers)\n        ]\n\n    def sample_blocks(self, seeds, heads=None, tails=None, neg_tails=None):\n        blocks = []\n        for sampler in self.samplers:\n            frontier = sampler(seeds)\n            if heads is not None:\n                eids = frontier.edge_ids(\n                    torch.cat([heads, heads]),\n                    torch.cat([tails, neg_tails]),\n                    return_uv=True,\n                )[2]\n                if len(eids) > 0:\n                    old_frontier = frontier\n                    frontier = dgl.remove_edges(old_frontier, eids)\n                    # print(old_frontier)\n                    # print(frontier)\n                    # print(frontier.edata['weights'])\n                    # frontier.edata['weights'] = old_frontier.edata['weights'][frontier.edata[dgl.EID]]\n            block = compact_and_copy(frontier, seeds)\n            seeds = block.srcdata[dgl.NID]\n            blocks.insert(0, block)\n        return blocks\n\n    def sample_from_item_pairs(self, heads, tails, neg_tails):\n        # Create a graph with positive connections only and another graph with negative\n        # connections only.\n        pos_graph = dgl.graph(\n            (heads, tails), num_nodes=self.g.num_nodes(self.item_type)\n        )\n        neg_graph = dgl.graph(\n            (heads, neg_tails), num_nodes=self.g.num_nodes(self.item_type)\n        )\n        pos_graph, neg_graph = dgl.compact_graphs([pos_graph, neg_graph])\n        seeds = pos_graph.ndata[dgl.NID]\n\n        blocks = self.sample_blocks(seeds, heads, tails, neg_tails)\n        return pos_graph, neg_graph, blocks\n\n\ndef assign_simple_node_features(ndata, g, ntype, assign_id=False):\n    \"\"\"\n    Copies data to the given block from the corresponding nodes in the original graph.\n    \"\"\"\n    for col in g.nodes[ntype].data.keys():\n        if not assign_id and col == dgl.NID:\n            continue\n        induced_nodes = ndata[dgl.NID]\n        ndata[col] = g.nodes[ntype].data[col][induced_nodes]\n\n\ndef assign_textual_node_features(ndata, textset, ntype):\n    \"\"\"\n    Assigns numericalized tokens from a torchtext dataset to given block.\n\n    The numericalized tokens would be stored in the block as node features\n    with the same name as ``field_name``.\n\n    The length would be stored as another node feature with name\n    ``field_name + '__len'``.\n\n    block : DGLGraph\n        First element of the compacted blocks, with \"dgl.NID\" as the\n        corresponding node ID in the original graph, hence the index to the\n        text dataset.\n\n        The numericalized tokens (and lengths if available) would be stored\n        onto the blocks as new node features.\n    textset : torchtext.data.Dataset\n        A torchtext dataset whose number of examples is the same as that\n        of nodes in the original graph.\n    \"\"\"\n    node_ids = ndata[dgl.NID].numpy()\n\n    for field_name, field in textset.items():\n        textlist, vocab, pad_var, batch_first = field\n\n        examples = [textlist[i] for i in node_ids]\n        ids_iter = numericalize_tokens_from_iterator(vocab, examples)\n\n        maxsize = max([len(textlist[i]) for i in node_ids])\n        ids = next(ids_iter)\n        x = torch.asarray([num for num in ids])\n        lengths = torch.tensor([len(x)])\n        tokens = padding(x, maxsize, pad_var)\n\n        for ids in ids_iter:\n            x = torch.asarray([num for num in ids])\n            l = torch.tensor([len(x)])\n            y = padding(x, maxsize, pad_var)\n            tokens = torch.vstack((tokens, y))\n            lengths = torch.cat((lengths, l))\n\n        if not batch_first:\n            tokens = tokens.t()\n\n        ndata[field_name] = tokens\n        ndata[field_name + \"__len\"] = lengths\n\n\ndef assign_features_to_blocks(blocks, g, textset, ntype):\n    # For the first block (which is closest to the input), copy the features from\n    # the original graph as well as the texts.\n    assign_simple_node_features(blocks[0].srcdata, g, ntype)\n    assign_textual_node_features(blocks[0].srcdata, textset, ntype)\n    assign_simple_node_features(blocks[-1].dstdata, g, ntype)\n    assign_textual_node_features(blocks[-1].dstdata, textset, ntype)\n\n\nclass PinSAGECollator(object):\n    def __init__(self, sampler, g, ntype, textset):\n        self.sampler = sampler\n        self.ntype = ntype\n        self.g = g\n        self.textset = textset\n\n    def collate_train(self, batches):\n        heads, tails, neg_tails = batches[0]\n        # Construct multilayer neighborhood via PinSAGE...\n        pos_graph, neg_graph, blocks = self.sampler.sample_from_item_pairs(\n            heads, tails, neg_tails\n        )\n        assign_features_to_blocks(blocks, self.g, self.textset, self.ntype)\n\n        return pos_graph, neg_graph, blocks\n\n    def collate_test(self, samples):\n        batch = torch.LongTensor(samples)\n        blocks = self.sampler.sample_blocks(batch)\n        assign_features_to_blocks(blocks, self.g, self.textset, self.ntype)\n        return blocks\n"
  },
  {
    "path": "examples/pytorch/pointcloud/bipointnet/ModelNetDataLoader.py",
    "content": "import os\nimport warnings\n\nimport numpy as np\nfrom torch.utils.data import Dataset\n\nwarnings.filterwarnings(\"ignore\")\n\n\ndef pc_normalize(pc):\n    centroid = np.mean(pc, axis=0)\n    pc = pc - centroid\n    m = np.max(np.sqrt(np.sum(pc**2, axis=1)))\n    pc = pc / m\n    return pc\n\n\ndef farthest_point_sample(point, npoint):\n    \"\"\"\n    Farthest point sampler works as follows:\n    1. Initialize the sample set S with a random point\n    2. Pick point P not in S, which maximizes the distance d(P, S)\n    3. Repeat step 2 until |S| = npoint\n\n    Input:\n        xyz: pointcloud data, [N, D]\n        npoint: number of samples\n    Return:\n        centroids: sampled pointcloud index, [npoint, D]\n    \"\"\"\n    N, D = point.shape\n    xyz = point[:, :3]\n    centroids = np.zeros((npoint,))\n    distance = np.ones((N,)) * 1e10\n    farthest = np.random.randint(0, N)\n    for i in range(npoint):\n        centroids[i] = farthest\n        centroid = xyz[farthest, :]\n        dist = np.sum((xyz - centroid) ** 2, -1)\n        mask = dist < distance\n        distance[mask] = dist[mask]\n        farthest = np.argmax(distance, -1)\n    point = point[centroids.astype(np.int32)]\n    return point\n\n\nclass ModelNetDataLoader(Dataset):\n    def __init__(\n        self,\n        root,\n        npoint=1024,\n        split=\"train\",\n        fps=False,\n        normal_channel=True,\n        cache_size=15000,\n    ):\n        \"\"\"\n        Input:\n            root: the root path to the local data files\n            npoint: number of points from each cloud\n            split: which split of the data, 'train' or 'test'\n            fps: whether to sample points with farthest point sampler\n            normal_channel: whether to use additional channel\n            cache_size: the cache size of in-memory point clouds\n        \"\"\"\n        self.root = root\n        self.npoints = npoint\n        self.fps = fps\n        self.catfile = os.path.join(self.root, \"modelnet40_shape_names.txt\")\n\n        self.cat = [line.rstrip() for line in open(self.catfile)]\n        self.classes = dict(zip(self.cat, range(len(self.cat))))\n        self.normal_channel = normal_channel\n\n        shape_ids = {}\n        shape_ids[\"train\"] = [\n            line.rstrip()\n            for line in open(os.path.join(self.root, \"modelnet40_train.txt\"))\n        ]\n        shape_ids[\"test\"] = [\n            line.rstrip()\n            for line in open(os.path.join(self.root, \"modelnet40_test.txt\"))\n        ]\n\n        assert split == \"train\" or split == \"test\"\n        shape_names = [\"_\".join(x.split(\"_\")[0:-1]) for x in shape_ids[split]]\n        # list of (shape_name, shape_txt_file_path) tuple\n        self.datapath = [\n            (\n                shape_names[i],\n                os.path.join(self.root, shape_names[i], shape_ids[split][i])\n                + \".txt\",\n            )\n            for i in range(len(shape_ids[split]))\n        ]\n        print(\"The size of %s data is %d\" % (split, len(self.datapath)))\n\n        self.cache_size = cache_size\n        self.cache = {}\n\n    def __len__(self):\n        return len(self.datapath)\n\n    def _get_item(self, index):\n        if index in self.cache:\n            point_set, cls = self.cache[index]\n        else:\n            fn = self.datapath[index]\n            cls = self.classes[self.datapath[index][0]]\n            cls = np.array([cls]).astype(np.int32)\n            point_set = np.loadtxt(fn[1], delimiter=\",\").astype(np.float32)\n            if self.fps:\n                point_set = farthest_point_sample(point_set, self.npoints)\n            else:\n                point_set = point_set[0 : self.npoints, :]\n\n            point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])\n\n            if not self.normal_channel:\n                point_set = point_set[:, 0:3]\n\n            if len(self.cache) < self.cache_size:\n                self.cache[index] = (point_set, cls)\n\n        return point_set, cls\n"
  },
  {
    "path": "examples/pytorch/pointcloud/bipointnet/README.md",
    "content": "## *BiPointNet: Binary Neural Network for Point Clouds*\n\nCreated by [Haotong Qin](https://htqin.github.io/), [Zhongang Cai](https://scholar.google.com/citations?user=WrDKqIAAAAAJ&hl=en), [Mingyuan Zhang](https://scholar.google.com/citations?user=2QLD4fAAAAAJ&hl=en), Yifu Ding, Haiyu Zhao, Shuai Yi, [Xianglong Liu](http://sites.nlsde.buaa.edu.cn/~xlliu/), and [Hao Su](https://cseweb.ucsd.edu/~haosu/) from Beihang University, SenseTime, and UCSD.\n\n![prediction example](https://htqin.github.io/Imgs/ICLR/overview_v1.png)\n\n### Introduction\n\nThis project is the official implementation of our accepted ICLR 2021 paper *BiPointNet: Binary Neural Network for Point Clouds* [[PDF]( https://openreview.net/forum?id=9QLRCVysdlO)]. To alleviate the resource constraint for real-time point cloud applications that run on edge devices, in this paper we present ***BiPointNet***, the first model binarization approach for efficient deep learning on point clouds. We first discover that the immense performance drop of binarized models for point clouds mainly stems from two challenges: aggregation-induced feature homogenization that leads to a degradation of information entropy, and scale distortion that hinders optimization and invalidates scale-sensitive structures. With theoretical justifications and in-depth analysis, our BiPointNet introduces Entropy-Maximizing Aggregation (EMA) to modulate the distribution before aggregation for the maximum information entropy, and Layer-wise Scale Recovery (LSR) to efficiently restore feature representation capacity. Extensive experiments show that BiPointNet outperforms existing binarization methods by convincing margins, at the level even comparable with the full precision counterpart. We highlight that our techniques are generic, guaranteeing significant improvements on various fundamental tasks and mainstream backbones, e.g., BiPointNet gives an impressive 14.7x speedup and 18.9x storage saving on real-world resource-constrained devices. Besides, our reasoning framework is dabnn.\n\n### How to Run\n\n```shell script\npython train_cls.py --model ${MODEL}\n```\n\nHere, `MODEL` has two choices: `bipointnet`  and  `bipointnet2_ssg`\n\n# Performance\n\n## Classification\n\n| Model           | Dataset    | Metric   | Score |\n| --------------- | ---------- | -------- | ----- |\n| BiPointNet      | ModelNet40 | Accuracy | 88.4  |\n| BiPointNet2_SSG | ModelNet40 | Accuracy | 83.1  |\n\nBecause of the difference in implementation brought by the application of DGL, this version is even better than the original paper.\n\n### Citation\n\nIf you find our work useful in your research, please consider citing:\n\n```\n@inproceedings{Qin:iclr21,\n  author    = {Haotong Qin and Zhongang Cai and Mingyuan Zhang \n  and Yifu Ding and Haiyu Zhao and Shuai Yi \n  and Xianglong Liu and Hao Su},\n  title     = {BiPointNet: Binary Neural Network for Point Clouds},\n  booktitle = {ICLR},\n  year      = {2021}\n}\n```"
  },
  {
    "path": "examples/pytorch/pointcloud/bipointnet/basic.py",
    "content": "import dgl\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.autograd import Function\nfrom torch.nn import Parameter\nfrom torch.nn.modules.utils import _single\n\n\nclass BinaryQuantize(Function):\n    @staticmethod\n    def forward(ctx, input):\n        ctx.save_for_backward(input)\n        out = torch.sign(input)\n        return out\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        input = ctx.saved_tensors\n        grad_input = grad_output\n        grad_input[input[0].gt(1)] = 0\n        grad_input[input[0].lt(-1)] = 0\n        return grad_input\n\n\nclass BiLinearLSR(torch.nn.Linear):\n    def __init__(self, in_features, out_features, bias=False, binary_act=True):\n        super(BiLinearLSR, self).__init__(in_features, out_features, bias=bias)\n        self.binary_act = binary_act\n\n        # must register a nn.Parameter placeholder for model loading\n        # self.register_parameter('scale', None) doesn't register None into state_dict\n        # so it leads to unexpected key error when loading saved model\n        # hence, init scale with Parameter\n        # however, Parameter(None) actually has size [0], not [] as a scalar\n        # hence, init it using the following trick\n        self.register_parameter(\n            \"scale\", Parameter(torch.Tensor([0.0]).squeeze())\n        )\n\n    def reset_scale(self, input):\n        bw = self.weight\n        ba = input\n        bw = bw - bw.mean()\n        self.scale = Parameter(\n            (\n                F.linear(ba, bw).std()\n                / F.linear(torch.sign(ba), torch.sign(bw)).std()\n            )\n            .float()\n            .to(ba.device)\n        )\n        # corner case when ba is all 0.0\n        if torch.isnan(self.scale):\n            self.scale = Parameter(\n                (bw.std() / torch.sign(bw).std()).float().to(ba.device)\n            )\n\n    def forward(self, input):\n        bw = self.weight\n        ba = input\n        bw = bw - bw.mean()\n\n        if self.scale.item() == 0.0:\n            self.reset_scale(input)\n\n        bw = BinaryQuantize().apply(bw)\n        bw = bw * self.scale\n        if self.binary_act:\n            ba = BinaryQuantize().apply(ba)\n        output = F.linear(ba, bw)\n        return output\n\n\nclass BiLinear(torch.nn.Linear):\n    def __init__(self, in_features, out_features, bias=True, binary_act=True):\n        super(BiLinear, self).__init__(in_features, out_features, bias=True)\n        self.binary_act = binary_act\n        self.output_ = None\n\n    def forward(self, input):\n        bw = self.weight\n        ba = input\n        bw = BinaryQuantize().apply(bw)\n        if self.binary_act:\n            ba = BinaryQuantize().apply(ba)\n        output = F.linear(ba, bw, self.bias)\n        self.output_ = output\n        return output\n\n\nclass BiConv2d(torch.nn.Conv2d):\n    def __init__(\n        self,\n        in_channels,\n        out_channels,\n        kernel_size,\n        stride=1,\n        padding=0,\n        dilation=1,\n        groups=1,\n        bias=True,\n        padding_mode=\"zeros\",\n    ):\n        super(BiConv2d, self).__init__(\n            in_channels,\n            out_channels,\n            kernel_size,\n            stride,\n            padding,\n            dilation,\n            groups,\n            bias,\n            padding_mode,\n        )\n\n    def forward(self, input):\n        bw = self.weight\n        ba = input\n        bw = bw - bw.mean()\n        bw = BinaryQuantize().apply(bw)\n        ba = BinaryQuantize().apply(ba)\n\n        if self.padding_mode == \"circular\":\n            expanded_padding = (\n                (self.padding[0] + 1) // 2,\n                self.padding[0] // 2,\n            )\n            return F.conv2d(\n                F.pad(ba, expanded_padding, mode=\"circular\"),\n                bw,\n                self.bias,\n                self.stride,\n                _single(0),\n                self.dilation,\n                self.groups,\n            )\n        return F.conv2d(\n            ba,\n            bw,\n            self.bias,\n            self.stride,\n            self.padding,\n            self.dilation,\n            self.groups,\n        )\n\n\ndef square_distance(src, dst):\n    \"\"\"\n    Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch\n    \"\"\"\n    B, N, _ = src.shape\n    _, M, _ = dst.shape\n    dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))\n    dist += torch.sum(src**2, -1).view(B, N, 1)\n    dist += torch.sum(dst**2, -1).view(B, 1, M)\n    return dist\n\n\ndef index_points(points, idx):\n    \"\"\"\n    Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch\n    \"\"\"\n    device = points.device\n    B = points.shape[0]\n    view_shape = list(idx.shape)\n    view_shape[1:] = [1] * (len(view_shape) - 1)\n    repeat_shape = list(idx.shape)\n    repeat_shape[0] = 1\n    batch_indices = (\n        torch.arange(B, dtype=torch.long)\n        .to(device)\n        .view(view_shape)\n        .repeat(repeat_shape)\n    )\n    new_points = points[batch_indices, idx, :]\n    return new_points\n\n\nclass FixedRadiusNearNeighbors(nn.Module):\n    \"\"\"\n    Ball Query - Find the neighbors with-in a fixed radius\n    \"\"\"\n\n    def __init__(self, radius, n_neighbor):\n        super(FixedRadiusNearNeighbors, self).__init__()\n        self.radius = radius\n        self.n_neighbor = n_neighbor\n\n    def forward(self, pos, centroids):\n        \"\"\"\n        Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch\n        \"\"\"\n        device = pos.device\n        B, N, _ = pos.shape\n        center_pos = index_points(pos, centroids)\n        _, S, _ = center_pos.shape\n        group_idx = (\n            torch.arange(N, dtype=torch.long)\n            .to(device)\n            .view(1, 1, N)\n            .repeat([B, S, 1])\n        )\n        sqrdists = square_distance(center_pos, pos)\n        group_idx[sqrdists > self.radius**2] = N\n        group_idx = group_idx.sort(dim=-1)[0][:, :, : self.n_neighbor]\n        group_first = (\n            group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, self.n_neighbor])\n        )\n        mask = group_idx == N\n        group_idx[mask] = group_first[mask]\n        return group_idx\n\n\nclass FixedRadiusNNGraph(nn.Module):\n    \"\"\"\n    Build NN graph\n    \"\"\"\n\n    def __init__(self, radius, n_neighbor):\n        super(FixedRadiusNNGraph, self).__init__()\n        self.radius = radius\n        self.n_neighbor = n_neighbor\n        self.frnn = FixedRadiusNearNeighbors(radius, n_neighbor)\n\n    def forward(self, pos, centroids, feat=None):\n        dev = pos.device\n        group_idx = self.frnn(pos, centroids)\n        B, N, _ = pos.shape\n        glist = []\n        for i in range(B):\n            center = torch.zeros((N)).to(dev)\n            center[centroids[i]] = 1\n            src = group_idx[i].contiguous().view(-1)\n            dst = centroids[i].view(-1, 1).repeat(1, self.n_neighbor).view(-1)\n\n            unified = torch.cat([src, dst])\n            uniq, inv_idx = torch.unique(unified, return_inverse=True)\n            src_idx = inv_idx[: src.shape[0]]\n            dst_idx = inv_idx[src.shape[0] :]\n\n            g = dgl.graph((src_idx, dst_idx))\n            g.ndata[\"pos\"] = pos[i][uniq]\n            g.ndata[\"center\"] = center[uniq]\n            if feat is not None:\n                g.ndata[\"feat\"] = feat[i][uniq]\n            glist.append(g)\n        bg = dgl.batch(glist)\n        return bg\n\n\nclass RelativePositionMessage(nn.Module):\n    \"\"\"\n    Compute the input feature from neighbors\n    \"\"\"\n\n    def __init__(self, n_neighbor):\n        super(RelativePositionMessage, self).__init__()\n        self.n_neighbor = n_neighbor\n\n    def forward(self, edges):\n        pos = edges.src[\"pos\"] - edges.dst[\"pos\"]\n        if \"feat\" in edges.src:\n            res = torch.cat([pos, edges.src[\"feat\"]], 1)\n        else:\n            res = pos\n        return {\"agg_feat\": res}\n"
  },
  {
    "path": "examples/pytorch/pointcloud/bipointnet/bipointnet2.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom basic import (\n    BiConv2d,\n    BiLinearLSR,\n    FixedRadiusNNGraph,\n    RelativePositionMessage,\n)\nfrom dgl.geometry import farthest_point_sampler\n\n\nclass BiPointNetConv(nn.Module):\n    \"\"\"\n    Feature aggregation\n    \"\"\"\n\n    def __init__(self, sizes, batch_size):\n        super(BiPointNetConv, self).__init__()\n        self.batch_size = batch_size\n        self.conv = nn.ModuleList()\n        self.bn = nn.ModuleList()\n        for i in range(1, len(sizes)):\n            self.conv.append(BiConv2d(sizes[i - 1], sizes[i], 1))\n            self.bn.append(nn.BatchNorm2d(sizes[i]))\n\n    def forward(self, nodes):\n        shape = nodes.mailbox[\"agg_feat\"].shape\n        h = (\n            nodes.mailbox[\"agg_feat\"]\n            .view(self.batch_size, -1, shape[1], shape[2])\n            .permute(0, 3, 2, 1)\n        )\n        for conv, bn in zip(self.conv, self.bn):\n            h = conv(h)\n            h = bn(h)\n            h = F.relu(h)\n        h = torch.max(h, 2)[0]\n        feat_dim = h.shape[1]\n        h = h.permute(0, 2, 1).reshape(-1, feat_dim)\n        return {\"new_feat\": h}\n\n    def group_all(self, pos, feat):\n        \"\"\"\n        Feature aggregation and pooling for the non-sampling layer\n        \"\"\"\n        if feat is not None:\n            h = torch.cat([pos, feat], 2)\n        else:\n            h = pos\n        B, N, D = h.shape\n        _, _, C = pos.shape\n        new_pos = torch.zeros(B, 1, C)\n        h = h.permute(0, 2, 1).view(B, -1, N, 1)\n        for conv, bn in zip(self.conv, self.bn):\n            h = conv(h)\n            h = bn(h)\n            h = F.relu(h)\n        h = torch.max(h[:, :, :, 0], 2)[0]  # [B,D]\n        return new_pos, h\n\n\nclass BiSAModule(nn.Module):\n    \"\"\"\n    The Set Abstraction Layer\n    \"\"\"\n\n    def __init__(\n        self,\n        npoints,\n        batch_size,\n        radius,\n        mlp_sizes,\n        n_neighbor=64,\n        group_all=False,\n    ):\n        super(BiSAModule, self).__init__()\n        self.group_all = group_all\n        if not group_all:\n            self.npoints = npoints\n            self.frnn_graph = FixedRadiusNNGraph(radius, n_neighbor)\n        self.message = RelativePositionMessage(n_neighbor)\n        self.conv = BiPointNetConv(mlp_sizes, batch_size)\n        self.batch_size = batch_size\n\n    def forward(self, pos, feat):\n        if self.group_all:\n            return self.conv.group_all(pos, feat)\n\n        centroids = farthest_point_sampler(pos, self.npoints)\n        g = self.frnn_graph(pos, centroids, feat)\n        g.update_all(self.message, self.conv)\n\n        mask = g.ndata[\"center\"] == 1\n        pos_dim = g.ndata[\"pos\"].shape[-1]\n        feat_dim = g.ndata[\"new_feat\"].shape[-1]\n        pos_res = g.ndata[\"pos\"][mask].view(self.batch_size, -1, pos_dim)\n        feat_res = g.ndata[\"new_feat\"][mask].view(self.batch_size, -1, feat_dim)\n        return pos_res, feat_res\n\n\nclass BiPointNet2SSGCls(nn.Module):\n    def __init__(\n        self, output_classes, batch_size, input_dims=3, dropout_prob=0.4\n    ):\n        super(BiPointNet2SSGCls, self).__init__()\n        self.input_dims = input_dims\n\n        self.sa_module1 = BiSAModule(\n            512, batch_size, 0.2, [input_dims, 64, 64, 128]\n        )\n        self.sa_module2 = BiSAModule(\n            128, batch_size, 0.4, [128 + 3, 128, 128, 256]\n        )\n        self.sa_module3 = BiSAModule(\n            None, batch_size, None, [256 + 3, 256, 512, 1024], group_all=True\n        )\n\n        self.mlp1 = BiLinearLSR(1024, 512)\n        self.bn1 = nn.BatchNorm1d(512)\n        self.drop1 = nn.Dropout(dropout_prob)\n\n        self.mlp2 = BiLinearLSR(512, 256)\n        self.bn2 = nn.BatchNorm1d(256)\n        self.drop2 = nn.Dropout(dropout_prob)\n\n        self.mlp_out = BiLinearLSR(256, output_classes)\n\n    def forward(self, x):\n        if x.shape[-1] > 3:\n            pos = x[:, :, :3]\n            feat = x[:, :, 3:]\n        else:\n            pos = x\n            feat = None\n        pos, feat = self.sa_module1(pos, feat)\n        pos, feat = self.sa_module2(pos, feat)\n        _, h = self.sa_module3(pos, feat)\n\n        h = self.mlp1(h)\n        h = self.bn1(h)\n        h = F.relu(h)\n        h = self.drop1(h)\n        h = self.mlp2(h)\n        h = self.bn2(h)\n        h = F.relu(h)\n        h = self.drop2(h)\n\n        out = self.mlp_out(h)\n        return out\n"
  },
  {
    "path": "examples/pytorch/pointcloud/bipointnet/bipointnet_cls.py",
    "content": "import numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom basic import BiLinear\nfrom torch.autograd import Variable\n\noffset_map = {1024: -3.2041, 2048: -3.4025, 4096: -3.5836}\n\n\nclass Conv1d(nn.Module):\n    def __init__(self, inplane, outplane, Linear):\n        super().__init__()\n        self.lin = Linear(inplane, outplane)\n\n    def forward(self, x):\n        B, C, N = x.shape\n        x = x.permute(0, 2, 1).contiguous().view(-1, C)\n        x = self.lin(x).view(B, N, -1).permute(0, 2, 1).contiguous()\n        return x\n\n\nclass EmaMaxPool(nn.Module):\n    def __init__(self, kernel_size, affine=True, Linear=BiLinear, use_bn=True):\n        super(EmaMaxPool, self).__init__()\n        self.kernel_size = kernel_size\n        self.bn3 = nn.BatchNorm1d(1024, affine=affine)\n        self.use_bn = use_bn\n\n    def forward(self, x):\n        batchsize, D, N = x.size()\n        if self.use_bn:\n            x = torch.max(x, 2, keepdim=True)[0] + offset_map[N]\n        else:\n            x = torch.max(x, 2, keepdim=True)[0] - 0.3\n        return x\n\n\nclass BiPointNetCls(nn.Module):\n    def __init__(\n        self,\n        output_classes,\n        input_dims=3,\n        conv1_dim=64,\n        use_transform=True,\n        Linear=BiLinear,\n    ):\n        super(BiPointNetCls, self).__init__()\n        self.input_dims = input_dims\n        self.conv1 = nn.ModuleList()\n        self.conv1.append(Conv1d(input_dims, conv1_dim, Linear=Linear))\n        self.conv1.append(Conv1d(conv1_dim, conv1_dim, Linear=Linear))\n        self.conv1.append(Conv1d(conv1_dim, conv1_dim, Linear=Linear))\n\n        self.bn1 = nn.ModuleList()\n        self.bn1.append(nn.BatchNorm1d(conv1_dim))\n        self.bn1.append(nn.BatchNorm1d(conv1_dim))\n        self.bn1.append(nn.BatchNorm1d(conv1_dim))\n\n        self.conv2 = nn.ModuleList()\n        self.conv2.append(Conv1d(conv1_dim, conv1_dim * 2, Linear=Linear))\n        self.conv2.append(Conv1d(conv1_dim * 2, conv1_dim * 16, Linear=Linear))\n\n        self.bn2 = nn.ModuleList()\n        self.bn2.append(nn.BatchNorm1d(conv1_dim * 2))\n        self.bn2.append(nn.BatchNorm1d(conv1_dim * 16))\n\n        self.maxpool = EmaMaxPool(conv1_dim * 16, Linear=Linear, use_bn=True)\n        self.pool_feat_len = conv1_dim * 16\n\n        self.mlp3 = nn.ModuleList()\n        self.mlp3.append(Linear(conv1_dim * 16, conv1_dim * 8))\n        self.mlp3.append(Linear(conv1_dim * 8, conv1_dim * 4))\n\n        self.bn3 = nn.ModuleList()\n        self.bn3.append(nn.BatchNorm1d(conv1_dim * 8))\n        self.bn3.append(nn.BatchNorm1d(conv1_dim * 4))\n\n        self.dropout = nn.Dropout(0.3)\n        self.mlp_out = Linear(conv1_dim * 4, output_classes)\n\n        self.use_transform = use_transform\n        if use_transform:\n            self.transform1 = TransformNet(input_dims)\n            self.trans_bn1 = nn.BatchNorm1d(input_dims)\n            self.transform2 = TransformNet(conv1_dim)\n            self.trans_bn2 = nn.BatchNorm1d(conv1_dim)\n\n    def forward(self, x):\n        batch_size = x.shape[0]\n        h = x.permute(0, 2, 1)\n        if self.use_transform:\n            trans = self.transform1(h)\n            h = h.transpose(2, 1)\n            h = torch.bmm(h, trans)\n            h = h.transpose(2, 1)\n            h = F.relu(self.trans_bn1(h))\n\n        for conv, bn in zip(self.conv1, self.bn1):\n            h = conv(h)\n            h = bn(h)\n            h = F.relu(h)\n\n        if self.use_transform:\n            trans = self.transform2(h)\n            h = h.transpose(2, 1)\n            h = torch.bmm(h, trans)\n            h = h.transpose(2, 1)\n            h = F.relu(self.trans_bn2(h))\n\n        for conv, bn in zip(self.conv2, self.bn2):\n            h = conv(h)\n            h = bn(h)\n            h = F.relu(h)\n\n        h = self.maxpool(h).view(-1, self.pool_feat_len)\n        for mlp, bn in zip(self.mlp3, self.bn3):\n            h = mlp(h)\n            h = bn(h)\n            h = F.relu(h)\n\n        h = self.dropout(h)\n        out = self.mlp_out(h)\n        return out\n\n\nclass TransformNet(nn.Module):\n    def __init__(self, input_dims=3, conv1_dim=64, Linear=BiLinear):\n        super(TransformNet, self).__init__()\n        self.conv = nn.ModuleList()\n        self.conv.append(Conv1d(input_dims, conv1_dim, Linear=Linear))\n        self.conv.append(Conv1d(conv1_dim, conv1_dim * 2, Linear=Linear))\n        self.conv.append(Conv1d(conv1_dim * 2, conv1_dim * 16, Linear=Linear))\n\n        self.bn = nn.ModuleList()\n        self.bn.append(nn.BatchNorm1d(conv1_dim))\n        self.bn.append(nn.BatchNorm1d(conv1_dim * 2))\n        self.bn.append(nn.BatchNorm1d(conv1_dim * 16))\n\n        # self.maxpool = nn.MaxPool1d(conv1_dim * 16)\n        self.maxpool = EmaMaxPool(conv1_dim * 16, Linear=Linear, use_bn=True)\n        self.pool_feat_len = conv1_dim * 16\n\n        self.mlp2 = nn.ModuleList()\n        self.mlp2.append(Linear(conv1_dim * 16, conv1_dim * 8))\n        self.mlp2.append(Linear(conv1_dim * 8, conv1_dim * 4))\n\n        self.bn2 = nn.ModuleList()\n        self.bn2.append(nn.BatchNorm1d(conv1_dim * 8))\n        self.bn2.append(nn.BatchNorm1d(conv1_dim * 4))\n\n        self.input_dims = input_dims\n        self.mlp_out = Linear(conv1_dim * 4, input_dims * input_dims)\n\n    def forward(self, h):\n        batch_size = h.shape[0]\n        for conv, bn in zip(self.conv, self.bn):\n            h = conv(h)\n            h = bn(h)\n            h = F.relu(h)\n\n        h = self.maxpool(h).view(-1, self.pool_feat_len)\n        for mlp, bn in zip(self.mlp2, self.bn2):\n            h = mlp(h)\n            h = bn(h)\n            h = F.relu(h)\n\n        out = self.mlp_out(h)\n\n        iden = Variable(\n            torch.from_numpy(\n                np.eye(self.input_dims).flatten().astype(np.float32)\n            )\n        )\n        iden = iden.view(1, self.input_dims * self.input_dims).repeat(\n            batch_size, 1\n        )\n        if out.is_cuda:\n            iden = iden.cuda()\n        out = out + iden\n        out = out.view(-1, self.input_dims, self.input_dims)\n        return out\n"
  },
  {
    "path": "examples/pytorch/pointcloud/bipointnet/train_cls.py",
    "content": "import argparse\nimport os\nimport urllib\nfrom functools import partial\n\nimport dgl\nimport provider\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nimport tqdm\nfrom bipointnet2 import BiPointNet2SSGCls\nfrom bipointnet_cls import BiPointNetCls\nfrom dgl.data.utils import download, get_download_dir\nfrom ModelNetDataLoader import ModelNetDataLoader\nfrom torch.utils.data import DataLoader\n\ntorch.backends.cudnn.enabled = False\n\n\n# from dataset import ModelNet\n\nparser = argparse.ArgumentParser()\nparser.add_argument(\"--model\", type=str, default=\"bipointnet\")\nparser.add_argument(\"--dataset-path\", type=str, default=\"\")\nparser.add_argument(\"--load-model-path\", type=str, default=\"\")\nparser.add_argument(\"--save-model-path\", type=str, default=\"\")\nparser.add_argument(\"--num-epochs\", type=int, default=200)\nparser.add_argument(\"--num-workers\", type=int, default=0)\nparser.add_argument(\"--batch-size\", type=int, default=32)\nargs = parser.parse_args()\n\nnum_workers = args.num_workers\nbatch_size = args.batch_size\n\ndata_filename = \"modelnet40_normal_resampled.zip\"\ndownload_path = os.path.join(get_download_dir(), data_filename)\nlocal_path = args.dataset_path or os.path.join(\n    get_download_dir(), \"modelnet40_normal_resampled\"\n)\nif not os.path.exists(local_path):\n    download(\n        \"https://shapenet.cs.stanford.edu/media/modelnet40_normal_resampled.zip\",\n        download_path,\n        verify_ssl=False,\n    )\n    from zipfile import ZipFile\n\n    with ZipFile(download_path) as z:\n        z.extractall(path=get_download_dir())\n\nCustomDataLoader = partial(\n    DataLoader,\n    num_workers=num_workers,\n    batch_size=batch_size,\n    shuffle=True,\n    drop_last=True,\n)\n\n\ndef train(net, opt, scheduler, train_loader, dev):\n    net.train()\n\n    total_loss = 0\n    num_batches = 0\n    total_correct = 0\n    count = 0\n    loss_f = nn.CrossEntropyLoss()\n    with tqdm.tqdm(train_loader, ascii=True) as tq:\n        for data, label in tq:\n            data = data.data.numpy()\n            data = provider.random_point_dropout(data)\n            data[:, :, 0:3] = provider.random_scale_point_cloud(data[:, :, 0:3])\n            data[:, :, 0:3] = provider.jitter_point_cloud(data[:, :, 0:3])\n            data[:, :, 0:3] = provider.shift_point_cloud(data[:, :, 0:3])\n            data = torch.tensor(data)\n            label = label[:, 0]\n\n            num_examples = label.shape[0]\n            data, label = data.to(dev), label.to(dev).squeeze().long()\n            opt.zero_grad()\n            logits = net(data)\n            loss = loss_f(logits, label)\n            loss.backward()\n            opt.step()\n\n            _, preds = logits.max(1)\n\n            num_batches += 1\n            count += num_examples\n            loss = loss.item()\n            correct = (preds == label).sum().item()\n            total_loss += loss\n            total_correct += correct\n\n            tq.set_postfix(\n                {\n                    \"AvgLoss\": \"%.5f\" % (total_loss / num_batches),\n                    \"AvgAcc\": \"%.5f\" % (total_correct / count),\n                }\n            )\n    scheduler.step()\n\n\ndef evaluate(net, test_loader, dev):\n    net.eval()\n\n    total_correct = 0\n    count = 0\n\n    with torch.no_grad():\n        with tqdm.tqdm(test_loader, ascii=True) as tq:\n            for data, label in tq:\n                label = label[:, 0]\n                num_examples = label.shape[0]\n                data, label = data.to(dev), label.to(dev).squeeze().long()\n                logits = net(data)\n                _, preds = logits.max(1)\n\n                correct = (preds == label).sum().item()\n                total_correct += correct\n                count += num_examples\n\n                tq.set_postfix({\"AvgAcc\": \"%.5f\" % (total_correct / count)})\n\n    return total_correct / count\n\n\ndev = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\nif args.model == \"bipointnet\":\n    net = BiPointNetCls(40, input_dims=6)\nelif args.model == \"bipointnet2_ssg\":\n    net = BiPointNet2SSGCls(40, batch_size, input_dims=6)\n\nnet = net.to(dev)\nif args.load_model_path:\n    net.load_state_dict(\n        torch.load(args.load_model_path, weights_only=False, map_location=dev)\n    )\n\nopt = optim.Adam(net.parameters(), lr=1e-3, weight_decay=1e-4)\n\nscheduler = optim.lr_scheduler.StepLR(opt, step_size=20, gamma=0.7)\n\ntrain_dataset = ModelNetDataLoader(local_path, 1024, split=\"train\")\ntest_dataset = ModelNetDataLoader(local_path, 1024, split=\"test\")\ntrain_loader = torch.utils.data.DataLoader(\n    train_dataset,\n    batch_size=batch_size,\n    shuffle=True,\n    num_workers=num_workers,\n    drop_last=True,\n)\ntest_loader = torch.utils.data.DataLoader(\n    test_dataset,\n    batch_size=batch_size,\n    shuffle=False,\n    num_workers=num_workers,\n    drop_last=True,\n)\n\nbest_test_acc = 0\n\nfor epoch in range(args.num_epochs):\n    train(net, opt, scheduler, train_loader, dev)\n    if (epoch + 1) % 1 == 0:\n        print(\"Epoch #%d Testing\" % epoch)\n        test_acc = evaluate(net, test_loader, dev)\n        if test_acc > best_test_acc:\n            best_test_acc = test_acc\n            if args.save_model_path:\n                torch.save(net.state_dict(), args.save_model_path)\n        print(\"Current test acc: %.5f (best: %.5f)\" % (test_acc, best_test_acc))\n"
  },
  {
    "path": "examples/pytorch/pointcloud/edgeconv/README.md",
    "content": "Dynamic EdgeConv\n====\n\nThis is a reproduction of the paper [Dynamic Graph CNN for Learning on Point\nClouds](https://arxiv.org/pdf/1801.07829.pdf).\n\nThe reproduced experiment is the 40-class classification on the ModelNet40\ndataset.  The sampled point clouds are identical to that of\n[PointNet](https://github.com/charlesq34/pointnet).\n\nTo train and test the model, simply run\n\n```python\npython main.py\n```\n\nThe model currently takes 3 minutes to train an epoch on Tesla V100, and an\nadditional 17 seconds to run a validation and 20 seconds to run a test.\n\nThe best validation performance is 93.5% with a test performance of 91.8%.\n\n## Dependencies\n\n* `h5py`\n* `tqdm`\n"
  },
  {
    "path": "examples/pytorch/pointcloud/edgeconv/main.py",
    "content": "import argparse\nimport os\nimport urllib\nfrom functools import partial\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nimport tqdm\n\nfrom dgl.data.utils import download, get_download_dir\nfrom model import compute_loss, Model\nfrom modelnet import ModelNet\nfrom torch.utils.data import DataLoader\n\nparser = argparse.ArgumentParser()\nparser.add_argument(\"--dataset-path\", type=str, default=\"\")\nparser.add_argument(\"--load-model-path\", type=str, default=\"\")\nparser.add_argument(\"--save-model-path\", type=str, default=\"\")\nparser.add_argument(\"--num-epochs\", type=int, default=100)\nparser.add_argument(\"--num-workers\", type=int, default=0)\nparser.add_argument(\"--batch-size\", type=int, default=32)\nargs = parser.parse_args()\n\nnum_workers = args.num_workers\nbatch_size = args.batch_size\ndata_filename = \"modelnet40-sampled-2048.h5\"\nlocal_path = args.dataset_path or os.path.join(\n    get_download_dir(), data_filename\n)\n\nif not os.path.exists(local_path):\n    download(\n        \"https://data.dgl.ai/dataset/modelnet40-sampled-2048.h5\", local_path\n    )\n\nCustomDataLoader = partial(\n    DataLoader,\n    num_workers=num_workers,\n    batch_size=batch_size,\n    shuffle=True,\n    drop_last=True,\n)\n\n\ndef train(model, opt, scheduler, train_loader, dev):\n    scheduler.step()\n\n    model.train()\n\n    total_loss = 0\n    num_batches = 0\n    total_correct = 0\n    count = 0\n    with tqdm.tqdm(train_loader, ascii=True) as tq:\n        for data, label in tq:\n            num_examples = label.shape[0]\n            data, label = data.to(dev), label.to(dev).squeeze().long()\n            opt.zero_grad()\n            logits = model(data)\n            loss = compute_loss(logits, label)\n            loss.backward()\n            opt.step()\n\n            _, preds = logits.max(1)\n\n            num_batches += 1\n            count += num_examples\n            loss = loss.item()\n            correct = (preds == label).sum().item()\n            total_loss += loss\n            total_correct += correct\n\n            tq.set_postfix(\n                {\n                    \"Loss\": \"%.5f\" % loss,\n                    \"AvgLoss\": \"%.5f\" % (total_loss / num_batches),\n                    \"Acc\": \"%.5f\" % (correct / num_examples),\n                    \"AvgAcc\": \"%.5f\" % (total_correct / count),\n                }\n            )\n\n\ndef evaluate(model, test_loader, dev):\n    model.eval()\n\n    total_correct = 0\n    count = 0\n\n    with torch.no_grad():\n        with tqdm.tqdm(test_loader, ascii=True) as tq:\n            for data, label in tq:\n                num_examples = label.shape[0]\n                data, label = data.to(dev), label.to(dev).squeeze().long()\n                logits = model(data)\n                _, preds = logits.max(1)\n\n                correct = (preds == label).sum().item()\n                total_correct += correct\n                count += num_examples\n\n                tq.set_postfix(\n                    {\n                        \"Acc\": \"%.5f\" % (correct / num_examples),\n                        \"AvgAcc\": \"%.5f\" % (total_correct / count),\n                    }\n                )\n\n    return total_correct / count\n\n\ndev = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\nmodel = Model(20, [64, 64, 128, 256], [512, 512, 256], 40)\nmodel = model.to(dev)\nif args.load_model_path:\n    model.load_state_dict(\n        torch.load(args.load_model_path, weights_only=False, map_location=dev)\n    )\n\nopt = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)\n\nscheduler = optim.lr_scheduler.CosineAnnealingLR(\n    opt, args.num_epochs, eta_min=0.001\n)\n\nmodelnet = ModelNet(local_path, 1024)\n\ntrain_loader = CustomDataLoader(modelnet.train())\nvalid_loader = CustomDataLoader(modelnet.valid())\ntest_loader = CustomDataLoader(modelnet.test())\n\nbest_valid_acc = 0\nbest_test_acc = 0\n\nfor epoch in range(args.num_epochs):\n    print(\"Epoch #%d Validating\" % epoch)\n    valid_acc = evaluate(model, valid_loader, dev)\n    test_acc = evaluate(model, test_loader, dev)\n    if valid_acc > best_valid_acc:\n        best_valid_acc = valid_acc\n        best_test_acc = test_acc\n        if args.save_model_path:\n            torch.save(model.state_dict(), args.save_model_path)\n    print(\n        \"Current validation acc: %.5f (best: %.5f), test acc: %.5f (best: %.5f)\"\n        % (valid_acc, best_valid_acc, test_acc, best_test_acc)\n    )\n\n    train(model, opt, scheduler, train_loader, dev)\n"
  },
  {
    "path": "examples/pytorch/pointcloud/edgeconv/model.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom dgl.nn.pytorch import EdgeConv, KNNGraph\n\n\nclass Model(nn.Module):\n    def __init__(\n        self,\n        k,\n        feature_dims,\n        emb_dims,\n        output_classes,\n        input_dims=3,\n        dropout_prob=0.5,\n    ):\n        super(Model, self).__init__()\n\n        self.nng = KNNGraph(k)\n        self.conv = nn.ModuleList()\n\n        self.num_layers = len(feature_dims)\n        for i in range(self.num_layers):\n            self.conv.append(\n                EdgeConv(\n                    feature_dims[i - 1] if i > 0 else input_dims,\n                    feature_dims[i],\n                    batch_norm=True,\n                )\n            )\n\n        self.proj = nn.Linear(sum(feature_dims), emb_dims[0])\n\n        self.embs = nn.ModuleList()\n        self.bn_embs = nn.ModuleList()\n        self.dropouts = nn.ModuleList()\n\n        self.num_embs = len(emb_dims) - 1\n        for i in range(1, self.num_embs + 1):\n            self.embs.append(\n                nn.Linear(\n                    # * 2 because of concatenation of max- and mean-pooling\n                    emb_dims[i - 1] if i > 1 else (emb_dims[i - 1] * 2),\n                    emb_dims[i],\n                )\n            )\n            self.bn_embs.append(nn.BatchNorm1d(emb_dims[i]))\n            self.dropouts.append(nn.Dropout(dropout_prob))\n\n        self.proj_output = nn.Linear(emb_dims[-1], output_classes)\n\n    def forward(self, x):\n        hs = []\n        batch_size, n_points, x_dims = x.shape\n        h = x\n\n        for i in range(self.num_layers):\n            g = self.nng(h).to(h.device)\n            h = h.view(batch_size * n_points, -1)\n            h = self.conv[i](g, h)\n            h = F.leaky_relu(h, 0.2)\n            h = h.view(batch_size, n_points, -1)\n            hs.append(h)\n\n        h = torch.cat(hs, 2)\n        h = self.proj(h)\n        h_max, _ = torch.max(h, 1)\n        h_avg = torch.mean(h, 1)\n        h = torch.cat([h_max, h_avg], 1)\n\n        for i in range(self.num_embs):\n            h = self.embs[i](h)\n            h = self.bn_embs[i](h)\n            h = F.leaky_relu(h, 0.2)\n            h = self.dropouts[i](h)\n\n        h = self.proj_output(h)\n        return h\n\n\ndef compute_loss(logits, y, eps=0.2):\n    num_classes = logits.shape[1]\n    one_hot = torch.zeros_like(logits).scatter_(1, y.view(-1, 1), 1)\n    one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (num_classes - 1)\n    log_prob = F.log_softmax(logits, 1)\n    loss = -(one_hot * log_prob).sum(1).mean()\n    return loss\n"
  },
  {
    "path": "examples/pytorch/pointcloud/edgeconv/modelnet.py",
    "content": "import numpy as np\nfrom torch.utils.data import Dataset\n\n\nclass ModelNet(object):\n    def __init__(self, path, num_points):\n        import h5py\n\n        self.f = h5py.File(path)\n        self.num_points = num_points\n\n        self.n_train = self.f[\"train/data\"].shape[0]\n        self.n_valid = int(self.n_train / 5)\n        self.n_train -= self.n_valid\n        self.n_test = self.f[\"test/data\"].shape[0]\n\n    def train(self):\n        return ModelNetDataset(self, \"train\")\n\n    def valid(self):\n        return ModelNetDataset(self, \"valid\")\n\n    def test(self):\n        return ModelNetDataset(self, \"test\")\n\n\nclass ModelNetDataset(Dataset):\n    def __init__(self, modelnet, mode):\n        super(ModelNetDataset, self).__init__()\n        self.num_points = modelnet.num_points\n        self.mode = mode\n\n        if mode == \"train\":\n            self.data = modelnet.f[\"train/data\"][: modelnet.n_train]\n            self.label = modelnet.f[\"train/label\"][: modelnet.n_train]\n        elif mode == \"valid\":\n            self.data = modelnet.f[\"train/data\"][modelnet.n_train :]\n            self.label = modelnet.f[\"train/label\"][modelnet.n_train :]\n        elif mode == \"test\":\n            self.data = modelnet.f[\"test/data\"].value\n            self.label = modelnet.f[\"test/label\"].value\n\n    def translate(self, x, scale=(2 / 3, 3 / 2), shift=(-0.2, 0.2)):\n        xyz1 = np.random.uniform(low=scale[0], high=scale[1], size=[3])\n        xyz2 = np.random.uniform(low=shift[0], high=shift[1], size=[3])\n        x = np.add(np.multiply(x, xyz1), xyz2).astype(\"float32\")\n        return x\n\n    def __len__(self):\n        return self.data.shape[0]\n\n    def __getitem__(self, i):\n        x = self.data[i][: self.num_points]\n        y = self.label[i]\n        if self.mode == \"train\":\n            x = self.translate(x)\n            np.random.shuffle(x)\n        return x, y\n"
  },
  {
    "path": "examples/pytorch/pointcloud/pct/ModelNetDataLoader.py",
    "content": "import os\nimport warnings\n\nimport numpy as np\nfrom torch.utils.data import Dataset\n\nwarnings.filterwarnings(\"ignore\")\n\n\ndef pc_normalize(pc):\n    centroid = np.mean(pc, axis=0)\n    pc = pc - centroid\n    m = np.max(np.sqrt(np.sum(pc**2, axis=1)))\n    pc = pc / m\n    return pc\n\n\ndef farthest_point_sample(point, npoint):\n    \"\"\"\n    Farthest point sampler works as follows:\n    1. Initialize the sample set S with a random point\n    2. Pick point P not in S, which maximizes the distance d(P, S)\n    3. Repeat step 2 until |S| = npoint\n\n    Input:\n        xyz: pointcloud data, [N, D]\n        npoint: number of samples\n    Return:\n        centroids: sampled pointcloud index, [npoint, D]\n    \"\"\"\n    N, D = point.shape\n    xyz = point[:, :3]\n    centroids = np.zeros((npoint,))\n    distance = np.ones((N,)) * 1e10\n    farthest = np.random.randint(0, N)\n    for i in range(npoint):\n        centroids[i] = farthest\n        centroid = xyz[farthest, :]\n        dist = np.sum((xyz - centroid) ** 2, -1)\n        mask = dist < distance\n        distance[mask] = dist[mask]\n        farthest = np.argmax(distance, -1)\n    point = point[centroids.astype(np.int32)]\n    return point\n\n\nclass ModelNetDataLoader(Dataset):\n    def __init__(\n        self,\n        root,\n        npoint=1024,\n        split=\"train\",\n        fps=False,\n        normal_channel=True,\n        cache_size=15000,\n    ):\n        \"\"\"\n        Input:\n            root: the root path to the local data files\n            npoint: number of points from each cloud\n            split: which split of the data, 'train' or 'test'\n            fps: whether to sample points with farthest point sampler\n            normal_channel: whether to use additional channel\n            cache_size: the cache size of in-memory point clouds\n        \"\"\"\n        self.root = root\n        self.npoints = npoint\n        self.fps = fps\n        self.catfile = os.path.join(self.root, \"modelnet40_shape_names.txt\")\n\n        self.cat = [line.rstrip() for line in open(self.catfile)]\n        self.classes = dict(zip(self.cat, range(len(self.cat))))\n        self.normal_channel = normal_channel\n\n        shape_ids = {}\n        shape_ids[\"train\"] = [\n            line.rstrip()\n            for line in open(os.path.join(self.root, \"modelnet40_train.txt\"))\n        ]\n        shape_ids[\"test\"] = [\n            line.rstrip()\n            for line in open(os.path.join(self.root, \"modelnet40_test.txt\"))\n        ]\n\n        assert split == \"train\" or split == \"test\"\n        shape_names = [\"_\".join(x.split(\"_\")[0:-1]) for x in shape_ids[split]]\n        # list of (shape_name, shape_txt_file_path) tuple\n        self.datapath = [\n            (\n                shape_names[i],\n                os.path.join(self.root, shape_names[i], shape_ids[split][i])\n                + \".txt\",\n            )\n            for i in range(len(shape_ids[split]))\n        ]\n        print(\"The size of %s data is %d\" % (split, len(self.datapath)))\n\n        self.cache_size = cache_size\n        self.cache = {}\n\n    def __len__(self):\n        return len(self.datapath)\n\n    def _get_item(self, index):\n        if index in self.cache:\n            point_set, cls = self.cache[index]\n        else:\n            fn = self.datapath[index]\n            cls = self.classes[self.datapath[index][0]]\n            cls = np.array([cls]).astype(np.int32)\n            point_set = np.loadtxt(fn[1], delimiter=\",\").astype(np.float32)\n            if self.fps:\n                point_set = farthest_point_sample(point_set, self.npoints)\n            else:\n                point_set = point_set[0 : self.npoints, :]\n\n            point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])\n\n            if not self.normal_channel:\n                point_set = point_set[:, 0:3]\n\n            if len(self.cache) < self.cache_size:\n                self.cache[index] = (point_set, cls)\n\n        return point_set, cls\n\n    def __getitem__(self, index):\n        return self._get_item(index)\n"
  },
  {
    "path": "examples/pytorch/pointcloud/pct/README.md",
    "content": "PCT\n====\n\nThis is a reproduction of the paper: [PCT: Point cloud transformer](http://arxiv.org/abs/2012.09688).\n\n# Performance\n| Task           | Dataset    | Metric   | Score - Paper  | Score - DGL (Adam) | Time(s) - DGL |\n|-----------------|------------|----------|------------------|-------------|-------------------|\n| Classification        | ModelNet40 | Accuracy | 93.2   | 92.1      | 740.0          |\n| Part Segmentation        | ShapeNet   | mIoU     | 86.4            | 85.6       | 390.0         |\n\n+ Time(s) are the average training time per epoch, measured on EC2 g4dn.12xlarge instance w/ Tesla T4 GPU.\n+ We run the code with the preprocessing used in [PointNet++](../pointnet). We can only get 84.5 for classification if we use the preprocessing described in the paper:\n    > During training, a random translation in [−0.2, 0.2], a random anisotropic scaling in [0.67, 1.5] and a random input dropout were applied to augment the input data.\n\n\n# How to Run\n\nFor point cloud classification, run with\n\n```python\npython train_cls.py\n```\n\nFor point cloud part-segmentation, run with\n\n```python\npython train_partseg.py\n```\n"
  },
  {
    "path": "examples/pytorch/pointcloud/pct/ShapeNet.py",
    "content": "import json\nimport os\nfrom zipfile import ZipFile\n\nimport dgl\n\nimport numpy as np\nimport tqdm\nfrom dgl.data.utils import download, get_download_dir\nfrom scipy.sparse import csr_matrix\nfrom torch.utils.data import Dataset\n\n\nclass ShapeNet(object):\n    def __init__(self, num_points=2048, normal_channel=True):\n        self.num_points = num_points\n        self.normal_channel = normal_channel\n\n        SHAPENET_DOWNLOAD_URL = \"https://shapenet.cs.stanford.edu/media/shapenetcore_partanno_segmentation_benchmark_v0_normal.zip\"\n        download_path = get_download_dir()\n        data_filename = (\n            \"shapenetcore_partanno_segmentation_benchmark_v0_normal.zip\"\n        )\n        data_path = os.path.join(\n            download_path,\n            \"shapenetcore_partanno_segmentation_benchmark_v0_normal\",\n        )\n        if not os.path.exists(data_path):\n            local_path = os.path.join(download_path, data_filename)\n            if not os.path.exists(local_path):\n                download(SHAPENET_DOWNLOAD_URL, local_path, verify_ssl=False)\n            with ZipFile(local_path) as z:\n                z.extractall(path=download_path)\n\n        synset_file = \"synsetoffset2category.txt\"\n        with open(os.path.join(data_path, synset_file)) as f:\n            synset = [t.split(\"\\n\")[0].split(\"\\t\") for t in f.readlines()]\n        self.synset_dict = {}\n        for syn in synset:\n            self.synset_dict[syn[1]] = syn[0]\n        self.seg_classes = {\n            \"Airplane\": [0, 1, 2, 3],\n            \"Bag\": [4, 5],\n            \"Cap\": [6, 7],\n            \"Car\": [8, 9, 10, 11],\n            \"Chair\": [12, 13, 14, 15],\n            \"Earphone\": [16, 17, 18],\n            \"Guitar\": [19, 20, 21],\n            \"Knife\": [22, 23],\n            \"Lamp\": [24, 25, 26, 27],\n            \"Laptop\": [28, 29],\n            \"Motorbike\": [30, 31, 32, 33, 34, 35],\n            \"Mug\": [36, 37],\n            \"Pistol\": [38, 39, 40],\n            \"Rocket\": [41, 42, 43],\n            \"Skateboard\": [44, 45, 46],\n            \"Table\": [47, 48, 49],\n        }\n\n        train_split_json = \"shuffled_train_file_list.json\"\n        val_split_json = \"shuffled_val_file_list.json\"\n        test_split_json = \"shuffled_test_file_list.json\"\n        split_path = os.path.join(data_path, \"train_test_split\")\n        with open(os.path.join(split_path, train_split_json)) as f:\n            tmp = f.read()\n            self.train_file_list = [\n                os.path.join(data_path, t.replace(\"shape_data/\", \"\") + \".txt\")\n                for t in json.loads(tmp)\n            ]\n        with open(os.path.join(split_path, val_split_json)) as f:\n            tmp = f.read()\n            self.val_file_list = [\n                os.path.join(data_path, t.replace(\"shape_data/\", \"\") + \".txt\")\n                for t in json.loads(tmp)\n            ]\n        with open(os.path.join(split_path, test_split_json)) as f:\n            tmp = f.read()\n            self.test_file_list = [\n                os.path.join(data_path, t.replace(\"shape_data/\", \"\") + \".txt\")\n                for t in json.loads(tmp)\n            ]\n\n    def train(self):\n        return ShapeNetDataset(\n            self, \"train\", self.num_points, self.normal_channel\n        )\n\n    def valid(self):\n        return ShapeNetDataset(\n            self, \"valid\", self.num_points, self.normal_channel\n        )\n\n    def trainval(self):\n        return ShapeNetDataset(\n            self, \"trainval\", self.num_points, self.normal_channel\n        )\n\n    def test(self):\n        return ShapeNetDataset(\n            self, \"test\", self.num_points, self.normal_channel\n        )\n\n\nclass ShapeNetDataset(Dataset):\n    def __init__(self, shapenet, mode, num_points, normal_channel=True):\n        super(ShapeNetDataset, self).__init__()\n        self.mode = mode\n        self.num_points = num_points\n        if not normal_channel:\n            self.dim = 3\n        else:\n            self.dim = 6\n\n        if mode == \"train\":\n            self.file_list = shapenet.train_file_list\n        elif mode == \"valid\":\n            self.file_list = shapenet.val_file_list\n        elif mode == \"test\":\n            self.file_list = shapenet.test_file_list\n        elif mode == \"trainval\":\n            self.file_list = shapenet.train_file_list + shapenet.val_file_list\n        else:\n            raise \"Not supported `mode`\"\n\n        data_list = []\n        label_list = []\n        category_list = []\n        print(\"Loading data from split \" + self.mode)\n        for fn in tqdm.tqdm(self.file_list, ascii=True):\n            with open(fn) as f:\n                data = np.array(\n                    [t.split(\"\\n\")[0].split(\" \") for t in f.readlines()]\n                ).astype(np.float)\n            data_list.append(data[:, 0 : self.dim])\n            label_list.append(data[:, 6].astype(int))\n            category_list.append(shapenet.synset_dict[fn.split(\"/\")[-2]])\n        self.data = data_list\n        self.label = label_list\n        self.category = category_list\n\n    def translate(self, x, scale=(2 / 3, 3 / 2), shift=(-0.2, 0.2), size=3):\n        xyz1 = np.random.uniform(low=scale[0], high=scale[1], size=[size])\n        xyz2 = np.random.uniform(low=shift[0], high=shift[1], size=[size])\n        x = np.add(np.multiply(x, xyz1), xyz2).astype(\"float32\")\n        return x\n\n    def __len__(self):\n        return len(self.data)\n\n    def __getitem__(self, i):\n        inds = np.random.choice(\n            self.data[i].shape[0], self.num_points, replace=True\n        )\n        x = self.data[i][inds, : self.dim]\n        y = self.label[i][inds]\n        cat = self.category[i]\n        if self.mode == \"train\":\n            x = self.translate(x, size=self.dim)\n        x = x.astype(np.float)\n        y = y.astype(int)\n        return x, y, cat\n"
  },
  {
    "path": "examples/pytorch/pointcloud/pct/helper.py",
    "content": "import dgl\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dgl.geometry import farthest_point_sampler\n\n\"\"\"\nPart of the code are adapted from\nhttps://github.com/yanx27/Pointnet_Pointnet2_pytorch\n\"\"\"\n\n\ndef square_distance(src, dst):\n    \"\"\"\n    Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch\n    \"\"\"\n    B, N, _ = src.shape\n    _, M, _ = dst.shape\n    dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))\n    dist += torch.sum(src**2, -1).view(B, N, 1)\n    dist += torch.sum(dst**2, -1).view(B, 1, M)\n    return dist\n\n\ndef index_points(points, idx):\n    \"\"\"\n    Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch\n    \"\"\"\n    device = points.device\n    B = points.shape[0]\n    view_shape = list(idx.shape)\n    view_shape[1:] = [1] * (len(view_shape) - 1)\n    repeat_shape = list(idx.shape)\n    repeat_shape[0] = 1\n    batch_indices = (\n        torch.arange(B, dtype=torch.long)\n        .to(device)\n        .view(view_shape)\n        .repeat(repeat_shape)\n    )\n    new_points = points[batch_indices, idx, :]\n    return new_points\n\n\nclass KNearNeighbors(nn.Module):\n    \"\"\"\n    Find the k nearest neighbors\n    \"\"\"\n\n    def __init__(self, n_neighbor):\n        super(KNearNeighbors, self).__init__()\n        self.n_neighbor = n_neighbor\n\n    def forward(self, pos, centroids):\n        \"\"\"\n        Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch\n        \"\"\"\n        center_pos = index_points(pos, centroids)\n        sqrdists = square_distance(center_pos, pos)\n        group_idx = sqrdists.argsort(dim=-1)[:, :, : self.n_neighbor]\n        return group_idx\n\n\nclass KNNGraphBuilder(nn.Module):\n    \"\"\"\n    Build NN graph\n    \"\"\"\n\n    def __init__(self, n_neighbor):\n        super(KNNGraphBuilder, self).__init__()\n        self.n_neighbor = n_neighbor\n        self.knn = KNearNeighbors(n_neighbor)\n\n    def forward(self, pos, centroids, feat=None):\n        dev = pos.device\n        group_idx = self.knn(pos, centroids)\n        B, N, _ = pos.shape\n        glist = []\n        for i in range(B):\n            center = torch.zeros((N)).to(dev)\n            center[centroids[i]] = 1\n            src = group_idx[i].contiguous().view(-1)\n            dst = (\n                centroids[i]\n                .view(-1, 1)\n                .repeat(\n                    1, min(self.n_neighbor, src.shape[0] // centroids.shape[1])\n                )\n                .view(-1)\n            )\n\n            unified = torch.cat([src, dst])\n            uniq, inv_idx = torch.unique(unified, return_inverse=True)\n            src_idx = inv_idx[: src.shape[0]]\n            dst_idx = inv_idx[src.shape[0] :]\n\n            g = dgl.graph((src_idx, dst_idx))\n            g.ndata[\"pos\"] = pos[i][uniq]\n            g.ndata[\"center\"] = center[uniq]\n            if feat is not None:\n                g.ndata[\"feat\"] = feat[i][uniq]\n            glist.append(g)\n        bg = dgl.batch(glist)\n        return bg\n\n\nclass KNNMessage(nn.Module):\n    \"\"\"\n    Compute the input feature from neighbors\n    \"\"\"\n\n    def __init__(self, n_neighbor):\n        super(KNNMessage, self).__init__()\n        self.n_neighbor = n_neighbor\n\n    def forward(self, edges):\n        norm = edges.src[\"feat\"] - edges.dst[\"feat\"]\n        if \"feat\" in edges.src:\n            res = torch.cat([norm, edges.src[\"feat\"]], 1)\n        else:\n            res = norm\n        return {\"agg_feat\": res}\n\n\nclass KNNConv(nn.Module):\n    \"\"\"\n    Feature aggregation\n    \"\"\"\n\n    def __init__(self, sizes):\n        super(KNNConv, self).__init__()\n        self.conv = nn.ModuleList()\n        self.bn = nn.ModuleList()\n        for i in range(1, len(sizes)):\n            self.conv.append(nn.Conv2d(sizes[i - 1], sizes[i], 1))\n            self.bn.append(nn.BatchNorm2d(sizes[i]))\n\n    def forward(self, nodes):\n        shape = nodes.mailbox[\"agg_feat\"].shape\n        h = (\n            nodes.mailbox[\"agg_feat\"]\n            .view(shape[0], -1, shape[1], shape[2])\n            .permute(0, 3, 2, 1)\n        )\n        for conv, bn in zip(self.conv, self.bn):\n            h = conv(h)\n            h = bn(h)\n            h = F.relu(h)\n        h = torch.max(h, 2)[0]\n        feat_dim = h.shape[1]\n        h = h.permute(0, 2, 1).reshape(-1, feat_dim)\n        return {\"new_feat\": h}\n\n\nclass TransitionDown(nn.Module):\n    \"\"\"\n    The Transition Down Module\n    \"\"\"\n\n    def __init__(self, in_channels, out_channels, n_neighbor=64):\n        super(TransitionDown, self).__init__()\n        self.frnn_graph = KNNGraphBuilder(n_neighbor)\n        self.message = KNNMessage(n_neighbor)\n        self.conv = KNNConv([in_channels, out_channels, out_channels])\n\n    def forward(self, pos, feat, n_point):\n        batch_size = pos.shape[0]\n        centroids = farthest_point_sampler(pos, n_point)\n        g = self.frnn_graph(pos, centroids, feat)\n        g.update_all(self.message, self.conv)\n\n        mask = g.ndata[\"center\"] == 1\n        pos_dim = g.ndata[\"pos\"].shape[-1]\n        feat_dim = g.ndata[\"new_feat\"].shape[-1]\n        pos_res = g.ndata[\"pos\"][mask].view(batch_size, -1, pos_dim)\n        feat_res = g.ndata[\"new_feat\"][mask].view(batch_size, -1, feat_dim)\n        return pos_res, feat_res\n"
  },
  {
    "path": "examples/pytorch/pointcloud/pct/pct.py",
    "content": "import torch\nfrom helper import TransitionDown\nfrom torch import nn\n\n\"\"\"\nPart of the code are adapted from\nhttps://github.com/MenghaoGuo/PCT\n\"\"\"\n\n\nclass PCTPositionEmbedding(nn.Module):\n    def __init__(self, channels=256):\n        super(PCTPositionEmbedding, self).__init__()\n        self.conv1 = nn.Conv1d(channels, channels, kernel_size=1, bias=False)\n        self.conv_pos = nn.Conv1d(3, channels, kernel_size=1, bias=False)\n\n        self.bn1 = nn.BatchNorm1d(channels)\n\n        self.sa1 = SALayerCLS(channels)\n        self.sa2 = SALayerCLS(channels)\n        self.sa3 = SALayerCLS(channels)\n        self.sa4 = SALayerCLS(channels)\n\n        self.relu = nn.ReLU()\n\n    def forward(self, x, xyz):\n        # add position embedding\n        xyz = xyz.permute(0, 2, 1)\n        xyz = self.conv_pos(xyz)\n\n        x = self.relu(self.bn1(self.conv1(x)))  # B, D, N\n\n        x1 = self.sa1(x, xyz)\n        x2 = self.sa2(x1, xyz)\n        x3 = self.sa3(x2, xyz)\n        x4 = self.sa4(x3, xyz)\n\n        x = torch.cat((x1, x2, x3, x4), dim=1)\n\n        return x\n\n\nclass SALayerCLS(nn.Module):\n    def __init__(self, channels):\n        super(SALayerCLS, self).__init__()\n        self.q_conv = nn.Conv1d(channels, channels // 4, 1, bias=False)\n        self.k_conv = nn.Conv1d(channels, channels // 4, 1, bias=False)\n        self.q_conv.weight = self.k_conv.weight\n        self.v_conv = nn.Conv1d(channels, channels, 1)\n        self.trans_conv = nn.Conv1d(channels, channels, 1)\n        self.after_norm = nn.BatchNorm1d(channels)\n        self.act = nn.ReLU()\n        self.softmax = nn.Softmax(dim=-1)\n\n    def forward(self, x, xyz):\n        x = x + xyz\n        x_q = self.q_conv(x).permute(0, 2, 1)  # b, n, c\n        x_k = self.k_conv(x)  # b, c, n\n        x_v = self.v_conv(x)\n        energy = torch.bmm(x_q, x_k)  # b, n, n\n        attention = self.softmax(energy)\n        attention = attention / (1e-9 + attention.sum(dim=1, keepdims=True))\n        x_r = torch.bmm(x_v, attention)  # b, c, n\n        x_r = self.act(self.after_norm(self.trans_conv(x - x_r)))\n        x = x + x_r\n        return x\n\n\nclass SALayerSeg(nn.Module):\n    def __init__(self, channels):\n        super(SALayerSeg, self).__init__()\n        self.q_conv = nn.Conv1d(channels, channels // 4, 1, bias=False)\n        self.k_conv = nn.Conv1d(channels, channels // 4, 1, bias=False)\n        self.q_conv.weight = self.k_conv.weight\n        self.v_conv = nn.Conv1d(channels, channels, 1)\n        self.trans_conv = nn.Conv1d(channels, channels, 1)\n        self.after_norm = nn.BatchNorm1d(channels)\n        self.act = nn.ReLU()\n        self.softmax = nn.Softmax(dim=-1)\n\n    def forward(self, x):\n        x_q = self.q_conv(x).permute(0, 2, 1)  # b, n, c\n        x_k = self.k_conv(x)  # b, c, n\n        x_v = self.v_conv(x)\n        energy = torch.bmm(x_q, x_k)  # b, n, n\n        attention = self.softmax(energy)\n        attention = attention / (1e-9 + attention.sum(dim=1, keepdims=True))\n        x_r = torch.bmm(x_v, attention)  # b, c, n\n        x_r = self.act(self.after_norm(self.trans_conv(x - x_r)))\n        x = x + x_r\n        return x\n\n\nclass PointTransformerCLS(nn.Module):\n    def __init__(self, output_channels=40):\n        super(PointTransformerCLS, self).__init__()\n        self.conv1 = nn.Conv1d(3, 64, kernel_size=1, bias=False)\n        self.conv2 = nn.Conv1d(64, 64, kernel_size=1, bias=False)\n        self.bn1 = nn.BatchNorm1d(64)\n        self.bn2 = nn.BatchNorm1d(64)\n        self.g_op0 = TransitionDown(\n            in_channels=128, out_channels=128, n_neighbor=32\n        )\n        self.g_op1 = TransitionDown(\n            in_channels=256, out_channels=256, n_neighbor=32\n        )\n\n        self.pt_last = PCTPositionEmbedding()\n\n        self.relu = nn.ReLU()\n        self.conv_fuse = nn.Sequential(\n            nn.Conv1d(1280, 1024, kernel_size=1, bias=False),\n            nn.BatchNorm1d(1024),\n            nn.LeakyReLU(negative_slope=0.2),\n        )\n\n        self.linear1 = nn.Linear(1024, 512, bias=False)\n        self.bn6 = nn.BatchNorm1d(512)\n        self.dp1 = nn.Dropout(p=0.5)\n        self.linear2 = nn.Linear(512, 256)\n        self.bn7 = nn.BatchNorm1d(256)\n        self.dp2 = nn.Dropout(p=0.5)\n        self.linear3 = nn.Linear(256, output_channels)\n\n    def forward(self, x):\n        xyz = x[..., :3]\n        x = x[..., 3:].permute(0, 2, 1)\n        batch_size, _, _ = x.size()\n        x = self.relu(self.bn1(self.conv1(x)))  # B, D, N\n        x = self.relu(self.bn2(self.conv2(x)))  # B, D, N\n        x = x.permute(0, 2, 1)\n\n        new_xyz, feature_0 = self.g_op0(xyz, x, n_point=512)\n        new_xyz, feature_1 = self.g_op1(new_xyz, feature_0, n_point=256)\n\n        # add position embedding on each layer\n        x = self.pt_last(feature_1, new_xyz)\n\n        x = torch.cat([x, feature_1], dim=1)\n        x = self.conv_fuse(x)\n        x, _ = torch.max(x, 2)\n        x = x.view(batch_size, -1)\n\n        x = self.relu(self.bn6(self.linear1(x)))\n        x = self.dp1(x)\n        x = self.relu(self.bn7(self.linear2(x)))\n        x = self.dp2(x)\n        x = self.linear3(x)\n\n        return x\n\n\nclass PointTransformerSeg(nn.Module):\n    def __init__(self, part_num=50):\n        super(PointTransformerSeg, self).__init__()\n        self.part_num = part_num\n        self.conv1 = nn.Conv1d(3, 128, kernel_size=1, bias=False)\n        self.conv2 = nn.Conv1d(128, 128, kernel_size=1, bias=False)\n\n        self.bn1 = nn.BatchNorm1d(128)\n        self.bn2 = nn.BatchNorm1d(128)\n\n        self.sa1 = SALayerSeg(128)\n        self.sa2 = SALayerSeg(128)\n        self.sa3 = SALayerSeg(128)\n        self.sa4 = SALayerSeg(128)\n\n        self.conv_fuse = nn.Sequential(\n            nn.Conv1d(512, 1024, kernel_size=1, bias=False),\n            nn.BatchNorm1d(1024),\n            nn.LeakyReLU(negative_slope=0.2),\n        )\n\n        self.label_conv = nn.Sequential(\n            nn.Conv1d(16, 64, kernel_size=1, bias=False),\n            nn.BatchNorm1d(64),\n            nn.LeakyReLU(negative_slope=0.2),\n        )\n\n        self.convs1 = nn.Conv1d(1024 * 3 + 64, 512, 1)\n        self.dp1 = nn.Dropout(0.5)\n        self.convs2 = nn.Conv1d(512, 256, 1)\n        self.convs3 = nn.Conv1d(256, self.part_num, 1)\n        self.bns1 = nn.BatchNorm1d(512)\n        self.bns2 = nn.BatchNorm1d(256)\n\n        self.relu = nn.ReLU()\n\n    def forward(self, x, cls_label):\n        x = x.permute(0, 2, 1)\n        batch_size, _, N = x.size()\n        x = self.relu(self.bn1(self.conv1(x)))  # B, D, N\n        x = self.relu(self.bn2(self.conv2(x)))\n        x1 = self.sa1(x)\n        x2 = self.sa2(x1)\n        x3 = self.sa3(x2)\n        x4 = self.sa4(x3)\n        x = torch.cat((x1, x2, x3, x4), dim=1)\n        x = self.conv_fuse(x)\n        x_max, _ = torch.max(x, 2)\n        x_avg = torch.mean(x, 2)\n        x_max_feature = x_max.view(batch_size, -1).unsqueeze(-1).repeat(1, 1, N)\n        x_avg_feature = x_avg.view(batch_size, -1).unsqueeze(-1).repeat(1, 1, N)\n        cls_label_feature = self.label_conv(cls_label).repeat(1, 1, N)\n        x_global_feature = torch.cat(\n            (x_max_feature, x_avg_feature, cls_label_feature), 1\n        )\n        x = torch.cat((x, x_global_feature), 1)\n        x = self.relu(self.bns1(self.convs1(x)))\n        x = self.dp1(x)\n        x = self.relu(self.bns2(self.convs2(x)))\n        x = self.convs3(x)\n        return x\n\n\nclass PartSegLoss(nn.Module):\n    def __init__(self, eps=0.2):\n        super(PartSegLoss, self).__init__()\n        self.eps = eps\n        self.loss = nn.CrossEntropyLoss()\n\n    def forward(self, logits, y):\n        num_classes = logits.shape[1]\n        logits = logits.permute(0, 2, 1).contiguous().view(-1, num_classes)\n        loss = self.loss(logits, y)\n        return loss\n"
  },
  {
    "path": "examples/pytorch/pointcloud/pct/provider.py",
    "content": "\"\"\"\nAdapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch/blob/master/provider.py\n\"\"\"\nimport numpy as np\n\n\ndef normalize_data(batch_data):\n    \"\"\"Normalize the batch data, use coordinates of the block centered at origin,\n    Input:\n        BxNxC array\n    Output:\n        BxNxC array\n    \"\"\"\n    B, N, C = batch_data.shape\n    normal_data = np.zeros((B, N, C))\n    for b in range(B):\n        pc = batch_data[b]\n        centroid = np.mean(pc, axis=0)\n        pc = pc - centroid\n        m = np.max(np.sqrt(np.sum(pc**2, axis=1)))\n        pc = pc / m\n        normal_data[b] = pc\n    return normal_data\n\n\ndef shuffle_data(data, labels):\n    \"\"\"Shuffle data and labels.\n    Input:\n      data: B,N,... numpy array\n      label: B,... numpy array\n    Return:\n      shuffled data, label and shuffle indices\n    \"\"\"\n    idx = np.arange(len(labels))\n    np.random.shuffle(idx)\n    return data[idx, ...], labels[idx], idx\n\n\ndef shuffle_points(batch_data):\n    \"\"\"Shuffle orders of points in each point cloud -- changes FPS behavior.\n    Use the same shuffling idx for the entire batch.\n    Input:\n        BxNxC array\n    Output:\n        BxNxC array\n    \"\"\"\n    idx = np.arange(batch_data.shape[1])\n    np.random.shuffle(idx)\n    return batch_data[:, idx, :]\n\n\ndef rotate_point_cloud(batch_data):\n    \"\"\"Randomly rotate the point clouds to augument the dataset\n    rotation is per shape based along up direction\n    Input:\n      BxNx3 array, original batch of point clouds\n    Return:\n      BxNx3 array, rotated batch of point clouds\n    \"\"\"\n    rotated_data = np.zeros(batch_data.shape, dtype=np.float32)\n    for k in range(batch_data.shape[0]):\n        rotation_angle = np.random.uniform() * 2 * np.pi\n        cosval = np.cos(rotation_angle)\n        sinval = np.sin(rotation_angle)\n        rotation_matrix = np.array(\n            [[cosval, 0, sinval], [0, 1, 0], [-sinval, 0, cosval]]\n        )\n        shape_pc = batch_data[k, ...]\n        rotated_data[k, ...] = np.dot(\n            shape_pc.reshape((-1, 3)), rotation_matrix\n        )\n    return rotated_data\n\n\ndef rotate_point_cloud_z(batch_data):\n    \"\"\"Randomly rotate the point clouds to augument the dataset\n    rotation is per shape based along up direction\n    Input:\n      BxNx3 array, original batch of point clouds\n    Return:\n      BxNx3 array, rotated batch of point clouds\n    \"\"\"\n    rotated_data = np.zeros(batch_data.shape, dtype=np.float32)\n    for k in range(batch_data.shape[0]):\n        rotation_angle = np.random.uniform() * 2 * np.pi\n        cosval = np.cos(rotation_angle)\n        sinval = np.sin(rotation_angle)\n        rotation_matrix = np.array(\n            [[cosval, sinval, 0], [-sinval, cosval, 0], [0, 0, 1]]\n        )\n        shape_pc = batch_data[k, ...]\n        rotated_data[k, ...] = np.dot(\n            shape_pc.reshape((-1, 3)), rotation_matrix\n        )\n    return rotated_data\n\n\ndef rotate_point_cloud_with_normal(batch_xyz_normal):\n    \"\"\"Randomly rotate XYZ, normal point cloud.\n    Input:\n        batch_xyz_normal: B,N,6, first three channels are XYZ, last 3 all normal\n    Output:\n        B,N,6, rotated XYZ, normal point cloud\n    \"\"\"\n    for k in range(batch_xyz_normal.shape[0]):\n        rotation_angle = np.random.uniform() * 2 * np.pi\n        cosval = np.cos(rotation_angle)\n        sinval = np.sin(rotation_angle)\n        rotation_matrix = np.array(\n            [[cosval, 0, sinval], [0, 1, 0], [-sinval, 0, cosval]]\n        )\n        shape_pc = batch_xyz_normal[k, :, 0:3]\n        shape_normal = batch_xyz_normal[k, :, 3:6]\n        batch_xyz_normal[k, :, 0:3] = np.dot(\n            shape_pc.reshape((-1, 3)), rotation_matrix\n        )\n        batch_xyz_normal[k, :, 3:6] = np.dot(\n            shape_normal.reshape((-1, 3)), rotation_matrix\n        )\n    return batch_xyz_normal\n\n\ndef rotate_perturbation_point_cloud_with_normal(\n    batch_data, angle_sigma=0.06, angle_clip=0.18\n):\n    \"\"\"Randomly perturb the point clouds by small rotations\n    Input:\n      BxNx6 array, original batch of point clouds and point normals\n    Return:\n      BxNx3 array, rotated batch of point clouds\n    \"\"\"\n    rotated_data = np.zeros(batch_data.shape, dtype=np.float32)\n    for k in range(batch_data.shape[0]):\n        angles = np.clip(\n            angle_sigma * np.random.randn(3), -angle_clip, angle_clip\n        )\n        Rx = np.array(\n            [\n                [1, 0, 0],\n                [0, np.cos(angles[0]), -np.sin(angles[0])],\n                [0, np.sin(angles[0]), np.cos(angles[0])],\n            ]\n        )\n        Ry = np.array(\n            [\n                [np.cos(angles[1]), 0, np.sin(angles[1])],\n                [0, 1, 0],\n                [-np.sin(angles[1]), 0, np.cos(angles[1])],\n            ]\n        )\n        Rz = np.array(\n            [\n                [np.cos(angles[2]), -np.sin(angles[2]), 0],\n                [np.sin(angles[2]), np.cos(angles[2]), 0],\n                [0, 0, 1],\n            ]\n        )\n        R = np.dot(Rz, np.dot(Ry, Rx))\n        shape_pc = batch_data[k, :, 0:3]\n        shape_normal = batch_data[k, :, 3:6]\n        rotated_data[k, :, 0:3] = np.dot(shape_pc.reshape((-1, 3)), R)\n        rotated_data[k, :, 3:6] = np.dot(shape_normal.reshape((-1, 3)), R)\n    return rotated_data\n\n\ndef rotate_point_cloud_by_angle(batch_data, rotation_angle):\n    \"\"\"Rotate the point cloud along up direction with certain angle.\n    Input:\n      BxNx3 array, original batch of point clouds\n    Return:\n      BxNx3 array, rotated batch of point clouds\n    \"\"\"\n    rotated_data = np.zeros(batch_data.shape, dtype=np.float32)\n    for k in range(batch_data.shape[0]):\n        # rotation_angle = np.random.uniform() * 2 * np.pi\n        cosval = np.cos(rotation_angle)\n        sinval = np.sin(rotation_angle)\n        rotation_matrix = np.array(\n            [[cosval, 0, sinval], [0, 1, 0], [-sinval, 0, cosval]]\n        )\n        shape_pc = batch_data[k, :, 0:3]\n        rotated_data[k, :, 0:3] = np.dot(\n            shape_pc.reshape((-1, 3)), rotation_matrix\n        )\n    return rotated_data\n\n\ndef rotate_point_cloud_by_angle_with_normal(batch_data, rotation_angle):\n    \"\"\"Rotate the point cloud along up direction with certain angle.\n    Input:\n      BxNx6 array, original batch of point clouds with normal\n      scalar, angle of rotation\n    Return:\n      BxNx6 array, rotated batch of point clouds iwth normal\n    \"\"\"\n    rotated_data = np.zeros(batch_data.shape, dtype=np.float32)\n    for k in range(batch_data.shape[0]):\n        # rotation_angle = np.random.uniform() * 2 * np.pi\n        cosval = np.cos(rotation_angle)\n        sinval = np.sin(rotation_angle)\n        rotation_matrix = np.array(\n            [[cosval, 0, sinval], [0, 1, 0], [-sinval, 0, cosval]]\n        )\n        shape_pc = batch_data[k, :, 0:3]\n        shape_normal = batch_data[k, :, 3:6]\n        rotated_data[k, :, 0:3] = np.dot(\n            shape_pc.reshape((-1, 3)), rotation_matrix\n        )\n        rotated_data[k, :, 3:6] = np.dot(\n            shape_normal.reshape((-1, 3)), rotation_matrix\n        )\n    return rotated_data\n\n\ndef rotate_perturbation_point_cloud(\n    batch_data, angle_sigma=0.06, angle_clip=0.18\n):\n    \"\"\"Randomly perturb the point clouds by small rotations\n    Input:\n      BxNx3 array, original batch of point clouds\n    Return:\n      BxNx3 array, rotated batch of point clouds\n    \"\"\"\n    rotated_data = np.zeros(batch_data.shape, dtype=np.float32)\n    for k in range(batch_data.shape[0]):\n        angles = np.clip(\n            angle_sigma * np.random.randn(3), -angle_clip, angle_clip\n        )\n        Rx = np.array(\n            [\n                [1, 0, 0],\n                [0, np.cos(angles[0]), -np.sin(angles[0])],\n                [0, np.sin(angles[0]), np.cos(angles[0])],\n            ]\n        )\n        Ry = np.array(\n            [\n                [np.cos(angles[1]), 0, np.sin(angles[1])],\n                [0, 1, 0],\n                [-np.sin(angles[1]), 0, np.cos(angles[1])],\n            ]\n        )\n        Rz = np.array(\n            [\n                [np.cos(angles[2]), -np.sin(angles[2]), 0],\n                [np.sin(angles[2]), np.cos(angles[2]), 0],\n                [0, 0, 1],\n            ]\n        )\n        R = np.dot(Rz, np.dot(Ry, Rx))\n        shape_pc = batch_data[k, ...]\n        rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), R)\n    return rotated_data\n\n\ndef jitter_point_cloud(batch_data, sigma=0.01, clip=0.05):\n    \"\"\"Randomly jitter points. jittering is per point.\n    Input:\n      BxNx3 array, original batch of point clouds\n    Return:\n      BxNx3 array, jittered batch of point clouds\n    \"\"\"\n    B, N, C = batch_data.shape\n    assert clip > 0\n    jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1 * clip, clip)\n    jittered_data += batch_data\n    return jittered_data\n\n\ndef shift_point_cloud(batch_data, shift_range=0.1):\n    \"\"\"Randomly shift point cloud. Shift is per point cloud.\n    Input:\n      BxNx3 array, original batch of point clouds\n    Return:\n      BxNx3 array, shifted batch of point clouds\n    \"\"\"\n    B, N, C = batch_data.shape\n    shifts = np.random.uniform(-shift_range, shift_range, (B, 3))\n    for batch_index in range(B):\n        batch_data[batch_index, :, :] += shifts[batch_index, :]\n    return batch_data\n\n\ndef random_scale_point_cloud(batch_data, scale_low=0.8, scale_high=1.25):\n    \"\"\"Randomly scale the point cloud. Scale is per point cloud.\n    Input:\n        BxNx3 array, original batch of point clouds\n    Return:\n        BxNx3 array, scaled batch of point clouds\n    \"\"\"\n    B, N, C = batch_data.shape\n    scales = np.random.uniform(scale_low, scale_high, B)\n    for batch_index in range(B):\n        batch_data[batch_index, :, :] *= scales[batch_index]\n    return batch_data\n\n\ndef random_point_dropout(batch_pc, max_dropout_ratio=0.875):\n    \"\"\"batch_pc: BxNx3\"\"\"\n    for b in range(batch_pc.shape[0]):\n        dropout_ratio = np.random.random() * max_dropout_ratio  # 0~0.875\n        drop_idx = np.where(\n            np.random.random((batch_pc.shape[1])) <= dropout_ratio\n        )[0]\n        if len(drop_idx) > 0:\n            dropout_ratio = (\n                np.random.random() * max_dropout_ratio\n            )  # 0~0.875 # not need\n            batch_pc[b, drop_idx, :] = batch_pc[\n                b, 0, :\n            ]  # set to the first point\n    return batch_pc\n"
  },
  {
    "path": "examples/pytorch/pointcloud/pct/train_cls.py",
    "content": "import argparse\nimport os\nimport time\nfrom functools import partial\n\nimport provider\nimport torch\nimport torch.nn as nn\nimport tqdm\n\nfrom dgl.data.utils import download, get_download_dir\nfrom ModelNetDataLoader import ModelNetDataLoader\nfrom pct import PointTransformerCLS\nfrom torch.utils.data import DataLoader\n\ntorch.backends.cudnn.enabled = False\n\nparser = argparse.ArgumentParser()\nparser.add_argument(\"--dataset-path\", type=str, default=\"\")\nparser.add_argument(\"--load-model-path\", type=str, default=\"\")\nparser.add_argument(\"--save-model-path\", type=str, default=\"\")\nparser.add_argument(\"--num-epochs\", type=int, default=250)\nparser.add_argument(\"--num-workers\", type=int, default=8)\nparser.add_argument(\"--batch-size\", type=int, default=32)\nargs = parser.parse_args()\n\nnum_workers = args.num_workers\nbatch_size = args.batch_size\n\ndata_filename = \"modelnet40_normal_resampled.zip\"\ndownload_path = os.path.join(get_download_dir(), data_filename)\nlocal_path = args.dataset_path or os.path.join(\n    get_download_dir(), \"modelnet40_normal_resampled\"\n)\n\nif not os.path.exists(local_path):\n    download(\n        \"https://shapenet.cs.stanford.edu/media/modelnet40_normal_resampled.zip\",\n        download_path,\n        verify_ssl=False,\n    )\n    from zipfile import ZipFile\n\n    with ZipFile(download_path) as z:\n        z.extractall(path=get_download_dir())\n\nCustomDataLoader = partial(\n    DataLoader,\n    num_workers=num_workers,\n    batch_size=batch_size,\n    shuffle=True,\n    drop_last=True,\n)\n\n\ndef train(net, opt, scheduler, train_loader, dev):\n    net.train()\n\n    total_loss = 0\n    num_batches = 0\n    total_correct = 0\n    count = 0\n    loss_f = nn.CrossEntropyLoss()\n    start_time = time.time()\n    with tqdm.tqdm(train_loader, ascii=True) as tq:\n        for data, label in tq:\n            data = data.data.numpy()\n            data = provider.random_point_dropout(data)\n            data[:, :, 0:3] = provider.random_scale_point_cloud(data[:, :, 0:3])\n            data[:, :, 0:3] = provider.jitter_point_cloud(data[:, :, 0:3])\n            data[:, :, 0:3] = provider.shift_point_cloud(data[:, :, 0:3])\n            data = torch.tensor(data)\n            label = label[:, 0]\n\n            num_examples = label.shape[0]\n            data, label = data.to(dev), label.to(dev).squeeze().long()\n            opt.zero_grad()\n            logits = net(data)\n            loss = loss_f(logits, label)\n            loss.backward()\n            opt.step()\n\n            _, preds = logits.max(1)\n\n            num_batches += 1\n            count += num_examples\n            loss = loss.item()\n            correct = (preds == label).sum().item()\n            total_loss += loss\n            total_correct += correct\n\n            tq.set_postfix(\n                {\n                    \"AvgLoss\": \"%.5f\" % (total_loss / num_batches),\n                    \"AvgAcc\": \"%.5f\" % (total_correct / count),\n                }\n            )\n    print(\n        \"[Train] AvgLoss: {:.5}, AvgAcc: {:.5}, Time: {:.5}s\".format(\n            total_loss / num_batches,\n            total_correct / count,\n            time.time() - start_time,\n        )\n    )\n    scheduler.step()\n\n\ndef evaluate(net, test_loader, dev):\n    net.eval()\n\n    total_correct = 0\n    count = 0\n    start_time = time.time()\n    with torch.no_grad():\n        with tqdm.tqdm(test_loader, ascii=True) as tq:\n            for data, label in tq:\n                label = label[:, 0]\n                num_examples = label.shape[0]\n                data, label = data.to(dev), label.to(dev).squeeze().long()\n                logits = net(data)\n                _, preds = logits.max(1)\n\n                correct = (preds == label).sum().item()\n                total_correct += correct\n                count += num_examples\n\n                tq.set_postfix({\"AvgAcc\": \"%.5f\" % (total_correct / count)})\n    print(\n        \"[Test]  AvgAcc: {:.5}, Time: {:.5}s\".format(\n            total_correct / count, time.time() - start_time\n        )\n    )\n    return total_correct / count\n\n\ndev = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\nnet = PointTransformerCLS()\n\nnet = net.to(dev)\nif args.load_model_path:\n    net.load_state_dict(\n        torch.load(args.load_model_path, weights_only=False, map_location=dev)\n    )\n\n\nopt = torch.optim.SGD(\n    net.parameters(), lr=0.01, weight_decay=1e-4, momentum=0.9\n)\n\nscheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\n    opt, T_max=args.num_epochs\n)\n\ntrain_dataset = ModelNetDataLoader(local_path, 1024, split=\"train\")\ntest_dataset = ModelNetDataLoader(local_path, 1024, split=\"test\")\ntrain_loader = torch.utils.data.DataLoader(\n    train_dataset,\n    batch_size=batch_size,\n    shuffle=True,\n    num_workers=num_workers,\n    drop_last=True,\n)\ntest_loader = torch.utils.data.DataLoader(\n    test_dataset,\n    batch_size=batch_size,\n    shuffle=False,\n    num_workers=num_workers,\n    drop_last=True,\n)\n\nbest_test_acc = 0\n\nfor epoch in range(args.num_epochs):\n    print(\"Epoch #{}: \".format(epoch))\n    train(net, opt, scheduler, train_loader, dev)\n    if (epoch + 1) % 1 == 0:\n        test_acc = evaluate(net, test_loader, dev)\n        if test_acc > best_test_acc:\n            best_test_acc = test_acc\n            if args.save_model_path:\n                torch.save(net.state_dict(), args.save_model_path)\n        print(\"Current test acc: %.5f (best: %.5f)\" % (test_acc, best_test_acc))\n    print()\n"
  },
  {
    "path": "examples/pytorch/pointcloud/pct/train_partseg.py",
    "content": "import argparse\nimport time\nfrom functools import partial\n\nimport dgl\n\nimport numpy as np\nimport provider\nimport torch\nimport torch.optim as optim\nimport tqdm\nfrom pct import PartSegLoss, PointTransformerSeg\nfrom ShapeNet import ShapeNet\nfrom torch.utils.data import DataLoader\n\nparser = argparse.ArgumentParser()\nparser.add_argument(\"--dataset-path\", type=str, default=\"\")\nparser.add_argument(\"--load-model-path\", type=str, default=\"\")\nparser.add_argument(\"--save-model-path\", type=str, default=\"\")\nparser.add_argument(\"--num-epochs\", type=int, default=500)\nparser.add_argument(\"--num-workers\", type=int, default=8)\nparser.add_argument(\"--batch-size\", type=int, default=16)\nparser.add_argument(\"--tensorboard\", action=\"store_true\")\nargs = parser.parse_args()\n\nnum_workers = args.num_workers\nbatch_size = args.batch_size\n\n\ndef collate(samples):\n    graphs, cat = map(list, zip(*samples))\n    return dgl.batch(graphs), cat\n\n\nCustomDataLoader = partial(\n    DataLoader,\n    num_workers=num_workers,\n    batch_size=batch_size,\n    shuffle=True,\n    drop_last=True,\n)\n\n\ndef train(net, opt, scheduler, train_loader, dev):\n    category_list = sorted(list(shapenet.seg_classes.keys()))\n    eye_mat = np.eye(16)\n    net.train()\n\n    total_loss = 0\n    num_batches = 0\n    total_correct = 0\n    count = 0\n    start = time.time()\n    with tqdm.tqdm(train_loader, ascii=True) as tq:\n        for data, label, cat in tq:\n            num_examples = data.shape[0]\n            data = data.to(dev, dtype=torch.float)\n            label = label.to(dev, dtype=torch.long).view(-1)\n            opt.zero_grad()\n            cat_ind = [category_list.index(c) for c in cat]\n            # An one-hot encoding for the object category\n            cat_tensor = torch.tensor(eye_mat[cat_ind]).to(\n                dev, dtype=torch.float\n            )\n            cat_tensor = cat_tensor.view(num_examples, 16, 1)\n            logits = net(data, cat_tensor)\n            loss = L(logits, label)\n            loss.backward()\n            opt.step()\n\n            _, preds = logits.max(1)\n\n            count += num_examples * 2048\n            loss = loss.item()\n            total_loss += loss\n            num_batches += 1\n            correct = (preds.view(-1) == label).sum().item()\n            total_correct += correct\n\n            AvgLoss = total_loss / num_batches\n            AvgAcc = total_correct / count\n\n            tq.set_postfix(\n                {\"AvgLoss\": \"%.5f\" % AvgLoss, \"AvgAcc\": \"%.5f\" % AvgAcc}\n            )\n    scheduler.step()\n    end = time.time()\n    print(\n        \"[Train] AvgLoss: {:.5}, AvgAcc: {:.5}, Time: {:.5}s\".format(\n            total_loss / num_batches, total_correct / count, end - start\n        )\n    )\n    return data, preds, AvgLoss, AvgAcc, end - start\n\n\ndef mIoU(preds, label, cat, cat_miou, seg_classes):\n    for i in range(preds.shape[0]):\n        shape_iou = 0\n        n = len(seg_classes[cat[i]])\n        for cls in seg_classes[cat[i]]:\n            pred_set = set(np.where(preds[i, :] == cls)[0])\n            label_set = set(np.where(label[i, :] == cls)[0])\n            union = len(pred_set.union(label_set))\n            inter = len(pred_set.intersection(label_set))\n            if union == 0:\n                shape_iou += 1\n            else:\n                shape_iou += inter / union\n        shape_iou /= n\n        cat_miou[cat[i]][0] += shape_iou\n        cat_miou[cat[i]][1] += 1\n\n    return cat_miou\n\n\ndef evaluate(net, test_loader, dev, per_cat_verbose=False):\n    category_list = sorted(list(shapenet.seg_classes.keys()))\n    eye_mat = np.eye(16)\n    net.eval()\n\n    cat_miou = {}\n    for k in shapenet.seg_classes.keys():\n        cat_miou[k] = [0, 0]\n    miou = 0\n    count = 0\n    per_cat_miou = 0\n    per_cat_count = 0\n\n    with torch.no_grad():\n        with tqdm.tqdm(test_loader, ascii=True) as tq:\n            for data, label, cat in tq:\n                num_examples = data.shape[0]\n                data = data.to(dev, dtype=torch.float)\n                label = label.to(dev, dtype=torch.long)\n                cat_ind = [category_list.index(c) for c in cat]\n                cat_tensor = torch.tensor(eye_mat[cat_ind]).to(\n                    dev, dtype=torch.float\n                )\n                cat_tensor = cat_tensor.view(num_examples, 16, 1)\n                logits = net(data, cat_tensor)\n                _, preds = logits.max(1)\n\n                cat_miou = mIoU(\n                    preds.cpu().numpy(),\n                    label.view(num_examples, -1).cpu().numpy(),\n                    cat,\n                    cat_miou,\n                    shapenet.seg_classes,\n                )\n                for _, v in cat_miou.items():\n                    if v[1] > 0:\n                        miou += v[0]\n                        count += v[1]\n                        per_cat_miou += v[0] / v[1]\n                        per_cat_count += 1\n                tq.set_postfix(\n                    {\n                        \"mIoU\": \"%.5f\" % (miou / count),\n                        \"per Category mIoU\": \"%.5f\"\n                        % (per_cat_miou / per_cat_count),\n                    }\n                )\n    print(\n        \"[Test] mIoU: %.5f, per Category mIoU: %.5f\"\n        % (miou / count, per_cat_miou / per_cat_count)\n    )\n    if per_cat_verbose:\n        print(\"-\" * 60)\n        print(\"Per-Category mIoU:\")\n        for k, v in cat_miou.items():\n            if v[1] > 0:\n                print(\"%s mIoU=%.5f\" % (k, v[0] / v[1]))\n            else:\n                print(\"%s mIoU=%.5f\" % (k, 1))\n        print(\"-\" * 60)\n    return miou / count, per_cat_miou / per_cat_count\n\n\ndev = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\nnet = PointTransformerSeg()\n\nnet = net.to(dev)\nif args.load_model_path:\n    net.load_state_dict(\n        torch.load(args.load_model_path, weights_only=False, map_location=dev)\n    )\n\nopt = torch.optim.SGD(\n    net.parameters(), lr=0.01, weight_decay=1e-4, momentum=0.9\n)\n\nscheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\n    opt, T_max=args.num_epochs\n)\n\nL = PartSegLoss()\n\nshapenet = ShapeNet(2048, normal_channel=False)\n\ntrain_loader = CustomDataLoader(shapenet.trainval())\ntest_loader = CustomDataLoader(shapenet.test())\n\n# Tensorboard\nif args.tensorboard:\n    import torchvision\n    from torch.utils.tensorboard import SummaryWriter\n    from torchvision import datasets, transforms\n\n    writer = SummaryWriter()\n# Select 50 distinct colors for different parts\ncolor_map = torch.tensor(\n    [\n        [47, 79, 79],\n        [139, 69, 19],\n        [112, 128, 144],\n        [85, 107, 47],\n        [139, 0, 0],\n        [128, 128, 0],\n        [72, 61, 139],\n        [0, 128, 0],\n        [188, 143, 143],\n        [60, 179, 113],\n        [205, 133, 63],\n        [0, 139, 139],\n        [70, 130, 180],\n        [205, 92, 92],\n        [154, 205, 50],\n        [0, 0, 139],\n        [50, 205, 50],\n        [250, 250, 250],\n        [218, 165, 32],\n        [139, 0, 139],\n        [10, 10, 10],\n        [176, 48, 96],\n        [72, 209, 204],\n        [153, 50, 204],\n        [255, 69, 0],\n        [255, 145, 0],\n        [0, 0, 205],\n        [255, 255, 0],\n        [0, 255, 0],\n        [233, 150, 122],\n        [220, 20, 60],\n        [0, 191, 255],\n        [160, 32, 240],\n        [192, 192, 192],\n        [173, 255, 47],\n        [218, 112, 214],\n        [216, 191, 216],\n        [255, 127, 80],\n        [255, 0, 255],\n        [100, 149, 237],\n        [128, 128, 128],\n        [221, 160, 221],\n        [144, 238, 144],\n        [123, 104, 238],\n        [255, 160, 122],\n        [175, 238, 238],\n        [238, 130, 238],\n        [127, 255, 212],\n        [255, 218, 185],\n        [255, 105, 180],\n    ]\n)\n# paint each point according to its pred\n\n\ndef paint(batched_points):\n    B, N = batched_points.shape\n    colored = color_map[batched_points].squeeze(2)\n    return colored\n\n\nbest_test_miou = 0\nbest_test_per_cat_miou = 0\n\nfor epoch in range(args.num_epochs):\n    print(\"Epoch #{}: \".format(epoch))\n    data, preds, AvgLoss, AvgAcc, training_time = train(\n        net, opt, scheduler, train_loader, dev\n    )\n    if (epoch + 1) % 5 == 0 or epoch == 0:\n        test_miou, test_per_cat_miou = evaluate(net, test_loader, dev, True)\n        if test_miou > best_test_miou:\n            best_test_miou = test_miou\n            best_test_per_cat_miou = test_per_cat_miou\n            if args.save_model_path:\n                torch.save(net.state_dict(), args.save_model_path)\n        print(\n            \"Current test mIoU: %.5f (best: %.5f), per-Category mIoU: %.5f (best: %.5f)\"\n            % (\n                test_miou,\n                best_test_miou,\n                test_per_cat_miou,\n                best_test_per_cat_miou,\n            )\n        )\n    # Tensorboard\n    if args.tensorboard:\n        colored = paint(preds)\n        writer.add_mesh(\n            \"data\", vertices=data, colors=colored, global_step=epoch\n        )\n        writer.add_scalar(\n            \"training time for one epoch\", training_time, global_step=epoch\n        )\n        writer.add_scalar(\"AvgLoss\", AvgLoss, global_step=epoch)\n        writer.add_scalar(\"AvgAcc\", AvgAcc, global_step=epoch)\n        if (epoch + 1) % 5 == 0:\n            writer.add_scalar(\"test mIoU\", test_miou, global_step=epoch)\n            writer.add_scalar(\n                \"best test mIoU\", best_test_miou, global_step=epoch\n            )\n    print()\n"
  },
  {
    "path": "examples/pytorch/pointcloud/point_transformer/ModelNetDataLoader.py",
    "content": "import os\nimport warnings\n\nimport numpy as np\nfrom torch.utils.data import Dataset\n\nwarnings.filterwarnings(\"ignore\")\n\n\ndef pc_normalize(pc):\n    centroid = np.mean(pc, axis=0)\n    pc = pc - centroid\n    m = np.max(np.sqrt(np.sum(pc**2, axis=1)))\n    pc = pc / m\n    return pc\n\n\ndef farthest_point_sample(point, npoint):\n    \"\"\"\n    Farthest point sampler works as follows:\n    1. Initialize the sample set S with a random point\n    2. Pick point P not in S, which maximizes the distance d(P, S)\n    3. Repeat step 2 until |S| = npoint\n\n    Input:\n        xyz: pointcloud data, [N, D]\n        npoint: number of samples\n    Return:\n        centroids: sampled pointcloud index, [npoint, D]\n    \"\"\"\n    N, D = point.shape\n    xyz = point[:, :3]\n    centroids = np.zeros((npoint,))\n    distance = np.ones((N,)) * 1e10\n    farthest = np.random.randint(0, N)\n    for i in range(npoint):\n        centroids[i] = farthest\n        centroid = xyz[farthest, :]\n        dist = np.sum((xyz - centroid) ** 2, -1)\n        mask = dist < distance\n        distance[mask] = dist[mask]\n        farthest = np.argmax(distance, -1)\n    point = point[centroids.astype(np.int32)]\n    return point\n\n\nclass ModelNetDataLoader(Dataset):\n    def __init__(\n        self,\n        root,\n        npoint=1024,\n        split=\"train\",\n        fps=False,\n        normal_channel=True,\n        cache_size=15000,\n    ):\n        \"\"\"\n        Input:\n            root: the root path to the local data files\n            npoint: number of points from each cloud\n            split: which split of the data, 'train' or 'test'\n            fps: whether to sample points with farthest point sampler\n            normal_channel: whether to use additional channel\n            cache_size: the cache size of in-memory point clouds\n        \"\"\"\n        self.root = root\n        self.npoints = npoint\n        self.fps = fps\n        self.catfile = os.path.join(self.root, \"modelnet40_shape_names.txt\")\n\n        self.cat = [line.rstrip() for line in open(self.catfile)]\n        self.classes = dict(zip(self.cat, range(len(self.cat))))\n        self.normal_channel = normal_channel\n\n        shape_ids = {}\n        shape_ids[\"train\"] = [\n            line.rstrip()\n            for line in open(os.path.join(self.root, \"modelnet40_train.txt\"))\n        ]\n        shape_ids[\"test\"] = [\n            line.rstrip()\n            for line in open(os.path.join(self.root, \"modelnet40_test.txt\"))\n        ]\n\n        assert split == \"train\" or split == \"test\"\n        shape_names = [\"_\".join(x.split(\"_\")[0:-1]) for x in shape_ids[split]]\n        # list of (shape_name, shape_txt_file_path) tuple\n        self.datapath = [\n            (\n                shape_names[i],\n                os.path.join(self.root, shape_names[i], shape_ids[split][i])\n                + \".txt\",\n            )\n            for i in range(len(shape_ids[split]))\n        ]\n        print(\"The size of %s data is %d\" % (split, len(self.datapath)))\n\n        self.cache_size = cache_size\n        self.cache = {}\n\n    def __len__(self):\n        return len(self.datapath)\n\n    def _get_item(self, index):\n        if index in self.cache:\n            point_set, cls = self.cache[index]\n        else:\n            fn = self.datapath[index]\n            cls = self.classes[self.datapath[index][0]]\n            cls = np.array([cls]).astype(np.int32)\n            point_set = np.loadtxt(fn[1], delimiter=\",\").astype(np.float32)\n            if self.fps:\n                point_set = farthest_point_sample(point_set, self.npoints)\n            else:\n                point_set = point_set[0 : self.npoints, :]\n\n            point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])\n\n            if not self.normal_channel:\n                point_set = point_set[:, 0:3]\n\n            if len(self.cache) < self.cache_size:\n                self.cache[index] = (point_set, cls)\n\n        return point_set, cls\n\n    def __getitem__(self, index):\n        return self._get_item(index)\n"
  },
  {
    "path": "examples/pytorch/pointcloud/point_transformer/README.md",
    "content": "Point Transformer\n====\n\n> This model is implemented on August 27, 2021 when there is no official code released.    \nThus we implemented this model based on the code from <https://github.com/qq456cvb/Point-Transformers>.\n\nThis is a reproduction of the paper: [Point Transformer](http://arxiv.org/abs/2012.09164).\n\n# Performance\n| Task           | Dataset    | Metric   | Score - Paper  | Score - DGL (Adam) | Score - DGL (SGD) | Time(s) - DGL |\n|-----------------|------------|----------|------------------|-------------|-------------|-------------------|\n| Classification        | ModelNet40 | Accuracy | 93.7   | 92.0        |  91.5        | 117.0          |\n| Part Segmentation        | ShapeNet   | mIoU     | 86.6            | 84.3        |  85.1        | 260.0         |\n\n+ Time(s) are the average training time per epoch, measured on EC2 p3.8xlarge instance w/ Tesla V100 GPU.\n\n# How to Run\n\nFor point cloud classification, run with\n\n```python\npython train_cls.py --opt [sgd/adam]\n```\n\nFor point cloud part-segmentation, run with\n\n```python\npython train_partseg.py --opt [sgd/adam]\n```\n"
  },
  {
    "path": "examples/pytorch/pointcloud/point_transformer/ShapeNet.py",
    "content": "import json\nimport os\nfrom zipfile import ZipFile\n\nimport dgl\n\nimport numpy as np\nimport tqdm\nfrom dgl.data.utils import download, get_download_dir\nfrom scipy.sparse import csr_matrix\nfrom torch.utils.data import Dataset\n\n\nclass ShapeNet(object):\n    def __init__(self, num_points=2048, normal_channel=True):\n        self.num_points = num_points\n        self.normal_channel = normal_channel\n\n        SHAPENET_DOWNLOAD_URL = \"https://shapenet.cs.stanford.edu/media/shapenetcore_partanno_segmentation_benchmark_v0_normal.zip\"\n        download_path = get_download_dir()\n        data_filename = (\n            \"shapenetcore_partanno_segmentation_benchmark_v0_normal.zip\"\n        )\n        data_path = os.path.join(\n            download_path,\n            \"shapenetcore_partanno_segmentation_benchmark_v0_normal\",\n        )\n        if not os.path.exists(data_path):\n            local_path = os.path.join(download_path, data_filename)\n            if not os.path.exists(local_path):\n                download(SHAPENET_DOWNLOAD_URL, local_path, verify_ssl=False)\n            with ZipFile(local_path) as z:\n                z.extractall(path=download_path)\n\n        synset_file = \"synsetoffset2category.txt\"\n        with open(os.path.join(data_path, synset_file)) as f:\n            synset = [t.split(\"\\n\")[0].split(\"\\t\") for t in f.readlines()]\n        self.synset_dict = {}\n        for syn in synset:\n            self.synset_dict[syn[1]] = syn[0]\n        self.seg_classes = {\n            \"Airplane\": [0, 1, 2, 3],\n            \"Bag\": [4, 5],\n            \"Cap\": [6, 7],\n            \"Car\": [8, 9, 10, 11],\n            \"Chair\": [12, 13, 14, 15],\n            \"Earphone\": [16, 17, 18],\n            \"Guitar\": [19, 20, 21],\n            \"Knife\": [22, 23],\n            \"Lamp\": [24, 25, 26, 27],\n            \"Laptop\": [28, 29],\n            \"Motorbike\": [30, 31, 32, 33, 34, 35],\n            \"Mug\": [36, 37],\n            \"Pistol\": [38, 39, 40],\n            \"Rocket\": [41, 42, 43],\n            \"Skateboard\": [44, 45, 46],\n            \"Table\": [47, 48, 49],\n        }\n\n        train_split_json = \"shuffled_train_file_list.json\"\n        val_split_json = \"shuffled_val_file_list.json\"\n        test_split_json = \"shuffled_test_file_list.json\"\n        split_path = os.path.join(data_path, \"train_test_split\")\n        with open(os.path.join(split_path, train_split_json)) as f:\n            tmp = f.read()\n            self.train_file_list = [\n                os.path.join(data_path, t.replace(\"shape_data/\", \"\") + \".txt\")\n                for t in json.loads(tmp)\n            ]\n        with open(os.path.join(split_path, val_split_json)) as f:\n            tmp = f.read()\n            self.val_file_list = [\n                os.path.join(data_path, t.replace(\"shape_data/\", \"\") + \".txt\")\n                for t in json.loads(tmp)\n            ]\n        with open(os.path.join(split_path, test_split_json)) as f:\n            tmp = f.read()\n            self.test_file_list = [\n                os.path.join(data_path, t.replace(\"shape_data/\", \"\") + \".txt\")\n                for t in json.loads(tmp)\n            ]\n\n    def train(self):\n        return ShapeNetDataset(\n            self, \"train\", self.num_points, self.normal_channel\n        )\n\n    def valid(self):\n        return ShapeNetDataset(\n            self, \"valid\", self.num_points, self.normal_channel\n        )\n\n    def trainval(self):\n        return ShapeNetDataset(\n            self, \"trainval\", self.num_points, self.normal_channel\n        )\n\n    def test(self):\n        return ShapeNetDataset(\n            self, \"test\", self.num_points, self.normal_channel\n        )\n\n\nclass ShapeNetDataset(Dataset):\n    def __init__(self, shapenet, mode, num_points, normal_channel=True):\n        super(ShapeNetDataset, self).__init__()\n        self.mode = mode\n        self.num_points = num_points\n        if not normal_channel:\n            self.dim = 3\n        else:\n            self.dim = 6\n\n        if mode == \"train\":\n            self.file_list = shapenet.train_file_list\n        elif mode == \"valid\":\n            self.file_list = shapenet.val_file_list\n        elif mode == \"test\":\n            self.file_list = shapenet.test_file_list\n        elif mode == \"trainval\":\n            self.file_list = shapenet.train_file_list + shapenet.val_file_list\n        else:\n            raise \"Not supported `mode`\"\n\n        data_list = []\n        label_list = []\n        category_list = []\n        print(\"Loading data from split \" + self.mode)\n        for fn in tqdm.tqdm(self.file_list, ascii=True):\n            with open(fn) as f:\n                data = np.array(\n                    [t.split(\"\\n\")[0].split(\" \") for t in f.readlines()]\n                ).astype(float)\n            data_list.append(data[:, 0 : self.dim])\n            label_list.append(data[:, 6].astype(int))\n            category_list.append(shapenet.synset_dict[fn.split(\"/\")[-2]])\n        self.data = data_list\n        self.label = label_list\n        self.category = category_list\n\n    def translate(self, x, scale=(2 / 3, 3 / 2), shift=(-0.2, 0.2), size=3):\n        xyz1 = np.random.uniform(low=scale[0], high=scale[1], size=[size])\n        xyz2 = np.random.uniform(low=shift[0], high=shift[1], size=[size])\n        x = np.add(np.multiply(x, xyz1), xyz2).astype(\"float32\")\n        return x\n\n    def __len__(self):\n        return len(self.data)\n\n    def __getitem__(self, i):\n        inds = np.random.choice(\n            self.data[i].shape[0], self.num_points, replace=True\n        )\n        x = self.data[i][inds, : self.dim]\n        y = self.label[i][inds]\n        cat = self.category[i]\n        if self.mode == \"train\":\n            x = self.translate(x, size=self.dim)\n        x = x.astype(float)\n        y = y.astype(int)\n        return x, y, cat\n"
  },
  {
    "path": "examples/pytorch/pointcloud/point_transformer/helper.py",
    "content": "import dgl\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dgl.geometry import farthest_point_sampler\n\n\"\"\"\nPart of the code are adapted from\nhttps://github.com/yanx27/Pointnet_Pointnet2_pytorch\n\"\"\"\n\n\ndef square_distance(src, dst):\n    \"\"\"\n    Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch\n    \"\"\"\n    B, N, _ = src.shape\n    _, M, _ = dst.shape\n    dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))\n    dist += torch.sum(src**2, -1).view(B, N, 1)\n    dist += torch.sum(dst**2, -1).view(B, 1, M)\n    return dist\n\n\ndef index_points(points, idx):\n    \"\"\"\n    Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch\n    \"\"\"\n    device = points.device\n    B = points.shape[0]\n    view_shape = list(idx.shape)\n    view_shape[1:] = [1] * (len(view_shape) - 1)\n    repeat_shape = list(idx.shape)\n    repeat_shape[0] = 1\n    batch_indices = (\n        torch.arange(B, dtype=torch.long)\n        .to(device)\n        .view(view_shape)\n        .repeat(repeat_shape)\n    )\n    new_points = points[batch_indices, idx, :]\n    return new_points\n\n\nclass KNearNeighbors(nn.Module):\n    \"\"\"\n    Find the k nearest neighbors\n    \"\"\"\n\n    def __init__(self, n_neighbor):\n        super(KNearNeighbors, self).__init__()\n        self.n_neighbor = n_neighbor\n\n    def forward(self, pos, centroids):\n        \"\"\"\n        Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch\n        \"\"\"\n        center_pos = index_points(pos, centroids)\n        sqrdists = square_distance(center_pos, pos)\n        group_idx = sqrdists.argsort(dim=-1)[:, :, : self.n_neighbor]\n        return group_idx\n\n\nclass KNNGraphBuilder(nn.Module):\n    \"\"\"\n    Build NN graph\n    \"\"\"\n\n    def __init__(self, n_neighbor):\n        super(KNNGraphBuilder, self).__init__()\n        self.n_neighbor = n_neighbor\n        self.knn = KNearNeighbors(n_neighbor)\n\n    def forward(self, pos, centroids, feat=None):\n        dev = pos.device\n        group_idx = self.knn(pos, centroids)\n        B, N, _ = pos.shape\n        glist = []\n        for i in range(B):\n            center = torch.zeros((N)).to(dev)\n            center[centroids[i]] = 1\n            src = group_idx[i].contiguous().view(-1)\n            dst = (\n                centroids[i]\n                .view(-1, 1)\n                .repeat(\n                    1, min(self.n_neighbor, src.shape[0] // centroids.shape[1])\n                )\n                .view(-1)\n            )\n\n            unified = torch.cat([src, dst])\n            uniq, inv_idx = torch.unique(unified, return_inverse=True)\n            src_idx = inv_idx[: src.shape[0]]\n            dst_idx = inv_idx[src.shape[0] :]\n\n            g = dgl.graph((src_idx, dst_idx))\n            g.ndata[\"pos\"] = pos[i][uniq]\n            g.ndata[\"center\"] = center[uniq]\n            if feat is not None:\n                g.ndata[\"feat\"] = feat[i][uniq]\n            glist.append(g)\n        bg = dgl.batch(glist)\n        return bg\n\n\nclass RelativePositionMessage(nn.Module):\n    \"\"\"\n    Compute the input feature from neighbors\n    \"\"\"\n\n    def __init__(self, n_neighbor):\n        super(RelativePositionMessage, self).__init__()\n        self.n_neighbor = n_neighbor\n\n    def forward(self, edges):\n        pos = edges.src[\"pos\"] - edges.dst[\"pos\"]\n        if \"feat\" in edges.src:\n            res = torch.cat([pos, edges.src[\"feat\"]], 1)\n        else:\n            res = pos\n        return {\"agg_feat\": res}\n\n\nclass KNNConv(nn.Module):\n    \"\"\"\n    Feature aggregation\n    \"\"\"\n\n    def __init__(self, sizes, batch_size):\n        super(KNNConv, self).__init__()\n        self.batch_size = batch_size\n        self.conv = nn.ModuleList()\n        self.bn = nn.ModuleList()\n        for i in range(1, len(sizes)):\n            self.conv.append(nn.Conv2d(sizes[i - 1], sizes[i], 1))\n            self.bn.append(nn.BatchNorm2d(sizes[i]))\n\n    def forward(self, nodes):\n        shape = nodes.mailbox[\"agg_feat\"].shape\n        h = (\n            nodes.mailbox[\"agg_feat\"]\n            .view(self.batch_size, -1, shape[1], shape[2])\n            .permute(0, 3, 2, 1)\n        )\n        for conv, bn in zip(self.conv, self.bn):\n            h = conv(h)\n            h = bn(h)\n            h = F.relu(h)\n        h = torch.max(h, 2)[0]\n        feat_dim = h.shape[1]\n        h = h.permute(0, 2, 1).reshape(-1, feat_dim)\n        return {\"new_feat\": h}\n\n    def group_all(self, pos, feat):\n        \"\"\"\n        Feature aggregation and pooling for the non-sampling layer\n        \"\"\"\n        if feat is not None:\n            h = torch.cat([pos, feat], 2)\n        else:\n            h = pos\n        B, N, D = h.shape\n        _, _, C = pos.shape\n        new_pos = torch.zeros(B, 1, C)\n        h = h.permute(0, 2, 1).view(B, -1, N, 1)\n        for conv, bn in zip(self.conv, self.bn):\n            h = conv(h)\n            h = bn(h)\n            h = F.relu(h)\n        h = torch.max(h[:, :, :, 0], 2)[0]  # [B,D]\n        return new_pos, h\n\n\nclass TransitionDown(nn.Module):\n    \"\"\"\n    The Transition Down Module\n    \"\"\"\n\n    def __init__(self, n_points, batch_size, mlp_sizes, n_neighbors=64):\n        super(TransitionDown, self).__init__()\n        self.n_points = n_points\n        self.frnn_graph = KNNGraphBuilder(n_neighbors)\n        self.message = RelativePositionMessage(n_neighbors)\n        self.conv = KNNConv(mlp_sizes, batch_size)\n        self.batch_size = batch_size\n\n    def forward(self, pos, feat):\n        centroids = farthest_point_sampler(pos, self.n_points)\n        g = self.frnn_graph(pos, centroids, feat)\n        g.update_all(self.message, self.conv)\n\n        mask = g.ndata[\"center\"] == 1\n        pos_dim = g.ndata[\"pos\"].shape[-1]\n        feat_dim = g.ndata[\"new_feat\"].shape[-1]\n        pos_res = g.ndata[\"pos\"][mask].view(self.batch_size, -1, pos_dim)\n        feat_res = g.ndata[\"new_feat\"][mask].view(self.batch_size, -1, feat_dim)\n        return pos_res, feat_res\n\n\nclass FeaturePropagation(nn.Module):\n    \"\"\"\n    The FeaturePropagation Layer\n    \"\"\"\n\n    def __init__(self, input_dims, sizes):\n        super(FeaturePropagation, self).__init__()\n        self.convs = nn.ModuleList()\n        self.bns = nn.ModuleList()\n\n        sizes = [input_dims] + sizes\n        for i in range(1, len(sizes)):\n            self.convs.append(nn.Conv1d(sizes[i - 1], sizes[i], 1))\n            self.bns.append(nn.BatchNorm1d(sizes[i]))\n\n    def forward(self, x1, x2, feat1, feat2):\n        \"\"\"\n        Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch\n            Input:\n                x1: input points position data, [B, N, C]\n                x2: sampled input points position data, [B, S, C]\n                feat1: input points data, [B, N, D]\n                feat2: input points data, [B, S, D]\n            Return:\n                new_feat: upsampled points data, [B, D', N]\n        \"\"\"\n        B, N, C = x1.shape\n        _, S, _ = x2.shape\n\n        if S == 1:\n            interpolated_feat = feat2.repeat(1, N, 1)\n        else:\n            dists = square_distance(x1, x2)\n            dists, idx = dists.sort(dim=-1)\n            dists, idx = dists[:, :, :3], idx[:, :, :3]  # [B, N, 3]\n\n            dist_recip = 1.0 / (dists + 1e-8)\n            norm = torch.sum(dist_recip, dim=2, keepdim=True)\n            weight = dist_recip / norm\n            interpolated_feat = torch.sum(\n                index_points(feat2, idx) * weight.view(B, N, 3, 1), dim=2\n            )\n\n        if feat1 is not None:\n            new_feat = torch.cat([feat1, interpolated_feat], dim=-1)\n        else:\n            new_feat = interpolated_feat\n\n        new_feat = new_feat.permute(0, 2, 1)  # [B, D, S]\n        for i, conv in enumerate(self.convs):\n            bn = self.bns[i]\n            new_feat = F.relu(bn(conv(new_feat)))\n        return new_feat\n\n\nclass SwapAxes(nn.Module):\n    def __init__(self, dim1=1, dim2=2):\n        super(SwapAxes, self).__init__()\n        self.dim1 = dim1\n        self.dim2 = dim2\n\n    def forward(self, x):\n        return x.transpose(self.dim1, self.dim2)\n\n\nclass TransitionUp(nn.Module):\n    \"\"\"\n    The Transition Up Module\n    \"\"\"\n\n    def __init__(self, dim1, dim2, dim_out):\n        super(TransitionUp, self).__init__()\n        self.fc1 = nn.Sequential(\n            nn.Linear(dim1, dim_out),\n            SwapAxes(),\n            nn.BatchNorm1d(dim_out),  # TODO\n            SwapAxes(),\n            nn.ReLU(),\n        )\n        self.fc2 = nn.Sequential(\n            nn.Linear(dim2, dim_out),\n            SwapAxes(),\n            nn.BatchNorm1d(dim_out),  # TODO\n            SwapAxes(),\n            nn.ReLU(),\n        )\n        self.fp = FeaturePropagation(-1, [])\n\n    def forward(self, pos1, feat1, pos2, feat2):\n        h1 = self.fc1(feat1)\n        h2 = self.fc2(feat2)\n        h1 = self.fp(pos2, pos1, None, h1).transpose(1, 2)\n        return h1 + h2\n"
  },
  {
    "path": "examples/pytorch/pointcloud/point_transformer/point_transformer.py",
    "content": "import numpy as np\nimport torch\nfrom helper import index_points, square_distance, TransitionDown, TransitionUp\nfrom torch import nn\n\n\"\"\"\nPart of the code are adapted from\nhttps://github.com/qq456cvb/Point-Transformers\n\"\"\"\n\n\nclass PointTransformerBlock(nn.Module):\n    def __init__(self, input_dim, n_neighbors, transformer_dim=None):\n        super(PointTransformerBlock, self).__init__()\n        if transformer_dim is None:\n            transformer_dim = input_dim\n        self.fc1 = nn.Linear(input_dim, transformer_dim)\n        self.fc2 = nn.Linear(transformer_dim, input_dim)\n        self.fc_delta = nn.Sequential(\n            nn.Linear(3, transformer_dim),\n            nn.ReLU(),\n            nn.Linear(transformer_dim, transformer_dim),\n        )\n        self.fc_gamma = nn.Sequential(\n            nn.Linear(transformer_dim, transformer_dim),\n            nn.ReLU(),\n            nn.Linear(transformer_dim, transformer_dim),\n        )\n        self.w_qs = nn.Linear(transformer_dim, transformer_dim, bias=False)\n        self.w_ks = nn.Linear(transformer_dim, transformer_dim, bias=False)\n        self.w_vs = nn.Linear(transformer_dim, transformer_dim, bias=False)\n        self.n_neighbors = n_neighbors\n\n    def forward(self, x, pos):\n        dists = square_distance(pos, pos)\n        knn_idx = dists.argsort()[:, :, : self.n_neighbors]  # b x n x k\n        knn_pos = index_points(pos, knn_idx)\n\n        h = self.fc1(x)\n        q, k, v = (\n            self.w_qs(h),\n            index_points(self.w_ks(h), knn_idx),\n            index_points(self.w_vs(h), knn_idx),\n        )\n\n        pos_enc = self.fc_delta(pos[:, :, None] - knn_pos)  # b x n x k x f\n\n        attn = self.fc_gamma(q[:, :, None] - k + pos_enc)\n        attn = torch.softmax(\n            attn / np.sqrt(k.size(-1)), dim=-2\n        )  # b x n x k x f\n\n        res = torch.einsum(\"bmnf,bmnf->bmf\", attn, v + pos_enc)\n        res = self.fc2(res) + x\n        return res, attn\n\n\nclass PointTransformer(nn.Module):\n    def __init__(\n        self,\n        n_points,\n        batch_size,\n        feature_dim=3,\n        n_blocks=4,\n        downsampling_rate=4,\n        hidden_dim=32,\n        transformer_dim=None,\n        n_neighbors=16,\n    ):\n        super(PointTransformer, self).__init__()\n        self.fc = nn.Sequential(\n            nn.Linear(feature_dim, hidden_dim),\n            nn.ReLU(),\n            nn.Linear(hidden_dim, hidden_dim),\n        )\n        self.ptb = PointTransformerBlock(\n            hidden_dim, n_neighbors, transformer_dim\n        )\n        self.transition_downs = nn.ModuleList()\n        self.transformers = nn.ModuleList()\n        for i in range(n_blocks):\n            block_hidden_dim = hidden_dim * 2 ** (i + 1)\n            block_n_points = n_points // (downsampling_rate ** (i + 1))\n            self.transition_downs.append(\n                TransitionDown(\n                    block_n_points,\n                    batch_size,\n                    [\n                        block_hidden_dim // 2 + 3,\n                        block_hidden_dim,\n                        block_hidden_dim,\n                    ],\n                    n_neighbors=n_neighbors,\n                )\n            )\n            self.transformers.append(\n                PointTransformerBlock(\n                    block_hidden_dim, n_neighbors, transformer_dim\n                )\n            )\n\n    def forward(self, x):\n        if x.shape[-1] > 3:\n            pos = x[:, :, :3]\n        else:\n            pos = x\n\n        feat = x\n        h = self.fc(feat)\n        h, _ = self.ptb(h, pos)\n\n        hidden_state = [(pos, h)]\n        for td, tf in zip(self.transition_downs, self.transformers):\n            pos, h = td(pos, h)\n            h, _ = tf(h, pos)\n            hidden_state.append((pos, h))\n\n        return h, hidden_state\n\n\nclass PointTransformerCLS(nn.Module):\n    def __init__(\n        self,\n        out_classes,\n        batch_size,\n        n_points=1024,\n        feature_dim=3,\n        n_blocks=4,\n        downsampling_rate=4,\n        hidden_dim=32,\n        transformer_dim=None,\n        n_neighbors=16,\n    ):\n        super(PointTransformerCLS, self).__init__()\n        self.backbone = PointTransformer(\n            n_points,\n            batch_size,\n            feature_dim,\n            n_blocks,\n            downsampling_rate,\n            hidden_dim,\n            transformer_dim,\n            n_neighbors,\n        )\n        self.out = self.fc2 = nn.Sequential(\n            nn.Linear(hidden_dim * 2 ** (n_blocks), 256),\n            nn.ReLU(),\n            nn.Linear(256, 64),\n            nn.ReLU(),\n            nn.Linear(64, out_classes),\n        )\n\n    def forward(self, x):\n        h, _ = self.backbone(x)\n        out = self.out(torch.mean(h, dim=1))\n        return out\n\n\nclass PointTransformerSeg(nn.Module):\n    def __init__(\n        self,\n        out_classes,\n        batch_size,\n        n_points=2048,\n        feature_dim=3,\n        n_blocks=4,\n        downsampling_rate=4,\n        hidden_dim=32,\n        transformer_dim=None,\n        n_neighbors=16,\n    ):\n        super().__init__()\n        self.backbone = PointTransformer(\n            n_points,\n            batch_size,\n            feature_dim,\n            n_blocks,\n            downsampling_rate,\n            hidden_dim,\n            transformer_dim,\n            n_neighbors,\n        )\n\n        self.fc = nn.Sequential(\n            nn.Linear(32 * 2**n_blocks, 512),\n            nn.ReLU(),\n            nn.Linear(512, 512),\n            nn.ReLU(),\n            nn.Linear(512, 32 * 2**n_blocks),\n        )\n        self.ptb = PointTransformerBlock(\n            32 * 2**n_blocks, n_neighbors, transformer_dim\n        )\n\n        self.n_blocks = n_blocks\n        self.transition_ups = nn.ModuleList()\n        self.transformers = nn.ModuleList()\n        for i in reversed(range(n_blocks)):\n            block_hidden_dim = 32 * 2**i\n            self.transition_ups.append(\n                TransitionUp(\n                    block_hidden_dim * 2, block_hidden_dim, block_hidden_dim\n                )\n            )\n            self.transformers.append(\n                PointTransformerBlock(\n                    block_hidden_dim, n_neighbors, transformer_dim\n                )\n            )\n\n        self.out = nn.Sequential(\n            nn.Linear(32 + 16, 64),\n            nn.ReLU(),\n            nn.Linear(64, 64),\n            nn.ReLU(),\n            nn.Linear(64, out_classes),\n        )\n\n    def forward(self, x, cat_vec=None):\n        _, hidden_state = self.backbone(x)\n        pos, h = hidden_state[-1]\n        h, _ = self.ptb(self.fc(h), pos)\n\n        for i in range(self.n_blocks):\n            h = self.transition_ups[i](\n                pos, h, hidden_state[-i - 2][0], hidden_state[-i - 2][1]\n            )\n            pos = hidden_state[-i - 2][0]\n            h, _ = self.transformers[i](h, pos)\n        return self.out(torch.cat([h, cat_vec], dim=-1))\n\n\nclass PartSegLoss(nn.Module):\n    def __init__(self, eps=0.2):\n        super(PartSegLoss, self).__init__()\n        self.eps = eps\n        self.loss = nn.CrossEntropyLoss()\n\n    def forward(self, logits, y):\n        num_classes = logits.shape[1]\n        logits = logits.permute(0, 2, 1).contiguous().view(-1, num_classes)\n        loss = self.loss(logits, y)\n        return loss\n"
  },
  {
    "path": "examples/pytorch/pointcloud/point_transformer/provider.py",
    "content": "\"\"\"\nAdapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch/blob/master/provider.py\n\"\"\"\nimport numpy as np\n\n\ndef normalize_data(batch_data):\n    \"\"\"Normalize the batch data, use coordinates of the block centered at origin,\n    Input:\n        BxNxC array\n    Output:\n        BxNxC array\n    \"\"\"\n    B, N, C = batch_data.shape\n    normal_data = np.zeros((B, N, C))\n    for b in range(B):\n        pc = batch_data[b]\n        centroid = np.mean(pc, axis=0)\n        pc = pc - centroid\n        m = np.max(np.sqrt(np.sum(pc**2, axis=1)))\n        pc = pc / m\n        normal_data[b] = pc\n    return normal_data\n\n\ndef shuffle_data(data, labels):\n    \"\"\"Shuffle data and labels.\n    Input:\n      data: B,N,... numpy array\n      label: B,... numpy array\n    Return:\n      shuffled data, label and shuffle indices\n    \"\"\"\n    idx = np.arange(len(labels))\n    np.random.shuffle(idx)\n    return data[idx, ...], labels[idx], idx\n\n\ndef shuffle_points(batch_data):\n    \"\"\"Shuffle orders of points in each point cloud -- changes FPS behavior.\n    Use the same shuffling idx for the entire batch.\n    Input:\n        BxNxC array\n    Output:\n        BxNxC array\n    \"\"\"\n    idx = np.arange(batch_data.shape[1])\n    np.random.shuffle(idx)\n    return batch_data[:, idx, :]\n\n\ndef rotate_point_cloud(batch_data):\n    \"\"\"Randomly rotate the point clouds to augument the dataset\n    rotation is per shape based along up direction\n    Input:\n      BxNx3 array, original batch of point clouds\n    Return:\n      BxNx3 array, rotated batch of point clouds\n    \"\"\"\n    rotated_data = np.zeros(batch_data.shape, dtype=np.float32)\n    for k in range(batch_data.shape[0]):\n        rotation_angle = np.random.uniform() * 2 * np.pi\n        cosval = np.cos(rotation_angle)\n        sinval = np.sin(rotation_angle)\n        rotation_matrix = np.array(\n            [[cosval, 0, sinval], [0, 1, 0], [-sinval, 0, cosval]]\n        )\n        shape_pc = batch_data[k, ...]\n        rotated_data[k, ...] = np.dot(\n            shape_pc.reshape((-1, 3)), rotation_matrix\n        )\n    return rotated_data\n\n\ndef rotate_point_cloud_z(batch_data):\n    \"\"\"Randomly rotate the point clouds to augument the dataset\n    rotation is per shape based along up direction\n    Input:\n      BxNx3 array, original batch of point clouds\n    Return:\n      BxNx3 array, rotated batch of point clouds\n    \"\"\"\n    rotated_data = np.zeros(batch_data.shape, dtype=np.float32)\n    for k in range(batch_data.shape[0]):\n        rotation_angle = np.random.uniform() * 2 * np.pi\n        cosval = np.cos(rotation_angle)\n        sinval = np.sin(rotation_angle)\n        rotation_matrix = np.array(\n            [[cosval, sinval, 0], [-sinval, cosval, 0], [0, 0, 1]]\n        )\n        shape_pc = batch_data[k, ...]\n        rotated_data[k, ...] = np.dot(\n            shape_pc.reshape((-1, 3)), rotation_matrix\n        )\n    return rotated_data\n\n\ndef rotate_point_cloud_with_normal(batch_xyz_normal):\n    \"\"\"Randomly rotate XYZ, normal point cloud.\n    Input:\n        batch_xyz_normal: B,N,6, first three channels are XYZ, last 3 all normal\n    Output:\n        B,N,6, rotated XYZ, normal point cloud\n    \"\"\"\n    for k in range(batch_xyz_normal.shape[0]):\n        rotation_angle = np.random.uniform() * 2 * np.pi\n        cosval = np.cos(rotation_angle)\n        sinval = np.sin(rotation_angle)\n        rotation_matrix = np.array(\n            [[cosval, 0, sinval], [0, 1, 0], [-sinval, 0, cosval]]\n        )\n        shape_pc = batch_xyz_normal[k, :, 0:3]\n        shape_normal = batch_xyz_normal[k, :, 3:6]\n        batch_xyz_normal[k, :, 0:3] = np.dot(\n            shape_pc.reshape((-1, 3)), rotation_matrix\n        )\n        batch_xyz_normal[k, :, 3:6] = np.dot(\n            shape_normal.reshape((-1, 3)), rotation_matrix\n        )\n    return batch_xyz_normal\n\n\ndef rotate_perturbation_point_cloud_with_normal(\n    batch_data, angle_sigma=0.06, angle_clip=0.18\n):\n    \"\"\"Randomly perturb the point clouds by small rotations\n    Input:\n      BxNx6 array, original batch of point clouds and point normals\n    Return:\n      BxNx3 array, rotated batch of point clouds\n    \"\"\"\n    rotated_data = np.zeros(batch_data.shape, dtype=np.float32)\n    for k in range(batch_data.shape[0]):\n        angles = np.clip(\n            angle_sigma * np.random.randn(3), -angle_clip, angle_clip\n        )\n        Rx = np.array(\n            [\n                [1, 0, 0],\n                [0, np.cos(angles[0]), -np.sin(angles[0])],\n                [0, np.sin(angles[0]), np.cos(angles[0])],\n            ]\n        )\n        Ry = np.array(\n            [\n                [np.cos(angles[1]), 0, np.sin(angles[1])],\n                [0, 1, 0],\n                [-np.sin(angles[1]), 0, np.cos(angles[1])],\n            ]\n        )\n        Rz = np.array(\n            [\n                [np.cos(angles[2]), -np.sin(angles[2]), 0],\n                [np.sin(angles[2]), np.cos(angles[2]), 0],\n                [0, 0, 1],\n            ]\n        )\n        R = np.dot(Rz, np.dot(Ry, Rx))\n        shape_pc = batch_data[k, :, 0:3]\n        shape_normal = batch_data[k, :, 3:6]\n        rotated_data[k, :, 0:3] = np.dot(shape_pc.reshape((-1, 3)), R)\n        rotated_data[k, :, 3:6] = np.dot(shape_normal.reshape((-1, 3)), R)\n    return rotated_data\n\n\ndef rotate_point_cloud_by_angle(batch_data, rotation_angle):\n    \"\"\"Rotate the point cloud along up direction with certain angle.\n    Input:\n      BxNx3 array, original batch of point clouds\n    Return:\n      BxNx3 array, rotated batch of point clouds\n    \"\"\"\n    rotated_data = np.zeros(batch_data.shape, dtype=np.float32)\n    for k in range(batch_data.shape[0]):\n        # rotation_angle = np.random.uniform() * 2 * np.pi\n        cosval = np.cos(rotation_angle)\n        sinval = np.sin(rotation_angle)\n        rotation_matrix = np.array(\n            [[cosval, 0, sinval], [0, 1, 0], [-sinval, 0, cosval]]\n        )\n        shape_pc = batch_data[k, :, 0:3]\n        rotated_data[k, :, 0:3] = np.dot(\n            shape_pc.reshape((-1, 3)), rotation_matrix\n        )\n    return rotated_data\n\n\ndef rotate_point_cloud_by_angle_with_normal(batch_data, rotation_angle):\n    \"\"\"Rotate the point cloud along up direction with certain angle.\n    Input:\n      BxNx6 array, original batch of point clouds with normal\n      scalar, angle of rotation\n    Return:\n      BxNx6 array, rotated batch of point clouds iwth normal\n    \"\"\"\n    rotated_data = np.zeros(batch_data.shape, dtype=np.float32)\n    for k in range(batch_data.shape[0]):\n        # rotation_angle = np.random.uniform() * 2 * np.pi\n        cosval = np.cos(rotation_angle)\n        sinval = np.sin(rotation_angle)\n        rotation_matrix = np.array(\n            [[cosval, 0, sinval], [0, 1, 0], [-sinval, 0, cosval]]\n        )\n        shape_pc = batch_data[k, :, 0:3]\n        shape_normal = batch_data[k, :, 3:6]\n        rotated_data[k, :, 0:3] = np.dot(\n            shape_pc.reshape((-1, 3)), rotation_matrix\n        )\n        rotated_data[k, :, 3:6] = np.dot(\n            shape_normal.reshape((-1, 3)), rotation_matrix\n        )\n    return rotated_data\n\n\ndef rotate_perturbation_point_cloud(\n    batch_data, angle_sigma=0.06, angle_clip=0.18\n):\n    \"\"\"Randomly perturb the point clouds by small rotations\n    Input:\n      BxNx3 array, original batch of point clouds\n    Return:\n      BxNx3 array, rotated batch of point clouds\n    \"\"\"\n    rotated_data = np.zeros(batch_data.shape, dtype=np.float32)\n    for k in range(batch_data.shape[0]):\n        angles = np.clip(\n            angle_sigma * np.random.randn(3), -angle_clip, angle_clip\n        )\n        Rx = np.array(\n            [\n                [1, 0, 0],\n                [0, np.cos(angles[0]), -np.sin(angles[0])],\n                [0, np.sin(angles[0]), np.cos(angles[0])],\n            ]\n        )\n        Ry = np.array(\n            [\n                [np.cos(angles[1]), 0, np.sin(angles[1])],\n                [0, 1, 0],\n                [-np.sin(angles[1]), 0, np.cos(angles[1])],\n            ]\n        )\n        Rz = np.array(\n            [\n                [np.cos(angles[2]), -np.sin(angles[2]), 0],\n                [np.sin(angles[2]), np.cos(angles[2]), 0],\n                [0, 0, 1],\n            ]\n        )\n        R = np.dot(Rz, np.dot(Ry, Rx))\n        shape_pc = batch_data[k, ...]\n        rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), R)\n    return rotated_data\n\n\ndef jitter_point_cloud(batch_data, sigma=0.01, clip=0.05):\n    \"\"\"Randomly jitter points. jittering is per point.\n    Input:\n      BxNx3 array, original batch of point clouds\n    Return:\n      BxNx3 array, jittered batch of point clouds\n    \"\"\"\n    B, N, C = batch_data.shape\n    assert clip > 0\n    jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1 * clip, clip)\n    jittered_data += batch_data\n    return jittered_data\n\n\ndef shift_point_cloud(batch_data, shift_range=0.1):\n    \"\"\"Randomly shift point cloud. Shift is per point cloud.\n    Input:\n      BxNx3 array, original batch of point clouds\n    Return:\n      BxNx3 array, shifted batch of point clouds\n    \"\"\"\n    B, N, C = batch_data.shape\n    shifts = np.random.uniform(-shift_range, shift_range, (B, 3))\n    for batch_index in range(B):\n        batch_data[batch_index, :, :] += shifts[batch_index, :]\n    return batch_data\n\n\ndef random_scale_point_cloud(batch_data, scale_low=0.8, scale_high=1.25):\n    \"\"\"Randomly scale the point cloud. Scale is per point cloud.\n    Input:\n        BxNx3 array, original batch of point clouds\n    Return:\n        BxNx3 array, scaled batch of point clouds\n    \"\"\"\n    B, N, C = batch_data.shape\n    scales = np.random.uniform(scale_low, scale_high, B)\n    for batch_index in range(B):\n        batch_data[batch_index, :, :] *= scales[batch_index]\n    return batch_data\n\n\ndef random_point_dropout(batch_pc, max_dropout_ratio=0.875):\n    \"\"\"batch_pc: BxNx3\"\"\"\n    for b in range(batch_pc.shape[0]):\n        dropout_ratio = np.random.random() * max_dropout_ratio  # 0~0.875\n        drop_idx = np.where(\n            np.random.random((batch_pc.shape[1])) <= dropout_ratio\n        )[0]\n        if len(drop_idx) > 0:\n            dropout_ratio = (\n                np.random.random() * max_dropout_ratio\n            )  # 0~0.875 # not need\n            batch_pc[b, drop_idx, :] = batch_pc[\n                b, 0, :\n            ]  # set to the first point\n    return batch_pc\n"
  },
  {
    "path": "examples/pytorch/pointcloud/point_transformer/train_cls.py",
    "content": "import argparse\nimport os\nimport time\nfrom functools import partial\n\nimport provider\nimport torch\nimport torch.nn as nn\nimport tqdm\n\nfrom dgl.data.utils import download, get_download_dir\nfrom ModelNetDataLoader import ModelNetDataLoader\nfrom point_transformer import PointTransformerCLS\nfrom torch.utils.data import DataLoader\n\ntorch.backends.cudnn.enabled = False\n\nparser = argparse.ArgumentParser()\nparser.add_argument(\"--dataset-path\", type=str, default=\"\")\nparser.add_argument(\"--load-model-path\", type=str, default=\"\")\nparser.add_argument(\"--save-model-path\", type=str, default=\"\")\nparser.add_argument(\"--num-epochs\", type=int, default=200)\nparser.add_argument(\"--num-workers\", type=int, default=8)\nparser.add_argument(\"--batch-size\", type=int, default=16)\nparser.add_argument(\"--opt\", type=str, default=\"adam\")\nargs = parser.parse_args()\n\nnum_workers = args.num_workers\nbatch_size = args.batch_size\n\ndata_filename = \"modelnet40_normal_resampled.zip\"\ndownload_path = os.path.join(get_download_dir(), data_filename)\nlocal_path = args.dataset_path or os.path.join(\n    get_download_dir(), \"modelnet40_normal_resampled\"\n)\n\nif not os.path.exists(local_path):\n    download(\n        \"https://shapenet.cs.stanford.edu/media/modelnet40_normal_resampled.zip\",\n        download_path,\n        verify_ssl=False,\n    )\n    from zipfile import ZipFile\n\n    with ZipFile(download_path) as z:\n        z.extractall(path=get_download_dir())\n\nCustomDataLoader = partial(\n    DataLoader,\n    num_workers=num_workers,\n    batch_size=batch_size,\n    shuffle=True,\n    drop_last=True,\n)\n\n\ndef train(net, opt, scheduler, train_loader, dev):\n    net.train()\n\n    total_loss = 0\n    num_batches = 0\n    total_correct = 0\n    count = 0\n    loss_f = nn.CrossEntropyLoss()\n    start_time = time.time()\n    with tqdm.tqdm(train_loader, ascii=True) as tq:\n        for data, label in tq:\n            data = data.data.numpy()\n            data = provider.random_point_dropout(data)\n            data[:, :, 0:3] = provider.random_scale_point_cloud(data[:, :, 0:3])\n            data[:, :, 0:3] = provider.jitter_point_cloud(data[:, :, 0:3])\n            data[:, :, 0:3] = provider.shift_point_cloud(data[:, :, 0:3])\n            data = torch.tensor(data)\n            label = label[:, 0]\n\n            num_examples = label.shape[0]\n            data, label = data.to(dev), label.to(dev).squeeze().long()\n            opt.zero_grad()\n            logits = net(data)\n            loss = loss_f(logits, label)\n            loss.backward()\n            opt.step()\n\n            _, preds = logits.max(1)\n\n            num_batches += 1\n            count += num_examples\n            loss = loss.item()\n            correct = (preds == label).sum().item()\n            total_loss += loss\n            total_correct += correct\n\n            tq.set_postfix(\n                {\n                    \"AvgLoss\": \"%.5f\" % (total_loss / num_batches),\n                    \"AvgAcc\": \"%.5f\" % (total_correct / count),\n                }\n            )\n    print(\n        \"[Train] AvgLoss: {:.5}, AvgAcc: {:.5}, Time: {:.5}s\".format(\n            total_loss / num_batches,\n            total_correct / count,\n            time.time() - start_time,\n        )\n    )\n    scheduler.step()\n\n\ndef evaluate(net, test_loader, dev):\n    net.eval()\n\n    total_correct = 0\n    count = 0\n    start_time = time.time()\n    with torch.no_grad():\n        with tqdm.tqdm(test_loader, ascii=True) as tq:\n            for data, label in tq:\n                label = label[:, 0]\n                num_examples = label.shape[0]\n                data, label = data.to(dev), label.to(dev).squeeze().long()\n                logits = net(data)\n                _, preds = logits.max(1)\n\n                correct = (preds == label).sum().item()\n                total_correct += correct\n                count += num_examples\n\n                tq.set_postfix({\"AvgAcc\": \"%.5f\" % (total_correct / count)})\n    print(\n        \"[Test]  AvgAcc: {:.5}, Time: {:.5}s\".format(\n            total_correct / count, time.time() - start_time\n        )\n    )\n    return total_correct / count\n\n\ndev = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\nnet = PointTransformerCLS(40, batch_size, feature_dim=6)\n\nnet = net.to(dev)\nif args.load_model_path:\n    net.load_state_dict(\n        torch.load(args.load_model_path, weights_only=False, map_location=dev)\n    )\n\nif args.opt == \"sgd\":\n    # The optimizer strategy described in paper:\n    opt = torch.optim.SGD(\n        net.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4\n    )\n    scheduler = torch.optim.lr_scheduler.MultiStepLR(\n        opt, milestones=[120, 160], gamma=0.1\n    )\nelif args.opt == \"adam\":\n    # The optimizer strategy proposed by\n    # https://github.com/qq456cvb/Point-Transformers:\n    opt = torch.optim.Adam(\n        net.parameters(),\n        lr=1e-3,\n        betas=(0.9, 0.999),\n        eps=1e-08,\n        weight_decay=1e-4,\n    )\n    scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=50, gamma=0.3)\n\ntrain_dataset = ModelNetDataLoader(local_path, 1024, split=\"train\")\ntest_dataset = ModelNetDataLoader(local_path, 1024, split=\"test\")\ntrain_loader = torch.utils.data.DataLoader(\n    train_dataset,\n    batch_size=batch_size,\n    shuffle=True,\n    num_workers=num_workers,\n    drop_last=True,\n)\ntest_loader = torch.utils.data.DataLoader(\n    test_dataset,\n    batch_size=batch_size,\n    shuffle=False,\n    num_workers=num_workers,\n    drop_last=True,\n)\n\nbest_test_acc = 0\n\nfor epoch in range(args.num_epochs):\n    print(\"Epoch #{}: \".format(epoch))\n    train(net, opt, scheduler, train_loader, dev)\n    if (epoch + 1) % 1 == 0:\n        test_acc = evaluate(net, test_loader, dev)\n        if test_acc > best_test_acc:\n            best_test_acc = test_acc\n            if args.save_model_path:\n                torch.save(net.state_dict(), args.save_model_path)\n        print(\"Current test acc: %.5f (best: %.5f)\" % (test_acc, best_test_acc))\n    print()\n"
  },
  {
    "path": "examples/pytorch/pointcloud/point_transformer/train_partseg.py",
    "content": "import argparse\nimport time\nfrom functools import partial\n\nimport dgl\n\nimport numpy as np\nimport torch\nimport torch.optim as optim\nimport tqdm\nfrom point_transformer import PartSegLoss, PointTransformerSeg\nfrom ShapeNet import ShapeNet\nfrom torch.utils.data import DataLoader\n\nparser = argparse.ArgumentParser()\nparser.add_argument(\"--dataset-path\", type=str, default=\"\")\nparser.add_argument(\"--load-model-path\", type=str, default=\"\")\nparser.add_argument(\"--save-model-path\", type=str, default=\"\")\nparser.add_argument(\"--num-epochs\", type=int, default=250)\nparser.add_argument(\"--num-workers\", type=int, default=8)\nparser.add_argument(\"--batch-size\", type=int, default=16)\nparser.add_argument(\"--tensorboard\", action=\"store_true\")\nparser.add_argument(\"--opt\", type=str, default=\"adam\")\nargs = parser.parse_args()\n\nnum_workers = args.num_workers\nbatch_size = args.batch_size\n\n\ndef collate(samples):\n    graphs, cat = map(list, zip(*samples))\n    return dgl.batch(graphs), cat\n\n\nCustomDataLoader = partial(\n    DataLoader,\n    num_workers=num_workers,\n    batch_size=batch_size,\n    shuffle=True,\n    drop_last=True,\n)\n\n\ndef train(net, opt, scheduler, train_loader, dev):\n    category_list = sorted(list(shapenet.seg_classes.keys()))\n    eye_mat = np.eye(16)\n    net.train()\n\n    total_loss = 0\n    num_batches = 0\n    total_correct = 0\n    count = 0\n    start = time.time()\n    with tqdm.tqdm(train_loader, ascii=True) as tq:\n        for data, label, cat in tq:\n            num_examples = data.shape[0]\n            data = data.to(dev, dtype=torch.float)\n            label = label.to(dev, dtype=torch.long).view(-1)\n            opt.zero_grad()\n            cat_ind = [category_list.index(c) for c in cat]\n            # An one-hot encoding for the object category\n            cat_tensor = (\n                torch.tensor(eye_mat[cat_ind])\n                .to(dev, dtype=torch.float)\n                .repeat(1, 2048)\n            )\n            cat_tensor = cat_tensor.view(num_examples, -1, 16)\n            logits = net(data, cat_tensor).permute(0, 2, 1)\n            loss = L(logits, label)\n            loss.backward()\n            opt.step()\n\n            _, preds = logits.max(1)\n\n            count += num_examples * 2048\n            loss = loss.item()\n            total_loss += loss\n            num_batches += 1\n            correct = (preds.view(-1) == label).sum().item()\n            total_correct += correct\n\n            AvgLoss = total_loss / num_batches\n            AvgAcc = total_correct / count\n\n            tq.set_postfix(\n                {\"AvgLoss\": \"%.5f\" % AvgLoss, \"AvgAcc\": \"%.5f\" % AvgAcc}\n            )\n    scheduler.step()\n    end = time.time()\n    print(\n        \"[Train] AvgLoss: {:.5}, AvgAcc: {:.5}, Time: {:.5}s\".format(\n            total_loss / num_batches, total_correct / count, end - start\n        )\n    )\n    return data, preds, AvgLoss, AvgAcc, end - start\n\n\ndef mIoU(preds, label, cat, cat_miou, seg_classes):\n    for i in range(preds.shape[0]):\n        shape_iou = 0\n        n = len(seg_classes[cat[i]])\n        for cls in seg_classes[cat[i]]:\n            pred_set = set(np.where(preds[i, :] == cls)[0])\n            label_set = set(np.where(label[i, :] == cls)[0])\n            union = len(pred_set.union(label_set))\n            inter = len(pred_set.intersection(label_set))\n            if union == 0:\n                shape_iou += 1\n            else:\n                shape_iou += inter / union\n        shape_iou /= n\n        cat_miou[cat[i]][0] += shape_iou\n        cat_miou[cat[i]][1] += 1\n\n    return cat_miou\n\n\ndef evaluate(net, test_loader, dev, per_cat_verbose=False):\n    category_list = sorted(list(shapenet.seg_classes.keys()))\n    eye_mat = np.eye(16)\n    net.eval()\n\n    cat_miou = {}\n    for k in shapenet.seg_classes.keys():\n        cat_miou[k] = [0, 0]\n    miou = 0\n    count = 0\n    per_cat_miou = 0\n    per_cat_count = 0\n\n    with torch.no_grad():\n        with tqdm.tqdm(test_loader, ascii=True) as tq:\n            for data, label, cat in tq:\n                num_examples = data.shape[0]\n                data = data.to(dev, dtype=torch.float)\n                label = label.to(dev, dtype=torch.long)\n                cat_ind = [category_list.index(c) for c in cat]\n                cat_tensor = (\n                    torch.tensor(eye_mat[cat_ind])\n                    .to(dev, dtype=torch.float)\n                    .repeat(1, 2048)\n                )\n                cat_tensor = cat_tensor.view(num_examples, -1, 16)\n                logits = net(data, cat_tensor).permute(0, 2, 1)\n                _, preds = logits.max(1)\n\n                cat_miou = mIoU(\n                    preds.cpu().numpy(),\n                    label.view(num_examples, -1).cpu().numpy(),\n                    cat,\n                    cat_miou,\n                    shapenet.seg_classes,\n                )\n                for _, v in cat_miou.items():\n                    if v[1] > 0:\n                        miou += v[0]\n                        count += v[1]\n                        per_cat_miou += v[0] / v[1]\n                        per_cat_count += 1\n                tq.set_postfix(\n                    {\n                        \"mIoU\": \"%.5f\" % (miou / count),\n                        \"per Category mIoU\": \"%.5f\"\n                        % (per_cat_miou / per_cat_count),\n                    }\n                )\n    print(\n        \"[Test] mIoU: %.5f, per Category mIoU: %.5f\"\n        % (miou / count, per_cat_miou / per_cat_count)\n    )\n    if per_cat_verbose:\n        print(\"-\" * 60)\n        print(\"Per-Category mIoU:\")\n        for k, v in cat_miou.items():\n            if v[1] > 0:\n                print(\"%s mIoU=%.5f\" % (k, v[0] / v[1]))\n            else:\n                print(\"%s mIoU=%.5f\" % (k, 1))\n        print(\"-\" * 60)\n    return miou / count, per_cat_miou / per_cat_count\n\n\ndev = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\nnet = PointTransformerSeg(50, batch_size)\n\nnet = net.to(dev)\nif args.load_model_path:\n    net.load_state_dict(\n        torch.load(args.load_model_path, weights_only=False, map_location=dev)\n    )\n\nif args.opt == \"sgd\":\n    # The optimizer strategy described in paper:\n    opt = torch.optim.SGD(\n        net.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4\n    )\n    scheduler = torch.optim.lr_scheduler.MultiStepLR(\n        opt, milestones=[120, 160], gamma=0.1\n    )\nelif args.opt == \"adam\":\n    # The optimizer strategy proposed by\n    # https://github.com/qq456cvb/Point-Transformers:\n    opt = torch.optim.Adam(\n        net.parameters(),\n        lr=1e-3,\n        betas=(0.9, 0.999),\n        eps=1e-08,\n        weight_decay=1e-4,\n    )\n    scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=50, gamma=0.3)\n\nL = PartSegLoss()\n\nshapenet = ShapeNet(2048, normal_channel=False)\n\ntrain_loader = CustomDataLoader(shapenet.trainval())\ntest_loader = CustomDataLoader(shapenet.test())\n\n# Tensorboard\nif args.tensorboard:\n    import torchvision\n    from torch.utils.tensorboard import SummaryWriter\n    from torchvision import datasets, transforms\n\n    writer = SummaryWriter()\n# Select 50 distinct colors for different parts\ncolor_map = torch.tensor(\n    [\n        [47, 79, 79],\n        [139, 69, 19],\n        [112, 128, 144],\n        [85, 107, 47],\n        [139, 0, 0],\n        [128, 128, 0],\n        [72, 61, 139],\n        [0, 128, 0],\n        [188, 143, 143],\n        [60, 179, 113],\n        [205, 133, 63],\n        [0, 139, 139],\n        [70, 130, 180],\n        [205, 92, 92],\n        [154, 205, 50],\n        [0, 0, 139],\n        [50, 205, 50],\n        [250, 250, 250],\n        [218, 165, 32],\n        [139, 0, 139],\n        [10, 10, 10],\n        [176, 48, 96],\n        [72, 209, 204],\n        [153, 50, 204],\n        [255, 69, 0],\n        [255, 145, 0],\n        [0, 0, 205],\n        [255, 255, 0],\n        [0, 255, 0],\n        [233, 150, 122],\n        [220, 20, 60],\n        [0, 191, 255],\n        [160, 32, 240],\n        [192, 192, 192],\n        [173, 255, 47],\n        [218, 112, 214],\n        [216, 191, 216],\n        [255, 127, 80],\n        [255, 0, 255],\n        [100, 149, 237],\n        [128, 128, 128],\n        [221, 160, 221],\n        [144, 238, 144],\n        [123, 104, 238],\n        [255, 160, 122],\n        [175, 238, 238],\n        [238, 130, 238],\n        [127, 255, 212],\n        [255, 218, 185],\n        [255, 105, 180],\n    ]\n)\n# paint each point according to its pred\n\n\ndef paint(batched_points):\n    B, N = batched_points.shape\n    colored = color_map[batched_points].squeeze(2)\n    return colored\n\n\nbest_test_miou = 0\nbest_test_per_cat_miou = 0\n\nfor epoch in range(args.num_epochs):\n    print(\"Epoch #{}: \".format(epoch))\n    data, preds, AvgLoss, AvgAcc, training_time = train(\n        net, opt, scheduler, train_loader, dev\n    )\n    if (epoch + 1) % 5 == 0 or epoch == 0:\n        test_miou, test_per_cat_miou = evaluate(net, test_loader, dev, True)\n        if test_miou > best_test_miou:\n            best_test_miou = test_miou\n            best_test_per_cat_miou = test_per_cat_miou\n            if args.save_model_path:\n                torch.save(net.state_dict(), args.save_model_path)\n        print(\n            \"Current test mIoU: %.5f (best: %.5f), per-Category mIoU: %.5f (best: %.5f)\"\n            % (\n                test_miou,\n                best_test_miou,\n                test_per_cat_miou,\n                best_test_per_cat_miou,\n            )\n        )\n    # Tensorboard\n    if args.tensorboard:\n        colored = paint(preds)\n        writer.add_mesh(\n            \"data\", vertices=data, colors=colored, global_step=epoch\n        )\n        writer.add_scalar(\n            \"training time for one epoch\", training_time, global_step=epoch\n        )\n        writer.add_scalar(\"AvgLoss\", AvgLoss, global_step=epoch)\n        writer.add_scalar(\"AvgAcc\", AvgAcc, global_step=epoch)\n        if (epoch + 1) % 5 == 0:\n            writer.add_scalar(\"test mIoU\", test_miou, global_step=epoch)\n            writer.add_scalar(\n                \"best test mIoU\", best_test_miou, global_step=epoch\n            )\n    print()\n"
  },
  {
    "path": "examples/pytorch/pointcloud/pointnet/ModelNetDataLoader.py",
    "content": "import os\nimport warnings\n\nimport numpy as np\nfrom torch.utils.data import Dataset\n\nwarnings.filterwarnings(\"ignore\")\n\n\ndef pc_normalize(pc):\n    centroid = np.mean(pc, axis=0)\n    pc = pc - centroid\n    m = np.max(np.sqrt(np.sum(pc**2, axis=1)))\n    pc = pc / m\n    return pc\n\n\ndef farthest_point_sample(point, npoint):\n    \"\"\"\n    Farthest point sampler works as follows:\n    1. Initialize the sample set S with a random point\n    2. Pick point P not in S, which maximizes the distance d(P, S)\n    3. Repeat step 2 until |S| = npoint\n\n    Input:\n        xyz: pointcloud data, [N, D]\n        npoint: number of samples\n    Return:\n        centroids: sampled pointcloud index, [npoint, D]\n    \"\"\"\n    N, D = point.shape\n    xyz = point[:, :3]\n    centroids = np.zeros((npoint,))\n    distance = np.ones((N,)) * 1e10\n    farthest = np.random.randint(0, N)\n    for i in range(npoint):\n        centroids[i] = farthest\n        centroid = xyz[farthest, :]\n        dist = np.sum((xyz - centroid) ** 2, -1)\n        mask = dist < distance\n        distance[mask] = dist[mask]\n        farthest = np.argmax(distance, -1)\n    point = point[centroids.astype(np.int32)]\n    return point\n\n\nclass ModelNetDataLoader(Dataset):\n    def __init__(\n        self,\n        root,\n        npoint=1024,\n        split=\"train\",\n        fps=False,\n        normal_channel=True,\n        cache_size=15000,\n    ):\n        \"\"\"\n        Input:\n            root: the root path to the local data files\n            npoint: number of points from each cloud\n            split: which split of the data, 'train' or 'test'\n            fps: whether to sample points with farthest point sampler\n            normal_channel: whether to use additional channel\n            cache_size: the cache size of in-memory point clouds\n        \"\"\"\n        self.root = root\n        self.npoints = npoint\n        self.fps = fps\n        self.catfile = os.path.join(self.root, \"modelnet40_shape_names.txt\")\n\n        self.cat = [line.rstrip() for line in open(self.catfile)]\n        self.classes = dict(zip(self.cat, range(len(self.cat))))\n        self.normal_channel = normal_channel\n\n        shape_ids = {}\n        shape_ids[\"train\"] = [\n            line.rstrip()\n            for line in open(os.path.join(self.root, \"modelnet40_train.txt\"))\n        ]\n        shape_ids[\"test\"] = [\n            line.rstrip()\n            for line in open(os.path.join(self.root, \"modelnet40_test.txt\"))\n        ]\n\n        assert split == \"train\" or split == \"test\"\n        shape_names = [\"_\".join(x.split(\"_\")[0:-1]) for x in shape_ids[split]]\n        # list of (shape_name, shape_txt_file_path) tuple\n        self.datapath = [\n            (\n                shape_names[i],\n                os.path.join(self.root, shape_names[i], shape_ids[split][i])\n                + \".txt\",\n            )\n            for i in range(len(shape_ids[split]))\n        ]\n        print(\"The size of %s data is %d\" % (split, len(self.datapath)))\n\n        self.cache_size = cache_size\n        self.cache = {}\n\n    def __len__(self):\n        return len(self.datapath)\n\n    def _get_item(self, index):\n        if index in self.cache:\n            point_set, cls = self.cache[index]\n        else:\n            fn = self.datapath[index]\n            cls = self.classes[self.datapath[index][0]]\n            cls = np.array([cls]).astype(np.int32)\n            point_set = np.loadtxt(fn[1], delimiter=\",\").astype(np.float32)\n            if self.fps:\n                point_set = farthest_point_sample(point_set, self.npoints)\n            else:\n                point_set = point_set[0 : self.npoints, :]\n\n            point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])\n\n            if not self.normal_channel:\n                point_set = point_set[:, 0:3]\n\n            if len(self.cache) < self.cache_size:\n                self.cache[index] = (point_set, cls)\n\n        return point_set, cls\n\n    def __getitem__(self, index):\n        return self._get_item(index)\n"
  },
  {
    "path": "examples/pytorch/pointcloud/pointnet/README.md",
    "content": "PointNet and PointNet++ for Point Cloud Classification and Segmentation\n====\n\nThis is a reproduction of the papers\n- [PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation](https://arxiv.org/abs/1612.00593).\n- [PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space](https://arxiv.org/abs/1706.02413).\n\n# Performance\n\n## Classification\n| Model           | Dataset    | Metric   | Score - PyTorch  | Score - DGL | Time(s) - PyTorch | Time(s) - DGL |\n|-----------------|------------|----------|------------------|-------------|-------------------|---------------|\n| PointNet        | ModelNet40 | Accuracy | 89.2(Official)   | 89.3        | 181.8             | 95.0          |\n| PointNet++(SSG) | ModelNet40 | Accuracy | 92.4             | 93.3        | 182.6             | 133.7         |\n| PointNet++(MSG) | ModelNet40 | Accuracy | 92.8             | 93.3        | 383.6             | 240.5         |\n\n## Part Segmentation\n\n| Model           | Dataset    | Metric   | Score - PyTorch | Score - DGL | Time(s) - PyTorch | Time(s) - DGL |\n|-----------------|------------|----------|-----------------|-------------|-------------------|---------------|\n| PointNet        | ShapeNet   | mIoU     | 84.3            | 83.6        | 251.6             | 234.0         |\n| PointNet++(SSG) | ShapeNet   | mIoU     | 84.9            | 84.5        | 361.7             | 240.1         |\n| PointNet++(MSG) | ShapeNet   | mIoU     | 85.4            | 84.6        | 817.3             | 821.8         |\n\n+ Score - PyTorch are collected from [this repo](https://github.com/yanx27/Pointnet_Pointnet2_pytorch).\n+ Time(s) are the average training time per epoch, measured on EC2 g4dn.4xlarge instance w/ Tesla T4 GPU.\n# How to Run\n\nFor point cloud classification, run with\n\n```python\npython train_cls.py\n```\n\nFor point cloud part-segmentation, run with\n\n```python\npython train_partseg.py\n```\n\n## To Visualize Part Segmentation in Tensorboard\n![Screenshot](vis.png)\nFirst ``pip install tensorboard``\nthen run\n```python \npython train_partseg.py --tensorboard\n```\nTo display in Tensorboard, run \n``tensorboard --logdir=runs``\n"
  },
  {
    "path": "examples/pytorch/pointcloud/pointnet/ShapeNet.py",
    "content": "import json\nimport os\nfrom zipfile import ZipFile\n\nimport dgl\n\nimport numpy as np\nimport tqdm\nfrom dgl.data.utils import download, get_download_dir\nfrom scipy.sparse import csr_matrix\nfrom torch.utils.data import Dataset\n\n\nclass ShapeNet(object):\n    def __init__(self, num_points=2048, normal_channel=True):\n        self.num_points = num_points\n        self.normal_channel = normal_channel\n\n        SHAPENET_DOWNLOAD_URL = \"https://shapenet.cs.stanford.edu/media/shapenetcore_partanno_segmentation_benchmark_v0_normal.zip\"\n        download_path = get_download_dir()\n        data_filename = (\n            \"shapenetcore_partanno_segmentation_benchmark_v0_normal.zip\"\n        )\n        data_path = os.path.join(\n            download_path,\n            \"shapenetcore_partanno_segmentation_benchmark_v0_normal\",\n        )\n        if not os.path.exists(data_path):\n            local_path = os.path.join(download_path, data_filename)\n            if not os.path.exists(local_path):\n                download(SHAPENET_DOWNLOAD_URL, local_path, verify_ssl=False)\n            with ZipFile(local_path) as z:\n                z.extractall(path=download_path)\n\n        synset_file = \"synsetoffset2category.txt\"\n        with open(os.path.join(data_path, synset_file)) as f:\n            synset = [t.split(\"\\n\")[0].split(\"\\t\") for t in f.readlines()]\n        self.synset_dict = {}\n        for syn in synset:\n            self.synset_dict[syn[1]] = syn[0]\n        self.seg_classes = {\n            \"Airplane\": [0, 1, 2, 3],\n            \"Bag\": [4, 5],\n            \"Cap\": [6, 7],\n            \"Car\": [8, 9, 10, 11],\n            \"Chair\": [12, 13, 14, 15],\n            \"Earphone\": [16, 17, 18],\n            \"Guitar\": [19, 20, 21],\n            \"Knife\": [22, 23],\n            \"Lamp\": [24, 25, 26, 27],\n            \"Laptop\": [28, 29],\n            \"Motorbike\": [30, 31, 32, 33, 34, 35],\n            \"Mug\": [36, 37],\n            \"Pistol\": [38, 39, 40],\n            \"Rocket\": [41, 42, 43],\n            \"Skateboard\": [44, 45, 46],\n            \"Table\": [47, 48, 49],\n        }\n\n        train_split_json = \"shuffled_train_file_list.json\"\n        val_split_json = \"shuffled_val_file_list.json\"\n        test_split_json = \"shuffled_test_file_list.json\"\n        split_path = os.path.join(data_path, \"train_test_split\")\n        with open(os.path.join(split_path, train_split_json)) as f:\n            tmp = f.read()\n            self.train_file_list = [\n                os.path.join(data_path, t.replace(\"shape_data/\", \"\") + \".txt\")\n                for t in json.loads(tmp)\n            ]\n        with open(os.path.join(split_path, val_split_json)) as f:\n            tmp = f.read()\n            self.val_file_list = [\n                os.path.join(data_path, t.replace(\"shape_data/\", \"\") + \".txt\")\n                for t in json.loads(tmp)\n            ]\n        with open(os.path.join(split_path, test_split_json)) as f:\n            tmp = f.read()\n            self.test_file_list = [\n                os.path.join(data_path, t.replace(\"shape_data/\", \"\") + \".txt\")\n                for t in json.loads(tmp)\n            ]\n\n    def train(self):\n        return ShapeNetDataset(\n            self, \"train\", self.num_points, self.normal_channel\n        )\n\n    def valid(self):\n        return ShapeNetDataset(\n            self, \"valid\", self.num_points, self.normal_channel\n        )\n\n    def trainval(self):\n        return ShapeNetDataset(\n            self, \"trainval\", self.num_points, self.normal_channel\n        )\n\n    def test(self):\n        return ShapeNetDataset(\n            self, \"test\", self.num_points, self.normal_channel\n        )\n\n\nclass ShapeNetDataset(Dataset):\n    def __init__(self, shapenet, mode, num_points, normal_channel=True):\n        super(ShapeNetDataset, self).__init__()\n        self.mode = mode\n        self.num_points = num_points\n        if not normal_channel:\n            self.dim = 3\n        else:\n            self.dim = 6\n\n        if mode == \"train\":\n            self.file_list = shapenet.train_file_list\n        elif mode == \"valid\":\n            self.file_list = shapenet.val_file_list\n        elif mode == \"test\":\n            self.file_list = shapenet.test_file_list\n        elif mode == \"trainval\":\n            self.file_list = shapenet.train_file_list + shapenet.val_file_list\n        else:\n            raise \"Not supported `mode`\"\n\n        data_list = []\n        label_list = []\n        category_list = []\n        print(\"Loading data from split \" + self.mode)\n        for fn in tqdm.tqdm(self.file_list, ascii=True):\n            with open(fn) as f:\n                data = np.array(\n                    [t.split(\"\\n\")[0].split(\" \") for t in f.readlines()]\n                ).astype(float)\n            data_list.append(data[:, 0 : self.dim])\n            label_list.append(data[:, 6].astype(int))\n            category_list.append(shapenet.synset_dict[fn.split(\"/\")[-2]])\n        self.data = data_list\n        self.label = label_list\n        self.category = category_list\n\n    def translate(self, x, scale=(2 / 3, 3 / 2), shift=(-0.2, 0.2), size=3):\n        xyz1 = np.random.uniform(low=scale[0], high=scale[1], size=[size])\n        xyz2 = np.random.uniform(low=shift[0], high=shift[1], size=[size])\n        x = np.add(np.multiply(x, xyz1), xyz2).astype(\"float32\")\n        return x\n\n    def __len__(self):\n        return len(self.data)\n\n    def __getitem__(self, i):\n        inds = np.random.choice(\n            self.data[i].shape[0], self.num_points, replace=True\n        )\n        x = self.data[i][inds, : self.dim]\n        y = self.label[i][inds]\n        cat = self.category[i]\n        if self.mode == \"train\":\n            x = self.translate(x, size=self.dim)\n        x = x.astype(float)\n        y = y.astype(int)\n        return x, y, cat\n"
  },
  {
    "path": "examples/pytorch/pointcloud/pointnet/pointnet2.py",
    "content": "import dgl\nimport dgl.function as fn\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dgl.geometry import (\n    farthest_point_sampler,\n)  # dgl.geometry.pytorch -> dgl.geometry\nfrom torch.autograd import Variable\n\n\"\"\"\nPart of the code are adapted from\nhttps://github.com/yanx27/Pointnet_Pointnet2_pytorch\n\"\"\"\n\n\ndef square_distance(src, dst):\n    \"\"\"\n    Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch\n    \"\"\"\n    B, N, _ = src.shape\n    _, M, _ = dst.shape\n    dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))\n    dist += torch.sum(src**2, -1).view(B, N, 1)\n    dist += torch.sum(dst**2, -1).view(B, 1, M)\n    return dist\n\n\ndef index_points(points, idx):\n    \"\"\"\n    Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch\n    \"\"\"\n    device = points.device\n    B = points.shape[0]\n    view_shape = list(idx.shape)\n    view_shape[1:] = [1] * (len(view_shape) - 1)\n    repeat_shape = list(idx.shape)\n    repeat_shape[0] = 1\n    batch_indices = (\n        torch.arange(B, dtype=torch.long)\n        .to(device)\n        .view(view_shape)\n        .repeat(repeat_shape)\n    )\n    new_points = points[batch_indices, idx, :]\n    return new_points\n\n\nclass FixedRadiusNearNeighbors(nn.Module):\n    \"\"\"\n    Ball Query - Find the neighbors with-in a fixed radius\n    \"\"\"\n\n    def __init__(self, radius, n_neighbor):\n        super(FixedRadiusNearNeighbors, self).__init__()\n        self.radius = radius\n        self.n_neighbor = n_neighbor\n\n    def forward(self, pos, centroids):\n        \"\"\"\n        Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch\n        \"\"\"\n        device = pos.device\n        B, N, _ = pos.shape\n        center_pos = index_points(pos, centroids)\n        _, S, _ = center_pos.shape\n        group_idx = (\n            torch.arange(N, dtype=torch.long)\n            .to(device)\n            .view(1, 1, N)\n            .repeat([B, S, 1])\n        )\n        sqrdists = square_distance(center_pos, pos)\n        group_idx[sqrdists > self.radius**2] = N\n        group_idx = group_idx.sort(dim=-1)[0][:, :, : self.n_neighbor]\n        group_first = (\n            group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, self.n_neighbor])\n        )\n        mask = group_idx == N\n        group_idx[mask] = group_first[mask]\n        return group_idx\n\n\nclass FixedRadiusNNGraph(nn.Module):\n    \"\"\"\n    Build NN graph\n    \"\"\"\n\n    def __init__(self, radius, n_neighbor):\n        super(FixedRadiusNNGraph, self).__init__()\n        self.radius = radius\n        self.n_neighbor = n_neighbor\n        self.frnn = FixedRadiusNearNeighbors(radius, n_neighbor)\n\n    def forward(self, pos, centroids, feat=None):\n        dev = pos.device\n        group_idx = self.frnn(pos, centroids)\n        B, N, _ = pos.shape\n        glist = []\n        for i in range(B):\n            center = torch.zeros((N)).to(dev)\n            center[centroids[i]] = 1\n            src = group_idx[i].contiguous().view(-1)\n            dst = centroids[i].view(-1, 1).repeat(1, self.n_neighbor).view(-1)\n\n            unified = torch.cat([src, dst])\n            uniq, inv_idx = torch.unique(unified, return_inverse=True)\n            src_idx = inv_idx[: src.shape[0]]\n            dst_idx = inv_idx[src.shape[0] :]\n\n            g = dgl.graph((src_idx, dst_idx))\n            g.ndata[\"pos\"] = pos[i][uniq]\n            g.ndata[\"center\"] = center[uniq]\n            if feat is not None:\n                g.ndata[\"feat\"] = feat[i][uniq]\n            glist.append(g)\n        bg = dgl.batch(glist)\n        return bg\n\n\nclass RelativePositionMessage(nn.Module):\n    \"\"\"\n    Compute the input feature from neighbors\n    \"\"\"\n\n    def __init__(self, n_neighbor):\n        super(RelativePositionMessage, self).__init__()\n        self.n_neighbor = n_neighbor\n\n    def forward(self, edges):\n        pos = edges.src[\"pos\"] - edges.dst[\"pos\"]\n        if \"feat\" in edges.src:\n            res = torch.cat([pos, edges.src[\"feat\"]], 1)\n        else:\n            res = pos\n        return {\"agg_feat\": res}\n\n\nclass PointNetConv(nn.Module):\n    \"\"\"\n    Feature aggregation\n    \"\"\"\n\n    def __init__(self, sizes, batch_size):\n        super(PointNetConv, self).__init__()\n        self.batch_size = batch_size\n        self.conv = nn.ModuleList()\n        self.bn = nn.ModuleList()\n        for i in range(1, len(sizes)):\n            self.conv.append(nn.Conv2d(sizes[i - 1], sizes[i], 1))\n            self.bn.append(nn.BatchNorm2d(sizes[i]))\n\n    def forward(self, nodes):\n        shape = nodes.mailbox[\"agg_feat\"].shape\n        h = (\n            nodes.mailbox[\"agg_feat\"]\n            .view(self.batch_size, -1, shape[1], shape[2])\n            .permute(0, 3, 2, 1)\n        )\n        for conv, bn in zip(self.conv, self.bn):\n            h = conv(h)\n            h = bn(h)\n            h = F.relu(h)\n        h = torch.max(h, 2)[0]\n        feat_dim = h.shape[1]\n        h = h.permute(0, 2, 1).reshape(-1, feat_dim)\n        return {\"new_feat\": h}\n\n    def group_all(self, pos, feat):\n        \"\"\"\n        Feature aggregation and pooling for the non-sampling layer\n        \"\"\"\n        if feat is not None:\n            h = torch.cat([pos, feat], 2)\n        else:\n            h = pos\n        B, N, D = h.shape\n        _, _, C = pos.shape\n        new_pos = torch.zeros(B, 1, C)\n        h = h.permute(0, 2, 1).view(B, -1, N, 1)\n        for conv, bn in zip(self.conv, self.bn):\n            h = conv(h)\n            h = bn(h)\n            h = F.relu(h)\n        h = torch.max(h[:, :, :, 0], 2)[0]  # [B,D]\n        return new_pos, h\n\n\nclass SAModule(nn.Module):\n    \"\"\"\n    The Set Abstraction Layer\n    \"\"\"\n\n    def __init__(\n        self,\n        npoints,\n        batch_size,\n        radius,\n        mlp_sizes,\n        n_neighbor=64,\n        group_all=False,\n    ):\n        super(SAModule, self).__init__()\n        self.group_all = group_all\n        if not group_all:\n            self.npoints = npoints\n            self.frnn_graph = FixedRadiusNNGraph(radius, n_neighbor)\n        self.message = RelativePositionMessage(n_neighbor)\n        self.conv = PointNetConv(mlp_sizes, batch_size)\n        self.batch_size = batch_size\n\n    def forward(self, pos, feat):\n        if self.group_all:\n            return self.conv.group_all(pos, feat)\n\n        centroids = farthest_point_sampler(pos, self.npoints)\n        g = self.frnn_graph(pos, centroids, feat)\n        g.update_all(self.message, self.conv)\n\n        mask = g.ndata[\"center\"] == 1\n        pos_dim = g.ndata[\"pos\"].shape[-1]\n        feat_dim = g.ndata[\"new_feat\"].shape[-1]\n        pos_res = g.ndata[\"pos\"][mask].view(self.batch_size, -1, pos_dim)\n        feat_res = g.ndata[\"new_feat\"][mask].view(self.batch_size, -1, feat_dim)\n        return pos_res, feat_res\n\n\nclass SAMSGModule(nn.Module):\n    \"\"\"\n    The Set Abstraction Multi-Scale grouping Layer\n    \"\"\"\n\n    def __init__(\n        self, npoints, batch_size, radius_list, n_neighbor_list, mlp_sizes_list\n    ):\n        super(SAMSGModule, self).__init__()\n        self.batch_size = batch_size\n        self.group_size = len(radius_list)\n\n        self.npoints = npoints\n        self.frnn_graph_list = nn.ModuleList()\n        self.message_list = nn.ModuleList()\n        self.conv_list = nn.ModuleList()\n        for i in range(self.group_size):\n            self.frnn_graph_list.append(\n                FixedRadiusNNGraph(radius_list[i], n_neighbor_list[i])\n            )\n            self.message_list.append(\n                RelativePositionMessage(n_neighbor_list[i])\n            )\n            self.conv_list.append(PointNetConv(mlp_sizes_list[i], batch_size))\n\n    def forward(self, pos, feat):\n        centroids = farthest_point_sampler(pos, self.npoints)\n        feat_res_list = []\n\n        for i in range(self.group_size):\n            g = self.frnn_graph_list[i](pos, centroids, feat)\n            g.update_all(self.message_list[i], self.conv_list[i])\n            mask = g.ndata[\"center\"] == 1\n            pos_dim = g.ndata[\"pos\"].shape[-1]\n            feat_dim = g.ndata[\"new_feat\"].shape[-1]\n            if i == 0:\n                pos_res = g.ndata[\"pos\"][mask].view(\n                    self.batch_size, -1, pos_dim\n                )\n            feat_res = g.ndata[\"new_feat\"][mask].view(\n                self.batch_size, -1, feat_dim\n            )\n            feat_res_list.append(feat_res)\n\n        feat_res = torch.cat(feat_res_list, 2)\n        return pos_res, feat_res\n\n\nclass PointNet2FP(nn.Module):\n    \"\"\"\n    The Feature Propagation Layer\n    \"\"\"\n\n    def __init__(self, input_dims, sizes):\n        super(PointNet2FP, self).__init__()\n        self.convs = nn.ModuleList()\n        self.bns = nn.ModuleList()\n\n        sizes = [input_dims] + sizes\n        for i in range(1, len(sizes)):\n            self.convs.append(nn.Conv1d(sizes[i - 1], sizes[i], 1))\n            self.bns.append(nn.BatchNorm1d(sizes[i]))\n\n    def forward(self, x1, x2, feat1, feat2):\n        \"\"\"\n        Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch\n            Input:\n                x1: input points position data, [B, N, C]\n                x2: sampled input points position data, [B, S, C]\n                feat1: input points data, [B, N, D]\n                feat2: input points data, [B, S, D]\n            Return:\n                new_feat: upsampled points data, [B, D', N]\n        \"\"\"\n        B, N, C = x1.shape\n        _, S, _ = x2.shape\n\n        if S == 1:\n            interpolated_feat = feat2.repeat(1, N, 1)\n        else:\n            dists = square_distance(x1, x2)\n            dists, idx = dists.sort(dim=-1)\n            dists, idx = dists[:, :, :3], idx[:, :, :3]  # [B, N, 3]\n\n            dist_recip = 1.0 / (dists + 1e-8)\n            norm = torch.sum(dist_recip, dim=2, keepdim=True)\n            weight = dist_recip / norm\n            interpolated_feat = torch.sum(\n                index_points(feat2, idx) * weight.view(B, N, 3, 1), dim=2\n            )\n\n        if feat1 is not None:\n            new_feat = torch.cat([feat1, interpolated_feat], dim=-1)\n        else:\n            new_feat = interpolated_feat\n\n        new_feat = new_feat.permute(0, 2, 1)  # [B, D, S]\n        for i, conv in enumerate(self.convs):\n            bn = self.bns[i]\n            new_feat = F.relu(bn(conv(new_feat)))\n        return new_feat\n\n\nclass PointNet2SSGCls(nn.Module):\n    def __init__(\n        self, output_classes, batch_size, input_dims=3, dropout_prob=0.4\n    ):\n        super(PointNet2SSGCls, self).__init__()\n        self.input_dims = input_dims\n\n        self.sa_module1 = SAModule(\n            512, batch_size, 0.2, [input_dims, 64, 64, 128]\n        )\n        self.sa_module2 = SAModule(\n            128, batch_size, 0.4, [128 + 3, 128, 128, 256]\n        )\n        self.sa_module3 = SAModule(\n            None, batch_size, None, [256 + 3, 256, 512, 1024], group_all=True\n        )\n\n        self.mlp1 = nn.Linear(1024, 512)\n        self.bn1 = nn.BatchNorm1d(512)\n        self.drop1 = nn.Dropout(dropout_prob)\n\n        self.mlp2 = nn.Linear(512, 256)\n        self.bn2 = nn.BatchNorm1d(256)\n        self.drop2 = nn.Dropout(dropout_prob)\n\n        self.mlp_out = nn.Linear(256, output_classes)\n\n    def forward(self, x):\n        if x.shape[-1] > 3:\n            pos = x[:, :, :3]\n            feat = x[:, :, 3:]\n        else:\n            pos = x\n            feat = None\n        pos, feat = self.sa_module1(pos, feat)\n        pos, feat = self.sa_module2(pos, feat)\n        _, h = self.sa_module3(pos, feat)\n\n        h = self.mlp1(h)\n        h = self.bn1(h)\n        h = F.relu(h)\n        h = self.drop1(h)\n        h = self.mlp2(h)\n        h = self.bn2(h)\n        h = F.relu(h)\n        h = self.drop2(h)\n\n        out = self.mlp_out(h)\n        return out\n\n\nclass PointNet2MSGCls(nn.Module):\n    def __init__(\n        self, output_classes, batch_size, input_dims=3, dropout_prob=0.4\n    ):\n        super(PointNet2MSGCls, self).__init__()\n        self.input_dims = input_dims\n\n        self.sa_msg_module1 = SAMSGModule(\n            512,\n            batch_size,\n            [0.1, 0.2, 0.4],\n            [16, 32, 128],\n            [\n                [input_dims, 32, 32, 64],\n                [input_dims, 64, 64, 128],\n                [input_dims, 64, 96, 128],\n            ],\n        )\n        self.sa_msg_module2 = SAMSGModule(\n            128,\n            batch_size,\n            [0.2, 0.4, 0.8],\n            [32, 64, 128],\n            [\n                [320 + 3, 64, 64, 128],\n                [320 + 3, 128, 128, 256],\n                [320 + 3, 128, 128, 256],\n            ],\n        )\n        self.sa_module3 = SAModule(\n            None, batch_size, None, [640 + 3, 256, 512, 1024], group_all=True\n        )\n\n        self.mlp1 = nn.Linear(1024, 512)\n        self.bn1 = nn.BatchNorm1d(512)\n        self.drop1 = nn.Dropout(dropout_prob)\n\n        self.mlp2 = nn.Linear(512, 256)\n        self.bn2 = nn.BatchNorm1d(256)\n        self.drop2 = nn.Dropout(dropout_prob)\n\n        self.mlp_out = nn.Linear(256, output_classes)\n\n    def forward(self, x):\n        if x.shape[-1] > 3:\n            pos = x[:, :, :3]\n            feat = x[:, :, 3:]\n        else:\n            pos = x\n            feat = None\n        pos, feat = self.sa_msg_module1(pos, feat)\n        pos, feat = self.sa_msg_module2(pos, feat)\n        _, h = self.sa_module3(pos, feat)\n\n        h = self.mlp1(h)\n        h = self.bn1(h)\n        h = F.relu(h)\n        h = self.drop1(h)\n        h = self.mlp2(h)\n        h = self.bn2(h)\n        h = F.relu(h)\n        h = self.drop2(h)\n\n        out = self.mlp_out(h)\n        return out\n"
  },
  {
    "path": "examples/pytorch/pointcloud/pointnet/pointnet2_partseg.py",
    "content": "import numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom pointnet2 import PointNet2FP, SAModule, SAMSGModule\nfrom torch.autograd import Variable\n\n\nclass PointNet2SSGPartSeg(nn.Module):\n    def __init__(self, output_classes, batch_size, input_dims=6):\n        super(PointNet2SSGPartSeg, self).__init__()\n        # if normal_channel == true, input_dims = 6+3\n        self.input_dims = input_dims\n\n        self.sa_module1 = SAModule(\n            512, batch_size, 0.2, [input_dims, 64, 64, 128], n_neighbor=32\n        )\n        self.sa_module2 = SAModule(\n            128, batch_size, 0.4, [128 + 3, 128, 128, 256]\n        )\n        self.sa_module3 = SAModule(\n            None, batch_size, None, [256 + 3, 256, 512, 1024], group_all=True\n        )\n\n        self.fp3 = PointNet2FP(1280, [256, 256])\n        self.fp2 = PointNet2FP(384, [256, 128])\n        # if normal_channel == true, 128+16+6+3\n        self.fp1 = PointNet2FP(128 + 16 + 6, [128, 128, 128])\n\n        self.conv1 = nn.Conv1d(128, 128, 1)\n        self.bn1 = nn.BatchNorm1d(128)\n        self.drop1 = nn.Dropout(0.5)\n        self.conv2 = nn.Conv1d(128, output_classes, 1)\n\n    def forward(self, x, cat_vec=None):\n        if x.shape[-1] > 3:\n            l0_pos = x[:, :, :3]\n            l0_feat = x\n        else:\n            l0_pos = x\n            l0_feat = x\n        # Set Abstraction layers\n        l1_pos, l1_feat = self.sa_module1(l0_pos, l0_feat)  # l1_feat: [B, N, D]\n        l2_pos, l2_feat = self.sa_module2(l1_pos, l1_feat)\n        l3_pos, l3_feat = self.sa_module3(l2_pos, l2_feat)  # [B, N, C], [B, D]\n        # Feature Propagation layers\n        l2_feat = self.fp3(\n            l2_pos, l3_pos, l2_feat, l3_feat.unsqueeze(1)\n        )  # l2_feat: [B, D, N]\n        l1_feat = self.fp2(l1_pos, l2_pos, l1_feat, l2_feat.permute(0, 2, 1))\n        l0_feat = torch.cat([cat_vec.permute(0, 2, 1), l0_pos, l0_feat], 2)\n        l0_feat = self.fp1(l0_pos, l1_pos, l0_feat, l1_feat.permute(0, 2, 1))\n        # FC layers\n        feat = F.relu(self.bn1(self.conv1(l0_feat)))\n        out = self.drop1(feat)\n        out = self.conv2(out)  # [B, output_classes, N]\n        return out\n\n\nclass PointNet2MSGPartSeg(nn.Module):\n    def __init__(self, output_classes, batch_size, input_dims=6):\n        super(PointNet2MSGPartSeg, self).__init__()\n\n        self.sa_msg_module1 = SAMSGModule(\n            512,\n            batch_size,\n            [0.1, 0.2, 0.4],\n            [32, 64, 128],\n            [\n                [input_dims, 32, 32, 64],\n                [input_dims, 64, 64, 128],\n                [input_dims, 64, 96, 128],\n            ],\n        )\n        self.sa_msg_module2 = SAMSGModule(\n            128,\n            batch_size,\n            [0.4, 0.8],\n            [64, 128],\n            [\n                [128 + 128 + 64 + 3, 128, 128, 256],\n                [128 + 128 + 64 + 3, 128, 196, 256],\n            ],\n        )\n        self.sa_module3 = SAModule(\n            None, batch_size, None, [512 + 3, 256, 512, 1024], group_all=True\n        )\n\n        self.fp3 = PointNet2FP(1536, [256, 256])\n        self.fp2 = PointNet2FP(576, [256, 128])\n        # if normal_channel == true, 150 + 3\n        self.fp1 = PointNet2FP(150, [128, 128])\n\n        self.conv1 = nn.Conv1d(128, 128, 1)\n        self.bn1 = nn.BatchNorm1d(128)\n        self.drop1 = nn.Dropout(0.5)\n        self.conv2 = nn.Conv1d(128, output_classes, 1)\n\n    def forward(self, x, cat_vec=None):\n        if x.shape[-1] > 3:\n            l0_pos = x[:, :, :3]\n            l0_feat = x\n        else:\n            l0_pos = x\n            l0_feat = x\n        # Set Abstraction layers\n        l1_pos, l1_feat = self.sa_msg_module1(l0_pos, l0_feat)\n        l2_pos, l2_feat = self.sa_msg_module2(l1_pos, l1_feat)\n        l3_pos, l3_feat = self.sa_module3(l2_pos, l2_feat)\n        # Feature Propagation layers\n        l2_feat = self.fp3(l2_pos, l3_pos, l2_feat, l3_feat.unsqueeze(1))\n        l1_feat = self.fp2(l1_pos, l2_pos, l1_feat, l2_feat.permute(0, 2, 1))\n        l0_feat = torch.cat([cat_vec.permute(0, 2, 1), l0_pos, l0_feat], 2)\n        l0_feat = self.fp1(l0_pos, l1_pos, l0_feat, l1_feat.permute(0, 2, 1))\n        # FC layers\n        feat = F.relu(self.bn1(self.conv1(l0_feat)))\n        out = self.drop1(feat)\n        out = self.conv2(out)\n        return out\n"
  },
  {
    "path": "examples/pytorch/pointcloud/pointnet/pointnet_cls.py",
    "content": "import numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.autograd import Variable\n\n\nclass PointNetCls(nn.Module):\n    def __init__(\n        self,\n        output_classes,\n        input_dims=3,\n        conv1_dim=64,\n        dropout_prob=0.5,\n        use_transform=True,\n    ):\n        super(PointNetCls, self).__init__()\n        self.input_dims = input_dims\n        self.conv1 = nn.ModuleList()\n        self.conv1.append(nn.Conv1d(input_dims, conv1_dim, 1))\n        self.conv1.append(nn.Conv1d(conv1_dim, conv1_dim, 1))\n        self.conv1.append(nn.Conv1d(conv1_dim, conv1_dim, 1))\n\n        self.bn1 = nn.ModuleList()\n        self.bn1.append(nn.BatchNorm1d(conv1_dim))\n        self.bn1.append(nn.BatchNorm1d(conv1_dim))\n        self.bn1.append(nn.BatchNorm1d(conv1_dim))\n\n        self.conv2 = nn.ModuleList()\n        self.conv2.append(nn.Conv1d(conv1_dim, conv1_dim * 2, 1))\n        self.conv2.append(nn.Conv1d(conv1_dim * 2, conv1_dim * 16, 1))\n\n        self.bn2 = nn.ModuleList()\n        self.bn2.append(nn.BatchNorm1d(conv1_dim * 2))\n        self.bn2.append(nn.BatchNorm1d(conv1_dim * 16))\n\n        self.maxpool = nn.MaxPool1d(conv1_dim * 16)\n        self.pool_feat_len = conv1_dim * 16\n\n        self.mlp3 = nn.ModuleList()\n        self.mlp3.append(nn.Linear(conv1_dim * 16, conv1_dim * 8))\n        self.mlp3.append(nn.Linear(conv1_dim * 8, conv1_dim * 4))\n\n        self.bn3 = nn.ModuleList()\n        self.bn3.append(nn.BatchNorm1d(conv1_dim * 8))\n        self.bn3.append(nn.BatchNorm1d(conv1_dim * 4))\n\n        self.dropout = nn.Dropout(0.3)\n        self.mlp_out = nn.Linear(conv1_dim * 4, output_classes)\n\n        self.use_transform = use_transform\n        if use_transform:\n            self.transform1 = TransformNet(input_dims)\n            self.trans_bn1 = nn.BatchNorm1d(input_dims)\n            self.transform2 = TransformNet(conv1_dim)\n            self.trans_bn2 = nn.BatchNorm1d(conv1_dim)\n\n    def forward(self, x):\n        batch_size = x.shape[0]\n        h = x.permute(0, 2, 1)\n        if self.use_transform:\n            trans = self.transform1(h)\n            h = h.transpose(2, 1)\n            h = torch.bmm(h, trans)\n            h = h.transpose(2, 1)\n            h = F.relu(self.trans_bn1(h))\n\n        for conv, bn in zip(self.conv1, self.bn1):\n            h = conv(h)\n            h = bn(h)\n            h = F.relu(h)\n\n        if self.use_transform:\n            trans = self.transform2(h)\n            h = h.transpose(2, 1)\n            h = torch.bmm(h, trans)\n            h = h.transpose(2, 1)\n            h = F.relu(self.trans_bn2(h))\n\n        for conv, bn in zip(self.conv2, self.bn2):\n            h = conv(h)\n            h = bn(h)\n            h = F.relu(h)\n\n        h = self.maxpool(h).view(-1, self.pool_feat_len)\n        for mlp, bn in zip(self.mlp3, self.bn3):\n            h = mlp(h)\n            h = bn(h)\n            h = F.relu(h)\n\n        h = self.dropout(h)\n        out = self.mlp_out(h)\n        return out\n\n\nclass TransformNet(nn.Module):\n    def __init__(self, input_dims=3, conv1_dim=64):\n        super(TransformNet, self).__init__()\n        self.conv = nn.ModuleList()\n        self.conv.append(nn.Conv1d(input_dims, conv1_dim, 1))\n        self.conv.append(nn.Conv1d(conv1_dim, conv1_dim * 2, 1))\n        self.conv.append(nn.Conv1d(conv1_dim * 2, conv1_dim * 16, 1))\n\n        self.bn = nn.ModuleList()\n        self.bn.append(nn.BatchNorm1d(conv1_dim))\n        self.bn.append(nn.BatchNorm1d(conv1_dim * 2))\n        self.bn.append(nn.BatchNorm1d(conv1_dim * 16))\n\n        self.maxpool = nn.MaxPool1d(conv1_dim * 16)\n        self.pool_feat_len = conv1_dim * 16\n\n        self.mlp2 = nn.ModuleList()\n        self.mlp2.append(nn.Linear(conv1_dim * 16, conv1_dim * 8))\n        self.mlp2.append(nn.Linear(conv1_dim * 8, conv1_dim * 4))\n\n        self.bn2 = nn.ModuleList()\n        self.bn2.append(nn.BatchNorm1d(conv1_dim * 8))\n        self.bn2.append(nn.BatchNorm1d(conv1_dim * 4))\n\n        self.input_dims = input_dims\n        self.mlp_out = nn.Linear(conv1_dim * 4, input_dims * input_dims)\n\n    def forward(self, h):\n        batch_size = h.shape[0]\n        for conv, bn in zip(self.conv, self.bn):\n            h = conv(h)\n            h = bn(h)\n            h = F.relu(h)\n\n        h = self.maxpool(h).view(-1, self.pool_feat_len)\n        for mlp, bn in zip(self.mlp2, self.bn2):\n            h = mlp(h)\n            h = bn(h)\n            h = F.relu(h)\n\n        out = self.mlp_out(h)\n\n        iden = Variable(\n            torch.from_numpy(\n                np.eye(self.input_dims).flatten().astype(np.float32)\n            )\n        )\n        iden = iden.view(1, self.input_dims * self.input_dims).repeat(\n            batch_size, 1\n        )\n        if out.is_cuda:\n            iden = iden.cuda()\n        out = out + iden\n        out = out.view(-1, self.input_dims, self.input_dims)\n        return out\n"
  },
  {
    "path": "examples/pytorch/pointcloud/pointnet/pointnet_partseg.py",
    "content": "import numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.autograd import Variable\n\n\nclass PointNetPartSeg(nn.Module):\n    def __init__(\n        self, output_classes, input_dims=3, num_points=2048, use_transform=True\n    ):\n        super(PointNetPartSeg, self).__init__()\n        self.input_dims = input_dims\n\n        self.conv1 = nn.ModuleList()\n        self.conv1.append(nn.Conv1d(input_dims, 64, 1))\n        self.conv1.append(nn.Conv1d(64, 128, 1))\n        self.conv1.append(nn.Conv1d(128, 128, 1))\n\n        self.bn1 = nn.ModuleList()\n        self.bn1.append(nn.BatchNorm1d(64))\n        self.bn1.append(nn.BatchNorm1d(128))\n        self.bn1.append(nn.BatchNorm1d(128))\n\n        self.conv2 = nn.ModuleList()\n        self.conv2.append(nn.Conv1d(128, 512, 1))\n\n        self.bn2 = nn.ModuleList()\n        self.bn2.append(nn.BatchNorm1d(512))\n\n        self.conv_max = nn.Conv1d(512, 2048, 1)\n        self.bn_max = nn.BatchNorm1d(2048)\n\n        self.maxpool = nn.MaxPool1d(num_points)\n        self.pool_feat_len = 2048\n\n        self.conv3 = nn.ModuleList()\n        self.conv3.append(nn.Conv1d(2048 + 64 + 128 * 3 + 512 + 16, 256, 1))\n        self.conv3.append(nn.Conv1d(256, 256, 1))\n        self.conv3.append(nn.Conv1d(256, 128, 1))\n\n        self.bn3 = nn.ModuleList()\n        self.bn3.append(nn.BatchNorm1d(256))\n        self.bn3.append(nn.BatchNorm1d(256))\n        self.bn3.append(nn.BatchNorm1d(128))\n\n        self.conv_out = nn.Conv1d(128, output_classes, 1)\n\n        self.use_transform = use_transform\n        if use_transform:\n            self.transform1 = TransformNet(self.input_dims)\n            self.trans_bn1 = nn.BatchNorm1d(self.input_dims)\n            self.transform2 = TransformNet(128)\n            self.trans_bn2 = nn.BatchNorm1d(128)\n\n    def forward(self, x, cat_vec=None):\n        batch_size = x.shape[0]\n        h = x.permute(0, 2, 1)\n        num_points = h.shape[2]\n        if self.use_transform:\n            trans = self.transform1(h)\n            h = h.transpose(2, 1)\n            h = torch.bmm(h, trans)\n            h = h.transpose(2, 1)\n            h = F.relu(self.trans_bn1(h))\n\n        mid_feat = []\n        for conv, bn in zip(self.conv1, self.bn1):\n            h = conv(h)\n            h = bn(h)\n            h = F.relu(h)\n            mid_feat.append(h)\n\n        if self.use_transform:\n            trans = self.transform2(h)\n            h = h.transpose(2, 1)\n            h = torch.bmm(h, trans)\n            h = h.transpose(2, 1)\n            h = F.relu(self.trans_bn2(h))\n            mid_feat.append(h)\n\n        for conv, bn in zip(self.conv2, self.bn2):\n            h = conv(h)\n            h = bn(h)\n            h = F.relu(h)\n            mid_feat.append(h)\n\n        h = self.conv_max(h)\n        h = self.bn_max(h)\n        h = self.maxpool(h).view(batch_size, -1, 1).repeat(1, 1, num_points)\n        mid_feat.append(h)\n        if cat_vec is not None:\n            mid_feat.append(cat_vec)\n        h = torch.cat(mid_feat, 1)\n        for conv, bn in zip(self.conv3, self.bn3):\n            h = conv(h)\n            h = bn(h)\n            h = F.relu(h)\n\n        out = self.conv_out(h)\n        return out\n\n\nclass TransformNet(nn.Module):\n    def __init__(self, input_dims=3, num_points=2048):\n        super(TransformNet, self).__init__()\n        self.conv = nn.ModuleList()\n        self.conv.append(nn.Conv1d(input_dims, 64, 1))\n        self.conv.append(nn.Conv1d(64, 128, 1))\n        self.conv.append(nn.Conv1d(128, 1024, 1))\n\n        self.bn = nn.ModuleList()\n        self.bn.append(nn.BatchNorm1d(64))\n        self.bn.append(nn.BatchNorm1d(128))\n        self.bn.append(nn.BatchNorm1d(1024))\n\n        self.maxpool = nn.MaxPool1d(num_points)\n        self.pool_feat_len = 1024\n\n        self.mlp2 = nn.ModuleList()\n        self.mlp2.append(nn.Linear(1024, 512))\n        self.mlp2.append(nn.Linear(512, 256))\n\n        self.bn2 = nn.ModuleList()\n        self.bn2.append(nn.BatchNorm1d(512))\n        self.bn2.append(nn.BatchNorm1d(256))\n\n        self.input_dims = input_dims\n        self.mlp_out = nn.Linear(256, input_dims * input_dims)\n\n    def forward(self, h):\n        batch_size = h.shape[0]\n        for conv, bn in zip(self.conv, self.bn):\n            h = conv(h)\n            h = bn(h)\n            h = F.relu(h)\n\n        h = self.maxpool(h).view(-1, self.pool_feat_len)\n        for mlp, bn in zip(self.mlp2, self.bn2):\n            h = mlp(h)\n            h = bn(h)\n            h = F.relu(h)\n\n        out = self.mlp_out(h)\n\n        iden = Variable(\n            torch.from_numpy(\n                np.eye(self.input_dims).flatten().astype(np.float32)\n            )\n        )\n        iden = iden.view(1, self.input_dims * self.input_dims).repeat(\n            batch_size, 1\n        )\n        if out.is_cuda:\n            iden = iden.cuda()\n        out = out + iden\n        out = out.view(-1, self.input_dims, self.input_dims)\n        return out\n\n\nclass PartSegLoss(nn.Module):\n    def __init__(self, eps=0.2):\n        super(PartSegLoss, self).__init__()\n        self.eps = eps\n        self.loss = nn.CrossEntropyLoss()\n\n    def forward(self, logits, y):\n        num_classes = logits.shape[1]\n        logits = logits.permute(0, 2, 1).contiguous().view(-1, num_classes)\n        loss = self.loss(logits, y)\n        return loss\n"
  },
  {
    "path": "examples/pytorch/pointcloud/pointnet/provider.py",
    "content": "\"\"\"\nAdapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch/blob/master/provider.py\n\"\"\"\nimport numpy as np\n\n\ndef normalize_data(batch_data):\n    \"\"\"Normalize the batch data, use coordinates of the block centered at origin,\n    Input:\n        BxNxC array\n    Output:\n        BxNxC array\n    \"\"\"\n    B, N, C = batch_data.shape\n    normal_data = np.zeros((B, N, C))\n    for b in range(B):\n        pc = batch_data[b]\n        centroid = np.mean(pc, axis=0)\n        pc = pc - centroid\n        m = np.max(np.sqrt(np.sum(pc**2, axis=1)))\n        pc = pc / m\n        normal_data[b] = pc\n    return normal_data\n\n\ndef shuffle_data(data, labels):\n    \"\"\"Shuffle data and labels.\n    Input:\n      data: B,N,... numpy array\n      label: B,... numpy array\n    Return:\n      shuffled data, label and shuffle indices\n    \"\"\"\n    idx = np.arange(len(labels))\n    np.random.shuffle(idx)\n    return data[idx, ...], labels[idx], idx\n\n\ndef shuffle_points(batch_data):\n    \"\"\"Shuffle orders of points in each point cloud -- changes FPS behavior.\n    Use the same shuffling idx for the entire batch.\n    Input:\n        BxNxC array\n    Output:\n        BxNxC array\n    \"\"\"\n    idx = np.arange(batch_data.shape[1])\n    np.random.shuffle(idx)\n    return batch_data[:, idx, :]\n\n\ndef rotate_point_cloud(batch_data):\n    \"\"\"Randomly rotate the point clouds to augument the dataset\n    rotation is per shape based along up direction\n    Input:\n      BxNx3 array, original batch of point clouds\n    Return:\n      BxNx3 array, rotated batch of point clouds\n    \"\"\"\n    rotated_data = np.zeros(batch_data.shape, dtype=np.float32)\n    for k in range(batch_data.shape[0]):\n        rotation_angle = np.random.uniform() * 2 * np.pi\n        cosval = np.cos(rotation_angle)\n        sinval = np.sin(rotation_angle)\n        rotation_matrix = np.array(\n            [[cosval, 0, sinval], [0, 1, 0], [-sinval, 0, cosval]]\n        )\n        shape_pc = batch_data[k, ...]\n        rotated_data[k, ...] = np.dot(\n            shape_pc.reshape((-1, 3)), rotation_matrix\n        )\n    return rotated_data\n\n\ndef rotate_point_cloud_z(batch_data):\n    \"\"\"Randomly rotate the point clouds to augument the dataset\n    rotation is per shape based along up direction\n    Input:\n      BxNx3 array, original batch of point clouds\n    Return:\n      BxNx3 array, rotated batch of point clouds\n    \"\"\"\n    rotated_data = np.zeros(batch_data.shape, dtype=np.float32)\n    for k in range(batch_data.shape[0]):\n        rotation_angle = np.random.uniform() * 2 * np.pi\n        cosval = np.cos(rotation_angle)\n        sinval = np.sin(rotation_angle)\n        rotation_matrix = np.array(\n            [[cosval, sinval, 0], [-sinval, cosval, 0], [0, 0, 1]]\n        )\n        shape_pc = batch_data[k, ...]\n        rotated_data[k, ...] = np.dot(\n            shape_pc.reshape((-1, 3)), rotation_matrix\n        )\n    return rotated_data\n\n\ndef rotate_point_cloud_with_normal(batch_xyz_normal):\n    \"\"\"Randomly rotate XYZ, normal point cloud.\n    Input:\n        batch_xyz_normal: B,N,6, first three channels are XYZ, last 3 all normal\n    Output:\n        B,N,6, rotated XYZ, normal point cloud\n    \"\"\"\n    for k in range(batch_xyz_normal.shape[0]):\n        rotation_angle = np.random.uniform() * 2 * np.pi\n        cosval = np.cos(rotation_angle)\n        sinval = np.sin(rotation_angle)\n        rotation_matrix = np.array(\n            [[cosval, 0, sinval], [0, 1, 0], [-sinval, 0, cosval]]\n        )\n        shape_pc = batch_xyz_normal[k, :, 0:3]\n        shape_normal = batch_xyz_normal[k, :, 3:6]\n        batch_xyz_normal[k, :, 0:3] = np.dot(\n            shape_pc.reshape((-1, 3)), rotation_matrix\n        )\n        batch_xyz_normal[k, :, 3:6] = np.dot(\n            shape_normal.reshape((-1, 3)), rotation_matrix\n        )\n    return batch_xyz_normal\n\n\ndef rotate_perturbation_point_cloud_with_normal(\n    batch_data, angle_sigma=0.06, angle_clip=0.18\n):\n    \"\"\"Randomly perturb the point clouds by small rotations\n    Input:\n      BxNx6 array, original batch of point clouds and point normals\n    Return:\n      BxNx3 array, rotated batch of point clouds\n    \"\"\"\n    rotated_data = np.zeros(batch_data.shape, dtype=np.float32)\n    for k in range(batch_data.shape[0]):\n        angles = np.clip(\n            angle_sigma * np.random.randn(3), -angle_clip, angle_clip\n        )\n        Rx = np.array(\n            [\n                [1, 0, 0],\n                [0, np.cos(angles[0]), -np.sin(angles[0])],\n                [0, np.sin(angles[0]), np.cos(angles[0])],\n            ]\n        )\n        Ry = np.array(\n            [\n                [np.cos(angles[1]), 0, np.sin(angles[1])],\n                [0, 1, 0],\n                [-np.sin(angles[1]), 0, np.cos(angles[1])],\n            ]\n        )\n        Rz = np.array(\n            [\n                [np.cos(angles[2]), -np.sin(angles[2]), 0],\n                [np.sin(angles[2]), np.cos(angles[2]), 0],\n                [0, 0, 1],\n            ]\n        )\n        R = np.dot(Rz, np.dot(Ry, Rx))\n        shape_pc = batch_data[k, :, 0:3]\n        shape_normal = batch_data[k, :, 3:6]\n        rotated_data[k, :, 0:3] = np.dot(shape_pc.reshape((-1, 3)), R)\n        rotated_data[k, :, 3:6] = np.dot(shape_normal.reshape((-1, 3)), R)\n    return rotated_data\n\n\ndef rotate_point_cloud_by_angle(batch_data, rotation_angle):\n    \"\"\"Rotate the point cloud along up direction with certain angle.\n    Input:\n      BxNx3 array, original batch of point clouds\n    Return:\n      BxNx3 array, rotated batch of point clouds\n    \"\"\"\n    rotated_data = np.zeros(batch_data.shape, dtype=np.float32)\n    for k in range(batch_data.shape[0]):\n        # rotation_angle = np.random.uniform() * 2 * np.pi\n        cosval = np.cos(rotation_angle)\n        sinval = np.sin(rotation_angle)\n        rotation_matrix = np.array(\n            [[cosval, 0, sinval], [0, 1, 0], [-sinval, 0, cosval]]\n        )\n        shape_pc = batch_data[k, :, 0:3]\n        rotated_data[k, :, 0:3] = np.dot(\n            shape_pc.reshape((-1, 3)), rotation_matrix\n        )\n    return rotated_data\n\n\ndef rotate_point_cloud_by_angle_with_normal(batch_data, rotation_angle):\n    \"\"\"Rotate the point cloud along up direction with certain angle.\n    Input:\n      BxNx6 array, original batch of point clouds with normal\n      scalar, angle of rotation\n    Return:\n      BxNx6 array, rotated batch of point clouds iwth normal\n    \"\"\"\n    rotated_data = np.zeros(batch_data.shape, dtype=np.float32)\n    for k in range(batch_data.shape[0]):\n        # rotation_angle = np.random.uniform() * 2 * np.pi\n        cosval = np.cos(rotation_angle)\n        sinval = np.sin(rotation_angle)\n        rotation_matrix = np.array(\n            [[cosval, 0, sinval], [0, 1, 0], [-sinval, 0, cosval]]\n        )\n        shape_pc = batch_data[k, :, 0:3]\n        shape_normal = batch_data[k, :, 3:6]\n        rotated_data[k, :, 0:3] = np.dot(\n            shape_pc.reshape((-1, 3)), rotation_matrix\n        )\n        rotated_data[k, :, 3:6] = np.dot(\n            shape_normal.reshape((-1, 3)), rotation_matrix\n        )\n    return rotated_data\n\n\ndef rotate_perturbation_point_cloud(\n    batch_data, angle_sigma=0.06, angle_clip=0.18\n):\n    \"\"\"Randomly perturb the point clouds by small rotations\n    Input:\n      BxNx3 array, original batch of point clouds\n    Return:\n      BxNx3 array, rotated batch of point clouds\n    \"\"\"\n    rotated_data = np.zeros(batch_data.shape, dtype=np.float32)\n    for k in range(batch_data.shape[0]):\n        angles = np.clip(\n            angle_sigma * np.random.randn(3), -angle_clip, angle_clip\n        )\n        Rx = np.array(\n            [\n                [1, 0, 0],\n                [0, np.cos(angles[0]), -np.sin(angles[0])],\n                [0, np.sin(angles[0]), np.cos(angles[0])],\n            ]\n        )\n        Ry = np.array(\n            [\n                [np.cos(angles[1]), 0, np.sin(angles[1])],\n                [0, 1, 0],\n                [-np.sin(angles[1]), 0, np.cos(angles[1])],\n            ]\n        )\n        Rz = np.array(\n            [\n                [np.cos(angles[2]), -np.sin(angles[2]), 0],\n                [np.sin(angles[2]), np.cos(angles[2]), 0],\n                [0, 0, 1],\n            ]\n        )\n        R = np.dot(Rz, np.dot(Ry, Rx))\n        shape_pc = batch_data[k, ...]\n        rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), R)\n    return rotated_data\n\n\ndef jitter_point_cloud(batch_data, sigma=0.01, clip=0.05):\n    \"\"\"Randomly jitter points. jittering is per point.\n    Input:\n      BxNx3 array, original batch of point clouds\n    Return:\n      BxNx3 array, jittered batch of point clouds\n    \"\"\"\n    B, N, C = batch_data.shape\n    assert clip > 0\n    jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1 * clip, clip)\n    jittered_data += batch_data\n    return jittered_data\n\n\ndef shift_point_cloud(batch_data, shift_range=0.1):\n    \"\"\"Randomly shift point cloud. Shift is per point cloud.\n    Input:\n      BxNx3 array, original batch of point clouds\n    Return:\n      BxNx3 array, shifted batch of point clouds\n    \"\"\"\n    B, N, C = batch_data.shape\n    shifts = np.random.uniform(-shift_range, shift_range, (B, 3))\n    for batch_index in range(B):\n        batch_data[batch_index, :, :] += shifts[batch_index, :]\n    return batch_data\n\n\ndef random_scale_point_cloud(batch_data, scale_low=0.8, scale_high=1.25):\n    \"\"\"Randomly scale the point cloud. Scale is per point cloud.\n    Input:\n        BxNx3 array, original batch of point clouds\n    Return:\n        BxNx3 array, scaled batch of point clouds\n    \"\"\"\n    B, N, C = batch_data.shape\n    scales = np.random.uniform(scale_low, scale_high, B)\n    for batch_index in range(B):\n        batch_data[batch_index, :, :] *= scales[batch_index]\n    return batch_data\n\n\ndef random_point_dropout(batch_pc, max_dropout_ratio=0.875):\n    \"\"\"batch_pc: BxNx3\"\"\"\n    for b in range(batch_pc.shape[0]):\n        dropout_ratio = np.random.random() * max_dropout_ratio  # 0~0.875\n        drop_idx = np.where(\n            np.random.random((batch_pc.shape[1])) <= dropout_ratio\n        )[0]\n        if len(drop_idx) > 0:\n            dropout_ratio = (\n                np.random.random() * max_dropout_ratio\n            )  # 0~0.875 # not need\n            batch_pc[b, drop_idx, :] = batch_pc[\n                b, 0, :\n            ]  # set to the first point\n    return batch_pc\n"
  },
  {
    "path": "examples/pytorch/pointcloud/pointnet/train_cls.py",
    "content": "import argparse\nimport os\nimport urllib\nfrom functools import partial\n\nimport dgl\n\nimport provider\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nimport tqdm\nfrom dgl.data.utils import download, get_download_dir\nfrom ModelNetDataLoader import ModelNetDataLoader\nfrom pointnet2 import PointNet2MSGCls, PointNet2SSGCls\nfrom pointnet_cls import PointNetCls\nfrom torch.utils.data import DataLoader\n\ntorch.backends.cudnn.enabled = False\n\n\n# from dataset import ModelNet\n\nparser = argparse.ArgumentParser()\nparser.add_argument(\"--model\", type=str, default=\"pointnet\")\nparser.add_argument(\"--dataset-path\", type=str, default=\"\")\nparser.add_argument(\"--load-model-path\", type=str, default=\"\")\nparser.add_argument(\"--save-model-path\", type=str, default=\"\")\nparser.add_argument(\"--num-epochs\", type=int, default=200)\nparser.add_argument(\"--num-workers\", type=int, default=8)\nparser.add_argument(\"--batch-size\", type=int, default=32)\nargs = parser.parse_args()\n\nnum_workers = args.num_workers\nbatch_size = args.batch_size\n\ndata_filename = \"modelnet40_normal_resampled.zip\"\ndownload_path = os.path.join(get_download_dir(), data_filename)\nlocal_path = args.dataset_path or os.path.join(\n    get_download_dir(), \"modelnet40_normal_resampled\"\n)\n\nif not os.path.exists(local_path):\n    download(\n        \"https://shapenet.cs.stanford.edu/media/modelnet40_normal_resampled.zip\",\n        download_path,\n        verify_ssl=False,\n    )\n    from zipfile import ZipFile\n\n    with ZipFile(download_path) as z:\n        z.extractall(path=get_download_dir())\n\nCustomDataLoader = partial(\n    DataLoader,\n    num_workers=num_workers,\n    batch_size=batch_size,\n    shuffle=True,\n    drop_last=True,\n)\n\n\ndef train(net, opt, scheduler, train_loader, dev):\n    net.train()\n\n    total_loss = 0\n    num_batches = 0\n    total_correct = 0\n    count = 0\n    loss_f = nn.CrossEntropyLoss()\n    with tqdm.tqdm(train_loader, ascii=True) as tq:\n        for data, label in tq:\n            data = data.data.numpy()\n            data = provider.random_point_dropout(data)\n            data[:, :, 0:3] = provider.random_scale_point_cloud(data[:, :, 0:3])\n            data[:, :, 0:3] = provider.jitter_point_cloud(data[:, :, 0:3])\n            data[:, :, 0:3] = provider.shift_point_cloud(data[:, :, 0:3])\n            data = torch.tensor(data)\n            label = label[:, 0]\n\n            num_examples = label.shape[0]\n            data, label = data.to(dev), label.to(dev).squeeze().long()\n            opt.zero_grad()\n            logits = net(data)\n            loss = loss_f(logits, label)\n            loss.backward()\n            opt.step()\n\n            _, preds = logits.max(1)\n\n            num_batches += 1\n            count += num_examples\n            loss = loss.item()\n            correct = (preds == label).sum().item()\n            total_loss += loss\n            total_correct += correct\n\n            tq.set_postfix(\n                {\n                    \"AvgLoss\": \"%.5f\" % (total_loss / num_batches),\n                    \"AvgAcc\": \"%.5f\" % (total_correct / count),\n                }\n            )\n    scheduler.step()\n\n\ndef evaluate(net, test_loader, dev):\n    net.eval()\n\n    total_correct = 0\n    count = 0\n\n    with torch.no_grad():\n        with tqdm.tqdm(test_loader, ascii=True) as tq:\n            for data, label in tq:\n                label = label[:, 0]\n                num_examples = label.shape[0]\n                data, label = data.to(dev), label.to(dev).squeeze().long()\n                logits = net(data)\n                _, preds = logits.max(1)\n\n                correct = (preds == label).sum().item()\n                total_correct += correct\n                count += num_examples\n\n                tq.set_postfix({\"AvgAcc\": \"%.5f\" % (total_correct / count)})\n\n    return total_correct / count\n\n\ndev = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\nif args.model == \"pointnet\":\n    net = PointNetCls(40, input_dims=6)\nelif args.model == \"pointnet2_ssg\":\n    net = PointNet2SSGCls(40, batch_size, input_dims=6)\nelif args.model == \"pointnet2_msg\":\n    net = PointNet2MSGCls(40, batch_size, input_dims=6)\n\nnet = net.to(dev)\nif args.load_model_path:\n    net.load_state_dict(\n        torch.load(args.load_model_path, weights_only=False, map_location=dev)\n    )\n\nopt = optim.Adam(net.parameters(), lr=1e-3, weight_decay=1e-4)\n\nscheduler = optim.lr_scheduler.StepLR(opt, step_size=20, gamma=0.7)\n\ntrain_dataset = ModelNetDataLoader(local_path, 1024, split=\"train\")\ntest_dataset = ModelNetDataLoader(local_path, 1024, split=\"test\")\ntrain_loader = torch.utils.data.DataLoader(\n    train_dataset,\n    batch_size=batch_size,\n    shuffle=True,\n    num_workers=num_workers,\n    drop_last=True,\n)\ntest_loader = torch.utils.data.DataLoader(\n    test_dataset,\n    batch_size=batch_size,\n    shuffle=False,\n    num_workers=num_workers,\n    drop_last=True,\n)\n\nbest_test_acc = 0\n\nfor epoch in range(args.num_epochs):\n    train(net, opt, scheduler, train_loader, dev)\n    if (epoch + 1) % 1 == 0:\n        print(\"Epoch #%d Testing\" % epoch)\n        test_acc = evaluate(net, test_loader, dev)\n        if test_acc > best_test_acc:\n            best_test_acc = test_acc\n            if args.save_model_path:\n                torch.save(net.state_dict(), args.save_model_path)\n        print(\"Current test acc: %.5f (best: %.5f)\" % (test_acc, best_test_acc))\n"
  },
  {
    "path": "examples/pytorch/pointcloud/pointnet/train_partseg.py",
    "content": "import argparse\nimport os\nimport time\nimport urllib\nfrom functools import partial\n\nimport dgl\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nimport tqdm\nfrom dgl.data.utils import download, get_download_dir\nfrom pointnet2_partseg import PointNet2MSGPartSeg, PointNet2SSGPartSeg\nfrom pointnet_partseg import PartSegLoss, PointNetPartSeg\nfrom ShapeNet import ShapeNet\nfrom torch.utils.data import DataLoader\n\nparser = argparse.ArgumentParser()\nparser.add_argument(\"--model\", type=str, default=\"pointnet\")\nparser.add_argument(\"--dataset-path\", type=str, default=\"\")\nparser.add_argument(\"--load-model-path\", type=str, default=\"\")\nparser.add_argument(\"--save-model-path\", type=str, default=\"\")\nparser.add_argument(\"--num-epochs\", type=int, default=250)\nparser.add_argument(\"--num-workers\", type=int, default=4)\nparser.add_argument(\"--batch-size\", type=int, default=16)\nparser.add_argument(\"--tensorboard\", action=\"store_true\")\nargs = parser.parse_args()\n\nnum_workers = args.num_workers\nbatch_size = args.batch_size\n\n\ndef collate(samples):\n    graphs, cat = map(list, zip(*samples))\n    return dgl.batch(graphs), cat\n\n\nCustomDataLoader = partial(\n    DataLoader,\n    num_workers=num_workers,\n    batch_size=batch_size,\n    shuffle=True,\n    drop_last=True,\n)\n\n\ndef train(net, opt, scheduler, train_loader, dev):\n    category_list = sorted(list(shapenet.seg_classes.keys()))\n    eye_mat = np.eye(16)\n    net.train()\n\n    total_loss = 0\n    num_batches = 0\n    total_correct = 0\n    count = 0\n    start = time.time()\n    with tqdm.tqdm(train_loader, ascii=True) as tq:\n        for data, label, cat in tq:\n            num_examples = data.shape[0]\n            data = data.to(dev, dtype=torch.float)\n            label = label.to(dev, dtype=torch.long).view(-1)\n            opt.zero_grad()\n            cat_ind = [category_list.index(c) for c in cat]\n            # An one-hot encoding for the object category\n            cat_tensor = (\n                torch.tensor(eye_mat[cat_ind])\n                .to(dev, dtype=torch.float)\n                .repeat(1, 2048)\n            )\n            cat_tensor = cat_tensor.view(num_examples, -1, 16).permute(0, 2, 1)\n            logits = net(data, cat_tensor)\n            loss = L(logits, label)\n            loss.backward()\n            opt.step()\n\n            _, preds = logits.max(1)\n\n            count += num_examples * 2048\n            loss = loss.item()\n            total_loss += loss\n            num_batches += 1\n            correct = (preds.view(-1) == label).sum().item()\n            total_correct += correct\n\n            AvgLoss = total_loss / num_batches\n            AvgAcc = total_correct / count\n\n            tq.set_postfix(\n                {\"AvgLoss\": \"%.5f\" % AvgLoss, \"AvgAcc\": \"%.5f\" % AvgAcc}\n            )\n    scheduler.step()\n    end = time.time()\n    return data, preds, AvgLoss, AvgAcc, end - start\n\n\ndef mIoU(preds, label, cat, cat_miou, seg_classes):\n    for i in range(preds.shape[0]):\n        shape_iou = 0\n        n = len(seg_classes[cat[i]])\n        for cls in seg_classes[cat[i]]:\n            pred_set = set(np.where(preds[i, :] == cls)[0])\n            label_set = set(np.where(label[i, :] == cls)[0])\n            union = len(pred_set.union(label_set))\n            inter = len(pred_set.intersection(label_set))\n            if union == 0:\n                shape_iou += 1\n            else:\n                shape_iou += inter / union\n        shape_iou /= n\n        cat_miou[cat[i]][0] += shape_iou\n        cat_miou[cat[i]][1] += 1\n\n    return cat_miou\n\n\ndef evaluate(net, test_loader, dev, per_cat_verbose=False):\n    category_list = sorted(list(shapenet.seg_classes.keys()))\n    eye_mat = np.eye(16)\n    net.eval()\n\n    cat_miou = {}\n    for k in shapenet.seg_classes.keys():\n        cat_miou[k] = [0, 0]\n    miou = 0\n    count = 0\n    per_cat_miou = 0\n    per_cat_count = 0\n\n    with torch.no_grad():\n        with tqdm.tqdm(test_loader, ascii=True) as tq:\n            for data, label, cat in tq:\n                num_examples = data.shape[0]\n                data = data.to(dev, dtype=torch.float)\n                label = label.to(dev, dtype=torch.long)\n                cat_ind = [category_list.index(c) for c in cat]\n                cat_tensor = (\n                    torch.tensor(eye_mat[cat_ind])\n                    .to(dev, dtype=torch.float)\n                    .repeat(1, 2048)\n                )\n                cat_tensor = cat_tensor.view(num_examples, -1, 16).permute(\n                    0, 2, 1\n                )\n                logits = net(data, cat_tensor)\n                _, preds = logits.max(1)\n\n                cat_miou = mIoU(\n                    preds.cpu().numpy(),\n                    label.view(num_examples, -1).cpu().numpy(),\n                    cat,\n                    cat_miou,\n                    shapenet.seg_classes,\n                )\n                for _, v in cat_miou.items():\n                    if v[1] > 0:\n                        miou += v[0]\n                        count += v[1]\n                        per_cat_miou += v[0] / v[1]\n                        per_cat_count += 1\n                tq.set_postfix(\n                    {\n                        \"mIoU\": \"%.5f\" % (miou / count),\n                        \"per Category mIoU\": \"%.5f\" % (miou / count),\n                    }\n                )\n    if per_cat_verbose:\n        print(\"Per-Category mIoU:\")\n        for k, v in cat_miou.items():\n            if v[1] > 0:\n                print(\"%s mIoU=%.5f\" % (k, v[0] / v[1]))\n            else:\n                print(\"%s mIoU=%.5f\" % (k, 1))\n    return miou / count, per_cat_miou / per_cat_count\n\n\ndev = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n# dev = \"cpu\"\nif args.model == \"pointnet\":\n    net = PointNetPartSeg(50, 3, 2048)\nelif args.model == \"pointnet2_ssg\":\n    net = PointNet2SSGPartSeg(50, batch_size, input_dims=6)\nelif args.model == \"pointnet2_msg\":\n    net = PointNet2MSGPartSeg(50, batch_size, input_dims=6)\n\nnet = net.to(dev)\nif args.load_model_path:\n    net.load_state_dict(\n        torch.load(args.load_model_path, weights_only=False, map_location=dev)\n    )\n\nopt = optim.Adam(net.parameters(), lr=0.001, weight_decay=1e-4)\nscheduler = optim.lr_scheduler.StepLR(opt, step_size=20, gamma=0.5)\nL = PartSegLoss()\n\nshapenet = ShapeNet(2048, normal_channel=False)\n\ntrain_loader = CustomDataLoader(shapenet.trainval())\ntest_loader = CustomDataLoader(shapenet.test())\n\n# Tensorboard\nif args.tensorboard:\n    import torchvision\n    from torch.utils.tensorboard import SummaryWriter\n    from torchvision import datasets, transforms\n\n    writer = SummaryWriter()\n# Select 50 distinct colors for different parts\ncolor_map = torch.tensor(\n    [\n        [47, 79, 79],\n        [139, 69, 19],\n        [112, 128, 144],\n        [85, 107, 47],\n        [139, 0, 0],\n        [128, 128, 0],\n        [72, 61, 139],\n        [0, 128, 0],\n        [188, 143, 143],\n        [60, 179, 113],\n        [205, 133, 63],\n        [0, 139, 139],\n        [70, 130, 180],\n        [205, 92, 92],\n        [154, 205, 50],\n        [0, 0, 139],\n        [50, 205, 50],\n        [250, 250, 250],\n        [218, 165, 32],\n        [139, 0, 139],\n        [10, 10, 10],\n        [176, 48, 96],\n        [72, 209, 204],\n        [153, 50, 204],\n        [255, 69, 0],\n        [255, 145, 0],\n        [0, 0, 205],\n        [255, 255, 0],\n        [0, 255, 0],\n        [233, 150, 122],\n        [220, 20, 60],\n        [0, 191, 255],\n        [160, 32, 240],\n        [192, 192, 192],\n        [173, 255, 47],\n        [218, 112, 214],\n        [216, 191, 216],\n        [255, 127, 80],\n        [255, 0, 255],\n        [100, 149, 237],\n        [128, 128, 128],\n        [221, 160, 221],\n        [144, 238, 144],\n        [123, 104, 238],\n        [255, 160, 122],\n        [175, 238, 238],\n        [238, 130, 238],\n        [127, 255, 212],\n        [255, 218, 185],\n        [255, 105, 180],\n    ]\n)\n\n\n# paint each point according to its pred\ndef paint(batched_points):\n    B, N = batched_points.shape\n    colored = color_map[batched_points].squeeze(2)\n    return colored\n\n\nbest_test_miou = 0\nbest_test_per_cat_miou = 0\n\nfor epoch in range(args.num_epochs):\n    data, preds, AvgLoss, AvgAcc, training_time = train(\n        net, opt, scheduler, train_loader, dev\n    )\n    if (epoch + 1) % 5 == 0:\n        print(\"Epoch #%d Testing\" % epoch)\n        test_miou, test_per_cat_miou = evaluate(\n            net, test_loader, dev, (epoch + 1) % 5 == 0\n        )\n        if test_miou > best_test_miou:\n            best_test_miou = test_miou\n            best_test_per_cat_miou = test_per_cat_miou\n            if args.save_model_path:\n                torch.save(net.state_dict(), args.save_model_path)\n        print(\n            \"Current test mIoU: %.5f (best: %.5f), per-Category mIoU: %.5f (best: %.5f)\"\n            % (\n                test_miou,\n                best_test_miou,\n                test_per_cat_miou,\n                best_test_per_cat_miou,\n            )\n        )\n    # Tensorboard\n    if args.tensorboard:\n        colored = paint(preds)\n        writer.add_mesh(\n            \"data\", vertices=data, colors=colored, global_step=epoch\n        )\n        writer.add_scalar(\n            \"training time for one epoch\", training_time, global_step=epoch\n        )\n        writer.add_scalar(\"AvgLoss\", AvgLoss, global_step=epoch)\n        writer.add_scalar(\"AvgAcc\", AvgAcc, global_step=epoch)\n        if (epoch + 1) % 5 == 0:\n            writer.add_scalar(\"test mIoU\", test_miou, global_step=epoch)\n            writer.add_scalar(\n                \"best test mIoU\", best_test_miou, global_step=epoch\n            )\n"
  },
  {
    "path": "examples/pytorch/rect/README.md",
    "content": "# **DGL Implementation of RECT (TKDE20)**\r\n\r\nThis DGL example implements the GNN model **RECT** (or more specifically its supervised part **RECT-L**) proposed in the paper [Network Embedding with Completely-imbalanced Labels](https://ieeexplore.ieee.org/document/8979355). The authors' original implementation can be found [here](https://github.com/zhengwang100/RECT).\r\n\r\n\r\n\r\n## Example Implementor\r\n\r\nThis example was implemented by [Tingzhang Zhao](https://github.com/Fizyhsp) when he was an undergraduate at USTB.\r\n\r\n\r\n\r\n## **Dataset and experimental setting**\r\n\r\nTwo DGL's build-in datasets (Cora and Citeseer) with their default train/val/test settings are used in this example. In addition, as this paper considers the zero-shot (i.e., completely-imbalanced) label setting, those \"unseen\" classes should be removed from the training set, as suggested in the paper. In this example, in each dataset, we simply remove the 2-3 classes (i.e., these 2-3 classes are unseen classes) from the labeled training set. Then, we obtain graph embedding results by different models. Finally, with the obtained embedding results and the original balanced labels, we train a logistic regression classifier to evaluate the model performance.\r\n\r\n\r\n\r\n## **Usage** \r\n\r\n`python main.py --dataset cora --gpu 0 --model-opt RECT-L --removed-class 0 1 2` #reproducing the RECT-L on \"cora\" datasets in the zero-shot label setting using GPU\r\n\r\n`python main.py --dataset cora --gpu 0 --model-opt GCN --removed-class 0 1 2` #reproducing the GCN on \"cora\" datasets in the zero-shot label setting using GPU\r\n\r\n`python main.py --dataset cora --gpu 0 --model-opt NodeFeats --removed-class 0 1 2` # evaluating the original node features using GPU\r\n\r\n\r\n\r\n## **Performance**\r\n\r\nThe performance results are are as follows:\r\n\r\n| **Datasets/Models** | **NodeFeats** | **GCN** | **RECT-L** |\r\n| :-----------------: | :-----------: | :-----: | :--------: |\r\n|      **Cora**       |     47.56     |  51.26  | **68.60**  |\r\n|    **Citeseer**     |     42.04     |  37.55  | **56.32**  |\r\n\r\n<center>Table 1：node classification results with the first three classes as \"unseen\"</center>\r\n<br/><br/>\r\n\r\n\r\n| **Datasets/Models** | **NodeFeats** | **GCN** | **RECT-L** |\r\n| :-----------------: | :-----------: | :-----: | :--------: |\r\n|      **Cora**       |     47.56     |  56.91  | **69.30**  |\r\n|    **Citeseer**     |     42.04     |  45.69  | **61.85**  |\r\n\r\n<center>Table 2：node classification results with the last two classes as \"unseen\"</center>\r\n<br/>\r\n"
  },
  {
    "path": "examples/pytorch/rect/classify.py",
    "content": "from statistics import mean\r\n\r\nimport torch\r\nimport torch.nn as nn\r\nimport torch.nn.functional as F\r\n\r\n\r\nclass LogisticRegressionClassifier(nn.Module):\r\n    \"\"\"Define a logistic regression classifier to evaluate the quality of embedding results\"\"\"\r\n\r\n    def __init__(self, nfeat, nclass):\r\n        super(LogisticRegressionClassifier, self).__init__()\r\n        self.lrc = nn.Linear(nfeat, nclass)\r\n\r\n    def forward(self, x):\r\n        preds = self.lrc(x)\r\n        return preds\r\n\r\n\r\ndef _evaluate(model, features, labels, test_mask):\r\n    model.eval()\r\n    with torch.no_grad():\r\n        logits = model(features)\r\n        logits = logits[test_mask]\r\n        labels = labels[test_mask]\r\n        _, indices = torch.max(logits, dim=1)\r\n        correct = torch.sum(indices == labels)\r\n        return correct.item() * 1.0 / len(labels)\r\n\r\n\r\ndef _train_test_with_lrc(model, features, labels, train_mask, test_mask):\r\n    \"\"\"Under the pre-defined balanced train/test label setting, train a lrc to evaluate the embedding results.\"\"\"\r\n    optimizer = torch.optim.Adam(model.parameters(), lr=0.2, weight_decay=5e-06)\r\n    for _ in range(100):\r\n        model.train()\r\n        optimizer.zero_grad()\r\n        output = model(features)\r\n        loss_train = F.cross_entropy(output[train_mask], labels[train_mask])\r\n        loss_train.backward()\r\n        optimizer.step()\r\n    return _evaluate(\r\n        model=model, features=features, labels=labels, test_mask=test_mask\r\n    )\r\n\r\n\r\ndef evaluate_embeds(\r\n    features, labels, train_mask, test_mask, n_classes, cuda, test_times=10\r\n):\r\n    print(\r\n        \"Training a logistic regression classifier with the pre-defined train/test split setting ...\"\r\n    )\r\n    res_list = []\r\n    for _ in range(test_times):\r\n        model = LogisticRegressionClassifier(\r\n            nfeat=features.shape[1], nclass=n_classes\r\n        )\r\n        if cuda:\r\n            model.cuda()\r\n        res = _train_test_with_lrc(\r\n            model=model,\r\n            features=features,\r\n            labels=labels,\r\n            train_mask=train_mask,\r\n            test_mask=test_mask,\r\n        )\r\n        res_list.append(res)\r\n    return mean(res_list)\r\n"
  },
  {
    "path": "examples/pytorch/rect/label_utils.py",
    "content": "from collections import defaultdict\r\n\r\nimport numpy as np\r\nimport torch\r\n\r\n\r\ndef remove_unseen_classes_from_training(train_mask, labels, removed_class):\r\n    \"\"\"Remove the unseen classes (the first three classes by default) to get the zero-shot (i.e., completely imbalanced) label setting\r\n    Input: train_mask, labels, removed_class\r\n    Output: train_mask_zs: the bool list only containing seen classes\r\n    \"\"\"\r\n    train_mask_zs = train_mask.clone()\r\n    for i in range(train_mask_zs.numel()):\r\n        if train_mask_zs[i] == 1 and (labels[i].item() in removed_class):\r\n            train_mask_zs[i] = 0\r\n    return train_mask_zs\r\n\r\n\r\ndef get_class_set(labels):\r\n    \"\"\"Get the class set.\r\n    Input: labels [l, [c1, c2, ..]]\r\n    Output：the labeled class set dict_keys([k1, k2, ..])\r\n    \"\"\"\r\n    mydict = {}\r\n    for y in labels:\r\n        for label in y:\r\n            mydict[int(label)] = 1\r\n    return mydict.keys()\r\n\r\n\r\ndef get_label_attributes(train_mask_zs, nodeids, labellist, features):\r\n    \"\"\"Get the class-center (semanic knowledge) of each seen class.\r\n    Suppose a node i is labeled as c, then attribute[c] += node_i_attribute, finally mean(attribute[c])\r\n    Input: train_mask_zs, nodeids, labellist, features\r\n    Output: label_attribute{}: label -> average_labeled_node_features (class centers)\r\n    \"\"\"\r\n    _, feat_num = features.shape\r\n    labels = get_class_set(labellist)\r\n    label_attribute_nodes = defaultdict(list)\r\n    for nodeid, labels in zip(nodeids, labellist):\r\n        for label in labels:\r\n            label_attribute_nodes[int(label)].append(int(nodeid))\r\n    label_attribute = {}\r\n    for label in label_attribute_nodes.keys():\r\n        nodes = label_attribute_nodes[int(label)]\r\n        selected_features = features[nodes, :]\r\n        label_attribute[int(label)] = np.mean(selected_features, axis=0)\r\n    return label_attribute\r\n\r\n\r\ndef get_labeled_nodes_label_attribute(train_mask_zs, labels, features, cuda):\r\n    \"\"\"Replace the original labels by their class-centers.\r\n    For each label c in the training set, the following operations will be performed:\r\n    Get label_attribute{} through function get_label_attributes, then res[i, :] = label_attribute[c]\r\n    Input: train_mask_zs, labels, features\r\n    Output: Y_{semantic} [l, ft]: tensor\r\n    \"\"\"\r\n    X = torch.LongTensor(range(features.shape[0]))\r\n    nodeids = []\r\n    labellist = []\r\n    for i in X[train_mask_zs].numpy().tolist():\r\n        nodeids.append(str(i))\r\n    for i in labels[train_mask_zs].cpu().numpy().tolist():\r\n        labellist.append([str(i)])\r\n\r\n    # 1. get the semantic knowledge (class centers) of all seen classes\r\n    label_attribute = get_label_attributes(\r\n        train_mask_zs=train_mask_zs,\r\n        nodeids=nodeids,\r\n        labellist=labellist,\r\n        features=features.cpu().numpy(),\r\n    )\r\n\r\n    # 2. replace original labels by their class centers (semantic knowledge)\r\n    res = np.zeros([len(nodeids), features.shape[1]])\r\n    for i, labels in enumerate(labellist):\r\n        # support mutiple labels\r\n        c = len(labels)\r\n        temp = np.zeros([c, features.shape[1]])\r\n        for ii, label in enumerate(labels):\r\n            temp[ii, :] = label_attribute[int(label)]\r\n        temp = np.mean(temp, axis=0)\r\n        res[i, :] = temp\r\n    if cuda:\r\n        res = torch.FloatTensor(res).cuda()\r\n    else:\r\n        res = torch.FloatTensor(res)\r\n    return res\r\n"
  },
  {
    "path": "examples/pytorch/rect/main.py",
    "content": "import torch\r\nimport torch.nn as nn\r\nfrom classify import evaluate_embeds\r\nfrom label_utils import (\r\n    get_labeled_nodes_label_attribute,\r\n    remove_unseen_classes_from_training,\r\n)\r\nfrom model import GCN, RECT_L\r\nfrom utils import load_data, process_classids, svd_feature\r\n\r\n\r\ndef main(args):\r\n    g, features, labels, train_mask, test_mask, n_classes, cuda = load_data(\r\n        args\r\n    )\r\n    # adopt any number of classes as the unseen classes (the first three classes by default)\r\n    removed_class = args.removed_class\r\n    if len(removed_class) > n_classes:\r\n        raise ValueError(\r\n            \"unseen number is greater than the number of classes: {}\".format(\r\n                len(removed_class)\r\n            )\r\n        )\r\n    for i in removed_class:\r\n        if i not in labels:\r\n            raise ValueError(\"class out of bounds: {}\".format(i))\r\n\r\n    # remove these unseen classes from the training set, to construct the zero-shot label setting\r\n    train_mask_zs = remove_unseen_classes_from_training(\r\n        train_mask=train_mask, labels=labels, removed_class=removed_class\r\n    )\r\n    print(\r\n        \"after removing the unseen classes, seen class labeled node num:\",\r\n        sum(train_mask_zs).item(),\r\n    )\r\n\r\n    if args.model_opt == \"RECT-L\":\r\n        model = RECT_L(\r\n            g=g,\r\n            in_feats=args.n_hidden,\r\n            n_hidden=args.n_hidden,\r\n            activation=nn.PReLU(),\r\n        )\r\n\r\n        if cuda:\r\n            model.cuda()\r\n        features = svd_feature(features=features, d=args.n_hidden)\r\n        attribute_labels = get_labeled_nodes_label_attribute(\r\n            train_mask_zs=train_mask_zs,\r\n            labels=labels,\r\n            features=features,\r\n            cuda=cuda,\r\n        )\r\n        loss_fcn = nn.MSELoss(reduction=\"sum\")\r\n        optimizer = torch.optim.Adam(\r\n            model.parameters(), lr=args.lr, weight_decay=args.weight_decay\r\n        )\r\n\r\n        for epoch in range(args.n_epochs):\r\n            model.train()\r\n            optimizer.zero_grad()\r\n            logits = model(features)\r\n            loss_train = loss_fcn(attribute_labels, logits[train_mask_zs])\r\n            print(\r\n                \"Epoch {:d} | Train Loss {:.5f}\".format(\r\n                    epoch + 1, loss_train.item()\r\n                )\r\n            )\r\n            loss_train.backward()\r\n            optimizer.step()\r\n        model.eval()\r\n        embeds = model.embed(features)\r\n\r\n    elif args.model_opt == \"GCN\":\r\n        model = GCN(\r\n            g=g,\r\n            in_feats=features.shape[1],\r\n            n_hidden=args.n_hidden,\r\n            n_classes=n_classes - len(removed_class),\r\n            activation=nn.PReLU(),\r\n            dropout=args.dropout,\r\n        )\r\n\r\n        if cuda:\r\n            model.cuda()\r\n        loss_fcn = nn.CrossEntropyLoss()\r\n        optimizer = torch.optim.Adam(\r\n            model.parameters(), lr=args.lr, weight_decay=args.weight_decay\r\n        )\r\n\r\n        for epoch in range(args.n_epochs):\r\n            model.train()\r\n            logits = model(features)\r\n            labels_train = process_classids(labels_temp=labels[train_mask_zs])\r\n            loss_train = loss_fcn(logits[train_mask_zs], labels_train)\r\n            optimizer.zero_grad()\r\n            print(\r\n                \"Epoch {:d} | Train Loss {:.5f}\".format(\r\n                    epoch + 1, loss_train.item()\r\n                )\r\n            )\r\n            loss_train.backward()\r\n            optimizer.step()\r\n        model.eval()\r\n        embeds = model.embed(features)\r\n\r\n    elif args.model_opt == \"NodeFeats\":\r\n        embeds = svd_feature(features)\r\n\r\n    # evaluate the quality of embedding results with the original balanced labels, to assess the model performance (as suggested in the paper)\r\n    res = evaluate_embeds(\r\n        features=embeds,\r\n        labels=labels,\r\n        train_mask=train_mask,\r\n        test_mask=test_mask,\r\n        n_classes=n_classes,\r\n        cuda=cuda,\r\n    )\r\n    print(\"Test Accuracy of {:s}: {:.4f}\".format(args.model_opt, res))\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    import argparse\r\n\r\n    parser = argparse.ArgumentParser(description=\"MODEL\")\r\n    parser.add_argument(\r\n        \"--model-opt\",\r\n        type=str,\r\n        default=\"RECT-L\",\r\n        choices=[\"RECT-L\", \"GCN\", \"NodeFeats\"],\r\n        help=\"model option\",\r\n    )\r\n    parser.add_argument(\r\n        \"--dataset\",\r\n        type=str,\r\n        default=\"cora\",\r\n        choices=[\"cora\", \"citeseer\"],\r\n        help=\"dataset\",\r\n    )\r\n    parser.add_argument(\r\n        \"--dropout\", type=float, default=0.0, help=\"dropout probability\"\r\n    )\r\n    parser.add_argument(\"--gpu\", type=int, default=0, help=\"gpu\")\r\n    parser.add_argument(\r\n        \"--removed-class\",\r\n        type=int,\r\n        nargs=\"*\",\r\n        default=[0, 1, 2],\r\n        help=\"remove the unseen classes\",\r\n    )\r\n    parser.add_argument(\"--lr\", type=float, default=1e-3, help=\"learning rate\")\r\n    parser.add_argument(\r\n        \"--n-epochs\", type=int, default=200, help=\"number of training epochs\"\r\n    )\r\n    parser.add_argument(\r\n        \"--n-hidden\", type=int, default=200, help=\"number of hidden gcn units\"\r\n    )\r\n    parser.add_argument(\r\n        \"--weight-decay\", type=float, default=5e-4, help=\"Weight for L2 loss\"\r\n    )\r\n    args = parser.parse_args()\r\n\r\n    main(args)\r\n"
  },
  {
    "path": "examples/pytorch/rect/model.py",
    "content": "import torch.nn as nn\r\nimport torch.nn.functional as F\r\n\r\nfrom dgl.nn.pytorch import GraphConv\r\n\r\n\r\nclass GCN(nn.Module):\r\n    def __init__(self, g, in_feats, n_hidden, n_classes, activation, dropout):\r\n        super(GCN, self).__init__()\r\n        self.g = g\r\n        self.gcn_1 = GraphConv(in_feats, n_hidden, activation=activation)\r\n        self.gcn_2 = GraphConv(n_hidden, n_classes)\r\n        self.dropout = nn.Dropout(p=dropout)\r\n\r\n    def forward(self, features):\r\n        h = self.gcn_1(self.g, features)\r\n        h = self.dropout(h)\r\n        preds = self.gcn_2(self.g, h)\r\n        return preds\r\n\r\n    def embed(self, inputs):\r\n        h_1 = self.gcn_1(self.g, inputs)\r\n        return h_1.detach()\r\n\r\n\r\nclass RECT_L(nn.Module):\r\n    def __init__(self, g, in_feats, n_hidden, activation, dropout=0.0):\r\n        super(RECT_L, self).__init__()\r\n        self.g = g\r\n        self.gcn_1 = GraphConv(in_feats, n_hidden, activation=activation)\r\n        self.fc = nn.Linear(n_hidden, in_feats)\r\n        self.dropout = dropout\r\n        nn.init.xavier_uniform_(self.fc.weight.data)\r\n\r\n    def forward(self, inputs):\r\n        h_1 = self.gcn_1(self.g, inputs)\r\n        h_1 = F.dropout(h_1, p=self.dropout, training=self.training)\r\n        preds = self.fc(h_1)\r\n        return preds\r\n\r\n    # Detach the return variables\r\n    def embed(self, inputs):\r\n        h_1 = self.gcn_1(self.g, inputs)\r\n        return h_1.detach()\r\n"
  },
  {
    "path": "examples/pytorch/rect/utils.py",
    "content": "import dgl\nimport torch\nfrom dgl.data import CiteseerGraphDataset, CoraGraphDataset\n\n\ndef load_data(args):\n    if args.dataset == \"cora\":\n        data = CoraGraphDataset()\n    elif args.dataset == \"citeseer\":\n        data = CiteseerGraphDataset()\n    else:\n        raise ValueError(\"Unknown dataset: {}\".format(args.dataset))\n    g = data[0]\n    if args.gpu < 0:\n        cuda = False\n    else:\n        cuda = True\n        g = g.int().to(args.gpu)\n    features = g.ndata[\"feat\"]\n    labels = g.ndata[\"label\"]\n    train_mask = g.ndata[\"train_mask\"]\n    test_mask = g.ndata[\"test_mask\"]\n    g = dgl.add_self_loop(g)\n    return g, features, labels, train_mask, test_mask, data.num_classes, cuda\n\n\ndef svd_feature(features, d=200):\n    \"\"\"Get 200-dimensional node features, to avoid curse of dimensionality\"\"\"\n    if features.shape[1] <= d:\n        return features\n    U, S, VT = torch.svd(features)\n    res = torch.mm(U[:, 0:d], torch.diag(S[0:d]))\n    return res\n\n\ndef process_classids(labels_temp):\n    \"\"\"Reorder the remaining classes with unseen classes removed.\n    Input: the label only removing unseen classes\n    Output: the label with reordered classes\n    \"\"\"\n    labeldict = {}\n    num = 0\n    for i in labels_temp:\n        labeldict[int(i)] = 1\n    labellist = sorted(labeldict)\n    for label in labellist:\n        labeldict[int(label)] = num\n        num = num + 1\n    for i in range(labels_temp.numel()):\n        labels_temp[i] = labeldict[int(labels_temp[i])]\n    return labels_temp\n"
  },
  {
    "path": "examples/pytorch/rgat/README.md",
    "content": "Relational Graph Attention Networks (RGAT)\n==============\nThis is an adaptation of RGCN where graph convolution is replaced with graph attention.\n\nDependencies\n------------\n- torchmetrics 0.11.4\n\nInstall as follows:\n```bash\npip install torchmetrics==0.11.4\n```\n\nHow to Run\n-------\n\nRun with the following for node classification on ogbn-mag dataset\n```bash\npython train.py\n```\n\n\nSummary\n-------\n* ogbn-mag (test acc.): ~0.3647\n"
  },
  {
    "path": "examples/pytorch/rgat/train.py",
    "content": "import dgl\nimport dgl.function as fn\nimport dgl.nn as dglnn\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torchmetrics.functional as MF\nimport tqdm\nfrom dgl import apply_each\nfrom dgl.dataloading import DataLoader, NeighborSampler\nfrom ogb.nodeproppred import DglNodePropPredDataset\n\n\nclass HeteroGAT(nn.Module):\n    def __init__(self, etypes, in_size, hid_size, out_size, n_heads=4):\n        super().__init__()\n        self.layers = nn.ModuleList()\n        self.layers.append(\n            dglnn.HeteroGraphConv(\n                {\n                    etype: dglnn.GATConv(in_size, hid_size // n_heads, n_heads)\n                    for etype in etypes\n                }\n            )\n        )\n        self.layers.append(\n            dglnn.HeteroGraphConv(\n                {\n                    etype: dglnn.GATConv(hid_size, hid_size // n_heads, n_heads)\n                    for etype in etypes\n                }\n            )\n        )\n        self.layers.append(\n            dglnn.HeteroGraphConv(\n                {\n                    etype: dglnn.GATConv(hid_size, hid_size // n_heads, n_heads)\n                    for etype in etypes\n                }\n            )\n        )\n        self.dropout = nn.Dropout(0.5)\n        self.linear = nn.Linear(hid_size, out_size)  # Should be HeteroLinear\n\n    def forward(self, blocks, x):\n        h = x\n        for l, (layer, block) in enumerate(zip(self.layers, blocks)):\n            h = layer(block, h)\n            # One thing is that h might return tensors with zero rows if the number of dst nodes\n            # of one node type is 0.  x.view(x.shape[0], -1) wouldn't work in this case.\n            h = apply_each(\n                h, lambda x: x.view(x.shape[0], x.shape[1] * x.shape[2])\n            )\n            if l != len(self.layers) - 1:\n                h = apply_each(h, F.relu)\n                h = apply_each(h, self.dropout)\n        return self.linear(h[\"paper\"])\n\n\ndef evaluate(num_classes, model, dataloader, desc):\n    preds = []\n    labels = []\n    with torch.no_grad():\n        for input_nodes, output_nodes, blocks in tqdm.tqdm(\n            dataloader, desc=desc\n        ):\n            x = blocks[0].srcdata[\"feat\"]\n            y = blocks[-1].dstdata[\"label\"][\"paper\"][:, 0]\n            y_hat = model(blocks, x)\n            preds.append(y_hat.cpu())\n            labels.append(y.cpu())\n        preds = torch.cat(preds, 0)\n        labels = torch.cat(labels, 0)\n        acc = MF.accuracy(\n            preds, labels, task=\"multiclass\", num_classes=num_classes\n        )\n        return acc\n\n\ndef train(train_loader, val_loader, test_loader, num_classes, model):\n    # loss function and optimizer\n    loss_fcn = nn.CrossEntropyLoss()\n    opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)\n\n    # training loop\n    for epoch in range(10):\n        model.train()\n        total_loss = 0\n        for it, (input_nodes, output_nodes, blocks) in enumerate(\n            tqdm.tqdm(train_dataloader, desc=\"Train\")\n        ):\n            x = blocks[0].srcdata[\"feat\"]\n            y = blocks[-1].dstdata[\"label\"][\"paper\"][:, 0]\n            y_hat = model(blocks, x)\n            loss = loss_fcn(y_hat, y)\n            opt.zero_grad()\n            loss.backward()\n            opt.step()\n            total_loss += loss.item()\n        model.eval()\n        val_acc = evaluate(num_classes, model, val_dataloader, \"Val. \")\n        test_acc = evaluate(num_classes, model, test_dataloader, \"Test \")\n        print(\n            f\"Epoch {epoch:05d} | Loss {total_loss/(it+1):.4f} | Validation Acc. {val_acc.item():.4f} | Test Acc. {test_acc.item():.4f}\"\n        )\n\n\nif __name__ == \"__main__\":\n    print(\n        f\"Training with DGL built-in HeteroGraphConv using GATConv as its convolution sub-modules\"\n    )\n    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n    # load and preprocess dataset\n    print(\"Loading data\")\n    dataset = DglNodePropPredDataset(\"ogbn-mag\")\n    graph, labels = dataset[0]\n    graph.ndata[\"label\"] = labels\n    # add reverse edges in \"cites\" relation, and add reverse edge types for the rest etypes\n    graph = dgl.AddReverse()(graph)\n    # precompute the author, topic, and institution features\n    graph.update_all(\n        fn.copy_u(\"feat\", \"m\"), fn.mean(\"m\", \"feat\"), etype=\"rev_writes\"\n    )\n    graph.update_all(\n        fn.copy_u(\"feat\", \"m\"), fn.mean(\"m\", \"feat\"), etype=\"has_topic\"\n    )\n    graph.update_all(\n        fn.copy_u(\"feat\", \"m\"), fn.mean(\"m\", \"feat\"), etype=\"affiliated_with\"\n    )\n    # find train/val/test indexes\n    split_idx = dataset.get_idx_split()\n    train_idx, val_idx, test_idx = (\n        split_idx[\"train\"],\n        split_idx[\"valid\"],\n        split_idx[\"test\"],\n    )\n    train_idx = apply_each(train_idx, lambda x: x.to(device))\n    val_idx = apply_each(val_idx, lambda x: x.to(device))\n    test_idx = apply_each(test_idx, lambda x: x.to(device))\n\n    # create RGAT model\n    in_size = graph.ndata[\"feat\"][\"paper\"].shape[1]\n    num_classes = dataset.num_classes\n    model = HeteroGAT(graph.etypes, in_size, 256, num_classes).to(device)\n\n    # dataloader + model training + testing\n    train_sampler = NeighborSampler(\n        [5, 5, 5],\n        prefetch_node_feats={k: [\"feat\"] for k in graph.ntypes},\n        prefetch_labels={\"paper\": [\"label\"]},\n    )\n    val_sampler = NeighborSampler(\n        [10, 10, 10],\n        prefetch_node_feats={k: [\"feat\"] for k in graph.ntypes},\n        prefetch_labels={\"paper\": [\"label\"]},\n    )\n    train_dataloader = DataLoader(\n        graph,\n        train_idx,\n        train_sampler,\n        device=device,\n        batch_size=1000,\n        shuffle=True,\n        drop_last=False,\n        num_workers=0,\n        use_uva=torch.cuda.is_available(),\n    )\n    val_dataloader = DataLoader(\n        graph,\n        val_idx,\n        val_sampler,\n        device=device,\n        batch_size=1000,\n        shuffle=False,\n        drop_last=False,\n        num_workers=0,\n        use_uva=torch.cuda.is_available(),\n    )\n    test_dataloader = DataLoader(\n        graph,\n        test_idx,\n        val_sampler,\n        device=device,\n        batch_size=1000,\n        shuffle=False,\n        drop_last=False,\n        num_workers=0,\n        use_uva=torch.cuda.is_available(),\n    )\n\n    train(train_dataloader, val_dataloader, test_dataloader, num_classes, model)\n"
  },
  {
    "path": "examples/pytorch/rgcn/README.md",
    "content": "# Relational-GCN\n\n* Paper: [Modeling Relational Data with Graph Convolutional Networks](https://arxiv.org/abs/1703.06103)\n* Author's code for entity classification: [https://github.com/tkipf/relational-gcn](https://github.com/tkipf/relational-gcn)\n* Author's code for link prediction: [https://github.com/MichSchli/RelationPrediction](https://github.com/MichSchli/RelationPrediction)\n\n### Dependencies\n- rdflib\n- torchmetrics 0.11.4\n\nInstall as follows:\n```bash\npip install rdflib\npip install torchmetrics==0.11.4\n```\n\nHow to run\n-------\n\n### Entity Classification\n\nRun with the following for entity classification (available datasets: aifb (default), mutag, bgs, and am)\n```bash\npython3 entity.py --dataset aifb\n```\n\nFor mini-batch training, run with the following (available datasets are the same as above)\n```bash\npython3 entity_sample.py --dataset aifb\n```\nFor multi-gpu training (with sampling), run with the following (same datasets and GPU IDs separated by comma)\n```bash\npython3 entity_sample_multi_gpu.py --dataset aifb --gpu 0,1\n```\n\n### Link Prediction\n\nRun with the following for link prediction on dataset FB15k-237 with filtered-MRR\n\n```bash\npython link.py\n```\n> **_NOTE:_** By default, we use uniform edge sampling instead of neighbor-based edge sampling as in [author's code](https://github.com/MichSchli/RelationPrediction). In practice, we find that it can achieve similar MRR.\n\n\nSummary\n-------\n\n### Entity Classification\n\n| Dataset       | Full-graph | Mini-batch\n| ------------- | -------    |  ------\n| aifb          | ~0.85      | ~0.82\n| mutag         | ~0.70      | ~0.50\n| bgs           | ~0.86      | ~0.64\n| am            | ~0.78      | ~0.42\n\n### Link Prediction\n| Dataset       | Best MRR\n| ------------- | -------\n| FB15k-237     | ~0.2397\n"
  },
  {
    "path": "examples/pytorch/rgcn/entity.py",
    "content": "import argparse\n\nimport dgl\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dgl.data.rdf import AIFBDataset, AMDataset, BGSDataset, MUTAGDataset\nfrom dgl.nn.pytorch import RelGraphConv\nfrom torchmetrics.functional import accuracy\n\n\nclass RGCN(nn.Module):\n    def __init__(self, num_nodes, h_dim, out_dim, num_rels):\n        super().__init__()\n        self.emb = nn.Embedding(num_nodes, h_dim)\n        # two-layer RGCN\n        self.conv1 = RelGraphConv(\n            h_dim,\n            h_dim,\n            num_rels,\n            regularizer=\"basis\",\n            num_bases=num_rels,\n            self_loop=False,\n        )\n        self.conv2 = RelGraphConv(\n            h_dim,\n            out_dim,\n            num_rels,\n            regularizer=\"basis\",\n            num_bases=num_rels,\n            self_loop=False,\n        )\n\n    def forward(self, g):\n        x = self.emb.weight\n        h = F.relu(self.conv1(g, x, g.edata[dgl.ETYPE], g.edata[\"norm\"]))\n        h = self.conv2(g, h, g.edata[dgl.ETYPE], g.edata[\"norm\"])\n        return h\n\n\ndef evaluate(g, target_idx, labels, num_classes, test_mask, model):\n    test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze()\n    model.eval()\n    with torch.no_grad():\n        logits = model(g)\n    logits = logits[target_idx]\n    return accuracy(\n        logits[test_idx].argmax(dim=1),\n        labels[test_idx],\n        task=\"multiclass\",\n        num_classes=num_classes,\n    ).item()\n\n\ndef train(g, target_idx, labels, num_classes, train_mask, model):\n    # define train idx, loss function and optimizer\n    train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze()\n    loss_fcn = nn.CrossEntropyLoss()\n    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)\n\n    model.train()\n    for epoch in range(50):\n        logits = model(g)\n        logits = logits[target_idx]\n        loss = loss_fcn(logits[train_idx], labels[train_idx])\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n        acc = accuracy(\n            logits[train_idx].argmax(dim=1),\n            labels[train_idx],\n            task=\"multiclass\",\n            num_classes=num_classes,\n        ).item()\n        print(\n            \"Epoch {:05d} | Loss {:.4f} | Train Accuracy {:.4f} \".format(\n                epoch, loss.item(), acc\n            )\n        )\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(\n        description=\"RGCN for entity classification\"\n    )\n    parser.add_argument(\n        \"--dataset\",\n        type=str,\n        default=\"aifb\",\n        help=\"Dataset name ('aifb', 'mutag', 'bgs', 'am').\",\n    )\n    args = parser.parse_args()\n    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n    print(f\"Training with DGL built-in RGCN module.\")\n\n    # load and preprocess dataset\n    if args.dataset == \"aifb\":\n        data = AIFBDataset()\n    elif args.dataset == \"mutag\":\n        data = MUTAGDataset()\n    elif args.dataset == \"bgs\":\n        data = BGSDataset()\n    elif args.dataset == \"am\":\n        data = AMDataset()\n    else:\n        raise ValueError(\"Unknown dataset: {}\".format(args.dataset))\n    g = data[0]\n    g = g.int().to(device)\n    num_rels = len(g.canonical_etypes)\n    category = data.predict_category\n    labels = g.nodes[category].data.pop(\"labels\")\n    train_mask = g.nodes[category].data.pop(\"train_mask\")\n    test_mask = g.nodes[category].data.pop(\"test_mask\")\n    # calculate normalization weight for each edge, and find target category and node id\n    for cetype in g.canonical_etypes:\n        g.edges[cetype].data[\"norm\"] = dgl.norm_by_dst(g, cetype).unsqueeze(1)\n    category_id = g.ntypes.index(category)\n    g = dgl.to_homogeneous(g, edata=[\"norm\"])\n    node_ids = torch.arange(g.num_nodes()).to(device)\n    target_idx = node_ids[g.ndata[dgl.NTYPE] == category_id]\n    # create RGCN model\n    in_size = g.num_nodes()  # featureless with one-hot encoding\n    num_classes = data.num_classes\n    model = RGCN(in_size, 16, num_classes, num_rels).to(device)\n\n    train(g, target_idx, labels, num_classes, train_mask, model)\n    acc = evaluate(g, target_idx, labels, num_classes, test_mask, model)\n    print(\"Test accuracy {:.4f}\".format(acc))\n"
  },
  {
    "path": "examples/pytorch/rgcn/entity_sample.py",
    "content": "import argparse\n\nimport dgl\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dgl.data.rdf import AIFBDataset, AMDataset, BGSDataset, MUTAGDataset\nfrom dgl.dataloading import DataLoader, MultiLayerNeighborSampler\nfrom dgl.nn.pytorch import RelGraphConv\nfrom torchmetrics.functional import accuracy\n\n\nclass RGCN(nn.Module):\n    def __init__(self, num_nodes, h_dim, out_dim, num_rels):\n        super().__init__()\n        self.emb = nn.Embedding(num_nodes, h_dim)\n        # two-layer RGCN\n        self.conv1 = RelGraphConv(\n            h_dim,\n            h_dim,\n            num_rels,\n            regularizer=\"basis\",\n            num_bases=num_rels,\n            self_loop=False,\n        )\n        self.conv2 = RelGraphConv(\n            h_dim,\n            out_dim,\n            num_rels,\n            regularizer=\"basis\",\n            num_bases=num_rels,\n            self_loop=False,\n        )\n\n    def forward(self, g):\n        x = self.emb(g[0].srcdata[dgl.NID])\n        h = F.relu(\n            self.conv1(g[0], x, g[0].edata[dgl.ETYPE], g[0].edata[\"norm\"])\n        )\n        h = self.conv2(g[1], h, g[1].edata[dgl.ETYPE], g[1].edata[\"norm\"])\n        return h\n\n\ndef evaluate(model, labels, num_classes, dataloader, inv_target):\n    model.eval()\n    eval_logits = []\n    eval_seeds = []\n    with torch.no_grad():\n        for input_nodes, output_nodes, blocks in dataloader:\n            output_nodes = inv_target[output_nodes]\n            for block in blocks:\n                block.edata[\"norm\"] = dgl.norm_by_dst(block).unsqueeze(1)\n            logits = model(blocks)\n            eval_logits.append(logits.cpu().detach())\n            eval_seeds.append(output_nodes.cpu().detach())\n    eval_logits = torch.cat(eval_logits)\n    eval_seeds = torch.cat(eval_seeds)\n    return accuracy(\n        eval_logits.argmax(dim=1),\n        labels[eval_seeds].cpu(),\n        task=\"multiclass\",\n        num_classes=num_classes,\n    ).item()\n\n\ndef train(device, g, target_idx, labels, train_mask, num_classes, model):\n    # define train idx, loss function and optimizer\n    train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze()\n    loss_fcn = nn.CrossEntropyLoss()\n    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)\n    # construct sampler and dataloader\n    sampler = MultiLayerNeighborSampler([4, 4])\n    train_loader = DataLoader(\n        g,\n        target_idx[train_idx],\n        sampler,\n        device=device,\n        batch_size=100,\n        shuffle=True,\n    )\n    # no separate validation subset, use train index instead for validation\n    val_loader = DataLoader(\n        g,\n        target_idx[train_idx],\n        sampler,\n        device=device,\n        batch_size=100,\n        shuffle=False,\n    )\n    for epoch in range(50):\n        model.train()\n        total_loss = 0\n        for it, (input_nodes, output_nodes, blocks) in enumerate(train_loader):\n            output_nodes = inv_target[output_nodes]\n            for block in blocks:\n                block.edata[\"norm\"] = dgl.norm_by_dst(block).unsqueeze(1)\n            logits = model(blocks)\n            loss = loss_fcn(logits, labels[output_nodes])\n            optimizer.zero_grad()\n            loss.backward()\n            optimizer.step()\n            total_loss += loss.item()\n        acc = evaluate(model, labels, num_classes, val_loader, inv_target)\n        print(\n            \"Epoch {:05d} | Loss {:.4f} | Val. Accuracy {:.4f} \".format(\n                epoch, total_loss / (it + 1), acc\n            )\n        )\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(\n        description=\"RGCN for entity classification with sampling\"\n    )\n    parser.add_argument(\n        \"--dataset\",\n        type=str,\n        default=\"aifb\",\n        help=\"Dataset name ('aifb', 'mutag', 'bgs', 'am').\",\n    )\n    args = parser.parse_args()\n    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n    print(f\"Training with DGL built-in RGCN module with sampling.\")\n\n    # load and preprocess dataset\n    if args.dataset == \"aifb\":\n        data = AIFBDataset()\n    elif args.dataset == \"mutag\":\n        data = MUTAGDataset()\n    elif args.dataset == \"bgs\":\n        data = BGSDataset()\n    elif args.dataset == \"am\":\n        data = AMDataset()\n    else:\n        raise ValueError(\"Unknown dataset: {}\".format(args.dataset))\n    g = data[0]\n    num_rels = len(g.canonical_etypes)\n    category = data.predict_category\n    labels = g.nodes[category].data.pop(\"labels\").to(device)\n    train_mask = g.nodes[category].data.pop(\"train_mask\")\n    test_mask = g.nodes[category].data.pop(\"test_mask\")\n    # find target category and node id\n    category_id = g.ntypes.index(category)\n    g = dgl.to_homogeneous(g)\n    node_ids = torch.arange(g.num_nodes())\n    target_idx = node_ids[g.ndata[dgl.NTYPE] == category_id]\n    # rename the fields as they can be changed by DataLoader\n    g.ndata[\"ntype\"] = g.ndata.pop(dgl.NTYPE)\n    g.ndata[\"type_id\"] = g.ndata.pop(dgl.NID)\n    # find the mapping (inv_target) from global node IDs to type-specific node IDs\n    inv_target = torch.empty((g.num_nodes(),), dtype=torch.int64).to(device)\n    inv_target[target_idx] = torch.arange(\n        0, target_idx.shape[0], dtype=inv_target.dtype\n    ).to(device)\n\n    # create RGCN model\n    in_size = g.num_nodes()  # featureless with one-hot encoding\n    num_classes = data.num_classes\n    model = RGCN(in_size, 16, num_classes, num_rels).to(device)\n\n    train(device, g, target_idx, labels, train_mask, num_classes, model)\n    test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze()\n    test_sampler = MultiLayerNeighborSampler(\n        [-1, -1]\n    )  # -1 for sampling all neighbors\n    test_loader = DataLoader(\n        g,\n        target_idx[test_idx],\n        test_sampler,\n        device=device,\n        batch_size=32,\n        shuffle=False,\n    )\n    acc = evaluate(model, labels, num_classes, test_loader, inv_target)\n    print(\"Test accuracy {:.4f}\".format(acc))\n"
  },
  {
    "path": "examples/pytorch/rgcn/entity_sample_multi_gpu.py",
    "content": "import argparse\nimport os\n\nimport dgl\nimport torch\nimport torch.distributed as dist\nimport torch.multiprocessing as mp\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dgl.data.rdf import AIFBDataset, AMDataset, BGSDataset, MUTAGDataset\nfrom dgl.dataloading import DataLoader, MultiLayerNeighborSampler\nfrom dgl.nn.pytorch import RelGraphConv\nfrom torch.nn.parallel import DistributedDataParallel\nfrom torchmetrics.functional import accuracy\n\n\nclass RGCN(nn.Module):\n    def __init__(self, num_nodes, h_dim, out_dim, num_rels):\n        super().__init__()\n        self.emb = nn.Embedding(num_nodes, h_dim)\n        # two-layer RGCN\n        self.conv1 = RelGraphConv(\n            h_dim,\n            h_dim,\n            num_rels,\n            regularizer=\"basis\",\n            num_bases=num_rels,\n            self_loop=False,\n        )\n        self.conv2 = RelGraphConv(\n            h_dim,\n            out_dim,\n            num_rels,\n            regularizer=\"basis\",\n            num_bases=num_rels,\n            self_loop=False,\n        )\n\n    def forward(self, g):\n        x = self.emb(g[0].srcdata[dgl.NID])\n        h = F.relu(\n            self.conv1(g[0], x, g[0].edata[dgl.ETYPE], g[0].edata[\"norm\"])\n        )\n        h = self.conv2(g[1], h, g[1].edata[dgl.ETYPE], g[1].edata[\"norm\"])\n        return h\n\n\ndef evaluate(model, labels, num_classes, dataloader, inv_target):\n    model.eval()\n    eval_logits = []\n    eval_seeds = []\n    with torch.no_grad():\n        for input_nodes, output_nodes, blocks in dataloader:\n            output_nodes = inv_target[output_nodes]\n            for block in blocks:\n                block.edata[\"norm\"] = dgl.norm_by_dst(block).unsqueeze(1)\n            logits = model(blocks)\n            eval_logits.append(logits.cpu().detach())\n            eval_seeds.append(output_nodes.cpu().detach())\n    eval_logits = torch.cat(eval_logits)\n    eval_seeds = torch.cat(eval_seeds)\n    num_seeds = len(eval_seeds)\n    loc_sum = accuracy(\n        eval_logits.argmax(dim=1),\n        labels[eval_seeds].cpu(),\n        task=\"multiclass\",\n        num_classes=num_classes,\n    ) * float(num_seeds)\n    return torch.tensor([loc_sum.item(), float(num_seeds)])\n\n\ndef train(\n    proc_id,\n    device,\n    g,\n    target_idx,\n    labels,\n    num_classes,\n    train_idx,\n    inv_target,\n    model,\n):\n    # define loss function and optimizer\n    loss_fcn = nn.CrossEntropyLoss()\n    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)\n    # construct sampler and dataloader\n    sampler = MultiLayerNeighborSampler([4, 4])\n    train_loader = DataLoader(\n        g,\n        target_idx[train_idx],\n        sampler,\n        device=device,\n        batch_size=100,\n        shuffle=True,\n        use_ddp=True,\n    )\n    # no separate validation subset, use train index instead for validation\n    val_loader = DataLoader(\n        g,\n        target_idx[train_idx],\n        sampler,\n        device=device,\n        batch_size=100,\n        shuffle=False,\n        use_ddp=True,\n    )\n    for epoch in range(50):\n        model.train()\n        total_loss = 0\n        for it, (input_nodes, output_nodes, blocks) in enumerate(train_loader):\n            output_nodes = inv_target[output_nodes]\n            for block in blocks:\n                block.edata[\"norm\"] = dgl.norm_by_dst(block).unsqueeze(1)\n            logits = model(blocks)\n            loss = loss_fcn(logits, labels[output_nodes])\n            optimizer.zero_grad()\n            loss.backward()\n            optimizer.step()\n            total_loss += loss.item()\n        # torchmetric accuracy defined as num_correct_labels / num_train_nodes\n        # loc_acc_split = [loc_accuracy * loc_num_train_nodes, loc_num_train_nodes]\n        loc_acc_split = evaluate(\n            model, labels, num_classes, val_loader, inv_target\n        ).to(device)\n        dist.reduce(loc_acc_split, 0)\n        if proc_id == 0:\n            acc = loc_acc_split[0] / loc_acc_split[1]\n            print(\n                \"Epoch {:05d} | Loss {:.4f} | Val. Accuracy {:.4f} \".format(\n                    epoch, total_loss / (it + 1), acc.item()\n                )\n            )\n\n\ndef run(proc_id, nprocs, devices, g, data):\n    # find corresponding device for my rank\n    device = devices[proc_id]\n    torch.cuda.set_device(device)\n    # initialize process group and unpack data for sub-processes\n    dist.init_process_group(\n        backend=\"nccl\",\n        init_method=\"tcp://127.0.0.1:12345\",\n        world_size=nprocs,\n        rank=proc_id,\n    )\n    (\n        num_rels,\n        num_classes,\n        labels,\n        train_idx,\n        test_idx,\n        target_idx,\n        inv_target,\n    ) = data\n    labels = labels.to(device)\n    inv_target = inv_target.to(device)\n    # create RGCN model (distributed)\n    in_size = g.num_nodes()\n    model = RGCN(in_size, 16, num_classes, num_rels).to(device)\n    model = DistributedDataParallel(\n        model, device_ids=[device], output_device=device\n    )\n    # training + testing\n    train(\n        proc_id,\n        device,\n        g,\n        target_idx,\n        labels,\n        num_classes,\n        train_idx,\n        inv_target,\n        model,\n    )\n    test_sampler = MultiLayerNeighborSampler(\n        [-1, -1]\n    )  # -1 for sampling all neighbors\n    test_loader = DataLoader(\n        g,\n        target_idx[test_idx],\n        test_sampler,\n        device=device,\n        batch_size=32,\n        shuffle=False,\n        use_ddp=True,\n    )\n    loc_acc_split = evaluate(\n        model, labels, num_classes, test_loader, inv_target\n    ).to(device)\n    dist.reduce(loc_acc_split, 0)\n    if proc_id == 0:\n        acc = loc_acc_split[0] / loc_acc_split[1]\n        print(\"Test accuracy {:.4f}\".format(acc))\n    # cleanup process group\n    dist.destroy_process_group()\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(\n        description=\"RGCN for entity classification with sampling (multi-gpu)\"\n    )\n    parser.add_argument(\n        \"--dataset\",\n        type=str,\n        default=\"aifb\",\n        help=\"Dataset name ('aifb', 'mutag', 'bgs', 'am').\",\n    )\n    parser.add_argument(\n        \"--gpu\",\n        type=str,\n        default=\"0\",\n        help=\"GPU(s) in use. Can be a list of gpu ids for multi-gpu training,\"\n        \" e.g., 0,1,2,3.\",\n    )\n    args = parser.parse_args()\n    devices = list(map(int, args.gpu.split(\",\")))\n    nprocs = len(devices)\n    print(\n        f\"Training with DGL built-in RGCN module with sampling using\",\n        nprocs,\n        f\"GPU(s)\",\n    )\n\n    # load and preprocess dataset at master(parent) process\n    if args.dataset == \"aifb\":\n        data = AIFBDataset()\n    elif args.dataset == \"mutag\":\n        data = MUTAGDataset()\n    elif args.dataset == \"bgs\":\n        data = BGSDataset()\n    elif args.dataset == \"am\":\n        data = AMDataset()\n    else:\n        raise ValueError(\"Unknown dataset: {}\".format(args.dataset))\n    g = data[0]\n    num_rels = len(g.canonical_etypes)\n    category = data.predict_category\n    labels = g.nodes[category].data.pop(\"labels\")\n    train_mask = g.nodes[category].data.pop(\"train_mask\")\n    test_mask = g.nodes[category].data.pop(\"test_mask\")\n    # find target category and node id\n    category_id = g.ntypes.index(category)\n    g = dgl.to_homogeneous(g)\n    node_ids = torch.arange(g.num_nodes())\n    target_idx = node_ids[g.ndata[dgl.NTYPE] == category_id]\n    # rename the fields as they can be changed by DataLoader\n    g.ndata[\"ntype\"] = g.ndata.pop(dgl.NTYPE)\n    g.ndata[\"type_id\"] = g.ndata.pop(dgl.NID)\n    # find the mapping (inv_target) from global node IDs to type-specific node IDs\n    inv_target = torch.empty((g.num_nodes(),), dtype=torch.int64)\n    inv_target[target_idx] = torch.arange(\n        0, target_idx.shape[0], dtype=inv_target.dtype\n    )\n    # avoid creating certain graph formats and train/test indexes in each sub-process to save momory\n    g.create_formats_()\n    train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze()\n    test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze()\n    # thread limiting to avoid resource competition\n    os.environ[\"OMP_NUM_THREADS\"] = str(mp.cpu_count() // 2 // nprocs)\n\n    data = (\n        num_rels,\n        data.num_classes,\n        labels,\n        train_idx,\n        test_idx,\n        target_idx,\n        inv_target,\n    )\n    mp.spawn(run, args=(nprocs, devices, g, data), nprocs=nprocs)\n"
  },
  {
    "path": "examples/pytorch/rgcn/entity_utils.py",
    "content": "import dgl\nimport torch as th\n\nfrom dgl.data.rdf import AIFBDataset, AMDataset, BGSDataset, MUTAGDataset\n\n\ndef load_data(data_name, get_norm=False, inv_target=False):\n    if data_name == \"aifb\":\n        dataset = AIFBDataset()\n    elif data_name == \"mutag\":\n        dataset = MUTAGDataset()\n    elif data_name == \"bgs\":\n        dataset = BGSDataset()\n    else:\n        dataset = AMDataset()\n\n    # Load hetero-graph\n    hg = dataset[0]\n\n    num_rels = len(hg.canonical_etypes)\n    category = dataset.predict_category\n    num_classes = dataset.num_classes\n    labels = hg.nodes[category].data.pop(\"labels\")\n    train_mask = hg.nodes[category].data.pop(\"train_mask\")\n    test_mask = hg.nodes[category].data.pop(\"test_mask\")\n    train_idx = th.nonzero(train_mask, as_tuple=False).squeeze()\n    test_idx = th.nonzero(test_mask, as_tuple=False).squeeze()\n\n    if get_norm:\n        # Calculate normalization weight for each edge,\n        # 1. / d, d is the degree of the destination node\n        for cetype in hg.canonical_etypes:\n            hg.edges[cetype].data[\"norm\"] = dgl.norm_by_dst(\n                hg, cetype\n            ).unsqueeze(1)\n        edata = [\"norm\"]\n    else:\n        edata = None\n\n    # get target category id\n    category_id = hg.ntypes.index(category)\n\n    g = dgl.to_homogeneous(hg, edata=edata)\n    # Rename the fields as they can be changed by for example DataLoader\n    g.ndata[\"ntype\"] = g.ndata.pop(dgl.NTYPE)\n    g.ndata[\"type_id\"] = g.ndata.pop(dgl.NID)\n    node_ids = th.arange(g.num_nodes())\n\n    # find out the target node ids in g\n    loc = g.ndata[\"ntype\"] == category_id\n    target_idx = node_ids[loc]\n\n    if inv_target:\n        # Map global node IDs to type-specific node IDs. This is required for\n        # looking up type-specific labels in a minibatch\n        inv_target = th.empty((g.num_nodes(),), dtype=th.int64)\n        inv_target[target_idx] = th.arange(\n            0, target_idx.shape[0], dtype=inv_target.dtype\n        )\n        return (\n            g,\n            num_rels,\n            num_classes,\n            labels,\n            train_idx,\n            test_idx,\n            target_idx,\n            inv_target,\n        )\n    else:\n        return g, num_rels, num_classes, labels, train_idx, test_idx, target_idx\n"
  },
  {
    "path": "examples/pytorch/rgcn/experimental/README.md",
    "content": "## Distributed training\n\nThis is an example of training RGCN node classification in a distributed fashion. Currently, the example train RGCN graphs with input node features. The current implementation follows ../rgcn/entity_claasify_mp.py.\n\nBefore training, install python libs by pip:\n\n```bash\npip3 install ogb pyarrow\n```\n\nTo train RGCN, it has four steps:\n\n### Step 0: Setup a Distributed File System\n* You may skip this step if your cluster already has folder(s) synchronized across machines.\n\nTo perform distributed training, files and codes need to be accessed across multiple machines. A distributed file system would perfectly handle the job (i.e., NFS, Ceph).\n\n#### Server side setup\nHere is an example of how to setup NFS. First, install essential libs on the storage server\n```bash\nsudo apt-get install nfs-kernel-server\n```\n\nBelow we assume the user account is `ubuntu` and we create a directory of `workspace` in the home directory.\n```bash\nmkdir -p /home/ubuntu/workspace\n```\n\nWe assume that the all servers are under a subnet with ip range `192.168.0.0` to `192.168.255.255`. The exports configuration needs to be modifed to\n\n```bash\nsudo vim /etc/exports\n# add the following line\n/home/ubuntu/workspace  192.168.0.0/16(rw,sync,no_subtree_check)\n```\n\nThe server's internal ip can be checked  via `ifconfig` or `ip`. If the ip does not begin with `192.168`, then you may use\n```bash\n# for ip range 10.0.0.0 – 10.255.255.255\n/home/ubuntu/workspace  10.0.0.0/8(rw,sync,no_subtree_check)\n# for ip range 172.16.0.0 – 172.31.255.255\n/home/ubuntu/workspace  172.16.0.0/12(rw,sync,no_subtree_check)\n```\n\nThen restart NFS, the setup on server side is finished.\n\n```\nsudo systemctl restart nfs-kernel-server\n```\n\nFor configraution details, please refer to [NFS ArchWiki](https://wiki.archlinux.org/index.php/NFS).\n\n\n#### Client side setup\n\nTo use NFS, clients also require to install essential packages\n\n```\nsudo apt-get install nfs-common\n```\n\nYou can either mount the NFS manually\n\n```\nmkdir -p /home/ubuntu/workspace\nsudo mount -t nfs <nfs-server-ip>:/home/ubuntu/workspace /home/ubuntu/workspace\n```\n\nor edit the fstab so the folder will be mounted automatically\n\n```\n# vim /etc/fstab\n## append the following line to the file\n<nfs-server-ip>:/home/ubuntu/workspace   /home/ubuntu/workspace   nfs   defaults\t0 0\n```\n\nThen run `mount -a`.\n\nNow go to `/home/ubuntu/workspace` and clone the DGL Github repository.\n\n### Step 1: set IP configuration file.\n\nUser need to set their own IP configuration file `ip_config.txt` before training. For example, if we have four machines in current cluster, the IP configuration could like this:\n\n```bash\n172.31.0.1\n172.31.0.2\n172.31.0.3\n172.31.0.4\n```\n\nUsers need to make sure that the master node (node-0) has right permission to ssh to all the other nodes without password authentication.\n[This link](https://linuxize.com/post/how-to-setup-passwordless-ssh-login/) provides instructions of setting passwordless SSH login.\n\n### Step 2: partition the graph.\n\nThe example provides a script to partition some builtin graphs such as ogbn-mag graph.\nIf we want to train RGCN on 4 machines, we need to partition the graph into 4 parts.\n\nIn this example, we partition the ogbn-mag graph into 4 parts with Metis. The partitions are balanced with respect to\nthe number of nodes, the number of edges and the number of labelled nodes.\n```bash\npython3 partition_graph.py --dataset ogbn-mag --num_parts 4 --balance_train --balance_edges\n```\n\n### Step 3: Launch distributed jobs\n\nDGL provides a script to launch the training job in the cluster. `part_config` and `ip_config`\nspecify relative paths to the path of the workspace.\n\nThe command below launches one training process on each machine and each training process has 4 sampling processes.\n\n```bash\npython3 ~/workspace/dgl/tools/launch.py \\\n--workspace ~/workspace/dgl/examples/pytorch/rgcn/experimental/ \\\n--num_trainers 1 \\\n--num_servers 1 \\\n--num_samplers 4 \\\n--part_config data/ogbn-mag.json \\\n--ip_config ip_config.txt \\\n\"python3 entity_classify_dist.py --graph-name ogbn-mag --dataset ogbn-mag --fanout='25,25' --batch-size 1024  --n-hidden 64 --lr 0.01 --eval-batch-size 1024  --low-mem --dropout 0.5 --use-self-loop --n-bases 2 --n-epochs 3 --layer-norm --ip-config ip_config.txt  --sparse-embedding --sparse-lr 0.06 --num_gpus 1\"\n```\n\nWe can get the performance score at the second epoch:\n```\nVal Acc 0.4323, Test Acc 0.4255, time: 128.0379\n```\n\nThe command below launches the same distributed training job using dgl distributed DistEmbedding\n```bash\npython3 ~/workspace/dgl/tools/launch.py \\\n--workspace ~/workspace/dgl/examples/pytorch/rgcn/experimental/ \\\n--num_trainers 1 \\\n--num_servers 1 \\\n--num_samplers 4 \\\n--part_config data/ogbn-mag.json \\\n--ip_config ip_config.txt \\\n\"python3 entity_classify_dist.py --graph-name ogbn-mag --dataset ogbn-mag --fanout='25,25' --batch-size 1024  --n-hidden 64 --lr 0.01 --eval-batch-size 1024  --low-mem --dropout 0.5 --use-self-loop --n-bases 2 --n-epochs 3 --layer-norm --ip-config ip_config.txt  --sparse-embedding --sparse-lr 0.06 --num_gpus 1 --dgl-sparse\"\n```\n\nWe can get the performance score at the second epoch:\n```\nVal Acc 0.4410, Test Acc 0.4282, time: 32.5274\n```\n\n**Note:** if you are using conda or other virtual environments on the remote machines, you need to replace `python3` in the command string (i.e. the last argument) with the path to the Python interpreter in that environment.\n\n## Partition a graph with ParMETIS\n\nIt has four steps to partition a graph with ParMETIS for DGL's distributed training.\nMore details about the four steps are explained in our\n[user guide](https://doc.dgl.ai/guide/distributed-preprocessing.html).\n\n### Step 1: write the graph into files.\n\nThe graph structure should be written as a node file and an edge file. The node features and edge features\ncan be written as DGL tensors. `write_mag.py` shows an example of writing the OGB MAG graph into files.\n\nAs `pm_dglpart` cannot handle self-loops and duplicate edges correctly, these edges are removed and stored\ninto `mag_removed_edges.txt` when calling `write_mag.py`. When converting ParMETIS outputs into DGLGraph\nin next steps, `mag_removed_edges.txt` should be passed in. Refer to Step 3 for more details.\n\n```bash\npython3 write_mag.py\n```\n\n### Step 2: partition the graph with ParMETIS\nRun the program called `pm_dglpart` in ParMETIS to read the node file and the edge file output in Step 1\nto partition the graph.\n\n```bash\npm_dglpart mag 2\n```\nThis partitions the graph into two parts with a single process.\n\n```\nmpirun -np 4 pm_dglpart mag 2\n```\nThis partitions the graph into eight parts with four processes.\n\n```\nmpirun --hostfile hostfile -np 4 pm_dglpart mag 2\n```\nThis partitions the graph into eight parts with four processes on multiple machines.\n`hostfile` specifies the IPs of the machines; one line for a machine. The input files\nshould reside in the machine where the command line runs. Each process will write\nthe partitions to files in the local machine. For simplicity, we recommend users to\nwrite the files on NFS.\n\n### Step 3: Convert the ParMETIS partitions into DGLGraph\n\nDGL provides a tool called `convert_partition.py` to load one partition at a time and convert it into a DGLGraph\nand save it into a file. As mentioned in Step 1, please pass `mag_removed_edges.txt` if any self-loops and\nduplicate edges are removed.\n\n```bash\npython3 ~/workspace/dgl/tools/convert_partition.py --input-dir . --graph-name mag --schema mag.json --num-parts 2 --num-node-weights 4 --output outputs --removed-edges mag_removed_edges.txt\n```\n\n### Step 4: Read node data and edge data for each partition\n\nThis shows an example of reading node data and edge data of each partition and saving them into files located in the same directory as the DGLGraph file.\n\n```bash\npython3 get_mag_data.py\n```\n\n### Step 5: Verify the partition result (Optional)\n\n```bash\npython3 verify_mag_partitions.py\n```\n\n## Distributed code runs in the standalone mode\n\nThe standalone mode is mainly used for development and testing. The procedure to run the code is much simpler.\n\n### Step 1: graph construction.\nWhen testing the standalone mode of the training script, we should construct a graph with one partition.\n```bash\npython3 partition_graph.py --dataset ogbn-mag --num_parts 1\n```\n\n### Step 2: run the training script\n```bash\nDGL_DIST_MODE=standalone python3 entity_classify_dist.py --graph-name ogbn-mag  --dataset ogbn-mag --fanout='25,25' --batch-size 512 --n-hidden 64 --lr 0.01 --eval-batch-size 128 --low-mem --dropout 0.5 --use-self-loop --n-bases 2 --n-epochs 3 --layer-norm --ip-config ip_config.txt --conf-path 'data/ogbn-mag.json' --standalone  --sparse-embedding  --sparse-lr 0.06\n```\n"
  },
  {
    "path": "examples/pytorch/rgcn/experimental/entity_classify_dist.py",
    "content": "\"\"\"\nModeling Relational Data with Graph Convolutional Networks\nPaper: https://arxiv.org/abs/1703.06103\nCode: https://github.com/tkipf/relational-gcn\nDifference compared to tkipf/relation-gcn\n* l2norm applied to all weights\n* remove nodes that won't be touched\n\"\"\"\n\nimport argparse\nimport gc, os\nimport itertools\nimport time\n\nimport numpy as np\n\nos.environ[\"DGLBACKEND\"] = \"pytorch\"\n\nfrom functools import partial\n\nimport dgl\nimport dgl.distributed\nimport torch as th\nimport torch.multiprocessing as mp\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nimport tqdm\nfrom dgl import DGLGraph, nn as dglnn\nfrom dgl.distributed import DistDataLoader\n\nfrom ogb.nodeproppred import DglNodePropPredDataset\nfrom torch.multiprocessing import Queue\nfrom torch.nn.parallel import DistributedDataParallel\nfrom torch.utils.data import DataLoader\n\n\nclass RelGraphConvLayer(nn.Module):\n    r\"\"\"Relational graph convolution layer.\n    Parameters\n    ----------\n    in_feat : int\n        Input feature size.\n    out_feat : int\n        Output feature size.\n    rel_names : list[str]\n        Relation names.\n    num_bases : int, optional\n        Number of bases. If is none, use number of relations. Default: None.\n    weight : bool, optional\n        True if a linear layer is applied after message passing. Default: True\n    bias : bool, optional\n        True if bias is added. Default: True\n    activation : callable, optional\n        Activation function. Default: None\n    self_loop : bool, optional\n        True to include self loop message. Default: False\n    dropout : float, optional\n        Dropout rate. Default: 0.0\n    \"\"\"\n\n    def __init__(\n        self,\n        in_feat,\n        out_feat,\n        rel_names,\n        num_bases,\n        *,\n        weight=True,\n        bias=True,\n        activation=None,\n        self_loop=False,\n        dropout=0.0\n    ):\n        super(RelGraphConvLayer, self).__init__()\n        self.in_feat = in_feat\n        self.out_feat = out_feat\n        self.rel_names = rel_names\n        self.num_bases = num_bases\n        self.bias = bias\n        self.activation = activation\n        self.self_loop = self_loop\n\n        self.conv = dglnn.HeteroGraphConv(\n            {\n                rel: dglnn.GraphConv(\n                    in_feat, out_feat, norm=\"right\", weight=False, bias=False\n                )\n                for rel in rel_names\n            }\n        )\n\n        self.use_weight = weight\n        self.use_basis = num_bases < len(self.rel_names) and weight\n        if self.use_weight:\n            if self.use_basis:\n                self.basis = dglnn.WeightBasis(\n                    (in_feat, out_feat), num_bases, len(self.rel_names)\n                )\n            else:\n                self.weight = nn.Parameter(\n                    th.Tensor(len(self.rel_names), in_feat, out_feat)\n                )\n                nn.init.xavier_uniform_(\n                    self.weight, gain=nn.init.calculate_gain(\"relu\")\n                )\n\n        # bias\n        if bias:\n            self.h_bias = nn.Parameter(th.Tensor(out_feat))\n            nn.init.zeros_(self.h_bias)\n\n        # weight for self loop\n        if self.self_loop:\n            self.loop_weight = nn.Parameter(th.Tensor(in_feat, out_feat))\n            nn.init.xavier_uniform_(\n                self.loop_weight, gain=nn.init.calculate_gain(\"relu\")\n            )\n\n        self.dropout = nn.Dropout(dropout)\n\n    def forward(self, g, inputs):\n        \"\"\"Forward computation\n        Parameters\n        ----------\n        g : DGLGraph\n            Input graph.\n        inputs : dict[str, torch.Tensor]\n            Node feature for each node type.\n        Returns\n        -------\n        dict[str, torch.Tensor]\n            New node features for each node type.\n        \"\"\"\n        g = g.local_var()\n        if self.use_weight:\n            weight = self.basis() if self.use_basis else self.weight\n            wdict = {\n                self.rel_names[i]: {\"weight\": w.squeeze(0)}\n                for i, w in enumerate(th.split(weight, 1, dim=0))\n            }\n        else:\n            wdict = {}\n\n        if g.is_block:\n            inputs_src = inputs\n            inputs_dst = {\n                k: v[: g.number_of_dst_nodes(k)] for k, v in inputs.items()\n            }\n        else:\n            inputs_src = inputs_dst = inputs\n\n        hs = self.conv(g, inputs, mod_kwargs=wdict)\n\n        def _apply(ntype, h):\n            if self.self_loop:\n                h = h + th.matmul(inputs_dst[ntype], self.loop_weight)\n            if self.bias:\n                h = h + self.h_bias\n            if self.activation:\n                h = self.activation(h)\n            return self.dropout(h)\n\n        return {ntype: _apply(ntype, h) for ntype, h in hs.items()}\n\n\nclass EntityClassify(nn.Module):\n    \"\"\"Entity classification class for RGCN\n    Parameters\n    ----------\n    device : int\n        Device to run the layer.\n    num_nodes : int\n        Number of nodes.\n    h_dim : int\n        Hidden dim size.\n    out_dim : int\n        Output dim size.\n    rel_names : list of str\n        A list of relation names.\n    num_bases : int\n        Number of bases. If is none, use number of relations.\n    num_hidden_layers : int\n        Number of hidden RelGraphConv Layer\n    dropout : float\n        Dropout\n    use_self_loop : bool\n        Use self loop if True, default False.\n    \"\"\"\n\n    def __init__(\n        self,\n        device,\n        h_dim,\n        out_dim,\n        rel_names,\n        num_bases=None,\n        num_hidden_layers=1,\n        dropout=0,\n        use_self_loop=False,\n        layer_norm=False,\n    ):\n        super(EntityClassify, self).__init__()\n        self.device = device\n        self.h_dim = h_dim\n        self.out_dim = out_dim\n        self.num_bases = None if num_bases < 0 else num_bases\n        self.num_hidden_layers = num_hidden_layers\n        self.dropout = dropout\n        self.use_self_loop = use_self_loop\n        self.layer_norm = layer_norm\n\n        self.layers = nn.ModuleList()\n        # i2h\n        self.layers.append(\n            RelGraphConvLayer(\n                self.h_dim,\n                self.h_dim,\n                rel_names,\n                self.num_bases,\n                activation=F.relu,\n                self_loop=self.use_self_loop,\n                dropout=self.dropout,\n            )\n        )\n        # h2h\n        for idx in range(self.num_hidden_layers):\n            self.layers.append(\n                RelGraphConvLayer(\n                    self.h_dim,\n                    self.h_dim,\n                    rel_names,\n                    self.num_bases,\n                    activation=F.relu,\n                    self_loop=self.use_self_loop,\n                    dropout=self.dropout,\n                )\n            )\n        # h2o\n        self.layers.append(\n            RelGraphConvLayer(\n                self.h_dim,\n                self.out_dim,\n                rel_names,\n                self.num_bases,\n                activation=None,\n                self_loop=self.use_self_loop,\n            )\n        )\n\n    def forward(self, blocks, feats, norm=None):\n        if blocks is None:\n            # full graph training\n            blocks = [self.g] * len(self.layers)\n        h = feats\n        for layer, block in zip(self.layers, blocks):\n            block = block.to(self.device)\n            h = layer(block, h)\n        return h\n\n\ndef init_emb(shape, dtype):\n    arr = th.zeros(shape, dtype=dtype)\n    nn.init.uniform_(arr, -1.0, 1.0)\n    return arr\n\n\nclass DistEmbedLayer(nn.Module):\n    r\"\"\"Embedding layer for featureless heterograph.\n    Parameters\n    ----------\n    dev_id : int\n        Device to run the layer.\n    g : DistGraph\n        training graph\n    embed_size : int\n        Output embed size\n    sparse_emb: bool\n        Whether to use sparse embedding\n        Default: False\n    dgl_sparse_emb: bool\n        Whether to use DGL sparse embedding\n        Default: False\n    embed_name : str, optional\n        Embed name\n    \"\"\"\n\n    def __init__(\n        self,\n        dev_id,\n        g,\n        embed_size,\n        sparse_emb=False,\n        dgl_sparse_emb=False,\n        feat_name=\"feat\",\n        embed_name=\"node_emb\",\n    ):\n        super(DistEmbedLayer, self).__init__()\n        self.dev_id = dev_id\n        self.embed_size = embed_size\n        self.embed_name = embed_name\n        self.feat_name = feat_name\n        self.sparse_emb = sparse_emb\n        self.g = g\n        self.ntype_id_map = {g.get_ntype_id(ntype): ntype for ntype in g.ntypes}\n\n        self.node_projs = nn.ModuleDict()\n        for ntype in g.ntypes:\n            if feat_name in g.nodes[ntype].data:\n                self.node_projs[ntype] = nn.Linear(\n                    g.nodes[ntype].data[feat_name].shape[1], embed_size\n                )\n                nn.init.xavier_uniform_(self.node_projs[ntype].weight)\n                print(\"node {} has data {}\".format(ntype, feat_name))\n        if sparse_emb:\n            if dgl_sparse_emb:\n                self.node_embeds = {}\n                for ntype in g.ntypes:\n                    # We only create embeddings for nodes without node features.\n                    if feat_name not in g.nodes[ntype].data:\n                        part_policy = g.get_node_partition_policy(ntype)\n                        self.node_embeds[ntype] = dgl.distributed.DistEmbedding(\n                            g.num_nodes(ntype),\n                            self.embed_size,\n                            embed_name + \"_\" + ntype,\n                            init_emb,\n                            part_policy,\n                        )\n            else:\n                self.node_embeds = nn.ModuleDict()\n                for ntype in g.ntypes:\n                    # We only create embeddings for nodes without node features.\n                    if feat_name not in g.nodes[ntype].data:\n                        self.node_embeds[ntype] = th.nn.Embedding(\n                            g.num_nodes(ntype),\n                            self.embed_size,\n                            sparse=self.sparse_emb,\n                        )\n                        nn.init.uniform_(\n                            self.node_embeds[ntype].weight, -1.0, 1.0\n                        )\n        else:\n            self.node_embeds = nn.ModuleDict()\n            for ntype in g.ntypes:\n                # We only create embeddings for nodes without node features.\n                if feat_name not in g.nodes[ntype].data:\n                    self.node_embeds[ntype] = th.nn.Embedding(\n                        g.num_nodes(ntype), self.embed_size\n                    )\n                    nn.init.uniform_(self.node_embeds[ntype].weight, -1.0, 1.0)\n\n    def forward(self, node_ids):\n        \"\"\"Forward computation\n        Parameters\n        ----------\n        node_ids : dict of Tensor\n            node ids to generate embedding for.\n        Returns\n        -------\n        tensor\n            embeddings as the input of the next layer\n        \"\"\"\n        embeds = {}\n        for ntype in node_ids:\n            if self.feat_name in self.g.nodes[ntype].data:\n                embeds[ntype] = self.node_projs[ntype](\n                    self.g.nodes[ntype]\n                    .data[self.feat_name][node_ids[ntype]]\n                    .to(self.dev_id)\n                )\n            else:\n                embeds[ntype] = self.node_embeds[ntype](node_ids[ntype]).to(\n                    self.dev_id\n                )\n        return embeds\n\n\ndef compute_acc(results, labels):\n    \"\"\"\n    Compute the accuracy of prediction given the labels.\n    \"\"\"\n    labels = labels.long()\n    return (results == labels).float().sum() / len(results)\n\n\ndef evaluate(\n    g,\n    model,\n    embed_layer,\n    labels,\n    eval_loader,\n    test_loader,\n    all_val_nid,\n    all_test_nid,\n):\n    model.eval()\n    embed_layer.eval()\n    eval_logits = []\n    eval_seeds = []\n\n    global_results = dgl.distributed.DistTensor(\n        labels.shape, th.long, \"results\", persistent=True\n    )\n\n    with th.no_grad():\n        th.cuda.empty_cache()\n        for sample_data in tqdm.tqdm(eval_loader):\n            input_nodes, seeds, blocks = sample_data\n            seeds = seeds[\"paper\"]\n            feats = embed_layer(input_nodes)\n            logits = model(blocks, feats)\n            assert len(logits) == 1\n            logits = logits[\"paper\"]\n            eval_logits.append(logits.cpu().detach())\n            assert np.all(seeds.numpy() < g.num_nodes(\"paper\"))\n            eval_seeds.append(seeds.cpu().detach())\n    eval_logits = th.cat(eval_logits)\n    eval_seeds = th.cat(eval_seeds)\n    global_results[eval_seeds] = eval_logits.argmax(dim=1)\n\n    test_logits = []\n    test_seeds = []\n    with th.no_grad():\n        th.cuda.empty_cache()\n        for sample_data in tqdm.tqdm(test_loader):\n            input_nodes, seeds, blocks = sample_data\n            seeds = seeds[\"paper\"]\n            feats = embed_layer(input_nodes)\n            logits = model(blocks, feats)\n            assert len(logits) == 1\n            logits = logits[\"paper\"]\n            test_logits.append(logits.cpu().detach())\n            assert np.all(seeds.numpy() < g.num_nodes(\"paper\"))\n            test_seeds.append(seeds.cpu().detach())\n    test_logits = th.cat(test_logits)\n    test_seeds = th.cat(test_seeds)\n    global_results[test_seeds] = test_logits.argmax(dim=1)\n\n    g.barrier()\n    if g.rank() == 0:\n        return compute_acc(\n            global_results[all_val_nid], labels[all_val_nid]\n        ), compute_acc(global_results[all_test_nid], labels[all_test_nid])\n    else:\n        return -1, -1\n\n\ndef run(args, device, data):\n    (\n        g,\n        num_classes,\n        train_nid,\n        val_nid,\n        test_nid,\n        labels,\n        all_val_nid,\n        all_test_nid,\n    ) = data\n\n    fanouts = [int(fanout) for fanout in args.fanout.split(\",\")]\n    val_fanouts = [int(fanout) for fanout in args.validation_fanout.split(\",\")]\n\n    sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts)\n    dataloader = dgl.distributed.DistNodeDataLoader(\n        g,\n        {\"paper\": train_nid},\n        sampler,\n        batch_size=args.batch_size,\n        shuffle=True,\n        drop_last=False,\n    )\n\n    valid_sampler = dgl.dataloading.MultiLayerNeighborSampler(val_fanouts)\n    valid_dataloader = dgl.distributed.DistNodeDataLoader(\n        g,\n        {\"paper\": val_nid},\n        valid_sampler,\n        batch_size=args.batch_size,\n        shuffle=False,\n        drop_last=False,\n    )\n\n    test_sampler = dgl.dataloading.MultiLayerNeighborSampler(val_fanouts)\n    test_dataloader = dgl.distributed.DistNodeDataLoader(\n        g,\n        {\"paper\": test_nid},\n        test_sampler,\n        batch_size=args.eval_batch_size,\n        shuffle=False,\n        drop_last=False,\n    )\n\n    embed_layer = DistEmbedLayer(\n        device,\n        g,\n        args.n_hidden,\n        sparse_emb=args.sparse_embedding,\n        dgl_sparse_emb=args.dgl_sparse,\n        feat_name=\"feat\",\n    )\n\n    model = EntityClassify(\n        device,\n        args.n_hidden,\n        num_classes,\n        g.etypes,\n        num_bases=args.n_bases,\n        num_hidden_layers=args.n_layers - 2,\n        dropout=args.dropout,\n        use_self_loop=args.use_self_loop,\n        layer_norm=args.layer_norm,\n    )\n    model = model.to(device)\n\n    if not args.standalone:\n        if args.num_gpus == -1:\n            model = DistributedDataParallel(model)\n            # If there are dense parameters in the embedding layer\n            # or we use Pytorch saprse embeddings.\n            if len(embed_layer.node_projs) > 0 or not args.dgl_sparse:\n                embed_layer = DistributedDataParallel(embed_layer)\n        else:\n            dev_id = g.rank() % args.num_gpus\n            model = DistributedDataParallel(\n                model, device_ids=[dev_id], output_device=dev_id\n            )\n            # If there are dense parameters in the embedding layer\n            # or we use Pytorch saprse embeddings.\n            if len(embed_layer.node_projs) > 0 or not args.dgl_sparse:\n                embed_layer = embed_layer.to(device)\n                embed_layer = DistributedDataParallel(\n                    embed_layer, device_ids=[dev_id], output_device=dev_id\n                )\n\n    if args.sparse_embedding:\n        if args.dgl_sparse and args.standalone:\n            emb_optimizer = dgl.distributed.optim.SparseAdam(\n                list(embed_layer.node_embeds.values()), lr=args.sparse_lr\n            )\n            print(\n                \"optimize DGL sparse embedding:\", embed_layer.node_embeds.keys()\n            )\n        elif args.dgl_sparse:\n            emb_optimizer = dgl.distributed.optim.SparseAdam(\n                list(embed_layer.module.node_embeds.values()), lr=args.sparse_lr\n            )\n            print(\n                \"optimize DGL sparse embedding:\",\n                embed_layer.module.node_embeds.keys(),\n            )\n        elif args.standalone:\n            emb_optimizer = th.optim.SparseAdam(\n                list(embed_layer.node_embeds.parameters()), lr=args.sparse_lr\n            )\n            print(\"optimize Pytorch sparse embedding:\", embed_layer.node_embeds)\n        else:\n            emb_optimizer = th.optim.SparseAdam(\n                list(embed_layer.module.node_embeds.parameters()),\n                lr=args.sparse_lr,\n            )\n            print(\n                \"optimize Pytorch sparse embedding:\",\n                embed_layer.module.node_embeds,\n            )\n\n        dense_params = list(model.parameters())\n        if args.standalone:\n            dense_params += list(embed_layer.node_projs.parameters())\n            print(\"optimize dense projection:\", embed_layer.node_projs)\n        else:\n            dense_params += list(embed_layer.module.node_projs.parameters())\n            print(\"optimize dense projection:\", embed_layer.module.node_projs)\n        optimizer = th.optim.Adam(\n            dense_params, lr=args.lr, weight_decay=args.l2norm\n        )\n    else:\n        all_params = list(model.parameters()) + list(embed_layer.parameters())\n        optimizer = th.optim.Adam(\n            all_params, lr=args.lr, weight_decay=args.l2norm\n        )\n\n    # training loop\n    print(\"start training...\")\n    for epoch in range(args.n_epochs):\n        tic = time.time()\n\n        sample_time = 0\n        copy_time = 0\n        forward_time = 0\n        backward_time = 0\n        update_time = 0\n        number_train = 0\n        number_input = 0\n\n        step_time = []\n        iter_t = []\n        sample_t = []\n        feat_copy_t = []\n        forward_t = []\n        backward_t = []\n        update_t = []\n        iter_tput = []\n\n        start = time.time()\n        # Loop over the dataloader to sample the computation dependency graph as a list of\n        # blocks.\n        step_time = []\n        for step, sample_data in enumerate(dataloader):\n            input_nodes, seeds, blocks = sample_data\n            seeds = seeds[\"paper\"]\n            number_train += seeds.shape[0]\n            number_input += np.sum(\n                [blocks[0].num_src_nodes(ntype) for ntype in blocks[0].ntypes]\n            )\n            tic_step = time.time()\n            sample_time += tic_step - start\n            sample_t.append(tic_step - start)\n\n            feats = embed_layer(input_nodes)\n            label = labels[seeds].to(device)\n            copy_time = time.time()\n            feat_copy_t.append(copy_time - tic_step)\n\n            # forward\n            logits = model(blocks, feats)\n            assert len(logits) == 1\n            logits = logits[\"paper\"]\n            loss = F.cross_entropy(logits, label)\n            forward_end = time.time()\n\n            # backward\n            optimizer.zero_grad()\n            if args.sparse_embedding:\n                emb_optimizer.zero_grad()\n            loss.backward()\n            compute_end = time.time()\n            forward_t.append(forward_end - copy_time)\n            backward_t.append(compute_end - forward_end)\n\n            # Update model parameters\n            optimizer.step()\n            if args.sparse_embedding:\n                emb_optimizer.step()\n            update_t.append(time.time() - compute_end)\n            step_t = time.time() - start\n            step_time.append(step_t)\n\n            train_acc = th.sum(logits.argmax(dim=1) == label).item() / len(\n                seeds\n            )\n\n            if step % args.log_every == 0:\n                print(\n                    \"[{}] Epoch {:05d} | Step {:05d} | Train acc {:.4f} | Loss {:.4f} | time {:.3f} s\"\n                    \"| sample {:.3f} | copy {:.3f} | forward {:.3f} | backward {:.3f} | update {:.3f}\".format(\n                        g.rank(),\n                        epoch,\n                        step,\n                        train_acc,\n                        loss.item(),\n                        np.sum(step_time[-args.log_every :]),\n                        np.sum(sample_t[-args.log_every :]),\n                        np.sum(feat_copy_t[-args.log_every :]),\n                        np.sum(forward_t[-args.log_every :]),\n                        np.sum(backward_t[-args.log_every :]),\n                        np.sum(update_t[-args.log_every :]),\n                    )\n                )\n            start = time.time()\n\n        gc.collect()\n        print(\n            \"[{}]Epoch Time(s): {:.4f}, sample: {:.4f}, data copy: {:.4f}, forward: {:.4f}, backward: {:.4f}, update: {:.4f}, #train: {}, #input: {}\".format(\n                g.rank(),\n                np.sum(step_time),\n                np.sum(sample_t),\n                np.sum(feat_copy_t),\n                np.sum(forward_t),\n                np.sum(backward_t),\n                np.sum(update_t),\n                number_train,\n                number_input,\n            )\n        )\n        epoch += 1\n\n        start = time.time()\n        g.barrier()\n        val_acc, test_acc = evaluate(\n            g,\n            model,\n            embed_layer,\n            labels,\n            valid_dataloader,\n            test_dataloader,\n            all_val_nid,\n            all_test_nid,\n        )\n        if val_acc >= 0:\n            print(\n                \"Val Acc {:.4f}, Test Acc {:.4f}, time: {:.4f}\".format(\n                    val_acc, test_acc, time.time() - start\n                )\n            )\n\n\ndef main(args):\n    dgl.distributed.initialize(args.ip_config)\n    if not args.standalone:\n        th.distributed.init_process_group(backend=\"gloo\")\n\n    g = dgl.distributed.DistGraph(args.graph_name, part_config=args.conf_path)\n    print(\"rank:\", g.rank())\n\n    pb = g.get_partition_book()\n    if \"trainer_id\" in g.nodes[\"paper\"].data:\n        train_nid = dgl.distributed.node_split(\n            g.nodes[\"paper\"].data[\"train_mask\"],\n            pb,\n            ntype=\"paper\",\n            force_even=True,\n            node_trainer_ids=g.nodes[\"paper\"].data[\"trainer_id\"],\n        )\n        val_nid = dgl.distributed.node_split(\n            g.nodes[\"paper\"].data[\"val_mask\"],\n            pb,\n            ntype=\"paper\",\n            force_even=True,\n            node_trainer_ids=g.nodes[\"paper\"].data[\"trainer_id\"],\n        )\n        test_nid = dgl.distributed.node_split(\n            g.nodes[\"paper\"].data[\"test_mask\"],\n            pb,\n            ntype=\"paper\",\n            force_even=True,\n            node_trainer_ids=g.nodes[\"paper\"].data[\"trainer_id\"],\n        )\n    else:\n        train_nid = dgl.distributed.node_split(\n            g.nodes[\"paper\"].data[\"train_mask\"],\n            pb,\n            ntype=\"paper\",\n            force_even=True,\n        )\n        val_nid = dgl.distributed.node_split(\n            g.nodes[\"paper\"].data[\"val_mask\"],\n            pb,\n            ntype=\"paper\",\n            force_even=True,\n        )\n        test_nid = dgl.distributed.node_split(\n            g.nodes[\"paper\"].data[\"test_mask\"],\n            pb,\n            ntype=\"paper\",\n            force_even=True,\n        )\n    local_nid = pb.partid2nids(pb.partid, \"paper\").detach().numpy()\n    print(\n        \"part {}, train: {} (local: {}), val: {} (local: {}), test: {} (local: {})\".format(\n            g.rank(),\n            len(train_nid),\n            len(np.intersect1d(train_nid.numpy(), local_nid)),\n            len(val_nid),\n            len(np.intersect1d(val_nid.numpy(), local_nid)),\n            len(test_nid),\n            len(np.intersect1d(test_nid.numpy(), local_nid)),\n        )\n    )\n    if args.num_gpus == -1:\n        device = th.device(\"cpu\")\n    else:\n        dev_id = g.rank() % args.num_gpus\n        device = th.device(\"cuda:\" + str(dev_id))\n    labels = g.nodes[\"paper\"].data[\"labels\"][np.arange(g.num_nodes(\"paper\"))]\n    all_val_nid = th.LongTensor(\n        np.nonzero(\n            g.nodes[\"paper\"].data[\"val_mask\"][np.arange(g.num_nodes(\"paper\"))]\n        )\n    ).squeeze()\n    all_test_nid = th.LongTensor(\n        np.nonzero(\n            g.nodes[\"paper\"].data[\"test_mask\"][np.arange(g.num_nodes(\"paper\"))]\n        )\n    ).squeeze()\n    n_classes = len(th.unique(labels[labels >= 0]))\n    print(\"#classes:\", n_classes)\n\n    run(\n        args,\n        device,\n        (\n            g,\n            n_classes,\n            train_nid,\n            val_nid,\n            test_nid,\n            labels,\n            all_val_nid,\n            all_test_nid,\n        ),\n    )\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"RGCN\")\n    # distributed training related\n    parser.add_argument(\"--graph-name\", type=str, help=\"graph name\")\n    parser.add_argument(\"--id\", type=int, help=\"the partition id\")\n    parser.add_argument(\n        \"--ip-config\", type=str, help=\"The file for IP configuration\"\n    )\n    parser.add_argument(\n        \"--conf-path\", type=str, help=\"The path to the partition config file\"\n    )\n\n    # rgcn related\n    parser.add_argument(\n        \"--num_gpus\",\n        type=int,\n        default=-1,\n        help=\"the number of GPU device. Use -1 for CPU training\",\n    )\n    parser.add_argument(\n        \"--dropout\", type=float, default=0, help=\"dropout probability\"\n    )\n    parser.add_argument(\n        \"--n-hidden\", type=int, default=16, help=\"number of hidden units\"\n    )\n    parser.add_argument(\"--lr\", type=float, default=1e-2, help=\"learning rate\")\n    parser.add_argument(\n        \"--sparse-lr\", type=float, default=1e-2, help=\"sparse lr rate\"\n    )\n    parser.add_argument(\n        \"--n-bases\",\n        type=int,\n        default=-1,\n        help=\"number of filter weight matrices, default: -1 [use all]\",\n    )\n    parser.add_argument(\n        \"--n-layers\", type=int, default=2, help=\"number of propagation rounds\"\n    )\n    parser.add_argument(\n        \"-e\",\n        \"--n-epochs\",\n        type=int,\n        default=50,\n        help=\"number of training epochs\",\n    )\n    parser.add_argument(\n        \"-d\", \"--dataset\", type=str, required=True, help=\"dataset to use\"\n    )\n    parser.add_argument(\"--l2norm\", type=float, default=0, help=\"l2 norm coef\")\n    parser.add_argument(\n        \"--relabel\",\n        default=False,\n        action=\"store_true\",\n        help=\"remove untouched nodes and relabel\",\n    )\n    parser.add_argument(\n        \"--fanout\",\n        type=str,\n        default=\"4, 4\",\n        help=\"Fan-out of neighbor sampling.\",\n    )\n    parser.add_argument(\n        \"--validation-fanout\",\n        type=str,\n        default=None,\n        help=\"Fan-out of neighbor sampling during validation.\",\n    )\n    parser.add_argument(\n        \"--use-self-loop\",\n        default=False,\n        action=\"store_true\",\n        help=\"include self feature as a special relation\",\n    )\n    parser.add_argument(\n        \"--batch-size\", type=int, default=100, help=\"Mini-batch size. \"\n    )\n    parser.add_argument(\n        \"--eval-batch-size\", type=int, default=128, help=\"Mini-batch size. \"\n    )\n    parser.add_argument(\"--log-every\", type=int, default=20)\n    parser.add_argument(\n        \"--low-mem\",\n        default=False,\n        action=\"store_true\",\n        help=\"Whether use low mem RelGraphCov\",\n    )\n    parser.add_argument(\n        \"--sparse-embedding\",\n        action=\"store_true\",\n        help=\"Use sparse embedding for node embeddings.\",\n    )\n    parser.add_argument(\n        \"--dgl-sparse\",\n        action=\"store_true\",\n        help=\"Whether to use DGL sparse embedding\",\n    )\n    parser.add_argument(\n        \"--layer-norm\",\n        default=False,\n        action=\"store_true\",\n        help=\"Use layer norm\",\n    )\n    parser.add_argument(\n        \"--local_rank\", type=int, help=\"get rank of the process\"\n    )\n    parser.add_argument(\n        \"--standalone\", action=\"store_true\", help=\"run in the standalone mode\"\n    )\n    args = parser.parse_args()\n\n    # if validation_fanout is None, set it with args.fanout\n    if args.validation_fanout is None:\n        args.validation_fanout = args.fanout\n    print(args)\n    main(args)\n"
  },
  {
    "path": "examples/pytorch/rgcn/experimental/get_mag_data.py",
    "content": "import json\n\nimport dgl\nimport numpy as np\nimport torch as th\nfrom ogb.nodeproppred import DglNodePropPredDataset\n\n# Load OGB-MAG.\ndataset = DglNodePropPredDataset(name=\"ogbn-mag\")\nhg_orig, labels = dataset[0]\nsubgs = {}\nfor etype in hg_orig.canonical_etypes:\n    u, v = hg_orig.all_edges(etype=etype)\n    subgs[etype] = (u, v)\n    subgs[(etype[2], \"rev-\" + etype[1], etype[0])] = (v, u)\nhg = dgl.heterograph(subgs)\nhg.nodes[\"paper\"].data[\"feat\"] = hg_orig.nodes[\"paper\"].data[\"feat\"]\n\nsplit_idx = dataset.get_idx_split()\ntrain_idx = split_idx[\"train\"][\"paper\"]\nval_idx = split_idx[\"valid\"][\"paper\"]\ntest_idx = split_idx[\"test\"][\"paper\"]\npaper_labels = labels[\"paper\"].squeeze()\n\ntrain_mask = th.zeros((hg.num_nodes(\"paper\"),), dtype=th.bool)\ntrain_mask[train_idx] = True\nval_mask = th.zeros((hg.num_nodes(\"paper\"),), dtype=th.bool)\nval_mask[val_idx] = True\ntest_mask = th.zeros((hg.num_nodes(\"paper\"),), dtype=th.bool)\ntest_mask[test_idx] = True\nhg.nodes[\"paper\"].data[\"train_mask\"] = train_mask\nhg.nodes[\"paper\"].data[\"val_mask\"] = val_mask\nhg.nodes[\"paper\"].data[\"test_mask\"] = test_mask\nhg.nodes[\"paper\"].data[\"labels\"] = paper_labels\n\nwith open(\"outputs/mag.json\") as json_file:\n    metadata = json.load(json_file)\n\nfor part_id in range(metadata[\"num_parts\"]):\n    subg = dgl.load_graphs(\"outputs/part{}/graph.dgl\".format(part_id))[0][0]\n\n    node_data = {}\n    for ntype in hg.ntypes:\n        local_node_idx = th.logical_and(\n            subg.ndata[\"inner_node\"].bool(),\n            subg.ndata[dgl.NTYPE] == hg.get_ntype_id(ntype),\n        )\n        local_nodes = subg.ndata[\"orig_id\"][local_node_idx].numpy()\n        for name in hg.nodes[ntype].data:\n            node_data[ntype + \"/\" + name] = hg.nodes[ntype].data[name][\n                local_nodes\n            ]\n    print(\"node features:\", node_data.keys())\n    dgl.data.utils.save_tensors(\n        \"outputs/\" + metadata[\"part-{}\".format(part_id)][\"node_feats\"],\n        node_data,\n    )\n\n    edge_data = {}\n    for etype in hg.etypes:\n        local_edges = subg.edata[\"orig_id\"][\n            subg.edata[dgl.ETYPE] == hg.get_etype_id(etype)\n        ]\n        for name in hg.edges[etype].data:\n            edge_data[etype + \"/\" + name] = hg.edges[etype].data[name][\n                local_edges\n            ]\n    print(\"edge features:\", edge_data.keys())\n    dgl.data.utils.save_tensors(\n        \"outputs/\" + metadata[\"part-{}\".format(part_id)][\"edge_feats\"],\n        edge_data,\n    )\n"
  },
  {
    "path": "examples/pytorch/rgcn/experimental/partition_graph.py",
    "content": "import argparse\nimport time\n\nimport dgl\nimport numpy as np\nimport torch as th\n\nfrom ogb.nodeproppred import DglNodePropPredDataset\n\n\ndef load_ogb(dataset):\n    if dataset == \"ogbn-mag\":\n        dataset = DglNodePropPredDataset(name=dataset)\n        split_idx = dataset.get_idx_split()\n        train_idx = split_idx[\"train\"][\"paper\"]\n        val_idx = split_idx[\"valid\"][\"paper\"]\n        test_idx = split_idx[\"test\"][\"paper\"]\n        hg_orig, labels = dataset[0]\n        subgs = {}\n        for etype in hg_orig.canonical_etypes:\n            u, v = hg_orig.all_edges(etype=etype)\n            subgs[etype] = (u, v)\n            subgs[(etype[2], \"rev-\" + etype[1], etype[0])] = (v, u)\n        hg = dgl.heterograph(subgs)\n        hg.nodes[\"paper\"].data[\"feat\"] = hg_orig.nodes[\"paper\"].data[\"feat\"]\n        paper_labels = labels[\"paper\"].squeeze()\n\n        num_rels = len(hg.canonical_etypes)\n        num_of_ntype = len(hg.ntypes)\n        num_classes = dataset.num_classes\n        category = \"paper\"\n        print(\"Number of relations: {}\".format(num_rels))\n        print(\"Number of class: {}\".format(num_classes))\n        print(\"Number of train: {}\".format(len(train_idx)))\n        print(\"Number of valid: {}\".format(len(val_idx)))\n        print(\"Number of test: {}\".format(len(test_idx)))\n\n        # get target category id\n        category_id = len(hg.ntypes)\n        for i, ntype in enumerate(hg.ntypes):\n            if ntype == category:\n                category_id = i\n\n        train_mask = th.zeros((hg.num_nodes(\"paper\"),), dtype=th.bool)\n        train_mask[train_idx] = True\n        val_mask = th.zeros((hg.num_nodes(\"paper\"),), dtype=th.bool)\n        val_mask[val_idx] = True\n        test_mask = th.zeros((hg.num_nodes(\"paper\"),), dtype=th.bool)\n        test_mask[test_idx] = True\n        hg.nodes[\"paper\"].data[\"train_mask\"] = train_mask\n        hg.nodes[\"paper\"].data[\"val_mask\"] = val_mask\n        hg.nodes[\"paper\"].data[\"test_mask\"] = test_mask\n\n        hg.nodes[\"paper\"].data[\"labels\"] = paper_labels\n        return hg\n    else:\n        raise (\"Do not support other ogbn datasets.\")\n\n\nif __name__ == \"__main__\":\n    argparser = argparse.ArgumentParser(\"Partition builtin graphs\")\n    argparser.add_argument(\n        \"--dataset\", type=str, default=\"ogbn-mag\", help=\"datasets: ogbn-mag\"\n    )\n    argparser.add_argument(\n        \"--num_parts\", type=int, default=4, help=\"number of partitions\"\n    )\n    argparser.add_argument(\n        \"--part_method\", type=str, default=\"metis\", help=\"the partition method\"\n    )\n    argparser.add_argument(\n        \"--balance_train\",\n        action=\"store_true\",\n        help=\"balance the training size in each partition.\",\n    )\n    argparser.add_argument(\n        \"--undirected\",\n        action=\"store_true\",\n        help=\"turn the graph into an undirected graph.\",\n    )\n    argparser.add_argument(\n        \"--balance_edges\",\n        action=\"store_true\",\n        help=\"balance the number of edges in each partition.\",\n    )\n    argparser.add_argument(\n        \"--num_trainers_per_machine\",\n        type=int,\n        default=1,\n        help=\"the number of trainers per machine. The trainer ids are stored\\\n                                in the node feature 'trainer_id'\",\n    )\n    argparser.add_argument(\n        \"--output\",\n        type=str,\n        default=\"data\",\n        help=\"Output path of partitioned graph.\",\n    )\n    args = argparser.parse_args()\n\n    start = time.time()\n    g = load_ogb(args.dataset)\n\n    print(\n        \"load {} takes {:.3f} seconds\".format(args.dataset, time.time() - start)\n    )\n    print(\"|V|={}, |E|={}\".format(g.num_nodes(), g.num_edges()))\n    print(\n        \"train: {}, valid: {}, test: {}\".format(\n            th.sum(g.nodes[\"paper\"].data[\"train_mask\"]),\n            th.sum(g.nodes[\"paper\"].data[\"val_mask\"]),\n            th.sum(g.nodes[\"paper\"].data[\"test_mask\"]),\n        )\n    )\n\n    if args.balance_train:\n        balance_ntypes = {\"paper\": g.nodes[\"paper\"].data[\"train_mask\"]}\n    else:\n        balance_ntypes = None\n\n    dgl.distributed.partition_graph(\n        g,\n        args.dataset,\n        args.num_parts,\n        args.output,\n        part_method=args.part_method,\n        balance_ntypes=balance_ntypes,\n        balance_edges=args.balance_edges,\n        num_trainers_per_machine=args.num_trainers_per_machine,\n    )\n"
  },
  {
    "path": "examples/pytorch/rgcn/experimental/preprocessing_dist_training/edges/identity1/sample.csv",
    "content": "identity1,0,0,1\nidentity1,1,0,2\n\n"
  },
  {
    "path": "examples/pytorch/rgcn/experimental/preprocessing_dist_training/edges/identity2/sample.csv",
    "content": "identity2,0,1,2\nidentity2,1,0,2\n\n"
  },
  {
    "path": "examples/pytorch/rgcn/experimental/preprocessing_dist_training/edges/identity3/sample.csv",
    "content": "identity3,0,0,2\n"
  },
  {
    "path": "examples/pytorch/rgcn/experimental/preprocessing_dist_training/metis_creation.py",
    "content": "import argparse\nimport glob\nimport json\nimport os\nfrom collections import defaultdict\n\nimport pandas as pd\n\npath = os.getcwd()\n\nparser = argparse.ArgumentParser()\nparser.add_argument(\n    \"-n\", \"--name\", help=\"name of graph to create\", default=\"order\"\n)\nparser.add_argument(\n    \"-nc\",\n    \"--node_column\",\n    nargs=\"+\",\n    default=[\"order_id\", \"entity_index\", \"order_datetime\", \"cid\"],\n)\nparser.add_argument(\"-nk\", \"--node_key\", default=\"entity_index\")\nparser.add_argument(\n    \"-ec\",\n    \"--edge_column\",\n    nargs=\"+\",\n    default=[\n        \"predicate_type\",\n        \"predicate_index\",\n        \"entity_index\",\n        \"entity_index_y\",\n    ],\n)\nparser.add_argument(\"-es\", \"--edge_start\", default=\"entity_index\")\nparser.add_argument(\"-en\", \"--edge_end\", default=\"entity_index_y\")\nargs = parser.parse_args()\n\n\n# Store all types of node in nodes folder\nnodes_list = sorted(glob.glob(os.path.join(path, \"nodes/*\")))\n\nif os.path.exists(\"{}_nodes.txt\".format(args.name)):\n    os.remove(\"{}_nodes.txt\".format(args.name))\n\nschema_dict = defaultdict(dict)\n\nnode_type_id = 0\nall_nodes_count = 0\nfor node_type_name in nodes_list:\n    nodes_count = 0\n    csv_files = sorted(glob.glob(os.path.join(node_type_name, \"*.csv\")))\n    for file_name in csv_files:\n        df = pd.read_csv(\n            file_name,\n            error_bad_lines=False,\n            escapechar=\"\\\\\",\n            names=args.node_column,\n            usecols=[*range(len(args.node_column))],\n        )\n        df_entity = pd.DataFrame(df[args.node_key], columns=[args.node_key])\n        df_entity[\"type\"] = node_type_id\n        column_list = [\"type\"]\n        for weight_index in range(len(nodes_list)):\n            weight_num = \"weight{}\".format(weight_index)\n            column_list.append(weight_num)\n            if weight_index == node_type_id:\n                df_entity[weight_num] = 1\n            else:\n                df_entity[weight_num] = 0\n        nodes_count += len(df_entity.index)\n        column_list.append(args.node_key)\n        # This loop is trying to create file which servers as an input for Metis Algorithm.\n        # More details about metis input can been found here : https://docs.dgl.ai/en/0.6.x/guide/distributed-preprocessing.html#input-format-for-parmetis\n        df_entity.to_csv(\n            \"{}_nodes.txt\".format(args.name),\n            columns=column_list,\n            sep=\" \",\n            index=False,\n            header=False,\n            mode=\"a\",\n        )\n    schema_dict[\"nid\"][os.path.basename(node_type_name)] = [\n        all_nodes_count,\n        nodes_count + all_nodes_count,\n    ]\n    all_nodes_count += nodes_count\n    node_type_id += 1\n\n\nif os.path.exists(\"{}_edges.txt\".format(args.name)):\n    os.remove(\"{}_edges.txt\".format(args.name))\n\n# Store all types of edge in edges folder\nedges_list = sorted(glob.glob(os.path.join(path, \"edges/*\")))\n\n\nall_edges_count = 0\nedge_type_id = 0\nfor edge_type_name in edges_list:\n    edge_count = 0\n    csv_files = sorted(glob.glob(os.path.join(edge_type_name, \"*.csv\")))\n    for file_name in csv_files:\n        df = pd.read_csv(\n            file_name,\n            error_bad_lines=False,\n            escapechar=\"\\\\\",\n            names=args.edge_column,\n            usecols=[*range(len(args.edge_column))],\n        )\n        df_entity = pd.DataFrame(\n            df[[args.edge_start, args.edge_end]],\n            columns=[args.edge_start, args.edge_end],\n        )\n        df_entity[\"type\"] = edge_type_id\n        df_entity = df_entity.reset_index()\n        df_entity[\"number\"] = df_entity.index + edge_count\n        edge_count += len(df_entity.index)\n        # This loop is trying to create file which servers as an input for Metis Algorithm.\n        # More details about metis input can been found here : https://docs.dgl.ai/en/0.6.x/guide/distributed-preprocessing.html#input-format-for-parmetis\n        df_entity.to_csv(\n            \"{}_edges.txt\".format(args.name),\n            columns=[args.edge_start, args.edge_end, \"number\", \"type\"],\n            sep=\" \",\n            index=False,\n            header=False,\n            mode=\"a\",\n        )\n    schema_dict[\"eid\"][os.path.basename(edge_type_name)] = [\n        all_edges_count,\n        all_edges_count + edge_count,\n    ]\n    edge_type_id += 1\n    all_edges_count += edge_count\n\nif os.path.exists(\"{}_stats.txt\".format(args.name)):\n    os.remove(\"{}_stats.txt\".format(args.name))\n\n\ndf = pd.DataFrame(\n    [[all_nodes_count, all_edges_count, len(nodes_list)]],\n    columns=[\"nodes_count\", \"edges_count\", \"weight_count\"],\n)\ndf.to_csv(\n    \"{}_stats.txt\".format(args.name),\n    columns=[\"nodes_count\", \"edges_count\", \"weight_count\"],\n    sep=\" \",\n    index=False,\n    header=False,\n)\n\nif os.path.exists(\"{}.json\".format(args.name)):\n    os.remove(\"{}.json\".format(args.name))\n\nwith open(\"{}.json\".format(args.name), \"w\", encoding=\"utf8\") as json_file:\n    json.dump(schema_dict, json_file, ensure_ascii=False)\n"
  },
  {
    "path": "examples/pytorch/rgcn/experimental/preprocessing_dist_training/nodes/order/sample.csv",
    "content": "171-0000102-1785122,0,2021-06-01 21:15:33,18604601535\n171-0000550-1206725,1,2021-06-08 12:53:53,19613747325\n171-0000784-4201160,2,2021-06-05 16:27:42,8348611025\n"
  },
  {
    "path": "examples/pytorch/rgcn/experimental/preprocessing_dist_training/pre_process_dist_training.sh",
    "content": "#!/bin/bash\n\n\ncur_dir=$(pwd)\nhost_count=`cat hostfile | wc -l`\ngraph_name=\"order\"\nperhost_part=2\ncurrent_host=`ifconfig | grep -Eo 'inet (addr:)?([0-9]*\\.){3}[0-9]*' | grep -Eo '([0-9]*\\.){3}[0-9]*' | grep -v '127.0.0.1'`\n\necho \"metis creation start\"\n\n##Nodes \n`python3 metis_creation.py -n ${graph_name}`\n\necho \"metis creation ends\"\n\necho \"directory creation starts\"\nwhile read p; do\n  if [ \"$p\" != \"$current_host\" ]; then\n    `ssh ${p} \"mkdir -p ${cur_dir}\" < /dev/null`\n  fi\ndone <hostfile\n\necho \"directory creation ends\"\n\necho \"partioning starts\"\n`mpirun --hostfile hostfile -np ${host_count} pm_dglpart ${graph_name} ${perhost_part} > mpirun.out`\necho \"partioning ends\"\n\n\necho \"scp starts\"\nwhile read p; do\n  if [ \"$p\" != \"$current_host\" ]; then\n    `scp ${p}:${cur_dir}/* ./ < /dev/null`\n  fi\ndone <hostfile\necho \"scp ends\"\n\necho \"fetching removed edges starts\"\n`cat mpirun.out | grep \"Duplicate edges with metadata\" | awk -F'[][]' '{print $4}' > remove.csv`\necho \"fetching removed edges ends\"\n\necho \"homo graph to herto graph starts\"\n`python3 substitute_to_hetero.py -n order -r remove.csv`\necho \"homo graph to herto graph ends\"\n"
  },
  {
    "path": "examples/pytorch/rgcn/experimental/verify_mag_partitions.py",
    "content": "import json\nimport os\n\nimport dgl\nimport numpy as np\nimport torch as th\nfrom ogb.nodeproppred import DglNodePropPredDataset\n\npartitions_folder = \"outputs\"\ngraph_name = \"mag\"\nwith open(\"{}/{}.json\".format(partitions_folder, graph_name)) as json_file:\n    metadata = json.load(json_file)\nnum_parts = metadata[\"num_parts\"]\n\n# Load OGB-MAG.\ndataset = DglNodePropPredDataset(name=\"ogbn-mag\")\nhg_orig, labels = dataset[0]\nsubgs = {}\nfor etype in hg_orig.canonical_etypes:\n    u, v = hg_orig.all_edges(etype=etype)\n    subgs[etype] = (u, v)\n    subgs[(etype[2], \"rev-\" + etype[1], etype[0])] = (v, u)\nhg = dgl.heterograph(subgs)\nhg.nodes[\"paper\"].data[\"feat\"] = hg_orig.nodes[\"paper\"].data[\"feat\"]\n\n# Construct node data and edge data after reshuffling.\nnode_feats = {}\nedge_feats = {}\nfor partid in range(num_parts):\n    part_node_feats = dgl.data.utils.load_tensors(\n        \"{}/part{}/node_feat.dgl\".format(partitions_folder, partid)\n    )\n    part_edge_feats = dgl.data.utils.load_tensors(\n        \"{}/part{}/edge_feat.dgl\".format(partitions_folder, partid)\n    )\n    for key in part_node_feats:\n        if key in node_feats:\n            node_feats[key].append(part_node_feats[key])\n        else:\n            node_feats[key] = [part_node_feats[key]]\n    for key in part_edge_feats:\n        if key in edge_feats:\n            edge_feats[key].append(part_edge_feats[key])\n        else:\n            edge_feats[key] = [part_edge_feats[key]]\nfor key in node_feats:\n    node_feats[key] = th.cat(node_feats[key])\nfor key in edge_feats:\n    edge_feats[key] = th.cat(edge_feats[key])\n\nntype_map = metadata[\"ntypes\"]\nntypes = [None] * len(ntype_map)\nfor key in ntype_map:\n    ntype_id = ntype_map[key]\n    ntypes[ntype_id] = key\netype_map = metadata[\"etypes\"]\netypes = [None] * len(etype_map)\nfor key in etype_map:\n    etype_id = etype_map[key]\n    etypes[etype_id] = key\n\netype2canonical = {\n    etype: (srctype, etype, dsttype)\n    for srctype, etype, dsttype in hg.canonical_etypes\n}\n\nnode_map = metadata[\"node_map\"]\nfor key in node_map:\n    node_map[key] = th.stack([th.tensor(row) for row in node_map[key]], 0)\nnid_map = dgl.distributed.id_map.IdMap(node_map)\nedge_map = metadata[\"edge_map\"]\nfor key in edge_map:\n    edge_map[key] = th.stack([th.tensor(row) for row in edge_map[key]], 0)\neid_map = dgl.distributed.id_map.IdMap(edge_map)\n\nfor ntype in node_map:\n    assert hg.num_nodes(ntype) == th.sum(\n        node_map[ntype][:, 1] - node_map[ntype][:, 0]\n    )\nfor etype in edge_map:\n    assert hg.num_edges(etype) == th.sum(\n        edge_map[etype][:, 1] - edge_map[etype][:, 0]\n    )\n\n# verify part_0 with graph_partition_book\neid = []\ngpb = dgl.distributed.graph_partition_book.RangePartitionBook(\n    0,\n    num_parts,\n    node_map,\n    edge_map,\n    {ntype: i for i, ntype in enumerate(hg.ntypes)},\n    {etype: i for i, etype in enumerate(hg.etypes)},\n)\nsubg0 = dgl.load_graphs(\"{}/part0/graph.dgl\".format(partitions_folder))[0][0]\nfor etype in hg.etypes:\n    type_eid = th.zeros((1,), dtype=th.int64)\n    eid.append(gpb.map_to_homo_eid(type_eid, etype))\neid = th.cat(eid)\npart_id = gpb.eid2partid(eid)\nassert th.all(part_id == 0)\nlocal_eid = gpb.eid2localeid(eid, 0)\nassert th.all(local_eid == eid)\nassert th.all(subg0.edata[dgl.EID][local_eid] == eid)\nlsrc, ldst = subg0.find_edges(local_eid)\ngsrc, gdst = subg0.ndata[dgl.NID][lsrc], subg0.ndata[dgl.NID][ldst]\n# The destination nodes are owned by the partition.\nassert th.all(gdst == ldst)\n# gdst which is not assigned into current partition is not required to equal ldst\nassert th.all(th.logical_or(gdst == ldst, subg0.ndata[\"inner_node\"][ldst] == 0))\netids, _ = gpb.map_to_per_etype(eid)\nsrc_tids, _ = gpb.map_to_per_ntype(gsrc)\ndst_tids, _ = gpb.map_to_per_ntype(gdst)\ncanonical_etypes = []\netype_ids = th.arange(0, len(etypes))\nfor src_tid, etype_id, dst_tid in zip(src_tids, etype_ids, dst_tids):\n    canonical_etypes.append(\n        (ntypes[src_tid], etypes[etype_id], ntypes[dst_tid])\n    )\nfor etype in canonical_etypes:\n    assert etype in hg.canonical_etypes\n\n# Load the graph partition structure.\norig_node_ids = {ntype: [] for ntype in hg.ntypes}\norig_edge_ids = {etype: [] for etype in hg.etypes}\nfor partid in range(num_parts):\n    print(\"test part\", partid)\n    part_file = \"{}/part{}/graph.dgl\".format(partitions_folder, partid)\n    subg = dgl.load_graphs(part_file)[0][0]\n    subg_src_id, subg_dst_id = subg.edges()\n    orig_src_id = subg.ndata[\"orig_id\"][subg_src_id]\n    orig_dst_id = subg.ndata[\"orig_id\"][subg_dst_id]\n    global_src_id = subg.ndata[dgl.NID][subg_src_id]\n    global_dst_id = subg.ndata[dgl.NID][subg_dst_id]\n    subg_ntype = subg.ndata[dgl.NTYPE]\n    subg_etype = subg.edata[dgl.ETYPE]\n    for ntype_id in th.unique(subg_ntype):\n        ntype = ntypes[ntype_id]\n        idx = subg_ntype == ntype_id\n        # This is global IDs after reshuffle.\n        nid = subg.ndata[dgl.NID][idx]\n        ntype_ids1, type_nid = nid_map(nid)\n        orig_type_nid = subg.ndata[\"orig_id\"][idx]\n        inner_node = subg.ndata[\"inner_node\"][idx]\n        # All nodes should have the same node type.\n        assert np.all(ntype_ids1.numpy() == int(ntype_id))\n        assert np.all(\n            nid[inner_node == 1].numpy()\n            == np.arange(node_map[ntype][partid, 0], node_map[ntype][partid, 1])\n        )\n        orig_node_ids[ntype].append(orig_type_nid[inner_node == 1])\n\n        # Check the degree of the inner nodes.\n        inner_nids = th.nonzero(\n            th.logical_and(subg_ntype == ntype_id, subg.ndata[\"inner_node\"]),\n            as_tuple=True,\n        )[0]\n        subg_deg = subg.in_degrees(inner_nids)\n        orig_nids = subg.ndata[\"orig_id\"][inner_nids]\n        # Calculate the in-degrees of nodes of a particular node type.\n        glob_deg = th.zeros(len(subg_deg), dtype=th.int64)\n        for etype in hg.canonical_etypes:\n            dst_ntype = etype[2]\n            if dst_ntype == ntype:\n                glob_deg += hg.in_degrees(orig_nids, etype=etype)\n        assert np.all(glob_deg.numpy() == subg_deg.numpy())\n\n        # Check node data.\n        for name in hg.nodes[ntype].data:\n            local_data = node_feats[ntype + \"/\" + name][type_nid]\n            local_data1 = hg.nodes[ntype].data[name][orig_type_nid]\n            assert np.all(local_data.numpy() == local_data1.numpy())\n\n    for etype_id in th.unique(subg_etype):\n        etype = etypes[etype_id]\n        srctype, _, dsttype = etype2canonical[etype]\n        idx = subg_etype == etype_id\n        exist = hg[etype].has_edges_between(orig_src_id[idx], orig_dst_id[idx])\n        assert np.all(exist.numpy())\n        eid = hg[etype].edge_ids(orig_src_id[idx], orig_dst_id[idx])\n        assert np.all(eid.numpy() == subg.edata[\"orig_id\"][idx].numpy())\n\n        ntype_ids, type_nid = nid_map(global_src_id[idx])\n        assert len(th.unique(ntype_ids)) == 1\n        assert ntypes[ntype_ids[0]] == srctype\n        ntype_ids, type_nid = nid_map(global_dst_id[idx])\n        assert len(th.unique(ntype_ids)) == 1\n        assert ntypes[ntype_ids[0]] == dsttype\n\n        # This is global IDs after reshuffle.\n        eid = subg.edata[dgl.EID][idx]\n        etype_ids1, type_eid = eid_map(eid)\n        orig_type_eid = subg.edata[\"orig_id\"][idx]\n        inner_edge = subg.edata[\"inner_edge\"][idx]\n        # All edges should have the same edge type.\n        assert np.all(etype_ids1.numpy() == int(etype_id))\n        assert np.all(\n            np.sort(eid[inner_edge == 1].numpy())\n            == np.arange(edge_map[etype][partid, 0], edge_map[etype][partid, 1])\n        )\n        orig_edge_ids[etype].append(orig_type_eid[inner_edge == 1])\n\n        # Check edge data.\n        for name in hg.edges[etype].data:\n            local_data = edge_feats[etype + \"/\" + name][type_eid]\n            local_data1 = hg.edges[etype].data[name][orig_type_eid]\n            assert np.all(local_data.numpy() == local_data1.numpy())\n\nfor ntype in orig_node_ids:\n    nids = th.cat(orig_node_ids[ntype])\n    nids = th.sort(nids)[0]\n    assert np.all((nids == th.arange(hg.num_nodes(ntype))).numpy())\n\nfor etype in orig_edge_ids:\n    eids = th.cat(orig_edge_ids[etype])\n    eids = th.sort(eids)[0]\n    assert np.all((eids == th.arange(hg.num_edges(etype))).numpy())\n"
  },
  {
    "path": "examples/pytorch/rgcn/experimental/write_mag.py",
    "content": "import json\n\nimport dgl\nimport numpy as np\nimport torch as th\nfrom ogb.nodeproppred import DglNodePropPredDataset\n\n# Load OGB-MAG.\ndataset = DglNodePropPredDataset(name=\"ogbn-mag\")\nhg_orig, labels = dataset[0]\nsubgs = {}\nfor etype in hg_orig.canonical_etypes:\n    u, v = hg_orig.all_edges(etype=etype)\n    subgs[etype] = (u, v)\n    subgs[(etype[2], \"rev-\" + etype[1], etype[0])] = (v, u)\nhg = dgl.heterograph(subgs)\nhg.nodes[\"paper\"].data[\"feat\"] = hg_orig.nodes[\"paper\"].data[\"feat\"]\nprint(hg)\n\n# OGB-MAG is stored in heterogeneous format. We need to convert it into homogeneous format.\ng = dgl.to_homogeneous(hg)\ng.ndata[\"orig_id\"] = g.ndata[dgl.NID]\ng.edata[\"orig_id\"] = g.edata[dgl.EID]\nprint(\"|V|=\" + str(g.num_nodes()))\nprint(\"|E|=\" + str(g.num_edges()))\nprint(\"|NTYPE|=\" + str(len(th.unique(g.ndata[dgl.NTYPE]))))\n\n# Store the metadata of nodes.\nnum_node_weights = 0\nnode_data = [g.ndata[dgl.NTYPE].numpy()]\nfor ntype_id in th.unique(g.ndata[dgl.NTYPE]):\n    node_data.append((g.ndata[dgl.NTYPE] == ntype_id).numpy())\n    num_node_weights += 1\nnode_data.append(g.ndata[\"orig_id\"].numpy())\nnode_data = np.stack(node_data, 1)\nnp.savetxt(\"mag_nodes.txt\", node_data, fmt=\"%d\", delimiter=\" \")\n\n# Store the node features\nnode_feats = {}\nfor ntype in hg.ntypes:\n    for name in hg.nodes[ntype].data:\n        node_feats[ntype + \"/\" + name] = hg.nodes[ntype].data[name]\ndgl.data.utils.save_tensors(\"node_feat.dgl\", node_feats)\n\n# Store the metadata of edges.\n# ParMETIS cannot handle duplicated edges and self-loops. We should remove them\n# in the preprocessing.\nsrc_id, dst_id = g.edges()\n# Remove self-loops\nself_loop_idx = src_id == dst_id\nnot_self_loop_idx = src_id != dst_id\nself_loop_src_id = src_id[self_loop_idx]\nself_loop_dst_id = dst_id[self_loop_idx]\nself_loop_orig_id = g.edata[\"orig_id\"][self_loop_idx]\nself_loop_etype = g.edata[dgl.ETYPE][self_loop_idx]\nsrc_id = src_id[not_self_loop_idx]\ndst_id = dst_id[not_self_loop_idx]\norig_id = g.edata[\"orig_id\"][not_self_loop_idx]\netype = g.edata[dgl.ETYPE][not_self_loop_idx]\n# Remove duplicated edges.\nids = (src_id * g.num_nodes() + dst_id).numpy()\nuniq_ids, idx = np.unique(ids, return_index=True)\nduplicate_idx = np.setdiff1d(np.arange(len(ids)), idx)\nduplicate_src_id = src_id[duplicate_idx]\nduplicate_dst_id = dst_id[duplicate_idx]\nduplicate_orig_id = orig_id[duplicate_idx]\nduplicate_etype = etype[duplicate_idx]\nsrc_id = src_id[idx]\ndst_id = dst_id[idx]\norig_id = orig_id[idx]\netype = etype[idx]\nedge_data = th.stack([src_id, dst_id, orig_id, etype], 1)\nnp.savetxt(\"mag_edges.txt\", edge_data.numpy(), fmt=\"%d\", delimiter=\" \")\nremoved_edge_data = th.stack(\n    [\n        th.cat([self_loop_src_id, duplicate_src_id]),\n        th.cat([self_loop_dst_id, duplicate_dst_id]),\n        th.cat([self_loop_orig_id, duplicate_orig_id]),\n        th.cat([self_loop_etype, duplicate_etype]),\n    ],\n    1,\n)\nnp.savetxt(\n    \"mag_removed_edges.txt\", removed_edge_data.numpy(), fmt=\"%d\", delimiter=\" \"\n)\nprint(\n    \"There are {} edges, remove {} self-loops and {} duplicated edges\".format(\n        g.num_edges(), len(self_loop_src_id), len(duplicate_src_id)\n    )\n)\n\n# Store the edge features\nedge_feats = {}\nfor etype in hg.etypes:\n    for name in hg.edges[etype].data:\n        edge_feats[etype + \"/\" + name] = hg.edges[etype].data[name]\ndgl.data.utils.save_tensors(\"edge_feat.dgl\", edge_feats)\n\n# Store the basic metadata of the graph.\ngraph_stats = [g.num_nodes(), len(src_id), num_node_weights]\nwith open(\"mag_stats.txt\", \"w\") as filehandle:\n    filehandle.writelines(\n        \"{} {} {}\".format(graph_stats[0], graph_stats[1], graph_stats[2])\n    )\n\n# Store the ID ranges of nodes and edges of the entire graph.\nnid_ranges = {}\neid_ranges = {}\nfor ntype in hg.ntypes:\n    ntype_id = hg.get_ntype_id(ntype)\n    nid = th.nonzero(g.ndata[dgl.NTYPE] == ntype_id, as_tuple=True)[0]\n    per_type_nid = g.ndata[\"orig_id\"][nid]\n    assert np.all((per_type_nid == th.arange(len(per_type_nid))).numpy())\n    assert np.all((nid == th.arange(nid[0], nid[-1] + 1)).numpy())\n    nid_ranges[ntype] = [int(nid[0]), int(nid[-1] + 1)]\nfor etype in hg.etypes:\n    etype_id = hg.get_etype_id(etype)\n    eid = th.nonzero(g.edata[dgl.ETYPE] == etype_id, as_tuple=True)[0]\n    assert np.all((eid == th.arange(eid[0], eid[-1] + 1)).numpy())\n    eid_ranges[etype] = [int(eid[0]), int(eid[-1] + 1)]\nwith open(\"mag.json\", \"w\") as outfile:\n    json.dump({\"nid\": nid_ranges, \"eid\": eid_ranges}, outfile, indent=4)\n"
  },
  {
    "path": "examples/pytorch/rgcn/link.py",
    "content": "import dgl\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport tqdm\nfrom dgl.data.knowledge_graph import FB15k237Dataset\nfrom dgl.dataloading import GraphDataLoader\nfrom dgl.nn.pytorch import RelGraphConv\n\n\n# for building training/testing graphs\ndef get_subset_g(g, mask, num_rels, bidirected=False):\n    src, dst = g.edges()\n    sub_src = src[mask]\n    sub_dst = dst[mask]\n    sub_rel = g.edata[\"etype\"][mask]\n\n    if bidirected:\n        sub_src, sub_dst = torch.cat([sub_src, sub_dst]), torch.cat(\n            [sub_dst, sub_src]\n        )\n        sub_rel = torch.cat([sub_rel, sub_rel + num_rels])\n\n    sub_g = dgl.graph((sub_src, sub_dst), num_nodes=g.num_nodes())\n    sub_g.edata[dgl.ETYPE] = sub_rel\n    return sub_g\n\n\nclass GlobalUniform:\n    def __init__(self, g, sample_size):\n        self.sample_size = sample_size\n        self.eids = np.arange(g.num_edges())\n\n    def sample(self):\n        return torch.from_numpy(np.random.choice(self.eids, self.sample_size))\n\n\nclass NegativeSampler:\n    def __init__(self, k=10):  # negative sampling rate = 10\n        self.k = k\n\n    def sample(self, pos_samples, num_nodes):\n        batch_size = len(pos_samples)\n        neg_batch_size = batch_size * self.k\n        neg_samples = np.tile(pos_samples, (self.k, 1))\n\n        values = np.random.randint(num_nodes, size=neg_batch_size)\n        choices = np.random.uniform(size=neg_batch_size)\n        subj = choices > 0.5\n        obj = choices <= 0.5\n        neg_samples[subj, 0] = values[subj]\n        neg_samples[obj, 2] = values[obj]\n        samples = np.concatenate((pos_samples, neg_samples))\n\n        # binary labels indicating positive and negative samples\n        labels = np.zeros(batch_size * (self.k + 1), dtype=np.float32)\n        labels[:batch_size] = 1\n\n        return torch.from_numpy(samples), torch.from_numpy(labels)\n\n\nclass SubgraphIterator:\n    def __init__(self, g, num_rels, sample_size=30000, num_epochs=6000):\n        self.g = g\n        self.num_rels = num_rels\n        self.sample_size = sample_size\n        self.num_epochs = num_epochs\n        self.pos_sampler = GlobalUniform(g, sample_size)\n        self.neg_sampler = NegativeSampler()\n\n    def __len__(self):\n        return self.num_epochs\n\n    def __getitem__(self, i):\n        eids = self.pos_sampler.sample()\n        src, dst = self.g.find_edges(eids)\n        src, dst = src.numpy(), dst.numpy()\n        rel = self.g.edata[dgl.ETYPE][eids].numpy()\n\n        # relabel nodes to have consecutive node IDs\n        uniq_v, edges = np.unique((src, dst), return_inverse=True)\n        num_nodes = len(uniq_v)\n        # edges is the concatenation of src, dst with relabeled ID\n        src, dst = np.reshape(edges, (2, -1))\n        relabeled_data = np.stack((src, rel, dst)).transpose()\n\n        samples, labels = self.neg_sampler.sample(relabeled_data, num_nodes)\n\n        # use only half of the positive edges\n        chosen_ids = np.random.choice(\n            np.arange(self.sample_size),\n            size=int(self.sample_size / 2),\n            replace=False,\n        )\n        src = src[chosen_ids]\n        dst = dst[chosen_ids]\n        rel = rel[chosen_ids]\n        src, dst = np.concatenate((src, dst)), np.concatenate((dst, src))\n        rel = np.concatenate((rel, rel + self.num_rels))\n        sub_g = dgl.graph((src, dst), num_nodes=num_nodes)\n        sub_g.edata[dgl.ETYPE] = torch.from_numpy(rel)\n        sub_g.edata[\"norm\"] = dgl.norm_by_dst(sub_g).unsqueeze(-1)\n        uniq_v = torch.from_numpy(uniq_v).view(-1).long()\n\n        return sub_g, uniq_v, samples, labels\n\n\nclass RGCN(nn.Module):\n    def __init__(self, num_nodes, h_dim, num_rels):\n        super().__init__()\n        # two-layer RGCN\n        self.emb = nn.Embedding(num_nodes, h_dim)\n        self.conv1 = RelGraphConv(\n            h_dim,\n            h_dim,\n            num_rels,\n            regularizer=\"bdd\",\n            num_bases=100,\n            self_loop=True,\n        )\n        self.conv2 = RelGraphConv(\n            h_dim,\n            h_dim,\n            num_rels,\n            regularizer=\"bdd\",\n            num_bases=100,\n            self_loop=True,\n        )\n        self.dropout = nn.Dropout(0.2)\n\n    def forward(self, g, nids):\n        x = self.emb(nids)\n        h = F.relu(self.conv1(g, x, g.edata[dgl.ETYPE], g.edata[\"norm\"]))\n        h = self.dropout(h)\n        h = self.conv2(g, h, g.edata[dgl.ETYPE], g.edata[\"norm\"])\n        return self.dropout(h)\n\n\nclass LinkPredict(nn.Module):\n    def __init__(self, num_nodes, num_rels, h_dim=500, reg_param=0.01):\n        super().__init__()\n        self.rgcn = RGCN(num_nodes, h_dim, num_rels * 2)\n        self.reg_param = reg_param\n        self.w_relation = nn.Parameter(torch.Tensor(num_rels, h_dim))\n        nn.init.xavier_uniform_(\n            self.w_relation, gain=nn.init.calculate_gain(\"relu\")\n        )\n\n    def calc_score(self, embedding, triplets):\n        s = embedding[triplets[:, 0]]\n        r = self.w_relation[triplets[:, 1]]\n        o = embedding[triplets[:, 2]]\n        score = torch.sum(s * r * o, dim=1)\n        return score\n\n    def forward(self, g, nids):\n        return self.rgcn(g, nids)\n\n    def regularization_loss(self, embedding):\n        return torch.mean(embedding.pow(2)) + torch.mean(self.w_relation.pow(2))\n\n    def get_loss(self, embed, triplets, labels):\n        # each row in the triplets is a 3-tuple of (source, relation, destination)\n        score = self.calc_score(embed, triplets)\n        predict_loss = F.binary_cross_entropy_with_logits(score, labels)\n        reg_loss = self.regularization_loss(embed)\n        return predict_loss + self.reg_param * reg_loss\n\n\ndef filter(\n    triplets_to_filter, target_s, target_r, target_o, num_nodes, filter_o=True\n):\n    \"\"\"Get candidate heads or tails to score\"\"\"\n    target_s, target_r, target_o = int(target_s), int(target_r), int(target_o)\n    # Add the ground truth node first\n    if filter_o:\n        candidate_nodes = [target_o]\n    else:\n        candidate_nodes = [target_s]\n    for e in range(num_nodes):\n        triplet = (\n            (target_s, target_r, e) if filter_o else (e, target_r, target_o)\n        )\n        # Do not consider a node if it leads to a real triplet\n        if triplet not in triplets_to_filter:\n            candidate_nodes.append(e)\n    return torch.LongTensor(candidate_nodes)\n\n\ndef perturb_and_get_filtered_rank(\n    emb, w, s, r, o, test_size, triplets_to_filter, filter_o=True\n):\n    \"\"\"Perturb subject or object in the triplets\"\"\"\n    num_nodes = emb.shape[0]\n    ranks = []\n    for idx in tqdm.tqdm(range(test_size), desc=\"Evaluate\"):\n        target_s = s[idx]\n        target_r = r[idx]\n        target_o = o[idx]\n        candidate_nodes = filter(\n            triplets_to_filter,\n            target_s,\n            target_r,\n            target_o,\n            num_nodes,\n            filter_o=filter_o,\n        )\n        if filter_o:\n            emb_s = emb[target_s]\n            emb_o = emb[candidate_nodes]\n        else:\n            emb_s = emb[candidate_nodes]\n            emb_o = emb[target_o]\n        target_idx = 0\n        emb_r = w[target_r]\n        emb_triplet = emb_s * emb_r * emb_o\n        scores = torch.sigmoid(torch.sum(emb_triplet, dim=1))\n\n        _, indices = torch.sort(scores, descending=True)\n        rank = int((indices == target_idx).nonzero())\n        ranks.append(rank)\n    return torch.LongTensor(ranks)\n\n\ndef calc_mrr(emb, w, mask, triplets_to_filter, batch_size=100, filter=True):\n    with torch.no_grad():\n        test_triplets = triplets_to_filter[mask]\n        s, r, o = test_triplets[:, 0], test_triplets[:, 1], test_triplets[:, 2]\n        test_size = len(s)\n        triplets_to_filter = {\n            tuple(triplet) for triplet in triplets_to_filter.tolist()\n        }\n        ranks_s = perturb_and_get_filtered_rank(\n            emb, w, s, r, o, test_size, triplets_to_filter, filter_o=False\n        )\n        ranks_o = perturb_and_get_filtered_rank(\n            emb, w, s, r, o, test_size, triplets_to_filter\n        )\n        ranks = torch.cat([ranks_s, ranks_o])\n        ranks += 1  # change to 1-indexed\n        mrr = torch.mean(1.0 / ranks.float()).item()\n    return mrr\n\n\ndef train(\n    dataloader,\n    test_g,\n    test_nids,\n    val_mask,\n    triplets,\n    device,\n    model_state_file,\n    model,\n):\n    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)\n    best_mrr = 0\n    for epoch, batch_data in enumerate(dataloader):  # single graph batch\n        model.train()\n        g, train_nids, edges, labels = batch_data\n        g = g.to(device)\n        train_nids = train_nids.to(device)\n        edges = edges.to(device)\n        labels = labels.to(device)\n\n        embed = model(g, train_nids)\n        loss = model.get_loss(embed, edges, labels)\n        optimizer.zero_grad()\n        loss.backward()\n        nn.utils.clip_grad_norm_(\n            model.parameters(), max_norm=1.0\n        )  # clip gradients\n        optimizer.step()\n        print(\n            \"Epoch {:04d} | Loss {:.4f} | Best MRR {:.4f}\".format(\n                epoch, loss.item(), best_mrr\n            )\n        )\n        if (epoch + 1) % 500 == 0:\n            # perform validation on CPU because full graph is too large\n            model = model.cpu()\n            model.eval()\n            embed = model(test_g, test_nids)\n            mrr = calc_mrr(\n                embed, model.w_relation, val_mask, triplets, batch_size=500\n            )\n            # save best model\n            if best_mrr < mrr:\n                best_mrr = mrr\n                torch.save(\n                    {\"state_dict\": model.state_dict(), \"epoch\": epoch},\n                    model_state_file,\n                )\n            model = model.to(device)\n\n\nif __name__ == \"__main__\":\n    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n    print(f\"Training with DGL built-in RGCN module\")\n\n    # load and preprocess dataset\n    data = FB15k237Dataset(reverse=False)\n    g = data[0]\n    num_nodes = g.num_nodes()\n    num_rels = data.num_rels\n    train_g = get_subset_g(g, g.edata[\"train_mask\"], num_rels)\n    test_g = get_subset_g(g, g.edata[\"train_mask\"], num_rels, bidirected=True)\n    test_g.edata[\"norm\"] = dgl.norm_by_dst(test_g).unsqueeze(-1)\n    test_nids = torch.arange(0, num_nodes)\n    val_mask = g.edata[\"val_mask\"]\n    test_mask = g.edata[\"test_mask\"]\n    subg_iter = SubgraphIterator(train_g, num_rels)  # uniform edge sampling\n    dataloader = GraphDataLoader(\n        subg_iter, batch_size=1, collate_fn=lambda x: x[0]\n    )\n\n    # Prepare data for metric computation\n    src, dst = g.edges()\n    triplets = torch.stack([src, g.edata[\"etype\"], dst], dim=1)\n\n    # create RGCN model\n    model = LinkPredict(num_nodes, num_rels).to(device)\n\n    # train\n    model_state_file = \"model_state.pth\"\n    train(\n        dataloader,\n        test_g,\n        test_nids,\n        val_mask,\n        triplets,\n        device,\n        model_state_file,\n        model,\n    )\n\n    # testing\n    print(\"Testing...\")\n    checkpoint = torch.load(model_state_file, weights_only=False)\n    model = model.cpu()  # test on CPU\n    model.eval()\n    model.load_state_dict(checkpoint[\"state_dict\"])\n    embed = model(test_g, test_nids)\n    best_mrr = calc_mrr(\n        embed, model.w_relation, test_mask, triplets, batch_size=500\n    )\n    print(\n        \"Best MRR {:.4f} achieved using the epoch {:04d}\".format(\n            best_mrr, checkpoint[\"epoch\"]\n        )\n    )\n"
  },
  {
    "path": "examples/pytorch/rgcn/model.py",
    "content": "import dgl\nimport torch as th\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dgl import DGLGraph\n\nfrom dgl.nn.pytorch import RelGraphConv\n\n\nclass RGCN(nn.Module):\n    def __init__(\n        self,\n        num_nodes,\n        h_dim,\n        out_dim,\n        num_rels,\n        regularizer=\"basis\",\n        num_bases=-1,\n        dropout=0.0,\n        self_loop=False,\n        ns_mode=False,\n    ):\n        super(RGCN, self).__init__()\n\n        if num_bases == -1:\n            num_bases = num_rels\n        self.emb = nn.Embedding(num_nodes, h_dim)\n        self.conv1 = RelGraphConv(\n            h_dim, h_dim, num_rels, regularizer, num_bases, self_loop=self_loop\n        )\n        self.conv2 = RelGraphConv(\n            h_dim,\n            out_dim,\n            num_rels,\n            regularizer,\n            num_bases,\n            self_loop=self_loop,\n        )\n        self.dropout = nn.Dropout(dropout)\n        self.ns_mode = ns_mode\n\n    def forward(self, g, nids=None):\n        if self.ns_mode:\n            # forward for neighbor sampling\n            x = self.emb(g[0].srcdata[dgl.NID])\n            h = self.conv1(g[0], x, g[0].edata[dgl.ETYPE], g[0].edata[\"norm\"])\n            h = self.dropout(F.relu(h))\n            h = self.conv2(g[1], h, g[1].edata[dgl.ETYPE], g[1].edata[\"norm\"])\n            return h\n        else:\n            x = self.emb.weight if nids is None else self.emb(nids)\n            h = self.conv1(g, x, g.edata[dgl.ETYPE], g.edata[\"norm\"])\n            h = self.dropout(F.relu(h))\n            h = self.conv2(g, h, g.edata[dgl.ETYPE], g.edata[\"norm\"])\n            return h\n"
  },
  {
    "path": "examples/pytorch/rgcn-hetero/.gitignore",
    "content": "*.pth\n*.pt\n"
  },
  {
    "path": "examples/pytorch/rgcn-hetero/README.md",
    "content": "# Relational-GCN\n\n* Paper: [https://arxiv.org/abs/1703.06103](https://arxiv.org/abs/1703.06103)\n* Author's code for entity classification: [https://github.com/tkipf/relational-gcn](https://github.com/tkipf/relational-gcn)\n* Author's code for link prediction: [https://github.com/MichSchli/RelationPrediction](https://github.com/MichSchli/RelationPrediction)\n\nThe preprocessing is slightly different from the author's code. We directly load and preprocess\nraw RDF data. For AIFB, BGS and AM,\nall literal nodes are pruned from the graph. For AIFB, some training/testing nodes\nthus become orphan and are excluded from the training/testing set. The resulting graph\nhas fewer entities and relations. As a reference (numbers include reverse edges and relations):\n\n| Dataset | #Nodes | #Edges | #Relations | #Labeled |\n| --- | --- | --- | --- | --- |\n| AIFB | 8,285 | 58,086 | 90 | 176 |\n| AIFB-hetero | 7,262 | 48,810 | 78 | 176 |\n| MUTAG | 23,644 | 148,454 | 46 | 340 |\n| MUTAG-hetero | 27,163 | 148,100 | 46 | 340 |\n| BGS | 333,845 | 1,832,398 | 206 | 146 |\n| BGS-hetero | 94,806 | 672,884 | 96 | 146 |\n| AM | 1,666,764 | 11,976,642 | 266 | 1000 |\n| AM-hetero | 881,680 | 5,668,682 | 96 | 1000 |\n\n### Dependencies\n* PyTorch 1.0+\n* requests\n* rdflib\n\n```\npip install requests torch rdflib pandas\n```\n\nExample code was tested with rdflib 4.2.2 and pandas 0.23.4\n\n### Entity Classification\n\nAll experiments use one-hot encoding as featureless input. Best accuracy reported.\n\n\nAIFB: accuracy 96.11% (5 runs, DGL), 95.83% (paper)\n```\npython3 entity_classify.py -d aifb --testing --gpu 0\n```\n\nMUTAG: accuracy 72.06% (5 runs, DGL), 73.23% (paper)\n```\npython3 entity_classify.py -d mutag --l2norm 5e-4 --n-bases 30 --testing --gpu 0\n```\n\nBGS: accuracy 91.73% (5 runs, DGL), 83.10% (paper)\n```\npython3 entity_classify.py -d bgs --l2norm 5e-4 --n-bases 40 --testing --gpu 0\n```\n\nAM: accuracy 88.28% (5 runs, DGL), 89.29% (paper)\n```\npython3 entity_classify.py -d am --l2norm 5e-4 --n-bases 40 --testing --gpu 0\n```\n\n### Entity Classification w/ minibatch training\n\nAccuracy numbers are reported by 5 runs.\n\nAIFB: accuracy best=97.22% avg=94.44%\n```\npython3 entity_classify_mb.py -d aifb --testing --gpu 0 --fanout=8\n```\n\nMUTAG: accuracy best=76.47% avg=67.37%\n```\npython3 entity_classify_mb.py -d mutag --l2norm 5e-4 --n-bases 30 --testing --gpu 0 --batch-size=50 --fanout=8\n```\n\nBGS: accuracy best=96.55% avg=91.04%\n```\npython3 entity_classify_mb.py -d bgs --l2norm 5e-4 --n-bases 40 --testing --gpu 0\n```\n\nAM: accuracy best=89.39% avg=88.55%\n```\npython3 entity_classify_mb.py -d am --l2norm 5e-4 --n-bases 40 --testing --gpu 0\n```\n\n### Offline Inferencing\nTrained Model can be exported by providing '--model\\_path <PATH>' parameter to entity\\_classify.py. And then test\\_classify.py can load the saved model and do the testing offline.\n\nAIFB:\n```\npython3 entity_classify.py -d aifb --testing --gpu 0 --model_path \"aifb.pt\"\npython3 test_classify.py -d aifb --gpu 0 --model_path \"aifb.pt\"\n```\n\nMUTAG:\n```\npython3 entity_classify.py -d mutag --l2norm 5e-4 --n-bases 30 --testing --gpu 0 --model_path \"mutag.pt\"\npython3 test_classify.py -d mutag --n-bases 30 --gpu 0 --model_path \"mutag.pt\"\n```\n\nBGS:\n```\npython3 entity_classify.py -d bgs --l2norm 5e-4 --n-bases 40 --testing --gpu 0 --model_path \"bgs.pt\"\npython3 test_classify.py -d bgs --n-bases 40 --gpu 0 --model_path \"bgs.pt\"\n```\n\nAM:\n```\npython3 entity_classify.py -d am --l2norm 5e-4 --n-bases 40 --testing --gpu 0 --model_path \"am.pt\"\npython3 test_classify.py -d am --n-bases 40 --gpu 0 --model_path \"am.pt\"\n```\n"
  },
  {
    "path": "examples/pytorch/rgcn-hetero/entity_classify.py",
    "content": "\"\"\"Modeling Relational Data with Graph Convolutional Networks\nPaper: https://arxiv.org/abs/1703.06103\nReference Code: https://github.com/tkipf/relational-gcn\n\"\"\"\nimport argparse\nimport time\n\nimport numpy as np\nimport torch as th\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom dgl.data.rdf import AIFBDataset, AMDataset, BGSDataset, MUTAGDataset\nfrom model import EntityClassify\n\n\ndef main(args):\n    # load graph data\n    if args.dataset == \"aifb\":\n        dataset = AIFBDataset()\n    elif args.dataset == \"mutag\":\n        dataset = MUTAGDataset()\n    elif args.dataset == \"bgs\":\n        dataset = BGSDataset()\n    elif args.dataset == \"am\":\n        dataset = AMDataset()\n    else:\n        raise ValueError()\n\n    g = dataset[0]\n    category = dataset.predict_category\n    num_classes = dataset.num_classes\n    train_mask = g.nodes[category].data.pop(\"train_mask\")\n    test_mask = g.nodes[category].data.pop(\"test_mask\")\n    train_idx = th.nonzero(train_mask, as_tuple=False).squeeze()\n    test_idx = th.nonzero(test_mask, as_tuple=False).squeeze()\n    labels = g.nodes[category].data.pop(\"labels\")\n    category_id = len(g.ntypes)\n    for i, ntype in enumerate(g.ntypes):\n        if ntype == category:\n            category_id = i\n\n    # split dataset into train, validate, test\n    if args.validation:\n        val_idx = train_idx[: len(train_idx) // 5]\n        train_idx = train_idx[len(train_idx) // 5 :]\n    else:\n        val_idx = train_idx\n\n    # check cuda\n    use_cuda = args.gpu >= 0 and th.cuda.is_available()\n    if use_cuda:\n        th.cuda.set_device(args.gpu)\n        g = g.to(\"cuda:%d\" % args.gpu)\n        labels = labels.cuda()\n        train_idx = train_idx.cuda()\n        test_idx = test_idx.cuda()\n\n    # create model\n    model = EntityClassify(\n        g,\n        args.n_hidden,\n        num_classes,\n        num_bases=args.n_bases,\n        num_hidden_layers=args.n_layers - 2,\n        dropout=args.dropout,\n        use_self_loop=args.use_self_loop,\n    )\n\n    if use_cuda:\n        model.cuda()\n\n    # optimizer\n    optimizer = th.optim.Adam(\n        model.parameters(), lr=args.lr, weight_decay=args.l2norm\n    )\n\n    # training loop\n    print(\"start training...\")\n    dur = []\n    model.train()\n    for epoch in range(args.n_epochs):\n        optimizer.zero_grad()\n        if epoch > 5:\n            t0 = time.time()\n        logits = model()[category]\n        loss = F.cross_entropy(logits[train_idx], labels[train_idx])\n        loss.backward()\n        optimizer.step()\n        t1 = time.time()\n\n        if epoch > 5:\n            dur.append(t1 - t0)\n        train_acc = th.sum(\n            logits[train_idx].argmax(dim=1) == labels[train_idx]\n        ).item() / len(train_idx)\n        val_loss = F.cross_entropy(logits[val_idx], labels[val_idx])\n        val_acc = th.sum(\n            logits[val_idx].argmax(dim=1) == labels[val_idx]\n        ).item() / len(val_idx)\n        print(\n            \"Epoch {:05d} | Train Acc: {:.4f} | Train Loss: {:.4f} | Valid Acc: {:.4f} | Valid loss: {:.4f} | Time: {:.4f}\".format(\n                epoch,\n                train_acc,\n                loss.item(),\n                val_acc,\n                val_loss.item(),\n                np.average(dur),\n            )\n        )\n    print()\n    if args.model_path is not None:\n        th.save(model.state_dict(), args.model_path)\n\n    model.eval()\n    logits = model.forward()[category]\n    test_loss = F.cross_entropy(logits[test_idx], labels[test_idx])\n    test_acc = th.sum(\n        logits[test_idx].argmax(dim=1) == labels[test_idx]\n    ).item() / len(test_idx)\n    print(\n        \"Test Acc: {:.4f} | Test loss: {:.4f}\".format(\n            test_acc, test_loss.item()\n        )\n    )\n    print()\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"RGCN\")\n    parser.add_argument(\n        \"--dropout\", type=float, default=0, help=\"dropout probability\"\n    )\n    parser.add_argument(\n        \"--n-hidden\", type=int, default=16, help=\"number of hidden units\"\n    )\n    parser.add_argument(\"--gpu\", type=int, default=-1, help=\"gpu\")\n    parser.add_argument(\"--lr\", type=float, default=1e-2, help=\"learning rate\")\n    parser.add_argument(\n        \"--n-bases\",\n        type=int,\n        default=-1,\n        help=\"number of filter weight matrices, default: -1 [use all]\",\n    )\n    parser.add_argument(\n        \"--n-layers\", type=int, default=2, help=\"number of propagation rounds\"\n    )\n    parser.add_argument(\n        \"-e\",\n        \"--n-epochs\",\n        type=int,\n        default=50,\n        help=\"number of training epochs\",\n    )\n    parser.add_argument(\n        \"-d\", \"--dataset\", type=str, required=True, help=\"dataset to use\"\n    )\n    parser.add_argument(\n        \"--model_path\", type=str, default=None, help=\"path for save the model\"\n    )\n    parser.add_argument(\"--l2norm\", type=float, default=0, help=\"l2 norm coef\")\n    parser.add_argument(\n        \"--use-self-loop\",\n        default=False,\n        action=\"store_true\",\n        help=\"include self feature as a special relation\",\n    )\n    fp = parser.add_mutually_exclusive_group(required=False)\n    fp.add_argument(\"--validation\", dest=\"validation\", action=\"store_true\")\n    fp.add_argument(\"--testing\", dest=\"validation\", action=\"store_false\")\n    parser.set_defaults(validation=True)\n\n    args = parser.parse_args()\n    print(args)\n    main(args)\n"
  },
  {
    "path": "examples/pytorch/rgcn-hetero/entity_classify_heteroAPI.py",
    "content": "\"\"\"Modeling Relational Data with Graph Convolutional Networks\nPaper: https://arxiv.org/abs/1703.06103\nReference Code: https://github.com/tkipf/relational-gcn\n\"\"\"\nimport argparse\nimport time\n\nimport numpy as np\nimport torch as th\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom dgl.data.rdf import AIFBDataset, AMDataset, BGSDataset, MUTAGDataset\nfrom model import EntityClassify_HeteroAPI\n\n\ndef main(args):\n    # load graph data\n    if args.dataset == \"aifb\":\n        dataset = AIFBDataset()\n    elif args.dataset == \"mutag\":\n        dataset = MUTAGDataset()\n    elif args.dataset == \"bgs\":\n        dataset = BGSDataset()\n    elif args.dataset == \"am\":\n        dataset = AMDataset()\n    else:\n        raise ValueError()\n\n    g = dataset[0]\n    category = dataset.predict_category\n    num_classes = dataset.num_classes\n    train_mask = g.nodes[category].data.pop(\"train_mask\")\n    test_mask = g.nodes[category].data.pop(\"test_mask\")\n    train_idx = th.nonzero(train_mask, as_tuple=False).squeeze()\n    test_idx = th.nonzero(test_mask, as_tuple=False).squeeze()\n    labels = g.nodes[category].data.pop(\"labels\")\n    category_id = len(g.ntypes)\n    for i, ntype in enumerate(g.ntypes):\n        if ntype == category:\n            category_id = i\n\n    # split dataset into train, validate, test\n    if args.validation:\n        val_idx = train_idx[: len(train_idx) // 5]\n        train_idx = train_idx[len(train_idx) // 5 :]\n    else:\n        val_idx = train_idx\n\n    # check cuda\n    use_cuda = args.gpu >= 0 and th.cuda.is_available()\n    if use_cuda:\n        th.cuda.set_device(args.gpu)\n        g = g.to(\"cuda:%d\" % args.gpu)\n        labels = labels.cuda()\n        train_idx = train_idx.cuda()\n        test_idx = test_idx.cuda()\n\n    # create model\n    model = EntityClassify_HeteroAPI(\n        g,\n        args.n_hidden,\n        num_classes,\n        num_bases=args.n_bases,\n        num_hidden_layers=args.n_layers - 2,\n        dropout=args.dropout,\n        use_self_loop=args.use_self_loop,\n    )\n\n    if use_cuda:\n        model.cuda()\n\n    # optimizer\n    optimizer = th.optim.Adam(\n        model.parameters(), lr=args.lr, weight_decay=args.l2norm\n    )\n\n    # training loop\n    print(\"start training...\")\n    dur = []\n    model.train()\n    for epoch in range(args.n_epochs):\n        optimizer.zero_grad()\n        t0 = time.time()\n        logits = model()[category]\n        loss = F.cross_entropy(logits[train_idx], labels[train_idx])\n        loss.backward()\n        optimizer.step()\n        t1 = time.time()\n\n        dur.append(t1 - t0)\n        train_acc = th.sum(\n            logits[train_idx].argmax(dim=1) == labels[train_idx]\n        ).item() / len(train_idx)\n        val_loss = F.cross_entropy(logits[val_idx], labels[val_idx])\n        val_acc = th.sum(\n            logits[val_idx].argmax(dim=1) == labels[val_idx]\n        ).item() / len(val_idx)\n        print(\n            \"Epoch {:05d} | Train Acc: {:.4f} | Train Loss: {:.4f} | Valid Acc: {:.4f} | Valid loss: {:.4f} | Time: {:.4f}\".format(\n                epoch,\n                train_acc,\n                loss.item(),\n                val_acc,\n                val_loss.item(),\n                np.average(dur),\n            )\n        )\n    print()\n    if args.model_path is not None:\n        th.save(model.state_dict(), args.model_path)\n\n    model.eval()\n    logits = model.forward()[category]\n    test_loss = F.cross_entropy(logits[test_idx], labels[test_idx])\n    test_acc = th.sum(\n        logits[test_idx].argmax(dim=1) == labels[test_idx]\n    ).item() / len(test_idx)\n    print(\n        \"Test Acc: {:.4f} | Test loss: {:.4f}\".format(\n            test_acc, test_loss.item()\n        )\n    )\n    print()\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"RGCN\")\n    parser.add_argument(\n        \"--dropout\", type=float, default=0, help=\"dropout probability\"\n    )\n    parser.add_argument(\n        \"--n-hidden\", type=int, default=16, help=\"number of hidden units\"\n    )\n    parser.add_argument(\"--gpu\", type=int, default=-1, help=\"gpu\")\n    parser.add_argument(\"--lr\", type=float, default=1e-2, help=\"learning rate\")\n    parser.add_argument(\n        \"--n-bases\",\n        type=int,\n        default=-1,\n        help=\"number of filter weight matrices, default: -1 [use all]\",\n    )\n    parser.add_argument(\n        \"--n-layers\", type=int, default=2, help=\"number of propagation rounds\"\n    )\n    parser.add_argument(\n        \"-e\",\n        \"--n-epochs\",\n        type=int,\n        default=50,\n        help=\"number of training epochs\",\n    )\n    parser.add_argument(\n        \"-d\", \"--dataset\", type=str, required=True, help=\"dataset to use\"\n    )\n    parser.add_argument(\n        \"--model_path\", type=str, default=None, help=\"path for save the model\"\n    )\n    parser.add_argument(\"--l2norm\", type=float, default=0, help=\"l2 norm coef\")\n    parser.add_argument(\n        \"--use-self-loop\",\n        default=False,\n        action=\"store_true\",\n        help=\"include self feature as a special relation\",\n    )\n    fp = parser.add_mutually_exclusive_group(required=False)\n    fp.add_argument(\"--validation\", dest=\"validation\", action=\"store_true\")\n    fp.add_argument(\"--testing\", dest=\"validation\", action=\"store_false\")\n    parser.set_defaults(validation=True)\n\n    args = parser.parse_args()\n    print(args)\n    main(args)\n"
  },
  {
    "path": "examples/pytorch/rgcn-hetero/entity_classify_mb.py",
    "content": "\"\"\"Modeling Relational Data with Graph Convolutional Networks\nPaper: https://arxiv.org/abs/1703.06103\nReference Code: https://github.com/tkipf/relational-gcn\n\"\"\"\nimport argparse\nimport itertools\nimport time\n\nimport dgl\n\nimport numpy as np\nimport torch as th\nimport torch.nn.functional as F\nfrom dgl.data.rdf import AIFBDataset, AMDataset, BGSDataset, MUTAGDataset\nfrom model import EntityClassify, RelGraphEmbed\n\n\ndef extract_embed(node_embed, input_nodes):\n    emb = {}\n    for ntype, nid in input_nodes.items():\n        nid = input_nodes[ntype]\n        emb[ntype] = node_embed[ntype][nid]\n    return emb\n\n\ndef evaluate(model, loader, node_embed, labels, category, device):\n    model.eval()\n    total_loss = 0\n    total_acc = 0\n    count = 0\n    with loader.enable_cpu_affinity():\n        for input_nodes, seeds, blocks in loader:\n            blocks = [blk.to(device) for blk in blocks]\n            seeds = seeds[category]\n            emb = extract_embed(node_embed, input_nodes)\n            emb = {k: e.to(device) for k, e in emb.items()}\n            lbl = labels[seeds].to(device)\n            logits = model(emb, blocks)[category]\n            loss = F.cross_entropy(logits, lbl)\n            acc = th.sum(logits.argmax(dim=1) == lbl).item()\n            total_loss += loss.item() * len(seeds)\n            total_acc += acc\n            count += len(seeds)\n    return total_loss / count, total_acc / count\n\n\ndef main(args):\n    # check cuda\n    device = \"cpu\"\n    use_cuda = args.gpu >= 0 and th.cuda.is_available()\n    if use_cuda:\n        th.cuda.set_device(args.gpu)\n        device = \"cuda:%d\" % args.gpu\n\n    # load graph data\n    if args.dataset == \"aifb\":\n        dataset = AIFBDataset()\n    elif args.dataset == \"mutag\":\n        dataset = MUTAGDataset()\n    elif args.dataset == \"bgs\":\n        dataset = BGSDataset()\n    elif args.dataset == \"am\":\n        dataset = AMDataset()\n    else:\n        raise ValueError()\n\n    g = dataset[0]\n    category = dataset.predict_category\n    num_classes = dataset.num_classes\n    train_mask = g.nodes[category].data.pop(\"train_mask\")\n    test_mask = g.nodes[category].data.pop(\"test_mask\")\n    train_idx = th.nonzero(train_mask, as_tuple=False).squeeze()\n    test_idx = th.nonzero(test_mask, as_tuple=False).squeeze()\n    labels = g.nodes[category].data.pop(\"labels\")\n\n    # split dataset into train, validate, test\n    if args.validation:\n        val_idx = train_idx[: len(train_idx) // 5]\n        train_idx = train_idx[len(train_idx) // 5 :]\n    else:\n        val_idx = train_idx\n\n    # create embeddings\n    embed_layer = RelGraphEmbed(g, args.n_hidden)\n\n    if not args.data_cpu:\n        labels = labels.to(device)\n        embed_layer = embed_layer.to(device)\n\n    if args.num_workers <= 0:\n        raise ValueError(\n            \"The '--num_workers' parameter value is expected \"\n            \"to be >0, but got {}.\".format(args.num_workers)\n        )\n\n    node_embed = embed_layer()\n    # create model\n    model = EntityClassify(\n        g,\n        args.n_hidden,\n        num_classes,\n        num_bases=args.n_bases,\n        num_hidden_layers=args.n_layers - 2,\n        dropout=args.dropout,\n        use_self_loop=args.use_self_loop,\n    )\n\n    if use_cuda:\n        model.cuda()\n\n    # train sampler\n    sampler = dgl.dataloading.MultiLayerNeighborSampler(\n        [args.fanout] * args.n_layers\n    )\n    loader = dgl.dataloading.DataLoader(\n        g,\n        {category: train_idx},\n        sampler,\n        batch_size=args.batch_size,\n        shuffle=True,\n        num_workers=args.num_workers,\n    )\n\n    # validation sampler\n    # we do not use full neighbor to save computation resources\n    val_sampler = dgl.dataloading.MultiLayerNeighborSampler(\n        [args.fanout] * args.n_layers\n    )\n    val_loader = dgl.dataloading.DataLoader(\n        g,\n        {category: val_idx},\n        val_sampler,\n        batch_size=args.batch_size,\n        shuffle=True,\n        num_workers=args.num_workers,\n    )\n\n    # optimizer\n    all_params = itertools.chain(model.parameters(), embed_layer.parameters())\n    optimizer = th.optim.Adam(all_params, lr=args.lr, weight_decay=args.l2norm)\n\n    # training loop\n    print(\"start training...\")\n    mean = 0\n    for epoch in range(args.n_epochs):\n        model.train()\n        optimizer.zero_grad()\n        if epoch > 3:\n            t0 = time.time()\n\n        with loader.enable_cpu_affinity():\n            for i, (input_nodes, seeds, blocks) in enumerate(loader):\n                blocks = [blk.to(device) for blk in blocks]\n                seeds = seeds[\n                    category\n                ]  # we only predict the nodes with type \"category\"\n                batch_tic = time.time()\n                emb = extract_embed(node_embed, input_nodes)\n                lbl = labels[seeds]\n                if use_cuda:\n                    emb = {k: e.cuda() for k, e in emb.items()}\n                    lbl = lbl.cuda()\n                logits = model(emb, blocks)[category]\n                loss = F.cross_entropy(logits, lbl)\n                loss.backward()\n                optimizer.step()\n\n                train_acc = th.sum(logits.argmax(dim=1) == lbl).item() / len(\n                    seeds\n                )\n                print(\n                    f\"Epoch {epoch:05d} | Batch {i:03d} | Train Acc: \"\n                    \"{train_acc:.4f} | Train Loss: {loss.item():.4f} | Time: \"\n                    \"{time.time() - batch_tic:.4f}\"\n                )\n\n        if epoch > 3:\n            mean = (mean * (epoch - 3) + (time.time() - t0)) / (epoch - 2)\n\n            val_loss, val_acc = evaluate(\n                model, val_loader, node_embed, labels, category, device\n            )\n            print(\n                f\"Epoch {epoch:05d} | Valid Acc: {val_acc:.4f} | Valid loss: \"\n                \"{val_loss:.4f} | Time: {mean:.4f}\"\n            )\n    print()\n    if args.model_path is not None:\n        th.save(model.state_dict(), args.model_path)\n\n    output = model.inference(\n        g,\n        args.batch_size,\n        \"cuda\" if use_cuda else \"cpu\",\n        args.num_workers,\n        node_embed,\n    )\n    test_pred = output[category][test_idx]\n    test_labels = labels[test_idx].to(test_pred.device)\n    test_acc = (test_pred.argmax(1) == test_labels).float().mean()\n    print(\"Test Acc: {:.4f}\".format(test_acc))\n    print()\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"RGCN\")\n    parser.add_argument(\n        \"--dropout\", type=float, default=0, help=\"dropout probability\"\n    )\n    parser.add_argument(\n        \"--n-hidden\", type=int, default=16, help=\"number of hidden units\"\n    )\n    parser.add_argument(\"--gpu\", type=int, default=-1, help=\"gpu\")\n    parser.add_argument(\"--lr\", type=float, default=1e-2, help=\"learning rate\")\n    parser.add_argument(\n        \"--n-bases\",\n        type=int,\n        default=-1,\n        help=\"number of filter weight matrices, default: -1 [use all]\",\n    )\n    parser.add_argument(\n        \"--n-layers\", type=int, default=2, help=\"number of propagation rounds\"\n    )\n    parser.add_argument(\n        \"-e\",\n        \"--n-epochs\",\n        type=int,\n        default=20,\n        help=\"number of training epochs\",\n    )\n    parser.add_argument(\n        \"-d\", \"--dataset\", type=str, required=True, help=\"dataset to use\"\n    )\n    parser.add_argument(\n        \"--model_path\", type=str, default=None, help=\"path for save the model\"\n    )\n    parser.add_argument(\"--l2norm\", type=float, default=0, help=\"l2 norm coef\")\n    parser.add_argument(\n        \"--use-self-loop\",\n        default=False,\n        action=\"store_true\",\n        help=\"include self feature as a special relation\",\n    )\n    parser.add_argument(\n        \"--batch-size\",\n        type=int,\n        default=100,\n        help=\"Mini-batch size. If -1, use full graph training.\",\n    )\n    parser.add_argument(\n        \"--fanout\", type=int, default=4, help=\"Fan-out of neighbor sampling.\"\n    )\n    parser.add_argument(\n        \"--data-cpu\",\n        action=\"store_true\",\n        help=\"By default the script puts all node features and labels \"\n        \"on GPU when using it to save time for data copy. This may \"\n        \"be undesired if they cannot fit in GPU memory at once. \"\n        \"This flag disables that.\",\n    )\n    parser.add_argument(\n        \"--num_workers\", type=int, default=4, help=\"Number of node dataloader\"\n    )\n\n    fp = parser.add_mutually_exclusive_group(required=False)\n    fp.add_argument(\"--validation\", dest=\"validation\", action=\"store_true\")\n    fp.add_argument(\"--testing\", dest=\"validation\", action=\"store_false\")\n    parser.set_defaults(validation=True)\n\n    args = parser.parse_args()\n    print(args)\n    main(args)\n"
  },
  {
    "path": "examples/pytorch/rgcn-hetero/model.py",
    "content": "\"\"\"RGCN layer implementation\"\"\"\nfrom collections import defaultdict\n\nimport dgl\nimport dgl.function as fn\nimport dgl.nn as dglnn\n\nimport torch as th\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport tqdm\n\n\nclass RelGraphConvLayer(nn.Module):\n    r\"\"\"Relational graph convolution layer.\n\n    Parameters\n    ----------\n    in_feat : int\n        Input feature size.\n    out_feat : int\n        Output feature size.\n    rel_names : list[str]\n        Relation names.\n    num_bases : int, optional\n        Number of bases. If is none, use number of relations. Default: None.\n    weight : bool, optional\n        True if a linear layer is applied after message passing. Default: True\n    bias : bool, optional\n        True if bias is added. Default: True\n    activation : callable, optional\n        Activation function. Default: None\n    self_loop : bool, optional\n        True to include self loop message. Default: False\n    dropout : float, optional\n        Dropout rate. Default: 0.0\n    \"\"\"\n\n    def __init__(\n        self,\n        in_feat,\n        out_feat,\n        rel_names,\n        num_bases,\n        *,\n        weight=True,\n        bias=True,\n        activation=None,\n        self_loop=False,\n        dropout=0.0\n    ):\n        super(RelGraphConvLayer, self).__init__()\n        self.in_feat = in_feat\n        self.out_feat = out_feat\n        self.rel_names = rel_names\n        self.num_bases = num_bases\n        self.bias = bias\n        self.activation = activation\n        self.self_loop = self_loop\n\n        self.conv = dglnn.HeteroGraphConv(\n            {\n                rel: dglnn.GraphConv(\n                    in_feat, out_feat, norm=\"right\", weight=False, bias=False\n                )\n                for rel in rel_names\n            }\n        )\n\n        self.use_weight = weight\n        self.use_basis = num_bases < len(self.rel_names) and weight\n        if self.use_weight:\n            if self.use_basis:\n                self.basis = dglnn.WeightBasis(\n                    (in_feat, out_feat), num_bases, len(self.rel_names)\n                )\n            else:\n                self.weight = nn.Parameter(\n                    th.Tensor(len(self.rel_names), in_feat, out_feat)\n                )\n                nn.init.xavier_uniform_(\n                    self.weight, gain=nn.init.calculate_gain(\"relu\")\n                )\n\n        # bias\n        if bias:\n            self.h_bias = nn.Parameter(th.Tensor(out_feat))\n            nn.init.zeros_(self.h_bias)\n\n        # weight for self loop\n        if self.self_loop:\n            self.loop_weight = nn.Parameter(th.Tensor(in_feat, out_feat))\n            nn.init.xavier_uniform_(\n                self.loop_weight, gain=nn.init.calculate_gain(\"relu\")\n            )\n\n        self.dropout = nn.Dropout(dropout)\n\n    def forward(self, g, inputs):\n        \"\"\"Forward computation\n\n        Parameters\n        ----------\n        g : DGLGraph\n            Input graph.\n        inputs : dict[str, torch.Tensor]\n            Node feature for each node type.\n\n        Returns\n        -------\n        dict[str, torch.Tensor]\n            New node features for each node type.\n        \"\"\"\n        g = g.local_var()\n        if self.use_weight:\n            weight = self.basis() if self.use_basis else self.weight\n            wdict = {\n                self.rel_names[i]: {\"weight\": w.squeeze(0)}\n                for i, w in enumerate(th.split(weight, 1, dim=0))\n            }\n        else:\n            wdict = {}\n\n        if g.is_block:\n            inputs_src = inputs\n            inputs_dst = {\n                k: v[: g.number_of_dst_nodes(k)] for k, v in inputs.items()\n            }\n        else:\n            inputs_src = inputs_dst = inputs\n\n        hs = self.conv(g, inputs, mod_kwargs=wdict)\n\n        def _apply(ntype, h):\n            if self.self_loop:\n                h = h + th.matmul(inputs_dst[ntype], self.loop_weight)\n            if self.bias:\n                h = h + self.h_bias\n            if self.activation:\n                h = self.activation(h)\n            return self.dropout(h)\n\n        return {ntype: _apply(ntype, h) for ntype, h in hs.items()}\n\n\nclass RelGraphConvLayerHeteroAPI(nn.Module):\n    r\"\"\"Relational graph convolution layer.\n\n    Parameters\n    ----------\n    in_feat : int\n        Input feature size.\n    out_feat : int\n        Output feature size.\n    rel_names : list[str]\n        Relation names.\n    num_bases : int, optional\n        Number of bases. If is none, use number of relations. Default: None.\n    weight : bool, optional\n        True if a linear layer is applied after message passing. Default: True\n    bias : bool, optional\n        True if bias is added. Default: True\n    activation : callable, optional\n        Activation function. Default: None\n    self_loop : bool, optional\n        True to include self loop message. Default: False\n    dropout : float, optional\n        Dropout rate. Default: 0.0\n    \"\"\"\n\n    def __init__(\n        self,\n        in_feat,\n        out_feat,\n        rel_names,\n        num_bases,\n        *,\n        weight=True,\n        bias=True,\n        activation=None,\n        self_loop=False,\n        dropout=0.0\n    ):\n        super(RelGraphConvLayerHeteroAPI, self).__init__()\n        self.in_feat = in_feat\n        self.out_feat = out_feat\n        self.rel_names = rel_names\n        self.num_bases = num_bases\n        self.bias = bias\n        self.activation = activation\n        self.self_loop = self_loop\n\n        self.use_weight = weight\n        self.use_basis = num_bases < len(self.rel_names) and weight\n        if self.use_weight:\n            if self.use_basis:\n                self.basis = dglnn.WeightBasis(\n                    (in_feat, out_feat), num_bases, len(self.rel_names)\n                )\n            else:\n                self.weight = nn.Parameter(\n                    th.Tensor(len(self.rel_names), in_feat, out_feat)\n                )\n                nn.init.xavier_uniform_(\n                    self.weight, gain=nn.init.calculate_gain(\"relu\")\n                )\n\n        # bias\n        if bias:\n            self.h_bias = nn.Parameter(th.Tensor(out_feat))\n            nn.init.zeros_(self.h_bias)\n\n        # weight for self loop\n        if self.self_loop:\n            self.loop_weight = nn.Parameter(th.Tensor(in_feat, out_feat))\n            nn.init.xavier_uniform_(\n                self.loop_weight, gain=nn.init.calculate_gain(\"relu\")\n            )\n\n        self.dropout = nn.Dropout(dropout)\n\n    def forward(self, g, inputs):\n        \"\"\"Forward computation\n\n        Parameters\n        ----------\n        g : DGLGraph\n            Input graph.\n        inputs : dict[str, torch.Tensor]\n            Node feature for each node type.\n\n        Returns\n        -------\n        dict[str, torch.Tensor]\n            New node features for each node type.\n        \"\"\"\n        g = g.local_var()\n        if self.use_weight:\n            weight = self.basis() if self.use_basis else self.weight\n            wdict = {\n                self.rel_names[i]: {\"weight\": w.squeeze(0)}\n                for i, w in enumerate(th.split(weight, 1, dim=0))\n            }\n        else:\n            wdict = {}\n\n        inputs_src = inputs_dst = inputs\n\n        for srctype, _, _ in g.canonical_etypes:\n            g.nodes[srctype].data[\"h\"] = inputs[srctype]\n\n        if self.use_weight:\n            g.apply_edges(fn.copy_u(\"h\", \"m\"))\n            m = g.edata[\"m\"]\n            for rel in g.canonical_etypes:\n                _, etype, _ = rel\n                g.edges[rel].data[\"h*w_r\"] = th.matmul(\n                    m[rel], wdict[etype][\"weight\"]\n                )\n        else:\n            g.apply_edges(fn.copy_u(\"h\", \"h*w_r\"))\n\n        g.update_all(fn.copy_e(\"h*w_r\", \"m\"), fn.sum(\"m\", \"h\"))\n\n        def _apply(ntype):\n            h = g.nodes[ntype].data[\"h\"]\n            if self.self_loop:\n                h = h + th.matmul(inputs_dst[ntype], self.loop_weight)\n            if self.bias:\n                h = h + self.h_bias\n            if self.activation:\n                h = self.activation(h)\n            return self.dropout(h)\n\n        return {ntype: _apply(ntype) for ntype in g.dsttypes}\n\n\nclass RelGraphEmbed(nn.Module):\n    r\"\"\"Embedding layer for featureless heterograph.\"\"\"\n\n    def __init__(\n        self, g, embed_size, embed_name=\"embed\", activation=None, dropout=0.0\n    ):\n        super(RelGraphEmbed, self).__init__()\n        self.g = g\n        self.embed_size = embed_size\n        self.embed_name = embed_name\n        self.activation = activation\n        self.dropout = nn.Dropout(dropout)\n\n        # create weight embeddings for each node for each relation\n        self.embeds = nn.ParameterDict()\n        for ntype in g.ntypes:\n            embed = nn.Parameter(th.Tensor(g.num_nodes(ntype), self.embed_size))\n            nn.init.xavier_uniform_(embed, gain=nn.init.calculate_gain(\"relu\"))\n            self.embeds[ntype] = embed\n\n    def forward(self, block=None):\n        \"\"\"Forward computation\n\n        Parameters\n        ----------\n        block : DGLGraph, optional\n            If not specified, directly return the full graph with embeddings stored in\n            :attr:`embed_name`. Otherwise, extract and store the embeddings to the block\n            graph and return.\n\n        Returns\n        -------\n        DGLGraph\n            The block graph fed with embeddings.\n        \"\"\"\n        return self.embeds\n\n\nclass EntityClassify(nn.Module):\n    def __init__(\n        self,\n        g,\n        h_dim,\n        out_dim,\n        num_bases,\n        num_hidden_layers=1,\n        dropout=0,\n        use_self_loop=False,\n    ):\n        super(EntityClassify, self).__init__()\n        self.g = g\n        self.h_dim = h_dim\n        self.out_dim = out_dim\n        self.rel_names = list(set(g.etypes))\n        self.rel_names.sort()\n        if num_bases < 0 or num_bases > len(self.rel_names):\n            self.num_bases = len(self.rel_names)\n        else:\n            self.num_bases = num_bases\n        self.num_hidden_layers = num_hidden_layers\n        self.dropout = dropout\n        self.use_self_loop = use_self_loop\n\n        self.embed_layer = RelGraphEmbed(g, self.h_dim)\n        self.layers = nn.ModuleList()\n        # i2h\n        self.layers.append(\n            RelGraphConvLayer(\n                self.h_dim,\n                self.h_dim,\n                self.rel_names,\n                self.num_bases,\n                activation=F.relu,\n                self_loop=self.use_self_loop,\n                dropout=self.dropout,\n                weight=False,\n            )\n        )\n        # h2h\n        for i in range(self.num_hidden_layers):\n            self.layers.append(\n                RelGraphConvLayer(\n                    self.h_dim,\n                    self.h_dim,\n                    self.rel_names,\n                    self.num_bases,\n                    activation=F.relu,\n                    self_loop=self.use_self_loop,\n                    dropout=self.dropout,\n                )\n            )\n        # h2o\n        self.layers.append(\n            RelGraphConvLayer(\n                self.h_dim,\n                self.out_dim,\n                self.rel_names,\n                self.num_bases,\n                activation=None,\n                self_loop=self.use_self_loop,\n            )\n        )\n\n    def forward(self, h=None, blocks=None):\n        if h is None:\n            # full graph training\n            h = self.embed_layer()\n        if blocks is None:\n            # full graph training\n            for layer in self.layers:\n                h = layer(self.g, h)\n        else:\n            # minibatch training\n            for layer, block in zip(self.layers, blocks):\n                h = layer(block, h)\n        return h\n\n    def inference(self, g, batch_size, device, num_workers, x=None):\n        \"\"\"Minibatch inference of final representation over all node types.\n\n        ***NOTE***\n        For node classification, the model is trained to predict on only one node type's\n        label.  Therefore, only that type's final representation is meaningful.\n        \"\"\"\n\n        if x is None:\n            x = self.embed_layer()\n\n        for l, layer in enumerate(self.layers):\n            y = {\n                k: th.zeros(\n                    g.num_nodes(k),\n                    self.h_dim if l != len(self.layers) - 1 else self.out_dim,\n                )\n                for k in g.ntypes\n            }\n\n            sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)\n            dataloader = dgl.dataloading.DataLoader(\n                g,\n                {k: th.arange(g.num_nodes(k)) for k in g.ntypes},\n                sampler,\n                batch_size=batch_size,\n                shuffle=True,\n                drop_last=False,\n                num_workers=num_workers,\n            )\n\n            with dataloader.enable_cpu_affinity():\n                for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):\n                    block = blocks[0].to(device)\n\n                    h = {\n                        k: x[k][input_nodes[k]].to(device)\n                        for k in input_nodes.keys()\n                    }\n                    h = layer(block, h)\n\n                    for k in output_nodes.keys():\n                        y[k][output_nodes[k]] = h[k].cpu()\n\n            x = y\n        return y\n\n\nclass EntityClassify_HeteroAPI(nn.Module):\n    def __init__(\n        self,\n        g,\n        h_dim,\n        out_dim,\n        num_bases,\n        num_hidden_layers=1,\n        dropout=0,\n        use_self_loop=False,\n    ):\n        super(EntityClassify_HeteroAPI, self).__init__()\n        self.g = g\n        self.h_dim = h_dim\n        self.out_dim = out_dim\n        self.rel_names = list(set(g.etypes))\n        self.rel_names.sort()\n        if num_bases < 0 or num_bases > len(self.rel_names):\n            self.num_bases = len(self.rel_names)\n        else:\n            self.num_bases = num_bases\n        self.num_hidden_layers = num_hidden_layers\n        self.dropout = dropout\n        self.use_self_loop = use_self_loop\n\n        self.embed_layer = RelGraphEmbed(g, self.h_dim)\n        self.layers = nn.ModuleList()\n        # i2h\n        self.layers.append(\n            RelGraphConvLayerHeteroAPI(\n                self.h_dim,\n                self.h_dim,\n                self.rel_names,\n                self.num_bases,\n                activation=F.relu,\n                self_loop=self.use_self_loop,\n                dropout=self.dropout,\n                weight=False,\n            )\n        )\n        # h2h\n        for i in range(self.num_hidden_layers):\n            self.layers.append(\n                RelGraphConvLayerHeteroAPI(\n                    self.h_dim,\n                    self.h_dim,\n                    self.rel_names,\n                    self.num_bases,\n                    activation=F.relu,\n                    self_loop=self.use_self_loop,\n                    dropout=self.dropout,\n                )\n            )\n        # h2o\n        self.layers.append(\n            RelGraphConvLayerHeteroAPI(\n                self.h_dim,\n                self.out_dim,\n                self.rel_names,\n                self.num_bases,\n                activation=None,\n                self_loop=self.use_self_loop,\n            )\n        )\n\n    def forward(self, h=None, blocks=None):\n        if h is None:\n            # full graph training\n            h = self.embed_layer()\n        if blocks is None:\n            # full graph training\n            for layer in self.layers:\n                h = layer(self.g, h)\n        else:\n            # minibatch training\n            for layer, block in zip(self.layers, blocks):\n                h = layer(block, h)\n        return h\n\n    def inference(self, g, batch_size, device, num_workers, x=None):\n        \"\"\"Minibatch inference of final representation over all node types.\n\n        ***NOTE***\n        For node classification, the model is trained to predict on only one node type's\n        label.  Therefore, only that type's final representation is meaningful.\n        \"\"\"\n\n        if x is None:\n            x = self.embed_layer()\n\n        for l, layer in enumerate(self.layers):\n            y = {\n                k: th.zeros(\n                    g.num_nodes(k),\n                    self.h_dim if l != len(self.layers) - 1 else self.out_dim,\n                )\n                for k in g.ntypes\n            }\n\n            sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)\n            dataloader = dgl.dataloading.DataLoader(\n                g,\n                {k: th.arange(g.num_nodes(k)) for k in g.ntypes},\n                sampler,\n                batch_size=batch_size,\n                shuffle=True,\n                drop_last=False,\n                num_workers=num_workers,\n            )\n\n            for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):\n                block = blocks[0].to(device)\n\n                h = {\n                    k: x[k][input_nodes[k]].to(device)\n                    for k in input_nodes.keys()\n                }\n                h = layer(block, h)\n\n                for k in h.keys():\n                    y[k][output_nodes[k]] = h[k].cpu()\n\n            x = y\n        return y\n"
  },
  {
    "path": "examples/pytorch/rgcn-hetero/test_classify.py",
    "content": "\"\"\"Infering Relational Data with Graph Convolutional Networks\n\"\"\"\nimport argparse\nfrom functools import partial\n\nimport torch as th\nimport torch.nn.functional as F\n\nfrom dgl.data.rdf import AIFBDataset, AMDataset, BGSDataset, MUTAGDataset\nfrom entity_classify import EntityClassify\n\n\ndef main(args):\n    # load graph data\n    if args.dataset == \"aifb\":\n        dataset = AIFBDataset()\n    elif args.dataset == \"mutag\":\n        dataset = MUTAGDataset()\n    elif args.dataset == \"bgs\":\n        dataset = BGSDataset()\n    elif args.dataset == \"am\":\n        dataset = AMDataset()\n    else:\n        raise ValueError()\n\n    g = dataset[0]\n    category = dataset.predict_category\n    num_classes = dataset.num_classes\n    test_mask = g.nodes[category].data.pop(\"test_mask\")\n    test_idx = th.nonzero(test_mask, as_tuple=False).squeeze()\n    labels = g.nodes[category].data.pop(\"labels\")\n\n    # check cuda\n    use_cuda = args.gpu >= 0 and th.cuda.is_available()\n    if use_cuda:\n        th.cuda.set_device(args.gpu)\n        labels = labels.cuda()\n        test_idx = test_idx.cuda()\n        g = g.to(\"cuda:%d\" % args.gpu)\n\n    # create model\n    model = EntityClassify(\n        g,\n        args.n_hidden,\n        num_classes,\n        num_bases=args.n_bases,\n        num_hidden_layers=args.n_layers - 2,\n        use_self_loop=args.use_self_loop,\n    )\n    model.load_state_dict(th.load(args.model_path))\n    if use_cuda:\n        model.cuda()\n\n    print(\"start testing...\")\n    model.eval()\n    logits = model.forward()[category]\n    test_loss = F.cross_entropy(logits[test_idx], labels[test_idx])\n    test_acc = th.sum(\n        logits[test_idx].argmax(dim=1) == labels[test_idx]\n    ).item() / len(test_idx)\n    print(\n        \"Test Acc: {:.4f} | Test loss: {:.4f}\".format(\n            test_acc, test_loss.item()\n        )\n    )\n    print()\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"RGCN\")\n    parser.add_argument(\n        \"--n-hidden\", type=int, default=16, help=\"number of hidden units\"\n    )\n    parser.add_argument(\"--gpu\", type=int, default=-1, help=\"gpu\")\n    parser.add_argument(\"--lr\", type=float, default=1e-2, help=\"learning rate\")\n    parser.add_argument(\n        \"--n-bases\",\n        type=int,\n        default=-1,\n        help=\"number of filter weight matrices, default: -1 [use all]\",\n    )\n    parser.add_argument(\n        \"--n-layers\", type=int, default=2, help=\"number of propagation rounds\"\n    )\n    parser.add_argument(\n        \"-d\", \"--dataset\", type=str, required=True, help=\"dataset to use\"\n    )\n    parser.add_argument(\n        \"--model_path\", type=str, help=\"path of the model to load from\"\n    )\n    parser.add_argument(\n        \"--use-self-loop\",\n        default=False,\n        action=\"store_true\",\n        help=\"include self feature as a special relation\",\n    )\n\n    args = parser.parse_args()\n    print(args)\n    main(args)\n"
  },
  {
    "path": "examples/pytorch/rrn/README.md",
    "content": "# Recurrent Relational Network (RRN)\n\n* Paper link: https://arxiv.org/abs/1711.08028\n* Author's code repo: https://github.com/rasmusbergpalm/recurrent-relational-networks\n\n## Dependencies\n\n* PyTorch 1.0+\n* DGL 0.5+\n\n## Codes\n\nThe folder contains a DGL implementation of Recurrent Relational Network, and its\napplication on sudoku solving.\n\n## Usage\n\n- To train the RRN for sudoku, run the following\n```\npython3 train_sudoku.py --output_dir out/ --do_train\n```\n\n- Test with specified aggregation steps:\n```\npython3 train_sudoku.py --output_dir out/ --do_eval --steps 64\n```\n\n  Test accuracy (puzzle-level): \n\n|       | 32 steps | 64 steps |\n| ----- | :------: | :------: |\n| Paper | 94.1     | 96.6     |\n| DGL   | 95.3     | 98.9     |\n\n\n- To use the trained model for solving sudoku, follow the example bellow:\n\n```python\nfrom sudoku_solver import solve_sudoku\n\nq = [[9, 7, 0, 4, 0, 2, 0, 5, 3],\n     [0, 4, 6, 0, 9, 0, 0, 0, 0],\n     [0, 0, 8, 6, 0, 1, 4, 0, 7],\n     [0, 0, 0, 0, 0, 3, 5, 0, 0],\n     [7, 6, 0, 0, 0, 0, 0, 8, 2],\n     [0, 0, 2, 8, 0, 0, 0, 0, 0],\n     [6, 0, 5, 1, 0, 7, 2, 0, 0],\n     [0, 0, 0, 0, 6, 0, 7, 4, 0],\n     [4, 3, 0, 2, 0, 9, 0, 6, 1]\n    ]\n\nanswer = solve_sudoku(q)\nprint(answer)\n'''\n[[9 7 1 4 8 2 6 5 3]\n [3 4 6 7 9 5 1 2 8]\n [2 5 8 6 3 1 4 9 7]\n [8 1 4 9 2 3 5 7 6]\n [7 6 3 5 1 4 9 8 2]\n [5 9 2 8 7 6 3 1 4]\n [6 8 5 1 4 7 2 3 9]\n [1 2 9 3 6 8 7 4 5]\n [4 3 7 2 5 9 8 6 1]]\n'''\n```\n"
  },
  {
    "path": "examples/pytorch/rrn/rrn.py",
    "content": "\"\"\"\nRecurrent Relational Network(RRN) module\n\nReferences:\n- Recurrent Relational Networks\n- Paper: https://arxiv.org/abs/1711.08028\n- Original Code: https://github.com/rasmusbergpalm/recurrent-relational-networks\n\"\"\"\n\nimport dgl.function as fn\nimport torch\nfrom torch import nn\n\n\nclass RRNLayer(nn.Module):\n    def __init__(self, msg_layer, node_update_func, edge_drop):\n        super(RRNLayer, self).__init__()\n        self.msg_layer = msg_layer\n        self.node_update_func = node_update_func\n        self.edge_dropout = nn.Dropout(edge_drop)\n\n    def forward(self, g):\n        g.apply_edges(self.get_msg)\n        g.edata[\"e\"] = self.edge_dropout(g.edata[\"e\"])\n        g.update_all(\n            message_func=fn.copy_e(\"e\", \"msg\"), reduce_func=fn.sum(\"msg\", \"m\")\n        )\n        g.apply_nodes(self.node_update)\n\n    def get_msg(self, edges):\n        e = torch.cat([edges.src[\"h\"], edges.dst[\"h\"]], -1)\n        e = self.msg_layer(e)\n        return {\"e\": e}\n\n    def node_update(self, nodes):\n        return self.node_update_func(nodes)\n\n\nclass RRN(nn.Module):\n    def __init__(self, msg_layer, node_update_func, num_steps, edge_drop):\n        super(RRN, self).__init__()\n        self.num_steps = num_steps\n        self.rrn_layer = RRNLayer(msg_layer, node_update_func, edge_drop)\n\n    def forward(self, g, get_all_outputs=True):\n        outputs = []\n        for _ in range(self.num_steps):\n            self.rrn_layer(g)\n            if get_all_outputs:\n                outputs.append(g.ndata[\"h\"])\n        if get_all_outputs:\n            outputs = torch.stack(outputs, 0)  # num_steps x n_nodes x h_dim\n        else:\n            outputs = g.ndata[\"h\"]  # n_nodes x h_dim\n        return outputs\n"
  },
  {
    "path": "examples/pytorch/rrn/sudoku.py",
    "content": "\"\"\"\nSudokuNN module based on RRN for solving sudoku puzzles\n\"\"\"\n\nimport torch\nfrom rrn import RRN\nfrom torch import nn\n\n\nclass SudokuNN(nn.Module):\n    def __init__(self, num_steps, embed_size=16, hidden_dim=96, edge_drop=0.1):\n        super(SudokuNN, self).__init__()\n        self.num_steps = num_steps\n\n        self.digit_embed = nn.Embedding(10, embed_size)\n        self.row_embed = nn.Embedding(9, embed_size)\n        self.col_embed = nn.Embedding(9, embed_size)\n\n        self.input_layer = nn.Sequential(\n            nn.Linear(3 * embed_size, hidden_dim),\n            nn.ReLU(),\n            nn.Linear(hidden_dim, hidden_dim),\n            nn.ReLU(),\n            nn.Linear(hidden_dim, hidden_dim),\n            nn.ReLU(),\n            nn.Linear(hidden_dim, hidden_dim),\n        )\n\n        self.lstm = nn.LSTMCell(hidden_dim * 2, hidden_dim, bias=False)\n\n        msg_layer = nn.Sequential(\n            nn.Linear(2 * hidden_dim, hidden_dim),\n            nn.ReLU(),\n            nn.Linear(hidden_dim, hidden_dim),\n            nn.ReLU(),\n            nn.Linear(hidden_dim, hidden_dim),\n            nn.ReLU(),\n            nn.Linear(hidden_dim, hidden_dim),\n        )\n\n        self.rrn = RRN(msg_layer, self.node_update_func, num_steps, edge_drop)\n\n        self.output_layer = nn.Linear(hidden_dim, 10)\n\n        self.loss_func = nn.CrossEntropyLoss()\n\n    def forward(self, g, is_training=True):\n        labels = g.ndata.pop(\"a\")\n\n        input_digits = self.digit_embed(g.ndata.pop(\"q\"))\n        rows = self.row_embed(g.ndata.pop(\"row\"))\n        cols = self.col_embed(g.ndata.pop(\"col\"))\n\n        x = self.input_layer(torch.cat([input_digits, rows, cols], -1))\n        g.ndata[\"x\"] = x\n        g.ndata[\"h\"] = x\n        g.ndata[\"rnn_h\"] = torch.zeros_like(x, dtype=torch.float)\n        g.ndata[\"rnn_c\"] = torch.zeros_like(x, dtype=torch.float)\n\n        outputs = self.rrn(g, is_training)\n        logits = self.output_layer(outputs)\n\n        preds = torch.argmax(logits, -1)\n\n        if is_training:\n            labels = torch.stack([labels] * self.num_steps, 0)\n        logits = logits.view([-1, 10])\n        labels = labels.view([-1])\n        loss = self.loss_func(logits, labels)\n        return preds, loss\n\n    def node_update_func(self, nodes):\n        x, h, m, c = (\n            nodes.data[\"x\"],\n            nodes.data[\"rnn_h\"],\n            nodes.data[\"m\"],\n            nodes.data[\"rnn_c\"],\n        )\n        new_h, new_c = self.lstm(torch.cat([x, m], -1), (h, c))\n        return {\"h\": new_h, \"rnn_c\": new_c, \"rnn_h\": new_h}\n"
  },
  {
    "path": "examples/pytorch/rrn/sudoku_data.py",
    "content": "import csv\nimport os\nimport urllib.request\nimport zipfile\nfrom copy import copy\n\nimport dgl\n\nimport numpy as np\nimport torch\nfrom torch.utils.data import DataLoader, RandomSampler, SequentialSampler\nfrom torch.utils.data.dataset import Dataset\n\n\ndef _basic_sudoku_graph():\n    grids = [\n        [0, 1, 2, 9, 10, 11, 18, 19, 20],\n        [3, 4, 5, 12, 13, 14, 21, 22, 23],\n        [6, 7, 8, 15, 16, 17, 24, 25, 26],\n        [27, 28, 29, 36, 37, 38, 45, 46, 47],\n        [30, 31, 32, 39, 40, 41, 48, 49, 50],\n        [33, 34, 35, 42, 43, 44, 51, 52, 53],\n        [54, 55, 56, 63, 64, 65, 72, 73, 74],\n        [57, 58, 59, 66, 67, 68, 75, 76, 77],\n        [60, 61, 62, 69, 70, 71, 78, 79, 80],\n    ]\n    edges = set()\n    for i in range(81):\n        row, col = i // 9, i % 9\n        # same row and col\n        row_src = row * 9\n        col_src = col\n        for _ in range(9):\n            edges.add((row_src, i))\n            edges.add((col_src, i))\n            row_src += 1\n            col_src += 9\n        # same grid\n        grid_row, grid_col = row // 3, col // 3\n        for n in grids[grid_row * 3 + grid_col]:\n            if n != i:\n                edges.add((n, i))\n    edges = list(edges)\n    g = dgl.graph(edges)\n    return g\n\n\nclass ListDataset(Dataset):\n    def __init__(self, *lists_of_data):\n        assert all(len(lists_of_data[0]) == len(d) for d in lists_of_data)\n        self.lists_of_data = lists_of_data\n\n    def __getitem__(self, index):\n        return tuple(d[index] for d in self.lists_of_data)\n\n    def __len__(self):\n        return len(self.lists_of_data[0])\n\n\ndef _get_sudoku_dataset(segment=\"train\"):\n    assert segment in [\"train\", \"valid\", \"test\"]\n    url = \"https://data.dgl.ai/dataset/sudoku-hard.zip\"\n    zip_fname = \"/tmp/sudoku-hard.zip\"\n    dest_dir = \"/tmp/sudoku-hard/\"\n\n    if not os.path.exists(dest_dir):\n        print(\"Downloading data...\")\n\n        urllib.request.urlretrieve(url, zip_fname)\n        with zipfile.ZipFile(zip_fname) as f:\n            f.extractall(\"/tmp/\")\n\n    def read_csv(fname):\n        print(\"Reading %s...\" % fname)\n        with open(dest_dir + fname) as f:\n            reader = csv.reader(f, delimiter=\",\")\n            return [(q, a) for q, a in reader]\n\n    data = read_csv(segment + \".csv\")\n\n    def encode(samples):\n        def parse(x):\n            return list(map(int, list(x)))\n\n        encoded = [(parse(q), parse(a)) for q, a in samples]\n        return encoded\n\n    data = encode(data)\n    print(f\"Number of puzzles in {segment} set : {len(data)}\")\n\n    return data\n\n\ndef sudoku_dataloader(batch_size, segment=\"train\"):\n    \"\"\"\n    Get a DataLoader instance for dataset of sudoku. Every iteration of the dataloader returns\n    a DGLGraph instance, the ndata of the graph contains:\n    'q': question, e.g. the sudoku puzzle to be solved, the position is to be filled with number from 1-9\n         if the value in the position is 0\n    'a': answer, the ground truth of the sudoku puzzle\n    'row': row index for each position in the grid\n    'col': column index for each position in the grid\n    :param batch_size: Batch size for the dataloader\n    :param segment: The segment of the datasets, must in ['train', 'valid', 'test']\n    :return: A pytorch DataLoader instance\n    \"\"\"\n    data = _get_sudoku_dataset(segment)\n    q, a = zip(*data)\n\n    dataset = ListDataset(q, a)\n    if segment == \"train\":\n        data_sampler = RandomSampler(dataset)\n    else:\n        data_sampler = SequentialSampler(dataset)\n\n    basic_graph = _basic_sudoku_graph()\n    sudoku_indices = np.arange(0, 81)\n    rows = sudoku_indices // 9\n    cols = sudoku_indices % 9\n\n    def collate_fn(batch):\n        graph_list = []\n        for q, a in batch:\n            q = torch.tensor(q, dtype=torch.long)\n            a = torch.tensor(a, dtype=torch.long)\n            graph = copy(basic_graph)\n            graph.ndata[\"q\"] = q  # q means question\n            graph.ndata[\"a\"] = a  # a means answer\n            graph.ndata[\"row\"] = torch.tensor(rows, dtype=torch.long)\n            graph.ndata[\"col\"] = torch.tensor(cols, dtype=torch.long)\n            graph_list.append(graph)\n        batch_graph = dgl.batch(graph_list)\n        return batch_graph\n\n    dataloader = DataLoader(\n        dataset, batch_size, sampler=data_sampler, collate_fn=collate_fn\n    )\n    return dataloader\n"
  },
  {
    "path": "examples/pytorch/rrn/sudoku_solver.py",
    "content": "import os\nimport urllib.request\n\nimport numpy as np\nimport torch\nfrom sudoku import SudokuNN\nfrom sudoku_data import _basic_sudoku_graph\n\n\ndef solve_sudoku(puzzle):\n    \"\"\"\n    Solve sudoku puzzle using RRN.\n    :param puzzle: an array-like data with shape [9, 9], blank positions are filled with 0\n    :return: a [9, 9] shaped numpy array\n    \"\"\"\n    puzzle = np.array(puzzle, dtype=int).reshape([-1])\n    model_path = \"ckpt\"\n    if not os.path.exists(model_path):\n        os.mkdir(model_path)\n\n    model_filename = os.path.join(model_path, \"rrn-sudoku.pkl\")\n    if not os.path.exists(model_filename):\n        print(\"Downloading model...\")\n        url = \"https://data.dgl.ai/models/rrn-sudoku.pkl\"\n        urllib.request.urlretrieve(url, model_filename)\n\n    model = SudokuNN(num_steps=64, edge_drop=0.0)\n    model.load_state_dict(\n        torch.load(model_filename, weights_only=False, map_location=\"cpu\")\n    )\n    model.eval()\n\n    g = _basic_sudoku_graph()\n    sudoku_indices = np.arange(0, 81)\n    rows = sudoku_indices // 9\n    cols = sudoku_indices % 9\n\n    g.ndata[\"row\"] = torch.tensor(rows, dtype=torch.long)\n    g.ndata[\"col\"] = torch.tensor(cols, dtype=torch.long)\n    g.ndata[\"q\"] = torch.tensor(puzzle, dtype=torch.long)\n    g.ndata[\"a\"] = torch.tensor(puzzle, dtype=torch.long)\n\n    pred, _ = model(g, False)\n    pred = pred.cpu().data.numpy().reshape([9, 9])\n    return pred\n\n\nif __name__ == \"__main__\":\n    q = [\n        [9, 7, 0, 4, 0, 2, 0, 5, 3],\n        [0, 4, 6, 0, 9, 0, 0, 0, 0],\n        [0, 0, 8, 6, 0, 1, 4, 0, 7],\n        [0, 0, 0, 0, 0, 3, 5, 0, 0],\n        [7, 6, 0, 0, 0, 0, 0, 8, 2],\n        [0, 0, 2, 8, 0, 0, 0, 0, 0],\n        [6, 0, 5, 1, 0, 7, 2, 0, 0],\n        [0, 0, 0, 0, 6, 0, 7, 4, 0],\n        [4, 3, 0, 2, 0, 9, 0, 6, 1],\n    ]\n\n    answer = solve_sudoku(q)\n    print(answer)\n"
  },
  {
    "path": "examples/pytorch/rrn/train_sudoku.py",
    "content": "import argparse\nimport os\n\nimport numpy as np\nimport torch\nfrom sudoku import SudokuNN\nfrom sudoku_data import sudoku_dataloader\nfrom torch.optim import Adam\n\n\ndef main(args):\n    if args.gpu < 0 or not torch.cuda.is_available():\n        device = torch.device(\"cpu\")\n    else:\n        device = torch.device(\"cuda\", args.gpu)\n\n    model = SudokuNN(num_steps=args.steps, edge_drop=args.edge_drop)\n\n    if args.do_train:\n        if not os.path.exists(args.output_dir):\n            os.mkdir(args.output_dir)\n        model.to(device)\n        train_dataloader = sudoku_dataloader(args.batch_size, segment=\"train\")\n        dev_dataloader = sudoku_dataloader(args.batch_size, segment=\"valid\")\n\n        opt = Adam(\n            model.parameters(), lr=args.lr, weight_decay=args.weight_decay\n        )\n\n        best_dev_acc = 0.0\n        for epoch in range(args.epochs):\n            model.train()\n            for i, g in enumerate(train_dataloader):\n                g = g.to(device)\n                _, loss = model(g)\n                opt.zero_grad()\n                loss.backward()\n                opt.step()\n                if i % 100 == 0:\n                    print(f\"Epoch {epoch}, batch {i}, loss {loss.cpu().data}\")\n\n            # dev\n            print(\"\\n=========Dev step========\")\n            model.eval()\n            dev_loss = []\n            dev_res = []\n            for g in dev_dataloader:\n                g = g.to(device)\n                target = g.ndata[\"a\"]\n                target = target.view([-1, 81])\n\n                with torch.no_grad():\n                    preds, loss = model(g, is_training=False)\n                    preds = preds.view([-1, 81])\n\n                    for i in range(preds.size(0)):\n                        dev_res.append(\n                            int(torch.equal(preds[i, :], target[i, :]))\n                        )\n\n                    dev_loss.append(loss.cpu().detach().data)\n\n            dev_acc = sum(dev_res) / len(dev_res)\n            print(f\"Dev loss {np.mean(dev_loss)}, accuracy {dev_acc}\")\n            if dev_acc >= best_dev_acc:\n                torch.save(\n                    model.state_dict(),\n                    os.path.join(args.output_dir, \"model_best.bin\"),\n                )\n                best_dev_acc = dev_acc\n            print(f\"Best dev accuracy {best_dev_acc}\\n\")\n\n        torch.save(\n            model.state_dict(), os.path.join(args.output_dir, \"model_final.bin\")\n        )\n\n    if args.do_eval:\n        model_path = os.path.join(args.output_dir, \"model_best.bin\")\n        if not os.path.exists(model_path):\n            raise FileNotFoundError(\"Saved model not Found!\")\n\n        model.load_state_dict(torch.load(model_path, weights_only=False))\n        model.to(device)\n\n        test_dataloader = sudoku_dataloader(args.batch_size, segment=\"test\")\n\n        print(\"\\n=========Test step========\")\n        model.eval()\n        test_loss = []\n        test_res = []\n        for g in test_dataloader:\n            g = g.to(device)\n            target = g.ndata[\"a\"]\n            target = target.view([-1, 81])\n\n            with torch.no_grad():\n                preds, loss = model(g, is_training=False)\n                preds = preds\n                preds = preds.view([-1, 81])\n\n                for i in range(preds.size(0)):\n                    test_res.append(int(torch.equal(preds[i, :], target[i, :])))\n\n                test_loss.append(loss.cpu().detach().data)\n\n        test_acc = sum(test_res) / len(test_res)\n        print(f\"Test loss {np.mean(test_loss)}, accuracy {test_acc}\\n\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(\n        description=\"Recurrent Relational Network on sudoku task.\"\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"The directory to save model\",\n    )\n    parser.add_argument(\n        \"--do_train\", default=False, action=\"store_true\", help=\"Train the model\"\n    )\n    parser.add_argument(\n        \"--do_eval\",\n        default=False,\n        action=\"store_true\",\n        help=\"Evaluate the model on test data\",\n    )\n    parser.add_argument(\n        \"--epochs\", type=int, default=100, help=\"Number of training epochs\"\n    )\n    parser.add_argument(\"--batch_size\", type=int, default=64, help=\"Batch size\")\n    parser.add_argument(\n        \"--edge_drop\", type=float, default=0.4, help=\"Dropout rate at edges.\"\n    )\n    parser.add_argument(\n        \"--steps\", type=int, default=32, help=\"Number of message passing steps.\"\n    )\n    parser.add_argument(\"--gpu\", type=int, default=-1, help=\"gpu\")\n    parser.add_argument(\"--lr\", type=float, default=2e-4, help=\"Learning rate\")\n    parser.add_argument(\n        \"--weight_decay\",\n        type=float,\n        default=1e-4,\n        help=\"weight decay (L2 penalty)\",\n    )\n\n    args = parser.parse_args()\n\n    main(args)\n"
  },
  {
    "path": "examples/pytorch/sagpool/README.md",
    "content": "# DGL Implementation of the SAGPool Paper\n\nThis DGL example implements the GNN model proposed in the paper [Self Attention Graph Pooling](https://arxiv.org/pdf/1904.08082.pdf). \nThe author's codes of implementation is in [here](https://github.com/inyeoplee77/SAGPool)\n\n\nThe graph dataset used in this example \n---------------------------------------\nThe DGL's built-in LegacyTUDataset. This is a serial of graph kernel datasets for graph classification. We use 'DD', 'PROTEINS', 'NCI1', 'NCI109' and 'Mutagenicity' in this SAGPool implementation. All these datasets are randomly splited to train, validation and test set with ratio 0.8, 0.1 and 0.1.\n\nNOTE: Since there is no data attributes in some of these datasets, we use node_id (in one-hot vector whose length is the max number of nodes across all graphs) as the node feature. Also note that the node_id in some datasets is not unique (e.g. a graph may has two nodes with the same id).\n\n|                  | DD     | PROTEINS | NCI1  | NCI109 | Mutagenicity |\n| ---------------- | ------ | -------- | ----- | ------ | ------------ |\n| NumGraphs        | 1178   | 1113     | 4110  | 4127   | 4337         |\n| AvgNodesPerGraph | 284.32 | 39.06    | 29.87 | 29.68  | 30.32        |\n| AvgEdgesPerGraph | 715.66 | 72.82    | 32.30 | 32.13  | 30.77        |\n| NumFeats         | 89     | 1        | 37    | 38     | 14           |\n| NumClasses       | 2      | 2        | 2     | 2      | 2            |\n\n\nHow to run example files\n--------------------------------\nThe valid dataset names (you can find a full list [here](https://chrsmrrs.github.io/datasets/docs/datasets/)):\n- 'DD' for D&D\n- 'PROTEINS' for PROTEINS\n- 'NCI1' for NCI1\n- 'NCI109' for NCI109\n- 'Mutagenicity' for Mutagenicity\n\nIn the sagpool folder, run\n\n```bash\npython main.py --dataset ${your_dataset_name_here}\n```\n\nIf want to use a GPU, run\n\n```bash\npython main.py --device ${your_device_id_here} --dataset ${your_dataset_name_here}\n```\n\nIf your want to perform a grid search, modify parameter settings in `grid_search_config.json` and run\n```bash\npython grid_search.py --device ${your_device_id_here} --num_trials ${num_of_trials_here}\n```\n\nPerformance\n-------------------------\n\nNOTE: We do not perform grid search or finetune here, so there may be a gap between results in paper and our results. Also, we only perform 10 trials for each experiment, which is different from 200 trials per experiment in the paper.\n\n**The global architecture result**\n| Dataset       | paper result (global)            | ours (global)               |\n| ------------- | -------------------------------- | --------------------------- |\n| D&D           | 76.19 (0.94)                     | 74.79 (2.69)                |\n| PROTEINS      | 70.04 (1.47)                     | 70.36 (5.90)                |\n| NCI1          | 74.18 (1.20)                     | 72.82 (2.36)                |\n| NCI109        | 74.06 (0.78)                     | 71.64 (2.65)                |\n| Mutagenicity  | N/A                              | 76.55 (2.89)                |\n\n**The hierarchical architecture result**\n| Dataset       | paper result (hierarchical)      | ours (hierarchical)         |\n| ------------- | -------------------------------- | --------------------------- |\n| D&D           | 76.45 (0.97)                     | 75.38 (4.17)                |\n| PROTEINS      | 71.86 (0.97)                     | 70.36 (5.68)                |\n| NCI1          | 67.45 (1.11)                     | 70.61 (2.25)                |\n| NCI109        | 67.86 (1.41)                     | 69.13 (3.85)                |\n| Mutagenicity  | N/A                              | 75.20 (1.95)                |\n"
  },
  {
    "path": "examples/pytorch/sagpool/grid_search.py",
    "content": "import json\nimport os\nfrom copy import deepcopy\n\nfrom main import main, parse_args\nfrom utils import get_stats\n\n\ndef load_config(path=\"./grid_search_config.json\"):\n    with open(path, \"r\") as f:\n        return json.load(f)\n\n\ndef run_experiments(args):\n    res = []\n    for i in range(args.num_trials):\n        print(\"Trial {}/{}\".format(i + 1, args.num_trials))\n        acc, _ = main(args)\n        res.append(acc)\n\n    mean, err_bd = get_stats(res, conf_interval=True)\n    return mean, err_bd\n\n\ndef grid_search(config: dict):\n    args = parse_args()\n    results = {}\n\n    for d in config[\"dataset\"]:\n        args.dataset = d\n        best_acc, err_bd = 0.0, 0.0\n        best_args = vars(args)\n        for arch in config[\"arch\"]:\n            args.architecture = arch\n            for hidden in config[\"hidden\"]:\n                args.hid_dim = hidden\n                for pool_ratio in config[\"pool_ratio\"]:\n                    args.pool_ratio = pool_ratio\n                    for lr in config[\"lr\"]:\n                        args.lr = lr\n                        for weight_decay in config[\"weight_decay\"]:\n                            args.weight_decay = weight_decay\n                            acc, bd = run_experiments(args)\n                            if acc > best_acc:\n                                best_acc = acc\n                                err_bd = bd\n                                best_args = deepcopy(vars(args))\n        args.output_path = \"./output\"\n        if not os.path.exists(args.output_path):\n            os.makedirs(args.output_path)\n        args.output_path = \"./output/{}.log\".format(d)\n        result = {\n            \"params\": best_args,\n            \"result\": \"{:.4f}({:.4f})\".format(best_acc, err_bd),\n        }\n        with open(args.output_path, \"w\") as f:\n            json.dump(result, f, sort_keys=True, indent=4)\n\n\ngrid_search(load_config())\n"
  },
  {
    "path": "examples/pytorch/sagpool/grid_search_config.json",
    "content": "{\n    \"arch\": [\"hierarchical\", \"global\"],\n    \"hidden\": [16, 32, 64, 128],\n    \"pool_ratio\": [0.25, 0.5],\n    \"lr\": [1e-2, 5e-2, 1e-3, 5e-3, 1e-4, 5e-4],\n    \"weight_decay\": [1e-2, 1e-3, 1e-4, 1e-5],\n    \"dataset\": [\"DD\", \"PROTEINS\", \"NCI1\", \"NCI109\", \"Mutagenicity\"]\n}\n"
  },
  {
    "path": "examples/pytorch/sagpool/layer.py",
    "content": "import dgl\nimport torch\nimport torch.nn.functional as F\nfrom dgl.nn import AvgPooling, GraphConv, MaxPooling\nfrom utils import get_batch_id, topk\n\n\nclass SAGPool(torch.nn.Module):\n    \"\"\"The Self-Attention Pooling layer in paper\n    `Self Attention Graph Pooling <https://arxiv.org/pdf/1904.08082.pdf>`\n\n    Args:\n        in_dim (int): The dimension of node feature.\n        ratio (float, optional): The pool ratio which determines the amount of nodes\n            remain after pooling. (default: :obj:`0.5`)\n        conv_op (torch.nn.Module, optional): The graph convolution layer in dgl used to\n        compute scale for each node. (default: :obj:`dgl.nn.GraphConv`)\n        non_linearity (Callable, optional): The non-linearity function, a pytorch function.\n            (default: :obj:`torch.tanh`)\n    \"\"\"\n\n    def __init__(\n        self,\n        in_dim: int,\n        ratio=0.5,\n        conv_op=GraphConv,\n        non_linearity=torch.tanh,\n    ):\n        super(SAGPool, self).__init__()\n        self.in_dim = in_dim\n        self.ratio = ratio\n        self.score_layer = conv_op(in_dim, 1)\n        self.non_linearity = non_linearity\n\n    def forward(self, graph: dgl.DGLGraph, feature: torch.Tensor):\n        score = self.score_layer(graph, feature).squeeze()\n        perm, next_batch_num_nodes = topk(\n            score,\n            self.ratio,\n            get_batch_id(graph.batch_num_nodes()),\n            graph.batch_num_nodes(),\n        )\n        feature = feature[perm] * self.non_linearity(score[perm]).view(-1, 1)\n        graph = dgl.node_subgraph(graph, perm)\n\n        # node_subgraph currently does not support batch-graph,\n        # the 'batch_num_nodes' of the result subgraph is None.\n        # So we manually set the 'batch_num_nodes' here.\n        # Since global pooling has nothing to do with 'batch_num_edges',\n        # we can leave it to be None or unchanged.\n        graph.set_batch_num_nodes(next_batch_num_nodes)\n\n        return graph, feature, perm\n\n\nclass ConvPoolBlock(torch.nn.Module):\n    \"\"\"A combination of GCN layer and SAGPool layer,\n    followed by a concatenated (mean||sum) readout operation.\n    \"\"\"\n\n    def __init__(self, in_dim: int, out_dim: int, pool_ratio=0.8):\n        super(ConvPoolBlock, self).__init__()\n        self.conv = GraphConv(in_dim, out_dim)\n        self.pool = SAGPool(out_dim, ratio=pool_ratio)\n        self.avgpool = AvgPooling()\n        self.maxpool = MaxPooling()\n\n    def forward(self, graph, feature):\n        out = F.relu(self.conv(graph, feature))\n        graph, out, _ = self.pool(graph, out)\n        g_out = torch.cat(\n            [self.avgpool(graph, out), self.maxpool(graph, out)], dim=-1\n        )\n        return graph, out, g_out\n"
  },
  {
    "path": "examples/pytorch/sagpool/main.py",
    "content": "import argparse\nimport json\nimport logging\nimport os\nfrom time import time\n\nimport dgl\n\nimport torch\nimport torch.nn\nimport torch.nn.functional as F\nfrom dgl.data import LegacyTUDataset\nfrom dgl.dataloading import GraphDataLoader\nfrom network import get_sag_network\nfrom torch.utils.data import random_split\nfrom utils import get_stats\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Self-Attention Graph Pooling\")\n    parser.add_argument(\n        \"--dataset\",\n        type=str,\n        default=\"DD\",\n        choices=[\"DD\", \"PROTEINS\", \"NCI1\", \"NCI109\", \"Mutagenicity\"],\n        help=\"DD/PROTEINS/NCI1/NCI109/Mutagenicity\",\n    )\n    parser.add_argument(\n        \"--batch_size\", type=int, default=128, help=\"batch size\"\n    )\n    parser.add_argument(\"--lr\", type=float, default=5e-4, help=\"learning rate\")\n    parser.add_argument(\n        \"--weight_decay\", type=float, default=1e-4, help=\"weight decay\"\n    )\n    parser.add_argument(\n        \"--pool_ratio\", type=float, default=0.5, help=\"pooling ratio\"\n    )\n    parser.add_argument(\"--hid_dim\", type=int, default=128, help=\"hidden size\")\n    parser.add_argument(\n        \"--dropout\", type=float, default=0.5, help=\"dropout ratio\"\n    )\n    parser.add_argument(\n        \"--epochs\",\n        type=int,\n        default=100000,\n        help=\"max number of training epochs\",\n    )\n    parser.add_argument(\n        \"--patience\", type=int, default=50, help=\"patience for early stopping\"\n    )\n    parser.add_argument(\n        \"--device\", type=int, default=-1, help=\"device id, -1 for cpu\"\n    )\n    parser.add_argument(\n        \"--architecture\",\n        type=str,\n        default=\"hierarchical\",\n        choices=[\"hierarchical\", \"global\"],\n        help=\"model architecture\",\n    )\n    parser.add_argument(\n        \"--dataset_path\", type=str, default=\"./dataset\", help=\"path to dataset\"\n    )\n    parser.add_argument(\n        \"--conv_layers\", type=int, default=3, help=\"number of conv layers\"\n    )\n    parser.add_argument(\n        \"--print_every\",\n        type=int,\n        default=10,\n        help=\"print trainlog every k epochs, -1 for silent training\",\n    )\n    parser.add_argument(\n        \"--num_trials\", type=int, default=1, help=\"number of trials\"\n    )\n    parser.add_argument(\"--output_path\", type=str, default=\"./output\")\n\n    args = parser.parse_args()\n\n    # device\n    args.device = \"cpu\" if args.device == -1 else \"cuda:{}\".format(args.device)\n    if not torch.cuda.is_available():\n        logging.warning(\"CUDA is not available, use CPU for training.\")\n        args.device = \"cpu\"\n\n    # print every\n    if args.print_every == -1:\n        args.print_every = args.epochs + 1\n\n    # paths\n    if not os.path.exists(args.dataset_path):\n        os.makedirs(args.dataset_path)\n    if not os.path.exists(args.output_path):\n        os.makedirs(args.output_path)\n    name = \"Data={}_Hidden={}_Arch={}_Pool={}_WeightDecay={}_Lr={}.log\".format(\n        args.dataset,\n        args.hid_dim,\n        args.architecture,\n        args.pool_ratio,\n        args.weight_decay,\n        args.lr,\n    )\n    args.output_path = os.path.join(args.output_path, name)\n\n    return args\n\n\ndef train(model: torch.nn.Module, optimizer, trainloader, device):\n    model.train()\n    total_loss = 0.0\n    num_batches = len(trainloader)\n    for batch in trainloader:\n        optimizer.zero_grad()\n        batch_graphs, batch_labels = batch\n        batch_graphs = batch_graphs.to(device)\n        batch_labels = batch_labels.long().to(device)\n        out = model(batch_graphs)\n        loss = F.nll_loss(out, batch_labels)\n        loss.backward()\n        optimizer.step()\n\n        total_loss += loss.item()\n\n    return total_loss / num_batches\n\n\n@torch.no_grad()\ndef test(model: torch.nn.Module, loader, device):\n    model.eval()\n    correct = 0.0\n    loss = 0.0\n    num_graphs = 0\n    for batch in loader:\n        batch_graphs, batch_labels = batch\n        num_graphs += batch_labels.size(0)\n        batch_graphs = batch_graphs.to(device)\n        batch_labels = batch_labels.long().to(device)\n        out = model(batch_graphs)\n        pred = out.argmax(dim=1)\n        loss += F.nll_loss(out, batch_labels, reduction=\"sum\").item()\n        correct += pred.eq(batch_labels).sum().item()\n    return correct / num_graphs, loss / num_graphs\n\n\ndef main(args):\n    # Step 1: Prepare graph data and retrieve train/validation/test index ============================= #\n    dataset = LegacyTUDataset(args.dataset, raw_dir=args.dataset_path)\n\n    # add self loop. We add self loop for each graph here since the function \"add_self_loop\" does not\n    # support batch graph.\n    for i in range(len(dataset)):\n        dataset.graph_lists[i] = dgl.add_self_loop(dataset.graph_lists[i])\n\n    num_training = int(len(dataset) * 0.8)\n    num_val = int(len(dataset) * 0.1)\n    num_test = len(dataset) - num_val - num_training\n    train_set, val_set, test_set = random_split(\n        dataset, [num_training, num_val, num_test]\n    )\n\n    train_loader = GraphDataLoader(\n        train_set, batch_size=args.batch_size, shuffle=True, num_workers=6\n    )\n    val_loader = GraphDataLoader(\n        val_set, batch_size=args.batch_size, num_workers=2\n    )\n    test_loader = GraphDataLoader(\n        test_set, batch_size=args.batch_size, num_workers=2\n    )\n\n    device = torch.device(args.device)\n\n    # Step 2: Create model =================================================================== #\n    num_feature, num_classes, _ = dataset.statistics()\n    model_op = get_sag_network(args.architecture)\n    model = model_op(\n        in_dim=num_feature,\n        hid_dim=args.hid_dim,\n        out_dim=num_classes,\n        num_convs=args.conv_layers,\n        pool_ratio=args.pool_ratio,\n        dropout=args.dropout,\n    ).to(device)\n    args.num_feature = int(num_feature)\n    args.num_classes = int(num_classes)\n\n    # Step 3: Create training components ===================================================== #\n    optimizer = torch.optim.Adam(\n        model.parameters(), lr=args.lr, weight_decay=args.weight_decay\n    )\n\n    # Step 4: training epoches =============================================================== #\n    bad_cound = 0\n    best_val_loss = float(\"inf\")\n    final_test_acc = 0.0\n    best_epoch = 0\n    train_times = []\n    for e in range(args.epochs):\n        s_time = time()\n        train_loss = train(model, optimizer, train_loader, device)\n        train_times.append(time() - s_time)\n        val_acc, val_loss = test(model, val_loader, device)\n        test_acc, _ = test(model, test_loader, device)\n        if best_val_loss > val_loss:\n            best_val_loss = val_loss\n            final_test_acc = test_acc\n            bad_cound = 0\n            best_epoch = e + 1\n        else:\n            bad_cound += 1\n        if bad_cound >= args.patience:\n            break\n\n        if (e + 1) % args.print_every == 0:\n            log_format = (\n                \"Epoch {}: loss={:.4f}, val_acc={:.4f}, final_test_acc={:.4f}\"\n            )\n            print(log_format.format(e + 1, train_loss, val_acc, final_test_acc))\n    print(\n        \"Best Epoch {}, final test acc {:.4f}\".format(\n            best_epoch, final_test_acc\n        )\n    )\n    return final_test_acc, sum(train_times) / len(train_times)\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    res = []\n    train_times = []\n    for i in range(args.num_trials):\n        print(\"Trial {}/{}\".format(i + 1, args.num_trials))\n        acc, train_time = main(args)\n        res.append(acc)\n        train_times.append(train_time)\n\n    mean, err_bd = get_stats(res)\n    print(\"mean acc: {:.4f}, error bound: {:.4f}\".format(mean, err_bd))\n\n    out_dict = {\n        \"hyper-parameters\": vars(args),\n        \"result\": \"{:.4f}(+-{:.4f})\".format(mean, err_bd),\n        \"train_time\": \"{:.4f}\".format(sum(train_times) / len(train_times)),\n    }\n\n    with open(args.output_path, \"w\") as f:\n        json.dump(out_dict, f, sort_keys=True, indent=4)\n"
  },
  {
    "path": "examples/pytorch/sagpool/network.py",
    "content": "import dgl\nimport torch\nimport torch.nn\nimport torch.nn.functional as F\nfrom dgl.nn import AvgPooling, GraphConv, MaxPooling\nfrom layer import ConvPoolBlock, SAGPool\n\n\nclass SAGNetworkHierarchical(torch.nn.Module):\n    \"\"\"The Self-Attention Graph Pooling Network with hierarchical readout in paper\n    `Self Attention Graph Pooling <https://arxiv.org/pdf/1904.08082.pdf>`\n\n    Args:\n        in_dim (int): The input node feature dimension.\n        hid_dim (int): The hidden dimension for node feature.\n        out_dim (int): The output dimension.\n        num_convs (int, optional): The number of graph convolution layers.\n            (default: 3)\n        pool_ratio (float, optional): The pool ratio which determines the amount of nodes\n            remain after pooling. (default: :obj:`0.5`)\n        dropout (float, optional): The dropout ratio for each layer. (default: 0)\n    \"\"\"\n\n    def __init__(\n        self,\n        in_dim: int,\n        hid_dim: int,\n        out_dim: int,\n        num_convs=3,\n        pool_ratio: float = 0.5,\n        dropout: float = 0.0,\n    ):\n        super(SAGNetworkHierarchical, self).__init__()\n\n        self.dropout = dropout\n        self.num_convpools = num_convs\n\n        convpools = []\n        for i in range(num_convs):\n            _i_dim = in_dim if i == 0 else hid_dim\n            _o_dim = hid_dim\n            convpools.append(\n                ConvPoolBlock(_i_dim, _o_dim, pool_ratio=pool_ratio)\n            )\n        self.convpools = torch.nn.ModuleList(convpools)\n\n        self.lin1 = torch.nn.Linear(hid_dim * 2, hid_dim)\n        self.lin2 = torch.nn.Linear(hid_dim, hid_dim // 2)\n        self.lin3 = torch.nn.Linear(hid_dim // 2, out_dim)\n\n    def forward(self, graph: dgl.DGLGraph):\n        feat = graph.ndata[\"feat\"]\n        final_readout = None\n\n        for i in range(self.num_convpools):\n            graph, feat, readout = self.convpools[i](graph, feat)\n            if final_readout is None:\n                final_readout = readout\n            else:\n                final_readout = final_readout + readout\n\n        feat = F.relu(self.lin1(final_readout))\n        feat = F.dropout(feat, p=self.dropout, training=self.training)\n        feat = F.relu(self.lin2(feat))\n        feat = F.log_softmax(self.lin3(feat), dim=-1)\n\n        return feat\n\n\nclass SAGNetworkGlobal(torch.nn.Module):\n    \"\"\"The Self-Attention Graph Pooling Network with global readout in paper\n    `Self Attention Graph Pooling <https://arxiv.org/pdf/1904.08082.pdf>`\n\n    Args:\n        in_dim (int): The input node feature dimension.\n        hid_dim (int): The hidden dimension for node feature.\n        out_dim (int): The output dimension.\n        num_convs (int, optional): The number of graph convolution layers.\n            (default: 3)\n        pool_ratio (float, optional): The pool ratio which determines the amount of nodes\n            remain after pooling. (default: :obj:`0.5`)\n        dropout (float, optional): The dropout ratio for each layer. (default: 0)\n    \"\"\"\n\n    def __init__(\n        self,\n        in_dim: int,\n        hid_dim: int,\n        out_dim: int,\n        num_convs=3,\n        pool_ratio: float = 0.5,\n        dropout: float = 0.0,\n    ):\n        super(SAGNetworkGlobal, self).__init__()\n        self.dropout = dropout\n        self.num_convs = num_convs\n\n        convs = []\n        for i in range(num_convs):\n            _i_dim = in_dim if i == 0 else hid_dim\n            _o_dim = hid_dim\n            convs.append(GraphConv(_i_dim, _o_dim))\n        self.convs = torch.nn.ModuleList(convs)\n\n        concat_dim = num_convs * hid_dim\n        self.pool = SAGPool(concat_dim, ratio=pool_ratio)\n        self.avg_readout = AvgPooling()\n        self.max_readout = MaxPooling()\n\n        self.lin1 = torch.nn.Linear(concat_dim * 2, hid_dim)\n        self.lin2 = torch.nn.Linear(hid_dim, hid_dim // 2)\n        self.lin3 = torch.nn.Linear(hid_dim // 2, out_dim)\n\n    def forward(self, graph: dgl.DGLGraph):\n        feat = graph.ndata[\"feat\"]\n        conv_res = []\n\n        for i in range(self.num_convs):\n            feat = self.convs[i](graph, feat)\n            conv_res.append(feat)\n\n        conv_res = torch.cat(conv_res, dim=-1)\n        graph, feat, _ = self.pool(graph, conv_res)\n        feat = torch.cat(\n            [self.avg_readout(graph, feat), self.max_readout(graph, feat)],\n            dim=-1,\n        )\n\n        feat = F.relu(self.lin1(feat))\n        feat = F.dropout(feat, p=self.dropout, training=self.training)\n        feat = F.relu(self.lin2(feat))\n        feat = F.log_softmax(self.lin3(feat), dim=-1)\n\n        return feat\n\n\ndef get_sag_network(net_type: str = \"hierarchical\"):\n    if net_type == \"hierarchical\":\n        return SAGNetworkHierarchical\n    elif net_type == \"global\":\n        return SAGNetworkGlobal\n    else:\n        raise ValueError(\n            \"SAGNetwork type {} is not supported.\".format(net_type)\n        )\n"
  },
  {
    "path": "examples/pytorch/sagpool/utils.py",
    "content": "import logging\nimport math\n\nimport torch\nfrom scipy.stats import t\n\n\ndef get_stats(\n    array, conf_interval=False, name=None, stdout=False, logout=False\n):\n    \"\"\"Compute mean and standard deviation from an numerical array\n\n    Args:\n        array (array like obj): The numerical array, this array can be\n            convert to :obj:`torch.Tensor`.\n        conf_interval (bool, optional): If True, compute the confidence interval bound (95%)\n            instead of the std value. (default: :obj:`False`)\n        name (str, optional): The name of this numerical array, for log usage.\n            (default: :obj:`None`)\n        stdout (bool, optional): Whether to output result to the terminal.\n            (default: :obj:`False`)\n        logout (bool, optional): Whether to output result via logging module.\n            (default: :obj:`False`)\n    \"\"\"\n    eps = 1e-9\n    array = torch.Tensor(array)\n    std, mean = torch.std_mean(array)\n    std = std.item()\n    mean = mean.item()\n    center = mean\n\n    if conf_interval:\n        n = array.size(0)\n        se = std / (math.sqrt(n) + eps)\n        t_value = t.ppf(0.975, df=n - 1)\n        err_bound = t_value * se\n    else:\n        err_bound = std\n\n    # log and print\n    if name is None:\n        name = \"array {}\".format(id(array))\n    log = \"{}: {:.4f}(+-{:.4f})\".format(name, center, err_bound)\n    if stdout:\n        print(log)\n    if logout:\n        logging.info(log)\n\n    return center, err_bound\n\n\ndef get_batch_id(num_nodes: torch.Tensor):\n    \"\"\"Convert the num_nodes array obtained from batch graph to batch_id array\n    for each node.\n\n    Args:\n        num_nodes (torch.Tensor): The tensor whose element is the number of nodes\n            in each graph in the batch graph.\n    \"\"\"\n    batch_size = num_nodes.size(0)\n    batch_ids = []\n    for i in range(batch_size):\n        item = torch.full(\n            (num_nodes[i],), i, dtype=torch.long, device=num_nodes.device\n        )\n        batch_ids.append(item)\n    return torch.cat(batch_ids)\n\n\ndef topk(\n    x: torch.Tensor,\n    ratio: float,\n    batch_id: torch.Tensor,\n    num_nodes: torch.Tensor,\n):\n    \"\"\"The top-k pooling method. Given a graph batch, this method will pool out some\n    nodes from input node feature tensor for each graph according to the given ratio.\n\n    Args:\n        x (torch.Tensor): The input node feature batch-tensor to be pooled.\n        ratio (float): the pool ratio. For example if :obj:`ratio=0.5` then half of the input\n            tensor will be pooled out.\n        batch_id (torch.Tensor): The batch_id of each element in the input tensor.\n        num_nodes (torch.Tensor): The number of nodes of each graph in batch.\n\n    Returns:\n        perm (torch.Tensor): The index in batch to be kept.\n        k (torch.Tensor): The remaining number of nodes for each graph.\n    \"\"\"\n    batch_size, max_num_nodes = num_nodes.size(0), num_nodes.max().item()\n\n    cum_num_nodes = torch.cat(\n        [num_nodes.new_zeros(1), num_nodes.cumsum(dim=0)[:-1]], dim=0\n    )\n\n    index = torch.arange(batch_id.size(0), dtype=torch.long, device=x.device)\n    index = (index - cum_num_nodes[batch_id]) + (batch_id * max_num_nodes)\n\n    dense_x = x.new_full(\n        (batch_size * max_num_nodes,), torch.finfo(x.dtype).min\n    )\n    dense_x[index] = x\n    dense_x = dense_x.view(batch_size, max_num_nodes)\n\n    _, perm = dense_x.sort(dim=-1, descending=True)\n    perm = perm + cum_num_nodes.view(-1, 1)\n    perm = perm.view(-1)\n\n    k = (ratio * num_nodes.to(torch.float)).ceil().to(torch.long)\n    mask = [\n        torch.arange(k[i], dtype=torch.long, device=x.device)\n        + i * max_num_nodes\n        for i in range(batch_size)\n    ]\n\n    mask = torch.cat(mask, dim=0)\n    perm = perm[mask]\n\n    return perm, k\n"
  },
  {
    "path": "examples/pytorch/seal/README.md",
    "content": "# DGL Implementation of the SEAL Paper\nThis DGL example implements the link prediction model proposed in the paper \n[Link Prediction Based on Graph Neural Networks](https://arxiv.org/pdf/1802.09691.pdf) \nand [REVISITING GRAPH NEURAL NETWORKS FOR LINK PREDICTION](https://arxiv.org/pdf/2010.16103.pdf)  \nThe author's codes of implementation is in [SEAL](https://github.com/muhanzhang/SEAL) (pytorch)\nand [SEAL_ogb](https://github.com/facebookresearch/SEAL_OGB) (torch_geometric)\n\nExample implementor\n----------------------\nThis example was implemented by [Smile](https://github.com/Smilexuhc) during his intern work at the AWS Shanghai AI Lab.\n\nThe graph dataset used in this example \n---------------------------------------\n\nogbl-collab\n - NumNodes: 235868\n - NumEdges: 2358104\n - NumNodeFeats: 128\n - NumEdgeWeights: 1\n - NumValidEdges: 160084\n - NumTestEdges: 146329\n\nDependencies\n--------------------------------\n\n- python 3.6+\n- Pytorch 1.5.0+\n- dgl 0.6.0 +\n- ogb  \n- pandas\n- tqdm\n- scipy\n\n\n How to run example files\n--------------------------------\nIn the seal_dgl folder    \nrun on cpu:  \n```shell script\npython main.py --gpu_id=-1 --subsample_ratio=0.1\n```\nrun on gpu:  \n```shell script\npython main.py --gpu_id=0  --subsample_ratio=0.1\n```\n\nPerformance\n-------------------------\nexperiment on `ogbl-collab`\n\n| method | valid-hits@50 | test-hits@50 |\n| ------ | ------------- | ------------ |\n| paper  | 63.89(0.49)         | 53.71(0.47)        |\n| ours     | 63.56(0.71)         | 53.61(0.78)        |\n\nNote: We only perform 5 trails in the experiment. "
  },
  {
    "path": "examples/pytorch/seal/logger.py",
    "content": "import logging\nimport os\nimport time\n\n\ndef _transform_log_level(str_level):\n    if str_level == \"info\":\n        return logging.INFO\n    elif str_level == \"warning\":\n        return logging.WARNING\n    elif str_level == \"critical\":\n        return logging.CRITICAL\n    elif str_level == \"debug\":\n        return logging.DEBUG\n    elif str_level == \"error\":\n        return logging.ERROR\n    else:\n        raise KeyError(\"Log level error\")\n\n\nclass LightLogging(object):\n    def __init__(self, log_path=None, log_name=\"lightlog\", log_level=\"debug\"):\n        log_level = _transform_log_level(log_level)\n\n        if log_path:\n            if not log_path.endswith(\"/\"):\n                log_path += \"/\"\n            if not os.path.exists(log_path):\n                os.mkdir(log_path)\n\n            if log_name.endswith(\"-\") or log_name.endswith(\"_\"):\n                log_name = (\n                    log_path\n                    + log_name\n                    + time.strftime(\n                        \"%Y-%m-%d-%H:%M\", time.localtime(time.time())\n                    )\n                    + \".log\"\n                )\n            else:\n                log_name = (\n                    log_path\n                    + log_name\n                    + \"_\"\n                    + time.strftime(\n                        \"%Y-%m-%d-%H-%M\", time.localtime(time.time())\n                    )\n                    + \".log\"\n                )\n\n            logging.basicConfig(\n                level=log_level,\n                format=\"%(asctime)s %(levelname)s: %(message)s\",\n                datefmt=\"%Y-%m-%d-%H:%M\",\n                handlers=[\n                    logging.FileHandler(log_name, mode=\"w\"),\n                    logging.StreamHandler(),\n                ],\n            )\n            logging.info(\"Start Logging\")\n            logging.info(\"Log file path: {}\".format(log_name))\n\n        else:\n            logging.basicConfig(\n                level=log_level,\n                format=\"%(asctime)s %(levelname)s: %(message)s\",\n                datefmt=\"%Y-%m-%d-%H:%M\",\n                handlers=[logging.StreamHandler()],\n            )\n            logging.info(\"Start Logging\")\n\n    def debug(self, msg):\n        logging.debug(msg)\n\n    def info(self, msg):\n        logging.info(msg)\n\n    def critical(self, msg):\n        logging.critical(msg)\n\n    def warning(self, msg):\n        logging.warning(msg)\n\n    def error(self, msg):\n        logging.error(msg)\n"
  },
  {
    "path": "examples/pytorch/seal/main.py",
    "content": "import time\n\nimport numpy as np\nimport torch\nimport torch.multiprocessing\n\nfrom dgl import EID, NID\nfrom dgl.dataloading import GraphDataLoader\nfrom logger import LightLogging\nfrom model import DGCNN, GCN\nfrom sampler import SEALData\nfrom torch.nn import BCEWithLogitsLoss\nfrom tqdm import tqdm\nfrom utils import evaluate_hits, load_ogb_dataset, parse_arguments\n\ntorch.multiprocessing.set_sharing_strategy(\"file_system\")\n\n\"\"\"\nPart of the code are adapted from\nhttps://github.com/facebookresearch/SEAL_OGB\n\"\"\"\n\n\ndef train(\n    model,\n    dataloader,\n    loss_fn,\n    optimizer,\n    device,\n    num_graphs=32,\n    total_graphs=None,\n):\n    model.train()\n\n    total_loss = 0\n    for g, labels in tqdm(dataloader, ncols=100):\n        g = g.to(device)\n        labels = labels.to(device)\n        optimizer.zero_grad()\n        logits = model(g, g.ndata[\"z\"], g.ndata[NID], g.edata[EID])\n        loss = loss_fn(logits, labels)\n        loss.backward()\n        optimizer.step()\n        total_loss += loss.item() * num_graphs\n\n    return total_loss / total_graphs\n\n\n@torch.no_grad()\ndef evaluate(model, dataloader, device):\n    model.eval()\n\n    y_pred, y_true = [], []\n    for g, labels in tqdm(dataloader, ncols=100):\n        g = g.to(device)\n        logits = model(g, g.ndata[\"z\"], g.ndata[NID], g.edata[EID])\n        y_pred.append(logits.view(-1).cpu())\n        y_true.append(labels.view(-1).cpu().to(torch.float))\n\n    y_pred, y_true = torch.cat(y_pred), torch.cat(y_true)\n    pos_pred = y_pred[y_true == 1]\n    neg_pred = y_pred[y_true == 0]\n\n    return pos_pred, neg_pred\n\n\ndef main(args, print_fn=print):\n    print_fn(\"Experiment arguments: {}\".format(args))\n\n    if args.random_seed:\n        torch.manual_seed(args.random_seed)\n    else:\n        torch.manual_seed(123)\n    # Load dataset\n    if args.dataset.startswith(\"ogbl\"):\n        graph, split_edge = load_ogb_dataset(args.dataset)\n    else:\n        raise NotImplementedError\n\n    num_nodes = graph.num_nodes()\n\n    # set gpu\n    if args.gpu_id >= 0 and torch.cuda.is_available():\n        device = \"cuda:{}\".format(args.gpu_id)\n    else:\n        device = \"cpu\"\n\n    if args.dataset == \"ogbl-collab\":\n        # ogbl-collab dataset is multi-edge graph\n        use_coalesce = True\n    else:\n        use_coalesce = False\n\n    # Generate positive and negative edges and corresponding labels\n    # Sampling subgraphs and generate node labeling features\n    seal_data = SEALData(\n        g=graph,\n        split_edge=split_edge,\n        hop=args.hop,\n        neg_samples=args.neg_samples,\n        subsample_ratio=args.subsample_ratio,\n        use_coalesce=use_coalesce,\n        prefix=args.dataset,\n        save_dir=args.save_dir,\n        num_workers=args.num_workers,\n        print_fn=print_fn,\n    )\n    node_attribute = seal_data.ndata[\"feat\"]\n    edge_weight = seal_data.edata[\"weight\"].float()\n\n    train_data = seal_data(\"train\")\n    val_data = seal_data(\"valid\")\n    test_data = seal_data(\"test\")\n\n    train_graphs = len(train_data.graph_list)\n\n    # Set data loader\n\n    train_loader = GraphDataLoader(\n        train_data, batch_size=args.batch_size, num_workers=args.num_workers\n    )\n    val_loader = GraphDataLoader(\n        val_data, batch_size=args.batch_size, num_workers=args.num_workers\n    )\n    test_loader = GraphDataLoader(\n        test_data, batch_size=args.batch_size, num_workers=args.num_workers\n    )\n\n    # set model\n    if args.model == \"gcn\":\n        model = GCN(\n            num_layers=args.num_layers,\n            hidden_units=args.hidden_units,\n            gcn_type=args.gcn_type,\n            pooling_type=args.pooling,\n            node_attributes=node_attribute,\n            edge_weights=edge_weight,\n            node_embedding=None,\n            use_embedding=True,\n            num_nodes=num_nodes,\n            dropout=args.dropout,\n        )\n    elif args.model == \"dgcnn\":\n        model = DGCNN(\n            num_layers=args.num_layers,\n            hidden_units=args.hidden_units,\n            k=args.sort_k,\n            gcn_type=args.gcn_type,\n            node_attributes=node_attribute,\n            edge_weights=edge_weight,\n            node_embedding=None,\n            use_embedding=True,\n            num_nodes=num_nodes,\n            dropout=args.dropout,\n        )\n    else:\n        raise ValueError(\"Model error\")\n\n    model = model.to(device)\n    parameters = model.parameters()\n    optimizer = torch.optim.Adam(parameters, lr=args.lr)\n    loss_fn = BCEWithLogitsLoss()\n    print_fn(\n        \"Total parameters: {}\".format(\n            sum([p.numel() for p in model.parameters()])\n        )\n    )\n\n    # train and evaluate loop\n    summary_val = []\n    summary_test = []\n    for epoch in range(args.epochs):\n        start_time = time.time()\n        loss = train(\n            model=model,\n            dataloader=train_loader,\n            loss_fn=loss_fn,\n            optimizer=optimizer,\n            device=device,\n            num_graphs=args.batch_size,\n            total_graphs=train_graphs,\n        )\n        train_time = time.time()\n        if epoch % args.eval_steps == 0:\n            val_pos_pred, val_neg_pred = evaluate(\n                model=model, dataloader=val_loader, device=device\n            )\n            test_pos_pred, test_neg_pred = evaluate(\n                model=model, dataloader=test_loader, device=device\n            )\n\n            val_metric = evaluate_hits(\n                args.dataset, val_pos_pred, val_neg_pred, args.hits_k\n            )\n            test_metric = evaluate_hits(\n                args.dataset, test_pos_pred, test_neg_pred, args.hits_k\n            )\n            evaluate_time = time.time()\n            print_fn(\n                \"Epoch-{}, train loss: {:.4f}, hits@{}: val-{:.4f}, test-{:.4f}, \"\n                \"cost time: train-{:.1f}s, total-{:.1f}s\".format(\n                    epoch,\n                    loss,\n                    args.hits_k,\n                    val_metric,\n                    test_metric,\n                    train_time - start_time,\n                    evaluate_time - start_time,\n                )\n            )\n            summary_val.append(val_metric)\n            summary_test.append(test_metric)\n\n    summary_test = np.array(summary_test)\n\n    print_fn(\"Experiment Results:\")\n    print_fn(\n        \"Best hits@{}: {:.4f}, epoch: {}\".format(\n            args.hits_k, np.max(summary_test), np.argmax(summary_test)\n        )\n    )\n\n\nif __name__ == \"__main__\":\n    args = parse_arguments()\n    logger = LightLogging(log_name=\"SEAL\", log_path=\"./logs\")\n    main(args, logger.info)\n"
  },
  {
    "path": "examples/pytorch/seal/model.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom dgl.nn.pytorch import GraphConv, SAGEConv, SortPooling, SumPooling\n\n\nclass GCN(nn.Module):\n    \"\"\"\n    GCN Model\n\n    Attributes:\n        num_layers(int): num of gcn layers\n        hidden_units(int): num of hidden units\n        gcn_type(str): type of gcn layer, 'gcn' for GraphConv and 'sage' for SAGEConv\n        pooling_type(str): type of graph pooling to get subgraph representation\n                           'sum' for sum pooling and 'center' for center pooling.\n        node_attributes(Tensor, optional): node attribute\n        edge_weights(Tensor, optional): edge weight\n        node_embedding(Tensor, optional): pre-trained node embedding\n        use_embedding(bool, optional): whether to use node embedding. Note that if 'use_embedding' is set True\n                             and 'node_embedding' is None, will automatically randomly initialize node embedding.\n        num_nodes(int, optional): num of nodes\n        dropout(float, optional): dropout rate\n        max_z(int, optional): default max vocab size of node labeling, default 1000.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        num_layers,\n        hidden_units,\n        gcn_type=\"gcn\",\n        pooling_type=\"sum\",\n        node_attributes=None,\n        edge_weights=None,\n        node_embedding=None,\n        use_embedding=False,\n        num_nodes=None,\n        dropout=0.5,\n        max_z=1000,\n    ):\n        super(GCN, self).__init__()\n        self.num_layers = num_layers\n        self.dropout = dropout\n        self.pooling_type = pooling_type\n        self.use_attribute = False if node_attributes is None else True\n        self.use_embedding = use_embedding\n        self.use_edge_weight = False if edge_weights is None else True\n\n        self.z_embedding = nn.Embedding(max_z, hidden_units)\n        if node_attributes is not None:\n            self.node_attributes_lookup = nn.Embedding.from_pretrained(\n                node_attributes\n            )\n            self.node_attributes_lookup.weight.requires_grad = False\n        if edge_weights is not None:\n            self.edge_weights_lookup = nn.Embedding.from_pretrained(\n                edge_weights\n            )\n            self.edge_weights_lookup.weight.requires_grad = False\n        if node_embedding is not None:\n            self.node_embedding = nn.Embedding.from_pretrained(node_embedding)\n            self.node_embedding.weight.requires_grad = False\n        elif use_embedding:\n            self.node_embedding = nn.Embedding(num_nodes, hidden_units)\n\n        initial_dim = hidden_units\n        if self.use_attribute:\n            initial_dim += self.node_attributes_lookup.embedding_dim\n        if self.use_embedding:\n            initial_dim += self.node_embedding.embedding_dim\n\n        self.layers = nn.ModuleList()\n        if gcn_type == \"gcn\":\n            self.layers.append(\n                GraphConv(initial_dim, hidden_units, allow_zero_in_degree=True)\n            )\n            for _ in range(num_layers - 1):\n                self.layers.append(\n                    GraphConv(\n                        hidden_units, hidden_units, allow_zero_in_degree=True\n                    )\n                )\n        elif gcn_type == \"sage\":\n            self.layers.append(\n                SAGEConv(initial_dim, hidden_units, aggregator_type=\"gcn\")\n            )\n            for _ in range(num_layers - 1):\n                self.layers.append(\n                    SAGEConv(hidden_units, hidden_units, aggregator_type=\"gcn\")\n                )\n        else:\n            raise ValueError(\"Gcn type error.\")\n\n        self.linear_1 = nn.Linear(hidden_units, hidden_units)\n        self.linear_2 = nn.Linear(hidden_units, 1)\n        if pooling_type != \"sum\":\n            raise ValueError(\"Pooling type error.\")\n        self.pooling = SumPooling()\n\n    def reset_parameters(self):\n        for layer in self.layers:\n            layer.reset_parameters()\n\n    def forward(self, g, z, node_id=None, edge_id=None):\n        \"\"\"\n        Args:\n            g(DGLGraph): the graph\n            z(Tensor): node labeling tensor, shape [N, 1]\n            node_id(Tensor, optional): node id tensor, shape [N, 1]\n            edge_id(Tensor, optional): edge id tensor, shape [E, 1]\n        Returns:\n            x(Tensor): output tensor\n\n        \"\"\"\n\n        z_emb = self.z_embedding(z)\n\n        if self.use_attribute:\n            x = self.node_attributes_lookup(node_id)\n            x = torch.cat([z_emb, x], 1)\n        else:\n            x = z_emb\n\n        if self.use_edge_weight:\n            edge_weight = self.edge_weights_lookup(edge_id)\n        else:\n            edge_weight = None\n\n        if self.use_embedding:\n            n_emb = self.node_embedding(node_id)\n            x = torch.cat([x, n_emb], 1)\n\n        for layer in self.layers[:-1]:\n            x = layer(g, x, edge_weight=edge_weight)\n            x = F.relu(x)\n            x = F.dropout(x, p=self.dropout, training=self.training)\n        x = self.layers[-1](g, x, edge_weight=edge_weight)\n\n        x = self.pooling(g, x)\n        x = F.relu(self.linear_1(x))\n        F.dropout(x, p=self.dropout, training=self.training)\n        x = self.linear_2(x)\n\n        return x\n\n\nclass DGCNN(nn.Module):\n    \"\"\"\n    An end-to-end deep learning architecture for graph classification.\n    paper link: https://muhanzhang.github.io/papers/AAAI_2018_DGCNN.pdf\n\n    Attributes:\n        num_layers(int): num of gcn layers\n        hidden_units(int): num of hidden units\n        k(int, optional): The number of nodes to hold for each graph in SortPooling.\n        gcn_type(str): type of gcn layer, 'gcn' for GraphConv and 'sage' for SAGEConv\n        node_attributes(Tensor, optional): node attribute\n        edge_weights(Tensor, optional): edge weight\n        node_embedding(Tensor, optional): pre-trained node embedding\n        use_embedding(bool, optional): whether to use node embedding. Note that if 'use_embedding' is set True\n                             and 'node_embedding' is None, will automatically randomly initialize node embedding.\n        num_nodes(int, optional): num of nodes\n        dropout(float, optional): dropout rate\n        max_z(int, optional): default max vocab size of node labeling, default 1000.\n    \"\"\"\n\n    def __init__(\n        self,\n        num_layers,\n        hidden_units,\n        k=10,\n        gcn_type=\"gcn\",\n        node_attributes=None,\n        edge_weights=None,\n        node_embedding=None,\n        use_embedding=False,\n        num_nodes=None,\n        dropout=0.5,\n        max_z=1000,\n    ):\n        super(DGCNN, self).__init__()\n        self.num_layers = num_layers\n        self.dropout = dropout\n        self.use_attribute = False if node_attributes is None else True\n        self.use_embedding = use_embedding\n        self.use_edge_weight = False if edge_weights is None else True\n\n        self.z_embedding = nn.Embedding(max_z, hidden_units)\n\n        if node_attributes is not None:\n            self.node_attributes_lookup = nn.Embedding.from_pretrained(\n                node_attributes\n            )\n            self.node_attributes_lookup.weight.requires_grad = False\n        if edge_weights is not None:\n            self.edge_weights_lookup = nn.Embedding.from_pretrained(\n                edge_weights\n            )\n            self.edge_weights_lookup.weight.requires_grad = False\n        if node_embedding is not None:\n            self.node_embedding = nn.Embedding.from_pretrained(node_embedding)\n            self.node_embedding.weight.requires_grad = False\n        elif use_embedding:\n            self.node_embedding = nn.Embedding(num_nodes, hidden_units)\n\n        initial_dim = hidden_units\n        if self.use_attribute:\n            initial_dim += self.node_attributes_lookup.embedding_dim\n        if self.use_embedding:\n            initial_dim += self.node_embedding.embedding_dim\n\n        self.layers = nn.ModuleList()\n        if gcn_type == \"gcn\":\n            self.layers.append(\n                GraphConv(initial_dim, hidden_units, allow_zero_in_degree=True)\n            )\n            for _ in range(num_layers - 1):\n                self.layers.append(\n                    GraphConv(\n                        hidden_units, hidden_units, allow_zero_in_degree=True\n                    )\n                )\n            self.layers.append(\n                GraphConv(hidden_units, 1, allow_zero_in_degree=True)\n            )\n        elif gcn_type == \"sage\":\n            self.layers.append(\n                SAGEConv(initial_dim, hidden_units, aggregator_type=\"gcn\")\n            )\n            for _ in range(num_layers - 1):\n                self.layers.append(\n                    SAGEConv(hidden_units, hidden_units, aggregator_type=\"gcn\")\n                )\n            self.layers.append(SAGEConv(hidden_units, 1, aggregator_type=\"gcn\"))\n        else:\n            raise ValueError(\"Gcn type error.\")\n\n        self.pooling = SortPooling(k=k)\n        conv1d_channels = [16, 32]\n        total_latent_dim = hidden_units * num_layers + 1\n        conv1d_kws = [total_latent_dim, 5]\n        self.conv_1 = nn.Conv1d(\n            1, conv1d_channels[0], conv1d_kws[0], conv1d_kws[0]\n        )\n        self.maxpool1d = nn.MaxPool1d(2, 2)\n        self.conv_2 = nn.Conv1d(\n            conv1d_channels[0], conv1d_channels[1], conv1d_kws[1], 1\n        )\n        dense_dim = int((k - 2) / 2 + 1)\n        dense_dim = (dense_dim - conv1d_kws[1] + 1) * conv1d_channels[1]\n        self.linear_1 = nn.Linear(dense_dim, 128)\n        self.linear_2 = nn.Linear(128, 1)\n\n    def forward(self, g, z, node_id=None, edge_id=None):\n        \"\"\"\n        Args:\n            g(DGLGraph): the graph\n            z(Tensor): node labeling tensor, shape [N, 1]\n            node_id(Tensor, optional): node id tensor, shape [N, 1]\n            edge_id(Tensor, optional): edge id tensor, shape [E, 1]\n        Returns:\n            x(Tensor): output tensor\n        \"\"\"\n        z_emb = self.z_embedding(z)\n        if self.use_attribute:\n            x = self.node_attributes_lookup(node_id)\n            x = torch.cat([z_emb, x], 1)\n        else:\n            x = z_emb\n        if self.use_edge_weight:\n            edge_weight = self.edge_weights_lookup(edge_id)\n        else:\n            edge_weight = None\n\n        if self.use_embedding:\n            n_emb = self.node_embedding(node_id)\n            x = torch.cat([x, n_emb], 1)\n\n        xs = [x]\n        for layer in self.layers:\n            out = torch.tanh(layer(g, xs[-1], edge_weight=edge_weight))\n            xs += [out]\n\n        x = torch.cat(xs[1:], dim=-1)\n\n        # SortPooling\n        x = self.pooling(g, x)\n        x = x.unsqueeze(1)\n        x = F.relu(self.conv_1(x))\n        x = self.maxpool1d(x)\n        x = F.relu(self.conv_2(x))\n        x = x.view(x.size(0), -1)\n\n        x = F.relu(self.linear_1(x))\n        F.dropout(x, p=self.dropout, training=self.training)\n        x = self.linear_2(x)\n\n        return x\n"
  },
  {
    "path": "examples/pytorch/seal/sampler.py",
    "content": "import os.path as osp\nfrom copy import deepcopy\n\nimport dgl\n\nimport torch\nfrom dgl import add_self_loop, DGLGraph, NID\nfrom dgl.dataloading.negative_sampler import Uniform\nfrom torch.utils.data import DataLoader, Dataset\nfrom tqdm import tqdm\nfrom utils import drnl_node_labeling\n\n\nclass GraphDataSet(Dataset):\n    \"\"\"\n    GraphDataset for torch DataLoader\n    \"\"\"\n\n    def __init__(self, graph_list, tensor):\n        self.graph_list = graph_list\n        self.tensor = tensor\n\n    def __len__(self):\n        return len(self.graph_list)\n\n    def __getitem__(self, index):\n        return (self.graph_list[index], self.tensor[index])\n\n\nclass PosNegEdgesGenerator(object):\n    \"\"\"\n    Generate positive and negative samples\n    Attributes:\n        g(dgl.DGLGraph): graph\n        split_edge(dict): split edge\n        neg_samples(int): num of negative samples per positive sample\n        subsample_ratio(float): ratio of subsample\n        shuffle(bool): if shuffle generated graph list\n    \"\"\"\n\n    def __init__(\n        self, g, split_edge, neg_samples=1, subsample_ratio=0.1, shuffle=True\n    ):\n        self.neg_sampler = Uniform(neg_samples)\n        self.subsample_ratio = subsample_ratio\n        self.split_edge = split_edge\n        self.g = g\n        self.shuffle = shuffle\n\n    def __call__(self, split_type):\n        if split_type == \"train\":\n            subsample_ratio = self.subsample_ratio\n        else:\n            subsample_ratio = 1\n\n        pos_edges = self.split_edge[split_type][\"edge\"]\n        if split_type == \"train\":\n            # Adding self loop in train avoids sampling the source node itself.\n            g = add_self_loop(self.g)\n            eids = g.edge_ids(pos_edges[:, 0], pos_edges[:, 1])\n            neg_edges = torch.stack(self.neg_sampler(g, eids), dim=1)\n        else:\n            neg_edges = self.split_edge[split_type][\"edge_neg\"]\n        pos_edges = self.subsample(pos_edges, subsample_ratio).long()\n        neg_edges = self.subsample(neg_edges, subsample_ratio).long()\n\n        edges = torch.cat([pos_edges, neg_edges])\n        labels = torch.cat(\n            [\n                torch.ones(pos_edges.size(0), 1),\n                torch.zeros(neg_edges.size(0), 1),\n            ]\n        )\n        if self.shuffle:\n            perm = torch.randperm(edges.size(0))\n            edges = edges[perm]\n            labels = labels[perm]\n        return edges, labels\n\n    def subsample(self, edges, subsample_ratio):\n        \"\"\"\n        Subsample generated edges.\n        Args:\n            edges(Tensor): edges to subsample\n            subsample_ratio(float): ratio of subsample\n\n        Returns:\n            edges(Tensor):  edges\n\n        \"\"\"\n\n        num_edges = edges.size(0)\n        perm = torch.randperm(num_edges)\n        perm = perm[: int(subsample_ratio * num_edges)]\n        edges = edges[perm]\n        return edges\n\n\nclass EdgeDataSet(Dataset):\n    \"\"\"\n    Assistant Dataset for speeding up the SEALSampler\n    \"\"\"\n\n    def __init__(self, edges, labels, transform):\n        self.edges = edges\n        self.transform = transform\n        self.labels = labels\n\n    def __len__(self):\n        return len(self.edges)\n\n    def __getitem__(self, index):\n        subgraph = self.transform(self.edges[index])\n        return (subgraph, self.labels[index])\n\n\nclass SEALSampler(object):\n    \"\"\"\n    Sampler for SEAL in paper(no-block version)\n    The  strategy is to sample all the k-hop neighbors around the two target nodes.\n    Attributes:\n        graph(DGLGraph): The graph\n        hop(int): num of hop\n        num_workers(int): num of workers\n\n    \"\"\"\n\n    def __init__(self, graph, hop=1, num_workers=32, print_fn=print):\n        self.graph = graph\n        self.hop = hop\n        self.print_fn = print_fn\n        self.num_workers = num_workers\n\n    def sample_subgraph(self, target_nodes):\n        \"\"\"\n        Args:\n            target_nodes(Tensor): Tensor of two target nodes\n        Returns:\n            subgraph(DGLGraph): subgraph\n        \"\"\"\n        sample_nodes = [target_nodes]\n        frontiers = target_nodes\n\n        for i in range(self.hop):\n            frontiers = self.graph.out_edges(frontiers)[1]\n            frontiers = torch.unique(frontiers)\n            sample_nodes.append(frontiers)\n\n        sample_nodes = torch.cat(sample_nodes)\n        sample_nodes = torch.unique(sample_nodes)\n        subgraph = dgl.node_subgraph(self.graph, sample_nodes)\n\n        # Each node should have unique node id in the new subgraph\n        u_id = int(\n            torch.nonzero(\n                subgraph.ndata[NID] == int(target_nodes[0]), as_tuple=False\n            )\n        )\n        v_id = int(\n            torch.nonzero(\n                subgraph.ndata[NID] == int(target_nodes[1]), as_tuple=False\n            )\n        )\n\n        # remove link between target nodes in positive subgraphs.\n        if subgraph.has_edges_between(u_id, v_id):\n            link_id = subgraph.edge_ids(u_id, v_id, return_uv=True)[2]\n            subgraph.remove_edges(link_id)\n        if subgraph.has_edges_between(v_id, u_id):\n            link_id = subgraph.edge_ids(v_id, u_id, return_uv=True)[2]\n            subgraph.remove_edges(link_id)\n\n        z = drnl_node_labeling(subgraph, u_id, v_id)\n        subgraph.ndata[\"z\"] = z\n\n        return subgraph\n\n    def _collate(self, batch):\n        batch_graphs, batch_labels = map(list, zip(*batch))\n\n        batch_graphs = dgl.batch(batch_graphs)\n        batch_labels = torch.stack(batch_labels)\n        return batch_graphs, batch_labels\n\n    def __call__(self, edges, labels):\n        subgraph_list = []\n        labels_list = []\n        edge_dataset = EdgeDataSet(\n            edges, labels, transform=self.sample_subgraph\n        )\n        self.print_fn(\n            \"Using {} workers in sampling job.\".format(self.num_workers)\n        )\n        sampler = DataLoader(\n            edge_dataset,\n            batch_size=32,\n            num_workers=self.num_workers,\n            shuffle=False,\n            collate_fn=self._collate,\n        )\n        for subgraph, label in tqdm(sampler, ncols=100):\n            label_copy = deepcopy(label)\n            subgraph = dgl.unbatch(subgraph)\n\n            del label\n            subgraph_list += subgraph\n            labels_list.append(label_copy)\n\n        return subgraph_list, torch.cat(labels_list)\n\n\nclass SEALData(object):\n    \"\"\"\n    1. Generate positive and negative samples\n    2. Subgraph sampling\n\n    Attributes:\n        g(dgl.DGLGraph): graph\n        split_edge(dict): split edge\n        hop(int): num of hop\n        neg_samples(int): num of negative samples per positive sample\n        subsample_ratio(float): ratio of subsample\n        use_coalesce(bool): True for coalesce graph. Graph with multi-edge need to coalesce\n    \"\"\"\n\n    def __init__(\n        self,\n        g,\n        split_edge,\n        hop=1,\n        neg_samples=1,\n        subsample_ratio=1,\n        prefix=None,\n        save_dir=None,\n        num_workers=32,\n        shuffle=True,\n        use_coalesce=True,\n        print_fn=print,\n    ):\n        self.g = g\n        self.hop = hop\n        self.subsample_ratio = subsample_ratio\n        self.prefix = prefix\n        self.save_dir = save_dir\n        self.print_fn = print_fn\n\n        self.generator = PosNegEdgesGenerator(\n            g=self.g,\n            split_edge=split_edge,\n            neg_samples=neg_samples,\n            subsample_ratio=subsample_ratio,\n            shuffle=shuffle,\n        )\n        if use_coalesce:\n            for k, v in g.edata.items():\n                g.edata[k] = v.float()  # dgl.to_simple() requires data is float\n            self.g = dgl.to_simple(\n                g, copy_ndata=True, copy_edata=True, aggregator=\"sum\"\n            )\n\n        self.ndata = {k: v for k, v in self.g.ndata.items()}\n        self.edata = {k: v for k, v in self.g.edata.items()}\n        self.g.ndata.clear()\n        self.g.edata.clear()\n        self.print_fn(\"Save ndata and edata in class.\")\n        self.print_fn(\"Clear ndata and edata in graph.\")\n\n        self.sampler = SEALSampler(\n            graph=self.g, hop=hop, num_workers=num_workers, print_fn=print_fn\n        )\n\n    def __call__(self, split_type):\n        if split_type == \"train\":\n            subsample_ratio = self.subsample_ratio\n        else:\n            subsample_ratio = 1\n\n        path = osp.join(\n            self.save_dir or \"\",\n            \"{}_{}_{}-hop_{}-subsample.bin\".format(\n                self.prefix, split_type, self.hop, subsample_ratio\n            ),\n        )\n\n        if osp.exists(path):\n            self.print_fn(\"Load existing processed {} files\".format(split_type))\n            graph_list, data = dgl.load_graphs(path)\n            dataset = GraphDataSet(graph_list, data[\"labels\"])\n\n        else:\n            self.print_fn(\"Processed {} files not exist.\".format(split_type))\n\n            edges, labels = self.generator(split_type)\n            self.print_fn(\"Generate {} edges totally.\".format(edges.size(0)))\n\n            graph_list, labels = self.sampler(edges, labels)\n            dataset = GraphDataSet(graph_list, labels)\n            dgl.save_graphs(path, graph_list, {\"labels\": labels})\n            self.print_fn(\"Save preprocessed subgraph to {}\".format(path))\n        return dataset\n"
  },
  {
    "path": "examples/pytorch/seal/utils.py",
    "content": "import argparse\n\nimport dgl\n\nimport numpy as np\nimport pandas as pd\nimport torch\nfrom ogb.linkproppred import DglLinkPropPredDataset, Evaluator\nfrom scipy.sparse.csgraph import shortest_path\n\n\ndef parse_arguments():\n    \"\"\"\n    Parse arguments\n    \"\"\"\n    parser = argparse.ArgumentParser(description=\"SEAL\")\n    parser.add_argument(\"--dataset\", type=str, default=\"ogbl-collab\")\n    parser.add_argument(\"--gpu_id\", type=int, default=0)\n    parser.add_argument(\"--hop\", type=int, default=1)\n    parser.add_argument(\"--model\", type=str, default=\"dgcnn\")\n    parser.add_argument(\"--gcn_type\", type=str, default=\"gcn\")\n    parser.add_argument(\"--num_layers\", type=int, default=3)\n    parser.add_argument(\"--hidden_units\", type=int, default=32)\n    parser.add_argument(\"--sort_k\", type=int, default=30)\n    parser.add_argument(\"--pooling\", type=str, default=\"sum\")\n    parser.add_argument(\"--dropout\", type=str, default=0.5)\n    parser.add_argument(\"--hits_k\", type=int, default=50)\n    parser.add_argument(\"--lr\", type=float, default=0.0001)\n    parser.add_argument(\"--neg_samples\", type=int, default=1)\n    parser.add_argument(\"--subsample_ratio\", type=float, default=0.1)\n    parser.add_argument(\"--epochs\", type=int, default=60)\n    parser.add_argument(\"--batch_size\", type=int, default=32)\n    parser.add_argument(\"--eval_steps\", type=int, default=5)\n    parser.add_argument(\"--num_workers\", type=int, default=32)\n    parser.add_argument(\"--random_seed\", type=int, default=2021)\n    parser.add_argument(\"--save_dir\", type=str, default=\"./processed\")\n    args = parser.parse_args()\n\n    return args\n\n\ndef load_ogb_dataset(dataset):\n    \"\"\"\n    Load OGB dataset\n    Args:\n        dataset(str): name of dataset (ogbl-collab, ogbl-ddi, ogbl-citation)\n\n    Returns:\n        graph(DGLGraph): graph\n        split_edge(dict): split edge\n\n    \"\"\"\n    dataset = DglLinkPropPredDataset(name=dataset)\n    split_edge = dataset.get_edge_split()\n    graph = dataset[0]\n\n    return graph, split_edge\n\n\ndef drnl_node_labeling(subgraph, src, dst):\n    \"\"\"\n    Double Radius Node labeling\n    d = r(i,u)+r(i,v)\n    label = 1+ min(r(i,u),r(i,v))+ (d//2)*(d//2+d%2-1)\n    Isolated nodes in subgraph will be set as zero.\n    Extreme large graph may cause memory error.\n\n    Args:\n        subgraph(DGLGraph): The graph\n        src(int): node id of one of src node in new subgraph\n        dst(int): node id of one of dst node in new subgraph\n    Returns:\n        z(Tensor): node labeling tensor\n    \"\"\"\n    adj = subgraph.adj_external().to_dense().numpy()\n    src, dst = (dst, src) if src > dst else (src, dst)\n\n    idx = list(range(src)) + list(range(src + 1, adj.shape[0]))\n    adj_wo_src = adj[idx, :][:, idx]\n\n    idx = list(range(dst)) + list(range(dst + 1, adj.shape[0]))\n    adj_wo_dst = adj[idx, :][:, idx]\n\n    dist2src = shortest_path(\n        adj_wo_dst, directed=False, unweighted=True, indices=src\n    )\n    dist2src = np.insert(dist2src, dst, 0, axis=0)\n    dist2src = torch.from_numpy(dist2src)\n\n    dist2dst = shortest_path(\n        adj_wo_src, directed=False, unweighted=True, indices=dst - 1\n    )\n    dist2dst = np.insert(dist2dst, src, 0, axis=0)\n    dist2dst = torch.from_numpy(dist2dst)\n\n    dist = dist2src + dist2dst\n    dist_over_2, dist_mod_2 = dist // 2, dist % 2\n\n    z = 1 + torch.min(dist2src, dist2dst)\n    z += dist_over_2 * (dist_over_2 + dist_mod_2 - 1)\n    z[src] = 1.0\n    z[dst] = 1.0\n    z[torch.isnan(z)] = 0.0\n\n    return z.to(torch.long)\n\n\ndef evaluate_hits(name, pos_pred, neg_pred, K):\n    \"\"\"\n    Compute hits\n    Args:\n        name(str): name of dataset\n        pos_pred(Tensor): predict value of positive edges\n        neg_pred(Tensor): predict value of negative edges\n        K(int): num of hits\n\n    Returns:\n        hits(float): score of hits\n\n\n    \"\"\"\n    evaluator = Evaluator(name)\n    evaluator.K = K\n    hits = evaluator.eval(\n        {\n            \"y_pred_pos\": pos_pred,\n            \"y_pred_neg\": neg_pred,\n        }\n    )[f\"hits@{K}\"]\n\n    return hits\n"
  },
  {
    "path": "examples/pytorch/sgc/README.md",
    "content": "Simple Graph Convolution (SGC)\n============\n\n- Paper link: [Simplifying Graph Convolutional Networks](https://arxiv.org/abs/1902.07153)\n- Author's code repo: [https://github.com/Tiiiger/SGC](https://github.com/Tiiiger/SGC). \n\nDependencies\n------------\n- PyTorch 0.4.1+\n- requests\n\n``bash\npip install torch requests\n``\n\nCodes\n-----\nThe folder contains an implementation of SGC (`sgc.py`).\n`sgc_reddit.py` contains an example of training SGC on the reddit dataset.\n\nResults\n-------\n\nRun with following (available dataset: \"cora\", \"citeseer\", \"pubmed\")\n```bash\npython3 sgc.py --dataset cora --gpu 0\npython3 sgc.py --dataset citeseer --weight-decay 5e-5 --n-epochs 150 --bias --gpu 0\npython3 sgc.py --dataset pubmed --weight-decay 5e-5 --bias --gpu 0\n```\nRun the following command to train on the reddit dataset.\n```bash\npython sgc_reddit.py --gpu 0\n```\n\nOn NVIDIA V100\n\n* cora: 0.819 (paper: 0.810), 0.0008s/epoch\n* citeseer: 0.725 (paper: 0.719), 0.0008s/epoch\n* pubmed: 0.788 (paper: 0.789), 0.0007s/epoch\n* reddit: 0.947 (paper: 0.949), 0.6872s in total\n"
  },
  {
    "path": "examples/pytorch/sgc/sgc.py",
    "content": "\"\"\"\nThis code was modified from the GCN implementation in DGL examples.\nSimplifying Graph Convolutional Networks\nPaper: https://arxiv.org/abs/1902.07153\nCode: https://github.com/Tiiiger/SGC\nSGC implementation in DGL.\n\"\"\"\nimport argparse\nimport math\nimport time\n\nimport dgl\nimport dgl.function as fn\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dgl.data import (\n    CiteseerGraphDataset,\n    CoraGraphDataset,\n    PubmedGraphDataset,\n    register_data_args,\n)\nfrom dgl.nn.pytorch.conv import SGConv\n\n\ndef evaluate(model, g, features, labels, mask):\n    model.eval()\n    with torch.no_grad():\n        logits = model(g, features)[mask]  # only compute the evaluation set\n        labels = labels[mask]\n        _, indices = torch.max(logits, dim=1)\n        correct = torch.sum(indices == labels)\n        return correct.item() * 1.0 / len(labels)\n\n\ndef main(args):\n    # load and preprocess dataset\n    if args.dataset == \"cora\":\n        data = CoraGraphDataset()\n    elif args.dataset == \"citeseer\":\n        data = CiteseerGraphDataset()\n    elif args.dataset == \"pubmed\":\n        data = PubmedGraphDataset()\n    else:\n        raise ValueError(\"Unknown dataset: {}\".format(args.dataset))\n\n    g = data[0]\n    if args.gpu < 0:\n        cuda = False\n    else:\n        cuda = True\n        g = g.int().to(args.gpu)\n\n    features = g.ndata[\"feat\"]\n    labels = g.ndata[\"label\"]\n    train_mask = g.ndata[\"train_mask\"]\n    val_mask = g.ndata[\"val_mask\"]\n    test_mask = g.ndata[\"test_mask\"]\n    in_feats = features.shape[1]\n    n_classes = data.num_classes\n    n_edges = g.num_edges()\n    print(\n        \"\"\"----Data statistics------'\n      #Edges %d\n      #Classes %d\n      #Train samples %d\n      #Val samples %d\n      #Test samples %d\"\"\"\n        % (\n            n_edges,\n            n_classes,\n            train_mask.int().sum().item(),\n            val_mask.int().sum().item(),\n            test_mask.int().sum().item(),\n        )\n    )\n\n    n_edges = g.num_edges()\n    # add self loop\n    g = dgl.remove_self_loop(g)\n    g = dgl.add_self_loop(g)\n\n    # create SGC model\n    model = SGConv(in_feats, n_classes, k=2, cached=True, bias=args.bias)\n\n    if cuda:\n        model.cuda()\n    loss_fcn = torch.nn.CrossEntropyLoss()\n\n    # use optimizer\n    optimizer = torch.optim.Adam(\n        model.parameters(), lr=args.lr, weight_decay=args.weight_decay\n    )\n\n    # initialize graph\n    dur = []\n    for epoch in range(args.n_epochs):\n        model.train()\n        if epoch >= 3:\n            t0 = time.time()\n        # forward\n        logits = model(g, features)  # only compute the train set\n        loss = loss_fcn(logits[train_mask], labels[train_mask])\n\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        if epoch >= 3:\n            dur.append(time.time() - t0)\n\n        acc = evaluate(model, g, features, labels, val_mask)\n        print(\n            \"Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | \"\n            \"ETputs(KTEPS) {:.2f}\".format(\n                epoch,\n                np.mean(dur),\n                loss.item(),\n                acc,\n                n_edges / np.mean(dur) / 1000,\n            )\n        )\n\n    print()\n    acc = evaluate(model, g, features, labels, test_mask)\n    print(\"Test Accuracy {:.4f}\".format(acc))\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"SGC\")\n    register_data_args(parser)\n    parser.add_argument(\"--gpu\", type=int, default=-1, help=\"gpu\")\n    parser.add_argument(\"--lr\", type=float, default=0.2, help=\"learning rate\")\n    parser.add_argument(\n        \"--bias\", action=\"store_true\", default=False, help=\"flag to use bias\"\n    )\n    parser.add_argument(\n        \"--n-epochs\", type=int, default=100, help=\"number of training epochs\"\n    )\n    parser.add_argument(\n        \"--weight-decay\", type=float, default=5e-6, help=\"Weight for L2 loss\"\n    )\n    args = parser.parse_args()\n    print(args)\n\n    main(args)\n"
  },
  {
    "path": "examples/pytorch/sgc/sgc_reddit.py",
    "content": "\"\"\"\nThis code was modified from the GCN implementation in DGL examples.\nSimplifying Graph Convolutional Networks\nPaper: https://arxiv.org/abs/1902.07153\nCode: https://github.com/Tiiiger/SGC\nSGC implementation in DGL.\n\"\"\"\nimport argparse\nimport math\nimport time\n\nimport dgl.function as fn\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dgl import DGLGraph\nfrom dgl.data import load_data, register_data_args\nfrom dgl.nn.pytorch.conv import SGConv\n\n\ndef normalize(h):\n    return (h - h.mean(0)) / h.std(0)\n\n\ndef evaluate(model, features, graph, labels, mask):\n    model.eval()\n    with torch.no_grad():\n        logits = model(graph, features)[mask]  # only compute the evaluation set\n        labels = labels[mask]\n        _, indices = torch.max(logits, dim=1)\n        correct = torch.sum(indices == labels)\n        return correct.item() * 1.0 / len(labels)\n\n\ndef main(args):\n    # load and preprocess dataset\n    args.dataset = \"reddit-self-loop\"\n    data = load_data(args)\n    g = data[0]\n    if args.gpu < 0:\n        cuda = False\n    else:\n        cuda = True\n        g = g.int().to(args.gpu)\n\n    features = g.ndata[\"feat\"]\n    labels = g.ndata[\"label\"]\n    train_mask = g.ndata[\"train_mask\"]\n    val_mask = g.ndata[\"val_mask\"]\n    test_mask = g.ndata[\"test_mask\"]\n    in_feats = features.shape[1]\n    n_classes = data.num_classes\n    n_edges = g.num_edges()\n    print(\n        \"\"\"----Data statistics------'\n      #Edges %d\n      #Classes %d\n      #Train samples %d\n      #Val samples %d\n      #Test samples %d\"\"\"\n        % (\n            n_edges,\n            n_classes,\n            g.ndata[\"train_mask\"].int().sum().item(),\n            g.ndata[\"val_mask\"].int().sum().item(),\n            g.ndata[\"test_mask\"].int().sum().item(),\n        )\n    )\n\n    # graph preprocess and calculate normalization factor\n    n_edges = g.num_edges()\n    # normalization\n    degs = g.in_degrees().float()\n    norm = torch.pow(degs, -0.5)\n    norm[torch.isinf(norm)] = 0\n    g.ndata[\"norm\"] = norm.unsqueeze(1)\n\n    # create SGC model\n    model = SGConv(\n        in_feats, n_classes, k=2, cached=True, bias=True, norm=normalize\n    )\n    if args.gpu >= 0:\n        model = model.cuda()\n\n    # use optimizer\n    optimizer = torch.optim.LBFGS(model.parameters())\n\n    # define loss closure\n    def closure():\n        optimizer.zero_grad()\n        output = model(g, features)[train_mask]\n        loss_train = F.cross_entropy(output, labels[train_mask])\n        loss_train.backward()\n        return loss_train\n\n    # initialize graph\n    for epoch in range(args.n_epochs):\n        model.train()\n        optimizer.step(closure)\n\n    acc = evaluate(model, features, g, labels, test_mask)\n    print(\"Test Accuracy {:.4f}\".format(acc))\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"SGC\")\n    register_data_args(parser)\n    parser.add_argument(\"--gpu\", type=int, default=-1, help=\"gpu\")\n    parser.add_argument(\n        \"--bias\", action=\"store_true\", default=False, help=\"flag to use bias\"\n    )\n    parser.add_argument(\n        \"--n-epochs\", type=int, default=2, help=\"number of training epochs\"\n    )\n    args = parser.parse_args()\n    print(args)\n\n    main(args)\n"
  },
  {
    "path": "examples/pytorch/sign/README.md",
    "content": "SIGN: Scalable Inception Graph Neural Networks\n===============\n\n- paper link: [https://arxiv.org/pdf/2004.11198.pdf](https://arxiv.org/pdf/2004.11198.pdf)\n\nRequirements\n----------------\n\n```bash\npip install requests ogb\n```\n\nResults\n---------------\n### [Ogbn-products](https://ogb.stanford.edu/docs/nodeprop/#ogbn-products) (Amazon co-purchase dataset)\n\n```bash\npython sign.py --dataset amazon\n```\n\nTest accuracy: mean 0.78672, std 0.00059\n\n### Reddit\n```bash\npython sign.py --dataset reddit\n```\n\nTest accuracy: mean 0.96326, std 0.00010\n"
  },
  {
    "path": "examples/pytorch/sign/dataset.py",
    "content": "import dgl\nimport numpy as np\nimport torch\n\n\ndef load_dataset(name):\n    dataset = name.lower()\n    if dataset == \"amazon\":\n        from ogb.nodeproppred.dataset_dgl import DglNodePropPredDataset\n\n        dataset = DglNodePropPredDataset(name=\"ogbn-products\")\n        splitted_idx = dataset.get_idx_split()\n        train_nid = splitted_idx[\"train\"]\n        val_nid = splitted_idx[\"valid\"]\n        test_nid = splitted_idx[\"test\"]\n        g, labels = dataset[0]\n        n_classes = int(labels.max() - labels.min() + 1)\n        g.ndata[\"label\"] = labels.squeeze()\n        g.ndata[\"feat\"] = g.ndata[\"feat\"].float()\n    elif dataset in [\"reddit\", \"cora\"]:\n        if dataset == \"reddit\":\n            from dgl.data import RedditDataset\n\n            data = RedditDataset(self_loop=True)\n            g = data[0]\n        else:\n            from dgl.data import CitationGraphDataset\n\n            data = CitationGraphDataset(\"cora\")\n            g = data[0]\n        n_classes = data.num_classes\n        train_mask = g.ndata[\"train_mask\"]\n        val_mask = g.ndata[\"val_mask\"]\n        test_mask = g.ndata[\"test_mask\"]\n        train_nid = torch.LongTensor(train_mask.nonzero().squeeze())\n        val_nid = torch.LongTensor(val_mask.nonzero().squeeze())\n        test_nid = torch.LongTensor(test_mask.nonzero().squeeze())\n    else:\n        print(\"Dataset {} is not supported\".format(name))\n        assert 0\n\n    return g, n_classes, train_nid, val_nid, test_nid\n"
  },
  {
    "path": "examples/pytorch/sign/sign.py",
    "content": "import argparse\nimport os\nimport time\n\nimport dgl\nimport dgl.function as fn\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dataset import load_dataset\n\n\nclass FeedForwardNet(nn.Module):\n    def __init__(self, in_feats, hidden, out_feats, n_layers, dropout):\n        super(FeedForwardNet, self).__init__()\n        self.layers = nn.ModuleList()\n        self.n_layers = n_layers\n        if n_layers == 1:\n            self.layers.append(nn.Linear(in_feats, out_feats))\n        else:\n            self.layers.append(nn.Linear(in_feats, hidden))\n            for i in range(n_layers - 2):\n                self.layers.append(nn.Linear(hidden, hidden))\n            self.layers.append(nn.Linear(hidden, out_feats))\n        if self.n_layers > 1:\n            self.prelu = nn.PReLU()\n            self.dropout = nn.Dropout(dropout)\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        gain = nn.init.calculate_gain(\"relu\")\n        for layer in self.layers:\n            nn.init.xavier_uniform_(layer.weight, gain=gain)\n            nn.init.zeros_(layer.bias)\n\n    def forward(self, x):\n        for layer_id, layer in enumerate(self.layers):\n            x = layer(x)\n            if layer_id < self.n_layers - 1:\n                x = self.dropout(self.prelu(x))\n        return x\n\n\nclass Model(nn.Module):\n    def __init__(self, in_feats, hidden, out_feats, R, n_layers, dropout):\n        super(Model, self).__init__()\n        self.dropout = nn.Dropout(dropout)\n        self.prelu = nn.PReLU()\n        self.inception_ffs = nn.ModuleList()\n        for hop in range(R + 1):\n            self.inception_ffs.append(\n                FeedForwardNet(in_feats, hidden, hidden, n_layers, dropout)\n            )\n        # self.linear = nn.Linear(hidden * (R + 1), out_feats)\n        self.project = FeedForwardNet(\n            (R + 1) * hidden, hidden, out_feats, n_layers, dropout\n        )\n\n    def forward(self, feats):\n        hidden = []\n        for feat, ff in zip(feats, self.inception_ffs):\n            hidden.append(ff(feat))\n        out = self.project(self.dropout(self.prelu(torch.cat(hidden, dim=-1))))\n        return out\n\n\ndef calc_weight(g):\n    \"\"\"\n    Compute row_normalized(D^(-1/2)AD^(-1/2))\n    \"\"\"\n    with g.local_scope():\n        # compute D^(-0.5)*D(-1/2), assuming A is Identity\n        g.ndata[\"in_deg\"] = g.in_degrees().float().pow(-0.5)\n        g.ndata[\"out_deg\"] = g.out_degrees().float().pow(-0.5)\n        g.apply_edges(fn.u_mul_v(\"out_deg\", \"in_deg\", \"weight\"))\n        # row-normalize weight\n        g.update_all(fn.copy_e(\"weight\", \"msg\"), fn.sum(\"msg\", \"norm\"))\n        g.apply_edges(fn.e_div_v(\"weight\", \"norm\", \"weight\"))\n        return g.edata[\"weight\"]\n\n\ndef preprocess(g, features, args):\n    \"\"\"\n    Pre-compute the average of n-th hop neighbors\n    \"\"\"\n    with torch.no_grad():\n        g.edata[\"weight\"] = calc_weight(g)\n        g.ndata[\"feat_0\"] = features\n        for hop in range(1, args.R + 1):\n            g.update_all(\n                fn.u_mul_e(f\"feat_{hop-1}\", \"weight\", \"msg\"),\n                fn.sum(\"msg\", f\"feat_{hop}\"),\n            )\n        res = []\n        for hop in range(args.R + 1):\n            res.append(g.ndata.pop(f\"feat_{hop}\"))\n        return res\n\n\ndef prepare_data(device, args):\n    data = load_dataset(args.dataset)\n    g, n_classes, train_nid, val_nid, test_nid = data\n    g = g.to(device)\n    in_feats = g.ndata[\"feat\"].shape[1]\n    feats = preprocess(g, g.ndata[\"feat\"], args)\n    labels = g.ndata[\"label\"]\n    # move to device\n    train_nid = train_nid.to(device)\n    val_nid = val_nid.to(device)\n    test_nid = test_nid.to(device)\n    train_feats = [x[train_nid] for x in feats]\n    train_labels = labels[train_nid]\n    return (\n        feats,\n        labels,\n        train_feats,\n        train_labels,\n        in_feats,\n        n_classes,\n        train_nid,\n        val_nid,\n        test_nid,\n    )\n\n\ndef evaluate(epoch, args, model, feats, labels, train, val, test):\n    with torch.no_grad():\n        batch_size = args.eval_batch_size\n        if batch_size <= 0:\n            pred = model(feats)\n        else:\n            pred = []\n            num_nodes = labels.shape[0]\n            n_batch = (num_nodes + batch_size - 1) // batch_size\n            for i in range(n_batch):\n                batch_start = i * batch_size\n                batch_end = min((i + 1) * batch_size, num_nodes)\n                batch_feats = [feat[batch_start:batch_end] for feat in feats]\n                pred.append(model(batch_feats))\n            pred = torch.cat(pred)\n\n        pred = torch.argmax(pred, dim=1)\n        correct = (pred == labels).float()\n        train_acc = correct[train].sum() / len(train)\n        val_acc = correct[val].sum() / len(val)\n        test_acc = correct[test].sum() / len(test)\n        return train_acc, val_acc, test_acc\n\n\ndef main(args):\n    if args.gpu < 0:\n        device = \"cpu\"\n    else:\n        device = \"cuda:{}\".format(args.gpu)\n\n    data = prepare_data(device, args)\n    (\n        feats,\n        labels,\n        train_feats,\n        train_labels,\n        in_size,\n        num_classes,\n        train_nid,\n        val_nid,\n        test_nid,\n    ) = data\n\n    model = Model(\n        in_size,\n        args.num_hidden,\n        num_classes,\n        args.R,\n        args.ff_layer,\n        args.dropout,\n    )\n    model = model.to(device)\n    loss_fcn = nn.CrossEntropyLoss()\n    optimizer = torch.optim.Adam(\n        model.parameters(), lr=args.lr, weight_decay=args.weight_decay\n    )\n\n    best_epoch = 0\n    best_val = 0\n    best_test = 0\n\n    for epoch in range(1, args.num_epochs + 1):\n        start = time.time()\n        model.train()\n        loss = loss_fcn(model(train_feats), train_labels)\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        if epoch % args.eval_every == 0:\n            model.eval()\n            acc = evaluate(\n                epoch, args, model, feats, labels, train_nid, val_nid, test_nid\n            )\n            end = time.time()\n            log = \"Epoch {}, Times(s): {:.4f}\".format(epoch, end - start)\n            log += \", Accuracy: Train {:.4f}, Val {:.4f}, Test {:.4f}\".format(\n                *acc\n            )\n            print(log)\n            if acc[1] > best_val:\n                best_val = acc[1]\n                best_epoch = epoch\n                best_test = acc[2]\n\n    print(\n        \"Best Epoch {}, Val {:.4f}, Test {:.4f}\".format(\n            best_epoch, best_val, best_test\n        )\n    )\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"SIGN\")\n    parser.add_argument(\"--num-epochs\", type=int, default=1000)\n    parser.add_argument(\"--num-hidden\", type=int, default=256)\n    parser.add_argument(\"--R\", type=int, default=3, help=\"number of hops\")\n    parser.add_argument(\"--lr\", type=float, default=0.003)\n    parser.add_argument(\"--dataset\", type=str, default=\"amazon\")\n    parser.add_argument(\"--dropout\", type=float, default=0.5)\n    parser.add_argument(\"--gpu\", type=int, default=0)\n    parser.add_argument(\"--weight-decay\", type=float, default=0)\n    parser.add_argument(\"--eval-every\", type=int, default=50)\n    parser.add_argument(\n        \"--eval-batch-size\",\n        type=int,\n        default=250000,\n        help=\"evaluation batch size, -1 for full batch\",\n    )\n    parser.add_argument(\n        \"--ff-layer\", type=int, default=2, help=\"number of feed-forward layers\"\n    )\n    args = parser.parse_args()\n\n    print(args)\n    main(args)\n"
  },
  {
    "path": "examples/pytorch/stgcn_wave/README.md",
    "content": "Spatio-Temporal Graph Convolutional Networks\r\n============\r\n\r\n- Paper link: [arXiv](https://arxiv.org/pdf/1709.04875v4.pdf)\r\n- Author's code repo: https://github.com/VeritasYin/STGCN_IJCAI-18.\r\n- See [this blog](https://towardsdatascience.com/build-your-first-graph-neural-network-model-to-predict-traffic-speed-in-20-minutes-b593f8f838e5) for more details about running the code.\r\n- Dependencies\r\n  - PyTorch 1.1.0+\r\n  - scikit-learn\r\n  - dgl\r\n  - tables\r\n\r\n\r\nHow to run\r\n----------\r\nplease get METR_LA dataset from [this Google drive](https://drive.google.com/open?id=10FOTa6HXPqX8Pf5WRoRwcFnW9BrNZEIX).\r\nand [this Github repo](https://github.com/chnsh/DCRNN_PyTorch)\r\n\r\nAn experiment in default settings can be run with\r\n\r\n```bash\r\npython main.py\r\n```\r\n\r\nAn experiment on the METR_LA dataset in customized settings can be run with\r\n```bash\r\npython main.py --lr --seed --disable-cuda --batch_size <batch-size> --epochs <number-of-epochs>\r\n```\r\n\r\nIf one wishes to adjust the model structure, you can change the arguments `control_str` and `channels`\r\n```bash\r\npython main.py --control_str <control-string> --channels <n-input-channel> <n-hidden-channels-1> <n-hidden-channels-2> ... <n-output-channels>\r\n```\r\n\r\n`<control-string>` is a string of the following characters representing a sequence of neural network modules:\r\n\r\n* `T`: representing a dilated temporal convolution layer, working on the temporal dimension.  The dilation factor is always twice as much as the previous temporal convolution layer.\r\n* `S`: representing a graph convolution layer, working on the spatial dimension.  The input channels and output channels are the same.\r\n* `N`: a Layer Normalization.\r\n\r\nThe argument list following `--channels` represents the output channels on each temporal convolution layer.  The list should have `N + 1` elements, where `N` is the number of `T`'s in `<control-string>`.\r\n\r\nThe activation function between two layers are always ReLU.\r\n\r\nFor example, the following command\r\n```bash\r\npython main.py --control_str TNTSTNTST --channels 1 16 32 32 64 128\r\n```\r\nspecifies the following architecture:\r\n\r\n```\r\n+------------------------------------------------------------+\r\n|                          Input                             |\r\n+------------------------------------------------------------+\r\n|  1D Conv, in_channel = 1, out_channel = 16, dilation = 1   |\r\n+------------------------------------------------------------+\r\n|                   Layer Normalization                      |\r\n+------------------------------------------------------------+\r\n|  1D Conv, in_channel = 16, out_channel = 32, dilation = 2  |\r\n+------------------------------------------------------------+\r\n|       Graph Conv, in_channel = 32, out_channel = 32        |\r\n+------------------------------------------------------------+\r\n|  1D Conv, in_channel = 32, out_channel = 32, dilation = 4  |\r\n+------------------------------------------------------------+\r\n|                   Layer Normalization                      |\r\n+------------------------------------------------------------+\r\n|  1D Conv, in_channel = 32, out_channel = 64, dilation = 8  |\r\n+------------------------------------------------------------+\r\n|       Graph Conv, in_channel = 64, out_channel = 64        |\r\n+------------------------------------------------------------+\r\n| 1D Conv, in_channel = 64, out_channel = 128, dilation = 16 |\r\n+------------------------------------------------------------+\r\n```\r\n\r\nResults\r\n-------\r\n\r\n```bash\r\npython main.py\r\n```\r\nMETR_LA MAE: ~5.76\r\n"
  },
  {
    "path": "examples/pytorch/stgcn_wave/load_data.py",
    "content": "import numpy as np\nimport pandas as pd\nimport torch\n\n\ndef load_data(file_path, len_train, len_val):\n    df = pd.read_csv(file_path, header=None).values.astype(float)\n    train = df[:len_train]\n    val = df[len_train : len_train + len_val]\n    test = df[len_train + len_val :]\n    return train, val, test\n\n\ndef data_transform(data, n_his, n_pred, device):\n    # produce data slices for training and testing\n    n_route = data.shape[1]\n    l = len(data)\n    num = l - n_his - n_pred\n    x = np.zeros([num, 1, n_his, n_route])\n    y = np.zeros([num, n_route])\n\n    cnt = 0\n    for i in range(l - n_his - n_pred):\n        head = i\n        tail = i + n_his\n        x[cnt, :, :, :] = data[head:tail].reshape(1, n_his, n_route)\n        y[cnt] = data[tail + n_pred - 1]\n        cnt += 1\n    return torch.Tensor(x).to(device), torch.Tensor(y).to(device)\n"
  },
  {
    "path": "examples/pytorch/stgcn_wave/main.py",
    "content": "import argparse\nimport random\n\nimport numpy as np\nimport pandas as pd\nimport scipy.sparse as sp\nimport torch\nimport torch.nn as nn\nfrom load_data import *\nfrom model import *\nfrom sensors2graph import *\nfrom sklearn.preprocessing import StandardScaler\nfrom utils import *\n\nimport dgl\n\nparser = argparse.ArgumentParser(description=\"STGCN_WAVE\")\nparser.add_argument(\"--lr\", default=0.001, type=float, help=\"learning rate\")\nparser.add_argument(\"--disablecuda\", action=\"store_true\", help=\"Disable CUDA\")\nparser.add_argument(\n    \"--batch_size\",\n    type=int,\n    default=50,\n    help=\"batch size for training and validation (default: 50)\",\n)\nparser.add_argument(\n    \"--epochs\", type=int, default=50, help=\"epochs for training  (default: 50)\"\n)\nparser.add_argument(\n    \"--num_layers\", type=int, default=9, help=\"number of layers\"\n)\nparser.add_argument(\"--window\", type=int, default=144, help=\"window length\")\nparser.add_argument(\n    \"--sensorsfilepath\",\n    type=str,\n    default=\"./data/sensor_graph/graph_sensor_ids.txt\",\n    help=\"sensors file path\",\n)\nparser.add_argument(\n    \"--disfilepath\",\n    type=str,\n    default=\"./data/sensor_graph/distances_la_2012.csv\",\n    help=\"distance file path\",\n)\nparser.add_argument(\n    \"--tsfilepath\", type=str, default=\"./data/metr-la.h5\", help=\"ts file path\"\n)\nparser.add_argument(\n    \"--savemodelpath\",\n    type=str,\n    default=\"stgcnwavemodel.pt\",\n    help=\"save model path\",\n)\nparser.add_argument(\n    \"--pred_len\",\n    type=int,\n    default=5,\n    help=\"how many steps away we want to predict\",\n)\nparser.add_argument(\n    \"--control_str\",\n    type=str,\n    default=\"TNTSTNTST\",\n    help=\"model strcture controller, T: Temporal Layer, S: Spatio Layer, N: Norm Layer\",\n)\nparser.add_argument(\n    \"--channels\",\n    type=int,\n    nargs=\"+\",\n    default=[1, 16, 32, 64, 32, 128],\n    help=\"model strcture controller, T: Temporal Layer, S: Spatio Layer, N: Norm Layer\",\n)\nargs = parser.parse_args()\n\ndevice = (\n    torch.device(\"cuda\")\n    if torch.cuda.is_available() and not args.disablecuda\n    else torch.device(\"cpu\")\n)\n\nwith open(args.sensorsfilepath) as f:\n    sensor_ids = f.read().strip().split(\",\")\ndistance_df = pd.read_csv(args.disfilepath, dtype={\"from\": \"str\", \"to\": \"str\"})\n\nadj_mx = get_adjacency_matrix(distance_df, sensor_ids)\nsp_mx = sp.coo_matrix(adj_mx)\nG = dgl.from_scipy(sp_mx)\n\n\ndf = pd.read_hdf(args.tsfilepath)\nnum_samples, num_nodes = df.shape\n\ntsdata = df.to_numpy()\n\n\nn_his = args.window\n\nsave_path = args.savemodelpath\n\n\nn_pred = args.pred_len\nn_route = num_nodes\nblocks = args.channels\n# blocks = [1, 16, 32, 64, 32, 128]\ndrop_prob = 0\nnum_layers = args.num_layers\n\nbatch_size = args.batch_size\nepochs = args.epochs\nlr = args.lr\n\n\nW = adj_mx\nlen_val = round(num_samples * 0.1)\nlen_train = round(num_samples * 0.7)\ntrain = df[:len_train]\nval = df[len_train : len_train + len_val]\ntest = df[len_train + len_val :]\n\nscaler = StandardScaler()\ntrain = scaler.fit_transform(train)\nval = scaler.transform(val)\ntest = scaler.transform(test)\n\n\nx_train, y_train = data_transform(train, n_his, n_pred, device)\nx_val, y_val = data_transform(val, n_his, n_pred, device)\nx_test, y_test = data_transform(test, n_his, n_pred, device)\n\ntrain_data = torch.utils.data.TensorDataset(x_train, y_train)\ntrain_iter = torch.utils.data.DataLoader(train_data, batch_size, shuffle=True)\nval_data = torch.utils.data.TensorDataset(x_val, y_val)\nval_iter = torch.utils.data.DataLoader(val_data, batch_size)\ntest_data = torch.utils.data.TensorDataset(x_test, y_test)\ntest_iter = torch.utils.data.DataLoader(test_data, batch_size)\n\n\nloss = nn.MSELoss()\nG = G.to(device)\nmodel = STGCN_WAVE(\n    blocks, n_his, n_route, G, drop_prob, num_layers, device, args.control_str\n).to(device)\noptimizer = torch.optim.RMSprop(model.parameters(), lr=lr)\n\nscheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.7)\n\nmin_val_loss = np.inf\nfor epoch in range(1, epochs + 1):\n    l_sum, n = 0.0, 0\n    model.train()\n    for x, y in train_iter:\n        y_pred = model(x).view(len(x), -1)\n        l = loss(y_pred, y)\n        optimizer.zero_grad()\n        l.backward()\n        optimizer.step()\n        l_sum += l.item() * y.shape[0]\n        n += y.shape[0]\n    scheduler.step()\n    val_loss = evaluate_model(model, loss, val_iter)\n    if val_loss < min_val_loss:\n        min_val_loss = val_loss\n        torch.save(model.state_dict(), save_path)\n    print(\n        \"epoch\",\n        epoch,\n        \", train loss:\",\n        l_sum / n,\n        \", validation loss:\",\n        val_loss,\n    )\nbest_model = STGCN_WAVE(\n    blocks, n_his, n_route, G, drop_prob, num_layers, device, args.control_str\n).to(device)\nbest_model.load_state_dict(torch.load(save_path, weights_only=False))\n\n\nl = evaluate_model(best_model, loss, test_iter)\nMAE, MAPE, RMSE = evaluate_metric(best_model, test_iter, scaler)\nprint(\"test loss:\", l, \"\\nMAE:\", MAE, \", MAPE:\", MAPE, \", RMSE:\", RMSE)\n"
  },
  {
    "path": "examples/pytorch/stgcn_wave/model.py",
    "content": "import math\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.nn.init as init\n\nfrom dgl.nn.pytorch import GraphConv\nfrom dgl.nn.pytorch.conv import ChebConv\n\n\nclass TemporalConvLayer(nn.Module):\n    \"\"\"Temporal convolution layer.\n\n    arguments\n    ---------\n    c_in : int\n        The number of input channels (features)\n    c_out : int\n        The number of output channels (features)\n    dia : int\n        The dilation size\n    \"\"\"\n\n    def __init__(self, c_in, c_out, dia=1):\n        super(TemporalConvLayer, self).__init__()\n        self.c_out = c_out\n        self.c_in = c_in\n        self.conv = nn.Conv2d(\n            c_in, c_out, (2, 1), 1, dilation=dia, padding=(0, 0)\n        )\n\n    def forward(self, x):\n        return torch.relu(self.conv(x))\n\n\nclass SpatioConvLayer(nn.Module):\n    def __init__(self, c, Lk):  # c : hidden dimension Lk: graph matrix\n        super(SpatioConvLayer, self).__init__()\n        self.g = Lk\n        self.gc = GraphConv(c, c, activation=F.relu)\n        # self.gc = ChebConv(c, c, 3)\n\n    def init(self):\n        stdv = 1.0 / math.sqrt(self.W.weight.size(1))\n        self.W.weight.data.uniform_(-stdv, stdv)\n\n    def forward(self, x):\n        x = x.transpose(0, 3)\n        x = x.transpose(1, 3)\n        output = self.gc(self.g, x)\n        output = output.transpose(1, 3)\n        output = output.transpose(0, 3)\n        return torch.relu(output)\n\n\nclass FullyConvLayer(nn.Module):\n    def __init__(self, c):\n        super(FullyConvLayer, self).__init__()\n        self.conv = nn.Conv2d(c, 1, 1)\n\n    def forward(self, x):\n        return self.conv(x)\n\n\nclass OutputLayer(nn.Module):\n    def __init__(self, c, T, n):\n        super(OutputLayer, self).__init__()\n        self.tconv1 = nn.Conv2d(c, c, (T, 1), 1, dilation=1, padding=(0, 0))\n        self.ln = nn.LayerNorm([n, c])\n        self.tconv2 = nn.Conv2d(c, c, (1, 1), 1, dilation=1, padding=(0, 0))\n        self.fc = FullyConvLayer(c)\n\n    def forward(self, x):\n        x_t1 = self.tconv1(x)\n        x_ln = self.ln(x_t1.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)\n        x_t2 = self.tconv2(x_ln)\n        return self.fc(x_t2)\n\n\nclass STGCN_WAVE(nn.Module):\n    def __init__(\n        self, c, T, n, Lk, p, num_layers, device, control_str=\"TNTSTNTST\"\n    ):\n        super(STGCN_WAVE, self).__init__()\n        self.control_str = control_str  # model structure controller\n        self.num_layers = len(control_str)\n        self.layers = nn.ModuleList([])\n        cnt = 0\n        diapower = 0\n        for i in range(self.num_layers):\n            i_layer = control_str[i]\n            if i_layer == \"T\":  # Temporal Layer\n                self.layers.append(\n                    TemporalConvLayer(c[cnt], c[cnt + 1], dia=2**diapower)\n                )\n                diapower += 1\n                cnt += 1\n            if i_layer == \"S\":  # Spatio Layer\n                self.layers.append(SpatioConvLayer(c[cnt], Lk))\n            if i_layer == \"N\":  # Norm Layer\n                self.layers.append(nn.LayerNorm([n, c[cnt]]))\n        self.output = OutputLayer(c[cnt], T + 1 - 2 ** (diapower), n)\n        for layer in self.layers:\n            layer = layer.to(device)\n\n    def forward(self, x):\n        for i in range(self.num_layers):\n            i_layer = self.control_str[i]\n            if i_layer == \"N\":\n                x = self.layers[i](x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)\n            else:\n                x = self.layers[i](x)\n        return self.output(x)\n"
  },
  {
    "path": "examples/pytorch/stgcn_wave/sensors2graph.py",
    "content": "import numpy as np\r\n\r\n\r\ndef get_adjacency_matrix(distance_df, sensor_ids, normalized_k=0.1):\r\n    \"\"\"\r\n    :param distance_df: data frame with three columns: [from, to, distance].\r\n    :param sensor_ids: list of sensor ids.\r\n    :param normalized_k: entries that become lower than normalized_k after normalization are set to zero for sparsity.\r\n    :return: adjacency matrix\r\n    \"\"\"\r\n    num_sensors = len(sensor_ids)\r\n    dist_mx = np.zeros((num_sensors, num_sensors), dtype=np.float32)\r\n    dist_mx[:] = np.inf\r\n    # Builds sensor id to index map.\r\n    sensor_id_to_ind = {}\r\n    for i, sensor_id in enumerate(sensor_ids):\r\n        sensor_id_to_ind[sensor_id] = i\r\n    # Fills cells in the matrix with distances.\r\n    for row in distance_df.values:\r\n        if row[0] not in sensor_id_to_ind or row[1] not in sensor_id_to_ind:\r\n            continue\r\n        dist_mx[sensor_id_to_ind[row[0]], sensor_id_to_ind[row[1]]] = row[2]\r\n\r\n    # Calculates the standard deviation as theta.\r\n    distances = dist_mx[~np.isinf(dist_mx)].flatten()\r\n    std = distances.std()\r\n    adj_mx = np.exp(-np.square(dist_mx / std))\r\n    # Make the adjacent matrix symmetric by taking the max.\r\n    # adj_mx = np.maximum.reduce([adj_mx, adj_mx.T])\r\n\r\n    # Sets entries that lower than a threshold, i.e., k, to zero for sparsity.\r\n    adj_mx[adj_mx < normalized_k] = 0\r\n    return adj_mx\r\n"
  },
  {
    "path": "examples/pytorch/stgcn_wave/utils.py",
    "content": "import numpy as np\nimport torch\n\n\ndef evaluate_model(model, loss, data_iter):\n    model.eval()\n    l_sum, n = 0.0, 0\n    with torch.no_grad():\n        for x, y in data_iter:\n            y_pred = model(x).view(len(x), -1)\n            l = loss(y_pred, y)\n            l_sum += l.item() * y.shape[0]\n            n += y.shape[0]\n        return l_sum / n\n\n\ndef evaluate_metric(model, data_iter, scaler):\n    model.eval()\n    with torch.no_grad():\n        mae, mape, mse = [], [], []\n        for x, y in data_iter:\n            y = scaler.inverse_transform(y.cpu().numpy()).reshape(-1)\n            y_pred = scaler.inverse_transform(\n                model(x).view(len(x), -1).cpu().numpy()\n            ).reshape(-1)\n            d = np.abs(y - y_pred)\n            mae += d.tolist()\n            mape += (d / y).tolist()\n            mse += (d**2).tolist()\n        MAE = np.array(mae).mean()\n        MAPE = np.array(mape).mean()\n        RMSE = np.sqrt(np.array(mse).mean())\n        return MAE, MAPE, RMSE\n"
  },
  {
    "path": "examples/pytorch/tagcn/README.md",
    "content": "Topology Adaptive Graph Convolutional networks (TAGCN)\n============\n\n- Paper link: [https://arxiv.org/abs/1710.10370](https://arxiv.org/abs/1710.10370)\n\nDependencies\n------------\n- PyTorch 0.4.1+\n- requests\n\n``bash\npip install torch requests\n``\n\nResults\n-------\nRun with following (available dataset: \"cora\", \"citeseer\", \"pubmed\")\n```bash\npython3 train.py --dataset cora --gpu 0 --self-loop\n```\n\n* cora: ~0.812 (0.804-0.823) (paper: 0.833)\n* citeseer: ~0.715 (paper: 0.714)\n* pubmed: ～0.794 (paper: 0.811)"
  },
  {
    "path": "examples/pytorch/tagcn/tagcn.py",
    "content": "\"\"\"TAGCN using DGL nn package\n\nReferences:\n- Topology Adaptive Graph Convolutional Networks\n- Paper: https://arxiv.org/abs/1710.10370\n\"\"\"\nimport torch\nimport torch.nn as nn\n\nfrom dgl.nn.pytorch.conv import TAGConv\n\n\nclass TAGCN(nn.Module):\n    def __init__(\n        self, g, in_feats, n_hidden, n_classes, n_layers, activation, dropout\n    ):\n        super(TAGCN, self).__init__()\n        self.g = g\n        self.layers = nn.ModuleList()\n        # input layer\n        self.layers.append(TAGConv(in_feats, n_hidden, activation=activation))\n        # hidden layers\n        for i in range(n_layers - 1):\n            self.layers.append(\n                TAGConv(n_hidden, n_hidden, activation=activation)\n            )\n        # output layer\n        self.layers.append(TAGConv(n_hidden, n_classes))  # activation=None\n        self.dropout = nn.Dropout(p=dropout)\n\n    def forward(self, features):\n        h = features\n        for i, layer in enumerate(self.layers):\n            if i != 0:\n                h = self.dropout(h)\n            h = layer(self.g, h)\n        return h\n"
  },
  {
    "path": "examples/pytorch/tagcn/train.py",
    "content": "import argparse\nimport time\n\nimport networkx as nx\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom dgl import DGLGraph\nfrom dgl.data import load_data, register_data_args\nfrom tagcn import TAGCN\n\n\ndef evaluate(model, features, labels, mask):\n    model.eval()\n    with torch.no_grad():\n        logits = model(features)\n        logits = logits[mask]\n        labels = labels[mask]\n        _, indices = torch.max(logits, dim=1)\n        correct = torch.sum(indices == labels)\n        return correct.item() * 1.0 / len(labels)\n\n\ndef main(args):\n    # load and preprocess dataset\n    data = load_data(args)\n    g = data[0]\n    if args.gpu < 0:\n        cuda = False\n    else:\n        cuda = True\n        g = g.to(args.gpu)\n    features = g.ndata[\"feat\"]\n    labels = g.ndata[\"label\"]\n    train_mask = g.ndata[\"train_mask\"]\n    val_mask = g.ndata[\"val_mask\"]\n    test_mask = g.ndata[\"test_mask\"]\n    in_feats = features.shape[1]\n    n_classes = data.num_classes\n    n_edges = g.num_edges()\n    print(\n        \"\"\"----Data statistics------'\n      #Edges %d\n      #Classes %d\n      #Train samples %d\n      #Val samples %d\n      #Test samples %d\"\"\"\n        % (\n            n_edges,\n            n_classes,\n            train_mask.int().sum().item(),\n            val_mask.int().sum().item(),\n            test_mask.int().sum().item(),\n        )\n    )\n\n    # graph preprocess and calculate normalization factor\n    # add self loop\n    if args.self_loop:\n        g = g.remove_self_loop().add_self_loop()\n    n_edges = g.num_edges()\n\n    # create TAGCN model\n    model = TAGCN(\n        g,\n        in_feats,\n        args.n_hidden,\n        n_classes,\n        args.n_layers,\n        F.relu,\n        args.dropout,\n    )\n\n    if cuda:\n        model.cuda()\n    loss_fcn = torch.nn.CrossEntropyLoss()\n\n    # use optimizer\n    optimizer = torch.optim.Adam(\n        model.parameters(), lr=args.lr, weight_decay=args.weight_decay\n    )\n\n    # initialize graph\n    dur = []\n    for epoch in range(args.n_epochs):\n        model.train()\n        if epoch >= 3:\n            t0 = time.time()\n        # forward\n        logits = model(features)\n        loss = loss_fcn(logits[train_mask], labels[train_mask])\n\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        if epoch >= 3:\n            dur.append(time.time() - t0)\n\n        acc = evaluate(model, features, labels, val_mask)\n        print(\n            \"Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | \"\n            \"ETputs(KTEPS) {:.2f}\".format(\n                epoch,\n                np.mean(dur),\n                loss.item(),\n                acc,\n                n_edges / np.mean(dur) / 1000,\n            )\n        )\n\n    print()\n    acc = evaluate(model, features, labels, test_mask)\n    print(\"Test Accuracy {:.4f}\".format(acc))\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"TAGCN\")\n    register_data_args(parser)\n    parser.add_argument(\n        \"--dropout\", type=float, default=0.5, help=\"dropout probability\"\n    )\n    parser.add_argument(\"--gpu\", type=int, default=-1, help=\"gpu\")\n    parser.add_argument(\"--lr\", type=float, default=1e-2, help=\"learning rate\")\n    parser.add_argument(\n        \"--n-epochs\", type=int, default=200, help=\"number of training epochs\"\n    )\n    parser.add_argument(\n        \"--n-hidden\", type=int, default=16, help=\"number of hidden tagcn units\"\n    )\n    parser.add_argument(\n        \"--n-layers\", type=int, default=1, help=\"number of hidden tagcn layers\"\n    )\n    parser.add_argument(\n        \"--weight-decay\", type=float, default=5e-4, help=\"Weight for L2 loss\"\n    )\n    parser.add_argument(\n        \"--self-loop\",\n        action=\"store_true\",\n        help=\"graph self-loop (default=False)\",\n    )\n    parser.set_defaults(self_loop=False)\n    args = parser.parse_args()\n    print(args)\n\n    main(args)\n"
  },
  {
    "path": "examples/pytorch/tgn/README.md",
    "content": "Temporal Graph Neural Network (TGN)\n===\n\nThe example was temporarily removed due to the change in the `DataLoader`\ninterface in DGL 1.0.  Please refer to the v0.9 example\n[here](https://github.com/dmlc/dgl/tree/0.9.x/examples/pytorch/tgn).\n"
  },
  {
    "path": "examples/pytorch/tree_lstm/README.md",
    "content": "# Tree-LSTM\r\nThis is a re-implementation of the following paper:\r\n\r\n> [**Improved Semantic Representations From Tree-Structured Long Short-Term Memory Networks**](http://arxiv.org/abs/1503.00075) \r\n> *Kai Sheng Tai, Richard Socher, and Christopher Manning*. \r\n\r\nThe provided implementation can achieve a test accuracy of 51.72 which is comparable with the result reported in the original paper: 51.0(±0.5).\r\n\r\n## Data\r\nThe script will download the [SST dataset] (http://nlp.stanford.edu/sentiment/index.html) automatically, and you need to download the GloVe word vectors yourself. For the command line, you can use this.\r\n```\r\nwget http://nlp.stanford.edu/data/glove.840B.300d.zip\r\nunzip glove.840B.300d.zip\r\n```\r\n\r\n## Dependencies\r\n* PyTorch 0.4.1+\r\n* requests\r\n* nltk\r\n\r\n```\r\npip install torch requests nltk\r\n```\r\n\r\n## Usage\r\n```\r\npython3 train.py --gpu 0\r\n```\r\n\r\n## Speed\r\n\r\nOn AWS p3.2x instance, it can achieve 3.18s per epoch when setting batch size to 256.\r\n"
  },
  {
    "path": "examples/pytorch/tree_lstm/train.py",
    "content": "import argparse\nimport collections\nimport time\n\nimport dgl\n\nimport numpy as np\nimport torch as th\nimport torch.nn.functional as F\nimport torch.nn.init as INIT\nimport torch.optim as optim\nfrom dgl.data.tree import SSTDataset\nfrom torch.utils.data import DataLoader\nfrom tree_lstm import TreeLSTM\n\nSSTBatch = collections.namedtuple(\n    \"SSTBatch\", [\"graph\", \"mask\", \"wordid\", \"label\"]\n)\n\n\ndef batcher(device):\n    def batcher_dev(batch):\n        batch_trees = dgl.batch(batch)\n        return SSTBatch(\n            graph=batch_trees,\n            mask=batch_trees.ndata[\"mask\"].to(device),\n            wordid=batch_trees.ndata[\"x\"].to(device),\n            label=batch_trees.ndata[\"y\"].to(device),\n        )\n\n    return batcher_dev\n\n\ndef main(args):\n    np.random.seed(args.seed)\n    th.manual_seed(args.seed)\n    th.cuda.manual_seed(args.seed)\n\n    best_epoch = -1\n    best_dev_acc = 0\n\n    cuda = args.gpu >= 0\n    device = th.device(\"cuda:{}\".format(args.gpu)) if cuda else th.device(\"cpu\")\n    if cuda:\n        th.cuda.set_device(args.gpu)\n\n    trainset = SSTDataset()\n    train_loader = DataLoader(\n        dataset=trainset,\n        batch_size=args.batch_size,\n        collate_fn=batcher(device),\n        shuffle=True,\n        num_workers=0,\n    )\n    devset = SSTDataset(mode=\"dev\")\n    dev_loader = DataLoader(\n        dataset=devset,\n        batch_size=100,\n        collate_fn=batcher(device),\n        shuffle=False,\n        num_workers=0,\n    )\n\n    testset = SSTDataset(mode=\"test\")\n    test_loader = DataLoader(\n        dataset=testset,\n        batch_size=100,\n        collate_fn=batcher(device),\n        shuffle=False,\n        num_workers=0,\n    )\n\n    model = TreeLSTM(\n        trainset.vocab_size,\n        args.x_size,\n        args.h_size,\n        trainset.num_classes,\n        args.dropout,\n        cell_type=\"childsum\" if args.child_sum else \"nary\",\n        pretrained_emb=trainset.pretrained_emb,\n    ).to(device)\n    print(model)\n    params_ex_emb = [\n        x\n        for x in list(model.parameters())\n        if x.requires_grad and x.size(0) != trainset.vocab_size\n    ]\n    params_emb = list(model.embedding.parameters())\n\n    for p in params_ex_emb:\n        if p.dim() > 1:\n            INIT.xavier_uniform_(p)\n\n    optimizer = optim.Adagrad(\n        [\n            {\n                \"params\": params_ex_emb,\n                \"lr\": args.lr,\n                \"weight_decay\": args.weight_decay,\n            },\n            {\"params\": params_emb, \"lr\": 0.1 * args.lr},\n        ]\n    )\n\n    dur = []\n    for epoch in range(args.epochs):\n        t_epoch = time.time()\n        model.train()\n        for step, batch in enumerate(train_loader):\n            g = batch.graph.to(device)\n            n = g.num_nodes()\n            h = th.zeros((n, args.h_size)).to(device)\n            c = th.zeros((n, args.h_size)).to(device)\n            if step >= 3:\n                t0 = time.time()  # tik\n\n            logits = model(batch, g, h, c)\n            logp = F.log_softmax(logits, 1)\n            loss = F.nll_loss(logp, batch.label, reduction=\"sum\")\n\n            optimizer.zero_grad()\n            loss.backward()\n            optimizer.step()\n\n            if step >= 3:\n                dur.append(time.time() - t0)  # tok\n\n            if step > 0 and step % args.log_every == 0:\n                pred = th.argmax(logits, 1)\n                acc = th.sum(th.eq(batch.label, pred))\n                root_ids = [\n                    i for i in range(g.num_nodes()) if g.out_degrees(i) == 0\n                ]\n                root_acc = np.sum(\n                    batch.label.cpu().data.numpy()[root_ids]\n                    == pred.cpu().data.numpy()[root_ids]\n                )\n\n                print(\n                    \"Epoch {:05d} | Step {:05d} | Loss {:.4f} | Acc {:.4f} | Root Acc {:.4f} | Time(s) {:.4f}\".format(\n                        epoch,\n                        step,\n                        loss.item(),\n                        1.0 * acc.item() / len(batch.label),\n                        1.0 * root_acc / len(root_ids),\n                        np.mean(dur),\n                    )\n                )\n        print(\n            \"Epoch {:05d} training time {:.4f}s\".format(\n                epoch, time.time() - t_epoch\n            )\n        )\n\n        # eval on dev set\n        accs = []\n        root_accs = []\n        model.eval()\n        for step, batch in enumerate(dev_loader):\n            g = batch.graph.to(device)\n            n = g.num_nodes()\n            with th.no_grad():\n                h = th.zeros((n, args.h_size)).to(device)\n                c = th.zeros((n, args.h_size)).to(device)\n                logits = model(batch, g, h, c)\n\n            pred = th.argmax(logits, 1)\n            acc = th.sum(th.eq(batch.label, pred)).item()\n            accs.append([acc, len(batch.label)])\n            root_ids = [\n                i for i in range(g.num_nodes()) if g.out_degrees(i) == 0\n            ]\n            root_acc = np.sum(\n                batch.label.cpu().data.numpy()[root_ids]\n                == pred.cpu().data.numpy()[root_ids]\n            )\n            root_accs.append([root_acc, len(root_ids)])\n\n        dev_acc = (\n            1.0 * np.sum([x[0] for x in accs]) / np.sum([x[1] for x in accs])\n        )\n        dev_root_acc = (\n            1.0\n            * np.sum([x[0] for x in root_accs])\n            / np.sum([x[1] for x in root_accs])\n        )\n        print(\n            \"Epoch {:05d} | Dev Acc {:.4f} | Root Acc {:.4f}\".format(\n                epoch, dev_acc, dev_root_acc\n            )\n        )\n\n        if dev_root_acc > best_dev_acc:\n            best_dev_acc = dev_root_acc\n            best_epoch = epoch\n            th.save(model.state_dict(), \"best_{}.pkl\".format(args.seed))\n        else:\n            if best_epoch <= epoch - 10:\n                break\n\n        # lr decay\n        for param_group in optimizer.param_groups:\n            param_group[\"lr\"] = max(1e-5, param_group[\"lr\"] * 0.99)  # 10\n            print(param_group[\"lr\"])\n\n    # test\n    model.load_state_dict(th.load(\"best_{}.pkl\".format(args.seed)))\n    accs = []\n    root_accs = []\n    model.eval()\n    for step, batch in enumerate(test_loader):\n        g = batch.graph.to(device)\n        n = g.num_nodes()\n        with th.no_grad():\n            h = th.zeros((n, args.h_size)).to(device)\n            c = th.zeros((n, args.h_size)).to(device)\n            logits = model(batch, g, h, c)\n\n        pred = th.argmax(logits, 1)\n        acc = th.sum(th.eq(batch.label, pred)).item()\n        accs.append([acc, len(batch.label)])\n        root_ids = [i for i in range(g.num_nodes()) if g.out_degrees(i) == 0]\n        root_acc = np.sum(\n            batch.label.cpu().data.numpy()[root_ids]\n            == pred.cpu().data.numpy()[root_ids]\n        )\n        root_accs.append([root_acc, len(root_ids)])\n\n    test_acc = 1.0 * np.sum([x[0] for x in accs]) / np.sum([x[1] for x in accs])\n    test_root_acc = (\n        1.0\n        * np.sum([x[0] for x in root_accs])\n        / np.sum([x[1] for x in root_accs])\n    )\n    print(\n        \"------------------------------------------------------------------------------------\"\n    )\n    print(\n        \"Epoch {:05d} | Test Acc {:.4f} | Root Acc {:.4f}\".format(\n            best_epoch, test_acc, test_root_acc\n        )\n    )\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--gpu\", type=int, default=-1)\n    parser.add_argument(\"--seed\", type=int, default=41)\n    parser.add_argument(\"--batch-size\", type=int, default=20)\n    parser.add_argument(\"--child-sum\", action=\"store_true\")\n    parser.add_argument(\"--x-size\", type=int, default=300)\n    parser.add_argument(\"--h-size\", type=int, default=150)\n    parser.add_argument(\"--epochs\", type=int, default=100)\n    parser.add_argument(\"--log-every\", type=int, default=5)\n    parser.add_argument(\"--lr\", type=float, default=0.05)\n    parser.add_argument(\"--weight-decay\", type=float, default=1e-4)\n    parser.add_argument(\"--dropout\", type=float, default=0.5)\n    args = parser.parse_args()\n    print(args)\n    main(args)\n"
  },
  {
    "path": "examples/pytorch/tree_lstm/tree_lstm.py",
    "content": "\"\"\"\nImproved Semantic Representations From Tree-Structured Long Short-Term Memory Networks\nhttps://arxiv.org/abs/1503.00075\n\"\"\"\nimport itertools\nimport time\n\nimport dgl\nimport networkx as nx\nimport numpy as np\nimport torch as th\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass TreeLSTMCell(nn.Module):\n    def __init__(self, x_size, h_size):\n        super(TreeLSTMCell, self).__init__()\n        self.W_iou = nn.Linear(x_size, 3 * h_size, bias=False)\n        self.U_iou = nn.Linear(2 * h_size, 3 * h_size, bias=False)\n        self.b_iou = nn.Parameter(th.zeros(1, 3 * h_size))\n        self.U_f = nn.Linear(2 * h_size, 2 * h_size)\n\n    def message_func(self, edges):\n        return {\"h\": edges.src[\"h\"], \"c\": edges.src[\"c\"]}\n\n    def reduce_func(self, nodes):\n        h_cat = nodes.mailbox[\"h\"].view(nodes.mailbox[\"h\"].size(0), -1)\n        f = th.sigmoid(self.U_f(h_cat)).view(*nodes.mailbox[\"h\"].size())\n        c = th.sum(f * nodes.mailbox[\"c\"], 1)\n        return {\"iou\": self.U_iou(h_cat), \"c\": c}\n\n    def apply_node_func(self, nodes):\n        iou = nodes.data[\"iou\"] + self.b_iou\n        i, o, u = th.chunk(iou, 3, 1)\n        i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u)\n        c = i * u + nodes.data[\"c\"]\n        h = o * th.tanh(c)\n        return {\"h\": h, \"c\": c}\n\n\nclass ChildSumTreeLSTMCell(nn.Module):\n    def __init__(self, x_size, h_size):\n        super(ChildSumTreeLSTMCell, self).__init__()\n        self.W_iou = nn.Linear(x_size, 3 * h_size, bias=False)\n        self.U_iou = nn.Linear(h_size, 3 * h_size, bias=False)\n        self.b_iou = nn.Parameter(th.zeros(1, 3 * h_size))\n        self.U_f = nn.Linear(h_size, h_size)\n\n    def message_func(self, edges):\n        return {\"h\": edges.src[\"h\"], \"c\": edges.src[\"c\"]}\n\n    def reduce_func(self, nodes):\n        h_tild = th.sum(nodes.mailbox[\"h\"], 1)\n        f = th.sigmoid(self.U_f(nodes.mailbox[\"h\"]))\n        c = th.sum(f * nodes.mailbox[\"c\"], 1)\n        return {\"iou\": self.U_iou(h_tild), \"c\": c}\n\n    def apply_node_func(self, nodes):\n        iou = nodes.data[\"iou\"] + self.b_iou\n        i, o, u = th.chunk(iou, 3, 1)\n        i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u)\n        c = i * u + nodes.data[\"c\"]\n        h = o * th.tanh(c)\n        return {\"h\": h, \"c\": c}\n\n\nclass TreeLSTM(nn.Module):\n    def __init__(\n        self,\n        num_vocabs,\n        x_size,\n        h_size,\n        num_classes,\n        dropout,\n        cell_type=\"nary\",\n        pretrained_emb=None,\n    ):\n        super(TreeLSTM, self).__init__()\n        self.x_size = x_size\n        self.embedding = nn.Embedding(num_vocabs, x_size)\n        if pretrained_emb is not None:\n            print(\"Using glove\")\n            self.embedding.weight.data.copy_(pretrained_emb)\n            self.embedding.weight.requires_grad = True\n        self.dropout = nn.Dropout(dropout)\n        self.linear = nn.Linear(h_size, num_classes)\n        cell = TreeLSTMCell if cell_type == \"nary\" else ChildSumTreeLSTMCell\n        self.cell = cell(x_size, h_size)\n\n    def forward(self, batch, g, h, c):\n        \"\"\"Compute tree-lstm prediction given a batch.\n        Parameters\n        ----------\n        batch : dgl.data.SSTBatch\n            The data batch.\n        g : dgl.DGLGraph\n            Tree for computation.\n        h : Tensor\n            Initial hidden state.\n        c : Tensor\n            Initial cell state.\n        Returns\n        -------\n        logits : Tensor\n            The prediction of each node.\n        \"\"\"\n        # feed embedding\n        embeds = self.embedding(batch.wordid * batch.mask)\n        g.ndata[\"iou\"] = self.cell.W_iou(\n            self.dropout(embeds)\n        ) * batch.mask.float().unsqueeze(-1)\n        g.ndata[\"h\"] = h\n        g.ndata[\"c\"] = c\n        # propagate\n        dgl.prop_nodes_topo(\n            g,\n            self.cell.message_func,\n            self.cell.reduce_func,\n            apply_node_func=self.cell.apply_node_func,\n        )\n        # compute logits\n        h = self.dropout(g.ndata.pop(\"h\"))\n        logits = self.linear(h)\n        return logits\n"
  },
  {
    "path": "examples/pytorch/vgae/README.md",
    "content": "# Variational Graph Auto-Encoders\n\n- Paper link：https://arxiv.org/abs/1611.07308\n- Author's code repo：https://github.com/tkipf/gae\n\n## Requirements\n\n- Pytorch\n- Python 3.x \n- DGL 0.6\n- scikit-learn\n\n## Run the demo\n\nRun with following (available dataset: \"cora\", \"citeseer\", \"pubmed\")\n\n```\npython train.py\n```\n\n## Dataset\n\nIn this example, I use two kinds of data source. One from DGL's bulit-in dataset (CoraGraphDataset, CiteseerGraphDataset and PubmedGraphDataset), another from website https://github.com/kimiyoung/planetoid.\n\nYou can specify a dataset as follows:\n\n```\npython train.py --datasrc dgl --dataset cora  // from DGL\npython train.py --datasrc website --dataset cora  // from website\n```\n\n**Note**: If you want to train by dataset from website, you should download folder https://github.com/kimiyoung/planetoid/tree/master/data. Then put it under project folder.\n\n## Results\n\nUse *area under the ROC curve* (AUC) and *average precision* (AP) scores for each model on the test set. Numbers show mean results and standard error for 10 runs with random initializations on fixed dataset splits.\n\n### Dataset from DGL\n\n| Dataset  | AUC            | AP            |\n| -------- | -------------- | ------------- |\n| Cora     | 91.8$\\pm$ 0.01 | 92.5$\\pm$0.01 |\n| Citeseer | 89.2$\\pm$0.02  | 90.8$\\pm$0.01 |\n| Pubmed   | 94.5$\\pm$0.01  | 94.6$\\pm$0.01 |\n\n### Dataset from website\n\n| Dataset  | AUC            | AP             |\n| -------- | -------------- | -------------- |\n| Cora     | 90.9$\\pm$ 0.01 | 92.1$\\pm$0.01  |\n| Citeseer | 90.3$\\pm$0.01  | 91.8$\\pm$0.01  |\n| Pubmed   | 94.4$\\pm$ 0.01 | 94.6$\\pm$ 0.01 |\n\n### Reported results in paper\n\n| Dataset  | AUC            | AP            |\n| -------- | -------------- | ------------- |\n| Cora     | 91.4$\\pm$ 0.01 | 92.6$\\pm$0.01 |\n| Citeseer | 90.8$\\pm$0.02  | 92.0$\\pm$0.02 |\n| Pubmed   | 94.4$\\pm$0.02  | 94.7$\\pm$0.02 |\n\n"
  },
  {
    "path": "examples/pytorch/vgae/input_data.py",
    "content": "\"\"\"\r\n****************NOTE*****************\r\nCREDITS : Thomas Kipf\r\nsince datasets are the same as those in kipf's implementation, \r\nTheir preprocessing source was used as-is.\r\n*************************************\r\n\"\"\"\r\nimport pickle as pkl\r\nimport sys\r\n\r\nimport networkx as nx\r\nimport numpy as np\r\nimport scipy.sparse as sp\r\n\r\n\r\ndef parse_index_file(filename):\r\n    index = []\r\n    for line in open(filename):\r\n        index.append(int(line.strip()))\r\n    return index\r\n\r\n\r\ndef load_data(dataset):\r\n    # load the data: x, tx, allx, graph\r\n    names = [\"x\", \"tx\", \"allx\", \"graph\"]\r\n    objects = []\r\n    for i in range(len(names)):\r\n        with open(\"data/ind.{}.{}\".format(dataset, names[i]), \"rb\") as f:\r\n            if sys.version_info > (3, 0):\r\n                objects.append(pkl.load(f, encoding=\"latin1\"))\r\n            else:\r\n                objects.append(pkl.load(f))\r\n    x, tx, allx, graph = tuple(objects)\r\n    test_idx_reorder = parse_index_file(\r\n        \"data/ind.{}.test.index\".format(dataset)\r\n    )\r\n    test_idx_range = np.sort(test_idx_reorder)\r\n\r\n    if dataset == \"citeseer\":\r\n        # Fix citeseer dataset (there are some isolated nodes in the graph)\r\n        # Find isolated nodes, add them as zero-vecs into the right position\r\n        test_idx_range_full = range(\r\n            min(test_idx_reorder), max(test_idx_reorder) + 1\r\n        )\r\n        tx_extended = sp.lil_matrix((len(test_idx_range_full), x.shape[1]))\r\n        tx_extended[test_idx_range - min(test_idx_range), :] = tx\r\n        tx = tx_extended\r\n\r\n    features = sp.vstack((allx, tx)).tolil()\r\n    features[test_idx_reorder, :] = features[test_idx_range, :]\r\n    adj = nx.adjacency_matrix(nx.from_dict_of_lists(graph))\r\n\r\n    return adj, features\r\n"
  },
  {
    "path": "examples/pytorch/vgae/model.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom dgl.nn.pytorch import GraphConv\nfrom train import device\n\n\nclass VGAEModel(nn.Module):\n    def __init__(self, in_dim, hidden1_dim, hidden2_dim):\n        super(VGAEModel, self).__init__()\n        self.in_dim = in_dim\n        self.hidden1_dim = hidden1_dim\n        self.hidden2_dim = hidden2_dim\n\n        layers = [\n            GraphConv(\n                self.in_dim,\n                self.hidden1_dim,\n                activation=F.relu,\n                allow_zero_in_degree=True,\n            ),\n            GraphConv(\n                self.hidden1_dim,\n                self.hidden2_dim,\n                activation=lambda x: x,\n                allow_zero_in_degree=True,\n            ),\n            GraphConv(\n                self.hidden1_dim,\n                self.hidden2_dim,\n                activation=lambda x: x,\n                allow_zero_in_degree=True,\n            ),\n        ]\n        self.layers = nn.ModuleList(layers)\n\n    def encoder(self, g, features):\n        h = self.layers[0](g, features)\n        self.mean = self.layers[1](g, h)\n        self.log_std = self.layers[2](g, h)\n        gaussian_noise = torch.randn(features.size(0), self.hidden2_dim).to(\n            device\n        )\n        sampled_z = self.mean + gaussian_noise * torch.exp(self.log_std).to(\n            device\n        )\n        return sampled_z\n\n    def decoder(self, z):\n        adj_rec = torch.sigmoid(torch.matmul(z, z.t()))\n        return adj_rec\n\n    def forward(self, g, features):\n        z = self.encoder(g, features)\n        adj_rec = self.decoder(z)\n        return adj_rec\n"
  },
  {
    "path": "examples/pytorch/vgae/preprocess.py",
    "content": "import numpy as np\nimport scipy.sparse as sp\nimport torch\n\n\ndef mask_test_edges(adj):\n    # Function to build test set with 10% positive links\n    # NOTE: Splits are randomized and results might slightly deviate from reported numbers in the paper.\n    # TODO: Clean up.\n\n    # Remove diagonal elements\n    adj = adj - sp.dia_matrix(\n        (adj.diagonal()[np.newaxis, :], [0]), shape=adj.shape\n    )\n    adj.eliminate_zeros()\n    # Check that diag is zero:\n    assert np.diag(adj.todense()).sum() == 0\n\n    adj_triu = sp.triu(adj)\n    adj_tuple = sparse_to_tuple(adj_triu)\n    edges = adj_tuple[0]\n    edges_all = sparse_to_tuple(adj)[0]\n    num_test = int(np.floor(edges.shape[0] / 10.0))\n    num_val = int(np.floor(edges.shape[0] / 20.0))\n\n    all_edge_idx = list(range(edges.shape[0]))\n    np.random.shuffle(all_edge_idx)\n    val_edge_idx = all_edge_idx[:num_val]\n    test_edge_idx = all_edge_idx[num_val : (num_val + num_test)]\n    test_edges = edges[test_edge_idx]\n    val_edges = edges[val_edge_idx]\n    train_edges = np.delete(\n        edges, np.hstack([test_edge_idx, val_edge_idx]), axis=0\n    )\n\n    def ismember(a, b, tol=5):\n        rows_close = np.all(np.round(a - b[:, None], tol) == 0, axis=-1)\n        return np.any(rows_close)\n\n    test_edges_false = []\n    while len(test_edges_false) < len(test_edges):\n        idx_i = np.random.randint(0, adj.shape[0])\n        idx_j = np.random.randint(0, adj.shape[0])\n        if idx_i == idx_j:\n            continue\n        if ismember([idx_i, idx_j], edges_all):\n            continue\n        if test_edges_false:\n            if ismember([idx_j, idx_i], np.array(test_edges_false)):\n                continue\n            if ismember([idx_i, idx_j], np.array(test_edges_false)):\n                continue\n        test_edges_false.append([idx_i, idx_j])\n\n    val_edges_false = []\n    while len(val_edges_false) < len(val_edges):\n        idx_i = np.random.randint(0, adj.shape[0])\n        idx_j = np.random.randint(0, adj.shape[0])\n        if idx_i == idx_j:\n            continue\n        if ismember([idx_i, idx_j], train_edges):\n            continue\n        if ismember([idx_j, idx_i], train_edges):\n            continue\n        if ismember([idx_i, idx_j], val_edges):\n            continue\n        if ismember([idx_j, idx_i], val_edges):\n            continue\n        if val_edges_false:\n            if ismember([idx_j, idx_i], np.array(val_edges_false)):\n                continue\n            if ismember([idx_i, idx_j], np.array(val_edges_false)):\n                continue\n        val_edges_false.append([idx_i, idx_j])\n\n    assert ~ismember(test_edges_false, edges_all)\n    assert ~ismember(val_edges_false, edges_all)\n    assert ~ismember(val_edges, train_edges)\n    assert ~ismember(test_edges, train_edges)\n    assert ~ismember(val_edges, test_edges)\n\n    data = np.ones(train_edges.shape[0])\n\n    # Re-build adj matrix\n    adj_train = sp.csr_matrix(\n        (data, (train_edges[:, 0], train_edges[:, 1])), shape=adj.shape\n    )\n    adj_train = adj_train + adj_train.T\n\n    # NOTE: these edge lists only contain single direction of edge!\n    return (\n        adj_train,\n        train_edges,\n        val_edges,\n        val_edges_false,\n        test_edges,\n        test_edges_false,\n    )\n\n\ndef mask_test_edges_dgl(graph, adj):\n    src, dst = graph.edges()\n    edges_all = torch.stack([src, dst], dim=0)\n    edges_all = edges_all.t().cpu().numpy()\n    num_test = int(np.floor(edges_all.shape[0] / 10.0))\n    num_val = int(np.floor(edges_all.shape[0] / 20.0))\n\n    all_edge_idx = list(range(edges_all.shape[0]))\n    np.random.shuffle(all_edge_idx)\n    val_edge_idx = all_edge_idx[:num_val]\n    test_edge_idx = all_edge_idx[num_val : (num_val + num_test)]\n    train_edge_idx = all_edge_idx[(num_val + num_test) :]\n    test_edges = edges_all[test_edge_idx]\n    val_edges = edges_all[val_edge_idx]\n    train_edges = np.delete(\n        edges_all, np.hstack([test_edge_idx, val_edge_idx]), axis=0\n    )\n\n    def ismember(a, b, tol=5):\n        rows_close = np.all(np.round(a - b[:, None], tol) == 0, axis=-1)\n        return np.any(rows_close)\n\n    test_edges_false = []\n    while len(test_edges_false) < len(test_edges):\n        idx_i = np.random.randint(0, adj.shape[0])\n        idx_j = np.random.randint(0, adj.shape[0])\n        if idx_i == idx_j:\n            continue\n        if ismember([idx_i, idx_j], edges_all):\n            continue\n        if test_edges_false:\n            if ismember([idx_j, idx_i], np.array(test_edges_false)):\n                continue\n            if ismember([idx_i, idx_j], np.array(test_edges_false)):\n                continue\n        test_edges_false.append([idx_i, idx_j])\n\n    val_edges_false = []\n    while len(val_edges_false) < len(val_edges):\n        idx_i = np.random.randint(0, adj.shape[0])\n        idx_j = np.random.randint(0, adj.shape[0])\n        if idx_i == idx_j:\n            continue\n        if ismember([idx_i, idx_j], train_edges):\n            continue\n        if ismember([idx_j, idx_i], train_edges):\n            continue\n        if ismember([idx_i, idx_j], val_edges):\n            continue\n        if ismember([idx_j, idx_i], val_edges):\n            continue\n        if val_edges_false:\n            if ismember([idx_j, idx_i], np.array(val_edges_false)):\n                continue\n            if ismember([idx_i, idx_j], np.array(val_edges_false)):\n                continue\n        val_edges_false.append([idx_i, idx_j])\n\n    assert ~ismember(test_edges_false, edges_all)\n    assert ~ismember(val_edges_false, edges_all)\n    assert ~ismember(val_edges, train_edges)\n    assert ~ismember(test_edges, train_edges)\n    assert ~ismember(val_edges, test_edges)\n\n    # NOTE: these edge lists only contain single direction of edge!\n    return (\n        train_edge_idx,\n        val_edges,\n        val_edges_false,\n        test_edges,\n        test_edges_false,\n    )\n\n\ndef sparse_to_tuple(sparse_mx):\n    if not sp.isspmatrix_coo(sparse_mx):\n        sparse_mx = sparse_mx.tocoo()\n    coords = np.vstack((sparse_mx.row, sparse_mx.col)).transpose()\n    values = sparse_mx.data\n    shape = sparse_mx.shape\n    return coords, values, shape\n\n\ndef preprocess_graph(adj):\n    adj = sp.coo_matrix(adj)\n    adj_ = adj + sp.eye(adj.shape[0])\n    rowsum = np.array(adj_.sum(1))\n    degree_mat_inv_sqrt = sp.diags(np.power(rowsum, -0.5).flatten())\n    adj_normalized = (\n        adj_.dot(degree_mat_inv_sqrt)\n        .transpose()\n        .dot(degree_mat_inv_sqrt)\n        .tocoo()\n    )\n    return adj_normalized, sparse_to_tuple(adj_normalized)\n"
  },
  {
    "path": "examples/pytorch/vgae/train.py",
    "content": "import argparse\nimport os\nimport time\n\nimport dgl\n\nimport model\nimport numpy as np\nimport scipy.sparse as sp\nimport torch\nimport torch.nn.functional as F\nfrom dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset\nfrom input_data import load_data\nfrom preprocess import (\n    mask_test_edges,\n    mask_test_edges_dgl,\n    preprocess_graph,\n    sparse_to_tuple,\n)\nfrom sklearn.metrics import average_precision_score, roc_auc_score\n\nos.environ[\"KMP_DUPLICATE_LIB_OK\"] = \"True\"\n\nparser = argparse.ArgumentParser(description=\"Variant Graph Auto Encoder\")\nparser.add_argument(\n    \"--learning_rate\", type=float, default=0.01, help=\"Initial learning rate.\"\n)\nparser.add_argument(\n    \"--epochs\", \"-e\", type=int, default=200, help=\"Number of epochs to train.\"\n)\nparser.add_argument(\n    \"--hidden1\",\n    \"-h1\",\n    type=int,\n    default=32,\n    help=\"Number of units in hidden layer 1.\",\n)\nparser.add_argument(\n    \"--hidden2\",\n    \"-h2\",\n    type=int,\n    default=16,\n    help=\"Number of units in hidden layer 2.\",\n)\nparser.add_argument(\n    \"--datasrc\",\n    \"-s\",\n    type=str,\n    default=\"dgl\",\n    help=\"Dataset download from dgl Dataset or website.\",\n)\nparser.add_argument(\n    \"--dataset\", \"-d\", type=str, default=\"cora\", help=\"Dataset string.\"\n)\nparser.add_argument(\"--gpu_id\", type=int, default=0, help=\"GPU id to use.\")\nargs = parser.parse_args()\n\n\n# check device\ndevice = torch.device(\n    \"cuda:{}\".format(args.gpu_id) if torch.cuda.is_available() else \"cpu\"\n)\n# device = \"cpu\"\n\n# roc_means = []\n# ap_means = []\n\n\ndef compute_loss_para(adj):\n    pos_weight = (adj.shape[0] * adj.shape[0] - adj.sum()) / adj.sum()\n    norm = (\n        adj.shape[0]\n        * adj.shape[0]\n        / float((adj.shape[0] * adj.shape[0] - adj.sum()) * 2)\n    )\n    weight_mask = adj.view(-1) == 1\n    weight_tensor = torch.ones(weight_mask.size(0)).to(device)\n    weight_tensor[weight_mask] = pos_weight\n    return weight_tensor, norm\n\n\ndef get_acc(adj_rec, adj_label):\n    labels_all = adj_label.view(-1).long()\n    preds_all = (adj_rec > 0.5).view(-1).long()\n    accuracy = (preds_all == labels_all).sum().float() / labels_all.size(0)\n    return accuracy\n\n\ndef get_scores(edges_pos, edges_neg, adj_rec):\n    def sigmoid(x):\n        return 1 / (1 + np.exp(-x))\n\n    adj_rec = adj_rec.cpu()\n    # Predict on test set of edges\n    preds = []\n    for e in edges_pos:\n        preds.append(sigmoid(adj_rec[e[0], e[1]].item()))\n\n    preds_neg = []\n    for e in edges_neg:\n        preds_neg.append(sigmoid(adj_rec[e[0], e[1]].data))\n\n    preds_all = np.hstack([preds, preds_neg])\n    labels_all = np.hstack([np.ones(len(preds)), np.zeros(len(preds_neg))])\n    roc_score = roc_auc_score(labels_all, preds_all)\n    ap_score = average_precision_score(labels_all, preds_all)\n\n    return roc_score, ap_score\n\n\ndef dgl_main():\n    # Load from DGL dataset\n    if args.dataset == \"cora\":\n        dataset = CoraGraphDataset(reverse_edge=False)\n    elif args.dataset == \"citeseer\":\n        dataset = CiteseerGraphDataset(reverse_edge=False)\n    elif args.dataset == \"pubmed\":\n        dataset = PubmedGraphDataset(reverse_edge=False)\n    else:\n        raise NotImplementedError\n    graph = dataset[0]\n\n    # Extract node features\n    feats = graph.ndata.pop(\"feat\").to(device)\n    in_dim = feats.shape[-1]\n\n    # generate input\n    adj_orig = graph.adj_external().to_dense()\n\n    # build test set with 10% positive links\n    (\n        train_edge_idx,\n        val_edges,\n        val_edges_false,\n        test_edges,\n        test_edges_false,\n    ) = mask_test_edges_dgl(graph, adj_orig)\n\n    graph = graph.to(device)\n\n    # create train graph\n    train_edge_idx = torch.tensor(train_edge_idx).to(device)\n    train_graph = dgl.edge_subgraph(graph, train_edge_idx, relabel_nodes=False)\n    train_graph = train_graph.to(device)\n    adj = train_graph.adj_external().to_dense().to(device)\n\n    # compute loss parameters\n    weight_tensor, norm = compute_loss_para(adj)\n\n    # create model\n    vgae_model = model.VGAEModel(in_dim, args.hidden1, args.hidden2)\n    vgae_model = vgae_model.to(device)\n\n    # create training component\n    optimizer = torch.optim.Adam(vgae_model.parameters(), lr=args.learning_rate)\n    print(\n        \"Total Parameters:\",\n        sum([p.nelement() for p in vgae_model.parameters()]),\n    )\n\n    # create training epoch\n    for epoch in range(args.epochs):\n        t = time.time()\n\n        # Training and validation using a full graph\n        vgae_model.train()\n\n        logits = vgae_model.forward(graph, feats)\n\n        # compute loss\n        loss = norm * F.binary_cross_entropy(\n            logits.view(-1), adj.view(-1), weight=weight_tensor\n        )\n        kl_divergence = (\n            0.5\n            / logits.size(0)\n            * (\n                1\n                + 2 * vgae_model.log_std\n                - vgae_model.mean**2\n                - torch.exp(vgae_model.log_std) ** 2\n            )\n            .sum(1)\n            .mean()\n        )\n        loss -= kl_divergence\n\n        # backward\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        train_acc = get_acc(logits, adj)\n\n        val_roc, val_ap = get_scores(val_edges, val_edges_false, logits)\n\n        # Print out performance\n        print(\n            \"Epoch:\",\n            \"%04d\" % (epoch + 1),\n            \"train_loss=\",\n            \"{:.5f}\".format(loss.item()),\n            \"train_acc=\",\n            \"{:.5f}\".format(train_acc),\n            \"val_roc=\",\n            \"{:.5f}\".format(val_roc),\n            \"val_ap=\",\n            \"{:.5f}\".format(val_ap),\n            \"time=\",\n            \"{:.5f}\".format(time.time() - t),\n        )\n\n    test_roc, test_ap = get_scores(test_edges, test_edges_false, logits)\n    # roc_means.append(test_roc)\n    # ap_means.append(test_ap)\n    print(\n        \"End of training!\",\n        \"test_roc=\",\n        \"{:.5f}\".format(test_roc),\n        \"test_ap=\",\n        \"{:.5f}\".format(test_ap),\n    )\n\n\ndef web_main():\n    adj, features = load_data(args.dataset)\n\n    features = sparse_to_tuple(features.tocoo())\n\n    # Store original adjacency matrix (without diagonal entries) for later\n    adj_orig = adj\n    adj_orig = adj_orig - sp.dia_matrix(\n        (adj_orig.diagonal()[np.newaxis, :], [0]), shape=adj_orig.shape\n    )\n    adj_orig.eliminate_zeros()\n\n    (\n        adj_train,\n        train_edges,\n        val_edges,\n        val_edges_false,\n        test_edges,\n        test_edges_false,\n    ) = mask_test_edges(adj)\n    adj = adj_train\n\n    # # Create model\n    # graph = dgl.from_scipy(adj)\n    # graph.add_self_loop()\n\n    # Some preprocessing\n    adj_normalization, adj_norm = preprocess_graph(adj)\n\n    # Create model\n    graph = dgl.from_scipy(adj_normalization)\n    graph.add_self_loop()\n\n    # Create Model\n    pos_weight = float(adj.shape[0] * adj.shape[0] - adj.sum()) / adj.sum()\n    norm = (\n        adj.shape[0]\n        * adj.shape[0]\n        / float((adj.shape[0] * adj.shape[0] - adj.sum()) * 2)\n    )\n\n    adj_label = adj_train + sp.eye(adj_train.shape[0])\n    adj_label = sparse_to_tuple(adj_label)\n\n    adj_norm = torch.sparse.FloatTensor(\n        torch.LongTensor(adj_norm[0].T),\n        torch.FloatTensor(adj_norm[1]),\n        torch.Size(adj_norm[2]),\n    )\n    adj_label = torch.sparse.FloatTensor(\n        torch.LongTensor(adj_label[0].T),\n        torch.FloatTensor(adj_label[1]),\n        torch.Size(adj_label[2]),\n    )\n    features = torch.sparse.FloatTensor(\n        torch.LongTensor(features[0].T),\n        torch.FloatTensor(features[1]),\n        torch.Size(features[2]),\n    )\n\n    weight_mask = adj_label.to_dense().view(-1) == 1\n    weight_tensor = torch.ones(weight_mask.size(0))\n    weight_tensor[weight_mask] = pos_weight\n\n    features = features.to_dense()\n    in_dim = features.shape[-1]\n\n    vgae_model = model.VGAEModel(in_dim, args.hidden1, args.hidden2)\n    # create training component\n    optimizer = torch.optim.Adam(vgae_model.parameters(), lr=args.learning_rate)\n    print(\n        \"Total Parameters:\",\n        sum([p.nelement() for p in vgae_model.parameters()]),\n    )\n\n    def get_scores(edges_pos, edges_neg, adj_rec):\n        def sigmoid(x):\n            return 1 / (1 + np.exp(-x))\n\n        # Predict on test set of edges\n        preds = []\n        pos = []\n        for e in edges_pos:\n            # print(e)\n            # print(adj_rec[e[0], e[1]])\n            preds.append(sigmoid(adj_rec[e[0], e[1]].item()))\n            pos.append(adj_orig[e[0], e[1]])\n\n        preds_neg = []\n        neg = []\n        for e in edges_neg:\n            preds_neg.append(sigmoid(adj_rec[e[0], e[1]].data))\n            neg.append(adj_orig[e[0], e[1]])\n\n        preds_all = np.hstack([preds, preds_neg])\n        labels_all = np.hstack([np.ones(len(preds)), np.zeros(len(preds_neg))])\n        roc_score = roc_auc_score(labels_all, preds_all)\n        ap_score = average_precision_score(labels_all, preds_all)\n\n        return roc_score, ap_score\n\n    def get_acc(adj_rec, adj_label):\n        labels_all = adj_label.to_dense().view(-1).long()\n        preds_all = (adj_rec > 0.5).view(-1).long()\n        accuracy = (preds_all == labels_all).sum().float() / labels_all.size(0)\n        return accuracy\n\n    # create training epoch\n    for epoch in range(args.epochs):\n        t = time.time()\n\n        # Training and validation using a full graph\n        vgae_model.train()\n\n        logits = vgae_model.forward(graph, features)\n\n        # compute loss\n        loss = norm * F.binary_cross_entropy(\n            logits.view(-1), adj_label.to_dense().view(-1), weight=weight_tensor\n        )\n        kl_divergence = (\n            0.5\n            / logits.size(0)\n            * (\n                1\n                + 2 * vgae_model.log_std\n                - vgae_model.mean**2\n                - torch.exp(vgae_model.log_std) ** 2\n            )\n            .sum(1)\n            .mean()\n        )\n        loss -= kl_divergence\n\n        # backward\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        train_acc = get_acc(logits, adj_label)\n\n        val_roc, val_ap = get_scores(val_edges, val_edges_false, logits)\n\n        # Print out performance\n        print(\n            \"Epoch:\",\n            \"%04d\" % (epoch + 1),\n            \"train_loss=\",\n            \"{:.5f}\".format(loss.item()),\n            \"train_acc=\",\n            \"{:.5f}\".format(train_acc),\n            \"val_roc=\",\n            \"{:.5f}\".format(val_roc),\n            \"val_ap=\",\n            \"{:.5f}\".format(val_ap),\n            \"time=\",\n            \"{:.5f}\".format(time.time() - t),\n        )\n\n    test_roc, test_ap = get_scores(test_edges, test_edges_false, logits)\n    print(\n        \"End of training!\",\n        \"test_roc=\",\n        \"{:.5f}\".format(test_roc),\n        \"test_ap=\",\n        \"{:.5f}\".format(test_ap),\n    )\n    # roc_means.append(test_roc)\n    # ap_means.append(test_ap)\n\n\n# if __name__ == '__main__':\n#     for i in range(10):\n#         web_main()\n#\n#     roc_mean = np.mean(roc_means)\n#     roc_std = np.std(roc_means, ddof=1)\n#     ap_mean = np.mean(ap_means)\n#     ap_std = np.std(ap_means, ddof=1)\n#     print(\"roc_mean=\", \"{:.5f}\".format(roc_mean), \"roc_std=\", \"{:.5f}\".format(roc_std), \"ap_mean=\",\n#           \"{:.5f}\".format(ap_mean), \"ap_std=\", \"{:.5f}\".format(ap_std))\n\nif __name__ == \"__main__\":\n    if args.datasrc == \"dgl\":\n        dgl_main()\n    elif args.datasrc == \"website\":\n        web_main()\n"
  },
  {
    "path": "examples/pytorch/vrgcn/README.md",
    "content": "VRGCN (control variate sampling)\n================================\n\nPaper: https://arxiv.org/abs/1710.10568\n\nRun with\n\n```bash\npython3 train_cv.py --num-epochs 30\npython3 train_cv_multi_gpu.py --num-epochs 30 --gpu 0,1,2,3  # multi-GPU\n```\n"
  },
  {
    "path": "examples/pytorch/vrgcn/train_cv.py",
    "content": "import argparse\nimport time\n\nimport dgl\nimport dgl.function as fn\nimport dgl.nn.pytorch as dglnn\n\nimport numpy as np\nimport torch as th\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nimport tqdm\nfrom dgl.data import RedditDataset\nfrom torch.utils.data import DataLoader\n\n\nclass SAGEConvWithCV(nn.Module):\n    def __init__(self, in_feats, out_feats, activation):\n        super().__init__()\n        self.W = nn.Linear(in_feats * 2, out_feats)\n        self.activation = activation\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        gain = nn.init.calculate_gain(\"relu\")\n        nn.init.xavier_uniform_(self.W.weight, gain=gain)\n        nn.init.constant_(self.W.bias, 0)\n\n    def forward(self, block, H, HBar=None):\n        if self.training:\n            with block.local_scope():\n                H_src, H_dst = H\n                HBar_src, agg_HBar_dst = HBar\n                block.dstdata[\"agg_hbar\"] = agg_HBar_dst\n                block.srcdata[\"hdelta\"] = H_src - HBar_src\n                block.update_all(\n                    fn.copy_u(\"hdelta\", \"m\"), fn.mean(\"m\", \"hdelta_new\")\n                )\n                h_neigh = (\n                    block.dstdata[\"agg_hbar\"] + block.dstdata[\"hdelta_new\"]\n                )\n                h = self.W(th.cat([H_dst, h_neigh], 1))\n                if self.activation is not None:\n                    h = self.activation(h)\n                return h\n        else:\n            with block.local_scope():\n                H_src, H_dst = H\n                block.srcdata[\"h\"] = H_src\n                block.update_all(fn.copy_u(\"h\", \"m\"), fn.mean(\"m\", \"h_new\"))\n                h_neigh = block.dstdata[\"h_new\"]\n                h = self.W(th.cat([H_dst, h_neigh], 1))\n                if self.activation is not None:\n                    h = self.activation(h)\n                return h\n\n\nclass SAGE(nn.Module):\n    def __init__(self, in_feats, n_hidden, n_classes, n_layers, activation):\n        super().__init__()\n        self.n_layers = n_layers\n        self.n_hidden = n_hidden\n        self.n_classes = n_classes\n        self.layers = nn.ModuleList()\n        self.layers.append(SAGEConvWithCV(in_feats, n_hidden, activation))\n        for i in range(1, n_layers - 1):\n            self.layers.append(SAGEConvWithCV(n_hidden, n_hidden, activation))\n        self.layers.append(SAGEConvWithCV(n_hidden, n_classes, None))\n\n    def forward(self, blocks):\n        h = blocks[0].srcdata[\"features\"]\n        updates = []\n        for layer, block in zip(self.layers, blocks):\n            # We need to first copy the representation of nodes on the RHS from the\n            # appropriate nodes on the LHS.\n            # Note that the shape of h is (num_nodes_LHS, D) and the shape of h_dst\n            # would be (num_nodes_RHS, D)\n            h_dst = h[: block.number_of_dst_nodes()]\n            hbar_src = block.srcdata[\"hist\"]\n            agg_hbar_dst = block.dstdata[\"agg_hist\"]\n            # Then we compute the updated representation on the RHS.\n            # The shape of h now becomes (num_nodes_RHS, D)\n            h = layer(block, (h, h_dst), (hbar_src, agg_hbar_dst))\n            block.dstdata[\"h_new\"] = h\n        return h\n\n    def inference(self, g, x, batch_size, device):\n        \"\"\"\n        Inference with the GraphSAGE model on full neighbors (i.e. without neighbor sampling).\n        g : the entire graph.\n        x : the input of entire node set.\n\n        The inference code is written in a fashion that it could handle any number of nodes and\n        layers.\n        \"\"\"\n        # During inference with sampling, multi-layer blocks are very inefficient because\n        # lots of computations in the first few layers are repeated.\n        # Therefore, we compute the representation of all nodes layer by layer.  The nodes\n        # on each layer are of course splitted in batches.\n        # TODO: can we standardize this?\n        nodes = th.arange(g.num_nodes())\n        ys = []\n        for l, layer in enumerate(self.layers):\n            y = th.zeros(\n                g.num_nodes(),\n                self.n_hidden if l != len(self.layers) - 1 else self.n_classes,\n            )\n\n            for start in tqdm.trange(0, len(nodes), batch_size):\n                end = start + batch_size\n                batch_nodes = nodes[start:end]\n                block = dgl.to_block(\n                    dgl.in_subgraph(g, batch_nodes), batch_nodes\n                )\n                block = block.int().to(device)\n                induced_nodes = block.srcdata[dgl.NID]\n\n                h = x[induced_nodes].to(device)\n                h_dst = h[: block.number_of_dst_nodes()]\n                h = layer(block, (h, h_dst))\n\n                y[start:end] = h.cpu()\n\n            ys.append(y)\n            x = y\n        return y, ys\n\n\nclass NeighborSampler(object):\n    def __init__(self, g, fanouts):\n        self.g = g\n        self.fanouts = fanouts\n\n    def sample_blocks(self, seeds):\n        seeds = th.LongTensor(seeds)\n        blocks = []\n        hist_blocks = []\n        for fanout in self.fanouts:\n            # For each seed node, sample ``fanout`` neighbors.\n            frontier = dgl.sampling.sample_neighbors(self.g, seeds, fanout)\n            hist_frontier = dgl.in_subgraph(self.g, seeds)\n            # Then we compact the frontier into a bipartite graph for message passing.\n            block = dgl.to_block(frontier, seeds)\n            hist_block = dgl.to_block(hist_frontier, seeds)\n            # Obtain the seed nodes for next layer.\n            seeds = block.srcdata[dgl.NID]\n\n            blocks.insert(0, block)\n            hist_blocks.insert(0, hist_block)\n        return blocks, hist_blocks\n\n\ndef compute_acc(pred, labels):\n    \"\"\"\n    Compute the accuracy of prediction given the labels.\n    \"\"\"\n    return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred)\n\n\ndef evaluate(model, g, labels, val_mask, batch_size, device):\n    \"\"\"\n    Evaluate the model on the validation set specified by ``val_mask``.\n    g : The entire graph.\n    inputs : The features of all the nodes.\n    labels : The labels of all the nodes.\n    val_mask : A 0-1 mask indicating which nodes do we actually compute the accuracy for.\n    batch_size : Number of nodes to compute at the same time.\n    device : The GPU device to evaluate on.\n    \"\"\"\n    model.eval()\n    with th.no_grad():\n        inputs = g.ndata[\"features\"]\n        pred, _ = model.inference(g, inputs, batch_size, device)\n    model.train()\n    return compute_acc(pred[val_mask], labels[val_mask])\n\n\ndef load_subtensor(\n    g, labels, blocks, hist_blocks, dev_id, aggregation_on_device=False\n):\n    \"\"\"\n    Copys features and labels of a set of nodes onto GPU.\n    \"\"\"\n    blocks[0].srcdata[\"features\"] = g.ndata[\"features\"][\n        blocks[0].srcdata[dgl.NID]\n    ]\n    blocks[-1].dstdata[\"label\"] = labels[blocks[-1].dstdata[dgl.NID]]\n    ret_blocks = []\n    ret_hist_blocks = []\n    for i, (block, hist_block) in enumerate(zip(blocks, hist_blocks)):\n        hist_col = \"features\" if i == 0 else \"hist_%d\" % i\n        block.srcdata[\"hist\"] = g.ndata[hist_col][block.srcdata[dgl.NID]]\n\n        # Aggregate history\n        hist_block.srcdata[\"hist\"] = g.ndata[hist_col][\n            hist_block.srcdata[dgl.NID]\n        ]\n        if aggregation_on_device:\n            hist_block = hist_block.to(dev_id)\n        hist_block.update_all(fn.copy_u(\"hist\", \"m\"), fn.mean(\"m\", \"agg_hist\"))\n\n        block = block.int().to(dev_id)\n        if not aggregation_on_device:\n            hist_block = hist_block.to(dev_id)\n        block.dstdata[\"agg_hist\"] = hist_block.dstdata[\"agg_hist\"]\n        ret_blocks.append(block)\n        ret_hist_blocks.append(hist_block)\n    return ret_blocks, ret_hist_blocks\n\n\ndef init_history(g, model, dev_id):\n    with th.no_grad():\n        history = model.inference(g, g.ndata[\"features\"], 1000, dev_id)[1]\n        for layer in range(args.num_layers + 1):\n            if layer > 0:\n                hist_col = \"hist_%d\" % layer\n                g.ndata[\"hist_%d\" % layer] = history[layer - 1]\n\n\ndef update_history(g, blocks):\n    with th.no_grad():\n        for i, block in enumerate(blocks):\n            ids = block.dstdata[dgl.NID].cpu()\n            hist_col = \"hist_%d\" % (i + 1)\n\n            h_new = block.dstdata[\"h_new\"].cpu()\n            g.ndata[hist_col][ids] = h_new\n\n\ndef run(args, dev_id, data):\n    dropout = 0.2\n\n    th.cuda.set_device(dev_id)\n\n    # Unpack data\n    train_mask, val_mask, in_feats, labels, n_classes, g = data\n    train_nid = train_mask.nonzero().squeeze()\n    val_nid = val_mask.nonzero().squeeze()\n\n    # Create sampler\n    sampler = NeighborSampler(g, [int(_) for _ in args.fan_out.split(\",\")])\n\n    # Create PyTorch DataLoader for constructing blocks\n    dataloader = DataLoader(\n        dataset=train_nid.numpy(),\n        batch_size=args.batch_size,\n        collate_fn=sampler.sample_blocks,\n        shuffle=True,\n        drop_last=False,\n        num_workers=args.num_workers_per_gpu,\n    )\n\n    # Define model\n    model = SAGE(in_feats, args.num_hidden, n_classes, args.num_layers, F.relu)\n\n    # Move the model to GPU and define optimizer\n    model = model.to(dev_id)\n    loss_fcn = nn.CrossEntropyLoss()\n    loss_fcn = loss_fcn.to(dev_id)\n    optimizer = optim.Adam(model.parameters(), lr=args.lr)\n\n    # Compute history tensor and their aggregation before training on CPU\n    model.eval()\n    init_history(g, model, dev_id)\n    model.train()\n\n    # Training loop\n    avg = 0\n    iter_tput = []\n    for epoch in range(args.num_epochs):\n        tic = time.time()\n        model.train()\n        tic_step = time.time()\n        for step, (blocks, hist_blocks) in enumerate(dataloader):\n            # The nodes for input lies at the LHS side of the first block.\n            # The nodes for output lies at the RHS side of the last block.\n            input_nodes = blocks[0].srcdata[dgl.NID]\n            seeds = blocks[-1].dstdata[dgl.NID]\n\n            blocks, hist_blocks = load_subtensor(\n                g, labels, blocks, hist_blocks, dev_id, True\n            )\n\n            # forward\n            batch_pred = model(blocks)\n            # update history\n            update_history(g, blocks)\n            # compute loss\n            batch_labels = blocks[-1].dstdata[\"label\"]\n            loss = loss_fcn(batch_pred, batch_labels)\n            # backward\n            optimizer.zero_grad()\n            loss.backward()\n            optimizer.step()\n            iter_tput.append(len(seeds) / (time.time() - tic_step))\n            if step % args.log_every == 0:\n                acc = compute_acc(batch_pred, batch_labels)\n                print(\n                    \"Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | Speed (samples/sec) {:.4f}\".format(\n                        epoch,\n                        step,\n                        loss.item(),\n                        acc.item(),\n                        np.mean(iter_tput[3:]),\n                    )\n                )\n            tic_step = time.time()\n        toc = time.time()\n        print(\"Epoch Time(s): {:.4f}\".format(toc - tic))\n        if epoch >= 5:\n            avg += toc - tic\n        if epoch % args.eval_every == 0 and epoch != 0:\n            model.eval()\n            eval_acc = evaluate(\n                model, g, labels, val_nid, args.val_batch_size, dev_id\n            )\n            print(\"Eval Acc {:.4f}\".format(eval_acc))\n\n    print(\"Avg epoch time: {}\".format(avg / (epoch - 4)))\n\n\nif __name__ == \"__main__\":\n    argparser = argparse.ArgumentParser(\"multi-gpu training\")\n    argparser.add_argument(\"--gpu\", type=str, default=\"0\")\n    argparser.add_argument(\"--num-epochs\", type=int, default=20)\n    argparser.add_argument(\"--num-hidden\", type=int, default=16)\n    argparser.add_argument(\"--num-layers\", type=int, default=2)\n    argparser.add_argument(\"--fan-out\", type=str, default=\"1,1\")\n    argparser.add_argument(\"--batch-size\", type=int, default=1000)\n    argparser.add_argument(\"--val-batch-size\", type=int, default=1000)\n    argparser.add_argument(\"--log-every\", type=int, default=20)\n    argparser.add_argument(\"--eval-every\", type=int, default=5)\n    argparser.add_argument(\"--lr\", type=float, default=0.003)\n    argparser.add_argument(\"--num-workers-per-gpu\", type=int, default=0)\n    args = argparser.parse_args()\n\n    # load reddit data\n    data = RedditDataset(self_loop=True)\n    n_classes = data.num_classes\n    g = data[0]\n    features = g.ndata[\"feat\"]\n    in_feats = features.shape[1]\n    labels = g.ndata[\"label\"]\n    train_mask = g.ndata[\"train_mask\"]\n    val_mask = g.ndata[\"val_mask\"]\n    g.ndata[\"features\"] = features\n    g.create_formats_()\n    # Pack data\n    data = train_mask, val_mask, in_feats, labels, n_classes, g\n\n    run(args, int(args.gpu), data)\n"
  },
  {
    "path": "examples/pytorch/vrgcn/train_cv_multi_gpu.py",
    "content": "import argparse\nimport math\nimport time\nimport traceback\n\nimport dgl\nimport dgl.function as fn\nimport dgl.nn.pytorch as dglnn\n\nimport numpy as np\nimport torch as th\nimport torch.multiprocessing as mp\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nimport tqdm\nfrom dgl.data import RedditDataset\nfrom torch.nn.parallel import DistributedDataParallel\nfrom torch.utils.data import DataLoader\n\n\nclass SAGEConvWithCV(nn.Module):\n    def __init__(self, in_feats, out_feats, activation):\n        super().__init__()\n        self.W = nn.Linear(in_feats * 2, out_feats)\n        self.activation = activation\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        gain = nn.init.calculate_gain(\"relu\")\n        nn.init.xavier_uniform_(self.W.weight, gain=gain)\n        nn.init.constant_(self.W.bias, 0)\n\n    def forward(self, block, H, HBar=None):\n        if self.training:\n            with block.local_scope():\n                H_src, H_dst = H\n                HBar_src, agg_HBar_dst = HBar\n                block.dstdata[\"agg_hbar\"] = agg_HBar_dst\n                block.srcdata[\"hdelta\"] = H_src - HBar_src\n                block.update_all(\n                    fn.copy_u(\"hdelta\", \"m\"), fn.mean(\"m\", \"hdelta_new\")\n                )\n                h_neigh = (\n                    block.dstdata[\"agg_hbar\"] + block.dstdata[\"hdelta_new\"]\n                )\n                h = self.W(th.cat([H_dst, h_neigh], 1))\n                if self.activation is not None:\n                    h = self.activation(h)\n                return h\n        else:\n            with block.local_scope():\n                H_src, H_dst = H\n                block.srcdata[\"h\"] = H_src\n                block.update_all(fn.copy_u(\"h\", \"m\"), fn.mean(\"m\", \"h_new\"))\n                h_neigh = block.dstdata[\"h_new\"]\n                h = self.W(th.cat([H_dst, h_neigh], 1))\n                if self.activation is not None:\n                    h = self.activation(h)\n                return h\n\n\nclass SAGE(nn.Module):\n    def __init__(self, in_feats, n_hidden, n_classes, n_layers, activation):\n        super().__init__()\n        self.n_layers = n_layers\n        self.n_hidden = n_hidden\n        self.n_classes = n_classes\n        self.layers = nn.ModuleList()\n        self.layers.append(SAGEConvWithCV(in_feats, n_hidden, activation))\n        for i in range(1, n_layers - 1):\n            self.layers.append(SAGEConvWithCV(n_hidden, n_hidden, activation))\n        self.layers.append(SAGEConvWithCV(n_hidden, n_classes, None))\n\n    def forward(self, blocks):\n        h = blocks[0].srcdata[\"features\"]\n        updates = []\n        for layer, block in zip(self.layers, blocks):\n            # We need to first copy the representation of nodes on the RHS from the\n            # appropriate nodes on the LHS.\n            # Note that the shape of h is (num_nodes_LHS, D) and the shape of h_dst\n            # would be (num_nodes_RHS, D)\n            h_dst = h[: block.number_of_dst_nodes()]\n            hbar_src = block.srcdata[\"hist\"]\n            agg_hbar_dst = block.dstdata[\"agg_hist\"]\n            # Then we compute the updated representation on the RHS.\n            # The shape of h now becomes (num_nodes_RHS, D)\n            h = layer(block, (h, h_dst), (hbar_src, agg_hbar_dst))\n            block.dstdata[\"h_new\"] = h\n        return h\n\n    def inference(self, g, x, batch_size, device):\n        \"\"\"\n        Inference with the GraphSAGE model on full neighbors (i.e. without neighbor sampling).\n        g : the entire graph.\n        x : the input of entire node set.\n\n        The inference code is written in a fashion that it could handle any number of nodes and\n        layers.\n        \"\"\"\n        # During inference with sampling, multi-layer blocks are very inefficient because\n        # lots of computations in the first few layers are repeated.\n        # Therefore, we compute the representation of all nodes layer by layer.  The nodes\n        # on each layer are of course splitted in batches.\n        # TODO: can we standardize this?\n        nodes = th.arange(g.num_nodes())\n        for l, layer in enumerate(self.layers):\n            y = g.ndata[\"hist_%d\" % (l + 1)]\n\n            for start in tqdm.trange(0, len(nodes), batch_size):\n                end = start + batch_size\n                batch_nodes = nodes[start:end]\n                block = dgl.to_block(\n                    dgl.in_subgraph(g, batch_nodes), batch_nodes\n                )\n                induced_nodes = block.srcdata[dgl.NID]\n\n                h = x[induced_nodes].to(device)\n                block = block.to(device)\n                h_dst = h[: block.number_of_dst_nodes()]\n                h = layer(block, (h, h_dst))\n\n                y[start:end] = h.cpu()\n\n            x = y\n        return y\n\n\nclass NeighborSampler(object):\n    def __init__(self, g, fanouts):\n        self.g = g\n        self.fanouts = fanouts\n\n    def sample_blocks(self, seeds):\n        seeds = th.LongTensor(seeds)\n        blocks = []\n        hist_blocks = []\n        for fanout in self.fanouts:\n            # For each seed node, sample ``fanout`` neighbors.\n            frontier = dgl.sampling.sample_neighbors(self.g, seeds, fanout)\n            # For history aggregation we sample all neighbors.\n            hist_frontier = dgl.in_subgraph(self.g, seeds)\n            # Then we compact the frontier into a bipartite graph for message passing.\n            block = dgl.to_block(frontier, seeds)\n            hist_block = dgl.to_block(hist_frontier, seeds)\n            # Obtain the seed nodes for next layer.\n            seeds = block.srcdata[dgl.NID]\n\n            blocks.insert(0, block)\n            hist_blocks.insert(0, hist_block)\n        return blocks, hist_blocks\n\n\ndef compute_acc(pred, labels):\n    \"\"\"\n    Compute the accuracy of prediction given the labels.\n    \"\"\"\n    return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred)\n\n\ndef evaluate(model, g, labels, val_mask, batch_size, device):\n    \"\"\"\n    Evaluate the model on the validation set specified by ``val_mask``.\n    g : The entire graph.\n    inputs : The features of all the nodes.\n    labels : The labels of all the nodes.\n    val_mask : A 0-1 mask indicating which nodes do we actually compute the accuracy for.\n    batch_size : Number of nodes to compute at the same time.\n    device : The GPU device to evaluate on.\n    \"\"\"\n    model.eval()\n    with th.no_grad():\n        inputs = g.ndata[\"features\"]\n        pred = model.inference(\n            g, inputs, batch_size, device\n        )  # also recomputes history tensors\n    model.train()\n    return compute_acc(pred[val_mask], labels[val_mask])\n\n\ndef load_subtensor(\n    g, labels, blocks, hist_blocks, dev_id, aggregation_on_device=False\n):\n    \"\"\"\n    Copys features and labels of a set of nodes onto GPU.\n    \"\"\"\n    blocks[0].srcdata[\"features\"] = g.ndata[\"features\"][\n        blocks[0].srcdata[dgl.NID]\n    ]\n    blocks[-1].dstdata[\"label\"] = labels[blocks[-1].dstdata[dgl.NID]]\n    ret_blocks = []\n    ret_hist_blocks = []\n    for i, (block, hist_block) in enumerate(zip(blocks, hist_blocks)):\n        hist_col = \"features\" if i == 0 else \"hist_%d\" % i\n        block.srcdata[\"hist\"] = g.ndata[hist_col][block.srcdata[dgl.NID]]\n\n        # Aggregate history\n        hist_block.srcdata[\"hist\"] = g.ndata[hist_col][\n            hist_block.srcdata[dgl.NID]\n        ]\n        if aggregation_on_device:\n            hist_block = hist_block.to(dev_id)\n            hist_block.srcdata[\"hist\"] = hist_block.srcdata[\"hist\"]\n        hist_block.update_all(fn.copy_u(\"hist\", \"m\"), fn.mean(\"m\", \"agg_hist\"))\n\n        block = block.to(dev_id)\n        if not aggregation_on_device:\n            hist_block = hist_block.to(dev_id)\n        block.dstdata[\"agg_hist\"] = hist_block.dstdata[\"agg_hist\"]\n        ret_blocks.append(block)\n        ret_hist_blocks.append(hist_block)\n    return ret_blocks, ret_hist_blocks\n\n\ndef create_history_storage(g, args, n_classes):\n    # Initialize history storage\n    for l in range(args.num_layers):\n        dim = args.num_hidden if l != args.num_layers - 1 else n_classes\n        g.ndata[\"hist_%d\" % (l + 1)] = th.zeros(\n            g.num_nodes(), dim\n        ).share_memory_()\n\n\ndef init_history(g, model, dev_id, batch_size):\n    with th.no_grad():\n        model.inference(\n            g, g.ndata[\"features\"], batch_size, dev_id\n        )  # replaces hist_i features in-place\n\n\ndef update_history(g, blocks):\n    with th.no_grad():\n        for i, block in enumerate(blocks):\n            ids = block.dstdata[dgl.NID].cpu()\n            hist_col = \"hist_%d\" % (i + 1)\n\n            h_new = block.dstdata[\"h_new\"].cpu()\n            g.ndata[hist_col][ids] = h_new\n\n\ndef run(proc_id, n_gpus, args, devices, data):\n    dropout = 0.2\n\n    dev_id = devices[proc_id]\n    if n_gpus > 1:\n        dist_init_method = \"tcp://{master_ip}:{master_port}\".format(\n            master_ip=\"127.0.0.1\", master_port=\"12345\"\n        )\n        world_size = n_gpus\n        th.distributed.init_process_group(\n            backend=\"nccl\",\n            init_method=dist_init_method,\n            world_size=world_size,\n            rank=proc_id,\n        )\n    th.cuda.set_device(dev_id)\n\n    # Unpack data\n    train_mask, val_mask, in_feats, labels, n_classes, g = data\n    train_nid = train_mask.nonzero().squeeze()\n    val_nid = val_mask.nonzero().squeeze()\n\n    # Create sampler\n    sampler = NeighborSampler(g, [int(_) for _ in args.fan_out.split(\",\")])\n\n    # Create PyTorch DataLoader for constructing blocks\n    if n_gpus > 1:\n        dist_sampler = th.utils.data.distributed.DistributedSampler(\n            train_nid.numpy(), shuffle=True, drop_last=False\n        )\n        dataloader = DataLoader(\n            dataset=train_nid.numpy(),\n            batch_size=args.batch_size,\n            collate_fn=sampler.sample_blocks,\n            sampler=dist_sampler,\n            num_workers=args.num_workers_per_gpu,\n        )\n    else:\n        dataloader = DataLoader(\n            dataset=train_nid.numpy(),\n            batch_size=args.batch_size,\n            collate_fn=sampler.sample_blocks,\n            shuffle=True,\n            drop_last=False,\n            num_workers=args.num_workers_per_gpu,\n        )\n\n    # Define model\n    model = SAGE(in_feats, args.num_hidden, n_classes, args.num_layers, F.relu)\n\n    # Move the model to GPU and define optimizer\n    model = model.to(dev_id)\n    if n_gpus > 1:\n        model = DistributedDataParallel(\n            model, device_ids=[dev_id], output_device=dev_id\n        )\n    loss_fcn = nn.CrossEntropyLoss()\n    loss_fcn = loss_fcn.to(dev_id)\n    optimizer = optim.Adam(model.parameters(), lr=args.lr)\n\n    # Compute history tensor and their aggregation before training on CPU\n    model.eval()\n    if n_gpus > 1:\n        if proc_id == 0:\n            init_history(g, model.module, dev_id, args.val_batch_size)\n        th.distributed.barrier()\n    else:\n        init_history(g, model, dev_id, args.val_batch_size)\n    model.train()\n\n    # Training loop\n    avg = 0\n    iter_tput = []\n    for epoch in range(args.num_epochs):\n        if n_gpus > 1:\n            dist_sampler.set_epoch(epoch)\n        tic = time.time()\n        model.train()\n        for step, (blocks, hist_blocks) in enumerate(dataloader):\n            if proc_id == 0:\n                tic_step = time.time()\n\n            # The nodes for input lies at the LHS side of the first block.\n            # The nodes for output lies at the RHS side of the last block.\n            seeds = blocks[-1].dstdata[dgl.NID]\n\n            blocks, hist_blocks = load_subtensor(\n                g, labels, blocks, hist_blocks, dev_id, True\n            )\n\n            # forward\n            batch_pred = model(blocks)\n            # update history\n            update_history(g, blocks)\n            # compute loss\n            batch_labels = blocks[-1].dstdata[\"label\"]\n            loss = loss_fcn(batch_pred, batch_labels)\n            # backward\n            optimizer.zero_grad()\n            loss.backward()\n            optimizer.step()\n            if proc_id == 0:\n                iter_tput.append(len(seeds) * n_gpus / (time.time() - tic_step))\n            if step % args.log_every == 0 and proc_id == 0:\n                acc = compute_acc(batch_pred, batch_labels)\n                print(\n                    \"Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | Speed (samples/sec) {:.4f}\".format(\n                        epoch,\n                        step,\n                        loss.item(),\n                        acc.item(),\n                        np.mean(iter_tput[3:]),\n                    )\n                )\n\n        if n_gpus > 1:\n            th.distributed.barrier()\n\n        toc = time.time()\n        if proc_id == 0:\n            print(\"Epoch Time(s): {:.4f}\".format(toc - tic))\n            if epoch >= 5:\n                avg += toc - tic\n            if epoch % args.eval_every == 0 and epoch != 0:\n                model.eval()\n                eval_acc = evaluate(\n                    model if n_gpus == 1 else model.module,\n                    g,\n                    labels,\n                    val_nid,\n                    args.val_batch_size,\n                    dev_id,\n                )\n                print(\"Eval Acc {:.4f}\".format(eval_acc))\n\n    if n_gpus > 1:\n        th.distributed.barrier()\n    if proc_id == 0:\n        print(\"Avg epoch time: {}\".format(avg / (epoch - 4)))\n\n\nif __name__ == \"__main__\":\n    argparser = argparse.ArgumentParser(\"multi-gpu training\")\n    argparser.add_argument(\"--gpu\", type=str, default=\"0\")\n    argparser.add_argument(\"--num-epochs\", type=int, default=20)\n    argparser.add_argument(\"--num-hidden\", type=int, default=16)\n    argparser.add_argument(\"--num-layers\", type=int, default=2)\n    argparser.add_argument(\"--fan-out\", type=str, default=\"1,1\")\n    argparser.add_argument(\"--batch-size\", type=int, default=1000)\n    argparser.add_argument(\"--val-batch-size\", type=int, default=1000)\n    argparser.add_argument(\"--log-every\", type=int, default=20)\n    argparser.add_argument(\"--eval-every\", type=int, default=5)\n    argparser.add_argument(\"--lr\", type=float, default=0.003)\n    argparser.add_argument(\"--num-workers-per-gpu\", type=int, default=0)\n    args = argparser.parse_args()\n\n    devices = list(map(int, args.gpu.split(\",\")))\n    n_gpus = len(devices)\n\n    # load reddit data\n    data = RedditDataset(self_loop=True)\n    n_classes = data.num_classes\n    g = data[0]\n    features = g.ndata[\"feat\"]\n    in_feats = features.shape[1]\n    labels = g.ndata[\"label\"]\n    train_mask = g.ndata[\"train_mask\"]\n    val_mask = g.ndata[\"val_mask\"]\n    g.ndata[\"features\"] = features.share_memory_()\n    create_history_storage(g, args, n_classes)\n\n    # Create csr/coo/csc formats before launching training processes with multi-gpu.\n    # This avoids creating certain formats in each sub-process, which saves momory and CPU.\n    g.create_formats_()\n    # Pack data\n    data = train_mask, val_mask, in_feats, labels, n_classes, g\n\n    if n_gpus == 1:\n        run(0, n_gpus, args, devices, data)\n    else:\n        mp.spawn(run, args=(n_gpus, args, devices, data), nprocs=n_gpus)\n"
  },
  {
    "path": "examples/sparse/appnp.py",
    "content": "\"\"\"\n[Predict then Propagate: Graph Neural Networks meet Personalized PageRank]\n(https://arxiv.org/abs/1810.05997)\n\"\"\"\n\nimport dgl.sparse as dglsp\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dgl.data import CoraGraphDataset\nfrom torch.optim import Adam\n\n\nclass APPNP(nn.Module):\n    def __init__(\n        self,\n        in_size,\n        out_size,\n        hidden_size=64,\n        dropout=0.1,\n        num_hops=10,\n        alpha=0.1,\n    ):\n        super().__init__()\n\n        self.f_theta = nn.Sequential(\n            nn.Dropout(dropout),\n            nn.Linear(in_size, hidden_size),\n            nn.ReLU(),\n            nn.Dropout(dropout),\n            nn.Linear(hidden_size, out_size),\n        )\n        self.num_hops = num_hops\n        self.A_dropout = nn.Dropout(dropout)\n        self.alpha = alpha\n\n    def forward(self, A_hat, X):\n        Z_0 = Z = self.f_theta(X)\n        for _ in range(self.num_hops):\n            A_drop = dglsp.val_like(A_hat, self.A_dropout(A_hat.val))\n            Z = (1 - self.alpha) * A_drop @ Z + self.alpha * Z_0\n        return Z\n\n\ndef evaluate(g, pred):\n    label = g.ndata[\"label\"]\n    val_mask = g.ndata[\"val_mask\"]\n    test_mask = g.ndata[\"test_mask\"]\n\n    # Compute accuracy on validation/test set.\n    val_acc = (pred[val_mask] == label[val_mask]).float().mean()\n    test_acc = (pred[test_mask] == label[test_mask]).float().mean()\n    return val_acc, test_acc\n\n\ndef train(model, g, A_hat, X):\n    label = g.ndata[\"label\"]\n    train_mask = g.ndata[\"train_mask\"]\n    optimizer = Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)\n\n    for epoch in range(50):\n        # Forward.\n        model.train()\n        logits = model(A_hat, X)\n\n        # Compute loss with nodes in training set.\n        loss = F.cross_entropy(logits[train_mask], label[train_mask])\n\n        # Backward.\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        # Compute prediction.\n        model.eval()\n        logits = model(A_hat, X)\n        pred = logits.argmax(dim=1)\n\n        # Evaluate the prediction.\n        val_acc, test_acc = evaluate(g, pred)\n        print(\n            f\"In epoch {epoch}, loss: {loss:.3f}, val acc: {val_acc:.3f}, test\"\n            f\" acc: {test_acc:.3f}\"\n        )\n\n\nif __name__ == \"__main__\":\n    # If CUDA is available, use GPU to accelerate the training, use CPU\n    # otherwise.\n    dev = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n\n    # Load graph from the existing dataset.\n    dataset = CoraGraphDataset()\n    g = dataset[0].to(dev)\n\n    # Create the sparse adjacency matrix A.\n    indices = torch.stack(g.edges())\n    N = g.num_nodes()\n    A = dglsp.spmatrix(indices, shape=(N, N))\n\n    # Calculate the symmetrically normalized adjacency matrix.\n    I = dglsp.identity(A.shape, device=dev)\n    A_hat = A + I\n    D_hat = dglsp.diag(A_hat.sum(dim=1)) ** -0.5\n    A_hat = D_hat @ A_hat @ D_hat\n\n    # Create APPNP model.\n    X = g.ndata[\"feat\"]\n    in_size = X.shape[1]\n    out_size = dataset.num_classes\n    model = APPNP(in_size, out_size).to(dev)\n\n    # Kick off training.\n    train(model, g, A_hat, X)\n"
  },
  {
    "path": "examples/sparse/c_and_s.py",
    "content": "\"\"\"\n[Combining Label Propagation and Simple Models Out-performs\nGraph Neural Networks](https://arxiv.org/abs/2010.13993)\n\"\"\"\nimport dgl.sparse as dglsp\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dgl.data import CoraGraphDataset\nfrom torch.optim import Adam\n\n\n###############################################################################\n# (HIGHLIGHT) Compute Label Propagation with Sparse Matrix API\n###############################################################################\n@torch.no_grad()\ndef label_propagation(A_hat, label, num_layers=20, alpha=0.9):\n    Y = label\n    for _ in range(num_layers):\n        Y = alpha * A_hat @ Y + (1 - alpha) * label\n        Y = Y.clamp_(0.0, 1.0)\n    return Y\n\n\ndef correct(A_hat, label, soft_label, mask):\n    # Compute error.\n    error = torch.zeros_like(soft_label)\n    error[mask] = label[mask] - soft_label[mask]\n\n    # Smooth error.\n    smoothed_error = label_propagation(A_hat, error)\n\n    # Autoscale.\n    sigma = error[mask].abs()\n    sigma = sigma.sum() / sigma.shape[0]\n    scale = sigma / smoothed_error.abs().sum(dim=1, keepdim=True)\n    scale[scale.isinf() | (scale > 1000)] = 1.0\n\n    # Correct.\n    result = soft_label + scale * smoothed_error\n    return result\n\n\ndef smooth(A_hat, label, soft_label, mask):\n    soft_label[mask] = label[mask].float()\n    return label_propagation(A_hat, soft_label)\n\n\ndef evaluate(g, pred):\n    label = g.ndata[\"label\"]\n    val_mask = g.ndata[\"val_mask\"]\n    test_mask = g.ndata[\"test_mask\"]\n\n    # Compute accuracy on validation/test set.\n    val_acc = (pred[val_mask] == label[val_mask]).float().mean()\n    test_acc = (pred[test_mask] == label[test_mask]).float().mean()\n    return val_acc, test_acc\n\n\ndef train(base_model, g, X):\n    label = g.ndata[\"label\"]\n    train_mask = g.ndata[\"train_mask\"]\n\n    optimizer = Adam(base_model.parameters(), lr=0.01)\n\n    for epoch in range(10):\n        # Forward.\n        base_model.train()\n        logits = base_model(X)\n\n        # Compute loss with nodes in training set.\n        loss = F.cross_entropy(logits[train_mask], label[train_mask])\n\n        # Backward.\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        # Compute prediction.\n        base_model.eval()\n        logits = base_model(X)\n        pred = logits.argmax(dim=1)\n\n        # Evaluate the prediction.\n        val_acc, test_acc = evaluate(g, pred)\n        print(\n            f\"Base model, In epoch {epoch}, loss: {loss:.3f}, \"\n            f\"val acc: {val_acc:.3f}, test acc: {test_acc:.3f}\"\n        )\n    return logits\n\n\nif __name__ == \"__main__\":\n    # If CUDA is available, use GPU to accelerate the training, use CPU\n    # otherwise.\n    dev = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n\n    # Load graph from the existing dataset.\n    dataset = CoraGraphDataset()\n    g = dataset[0].to(dev)\n\n    # Create the sparse adjacency matrix A.\n    indices = torch.stack(g.edges())\n    N = g.num_nodes()\n    A = dglsp.spmatrix(indices, shape=(N, N))\n\n    # Calculate the symmetrically normalized adjacency matrix.\n    I = dglsp.identity(A.shape, device=dev)\n    A_hat = A + I\n    D_hat = dglsp.diag(A_hat.sum(dim=1)) ** -0.5\n    A_hat = D_hat @ A_hat @ D_hat\n\n    # Create models.\n    X = g.ndata[\"feat\"]\n    in_size = X.shape[1]\n    out_size = dataset.num_classes\n    base_model = nn.Linear(in_size, out_size).to(dev)\n\n    # Stage1: Train the base model.\n    logits = train(base_model, g, X)\n\n    # Stage2: Correct and Smooth.\n    soft_label = F.softmax(logits, dim=1)\n    label = F.one_hot(g.ndata[\"label\"])\n    soft_label = correct(A_hat, label, soft_label, g.ndata[\"train_mask\"])\n    soft_label = smooth(A_hat, label, soft_label, g.ndata[\"train_mask\"])\n    pred = soft_label.argmax(dim=1)\n    val_acc, test_acc = evaluate(g, pred)\n    print(f\"val acc: {val_acc:.3f}, test acc: {test_acc:.3f}\")\n"
  },
  {
    "path": "examples/sparse/gat.py",
    "content": "\"\"\"\n[Graph Attention Networks]\n(https://arxiv.org/abs/1710.10903)\n\"\"\"\n\nimport dgl.sparse as dglsp\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dgl.data import CoraGraphDataset\nfrom torch.optim import Adam\n\n\nclass GATConv(nn.Module):\n    def __init__(self, in_size, out_size, num_heads, dropout):\n        super().__init__()\n\n        self.out_size = out_size\n        self.num_heads = num_heads\n\n        self.dropout = nn.Dropout(dropout)\n        self.W = nn.Linear(in_size, out_size * num_heads)\n        self.a_l = nn.Parameter(torch.zeros(1, out_size, num_heads))\n        self.a_r = nn.Parameter(torch.zeros(1, out_size, num_heads))\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        gain = nn.init.calculate_gain(\"relu\")\n        nn.init.xavier_normal_(self.W.weight, gain=gain)\n        nn.init.xavier_normal_(self.a_l, gain=gain)\n        nn.init.xavier_normal_(self.a_r, gain=gain)\n\n    ###########################################################################\n    # (HIGHLIGHT) Take the advantage of DGL sparse APIs to implement\n    # multihead attention.\n    ###########################################################################\n    def forward(self, A_hat, Z):\n        Z = self.dropout(Z)\n        Z = self.W(Z).view(Z.shape[0], self.out_size, self.num_heads)\n\n        # a^T [Wh_i || Wh_j] = a_l Wh_i + a_r Wh_j\n        e_l = (Z * self.a_l).sum(dim=1)\n        e_r = (Z * self.a_r).sum(dim=1)\n        e = e_l[A_hat.row] + e_r[A_hat.col]\n\n        a = F.leaky_relu(e)\n        A_atten = dglsp.val_like(A_hat, a).softmax()\n        a_drop = self.dropout(A_atten.val)\n        A_atten = dglsp.val_like(A_atten, a_drop)\n        return dglsp.bspmm(A_atten, Z)\n\n\nclass GAT(nn.Module):\n    def __init__(\n        self, in_size, out_size, hidden_size=8, num_heads=8, dropout=0.6\n    ):\n        super().__init__()\n\n        self.in_conv = GATConv(\n            in_size, hidden_size, num_heads=num_heads, dropout=dropout\n        )\n        self.out_conv = GATConv(\n            hidden_size * num_heads, out_size, num_heads=1, dropout=dropout\n        )\n\n    def forward(self, A_hat, X):\n        # Flatten the head and feature dimension.\n        Z = F.elu(self.in_conv(A_hat, X)).flatten(1)\n        # Average over the head dimension.\n        Z = self.out_conv(A_hat, Z).mean(-1)\n        return Z\n\n\ndef evaluate(g, pred):\n    label = g.ndata[\"label\"]\n    val_mask = g.ndata[\"val_mask\"]\n    test_mask = g.ndata[\"test_mask\"]\n\n    # Compute accuracy on validation/test set.\n    val_acc = (pred[val_mask] == label[val_mask]).float().mean()\n    test_acc = (pred[test_mask] == label[test_mask]).float().mean()\n    return val_acc, test_acc\n\n\ndef train(model, g, A_hat, X):\n    label = g.ndata[\"label\"]\n    train_mask = g.ndata[\"train_mask\"]\n    optimizer = Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)\n\n    for epoch in range(50):\n        # Forward.\n        model.train()\n        logits = model(A_hat, X)\n\n        # Compute loss with nodes in training set.\n        loss = F.cross_entropy(logits[train_mask], label[train_mask])\n\n        # Backward.\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        # Compute prediction.\n        model.eval()\n        logits = model(A_hat, X)\n        pred = logits.argmax(dim=1)\n\n        # Evaluate the prediction.\n        val_acc, test_acc = evaluate(g, pred)\n        print(\n            f\"In epoch {epoch}, loss: {loss:.3f}, val acc: {val_acc:.3f}, test\"\n            f\" acc: {test_acc:.3f}\"\n        )\n\n\nif __name__ == \"__main__\":\n    # If CUDA is available, use GPU to accelerate the training, use CPU\n    # otherwise.\n    dev = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n\n    # Load graph from the existing dataset.\n    dataset = CoraGraphDataset()\n    g = dataset[0].to(dev)\n\n    # Create the sparse adjacency matrix A.\n    indices = torch.stack(g.edges())\n    N = g.num_nodes()\n    A = dglsp.spmatrix(indices, shape=(N, N))\n\n    # Add self-loops.\n    I = dglsp.identity(A.shape, device=dev)\n    A_hat = A + I\n\n    # Create GAT model.\n    X = g.ndata[\"feat\"]\n    in_size = X.shape[1]\n    out_size = dataset.num_classes\n    model = GAT(in_size, out_size).to(dev)\n\n    # Kick off training.\n    train(model, g, A_hat, X)\n"
  },
  {
    "path": "examples/sparse/gcn.py",
    "content": "\"\"\"\n[Semi-Supervised Classification with Graph Convolutional Networks]\n(https://arxiv.org/abs/1609.02907)\n\"\"\"\n\nimport dgl.sparse as dglsp\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dgl.data import CoraGraphDataset\nfrom torch.optim import Adam\n\n\nclass GCN(nn.Module):\n    def __init__(self, in_size, out_size, hidden_size=16):\n        super().__init__()\n\n        # Two-layer GCN.\n        self.W1 = nn.Linear(in_size, hidden_size)\n        self.W2 = nn.Linear(hidden_size, out_size)\n\n    ############################################################################\n    # (HIGHLIGHT) Take the advantage of DGL sparse APIs to implement the GCN\n    # forward process.\n    ############################################################################\n    def forward(self, A_norm, X):\n        X = A_norm @ self.W1(X)\n        X = F.relu(X)\n        X = A_norm @ self.W2(X)\n        return X\n\n\ndef evaluate(g, pred):\n    label = g.ndata[\"label\"]\n    val_mask = g.ndata[\"val_mask\"]\n    test_mask = g.ndata[\"test_mask\"]\n\n    # Compute accuracy on validation/test set.\n    val_acc = (pred[val_mask] == label[val_mask]).float().mean()\n    test_acc = (pred[test_mask] == label[test_mask]).float().mean()\n    return val_acc, test_acc\n\n\ndef train(model, g, A_norm, X):\n    label = g.ndata[\"label\"]\n    train_mask = g.ndata[\"train_mask\"]\n    optimizer = Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)\n    loss_fcn = nn.CrossEntropyLoss()\n\n    for epoch in range(200):\n        model.train()\n\n        # Forward.\n        logits = model(A_norm, X)\n\n        # Compute loss with nodes in the training set.\n        loss = loss_fcn(logits[train_mask], label[train_mask])\n\n        # Backward.\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        # Compute prediction.\n        pred = logits.argmax(dim=1)\n\n        # Evaluate the prediction.\n        val_acc, test_acc = evaluate(g, pred)\n        if epoch % 20 == 0:\n            print(\n                f\"In epoch {epoch}, loss: {loss:.3f}, val acc: {val_acc:.3f}\"\n                f\", test acc: {test_acc:.3f}\"\n            )\n\n\nif __name__ == \"__main__\":\n    # If CUDA is available, use GPU to accelerate the training, use CPU\n    # otherwise.\n    dev = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n\n    # Load graph from the existing dataset.\n    dataset = CoraGraphDataset()\n    g = dataset[0].to(dev)\n    num_classes = dataset.num_classes\n    X = g.ndata[\"feat\"]\n\n    # Create the adjacency matrix of graph.\n    indices = torch.stack(g.edges())\n    N = g.num_nodes()\n    A = dglsp.spmatrix(indices, shape=(N, N))\n\n    ############################################################################\n    # (HIGHLIGHT) Compute the symmetrically normalized adjacency matrix with\n    # Sparse Matrix API\n    ############################################################################\n    I = dglsp.identity(A.shape, device=dev)\n    A_hat = A + I\n    D_hat = dglsp.diag(A_hat.sum(1)) ** -0.5\n    A_norm = D_hat @ A_hat @ D_hat\n\n    # Create model.\n    in_size = X.shape[1]\n    out_size = num_classes\n    model = GCN(in_size, out_size).to(dev)\n\n    # Kick off training.\n    train(model, g, A_norm, X)\n"
  },
  {
    "path": "examples/sparse/gcnii.py",
    "content": "\"\"\"\n[Simple and Deep Graph Convolutional Networks]\n(https://arxiv.org/abs/2007.02133)\n\"\"\"\n\nimport math\n\nimport dgl.sparse as dglsp\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dgl.data import CoraGraphDataset\nfrom torch.optim import Adam\n\n\nclass GCNIIConvolution(nn.Module):\n    def __init__(self, in_size, out_size):\n        super().__init__()\n        self.out_size = out_size\n        self.weight = nn.Linear(in_size, out_size, bias=False)\n\n    ############################################################################\n    # (HIGHLIGHT) Take the advantage of DGL sparse APIs to implement the GCNII\n    # forward process.\n    ############################################################################\n    def forward(self, A_norm, H, H0, lamda, alpha, l):\n        beta = math.log(lamda / l + 1)\n\n        # Multiply a sparse matrix by a dense matrix.\n        H = A_norm @ H\n        H = (1 - alpha) * H + alpha * H0\n        H = (1 - beta) * H + beta * self.weight(H)\n        return H\n\n\nclass GCNII(nn.Module):\n    def __init__(\n        self,\n        in_size,\n        out_size,\n        hidden_size,\n        n_layers,\n        lamda,\n        alpha,\n        dropout=0.5,\n    ):\n        super().__init__()\n        self.hidden_size = hidden_size\n        self.n_layers = n_layers\n        self.lamda = lamda\n        self.alpha = alpha\n\n        # The GCNII model.\n        self.layers = nn.ModuleList()\n        self.layers.append(nn.Linear(in_size, hidden_size))\n        for _ in range(n_layers):\n            self.layers.append(GCNIIConvolution(hidden_size, hidden_size))\n        self.layers.append(nn.Linear(hidden_size, out_size))\n\n        self.activation = nn.ReLU()\n        self.dropout = dropout\n\n    def forward(self, A_norm, feature):\n        H = feature\n        H = F.dropout(H, self.dropout, training=self.training)\n        H = self.layers[0](H)\n        H = self.activation(H)\n        H0 = H\n\n        # The GCNII convolution forward.\n        for i, conv in enumerate(self.layers[1:-1]):\n            H = F.dropout(H, self.dropout, training=self.training)\n            H = conv(A_norm, H, H0, self.lamda, self.alpha, i + 1)\n            H = self.activation(H)\n\n        H = F.dropout(H, self.dropout, training=self.training)\n        H = self.layers[-1](H)\n\n        return H\n\n\ndef evaluate(model, A_norm, H, label, val_mask, test_mask):\n    model.eval()\n    logits = model(A_norm, H)\n    pred = logits.argmax(dim=1)\n\n    # Compute accuracy on validation/test set.\n    val_acc = (pred[val_mask] == label[val_mask]).float().mean()\n    test_acc = (pred[test_mask] == label[test_mask]).float().mean()\n    return val_acc, test_acc\n\n\ndef train(model, g, A_norm, H):\n    label = g.ndata[\"label\"]\n    train_mask = g.ndata[\"train_mask\"]\n    val_mask = g.ndata[\"val_mask\"]\n    test_mask = g.ndata[\"test_mask\"]\n    optimizer = Adam(model.parameters(), lr=0.01, weight_decay=5e-4)\n\n    loss_fcn = nn.CrossEntropyLoss()\n\n    for epoch in range(100):\n        model.train()\n        optimizer.zero_grad()\n\n        # Forward.\n        logits = model(A_norm, H)\n\n        # Compute loss with nodes in the training set.\n        loss = loss_fcn(logits[train_mask], label[train_mask])\n\n        # Backward.\n        loss.backward()\n        optimizer.step()\n\n        # Evaluate the prediction.\n        val_acc, test_acc = evaluate(\n            model, A_norm, H, label, val_mask, test_mask\n        )\n        if epoch % 5 == 0:\n            print(\n                f\"In epoch {epoch}, loss: {loss:.3f}, val acc: {val_acc:.3f}\"\n                f\", test acc: {test_acc:.3f}\"\n            )\n\n\nif __name__ == \"__main__\":\n    # If CUDA is available, use GPU to accelerate the training, use CPU\n    # otherwise.\n    dev = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n\n    # Load graph from the existing dataset.\n    dataset = CoraGraphDataset()\n    g = dataset[0].to(dev)\n    num_classes = dataset.num_classes\n    H = g.ndata[\"feat\"]\n\n    # Create the adjacency matrix of graph.\n    indices = torch.stack(g.edges())\n    N = g.num_nodes()\n    A = dglsp.spmatrix(indices, shape=(N, N))\n\n    ############################################################################\n    # (HIGHLIGHT) Compute the symmetrically normalized adjacency matrix with\n    # Sparse Matrix API\n    ############################################################################\n    I = dglsp.identity(A.shape, device=dev)\n    A_hat = A + I\n    D_hat = dglsp.diag(A_hat.sum(1)) ** -0.5\n    A_norm = D_hat @ A_hat @ D_hat\n\n    # Create model.\n    in_size = H.shape[1]\n    out_size = num_classes\n    model = GCNII(\n        in_size,\n        out_size,\n        hidden_size=64,\n        n_layers=64,\n        lamda=0.5,\n        alpha=0.2,\n        dropout=0.5,\n    ).to(dev)\n\n    # Kick off training.\n    train(model, g, A_norm, H)\n"
  },
  {
    "path": "examples/sparse/graph_transformer.py",
    "content": "\"\"\"\n[A Generalization of Transformer Networks to Graphs]\n(https://arxiv.org/abs/2012.09699)\n\"\"\"\n\nimport dgl\nimport dgl.nn as dglnn\nimport dgl.sparse as dglsp\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\n\nfrom dgl.data import AsGraphPredDataset\nfrom dgl.dataloading import GraphDataLoader\nfrom ogb.graphproppred import collate_dgl, DglGraphPropPredDataset, Evaluator\nfrom ogb.graphproppred.mol_encoder import AtomEncoder\nfrom tqdm import tqdm\n\n\nclass SparseMHA(nn.Module):\n    \"\"\"Sparse Multi-head Attention Module\"\"\"\n\n    def __init__(self, hidden_size=80, num_heads=8):\n        super().__init__()\n        self.hidden_size = hidden_size\n        self.num_heads = num_heads\n        self.head_dim = hidden_size // num_heads\n        self.scaling = self.head_dim**-0.5\n\n        self.q_proj = nn.Linear(hidden_size, hidden_size)\n        self.k_proj = nn.Linear(hidden_size, hidden_size)\n        self.v_proj = nn.Linear(hidden_size, hidden_size)\n        self.out_proj = nn.Linear(hidden_size, hidden_size)\n\n    def forward(self, A, h):\n        N = len(h)\n        q = self.q_proj(h).reshape(N, self.head_dim, self.num_heads)\n        q *= self.scaling\n        k = self.k_proj(h).reshape(N, self.head_dim, self.num_heads)\n        v = self.v_proj(h).reshape(N, self.head_dim, self.num_heads)\n\n        ######################################################################\n        # (HIGHLIGHT) Compute the multi-head attention with Sparse Matrix API\n        ######################################################################\n        attn = dglsp.bsddmm(A, q, k.transpose(1, 0))  # [N, N, nh]\n        attn = attn.softmax()\n        out = dglsp.bspmm(attn, v)\n\n        return self.out_proj(out.reshape(N, -1))\n\n\nclass GTLayer(nn.Module):\n    \"\"\"Graph Transformer Layer\"\"\"\n\n    def __init__(self, hidden_size=80, num_heads=8):\n        super().__init__()\n        self.MHA = SparseMHA(hidden_size=hidden_size, num_heads=num_heads)\n        self.batchnorm1 = nn.BatchNorm1d(hidden_size)\n        self.batchnorm2 = nn.BatchNorm1d(hidden_size)\n        self.FFN1 = nn.Linear(hidden_size, hidden_size * 2)\n        self.FFN2 = nn.Linear(hidden_size * 2, hidden_size)\n\n    def forward(self, A, h):\n        h1 = h\n        h = self.MHA(A, h)\n        h = self.batchnorm1(h + h1)\n\n        h2 = h\n        h = self.FFN2(F.relu(self.FFN1(h)))\n        h = h2 + h\n\n        return self.batchnorm2(h)\n\n\nclass GTModel(nn.Module):\n    def __init__(\n        self,\n        out_size,\n        hidden_size=80,\n        pos_enc_size=2,\n        num_layers=8,\n        num_heads=8,\n    ):\n        super().__init__()\n        self.atom_encoder = AtomEncoder(hidden_size)\n        self.pos_linear = nn.Linear(pos_enc_size, hidden_size)\n        self.layers = nn.ModuleList(\n            [GTLayer(hidden_size, num_heads) for _ in range(num_layers)]\n        )\n        self.pooler = dglnn.SumPooling()\n        self.predictor = nn.Sequential(\n            nn.Linear(hidden_size, hidden_size // 2),\n            nn.ReLU(),\n            nn.Linear(hidden_size // 2, hidden_size // 4),\n            nn.ReLU(),\n            nn.Linear(hidden_size // 4, out_size),\n        )\n\n    def forward(self, g, X, pos_enc):\n        indices = torch.stack(g.edges())\n        N = g.num_nodes()\n        A = dglsp.spmatrix(indices, shape=(N, N))\n        h = self.atom_encoder(X) + self.pos_linear(pos_enc)\n        for layer in self.layers:\n            h = layer(A, h)\n        h = self.pooler(g, h)\n\n        return self.predictor(h)\n\n\n@torch.no_grad()\ndef evaluate(model, dataloader, evaluator, device):\n    model.eval()\n    y_true = []\n    y_pred = []\n    for batched_g, labels in dataloader:\n        batched_g, labels = batched_g.to(device), labels.to(device)\n        y_hat = model(batched_g, batched_g.ndata[\"feat\"], batched_g.ndata[\"PE\"])\n        y_true.append(labels.view(y_hat.shape).detach().cpu())\n        y_pred.append(y_hat.detach().cpu())\n    y_true = torch.cat(y_true, dim=0).numpy()\n    y_pred = torch.cat(y_pred, dim=0).numpy()\n    input_dict = {\"y_true\": y_true, \"y_pred\": y_pred}\n    return evaluator.eval(input_dict)[\"rocauc\"]\n\n\ndef train(model, dataset, evaluator, device):\n    train_dataloader = GraphDataLoader(\n        dataset[dataset.train_idx],\n        batch_size=256,\n        shuffle=True,\n        collate_fn=collate_dgl,\n    )\n    valid_dataloader = GraphDataLoader(\n        dataset[dataset.val_idx], batch_size=256, collate_fn=collate_dgl\n    )\n    test_dataloader = GraphDataLoader(\n        dataset[dataset.test_idx], batch_size=256, collate_fn=collate_dgl\n    )\n    optimizer = optim.Adam(model.parameters(), lr=0.001)\n    num_epochs = 50\n    scheduler = optim.lr_scheduler.StepLR(\n        optimizer, step_size=num_epochs, gamma=0.5\n    )\n    loss_fcn = nn.BCEWithLogitsLoss()\n\n    for epoch in range(num_epochs):\n        model.train()\n        total_loss = 0.0\n        for batched_g, labels in train_dataloader:\n            batched_g, labels = batched_g.to(device), labels.to(device)\n            logits = model(\n                batched_g, batched_g.ndata[\"feat\"], batched_g.ndata[\"PE\"]\n            )\n            loss = loss_fcn(logits, labels.float())\n            total_loss += loss.item()\n            optimizer.zero_grad()\n            loss.backward()\n            optimizer.step()\n        scheduler.step()\n        avg_loss = total_loss / len(train_dataloader)\n        val_metric = evaluate(model, valid_dataloader, evaluator, device)\n        test_metric = evaluate(model, test_dataloader, evaluator, device)\n        print(\n            f\"Epoch: {epoch:03d}, Loss: {avg_loss:.4f}, \"\n            f\"Val: {val_metric:.4f}, Test: {test_metric:.4f}\"\n        )\n\n\nif __name__ == \"__main__\":\n    # If CUDA is available, use GPU to accelerate the training, use CPU\n    # otherwise.\n    dev = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n\n    # load dataset\n    pos_enc_size = 8\n    dataset = AsGraphPredDataset(\n        DglGraphPropPredDataset(\"ogbg-molhiv\", \"./data/OGB\")\n    )\n    evaluator = Evaluator(\"ogbg-molhiv\")\n    # laplacian positional encoding\n    for g, _ in tqdm(dataset, desc=\"Computing Laplacian PE\"):\n        g.ndata[\"PE\"] = dgl.lap_pe(g, k=pos_enc_size, padding=True)\n\n    # Create model.\n    out_size = dataset.num_tasks\n    model = GTModel(out_size=out_size, pos_enc_size=pos_enc_size).to(dev)\n\n    # Kick off training.\n    train(model, dataset, evaluator, dev)\n"
  },
  {
    "path": "examples/sparse/han.py",
    "content": "\"\"\"\n[Heterogeneous Graph Attention Network]\n(https://arxiv.org/abs/1903.07293)\n\"\"\"\n\nimport pickle\n\nimport dgl.sparse as dglsp\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom dgl.data.utils import _get_dgl_url, download, get_download_dir\nfrom torch.optim import Adam\n\n\nclass GATConv(nn.Module):\n    def __init__(self, in_size, out_size, num_heads, dropout):\n        super().__init__()\n\n        self.out_size = out_size\n        self.num_heads = num_heads\n\n        self.dropout = nn.Dropout(dropout)\n        self.W = nn.Linear(in_size, out_size * num_heads)\n        self.a_l = nn.Parameter(torch.zeros(1, out_size, num_heads))\n        self.a_r = nn.Parameter(torch.zeros(1, out_size, num_heads))\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        gain = nn.init.calculate_gain(\"relu\")\n        nn.init.xavier_normal_(self.W.weight, gain=gain)\n        nn.init.xavier_normal_(self.a_l, gain=gain)\n        nn.init.xavier_normal_(self.a_r, gain=gain)\n\n    ###########################################################################\n    # (HIGHLIGHT) Take the advantage of DGL sparse APIs to implement\n    # multihead attention.\n    ###########################################################################\n    def forward(self, A_hat, Z):\n        Z = self.dropout(Z)\n        Z = self.W(Z).view(Z.shape[0], self.out_size, self.num_heads)\n\n        # a^T [Wh_i || Wh_j] = a_l Wh_i + a_r Wh_j\n        e_l = (Z * self.a_l).sum(dim=1)\n        e_r = (Z * self.a_r).sum(dim=1)\n        e = e_l[A_hat.row] + e_r[A_hat.col]\n\n        a = F.leaky_relu(e)\n        A_atten = dglsp.val_like(A_hat, a).softmax()\n        a_drop = self.dropout(A_atten.val)\n        A_atten = dglsp.val_like(A_atten, a_drop)\n        return dglsp.bspmm(A_atten, Z)\n\n\nclass SemanticAttention(nn.Module):\n    def __init__(self, in_size, hidden_size=128):\n        super().__init__()\n\n        self.project = nn.Sequential(\n            nn.Linear(in_size, hidden_size),\n            nn.Tanh(),\n            nn.Linear(hidden_size, 1, bias=False),\n        )\n\n    def forward(self, z):\n        w = self.project(z).mean(0)\n        beta = torch.softmax(w, dim=0)\n        beta = beta.expand((z.shape[0],) + beta.shape)\n\n        return (beta * z).sum(1)\n\n\nclass HAN(nn.Module):\n    def __init__(\n        self,\n        num_meta_paths,\n        in_size,\n        out_size,\n        hidden_size=8,\n        num_heads=8,\n        dropout=0.6,\n    ):\n        super().__init__()\n\n        self.gat_layers = nn.ModuleList()\n        for _ in range(num_meta_paths):\n            self.gat_layers.append(\n                GATConv(in_size, hidden_size, num_heads, dropout)\n            )\n\n        in_size = hidden_size * num_heads\n        self.semantic_attention = SemanticAttention(in_size)\n        self.predict = nn.Linear(in_size, out_size)\n\n    def forward(self, A_list, X):\n        meta_path_Z_list = []\n        for i, A in enumerate(A_list):\n            meta_path_Z_list.append(self.gat_layers[i](A, X).flatten(1))\n\n        # (num_nodes, num_meta_paths, hidden_size * num_heads)\n        meta_path_Z = torch.stack(meta_path_Z_list, dim=1)\n\n        Z = self.semantic_attention(meta_path_Z)\n        Z = self.predict(Z)\n\n        return Z\n\n\ndef evaluate(label, val_idx, test_idx, pred):\n    # Compute accuracy on validation/test set.\n    val_acc = (pred[val_idx] == label[val_idx]).float().mean()\n    test_acc = (pred[test_idx] == label[test_idx]).float().mean()\n    return val_acc, test_acc\n\n\ndef train(model, data, A_list, X, label):\n    dev = X.device\n    train_idx = torch.from_numpy(data[\"train_idx\"]).long().squeeze(0).to(dev)\n    val_idx = torch.from_numpy(data[\"val_idx\"]).long().squeeze(0).to(dev)\n    test_idx = torch.from_numpy(data[\"test_idx\"]).long().squeeze(0).to(dev)\n    optimizer = Adam(model.parameters(), lr=0.005, weight_decay=0.001)\n\n    for epoch in range(70):\n        # Forward.\n        model.train()\n        logits = model(A_list, X)\n\n        # Compute loss with nodes in training set.\n        loss = F.cross_entropy(logits[train_idx], label[train_idx])\n\n        # Backward.\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        # Compute prediction.\n        model.eval()\n        logits = model(A_list, X)\n        pred = logits.argmax(dim=1)\n\n        # Evaluate the prediction.\n        val_acc, test_acc = evaluate(label, val_idx, test_idx, pred)\n        print(\n            f\"In epoch {epoch}, loss: {loss:.3f}, val acc: {val_acc:.3f}, test\"\n            f\" acc: {test_acc:.3f}\"\n        )\n\n\nif __name__ == \"__main__\":\n    # If CUDA is available, use GPU to accelerate the training, use CPU\n    # otherwise.\n    dev = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n\n    # (TODO): Move the logic to a built-in dataset.\n    # Load the data.\n    url = \"dataset/ACM3025.pkl\"\n    data_path = get_download_dir() + \"/ACM3025.pkl\"\n    download(_get_dgl_url(url), path=data_path)\n\n    with open(data_path, \"rb\") as f:\n        data = pickle.load(f)\n\n    # Create sparse adjacency matrices corresponding to two meta paths.\n    # Self-loops already added.\n    PAP_dst, PAP_src = data[\"PAP\"].nonzero()\n    PAP_indices = torch.stack(\n        [torch.from_numpy(PAP_src).long(), torch.from_numpy(PAP_dst).long()]\n    ).to(dev)\n    PAP_A = dglsp.spmatrix(PAP_indices)\n\n    PLP_dst, PLP_src = data[\"PLP\"].nonzero()\n    PLP_indices = torch.stack(\n        [torch.from_numpy(PLP_src).long(), torch.from_numpy(PLP_src).long()]\n    ).to(dev)\n    PLP_A = dglsp.spmatrix(PLP_indices)\n    A_list = [PAP_A, PLP_A]\n\n    # Create HAN model.\n    X = torch.from_numpy(data[\"feature\"].todense()).float().to(dev)\n    label = torch.from_numpy(data[\"label\"].todense())\n    out_size = label.shape[1]\n    label = label.nonzero()[:, 1].to(dev)\n    in_size = X.shape[1]\n    model = HAN(len(A_list), in_size, out_size).to(dev)\n\n    # Kick off training.\n    train(model, data, A_list, X, label)\n"
  },
  {
    "path": "examples/sparse/hetero-rgcn.py",
    "content": "\"\"\"\nModeling Relational Data with Graph Convolutional Networks\nPaper: https://arxiv.org/abs/1703.06103\nReference Code: https://github.com/tkipf/relational-gcn\n\nThis script trains and tests a Hetero Relational Graph Convolutional Networks \n(Hetero-RGCN) model based on the information of a full graph.\n\nThis flowchart describes the main functional sequence of the provided example.\nmain\n│\n├───> Load and preprocess full dataset\n│\n├───> Instantiate Hetero-RGCN model\n│\n├───> train\n│     │\n│     └───> Training loop\n│           │\n│           └───> Hetero-RGCN.forward\n└───> test\n      │\n      └───> Evaluate the model\n\"\"\"\nimport argparse\nimport time\n\nimport dgl\nimport dgl.sparse as dglsp\n\nimport numpy as np\n\nimport torch as th\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom dgl.data.rdf import AIFBDataset, AMDataset, BGSDataset, MUTAGDataset\n\n\nclass RelGraphEmbed(nn.Module):\n    r\"\"\"Embedding layer for featureless heterograph.\"\"\"\n\n    def __init__(\n        self,\n        ntype_num,\n        embed_size,\n    ):\n        super(RelGraphEmbed, self).__init__()\n        self.embed_size = embed_size\n        self.dropout = nn.Dropout(0.0)\n\n        # Create weight embeddings for each node for each relation.\n        self.embeds = nn.ParameterDict()\n        for ntype, num_nodes in ntype_num.items():\n            embed = nn.Parameter(th.Tensor(num_nodes, self.embed_size))\n            nn.init.xavier_uniform_(embed, gain=nn.init.calculate_gain(\"relu\"))\n            self.embeds[ntype] = embed\n\n    def forward(self):\n        return self.embeds\n\n\nclass HeteroRelationalGraphConv(nn.Module):\n    r\"\"\"HeteroRelational graph convolution layer.\n\n    Parameters\n    ----------\n    in_size : int\n        Input feature size.\n    out_size : int\n        Output feature size.\n    relation_names : list[str]\n        Relation names.\n    \"\"\"\n\n    def __init__(\n        self,\n        in_size,\n        out_size,\n        relation_names,\n        activation=None,\n    ):\n        super(HeteroRelationalGraphConv, self).__init__()\n        self.in_size = in_size\n        self.out_size = out_size\n        self.relation_names = relation_names\n        self.activation = activation\n\n        ########################################################################\n        # (HIGHLIGHT) HeteroGraphConv is a graph convolution operator over\n        # heterogeneous graphs. A dictionary is passed where the key is the\n        # relation name and the value is the insatnce of conv layer.\n        ########################################################################\n        self.W = nn.ModuleDict(\n            {str(rel): nn.Linear(in_size, out_size) for rel in relation_names}\n        )\n\n        self.dropout = nn.Dropout(0.0)\n\n    def forward(self, A, inputs):\n        \"\"\"Forward computation\n\n        Parameters\n        ----------\n        A : Hetero Sparse Matrix\n            Input graph.\n        inputs : dict[str, torch.Tensor]\n            Node feature for each node type.\n\n        Returns\n        -------\n        dict[str, torch.Tensor]\n            New node features for each node type.\n        \"\"\"\n        hs = {}\n        for rel in A:\n            src_type, edge_type, dst_type = rel\n            if dst_type not in hs:\n                hs[dst_type] = th.zeros(\n                    inputs[dst_type].shape[0], self.out_size\n                )\n            ####################################################################\n            # (HIGHLIGHT) Sparse library use hetero sparse matrix to present\n            # heterogeneous graphs. A dictionary is passed where the key is\n            # the tuple of (source node type, edge type, destination node type)\n            # and the value is the sparse matrix contructed from the key on\n            # global graph. The convolution operation is the multiplication of\n            # sparse matrix and convolutional layer.\n            ####################################################################\n            hs[dst_type] = hs[dst_type] + (\n                A[rel].T @ self.W[str(edge_type)](inputs[src_type])\n            )\n            if self.activation:\n                hs[dst_type] = self.activation(hs[dst_type])\n            hs[dst_type] = self.dropout(hs[dst_type])\n\n        return hs\n\n\nclass EntityClassify(nn.Module):\n    def __init__(\n        self,\n        in_size,\n        out_size,\n        relation_names,\n        embed_layer,\n    ):\n        super(EntityClassify, self).__init__()\n        self.in_size = in_size\n        self.out_size = out_size\n        self.relation_names = relation_names\n        self.relation_names.sort()\n        self.embed_layer = embed_layer\n\n        self.layers = nn.ModuleList()\n        # Input to hidden.\n        self.layers.append(\n            HeteroRelationalGraphConv(\n                self.in_size,\n                self.in_size,\n                self.relation_names,\n                activation=F.relu,\n            )\n        )\n        # Hidden to output.\n        self.layers.append(\n            HeteroRelationalGraphConv(\n                self.in_size,\n                self.out_size,\n                self.relation_names,\n            )\n        )\n\n    def forward(self, A):\n        h = self.embed_layer()\n        for layer in self.layers:\n            h = layer(A, h)\n        return h\n\n\ndef main(args):\n    # Load graph data.\n    if args.dataset == \"aifb\":\n        dataset = AIFBDataset()\n    elif args.dataset == \"bgs\":\n        dataset = BGSDataset()\n    else:\n        raise ValueError()\n\n    g = dataset[0]\n    category = dataset.predict_category\n    num_classes = dataset.num_classes\n    train_mask = g.nodes[category].data.pop(\"train_mask\")\n    test_mask = g.nodes[category].data.pop(\"test_mask\")\n    train_idx = th.nonzero(train_mask, as_tuple=False).squeeze()\n    test_idx = th.nonzero(test_mask, as_tuple=False).squeeze()\n    labels = g.nodes[category].data.pop(\"labels\")\n\n    # Split dataset into train, validate, test.\n    val_idx = train_idx[: len(train_idx) // 5]\n    train_idx = train_idx[len(train_idx) // 5 :]\n\n    embed_layer = RelGraphEmbed(\n        {ntype: g.num_nodes(ntype) for ntype in g.ntypes}, 16\n    )\n\n    # Create model.\n    model = EntityClassify(\n        16,\n        num_classes,\n        list(set(g.etypes)),\n        embed_layer,\n    )\n\n    # Optimizer.\n    optimizer = th.optim.Adam(model.parameters(), lr=1e-2, weight_decay=0)\n\n    # Construct hetero sparse matrix.\n    A = {}\n    for stype, etype, dtype in g.canonical_etypes:\n        eg = g[stype, etype, dtype]\n        indices = th.stack(eg.edges(\"uv\"))\n        A[(stype, etype, dtype)] = dglsp.spmatrix(\n            indices, shape=(g.num_nodes(stype), g.num_nodes(dtype))\n        )\n        ###########################################################\n        # (HIGHLIGHT) Compute the normalized adjacency matrix with\n        # Sparse Matrix API\n        ###########################################################\n        D1_hat = dglsp.diag(A[(stype, etype, dtype)].sum(1)) ** -0.5\n        D2_hat = dglsp.diag(A[(stype, etype, dtype)].sum(0)) ** -0.5\n        A[(stype, etype, dtype)] = D1_hat @ A[(stype, etype, dtype)] @ D2_hat\n\n    # Training loop.\n    print(\"start training...\")\n    model.train()\n    for epoch in range(10):\n        optimizer.zero_grad()\n        logits = model(A)[category]\n        loss = F.cross_entropy(logits[train_idx], labels[train_idx])\n        loss.backward()\n        optimizer.step()\n\n        train_acc = th.sum(\n            logits[train_idx].argmax(dim=1) == labels[train_idx]\n        ).item() / len(train_idx)\n        val_loss = F.cross_entropy(logits[val_idx], labels[val_idx])\n        val_acc = th.sum(\n            logits[val_idx].argmax(dim=1) == labels[val_idx]\n        ).item() / len(val_idx)\n        print(\n            f\"Epoch {epoch:05d} | Train Acc: {train_acc:.4f} | \"\n            f\"Train Loss: {loss.item():.4f} | Valid Acc: {val_acc:.4f} | \"\n            f\"Valid loss: {val_loss.item():.4f} \"\n        )\n    print()\n\n    model.eval()\n    logits = model.forward(A)[category]\n    test_loss = F.cross_entropy(logits[test_idx], labels[test_idx])\n    test_acc = th.sum(\n        logits[test_idx].argmax(dim=1) == labels[test_idx]\n    ).item() / len(test_idx)\n    print(\n        \"Test Acc: {:.4f} | Test loss: {:.4f}\".format(\n            test_acc, test_loss.item()\n        )\n    )\n    print()\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"RGCN\")\n    parser.add_argument(\n        \"-d\", \"--dataset\", type=str, required=True, help=\"dataset to use\"\n    )\n\n    args = parser.parse_args()\n    print(args)\n    main(args)\n"
  },
  {
    "path": "examples/sparse/hgnn.py",
    "content": "\"\"\"\nHypergraph Neural Networks (https://arxiv.org/pdf/1809.09401.pdf)\n\"\"\"\nimport dgl.sparse as dglsp\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport tqdm\nfrom dgl.data import CoraGraphDataset\nfrom torchmetrics.functional import accuracy\n\n\nclass HGNN(nn.Module):\n    def __init__(self, H, in_size, out_size, hidden_dims=16):\n        super().__init__()\n\n        self.Theta1 = nn.Linear(in_size, hidden_dims)\n        self.Theta2 = nn.Linear(hidden_dims, out_size)\n        self.dropout = nn.Dropout(0.5)\n\n        ###########################################################\n        # (HIGHLIGHT) Compute the Laplacian with Sparse Matrix API\n        ###########################################################\n        d_V = H.sum(1)  # node degree\n        d_E = H.sum(0)  # edge degree\n        n_edges = d_E.shape[0]\n        D_V_invsqrt = dglsp.diag(d_V**-0.5)  # D_V ** (-1/2)\n        D_E_inv = dglsp.diag(d_E**-1)  # D_E ** (-1)\n        W = dglsp.identity((n_edges, n_edges))\n        self.laplacian = D_V_invsqrt @ H @ W @ D_E_inv @ H.T @ D_V_invsqrt\n\n    def forward(self, X):\n        X = self.laplacian @ self.Theta1(self.dropout(X))\n        X = F.relu(X)\n        X = self.laplacian @ self.Theta2(self.dropout(X))\n        return X\n\n\ndef train(model, optimizer, X, Y, train_mask):\n    model.train()\n    Y_hat = model(X)\n    loss = F.cross_entropy(Y_hat[train_mask], Y[train_mask])\n    optimizer.zero_grad()\n    loss.backward()\n    optimizer.step()\n\n\ndef evaluate(model, X, Y, val_mask, test_mask, num_classes):\n    model.eval()\n    Y_hat = model(X)\n    val_acc = accuracy(\n        Y_hat[val_mask], Y[val_mask], task=\"multiclass\", num_classes=num_classes\n    )\n    test_acc = accuracy(\n        Y_hat[test_mask],\n        Y[test_mask],\n        task=\"multiclass\",\n        num_classes=num_classes,\n    )\n    return val_acc, test_acc\n\n\ndef load_data():\n    dataset = CoraGraphDataset()\n\n    graph = dataset[0]\n    # The paper created a hypergraph from the original graph. For each node in\n    # the original graph, a hyperedge in the hypergraph is created to connect\n    # its neighbors and itself. In this case, the incidence matrix of the\n    # hypergraph is the same as the adjacency matrix of the original graph (with\n    # self-loops).\n    # We follow the paper and assume that the rows of the incidence matrix\n    # are for nodes and the columns are for edges.\n    indices = torch.stack(graph.edges())\n    H = dglsp.spmatrix(indices)\n    H = H + dglsp.identity(H.shape)\n\n    X = graph.ndata[\"feat\"]\n    Y = graph.ndata[\"label\"]\n    train_mask = graph.ndata[\"train_mask\"]\n    val_mask = graph.ndata[\"val_mask\"]\n    test_mask = graph.ndata[\"test_mask\"]\n    return H, X, Y, dataset.num_classes, train_mask, val_mask, test_mask\n\n\ndef main():\n    H, X, Y, num_classes, train_mask, val_mask, test_mask = load_data()\n    model = HGNN(H, X.shape[1], num_classes)\n    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n\n    with tqdm.trange(500) as tq:\n        for epoch in tq:\n            train(model, optimizer, X, Y, train_mask)\n            val_acc, test_acc = evaluate(\n                model, X, Y, val_mask, test_mask, num_classes\n            )\n            tq.set_postfix(\n                {\n                    \"Val acc\": f\"{val_acc:.5f}\",\n                    \"Test acc\": f\"{test_acc:.5f}\",\n                },\n                refresh=False,\n            )\n\n    print(f\"Test acc: {test_acc:.3f}\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/sparse/hypergraphatt.py",
    "content": "\"\"\"\nHypergraph Convolution and Hypergraph Attention\n(https://arxiv.org/pdf/1901.08150.pdf).\n\"\"\"\nimport argparse\n\nimport dgl.sparse as dglsp\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport tqdm\nfrom dgl.data import CoraGraphDataset\nfrom torchmetrics.functional import accuracy\n\n\ndef hypergraph_laplacian(H):\n    ###########################################################\n    # (HIGHLIGHT) Compute the Laplacian with Sparse Matrix API\n    ###########################################################\n    d_V = H.sum(1)  # node degree\n    d_E = H.sum(0)  # edge degree\n    n_edges = d_E.shape[0]\n    D_V_invsqrt = dglsp.diag(d_V**-0.5)  # D_V ** (-1/2)\n    D_E_inv = dglsp.diag(d_E**-1)  # D_E ** (-1)\n    W = dglsp.identity((n_edges, n_edges))\n    return D_V_invsqrt @ H @ W @ D_E_inv @ H.T @ D_V_invsqrt\n\n\nclass HypergraphAttention(nn.Module):\n    \"\"\"Hypergraph Attention module as in the paper\n    `Hypergraph Convolution and Hypergraph Attention\n    <https://arxiv.org/pdf/1901.08150.pdf>`_.\n    \"\"\"\n\n    def __init__(self, in_size, out_size):\n        super().__init__()\n\n        self.P = nn.Linear(in_size, out_size)\n        self.a = nn.Linear(2 * out_size, 1)\n\n    def forward(self, H, X, X_edges):\n        Z = self.P(X)\n        Z_edges = self.P(X_edges)\n        sim = self.a(torch.cat([Z[H.row], Z_edges[H.col]], 1))\n        sim = F.leaky_relu(sim, 0.2).squeeze(1)\n        # Reassign the hypergraph new weights.\n        H_att = dglsp.val_like(H, sim)\n        H_att = H_att.softmax()\n        return hypergraph_laplacian(H_att) @ Z\n\n\nclass Net(nn.Module):\n    def __init__(self, in_size, out_size, hidden_size=16):\n        super().__init__()\n\n        self.layer1 = HypergraphAttention(in_size, hidden_size)\n        self.layer2 = HypergraphAttention(hidden_size, out_size)\n\n    def forward(self, H, X):\n        Z = self.layer1(H, X, X)\n        Z = F.elu(Z)\n        Z = self.layer2(H, Z, Z)\n        return Z\n\n\ndef train(model, optimizer, H, X, Y, train_mask):\n    model.train()\n    Y_hat = model(H, X)\n    loss = F.cross_entropy(Y_hat[train_mask], Y[train_mask])\n    optimizer.zero_grad()\n    loss.backward()\n    optimizer.step()\n    return loss.item()\n\n\ndef evaluate(model, H, X, Y, val_mask, test_mask, num_classes):\n    model.eval()\n    Y_hat = model(H, X)\n    val_acc = accuracy(\n        Y_hat[val_mask], Y[val_mask], task=\"multiclass\", num_classes=num_classes\n    )\n    test_acc = accuracy(\n        Y_hat[test_mask],\n        Y[test_mask],\n        task=\"multiclass\",\n        num_classes=num_classes,\n    )\n    return val_acc, test_acc\n\n\ndef load_data():\n    dataset = CoraGraphDataset()\n\n    graph = dataset[0]\n    # The paper created a hypergraph from the original graph. For each node in\n    # the original graph, a hyperedge in the hypergraph is created to connect\n    # its neighbors and itself. In this case, the incidence matrix of the\n    # hypergraph is the same as the adjacency matrix of the original graph (with\n    # self-loops).\n    # We follow the paper and assume that the rows of the incidence matrix\n    # are for nodes and the columns are for edges.\n    indices = torch.stack(graph.edges())\n    H = dglsp.spmatrix(indices)\n    H = H + dglsp.identity(H.shape)\n\n    X = graph.ndata[\"feat\"]\n    Y = graph.ndata[\"label\"]\n    train_mask = graph.ndata[\"train_mask\"]\n    val_mask = graph.ndata[\"val_mask\"]\n    test_mask = graph.ndata[\"test_mask\"]\n    return H, X, Y, dataset.num_classes, train_mask, val_mask, test_mask\n\n\ndef main(args):\n    H, X, Y, num_classes, train_mask, val_mask, test_mask = load_data()\n    model = Net(X.shape[1], num_classes)\n    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n\n    with tqdm.trange(args.epochs) as tq:\n        for epoch in tq:\n            loss = train(model, optimizer, H, X, Y, train_mask)\n            val_acc, test_acc = evaluate(\n                model, H, X, Y, val_mask, test_mask, num_classes\n            )\n            tq.set_postfix(\n                {\n                    \"Loss\": f\"{loss:.5f}\",\n                    \"Val acc\": f\"{val_acc:.5f}\",\n                    \"Test acc\": f\"{test_acc:.5f}\",\n                },\n                refresh=False,\n            )\n\n    print(f\"Test acc: {test_acc:.3f}\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"Hypergraph Attention Example\")\n    parser.add_argument(\n        \"--epochs\", type=int, default=500, help=\"Number of training epochs.\"\n    )\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/sparse/pagerank.py",
    "content": "import dgl.sparse as dglsp\nimport networkx as nx\nimport torch\n\nN = 100\nDAMP = 0.85\nK = 10\n\n\ndef pagerank(A):\n    D = A.sum(0)\n    V = torch.ones(N) / N\n    for _ in range(K):\n        ########################################################################\n        # (HIGHLIGHT) Take the advantage of DGL sparse APIs to calculate the\n        # page rank.\n        ########################################################################\n        V = (1 - DAMP) / N + DAMP * A @ (V / D)\n    return V\n\n\nif __name__ == \"__main__\":\n    g = nx.erdos_renyi_graph(N, 0.05, seed=10086)\n\n    # Create the adjacency matrix of graph.\n    edges = list(g.to_directed().edges())\n    indices = torch.tensor(edges).transpose(0, 1)\n    A = dglsp.spmatrix(indices, shape=(N, N))\n\n    V = pagerank(A)\n    print(V)\n"
  },
  {
    "path": "examples/sparse/sampling/graphsage.py",
    "content": "\"\"\"\nThis script demonstrate how to use dgl sparse library to sample on graph and \ntrain model. It trains and tests a GraphSAGE model using the sparse sample and \ncompact operators to sample submatrix from the whole matrix.\n\nThis flowchart describes the main functional sequence of the provided example.\nmain\n│\n├───> Load and preprocess full dataset\n│\n├───> Instantiate SAGE model\n│\n├───> train\n│     │\n│     └───> Training loop\n│           │\n│           ├───> Sample submatrix\n│           │\n│           └───> SAGE.forward\n└───> test\n      │\n      ├───> Sample submatrix\n      │\n      └───> Evaluate the model\n\"\"\"\nimport argparse\n\nimport dgl.sparse as dglsp\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torchmetrics.functional as MF\nfrom dgl.data import AsNodePredDataset\nfrom ogb.nodeproppred import DglNodePropPredDataset\n\n\nclass SAGEConv(nn.Module):\n    r\"\"\"GraphSAGE layer from `Inductive Representation Learning on\n    Large Graphs <https://arxiv.org/pdf/1706.02216.pdf>`__\n    \"\"\"\n\n    def __init__(\n        self,\n        in_size,\n        out_size,\n    ):\n        super(SAGEConv, self).__init__()\n        self._in_src_feats, self._in_dst_feats = in_size, in_size\n        self._out_size = out_size\n\n        self.fc_neigh = nn.Linear(self._in_src_feats, out_size, bias=False)\n        self.fc_self = nn.Linear(self._in_dst_feats, out_size, bias=True)\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        gain = nn.init.calculate_gain(\"relu\")\n        nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)\n        nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)\n\n    def forward(self, A, feat):\n        feat_src = feat\n        feat_dst = feat[: A.shape[1]]\n\n        # Aggregator type: mean.\n        srcdata = self.fc_neigh(feat_src)\n        # Divided by degree.\n        D_hat = dglsp.diag(A.sum(0)) ** -1\n        A_div = A @ D_hat\n        # Conv neighbors.\n        dstdata = A_div.T @ srcdata\n\n        rst = self.fc_self(feat_dst) + dstdata\n        return rst\n\n\nclass SAGE(nn.Module):\n    def __init__(self, in_size, hid_size, out_size):\n        super().__init__()\n        self.layers = nn.ModuleList()\n        # Three-layer GraphSAGE-gcn.\n        self.layers.append(SAGEConv(in_size, hid_size))\n        self.layers.append(SAGEConv(hid_size, hid_size))\n        self.layers.append(SAGEConv(hid_size, out_size))\n        self.dropout = nn.Dropout(0.5)\n        self.hid_size = hid_size\n        self.out_size = out_size\n\n    def forward(self, sampled_matrices, x):\n        hidden_x = x\n        for layer_idx, (layer, sampled_matrix) in enumerate(\n            zip(self.layers, sampled_matrices)\n        ):\n            hidden_x = layer(sampled_matrix, hidden_x)\n            if layer_idx != len(self.layers) - 1:\n                hidden_x = F.relu(hidden_x)\n                hidden_x = self.dropout(hidden_x)\n        return hidden_x\n\n\ndef multilayer_sample(A, fanouts, seeds, ndata):\n    sampled_matrices = []\n    src = seeds\n\n    #####################################################################\n    # (HIGHLIGHT) Using the sparse sample operator to preform random\n    # sampling on the neighboring nodes of the seeds nodes. The sparse\n    # compact operator is then employed to compact and relabel the sampled\n    # matrix, resulting in the sampled matrix and the relabel index.\n    #####################################################################\n\n    for fanout in fanouts:\n        # Sample neighbors.\n        sampled_matrix = A.sample(1, fanout, ids=src).coalesce()\n        # Compact the sampled matrix.\n        compacted_mat, row_ids = sampled_matrix.compact(0)\n        sampled_matrices.insert(0, compacted_mat)\n        src = row_ids\n\n    x = ndata[\"feat\"][src]\n    y = ndata[\"label\"][seeds]\n    return sampled_matrices, x, y\n\n\ndef evaluate(model, A, dataloader, ndata, num_classes):\n    model.eval()\n    ys = []\n    y_hats = []\n    fanouts = [10, 10, 10]\n    for it, seeds in enumerate(dataloader):\n        with torch.no_grad():\n            sampled_matrices, x, y = multilayer_sample(A, fanouts, seeds, ndata)\n            ys.append(y)\n            y_hats.append(model(sampled_matrices, x))\n\n    return MF.accuracy(\n        torch.cat(y_hats),\n        torch.cat(ys),\n        task=\"multiclass\",\n        num_classes=num_classes,\n    )\n\n\ndef validate(device, A, ndata, dataset, model, batch_size):\n    inf_id = dataset.test_idx.to(device)\n    inf_dataloader = torch.utils.data.DataLoader(inf_id, batch_size=batch_size)\n    acc = evaluate(model, A, inf_dataloader, ndata, dataset.num_classes)\n    return acc\n\n\ndef train(device, A, ndata, dataset, model):\n    # Create sampler & dataloader.\n    train_idx = dataset.train_idx.to(device)\n    val_idx = dataset.val_idx.to(device)\n\n    train_dataloader = torch.utils.data.DataLoader(\n        train_idx, batch_size=1024, shuffle=True\n    )\n    val_dataloader = torch.utils.data.DataLoader(val_idx, batch_size=1024)\n\n    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4)\n\n    fanouts = [10, 10, 10]\n    for epoch in range(10):\n        model.train()\n        total_loss = 0\n        for it, seeds in enumerate(train_dataloader):\n            sampled_matrices, x, y = multilayer_sample(A, fanouts, seeds, ndata)\n            y_hat = model(sampled_matrices, x)\n            loss = F.cross_entropy(y_hat, y)\n            optimizer.zero_grad()\n            loss.backward()\n            optimizer.step()\n            total_loss += loss.item()\n\n        acc = evaluate(model, A, val_dataloader, ndata, dataset.num_classes)\n        print(\n            \"Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} \".format(\n                epoch, total_loss / (it + 1), acc.item()\n            )\n        )\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"GraphSAGE\")\n    parser.add_argument(\n        \"--mode\",\n        default=\"gpu\",\n        choices=[\"cpu\", \"gpu\"],\n        help=\"Training mode. 'cpu' for CPU training, 'gpu' for GPU training.\",\n    )\n    args = parser.parse_args()\n    if not torch.cuda.is_available():\n        args.mode = \"cpu\"\n    print(f\"Training in {args.mode} mode.\")\n\n    #####################################################################\n    # (HIGHLIGHT) This example implements a graphSAGE algorithm by sparse\n    # operators, which involves sampling a subgraph from a full graph and\n    # conducting training.\n    #\n    # First, the whole graph is loaded onto the CPU or GPU and transformed\n    # to sparse matrix. To obtain the training subgraph, it samples three\n    # submatrices by seed nodes, which contains their randomly sampled\n    # 1-hop, 2-hop, and 3-hop neighbors. Then, the features of the\n    # subgraph are input to the network for training.\n    #####################################################################\n\n    # Load and preprocess dataset.\n    print(\"Loading data\")\n    device = torch.device(\"cpu\" if args.mode == \"cpu\" else \"cuda\")\n    dataset = AsNodePredDataset(DglNodePropPredDataset(\"ogbn-products\"))\n    g = dataset[0]\n    g = g.to(device)\n\n    # Create GraphSAGE model.\n    in_size = g.ndata[\"feat\"].shape[1]\n    out_size = dataset.num_classes\n    model = SAGE(in_size, 256, out_size).to(device)\n\n    # Create sparse.\n    indices = torch.stack(g.edges())\n    N = g.num_nodes()\n    A = dglsp.spmatrix(indices, shape=(N, N))\n\n    # Model training.\n    print(\"Training...\")\n    train(device, A, g.ndata, dataset, model)\n\n    # Test the model.\n    print(\"Testing...\")\n    acc = validate(device, A, g.ndata, dataset, model, batch_size=4096)\n    print(f\"Test accuracy {acc:.4f}\")\n"
  },
  {
    "path": "examples/sparse/sampling/ladies.py",
    "content": "\"\"\"\nThis script demonstrates how to use dgl sparse library to sample on graph and \ntrain model. It trains and tests a LADIES model using the sparse power and \nsp_broadcast_v operators to sample submatrix from the whole matrix.\n\nThis flowchart describes the main functional sequence of the provided example.\nmain\n│\n├───> Load and preprocess full dataset\n│\n├───> Instantiate LADIES model\n│\n├───> train\n│     │\n│     └───> Training loop\n│           │\n│           ├───> Sample submatrix\n│           │\n│           └───> LADIES.forward\n└───> test\n      │\n      ├───> Sample submatrix\n      │\n      └───> Evaluate the model\n\"\"\"\nimport argparse\n\nimport dgl.sparse as dglsp\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torchmetrics.functional as MF\nfrom dgl.data import AsNodePredDataset\nfrom dgl.sparse import sp_broadcast_v\nfrom ogb.nodeproppred import DglNodePropPredDataset\n\n\nclass SAGEConv(nn.Module):\n    r\"\"\"LADIES layer from `Layer-Dependent Importance Sampling\n    for Training Deep and Large Graph Convolutional Networks\n    <https://arxiv.org/abs/1911.07323.pdf>`__\"\"\"\n\n    def __init__(\n        self,\n        in_size,\n        out_size,\n    ):\n        super(SAGEConv, self).__init__()\n        self._in_src_feats, self._in_dst_feats = in_size, in_size\n        self._out_size = out_size\n\n        self.fc_neigh = nn.Linear(self._in_src_feats, out_size, bias=False)\n        self.fc_self = nn.Linear(self._in_dst_feats, out_size, bias=True)\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        gain = nn.init.calculate_gain(\"relu\")\n        nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)\n        nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)\n\n    def forward(self, A, feat):\n        feat_src = feat\n        feat_dst = feat[: A.shape[1]]\n\n        # Aggregator type: mean.\n        srcdata = self.fc_neigh(feat_src)\n        # Divided by degree.\n        D_hat = dglsp.diag(A.sum(0)) ** -1\n        A_div = A @ D_hat\n        # Conv neighbors.\n        dstdata = A_div.T @ srcdata\n\n        rst = self.fc_self(feat_dst) + dstdata\n        return rst\n\n\nclass LADIES(nn.Module):\n    def __init__(self, in_size, hid_size, out_size):\n        super().__init__()\n        self.layers = nn.ModuleList()\n        # Three-layer LADIES.\n        self.layers.append(SAGEConv(in_size, hid_size))\n        self.layers.append(SAGEConv(hid_size, hid_size))\n        self.layers.append(SAGEConv(hid_size, out_size))\n\n        self.dropout = nn.Dropout(0.5)\n        self.hid_size = hid_size\n        self.out_size = out_size\n\n    def forward(self, sampled_matrices, x):\n        hidden_x = x\n        for layer_idx, (layer, sampled_matrix) in enumerate(\n            zip(self.layers, sampled_matrices)\n        ):\n            hidden_x = layer(sampled_matrix, hidden_x)\n            if layer_idx != len(self.layers) - 1:\n                hidden_x = F.relu(hidden_x)\n                hidden_x = self.dropout(hidden_x)\n        return hidden_x\n\n\ndef multilayer_sample(A, fanouts, seeds, ndata):\n    sampled_matrices = []\n    src = seeds\n\n    #########################################################################\n    # (HIGHLIGHT) Using the sparse sample operator to preform LADIES sampling\n    # algorithm from the neighboring nodes of the seeds nodes.\n    # The sparse sp_power operator is applied to compute sample probability,\n    # and sp_broadcast_v is then employed to normalize weight by performing\n    # division operations on column.\n    #########################################################################\n\n    for fanout in fanouts:\n        # Sample neighbors.\n        sub_A = A.index_select(1, src)\n        # Compute probability weight.\n        row_probs = (sub_A**2).sum(1)\n        row_probs = row_probs / row_probs.sum(0)\n        # Layer-wise sample nodes.\n        row_ids = torch.multinomial(row_probs, fanout, replacement=False)\n        # Add self-loop.\n        row_ids = torch.cat((row_ids, src), 0).unique()\n        sampled_matrix = sub_A.index_select(0, row_ids)\n        # Normalize edge weights.\n        div_matirx = sp_broadcast_v(\n            sampled_matrix, row_probs[row_ids].reshape(-1, 1), \"truediv\"\n        )\n        div_matirx = sp_broadcast_v(div_matirx, div_matirx.sum(0), \"truediv\")\n\n        # Save the sampled matrix.\n        sampled_matrices.insert(0, div_matirx)\n        src = row_ids\n\n    x = ndata[\"feat\"][src]\n    y = ndata[\"label\"][seeds]\n    return sampled_matrices, x, y\n\n\ndef evaluate(model, A, dataloader, ndata, num_classes):\n    model.eval()\n    ys = []\n    y_hats = []\n    fanouts = [4000, 4000, 4000]\n    for seeds in dataloader:\n        with torch.no_grad():\n            sampled_matrices, x, y = multilayer_sample(A, fanouts, seeds, ndata)\n            ys.append(y)\n            y_hats.append(model(sampled_matrices, x))\n\n    return MF.accuracy(\n        torch.cat(y_hats),\n        torch.cat(ys),\n        task=\"multiclass\",\n        num_classes=num_classes,\n    )\n\n\ndef validate(device, A, ndata, dataset, model, batch_size):\n    inf_id = dataset.test_idx.to(device)\n    inf_dataloader = torch.utils.data.DataLoader(inf_id, batch_size=batch_size)\n    acc = evaluate(model, A, inf_dataloader, ndata, dataset.num_classes)\n    return acc\n\n\ndef train(device, A, ndata, dataset, model):\n    # Create sampler & dataloader.\n    train_idx = dataset.train_idx.to(device)\n    val_idx = dataset.val_idx.to(device)\n\n    train_dataloader = torch.utils.data.DataLoader(\n        train_idx, batch_size=1024, shuffle=True\n    )\n    val_dataloader = torch.utils.data.DataLoader(val_idx, batch_size=1024)\n\n    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4)\n\n    fanouts = [4000, 4000, 4000]\n    for epoch in range(20):\n        model.train()\n        total_loss = 0\n        for it, seeds in enumerate(train_dataloader):\n            sampled_matrices, x, y = multilayer_sample(A, fanouts, seeds, ndata)\n            y_hat = model(sampled_matrices, x)\n            loss = F.cross_entropy(y_hat, y)\n            optimizer.zero_grad()\n            loss.backward()\n            optimizer.step()\n            total_loss += loss.item()\n\n        acc = evaluate(model, A, val_dataloader, ndata, dataset.num_classes)\n        print(\n            \"Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} \".format(\n                epoch, total_loss / (it + 1), acc.item()\n            )\n        )\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"LADIESConv\")\n    parser.add_argument(\n        \"--mode\",\n        default=\"gpu\",\n        choices=[\"cpu\", \"gpu\"],\n        help=\"Training mode. 'cpu' for CPU training, 'gpu' for GPU training.\",\n    )\n    args = parser.parse_args()\n    if not torch.cuda.is_available():\n        args.mode = \"cpu\"\n    print(f\"Training in {args.mode} mode.\")\n\n    #####################################################################\n    # (HIGHLIGHT) This example implements a LADIES algorithm by sparse\n    # operators, which involves sampling a subgraph from a full graph and\n    # conducting training.\n    #\n    # First, the whole graph is loaded onto the CPU or GPU and transformed\n    # to sparse matrix. To obtain the training subgraph, it samples three\n    # submatrices by seed nodes, which contains their layer-wise sampled\n    # 1-hop, 2-hop, and 3-hop neighbors. Then, the features of the\n    # subgraph are input to the network for training.\n    #####################################################################\n\n    # Load and preprocess dataset.\n    print(\"Loading data\")\n    device = torch.device(\"cpu\" if args.mode == \"cpu\" else \"cuda\")\n    dataset = AsNodePredDataset(DglNodePropPredDataset(\"ogbn-products\"))\n    g = dataset[0]\n\n    # Create LADIES model.\n    in_size = g.ndata[\"feat\"].shape[1]\n    out_size = dataset.num_classes\n    model = LADIES(in_size, 256, out_size).to(device)\n\n    # Create sparse.\n    indices = torch.stack(g.edges())\n    N = g.num_nodes()\n    A = dglsp.spmatrix(indices, shape=(N, N)).coalesce()\n    I = dglsp.identity(A.shape)\n\n    # Initialize laplacian matrix.\n    A_hat = A + I\n    D_hat = dglsp.diag(A_hat.sum(1)) ** -0.5\n    A_norm = D_hat @ A_hat @ D_hat\n    A_norm = A_norm.to(device)\n    g = g.to(device)\n\n    # Model training.\n    print(\"Training...\")\n    train(device, A_norm, g.ndata, dataset, model)\n\n    # Test the model.\n    print(\"Testing...\")\n    acc = validate(device, A_norm, g.ndata, dataset, model, batch_size=2048)\n    print(f\"Test accuracy {acc:.4f}\")\n"
  },
  {
    "path": "examples/sparse/sgc.py",
    "content": "\"\"\"\n[Simplifying Graph Convolutional Networks]\n(https://arxiv.org/abs/1902.07153)\n\"\"\"\n\nimport dgl.sparse as dglsp\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dgl.data import CoraGraphDataset\nfrom torch.optim import Adam\n\n\n################################################################################\n# (HIGHLIGHT) Take the advantage of DGL sparse APIs to implement the feature\n# pre-computation.\n################################################################################\ndef pre_compute(A, X, k):\n    for _ in range(k):\n        X = A @ X\n    return X\n\n\ndef evaluate(g, pred):\n    label = g.ndata[\"label\"]\n    val_mask = g.ndata[\"val_mask\"]\n    test_mask = g.ndata[\"test_mask\"]\n\n    # Compute accuracy on validation/test set.\n    val_acc = (pred[val_mask] == label[val_mask]).float().mean()\n    test_acc = (pred[test_mask] == label[test_mask]).float().mean()\n    return val_acc, test_acc\n\n\ndef train(model, g, X_sgc):\n    label = g.ndata[\"label\"]\n    train_mask = g.ndata[\"train_mask\"]\n    optimizer = Adam(model.parameters(), lr=2e-1, weight_decay=5e-6)\n\n    for epoch in range(20):\n        # Forward.\n        logits = model(X_sgc)\n\n        # Compute loss with nodes in the training set.\n        loss = F.cross_entropy(logits[train_mask], label[train_mask])\n\n        # Backward.\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        # Compute prediction.\n        pred = logits.argmax(dim=1)\n\n        # Evaluate the prediction.\n        val_acc, test_acc = evaluate(g, pred)\n        print(\n            f\"In epoch {epoch}, loss: {loss:.3f}, val acc: {val_acc:.3f}, test\"\n            f\" acc: {test_acc:.3f}\"\n        )\n\n\nif __name__ == \"__main__\":\n    # If CUDA is available, use GPU to accelerate the training, use CPU\n    # otherwise.\n    dev = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n\n    # Load graph from the existing dataset.\n    dataset = CoraGraphDataset()\n    g = dataset[0].to(dev)\n\n    # Create the sparse adjacency matrix A\n    indices = torch.stack(g.edges())\n    N = g.num_nodes()\n    A = dglsp.spmatrix(indices, shape=(N, N))\n\n    # Calculate the symmetrically normalized adjacency matrix.\n    I = dglsp.identity(A.shape, device=dev)\n    A_hat = A + I\n    D_hat = dglsp.diag(A_hat.sum(dim=1)) ** -0.5\n    A_hat = D_hat @ A_hat @ D_hat\n\n    # 2-hop diffusion.\n    k = 2\n    X = g.ndata[\"feat\"]\n    X_sgc = pre_compute(A_hat, X, k)\n\n    # Create model.\n    in_size = X.shape[1]\n    out_size = dataset.num_classes\n    model = nn.Linear(in_size, out_size).to(dev)\n\n    # Kick off training.\n    train(model, g, X_sgc)\n"
  },
  {
    "path": "examples/sparse/sign.py",
    "content": "\"\"\"\n[SIGN: Scalable Inception Graph Neural Networks]\n(https://arxiv.org/abs/2004.11198)\n\nThis example shows a simplified version of SIGN: a precomputed 2-hops diffusion\noperator on top of symmetrically normalized adjacency matrix A_hat.\n\"\"\"\n\nimport dgl.sparse as dglsp\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dgl.data import CoraGraphDataset\nfrom torch.optim import Adam\n\n\n################################################################################\n# (HIGHLIGHT) Take the advantage of DGL sparse APIs to implement the feature\n# diffusion in SIGN laconically.\n################################################################################\ndef sign_diffusion(A, X, r):\n    # Perform the r-hop diffusion operation.\n    X_sign = [X]\n    for _ in range(r):\n        X = A @ X\n        X_sign.append(X)\n    return X_sign\n\n\nclass SIGN(nn.Module):\n    def __init__(self, in_size, out_size, r, hidden_size=256):\n        super().__init__()\n        # Note that theta and omega refer to the learnable matrices in the\n        # original paper correspondingly. The variable r refers to subscript to\n        # theta.\n        self.theta = nn.ModuleList(\n            [nn.Linear(in_size, hidden_size) for _ in range(r + 1)]\n        )\n        self.omega = nn.Linear(hidden_size * (r + 1), out_size)\n\n    def forward(self, X_sign):\n        results = []\n        for i in range(len(X_sign)):\n            results.append(self.theta[i](X_sign[i]))\n        Z = F.relu(torch.cat(results, dim=1))\n        return self.omega(Z)\n\n\ndef evaluate(g, pred):\n    label = g.ndata[\"label\"]\n    val_mask = g.ndata[\"val_mask\"]\n    test_mask = g.ndata[\"test_mask\"]\n\n    # Compute accuracy on validation/test set.\n    val_acc = (pred[val_mask] == label[val_mask]).float().mean()\n    test_acc = (pred[test_mask] == label[test_mask]).float().mean()\n    return val_acc, test_acc\n\n\ndef train(model, g, X_sign):\n    label = g.ndata[\"label\"]\n    train_mask = g.ndata[\"train_mask\"]\n    optimizer = Adam(model.parameters(), lr=3e-3)\n\n    for epoch in range(10):\n        # Switch the model to training mode.\n        model.train()\n\n        # Forward.\n        logits = model(X_sign)\n\n        # Compute loss with nodes in training set.\n        loss = F.cross_entropy(logits[train_mask], label[train_mask])\n\n        # Backward.\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        # Switch the model to evaluating mode.\n        model.eval()\n\n        # Compute prediction.\n        logits = model(X_sign)\n        pred = logits.argmax(1)\n\n        # Evaluate the prediction.\n        val_acc, test_acc = evaluate(g, pred)\n        print(\n            f\"In epoch {epoch}, loss: {loss:.3f}, val acc: {val_acc:.3f}, test\"\n            f\" acc: {test_acc:.3f}\"\n        )\n\n\nif __name__ == \"__main__\":\n    # If CUDA is available, use GPU to accelerate the training, use CPU\n    # otherwise.\n    dev = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n\n    # Load graph from the existing dataset.\n    dataset = CoraGraphDataset()\n    g = dataset[0].to(dev)\n\n    # Create the sparse adjacency matrix A (note that W was used as the notation\n    # for adjacency matrix in the original paper).\n    indices = torch.stack(g.edges())\n    N = g.num_nodes()\n    A = dglsp.spmatrix(indices, shape=(N, N))\n\n    # Calculate the symmetrically normalized adjacency matrix.\n    I = dglsp.identity(A.shape, device=dev)\n    A_hat = A + I\n    D_hat = dglsp.diag(A_hat.sum(dim=1)) ** -0.5\n    A_hat = D_hat @ A_hat @ D_hat\n\n    # 2-hop diffusion.\n    r = 2\n    X = g.ndata[\"feat\"]\n    X_sign = sign_diffusion(A_hat, X, r)\n\n    # Create SIGN model.\n    in_size = X.shape[1]\n    out_size = dataset.num_classes\n    model = SIGN(in_size, out_size, r).to(dev)\n\n    # Kick off training.\n    train(model, g, X_sign)\n"
  },
  {
    "path": "examples/sparse/twirls.py",
    "content": "\"\"\"\n[Graph Neural Networks Inspired by Classical Iterative Algorithms]\n(https://arxiv.org/pdf/2103.06064.pdf)\n\nThis example shows a simplified version of the TWIRLS model proposed\nin the paper. It implements two variants. One is the basic iterative\ngraph diffusion algorithm. The other is an advanced implementation\nwith attention.\n\"\"\"\n\nimport argparse\n\nimport dgl.sparse as dglsp\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dgl.data import CoraGraphDataset\nfrom torch.optim import Adam\n\n\nclass MLP(nn.Module):\n    def __init__(self, in_size, hidden_size):\n        super().__init__()\n        self.linear_1 = nn.Linear(in_size, hidden_size)\n        self.linear_2 = nn.Linear(hidden_size, hidden_size)\n        self.dropout = nn.Dropout(0.8)\n\n    def forward(self, X):\n        H = self.linear_1(X)\n        H = F.relu(H)\n        H = self.dropout(H)\n        H = self.linear_2(H)\n        return H\n\n\n################################################################################\n# (HIGHLIGHT) Use DGL sparse API to implement the iterative graph diffusion\n# algorithm.\n################################################################################\nclass TWIRLS(nn.Module):\n    def __init__(\n        self,\n        in_size,\n        out_size,\n        hidden_size=128,\n        num_steps=16,\n        lam=1.0,\n        alpha=0.5,\n    ):\n        super().__init__()\n        self.num_steps = num_steps\n        self.lam = lam\n        self.alpha = alpha\n        self.mlp = MLP(in_size, hidden_size)\n        self.linear_out = nn.Linear(hidden_size, out_size)\n\n    def forward(self, A, X):\n        # Compute Y = Y0 = f(X; W) using a two-layer MLP.\n        Y = Y0 = self.mlp(X)\n\n        # Compute diagonal matrix D_tild.\n        I = dglsp.identity(A.shape, device=A.device)\n        D_tild = self.lam * dglsp.diag(A.sum(1)) + I\n\n        # Iteratively compute new Y by equation (6) in the paper.\n        for k in range(self.num_steps):\n            Y_hat = self.lam * A @ Y + Y0\n            # The inverse of a diagonal matrix inverses its diagonal values.\n            Y = (1 - self.alpha) * Y + self.alpha * (D_tild**-1) @ Y_hat\n\n        # Apply a linear layer on the final output.\n        return self.linear_out(Y)\n\n\n################################################################################\n# (HIGHLIGHT) Implementation of the advanced TWIRLS model with attention\n# to show the usage of differentiable weighted sparse matrix.\n################################################################################\nclass TWIRLSWithAttention(nn.Module):\n    def __init__(\n        self,\n        in_size,\n        out_size,\n        hidden_size=128,\n        num_steps=16,\n        lam=1.0,\n        alpha=0.5,\n    ):\n        super().__init__()\n        self.num_steps = num_steps\n        self.lam = lam\n        self.alpha = alpha\n        self.mlp = MLP(in_size, hidden_size)\n        self.linear_out = nn.Linear(hidden_size, out_size)\n\n    def forward(self, A, X):\n        # Compute Y = Y0 = f(X; W) using a two-layer MLP.\n        Y = Y0 = self.mlp(X)\n\n        # Compute diagonal matrix D_tild.\n        I = dglsp.identity(A.shape, device=A.device)\n        D_tild = self.lam * dglsp.diag(A.sum(1)) + I\n\n        # Conduct half of the diffusion steps.\n        for k in range(self.num_steps // 2):\n            Y_hat = self.lam * A @ Y + Y0\n            Y = (1 - self.alpha) * Y + self.alpha * (D_tild**-1) @ Y_hat\n\n        # Calculate attention weight by equation (25) in the paper.\n        Y_i = Y[A.row]\n        Y_j = Y[A.col]\n        norm_ij = torch.linalg.vector_norm(Y_i - Y_j, dim=1)\n        # Bound the attention value within [0.0, 1.0).\n        gamma_ij = torch.clamp(0.5 / (norm_ij + 1e-7), min=0.0, max=1.0)\n        # Create a new adjacency matrix with the new weight.\n        A = dglsp.val_like(A, gamma_ij)\n        # Recompute D_tild.\n        D_tild = self.lam * dglsp.diag(A.sum(1)) + I\n\n        # Conduct the other half of the diffusion steps.\n        for k in range(self.num_steps // 2):\n            Y_hat = self.lam * A @ Y + Y0\n            Y = (1 - self.alpha) * Y + self.alpha * (D_tild**-1) @ Y_hat\n\n        # Apply a linear layer on the final output.\n        return self.linear_out(Y)\n\n\ndef evaluate(g, pred):\n    model.eval()\n    label = g.ndata[\"label\"]\n    val_mask = g.ndata[\"val_mask\"]\n    test_mask = g.ndata[\"test_mask\"]\n\n    # Compute accuracy on validation/test set.\n    val_acc = (pred[val_mask] == label[val_mask]).float().mean()\n    test_acc = (pred[test_mask] == label[test_mask]).float().mean()\n    return val_acc, test_acc\n\n\ndef train(g, model, A, X):\n    labels = g.ndata[\"label\"]\n    train_mask = g.ndata[\"train_mask\"]\n    optimizer = Adam(model.parameters(), lr=5e-4)\n\n    for epoch in range(300):\n        model.train()\n        # Forward.\n        logits = model(A, X)\n\n        # Compute loss with nodes in training set.\n        loss = F.cross_entropy(logits[train_mask], labels[train_mask])\n\n        # Backward.\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        # Compute prediction.\n        pred = logits.argmax(1)\n\n        # Evaluate the prediction.\n        val_acc, test_acc = evaluate(g, pred)\n        print(\n            f\"In epoch {epoch}, loss: {loss:.3f}, val acc: {val_acc:.3f}, test\"\n            f\" acc: {test_acc:.3f}\"\n        )\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(\"TWIRLS example in DGL Sparse.\")\n    parser.add_argument(\n        \"--attention\", action=\"store_true\", help=\"Use TWIRLS with attention.\"\n    )\n    args = parser.parse_args()\n    # If CUDA is available, use GPU to accelerate the training, use CPU\n    # otherwise.\n    dev = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n\n    # Load graph from the existing dataset.\n    dataset = CoraGraphDataset()\n    g = dataset[0].to(dev)\n    X = g.ndata[\"feat\"]\n\n    # Create the sparse adjacency matrix A.\n    indices = torch.stack(g.edges())\n    N = g.num_nodes()\n    A = dglsp.spmatrix(indices, shape=(N, N))\n\n    # Create the TWIRLS model.\n    in_size = X.shape[1]\n    out_size = dataset.num_classes\n    if args.attention:\n        model = TWIRLSWithAttention(in_size, out_size).to(dev)\n    else:\n        model = TWIRLS(in_size, out_size).to(dev)\n\n    # Kick off training.\n    train(g, model, A, X)\n"
  },
  {
    "path": "examples/tensorflow/dgi/README.md",
    "content": "Deep Graph Infomax (DGI)\n========================\n\n- Paper link: [https://arxiv.org/abs/1809.10341](https://arxiv.org/abs/1809.10341)\n- Author's code repo (in Pytorch):\n  [https://github.com/PetarV-/DGI](https://github.com/PetarV-/DGI)\n\nDependencies\n------------\n- tensorflow 2.1+\n- requests\n\n```bash\npip install tensorflow requests\n```\n\nHow to run\n----------\n\nRun with following:\n\n```bash\npython3 train.py --dataset=cora --gpu=0 --self-loop\n```\n\n```bash\npython3 train.py --dataset=citeseer --gpu=0\n```\n\n```bash\npython3 train.py --dataset=pubmed --gpu=0\n```\n\nResults\n-------\n* cora: ~81.6 (80.9-82.9) (paper: 82.3)\n* citeseer: ~70.2 (paper: 71.8)\n* pubmed: ~77.2 (paper: 76.8)\n"
  },
  {
    "path": "examples/tensorflow/dgi/dgi.py",
    "content": "\"\"\"\nDeep Graph Infomax in DGL\n\nReferences\n----------\nPapers: https://arxiv.org/abs/1809.10341\nAuthor's code: https://github.com/PetarV-/DGI\n\"\"\"\n\nimport math\n\nimport numpy as np\nimport tensorflow as tf\nfrom gcn import GCN\nfrom tensorflow.keras import layers\n\n\nclass Encoder(layers.Layer):\n    def __init__(self, g, in_feats, n_hidden, n_layers, activation, dropout):\n        super(Encoder, self).__init__()\n        self.g = g\n        self.conv = GCN(\n            g, in_feats, n_hidden, n_hidden, n_layers, activation, dropout\n        )\n\n    def call(self, features, corrupt=False):\n        if corrupt:\n            perm = np.random.permutation(self.g.number_of_nodes())\n            features = tf.gather(features, perm)\n        features = self.conv(features)\n        return features\n\n\nclass Discriminator(layers.Layer):\n    def __init__(self, n_hidden):\n        super(Discriminator, self).__init__()\n        uinit = tf.keras.initializers.RandomUniform(\n            -1.0 / math.sqrt(n_hidden), 1.0 / math.sqrt(n_hidden)\n        )\n        self.weight = tf.Variable(\n            initial_value=uinit(shape=(n_hidden, n_hidden), dtype=\"float32\"),\n            trainable=True,\n        )\n\n    def call(self, features, summary):\n        features = tf.matmul(\n            features, tf.matmul(self.weight, tf.expand_dims(summary, -1))\n        )\n        return features\n\n\nclass DGI(tf.keras.Model):\n    def __init__(self, g, in_feats, n_hidden, n_layers, activation, dropout):\n        super(DGI, self).__init__()\n        self.encoder = Encoder(\n            g, in_feats, n_hidden, n_layers, activation, dropout\n        )\n        self.discriminator = Discriminator(n_hidden)\n        self.loss = tf.nn.sigmoid_cross_entropy_with_logits\n\n    def call(self, features):\n        positive = self.encoder(features, corrupt=False)\n        negative = self.encoder(features, corrupt=True)\n        summary = tf.nn.sigmoid(tf.reduce_mean(positive, axis=0))\n\n        positive = self.discriminator(positive, summary)\n        negative = self.discriminator(negative, summary)\n\n        l1 = self.loss(tf.ones(positive.shape), positive)\n        l2 = self.loss(tf.zeros(negative.shape), negative)\n\n        return tf.reduce_mean(l1) + tf.reduce_mean(l2)\n\n\nclass Classifier(layers.Layer):\n    def __init__(self, n_hidden, n_classes):\n        super(Classifier, self).__init__()\n        self.fc = layers.Dense(n_classes)\n\n    def call(self, features):\n        features = self.fc(features)\n        return features\n"
  },
  {
    "path": "examples/tensorflow/dgi/gcn.py",
    "content": "\"\"\"\nThis code was copied from the GCN implementation in DGL examples.\n\"\"\"\nimport tensorflow as tf\n\nfrom dgl.nn.tensorflow import GraphConv\nfrom tensorflow.keras import layers\n\n\nclass GCN(layers.Layer):\n    def __init__(\n        self, g, in_feats, n_hidden, n_classes, n_layers, activation, dropout\n    ):\n        super(GCN, self).__init__()\n        self.g = g\n        self.layers = []\n        # input layer\n        self.layers.append(GraphConv(in_feats, n_hidden, activation=activation))\n        # hidden layers\n        for i in range(n_layers - 1):\n            self.layers.append(\n                GraphConv(n_hidden, n_hidden, activation=activation)\n            )\n        # output layer\n        self.layers.append(GraphConv(n_hidden, n_classes))\n        self.dropout = layers.Dropout(dropout)\n\n    def call(self, features):\n        h = features\n        for i, layer in enumerate(self.layers):\n            if i != 0:\n                h = self.dropout(h)\n            h = layer(self.g, h)\n        return h\n"
  },
  {
    "path": "examples/tensorflow/dgi/train.py",
    "content": "import argparse\nimport time\n\nimport dgl\n\nimport networkx as nx\nimport numpy as np\nimport tensorflow as tf\nfrom dgi import Classifier, DGI\nfrom dgl.data import (\n    CiteseerGraphDataset,\n    CoraGraphDataset,\n    PubmedGraphDataset,\n    register_data_args,\n)\nfrom tensorflow.keras import layers\n\n\ndef evaluate(model, features, labels, mask):\n    logits = model(features, training=False)\n    logits = logits[mask]\n    labels = labels[mask]\n    indices = tf.math.argmax(logits, axis=1)\n    acc = tf.reduce_mean(tf.cast(indices == labels, dtype=tf.float32))\n    return acc.numpy().item()\n\n\ndef main(args):\n    # load and preprocess dataset\n    if args.dataset == \"cora\":\n        data = CoraGraphDataset()\n    elif args.dataset == \"citeseer\":\n        data = CiteseerGraphDataset()\n    elif args.dataset == \"pubmed\":\n        data = PubmedGraphDataset()\n    else:\n        raise ValueError(\"Unknown dataset: {}\".format(args.dataset))\n\n    g = data[0]\n    if args.gpu < 0:\n        device = \"/cpu:0\"\n    else:\n        device = \"/gpu:{}\".format(args.gpu)\n        g = g.to(device)\n\n    with tf.device(device):\n        features = g.ndata[\"feat\"]\n        labels = g.ndata[\"label\"]\n        train_mask = g.ndata[\"train_mask\"]\n        val_mask = g.ndata[\"val_mask\"]\n        test_mask = g.ndata[\"test_mask\"]\n        in_feats = features.shape[1]\n        n_classes = data.num_classes\n        n_edges = g.number_of_edges()\n\n        # add self loop\n        if args.self_loop:\n            g = dgl.remove_self_loop(g)\n            g = dgl.add_self_loop(g)\n        n_edges = g.number_of_edges()\n\n        # create DGI model\n        dgi = DGI(\n            g,\n            in_feats,\n            args.n_hidden,\n            args.n_layers,\n            tf.keras.layers.PReLU(\n                alpha_initializer=tf.constant_initializer(0.25)\n            ),\n            args.dropout,\n        )\n\n        dgi_optimizer = tf.keras.optimizers.Adam(learning_rate=args.dgi_lr)\n\n        # train deep graph infomax\n        cnt_wait = 0\n        best = 1e9\n        best_t = 0\n        dur = []\n        for epoch in range(args.n_dgi_epochs):\n            if epoch >= 3:\n                t0 = time.time()\n\n            with tf.GradientTape() as tape:\n                loss = dgi(features)\n                # Manually Weight Decay\n                # We found Tensorflow has a different implementation on weight decay\n                # of Adam(W) optimizer with PyTorch. And this results in worse results.\n                # Manually adding weights to the loss to do weight decay solves this problem.\n                for weight in dgi.trainable_weights:\n                    loss = loss + args.weight_decay * tf.nn.l2_loss(weight)\n                grads = tape.gradient(loss, dgi.trainable_weights)\n                dgi_optimizer.apply_gradients(zip(grads, dgi.trainable_weights))\n\n            if loss < best:\n                best = loss\n                best_t = epoch\n                cnt_wait = 0\n                dgi.save_weights(\"best_dgi.pkl\")\n            else:\n                cnt_wait += 1\n\n            if cnt_wait == args.patience:\n                print(\"Early stopping!\")\n                break\n\n            if epoch >= 3:\n                dur.append(time.time() - t0)\n\n            print(\n                \"Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | \"\n                \"ETputs(KTEPS) {:.2f}\".format(\n                    epoch,\n                    np.mean(dur),\n                    loss.numpy().item(),\n                    n_edges / np.mean(dur) / 1000,\n                )\n            )\n\n        # create classifier model\n        classifier = Classifier(args.n_hidden, n_classes)\n\n        classifier_optimizer = tf.keras.optimizers.Adam(\n            learning_rate=args.classifier_lr\n        )\n\n        # train classifier\n        print(\"Loading {}th epoch\".format(best_t))\n        dgi.load_weights(\"best_dgi.pkl\")\n        embeds = dgi.encoder(features, corrupt=False)\n        embeds = tf.stop_gradient(embeds)\n        dur = []\n        loss_fcn = tf.keras.losses.SparseCategoricalCrossentropy(\n            from_logits=True\n        )\n        for epoch in range(args.n_classifier_epochs):\n            if epoch >= 3:\n                t0 = time.time()\n            with tf.GradientTape() as tape:\n                preds = classifier(embeds)\n                loss = loss_fcn(labels[train_mask], preds[train_mask])\n                # Manually Weight Decay\n                # We found Tensorflow has a different implementation on weight decay\n                # of Adam(W) optimizer with PyTorch. And this results in worse results.\n                # Manually adding weights to the loss to do weight decay solves this problem.\n                # In original code, there's no weight decay applied in this part\n                # link: https://github.com/PetarV-/DGI/blob/master/execute.py#L121\n                # for weight in classifier.trainable_weights:\n                #     loss = loss + \\\n                #         args.weight_decay * tf.nn.l2_loss(weight)\n                grads = tape.gradient(loss, classifier.trainable_weights)\n                classifier_optimizer.apply_gradients(\n                    zip(grads, classifier.trainable_weights)\n                )\n            if epoch >= 3:\n                dur.append(time.time() - t0)\n\n            acc = evaluate(classifier, embeds, labels, val_mask)\n            print(\n                \"Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | \"\n                \"ETputs(KTEPS) {:.2f}\".format(\n                    epoch,\n                    np.mean(dur),\n                    loss.numpy().item(),\n                    acc,\n                    n_edges / np.mean(dur) / 1000,\n                )\n            )\n\n        print()\n        acc = evaluate(classifier, embeds, labels, test_mask)\n        print(\"Test Accuracy {:.4f}\".format(acc))\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"DGI\")\n    register_data_args(parser)\n    parser.add_argument(\n        \"--dropout\", type=float, default=0.0, help=\"dropout probability\"\n    )\n    parser.add_argument(\"--gpu\", type=int, default=-1, help=\"gpu\")\n    parser.add_argument(\n        \"--dgi-lr\", type=float, default=1e-3, help=\"dgi learning rate\"\n    )\n    parser.add_argument(\n        \"--classifier-lr\",\n        type=float,\n        default=1e-2,\n        help=\"classifier learning rate\",\n    )\n    parser.add_argument(\n        \"--n-dgi-epochs\",\n        type=int,\n        default=300,\n        help=\"number of training epochs\",\n    )\n    parser.add_argument(\n        \"--n-classifier-epochs\",\n        type=int,\n        default=300,\n        help=\"number of training epochs\",\n    )\n    parser.add_argument(\n        \"--n-hidden\", type=int, default=512, help=\"number of hidden gcn units\"\n    )\n    parser.add_argument(\n        \"--n-layers\", type=int, default=1, help=\"number of hidden gcn layers\"\n    )\n    parser.add_argument(\n        \"--weight-decay\", type=float, default=0.0, help=\"Weight for L2 loss\"\n    )\n    parser.add_argument(\n        \"--patience\", type=int, default=20, help=\"early stop patience condition\"\n    )\n    parser.add_argument(\n        \"--self-loop\",\n        action=\"store_true\",\n        help=\"graph self-loop (default=False)\",\n    )\n    parser.set_defaults(self_loop=False)\n    args = parser.parse_args()\n    print(args)\n\n    main(args)\n"
  },
  {
    "path": "examples/tensorflow/gat/README.md",
    "content": "Graph Attention Networks (GAT)\n============\n\n- Paper link: [https://arxiv.org/abs/1710.10903](https://arxiv.org/abs/1710.10903)\n- Author's code repo (in Tensorflow):\n  [https://github.com/PetarV-/GAT](https://github.com/PetarV-/GAT).\n- Popular pytorch implementation:\n  [https://github.com/Diego999/pyGAT](https://github.com/Diego999/pyGAT).\n\nDependencies\n------------\n- tensorflow 2.1.0+\n- requests\n\n```bash\npip install tensorflow requests\nDGLBACKEND=tensorflow\n```\n\nHow to run\n----------\n\nRun with following:\n\n```bash\npython3 train.py --dataset=cora --gpu=0\n```\n\n```bash\npython3 train.py --dataset=citeseer --gpu=0 --early-stop\n```\n\n```bash\npython3 train.py --dataset=pubmed --gpu=0 --num-out-heads=8 --weight-decay=0.001 --early-stop\n```\n\n\nResults\n-------\n\n| Dataset  | Test Accuracy | Baseline (paper) |\n| -------- | ------------- | ---------------- |\n| Cora     | 84.2          | 83.0(+-0.7)      |\n| Citeseer | 70.9          | 72.5(+-0.7)      |\n| Pubmed   | 78.5          | 79.0(+-0.3)      |\n\n* All the accuracy numbers are obtained after 200 epochs.\n* All time is measured on EC2 p3.2xlarge instance w/ V100 GPU.\n"
  },
  {
    "path": "examples/tensorflow/gat/gat.py",
    "content": "\"\"\"\nGraph Attention Networks in DGL using SPMV optimization.\nReferences\n----------\nPaper: https://arxiv.org/abs/1710.10903\nAuthor's code: https://github.com/PetarV-/GAT\nPytorch implementation: https://github.com/Diego999/pyGAT\n\"\"\"\n\nimport dgl.function as fn\nimport tensorflow as tf\nfrom dgl.nn import GATConv\nfrom tensorflow.keras import layers\n\n\nclass GAT(tf.keras.Model):\n    def __init__(\n        self,\n        g,\n        num_layers,\n        in_dim,\n        num_hidden,\n        num_classes,\n        heads,\n        activation,\n        feat_drop,\n        attn_drop,\n        negative_slope,\n        residual,\n    ):\n        super(GAT, self).__init__()\n        self.g = g\n        self.num_layers = num_layers\n        self.gat_layers = []\n        self.activation = activation\n        # input projection (no residual)\n        self.gat_layers.append(\n            GATConv(\n                in_dim,\n                num_hidden,\n                heads[0],\n                feat_drop,\n                attn_drop,\n                negative_slope,\n                False,\n                self.activation,\n            )\n        )\n        # hidden layers\n        for l in range(1, num_layers):\n            # due to multi-head, the in_dim = num_hidden * num_heads\n            self.gat_layers.append(\n                GATConv(\n                    num_hidden * heads[l - 1],\n                    num_hidden,\n                    heads[l],\n                    feat_drop,\n                    attn_drop,\n                    negative_slope,\n                    residual,\n                    self.activation,\n                )\n            )\n        # output projection\n        self.gat_layers.append(\n            GATConv(\n                num_hidden * heads[-2],\n                num_classes,\n                heads[-1],\n                feat_drop,\n                attn_drop,\n                negative_slope,\n                residual,\n                None,\n            )\n        )\n\n    def call(self, inputs):\n        h = inputs\n        for l in range(self.num_layers):\n            h = self.gat_layers[l](self.g, h)\n            h = tf.reshape(h, (h.shape[0], -1))\n        # output projection\n        logits = tf.reduce_mean(self.gat_layers[-1](self.g, h), axis=1)\n        return logits\n"
  },
  {
    "path": "examples/tensorflow/gat/train.py",
    "content": "\"\"\"\nGraph Attention Networks in DGL using SPMV optimization.\nMultiple heads are also batched together for faster training.\nCompared with the original paper, this code does not implement\nearly stopping.\nReferences\n----------\nPaper: https://arxiv.org/abs/1710.10903\nAuthor's code: https://github.com/PetarV-/GAT\nPytorch implementation: https://github.com/Diego999/pyGAT\n\"\"\"\n\nimport argparse\nimport time\n\nimport dgl\n\nimport networkx as nx\nimport numpy as np\nimport tensorflow as tf\nfrom dgl.data import (\n    CiteseerGraphDataset,\n    CoraGraphDataset,\n    PubmedGraphDataset,\n    register_data_args,\n)\nfrom gat import GAT\nfrom utils import EarlyStopping\n\n\ndef accuracy(logits, labels):\n    indices = tf.math.argmax(logits, axis=1)\n    acc = tf.reduce_mean(tf.cast(indices == labels, dtype=tf.float32))\n    return acc.numpy().item()\n\n\ndef evaluate(model, features, labels, mask):\n    logits = model(features, training=False)\n    logits = logits[mask]\n    labels = labels[mask]\n    return accuracy(logits, labels)\n\n\ndef main(args):\n    # load and preprocess dataset\n    if args.dataset == \"cora\":\n        data = CoraGraphDataset()\n    elif args.dataset == \"citeseer\":\n        data = CiteseerGraphDataset()\n    elif args.dataset == \"pubmed\":\n        data = PubmedGraphDataset()\n    else:\n        raise ValueError(\"Unknown dataset: {}\".format(args.dataset))\n\n    g = data[0]\n    if args.gpu < 0:\n        device = \"/cpu:0\"\n    else:\n        device = \"/gpu:{}\".format(args.gpu)\n        g = g.to(device)\n\n    with tf.device(device):\n        features = g.ndata[\"feat\"]\n        labels = g.ndata[\"label\"]\n        train_mask = g.ndata[\"train_mask\"]\n        val_mask = g.ndata[\"val_mask\"]\n        test_mask = g.ndata[\"test_mask\"]\n        num_feats = features.shape[1]\n        n_classes = data.num_classes\n        n_edges = g.number_of_edges()\n        print(\n            \"\"\"----Data statistics------'\n        #Edges %d\n        #Classes %d\n        #Train samples %d\n        #Val samples %d\n        #Test samples %d\"\"\"\n            % (\n                n_edges,\n                n_classes,\n                train_mask.numpy().sum(),\n                val_mask.numpy().sum(),\n                test_mask.numpy().sum(),\n            )\n        )\n\n        g = dgl.remove_self_loop(g)\n        g = dgl.add_self_loop(g)\n        n_edges = g.number_of_edges()\n        # create model\n        heads = ([args.num_heads] * args.num_layers) + [args.num_out_heads]\n        model = GAT(\n            g,\n            args.num_layers,\n            num_feats,\n            args.num_hidden,\n            n_classes,\n            heads,\n            tf.nn.elu,\n            args.in_drop,\n            args.attn_drop,\n            args.negative_slope,\n            args.residual,\n        )\n        print(model)\n        if args.early_stop:\n            stopper = EarlyStopping(patience=100)\n\n        # loss_fcn = tf.keras.losses.SparseCategoricalCrossentropy(\n        #     from_logits=False)\n        loss_fcn = tf.nn.sparse_softmax_cross_entropy_with_logits\n\n        # use optimizer\n        optimizer = tf.keras.optimizers.Adam(\n            learning_rate=args.lr, epsilon=1e-8\n        )\n\n        # initialize graph\n        dur = []\n        for epoch in range(args.epochs):\n            if epoch >= 3:\n                t0 = time.time()\n            # forward\n            with tf.GradientTape() as tape:\n                tape.watch(model.trainable_weights)\n                logits = model(features, training=True)\n                loss_value = tf.reduce_mean(\n                    loss_fcn(\n                        labels=labels[train_mask], logits=logits[train_mask]\n                    )\n                )\n                # Manually Weight Decay\n                # We found Tensorflow has a different implementation on weight decay\n                # of Adam(W) optimizer with PyTorch. And this results in worse results.\n                # Manually adding weights to the loss to do weight decay solves this problem.\n                for weight in model.trainable_weights:\n                    loss_value = loss_value + args.weight_decay * tf.nn.l2_loss(\n                        weight\n                    )\n\n                grads = tape.gradient(loss_value, model.trainable_weights)\n                optimizer.apply_gradients(zip(grads, model.trainable_weights))\n\n            if epoch >= 3:\n                dur.append(time.time() - t0)\n\n            train_acc = accuracy(logits[train_mask], labels[train_mask])\n\n            if args.fastmode:\n                val_acc = accuracy(logits[val_mask], labels[val_mask])\n            else:\n                val_acc = evaluate(model, features, labels, val_mask)\n                if args.early_stop:\n                    if stopper.step(val_acc, model):\n                        break\n\n            print(\n                \"Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | TrainAcc {:.4f} |\"\n                \" ValAcc {:.4f} | ETputs(KTEPS) {:.2f}\".format(\n                    epoch,\n                    np.mean(dur),\n                    loss_value.numpy().item(),\n                    train_acc,\n                    val_acc,\n                    n_edges / np.mean(dur) / 1000,\n                )\n            )\n\n        print()\n        if args.early_stop:\n            model.load_weights(\"es_checkpoint.pb\")\n        acc = evaluate(model, features, labels, test_mask)\n        print(\"Test Accuracy {:.4f}\".format(acc))\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"GAT\")\n    register_data_args(parser)\n    parser.add_argument(\n        \"--gpu\",\n        type=int,\n        default=-1,\n        help=\"which GPU to use. Set -1 to use CPU.\",\n    )\n    parser.add_argument(\n        \"--epochs\", type=int, default=200, help=\"number of training epochs\"\n    )\n    parser.add_argument(\n        \"--num-heads\",\n        type=int,\n        default=8,\n        help=\"number of hidden attention heads\",\n    )\n    parser.add_argument(\n        \"--num-out-heads\",\n        type=int,\n        default=1,\n        help=\"number of output attention heads\",\n    )\n    parser.add_argument(\n        \"--num-layers\", type=int, default=1, help=\"number of hidden layers\"\n    )\n    parser.add_argument(\n        \"--num-hidden\", type=int, default=8, help=\"number of hidden units\"\n    )\n    parser.add_argument(\n        \"--residual\",\n        action=\"store_true\",\n        default=False,\n        help=\"use residual connection\",\n    )\n    parser.add_argument(\n        \"--in-drop\", type=float, default=0.6, help=\"input feature dropout\"\n    )\n    parser.add_argument(\n        \"--attn-drop\", type=float, default=0.6, help=\"attention dropout\"\n    )\n    parser.add_argument(\"--lr\", type=float, default=0.005, help=\"learning rate\")\n    parser.add_argument(\n        \"--weight-decay\", type=float, default=5e-4, help=\"weight decay\"\n    )\n    parser.add_argument(\n        \"--negative-slope\",\n        type=float,\n        default=0.2,\n        help=\"the negative slope of leaky relu\",\n    )\n    parser.add_argument(\n        \"--early-stop\",\n        action=\"store_true\",\n        default=False,\n        help=\"indicates whether to use early stop or not\",\n    )\n    parser.add_argument(\n        \"--fastmode\",\n        action=\"store_true\",\n        default=False,\n        help=\"skip re-evaluate the validation set\",\n    )\n    args = parser.parse_args()\n    print(args)\n\n    main(args)\n"
  },
  {
    "path": "examples/tensorflow/gat/utils.py",
    "content": "import numpy as np\n\n\nclass EarlyStopping:\n    def __init__(self, patience=10):\n        self.patience = patience\n        self.counter = 0\n        self.best_score = None\n        self.early_stop = False\n\n    def step(self, acc, model):\n        score = acc\n        if self.best_score is None:\n            self.best_score = score\n            self.save_checkpoint(model)\n        elif score < self.best_score:\n            self.counter += 1\n            print(\n                f\"EarlyStopping counter: {self.counter} out of {self.patience}\"\n            )\n            if self.counter >= self.patience:\n                self.early_stop = True\n        else:\n            self.best_score = score\n            self.save_checkpoint(model)\n            self.counter = 0\n        return self.early_stop\n\n    def save_checkpoint(self, model):\n        \"\"\"Saves model when validation loss decrease.\"\"\"\n        model.save_weights(\"es_checkpoint.pb\")\n"
  },
  {
    "path": "examples/tensorflow/gcn/README.md",
    "content": "Graph Convolutional Networks (GCN)\n============\n\n- Paper link: [https://arxiv.org/abs/1609.02907](https://arxiv.org/abs/1609.02907)\n- Author's code repo: [https://github.com/tkipf/gcn](https://github.com/tkipf/gcn). Note that the original code is\nimplemented with Tensorflow for the paper.\n\nDependencies\n------------\n- Tensorflow 2.1+\n- requests\n\n``bash\npip install tensorflow requests\nexport DGLBACKEND=tensorflow\n``\n\nCodes\n-----\nThe folder contains three implementations of GCN:\n- `gcn.py` uses DGL's predefined graph convolution module.\n- `gcn_mp.py` uses user-defined message and reduce functions.\n- `gcn_builtin.py` improves from `gcn_mp.py` by using DGL's builtin functions\n   so SPMV optimization could be applied.\n\nResults\n-------\n\nRun with following (available dataset: \"cora\", \"citeseer\", \"pubmed\")\n```bash\npython3 train.py --dataset cora --gpu 0 --self-loop\n```\n\n* cora: ~0.810 (0.79-0.83) (paper: 0.815)\n* citeseer: 0.707 (paper: 0.703)\n* pubmed: 0.792 (paper: 0.790)\n"
  },
  {
    "path": "examples/tensorflow/gcn/gcn.py",
    "content": "\"\"\"GCN using DGL nn package\n\nReferences:\n- Semi-Supervised Classification with Graph Convolutional Networks\n- Paper: https://arxiv.org/abs/1609.02907\n- Code: https://github.com/tkipf/gcn\n\"\"\"\nimport tensorflow as tf\n\nfrom dgl.nn.tensorflow import GraphConv\nfrom tensorflow.keras import layers\n\n\nclass GCN(tf.keras.Model):\n    def __init__(\n        self, g, in_feats, n_hidden, n_classes, n_layers, activation, dropout\n    ):\n        super(GCN, self).__init__()\n        self.g = g\n        self.layer_list = []\n        # input layer\n        self.layer_list.append(\n            GraphConv(in_feats, n_hidden, activation=activation)\n        )\n        # hidden layers\n        for i in range(n_layers - 1):\n            self.layer_list.append(\n                GraphConv(n_hidden, n_hidden, activation=activation)\n            )\n        # output layer\n        self.layer_list.append(GraphConv(n_hidden, n_classes))\n        self.dropout = layers.Dropout(dropout)\n\n    def call(self, features):\n        h = features\n        for i, layer in enumerate(self.layer_list):\n            if i != 0:\n                h = self.dropout(h)\n            h = layer(self.g, h)\n        return h\n"
  },
  {
    "path": "examples/tensorflow/gcn/gcn_builtin.py",
    "content": "import argparse\nimport math\nimport time\n\nimport dgl\nimport dgl.function as fn\nimport networkx as nx\nimport numpy as np\nimport tensorflow as tf\nfrom dgl.data import (\n    CiteseerGraphDataset,\n    CoraGraphDataset,\n    PubmedGraphDataset,\n    register_data_args,\n)\nfrom tensorflow.keras import layers\n\n\nclass GCNLayer(layers.Layer):\n    def __init__(self, g, in_feats, out_feats, activation, dropout, bias=True):\n        super(GCNLayer, self).__init__()\n        self.g = g\n\n        w_init = tf.keras.initializers.VarianceScaling(\n            scale=1.0, mode=\"fan_out\", distribution=\"uniform\"\n        )\n        self.weight = tf.Variable(\n            initial_value=w_init(shape=(in_feats, out_feats), dtype=\"float32\"),\n            trainable=True,\n        )\n        if dropout:\n            self.dropout = layers.Dropout(rate=dropout)\n        else:\n            self.dropout = 0.0\n        if bias:\n            b_init = tf.zeros_initializer()\n            self.bias = tf.Variable(\n                initial_value=b_init(shape=(out_feats,), dtype=\"float32\"),\n                trainable=True,\n            )\n        else:\n            self.bias = None\n        self.activation = activation\n\n    def call(self, h):\n        if self.dropout:\n            h = self.dropout(h)\n        self.g.ndata[\"h\"] = tf.matmul(h, self.weight)\n        self.g.ndata[\"norm_h\"] = self.g.ndata[\"h\"] * self.g.ndata[\"norm\"]\n        self.g.update_all(fn.copy_u(\"norm_h\", \"m\"), fn.sum(\"m\", \"h\"))\n        h = self.g.ndata[\"h\"]\n        if self.bias is not None:\n            h = h + self.bias\n        if self.activation:\n            h = self.activation(h)\n        return h\n\n\nclass GCN(layers.Layer):\n    def __init__(\n        self, g, in_feats, n_hidden, n_classes, n_layers, activation, dropout\n    ):\n        super(GCN, self).__init__()\n        self.layers = []\n\n        # input layer\n        self.layers.append(GCNLayer(g, in_feats, n_hidden, activation, dropout))\n        # hidden layers\n        for i in range(n_layers - 1):\n            self.layers.append(\n                GCNLayer(g, n_hidden, n_hidden, activation, dropout)\n            )\n        # output layer\n        self.layers.append(GCNLayer(g, n_hidden, n_classes, None, dropout))\n\n    def call(self, features):\n        h = features\n        for layer in self.layers:\n            h = layer(h)\n        return h\n\n\ndef evaluate(model, features, labels, mask):\n    logits = model(features, training=False)\n    logits = logits[mask]\n    labels = labels[mask]\n    indices = tf.math.argmax(logits, axis=1)\n    acc = tf.reduce_mean(tf.cast(indices == labels, dtype=tf.float32))\n    return acc.numpy().item()\n\n\ndef main(args):\n    # load and preprocess dataset\n    if args.dataset == \"cora\":\n        data = CoraGraphDataset()\n    elif args.dataset == \"citeseer\":\n        data = CiteseerGraphDataset()\n    elif args.dataset == \"pubmed\":\n        data = PubmedGraphDataset()\n    else:\n        raise ValueError(\"Unknown dataset: {}\".format(args.dataset))\n\n    g = data[0]\n    if args.gpu < 0:\n        device = \"/cpu:0\"\n    else:\n        device = \"/gpu:{}\".format(args.gpu)\n        g = g.to(device)\n\n    with tf.device(device):\n        features = g.ndata[\"feat\"]\n        labels = g.ndata[\"label\"]\n        train_mask = g.ndata[\"train_mask\"]\n        val_mask = g.ndata[\"val_mask\"]\n        test_mask = g.ndata[\"test_mask\"]\n        in_feats = features.shape[1]\n        n_classes = data.num_classes\n        n_edges = data.graph.number_of_edges()\n        print(\n            \"\"\"----Data statistics------'\n        #Edges %d\n        #Classes %d\n        #Train samples %d\n        #Val samples %d\n        #Test samples %d\"\"\"\n            % (\n                n_edges,\n                n_classes,\n                train_mask.numpy().sum(),\n                val_mask.numpy().sum(),\n                test_mask.numpy().sum(),\n            )\n        )\n\n        # add self loop\n        g = dgl.remove_self_loop(g)\n        g = dgl.add_self_loop(g)\n        n_edges = g.number_of_edges()\n        # # normalization\n        degs = tf.cast(tf.identity(g.in_degrees()), dtype=tf.float32)\n        norm = tf.math.pow(degs, -0.5)\n        norm = tf.where(tf.math.is_inf(norm), tf.zeros_like(norm), norm)\n\n        g.ndata[\"norm\"] = tf.expand_dims(norm, -1)\n\n        # create GCN model\n        model = GCN(\n            g,\n            in_feats,\n            args.n_hidden,\n            n_classes,\n            args.n_layers,\n            tf.nn.relu,\n            args.dropout,\n        )\n\n        optimizer = tf.keras.optimizers.Adam(learning_rate=args.lr)\n\n        loss_fcn = tf.keras.losses.SparseCategoricalCrossentropy(\n            from_logits=True\n        )\n        # initialize graph\n        dur = []\n        for epoch in range(args.n_epochs):\n            if epoch >= 3:\n                t0 = time.time()\n            # forward\n            with tf.GradientTape() as tape:\n                logits = model(features)\n                loss_value = loss_fcn(labels[train_mask], logits[train_mask])\n                # Manually Weight Decay\n                # We found Tensorflow has a different implementation on weight decay\n                # of Adam(W) optimizer with PyTorch. And this results in worse results.\n                # Manually adding weights to the loss to do weight decay solves this problem.\n                for weight in model.trainable_weights:\n                    loss_value = loss_value + args.weight_decay * tf.nn.l2_loss(\n                        weight\n                    )\n\n                grads = tape.gradient(loss_value, model.trainable_weights)\n                optimizer.apply_gradients(zip(grads, model.trainable_weights))\n\n            if epoch >= 3:\n                dur.append(time.time() - t0)\n\n            acc = evaluate(model, features, labels, val_mask)\n            print(\n                \"Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | \"\n                \"ETputs(KTEPS) {:.2f}\".format(\n                    epoch,\n                    np.mean(dur),\n                    loss_value.numpy().item(),\n                    acc,\n                    n_edges / np.mean(dur) / 1000,\n                )\n            )\n\n        acc = evaluate(model, features, labels, test_mask)\n        print(\"Test Accuracy {:.4f}\".format(acc))\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"GCN\")\n    register_data_args(parser)\n    parser.add_argument(\n        \"--dropout\", type=float, default=0.5, help=\"dropout probability\"\n    )\n    parser.add_argument(\"--gpu\", type=int, default=-1, help=\"gpu\")\n    parser.add_argument(\"--lr\", type=float, default=1e-2, help=\"learning rate\")\n    parser.add_argument(\n        \"--n-epochs\", type=int, default=200, help=\"number of training epochs\"\n    )\n    parser.add_argument(\n        \"--n-hidden\", type=int, default=16, help=\"number of hidden gcn units\"\n    )\n    parser.add_argument(\n        \"--n-layers\", type=int, default=1, help=\"number of hidden gcn layers\"\n    )\n    parser.add_argument(\n        \"--weight-decay\", type=float, default=5e-4, help=\"Weight for L2 loss\"\n    )\n    args = parser.parse_args()\n    print(args)\n\n    main(args)\n"
  },
  {
    "path": "examples/tensorflow/gcn/gcn_mp.py",
    "content": "import argparse\nimport math\nimport time\n\nimport dgl\n\nimport networkx as nx\nimport numpy as np\nimport tensorflow as tf\nfrom dgl.data import (\n    CiteseerGraphDataset,\n    CoraGraphDataset,\n    PubmedGraphDataset,\n    register_data_args,\n)\nfrom tensorflow.keras import layers\n\n\ndef gcn_msg(edge):\n    msg = edge.src[\"h\"] * edge.src[\"norm\"]\n    return {\"m\": msg}\n\n\ndef gcn_reduce(node):\n    accum = tf.reduce_sum(node.mailbox[\"m\"], 1) * node.data[\"norm\"]\n    return {\"h\": accum}\n\n\nclass GCNLayer(layers.Layer):\n    def __init__(self, g, in_feats, out_feats, activation, dropout, bias=True):\n        super(GCNLayer, self).__init__()\n        self.g = g\n\n        w_init = tf.random_normal_initializer()\n        self.weight = tf.Variable(\n            initial_value=w_init(shape=(in_feats, out_feats), dtype=\"float32\"),\n            trainable=True,\n        )\n        if dropout:\n            self.dropout = layers.Dropout(rate=dropout)\n        else:\n            self.dropout = 0.0\n        if bias:\n            b_init = tf.zeros_initializer()\n            self.bias = tf.Variable(\n                initial_value=b_init(shape=(out_feats,), dtype=\"float32\"),\n                trainable=True,\n            )\n        else:\n            self.bias = None\n        self.activation = activation\n\n    def call(self, h):\n        if self.dropout:\n            h = self.dropout(h)\n        self.g.ndata[\"h\"] = tf.matmul(h, self.weight)\n        self.g.update_all(gcn_msg, gcn_reduce)\n        h = self.g.ndata[\"h\"]\n        if self.bias is not None:\n            h = h + self.bias\n        if self.activation:\n            h = self.activation(h)\n        return h\n\n\nclass GCN(layers.Layer):\n    def __init__(\n        self, g, in_feats, n_hidden, n_classes, n_layers, activation, dropout\n    ):\n        super(GCN, self).__init__()\n        self.layers = []\n\n        # input layer\n        self.layers.append(GCNLayer(g, in_feats, n_hidden, activation, dropout))\n        # hidden layers\n        for i in range(n_layers - 1):\n            self.layers.append(\n                GCNLayer(g, n_hidden, n_hidden, activation, dropout)\n            )\n        # output layer\n        self.layers.append(GCNLayer(g, n_hidden, n_classes, None, dropout))\n\n    def call(self, features):\n        h = features\n        for layer in self.layers:\n            h = layer(h)\n        return h\n\n\ndef evaluate(model, features, labels, mask):\n    logits = model(features, training=False)\n    logits = logits[mask]\n    labels = labels[mask]\n    indices = tf.math.argmax(logits, axis=1)\n    acc = tf.reduce_mean(tf.cast(indices == labels, dtype=tf.float32))\n    return acc.numpy().item()\n\n\ndef main(args):\n    # load and preprocess dataset\n    if args.dataset == \"cora\":\n        data = CoraGraphDataset()\n    elif args.dataset == \"citeseer\":\n        data = CiteseerGraphDataset()\n    elif args.dataset == \"pubmed\":\n        data = PubmedGraphDataset()\n    else:\n        raise ValueError(\"Unknown dataset: {}\".format(args.dataset))\n\n    g = data[0]\n    if args.gpu < 0:\n        device = \"/cpu:0\"\n    else:\n        device = \"/gpu:{}\".format(args.gpu)\n        g = g.to(device)\n\n    with tf.device(device):\n        features = g.ndata[\"feat\"]\n        labels = g.ndata[\"label\"]\n        train_mask = g.ndata[\"train_mask\"]\n        val_mask = g.ndata[\"val_mask\"]\n        test_mask = g.ndata[\"test_mask\"]\n        in_feats = features.shape[1]\n        n_classes = data.num_classes\n        n_edges = data.graph.number_of_edges()\n        print(\n            \"\"\"----Data statistics------'\n        #Edges %d\n        #Classes %d\n        #Train samples %d\n        #Val samples %d\n        #Test samples %d\"\"\"\n            % (\n                n_edges,\n                n_classes,\n                train_mask.numpy().sum(),\n                val_mask.numpy().sum(),\n                test_mask.numpy().sum(),\n            )\n        )\n\n        # add self loop\n        if args.self_loop:\n            g = dgl.remove_self_loop(g)\n            g = dgl.add_self_loop(g)\n        n_edges = g.number_of_edges()\n        n_edges = g.number_of_edges()\n        # # normalization\n        degs = tf.cast(tf.identity(g.in_degrees()), dtype=tf.float32)\n        norm = tf.math.pow(degs, -0.5)\n        norm = tf.where(tf.math.is_inf(norm), tf.zeros_like(norm), norm)\n\n        g.ndata[\"norm\"] = tf.expand_dims(norm, -1)\n\n        # create GCN model\n        model = GCN(\n            g,\n            in_feats,\n            args.n_hidden,\n            n_classes,\n            args.n_layers,\n            tf.nn.relu,\n            args.dropout,\n        )\n\n        optimizer = tf.keras.optimizers.Adam(learning_rate=args.lr)\n\n        loss_fcn = tf.keras.losses.SparseCategoricalCrossentropy(\n            from_logits=True\n        )\n        # initialize graph\n        dur = []\n        for epoch in range(args.n_epochs):\n            if epoch >= 3:\n                t0 = time.time()\n            # forward\n            with tf.GradientTape() as tape:\n                logits = model(features)\n                loss_value = loss_fcn(labels[train_mask], logits[train_mask])\n                # Manually Weight Decay\n                # We found Tensorflow has a different implementation on weight decay\n                # of Adam(W) optimizer with PyTorch. And this results in worse results.\n                # Manually adding weights to the loss to do weight decay solves this problem.\n                for weight in model.trainable_weights:\n                    loss_value = loss_value + args.weight_decay * tf.nn.l2_loss(\n                        weight\n                    )\n                grads = tape.gradient(loss_value, model.trainable_weights)\n                optimizer.apply_gradients(zip(grads, model.trainable_weights))\n\n            if epoch >= 3:\n                dur.append(time.time() - t0)\n\n            acc = evaluate(model, features, labels, val_mask)\n            print(\n                \"Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | \"\n                \"ETputs(KTEPS) {:.2f}\".format(\n                    epoch,\n                    np.mean(dur),\n                    loss_value.numpy().item(),\n                    acc,\n                    n_edges / np.mean(dur) / 1000,\n                )\n            )\n\n        acc = evaluate(model, features, labels, test_mask)\n        print(\"Test Accuracy {:.4f}\".format(acc))\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"GCN\")\n    register_data_args(parser)\n    parser.add_argument(\n        \"--dropout\", type=float, default=0.5, help=\"dropout probability\"\n    )\n    parser.add_argument(\"--gpu\", type=int, default=-1, help=\"gpu\")\n    parser.add_argument(\"--lr\", type=float, default=1e-2, help=\"learning rate\")\n    parser.add_argument(\n        \"--n-epochs\", type=int, default=200, help=\"number of training epochs\"\n    )\n    parser.add_argument(\n        \"--n-hidden\", type=int, default=16, help=\"number of hidden gcn units\"\n    )\n    parser.add_argument(\n        \"--n-layers\", type=int, default=1, help=\"number of hidden gcn layers\"\n    )\n    parser.add_argument(\n        \"--weight-decay\", type=float, default=5e-4, help=\"Weight for L2 loss\"\n    )\n    parser.add_argument(\n        \"--self-loop\",\n        action=\"store_true\",\n        help=\"graph self-loop (default=False)\",\n    )\n    args = parser.parse_args()\n    print(args)\n\n    main(args)\n"
  },
  {
    "path": "examples/tensorflow/gcn/train.py",
    "content": "import argparse\nimport time\n\nimport dgl\n\nimport numpy as np\nimport tensorflow as tf\nfrom dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset\nfrom gcn import GCN\n\n\ndef evaluate(model, features, labels, mask):\n    logits = model(features, training=False)\n    logits = logits[mask]\n    labels = labels[mask]\n    indices = tf.math.argmax(logits, axis=1)\n    acc = tf.reduce_mean(tf.cast(indices == labels, dtype=tf.float32))\n    return acc.numpy().item()\n\n\ndef main(args):\n    # load and preprocess dataset\n    if args.dataset == \"cora\":\n        data = CoraGraphDataset()\n    elif args.dataset == \"citeseer\":\n        data = CiteseerGraphDataset()\n    elif args.dataset == \"pubmed\":\n        data = PubmedGraphDataset()\n    else:\n        raise ValueError(\"Unknown dataset: {}\".format(args.dataset))\n\n    g = data[0]\n    if args.gpu < 0:\n        device = \"/cpu:0\"\n    else:\n        device = \"/gpu:{}\".format(args.gpu)\n        g = g.to(device)\n\n    with tf.device(device):\n        features = g.ndata[\"feat\"]\n        labels = g.ndata[\"label\"]\n        train_mask = g.ndata[\"train_mask\"]\n        val_mask = g.ndata[\"val_mask\"]\n        test_mask = g.ndata[\"test_mask\"]\n        in_feats = features.shape[1]\n        n_classes = data.num_classes\n        n_edges = g.number_of_edges()\n        print(\n            \"\"\"----Data statistics------'\n        #Edges %d\n        #Classes %d\n        #Train samples %d\n        #Val samples %d\n        #Test samples %d\"\"\"\n            % (\n                n_edges,\n                n_classes,\n                train_mask.numpy().sum(),\n                val_mask.numpy().sum(),\n                test_mask.numpy().sum(),\n            )\n        )\n\n        # add self loop\n        if args.self_loop:\n            g = dgl.remove_self_loop(g)\n            g = dgl.add_self_loop(g)\n        n_edges = g.number_of_edges()\n        # normalization\n        degs = tf.cast(tf.identity(g.in_degrees()), dtype=tf.float32)\n        norm = tf.math.pow(degs, -0.5)\n        norm = tf.where(tf.math.is_inf(norm), tf.zeros_like(norm), norm)\n\n        g.ndata[\"norm\"] = tf.expand_dims(norm, -1)\n\n        # create GCN model\n        model = GCN(\n            g,\n            in_feats,\n            args.n_hidden,\n            n_classes,\n            args.n_layers,\n            tf.nn.relu,\n            args.dropout,\n        )\n\n        loss_fcn = tf.keras.losses.SparseCategoricalCrossentropy(\n            from_logits=True\n        )\n        # use optimizer\n        optimizer = tf.keras.optimizers.Adam(\n            learning_rate=args.lr, epsilon=1e-8\n        )\n\n        # initialize graph\n        dur = []\n        for epoch in range(args.n_epochs):\n            if epoch >= 3:\n                t0 = time.time()\n            # forward\n            with tf.GradientTape() as tape:\n                logits = model(features)\n                loss_value = loss_fcn(labels[train_mask], logits[train_mask])\n                # Manually Weight Decay\n                # We found Tensorflow has a different implementation on weight decay\n                # of Adam(W) optimizer with PyTorch. And this results in worse results.\n                # Manually adding weights to the loss to do weight decay solves this problem.\n                for weight in model.trainable_weights:\n                    loss_value = loss_value + args.weight_decay * tf.nn.l2_loss(\n                        weight\n                    )\n\n                grads = tape.gradient(loss_value, model.trainable_weights)\n                optimizer.apply_gradients(zip(grads, model.trainable_weights))\n            if epoch >= 3:\n                dur.append(time.time() - t0)\n\n            acc = evaluate(model, features, labels, val_mask)\n            print(\n                \"Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | \"\n                \"ETputs(KTEPS) {:.2f}\".format(\n                    epoch,\n                    np.mean(dur),\n                    loss_value.numpy().item(),\n                    acc,\n                    n_edges / np.mean(dur) / 1000,\n                )\n            )\n\n        acc = evaluate(model, features, labels, test_mask)\n        print(\"Test Accuracy {:.4f}\".format(acc))\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"GCN\")\n    parser.add_argument(\n        \"--dataset\",\n        type=str,\n        default=\"cora\",\n        help=\"Dataset name ('cora', 'citeseer', 'pubmed').\",\n    )\n    parser.add_argument(\n        \"--dropout\", type=float, default=0.5, help=\"dropout probability\"\n    )\n    parser.add_argument(\"--gpu\", type=int, default=-1, help=\"gpu\")\n    parser.add_argument(\"--lr\", type=float, default=1e-2, help=\"learning rate\")\n    parser.add_argument(\n        \"--n-epochs\", type=int, default=200, help=\"number of training epochs\"\n    )\n    parser.add_argument(\n        \"--n-hidden\", type=int, default=16, help=\"number of hidden gcn units\"\n    )\n    parser.add_argument(\n        \"--n-layers\", type=int, default=1, help=\"number of hidden gcn layers\"\n    )\n    parser.add_argument(\n        \"--weight-decay\", type=float, default=5e-4, help=\"Weight for L2 loss\"\n    )\n    parser.add_argument(\n        \"--self-loop\",\n        action=\"store_true\",\n        help=\"graph self-loop (default=False)\",\n    )\n    parser.set_defaults(self_loop=False)\n    args = parser.parse_args()\n    print(args)\n\n    main(args)\n"
  },
  {
    "path": "examples/tensorflow/rgcn/README.md",
    "content": "# Relational-GCN\n\n* Paper: [https://arxiv.org/abs/1703.06103](https://arxiv.org/abs/1703.06103)\n* Author's code for entity classification: [https://github.com/tkipf/relational-gcn](https://github.com/tkipf/relational-gcn)\n* Author's code for link prediction: [https://github.com/MichSchli/RelationPrediction](https://github.com/MichSchli/RelationPrediction)\n\n### Dependencies\n* Tensorflow 2.2+\n* requests\n* rdflib\n* pandas\n\n```\npip install requests tensorflow rdflib pandas\nexport DGLBACKEND=tensorflow\n```\n\nExample code was tested with rdflib 4.2.2 and pandas 0.23.4\n\n### Entity Classification\nAIFB: accuracy 92.78% (5 runs, DGL), 95.83% (paper)\n```\npython3 entity_classify.py -d aifb --testing --gpu 0\n```\n\nMUTAG: accuracy 71.47% (5 runs, DGL), 73.23% (paper)\n```\npython3 entity_classify.py -d mutag --l2norm 5e-4 --n-bases 30 --testing --gpu 0\n```\n\nBGS: accuracy 93.10% (5 runs, DGL n-base=25), 83.10% (paper n-base=40)\n```\npython3 entity_classify.py -d bgs --l2norm 5e-4 --n-bases 25 --testing --gpu 0\n```\n"
  },
  {
    "path": "examples/tensorflow/rgcn/entity_classify.py",
    "content": "\"\"\"\nModeling Relational Data with Graph Convolutional Networks\nPaper: https://arxiv.org/abs/1703.06103\nCode: https://github.com/tkipf/relational-gcn\n\nDifference compared to tkipf/relation-gcn\n* l2norm applied to all weights\n* remove nodes that won't be touched\n\"\"\"\n\nimport argparse\nimport time\nfrom functools import partial\n\nimport dgl\n\nimport numpy as np\nimport tensorflow as tf\nfrom dgl.data.rdf import AIFBDataset, AMDataset, BGSDataset, MUTAGDataset\nfrom dgl.nn.tensorflow import RelGraphConv\nfrom model import BaseRGCN\nfrom tensorflow.keras import layers\n\n\nclass EntityClassify(BaseRGCN):\n    def create_features(self):\n        features = tf.range(self.num_nodes)\n        return features\n\n    def build_input_layer(self):\n        return RelGraphConv(\n            self.num_nodes,\n            self.h_dim,\n            self.num_rels,\n            \"basis\",\n            self.num_bases,\n            activation=tf.nn.relu,\n            self_loop=self.use_self_loop,\n            dropout=self.dropout,\n        )\n\n    def build_hidden_layer(self, idx):\n        return RelGraphConv(\n            self.h_dim,\n            self.h_dim,\n            self.num_rels,\n            \"basis\",\n            self.num_bases,\n            activation=tf.nn.relu,\n            self_loop=self.use_self_loop,\n            dropout=self.dropout,\n        )\n\n    def build_output_layer(self):\n        return RelGraphConv(\n            self.h_dim,\n            self.out_dim,\n            self.num_rels,\n            \"basis\",\n            self.num_bases,\n            activation=partial(tf.nn.softmax, axis=1),\n            self_loop=self.use_self_loop,\n        )\n\n\ndef acc(logits, labels, mask):\n    logits = tf.gather(logits, mask)\n    labels = tf.gather(labels, mask)\n    indices = tf.math.argmax(logits, axis=1)\n    acc = tf.reduce_mean(tf.cast(indices == labels, dtype=tf.float32))\n    return acc\n\n\ndef main(args):\n    # load graph data\n    if args.dataset == \"aifb\":\n        dataset = AIFBDataset()\n    elif args.dataset == \"mutag\":\n        dataset = MUTAGDataset()\n    elif args.dataset == \"bgs\":\n        dataset = BGSDataset()\n    elif args.dataset == \"am\":\n        dataset = AMDataset()\n    else:\n        raise ValueError()\n\n    # preprocessing in cpu\n    with tf.device(\"/cpu:0\"):\n        # Load from hetero-graph\n        hg = dataset[0]\n\n        num_rels = len(hg.canonical_etypes)\n        category = dataset.predict_category\n        num_classes = dataset.num_classes\n        train_mask = hg.nodes[category].data.pop(\"train_mask\")\n        test_mask = hg.nodes[category].data.pop(\"test_mask\")\n        train_idx = tf.squeeze(tf.where(train_mask))\n        test_idx = tf.squeeze(tf.where(test_mask))\n        labels = hg.nodes[category].data.pop(\"labels\")\n\n        # split dataset into train, validate, test\n        if args.validation:\n            val_idx = train_idx[: len(train_idx) // 5]\n            train_idx = train_idx[len(train_idx) // 5 :]\n        else:\n            val_idx = train_idx\n\n        # calculate norm for each edge type and store in edge\n        for canonical_etype in hg.canonical_etypes:\n            u, v, eid = hg.all_edges(form=\"all\", etype=canonical_etype)\n            _, inverse_index, count = tf.unique_with_counts(v)\n            degrees = tf.gather(count, inverse_index)\n            norm = tf.ones(eid.shape[0]) / tf.cast(degrees, tf.float32)\n            norm = tf.expand_dims(norm, 1)\n            hg.edges[canonical_etype].data[\"norm\"] = norm\n\n        # get target category id\n        category_id = len(hg.ntypes)\n        for i, ntype in enumerate(hg.ntypes):\n            if ntype == category:\n                category_id = i\n\n        # edge type and normalization factor\n        g = dgl.to_homogeneous(hg, edata=[\"norm\"])\n\n    # check cuda\n    if args.gpu < 0:\n        device = \"/cpu:0\"\n        use_cuda = False\n    else:\n        device = \"/gpu:{}\".format(args.gpu)\n        g = g.to(device)\n        use_cuda = True\n    num_nodes = g.number_of_nodes()\n    node_ids = tf.range(num_nodes, dtype=tf.int64)\n    edge_norm = g.edata[\"norm\"]\n    edge_type = tf.cast(g.edata[dgl.ETYPE], tf.int64)\n\n    # find out the target node ids in g\n    node_tids = g.ndata[dgl.NTYPE]\n    loc = node_tids == category_id\n    target_idx = tf.squeeze(tf.where(loc))\n\n    # since the nodes are featureless, the input feature is then the node id.\n    feats = tf.range(num_nodes, dtype=tf.int64)\n\n    with tf.device(device):\n        # create model\n        model = EntityClassify(\n            num_nodes,\n            args.n_hidden,\n            num_classes,\n            num_rels,\n            num_bases=args.n_bases,\n            num_hidden_layers=args.n_layers - 2,\n            dropout=args.dropout,\n            use_self_loop=args.use_self_loop,\n            use_cuda=use_cuda,\n        )\n\n        # optimizer\n        optimizer = tf.keras.optimizers.Adam(learning_rate=args.lr)\n        # training loop\n        print(\"start training...\")\n        forward_time = []\n        backward_time = []\n        loss_fcn = tf.keras.losses.SparseCategoricalCrossentropy(\n            from_logits=False\n        )\n        for epoch in range(args.n_epochs):\n            t0 = time.time()\n            with tf.GradientTape() as tape:\n                logits = model(g, feats, edge_type, edge_norm)\n                logits = tf.gather(logits, target_idx)\n                loss = loss_fcn(\n                    tf.gather(labels, train_idx), tf.gather(logits, train_idx)\n                )\n                # Manually Weight Decay\n                # We found Tensorflow has a different implementation on weight decay\n                # of Adam(W) optimizer with PyTorch. And this results in worse results.\n                # Manually adding weights to the loss to do weight decay solves this problem.\n                for weight in model.trainable_weights:\n                    loss = loss + args.l2norm * tf.nn.l2_loss(weight)\n                t1 = time.time()\n                grads = tape.gradient(loss, model.trainable_weights)\n                optimizer.apply_gradients(zip(grads, model.trainable_weights))\n                t2 = time.time()\n\n            forward_time.append(t1 - t0)\n            backward_time.append(t2 - t1)\n            print(\n                \"Epoch {:05d} | Train Forward Time(s) {:.4f} | Backward Time(s) {:.4f}\".format(\n                    epoch, forward_time[-1], backward_time[-1]\n                )\n            )\n            train_acc = acc(logits, labels, train_idx)\n            val_loss = loss_fcn(\n                tf.gather(labels, val_idx), tf.gather(logits, val_idx)\n            )\n            val_acc = acc(logits, labels, val_idx)\n            print(\n                \"Train Accuracy: {:.4f} | Train Loss: {:.4f} | Validation Accuracy: {:.4f} | Validation loss: {:.4f}\".format(\n                    train_acc,\n                    loss.numpy().item(),\n                    val_acc,\n                    val_loss.numpy().item(),\n                )\n            )\n        print()\n\n        logits = model(g, feats, edge_type, edge_norm)\n        logits = tf.gather(logits, target_idx)\n        test_loss = loss_fcn(\n            tf.gather(labels, test_idx), tf.gather(logits, test_idx)\n        )\n        test_acc = acc(logits, labels, test_idx)\n        print(\n            \"Test Accuracy: {:.4f} | Test loss: {:.4f}\".format(\n                test_acc, test_loss.numpy().item()\n            )\n        )\n        print()\n\n        print(\n            \"Mean forward time: {:4f}\".format(\n                np.mean(forward_time[len(forward_time) // 4 :])\n            )\n        )\n        print(\n            \"Mean backward time: {:4f}\".format(\n                np.mean(backward_time[len(backward_time) // 4 :])\n            )\n        )\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"RGCN\")\n    parser.add_argument(\n        \"--dropout\", type=float, default=0, help=\"dropout probability\"\n    )\n    parser.add_argument(\n        \"--n-hidden\", type=int, default=16, help=\"number of hidden units\"\n    )\n    parser.add_argument(\"--gpu\", type=int, default=-1, help=\"gpu\")\n    parser.add_argument(\"--lr\", type=float, default=1e-2, help=\"learning rate\")\n    parser.add_argument(\n        \"--n-bases\",\n        type=int,\n        default=-1,\n        help=\"number of filter weight matrices, default: -1 [use all]\",\n    )\n    parser.add_argument(\n        \"--n-layers\", type=int, default=2, help=\"number of propagation rounds\"\n    )\n    parser.add_argument(\n        \"-e\",\n        \"--n-epochs\",\n        type=int,\n        default=50,\n        help=\"number of training epochs\",\n    )\n    parser.add_argument(\n        \"-d\", \"--dataset\", type=str, required=True, help=\"dataset to use\"\n    )\n    parser.add_argument(\"--l2norm\", type=float, default=0, help=\"l2 norm coef\")\n    parser.add_argument(\n        \"--use-self-loop\",\n        default=False,\n        action=\"store_true\",\n        help=\"include self feature as a special relation\",\n    )\n    fp = parser.add_mutually_exclusive_group(required=False)\n    fp.add_argument(\"--validation\", dest=\"validation\", action=\"store_true\")\n    fp.add_argument(\"--testing\", dest=\"validation\", action=\"store_false\")\n    parser.set_defaults(validation=True)\n\n    args = parser.parse_args()\n    print(args)\n    args.bfs_level = args.n_layers + 1  # pruning used nodes for memory\n    main(args)\n"
  },
  {
    "path": "examples/tensorflow/rgcn/model.py",
    "content": "import tensorflow as tf\nfrom tensorflow.keras import layers\n\n\nclass BaseRGCN(layers.Layer):\n    def __init__(\n        self,\n        num_nodes,\n        h_dim,\n        out_dim,\n        num_rels,\n        num_bases,\n        num_hidden_layers=1,\n        dropout=0,\n        use_self_loop=False,\n        use_cuda=False,\n    ):\n        super(BaseRGCN, self).__init__()\n        self.num_nodes = num_nodes\n        self.h_dim = h_dim\n        self.out_dim = out_dim\n        self.num_rels = num_rels\n        self.num_bases = None if num_bases < 0 else num_bases\n        self.num_hidden_layers = num_hidden_layers\n        self.dropout = dropout\n        self.use_self_loop = use_self_loop\n        self.use_cuda = use_cuda\n\n        # create rgcn layers\n        self.build_model()\n\n    def build_model(self):\n        self.layers = []\n        # i2h\n        i2h = self.build_input_layer()\n        if i2h is not None:\n            self.layers.append(i2h)\n        # h2h\n        for idx in range(self.num_hidden_layers):\n            h2h = self.build_hidden_layer(idx)\n            self.layers.append(h2h)\n        # h2o\n        h2o = self.build_output_layer()\n        if h2o is not None:\n            self.layers.append(h2o)\n\n    def build_input_layer(self):\n        return None\n\n    def build_hidden_layer(self, idx):\n        raise NotImplementedError\n\n    def build_output_layer(self):\n        return None\n\n    def call(self, g, h, r, norm):\n        for layer in self.layers:\n            h = layer(g, h, r, norm)\n        return h\n"
  },
  {
    "path": "examples/tensorflow/rgcn/utils.py",
    "content": "\"\"\"\nUtility functions for link prediction\nMost code is adapted from authors' implementation of RGCN link prediction:\nhttps://github.com/MichSchli/RelationPrediction\n\n\"\"\"\n\nimport dgl\nimport numpy as np\nimport tensorflow as tf\n\n#######################################################################\n#\n# Utility function for building training and testing graphs\n#\n#######################################################################\n\n\ndef get_adj_and_degrees(num_nodes, triplets):\n    \"\"\"Get adjacency list and degrees of the graph\"\"\"\n    adj_list = [[] for _ in range(num_nodes)]\n    for i, triplet in enumerate(triplets):\n        adj_list[triplet[0]].append([i, triplet[2]])\n        adj_list[triplet[2]].append([i, triplet[0]])\n\n    degrees = np.array([len(a) for a in adj_list])\n    adj_list = [np.array(a) for a in adj_list]\n    return adj_list, degrees\n\n\ndef sample_edge_neighborhood(adj_list, degrees, n_triplets, sample_size):\n    \"\"\"Sample edges by neighborhool expansion.\n\n    This guarantees that the sampled edges form a connected graph, which\n    may help deeper GNNs that require information from more than one hop.\n    \"\"\"\n    edges = np.zeros((sample_size), dtype=np.int32)\n\n    # initialize\n    sample_counts = np.array([d for d in degrees])\n    picked = np.array([False for _ in range(n_triplets)])\n    seen = np.array([False for _ in degrees])\n\n    for i in range(0, sample_size):\n        weights = sample_counts * seen\n\n        if np.sum(weights) == 0:\n            weights = np.ones_like(weights)\n            weights[np.where(sample_counts == 0)] = 0\n\n        probabilities = (weights) / np.sum(weights)\n        chosen_vertex = np.random.choice(\n            np.arange(degrees.shape[0]), p=probabilities\n        )\n        chosen_adj_list = adj_list[chosen_vertex]\n        seen[chosen_vertex] = True\n\n        chosen_edge = np.random.choice(np.arange(chosen_adj_list.shape[0]))\n        chosen_edge = chosen_adj_list[chosen_edge]\n        edge_number = chosen_edge[0]\n\n        while picked[edge_number]:\n            chosen_edge = np.random.choice(np.arange(chosen_adj_list.shape[0]))\n            chosen_edge = chosen_adj_list[chosen_edge]\n            edge_number = chosen_edge[0]\n\n        edges[i] = edge_number\n        other_vertex = chosen_edge[1]\n        picked[edge_number] = True\n        sample_counts[chosen_vertex] -= 1\n        sample_counts[other_vertex] -= 1\n        seen[other_vertex] = True\n\n    return edges\n\n\ndef sample_edge_uniform(adj_list, degrees, n_triplets, sample_size):\n    \"\"\"Sample edges uniformly from all the edges.\"\"\"\n    all_edges = np.arange(n_triplets)\n    return np.random.choice(all_edges, sample_size, replace=False)\n\n\ndef generate_sampled_graph_and_labels(\n    triplets,\n    sample_size,\n    split_size,\n    num_rels,\n    adj_list,\n    degrees,\n    negative_rate,\n    sampler=\"uniform\",\n):\n    \"\"\"Get training graph and signals\n    First perform edge neighborhood sampling on graph, then perform negative\n    sampling to generate negative samples\n    \"\"\"\n    # perform edge neighbor sampling\n    if sampler == \"uniform\":\n        edges = sample_edge_uniform(\n            adj_list, degrees, len(triplets), sample_size\n        )\n    elif sampler == \"neighbor\":\n        edges = sample_edge_neighborhood(\n            adj_list, degrees, len(triplets), sample_size\n        )\n    else:\n        raise ValueError(\"Sampler type must be either 'uniform' or 'neighbor'.\")\n\n    # relabel nodes to have consecutive node ids\n    edges = triplets[edges]\n    src, rel, dst = edges.transpose()\n    uniq_v, edges = np.unique((src, dst), return_inverse=True)\n    src, dst = np.reshape(edges, (2, -1))\n    relabeled_edges = np.stack((src, rel, dst)).transpose()\n\n    # negative sampling\n    samples, labels = negative_sampling(\n        relabeled_edges, len(uniq_v), negative_rate\n    )\n\n    # further split graph, only half of the edges will be used as graph\n    # structure, while the rest half is used as unseen positive samples\n    split_size = int(sample_size * split_size)\n    graph_split_ids = np.random.choice(\n        np.arange(sample_size), size=split_size, replace=False\n    )\n    src = src[graph_split_ids]\n    dst = dst[graph_split_ids]\n    rel = rel[graph_split_ids]\n\n    # build DGL graph\n    print(\"# sampled nodes: {}\".format(len(uniq_v)))\n    print(\"# sampled edges: {}\".format(len(src) * 2))\n    g, rel, norm = build_graph_from_triplets(\n        len(uniq_v), num_rels, (src, rel, dst)\n    )\n    return g, uniq_v, rel, norm, samples, labels\n\n\ndef comp_deg_norm(g):\n    g = g.local_var()\n    in_deg = g.in_degrees(range(g.number_of_nodes())).float().numpy()\n    norm = 1.0 / in_deg\n    norm[np.isinf(norm)] = 0\n    return norm\n\n\ndef build_graph_from_triplets(num_nodes, num_rels, triplets):\n    \"\"\"Create a DGL graph. The graph is bidirectional because RGCN authors\n    use reversed relations.\n    This function also generates edge type and normalization factor\n    (reciprocal of node incoming degree)\n    \"\"\"\n    g = dgl.DGLGraph()\n    g.add_nodes(num_nodes)\n    src, rel, dst = triplets\n    src, dst = np.concatenate((src, dst)), np.concatenate((dst, src))\n    rel = np.concatenate((rel, rel + num_rels))\n    edges = sorted(zip(dst, src, rel))\n    dst, src, rel = np.array(edges).transpose()\n    g.add_edges(src, dst)\n    norm = comp_deg_norm(g)\n    print(\"# nodes: {}, # edges: {}\".format(num_nodes, len(src)))\n    return g, rel, norm\n\n\ndef build_test_graph(num_nodes, num_rels, edges):\n    src, rel, dst = edges.transpose()\n    print(\"Test graph:\")\n    return build_graph_from_triplets(num_nodes, num_rels, (src, rel, dst))\n\n\ndef negative_sampling(pos_samples, num_entity, negative_rate):\n    size_of_batch = len(pos_samples)\n    num_to_generate = size_of_batch * negative_rate\n    neg_samples = np.tile(pos_samples, (negative_rate, 1))\n    labels = np.zeros(size_of_batch * (negative_rate + 1), dtype=np.float32)\n    labels[:size_of_batch] = 1\n    values = np.random.randint(num_entity, size=num_to_generate)\n    choices = np.random.uniform(size=num_to_generate)\n    subj = choices > 0.5\n    obj = choices <= 0.5\n    neg_samples[subj, 0] = values[subj]\n    neg_samples[obj, 2] = values[obj]\n\n    return np.concatenate((pos_samples, neg_samples)), labels\n"
  },
  {
    "path": "examples/tensorflow/sgc/README.md",
    "content": "# Simple Graph Convolution (SGC)\n\n> Graph Convolutional Networks derive inspiration primarily from recent deep learning approaches, and as a result, may inherit unnecessary complexity and redundant computation. In this paper, we reduce this excess complexity through successively removing nonlinearities and collapsing weight matrices between consecutive layers. We theoretically analyze the resulting linear model and show that it corresponds to a fixed low-pass filter followed by a linear classifier.\n\n* [Paper](https://arxiv.org/abs/1902.07153)\n* [Author Implementation](https://github.com/Tiiiger/SGC)\n\nNote: TensorFlow uses a different implementation of weight decay in AdamW to PyTorch. This results in differences in performance. You can see this by manually adding the L2 of the weights to the loss like [this](https://github.com/dmlc/dgl/blob/d696558b0bbcb60f1c4cf68dc93cd22c1077ce06/examples/tensorflow/gcn/train.py#L99) for comparison.\n\n## Requirements\n\nThis example is tested with TensorFlow 2.3.0.\n\n```bash\n$ pip install dgl tensorflow tensorflow_addons\n```\n\n## Usage\n```bash\n$ python sgc.py --help\nusage: sgc.py [-h] [--dataset DATASET] [--lr LR] [--bias]\n              [--n-epochs N_EPOCHS] [--weight-decay WEIGHT_DECAY]\n\nRun experiment for Simple Graph Convolution (SGC)\n\noptional arguments:\n  -h, --help                    show this help message and exit\n  --dataset DATASET             dataset to run\n  --lr LR                       learning rate\n  --bias                        flag to use bias\n  --n-epochs N_EPOCHS           number of training epochs\n  --weight-decay WEIGHT_DECAY   weight for L2 loss\n```\n\n## Results\n```bash\n# Cora citation network dataset\n$ python sgc.py --dataset cora --lr 0.2 --n-epochs 100 --weight-decay 5e-6\n...\nEpoch 100/100\n1/1 [==============================] - 0s 40ms/step - loss: 0.0313 - accuracy: 1.0000 - val_loss: 0.7870 - val_accuracy: 0.7620\nTest Accuracy: 77.2%\n\n# Citeseer citation network dataset\n$ python sgc.py --dataset citeseer --lr 0.2 --n-epochs 150 --bias --weight-decay 5e-5\n...\nEpoch 150/150\n1/1 [==============================] - 0s 65ms/step - loss: 0.0160 - accuracy: 1.0000 - val_loss: 1.1021 - val_accuracy: 0.6420\nTest Accuracy: 63.9%\n\n# Pubmed citation network dataset\n$ python sgc.py --dataset pubmed --lr 0.2 --n-epochs 100 --bias --weight-decay 5e-5\n...\nEpoch 100/100\n1/1 [==============================] - 0s 52ms/step - loss: 0.0421 - accuracy: 1.0000 - val_loss: 0.5862 - val_accuracy: 0.7680\nTest Accuracy: 76.3%\n```\n\n| Dataset  | Accuracy | Paper |\n|----------|----------|-------|\n| Cora     | 77.3%    | 81.0% |\n| Citeseer | 63.9%    | 71.9% |\n| Pubmed   | 76.4%    | 78.9% |\n"
  },
  {
    "path": "examples/tensorflow/sgc/sgc.py",
    "content": "\"\"\"\nThis code was modified from implementations of SGC in other backends.\n\nSimplifying Graph Convolutional Networks (Wu, Zhang and Souza et al, 2019)\nPaper: https://arxiv.org/abs/1902.07153\nAuthor Implementation: https://github.com/Tiiiger/SGC\n\nSGC implementation in DGL.\n\"\"\"\nimport argparse\nimport textwrap\n\nimport tensorflow as tf\nimport tensorflow_addons as tfa\n\nfrom dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset\nfrom dgl.nn.tensorflow.conv import SGConv\n\n_DATASETS = {\n    \"citeseer\": CiteseerGraphDataset(verbose=False),\n    \"cora\": CoraGraphDataset(verbose=False),\n    \"pubmed\": PubmedGraphDataset(verbose=False),\n}\n\n\ndef load_data(dataset):\n    return _DATASETS[dataset]\n\n\ndef _sum_boolean_tensor(x):\n    return tf.reduce_sum(tf.cast(x, dtype=\"int64\"))\n\n\ndef describe_data(data):\n    g = data[0]\n\n    n_edges = g.number_of_edges()\n    num_classes = data.num_classes\n\n    train_mask = g.ndata[\"train_mask\"]\n    val_mask = g.ndata[\"val_mask\"]\n    test_mask = g.ndata[\"test_mask\"]\n\n    description = textwrap.dedent(\n        f\"\"\"\n        ----Data statistics----\n        Edges           {n_edges:,.0f}\n        Classes         {num_classes:,.0f}\n        Train samples   {_sum_boolean_tensor(train_mask):,.0f}\n        Val samples     {_sum_boolean_tensor(val_mask):,.0f}\n        Test samples    {_sum_boolean_tensor(test_mask):,.0f}\n        \"\"\"\n    )\n    return description\n\n\nclass SGC(tf.keras.Model):\n    def __init__(self, g, num_classes, bias=False):\n        super().__init__()\n        self.num_classes = num_classes\n        self.g = self.ensure_self_loop(g)\n        self.conv = SGConv(\n            in_feats=self.in_feats,\n            out_feats=self.num_classes,\n            k=2,\n            cached=True,\n            bias=bias,\n        )\n\n    def call(self, inputs):\n        return self.conv(self.g, inputs)\n\n    @property\n    def in_feats(self):\n        return self.g.ndata[\"feat\"].shape[1]\n\n    @property\n    def num_nodes(self):\n        return self.g.num_nodes()\n\n    @staticmethod\n    def ensure_self_loop(g):\n        g = g.remove_self_loop()\n        g = g.add_self_loop()\n        return g\n\n    def train_step(self, data):\n        X, y = data\n        mask = self.g.ndata[\"train_mask\"]\n\n        with tf.GradientTape() as tape:\n            y_pred = self(X, training=True)\n            loss = self.compiled_loss(y[mask], y_pred[mask])\n\n        trainable_variables = self.trainable_variables\n        gradients = tape.gradient(loss, trainable_variables)\n        self.optimizer.apply_gradients(zip(gradients, trainable_variables))\n        self.compiled_metrics.update_state(y[mask], y_pred[mask])\n        return {m.name: m.result() for m in self.metrics}\n\n    def test_step(self, data):\n        X, y = data\n        mask = self.g.ndata[\"val_mask\"]\n        y_pred = self(X, training=False)\n        self.compiled_loss(y[mask], y_pred[mask])\n        self.compiled_metrics.update_state(y[mask], y_pred[mask])\n        return {m.name: m.result() for m in self.metrics}\n\n    def compile(self, *args, **kwargs):\n        super().compile(*args, **kwargs, run_eagerly=True)\n\n    def fit(self, *args, **kwargs):\n        kwargs[\"batch_size\"] = self.num_nodes\n        kwargs[\"shuffle\"] = False\n        super().fit(*args, **kwargs)\n\n    def predict(self, *args, **kwargs):\n        kwargs[\"batch_size\"] = self.num_nodes\n        return super().predict(*args, **kwargs)\n\n\ndef main(dataset, lr, bias, n_epochs, weight_decay):\n    data = load_data(dataset)\n    print(describe_data(data))\n\n    g = data[0]\n    X = g.ndata[\"feat\"]\n    y = g.ndata[\"label\"]\n\n    model = SGC(g=g, num_classes=data.num_classes, bias=bias)\n\n    loss = tf.losses.SparseCategoricalCrossentropy(from_logits=True)\n    optimizer = tfa.optimizers.AdamW(weight_decay, lr)\n    accuracy = tf.metrics.SparseCategoricalAccuracy(name=\"accuracy\")\n\n    model.compile(optimizer, loss, metrics=[accuracy])\n    model.fit(x=X, y=y, epochs=n_epochs, validation_data=(X, y))\n\n    y_pred = model.predict(X, batch_size=len(X))\n    test_mask = g.ndata[\"test_mask\"]\n    test_accuracy = accuracy(y[test_mask], y_pred[test_mask])\n    print(f\"Test Accuracy: {test_accuracy:.1%}\")\n\n\ndef _parse_args():\n    parser = argparse.ArgumentParser(\n        description=\"Run experiment for Simple Graph Convolution (SGC)\"\n    )\n    parser.add_argument(\"--dataset\", default=\"cora\", help=\"dataset to run\")\n    parser.add_argument(\"--lr\", type=float, default=0.2, help=\"learning rate\")\n    parser.add_argument(\n        \"--bias\", action=\"store_true\", default=False, help=\"flag to use bias\"\n    )\n    parser.add_argument(\n        \"--n-epochs\", type=int, default=100, help=\"number of training epochs\"\n    )\n    parser.add_argument(\n        \"--weight-decay\", type=float, default=5e-6, help=\"weight for L2 loss\"\n    )\n    return parser.parse_args()\n\n\nif __name__ == \"__main__\":\n    args = _parse_args()\n    main(\n        dataset=args.dataset,\n        lr=args.lr,\n        bias=args.bias,\n        n_epochs=args.n_epochs,\n        weight_decay=args.weight_decay,\n    )\n"
  },
  {
    "path": "graphbolt/CMakeLists.txt",
    "content": "cmake_minimum_required(VERSION 3.18)\nproject(graphbolt C CXX)\nset(CMAKE_CXX_STANDARD 17)\nset(CMAKE_CXX_STANDARD_REQUIRED ON)\n\nif(USE_CUDA)\n  message(STATUS \"Build graphbolt with CUDA support\")\n  enable_language(CUDA)\n  add_definitions(-DGRAPHBOLT_USE_CUDA)\nendif()\n\n# For windows, define NOMINMAX to avoid conflict with std::min/max\nif(MSVC)\n  add_definitions(-DNOMINMAX)\nendif()\n\n# Find PyTorch cmake files and PyTorch versions with the python interpreter\n# $PYTHON_INTERP (\"python3\" or \"python\" if empty)\nif(NOT PYTHON_INTERP)\n  find_program(PYTHON_INTERP NAMES python3 python)\nendif()\n\nmessage(STATUS \"Using Python interpreter: ${PYTHON_INTERP}\")\n\nfile(TO_NATIVE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/find_cmake.py FIND_CMAKE_PY)\nexecute_process(\n  COMMAND ${PYTHON_INTERP} ${FIND_CMAKE_PY}\n  OUTPUT_VARIABLE TORCH_PREFIX_VER\n  OUTPUT_STRIP_TRAILING_WHITESPACE\n)\n\nmessage(STATUS \"find_cmake.py output: ${TORCH_PREFIX_VER}\")\nlist(GET TORCH_PREFIX_VER 0 TORCH_PREFIX)\nlist(GET TORCH_PREFIX_VER 1 TORCH_VER)\n\nmessage(STATUS \"Configuring for PyTorch ${TORCH_VER}\")\nstring(REPLACE \".\" \";\" TORCH_VERSION_LIST ${TORCH_VER})\n\nset(Torch_DIR \"${TORCH_PREFIX}/Torch\")\nmessage(STATUS \"Setting directory to ${Torch_DIR}\")\n\nfind_package(Torch REQUIRED)\nset(CMAKE_C_FLAGS \"${CMAKE_C_FLAGS} ${TORCH_C_FLAGS}\")\nset(CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}\")\nset(CMAKE_CXX_FLAGS_DEBUG \"${CMAKE_CXX_FLAGS_DEBUG} -O0 -g3 -ggdb\")\n\nset(LIB_GRAPHBOLT_NAME \"graphbolt_pytorch_${TORCH_VER}\")\noption(BUILD_WITH_TASKFLOW \"Use taskflow as parallel backend\" ON)\noption(USE_OPENMP \"Use OpenMP for graphbolt\" ON)\noption(USE_LIBURING \"Build graphbolt with liburing support\" ON)\n\nset(BOLT_DIR \"${CMAKE_CURRENT_SOURCE_DIR}/src\")\nset(BOLT_INCLUDE \"${CMAKE_CURRENT_SOURCE_DIR}/include\")\nfile(GLOB BOLT_HEADERS ${BOLT_INCLUDE})\nfile(GLOB BOLT_SRC ${BOLT_DIR}/*.cc)\nif(USE_CUDA)\n  file(GLOB BOLT_CUDA_SRC\n    ${BOLT_DIR}/cuda/*.cu\n    ${BOLT_DIR}/cuda/*.cc\n  )\n  list(APPEND BOLT_SRC ${BOLT_CUDA_SRC})\n  if(DEFINED ENV{CUDAARCHS})\n    set(CMAKE_CUDA_ARCHITECTURES $ENV{CUDAARCHS})\n  endif()\n  set(CMAKE_CUDA_ARCHITECTURES_FILTERED ${CMAKE_CUDA_ARCHITECTURES})\n  # CUDA extension supports only sm_70 and up (Volta+).\n  list(FILTER CMAKE_CUDA_ARCHITECTURES_FILTERED EXCLUDE REGEX \"[2-6][0-9]\")\n  list(LENGTH CMAKE_CUDA_ARCHITECTURES_FILTERED CMAKE_CUDA_ARCHITECTURES_FILTERED_LEN)\n  if(CMAKE_CUDA_ARCHITECTURES_FILTERED_LEN EQUAL 0)\n    # Build the CUDA extension at least build for Volta.\n    set(CMAKE_CUDA_ARCHITECTURES_FILTERED \"70\")\n  endif()\n  set(LIB_GRAPHBOLT_CUDA_NAME \"${LIB_GRAPHBOLT_NAME}_cuda\")\nendif()\n\nadd_library(${LIB_GRAPHBOLT_NAME} SHARED ${BOLT_SRC} ${BOLT_HEADERS})\ninclude_directories(BEFORE ${BOLT_DIR}\n                           ${BOLT_HEADERS}\n                           # For CXX20 features:\n                           # `std::atomic_ref`, `std::counting_semaphore`\n                           \"../third_party/cccl/libcudacxx/include\"\n                           \"../third_party/pcg/include\"\n                           \"../third_party/tsl_robin_map/include\")\ntarget_link_libraries(${LIB_GRAPHBOLT_NAME} \"${TORCH_LIBRARIES}\")\nif(BUILD_WITH_TASKFLOW)\n  target_include_directories(${LIB_GRAPHBOLT_NAME} PRIVATE \"../third_party/taskflow\")\n  target_compile_definitions(${LIB_GRAPHBOLT_NAME} PRIVATE BUILD_WITH_TASKFLOW=1)\nendif()\n\nif(USE_OPENMP)\n  find_package(OpenMP REQUIRED)\n  target_link_libraries(${LIB_GRAPHBOLT_NAME} OpenMP::OpenMP_CXX)\n  message(STATUS \"Build graphbolt with OpenMP.\")\nendif(USE_OPENMP)\n\nif(CMAKE_SYSTEM_NAME MATCHES \"Linux\")\n  if(USE_LIBURING)\n    add_definitions(-DHAVE_LIBRARY_LIBURING)\n    include(ExternalProject)\n    set(LIBURING_INSTALL_DIR ${CMAKE_CURRENT_BINARY_DIR}/third_party/liburing)\n    set(LIBURING_C_COMPILER \"${CMAKE_C_COMPILER} -w\")\n    ExternalProject_Add(\n      liburing\n      SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../third_party/liburing\n      CONFIGURE_COMMAND <SOURCE_DIR>/configure --cc=${LIBURING_C_COMPILER} --cxx=${CMAKE_CXX_COMPILER} --prefix=/\n      # In order to avoid the error `error: redefinition of 'struct in6_pktinfo'` on ubi7\n      # when building examples, let's build src only.\n      BUILD_COMMAND bash -c \"make -j 4 -C src/\"\n      BUILD_IN_SOURCE ON\n      INSTALL_COMMAND make install DESTDIR=${LIBURING_INSTALL_DIR}\n      BUILD_BYPRODUCTS ${LIBURING_INSTALL_DIR}/lib/liburing.a\n      BUILD_BYPRODUCTS ${LIBURING_INSTALL_DIR}/include\n      DOWNLOAD_EXTRACT_TIMESTAMP true\n    )\n    set(LIBURING_INCLUDE ${LIBURING_INSTALL_DIR}/include)\n    set(LIBURING ${LIBURING_INSTALL_DIR}/lib/liburing.a)\n\n    target_include_directories(${LIB_GRAPHBOLT_NAME} PRIVATE ${LIBURING_INCLUDE})\n    add_dependencies(${LIB_GRAPHBOLT_NAME} liburing)\n    target_link_libraries(${LIB_GRAPHBOLT_NAME} ${CMAKE_CURRENT_BINARY_DIR}/third_party/liburing/lib/liburing.a)\n    message(STATUS \"Build graphbolt with liburing.\")\n  endif(USE_LIBURING)\nendif()\n\nif(USE_CUDA)\n  file(GLOB BOLT_CUDA_EXTENSION_SRC\n    ${BOLT_DIR}/cuda/extension/*.cu\n    ${BOLT_DIR}/cuda/extension/*.cc\n    ../third_party/HugeCTR/gpu_cache/src/nv_gpu_cache.cu\n  )\n  # Until https://github.com/NVIDIA/cccl/issues/1083 is resolved, we need to\n  # compile the cuda/extension folder with Volta+ CUDA architectures.\n  add_library(${LIB_GRAPHBOLT_CUDA_NAME} STATIC ${BOLT_CUDA_EXTENSION_SRC} ${BOLT_HEADERS})\n  target_link_libraries(${LIB_GRAPHBOLT_CUDA_NAME} \"${TORCH_LIBRARIES}\")\n\n  set_target_properties(${LIB_GRAPHBOLT_NAME} PROPERTIES CUDA_STANDARD 17)\n  set_target_properties(${LIB_GRAPHBOLT_CUDA_NAME} PROPERTIES CUDA_STANDARD 17)\n  set_target_properties(${LIB_GRAPHBOLT_CUDA_NAME} PROPERTIES CUDA_ARCHITECTURES \"${CMAKE_CUDA_ARCHITECTURES_FILTERED}\")\n  set_target_properties(${LIB_GRAPHBOLT_CUDA_NAME} PROPERTIES POSITION_INDEPENDENT_CODE TRUE)\n  # Enables libcudacxx for gpu_cache. \n  target_compile_definitions(${LIB_GRAPHBOLT_CUDA_NAME} PRIVATE LIBCUDACXX_VERSION)\n  include_directories(AFTER \"../third_party/HugeCTR/gpu_cache/include\")\n  message(STATUS \"Build graphbolt extension with HugeCTR GPU embedding cache.\")\n\n  message(STATUS \"Use external CCCL library for a consistent API and performance for graphbolt.\")\n  include_directories(BEFORE\n                      \"../third_party/cccl/thrust\"\n                      \"../third_party/cccl/cub\"\n                      \"../third_party/cuco/include\")\n  \n  get_property(archs TARGET ${LIB_GRAPHBOLT_NAME} PROPERTY CUDA_ARCHITECTURES)\n  message(STATUS \"CUDA_ARCHITECTURES for graphbolt: ${archs}\")\n\n  get_property(archs TARGET ${LIB_GRAPHBOLT_CUDA_NAME} PROPERTY CUDA_ARCHITECTURES)\n  message(STATUS \"CUDA_ARCHITECTURES for graphbolt extension: ${archs}\")\n\n  target_link_libraries(${LIB_GRAPHBOLT_NAME} ${LIB_GRAPHBOLT_CUDA_NAME})\nendif()\n\n# The Torch CMake configuration only sets up the path for the MKL library when\n# using the conda distribution. The following is a workaround to address this\n# when using a standalone installation of MKL.\nif(DEFINED MKL_LIBRARIES)\n  target_link_directories(${LIB_GRAPHBOLT_NAME} PRIVATE\n                          ${MKL_ROOT}/lib/${MKL_ARCH})\nendif()\n"
  },
  {
    "path": "graphbolt/build.bat",
    "content": "REM Helper script to build Graphbolt libraries for PyTorch\n@ECHO OFF\nSETLOCAL EnableDelayedExpansion\n\nMD \"%BINDIR%\\graphbolt\"\nDEL /S /Q build\nMD build\nPUSHD build\n\nIF x%1x == xx GOTO single\n\nFOR %%X IN (%*) DO (\n  DEL /S /Q *\n  \"%CMAKE_COMMAND%\" -DGPU_CACHE_BUILD_DIR=%BINDIR% -DCMAKE_CONFIGURATION_TYPES=Release -DPYTHON_INTERP=%%X -DTORCH_CUDA_ARCH_LIST=Volta .. -G \"Visual Studio 16 2019\" || EXIT /B 1\n  msbuild graphbolt.sln /m /nr:false || EXIT /B 1\n  COPY /Y Release\\*.dll \"%BINDIR%\\graphbolt\" || EXIT /B 1\n)\n\nGOTO end\n\n:single\n\nDEL /S /Q *\n\"%CMAKE_COMMAND%\" -DGPU_CACHE_BUILD_DIR=%BINDIR% -DCMAKE_CONFIGURATION_TYPES=Release -DTORCH_CUDA_ARCH_LIST=Volta .. -G \"Visual Studio 16 2019\" || EXIT /B 1\nmsbuild graphbolt.sln /m /nr:false || EXIT /B 1\nCOPY /Y Release\\*.dll \"%BINDIR%\\graphbolt\" || EXIT /B 1\n\n:end\nPOPD\n\nENDLOCAL\n"
  },
  {
    "path": "graphbolt/build.sh",
    "content": "#!/bin/bash\n# Helper script to build graphbolt libraries for PyTorch\nset -e\n\nmkdir -p build\nmkdir -p $BINDIR/graphbolt\ncd build\n\nif [ $(uname) = 'Darwin' ]; then\n  CPSOURCE=*.dylib\nelse\n  CPSOURCE=*.so\nfi\n\n# We build for the same architectures as DGL, thus we hardcode\n# TORCH_CUDA_ARCH_LIST and we need to at least compile for Volta. Until\n# https://github.com/NVIDIA/cccl/issues/1083 is resolved, we need to compile the\n# cuda/extension folder with Volta+ CUDA architectures.\nTORCH_CUDA_ARCH_LIST=\"Volta\"\nif ! [[ -z \"${CUDAARCHS}\" ]]; then\n  # The architecture list is passed as an environment variable, we set\n  # TORCH_CUDA_ARCH_LIST to the latest architecture.\n  CUDAARCHSARR=(${CUDAARCHS//;/ })\n  LAST_ARCHITECTURE=${CUDAARCHSARR[-1]}\n  # TORCH_CUDA_ARCH_LIST has to be at least 70 to override Volta default.\n  if (( $LAST_ARCHITECTURE >= 70 )); then\n    # Convert \"75\" to \"7.5\".\n    TORCH_CUDA_ARCH_LIST=${LAST_ARCHITECTURE:0:-1}'.'${LAST_ARCHITECTURE: -1}\n  fi\nfi\nCMAKE_FLAGS=\"-DCUDA_TOOLKIT_ROOT_DIR=$CUDA_TOOLKIT_ROOT_DIR -DUSE_CUDA=$USE_CUDA -DTORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST\"\necho \"graphbolt cmake flags: $CMAKE_FLAGS\"\n\nif [ $# -eq 0 ]; then\n  $CMAKE_COMMAND $CMAKE_FLAGS ..\n  make -j\n  cp -v $CPSOURCE $BINDIR/graphbolt\nelse\n  for PYTHON_INTERP in $@; do\n    TORCH_VER=$($PYTHON_INTERP -c 'import torch; print(torch.__version__.split(\"+\")[0])')\n    mkdir -p $TORCH_VER\n    cd $TORCH_VER\n    $CMAKE_COMMAND $CMAKE_FLAGS -DPYTHON_INTERP=$PYTHON_INTERP ../..\n    make -j\n    cp -v $CPSOURCE $BINDIR/graphbolt\n    cd ..\n  done\nfi\n"
  },
  {
    "path": "graphbolt/find_cmake.py",
    "content": "import os\n\nimport torch\n\ncmake_prefix_path = getattr(\n    torch.utils,\n    \"cmake_prefix_path\",\n    os.path.join(os.path.dirname(torch.__file__), \"share\", \"cmake\"),\n)\nversion = torch.__version__.split(\"+\")[0]\nprint(\";\".join([cmake_prefix_path, version]))\n"
  },
  {
    "path": "graphbolt/include/graphbolt/async.h",
    "content": "/**\n *   Copyright (c) 2024, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)\n *   All rights reserved.\n *\n *   Licensed under the Apache License, Version 2.0 (the \"License\");\n *   you may not use this file except in compliance with the License.\n *   You may obtain a copy of the License at\n *\n *       http://www.apache.org/licenses/LICENSE-2.0\n *\n *   Unless required by applicable law or agreed to in writing, software\n *   distributed under the License is distributed on an \"AS IS\" BASIS,\n *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n *   See the License for the specific language governing permissions and\n *   limitations under the License.\n *\n * @file graphbolt/async.h\n * @brief Provides asynchronous task utilities for GraphBolt.\n */\n#ifndef GRAPHBOLT_ASYNC_H_\n#define GRAPHBOLT_ASYNC_H_\n\n#include <ATen/Parallel.h>\n#include <torch/script.h>\n\n#include <future>\n#include <memory>\n#include <mutex>\n#include <variant>\n\n#ifdef BUILD_WITH_TASKFLOW\n#include <taskflow/algorithm/for_each.hpp>\n#include <taskflow/taskflow.hpp>\n#else\n#include <atomic>\n#include <exception>\n#include <type_traits>\n#endif\n\n#ifdef GRAPHBOLT_USE_CUDA\n#include <ATen/cuda/CUDAEvent.h>\n#include <c10/cuda/CUDAGuard.h>\n#include <c10/cuda/CUDAStream.h>\n#include <torch/csrc/api/include/torch/cuda.h>\n#endif\n\nnamespace graphbolt {\n\nenum ThreadPool { intraop, interop };\n\n#ifdef BUILD_WITH_TASKFLOW\n\ntemplate <ThreadPool pool_type>\ninline tf::Executor& _get_thread_pool() {\n  static std::unique_ptr<tf::Executor> pool;\n  static std::once_flag flag;\n  std::call_once(flag, [&] {\n    const int num_threads = pool_type == ThreadPool::intraop\n                                ? torch::get_num_threads()\n                                : torch::get_num_interop_threads();\n    pool = std::make_unique<tf::Executor>(num_threads);\n  });\n  return *pool.get();\n}\n\ninline tf::Executor& intraop_pool() {\n  return _get_thread_pool<ThreadPool::intraop>();\n}\n\ninline tf::Executor& interop_pool() {\n  return _get_thread_pool<ThreadPool::interop>();\n}\n\ninline tf::Executor& get_thread_pool(ThreadPool pool_type) {\n  return pool_type == ThreadPool::intraop ? intraop_pool() : interop_pool();\n}\n#endif  // BUILD_WITH_TASKFLOW\n\ninline int get_num_threads() {\n#ifdef BUILD_WITH_TASKFLOW\n  return intraop_pool().num_workers();\n#else\n  return torch::get_num_threads();\n#endif\n}\n\ninline int get_num_interop_threads() {\n#ifdef BUILD_WITH_TASKFLOW\n  return interop_pool().num_workers();\n#else\n  return torch::get_num_interop_threads();\n#endif\n}\n\ntemplate <typename T>\nclass Future : public torch::CustomClassHolder {\n#ifdef GRAPHBOLT_USE_CUDA\n  using T_no_event = std::conditional_t<std::is_void_v<T>, std::monostate, T>;\n  using T_with_event = std::conditional_t<\n      std::is_void_v<T>, at::cuda::CUDAEvent,\n      std::pair<T, at::cuda::CUDAEvent>>;\n  using future_type = std::future<std::variant<T_no_event, T_with_event>>;\n#else\n  using future_type = std::future<T>;\n#endif\n\n public:\n#ifdef GRAPHBOLT_USE_CUDA\n  using return_type = std::variant<T_no_event, T_with_event>;\n#else\n  using return_type = T;\n#endif\n\n  Future(future_type&& future) : future_(std::move(future)) {}\n\n  Future() = default;\n\n  T Wait() {\n#ifdef GRAPHBOLT_USE_CUDA\n    auto result = future_.get();\n    if constexpr (std::is_void_v<T>) {\n      if (std::holds_alternative<T_with_event>(result)) {\n        auto&& event = std::get<T_with_event>(result);\n        event.block(c10::cuda::getCurrentCUDAStream());\n      }\n      return;\n    } else if (std::holds_alternative<T_with_event>(result)) {\n      auto&& [value, event] = std::get<T_with_event>(result);\n      event.block(c10::cuda::getCurrentCUDAStream());\n      return value;\n    } else {\n      return std::get<T_no_event>(result);\n    }\n#else\n    return future_.get();\n#endif\n  }\n\n private:\n  future_type future_;\n};\n\n/**\n * @brief Utilizes at::launch to launch an async task in the interop thread\n * pool. We should not make use of any native CPU torch ops inside the launched\n * task to avoid spawning a new OpenMP threadpool on each interop thread.\n */\ntemplate <typename F>\ninline auto async(F&& function, bool is_cuda = false) {\n  using T = decltype(function());\n#ifdef GRAPHBOLT_USE_CUDA\n  struct c10::StreamData3 stream_data;\n  if (is_cuda) {\n    stream_data = c10::cuda::getCurrentCUDAStream().pack3();\n  }\n#endif\n  using return_type = typename Future<T>::return_type;\n  auto fn = [=, func = std::move(function)]() -> return_type {\n#ifdef GRAPHBOLT_USE_CUDA\n    // We make sure to use the same CUDA stream as the thread launching the\n    // async operation.\n    if (is_cuda) {\n      auto stream = c10::cuda::CUDAStream::unpack3(\n          stream_data.stream_id, stream_data.device_index,\n          stream_data.device_type);\n      c10::cuda::CUDAStreamGuard guard(stream);\n      at::cuda::CUDAEvent event;\n      // Might be executed on the GPU so we record an event to be able to\n      // synchronize with it later, in case it is executed on an alternative\n      // CUDA stream.\n      if constexpr (std::is_void_v<T>) {\n        func();\n        event.record();\n        return event;\n      } else {\n        auto result = func();\n        event.record();\n        return std::make_pair(std::move(result), std::move(event));\n      }\n    }\n    if constexpr (std::is_void_v<T>) {\n      func();\n      return std::monostate{};\n    } else {\n      return func();\n    }\n#else\n    return func();\n#endif\n  };\n#ifdef BUILD_WITH_TASKFLOW\n  auto future = interop_pool().async(std::move(fn));\n#else\n  auto promise = std::make_shared<std::promise<return_type>>();\n  auto future = promise->get_future();\n  at::launch([promise, func = std::move(fn)]() {\n    if constexpr (std::is_void_v<return_type>) {\n      func();\n      promise->set_value();\n    } else\n      promise->set_value(func());\n  });\n#endif\n  return c10::make_intrusive<Future<T>>(std::move(future));\n}\n\ntemplate <ThreadPool pool_type, bool for_each, typename F>\ninline void _parallel_for(\n    const int64_t begin, const int64_t end, const int64_t grain_size,\n    const F& f) {\n  if (begin >= end) return;\n  int64_t num_threads = get_num_threads();\n  const auto num_iter = end - begin;\n  const bool use_parallel =\n      (num_iter > grain_size && num_iter > 1 && num_threads > 1);\n  if (!use_parallel) {\n    if constexpr (for_each) {\n      for (int64_t i = begin; i < end; i++) f(i);\n    } else {\n      f(begin, end);\n    }\n    return;\n  }\n  if (grain_size > 0) {\n    num_threads = std::min(num_threads, at::divup(end - begin, grain_size));\n  }\n  int64_t chunk_size = at::divup((end - begin), num_threads);\n#ifdef BUILD_WITH_TASKFLOW\n  tf::Taskflow flow;\n  flow.for_each_index(int64_t{0}, num_threads, int64_t{1}, [=](int64_t tid) {\n    const int64_t begin_tid = begin + tid * chunk_size;\n    if (begin_tid < end) {\n      const int64_t end_tid = std::min(end, begin_tid + chunk_size);\n      if constexpr (for_each) {\n        for (int64_t i = begin_tid; i < end_tid; i++) f(i);\n      } else {\n        f(begin_tid, end_tid);\n      }\n    }\n  });\n  _get_thread_pool<pool_type>().run(flow).get();\n#else\n  std::promise<void> promise;\n  std::future<void> future;\n  std::atomic_flag err_flag = ATOMIC_FLAG_INIT;\n  std::exception_ptr eptr;\n  int num_launched = 0;\n  std::atomic<int> num_finished = 0;\n  for (int tid = num_threads - 1; tid >= 0; tid--) {\n    const int64_t begin_tid = begin + tid * chunk_size;\n    if (begin_tid < end) {\n      const int64_t end_tid = std::min(end, begin_tid + chunk_size);\n      if (tid == 0) {\n        // Launch the thread 0's work inline.\n        if constexpr (for_each) {\n          for (int64_t i = begin_tid; i < end_tid; i++) f(i);\n        } else {\n          f(begin_tid, end_tid);\n        }\n        continue;\n      }\n      if (!future.valid()) {\n        future = promise.get_future();\n        num_launched = tid;\n      }\n      at::launch([&f, &err_flag, &eptr, &promise, &num_finished, num_launched,\n                  begin_tid, end_tid] {\n        try {\n          if constexpr (for_each) {\n            for (int64_t i = begin_tid; i < end_tid; i++) f(i);\n          } else {\n            f(begin_tid, end_tid);\n          }\n        } catch (...) {\n          if (!err_flag.test_and_set()) {\n            eptr = std::current_exception();\n          }\n        }\n        auto ticket = num_finished.fetch_add(1, std::memory_order_release);\n        if (1 + ticket == num_launched) {\n          // The last thread signals the end of execution.\n          promise.set_value();\n        }\n      });\n    }\n  }\n  // Wait for the launched work to finish.\n  if (num_launched > 0) {\n    future.get();\n    if (eptr) {\n      std::rethrow_exception(eptr);\n    }\n  }\n#endif\n}\n\n/**\n * @brief GraphBolt's version of torch::parallel_for. Since torch::parallel_for\n * uses OpenMP threadpool, async tasks can not make use of it due to multiple\n * OpenMP threadpools being created for each async thread. Moreover, inside\n * graphbolt::parallel_for, we should not make use of any native CPU torch ops\n * as they will spawn an OpenMP threadpool.\n */\ntemplate <typename F>\ninline void parallel_for(\n    const int64_t begin, const int64_t end, const int64_t grain_size,\n    const F& f) {\n  _parallel_for<ThreadPool::intraop, false>(begin, end, grain_size, f);\n}\n\n/**\n * @brief Compared to parallel_for, it expects the passed function to take a\n * single argument for each iteration.\n */\ntemplate <typename F>\ninline void parallel_for_each(\n    const int64_t begin, const int64_t end, const int64_t grain_size,\n    const F& f) {\n  _parallel_for<ThreadPool::intraop, true>(begin, end, grain_size, f);\n}\n\n/**\n * @brief Same as parallel_for but uses the interop thread pool.\n */\ntemplate <typename F>\ninline void parallel_for_interop(\n    const int64_t begin, const int64_t end, const int64_t grain_size,\n    const F& f) {\n  _parallel_for<ThreadPool::interop, false>(begin, end, grain_size, f);\n}\n\n/**\n * @brief Compared to parallel_for_interop, it expects the passed function to\n * take a single argument for each iteration.\n */\ntemplate <typename F>\ninline void parallel_for_each_interop(\n    const int64_t begin, const int64_t end, const int64_t grain_size,\n    const F& f) {\n  _parallel_for<ThreadPool::interop, true>(begin, end, grain_size, f);\n}\n\n}  // namespace graphbolt\n\n#endif  // GRAPHBOLT_ASYNC_H_\n"
  },
  {
    "path": "graphbolt/include/graphbolt/continuous_seed.h",
    "content": "/**\n *   Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)\n *   All rights reserved.\n *\n *   Licensed under the Apache License, Version 2.0 (the \"License\");\n *   you may not use this file except in compliance with the License.\n *   You may obtain a copy of the License at\n *\n *       http://www.apache.org/licenses/LICENSE-2.0\n *\n *   Unless required by applicable law or agreed to in writing, software\n *   distributed under the License is distributed on an \"AS IS\" BASIS,\n *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n *   See the License for the specific language governing permissions and\n *   limitations under the License.\n *\n * @file graphbolt/continuous_seed.h\n * @brief CPU and CUDA implementation for continuous random seeds\n */\n#ifndef GRAPHBOLT_CONTINUOUS_SEED_H_\n#define GRAPHBOLT_CONTINUOUS_SEED_H_\n\n#include <torch/script.h>\n\n#include <cmath>\n\n#ifdef __CUDACC__\n#include <curand_kernel.h>\n#else\n#include <pcg_random.hpp>\n#include <random>\n#endif  // __CUDA_ARCH__\n\n#ifndef M_SQRT1_2\n#define M_SQRT1_2 0.707106781186547524401\n#endif  // M_SQRT1_2\n\nnamespace graphbolt {\n\nclass continuous_seed {\n  uint64_t s[2];\n  float c[2];\n\n public:\n  /* implicit */ continuous_seed(const int64_t seed) {  // NOLINT\n    s[0] = s[1] = seed;\n    c[0] = c[1] = 0;\n  }\n\n  continuous_seed(torch::Tensor seed_arr, float r) {\n    auto seed = seed_arr.data_ptr<int64_t>();\n    s[0] = seed[0];\n    s[1] = seed[seed_arr.size(0) - 1];\n    const auto pi = std::acos(-1.0);\n    c[0] = std::cos(pi * r / 2);\n    c[1] = std::sin(pi * r / 2);\n  }\n\n  uint64_t get_seed(int i) const { return s[i != 0]; }\n\n#ifdef __CUDACC__\n  __device__ inline float uniform(const uint64_t t) const {\n    const uint64_t kCurandSeed = 999961;  // Could be any random number.\n    curandStatePhilox4_32_10_t rng;\n    curand_init(kCurandSeed, s[0], t, &rng);\n    float rnd;\n    if (s[0] != s[1]) {\n      rnd = c[0] * curand_normal(&rng);\n      curand_init(kCurandSeed, s[1], t, &rng);\n      rnd += c[1] * curand_normal(&rng);\n      rnd = normcdff(rnd);\n    } else {\n      rnd = curand_uniform(&rng);\n    }\n    return rnd;\n  }\n#else\n  inline float uniform(const uint64_t t) const {\n    pcg32 ng0(s[0], t);\n    float rnd;\n    if (s[0] != s[1]) {\n      std::normal_distribution<float> norm;\n      rnd = c[0] * norm(ng0);\n      pcg32 ng1(s[1], t);\n      norm.reset();\n      rnd += c[1] * norm(ng1);\n      rnd = std::erfc(-rnd * static_cast<float>(M_SQRT1_2)) / 2.0f;\n    } else {\n      std::uniform_real_distribution<float> uni;\n      rnd = uni(ng0);\n    }\n    return rnd;\n  }\n#endif  // __CUDA_ARCH__\n};\n\nclass single_seed {\n  uint64_t seed_;\n\n public:\n  /* implicit */ single_seed(const int64_t seed) : seed_(seed) {}  // NOLINT\n\n  single_seed(torch::Tensor seed_arr)\n      : seed_(seed_arr.data_ptr<int64_t>()[0]) {}\n\n#ifdef __CUDACC__\n  __device__ inline float uniform(const uint64_t id) const {\n    const uint64_t kCurandSeed = 999961;  // Could be any random number.\n    curandStatePhilox4_32_10_t rng;\n    curand_init(kCurandSeed, seed_, id, &rng);\n    return curand_uniform(&rng);\n  }\n#else\n  inline float uniform(const uint64_t id) const {\n    pcg32 ng0(seed_, id);\n    std::uniform_real_distribution<float> uni;\n    return uni(ng0);\n  }\n#endif  // __CUDA_ARCH__\n};\n\n}  // namespace graphbolt\n\n#endif  // GRAPHBOLT_CONTINUOUS_SEED_H_\n"
  },
  {
    "path": "graphbolt/include/graphbolt/cuda_ops.h",
    "content": "/**\n *   Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)\n *   All rights reserved.\n *\n *   Licensed under the Apache License, Version 2.0 (the \"License\");\n *   you may not use this file except in compliance with the License.\n *   You may obtain a copy of the License at\n *\n *       http://www.apache.org/licenses/LICENSE-2.0\n *\n *   Unless required by applicable law or agreed to in writing, software\n *   distributed under the License is distributed on an \"AS IS\" BASIS,\n *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n *   See the License for the specific language governing permissions and\n *   limitations under the License.\n *\n * @file graphbolt/cuda_ops.h\n * @brief Available CUDA operations in Graphbolt.\n */\n#ifndef GRAPHBOLT_CUDA_OPS_H_\n#define GRAPHBOLT_CUDA_OPS_H_\n\n#include <torch/script.h>\n\n#include <type_traits>\n\nnamespace graphbolt {\nnamespace ops {\n\n/**\n * @brief Sorts the given input and optionally returns the original indexes.\n *\n * @param input         A pointer to storage containing IDs.\n * @param num_items     Size of the input storage.\n * @param num_bits      An integer such that all elements of input tensor are\n *                      are less than (1 << num_bits).\n *\n * @return\n * - A tuple of tensors if return_original_positions is true, where the first\n * one includes sorted input, the second contains original positions of the\n * sorted result. If return_original_positions is false, then returns only the\n * sorted input.\n */\ntemplate <bool return_original_positions, typename scalar_t>\nstd::conditional_t<\n    return_original_positions, std::pair<torch::Tensor, torch::Tensor>,\n    torch::Tensor>\nSort(const scalar_t* input, int64_t num_items, int num_bits);\n\n/**\n * @brief Sorts the given input and optionally returns the original indexes.\n *\n * @param input         A tensor containing IDs.\n * @param num_bits      An integer such that all elements of input tensor are\n *                      are less than (1 << num_bits).\n *\n * @return\n * - A tuple of tensors if return_original_positions is true, where the first\n * one includes sorted input, the second contains original positions of the\n * sorted result. If return_original_positions is false, then returns only the\n * sorted input.\n */\ntemplate <bool return_original_positions = true>\nstd::conditional_t<\n    return_original_positions, std::pair<torch::Tensor, torch::Tensor>,\n    torch::Tensor>\nSort(torch::Tensor input, int num_bits = 0);\n\n/**\n * @brief Tests if each element of elements is in test_elements. Returns a\n * boolean tensor of the same shape as elements that is True for elements\n * in test_elements and False otherwise. Enhance torch.isin by implementing\n * multi-threaded searching, as detailed in the documentation at\n * https://pytorch.org/docs/stable/generated/torch.isin.html.\"\n *\n * @param elements        Input elements\n * @param test_elements   Values against which to test for each input element.\n *\n * @return\n * A boolean tensor of the same shape as elements that is True for elements\n * in test_elements and False otherwise.\n */\ntorch::Tensor IsIn(torch::Tensor elements, torch::Tensor test_elements);\n\n/**\n * @brief Returns the indexes of the nonzero elements in the given boolean mask\n * if logical_not is false. Otherwise, returns the indexes of the zero elements\n * instead.\n *\n * @param mask        Input boolean mask.\n * @param logical_not Whether mask should be treated as ~mask.\n *\n * @return An int64_t tensor of the same shape as mask containing the indexes\n * of the selected elements.\n */\ntorch::Tensor Nonzero(torch::Tensor mask, bool logical_not);\n\n/**\n * @brief Select columns for a sparse matrix in a CSC format according to nodes\n * tensor.\n *\n * NOTE: The shape of all tensors must be 1-D.\n *\n * @param in_degree Indegree tensor containing degrees of nodes being copied.\n * @param sliced_indptr Sliced_indptr tensor containing indptr values of nodes\n * being copied.\n * @param indices Indices tensor with edge information of shape (indptr[N],).\n * @param nodes Nodes tensor with shape (M,).\n * @param nodes_max An upperbound on `nodes.max()`.\n * @param output_size The total number of edges being copied.\n * @return (torch::Tensor, torch::Tensor) Output indptr and indices tensors of\n * shapes (M + 1,) and ((indptr[nodes + 1] - indptr[nodes]).sum(),).\n */\nstd::tuple<torch::Tensor, torch::Tensor> IndexSelectCSCImpl(\n    torch::Tensor in_degree, torch::Tensor sliced_indptr, torch::Tensor indices,\n    torch::Tensor nodes, int64_t nodes_max,\n    torch::optional<int64_t> output_size = torch::nullopt);\n\n/**\n * @brief Select columns for a sparse matrix in a CSC format according to nodes\n * tensor.\n *\n * NOTE: The shape of all tensors must be 1-D.\n *\n * @param indptr Indptr tensor containing offsets with shape (N,).\n * @param indices Indices tensor with edge information of shape (indptr[N],).\n * @param nodes Nodes tensor with shape (M,).\n * @param output_size The total number of edges being copied.\n * @return (torch::Tensor, torch::Tensor) Output indptr and indices tensors of\n * shapes (M + 1,) and ((indptr[nodes + 1] - indptr[nodes]).sum(),).\n */\nstd::tuple<torch::Tensor, torch::Tensor> IndexSelectCSCImpl(\n    torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes,\n    torch::optional<int64_t> output_size = torch::nullopt);\n\n/**\n * @brief Select columns for a sparse matrix in a CSC format according to nodes\n * tensor for a given list of tensors.\n *\n * NOTE: The shape of all tensors must be 1-D.\n *\n * @param indptr Indptr tensor containing offsets with shape (N,).\n * @param indices_list Vector of indices tensor with edge information of shape\n * (indptr[N],).\n * @param nodes Nodes tensor with shape (M,).\n * @param with_edge_ids Whether to return edge ids tensor corresponding to\n * sliced edges as the last element of the output.\n * @param output_size The total number of edges being copied.\n * @return (torch::Tensor, std::vector<torch::Tensor>) Output indptr and vector\n * of indices tensors of shapes (M + 1,) and ((indptr[nodes + 1] -\n * indptr[nodes]).sum(),).\n */\nstd::tuple<torch::Tensor, std::vector<torch::Tensor>> IndexSelectCSCBatchedImpl(\n    torch::Tensor indptr, std::vector<torch::Tensor> indices_list,\n    torch::Tensor nodes, bool with_edge_ids,\n    torch::optional<int64_t> output_size);\n\n/**\n * @brief Slices the indptr tensor with nodes and returns the indegrees of the\n * given nodes and their indptr values.\n *\n * @param indptr The indptr tensor.\n * @param nodes  The nodes to read from indptr. If not provided, assumed to be\n * equal to torch.arange(indptr.size(0) - 1).\n *\n * @return Tuple of tensors with values:\n * (indptr[nodes + 1] - indptr[nodes], indptr[nodes]), the returned indegrees\n * tensor (first one) has size nodes.size(0) + 1 so that calling ExclusiveCumSum\n * on it gives the output indptr.\n */\nstd::tuple<torch::Tensor, torch::Tensor> SliceCSCIndptr(\n    torch::Tensor indptr, torch::optional<torch::Tensor> nodes);\n\n/**\n * @brief Given the compacted sub_indptr tensor, edge type tensor and\n * sliced_indptr tensor of the original graph, returns the heterogenous\n * versions of sub_indptr, indegrees and sliced_indptr.\n *\n * @param sub_indptr     The compacted indptr tensor.\n * @param etypes         The compacted type_per_edge tensor.\n * @param sliced_indptr  The sliced_indptr tensor of original graph.\n * @param num_fanouts    The number of fanout values.\n *\n * @return Tuple of tensors (new_sub_indptr, new_indegrees, new_sliced_indptr):\n */\nstd::tuple<torch::Tensor, torch::Tensor, torch::Tensor> SliceCSCIndptrHetero(\n    torch::Tensor sub_indptr, torch::Tensor etypes, torch::Tensor sliced_indptr,\n    int64_t num_fanouts);\n\n/**\n * @brief Computes the exclusive prefix sum of the given input.\n *\n * @param input The input tensor.\n *\n * @return The prefix sum result such that r[i] = \\sum_{j=0}^{i-1} input[j]\n */\ntorch::Tensor ExclusiveCumSum(torch::Tensor input);\n\n/**\n * @brief Computes the gather operation on a given input and index tensor.\n *\n * @param input The input tensor.\n * @param index The index tensor.\n * @param dtype The optional output dtype. If not given, inferred from the input\n * tensor.\n *\n * @return The result of the input.gather(0, index).to(dtype) operation.\n */\ntorch::Tensor Gather(\n    torch::Tensor input, torch::Tensor index,\n    torch::optional<torch::ScalarType> dtype = torch::nullopt);\n\n/**\n * @brief Select rows from input tensor according to index tensor.\n *\n * NOTE:\n * 1. The shape of input tensor can be multi-dimensional, but the index tensor\n * must be 1-D.\n * 2. Should be called if input is on pinned memory and index is on pinned\n * memory or GPU memory.\n *\n * @param input Input tensor with shape (N, ...).\n * @param index Index tensor with shape (M,).\n * @return torch::Tensor Output tensor with shape (M, ...).\n */\ntorch::Tensor UVAIndexSelectImpl(torch::Tensor input, torch::Tensor index);\n\n/**\n * @brief ExpandIndptrImpl implements conversion from a given indptr offset\n * tensor to a COO format tensor. If node_ids is not given, it is assumed to be\n * equal to torch::arange(indptr.size(0) - 1, dtype=dtype).\n *\n * @param indptr       The indptr offset tensor.\n * @param dtype        The dtype of the returned output tensor.\n * @param node_ids     Optional 1D tensor represents the node ids.\n * @param output_size  Optional value of indptr[-1]. Passing it eliminates CPU\n * GPU synchronization.\n *\n * @return The resulting tensor.\n */\ntorch::Tensor ExpandIndptrImpl(\n    torch::Tensor indptr, torch::ScalarType dtype,\n    torch::optional<torch::Tensor> node_ids = torch::nullopt,\n    torch::optional<int64_t> output_size = torch::nullopt);\n\n/**\n * @brief IndptrEdgeIdsImpl implements conversion from a given indptr offset\n * tensor to a COO edge ids tensor. For a given indptr [0, 2, 5, 7] and offset\n * tensor [0, 100, 200], the output will be [0, 1, 100, 101, 102, 201, 202]. If\n * offset was not provided, the output would be [0, 1, 0, 1, 2, 0, 1].\n *\n * @param indptr       The indptr offset tensor.\n * @param dtype        The dtype of the returned output tensor.\n * @param offset       The offset tensor.\n * @param output_size  Optional value of indptr[-1]. Passing it eliminates CPU\n * GPU synchronization.\n *\n * @return The resulting tensor.\n */\ntorch::Tensor IndptrEdgeIdsImpl(\n    torch::Tensor indptr, torch::ScalarType dtype,\n    torch::optional<torch::Tensor> offset,\n    torch::optional<int64_t> output_size);\n\n/**\n * @brief Removes duplicate elements from the concatenated 'unique_dst_ids' and\n * 'src_ids' tensor and applies the uniqueness information to compact both\n * source and destination tensors.\n *\n * The function performs two main operations:\n *   1. Unique Operation: 'unique(concat(unique_dst_ids, src_ids))', in which\n * the unique operator will guarantee the 'unique_dst_ids' are at the head of\n * the result tensor.\n *   2. Compact Operation: Utilizes the reverse mapping derived from the unique\n * operation to transform 'src_ids' and 'dst_ids' into compacted IDs.\n *\n * When world_size is greater than 1, then the given ids are partitioned between\n * the available ranks. The ids corresponding to the given rank are guaranteed\n * to come before the ids of other ranks. To do this, the partition ids are\n * rotated backwards by the given rank so that the ids are ordered as:\n * [rank, rank + 1, world_size, 0, ..., rank - 1]. This is supported only for\n * Volta and later generation NVIDIA GPUs.\n *\n * @param src_ids         A tensor containing source IDs.\n * @param dst_ids         A tensor containing destination IDs.\n * @param unique_dst_ids  A tensor containing unique destination IDs, which is\n *                        exactly all the unique elements in 'dst_ids'.\n * @param rank            The rank of the current GPU.\n * @param world_size      The total # GPUs, world size.\n *\n * @return (unique_ids, compacted_src_ids, compacted_dst_ids, unique_offsets)\n * - A tensor representing all unique elements in 'src_ids' and 'dst_ids' after\n * removing duplicates. The indices in this tensor precisely match the compacted\n * IDs of the corresponding elements.\n * - The tensor corresponding to the 'src_ids' tensor, where the entries are\n * mapped to compacted IDs.\n * - The tensor corresponding to the 'dst_ids' tensor, where the entries are\n * mapped to compacted IDs.\n * - The tensor corresponding to the offsets into the unique_ids tensor. Has\n * size `world_size + 1` and unique_ids[offsets[i]: offsets[i + 1]] belongs to\n * the rank `(rank + i) % world_size`.\n *\n * @example\n *   torch::Tensor src_ids = src\n *   torch::Tensor dst_ids = dst\n *   torch::Tensor unique_dst_ids = torch::unique(dst);\n *   auto result = UniqueAndCompact(src_ids, dst_ids, unique_dst_ids);\n *   torch::Tensor unique_ids = std::get<0>(result);\n *   torch::Tensor compacted_src_ids = std::get<1>(result);\n *   torch::Tensor compacted_dst_ids = std::get<2>(result);\n */\nstd::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>\nUniqueAndCompact(\n    const torch::Tensor src_ids, const torch::Tensor dst_ids,\n    const torch::Tensor unique_dst_ids, const int64_t rank,\n    const int64_t world_size);\n\n/**\n * @brief Batched version of UniqueAndCompact. The ith element of the return\n * value is equal to the passing the ith elements of the input arguments to\n * UniqueAndCompact.\n */\nstd::vector<\n    std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>>\nUniqueAndCompactBatched(\n    const std::vector<torch::Tensor>& src_ids,\n    const std::vector<torch::Tensor>& dst_ids,\n    const std::vector<torch::Tensor>& unique_dst_ids, const int64_t rank,\n    const int64_t world_size);\n\n}  //  namespace ops\n}  //  namespace graphbolt\n\n#endif  // GRAPHBOLT_CUDA_OPS_H_\n"
  },
  {
    "path": "graphbolt/include/graphbolt/cuda_sampling_ops.h",
    "content": "/**\n *   Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)\n *   All rights reserved.\n *\n *   Licensed under the Apache License, Version 2.0 (the \"License\");\n *   you may not use this file except in compliance with the License.\n *   You may obtain a copy of the License at\n *\n *       http://www.apache.org/licenses/LICENSE-2.0\n *\n *   Unless required by applicable law or agreed to in writing, software\n *   distributed under the License is distributed on an \"AS IS\" BASIS,\n *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n *   See the License for the specific language governing permissions and\n *   limitations under the License.\n *\n * @file graphbolt/cuda_sampling_ops.h\n * @brief Available CUDA sampling operations in Graphbolt.\n */\n#ifndef GRAPHBOLT_CUDA_SAMPLING_OPS_H_\n#define GRAPHBOLT_CUDA_SAMPLING_OPS_H_\n\n#include <graphbolt/fused_sampled_subgraph.h>\n#include <torch/script.h>\n\nnamespace graphbolt {\nnamespace ops {\n\n/**\n * @brief Sample neighboring edges of the given nodes and return the induced\n * subgraph.\n *\n * @param indptr Index pointer array of the CSC.\n * @param indices Indices array of the CSC.\n * @param seeds The nodes from which to sample neighbors. If not provided,\n * assumed to be equal to torch.arange(indptr.size(0) - 1).\n * @param seed_offsets The offsets of the given seeds,\n * seeds[seed_offsets[i]: seed_offsets[i + 1]] has node type i.\n * @param fanouts The number of edges to be sampled for each node with or\n * without considering edge types.\n *   - When the length is 1, it indicates that the fanout applies to all\n * neighbors of the node as a collective, regardless of the edge type.\n *   - Otherwise, the length should equal to the number of edge types, and\n * each fanout value corresponds to a specific edge type of the node.\n * The value of each fanout should be >= 0 or = -1.\n *   - When the value is -1, all neighbors will be chosen for sampling. It is\n * equivalent to selecting all neighbors with non-zero probability when the\n * fanout is >= the number of neighbors (and replacement is set to false).\n *   - When the value is a non-negative integer, it serves as a minimum\n * threshold for selecting neighbors.\n * @param replace Boolean indicating whether the sample is preformed with or\n * without replacement. If True, a value can be selected multiple times.\n * Otherwise, each value can be selected only once.\n * @param layer Boolean indicating whether neighbors should be sampled in a\n * layer sampling fashion. Uses the LABOR-0 algorithm to increase overlap of\n * sampled edges, see arXiv:2210.13339.\n * @param returning_indices_is_optional Boolean indicating whether returning\n * indices tensor is optional.\n * @param type_per_edge A tensor representing the type of each edge, if present.\n * @param probs_or_mask An optional tensor with (unnormalized) probabilities\n * corresponding to each neighboring edge of a node. It must be\n * a 1D tensor, with the number of elements equaling the total number of edges.\n * @param node_type_to_id A dictionary mapping node type names to type IDs. The\n * length of it is equal to the number of node types. The key is the node type\n * name, and the value is the corresponding type ID.\n * @param edge_type_to_id A dictionary mapping edge type names to type IDs. The\n * length of it is equal to the number of edge types. The key is the edge type\n * name, and the value is the corresponding type ID.\n * @param random_seed The random seed for the sampler for layer=True.\n * @param seed2_contribution The contribution of the second random seed, [0, 1)\n * for layer=True.\n * @param seeds_timestamp The timestamp of the seeds.\n * @param seeds_pre_time_window The time window of the seeds represents a period\n * of time before `seeds_timestamp`. If provided, only neighbors and related\n * edges whose timestamps fall within\n * `[seeds_timestamp - seeds_pre_time_window, seeds_timestamp]` will be\n * filtered.\n * @param node_timestamp An optional tensor that contains the timestamp of nodes\n * in the graph.\n * @param edge_timestamp An optional tensor that contains the timestamp of edges\n * in the graph.\n *\n * @return An intrusive pointer to a FusedSampledSubgraph object containing\n * the sampled graph's information.\n */\nc10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(\n    torch::Tensor indptr, torch::Tensor indices,\n    torch::optional<torch::Tensor> seeds,\n    torch::optional<std::vector<int64_t>> seed_offsets,\n    const std::vector<int64_t>& fanouts, bool replace, bool layer,\n    bool returning_indices_is_optional,\n    torch::optional<torch::Tensor> type_per_edge = torch::nullopt,\n    torch::optional<torch::Tensor> probs_or_mask = torch::nullopt,\n    torch::optional<torch::Tensor> node_type_offset = torch::nullopt,\n    torch::optional<torch::Dict<std::string, int64_t>> node_type_to_id =\n        torch::nullopt,\n    torch::optional<torch::Dict<std::string, int64_t>> edge_type_to_id =\n        torch::nullopt,\n    torch::optional<torch::Tensor> random_seed = torch::nullopt,\n    float seed2_contribution = .0f,\n    // Optional temporal sampling arguments begin.\n    torch::optional<torch::Tensor> seeds_timestamp = torch::nullopt,\n    torch::optional<torch::Tensor> seeds_pre_time_window = torch::nullopt,\n    torch::optional<torch::Tensor> node_timestamp = torch::nullopt,\n    torch::optional<torch::Tensor> edge_timestamp = torch::nullopt\n    // Optional temporal sampling arguments end.\n);\n\n/**\n * @brief Return the subgraph induced on the inbound edges of the given nodes.\n * @param nodes Type agnostic node IDs to form the subgraph.\n *\n * @return FusedSampledSubgraph.\n */\nc10::intrusive_ptr<sampling::FusedSampledSubgraph> InSubgraph(\n    torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes,\n    torch::optional<torch::Tensor> type_per_edge);\n\n}  //  namespace ops\n}  //  namespace graphbolt\n\n#endif  // GRAPHBOLT_CUDA_SAMPLING_OPS_H_\n"
  },
  {
    "path": "graphbolt/include/graphbolt/fused_csc_sampling_graph.h",
    "content": "/**\n *  Copyright (c) 2023 by Contributors\n * @file graphbolt/fused_csc_sampling_graph.h\n * @brief Header file of csc sampling graph.\n */\n#ifndef GRAPHBOLT_CSC_SAMPLING_GRAPH_H_\n#define GRAPHBOLT_CSC_SAMPLING_GRAPH_H_\n\n#include <graphbolt/async.h>\n#include <graphbolt/continuous_seed.h>\n#include <graphbolt/fused_sampled_subgraph.h>\n#include <graphbolt/shared_memory.h>\n#include <torch/torch.h>\n\n#include <string>\n#include <vector>\n\nnamespace graphbolt {\nnamespace sampling {\n\nenum SamplerType { NEIGHBOR, LABOR, LABOR_DEPENDENT };\nenum TemporalOption { NOT_TEMPORAL, TEMPORAL };\n\nconstexpr bool is_labor(SamplerType S) {\n  return S == SamplerType::LABOR || S == SamplerType::LABOR_DEPENDENT;\n}\n\ntemplate <SamplerType S>\nstruct SamplerArgs;\n\ntemplate <>\nstruct SamplerArgs<SamplerType::NEIGHBOR> {};\n\ntemplate <>\nstruct SamplerArgs<SamplerType::LABOR> {\n  const torch::Tensor& indices;\n  single_seed random_seed;\n  int64_t num_nodes;\n};\n\ntemplate <>\nstruct SamplerArgs<SamplerType::LABOR_DEPENDENT> {\n  const torch::Tensor& indices;\n  continuous_seed random_seed;\n  int64_t num_nodes;\n};\n\n/**\n * @brief A sampling oriented csc format graph.\n *\n * Example usage:\n *\n * Suppose the graph has 3 node types, 3 edge types and 6 edges\n * auto node_type_offset = {0, 2, 4, 6}\n * auto type_per_edge = {0, 1, 0, 2, 1, 2}\n * auto graph = FusedCSCSamplingGraph(..., ..., node_type_offset, type_per_edge)\n *\n * The `node_type_offset` tensor represents the offset array of node type, the\n * given array indicates that node [0, 2) has type id 0, [2, 4) has type id 1,\n * and [4, 6) has type id 2. And the `type_per_edge` tensor represents the type\n * id of each edge.\n */\nclass FusedCSCSamplingGraph : public torch::CustomClassHolder {\n public:\n  using NodeTypeToIDMap = torch::Dict<std::string, int64_t>;\n  using EdgeTypeToIDMap = torch::Dict<std::string, int64_t>;\n  using NodeAttrMap = torch::Dict<std::string, torch::Tensor>;\n  using EdgeAttrMap = torch::Dict<std::string, torch::Tensor>;\n  /** @brief Default constructor. */\n  FusedCSCSamplingGraph() = default;\n\n  /**\n   * @brief Constructor for CSC with data.\n   * @param indptr The CSC format index pointer array.\n   * @param indices The CSC format index array.\n   * @param node_type_offset A tensor representing the offset of node types, if\n   * present.\n   * @param type_per_edge A tensor representing the type of each edge, if\n   * present.\n   * @param node_type_to_id A dictionary mapping node type names to type IDs, if\n   * present.\n   * @param edge_type_to_id A dictionary mapping edge type names to type IDs, if\n   * present.\n   * @param node_attributes A dictionary of node attributes, if present.\n   * @param edge_attributes A dictionary of edge attributes, if present.\n   *\n   */\n  FusedCSCSamplingGraph(\n      const torch::Tensor& indptr, const torch::Tensor& indices,\n      const torch::optional<torch::Tensor>& node_type_offset = torch::nullopt,\n      const torch::optional<torch::Tensor>& type_per_edge = torch::nullopt,\n      const torch::optional<NodeTypeToIDMap>& node_type_to_id = torch::nullopt,\n      const torch::optional<EdgeTypeToIDMap>& edge_type_to_id = torch::nullopt,\n      const torch::optional<NodeAttrMap>& node_attributes = torch::nullopt,\n      const torch::optional<EdgeAttrMap>& edge_attributes = torch::nullopt);\n\n  /**\n   * @brief Create a fused CSC graph from tensors of CSC format.\n   * @param indptr Index pointer array of the CSC.\n   * @param indices Indices array of the CSC.\n   * @param node_type_offset A tensor representing the offset of node types, if\n   * present.\n   * @param type_per_edge A tensor representing the type of each edge, if\n   * present.\n   * @param node_type_to_id A dictionary mapping node type names to type IDs, if\n   * present.\n   * @param edge_type_to_id A dictionary mapping edge type names to type IDs, if\n   * present.\n   * @param node_attributes A dictionary of node attributes, if present.\n   * @param edge_attributes A dictionary of edge attributes, if present.\n   *\n   * @return FusedCSCSamplingGraph\n   */\n  static c10::intrusive_ptr<FusedCSCSamplingGraph> Create(\n      const torch::Tensor& indptr, const torch::Tensor& indices,\n      const torch::optional<torch::Tensor>& node_type_offset,\n      const torch::optional<torch::Tensor>& type_per_edge,\n      const torch::optional<NodeTypeToIDMap>& node_type_to_id,\n      const torch::optional<EdgeTypeToIDMap>& edge_type_to_id,\n      const torch::optional<NodeAttrMap>& node_attributes,\n      const torch::optional<EdgeAttrMap>& edge_attributes);\n\n  /** @brief Get the number of nodes. */\n  int64_t NumNodes() const { return indptr_.size(0) - 1; }\n\n  /** @brief Get the number of edges. */\n  int64_t NumEdges() const { return indices_.size(0); }\n\n  /** @brief Get the csc index pointer tensor. */\n  const torch::Tensor CSCIndptr() const { return indptr_; }\n\n  /** @brief Get the index tensor. */\n  const torch::Tensor Indices() const { return indices_; }\n\n  /** @brief Get the node type offset tensor for a heterogeneous graph. */\n  inline const torch::optional<torch::Tensor> NodeTypeOffset() const {\n    return node_type_offset_;\n  }\n\n  /** @brief Get the edge type tensor for a heterogeneous graph. */\n  inline const torch::optional<torch::Tensor> TypePerEdge() const {\n    return type_per_edge_;\n  }\n\n  /**\n   * @brief Get the node type to id map for a heterogeneous graph.\n   * @note The map is a dictionary mapping node type names to type IDs.\n   */\n  inline const torch::optional<NodeTypeToIDMap> NodeTypeToID() const {\n    return node_type_to_id_;\n  }\n\n  /**\n   * @brief Get the edge type to id map for a heterogeneous graph.\n   * @note The map is a dictionary mapping edge type names to type IDs.\n   */\n  inline const torch::optional<EdgeTypeToIDMap> EdgeTypeToID() const {\n    return edge_type_to_id_;\n  }\n\n  /** @brief Get the node attributes dictionary. */\n  inline const torch::optional<EdgeAttrMap> NodeAttributes() const {\n    return node_attributes_;\n  }\n\n  /** @brief Get the edge attributes dictionary. */\n  inline const torch::optional<EdgeAttrMap> EdgeAttributes() const {\n    return edge_attributes_;\n  }\n\n  /**\n   * @brief Get the node attribute tensor by name.\n   *\n   * If the input name is empty, return nullopt. Otherwise, return the node\n   * attribute tensor by name.\n   */\n  inline torch::optional<torch::Tensor> NodeAttribute(\n      torch::optional<std::string> name) const {\n    if (!name.has_value()) {\n      return torch::nullopt;\n    }\n    TORCH_CHECK(\n        node_attributes_.has_value() &&\n            node_attributes_.value().contains(name.value()),\n        \"Node attribute \", name.value(), \" does not exist.\");\n    return torch::optional<torch::Tensor>(\n        node_attributes_.value().at(name.value()));\n  }\n\n  /**\n   * @brief Get the edge attribute tensor by name.\n   *\n   * If the input name is empty, return nullopt. Otherwise, return the edge\n   * attribute tensor by name.\n   */\n  inline torch::optional<torch::Tensor> EdgeAttribute(\n      torch::optional<std::string> name) const {\n    if (!name.has_value()) {\n      return torch::nullopt;\n    }\n    TORCH_CHECK(\n        edge_attributes_.has_value() &&\n            edge_attributes_.value().contains(name.value()),\n        \"Edge attribute \", name.value(), \" does not exist.\");\n    return torch::optional<torch::Tensor>(\n        edge_attributes_.value().at(name.value()));\n  }\n\n  /** @brief Set the csc index pointer tensor. */\n  inline void SetCSCIndptr(const torch::Tensor& indptr) { indptr_ = indptr; }\n\n  /** @brief Set the index tensor. */\n  inline void SetIndices(const torch::Tensor& indices) { indices_ = indices; }\n\n  /** @brief Set the node type offset tensor for a heterogeneous graph. */\n  inline void SetNodeTypeOffset(\n      const torch::optional<torch::Tensor>& node_type_offset) {\n    node_type_offset_ = node_type_offset;\n  }\n\n  /** @brief Set the edge type tensor for a heterogeneous graph. */\n  inline void SetTypePerEdge(\n      const torch::optional<torch::Tensor>& type_per_edge) {\n    type_per_edge_ = type_per_edge;\n  }\n\n  /**\n   * @brief Set the node type to id map for a heterogeneous graph.\n   * @note The map is a dictionary mapping node type names to type IDs.\n   */\n  inline void SetNodeTypeToID(\n      const torch::optional<NodeTypeToIDMap>& node_type_to_id) {\n    node_type_to_id_ = node_type_to_id;\n  }\n\n  /**\n   * @brief Set the edge type to id map for a heterogeneous graph.\n   * @note The map is a dictionary mapping edge type names to type IDs.\n   */\n  inline void SetEdgeTypeToID(\n      const torch::optional<EdgeTypeToIDMap>& edge_type_to_id) {\n    edge_type_to_id_ = edge_type_to_id;\n  }\n\n  /** @brief Set the node attributes dictionary. */\n  inline void SetNodeAttributes(\n      const torch::optional<EdgeAttrMap>& node_attributes) {\n    node_attributes_ = node_attributes;\n  }\n\n  /** @brief Set the edge attributes dictionary. */\n  inline void SetEdgeAttributes(\n      const torch::optional<EdgeAttrMap>& edge_attributes) {\n    edge_attributes_ = edge_attributes;\n  }\n\n  /** @brief Add node attribute by name. */\n  inline void AddNodeAttribute(\n      const std::string& name, const torch::Tensor& node_attribute) {\n    if (!node_attributes_.has_value()) {\n      node_attributes_ = NodeAttrMap();\n    }\n    node_attributes_.value().insert_or_assign(name, node_attribute);\n  }\n\n  /** @brief Add edge attribute by name. */\n  inline void AddEdgeAttribute(\n      const std::string& name, const torch::Tensor& edge_attribute) {\n    if (!edge_attributes_.has_value()) {\n      edge_attributes_ = EdgeAttrMap();\n    }\n    edge_attributes_.value().insert_or_assign(name, edge_attribute);\n  }\n\n  /**\n   * @brief Magic number to indicate graph version in serialize/deserialize\n   * stage.\n   */\n  static constexpr int64_t kCSCSamplingGraphSerializeMagic = 0xDD2E60F0F6B4A128;\n\n  /**\n   * @brief Load graph from stream.\n   * @param archive Input stream for deserializing.\n   */\n  void Load(torch::serialize::InputArchive& archive);\n\n  /**\n   * @brief Save graph to stream.\n   * @param archive Output stream for serializing.\n   */\n  void Save(torch::serialize::OutputArchive& archive) const;\n\n  /**\n   * @brief Pickle method for deserializing.\n   * @param state The state of serialized FusedCSCSamplingGraph.\n   */\n  void SetState(\n      const torch::Dict<std::string, torch::Dict<std::string, torch::Tensor>>&\n          state);\n\n  /**\n   * @brief Pickle method for serializing.\n   * @returns The state of this FusedCSCSamplingGraph.\n   */\n  torch::Dict<std::string, torch::Dict<std::string, torch::Tensor>> GetState()\n      const;\n\n  /**\n   * @brief Return the subgraph induced on the inbound edges of the given nodes.\n   * @param nodes Type agnostic node IDs to form the subgraph.\n   *\n   * @return FusedSampledSubgraph.\n   */\n  c10::intrusive_ptr<FusedSampledSubgraph> InSubgraph(\n      const torch::Tensor& nodes) const;\n\n  /**\n   * @brief Sample neighboring edges of the given nodes and return the induced\n   * subgraph.\n   *\n   * @param seeds The nodes from which to sample neighbors. If not provided,\n   * assumed to be equal to torch.arange(NumNodes()).\n   * @param seed_offsets The offsets of the given seeds,\n   * seeds[seed_offsets[i]: seed_offsets[i + 1]] has node type id i.\n   * @param fanouts The number of edges to be sampled for each node with or\n   * without considering edge types.\n   *   - When the length is 1, it indicates that the fanout applies to all\n   * neighbors of the node as a collective, regardless of the edge type.\n   *   - Otherwise, the length should equal to the number of edge types, and\n   * each fanout value corresponds to a specific edge type of the node.\n   * The value of each fanout should be >= 0 or = -1.\n   *   - When the value is -1, all neighbors will be chosen for sampling. It is\n   * equivalent to selecting all neighbors with non-zero probability when the\n   * fanout is >= the number of neighbors (and replacement is set to false).\n   *   - When the value is a non-negative integer, it serves as a minimum\n   * threshold for selecting neighbors.\n   * @param replace Boolean indicating whether the sample is preformed with or\n   * without replacement. If True, a value can be selected multiple times.\n   * Otherwise, each value can be selected only once.\n   * @param layer Boolean indicating whether neighbors should be sampled in a\n   * layer sampling fashion. Uses the LABOR-0 algorithm to increase overlap of\n   * sampled edges, see arXiv:2210.13339.\n   * @param returning_indices_is_optional Boolean indicating whether returning\n   * indices tensor is optional.\n   * @param probs_or_mask An optional edge attribute tensor for probablities\n   * or masks. This attribute tensor should contain (unnormalized)\n   * probabilities corresponding to each neighboring edge of a node. It must be\n   * a 1D floating-point or boolean tensor, with the number of elements\n   * equalling the total number of edges.\n   * @param random_seed The random seed for the sampler for layer=True.\n   * @param seed2_contribution The contribution of the second random seed,\n   * [0, 1) for layer=True.\n   *\n   * @return An intrusive pointer to a FusedSampledSubgraph object containing\n   * the sampled graph's information.\n   */\n  c10::intrusive_ptr<FusedSampledSubgraph> SampleNeighbors(\n      torch::optional<torch::Tensor> seeds,\n      torch::optional<std::vector<int64_t>> seed_offsets,\n      const std::vector<int64_t>& fanouts, bool replace, bool layer,\n      bool returning_indices_is_optional,\n      torch::optional<torch::Tensor> probs_or_mask,\n      torch::optional<torch::Tensor> random_seed,\n      double seed2_contribution) const;\n\n  c10::intrusive_ptr<Future<c10::intrusive_ptr<FusedSampledSubgraph>>>\n  SampleNeighborsAsync(\n      torch::optional<torch::Tensor> seeds,\n      torch::optional<std::vector<int64_t>> seed_offsets,\n      const std::vector<int64_t>& fanouts, bool replace, bool layer,\n      bool returning_indices_is_optional,\n      torch::optional<torch::Tensor> probs_or_mask,\n      torch::optional<torch::Tensor> random_seed,\n      double seed2_contribution) const;\n\n  /**\n   * @brief Sample neighboring edges of the given nodes with a temporal\n   * constraint. If `node_timestamp_attr_name` or `edge_timestamp_attr_name` is\n   * given, the sampled neighbors or edges of an input node must have a\n   * timestamp that is smaller than that of the input node.\n   *\n   * @param seeds The seeds nodes from which to sample neighbors.\n   * @param seed_offsets The offsets of the given seeds,\n   * seeds[seed_offsets[i]: seed_offsets[i + 1]] has node type id i.\n   * @param seeds_timestamp The timestamp of the nodes.\n   * @param fanouts The number of edges to be sampled for each node with or\n   * without considering edge types, following the same rules as in\n   * SampleNeighbors.\n   * @param replace Boolean indicating whether the sample is preformed with or\n   * without replacement. If True, a value can be selected multiple times.\n   * Otherwise, each value can be selected only once.\n   * @param layer Boolean indicating whether neighbors should be sampled in a\n   * layer sampling fashion. Uses the LABOR-0 algorithm to increase overlap of\n   * sampled edges, see arXiv:2210.13339.\n   * @param returning_indices_is_optional Boolean indicating whether returning\n   * indices tensor is optional.\n   * @param seeds_pre_time_window The time window of the seed nodes represents\n   * a period of time before `seeds_timestamp`. If provided, only\n   * neighbors and related edges whose timestamps fall within\n   * `[seeds_timestamp - seeds_pre_time_window, seeds_timestamp]` will be\n   * filtered.\n   * @param probs_or_mask An optional edge attribute tensor for probablities\n   * or masks, following the same rules as in SampleNeighbors.\n   * @param node_timestamp_attr_name An optional string specifying the name of\n   * the node attribute that contains the timestamp of nodes in the graph.\n   * @param edge_timestamp_attr_name An optional string specifying the name of\n   * the edge attribute that contains the timestamp of edges in the graph.\n   *\n   * @return An intrusive pointer to a FusedSampledSubgraph object containing\n   * the sampled graph's information.\n   *\n   */\n  c10::intrusive_ptr<FusedSampledSubgraph> TemporalSampleNeighbors(\n      const torch::optional<torch::Tensor>& seeds,\n      const torch::optional<std::vector<int64_t>>& seed_offsets,\n      const torch::Tensor& seeds_timestamp, const std::vector<int64_t>& fanouts,\n      bool replace, bool layer, bool returning_indices_is_optional,\n      torch::optional<torch::Tensor> seeds_pre_time_window,\n      torch::optional<torch::Tensor> probs_or_mask,\n      torch::optional<std::string> node_timestamp_attr_name,\n      torch::optional<std::string> edge_timestamp_attr_name,\n      torch::optional<torch::Tensor> random_seed,\n      double seed2_contribution) const;\n\n  /**\n   * @brief Copy the graph to shared memory.\n   * @param shared_memory_name The name of the shared memory.\n   *\n   * @return A new FusedCSCSamplingGraph object on shared memory.\n   */\n  c10::intrusive_ptr<FusedCSCSamplingGraph> CopyToSharedMemory(\n      const std::string& shared_memory_name);\n\n  /**\n   * @brief Load the graph from shared memory.\n   * @param shared_memory_name The name of the shared memory.\n   *\n   * @return A new FusedCSCSamplingGraph object on shared memory.\n   */\n  static c10::intrusive_ptr<FusedCSCSamplingGraph> LoadFromSharedMemory(\n      const std::string& shared_memory_name);\n\n  /**\n   * @brief Hold the shared memory objects of the the tensor metadata and data.\n   * @note Shared memory used to hold the tensor metadata and data of this\n   * class. By storing its shared memory objects, the graph controls the\n   * resources of shared memory, which will be released automatically when the\n   * graph is destroyed. This function is for internal use by CopyToSharedMemory\n   * and LoadFromSharedMemory. Please contact the DGL team if you need to use\n   * it.\n   * @param tensor_metadata_shm The shared memory objects of tensor metadata.\n   * @param tensor_data_shm The shared memory objects of tensor data.\n   */\n  void HoldSharedMemoryObject(\n      SharedMemoryPtr tensor_metadata_shm, SharedMemoryPtr tensor_data_shm);\n\n private:\n  template <TemporalOption Temporal, typename NumPickFn, typename PickFn>\n  c10::intrusive_ptr<FusedSampledSubgraph> SampleNeighborsImpl(\n      const torch::Tensor& seeds,\n      const torch::optional<std::vector<int64_t>>& seed_offsets,\n      const std::vector<int64_t>& fanouts, NumPickFn num_pick_fn,\n      PickFn pick_fn) const;\n\n  /** @brief CSC format index pointer array. */\n  torch::Tensor indptr_;\n\n  /** @brief CSC format index array. */\n  torch::Tensor indices_;\n\n  /**\n   * @brief Offset array of node type. The length of it is equal to the number\n   * of node types + 1. The tensor is in ascending order as nodes of the same\n   * type have continuous IDs, and larger node IDs are paired with larger node\n   * type IDs. Its first value is 0 and last value is the number of nodes. And\n   * nodes with ID between `node_type_offset_[i] ~ node_type_offset_[i+1]` are\n   * of type id `i`.\n   */\n  torch::optional<torch::Tensor> node_type_offset_;\n\n  /**\n   * @brief Type id of each edge, where type id is the corresponding index of\n   * edge types. The length of it is equal to the number of edges.\n   */\n  torch::optional<torch::Tensor> type_per_edge_;\n\n  /**\n   * @brief A dictionary mapping node type names to type IDs. The length of it\n   * is equal to the number of node types. The key is the node type name, and\n   * the value is the corresponding type ID.\n   */\n  torch::optional<NodeTypeToIDMap> node_type_to_id_;\n\n  /**\n   * @brief A dictionary mapping edge type names to type IDs. The length of it\n   * is equal to the number of edge types. The key is the edge type name, and\n   * the value is the corresponding type ID.\n   */\n  torch::optional<EdgeTypeToIDMap> edge_type_to_id_;\n\n  /**\n   * @brief A dictionary of node attributes. Each key represents the attribute's\n   * name, while the corresponding value holds the attribute's specific value.\n   * The length of each value should match the total number of nodes.\"\n   */\n  torch::optional<NodeAttrMap> node_attributes_;\n\n  /**\n   * @brief A dictionary of edge attributes. Each key represents the attribute's\n   * name, while the corresponding value holds the attribute's specific value.\n   * The length of each value should match the total number of edges.\"\n   */\n  torch::optional<EdgeAttrMap> edge_attributes_;\n\n  /**\n   * @brief Shared memory used to hold the tensor metadata and data of this\n   * class. By storing its shared memory objects, the graph controls the\n   * resources of shared memory, which will be released automatically when the\n   * graph is destroyed.\n   */\n  SharedMemoryPtr tensor_metadata_shm_, tensor_data_shm_;\n};\n\n/**\n * @brief Calculate the number of the neighbors to be picked for the given node.\n *\n * @param fanout The number of edges to be sampled for each node. It should be\n * >= 0 or -1.\n *  - When the value is -1, all neighbors (with non-zero probability, if\n * weighted) will be chosen for sampling. It is equivalent to selecting all\n * neighbors with non-zero probability when the fanout is >= the number of\n * neighbors (and replacement is set to false).\n *  - When the value is a non-negative integer, it serves as a minimum\n * threshold for selecting neighbors.\n * @param replace Boolean indicating whether the sample is performed with or\n * without replacement. If True, a value can be selected multiple times.\n * Otherwise, each value can be selected only once.\n * @param probs_or_mask Optional tensor containing the (unnormalized)\n * probabilities associated with each neighboring edge of a node in the original\n * graph. It must be a 1D floating-point tensor with the number of elements\n * equal to the number of edges in the graph.\n * @param offset The starting edge ID for the connected neighbors of the given\n * node.\n * @param num_neighbors The number of neighbors of this node.\n * @param num_picked_ptr The pointer of the tensor which stores the pick\n * numbers.\n */\ntemplate <typename PickedNumType>\nvoid NumPick(\n    int64_t fanout, bool replace,\n    const torch::optional<torch::Tensor>& probs_or_mask, int64_t offset,\n    int64_t num_neighbors, PickedNumType* num_picked_ptr);\n\nint64_t TemporalNumPick(\n    torch::Tensor seed_timestamp, torch::Tensor csc_indics, int64_t fanout,\n    bool replace, const torch::optional<torch::Tensor>& seed_pre_time_window,\n    const torch::optional<torch::Tensor>& probs_or_mask,\n    const torch::optional<torch::Tensor>& node_timestamp,\n    const torch::optional<torch::Tensor>& edge_timestamp, int64_t seed_offset,\n    int64_t offset, int64_t num_neighbors);\n\ntemplate <typename PickedNumType>\nvoid NumPickByEtype(\n    bool with_seed_offsets, const std::vector<int64_t>& fanouts, bool replace,\n    const torch::Tensor& type_per_edge,\n    const torch::optional<torch::Tensor>& probs_or_mask, int64_t offset,\n    int64_t num_neighbors, PickedNumType* num_picked_ptr, int64_t seed_index,\n    const std::vector<int64_t>& etype_id_to_num_picked_offset);\n\nint64_t TemporalNumPickByEtype(\n    torch::Tensor seed_timestamp, torch::Tensor csc_indices,\n    const std::vector<int64_t>& fanouts, bool replace,\n    const torch::Tensor& type_per_edge,\n    const torch::optional<torch::Tensor>& seed_pre_time_window,\n    const torch::optional<torch::Tensor>& probs_or_mask,\n    const torch::optional<torch::Tensor>& node_timestamp,\n    const torch::optional<torch::Tensor>& edge_timestamp, int64_t seed_offset,\n    int64_t offset, int64_t num_neighbors);\n\n/**\n * @brief Picks a specified number of neighbors for a node, starting from the\n * given offset and having the specified number of neighbors.\n *\n * If 'probs_or_mask' is provided, it indicates that the sampling is\n * non-uniform. In such cases:\n * - When the number of neighbors with non-zero probability is less than or\n * equal to fanout, all neighbors with non-zero probability will be selected.\n * - When the number of neighbors with non-zero probability exceeds fanout, the\n * sampling process will select 'fanout' elements based on their respective\n * probabilities. Higher probabilities will increase the chances of being chosen\n * during the sampling process.\n *\n * @param offset The starting edge ID for the connected neighbors of the sampled\n * node.\n * @param num_neighbors The number of neighbors to pick.\n * @param fanout The number of edges to be sampled for each node. It should be\n * >= 0 or -1.\n *  - When the value is -1, all neighbors will be chosen for sampling. It is\n * equivalent to selecting all neighbors with non-zero probability when the\n * fanout is >= the number of neighbors (and replacement is set to false).\n *  - When the value is a non-negative integer, it serves as a minimum\n * threshold for selecting neighbors.\n * @param replace Boolean indicating whether the sample is preformed with or\n * without replacement. If True, a value can be selected multiple times.\n * Otherwise, each value can be selected only once.\n * @param options Tensor options specifying the desired data type of the result.\n * @param probs_or_mask Optional tensor containing the (unnormalized)\n * probabilities associated with each neighboring edge of a node in the original\n * graph. It must be a 1D floating-point tensor with the number of elements\n * equal to the number of edges in the graph.\n * @param picked_data_ptr The destination address where the picked neighbors\n * should be put. Enough memory space should be allocated in advance.\n */\ntemplate <typename PickedType>\nint64_t Pick(\n    int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,\n    const torch::TensorOptions& options,\n    const torch::optional<torch::Tensor>& probs_or_mask,\n    SamplerArgs<SamplerType::NEIGHBOR> args, PickedType* picked_data_ptr);\n\ntemplate <SamplerType S, typename PickedType>\nstd::enable_if_t<is_labor(S), int64_t> Pick(\n    int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,\n    const torch::TensorOptions& options,\n    const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args,\n    PickedType* picked_data_ptr);\n\ntemplate <typename PickedType>\nint64_t TemporalPick(\n    torch::Tensor seed_timestamp, torch::Tensor csc_indices,\n    int64_t seed_offset, int64_t offset, int64_t num_neighbors, int64_t fanout,\n    bool replace, const torch::TensorOptions& options,\n    const torch::optional<torch::Tensor>& seed_pre_time_window,\n    const torch::optional<torch::Tensor>& probs_or_mask,\n    const torch::optional<torch::Tensor>& node_timestamp,\n    const torch::optional<torch::Tensor>& edge_timestamp,\n    PickedType* picked_data_ptr);\n\n/**\n * @brief Picks a specified number of neighbors for a node per edge type,\n * starting from the given offset and having the specified number of neighbors.\n *\n * @param offset The starting edge ID for the connected neighbors of the sampled\n * node.\n * @param num_neighbors The number of neighbors to pick.\n * @param fanouts The edge sampling numbers corresponding to each edge type for\n * a single node. The value of each fanout should be >= 0 or = 1.\n *  - When the value is -1, all neighbors with non-zero probability will be\n * chosen for sampling. It is equivalent to selecting all neighbors when the\n * fanout is >= the number of neighbors (and replacement is set to false).\n *  - When the value is a non-negative integer, it serves as a minimum threshold\n * for selecting neighbors.\n * @param replace Boolean indicating whether the sample is preformed with or\n * without replacement. If True, a value can be selected multiple times.\n * Otherwise, each value can be selected only once.\n * @param options Tensor options specifying the desired data type of the result.\n * @param type_per_edge Tensor representing the type of each edge in the\n * original graph.\n * @param probs_or_mask Optional tensor containing the (unnormalized)\n * probabilities associated with each neighboring edge of a node in the original\n * graph. It must be a 1D floating-point tensor with the number of elements\n * equal to the number of edges in the graph.\n * @param picked_data_ptr The pointer of the tensor where the picked neighbors\n * should be put. Enough memory space should be allocated in advance.\n * @param seed_offset The offset(index) of the seed among the group of seeds\n * which share the same node type.\n * @param subgraph_indptr_ptr The pointer of the tensor which stores the indptr\n * of the sampled subgraph.\n * @param etype_id_to_num_picked_offset A vector storing the mappings from each\n * etype_id to the offset of its pick numbers in the tensor.\n */\ntemplate <SamplerType S, typename PickedType>\nint64_t PickByEtype(\n    bool with_seed_offsets, int64_t offset, int64_t num_neighbors,\n    const std::vector<int64_t>& fanouts, bool replace,\n    const torch::TensorOptions& options, const torch::Tensor& type_per_edge,\n    const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args,\n    PickedType* picked_data_ptr, int64_t seed_offset,\n    PickedType* subgraph_indptr_ptr,\n    const std::vector<int64_t>& etype_id_to_num_picked_offset);\n\ntemplate <typename PickedType>\nint64_t TemporalPickByEtype(\n    torch::Tensor seed_timestamp, torch::Tensor csc_indices,\n    int64_t seed_offset, int64_t offset, int64_t num_neighbors,\n    const std::vector<int64_t>& fanouts, bool replace,\n    const torch::TensorOptions& options, const torch::Tensor& type_per_edge,\n    const torch::optional<torch::Tensor>& seed_pre_time_window,\n    const torch::optional<torch::Tensor>& probs_or_mask,\n    const torch::optional<torch::Tensor>& node_timestamp,\n    const torch::optional<torch::Tensor>& edge_timestamp,\n    PickedType* picked_data_ptr);\n\ntemplate <\n    bool NonUniform, bool Replace, typename ProbsType, SamplerType S,\n    typename PickedType, int StackSize = 1024>\nstd::enable_if_t<is_labor(S), int64_t> LaborPick(\n    int64_t offset, int64_t num_neighbors, int64_t fanout,\n    const torch::TensorOptions& options,\n    const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args,\n    PickedType* picked_data_ptr);\n\n}  // namespace sampling\n}  // namespace graphbolt\n\n#endif  // GRAPHBOLT_CSC_SAMPLING_GRAPH_H_\n"
  },
  {
    "path": "graphbolt/include/graphbolt/fused_sampled_subgraph.h",
    "content": "/**\n *  Copyright (c) 2023 by Contributors\n * @file graphbolt/fused_sampled_subgraph.h\n * @brief Header file of sampled sub graph.\n */\n\n#ifndef GRAPHBOLT_FUSED_SAMPLED_SUBGRAPH_H_\n#define GRAPHBOLT_FUSED_SAMPLED_SUBGRAPH_H_\n\n#include <torch/custom_class.h>\n#include <torch/torch.h>\n\nnamespace graphbolt {\nnamespace sampling {\n\n/**\n * @brief Struct representing a sampled subgraph.\n *\n * Example usage:\n *\n * Suppose the subgraph has 3 nodes and 4 edges.\n * ```\n * auto indptr = torch::tensor({0, 2, 3, 4}, {torch::kInt64});\n * auto indices = torch::tensor({55, 101, 3, 3}, {torch::kInt64});\n * auto original_column_node_ids = torch::tensor({3, 3, 101}, {torch::kInt64});\n *\n * FusedSampledSubgraph sampledSubgraph(indptr, indices,\n * original_column_node_ids);\n * ```\n *\n * The `original_column_node_ids` indicates that nodes `[3, 3, 101]` in the\n * original graph are mapped to `[0, 1, 2]` in this subgraph, and because\n * `original_row_node_ids` is `Null`, `{55, 101, 3, 3}` in `indices` is just\n * the original node ids without compaction.\n *\n * If `original_row_node_ids = torch::tensor({55, 101, 3}, {torch::kInt64})`,\n * it would indicate a different mapping for the row nodes. Note this is\n * inconsistent with column, which is legal, as `3` is mapped to `0` and `1` in\n * the column while `2` in the row.\n */\nstruct FusedSampledSubgraph : torch::CustomClassHolder {\n public:\n  /**\n   * @brief Constructor for the FusedSampledSubgraph struct.\n   *\n   * @param indptr CSC format index pointer array.\n   * @param indices CSC format index array.\n   * @param original_column_node_ids Row's reverse node ids in the original\n   * graph.\n   * @param original_row_node_ids Column's reverse node ids in the original\n   * graph.\n   * @param original_edge_ids Mapping of subgraph edge IDs to original\n   * FusedCSCSamplingGraph edge IDs.\n   * @param type_per_edge Type id of each edge.\n   * @param etype_offsets Edge offsets for the sampled edges for the sampled\n   * edges that are sorted w.r.t. edge types.\n   */\n  FusedSampledSubgraph(\n      torch::Tensor indptr, torch::optional<torch::Tensor> indices,\n      torch::Tensor original_edge_ids,\n      torch::optional<torch::Tensor> original_column_node_ids,\n      torch::optional<torch::Tensor> original_row_node_ids = torch::nullopt,\n      torch::optional<torch::Tensor> type_per_edge = torch::nullopt,\n      torch::optional<torch::Tensor> etype_offsets = torch::nullopt)\n      : indptr(indptr),\n        indices(indices),\n        original_edge_ids(original_edge_ids),\n        original_column_node_ids(original_column_node_ids),\n        original_row_node_ids(original_row_node_ids),\n        type_per_edge(type_per_edge),\n        etype_offsets(etype_offsets) {}\n\n  FusedSampledSubgraph() = default;\n\n  /**\n   * @brief CSC format index pointer array, where the implicit node ids are\n   * already compacted. And the original ids are stored in the\n   * `original_column_node_ids` field. Its length is equal to:\n   * 1 + \\sum_{etype} #seeds with dst_node_type(etype)\n   */\n  torch::Tensor indptr;\n\n  /**\n   * @brief CSC format index array, where the node ids can be compacted ids or\n   * original ids. If compacted, the original ids are stored in the\n   * `original_row_node_ids` field. The indices are sorted w.r.t. their edge\n   * types for the heterogenous case.\n   *\n   * @note This is optional if its fetch operation will be performed later using\n   * the original_edge_ids tensor.\n   */\n  torch::optional<torch::Tensor> indices;\n\n  /**\n   * @brief Mapping of subgraph edge IDs to original FusedCSCSamplingGraph\n   * edge IDs.\n   *\n   * In this subgraph, the edge at index i corresponds to the edge with ID\n   * original_edge_ids[i] in the original FusedCSCSamplingGraph. Edges are\n   * sorted by type for heterogeneous graphs.\n   *\n   * Note: To retrieve the actual original edge IDs for feature fetching, use\n   * the `_ORIGINAL_EDGE_ID` edge attribute in FusedCSCSamplingGraph to map the\n   * `original_edge_ids` agin, as IDs may have been remapped during conversion\n   * to FusedCSCSamplingGraph.\n   */\n  torch::Tensor original_edge_ids;\n\n  /**\n   * @brief Column's reverse node ids in the original graph. A graph structure\n   * can be treated as a coordinated row and column pair, and this is the the\n   * mapped ids of the column.\n   *\n   * @note This is optional and the mapping relations can be inconsistent with\n   * column's. It can be missing when the sampling algorithm is called via a\n   * sliced sampled subgraph with missing seeds argument.\n   */\n  torch::optional<torch::Tensor> original_column_node_ids;\n\n  /**\n   * @brief Row's reverse node ids in the original graph. A graph structure\n   * can be treated as a coordinated row and column pair, and this is the the\n   * mapped ids of the row.\n   *\n   * @note This is optional and the mapping relations can be inconsistent with\n   * row's.\n   */\n  torch::optional<torch::Tensor> original_row_node_ids;\n\n  /**\n   * @brief Type id of each edge, where type id is the corresponding index of\n   * edge types. The length of it is equal to the number of edges in the\n   * subgraph.\n   *\n   * @note This output is not created by the CUDA implementation as the edges\n   * are sorted w.r.t edge types, one has to use etype_offsets to infer the edge\n   * type information. This field is going to be deprecated. It can be generated\n   * when needed by computing gb.expand_indptr(etype_offsets).\n   */\n  torch::optional<torch::Tensor> type_per_edge;\n\n  /**\n   * @brief Offsets of each etype,\n   * type_per_edge[etype_offsets[i]: etype_offsets[i + 1]] == i\n   * It has length equal to (1 + #etype), and the edges are guaranteed to be\n   * sorted w.r.t. their edge types.\n   */\n  torch::optional<torch::Tensor> etype_offsets;\n};\n\n}  // namespace sampling\n}  // namespace graphbolt\n\n#endif  // GRAPHBOLT_FUSED_SAMPLED_SUBGRAPH_H_\n"
  },
  {
    "path": "graphbolt/include/graphbolt/isin.h",
    "content": "/**\n *  Copyright (c) 2023 by Contributors\n *\n * @file graphbolt/isin.h\n * @brief isin op.\n */\n#ifndef GRAPHBOLT_ISIN_H_\n#define GRAPHBOLT_ISIN_H_\n\n#include <graphbolt/async.h>\n#include <torch/torch.h>\n\nnamespace graphbolt {\nnamespace sampling {\n\n/**\n * @brief Tests if each element of elements is in test_elements. Returns a\n * boolean tensor of the same shape as elements that is True for elements\n * in test_elements and False otherwise. Enhance torch.isin by implementing\n * multi-threaded searching, as detailed in the documentation at\n * https://pytorch.org/docs/stable/generated/torch.isin.html.\"\n *\n * @param elements        Input elements\n * @param test_elements   Values against which to test for each input element.\n *\n * @return\n * A boolean tensor of the same shape as elements that is True for elements\n * in test_elements and False otherwise.\n */\ntorch::Tensor IsIn(\n    const torch::Tensor& elements, const torch::Tensor& test_elements);\n\n/**\n * @brief Tests if each element of elements is not in test_elements. Returns an\n * int64_t tensor of the same shape as elements containing the indexes of the\n * elements not found in test_elements.\n *\n * @param elements        Input elements\n * @param test_elements   Values against which to test for each input element.\n *\n * @return An int64_t tensor of the same shape as elements containing indexes of\n * elements not found in test_elements.\n */\ntorch::Tensor IsNotInIndex(\n    const torch::Tensor& elements, const torch::Tensor& test_elements);\n\nc10::intrusive_ptr<Future<torch::Tensor>> IsNotInIndexAsync(\n    const torch::Tensor& elements, const torch::Tensor& test_elements);\n\n}  // namespace sampling\n}  // namespace graphbolt\n\n#endif  // GRAPHBOLT_ISIN_H_\n"
  },
  {
    "path": "graphbolt/include/graphbolt/serialize.h",
    "content": "/**\n *  Copyright (c) 2023 by Contributors\n * @file graphbolt/serialize.h\n * @brief Utility functions for serialize and deserialize.\n */\n\n#ifndef GRAPHBOLT_SERIALIZE_H_\n#define GRAPHBOLT_SERIALIZE_H_\n\n#include <graphbolt/fused_csc_sampling_graph.h>\n#include <torch/torch.h>\n\n#include <string>\n#include <vector>\n\n/**\n * @brief Overload stream operator to enable `torch::save()` and `torch.load()`\n * for FusedCSCSamplingGraph.\n */\nnamespace torch {\n\n/**\n * @brief Overload input stream operator for FusedCSCSamplingGraph\n * deserialization. This enables `torch::load()` for FusedCSCSamplingGraph.\n *\n * @param archive Input stream for deserializing.\n * @param graph FusedCSCSamplingGraph.\n *\n * @return archive\n *\n * @code\n * auto&& graph = c10::make_intrusive<sampling::FusedCSCSamplingGraph>();\n * torch::load(*graph, filename);\n */\ninline serialize::InputArchive& operator>>(\n    serialize::InputArchive& archive,\n    graphbolt::sampling::FusedCSCSamplingGraph& graph);\n\n/**\n * @brief Overload output stream operator for FusedCSCSamplingGraph\n * serialization. This enables `torch::save()` for FusedCSCSamplingGraph.\n * @param archive Output stream for serializing.\n * @param graph FusedCSCSamplingGraph.\n *\n * @return archive\n *\n * @code\n * auto&& graph = c10::make_intrusive<sampling::FusedCSCSamplingGraph>();\n * torch::save(*graph, filename);\n */\ninline serialize::OutputArchive& operator<<(\n    serialize::OutputArchive& archive,\n    const graphbolt::sampling::FusedCSCSamplingGraph& graph);\n\n}  // namespace torch\n\nnamespace graphbolt {\n\n/**\n * @brief Read data from archive and format to specified type.\n * @param archive Input archive.\n * @param key Key name of data.\n *\n * @return data.\n */\ntemplate <typename T>\nT read_from_archive(\n    torch::serialize::InputArchive& archive, const std::string& key) {\n  torch::IValue data;\n  archive.read(key, data);\n  return data.to<T>();\n}\n\n}  // namespace graphbolt\n\n#endif  // GRAPHBOLT_SERIALIZE_H_\n"
  },
  {
    "path": "graphbolt/include/graphbolt/shared_memory.h",
    "content": "/**\n *  Copyright (c) 2023 by Contributors\n *\n * @file graphbolt/shared_memory.h\n * @brief Header file of graphbolt shared memory.\n */\n#ifndef GRAPHBOLT_SHARED_MEMORY_H_\n#define GRAPHBOLT_SHARED_MEMORY_H_\n\n#ifdef _WIN32\n#include <windows.h>\n#endif  // _WIN32\n\n#include <memory>\n#include <string>\n\nnamespace graphbolt {\nnamespace sampling {\n\n/**\n * @brief The SharedMemory is responsible for storing all the necessary\n * parameters of the buffer. Each SharedMemory instance is associated with a\n * shared memory object. The object will be removed when the associated\n * SharedMemory instance is destroyed.\n */\nclass SharedMemory {\n public:\n  /**\n   * @brief Constructor of the shared memory.\n   * @param name The name of the shared memory.\n   */\n  explicit SharedMemory(const std::string& name);\n\n  SharedMemory(const SharedMemory&) = delete;\n  SharedMemory& operator=(const SharedMemory&) = delete;\n\n  /**\n   * @brief The destructor is responsible for unmapping the shared memory and\n   * removing the associated shared memory object.\n   */\n  ~SharedMemory();\n\n  /** @brief Get the name of shared memory. */\n  std::string GetName() const { return name_; }\n\n  /** @brief Get the pointer to the shared memory. */\n  void* GetMemory() const { return ptr_; }\n\n  /** @brief Get the size of the shared memory. */\n  size_t GetSize() const { return size_; }\n\n  /**\n   * @brief Creates the shared memory object and map the shared memory.\n   *\n   * @param size The size of the shared memory.\n   * @return The pointer to the shared memory.\n   */\n  void* Create(size_t size);\n\n  /**\n   * @brief Open the created shared memory object and map the shared memory.\n   *\n   */\n  void* Open();\n\n  /**\n   * @brief Check if the shared memory exists.\n   *\n   * @param name The name of the shared memory.\n   * @return True if the shared memory exists, otherwise False.\n   */\n  static bool Exists(const std::string& name);\n\n private:\n  /** @brief The name of the shared memory. */\n  std::string name_;\n\n  /** @brief The size of the shared memory. */\n  size_t size_;\n\n  /** @brief The pointer of the shared memory. */\n  void* ptr_;\n\n#ifdef _WIN32\n\n  /** @brief The handle of the shared memory object. */\n  HANDLE handle_;\n\n#else  // _WIN32\n\n  /** @brief The file descriptor of the shared memory object. */\n  int file_descriptor_;\n\n  /**\n   * @brief Whether the shared memory is created by the instance.\n   *\n   * The instance that creates the shared memory object is responsible for\n   * unlinking the shared memory object.\n   */\n  bool is_creator_;\n\n#endif  // _WIN32\n};\n\nusing SharedMemoryPtr = std::unique_ptr<SharedMemory>;\n\n}  // namespace sampling\n}  // namespace graphbolt\n\n#endif  // GRAPHBOLT_SHARED_MEMORY_H_\n"
  },
  {
    "path": "graphbolt/include/graphbolt/unique_and_compact.h",
    "content": "/**\n *  Copyright (c) 2023 by Contributors\n *\n * @file unique_and_compact.h\n * @brief Unique and compact op.\n */\n#ifndef GRAPHBOLT_UNIQUE_AND_COMPACT_H_\n#define GRAPHBOLT_UNIQUE_AND_COMPACT_H_\n\n#include <graphbolt/async.h>\n#include <torch/torch.h>\n\nnamespace graphbolt {\nnamespace sampling {\n/**\n * @brief Removes duplicate elements from the concatenated 'unique_dst_ids' and\n * 'src_ids' tensor and applies the uniqueness information to compact both\n * source and destination tensors.\n *\n * The function performs two main operations:\n *   1. Unique Operation: 'unique(concat(unique_dst_ids, src_ids))', in which\n * the unique operator will guarantee the 'unique_dst_ids' are at the head of\n * the result tensor.\n *   2. Compact Operation: Utilizes the reverse mapping derived from the unique\n * operation to transform 'src_ids' and 'dst_ids' into compacted IDs.\n *\n * When world_size is greater than 1, then the given ids are partitioned between\n * the available ranks. The ids corresponding to the given rank are guaranteed\n * to come before the ids of other ranks. To do this, the partition ids are\n * rotated backwards by the given rank so that the ids are ordered as:\n * [rank, rank + 1, world_size, 0, ..., rank - 1]. This is supported only for\n * Volta and later generation NVIDIA GPUs.\n *\n * @param src_ids         A tensor containing source IDs.\n * @param dst_ids         A tensor containing destination IDs.\n * @param unique_dst_ids  A tensor containing unique destination IDs, which is\n *                        exactly all the unique elements in 'dst_ids'.\n * @param rank            The rank of the current GPU.\n * @param world_size      The total # GPUs, world size.\n *\n * @return (unique_ids, compacted_src_ids, compacted_dst_ids, unique_offsets)\n * - A tensor representing all unique elements in 'src_ids' and 'dst_ids' after\n * removing duplicates. The indices in this tensor precisely match the compacted\n * IDs of the corresponding elements.\n * - The tensor corresponding to the 'src_ids' tensor, where the entries are\n * mapped to compacted IDs.\n * - The tensor corresponding to the 'dst_ids' tensor, where the entries are\n * mapped to compacted IDs.\n * - The tensor corresponding to the offsets into the unique_ids tensor. Has\n * size `world_size + 1` and unique_ids[offsets[i]: offsets[i + 1]] belongs to\n * the rank `(rank + i) % world_size`.\n *\n * @example\n *   torch::Tensor src_ids = src\n *   torch::Tensor dst_ids = dst\n *   torch::Tensor unique_dst_ids = torch::unique(dst);\n *   auto result = UniqueAndCompact(src_ids, dst_ids, unique_dst_ids);\n *   torch::Tensor unique_ids = std::get<0>(result);\n *   torch::Tensor compacted_src_ids = std::get<1>(result);\n *   torch::Tensor compacted_dst_ids = std::get<2>(result);\n */\nstd::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>\nUniqueAndCompact(\n    const torch::Tensor& src_ids, const torch::Tensor& dst_ids,\n    const torch::Tensor unique_dst_ids, const int64_t rank,\n    const int64_t world_size);\n\nstd::vector<\n    std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>>\nUniqueAndCompactBatched(\n    const std::vector<torch::Tensor>& src_ids,\n    const std::vector<torch::Tensor>& dst_ids,\n    const std::vector<torch::Tensor> unique_dst_ids, const int64_t rank,\n    const int64_t world_size);\n\nc10::intrusive_ptr<Future<std::vector<\n    std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>>>>\nUniqueAndCompactBatchedAsync(\n    const std::vector<torch::Tensor>& src_ids,\n    const std::vector<torch::Tensor>& dst_ids,\n    const std::vector<torch::Tensor> unique_dst_ids, const int64_t rank,\n    const int64_t world_size);\n\n}  // namespace sampling\n}  // namespace graphbolt\n\n#endif  // GRAPHBOLT_UNIQUE_AND_COMPACT_H_\n"
  },
  {
    "path": "graphbolt/src/cache_policy.cc",
    "content": "/**\n *   Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)\n *   All rights reserved.\n *\n *   Licensed under the Apache License, Version 2.0 (the \"License\");\n *   you may not use this file except in compliance with the License.\n *   You may obtain a copy of the License at\n *\n *       http://www.apache.org/licenses/LICENSE-2.0\n *\n *   Unless required by applicable law or agreed to in writing, software\n *   distributed under the License is distributed on an \"AS IS\" BASIS,\n *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n *   See the License for the specific language governing permissions and\n *   limitations under the License.\n *\n * @file cache_policy.cc\n * @brief Cache policy implementation on the CPU.\n */\n#include \"./cache_policy.h\"\n\n#include \"./utils.h\"\n\nnamespace graphbolt {\nnamespace storage {\n\ntemplate <typename CachePolicy>\nstd::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>\nBaseCachePolicy::QueryImpl(CachePolicy& policy, torch::Tensor keys) {\n  auto positions = torch::empty_like(\n      keys, keys.options()\n                .dtype(torch::kInt64)\n                .pinned_memory(utils::is_pinned(keys)));\n  auto indices = torch::empty_like(\n      keys, keys.options()\n                .dtype(torch::kInt64)\n                .pinned_memory(utils::is_pinned(keys)));\n  auto found_ptr_tensor = torch::empty_like(\n      keys, keys.options()\n                .dtype(torch::kInt64)\n                .pinned_memory(utils::is_pinned(keys)));\n  auto missing_keys = torch::empty_like(\n      keys, keys.options().pinned_memory(utils::is_pinned(keys)));\n  int64_t found_cnt = 0;\n  int64_t missing_cnt = keys.size(0);\n  AT_DISPATCH_INDEX_TYPES(\n      keys.scalar_type(), \"BaseCachePolicy::Query::DispatchForKeys\", ([&] {\n        auto keys_ptr = keys.data_ptr<index_t>();\n        auto positions_ptr = positions.data_ptr<int64_t>();\n        auto indices_ptr = indices.data_ptr<int64_t>();\n        static_assert(\n            sizeof(CacheKey*) == sizeof(int64_t), \"You need 64 bit pointers.\");\n        auto found_ptr =\n            reinterpret_cast<CacheKey**>(found_ptr_tensor.data_ptr<int64_t>());\n        auto missing_keys_ptr = missing_keys.data_ptr<index_t>();\n        for (int64_t i = 0; i < keys.size(0); i++) {\n          const auto key = keys_ptr[i];\n          auto cache_key_ptr = policy.Read(key);\n          if (cache_key_ptr) {\n            positions_ptr[found_cnt] = cache_key_ptr->getPos();\n            found_ptr[found_cnt] = cache_key_ptr;\n            indices_ptr[found_cnt++] = i;\n          } else {\n            indices_ptr[--missing_cnt] = i;\n            missing_keys_ptr[missing_cnt] = key;\n          }\n        }\n      }));\n  return {\n      positions.slice(0, 0, found_cnt), indices,\n      missing_keys.slice(0, found_cnt),\n      found_ptr_tensor.slice(0, 0, found_cnt)};\n}\n\ntemplate <typename CachePolicy>\nstd::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>\nBaseCachePolicy::QueryAndReplaceImpl(CachePolicy& policy, torch::Tensor keys) {\n  auto positions = torch::empty_like(\n      keys, keys.options()\n                .dtype(torch::kInt64)\n                .pinned_memory(utils::is_pinned(keys)));\n  auto indices = torch::empty_like(\n      keys, keys.options()\n                .dtype(torch::kInt64)\n                .pinned_memory(utils::is_pinned(keys)));\n  auto pointers = torch::empty_like(keys, keys.options().dtype(torch::kInt64));\n  auto missing_keys = torch::empty_like(\n      keys, keys.options().pinned_memory(utils::is_pinned(keys)));\n  int64_t found_cnt = 0;\n  int64_t missing_cnt = keys.size(0);\n  AT_DISPATCH_INDEX_TYPES(\n      keys.scalar_type(), \"BaseCachePolicy::Replace\", ([&] {\n        auto keys_ptr = keys.data_ptr<index_t>();\n        auto positions_ptr = positions.data_ptr<int64_t>();\n        auto indices_ptr = indices.data_ptr<int64_t>();\n        static_assert(\n            sizeof(CacheKey*) == sizeof(int64_t), \"You need 64 bit pointers.\");\n        auto pointers_ptr =\n            reinterpret_cast<CacheKey**>(pointers.data_ptr<int64_t>());\n        auto missing_keys_ptr = missing_keys.data_ptr<index_t>();\n        set_t<int64_t> position_set;\n        position_set.reserve(keys.size(0));\n        // Query and Replace combined.\n        for (int64_t i = 0; i < keys.size(0); i++) {\n          const auto key = keys_ptr[i];\n          const auto [it, can_read] = policy.Emplace(key);\n          if (can_read) {\n            auto& cache_key = *it->second;\n            positions_ptr[found_cnt] = cache_key.getPos();\n            pointers_ptr[found_cnt] = &cache_key;\n            indices_ptr[found_cnt++] = i;\n          } else {\n            indices_ptr[--missing_cnt] = i;\n            missing_keys_ptr[missing_cnt] = key;\n            // Ensure that even if an offset is added, it stays negative.\n            auto position = std::numeric_limits<int64_t>::min();\n            CacheKey* cache_key_ptr = nullptr;\n            if (it->second == policy.getMapSentinelValue()) {\n              cache_key_ptr = policy.Insert(it);\n              position = cache_key_ptr->getPos();\n              TORCH_CHECK(\n                  // We check for the uniqueness of the positions.\n                  std::get<1>(position_set.insert(position)),\n                  \"Can't insert all, larger cache capacity is needed.\");\n            }\n            positions_ptr[missing_cnt] = position;\n            pointers_ptr[missing_cnt] = cache_key_ptr;\n          }\n        }\n      }));\n  return {positions, indices, pointers, missing_keys.slice(0, found_cnt)};\n}\n\ntemplate <typename CachePolicy>\nstd::tuple<torch::Tensor, torch::Tensor> BaseCachePolicy::ReplaceImpl(\n    CachePolicy& policy, torch::Tensor keys) {\n  auto positions = torch::empty_like(\n      keys, keys.options()\n                .dtype(torch::kInt64)\n                .pinned_memory(utils::is_pinned(keys)));\n  auto pointers = torch::empty_like(\n      keys, keys.options()\n                .dtype(torch::kInt64)\n                .pinned_memory(utils::is_pinned(keys)));\n  AT_DISPATCH_INDEX_TYPES(\n      keys.scalar_type(), \"BaseCachePolicy::Replace\", ([&] {\n        auto keys_ptr = keys.data_ptr<index_t>();\n        auto positions_ptr = positions.data_ptr<int64_t>();\n        static_assert(\n            sizeof(CacheKey*) == sizeof(int64_t), \"You need 64 bit pointers.\");\n        auto pointers_ptr =\n            reinterpret_cast<CacheKey**>(pointers.data_ptr<int64_t>());\n        set_t<int64_t> position_set;\n        position_set.reserve(keys.size(0));\n        for (int64_t i = 0; i < keys.size(0); i++) {\n          const auto key = keys_ptr[i];\n          // Ensure that even if an offset is added, it stays negative.\n          auto position = std::numeric_limits<int64_t>::min();\n          CacheKey* cache_key_ptr = nullptr;\n          const auto [it, _] = policy.Emplace(key);\n          if (it->second == policy.getMapSentinelValue()) {\n            cache_key_ptr = policy.Insert(it);\n            position = cache_key_ptr->getPos();\n            TORCH_CHECK(\n                // We check for the uniqueness of the positions.\n                std::get<1>(position_set.insert(position)),\n                \"Can't insert all, larger cache capacity is needed.\");\n          }\n          positions_ptr[i] = position;\n          pointers_ptr[i] = cache_key_ptr;\n        }\n      }));\n  return {positions, pointers};\n}\n\ntemplate <bool write>\nvoid BaseCachePolicy::ReadingWritingCompletedImpl(torch::Tensor pointers) {\n  static_assert(\n      sizeof(CacheKey*) == sizeof(int64_t), \"You need 64 bit pointers.\");\n  auto pointers_ptr =\n      reinterpret_cast<CacheKey**>(pointers.data_ptr<int64_t>());\n  for (int64_t i = 0; i < pointers.size(0); i++) {\n    const auto pointer = pointers_ptr[i];\n    if (!write || pointer) {\n      pointer->EndUse<write>();\n    }\n  }\n}\n\nvoid BaseCachePolicy::ReadingCompleted(torch::Tensor pointers) {\n  ReadingWritingCompletedImpl<false>(pointers);\n}\n\nvoid BaseCachePolicy::WritingCompleted(torch::Tensor pointers) {\n  ReadingWritingCompletedImpl<true>(pointers);\n}\n\nS3FifoCachePolicy::S3FifoCachePolicy(int64_t capacity)\n    : BaseCachePolicy(capacity),\n      ghost_queue_(capacity - capacity / 10),\n      small_queue_size_target_(capacity / 10),\n      small_queue_size_(0) {\n  TORCH_CHECK(small_queue_size_target_ > 0, \"Capacity is not large enough.\");\n  ghost_set_.reserve(ghost_queue_.Capacity());\n  key_to_cache_key_.reserve(kCapacityFactor * (capacity + 1));\n}\n\nstd::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>\nS3FifoCachePolicy::Query(torch::Tensor keys) {\n  return QueryImpl(*this, keys);\n}\n\nstd::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>\nS3FifoCachePolicy::QueryAndReplace(torch::Tensor keys) {\n  return QueryAndReplaceImpl(*this, keys);\n}\n\nstd::tuple<torch::Tensor, torch::Tensor> S3FifoCachePolicy::Replace(\n    torch::Tensor keys) {\n  return ReplaceImpl(*this, keys);\n}\n\nSieveCachePolicy::SieveCachePolicy(int64_t capacity)\n    // Ensure that queue_ is constructed first before accessing its `.end()`.\n    : BaseCachePolicy(capacity), queue_(), hand_(queue_.end()) {\n  TORCH_CHECK(capacity > 0, \"Capacity needs to be positive.\");\n  key_to_cache_key_.reserve(kCapacityFactor * (capacity + 1));\n}\n\nstd::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>\nSieveCachePolicy::Query(torch::Tensor keys) {\n  return QueryImpl(*this, keys);\n}\n\nstd::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>\nSieveCachePolicy::QueryAndReplace(torch::Tensor keys) {\n  return QueryAndReplaceImpl(*this, keys);\n}\n\nstd::tuple<torch::Tensor, torch::Tensor> SieveCachePolicy::Replace(\n    torch::Tensor keys) {\n  return ReplaceImpl(*this, keys);\n}\n\nLruCachePolicy::LruCachePolicy(int64_t capacity) : BaseCachePolicy(capacity) {\n  TORCH_CHECK(capacity > 0, \"Capacity needs to be positive.\");\n  key_to_cache_key_.reserve(kCapacityFactor * (capacity + 1));\n}\n\nstd::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>\nLruCachePolicy::Query(torch::Tensor keys) {\n  return QueryImpl(*this, keys);\n}\n\nstd::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>\nLruCachePolicy::QueryAndReplace(torch::Tensor keys) {\n  return QueryAndReplaceImpl(*this, keys);\n}\n\nstd::tuple<torch::Tensor, torch::Tensor> LruCachePolicy::Replace(\n    torch::Tensor keys) {\n  return ReplaceImpl(*this, keys);\n}\n\nClockCachePolicy::ClockCachePolicy(int64_t capacity)\n    : BaseCachePolicy(capacity) {\n  TORCH_CHECK(capacity > 0, \"Capacity needs to be positive.\");\n  key_to_cache_key_.reserve(kCapacityFactor * (capacity + 1));\n}\n\nstd::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>\nClockCachePolicy::Query(torch::Tensor keys) {\n  return QueryImpl(*this, keys);\n}\n\nstd::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>\nClockCachePolicy::QueryAndReplace(torch::Tensor keys) {\n  return QueryAndReplaceImpl(*this, keys);\n}\n\nstd::tuple<torch::Tensor, torch::Tensor> ClockCachePolicy::Replace(\n    torch::Tensor keys) {\n  return ReplaceImpl(*this, keys);\n}\n\n}  // namespace storage\n}  // namespace graphbolt\n"
  },
  {
    "path": "graphbolt/src/cache_policy.h",
    "content": "/**\n *   Copyright (c) 2024, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)\n *   All rights reserved.\n *\n *   Licensed under the Apache License, Version 2.0 (the \"License\");\n *   you may not use this file except in compliance with the License.\n *   You may obtain a copy of the License at\n *\n *       http://www.apache.org/licenses/LICENSE-2.0\n *\n *   Unless required by applicable law or agreed to in writing, software\n *   distributed under the License is distributed on an \"AS IS\" BASIS,\n *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n *   See the License for the specific language governing permissions and\n *   limitations under the License.\n *\n * @file cache_policy.h\n * @brief Cache policy implementation on the CPU.\n */\n#ifndef GRAPHBOLT_CACHE_POLICY_H_\n#define GRAPHBOLT_CACHE_POLICY_H_\n\n#include <torch/custom_class.h>\n#include <torch/torch.h>\n#include <tsl/robin_map.h>\n#include <tsl/robin_set.h>\n\n#include <cuda/std/atomic>\n#include <limits>\n\n#include \"./circular_queue.h\"\n\nnamespace graphbolt {\nnamespace storage {\n\nstruct CacheKey {\n  auto getKey() const {\n    return (static_cast<int64_t>(key_higher_16_bits_) << 32) +\n           key_lower_32_bits_;\n  }\n\n  CacheKey(int64_t key) : CacheKey(key, std::numeric_limits<int64_t>::min()) {}\n\n  CacheKey(int64_t key, int64_t position)\n      : freq_(0),\n        // EndUse<true>() should be called to reset the reference count.\n        reference_count_(-1),\n        key_higher_16_bits_(key >> 32),\n        key_lower_32_bits_(key),\n        position_in_cache_(position) {\n    TORCH_CHECK(key == getKey());\n    static_assert(sizeof(CacheKey) == 2 * sizeof(int64_t));\n  }\n\n  CacheKey() = default;\n\n  auto getFreq() const { return freq_; }\n\n  auto getPos() const { return position_in_cache_; }\n\n  CacheKey& setPos(int64_t pos) {\n    position_in_cache_ = pos;\n    return *this;\n  }\n\n  CacheKey& Increment() {\n    freq_ = std::min(3, static_cast<int>(freq_ + 1));\n    return *this;\n  }\n\n  CacheKey& Decrement() {\n    freq_ = std::max(0, static_cast<int>(freq_ - 1));\n    return *this;\n  }\n\n  CacheKey& SetFreq() {\n    freq_ = 1;\n    return *this;\n  }\n\n  CacheKey& ResetFreq() {\n    freq_ = 0;\n    return *this;\n  }\n\n  CacheKey& StartRead() {\n    ::cuda::std::atomic_ref ref(reference_count_);\n    // StartRead runs concurrently only with EndUse. EndUse does not need to see\n    // this modification at all. So we can use the relaxed memory order.\n    const auto old_val = ref.fetch_add(1, ::cuda::std::memory_order_relaxed);\n    TORCH_CHECK(\n        old_val < std::numeric_limits<int8_t>::max(),\n        \"There are too many in-flight read requests to the same cache entry!\");\n    return *this;\n  }\n\n  template <bool write>\n  CacheKey& EndUse() {\n    ::cuda::std::atomic_ref ref(reference_count_);\n    // The EndUse operation needs to synchronize with the InUse operation. So we\n    // have an release-acquire ordering between the two.\n    // https://en.cppreference.com/w/cpp/atomic/memory_order#Release-Acquire_ordering\n    if constexpr (write) {\n      ref.fetch_add(1, ::cuda::std::memory_order_release);\n    } else {\n      ref.fetch_add(-1, ::cuda::std::memory_order_release);\n    }\n    return *this;\n  }\n\n  bool InUse() const {\n    ::cuda::std::atomic_ref ref(reference_count_);\n    // The operations after a call to this function need to happen after the\n    // load operation. Hence the acquire order.\n    return ref.load(::cuda::std::memory_order_acquire);\n  }\n\n  bool BeingWritten() const {\n    ::cuda::std::atomic_ref ref(reference_count_);\n    // The only operation coming after this op is the StartRead operation. Since\n    // StartRead is a refcount increment operation, it is fine if we don't\n    // synchronize with EndUse ops.\n    return ref.load(::cuda::std::memory_order_relaxed) < 0;\n  }\n\n  friend std::ostream& operator<<(std::ostream& os, const CacheKey& key_ref) {\n    ::cuda::std::atomic_ref ref(key_ref.reference_count_);\n    return os << '(' << key_ref.getKey() << \", \" << key_ref.freq_ << \", \"\n              << key_ref.position_in_cache_ << \", \" << ref.load() << \")\";\n  }\n\n private:\n  int8_t freq_;\n  // Negative values indicate writing while positive values indicate reading.\n  // Access only through an std::atomic_ref instance atomically.\n  int8_t reference_count_;\n  // Keys are restricted to be 48-bit unsigned integers.\n  uint16_t key_higher_16_bits_;\n  uint32_t key_lower_32_bits_;\n  int64_t position_in_cache_;\n};\n\nclass BaseCachePolicy {\n public:\n  BaseCachePolicy(int64_t capacity) : capacity_(capacity), cache_usage_(0) {}\n\n  BaseCachePolicy() = default;\n\n  /**\n   * @brief A virtual base class constructor ensures that the derived class\n   * destructor gets called.\n   */\n  virtual ~BaseCachePolicy() = default;\n\n  /**\n   * @brief The policy query function.\n   * @param keys The keys to query the cache.\n   *\n   * @return (positions, indices, missing_keys, found_ptrs), where positions has\n   * the locations of the keys which were found in the cache, missing_keys has\n   * the keys that were not found and indices is defined such that\n   * keys[indices[:positions.size(0)]] gives us the keys for the found pointers\n   * and keys[indices[positions.size(0):]] is identical to missing_keys.\n   */\n  virtual std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>\n  Query(torch::Tensor keys) = 0;\n\n  /**\n   * @brief The policy query function.\n   * @param keys The keys to query the cache.\n   *\n   * @return (positions, indices, pointers, missing_keys), where positions has\n   * the locations of the keys which were emplaced into the cache, pointers\n   * point to the emplaced CacheKey pointers in the cache, missing_keys has the\n   * keys that were not found and just inserted and indices is defined such that\n   * keys[indices[:keys.size(0) - missing_keys.size(0)]] gives us the keys for\n   * the found keys and keys[indices[keys.size(0) - missing_keys.size(0):]] is\n   * identical to missing_keys.\n   */\n  virtual std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>\n  QueryAndReplace(torch::Tensor keys) = 0;\n\n  /**\n   * @brief The policy replace function.\n   * @param keys The keys to query the cache.\n   *\n   * @return (positions, pointers), where positions has the locations of the\n   * replaced entries and pointers point to their CacheKey pointers in the\n   * cache.\n   */\n  virtual std::tuple<torch::Tensor, torch::Tensor> Replace(\n      torch::Tensor keys) = 0;\n\n  /**\n   * @brief A reader has finished reading these keys, so they can be evicted.\n   * @param pointers The CacheKey pointers in the cache to unmark.\n   */\n  static void ReadingCompleted(torch::Tensor pointers);\n\n  /**\n   * @brief A writer has finished writing these keys, so they can be evicted.\n   * @param pointers The CacheKey pointers in the cache to unmark.\n   */\n  static void WritingCompleted(torch::Tensor pointers);\n\n protected:\n  template <typename K, typename V>\n  using map_t = tsl::robin_map<K, V>;\n  template <typename K>\n  using set_t = tsl::robin_set<K>;\n  template <typename iterator>\n  static auto& mutable_value_ref(iterator it) {\n    return it.value();\n  }\n  static constexpr int kCapacityFactor = 2;\n\n  template <typename CachePolicy>\n  static std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>\n  QueryImpl(CachePolicy& policy, torch::Tensor keys);\n\n  template <typename CachePolicy>\n  static std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>\n  QueryAndReplaceImpl(CachePolicy& policy, torch::Tensor keys);\n\n  template <typename CachePolicy>\n  static std::tuple<torch::Tensor, torch::Tensor> ReplaceImpl(\n      CachePolicy& policy, torch::Tensor keys);\n\n  template <typename T>\n  static void MoveToFront(\n      std::list<T>& from, std::list<T>& to,\n      typename std::list<T>::iterator it) {\n    std::list<T> temp;\n    // Transfer the element to temp to keep references valid.\n    auto next_it = it;\n    std::advance(next_it, 1);\n    temp.splice(temp.begin(), from, it, next_it);\n    // Move the element to the beginning of the queue.\n    to.splice(to.begin(), temp);\n    // The iterators and references are not invalidated.\n    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(it == to.begin());\n  }\n\n  int64_t capacity_;\n  int64_t cache_usage_;\n\n private:\n  template <bool write>\n  static void ReadingWritingCompletedImpl(torch::Tensor pointers);\n};\n\n/**\n * @brief S3FIFO is a simple, scalable FIFObased algorithm with three static\n * queues (S3-FIFO). https://dl.acm.org/doi/pdf/10.1145/3600006.3613147\n **/\nclass S3FifoCachePolicy : public BaseCachePolicy {\n public:\n  using map_iterator = map_t<int64_t, CacheKey*>::iterator;\n  /**\n   * @brief Constructor for the S3FifoCachePolicy class.\n   *\n   * @param capacity The capacity of the cache in terms of # elements.\n   */\n  S3FifoCachePolicy(int64_t capacity);\n\n  S3FifoCachePolicy() = default;\n\n  S3FifoCachePolicy(S3FifoCachePolicy&&) = default;\n\n  virtual ~S3FifoCachePolicy() = default;\n\n  /**\n   * @brief See BaseCachePolicy::Query.\n   */\n  std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> Query(\n      torch::Tensor keys);\n\n  /**\n   * @brief See BaseCachePolicy::QueryAndReplace.\n   */\n  std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>\n  QueryAndReplace(torch::Tensor keys);\n\n  /**\n   * @brief See BaseCachePolicy::Replace.\n   */\n  std::tuple<torch::Tensor, torch::Tensor> Replace(torch::Tensor keys);\n\n  CacheKey* Read(int64_t key) {\n    auto it = key_to_cache_key_.find(key);\n    if (it != key_to_cache_key_.end()) {\n      auto& cache_key = it->second->Increment();\n      if (!cache_key.BeingWritten()) {\n        return &cache_key.StartRead();\n      }\n    }\n    return nullptr;\n  }\n\n  auto getMapSentinelValue() const { return nullptr; }\n\n  std::pair<map_iterator, bool> Emplace(int64_t key) {\n    auto [it, inserted] = key_to_cache_key_.emplace(key, getMapSentinelValue());\n    bool readable = false;\n    if (!inserted) {\n      auto& cache_key = it->second->Increment();\n      if (!cache_key.BeingWritten()) {\n        cache_key.StartRead();\n        readable = true;\n      }\n    }\n    return {it, readable};\n  }\n\n  CacheKey* Insert(map_iterator it) {\n    const auto key = it->first;\n    const auto in_ghost_queue = ghost_set_.erase(key);\n    auto& queue = in_ghost_queue ? main_queue_ : small_queue_;\n    queue.push_front(CacheKey(key));\n    small_queue_size_ += 1 - in_ghost_queue;\n    auto cache_key_ptr = &queue.front();\n    mutable_value_ref(it) = cache_key_ptr;\n    return &cache_key_ptr->setPos(Evict());\n  }\n\n private:\n  int64_t EvictMainQueue() {\n    while (true) {\n      auto& evicted = main_queue_.back();\n      if (evicted.getFreq() > 0 || evicted.InUse()) {\n        evicted.Decrement();\n        auto it = main_queue_.end();\n        std::advance(it, -1);\n        MoveToFront(main_queue_, main_queue_, it);\n      } else {\n        key_to_cache_key_.erase(evicted.getKey());\n        const auto evicted_pos = evicted.getPos();\n        main_queue_.pop_back();\n        return evicted_pos;\n      }\n    }\n  }\n\n  int64_t EvictSmallQueue() {\n    while (small_queue_size_ > small_queue_size_target_) {\n      --small_queue_size_;\n      auto& evicted = small_queue_.back();\n      if (evicted.getFreq() > 0 || evicted.InUse()) {\n        evicted.ResetFreq();\n        auto it = small_queue_.end();\n        std::advance(it, -1);\n        MoveToFront(small_queue_, main_queue_, it);\n      } else {\n        const auto evicted_key = evicted.getKey();\n        key_to_cache_key_.erase(evicted_key);\n        const auto evicted_pos = evicted.getPos();\n        small_queue_.pop_back();\n        if (ghost_queue_.IsFull()) {\n          ghost_set_.erase(ghost_queue_.Pop());\n        }\n        ghost_set_.insert(evicted_key);\n        ghost_queue_.Push(evicted_key);\n        return evicted_pos;\n      }\n    }\n    return -1;\n  }\n\n  int64_t Evict() {\n    // If the cache has space, get an unused slot otherwise perform eviction.\n    if (cache_usage_ < capacity_) return cache_usage_++;\n    const auto pos = EvictSmallQueue();\n    return pos >= 0 ? pos : EvictMainQueue();\n  }\n\n  std::list<CacheKey> small_queue_, main_queue_;\n  CircularQueue<int64_t> ghost_queue_;\n  size_t small_queue_size_target_;\n  // std::list<>::size() is O(N) before the CXX11 ABI which torch enforces.\n  size_t small_queue_size_;\n  set_t<int64_t> ghost_set_;\n  map_t<int64_t, CacheKey*> key_to_cache_key_;\n};\n\n/**\n * @brief SIEVE is a simple, scalable FIFObased algorithm with a single static\n * queue. https://www.usenix.org/system/files/nsdi24-zhang-yazhuo.pdf\n **/\nclass SieveCachePolicy : public BaseCachePolicy {\n public:\n  using map_iterator = map_t<int64_t, CacheKey*>::iterator;\n  /**\n   * @brief Constructor for the SieveCachePolicy class.\n   *\n   * @param capacity The capacity of the cache in terms of # elements.\n   */\n  SieveCachePolicy(int64_t capacity);\n\n  SieveCachePolicy() = default;\n\n  virtual ~SieveCachePolicy() = default;\n\n  /**\n   * @brief See BaseCachePolicy::Query.\n   */\n  std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> Query(\n      torch::Tensor keys);\n\n  /**\n   * @brief See BaseCachePolicy::QueryAndReplace.\n   */\n  std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>\n  QueryAndReplace(torch::Tensor keys);\n\n  /**\n   * @brief See BaseCachePolicy::Replace.\n   */\n  std::tuple<torch::Tensor, torch::Tensor> Replace(torch::Tensor keys);\n\n  CacheKey* Read(int64_t key) {\n    auto it = key_to_cache_key_.find(key);\n    if (it != key_to_cache_key_.end()) {\n      auto& cache_key = it->second->SetFreq();\n      if (!cache_key.BeingWritten()) {\n        return &cache_key.StartRead();\n      }\n    }\n    return nullptr;\n  }\n\n  auto getMapSentinelValue() const { return nullptr; }\n\n  std::pair<map_iterator, bool> Emplace(int64_t key) {\n    auto [it, inserted] = key_to_cache_key_.emplace(key, getMapSentinelValue());\n    bool readable = false;\n    if (!inserted) {\n      auto& cache_key = it->second->SetFreq();\n      if (!cache_key.BeingWritten()) {\n        cache_key.StartRead();\n        readable = true;\n      }\n    }\n    return {it, readable};\n  }\n\n  CacheKey* Insert(map_iterator it) {\n    const auto key = it->first;\n    queue_.push_front(CacheKey(key));\n    auto cache_key_ptr = &queue_.front();\n    mutable_value_ref(it) = cache_key_ptr;\n    return &cache_key_ptr->setPos(Evict());\n  }\n\n private:\n  int64_t Evict() {\n    // If the cache has space, get an unused slot otherwise perform eviction.\n    if (cache_usage_ < capacity_) return cache_usage_++;\n    --hand_;\n    while (hand_->getFreq() || hand_->InUse()) {\n      hand_->ResetFreq();\n      if (hand_ == queue_.begin()) hand_ = queue_.end();\n      --hand_;\n    }\n    key_to_cache_key_.erase(hand_->getKey());\n    const auto pos = hand_->getPos();\n    const auto temp = hand_;\n    if (hand_ == queue_.begin()) {\n      hand_ = queue_.end();\n    } else {\n      ++hand_;\n    }\n    queue_.erase(temp);\n    return pos;\n  }\n\n  std::list<CacheKey> queue_;\n  decltype(queue_)::iterator hand_;\n  map_t<int64_t, CacheKey*> key_to_cache_key_;\n};\n\n/**\n * @brief LeastRecentlyUsed is a simple, scalable FIFObased algorithm with a\n * single static queue.\n **/\nclass LruCachePolicy : public BaseCachePolicy {\n public:\n  using map_iterator = map_t<int64_t, std::list<CacheKey>::iterator>::iterator;\n  /**\n   * @brief Constructor for the LruCachePolicy class.\n   *\n   * @param capacity The capacity of the cache in terms of # elements.\n   */\n  LruCachePolicy(int64_t capacity);\n\n  LruCachePolicy() = default;\n\n  virtual ~LruCachePolicy() = default;\n\n  /**\n   * @brief See BaseCachePolicy::Query.\n   */\n  std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> Query(\n      torch::Tensor keys);\n\n  /**\n   * @brief See BaseCachePolicy::QueryAndReplace.\n   */\n  std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>\n  QueryAndReplace(torch::Tensor keys);\n\n  /**\n   * @brief See BaseCachePolicy::Replace.\n   */\n  std::tuple<torch::Tensor, torch::Tensor> Replace(torch::Tensor keys);\n\n  CacheKey* Read(int64_t key) {\n    auto it = key_to_cache_key_.find(key);\n    if (it != key_to_cache_key_.end()) {\n      auto& cache_key = *it->second;\n      MoveToFront(queue_, queue_, it->second);\n      if (!cache_key.BeingWritten()) {\n        return &cache_key.StartRead();\n      }\n    }\n    return nullptr;\n  }\n\n  auto getMapSentinelValue() { return queue_.end(); }\n\n  std::pair<map_iterator, bool> Emplace(int64_t key) {\n    auto [it, inserted] = key_to_cache_key_.emplace(key, getMapSentinelValue());\n    bool readable = false;\n    if (!inserted) {\n      auto& cache_key = *it->second;\n      MoveToFront(queue_, queue_, it->second);\n      if (!cache_key.BeingWritten()) {\n        cache_key.StartRead();\n        readable = true;\n      }\n    }\n    return {it, readable};\n  }\n\n  CacheKey* Insert(map_iterator it) {\n    const auto key = it->first;\n    queue_.push_front(CacheKey(key));\n    mutable_value_ref(it) = queue_.begin();\n    auto cache_key_ptr = &queue_.front();\n    return &cache_key_ptr->setPos(Evict());\n  }\n\n private:\n  int64_t Evict() {\n    // If the cache has space, get an unused slot otherwise perform eviction.\n    if (cache_usage_ < capacity_) return cache_usage_++;\n    // Do not evict items that are still in use.\n    while (queue_.back().InUse()) {\n      auto it = queue_.end();\n      std::advance(it, -1);\n      // Move the last element to the front without invalidating references.\n      MoveToFront(queue_, queue_, it);\n    }\n    const auto& cache_key = queue_.back();\n    key_to_cache_key_.erase(cache_key.getKey());\n    const auto pos = cache_key.getPos();\n    queue_.pop_back();\n    return pos;\n  }\n\n  std::list<CacheKey> queue_;\n  map_t<int64_t, decltype(queue_)::iterator> key_to_cache_key_;\n};\n\n/**\n * @brief Clock (FIFO-Reinsertion) is a simple, scalable FIFObased algorithm\n * with a single static queue.\n * https://people.csail.mit.edu/saltzer/Multics/MHP-Saltzer-060508/bookcases/M00s/M0104%20074-12).PDF\n **/\nclass ClockCachePolicy : public BaseCachePolicy {\n public:\n  using map_iterator = map_t<int64_t, CacheKey*>::iterator;\n  /**\n   * @brief Constructor for the ClockCachePolicy class.\n   *\n   * @param capacity The capacity of the cache in terms of # elements.\n   */\n  ClockCachePolicy(int64_t capacity);\n\n  ClockCachePolicy() = default;\n\n  ClockCachePolicy(ClockCachePolicy&&) = default;\n\n  virtual ~ClockCachePolicy() = default;\n\n  /**\n   * @brief See BaseCachePolicy::Query.\n   */\n  std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> Query(\n      torch::Tensor keys);\n\n  /**\n   * @brief See BaseCachePolicy::QueryAndReplace.\n   */\n  std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>\n  QueryAndReplace(torch::Tensor keys);\n\n  /**\n   * @brief See BaseCachePolicy::Replace.\n   */\n  std::tuple<torch::Tensor, torch::Tensor> Replace(torch::Tensor keys);\n\n  CacheKey* Read(int64_t key) {\n    auto it = key_to_cache_key_.find(key);\n    if (it != key_to_cache_key_.end()) {\n      auto& cache_key = it->second->SetFreq();\n      if (!cache_key.BeingWritten()) {\n        return &cache_key.StartRead();\n      }\n    }\n    return nullptr;\n  }\n\n  auto getMapSentinelValue() const { return nullptr; }\n\n  std::pair<map_iterator, bool> Emplace(int64_t key) {\n    auto [it, inserted] = key_to_cache_key_.emplace(key, getMapSentinelValue());\n    bool readable = false;\n    if (!inserted) {\n      auto& cache_key = it->second->SetFreq();\n      if (!cache_key.BeingWritten()) {\n        cache_key.StartRead();\n        readable = true;\n      }\n    }\n    return {it, readable};\n  }\n\n  CacheKey* Insert(map_iterator it) {\n    const auto key = it->first;\n    queue_.push_front(CacheKey(key));\n    auto cache_key_ptr = &queue_.front();\n    mutable_value_ref(it) = cache_key_ptr;\n    return &cache_key_ptr->setPos(Evict());\n  }\n\n private:\n  int64_t Evict() {\n    // If the cache has space, get an unused slot otherwise perform eviction.\n    if (cache_usage_ < capacity_) return cache_usage_++;\n    while (true) {\n      auto& cache_key = queue_.back();\n      if (cache_key.getFreq() || cache_key.InUse()) {\n        cache_key.ResetFreq();\n        auto it = queue_.end();\n        std::advance(it, -1);\n        MoveToFront(queue_, queue_, it);\n      } else {\n        key_to_cache_key_.erase(cache_key.getKey());\n        const auto evicted_pos = cache_key.getPos();\n        queue_.pop_back();\n        return evicted_pos;\n      }\n    }\n  }\n\n  std::list<CacheKey> queue_;\n  map_t<int64_t, CacheKey*> key_to_cache_key_;\n};\n\n}  // namespace storage\n}  // namespace graphbolt\n\n#endif  // GRAPHBOLT_CACHE_POLICY_H_\n"
  },
  {
    "path": "graphbolt/src/circular_queue.h",
    "content": "/**\n *   Copyright (c) 2024, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)\n *   All rights reserved.\n *\n *   Licensed under the Apache License, Version 2.0 (the \"License\");\n *   you may not use this file except in compliance with the License.\n *   You may obtain a copy of the License at\n *\n *       http://www.apache.org/licenses/LICENSE-2.0\n *\n *   Unless required by applicable law or agreed to in writing, software\n *   distributed under the License is distributed on an \"AS IS\" BASIS,\n *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n *   See the License for the specific language governing permissions and\n *   limitations under the License.\n *\n * @file circular_queue.h\n * @brief Circular queue implementation.\n */\n#ifndef GRAPHBOLT_CIRCULAR_QUEUE_H_\n#define GRAPHBOLT_CIRCULAR_QUEUE_H_\n\n#include <memory>\n\nnamespace graphbolt {\n\ntemplate <typename T>\nstruct CircularQueue {\n  CircularQueue(const int64_t capacity)\n      : tail_(0),\n        head_(0),\n        // + 1 is needed to be able to differentiate empty and full states.\n        capacity_(capacity + 1),\n        data_{new T[capacity + 1]} {}\n\n  CircularQueue() = default;\n\n  T* Push(const T& x) {\n    auto insert_ptr = &data_[PostIncrement(tail_)];\n    *insert_ptr = x;\n    return insert_ptr;\n  }\n\n  T Pop() { return data_[PostIncrement(head_)]; }\n\n  void PopN(int64_t N) {\n    head_ += N;\n    if (head_ >= capacity_) head_ -= capacity_;\n  }\n\n  auto Clear() { head_ = tail_; }\n\n  T& Front() const { return data_[head_]; }\n\n  bool IsFull() const {\n    const auto diff = tail_ + 1 - head_;\n    return diff == 0 || diff == capacity_;\n  }\n\n  auto Size() const {\n    auto diff = tail_ - head_;\n    if (diff < 0) diff += capacity_;\n    return diff;\n  }\n\n  friend std::ostream& operator<<(\n      std::ostream& os, const CircularQueue& queue) {\n    for (auto i = queue.head_; i != queue.tail_; queue.PostIncrement(i)) {\n      os << queue.data_[i] << \", \";\n    }\n    return os << \"\\n\";\n  }\n\n  bool IsEmpty() const { return tail_ == head_; }\n\n  auto Capacity() const { return capacity_ - 1; }\n\n private:\n  int64_t PostIncrement(int64_t& i) const {\n    const auto ret = i++;\n    if (i >= capacity_) i -= capacity_;\n    return ret;\n  }\n\n  int64_t tail_;\n  int64_t head_;\n  int64_t capacity_;\n  std::unique_ptr<T[]> data_;\n};\n\n}  // namespace graphbolt\n\n#endif  // GRAPHBOLT_CIRCULAR_QUEUE_H_\n"
  },
  {
    "path": "graphbolt/src/cnumpy.cc",
    "content": "/**\n *  Copyright (c) 2024, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)\n *  Copyright (c) 2023 by Contributors\n * @file cnumpy.cc\n * @brief Numpy File Fetecher class.\n */\n\n#include \"./cnumpy.h\"\n\n#include \"./io_uring.h\"\n\n#ifdef HAVE_LIBRARY_LIBURING\n#include <fcntl.h>\n#include <sys/stat.h>\n#include <unistd.h>\n#endif\n\n#include <graphbolt/async.h>\n#include <torch/torch.h>\n\n#include <atomic>\n#include <cstring>\n#include <fstream>\n#include <memory>\n#include <numeric>\n#include <stdexcept>\n#include <vector>\n\n#include \"./circular_queue.h\"\n#include \"./utils.h\"\n\nnamespace graphbolt {\nnamespace storage {\n\nOnDiskNpyArray::OnDiskNpyArray(\n    std::string filename, torch::ScalarType dtype,\n    const std::vector<int64_t> &shape, torch::optional<int64_t> num_threads)\n    : filename_(filename),\n      feature_dim_(shape),\n      dtype_(dtype),\n      feature_size_(std::accumulate(\n          shape.begin() + 1, shape.end(), c10::elementSize(dtype),\n          std::multiplies<int64_t>())) {\n#ifndef __linux__\n  throw std::runtime_error(\n      \"OnDiskNpyArray is not supported on non-Linux systems.\");\n#endif\n#ifdef HAVE_LIBRARY_LIBURING\n  ParseNumpyHeader();\n  file_description_ = ::open(filename.c_str(), O_RDONLY | O_DIRECT);\n  if (file_description_ < 0) {\n    throw std::runtime_error(\"npy_load: Unable to open file \" + filename);\n  }\n  struct stat st;\n  TORCH_CHECK(::fstat(file_description_, &st) == 0);\n  const auto file_size = st.st_size;\n  block_size_ = st.st_blksize;\n  TORCH_CHECK(file_size - prefix_len_ >= feature_dim_[0] * feature_size_);\n\n  // The minimum page size to contain one feature.\n  aligned_length_ = (feature_size_ + block_size_ - 1) & ~(block_size_ - 1);\n\n  std::call_once(call_once_flag_, [&] {\n    // Get system max interop thread count.\n    num_queues_ =\n        io_uring::num_threads.value_or(torch::get_num_interop_threads());\n    TORCH_CHECK(num_queues_ > 0, \"A positive # queues is required.\");\n    io_uring_queue_ = std::unique_ptr<::io_uring[], io_uring_queue_destroyer>(\n        new ::io_uring[num_queues_], io_uring_queue_destroyer{num_queues_});\n    TORCH_CHECK(num_queues_ <= counting_semaphore_t::max());\n    semaphore_.release(num_queues_);\n    available_queues_.reserve(num_queues_);\n    // Init io_uring queue.\n    for (int64_t t = 0; t < num_queues_; t++) {\n      available_queues_.push_back(t);\n      TORCH_CHECK(\n          ::io_uring_queue_init(2 * kGroupSize, &io_uring_queue_[t], 0) == 0);\n      // We have allocated 2 * kGroupSize submission queue entries and\n      // 4 * kGroupSize completion queue entries after this call.\n    }\n  });\n\n  num_thread_ = std::min(\n      static_cast<int64_t>(num_queues_), num_threads.value_or(num_queues_));\n  TORCH_CHECK(num_thread_ > 0, \"A positive # threads is required.\");\n\n  // We allocate buffers for each existing queue because we might get assigned\n  // any queue in range [0, num_queues_).\n  read_tensor_ = torch::empty(\n      ReadBufferSizePerThread() * num_queues_ + block_size_ - 1,\n      torch::TensorOptions().dtype(torch::kInt8).device(torch::kCPU));\n#else\n  throw std::runtime_error(\"DiskBasedFeature is not available now.\");\n#endif  // HAVE_LIBRARY_LIBURING\n}\n\nc10::intrusive_ptr<OnDiskNpyArray> OnDiskNpyArray::Create(\n    std::string path, torch::ScalarType dtype,\n    const std::vector<int64_t> &shape, torch::optional<int64_t> num_threads) {\n  return c10::make_intrusive<OnDiskNpyArray>(path, dtype, shape, num_threads);\n}\n\nOnDiskNpyArray::~OnDiskNpyArray() {\n#ifdef HAVE_LIBRARY_LIBURING\n  TORCH_CHECK(::close(file_description_) == 0);\n#endif  // HAVE_LIBRARY_LIBURING\n}\n\nvoid OnDiskNpyArray::ParseNumpyHeader() {\n  // Parse numpy file header to get basic info of feature.\n  // Get file prefix length.\n  std::ifstream file(filename_);\n  if (!file.is_open()) {\n    throw std::runtime_error(\n        \"ParseNumpyHeader: Unable to open file \" + filename_);\n  }\n  std::string header;\n  std::getline(file, header);\n  // Get prefix length for computing feature offset,\n  // add one for new-line character.\n  prefix_len_ = header.size() + 1;\n}\n\nc10::intrusive_ptr<Future<torch::Tensor>> OnDiskNpyArray::IndexSelect(\n    torch::Tensor index) {\n#ifdef HAVE_LIBRARY_LIBURING\n  return IndexSelectIOUring(index);\n#else\n  TORCH_CHECK(false, \"OnDiskNpyArray is not supported on non-Linux systems.\");\n  return {};\n#endif  // HAVE_LIBRARY_LIBURING\n}\n\nclass ReadRequest {\n public:\n  char *destination_;\n  int64_t read_len_;\n  int64_t offset_;\n  int64_t block_size_;\n  char *aligned_read_buffer_;\n\n  auto AlignedOffset() const { return offset_ & ~(block_size_ - 1); }\n\n  auto ReadBuffer() const {\n    return aligned_read_buffer_ + offset_ - AlignedOffset();\n  }\n\n  auto AlignedReadSize() const {\n    const int64_t end_offset = offset_ + read_len_;\n    const int64_t aligned_end_offset =\n        (end_offset + block_size_ - 1) & ~(block_size_ - 1);\n    return aligned_end_offset - AlignedOffset();\n  }\n\n  auto MinimumReadSize() const { return offset_ + read_len_ - AlignedOffset(); }\n};\n\n#ifdef HAVE_LIBRARY_LIBURING\ntorch::Tensor OnDiskNpyArray::IndexSelectIOUringImpl(torch::Tensor index) {\n  std::vector<int64_t> shape(index.sizes().begin(), index.sizes().end());\n  shape.insert(shape.end(), feature_dim_.begin() + 1, feature_dim_.end());\n  auto result = torch::empty(\n      shape, index.options()\n                 .dtype(dtype_)\n                 .layout(torch::kStrided)\n                 .pinned_memory(utils::is_pinned(index))\n                 .requires_grad(false));\n  auto result_buffer = reinterpret_cast<char *>(result.data_ptr());\n\n  // Indicator for index error.\n  std::atomic<int> error_flag{};\n  std::atomic<int64_t> work_queue{};\n  // Construct a QueueAndBufferAcquirer object so that the worker threads can\n  // share the available queues and buffers.\n  QueueAndBufferAcquirer queue_source(this);\n  graphbolt::parallel_for_each_interop(0, num_thread_, 1, [&](int) {\n    // The completion queue might contain 4 * kGroupSize while we may submit\n    // 4 * kGroupSize more. No harm in overallocation here.\n    CircularQueue<ReadRequest> read_queue(8 * kGroupSize);\n    int64_t num_submitted = 0;\n    int64_t num_completed = 0;\n    auto [acquired_queue_handle, read_buffer_source2] = queue_source.get();\n    auto &io_uring_queue = acquired_queue_handle.get();\n    // Capturing structured binding is available only in C++20, so we rename.\n    auto read_buffer_source = read_buffer_source2;\n    auto submit_fn = [&](int64_t submission_minimum_batch_size) {\n      if (read_queue.Size() < submission_minimum_batch_size) return;\n      TORCH_CHECK(  // Check for sqe overflow.\n          read_queue.Size() <= 2 * kGroupSize);\n      TORCH_CHECK(  // Check for cqe overflow.\n          read_queue.Size() + num_submitted - num_completed <= 4 * kGroupSize);\n      // Submit and wait for the reads.\n      while (!read_queue.IsEmpty()) {\n        const auto submitted = ::io_uring_submit(&io_uring_queue);\n        TORCH_CHECK(submitted >= 0);\n        num_submitted += submitted;\n        // Pop the submitted entries from the queue.\n        read_queue.PopN(submitted);\n      }\n    };\n    for (int64_t read_buffer_slot = 0; true;) {\n      auto request_read_buffer = [&]() {\n        return read_buffer_source + (aligned_length_ + block_size_) *\n                                        (read_buffer_slot++ % (8 * kGroupSize));\n      };\n      const auto num_requested_items = std::max(\n          std::min(\n              // The condition not to overflow the completion queue.\n              2 * kGroupSize -\n                  (read_queue.Size() + num_submitted - num_completed),\n              // The condition not to overflow the submission queue.\n              kGroupSize - read_queue.Size()),\n          int64_t{});\n      const auto begin =\n          work_queue.fetch_add(num_requested_items, std::memory_order_relaxed);\n      if ((begin >= index.numel() && read_queue.IsEmpty() &&\n           num_completed >= num_submitted) ||\n          // Even when we encounter out of bounds index (error_flag == 1), we\n          // continue. We want to ensure the reads in flight successfully\n          // complete to avoid the instability due to incompleted reads.\n          error_flag.load(std::memory_order_relaxed) > 1)\n        break;\n      const auto end = std::min(begin + num_requested_items, index.numel());\n      AT_DISPATCH_INDEX_TYPES(\n          index.scalar_type(), \"IndexSelectIOUring\", ([&] {\n            auto index_data = index.data_ptr<index_t>();\n            for (int64_t i = begin; i < end; ++i) {\n              int64_t feature_id = index_data[i];\n              if (feature_id < 0) feature_id += feature_dim_[0];\n              if (feature_id < 0 || feature_id >= feature_dim_[0]) {\n                error_flag.store(1, std::memory_order_relaxed);\n                // Simply skip the out of bounds index.\n                continue;\n              }\n              // calculate offset of the feature.\n              const int64_t offset = feature_id * feature_size_ + prefix_len_;\n\n              ReadRequest req{\n                  result_buffer + feature_size_ * i, feature_size_, offset,\n                  block_size_, request_read_buffer()};\n\n              // Put requests into io_uring queue.\n              struct io_uring_sqe *sqe = io_uring_get_sqe(&io_uring_queue);\n              TORCH_CHECK(sqe);\n              io_uring_sqe_set_data(sqe, read_queue.Push(req));\n              io_uring_prep_read(\n                  sqe, file_description_, req.aligned_read_buffer_,\n                  req.AlignedReadSize(), req.AlignedOffset());\n              submit_fn(kGroupSize);\n            }\n          }));\n\n      submit_fn(1);  // Submit all sqes.\n      // Wait for the reads; completion queue entries.\n      struct io_uring_cqe *cqe;\n      TORCH_CHECK(num_submitted - num_completed <= 2 * kGroupSize);\n      TORCH_CHECK(\n          ::io_uring_wait_cqe_nr(\n              &io_uring_queue, &cqe, num_submitted - num_completed) == 0);\n      // Check the reads and abort on failure.\n      int num_cqes_seen = 0;\n      unsigned head;\n      io_uring_for_each_cqe(&io_uring_queue, head, cqe) {\n        const auto &req =\n            *reinterpret_cast<ReadRequest *>(io_uring_cqe_get_data(cqe));\n        auto actual_read_len = cqe->res;\n        if (actual_read_len < 0) {\n          error_flag.store(actual_read_len, std::memory_order_relaxed);\n          break;\n        }\n        const auto remaining_read_len =\n            std::max(req.MinimumReadSize() - actual_read_len, int64_t{});\n        const auto remaining_useful_read_len =\n            std::min(remaining_read_len, req.read_len_);\n        const auto useful_read_len = req.read_len_ - remaining_useful_read_len;\n        if (remaining_read_len) {\n          // Remaining portion will be read as part of the next batch.\n          ReadRequest rest{\n              req.destination_ + useful_read_len, remaining_useful_read_len,\n              req.offset_ + useful_read_len, block_size_,\n              request_read_buffer()};\n          // Put requests into io_uring queue.\n          struct io_uring_sqe *sqe = io_uring_get_sqe(&io_uring_queue);\n          TORCH_CHECK(sqe);\n          io_uring_sqe_set_data(sqe, read_queue.Push(rest));\n          io_uring_prep_read(\n              sqe, file_description_, rest.aligned_read_buffer_,\n              rest.AlignedReadSize(), rest.AlignedOffset());\n          submit_fn(kGroupSize);\n        }\n        // Copy results into result_buffer.\n        std::memcpy(req.destination_, req.ReadBuffer(), useful_read_len);\n        num_cqes_seen++;\n      }\n\n      // Move the head pointer of completion queue.\n      io_uring_cq_advance(&io_uring_queue, num_cqes_seen);\n      num_completed += num_cqes_seen;\n    }\n  });\n  const auto ret_val = error_flag.load(std::memory_order_relaxed);\n  switch (ret_val) {\n    case 0:  // Successful.\n      return result;\n    case 1:\n      throw std::out_of_range(\"IndexError: Index out of range.\");\n    default:\n      throw std::runtime_error(\n          \"io_uring error with errno: \" + std::to_string(-ret_val));\n  }\n}\n\nc10::intrusive_ptr<Future<torch::Tensor>> OnDiskNpyArray::IndexSelectIOUring(\n    torch::Tensor index) {\n  return async([=, this] { return IndexSelectIOUringImpl(index); });\n}\n\n#endif  // HAVE_LIBRARY_LIBURING\n}  // namespace storage\n}  // namespace graphbolt\n"
  },
  {
    "path": "graphbolt/src/cnumpy.h",
    "content": "/**\n *  Copyright (c) 2024, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)\n *  Copyright (c) 2023 by Contributors\n * @file cnumpy.h\n * @brief Numpy File Fetecher class.\n */\n\n#ifdef HAVE_LIBRARY_LIBURING\n#include <liburing.h>\n#endif  // HAVE_LIBRARY_LIBURING\n\n#include <graphbolt/async.h>\n#include <torch/script.h>\n\n#include <cassert>\n#include <cstdint>\n#include <cstdio>\n#include <cstdlib>\n#include <cstring>\n#include <cuda/std/semaphore>\n#include <memory>\n#include <mutex>\n#include <string>\n#include <utility>\n#include <vector>\n\nnamespace graphbolt {\nnamespace storage {\n\nnamespace {\n#ifdef HAVE_LIBRARY_LIBURING\nstruct io_uring_queue_destroyer {\n  int num_thread_;\n  void operator()(::io_uring* queues) {\n    if (!queues) return;\n    for (int t = 0; t < num_thread_; t++) {\n      // IO queue exit.\n      ::io_uring_queue_exit(&queues[t]);\n    }\n    delete[] queues;\n  }\n};\n#endif  // HAVE_LIBRARY_LIBURING\n}  // namespace\n\n/**\n * @brief Disk Numpy Fetecher class.\n */\nclass OnDiskNpyArray : public torch::CustomClassHolder {\n  // No user will need more than 1024 io_uring queues.\n  using counting_semaphore_t = ::cuda::std::counting_semaphore<1024>;\n\n public:\n  static constexpr int kGroupSize = 256;\n\n  /** @brief Default constructor. */\n  OnDiskNpyArray() = default;\n\n  /**\n   * @brief Constructor with given file path and data type.\n   * @param path Path to the on disk numpy file.\n   * @param dtype Data type of numpy array.\n   *\n   * @return OnDiskNpyArray\n   */\n  OnDiskNpyArray(\n      std::string filename, torch::ScalarType dtype,\n      const std::vector<int64_t>& shape, torch::optional<int64_t> num_threads);\n\n  /** @brief Create a disk feature fetcher from numpy file. */\n  static c10::intrusive_ptr<OnDiskNpyArray> Create(\n      std::string path, torch::ScalarType dtype,\n      const std::vector<int64_t>& shape, torch::optional<int64_t> num_threads);\n\n  /** @brief Deconstructor. */\n  ~OnDiskNpyArray();\n\n  /**\n   * @brief Parses the header of a numpy file to extract feature information.\n   **/\n  void ParseNumpyHeader();\n\n  /**\n   * @brief Read disk numpy file based on given index and transform to\n   * tensor.\n   */\n  c10::intrusive_ptr<Future<torch::Tensor>> IndexSelect(torch::Tensor index);\n\n#ifdef HAVE_LIBRARY_LIBURING\n  /**\n   * @brief Index-select operation on an on-disk numpy array using IO Uring for\n   * asynchronous I/O.\n   *\n   * This function performs index-select operation on an on-disk numpy array. It\n   * uses IO Uring for asynchronous I/O to efficiently read data from disk. The\n   * input tensor 'index' specifies the indices of features to select. The\n   * function reads features corresponding to the indices from the disk and\n   * returns a new tensor containing the selected features.\n   *\n   * @param index A 1D tensor containing the indices of features to select.\n   * @return A tensor containing the selected features.\n   * @throws std::runtime_error If index is out of range.\n   */\n  c10::intrusive_ptr<Future<torch::Tensor>> IndexSelectIOUring(\n      torch::Tensor index);\n\n  torch::Tensor IndexSelectIOUringImpl(torch::Tensor index);\n\n#endif  // HAVE_LIBRARY_LIBURING\n private:\n  int64_t ReadBufferSizePerThread() const {\n    return (aligned_length_ + block_size_) * kGroupSize * 8;\n  }\n\n  char* ReadBuffer(int thread_id) const {\n    auto read_buffer_void_ptr = read_tensor_.data_ptr();\n    size_t read_buffer_size = read_tensor_.numel();\n    auto read_buffer = reinterpret_cast<char*>(std::align(\n        block_size_, ReadBufferSizePerThread() * num_thread_,\n        read_buffer_void_ptr, read_buffer_size));\n    TORCH_CHECK(read_buffer, \"read_buffer allocation failed!\");\n    return read_buffer + ReadBufferSizePerThread() * thread_id;\n  }\n\n  const std::string filename_;  // Path to numpy file.\n  int file_description_;        // File description.\n  int64_t block_size_;          // Block size of the opened file.\n  int64_t prefix_len_;          // Length of head data in numpy file.\n  const std::vector<int64_t>\n      feature_dim_;                // Shape of features, e.g. {N,M,K,L}.\n  const torch::ScalarType dtype_;  // Feature data type.\n  const int64_t feature_size_;     // Number of bytes of feature size.\n  int64_t aligned_length_;         // Aligned feature_size.\n  int num_thread_;                 // Default thread number.\n  torch::Tensor read_tensor_;      // Provides temporary read buffer.\n\n#ifdef HAVE_LIBRARY_LIBURING\n\n  static inline std::once_flag\n      call_once_flag_;            // Protect initialization of below.\n  static inline int num_queues_;  // Number of queues.\n  static inline std::unique_ptr<::io_uring[], io_uring_queue_destroyer>\n      io_uring_queue_;  // io_uring queue.\n  static inline counting_semaphore_t semaphore_{\n      0};  // Control access to the io_uring queues.\n  static inline std::mutex available_queues_mtx_;  // available_queues_ mutex.\n  static inline std::vector<int> available_queues_;\n\n  /**\n   * @brief This class is meant to distribute the available read buffers and the\n   * statically declared io_uring queues among the worker threads.\n   */\n  class QueueAndBufferAcquirer {\n   public:\n    class UniqueQueue {\n     public:\n      UniqueQueue(int thread_id) : thread_id_(thread_id) {}\n      UniqueQueue(const UniqueQueue&) = delete;\n      UniqueQueue& operator=(const UniqueQueue&) = delete;\n\n      /**\n       * @brief Returns the queue back to the pool.\n       */\n      ~UniqueQueue() {\n        {\n          // We give back the slot we used.\n          std::lock_guard lock(available_queues_mtx_);\n          available_queues_.push_back(thread_id_);\n        }\n        semaphore_.release();\n      }\n\n      /**\n       * @brief Returns the raw io_uring queue.\n       */\n      ::io_uring& get() const { return io_uring_queue_[thread_id_]; }\n\n     private:\n      int thread_id_;\n    };\n\n    QueueAndBufferAcquirer(OnDiskNpyArray* array) : array_(array) {\n      semaphore_.acquire();\n    }\n\n    ~QueueAndBufferAcquirer() {\n      // If none of the worker threads acquire the semaphore, we make sure to\n      // release the ticket taken in the constructor.\n      if (!entering_first_.test_and_set(std::memory_order_relaxed)) {\n        semaphore_.release();\n      }\n    }\n\n    /**\n     * @brief Returns the secured io_uring queue and the read buffer as a pair.\n     * The raw io_uring queue can be accessed by calling `.get()` on the\n     * returned UniqueQueue object.\n     *\n     * @note The returned UniqueQueue object manages the lifetime of the\n     * io_uring queue. Its destructor returns the queue back to the pool.\n     */\n    std::pair<UniqueQueue, char*> get() {\n      // We consume a slot from the semaphore to use a queue.\n      if (entering_first_.test_and_set(std::memory_order_relaxed)) {\n        semaphore_.acquire();\n      }\n      const auto thread_id = [&] {\n        std::lock_guard lock(available_queues_mtx_);\n        TORCH_CHECK(!available_queues_.empty());\n        const auto thread_id = available_queues_.back();\n        available_queues_.pop_back();\n        return thread_id;\n      }();\n      return {\n          std::piecewise_construct, std::make_tuple(thread_id),\n          std::make_tuple(array_->ReadBuffer(thread_id))};\n    }\n\n   private:\n    const OnDiskNpyArray* array_;\n    std::atomic_flag entering_first_ = ATOMIC_FLAG_INIT;\n  };\n\n#endif  // HAVE_LIBRARY_LIBURING\n};\n\n}  // namespace storage\n}  // namespace graphbolt\n"
  },
  {
    "path": "graphbolt/src/concurrent_id_hash_map.cc",
    "content": "/**\n *  Copyright (c) 2023 by Contributors\n * @file concurrent_id_hash_map.cc\n * @brief Class about id hash map.\n */\n\n#include \"concurrent_id_hash_map.h\"\n\n#ifdef _MSC_VER\n#include <intrin.h>\n#endif  // _MSC_VER\n\n#include <cmath>\n#include <cuda/std/atomic>\n#include <numeric>\n\nnamespace {\nstatic constexpr int64_t kEmptyKey = -1;\nstatic constexpr int kGrainSize = 256;\n\n// The formula is established from experience which is used to get the hashmap\n// size from the input array size.\ninline size_t GetMapSize(size_t num) {\n  size_t capacity = 1;\n  return capacity << static_cast<size_t>(1 + std::log2(num * 3));\n}\n}  // namespace\n\nnamespace graphbolt {\nnamespace sampling {\n\ntemplate <typename IdType>\nConcurrentIdHashMap<IdType>::ConcurrentIdHashMap(\n    const torch::Tensor& ids, size_t num_seeds) {\n  const IdType* ids_data = ids.data_ptr<IdType>();\n  const size_t num_ids = static_cast<size_t>(ids.size(0));\n  size_t capacity = GetMapSize(num_ids);\n  mask_ = static_cast<IdType>(capacity - 1);\n\n  hash_map_ =\n      torch::full({static_cast<int64_t>(capacity * 2)}, -1, ids.options());\n\n  // This code block is to fill the ids into hash_map_.\n  unique_ids_ = torch::empty_like(ids);\n  IdType* unique_ids_data = unique_ids_.data_ptr<IdType>();\n  // Insert all ids into the hash map.\n  torch::parallel_for(0, num_ids, kGrainSize, [&](int64_t s, int64_t e) {\n    for (int64_t i = s; i < e; i++) {\n      InsertAndSetMin(ids_data[i], static_cast<IdType>(i));\n    }\n  });\n  // Place the first `num_seeds` ids.\n  unique_ids_.slice(0, 0, num_seeds) = ids.slice(0, 0, num_seeds);\n\n  auto valid_tensor = torch::empty(num_ids, ids.options().dtype(torch::kInt8));\n  auto valid = valid_tensor.data_ptr<int8_t>();\n\n  const int64_t num_threads = torch::get_num_threads();\n  std::vector<size_t> block_offset(num_threads + 1, 0);\n\n  // Count the valid numbers in each thread.\n  torch::parallel_for(\n      num_seeds, num_ids, kGrainSize, [&](int64_t s, int64_t e) {\n        size_t count = 0;\n        for (int64_t i = s; i < e; i++) {\n          if (MapId(ids_data[i]) == i) {\n            count++;\n            valid[i] = 1;\n          } else {\n            valid[i] = 0;\n          }\n        }\n        auto thread_id = torch::get_thread_num();\n        block_offset[thread_id + 1] = count;\n      });\n\n  // Get ExclusiveSum of each block.\n  std::partial_sum(\n      block_offset.begin() + 1, block_offset.end(), block_offset.begin() + 1);\n  unique_ids_ = unique_ids_.slice(0, 0, num_seeds + block_offset.back());\n\n  // Get unique array from ids and set value for hash map.\n  torch::parallel_for(\n      num_seeds, num_ids, kGrainSize, [&](int64_t s, int64_t e) {\n        auto thread_id = torch::get_thread_num();\n        auto pos = block_offset[thread_id] + num_seeds;\n        for (int64_t i = s; i < e; i++) {\n          if (valid[i]) {\n            unique_ids_data[pos] = ids_data[i];\n            Set(ids_data[i], pos);\n            pos = pos + 1;\n          }\n        }\n      });\n}\n\ntemplate <typename IdType>\ntorch::Tensor ConcurrentIdHashMap<IdType>::MapIds(\n    const torch::Tensor& ids) const {\n  const IdType* ids_data = ids.data_ptr<IdType>();\n\n  torch::Tensor new_ids = torch::empty_like(ids);\n  auto num_ids = new_ids.size(0);\n  IdType* values_data = new_ids.data_ptr<IdType>();\n\n  torch::parallel_for(0, num_ids, kGrainSize, [&](int64_t s, int64_t e) {\n    for (int64_t i = s; i < e; i++) {\n      values_data[i] = MapId(ids_data[i]);\n    }\n  });\n  return new_ids;\n}\n\ntemplate <typename IdType>\nconstexpr IdType getKeyIndex(IdType pos) {\n  return 2 * pos;\n}\n\ntemplate <typename IdType>\nconstexpr IdType getValueIndex(IdType pos) {\n  return 2 * pos + 1;\n}\n\ntemplate <typename IdType>\ninline void ConcurrentIdHashMap<IdType>::Next(\n    IdType* pos, IdType* delta) const {\n  // Use Quadric probing.\n  *pos = (*pos + (*delta) * (*delta)) & mask_;\n  *delta = *delta + 1;\n}\n\ntemplate <typename IdType>\ninline IdType ConcurrentIdHashMap<IdType>::MapId(IdType id) const {\n  IdType pos = (id & mask_), delta = 1;\n  IdType empty_key = static_cast<IdType>(kEmptyKey);\n  IdType* hash_map_data = hash_map_.data_ptr<IdType>();\n  IdType key = hash_map_data[getKeyIndex(pos)];\n  while (key != empty_key && key != id) {\n    Next(&pos, &delta);\n    key = hash_map_data[getKeyIndex(pos)];\n  }\n  if (key == empty_key) {\n    throw std::out_of_range(\"Id not found: \" + std::to_string(id));\n  }\n  return hash_map_data[getValueIndex(pos)];\n}\n\ntemplate <typename IdType>\nbool ConcurrentIdHashMap<IdType>::Insert(IdType id) {\n  IdType pos = (id & mask_), delta = 1;\n  InsertState state = AttemptInsertAt(pos, id);\n  while (state == InsertState::OCCUPIED) {\n    Next(&pos, &delta);\n    state = AttemptInsertAt(pos, id);\n  }\n\n  return state == InsertState::INSERTED;\n}\n\ntemplate <typename IdType>\ninline void ConcurrentIdHashMap<IdType>::Set(IdType key, IdType value) {\n  IdType pos = (key & mask_), delta = 1;\n  IdType* hash_map_data = hash_map_.data_ptr<IdType>();\n  while (hash_map_data[getKeyIndex(pos)] != key) {\n    Next(&pos, &delta);\n  }\n\n  hash_map_data[getValueIndex(pos)] = value;\n}\n\ntemplate <typename IdType>\ninline void ConcurrentIdHashMap<IdType>::InsertAndSet(IdType id, IdType value) {\n  IdType pos = (id & mask_), delta = 1;\n  while (AttemptInsertAt(pos, id) == InsertState::OCCUPIED) {\n    Next(&pos, &delta);\n  }\n\n  hash_map_.data_ptr<IdType>()[getValueIndex(pos)] = value;\n}\n\ntemplate <typename IdType>\nvoid ConcurrentIdHashMap<IdType>::InsertAndSetMin(IdType id, IdType value) {\n  IdType pos = (id & mask_), delta = 1;\n  InsertState state = AttemptInsertAt(pos, id);\n  while (state == InsertState::OCCUPIED) {\n    Next(&pos, &delta);\n    state = AttemptInsertAt(pos, id);\n  }\n\n  IdType empty_key = static_cast<IdType>(kEmptyKey);\n  IdType val_pos = getValueIndex(pos);\n  ::cuda::std::atomic_ref value_ref(\n      reinterpret_cast<IdType*>(hash_map_.data_ptr())[val_pos]);\n  for (auto old_val = empty_key; old_val == empty_key || old_val > value;) {\n    // It is more efficient to use weak variant in a loop.\n    if (value_ref.compare_exchange_weak(old_val, value)) break;\n  }\n}\n\ntemplate <typename IdType>\ninline typename ConcurrentIdHashMap<IdType>::InsertState\nConcurrentIdHashMap<IdType>::AttemptInsertAt(int64_t pos, IdType key) {\n  auto expected = static_cast<IdType>(kEmptyKey);\n  ::cuda::std::atomic_ref key_ref(\n      reinterpret_cast<IdType*>(hash_map_.data_ptr())[getKeyIndex(pos)]);\n  if (key_ref.compare_exchange_strong(expected, key)) {\n    return InsertState::INSERTED;\n  } else if (expected == key) {\n    return InsertState::EXISTED;\n  } else {\n    return InsertState::OCCUPIED;\n  }\n}\n\ntemplate class ConcurrentIdHashMap<int32_t>;\ntemplate class ConcurrentIdHashMap<int64_t>;\n\n}  // namespace sampling\n}  // namespace graphbolt\n"
  },
  {
    "path": "graphbolt/src/concurrent_id_hash_map.h",
    "content": "/**\n *  Copyright (c) 2023 by Contributors\n * @file concurrent_id_hash_map.h\n * @brief Class about concurrent id hash map.\n */\n\n#ifndef GRAPHBOLT_CONCURRENT_ID_HASH_MAP_H_\n#define GRAPHBOLT_CONCURRENT_ID_HASH_MAP_H_\n\n#include <torch/torch.h>\n\n#include <functional>\n#include <memory>\n#include <vector>\n\nnamespace graphbolt {\nnamespace sampling {\n\n/**\n * @brief A CPU targeted hashmap for mapping duplicate and non-consecutive ids\n * in the provided array to unique and consecutive ones. It utilizes\n * multi-threading to accelerate the insert and search speed. Currently it is\n * only designed to be used in `ToBlockCpu` for optimizing, so it only support\n * key insertions once with Init function, and it does not support key deletion.\n *\n * The hash map should be prepared in two phases before using. With the first\n * being creating the hashmap, and then initialize it with an id array which is\n * divided into 2 parts: [`seed ids`, `sampled ids`]. `Seed ids` refer to\n * a set ids chosen as the input for sampling process and `sampled ids` are the\n * ids new sampled from the process (note the the `seed ids` might also be\n * sampled in the process and included in the `sampled ids`). In result `seed\n * ids` are mapped to [0, num_seed_ids) and `sampled ids` to [num_seed_ids,\n * num_unique_ids). Notice that mapping order is stable for `seed ids` while not\n * for the `sampled ids`.\n *\n * For example, for an array `A` having 4 seed ids with following entries:\n * [99, 98, 100, 97, 97, 101, 101, 102, 101]\n * Create the hashmap `H` with:\n * `H = ConcurrentIdHashMap()` (1)\n * And Init it with:\n * `U = H.Init(A)` (2)  (U is an id array used to store the unqiue\n * ids in A).\n * Then `U` should be (U is not exclusive as the overall mapping is not stable):\n * [99, 98, 100, 97, 102, 101]\n * And the hashmap should generate following mappings:\n *  * [\n *   {key: 99, value: 0},\n *   {key: 98, value: 1},\n *   {key: 100, value: 2},\n *   {key: 97, value: 3},\n *   {key: 102, value: 4},\n *   {key: 101, value: 5}\n * ]\n * Search the hashmap with array `I`=[98, 99, 102]:\n * R = H.Map(I) (3)\n * R should be:\n * [1, 0, 4]\n **/\ntemplate <typename IdType>\nclass ConcurrentIdHashMap {\n private:\n  /**\n   * @brief The result state of an attempt to insert.\n   */\n  enum class InsertState {\n    OCCUPIED,  // Indicates that the space where an insertion is being\n               // attempted is already occupied by another element.\n    EXISTED,  // Indicates that the element being inserted already exists in the\n              // map, and thus no insertion is performed.\n    INSERTED  // Indicates that the insertion was successful and a new element\n              // was added to the map.\n  };\n\n public:\n  /**\n   * @brief Initialize the hashmap with an array of ids. The first `num_seeds`\n   * ids are unique and must be mapped to a contiguous array starting\n   * from 0. The left can be duplicated and the mapping result is not stable.\n   * The unique'ified ids can be accessed through calling `GetUniqueIds()`;\n   *\n   * @param ids The array of the ids to be inserted.\n   * @param num_seeds The number of seed ids.\n   */\n  ConcurrentIdHashMap(const torch::Tensor& ids, size_t num_seeds);\n\n  ConcurrentIdHashMap(const ConcurrentIdHashMap& other) = delete;\n  ConcurrentIdHashMap& operator=(const ConcurrentIdHashMap& other) = delete;\n\n  /**\n   * @brief Get the unique ids for the keys given in the constructor.\n   */\n  const torch::Tensor& GetUniqueIds() const { return unique_ids_; }\n\n  /**\n   * @brief Find mappings of given keys.\n   *\n   * @param ids The keys to map for.\n   *\n   * @return Mapping results corresponding to `ids`.\n   */\n  torch::Tensor MapIds(const torch::Tensor& ids) const;\n\n private:\n  /**\n   * @brief Get the next position and delta for probing.\n   *\n   * @param[in,out] pos Calculate the next position with quadric probing.\n   * @param[in,out] delta Calculate the next delta by adding 1.\n   */\n  inline void Next(IdType* pos, IdType* delta) const;\n\n  /**\n   * @brief Find the mapping of a given key.\n   *\n   * @param id The key to map for.\n   *\n   * @return Mapping result corresponding to `id`.\n   */\n  inline IdType MapId(const IdType id) const;\n\n  /**\n   * @brief Insert an id into the hash map.\n   *\n   * @param id The id to be inserted.\n   *\n   * @return Whether the `id` is inserted or not.\n   */\n  inline bool Insert(IdType id);\n\n  /**\n   * @brief Set the value for the key in the hash map.\n   *\n   * @param key The key to set for.\n   * @param value The value to be set for the `key`.\n   *\n   * @warning Key must exist.\n   */\n  inline void Set(IdType key, IdType value);\n\n  /**\n   * @brief Insert a key into the hash map.\n   *\n   * @param id The key to be inserted.\n   * @param value The value to be set for the `key`.\n   *\n   */\n  inline void InsertAndSet(IdType key, IdType value);\n\n  /**\n   * @brief Insert a key into the hash map. If the key exists, set the value\n   * with the smaller value.\n   *\n   * @param id The key to be inserted.\n   * @param value The value to be set for the `key`.\n   *\n   */\n  inline void InsertAndSetMin(IdType id, IdType value);\n\n  /**\n   * @brief Attempt to insert the key into the hash map at the given position.\n   *\n   * @param pos The position in the hash map to be inserted at.\n   * @param key The key to be inserted.\n   *\n   * @return The state of the insertion.\n   */\n  inline InsertState AttemptInsertAt(int64_t pos, IdType key);\n\n private:\n  /**\n   * @brief Hash maps which is used to store all elements.\n   */\n  torch::Tensor hash_map_;\n\n  /**\n   * @brief Holds the ids that are made unique in the constructor.\n   */\n  torch::Tensor unique_ids_;\n\n  /**\n   * @brief Mask which is assisted to get the position in the table\n   * for a key by performing `&` operation with it.\n   */\n  IdType mask_;\n};\n\n}  // namespace sampling\n}  // namespace graphbolt\n\n#endif  // GRAPHBOLT_CONCURRENT_ID_HASH_MAP_H_\n"
  },
  {
    "path": "graphbolt/src/cuda/common.h",
    "content": "/**\n *   Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)\n *   All rights reserved.\n *\n *   Licensed under the Apache License, Version 2.0 (the \"License\");\n *   you may not use this file except in compliance with the License.\n *   You may obtain a copy of the License at\n *\n *       http://www.apache.org/licenses/LICENSE-2.0\n *\n *   Unless required by applicable law or agreed to in writing, software\n *   distributed under the License is distributed on an \"AS IS\" BASIS,\n *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n *   See the License for the specific language governing permissions and\n *   limitations under the License.\n *\n * @file cuda/common.h\n * @brief Common utilities for CUDA\n */\n#ifndef GRAPHBOLT_CUDA_COMMON_H_\n#define GRAPHBOLT_CUDA_COMMON_H_\n\n#include <ATen/cuda/CUDAEvent.h>\n#include <c10/cuda/CUDACachingAllocator.h>\n#include <c10/cuda/CUDAException.h>\n#include <c10/cuda/CUDAStream.h>\n#include <cuda_runtime.h>\n#include <thrust/execution_policy.h>\n#include <torch/script.h>\n\n#include <memory>\n#include <unordered_map>\n\nnamespace graphbolt {\nnamespace cuda {\n\n/**\n * @brief This class is designed to allocate workspace storage\n * and to get a nonblocking thrust execution policy\n * that uses torch's CUDA memory pool and the current cuda stream:\n *\n * cuda::CUDAWorkspaceAllocator allocator;\n * const auto stream = torch::cuda::getDefaultCUDAStream();\n * const auto exec_policy = thrust::cuda::par_nosync(allocator).on(stream);\n *\n * Now, one can pass exec_policy to thrust functions\n *\n * To get an integer array of size 1000 whose lifetime is managed by unique_ptr,\n * use:\n *\n * auto int_array = allocator.AllocateStorage<int>(1000);\n *\n * int_array.get() gives the raw pointer.\n */\ntemplate <typename value_t = char>\nstruct CUDAWorkspaceAllocator {\n  static_assert(sizeof(char) == 1, \"sizeof(char) == 1 should hold.\");\n  // Required by thrust to satisfy allocator requirements.\n  using value_type = value_t;\n\n  explicit CUDAWorkspaceAllocator() {\n    at::globalContext().lazyInitDevice(at::kCUDA);\n  }\n\n  template <class U>\n  CUDAWorkspaceAllocator(CUDAWorkspaceAllocator<U> const&) noexcept {}\n\n  CUDAWorkspaceAllocator& operator=(const CUDAWorkspaceAllocator&) = default;\n\n  void operator()(void* ptr) const {\n    c10::cuda::CUDACachingAllocator::raw_delete(ptr);\n  }\n\n  // Required by thrust to satisfy allocator requirements.\n  value_type* allocate(std::ptrdiff_t size) const {\n    return reinterpret_cast<value_type*>(\n        c10::cuda::CUDACachingAllocator::raw_alloc(size * sizeof(value_type)));\n  }\n\n  // Required by thrust to satisfy allocator requirements.\n  void deallocate(value_type* ptr, std::size_t) const { operator()(ptr); }\n\n  template <typename T>\n  std::unique_ptr<T, CUDAWorkspaceAllocator> AllocateStorage(\n      std::size_t size) const {\n    return std::unique_ptr<T, CUDAWorkspaceAllocator>(\n        reinterpret_cast<T*>(\n            c10::cuda::CUDACachingAllocator::raw_alloc(sizeof(T) * size)),\n        *this);\n  }\n};\n\ninline auto GetAllocator() { return CUDAWorkspaceAllocator{}; }\n\ninline auto GetCurrentStream() { return c10::cuda::getCurrentCUDAStream(); }\n\ntemplate <typename T>\ninline bool is_zero(T size) {\n  return size == 0;\n}\n\ntemplate <>\ninline bool is_zero<dim3>(dim3 size) {\n  return size.x == 0 || size.y == 0 || size.z == 0;\n}\n\n#define CUDA_RUNTIME_CHECK(EXPR)                           \\\n  do {                                                     \\\n    cudaError_t __err = EXPR;                              \\\n    if (__err != cudaSuccess) {                            \\\n      auto get_error_str_err = cudaGetErrorString(__err);  \\\n      AT_ERROR(\"CUDA runtime error: \", get_error_str_err); \\\n    }                                                      \\\n  } while (0)\n\n#define CUDA_CALL(func) C10_CUDA_CHECK((func))\n\n#define CUDA_KERNEL_CALL(kernel, nblks, nthrs, shmem, ...)          \\\n  {                                                                 \\\n    if (!graphbolt::cuda::is_zero((nblks)) &&                       \\\n        !graphbolt::cuda::is_zero((nthrs))) {                       \\\n      auto stream = graphbolt::cuda::GetCurrentStream();            \\\n      (kernel)<<<(nblks), (nthrs), (shmem), stream>>>(__VA_ARGS__); \\\n      C10_CUDA_KERNEL_LAUNCH_CHECK();                               \\\n    }                                                               \\\n  }\n\n#define CUB_CALL(fn, ...)                                                     \\\n  {                                                                           \\\n    auto allocator = graphbolt::cuda::GetAllocator();                         \\\n    auto stream = graphbolt::cuda::GetCurrentStream();                        \\\n    size_t workspace_size = 0;                                                \\\n    CUDA_CALL(cub::fn(nullptr, workspace_size, __VA_ARGS__, stream));         \\\n    auto workspace = allocator.AllocateStorage<char>(workspace_size);         \\\n    CUDA_CALL(cub::fn(workspace.get(), workspace_size, __VA_ARGS__, stream)); \\\n  }\n\n#define THRUST_CALL(fn, ...)                                                 \\\n  [&] {                                                                      \\\n    auto allocator = graphbolt::cuda::GetAllocator();                        \\\n    auto stream = graphbolt::cuda::GetCurrentStream();                       \\\n    const auto exec_policy = thrust::cuda::par_nosync(allocator).on(stream); \\\n    return thrust::fn(exec_policy, __VA_ARGS__);                             \\\n  }()\n\n/**\n * @brief This class is designed to handle the copy operation of a single\n * scalar_t item from a given CUDA device pointer. Later, if the object is cast\n * into scalar_t, the value can be read.\n *\n * auto num_edges = cuda::CopyScalar(indptr.data_ptr<scalar_t>() +\n *     indptr.size(0) - 1);\n * // Perform many operations here, they will run as normal.\n * // We finally need to read num_edges.\n * auto indices = torch::empty(static_cast<scalar_t>(num_edges));\n */\ntemplate <typename scalar_t>\nstruct CopyScalar {\n  CopyScalar() : is_ready_(true) { init_pinned_storage(); }\n\n  void record(at::cuda::CUDAStream stream = GetCurrentStream()) {\n    copy_event_.record(stream);\n    is_ready_ = false;\n  }\n\n  scalar_t* get() {\n    return reinterpret_cast<scalar_t*>(pinned_scalar_.data_ptr());\n  }\n\n  CopyScalar(const scalar_t* device_ptr) {\n    init_pinned_storage();\n    auto stream = GetCurrentStream();\n    CUDA_CALL(cudaMemcpyAsync(\n        reinterpret_cast<scalar_t*>(pinned_scalar_.data_ptr()), device_ptr,\n        sizeof(scalar_t), cudaMemcpyDeviceToHost, stream));\n    record(stream);\n  }\n\n  operator scalar_t() {\n    if (!is_ready_) {\n      copy_event_.synchronize();\n      is_ready_ = true;\n    }\n    return *get();\n  }\n\n private:\n  void init_pinned_storage() {\n    pinned_scalar_ = torch::empty(\n        sizeof(scalar_t),\n        c10::TensorOptions().dtype(torch::kBool).pinned_memory(true));\n  }\n\n  torch::Tensor pinned_scalar_;\n  at::cuda::CUDAEvent copy_event_;\n  bool is_ready_;\n};\n\n#define GRAPHBOLT_DISPATCH_ELEMENT_SIZES(element_size, name, ...)             \\\n  [&] {                                                                       \\\n    switch (element_size) {                                                   \\\n      case 1: {                                                               \\\n        using element_size_t = uint8_t;                                       \\\n        return __VA_ARGS__();                                                 \\\n      }                                                                       \\\n      case 2: {                                                               \\\n        using element_size_t = uint16_t;                                      \\\n        return __VA_ARGS__();                                                 \\\n      }                                                                       \\\n      case 4: {                                                               \\\n        using element_size_t = uint32_t;                                      \\\n        return __VA_ARGS__();                                                 \\\n      }                                                                       \\\n      case 8: {                                                               \\\n        using element_size_t = uint64_t;                                      \\\n        return __VA_ARGS__();                                                 \\\n      }                                                                       \\\n      case 16: {                                                              \\\n        using element_size_t = float4;                                        \\\n        return __VA_ARGS__();                                                 \\\n      }                                                                       \\\n      default:                                                                \\\n        TORCH_CHECK(false, name, \" with the element_size is not supported!\"); \\\n        using element_size_t = uint8_t;                                       \\\n        return __VA_ARGS__();                                                 \\\n    }                                                                         \\\n  }()\n\n}  // namespace cuda\n}  // namespace graphbolt\n#endif  // GRAPHBOLT_CUDA_COMMON_H_\n"
  },
  {
    "path": "graphbolt/src/cuda/cooperative_minibatching_utils.cu",
    "content": "/**\n *   Copyright (c) 2024, mfbalin (Muhammed Fatih Balin)\n *   All rights reserved.\n *\n *   Licensed under the Apache License, Version 2.0 (the \"License\");\n *   you may not use this file except in compliance with the License.\n *   You may obtain a copy of the License at\n *\n *       http://www.apache.org/licenses/LICENSE-2.0\n *\n *   Unless required by applicable law or agreed to in writing, software\n *   distributed under the License is distributed on an \"AS IS\" BASIS,\n *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n *   See the License for the specific language governing permissions and\n *   limitations under the License.\n *\n * @file cuda/cooperative_minibatching_utils.cu\n * @brief Cooperative Minibatching (arXiv:2310.12403) utility function\n * implementations in CUDA.\n */\n#include <graphbolt/cuda_ops.h>\n#include <thrust/scatter.h>\n#include <thrust/transform.h>\n\n#include <cub/cub.cuh>\n#include <cuda/functional>\n\n#include \"../utils.h\"\n#include \"./common.h\"\n#include \"./cooperative_minibatching_utils.cuh\"\n#include \"./cooperative_minibatching_utils.h\"\n#include \"./utils.h\"\n\nnamespace graphbolt {\nnamespace cuda {\n\ntorch::Tensor RankAssignment(\n    torch::Tensor nodes, const int64_t rank, const int64_t world_size) {\n  auto part_ids = torch::empty_like(nodes, nodes.options().dtype(kPartDType));\n  auto part_ids_ptr = part_ids.data_ptr<part_t>();\n  AT_DISPATCH_INDEX_TYPES(\n      nodes.scalar_type(), \"RankAssignment\", ([&] {\n        auto nodes_ptr = nodes.data_ptr<index_t>();\n        THRUST_CALL(\n            transform, nodes_ptr, nodes_ptr + nodes.numel(), part_ids_ptr,\n            ::cuda::proclaim_return_type<part_t>(\n                [rank = static_cast<uint32_t>(rank),\n                 world_size = static_cast<uint32_t>(\n                     world_size)] __device__(index_t id) -> part_t {\n                  return rank_assignment(id, rank, world_size);\n                }));\n      }));\n  return part_ids;\n}\n\nstd::tuple<torch::Tensor, torch::Tensor, torch::Tensor, at::cuda::CUDAEvent>\nRankSortImpl(\n    torch::Tensor nodes, torch::Tensor part_ids, torch::Tensor offsets_dev,\n    const int64_t world_size) {\n  const int num_bits = cuda::NumberOfBits(world_size);\n  const auto num_batches = offsets_dev.numel() - 1;\n  auto offsets_dev_ptr = offsets_dev.data_ptr<int64_t>();\n  auto part_ids_sorted = torch::empty_like(part_ids);\n  auto part_ids2 = part_ids.clone();\n  auto part_ids2_sorted = torch::empty_like(part_ids2);\n  auto nodes_sorted = torch::empty_like(nodes);\n  auto index = torch::arange(nodes.numel(), nodes.options());\n  auto index_sorted = torch::empty_like(index);\n  return AT_DISPATCH_INDEX_TYPES(\n      nodes.scalar_type(), \"RankSortImpl\", ([&] {\n        CUB_CALL(\n            DeviceSegmentedRadixSort::SortPairs,\n            part_ids.data_ptr<cuda::part_t>(),\n            part_ids_sorted.data_ptr<cuda::part_t>(), nodes.data_ptr<index_t>(),\n            nodes_sorted.data_ptr<index_t>(), nodes.numel(), num_batches,\n            offsets_dev_ptr, offsets_dev_ptr + 1, 0, num_bits);\n        auto offsets = torch::empty(\n            num_batches * world_size + 1, c10::TensorOptions()\n                                              .dtype(offsets_dev.scalar_type())\n                                              .pinned_memory(true));\n        CUB_CALL(\n            DeviceFor::Bulk, num_batches * world_size + 1,\n            [=, part_ids = part_ids_sorted.data_ptr<cuda::part_t>(),\n             offsets = offsets.data_ptr<int64_t>()] __device__(int64_t i) {\n              const auto batch_id = i / world_size;\n              const auto rank = i % world_size;\n              const auto offset_begin = offsets_dev_ptr[batch_id];\n              const auto offset_end =\n                  offsets_dev_ptr[::cuda::std::min(batch_id + 1, num_batches)];\n              offsets[i] = cub::LowerBound(\n                               part_ids + offset_begin,\n                               offset_end - offset_begin, rank) +\n                           offset_begin;\n            });\n        at::cuda::CUDAEvent offsets_event;\n        offsets_event.record();\n        CUB_CALL(\n            DeviceSegmentedRadixSort::SortPairs,\n            part_ids2.data_ptr<cuda::part_t>(),\n            part_ids2_sorted.data_ptr<cuda::part_t>(),\n            index.data_ptr<index_t>(), index_sorted.data_ptr<index_t>(),\n            nodes.numel(), num_batches, offsets_dev_ptr, offsets_dev_ptr + 1, 0,\n            num_bits);\n        auto values = ops::IndptrEdgeIdsImpl(\n            offsets_dev, nodes.scalar_type(), torch::nullopt, nodes.numel());\n        THRUST_CALL(\n            scatter, values.data_ptr<index_t>(),\n            values.data_ptr<index_t>() + values.numel(),\n            index_sorted.data_ptr<index_t>(), index.data_ptr<index_t>());\n        return std::make_tuple(\n            nodes_sorted, index, offsets, std::move(offsets_event));\n      }));\n}\n\nstd::vector<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>> RankSort(\n    const std::vector<torch::Tensor>& nodes_list, const int64_t rank,\n    const int64_t world_size) {\n  const auto num_batches = nodes_list.size();\n  auto nodes = torch::cat(nodes_list, 0);\n  auto offsets = torch::empty(\n      num_batches + 1,\n      c10::TensorOptions().dtype(torch::kInt64).pinned_memory(true));\n  auto offsets_ptr = offsets.data_ptr<int64_t>();\n  offsets_ptr[0] = 0;\n  for (int64_t i = 0; i < num_batches; i++) {\n    offsets_ptr[i + 1] = offsets_ptr[i] + nodes_list[i].numel();\n  }\n  auto part_ids = RankAssignment(nodes, rank, world_size);\n  auto offsets_dev =\n      torch::empty_like(offsets, nodes.options().dtype(offsets.scalar_type()));\n  CUDA_CALL(cudaMemcpyAsync(\n      offsets_dev.data_ptr<int64_t>(), offsets_ptr,\n      sizeof(int64_t) * offsets.numel(), cudaMemcpyHostToDevice,\n      cuda::GetCurrentStream()));\n  auto [nodes_sorted, index_sorted, rank_offsets, rank_offsets_event] =\n      RankSortImpl(nodes, part_ids, offsets_dev, world_size);\n  std::vector<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>> results;\n  rank_offsets_event.synchronize();\n  for (int64_t i = 0; i < num_batches; i++) {\n    results.emplace_back(\n        nodes_sorted.slice(0, offsets_ptr[i], offsets_ptr[i + 1]),\n        index_sorted.slice(0, offsets_ptr[i], offsets_ptr[i + 1]),\n        rank_offsets.slice(0, i * world_size, (i + 1) * world_size + 1));\n  }\n  return results;\n}\n\nc10::intrusive_ptr<Future<\n    std::vector<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>>>>\nRankSortAsync(\n    const std::vector<torch::Tensor>& nodes_list, const int64_t rank,\n    const int64_t world_size) {\n  return async(\n      [=] { return RankSort(nodes_list, rank, world_size); },\n      utils::is_on_gpu(nodes_list.at(0)));\n}\n\n}  // namespace cuda\n}  // namespace graphbolt\n"
  },
  {
    "path": "graphbolt/src/cuda/cooperative_minibatching_utils.cuh",
    "content": "/**\n *   Copyright (c) 2024, mfbalin (Muhammed Fatih Balin)\n *   All rights reserved.\n *\n *   Licensed under the Apache License, Version 2.0 (the \"License\");\n *   you may not use this file except in compliance with the License.\n *   You may obtain a copy of the License at\n *\n *       http://www.apache.org/licenses/LICENSE-2.0\n *\n *   Unless required by applicable law or agreed to in writing, software\n *   distributed under the License is distributed on an \"AS IS\" BASIS,\n *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n *   See the License for the specific language governing permissions and\n *   limitations under the License.\n *\n * @file cuda/cooperative_minibatching_utils.cuh\n * @brief Cooperative Minibatching (arXiv:2310.12403) utility device functions\n * in CUDA.\n */\n#ifndef GRAPHBOLT_CUDA_COOPERATIVE_MINIBATCHING_UTILS_CUH_\n#define GRAPHBOLT_CUDA_COOPERATIVE_MINIBATCHING_UTILS_CUH_\n\n#include <curand_kernel.h>\n\nnamespace graphbolt {\nnamespace cuda {\n\nusing part_t = uint8_t;\nconstexpr auto kPartDType = torch::kUInt8;\n\n/**\n * @brief Given a vertex id, the rank of current GPU and the world size, returns\n * the rank that this id belongs in a deterministic manner.\n *\n * @param id         The node id that will mapped to a rank in [0, world_size).\n * @param rank       The rank of the current GPU.\n * @param world_size The world size, the total number of cooperating GPUs.\n *\n * @return The rank of the GPU the given id is mapped to.\n */\ntemplate <typename index_t>\n__device__ inline auto rank_assignment(\n    index_t id, uint32_t rank, uint32_t world_size) {\n  // Consider using a faster implementation in the future.\n  constexpr uint64_t kCurandSeed = 999961;  // Any random number.\n  curandStatePhilox4_32_10_t rng;\n  curand_init(kCurandSeed, 0, id, &rng);\n  return (curand(&rng) - rank) % world_size;\n}\n\n}  // namespace cuda\n}  // namespace graphbolt\n\n#endif  // GRAPHBOLT_CUDA_COOPERATIVE_MINIBATCHING_UTILS_CUH_\n"
  },
  {
    "path": "graphbolt/src/cuda/cooperative_minibatching_utils.h",
    "content": "/**\n *   Copyright (c) 2024, mfbalin (Muhammed Fatih Balin)\n *   All rights reserved.\n *\n *   Licensed under the Apache License, Version 2.0 (the \"License\");\n *   you may not use this file except in compliance with the License.\n *   You may obtain a copy of the License at\n *\n *       http://www.apache.org/licenses/LICENSE-2.0\n *\n *   Unless required by applicable law or agreed to in writing, software\n *   distributed under the License is distributed on an \"AS IS\" BASIS,\n *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n *   See the License for the specific language governing permissions and\n *   limitations under the License.\n *\n * @file cuda/cooperative_minibatching_utils.h\n * @brief Cooperative Minibatching (arXiv:2310.12403) utility function headers\n * in CUDA.\n */\n#ifndef GRAPHBOLT_CUDA_COOPERATIVE_MINIBATCHING_UTILS_H_\n#define GRAPHBOLT_CUDA_COOPERATIVE_MINIBATCHING_UTILS_H_\n\n#include <ATen/cuda/CUDAEvent.h>\n#include <graphbolt/async.h>\n#include <torch/script.h>\n\nnamespace graphbolt {\nnamespace cuda {\n\n/**\n * @brief Given node ids, the rank of current GPU and the world size, returns\n * the ranks that the given ids belong in a deterministic manner.\n *\n * @param nodes      Node id tensor to be mapped to a rank in [0, world_size).\n * @param rank       Rank of the current GPU.\n * @param world_size World size, the total number of cooperating GPUs.\n *\n * @return The rank tensor of the GPU the given id tensor is mapped to.\n */\ntorch::Tensor RankAssignment(\n    torch::Tensor nodes, int64_t rank, int64_t world_size);\n\n/**\n * @brief Given node ids, the ranks they belong, the offsets to separate\n * different node types and world size, returns node ids sorted w.r.t. the ranks\n * that the given ids belong along with their new positions.\n *\n * @param nodes        Node id tensor to be mapped to a rank in [0, world_size).\n * @param part_ids     Rank tensor the nodes belong to.\n * @param offsets_dev  Offsets to separate different node types.\n * @param world_size   World size, the total number of cooperating GPUs.\n *\n * @return (sorted_nodes, new_positions, rank_offsets, rank_offsets_event),\n * where the first one includes sorted nodes, the second contains new positions\n * of the given nodes, so that sorted_nodes[new_positions] == nodes, and the\n * third contains the offsets of the sorted_nodes indicating\n * sorted_nodes[rank_offsets[i]: rank_offsets[i + 1]] contains nodes that\n * belongs to the `i`th rank. Before accessing rank_offsets on the CPU,\n * `rank_offsets_event.synchronize()` is required.\n */\nstd::tuple<torch::Tensor, torch::Tensor, torch::Tensor, at::cuda::CUDAEvent>\nRankSortImpl(\n    torch::Tensor nodes, torch::Tensor part_ids, torch::Tensor offsets_dev,\n    int64_t world_size);\n\n/**\n * @brief Given a vector of node ids, the rank of current GPU and the world\n * size, returns node ids sorted w.r.t. the ranks that the given ids belong\n * along with the original positions.\n *\n * @param nodes_list   Node id tensor to be mapped to a rank in [0, world_size).\n * @param rank         Rank of the current GPU.\n * @param world_size   World size, the total number of cooperating GPUs.\n *\n * @return vector of (sorted_nodes, new_positions, rank_offsets), where the\n * first one includes sorted nodes, the second contains new positions of the\n * given nodes, so that sorted_nodes[new_positions] == nodes, and the third\n * contains the offsets of the sorted_nodes indicating\n * sorted_nodes[rank_offsets[i]: rank_offsets[i + 1]] contains nodes that\n * belongs to the `i`th rank.\n */\nstd::vector<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>> RankSort(\n    const std::vector<torch::Tensor>& nodes_list, int64_t rank,\n    int64_t world_size);\n\nc10::intrusive_ptr<Future<\n    std::vector<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>>>>\nRankSortAsync(\n    const std::vector<torch::Tensor>& nodes_list, const int64_t rank,\n    const int64_t world_size);\n\n}  // namespace cuda\n}  // namespace graphbolt\n\n#endif  // GRAPHBOLT_CUDA_COOPERATIVE_MINIBATCHING_UTILS_H_\n"
  },
  {
    "path": "graphbolt/src/cuda/cumsum.cu",
    "content": "/**\n *   Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)\n *   All rights reserved.\n *\n *   Licensed under the Apache License, Version 2.0 (the \"License\");\n *   you may not use this file except in compliance with the License.\n *   You may obtain a copy of the License at\n *\n *       http://www.apache.org/licenses/LICENSE-2.0\n *\n *   Unless required by applicable law or agreed to in writing, software\n *   distributed under the License is distributed on an \"AS IS\" BASIS,\n *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n *   See the License for the specific language governing permissions and\n *   limitations under the License.\n *\n * @file cuda/cumsum.cu\n * @brief Cumsum operators implementation on CUDA.\n */\n#include <cub/cub.cuh>\n\n#include \"./common.h\"\n\nnamespace graphbolt {\nnamespace ops {\n\ntorch::Tensor ExclusiveCumSum(torch::Tensor input) {\n  auto result = torch::empty_like(input);\n\n  AT_DISPATCH_INTEGRAL_TYPES(input.scalar_type(), \"ExclusiveCumSum\", ([&] {\n                               CUB_CALL(\n                                   DeviceScan::ExclusiveSum,\n                                   input.data_ptr<scalar_t>(),\n                                   result.data_ptr<scalar_t>(), input.size(0));\n                             }));\n  return result;\n}\n\n}  // namespace ops\n}  // namespace graphbolt\n"
  },
  {
    "path": "graphbolt/src/cuda/expand_indptr.cu",
    "content": "/**\n *   Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)\n *   All rights reserved.\n *\n *   Licensed under the Apache License, Version 2.0 (the \"License\");\n *   you may not use this file except in compliance with the License.\n *   You may obtain a copy of the License at\n *\n *       http://www.apache.org/licenses/LICENSE-2.0\n *\n *   Unless required by applicable law or agreed to in writing, software\n *   distributed under the License is distributed on an \"AS IS\" BASIS,\n *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n *   See the License for the specific language governing permissions and\n *   limitations under the License.\n *\n * @file cuda/expand_indptr.cu\n * @brief ExpandIndptr operator implementation on CUDA.\n */\n#include <thrust/iterator/constant_iterator.h>\n#include <thrust/iterator/counting_iterator.h>\n#include <thrust/iterator/transform_iterator.h>\n\n#include <cub/cub.cuh>\n#include <limits>\n\n#include \"./common.h\"\n\nnamespace graphbolt {\nnamespace ops {\n\ntemplate <typename indices_t, typename nodes_t>\nstruct RepeatIndex {\n  const nodes_t* nodes;\n  __host__ __device__ auto operator()(indices_t i) {\n    return thrust::make_constant_iterator(nodes ? nodes[i] : i);\n  }\n};\n\ntemplate <typename indices_t, typename nodes_t>\nstruct IotaIndex {\n  const nodes_t* nodes;\n  __host__ __device__ auto operator()(indices_t i) {\n    return thrust::make_counting_iterator(nodes ? nodes[i] : 0);\n  }\n};\n\ntemplate <typename indptr_t, typename indices_t>\nstruct OutputBufferIndexer {\n  const indptr_t* indptr;\n  indices_t* buffer;\n  __host__ __device__ auto operator()(int64_t i) { return buffer + indptr[i]; }\n};\n\ntemplate <typename indptr_t>\nstruct AdjacentDifference {\n  const indptr_t* indptr;\n  __host__ __device__ auto operator()(int64_t i) {\n    return indptr[i + 1] - indptr[i];\n  }\n};\n\ntorch::Tensor ExpandIndptrImpl(\n    torch::Tensor indptr, torch::ScalarType dtype,\n    torch::optional<torch::Tensor> nodes, torch::optional<int64_t> output_size,\n    const bool is_edge_ids_variant) {\n  if (!output_size.has_value()) {\n    output_size = AT_DISPATCH_INTEGRAL_TYPES(\n        indptr.scalar_type(), \"ExpandIndptrIndptr[-1]\", ([&]() -> int64_t {\n          auto indptr_ptr = indptr.data_ptr<scalar_t>();\n          auto output_size = cuda::CopyScalar{indptr_ptr + indptr.size(0) - 1};\n          return static_cast<scalar_t>(output_size);\n        }));\n  }\n  auto csc_rows =\n      torch::empty(output_size.value(), indptr.options().dtype(dtype));\n\n  AT_DISPATCH_INTEGRAL_TYPES(\n      indptr.scalar_type(), \"ExpandIndptrIndptr\", ([&] {\n        using indptr_t = scalar_t;\n        auto indptr_ptr = indptr.data_ptr<indptr_t>();\n        AT_DISPATCH_INTEGRAL_TYPES(\n            dtype, \"ExpandIndptrIndices\", ([&] {\n              using indices_t = scalar_t;\n              auto csc_rows_ptr = csc_rows.data_ptr<indices_t>();\n\n              auto nodes_dtype = nodes ? nodes.value().scalar_type() : dtype;\n              AT_DISPATCH_INTEGRAL_TYPES(\n                  nodes_dtype, \"ExpandIndptrNodes\", ([&] {\n                    using nodes_t = scalar_t;\n                    auto nodes_ptr =\n                        nodes ? nodes.value().data_ptr<nodes_t>() : nullptr;\n\n                    thrust::counting_iterator<int64_t> iota(0);\n                    auto output_buffer = thrust::make_transform_iterator(\n                        iota, OutputBufferIndexer<indptr_t, indices_t>{\n                                  indptr_ptr, csc_rows_ptr});\n                    auto buffer_sizes = thrust::make_transform_iterator(\n                        iota, AdjacentDifference<indptr_t>{indptr_ptr});\n\n                    const auto num_rows = indptr.size(0) - 1;\n                    constexpr int64_t max_copy_at_once =\n                        std::numeric_limits<int32_t>::max();\n\n                    if (is_edge_ids_variant) {\n                      auto input_buffer = thrust::make_transform_iterator(\n                          iota, IotaIndex<indices_t, nodes_t>{nodes_ptr});\n                      for (int64_t i = 0; i < num_rows; i += max_copy_at_once) {\n                        CUB_CALL(\n                            DeviceCopy::Batched, input_buffer + i,\n                            output_buffer + i, buffer_sizes + i,\n                            std::min(num_rows - i, max_copy_at_once));\n                      }\n                    } else {\n                      auto input_buffer = thrust::make_transform_iterator(\n                          iota, RepeatIndex<indices_t, nodes_t>{nodes_ptr});\n                      for (int64_t i = 0; i < num_rows; i += max_copy_at_once) {\n                        CUB_CALL(\n                            DeviceCopy::Batched, input_buffer + i,\n                            output_buffer + i, buffer_sizes + i,\n                            std::min(num_rows - i, max_copy_at_once));\n                      }\n                    }\n                  }));\n            }));\n      }));\n  return csc_rows;\n}\n\ntorch::Tensor ExpandIndptrImpl(\n    torch::Tensor indptr, torch::ScalarType dtype,\n    torch::optional<torch::Tensor> nodes,\n    torch::optional<int64_t> output_size) {\n  return ExpandIndptrImpl(indptr, dtype, nodes, output_size, false);\n}\n\ntorch::Tensor IndptrEdgeIdsImpl(\n    torch::Tensor indptr, torch::ScalarType dtype,\n    torch::optional<torch::Tensor> offset,\n    torch::optional<int64_t> output_size) {\n  return ExpandIndptrImpl(indptr, dtype, offset, output_size, true);\n}\n\n}  // namespace ops\n}  // namespace graphbolt\n"
  },
  {
    "path": "graphbolt/src/cuda/extension/gpu_cache.cu",
    "content": "/**\n *   Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)\n *   All rights reserved.\n *\n *   Licensed under the Apache License, Version 2.0 (the \"License\");\n *   you may not use this file except in compliance with the License.\n *   You may obtain a copy of the License at\n *\n *       http://www.apache.org/licenses/LICENSE-2.0\n *\n *   Unless required by applicable law or agreed to in writing, software\n *   distributed under the License is distributed on an \"AS IS\" BASIS,\n *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n *   See the License for the specific language governing permissions and\n *   limitations under the License.\n *\n * @file cuda/gpu_cache.cu\n * @brief GPUCache implementation on CUDA.\n */\n#include <numeric>\n\n#include \"../common.h\"\n#include \"./gpu_cache.h\"\n\nnamespace graphbolt {\nnamespace cuda {\n\nGpuCache::GpuCache(const std::vector<int64_t> &shape, torch::ScalarType dtype) {\n  TORCH_CHECK(shape.size() >= 2, \"Shape must at least have 2 dimensions.\");\n  const auto num_items = shape[0];\n  TORCH_CHECK(\n      num_items > 0, \"The capacity of GpuCache needs to be a positive.\");\n  const int64_t num_feats =\n      std::accumulate(shape.begin() + 1, shape.end(), 1ll, std::multiplies<>());\n  const int element_size =\n      torch::empty(1, torch::TensorOptions().dtype(dtype)).element_size();\n  num_bytes_ = num_feats * element_size;\n  num_float_feats_ = (num_bytes_ + sizeof(float) - 1) / sizeof(float);\n  cache_ = std::make_unique<gpu_cache_t>(\n      (num_items + bucket_size - 1) / bucket_size, num_float_feats_);\n  shape_ = shape;\n  shape_[0] = -1;\n  dtype_ = dtype;\n  device_id_ = cuda::GetCurrentStream().device_index();\n}\n\nstd::tuple<torch::Tensor, torch::Tensor, torch::Tensor> GpuCache::Query(\n    torch::Tensor keys) {\n  TORCH_CHECK(keys.device().is_cuda(), \"Keys should be on a CUDA device.\");\n  TORCH_CHECK(\n      keys.device().index() == device_id_,\n      \"Keys should be on the correct CUDA device.\");\n  TORCH_CHECK(keys.sizes().size() == 1, \"Keys should be a 1D tensor.\");\n  keys = keys.to(torch::kLong);\n  auto values = torch::empty(\n      {keys.size(0), num_float_feats_}, keys.options().dtype(torch::kFloat));\n  auto missing_index =\n      torch::empty(keys.size(0), keys.options().dtype(torch::kLong));\n  auto missing_keys =\n      torch::empty(keys.size(0), keys.options().dtype(torch::kLong));\n  auto allocator = cuda::GetAllocator();\n  auto missing_len_device = allocator.AllocateStorage<size_t>(1);\n  cache_->Query(\n      reinterpret_cast<const key_t *>(keys.data_ptr()), keys.size(0),\n      values.data_ptr<float>(),\n      reinterpret_cast<uint64_t *>(missing_index.data_ptr()),\n      reinterpret_cast<key_t *>(missing_keys.data_ptr()),\n      missing_len_device.get(), cuda::GetCurrentStream());\n  values = values.view(torch::kByte)\n               .slice(1, 0, num_bytes_)\n               .view(dtype_)\n               .view(shape_);\n  cuda::CopyScalar<size_t> missing_len(missing_len_device.get());\n  missing_index = missing_index.slice(0, 0, static_cast<size_t>(missing_len));\n  missing_keys = missing_keys.slice(0, 0, static_cast<size_t>(missing_len));\n  return std::make_tuple(values, missing_index, missing_keys);\n}\n\nc10::intrusive_ptr<Future<std::vector<torch::Tensor>>> GpuCache::QueryAsync(\n    torch::Tensor keys) {\n  return async(\n      [=] {\n        auto [values, missing_index, missing_keys] = Query(keys);\n        return std::vector{values, missing_index, missing_keys};\n      },\n      true);\n}\n\nvoid GpuCache::Replace(torch::Tensor keys, torch::Tensor values) {\n  TORCH_CHECK(keys.device().is_cuda(), \"Keys should be on a CUDA device.\");\n  TORCH_CHECK(\n      keys.device().index() == device_id_,\n      \"Keys should be on the correct CUDA device.\");\n  TORCH_CHECK(values.device().is_cuda(), \"Keys should be on a CUDA device.\");\n  TORCH_CHECK(\n      values.device().index() == device_id_,\n      \"Values should be on the correct CUDA device.\");\n  TORCH_CHECK(\n      keys.size(0) == values.size(0),\n      \"The first dimensions of keys and values must match.\");\n  TORCH_CHECK(\n      std::equal(shape_.begin() + 1, shape_.end(), values.sizes().begin() + 1),\n      \"Values should have the correct dimensions.\");\n  TORCH_CHECK(\n      values.scalar_type() == dtype_, \"Values should have the correct dtype.\");\n  if (keys.numel() == 0) return;\n  keys = keys.to(torch::kLong);\n  torch::Tensor float_values;\n  if (num_bytes_ % sizeof(float) != 0) {\n    float_values = torch::empty(\n        {values.size(0), num_float_feats_},\n        values.options().dtype(torch::kFloat));\n    float_values.view(torch::kByte)\n        .slice(1, 0, num_bytes_)\n        .copy_(values.view(torch::kByte).view({values.size(0), -1}));\n  } else {\n    float_values = values.view(torch::kByte)\n                       .view({values.size(0), -1})\n                       .view(torch::kFloat)\n                       .contiguous();\n  }\n  cache_->Replace(\n      reinterpret_cast<const key_t *>(keys.data_ptr()), keys.size(0),\n      float_values.data_ptr<float>(), cuda::GetCurrentStream());\n}\n\nc10::intrusive_ptr<GpuCache> GpuCache::Create(\n    const std::vector<int64_t> &shape, torch::ScalarType dtype) {\n  return c10::make_intrusive<GpuCache>(shape, dtype);\n}\n\n}  // namespace cuda\n}  // namespace graphbolt\n"
  },
  {
    "path": "graphbolt/src/cuda/extension/gpu_cache.h",
    "content": "/**\n *   Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)\n *   All rights reserved.\n *\n *   Licensed under the Apache License, Version 2.0 (the \"License\");\n *   you may not use this file except in compliance with the License.\n *   You may obtain a copy of the License at\n *\n *       http://www.apache.org/licenses/LICENSE-2.0\n *\n *   Unless required by applicable law or agreed to in writing, software\n *   distributed under the License is distributed on an \"AS IS\" BASIS,\n *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n *   See the License for the specific language governing permissions and\n *   limitations under the License.\n *\n * @file cuda/gpu_cache.h\n * @brief Header file of HugeCTR gpu_cache wrapper.\n */\n\n#ifndef GRAPHBOLT_GPU_CACHE_H_\n#define GRAPHBOLT_GPU_CACHE_H_\n\n#include <graphbolt/async.h>\n#include <torch/custom_class.h>\n#include <torch/torch.h>\n\n#include <limits>\n#include <nv_gpu_cache.hpp>\n\nnamespace graphbolt {\nnamespace cuda {\n\nclass GpuCache : public torch::CustomClassHolder {\n  using key_t = long long;\n  constexpr static int set_associativity = 2;\n  constexpr static int WARP_SIZE = 32;\n  constexpr static int bucket_size = WARP_SIZE * set_associativity;\n  using gpu_cache_t = ::gpu_cache::gpu_cache<\n      key_t, uint64_t, std::numeric_limits<key_t>::max(), set_associativity,\n      WARP_SIZE>;\n\n public:\n  /**\n   * @brief Constructor for the GpuCache struct.\n   *\n   * @param shape The shape of the GPU cache.\n   * @param dtype The datatype of items to be stored.\n   */\n  GpuCache(const std::vector<int64_t>& shape, torch::ScalarType dtype);\n\n  GpuCache() = default;\n\n  std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> Query(\n      torch::Tensor keys);\n\n  c10::intrusive_ptr<Future<std::vector<torch::Tensor>>> QueryAsync(\n      torch::Tensor keys);\n\n  void Replace(torch::Tensor keys, torch::Tensor values);\n\n  static c10::intrusive_ptr<GpuCache> Create(\n      const std::vector<int64_t>& shape, torch::ScalarType dtype);\n\n private:\n  std::vector<int64_t> shape_;\n  torch::ScalarType dtype_;\n  std::unique_ptr<gpu_cache_t> cache_;\n  int64_t num_bytes_;\n  int64_t num_float_feats_;\n  torch::DeviceIndex device_id_;\n};\n\n// The cu file in HugeCTR gpu cache uses unsigned int and long long.\n// Changing to int64_t results in a mismatch of template arguments.\nstatic_assert(\n    sizeof(long long) == sizeof(int64_t),\n    \"long long and int64_t needs to have the same size.\");  // NOLINT\n\n}  // namespace cuda\n}  // namespace graphbolt\n\n#endif  // GRAPHBOLT_GPU_CACHE_H_\n"
  },
  {
    "path": "graphbolt/src/cuda/extension/gpu_graph_cache.cu",
    "content": "/**\n *   Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)\n *   All rights reserved.\n *\n *   Licensed under the Apache License, Version 2.0 (the \"License\");\n *   you may not use this file except in compliance with the License.\n *   You may obtain a copy of the License at\n *\n *       http://www.apache.org/licenses/LICENSE-2.0\n *\n *   Unless required by applicable law or agreed to in writing, software\n *   distributed under the License is distributed on an \"AS IS\" BASIS,\n *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n *   See the License for the specific language governing permissions and\n *   limitations under the License.\n *\n * @file cuda/gpu_graph_cache.cu\n * @brief GPU graph cache implementation on CUDA.\n */\n#include <graphbolt/cuda_ops.h>\n#include <thrust/gather.h>\n#include <thrust/transform.h>\n\n#include <cstddef>\n#include <cub/cub.cuh>\n#include <cuco/static_map.cuh>\n#include <cuda/std/atomic>\n#include <cuda/stream_ref>\n#include <limits>\n#include <numeric>\n#include <type_traits>\n\n#include \"../common.h\"\n#include \"../utils.h\"\n#include \"./gpu_graph_cache.h\"\n\nnamespace graphbolt {\nnamespace cuda {\n\nnamespace {\n\nconstexpr int cg_size = 1;\ntemplate <typename index_t>\nusing probing_t =\n    cuco::linear_probing<cg_size, cuco::default_hash_function<index_t>>;\ntemplate <typename index_t>\nusing allocator_t = cuda::CUDAWorkspaceAllocator<cuco::pair<index_t, index_t>>;\ntemplate <typename index_t>\nusing map_t = cuco::static_map<\n    index_t, index_t, cuco::extent<int64_t>, ::cuda::thread_scope_device,\n    thrust::equal_to<index_t>, probing_t<index_t>, allocator_t<index_t>>;\n\ntemplate <typename index_t, typename map_t>\n__global__ void _Insert(\n    const int64_t num_nodes, const index_t num_existing, const index_t* seeds,\n    const index_t* missing_indices, const index_t* indices, map_t map) {\n  int64_t i = blockIdx.x * blockDim.x + threadIdx.x;\n  const int stride = gridDim.x * blockDim.x;\n\n  while (i < num_nodes) {\n    const auto key = seeds[missing_indices[indices[i]]];\n\n    auto slot = map.find(key);\n    slot->second = num_existing + i;\n\n    i += stride;\n  }\n}\n\n/**\n * @brief For node ids not in the cache, it keeps their access count inside\n * a hash table as (v, -c) where v is the node id and c is the access count.\n * When c == -threshold, it means that v will be inserted into the cache\n * during the call to the replace method. Once v is inserted into the cache,\n * c is assigned to a nonnegative value and indicates the local id of vertex\n * v in the cache.\n *\n * @param num_nodes The number of node ids.\n * @param seeds The node ids the cache is being queried with.\n * @param positions Holds the values found in the hash table.\n * @param map The hash table holding (v, -c) or (v, local_id).\n *\n */\ntemplate <typename index_t, typename map_t>\n__global__ void _QueryAndIncrement(\n    const int64_t num_nodes, const index_t* seeds, index_t* positions,\n    map_t map) {\n  int64_t i = blockIdx.x * blockDim.x + threadIdx.x;\n  const int stride = gridDim.x * blockDim.x;\n\n  while (i < num_nodes) {\n    const auto key = seeds[i];\n\n    constexpr index_t minusONE = -1;\n    auto [slot, is_new_key] = map.insert_and_find(cuco::pair{key, minusONE});\n\n    int64_t position = -1;\n\n    if (!is_new_key) {\n      auto ref = ::cuda::atomic_ref<index_t, ::cuda::thread_scope_device>{\n          slot->second};\n      position = ref.load(::cuda::memory_order_relaxed);\n      if (position < 0) {\n        position = ref.fetch_add(-1, ::cuda::memory_order_relaxed) - 1;\n      }\n    }\n\n    positions[i] = position;\n\n    i += stride;\n  }\n}\n\nconstexpr int kIntBlockSize = 512;\n}  // namespace\n\nc10::intrusive_ptr<GpuGraphCache> GpuGraphCache::Create(\n    const int64_t num_edges, const int64_t threshold,\n    torch::ScalarType indptr_dtype, std::vector<torch::ScalarType> dtypes,\n    bool has_original_edge_ids) {\n  return c10::make_intrusive<GpuGraphCache>(\n      num_edges, threshold, indptr_dtype, dtypes, has_original_edge_ids);\n}\n\nGpuGraphCache::GpuGraphCache(\n    const int64_t num_edges, const int64_t threshold,\n    torch::ScalarType indptr_dtype, std::vector<torch::ScalarType> dtypes,\n    bool has_original_edge_ids) {\n  const int64_t initial_node_capacity = 1024;\n  AT_DISPATCH_INDEX_TYPES(\n      dtypes.at(0), \"GpuGraphCache::GpuGraphCache\", ([&] {\n        auto map_temp = map_t<index_t>{\n            initial_node_capacity,\n            kDoubleLoadFactor,\n            cuco::empty_key{static_cast<index_t>(-1)},\n            cuco::empty_value{std::numeric_limits<index_t>::lowest()},\n            {},\n            probing_t<index_t>{},\n            {},\n            {},\n            allocator_t<index_t>{},\n            ::cuda::stream_ref{cuda::GetCurrentStream()}};\n        map_ = new map_t<index_t>{std::move(map_temp)};\n      }));\n  C10_CUDA_KERNEL_LAUNCH_CHECK();  // Check the map constructor's success.\n  const auto options = torch::TensorOptions().device(c10::DeviceType::CUDA);\n  TORCH_CHECK(threshold > 0, \"Threshold should be a position integer.\");\n  threshold_ = threshold;\n  device_id_ = cuda::GetCurrentStream().device_index();\n  map_size_ = 0;\n  num_nodes_ = 0;\n  num_edges_ = 0;\n  indptr_ =\n      torch::zeros(initial_node_capacity + 1, options.dtype(indptr_dtype));\n  if (!has_original_edge_ids) {\n    offset_ = torch::empty(indptr_.size(0) - 1, indptr_.options());\n  }\n  for (auto dtype : dtypes) {\n    cached_edge_tensors_.push_back(\n        torch::empty(num_edges, options.dtype(dtype)));\n  }\n}\n\nGpuGraphCache::~GpuGraphCache() {\n  AT_DISPATCH_INDEX_TYPES(\n      cached_edge_tensors_.at(0).scalar_type(), \"GpuGraphCache::GpuGraphCache\",\n      ([&] { delete reinterpret_cast<map_t<index_t>*>(map_); }));\n}\n\nstd::tuple<torch::Tensor, torch::Tensor, int64_t, int64_t> GpuGraphCache::Query(\n    torch::Tensor seeds) {\n  TORCH_CHECK(seeds.device().is_cuda(), \"Seeds should be on a CUDA device.\");\n  TORCH_CHECK(\n      seeds.device().index() == device_id_,\n      \"Seeds should be on the correct CUDA device.\");\n  TORCH_CHECK(seeds.sizes().size() == 1, \"Keys should be a 1D tensor.\");\n  std::lock_guard lock(mtx_);\n  auto allocator = cuda::GetAllocator();\n  auto index_dtype = cached_edge_tensors_.at(0).scalar_type();\n  const dim3 block(kIntBlockSize);\n  const dim3 grid((seeds.size(0) + kIntBlockSize - 1) / kIntBlockSize);\n  return AT_DISPATCH_INDEX_TYPES(\n      index_dtype, \"GpuGraphCache::Query\", ([&] {\n        auto map = reinterpret_cast<map_t<index_t>*>(map_);\n        while ((\n            map_size_ + seeds.size(0) >= map->capacity() * kDoubleLoadFactor)) {\n          map->rehash_async(\n              map->capacity() * kIntGrowthFactor,\n              ::cuda::stream_ref{cuda::GetCurrentStream()});\n        }\n        auto positions = torch::empty_like(seeds);\n        CUDA_KERNEL_CALL(\n            _QueryAndIncrement, grid, block, 0,\n            static_cast<int64_t>(seeds.size(0)), seeds.data_ptr<index_t>(),\n            positions.data_ptr<index_t>(), map->ref(cuco::insert_and_find));\n        auto num_threshold_new_hit =\n            allocator.AllocateStorage<thrust::tuple<int64_t, int64_t, int64_t>>(\n                1);\n        // Since threshold_ is a class member, we want the lambda functions\n        // below to only capture this particular variable by reassigning it to a\n        // local variable.\n        const auto threshold = -threshold_;\n        auto is_threshold_new_hit = thrust::make_transform_iterator(\n            positions.data_ptr<index_t>(), [=] __host__ __device__(index_t x) {\n              int64_t is_threshold = x == threshold;\n              int64_t is_new = x == -1;\n              int64_t is_hit = x >= 0;\n              return thrust::make_tuple(is_threshold, is_new, is_hit);\n            });\n        CUB_CALL(\n            DeviceReduce::Reduce, is_threshold_new_hit,\n            num_threshold_new_hit.get(), positions.size(0),\n            [] __host__ __device__(\n                const thrust::tuple<int64_t, int64_t, int64_t>& a,\n                const thrust::tuple<int64_t, int64_t, int64_t>& b) {\n              return thrust::make_tuple(\n                  thrust::get<0>(a) + thrust::get<0>(b),\n                  thrust::get<1>(a) + thrust::get<1>(b),\n                  thrust::get<2>(a) + thrust::get<2>(b));\n            },\n            thrust::tuple<int64_t, int64_t, int64_t>{});\n        CopyScalar num_threshold_new_hit_cpu{num_threshold_new_hit.get()};\n        thrust::counting_iterator<index_t> iota{0};\n        auto position_and_index =\n            thrust::make_zip_iterator(positions.data_ptr<index_t>(), iota);\n        auto output_positions = torch::empty_like(seeds);\n        auto output_indices = torch::empty_like(seeds);\n        auto output_position_and_index = thrust::make_zip_iterator(\n            output_positions.data_ptr<index_t>(),\n            output_indices.data_ptr<index_t>());\n        CUB_CALL(\n            DevicePartition::If, position_and_index, output_position_and_index,\n            cub::DiscardOutputIterator{}, seeds.size(0),\n            [] __device__(thrust::tuple<index_t, index_t> & x) {\n              return thrust::get<0>(x) >= 0;\n            });\n        const auto [num_threshold, num_new, num_hit] =\n            static_cast<thrust::tuple<int64_t, int64_t, int64_t>>(\n                num_threshold_new_hit_cpu);\n        map_size_ += num_new;\n\n        return std::make_tuple(\n            output_indices, output_positions, num_hit, num_threshold);\n      }));\n}\n\nc10::intrusive_ptr<\n    Future<std::tuple<torch::Tensor, torch::Tensor, int64_t, int64_t>>>\nGpuGraphCache::QueryAsync(torch::Tensor seeds) {\n  return async([=] { return Query(seeds); }, true);\n}\n\nstd::tuple<torch::Tensor, std::vector<torch::Tensor>> GpuGraphCache::Replace(\n    torch::Tensor seeds, torch::Tensor indices, torch::Tensor positions,\n    int64_t num_hit, int64_t num_threshold, torch::Tensor indptr,\n    std::vector<torch::Tensor> edge_tensors) {\n  const auto with_edge_ids = offset_.has_value();\n  // The last element of edge_tensors has the edge ids.\n  const auto num_tensors = edge_tensors.size() - with_edge_ids;\n  TORCH_CHECK(\n      num_tensors == cached_edge_tensors_.size(),\n      \"Same number of tensors need to be passed!\");\n  const auto num_nodes = seeds.size(0);\n  TORCH_CHECK(\n      indptr.size(0) == num_nodes - num_hit + 1,\n      \"(indptr.size(0) == seeds.size(0) - num_hit + 1) failed.\");\n  std::lock_guard lock(mtx_);\n  const int64_t num_buffers = num_nodes * num_tensors;\n  auto allocator = cuda::GetAllocator();\n  auto index_dtype = cached_edge_tensors_.at(0).scalar_type();\n  return AT_DISPATCH_INDEX_TYPES(\n      index_dtype, \"GpuGraphCache::Replace\", ([&] {\n        using indices_t = index_t;\n        return AT_DISPATCH_INDEX_TYPES(\n            indptr_.scalar_type(), \"GpuGraphCache::Replace::copy_prep\", ([&] {\n              using indptr_t = index_t;\n              static_assert(\n                  sizeof(int64_t) == sizeof(void*),\n                  \"Pointers have to be 64-bit.\");\n              static_assert(\n                  sizeof(std::byte) == 1, \"Byte needs to have a size of 1.\");\n              auto cache_missing_dtype = torch::empty(\n                  // Below, we use this storage to store a tuple of 4 elements,\n                  // since each element is 64-bit, we need 4x int64 storage.\n                  4 * num_tensors, c10::TensorOptions()\n                                       .dtype(torch::kInt64)\n                                       .pinned_memory(true));\n              auto cache_missing_dtype_ptr =\n                  reinterpret_cast<::cuda::std::tuple<\n                      std::byte*, std::byte*, int64_t, int64_t>*>(\n                      cache_missing_dtype.data_ptr());\n              int64_t total_size = 0;\n              for (size_t i = 0; i < num_tensors; i++) {\n                TORCH_CHECK(\n                    cached_edge_tensors_[i].scalar_type() ==\n                        edge_tensors[i].scalar_type(),\n                    \"The dtypes of edge tensors must match.\");\n                if (i > 0) {\n                  TORCH_CHECK(\n                      edge_tensors[i - 1].size(0) == edge_tensors[i].size(0),\n                      \"The missing edge tensors should have identical size.\");\n                }\n                const int64_t element_size = edge_tensors[i].element_size();\n                cache_missing_dtype_ptr[i] = {\n                    reinterpret_cast<std::byte*>(\n                        cached_edge_tensors_[i].data_ptr()),\n                    reinterpret_cast<std::byte*>(edge_tensors[i].data_ptr()),\n                    element_size, total_size};\n                total_size += element_size;\n              }\n              auto cache_missing_dtype_dev = allocator.AllocateStorage<\n                  ::cuda::std::tuple<std::byte*, std::byte*, int64_t, int64_t>>(\n                  num_tensors);\n              THRUST_CALL(\n                  copy_n, cache_missing_dtype_ptr, num_tensors,\n                  cache_missing_dtype_dev.get());\n\n              auto input = allocator.AllocateStorage<std::byte*>(num_buffers);\n              auto input_size =\n                  allocator.AllocateStorage<size_t>(num_buffers + 1);\n              torch::optional<torch::Tensor> edge_id_offsets;\n              if (with_edge_ids) {\n                edge_id_offsets = torch::empty(\n                    num_nodes,\n                    seeds.options().dtype(offset_.value().scalar_type()));\n              }\n              const auto cache_missing_dtype_dev_ptr =\n                  cache_missing_dtype_dev.get();\n              const auto indices_ptr = indices.data_ptr<indices_t>();\n              const auto positions_ptr = positions.data_ptr<indices_t>();\n              const auto input_ptr = input.get();\n              const auto input_size_ptr = input_size.get();\n              const auto edge_id_offsets_ptr =\n                  edge_id_offsets ? edge_id_offsets->data_ptr<indptr_t>()\n                                  : nullptr;\n              const auto cache_indptr = indptr_.data_ptr<indptr_t>();\n              const auto missing_indptr = indptr.data_ptr<indptr_t>();\n              const auto cache_offset =\n                  offset_ ? offset_->data_ptr<indptr_t>() : nullptr;\n              const auto missing_edge_ids =\n                  edge_id_offsets ? edge_tensors.back().data_ptr<indptr_t>()\n                                  : nullptr;\n              CUB_CALL(DeviceFor::Bulk, num_buffers, [=] __device__(int64_t i) {\n                const auto tensor_idx = i / num_nodes;\n                const auto idx = i % num_nodes;\n                const auto pos = positions_ptr[idx];\n                const auto original_idx = indices_ptr[idx];\n                const auto [cache_ptr, missing_ptr, size, cum_size] =\n                    cache_missing_dtype_dev_ptr[tensor_idx];\n                const auto is_cached = pos >= 0;\n                const auto offset = is_cached ? cache_indptr[pos]\n                                              : missing_indptr[idx - num_hit];\n                const auto offset_end = is_cached\n                                            ? cache_indptr[pos + 1]\n                                            : missing_indptr[idx - num_hit + 1];\n                const auto out_idx = tensor_idx * num_nodes + original_idx;\n\n                input_ptr[out_idx] =\n                    (is_cached ? cache_ptr : missing_ptr) + offset * size;\n                input_size_ptr[out_idx] = size * (offset_end - offset);\n                if (edge_id_offsets_ptr && i < num_nodes) {\n                  const auto edge_id =\n                      is_cached ? cache_offset[pos] : missing_edge_ids[offset];\n                  edge_id_offsets_ptr[out_idx] = edge_id;\n                }\n              });\n              auto output_indptr = torch::empty(\n                  num_nodes + 1, seeds.options().dtype(indptr_.scalar_type()));\n              auto output_indptr_ptr = output_indptr.data_ptr<indptr_t>();\n              const auto element_size =\n                  ::cuda::std::get<2>(cache_missing_dtype_ptr[0]);\n              auto input_indegree = thrust::make_transform_iterator(\n                  input_size_ptr, [=] __host__ __device__(size_t x) {\n                    return x / element_size;\n                  });\n              CUB_CALL(\n                  DeviceScan::ExclusiveSum, input_indegree, output_indptr_ptr,\n                  num_nodes + 1);\n              CopyScalar output_size{output_indptr_ptr + num_nodes};\n\n              if (num_threshold > 0) {\n                // Insert the vertices whose access count equal threshold.\n                auto missing_positions = positions.slice(0, num_hit);\n                auto missing_indices = indices.slice(0, num_hit);\n\n                thrust::counting_iterator<indices_t> iota{0};\n                auto threshold = -threshold_;\n                auto is_threshold = thrust::make_transform_iterator(\n                    missing_positions.data_ptr<indices_t>(),\n                    [=] __host__ __device__(indices_t x) {\n                      return x == threshold;\n                    });\n                auto output_indices =\n                    torch::empty(num_threshold, seeds.options());\n                CUB_CALL(\n                    DeviceSelect::Flagged, iota, is_threshold,\n                    output_indices.data_ptr<indices_t>(),\n                    cub::DiscardOutputIterator{}, missing_positions.size(0));\n                auto [in_degree, sliced_indptr] =\n                    ops::SliceCSCIndptr(indptr, output_indices);\n                while (num_nodes_ + num_threshold >= indptr_.size(0)) {\n                  auto new_indptr = torch::empty(\n                      indptr_.size(0) * kIntGrowthFactor, indptr_.options());\n                  new_indptr.slice(0, 0, indptr_.size(0)) = indptr_;\n                  indptr_ = new_indptr;\n                  if (offset_) {\n                    auto new_offset =\n                        torch::empty(indptr_.size(0) - 1, offset_->options());\n                    new_offset.slice(0, 0, offset_->size(0)) = *offset_;\n                    offset_ = new_offset;\n                  }\n                }\n                torch::Tensor sindptr;\n                bool enough_space;\n                torch::optional<int64_t> cached_output_size;\n                for (size_t i = 0; i < num_tensors; i++) {\n                  torch::Tensor sindices;\n                  std::tie(sindptr, sindices) = ops::IndexSelectCSCImpl(\n                      in_degree, sliced_indptr, edge_tensors[i], output_indices,\n                      indptr.size(0) - 2, cached_output_size);\n                  cached_output_size = sindices.size(0);\n                  enough_space = num_edges_ + *cached_output_size <=\n                                 cached_edge_tensors_[i].size(0);\n                  if (enough_space) {\n                    cached_edge_tensors_[i].slice(\n                        0, num_edges_, num_edges_ + *cached_output_size) =\n                        sindices;\n                  } else\n                    break;\n                }\n                if (enough_space) {\n                  auto num_edges = num_edges_;\n                  if (offset_) {\n                    auto transform_input_it = thrust::make_zip_iterator(\n                        sindptr.data_ptr<indptr_t>() + 1,\n                        sliced_indptr.data_ptr<indptr_t>());\n                    auto transform_output_it = thrust::make_zip_iterator(\n                        indptr_.data_ptr<indptr_t>() + num_nodes_ + 1,\n                        offset_->data_ptr<indptr_t>() + num_nodes_);\n                    THRUST_CALL(\n                        transform, transform_input_it,\n                        transform_input_it + sindptr.size(0) - 1,\n                        transform_output_it,\n                        [=] __host__ __device__(\n                            const thrust::tuple<indptr_t, indptr_t>& x) {\n                          return thrust::make_tuple(\n                              thrust::get<0>(x) + num_edges,\n                              missing_edge_ids[thrust::get<1>(x)]);\n                        });\n                  } else {\n                    THRUST_CALL(\n                        transform, sindptr.data_ptr<indptr_t>() + 1,\n                        sindptr.data_ptr<indptr_t>() + sindptr.size(0),\n                        indptr_.data_ptr<indptr_t>() + num_nodes_ + 1,\n                        [=] __host__ __device__(const indptr_t& x) {\n                          return x + num_edges;\n                        });\n                  }\n                  auto map = reinterpret_cast<map_t<indices_t>*>(map_);\n                  const dim3 block(kIntBlockSize);\n                  const dim3 grid(\n                      (num_threshold + kIntBlockSize - 1) / kIntBlockSize);\n                  CUDA_KERNEL_CALL(\n                      _Insert, grid, block, 0, output_indices.size(0),\n                      static_cast<indices_t>(num_nodes_),\n                      seeds.data_ptr<indices_t>(),\n                      missing_indices.data_ptr<indices_t>(),\n                      output_indices.data_ptr<indices_t>(),\n                      map->ref(cuco::find));\n                  num_edges_ += *cached_output_size;\n                  num_nodes_ += num_threshold;\n                }\n              }\n\n              constexpr int alignment = 128;\n              const auto output_allocation_count =\n                  (static_cast<indptr_t>(output_size) + alignment - 1) /\n                  alignment * alignment;\n              auto output_allocation = torch::empty(\n                  output_allocation_count * total_size,\n                  seeds.options().dtype(torch::kInt8));\n              const auto output_allocation_ptr =\n                  output_allocation.data_ptr<int8_t>();\n\n              std::vector<torch::Tensor> output_edge_tensors;\n              for (size_t i = 0; i < num_tensors; i++) {\n                const auto cum_size =\n                    ::cuda::std::get<3>(cache_missing_dtype_ptr[i]);\n                output_edge_tensors.push_back(\n                    output_allocation\n                        .slice(0, cum_size * output_allocation_count)\n                        .view(edge_tensors[i].scalar_type())\n                        .slice(0, 0, static_cast<indptr_t>(output_size)));\n              }\n              if (edge_id_offsets) {\n                // Append the edge ids as the last element of the output.\n                output_edge_tensors.push_back(ops::IndptrEdgeIdsImpl(\n                    output_indptr, output_indptr.scalar_type(),\n                    *edge_id_offsets,\n                    static_cast<int64_t>(static_cast<indptr_t>(output_size))));\n              }\n\n              {\n                thrust::counting_iterator<int64_t> iota{0};\n                auto output_buffer_it = thrust::make_transform_iterator(\n                    iota, [=] __host__ __device__(int64_t i) {\n                      const auto tensor_idx = i / num_nodes;\n                      const auto idx = i % num_nodes;\n                      const auto offset = output_indptr_ptr[idx];\n                      const auto [_0, _1, size, cum_size] =\n                          cache_missing_dtype_dev_ptr[tensor_idx];\n                      return output_allocation_ptr +\n                             cum_size * output_allocation_count + offset * size;\n                    });\n                constexpr int64_t max_copy_at_once =\n                    std::numeric_limits<int32_t>::max();\n                for (int64_t i = 0; i < num_buffers; i += max_copy_at_once) {\n                  CUB_CALL(\n                      DeviceMemcpy::Batched, input.get() + i,\n                      output_buffer_it + i, input_size_ptr + i,\n                      std::min(num_buffers - i, max_copy_at_once));\n                }\n              }\n\n              return std::make_tuple(output_indptr, output_edge_tensors);\n            }));\n      }));\n}\n\nc10::intrusive_ptr<\n    Future<std::tuple<torch::Tensor, std::vector<torch::Tensor>>>>\nGpuGraphCache::ReplaceAsync(\n    torch::Tensor seeds, torch::Tensor indices, torch::Tensor positions,\n    int64_t num_hit, int64_t num_threshold, torch::Tensor indptr,\n    std::vector<torch::Tensor> edge_tensors) {\n  return async(\n      [=] {\n        return Replace(\n            seeds, indices, positions, num_hit, num_threshold, indptr,\n            edge_tensors);\n      },\n      true);\n}\n\n}  // namespace cuda\n}  // namespace graphbolt\n"
  },
  {
    "path": "graphbolt/src/cuda/extension/gpu_graph_cache.h",
    "content": "/**\n *   Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)\n *   All rights reserved.\n *\n *   Licensed under the Apache License, Version 2.0 (the \"License\");\n *   you may not use this file except in compliance with the License.\n *   You may obtain a copy of the License at\n *\n *       http://www.apache.org/licenses/LICENSE-2.0\n *\n *   Unless required by applicable law or agreed to in writing, software\n *   distributed under the License is distributed on an \"AS IS\" BASIS,\n *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n *   See the License for the specific language governing permissions and\n *   limitations under the License.\n *\n * @file cuda/gpu_graph_cache.h\n * @brief Header file of GPU graph cache.\n */\n\n#ifndef GRAPHBOLT_GPU_GRAPH_CACHE_H_\n#define GRAPHBOLT_GPU_GRAPH_CACHE_H_\n\n#include <graphbolt/async.h>\n#include <torch/custom_class.h>\n#include <torch/torch.h>\n\n#include <mutex>\n\nnamespace graphbolt {\nnamespace cuda {\n\nclass GpuGraphCache : public torch::CustomClassHolder {\n  // The load factor of the constructed hash table.\n  static constexpr double kDoubleLoadFactor = 0.8;\n  // The growth factor of the hash table and the dynamically sized indptr\n  // tensor.\n  static constexpr int kIntGrowthFactor = 2;\n\n public:\n  /**\n   * @brief Constructor for the GpuGraphCache struct.\n   *\n   * @param num_edges The edge capacity of GPU cache.\n   * @param threshold The access threshold before a vertex neighborhood is\n   * cached.\n   * @param indptr_dtype The node id datatype.\n   * @param dtypes The dtypes of the edge tensors to be cached. dtypes[0] is\n   * reserved for the indices edge tensor holding node ids.\n   * @param has_original_edge_ids Whether the graph to be cached has original\n   * edge ids.\n   */\n  GpuGraphCache(\n      const int64_t num_edges, const int64_t threshold,\n      torch::ScalarType indptr_dtype, std::vector<torch::ScalarType> dtypes,\n      bool has_original_edge_ids);\n\n  GpuGraphCache() = default;\n\n  ~GpuGraphCache();\n\n  /**\n   * @brief Queries the cache. Returns tensors indicating which elements are\n   * missing.\n   *\n   * @param seeds The node ids to query the cache with.\n   *\n   * @return\n   * (torch::Tensor, torch::Tensor, int64_t, int64_t) index, position,\n   * number of cache hits and number of ids that will enter the cache.\n   */\n  std::tuple<torch::Tensor, torch::Tensor, int64_t, int64_t> Query(\n      torch::Tensor seeds);\n\n  c10::intrusive_ptr<\n      Future<std::tuple<torch::Tensor, torch::Tensor, int64_t, int64_t>>>\n  QueryAsync(torch::Tensor seeds);\n\n  /**\n   * @brief After the graph structure for the missing node ids are fetched, it\n   * inserts the node ids which passes the threshold and returns the final\n   * output graph structure, combining the information in the cache with the\n   * graph structure for the missing node ids.\n   *\n   * @param seeds The node ids that the cache was queried with.\n   * @param indices seeds[indices[:num_hit]] gives us the node ids that were\n   * found in the cache\n   * @param positions positions[:num_hit] gives where the node ids can be found\n   * in the cache.\n   * @param num_hit The number of seeds that are already in the cache.\n   * @param num_threshold The number of seeds among the missing node ids that\n   * will be inserted into the cache.\n   * @param indptr The indptr for the missing seeds fetched from remote.\n   * @param edge_tensors The edge tensors for the missing seeds. The last\n   * element of edge_tensors is treated as the edge ids tensor with\n   * indptr_dtype.\n   *\n   * @return (torch::Tensor, std::vector<torch::Tensor>) The final indptr and\n   * edge_tensors, directly corresponding to the seeds tensor.\n   */\n  std::tuple<torch::Tensor, std::vector<torch::Tensor>> Replace(\n      torch::Tensor seeds, torch::Tensor indices, torch::Tensor positions,\n      int64_t num_hit, int64_t num_threshold, torch::Tensor indptr,\n      std::vector<torch::Tensor> edge_tensors);\n\n  c10::intrusive_ptr<\n      Future<std::tuple<torch::Tensor, std::vector<torch::Tensor>>>>\n  ReplaceAsync(\n      torch::Tensor seeds, torch::Tensor indices, torch::Tensor positions,\n      int64_t num_hit, int64_t num_threshold, torch::Tensor indptr,\n      std::vector<torch::Tensor> edge_tensors);\n\n  static c10::intrusive_ptr<GpuGraphCache> Create(\n      const int64_t num_edges, const int64_t threshold,\n      torch::ScalarType indptr_dtype, std::vector<torch::ScalarType> dtypes,\n      bool has_original_edge_ids);\n\n private:\n  void* map_;                     // pointer to the hash table.\n  int64_t threshold_;             // A positive threshold value.\n  torch::DeviceIndex device_id_;  // Which GPU the cache resides in.\n  int64_t map_size_;              // The number of nodes inside the hash table.\n  int64_t num_nodes_;             // The number of cached nodes in the cache.\n  int64_t num_edges_;             // The number of cached edges in the cache.\n  torch::Tensor indptr_;          // The cached graph structure indptr tensor.\n  torch::optional<torch::Tensor>\n      offset_;  // The original graph's sliced_indptr tensor.\n  std::vector<torch::Tensor> cached_edge_tensors_;  // The cached graph\n                                                    // structure edge tensors.\n  std::mutex mtx_;  // Protects the data structure and makes it threadsafe.\n};\n\n}  // namespace cuda\n}  // namespace graphbolt\n\n#endif  // GRAPHBOLT_GPU_CACHE_H_\n"
  },
  {
    "path": "graphbolt/src/cuda/extension/unique_and_compact.h",
    "content": "/**\n *   Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)\n *   All rights reserved.\n *\n *   Licensed under the Apache License, Version 2.0 (the \"License\");\n *   you may not use this file except in compliance with the License.\n *   You may obtain a copy of the License at\n *\n *       http://www.apache.org/licenses/LICENSE-2.0\n *\n *   Unless required by applicable law or agreed to in writing, software\n *   distributed under the License is distributed on an \"AS IS\" BASIS,\n *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n *   See the License for the specific language governing permissions and\n *   limitations under the License.\n *\n * @file cuda/unique_and_compact.h\n * @brief Unique and compact operator utilities on CUDA using hash table.\n */\n\n#ifndef GRAPHBOLT_CUDA_UNIQUE_AND_COMPACT_H_\n#define GRAPHBOLT_CUDA_UNIQUE_AND_COMPACT_H_\n\n#include <torch/script.h>\n\n#include <vector>\n\nnamespace graphbolt {\nnamespace ops {\n\nstd::vector<\n    std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>>\nUniqueAndCompactBatchedHashMapBased(\n    const std::vector<torch::Tensor>& src_ids,\n    const std::vector<torch::Tensor>& dst_ids,\n    const std::vector<torch::Tensor>& unique_dst_ids, const int64_t rank,\n    const int64_t world_size);\n\n}  // namespace ops\n}  // namespace graphbolt\n\n#endif  // GRAPHBOLT_CUDA_UNIQUE_AND_COMPACT_H_\n"
  },
  {
    "path": "graphbolt/src/cuda/extension/unique_and_compact_map.cu",
    "content": "/**\n *   Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)\n *   All rights reserved.\n *\n *   Licensed under the Apache License, Version 2.0 (the \"License\");\n *   you may not use this file except in compliance with the License.\n *   You may obtain a copy of the License at\n *\n *       http://www.apache.org/licenses/LICENSE-2.0\n *\n *   Unless required by applicable law or agreed to in writing, software\n *   distributed under the License is distributed on an \"AS IS\" BASIS,\n *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n *   See the License for the specific language governing permissions and\n *   limitations under the License.\n *\n * @file cuda/unique_and_compact_map.cu\n * @brief Unique and compact operator implementation on CUDA using hash table.\n */\n#include <graphbolt/cuda_ops.h>\n#include <thrust/iterator/reverse_iterator.h>\n#include <thrust/iterator/tabulate_output_iterator.h>\n#include <thrust/iterator/transform_iterator.h>\n#include <thrust/iterator/transform_output_iterator.h>\n\n#include <cub/cub.cuh>\n#include <cuco/static_map.cuh>\n#include <cuda/functional>\n#include <cuda/std/atomic>\n#include <cuda/std/utility>\n#include <cuda/stream_ref>\n#include <limits>\n#include <numeric>\n\n#include \"../common.h\"\n#include \"../cooperative_minibatching_utils.cuh\"\n#include \"../cooperative_minibatching_utils.h\"\n#include \"../utils.h\"\n#include \"./unique_and_compact.h\"\n\nnamespace graphbolt {\nnamespace ops {\n\n// Support graphs with up to 2^kNodeIdBits nodes.\nconstexpr int kNodeIdBits = 40;\n\ntemplate <typename index_t, typename map_t>\n__global__ void _InsertAndSetMinBatched(\n    const int64_t num_edges, const int32_t* const indexes, index_t** pointers,\n    const int64_t* const offsets, map_t map) {\n  int64_t i = blockIdx.x * blockDim.x + threadIdx.x;\n  const int stride = gridDim.x * blockDim.x;\n\n  while (i < num_edges) {\n    const auto tensor_index = indexes[i];\n    const auto tensor_offset = i - offsets[tensor_index];\n    const int64_t node_id = pointers[tensor_index][tensor_offset];\n    const int64_t batch_index = tensor_index / 2;\n    const int64_t key = node_id | (batch_index << kNodeIdBits);\n\n    auto [slot, is_new_key] = map.insert_and_find(cuco::pair{key, i});\n\n    if (!is_new_key) {\n      auto ref = ::cuda::atomic_ref<int64_t, ::cuda::thread_scope_device>{\n          slot->second};\n      ref.fetch_min(i, ::cuda::memory_order_relaxed);\n    }\n\n    i += stride;\n  }\n}\n\ntemplate <typename index_t, typename map_t>\n__global__ void _MapIdsBatched(\n    const int num_batches, const int64_t num_edges,\n    const int32_t* const indexes, index_t** pointers,\n    const int64_t* const offsets, const int64_t* const unique_ids_offsets,\n    const index_t* const index, map_t map, index_t* mapped_ids) {\n  int64_t i = blockIdx.x * blockDim.x + threadIdx.x;\n  const int stride = gridDim.x * blockDim.x;\n\n  while (i < num_edges) {\n    const auto tensor_index = indexes[i];\n    int64_t batch_index;\n\n    if (tensor_index >= 2 * num_batches) {\n      batch_index = tensor_index - 2 * num_batches;\n    } else if (tensor_index & 1) {\n      batch_index = tensor_index / 2;\n    } else {\n      batch_index = -1;\n    }\n\n    // Only map src or dst ids.\n    if (batch_index >= 0) {\n      const auto tensor_offset = i - offsets[tensor_index];\n      const int64_t node_id = pointers[tensor_index][tensor_offset];\n      const int64_t key = node_id | (batch_index << kNodeIdBits);\n\n      auto slot = map.find(key);\n      auto new_id = slot->second;\n      if (index) {\n        new_id = index[new_id];\n      } else {\n        new_id -= unique_ids_offsets[batch_index];\n      }\n      mapped_ids[i] = new_id;\n    }\n\n    i += stride;\n  }\n}\n\nstd::vector<\n    std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>>\nUniqueAndCompactBatchedHashMapBased(\n    const std::vector<torch::Tensor>& src_ids,\n    const std::vector<torch::Tensor>& dst_ids,\n    const std::vector<torch::Tensor>& unique_dst_ids, const int64_t rank,\n    const int64_t world_size) {\n  TORCH_CHECK(\n      rank < world_size, \"rank needs to be smaller than the world_size.\");\n  TORCH_CHECK(world_size <= std::numeric_limits<uint32_t>::max());\n  auto allocator = cuda::GetAllocator();\n  auto stream = cuda::GetCurrentStream();\n  auto scalar_type = src_ids.at(0).scalar_type();\n  constexpr int BLOCK_SIZE = 512;\n  const auto num_batches = src_ids.size();\n  static_assert(\n      sizeof(std::ptrdiff_t) == sizeof(int64_t),\n      \"Need to be compiled on a 64-bit system.\");\n  constexpr int batch_id_bits = sizeof(int64_t) * 8 - 1 - kNodeIdBits;\n  TORCH_CHECK(\n      num_batches <= (1 << batch_id_bits),\n      \"UniqueAndCompactBatched supports a batch size of up to \",\n      1 << batch_id_bits);\n  return AT_DISPATCH_INDEX_TYPES(\n      scalar_type, \"unique_and_compact\", ([&] {\n        // For 2 batches of inputs, stores the input tensor pointers in the\n        // unique_dst, src, unique_dst, src, dst, dst order. Since there are\n        // 3 * num_batches input tensors, we need the first 3 * num_batches to\n        // store the input tensor pointers. Then, we store offsets in the rest\n        // of the 3 * num_batches + 1 space as if they were stored contiguously.\n        auto pointers_and_offsets = torch::empty(\n            6 * num_batches + 1,\n            c10::TensorOptions().dtype(torch::kInt64).pinned_memory(true));\n        // Points to the input tensor pointers.\n        auto pointers_ptr =\n            reinterpret_cast<index_t**>(pointers_and_offsets.data_ptr());\n        // Points to the input tensor storage logical offsets.\n        auto offsets_ptr =\n            pointers_and_offsets.data_ptr<int64_t>() + 3 * num_batches;\n        for (std::size_t i = 0; i < num_batches; i++) {\n          pointers_ptr[2 * i] = unique_dst_ids.at(i).data_ptr<index_t>();\n          offsets_ptr[2 * i] = unique_dst_ids[i].size(0);\n          pointers_ptr[2 * i + 1] = src_ids.at(i).data_ptr<index_t>();\n          offsets_ptr[2 * i + 1] = src_ids[i].size(0);\n          pointers_ptr[2 * num_batches + i] = dst_ids.at(i).data_ptr<index_t>();\n          offsets_ptr[2 * num_batches + i] = dst_ids[i].size(0);\n        }\n        // Finish computing the offsets by taking a cumulative sum.\n        std::exclusive_scan(\n            offsets_ptr, offsets_ptr + 3 * num_batches + 1, offsets_ptr, 0ll);\n        // Device version of the tensors defined above. We store the information\n        // initially on the CPU, which are later copied to the device.\n        auto pointers_and_offsets_dev = torch::empty(\n            pointers_and_offsets.size(0),\n            src_ids[0].options().dtype(pointers_and_offsets.scalar_type()));\n        auto offsets_dev = pointers_and_offsets_dev.slice(0, 3 * num_batches);\n        auto pointers_dev_ptr =\n            reinterpret_cast<index_t**>(pointers_and_offsets_dev.data_ptr());\n        auto offsets_dev_ptr = offsets_dev.data_ptr<int64_t>();\n        CUDA_CALL(cudaMemcpyAsync(\n            pointers_dev_ptr, pointers_ptr,\n            sizeof(int64_t) * pointers_and_offsets.size(0),\n            cudaMemcpyHostToDevice, stream));\n        auto indexes = ExpandIndptrImpl(\n            offsets_dev, torch::kInt32, torch::nullopt,\n            offsets_ptr[3 * num_batches]);\n        cuco::static_map map{\n            offsets_ptr[2 * num_batches],\n            0.5,  // load_factor\n            cuco::empty_key{static_cast<int64_t>(-1)},\n            cuco::empty_value{static_cast<int64_t>(-1)},\n            {},\n            cuco::linear_probing<1, cuco::default_hash_function<int64_t>>{},\n            {},\n            {},\n            cuda::CUDAWorkspaceAllocator<cuco::pair<int64_t, int64_t>>{},\n            ::cuda::stream_ref{stream},\n        };\n        C10_CUDA_KERNEL_LAUNCH_CHECK();  // Check the map constructor's success.\n        const dim3 block(BLOCK_SIZE);\n        const dim3 grid(\n            (offsets_ptr[2 * num_batches] + BLOCK_SIZE - 1) / BLOCK_SIZE);\n        CUDA_KERNEL_CALL(\n            _InsertAndSetMinBatched, grid, block, 0,\n            offsets_ptr[2 * num_batches], indexes.data_ptr<int32_t>(),\n            pointers_dev_ptr, offsets_dev_ptr, map.ref(cuco::insert_and_find));\n        cub::ArgIndexInputIterator index_it(indexes.data_ptr<int32_t>());\n        auto input_it = thrust::make_transform_iterator(\n            index_it,\n            ::cuda::proclaim_return_type<\n                ::cuda::std::tuple<int64_t*, index_t, int32_t, bool>>(\n                [=, map = map.ref(cuco::find)] __device__(auto it)\n                    -> ::cuda::std::tuple<int64_t*, index_t, int32_t, bool> {\n                  const auto i = it.key;\n                  const auto tensor_index = it.value;\n                  const auto tensor_offset = i - offsets_dev_ptr[tensor_index];\n                  const int64_t node_id =\n                      pointers_dev_ptr[tensor_index][tensor_offset];\n                  const auto batch_index = tensor_index / 2;\n                  const int64_t key =\n                      node_id |\n                      (static_cast<int64_t>(batch_index) << kNodeIdBits);\n                  const auto batch_offset = offsets_dev_ptr[batch_index * 2];\n\n                  auto slot = map.find(key);\n                  const auto valid = slot->second == i;\n\n                  return {&slot->second, node_id, batch_index, valid};\n                }));\n        torch::optional<torch::Tensor> part_ids;\n        if (world_size > 1) {\n          part_ids = torch::empty(\n              offsets_ptr[2 * num_batches],\n              src_ids[0].options().dtype(cuda::kPartDType));\n        }\n        auto unique_ids =\n            torch::empty(offsets_ptr[2 * num_batches], src_ids[0].options());\n        auto unique_ids_offsets_dev = torch::full(\n            num_batches + 1, std::numeric_limits<int64_t>::max(),\n            src_ids[0].options().dtype(torch::kInt64));\n        auto unique_ids_offsets_dev_ptr =\n            unique_ids_offsets_dev.data_ptr<int64_t>();\n        auto output_it = thrust::make_tabulate_output_iterator(\n            ::cuda::proclaim_return_type<void>(\n                [=, unique_ids_ptr = unique_ids.data_ptr<index_t>(),\n                 part_ids_ptr =\n                     part_ids ? part_ids->data_ptr<cuda::part_t>() : nullptr,\n                 rank = static_cast<uint32_t>(rank),\n                 world_size = static_cast<uint32_t>(\n                     world_size)] __device__(const int64_t i, const auto& t) {\n                  *::cuda::std::get<0>(t) = i;\n                  const auto node_id = ::cuda::std::get<1>(t);\n                  unique_ids_ptr[i] = node_id;\n                  if (part_ids_ptr) {\n                    part_ids_ptr[i] =\n                        cuda::rank_assignment(node_id, rank, world_size);\n                  }\n                  const auto batch_index = ::cuda::std::get<2>(t);\n                  auto ref =\n                      ::cuda::atomic_ref<int64_t, ::cuda::thread_scope_device>{\n                          unique_ids_offsets_dev_ptr[batch_index]};\n                  ref.fetch_min(i, ::cuda::memory_order_relaxed);\n                }));\n        CUB_CALL(\n            DeviceSelect::If, input_it, output_it,\n            unique_ids_offsets_dev_ptr + num_batches,\n            offsets_ptr[2 * num_batches],\n            ::cuda::proclaim_return_type<bool>([] __device__(const auto& t) {\n              return ::cuda::std::get<3>(t);\n            }));\n        auto unique_ids_offsets = torch::empty(\n            num_batches + 1,\n            c10::TensorOptions().dtype(torch::kInt64).pinned_memory(true));\n        {\n          auto unique_ids_offsets_dev2 =\n              torch::empty_like(unique_ids_offsets_dev);\n          CUB_CALL(\n              DeviceScan::InclusiveScan,\n              thrust::make_reverse_iterator(\n                  num_batches + 1 + unique_ids_offsets_dev_ptr),\n              thrust::make_reverse_iterator(\n                  num_batches + 1 +\n                  thrust::make_transform_output_iterator(\n                      thrust::make_zip_iterator(\n                          unique_ids_offsets_dev2.data_ptr<int64_t>(),\n                          unique_ids_offsets.data_ptr<int64_t>()),\n                      ::cuda::proclaim_return_type<\n                          thrust::tuple<int64_t, int64_t>>(\n                          [=] __device__(const auto x) {\n                            return thrust::make_tuple(x, x);\n                          }))),\n              cub::Min{}, num_batches + 1);\n          unique_ids_offsets_dev = unique_ids_offsets_dev2;\n          unique_ids_offsets_dev_ptr =\n              unique_ids_offsets_dev.data_ptr<int64_t>();\n        }\n        at::cuda::CUDAEvent unique_ids_offsets_event;\n        unique_ids_offsets_event.record();\n        torch::optional<torch::Tensor> index;\n        if (part_ids) {\n          unique_ids_offsets_event.synchronize();\n          const auto num_unique =\n              unique_ids_offsets.data_ptr<int64_t>()[num_batches];\n          unique_ids = unique_ids.slice(0, 0, num_unique);\n          part_ids = part_ids->slice(0, 0, num_unique);\n          std::tie(\n              unique_ids, index, unique_ids_offsets, unique_ids_offsets_event) =\n              cuda::RankSortImpl(\n                  unique_ids, *part_ids, unique_ids_offsets_dev, world_size);\n        }\n        auto mapped_ids =\n            torch::empty(offsets_ptr[3 * num_batches], unique_ids.options());\n        CUDA_KERNEL_CALL(\n            _MapIdsBatched, grid, block, 0, num_batches,\n            offsets_ptr[3 * num_batches], indexes.data_ptr<int32_t>(),\n            pointers_dev_ptr, offsets_dev_ptr, unique_ids_offsets_dev_ptr,\n            index ? index->data_ptr<index_t>() : nullptr, map.ref(cuco::find),\n            mapped_ids.data_ptr<index_t>());\n        std::vector<std::tuple<\n            torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>>\n            results;\n        unique_ids_offsets_event.synchronize();\n        auto unique_ids_offsets_ptr = unique_ids_offsets.data_ptr<int64_t>();\n        for (int64_t i = 0; i < num_batches; i++) {\n          results.emplace_back(\n              unique_ids.slice(\n                  0, unique_ids_offsets_ptr[i * world_size],\n                  unique_ids_offsets_ptr[(i + 1) * world_size]),\n              mapped_ids.slice(\n                  0, offsets_ptr[2 * i + 1], offsets_ptr[2 * i + 2]),\n              mapped_ids.slice(\n                  0, offsets_ptr[2 * num_batches + i],\n                  offsets_ptr[2 * num_batches + i + 1]),\n              unique_ids_offsets.slice(\n                  0, i * world_size, (i + 1) * world_size + 1));\n        }\n        return results;\n      }));\n}\n\n}  // namespace ops\n}  // namespace graphbolt\n"
  },
  {
    "path": "graphbolt/src/cuda/gather.cu",
    "content": "/**\n *   Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)\n *   All rights reserved.\n *\n *   Licensed under the Apache License, Version 2.0 (the \"License\");\n *   you may not use this file except in compliance with the License.\n *   You may obtain a copy of the License at\n *\n *       http://www.apache.org/licenses/LICENSE-2.0\n *\n *   Unless required by applicable law or agreed to in writing, software\n *   distributed under the License is distributed on an \"AS IS\" BASIS,\n *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n *   See the License for the specific language governing permissions and\n *   limitations under the License.\n *\n * @file cuda/gather.cu\n * @brief Gather operators implementation on CUDA.\n */\n#include <thrust/gather.h>\n\n#include \"./common.h\"\n\nnamespace graphbolt {\nnamespace ops {\n\ntorch::Tensor Gather(\n    torch::Tensor input, torch::Tensor index,\n    torch::optional<torch::ScalarType> dtype) {\n  if (!dtype.has_value()) dtype = input.scalar_type();\n  auto output = torch::empty(index.sizes(), index.options().dtype(*dtype));\n  AT_DISPATCH_INDEX_TYPES(\n      index.scalar_type(), \"GatherIndexType\", ([&] {\n        AT_DISPATCH_INTEGRAL_TYPES(\n            input.scalar_type(), \"GatherInputType\", ([&] {\n              using input_t = scalar_t;\n              AT_DISPATCH_INTEGRAL_TYPES(*dtype, \"GatherOutputType\", ([&] {\n                using output_t = scalar_t;\n                THRUST_CALL(\n                    gather, index.data_ptr<index_t>(),\n                    index.data_ptr<index_t>() + index.size(0),\n                    input.data_ptr<input_t>(), output.data_ptr<output_t>());\n              }));\n            }));\n      }));\n  return output;\n}\n\n}  // namespace ops\n}  // namespace graphbolt\n"
  },
  {
    "path": "graphbolt/src/cuda/index_select_csc_impl.cu",
    "content": "/**\n *   Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)\n *   All rights reserved.\n *\n *   Licensed under the Apache License, Version 2.0 (the \"License\");\n *   you may not use this file except in compliance with the License.\n *   You may obtain a copy of the License at\n *\n *       http://www.apache.org/licenses/LICENSE-2.0\n *\n *   Unless required by applicable law or agreed to in writing, software\n *   distributed under the License is distributed on an \"AS IS\" BASIS,\n *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n *   See the License for the specific language governing permissions and\n *   limitations under the License.\n *\n * @file cuda/index_select_csc_impl.cu\n * @brief Index select csc operator implementation on CUDA.\n */\n#include <c10/core/ScalarType.h>\n#include <graphbolt/cuda_ops.h>\n#include <thrust/iterator/counting_iterator.h>\n#include <thrust/iterator/transform_iterator.h>\n#include <thrust/iterator/zip_iterator.h>\n\n#include <cstdint>\n#include <cub/cub.cuh>\n#include <numeric>\n\n#include \"./common.h\"\n#include \"./max_uva_threads.h\"\n#include \"./utils.h\"\n\nnamespace graphbolt {\nnamespace ops {\n\nconstexpr int BLOCK_SIZE = CUDA_MAX_NUM_THREADS;\n\n// Given the in_degree array and a permutation, returns in_degree of the output\n// and the permuted and modified in_degree of the input. The modified in_degree\n// is modified so that there is slack to be able to align as needed.\ntemplate <typename indptr_t, typename indices_t>\nstruct AlignmentFunc {\n  static_assert(GPU_CACHE_LINE_SIZE % sizeof(indices_t) == 0);\n  const indptr_t* in_degree;\n  const int64_t* perm;\n  int64_t num_nodes;\n  __host__ __device__ auto operator()(int64_t row) {\n    constexpr int num_elements = GPU_CACHE_LINE_SIZE / sizeof(indices_t);\n    return thrust::make_tuple(\n        in_degree[row],\n        // A single cache line has num_elements items, we add num_elements - 1\n        // to ensure there is enough slack to move forward or backward by\n        // num_elements - 1 items if the performed access is not aligned.\n        static_cast<indptr_t>(\n            in_degree[perm ? perm[row % num_nodes] : row] + num_elements - 1));\n  }\n};\n\ntemplate <typename indptr_t, typename indices_t, typename coo_rows_t>\n__global__ void _CopyIndicesAlignedKernel(\n    const indptr_t edge_count, const indptr_t* const indptr,\n    const indptr_t* const output_indptr,\n    const indptr_t* const output_indptr_aligned, const indices_t* const indices,\n    const coo_rows_t* const coo_aligned_rows, indices_t* const output_indices,\n    const int64_t* const perm) {\n  indptr_t idx = static_cast<indptr_t>(blockIdx.x) * blockDim.x + threadIdx.x;\n  const int stride_x = gridDim.x * blockDim.x;\n\n  while (idx < edge_count) {\n    const auto permuted_row_pos = coo_aligned_rows[idx];\n    const auto row_pos = perm ? perm[permuted_row_pos] : permuted_row_pos;\n    const auto out_row = output_indptr[row_pos];\n    const auto d = output_indptr[row_pos + 1] - out_row;\n    const int offset = (reinterpret_cast<std::uintptr_t>(\n                            indices + indptr[row_pos] -\n                            output_indptr_aligned[permuted_row_pos]) %\n                        GPU_CACHE_LINE_SIZE) /\n                       sizeof(indices_t);\n    const auto rofs = idx - output_indptr_aligned[permuted_row_pos] - offset;\n    if (rofs >= 0 && rofs < d) {\n      const auto in_idx = indptr[row_pos] + rofs;\n      assert(\n          reinterpret_cast<std::uintptr_t>(indices + in_idx - idx) %\n              GPU_CACHE_LINE_SIZE ==\n          0);\n      const auto u = indices[in_idx];\n      output_indices[out_row + rofs] = u;\n    }\n    idx += stride_x;\n  }\n}\n\nstruct PairSum {\n  template <typename indptr_t>\n  __host__ __device__ auto operator()(\n      const thrust::tuple<indptr_t, indptr_t> a,\n      const thrust::tuple<indptr_t, indptr_t> b) {\n    return thrust::make_tuple(\n        thrust::get<0>(a) + thrust::get<0>(b),\n        thrust::get<1>(a) + thrust::get<1>(b));\n  };\n};\n\ntemplate <typename indptr_t, typename indices_t>\nstd::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCCopyIndices(\n    torch::Tensor indices, const int64_t num_nodes,\n    const indptr_t* const in_degree, const indptr_t* const sliced_indptr,\n    const int64_t* const perm, torch::TensorOptions options,\n    torch::ScalarType indptr_scalar_type,\n    torch::optional<int64_t> output_size) {\n  auto allocator = cuda::GetAllocator();\n  thrust::counting_iterator<int64_t> iota(0);\n\n  // Output indptr for the slice indexed by nodes.\n  auto output_indptr =\n      torch::empty(num_nodes + 1, options.dtype(indptr_scalar_type));\n\n  auto output_indptr_aligned =\n      torch::empty(num_nodes + 1, options.dtype(indptr_scalar_type));\n  auto output_indptr_aligned_ptr = output_indptr_aligned.data_ptr<indptr_t>();\n\n  {\n    // Returns the actual and modified_indegree as a pair, the\n    // latter overestimates the actual indegree for alignment\n    // purposes.\n    auto modified_in_degree = thrust::make_transform_iterator(\n        iota, AlignmentFunc<indptr_t, indices_t>{in_degree, perm, num_nodes});\n    auto output_indptr_pair = thrust::make_zip_iterator(\n        output_indptr.data_ptr<indptr_t>(), output_indptr_aligned_ptr);\n    thrust::tuple<indptr_t, indptr_t> zero_value{};\n    // Compute the prefix sum over actual and modified indegrees.\n    CUB_CALL(\n        DeviceScan::ExclusiveScan, modified_in_degree, output_indptr_pair,\n        PairSum{}, zero_value, num_nodes + 1);\n  }\n\n  // Copy the actual total number of edges.\n  if (!output_size.has_value()) {\n    auto edge_count =\n        cuda::CopyScalar{output_indptr.data_ptr<indptr_t>() + num_nodes};\n    output_size = static_cast<indptr_t>(edge_count);\n  }\n  // Copy the modified number of edges.\n  auto edge_count_aligned_ =\n      cuda::CopyScalar{output_indptr_aligned_ptr + num_nodes};\n  const int64_t edge_count_aligned = static_cast<indptr_t>(edge_count_aligned_);\n\n  // Allocate output array with actual number of edges.\n  torch::Tensor output_indices =\n      torch::empty(output_size.value(), options.dtype(indices.scalar_type()));\n  const dim3 block(BLOCK_SIZE);\n  const dim3 grid(\n      (std::min(edge_count_aligned, cuda::max_uva_threads.value_or(1 << 20)) +\n       BLOCK_SIZE - 1) /\n      BLOCK_SIZE);\n\n  // Find the smallest integer type to store the coo_aligned_rows tensor.\n  const int num_bits = cuda::NumberOfBits(num_nodes);\n  std::array<int, 4> type_bits = {8, 15, 31, 63};\n  const auto type_index =\n      std::lower_bound(type_bits.begin(), type_bits.end(), num_bits) -\n      type_bits.begin();\n  std::array<torch::ScalarType, 5> types = {\n      torch::kByte, torch::kInt16, torch::kInt32, torch::kLong, torch::kLong};\n  auto coo_dtype = types[type_index];\n\n  auto coo_aligned_rows = ExpandIndptrImpl(\n      output_indptr_aligned, coo_dtype, torch::nullopt, edge_count_aligned);\n\n  AT_DISPATCH_INTEGRAL_TYPES(\n      coo_dtype, \"UVAIndexSelectCSCCopyIndicesCOO\", ([&] {\n        using coo_rows_t = scalar_t;\n        // Perform the actual copying, of the indices array into\n        // output_indices in an aligned manner.\n        CUDA_KERNEL_CALL(\n            _CopyIndicesAlignedKernel, grid, block, 0,\n            static_cast<indptr_t>(edge_count_aligned_), sliced_indptr,\n            output_indptr.data_ptr<indptr_t>(), output_indptr_aligned_ptr,\n            reinterpret_cast<indices_t*>(indices.data_ptr()),\n            coo_aligned_rows.data_ptr<coo_rows_t>(),\n            reinterpret_cast<indices_t*>(output_indices.data_ptr()), perm);\n      }));\n  return {output_indptr, output_indices};\n}\n\nstd::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCImpl(\n    torch::Tensor in_degree, torch::Tensor sliced_indptr, torch::Tensor indices,\n    torch::Tensor nodes, int num_bits, torch::optional<int64_t> output_size) {\n  // Sorting nodes so that accesses over PCI-e are more regular.\n  const auto sorted_idx = Sort(nodes, num_bits).second;\n  const int64_t num_nodes = nodes.size(0);\n\n  return AT_DISPATCH_INTEGRAL_TYPES(\n      sliced_indptr.scalar_type(), \"UVAIndexSelectCSCIndptr\", ([&] {\n        using indptr_t = scalar_t;\n        return GRAPHBOLT_DISPATCH_ELEMENT_SIZES(\n            indices.element_size(), \"UVAIndexSelectCSCCopyIndices\", ([&] {\n              return UVAIndexSelectCSCCopyIndices<indptr_t, element_size_t>(\n                  indices, num_nodes, in_degree.data_ptr<indptr_t>(),\n                  sliced_indptr.data_ptr<indptr_t>(),\n                  sorted_idx.data_ptr<int64_t>(), nodes.options(),\n                  sliced_indptr.scalar_type(), output_size);\n            }));\n      }));\n}\n\ntemplate <typename indptr_t, typename indices_t>\nstruct IteratorFunc {\n  indptr_t* indptr;\n  indices_t* indices;\n  __host__ __device__ auto operator()(int64_t i) { return indices + indptr[i]; }\n};\n\ntemplate <typename indptr_t, typename indices_t>\nstruct ConvertToBytes {\n  const indptr_t* in_degree;\n  __host__ __device__ indptr_t operator()(int64_t i) {\n    return in_degree[i] * sizeof(indices_t);\n  }\n};\n\ntemplate <typename indptr_t, typename indices_t>\nvoid IndexSelectCSCCopyIndices(\n    const int64_t num_nodes, indices_t* const indices,\n    indptr_t* const sliced_indptr, const indptr_t* const in_degree,\n    indptr_t* const output_indptr, indices_t* const output_indices) {\n  thrust::counting_iterator<int64_t> iota(0);\n\n  auto input_buffer_it = thrust::make_transform_iterator(\n      iota, IteratorFunc<indptr_t, indices_t>{sliced_indptr, indices});\n  auto output_buffer_it = thrust::make_transform_iterator(\n      iota, IteratorFunc<indptr_t, indices_t>{output_indptr, output_indices});\n  auto buffer_sizes = thrust::make_transform_iterator(\n      iota, ConvertToBytes<indptr_t, indices_t>{in_degree});\n  constexpr int64_t max_copy_at_once = std::numeric_limits<int32_t>::max();\n\n  // Performs the copy from indices into output_indices.\n  for (int64_t i = 0; i < num_nodes; i += max_copy_at_once) {\n    CUB_CALL(\n        DeviceMemcpy::Batched, input_buffer_it + i, output_buffer_it + i,\n        buffer_sizes + i, std::min(num_nodes - i, max_copy_at_once));\n  }\n}\n\nstd::tuple<torch::Tensor, torch::Tensor> DeviceIndexSelectCSCImpl(\n    torch::Tensor in_degree, torch::Tensor sliced_indptr, torch::Tensor indices,\n    torch::TensorOptions options, torch::optional<int64_t> output_size) {\n  const int64_t num_nodes = sliced_indptr.size(0);\n  return AT_DISPATCH_INTEGRAL_TYPES(\n      sliced_indptr.scalar_type(), \"IndexSelectCSCIndptr\", ([&] {\n        using indptr_t = scalar_t;\n        auto in_degree_ptr = in_degree.data_ptr<indptr_t>();\n        auto sliced_indptr_ptr = sliced_indptr.data_ptr<indptr_t>();\n        // Output indptr for the slice indexed by nodes.\n        torch::Tensor output_indptr = torch::empty(\n            num_nodes + 1, options.dtype(sliced_indptr.scalar_type()));\n\n        // Compute the output indptr, output_indptr.\n        CUB_CALL(\n            DeviceScan::ExclusiveSum, in_degree_ptr,\n            output_indptr.data_ptr<indptr_t>(), num_nodes + 1);\n\n        // Number of edges being copied.\n        if (!output_size.has_value()) {\n          auto edge_count =\n              cuda::CopyScalar{output_indptr.data_ptr<indptr_t>() + num_nodes};\n          output_size = static_cast<indptr_t>(edge_count);\n        }\n        // Allocate output array of size number of copied edges.\n        torch::Tensor output_indices = torch::empty(\n            output_size.value(), options.dtype(indices.scalar_type()));\n        GRAPHBOLT_DISPATCH_ELEMENT_SIZES(\n            indices.element_size(), \"IndexSelectCSCCopyIndices\", ([&] {\n              using indices_t = element_size_t;\n              IndexSelectCSCCopyIndices<indptr_t, indices_t>(\n                  num_nodes, reinterpret_cast<indices_t*>(indices.data_ptr()),\n                  sliced_indptr_ptr, in_degree_ptr,\n                  output_indptr.data_ptr<indptr_t>(),\n                  reinterpret_cast<indices_t*>(output_indices.data_ptr()));\n            }));\n        return std::make_tuple(output_indptr, output_indices);\n      }));\n}\n\nstd::tuple<torch::Tensor, torch::Tensor> IndexSelectCSCImpl(\n    torch::Tensor in_degree, torch::Tensor sliced_indptr, torch::Tensor indices,\n    torch::Tensor nodes, int64_t nodes_max,\n    torch::optional<int64_t> output_size) {\n  if (indices.is_pinned()) {\n    int num_bits = cuda::NumberOfBits(nodes_max + 1);\n    return UVAIndexSelectCSCImpl(\n        in_degree, sliced_indptr, indices, nodes, num_bits, output_size);\n  } else {\n    return DeviceIndexSelectCSCImpl(\n        in_degree, sliced_indptr, indices, nodes.options(), output_size);\n  }\n}\n\nstd::tuple<torch::Tensor, torch::Tensor> IndexSelectCSCImpl(\n    torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes,\n    torch::optional<int64_t> output_size) {\n  auto [in_degree, sliced_indptr] = SliceCSCIndptr(indptr, nodes);\n  return IndexSelectCSCImpl(\n      in_degree, sliced_indptr, indices, nodes, indptr.size(0) - 2,\n      output_size);\n}\n\nstd::tuple<torch::Tensor, std::vector<torch::Tensor>> IndexSelectCSCBatchedImpl(\n    torch::Tensor indptr, std::vector<torch::Tensor> indices_list,\n    torch::Tensor nodes, bool with_edge_ids,\n    torch::optional<int64_t> output_size) {\n  auto [in_degree, sliced_indptr] = SliceCSCIndptr(indptr, nodes);\n  std::vector<torch::Tensor> results;\n  results.reserve(indices_list.size());\n  torch::Tensor output_indptr;\n  for (auto& indices : indices_list) {\n    torch::Tensor output_indices;\n    std::tie(output_indptr, output_indices) = IndexSelectCSCImpl(\n        in_degree, sliced_indptr, indices, nodes, indptr.size(0) - 2,\n        output_size);\n    if (!output_size.has_value()) output_size = output_indices.size(0);\n    TORCH_CHECK(*output_size == output_indices.size(0));\n    results.push_back(output_indices);\n  }\n  if (with_edge_ids) {\n    results.push_back(IndptrEdgeIdsImpl(\n        output_indptr, sliced_indptr.scalar_type(), sliced_indptr,\n        output_size));\n  }\n  return {output_indptr, results};\n}\n\n}  //  namespace ops\n}  //  namespace graphbolt\n"
  },
  {
    "path": "graphbolt/src/cuda/index_select_impl.cu",
    "content": "/**\n *   Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)\n *   All rights reserved.\n *\n *   Licensed under the Apache License, Version 2.0 (the \"License\");\n *   you may not use this file except in compliance with the License.\n *   You may obtain a copy of the License at\n *\n *       http://www.apache.org/licenses/LICENSE-2.0\n *\n *   Unless required by applicable law or agreed to in writing, software\n *   distributed under the License is distributed on an \"AS IS\" BASIS,\n *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n *   See the License for the specific language governing permissions and\n *   limitations under the License.\n *\n * @file cuda/index_select_impl.cu\n * @brief Index select operator implementation on CUDA.\n */\n#include <c10/core/ScalarType.h>\n#include <graphbolt/cuda_ops.h>\n\n#include <numeric>\n\n#include \"./common.h\"\n#include \"./max_uva_threads.h\"\n#include \"./utils.h\"\n\nnamespace graphbolt {\nnamespace ops {\n\n/** @brief Index select operator implementation for feature size 1. */\ntemplate <typename DType, typename IdType>\n__global__ void IndexSelectSingleKernel(\n    const DType* input, const int64_t input_len, const IdType* index,\n    const int64_t output_len, DType* output,\n    const int64_t* permutation = nullptr) {\n  int64_t out_row_index = blockIdx.x * blockDim.x + threadIdx.x;\n  int stride = gridDim.x * blockDim.x;\n  while (out_row_index < output_len) {\n    assert(index[out_row_index] >= 0 && index[out_row_index] < input_len);\n    const auto out_row =\n        permutation ? permutation[out_row_index] : out_row_index;\n    output[out_row] = input[index[out_row_index]];\n    out_row_index += stride;\n  }\n}\n\n/**\n * @brief Index select operator implementation for feature size > 1.\n */\ntemplate <typename DType, typename IdType>\n__global__ void IndexSelectMultiKernel(\n    const DType* const input, const int64_t input_len,\n    const int64_t feature_size, const IdType* const index,\n    const int64_t output_len, DType* const output,\n    const int64_t* permutation = nullptr) {\n  int64_t out_row_index = blockIdx.x * blockDim.y + threadIdx.y;\n\n  const int64_t stride = blockDim.y * gridDim.x;\n\n  while (out_row_index < output_len) {\n    int64_t column = threadIdx.x;\n    const int64_t in_row = index[out_row_index];\n    assert(in_row >= 0 && in_row < input_len);\n    const auto out_row =\n        permutation ? permutation[out_row_index] : out_row_index;\n    while (column < feature_size) {\n      output[out_row * feature_size + column] =\n          input[in_row * feature_size + column];\n      column += blockDim.x;\n    }\n    out_row_index += stride;\n  }\n}\n\n/**\n * @brief Index select operator implementation for feature size > 1.\n *\n * @note This is a cross-device access version of IndexSelectMultiKernel. Since\n * the memory access over PCIe is more sensitive to the data access aligment\n * (cacheline), we need a separate version here.\n */\ntemplate <typename DType, typename IdType>\n__global__ void IndexSelectMultiKernelAligned(\n    const DType* const input, const int64_t input_len,\n    const int64_t feature_size, const IdType* const index,\n    const int64_t output_len, DType* const output,\n    const int64_t* permutation = nullptr) {\n  int64_t out_row_index = blockIdx.x * blockDim.y + threadIdx.y;\n\n  const int64_t stride = blockDim.y * gridDim.x;\n\n  while (out_row_index < output_len) {\n    int64_t col = threadIdx.x;\n    const int64_t in_row = index[out_row_index];\n    assert(in_row >= 0 && in_row < input_len);\n    const int64_t idx_offset =\n        ((uint64_t)(&input[in_row * feature_size]) % GPU_CACHE_LINE_SIZE) /\n        sizeof(DType);\n    col = col - idx_offset;\n    const auto out_row =\n        permutation ? permutation[out_row_index] : out_row_index;\n    while (col < feature_size) {\n      if (col >= 0)\n        output[out_row * feature_size + col] =\n            input[in_row * feature_size + col];\n      col += blockDim.x;\n    }\n    out_row_index += stride;\n  }\n}\n\ntemplate <typename DType, typename IdType>\ntorch::Tensor UVAIndexSelectImpl_(torch::Tensor input, torch::Tensor index) {\n  const int64_t input_len = input.size(0);\n  const int64_t return_len = index.size(0);\n  const int64_t original_feature_size = std::accumulate(\n      input.sizes().begin() + 1, input.sizes().end(), 1ll, std::multiplies<>());\n  const auto aligned_feature_size =\n      input.element_size() * original_feature_size / sizeof(DType);\n  torch::Tensor ret = torch::empty(\n      {return_len, original_feature_size}, torch::TensorOptions()\n                                               .dtype(input.dtype())\n                                               .device(c10::DeviceType::CUDA));\n  DType* input_ptr = reinterpret_cast<DType*>(input.data_ptr());\n  DType* ret_ptr = reinterpret_cast<DType*>(ret.data_ptr());\n\n  // Sort the index to improve the memory access pattern.\n  torch::Tensor sorted_index, permutation;\n  std::tie(sorted_index, permutation) =\n      Sort(index, cuda::NumberOfBits(input_len));\n  const IdType* index_sorted_ptr = sorted_index.data_ptr<IdType>();\n  const int64_t* permutation_ptr = permutation.data_ptr<int64_t>();\n\n  if (aligned_feature_size == 1) {\n    // Use a single thread to process each output row to avoid wasting threads.\n    const int num_threads = cuda::FindNumThreads(return_len);\n    const int num_blocks =\n        (std::min(return_len, cuda::max_uva_threads.value_or(1 << 20)) +\n         num_threads - 1) /\n        num_threads;\n    CUDA_KERNEL_CALL(\n        IndexSelectSingleKernel, num_blocks, num_threads, 0, input_ptr,\n        input_len, index_sorted_ptr, return_len, ret_ptr, permutation_ptr);\n  } else {\n    constexpr int BLOCK_SIZE = CUDA_MAX_NUM_THREADS;\n    dim3 block(BLOCK_SIZE, 1);\n    while (static_cast<int64_t>(block.x) >= 2 * aligned_feature_size) {\n      block.x >>= 1;\n      block.y <<= 1;\n    }\n    const dim3 grid(std::min(\n        (return_len + block.y - 1) / block.y,\n        cuda::max_uva_threads.value_or(1 << 20) / BLOCK_SIZE));\n    if (aligned_feature_size * sizeof(DType) <= GPU_CACHE_LINE_SIZE) {\n      // When feature size is smaller than GPU cache line size, use unaligned\n      // version for less SM usage, which is more resource efficient.\n      CUDA_KERNEL_CALL(\n          IndexSelectMultiKernel, grid, block, 0, input_ptr, input_len,\n          aligned_feature_size, index_sorted_ptr, return_len, ret_ptr,\n          permutation_ptr);\n    } else {\n      // Use aligned version to improve the memory access pattern.\n      CUDA_KERNEL_CALL(\n          IndexSelectMultiKernelAligned, grid, block, 0, input_ptr, input_len,\n          aligned_feature_size, index_sorted_ptr, return_len, ret_ptr,\n          permutation_ptr);\n    }\n  }\n\n  auto return_shape = std::vector<int64_t>({return_len});\n  return_shape.insert(\n      return_shape.end(), input.sizes().begin() + 1, input.sizes().end());\n  ret = ret.reshape(return_shape);\n  return ret;\n}\n\n/**\n * @brief UVA index select operator implementation on CUDA.\n *\n * All basic torch types are supported for input.\n * The supporting index types are: int, int64_t.\n */\ntorch::Tensor UVAIndexSelectImpl(torch::Tensor input, torch::Tensor index) {\n  return AT_DISPATCH_INDEX_TYPES(\n      index.scalar_type(), \"UVAIndexSelectImpl\", ([&] {\n        const auto ptr = (size_t)input.data_ptr();\n        const int64_t feature_size = std::accumulate(\n            input.sizes().begin() + 1, input.sizes().end(), 1ll,\n            std::multiplies<>());\n        // We perform the copy with datatype of size powers of 2, and the\n        // maximum data type we use has 16 bytes. We check the alignment of the\n        // pointer and the feature dimensionality to determine the largest\n        // type to use for the copy to minimize the number of CUDA threads used.\n        // Alignment denotes the maximum suitable alignment and datatype size\n        // for the copies.\n        const int aligned_access_size =\n            std::gcd(16, std::gcd(ptr, input.element_size() * feature_size));\n        return GRAPHBOLT_DISPATCH_ELEMENT_SIZES(\n            aligned_access_size, \"UVAIndexSelectImplElementSize\", ([&] {\n              return UVAIndexSelectImpl_<element_size_t, index_t>(input, index);\n            }));\n      }));\n}\n\n}  //  namespace ops\n}  //  namespace graphbolt\n"
  },
  {
    "path": "graphbolt/src/cuda/insubgraph.cu",
    "content": "/**\n *   Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)\n *   All rights reserved.\n *\n *   Licensed under the Apache License, Version 2.0 (the \"License\");\n *   you may not use this file except in compliance with the License.\n *   You may obtain a copy of the License at\n *\n *       http://www.apache.org/licenses/LICENSE-2.0\n *\n *   Unless required by applicable law or agreed to in writing, software\n *   distributed under the License is distributed on an \"AS IS\" BASIS,\n *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n *   See the License for the specific language governing permissions and\n *   limitations under the License.\n *\n * @file cuda/insubgraph.cu\n * @brief InSubgraph operator implementation on CUDA.\n */\n\n#include <graphbolt/cuda_ops.h>\n#include <graphbolt/cuda_sampling_ops.h>\n\n#include \"./common.h\"\n\nnamespace graphbolt {\nnamespace ops {\n\nc10::intrusive_ptr<sampling::FusedSampledSubgraph> InSubgraph(\n    torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes,\n    torch::optional<torch::Tensor> type_per_edge) {\n  auto [in_degree, sliced_indptr] = SliceCSCIndptr(indptr, nodes);\n  auto [output_indptr, output_indices] = IndexSelectCSCImpl(\n      in_degree, sliced_indptr, indices, nodes, indptr.size(0) - 2);\n  const int64_t num_edges = output_indices.size(0);\n  torch::optional<torch::Tensor> output_type_per_edge;\n  if (type_per_edge) {\n    output_type_per_edge = std::get<1>(IndexSelectCSCImpl(\n        in_degree, sliced_indptr, type_per_edge.value(), nodes,\n        indptr.size(0) - 2, num_edges));\n  }\n  auto edge_ids = IndptrEdgeIdsImpl(\n      output_indptr, sliced_indptr.scalar_type(), sliced_indptr, num_edges);\n\n  return c10::make_intrusive<sampling::FusedSampledSubgraph>(\n      output_indptr, output_indices, edge_ids, nodes, torch::nullopt,\n      output_type_per_edge);\n}\n\n}  // namespace ops\n}  // namespace graphbolt\n"
  },
  {
    "path": "graphbolt/src/cuda/isin.cu",
    "content": "/**\n *   Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)\n *   All rights reserved.\n *\n *   Licensed under the Apache License, Version 2.0 (the \"License\");\n *   you may not use this file except in compliance with the License.\n *   You may obtain a copy of the License at\n *\n *       http://www.apache.org/licenses/LICENSE-2.0\n *\n *   Unless required by applicable law or agreed to in writing, software\n *   distributed under the License is distributed on an \"AS IS\" BASIS,\n *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n *   See the License for the specific language governing permissions and\n *   limitations under the License.\n *\n * @file cuda/isin.cu\n * @brief IsIn operator implementation on CUDA.\n */\n#include <graphbolt/cuda_ops.h>\n#include <thrust/binary_search.h>\n\n#include <cub/cub.cuh>\n\n#include \"./common.h\"\n\nnamespace graphbolt {\nnamespace ops {\n\ntorch::Tensor IsIn(torch::Tensor elements, torch::Tensor test_elements) {\n  auto sorted_test_elements = Sort<false>(test_elements);\n  auto result = torch::empty_like(elements, torch::kBool);\n\n  AT_DISPATCH_INTEGRAL_TYPES(\n      elements.scalar_type(), \"IsInOperation\", ([&] {\n        THRUST_CALL(\n            binary_search, sorted_test_elements.data_ptr<scalar_t>(),\n            sorted_test_elements.data_ptr<scalar_t>() +\n                sorted_test_elements.size(0),\n            elements.data_ptr<scalar_t>(),\n            elements.data_ptr<scalar_t>() + elements.size(0),\n            result.data_ptr<bool>());\n      }));\n  return result;\n}\n\ntorch::Tensor Nonzero(torch::Tensor mask, bool logical_not) {\n  thrust::counting_iterator<int64_t> iota(0);\n  auto result = torch::empty_like(mask, torch::kInt64);\n  auto mask_ptr = mask.data_ptr<bool>();\n  auto result_ptr = result.data_ptr<int64_t>();\n  auto allocator = cuda::GetAllocator();\n  auto num_copied = allocator.AllocateStorage<int64_t>(1);\n  if (logical_not) {\n    CUB_CALL(\n        DeviceSelect::FlaggedIf, iota, mask_ptr, result_ptr, num_copied.get(),\n        mask.numel(), thrust::logical_not<bool>{});\n  } else {\n    CUB_CALL(\n        DeviceSelect::Flagged, iota, mask_ptr, result_ptr, num_copied.get(),\n        mask.numel());\n  }\n  cuda::CopyScalar num_copied_cpu(num_copied.get());\n  return result.slice(0, 0, static_cast<int64_t>(num_copied_cpu));\n}\n\n}  // namespace ops\n}  // namespace graphbolt\n"
  },
  {
    "path": "graphbolt/src/cuda/max_uva_threads.cc",
    "content": "/**\n *   Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)\n *   All rights reserved.\n *\n *   Licensed under the Apache License, Version 2.0 (the \"License\");\n *   you may not use this file except in compliance with the License.\n *   You may obtain a copy of the License at\n *\n *       http://www.apache.org/licenses/LICENSE-2.0\n *\n *   Unless required by applicable law or agreed to in writing, software\n *   distributed under the License is distributed on an \"AS IS\" BASIS,\n *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n *   See the License for the specific language governing permissions and\n *   limitations under the License.\n *\n * @file cuda/max_uva_threads.cc\n * @brief Max uva threads variable setter function.\n */\n#include \"./max_uva_threads.h\"\n\nnamespace graphbolt {\nnamespace cuda {\n\nvoid set_max_uva_threads(int64_t count) { max_uva_threads = count; }\n\n}  // namespace cuda\n}  // namespace graphbolt\n"
  },
  {
    "path": "graphbolt/src/cuda/max_uva_threads.h",
    "content": "/**\n *   Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)\n *   All rights reserved.\n *\n *   Licensed under the Apache License, Version 2.0 (the \"License\");\n *   you may not use this file except in compliance with the License.\n *   You may obtain a copy of the License at\n *\n *       http://www.apache.org/licenses/LICENSE-2.0\n *\n *   Unless required by applicable law or agreed to in writing, software\n *   distributed under the License is distributed on an \"AS IS\" BASIS,\n *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n *   See the License for the specific language governing permissions and\n *   limitations under the License.\n *\n * @file cuda/max_uva_threads.h\n * @brief Max uva threads variable declaration.\n */\n#ifndef GRAPHBOLT_MAX_UVA_THREADS_H_\n#define GRAPHBOLT_MAX_UVA_THREADS_H_\n\n#include <cstdint>\n#include <optional>\n\nnamespace graphbolt {\nnamespace cuda {\n\n/** @brief Set a limit on the number of CUDA threads for UVA accesses. */\ninline std::optional<int64_t> max_uva_threads;\n\nvoid set_max_uva_threads(int64_t count);\n\n}  // namespace cuda\n}  // namespace graphbolt\n\n#endif  // GRAPHBOLT_MAX_UVA_THREADS_H_\n"
  },
  {
    "path": "graphbolt/src/cuda/neighbor_sampler.cu",
    "content": "/**\n *   Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)\n *   All rights reserved.\n *\n *   Licensed under the Apache License, Version 2.0 (the \"License\");\n *   you may not use this file except in compliance with the License.\n *   You may obtain a copy of the License at\n *\n *       http://www.apache.org/licenses/LICENSE-2.0\n *\n *   Unless required by applicable law or agreed to in writing, software\n *   distributed under the License is distributed on an \"AS IS\" BASIS,\n *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n *   See the License for the specific language governing permissions and\n *   limitations under the License.\n *\n * @file cuda/index_select_impl.cu\n * @brief Index select operator implementation on CUDA.\n */\n#include <c10/core/ScalarType.h>\n#include <curand_kernel.h>\n#include <graphbolt/continuous_seed.h>\n#include <graphbolt/cuda_ops.h>\n#include <graphbolt/cuda_sampling_ops.h>\n#include <thrust/copy.h>\n#include <thrust/gather.h>\n#include <thrust/iterator/counting_iterator.h>\n#include <thrust/iterator/transform_iterator.h>\n#include <thrust/iterator/transform_output_iterator.h>\n\n#include <algorithm>\n#include <array>\n#include <cub/cub.cuh>\n#if __CUDA_ARCH__ >= 700\n#include <cuda/atomic>\n#endif  // __CUDA_ARCH__ >= 700\n#include <limits>\n#include <numeric>\n#include <type_traits>\n\n#include \"../macro.h\"\n#include \"../random.h\"\n#include \"../utils.h\"\n#include \"./common.h\"\n#include \"./utils.h\"\n\nnamespace graphbolt {\nnamespace ops {\n\nconstexpr int BLOCK_SIZE = 128;\n\ninline __device__ int64_t AtomicMax(int64_t* const address, const int64_t val) {\n  // To match the type of \"::atomicCAS\", ignore lint warning.\n  using Type = unsigned long long int;  // NOLINT\n\n  static_assert(sizeof(Type) == sizeof(*address), \"Type width must match\");\n\n  return atomicMax(reinterpret_cast<Type*>(address), static_cast<Type>(val));\n}\n\ninline __device__ int32_t AtomicMax(int32_t* const address, const int32_t val) {\n  // To match the type of \"::atomicCAS\", ignore lint warning.\n  using Type = int;  // NOLINT\n\n  static_assert(sizeof(Type) == sizeof(*address), \"Type width must match\");\n\n  return atomicMax(reinterpret_cast<Type*>(address), static_cast<Type>(val));\n}\n\n/**\n * @brief Performs neighbor sampling and fills the edge_ids array with\n * original edge ids if sliced_indptr is valid. If not, then it fills the edge\n * ids array with numbers upto the node degree.\n */\ntemplate <typename indptr_t, typename indices_t>\n__global__ void _ComputeRandomsNS(\n    const int64_t num_edges, const indptr_t* const sliced_indptr,\n    const indptr_t* const sub_indptr, const indptr_t* const output_indptr,\n    const indices_t* const csr_rows, const uint64_t random_seed,\n    indptr_t* edge_ids) {\n  int64_t i = blockIdx.x * blockDim.x + threadIdx.x;\n  const int stride = gridDim.x * blockDim.x;\n\n  curandStatePhilox4_32_10_t rng;\n  curand_init(random_seed, i, 0, &rng);\n\n  while (i < num_edges) {\n    const auto row_position = csr_rows[i];\n    const auto row_offset = i - sub_indptr[row_position];\n    const auto output_offset = output_indptr[row_position];\n    const auto fanout = output_indptr[row_position + 1] - output_offset;\n    const auto rnd =\n        row_offset < fanout ? row_offset : curand(&rng) % (row_offset + 1);\n    if (rnd < fanout) {\n      const indptr_t edge_id =\n          row_offset + (sliced_indptr ? sliced_indptr[row_position] : 0);\n#if __CUDA_ARCH__ >= 700\n      ::cuda::atomic_ref<indptr_t, ::cuda::thread_scope_device> a(\n          edge_ids[output_offset + rnd]);\n      a.fetch_max(edge_id, ::cuda::std::memory_order_relaxed);\n#else\n      AtomicMax(edge_ids + output_offset + rnd, edge_id);\n#endif  // __CUDA_ARCH__\n    }\n\n    i += stride;\n  }\n}\n\n/**\n * @brief Fills the random_arr with random numbers and the edge_ids array with\n * original edge ids. When random_arr is sorted along with edge_ids, the first\n * fanout elements of each row gives us the sampled edges.\n */\ntemplate <\n    typename float_t, typename indptr_t, typename indices_t, typename weights_t,\n    typename edge_id_t>\n__global__ void _ComputeRandoms(\n    const int64_t num_edges, const indptr_t* const sliced_indptr,\n    const indptr_t* const sub_indptr, const indices_t* const csr_rows,\n    const weights_t* const sliced_weights, const indices_t* const indices,\n    const continuous_seed random_seed, float_t* random_arr,\n    edge_id_t* edge_ids) {\n  int64_t i = blockIdx.x * blockDim.x + threadIdx.x;\n  const int stride = gridDim.x * blockDim.x;\n  const auto labor = indices != nullptr;\n  const float_t inf =\n      static_cast<float_t>(std::numeric_limits<float>::infinity());\n\n  while (i < num_edges) {\n    const auto row_position = csr_rows[i];\n    const auto row_offset = i - sub_indptr[row_position];\n    const auto in_idx = sliced_indptr[row_position] + row_offset;\n    const auto rnd = random_seed.uniform(labor ? indices[in_idx] : i);\n    const auto prob =\n        sliced_weights ? sliced_weights[i] : static_cast<weights_t>(1);\n    const auto exp_rnd = -__logf(rnd);\n    const float_t adjusted_rnd =\n        prob > 0 ? static_cast<float_t>(exp_rnd / prob) : inf;\n    random_arr[i] = adjusted_rnd;\n    edge_ids[i] = row_offset;\n\n    i += stride;\n  }\n}\n\nstruct IsPositive {\n  template <typename probs_t>\n  __host__ __device__ auto operator()(probs_t x) {\n    return x > 0;\n  }\n};\n\ntemplate <typename indptr_t>\nstruct MinInDegreeFanout {\n  const indptr_t* in_degree;\n  const int64_t* fanouts;\n  size_t num_fanouts;\n  __host__ __device__ auto operator()(int64_t i) {\n    return static_cast<indptr_t>(\n        min(static_cast<int64_t>(in_degree[i]), fanouts[i % num_fanouts]));\n  }\n};\n\ntemplate <typename indptr_t, typename indices_t>\nstruct IteratorFunc {\n  indptr_t* indptr;\n  indices_t* indices;\n  __host__ __device__ auto operator()(int64_t i) { return indices + indptr[i]; }\n};\n\ntemplate <typename indptr_t>\nstruct AddOffset {\n  indptr_t offset;\n  template <typename edge_id_t>\n  __host__ __device__ indptr_t operator()(edge_id_t x) {\n    return x + offset;\n  }\n};\n\ntemplate <typename indptr_t, typename indices_t>\nstruct IteratorFuncAddOffset {\n  indptr_t* indptr;\n  indptr_t* sliced_indptr;\n  indices_t* indices;\n  __host__ __device__ auto operator()(int64_t i) {\n    return thrust::transform_output_iterator{\n        indices + indptr[i], AddOffset<indptr_t>{sliced_indptr[i]}};\n  }\n};\n\ntemplate <typename indptr_t, typename in_degree_iterator_t>\nstruct SegmentEndFunc {\n  indptr_t* indptr;\n  in_degree_iterator_t in_degree;\n  __host__ __device__ auto operator()(int64_t i) {\n    return indptr[i] + in_degree[i];\n  }\n};\n\nc10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(\n    torch::Tensor indptr, torch::Tensor indices,\n    torch::optional<torch::Tensor> seeds,\n    torch::optional<std::vector<int64_t>> seed_offsets,\n    const std::vector<int64_t>& fanouts, bool replace, bool layer,\n    bool returning_indices_is_optional,\n    torch::optional<torch::Tensor> type_per_edge,\n    torch::optional<torch::Tensor> probs_or_mask,\n    torch::optional<torch::Tensor> node_type_offset,\n    torch::optional<torch::Dict<std::string, int64_t>> node_type_to_id,\n    torch::optional<torch::Dict<std::string, int64_t>> edge_type_to_id,\n    torch::optional<torch::Tensor> random_seed_tensor, float seed2_contribution,\n    // Optional temporal sampling arguments begin.\n    torch::optional<torch::Tensor> seeds_timestamp,\n    torch::optional<torch::Tensor> seeds_pre_time_window,\n    torch::optional<torch::Tensor> node_timestamp,\n    torch::optional<torch::Tensor> edge_timestamp\n    // Optional temporal sampling arguments end.\n) {\n  // When seed_offsets.has_value() in the hetero case, we compute the output of\n  // sample_neighbors _convert_to_sampled_subgraph in a fused manner so that\n  // _convert_to_sampled_subgraph only has to perform slices over the returned\n  // indptr and indices tensors to form CSC outputs for each edge type.\n  TORCH_CHECK(!replace, \"Sampling with replacement is not supported yet!\");\n  // Assume that indptr, indices, seeds, type_per_edge and probs_or_mask\n  // are all resident on the GPU. If not, it is better to first extract them\n  // before calling this function.\n  auto allocator = cuda::GetAllocator();\n  auto num_rows =\n      seeds.has_value() ? seeds.value().size(0) : indptr.size(0) - 1;\n  auto fanouts_pinned = torch::empty(\n      fanouts.size(),\n      c10::TensorOptions().dtype(torch::kLong).pinned_memory(true));\n  auto fanouts_pinned_ptr = fanouts_pinned.data_ptr<int64_t>();\n  for (size_t i = 0; i < fanouts.size(); i++) {\n    fanouts_pinned_ptr[i] =\n        fanouts[i] >= 0 ? fanouts[i] : std::numeric_limits<int64_t>::max();\n  }\n  // Finally, copy the adjusted fanout values to the device memory.\n  auto fanouts_device = allocator.AllocateStorage<int64_t>(fanouts.size());\n  CUDA_CALL(cudaMemcpyAsync(\n      fanouts_device.get(), fanouts_pinned_ptr,\n      sizeof(int64_t) * fanouts.size(), cudaMemcpyHostToDevice,\n      cuda::GetCurrentStream()));\n  auto in_degree_and_sliced_indptr = SliceCSCIndptr(indptr, seeds);\n  auto in_degree = std::get<0>(in_degree_and_sliced_indptr);\n  auto sliced_indptr = std::get<1>(in_degree_and_sliced_indptr);\n  const auto homo_in_degree = in_degree;\n  const auto homo_sliced_indptr = sliced_indptr;\n  auto max_in_degree = torch::empty(\n      1,\n      c10::TensorOptions().dtype(in_degree.scalar_type()).pinned_memory(true));\n  AT_DISPATCH_INDEX_TYPES(\n      indptr.scalar_type(), \"SampleNeighborsMaxInDegree\", ([&] {\n        CUB_CALL(\n            DeviceReduce::Max, in_degree.data_ptr<index_t>(),\n            max_in_degree.data_ptr<index_t>(), num_rows);\n      }));\n  // Protect access to max_in_degree with a CUDAEvent\n  at::cuda::CUDAEvent max_in_degree_event;\n  max_in_degree_event.record();\n  torch::optional<int64_t> num_edges;\n  torch::Tensor sub_indptr;\n  if (!seeds.has_value()) {\n    num_edges = indices.size(0);\n    sub_indptr = indptr;\n  }\n  torch::optional<torch::Tensor> sliced_probs_or_mask;\n  if (probs_or_mask.has_value()) {\n    if (seeds.has_value()) {\n      torch::Tensor sliced_probs_or_mask_tensor;\n      std::tie(sub_indptr, sliced_probs_or_mask_tensor) = IndexSelectCSCImpl(\n          in_degree, sliced_indptr, probs_or_mask.value(), seeds.value(),\n          indptr.size(0) - 2, num_edges);\n      sliced_probs_or_mask = sliced_probs_or_mask_tensor;\n      num_edges = sliced_probs_or_mask_tensor.size(0);\n    } else {\n      sliced_probs_or_mask = probs_or_mask;\n    }\n  }\n  if (fanouts.size() > 1) {\n    torch::Tensor sliced_type_per_edge;\n    if (seeds.has_value()) {\n      std::tie(sub_indptr, sliced_type_per_edge) = IndexSelectCSCImpl(\n          in_degree, sliced_indptr, type_per_edge.value(), seeds.value(),\n          indptr.size(0) - 2, num_edges);\n    } else {\n      sliced_type_per_edge = type_per_edge.value();\n    }\n    std::tie(sub_indptr, in_degree, sliced_indptr) = SliceCSCIndptrHetero(\n        sub_indptr, sliced_type_per_edge, sliced_indptr, fanouts.size());\n    num_rows = sliced_indptr.size(0);\n    num_edges = sliced_type_per_edge.size(0);\n  }\n  // If sub_indptr was not computed in the two code blocks above:\n  if (seeds.has_value() && !probs_or_mask.has_value() && fanouts.size() <= 1) {\n    sub_indptr = ExclusiveCumSum(in_degree);\n  }\n  torch::optional<torch::Tensor> homo_coo_rows;\n  if (seeds_timestamp.has_value()) {\n    // Temporal sampling is enabled.\n    const auto homo_sub_indptr =\n        fanouts.size() > 1 ? ExclusiveCumSum(homo_in_degree) : sub_indptr;\n    homo_coo_rows = ExpandIndptrImpl(\n        homo_sub_indptr, indices.scalar_type(), torch::nullopt, num_edges);\n    num_edges = homo_coo_rows->size(0);\n    const auto is_probs_initialized = sliced_probs_or_mask.has_value();\n    if (!is_probs_initialized) {\n      sliced_probs_or_mask =\n          torch::empty(*num_edges, sub_indptr.options().dtype(torch::kBool));\n    }\n    GRAPHBOLT_DISPATCH_ALL_TYPES(\n        sliced_probs_or_mask->scalar_type(),\n        \"SampleNeighborsTemporalProbsOrMask\", ([&] {\n          const scalar_t* input_probs_ptr =\n              is_probs_initialized ? sliced_probs_or_mask->data_ptr<scalar_t>()\n                                   : nullptr;\n          auto output_probs_ptr = sliced_probs_or_mask->data_ptr<scalar_t>();\n          using timestamp_t = int64_t;\n          const auto seeds_timestamp_ptr =\n              seeds_timestamp->data_ptr<timestamp_t>();\n          const timestamp_t* seeds_pre_time_window_ptr =\n              seeds_pre_time_window.has_value()\n                  ? seeds_pre_time_window->data_ptr<timestamp_t>()\n                  : nullptr;\n          const timestamp_t* node_timestamp_ptr =\n              node_timestamp.has_value()\n                  ? node_timestamp->data_ptr<timestamp_t>()\n                  : nullptr;\n          const timestamp_t* edge_timestamp_ptr =\n              edge_timestamp.has_value()\n                  ? edge_timestamp->data_ptr<timestamp_t>()\n                  : nullptr;\n          AT_DISPATCH_INDEX_TYPES(\n              homo_coo_rows->scalar_type(),\n              \"SampleNeighborsTemporalMaskIndices\", ([&] {\n                const auto coo_rows_ptr = homo_coo_rows->data_ptr<index_t>();\n                const auto indices_ptr = indices.data_ptr<index_t>();\n                AT_DISPATCH_INDEX_TYPES(\n                    homo_sliced_indptr.scalar_type(),\n                    \"SampleNeighborsTemporalMaskIndptr\", ([&] {\n                      const auto sliced_indptr_data =\n                          homo_sliced_indptr.data_ptr<index_t>();\n                      const auto sub_indptr_data =\n                          homo_sub_indptr.data_ptr<index_t>();\n                      CUB_CALL(\n                          DeviceFor::Bulk, *num_edges,\n                          [=] __device__(int64_t i) {\n                            const auto row = coo_rows_ptr[i];\n                            const auto seed_timestamp =\n                                seeds_timestamp_ptr[row];\n                            const auto row_offset = i - sub_indptr_data[row];\n                            const auto in_idx =\n                                sliced_indptr_data[row] + row_offset;\n                            bool mask = true;\n                            if (node_timestamp_ptr) {\n                              const auto index = indices_ptr[in_idx];\n                              const auto neighbor_timestamp =\n                                  node_timestamp_ptr[index];\n                              mask &= neighbor_timestamp < seed_timestamp;\n                              if (seeds_pre_time_window_ptr) {\n                                mask &= neighbor_timestamp >\n                                        seed_timestamp -\n                                            seeds_pre_time_window_ptr[row];\n                              }\n                            }\n                            if (edge_timestamp_ptr) {\n                              const auto edge_timestamp =\n                                  edge_timestamp_ptr[in_idx];\n                              mask &= edge_timestamp < seed_timestamp;\n                              if (seeds_pre_time_window_ptr) {\n                                mask &= edge_timestamp >\n                                        seed_timestamp -\n                                            seeds_pre_time_window_ptr[row];\n                              }\n                            }\n                            const scalar_t prob = input_probs_ptr\n                                                      ? input_probs_ptr[i]\n                                                      : scalar_t{1};\n                            output_probs_ptr[i] =\n                                prob * static_cast<scalar_t>(mask);\n                          });\n                    }));\n              }));\n        }));\n  }\n  const continuous_seed random_seed = [&] {\n    if (random_seed_tensor.has_value()) {\n      return continuous_seed(random_seed_tensor.value(), seed2_contribution);\n    } else {\n      return continuous_seed{RandomEngine::ThreadLocal()->RandInt(\n          static_cast<int64_t>(0), std::numeric_limits<int64_t>::max())};\n    }\n  }();\n  auto output_indptr = torch::empty_like(sub_indptr);\n  torch::Tensor picked_eids;\n  torch::optional<torch::Tensor> output_indices;\n\n  AT_DISPATCH_INDEX_TYPES(\n      indptr.scalar_type(), \"SampleNeighborsIndptr\", ([&] {\n        using indptr_t = index_t;\n        if (sliced_probs_or_mask.has_value()) {\n          // Count nonzero probs into in_degree.\n          GRAPHBOLT_DISPATCH_ALL_TYPES(\n              sliced_probs_or_mask->scalar_type(),\n              \"SampleNeighborsPositiveProbs\", ([&] {\n                using probs_t = scalar_t;\n                auto is_nonzero = thrust::make_transform_iterator(\n                    sliced_probs_or_mask->data_ptr<probs_t>(), IsPositive{});\n                CUB_CALL(\n                    DeviceSegmentedReduce::Sum, is_nonzero,\n                    in_degree.data_ptr<indptr_t>(), num_rows,\n                    sub_indptr.data_ptr<indptr_t>(),\n                    sub_indptr.data_ptr<indptr_t>() + 1);\n              }));\n        }\n        thrust::counting_iterator<int64_t> iota(0);\n        auto sampled_degree = thrust::make_transform_iterator(\n            iota, MinInDegreeFanout<indptr_t>{\n                      in_degree.data_ptr<indptr_t>(), fanouts_device.get(),\n                      fanouts.size()});\n\n        // Compute output_indptr.\n        CUB_CALL(\n            DeviceScan::ExclusiveSum, sampled_degree,\n            output_indptr.data_ptr<indptr_t>(), num_rows + 1);\n\n        auto num_sampled_edges =\n            cuda::CopyScalar{output_indptr.data_ptr<indptr_t>() + num_rows};\n\n        // This operation is placed after num_sampled_edges copy is started to\n        // hide the latency of copy synchronization later.\n        torch::Tensor coo_rows;\n        if (!homo_coo_rows.has_value() || fanouts.size() > 1) {\n          coo_rows = ExpandIndptrImpl(\n              sub_indptr, indices.scalar_type(), torch::nullopt, num_edges);\n          num_edges = coo_rows.size(0);\n        } else {\n          coo_rows = *homo_coo_rows;\n        }\n\n        // Find the smallest integer type to store the edge id offsets. We synch\n        // the CUDAEvent so that the access is safe.\n        auto compute_num_bits = [&] {\n          max_in_degree_event.synchronize();\n          return cuda::NumberOfBits(max_in_degree.data_ptr<indptr_t>()[0]);\n        };\n        if (layer || sliced_probs_or_mask.has_value()) {\n          const int num_bits = compute_num_bits();\n          std::array<int, 4> type_bits = {8, 16, 32, 64};\n          const auto type_index =\n              std::lower_bound(type_bits.begin(), type_bits.end(), num_bits) -\n              type_bits.begin();\n          std::array<torch::ScalarType, 5> types = {\n              torch::kByte, torch::kInt16, torch::kInt32, torch::kLong,\n              torch::kLong};\n          auto edge_id_dtype = types[type_index];\n          AT_DISPATCH_INTEGRAL_TYPES(\n              edge_id_dtype, \"SampleNeighborsEdgeIDs\", ([&] {\n                using edge_id_t = std::make_unsigned_t<scalar_t>;\n                TORCH_CHECK(\n                    num_bits <= sizeof(edge_id_t) * 8,\n                    \"Selected edge_id_t must be capable of storing edge_ids.\");\n                // Using bfloat16 for random numbers works just as reliably as\n                // float32 and provides around 30% speedup.\n                using rnd_t = nv_bfloat16;\n                auto randoms =\n                    allocator.AllocateStorage<rnd_t>(num_edges.value());\n                auto randoms_sorted =\n                    allocator.AllocateStorage<rnd_t>(num_edges.value());\n                auto edge_id_segments =\n                    allocator.AllocateStorage<edge_id_t>(num_edges.value());\n                auto sorted_edge_id_segments =\n                    allocator.AllocateStorage<edge_id_t>(num_edges.value());\n                AT_DISPATCH_INDEX_TYPES(\n                    indices.scalar_type(), \"SampleNeighborsIndices\", ([&] {\n                      using indices_t = index_t;\n                      auto probs_or_mask_scalar_type = torch::kFloat32;\n                      if (sliced_probs_or_mask.has_value()) {\n                        probs_or_mask_scalar_type =\n                            sliced_probs_or_mask->scalar_type();\n                      }\n                      GRAPHBOLT_DISPATCH_ALL_TYPES(\n                          probs_or_mask_scalar_type, \"SampleNeighborsProbs\",\n                          ([&] {\n                            using probs_t = scalar_t;\n                            probs_t* sliced_probs_ptr = nullptr;\n                            if (sliced_probs_or_mask.has_value()) {\n                              sliced_probs_ptr =\n                                  sliced_probs_or_mask->data_ptr<probs_t>();\n                            }\n                            const indices_t* indices_ptr =\n                                layer ? indices.data_ptr<indices_t>() : nullptr;\n                            const dim3 block(BLOCK_SIZE);\n                            const dim3 grid(\n                                (num_edges.value() + BLOCK_SIZE - 1) /\n                                BLOCK_SIZE);\n                            // Compute row and random number pairs.\n                            CUDA_KERNEL_CALL(\n                                _ComputeRandoms, grid, block, 0,\n                                num_edges.value(),\n                                sliced_indptr.data_ptr<indptr_t>(),\n                                sub_indptr.data_ptr<indptr_t>(),\n                                coo_rows.data_ptr<indices_t>(),\n                                sliced_probs_ptr, indices_ptr, random_seed,\n                                randoms.get(), edge_id_segments.get());\n                          }));\n                    }));\n\n                // Sort the random numbers along with edge ids, after\n                // sorting the first fanout elements of each row will\n                // give us the sampled edges.\n                CUB_CALL(\n                    DeviceSegmentedSort::SortPairs, randoms.get(),\n                    randoms_sorted.get(), edge_id_segments.get(),\n                    sorted_edge_id_segments.get(), num_edges.value(), num_rows,\n                    sub_indptr.data_ptr<indptr_t>(),\n                    sub_indptr.data_ptr<indptr_t>() + 1);\n\n                picked_eids = torch::empty(\n                    static_cast<indptr_t>(num_sampled_edges),\n                    sub_indptr.options());\n\n                // Need to sort the sampled edges only when fanouts.size() == 1\n                // since multiple fanout sampling case is automatically going to\n                // be sorted.\n                if (type_per_edge && fanouts.size() == 1) {\n                  // Ensuring sort result still ends up in\n                  // sorted_edge_id_segments\n                  std::swap(edge_id_segments, sorted_edge_id_segments);\n                  auto sampled_segment_end_it = thrust::make_transform_iterator(\n                      iota,\n                      SegmentEndFunc<indptr_t, decltype(sampled_degree)>{\n                          sub_indptr.data_ptr<indptr_t>(), sampled_degree});\n                  CUB_CALL(\n                      DeviceSegmentedSort::SortKeys, edge_id_segments.get(),\n                      sorted_edge_id_segments.get(), picked_eids.size(0),\n                      num_rows, sub_indptr.data_ptr<indptr_t>(),\n                      sampled_segment_end_it);\n                }\n\n                auto input_buffer_it = thrust::make_transform_iterator(\n                    iota, IteratorFunc<indptr_t, edge_id_t>{\n                              sub_indptr.data_ptr<indptr_t>(),\n                              sorted_edge_id_segments.get()});\n                auto output_buffer_it = thrust::make_transform_iterator(\n                    iota, IteratorFuncAddOffset<indptr_t, indptr_t>{\n                              output_indptr.data_ptr<indptr_t>(),\n                              sliced_indptr.data_ptr<indptr_t>(),\n                              picked_eids.data_ptr<indptr_t>()});\n                constexpr int64_t max_copy_at_once =\n                    std::numeric_limits<int32_t>::max();\n\n                // Copy the sampled edge ids into picked_eids tensor.\n                for (int64_t i = 0; i < num_rows; i += max_copy_at_once) {\n                  CUB_CALL(\n                      DeviceCopy::Batched, input_buffer_it + i,\n                      output_buffer_it + i, sampled_degree + i,\n                      std::min(num_rows - i, max_copy_at_once));\n                }\n              }));\n        } else {  // Non-weighted neighbor sampling.\n          picked_eids = torch::zeros(num_edges.value(), sub_indptr.options());\n          const auto sort_needed = type_per_edge && fanouts.size() == 1;\n          const auto sliced_indptr_ptr =\n              sort_needed ? nullptr : sliced_indptr.data_ptr<indptr_t>();\n\n          const dim3 block(BLOCK_SIZE);\n          const dim3 grid(\n              (std::min(num_edges.value(), static_cast<int64_t>(1 << 20)) +\n               BLOCK_SIZE - 1) /\n              BLOCK_SIZE);\n          AT_DISPATCH_INDEX_TYPES(\n              indices.scalar_type(), \"SampleNeighborsIndices\", ([&] {\n                using indices_t = index_t;\n                // Compute row and random number pairs.\n                CUDA_KERNEL_CALL(\n                    _ComputeRandomsNS, grid, block, 0, num_edges.value(),\n                    sliced_indptr_ptr, sub_indptr.data_ptr<indptr_t>(),\n                    output_indptr.data_ptr<indptr_t>(),\n                    coo_rows.data_ptr<indices_t>(), random_seed.get_seed(0),\n                    picked_eids.data_ptr<indptr_t>());\n              }));\n\n          picked_eids =\n              picked_eids.slice(0, 0, static_cast<indptr_t>(num_sampled_edges));\n\n          // Need to sort the sampled edges only when fanouts.size() == 1\n          // since multiple fanout sampling case is automatically going to\n          // be sorted.\n          if (sort_needed) {\n            const int num_bits = compute_num_bits();\n            std::array<int, 4> type_bits = {8, 15, 31, 63};\n            const auto type_index =\n                std::lower_bound(type_bits.begin(), type_bits.end(), num_bits) -\n                type_bits.begin();\n            std::array<torch::ScalarType, 5> types = {\n                torch::kByte, torch::kInt16, torch::kInt32, torch::kLong,\n                torch::kLong};\n            auto edge_id_dtype = types[type_index];\n            AT_DISPATCH_INTEGRAL_TYPES(\n                edge_id_dtype, \"SampleNeighborsEdgeIDs\", ([&] {\n                  using edge_id_t = scalar_t;\n                  TORCH_CHECK(\n                      num_bits <= sizeof(edge_id_t) * 8,\n                      \"Selected edge_id_t must be capable of storing \"\n                      \"edge_ids.\");\n                  auto picked_offsets = picked_eids.to(edge_id_dtype);\n                  auto sorted_offsets = torch::empty_like(picked_offsets);\n                  CUB_CALL(\n                      DeviceSegmentedSort::SortKeys,\n                      picked_offsets.data_ptr<edge_id_t>(),\n                      sorted_offsets.data_ptr<edge_id_t>(), picked_eids.size(0),\n                      num_rows, output_indptr.data_ptr<indptr_t>(),\n                      output_indptr.data_ptr<indptr_t>() + 1);\n                  auto edge_id_offsets = ExpandIndptrImpl(\n                      output_indptr, picked_eids.scalar_type(), sliced_indptr,\n                      picked_eids.size(0));\n                  picked_eids = sorted_offsets.to(picked_eids.scalar_type()) +\n                                edge_id_offsets;\n                }));\n          }\n        }\n\n        if (!returning_indices_is_optional || utils::is_on_gpu(indices)) {\n          output_indices = Gather(indices, picked_eids);\n        }\n      }));\n\n  torch::optional<torch::Tensor> output_type_per_edge;\n  torch::optional<torch::Tensor> edge_offsets;\n  if (type_per_edge && seed_offsets) {\n    const int64_t num_etypes =\n        edge_type_to_id.has_value() ? edge_type_to_id->size() : 1;\n    // If we performed homogenous sampling on hetero graph, we have to look at\n    // type_per_edge of sampled edges and determine the offsets of different\n    // sampled etypes and convert to fused hetero indptr representation.\n    if (fanouts.size() == 1) {\n      output_type_per_edge = Gather(*type_per_edge, picked_eids);\n      torch::Tensor output_in_degree, sliced_output_indptr;\n      sliced_output_indptr =\n          output_indptr.slice(0, 0, output_indptr.size(0) - 1);\n      std::tie(output_indptr, output_in_degree, sliced_output_indptr) =\n          SliceCSCIndptrHetero(\n              output_indptr, output_type_per_edge.value(), sliced_output_indptr,\n              num_etypes);\n      // We use num_rows to hold num_seeds * num_etypes. So, it needs to be\n      // updated when sampling with a single fanout value when the graph is\n      // heterogenous.\n      num_rows = sliced_output_indptr.size(0);\n    }\n    // Here, we check what are the dst node types for the given seeds so that\n    // we can compute the output indptr space later.\n    std::vector<int64_t> etype_id_to_dst_ntype_id(num_etypes);\n    // Here, we check what are the src node types for the given seeds so that\n    // we can subtract source node offset from indices later.\n    auto etype_id_to_src_ntype_id = torch::empty(\n        2 * num_etypes,\n        c10::TensorOptions().dtype(torch::kLong).pinned_memory(true));\n    auto etype_id_to_src_ntype_id_ptr =\n        etype_id_to_src_ntype_id.data_ptr<int64_t>();\n    for (auto& etype_and_id : edge_type_to_id.value()) {\n      auto etype = etype_and_id.key();\n      auto id = etype_and_id.value();\n      auto [src_type, dst_type] = utils::parse_src_dst_ntype_from_etype(etype);\n      etype_id_to_dst_ntype_id[id] = node_type_to_id->at(dst_type);\n      etype_id_to_src_ntype_id_ptr[2 * id] =\n          etype_id_to_src_ntype_id_ptr[2 * id + 1] =\n              node_type_to_id->at(src_type);\n    }\n    auto indices_offsets_device = torch::empty(\n        etype_id_to_src_ntype_id.size(0),\n        picked_eids.options().dtype(torch::kLong));\n    AT_DISPATCH_INDEX_TYPES(\n        node_type_offset->scalar_type(), \"SampleNeighborsNodeTypeOffset\", ([&] {\n          THRUST_CALL(\n              gather, etype_id_to_src_ntype_id_ptr,\n              etype_id_to_src_ntype_id_ptr + etype_id_to_src_ntype_id.size(0),\n              node_type_offset->data_ptr<index_t>(),\n              indices_offsets_device.data_ptr<int64_t>());\n        }));\n    // For each edge type, we compute the start and end offsets to index into\n    // indptr to form the final output_indptr.\n    auto indptr_offsets = torch::empty(\n        num_etypes * 2,\n        c10::TensorOptions().dtype(torch::kLong).pinned_memory(true));\n    auto indptr_offsets_ptr = indptr_offsets.data_ptr<int64_t>();\n    // We compute the indptr offsets here, right now, output_indptr is of size\n    // # seeds * num_etypes + 1. We can simply take slices to get correct output\n    // indptr. The final output_indptr is same as current indptr except that\n    // some intermediate values are removed to change the node ids space from\n    // all of the seed vertices to the node id space of the dst node type of\n    // each edge type.\n    for (int i = 0; i < num_etypes; i++) {\n      indptr_offsets_ptr[2 * i] = num_rows / num_etypes * i +\n                                  seed_offsets->at(etype_id_to_dst_ntype_id[i]);\n      indptr_offsets_ptr[2 * i + 1] =\n          num_rows / num_etypes * i +\n          seed_offsets->at(etype_id_to_dst_ntype_id[i] + 1);\n    }\n    auto permutation = torch::arange(\n        0, num_rows * num_etypes, num_etypes, output_indptr.options());\n    permutation =\n        permutation.remainder(num_rows) + permutation.div(num_rows, \"floor\");\n    // This permutation, when applied sorts the sampled edges with respect to\n    // edge types.\n    auto [output_in_degree, sliced_output_indptr] =\n        SliceCSCIndptr(output_indptr, permutation);\n    std::tie(output_indptr, picked_eids) = IndexSelectCSCImpl(\n        output_in_degree, sliced_output_indptr, picked_eids, permutation,\n        num_rows - 1, picked_eids.size(0));\n    edge_offsets = torch::empty(\n        num_etypes * 2, c10::TensorOptions()\n                            .dtype(output_indptr.scalar_type())\n                            .pinned_memory(true));\n    auto edge_offsets_device =\n        torch::empty(num_etypes * 2, output_indptr.options());\n    at::cuda::CUDAEvent edge_offsets_event;\n    AT_DISPATCH_INDEX_TYPES(\n        indptr.scalar_type(), \"SampleNeighborsEdgeOffsets\", ([&] {\n          auto edge_offsets_pinned_device_pair =\n              thrust::make_transform_output_iterator(\n                  thrust::make_zip_iterator(\n                      edge_offsets->data_ptr<index_t>(),\n                      edge_offsets_device.data_ptr<index_t>()),\n                  [=] __device__(index_t x) {\n                    return thrust::make_tuple(x, x);\n                  });\n          THRUST_CALL(\n              gather, indptr_offsets_ptr,\n              indptr_offsets_ptr + indptr_offsets.size(0),\n              output_indptr.data_ptr<index_t>(),\n              edge_offsets_pinned_device_pair);\n        }));\n    edge_offsets_event.record();\n    if (output_indices.has_value()) {\n      auto indices_offset_subtract = ExpandIndptrImpl(\n          edge_offsets_device, indices.scalar_type(), indices_offsets_device,\n          output_indices->size(0));\n      // The output_indices is permuted here.\n      std::tie(output_indptr, output_indices) = IndexSelectCSCImpl(\n          output_in_degree, sliced_output_indptr, *output_indices, permutation,\n          num_rows - 1, output_indices->size(0));\n      *output_indices -= indices_offset_subtract;\n    }\n    auto output_indptr_offsets = torch::empty(\n        num_etypes * 2,\n        c10::TensorOptions().dtype(torch::kLong).pinned_memory(true));\n    auto output_indptr_offsets_ptr = output_indptr_offsets.data_ptr<int64_t>();\n    std::vector<torch::Tensor> indptr_list;\n    for (int i = 0; i < num_etypes; i++) {\n      indptr_list.push_back(output_indptr.slice(\n          0, indptr_offsets_ptr[2 * i], indptr_offsets_ptr[2 * i + 1] + 1));\n      output_indptr_offsets_ptr[2 * i] =\n          i == 0 ? 0 : output_indptr_offsets_ptr[2 * i - 1];\n      output_indptr_offsets_ptr[2 * i + 1] =\n          output_indptr_offsets_ptr[2 * i] + indptr_list.back().size(0);\n    }\n    auto output_indptr_offsets_device = torch::empty(\n        output_indptr_offsets.size(0),\n        output_indptr.options().dtype(torch::kLong));\n    THRUST_CALL(\n        copy_n, output_indptr_offsets_ptr, output_indptr_offsets.size(0),\n        output_indptr_offsets_device.data_ptr<int64_t>());\n    // We form the final output indptr by concatenating pieces for different\n    // edge types.\n    output_indptr = torch::cat(indptr_list);\n    auto indptr_offset_subtract = ExpandIndptrImpl(\n        output_indptr_offsets_device, indptr.scalar_type(), edge_offsets_device,\n        output_indptr.size(0));\n    output_indptr -= indptr_offset_subtract;\n    edge_offsets_event.synchronize();\n    // We read the edge_offsets here, they are in pairs but we don't need it to\n    // be in pairs. So we remove the duplicate information from it and turn it\n    // into a real offsets array.\n    AT_DISPATCH_INDEX_TYPES(\n        indptr.scalar_type(), \"SampleNeighborsEdgeOffsetsCheck\", ([&] {\n          auto edge_offsets_ptr = edge_offsets->data_ptr<index_t>();\n          TORCH_CHECK(edge_offsets_ptr[0] == 0, \"edge_offsets is incorrect.\");\n          for (int i = 1; i < num_etypes; i++) {\n            TORCH_CHECK(\n                edge_offsets_ptr[2 * i - 1] == edge_offsets_ptr[2 * i],\n                \"edge_offsets is incorrect.\");\n          }\n          TORCH_CHECK(\n              edge_offsets_ptr[2 * num_etypes - 1] == picked_eids.size(0),\n              \"edge_offsets is incorrect.\");\n          for (int i = 0; i < num_etypes; i++) {\n            edge_offsets_ptr[i + 1] = edge_offsets_ptr[2 * i + 1];\n          }\n        }));\n    edge_offsets = edge_offsets->slice(0, 0, num_etypes + 1);\n  } else {\n    // Convert output_indptr back to homo by discarding intermediate offsets.\n    output_indptr =\n        output_indptr.slice(0, 0, output_indptr.size(0), fanouts.size());\n    if (type_per_edge)\n      output_type_per_edge = Gather(*type_per_edge, picked_eids);\n  }\n\n  return c10::make_intrusive<sampling::FusedSampledSubgraph>(\n      output_indptr, output_indices, picked_eids, seeds, torch::nullopt,\n      output_type_per_edge, edge_offsets);\n}\n\n}  //  namespace ops\n}  //  namespace graphbolt\n"
  },
  {
    "path": "graphbolt/src/cuda/sampling_utils.cu",
    "content": "/**\n *   Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)\n *   All rights reserved.\n *\n *   Licensed under the Apache License, Version 2.0 (the \"License\");\n *   you may not use this file except in compliance with the License.\n *   You may obtain a copy of the License at\n *\n *       http://www.apache.org/licenses/LICENSE-2.0\n *\n *   Unless required by applicable law or agreed to in writing, software\n *   distributed under the License is distributed on an \"AS IS\" BASIS,\n *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n *   See the License for the specific language governing permissions and\n *   limitations under the License.\n *\n * @file cuda/sampling_utils.cu\n * @brief Sampling utility function implementations on CUDA.\n */\n#include <thrust/for_each.h>\n#include <thrust/iterator/counting_iterator.h>\n\n#include <cub/cub.cuh>\n\n#include \"./common.h\"\n#include \"./utils.h\"\n\nnamespace graphbolt {\nnamespace ops {\n\n// Given rows and indptr, computes:\n// inrow_indptr[i] = indptr[rows[i]];\n// in_degree[i] = indptr[rows[i] + 1] - indptr[rows[i]];\ntemplate <typename indptr_t, typename nodes_t>\nstruct SliceFunc {\n  const nodes_t* rows;\n  const indptr_t* indptr;\n  indptr_t* in_degree;\n  indptr_t* inrow_indptr;\n  __host__ __device__ auto operator()(int64_t tIdx) {\n    const auto out_row = rows[tIdx];\n    const auto indptr_val = indptr[out_row];\n    const auto degree = indptr[out_row + 1] - indptr_val;\n    in_degree[tIdx] = degree;\n    inrow_indptr[tIdx] = indptr_val;\n  }\n};\n\n// Returns (indptr[nodes + 1] - indptr[nodes], indptr[nodes])\nstd::tuple<torch::Tensor, torch::Tensor> SliceCSCIndptr(\n    torch::Tensor indptr, torch::optional<torch::Tensor> nodes_optional) {\n  if (nodes_optional.has_value()) {\n    auto nodes = nodes_optional.value();\n    const int64_t num_nodes = nodes.size(0);\n    // Read indptr only once in case it is pinned and access is slow.\n    auto sliced_indptr =\n        torch::empty(num_nodes, nodes.options().dtype(indptr.scalar_type()));\n    // compute in-degrees\n    auto in_degree = torch::empty(\n        num_nodes + 1, nodes.options().dtype(indptr.scalar_type()));\n    thrust::counting_iterator<int64_t> iota(0);\n    AT_DISPATCH_INTEGRAL_TYPES(\n        indptr.scalar_type(), \"IndexSelectCSCIndptr\", ([&] {\n          using indptr_t = scalar_t;\n          AT_DISPATCH_INDEX_TYPES(\n              nodes.scalar_type(), \"IndexSelectCSCNodes\", ([&] {\n                using nodes_t = index_t;\n                THRUST_CALL(\n                    for_each, iota, iota + num_nodes,\n                    SliceFunc<indptr_t, nodes_t>{\n                        nodes.data_ptr<nodes_t>(), indptr.data_ptr<indptr_t>(),\n                        in_degree.data_ptr<indptr_t>(),\n                        sliced_indptr.data_ptr<indptr_t>()});\n              }));\n        }));\n    return {in_degree, sliced_indptr};\n  } else {\n    const int64_t num_nodes = indptr.size(0) - 1;\n    auto sliced_indptr = indptr.slice(0, 0, num_nodes);\n    auto in_degree = torch::empty(\n        num_nodes + 2, indptr.options().dtype(indptr.scalar_type()));\n    AT_DISPATCH_INTEGRAL_TYPES(\n        indptr.scalar_type(), \"IndexSelectCSCIndptr\", ([&] {\n          using indptr_t = scalar_t;\n          CUB_CALL(\n              DeviceAdjacentDifference::SubtractLeftCopy,\n              indptr.data_ptr<indptr_t>(), in_degree.data_ptr<indptr_t>(),\n              num_nodes + 1, cub::Difference{});\n        }));\n    in_degree = in_degree.slice(0, 1);\n    return {in_degree, sliced_indptr};\n  }\n}\n\ntemplate <typename indptr_t, typename etype_t>\nstruct EdgeTypeSearch {\n  const indptr_t* sub_indptr;\n  const indptr_t* sliced_indptr;\n  const etype_t* etypes;\n  int64_t num_fanouts;\n  int64_t num_rows;\n  indptr_t* new_sub_indptr;\n  indptr_t* new_sliced_indptr;\n  __host__ __device__ auto operator()(int64_t i) {\n    const auto homo_i = i / num_fanouts;\n    const auto indptr_i = sub_indptr[homo_i];\n    const auto degree = sub_indptr[homo_i + 1] - indptr_i;\n    const etype_t etype = i % num_fanouts;\n    auto offset = cub::LowerBound(etypes + indptr_i, degree, etype);\n    new_sub_indptr[i] = indptr_i + offset;\n    new_sliced_indptr[i] = sliced_indptr[homo_i] + offset;\n    if (i == num_rows - 1) new_sub_indptr[num_rows] = indptr_i + degree;\n  }\n};\n\nstd::tuple<torch::Tensor, torch::Tensor, torch::Tensor> SliceCSCIndptrHetero(\n    torch::Tensor sub_indptr, torch::Tensor etypes, torch::Tensor sliced_indptr,\n    int64_t num_fanouts) {\n  auto num_rows = (sub_indptr.size(0) - 1) * num_fanouts;\n  auto new_sub_indptr = torch::empty(num_rows + 1, sub_indptr.options());\n  auto new_indegree = torch::empty(num_rows + 2, sub_indptr.options());\n  auto new_sliced_indptr = torch::empty(num_rows, sliced_indptr.options());\n  thrust::counting_iterator<int64_t> iota(0);\n  AT_DISPATCH_INTEGRAL_TYPES(\n      sub_indptr.scalar_type(), \"SliceCSCIndptrHeteroIndptr\", ([&] {\n        using indptr_t = scalar_t;\n        AT_DISPATCH_INTEGRAL_TYPES(\n            etypes.scalar_type(), \"SliceCSCIndptrHeteroTypePerEdge\", ([&] {\n              using etype_t = scalar_t;\n              THRUST_CALL(\n                  for_each, iota, iota + num_rows,\n                  EdgeTypeSearch<indptr_t, etype_t>{\n                      sub_indptr.data_ptr<indptr_t>(),\n                      sliced_indptr.data_ptr<indptr_t>(),\n                      etypes.data_ptr<etype_t>(), num_fanouts, num_rows,\n                      new_sub_indptr.data_ptr<indptr_t>(),\n                      new_sliced_indptr.data_ptr<indptr_t>()});\n            }));\n        CUB_CALL(\n            DeviceAdjacentDifference::SubtractLeftCopy,\n            new_sub_indptr.data_ptr<indptr_t>(),\n            new_indegree.data_ptr<indptr_t>(), num_rows + 1, cub::Difference{});\n      }));\n  // Discard the first element of the SubtractLeftCopy result and ensure that\n  // new_indegree tensor has size num_rows + 1 so that its ExclusiveCumSum is\n  // directly equivalent to new_sub_indptr.\n  // Equivalent to new_indegree = new_indegree[1:] in Python.\n  new_indegree = new_indegree.slice(0, 1);\n  return {new_sub_indptr, new_indegree, new_sliced_indptr};\n}\n\n}  //  namespace ops\n}  //  namespace graphbolt\n"
  },
  {
    "path": "graphbolt/src/cuda/sort_impl.cu",
    "content": "/**\n *   Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)\n *   All rights reserved.\n *\n *   Licensed under the Apache License, Version 2.0 (the \"License\");\n *   you may not use this file except in compliance with the License.\n *   You may obtain a copy of the License at\n *\n *       http://www.apache.org/licenses/LICENSE-2.0\n *\n *   Unless required by applicable law or agreed to in writing, software\n *   distributed under the License is distributed on an \"AS IS\" BASIS,\n *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n *   See the License for the specific language governing permissions and\n *   limitations under the License.\n *\n * @file cuda/sort_impl.cu\n * @brief Sort implementation on CUDA.\n */\n#include <c10/core/ScalarType.h>\n\n#include <cub/cub.cuh>\n\n#include \"./common.h\"\n#include \"./utils.h\"\n\nnamespace graphbolt {\nnamespace ops {\n\ntemplate <bool return_original_positions, typename scalar_t>\nstd::conditional_t<\n    return_original_positions, std::pair<torch::Tensor, torch::Tensor>,\n    torch::Tensor>\nSort(const scalar_t* input_keys, int64_t num_items, int num_bits) {\n  const auto options = torch::TensorOptions().device(c10::DeviceType::CUDA);\n  constexpr c10::ScalarType dtype = c10::CppTypeToScalarType<scalar_t>::value;\n  auto sorted_array = torch::empty(num_items, options.dtype(dtype));\n  auto sorted_keys = sorted_array.data_ptr<scalar_t>();\n  if (num_bits == 0) {\n    num_bits = sizeof(scalar_t) * 8;\n  }\n\n  if constexpr (return_original_positions) {\n    // We utilize int64_t for the values array. (torch::kLong == int64_t)\n    auto original_idx = torch::arange(num_items, options.dtype(torch::kLong));\n    auto sorted_idx = torch::empty_like(original_idx);\n    const int64_t* input_values = original_idx.data_ptr<int64_t>();\n    int64_t* sorted_values = sorted_idx.data_ptr<int64_t>();\n    CUB_CALL(\n        DeviceRadixSort::SortPairs, input_keys, sorted_keys, input_values,\n        sorted_values, num_items, 0, num_bits);\n    return std::make_pair(sorted_array, sorted_idx);\n  } else {\n    CUB_CALL(\n        DeviceRadixSort::SortKeys, input_keys, sorted_keys, num_items, 0,\n        num_bits);\n    return sorted_array;\n  }\n}\n\ntemplate <bool return_original_positions>\nstd::conditional_t<\n    return_original_positions, std::pair<torch::Tensor, torch::Tensor>,\n    torch::Tensor>\nSort(torch::Tensor input, int num_bits) {\n  return AT_DISPATCH_INTEGRAL_TYPES(input.scalar_type(), \"SortImpl\", ([&] {\n                                      return Sort<return_original_positions>(\n                                          input.data_ptr<scalar_t>(),\n                                          input.size(0), num_bits);\n                                    }));\n}\n\ntemplate torch::Tensor Sort<false>(torch::Tensor input, int num_bits);\ntemplate std::pair<torch::Tensor, torch::Tensor> Sort<true>(\n    torch::Tensor input, int num_bits);\n\n}  //  namespace ops\n}  //  namespace graphbolt\n"
  },
  {
    "path": "graphbolt/src/cuda/unique_and_compact_impl.cu",
    "content": "/**\n *   Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)\n *   All rights reserved.\n *\n *   Licensed under the Apache License, Version 2.0 (the \"License\");\n *   you may not use this file except in compliance with the License.\n *   You may obtain a copy of the License at\n *\n *       http://www.apache.org/licenses/LICENSE-2.0\n *\n *   Unless required by applicable law or agreed to in writing, software\n *   distributed under the License is distributed on an \"AS IS\" BASIS,\n *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n *   See the License for the specific language governing permissions and\n *   limitations under the License.\n *\n * @file cuda/unique_and_compact_impl.cu\n * @brief Unique and compact operator implementation on CUDA.\n */\n#include <graphbolt/cuda_ops.h>\n#include <thrust/binary_search.h>\n#include <thrust/functional.h>\n#include <thrust/gather.h>\n#include <thrust/logical.h>\n\n#include <cub/cub.cuh>\n#include <mutex>\n#include <type_traits>\n#include <unordered_map>\n\n#include \"./common.h\"\n#include \"./extension/unique_and_compact.h\"\n#include \"./utils.h\"\n\nnamespace graphbolt {\nnamespace ops {\n\ntemplate <typename scalar_t>\nstruct EqualityFunc {\n  const scalar_t* sorted_order;\n  const scalar_t* found_locations;\n  const scalar_t* searched_items;\n  __host__ __device__ auto operator()(int64_t i) {\n    return sorted_order[found_locations[i]] == searched_items[i];\n  }\n};\n\n#define DefineCubReductionFunction(cub_reduce_fn, name)           \\\n  template <typename scalar_iterator_t>                           \\\n  auto name(const scalar_iterator_t input, int64_t size) {        \\\n    using scalar_t = std::remove_reference_t<decltype(input[0])>; \\\n    cuda::CopyScalar<scalar_t> result;                            \\\n    CUB_CALL(cub_reduce_fn, input, result.get(), size);           \\\n    return result;                                                \\\n  }\n\nDefineCubReductionFunction(DeviceReduce::Max, Max);\nDefineCubReductionFunction(DeviceReduce::Min, Min);\n\nstd::vector<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>>\nUniqueAndCompactBatchedSortBased(\n    const std::vector<torch::Tensor>& src_ids,\n    const std::vector<torch::Tensor>& dst_ids,\n    const std::vector<torch::Tensor>& unique_dst_ids, int num_bits = 0) {\n  auto allocator = cuda::GetAllocator();\n  auto stream = cuda::GetCurrentStream();\n  auto scalar_type = src_ids.at(0).scalar_type();\n  return AT_DISPATCH_INDEX_TYPES(\n      scalar_type, \"unique_and_compact\", ([&] {\n        std::vector<index_t*> src_ids_ptr, dst_ids_ptr, unique_dst_ids_ptr;\n        for (std::size_t i = 0; i < src_ids.size(); i++) {\n          src_ids_ptr.emplace_back(src_ids[i].data_ptr<index_t>());\n          dst_ids_ptr.emplace_back(dst_ids[i].data_ptr<index_t>());\n          unique_dst_ids_ptr.emplace_back(\n              unique_dst_ids[i].data_ptr<index_t>());\n        }\n\n        // If num_bits is not given, compute maximum vertex ids to compute\n        // num_bits later to speedup the expensive sort operations.\n        std::vector<cuda::CopyScalar<index_t>> max_id_src;\n        std::vector<cuda::CopyScalar<index_t>> max_id_dst;\n        for (std::size_t i = 0; num_bits == 0 && i < src_ids.size(); i++) {\n          max_id_src.emplace_back(Max(src_ids_ptr[i], src_ids[i].size(0)));\n          max_id_dst.emplace_back(\n              Max(unique_dst_ids_ptr[i], unique_dst_ids[i].size(0)));\n        }\n\n        // Sort the unique_dst_ids tensor.\n        std::vector<torch::Tensor> sorted_unique_dst_ids;\n        std::vector<index_t*> sorted_unique_dst_ids_ptr;\n        for (std::size_t i = 0; i < unique_dst_ids.size(); i++) {\n          sorted_unique_dst_ids.emplace_back(Sort<false>(\n              unique_dst_ids_ptr[i], unique_dst_ids[i].size(0), num_bits));\n          sorted_unique_dst_ids_ptr.emplace_back(\n              sorted_unique_dst_ids[i].data_ptr<index_t>());\n        }\n\n        // Mark dst nodes in the src_ids tensor.\n        std::vector<decltype(allocator.AllocateStorage<bool>(0))> is_dst;\n        for (std::size_t i = 0; i < src_ids.size(); i++) {\n          is_dst.emplace_back(\n              allocator.AllocateStorage<bool>(src_ids[i].size(0)));\n          THRUST_CALL(\n              binary_search, sorted_unique_dst_ids_ptr[i],\n              sorted_unique_dst_ids_ptr[i] + unique_dst_ids[i].size(0),\n              src_ids_ptr[i], src_ids_ptr[i] + src_ids[i].size(0),\n              is_dst[i].get());\n        }\n\n        // Filter the non-dst nodes in the src_ids tensor, hence only_src.\n        std::vector<torch::Tensor> only_src;\n        {\n          std::vector<cuda::CopyScalar<int64_t>> only_src_size;\n          for (std::size_t i = 0; i < src_ids.size(); i++) {\n            only_src.emplace_back(torch::empty(\n                src_ids[i].size(0), sorted_unique_dst_ids[i].options()));\n            auto is_src = thrust::make_transform_iterator(\n                is_dst[i].get(), thrust::logical_not<bool>{});\n            only_src_size.emplace_back(cuda::CopyScalar<int64_t>{});\n            CUB_CALL(\n                DeviceSelect::Flagged, src_ids_ptr[i], is_src,\n                only_src[i].data_ptr<index_t>(), only_src_size[i].get(),\n                src_ids[i].size(0));\n          }\n          stream.synchronize();\n          for (std::size_t i = 0; i < only_src.size(); i++) {\n            only_src[i] =\n                only_src[i].slice(0, 0, static_cast<int64_t>(only_src_size[i]));\n          }\n        }\n\n        // The code block above synchronizes, ensuring safe access to\n        // max_id_src and max_id_dst.\n        if (num_bits == 0) {\n          index_t max_id = 0;\n          for (std::size_t i = 0; i < max_id_src.size(); i++) {\n            max_id = std::max(max_id, static_cast<index_t>(max_id_src[i]));\n            max_id = std::max(max_id, static_cast<index_t>(max_id_dst[i]));\n          }\n          num_bits = cuda::NumberOfBits(1ll + max_id);\n        }\n\n        // Sort the only_src tensor so that we can unique it later.\n        std::vector<torch::Tensor> sorted_only_src;\n        for (auto& only_src_i : only_src) {\n          sorted_only_src.emplace_back(Sort<false>(\n              only_src_i.data_ptr<index_t>(), only_src_i.size(0), num_bits));\n        }\n\n        std::vector<torch::Tensor> unique_only_src;\n        std::vector<index_t*> unique_only_src_ptr;\n\n        std::vector<cuda::CopyScalar<int64_t>> unique_only_src_size;\n        for (std::size_t i = 0; i < src_ids.size(); i++) {\n          // Compute the unique operation on the only_src tensor.\n          unique_only_src.emplace_back(\n              torch::empty(only_src[i].size(0), src_ids[i].options()));\n          unique_only_src_ptr.emplace_back(\n              unique_only_src[i].data_ptr<index_t>());\n          unique_only_src_size.emplace_back(cuda::CopyScalar<int64_t>{});\n          CUB_CALL(\n              DeviceSelect::Unique, sorted_only_src[i].data_ptr<index_t>(),\n              unique_only_src_ptr[i], unique_only_src_size[i].get(),\n              only_src[i].size(0));\n        }\n        stream.synchronize();\n        for (std::size_t i = 0; i < unique_only_src.size(); i++) {\n          unique_only_src[i] = unique_only_src[i].slice(\n              0, 0, static_cast<int64_t>(unique_only_src_size[i]));\n        }\n\n        std::vector<torch::Tensor> real_order;\n        for (std::size_t i = 0; i < unique_dst_ids.size(); i++) {\n          real_order.emplace_back(\n              torch::cat({unique_dst_ids[i], unique_only_src[i]}));\n        }\n        // Sort here so that binary search can be used to lookup new_ids.\n        std::vector<torch::Tensor> sorted_order, new_ids;\n        std::vector<index_t*> sorted_order_ptr;\n        std::vector<int64_t*> new_ids_ptr;\n        for (std::size_t i = 0; i < real_order.size(); i++) {\n          auto [sorted_order_i, new_ids_i] = Sort(real_order[i], num_bits);\n          sorted_order_ptr.emplace_back(sorted_order_i.data_ptr<index_t>());\n          new_ids_ptr.emplace_back(new_ids_i.data_ptr<int64_t>());\n          sorted_order.emplace_back(std::move(sorted_order_i));\n          new_ids.emplace_back(std::move(new_ids_i));\n        }\n        // Holds the found locations of the src and dst ids in the\n        // sorted_order. Later is used to lookup the new ids of the src_ids\n        // and dst_ids tensors.\n        std::vector<decltype(allocator.AllocateStorage<index_t>(0))>\n            new_dst_ids_loc;\n        for (std::size_t i = 0; i < sorted_order.size(); i++) {\n          new_dst_ids_loc.emplace_back(\n              allocator.AllocateStorage<index_t>(dst_ids[i].size(0)));\n          THRUST_CALL(\n              lower_bound, sorted_order_ptr[i],\n              sorted_order_ptr[i] + sorted_order[i].size(0), dst_ids_ptr[i],\n              dst_ids_ptr[i] + dst_ids[i].size(0), new_dst_ids_loc[i].get());\n        }\n\n        std::vector<cuda::CopyScalar<bool>> all_exist;\n        at::cuda::CUDAEvent all_exist_event;\n        bool should_record = false;\n        // Check if unique_dst_ids includes all dst_ids.\n        for (std::size_t i = 0; i < dst_ids.size(); i++) {\n          if (dst_ids[i].size(0) > 0) {\n            thrust::counting_iterator<int64_t> iota(0);\n            auto equal_it = thrust::make_transform_iterator(\n                iota, EqualityFunc<index_t>{\n                          sorted_order_ptr[i], new_dst_ids_loc[i].get(),\n                          dst_ids_ptr[i]});\n            all_exist.emplace_back(Min(equal_it, dst_ids[i].size(0)));\n            should_record = true;\n          } else {\n            all_exist.emplace_back(cuda::CopyScalar<bool>{});\n          }\n        }\n        if (should_record) all_exist_event.record();\n\n        std::vector<decltype(allocator.AllocateStorage<index_t>(0))>\n            new_src_ids_loc;\n        for (std::size_t i = 0; i < sorted_order.size(); i++) {\n          new_src_ids_loc.emplace_back(\n              allocator.AllocateStorage<index_t>(src_ids[i].size(0)));\n          THRUST_CALL(\n              lower_bound, sorted_order_ptr[i],\n              sorted_order_ptr[i] + sorted_order[i].size(0), src_ids_ptr[i],\n              src_ids_ptr[i] + src_ids[i].size(0), new_src_ids_loc[i].get());\n        }\n\n        // Finally, lookup the new compact ids of the src and dst tensors\n        // via gather operations.\n        std::vector<torch::Tensor> new_src_ids;\n        for (std::size_t i = 0; i < src_ids.size(); i++) {\n          new_src_ids.emplace_back(torch::empty_like(src_ids[i]));\n          THRUST_CALL(\n              gather, new_src_ids_loc[i].get(),\n              new_src_ids_loc[i].get() + src_ids[i].size(0),\n              new_ids[i].data_ptr<int64_t>(),\n              new_src_ids[i].data_ptr<index_t>());\n        }\n        // Perform check before we gather for the dst indices.\n        for (std::size_t i = 0; i < dst_ids.size(); i++) {\n          if (dst_ids[i].size(0) > 0) {\n            if (should_record) {\n              all_exist_event.synchronize();\n              should_record = false;\n            }\n            if (!static_cast<bool>(all_exist[i])) {\n              throw std::out_of_range(\"Some ids not found.\");\n            }\n          }\n        }\n        std::vector<torch::Tensor> new_dst_ids;\n        for (std::size_t i = 0; i < dst_ids.size(); i++) {\n          new_dst_ids.emplace_back(torch::empty_like(dst_ids[i]));\n          THRUST_CALL(\n              gather, new_dst_ids_loc[i].get(),\n              new_dst_ids_loc[i].get() + dst_ids[i].size(0),\n              new_ids[i].data_ptr<int64_t>(),\n              new_dst_ids[i].data_ptr<index_t>());\n        }\n        std::vector<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>>\n            results;\n        for (std::size_t i = 0; i < src_ids.size(); i++) {\n          results.emplace_back(\n              std::move(real_order[i]), std::move(new_src_ids[i]),\n              std::move(new_dst_ids[i]));\n        }\n        return results;\n      }));\n}\n\nstd::vector<\n    std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>>\nUniqueAndCompactBatched(\n    const std::vector<torch::Tensor>& src_ids,\n    const std::vector<torch::Tensor>& dst_ids,\n    const std::vector<torch::Tensor>& unique_dst_ids, const int64_t rank,\n    const int64_t world_size) {\n  if (cuda::compute_capability() >= 70) {\n    // Utilizes a hash table based implementation, the mapped id of a vertex\n    // will be monotonically increasing as the first occurrence index of it in\n    // torch.cat([unique_dst_ids, src_ids]). Thus, it is deterministic.\n    return UniqueAndCompactBatchedHashMapBased(\n        src_ids, dst_ids, unique_dst_ids, rank, world_size);\n  }\n  TORCH_CHECK(\n      world_size <= 1,\n      \"Cooperative Minibatching (arXiv:2310.12403) is not supported on \"\n      \"pre-Volta generation GPUs.\");\n  // Utilizes a sort based algorithm, the mapped id of a vertex part of the\n  // src_ids but not part of the unique_dst_ids will be monotonically increasing\n  // as the actual vertex id increases. Thus, it is deterministic.\n  auto results3 =\n      UniqueAndCompactBatchedSortBased(src_ids, dst_ids, unique_dst_ids);\n  std::vector<\n      std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>>\n      results4;\n  auto offsets = torch::zeros(\n      2 * results3.size(),\n      c10::TensorOptions().dtype(torch::kInt64).pinned_memory(true));\n  for (const auto& [a, b, c] : results3) {\n    auto d = offsets.slice(0, 0, 2);\n    d.data_ptr<int64_t>()[1] = a.size(0);\n    results4.emplace_back(a, b, c, d);\n    offsets = offsets.slice(0, 2);\n  }\n  return results4;\n}\n\nstd::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>\nUniqueAndCompact(\n    const torch::Tensor src_ids, const torch::Tensor dst_ids,\n    const torch::Tensor unique_dst_ids, const int64_t rank,\n    const int64_t world_size) {\n  return UniqueAndCompactBatched(\n      {src_ids}, {dst_ids}, {unique_dst_ids}, rank, world_size)[0];\n}\n\n}  // namespace ops\n}  // namespace graphbolt\n"
  },
  {
    "path": "graphbolt/src/cuda/utils.h",
    "content": "/**\n *  Copyright (c) 2023 by Contributors\n *\n * @file utils.h\n * @brief CUDA utilities.\n */\n\n#ifndef GRAPHBOLT_CUDA_UTILS_H_\n#define GRAPHBOLT_CUDA_UTILS_H_\n\n// The cache line size of GPU.\nconstexpr int GPU_CACHE_LINE_SIZE = 128;\n// The max number of threads per block.\nconstexpr int CUDA_MAX_NUM_THREADS = 1024;\n\nnamespace graphbolt {\nnamespace cuda {\n\n/**\n * @brief Returns the compute capability of the cuda device, e.g. 70 for Volta.\n */\ninline int compute_capability(\n    int device = cuda::GetCurrentStream().device_index()) {\n  int sm_version;\n  CUDA_RUNTIME_CHECK(cub::SmVersion(sm_version, device));\n  return sm_version / 10;\n};\n\n/**\n * @brief Calculate the number of threads needed given the size of the dimension\n * to be processed.\n *\n * It finds the largest power of two that is less than or equal to the minimum\n * of size and CUDA_MAX_NUM_THREADS.\n */\ninline int FindNumThreads(int size) {\n  int ret = 1;\n  while ((ret << 1) <= std::min(size, CUDA_MAX_NUM_THREADS)) {\n    ret <<= 1;\n  }\n  return ret;\n}\n\n/**\n * @brief Calculate the smallest number of bits needed to represent a given\n * range of integers [0, range).\n */\ntemplate <typename T>\nint NumberOfBits(const T& range) {\n  if (range <= 1) {\n    // ranges of 0 or 1 require no bits to store\n    return 0;\n  }\n\n  int bits = 1;\n  const auto urange = static_cast<std::make_unsigned_t<T>>(range);\n  while (bits < static_cast<int>(sizeof(T) * 8) && (1ull << bits) < urange) {\n    ++bits;\n  }\n\n  return bits;\n}\n\n}  // namespace cuda\n}  // namespace graphbolt\n\n#endif  // GRAPHBOLT_CUDA_UTILS_H_\n"
  },
  {
    "path": "graphbolt/src/expand_indptr.cc",
    "content": "/**\n *  Copyright (c) 2023 by Contributors\n *  Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)\n * @file expand_indptr.cc\n * @brief ExpandIndptr operators.\n */\n#include <graphbolt/cuda_ops.h>\n#include <torch/autograd.h>\n\n#include \"./macro.h\"\n#include \"./utils.h\"\n\nnamespace graphbolt {\nnamespace ops {\n\ntorch::Tensor ExpandIndptr(\n    torch::Tensor indptr, torch::ScalarType dtype,\n    torch::optional<torch::Tensor> node_ids,\n    torch::optional<int64_t> output_size) {\n  if (utils::is_on_gpu(indptr) &&\n      (!node_ids.has_value() || utils::is_on_gpu(node_ids.value()))) {\n    GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(c10::DeviceType::CUDA, \"ExpandIndptr\", {\n      return ExpandIndptrImpl(indptr, dtype, node_ids, output_size);\n    });\n  }\n  if (!node_ids.has_value()) {\n    return torch::repeat_interleave(indptr.diff(), output_size).to(dtype);\n  }\n  return node_ids.value().to(dtype).repeat_interleave(\n      indptr.diff(), 0, output_size);\n}\n\ntorch::Tensor IndptrEdgeIds(\n    torch::Tensor indptr, torch::ScalarType dtype,\n    torch::optional<torch::Tensor> offset,\n    torch::optional<int64_t> output_size) {\n  if (utils::is_on_gpu(indptr) &&\n      (!offset.has_value() || utils::is_on_gpu(offset.value()))) {\n    GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(\n        c10::DeviceType::CUDA, \"IndptrEdgeIds\",\n        { return IndptrEdgeIdsImpl(indptr, dtype, offset, output_size); });\n  }\n  TORCH_CHECK(false, \"CPU implementation of IndptrEdgeIds is not available.\");\n}\n\nTORCH_LIBRARY_IMPL(graphbolt, CPU, m) {\n  m.impl(\"expand_indptr\", &ExpandIndptr);\n}\n\n#ifdef GRAPHBOLT_USE_CUDA\nTORCH_LIBRARY_IMPL(graphbolt, CUDA, m) {\n  m.impl(\"expand_indptr\", &ExpandIndptrImpl);\n}\n#endif\n\nTORCH_LIBRARY_IMPL(graphbolt, Autograd, m) {\n  m.impl(\"expand_indptr\", torch::autograd::autogradNotImplementedFallback());\n}\n\nTORCH_LIBRARY_IMPL(graphbolt, CPU, m) {\n  m.impl(\"indptr_edge_ids\", &IndptrEdgeIds);\n}\n\n#ifdef GRAPHBOLT_USE_CUDA\nTORCH_LIBRARY_IMPL(graphbolt, CUDA, m) {\n  m.impl(\"indptr_edge_ids\", &IndptrEdgeIdsImpl);\n}\n#endif\n\nTORCH_LIBRARY_IMPL(graphbolt, Autograd, m) {\n  m.impl(\"indptr_edge_ids\", torch::autograd::autogradNotImplementedFallback());\n}\n\n}  // namespace ops\n}  // namespace graphbolt\n"
  },
  {
    "path": "graphbolt/src/expand_indptr.h",
    "content": "/**\n *  Copyright (c) 2023 by Contributors\n *  Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)\n * @file expand_indptr.h\n * @brief ExpandIndptr operators.\n */\n#ifndef GRAPHBOLT_EXPAND_INDPTR_H_\n#define GRAPHBOLT_EXPAND_INDPTR_H_\n\n#include <torch/script.h>\n\nnamespace graphbolt {\nnamespace ops {\n\n/**\n * @brief ExpandIndptr implements conversion from a given indptr offset\n * tensor to a COO format tensor. If node_ids is not given, it is assumed to be\n * equal to torch::arange(indptr.size(0) - 1, dtype=dtype).\n *\n * @param indptr       The indptr offset tensor.\n * @param dtype        The dtype of the returned output tensor.\n * @param node_ids     1D tensor represents the node ids.\n * @param output_size  Optional, value of indptr[-1]. Passing it eliminates CPU\n * GPU synchronization.\n *\n * @return The resulting tensor.\n */\ntorch::Tensor ExpandIndptr(\n    torch::Tensor indptr, torch::ScalarType dtype,\n    torch::optional<torch::Tensor> node_ids = torch::nullopt,\n    torch::optional<int64_t> output_size = torch::nullopt);\n\n/**\n * @brief IndptrEdgeIdsImpl implements conversion from a given indptr offset\n * tensor to a COO edge ids tensor. For a given indptr [0, 2, 5, 7] and offset\n * tensor [0, 100, 200], the output will be [0, 1, 100, 101, 102, 201, 202]. If\n * offset was not provided, the output would be [0, 1, 0, 1, 2, 0, 1].\n *\n * @param indptr       The indptr offset tensor.\n * @param dtype        The dtype of the returned output tensor.\n * @param offset       The offset tensor.\n * @param output_size  Optional value of indptr[-1]. Passing it eliminates CPU\n * GPU synchronization.\n *\n * @return The resulting tensor.\n */\ntorch::Tensor IndptrEdgeIds(\n    torch::Tensor indptr, torch::ScalarType dtype,\n    torch::optional<torch::Tensor> offset,\n    torch::optional<int64_t> output_size);\n\n}  // namespace ops\n}  // namespace graphbolt\n\n#endif  // GRAPHBOLT_EXPAND_INDPTR_H_\n"
  },
  {
    "path": "graphbolt/src/feature_cache.cc",
    "content": "/**\n *   Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)\n *   All rights reserved.\n *\n *   Licensed under the Apache License, Version 2.0 (the \"License\");\n *   you may not use this file except in compliance with the License.\n *   You may obtain a copy of the License at\n *\n *       http://www.apache.org/licenses/LICENSE-2.0\n *\n *   Unless required by applicable law or agreed to in writing, software\n *   distributed under the License is distributed on an \"AS IS\" BASIS,\n *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n *   See the License for the specific language governing permissions and\n *   limitations under the License.\n *\n * @file feature_cache.cc\n * @brief Feature cache implementation on the CPU.\n */\n#include \"./feature_cache.h\"\n\n#include \"./index_select.h\"\n#include \"./utils.h\"\n\nnamespace graphbolt {\nnamespace storage {\n\nconstexpr int kIntGrainSize = 64;\n\nFeatureCache::FeatureCache(\n    const std::vector<int64_t>& shape, torch::ScalarType dtype, bool pin_memory)\n    : tensor_(torch::empty(\n          shape, c10::TensorOptions().dtype(dtype).pinned_memory(pin_memory))) {\n}\n\ntorch::Tensor FeatureCache::Query(\n    torch::Tensor positions, torch::Tensor indices, int64_t size) {\n  const bool pin_memory =\n      utils::is_pinned(positions) || utils::is_pinned(indices);\n  std::vector<int64_t> output_shape{\n      tensor_.sizes().begin(), tensor_.sizes().end()};\n  output_shape[0] = size;\n  auto values =\n      torch::empty(output_shape, tensor_.options().pinned_memory(pin_memory));\n  const auto row_bytes = values.slice(0, 0, 1).numel() * values.element_size();\n  auto values_ptr = reinterpret_cast<std::byte*>(values.data_ptr());\n  const auto tensor_ptr = reinterpret_cast<std::byte*>(tensor_.data_ptr());\n  const auto positions_ptr = positions.data_ptr<int64_t>();\n  const auto indices_ptr = indices.data_ptr<int64_t>();\n  graphbolt::parallel_for_each(\n      0, positions.size(0), kIntGrainSize, [&](const int64_t i) {\n        std::memcpy(\n            values_ptr + indices_ptr[i] * row_bytes,\n            tensor_ptr + positions_ptr[i] * row_bytes, row_bytes);\n      });\n  return values;\n}\n\nc10::intrusive_ptr<Future<torch::Tensor>> FeatureCache::QueryAsync(\n    torch::Tensor positions, torch::Tensor indices, int64_t size) {\n  return async([=] { return Query(positions, indices, size); });\n}\n\ntorch::Tensor FeatureCache::IndexSelect(torch::Tensor positions) {\n  return ops::IndexSelect(tensor_, positions);\n}\n\nvoid FeatureCache::Replace(torch::Tensor positions, torch::Tensor values) {\n  TORCH_CHECK(positions.size(0) == values.size(0));\n  if (values.numel() == 0) return;\n  const auto row_bytes = values.slice(0, 0, 1).numel() * values.element_size();\n  TORCH_CHECK(\n      row_bytes == tensor_.slice(0, 0, 1).numel() * tensor_.element_size(),\n      \"The # bytes of a single row should match the cache's.\");\n  auto values_ptr = reinterpret_cast<std::byte*>(values.data_ptr());\n  const auto tensor_ptr = reinterpret_cast<std::byte*>(tensor_.data_ptr());\n  const auto positions_ptr = positions.data_ptr<int64_t>();\n  graphbolt::parallel_for_each(\n      0, positions.size(0), kIntGrainSize, [&](const int64_t i) {\n        const auto position = positions_ptr[i];\n        if (position >= 0) {\n          std::memcpy(\n              tensor_ptr + position * row_bytes, values_ptr + i * row_bytes,\n              row_bytes);\n        }\n      });\n}\n\nc10::intrusive_ptr<Future<void>> FeatureCache::ReplaceAsync(\n    torch::Tensor positions, torch::Tensor values) {\n  return async([=] { return Replace(positions, values); });\n}\n\nc10::intrusive_ptr<FeatureCache> FeatureCache::Create(\n    const std::vector<int64_t>& shape, torch::ScalarType dtype,\n    bool pin_memory) {\n  return c10::make_intrusive<FeatureCache>(shape, dtype, pin_memory);\n}\n\n}  // namespace storage\n}  // namespace graphbolt\n"
  },
  {
    "path": "graphbolt/src/feature_cache.h",
    "content": "/**\n *   Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)\n *   All rights reserved.\n *\n *   Licensed under the Apache License, Version 2.0 (the \"License\");\n *   you may not use this file except in compliance with the License.\n *   You may obtain a copy of the License at\n *\n *       http://www.apache.org/licenses/LICENSE-2.0\n *\n *   Unless required by applicable law or agreed to in writing, software\n *   distributed under the License is distributed on an \"AS IS\" BASIS,\n *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n *   See the License for the specific language governing permissions and\n *   limitations under the License.\n *\n * @file feature_cache.h\n * @brief Feature cache implementation on the CPU.\n */\n#ifndef GRAPHBOLT_FEATURE_CACHE_H_\n#define GRAPHBOLT_FEATURE_CACHE_H_\n\n#include <graphbolt/async.h>\n#include <torch/custom_class.h>\n#include <torch/torch.h>\n\n#include <vector>\n\nnamespace graphbolt {\nnamespace storage {\n\nstruct FeatureCache : public torch::CustomClassHolder {\n  /**\n   * @brief Constructor for the FeatureCache struct.\n   *\n   * @param shape The shape of the cache.\n   * @param dtype The dtype of elements stored in the cache.\n   * @param pin_memory Whether to pin the memory of the cache storage tensor.\n   */\n  FeatureCache(\n      const std::vector<int64_t>& shape, torch::ScalarType dtype,\n      bool pin_memory);\n\n  bool IsPinned() const { return tensor_.is_pinned(); }\n\n  int64_t NumBytes() const { return tensor_.numel() * tensor_.element_size(); }\n\n  /**\n   * @brief The cache query function. Allocates an empty tensor `values` with\n   * size as the first dimension and runs\n   * values[indices[:positions.size(0)]] = cache_tensor[positions] before\n   * returning it.\n   *\n   * @param positions The positions of the queried items.\n   * @param indices The indices of the queried items among the original keys.\n   * Only the first portion corresponding to the provided positions tensor is\n   * used, e.g. indices[:positions.size(0)].\n   * @param size The size of the original keys, hence the first dimension of\n   * the output shape.\n   *\n   * @return The values tensor is returned. Its memory is pinned if pin_memory\n   * is true.\n   */\n  torch::Tensor Query(\n      torch::Tensor positions, torch::Tensor indices, int64_t size);\n\n  c10::intrusive_ptr<Future<torch::Tensor>> QueryAsync(\n      torch::Tensor positions, torch::Tensor indices, int64_t size);\n\n  /**\n   * @brief The cache tensor index_select returns cache_tensor[positions].\n   *\n   * @param positions The positions of the queried items.\n   *\n   * @return The values tensor is returned on the same device as positions.\n   */\n  torch::Tensor IndexSelect(torch::Tensor positions);\n\n  /**\n   * @brief The cache replace function.\n   *\n   * @param positions The positions to replace in the cache.\n   * @param values The values to be inserted into the cache.\n   */\n  void Replace(torch::Tensor positions, torch::Tensor values);\n\n  c10::intrusive_ptr<Future<void>> ReplaceAsync(\n      torch::Tensor positions, torch::Tensor values);\n\n  static c10::intrusive_ptr<FeatureCache> Create(\n      const std::vector<int64_t>& shape, torch::ScalarType dtype,\n      bool pin_memory);\n\n private:\n  torch::Tensor tensor_;\n};\n\n}  // namespace storage\n}  // namespace graphbolt\n\n#endif  // GRAPHBOLT_FEATURE_CACHE_H_\n"
  },
  {
    "path": "graphbolt/src/fused_csc_sampling_graph.cc",
    "content": "/**\n *  Copyright (c) 2023 by Contributors\n * @file fused_csc_sampling_graph.cc\n * @brief Source file of sampling graph.\n */\n\n#include <graphbolt/cuda_sampling_ops.h>\n#include <graphbolt/fused_csc_sampling_graph.h>\n#include <graphbolt/serialize.h>\n#include <torch/torch.h>\n\n#include <algorithm>\n#include <array>\n#include <cmath>\n#include <limits>\n#include <numeric>\n#include <tuple>\n#include <type_traits>\n#include <vector>\n\n#include \"./expand_indptr.h\"\n#include \"./index_select.h\"\n#include \"./macro.h\"\n#include \"./random.h\"\n#include \"./shared_memory_helper.h\"\n#include \"./utils.h\"\n\nnamespace {\ntorch::optional<torch::Dict<std::string, torch::Tensor>> TensorizeDict(\n    const torch::optional<torch::Dict<std::string, int64_t>>& dict) {\n  if (!dict.has_value()) {\n    return torch::nullopt;\n  }\n  torch::Dict<std::string, torch::Tensor> result;\n  for (const auto& pair : dict.value()) {\n    result.insert(pair.key(), torch::tensor(pair.value(), torch::kInt64));\n  }\n  return result;\n}\n\ntorch::optional<torch::Dict<std::string, int64_t>> DetensorizeDict(\n    const torch::optional<torch::Dict<std::string, torch::Tensor>>& dict) {\n  if (!dict.has_value()) {\n    return torch::nullopt;\n  }\n  torch::Dict<std::string, int64_t> result;\n  for (const auto& pair : dict.value()) {\n    result.insert(pair.key(), pair.value().item<int64_t>());\n  }\n  return result;\n}\n}  // namespace\n\nnamespace graphbolt {\nnamespace sampling {\n\nstatic const int kPickleVersion = 6199;\n\nFusedCSCSamplingGraph::FusedCSCSamplingGraph(\n    const torch::Tensor& indptr, const torch::Tensor& indices,\n    const torch::optional<torch::Tensor>& node_type_offset,\n    const torch::optional<torch::Tensor>& type_per_edge,\n    const torch::optional<NodeTypeToIDMap>& node_type_to_id,\n    const torch::optional<EdgeTypeToIDMap>& edge_type_to_id,\n    const torch::optional<NodeAttrMap>& node_attributes,\n    const torch::optional<EdgeAttrMap>& edge_attributes)\n    : indptr_(indptr),\n      indices_(indices),\n      node_type_offset_(node_type_offset),\n      type_per_edge_(type_per_edge),\n      node_type_to_id_(node_type_to_id),\n      edge_type_to_id_(edge_type_to_id),\n      node_attributes_(node_attributes),\n      edge_attributes_(edge_attributes) {\n  TORCH_CHECK(indptr.dim() == 1);\n  TORCH_CHECK(indices.dim() == 1);\n  TORCH_CHECK(indptr.device() == indices.device());\n}\n\nc10::intrusive_ptr<FusedCSCSamplingGraph> FusedCSCSamplingGraph::Create(\n    const torch::Tensor& indptr, const torch::Tensor& indices,\n    const torch::optional<torch::Tensor>& node_type_offset,\n    const torch::optional<torch::Tensor>& type_per_edge,\n    const torch::optional<NodeTypeToIDMap>& node_type_to_id,\n    const torch::optional<EdgeTypeToIDMap>& edge_type_to_id,\n    const torch::optional<NodeAttrMap>& node_attributes,\n    const torch::optional<EdgeAttrMap>& edge_attributes) {\n  if (node_type_offset.has_value()) {\n    auto& offset = node_type_offset.value();\n    TORCH_CHECK(offset.dim() == 1);\n    TORCH_CHECK(node_type_to_id.has_value());\n    TORCH_CHECK(\n        offset.size(0) ==\n        static_cast<int64_t>(node_type_to_id.value().size() + 1));\n  }\n  if (type_per_edge.has_value()) {\n    TORCH_CHECK(type_per_edge.value().dim() == 1);\n    TORCH_CHECK(type_per_edge.value().size(0) == indices.size(0));\n    TORCH_CHECK(edge_type_to_id.has_value());\n  }\n  if (node_attributes.has_value()) {\n    for (const auto& pair : node_attributes.value()) {\n      TORCH_CHECK(\n          pair.value().size(0) == indptr.size(0) - 1,\n          \"Expected node_attribute.size(0) and num_nodes to be equal, \"\n          \"but node_attribute.size(0) was \",\n          pair.value().size(0), \", and num_nodes was \", indptr.size(0) - 1,\n          \".\");\n    }\n  }\n  if (edge_attributes.has_value()) {\n    for (const auto& pair : edge_attributes.value()) {\n      TORCH_CHECK(\n          pair.value().size(0) == indices.size(0),\n          \"Expected edge_attribute.size(0) and num_edges to be equal, \"\n          \"but edge_attribute.size(0) was \",\n          pair.value().size(0), \", and num_edges was \", indices.size(0), \".\");\n    }\n  }\n  return c10::make_intrusive<FusedCSCSamplingGraph>(\n      indptr, indices, node_type_offset, type_per_edge, node_type_to_id,\n      edge_type_to_id, node_attributes, edge_attributes);\n}\n\nvoid FusedCSCSamplingGraph::Load(torch::serialize::InputArchive& archive) {\n  const int64_t magic_num =\n      read_from_archive<int64_t>(archive, \"FusedCSCSamplingGraph/magic_num\");\n  TORCH_CHECK(\n      magic_num == kCSCSamplingGraphSerializeMagic,\n      \"Magic numbers mismatch when loading FusedCSCSamplingGraph.\");\n  indptr_ =\n      read_from_archive<torch::Tensor>(archive, \"FusedCSCSamplingGraph/indptr\");\n  indices_ = read_from_archive<torch::Tensor>(\n      archive, \"FusedCSCSamplingGraph/indices\");\n  if (read_from_archive<bool>(\n          archive, \"FusedCSCSamplingGraph/has_node_type_offset\")) {\n    node_type_offset_ = read_from_archive<torch::Tensor>(\n        archive, \"FusedCSCSamplingGraph/node_type_offset\");\n  }\n  if (read_from_archive<bool>(\n          archive, \"FusedCSCSamplingGraph/has_type_per_edge\")) {\n    type_per_edge_ = read_from_archive<torch::Tensor>(\n        archive, \"FusedCSCSamplingGraph/type_per_edge\");\n  }\n\n  if (read_from_archive<bool>(\n          archive, \"FusedCSCSamplingGraph/has_node_type_to_id\")) {\n    node_type_to_id_ = read_from_archive<NodeTypeToIDMap>(\n        archive, \"FusedCSCSamplingGraph/node_type_to_id\");\n  }\n\n  if (read_from_archive<bool>(\n          archive, \"FusedCSCSamplingGraph/has_edge_type_to_id\")) {\n    edge_type_to_id_ = read_from_archive<EdgeTypeToIDMap>(\n        archive, \"FusedCSCSamplingGraph/edge_type_to_id\");\n  }\n\n  if (read_from_archive<bool>(\n          archive, \"FusedCSCSamplingGraph/has_node_attributes\")) {\n    node_attributes_ = read_from_archive<NodeAttrMap>(\n        archive, \"FusedCSCSamplingGraph/node_attributes\");\n  }\n  if (read_from_archive<bool>(\n          archive, \"FusedCSCSamplingGraph/has_edge_attributes\")) {\n    edge_attributes_ = read_from_archive<EdgeAttrMap>(\n        archive, \"FusedCSCSamplingGraph/edge_attributes\");\n  }\n}\n\nvoid FusedCSCSamplingGraph::Save(\n    torch::serialize::OutputArchive& archive) const {\n  archive.write(\n      \"FusedCSCSamplingGraph/magic_num\", kCSCSamplingGraphSerializeMagic);\n  archive.write(\"FusedCSCSamplingGraph/indptr\", indptr_);\n  archive.write(\"FusedCSCSamplingGraph/indices\", indices_);\n  archive.write(\n      \"FusedCSCSamplingGraph/has_node_type_offset\",\n      node_type_offset_.has_value());\n  if (node_type_offset_) {\n    archive.write(\n        \"FusedCSCSamplingGraph/node_type_offset\", node_type_offset_.value());\n  }\n  archive.write(\n      \"FusedCSCSamplingGraph/has_type_per_edge\", type_per_edge_.has_value());\n  if (type_per_edge_) {\n    archive.write(\n        \"FusedCSCSamplingGraph/type_per_edge\", type_per_edge_.value());\n  }\n  archive.write(\n      \"FusedCSCSamplingGraph/has_node_type_to_id\",\n      node_type_to_id_.has_value());\n  if (node_type_to_id_) {\n    archive.write(\n        \"FusedCSCSamplingGraph/node_type_to_id\", node_type_to_id_.value());\n  }\n  archive.write(\n      \"FusedCSCSamplingGraph/has_edge_type_to_id\",\n      edge_type_to_id_.has_value());\n  if (edge_type_to_id_) {\n    archive.write(\n        \"FusedCSCSamplingGraph/edge_type_to_id\", edge_type_to_id_.value());\n  }\n  archive.write(\n      \"FusedCSCSamplingGraph/has_node_attributes\",\n      node_attributes_.has_value());\n  if (node_attributes_) {\n    archive.write(\n        \"FusedCSCSamplingGraph/node_attributes\", node_attributes_.value());\n  }\n  archive.write(\n      \"FusedCSCSamplingGraph/has_edge_attributes\",\n      edge_attributes_.has_value());\n  if (edge_attributes_) {\n    archive.write(\n        \"FusedCSCSamplingGraph/edge_attributes\", edge_attributes_.value());\n  }\n}\n\nvoid FusedCSCSamplingGraph::SetState(\n    const torch::Dict<std::string, torch::Dict<std::string, torch::Tensor>>&\n        state) {\n  // State is a dict of dicts. The tensor-type attributes are stored in the dict\n  // with key \"independent_tensors\". The dict-type attributes (edge_attributes)\n  // are stored directly with the their name as the key.\n  const auto& independent_tensors = state.at(\"independent_tensors\");\n  TORCH_CHECK(\n      independent_tensors.at(\"version_number\")\n          .equal(torch::tensor({kPickleVersion})),\n      \"Version number mismatches when loading pickled FusedCSCSamplingGraph.\")\n  indptr_ = independent_tensors.at(\"indptr\");\n  indices_ = independent_tensors.at(\"indices\");\n  if (independent_tensors.find(\"node_type_offset\") !=\n      independent_tensors.end()) {\n    node_type_offset_ = independent_tensors.at(\"node_type_offset\");\n  }\n  if (independent_tensors.find(\"type_per_edge\") != independent_tensors.end()) {\n    type_per_edge_ = independent_tensors.at(\"type_per_edge\");\n  }\n  if (state.find(\"node_type_to_id\") != state.end()) {\n    node_type_to_id_ = DetensorizeDict(state.at(\"node_type_to_id\"));\n  }\n  if (state.find(\"edge_type_to_id\") != state.end()) {\n    edge_type_to_id_ = DetensorizeDict(state.at(\"edge_type_to_id\"));\n  }\n  if (state.find(\"node_attributes\") != state.end()) {\n    node_attributes_ = state.at(\"node_attributes\");\n  }\n  if (state.find(\"edge_attributes\") != state.end()) {\n    edge_attributes_ = state.at(\"edge_attributes\");\n  }\n}\n\ntorch::Dict<std::string, torch::Dict<std::string, torch::Tensor>>\nFusedCSCSamplingGraph::GetState() const {\n  // State is a dict of dicts. The tensor-type attributes are stored in the dict\n  // with key \"independent_tensors\". The dict-type attributes (edge_attributes)\n  // are stored directly with the their name as the key.\n  torch::Dict<std::string, torch::Dict<std::string, torch::Tensor>> state;\n  torch::Dict<std::string, torch::Tensor> independent_tensors;\n  // Serialization version number. It indicates the serialization method of the\n  // whole state.\n  independent_tensors.insert(\"version_number\", torch::tensor({kPickleVersion}));\n  independent_tensors.insert(\"indptr\", indptr_);\n  independent_tensors.insert(\"indices\", indices_);\n  if (node_type_offset_.has_value()) {\n    independent_tensors.insert(\"node_type_offset\", node_type_offset_.value());\n  }\n  if (type_per_edge_.has_value()) {\n    independent_tensors.insert(\"type_per_edge\", type_per_edge_.value());\n  }\n  state.insert(\"independent_tensors\", independent_tensors);\n  if (node_type_to_id_.has_value()) {\n    state.insert(\"node_type_to_id\", TensorizeDict(node_type_to_id_).value());\n  }\n  if (edge_type_to_id_.has_value()) {\n    state.insert(\"edge_type_to_id\", TensorizeDict(edge_type_to_id_).value());\n  }\n  if (node_attributes_.has_value()) {\n    state.insert(\"node_attributes\", node_attributes_.value());\n  }\n  if (edge_attributes_.has_value()) {\n    state.insert(\"edge_attributes\", edge_attributes_.value());\n  }\n  return state;\n}\n\nc10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::InSubgraph(\n    const torch::Tensor& nodes) const {\n  if (utils::is_on_gpu(nodes) && utils::is_accessible_from_gpu(indptr_) &&\n      utils::is_accessible_from_gpu(indices_) &&\n      (!type_per_edge_.has_value() ||\n       utils::is_accessible_from_gpu(type_per_edge_.value()))) {\n    GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(c10::DeviceType::CUDA, \"InSubgraph\", {\n      return ops::InSubgraph(indptr_, indices_, nodes, type_per_edge_);\n    });\n  }\n  std::vector<torch::Tensor> tensors{indices_};\n  if (type_per_edge_.has_value()) {\n    tensors.push_back(*type_per_edge_);\n  }\n\n  auto [output_indptr, results] =\n      ops::IndexSelectCSCBatched(indptr_, tensors, nodes, true, torch::nullopt);\n  torch::optional<torch::Tensor> type_per_edge;\n  if (type_per_edge_.has_value()) {\n    type_per_edge = results.at(1);\n  }\n\n  return c10::make_intrusive<FusedSampledSubgraph>(\n      // original_row_node_ids is not computed here and is unused.\n      output_indptr, results.at(0), results.back(), nodes, torch::nullopt,\n      type_per_edge);\n}\n\n/**\n * @brief Get a lambda function which counts the number of the neighbors to be\n * sampled.\n *\n * @param fanouts The number of edges to be sampled for each node with or\n * without considering edge types.\n * @param replace Boolean indicating whether the sample is performed with or\n * without replacement. If True, a value can be selected multiple times.\n * Otherwise, each value can be selected only once.\n * @param type_per_edge A tensor representing the type of each edge, if\n * present.\n * @param probs_or_mask Optional tensor containing the (unnormalized)\n * probabilities associated with each neighboring edge of a node in the original\n * graph. It must be a 1D floating-point tensor with the number of elements\n * equal to the number of edges in the graph.\n *\n * @return A lambda function (int64_t seed_offset, int64_t offset, int64_t\n * num_neighbors) -> torch::Tensor, which takes seed offset (the offset of the\n * seed to sample), offset (the starting edge ID of the given node) and\n * num_neighbors (number of neighbors) as params and returns the pick number of\n * the given node.\n */\nauto GetNumPickFn(\n    const std::vector<int64_t>& fanouts, bool replace,\n    const torch::optional<torch::Tensor>& type_per_edge,\n    const torch::optional<torch::Tensor>& probs_or_mask,\n    bool with_seed_offsets) {\n  // If fanouts.size() > 1, returns the total number of all edge types of the\n  // given node.\n  return [&fanouts, replace, &probs_or_mask, &type_per_edge, with_seed_offsets](\n             int64_t offset, int64_t num_neighbors, auto num_picked_ptr,\n             int64_t seed_index,\n             const std::vector<int64_t>& etype_id_to_num_picked_offset) {\n    if (fanouts.size() > 1) {\n      NumPickByEtype(\n          with_seed_offsets, fanouts, replace, type_per_edge.value(),\n          probs_or_mask, offset, num_neighbors, num_picked_ptr, seed_index,\n          etype_id_to_num_picked_offset);\n    } else {\n      NumPick(\n          fanouts[0], replace, probs_or_mask, offset, num_neighbors,\n          num_picked_ptr + seed_index);\n    }\n  };\n}\n\nauto GetTemporalNumPickFn(\n    torch::Tensor seed_timestamp, torch::Tensor csc_indices,\n    const std::vector<int64_t>& fanouts, bool replace,\n    const torch::optional<torch::Tensor>& type_per_edge,\n    const torch::optional<torch::Tensor>& seed_pre_time_window,\n    const torch::optional<torch::Tensor>& probs_or_mask,\n    const torch::optional<torch::Tensor>& node_timestamp,\n    const torch::optional<torch::Tensor>& edge_timestamp) {\n  // If fanouts.size() > 1, returns the total number of all edge types of the\n  // given node.\n  return [&seed_timestamp, &csc_indices, &fanouts, replace,\n          &seed_pre_time_window, &probs_or_mask, &type_per_edge,\n          &node_timestamp, &edge_timestamp](\n             int64_t seed_offset, int64_t offset, int64_t num_neighbors) {\n    if (fanouts.size() > 1) {\n      return TemporalNumPickByEtype(\n          seed_timestamp, csc_indices, fanouts, replace, type_per_edge.value(),\n          seed_pre_time_window, probs_or_mask, node_timestamp, edge_timestamp,\n          seed_offset, offset, num_neighbors);\n    } else {\n      return TemporalNumPick(\n          seed_timestamp, csc_indices, fanouts[0], replace,\n          seed_pre_time_window, probs_or_mask, node_timestamp, edge_timestamp,\n          seed_offset, offset, num_neighbors);\n    }\n  };\n}\n\n/**\n * @brief Get a lambda function which contains the sampling process.\n *\n * @param fanouts The number of edges to be sampled for each node with or\n * without considering edge types.\n * @param replace Boolean indicating whether the sample is performed with or\n * without replacement. If True, a value can be selected multiple times.\n * Otherwise, each value can be selected only once.\n * @param options Tensor options specifying the desired data type of the result.\n * @param type_per_edge A tensor representing the type of each edge, if\n * present.\n * @param probs_or_mask Optional tensor containing the (unnormalized)\n * probabilities associated with each neighboring edge of a node in the original\n * graph. It must be a 1D floating-point tensor with the number of elements\n * equal to the number of edges in the graph.\n * @param args Contains sampling algorithm specific arguments.\n *\n * @return A lambda function: (int64_t seed_offset, int64_t offset, int64_t\n * num_neighbors, PickedType* picked_data_ptr) -> torch::Tensor, which takes\n * seed_offset (the offset of the seed to sample), offset (the starting edge ID\n * of the given node) and num_neighbors (number of neighbors) as params and puts\n * the picked neighbors at the address specified by picked_data_ptr.\n */\ntemplate <SamplerType S>\nauto GetPickFn(\n    const std::vector<int64_t>& fanouts, bool replace,\n    const torch::TensorOptions& options,\n    const torch::optional<torch::Tensor>& type_per_edge,\n    const torch::optional<torch::Tensor>& probs_or_mask, bool with_seed_offsets,\n    SamplerArgs<S> args) {\n  return [&fanouts, replace, &options, &type_per_edge, &probs_or_mask, args,\n          with_seed_offsets](\n             int64_t offset, int64_t num_neighbors, auto picked_data_ptr,\n             int64_t seed_offset, auto subgraph_indptr_ptr,\n             const std::vector<int64_t>& etype_id_to_num_picked_offset) {\n    // If fanouts.size() > 1, perform sampling for each edge type of each\n    // node; otherwise just sample once for each node with no regard of edge\n    // types.\n    if (fanouts.size() > 1) {\n      return PickByEtype(\n          with_seed_offsets, offset, num_neighbors, fanouts, replace, options,\n          type_per_edge.value(), probs_or_mask, args, picked_data_ptr,\n          seed_offset, subgraph_indptr_ptr, etype_id_to_num_picked_offset);\n    } else {\n      picked_data_ptr += subgraph_indptr_ptr[seed_offset];\n      int64_t num_sampled = Pick(\n          offset, num_neighbors, fanouts[0], replace, options, probs_or_mask,\n          args, picked_data_ptr);\n      if (type_per_edge) {\n        std::sort(picked_data_ptr, picked_data_ptr + num_sampled);\n      }\n      return num_sampled;\n    }\n  };\n}\n\ntemplate <SamplerType S>\nauto GetTemporalPickFn(\n    torch::Tensor seed_timestamp, torch::Tensor csc_indices,\n    const std::vector<int64_t>& fanouts, bool replace,\n    const torch::TensorOptions& options,\n    const torch::optional<torch::Tensor>& type_per_edge,\n    const torch::optional<torch::Tensor>& seed_pre_time_window,\n    const torch::optional<torch::Tensor>& probs_or_mask,\n    const torch::optional<torch::Tensor>& node_timestamp,\n    const torch::optional<torch::Tensor>& edge_timestamp, SamplerArgs<S> args) {\n  return [&seed_timestamp, &csc_indices, &fanouts, replace, &options,\n          &type_per_edge, &seed_pre_time_window, &probs_or_mask,\n          &node_timestamp, &edge_timestamp, args](\n             int64_t seed_offset, int64_t offset, int64_t num_neighbors,\n             auto picked_data_ptr) {\n    // If fanouts.size() > 1, perform sampling for each edge type of each\n    // node; otherwise just sample once for each node with no regard of edge\n    // types.\n    if (fanouts.size() > 1) {\n      return TemporalPickByEtype(\n          seed_timestamp, csc_indices, seed_offset, offset, num_neighbors,\n          fanouts, replace, options, type_per_edge.value(),\n          seed_pre_time_window, probs_or_mask, node_timestamp, edge_timestamp,\n          args, picked_data_ptr);\n    } else {\n      int64_t num_sampled = TemporalPick(\n          seed_timestamp, csc_indices, seed_offset, offset, num_neighbors,\n          fanouts[0], replace, options, seed_pre_time_window, probs_or_mask,\n          node_timestamp, edge_timestamp, args, picked_data_ptr);\n      if (type_per_edge.has_value()) {\n        std::sort(picked_data_ptr, picked_data_ptr + num_sampled);\n      }\n      return num_sampled;\n    }\n  };\n}\n\ntemplate <TemporalOption Temporal, typename NumPickFn, typename PickFn>\nc10::intrusive_ptr<FusedSampledSubgraph>\nFusedCSCSamplingGraph::SampleNeighborsImpl(\n    const torch::Tensor& seeds,\n    const torch::optional<std::vector<int64_t>>& seed_offsets,\n    const std::vector<int64_t>& fanouts, NumPickFn num_pick_fn,\n    PickFn pick_fn) const {\n  const int64_t num_seeds = seeds.size(0);\n  const auto indptr_options = indptr_.options();\n\n  // Calculate GrainSize for parallel_for.\n  // Set the default grain size to 64.\n  const int64_t grain_size = 64;\n  torch::Tensor picked_eids;\n  torch::Tensor subgraph_indptr;\n  torch::Tensor subgraph_indices;\n  torch::optional<torch::Tensor> subgraph_type_per_edge = torch::nullopt;\n  torch::optional<torch::Tensor> edge_offsets = torch::nullopt;\n\n  bool with_seed_offsets = seed_offsets.has_value();\n  bool hetero_with_seed_offsets = with_seed_offsets && fanouts.size() > 1 &&\n                                  Temporal == TemporalOption::NOT_TEMPORAL;\n\n  // Get the number of edge types. If it's homo or if the size of fanouts is 1\n  // (hetero graph but sampled as a homo graph), set num_etypes as 1.\n  // In temporal sampling, this will not be used for now since the logic hasn't\n  // been adopted for temporal sampling.\n  const int64_t num_etypes =\n      (edge_type_to_id_.has_value() && hetero_with_seed_offsets)\n          ? edge_type_to_id_->size()\n          : 1;\n  std::vector<int64_t> etype_id_to_src_ntype_id(num_etypes);\n  std::vector<int64_t> etype_id_to_dst_ntype_id(num_etypes);\n  torch::optional<torch::Tensor> subgraph_indptr_substract = torch::nullopt;\n  // The pick numbers are stored in a single tensor by the order of etype. Each\n  // etype corresponds to a group of seeds whose ntype are the same as the\n  // dst_type. `etype_id_to_num_picked_offset` indicates the beginning offset\n  // where each etype's corresponding seeds' pick numbers are stored in the pick\n  // number tensor.\n  std::vector<int64_t> etype_id_to_num_picked_offset(num_etypes + 1);\n  if (hetero_with_seed_offsets) {\n    for (auto& etype_and_id : edge_type_to_id_.value()) {\n      auto etype = etype_and_id.key();\n      auto id = etype_and_id.value();\n      auto [src_type, dst_type] = utils::parse_src_dst_ntype_from_etype(etype);\n      auto dst_ntype_id = node_type_to_id_->at(dst_type);\n      etype_id_to_src_ntype_id[id] = node_type_to_id_->at(src_type);\n      etype_id_to_dst_ntype_id[id] = dst_ntype_id;\n      etype_id_to_num_picked_offset[id + 1] =\n          seed_offsets->at(dst_ntype_id + 1) - seed_offsets->at(dst_ntype_id) +\n          1;\n    }\n    std::partial_sum(\n        etype_id_to_num_picked_offset.begin(),\n        etype_id_to_num_picked_offset.end(),\n        etype_id_to_num_picked_offset.begin());\n  } else {\n    etype_id_to_dst_ntype_id[0] = 0;\n    etype_id_to_num_picked_offset[1] = num_seeds + 1;\n  }\n  // `num_rows` indicates the length of `num_picked_neighbors_per_node`, which\n  // is used for storing pick numbers. In non-temporal hetero sampling, it\n  // equals to sum_{etype} #seeds with ntype=dst_type(etype). In homo sampling,\n  // it equals to `num_seeds`.\n  const int64_t num_rows = etype_id_to_num_picked_offset[num_etypes];\n  torch::Tensor num_picked_neighbors_per_node =\n      // Need to use zeros because all nodes don't have all etypes.\n      torch::zeros({num_rows}, indptr_options);\n\n  AT_DISPATCH_INDEX_TYPES(\n      indptr_.scalar_type(), \"SampleNeighborsImplWrappedWithIndptr\", ([&] {\n        using indptr_t = index_t;\n        AT_DISPATCH_INDEX_TYPES(\n            seeds.scalar_type(), \"SampleNeighborsImplWrappedWithSeeds\", ([&] {\n              using seeds_t = index_t;\n              const auto indptr_data = indptr_.data_ptr<indptr_t>();\n              const auto num_picked_neighbors_data_ptr =\n                  num_picked_neighbors_per_node.data_ptr<indptr_t>();\n              num_picked_neighbors_data_ptr[0] = 0;\n              const auto seeds_data_ptr = seeds.data_ptr<seeds_t>();\n\n              // Step 1. Calculate pick number of each node.\n              torch::parallel_for(\n                  0, num_seeds, grain_size, [&](int64_t begin, int64_t end) {\n                    for (int64_t i = begin; i < end; ++i) {\n                      const auto nid = seeds_data_ptr[i];\n                      TORCH_CHECK(\n                          nid >= 0 && nid < NumNodes(),\n                          \"The seed nodes' IDs should fall within the range of \"\n                          \"the graph's node IDs.\");\n                      const auto offset = indptr_data[nid];\n                      const auto num_neighbors = indptr_data[nid + 1] - offset;\n\n                      if constexpr (Temporal == TemporalOption::TEMPORAL) {\n                        num_picked_neighbors_data_ptr[i + 1] =\n                            num_neighbors == 0\n                                ? 0\n                                : num_pick_fn(i, offset, num_neighbors);\n                      } else {\n                        const auto seed_type_id =\n                            (hetero_with_seed_offsets)\n                                ? std::upper_bound(\n                                      seed_offsets->begin(),\n                                      seed_offsets->end(), i) -\n                                      seed_offsets->begin() - 1\n                                : 0;\n                        // `seed_index` indicates the index of the current\n                        // seed within the group of seeds which have the same\n                        // node type.\n                        const auto seed_index =\n                            (hetero_with_seed_offsets)\n                                ? i - seed_offsets->at(seed_type_id)\n                                : i;\n                        num_pick_fn(\n                            offset, num_neighbors,\n                            num_picked_neighbors_data_ptr + 1, seed_index,\n                            etype_id_to_num_picked_offset);\n                      }\n                    }\n                  });\n\n              // Step 2. Calculate prefix sum to get total length and offsets of\n              // each node. It's also the indptr of the generated subgraph.\n              subgraph_indptr = num_picked_neighbors_per_node.cumsum(\n                  0, indptr_.scalar_type());\n              auto subgraph_indptr_data_ptr =\n                  subgraph_indptr.data_ptr<indptr_t>();\n\n              if (hetero_with_seed_offsets) {\n                torch::Tensor num_picked_offset_tensor =\n                    torch::empty({num_etypes + 1}, indptr_options);\n                const auto num_picked_offset_data_ptr =\n                    num_picked_offset_tensor.data_ptr<indptr_t>();\n                std::copy(\n                    etype_id_to_num_picked_offset.begin(),\n                    etype_id_to_num_picked_offset.end(),\n                    num_picked_offset_data_ptr);\n                torch::Tensor substract_offset =\n                    torch::empty({num_etypes}, indptr_options);\n                const auto substract_offset_data_ptr =\n                    substract_offset.data_ptr<indptr_t>();\n                for (auto i = 0; i < num_etypes; ++i) {\n                  // Collect the total pick number subtract offsets.\n                  substract_offset_data_ptr[i] = subgraph_indptr_data_ptr\n                      [etype_id_to_num_picked_offset[i]];\n                }\n                subgraph_indptr_substract = ops::ExpandIndptr(\n                    num_picked_offset_tensor, indptr_.scalar_type(),\n                    substract_offset);\n              }\n\n              // When doing non-temporal hetero sampling, we generate an\n              // edge_offsets tensor.\n              if (hetero_with_seed_offsets) {\n                edge_offsets = torch::empty({num_etypes + 1}, indptr_options);\n                auto edge_offsets_data_ptr =\n                    edge_offsets.value().data_ptr<indptr_t>();\n                edge_offsets_data_ptr[0] = 0;\n                for (auto i = 0; i < num_etypes; ++i) {\n                  edge_offsets_data_ptr[i + 1] = subgraph_indptr_data_ptr\n                      [etype_id_to_num_picked_offset[i + 1] - 1];\n                }\n              }\n\n              // Step 3. Allocate the tensor for picked neighbors.\n              const auto total_length =\n                  subgraph_indptr.data_ptr<indptr_t>()[num_rows - 1];\n              picked_eids = torch::empty({total_length}, indptr_options);\n              subgraph_indices =\n                  torch::empty({total_length}, indices_.options());\n              if (!hetero_with_seed_offsets && type_per_edge_.has_value()) {\n                subgraph_type_per_edge = torch::empty(\n                    {total_length}, type_per_edge_.value().options());\n              }\n\n              auto picked_eids_data_ptr = picked_eids.data_ptr<indptr_t>();\n              torch::parallel_for(\n                  0, num_seeds, grain_size, [&](int64_t begin, int64_t end) {\n                    for (int64_t i = begin; i < end; ++i) {\n                      const auto nid = seeds_data_ptr[i];\n                      const auto offset = indptr_data[nid];\n                      const auto num_neighbors = indptr_data[nid + 1] - offset;\n                      auto picked_number = 0;\n                      const auto seed_type_id =\n                          (hetero_with_seed_offsets)\n                              ? std::upper_bound(\n                                    seed_offsets->begin(), seed_offsets->end(),\n                                    i) -\n                                    seed_offsets->begin() - 1\n                              : 0;\n                      const auto seed_index =\n                          (hetero_with_seed_offsets)\n                              ? i - seed_offsets->at(seed_type_id)\n                              : i;\n\n                      // Step 4. Pick neighbors for each node.\n                      if constexpr (Temporal == TemporalOption::TEMPORAL) {\n                        picked_number = num_picked_neighbors_data_ptr[i + 1];\n                        auto picked_offset = subgraph_indptr_data_ptr[i];\n                        if (picked_number > 0) {\n                          auto actual_picked_count = pick_fn(\n                              i, offset, num_neighbors,\n                              picked_eids_data_ptr + picked_offset);\n                          TORCH_CHECK(\n                              actual_picked_count == picked_number,\n                              \"Actual picked count doesn't match the calculated\"\n                              \" pick number.\");\n                        }\n                      } else {\n                        picked_number = pick_fn(\n                            offset, num_neighbors, picked_eids_data_ptr,\n                            seed_index, subgraph_indptr_data_ptr,\n                            etype_id_to_num_picked_offset);\n                        if (!hetero_with_seed_offsets) {\n                          TORCH_CHECK(\n                              num_picked_neighbors_data_ptr[i + 1] ==\n                                  picked_number,\n                              \"Actual picked count doesn't match the calculated\"\n                              \" pick number.\");\n                        }\n                      }\n\n                      // Step 5. Calculate other attributes and return the\n                      // subgraph.\n                      if (picked_number > 0) {\n                        // indices dtype and seeds dtype is required to be same.\n                        using index_t = seeds_t;\n                        auto subgraph_indices_data_ptr =\n                            subgraph_indices.data_ptr<index_t>();\n                        auto indices_data_ptr = indices_.data_ptr<index_t>();\n                        for (auto i = 0; i < num_etypes; ++i) {\n                          if (etype_id_to_dst_ntype_id[i] != seed_type_id)\n                            continue;\n                          const auto indptr_offset =\n                              with_seed_offsets\n                                  ? etype_id_to_num_picked_offset[i] +\n                                        seed_index\n                                  : seed_index;\n                          const auto picked_begin =\n                              subgraph_indptr_data_ptr[indptr_offset];\n                          const auto picked_end =\n                              subgraph_indptr_data_ptr[indptr_offset + 1];\n                          for (auto j = picked_begin; j < picked_end; ++j) {\n                            subgraph_indices_data_ptr[j] =\n                                indices_data_ptr[picked_eids_data_ptr[j]];\n                            if (hetero_with_seed_offsets &&\n                                node_type_offset_.has_value()) {\n                              // Substract the node type offset from\n                              // subgraph indices. Assuming\n                              // node_type_offset has the same dtype as\n                              // indices.\n                              auto node_type_offset_data =\n                                  node_type_offset_.value().data_ptr<index_t>();\n                              subgraph_indices_data_ptr[j] -=\n                                  node_type_offset_data\n                                      [etype_id_to_src_ntype_id[i]];\n                            }\n                          }\n                        }\n\n                        if (!hetero_with_seed_offsets &&\n                            type_per_edge_.has_value()) {\n                          // When hetero graph is sampled as a homo graph, we\n                          // still generate type_per_edge tensor for this\n                          // situation.\n                          AT_DISPATCH_INTEGRAL_TYPES(\n                              subgraph_type_per_edge.value().scalar_type(),\n                              \"IndexSelectTypePerEdge\", ([&] {\n                                auto subgraph_type_per_edge_data_ptr =\n                                    subgraph_type_per_edge.value()\n                                        .data_ptr<scalar_t>();\n                                auto type_per_edge_data_ptr =\n                                    type_per_edge_.value().data_ptr<scalar_t>();\n                                const auto picked_offset =\n                                    subgraph_indptr_data_ptr[seed_index];\n                                for (auto j = picked_offset;\n                                     j < picked_offset + picked_number; ++j)\n                                  subgraph_type_per_edge_data_ptr[j] =\n                                      type_per_edge_data_ptr\n                                          [picked_eids_data_ptr[j]];\n                              }));\n                        }\n                      }\n                    }\n                  });\n            }));\n      }));\n\n  if (subgraph_indptr_substract.has_value()) {\n    subgraph_indptr -= subgraph_indptr_substract.value();\n  }\n\n  return c10::make_intrusive<FusedSampledSubgraph>(\n      subgraph_indptr, subgraph_indices, picked_eids, seeds, torch::nullopt,\n      subgraph_type_per_edge, edge_offsets);\n}\n\nc10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(\n    torch::optional<torch::Tensor> seeds,\n    torch::optional<std::vector<int64_t>> seed_offsets,\n    const std::vector<int64_t>& fanouts, bool replace, bool layer,\n    bool returning_indices_is_optional,\n    torch::optional<torch::Tensor> probs_or_mask,\n    torch::optional<torch::Tensor> random_seed,\n    double seed2_contribution) const {\n  // If seeds does not have a value, then we expect all arguments to be resident\n  // on the GPU. If seeds has a value, then we expect them to be accessible from\n  // GPU. This is required for the dispatch to work when CUDA is not available.\n  if (((!seeds.has_value() && utils::is_on_gpu(indptr_) &&\n        utils::is_on_gpu(indices_) &&\n        (!probs_or_mask.has_value() ||\n         utils::is_on_gpu(probs_or_mask.value())) &&\n        (!type_per_edge_.has_value() ||\n         utils::is_on_gpu(type_per_edge_.value()))) ||\n       (seeds.has_value() && utils::is_on_gpu(seeds.value()) &&\n        utils::is_accessible_from_gpu(indptr_) &&\n        utils::is_accessible_from_gpu(indices_) &&\n        (!probs_or_mask.has_value() ||\n         utils::is_accessible_from_gpu(probs_or_mask.value())) &&\n        (!type_per_edge_.has_value() ||\n         utils::is_accessible_from_gpu(type_per_edge_.value())))) &&\n      !replace) {\n    GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(\n        c10::DeviceType::CUDA, \"SampleNeighbors\", {\n          return ops::SampleNeighbors(\n              indptr_, indices_, seeds, seed_offsets, fanouts, replace, layer,\n              returning_indices_is_optional, type_per_edge_, probs_or_mask,\n              node_type_offset_, node_type_to_id_, edge_type_to_id_,\n              random_seed, seed2_contribution);\n        });\n  }\n  TORCH_CHECK(seeds.has_value(), \"Nodes can not be None on the CPU.\");\n\n  if (probs_or_mask.has_value()) {\n    // Note probs will be passed as input for 'torch.multinomial' in deeper\n    // stack, which doesn't support 'torch.half' and 'torch.bool' data types. To\n    // avoid crashes, convert 'probs_or_mask' to 'float32' data type.\n    if (probs_or_mask.value().dtype() == torch::kBool ||\n        probs_or_mask.value().dtype() == torch::kFloat16) {\n      probs_or_mask = probs_or_mask.value().to(torch::kFloat32);\n    }\n  }\n\n  bool with_seed_offsets = seed_offsets.has_value();\n\n  if (layer) {\n    if (random_seed.has_value() && random_seed->numel() >= 2) {\n      SamplerArgs<SamplerType::LABOR_DEPENDENT> args{\n          indices_,\n          {random_seed.value(), static_cast<float>(seed2_contribution)},\n          NumNodes()};\n      return SampleNeighborsImpl<TemporalOption::NOT_TEMPORAL>(\n          seeds.value(), seed_offsets, fanouts,\n          GetNumPickFn(\n              fanouts, replace, type_per_edge_, probs_or_mask,\n              with_seed_offsets),\n          GetPickFn(\n              fanouts, replace, indptr_.options(), type_per_edge_,\n              probs_or_mask, with_seed_offsets, args));\n    } else {\n      auto args = [&] {\n        if (random_seed.has_value() && random_seed->numel() == 1) {\n          return SamplerArgs<SamplerType::LABOR>{\n              indices_, random_seed.value(), NumNodes()};\n        } else {\n          return SamplerArgs<SamplerType::LABOR>{\n              indices_,\n              RandomEngine::ThreadLocal()->RandInt(\n                  static_cast<int64_t>(0), std::numeric_limits<int64_t>::max()),\n              NumNodes()};\n        }\n      }();\n      return SampleNeighborsImpl<TemporalOption::NOT_TEMPORAL>(\n          seeds.value(), seed_offsets, fanouts,\n          GetNumPickFn(\n              fanouts, replace, type_per_edge_, probs_or_mask,\n              with_seed_offsets),\n          GetPickFn(\n              fanouts, replace, indptr_.options(), type_per_edge_,\n              probs_or_mask, with_seed_offsets, args));\n    }\n  } else {\n    SamplerArgs<SamplerType::NEIGHBOR> args;\n    return SampleNeighborsImpl<TemporalOption::NOT_TEMPORAL>(\n        seeds.value(), seed_offsets, fanouts,\n        GetNumPickFn(\n            fanouts, replace, type_per_edge_, probs_or_mask, with_seed_offsets),\n        GetPickFn(\n            fanouts, replace, indptr_.options(), type_per_edge_, probs_or_mask,\n            with_seed_offsets, args));\n  }\n}\n\nc10::intrusive_ptr<Future<c10::intrusive_ptr<FusedSampledSubgraph>>>\nFusedCSCSamplingGraph::SampleNeighborsAsync(\n    torch::optional<torch::Tensor> seeds,\n    torch::optional<std::vector<int64_t>> seed_offsets,\n    const std::vector<int64_t>& fanouts, bool replace, bool layer,\n    bool returning_indices_is_optional,\n    torch::optional<torch::Tensor> probs_or_mask,\n    torch::optional<torch::Tensor> random_seed,\n    double seed2_contribution) const {\n  return async(\n      [=] {\n        return this->SampleNeighbors(\n            seeds, seed_offsets, fanouts, replace, layer,\n            returning_indices_is_optional, probs_or_mask, random_seed,\n            seed2_contribution);\n      },\n      (seeds.has_value() && utils::is_on_gpu(*seeds)) ||\n          utils::is_on_gpu(indptr_));\n}\n\nc10::intrusive_ptr<FusedSampledSubgraph>\nFusedCSCSamplingGraph::TemporalSampleNeighbors(\n    const torch::optional<torch::Tensor>& seeds,\n    const torch::optional<std::vector<int64_t>>& seed_offsets,\n    const torch::Tensor& seeds_timestamp, const std::vector<int64_t>& fanouts,\n    bool replace, bool layer, bool returning_indices_is_optional,\n    torch::optional<torch::Tensor> seeds_pre_time_window,\n    torch::optional<torch::Tensor> probs_or_mask,\n    torch::optional<std::string> node_timestamp_attr_name,\n    torch::optional<std::string> edge_timestamp_attr_name,\n    torch::optional<torch::Tensor> random_seed,\n    double seed2_contribution) const {\n  // 1. Get the timestamp attribute for nodes of the graph\n  const auto node_timestamp = this->NodeAttribute(node_timestamp_attr_name);\n  // 2. Get the timestamp attribute for edges of the graph\n  const auto edge_timestamp = this->EdgeAttribute(edge_timestamp_attr_name);\n  // If seeds does not have a value, then we expect all arguments to be resident\n  // on the GPU. If seeds has a value, then we expect them to be accessible from\n  // GPU. This is required for the dispatch to work when CUDA is not available.\n  if (((!seeds.has_value() && utils::is_on_gpu(indptr_) &&\n        utils::is_on_gpu(indices_) &&\n        (!probs_or_mask.has_value() ||\n         utils::is_on_gpu(probs_or_mask.value())) &&\n        (!type_per_edge_.has_value() ||\n         utils::is_on_gpu(type_per_edge_.value()))) ||\n       (seeds.has_value() && utils::is_on_gpu(seeds.value()) &&\n        utils::is_accessible_from_gpu(indptr_) &&\n        utils::is_accessible_from_gpu(indices_) &&\n        (!probs_or_mask.has_value() ||\n         utils::is_accessible_from_gpu(probs_or_mask.value())) &&\n        (!type_per_edge_.has_value() ||\n         utils::is_accessible_from_gpu(type_per_edge_.value())))) &&\n      utils::is_accessible_from_gpu(seeds_timestamp) &&\n      (!seeds_pre_time_window.has_value() ||\n       utils::is_accessible_from_gpu(*seeds_pre_time_window)) &&\n      (!node_timestamp.has_value() ||\n       utils::is_accessible_from_gpu(*node_timestamp)) &&\n      (!edge_timestamp.has_value() ||\n       utils::is_accessible_from_gpu(*edge_timestamp)) &&\n      !replace) {\n    GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(\n        c10::DeviceType::CUDA, \"SampleNeighbors\", {\n          return ops::SampleNeighbors(\n              indptr_, indices_, seeds, seed_offsets, fanouts, replace, layer,\n              returning_indices_is_optional, type_per_edge_, probs_or_mask,\n              node_type_offset_, node_type_to_id_, edge_type_to_id_,\n              random_seed, seed2_contribution, seeds_timestamp,\n              seeds_pre_time_window, node_timestamp, edge_timestamp);\n        });\n  }\n  TORCH_CHECK(seeds.has_value(), \"Nodes can not be None for CPU.\");\n  // 3. Get probs_or_mask.\n  if (probs_or_mask.has_value()) {\n    // Note probs will be passed as input for 'torch.multinomial' in deeper\n    // stack, which doesn't support 'torch.half' and 'torch.bool' data types. To\n    // avoid crashes, convert 'probs_or_mask' to 'float32' data type.\n    if (probs_or_mask.value().dtype() == torch::kBool ||\n        probs_or_mask.value().dtype() == torch::kFloat16) {\n      probs_or_mask = probs_or_mask.value().to(torch::kFloat32);\n    }\n  }\n  // 4. Call SampleNeighborsImpl\n  if (layer) {\n    if (random_seed.has_value() && random_seed->numel() >= 2) {\n      SamplerArgs<SamplerType::LABOR_DEPENDENT> args{\n          indices_,\n          {random_seed.value(), static_cast<float>(seed2_contribution)},\n          NumNodes()};\n      return SampleNeighborsImpl<TemporalOption::TEMPORAL>(\n          *seeds, seed_offsets, fanouts,\n          GetTemporalNumPickFn(\n              seeds_timestamp, indices_, fanouts, replace, type_per_edge_,\n              seeds_pre_time_window, probs_or_mask, node_timestamp,\n              edge_timestamp),\n          GetTemporalPickFn(\n              seeds_timestamp, indices_, fanouts, replace, indptr_.options(),\n              type_per_edge_, seeds_pre_time_window, probs_or_mask,\n              node_timestamp, edge_timestamp, args));\n    } else {\n      auto args = [&] {\n        if (random_seed.has_value() && random_seed->numel() == 1) {\n          return SamplerArgs<SamplerType::LABOR>{\n              indices_, random_seed.value(), NumNodes()};\n        } else {\n          return SamplerArgs<SamplerType::LABOR>{\n              indices_,\n              RandomEngine::ThreadLocal()->RandInt(\n                  static_cast<int64_t>(0), std::numeric_limits<int64_t>::max()),\n              NumNodes()};\n        }\n      }();\n      return SampleNeighborsImpl<TemporalOption::TEMPORAL>(\n          *seeds, seed_offsets, fanouts,\n          GetTemporalNumPickFn(\n              seeds_timestamp, indices_, fanouts, replace, type_per_edge_,\n              seeds_pre_time_window, probs_or_mask, node_timestamp,\n              edge_timestamp),\n          GetTemporalPickFn(\n              seeds_timestamp, indices_, fanouts, replace, indptr_.options(),\n              type_per_edge_, seeds_pre_time_window, probs_or_mask,\n              node_timestamp, edge_timestamp, args));\n    }\n  } else {\n    SamplerArgs<SamplerType::NEIGHBOR> args;\n    return SampleNeighborsImpl<TemporalOption::TEMPORAL>(\n        *seeds, seed_offsets, fanouts,\n        GetTemporalNumPickFn(\n            seeds_timestamp, this->indices_, fanouts, replace, type_per_edge_,\n            seeds_pre_time_window, probs_or_mask, node_timestamp,\n            edge_timestamp),\n        GetTemporalPickFn(\n            seeds_timestamp, this->indices_, fanouts, replace,\n            indptr_.options(), type_per_edge_, seeds_pre_time_window,\n            probs_or_mask, node_timestamp, edge_timestamp, args));\n  }\n}\n\nstatic c10::intrusive_ptr<FusedCSCSamplingGraph>\nBuildGraphFromSharedMemoryHelper(SharedMemoryHelper&& helper) {\n  helper.InitializeRead();\n  auto indptr = helper.ReadTorchTensor();\n  auto indices = helper.ReadTorchTensor();\n  auto node_type_offset = helper.ReadTorchTensor();\n  auto type_per_edge = helper.ReadTorchTensor();\n  auto node_type_to_id = DetensorizeDict(helper.ReadTorchTensorDict());\n  auto edge_type_to_id = DetensorizeDict(helper.ReadTorchTensorDict());\n  auto node_attributes = helper.ReadTorchTensorDict();\n  auto edge_attributes = helper.ReadTorchTensorDict();\n  auto graph = c10::make_intrusive<FusedCSCSamplingGraph>(\n      indptr.value(), indices.value(), node_type_offset, type_per_edge,\n      node_type_to_id, edge_type_to_id, node_attributes, edge_attributes);\n  auto shared_memory = helper.ReleaseSharedMemory();\n  graph->HoldSharedMemoryObject(\n      std::move(shared_memory.first), std::move(shared_memory.second));\n  return graph;\n}\n\nc10::intrusive_ptr<FusedCSCSamplingGraph>\nFusedCSCSamplingGraph::CopyToSharedMemory(\n    const std::string& shared_memory_name) {\n  SharedMemoryHelper helper(shared_memory_name);\n  helper.WriteTorchTensor(indptr_);\n  helper.WriteTorchTensor(indices_);\n  helper.WriteTorchTensor(node_type_offset_);\n  helper.WriteTorchTensor(type_per_edge_);\n  helper.WriteTorchTensorDict(TensorizeDict(node_type_to_id_));\n  helper.WriteTorchTensorDict(TensorizeDict(edge_type_to_id_));\n  helper.WriteTorchTensorDict(node_attributes_);\n  helper.WriteTorchTensorDict(edge_attributes_);\n  helper.Flush();\n  return BuildGraphFromSharedMemoryHelper(std::move(helper));\n}\n\nc10::intrusive_ptr<FusedCSCSamplingGraph>\nFusedCSCSamplingGraph::LoadFromSharedMemory(\n    const std::string& shared_memory_name) {\n  SharedMemoryHelper helper(shared_memory_name);\n  return BuildGraphFromSharedMemoryHelper(std::move(helper));\n}\n\nvoid FusedCSCSamplingGraph::HoldSharedMemoryObject(\n    SharedMemoryPtr tensor_metadata_shm, SharedMemoryPtr tensor_data_shm) {\n  tensor_metadata_shm_ = std::move(tensor_metadata_shm);\n  tensor_data_shm_ = std::move(tensor_data_shm);\n}\n\ntemplate <typename PickedNumType>\nvoid NumPick(\n    int64_t fanout, bool replace,\n    const torch::optional<torch::Tensor>& probs_or_mask, int64_t offset,\n    int64_t num_neighbors, PickedNumType* picked_num_ptr) {\n  int64_t num_valid_neighbors = num_neighbors;\n  if (probs_or_mask.has_value() && num_neighbors > 0) {\n    // Subtract the count of zeros in probs_or_mask.\n    AT_DISPATCH_ALL_TYPES(\n        probs_or_mask.value().scalar_type(), \"CountZero\", ([&] {\n          scalar_t* probs_data_ptr = probs_or_mask.value().data_ptr<scalar_t>();\n          num_valid_neighbors -= std::count(\n              probs_data_ptr + offset, probs_data_ptr + offset + num_neighbors,\n              0);\n        }));\n  }\n  if (num_valid_neighbors == 0 || fanout == -1) {\n    *picked_num_ptr = num_valid_neighbors;\n  } else {\n    *picked_num_ptr = replace ? fanout : std::min(fanout, num_valid_neighbors);\n  }\n}\n\ntorch::Tensor TemporalMask(\n    int64_t seed_timestamp, torch::Tensor csc_indices,\n    const torch::optional<int64_t>& seed_pre_time_window,\n    const torch::optional<torch::Tensor>& probs_or_mask,\n    const torch::optional<torch::Tensor>& node_timestamp,\n    const torch::optional<torch::Tensor>& edge_timestamp,\n    std::pair<int64_t, int64_t> edge_range) {\n  auto [l, r] = edge_range;\n  torch::Tensor mask = torch::ones({r - l}, torch::kBool);\n  if (node_timestamp.has_value()) {\n    auto neighbor_timestamp =\n        node_timestamp.value().index_select(0, csc_indices.slice(0, l, r));\n    mask &= neighbor_timestamp < seed_timestamp;\n    if (seed_pre_time_window.has_value())\n      mask &=\n          neighbor_timestamp > seed_timestamp - seed_pre_time_window.value();\n  }\n  if (edge_timestamp.has_value()) {\n    auto edge_ts = edge_timestamp.value().slice(0, l, r);\n    mask &= edge_ts < seed_timestamp;\n    if (seed_pre_time_window.has_value())\n      mask &= edge_ts > seed_timestamp - seed_pre_time_window.value();\n  }\n  if (probs_or_mask.has_value()) {\n    mask &= probs_or_mask.value().slice(0, l, r) != 0;\n  }\n  return mask;\n}\n\n/**\n * @brief Fast path for temporal sampling without probability. It is used when\n * the number of neighbors is large. It randomly samples neighbors and checks\n * the timestamp of the neighbors. It is successful if the number of sampled\n * neighbors in kTriedThreshold trials is equal to the fanout.\n */\nstd::pair<bool, std::vector<int64_t>> FastTemporalPick(\n    torch::Tensor seed_timestamp, torch::Tensor csc_indices, int64_t fanout,\n    bool replace, const torch::optional<torch::Tensor>& seed_pre_time_window,\n    const torch::optional<torch::Tensor>& node_timestamp,\n    const torch::optional<torch::Tensor>& edge_timestamp, int64_t seed_offset,\n    int64_t offset, int64_t num_neighbors) {\n  constexpr int64_t kTriedThreshold = 1000;\n  auto timestamp = utils::GetValueByIndex<int64_t>(seed_timestamp, seed_offset);\n  torch::optional<int64_t> time_window = torch::nullopt;\n  if (seed_pre_time_window.has_value()) {\n    time_window = utils::GetValueByIndex<int64_t>(\n        seed_pre_time_window.value(), seed_offset);\n  }\n  std::vector<int64_t> sampled_edges;\n  sampled_edges.reserve(fanout);\n  std::set<int64_t> sampled_edge_set;\n  int64_t sample_count = 0;\n  int64_t tried = 0;\n  while (sample_count < fanout && tried < kTriedThreshold) {\n    int64_t edge_id =\n        RandomEngine::ThreadLocal()->RandInt(offset, offset + num_neighbors);\n    ++tried;\n    if (!replace && sampled_edge_set.count(edge_id) > 0) {\n      continue;\n    }\n    if (node_timestamp.has_value()) {\n      bool flag = true;\n      AT_DISPATCH_INDEX_TYPES(\n          csc_indices.scalar_type(), \"CheckNodeTimeStamp\", ([&] {\n            int64_t neighbor_id =\n                utils::GetValueByIndex<index_t>(csc_indices, edge_id);\n            auto neighbor_ts = utils::GetValueByIndex<int64_t>(\n                node_timestamp.value(), neighbor_id);\n            if (neighbor_ts >= timestamp ||\n                (time_window.has_value() &&\n                 neighbor_ts <= (timestamp - time_window.value())))\n              flag = false;\n          }));\n      if (!flag) continue;\n    }\n    if (edge_timestamp.has_value()) {\n      auto edge_ts =\n          utils::GetValueByIndex<int64_t>(edge_timestamp.value(), edge_id);\n      if (edge_ts >= timestamp ||\n          (time_window.has_value() &&\n           edge_ts <= (timestamp - time_window.value())))\n        continue;\n      continue;\n    }\n    if (!replace) {\n      sampled_edge_set.insert(edge_id);\n    }\n    sampled_edges.push_back(edge_id);\n    sample_count++;\n  }\n  if (sample_count < fanout) {\n    return {false, {}};\n  }\n  return {true, sampled_edges};\n}\n\nint64_t TemporalNumPick(\n    torch::Tensor seed_timestamp, torch::Tensor csc_indics, int64_t fanout,\n    bool replace, const torch::optional<torch::Tensor>& seed_pre_time_window,\n    const torch::optional<torch::Tensor>& probs_or_mask,\n    const torch::optional<torch::Tensor>& node_timestamp,\n    const torch::optional<torch::Tensor>& edge_timestamp, int64_t seed_offset,\n    int64_t offset, int64_t num_neighbors) {\n  constexpr int64_t kFastPathThreshold = 1000;\n  if (num_neighbors > kFastPathThreshold && !probs_or_mask.has_value()) {\n    // TODO: Currently we use the fast path both in TemporalNumPick and\n    // TemporalPick. We may only sample once in TemporalNumPick and use the\n    // sampled edges in TemporalPick to avoid sampling twice.\n    auto [success, sampled_edges] = FastTemporalPick(\n        seed_timestamp, csc_indics, fanout, replace, seed_pre_time_window,\n        node_timestamp, edge_timestamp, seed_offset, offset, num_neighbors);\n    if (success) return sampled_edges.size();\n  }\n  torch::optional<int64_t> time_window = torch::nullopt;\n  if (seed_pre_time_window.has_value()) {\n    time_window = utils::GetValueByIndex<int64_t>(\n        seed_pre_time_window.value(), seed_offset);\n  }\n  auto mask = TemporalMask(\n      utils::GetValueByIndex<int64_t>(seed_timestamp, seed_offset), csc_indics,\n      time_window, probs_or_mask, node_timestamp, edge_timestamp,\n      {offset, offset + num_neighbors});\n  int64_t num_valid_neighbors = utils::GetValueByIndex<int64_t>(mask.sum(), 0);\n  if (num_valid_neighbors == 0 || fanout == -1) return num_valid_neighbors;\n  return replace ? fanout : std::min(fanout, num_valid_neighbors);\n}\n\ntemplate <typename PickedNumType>\nvoid NumPickByEtype(\n    bool with_seed_offsets, const std::vector<int64_t>& fanouts, bool replace,\n    const torch::Tensor& type_per_edge,\n    const torch::optional<torch::Tensor>& probs_or_mask, int64_t offset,\n    int64_t num_neighbors, PickedNumType* num_picked_ptr, int64_t seed_index,\n    const std::vector<int64_t>& etype_id_to_num_picked_offset) {\n  int64_t etype_begin = offset;\n  const int64_t end = offset + num_neighbors;\n  PickedNumType total_count = 0;\n  AT_DISPATCH_INTEGRAL_TYPES(\n      type_per_edge.scalar_type(), \"NumPickFnByEtype\", ([&] {\n        const scalar_t* type_per_edge_data = type_per_edge.data_ptr<scalar_t>();\n        while (etype_begin < end) {\n          scalar_t etype = type_per_edge_data[etype_begin];\n          TORCH_CHECK(\n              etype >= 0 && etype < (int64_t)fanouts.size(),\n              \"Etype values exceed the number of fanouts.\");\n          auto etype_end_it = std::upper_bound(\n              type_per_edge_data + etype_begin, type_per_edge_data + end,\n              etype);\n          int64_t etype_end = etype_end_it - type_per_edge_data;\n          // Do sampling for one etype.\n          if (with_seed_offsets) {\n            // The pick numbers aren't stored continuously, but separately for\n            // each different etype.\n            const auto offset =\n                etype_id_to_num_picked_offset[etype] + seed_index;\n            NumPick(\n                fanouts[etype], replace, probs_or_mask, etype_begin,\n                etype_end - etype_begin, num_picked_ptr + offset);\n          } else {\n            PickedNumType picked_count = 0;\n            NumPick(\n                fanouts[etype], replace, probs_or_mask, etype_begin,\n                etype_end - etype_begin, &picked_count);\n            total_count += picked_count;\n          }\n          etype_begin = etype_end;\n        }\n      }));\n  if (!with_seed_offsets) {\n    num_picked_ptr[seed_index] = total_count;\n  }\n}\n\nint64_t TemporalNumPickByEtype(\n    torch::Tensor seed_timestamp, torch::Tensor csc_indices,\n    const std::vector<int64_t>& fanouts, bool replace,\n    const torch::Tensor& type_per_edge,\n    const torch::optional<torch::Tensor>& seed_pre_time_window,\n    const torch::optional<torch::Tensor>& probs_or_mask,\n    const torch::optional<torch::Tensor>& node_timestamp,\n    const torch::optional<torch::Tensor>& edge_timestamp, int64_t seed_offset,\n    int64_t offset, int64_t num_neighbors) {\n  int64_t etype_begin = offset;\n  const int64_t end = offset + num_neighbors;\n  int64_t total_count = 0;\n  AT_DISPATCH_INTEGRAL_TYPES(\n      type_per_edge.scalar_type(), \"TemporalNumPickFnByEtype\", ([&] {\n        const scalar_t* type_per_edge_data = type_per_edge.data_ptr<scalar_t>();\n        while (etype_begin < end) {\n          scalar_t etype = type_per_edge_data[etype_begin];\n          TORCH_CHECK(\n              etype >= 0 && etype < (int64_t)fanouts.size(),\n              \"Etype values exceed the number of fanouts.\");\n          auto etype_end_it = std::upper_bound(\n              type_per_edge_data + etype_begin, type_per_edge_data + end,\n              etype);\n          int64_t etype_end = etype_end_it - type_per_edge_data;\n          // Do sampling for one etype.\n          total_count += TemporalNumPick(\n              seed_timestamp, csc_indices, fanouts[etype], replace,\n              seed_pre_time_window, probs_or_mask, node_timestamp,\n              edge_timestamp, seed_offset, etype_begin,\n              etype_end - etype_begin);\n          etype_begin = etype_end;\n        }\n      }));\n  return total_count;\n}\n\n/**\n * @brief Perform uniform sampling of elements and return the sampled indices.\n *\n * @param offset The starting edge ID for the connected neighbors of the sampled\n * node.\n * @param num_neighbors The number of neighbors to pick.\n * @param fanout The number of edges to be sampled for each node. It should be\n * >= 0 or -1.\n *  - When the value is -1, all neighbors will be sampled once regardless of\n * replacement. It is equivalent to selecting all neighbors when the fanout is\n * >= the number of neighbors (and replacement is set to false).\n *  - When the value is a non-negative integer, it serves as a minimum\n * threshold for selecting neighbors.\n * @param replace Boolean indicating whether the sample is performed with or\n * without replacement. If True, a value can be selected multiple times.\n * Otherwise, each value can be selected only once.\n * @param options Tensor options specifying the desired data type of the result.\n * @param picked_data_ptr The destination address where the picked neighbors\n * should be put. Enough memory space should be allocated in advance.\n */\ntemplate <typename PickedType>\ninline int64_t UniformPick(\n    int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,\n    const torch::TensorOptions& options, PickedType* picked_data_ptr) {\n  if ((fanout == -1) || (num_neighbors <= fanout && !replace)) {\n    std::iota(picked_data_ptr, picked_data_ptr + num_neighbors, offset);\n    return num_neighbors;\n  } else if (replace) {\n    std::memcpy(\n        picked_data_ptr,\n        torch::randint(offset, offset + num_neighbors, {fanout}, options)\n            .data_ptr<PickedType>(),\n        fanout * sizeof(PickedType));\n    return fanout;\n  } else {\n    // We use different sampling strategies for different sampling case.\n    if (fanout >= num_neighbors / 10) {\n      // [Algorithm]\n      // This algorithm is conceptually related to the Fisher-Yates\n      // shuffle.\n      //\n      // [Complexity Analysis]\n      // This algorithm's memory complexity is O(num_neighbors), but\n      // it generates fewer random numbers (O(fanout)).\n      //\n      // (Compare) Reservoir algorithm is one of the most classical\n      // sampling algorithms. Both the reservoir algorithm and our\n      // algorithm offer distinct advantages, we need to compare to\n      // illustrate our trade-offs.\n      // The reservoir algorithm is memory-efficient (O(fanout)) but\n      // creates many random numbers (O(num_neighbors)), which is\n      // costly.\n      //\n      // [Practical Consideration]\n      // Use this algorithm when `fanout >= num_neighbors / 10` to\n      // reduce computation.\n      // In this scenarios above, memory complexity is not a concern due\n      // to the small size of both `fanout` and `num_neighbors`. And it\n      // is efficient to allocate a small amount of memory. So the\n      // algorithm performence is great in this case.\n      std::vector<PickedType> seq(num_neighbors);\n      // Assign the seq with [offset, offset + num_neighbors].\n      std::iota(seq.begin(), seq.end(), offset);\n      for (int64_t i = 0; i < fanout; ++i) {\n        auto j = RandomEngine::ThreadLocal()->RandInt(i, num_neighbors);\n        std::swap(seq[i], seq[j]);\n      }\n      // Save the randomly sampled fanout elements to the output tensor.\n      std::copy(seq.begin(), seq.begin() + fanout, picked_data_ptr);\n      return fanout;\n    } else if (fanout < 64) {\n      // [Algorithm]\n      // Use linear search to verify uniqueness.\n      //\n      // [Complexity Analysis]\n      // Since the set of numbers is small (up to 64), so it is more\n      // cost-effective for the CPU to use this algorithm.\n      auto begin = picked_data_ptr;\n      auto end = picked_data_ptr + fanout;\n\n      while (begin != end) {\n        // Put the new random number in the last position.\n        *begin = RandomEngine::ThreadLocal()->RandInt(\n            offset, offset + num_neighbors);\n        // Check if a new value doesn't exist in current\n        // range(picked_data_ptr, begin). Otherwise get a new\n        // value until we haven't unique range of elements.\n        auto it = std::find(picked_data_ptr, begin, *begin);\n        if (it == begin) ++begin;\n      }\n      return fanout;\n    } else {\n      // [Algorithm]\n      // Use hash-set to verify uniqueness. In the best scenario, the\n      // time complexity is O(fanout), assuming no conflicts occur.\n      //\n      // [Complexity Analysis]\n      // Let K = (fanout / num_neighbors), the expected number of extra\n      // sampling steps is roughly K^2 / (1-K) * num_neighbors, which\n      // means in the worst case scenario, the time complexity is\n      // O(num_neighbors^2).\n      //\n      // [Practical Consideration]\n      // In practice, we set the threshold K to 1/10. This trade-off is\n      // due to the slower performance of std::unordered_set, which\n      // would otherwise increase the sampling cost. By doing so, we\n      // achieve a balance between theoretical efficiency and practical\n      // performance.\n      std::unordered_set<PickedType> picked_set;\n      while (static_cast<int64_t>(picked_set.size()) < fanout) {\n        picked_set.insert(RandomEngine::ThreadLocal()->RandInt(\n            offset, offset + num_neighbors));\n      }\n      std::copy(picked_set.begin(), picked_set.end(), picked_data_ptr);\n      return picked_set.size();\n    }\n  }\n}\n\n/** @brief An operator to perform non-uniform sampling. */\nstatic torch::Tensor NonUniformPickOp(\n    torch::Tensor probs, int64_t fanout, bool replace) {\n  auto positive_probs_indices = probs.nonzero().squeeze(1);\n  auto num_positive_probs = positive_probs_indices.size(0);\n  if (num_positive_probs == 0) return torch::empty({0}, torch::kLong);\n  if ((fanout == -1) || (num_positive_probs <= fanout && !replace)) {\n    return positive_probs_indices;\n  }\n  if (!replace) fanout = std::min(fanout, num_positive_probs);\n  if (fanout == 0) return torch::empty({0}, torch::kLong);\n  auto ret_tensor = torch::empty({fanout}, torch::kLong);\n  auto ret_ptr = ret_tensor.data_ptr<int64_t>();\n  AT_DISPATCH_FLOATING_TYPES(\n      probs.scalar_type(), \"MultinomialSampling\", ([&] {\n        auto probs_data_ptr = probs.data_ptr<scalar_t>();\n        auto positive_probs_indices_ptr =\n            positive_probs_indices.data_ptr<int64_t>();\n\n        if (!replace) {\n          // The algorithm is from gumbel softmax.\n          // s = argmax( logp - log(-log(eps)) ) where eps ~ U(0, 1).\n          // Here we can apply exp to the formula which will not affect result\n          // of argmax or topk. Then we have\n          // s = argmax( p / (-log(eps)) ) where eps ~ U(0, 1).\n          // We can also simplify the formula above by\n          // s = argmax( p / q ) where q ~ Exp(1).\n          if (fanout == 1) {\n            // Return argmax(p / q).\n            scalar_t max_prob = 0;\n            int64_t max_prob_index = -1;\n            // We only care about the neighbors with non-zero probability.\n            for (auto i = 0; i < num_positive_probs; ++i) {\n              // Calculate (p / q) for the current neighbor.\n              scalar_t current_prob =\n                  probs_data_ptr[positive_probs_indices_ptr[i]] /\n                  RandomEngine::ThreadLocal()->Exponential(1.);\n              if (current_prob > max_prob) {\n                max_prob = current_prob;\n                max_prob_index = positive_probs_indices_ptr[i];\n              }\n            }\n            ret_ptr[0] = max_prob_index;\n          } else {\n            // Return topk(p / q).\n            std::vector<std::pair<scalar_t, int64_t>> q(num_positive_probs);\n            for (auto i = 0; i < num_positive_probs; ++i) {\n              q[i].first = probs_data_ptr[positive_probs_indices_ptr[i]] /\n                           RandomEngine::ThreadLocal()->Exponential(1.);\n              q[i].second = positive_probs_indices_ptr[i];\n            }\n            if (fanout < num_positive_probs / 64) {\n              // Use partial_sort.\n              std::partial_sort(\n                  q.begin(), q.begin() + fanout, q.end(), std::greater{});\n              for (auto i = 0; i < fanout; ++i) {\n                ret_ptr[i] = q[i].second;\n              }\n            } else {\n              // Use nth_element.\n              std::nth_element(\n                  q.begin(), q.begin() + fanout - 1, q.end(), std::greater{});\n              for (auto i = 0; i < fanout; ++i) {\n                ret_ptr[i] = q[i].second;\n              }\n            }\n          }\n        } else {\n          // Calculate cumulative sum of probabilities.\n          std::vector<scalar_t> prefix_sum_probs(num_positive_probs);\n          scalar_t sum_probs = 0;\n          for (auto i = 0; i < num_positive_probs; ++i) {\n            sum_probs += probs_data_ptr[positive_probs_indices_ptr[i]];\n            prefix_sum_probs[i] = sum_probs;\n          }\n          // Normalize.\n          if ((sum_probs > 1.00001) || (sum_probs < 0.99999)) {\n            for (auto i = 0; i < num_positive_probs; ++i) {\n              prefix_sum_probs[i] /= sum_probs;\n            }\n          }\n          for (auto i = 0; i < fanout; ++i) {\n            // Sample a probability mass from a uniform distribution.\n            double uniform_sample =\n                RandomEngine::ThreadLocal()->Uniform(0., 1.);\n            // Use a binary search to find the index.\n            int sampled_index = std::lower_bound(\n                                    prefix_sum_probs.begin(),\n                                    prefix_sum_probs.end(), uniform_sample) -\n                                prefix_sum_probs.begin();\n            ret_ptr[i] = positive_probs_indices_ptr[sampled_index];\n          }\n        }\n      }));\n  return ret_tensor;\n}\n\n/**\n * @brief Perform non-uniform sampling of elements based on probabilities and\n * return the sampled indices.\n *\n * If 'probs_or_mask' is provided, it indicates that the sampling is\n * non-uniform. In such cases:\n * - When the number of neighbors with non-zero probability is less than or\n * equal to fanout, all neighbors with non-zero probability will be selected.\n * - When the number of neighbors with non-zero probability exceeds fanout, the\n * sampling process will select 'fanout' elements based on their respective\n * probabilities. Higher probabilities will increase the chances of being chosen\n * during the sampling process.\n *\n * @param offset The starting edge ID for the connected neighbors of the sampled\n * node.\n * @param num_neighbors The number of neighbors to pick.\n * @param fanout The number of edges to be sampled for each node. It should be\n * >= 0 or -1.\n *  - When the value is -1, all neighbors with non-zero probability will be\n * sampled once regardless of replacement. It is equivalent to selecting all\n * neighbors with non-zero probability when the fanout is >= the number of\n * neighbors (and replacement is set to false).\n *  - When the value is a non-negative integer, it serves as a minimum\n * threshold for selecting neighbors.\n * @param replace Boolean indicating whether the sample is performed with or\n * without replacement. If True, a value can be selected multiple times.\n * Otherwise, each value can be selected only once.\n * @param options Tensor options specifying the desired data type of the result.\n * @param probs_or_mask Optional tensor containing the (unnormalized)\n * probabilities associated with each neighboring edge of a node in the original\n * graph. It must be a 1D floating-point tensor with the number of elements\n * equal to the number of edges in the graph.\n * @param picked_data_ptr The destination address where the picked neighbors\n * should be put. Enough memory space should be allocated in advance.\n */\ntemplate <typename PickedType>\ninline int64_t NonUniformPick(\n    int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,\n    const torch::TensorOptions& options, const torch::Tensor& probs_or_mask,\n    PickedType* picked_data_ptr) {\n  auto local_probs =\n      probs_or_mask.size(0) > num_neighbors\n          ? probs_or_mask.slice(0, offset, offset + num_neighbors)\n          : probs_or_mask;\n  auto picked_indices = NonUniformPickOp(local_probs, fanout, replace);\n  auto picked_indices_ptr = picked_indices.data_ptr<int64_t>();\n  for (int i = 0; i < picked_indices.numel(); ++i) {\n    picked_data_ptr[i] =\n        static_cast<PickedType>(picked_indices_ptr[i]) + offset;\n  }\n  return picked_indices.numel();\n}\n\ntemplate <typename PickedType>\nint64_t Pick(\n    int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,\n    const torch::TensorOptions& options,\n    const torch::optional<torch::Tensor>& probs_or_mask,\n    SamplerArgs<SamplerType::NEIGHBOR> args, PickedType* picked_data_ptr) {\n  if (fanout == 0 || num_neighbors == 0) return 0;\n  if (probs_or_mask.has_value()) {\n    return NonUniformPick(\n        offset, num_neighbors, fanout, replace, options, probs_or_mask.value(),\n        picked_data_ptr);\n  } else {\n    return UniformPick(\n        offset, num_neighbors, fanout, replace, options, picked_data_ptr);\n  }\n}\n\ntemplate <SamplerType S, typename PickedType>\nint64_t TemporalPick(\n    torch::Tensor seed_timestamp, torch::Tensor csc_indices,\n    int64_t seed_offset, int64_t offset, int64_t num_neighbors, int64_t fanout,\n    bool replace, const torch::TensorOptions& options,\n    const torch::optional<torch::Tensor>& seed_pre_time_window,\n    const torch::optional<torch::Tensor>& probs_or_mask,\n    const torch::optional<torch::Tensor>& node_timestamp,\n    const torch::optional<torch::Tensor>& edge_timestamp, SamplerArgs<S> args,\n    PickedType* picked_data_ptr) {\n  constexpr int64_t kFastPathThreshold = 1000;\n  if (S == SamplerType::NEIGHBOR && num_neighbors > kFastPathThreshold &&\n      !probs_or_mask.has_value()) {\n    auto [success, sampled_edges] = FastTemporalPick(\n        seed_timestamp, csc_indices, fanout, replace, seed_pre_time_window,\n        node_timestamp, edge_timestamp, seed_offset, offset, num_neighbors);\n    if (success) {\n      for (size_t i = 0; i < sampled_edges.size(); ++i) {\n        picked_data_ptr[i] = static_cast<PickedType>(sampled_edges[i]);\n      }\n      return sampled_edges.size();\n    }\n  }\n  torch::optional<int64_t> time_window = torch::nullopt;\n  if (seed_pre_time_window.has_value()) {\n    time_window = utils::GetValueByIndex<int64_t>(\n        seed_pre_time_window.value(), seed_offset);\n  }\n  auto mask = TemporalMask(\n      utils::GetValueByIndex<int64_t>(seed_timestamp, seed_offset), csc_indices,\n      time_window, probs_or_mask, node_timestamp, edge_timestamp,\n      {offset, offset + num_neighbors});\n  torch::Tensor masked_prob;\n  if (probs_or_mask.has_value()) {\n    masked_prob =\n        probs_or_mask.value().slice(0, offset, offset + num_neighbors) * mask;\n  } else {\n    masked_prob = S == SamplerType::NEIGHBOR ? mask.to(torch::kFloat32) : mask;\n  }\n  if constexpr (S == SamplerType::NEIGHBOR) {\n    auto picked_indices = NonUniformPickOp(masked_prob, fanout, replace);\n    auto picked_indices_ptr = picked_indices.data_ptr<int64_t>();\n    for (int i = 0; i < picked_indices.numel(); ++i) {\n      picked_data_ptr[i] =\n          static_cast<PickedType>(picked_indices_ptr[i]) + offset;\n    }\n    return picked_indices.numel();\n  }\n  if constexpr (is_labor(S)) {\n    return Pick(\n        offset, num_neighbors, fanout, replace, options, masked_prob, args,\n        picked_data_ptr);\n  }\n}\n\ntemplate <SamplerType S, typename PickedType>\nint64_t PickByEtype(\n    bool with_seed_offsets, int64_t offset, int64_t num_neighbors,\n    const std::vector<int64_t>& fanouts, bool replace,\n    const torch::TensorOptions& options, const torch::Tensor& type_per_edge,\n    const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args,\n    PickedType* picked_data_ptr, int64_t seed_index,\n    PickedType* subgraph_indptr_ptr,\n    const std::vector<int64_t>& etype_id_to_num_picked_offset) {\n  int64_t etype_begin = offset;\n  int64_t etype_end = offset;\n  int64_t picked_total_count = 0;\n  AT_DISPATCH_INTEGRAL_TYPES(\n      type_per_edge.scalar_type(), \"PickByEtype\", ([&] {\n        const scalar_t* type_per_edge_data = type_per_edge.data_ptr<scalar_t>();\n        const auto end = offset + num_neighbors;\n        while (etype_begin < end) {\n          scalar_t etype = type_per_edge_data[etype_begin];\n          TORCH_CHECK(\n              etype >= 0 && etype < (int64_t)fanouts.size(),\n              \"Etype values exceed the number of fanouts.\");\n          int64_t fanout = fanouts[etype];\n          auto etype_end_it = std::upper_bound(\n              type_per_edge_data + etype_begin, type_per_edge_data + end,\n              etype);\n          etype_end = etype_end_it - type_per_edge_data;\n          // Do sampling for one etype. The picked nodes aren't stored\n          // continuously, but separately for each different etype.\n          if (fanout != 0) {\n            auto picked_count = 0;\n            if (with_seed_offsets) {\n              const auto indptr_offset =\n                  etype_id_to_num_picked_offset[etype] + seed_index;\n              picked_count = Pick(\n                  etype_begin, etype_end - etype_begin, fanout, replace,\n                  options, probs_or_mask, args,\n                  picked_data_ptr + subgraph_indptr_ptr[indptr_offset]);\n              TORCH_CHECK(\n                  subgraph_indptr_ptr[indptr_offset + 1] -\n                          subgraph_indptr_ptr[indptr_offset] ==\n                      picked_count,\n                  \"Actual picked count doesn't match the calculated \"\n                  \"pick number.\");\n            } else {\n              picked_count = Pick(\n                  etype_begin, etype_end - etype_begin, fanout, replace,\n                  options, probs_or_mask, args,\n                  picked_data_ptr + subgraph_indptr_ptr[seed_index] +\n                      picked_total_count);\n            }\n            picked_total_count += picked_count;\n          }\n          etype_begin = etype_end;\n        }\n      }));\n  return picked_total_count;\n}\n\ntemplate <SamplerType S, typename PickedType>\nint64_t TemporalPickByEtype(\n    torch::Tensor seed_timestamp, torch::Tensor csc_indices,\n    int64_t seed_offset, int64_t offset, int64_t num_neighbors,\n    const std::vector<int64_t>& fanouts, bool replace,\n    const torch::TensorOptions& options, const torch::Tensor& type_per_edge,\n    const torch::optional<torch::Tensor>& seed_pre_time_window,\n    const torch::optional<torch::Tensor>& probs_or_mask,\n    const torch::optional<torch::Tensor>& node_timestamp,\n    const torch::optional<torch::Tensor>& edge_timestamp, SamplerArgs<S> args,\n    PickedType* picked_data_ptr) {\n  int64_t etype_begin = offset;\n  int64_t etype_end = offset;\n  int64_t pick_offset = 0;\n  AT_DISPATCH_INTEGRAL_TYPES(\n      type_per_edge.scalar_type(), \"TemporalPickByEtype\", ([&] {\n        const scalar_t* type_per_edge_data = type_per_edge.data_ptr<scalar_t>();\n        const auto end = offset + num_neighbors;\n        while (etype_begin < end) {\n          scalar_t etype = type_per_edge_data[etype_begin];\n          TORCH_CHECK(\n              etype >= 0 && etype < (int64_t)fanouts.size(),\n              \"Etype values exceed the number of fanouts.\");\n          int64_t fanout = fanouts[etype];\n          auto etype_end_it = std::upper_bound(\n              type_per_edge_data + etype_begin, type_per_edge_data + end,\n              etype);\n          etype_end = etype_end_it - type_per_edge_data;\n          // Do sampling for one etype.\n          if (fanout != 0) {\n            int64_t picked_count = TemporalPick(\n                seed_timestamp, csc_indices, seed_offset, etype_begin,\n                etype_end - etype_begin, fanout, replace, options,\n                seed_pre_time_window, probs_or_mask, node_timestamp,\n                edge_timestamp, args, picked_data_ptr + pick_offset);\n            pick_offset += picked_count;\n          }\n          etype_begin = etype_end;\n        }\n      }));\n  return pick_offset;\n}\n\ntemplate <SamplerType S, typename PickedType>\nstd::enable_if_t<is_labor(S), int64_t> Pick(\n    int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,\n    const torch::TensorOptions& options,\n    const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args,\n    PickedType* picked_data_ptr) {\n  if (fanout == 0 || num_neighbors == 0) return 0;\n  if (probs_or_mask.has_value()) {\n    if (fanout < 0) {\n      return NonUniformPick(\n          offset, num_neighbors, fanout, replace, options,\n          probs_or_mask.value(), picked_data_ptr);\n    } else {\n      int64_t picked_count;\n      GRAPHBOLT_DISPATCH_ALL_TYPES(\n          probs_or_mask.value().scalar_type(), \"LaborPickFloatType\", ([&] {\n            if (replace) {\n              picked_count = LaborPick<true, true, scalar_t>(\n                  offset, num_neighbors, fanout, options, probs_or_mask, args,\n                  picked_data_ptr);\n            } else {\n              picked_count = LaborPick<true, false, scalar_t>(\n                  offset, num_neighbors, fanout, options, probs_or_mask, args,\n                  picked_data_ptr);\n            }\n          }));\n      return picked_count;\n    }\n  } else if (fanout < 0) {\n    return UniformPick(\n        offset, num_neighbors, fanout, replace, options, picked_data_ptr);\n  } else if (replace) {\n    return LaborPick<false, true, float>(\n        offset, num_neighbors, fanout, options,\n        /* probs_or_mask= */ torch::nullopt, args, picked_data_ptr);\n  } else {  // replace = false\n    return LaborPick<false, false, float>(\n        offset, num_neighbors, fanout, options,\n        /* probs_or_mask= */ torch::nullopt, args, picked_data_ptr);\n  }\n}\n\ntemplate <typename T, typename U>\ninline void safe_divide(T& a, U b) {\n  a = b > 0 ? (T)(a / b) : std::numeric_limits<T>::infinity();\n}\n\nnamespace labor {\n\ntemplate <typename T>\ninline T invcdf(T u, int64_t n, T rem) {\n  constexpr T one = 1;\n  return rem * (one - std::pow(one - u, one / n));\n}\n\ntemplate <typename T, typename seed_t>\ninline T jth_sorted_uniform_random(\n    seed_t seed, int64_t t, int64_t c, int64_t j, T& rem, int64_t n) {\n  const T u = seed.uniform(t + j * c);\n  // https://mathematica.stackexchange.com/a/256707\n  rem -= invcdf(u, n, rem);\n  return 1 - rem;\n}\n\n};  // namespace labor\n\n/**\n * @brief Perform uniform-nonuniform sampling of elements depending on the\n * template parameter NonUniform and return the sampled indices.\n *\n * @param offset The starting edge ID for the connected neighbors of the sampled\n * node.\n * @param num_neighbors The number of neighbors to pick.\n * @param fanout The number of edges to be sampled for each node. It should be\n * >= 0 or -1.\n *  - When the value is -1, all neighbors (with non-zero probability, if\n * weighted) will be sampled once regardless of replacement. It is equivalent to\n * selecting all neighbors with non-zero probability when the fanout is >= the\n * number of neighbors (and replacement is set to false).\n *  - When the value is a non-negative integer, it serves as a minimum\n * threshold for selecting neighbors.\n * @param options Tensor options specifying the desired data type of the result.\n * @param probs_or_mask Optional tensor containing the (unnormalized)\n * probabilities associated with each neighboring edge of a node in the original\n * graph. It must be a 1D floating-point tensor with the number of elements\n * equal to the number of edges in the graph.\n * @param args Contains labor specific arguments.\n * @param picked_data_ptr The destination address where the picked neighbors\n * should be put. Enough memory space should be allocated in advance.\n */\ntemplate <\n    bool NonUniform, bool Replace, typename ProbsType, SamplerType S,\n    typename PickedType, int StackSize>\ninline std::enable_if_t<is_labor(S), int64_t> LaborPick(\n    int64_t offset, int64_t num_neighbors, int64_t fanout,\n    const torch::TensorOptions& options,\n    const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args,\n    PickedType* picked_data_ptr) {\n  fanout = Replace ? fanout : std::min(fanout, num_neighbors);\n  if (!NonUniform && !Replace && fanout >= num_neighbors) {\n    std::iota(picked_data_ptr, picked_data_ptr + num_neighbors, offset);\n    return num_neighbors;\n  }\n  // Assuming max_degree of a vertex is <= 4 billion.\n  std::array<std::pair<float, uint32_t>, StackSize> heap;\n  auto heap_data = heap.data();\n  torch::Tensor heap_tensor;\n  if (fanout > StackSize) {\n    constexpr int factor = sizeof(heap_data[0]) / sizeof(int32_t);\n    heap_tensor = torch::empty({fanout * factor}, torch::kInt32);\n    heap_data = reinterpret_cast<std::pair<float, uint32_t>*>(\n        heap_tensor.data_ptr<int32_t>());\n  }\n  const ProbsType* local_probs_data =\n      NonUniform ? probs_or_mask.value().data_ptr<ProbsType>() + offset\n                 : nullptr;\n  if (NonUniform && probs_or_mask.value().size(0) <= num_neighbors) {\n    local_probs_data -= offset;\n  }\n  AT_DISPATCH_INDEX_TYPES(\n      args.indices.scalar_type(), \"LaborPickMain\", ([&] {\n        const auto local_indices_data =\n            reinterpret_cast<index_t*>(args.indices.data_ptr()) + offset;\n        if constexpr (Replace) {\n          // [Algorithm] @mfbalin\n          // Use a max-heap to get rid of the big random numbers and filter the\n          // smallest fanout of them. Implements arXiv:2210.13339 Section A.3.\n          // Unlike sampling without replacement below, the same item can be\n          // included fanout times in our sample. Thus, we sort and pick the\n          // smallest fanout random numbers out of num_neighbors * fanout of\n          // them. Each item has fanout many random numbers in the race and the\n          // smallest fanout of them get picked. Instead of generating\n          // fanout * num_neighbors random numbers and increase the complexity,\n          // I devised an algorithm to generate the fanout numbers for an item\n          // in a sorted manner on demand, meaning we continue generating random\n          // numbers for an item only if it has been sampled that many times\n          // already.\n          // https://gist.github.com/mfbalin/096dcad5e3b1f6a59ff7ff2f9f541618\n          //\n          // [Complexity Analysis]\n          // Will modify the heap at most linear in O(num_neighbors + fanout)\n          // and each modification takes O(log(fanout)). So the total complexity\n          // is O((fanout + num_neighbors) log(fanout)). It is possible to\n          // decrease the logarithmic factor down to\n          // O(log(min(fanout, num_neighbors))).\n          std::array<float, StackSize> remaining;\n          auto remaining_data = remaining.data();\n          torch::Tensor remaining_tensor;\n          if (num_neighbors > StackSize) {\n            remaining_tensor = torch::empty({num_neighbors}, torch::kFloat32);\n            remaining_data = remaining_tensor.data_ptr<float>();\n          }\n          std::fill_n(remaining_data, num_neighbors, 1.f);\n          auto heap_end = heap_data;\n          const auto init_count = (num_neighbors + fanout - 1) / num_neighbors;\n          auto sample_neighbor_i_with_index_t_jth_time =\n              [&](index_t t, int64_t j, uint32_t i) {\n                auto rnd = labor::jth_sorted_uniform_random(\n                    args.random_seed, t, args.num_nodes, j, remaining_data[i],\n                    fanout - j);  // r_t\n                if constexpr (NonUniform) {\n                  safe_divide(rnd, local_probs_data[i]);\n                }  // r_t / \\pi_t\n                if (heap_end < heap_data + fanout) {\n                  heap_end[0] = std::make_pair(rnd, i);\n                  if (++heap_end >= heap_data + fanout) {\n                    std::make_heap(heap_data, heap_data + fanout);\n                  }\n                  return false;\n                } else if (rnd < heap_data[0].first) {\n                  std::pop_heap(heap_data, heap_data + fanout);\n                  heap_data[fanout - 1] = std::make_pair(rnd, i);\n                  std::push_heap(heap_data, heap_data + fanout);\n                  return false;\n                } else {\n                  remaining_data[i] = -1;\n                  return true;\n                }\n              };\n          for (uint32_t i = 0; i < num_neighbors; ++i) {\n            const auto t = local_indices_data[i];\n            for (int64_t j = 0; j < init_count; j++) {\n              sample_neighbor_i_with_index_t_jth_time(t, j, i);\n            }\n          }\n          for (uint32_t i = 0; i < num_neighbors; ++i) {\n            if (remaining_data[i] == -1) continue;\n            const auto t = local_indices_data[i];\n            for (int64_t j = init_count; j < fanout; ++j) {\n              if (sample_neighbor_i_with_index_t_jth_time(t, j, i)) break;\n            }\n          }\n        } else {\n          // [Algorithm]\n          // Use a max-heap to get rid of the big random numbers and filter the\n          // smallest fanout of them. Implements arXiv:2210.13339 Section A.3.\n          //\n          // [Complexity Analysis]\n          // the first for loop and std::make_heap runs in time O(fanouts).\n          // The next for loop compares each random number to the current\n          // minimum fanout numbers. For any given i, the probability that the\n          // current random number will replace any number in the heap is fanout\n          // / i. Summing from i=fanout to num_neighbors, we get f * (H_n -\n          // H_f), where n is num_neighbors and f is fanout, H_f is \\sum_j=1^f\n          // 1/j. In the end H_n - H_f = O(log n/f), there are n - f iterations,\n          // each heap operation takes time log f, so the total complexity is\n          // O(f + (n - f)\n          // + f log(n/f) log f) = O(n + f log(f) log(n/f)). If f << n (f is a\n          // constant in almost all cases), then the average complexity is\n          // O(num_neighbors).\n          for (uint32_t i = 0; i < fanout; ++i) {\n            const auto t = local_indices_data[i];\n            auto rnd = args.random_seed.uniform(t);  // r_t\n            if constexpr (NonUniform) {\n              safe_divide(rnd, local_probs_data[i]);\n            }  // r_t / \\pi_t\n            heap_data[i] = std::make_pair(rnd, i);\n          }\n          if (!NonUniform || fanout < num_neighbors) {\n            std::make_heap(heap_data, heap_data + fanout);\n          }\n          for (uint32_t i = fanout; i < num_neighbors; ++i) {\n            const auto t = local_indices_data[i];\n            auto rnd = args.random_seed.uniform(t);  // r_t\n            if constexpr (NonUniform) {\n              safe_divide(rnd, local_probs_data[i]);\n            }  // r_t / \\pi_t\n            if (rnd < heap_data[0].first) {\n              std::pop_heap(heap_data, heap_data + fanout);\n              heap_data[fanout - 1] = std::make_pair(rnd, i);\n              std::push_heap(heap_data, heap_data + fanout);\n            }\n          }\n        }\n      }));\n  int64_t num_sampled = 0;\n  for (int64_t i = 0; i < fanout; ++i) {\n    const auto [rnd, j] = heap_data[i];\n    if (!NonUniform || rnd < std::numeric_limits<float>::infinity()) {\n      picked_data_ptr[num_sampled++] = offset + j;\n    }\n  }\n  return num_sampled;\n}\n\n}  // namespace sampling\n}  // namespace graphbolt\n"
  },
  {
    "path": "graphbolt/src/index_select.cc",
    "content": "/**\n *  Copyright (c) 2023 by Contributors\n * @file index_select.cc\n * @brief Index select operators.\n */\n#include \"./index_select.h\"\n\n#include <graphbolt/cuda_ops.h>\n#include <graphbolt/fused_csc_sampling_graph.h>\n\n#include <cstring>\n#include <numeric>\n\n#include \"./macro.h\"\n#include \"./utils.h\"\n\nnamespace graphbolt {\nnamespace ops {\n\nconstexpr int kIntGrainSize = 64;\n\ntorch::Tensor IndexSelect(torch::Tensor input, torch::Tensor index) {\n  if (utils::is_on_gpu(index)) {\n    if (input.is_pinned()) {\n      GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(\n          c10::DeviceType::CUDA, \"UVAIndexSelect\",\n          { return UVAIndexSelectImpl(input, index); });\n    } else {\n      return torch::index_select(input, 0, index);\n    }\n  }\n  auto output_shape = input.sizes().vec();\n  output_shape[0] = index.numel();\n  auto result = torch::empty(\n      output_shape, index.options()\n                        .dtype(input.dtype())\n                        .pinned_memory(utils::is_pinned(index)));\n  auto result_ptr = reinterpret_cast<std::byte*>(result.data_ptr());\n  const auto input_ptr = reinterpret_cast<std::byte*>(input.data_ptr());\n  const auto row_bytes = input.slice(0, 0, 1).numel() * input.element_size();\n  const auto stride = input.stride(0) * input.element_size();\n  const auto num_input_rows = input.size(0);\n  AT_DISPATCH_INDEX_TYPES(\n      index.scalar_type(), \"IndexSelect::index::scalar_type()\", ([&] {\n        const auto index_ptr = index.data_ptr<index_t>();\n        graphbolt::parallel_for(\n            0, index.size(0), kIntGrainSize, [&](int64_t begin, int64_t end) {\n              for (int64_t i = begin; i < end; i++) {\n                auto idx = index_ptr[i];\n                if (idx < 0) idx += num_input_rows;\n                if (idx < 0 || idx >= num_input_rows) {\n                  // Throw IndexError via torch.\n                  idx += input[num_input_rows].item<index_t>();\n                }\n                std::memcpy(\n                    result_ptr + i * row_bytes, input_ptr + idx * stride,\n                    row_bytes);\n              }\n            });\n      }));\n  return result;\n}\n\nc10::intrusive_ptr<Future<torch::Tensor>> IndexSelectAsync(\n    torch::Tensor input, torch::Tensor index) {\n  TORCH_CHECK(!utils::is_on_gpu(index) && !utils::is_on_gpu(input));\n  return async([=] { return IndexSelect(input, index); });\n}\n\nc10::intrusive_ptr<Future<torch::Tensor>> ScatterAsync(\n    torch::Tensor input, torch::Tensor index, torch::Tensor src) {\n  TORCH_CHECK(\n      !utils::is_on_gpu(input) && !utils::is_on_gpu(index) &&\n      !utils::is_on_gpu(src));\n  TORCH_CHECK(index.sizes().size() == 1, \"index tensor needs to be 1d.\");\n  for (size_t i = 1; i < input.sizes().size(); i++) {\n    TORCH_CHECK(\n        input.size(i) == src.size(i),\n        \"dimension mismatch between input and src at \", i,\n        \"th dimension: \", input.size(i), \" != \", src.size(i), \".\");\n  }\n  return async([=] {\n    const auto row_bytes = src.slice(0, 0, 1).numel() * src.element_size();\n    const auto src_ptr = reinterpret_cast<std::byte*>(src.data_ptr());\n    auto input_ptr = reinterpret_cast<std::byte*>(input.data_ptr());\n    AT_DISPATCH_INDEX_TYPES(\n        index.scalar_type(), \"ScatterAsync::index::scalar_type()\", ([&] {\n          const auto index_ptr = index.data_ptr<index_t>();\n          graphbolt::parallel_for(\n              0, index.size(0), kIntGrainSize, [&](int64_t begin, int64_t end) {\n                for (int64_t i = begin; i < end; i++) {\n                  std::memcpy(\n                      input_ptr + index_ptr[i] * row_bytes,\n                      src_ptr + i * row_bytes, row_bytes);\n                }\n              });\n        }));\n    return input;\n  });\n}\n\nstd::tuple<torch::Tensor, torch::Tensor> IndexSelectCSC(\n    torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes,\n    torch::optional<int64_t> output_size) {\n  TORCH_CHECK(\n      indices.sizes().size() == 1, \"IndexSelectCSC only supports 1d tensors\");\n  if (utils::is_on_gpu(nodes) && utils::is_accessible_from_gpu(indptr) &&\n      utils::is_accessible_from_gpu(indices)) {\n    GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(\n        c10::DeviceType::CUDA, \"IndexSelectCSCImpl\",\n        { return IndexSelectCSCImpl(indptr, indices, nodes, output_size); });\n  }\n  auto [output_indptr, results] = IndexSelectCSCBatched(\n      indptr, std::vector{indices}, nodes, false, output_size);\n  return std::make_tuple(output_indptr, results.at(0));\n}\n\nstd::tuple<torch::Tensor, std::vector<torch::Tensor>> IndexSelectCSCBatched(\n    torch::Tensor indptr, std::vector<torch::Tensor> indices_list,\n    torch::Tensor nodes, bool with_edge_ids,\n    torch::optional<int64_t> output_size) {\n  for (auto& indices : indices_list) {\n    TORCH_CHECK(\n        indices.sizes().size() == 1,\n        \"IndexSelectCSCBatched only supports 1d tensors\");\n  }\n  if (utils::is_on_gpu(nodes) && utils::is_accessible_from_gpu(indptr) &&\n      utils::are_accessible_from_gpu(indices_list)) {\n    GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(\n        c10::DeviceType::CUDA, \"IndexSelectCSCImpl\", {\n          return IndexSelectCSCBatchedImpl(\n              indptr, indices_list, nodes, with_edge_ids, output_size);\n        });\n  }\n  constexpr int kDefaultGrainSize = 128;\n  const auto num_nodes = nodes.size(0);\n  torch::Tensor output_indptr = torch::empty(\n      {num_nodes + 1}, nodes.options().dtype(indptr.scalar_type()));\n  std::vector<torch::Tensor> results;\n  torch::optional<torch::Tensor> edge_ids;\n  AT_DISPATCH_INDEX_TYPES(\n      indptr.scalar_type(), \"IndexSelectCSCBatched::indptr\", ([&] {\n        using indptr_t = index_t;\n        const auto indptr_data = indptr.data_ptr<indptr_t>();\n        auto out_indptr_data = output_indptr.data_ptr<indptr_t>();\n        out_indptr_data[0] = 0;\n        AT_DISPATCH_INDEX_TYPES(\n            nodes.scalar_type(), \"IndexSelectCSCBatched::nodes\", ([&] {\n              const auto nodes_data = nodes.data_ptr<index_t>();\n              torch::parallel_for(\n                  0, num_nodes, kDefaultGrainSize,\n                  [&](int64_t begin, int64_t end) {\n                    for (int64_t i = begin; i < end; i++) {\n                      const auto node_id = nodes_data[i];\n                      const auto degree =\n                          indptr_data[node_id + 1] - indptr_data[node_id];\n                      out_indptr_data[i + 1] = degree;\n                    }\n                  });\n              output_indptr = output_indptr.cumsum(0, indptr.scalar_type());\n              out_indptr_data = output_indptr.data_ptr<indptr_t>();\n              TORCH_CHECK(\n                  !output_size.has_value() ||\n                      out_indptr_data[num_nodes] == *output_size,\n                  \"An incorrect output_size argument was provided.\");\n              output_size = out_indptr_data[num_nodes];\n              for (const auto& indices : indices_list) {\n                results.push_back(torch::empty(\n                    *output_size,\n                    nodes.options().dtype(indices.scalar_type())));\n              }\n              if (with_edge_ids) {\n                edge_ids = torch::empty(\n                    *output_size, nodes.options().dtype(indptr.scalar_type()));\n              }\n              torch::parallel_for(\n                  0, num_nodes, kDefaultGrainSize,\n                  [&](int64_t begin, int64_t end) {\n                    for (int64_t i = begin; i < end; i++) {\n                      const auto output_offset = out_indptr_data[i];\n                      const auto numel = out_indptr_data[i + 1] - output_offset;\n                      const auto input_offset = indptr_data[nodes_data[i]];\n                      for (size_t tensor_id = 0;\n                           tensor_id < indices_list.size(); tensor_id++) {\n                        auto output = reinterpret_cast<std::byte*>(\n                            results[tensor_id].data_ptr());\n                        const auto input = reinterpret_cast<std::byte*>(\n                            indices_list[tensor_id].data_ptr());\n                        const auto element_size =\n                            indices_list[tensor_id].element_size();\n                        std::memcpy(\n                            output + output_offset * element_size,\n                            input + input_offset * element_size,\n                            element_size * numel);\n                      }\n                      if (edge_ids.has_value()) {\n                        auto output = edge_ids->data_ptr<indptr_t>();\n                        std::iota(\n                            output + output_offset,\n                            output + output_offset + numel, input_offset);\n                      }\n                    }\n                  });\n            }));\n      }));\n  if (edge_ids) results.push_back(*edge_ids);\n  return std::make_tuple(output_indptr, results);\n}\n\nc10::intrusive_ptr<\n    Future<std::tuple<torch::Tensor, std::vector<torch::Tensor>>>>\nIndexSelectCSCBatchedAsync(\n    torch::Tensor indptr, std::vector<torch::Tensor> indices_list,\n    torch::Tensor nodes, bool with_edge_ids,\n    torch::optional<int64_t> output_size) {\n  return async(\n      [=] {\n        return IndexSelectCSCBatched(\n            indptr, indices_list, nodes, with_edge_ids, output_size);\n      },\n      utils::is_on_gpu(nodes));\n}\n\n}  // namespace ops\n}  // namespace graphbolt\n"
  },
  {
    "path": "graphbolt/src/index_select.h",
    "content": "/**\n *  Copyright (c) 2023 by Contributors\n * @file index_select.h\n * @brief Index select operators.\n */\n#ifndef GRAPHBOLT_INDEX_SELECT_H_\n#define GRAPHBOLT_INDEX_SELECT_H_\n\n#include <graphbolt/async.h>\n#include <torch/script.h>\n\nnamespace graphbolt {\nnamespace ops {\n\n/**\n * @brief Select columns for a sparse matrix in a CSC format according to nodes\n * tensor.\n *\n * NOTE:\n * 1. The shape of all tensors must be 1-D.\n * 2. If indices is on pinned memory and nodes is on pinned memory or GPU\n * memory, then UVAIndexSelectCSCImpl will be called. If indices is on GPU\n * memory, then IndexSelectCSCImpl will be called. Otherwise,\n * FusedCSCSamplingGraph::InSubgraph will be called.\n *\n * @param indptr Indptr tensor containing offsets with shape (N,).\n * @param indices Indices tensor with edge information of shape (indptr[N],).\n * @param nodes Nodes tensor with shape (M,).\n * @param output_size The total number of edges being copied.\n * @return (torch::Tensor, torch::Tensor) Output indptr and indices tensors of\n * shapes (M + 1,) and ((indptr[nodes + 1] - indptr[nodes]).sum(),).\n */\nstd::tuple<torch::Tensor, torch::Tensor> IndexSelectCSC(\n    torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes,\n    torch::optional<int64_t> output_size = torch::nullopt);\n\n/**\n * @brief Select rows from input tensor according to index tensor.\n *\n * NOTE:\n * 1. The shape of input tensor can be multi-dimensional, but the index tensor\n * must be 1-D.\n * 2. If input is on pinned memory and index is on pinned memory or GPU memory,\n * then UVAIndexSelectImpl will be called. Otherwise, torch::index_select will\n * be called.\n *\n * @param input Input tensor with shape (N, ...).\n * @param index Index tensor with shape (M,).\n * @return torch::Tensor Output tensor with shape (M, ...).\n */\ntorch::Tensor IndexSelect(torch::Tensor input, torch::Tensor index);\n\n/**\n * @brief The async version of IndexSelect, available for only CPU tensors.\n *\n * @return Returns a future containing a torch::Tensor.\n */\nc10::intrusive_ptr<Future<torch::Tensor>> IndexSelectAsync(\n    torch::Tensor input, torch::Tensor index);\n\n/**\n * @brief The async version of operation input[index] = src.\n * @param input The input tensor.\n * @param index The index tensor into input.\n * @param src The src tensor being assigned into input.\n *\n * @return Returns a future containing input, a torch::Tensor.\n */\nc10::intrusive_ptr<Future<torch::Tensor>> ScatterAsync(\n    torch::Tensor input, torch::Tensor index, torch::Tensor src);\n\n/**\n * @brief Select columns for a sparse matrix in a CSC format according to nodes\n * tensor.\n *\n * NOTE: The shape of all tensors must be 1-D.\n *\n * @param indptr Indptr tensor containing offsets with shape (N,).\n * @param indices_list Vector of indices tensor with edge information of shape\n * (indptr[N],).\n * @param nodes Nodes tensor with shape (M,).\n * @param with_edge_ids Whether to return edge ids tensor corresponding to\n * sliced edges as the last element of the output.\n * @param output_size The total number of edges being copied.\n *\n * @return (torch::Tensor, std::vector<torch::Tensor>) Output indptr and vector\n * of indices tensors of shapes (M + 1,) and ((indptr[nodes + 1] -\n * indptr[nodes]).sum(),).\n */\nstd::tuple<torch::Tensor, std::vector<torch::Tensor>> IndexSelectCSCBatched(\n    torch::Tensor indptr, std::vector<torch::Tensor> indices_list,\n    torch::Tensor nodes, bool with_edge_ids,\n    torch::optional<int64_t> output_size);\n\nc10::intrusive_ptr<\n    Future<std::tuple<torch::Tensor, std::vector<torch::Tensor>>>>\nIndexSelectCSCBatchedAsync(\n    torch::Tensor indptr, std::vector<torch::Tensor> indices_list,\n    torch::Tensor nodes, bool with_edge_ids,\n    torch::optional<int64_t> output_size);\n\n}  // namespace ops\n}  // namespace graphbolt\n\n#endif  // GRAPHBOLT_INDEX_SELECT_H_\n"
  },
  {
    "path": "graphbolt/src/io_uring.cc",
    "content": "/**\n *   Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)\n *   All rights reserved.\n *\n *   Licensed under the Apache License, Version 2.0 (the \"License\");\n *   you may not use this file except in compliance with the License.\n *   You may obtain a copy of the License at\n *\n *       http://www.apache.org/licenses/LICENSE-2.0\n *\n *   Unless required by applicable law or agreed to in writing, software\n *   distributed under the License is distributed on an \"AS IS\" BASIS,\n *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n *   See the License for the specific language governing permissions and\n *   limitations under the License.\n *\n * @file io_uring.cc\n * @brief io_uring related functions.\n */\n#include \"./io_uring.h\"\n\n#ifdef HAVE_LIBRARY_LIBURING\n\n#include <errno.h>\n#include <liburing.h>\n#include <liburing/io_uring.h>\n#include <stddef.h>\n#include <sys/syscall.h>\n#include <unistd.h>\n\n#include <memory>\n#include <mutex>\n\nstruct io_uring_probe_destroyer {\n  void operator()(struct io_uring_probe* p) {\n    if (p) io_uring_free_probe(p);\n  }\n};\n#endif\n\nnamespace graphbolt {\nnamespace io_uring {\n\nbool IsAvailable() {\n#ifdef HAVE_LIBRARY_LIBURING\n  /** @brief The cached value of whether io_uring is available. */\n  static bool cached_is_available;\n\n  /** @brief Ensure cached_is_available is initialized once and thread-safe. */\n  static std::once_flag initialization_flag;\n\n  std::call_once(initialization_flag, []() {\n    // https://unix.stackexchange.com/a/596284/314554\n    cached_is_available =\n        !(syscall(\n              __NR_io_uring_register, 0, IORING_UNREGISTER_BUFFERS, NULL, 0) &&\n          errno == ENOSYS);\n\n    std::unique_ptr<struct io_uring_probe, io_uring_probe_destroyer> probe(\n        io_uring_get_probe(), io_uring_probe_destroyer());\n    if (probe.get()) {\n      cached_is_available =\n          cached_is_available &&\n          io_uring_opcode_supported(probe.get(), IORING_OP_READ);\n      cached_is_available =\n          cached_is_available &&\n          io_uring_opcode_supported(probe.get(), IORING_OP_READV);\n    } else {\n      cached_is_available = false;\n    }\n  });\n\n  return cached_is_available;\n#else\n  return false;\n#endif\n}\n\nvoid SetNumThreads(int64_t count) { num_threads = count; }\n\n}  // namespace io_uring\n}  // namespace graphbolt\n"
  },
  {
    "path": "graphbolt/src/io_uring.h",
    "content": "/**\n *   Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)\n *   All rights reserved.\n *\n *   Licensed under the Apache License, Version 2.0 (the \"License\");\n *   you may not use this file except in compliance with the License.\n *   You may obtain a copy of the License at\n *\n *       http://www.apache.org/licenses/LICENSE-2.0\n *\n *   Unless required by applicable law or agreed to in writing, software\n *   distributed under the License is distributed on an \"AS IS\" BASIS,\n *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n *   See the License for the specific language governing permissions and\n *   limitations under the License.\n *\n * @file io_uring.h\n * @brief io_uring related functions.\n */\n#ifndef GRAPHBOLT_IO_URING_H_\n#define GRAPHBOLT_IO_URING_H_\n\n#include <cstdint>\n#include <optional>\n\nnamespace graphbolt {\nnamespace io_uring {\n\nbool IsAvailable();\n\n/** @brief Set a limit on # background io_uring threads. */\ninline std::optional<int64_t> num_threads;\n\n/**\n * @brief Set the number of background io_uring threads.\n */\nvoid SetNumThreads(int64_t count);\n\n}  // namespace io_uring\n}  // namespace graphbolt\n\n#endif  // GRAPHBOLT_IO_URING_H_\n"
  },
  {
    "path": "graphbolt/src/isin.cc",
    "content": "/**\n *  Copyright (c) 2023 by Contributors\n *\n * @file isin.cc\n * @brief Isin op.\n */\n\n#include <graphbolt/cuda_ops.h>\n#include <graphbolt/isin.h>\n\n#include \"./macro.h\"\n#include \"./utils.h\"\n\nnamespace {\nstatic constexpr int kSearchGrainSize = 4096;\n}  // namespace\n\nnamespace graphbolt {\nnamespace sampling {\n\ntorch::Tensor IsInCPU(\n    const torch::Tensor& elements, const torch::Tensor& test_elements) {\n  torch::Tensor sorted_test_elements;\n  std::tie(sorted_test_elements, std::ignore) = test_elements.sort(\n      /*stable=*/false, /*dim=*/0, /*descending=*/false);\n  torch::Tensor result = torch::empty_like(elements, torch::kBool);\n  size_t num_test_elements = test_elements.size(0);\n  size_t num_elements = elements.size(0);\n\n  AT_DISPATCH_INTEGRAL_TYPES(\n      elements.scalar_type(), \"IsInOperation\", ([&] {\n        const scalar_t* elements_ptr = elements.data_ptr<scalar_t>();\n        const scalar_t* sorted_test_elements_ptr =\n            sorted_test_elements.data_ptr<scalar_t>();\n        bool* result_ptr = result.data_ptr<bool>();\n        torch::parallel_for(\n            0, num_elements, kSearchGrainSize, [&](size_t start, size_t end) {\n              for (auto i = start; i < end; i++) {\n                result_ptr[i] = std::binary_search(\n                    sorted_test_elements_ptr,\n                    sorted_test_elements_ptr + num_test_elements,\n                    elements_ptr[i]);\n              }\n            });\n      }));\n  return result;\n}\n\ntorch::Tensor IsIn(\n    const torch::Tensor& elements, const torch::Tensor& test_elements) {\n  if (utils::is_on_gpu(elements) && utils::is_on_gpu(test_elements)) {\n    GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(\n        c10::DeviceType::CUDA, \"IsInOperation\",\n        { return ops::IsIn(elements, test_elements); });\n  } else {\n    return IsInCPU(elements, test_elements);\n  }\n}\n\ntorch::Tensor IsNotInIndex(\n    const torch::Tensor& elements, const torch::Tensor& test_elements) {\n  auto mask = IsIn(elements, test_elements);\n  if (utils::is_on_gpu(mask)) {\n    GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(\n        c10::DeviceType::CUDA, \"NonzeroOperation\",\n        { return ops::Nonzero(mask, true); });\n  }\n  return torch::nonzero(torch::logical_not(mask)).squeeze(1);\n}\n\nc10::intrusive_ptr<Future<torch::Tensor>> IsNotInIndexAsync(\n    const torch::Tensor& elements, const torch::Tensor& test_elements) {\n  return async([=] { return IsNotInIndex(elements, test_elements); });\n}\n\n}  // namespace sampling\n}  // namespace graphbolt\n"
  },
  {
    "path": "graphbolt/src/macro.h",
    "content": "/**\n *  Copyright (c) 2023 by Contributors\n * @file macro.h\n * @brief Graphbolt macros.\n */\n\n#ifndef GRAPHBOLT_MACRO_H_\n#define GRAPHBOLT_MACRO_H_\n\n#include <torch/script.h>\n\nnamespace graphbolt {\n\n// Dispatch operator implementation function to CUDA device only.\n#ifdef GRAPHBOLT_USE_CUDA\n#define GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(device_type, name, ...) \\\n  if (device_type == c10::DeviceType::CUDA) {                       \\\n    [[maybe_unused]] auto XPU = c10::DeviceType::CUDA;              \\\n    __VA_ARGS__                                                     \\\n  } else {                                                          \\\n    TORCH_CHECK(false, name, \" is only available on CUDA device.\"); \\\n  }\n#else\n#define GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(device_type, name, ...) \\\n  TORCH_CHECK(false, name, \" is only available on CUDA device.\");\n#endif\n\n// This includes all integer, float and boolean types.\n#define GRAPHBOLT_DISPATCH_CASE_ALL_TYPES(...)            \\\n  AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__)                 \\\n  AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)     \\\n  AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \\\n  AT_DISPATCH_CASE(at::ScalarType::Bool, __VA_ARGS__)\n\n#define GRAPHBOLT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \\\n  AT_DISPATCH_SWITCH(TYPE, NAME, GRAPHBOLT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__))\n\n}  // namespace graphbolt\n\n#endif  // GRAPHBOLT_MACRO_H_\n"
  },
  {
    "path": "graphbolt/src/partitioned_cache_policy.cc",
    "content": "/**\n *   Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)\n *   All rights reserved.\n *\n *   Licensed under the Apache License, Version 2.0 (the \"License\");\n *   you may not use this file except in compliance with the License.\n *   You may obtain a copy of the License at\n *\n *       http://www.apache.org/licenses/LICENSE-2.0\n *\n *   Unless required by applicable law or agreed to in writing, software\n *   distributed under the License is distributed on an \"AS IS\" BASIS,\n *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n *   See the License for the specific language governing permissions and\n *   limitations under the License.\n *\n * @file partitioned_cache_policy.cc\n * @brief Partitioned cache policy implementation on the CPU.\n */\n#include \"./partitioned_cache_policy.h\"\n\n#include <algorithm>\n#include <limits>\n#include <numeric>\n\n#include \"./utils.h\"\n\nnamespace graphbolt {\nnamespace storage {\n\nconstexpr int kIntGrainSize = 256;\n\ntorch::Tensor AddOffset(torch::Tensor keys, int64_t offset) {\n  if (offset == 0) return keys;\n  auto output = torch::empty_like(\n      keys, keys.options().pinned_memory(utils::is_pinned(keys)));\n  AT_DISPATCH_INDEX_TYPES(\n      keys.scalar_type(), \"AddOffset\", ([&] {\n        auto keys_ptr = keys.data_ptr<index_t>();\n        auto output_ptr = output.data_ptr<index_t>();\n        graphbolt::parallel_for_each(\n            0, keys.numel(), kIntGrainSize, [&](int64_t i) {\n              const auto result = keys_ptr[i] + offset;\n              if constexpr (!std::is_same_v<index_t, int64_t>) {\n                TORCH_CHECK(\n                    std::numeric_limits<index_t>::min() <= result &&\n                    result <= std::numeric_limits<index_t>::max());\n              }\n              output_ptr[i] = static_cast<index_t>(result);\n            });\n      }));\n  return output;\n}\n\ntemplate <typename CachePolicy>\nPartitionedCachePolicy::PartitionedCachePolicy(\n    CachePolicy, int64_t capacity, int64_t num_partitions)\n    : capacity_(capacity) {\n  TORCH_CHECK(num_partitions >= 1, \"# partitions need to be positive.\");\n  for (int64_t i = 0; i < num_partitions; i++) {\n    const auto begin = i * capacity / num_partitions;\n    const auto end = (i + 1) * capacity / num_partitions;\n    policies_.emplace_back(std::make_unique<CachePolicy>(end - begin));\n  }\n}\n\nstd::tuple<torch::Tensor, torch::Tensor, torch::Tensor>\nPartitionedCachePolicy::Partition(torch::Tensor keys) {\n  const int64_t num_parts = policies_.size();\n  torch::Tensor offsets = torch::empty(\n      num_parts * num_parts + 1, keys.options().dtype(torch::kInt64));\n  auto offsets_ptr = offsets.data_ptr<int64_t>();\n  std::fill_n(offsets_ptr, offsets.size(0), int64_t{});\n  auto indices = torch::empty_like(keys, keys.options().dtype(torch::kInt64));\n  auto part_id = torch::empty_like(keys, keys.options().dtype(torch::kInt32));\n  const auto num_keys = keys.size(0);\n  auto part_id_ptr = part_id.data_ptr<int32_t>();\n  AT_DISPATCH_INDEX_TYPES(\n      keys.scalar_type(), \"PartitionedCachePolicy::partition\", ([&] {\n        auto keys_ptr = keys.data_ptr<index_t>();\n        namespace gb = graphbolt;\n        gb::parallel_for_each(0, num_parts, 1, [&](int64_t tid) {\n          const auto begin = tid * num_keys / num_parts;\n          const auto end = (tid + 1) * num_keys / num_parts;\n          for (int64_t i = begin; i < end; i++) {\n            const auto part_id = PartAssignment(keys_ptr[i]);\n            offsets_ptr[tid * num_parts + part_id]++;\n            part_id_ptr[i] = part_id;\n          }\n        });\n      }));\n\n  // Transpose the offsets tensor, take cumsum and transpose back.\n  auto offsets_permuted = torch::empty_like(offsets);\n  auto offsets_permuted_ptr = offsets_permuted.data_ptr<int64_t>();\n  graphbolt::parallel_for_each(\n      0, num_parts * num_parts, kIntGrainSize, [&](int64_t i) {\n        const auto part_id = i % num_parts;\n        const auto tid = i / num_parts;\n        // + 1 so that we have exclusive_scan after torch.cumsum().\n        offsets_permuted_ptr[part_id * num_parts + tid + 1] = offsets_ptr[i];\n      });\n  offsets_permuted_ptr[0] = 0;\n  // offsets = offsets_permuted.cumsum(0); @TODO implement this in parallel.\n  std::inclusive_scan(\n      offsets_permuted_ptr, offsets_permuted_ptr + num_parts * num_parts + 1,\n      offsets_ptr);\n  offsets_ptr = offsets.data_ptr<int64_t>();\n  graphbolt::parallel_for_each(\n      0, num_parts * num_parts, kIntGrainSize, [&](int64_t i) {\n        const auto part_id = i % num_parts;\n        const auto tid = i / num_parts;\n        offsets_permuted_ptr[i] = offsets_ptr[part_id * num_parts + tid];\n      });\n  auto indices_ptr = indices.data_ptr<int64_t>();\n  auto permuted_keys = torch::empty_like(keys);\n  auto offsets_sliced = torch::empty(num_parts + 1, offsets.options());\n  auto offsets_sliced_ptr = offsets_sliced.data_ptr<int64_t>();\n  offsets_sliced_ptr[0] = 0;\n  AT_DISPATCH_INDEX_TYPES(\n      keys.scalar_type(), \"PartitionedCachePolicy::partition\", ([&] {\n        auto keys_ptr = keys.data_ptr<index_t>();\n        auto permuted_keys_ptr = permuted_keys.data_ptr<index_t>();\n        namespace gb = graphbolt;\n        gb::parallel_for_each(0, num_parts, 1, [&](int64_t tid) {\n          const auto begin = tid * num_keys / num_parts;\n          const auto end = (tid + 1) * num_keys / num_parts;\n          for (int64_t i = begin; i < end; i++) {\n            const auto part_id = part_id_ptr[i];\n            auto& offset = offsets_permuted_ptr[tid * num_parts + part_id];\n            indices_ptr[offset] = i;\n            permuted_keys_ptr[offset++] = keys_ptr[i];\n          }\n          offsets_sliced_ptr[tid + 1] = offsets_ptr[(tid + 1) * num_parts];\n        });\n      }));\n  return {offsets_sliced, indices, permuted_keys};\n}\n\nstd::tuple<\n    torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,\n    torch::Tensor>\nPartitionedCachePolicy::Query(torch::Tensor keys, const int64_t offset) {\n  keys = AddOffset(keys, offset);\n  if (policies_.size() == 1) {\n    std::lock_guard lock(mtx_);\n    auto [positions, output_indices, missing_keys, found_pointers] =\n        policies_[0]->Query(keys);\n    auto found_and_missing_offsets = torch::empty(4, found_pointers.options());\n    auto found_and_missing_offsets_ptr =\n        found_and_missing_offsets.data_ptr<int64_t>();\n    // Found offsets part.\n    found_and_missing_offsets_ptr[0] = 0;\n    found_and_missing_offsets_ptr[1] = found_pointers.size(0);\n    // Missing offsets part.\n    found_and_missing_offsets_ptr[2] = 0;\n    found_and_missing_offsets_ptr[3] = missing_keys.size(0);\n    auto found_offsets = found_and_missing_offsets.slice(0, 0, 2);\n    auto missing_offsets = found_and_missing_offsets.slice(0, 2);\n    missing_keys = AddOffset(missing_keys, -offset);\n    return {positions,      output_indices, missing_keys,\n            found_pointers, found_offsets,  missing_offsets};\n  };\n  torch::Tensor offsets, indices, permuted_keys;\n  std::tie(offsets, indices, permuted_keys) = Partition(keys);\n  auto offsets_ptr = offsets.data_ptr<int64_t>();\n  auto indices_ptr = indices.data_ptr<int64_t>();\n  std::vector<\n      std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>>\n      results(policies_.size());\n  torch::Tensor result_offsets_tensor =\n      torch::empty(policies_.size() * 2 + 1, offsets.options());\n  auto result_offsets = result_offsets_tensor.data_ptr<int64_t>();\n  namespace gb = graphbolt;\n  {\n    std::lock_guard lock(mtx_);\n    gb::parallel_for_each(0, policies_.size(), 1, [&](int64_t tid) {\n      const auto begin = offsets_ptr[tid];\n      const auto end = offsets_ptr[tid + 1];\n      results[tid] =\n          policies_.at(tid)->Query(permuted_keys.slice(0, begin, end));\n      result_offsets[tid] = std::get<0>(results[tid]).size(0);\n      result_offsets[tid + policies_.size()] =\n          std::get<2>(results[tid]).size(0);\n    });\n  }\n  std::exclusive_scan(\n      result_offsets, result_offsets + result_offsets_tensor.size(0),\n      result_offsets, 0);\n  torch::Tensor positions = torch::empty(\n      result_offsets[policies_.size()],\n      std::get<0>(results[0]).options().pinned_memory(utils::is_pinned(keys)));\n  torch::Tensor output_indices = torch::empty_like(\n      indices, indices.options().pinned_memory(utils::is_pinned(keys)));\n  torch::Tensor missing_keys = torch::empty(\n      indices.size(0) - positions.size(0),\n      std::get<2>(results[0]).options().pinned_memory(utils::is_pinned(keys)));\n  torch::Tensor found_pointers = torch::empty(\n      positions.size(0),\n      std::get<3>(results[0]).options().pinned_memory(utils::is_pinned(keys)));\n  auto missing_offsets =\n      torch::empty(policies_.size() + 1, result_offsets_tensor.options());\n  auto output_indices_ptr = output_indices.data_ptr<int64_t>();\n  auto missing_offsets_ptr = missing_offsets.data_ptr<int64_t>();\n  missing_offsets_ptr[0] = 0;\n  gb::parallel_for_each(0, policies_.size(), 1, [&](int64_t tid) {\n    auto out_index_ptr = indices_ptr + offsets_ptr[tid];\n    auto begin = result_offsets[tid];\n    auto end = result_offsets[tid + 1];\n    const auto num_selected = end - begin;\n    auto indices_ptr = std::get<1>(results[tid]).data_ptr<int64_t>();\n    for (int64_t i = 0; i < num_selected; i++) {\n      output_indices_ptr[begin + i] = out_index_ptr[indices_ptr[i]];\n    }\n    auto selected_positions_ptr = std::get<0>(results[tid]).data_ptr<int64_t>();\n    std::transform(\n        selected_positions_ptr, selected_positions_ptr + num_selected,\n        positions.data_ptr<int64_t>() + begin,\n        [off = tid * capacity_ / policies_.size()](auto x) { return x + off; });\n    auto selected_pointers_ptr = std::get<3>(results[tid]).data_ptr<int64_t>();\n    std::copy(\n        selected_pointers_ptr, selected_pointers_ptr + num_selected,\n        found_pointers.data_ptr<int64_t>() + begin);\n    begin = result_offsets[policies_.size() + tid];\n    end = result_offsets[policies_.size() + tid + 1];\n    missing_offsets[tid + 1] = end - result_offsets[policies_.size()];\n    const auto num_missing = end - begin;\n    for (int64_t i = 0; i < num_missing; i++) {\n      output_indices_ptr[begin + i] =\n          out_index_ptr[indices_ptr[i + num_selected]];\n    }\n    std::memcpy(\n        reinterpret_cast<std::byte*>(missing_keys.data_ptr()) +\n            (begin - positions.size(0)) * missing_keys.element_size(),\n        std::get<2>(results[tid]).data_ptr(),\n        num_missing * missing_keys.element_size());\n  });\n  auto found_offsets = result_offsets_tensor.slice(0, 0, policies_.size() + 1);\n  missing_keys = AddOffset(missing_keys, -offset);\n  return std::make_tuple(\n      positions, output_indices, missing_keys, found_pointers, found_offsets,\n      missing_offsets);\n}\n\nc10::intrusive_ptr<Future<std::vector<torch::Tensor>>>\nPartitionedCachePolicy::QueryAsync(torch::Tensor keys, const int64_t offset) {\n  return async([=] {\n    auto\n        [positions, output_indices, missing_keys, found_pointers, found_offsets,\n         missing_offsets] = Query(keys, offset);\n    return std::vector{positions,      output_indices, missing_keys,\n                       found_pointers, found_offsets,  missing_offsets};\n  });\n}\n\nstd::tuple<\n    torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,\n    torch::Tensor>\nPartitionedCachePolicy::QueryAndReplace(\n    torch::Tensor keys, const int64_t offset) {\n  keys = AddOffset(keys, offset);\n  if (policies_.size() == 1) {\n    std::lock_guard lock(mtx_);\n    auto [positions, output_indices, pointers, missing_keys] =\n        policies_[0]->QueryAndReplace(keys);\n    auto found_and_missing_offsets = torch::empty(4, pointers.options());\n    auto found_and_missing_offsets_ptr =\n        found_and_missing_offsets.data_ptr<int64_t>();\n    // Found offsets part.\n    found_and_missing_offsets_ptr[0] = 0;\n    found_and_missing_offsets_ptr[1] = keys.size(0) - missing_keys.size(0);\n    // Missing offsets part.\n    found_and_missing_offsets_ptr[2] = 0;\n    found_and_missing_offsets_ptr[3] = missing_keys.size(0);\n    auto found_offsets = found_and_missing_offsets.slice(0, 0, 2);\n    auto missing_offsets = found_and_missing_offsets.slice(0, 2);\n    missing_keys = AddOffset(missing_keys, -offset);\n    return {positions,    output_indices, pointers,\n            missing_keys, found_offsets,  missing_offsets};\n  }\n  torch::Tensor offsets, indices, permuted_keys;\n  std::tie(offsets, indices, permuted_keys) = Partition(keys);\n  auto offsets_ptr = offsets.data_ptr<int64_t>();\n  auto indices_ptr = indices.data_ptr<int64_t>();\n  std::vector<\n      std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>>\n      results(policies_.size());\n  torch::Tensor result_offsets_tensor =\n      torch::empty(policies_.size() * 2 + 1, offsets.options());\n  auto result_offsets = result_offsets_tensor.data_ptr<int64_t>();\n  namespace gb = graphbolt;\n  {\n    std::lock_guard lock(mtx_);\n    gb::parallel_for_each(0, policies_.size(), 1, [&](int64_t tid) {\n      const auto begin = offsets_ptr[tid];\n      const auto end = offsets_ptr[tid + 1];\n      results[tid] = policies_.at(tid)->QueryAndReplace(\n          permuted_keys.slice(0, begin, end));\n      const auto missing_cnt = std::get<3>(results[tid]).size(0);\n      result_offsets[tid] = end - begin - missing_cnt;\n      result_offsets[tid + policies_.size()] = missing_cnt;\n    });\n  }\n  std::exclusive_scan(\n      result_offsets, result_offsets + result_offsets_tensor.size(0),\n      result_offsets, 0);\n  torch::Tensor positions = torch::empty(\n      keys.size(0),\n      std::get<0>(results[0]).options().pinned_memory(utils::is_pinned(keys)));\n  torch::Tensor output_indices = torch::empty_like(\n      indices, indices.options().pinned_memory(utils::is_pinned(keys)));\n  torch::Tensor pointers = torch::empty(\n      keys.size(0),\n      std::get<2>(results[0]).options().pinned_memory(utils::is_pinned(keys)));\n  torch::Tensor missing_keys = torch::empty(\n      result_offsets[2 * policies_.size()] - result_offsets[policies_.size()],\n      std::get<3>(results[0]).options().pinned_memory(utils::is_pinned(keys)));\n  auto missing_offsets =\n      torch::empty(policies_.size() + 1, result_offsets_tensor.options());\n  auto positions_ptr = positions.data_ptr<int64_t>();\n  auto output_indices_ptr = output_indices.data_ptr<int64_t>();\n  auto pointers_ptr = pointers.data_ptr<int64_t>();\n  auto missing_offsets_ptr = missing_offsets.data_ptr<int64_t>();\n  missing_offsets_ptr[0] = 0;\n  gb::parallel_for_each(0, policies_.size(), 1, [&](int64_t tid) {\n    auto out_index_ptr = indices_ptr + offsets_ptr[tid];\n    auto begin = result_offsets[tid];\n    auto end = result_offsets[tid + 1];\n    const auto num_selected = end - begin;\n    auto indices_ptr = std::get<1>(results[tid]).data_ptr<int64_t>();\n    for (int64_t i = 0; i < num_selected; i++) {\n      output_indices_ptr[begin + i] = out_index_ptr[indices_ptr[i]];\n    }\n    auto selected_positions_ptr = std::get<0>(results[tid]).data_ptr<int64_t>();\n    std::transform(\n        selected_positions_ptr, selected_positions_ptr + num_selected,\n        positions_ptr + begin,\n        [off = tid * capacity_ / policies_.size()](auto x) { return x + off; });\n    auto selected_pointers_ptr = std::get<2>(results[tid]).data_ptr<int64_t>();\n    std::copy(\n        selected_pointers_ptr, selected_pointers_ptr + num_selected,\n        pointers_ptr + begin);\n    begin = result_offsets[policies_.size() + tid];\n    end = result_offsets[policies_.size() + tid + 1];\n    missing_offsets[tid + 1] = end - result_offsets[policies_.size()];\n    const auto num_missing = end - begin;\n    for (int64_t i = 0; i < num_missing; i++) {\n      output_indices_ptr[begin + i] =\n          out_index_ptr[indices_ptr[i + num_selected]];\n    }\n    auto missing_positions_ptr = selected_positions_ptr + num_selected;\n    std::transform(\n        missing_positions_ptr, missing_positions_ptr + num_missing,\n        positions_ptr + begin,\n        [off = tid * capacity_ / policies_.size()](auto x) { return x + off; });\n    auto missing_pointers_ptr = selected_pointers_ptr + num_selected;\n    std::copy(\n        missing_pointers_ptr, missing_pointers_ptr + num_missing,\n        pointers_ptr + begin);\n    std::memcpy(\n        reinterpret_cast<std::byte*>(missing_keys.data_ptr()) +\n            (begin - result_offsets[policies_.size()]) *\n                missing_keys.element_size(),\n        std::get<3>(results[tid]).data_ptr(),\n        num_missing * missing_keys.element_size());\n  });\n  auto found_offsets = result_offsets_tensor.slice(0, 0, policies_.size() + 1);\n  missing_keys = AddOffset(missing_keys, -offset);\n  return std::make_tuple(\n      positions, output_indices, pointers, missing_keys, found_offsets,\n      missing_offsets);\n}\n\nc10::intrusive_ptr<Future<std::vector<torch::Tensor>>>\nPartitionedCachePolicy::QueryAndReplaceAsync(\n    torch::Tensor keys, const int64_t offset) {\n  return async([=] {\n    auto\n        [positions, output_indices, pointers, missing_keys, found_offsets,\n         missing_offsets] = QueryAndReplace(keys, offset);\n    return std::vector{positions,    output_indices, pointers,\n                       missing_keys, found_offsets,  missing_offsets};\n  });\n}\n\nstd::tuple<torch::Tensor, torch::Tensor, torch::Tensor>\nPartitionedCachePolicy::Replace(\n    torch::Tensor keys, torch::optional<torch::Tensor> offsets,\n    const int64_t offset) {\n  keys = AddOffset(keys, offset);\n  if (policies_.size() == 1) {\n    std::lock_guard lock(mtx_);\n    auto [positions, pointers] = policies_[0]->Replace(keys);\n    if (!offsets.has_value()) {\n      offsets = torch::empty(2, pointers.options());\n      auto offsets_ptr = offsets->data_ptr<int64_t>();\n      offsets_ptr[0] = 0;\n      offsets_ptr[1] = pointers.size(0);\n    }\n    return {positions, pointers, *offsets};\n  }\n  const auto offsets_provided = offsets.has_value();\n  torch::Tensor indices, permuted_keys;\n  if (!offsets_provided) {\n    std::tie(offsets, indices, permuted_keys) = Partition(keys);\n  } else {\n    permuted_keys = keys;\n  }\n  auto output_positions = torch::empty_like(\n      keys, keys.options()\n                .dtype(torch::kInt64)\n                .pinned_memory(utils::is_pinned(keys)));\n  auto output_pointers = torch::empty_like(\n      keys, keys.options()\n                .dtype(torch::kInt64)\n                .pinned_memory(utils::is_pinned(keys)));\n  auto offsets_ptr = offsets->data_ptr<int64_t>();\n  auto indices_ptr = offsets_provided ? nullptr : indices.data_ptr<int64_t>();\n  auto output_positions_ptr = output_positions.data_ptr<int64_t>();\n  auto output_pointers_ptr = output_pointers.data_ptr<int64_t>();\n  namespace gb = graphbolt;\n  std::unique_lock lock(mtx_);\n  std::atomic<size_t> semaphore = policies_.size();\n  gb::parallel_for_each(0, policies_.size(), 1, [&](int64_t tid) {\n    const auto begin = offsets_ptr[tid];\n    const auto end = offsets_ptr[tid + 1];\n    auto [positions, pointers] =\n        policies_.at(tid)->Replace(permuted_keys.slice(0, begin, end));\n    const auto ticket = semaphore.fetch_add(-1, std::memory_order_release) - 1;\n    if (ticket == 0) {\n      // This thread was the last thread in the critical region.\n      lock.unlock();\n    }\n    auto positions_ptr = positions.data_ptr<int64_t>();\n    const auto off = tid * capacity_ / policies_.size();\n    if (indices_ptr) {\n      for (int64_t i = 0; i < positions.size(0); i++) {\n        output_positions_ptr[indices_ptr[begin + i]] = positions_ptr[i] + off;\n      }\n    } else {\n      std::transform(\n          positions_ptr, positions_ptr + positions.size(0),\n          output_positions_ptr + begin, [off](auto x) { return x + off; });\n    }\n    auto pointers_ptr = pointers.data_ptr<int64_t>();\n    std::copy(\n        pointers_ptr, pointers_ptr + pointers.size(0),\n        output_pointers_ptr + begin);\n  });\n  return {output_positions, output_pointers, *offsets};\n}\n\nc10::intrusive_ptr<Future<std::vector<torch::Tensor>>>\nPartitionedCachePolicy::ReplaceAsync(\n    torch::Tensor keys, torch::optional<torch::Tensor> offsets,\n    const int64_t offset) {\n  return async([=] {\n    auto [positions, pointers, offsets_out] = Replace(keys, offsets, offset);\n    return std::vector{positions, pointers, offsets_out};\n  });\n}\n\ntemplate <bool write>\nvoid PartitionedCachePolicy::ReadingWritingCompletedImpl(\n    torch::Tensor pointers, torch::Tensor offsets) {\n  if (policies_.size() == 1) {\n    if constexpr (write)\n      policies_[0]->WritingCompleted(pointers);\n    else\n      policies_[0]->ReadingCompleted(pointers);\n    return;\n  }\n  auto offsets_ptr = offsets.data_ptr<int64_t>();\n  namespace gb = graphbolt;\n  gb::parallel_for_each(0, policies_.size(), 1, [&](int64_t tid) {\n    const auto begin = offsets_ptr[tid];\n    const auto end = offsets_ptr[tid + 1];\n    if constexpr (write)\n      policies_.at(tid)->WritingCompleted(pointers.slice(0, begin, end));\n    else\n      policies_.at(tid)->ReadingCompleted(pointers.slice(0, begin, end));\n  });\n}\n\nvoid PartitionedCachePolicy::ReadingCompleted(\n    torch::Tensor pointers, torch::Tensor offsets) {\n  ReadingWritingCompletedImpl<false>(pointers, offsets);\n}\n\nvoid PartitionedCachePolicy::WritingCompleted(\n    torch::Tensor pointers, torch::Tensor offsets) {\n  ReadingWritingCompletedImpl<true>(pointers, offsets);\n}\n\nc10::intrusive_ptr<Future<void>> PartitionedCachePolicy::ReadingCompletedAsync(\n    torch::Tensor pointers, torch::Tensor offsets) {\n  return async([=] { return ReadingCompleted(pointers, offsets); });\n}\n\nc10::intrusive_ptr<Future<void>> PartitionedCachePolicy::WritingCompletedAsync(\n    torch::Tensor pointers, torch::Tensor offsets) {\n  return async([=] { return WritingCompleted(pointers, offsets); });\n}\n\ntemplate <typename CachePolicy>\nc10::intrusive_ptr<PartitionedCachePolicy> PartitionedCachePolicy::Create(\n    int64_t capacity, int64_t num_partitions) {\n  static_assert(std::is_base_of_v<BaseCachePolicy, CachePolicy>);\n  return c10::make_intrusive<PartitionedCachePolicy>(\n      CachePolicy(), capacity, num_partitions);\n}\n\ntemplate c10::intrusive_ptr<PartitionedCachePolicy>\n    PartitionedCachePolicy::Create<S3FifoCachePolicy>(int64_t, int64_t);\ntemplate c10::intrusive_ptr<PartitionedCachePolicy>\n    PartitionedCachePolicy::Create<SieveCachePolicy>(int64_t, int64_t);\ntemplate c10::intrusive_ptr<PartitionedCachePolicy>\n    PartitionedCachePolicy::Create<LruCachePolicy>(int64_t, int64_t);\ntemplate c10::intrusive_ptr<PartitionedCachePolicy>\n    PartitionedCachePolicy::Create<ClockCachePolicy>(int64_t, int64_t);\n\n}  // namespace storage\n}  // namespace graphbolt\n"
  },
  {
    "path": "graphbolt/src/partitioned_cache_policy.h",
    "content": "/**\n *   Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)\n *   All rights reserved.\n *\n *   Licensed under the Apache License, Version 2.0 (the \"License\");\n *   you may not use this file except in compliance with the License.\n *   You may obtain a copy of the License at\n *\n *       http://www.apache.org/licenses/LICENSE-2.0\n *\n *   Unless required by applicable law or agreed to in writing, software\n *   distributed under the License is distributed on an \"AS IS\" BASIS,\n *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n *   See the License for the specific language governing permissions and\n *   limitations under the License.\n *\n * @file partitioned_cache_policy.h\n * @brief Partitioned cache policy implementation on the CPU.\n */\n#ifndef GRAPHBOLT_PARTITIONED_CACHE_H_\n#define GRAPHBOLT_PARTITIONED_CACHE_H_\n\n#include <graphbolt/async.h>\n#include <torch/custom_class.h>\n#include <torch/torch.h>\n\n#include <mutex>\n#include <pcg_random.hpp>\n#include <random>\n#include <type_traits>\n#include <vector>\n\n#include \"./cache_policy.h\"\n\nnamespace graphbolt {\nnamespace storage {\n\n/**\n * @brief PartitionedCachePolicy works by partitioning the key space to a set\n * number of partitions that is provided as the second argument of its\n * constructor. Since the partitioning is random but deterministic, the caching\n * policy performance is not affected as the key distribution stays the same in\n * each partition.\n **/\nclass PartitionedCachePolicy : public torch::CustomClassHolder {\n public:\n  /**\n   * @brief The policy query function.\n   * @param capacity The capacity of the cache.\n   * @param num_partitions The number of caching policies instantiated in a\n   * one-to-one mapping to each partition.\n   */\n  template <typename CachePolicy>\n  PartitionedCachePolicy(CachePolicy, int64_t capacity, int64_t num_partitions);\n\n  /**\n   * @brief The policy query function.\n   * @param keys The keys to query the cache.\n   * @param offset The offset to be added to the keys.\n   *\n   * @return (positions, indices, missing_keys, found_ptrs, found_offsets,\n   * missing_offsets), where positions has the locations of the keys which were\n   * found in the cache, missing_keys has the keys that were not found and\n   * indices is defined such that keys[indices[:positions.size(0)]] gives us the\n   * keys for the found pointers and keys[indices[positions.size(0):]] is\n   * identical to missing_keys. The found_offsets tensor holds the partition\n   * offsets for the found pointers. The missing_offsets holds the partition\n   * offsets for the missing_keys.\n   */\n  std::tuple<\n      torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,\n      torch::Tensor>\n  Query(torch::Tensor keys, int64_t offset);\n\n  c10::intrusive_ptr<Future<std::vector<torch::Tensor>>> QueryAsync(\n      torch::Tensor keys, int64_t offset);\n\n  /**\n   * @brief The policy query and then replace function.\n   * @param keys The keys to query the cache.\n   * @param offset The offset to be added to the keys.\n   *\n   * @return (positions, indices, pointers, missing_keys, found_offsets,\n   * missing_offsets), where positions has the locations of the keys which were\n   * emplaced into the cache, pointers point to the emplaced CacheKey pointers\n   * in the cache, missing_keys has the keys that were not found and just\n   * inserted and indices is defined such that keys[indices[:keys.size(0) -\n   * missing_keys.size(0)]] gives us the keys for the found keys and\n   * keys[indices[keys.size(0) - missing_keys.size(0):]] is identical to\n   * missing_keys. The found_offsets tensor holds the partition offsets for the\n   * found pointers. The missing_offsets holds the partition offsets for the\n   * missing_keys and missing pointers.\n   */\n  std::tuple<\n      torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,\n      torch::Tensor>\n  QueryAndReplace(torch::Tensor keys, int64_t offset);\n\n  c10::intrusive_ptr<Future<std::vector<torch::Tensor>>> QueryAndReplaceAsync(\n      torch::Tensor keys, int64_t offset);\n\n  /**\n   * @brief The policy replace function.\n   * @param keys The keys to query the cache.\n   * @param offsets The partition offsets for the keys.\n   * @param offset The offset to be added to the keys.\n   *\n   * @return (positions, pointers, offsets), where positions holds the locations\n   * of the replaced entries in the cache, pointers holds the CacheKey pointers\n   * for the inserted keys and offsets holds the partition offsets for pointers.\n   */\n  std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> Replace(\n      torch::Tensor keys, torch::optional<torch::Tensor> offsets,\n      int64_t offset);\n\n  c10::intrusive_ptr<Future<std::vector<torch::Tensor>>> ReplaceAsync(\n      torch::Tensor keys, torch::optional<torch::Tensor> offsets,\n      int64_t offset);\n\n  template <bool write>\n  void ReadingWritingCompletedImpl(\n      torch::Tensor pointers, torch::Tensor offsets);\n\n  /**\n   * @brief A reader has finished reading these keys, so they can be\n   * evicted.\n   * @param pointers The CacheKey pointers in the cache to unmark.\n   * @param offsets The partition offsets for the pointers.\n   */\n  void ReadingCompleted(torch::Tensor pointers, torch::Tensor offsets);\n\n  /**\n   * @brief A writer has finished writing these keys, so they can be evicted.\n   * @param pointers The CacheKey pointers in the cache to unmark.\n   * @param offsets The partition offsets for the pointers.\n   */\n  void WritingCompleted(torch::Tensor pointers, torch::Tensor offsets);\n\n  c10::intrusive_ptr<Future<void>> ReadingCompletedAsync(\n      torch::Tensor pointers, torch::Tensor offsets);\n\n  c10::intrusive_ptr<Future<void>> WritingCompletedAsync(\n      torch::Tensor pointers, torch::Tensor offsets);\n\n  template <typename CachePolicy>\n  static c10::intrusive_ptr<PartitionedCachePolicy> Create(\n      int64_t capacity, int64_t num_partitions);\n\n private:\n  static constexpr uint64_t seed = 1e9 + 7;\n\n  /**\n   * @brief Deterministic assignment of keys to different parts.\n   */\n  int32_t PartAssignment(int64_t key) {\n    pcg32 rng(seed, key);\n    std::uniform_int_distribution<int32_t> dist(0, policies_.size() - 1);\n    return dist(rng);\n  }\n\n  /**\n   * @brief The partition function for a given keys tensor.\n   * @param keys The keys to query the cache.\n   *\n   * @return (offsets, indices, permuted_keys), the returned tensors have the\n   * following properties:\n   * permuted_keys[offsets[i]: offsets[i + 1]] belong to part i and\n   * keys[indices] == permuted_keys\n   */\n  std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> Partition(\n      torch::Tensor keys);\n\n  int64_t capacity_;\n  std::vector<std::unique_ptr<BaseCachePolicy>> policies_;\n  std::mutex mtx_;\n};\n\n}  // namespace storage\n}  // namespace graphbolt\n\n#endif  // GRAPHBOLT_PARTITIONED_CACHE_H_\n"
  },
  {
    "path": "graphbolt/src/python_binding.cc",
    "content": "/**\n *  Copyright (c) 2023 by Contributors\n * @file python_binding.cc\n * @brief Graph bolt library Python binding.\n */\n\n#include <graphbolt/fused_csc_sampling_graph.h>\n#include <graphbolt/isin.h>\n#include <graphbolt/serialize.h>\n#include <graphbolt/unique_and_compact.h>\n\n#ifdef GRAPHBOLT_USE_CUDA\n#include \"./cuda/cooperative_minibatching_utils.h\"\n#include \"./cuda/max_uva_threads.h\"\n#endif\n#include \"./cnumpy.h\"\n#include \"./feature_cache.h\"\n#include \"./index_select.h\"\n#include \"./io_uring.h\"\n#include \"./partitioned_cache_policy.h\"\n#include \"./random.h\"\n#include \"./utils.h\"\n\n#ifdef GRAPHBOLT_USE_CUDA\n#include \"./cuda/extension/gpu_cache.h\"\n#include \"./cuda/extension/gpu_graph_cache.h\"\n#endif\n\nnamespace graphbolt {\nnamespace sampling {\n\nTORCH_LIBRARY(graphbolt, m) {\n  m.class_<FusedSampledSubgraph>(\"FusedSampledSubgraph\")\n      .def(torch::init<>())\n      .def_readwrite(\"indptr\", &FusedSampledSubgraph::indptr)\n      .def_readwrite(\"indices\", &FusedSampledSubgraph::indices)\n      .def_readwrite(\n          \"original_row_node_ids\", &FusedSampledSubgraph::original_row_node_ids)\n      .def_readwrite(\n          \"original_column_node_ids\",\n          &FusedSampledSubgraph::original_column_node_ids)\n      .def_readwrite(\n          \"original_edge_ids\", &FusedSampledSubgraph::original_edge_ids)\n      .def_readwrite(\"type_per_edge\", &FusedSampledSubgraph::type_per_edge)\n      .def_readwrite(\"etype_offsets\", &FusedSampledSubgraph::etype_offsets);\n  m.class_<Future<void>>(\"VoidFuture\").def(\"wait\", &Future<void>::Wait);\n  m.class_<Future<torch::Tensor>>(\"TensorFuture\")\n      .def(\"wait\", &Future<torch::Tensor>::Wait);\n  m.class_<Future<std::vector<torch::Tensor>>>(\"TensorListFuture\")\n      .def(\"wait\", &Future<std::vector<torch::Tensor>>::Wait);\n  m.class_<Future<c10::intrusive_ptr<FusedSampledSubgraph>>>(\n       \"FusedSampledSubgraphFuture\")\n      .def(\"wait\", &Future<c10::intrusive_ptr<FusedSampledSubgraph>>::Wait);\n  m.class_<Future<std::vector<\n      std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>>>>(\n       \"UniqueAndCompactBatchedFuture\")\n      .def(\n          \"wait\",\n          &Future<std::vector<std::tuple<\n              torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>>>::\n              Wait);\n  m.class_<Future<\n      std::vector<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>>>>(\n       \"RankSortFuture\")\n      .def(\n          \"wait\",\n          &Future<std::vector<\n              std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>>>::Wait);\n  m.class_<Future<std::tuple<torch::Tensor, torch::Tensor, int64_t, int64_t>>>(\n       \"GpuGraphCacheQueryFuture\")\n      .def(\n          \"wait\",\n          &Future<std::tuple<torch::Tensor, torch::Tensor, int64_t, int64_t>>::\n              Wait);\n  m.class_<Future<std::tuple<torch::Tensor, std::vector<torch::Tensor>>>>(\n       \"GpuGraphCacheReplaceFuture\")\n      .def(\n          \"wait\",\n          &Future<std::tuple<torch::Tensor, std::vector<torch::Tensor>>>::Wait);\n  m.class_<storage::OnDiskNpyArray>(\"OnDiskNpyArray\")\n      .def(\"index_select\", &storage::OnDiskNpyArray::IndexSelect);\n  m.class_<FusedCSCSamplingGraph>(\"FusedCSCSamplingGraph\")\n      .def(\"num_nodes\", &FusedCSCSamplingGraph::NumNodes)\n      .def(\"num_edges\", &FusedCSCSamplingGraph::NumEdges)\n      .def(\"csc_indptr\", &FusedCSCSamplingGraph::CSCIndptr)\n      .def(\"indices\", &FusedCSCSamplingGraph::Indices)\n      .def(\"node_type_offset\", &FusedCSCSamplingGraph::NodeTypeOffset)\n      .def(\"type_per_edge\", &FusedCSCSamplingGraph::TypePerEdge)\n      .def(\"node_type_to_id\", &FusedCSCSamplingGraph::NodeTypeToID)\n      .def(\"edge_type_to_id\", &FusedCSCSamplingGraph::EdgeTypeToID)\n      .def(\"node_attributes\", &FusedCSCSamplingGraph::NodeAttributes)\n      .def(\"edge_attributes\", &FusedCSCSamplingGraph::EdgeAttributes)\n      .def(\"node_attribute\", &FusedCSCSamplingGraph::NodeAttribute)\n      .def(\"edge_attribute\", &FusedCSCSamplingGraph::EdgeAttribute)\n      .def(\"set_csc_indptr\", &FusedCSCSamplingGraph::SetCSCIndptr)\n      .def(\"set_indices\", &FusedCSCSamplingGraph::SetIndices)\n      .def(\"set_node_type_offset\", &FusedCSCSamplingGraph::SetNodeTypeOffset)\n      .def(\"set_type_per_edge\", &FusedCSCSamplingGraph::SetTypePerEdge)\n      .def(\"set_node_type_to_id\", &FusedCSCSamplingGraph::SetNodeTypeToID)\n      .def(\"set_edge_type_to_id\", &FusedCSCSamplingGraph::SetEdgeTypeToID)\n      .def(\"set_node_attributes\", &FusedCSCSamplingGraph::SetNodeAttributes)\n      .def(\"set_edge_attributes\", &FusedCSCSamplingGraph::SetEdgeAttributes)\n      .def(\"add_node_attribute\", &FusedCSCSamplingGraph::AddNodeAttribute)\n      .def(\"add_edge_attribute\", &FusedCSCSamplingGraph::AddEdgeAttribute)\n      .def(\"in_subgraph\", &FusedCSCSamplingGraph::InSubgraph)\n      .def(\"sample_neighbors\", &FusedCSCSamplingGraph::SampleNeighbors)\n      .def(\n          \"sample_neighbors_async\",\n          &FusedCSCSamplingGraph::SampleNeighborsAsync)\n      .def(\n          \"temporal_sample_neighbors\",\n          &FusedCSCSamplingGraph::TemporalSampleNeighbors)\n      .def(\"copy_to_shared_memory\", &FusedCSCSamplingGraph::CopyToSharedMemory)\n      .def_pickle(\n          // __getstate__\n          [](const c10::intrusive_ptr<FusedCSCSamplingGraph>& self)\n              -> torch::Dict<\n                  std::string, torch::Dict<std::string, torch::Tensor>> {\n            return self->GetState();\n          },\n          // __setstate__\n          [](torch::Dict<std::string, torch::Dict<std::string, torch::Tensor>>\n                 state) -> c10::intrusive_ptr<FusedCSCSamplingGraph> {\n            auto g = c10::make_intrusive<FusedCSCSamplingGraph>();\n            g->SetState(state);\n            return g;\n          });\n#ifdef GRAPHBOLT_USE_CUDA\n  m.class_<cuda::GpuCache>(\"GpuCache\")\n      .def(\"query\", &cuda::GpuCache::Query)\n      .def(\"query_async\", &cuda::GpuCache::QueryAsync)\n      .def(\"replace\", &cuda::GpuCache::Replace);\n  m.def(\"gpu_cache\", &cuda::GpuCache::Create);\n  m.class_<cuda::GpuGraphCache>(\"GpuGraphCache\")\n      .def(\"query\", &cuda::GpuGraphCache::Query)\n      .def(\"query_async\", &cuda::GpuGraphCache::QueryAsync)\n      .def(\"replace\", &cuda::GpuGraphCache::Replace)\n      .def(\"replace_async\", &cuda::GpuGraphCache::ReplaceAsync);\n  m.def(\"gpu_graph_cache\", &cuda::GpuGraphCache::Create);\n#endif\n  m.def(\"fused_csc_sampling_graph\", &FusedCSCSamplingGraph::Create);\n  m.class_<storage::PartitionedCachePolicy>(\"PartitionedCachePolicy\")\n      .def(\"query\", &storage::PartitionedCachePolicy::Query)\n      .def(\"query_async\", &storage::PartitionedCachePolicy::QueryAsync)\n      .def(\n          \"query_and_replace\",\n          &storage::PartitionedCachePolicy::QueryAndReplace)\n      .def(\n          \"query_and_replace_async\",\n          &storage::PartitionedCachePolicy::QueryAndReplaceAsync)\n      .def(\"replace\", &storage::PartitionedCachePolicy::Replace)\n      .def(\"replace_async\", &storage::PartitionedCachePolicy::ReplaceAsync)\n      .def(\n          \"reading_completed\",\n          &storage::PartitionedCachePolicy::ReadingCompleted)\n      .def(\n          \"reading_completed_async\",\n          &storage::PartitionedCachePolicy::ReadingCompletedAsync)\n      .def(\n          \"writing_completed\",\n          &storage::PartitionedCachePolicy::WritingCompleted)\n      .def(\n          \"writing_completed_async\",\n          &storage::PartitionedCachePolicy::WritingCompletedAsync);\n  m.def(\n      \"s3_fifo_cache_policy\",\n      &storage::PartitionedCachePolicy::Create<storage::S3FifoCachePolicy>);\n  m.def(\n      \"sieve_cache_policy\",\n      &storage::PartitionedCachePolicy::Create<storage::SieveCachePolicy>);\n  m.def(\n      \"lru_cache_policy\",\n      &storage::PartitionedCachePolicy::Create<storage::LruCachePolicy>);\n  m.def(\n      \"clock_cache_policy\",\n      &storage::PartitionedCachePolicy::Create<storage::ClockCachePolicy>);\n  m.class_<storage::FeatureCache>(\"FeatureCache\")\n      .def(\"is_pinned\", &storage::FeatureCache::IsPinned)\n      .def_property(\"nbytes\", &storage::FeatureCache::NumBytes)\n      .def(\"index_select\", &storage::FeatureCache::IndexSelect)\n      .def(\"query\", &storage::FeatureCache::Query)\n      .def(\"query_async\", &storage::FeatureCache::QueryAsync)\n      .def(\"replace\", &storage::FeatureCache::Replace)\n      .def(\"replace_async\", &storage::FeatureCache::ReplaceAsync);\n  m.def(\"feature_cache\", &storage::FeatureCache::Create);\n  m.def(\n      \"load_from_shared_memory\", &FusedCSCSamplingGraph::LoadFromSharedMemory);\n  m.def(\"unique_and_compact\", &UniqueAndCompact);\n  m.def(\"unique_and_compact_batched\", &UniqueAndCompactBatched);\n  m.def(\"unique_and_compact_batched_async\", &UniqueAndCompactBatchedAsync);\n  m.def(\"isin\", &IsIn);\n  m.def(\"is_not_in_index\", &IsNotInIndex);\n  m.def(\"is_not_in_index_async\", &IsNotInIndexAsync);\n  m.def(\"index_select\", &ops::IndexSelect);\n  m.def(\"index_select_async\", &ops::IndexSelectAsync);\n  m.def(\"scatter_async\", &ops::ScatterAsync);\n  m.def(\"index_select_csc\", &ops::IndexSelectCSC);\n  m.def(\"index_select_csc_batched\", &ops::IndexSelectCSCBatched);\n  m.def(\"index_select_csc_batched_async\", &ops::IndexSelectCSCBatchedAsync);\n  m.def(\"ondisk_npy_array\", &storage::OnDiskNpyArray::Create);\n  m.def(\"detect_io_uring\", &io_uring::IsAvailable);\n  m.def(\"set_num_io_uring_threads\", &io_uring::SetNumThreads);\n  m.def(\"set_worker_id\", &utils::SetWorkerId);\n  m.def(\"set_seed\", &RandomEngine::SetManualSeed);\n#ifdef GRAPHBOLT_USE_CUDA\n  m.def(\"set_max_uva_threads\", &cuda::set_max_uva_threads);\n  m.def(\"rank_sort\", &cuda::RankSort);\n  m.def(\"rank_sort_async\", &cuda::RankSortAsync);\n#endif\n#ifdef HAS_IMPL_ABSTRACT_PYSTUB\n  m.impl_abstract_pystub(\"dgl.graphbolt.base\", \"//dgl.graphbolt.base\");\n#endif\n  m.def(\n      \"expand_indptr(Tensor indptr, ScalarType dtype, Tensor? node_ids, \"\n      \"SymInt? output_size) -> Tensor\"\n#ifdef HAS_PT2_COMPLIANT_TAG\n      ,\n      {at::Tag::pt2_compliant_tag}\n#endif\n  );\n  m.def(\n      \"indptr_edge_ids(Tensor indptr, ScalarType dtype, Tensor? offset, \"\n      \"SymInt? output_size) -> \"\n      \"Tensor\"\n#ifdef HAS_PT2_COMPLIANT_TAG\n      ,\n      {at::Tag::pt2_compliant_tag}\n#endif\n  );\n}\n\n}  // namespace sampling\n}  // namespace graphbolt\n"
  },
  {
    "path": "graphbolt/src/random.cc",
    "content": "/**\n *  Copyright (c) 2023 by Contributors\n * @file random.cc\n * @brief Random Engine.\n */\n\n#include \"./random.h\"\n\n#include <torch/torch.h>\n\nnamespace graphbolt {\n\nnamespace {\n\n// Get a unique integer ID representing this thread.\ninline uint32_t GetThreadId() {\n  static int num_threads = 0;\n  static std::mutex mutex;\n  static thread_local int id = -1;\n\n  if (id == -1) {\n    std::lock_guard<std::mutex> guard(mutex);\n    id = num_threads;\n    num_threads++;\n  }\n  return id;\n}\n\n};  // namespace\n\nstd::mutex RandomEngine::manual_seed_mutex;\nstd::optional<uint64_t> RandomEngine::manual_seed;\n\n/** @brief Constructor with default seed. */\nRandomEngine::RandomEngine() {\n  std::random_device rd;\n  std::lock_guard lock(manual_seed_mutex);\n  if (!manual_seed.has_value()) manual_seed = rd();\n  SetSeed(manual_seed.value());\n}\n\n/** @brief Constructor with given seed. */\nRandomEngine::RandomEngine(uint64_t seed) : RandomEngine(seed, GetThreadId()) {}\n\n/** @brief Constructor with given seed. */\nRandomEngine::RandomEngine(uint64_t seed, uint64_t stream) {\n  SetSeed(seed, stream);\n}\n\n/** @brief Get the thread-local random number generator instance. */\nRandomEngine* RandomEngine::ThreadLocal() {\n  static thread_local RandomEngine engine;\n  return &engine;\n}\n\n/** @brief Set the seed. */\nvoid RandomEngine::SetSeed(uint64_t seed) { SetSeed(seed, GetThreadId()); }\n\n/** @brief Set the seed. */\nvoid RandomEngine::SetSeed(uint64_t seed, uint64_t stream) {\n  rng_.seed(seed, stream);\n}\n\n/** @brief Manually fix the seed. */\nvoid RandomEngine::SetManualSeed(int64_t seed) {\n  // Intentionally set the seed for current thread also.\n  RandomEngine::ThreadLocal()->SetSeed(seed);\n  std::lock_guard lock(manual_seed_mutex);\n  manual_seed = seed;\n}\n\n}  // namespace graphbolt\n"
  },
  {
    "path": "graphbolt/src/random.h",
    "content": "\n/**\n *  Copyright (c) 2023 by Contributors\n *\n * @file random.h\n * @brief Random Engine class.\n */\n#ifndef GRAPHBOLT_RANDOM_H_\n#define GRAPHBOLT_RANDOM_H_\n\n#include <mutex>\n#include <optional>\n#include <pcg_random.hpp>\n#include <random>\n#include <thread>\n\nnamespace graphbolt {\n\n/**\n * @brief Thread-local Random Number Generator class.\n */\nclass RandomEngine {\n public:\n  /** @brief Constructor with default seed. */\n  RandomEngine();\n\n  /** @brief Constructor with given seed. */\n  explicit RandomEngine(uint64_t seed);\n  explicit RandomEngine(uint64_t seed, uint64_t stream);\n\n  /** @brief Get the thread-local random number generator instance. */\n  static RandomEngine* ThreadLocal();\n\n  /** @brief Set the seed. */\n  void SetSeed(uint64_t seed);\n  void SetSeed(uint64_t seed, uint64_t stream);\n\n  /** @brief Protect manual seed accesses. */\n  static std::mutex manual_seed_mutex;\n\n  /** @brief Manually fix the seed. */\n  static std::optional<uint64_t> manual_seed;\n  static void SetManualSeed(int64_t seed);\n\n  /**\n   * @brief Generate a uniform random integer in [low, high).\n   */\n  template <typename T>\n  T RandInt(T lower, T upper) {\n    std::uniform_int_distribution<T> dist(lower, upper - 1);\n    return dist(rng_);\n  }\n\n  /**\n   * @brief Generate a uniform random real number in [low, high).\n   */\n  template <typename T>\n  T Uniform(T lower, T upper) {\n    std::uniform_real_distribution<T> dist(lower, upper);\n    return dist(rng_);\n  }\n\n  /**\n   * @brief Generate random non-negative floating-point values according to\n   * exponential distribution. Probability density function: P(x|λ) = λe^(-λx).\n   */\n  template <typename T>\n  T Exponential(T lambda) {\n    std::exponential_distribution<T> dist(lambda);\n    return dist(rng_);\n  }\n\n private:\n  pcg32 rng_;\n};\n\n}  // namespace graphbolt\n\n#endif  // GRAPHBOLT_RANDOM_H_\n"
  },
  {
    "path": "graphbolt/src/serialize.cc",
    "content": "/**\n *  Copyright (c) 2023 by Contributors\n * @file graphbolt/src/serialize.cc\n * @brief Source file of serialize.\n */\n\n#include <graphbolt/serialize.h>\n#include <torch/torch.h>\n\nnamespace torch {\n\nserialize::InputArchive& operator>>(\n    serialize::InputArchive& archive,\n    graphbolt::sampling::FusedCSCSamplingGraph& graph) {\n  graph.Load(archive);\n  return archive;\n}\n\nserialize::OutputArchive& operator<<(\n    serialize::OutputArchive& archive,\n    const graphbolt::sampling::FusedCSCSamplingGraph& graph) {\n  graph.Save(archive);\n  return archive;\n}\n\n}  // namespace torch\n"
  },
  {
    "path": "graphbolt/src/shared_memory.cc",
    "content": "/**\n *  Copyright (c) 2023 by Contributors\n * @file shared_memory.cc\n * @brief Source file of graphbolt shared memory.\n */\n#ifndef _WIN32\n#include <fcntl.h>\n#include <sys/mman.h>\n#include <sys/stat.h>\n#include <unistd.h>\n#endif  // !_WIN32\n\n#include <graphbolt/shared_memory.h>\n#include <stdio.h>\n#include <string.h>\n#include <torch/torch.h>\n\nnamespace graphbolt {\nnamespace sampling {\n\n// Two processes opening the same path are guaranteed to access the same shared\n// memory object if and only if path begins with a slash ('/') character.\nconstexpr char kSharedMemNamePrefix[] = \"/dgl.graphbolt.\";\nconstexpr char kSharedMemNameSuffix[] = \".lock\";\n\n// A prefix and a suffix are added to the name of the shared memory to create\n// the name of the shared memory object.\ninline std::string DecorateName(const std::string& name) {\n  return kSharedMemNamePrefix + name + kSharedMemNameSuffix;\n}\n\nSharedMemory::SharedMemory(const std::string& name)\n    : name_(name), size_(0), ptr_(nullptr) {\n#ifdef _WIN32\n  this->handle_ = nullptr;\n#else   // _WIN32\n  this->file_descriptor_ = -1;\n  this->is_creator_ = false;\n#endif  // _WIN32\n}\n\n#ifdef _WIN32\n\nSharedMemory::~SharedMemory() {\n  if (ptr_) CHECK(UnmapViewOfFile(ptr_)) << \"Win32 Error: \" << GetLastError();\n  if (handle_) CloseHandle(handle_);\n}\n\nvoid* SharedMemory::Create(size_t size) {\n  size_ = size;\n\n  std::string decorated_name = DecorateName(name_);\n  handle_ = CreateFileMapping(\n      INVALID_HANDLE_VALUE, nullptr, PAGE_READWRITE,\n      static_cast<DWORD>(size >> 32), static_cast<DWORD>(size & 0xFFFFFFFF),\n      decorated_name.c_str());\n  TORCH_CHECK(\n      handle_ != nullptr, \"Failed to open \", decorated_name,\n      \", Win32 error: \", GetLastError());\n\n  ptr_ = MapViewOfFile(handle_, FILE_MAP_ALL_ACCESS, 0, 0, size);\n  TORCH_CHECK(\n      ptr_ != nullptr, \"Memory mapping failed, Win32 error: \", GetLastError());\n  return ptr_;\n}\n\nvoid* SharedMemory::Open() {\n  std::string decorated_name = DecorateName(name_);\n  handle_ = OpenFileMapping(FILE_MAP_ALL_ACCESS, FALSE, decorated_name.c_str());\n  TORCH_CHECK(\n      handle_ != nullptr, \"Failed to open \", decorated_name,\n      \", Win32 Error: \", GetLastError());\n\n  ptr_ = MapViewOfFile(handle_, FILE_MAP_ALL_ACCESS, 0, 0, 0);\n  TORCH_CHECK(\n      ptr_ != nullptr, \"Memory mapping failed, Win32 error: \", GetLastError());\n\n  // Obtain the size of the memory-mapped file.\n  MEMORY_BASIC_INFORMATION memInfo;\n  TORCH_CHECK(\n      VirtualQuery(ptr_, &memInfo, sizeof(memInfo)) != 0,\n      \"Failed to get the size of shared memory: \", GetLastError());\n  size_ = static_cast<size_t>(memInfo.RegionSize);\n\n  return ptr_;\n}\n\nbool SharedMemory::Exists(const std::string& name) {\n  std::string decorated_name = DecorateName(name);\n  HANDLE handle =\n      OpenFileMapping(FILE_MAP_ALL_ACCESS, FALSE, decorated_name.c_str());\n  bool exists = handle != nullptr;\n  if (exists) {\n    CloseHandle(handle);\n  }\n  return exists;\n}\n\n#else  // _WIN32\n\nSharedMemory::~SharedMemory() {\n  if (ptr_ && size_ != 0) CHECK(munmap(ptr_, size_) != -1) << strerror(errno);\n  if (file_descriptor_ != -1) close(file_descriptor_);\n\n  std::string decorated_name = DecorateName(name_);\n  if (is_creator_ && decorated_name != \"\") shm_unlink(decorated_name.c_str());\n}\n\nvoid *SharedMemory::Create(size_t size) {\n  size_ = size;\n  is_creator_ = true;\n\n  // TODO(zhenkun): handle the error properly if the shared memory object\n  // already exists.\n  std::string decorated_name = DecorateName(name_);\n  file_descriptor_ =\n      shm_open(decorated_name.c_str(), O_RDWR | O_CREAT, S_IRUSR | S_IWUSR);\n  TORCH_CHECK(file_descriptor_ != -1, \"Failed to open: \", strerror(errno));\n\n  auto status = ftruncate(file_descriptor_, size);\n  TORCH_CHECK(status != -1, \"Failed to truncate the file: \", strerror(errno));\n\n  ptr_ =\n      mmap(NULL, size, PROT_READ | PROT_WRITE, MAP_SHARED, file_descriptor_, 0);\n  TORCH_CHECK(\n      ptr_ != MAP_FAILED,\n      \"Failed to map shared memory, mmap failed with error: \", strerror(errno));\n  return ptr_;\n}\n\nvoid *SharedMemory::Open() {\n  std::string decorated_name = DecorateName(name_);\n  file_descriptor_ =\n      shm_open(decorated_name.c_str(), O_RDWR, S_IRUSR | S_IWUSR);\n  TORCH_CHECK(\n      file_descriptor_ != -1, \"Failed to open \", decorated_name, \": \",\n      strerror(errno));\n\n  struct stat shm_stat;\n  TORCH_CHECK(\n      fstat(file_descriptor_, &shm_stat) == 0,\n      \"Failed to get the size of shared memory: \", strerror(errno));\n  size_ = shm_stat.st_size;\n\n  ptr_ = mmap(\n      NULL, size_, PROT_READ | PROT_WRITE, MAP_SHARED, file_descriptor_, 0);\n  TORCH_CHECK(\n      ptr_ != MAP_FAILED,\n      \"Failed to map shared memory, mmap failed with error: \", strerror(errno));\n  return ptr_;\n}\n\nbool SharedMemory::Exists(const std::string &name) {\n  std::string decorated_name = DecorateName(name);\n  int file_descriptor =\n      shm_open(decorated_name.c_str(), O_RDONLY, S_IRUSR | S_IWUSR);\n  bool exists = file_descriptor > 0;\n  if (exists) {\n    close(file_descriptor);\n  }\n  return exists;\n}\n\n#endif  // _WIN32\n\n}  // namespace sampling\n}  // namespace graphbolt\n"
  },
  {
    "path": "graphbolt/src/shared_memory_helper.cc",
    "content": "/**\n *  Copyright (c) 2023 by Contributors\n *\n * @file shared_memory_helper.cc\n * @brief Share memory helper implementation.\n */\n#include \"./shared_memory_helper.h\"\n\n#include <graphbolt/serialize.h>\n#include <graphbolt/shared_memory.h>\n#include <torch/torch.h>\n\n#include <cstring>\n#include <string>\n#include <tuple>\n#include <vector>\n\nnamespace graphbolt {\nnamespace sampling {\n\nstatic std::string GetSharedMemoryMetadataName(const std::string& name) {\n  return name + \"_metadata\";\n}\n\nstatic std::string GetSharedMemoryDataName(const std::string& name) {\n  return name + \"_data\";\n}\n\n// To avoid unaligned memory access, we round the size of the binary buffer to\n// the nearest multiple of 8 bytes.\ninline static int64_t GetRoundedSize(int64_t size) {\n  constexpr int64_t ALIGNED_SIZE = 8;\n  return (size + ALIGNED_SIZE - 1) / ALIGNED_SIZE * ALIGNED_SIZE;\n}\n\nSharedMemoryHelper::SharedMemoryHelper(const std::string& name)\n    : name_(name),\n      metadata_size_(0),\n      data_size_(0),\n      metadata_shared_memory_(nullptr),\n      data_shared_memory_(nullptr),\n      metadata_offset_(0),\n      data_offset_(0) {}\n\nvoid SharedMemoryHelper::InitializeRead() {\n  metadata_offset_ = 0;\n  data_offset_ = 0;\n  if (metadata_shared_memory_ == nullptr) {\n    // Reader process opens the shared memory.\n    metadata_shared_memory_ =\n        std::make_unique<SharedMemory>(GetSharedMemoryMetadataName(name_));\n    metadata_shared_memory_->Open();\n    metadata_size_ = metadata_shared_memory_->GetSize();\n    data_shared_memory_ =\n        std::make_unique<SharedMemory>(GetSharedMemoryDataName(name_));\n    data_shared_memory_->Open();\n    data_size_ = data_shared_memory_->GetSize();\n  }\n}\n\nvoid SharedMemoryHelper::WriteTorchArchive(\n    torch::serialize::OutputArchive&& archive) {\n  metadata_to_write_.emplace_back(std::move(archive));\n}\n\ntorch::serialize::InputArchive SharedMemoryHelper::ReadTorchArchive() {\n  auto metadata_ptr = this->GetCurrentMetadataPtr();\n  int64_t metadata_size = static_cast<int64_t*>(metadata_ptr)[0];\n  torch::serialize::InputArchive archive;\n  archive.load_from(\n      static_cast<const char*>(metadata_ptr) + sizeof(int64_t), metadata_size);\n  auto rounded_size = GetRoundedSize(metadata_size);\n  this->MoveMetadataPtr(sizeof(int64_t) + rounded_size);\n  return archive;\n}\n\nvoid SharedMemoryHelper::WriteTorchTensor(\n    torch::optional<torch::Tensor> tensor) {\n  torch::serialize::OutputArchive archive;\n  archive.write(\"has_value\", tensor.has_value());\n  if (tensor.has_value()) {\n    archive.write(\"shape\", tensor.value().sizes());\n    archive.write(\"dtype\", tensor.value().scalar_type());\n  }\n  this->WriteTorchArchive(std::move(archive));\n  tensors_to_write_.push_back(tensor);\n}\n\ntorch::optional<torch::Tensor> SharedMemoryHelper::ReadTorchTensor() {\n  auto archive = this->ReadTorchArchive();\n  bool has_value = read_from_archive<bool>(archive, \"has_value\");\n  if (has_value) {\n    auto shape = read_from_archive<std::vector<int64_t>>(archive, \"shape\");\n    auto dtype = read_from_archive<torch::ScalarType>(archive, \"dtype\");\n    auto data_ptr = this->GetCurrentDataPtr();\n    auto tensor = torch::from_blob(data_ptr, shape, dtype);\n    auto rounded_size = GetRoundedSize(tensor.numel() * tensor.element_size());\n    this->MoveDataPtr(rounded_size);\n    return tensor;\n  } else {\n    return torch::nullopt;\n  }\n}\n\nvoid SharedMemoryHelper::WriteTorchTensorDict(\n    torch::optional<torch::Dict<std::string, torch::Tensor>> tensor_dict) {\n  torch::serialize::OutputArchive archive;\n  if (!tensor_dict.has_value()) {\n    archive.write(\"has_value\", false);\n    this->WriteTorchArchive(std::move(archive));\n    return;\n  }\n  archive.write(\"has_value\", true);\n  auto dict_value = tensor_dict.value();\n  archive.write(\"num_tensors\", static_cast<int64_t>(dict_value.size()));\n  int counter = 0;\n  for (auto it = dict_value.begin(); it != dict_value.end(); ++it) {\n    archive.write(std::string(\"key_\") + std::to_string(counter), it->key());\n    counter++;\n  }\n  this->WriteTorchArchive(std::move(archive));\n  for (auto it = dict_value.begin(); it != dict_value.end(); ++it) {\n    this->WriteTorchTensor(it->value());\n  }\n}\n\ntorch::optional<torch::Dict<std::string, torch::Tensor>>\nSharedMemoryHelper::ReadTorchTensorDict() {\n  auto archive = this->ReadTorchArchive();\n  if (!read_from_archive<bool>(archive, \"has_value\")) {\n    return torch::nullopt;\n  }\n  int64_t num_tensors = read_from_archive<int64_t>(archive, \"num_tensors\");\n  torch::Dict<std::string, torch::Tensor> tensor_dict;\n  for (int64_t i = 0; i < num_tensors; ++i) {\n    auto key = read_from_archive<std::string>(\n        archive, std::string(\"key_\") + std::to_string(i));\n    auto tensor = this->ReadTorchTensor();\n    tensor_dict.insert(key, tensor.value());\n  }\n  return tensor_dict;\n}\n\nvoid SharedMemoryHelper::SerializeMetadata() {\n  for (auto& archive : metadata_to_write_) {\n    std::stringstream serialized;\n    archive.save_to(serialized);\n    metadata_strings_to_write_.push_back(std::move(serialized.str()));\n  }\n  metadata_to_write_.clear();\n}\n\nvoid SharedMemoryHelper::WriteMetadataToSharedMemory() {\n  metadata_offset_ = 0;\n  for (const auto& str : metadata_strings_to_write_) {\n    auto metadata_ptr = this->GetCurrentMetadataPtr();\n    static_cast<int64_t*>(metadata_ptr)[0] = str.size();\n    memcpy(\n        static_cast<char*>(metadata_ptr) + sizeof(int64_t), str.data(),\n        str.size());\n    int64_t rounded_size = GetRoundedSize(str.size());\n    this->MoveMetadataPtr(sizeof(int64_t) + rounded_size);\n  }\n  metadata_strings_to_write_.clear();\n}\n\nvoid SharedMemoryHelper::WriteTorchTensorInternal(\n    torch::optional<torch::Tensor> tensor) {\n  if (tensor.has_value()) {\n    size_t memory_size = tensor.value().numel() * tensor.value().element_size();\n    auto data_ptr = this->GetCurrentDataPtr();\n    auto contiguous_tensor = tensor.value().contiguous();\n    memcpy(data_ptr, contiguous_tensor.data_ptr(), memory_size);\n    this->MoveDataPtr(GetRoundedSize(memory_size));\n  }\n}\n\nvoid SharedMemoryHelper::Flush() {\n  size_t data_size = 0;\n  for (auto tensor : tensors_to_write_) {\n    if (tensor.has_value()) {\n      auto tensor_size = tensor.value().numel() * tensor.value().element_size();\n      data_size += GetRoundedSize(tensor_size);\n    }\n  }\n\n  // Serialize the metadata archives.\n  SerializeMetadata();\n\n  // Create the shared memory objects.\n  const size_t metadata_size = std::accumulate(\n      metadata_strings_to_write_.begin(), metadata_strings_to_write_.end(), 0,\n      [](size_t sum, const std::string& str) {\n        return sum + sizeof(int64_t) + GetRoundedSize(str.size());\n      });\n  metadata_shared_memory_ =\n      std::make_unique<SharedMemory>(GetSharedMemoryMetadataName(name_));\n  metadata_shared_memory_->Create(metadata_size);\n  metadata_size_ = metadata_size;\n\n  // Write the metadata and tensor data to the shared memory.\n  WriteMetadataToSharedMemory();\n  data_shared_memory_ =\n      std::make_unique<SharedMemory>(GetSharedMemoryDataName(name_));\n  data_shared_memory_->Create(data_size);\n  data_size_ = data_size;\n  data_offset_ = 0;\n  for (auto tensor : tensors_to_write_) {\n    this->WriteTorchTensorInternal(tensor);\n  }\n\n  metadata_to_write_.clear();\n  tensors_to_write_.clear();\n}\n\nstd::pair<SharedMemoryPtr, SharedMemoryPtr>\nSharedMemoryHelper::ReleaseSharedMemory() {\n  return std::make_pair(\n      std::move(metadata_shared_memory_), std::move(data_shared_memory_));\n}\n\n}  // namespace sampling\n}  // namespace graphbolt\n"
  },
  {
    "path": "graphbolt/src/shared_memory_helper.h",
    "content": "/**\n *  Copyright (c) 2023 by Contributors\n *\n * @file shared_memory_helper.h\n * @brief Share memory helper.\n */\n#ifndef GRAPHBOLT_SHARED_MEMORY_HELPER_H_\n#define GRAPHBOLT_SHARED_MEMORY_HELPER_H_\n\n#include <graphbolt/shared_memory.h>\n#include <torch/torch.h>\n\n#include <memory>\n#include <sstream>\n#include <string>\n#include <tuple>\n#include <vector>\n\nnamespace graphbolt {\nnamespace sampling {\n\n/**\n * @brief SharedMemoryHelper is a helper class to write/read data structures\n * to/from shared memory.\n *\n * In order to write data structure to shared memory, we need to serialize the\n * data structure to a binary buffer and then write the buffer to the shared\n * memory. However, the size of the binary buffer is not known in advance. To\n * solve this problem, we use two shared memory objects: one for storing the\n * metadata and the other for storing the binary buffer. The metadata includes\n * the metadata of data structures such as size and shape. The size of the\n * metadata is decided by the size of metadata. The size of the binary buffer is\n * decided by the size of the data structures.\n *\n * To avoid repeated shared memory allocation, this helper class uses lazy data\n * structure writing. The data structures are written to the shared memory only\n * when `Flush` is called. The data structures are written in the order of\n * calling `WriteTorchArchive`, `WriteTorchTensor` and `WriteTorchTensorDict`,\n * and also read in the same order.\n *\n * The usage of this class as a writer is as follows:\n * @code{.cpp}\n * SharedMemoryHelper shm_helper(\"shm_name\", 1024, true);\n * shm_helper.WriteTorchArchive(archive);\n * shm_helper.WriteTorchTensor(tensor);\n * shm_helper.WriteTorchTensorDict(tensor_dict);\n * shm_helper.Flush();\n * // After `Flush`, the data structures are written to the shared memory.\n * // Then the helper class can be used as a reader.\n * shm_helper.InitializeRead();\n * auto archive = shm_helper.ReadTorchArchive();\n * auto tensor = shm_helper.ReadTorchTensor();\n * auto tensor_dict = shm_helper.ReadTorchTensorDict();\n * @endcode\n *\n * The usage of this class as a reader is as follows:\n * @code{.cpp}\n * SharedMemoryHelper shm_helper(\"shm_name\", 1024, false);\n * shm_helper.InitializeRead();\n * auto archive = shm_helper.ReadTorchArchive();\n * auto tensor = shm_helper.ReadTorchTensor();\n * auto tensor_dict = shm_helper.ReadTorchTensorDict();\n * @endcode\n *\n *\n */\nclass SharedMemoryHelper {\n public:\n  /**\n   * @brief Constructor of the shared memory helper.\n   * @param name The name of the shared memory.\n   */\n  SharedMemoryHelper(const std::string& name);\n\n  /** @brief Initialize this helper class before reading. */\n  void InitializeRead();\n\n  void WriteTorchArchive(torch::serialize::OutputArchive&& archive);\n  torch::serialize::InputArchive ReadTorchArchive();\n\n  void WriteTorchTensor(torch::optional<torch::Tensor> tensor);\n  torch::optional<torch::Tensor> ReadTorchTensor();\n\n  void WriteTorchTensorDict(\n      torch::optional<torch::Dict<std::string, torch::Tensor>> tensor_dict);\n  torch::optional<torch::Dict<std::string, torch::Tensor>>\n  ReadTorchTensorDict();\n\n  /** @brief Flush the data structures to the shared memory. */\n  void Flush();\n\n  /** @brief Release the shared memory and return their left values. */\n  std::pair<SharedMemoryPtr, SharedMemoryPtr> ReleaseSharedMemory();\n\n private:\n  /**\n   * @brief Serialize metadata to string.\n   */\n  void SerializeMetadata();\n  /**\n   * @brief Write the metadata to the shared memory. This function is\n   * called by `Flush`.\n   */\n  void WriteMetadataToSharedMemory();\n  /**\n   * @brief Write the tensor data to the shared memory. This function is\n   * called by `Flush`.\n   */\n  void WriteTorchTensorInternal(torch::optional<torch::Tensor> tensor);\n\n  inline void* GetCurrentMetadataPtr() const {\n    return static_cast<char*>(metadata_shared_memory_->GetMemory()) +\n           metadata_offset_;\n  }\n  inline void* GetCurrentDataPtr() const {\n    return static_cast<char*>(data_shared_memory_->GetMemory()) + data_offset_;\n  }\n  inline void MoveMetadataPtr(int64_t offset) {\n    TORCH_CHECK(\n        metadata_offset_ + offset <= metadata_size_,\n        \"The size of metadata exceeds the maximum size of shared memory.\");\n    metadata_offset_ += offset;\n  }\n  inline void MoveDataPtr(int64_t offset) {\n    TORCH_CHECK(\n        data_offset_ + offset <= data_size_,\n        \"The size of data exceeds the maximum size of shared memory.\");\n    data_offset_ += offset;\n  }\n\n  std::string name_;\n  bool is_creator_;\n\n  size_t metadata_size_;\n  size_t data_size_;\n\n  // The shared memory objects for storing metadata and tensor data.\n  SharedMemoryPtr metadata_shared_memory_, data_shared_memory_;\n\n  // The read/write offsets of the metadata and tensor data.\n  size_t metadata_offset_, data_offset_;\n\n  // The data structures to write to the shared memory. They are written to the\n  // shared memory only when `Flush` is called.\n  std::vector<torch::serialize::OutputArchive> metadata_to_write_;\n  std::vector<std::string> metadata_strings_to_write_;\n  std::vector<torch::optional<torch::Tensor>> tensors_to_write_;\n};\n\n}  // namespace sampling\n}  // namespace graphbolt\n\n#endif  // GRAPHBOLT_SHARED_MEMORY_HELPER_H_\n"
  },
  {
    "path": "graphbolt/src/unique_and_compact.cc",
    "content": "/**\n *  Copyright (c) 2023 by Contributors\n *\n * @file unique_and_compact.cc\n * @brief Unique and compact op.\n */\n\n#include <graphbolt/cuda_ops.h>\n#include <graphbolt/unique_and_compact.h>\n\n#include \"./concurrent_id_hash_map.h\"\n#include \"./macro.h\"\n#include \"./utils.h\"\n\nnamespace graphbolt {\nnamespace sampling {\nstd::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>\nUniqueAndCompact(\n    const torch::Tensor& src_ids, const torch::Tensor& dst_ids,\n    const torch::Tensor unique_dst_ids, const int64_t rank,\n    const int64_t world_size) {\n  if (utils::is_on_gpu(src_ids) && utils::is_on_gpu(dst_ids) &&\n      utils::is_on_gpu(unique_dst_ids)) {\n    GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(\n        c10::DeviceType::CUDA, \"unique_and_compact\", {\n          return ops::UniqueAndCompact(\n              src_ids, dst_ids, unique_dst_ids, rank, world_size);\n        });\n  }\n  TORCH_CHECK(\n      world_size <= 1,\n      \"Cooperative Minibatching (arXiv:2310.12403) is supported only on GPUs.\");\n  auto num_dst = unique_dst_ids.size(0);\n  torch::Tensor ids = torch::cat({unique_dst_ids, src_ids});\n  auto [unique_ids, compacted_src, compacted_dst] = AT_DISPATCH_INDEX_TYPES(\n      ids.scalar_type(), \"unique_and_compact\", ([&] {\n        ConcurrentIdHashMap<index_t> id_map(ids, num_dst);\n        return std::make_tuple(\n            id_map.GetUniqueIds(), id_map.MapIds(src_ids),\n            id_map.MapIds(dst_ids));\n      }));\n  auto offsets = torch::zeros(2, c10::TensorOptions().dtype(torch::kInt64));\n  offsets.data_ptr<int64_t>()[1] = unique_ids.size(0);\n  return {unique_ids, compacted_src, compacted_dst, offsets};\n}\n\nstd::vector<\n    std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>>\nUniqueAndCompactBatched(\n    const std::vector<torch::Tensor>& src_ids,\n    const std::vector<torch::Tensor>& dst_ids,\n    const std::vector<torch::Tensor> unique_dst_ids, const int64_t rank,\n    const int64_t world_size) {\n  TORCH_CHECK(\n      src_ids.size() == dst_ids.size() &&\n          dst_ids.size() == unique_dst_ids.size(),\n      \"The batch dimension of the parameters need to be identical.\");\n  bool all_on_gpu = true;\n  for (std::size_t i = 0; i < src_ids.size(); i++) {\n    all_on_gpu = all_on_gpu && utils::is_on_gpu(src_ids[i]) &&\n                 utils::is_on_gpu(dst_ids[i]) &&\n                 utils::is_on_gpu(unique_dst_ids[i]);\n    if (!all_on_gpu) break;\n  }\n  if (all_on_gpu) {\n    GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(\n        c10::DeviceType::CUDA, \"unique_and_compact\", {\n          return ops::UniqueAndCompactBatched(\n              src_ids, dst_ids, unique_dst_ids, rank, world_size);\n        });\n  }\n  std::vector<\n      std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>>\n      results;\n  results.reserve(src_ids.size());\n  for (std::size_t i = 0; i < src_ids.size(); i++) {\n    results.emplace_back(UniqueAndCompact(\n        src_ids[i], dst_ids[i], unique_dst_ids[i], rank, world_size));\n  }\n  return results;\n}\n\nc10::intrusive_ptr<Future<std::vector<\n    std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>>>>\nUniqueAndCompactBatchedAsync(\n    const std::vector<torch::Tensor>& src_ids,\n    const std::vector<torch::Tensor>& dst_ids,\n    const std::vector<torch::Tensor> unique_dst_ids, const int64_t rank,\n    const int64_t world_size) {\n  return async(\n      [=] {\n        return UniqueAndCompactBatched(\n            src_ids, dst_ids, unique_dst_ids, rank, world_size);\n      },\n      utils::is_on_gpu(src_ids.at(0)));\n}\n\n}  // namespace sampling\n}  // namespace graphbolt\n"
  },
  {
    "path": "graphbolt/src/utils.cc",
    "content": "/**\n *   Copyright (c) 2024, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)\n *   All rights reserved.\n *\n *   Licensed under the Apache License, Version 2.0 (the \"License\");\n *   you may not use this file except in compliance with the License.\n *   You may obtain a copy of the License at\n *\n *       http://www.apache.org/licenses/LICENSE-2.0\n *\n *   Unless required by applicable law or agreed to in writing, software\n *   distributed under the License is distributed on an \"AS IS\" BASIS,\n *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n *   See the License for the specific language governing permissions and\n *   limitations under the License.\n *\n * @file utils.cc\n * @brief Graphbolt utils implementations.\n */\n#include \"./utils.h\"\n\n#include <optional>\n\nnamespace graphbolt {\nnamespace utils {\n\nnamespace {\nstd::optional<int64_t> worker_id;\n}\n\nstd::optional<int64_t> GetWorkerId() { return worker_id; }\n\nvoid SetWorkerId(int64_t worker_id_value) { worker_id = worker_id_value; }\n\n}  // namespace utils\n}  // namespace graphbolt\n"
  },
  {
    "path": "graphbolt/src/utils.h",
    "content": "/**\n *  Copyright (c) 2023 by Contributors\n * @file utils.h\n * @brief Graphbolt utils.\n */\n\n#ifndef GRAPHBOLT_UTILS_H_\n#define GRAPHBOLT_UTILS_H_\n\n#include <torch/script.h>\n\n#include <optional>\n\nnamespace graphbolt {\nnamespace utils {\n\n/**\n * @brief If this process is a worker part as part of a DataLoader, then returns\n * the assigned worker id less than the # workers.\n */\nstd::optional<int64_t> GetWorkerId();\n\n/**\n * @brief If this process is a worker part as part of a DataLoader, then this\n * function is called to initialize its worked id to be less than the # workers.\n */\nvoid SetWorkerId(int64_t worker_id_value);\n\n/**\n * @brief Checks whether the tensor is stored on the GPU.\n */\ninline bool is_on_gpu(const torch::Tensor& tensor) {\n  return tensor.device().is_cuda();\n}\n\n/**\n * @brief Checks whether the tensor is stored on the GPU or the pinned memory.\n */\ninline bool is_accessible_from_gpu(const torch::Tensor& tensor) {\n  return is_on_gpu(tensor) || tensor.is_pinned();\n}\n\n/**\n * @brief Checks whether the tensor is stored on the pinned memory.\n */\ninline bool is_pinned(const torch::Tensor& tensor) {\n  // If this process is a worker, we should avoid initializing the CUDA context.\n  return !GetWorkerId() && tensor.is_pinned();\n}\n\n/**\n * @brief Checks whether the tensors are all stored on the GPU or the pinned\n * memory.\n */\ntemplate <typename TensorContainer>\ninline bool are_accessible_from_gpu(const TensorContainer& tensors) {\n  for (auto& tensor : tensors) {\n    if (!is_accessible_from_gpu(tensor)) return false;\n  }\n  return true;\n}\n\n/**\n * @brief Parses the source and destination node type from a given edge type\n * triple seperated with \":\".\n */\ninline std::pair<std::string, std::string> parse_src_dst_ntype_from_etype(\n    std::string etype) {\n  auto first_seperator_it = std::find(etype.begin(), etype.end(), ':');\n  auto second_seperator_pos =\n      std::find(first_seperator_it + 1, etype.end(), ':') - etype.begin();\n  return {\n      etype.substr(0, first_seperator_it - etype.begin()),\n      etype.substr(second_seperator_pos + 1)};\n}\n\n/**\n * @brief Retrieves the value of the tensor at the given index.\n *\n * @note If the tensor is not contiguous, it will be copied to a contiguous\n * tensor.\n *\n * @tparam T The type of the tensor.\n * @param tensor The tensor.\n * @param index The index.\n *\n * @return T The value of the tensor at the given index.\n */\ntemplate <typename T>\nT GetValueByIndex(const torch::Tensor& tensor, int64_t index) {\n  TORCH_CHECK(\n      index >= 0 && index < tensor.numel(),\n      \"The index should be within the range of the tensor, but got index \",\n      index, \" and tensor size \", tensor.numel());\n  auto contiguous_tensor = tensor.contiguous();\n  auto data_ptr = contiguous_tensor.data_ptr<T>();\n  return data_ptr[index];\n}\n\n}  // namespace utils\n}  // namespace graphbolt\n\n#endif  // GRAPHBOLT_UTILS_H_\n"
  },
  {
    "path": "include/dgl/array.h",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file dgl/array.h\n * @brief Common array operations required by DGL.\n *\n * Note that this is not meant for a full support of array library such as ATen.\n * Only a limited set of operators required by DGL are implemented.\n */\n#ifndef DGL_ARRAY_H_\n#define DGL_ARRAY_H_\n#include \"./aten/array_ops.h\"\n#include \"./aten/coo.h\"\n#include \"./aten/csr.h\"\n#include \"./aten/macro.h\"\n#include \"./aten/spmat.h\"\n#include \"./aten/types.h\"\n#endif  // DGL_ARRAY_H_\n"
  },
  {
    "path": "include/dgl/array_iterator.h",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file dgl/array_iterator.h\n * @brief Various iterators.\n */\n#ifndef DGL_ARRAY_ITERATOR_H_\n#define DGL_ARRAY_ITERATOR_H_\n\n#ifdef __CUDA_ARCH__\n#define CUB_INLINE __host__ __device__ __forceinline__\n#else\n#define CUB_INLINE inline\n#endif  // __CUDA_ARCH__\n\n#include <algorithm>\n#include <iterator>\n#include <utility>\n\nnamespace dgl {\nnamespace aten {\n\nusing std::swap;\n\n// Make std::pair work on both host and device\ntemplate <typename DType>\nstruct Pair {\n  Pair() = default;\n  Pair(const Pair& other) = default;\n  Pair(Pair&& other) = default;\n  CUB_INLINE Pair(DType a, DType b) : first(a), second(b) {}\n  CUB_INLINE Pair& operator=(const Pair& other) {\n    first = other.first;\n    second = other.second;\n    return *this;\n  }\n  CUB_INLINE operator std::pair<DType, DType>() const {\n    return std::make_pair(first, second);\n  }\n  CUB_INLINE bool operator==(const Pair& other) const {\n    return (first == other.first) && (second == other.second);\n  }\n  CUB_INLINE void swap(const Pair& other) const {\n    std::swap(first, other.first);\n    std::swap(second, other.second);\n  }\n  DType first, second;\n};\n\ntemplate <typename DType>\nCUB_INLINE void swap(const Pair<DType>& r1, const Pair<DType>& r2) {\n  r1.swap(r2);\n}\n\n// PairRef and PairIterator that serves as an iterator over a pair of arrays in\n// a zipped fashion like zip(a, b).\ntemplate <typename DType>\nstruct PairRef {\n  PairRef() = delete;\n  PairRef(const PairRef& other) = default;\n  PairRef(PairRef&& other) = default;\n  CUB_INLINE PairRef(DType* const r, DType* const c) : a(r), b(c) {}\n  CUB_INLINE PairRef& operator=(const PairRef& other) {\n    *a = *other.a;\n    *b = *other.b;\n    return *this;\n  }\n  CUB_INLINE PairRef& operator=(const Pair<DType>& val) {\n    *a = val.first;\n    *b = val.second;\n    return *this;\n  }\n  CUB_INLINE operator Pair<DType>() const { return Pair<DType>(*a, *b); }\n  CUB_INLINE operator std::pair<DType, DType>() const {\n    return std::make_pair(*a, *b);\n  }\n  CUB_INLINE bool operator==(const PairRef& other) const {\n    return (*a == *(other.a)) && (*b == *(other.b));\n  }\n  CUB_INLINE void swap(const PairRef& other) const {\n    std::swap(*a, *other.a);\n    std::swap(*b, *other.b);\n  }\n  DType *a, *b;\n};\n\ntemplate <typename DType>\nCUB_INLINE void swap(const PairRef<DType>& r1, const PairRef<DType>& r2) {\n  r1.swap(r2);\n}\n\ntemplate <typename DType>\nstruct PairIterator : public std::iterator<\n                          std::random_access_iterator_tag, Pair<DType>,\n                          std::ptrdiff_t, Pair<DType*>, PairRef<DType>> {\n  PairIterator() = default;\n  PairIterator(const PairIterator& other) = default;\n  PairIterator(PairIterator&& other) = default;\n  CUB_INLINE PairIterator(DType* x, DType* y) : a(x), b(y) {}\n  PairIterator& operator=(const PairIterator& other) = default;\n  PairIterator& operator=(PairIterator&& other) = default;\n  ~PairIterator() = default;\n  CUB_INLINE bool operator==(const PairIterator& other) const {\n    return a == other.a;\n  }\n  CUB_INLINE bool operator!=(const PairIterator& other) const {\n    return a != other.a;\n  }\n  CUB_INLINE bool operator<(const PairIterator& other) const {\n    return a < other.a;\n  }\n  CUB_INLINE bool operator>(const PairIterator& other) const {\n    return a > other.a;\n  }\n  CUB_INLINE bool operator<=(const PairIterator& other) const {\n    return a <= other.a;\n  }\n  CUB_INLINE bool operator>=(const PairIterator& other) const {\n    return a >= other.a;\n  }\n  CUB_INLINE PairIterator& operator+=(const std::ptrdiff_t& movement) {\n    a += movement;\n    b += movement;\n    return *this;\n  }\n  CUB_INLINE PairIterator& operator-=(const std::ptrdiff_t& movement) {\n    a -= movement;\n    b -= movement;\n    return *this;\n  }\n  CUB_INLINE PairIterator& operator++() {\n    ++a;\n    ++b;\n    return *this;\n  }\n  CUB_INLINE PairIterator& operator--() {\n    --a;\n    --b;\n    return *this;\n  }\n  CUB_INLINE PairIterator operator++(int) {\n    PairIterator ret(*this);\n    operator++();\n    return ret;\n  }\n  CUB_INLINE PairIterator operator--(int) {\n    PairIterator ret(*this);\n    operator--();\n    return ret;\n  }\n  CUB_INLINE PairIterator operator+(const std::ptrdiff_t& movement) const {\n    return PairIterator(a + movement, b + movement);\n  }\n  CUB_INLINE PairIterator operator-(const std::ptrdiff_t& movement) const {\n    return PairIterator(a - movement, b - movement);\n  }\n  CUB_INLINE std::ptrdiff_t operator-(const PairIterator& other) const {\n    return a - other.a;\n  }\n  CUB_INLINE PairRef<DType> operator*() const { return PairRef<DType>(a, b); }\n  CUB_INLINE PairRef<DType> operator*() { return PairRef<DType>(a, b); }\n  CUB_INLINE PairRef<DType> operator[](size_t offset) const {\n    return PairRef<DType>(a + offset, b + offset);\n  }\n  CUB_INLINE PairRef<DType> operator[](size_t offset) {\n    return PairRef<DType>(a + offset, b + offset);\n  }\n  DType *a, *b;\n};\n\n};  // namespace aten\n};  // namespace dgl\n\n#endif  // DGL_ARRAY_ITERATOR_H_\n"
  },
  {
    "path": "include/dgl/aten/array_ops.h",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file dgl/aten/array_ops.h\n * @brief Common array operations required by DGL.\n *\n * Note that this is not meant for a full support of array library such as ATen.\n * Only a limited set of operators required by DGL are implemented.\n */\n#ifndef DGL_ATEN_ARRAY_OPS_H_\n#define DGL_ATEN_ARRAY_OPS_H_\n\n#include <algorithm>\n#include <string>\n#include <tuple>\n#include <utility>\n#include <vector>\n\n#include \"./types.h\"\n\nnamespace dgl {\nnamespace aten {\n\n//////////////////////////////////////////////////////////////////////\n// ID array\n//////////////////////////////////////////////////////////////////////\n\n/** @return A special array to represent null. */\ninline NDArray NullArray(\n    const DGLDataType& dtype = DGLDataType{kDGLInt, 64, 1},\n    const DGLContext& ctx = DGLContext{kDGLCPU, 0}) {\n  return NDArray::Empty({0}, dtype, ctx);\n}\n\n/**\n * @return Whether the input array is a null array.\n */\ninline bool IsNullArray(NDArray array) { return array->shape[0] == 0; }\n\n/**\n * @brief Create a new id array with given length\n * @param length The array length\n * @param ctx The array context\n * @param nbits The number of integer bits\n * @return id array\n */\nIdArray NewIdArray(\n    int64_t length, DGLContext ctx = DGLContext{kDGLCPU, 0},\n    uint8_t nbits = 64);\n\n/**\n * @brief Create a new float array with given length\n * @param length The array length\n * @param ctx The array context\n * @param nbits The number of integer bits\n * @return float array\n */\nFloatArray NewFloatArray(int64_t length,\n                   DGLContext ctx = DGLContext{kDGLCPU, 0},\n                   uint8_t nbits = 32);\n\n/**\n * @brief Create a new id array using the given vector data\n * @param vec The vector data\n * @param nbits The integer bits of the returned array\n * @param ctx The array context\n * @return the id array\n */\ntemplate <typename T>\nIdArray VecToIdArray(\n    const std::vector<T>& vec, uint8_t nbits = 64,\n    DGLContext ctx = DGLContext{kDGLCPU, 0});\n\n/**\n * @brief Return an array representing a 1D range.\n * @param low Lower bound (inclusive).\n * @param high Higher bound (exclusive).\n * @param nbits result array's bits (32 or 64)\n * @param ctx Device context\n * @return range array\n */\nIdArray Range(int64_t low, int64_t high, uint8_t nbits, DGLContext ctx);\n\n/**\n * @brief Return an array full of the given value\n * @param val The value to fill.\n * @param length Number of elements.\n * @param nbits result array's bits (32 or 64)\n * @param ctx Device context\n * @return the result array\n */\nIdArray Full(int64_t val, int64_t length, uint8_t nbits, DGLContext ctx);\n\n/**\n * @brief Return an array full of the given value with the given type.\n * @param val The value to fill.\n * @param length Number of elements.\n * @param ctx Device context\n * @return the result array\n */\ntemplate <typename DType>\nNDArray Full(DType val, int64_t length, DGLContext ctx);\n\n/** @brief Create a deep copy of the given array */\nIdArray Clone(IdArray arr);\n\n/** @brief Convert the idarray to the given bit width */\nIdArray AsNumBits(IdArray arr, uint8_t bits);\n\n/** @brief Arithmetic functions */\nIdArray Add(IdArray lhs, IdArray rhs);\nIdArray Sub(IdArray lhs, IdArray rhs);\nIdArray Mul(IdArray lhs, IdArray rhs);\nIdArray Div(IdArray lhs, IdArray rhs);\nIdArray Mod(IdArray lhs, IdArray rhs);\n\nIdArray Add(IdArray lhs, int64_t rhs);\nIdArray Sub(IdArray lhs, int64_t rhs);\nIdArray Mul(IdArray lhs, int64_t rhs);\nIdArray Div(IdArray lhs, int64_t rhs);\nIdArray Mod(IdArray lhs, int64_t rhs);\n\nIdArray Add(int64_t lhs, IdArray rhs);\nIdArray Sub(int64_t lhs, IdArray rhs);\nIdArray Mul(int64_t lhs, IdArray rhs);\nIdArray Div(int64_t lhs, IdArray rhs);\nIdArray Mod(int64_t lhs, IdArray rhs);\n\nIdArray Neg(IdArray array);\n\n// XXX(minjie): currently using integer array for bool type\nIdArray GT(IdArray lhs, IdArray rhs);\nIdArray LT(IdArray lhs, IdArray rhs);\nIdArray GE(IdArray lhs, IdArray rhs);\nIdArray LE(IdArray lhs, IdArray rhs);\nIdArray EQ(IdArray lhs, IdArray rhs);\nIdArray NE(IdArray lhs, IdArray rhs);\n\nIdArray GT(IdArray lhs, int64_t rhs);\nIdArray LT(IdArray lhs, int64_t rhs);\nIdArray GE(IdArray lhs, int64_t rhs);\nIdArray LE(IdArray lhs, int64_t rhs);\nIdArray EQ(IdArray lhs, int64_t rhs);\nIdArray NE(IdArray lhs, int64_t rhs);\n\nIdArray GT(int64_t lhs, IdArray rhs);\nIdArray LT(int64_t lhs, IdArray rhs);\nIdArray GE(int64_t lhs, IdArray rhs);\nIdArray LE(int64_t lhs, IdArray rhs);\nIdArray EQ(int64_t lhs, IdArray rhs);\nIdArray NE(int64_t lhs, IdArray rhs);\n\n/** @brief Stack two arrays (of len L) into a 2*L length array */\nIdArray HStack(IdArray arr1, IdArray arr2);\n\n/** @brief Return the indices of the elements that are non-zero. */\nIdArray NonZero(BoolArray bool_arr);\n\n/**\n * @brief Return the data under the index. In numpy notation, A[I]\n * @tparam ValueType The type of return value.\n */\ntemplate <typename ValueType>\nValueType IndexSelect(NDArray array, int64_t index);\n\n/**\n * @brief Return the data under the index. In numpy notation, A[I]\n */\nNDArray IndexSelect(NDArray array, IdArray index);\n\n/**\n * @brief Return the data from `start` (inclusive) to `end` (exclusive).\n */\nNDArray IndexSelect(NDArray array, int64_t start, int64_t end);\n\n/**\n * @brief Permute the elements of an array according to given indices.\n *\n * Only support 1D arrays.\n *\n * Equivalent to:\n *\n * <code>\n *     result = np.zeros_like(array)\n *     result[indices] = array\n * </code>\n */\nNDArray Scatter(NDArray array, IdArray indices);\n\n/**\n * @brief Scatter data into the output array.\n *\n * Equivalent to:\n *\n * <code>\n *     out[index] = value\n * </code>\n */\nvoid Scatter_(IdArray index, NDArray value, NDArray out);\n\n/**\n * @brief Repeat each element a number of times.  Equivalent to np.repeat(array,\n * repeats)\n * @param array A 1D vector\n * @param repeats A 1D integer vector for number of times to repeat for each\n * element in \\c array.  Must have the same shape as \\c array.\n */\nNDArray Repeat(NDArray array, IdArray repeats);\n\n/**\n * @brief Relabel the given ids to consecutive ids.\n *\n * Relabeling is done inplace. The mapping is created from the union\n * of the give arrays.\n *\n * Example:\n *\n * Given two IdArrays [2, 3, 10, 0, 2] and [4, 10, 5], one possible return\n * mapping is [2, 3, 10, 4, 0, 5], meaning the new ID 0 maps to the old ID\n * 2, 1 maps to 3, so on and so forth.\n *\n * @param arrays The id arrays to relabel.\n * @return mapping array M from new id to old id.\n */\nIdArray Relabel_(const std::vector<IdArray>& arrays);\n\n/**\n * @brief concatenate the given id arrays to one array\n *\n * Example:\n *\n * Given two IdArrays [2, 3, 10, 0, 2] and [4, 10, 5]\n * Return [2, 3, 10, 0, 2, 4, 10, 5]\n *\n * @param arrays The id arrays to concatenate.\n * @return concatenated array.\n */\nNDArray Concat(const std::vector<IdArray>& arrays);\n\n/** @brief Return whether the array is a valid 1D int array*/\ninline bool IsValidIdArray(const dgl::runtime::NDArray& arr) {\n  return arr->ndim == 1 && arr->dtype.code == kDGLInt;\n}\n\n/**\n * @brief Packs a tensor containing padded sequences of variable length.\n *\n * Similar to \\c pack_padded_sequence in PyTorch, except that\n *\n * 1. The length for each sequence (before padding) is inferred as the number\n *    of elements before the first occurrence of \\c pad_value.\n * 2. It does not sort the sequences by length.\n * 3. Along with the tensor containing the packed sequence, it returns both the\n *    length, as well as the offsets to the packed tensor, of each sequence.\n *\n * @param array The tensor containing sequences padded to the same length\n * @param pad_value The padding value\n * @return A triplet of packed tensor, the length tensor, and the offset tensor\n *\n * @note Example: consider the following array with padding value -1:\n *\n * <code>\n *     [[1, 2, -1, -1],\n *      [3, 4,  5, -1]]\n * </code>\n *\n * The packed tensor would be [1, 2, 3, 4, 5].\n *\n * The length tensor would be [2, 3], i.e. the length of each sequence before\n * padding.\n *\n * The offset tensor would be [0, 2], i.e. the offset to the packed tensor for\n * each sequence (before padding)\n */\ntemplate <typename ValueType>\nstd::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, ValueType pad_value);\n\n/**\n * @brief Batch-slice a 1D or 2D array, and then pack the list of sliced arrays\n * by concatenation.\n *\n * If a 2D array is given, then the function is equivalent to:\n *\n * <code>\n *     def ConcatSlices(array, lengths):\n *         slices = [array[i, :l] for i, l in enumerate(lengths)]\n *         packed = np.concatenate(slices)\n *         offsets = np.cumsum([0] + lengths[:-1])\n *         return packed, offsets\n * </code>\n *\n * If a 1D array is given, then the function is equivalent to\n *\n * <code>\n *     def ConcatSlices(array, lengths):\n *         slices = [array[:l] for l in lengths]\n *         packed = np.concatenate(slices)\n *         offsets = np.cumsum([0] + lengths[:-1])\n *         return packed, offsets\n * </code>\n *\n * @param array A 1D or 2D tensor for slicing\n * @param lengths A 1D tensor indicating the number of elements to slice\n * @return The tensor with packed slices along with the offsets.\n */\nstd::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths);\n\n/**\n * @brief Return the cumulative summation (or inclusive sum) of the input array.\n *\n * The first element out[0] is equal to the first element of the input array\n * array[0]. The rest elements are defined recursively, out[i] = out[i-1] +\n * array[i]. Hence, the result array length is the same as the input array\n * length.\n *\n * If prepend_zero is true, then the first element is zero and the result array\n * length is the input array length plus one. This is useful for creating\n * an indptr array over a count array.\n *\n * @param array The 1D input array.\n * @return Array after cumsum.\n */\nIdArray CumSum(IdArray array, bool prepend_zero = false);\n\n/**\n * @brief Return the nonzero index.\n *\n * Only support 1D array. The result index array is in int64.\n *\n * @param array The input array.\n * @return A 1D index array storing the positions of the non zero values.\n */\nIdArray NonZero(NDArray array);\n\n/**\n * @brief Sort the ID vector in ascending order.\n *\n * It performs both sort and arg_sort (returning the sorted index). The sorted\n * index is always in int64.\n *\n * @param array Input array.\n * @param num_bits The number of bits used in key comparison. For example, if\n * the data type of the input array is int32_t and `num_bits = 8`, it only uses\n * bits in index range [0, 8) for sorting. Setting it to a small value could\n *                 speed up the sorting if the underlying sorting algorithm is\n * radix sort (e.g., on GPU). Setting it to zero (default value) means using all\n * the bits for comparison. On CPU, it currently has no effect.\n * @return A pair of arrays: sorted values and sorted index to the original\n * position.\n */\nstd::pair<IdArray, IdArray> Sort(IdArray array, int num_bits = 0);\n\n/**\n * @brief Return a string that prints out some debug information.\n */\nstd::string ToDebugString(NDArray array);\n\n// inline implementations\ntemplate <typename T>\nIdArray VecToIdArray(const std::vector<T>& vec, uint8_t nbits, DGLContext ctx) {\n  IdArray ret = NewIdArray(vec.size(), DGLContext{kDGLCPU, 0}, nbits);\n  if (nbits == 32) {\n    std::copy(vec.begin(), vec.end(), static_cast<int32_t*>(ret->data));\n  } else if (nbits == 64) {\n    std::copy(vec.begin(), vec.end(), static_cast<int64_t*>(ret->data));\n  } else {\n    LOG(FATAL) << \"Only int32 or int64 is supported.\";\n  }\n  return ret.CopyTo(ctx);\n}\n\n/**\n * @brief Get the context of the first array, and check if the non-null arrays'\n * contexts are the same.\n */\ninline DGLContext GetContextOf(const std::vector<IdArray>& arrays) {\n  bool first = true;\n  DGLContext result;\n  for (auto& array : arrays) {\n    if (first) {\n      first = false;\n      result = array->ctx;\n    } else {\n      CHECK_EQ(array->ctx, result)\n          << \"Context of the input arrays are different\";\n    }\n  }\n  return result;\n}\n\n}  // namespace aten\n}  // namespace dgl\n\n#endif  // DGL_ATEN_ARRAY_OPS_H_\n"
  },
  {
    "path": "include/dgl/aten/coo.h",
    "content": "\n/**\n *  Copyright (c) 2020-2022 by Contributors\n * @file dgl/aten/coo.h\n * @brief Common COO operations required by DGL.\n */\n#ifndef DGL_ATEN_COO_H_\n#define DGL_ATEN_COO_H_\n\n#include <dmlc/io.h>\n#include <dmlc/serializer.h>\n\n#include <string>\n#include <tuple>\n#include <utility>\n#include <vector>\n\n#include \"./array_ops.h\"\n#include \"./macro.h\"\n#include \"./spmat.h\"\n#include \"./types.h\"\n\nnamespace dgl {\nnamespace aten {\n\nstruct CSRMatrix;\n\n/**\n * @brief Plain COO structure\n *\n * The data array stores integer ids for reading edge features.\n * Note that we do allow duplicate non-zero entries -- multiple non-zero entries\n * that have the same row, col indices. It corresponds to multigraph in\n * graph terminology.\n */\n\nconstexpr uint64_t kDGLSerialize_AtenCooMatrixMagic = 0xDD61ffd305dff127;\n\n// TODO(BarclayII): Graph queries on COO formats should support the case where\n// data ordered by rows/columns instead of EID.\nstruct COOMatrix {\n  /** @brief the dense shape of the matrix */\n  int64_t num_rows = 0, num_cols = 0;\n  /** @brief COO index arrays */\n  IdArray row, col;\n  /** @brief data index array. When is null, assume it is from 0 to NNZ - 1. */\n  IdArray data;\n  /** @brief whether the row indices are sorted */\n  bool row_sorted = false;\n  /** @brief whether the column indices per row are sorted */\n  bool col_sorted = false;\n  /** @brief whether the matrix is in pinned memory */\n  bool is_pinned = false;\n  /** @brief default constructor */\n  COOMatrix() = default;\n  /** @brief constructor */\n  COOMatrix(\n      int64_t nrows, int64_t ncols, IdArray rarr, IdArray carr,\n      IdArray darr = NullArray(), bool rsorted = false, bool csorted = false)\n      : num_rows(nrows),\n        num_cols(ncols),\n        row(rarr),\n        col(carr),\n        data(darr),\n        row_sorted(rsorted),\n        col_sorted(csorted) {\n    CheckValidity();\n  }\n\n  /** @brief constructor from SparseMatrix object */\n  explicit COOMatrix(const SparseMatrix& spmat)\n      : num_rows(spmat.num_rows),\n        num_cols(spmat.num_cols),\n        row(spmat.indices[0]),\n        col(spmat.indices[1]),\n        data(spmat.indices[2]),\n        row_sorted(spmat.flags[0]),\n        col_sorted(spmat.flags[1]) {\n    CheckValidity();\n  }\n\n  // Convert to a SparseMatrix object that can return to python.\n  SparseMatrix ToSparseMatrix() const {\n    return SparseMatrix(\n        static_cast<int32_t>(SparseFormat::kCOO), num_rows, num_cols,\n        {row, col, data}, {row_sorted, col_sorted});\n  }\n\n  bool Load(dmlc::Stream* fs) {\n    uint64_t magicNum;\n    CHECK(fs->Read(&magicNum)) << \"Invalid Magic Number\";\n    CHECK_EQ(magicNum, kDGLSerialize_AtenCooMatrixMagic)\n        << \"Invalid COOMatrix Data\";\n    CHECK(fs->Read(&num_cols)) << \"Invalid num_cols\";\n    CHECK(fs->Read(&num_rows)) << \"Invalid num_rows\";\n    CHECK(fs->Read(&row)) << \"Invalid row\";\n    CHECK(fs->Read(&col)) << \"Invalid col\";\n    CHECK(fs->Read(&data)) << \"Invalid data\";\n    CHECK(fs->Read(&row_sorted)) << \"Invalid row_sorted\";\n    CHECK(fs->Read(&col_sorted)) << \"Invalid col_sorted\";\n    CheckValidity();\n    return true;\n  }\n\n  void Save(dmlc::Stream* fs) const {\n    fs->Write(kDGLSerialize_AtenCooMatrixMagic);\n    fs->Write(num_cols);\n    fs->Write(num_rows);\n    fs->Write(row);\n    fs->Write(col);\n    fs->Write(data);\n    fs->Write(row_sorted);\n    fs->Write(col_sorted);\n  }\n\n  inline void CheckValidity() const {\n    CHECK_SAME_DTYPE(row, col);\n    CHECK_SAME_CONTEXT(row, col);\n    if (!aten::IsNullArray(data)) {\n      CHECK_SAME_DTYPE(row, data);\n      CHECK_SAME_CONTEXT(row, data);\n    }\n    CHECK_NO_OVERFLOW(row->dtype, num_rows);\n    CHECK_NO_OVERFLOW(row->dtype, num_cols);\n  }\n\n  inline bool IsEmpty() const {\n    return aten::IsNullArray(row) && aten::IsNullArray(col) &&\n           aten::IsNullArray(data);\n  }\n\n  // Check and update the internal flag is_pinned.\n  // This function will initialize a cuda context.\n  inline bool CheckIfPinnedInCUDA() {\n    is_pinned = (aten::IsNullArray(row) || row.IsPinned()) &&\n                (aten::IsNullArray(col) || col.IsPinned()) &&\n                (aten::IsNullArray(data) || data.IsPinned());\n    return is_pinned;\n  }\n\n  /** @brief Return a copy of this matrix on the give device context. */\n  inline COOMatrix CopyTo(const DGLContext& ctx) const {\n    if (ctx == row->ctx) return *this;\n    return COOMatrix(\n        num_rows, num_cols, row.CopyTo(ctx), col.CopyTo(ctx),\n        aten::IsNullArray(data) ? data : data.CopyTo(ctx), row_sorted,\n        col_sorted);\n  }\n\n  /** @brief Return a copy of this matrix in pinned (page-locked) memory. */\n  inline COOMatrix PinMemory() {\n    if (!IsEmpty()) {\n      if (is_pinned) return *this;\n      auto new_coo = COOMatrix(\n          num_rows, num_cols, row.PinMemory(), col.PinMemory(),\n          aten::IsNullArray(data) ? data : data.PinMemory(), row_sorted,\n          col_sorted);\n      CHECK(new_coo.CheckIfPinnedInCUDA())\n          << \"An internal DGL error has occured while trying to pin a COO \"\n             \"matrix. Please file a bug at \"\n             \"'https://github.com/dmlc/dgl/issues' \"\n             \"with the above stacktrace.\";\n      return new_coo;\n    }\n    is_pinned = true;\n    return *this;\n  }\n\n  /**\n   * @brief Pin the row, col and data (if not Null) of the matrix.\n   * @note This is an in-place method. Behavior depends on the current context,\n   *       kDGLCPU: will be pinned;\n   *       IsPinned: directly return;\n   *       kDGLCUDA: invalid, will throw an error.\n   *       The context check is deferred to pinning the NDArray.\n   */\n  inline void PinMemory_() {\n    if (!IsEmpty()) {\n      if (is_pinned) return;\n      row.PinMemory_();\n      col.PinMemory_();\n      if (!aten::IsNullArray(data)) {\n        data.PinMemory_();\n      }\n      is_pinned = true;\n    }\n    is_pinned = true;\n    return;\n  }\n\n  /**\n   * @brief Unpin the row, col and data (if not Null) of the matrix.\n   * @note This is an in-place method. Behavior depends on the current context,\n   *       IsPinned: will be unpinned;\n   *       others: directly return.\n   *       The context check is deferred to unpinning the NDArray.\n   */\n  inline void UnpinMemory_() {\n    if (!IsEmpty()) {\n      if (!is_pinned) return;\n      row.UnpinMemory_();\n      col.UnpinMemory_();\n      if (!aten::IsNullArray(data)) {\n        data.UnpinMemory_();\n      }\n      is_pinned = false;\n    }\n    is_pinned = false;\n    return;\n  }\n\n  /**\n   * @brief Record stream for the row, col and data (if not Null) of the matrix.\n   * @param stream The stream that is using the graph\n   */\n  inline void RecordStream(DGLStreamHandle stream) const {\n    row.RecordStream(stream);\n    col.RecordStream(stream);\n    if (!aten::IsNullArray(data)) {\n      data.RecordStream(stream);\n    }\n  }\n};\n\n///////////////////////// COO routines //////////////////////////\n\n/** @brief Return true if the value (row, col) is non-zero */\nbool COOIsNonZero(COOMatrix, int64_t row, int64_t col);\n/**\n * @brief Batched implementation of COOIsNonZero.\n * @note This operator allows broadcasting (i.e, either row or col can be of\n * length 1).\n */\nruntime::NDArray COOIsNonZero(\n    COOMatrix, runtime::NDArray row, runtime::NDArray col);\n\n/** @brief Return the nnz of the given row */\nint64_t COOGetRowNNZ(COOMatrix, int64_t row);\nruntime::NDArray COOGetRowNNZ(COOMatrix, runtime::NDArray row);\n\n/** @brief Return the data array of the given row */\nstd::pair<runtime::NDArray, runtime::NDArray> COOGetRowDataAndIndices(\n    COOMatrix, int64_t row);\n\n/** @brief Whether the COO matrix contains data */\ninline bool COOHasData(COOMatrix csr) { return !IsNullArray(csr.data); }\n\n/**\n * @brief Check whether the COO is sorted.\n *\n * It returns two flags: one for whether the row is sorted;\n * the other for whether the columns of each row is sorted\n * if the first flag is true.\n *\n * Complexity: O(NNZ)\n */\nstd::pair<bool, bool> COOIsSorted(COOMatrix coo);\n\n/**\n * @brief Get the data and the row,col indices for each returned entries.\n *\n * The operator supports matrix with duplicate entries and all the matched\n * entries will be returned. The operator assumes there is NO duplicate (row,\n * col) pair in the given input. Otherwise, the returned result is undefined.\n *\n * @note This operator allows broadcasting (i.e, either row or col can be of\n * length 1).\n * @param mat Sparse matrix\n * @param rows Row index\n * @param cols Column index\n * @return Three arrays {rows, cols, data}\n */\nstd::vector<runtime::NDArray> COOGetDataAndIndices(\n    COOMatrix mat, runtime::NDArray rows, runtime::NDArray cols);\n\n/**\n * @brief Get data. The return type is an ndarray due to possible duplicate\n * entries.\n */\ninline runtime::NDArray COOGetAllData(COOMatrix mat, int64_t row, int64_t col) {\n  IdArray rows =\n      VecToIdArray<int64_t>({row}, mat.row->dtype.bits, mat.row->ctx);\n  IdArray cols =\n      VecToIdArray<int64_t>({col}, mat.row->dtype.bits, mat.row->ctx);\n  const auto& rst = COOGetDataAndIndices(mat, rows, cols);\n  return rst[2];\n}\n\n/**\n * @brief Get the data for each (row, col) pair.\n *\n * The operator supports matrix with duplicate entries but only one matched\n * entry will be returned for each (row, col) pair. Support duplicate input\n * (row, col) pairs.\n *\n * @note This operator allows broadcasting (i.e, either row or col can be of\n * length 1).\n *\n * @param mat Sparse matrix.\n * @param rows Row index.\n * @param cols Column index.\n * @return Data array. The i^th element is the data of (rows[i], cols[i])\n */\nruntime::NDArray COOGetData(\n    COOMatrix mat, runtime::NDArray rows, runtime::NDArray cols);\n\n/** @brief Return a transposed COO matrix */\nCOOMatrix COOTranspose(COOMatrix coo);\n\n/**\n * @brief Convert COO matrix to CSR matrix.\n *\n * If the input COO matrix does not have data array, the data array of\n * the result CSR matrix stores a shuffle index for how the entries\n * will be reordered in CSR. The i^th entry in the result CSR corresponds\n * to the CSR.data[i] th entry in the input COO.\n *\n * Conversion complexity: O(nnz)\n *\n * - The function first check whether the input COO matrix is sorted\n *   using a linear scan.\n * - If the COO matrix is row sorted, the conversion can be done very\n *   efficiently in a sequential scan. The result indices and data arrays\n *   are directly equal to the column and data arrays from the input.\n * - If the COO matrix is further column sorted, the result CSR is\n *   also column sorted.\n * - Otherwise, the conversion is more costly but still is O(nnz).\n *\n * @param coo Input COO matrix.\n * @return CSR matrix.\n */\nCSRMatrix COOToCSR(COOMatrix coo);\n\n/**\n * @brief Slice rows of the given matrix and return.\n * @param coo COO matrix\n * @param start Start row id (inclusive)\n * @param end End row id (exclusive)\n */\nCOOMatrix COOSliceRows(COOMatrix coo, int64_t start, int64_t end);\nCOOMatrix COOSliceRows(COOMatrix coo, runtime::NDArray rows);\n\n/**\n * @brief Get the submatrix specified by the row and col ids.\n *\n * In numpy notation, given matrix M, row index array I, col index array J\n * This function returns the submatrix M[I, J].\n *\n * @param coo The input coo matrix\n * @param rows The row index to select\n * @param cols The col index to select\n * @return submatrix\n */\nCOOMatrix COOSliceMatrix(\n    COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols);\n\n/** @return True if the matrix has duplicate entries */\nbool COOHasDuplicate(COOMatrix coo);\n\n/**\n * @brief Deduplicate the entries of a sorted COO matrix, replacing the data\n * with the number of occurrences of the row-col coordinates.\n */\nstd::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo);\n\n/**\n * @brief Sort the indices of a COO matrix in-place.\n *\n * The function sorts row indices in ascending order. If sort_column is true,\n * col indices are sorted in ascending order too. The data array of the returned\n * COOMatrix stores the shuffled index which could be used to fetch edge data.\n *\n * Complexity: O(N*log(N)) time and O(1) space, where N is the number of\n * nonzeros.\n * TODO(minjie): The time complexity could be improved to O(N) by using a O(N)\n * space.\n *\n * @param mat The coo matrix to sort.\n * @param sort_column True if column index should be sorted too.\n */\nvoid COOSort_(COOMatrix* mat, bool sort_column = false);\n\n/**\n * @brief Sort the indices of a COO matrix.\n *\n * The function sorts row indices in ascending order. If sort_column is true,\n * col indices are sorted in ascending order too. The data array of the returned\n * COOMatrix stores the shuffled index which could be used to fetch edge data.\n *\n * Complexity: O(N*log(N)) time and O(1) space, where N is the number of\n * nonzeros.\n * TODO(minjie): The time complexity could be improved to O(N) by using a O(N)\n * space.\n *\n * @param mat The input coo matrix\n * @param sort_column True if column index should be sorted too.\n * @return COO matrix with index sorted.\n */\ninline COOMatrix COOSort(COOMatrix mat, bool sort_column = false) {\n  if ((mat.row_sorted && !sort_column) || mat.col_sorted) return mat;\n  COOMatrix ret(\n      mat.num_rows, mat.num_cols, mat.row.Clone(), mat.col.Clone(),\n      COOHasData(mat) ? mat.data.Clone() : mat.data, mat.row_sorted,\n      mat.col_sorted);\n  COOSort_(&ret, sort_column);\n  return ret;\n}\n\n/**\n * @brief Remove entries from COO matrix by entry indices (data indices)\n * @return A new COO matrix as well as a mapping from the new COO entries to the\n * old COO entries.\n */\nCOOMatrix COORemove(COOMatrix coo, IdArray entries);\n\n/**\n * @brief Reorder the rows and colmns according to the new row and column order.\n * @param csr The input coo matrix.\n * @param new_row_ids the new row Ids (the index is the old row Id)\n * @param new_col_ids the new column Ids (the index is the old col Id).\n */\nCOOMatrix COOReorder(\n    COOMatrix coo, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids);\n\n/**\n * @brief Randomly select a fixed number of non-zero entries along each given\n * row using arXiv:2210.13339, Labor sampling.\n *\n * The picked indices are returned in the form of a COO matrix.\n *\n * The passed random_seed makes it so that for any seed vertex s and its\n * neighbor t, the rolled random variate r_t is the same for any call to this\n * function with the same random seed. When sampling as part of the same batch,\n * one would want identical seeds so that LABOR can globally sample. One example\n * is that for heterogenous graphs, there is a single random seed passed for\n * each edge type. This will sample much fewer vertices compared to having\n * unique random seeds for each edge type. If one called this function\n * individually for each edge type for a heterogenous graph with different\n * random seeds, then it would run LABOR locally for each edge type, resulting\n * into a larger number of vertices being sampled.\n *\n * If this function is called without a random_seed, we get the random seed by\n * getting a random number from DGL.\n *\n *\n * Examples:\n *\n * // coo.num_rows = 4;\n * // coo.num_cols = 4;\n * // coo.rows = [0, 0, 1, 3, 3]\n * // coo.cols = [0, 1, 1, 2, 3]\n * // coo.data = [2, 3, 0, 1, 4]\n * COOMatrix coo = ...;\n * IdArray rows = ... ; // [1, 3]\n * COOMatrix sampled = COOLaborSampling(coo, rows, 2, NullArray(), 0 \\\n *     , NullArray(), NullArray());\n * // possible sampled coo matrix:\n * // sampled.num_rows = 4\n * // sampled.num_cols = 4\n * // sampled.rows = [1, 3, 3]\n * // sampled.cols = [1, 2, 3]\n * // sampled.data = [3, 0, 4]\n *\n * @param mat Input coo matrix.\n * @param rows Rows to sample from.\n * @param num_samples Number of samples using labor sampling\n * @param prob Probability array for nonuniform sampling\n * @param importance_sampling Whether to enable importance sampling\n * @param random_seed The random seed for the sampler\n * @param seed2_contribution The contribution of the second random seed, [0, 1)\n * @param NIDs global nids if sampling from a subgraph\n * @return A pair of COOMatrix storing the picked row and col indices and edge\n *         weights if importance_sampling != 0 or prob argument was passed.\n *         Its data field stores the the index of the picked elements in the\n *         value array.\n */\nstd::pair<COOMatrix, FloatArray> COOLaborSampling(\n    COOMatrix mat, IdArray rows, int64_t num_samples,\n    FloatArray prob = NullArray(), int importance_sampling = 0,\n    IdArray random_seed = NullArray(), float seed2_contribution = 0,\n    IdArray NIDs = NullArray());\n\n/**\n * @brief Randomly select a fixed number of non-zero entries along each given\n * row independently.\n *\n * The function performs random choices along each row independently.\n * The picked indices are returned in the form of a COO matrix.\n *\n * If replace is false and a row has fewer non-zero values than num_samples,\n * all the values are picked.\n *\n * Examples:\n *\n * // coo.num_rows = 4;\n * // coo.num_cols = 4;\n * // coo.rows = [0, 0, 1, 3, 3]\n * // coo.cols = [0, 1, 1, 2, 3]\n * // coo.data = [2, 3, 0, 1, 4]\n * COOMatrix coo = ...;\n * IdArray rows = ... ; // [1, 3]\n * COOMatrix sampled = COORowWiseSampling(coo, rows, 2, FloatArray(), false);\n * // possible sampled coo matrix:\n * // sampled.num_rows = 4\n * // sampled.num_cols = 4\n * // sampled.rows = [1, 3, 3]\n * // sampled.cols = [1, 2, 3]\n * // sampled.data = [3, 0, 4]\n *\n * @param mat Input coo matrix.\n * @param rows Rows to sample from.\n * @param num_samples Number of samples\n * @param prob_or_mask Unnormalized probability array or mask array.\n *                     Should be of the same length as the data array.\n *                     If an empty array is provided, assume uniform.\n * @param replace True if sample with replacement\n * @return A COOMatrix storing the picked row and col indices. Its data field\n * stores the the index of the picked elements in the value array.\n */\nCOOMatrix COORowWiseSampling(\n    COOMatrix mat, IdArray rows, int64_t num_samples,\n    NDArray prob_or_mask = NDArray(), bool replace = true);\n\n/**\n * @brief Randomly select a fixed number of non-zero entries for each edge type\n *        along each given row independently.\n *\n * The function performs random choices along each row independently.\n * In each row, num_samples samples is picked for each edge type. (The edge\n * type is stored in etypes)\n * The picked indices are returned in the form of a COO matrix.\n *\n * If replace is false and a row has fewer non-zero values than num_samples,\n * all the values are picked.\n *\n * Examples:\n *\n * // coo.num_rows = 4;\n * // coo.num_cols = 4;\n * // coo.rows = [0, 0, 0, 0, 3]\n * // coo.cols = [0, 1, 3, 2, 3]\n * // coo.data = [2, 3, 0, 1, 4]\n * // eid2etype_offset = [0, 3, 4, 5]\n * COOMatrix coo = ...;\n * IdArray rows = ... ; // [0, 3]\n * std::vector<int64_t> num_samples = {2, 2, 2};\n * COOMatrix sampled = COORowWisePerEtypeSampling(coo, rows, eid2etype_offset,\n * num_samples, FloatArray(), false);\n * // possible sampled coo matrix:\n * // sampled.num_rows = 4\n * // sampled.num_cols = 4\n * // sampled.rows = [0, 0, 0, 3]\n * // sampled.cols = [0, 3, 2, 3]\n * // sampled.data = [2, 0, 1, 4]\n *\n * @param mat Input coo matrix.\n * @param rows Rows to sample from.\n * @param eid2etype_offset The offset to each edge type.\n * @param num_samples Number of samples\n * @param prob_or_mask Unnormalized probability array or mask array.\n *                     Should be of the same length as the data array.\n *                     If an empty array is provided, assume uniform.\n * @param replace True if sample with replacement\n * @return A COOMatrix storing the picked row and col indices. Its data field\n * stores the the index of the picked elements in the value array.\n * @note The edges of the entire graph must be ordered by their edge types.\n */\nCOOMatrix COORowWisePerEtypeSampling(\n    COOMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,\n    const std::vector<int64_t>& num_samples,\n    const std::vector<NDArray>& prob_or_mask, bool replace = true);\n\n/**\n * @brief Select K non-zero entries with the largest weights along each given\n * row.\n *\n * The function performs top-k selection along each row independently.\n * The picked indices are returned in the form of a COO matrix.\n *\n * If replace is false and a row has fewer non-zero values than k,\n * all the values are picked.\n *\n * Examples:\n *\n * // coo.num_rows = 4;\n * // coo.num_cols = 4;\n * // coo.rows = [0, 0, 1, 3, 3]\n * // coo.cols = [0, 1, 1, 2, 3]\n * // coo.data = [2, 3, 0, 1, 4]\n * COOMatrix coo = ...;\n * IdArray rows = ... ;  // [0, 1, 3]\n * FloatArray weight = ... ;  // [1., 0., -1., 10., 20.]\n * COOMatrix sampled = COORowWiseTopk(coo, rows, 1, weight);\n * // possible sampled coo matrix:\n * // sampled.num_rows = 4\n * // sampled.num_cols = 4\n * // sampled.rows = [0, 1, 3]\n * // sampled.cols = [1, 1, 2]\n * // sampled.data = [3, 0, 1]\n *\n * @param mat Input COO matrix.\n * @param rows Rows to sample from.\n * @param k The K value.\n * @param weight Weight associated with each entry. Should be of the same length\n * as the data array. If an empty array is provided, assume uniform.\n * @param ascending If true, elements are sorted by ascending order, equivalent\n * to find the K smallest values. Otherwise, find K largest values.\n * @return A COOMatrix storing the picked row and col indices. Its data field\n * stores the the index of the picked elements in the value array.\n */\nCOOMatrix COORowWiseTopk(\n    COOMatrix mat, IdArray rows, int64_t k, NDArray weight,\n    bool ascending = false);\n\n/**\n * @brief Union two COOMatrix into one COOMatrix.\n *\n * Two Matrix must have the same shape.\n *\n * Example:\n *\n * A = [[0, 0, 1, 0],\n *      [1, 0, 1, 1],\n *      [0, 1, 0, 0]]\n *\n * B = [[0, 1, 1, 0],\n *      [0, 0, 0, 1],\n *      [0, 0, 1, 0]]\n *\n * COOMatrix_A.num_rows : 3\n * COOMatrix_A.num_cols : 4\n * COOMatrix_B.num_rows : 3\n * COOMatrix_B.num_cols : 4\n *\n * C = UnionCoo({A, B});\n *\n * C = [[0, 1, 2, 0],\n *      [1, 0, 1, 2],\n *      [0, 1, 1, 0]]\n *\n * COOMatrix_C.num_rows : 3\n * COOMatrix_C.num_cols : 4\n */\nCOOMatrix UnionCoo(const std::vector<COOMatrix>& coos);\n\n/**\n * @brief DisjointUnion a list COOMatrix into one COOMatrix.\n *\n * Examples:\n *\n * A = [[0, 0, 1],\n *      [1, 0, 1],\n *      [0, 1, 0]]\n *\n * B = [[0, 0],\n *      [1, 0]]\n *\n * COOMatrix_A.num_rows : 3\n * COOMatrix_A.num_cols : 3\n * COOMatrix_B.num_rows : 2\n * COOMatrix_B.num_cols : 2\n *\n * C = DisjointUnionCoo({A, B});\n *\n * C = [[0, 0, 1, 0, 0],\n *      [1, 0, 1, 0, 0],\n *      [0, 1, 0, 0, 0],\n *      [0, 0, 0, 0, 0],\n *      [0, 0, 0, 1, 0]]\n * COOMatrix_C.num_rows : 5\n * COOMatrix_C.num_cols : 5\n *\n * @param coos The input list of coo matrix.\n * @param src_offset A list of integers recording src vertix id offset of each\n * Matrix in coos\n * @param src_offset A list of integers recording dst vertix id offset of each\n * Matrix in coos\n * @return The combined COOMatrix.\n */\nCOOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos);\n\n/**\n * @brief COOMatrix toSimple.\n *\n * A = [[0, 0, 0],\n *      [3, 0, 2],\n *      [1, 1, 0],\n *      [0, 0, 4]]\n *\n * B, cnt, edge_map = COOToSimple(A)\n *\n * B = [[0, 0, 0],\n *      [1, 0, 1],\n *      [1, 1, 0],\n *      [0, 0, 1]]\n * cnt = [3, 2, 1, 1, 4]\n * edge_map = [0, 0, 0, 1, 1, 2, 3, 4, 4, 4, 4]\n *\n * @return The simplified COOMatrix\n *         The count recording the number of duplicated edges from the original\n * graph. The edge mapping from the edge IDs of original graph to those of the\n *         returned graph.\n */\nstd::tuple<COOMatrix, IdArray, IdArray> COOToSimple(const COOMatrix& coo);\n\n/**\n * @brief Split a COOMatrix into multiple disjoin components.\n *\n * Examples:\n *\n * C = [[0, 0, 1, 0, 0],\n *      [1, 0, 1, 0, 0],\n *      [0, 1, 0, 0, 0],\n *      [0, 0, 0, 0, 0],\n *      [0, 0, 0, 1, 0],\n *      [0, 0, 0, 0, 1]]\n * COOMatrix_C.num_rows : 6\n * COOMatrix_C.num_cols : 5\n *\n * batch_size : 2\n * edge_cumsum : [0, 4, 6]\n * src_vertex_cumsum : [0, 3, 6]\n * dst_vertex_cumsum : [0, 3, 5]\n *\n * ret = DisjointPartitionCooBySizes(C,\n *                                   batch_size,\n *                                   edge_cumsum,\n *                                   src_vertex_cumsum,\n *                                   dst_vertex_cumsum)\n *\n * A = [[0, 0, 1],\n *      [1, 0, 1],\n *      [0, 1, 0]]\n * COOMatrix_A.num_rows : 3\n * COOMatrix_A.num_cols : 3\n *\n * B = [[0, 0],\n *      [1, 0],\n *      [0, 1]]\n * COOMatrix_B.num_rows : 3\n * COOMatrix_B.num_cols : 2\n *\n * @param coo COOMatrix to split.\n * @param batch_size Number of disjoin components (Sub COOMatrix)\n * @param edge_cumsum Number of edges of each components\n * @param src_vertex_cumsum Number of src vertices of each component.\n * @param dst_vertex_cumsum Number of dst vertices of each component.\n * @return A list of COOMatrixes representing each disjoint components.\n */\nstd::vector<COOMatrix> DisjointPartitionCooBySizes(\n    const COOMatrix& coo, const uint64_t batch_size,\n    const std::vector<uint64_t>& edge_cumsum,\n    const std::vector<uint64_t>& src_vertex_cumsum,\n    const std::vector<uint64_t>& dst_vertex_cumsum);\n\n/**\n * @brief Slice a contiguous chunk from a COOMatrix\n *\n * Examples:\n *\n * C = [[0, 0, 1, 0, 0],\n *      [1, 0, 1, 0, 0],\n *      [0, 1, 0, 0, 0],\n *      [0, 0, 0, 0, 0],\n *      [0, 0, 0, 1, 0],\n *      [0, 0, 0, 0, 1]]\n * COOMatrix_C.num_rows : 6\n * COOMatrix_C.num_cols : 5\n *\n * edge_range : [4, 6]\n * src_vertex_range : [3, 6]\n * dst_vertex_range : [3, 5]\n *\n * ret = COOSliceContiguousChunk(C,\n *                               edge_range,\n *                               src_vertex_range,\n *                               dst_vertex_range)\n *\n * ret = [[0, 0],\n *        [1, 0],\n *        [0, 1]]\n * COOMatrix_ret.num_rows : 3\n * COOMatrix_ret.num_cols : 2\n *\n * @param coo COOMatrix to slice.\n * @param edge_range ID range of the edges in the chunk\n * @param src_vertex_range ID range of the src vertices in the chunk.\n * @param dst_vertex_range ID range of the dst vertices in the chunk.\n * @return COOMatrix representing the chunk.\n */\nCOOMatrix COOSliceContiguousChunk(\n    const COOMatrix& coo, const std::vector<uint64_t>& edge_range,\n    const std::vector<uint64_t>& src_vertex_range,\n    const std::vector<uint64_t>& dst_vertex_range);\n\n/**\n * @brief Create a LineGraph of input coo\n *\n * A = [[0, 0, 1],\n *      [1, 0, 1],\n *      [1, 1, 0]]\n * A.row = [0, 1, 1, 2, 2]\n * A.col = [2, 0, 2, 0, 1]\n * A.eid = [0, 1, 2, 3, 4]\n *\n * B = COOLineGraph(A, backtracking=False)\n *\n * B = [[0, 0, 0, 0, 1],\n *      [1, 0, 0, 0, 0],\n *      [0, 0, 0, 1, 0],\n *      [0, 0, 0, 0, 0],\n *      [0, 1, 0, 0, 0]]\n *\n * C = COOLineGraph(A, backtracking=True)\n *\n * C = [[0, 0, 0, 1, 1],\n *      [1, 0, 0, 0, 0],\n *      [0, 0, 0, 1, 1],\n *      [1, 0, 0, 0, 0],\n *      [0, 1, 1, 0, 0]]\n *\n * @param coo COOMatrix to create the LineGraph\n * @param backtracking whether the pair of (v, u) (u, v) edges are treated as\n * linked\n * @return LineGraph in COO format\n */\nCOOMatrix COOLineGraph(const COOMatrix& coo, bool backtracking);\n\n/**\n * @brief Generalized Sparse Matrix-Matrix Multiplication on COO.\n * @param op The binary operator, could be `add`, `sub', `mul`, 'div',\n *        `copy_u`, `copy_e'.\n * @param op The reduce operator, could be `sum`, `min`, `max'.\n * @param coo The COO we apply SpMM on.\n * @param ufeat The source node feature.\n * @param efeat The edge feature.\n * @param out The output feature on destination nodes.\n * @param out_aux A list of NDArray's that contains auxiliary information such\n *        as the argmax on source nodes and edges for reduce operators such as\n *        `min` and `max`.\n */\nvoid COOSpMM(\n    const std::string& op, const std::string& reduce, const COOMatrix& coo,\n    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);\n\n/** @brief COOSpMM C interface without std::string. */\nvoid COOSpMM(\n    const char* op, const char* reduce, const COOMatrix& coo, NDArray ufeat,\n    NDArray efeat, NDArray out, std::vector<NDArray> out_aux);\n\n/**\n * @brief Generalized Sampled Dense-Dense Matrix Multiplication on COO.\n * @param op The binary operator, could be `add`, `sub', `mul`, 'div',\n *        `dot`, `copy_u`, `copy_e'.\n * @param coo The COO we apply SpMM on.\n * @param ufeat The source node feature.\n * @param vfeat The destination node feature.\n * @param out The output feature on edge.\n * @param lhs_target Type of `ufeat` (0: source, 1: edge, 2: destination).\n * @param rhs_target Type of `ufeat` (0: source, 1: edge, 2: destination).\n */\nvoid COOSDDMM(\n    const std::string& op, const COOMatrix& coo, NDArray ufeat, NDArray efeat,\n    NDArray out, int lhs_target, int rhs_target);\n\n/** @brief COOSDDMM C interface without std::string. */\nvoid COOSDDMM(\n    const char* op, const COOMatrix& coo, NDArray ufeat, NDArray efeat,\n    NDArray out, int lhs_target, int rhs_target);\n\n}  // namespace aten\n}  // namespace dgl\n\nnamespace dmlc {\nDMLC_DECLARE_TRAITS(has_saveload, dgl::aten::COOMatrix, true);\n}  // namespace dmlc\n\n#endif  // DGL_ATEN_COO_H_\n"
  },
  {
    "path": "include/dgl/aten/csr.h",
    "content": "/**\n *  Copyright (c) 2020-2022 by Contributors\n * @file dgl/aten/csr.h\n * @brief Common CSR operations required by DGL.\n */\n#ifndef DGL_ATEN_CSR_H_\n#define DGL_ATEN_CSR_H_\n\n#include <dmlc/io.h>\n#include <dmlc/serializer.h>\n\n#include <string>\n#include <tuple>\n#include <utility>\n#include <vector>\n\n#include \"./array_ops.h\"\n#include \"./macro.h\"\n#include \"./spmat.h\"\n#include \"./types.h\"\n\nnamespace dgl {\nnamespace aten {\n\nstruct COOMatrix;\n\n/**\n * @brief Plain CSR matrix\n *\n * The column indices are 0-based and are not necessarily sorted. The data array\n * stores integer ids for reading edge features.\n *\n * Note that we do allow duplicate non-zero entries -- multiple non-zero entries\n * that have the same row, col indices. It corresponds to multigraph in\n * graph terminology.\n */\n\nconstexpr uint64_t kDGLSerialize_AtenCsrMatrixMagic = 0xDD6cd31205dff127;\n\nstruct CSRMatrix {\n  /** @brief the dense shape of the matrix */\n  int64_t num_rows = 0, num_cols = 0;\n  /** @brief CSR index arrays */\n  IdArray indptr, indices;\n  /** @brief data index array. When is null, assume it is from 0 to NNZ - 1. */\n  IdArray data;\n  /** @brief whether the column indices per row are sorted */\n  bool sorted = false;\n  /** @brief whether the matrix is in pinned memory */\n  bool is_pinned = false;\n  /** @brief default constructor */\n  CSRMatrix() = default;\n  /** @brief constructor */\n  CSRMatrix(\n      int64_t nrows, int64_t ncols, IdArray parr, IdArray iarr,\n      IdArray darr = NullArray(), bool sorted_flag = false)\n      : num_rows(nrows),\n        num_cols(ncols),\n        indptr(parr),\n        indices(iarr),\n        data(darr),\n        sorted(sorted_flag) {\n    CheckValidity();\n  }\n\n  /** @brief constructor from SparseMatrix object */\n  explicit CSRMatrix(const SparseMatrix& spmat)\n      : num_rows(spmat.num_rows),\n        num_cols(spmat.num_cols),\n        indptr(spmat.indices[0]),\n        indices(spmat.indices[1]),\n        data(spmat.indices[2]),\n        sorted(spmat.flags[0]) {\n    CheckValidity();\n  }\n\n  // Convert to a SparseMatrix object that can return to python.\n  SparseMatrix ToSparseMatrix() const {\n    return SparseMatrix(\n        static_cast<int32_t>(SparseFormat::kCSR), num_rows, num_cols,\n        {indptr, indices, data}, {sorted});\n  }\n\n  bool Load(dmlc::Stream* fs) {\n    uint64_t magicNum;\n    CHECK(fs->Read(&magicNum)) << \"Invalid Magic Number\";\n    CHECK_EQ(magicNum, kDGLSerialize_AtenCsrMatrixMagic)\n        << \"Invalid CSRMatrix Data\";\n    CHECK(fs->Read(&num_cols)) << \"Invalid num_cols\";\n    CHECK(fs->Read(&num_rows)) << \"Invalid num_rows\";\n    CHECK(fs->Read(&indptr)) << \"Invalid indptr\";\n    CHECK(fs->Read(&indices)) << \"Invalid indices\";\n    CHECK(fs->Read(&data)) << \"Invalid data\";\n    CHECK(fs->Read(&sorted)) << \"Invalid sorted\";\n    CheckValidity();\n    return true;\n  }\n\n  void Save(dmlc::Stream* fs) const {\n    fs->Write(kDGLSerialize_AtenCsrMatrixMagic);\n    fs->Write(num_cols);\n    fs->Write(num_rows);\n    fs->Write(indptr);\n    fs->Write(indices);\n    fs->Write(data);\n    fs->Write(sorted);\n  }\n\n  inline void CheckValidity() const {\n    CHECK_SAME_DTYPE(indptr, indices);\n    CHECK_SAME_CONTEXT(indptr, indices);\n    if (!aten::IsNullArray(data)) {\n      CHECK_SAME_DTYPE(indptr, data);\n      CHECK_SAME_CONTEXT(indptr, data);\n    }\n    CHECK_NO_OVERFLOW(indptr->dtype, num_rows);\n    CHECK_NO_OVERFLOW(indptr->dtype, num_cols);\n    CHECK_EQ(indptr->shape[0], num_rows + 1);\n  }\n\n  inline bool IsEmpty() const {\n    return aten::IsNullArray(indptr) && aten::IsNullArray(indices) &&\n           aten::IsNullArray(data);\n  }\n\n  // Check and update the internal flag is_pinned.\n  // This function will initialize a cuda context.\n  inline bool CheckIfPinnedInCUDA() {\n    is_pinned = (aten::IsNullArray(indptr) || indptr.IsPinned()) &&\n                (aten::IsNullArray(indices) || indices.IsPinned()) &&\n                (aten::IsNullArray(data) || data.IsPinned());\n    return is_pinned;\n  }\n\n  /** @brief Return a copy of this matrix on the give device context. */\n  inline CSRMatrix CopyTo(const DGLContext& ctx) const {\n    if (ctx == indptr->ctx) return *this;\n    return CSRMatrix(\n        num_rows, num_cols, indptr.CopyTo(ctx), indices.CopyTo(ctx),\n        aten::IsNullArray(data) ? data : data.CopyTo(ctx), sorted);\n  }\n\n  /** @brief Return a copy of this matrix in pinned (page-locked) memory. */\n  inline CSRMatrix PinMemory() {\n    if (!IsEmpty()) {\n      if (is_pinned) return *this;\n      auto new_csr = CSRMatrix(\n          num_rows, num_cols, indptr.PinMemory(), indices.PinMemory(),\n          aten::IsNullArray(data) ? data : data.PinMemory(), sorted);\n      CHECK(new_csr.CheckIfPinnedInCUDA())\n          << \"An internal DGL error has occured while trying to pin a CSR \"\n             \"matrix. Please file a bug at \"\n             \"'https://github.com/dmlc/dgl/issues' \"\n             \"with the above stacktrace.\";\n      return new_csr;\n    }\n    is_pinned = true;\n    return *this;\n  }\n\n  /**\n   * @brief Pin the indptr, indices and data (if not Null) of the matrix.\n   * @note This is an in-place method. Behavior depends on the current context,\n   *       kDGLCPU: will be pinned;\n   *       IsPinned: directly return;\n   *       kDGLCUDA: invalid, will throw an error.\n   *       The context check is deferred to pinning the NDArray.\n   */\n  inline void PinMemory_() {\n    if (!IsEmpty()) {\n      if (is_pinned) return;\n      indptr.PinMemory_();\n      indices.PinMemory_();\n      if (!aten::IsNullArray(data)) {\n        data.PinMemory_();\n      }\n      is_pinned = true;\n    }\n    is_pinned = true;\n    return;\n  }\n\n  /**\n   * @brief Unpin the indptr, indices and data (if not Null) of the matrix.\n   * @note This is an in-place method. Behavior depends on the current context,\n   *       IsPinned: will be unpinned;\n   *       others: directly return.\n   *       The context check is deferred to unpinning the NDArray.\n   */\n  inline void UnpinMemory_() {\n    if (!IsEmpty()) {\n      if (!is_pinned) return;\n      indptr.UnpinMemory_();\n      indices.UnpinMemory_();\n      if (!aten::IsNullArray(data)) {\n        data.UnpinMemory_();\n      }\n      is_pinned = false;\n    }\n    is_pinned = false;\n    return;\n  }\n\n  /**\n   * @brief Record stream for the indptr, indices and data (if not Null) of the\n   * matrix.\n   * @param stream The stream that is using the graph\n   */\n  inline void RecordStream(DGLStreamHandle stream) const {\n    indptr.RecordStream(stream);\n    indices.RecordStream(stream);\n    if (!aten::IsNullArray(data)) {\n      data.RecordStream(stream);\n    }\n  }\n};\n\n///////////////////////// CSR routines //////////////////////////\n\n/** @brief Return true if the value (row, col) is non-zero */\nbool CSRIsNonZero(CSRMatrix, int64_t row, int64_t col);\n/**\n * @brief Batched implementation of CSRIsNonZero.\n * @note This operator allows broadcasting (i.e, either row or col can be of\n * length 1).\n */\nruntime::NDArray CSRIsNonZero(\n    CSRMatrix, runtime::NDArray row, runtime::NDArray col);\n\n/** @brief Return the nnz of the given row */\nint64_t CSRGetRowNNZ(CSRMatrix, int64_t row);\nruntime::NDArray CSRGetRowNNZ(CSRMatrix, runtime::NDArray row);\n\n/** @brief Return the column index array of the given row */\nruntime::NDArray CSRGetRowColumnIndices(CSRMatrix, int64_t row);\n\n/** @brief Return the data array of the given row */\nruntime::NDArray CSRGetRowData(CSRMatrix, int64_t row);\n\n/** @brief Whether the CSR matrix contains data */\ninline bool CSRHasData(CSRMatrix csr) { return !IsNullArray(csr.data); }\n\n/** @brief Whether the column indices of each row is sorted. */\nbool CSRIsSorted(CSRMatrix csr);\n\n/**\n * @brief Get the data and the row,col indices for each returned entries.\n *\n * The operator supports matrix with duplicate entries and all the matched\n * entries will be returned. The operator assumes there is NO duplicate (row,\n * col) pair in the given input. Otherwise, the returned result is undefined.\n *\n * If some (row, col) pairs do not contain a valid non-zero elements,\n * they will not be included in the return arrays.\n *\n * @note This operator allows broadcasting (i.e, either row or col can be of\n * length 1).\n * @param mat Sparse matrix\n * @param rows Row index\n * @param cols Column index\n * @return Three arrays {rows, cols, data}\n */\nstd::vector<runtime::NDArray> CSRGetDataAndIndices(\n    CSRMatrix, runtime::NDArray rows, runtime::NDArray cols);\n\n/**\n * @brief Get data. The return type is an ndarray due to possible duplicate\n * entries.\n */\ninline runtime::NDArray CSRGetAllData(CSRMatrix mat, int64_t row, int64_t col) {\n  const auto& nbits = mat.indptr->dtype.bits;\n  const auto& ctx = mat.indptr->ctx;\n  IdArray rows = VecToIdArray<int64_t>({row}, nbits, ctx);\n  IdArray cols = VecToIdArray<int64_t>({col}, nbits, ctx);\n  const auto& rst = CSRGetDataAndIndices(mat, rows, cols);\n  return rst[2];\n}\n\n/**\n * @brief Get the data for each (row, col) pair.\n *\n * The operator supports matrix with duplicate entries but only one matched\n * entry will be returned for each (row, col) pair. Support duplicate input\n * (row, col) pairs.\n *\n * If some (row, col) pairs do not contain a valid non-zero elements,\n * their data values are filled with -1.\n *\n * @note This operator allows broadcasting (i.e, either row or col can be of\n * length 1).\n *\n * @param mat Sparse matrix.\n * @param rows Row index.\n * @param cols Column index.\n * @return Data array. The i^th element is the data of (rows[i], cols[i])\n */\nruntime::NDArray CSRGetData(\n    CSRMatrix, runtime::NDArray rows, runtime::NDArray cols);\n\n/**\n * @brief Get the data for each (row, col) pair, then index into the weights\n * array.\n *\n * The operator supports matrix with duplicate entries but only one matched\n * entry will be returned for each (row, col) pair. Support duplicate input\n * (row, col) pairs.\n *\n * If some (row, col) pairs do not contain a valid non-zero elements to index\n * into the weights array, DGL returns the value \\a filler for that pair\n * instead.\n *\n * @note This operator allows broadcasting (i.e, either row or col can be of\n * length 1).\n *\n * @tparam DType the data type of the weights array.\n * @param mat Sparse matrix.\n * @param rows Row index.\n * @param cols Column index.\n * @param weights The weights array.\n * @param filler The value to return for row-column pairs not existent in the\n * matrix.\n * @return Data array. The i^th element is the data of (rows[i], cols[i])\n */\ntemplate <typename DType>\nruntime::NDArray CSRGetData(\n    CSRMatrix, runtime::NDArray rows, runtime::NDArray cols,\n    runtime::NDArray weights, DType filler);\n\n/**\n * @brief Get the data for each (row, col) pair, then index into the weights\n * array.\n *\n * The operator supports matrix with duplicate entries but only one matched\n * entry will be returned for each (row, col) pair. Support duplicate input\n * (row, col) pairs.\n *\n * If some (row, col) pairs do not contain a valid non-zero elements to index\n * into the weights array, DGL returns the value \\a filler for that pair\n * instead.\n *\n * @note This operator allows broadcasting (i.e, either row or col can be of\n * length 1).\n\n * @note This is the floating point number version of `CSRGetData`, which\n removes the dtype template.\n *\n * @param mat Sparse matrix.\n * @param rows Row index.\n * @param cols Column index.\n * @param weights The weights array.\n * @param filler The value to return for row-column pairs not existent in the\n * matrix.\n * @return Data array. The i^th element is the data of (rows[i], cols[i])\n */\nruntime::NDArray CSRGetFloatingData(\n    CSRMatrix, runtime::NDArray rows, runtime::NDArray cols,\n    runtime::NDArray weights, double filler);\n\n/** @brief Return a transposed CSR matrix */\nCSRMatrix CSRTranspose(CSRMatrix csr);\n\n/**\n * @brief Convert CSR matrix to COO matrix.\n *\n * Complexity: O(nnz)\n *\n * - If data_as_order is false, the column and data arrays of the\n *   result COO are equal to the indices and data arrays of the\n *   input CSR. The result COO is also row sorted.\n * - If the input CSR is further sorted, the result COO is also\n *   column sorted.\n *\n * @param csr Input csr matrix\n * @param data_as_order If true, the data array in the input csr matrix contains\n * the order by which the resulting COO tuples are stored. In this case, the\n *                      data array of the resulting COO matrix will be empty\n * because it is essentially a consecutive range.\n * @return a coo matrix\n */\nCOOMatrix CSRToCOO(CSRMatrix csr, bool data_as_order);\n\n/**\n * @brief Slice rows of the given matrix and return.\n *\n * The sliced row IDs are relabeled to starting from zero.\n *\n * Examples:\n * num_rows = 4\n * num_cols = 4\n * indptr = [0, 2, 3, 3, 5]\n * indices = [1, 0, 2, 3, 1]\n *\n *  After CSRSliceRows(csr, 1, 3)\n *\n * num_rows = 2\n * num_cols = 4\n * indptr = [0, 1, 1]\n * indices = [2]\n *\n * @param csr CSR matrix\n * @param start Start row id (inclusive)\n * @param end End row id (exclusive)\n * @return sliced rows stored in a CSR matrix\n */\nCSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end);\nCSRMatrix CSRSliceRows(CSRMatrix csr, runtime::NDArray rows);\n\n/**\n * @brief Get the submatrix specified by the row and col ids.\n *\n * In numpy notation, given matrix M, row index array I, col index array J\n * This function returns the submatrix M[I, J]. It assumes that there is no\n * duplicate (row, col) pair in the given indices. M could have duplicate\n * entries.\n *\n * The sliced row and column IDs are relabeled according to the given\n * rows and cols (i.e., row #0 in the new matrix corresponds to rows[0] in\n * the original matrix).\n *\n * @param csr The input csr matrix\n * @param rows The row index to select\n * @param cols The col index to select\n * @return submatrix\n */\nCSRMatrix CSRSliceMatrix(\n    CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);\n\n/** @return True if the matrix has duplicate entries */\nbool CSRHasDuplicate(CSRMatrix csr);\n\n/**\n * @brief Sort the column index at each row in ascending order in-place.\n *\n * Only the indices and data arrays (if available) will be mutated. The indptr\n * array stays the same.\n *\n * Examples:\n * num_rows = 4\n * num_cols = 4\n * indptr = [0, 2, 3, 3, 5]\n * indices = [1, 0, 2, 3, 1]\n *\n *  After CSRSort_(&csr)\n *\n * indptr = [0, 2, 3, 3, 5]\n * indices = [0, 1, 1, 2, 3]\n */\nvoid CSRSort_(CSRMatrix* csr);\n\n/**\n * @brief Sort the column index at each row in ascending order.\n *\n * Return a new CSR matrix with sorted column indices and data arrays.\n */\ninline CSRMatrix CSRSort(CSRMatrix csr) {\n  if (csr.sorted) return csr;\n  CSRMatrix ret(\n      csr.num_rows, csr.num_cols, csr.indptr, csr.indices.Clone(),\n      CSRHasData(csr) ? csr.data.Clone() : csr.data, csr.sorted);\n  CSRSort_(&ret);\n  return ret;\n}\n\n/**\n * @brief Reorder the rows and colmns according to the new row and column order.\n * @param csr The input csr matrix.\n * @param new_row_ids the new row Ids (the index is the old row Id)\n * @param new_col_ids the new column Ids (the index is the old col Id).\n */\nCSRMatrix CSRReorder(\n    CSRMatrix csr, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids);\n\n/**\n * @brief Remove entries from CSR matrix by entry indices (data indices)\n * @return A new CSR matrix as well as a mapping from the new CSR entries to the\n * old CSR entries.\n */\nCSRMatrix CSRRemove(CSRMatrix csr, IdArray entries);\n\n/**\n * @brief Randomly select a fixed number of non-zero entries along each given\n * row using arXiv:2210.13339, Labor sampling.\n *\n * The picked indices are returned in the form of a COO matrix.\n *\n * The passed random_seed makes it so that for any seed vertex s and its\n * neighbor t, the rolled random variate r_t is the same for any call to this\n * function with the same random seed. When sampling as part of the same batch,\n * one would want identical seeds so that LABOR can globally sample. One example\n * is that for heterogenous graphs, there is a single random seed passed for\n * each edge type. This will sample much fewer vertices compared to having\n * unique random seeds for each edge type. If one called this function\n * individually for each edge type for a heterogenous graph with different\n * random seeds, then it would run LABOR locally for each edge type, resulting\n * into a larger number of vertices being sampled.\n *\n * If this function is called without a random_seed, we get the random seed by\n * getting a random number from DGL.\n *\n *\n * Examples:\n *\n * // csr.num_rows = 4;\n * // csr.num_cols = 4;\n * // csr.indptr = [0, 2, 3, 3, 5]\n * // csr.indices = [0, 1, 1, 2, 3]\n * // csr.data = [2, 3, 0, 1, 4]\n * CSRMatrix csr = ...;\n * IdArray rows = ... ; // [1, 3]\n * COOMatrix sampled = CSRLaborSampling(csr, rows, 2, NullArray(), 0, \\\n *     NullArray(), NullArray());\n * // possible sampled coo matrix:\n * // sampled.num_rows = 4\n * // sampled.num_cols = 4\n * // sampled.rows = [1, 3, 3]\n * // sampled.cols = [1, 2, 3]\n * // sampled.data = [3, 0, 4]\n *\n * @param mat Input CSR matrix.\n * @param rows Rows to sample from.\n * @param num_samples Number of samples using labor sampling\n * @param prob Probability array for nonuniform sampling\n * @param importance_sampling Whether to enable importance sampling\n * @param random_seed The random seed for the sampler\n * @param seed2_contribution The contribution of the second random seed, [0, 1)\n * @param NIDs global nids if sampling from a subgraph\n * @return A pair of COOMatrix storing the picked row and col indices and edge\n *         weights if importance_sampling != 0 or prob argument was passed. Its\n *         data field stores the the index of the picked elements in the value\n *         array.\n */\nstd::pair<COOMatrix, FloatArray> CSRLaborSampling(\n    CSRMatrix mat, IdArray rows, int64_t num_samples,\n    FloatArray prob = NullArray(), int importance_sampling = 0,\n    IdArray random_seed = NullArray(), float seed2_contribution = 0,\n    IdArray NIDs = NullArray());\n\n/*!\n * @brief Randomly select a fixed number of non-zero entries along each given\n * row independently.\n *\n * The function performs random choices along each row independently.\n * The picked indices are returned in the form of a COO matrix.\n *\n * If replace is false and a row has fewer non-zero values than num_samples,\n * all the values are picked.\n *\n * Examples:\n *\n * // csr.num_rows = 4;\n * // csr.num_cols = 4;\n * // csr.indptr = [0, 2, 3, 3, 5]\n * // csr.indices = [0, 1, 1, 2, 3]\n * // csr.data = [2, 3, 0, 1, 4]\n * CSRMatrix csr = ...;\n * IdArray rows = ... ; // [1, 3]\n * COOMatrix sampled = CSRRowWiseSampling(csr, rows, 2, FloatArray(), false);\n * // possible sampled coo matrix:\n * // sampled.num_rows = 4\n * // sampled.num_cols = 4\n * // sampled.rows = [1, 3, 3]\n * // sampled.cols = [1, 2, 3]\n * // sampled.data = [3, 0, 4]\n *\n * @param mat Input CSR matrix.\n * @param rows Rows to sample from.\n * @param num_samples Number of samples\n * @param prob_or_mask Unnormalized probability array or mask array.\n *                     Should be of the same length as the data array.\n *                     If an empty array is provided, assume uniform.\n * @param replace True if sample with replacement\n * @return A COOMatrix storing the picked row, col and data indices.\n * @note The edges of the entire graph must be ordered by their edge types.\n */\nCOOMatrix CSRRowWiseSampling(\n    CSRMatrix mat, IdArray rows, int64_t num_samples,\n    NDArray prob_or_mask = NDArray(), bool replace = true);\n\n/*!\n * @brief Randomly select a fixed number of non-zero entries along each given\n * row independently.\n *\n * The function performs random choices along each row independently.\n * The picked indices are returned in the form of a CSR matrix, with\n * additional IdArray that is an extended version of CSR's index pointers.\n *\n * With template parameter set to True rows are also saved as new seed nodes and\n * mapped\n *\n * If replace is false and a row has fewer non-zero values than num_samples,\n * all the values are picked.\n *\n * Examples:\n *\n * // csr.num_rows = 4;\n * // csr.num_cols = 4;\n * // csr.indptr = [0, 2, 3, 3, 5]\n * // csr.indices = [0, 1, 1, 2, 3]\n * // csr.data = [2, 3, 0, 1, 4]\n * CSRMatrix csr = ...;\n * IdArray rows = ... ; // [1, 3]\n * IdArray seed_mapping = [-1, -1, -1, -1];\n * std::vector<IdType> new_seed_nodes = {};\n *\n * std::pair<CSRMatrix, IdArray> sampled = CSRRowWiseSamplingFused<\n *                                         typename IdType, True>(\n *                                         csr, rows, seed_mapping,\n *                                         new_seed_nodes, 2,\n *                                         FloatArray(), false);\n * // possible sampled csr matrix:\n * // sampled.first.num_rows = 2\n * // sampled.first.num_cols = 3\n * // sampled.first.indptr = [0, 1, 3]\n * // sampled.first.indices = [1, 2, 3]\n * // sampled.first.data = [0, 1, 4]\n * // sampled.second = [0, 1, 1]\n * // seed_mapping = [-1, 0, -1, 1];\n * // new_seed_nodes = {1, 3};\n *\n * @tparam IdType Graph's index data type, can be int32_t or int64_t\n * @tparam map_seed_nodes If set for true we map and copy rows to new_seed_nodes\n * @param mat Input CSR matrix.\n * @param rows Rows to sample from.\n * @param seed_mapping Mapping array used if map_seed_nodes=true. If so each row\n * from rows will be set to its position e.g. mapping[rows[i]] = i.\n * @param new_seed_nodes Vector used if map_seed_nodes=true. If so it will\n * contain rows.\n * @param rows Rows to sample from.\n * @param num_samples Number of samples\n * @param prob_or_mask Unnormalized probability array or mask array.\n *                     Should be of the same length as the data array.\n *                     If an empty array is provided, assume uniform.\n * @param replace True if sample with replacement\n * @return A CSRMatrix storing the picked row, col and data indices,\n *         COO version of picked rows\n * @note The edges of the entire graph must be ordered by their edge types,\n *       rows must be unique\n */\ntemplate <typename IdType, bool map_seed_nodes>\nstd::pair<CSRMatrix, IdArray> CSRRowWiseSamplingFused(\n    CSRMatrix mat, IdArray rows, IdArray seed_mapping,\n    std::vector<IdType>* new_seed_nodes, int64_t num_samples,\n    NDArray prob_or_mask = NDArray(), bool replace = true);\n\n/**\n * @brief Randomly select a fixed number of non-zero entries for each edge type\n *        along each given row independently.\n *\n * The function performs random choices along each row independently.\n * In each row, num_samples samples is picked for each edge type. (The edge\n * type is stored in etypes)\n * The picked indices are returned in the form of a COO matrix.\n *\n * If replace is false and a row has fewer non-zero values than num_samples,\n * all the values are picked.\n *\n * Examples: TODO\n *\n * // csr.num_rows = 4;\n * // csr.num_cols = 4;\n * // csr.indptr = [0, 4, 4, 4, 5]\n * // csr.cols = [0, 1, 3, 2, 3]\n * // csr.data = [2, 3, 0, 1, 4]\n * // eid2etype_offset = [0, 3, 4, 5]\n * CSRMatrix csr = ...;\n * IdArray rows = ... ; // [0, 3]\n * std::vector<int64_t> num_samples = {2, 2, 2};\n * COOMatrix sampled = CSRRowWisePerEtypeSampling(csr, rows, eid2etype_offset,\n * num_samples, FloatArray(), false);\n * // possible sampled coo matrix:\n * // sampled.num_rows = 4\n * // sampled.num_cols = 4\n * // sampled.rows = [0, 0, 0, 3]\n * // sampled.cols = [0, 3, 2, 3]\n * // sampled.data = [2, 0, 1, 4]\n *\n * @param mat Input CSR matrix.\n * @param rows Rows to sample from.\n * @param eid2etype_offset The offset to each edge type.\n * @param num_samples Number of samples to choose per edge type.\n * @param prob_or_mask Unnormalized probability array or mask array.\n *                     Should be of the same length as the data array.\n *                     If an empty array is provided, assume uniform.\n * @param replace True if sample with replacement\n * @param rowwise_etype_sorted whether the CSR column indices per row are\n * ordered by edge type.\n * @return A COOMatrix storing the picked row, col and data indices.\n * @note The edges must be ordered by their edge types.\n */\nCOOMatrix CSRRowWisePerEtypeSampling(\n    CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,\n    const std::vector<int64_t>& num_samples,\n    const std::vector<NDArray>& prob_or_mask, bool replace = true,\n    bool rowwise_etype_sorted = false);\n\n/**\n * @brief Select K non-zero entries with the largest weights along each given\n * row.\n *\n * The function performs top-k selection along each row independently.\n * The picked indices are returned in the form of a COO matrix.\n *\n * If replace is false and a row has fewer non-zero values than k,\n * all the values are picked.\n *\n * Examples:\n *\n * // csr.num_rows = 4;\n * // csr.num_cols = 4;\n * // csr.indptr = [0, 2, 3, 3, 5]\n * // csr.indices = [0, 1, 1, 2, 3]\n * // csr.data = [2, 3, 0, 1, 4]\n * CSRMatrix csr = ...;\n * IdArray rows = ... ;  // [0, 1, 3]\n * FloatArray weight = ... ;  // [1., 0., -1., 10., 20.]\n * COOMatrix sampled = CSRRowWiseTopk(csr, rows, 1, weight);\n * // possible sampled coo matrix:\n * // sampled.num_rows = 4\n * // sampled.num_cols = 4\n * // sampled.rows = [0, 1, 3]\n * // sampled.cols = [1, 1, 2]\n * // sampled.data = [3, 0, 1]\n *\n * @param mat Input CSR matrix.\n * @param rows Rows to sample from.\n * @param k The K value.\n * @param weight Weight associated with each entry. Should be of the same length\n * as the data array. If an empty array is provided, assume uniform.\n * @param ascending If true, elements are sorted by ascending order, equivalent\n * to find the K smallest values. Otherwise, find K largest values.\n * @return A COOMatrix storing the picked row and col indices. Its data field\n * stores the the index of the picked elements in the value array.\n */\nCOOMatrix CSRRowWiseTopk(\n    CSRMatrix mat, IdArray rows, int64_t k, FloatArray weight,\n    bool ascending = false);\n\n/**\n * @brief Randomly select a fixed number of non-zero entries along each given\n * row independently, where the probability of columns to be picked can be\n * biased according to its tag.\n *\n * Each column is assigned an integer tag which determines its probability to be\n * sampled. Users can assign different probability to different tags.\n *\n * This function only works with a CSR matrix sorted according to the tag so\n * that entries with the same column tag are arranged in a consecutive range,\n * and the input `tag_offset` represents the boundaries of these ranges.\n * However, the function itself will not check if the input matrix has been\n * sorted. It's the caller's responsibility to ensure the input matrix has been\n * sorted by `CSRSortByTag` (it will also return a NDArray `tag_offset` which\n * should be used as an input of this function).\n *\n * The picked indices are returned in the form of a COO matrix.\n *\n * If replace is false and a row has fewer non-zero values than num_samples,\n * all the values are picked.\n *\n * Examples:\n *\n * // csr.num_rows = 4;\n * // csr.num_cols = 4;\n * // csr.indptr = [0, 2, 4, 5, 5]\n * // csr.indices =                [1, 2, 2, 3, 3]\n * // tag of each element's column: 0, 0, 0, 1, 1\n * // tag_offset = [[0, 2, 2], [0, 1, 2], [0, 0, 1]]\n * // csr.data = [2, 3, 0, 1, 4]\n * // bias = [1.0, 0.0]\n * CSRMatrix mat = ...;\n * IdArray rows = ...; //[0, 1]\n * NDArray tag_offset = ...;\n * FloatArray bias = ...;\n * COOMatrix sampled = CSRRowWiseSamplingBiased(mat, rows, 1, bias);\n * // possible sampled coo matrix:\n * // sampled.num_rows = 4\n * // sampled.num_cols = 4\n * // sampled.rows = [0, 1]\n * // sampled.cols = [1, 2]\n * // sampled.data = [2, 0]\n * // Note that in this case, for row 1, the column 3 will never be picked as it\n * has tag 1 and the\n * // probability of tag 1 is 0.\n *\n *\n * @param mat Input CSR matrix.\n * @param rows Rows to sample from.\n * @param num_samples Number of samples.\n * @param tag_offset The boundaries of tags. Should be of the shape [num_row,\n * num_tags+1]\n * @param bias Unnormalized probability array. Should be of length num_tags\n * @param replace True if sample with replacement\n * @return A COOMatrix storing the picked row and col indices. Its data field\n * stores the the index of the picked elements in the value array.\n *\n */\nCOOMatrix CSRRowWiseSamplingBiased(\n    CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray tag_offset,\n    FloatArray bias, bool replace = true);\n\n/**\n * @brief Uniformly sample row-column pairs whose entries do not exist in the\n * given sparse matrix using rejection sampling.\n *\n * @note The number of samples returned may not necessarily be the number of\n * samples given.\n *\n * @param csr The CSR matrix.\n * @param num_samples The number of samples.\n * @param num_trials The number of trials.\n * @param exclude_self_loops Do not include the examples where the row equals\n * the column.\n * @param replace Whether to sample with replacement.\n * @param redundancy How much redundant negative examples to take in case of\n * duplicate examples.\n * @return A pair of row and column tensors.\n */\nstd::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(\n    const CSRMatrix& csr, int64_t num_samples, int num_trials,\n    bool exclude_self_loops, bool replace, double redundancy);\n\n/**\n * @brief Sort the column index according to the tag of each column.\n *\n * Example:\n * indptr  = [0, 5, 8]\n * indices = [0, 1, 2, 3, 4, 0, 1, 2]\n *\n * tag     = [1, 1, 0, 2, 0]\n *\n *  After CSRSortByTag\n *\n * indptr  = [0, 5, 8]\n * indices = [2, 4, 0, 1, 3, 2, 0, 1]\n * (tag)   = [0, 0, 1, 1, 2, 0, 1, 1]\n *           ^    ^     ^  ^\n *                         ^  ^     ^^\n * (the tag array itself is unchanged.)\n *\n * Return:\n * [[0, 2, 4, 5], [0, 1, 3, 3]] (marked with ^)\n *\n * @param csr The csr matrix to be sorted\n * @param tag_array Tag of each column. IdArray with length num_cols\n * @param num_tags Number of tags. It should be equal to max(tag_array)+1.\n * @return 1. A sorted copy of the given CSR matrix\n *         2. The split positions of different tags. NDArray of shape (num_rows,\n * num_tags + 1)\n */\nstd::pair<CSRMatrix, NDArray> CSRSortByTag(\n    const CSRMatrix& csr, const IdArray tag_array, int64_t num_tags);\n\n/**\n * @brief Union two CSRMatrix into one CSRMatrix.\n *\n * Two Matrix must have the same shape.\n *\n * Example:\n *\n * A = [[0, 0, 1, 0],\n *      [1, 0, 1, 1],\n *      [0, 1, 0, 0]]\n *\n * B = [[0, 1, 1, 0],\n *      [0, 0, 0, 1],\n *      [0, 0, 1, 0]]\n *\n * CSRMatrix_A.num_rows : 3\n * CSRMatrix_A.num_cols : 4\n * CSRMatrix_B.num_rows : 3\n * CSRMatrix_B.num_cols : 4\n *\n * C = UnionCsr({A, B});\n *\n * C = [[0, 1, 2, 0],\n *      [1, 0, 1, 2],\n *      [0, 1, 1, 0]]\n *\n * CSRMatrix_C.num_rows : 3\n * CSRMatrix_C.num_cols : 4\n */\nCSRMatrix UnionCsr(const std::vector<CSRMatrix>& csrs);\n\n/**\n * @brief Union a list CSRMatrix into one CSRMatrix.\n *\n * Examples:\n *\n * A = [[0, 0, 1],\n *      [1, 0, 1],\n *      [0, 1, 0]]\n *\n * B = [[0, 0],\n *      [1, 0]]\n *\n * CSRMatrix_A.num_rows : 3\n * CSRMatrix_A.num_cols : 3\n * CSRMatrix_B.num_rows : 2\n * CSRMatrix_B.num_cols : 2\n *\n * C = DisjointUnionCsr({A, B});\n *\n * C = [[0, 0, 1, 0, 0],\n *      [1, 0, 1, 0, 0],\n *      [0, 1, 0, 0, 0],\n *      [0, 0, 0, 0, 0],\n *      [0, 0, 0, 1, 0]]\n * CSRMatrix_C.num_rows : 5\n * CSRMatrix_C.num_cols : 5\n *\n * @param csrs The input list of csr matrix.\n * @param src_offset A list of integers recording src vertix id offset of each\n * Matrix in csrs\n * @param src_offset A list of integers recording dst vertix id offset of each\n * Matrix in csrs\n * @return The combined CSRMatrix.\n */\nCSRMatrix DisjointUnionCsr(const std::vector<CSRMatrix>& csrs);\n\n/**\n * @brief CSRMatrix toSimple.\n *\n * A = [[0, 0, 0],\n *      [3, 0, 2],\n *      [1, 1, 0],\n *      [0, 0, 4]]\n *\n * B, cnt, edge_map = CSRToSimple(A)\n *\n * B = [[0, 0, 0],\n *      [1, 0, 1],\n *      [1, 1, 0],\n *      [0, 0, 1]]\n * cnt = [3, 2, 1, 1, 4]\n * edge_map = [0, 0, 0, 1, 1, 2, 3, 4, 4, 4, 4]\n *\n * @return The simplified CSRMatrix\n *         The count recording the number of duplicated edges from the original\n * graph. The edge mapping from the edge IDs of original graph to those of the\n *         returned graph.\n */\nstd::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple(const CSRMatrix& csr);\n\n/**\n * @brief Split a CSRMatrix into multiple disjoint components.\n *\n * Examples:\n *\n * C = [[0, 0, 1, 0, 0],\n *      [1, 0, 1, 0, 0],\n *      [0, 1, 0, 0, 0],\n *      [0, 0, 0, 0, 0],\n *      [0, 0, 0, 1, 0],\n *      [0, 0, 0, 0, 1]]\n * CSRMatrix_C.num_rows : 6\n * CSRMatrix_C.num_cols : 5\n *\n * batch_size : 2\n * edge_cumsum : [0, 4, 6]\n * src_vertex_cumsum : [0, 3, 6]\n * dst_vertex_cumsum : [0, 3, 5]\n *\n * ret = DisjointPartitionCsrBySizes(C,\n *                                   batch_size,\n *                                   edge_cumsum,\n *                                   src_vertex_cumsum,\n *                                   dst_vertex_cumsum)\n *\n * A = [[0, 0, 1],\n *      [1, 0, 1],\n *      [0, 1, 0]]\n * CSRMatrix_A.num_rows : 3\n * CSRMatrix_A.num_cols : 3\n *\n * B = [[0, 0],\n *      [1, 0],\n *      [0, 1]]\n * CSRMatrix_B.num_rows : 3\n * CSRMatrix_B.num_cols : 2\n *\n * @param csr CSRMatrix to split.\n * @param batch_size Number of disjoin components (Sub CSRMatrix)\n * @param edge_cumsum Number of edges of each components\n * @param src_vertex_cumsum Number of src vertices of each component.\n * @param dst_vertex_cumsum Number of dst vertices of each component.\n * @return A list of CSRMatrixes representing each disjoint components.\n */\nstd::vector<CSRMatrix> DisjointPartitionCsrBySizes(\n    const CSRMatrix& csrs, const uint64_t batch_size,\n    const std::vector<uint64_t>& edge_cumsum,\n    const std::vector<uint64_t>& src_vertex_cumsum,\n    const std::vector<uint64_t>& dst_vertex_cumsum);\n\n/**\n * @brief Slice a contiguous chunk from a CSRMatrix\n *\n * Examples:\n *\n * C = [[0, 0, 1, 0, 0],\n *      [1, 0, 1, 0, 0],\n *      [0, 1, 0, 0, 0],\n *      [0, 0, 0, 0, 0],\n *      [0, 0, 0, 1, 0],\n *      [0, 0, 0, 0, 1]]\n * CSRMatrix_C.num_rows : 6\n * CSRMatrix_C.num_cols : 5\n *\n * edge_range : [4, 6]\n * src_vertex_range : [3, 6]\n * dst_vertex_range : [3, 5]\n *\n * ret = CSRSliceContiguousChunk(C,\n *                               edge_range,\n *                               src_vertex_range,\n *                               dst_vertex_range)\n *\n * ret = [[0, 0],\n *        [1, 0],\n *        [0, 1]]\n * CSRMatrix_ret.num_rows : 3\n * CSRMatrix_ret.num_cols : 2\n *\n * @param csr CSRMatrix to slice.\n * @param edge_range ID range of the edges in the chunk\n * @param src_vertex_range ID range of the src vertices in the chunk.\n * @param dst_vertex_range ID range of the dst vertices in the chunk.\n * @return CSRMatrix representing the chunk.\n */\nCSRMatrix CSRSliceContiguousChunk(\n    const CSRMatrix& csr, const std::vector<uint64_t>& edge_range,\n    const std::vector<uint64_t>& src_vertex_range,\n    const std::vector<uint64_t>& dst_vertex_range);\n\n/**\n * @brief Generalized Sparse Matrix-Matrix Multiplication on CSR.\n * @param op The binary operator, could be `add`, `sub', `mul`, 'div',\n *        `copy_u`, `copy_e'.\n * @param op The reduce operator, could be `sum`, `min`, `max'.\n * @param csr The CSR we apply SpMM on.\n * @param ufeat The source node feature.\n * @param efeat The edge feature.\n * @param out The output feature on destination nodes.\n * @param out_aux A list of NDArray's that contains auxiliary information such\n *        as the argmax on source nodes and edges for reduce operators such as\n *        `min` and `max`.\n */\nvoid CSRSpMM(\n    const std::string& op, const std::string& reduce, const CSRMatrix& csr,\n    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);\n\n/** @brief CSRSpMM C interface without std::string. */\nvoid CSRSpMM(\n    const char* op, const char* reduce, const CSRMatrix& csr, NDArray ufeat,\n    NDArray efeat, NDArray out, std::vector<NDArray> out_aux);\n\n/**\n * @brief Generalized Sampled Dense-Dense Matrix Multiplication on CSR.\n * @param op The binary operator, could be `add`, `sub', `mul`, 'div',\n *        `dot`, `copy_u`, `copy_e'.\n * @param csr The CSR we apply SpMM on.\n * @param ufeat The source node feature.\n * @param vfeat The destination node feature.\n * @param out The output feature on edge.\n * @param lhs_target Type of `ufeat` (0: source, 1: edge, 2: destination).\n * @param rhs_target Type of `ufeat` (0: source, 1: edge, 2: destination).\n */\nvoid CSRSDDMM(\n    const std::string& op, const CSRMatrix& csr, NDArray ufeat, NDArray efeat,\n    NDArray out, int lhs_target, int rhs_target);\n\n/** @brief CSRSDDMM C interface without std::string. */\nvoid CSRSDDMM(\n    const char* op, const CSRMatrix& csr, NDArray ufeat, NDArray efeat,\n    NDArray out, int lhs_target, int rhs_target);\n\n}  // namespace aten\n}  // namespace dgl\n\nnamespace dmlc {\nDMLC_DECLARE_TRAITS(has_saveload, dgl::aten::CSRMatrix, true);\n}  // namespace dmlc\n\n#endif  // DGL_ATEN_CSR_H_\n"
  },
  {
    "path": "include/dgl/aten/macro.h",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file dgl/aten/macro.h\n * @brief Common macros for aten package.\n */\n\n#ifndef DGL_ATEN_MACRO_H_\n#define DGL_ATEN_MACRO_H_\n\n///////////////////////// Dispatchers //////////////////////////\n\n/**\n * Dispatch according to device:\n *\n * ATEN_XPU_SWITCH(array->ctx.device_type, XPU, {\n *   // Now XPU is a placeholder for array->ctx.device_type\n *   DeviceSpecificImplementation<XPU>(...);\n * });\n */\n#define ATEN_XPU_SWITCH(val, XPU, op, ...)                               \\\n  do {                                                                   \\\n    if ((val) == kDGLCPU) {                                              \\\n      constexpr auto XPU = kDGLCPU;                                      \\\n      { __VA_ARGS__ }                                                    \\\n    } else {                                                             \\\n      LOG(FATAL) << \"Operator \" << (op) << \" does not support \"          \\\n                 << dgl::runtime::DeviceTypeCode2Str(val) << \" device.\"; \\\n    }                                                                    \\\n  } while (0)\n\n/**\n * Dispatch according to device:\n *\n * XXX(minjie): temporary macro that allows CUDA operator\n *\n * ATEN_XPU_SWITCH(array->ctx.device_type, XPU, {\n *   // Now XPU is a placeholder for array->ctx.device_type\n *   DeviceSpecificImplementation<XPU>(...);\n * });\n *\n * We treat pinned memory as normal host memory if we don't want\n * to enable CUDA UVA access for this operator\n */\n#ifdef DGL_USE_CUDA\n#define ATEN_XPU_SWITCH_CUDA(val, XPU, op, ...)                          \\\n  do {                                                                   \\\n    if ((val) == kDGLCPU) {                                              \\\n      constexpr auto XPU = kDGLCPU;                                      \\\n      { __VA_ARGS__ }                                                    \\\n    } else if ((val) == kDGLCUDA) {                                      \\\n      constexpr auto XPU = kDGLCUDA;                                     \\\n      { __VA_ARGS__ }                                                    \\\n    } else {                                                             \\\n      LOG(FATAL) << \"Operator \" << (op) << \" does not support \"          \\\n                 << dgl::runtime::DeviceTypeCode2Str(val) << \" device.\"; \\\n    }                                                                    \\\n  } while (0)\n#else  // DGL_USE_CUDA\n#define ATEN_XPU_SWITCH_CUDA ATEN_XPU_SWITCH\n#endif  // DGL_USE_CUDA\n\n/**\n * Dispatch according to integral type (either int32 or int64):\n *\n * ATEN_ID_TYPE_SWITCH(array->dtype, IdType, {\n *   // Now IdType is the type corresponding to data type in array.\n *   // For instance, one can do this for a CPU array:\n *   DType *data = static_cast<DType *>(array->data);\n * });\n */\n#define ATEN_ID_TYPE_SWITCH(val, IdType, ...)                   \\\n  do {                                                          \\\n    CHECK_EQ((val).code, kDGLInt) << \"ID must be integer type\"; \\\n    if ((val).bits == 32) {                                     \\\n      typedef int32_t IdType;                                   \\\n      { __VA_ARGS__ }                                           \\\n    } else if ((val).bits == 64) {                              \\\n      typedef int64_t IdType;                                   \\\n      { __VA_ARGS__ }                                           \\\n    } else {                                                    \\\n      LOG(FATAL) << \"ID can only be int32 or int64\";            \\\n    }                                                           \\\n  } while (0)\n\n/**\n * Dispatch according to bits (either int32 or int64):\n *\n * ATEN_ID_BITS_SWITCH(bits, IdType, {\n *   // Now IdType is the type corresponding to data type in array.\n *   // For instance, one can do this for a CPU array:\n *   DType *data = static_cast<DType *>(array->data);\n * });\n */\n#define ATEN_ID_BITS_SWITCH(bits, IdType, ...)                      \\\n  do {                                                              \\\n    CHECK((bits) == 32 || (bits) == 64) << \"bits must be 32 or 64\"; \\\n    if ((bits) == 32) {                                             \\\n      typedef int32_t IdType;                                       \\\n      { __VA_ARGS__ }                                               \\\n    } else if ((bits) == 64) {                                      \\\n      typedef int64_t IdType;                                       \\\n      { __VA_ARGS__ }                                               \\\n    } else {                                                        \\\n      LOG(FATAL) << \"ID can only be int32 or int64\";                \\\n    }                                                               \\\n  } while (0)\n\n/**\n * Dispatch according to float type (either float32 or float64):\n *\n * ATEN_FLOAT_TYPE_SWITCH(array->dtype, FloatType, {\n *   // Now FloatType is the type corresponding to data type in array.\n *   // For instance, one can do this for a CPU array:\n *   FloatType *data = static_cast<FloatType *>(array->data);\n * });\n */\n#define ATEN_FLOAT_TYPE_SWITCH(val, FloatType, val_name, ...)               \\\n  do {                                                                      \\\n    CHECK_EQ((val).code, kDGLFloat) << (val_name) << \" must be float type\"; \\\n    if ((val).bits == 32) {                                                 \\\n      typedef float FloatType;                                              \\\n      { __VA_ARGS__ }                                                       \\\n    } else if ((val).bits == 64) {                                          \\\n      typedef double FloatType;                                             \\\n      { __VA_ARGS__ }                                                       \\\n    } else {                                                                \\\n      LOG(FATAL) << (val_name) << \" can only be float32 or float64\";        \\\n    }                                                                       \\\n  } while (0)\n\n/**\n * Dispatch according to float type, including 16bits\n * (float16/bfloat16/float32/float64).\n */\n#ifdef DGL_USE_CUDA\n#if BF16_ENABLED\n#define ATEN_FLOAT_TYPE_SWITCH_16BITS(val, FloatType, XPU, val_name, ...)   \\\n  do {                                                                      \\\n    CHECK((val).code == kDGLFloat || (val.code == kDGLBfloat))              \\\n        << (val_name) << \" must be float type\";                             \\\n    if ((val).bits == 32) {                                                 \\\n      typedef float FloatType;                                              \\\n      { __VA_ARGS__ }                                                       \\\n    } else if ((val).bits == 64) {                                          \\\n      typedef double FloatType;                                             \\\n      { __VA_ARGS__ }                                                       \\\n    } else if (                                                             \\\n        XPU == kDGLCUDA && (val).bits == 16 && (val).code == kDGLFloat) {   \\\n      typedef __half FloatType;                                             \\\n      { __VA_ARGS__ }                                                       \\\n    } else if (                                                             \\\n        XPU == kDGLCUDA && (val).bits == 16 && (val).code == kDGLBfloat) {  \\\n      typedef __nv_bfloat16 FloatType;                                      \\\n      { __VA_ARGS__ }                                                       \\\n    } else if (                                                             \\\n        XPU == kDGLCPU && (val).bits == 16 && (val).code == kDGLFloat) {    \\\n      LOG(FATAL) << (val_name) << \" can't be float16 on CPU\";               \\\n    } else if (                                                             \\\n        XPU == kDGLCPU && (val).bits == 16 && (val).code == kDGLBfloat) {   \\\n      typedef BFloat16 FloatType;                                           \\\n      { __VA_ARGS__ }                                                       \\\n    } else {                                                                \\\n      LOG(FATAL) << (val_name)                                              \\\n                 << \" can only be float16/bfloat16/float32/float64 on GPU\"; \\\n    }                                                                       \\\n  } while (0)\n#else  // BF16_ENABLED\n#define ATEN_FLOAT_TYPE_SWITCH_16BITS(val, FloatType, XPU, val_name, ...)  \\\n  do {                                                                     \\\n    CHECK((val).code == kDGLFloat || (val.code == kDGLBfloat))             \\\n        << (val_name) << \" must be float type\";                            \\\n    if ((val).bits == 32) {                                                \\\n      typedef float FloatType;                                             \\\n      { __VA_ARGS__ }                                                      \\\n    } else if ((val).bits == 64) {                                         \\\n      typedef double FloatType;                                            \\\n      { __VA_ARGS__ }                                                      \\\n    } else if (                                                            \\\n        XPU == kDGLCUDA && (val).bits == 16 && (val).code == kDGLFloat) {  \\\n      typedef __half FloatType;                                            \\\n      { __VA_ARGS__ }                                                      \\\n    } else if (                                                            \\\n        XPU == kDGLCUDA && (val).bits == 16 && (val).code == kDGLBfloat) { \\\n      LOG(FATAL) << \"bfloat16 requires CUDA >= 11.0\";                      \\\n    } else if (                                                            \\\n        XPU == kDGLCPU && (val).bits == 16 && (val).code == kDGLFloat) {   \\\n      LOG(FATAL) << (val_name) << \" can't be float16 on CPU\";              \\\n    } else if (                                                            \\\n        XPU == kDGLCPU && (val).bits == 16 && (val).code == kDGLBfloat) {  \\\n      typedef BFloat16 FloatType;                                          \\\n      { __VA_ARGS__ }                                                      \\\n    } else {                                                               \\\n      LOG(FATAL) << (val_name)                                             \\\n                 << \" can only be float16/float32/float64 on GPU\";         \\\n    }                                                                      \\\n  } while (0)\n#endif  // BF16_ENABLED\n#else   // DGL_USE_CUDA\n#define ATEN_FLOAT_TYPE_SWITCH_16BITS(val, FloatType, XPU, val_name, ...) \\\n  do {                                                                    \\\n    CHECK((val).code == kDGLFloat || (val.code == kDGLBfloat))            \\\n        << (val_name) << \" must be float type\";                           \\\n    if ((val).bits == 32) {                                               \\\n      typedef float FloatType;                                            \\\n      { __VA_ARGS__ }                                                     \\\n    } else if ((val).bits == 64) {                                        \\\n      typedef double FloatType;                                           \\\n      { __VA_ARGS__ }                                                     \\\n    } else if (                                                           \\\n        XPU == kDGLCPU && (val).bits == 16 && (val).code == kDGLBfloat) { \\\n      typedef BFloat16 FloatType;                                         \\\n      { __VA_ARGS__ }                                                     \\\n    } else {                                                              \\\n      LOG(FATAL) << (val_name)                                            \\\n                 << \" can only be bfloat16/float32/float64 on CPU\";       \\\n    }                                                                     \\\n  } while (0)\n#endif  // DGL_USE_CUDA\n\n/**\n * Dispatch according to data type (int32, int64, float32 or float64):\n *\n * ATEN_DTYPE_SWITCH(array->dtype, DType, {\n *   // Now DType is the type corresponding to data type in array.\n *   // For instance, one can do this for a CPU array:\n *   DType *data = static_cast<DType *>(array->data);\n * });\n */\n#define ATEN_DTYPE_SWITCH(val, DType, val_name, ...)                 \\\n  do {                                                               \\\n    if ((val).code == kDGLInt && (val).bits == 32) {                 \\\n      typedef int32_t DType;                                         \\\n      { __VA_ARGS__ }                                                \\\n    } else if ((val).code == kDGLInt && (val).bits == 64) {          \\\n      typedef int64_t DType;                                         \\\n      { __VA_ARGS__ }                                                \\\n    } else if ((val).code == kDGLFloat && (val).bits == 32) {        \\\n      typedef float DType;                                           \\\n      { __VA_ARGS__ }                                                \\\n    } else if ((val).code == kDGLFloat && (val).bits == 64) {        \\\n      typedef double DType;                                          \\\n      { __VA_ARGS__ }                                                \\\n    } else {                                                         \\\n      LOG(FATAL) << (val_name)                                       \\\n                 << \" can only be int32, int64, float32 or float64\"; \\\n    }                                                                \\\n  } while (0)\n\n/**\n * Dispatch according to data type (int8, uint8, float32 or float64):\n *\n * ATEN_FLOAT_INT8_UINT8_TYPE_SWITCH(array->dtype, DType, {\n *   // Now DType is the type corresponding to data type in array.\n *   // For instance, one can do this for a CPU array:\n *   DType *data = static_cast<DType *>(array->data);\n * });\n */\n#define ATEN_FLOAT_INT8_UINT8_TYPE_SWITCH(val, DType, val_name, ...) \\\n  do {                                                               \\\n    if ((val).code == kDGLInt && (val).bits == 8) {                  \\\n      typedef int8_t DType;                                          \\\n      { __VA_ARGS__ }                                                \\\n    } else if ((val).code == kDGLUInt && (val).bits == 8) {          \\\n      typedef uint8_t DType;                                         \\\n      { __VA_ARGS__ }                                                \\\n    } else if ((val).code == kDGLFloat && (val).bits == 32) {        \\\n      typedef float DType;                                           \\\n      { __VA_ARGS__ }                                                \\\n    } else if ((val).code == kDGLFloat && (val).bits == 64) {        \\\n      typedef double DType;                                          \\\n      { __VA_ARGS__ }                                                \\\n    } else {                                                         \\\n      LOG(FATAL) << (val_name)                                       \\\n                 << \" can only be int8, uint8, float32 or float64\";  \\\n    }                                                                \\\n  } while (0)\n\n/**\n * Dispatch data type only based on bit-width (8-bit, 16-bit, 32-bit, 64-bit):\n *\n * ATEN_DTYPE_BITS_ONLY_SWITCH(array->dtype, DType, {\n *   // Now DType is the type which has the same bit-width with the\n *   // data type in array.\n *   // Do not use for computation, but only for read and write.\n *   // For instance, one can do this for a CPU array:\n *   DType *data = static_cast<DType *>(array->data);\n * });\n */\n#define ATEN_DTYPE_BITS_ONLY_SWITCH(val, DType, val_name, ...)       \\\n  do {                                                               \\\n    if ((val).bits == 8) {                                           \\\n      typedef int8_t DType;                                          \\\n      { __VA_ARGS__ }                                                \\\n    } else if ((val).bits == 16) {                                   \\\n      typedef int16_t DType;                                         \\\n      { __VA_ARGS__ }                                                \\\n    } else if ((val).bits == 32) {                                   \\\n      typedef int32_t DType;                                         \\\n      { __VA_ARGS__ }                                                \\\n    } else if ((val).bits == 64) {                                   \\\n      typedef int64_t DType;                                         \\\n      { __VA_ARGS__ }                                                \\\n    } else {                                                         \\\n      LOG(FATAL) << (val_name)                                       \\\n                 << \" can only be 8-bit, 16-bit, 32-bit, or 64-bit\"; \\\n    }                                                                \\\n  } while (0)\n\n/**\n * Dispatch according to integral type of CSR graphs.\n * Identical to ATEN_ID_TYPE_SWITCH except for a different error message.\n */\n#define ATEN_CSR_DTYPE_SWITCH(val, DType, ...)                    \\\n  do {                                                            \\\n    if ((val).code == kDGLInt && (val).bits == 32) {              \\\n      typedef int32_t DType;                                      \\\n      { __VA_ARGS__ }                                             \\\n    } else if ((val).code == kDGLInt && (val).bits == 64) {       \\\n      typedef int64_t DType;                                      \\\n      { __VA_ARGS__ }                                             \\\n    } else {                                                      \\\n      LOG(FATAL) << \"CSR matrix data can only be int32 or int64\"; \\\n    }                                                             \\\n  } while (0)\n\n// Macro to dispatch according to device context and index type.\n#define ATEN_CSR_SWITCH(csr, XPU, IdType, op, ...)                     \\\n  ATEN_XPU_SWITCH((csr).indptr->ctx.device_type, XPU, op, {            \\\n    ATEN_ID_TYPE_SWITCH((csr).indptr->dtype, IdType, {{__VA_ARGS__}}); \\\n  });\n\n// Macro to dispatch according to device context and index type.\n#define ATEN_COO_SWITCH(coo, XPU, IdType, op, ...)                  \\\n  ATEN_XPU_SWITCH((coo).row->ctx.device_type, XPU, op, {            \\\n    ATEN_ID_TYPE_SWITCH((coo).row->dtype, IdType, {{__VA_ARGS__}}); \\\n  });\n\n#define CHECK_VALID_CONTEXT(VAR1, VAR2)                          \\\n  CHECK(                                                         \\\n      ((VAR1)->ctx == (VAR2)->ctx) || (VAR1).IsPinned() ||       \\\n      ((VAR1).NumElements() == 0)) /* Let empty arrays pass */   \\\n      << \"Expected \" << (#VAR2) << \"(\" << (VAR2)->ctx << \")\"     \\\n      << \" to have the same device \"                             \\\n      << \"context as \" << (#VAR1) << \"(\" << (VAR1)->ctx << \"). \" \\\n      << \"Or \" << (#VAR1) << \"(\" << (VAR1)->ctx << \")\"           \\\n      << \" is pinned\";\n\n/**\n * Macro to dispatch according to the context of array and dtype of csr\n * to enable CUDA UVA ops.\n * Context check is covered here to avoid confusion with CHECK_SAME_CONTEXT.\n * If csr has the same context with array, same behivor as ATEN_CSR_SWITCH_CUDA.\n * If csr is pinned, array's context will conduct the actual operation.\n */\n#define ATEN_CSR_SWITCH_CUDA_UVA(csr, array, XPU, IdType, op, ...)       \\\n  do {                                                                   \\\n    CHECK_VALID_CONTEXT(csr.indices, array);                             \\\n    ATEN_XPU_SWITCH_CUDA(array->ctx.device_type, XPU, op, {              \\\n      ATEN_ID_TYPE_SWITCH((csr).indptr->dtype, IdType, {{__VA_ARGS__}}); \\\n    });                                                                  \\\n  } while (0)\n\n// Macro to dispatch according to device context (allowing cuda)\n#ifdef DGL_USE_CUDA\n#define ATEN_CSR_SWITCH_CUDA(csr, XPU, IdType, op, ...)                \\\n  ATEN_XPU_SWITCH_CUDA((csr).indptr->ctx.device_type, XPU, op, {       \\\n    ATEN_ID_TYPE_SWITCH((csr).indptr->dtype, IdType, {{__VA_ARGS__}}); \\\n  });\n\n// Macro to dispatch according to device context and index type.\n#define ATEN_COO_SWITCH_CUDA(coo, XPU, IdType, op, ...)             \\\n  ATEN_XPU_SWITCH_CUDA((coo).row->ctx.device_type, XPU, op, {       \\\n    ATEN_ID_TYPE_SWITCH((coo).row->dtype, IdType, {{__VA_ARGS__}}); \\\n  });\n#else  // DGL_USE_CUDA\n#define ATEN_CSR_SWITCH_CUDA ATEN_CSR_SWITCH\n#define ATEN_COO_SWITCH_CUDA ATEN_COO_SWITCH\n#endif  // DGL_USE_CUDA\n\n///////////////////////// Array checks //////////////////////////\n\n#define IS_INT32(a) ((a)->dtype.code == kDGLInt && (a)->dtype.bits == 32)\n#define IS_INT64(a) ((a)->dtype.code == kDGLInt && (a)->dtype.bits == 64)\n#define IS_FLOAT32(a) ((a)->dtype.code == kDGLFloat && (a)->dtype.bits == 32)\n#define IS_FLOAT64(a) ((a)->dtype.code == kDGLFloat && (a)->dtype.bits == 64)\n\n#define CHECK_IF(cond, prop, value_name, dtype_name)                           \\\n  CHECK(cond) << \"Expecting \" << (prop) << \" of \" << (value_name) << \" to be \" \\\n              << (dtype_name)\n\n#define CHECK_INT32(value, value_name) \\\n  CHECK_IF(IS_INT32(value), \"dtype\", value_name, \"int32\")\n#define CHECK_INT64(value, value_name) \\\n  CHECK_IF(IS_INT64(value), \"dtype\", value_name, \"int64\")\n#define CHECK_INT(value, value_name)                           \\\n  CHECK_IF(                                                    \\\n      IS_INT32(value) || IS_INT64(value), \"dtype\", value_name, \\\n      \"int32 or int64\")\n#define CHECK_FLOAT32(value, value_name) \\\n  CHECK_IF(IS_FLOAT32(value), \"dtype\", value_name, \"float32\")\n#define CHECK_FLOAT64(value, value_name) \\\n  CHECK_IF(IS_FLOAT64(value), \"dtype\", value_name, \"float64\")\n#define CHECK_FLOAT(value, value_name)                             \\\n  CHECK_IF(                                                        \\\n      IS_FLOAT32(value) || IS_FLOAT64(value), \"dtype\", value_name, \\\n      \"float32 or float64\")\n\n#define CHECK_NDIM(value, _ndim, value_name) \\\n  CHECK_IF((value)->ndim == (_ndim), \"ndim\", value_name, _ndim)\n\n#define CHECK_SAME_DTYPE(VAR1, VAR2)                                     \\\n  CHECK((VAR1)->dtype == (VAR2)->dtype)                                  \\\n      << \"Expected \" << (#VAR2) << \" to be the same type as \" << (#VAR1) \\\n      << \"(\" << (VAR1)->dtype << \")\"                                     \\\n      << \". But got \" << (VAR2)->dtype << \".\";\n\n#define CHECK_SAME_CONTEXT(VAR1, VAR2)                                    \\\n  CHECK((VAR1)->ctx == (VAR2)->ctx)                                       \\\n      << \"Expected \" << (#VAR2) << \" to have the same device context as \" \\\n      << (#VAR1) << \"(\" << (VAR1)->ctx << \")\"                             \\\n      << \". But got \" << (VAR2)->ctx << \".\";\n\n#define CHECK_NO_OVERFLOW(dtype, val)                         \\\n  do {                                                        \\\n    if (sizeof(val) == 8 && (dtype).bits == 32)               \\\n      CHECK_LE((val), 0x7FFFFFFFL)                            \\\n          << \"int32 overflow for argument \" << (#val) << \".\"; \\\n  } while (0);\n\n#define CHECK_IS_ID_ARRAY(VAR)                                \\\n  CHECK((VAR)->ndim == 1 && (IS_INT32(VAR) || IS_INT64(VAR))) \\\n      << \"Expected argument \" << (#VAR) << \" to be an 1D integer array.\";\n\n#endif  // DGL_ATEN_MACRO_H_\n"
  },
  {
    "path": "include/dgl/aten/spmat.h",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file dgl/aten/spmat.h\n * @brief Sparse matrix definitions\n */\n#ifndef DGL_ATEN_SPMAT_H_\n#define DGL_ATEN_SPMAT_H_\n\n#include <string>\n#include <vector>\n\n#include \"../runtime/object.h\"\n#include \"./types.h\"\n\nnamespace dgl {\n\n/**\n * @brief Sparse format.\n */\nenum class SparseFormat {\n  kCOO = 1,\n  kCSR = 2,\n  kCSC = 3,\n};\n\n/**\n * @brief Sparse format codes\n */\nconst dgl_format_code_t ALL_CODE = 0x7;\nconst dgl_format_code_t ANY_CODE = 0x0;\nconst dgl_format_code_t COO_CODE = 0x1;\nconst dgl_format_code_t CSR_CODE = 0x2;\nconst dgl_format_code_t CSC_CODE = 0x4;\n\n// Parse sparse format from string.\ninline SparseFormat ParseSparseFormat(const std::string& name) {\n  if (name == \"coo\")\n    return SparseFormat::kCOO;\n  else if (name == \"csr\")\n    return SparseFormat::kCSR;\n  else if (name == \"csc\")\n    return SparseFormat::kCSC;\n  else\n    LOG(FATAL) << \"Sparse format not recognized\";\n  return SparseFormat::kCOO;\n}\n\n// Create string from sparse format.\ninline std::string ToStringSparseFormat(SparseFormat sparse_format) {\n  if (sparse_format == SparseFormat::kCOO)\n    return std::string(\"coo\");\n  else if (sparse_format == SparseFormat::kCSR)\n    return std::string(\"csr\");\n  else\n    return std::string(\"csc\");\n}\n\ninline std::vector<SparseFormat> CodeToSparseFormats(dgl_format_code_t code) {\n  std::vector<SparseFormat> ret;\n  if (code & COO_CODE) ret.push_back(SparseFormat::kCOO);\n  if (code & CSR_CODE) ret.push_back(SparseFormat::kCSR);\n  if (code & CSC_CODE) ret.push_back(SparseFormat::kCSC);\n  return ret;\n}\n\ninline dgl_format_code_t SparseFormatsToCode(\n    const std::vector<SparseFormat>& formats) {\n  dgl_format_code_t ret = 0;\n  for (auto format : formats) {\n    switch (format) {\n      case SparseFormat::kCOO:\n        ret |= COO_CODE;\n        break;\n      case SparseFormat::kCSR:\n        ret |= CSR_CODE;\n        break;\n      case SparseFormat::kCSC:\n        ret |= CSC_CODE;\n        break;\n      default:\n        LOG(FATAL) << \"Only support COO/CSR/CSC formats.\";\n    }\n  }\n  return ret;\n}\n\ninline std::string CodeToStr(dgl_format_code_t code) {\n  std::string ret = \"\";\n  if (code & COO_CODE) ret += \"coo \";\n  if (code & CSR_CODE) ret += \"csr \";\n  if (code & CSC_CODE) ret += \"csc \";\n  return ret;\n}\n\ninline SparseFormat DecodeFormat(dgl_format_code_t code) {\n  if (code & COO_CODE) return SparseFormat::kCOO;\n  if (code & CSC_CODE) return SparseFormat::kCSC;\n  return SparseFormat::kCSR;\n}\n\n// Sparse matrix object that is exposed to python API.\nstruct SparseMatrix : public runtime::Object {\n  // Sparse format.\n  int32_t format = 0;\n\n  // Shape of this matrix.\n  int64_t num_rows = 0, num_cols = 0;\n\n  // Index arrays. For CSR, it is {indptr, indices, data}. For COO, it is {row,\n  // col, data}.\n  std::vector<IdArray> indices;\n\n  // Boolean flags.\n  // TODO(minjie): We might revisit this later to provide a more general\n  // solution. Currently, we only consider aten::COOMatrix and aten::CSRMatrix.\n  std::vector<bool> flags;\n\n  SparseMatrix() {}\n\n  SparseMatrix(\n      int32_t fmt, int64_t nrows, int64_t ncols,\n      const std::vector<IdArray>& idx, const std::vector<bool>& flg)\n      : format(fmt),\n        num_rows(nrows),\n        num_cols(ncols),\n        indices(idx),\n        flags(flg) {}\n\n  static constexpr const char* _type_key = \"aten.SparseMatrix\";\n  DGL_DECLARE_OBJECT_TYPE_INFO(SparseMatrix, runtime::Object);\n};\n// Define SparseMatrixRef\nDGL_DEFINE_OBJECT_REF(SparseMatrixRef, SparseMatrix);\n\n}  // namespace dgl\n\n#endif  // DGL_ATEN_SPMAT_H_\n"
  },
  {
    "path": "include/dgl/aten/types.h",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file dgl/aten/types.h\n * @brief Array and ID types\n */\n#ifndef DGL_ATEN_TYPES_H_\n#define DGL_ATEN_TYPES_H_\n\n#include <cstdint>\n\n#include \"../runtime/ndarray.h\"\n\nnamespace dgl {\n\ntypedef uint64_t dgl_id_t;\ntypedef uint64_t dgl_type_t;\n/** @brief Type for dgl fomrat code, whose binary representation indices\n * which sparse format is in use and which is not.\n *\n * Suppose the binary representation is xyz, then\n * - x indicates whether csc is in use (1 for true and 0 for false).\n * - y indicates whether csr is in use.\n * - z indicates whether coo is in use.\n */\ntypedef uint8_t dgl_format_code_t;\n\nusing dgl::runtime::NDArray;\n\ntypedef NDArray IdArray;\ntypedef NDArray DegreeArray;\ntypedef NDArray BoolArray;\ntypedef NDArray IntArray;\ntypedef NDArray FloatArray;\ntypedef NDArray TypeArray;\n\nnamespace aten {\n\nstatic const DGLContext CPU{kDGLCPU, 0};\n\n}  // namespace aten\n}  // namespace dgl\n\n#endif  // DGL_ATEN_TYPES_H_\n"
  },
  {
    "path": "include/dgl/base_heterograph.h",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file dgl/heterograph_interface.h\n * @brief DGL heterogeneous graph index class.\n */\n\n#ifndef DGL_BASE_HETEROGRAPH_H_\n#define DGL_BASE_HETEROGRAPH_H_\n\n#include <algorithm>\n#include <memory>\n#include <string>\n#include <utility>\n#include <vector>\n\n#include \"./runtime/object.h\"\n#include \"array.h\"\n#include \"aten/spmat.h\"\n#include \"aten/types.h\"\n#include \"graph_interface.h\"\n\nnamespace dgl {\n\n// Forward declaration\nclass BaseHeteroGraph;\nclass HeteroPickleStates;\ntypedef std::shared_ptr<BaseHeteroGraph> HeteroGraphPtr;\n\nstruct FlattenedHeteroGraph;\ntypedef std::shared_ptr<FlattenedHeteroGraph> FlattenedHeteroGraphPtr;\n\nstruct HeteroSubgraph;\n\n/** @brief Enum class for edge direction */\nenum class EdgeDir {\n  kIn,  // in edge direction\n  kOut  // out edge direction\n};\n\n/**\n * @brief Base heterogenous graph.\n *\n * In heterograph, nodes represent entities and edges represent relations.\n * Nodes and edges are associated with types. The same pair of entity types\n * can have multiple relation types between them, but relation type **uniquely**\n * identifies the source and destination entity types.\n *\n * In a high-level, a heterograph is a data structure composed of:\n *  - A meta-graph that stores the entity-entity relation graph.\n *  - A dictionary of relation type to the bipartite graph representing the\n *    actual connections among entity nodes.\n */\nclass BaseHeteroGraph : public runtime::Object {\n public:\n  explicit BaseHeteroGraph(GraphPtr meta_graph) : meta_graph_(meta_graph) {}\n\n  virtual ~BaseHeteroGraph() = default;\n\n  ////////////////////// query/operations on meta graph ///////////////////////\n\n  /** @return the number of vertex types */\n  virtual uint64_t NumVertexTypes() const { return meta_graph_->NumVertices(); }\n\n  /** @return the number of edge types */\n  virtual uint64_t NumEdgeTypes() const { return meta_graph_->NumEdges(); }\n\n  /** @return given the edge type, find the source type */\n  virtual std::pair<dgl_type_t, dgl_type_t> GetEndpointTypes(\n      dgl_type_t etype) const {\n    return meta_graph_->FindEdge(etype);\n  }\n\n  /** @return the meta graph */\n  virtual GraphPtr meta_graph() const { return meta_graph_; }\n\n  /**\n   * @brief Return the bipartite graph of the given edge type.\n   * @param etype The edge type.\n   * @return The bipartite graph.\n   */\n  virtual HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const = 0;\n\n  ///////////////////// query/operations on realized graph /////////////////////\n\n  /** @brief Add vertices to the given vertex type */\n  virtual void AddVertices(dgl_type_t vtype, uint64_t num_vertices) = 0;\n\n  /** @brief Add one edge to the given edge type */\n  virtual void AddEdge(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) = 0;\n\n  /** @brief Add edges to the given edge type */\n  virtual void AddEdges(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) = 0;\n\n  /**\n   * @brief Clear the graph. Remove all vertices/edges.\n   */\n  virtual void Clear() = 0;\n\n  /**\n   * @brief Get the data type of node and edge IDs of this graph.\n   */\n  virtual DGLDataType DataType() const = 0;\n\n  /**\n   * @brief Get the device context of this graph.\n   */\n  virtual DGLContext Context() const = 0;\n\n  /**\n   * @brief Pin graph.\n   */\n  virtual void PinMemory_() = 0;\n\n  /**\n   * @brief Check if this graph is pinned.\n   */\n  virtual bool IsPinned() const = 0;\n\n  /**\n   * @brief Record stream for this graph.\n   * @param stream The stream that is using the graph\n   */\n  virtual void RecordStream(DGLStreamHandle stream) = 0;\n\n  /**\n   * @brief Get the number of integer bits used to store node/edge ids (32 or\n   * 64).\n   */\n  // TODO(BarclayII) replace NumBits() calls to DataType() calls\n  virtual uint8_t NumBits() const = 0;\n\n  /**\n   * @return whether the graph is a multigraph\n   */\n  virtual bool IsMultigraph() const = 0;\n\n  /** @return whether the graph is read-only */\n  virtual bool IsReadonly() const = 0;\n\n  /** @return the number of vertices in the graph.*/\n  virtual uint64_t NumVertices(dgl_type_t vtype) const = 0;\n\n  /** @return the number of vertices for each type in the graph as a vector */\n  inline virtual std::vector<int64_t> NumVerticesPerType() const {\n    LOG(FATAL) << \"[BUG] NumVerticesPerType() not supported on this object.\";\n    return {};\n  }\n\n  /** @return the number of edges in the graph.*/\n  virtual uint64_t NumEdges(dgl_type_t etype) const = 0;\n\n  /** @return true if the given vertex is in the graph.*/\n  virtual bool HasVertex(dgl_type_t vtype, dgl_id_t vid) const = 0;\n\n  /** @return a 0-1 array indicating whether the given vertices are in the\n   * graph.\n   */\n  virtual BoolArray HasVertices(dgl_type_t vtype, IdArray vids) const = 0;\n\n  /** @return true if the given edge is in the graph.*/\n  virtual bool HasEdgeBetween(\n      dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const = 0;\n\n  /** @return a 0-1 array indicating whether the given edges are in the graph.*/\n  virtual BoolArray HasEdgesBetween(\n      dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const = 0;\n\n  /**\n   * @brief Find the predecessors of a vertex.\n   * @note The given vertex should belong to the source vertex type\n   *       of the given edge type.\n   * @param etype The edge type\n   * @param vid The vertex id.\n   * @return the predecessor id array.\n   */\n  virtual IdArray Predecessors(dgl_type_t etype, dgl_id_t dst) const = 0;\n\n  /**\n   * @brief Find the successors of a vertex.\n   * @note The given vertex should belong to the dest vertex type\n   *       of the given edge type.\n   * @param etype The edge type\n   * @param vid The vertex id.\n   * @return the successor id array.\n   */\n  virtual IdArray Successors(dgl_type_t etype, dgl_id_t src) const = 0;\n\n  /**\n   * @brief Get all edge ids between the two given endpoints\n   * @note The given src and dst vertices should belong to the source vertex\n   * type and the dest vertex type of the given edge type, respectively.\n   * @param etype The edge type\n   * @param src The source vertex.\n   * @param dst The destination vertex.\n   * @return the edge id array.\n   */\n  virtual IdArray EdgeId(\n      dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const = 0;\n\n  /**\n   * @brief Get all edge ids between the given endpoint pairs.\n   *\n   * @param etype The edge type\n   * @param src The src vertex ids.\n   * @param dst The dst vertex ids.\n   * @return EdgeArray containing all edges between all pairs.\n   */\n  virtual EdgeArray EdgeIdsAll(\n      dgl_type_t etype, IdArray src, IdArray dst) const = 0;\n\n  /**\n   * @brief Get edge ids between the given endpoint pairs.\n   *\n   * Only find one matched edge Ids even if there are multiple matches due to\n   * parallel edges. The i^th Id in the returned array is for edge (src[i],\n   * dst[i]).\n   *\n   * @param etype The edge type\n   * @param src The src vertex ids.\n   * @param dst The dst vertex ids.\n   * @return EdgeArray containing all edges between all pairs.\n   */\n  virtual IdArray EdgeIdsOne(\n      dgl_type_t etype, IdArray src, IdArray dst) const = 0;\n\n  /**\n   * @brief Find the edge ID and return the pair of endpoints\n   * @param etype The edge type\n   * @param eid The edge ID\n   * @return a pair whose first element is the source and the second the\n   * destination.\n   */\n  virtual std::pair<dgl_id_t, dgl_id_t> FindEdge(\n      dgl_type_t etype, dgl_id_t eid) const = 0;\n\n  /**\n   * @brief Find the edge IDs and return their source and target node IDs.\n   * @param etype The edge type\n   * @param eids The edge ID array.\n   * @return EdgeArray containing all edges with id in eid.  The order is\n   * preserved.\n   */\n  virtual EdgeArray FindEdges(dgl_type_t etype, IdArray eids) const = 0;\n\n  /**\n   * @brief Get the in edges of the vertex.\n   * @note The given vertex should belong to the dest vertex type\n   *       of the given edge type.\n   * @param etype The edge type\n   * @param vid The vertex id.\n   * @return the edges\n   */\n  virtual EdgeArray InEdges(dgl_type_t etype, dgl_id_t vid) const = 0;\n\n  /**\n   * @brief Get the in edges of the vertices.\n   * @note The given vertex should belong to the dest vertex type\n   *       of the given edge type.\n   * @param etype The edge type\n   * @param vids The vertex id array.\n   * @return the id arrays of the two endpoints of the edges.\n   */\n  virtual EdgeArray InEdges(dgl_type_t etype, IdArray vids) const = 0;\n\n  /**\n   * @brief Get the out edges of the vertex.\n   * @note The given vertex should belong to the source vertex type\n   *       of the given edge type.\n   * @param etype The edge type\n   * @param vid The vertex id.\n   * @return the id arrays of the two endpoints of the edges.\n   */\n  virtual EdgeArray OutEdges(dgl_type_t etype, dgl_id_t vid) const = 0;\n\n  /**\n   * @brief Get the out edges of the vertices.\n   * @note The given vertex should belong to the source vertex type\n   *       of the given edge type.\n   * @param etype The edge type\n   * @param vids The vertex id array.\n   * @return the id arrays of the two endpoints of the edges.\n   */\n  virtual EdgeArray OutEdges(dgl_type_t etype, IdArray vids) const = 0;\n\n  /**\n   * @brief Get all the edges in the graph.\n   * @note If order is \"srcdst\", the returned edges list is sorted by their src\n   * and dst ids. If order is \"eid\", they are in their edge id order. Otherwise,\n   * in the arbitrary order.\n   * @param etype The edge type\n   * @param order The order of the returned edge list.\n   * @return the id arrays of the two endpoints of the edges.\n   */\n  virtual EdgeArray Edges(\n      dgl_type_t etype, const std::string& order = \"\") const = 0;\n\n  /**\n   * @brief Get the in degree of the given vertex.\n   * @note The given vertex should belong to the dest vertex type of the given\n   * edge type.\n   * @param etype The edge type\n   * @param vid The vertex id.\n   * @return the in degree\n   */\n  virtual uint64_t InDegree(dgl_type_t etype, dgl_id_t vid) const = 0;\n\n  /**\n   * @brief Get the in degrees of the given vertices.\n   * @note The given vertex should belong to the dest vertex type of the given\n   * edge type.\n   * @param etype The edge type\n   * @param vid The vertex id array.\n   * @return the in degree array\n   */\n  virtual DegreeArray InDegrees(dgl_type_t etype, IdArray vids) const = 0;\n\n  /**\n   * @brief Get the out degree of the given vertex.\n   * @note The given vertex should belong to the source vertex type of the given\n   * edge type.\n   * @param etype The edge type\n   * @param vid The vertex id.\n   * @return the out degree\n   */\n  virtual uint64_t OutDegree(dgl_type_t etype, dgl_id_t vid) const = 0;\n\n  /**\n   * @brief Get the out degrees of the given vertices.\n   * @note The given vertex should belong to the source vertex type of the given\n   * edge type.\n   * @param etype The edge type\n   * @param vid The vertex id array.\n   * @return the out degree array\n   */\n  virtual DegreeArray OutDegrees(dgl_type_t etype, IdArray vids) const = 0;\n\n  /**\n   * @brief Return the successor vector\n   * @note The given vertex should belong to the source vertex type of the given\n   * edge type.\n   * @param vid The vertex id.\n   * @return the successor vector iterator pair.\n   */\n  virtual DGLIdIters SuccVec(dgl_type_t etype, dgl_id_t vid) const = 0;\n\n  /**\n   * @brief Return the out edge id vector\n   * @note The given vertex should belong to the source vertex type of the given\n   * edge type.\n   * @param vid The vertex id.\n   * @return the out edge id vector iterator pair.\n   */\n  virtual DGLIdIters OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const = 0;\n\n  /**\n   * @brief Return the predecessor vector\n   * @note The given vertex should belong to the dest vertex type of the given\n   * edge type.\n   * @param vid The vertex id.\n   * @return the predecessor vector iterator pair.\n   */\n  virtual DGLIdIters PredVec(dgl_type_t etype, dgl_id_t vid) const = 0;\n\n  /**\n   * @brief Return the in edge id vector\n   * @note The given vertex should belong to the dest vertex type of the given\n   * edge type.\n   * @param vid The vertex id.\n   * @return the in edge id vector iterator pair.\n   */\n  virtual DGLIdIters InEdgeVec(dgl_type_t etype, dgl_id_t vid) const = 0;\n\n  /**\n   * @brief Get the adjacency matrix of the graph.\n   *\n   * TODO(minjie): deprecate this interface; replace it with GetXXXMatrix.\n   *\n   * By default, a row of returned adjacency matrix represents the destination\n   * of an edge and the column represents the source.\n   *\n   * If the fmt is 'csr', the function should return three arrays, representing\n   *  indptr, indices and edge ids\n   *\n   * If the fmt is 'coo', the function should return one array of shape (2,\n   * nnz), representing a horitonzal stack of row and col indices.\n   *\n   * @param transpose A flag to transpose the returned adjacency matrix.\n   * @param fmt the format of the returned adjacency matrix.\n   * @return a vector of IdArrays.\n   */\n  virtual std::vector<IdArray> GetAdj(\n      dgl_type_t etype, bool transpose, const std::string& fmt) const = 0;\n\n  /**\n   * @brief Determine which format to use with a preference.\n   *\n   * Otherwise, it will return whatever DGL thinks is the most appropriate given\n   * the arguments.\n   *\n   * @param etype Edge type.\n   * @param preferred_formats Preferred sparse formats.\n   * @return Available sparse format.\n   */\n  virtual SparseFormat SelectFormat(\n      dgl_type_t etype, dgl_format_code_t preferred_formats) const = 0;\n\n  /**\n   * @brief Return sparse formats already created for the graph.\n   *\n   * @return a number of type dgl_format_code_t.\n   */\n  virtual dgl_format_code_t GetCreatedFormats() const = 0;\n\n  /**\n   * @brief Return allowed sparse formats for the graph.\n   *\n   * @return a number of type dgl_format_code_t.\n   */\n  virtual dgl_format_code_t GetAllowedFormats() const = 0;\n\n  /**\n   * @brief Return the graph in specified available formats.\n   *\n   * @return The new graph.\n   */\n  virtual HeteroGraphPtr GetGraphInFormat(dgl_format_code_t formats) const = 0;\n\n  /**\n   * @brief Get adjacency matrix in COO format.\n   * @param etype Edge type.\n   * @return COO matrix.\n   */\n  virtual aten::COOMatrix GetCOOMatrix(dgl_type_t etype) const = 0;\n\n  /**\n   * @brief Get adjacency matrix in CSR format.\n   *\n   * The row and column sizes are equal to the number of dsttype and srctype\n   * nodes, respectively.\n   *\n   * @param etype Edge type.\n   * @return CSR matrix.\n   */\n  virtual aten::CSRMatrix GetCSRMatrix(dgl_type_t etype) const = 0;\n\n  /**\n   * @brief Get adjacency matrix in CSC format.\n   *\n   * A CSC matrix is equivalent to the transpose of a CSR matrix.\n   * We reuse the CSRMatrix data structure as return value. The row and column\n   * sizes are equal to the number of dsttype and srctype nodes, respectively.\n   *\n   * @param etype Edge type.\n   * @return A CSR matrix.\n   */\n  virtual aten::CSRMatrix GetCSCMatrix(dgl_type_t etype) const = 0;\n\n  /**\n   * @brief Extract the induced subgraph by the given vertices.\n   *\n   * The length of the given vector should be equal to the number of vertex\n   * types. Empty arrays can be provided if no vertex is needed for the type.\n   * The result subgraph has the same meta graph with the parent, but some types\n   * can have no node/edge.\n   *\n   * @param vids the induced vertices per type.\n   * @return the subgraph.\n   */\n  virtual HeteroSubgraph VertexSubgraph(\n      const std::vector<IdArray>& vids) const = 0;\n\n  /**\n   * @brief Extract the induced subgraph by the given edges.\n   *\n   * The length of the given vector should be equal to the number of edge types.\n   * Empty arrays can be provided if no edge is needed for the type. The result\n   * subgraph has the same meta graph with the parent, but some types can have\n   * no node/edge.\n   *\n   * @param eids The edges in the subgraph.\n   * @param preserve_nodes If true, the vertices will not be relabeled, so some\n   * vertices may have no incident edges.\n   * @return the subgraph.\n   */\n  virtual HeteroSubgraph EdgeSubgraph(\n      const std::vector<IdArray>& eids, bool preserve_nodes = false) const = 0;\n\n  /**\n   * @brief Convert the list of requested unitgraph graphs into a single\n   * unitgraph graph.\n   *\n   * @param etypes The list of edge type IDs.\n   * @return The flattened graph, with induced source/edge/destination\n   * types/IDs.\n   */\n  virtual FlattenedHeteroGraphPtr Flatten(\n      const std::vector<dgl_type_t>& etypes) const {\n    LOG(FATAL) << \"Flatten operation unsupported\";\n    return nullptr;\n  }\n\n  /** @brief Cast this graph to immutable graph */\n  virtual GraphPtr AsImmutableGraph() const {\n    LOG(FATAL) << \"AsImmutableGraph not supported.\";\n    return nullptr;\n  }\n\n  static constexpr const char* _type_key = \"graph.HeteroGraph\";\n  DGL_DECLARE_OBJECT_TYPE_INFO(BaseHeteroGraph, runtime::Object);\n\n protected:\n  /** @brief meta graph */\n  GraphPtr meta_graph_;\n\n  // empty constructor\n  BaseHeteroGraph() {}\n};\n\n// Define HeteroGraphRef\nDGL_DEFINE_OBJECT_REF(HeteroGraphRef, BaseHeteroGraph);\n\n/**\n * @brief Hetero-subgraph data structure.\n *\n * This class can be used as arguments and return values of a C API.\n *\n * <code>\n *   DGL_REGISTER_GLOBAL(\"some_c_api\")\n *   .set_body([] (DGLArgs args, DGLRetValue* rv) {\n *     HeteroSubgraphRef subg = args[0];\n *     std::shared_ptr<HeteroSubgraph> ret = do_something( ... );\n *     *rv = HeteroSubgraphRef(ret);\n *   });\n * </code>\n */\nstruct HeteroSubgraph : public runtime::Object {\n  /** @brief The heterograph. */\n  HeteroGraphPtr graph;\n  /**\n   * @brief The induced vertex ids of each entity type.\n   * The vector length is equal to the number of vertex types in the parent\n   * graph. Each array i has the same length as the number of vertices in type\n   * i. Empty array is allowed if the mapping is identity.\n   */\n  std::vector<IdArray> induced_vertices;\n  /**\n   * @brief The induced edge ids of each relation type.\n   * The vector length is equal to the number of edge types in the parent graph.\n   * Each array i has the same length as the number of edges in type i.\n   * Empty array is allowed if the mapping is identity.\n   */\n  std::vector<IdArray> induced_edges;\n\n  static constexpr const char* _type_key = \"graph.HeteroSubgraph\";\n  DGL_DECLARE_OBJECT_TYPE_INFO(HeteroSubgraph, runtime::Object);\n};\n\n// Define HeteroSubgraphRef\nDGL_DEFINE_OBJECT_REF(HeteroSubgraphRef, HeteroSubgraph);\n\n/** @brief The flattened heterograph */\nstruct FlattenedHeteroGraph : public runtime::Object {\n  /** @brief The graph */\n  HeteroGraphRef graph;\n  /**\n   * @brief Mapping from source node ID to node type in parent graph\n   * @note The induced type array guarantees that the same type always appear\n   * contiguously.\n   */\n  IdArray induced_srctype;\n  /**\n   * @brief The set of node types in parent graph appearing in source nodes.\n   */\n  IdArray induced_srctype_set;\n  /** @brief Mapping from source node ID to local node ID in parent graph */\n  IdArray induced_srcid;\n  /**\n   * @brief Mapping from edge ID to edge type in parent graph\n   * @note The induced type array guarantees that the same type always appear\n   * contiguously.\n   */\n  IdArray induced_etype;\n  /**\n   * @brief The set of edge types in parent graph appearing in edges.\n   */\n  IdArray induced_etype_set;\n  /** @brief Mapping from edge ID to local edge ID in parent graph */\n  IdArray induced_eid;\n  /**\n   * @brief Mapping from destination node ID to node type in parent graph\n   * @note The induced type array guarantees that the same type always appear\n   * contiguously.\n   */\n  IdArray induced_dsttype;\n  /**\n   * @brief The set of node types in parent graph appearing in destination\n   * nodes.\n   */\n  IdArray induced_dsttype_set;\n  /** @brief Mapping from destination node ID to local node ID in parent graph\n   */\n  IdArray induced_dstid;\n\n  void VisitAttrs(runtime::AttrVisitor* v) final {\n    v->Visit(\"graph\", &graph);\n    v->Visit(\"induced_srctype\", &induced_srctype);\n    v->Visit(\"induced_srctype_set\", &induced_srctype_set);\n    v->Visit(\"induced_srcid\", &induced_srcid);\n    v->Visit(\"induced_etype\", &induced_etype);\n    v->Visit(\"induced_etype_set\", &induced_etype_set);\n    v->Visit(\"induced_eid\", &induced_eid);\n    v->Visit(\"induced_dsttype\", &induced_dsttype);\n    v->Visit(\"induced_dsttype_set\", &induced_dsttype_set);\n    v->Visit(\"induced_dstid\", &induced_dstid);\n  }\n\n  static constexpr const char* _type_key = \"graph.FlattenedHeteroGraph\";\n  DGL_DECLARE_OBJECT_TYPE_INFO(FlattenedHeteroGraph, runtime::Object);\n};\nDGL_DEFINE_OBJECT_REF(FlattenedHeteroGraphRef, FlattenedHeteroGraph);\n\n// Declarations of functions and algorithms\n\n/**\n * @brief Create a heterograph from meta graph and a list of bipartite graph,\n * additionally specifying number of nodes per type.\n */\nHeteroGraphPtr CreateHeteroGraph(\n    GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs,\n    const std::vector<int64_t>& num_nodes_per_type = {});\n\n/**\n * @brief Create a heterograph from COO input.\n * @param num_vtypes Number of vertex types. Must be 1 or 2.\n * @param num_src Number of nodes in the source type.\n * @param num_dst Number of nodes in the destination type.\n * @param row Src node ids of the edges.\n * @param col Dst node ids of the edges.\n * @param row_sorted Whether the `row` array is in sorted ascending order.\n * @param col_sorted When `row_sorted` is true, whether the columns within each\n * row are also sorted. When `row_sorted` is false, this flag must also be\n * false.\n * @param formats Sparse formats used for storing this graph.\n * @return A heterograph pointer.\n */\nHeteroGraphPtr CreateFromCOO(\n    int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray row,\n    IdArray col, bool row_sorted = false, bool col_sorted = false,\n    dgl_format_code_t formats = ALL_CODE);\n\n/**\n * @brief Create a heterograph from COO input.\n * @param num_vtypes Number of vertex types. Must be 1 or 2.\n * @param mat The COO matrix\n * @param formats Sparse formats used for storing this graph.\n * @return A heterograph pointer.\n */\nHeteroGraphPtr CreateFromCOO(\n    int64_t num_vtypes, const aten::COOMatrix& mat,\n    dgl_format_code_t formats = ALL_CODE);\n\n/**\n * @brief Create a heterograph from CSR input.\n * @param num_vtypes Number of vertex types. Must be 1 or 2.\n * @param num_src Number of nodes in the source type.\n * @param num_dst Number of nodes in the destination type.\n * @param indptr Indptr array\n * @param indices Indices array\n * @param edge_ids Edge ids\n * @param formats Sparse formats for storing this graph.\n * @return A heterograph pointer.\n */\nHeteroGraphPtr CreateFromCSR(\n    int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray indptr,\n    IdArray indices, IdArray edge_ids, dgl_format_code_t formats = ALL_CODE);\n\n/**\n * @brief Create a heterograph from CSR input.\n * @param num_vtypes Number of vertex types. Must be 1 or 2.\n * @param mat The CSR matrix\n * @param formats Sparse formats for storing this graph.\n * @return A heterograph pointer.\n */\nHeteroGraphPtr CreateFromCSR(\n    int64_t num_vtypes, const aten::CSRMatrix& mat,\n    dgl_format_code_t formats = ALL_CODE);\n\n/**\n * @brief Create a heterograph from CSC input.\n * @param num_vtypes Number of vertex types. Must be 1 or 2.\n * @param num_src Number of nodes in the source type.\n * @param num_dst Number of nodes in the destination type.\n * @param indptr Indptr array\n * @param indices Indices array\n * @param edge_ids Edge ids\n * @param formats Sparse formats used for storing this graph.\n * @return A heterograph pointer.\n */\nHeteroGraphPtr CreateFromCSC(\n    int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray indptr,\n    IdArray indices, IdArray edge_ids, dgl_format_code_t formats = ALL_CODE);\n\n/**\n * @brief Create a heterograph from CSC input.\n * @param num_vtypes Number of vertex types. Must be 1 or 2.\n * @param mat The CSC matrix\n * @param formats Sparse formats available for storing this graph.\n * @return A heterograph pointer.\n */\nHeteroGraphPtr CreateFromCSC(\n    int64_t num_vtypes, const aten::CSRMatrix& mat,\n    dgl_format_code_t formats = ALL_CODE);\n\n/**\n * @brief Extract the subgraph of the in edges of the given nodes.\n * @param graph Graph\n * @param nodes Node IDs of each type\n * @param relabel_nodes Whether to remove isolated nodes and relabel the rest\n * ones\n * @return Subgraph containing only the in edges. The returned graph has\n * the same schema as the original one.\n */\nHeteroSubgraph InEdgeGraph(\n    const HeteroGraphPtr graph, const std::vector<IdArray>& nodes,\n    bool relabel_nodes = false);\n\n/**\n * @brief Extract the subgraph of the out edges of the given nodes.\n * @param graph Graph\n * @param nodes Node IDs of each type\n * @param relabel_nodes Whether to remove isolated nodes and relabel the rest\n * ones\n * @return Subgraph containing only the out edges. The returned graph has\n * the same schema as the original one.\n */\nHeteroSubgraph OutEdgeGraph(\n    const HeteroGraphPtr graph, const std::vector<IdArray>& nodes,\n    bool relabel_nodes = false);\n\n/**\n * @brief Joint union multiple graphs into one graph.\n *\n * All input graphs should have the same metagraph.\n *\n * TODO(xiangsx): remove the meta_graph argument\n *\n * @param meta_graph Metagraph of the inputs and result.\n * @param component_graphs Input graphs\n * @return One graph that unions all the components\n */\nHeteroGraphPtr JointUnionHeteroGraph(\n    GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs);\n\n/**\n * @brief Union multiple graphs into one with each input graph as one disjoint\n * component.\n *\n * All input graphs should have the same metagraph.\n *\n * TODO(minjie): remove the meta_graph argument\n *\n * @tparam IdType Graph's index data type, can be int32_t or int64_t\n * @param meta_graph Metagraph of the inputs and result.\n * @param component_graphs Input graphs\n * @return One graph that unions all the components\n */\ntemplate <class IdType>\nHeteroGraphPtr DisjointUnionHeteroGraph(\n    GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs);\n\nHeteroGraphPtr DisjointUnionHeteroGraph2(\n    GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs);\n\n/**\n * @brief Slice a contiguous subgraph, e.g. retrieve a component graph from a\n * batched graph.\n *\n * TODO(mufei): remove the meta_graph argument\n *\n * @param meta_graph Metagraph of the input and result.\n * @param batched_graph Input graph.\n * @param num_nodes_per_type Number of vertices of each type in the result.\n * @param start_nid_per_type Start vertex ID of each type to slice.\n * @param num_edges_per_type Number of edges of each type in the result.\n * @param start_eid_per_type Start edge ID of each type to slice.\n * @return Sliced graph\n */\nHeteroGraphPtr SliceHeteroGraph(\n    GraphPtr meta_graph, HeteroGraphPtr batched_graph,\n    IdArray num_nodes_per_type, IdArray start_nid_per_type,\n    IdArray num_edges_per_type, IdArray start_eid_per_type);\n\n/**\n * @brief Split a graph into multiple disjoin components.\n *\n * Edges across different components are ignored. All the result graphs have the\n * same metagraph as the input one.\n *\n * The `vertex_sizes` and `edge_sizes` arrays the concatenation of arrays of\n * each node/edge type. Suppose there are N vertex types, then the array length\n * should be B*N, where B is the number of components to split.\n *\n * TODO(minjie): remove the meta_graph argument; use vector<IdArray> for\n * vertex_sizes and edge_sizes.\n *\n * @tparam IdType Graph's index data type, can be int32_t or int64_t\n * @param meta_graph Metagraph.\n * @param batched_graph Input graph.\n * @param vertex_sizes Number of vertices of each component.\n * @param edge_sizes Number of vertices of each component.\n * @return A list of graphs representing each disjoint components.\n */\ntemplate <class IdType>\nstd::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes(\n    GraphPtr meta_graph, HeteroGraphPtr batched_graph, IdArray vertex_sizes,\n    IdArray edge_sizes);\n\nstd::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes2(\n    GraphPtr meta_graph, HeteroGraphPtr batched_graph, IdArray vertex_sizes,\n    IdArray edge_sizes);\n\n/**\n * @brief Structure for pickle/unpickle.\n *\n * The design principle is to leverage the NDArray class as much as possible so\n * that when they are converted to backend-specific tensors, we could leverage\n * the efficient pickle/unpickle solutions from the backend framework.\n *\n * NOTE(minjie): This is a temporary solution before we support shared memory\n *   storage ourselves.\n *\n * This class can be used as arguments and return values of a C API.\n */\nstruct HeteroPickleStates : public runtime::Object {\n  /** @brief version number */\n  int64_t version = 0;\n\n  /** @brief Metainformation\n   *\n   * metagraph, number of nodes per type, format, flags\n   */\n  std::string meta;\n\n  /** @brief Arrays representing graph structure (coo or csr) */\n  std::vector<IdArray> arrays;\n\n  /* To support backward compatibility, we have to retain fields in the old\n   * version of HeteroPickleStates\n   */\n\n  /** @brief Metagraph(64bits ImmutableGraph) */\n  GraphPtr metagraph;\n\n  /** @brief Number of nodes per type */\n  std::vector<int64_t> num_nodes_per_type;\n\n  /** @brief adjacency matrices of each relation graph */\n  std::vector<std::shared_ptr<SparseMatrix> > adjs;\n\n  static constexpr const char* _type_key = \"graph.HeteroPickleStates\";\n  DGL_DECLARE_OBJECT_TYPE_INFO(HeteroPickleStates, runtime::Object);\n};\n\n// Define HeteroPickleStatesRef\nDGL_DEFINE_OBJECT_REF(HeteroPickleStatesRef, HeteroPickleStates);\n\n/**\n * @brief Create a heterograph from pickling states.\n *\n * @param states Pickle states\n * @return A heterograph pointer\n */\nHeteroGraphPtr HeteroUnpickle(const HeteroPickleStates& states);\n\n/**\n * @brief Get the pickling state of the relation graph structure in backend\n * tensors.\n *\n * @return a HeteroPickleStates object\n */\nHeteroPickleStates HeteroPickle(HeteroGraphPtr graph);\n\n/**\n * @brief Old version of HeteroUnpickle, for backward compatibility\n *\n * @param states Pickle states\n * @return A heterograph pointer\n */\nHeteroGraphPtr HeteroUnpickleOld(const HeteroPickleStates& states);\n\n/**\n * @brief Create heterograph from pickling states pickled by ForkingPickler.\n *\n * This is different from HeteroUnpickle where\n * (1) Backward compatibility is not required,\n * (2) All graph formats are pickled instead of only one.\n */\nHeteroGraphPtr HeteroForkingUnpickle(const HeteroPickleStates& states);\n\n/**\n * @brief Get the pickling states of the relation graph structure in backend\n * tensors for ForkingPickler.\n *\n * This is different from HeteroPickle where\n * (1) Backward compatibility is not required,\n * (2) All graph formats are pickled instead of only one.\n */\nHeteroPickleStates HeteroForkingPickle(HeteroGraphPtr graph);\n\n#define FORMAT_HAS_CSC(format) ((format)&CSC_CODE)\n\n#define FORMAT_HAS_CSR(format) ((format)&CSR_CODE)\n\n#define FORMAT_HAS_COO(format) ((format)&COO_CODE)\n\n}  // namespace dgl\n\n#endif  // DGL_BASE_HETEROGRAPH_H_\n"
  },
  {
    "path": "include/dgl/bcast.h",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file dgl/aten/bcast.h\n * @brief Broadcast related function C++ header.\n */\n#ifndef DGL_BCAST_H_\n#define DGL_BCAST_H_\n\n#include <string>\n#include <vector>\n\n#include \"./runtime/ndarray.h\"\n\nusing namespace dgl::runtime;\nnamespace dgl {\n\n/**\n * @brief Broadcast offsets and auxiliary information.\n */\nstruct BcastOff {\n  /**\n   * @brief offset vector of lhs operand and rhs operand.\n   * @note lhs_offset[i] indicates the start position of the scalar\n   *       in lhs operand that required to compute the i-th element\n   *       in the output, likewise for rhs_offset.\n   *\n   * For example, when lhs array has shape (1, 3) and rhs array\n   * has shape (5, 1), the resulting array would have shape (5, 3),\n   * then both lhs_offset and rhs_offset would contain 15 elements.\n   *\n   * lhs_offset: 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2\n   * rhs_offset: 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4\n   *\n   * in order to compute the 7-th (row 2, column 0) element in the output,\n   * we need the 0-th element in the lhs array and the 2-th element in the\n   * rhs array.\n   */\n  std::vector<int64_t> lhs_offset, rhs_offset;\n  /** @brief Whether broadcast is required or not. */\n  bool use_bcast;\n  /**\n   * @brief Auxiliary information for kernel computation\n   * @note lhs_len refers to the left hand side operand length.\n   *       e.g. 15 for shape (1, 3, 5)\n   *       rhs_len refers to the right hand side operand length.\n   *       e.g. 15 for shape (3, 1, 5)\n   *       out_len refers to the output length.\n   *       e.g. 45 for shape (3, 3, 5)\n   *       reduce_size refers to the reduction size (for op like dot).\n   *       e.g. 1 for add, 5 for dot and lhs_shape,rhs_shape=(3,5)\n   */\n  int64_t lhs_len, rhs_len, out_len, reduce_size;\n};\n\n/**\n * @brief: Compute broadcast and auxiliary information given operator\n *         and operands for kernel computation.\n * @param op: a string indicates the operator, could be `add`, `sub`,\n *        `mul`, `div`, `dot`, 'copy_u`, `copy_e`.\n * @param lhs The left hand side operand of NDArray class.\n * @param rhs The right hand side operand of NDArray class.\n * @return the broadcast information of BcastOff class.\n */\nBcastOff CalcBcastOff(const std::string& op, NDArray lhs, NDArray rhs);\n\n}  // namespace dgl\n\n#endif  // DGL_BCAST_H_\n"
  },
  {
    "path": "include/dgl/env_variable.h",
    "content": "/**\n *  Copyright (c) 2023 by Contributors\n * @file dgl/env_variable.h\n * @brief Class about envrionment variables.\n */\n#ifndef DGL_ENV_VARIABLE_H_\n#define DGL_ENV_VARIABLE_H_\n\n#include <cstdlib>\n\nnamespace dgl {\n\nstatic const char* kDGLParallelForGrainSize =\n    std::getenv(\"DGL_PARALLEL_FOR_GRAIN_SIZE\");\n\n}  // namespace dgl\n\n#endif  // DGL_ENV_VARIABLE_H_\n"
  },
  {
    "path": "include/dgl/graph.h",
    "content": "/**\n *  Copyright (c) 2018 by Contributors\n * @file dgl/graph.h\n * @brief DGL graph index class.\n */\n#ifndef DGL_GRAPH_H_\n#define DGL_GRAPH_H_\n\n#include <cstdint>\n#include <memory>\n#include <string>\n#include <tuple>\n#include <utility>\n#include <vector>\n\n#include \"graph_interface.h\"\n\nnamespace dgl {\n\nclass Graph;\nclass GraphOp;\ntypedef std::shared_ptr<Graph> MutableGraphPtr;\n\n/** @brief Mutable graph based on adjacency list. */\nclass Graph : public GraphInterface {\n public:\n  /** @brief default constructor */\n  Graph() {}\n\n  /** @brief construct a graph from the coo format. */\n  Graph(IdArray src_ids, IdArray dst_ids, size_t num_nodes);\n\n  /** @brief default copy constructor */\n  Graph(const Graph& other) = default;\n\n#ifndef _MSC_VER\n  /** @brief default move constructor */\n  Graph(Graph&& other) = default;\n#else\n  Graph(Graph&& other) {\n    adjlist_ = other.adjlist_;\n    reverse_adjlist_ = other.reverse_adjlist_;\n    all_edges_src_ = other.all_edges_src_;\n    all_edges_dst_ = other.all_edges_dst_;\n    read_only_ = other.read_only_;\n    num_edges_ = other.num_edges_;\n    other.Clear();\n  }\n#endif  // _MSC_VER\n\n  /** @brief default assign constructor */\n  Graph& operator=(const Graph& other) = default;\n\n  /** @brief default destructor */\n  ~Graph() = default;\n\n  /**\n   * @brief Add vertices to the graph.\n   * @note Since vertices are integers enumerated from zero, only the number of\n   *       vertices to be added needs to be specified.\n   * @param num_vertices The number of vertices to be added.\n   */\n  void AddVertices(uint64_t num_vertices) override;\n\n  /**\n   * @brief Add one edge to the graph.\n   * @param src The source vertex.\n   * @param dst The destination vertex.\n   */\n  void AddEdge(dgl_id_t src, dgl_id_t dst) override;\n\n  /**\n   * @brief Add edges to the graph.\n   * @param src_ids The source vertex id array.\n   * @param dst_ids The destination vertex id array.\n   */\n  void AddEdges(IdArray src_ids, IdArray dst_ids) override;\n\n  /**\n   * @brief Clear the graph. Remove all vertices/edges.\n   */\n  void Clear() override {\n    adjlist_.clear();\n    reverse_adjlist_.clear();\n    all_edges_src_.clear();\n    all_edges_dst_.clear();\n    read_only_ = false;\n    num_edges_ = 0;\n  }\n\n  DGLContext Context() const override { return DGLContext{kDGLCPU, 0}; }\n\n  uint8_t NumBits() const override { return 64; }\n\n  /**\n   * @note not const since we have caches\n   * @return whether the graph is a multigraph\n   */\n  bool IsMultigraph() const override;\n\n  /**\n   * @return whether the graph is read-only\n   */\n  bool IsReadonly() const override { return false; }\n\n  /** @return the number of vertices in the graph.*/\n  uint64_t NumVertices() const override { return adjlist_.size(); }\n\n  /** @return the number of edges in the graph.*/\n  uint64_t NumEdges() const override { return num_edges_; }\n\n  /** @return a 0-1 array indicating whether the given vertices are in the\n   * graph.\n   */\n  BoolArray HasVertices(IdArray vids) const override;\n\n  /** @return true if the given edge is in the graph.*/\n  bool HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const override;\n\n  /** @return a 0-1 array indicating whether the given edges are in the graph.*/\n  BoolArray HasEdgesBetween(IdArray src_ids, IdArray dst_ids) const override;\n\n  /**\n   * @brief Find the predecessors of a vertex.\n   * @param vid The vertex id.\n   * @param radius The radius of the neighborhood. Default is immediate neighbor\n   *        (radius=1).\n   * @return the predecessor id array.\n   */\n  IdArray Predecessors(dgl_id_t vid, uint64_t radius = 1) const override;\n\n  /**\n   * @brief Find the successors of a vertex.\n   * @param vid The vertex id.\n   * @param radius The radius of the neighborhood. Default is immediate neighbor\n   *        (radius=1).\n   * @return the successor id array.\n   */\n  IdArray Successors(dgl_id_t vid, uint64_t radius = 1) const override;\n\n  /**\n   * @brief Get all edge ids between the two given endpoints\n   * @note Edges are associated with an integer id start from zero.\n   *       The id is assigned when the edge is being added to the graph.\n   * @param src The source vertex.\n   * @param dst The destination vertex.\n   * @return the edge id array.\n   */\n  IdArray EdgeId(dgl_id_t src, dgl_id_t dst) const override;\n\n  /**\n   * @brief Get all edge ids between the given endpoint pairs.\n   * @note Edges are associated with an integer id start from zero.\n   *       The id is assigned when the edge is being added to the graph.\n   *       If duplicate pairs exist, the returned edge IDs will also duplicate.\n   *       The order of returned edge IDs will follow the order of src-dst pairs\n   *       first, and ties are broken by the order of edge ID.\n   * @return EdgeArray containing all edges between all pairs.\n   */\n  EdgeArray EdgeIds(IdArray src, IdArray dst) const override;\n\n  /**\n   * @brief Find the edge ID and return the pair of endpoints\n   * @param eid The edge ID\n   * @return a pair whose first element is the source and the second the\n   * destination.\n   */\n  std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_id_t eid) const override {\n    return std::make_pair(all_edges_src_[eid], all_edges_dst_[eid]);\n  }\n\n  /**\n   * @brief Find the edge IDs and return their source and target node IDs.\n   * @param eids The edge ID array.\n   * @return EdgeArray containing all edges with id in eid.  The order is\n   *         preserved.\n   */\n  EdgeArray FindEdges(IdArray eids) const override;\n\n  /**\n   * @brief Get the in edges of the vertex.\n   * @note The returned dst id array is filled with vid.\n   * @param vid The vertex id.\n   * @return the edges\n   */\n  EdgeArray InEdges(dgl_id_t vid) const override;\n\n  /**\n   * @brief Get the in edges of the vertices.\n   * @param vids The vertex id array.\n   * @return the id arrays of the two endpoints of the edges.\n   */\n  EdgeArray InEdges(IdArray vids) const override;\n\n  /**\n   * @brief Get the out edges of the vertex.\n   * @note The returned src id array is filled with vid.\n   * @param vid The vertex id.\n   * @return the id arrays of the two endpoints of the edges.\n   */\n  EdgeArray OutEdges(dgl_id_t vid) const override;\n\n  /**\n   * @brief Get the out edges of the vertices.\n   * @param vids The vertex id array.\n   * @return the id arrays of the two endpoints of the edges.\n   */\n  EdgeArray OutEdges(IdArray vids) const override;\n\n  /**\n   * @brief Get all the edges in the graph.\n   * @note If sorted is true, the returned edges list is sorted by their src and\n   *       dst ids. Otherwise, they are in their edge id order.\n   * @param sorted Whether the returned edge list is sorted by their src and dst\n   *        ids.\n   * @return the id arrays of the two endpoints of the edges.\n   */\n  EdgeArray Edges(const std::string& order = \"\") const override;\n\n  /**\n   * @brief Get the in degree of the given vertex.\n   * @param vid The vertex id.\n   * @return the in degree\n   */\n  uint64_t InDegree(dgl_id_t vid) const override {\n    CHECK(HasVertex(vid)) << \"invalid vertex: \" << vid;\n    return reverse_adjlist_[vid].succ.size();\n  }\n\n  /**\n   * @brief Get the in degrees of the given vertices.\n   * @param vid The vertex id array.\n   * @return the in degree array\n   */\n  DegreeArray InDegrees(IdArray vids) const override;\n\n  /**\n   * @brief Get the out degree of the given vertex.\n   * @param vid The vertex id.\n   * @return the out degree\n   */\n  uint64_t OutDegree(dgl_id_t vid) const override {\n    CHECK(HasVertex(vid)) << \"invalid vertex: \" << vid;\n    return adjlist_[vid].succ.size();\n  }\n\n  /**\n   * @brief Get the out degrees of the given vertices.\n   * @param vid The vertex id array.\n   * @return the out degree array\n   */\n  DegreeArray OutDegrees(IdArray vids) const override;\n\n  /**\n   * @brief Construct the induced subgraph of the given vertices.\n   *\n   * The induced subgraph is a subgraph formed by specifying a set of vertices\n   * V' and then selecting all of the edges from the original graph that connect\n   * two vertices in V'.\n   *\n   * Vertices and edges in the original graph will be \"reindexed\" to local\n   * index. The local index of the vertices preserve the order of the given id\n   * array, while the local index of the edges preserve the index order in the\n   * original graph. Vertices not in the original graph are ignored.\n   *\n   * The result subgraph is read-only.\n   *\n   * @param vids The vertices in the subgraph.\n   * @return the induced subgraph\n   */\n  Subgraph VertexSubgraph(IdArray vids) const override;\n\n  /**\n   * @brief Construct the induced edge subgraph of the given edges.\n   *\n   * The induced edges subgraph is a subgraph formed by specifying a set of\n   * edges E' and then selecting all of the nodes from the original graph that\n   * are endpoints in E'.\n   *\n   * Vertices and edges in the original graph will be \"reindexed\" to local\n   * index. The local index of the edges preserve the order of the given id\n   * array, while the local index of the vertices preserve the index order in\n   * the original graph. Edges not in the original graph are ignored.\n   *\n   * The result subgraph is read-only.\n   *\n   * @param eids The edges in the subgraph.\n   * @return the induced edge subgraph\n   */\n  Subgraph EdgeSubgraph(\n      IdArray eids, bool preserve_nodes = false) const override;\n\n  /**\n   * @brief Return the successor vector\n   * @param vid The vertex id.\n   * @return the successor vector\n   */\n  DGLIdIters SuccVec(dgl_id_t vid) const override {\n    auto data = adjlist_[vid].succ.data();\n    auto size = adjlist_[vid].succ.size();\n    return DGLIdIters(data, data + size);\n  }\n\n  /**\n   * @brief Return the out edge id vector\n   * @param vid The vertex id.\n   * @return the out edge id vector\n   */\n  DGLIdIters OutEdgeVec(dgl_id_t vid) const override {\n    auto data = adjlist_[vid].edge_id.data();\n    auto size = adjlist_[vid].edge_id.size();\n    return DGLIdIters(data, data + size);\n  }\n\n  /**\n   * @brief Return the predecessor vector\n   * @param vid The vertex id.\n   * @return the predecessor vector\n   */\n  DGLIdIters PredVec(dgl_id_t vid) const override {\n    auto data = reverse_adjlist_[vid].succ.data();\n    auto size = reverse_adjlist_[vid].succ.size();\n    return DGLIdIters(data, data + size);\n  }\n\n  /**\n   * @brief Return the in edge id vector\n   * @param vid The vertex id.\n   * @return the in edge id vector\n   */\n  DGLIdIters InEdgeVec(dgl_id_t vid) const override {\n    auto data = reverse_adjlist_[vid].edge_id.data();\n    auto size = reverse_adjlist_[vid].edge_id.size();\n    return DGLIdIters(data, data + size);\n  }\n\n  /**\n   * @brief Get the adjacency matrix of the graph.\n   *\n   * By default, a row of returned adjacency matrix represents the destination\n   * of an edge and the column represents the source.\n   * @param transpose A flag to transpose the returned adjacency matrix.\n   * @param fmt the format of the returned adjacency matrix.\n   * @return a vector of three IdArray.\n   */\n  std::vector<IdArray> GetAdj(\n      bool transpose, const std::string& fmt) const override;\n\n  /** @brief Create an empty graph */\n  static MutableGraphPtr Create() { return std::make_shared<Graph>(); }\n\n  /** @brief Create from coo */\n  static MutableGraphPtr CreateFromCOO(\n      int64_t num_nodes, IdArray src_ids, IdArray dst_ids) {\n    return std::make_shared<Graph>(src_ids, dst_ids, num_nodes);\n  }\n\n protected:\n  friend class GraphOp;\n  /** @brief Internal edge list type */\n  struct EdgeList {\n    /** @brief successor vertex list */\n    std::vector<dgl_id_t> succ;\n    /** @brief out edge list */\n    std::vector<dgl_id_t> edge_id;\n  };\n  typedef std::vector<EdgeList> AdjacencyList;\n\n  /** @brief adjacency list using vector storage */\n  AdjacencyList adjlist_;\n  /** @brief reverse adjacency list using vector storage */\n  AdjacencyList reverse_adjlist_;\n\n  /** @brief all edges' src endpoints in their edge id order */\n  std::vector<dgl_id_t> all_edges_src_;\n  /** @brief all edges' dst endpoints in their edge id order */\n  std::vector<dgl_id_t> all_edges_dst_;\n\n  /** @brief read only flag */\n  bool read_only_ = false;\n\n  /** @brief number of edges */\n  uint64_t num_edges_ = 0;\n};\n\n}  // namespace dgl\n\n#endif  // DGL_GRAPH_H_\n"
  },
  {
    "path": "include/dgl/graph_interface.h",
    "content": "/**\n *  Copyright (c) 2018 by Contributors\n * @file dgl/graph_interface.h\n * @brief DGL graph index class.\n */\n#ifndef DGL_GRAPH_INTERFACE_H_\n#define DGL_GRAPH_INTERFACE_H_\n\n#include <algorithm>\n#include <memory>\n#include <string>\n#include <utility>\n#include <vector>\n\n#include \"./runtime/object.h\"\n#include \"array.h\"\n\nnamespace dgl {\n\nconst dgl_id_t DGL_INVALID_ID = static_cast<dgl_id_t>(-1);\n\n/**\n * @brief This class references data in std::vector.\n *\n * This isn't a STL-style iterator. It provides a STL data container interface.\n * but it doesn't own data itself. instead, it only references data in\n * std::vector.\n */\nclass DGLIdIters {\n public:\n  /** @brief default constructor to create an empty range */\n  DGLIdIters() {}\n  /** @brief constructor with given begin and end */\n  DGLIdIters(const dgl_id_t *begin, const dgl_id_t *end) {\n    this->begin_ = begin;\n    this->end_ = end;\n  }\n  const dgl_id_t *begin() const { return this->begin_; }\n  const dgl_id_t *end() const { return this->end_; }\n  dgl_id_t operator[](int64_t i) const { return *(this->begin_ + i); }\n  size_t size() const { return this->end_ - this->begin_; }\n\n private:\n  const dgl_id_t *begin_{nullptr}, *end_{nullptr};\n};\n\n/**\n * @brief int32 version for DGLIdIters\n *\n */\nclass DGLIdIters32 {\n public:\n  /** @brief default constructor to create an empty range */\n  DGLIdIters32() {}\n  /** @brief constructor with given begin and end */\n  DGLIdIters32(const int32_t *begin, const int32_t *end) {\n    this->begin_ = begin;\n    this->end_ = end;\n  }\n  const int32_t *begin() const { return this->begin_; }\n  const int32_t *end() const { return this->end_; }\n  int32_t operator[](int32_t i) const { return *(this->begin_ + i); }\n  size_t size() const { return this->end_ - this->begin_; }\n\n private:\n  const int32_t *begin_{nullptr}, *end_{nullptr};\n};\n\n/* @brief structure used to represent a list of edges */\ntypedef struct {\n  /* @brief the two endpoints and the id of the edge */\n  IdArray src, dst, id;\n} EdgeArray;\n\n// forward declaration\nstruct Subgraph;\nclass GraphRef;\nclass GraphInterface;\ntypedef std::shared_ptr<GraphInterface> GraphPtr;\n\n/**\n * @brief dgl graph index interface.\n *\n * DGL's graph is directed. Vertices are integers enumerated from zero.\n *\n * When calling functions supporing multiple edges (e.g. AddEdges, HasEdges),\n * the input edges are represented by two id arrays for source and destination\n * vertex ids. In the general case, the two arrays should have the same length.\n * If the length of src id array is one, it represents one-many connections.\n * If the length of dst id array is one, it represents many-one connections.\n */\nclass GraphInterface : public runtime::Object {\n public:\n  virtual ~GraphInterface() = default;\n\n  /**\n   * @brief Add vertices to the graph.\n   * @note Since vertices are integers enumerated from zero, only the number of\n   *       vertices to be added needs to be specified.\n   * @param num_vertices The number of vertices to be added.\n   */\n  virtual void AddVertices(uint64_t num_vertices) = 0;\n\n  /**\n   * @brief Add one edge to the graph.\n   * @param src The source vertex.\n   * @param dst The destination vertex.\n   */\n  virtual void AddEdge(dgl_id_t src, dgl_id_t dst) = 0;\n\n  /**\n   * @brief Add edges to the graph.\n   * @param src_ids The source vertex id array.\n   * @param dst_ids The destination vertex id array.\n   */\n  virtual void AddEdges(IdArray src_ids, IdArray dst_ids) = 0;\n\n  /**\n   * @brief Clear the graph. Remove all vertices/edges.\n   */\n  virtual void Clear() = 0;\n\n  /**\n   * @brief Get the device context of this graph.\n   */\n  virtual DGLContext Context() const = 0;\n\n  /**\n   * @brief Get the number of integer bits used to store node/edge ids\n   *        (32 or 64).\n   */\n  virtual uint8_t NumBits() const = 0;\n\n  /**\n   * @return whether the graph is a multigraph\n   */\n  virtual bool IsMultigraph() const = 0;\n\n  /**\n   * @return whether the graph is unibipartite\n   */\n  virtual bool IsUniBipartite() const {\n    EdgeArray edges = Edges();\n    IdArray src = edges.src;\n    IdArray dst = edges.dst;\n\n    bool is_unibipartite = true;\n    const size_t n = edges.src.NumElements();\n    ATEN_ID_TYPE_SWITCH(src->dtype, IdType, {\n      auto src_v = src.ToVector<IdType>();\n      std::sort(src_v.begin(), src_v.end());\n      auto dst_v = dst.ToVector<IdType>();\n      std::sort(dst_v.begin(), dst_v.end());\n      // std::set_intersection() requires output, so this is better\n      for (size_t i = 0, j = 0; i < n && j < n;) {\n        if (src_v[i] < dst_v[j]) {\n          ++i;\n        } else if (src_v[i] == dst_v[j]) {\n          is_unibipartite = false;\n          break;\n        } else {\n          ++j;\n        }\n      }\n    });\n\n    return is_unibipartite;\n  }\n\n  /**\n   * @return whether the graph is read-only\n   */\n  virtual bool IsReadonly() const = 0;\n\n  /** @return the number of vertices in the graph.*/\n  virtual uint64_t NumVertices() const = 0;\n\n  /** @return the number of edges in the graph.*/\n  virtual uint64_t NumEdges() const = 0;\n\n  /** @return true if the given vertex is in the graph.*/\n  virtual bool HasVertex(dgl_id_t vid) const { return vid < NumVertices(); }\n\n  /** @return a 0-1 array indicating whether the given vertices are in the\n   *          graph.\n   */\n  virtual BoolArray HasVertices(IdArray vids) const = 0;\n\n  /** @return true if the given edge is in the graph.*/\n  virtual bool HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const = 0;\n\n  /** @return a 0-1 array indicating whether the given edges are in the graph.*/\n  virtual BoolArray HasEdgesBetween(IdArray src_ids, IdArray dst_ids) const = 0;\n\n  /**\n   * @brief Find the predecessors of a vertex.\n   * @param vid The vertex id.\n   * @param radius The radius of the neighborhood. Default is immediate neighbor\n   *        (radius=1).\n   * @return the predecessor id array.\n   */\n  virtual IdArray Predecessors(dgl_id_t vid, uint64_t radius = 1) const = 0;\n\n  /**\n   * @brief Find the successors of a vertex.\n   * @param vid The vertex id.\n   * @param radius The radius of the neighborhood. Default is immediate neighbor\n   *        (radius=1).\n   * @return the successor id array.\n   */\n  virtual IdArray Successors(dgl_id_t vid, uint64_t radius = 1) const = 0;\n\n  /**\n   * @brief Get all edge ids between the two given endpoints\n   * @note Edges are associated with an integer id start from zero.\n   *       The id is assigned when the edge is being added to the graph.\n   * @param src The source vertex.\n   * @param dst The destination vertex.\n   * @return the edge id array.\n   */\n  virtual IdArray EdgeId(dgl_id_t src, dgl_id_t dst) const = 0;\n\n  /**\n   * @brief Get all edge ids between the given endpoint pairs.\n   * @note Edges are associated with an integer id start from zero.\n   *       The id is assigned when the edge is being added to the graph.\n   *       If duplicate pairs exist, the returned edge IDs will also duplicate.\n   *       The order of returned edge IDs will follow the order of src-dst pairs\n   *       first, and ties are broken by the order of edge ID.\n   * @return EdgeArray containing all edges between all pairs.\n   */\n  virtual EdgeArray EdgeIds(IdArray src, IdArray dst) const = 0;\n\n  /**\n   * @brief Find the edge ID and return the pair of endpoints\n   * @param eid The edge ID\n   * @return a pair whose first element is the source and the second the\n   *         destination.\n   */\n  virtual std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_id_t eid) const = 0;\n\n  /**\n   * @brief Find the edge IDs and return their source and target node IDs.\n   * @param eids The edge ID array.\n   * @return EdgeArray containing all edges with id in eid.  The order is\n   *         preserved.\n   */\n  virtual EdgeArray FindEdges(IdArray eids) const = 0;\n\n  /**\n   * @brief Get the in edges of the vertex.\n   * @note The returned dst id array is filled with vid.\n   * @param vid The vertex id.\n   * @return the edges\n   */\n  virtual EdgeArray InEdges(dgl_id_t vid) const = 0;\n\n  /**\n   * @brief Get the in edges of the vertices.\n   * @param vids The vertex id array.\n   * @return the id arrays of the two endpoints of the edges.\n   */\n  virtual EdgeArray InEdges(IdArray vids) const = 0;\n\n  /**\n   * @brief Get the out edges of the vertex.\n   * @note The returned src id array is filled with vid.\n   * @param vid The vertex id.\n   * @return the id arrays of the two endpoints of the edges.\n   */\n  virtual EdgeArray OutEdges(dgl_id_t vid) const = 0;\n\n  /**\n   * @brief Get the out edges of the vertices.\n   * @param vids The vertex id array.\n   * @return the id arrays of the two endpoints of the edges.\n   */\n  virtual EdgeArray OutEdges(IdArray vids) const = 0;\n\n  /**\n   * @brief Get all the edges in the graph.\n   * @note If order is \"srcdst\", the returned edges list is sorted by their src\n   *       and dst ids. If order is \"eid\", they are in their edge id order.\n   *       Otherwise, in the arbitrary order.\n   * @param order The order of the returned edge list.\n   * @return the id arrays of the two endpoints of the edges.\n   */\n  virtual EdgeArray Edges(const std::string &order = \"\") const = 0;\n\n  /**\n   * @brief Get the in degree of the given vertex.\n   * @param vid The vertex id.\n   * @return the in degree\n   */\n  virtual uint64_t InDegree(dgl_id_t vid) const = 0;\n\n  /**\n   * @brief Get the in degrees of the given vertices.\n   * @param vid The vertex id array.\n   * @return the in degree array\n   */\n  virtual DegreeArray InDegrees(IdArray vids) const = 0;\n\n  /**\n   * @brief Get the out degree of the given vertex.\n   * @param vid The vertex id.\n   * @return the out degree\n   */\n  virtual uint64_t OutDegree(dgl_id_t vid) const = 0;\n\n  /**\n   * @brief Get the out degrees of the given vertices.\n   * @param vid The vertex id array.\n   * @return the out degree array\n   */\n  virtual DegreeArray OutDegrees(IdArray vids) const = 0;\n\n  /**\n   * @brief Construct the induced subgraph of the given vertices.\n   *\n   * The induced subgraph is a subgraph formed by specifying a set of vertices\n   * V' and then selecting all of the edges from the original graph that connect\n   * two vertices in V'.\n   *\n   * Vertices and edges in the original graph will be \"reindexed\" to local\n   * index. The local index of the vertices preserve the order of the given id\n   * array, while the local index of the edges preserve the index order in the\n   * original graph. Vertices not in the original graph are ignored.\n   *\n   * The result subgraph is read-only.\n   *\n   * @param vids The vertices in the subgraph.\n   * @return the induced subgraph\n   */\n  virtual Subgraph VertexSubgraph(IdArray vids) const = 0;\n\n  /**\n   * @brief Construct the induced edge subgraph of the given edges.\n   *\n   * The induced edges subgraph is a subgraph formed by specifying a set of\n   * edges E' and then selecting all of the nodes from the original graph that\n   * are endpoints in E'.\n   *\n   * Vertices and edges in the original graph will be \"reindexed\" to local\n   * index. The local index of the edges preserve the order of the given id\n   * array, while the local index of the vertices preserve the index order in\n   * the original graph. Edges not in the original graph are ignored.\n   *\n   * The result subgraph is read-only.\n   *\n   * @param eids The edges in the subgraph.\n   * @param preserve_nodes If true, the vertices will not be relabeled, so some\n   *        vertices may have no incident edges.\n   * @return the induced edge subgraph\n   */\n  virtual Subgraph EdgeSubgraph(\n      IdArray eids, bool preserve_nodes = false) const = 0;\n\n  /**\n   * @brief Return the successor vector\n   * @param vid The vertex id.\n   * @return the successor vector iterator pair.\n   */\n  virtual DGLIdIters SuccVec(dgl_id_t vid) const = 0;\n\n  /**\n   * @brief Return the out edge id vector\n   * @param vid The vertex id.\n   * @return the out edge id vector iterator pair.\n   */\n  virtual DGLIdIters OutEdgeVec(dgl_id_t vid) const = 0;\n\n  /**\n   * @brief Return the predecessor vector\n   * @param vid The vertex id.\n   * @return the predecessor vector iterator pair.\n   */\n  virtual DGLIdIters PredVec(dgl_id_t vid) const = 0;\n\n  /**\n   * @brief Return the in edge id vector\n   * @param vid The vertex id.\n   * @return the in edge id vector iterator pair.\n   */\n  virtual DGLIdIters InEdgeVec(dgl_id_t vid) const = 0;\n\n  /**\n   * @brief Get the adjacency matrix of the graph.\n   *\n   * By default, a row of returned adjacency matrix represents the destination\n   * of an edge and the column represents the source.\n   *\n   * If the fmt is 'csr', the function should return three arrays, representing\n   *  indptr, indices and edge ids\n   *\n   * If the fmt is 'coo', the function should return one array of shape (2,\n   * nnz), representing a horitonzal stack of row and col indices.\n   *\n   * @param transpose A flag to transpose the returned adjacency matrix.\n   * @param fmt the format of the returned adjacency matrix.\n   * @return a vector of IdArrays.\n   */\n  virtual std::vector<IdArray> GetAdj(\n      bool transpose, const std::string &fmt) const = 0;\n\n  /**\n   * @brief Sort the columns in CSR.\n   *\n   * This sorts the columns in each row based on the column Ids.\n   * The edge ids should be sorted accordingly.\n   */\n  virtual void SortCSR() {}\n\n  static constexpr const char *_type_key = \"graph.Graph\";\n  DGL_DECLARE_OBJECT_TYPE_INFO(GraphInterface, runtime::Object);\n};\n\n// Define GraphRef\nDGL_DEFINE_OBJECT_REF(GraphRef, GraphInterface);\n\n/** @brief Subgraph data structure */\nstruct Subgraph : public runtime::Object {\n  /** @brief The graph. */\n  GraphPtr graph;\n  /**\n   * @brief The induced vertex ids.\n   * @note This is also a map from the new vertex id to the vertex id in the\n   *       parent graph.\n   */\n  IdArray induced_vertices;\n  /**\n   * @brief The induced edge ids.\n   * @note This is also a map from the new edge id to the edge id in the parent\n   *       graph.\n   */\n  IdArray induced_edges;\n\n  static constexpr const char *_type_key = \"graph.Subgraph\";\n  DGL_DECLARE_OBJECT_TYPE_INFO(Subgraph, runtime::Object);\n};\n\n/** @brief Subgraph data structure for negative subgraph */\nstruct NegSubgraph : public Subgraph {\n  /** @brief The existence of the negative edges in the parent graph. */\n  IdArray exist;\n\n  /** @brief The Ids of head nodes */\n  IdArray head_nid;\n\n  /** @brief The Ids of tail nodes */\n  IdArray tail_nid;\n};\n\n/** @brief Subgraph data structure for halo subgraph */\nstruct HaloSubgraph : public Subgraph {\n  /** @brief Indicate if a node belongs to the partition. */\n  IdArray inner_nodes;\n};\n\n// Define SubgraphRef\nDGL_DEFINE_OBJECT_REF(SubgraphRef, Subgraph);\n\n}  // namespace dgl\n\n#endif  // DGL_GRAPH_INTERFACE_H_\n"
  },
  {
    "path": "include/dgl/graph_op.h",
    "content": "/**\n *  Copyright (c) 2018 by Contributors\n * @file dgl/graph_op.h\n * @brief Operations on graph index.\n */\n#ifndef DGL_GRAPH_OP_H_\n#define DGL_GRAPH_OP_H_\n\n#include <vector>\n\n#include \"graph.h\"\n#include \"immutable_graph.h\"\n\nnamespace dgl {\n\nclass GraphOp {\n public:\n  /**\n   * @brief Return a new graph with all the edges reversed.\n   *\n   * The returned graph preserves the vertex and edge index in the original\n   * graph.\n   *\n   * @return the reversed graph\n   */\n  static GraphPtr Reverse(GraphPtr graph);\n\n  /**\n   * @brief Return the line graph.\n   *\n   * If i~j and j~i are two edges in original graph G, then\n   * (i,j)~(j,i) and (j,i)~(i,j) are the \"backtracking\" edges on\n   * the line graph.\n   *\n   * @param graph The input graph.\n   * @param backtracking Whether the backtracking edges are included or not\n   * @return the line graph\n   */\n  static GraphPtr LineGraph(GraphPtr graph, bool backtracking);\n\n  /**\n   * @brief Return a disjoint union of the input graphs.\n   *\n   * The new graph will include all the nodes/edges in the given graphs.\n   * Nodes/Edges will be relabled by adding the cumsum of the previous graph\n   * sizes in the given sequence order. For example, giving input [g1, g2, g3],\n   * where they have 5, 6, 7 nodes respectively. Then node#2 of g2 will become\n   * node#7 in the result graph. Edge ids are re-assigned similarly.\n   *\n   * The input list must be either ALL mutable graphs or ALL immutable graphs.\n   * The returned graph type is also determined by the input graph type.\n   *\n   * @param graphs A list of input graphs to be unioned.\n   * @return the disjoint union of the graphs\n   */\n  static GraphPtr DisjointUnion(std::vector<GraphPtr> graphs);\n\n  /**\n   * @brief Partition the graph into several subgraphs.\n   *\n   * This is a reverse operation of DisjointUnion. The graph will be partitioned\n   * into num graphs. This requires the given number of partitions to evenly\n   * divides the number of nodes in the graph.\n   *\n   * If the input graph is mutable, the result graphs are mutable.\n   * If the input graph is immutable, the result graphs are immutable.\n   *\n   * @param graph The graph to be partitioned.\n   * @param num The number of partitions.\n   * @return a list of partitioned graphs\n   */\n  static std::vector<GraphPtr> DisjointPartitionByNum(\n      GraphPtr graph, int64_t num);\n\n  /**\n   * @brief Partition the graph into several subgraphs.\n   *\n   * This is a reverse operation of DisjointUnion. The graph will be partitioned\n   * based on the given sizes. This requires the sum of the given sizes is equal\n   * to the number of nodes in the graph.\n   *\n   * If the input graph is mutable, the result graphs are mutable.\n   * If the input graph is immutable, the result graphs are immutable.\n   *\n   * @param graph The graph to be partitioned.\n   * @param sizes The number of partitions.\n   * @return a list of partitioned graphs\n   */\n  static std::vector<GraphPtr> DisjointPartitionBySizes(\n      GraphPtr graph, IdArray sizes);\n\n  /**\n   * @brief Map vids in the parent graph to the vids in the subgraph.\n   *\n   * If the Id doesn't exist in the subgraph, -1 will be used.\n   *\n   * @param parent_vid_map An array that maps the vids in the parent graph to\n   * the subgraph. The elements store the vertex Ids in the parent graph, and\n   * the indices indicate the vertex Ids in the subgraph.\n   * @param query The vertex Ids in the parent graph.\n   * @return an Id array that contains the subgraph node Ids.\n   */\n  static IdArray MapParentIdToSubgraphId(IdArray parent_vid_map, IdArray query);\n\n  /**\n   * @brief Expand an Id array based on the offset array.\n   *\n   * For example,\n   * ids:     [0, 1, 2, 3, 4],\n   * offset:  [0, 2, 2, 5, 6, 7],\n   * result:  [0, 0, 2, 2, 2, 3, 4].\n   * The offset array has one more element than the ids array.\n   * (offset[i], offset[i+1]) shows the location of ids[i] in the result array.\n   *\n   * @param ids An array that contains the node or edge Ids.\n   * @param offset An array that contains the offset after expansion.\n   * @return a expanded Id array.\n   */\n  static IdArray ExpandIds(IdArray ids, IdArray offset);\n\n  /**\n   * @brief Convert the graph to a simple graph.\n   * @param graph The input graph.\n   * @return a new immutable simple graph with no multi-edge.\n   */\n  static GraphPtr ToSimpleGraph(GraphPtr graph);\n\n  /**\n   * @brief Convert the graph to a mutable bidirected graph.\n   *\n   * If the original graph has m edges for i -> j and n edges for\n   * j -> i, the new graph will have max(m, n) edges for both\n   * i -> j and j -> i.\n   *\n   * @param graph The input graph.\n   * @return a new mutable bidirected graph.\n   */\n  static GraphPtr ToBidirectedMutableGraph(GraphPtr graph);\n\n  /**\n   * @brief Same as BidirectedMutableGraph except that the returned graph is\n   *        immutable.\n   * @param graph The input graph.\n   * @return a new immutable bidirected\n   * graph.\n   */\n  static GraphPtr ToBidirectedImmutableGraph(GraphPtr graph);\n  /**\n   * @brief Same as BidirectedMutableGraph except that the returned graph is\n   * immutable and call gk_csr_MakeSymmetric in GKlib. This is more efficient\n   * than ToBidirectedImmutableGraph. It return a null pointer if the conversion\n   * fails.\n   *\n   * @param graph The input graph.\n   * @return a new immutable bidirected graph.\n   */\n  static GraphPtr ToBidirectedSimpleImmutableGraph(ImmutableGraphPtr ig);\n\n  /**\n   * @brief Get a induced subgraph with HALO nodes.\n   * The HALO nodes are the ones that can be reached from `nodes` within\n   * `num_hops`.\n   * @param graph The input graph.\n   * @param nodes The input nodes that form the core of the induced subgraph.\n   * @param num_hops The number of hops to reach.\n   * @return the induced subgraph with HALO nodes.\n   */\n  static HaloSubgraph GetSubgraphWithHalo(\n      GraphPtr graph, IdArray nodes, int num_hops);\n\n  /**\n   * @brief Reorder the nodes in the immutable graph.\n   * @param graph The input graph.\n   * @param new_order The node Ids in the new graph. The index in `new_order` is\n   *        old node Ids.\n   * @return the graph with reordered node Ids\n   */\n  static GraphPtr ReorderImmutableGraph(\n      ImmutableGraphPtr ig, IdArray new_order);\n};\n\n}  // namespace dgl\n\n#endif  // DGL_GRAPH_OP_H_\n"
  },
  {
    "path": "include/dgl/graph_serializer.h",
    "content": "/**\n *  Copyright (c) 2018 by Contributors\n * @file graph/graph_serializer.cc\n * @brief DGL serializer APIs\n */\n\n#ifndef DGL_GRAPH_SERIALIZER_H_\n#define DGL_GRAPH_SERIALIZER_H_\n\n#include <memory>\nnamespace dgl {\n\n// Util class to call the private/public empty constructor, which is needed for\n// serialization\nclass Serializer {\n public:\n  template <typename T>\n  static T* new_object() {\n    return new T();\n  }\n\n  template <typename T>\n  static std::shared_ptr<T> make_shared() {\n    return std::shared_ptr<T>(new T());\n  }\n};\n}  // namespace dgl\n\n#endif  // DGL_GRAPH_SERIALIZER_H_\n"
  },
  {
    "path": "include/dgl/graph_traversal.h",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file dgl/graph_traversal.h\n * @brief common graph traversal operations\n */\n#ifndef DGL_GRAPH_TRAVERSAL_H_\n#define DGL_GRAPH_TRAVERSAL_H_\n\n#include \"array.h\"\n#include \"base_heterograph.h\"\n\nnamespace dgl {\n\n///////////////////////// Graph Traverse routines //////////////////////////\n/**\n * @brief Class for representing frontiers.\n *\n * Each frontier is a list of nodes/edges (specified by their ids).\n * An optional tag can be specified on each node/edge (represented by an int\n * value).\n */\nstruct Frontiers {\n  /** @brief a vector store for the nodes/edges in all the frontiers */\n  IdArray ids;\n\n  /**\n   * @brief a vector store for node/edge tags. Dtype is int64.\n   * Empty if no tags are requested\n   */\n  IdArray tags;\n\n  /** @brief a section vector to indicate each frontier Dtype is int64. */\n  IdArray sections;\n};\n\nnamespace aten {\n\n/**\n * @brief Traverse the graph in a breadth-first-search (BFS) order.\n *\n * @param csr The input csr matrix.\n * @param sources Source nodes.\n * @return A Frontiers object containing the search result\n */\nFrontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source);\n\n/**\n * @brief Traverse the graph in a breadth-first-search (BFS) order, returning\n *        the edges of the BFS tree.\n *\n * @param csr The input csr matrix.\n * @param sources Source nodes.\n * @return A Frontiers object containing the search result\n */\nFrontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source);\n\n/**\n * @brief Traverse the graph in topological order.\n *\n * @param csr The input csr matrix.\n * @return A Frontiers object containing the search result\n */\nFrontiers TopologicalNodesFrontiers(const CSRMatrix& csr);\n\n/**\n * @brief Traverse the graph in a depth-first-search (DFS) order.\n *\n * @param csr The input csr matrix.\n * @param sources Source nodes.\n * @return A Frontiers object containing the search result\n */\nFrontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source);\n\n/**\n * @brief Traverse the graph in a depth-first-search (DFS) order and return the\n *        recorded edge tag if return_labels is specified.\n *\n * The traversal visit edges in its DFS order. Edges have three tags:\n * FORWARD(0), REVERSE(1), NONTREE(2)\n *\n * A FORWARD edge is one in which `u` has been visisted but `v` has not.\n * A REVERSE edge is one in which both `u` and `v` have been visisted and the\n * edge is in the DFS tree.\n * A NONTREE edge is one in which both `u` and `v` have been visisted but the\n * edge is NOT in the DFS tree.\n *\n * @param csr The input csr matrix.\n * @param sources Source nodes.\n * @param has_reverse_edge If true, REVERSE edges are included\n * @param has_nontree_edge If true, NONTREE edges are included\n * @param return_labels If true, return the recorded edge tags.\n * @return A Frontiers object containing the search result\n */\nFrontiers DGLDFSLabeledEdges(\n    const CSRMatrix& csr, IdArray source, const bool has_reverse_edge,\n    const bool has_nontree_edge, const bool return_labels);\n\n}  // namespace aten\n}  // namespace dgl\n\n#endif  // DGL_GRAPH_TRAVERSAL_H_\n"
  },
  {
    "path": "include/dgl/immutable_graph.h",
    "content": "/**\n *  Copyright (c) 2018 by Contributors\n * @file dgl/immutable_graph.h\n * @brief DGL immutable graph index class.\n */\n#ifndef DGL_IMMUTABLE_GRAPH_H_\n#define DGL_IMMUTABLE_GRAPH_H_\n\n#include <algorithm>\n#include <cstdint>\n#include <memory>\n#include <string>\n#include <tuple>\n#include <utility>\n#include <vector>\n\n#include \"base_heterograph.h\"\n#include \"graph_interface.h\"\n#include \"lazy.h\"\n#include \"runtime/ndarray.h\"\n\nnamespace dgl {\n\nclass CSR;\nclass COO;\ntypedef std::shared_ptr<CSR> CSRPtr;\ntypedef std::shared_ptr<COO> COOPtr;\n\nclass ImmutableGraph;\ntypedef std::shared_ptr<ImmutableGraph> ImmutableGraphPtr;\n\n/**\n * @brief Graph class stored using CSR structure.\n */\nclass CSR : public GraphInterface {\n public:\n  // Create a csr graph that has the given number of verts and edges.\n  CSR(int64_t num_vertices, int64_t num_edges);\n  // Create a csr graph whose memory is stored in the shared memory\n  //   that has the given number of verts and edges.\n  CSR(const std::string &shared_mem_name, int64_t num_vertices,\n      int64_t num_edges);\n\n  // Create a csr graph that shares the given indptr and indices.\n  CSR(IdArray indptr, IdArray indices, IdArray edge_ids);\n\n  // Create a csr graph by data iterator\n  template <typename IndptrIter, typename IndicesIter, typename EdgeIdIter>\n  CSR(int64_t num_vertices, int64_t num_edges, IndptrIter indptr_begin,\n      IndicesIter indices_begin, EdgeIdIter edge_ids_begin);\n\n  // Create a csr graph whose memory is stored in the shared memory\n  //   and the structure is given by the indptr and indcies.\n  CSR(IdArray indptr, IdArray indices, IdArray edge_ids,\n      const std::string &shared_mem_name);\n\n  void AddVertices(uint64_t num_vertices) override {\n    LOG(FATAL) << \"CSR graph does not allow mutation.\";\n  }\n\n  void AddEdge(dgl_id_t src, dgl_id_t dst) override {\n    LOG(FATAL) << \"CSR graph does not allow mutation.\";\n  }\n\n  void AddEdges(IdArray src_ids, IdArray dst_ids) override {\n    LOG(FATAL) << \"CSR graph does not allow mutation.\";\n  }\n\n  void Clear() override { LOG(FATAL) << \"CSR graph does not allow mutation.\"; }\n\n  DGLContext Context() const override { return adj_.indptr->ctx; }\n\n  uint8_t NumBits() const override { return adj_.indices->dtype.bits; }\n\n  bool IsMultigraph() const override;\n\n  bool IsReadonly() const override { return true; }\n\n  uint64_t NumVertices() const override { return adj_.indptr->shape[0] - 1; }\n\n  uint64_t NumEdges() const override { return adj_.indices->shape[0]; }\n\n  BoolArray HasVertices(IdArray vids) const override {\n    LOG(FATAL) << \"Not enabled for CSR graph\";\n    return {};\n  }\n\n  bool HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const override;\n\n  BoolArray HasEdgesBetween(IdArray src_ids, IdArray dst_ids) const override;\n\n  IdArray Predecessors(dgl_id_t vid, uint64_t radius = 1) const override {\n    LOG(FATAL) << \"CSR graph does not support efficient predecessor query.\"\n               << \" Please use successors on the reverse CSR graph.\";\n    return {};\n  }\n\n  IdArray Successors(dgl_id_t vid, uint64_t radius = 1) const override;\n\n  IdArray EdgeId(dgl_id_t src, dgl_id_t dst) const override;\n\n  EdgeArray EdgeIds(IdArray src, IdArray dst) const override;\n\n  std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_id_t eid) const override {\n    LOG(FATAL) << \"CSR graph does not support efficient FindEdge.\"\n               << \" Please use COO graph.\";\n    return {};\n  }\n\n  EdgeArray FindEdges(IdArray eids) const override {\n    LOG(FATAL) << \"CSR graph does not support efficient FindEdges.\"\n               << \" Please use COO graph.\";\n    return {};\n  }\n\n  EdgeArray InEdges(dgl_id_t vid) const override {\n    LOG(FATAL) << \"CSR graph does not support efficient inedges query.\"\n               << \" Please use outedges on the reverse CSR graph.\";\n    return {};\n  }\n\n  EdgeArray InEdges(IdArray vids) const override {\n    LOG(FATAL) << \"CSR graph does not support efficient inedges query.\"\n               << \" Please use outedges on the reverse CSR graph.\";\n    return {};\n  }\n\n  EdgeArray OutEdges(dgl_id_t vid) const override;\n\n  EdgeArray OutEdges(IdArray vids) const override;\n\n  EdgeArray Edges(const std::string &order = \"\") const override;\n\n  uint64_t InDegree(dgl_id_t vid) const override {\n    LOG(FATAL) << \"CSR graph does not support efficient indegree query.\"\n               << \" Please use outdegree on the reverse CSR graph.\";\n    return 0;\n  }\n\n  DegreeArray InDegrees(IdArray vids) const override {\n    LOG(FATAL) << \"CSR graph does not support efficient indegree query.\"\n               << \" Please use outdegree on the reverse CSR graph.\";\n    return {};\n  }\n\n  uint64_t OutDegree(dgl_id_t vid) const override {\n    return aten::CSRGetRowNNZ(adj_, vid);\n  }\n\n  DegreeArray OutDegrees(IdArray vids) const override;\n\n  Subgraph VertexSubgraph(IdArray vids) const override;\n\n  Subgraph EdgeSubgraph(\n      IdArray eids, bool preserve_nodes = false) const override {\n    LOG(FATAL) << \"CSR graph does not support efficient EdgeSubgraph.\"\n               << \" Please use COO graph instead.\";\n    return {};\n  }\n\n  DGLIdIters SuccVec(dgl_id_t vid) const override;\n\n  DGLIdIters OutEdgeVec(dgl_id_t vid) const override;\n\n  DGLIdIters PredVec(dgl_id_t vid) const override {\n    LOG(FATAL) << \"CSR graph does not support efficient PredVec.\"\n               << \" Please use SuccVec on the reverse CSR graph.\";\n    return DGLIdIters(nullptr, nullptr);\n  }\n\n  DGLIdIters InEdgeVec(dgl_id_t vid) const override {\n    LOG(FATAL) << \"CSR graph does not support efficient InEdgeVec.\"\n               << \" Please use OutEdgeVec on the reverse CSR graph.\";\n    return DGLIdIters(nullptr, nullptr);\n  }\n\n  std::vector<IdArray> GetAdj(\n      bool transpose, const std::string &fmt) const override {\n    CHECK(!transpose && fmt == \"csr\") << \"Not valid adj format request.\";\n    return {adj_.indptr, adj_.indices, adj_.data};\n  }\n\n  /** @brief Indicate whether this uses shared memory. */\n  bool IsSharedMem() const { return !shared_mem_name_.empty(); }\n\n  /** @brief Return the reverse of this CSR graph (i.e, a CSC graph) */\n  CSRPtr Transpose() const;\n\n  /** @brief Convert this CSR to COO */\n  COOPtr ToCOO() const;\n\n  /**\n   * @return the csr matrix that represents this graph.\n   * @note The csr matrix shares the storage with this graph.\n   *       The data field of the CSR matrix stores the edge ids.\n   */\n  aten::CSRMatrix ToCSRMatrix() const { return adj_; }\n\n  /**\n   * @brief Copy the data to another context.\n   * @param ctx The target context.\n   * @return The graph under another context.\n   */\n  CSR CopyTo(const DGLContext &ctx) const;\n\n  /**\n   * @brief Copy data to shared memory.\n   * @param name The name of the shared memory.\n   * @return The graph in the shared memory\n   */\n  CSR CopyToSharedMem(const std::string &name) const;\n\n  /**\n   * @brief Convert the graph to use the given number of bits for storage.\n   * @param bits The new number of integer bits (32 or 64).\n   * @return The graph with new bit size storage.\n   */\n  CSR AsNumBits(uint8_t bits) const;\n\n  // member getters\n\n  IdArray indptr() const { return adj_.indptr; }\n\n  IdArray indices() const { return adj_.indices; }\n\n  IdArray edge_ids() const { return adj_.data; }\n\n  /** @return Load CSR from stream */\n  bool Load(dmlc::Stream *fs);\n\n  /** @return Save CSR to stream */\n  void Save(dmlc::Stream *fs) const;\n\n  void SortCSR() override {\n    if (adj_.sorted) return;\n    aten::CSRSort_(&adj_);\n  }\n\n private:\n  friend class Serializer;\n\n  /** @brief private default constructor */\n  CSR() { adj_.sorted = false; }\n  // The internal CSR adjacency matrix.\n  // The data field stores edge ids.\n  aten::CSRMatrix adj_;\n\n  // The name of the shared memory to store data.\n  // If it's empty, data isn't stored in shared memory.\n  std::string shared_mem_name_;\n};\n\nclass COO : public GraphInterface {\n public:\n  // Create a coo graph that shares the given src and dst\n  COO(int64_t num_vertices, IdArray src, IdArray dst, bool row_sorted = false,\n      bool col_sorted = false);\n\n  // TODO(da): add constructor for creating COO from shared memory\n\n  void AddVertices(uint64_t num_vertices) override {\n    LOG(FATAL) << \"COO graph does not allow mutation.\";\n  }\n\n  void AddEdge(dgl_id_t src, dgl_id_t dst) override {\n    LOG(FATAL) << \"COO graph does not allow mutation.\";\n  }\n\n  void AddEdges(IdArray src_ids, IdArray dst_ids) override {\n    LOG(FATAL) << \"COO graph does not allow mutation.\";\n  }\n\n  void Clear() override { LOG(FATAL) << \"COO graph does not allow mutation.\"; }\n\n  DGLContext Context() const override { return adj_.row->ctx; }\n\n  uint8_t NumBits() const override { return adj_.row->dtype.bits; }\n\n  bool IsMultigraph() const override;\n\n  bool IsReadonly() const override { return true; }\n\n  uint64_t NumVertices() const override { return adj_.num_rows; }\n\n  uint64_t NumEdges() const override { return adj_.row->shape[0]; }\n\n  bool HasVertex(dgl_id_t vid) const override { return vid < NumVertices(); }\n\n  BoolArray HasVertices(IdArray vids) const override {\n    LOG(FATAL) << \"Not enabled for COO graph\";\n    return {};\n  }\n\n  bool HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const override {\n    LOG(FATAL) << \"COO graph does not support efficient HasEdgeBetween.\"\n               << \" Please use CSR graph or AdjList graph instead.\";\n    return false;\n  }\n\n  BoolArray HasEdgesBetween(IdArray src_ids, IdArray dst_ids) const override {\n    LOG(FATAL) << \"COO graph does not support efficient HasEdgeBetween.\"\n               << \" Please use CSR graph or AdjList graph instead.\";\n    return {};\n  }\n\n  IdArray Predecessors(dgl_id_t vid, uint64_t radius = 1) const override {\n    LOG(FATAL) << \"COO graph does not support efficient Predecessors.\"\n               << \" Please use CSR graph or AdjList graph instead.\";\n    return {};\n  }\n\n  IdArray Successors(dgl_id_t vid, uint64_t radius = 1) const override {\n    LOG(FATAL) << \"COO graph does not support efficient Successors.\"\n               << \" Please use CSR graph or AdjList graph instead.\";\n    return {};\n  }\n\n  IdArray EdgeId(dgl_id_t src, dgl_id_t dst) const override {\n    LOG(FATAL) << \"COO graph does not support efficient EdgeId.\"\n               << \" Please use CSR graph or AdjList graph instead.\";\n    return {};\n  }\n\n  EdgeArray EdgeIds(IdArray src, IdArray dst) const override {\n    LOG(FATAL) << \"COO graph does not support efficient EdgeId.\"\n               << \" Please use CSR graph or AdjList graph instead.\";\n    return {};\n  }\n\n  std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_id_t eid) const override;\n\n  EdgeArray FindEdges(IdArray eids) const override;\n\n  EdgeArray InEdges(dgl_id_t vid) const override {\n    LOG(FATAL) << \"COO graph does not support efficient InEdges.\"\n               << \" Please use CSR graph or AdjList graph instead.\";\n    return {};\n  }\n\n  EdgeArray InEdges(IdArray vids) const override {\n    LOG(FATAL) << \"COO graph does not support efficient InEdges.\"\n               << \" Please use CSR graph or AdjList graph instead.\";\n    return {};\n  }\n\n  EdgeArray OutEdges(dgl_id_t vid) const override {\n    LOG(FATAL) << \"COO graph does not support efficient OutEdges.\"\n               << \" Please use CSR graph or AdjList graph instead.\";\n    return {};\n  }\n\n  EdgeArray OutEdges(IdArray vids) const override {\n    LOG(FATAL) << \"COO graph does not support efficient OutEdges.\"\n               << \" Please use CSR graph or AdjList graph instead.\";\n    return {};\n  }\n\n  EdgeArray Edges(const std::string &order = \"\") const override;\n\n  uint64_t InDegree(dgl_id_t vid) const override {\n    LOG(FATAL) << \"COO graph does not support efficient InDegree.\"\n               << \" Please use CSR graph or AdjList graph instead.\";\n    return 0;\n  }\n\n  DegreeArray InDegrees(IdArray vids) const override {\n    LOG(FATAL) << \"COO graph does not support efficient InDegrees.\"\n               << \" Please use CSR graph or AdjList graph instead.\";\n    return {};\n  }\n\n  uint64_t OutDegree(dgl_id_t vid) const override {\n    LOG(FATAL) << \"COO graph does not support efficient OutDegree.\"\n               << \" Please use CSR graph or AdjList graph instead.\";\n    return 0;\n  }\n\n  DegreeArray OutDegrees(IdArray vids) const override {\n    LOG(FATAL) << \"COO graph does not support efficient OutDegrees.\"\n               << \" Please use CSR graph or AdjList graph instead.\";\n    return {};\n  }\n\n  Subgraph VertexSubgraph(IdArray vids) const override {\n    LOG(FATAL) << \"COO graph does not support efficient VertexSubgraph.\"\n               << \" Please use CSR graph or AdjList graph instead.\";\n    return {};\n  }\n\n  Subgraph EdgeSubgraph(\n      IdArray eids, bool preserve_nodes = false) const override;\n\n  DGLIdIters SuccVec(dgl_id_t vid) const override {\n    LOG(FATAL) << \"COO graph does not support efficient SuccVec.\"\n               << \" Please use CSR graph or AdjList graph instead.\";\n    return DGLIdIters(nullptr, nullptr);\n  }\n\n  DGLIdIters OutEdgeVec(dgl_id_t vid) const override {\n    LOG(FATAL) << \"COO graph does not support efficient OutEdgeVec.\"\n               << \" Please use CSR graph or AdjList graph instead.\";\n    return DGLIdIters(nullptr, nullptr);\n  }\n\n  DGLIdIters PredVec(dgl_id_t vid) const override {\n    LOG(FATAL) << \"COO graph does not support efficient PredVec.\"\n               << \" Please use CSR graph or AdjList graph instead.\";\n    return DGLIdIters(nullptr, nullptr);\n  }\n\n  DGLIdIters InEdgeVec(dgl_id_t vid) const override {\n    LOG(FATAL) << \"COO graph does not support efficient InEdgeVec.\"\n               << \" Please use CSR graph or AdjList graph instead.\";\n    return DGLIdIters(nullptr, nullptr);\n  }\n\n  std::vector<IdArray> GetAdj(\n      bool transpose, const std::string &fmt) const override {\n    CHECK(fmt == \"coo\") << \"Not valid adj format request.\";\n    if (transpose) {\n      return {aten::HStack(adj_.col, adj_.row)};\n    } else {\n      return {aten::HStack(adj_.row, adj_.col)};\n    }\n  }\n\n  /** @brief Return the transpose of this COO */\n  COOPtr Transpose() const {\n    return COOPtr(new COO(adj_.num_rows, adj_.col, adj_.row));\n  }\n\n  /** @brief Convert this COO to CSR */\n  CSRPtr ToCSR() const;\n\n  /**\n   * @brief Get the coo matrix that represents this graph.\n   * @note The coo matrix shares the storage with this graph.\n   *       The data field of the coo matrix is none.\n   */\n  aten::COOMatrix ToCOOMatrix() const { return adj_; }\n\n  /**\n   * @brief Copy the data to another context.\n   * @param ctx The target context.\n   * @return The graph under another context.\n   */\n  COO CopyTo(const DGLContext &ctx) const;\n\n  /**\n   * @brief Copy data to shared memory.\n   * @param name The name of the shared memory.\n   * @return The graph in the shared memory\n   */\n  COO CopyToSharedMem(const std::string &name) const;\n\n  /**\n   * @brief Convert the graph to use the given number of bits for storage.\n   * @param bits The new number of integer bits (32 or 64).\n   * @return The graph with new bit size storage.\n   */\n  COO AsNumBits(uint8_t bits) const;\n\n  /** @brief Indicate whether this uses shared memory. */\n  bool IsSharedMem() const { return false; }\n\n  // member getters\n\n  IdArray src() const { return adj_.row; }\n\n  IdArray dst() const { return adj_.col; }\n\n private:\n  /** @brief private default constructor */\n  COO() {}\n\n  // The internal COO adjacency matrix.\n  // The data field is empty\n  aten::COOMatrix adj_;\n};\n\n/**\n * @brief DGL immutable graph index class.\n *\n * DGL's graph is directed. Vertices are integers enumerated from zero.\n */\nclass ImmutableGraph : public GraphInterface {\n public:\n  /** @brief Construct an immutable graph from the COO format. */\n  explicit ImmutableGraph(COOPtr coo) : coo_(coo) {}\n\n  /**\n   * @brief Construct an immutable graph from the CSR format.\n   *\n   * For a single graph, we need two CSRs, one stores the in-edges of vertices\n   * and the other stores the out-edges of vertices. These two CSRs stores the\n   * same edges. The reason we need both is that some operators are faster on\n   * in-edge CSR and the other operators are faster on out-edge CSR.\n   *\n   * However, not both CSRs are required. Technically, one CSR contains all\n   * information. Thus, when we construct a temporary graphs (e.g., the sampled\n   * subgraphs), we only construct one of the CSRs that runs fast for some\n   * operations we expect and construct the other CSR on demand.\n   */\n  ImmutableGraph(CSRPtr in_csr, CSRPtr out_csr)\n      : in_csr_(in_csr), out_csr_(out_csr) {\n    CHECK(in_csr_ || out_csr_) << \"Both CSR are missing.\";\n  }\n\n  /** @brief Construct an immutable graph from one CSR. */\n  explicit ImmutableGraph(CSRPtr csr) : out_csr_(csr) {}\n\n  /** @brief default copy constructor */\n  ImmutableGraph(const ImmutableGraph &other) = default;\n\n#ifndef _MSC_VER\n  /** @brief default move constructor */\n  ImmutableGraph(ImmutableGraph &&other) = default;\n#else\n  ImmutableGraph(ImmutableGraph &&other) {\n    this->in_csr_ = other.in_csr_;\n    this->out_csr_ = other.out_csr_;\n    this->coo_ = other.coo_;\n    other.in_csr_ = nullptr;\n    other.out_csr_ = nullptr;\n    other.coo_ = nullptr;\n  }\n#endif  // _MSC_VER\n\n  /** @brief default assign constructor */\n  ImmutableGraph &operator=(const ImmutableGraph &other) = default;\n\n  /** @brief default destructor */\n  ~ImmutableGraph() = default;\n\n  void AddVertices(uint64_t num_vertices) override {\n    LOG(FATAL) << \"AddVertices isn't supported in ImmutableGraph\";\n  }\n\n  void AddEdge(dgl_id_t src, dgl_id_t dst) override {\n    LOG(FATAL) << \"AddEdge isn't supported in ImmutableGraph\";\n  }\n\n  void AddEdges(IdArray src_ids, IdArray dst_ids) override {\n    LOG(FATAL) << \"AddEdges isn't supported in ImmutableGraph\";\n  }\n\n  void Clear() override {\n    LOG(FATAL) << \"Clear isn't supported in ImmutableGraph\";\n  }\n\n  DGLContext Context() const override { return AnyGraph()->Context(); }\n\n  uint8_t NumBits() const override { return AnyGraph()->NumBits(); }\n\n  /**\n   * @note not const since we have caches\n   * @return whether the graph is a multigraph\n   */\n  bool IsMultigraph() const override { return AnyGraph()->IsMultigraph(); }\n\n  /**\n   * @return whether the graph is read-only\n   */\n  bool IsReadonly() const override { return true; }\n\n  /**\n   * @brief Check if the graph is unibipartite.\n   *\n   * @return True if the graph is unibipartite.\n   */\n  bool IsUniBipartite() const override {\n    if (!is_unibipartite_set_) {\n      is_unibipartite_ = GraphInterface::IsUniBipartite();\n      is_unibipartite_set_ = true;\n    }\n\n    return is_unibipartite_;\n  }\n\n  /** @return the number of vertices in the graph.*/\n  uint64_t NumVertices() const override { return AnyGraph()->NumVertices(); }\n\n  /** @return the number of edges in the graph.*/\n  uint64_t NumEdges() const override { return AnyGraph()->NumEdges(); }\n\n  /** @return true if the given vertex is in the graph.*/\n  bool HasVertex(dgl_id_t vid) const override { return vid < NumVertices(); }\n\n  BoolArray HasVertices(IdArray vids) const override;\n\n  /** @return true if the given edge is in the graph.*/\n  bool HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const override {\n    if (in_csr_) {\n      return in_csr_->HasEdgeBetween(dst, src);\n    } else {\n      return GetOutCSR()->HasEdgeBetween(src, dst);\n    }\n  }\n\n  BoolArray HasEdgesBetween(IdArray src, IdArray dst) const override {\n    if (in_csr_) {\n      return in_csr_->HasEdgesBetween(dst, src);\n    } else {\n      return GetOutCSR()->HasEdgesBetween(src, dst);\n    }\n  }\n\n  /**\n   * @brief Find the predecessors of a vertex.\n   * @param vid The vertex id.\n   * @param radius The radius of the neighborhood. Default is immediate neighbor\n   *        (radius=1).\n   * @return the predecessor id array.\n   */\n  IdArray Predecessors(dgl_id_t vid, uint64_t radius = 1) const override {\n    return GetInCSR()->Successors(vid, radius);\n  }\n\n  /**\n   * @brief Find the successors of a vertex.\n   * @param vid The vertex id.\n   * @param radius The radius of the neighborhood. Default is immediate neighbor\n   *        (radius=1).\n   * @return the successor id array.\n   */\n  IdArray Successors(dgl_id_t vid, uint64_t radius = 1) const override {\n    return GetOutCSR()->Successors(vid, radius);\n  }\n\n  /**\n   * @brief Get all edge ids between the two given endpoints\n   * @note Edges are associated with an integer id start from zero.\n   *       The id is assigned when the edge is being added to the graph.\n   * @param src The source vertex.\n   * @param dst The destination vertex.\n   * @return the edge id array.\n   */\n  IdArray EdgeId(dgl_id_t src, dgl_id_t dst) const override {\n    if (in_csr_) {\n      return in_csr_->EdgeId(dst, src);\n    } else {\n      return GetOutCSR()->EdgeId(src, dst);\n    }\n  }\n\n  /**\n   * @brief Get all edge ids between the given endpoint pairs.\n   * @note Edges are associated with an integer id start from zero.\n   *       The id is assigned when the edge is being added to the graph.\n   *       If duplicate pairs exist, the returned edge IDs will also duplicate.\n   *       The order of returned edge IDs will follow the order of src-dst pairs\n   *       first, and ties are broken by the order of edge ID.\n   * @return EdgeArray containing all edges between all pairs.\n   */\n  EdgeArray EdgeIds(IdArray src, IdArray dst) const override {\n    if (in_csr_) {\n      EdgeArray edges = in_csr_->EdgeIds(dst, src);\n      return EdgeArray{edges.dst, edges.src, edges.id};\n    } else {\n      return GetOutCSR()->EdgeIds(src, dst);\n    }\n  }\n\n  /**\n   * @brief Find the edge ID and return the pair of endpoints\n   * @param eid The edge ID\n   * @return a pair whose first element is the source and the second the\n   *         destination.\n   */\n  std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_id_t eid) const override {\n    return GetCOO()->FindEdge(eid);\n  }\n\n  /**\n   * @brief Find the edge IDs and return their source and target node IDs.\n   * @param eids The edge ID array.\n   * @return EdgeArray containing all edges with id in eid.  The order is\n   *         preserved.\n   */\n  EdgeArray FindEdges(IdArray eids) const override {\n    return GetCOO()->FindEdges(eids);\n  }\n\n  /**\n   * @brief Get the in edges of the vertex.\n   * @note The returned dst id array is filled with vid.\n   * @param vid The vertex id.\n   * @return the edges\n   */\n  EdgeArray InEdges(dgl_id_t vid) const override {\n    const EdgeArray &ret = GetInCSR()->OutEdges(vid);\n    return {ret.dst, ret.src, ret.id};\n  }\n\n  /**\n   * @brief Get the in edges of the vertices.\n   * @param vids The vertex id array.\n   * @return the id arrays of the two endpoints of the edges.\n   */\n  EdgeArray InEdges(IdArray vids) const override {\n    const EdgeArray &ret = GetInCSR()->OutEdges(vids);\n    return {ret.dst, ret.src, ret.id};\n  }\n\n  /**\n   * @brief Get the out edges of the vertex.\n   * @note The returned src id array is filled with vid.\n   * @param vid The vertex id.\n   * @return the id arrays of the two endpoints of the edges.\n   */\n  EdgeArray OutEdges(dgl_id_t vid) const override {\n    return GetOutCSR()->OutEdges(vid);\n  }\n\n  /**\n   * @brief Get the out edges of the vertices.\n   * @param vids The vertex id array.\n   * @return the id arrays of the two endpoints of the edges.\n   */\n  EdgeArray OutEdges(IdArray vids) const override {\n    return GetOutCSR()->OutEdges(vids);\n  }\n\n  /**\n   * @brief Get all the edges in the graph.\n   * @note If sorted is true, the returned edges list is sorted by their src and\n   *       dst ids. Otherwise, they are in their edge id order.\n   * @param sorted Whether the returned edge list is sorted by their src and dst\n   *        ids.\n   * @return the id arrays of the two endpoints of the edges.\n   */\n  EdgeArray Edges(const std::string &order = \"\") const override;\n\n  /**\n   * @brief Get the in degree of the given vertex.\n   * @param vid The vertex id.\n   * @return the in degree\n   */\n  uint64_t InDegree(dgl_id_t vid) const override {\n    return GetInCSR()->OutDegree(vid);\n  }\n\n  /**\n   * @brief Get the in degrees of the given vertices.\n   * @param vid The vertex id array.\n   * @return the in degree array\n   */\n  DegreeArray InDegrees(IdArray vids) const override {\n    return GetInCSR()->OutDegrees(vids);\n  }\n\n  /**\n   * @brief Get the out degree of the given vertex.\n   * @param vid The vertex id.\n   * @return the out degree\n   */\n  uint64_t OutDegree(dgl_id_t vid) const override {\n    return GetOutCSR()->OutDegree(vid);\n  }\n\n  /**\n   * @brief Get the out degrees of the given vertices.\n   * @param vid The vertex id array.\n   * @return the out degree array\n   */\n  DegreeArray OutDegrees(IdArray vids) const override {\n    return GetOutCSR()->OutDegrees(vids);\n  }\n\n  /**\n   * @brief Construct the induced subgraph of the given vertices.\n   *\n   * The induced subgraph is a subgraph formed by specifying a set of vertices\n   * V' and then selecting all of the edges from the original graph that connect\n   * two vertices in V'.\n   *\n   * Vertices and edges in the original graph will be \"reindexed\" to local\n   * index. The local index of the vertices preserve the order of the given id\n   * array, while the local index of the edges preserve the index order in the\n   * original graph. Vertices not in the original graph are ignored.\n   *\n   * The result subgraph is read-only.\n   *\n   * @param vids The vertices in the subgraph.\n   * @return the induced subgraph\n   */\n  Subgraph VertexSubgraph(IdArray vids) const override;\n\n  /**\n   * @brief Construct the induced edge subgraph of the given edges.\n   *\n   * The induced edges subgraph is a subgraph formed by specifying a set of\n   * edges E' and then selecting all of the nodes from the original graph that\n   * are endpoints in E'.\n   *\n   * Vertices and edges in the original graph will be \"reindexed\" to local\n   * index. The local index of the edges preserve the order of the given id\n   * array, while the local index of the vertices preserve the index order in\n   * the original graph. Edges not in the original graph are ignored.\n   *\n   * The result subgraph is read-only.\n   *\n   * @param eids The edges in the subgraph.\n   * @return the induced edge subgraph\n   */\n  Subgraph EdgeSubgraph(\n      IdArray eids, bool preserve_nodes = false) const override;\n\n  /**\n   * @brief Return the successor vector\n   * @param vid The vertex id.\n   * @return the successor vector\n   */\n  DGLIdIters SuccVec(dgl_id_t vid) const override {\n    return GetOutCSR()->SuccVec(vid);\n  }\n\n  /**\n   * @brief Return the out edge id vector\n   * @param vid The vertex id.\n   * @return the out edge id vector\n   */\n  DGLIdIters OutEdgeVec(dgl_id_t vid) const override {\n    return GetOutCSR()->OutEdgeVec(vid);\n  }\n\n  /**\n   * @brief Return the predecessor vector\n   * @param vid The vertex id.\n   * @return the predecessor vector\n   */\n  DGLIdIters PredVec(dgl_id_t vid) const override {\n    return GetInCSR()->SuccVec(vid);\n  }\n\n  /**\n   * @brief Return the in edge id vector\n   * @param vid The vertex id.\n   * @return the in edge id vector\n   */\n  DGLIdIters InEdgeVec(dgl_id_t vid) const override {\n    return GetInCSR()->OutEdgeVec(vid);\n  }\n\n  /**\n   * @brief Get the adjacency matrix of the graph.\n   *\n   * By default, a row of returned adjacency matrix represents the destination\n   * of an edge and the column represents the source.\n   * @param transpose A flag to transpose the returned adjacency matrix.\n   * @param fmt the format of the returned adjacency matrix.\n   * @return a vector of three IdArray.\n   */\n  std::vector<IdArray> GetAdj(\n      bool transpose, const std::string &fmt) const override;\n\n  /** @brief Return in csr. If not exist, transpose the other one.*/\n  CSRPtr GetInCSR() const;\n\n  /** @brief Return out csr. If not exist, transpose the other one.*/\n  CSRPtr GetOutCSR() const;\n\n  /** @brief Return coo. If not exist, create from csr.*/\n  COOPtr GetCOO() const;\n\n  /** @brief Create an immutable graph from CSR. */\n  static ImmutableGraphPtr CreateFromCSR(\n      IdArray indptr, IdArray indices, IdArray edge_ids,\n      const std::string &edge_dir);\n\n  static ImmutableGraphPtr CreateFromCSR(const std::string &shared_mem_name);\n\n  /** @brief Create an immutable graph from COO. */\n  static ImmutableGraphPtr CreateFromCOO(\n      int64_t num_vertices, IdArray src, IdArray dst, bool row_osrted = false,\n      bool col_sorted = false);\n\n  /**\n   * @brief Convert the given graph to an immutable graph.\n   *\n   * If the graph is already an immutable graph. The result graph will share\n   * the storage with the given one.\n   *\n   * @param graph The input graph.\n   * @return an immutable graph object.\n   */\n  static ImmutableGraphPtr ToImmutable(GraphPtr graph);\n\n  /**\n   * @brief Copy the data to another context.\n   * @param ctx The target context.\n   * @return The graph under another context.\n   */\n  static ImmutableGraphPtr CopyTo(ImmutableGraphPtr g, const DGLContext &ctx);\n\n  /**\n   * @brief Copy data to shared memory.\n   * @param name The name of the shared memory.\n   * @return The graph in the shared memory\n   */\n  static ImmutableGraphPtr CopyToSharedMem(\n      ImmutableGraphPtr g, const std::string &name);\n\n  /**\n   * @brief Convert the graph to use the given number of bits for storage.\n   * @param bits The new number of integer bits (32 or 64).\n   * @return The graph with new bit size storage.\n   */\n  static ImmutableGraphPtr AsNumBits(ImmutableGraphPtr g, uint8_t bits);\n\n  /**\n   * @brief Return a new graph with all the edges reversed.\n   *\n   * The returned graph preserves the vertex and edge index in the original\n   * graph.\n   *\n   * @return the reversed graph\n   */\n  ImmutableGraphPtr Reverse() const;\n\n  /** @return Load ImmutableGraph from stream, using out csr */\n  bool Load(dmlc::Stream *fs);\n\n  /** @return Save ImmutableGraph to stream, using out csr */\n  void Save(dmlc::Stream *fs) const;\n\n  void SortCSR() override {\n    GetInCSR()->SortCSR();\n    GetOutCSR()->SortCSR();\n  }\n\n  bool HasInCSR() const { return in_csr_ != NULL; }\n\n  bool HasOutCSR() const { return out_csr_ != NULL; }\n\n  /** @brief Cast this graph to a heterograph */\n  HeteroGraphPtr AsHeteroGraph() const;\n\n protected:\n  friend class Serializer;\n  friend class UnitGraph;\n\n  /** @brief internal default constructor */\n  ImmutableGraph() {}\n\n  /** @brief internal constructor for all the members */\n  ImmutableGraph(CSRPtr in_csr, CSRPtr out_csr, COOPtr coo)\n      : in_csr_(in_csr), out_csr_(out_csr), coo_(coo) {\n    CHECK(AnyGraph()) << \"At least one graph structure should exist.\";\n  }\n\n  ImmutableGraph(\n      CSRPtr in_csr, CSRPtr out_csr, const std::string shared_mem_name)\n      : in_csr_(in_csr), out_csr_(out_csr) {\n    CHECK(in_csr_ || out_csr_) << \"Both CSR are missing.\";\n    this->shared_mem_name_ = shared_mem_name;\n  }\n\n  /** @brief return pointer to any available graph structure */\n  GraphPtr AnyGraph() const {\n    if (in_csr_) {\n      return in_csr_;\n    } else if (out_csr_) {\n      return out_csr_;\n    } else {\n      return coo_;\n    }\n  }\n\n  // Store the in csr (i.e, the reverse csr)\n  CSRPtr in_csr_;\n  // Store the out csr (i.e, the normal csr)\n  CSRPtr out_csr_;\n  // Store the edge list indexed by edge id (COO)\n  COOPtr coo_;\n\n  // The name of shared memory for this graph.\n  // If it's empty, the graph isn't stored in shared memory.\n  std::string shared_mem_name_;\n  // We serialize the metadata of the graph index here for shared memory.\n  NDArray serialized_shared_meta_;\n\n  // Whether or not the `is_unibipartite_` property has been set.\n  mutable bool is_unibipartite_set_ = false;\n  // Whether this graph is unibipartite. If `is_unibipartite_set_` is false,\n  // then this flag should be considered in an unititialized state.\n  mutable bool is_unibipartite_ = false;\n};\n\n// inline implementations\n\ntemplate <typename IndptrIter, typename IndicesIter, typename EdgeIdIter>\nCSR::CSR(\n    int64_t num_vertices, int64_t num_edges, IndptrIter indptr_begin,\n    IndicesIter indices_begin, EdgeIdIter edge_ids_begin) {\n  // TODO(minjie): this should be changed to a device-agnostic implementation\n  // in the future.\n  adj_.num_rows = num_vertices;\n  adj_.num_cols = num_vertices;\n  adj_.indptr = aten::NewIdArray(num_vertices + 1);\n  adj_.indices = aten::NewIdArray(num_edges);\n  adj_.data = aten::NewIdArray(num_edges);\n  dgl_id_t *indptr_data = static_cast<dgl_id_t *>(adj_.indptr->data);\n  dgl_id_t *indices_data = static_cast<dgl_id_t *>(adj_.indices->data);\n  dgl_id_t *edge_ids_data = static_cast<dgl_id_t *>(adj_.data->data);\n  for (int64_t i = 0; i < num_vertices + 1; ++i)\n    *(indptr_data++) = *(indptr_begin++);\n  for (int64_t i = 0; i < num_edges; ++i) {\n    *(indices_data++) = *(indices_begin++);\n    *(edge_ids_data++) = *(edge_ids_begin++);\n  }\n}\n\n}  // namespace dgl\n\nnamespace dmlc {\nDMLC_DECLARE_TRAITS(has_saveload, dgl::CSR, true);\nDMLC_DECLARE_TRAITS(has_saveload, dgl::ImmutableGraph, true);\n}  // namespace dmlc\n\n#endif  // DGL_IMMUTABLE_GRAPH_H_\n"
  },
  {
    "path": "include/dgl/kernel.h",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file dgl/aten/kernel.h\n * @brief Sparse matrix operators.\n */\n#ifndef DGL_KERNEL_H_\n#define DGL_KERNEL_H_\n\n#include <string>\n#include <utility>\n#include <vector>\n\n#include \"./base_heterograph.h\"\n#include \"./bcast.h\"\n#include \"array.h\"\n\nnamespace dgl {\nnamespace aten {\n\n/**\n * @brief Generalized Sparse Matrix-Matrix Multiplication.\n * @param op The binary operator, could be `add`, `sub', `mul`, 'div',\n *        `copy_u`, `copy_e'.\n * @param op The reduce operator, could be `sum`, `min`, `max'.\n * @param graph The graph we apply SpMM on.\n * @param ufeat The source node feature.\n * @param efeat The edge feature.\n * @param out The output feature on destination nodes.\n * @param out_aux A list of NDArray's that contains auxiliary information such\n *        as the argmax on source nodes and edges for reduce operators such as\n *        `min` and `max`.\n */\nvoid SpMM(\n    const std::string& op, const std::string& reduce, HeteroGraphPtr graph,\n    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);\n\n/**\n * @brief Generalized Sampled Dense-Dense Matrix Multiplication.\n * @param op The binary operator, could be `add`, `sub', `mul`, 'div',\n *        `dot`, `copy_u`, `copy_e'.\n * @param graph The graph we apply SpMM on.\n * @param ufeat The source node feature.\n * @param vfeat The destination node feature.\n * @param out The output feature on edge.\n */\nvoid SDDMM(\n    const std::string& op, HeteroGraphPtr graph, NDArray ufeat, NDArray efeat,\n    NDArray out);\n\n/**\n * @brief Sparse-sparse matrix multiplication.\n *\n * The sparse matrices must have scalar weights (i.e. \\a A_weights and \\a\n * B_weights are 1D vectors.)\n */\nstd::pair<CSRMatrix, NDArray> CSRMM(\n    CSRMatrix A, NDArray A_weights, CSRMatrix B, NDArray B_weights);\n\n/**\n * @brief Summing up a list of sparse matrices.\n *\n * The sparse matrices must have scalar weights (i.e. the arrays in \\a A_weights\n * are 1D vectors.)\n */\nstd::pair<CSRMatrix, NDArray> CSRSum(\n    const std::vector<CSRMatrix>& A, const std::vector<NDArray>& A_weights);\n\n}  // namespace aten\n}  // namespace dgl\n\n#endif  // DGL_KERNEL_H_\n"
  },
  {
    "path": "include/dgl/lazy.h",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file dgl/lazy.h\n * @brief Lazy object that will be materialized only when being queried.\n */\n#ifndef DGL_LAZY_H_\n#define DGL_LAZY_H_\n\n#include <memory>\n\nnamespace dgl {\n\n/**\n * @brief Lazy object that will be materialized only when being queried.\n *\n * The object should be immutable -- no mutation once materialized.\n * The object is currently not threaad safe.\n */\ntemplate <typename T>\nclass Lazy {\n public:\n  /** @brief default constructor to construct a lazy object */\n  Lazy() {}\n\n  /**\n   * @brief constructor to construct an object with given value (non-lazy case)\n   */\n  explicit Lazy(const T& val) : ptr_(new T(val)) {}\n\n  /** @brief destructor */\n  ~Lazy() = default;\n\n  /**\n   * @brief Get the value of this object. If the object has not been\n   *        instantiated, using the provided function to create it.\n   * @param fn The creator function.\n   * @return the object value.\n   */\n  template <typename Fn>\n  const T& Get(Fn fn) {\n    if (!ptr_) {\n      ptr_.reset(new T(fn()));\n    }\n    return *ptr_;\n  }\n\n private:\n  /** @brief the internal data pointer */\n  std::shared_ptr<T> ptr_{nullptr};\n};\n\n}  // namespace dgl\n\n#endif  // DGL_LAZY_H_\n"
  },
  {
    "path": "include/dgl/nodeflow.h",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file dgl/nodeflow.h\n * @brief DGL NodeFlow class.\n */\n#ifndef DGL_NODEFLOW_H_\n#define DGL_NODEFLOW_H_\n\n#include <memory>\n#include <string>\n#include <vector>\n\n#include \"./runtime/object.h\"\n#include \"graph_interface.h\"\n\nnamespace dgl {\n\nclass ImmutableGraph;\n\n/**\n * @brief A NodeFlow graph stores the sampling results for a sampler that\n * samples nodes/edges in layers.\n *\n * We store multiple layers of the sampling results in a single graph, which\n * results in a more compact format. We store extra information, such as the\n * node and edge mapping from the NodeFlow graph to the parent graph.\n */\nstruct NodeFlowObject : public runtime::Object {\n  /** @brief The graph. */\n  GraphPtr graph;\n  /**\n   * @brief the offsets of each layer.\n   */\n  IdArray layer_offsets;\n  /**\n   * @brief the offsets of each flow.\n   */\n  IdArray flow_offsets;\n  /**\n   * @brief The node mapping from the NodeFlow graph to the parent graph.\n   */\n  IdArray node_mapping;\n  /**\n   * @brief The edge mapping from the NodeFlow graph to the parent graph.\n   */\n  IdArray edge_mapping;\n\n  static constexpr const char *_type_key = \"graph.NodeFlow\";\n  DGL_DECLARE_OBJECT_TYPE_INFO(NodeFlowObject, runtime::Object);\n};\n\n// Define NodeFlow as the reference class of NodeFlowObject\nclass NodeFlow : public runtime::ObjectRef {\n public:\n  DGL_DEFINE_OBJECT_REF_METHODS(NodeFlow, runtime::ObjectRef, NodeFlowObject);\n\n  /** @brief create a new nodeflow reference */\n  static NodeFlow Create() {\n    return NodeFlow(std::make_shared<NodeFlowObject>());\n  }\n};\n\n/**\n * @brief Get a slice on a graph that represents a NodeFlow.\n *\n * The entire block has to be taken as a slice. Users have to specify the\n * correct starting and ending location of a layer.\n *\n * If remap is false, the returned arrays can be viewed as a sub-matrix slice\n * of the adjmat of the input graph. Let the adjmat of the input graph be A,\n * then the slice is equal to (in numpy syntax):\n *   A[layer1_start:layer1_end, layer0_start:layer0_end]\n *\n * If remap is true,  the returned arrays represents an adjacency matrix\n * of shape NxM, where N is the number of nodes in layer1 and M is\n * the number of nodes in layer0. Nodes in layer0 will be remapped to\n * [0, M) and nodes in layer1 will be remapped to [0, N).\n *\n * A row of the returned adjacency matrix represents the destination\n * of an edge and the column represents the source.\n *\n * If fmt == \"csr\", the function returns three arrays: indptr, indices, eid.\n * If fmt == \"coo\", the function returns two arrays: idx, eid. Here, the idx\n * array is the concatenation of src and dst node id arrays.\n *\n * @param graph An immutable graph.\n * @param fmt the format of the returned adjacency matrix.\n * @param layer0_size the size of the first layer in the block.\n * @param layer1_start the location where the second layer starts.\n * @param layer1_end the location where the secnd layer ends.\n * @param remap Indicates to remap all vertex ids and edge Ids to local Id\n * space.\n * @return a vector of IdArrays.\n */\nstd::vector<IdArray> GetNodeFlowSlice(\n    const ImmutableGraph &graph, const std::string &fmt, size_t layer0_size,\n    size_t layer1_start, size_t layer1_end, bool remap);\n\n}  // namespace dgl\n\n#endif  // DGL_NODEFLOW_H_\n"
  },
  {
    "path": "include/dgl/packed_func_ext.h",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file packed_func_ext.h\n * @brief Extension package to PackedFunc\n *   This enables pass ObjectRef types into/from PackedFunc.\n */\n#ifndef DGL_PACKED_FUNC_EXT_H_\n#define DGL_PACKED_FUNC_EXT_H_\n\n#include <memory>\n#include <sstream>\n#include <string>\n#include <type_traits>\n\n#include \"./runtime/container.h\"\n#include \"./runtime/object.h\"\n#include \"./runtime/packed_func.h\"\n\nnamespace dgl {\nnamespace runtime {\n/**\n * @brief Runtime type checker for node type.\n * @tparam T the type to be checked.\n */\ntemplate <typename T>\nstruct ObjectTypeChecker {\n  static inline bool Check(Object* sptr) {\n    // This is the only place in the project where RTTI is used\n    // It can be turned off, but will make non strict checking.\n    // TODO(tqchen) possibly find alternative to turn of RTTI\n    using ContainerType = typename T::ContainerType;\n    return sptr->derived_from<ContainerType>();\n  }\n  static inline void PrintName(std::ostringstream& os) {  // NOLINT(*)\n    using ContainerType = typename T::ContainerType;\n    os << ContainerType::_type_key;\n  }\n};\n\ntemplate <typename T>\nstruct ObjectTypeChecker<List<T> > {\n  static inline bool Check(Object* sptr) {\n    if (sptr == nullptr) return false;\n    if (!sptr->is_type<ListObject>()) return false;\n    ListObject* n = static_cast<ListObject*>(sptr);\n    for (const auto& p : n->data) {\n      if (!ObjectTypeChecker<T>::Check(p.get())) return false;\n    }\n    return true;\n  }\n  static inline void PrintName(std::ostringstream& os) {  // NOLINT(*)\n    os << \"list<\";\n    ObjectTypeChecker<T>::PrintName(os);\n    os << \">\";\n  }\n};\n\ntemplate <typename V>\nstruct ObjectTypeChecker<Map<std::string, V> > {\n  static inline bool Check(Object* sptr) {\n    if (sptr == nullptr) return false;\n    if (!sptr->is_type<StrMapObject>()) return false;\n    StrMapObject* n = static_cast<StrMapObject*>(sptr);\n    for (const auto& kv : n->data) {\n      if (!ObjectTypeChecker<V>::Check(kv.second.get())) return false;\n    }\n    return true;\n  }\n  static inline void PrintName(std::ostringstream& os) {  // NOLINT(*)\n    os << \"map<string\";\n    os << ',';\n    ObjectTypeChecker<V>::PrintName(os);\n    os << '>';\n  }\n};\n\ntemplate <typename K, typename V>\nstruct ObjectTypeChecker<Map<K, V> > {\n  static inline bool Check(Object* sptr) {\n    if (sptr == nullptr) return false;\n    if (!sptr->is_type<MapObject>()) return false;\n    MapObject* n = static_cast<MapObject*>(sptr);\n    for (const auto& kv : n->data) {\n      if (!ObjectTypeChecker<K>::Check(kv.first.get())) return false;\n      if (!ObjectTypeChecker<V>::Check(kv.second.get())) return false;\n    }\n    return true;\n  }\n  static inline void PrintName(std::ostringstream& os) {  // NOLINT(*)\n    os << \"map<\";\n    ObjectTypeChecker<K>::PrintName(os);\n    os << ',';\n    ObjectTypeChecker<V>::PrintName(os);\n    os << '>';\n  }\n};\n\ntemplate <typename T>\ninline std::string NodeTypeName() {\n  std::ostringstream os;\n  ObjectTypeChecker<T>::PrintName(os);\n  return os.str();\n}\n\n// extensions for DGLArgValue\n\ntemplate <typename TObjectRef>\ninline TObjectRef DGLArgValue::AsObjectRef() const {\n  static_assert(\n      std::is_base_of<ObjectRef, TObjectRef>::value,\n      \"Conversion only works for ObjectRef derived class\");\n  if (type_code_ == kNull) return TObjectRef();\n  DGL_CHECK_TYPE_CODE(type_code_, kObjectHandle);\n  std::shared_ptr<Object>& sptr = *ptr<std::shared_ptr<Object> >();\n  CHECK(ObjectTypeChecker<TObjectRef>::Check(sptr.get()))\n      << \"Expected type \" << NodeTypeName<TObjectRef>() << \" but get \"\n      << sptr->type_key();\n  return TObjectRef(sptr);\n}\n\ninline std::shared_ptr<Object>& DGLArgValue::obj_sptr() {\n  DGL_CHECK_TYPE_CODE(type_code_, kObjectHandle);\n  return *ptr<std::shared_ptr<Object> >();\n}\n\ntemplate <typename TObjectRef, typename>\ninline bool DGLArgValue::IsObjectType() const {\n  DGL_CHECK_TYPE_CODE(type_code_, kObjectHandle);\n  std::shared_ptr<Object>& sptr = *ptr<std::shared_ptr<Object> >();\n  return ObjectTypeChecker<TObjectRef>::Check(sptr.get());\n}\n\n// extensions for DGLRetValue\n\ninline DGLRetValue& DGLRetValue::operator=(\n    const std::shared_ptr<Object>& other) {\n  if (other.get() == nullptr) {\n    SwitchToPOD(kNull);\n  } else {\n    SwitchToClass<std::shared_ptr<Object> >(kObjectHandle, other);\n  }\n  return *this;\n}\n\ninline DGLRetValue& DGLRetValue::operator=(const ObjectRef& other) {\n  if (!other.defined()) {\n    SwitchToPOD(kNull);\n  } else {\n    SwitchToClass<std::shared_ptr<Object> >(kObjectHandle, other.obj_);\n  }\n  return *this;\n}\n\ntemplate <typename TObjectRef>\ninline TObjectRef DGLRetValue::AsObjectRef() const {\n  static_assert(\n      std::is_base_of<ObjectRef, TObjectRef>::value,\n      \"Conversion only works for ObjectRef\");\n  if (type_code_ == kNull) return TObjectRef();\n  DGL_CHECK_TYPE_CODE(type_code_, kObjectHandle);\n  return TObjectRef(*ptr<std::shared_ptr<Object> >());\n}\n\ninline void DGLArgsSetter::operator()(\n    size_t i, const ObjectRef& other) const {  // NOLINT(*)\n  if (other.defined()) {\n    values_[i].v_handle = const_cast<std::shared_ptr<Object>*>(&(other.obj_));\n    type_codes_[i] = kObjectHandle;\n  } else {\n    type_codes_[i] = kNull;\n  }\n}\n\n}  // namespace runtime\n}  // namespace dgl\n#endif  // DGL_PACKED_FUNC_EXT_H_\n"
  },
  {
    "path": "include/dgl/random.h",
    "content": "/**\n * Copyright (c) 2017 by Contributors\n * @file dgl/random.h\n * @brief Random number generators\n */\n\n#ifndef DGL_RANDOM_H_\n#define DGL_RANDOM_H_\n\n#include <dgl/array.h>\n#include <dmlc/logging.h>\n#include <dmlc/thread_local.h>\n\n#include <random>\n#include <thread>\n#include <vector>\n\n#include <pcg_random.hpp>\n\nnamespace dgl {\n\nnamespace {\n\n// Get a unique integer ID representing this thread.\ninline uint32_t GetThreadId() {\n  static int num_threads = 0;\n  static std::mutex mutex;\n  static thread_local int id = -1;\n\n  if (id == -1) {\n    std::lock_guard<std::mutex> guard(mutex);\n    id = num_threads;\n    num_threads++;\n  }\n  return id;\n}\n\n};  // namespace\n\n/**\n * @brief Thread-local Random Number Generator class\n */\nclass RandomEngine {\n public:\n  /** @brief Constructor with default seed */\n  RandomEngine() {\n    std::random_device rd;\n    SetSeed(rd());\n  }\n\n  /** @brief Constructor with given seed */\n  explicit RandomEngine(uint64_t seed, uint64_t stream = GetThreadId()) {\n    SetSeed(seed, stream);\n  }\n\n  /** @brief Get the thread-local random number generator instance */\n  static RandomEngine* ThreadLocal() {\n    return dmlc::ThreadLocalStore<RandomEngine>::Get();\n  }\n\n  /**\n   * @brief Set the seed of this random number generator\n   */\n  void SetSeed(uint64_t seed, uint64_t stream = GetThreadId()) {\n    rng_.seed(seed, stream);\n  }\n\n  /**\n   * @brief Generate an arbitrary random 32-bit integer.\n   */\n  int32_t RandInt32() { return static_cast<int32_t>(rng_()); }\n\n  /**\n   * @brief Generate a uniform random integer in [0, upper)\n   */\n  template <typename T>\n  T RandInt(T upper) {\n    return RandInt<T>(0, upper);\n  }\n\n  /**\n   * @brief Generate a uniform random integer in [lower, upper)\n   */\n  template <typename T>\n  T RandInt(T lower, T upper) {\n    CHECK_LT(lower, upper);\n    std::uniform_int_distribution<T> dist(lower, upper - 1);\n    return dist(rng_);\n  }\n\n  /**\n   * @brief Generate a uniform random float in [0, 1)\n   */\n  template <typename T>\n  T Uniform() {\n    return Uniform<T>(0., 1.);\n  }\n\n  /**\n   * @brief Generate a uniform random float in [lower, upper)\n   */\n  template <typename T>\n  T Uniform(T lower, T upper) {\n    // Although the result is in [lower, upper), we allow lower == upper as in\n    // www.cplusplus.com/reference/random/uniform_real_distribution/uniform_real_distribution/\n    CHECK_LE(lower, upper);\n    std::uniform_real_distribution<T> dist(lower, upper);\n    return dist(rng_);\n  }\n\n  /**\n   * @brief Pick a random integer between 0 to N-1 according to given\n   *        probabilities.\n   * @tparam IdxType Return integer type.\n   * @param prob Array of N unnormalized probability of each element. Must be\n   *        non-negative.\n   * @return An integer randomly picked from 0 to N-1.\n   */\n  template <typename IdxType>\n  IdxType Choice(FloatArray prob);\n\n  /**\n   * @brief Pick random integers between 0 to N-1 according to given\n   * probabilities\n   *\n   * If replace is false, the number of picked integers must not larger than N.\n   *\n   * @tparam IdxType Id type\n   * @tparam FloatType Probability value type\n   * @param num Number of integers to choose\n   * @param prob Array of N unnormalized probability of each element.  Must be\n   *        non-negative.\n   * @param out The output buffer to write selected indices.\n   * @param replace If true, choose with replacement.\n   */\n  template <typename IdxType, typename FloatType>\n  void Choice(IdxType num, FloatArray prob, IdxType* out, bool replace = true);\n\n  /**\n   * @brief Pick random integers between 0 to N-1 according to given\n   * probabilities\n   *\n   * If replace is false, the number of picked integers must not larger than N.\n   *\n   * @tparam IdxType Id type\n   * @tparam FloatType Probability value type\n   * @param num Number of integers to choose\n   * @param prob Array of N unnormalized probability of each element.  Must be\n   *        non-negative.\n   * @param replace If true, choose with replacement.\n   * @return Picked indices\n   */\n  template <typename IdxType, typename FloatType>\n  IdArray Choice(IdxType num, FloatArray prob, bool replace = true) {\n    const DGLDataType dtype{kDGLInt, sizeof(IdxType) * 8, 1};\n    IdArray ret = IdArray::Empty({num}, dtype, prob->ctx);\n    Choice<IdxType, FloatType>(\n        num, prob, static_cast<IdxType*>(ret->data), replace);\n    return ret;\n  }\n\n  /**\n   * @brief Pick random integers from population by uniform distribution.\n   *\n   * If replace is false, num must not be larger than population.\n   *\n   * @tparam IdxType Return integer type\n   * @param num Number of integers to choose\n   * @param population Total number of elements to choose from.\n   * @param out The output buffer to write selected indices.\n   * @param replace If true, choose with replacement.\n   */\n  template <typename IdxType>\n  void UniformChoice(\n      IdxType num, IdxType population, IdxType* out, bool replace = true);\n\n  /**\n   * @brief Pick random integers from population by uniform distribution.\n   *\n   * If replace is false, num must not be larger than population.\n   *\n   * @tparam IdxType Return integer type\n   * @param num Number of integers to choose\n   * @param population Total number of elements to choose from.\n   * @param replace If true, choose with replacement.\n   * @return Picked indices\n   */\n  template <typename IdxType>\n  IdArray UniformChoice(IdxType num, IdxType population, bool replace = true) {\n    const DGLDataType dtype{kDGLInt, sizeof(IdxType) * 8, 1};\n    // TODO(minjie): only CPU implementation right now\n    IdArray ret = IdArray::Empty({num}, dtype, DGLContext{kDGLCPU, 0});\n    UniformChoice<IdxType>(\n        num, population, static_cast<IdxType*>(ret->data), replace);\n    return ret;\n  }\n\n  /**\n   * @brief Pick random integers with different probability for different\n   * segments.\n   *\n   * For example, if split=[0, 4, 10] and bias=[1.5, 1], it means to pick some\n   * integers from 0 to 9, which is divided into two segments. 0-3 are in the\n   * first segment and the rest belongs to the second. The weight(bias) of each\n   * candidate in the first segment is upweighted to 1.5.\n   *\n   *  candidate | 0 1 2 3 | 4 5 6 7 8 9 |\n   *  split       ^         ^            ^\n   *  bias      |   1.5   |      1      |\n   *\n   *\n   * The complexity of this operator is O(k * log(T)) where k is the number of\n   * integers we want to pick, and T is the number of segments. It is much\n   * faster compared with assigning probability for each candidate, of which the\n   * complexity is O(k * log(N)) where N is the number of all candidates.\n   *\n   * If replace is false, num must not be larger than population.\n   *\n   * @tparam IdxType Return integer type\n   * @param num Number of integers to choose\n   * @param split Array of T+1 split positions of different segments(including\n   *        start and end)\n   * @param bias Array of T weight of each segments.\n   * @param out The output buffer to write selected indices.\n   * @param replace If true, choose with replacement.\n   */\n  template <typename IdxType, typename FloatType>\n  void BiasedChoice(\n      IdxType num, const IdxType* split, FloatArray bias, IdxType* out,\n      bool replace = true);\n\n  /**\n   * @brief Pick random integers with different probability for different\n   * segments.\n   *\n   * If replace is false, num must not be larger than population.\n   *\n   * @tparam IdxType Return integer type\n   * @param num Number of integers to choose\n   * @param split Split positions of different segments\n   * @param bias Weights of different segments\n   * @param replace If true, choose with replacement.\n   */\n  template <typename IdxType, typename FloatType>\n  IdArray BiasedChoice(\n      IdxType num, const IdxType* split, FloatArray bias, bool replace = true) {\n    const DGLDataType dtype{kDGLInt, sizeof(IdxType) * 8, 1};\n    IdArray ret = IdArray::Empty({num}, dtype, DGLContext{kDGLCPU, 0});\n    BiasedChoice<IdxType, FloatType>(\n        num, split, bias, static_cast<IdxType*>(ret->data), replace);\n    return ret;\n  }\n\n private:\n  pcg32 rng_;\n};\n\n};  // namespace dgl\n\n#endif  // DGL_RANDOM_H_\n"
  },
  {
    "path": "include/dgl/runtime/bfloat16.h",
    "content": "/**\n *  Copyright (c) 2023 by Contributors\n * @file dgl/runtime/ndarray.h\n * @brief BFloat16 CPU header\n */\n#ifndef DGL_RUNTIME_BFLOAT16_H_\n#define DGL_RUNTIME_BFLOAT16_H_\n\n#include <cmath>\n\nclass BFloat16 {\n  uint16_t val;\n\n public:\n  constexpr BFloat16() : val(0) {}\n  // Disable lint \"explicit\" warning, since implicit usage on constructor is\n  // expected.\n  BFloat16(float f) {  // NOLINT\n    if (std::isnan(f)) {\n      val = 0x7FC0;\n    } else {\n      union {\n        uint16_t iraw16[2];\n        uint32_t iraw32;\n        float f32;\n      };\n\n      f32 = f;\n      const uint32_t rounding_bias = 0x00007FFF + (iraw16[1] & 0x1);\n      val = static_cast<uint16_t>((iraw32 + rounding_bias) >> 16);\n    }\n  }\n  static constexpr BFloat16 Min() {\n    BFloat16 min;\n    min.val = 0xFF80;\n    return min;\n  }\n\n  static constexpr BFloat16 Max() {\n    BFloat16 max;\n    max.val = 0x7F80;\n    return max;\n  }\n\n  BFloat16& operator-=(const float& rhs) {\n    float lhs = (*this);\n    (*this) = lhs - rhs;\n    return *this;\n  }\n\n  BFloat16& operator+=(const float& rhs) {\n    float lhs = (*this);\n    (*this) = lhs + rhs;\n    return *this;\n  }\n\n  operator float() const {\n    union {\n      float f;\n      uint16_t raw[2];\n    };\n    raw[0] = 0;\n    raw[1] = val;\n    return f;\n  }\n};\n\n#endif  // DGL_RUNTIME_BFLOAT16_H_\n"
  },
  {
    "path": "include/dgl/runtime/c_backend_api.h",
    "content": "/**\n *  Copyright (c) 2017 by Contributors\n * @file dgl/runtime/c_backend_api.h\n * @brief DGL runtime backend API.\n *\n *  The functions defined in this header are intended to be\n *  used by compiled dgl operators, usually user do not need to use these\n *  function directly.\n */\n#ifndef DGL_RUNTIME_C_BACKEND_API_H_\n#define DGL_RUNTIME_C_BACKEND_API_H_\n\n#include \"c_runtime_api.h\"\n\n#ifdef __cplusplus\nextern \"C\" {\n#endif\n\n// Backend related functions.\n/**\n * @brief Backend function for modules to get function\n *  from its environment mod_node (its imports and global function).\n *  The user do should not call DGLFuncFree on func.\n *\n * @param mod_node The module handle.\n * @param func_name The name of the function.\n * @param out The result function.\n * @return 0 when no error is thrown, -1 when failure happens\n */\nDGL_DLL int DGLBackendGetFuncFromEnv(\n    void* mod_node, const char* func_name, DGLFunctionHandle* out);\n/**\n * @brief Backend function to register system-wide library symbol.\n *\n * @param name The name of the symbol\n * @param ptr The symbol address.\n * @return 0 when no error is thrown, -1 when failure happens\n */\nDGL_DLL int DGLBackendRegisterSystemLibSymbol(const char* name, void* ptr);\n\n/**\n * @brief Backend function to allocate temporal workspace.\n *\n * @note The result allocate spaced is ensured to be aligned to\n *       kTempAllocaAlignment.\n *\n * @param nbytes The size of the space requested.\n * @param device_type The device type which the space will be allocated.\n * @param device_id The device id which the space will be allocated.\n * @param dtype_code_hint The type code of the array elements. Only used in\n *        certain backends such as OpenGL.\n * @param dtype_bits_hint The type bits of the array elements. Only used in\n *        certain backends such as OpenGL.\n * @return nullptr when error is thrown, a valid ptr if success\n */\nDGL_DLL void* DGLBackendAllocWorkspace(\n    int device_type, int device_id, uint64_t nbytes, int dtype_code_hint,\n    int dtype_bits_hint);\n\n/**\n * @brief Backend function to free temporal workspace.\n *\n * @param ptr The result allocated space pointer.\n * @param device_type The device type which the space will be allocated.\n * @param device_id The device id which the space will be allocated.\n * @return 0 when no error is thrown, -1 when failure happens\n *\n * @sa DGLBackendAllocWorkspace\n */\nDGL_DLL int DGLBackendFreeWorkspace(int device_type, int device_id, void* ptr);\n\n/**\n * @brief Environment for DGL parallel task.\n */\ntypedef struct {\n  /**\n   * @brief Auxiliary used for synchronization\n   */\n  void* sync_handle;\n  /** @brief total amount of task */\n  int32_t num_task;\n} DGLParallelGroupEnv;\n\n/**\n * @brief The callback function to execute a parallel lambda\n * @param task_id the task id of the function.\n * @param penv The parallel environment backs the execution.\n * @param cdata The supporting closure data.\n */\ntypedef int (*FDGLParallelLambda)(\n    int task_id, DGLParallelGroupEnv* penv, void* cdata);\n\n/**\n * @brief Backend function for running parallel jobs.\n *\n * @param flambda The parallel function to be launched.\n * @param cdata The closure data.\n * @param num_task Number of tasks to launch, can be 0, means launch\n *        with all available threads.\n *\n * @return 0 when no error is thrown, -1 when failure happens\n */\nDGL_DLL int DGLBackendParallelLaunch(\n    FDGLParallelLambda flambda, void* cdata, int num_task);\n\n/**\n * @brief BSP barrrier between parallel threads\n * @param task_id the task id of the function.\n * @param penv The parallel environment backs the execution.\n * @return 0 when no error is thrown, -1 when failure happens\n */\nDGL_DLL int DGLBackendParallelBarrier(int task_id, DGLParallelGroupEnv* penv);\n\n/**\n * @brief Simple static initialization fucntion.\n *  Run f once and set handle to be not null.\n *  This function is mainly used for test purpose.\n *\n * @param handle An global address to indicate f\n * @param f The function to be ran\n * @param cdata The closure data to pass to the function.\n * @param nbytes Number of bytes in the closure data.\n * @return 0 when no error is thrown, -1 when failure happens\n */\nDGL_DLL int DGLBackendRunOnce(\n    void** handle, int (*f)(void*), void* cdata, int nbytes);\n\n#ifdef __cplusplus\n}  // DGL_EXTERN_C\n#endif\n#endif  // DGL_RUNTIME_C_BACKEND_API_H_\n"
  },
  {
    "path": "include/dgl/runtime/c_object_api.h",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file dgl/runtime/c_object_api.h\n *\n * @brief DGL Object C API, used to extend and prototype new CAPIs.\n *\n * @note Most API functions are registerd as PackedFunc and\n *  can be grabbed via DGLFuncGetGlobal\n */\n#ifndef DGL_RUNTIME_C_OBJECT_API_H_\n#define DGL_RUNTIME_C_OBJECT_API_H_\n\n#include \"./c_runtime_api.h\"\n\n#ifdef __cplusplus\nextern \"C\" {\n#endif\n\n/** @brief handle to object */\ntypedef void* ObjectHandle;\n\n/**\n * @brief free the object handle\n * @param handle The object handle to be freed.\n * @return 0 when success, -1 when failure happens\n */\nDGL_DLL int DGLObjectFree(ObjectHandle handle);\n\n/**\n * @brief Convert type key to type index.\n * @param type_key The key of the type.\n * @param out_index the corresponding type index.\n * @return 0 when success, -1 when failure happens\n */\nDGL_DLL int DGLObjectTypeKey2Index(const char* type_key, int* out_index);\n\n/**\n * @brief Get runtime type index of the object.\n * @param handle the object handle.\n * @param out_index the corresponding type index.\n * @return 0 when success, -1 when failure happens\n */\nDGL_DLL int DGLObjectGetTypeIndex(ObjectHandle handle, int* out_index);\n\n/**\n * @brief get attributes given key\n * @param handle The object handle\n * @param key The attribute name\n * @param out_value The attribute value\n * @param out_type_code The type code of the attribute.\n * @param out_success Whether get is successful.\n * @return 0 when success, -1 when failure happens\n * @note API calls always exchanges with type bits=64, lanes=1\n */\nDGL_DLL int DGLObjectGetAttr(\n    ObjectHandle handle, const char* key, DGLValue* out_value,\n    int* out_type_code, int* out_success);\n\n/**\n * @brief get attributes names in the object.\n * @param handle The object handle\n * @param out_size The number of functions\n * @param out_array The array of function names.\n * @return 0 when success, -1 when failure happens\n */\nDGL_DLL int DGLObjectListAttrNames(\n    ObjectHandle handle, int* out_size, const char*** out_array);\n#ifdef __cplusplus\n}  // DGL_EXTERN_C\n#endif\n#endif  // DGL_RUNTIME_C_OBJECT_API_H_\n"
  },
  {
    "path": "include/dgl/runtime/c_runtime_api.h",
    "content": "/**\n *  Copyright (c) 2016-2022 by Contributors\n * @file dgl/runtime/c_runtime_api.h\n * @brief DGL runtime library.\n *\n * This runtime is adapted from TVM project (commit: 2ce5277)\n */\n#ifndef DGL_RUNTIME_C_RUNTIME_API_H_\n#define DGL_RUNTIME_C_RUNTIME_API_H_\n\n// Macros to do weak linking\n#ifdef _MSC_VER\n#define DGL_WEAK __declspec(selectany)\n#else\n#define DGL_WEAK __attribute__((weak))\n#endif\n\n#ifdef __EMSCRIPTEN__\n#include <emscripten/emscripten.h>\n#define DGL_DLL EMSCRIPTEN_KEEPALIVE\n#endif\n\n#ifndef DGL_DLL\n#ifdef _WIN32\n#ifdef DGL_EXPORTS\n#define DGL_DLL __declspec(dllexport)\n#else\n#define DGL_DLL __declspec(dllimport)\n#endif\n#else\n#define DGL_DLL\n#endif\n#endif\n\n// DGL version\n#define DGL_VERSION \"2.5\"\n\n#ifdef __cplusplus\nextern \"C\" {\n#endif\n#include <stddef.h>\n#include <stdint.h>\n\n/** @brief type of array index. */\ntypedef int64_t dgl_index_t;\n\n/**\n * @brief The device type in DGLContext.\n */\n#ifdef __cplusplus\ntypedef enum : int32_t {\n#else\ntypedef enum {\n#endif\n  /** @brief CPU device */\n  kDGLCPU = 1,\n  /** @brief CUDA GPU device */\n  kDGLCUDA = 2,\n  // add more devices once supported\n} DGLDeviceType;\n\n/**\n * @brief The object type code is used in DGL FFI to indicate the types of\n *        objects passed between C and Python.\n */\ntypedef enum {\n  kObjectInt = 0U,\n  kObjectUInt = 1U,\n  kObjectFloat = 2U,\n  kHandle = 3U,\n  kNull = 4U,\n  kDGLDataType = 5U,\n  kDGLContext = 6U,\n  kArrayHandle = 7U,\n  kObjectHandle = 8U,\n  kModuleHandle = 9U,\n  kFuncHandle = 10U,\n  kStr = 11U,\n  kBytes = 12U,\n  kNDArrayContainer = 13U,\n  // Extension codes for other frameworks to integrate DGL PackedFunc.\n  // To make sure each framework's id do not conflict, use first and\n  // last sections to mark ranges.\n  // Open an issue at the repo if you need a section of code.\n  kExtBegin = 15U,\n  kNNVMFirst = 16U,\n  kNNVMLast = 20U,\n  // The following section of code is used for non-reserved types.\n  kExtReserveEnd = 64U,\n  kExtEnd = 128U\n} DGLObjectTypeCode;\n\n/**\n * @brief The type code options DGLDataType.\n */\ntypedef enum {\n  /** @brief signed integer */\n  kDGLInt = 0U,\n  /** @brief unsigned integer */\n  kDGLUInt = 1U,\n  /** @brief IEEE floating point */\n  kDGLFloat = 2U,\n  /** @brief bfloat16 */\n  kDGLBfloat = 4U,\n  // add more data types if we are going to support them\n} DGLDataTypeCode;\n\n/**\n * @brief The data type the tensor can hold. The data type is assumed to follow\n * the native endian-ness. An explicit error message should be raised when\n * attempting to export an array with non-native endianness\n *\n *  Examples\n *   - float: type_code = 2, bits = 32, lanes=1\n *   - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4\n *   - int8: type_code = 0, bits = 8, lanes=1\n */\ntypedef struct {\n  /**\n   * @brief Type code of base types.\n   * We keep it uint8_t instead of DGLDataTypeCode for minimal memory\n   * footprint, but the value should be one of DGLDataTypeCode enum values.\n   * */\n  uint8_t code;\n  /**\n   * @brief Number of bits, common choices are 8, 16, 32.\n   */\n  uint8_t bits;\n  /** @brief Number of lanes in the type, used for vector types. */\n  uint16_t lanes;\n} DGLDataType;\n\n/**\n * @brief The Device information, abstract away common device types.\n */\ntypedef struct {\n  /** @brief The device type used in the device. */\n  DGLDeviceType device_type;\n  /**\n   * @brief The device index.\n   * For vanilla CPU memory, pinned memory, or managed memory, this is set to 0.\n   */\n  int32_t device_id;\n} DGLContext;\n\n/**\n * @brief The tensor array stucture to DGL API.\n * The structure is heavily inspired by DLTensor from DLPack.\n */\ntypedef struct {\n  /**\n   * @brief The data pointer points to the allocated data.\n   *\n   * Depending on the device context, it can be a CPU pointer, or a CUDA\n   * device pointer or  acl_mem handle in OpenCL.\n   * This pointer is always aligned to 256 bytes as in CUDA. Use the\n   * `byte_offset` field to mark the beginning of the actual data (if the\n   * address is not 256 byte aligned).\n   *\n   * Note that as of Nov 2021, multiply libraries (CuPy, PyTorch, TensorFlow,\n   * TVM, perhaps others) do not adhere to this 256 byte alignment requirement\n   * on CPU/CUDA/ROCm, and always use `byte_offset=0`.  This is likely to be\n   * fixed in the future; at the moment it is recommended\n   * to not rely on the data pointer being correctly aligned.\n   *\n   * For a DGLArray, the size of memory required to store the contents of\n   * data can be calculated as follows:\n   *\n   * @code{.c}\n   * static inline size_t GetDataSize(const DGLArray* t) {\n   *   size_t size = 1;\n   *   for (int32_t i = 0; i < t->ndim; ++i) {\n   *     size *= t->shape[i];\n   *   }\n   *   size *= (t->dtype.bits * t->dtype.lanes + 7) / 8;\n   *   return size;\n   * }\n   * @endcode\n   */\n  void* data;\n  /** @brief The device of the tensor */\n  DGLContext ctx;\n  /** @brief Number of dimensions */\n  int32_t ndim;\n  /** @brief The data type of the pointer*/\n  DGLDataType dtype;\n  /** @brief The shape of the tensor */\n  int64_t* shape;\n  /**\n   * @brief strides of the tensor (in number of elements, not bytes)\n   *  can be NULL, indicating tensor is compact and row-majored.\n   */\n  int64_t* strides;\n  /** @brief The offset in bytes to the beginning pointer to data */\n  uint64_t byte_offset;\n} DGLArray;\n\n/** @brief the array handle */\ntypedef DGLArray* DGLArrayHandle;\n\n/**\n * @brief Union type of values\n *  being passed through API and function calls.\n */\ntypedef union {\n  int64_t v_int64;\n  double v_float64;\n  void* v_handle;\n  const char* v_str;\n  DGLDataType v_type;\n  DGLContext v_ctx;\n} DGLValue;\n\n/**\n * @brief Byte array type used to pass in byte array\n *  When kBytes is used as data type.\n */\ntypedef struct {\n  const char* data;\n  size_t size;\n} DGLByteArray;\n\n/** @brief Handle to DGL runtime modules. */\ntypedef void* DGLModuleHandle;\n/** @brief Handle to packed function handle. */\ntypedef void* DGLFunctionHandle;\n/** @brief Handle to hold return value. */\ntypedef void* DGLRetValueHandle;\n/**\n * @brief The stream that is specific to device\n * can be NULL, which indicates the default one.\n */\ntypedef void* DGLStreamHandle;\n\n/**\n * @brief Used for implementing C API function.\n *  Set last error message before return.\n * @param msg The error message to be set.\n */\nDGL_DLL void DGLAPISetLastError(const char* msg);\n\n/**\n * @brief return str message of the last error\n *  all function in this file will return 0 when success\n *  and -1 when an error occured,\n *  DGLGetLastError can be called to retrieve the error\n *\n *  this function is threadsafe and can be called by different thread\n *\n * @return error info\n */\nDGL_DLL const char* DGLGetLastError(void);\n/**\n * @brief Load module from file.\n * @param file_name The file name to load the module from.\n * @param format The format of the module.\n * @param out The result module\n *\n * @return 0 when success, -1 when failure happens\n * @note The resulting module do not contain import relation.\n *  It can be reconstructed by DGLModImport.\n */\nDGL_DLL int DGLModLoadFromFile(\n    const char* file_name, const char* format, DGLModuleHandle* out);\n\n/**\n * @brief Add dep to mod's dependency.\n *  This allows functions in this module to use modules.\n *\n * @param mod The module handle.\n * @param dep The dependent module to be imported.\n * @return 0 when success, -1 when failure happens\n */\nDGL_DLL int DGLModImport(DGLModuleHandle mod, DGLModuleHandle dep);\n\n/**\n * @brief Get function from the module.\n * @param mod The module handle.\n * @param func_name The name of the function.\n * @param query_imports Whether to query imported modules\n * @param out The result function, can be NULL if it is not available.\n * @return 0 when no error is thrown, -1 when failure happens\n */\nDGL_DLL int DGLModGetFunction(\n    DGLModuleHandle mod, const char* func_name, int query_imports,\n    DGLFunctionHandle* out);\n\n/**\n * @brief Free front-end extension type resource.\n * @param handle The extension handle.\n * @param type_code The type of of the extension type.\n * @return 0 when success, -1 when failure happens\n */\nDGL_DLL int DGLExtTypeFree(void* handle, int type_code);\n\n/**\n * @brief Free the Module\n * @param mod The module to be freed.\n *\n * @note This may not free up the module's resources.\n *  If there is active DGLFunctionHandle uses the module\n *  Or if this module is imported by another active module.\n *\n *  The all functions remains valid until DGLFuncFree is called.\n * @return 0 when success, -1 when failure happens\n */\nDGL_DLL int DGLModFree(DGLModuleHandle mod);\n\n/**\n * @brief Free the function when it is no longer needed.\n * @param func The function handle\n * @return 0 when success, -1 when failure happens\n */\nDGL_DLL int DGLFuncFree(DGLFunctionHandle func);\n\n/**\n * @brief Call a Packed DGL Function.\n *\n * @param func node handle of the function.\n * @param arg_values The arguments\n * @param type_codes The type codes of the arguments\n * @param num_args Number of arguments.\n *\n * @param ret_val The return value.\n * @param ret_type_code the type code of return value.\n *\n * @return 0 when success, -1 when failure happens\n * @note DGL calls always exchanges with type bits=64, lanes=1\n *\n * @note API calls always exchanges with type bits=64, lanes=1\n *   If API call returns container handles (e.g. FunctionHandle)\n *   these handles should be managed by the front-end.\n *   The front-end need to call free function (e.g. DGLFuncFree)\n *   to free these handles.\n */\nDGL_DLL int DGLFuncCall(\n    DGLFunctionHandle func, DGLValue* arg_values, int* type_codes, int num_args,\n    DGLValue* ret_val, int* ret_type_code);\n\n/**\n * @brief Set the return value of DGLPackedCFunc.\n *\n *  This function is called by DGLPackedCFunc to set the return value.\n *  When this function is not called, the function returns null by default.\n *\n * @param ret The return value handle, pass by ret in DGLPackedCFunc\n * @param value The value to be returned.\n * @param type_code The type of the value to be returned.\n * @param num_ret Number of return values, for now only 1 is supported.\n */\nDGL_DLL int DGLCFuncSetReturn(\n    DGLRetValueHandle ret, DGLValue* value, int* type_code, int num_ret);\n\n/**\n * @brief Inplace translate callback argument value to return value.\n *  This is only needed for non-POD arguments.\n *\n * @param value The value to be translated.\n * @param code The type code to be translated.\n * @note This function will do a shallow copy when necessary.\n *\n * @return 0 when success, -1 when failure happens.\n */\nDGL_DLL int DGLCbArgToReturn(DGLValue* value, int code);\n\n/**\n * @brief C type of packed function.\n *\n * @param args The arguments\n * @param type_codes The type codes of the arguments\n * @param num_args Number of arguments.\n * @param ret The return value handle.\n * @param resource_handle The handle additional resouce handle from fron-end.\n * @return 0 if success, -1 if failure happens, set error via\n *         DGLAPISetLastError.\n * @sa DGLCFuncSetReturn\n */\ntypedef int (*DGLPackedCFunc)(\n    DGLValue* args, int* type_codes, int num_args, DGLRetValueHandle ret,\n    void* resource_handle);\n\n/**\n * @brief C callback to free the resource handle in C packed function.\n * @param resource_handle The handle additional resouce handle from fron-end.\n */\ntypedef void (*DGLPackedCFuncFinalizer)(void* resource_handle);\n\n/**\n * @brief Signature for extension function declarer.\n *\n *  DGL call this function to get the extension functions\n *  The declarer will call register_func to register function and their name.\n *\n * @param register_func_handle The register function\n * @return 0 if success, -1 if failure happens\n */\ntypedef int (*DGLExtensionFuncDeclarer)(DGLFunctionHandle register_func_handle);\n\n/**\n * @brief Wrap a DGLPackedCFunc to become a FunctionHandle.\n *\n * The resource_handle will be managed by DGL API, until the function is no\n * longer used.\n *\n * @param func The packed C function.\n * @param resource_handle The resource handle from front-end, can be NULL.\n * @param fin The finalizer on resource handle when the FunctionHandle get\n *        freed, can be NULL.\n * @param out the result function handle.\n * @return 0 when success, -1 when failure happens.\n */\nDGL_DLL int DGLFuncCreateFromCFunc(\n    DGLPackedCFunc func, void* resource_handle, DGLPackedCFuncFinalizer fin,\n    DGLFunctionHandle* out);\n\n/**\n * @brief Register the function to runtime's global table.\n *\n * The registered function then can be pulled by the backend by the name.\n *\n * @param name The name of the function.\n * @param f The function to be registered.\n * @param override Whether allow override already registered function.\n */\nDGL_DLL int DGLFuncRegisterGlobal(\n    const char* name, DGLFunctionHandle f, int override);\n\n/**\n * @brief Get a global function.\n *\n * @param name The name of the function.\n * @param out the result function pointer, NULL if it does not exist.\n *\n * @note The function handle of global function is managed by DGL runtime,\n *  So DGLFuncFree is should not be called when it get deleted.\n */\nDGL_DLL int DGLFuncGetGlobal(const char* name, DGLFunctionHandle* out);\n\n/**\n * @brief List all the globally registered function name\n * @param out_size The number of functions\n * @param out_array The array of function names.\n * @return 0 when success, -1 when failure happens\n */\nDGL_DLL int DGLFuncListGlobalNames(int* out_size, const char*** out_array);\n\n// Array related apis for quick proptyping\n/**\n * @brief Allocate a nd-array's memory,\n *  including space of shape, of given spec.\n *\n * @param shape The shape of the array, the data content will be copied to out\n * @param ndim The number of dimension of the array.\n * @param dtype_code The type code of the dtype\n * @param dtype_bits The number of bits of dtype\n * @param dtype_lanes The number of lanes in the dtype.\n * @param device_type The device type of context\n * @param device_id The device id of context.\n * @param out The output handle.\n * @return 0 when success, -1 when failure happens\n */\nDGL_DLL int DGLArrayAlloc(\n    const dgl_index_t* shape, int ndim, int dtype_code, int dtype_bits,\n    int dtype_lanes, int device_type, int device_id, DGLArrayHandle* out);\n\n/**\n * @brief Allocate a nd-array's with shared memory,\n *  including space of shape, of given spec.\n *\n * @param the name of the shared memory\n * @param shape The shape of the array, the data content will be copied to out\n * @param ndim The number of dimension of the array.\n * @param dtype_code The type code of the dtype\n * @param dtype_bits The number of bits of dtype\n * @param dtype_lanes The number of lanes in the dtype.\n * @param is_create whether the shared memory is created\n * @param out The output handle.\n * @return 0 when success, -1 when failure happens\n */\nint DGLArrayAllocSharedMem(\n    const char* mem_name, const dgl_index_t* shape, int ndim, int dtype_code,\n    int dtype_bits, int dtype_lanes, bool is_create, DGLArrayHandle* out);\n\n/**\n * @brief Free the DGL Array.\n * @param handle The array handle to be freed.\n * @return 0 when success, -1 when failure happens\n */\nDGL_DLL int DGLArrayFree(DGLArrayHandle handle);\n\n/**\n * @brief Copy array data from CPU byte array.\n * @param handle The array handle.\n * @param data the data pointer\n * @param nbytes The number of bytes to copy.\n * @return 0 when success, -1 when failure happens\n */\nDGL_DLL int DGLArrayCopyFromBytes(\n    DGLArrayHandle handle, void* data, size_t nbytes);\n\n/**\n * @brief Copy array data to CPU byte array.\n * @param handle The array handle.\n * @param data the data pointer\n * @param nbytes The number of bytes to copy.\n * @return 0 when success, -1 when failure happens\n */\nDGL_DLL int DGLArrayCopyToBytes(\n    DGLArrayHandle handle, void* data, size_t nbytes);\n\n/**\n * @brief Copy the array, both from and to must be valid during the copy.\n * @param from The array to be copied from.\n * @param to The target space.\n * @return 0 when success, -1 when failure happens\n */\nDGL_DLL int DGLArrayCopyFromTo(DGLArrayHandle from, DGLArrayHandle to);\n\n/**\n * @brief Create a new runtime stream.\n *\n * @param device_type The device type of context\n * @param device_id The device id of context\n * @param out The new stream handle\n * @return 0 when success, -1 when failure happens\n */\nDGL_DLL int DGLStreamCreate(\n    int device_type, int device_id, DGLStreamHandle* out);\n\n/**\n * @brief Free a created stream handle.\n *\n * @param device_type The device type of context\n * @param device_id The device id of context\n * @param stream The stream to be freed\n * @return 0 when success, -1 when failure happens\n */\nDGL_DLL int DGLStreamFree(\n    int device_type, int device_id, DGLStreamHandle stream);\n\n/**\n * @brief Set the runtime stream of current thread to be stream.\n *  The subsequent calls to the same device_type\n *  will use the setted stream handle.\n *  The specific type of stream is runtime device dependent.\n *\n * @param device_type The device type of context\n * @param device_id The device id of context.\n * @param handle The stream handle.\n * @return 0 when success, -1 when failure happens\n */\nDGL_DLL int DGLSetStream(\n    int device_type, int device_id, DGLStreamHandle handle);\n\n/**\n * @brief Get the runtime stream of current thread.\n *\n * @param device_type The device type of context\n * @param device_id The device id of context.\n * @param handle The stream handle.\n * @return 0 when success, -1 when failure happens\n */\nDGL_DLL int DGLGetStream(\n    int device_type, int device_id, DGLStreamHandle* handle);\n\n/**\n * @brief Wait until all computations on stream completes.\n *\n * @param device_type The device type of context\n * @param device_id The device id of context.\n * @param stream The stream to be synchronized.\n * @return 0 when success, -1 when failure happens\n */\nDGL_DLL int DGLSynchronize(\n    int device_type, int device_id, DGLStreamHandle stream);\n\n/**\n * @brief Synchronize two streams of execution.\n *\n * @param device_type The device type of context\n * @param device_id The device id of context\n * @param src The source stream to synchronize.\n * @param dst The destination stream to synchronize.\n * @return 0 when success, -1 when failure happens\n */\nDGL_DLL int DGLStreamStreamSynchronize(\n    int device_type, int device_id, DGLStreamHandle src, DGLStreamHandle dst);\n\n/**\n * @brief Load tensor adapter.\n * @return 0 when success, -1 when failure happens.\n */\nDGL_DLL int DGLLoadTensorAdapter(const char* path);\n\n/**\n * @brief Pin host memory.\n */\nint DGLArrayPinData(DGLArrayHandle handle, DGLContext ctx);\n\n/**\n * @brief Unpin host memory.\n */\nint DGLArrayUnpinData(DGLArrayHandle handle, DGLContext ctx);\n\n/**\n * @brief Record the stream that's using this tensor.\n */\nint DGLArrayRecordStream(DGLArrayHandle handle, DGLStreamHandle stream);\n\n/**\n * @brief Bug report macro.\n *\n * This serves as a sanity check on system side to make sure the code is correct\n * by checking whether a condition always holds for complex reasons.  Failing\n * the condition signifies a system bug instead of users giving invalid inputs\n * or using the functionality incorrectly.\n *\n * Hints the user to file a bug report if the condition fails.\n */\n#define BUG_IF_FAIL(cond)                                                    \\\n  CHECK(cond)                                                                \\\n      << \"A bug has been occurred.  \"                                        \\\n         \"Please file a bug report at https://github.com/dmlc/dgl/issues.  \" \\\n         \"Message: \"\n\n#ifdef __cplusplus\n}  // DGL_EXTERN_C\n#endif\n#endif  // DGL_RUNTIME_C_RUNTIME_API_H_\n"
  },
  {
    "path": "include/dgl/runtime/config.h",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file runtime/config.h\n * @brief DGL runtime config\n */\n\n#ifndef DGL_RUNTIME_CONFIG_H_\n#define DGL_RUNTIME_CONFIG_H_\n\nnamespace dgl {\nnamespace runtime {\n\nclass Config {\n public:\n  static Config* Global() {\n    static Config config;\n    return &config;\n  }\n\n  // Enabling or disable use libxsmm for Spmm\n  void EnableLibxsmm(bool);\n  bool IsLibxsmmAvailable() const;\n\n private:\n  Config();\n  bool libxsmm_;\n};\n\n}  // namespace runtime\n}  // namespace dgl\n\n#endif  // DGL_RUNTIME_CONFIG_H_\n"
  },
  {
    "path": "include/dgl/runtime/container.h",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file runtime/container.h\n * @brief Defines the container object data structures.\n */\n#ifndef DGL_RUNTIME_CONTAINER_H_\n#define DGL_RUNTIME_CONTAINER_H_\n\n#include <memory>\n#include <string>\n#include <unordered_map>\n#include <utility>\n#include <vector>\n\n#include \"object.h\"\n#include \"packed_func.h\"\n\nnamespace dgl {\nnamespace runtime {\n\n/**\n * @brief value object.\n *\n * It is typically used to wrap a non-Object type to Object type.\n * Any type that is supported by DGLRetValue is supported by this.\n */\nclass ValueObject : public Object {\n public:\n  /** @brief the value data */\n  DGLRetValue data;\n\n  static constexpr const char* _type_key = \"Value\";\n  DGL_DECLARE_OBJECT_TYPE_INFO(ValueObject, Object);\n};\n\n/** @brief Construct a value object. */\ntemplate <typename T>\ninline std::shared_ptr<ValueObject> MakeValue(T&& val) {\n  auto obj = std::make_shared<ValueObject>();\n  obj->data = val;\n  return obj;\n}\n\n/** @brief Vallue reference type */\nclass Value : public ObjectRef {\n public:\n  Value() {}\n  explicit Value(std::shared_ptr<Object> o) : ObjectRef(o) {}\n\n  const ValueObject* operator->() const {\n    return static_cast<const ValueObject*>(obj_.get());\n  }\n\n  using ContainerType = ValueObject;\n};\n\n/** @brief list obj content in list */\nclass ListObject : public Object {\n public:\n  /** @brief the data content */\n  std::vector<std::shared_ptr<Object> > data;\n\n  void VisitAttrs(AttrVisitor* visitor) final {\n    // Visitor to list have no effect.\n  }\n\n  static constexpr const char* _type_key = \"List\";\n  DGL_DECLARE_OBJECT_TYPE_INFO(ListObject, Object);\n};\n\n/** @brief map obj content */\nclass MapObject : public Object {\n public:\n  void VisitAttrs(AttrVisitor* visitor) final {\n    // Visitor to map have no effect.\n  }\n  // hash function\n  struct Hash {\n    size_t operator()(const std::shared_ptr<Object>& n) const {\n      return std::hash<Object*>()(n.get());\n    }\n  };\n  // comparator\n  struct Equal {\n    bool operator()(\n        const std::shared_ptr<Object>& a,\n        const std::shared_ptr<Object>& b) const {\n      return a.get() == b.get();\n    }\n  };\n\n  /** @brief The corresponding conatiner type */\n  using ContainerType = std::unordered_map<\n      std::shared_ptr<Object>, std::shared_ptr<Object>, Hash, Equal>;\n\n  /** @brief the data content */\n  ContainerType data;\n\n  static constexpr const char* _type_key = \"Map\";\n  DGL_DECLARE_OBJECT_TYPE_INFO(MapObject, Object);\n};\n\n/** @brief specialized map obj with string as key */\nclass StrMapObject : public Object {\n public:\n  void VisitAttrs(AttrVisitor* visitor) final {\n    // Visitor to map have no effect.\n  }\n  /** @brief The corresponding conatiner type */\n  using ContainerType =\n      std::unordered_map<std::string, std::shared_ptr<Object> >;\n\n  /** @brief the data content */\n  ContainerType data;\n\n  static constexpr const char* _type_key = \"StrMap\";\n  DGL_DECLARE_OBJECT_TYPE_INFO(StrMapObject, Object);\n};\n\n/**\n * @brief iterator adapter that adapts TIter to return another type.\n * @tparam Converter a struct that contains converting function\n * @tparam TIter the content iterator type.\n */\ntemplate <typename Converter, typename TIter>\nclass IterAdapter {\n public:\n  explicit IterAdapter(TIter iter) : iter_(iter) {}\n  inline IterAdapter& operator++() {  // NOLINT(*)\n    ++iter_;\n    return *this;\n  }\n  inline IterAdapter& operator++(int) {  // NOLINT(*)\n    ++iter_;\n    return *this;\n  }\n  inline IterAdapter operator+(int offset) const {  // NOLINT(*)\n    return IterAdapter(iter_ + offset);\n  }\n  inline bool operator==(IterAdapter other) const {\n    return iter_ == other.iter_;\n  }\n  inline bool operator!=(IterAdapter other) const { return !(*this == other); }\n  inline const typename Converter::ResultType operator*() const {\n    return Converter::convert(*iter_);\n  }\n\n private:\n  TIter iter_;\n};\n\n/**\n * @brief List container of ObjectRef.\n *\n * List implements copy on write semantics, which means list is mutable\n * but copy will happen when list is referenced in more than two places.\n *\n * That is said when using this container for runtime arguments or return\n * values, try use the constructor to create the list at once (for example\n * from an existing vector).\n *\n * operator[] only provide const access, use Set to mutate the content.\n *\n * @tparam T The content ObjectRef type.\n *\n * @note The element type must subclass \\c ObjectRef.  Otherwise, the\n * compiler would throw an error:\n *\n * <code>\n *      error: no type named 'type' in 'struct std::enable_if<false, void>'\n * </code>\n *\n * Example:\n *\n * <code>\n *     // List<int> list;          // fails\n *     // List<NDArray> list2;     // fails\n *     List<Value> list;           // works\n *     list.push_back(Value(MakeValue(1)));  // works\n *     list.push_back(Value(MakeValue(NDArray::Empty(shape, dtype, ctx))));  //\n * works\n * </code>\n */\ntemplate <\n    typename T,\n    typename =\n        typename std::enable_if<std::is_base_of<ObjectRef, T>::value>::type>\nclass List : public ObjectRef {\n public:\n  /**\n   * @brief default constructor\n   */\n  List() { obj_ = std::make_shared<ListObject>(); }\n  /**\n   * @brief move constructor\n   * @param other source\n   */\n  List(List<T>&& other) {  // NOLINT(*)\n    obj_ = std::move(other.obj_);\n  }\n  /**\n   * @brief copy constructor\n   * @param other source\n   */\n  List(const List<T>& other) : ObjectRef(other.obj_) {  // NOLINT(*)\n  }\n  /**\n   * @brief constructor from pointer\n   * @param n the container pointer\n   */\n  explicit List(std::shared_ptr<Object> n) : ObjectRef(n) {}\n  /**\n   * @brief constructor from iterator\n   * @param begin begin of iterator\n   * @param end end of iterator\n   * @tparam IterType The type of iterator\n   */\n  template <typename IterType>\n  List(IterType begin, IterType end) {\n    assign(begin, end);\n  }\n  /**\n   * @brief constructor from initializer list\n   * @param init The initalizer list\n   */\n  List(std::initializer_list<T> init) {  // NOLINT(*)\n    assign(init.begin(), init.end());\n  }\n  /**\n   * @brief constructor from vector\n   * @param init The vector\n   */\n  List(const std::vector<T>& init) {  // NOLINT(*)\n    assign(init.begin(), init.end());\n  }\n  /**\n   * @brief Constructs a container with n elements. Each element is a copy of\n   * val\n   * @param n The size of the container\n   * @param val The init value\n   */\n  explicit List(size_t n, const T& val) {\n    auto tmp_obj = std::make_shared<ListObject>();\n    for (size_t i = 0; i < n; ++i) {\n      tmp_obj->data.push_back(val.obj_);\n    }\n    obj_ = std::move(tmp_obj);\n  }\n  /**\n   * @brief move assign operator\n   * @param other The source of assignment\n   * @return reference to self.\n   */\n  List<T>& operator=(List<T>&& other) {\n    obj_ = std::move(other.obj_);\n    return *this;\n  }\n  /**\n   * @brief copy assign operator\n   * @param other The source of assignment\n   * @return reference to self.\n   */\n  List<T>& operator=(const List<T>& other) {\n    obj_ = other.obj_;\n    return *this;\n  }\n  /**\n   * @brief reset the list to content from iterator.\n   * @param begin begin of iterator\n   * @param end end of iterator\n   * @tparam IterType The type of iterator\n   */\n  template <typename IterType>\n  void assign(IterType begin, IterType end) {\n    auto n = std::make_shared<ListObject>();\n    for (IterType it = begin; it != end; ++it) {\n      n->data.push_back((*it).obj_);\n    }\n    obj_ = std::move(n);\n  }\n  /**\n   * @brief Read i-th element from list.\n   * @param i The index\n   * @return the i-th element.\n   */\n  inline const T operator[](size_t i) const {\n    return T(static_cast<const ListObject*>(obj_.get())->data[i]);\n  }\n  /** @return The size of the list */\n  inline size_t size() const {\n    if (obj_.get() == nullptr) return 0;\n    return static_cast<const ListObject*>(obj_.get())->data.size();\n  }\n  /**\n   * @brief copy on write semantics\n   *  Do nothing if current handle is the unique copy of the list.\n   *  Otherwise make a new copy of the list to ensure the current handle\n   *  hold a unique copy.\n   *\n   * @return Handle to the internal obj container(which ganrantees to be unique)\n   */\n  inline ListObject* CopyOnWrite() {\n    if (obj_.get() == nullptr || !obj_.unique()) {\n      obj_ = std::make_shared<ListObject>(\n          *static_cast<const ListObject*>(obj_.get()));\n    }\n    return static_cast<ListObject*>(obj_.get());\n  }\n  /**\n   * @brief push a new item to the back of the list\n   * @param item The item to be pushed.\n   */\n  inline void push_back(const T& item) {\n    ListObject* n = this->CopyOnWrite();\n    n->data.push_back(item.obj_);\n  }\n  /**\n   * @brief set i-th element of the list.\n   * @param i The index\n   * @param value The value to be setted.\n   */\n  inline void Set(size_t i, const T& value) {\n    ListObject* n = this->CopyOnWrite();\n    n->data[i] = value.obj_;\n  }\n  /** @return whether list is empty */\n  inline bool empty() const { return size() == 0; }\n  /** @brief Copy the content to a vector */\n  inline std::vector<T> ToVector() const {\n    return std::vector<T>(begin(), end());\n  }\n  /** @brief specify container obj */\n  using ContainerType = ListObject;\n\n  struct Ptr2ObjectRef {\n    using ResultType = T;\n    static inline T convert(const std::shared_ptr<Object>& n) { return T(n); }\n  };\n  using iterator = IterAdapter<\n      Ptr2ObjectRef, std::vector<std::shared_ptr<Object> >::const_iterator>;\n\n  using reverse_iterator = IterAdapter<\n      Ptr2ObjectRef,\n      std::vector<std::shared_ptr<Object> >::const_reverse_iterator>;\n\n  /** @return begin iterator */\n  inline iterator begin() const {\n    return iterator(static_cast<const ListObject*>(obj_.get())->data.begin());\n  }\n  /** @return end iterator */\n  inline iterator end() const {\n    return iterator(static_cast<const ListObject*>(obj_.get())->data.end());\n  }\n  /** @return rbegin iterator */\n  inline reverse_iterator rbegin() const {\n    return reverse_iterator(\n        static_cast<const ListObject*>(obj_.get())->data.rbegin());\n  }\n  /** @return rend iterator */\n  inline reverse_iterator rend() const {\n    return reverse_iterator(\n        static_cast<const ListObject*>(obj_.get())->data.rend());\n  }\n};\n\n/**\n * @brief Map container of ObjectRef->ObjectRef.\n *\n * Map implements copy on write semantics, which means map is mutable\n * but copy will happen when list is referenced in more than two places.\n *\n * That is said when using this container for runtime arguments or return\n * values, try use the constructor to create it at once (for example\n * from an existing std::map).\n *\n * operator[] only provide const acces, use Set to mutate the content.\n *\n * @tparam K The key ObjectRef type.\n * @tparam V The value ObjectRef type.\n *\n * @note The element type must subclass \\c ObjectRef.  Otherwise, the\n * compiler would throw an error:\n *\n * <code>\n *      error: no type named 'type' in 'struct std::enable_if<false, void>'\n * </code>\n *\n * Example:\n *\n * <code>\n *     // Map<std::string, int> map;          // fails\n *     // Map<std::string, NDArray> map2;     // fails\n *     Map<std::string, Value> map;           // works\n *     map.Set(\"key1\", Value(MakeValue(1)));  // works\n *     map.Set(\"key2\", Value(MakeValue(NDArray::Empty(shape, dtype, ctx))));  //\n * works\n * </code>\n */\ntemplate <\n    typename K, typename V,\n    typename = typename std::enable_if<\n        std::is_base_of<ObjectRef, K>::value ||\n        std::is_base_of<std::string, K>::value>::type,\n    typename =\n        typename std::enable_if<std::is_base_of<ObjectRef, V>::value>::type>\nclass Map : public ObjectRef {\n public:\n  /**\n   * @brief default constructor\n   */\n  Map() { obj_ = std::make_shared<MapObject>(); }\n  /**\n   * @brief move constructor\n   * @param other source\n   */\n  Map(Map<K, V>&& other) {  // NOLINT(*)\n    obj_ = std::move(other.obj_);\n  }\n  /**\n   * @brief copy constructor\n   * @param other source\n   */\n  Map(const Map<K, V>& other) : ObjectRef(other.obj_) {  // NOLINT(*)\n  }\n  /**\n   * @brief constructor from pointer\n   * @param n the container pointer\n   */\n  explicit Map(std::shared_ptr<Object> n) : ObjectRef(n) {}\n  /**\n   * @brief constructor from iterator\n   * @param begin begin of iterator\n   * @param end end of iterator\n   * @tparam IterType The type of iterator\n   */\n  template <typename IterType>\n  Map(IterType begin, IterType end) {\n    assign(begin, end);\n  }\n  /**\n   * @brief constructor from initializer list\n   * @param init The initalizer list\n   */\n  Map(std::initializer_list<std::pair<K, V> > init) {  // NOLINT(*)\n    assign(init.begin(), init.end());\n  }\n  /**\n   * @brief constructor from vector\n   * @param init The vector\n   */\n  template <typename Hash, typename Equal>\n  Map(const std::unordered_map<K, V, Hash, Equal>& init) {  // NOLINT(*)\n    assign(init.begin(), init.end());\n  }\n  /**\n   * @brief move assign operator\n   * @param other The source of assignment\n   * @return reference to self.\n   */\n  Map<K, V>& operator=(Map<K, V>&& other) {\n    obj_ = std::move(other.obj_);\n    return *this;\n  }\n  /**\n   * @brief copy assign operator\n   * @param other The source of assignment\n   * @return reference to self.\n   */\n  Map<K, V>& operator=(const Map<K, V>& other) {\n    obj_ = other.obj_;\n    return *this;\n  }\n  /**\n   * @brief reset the list to content from iterator.\n   * @param begin begin of iterator\n   * @param end end of iterator\n   * @tparam IterType The type of iterator\n   */\n  template <typename IterType>\n  void assign(IterType begin, IterType end) {\n    auto n = std::shared_ptr<MapObject>();\n    for (IterType i = begin; i != end; ++i) {\n      n->data.emplace(std::make_pair(i->first.obj_, i->second.obj_));\n    }\n    obj_ = std::move(n);\n  }\n  /**\n   * @brief Read element from map.\n   * @param key The key\n   * @return the corresonding element.\n   */\n  inline const V operator[](const K& key) const {\n    return V(static_cast<const MapObject*>(obj_.get())->data.at(key.obj_));\n  }\n  /**\n   * @brief Read element from map.\n   * @param key The key\n   * @return the corresonding element.\n   */\n  inline const V at(const K& key) const {\n    return V(static_cast<const MapObject*>(obj_.get())->data.at(key.obj_));\n  }\n  /** @return The size of the list */\n  inline size_t size() const {\n    if (obj_.get() == nullptr) return 0;\n    return static_cast<const MapObject*>(obj_.get())->data.size();\n  }\n  /** @return The size of the list */\n  inline size_t count(const K& key) const {\n    if (obj_.get() == nullptr) return 0;\n    return static_cast<const MapObject*>(obj_.get())->data.count(key.obj_);\n  }\n  /**\n   * @brief copy on write semantics\n   *  Do nothing if current handle is the unique copy of the list.\n   *  Otherwise make a new copy of the list to ensure the current handle\n   *  hold a unique copy.\n   *\n   * @return Handle to the internal obj container(which ganrantees to be unique)\n   */\n  inline MapObject* CopyOnWrite() {\n    if (obj_.get() == nullptr || !obj_.unique()) {\n      obj_ = std::make_shared<MapObject>(\n          *static_cast<const MapObject*>(obj_.get()));\n    }\n    return static_cast<MapObject*>(obj_.get());\n  }\n  /**\n   * @brief set the Map.\n   * @param key The index key.\n   * @param value The value to be setted.\n   */\n  inline void Set(const K& key, const V& value) {\n    MapObject* n = this->CopyOnWrite();\n    n->data[key.obj_] = value.obj_;\n  }\n\n  /** @return whether list is empty */\n  inline bool empty() const { return size() == 0; }\n  /** @brief specify container obj */\n  using ContainerType = MapObject;\n\n  struct Ptr2ObjectRef {\n    using ResultType = std::pair<K, V>;\n    static inline ResultType convert(\n        const std::pair<std::shared_ptr<Object>, std::shared_ptr<Object> >& n) {\n      return std::make_pair(K(n.first), V(n.second));\n    }\n  };\n\n  using iterator =\n      IterAdapter<Ptr2ObjectRef, MapObject::ContainerType::const_iterator>;\n\n  /** @return begin iterator */\n  inline iterator begin() const {\n    return iterator(static_cast<const MapObject*>(obj_.get())->data.begin());\n  }\n  /** @return end iterator */\n  inline iterator end() const {\n    return iterator(static_cast<const MapObject*>(obj_.get())->data.end());\n  }\n  /** @return begin iterator */\n  inline iterator find(const K& key) const {\n    return iterator(\n        static_cast<const MapObject*>(obj_.get())->data.find(key.obj_));\n  }\n};\n\n// specialize of string map\ntemplate <typename V, typename T1, typename T2>\nclass Map<std::string, V, T1, T2> : public ObjectRef {\n public:\n  // for code reuse\n  Map() { obj_ = std::make_shared<StrMapObject>(); }\n  Map(Map<std::string, V>&& other) {  // NOLINT(*)\n    obj_ = std::move(other.obj_);\n  }\n  Map(const Map<std::string, V>& other) : ObjectRef(other.obj_) {  // NOLINT(*)\n  }\n  explicit Map(std::shared_ptr<Object> n) : ObjectRef(n) {}\n  template <typename IterType>\n  Map(IterType begin, IterType end) {\n    assign(begin, end);\n  }\n  Map(std::initializer_list<std::pair<std::string, V> > init) {  // NOLINT(*)\n    assign(init.begin(), init.end());\n  }\n\n  template <typename Hash, typename Equal>\n  Map(const std::unordered_map<std::string, V, Hash, Equal>&\n          init) {  // NOLINT(*)\n    assign(init.begin(), init.end());\n  }\n  Map<std::string, V>& operator=(Map<std::string, V>&& other) {\n    obj_ = std::move(other.obj_);\n    return *this;\n  }\n  Map<std::string, V>& operator=(const Map<std::string, V>& other) {\n    obj_ = other.obj_;\n    return *this;\n  }\n  template <typename IterType>\n  void assign(IterType begin, IterType end) {\n    auto n = std::make_shared<StrMapObject>();\n    for (IterType i = begin; i != end; ++i) {\n      n->data.emplace(std::make_pair(i->first, i->second.obj_));\n    }\n    obj_ = std::move(n);\n  }\n  inline const V operator[](const std::string& key) const {\n    return V(static_cast<const StrMapObject*>(obj_.get())->data.at(key));\n  }\n  inline const V at(const std::string& key) const {\n    return V(static_cast<const StrMapObject*>(obj_.get())->data.at(key));\n  }\n  inline size_t size() const {\n    if (obj_.get() == nullptr) return 0;\n    return static_cast<const StrMapObject*>(obj_.get())->data.size();\n  }\n  inline size_t count(const std::string& key) const {\n    if (obj_.get() == nullptr) return 0;\n    return static_cast<const StrMapObject*>(obj_.get())->data.count(key);\n  }\n  inline StrMapObject* CopyOnWrite() {\n    if (obj_.get() == nullptr || !obj_.unique()) {\n      obj_ = std::make_shared<MapObject>(\n          *static_cast<const MapObject*>(obj_.get()));\n    }\n    return static_cast<StrMapObject*>(obj_.get());\n  }\n  inline void Set(const std::string& key, const V& value) {\n    StrMapObject* n = this->CopyOnWrite();\n    n->data[key] = value.obj_;\n  }\n  inline bool empty() const { return size() == 0; }\n  using ContainerType = StrMapObject;\n\n  struct Ptr2ObjectRef {\n    using ResultType = std::pair<std::string, V>;\n    static inline ResultType convert(\n        const std::pair<std::string, std::shared_ptr<Object> >& n) {\n      return std::make_pair(n.first, V(n.second));\n    }\n  };\n\n  using iterator =\n      IterAdapter<Ptr2ObjectRef, StrMapObject::ContainerType::const_iterator>;\n\n  /** @return begin iterator */\n  inline iterator begin() const {\n    return iterator(static_cast<const StrMapObject*>(obj_.get())->data.begin());\n  }\n  /** @return end iterator */\n  inline iterator end() const {\n    return iterator(static_cast<const StrMapObject*>(obj_.get())->data.end());\n  }\n  /** @return begin iterator */\n  inline iterator find(const std::string& key) const {\n    return iterator(\n        static_cast<const StrMapObject*>(obj_.get())->data.find(key));\n  }\n};\n\n/**\n * @brief Helper function to convert a List<Value> object to a vector.\n * @tparam T element type\n * @param list Input list object.\n * @return std vector\n */\ntemplate <typename T>\ninline std::vector<T> ListValueToVector(const List<Value>& list) {\n  std::vector<T> ret;\n  ret.reserve(list.size());\n  for (Value val : list)\n    // (BarclayII) apparently MSVC 2017 CL 19.10 had trouble parsing\n    //     ret.push_back(val->data)\n    // So I kindly tell it how to properly parse it.\n    ret.push_back(val->data.operator T());\n  return ret;\n}\n\n}  // namespace runtime\n}  // namespace dgl\n\n#endif  // DGL_RUNTIME_CONTAINER_H_\n"
  },
  {
    "path": "include/dgl/runtime/device_api.h",
    "content": "/**\n *  Copyright (c) 2016 by Contributors\n * @file dgl/runtime/device_api.h\n * @brief Abstract device memory management API\n */\n#ifndef DGL_RUNTIME_DEVICE_API_H_\n#define DGL_RUNTIME_DEVICE_API_H_\n\n#include <string>\n\n#include \"c_runtime_api.h\"\n#include \"packed_func.h\"\n\nnamespace dgl {\nnamespace runtime {\n/**\n * @brief the query type into GetAttr\n */\nenum DeviceAttrKind : int {\n  kExist = 0,\n  kMaxThreadsPerBlock = 1,\n  kWarpSize = 2,\n  kMaxSharedMemoryPerBlock = 3,\n  kComputeVersion = 4,\n  kDeviceName = 5,\n  kMaxClockRate = 6,\n  kMultiProcessorCount = 7,\n  kMaxThreadDimensions = 8\n};\n\n/** @brief Number of bytes each allocation must align to */\nconstexpr int kAllocAlignment = 64;\n\n/** @brief Number of bytes each allocation must align to in temporary allocation\n */\nconstexpr int kTempAllocaAlignment = 64;\n\n/** @brief Maximum size that can be allocated on stack */\nconstexpr int kMaxStackAlloca = 1024;\n\n/**\n * @brief DGL Runtime Device API, abstracts the device\n *  specific interface for memory management.\n */\nclass DeviceAPI {\n public:\n  /** @brief virtual destructor */\n  virtual ~DeviceAPI() {}\n  /**\n   * @brief Check whether the device is available.\n   */\n  virtual bool IsAvailable() { return true; }\n\n  /**\n   * @brief Set the environment device id to ctx\n   * @param ctx The context to be set.\n   */\n  virtual void SetDevice(DGLContext ctx) = 0;\n\n  /**\n   * @brief Get attribute of specified device.\n   * @param ctx The device context\n   * @param kind The result kind\n   * @param rv The return value.\n   * @sa DeviceAttrKind\n   */\n  virtual void GetAttr(\n      DGLContext ctx, DeviceAttrKind kind, DGLRetValue* rv) = 0;\n\n  /**\n   * @brief Allocate a data space on device.\n   * @param ctx The device context to perform operation.\n   * @param nbytes The number of bytes in memory.\n   * @param alignment The alignment of the memory.\n   * @param type_hint The type of elements. Only needed by certain backends such\n   * as OpenGL, as nbytes & alignment are sufficient for most backends.\n   * @return The allocated device pointer.\n   */\n  virtual void* AllocDataSpace(\n      DGLContext ctx, size_t nbytes, size_t alignment,\n      DGLDataType type_hint) = 0;\n\n  /**\n   * @brief Free a data space on device.\n   * @param ctx The device context to perform operation.\n   * @param ptr The data space.\n   */\n  virtual void FreeDataSpace(DGLContext ctx, void* ptr) = 0;\n\n  /**\n   * @brief copy data from one place to another\n   * @param from The source array.\n   * @param from_offset The byte offeset in the from.\n   * @param to The target array.\n   * @param to_offset The byte offset in the to.\n   * @param num_bytes The size of the memory in bytes.\n   * @param ctx_from The source context.\n   * @param ctx_to The target context.\n   * @param type_hint The type of elements, only needed by certain backends,\n   *     can be useful for cross device endian converison.\n   */\n  virtual void CopyDataFromTo(\n      const void* from, size_t from_offset, void* to, size_t to_offset,\n      size_t num_bytes, DGLContext ctx_from, DGLContext ctx_to,\n      DGLDataType type_hint) = 0;\n\n  /**\n   * @brief copy data between device and CPU while recording the event.\n   * @param from The source array.\n   * @param from_offset The byte offeset in the from.\n   * @param to The target array.\n   * @param to_offset The byte offset in the to.\n   * @param num_bytes The size of the memory in bytes.\n   * @param ctx_from The source context.\n   * @param ctx_to The target context.\n   * @param type_hint The type of elements, only needed by certain backends,\n   *     can be useful for cross device endian converison.\n   * @param pytorch_ctx The context pointer from PyTorch's CachingHostAllocator.\n   * @note This function only works when PyTorch CachingHostAllocator is\n   *     available.\n   */\n  virtual void RecordedCopyDataFromTo(\n      void* from, size_t from_offset, void* to, size_t to_offset,\n      size_t num_bytes, DGLContext ctx_from, DGLContext ctx_to,\n      DGLDataType type_hint, void* pytorch_ctx) = 0;\n\n  /**\n   * @brief Create a new stream of execution.\n   *\n   * @param ctx The context of allocation.\n   */\n  DGL_DLL virtual DGLStreamHandle CreateStream(DGLContext ctx);\n\n  /**\n   * @brief Free a stream of execution\n   *\n   * @param ctx The context of the stream\n   * @param stream The pointer to be freed.\n   */\n  DGL_DLL virtual void FreeStream(DGLContext ctx, DGLStreamHandle stream);\n\n  /**\n   * @brief Synchronize the stream\n   * @param ctx The context to perform operation.\n   * @param stream The stream to be sync.\n   */\n  virtual void StreamSync(DGLContext ctx, DGLStreamHandle stream) = 0;\n\n  /**\n   * @brief Set the stream\n   * @param ctx The context to set stream.\n   * @param stream The stream to be set.\n   */\n  virtual void SetStream(DGLContext ctx, DGLStreamHandle stream) {}\n\n  /**\n   * @brief Get the stream\n   */\n  virtual DGLStreamHandle GetStream() const { return nullptr; }\n\n  /**\n   * @brief Synchronize 2 streams of execution.\n   *\n   * An event is created in event_src stream that the second then\n   * stream waits on.  Neither event_src or event_dst need to be of\n   * the same device ID as the context, but they must be of the same\n   * device type.\n   *\n   * @param ctx The context of the streams.\n   * @param event_src The source stream to synchronize.\n   * @param event_dst The destination stream to synchronize.\n   */\n  DGL_DLL virtual void SyncStreamFromTo(\n      DGLContext ctx, DGLStreamHandle event_src, DGLStreamHandle event_dst);\n\n  /**\n   * @brief Pin host memory using cudaHostRegister().\n   *\n   * @param ptr The host memory pointer to be pinned.\n   * @param nbytes The size to be pinned.\n   * @return false when pinning an empty tensor. true otherwise.\n   */\n  DGL_DLL virtual bool PinData(void* ptr, size_t nbytes);\n\n  /**\n   * @brief Unpin host memory using cudaHostUnregister().\n   *\n   * @param ptr The host memory pointer to be unpinned.\n   */\n  DGL_DLL virtual void UnpinData(void* ptr);\n\n  /**\n   * @brief Allocate the pinned memory using PyTorch CachingHostAllocator.\n   *\n   * @param nbytes The size to be pinned.\n   * @param ctx Pointer to the context pointer from PyTorch's\n   *     CachingHostAllocator.\n   * @param deleter Pointer to the deleter function from PyTorch's\n   *     CachingHostAllocator.\n   */\n  DGL_DLL virtual void* AllocPinnedDataSpace(\n      size_t nbytes, void** ctx, void** deleter);\n\n  /**\n   * @brief 'Deallocate' the pinned memory from PyTorch CachingHostAllocator.\n   * @note It avoids unnecessary cudaFreeHost calls and puts the memory\n   *     block into CachingHostAllocator's free list.\n   * @param deleter Pointer to the deleter function from PyTorch's\n   *     CachingHostAllocator.\n   */\n  DGL_DLL virtual void FreePinnedDataSpace(void** deleter);\n\n  /**\n   * @brief Check whether the memory is in pinned memory.\n   */\n  DGL_DLL virtual bool IsPinned(const void* ptr) { return false; }\n\n  /**\n   * @brief Allocate temporal workspace for backend execution.\n   *\n   *  \\note We have the following assumption about backend temporal\n   *   workspace allocation, and backend will optimize for such assumption:\n   *\n   *  - Only a few allocation will happen, and space will be released after use.\n   *  - The release order is usually in reverse order of allocate (stack style).\n   *  - Repeative pattern of same allocations over different runs.\n   *  - Workspace should not overlap between different threads(i.e. be\n   * threadlocal)\n   *\n   * @param ctx The context of allocation.\n   * @param nbytes The size to be allocated.\n   * @param type_hint The type of elements. Only needed by certain backends such\n   * as OpenGL, as nbytes is sufficient for most backends.\n   */\n  DGL_DLL virtual void* AllocWorkspace(\n      DGLContext ctx, size_t nbytes, DGLDataType type_hint = {});\n\n  /**\n   * @brief Free temporal workspace in backend execution.\n   *\n   * @param ctx The context of allocation.\n   * @param ptr The pointer to be freed.\n   */\n  DGL_DLL virtual void FreeWorkspace(DGLContext ctx, void* ptr);\n\n  /**\n   * @brief Get device API based on context.\n   * @param ctx The context\n   * @param allow_missing Whether allow missing\n   * @return The corresponding device API.\n   */\n  DGL_DLL static DeviceAPI* Get(DGLContext ctx, bool allow_missing = false);\n\n  /**\n   * @brief Get device API based on device type.\n   * @param dev_type The device type\n   * @param allow_missing Whether allow missing\n   * @return The corresponding device API.\n   */\n  DGL_DLL static DeviceAPI* Get(\n      DGLDeviceType dev_type, bool allow_missing = false);\n};\n\n/** @brief The device type bigger than this is RPC device */\nconstexpr int kRPCSessMask = 128;\n}  // namespace runtime\n}  // namespace dgl\n#endif  // DGL_RUNTIME_DEVICE_API_H_\n"
  },
  {
    "path": "include/dgl/runtime/dlpack_convert.h",
    "content": "/**\n *  Copyright (c) 2022 by Contributors\n * @file include/dgl/runtime/dlpack_convert.h\n * @brief Conversion between NDArray and DLPack.\n */\n#ifndef DGL_RUNTIME_DLPACK_CONVERT_H_\n#define DGL_RUNTIME_DLPACK_CONVERT_H_\n\n#include \"c_runtime_api.h\"\n#include \"ndarray.h\"\n\nstruct DLManagedTensor;\n\nnamespace dgl {\nnamespace runtime {\n\nstruct DLPackConvert {\n  /**\n   * @brief Create a DGL NDArray from a DLPack tensor.\n   *\n   * This allows us to create a NDArray using the memory\n   * allocated by an external deep learning framework\n   * that is DLPack compatible.\n   *\n   * The memory is retained until the NDArray went out of scope.\n   * @param tensor The DLPack tensor to copy from.\n   * @return The created NDArray view.\n   */\n  static NDArray FromDLPack(DLManagedTensor* tensor);\n\n  /**\n   * @brief Deleter for NDArray converted from DLPack.\n   *\n   * This is used from data which is passed from external\n   * DLPack(DLManagedTensor) that are not allocated inside of DGL. This enables\n   * us to create NDArray from memory allocated by other frameworks that are\n   * DLPack compatible\n   */\n  static void DLPackDeleter(NDArray::Container* ptr);\n\n  /** @brief Convert a DGL NDArray to a DLPack tensor.\n   *\n   * @param from The DGL NDArray.\n   * @return A DLPack tensor.\n   */\n  static DLManagedTensor* ToDLPack(const NDArray& from);\n};\n\n}  // namespace runtime\n}  // namespace dgl\n\n#ifdef __cplusplus\nextern \"C\" {\n#endif\n\n/**\n * @brief Delete (free) a DLManagedTensor's data.\n * @param dltensor Pointer to the DLManagedTensor.\n */\nDGL_DLL void DGLDLManagedTensorCallDeleter(DLManagedTensor* dltensor);\n\n/**\n * @brief Produce an array from the DLManagedTensor that shares data memory\n * with the DLManagedTensor.\n * @param from The source DLManagedTensor.\n * @param out The output array handle.\n * @return 0 when success, -1 when failure happens\n */\nDGL_DLL int DGLArrayFromDLPack(DLManagedTensor* from, DGLArrayHandle* out);\n\n/**\n * @brief Produce a DLMangedTensor from the array that shares data memory with\n * the array.\n * @param from The source array.\n * @param out The DLManagedTensor handle.\n * @return 0 when success, -1 when failure happens\n */\nDGL_DLL int DGLArrayToDLPack(\n    DGLArrayHandle from, DLManagedTensor** out, int alignment = 0);\n\n#ifdef __cplusplus\n}  // DGL_EXTERN_C\n#endif\n#endif  // DGL_RUNTIME_DLPACK_CONVERT_H_\n"
  },
  {
    "path": "include/dgl/runtime/module.h",
    "content": "/**\n *  Copyright (c) 2017 by Contributors\n * @file dgl/runtime/module.h\n * @brief Runtime container of the functions generated by DGL,\n *  This is used to support dynamically link, load and save\n *  functions from different convention under unified API.\n */\n#ifndef DGL_RUNTIME_MODULE_H_\n#define DGL_RUNTIME_MODULE_H_\n\n#include <dmlc/io.h>\n\n#include <memory>\n#include <string>\n#include <unordered_map>\n#include <vector>\n\n#include \"c_runtime_api.h\"\n\nnamespace dgl {\nnamespace runtime {\n\n// The internal container of module.\nclass ModuleNode;\nclass PackedFunc;\n\n/**\n * @brief Module container of DGL.\n */\nclass Module {\n public:\n  Module() {}\n  // constructor from container.\n  explicit Module(std::shared_ptr<ModuleNode> n) : node_(n) {}\n  /**\n   * @brief Get packed function from current module by name.\n   *\n   * @param name The name of the function.\n   * @param query_imports Whether also query dependency modules.\n   * @return The result function.\n   *  This function will return PackedFunc(nullptr) if function do not exist.\n   * @note Implemented in packed_func.cc\n   */\n  inline PackedFunc GetFunction(\n      const std::string& name, bool query_imports = false);\n  /** @return internal container */\n  inline ModuleNode* operator->();\n  /** @return internal container */\n  inline const ModuleNode* operator->() const;\n  // The following functions requires link with runtime.\n  /**\n   * @brief Import another module into this module.\n   * @param other The module to be imported.\n   *\n   * @note Cyclic dependency is not allowed among modules,\n   *  An error will be thrown when cyclic dependency is detected.\n   */\n  DGL_DLL void Import(Module other);\n  /**\n   * @brief Load a module from file.\n   * @param file_name The name of the host function module.\n   * @param format The format of the file.\n   * @note This function won't load the import relationship.\n   *  Re-create import relationship by calling Import.\n   */\n  DGL_DLL static Module LoadFromFile(\n      const std::string& file_name, const std::string& format = \"\");\n\n private:\n  std::shared_ptr<ModuleNode> node_;\n};\n\n/**\n * @brief Base node container of module.\n *  Do not create this directly, instead use Module.\n */\nclass ModuleNode {\n public:\n  /** @brief virtual destructor */\n  virtual ~ModuleNode() {}\n  /** @return The module type key */\n  virtual const char* type_key() const = 0;\n  /**\n   * @brief Get a PackedFunc from module.\n   *\n   *  The PackedFunc may not be fully initialized,\n   *  there might still be first time running overhead when\n   *  executing the function on certain devices.\n   *  For benchmarking, use prepare to eliminate\n   *\n   * @param name the name of the function.\n   * @param sptr_to_self The shared_ptr that points to this module node.\n   *\n   * @return PackedFunc(nullptr) when it is not available.\n   *\n   * @note The function will always remain valid.\n   *   If the function need resource from the module(e.g. late linking),\n   *   it should capture sptr_to_self.\n   */\n  virtual PackedFunc GetFunction(\n      const std::string& name,\n      const std::shared_ptr<ModuleNode>& sptr_to_self) = 0;\n  /**\n   * @brief Save the module to file.\n   * @param file_name The file to be saved to.\n   * @param format The format of the file.\n   */\n  virtual void SaveToFile(\n      const std::string& file_name, const std::string& format);\n  /**\n   * @brief Save the module to binary stream.\n   * @param stream The binary stream to save to.\n   * @note It is recommended to implement this for device modules,\n   *   but not necessarily host modules.\n   *   We can use this to do AOT loading of bundled device functions.\n   */\n  DGL_DLL virtual void SaveToBinary(dmlc::Stream* stream);\n  /**\n   * @brief Get the source code of module, when available.\n   * @param format Format of the source code, can be empty by default.\n   * @return Possible source code when available.\n   */\n  DGL_DLL virtual std::string GetSource(const std::string& format = \"\");\n  /**\n   * @brief Get a function from current environment\n   *  The environment includes all the imports as well as Global functions.\n   *\n   * @param name name of the function.\n   * @return The corresponding function.\n   */\n  DGL_DLL const PackedFunc* GetFuncFromEnv(const std::string& name);\n  /** @return The module it imports from */\n  const std::vector<Module>& imports() const { return imports_; }\n\n protected:\n  friend class Module;\n  /** @brief The modules this module depend on */\n  std::vector<Module> imports_;\n\n private:\n  /** @brief Cache used by GetImport */\n  std::unordered_map<std::string, std::unique_ptr<PackedFunc> > import_cache_;\n};\n\n/** @brief namespace for constant symbols */\nnamespace symbol {\n/** @brief Global variable to store module context. */\nconstexpr const char* dgl_module_ctx = \"__dgl_module_ctx\";\n/** @brief Global variable to store device module blob */\nconstexpr const char* dgl_dev_mblob = \"__dgl_dev_mblob\";\n/** @brief Number of bytes of device module blob. */\nconstexpr const char* dgl_dev_mblob_nbytes = \"__dgl_dev_mblob_nbytes\";\n/** @brief global function to set device */\nconstexpr const char* dgl_set_device = \"__dgl_set_device\";\n/** @brief Auxiliary counter to global barrier. */\nconstexpr const char* dgl_global_barrier_state = \"__dgl_global_barrier_state\";\n/**\n * @brief Prepare the global barrier before kernels that uses global barrier.\n */\nconstexpr const char* dgl_prepare_global_barrier =\n    \"__dgl_prepare_global_barrier\";\n/** @brief Placeholder for the module's entry function. */\nconstexpr const char* dgl_module_main = \"__dgl_main__\";\n}  // namespace symbol\n\n// implementations of inline functions.\ninline ModuleNode* Module::operator->() { return node_.get(); }\n\ninline const ModuleNode* Module::operator->() const { return node_.get(); }\n\n}  // namespace runtime\n}  // namespace dgl\n\n#include \"packed_func.h\"\n#endif  // DGL_RUNTIME_MODULE_H_\n"
  },
  {
    "path": "include/dgl/runtime/ndarray.h",
    "content": "/**\n *  Copyright (c) 2017-2022 by Contributors\n * @file dgl/runtime/ndarray.h\n * @brief Abstract device memory management API\n */\n#ifndef DGL_RUNTIME_NDARRAY_H_\n#define DGL_RUNTIME_NDARRAY_H_\n\n#include <atomic>\n#include <memory>\n#include <string>\n#include <utility>\n#include <vector>\n\n#include \"bfloat16.h\"\n#include \"c_runtime_api.h\"\n#include \"serializer.h\"\n#include \"shared_mem.h\"\n\n#ifdef DGL_USE_CUDA\n#include <cuda_runtime.h>\n\n#define BF16_ENABLED (defined(CUDART_VERSION) && CUDART_VERSION >= 11000)\n\n#include <cuda_fp16.h>\n#if BF16_ENABLED\n#include <cuda_bf16.h>\n#endif  // BF16_ENABLED\n#endif  // DGL_USE_CUDA\n\n// forward declaration\ninline std::ostream& operator<<(std::ostream& os, DGLDataType t);\n\nnamespace dgl {\n\n/**\n * @brief Type traits that converts a C type to a DGLDataType.\n *\n * Usage:\n * DGLDataTypeTraits<int>::dtype == dtype\n */\ntemplate <typename T>\nstruct DGLDataTypeTraits {\n  static constexpr DGLDataType dtype{0, 0, 0};  // dummy\n};\n#define GEN_DGLDATATYPETRAITS_FOR(T, code, bits)       \\\n  template <>                                          \\\n  struct DGLDataTypeTraits<T> {                        \\\n    static constexpr DGLDataType dtype{code, bits, 1}; \\\n  }\nGEN_DGLDATATYPETRAITS_FOR(int8_t, kDGLInt, 8);\nGEN_DGLDATATYPETRAITS_FOR(uint8_t, kDGLUInt, 8);\nGEN_DGLDATATYPETRAITS_FOR(int16_t, kDGLInt, 16);\nGEN_DGLDATATYPETRAITS_FOR(int32_t, kDGLInt, 32);\nGEN_DGLDATATYPETRAITS_FOR(int64_t, kDGLInt, 64);\n// XXX(BarclayII) most DL frameworks do not support unsigned int and long\n// arrays, so I'm just converting uints to signed DTypes.\nGEN_DGLDATATYPETRAITS_FOR(uint32_t, kDGLInt, 32);\nGEN_DGLDATATYPETRAITS_FOR(uint64_t, kDGLInt, 64);\n#ifdef DGL_USE_CUDA\nGEN_DGLDATATYPETRAITS_FOR(__half, kDGLFloat, 16);\n#if BF16_ENABLED\nGEN_DGLDATATYPETRAITS_FOR(__nv_bfloat16, kDGLBfloat, 16);\n#endif  // BF16_ENABLED\n#endif  // DGL_USE_CUDA\nGEN_DGLDATATYPETRAITS_FOR(float, kDGLFloat, 32);\nGEN_DGLDATATYPETRAITS_FOR(double, kDGLFloat, 64);\n#undef GEN_DGLDATATYPETRAITS_FOR\n\nnamespace runtime {\n\n/**\n * @brief DLPack converter.\n */\nstruct DLPackConvert;\n\n/**\n * @brief Managed NDArray.\n *  The array is backed by reference counted blocks.\n */\nclass NDArray {\n public:\n  // internal container type\n  struct Container;\n  /** @brief default constructor */\n  NDArray() {}\n  /**\n   * @brief cosntruct a NDArray that refers to data\n   * @param data The data this NDArray refers to\n   */\n  explicit inline NDArray(Container* data);\n  /**\n   * @brief copy constructor\n   * @param other The value to be copied\n   */\n  inline NDArray(const NDArray& other);  // NOLINT(*)\n  /**\n   * @brief move constructor\n   * @param other The value to be moved\n   */\n  NDArray(NDArray&& other)  // NOLINT(*)\n      : data_(other.data_) {\n    other.data_ = nullptr;\n  }\n  /** @brief destructor */\n  ~NDArray() { this->reset(); }\n  /**\n   * @brief Swap this array with another NDArray\n   * @param other The other NDArray\n   */\n  void swap(NDArray& other) {  // NOLINT(*)\n    std::swap(data_, other.data_);\n  }\n  /**\n   * @brief copy assignmemt\n   * @param other The value to be assigned.\n   * @return reference to self.\n   */\n  NDArray& operator=(const NDArray& other) {  // NOLINT(*)\n    // copy-and-swap idiom\n    NDArray(other).swap(*this);  // NOLINT(*)\n    return *this;\n  }\n  /**\n   * @brief move assignmemt\n   * @param other The value to be assigned.\n   * @return reference to self.\n   */\n  NDArray& operator=(NDArray&& other) {  // NOLINT(*)\n    // copy-and-swap idiom\n    NDArray(std::move(other)).swap(*this);  // NOLINT(*)\n    return *this;\n  }\n  /** @return If NDArray is defined */\n  bool defined() const { return data_ != nullptr; }\n  /** @return If both NDArray reference the same container */\n  bool same_as(const NDArray& other) const { return data_ == other.data_; }\n  /** @brief reset the content of NDArray to be nullptr */\n  inline void reset();\n  /**\n   * @return the reference counter\n   * @note this number is approximate in multi-threaded setting.\n   */\n  inline int use_count() const;\n  /** @return Pointer to content of DGLArray */\n  inline const DGLArray* operator->() const;\n  /** @return True if the ndarray is contiguous. */\n  bool IsContiguous() const;\n  /** @return the data pointer with type. */\n  template <typename T>\n  inline T* Ptr() const {\n    if (!defined())\n      return nullptr;\n    else\n      return static_cast<T*>(operator->()->data);\n  }\n\n  /**\n   * @brief Copy data content from/into another array.\n   * @param other The source array to be copied from.\n   * @note The copy runs on the dgl internal stream if it involves a GPU\n   * context.\n   */\n  inline void CopyFrom(DGLArray* other);\n  inline void CopyFrom(const NDArray& other);\n  inline void CopyTo(DGLArray* other) const;\n  inline void CopyTo(const NDArray& other) const;\n\n  /**\n   * @brief Copy the data to another context.\n   * @param ctx The target context.\n   * @return The array under another context.\n   */\n  inline NDArray CopyTo(const DGLContext& ctx) const;\n\n  /**\n   * @brief Return a new array with a copy of the content.\n   */\n  inline NDArray Clone() const;\n\n  /**\n   * @brief Return a copy of the current instance of NDArray in pinned\n   *     (page-locked) memory.\n   * @note This is an out-of-place method, which utilizes PyTorch's\n   *     CachingHostAllocator for allocating pinned memory and copying data\n   *     from the current NDAarray. As a result, PyTorch is responsible for\n   *     managing the lifecycle of the returned NDArray, including deciding\n   *     when to flush the data for reuse or call cudaFreeHost. The current\n   *     context must be kDGLCPU, otherwise, an error will be thrown.\n   */\n  inline NDArray PinMemory();\n\n  /**\n   * @brief In-place method to pin the current array by calling PinContainer\n   *        on the underlying NDArray:Container.\n   * @note This is an in-place method that flags the memory as page-locked by\n   *     utilizing cudaHostRegister at the underlying level to pin the current\n   *     instance of NDArray. The current context must be kDGLCPU, otherwise,\n   *     an error will be thrown.\n   */\n  inline void PinMemory_();\n\n  /**\n   * @brief In-place method to unpin the current array by calling UnpinContainer\n   *        on the underlying NDArray:Container.\n   * @note This is an in-place method. Behavior depends on the current context,\n   *       IsPinned: will be unpinned;\n   *       others: directly return.\n   */\n  inline void UnpinMemory_();\n\n  /**\n   * @brief Check if the array is pinned.\n   */\n  inline bool IsPinned() const;\n\n  /**\n   * @brief Record streams that are using the underlying tensor.\n   * @param stream The stream that is using the underlying tensor.\n   */\n  inline void RecordStream(DGLStreamHandle stream) const;\n\n  /**\n   * @brief Load NDArray from stream\n   * @param stream The input data stream\n   * @return Whether load is successful\n   */\n  bool Load(dmlc::Stream* stream);\n\n  /**\n   * @brief Save NDArray to stream\n   * @param stream The output data stream\n   */\n  void Save(dmlc::Stream* stream) const;\n\n  /**\n   * @brief Create a NDArray that shares the data memory with the current one.\n   * @param shape The shape of the new array.\n   * @param dtype The data type of the new array.\n   * @param offset The offset (in bytes) of the starting pointer.\n   * @note The memory size of new array must be smaller than the current one.\n   */\n  DGL_DLL NDArray\n  CreateView(std::vector<int64_t> shape, DGLDataType dtype, int64_t offset = 0);\n\n  /**\n   * @brief Create an empty NDArray.\n   * @param shape The shape of the new array.\n   * @param dtype The data type of the new array.\n   * @param ctx The context of the array.\n   * @return The created Array\n   */\n  DGL_DLL static NDArray Empty(\n      std::vector<int64_t> shape, DGLDataType dtype, DGLContext ctx);\n\n  /**\n   * @brief Create an empty NDArray in pinned memory.\n   * @param shape The shape of the new array.\n   * @param dtype The data type of the new array.\n   * @param ctx The context of the array.\n   * @return The created array.\n   */\n  DGL_DLL static NDArray PinnedEmpty(\n      std::vector<int64_t> shape, DGLDataType dtype, DGLContext ctx);\n\n  /**\n   * @brief Create an empty NDArray with shared memory.\n   * @param name The name of shared memory.\n   * @param shape The shape of the new array.\n   * @param dtype The data type of the new array.\n   * @param ctx The context of the array.\n   * @param is_create whether to create shared memory.\n   * @return The created Array\n   */\n  DGL_DLL static NDArray EmptyShared(\n      const std::string& name, std::vector<int64_t> shape, DGLDataType dtype,\n      DGLContext ctx, bool is_create);\n\n  /**\n   * @brief Get the size of the array in the number of bytes.\n   */\n  size_t GetSize() const;\n\n  /**\n   * @brief Get the number of elements in this array.\n   */\n  int64_t NumElements() const;\n\n  /**\n   * @brief Create a NDArray by copying from std::vector.\n   * @tparam T Type of vector data.  Determines the dtype of returned array.\n   */\n  template <typename T>\n  DGL_DLL static NDArray FromVector(\n      const std::vector<T>& vec, DGLContext ctx = DGLContext{kDGLCPU, 0});\n\n  /**\n   * @brief Create a NDArray from a raw pointer.\n   */\n  DGL_DLL static NDArray CreateFromRaw(\n      const std::vector<int64_t>& shape, DGLDataType dtype, DGLContext ctx,\n      void* raw, bool auto_free);\n\n  /**\n   * @brief Create a std::vector from a 1D NDArray.\n   * @tparam T Type of vector data.\n   * @note Type casting is NOT performed.  The caller has to make sure that the\n   * vector type matches the dtype of NDArray.\n   */\n  template <typename T>\n  std::vector<T> ToVector() const;\n\n  std::shared_ptr<SharedMemory> GetSharedMem() const;\n\n  /**\n   * @brief Function to copy data from one array to another.\n   * @param from The source array.\n   * @param to The target array.\n   * @param (optional) stream The stream used in copy.\n   */\n  DGL_DLL static void CopyFromTo(DGLArray* from, DGLArray* to);\n  DGL_DLL static void CopyFromTo(\n      DGLArray* from, DGLArray* to, DGLStreamHandle stream);\n\n  /**\n   * @brief Function to copy data between device and CPU while recording the\n   *     event.\n   * @param from The source array.\n   * @param to The target array.\n   * @param pytorch_ctx The context pointer from PyTorch's CachingHostAllocator.\n   * @note This function fuses data-copy and event recording to ensure\n   *     CachingHostAllocator works properly.\n   */\n  DGL_DLL static void RecordedCopyFromTo(\n      DGLArray* from, DGLArray* to, void* pytorch_ctx);\n\n  /**\n   * @brief Function to pin the DGLArray of a Container.\n   * @param ptr The container to be pinned.\n   * @note Data of the given array will be pinned inplace.\n   *       Behavior depends on the current context,\n   *       kDGLCPU: will be pinned;\n   *       IsPinned: directly return;\n   *       kDGLCUDA: invalid, will throw an error.\n   */\n  DGL_DLL static void PinContainer(Container* ptr);\n\n  /**\n   * @brief Function to unpin the DGLArray of a Container.\n   * @param ptr The container to be unpinned.\n   * @note Data of the given array will be unpinned inplace.\n   *       Behavior depends on the current context,\n   *       IsPinned: will be unpinned;\n   *       others: directly return.\n   */\n  DGL_DLL static void UnpinContainer(Container* ptr);\n\n  /**\n   * @brief Function check if the DGLArray of a Container is pinned.\n   * @param ptr The container to be checked.\n   * @return true if pinned.\n   */\n  DGL_DLL static bool IsContainerPinned(Container* ptr);\n\n  /**\n   * @brief Record streams that are using this tensor.\n   * @param ptr Pointer of the tensor to be recorded.\n   * @param stream The stream that is using this tensor.\n   */\n  DGL_DLL static void RecordStream(DGLArray* tensor, DGLStreamHandle stream);\n\n  // internal namespace\n  struct Internal {\n    // Default deleter for the container\n    static void DefaultDeleter(NDArray::Container* ptr);\n    // Local create function which allocates tensor metadata\n    // but does not allocate space for the data.\n    static NDArray Create(\n        std::vector<int64_t> shape, DGLDataType dtype, DGLContext ctx);\n    // Implementation of API function\n    static DGLArray* MoveAsDGLArray(NDArray arr);\n  };\n\n private:\n  /** @brief Internal Data content */\n  Container* data_{nullptr};\n  // enable internal functions\n  friend struct Internal;\n  friend struct DLPackConvert;\n  friend class DGLRetValue;\n  friend class DGLArgsSetter;\n};\n\n/**\n * @brief Save a DGLArray to stream\n * @param strm The outpu stream\n * @param tensor The tensor to be saved.\n */\ninline bool SaveDGLArray(dmlc::Stream* strm, const DGLArray* tensor);\n\n/**\n * @brief Reference counted Container object used to back NDArray.\n *\n *  This object is DGLArray compatible:\n *    the pointer to the NDArrayContainer can be directly\n *    interpreted as a DGLArray*\n *\n * @note: do not use this function directly, use NDArray.\n */\nstruct NDArray::Container {\n public:\n  /** NOTE: the first part of this structure is the same as\n   * DLManagedTensor, note that, however, the deleter\n   * is only called when the reference counter goes to 0\n   */\n  /**\n   * @brief Tensor structure.\n   * @note it is important that the first field is DGLArray\n   *  So that this data structure is DGLArray compatible.\n   *  The head ptr of this struct can be viewed as DGLArray*.\n   */\n  DGLArray dl_tensor;\n  /**\n   * @brief addtional context, reserved for recycling\n   * @note We can attach additional content here\n   *  which the current container depend on\n   *  (e.g. reference to original memory when creating views).\n   */\n  void* manager_ctx{nullptr};\n  /**\n   * @brief Customized deleter\n   *\n   * @note The customized deleter is helpful to enable\n   *  different ways of memory allocator that are not\n   *  currently defined by the system.\n   */\n  void (*deleter)(Container* self) = nullptr;\n  /** @brief default constructor */\n  Container() {\n    dl_tensor.data = nullptr;\n    dl_tensor.ndim = 0;\n    dl_tensor.shape = nullptr;\n    dl_tensor.strides = nullptr;\n    dl_tensor.byte_offset = 0;\n  }\n  /** @brief pointer to shared memory */\n  std::shared_ptr<SharedMemory> mem;\n  /** @brief developer function, increases reference counter */\n  void IncRef() { ref_counter_.fetch_add(1, std::memory_order_relaxed); }\n  /** @brief developer function, decrease reference counter */\n  void DecRef() {\n    if (ref_counter_.fetch_sub(1, std::memory_order_release) == 1) {\n      std::atomic_thread_fence(std::memory_order_acquire);\n      if (this->deleter != nullptr) {\n        (*this->deleter)(this);\n      }\n    }\n  }\n\n private:\n  friend struct DLPackConvert;\n  friend class NDArray;\n  friend class RPCWrappedFunc;\n  /**\n   * @brief The shape container,\n   *  can be used for shape data.\n   */\n  std::vector<int64_t> shape_;\n  /**\n   * @brief The stride container,\n   *  can be used for stride data.\n   */\n  std::vector<int64_t> stride_;\n  /** @brief The internal array object */\n  std::atomic<int> ref_counter_{0};\n\n  /** @brief Whether underlying dl_tensor is pinned by DGL. */\n  bool pinned_by_dgl_{false};\n\n  /** @brief Whether underlying dl_tensor is pinned by PyTorch\n   *    (CachingHostAllocator). */\n  bool pinned_by_pytorch_{false};\n\n  /** @brief The PyTorch storage ctx ptr if pinned_by_pytorch_ = True. */\n  void* pytorch_ctx_{nullptr};\n\n  /** @brief Pointer to the corresp. PyTorch deleter if pinned_by_pytorch_ =\n   *    True.\n   */\n  void* pytorch_raw_deleter_{nullptr};\n};\n\n// implementations of inline functions\n// the usages of functions are documented in place.\ninline NDArray::NDArray(Container* data) : data_(data) {\n  if (data_) data_->IncRef();\n}\n\ninline NDArray::NDArray(const NDArray& other) : data_(other.data_) {\n  if (data_) data_->IncRef();\n}\n\ninline void NDArray::reset() {\n  if (data_) {\n    data_->DecRef();\n    data_ = nullptr;\n  }\n}\n\ninline void NDArray::CopyFrom(DGLArray* other) {\n  CHECK(data_ != nullptr);\n  CopyFromTo(other, &(data_->dl_tensor));\n}\n\ninline void NDArray::CopyFrom(const NDArray& other) {\n  CHECK(other.data_ != nullptr);\n  // Copy between two devices\n  if (data_->dl_tensor.ctx.device_type !=\n      other.data_->dl_tensor.ctx.device_type) {\n    CHECK(data_ != nullptr);\n    auto to_ctx_type = data_->dl_tensor.ctx.device_type;\n    auto cpu_data = (to_ctx_type == kDGLCPU ? data_ : other.data_);\n    // Pinned by PyTorch\n    if (cpu_data->pinned_by_pytorch_) {\n      // To ensure correct behavior, the event must be recorded after\n      // cudaMemcpyAsync as long as the memory is pinned by PyTorch.\n      void* pytorch_ctx = cpu_data->pytorch_ctx_;\n      RecordedCopyFromTo(\n          &(other.data_->dl_tensor), &(data_->dl_tensor), pytorch_ctx);\n      return;\n    }\n  }\n  CopyFrom(&(other.data_->dl_tensor));\n}\n\ninline void NDArray::CopyTo(DGLArray* other) const {\n  CHECK(data_ != nullptr);\n  CopyFromTo(&(data_->dl_tensor), other);\n}\n\ninline void NDArray::CopyTo(const NDArray& other) const {\n  CHECK(other.data_ != nullptr);\n  // copy between two devices\n  if (data_->dl_tensor.ctx.device_type !=\n      other.data_->dl_tensor.ctx.device_type) {\n    CHECK(data_ != nullptr);\n    auto from_ctx_type = data_->dl_tensor.ctx.device_type;\n    auto cpu_data = (from_ctx_type == kDGLCPU ? data_ : other.data_);\n    // pinned by PyTorch\n    if (cpu_data->pinned_by_pytorch_) {\n      // To ensure correct behavior, the event must be recorded after\n      // cudaMemcpyAsync as long as the memory is pinned by PyTorch.\n      void* pytorch_ctx = cpu_data->pytorch_ctx_;\n      RecordedCopyFromTo(\n          &(data_->dl_tensor), &(other.data_->dl_tensor), pytorch_ctx);\n      return;\n    }\n  }\n  CopyTo(&(other.data_->dl_tensor));\n}\n\ninline NDArray NDArray::CopyTo(const DGLContext& ctx) const {\n  CHECK(data_ != nullptr);\n  const DGLArray* array = operator->();\n  NDArray ret = Empty(\n      std::vector<int64_t>(array->shape, array->shape + array->ndim),\n      array->dtype, ctx);\n  this->CopyTo(ret);\n  return ret;\n}\n\ninline NDArray NDArray::Clone() const {\n  CHECK(data_ != nullptr);\n  const DGLArray* array = operator->();\n  return this->CopyTo(array->ctx);\n}\n\ninline NDArray NDArray::PinMemory() {\n  CHECK(data_ != nullptr);\n  const DGLArray* array = operator->();\n  auto ctx = array->ctx;\n  NDArray ret = PinnedEmpty(\n      std::vector<int64_t>(array->shape, array->shape + array->ndim),\n      array->dtype, ctx);\n  this->CopyTo(ret);\n  return ret;\n}\n\ninline void NDArray::PinMemory_() {\n  CHECK(data_ != nullptr);\n  PinContainer(data_);\n}\n\ninline void NDArray::UnpinMemory_() {\n  CHECK(data_ != nullptr);\n  UnpinContainer(data_);\n}\n\ninline bool NDArray::IsPinned() const {\n  CHECK(data_ != nullptr);\n  return IsContainerPinned(data_);\n}\n\ninline void NDArray::RecordStream(DGLStreamHandle stream) const {\n  CHECK(data_ != nullptr);\n  RecordStream(&(data_->dl_tensor), stream);\n}\n\ninline int NDArray::use_count() const {\n  if (data_ == nullptr) return 0;\n  return data_->ref_counter_.load(std::memory_order_relaxed);\n}\n\ninline const DGLArray* NDArray::operator->() const {\n  return &(data_->dl_tensor);\n}\n\n/** @brief Magic number for NDArray file */\nconstexpr uint64_t kDGLNDArrayMagic = 0xDD5E40F096B4A13F;\n\ninline bool SaveDGLArray(dmlc::Stream* strm, DGLArray* tensor) {\n  uint64_t header = kDGLNDArrayMagic, reserved = 0;\n  strm->Write(header);\n  strm->Write(reserved);\n  // Always save data as CPU context\n  //\n  // Parameters that get serialized should be in CPU by default.\n  // So even the array's context is GPU, it will be stored as CPU array.\n  // This is used to prevent case when another user loads the parameters\n  // back on machine that do not have GPU or related context.\n  //\n  // We can always do array.CopyTo(target_ctx) to get a corresponding\n  // array in the target context.\n  DGLContext cpu_ctx;\n  cpu_ctx.device_type = kDGLCPU;\n  cpu_ctx.device_id = 0;\n  strm->Write(cpu_ctx);\n  strm->Write(tensor->ndim);\n  strm->Write(tensor->dtype);\n  int ndim = tensor->ndim;\n  strm->WriteArray(tensor->shape, ndim);\n  int type_bytes = tensor->dtype.bits / 8;\n  int64_t num_elems = 1;\n  for (int i = 0; i < ndim; ++i) {\n    num_elems *= tensor->shape[i];\n  }\n  int64_t data_byte_size = type_bytes * num_elems;\n  strm->Write(data_byte_size);\n\n  if (DMLC_IO_NO_ENDIAN_SWAP && tensor->ctx.device_type == kDGLCPU &&\n      tensor->strides == nullptr && tensor->byte_offset == 0) {\n    // quick path\n    strm->Write(tensor->data, data_byte_size);\n  } else {\n    std::vector<uint8_t> bytes(data_byte_size);\n    CHECK_EQ(\n        DGLArrayCopyToBytes(tensor, dmlc::BeginPtr(bytes), data_byte_size), 0)\n        << DGLGetLastError();\n    if (!DMLC_IO_NO_ENDIAN_SWAP) {\n      dmlc::ByteSwap(dmlc::BeginPtr(bytes), type_bytes, num_elems);\n    }\n    strm->Write(dmlc::BeginPtr(bytes), data_byte_size);\n  }\n  return true;\n}\n\n/**\n * @brief Convert type code to its name\n * @param type_code The type code .\n * @return The name of type code.\n */\ninline const char* TypeCode2Str(int type_code) {\n  switch (type_code) {\n    case kDGLInt:\n      return \"int\";\n    case kDGLUInt:\n      return \"uint\";\n    case kDGLFloat:\n      return \"float\";\n    case kStr:\n      return \"str\";\n    case kBytes:\n      return \"bytes\";\n    case kHandle:\n      return \"handle\";\n    case kNull:\n      return \"NULL\";\n    case kObjectHandle:\n      return \"ObjectHandle\";\n    case kArrayHandle:\n      return \"ArrayHandle\";\n    case kDGLDataType:\n      return \"DGLDataType\";\n    case kDGLContext:\n      return \"DGLContext\";\n    case kFuncHandle:\n      return \"FunctionHandle\";\n    case kModuleHandle:\n      return \"ModuleHandle\";\n    case kNDArrayContainer:\n      return \"NDArrayContainer\";\n    default:\n      LOG(FATAL) << \"unknown type_code=\" << static_cast<int>(type_code);\n      return \"\";\n  }\n}\n\n/**\n * @brief Convert device type code to its name\n * @param device_type The device type code.\n * @return The name of the device.\n */\ninline const char* DeviceTypeCode2Str(DGLDeviceType device_type) {\n  switch (device_type) {\n    case kDGLCPU:\n      return \"cpu\";\n    case kDGLCUDA:\n      return \"cuda\";\n    default:\n      LOG(FATAL) << \"Unsupported device type code=\"\n                 << static_cast<int>(device_type);\n      return \"\";\n  }\n}\n\n/**\n * @brief convert a string to DGL type.\n * @param s The string to be converted.\n * @return The corresponding dgl type.\n */\ninline DGLDataType String2DGLDataType(std::string s) {\n  DGLDataType t;\n  t.bits = 32;\n  t.lanes = 1;\n  const char* scan;\n  if (s.substr(0, 3) == \"int\") {\n    t.code = kDGLInt;\n    scan = s.c_str() + 3;\n  } else if (s.substr(0, 4) == \"uint\") {\n    t.code = kDGLUInt;\n    scan = s.c_str() + 4;\n  } else if (s.substr(0, 5) == \"float\") {\n    t.code = kDGLFloat;\n    scan = s.c_str() + 5;\n  } else if (s.substr(0, 6) == \"handle\") {\n    t.code = kHandle;\n    t.bits = 64;  // handle uses 64 bit by default.\n    scan = s.c_str() + 6;\n  } else {\n    scan = s.c_str();\n    LOG(FATAL) << \"unknown type \" << s;\n  }\n  char* xdelim;  // emulate sscanf(\"%ux%u\", bits, lanes)\n  uint8_t bits = static_cast<uint8_t>(strtoul(scan, &xdelim, 10));\n  if (bits != 0) t.bits = bits;\n  if (*xdelim == 'x') {\n    t.lanes = static_cast<uint16_t>(strtoul(xdelim + 1, nullptr, 10));\n  }\n  return t;\n}\n\n/**\n * @brief convert a DGL type to string.\n * @param t The type to be converted.\n * @return The corresponding dgl type in string.\n */\ninline std::string DGLDataType2String(DGLDataType t) {\n#ifndef _LIBCPP_SGX_NO_IOSTREAMS\n  std::ostringstream os;\n  os << t;\n  return os.str();\n#else\n  std::string repr = \"\";\n  repr += TypeCode2Str(t.code);\n  if (t.code == kHandle) return repr;\n  repr += std::to_string(static_cast<int>(t.bits));\n  if (t.lanes != 1) {\n    repr += \"x\" + std::to_string(static_cast<int>(t.lanes));\n  }\n  return repr;\n#endif\n}\n\n// macro to check type code.\n#define DGL_CHECK_TYPE_CODE(CODE, T)                                  \\\n  CHECK_EQ(CODE, T) << \" expected \" << TypeCode2Str(T) << \" but get \" \\\n                    << TypeCode2Str(CODE)\n\n}  // namespace runtime\n}  // namespace dgl\n\nnamespace dmlc {\nDMLC_DECLARE_TRAITS(has_saveload, dgl::runtime::NDArray, true);\n}  // namespace dmlc\n\n///////////////// Operator overloading for NDArray /////////////////\ndgl::runtime::NDArray operator+(\n    const dgl::runtime::NDArray& a1, const dgl::runtime::NDArray& a2);\ndgl::runtime::NDArray operator-(\n    const dgl::runtime::NDArray& a1, const dgl::runtime::NDArray& a2);\ndgl::runtime::NDArray operator*(\n    const dgl::runtime::NDArray& a1, const dgl::runtime::NDArray& a2);\ndgl::runtime::NDArray operator/(\n    const dgl::runtime::NDArray& a1, const dgl::runtime::NDArray& a2);\ndgl::runtime::NDArray operator%(\n    const dgl::runtime::NDArray& a1, const dgl::runtime::NDArray& a2);\ndgl::runtime::NDArray operator+(const dgl::runtime::NDArray& a1, int64_t rhs);\ndgl::runtime::NDArray operator-(const dgl::runtime::NDArray& a1, int64_t rhs);\ndgl::runtime::NDArray operator*(const dgl::runtime::NDArray& a1, int64_t rhs);\ndgl::runtime::NDArray operator/(const dgl::runtime::NDArray& a1, int64_t rhs);\ndgl::runtime::NDArray operator%(const dgl::runtime::NDArray& a1, int64_t rhs);\ndgl::runtime::NDArray operator+(int64_t lhs, const dgl::runtime::NDArray& a2);\ndgl::runtime::NDArray operator-(int64_t lhs, const dgl::runtime::NDArray& a2);\ndgl::runtime::NDArray operator*(int64_t lhs, const dgl::runtime::NDArray& a2);\ndgl::runtime::NDArray operator/(int64_t lhs, const dgl::runtime::NDArray& a2);\ndgl::runtime::NDArray operator%(int64_t lhs, const dgl::runtime::NDArray& a2);\ndgl::runtime::NDArray operator-(const dgl::runtime::NDArray& array);\n\ndgl::runtime::NDArray operator>(\n    const dgl::runtime::NDArray& a1, const dgl::runtime::NDArray& a2);\ndgl::runtime::NDArray operator<(\n    const dgl::runtime::NDArray& a1, const dgl::runtime::NDArray& a2);\ndgl::runtime::NDArray operator>=(\n    const dgl::runtime::NDArray& a1, const dgl::runtime::NDArray& a2);\ndgl::runtime::NDArray operator<=(\n    const dgl::runtime::NDArray& a1, const dgl::runtime::NDArray& a2);\ndgl::runtime::NDArray operator==(\n    const dgl::runtime::NDArray& a1, const dgl::runtime::NDArray& a2);\ndgl::runtime::NDArray operator!=(\n    const dgl::runtime::NDArray& a1, const dgl::runtime::NDArray& a2);\ndgl::runtime::NDArray operator>(const dgl::runtime::NDArray& a1, int64_t rhs);\ndgl::runtime::NDArray operator<(const dgl::runtime::NDArray& a1, int64_t rhs);\ndgl::runtime::NDArray operator>=(const dgl::runtime::NDArray& a1, int64_t rhs);\ndgl::runtime::NDArray operator<=(const dgl::runtime::NDArray& a1, int64_t rhs);\ndgl::runtime::NDArray operator==(const dgl::runtime::NDArray& a1, int64_t rhs);\ndgl::runtime::NDArray operator!=(const dgl::runtime::NDArray& a1, int64_t rhs);\ndgl::runtime::NDArray operator>(int64_t lhs, const dgl::runtime::NDArray& a2);\ndgl::runtime::NDArray operator<(int64_t lhs, const dgl::runtime::NDArray& a2);\ndgl::runtime::NDArray operator>=(int64_t lhs, const dgl::runtime::NDArray& a2);\ndgl::runtime::NDArray operator<=(int64_t lhs, const dgl::runtime::NDArray& a2);\ndgl::runtime::NDArray operator==(int64_t lhs, const dgl::runtime::NDArray& a2);\ndgl::runtime::NDArray operator!=(int64_t lhs, const dgl::runtime::NDArray& a2);\n\nstd::ostream& operator<<(std::ostream& os, dgl::runtime::NDArray array);\n\n///////////////// Operator overloading for DGLDataType /////////////////\n\n/** @brief Check whether two data types are the same.*/\ninline bool operator==(const DGLDataType& ty1, const DGLDataType& ty2) {\n  return ty1.code == ty2.code && ty1.bits == ty2.bits && ty1.lanes == ty2.lanes;\n}\n\n/** @brief Check whether two data types are different.*/\ninline bool operator!=(const DGLDataType& ty1, const DGLDataType& ty2) {\n  return !(ty1 == ty2);\n}\n\n#ifndef _LIBCPP_SGX_NO_IOSTREAMS\ninline std::ostream& operator<<(std::ostream& os, DGLDataType t) {\n  os << dgl::runtime::TypeCode2Str(t.code);\n  if (t.code == kHandle) return os;\n  os << static_cast<int>(t.bits);\n  if (t.lanes != 1) {\n    os << 'x' << static_cast<int>(t.lanes);\n  }\n  return os;\n}\n#endif\n\n///////////////// Operator overloading for DGLContext /////////////////\n\n/** @brief Check whether two device contexts are the same.*/\ninline bool operator==(const DGLContext& ctx1, const DGLContext& ctx2) {\n  return ctx1.device_type == ctx2.device_type &&\n         ctx1.device_id == ctx2.device_id;\n}\n\n/** @brief Check whether two device contexts are different.*/\ninline bool operator!=(const DGLContext& ctx1, const DGLContext& ctx2) {\n  return !(ctx1 == ctx2);\n}\n\n#ifndef _LIBCPP_SGX_NO_IOSTREAMS\ninline std::ostream& operator<<(std::ostream& os, const DGLContext& ctx) {\n  return os << dgl::runtime::DeviceTypeCode2Str(ctx.device_type) << \":\"\n            << ctx.device_id;\n}\n#endif\n\n#endif  // DGL_RUNTIME_NDARRAY_H_\n"
  },
  {
    "path": "include/dgl/runtime/object.h",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file runtime/object.h\n * @brief Defines the Object data structures.\n */\n#ifndef DGL_RUNTIME_OBJECT_H_\n#define DGL_RUNTIME_OBJECT_H_\n\n#include <dmlc/logging.h>\n\n#include <memory>\n#include <string>\n#include <type_traits>\n#include <vector>\n\nnamespace dgl {\nnamespace runtime {\n\n// forward declaration\nclass Object;\nclass ObjectRef;\nclass NDArray;\n\n/**\n * @brief Visitor class to each object attribute.\n *  The content is going to be called for each field.\n */\nclass AttrVisitor {\n public:\n  //! \\cond Doxygen_Suppress\n  virtual void Visit(const char* key, double* value) = 0;\n  virtual void Visit(const char* key, int64_t* value) = 0;\n  virtual void Visit(const char* key, uint64_t* value) = 0;\n  virtual void Visit(const char* key, int* value) = 0;\n  virtual void Visit(const char* key, bool* value) = 0;\n  virtual void Visit(const char* key, std::string* value) = 0;\n  virtual void Visit(const char* key, ObjectRef* value) = 0;\n  virtual void Visit(const char* key, NDArray* value) = 0;\n  template <\n      typename ENum,\n      typename = typename std::enable_if<std::is_enum<ENum>::value>::type>\n  void Visit(const char* key, ENum* ptr) {\n    static_assert(\n        std::is_same<int, typename std::underlying_type<ENum>::type>::value,\n        \"declare enum to be enum int to use visitor\");\n    this->Visit(key, reinterpret_cast<int*>(ptr));\n  }\n  //! \\endcond\n};\n\n/**\n * @brief base class of object container.\n *  All object's internal is stored as std::shared_ptr<Object>\n */\nclass Object {\n public:\n  /** @brief virtual destructor */\n  virtual ~Object() {}\n  /** @return The unique type key of the object */\n  virtual const char* type_key() const = 0;\n  /**\n   * @brief Apply visitor to each field of the Object\n   *  Visitor could mutate the content of the object.\n   *  override if Object contains attribute fields.\n   * @param visitor The visitor\n   */\n  virtual void VisitAttrs(AttrVisitor* visitor) {}\n  /** @return the type index of the object */\n  virtual uint32_t type_index() const = 0;\n  /**\n   * @brief Whether this object derives from object with type_index=tid.\n   *  Implemented by DGL_DECLARE_OBJECT_TYPE_INFO\n   *\n   * @param tid The type index.\n   * @return the check result.\n   */\n  virtual bool _DerivedFrom(uint32_t tid) const;\n  /**\n   * @brief get a runtime unique type index given a type key\n   * @param type_key Type key of a type.\n   * @return the corresponding type index.\n   */\n  static uint32_t TypeKey2Index(const char* type_key);\n  /**\n   * @brief get type key from type index.\n   * @param index The type index\n   * @return the corresponding type key.\n   */\n  static const char* TypeIndex2Key(uint32_t index);\n  /**\n   * @return whether the type is derived from\n   */\n  template <typename T>\n  inline bool derived_from() const;\n  /**\n   * @return whether the object is of type T\n   * @tparam The type to be checked.\n   */\n  template <typename T>\n  inline bool is_type() const;\n  // object ref can see this\n  friend class ObjectRef;\n  static constexpr const char* _type_key = \"Object\";\n};\n\n/** @brief base class of all reference object */\nclass ObjectRef {\n public:\n  /** @brief type indicate the container type */\n  using ContainerType = Object;\n  /**\n   * @brief Comparator\n   *\n   * Compare with the two are referencing to the same object (compare by\n   * address).\n   *\n   * @param other Another object ref.\n   * @return the compare result.\n   * @sa same_as\n   */\n  inline bool operator==(const ObjectRef& other) const;\n  /**\n   * @brief Comparator\n   *\n   * Compare with the two are referencing to the same object (compare by\n   * address).\n   *\n   * @param other Another object ref.\n   * @return the compare result.\n   */\n  inline bool same_as(const ObjectRef& other) const;\n  /**\n   * @brief Comparator\n   *\n   * The operator overload allows ObjectRef be used in std::map.\n   *\n   * @param other Another object ref.\n   * @return the compare result.\n   */\n  inline bool operator<(const ObjectRef& other) const;\n  /**\n   * @brief Comparator\n   * @param other Another object ref.\n   * @return the compare result.\n   * @sa same_as\n   */\n  inline bool operator!=(const ObjectRef& other) const;\n  /** @return the hash function for ObjectRef */\n  inline size_t hash() const;\n  /** @return whether the expression is null */\n  inline bool defined() const;\n  /** @return the internal type index of Object */\n  inline uint32_t type_index() const;\n  /** @return the internal object pointer */\n  inline const Object* get() const;\n  /** @return the internal object pointer */\n  inline const Object* operator->() const;\n  /**\n   * @brief Downcast this object to its actual type.\n   * This returns nullptr if the object is not of the requested type.\n   * Example usage:\n   *\n   * if (const Banana *banana = obj->as<Banana>()) {\n   *   // This is a Banana!\n   * }\n   * @tparam T the target type, must be subtype of Object\n   */\n  template <typename T>\n  inline const T* as() const;\n\n  /** @brief default constructor */\n  ObjectRef() = default;\n  explicit ObjectRef(std::shared_ptr<Object> obj) : obj_(obj) {}\n\n  /** @brief the internal object, do not touch */\n  std::shared_ptr<Object> obj_;\n};\n\n/**\n * @brief helper macro to declare type information in a base object.\n *\n * This is macro should be used in abstract base class definition\n * because it does not define type_key and type_index.\n */\n#define DGL_DECLARE_BASE_OBJECT_INFO(TypeName, Parent)         \\\n  const bool _DerivedFrom(uint32_t tid) const override {       \\\n    static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \\\n    if (tidx == tid) return true;                              \\\n    return Parent::_DerivedFrom(tid);                          \\\n  }\n\n/**\n * @brief helper macro to declare type information in a terminal class\n *\n * This is macro should be used in terminal class definition.\n *\n * For example:\n *\n * // This class is an abstract class and cannot create instances\n * class SomeBaseClass : public Object {\n *  public:\n *   static constexpr const char* _type_key = \"some_base\";\n *   DGL_DECLARE_BASE_OBJECT_INFO(SomeBaseClass, Object);\n * };\n *\n * // Child class that allows instantiation\n * class SomeChildClass : public SomeBaseClass {\n *  public:\n *   static constexpr const char* _type_key = \"some_child\";\n *   DGL_DECLARE_OBJECT_TYPE_INFO(SomeChildClass, SomeBaseClass);\n * };\n */\n#define DGL_DECLARE_OBJECT_TYPE_INFO(TypeName, Parent)               \\\n  const char* type_key() const final { return TypeName::_type_key; } \\\n  uint32_t type_index() const final {                                \\\n    static uint32_t tidx = TypeKey2Index(TypeName::_type_key);       \\\n    return tidx;                                                     \\\n  }                                                                  \\\n  bool _DerivedFrom(uint32_t tid) const final {                      \\\n    static uint32_t tidx = TypeKey2Index(TypeName::_type_key);       \\\n    if (tidx == tid) return true;                                    \\\n    return Parent::_DerivedFrom(tid);                                \\\n  }\n\n/** @brief Macro to generate common object reference class method definition */\n#define DGL_DEFINE_OBJECT_REF_METHODS(TypeName, BaseTypeName, ObjectName)   \\\n  TypeName() {}                                                             \\\n  explicit TypeName(std::shared_ptr<runtime::Object> obj)                   \\\n      : BaseTypeName(obj) {}                                                \\\n  const ObjectName* operator->() const {                                    \\\n    return static_cast<const ObjectName*>(obj_.get());                      \\\n  }                                                                         \\\n  ObjectName* operator->() { return static_cast<ObjectName*>(obj_.get()); } \\\n  std::shared_ptr<ObjectName> sptr() const {                                \\\n    return CHECK_NOTNULL(std::dynamic_pointer_cast<ObjectName>(obj_));      \\\n  }                                                                         \\\n  operator bool() const { return this->defined(); }                         \\\n  using ContainerType = ObjectName\n\n/** @brief Macro to generate object reference class definition */\n#define DGL_DEFINE_OBJECT_REF(TypeName, ObjectName)       \\\n  class TypeName : public ::dgl::runtime::ObjectRef {     \\\n   public:                                                \\\n    DGL_DEFINE_OBJECT_REF_METHODS(                        \\\n        TypeName, ::dgl::runtime::ObjectRef, ObjectName); \\\n  }\n\n// implementations of inline functions after this\ntemplate <typename T>\ninline bool Object::is_type() const {\n  // use static field so query only happens once.\n  static uint32_t type_id = Object::TypeKey2Index(T::_type_key);\n  return type_id == this->type_index();\n}\n\ntemplate <typename T>\ninline bool Object::derived_from() const {\n  // use static field so query only happens once.\n  static uint32_t type_id = Object::TypeKey2Index(T::_type_key);\n  return this->_DerivedFrom(type_id);\n}\n\ninline const Object* ObjectRef::get() const { return obj_.get(); }\n\ninline const Object* ObjectRef::operator->() const { return obj_.get(); }\n\ninline bool ObjectRef::defined() const { return obj_.get() != nullptr; }\n\ninline bool ObjectRef::operator==(const ObjectRef& other) const {\n  return obj_.get() == other.obj_.get();\n}\n\ninline bool ObjectRef::same_as(const ObjectRef& other) const {\n  return obj_.get() == other.obj_.get();\n}\n\ninline bool ObjectRef::operator<(const ObjectRef& other) const {\n  return obj_.get() < other.obj_.get();\n}\n\ninline bool ObjectRef::operator!=(const ObjectRef& other) const {\n  return obj_.get() != other.obj_.get();\n}\n\ninline size_t ObjectRef::hash() const {\n  return std::hash<Object*>()(obj_.get());\n}\n\ninline uint32_t ObjectRef::type_index() const {\n  CHECK(obj_.get() != nullptr) << \"null type\";\n  return get()->type_index();\n}\n\ntemplate <typename T>\ninline const T* ObjectRef::as() const {\n  const Object* ptr = get();\n  if (ptr && ptr->is_type<T>()) {\n    return static_cast<const T*>(ptr);\n  }\n  return nullptr;\n}\n\n/** @brief The hash function for nodes */\nstruct ObjectHash {\n  size_t operator()(const ObjectRef& a) const { return a.hash(); }\n};\n\n/** @brief The equal comparator for nodes */\nstruct ObjectEqual {\n  bool operator()(const ObjectRef& a, const ObjectRef& b) const {\n    return a.get() == b.get();\n  }\n};\n\n}  // namespace runtime\n}  // namespace dgl\n\nnamespace std {\ntemplate <>\nstruct hash<::dgl::runtime::ObjectRef> {\n  std::size_t operator()(const ::dgl::runtime::ObjectRef& k) const {\n    return k.hash();\n  }\n};\n\n}  // namespace std\n\n#endif  // DGL_RUNTIME_OBJECT_H_\n"
  },
  {
    "path": "include/dgl/runtime/packed_func.h",
    "content": "/**\n *  Copyright (c) 2017 by Contributors\n * @file dgl/runtime/packed_func.h\n * @brief Type-erased function used across DGL API.\n */\n#ifndef DGL_RUNTIME_PACKED_FUNC_H_\n#define DGL_RUNTIME_PACKED_FUNC_H_\n\n#include <dmlc/logging.h>\n\n#include <functional>\n#include <limits>\n#include <memory>\n#include <string>\n#include <tuple>\n#include <type_traits>\n#include <utility>\n#include <vector>\n\n#include \"c_runtime_api.h\"\n#include \"module.h\"\n#include \"ndarray.h\"\n\n// Whether use DGL runtime in header only mode.\n#ifndef DGL_RUNTIME_HEADER_ONLY\n#define DGL_RUNTIME_HEADER_ONLY 0\n#endif\n\nnamespace dgl {\nnamespace runtime {\n\n// Forward declare ObjectRef and Object for extensions.\n// This header works fine without depend on ObjectRef\n// as long as it is not used.\nclass Object;\nclass ObjectRef;\n\n// forward declarations\nclass DGLArgs;\nclass DGLArgValue;\nclass DGLRetValue;\nclass DGLArgsSetter;\n\n/**\n * @brief Packed function is a type-erased function.\n *  The arguments are passed by packed format.\n *\n *  This is an useful unified interface to call generated functions,\n *  It is the unified function function type of DGL.\n *  It corresponds to DGLFunctionHandle in C runtime API.\n */\nclass PackedFunc {\n public:\n  /**\n   * @brief The internal std::function\n   * @param args The arguments to the function.\n   * @param rv The return value.\n   *\n   * @code\n   *   // Example code on how to implemented FType\n   *   void MyPackedFunc(DGLArgs args, DGLRetValue* rv) {\n   *     // automatically convert arguments to desired type.\n   *     int a0 = args[0];\n   *     float a1 = args[1];\n   *     ...\n   *     // automatically assign values to rv\n   *     std::string my_return_value = \"x\";\n   *     *rv = my_return_value;\n   *   }\n   * @endcode\n   */\n  using FType = std::function<void(DGLArgs args, DGLRetValue* rv)>;\n  /** @brief default constructor */\n  PackedFunc() {}\n  /**\n   * @brief constructing a packed function from a std::function.\n   * @param body the internal container of packed function.\n   */\n  explicit PackedFunc(FType body) : body_(body) {}\n  /**\n   * @brief Call packed function by directly passing in unpacked format.\n   * @param args Arguments to be passed.\n   * @tparam Args arguments to be passed.\n   *\n   * @code\n   *   // Example code on how to call packed function\n   *   void CallPacked(PackedFunc f) {\n   *     // call like normal functions by pass in arguments\n   *     // return value is automatically converted back\n   *     int rvalue = f(1, 2.0);\n   *   }\n   * @endcode\n   */\n  template <typename... Args>\n  inline DGLRetValue operator()(Args&&... args) const;\n  /**\n   * @brief Call the function in packed format.\n   * @param args The arguments\n   * @param rv The return value.\n   */\n  inline void CallPacked(DGLArgs args, DGLRetValue* rv) const;\n  /** @return the internal body function */\n  inline FType body() const;\n  /** @return Whether the packed function is nullptr */\n  bool operator==(std::nullptr_t null) const { return body_ == nullptr; }\n  /** @return Whether the packed function is not nullptr */\n  bool operator!=(std::nullptr_t null) const { return body_ != nullptr; }\n\n private:\n  /** @brief internal container of packed function */\n  FType body_;\n};\n\n/**\n * @brief Please refer to \\ref TypedPackedFuncAnchor\n * \"TypedPackedFunc<R(Args..)>\"\n */\ntemplate <typename FType>\nclass TypedPackedFunc;\n\n/**\n * @anchor TypedPackedFuncAnchor\n * @brief A PackedFunc wrapper to provide typed function signature.\n * It is backed by a PackedFunc internally.\n *\n * TypedPackedFunc enables compile time type checking.\n * TypedPackedFunc works with the runtime system:\n * - It can be passed as an argument of PackedFunc.\n * - It can be assigned to DGLRetValue.\n * - It can be directly converted to a type-erased PackedFunc.\n *\n * Developers should prefer TypedPackedFunc over PackedFunc in C++ code\n * as it enables compile time checking.\n * We can construct a TypedPackedFunc from a lambda function\n * with the same signature.\n *\n * @code\n *  // user defined lambda function.\n *  auto addone = [](int x)->int {\n *    return x + 1;\n *  };\n *  // We can directly convert\n *  // lambda function to TypedPackedFunc\n *  TypedPackedFunc<int(int)> ftyped(addone);\n *  // invoke the function.\n *  int y = ftyped(1);\n *  // Can be directly converted to PackedFunc\n *  PackedFunc packed = ftype;\n * @endcode\n * @tparam R The return value of the function.\n * @tparam Args The argument signature of the function.\n */\ntemplate <typename R, typename... Args>\nclass TypedPackedFunc<R(Args...)> {\n public:\n  /** @brief short hand for this function type */\n  using TSelf = TypedPackedFunc<R(Args...)>;\n  /** @brief default constructor */\n  TypedPackedFunc() {}\n  /**\n   * @brief construct by wrap a PackedFunc\n   *\n   * Example usage:\n   * @code\n   * PackedFunc packed([](DGLArgs args, DGLRetValue *rv) {\n   *   int x = args[0];\n   *   *rv = x + 1;\n   *  });\n   * // construct from packed function\n   * TypedPackedFunc<int(int)> ftyped(packed);\n   * // call the typed version.\n   * CHECK_EQ(ftyped(1), 2);\n   * @endcode\n   *\n   * @param packed The packed function\n   */\n  inline explicit TypedPackedFunc(PackedFunc packed);\n  /**\n   * @brief construct from a lambda function with the same signature.\n   *\n   * Example usage:\n   * @code\n   * auto typed_lambda = [](int x)->int { return x + 1; }\n   * // construct from packed function\n   * TypedPackedFunc<int(int)> ftyped(typed_lambda);\n   * // call the typed version.\n   * CHECK_EQ(ftyped(1), 2);\n   * @endcode\n   *\n   * @param typed_lambda typed lambda function.\n   * @tparam FLambda the type of the lambda function.\n   */\n  template <\n      typename FLambda, typename = typename std::enable_if<std::is_convertible<\n                            FLambda, std::function<R(Args...)> >::value>::type>\n  explicit TypedPackedFunc(const FLambda& typed_lambda) {\n    this->AssignTypedLambda(typed_lambda);\n  }\n  /**\n   * @brief copy assignment operator from typed lambda\n   *\n   * Example usage:\n   * @code\n   * // construct from packed function\n   * TypedPackedFunc<int(int)> ftyped;\n   * ftyped = [](int x) { return x + 1; }\n   * // call the typed version.\n   * CHECK_EQ(ftyped(1), 2);\n   * @endcode\n   *\n   * @param typed_lambda typed lambda function.\n   * @tparam FLambda the type of the lambda function.\n   * @returns reference to self.\n   */\n  template <\n      typename FLambda, typename = typename std::enable_if<std::is_convertible<\n                            FLambda,\n                            std::function<R(Args...)> >::value>::type>\n  TSelf& operator=(FLambda typed_lambda) {  // NOLINT(*)\n    this->AssignTypedLambda(typed_lambda);\n    return *this;\n  }\n  /**\n   * @brief copy assignment operator from PackedFunc.\n   * @param packed The packed function.\n   * @returns reference to self.\n   */\n  TSelf& operator=(PackedFunc packed) {\n    packed_ = packed;\n    return *this;\n  }\n  /**\n   * @brief Invoke the operator.\n   * @param args The arguments\n   * @returns The return value.\n   */\n  inline R operator()(Args... args) const;\n  /**\n   * @brief convert to PackedFunc\n   * @return the internal PackedFunc\n   */\n  operator PackedFunc() const { return packed(); }\n  /**\n   * @return reference the internal PackedFunc\n   */\n  const PackedFunc& packed() const { return packed_; }\n\n private:\n  friend class DGLRetValue;\n  /** @brief The internal packed function */\n  PackedFunc packed_;\n  /**\n   * @brief Assign the packed field using a typed lambda function.\n   *\n   * @param flambda The lambda function.\n   * @tparam FLambda The lambda function type.\n   * @note We capture the lambda when possible for maximum efficiency.\n   */\n  template <typename FLambda>\n  inline void AssignTypedLambda(FLambda flambda);\n};\n\n/** @brief Arguments into DGL functions. */\nclass DGLArgs {\n public:\n  const DGLValue* values;\n  const int* type_codes;\n  int num_args;\n  /**\n   * @brief constructor\n   * @param values The argument values\n   * @param type_codes The argument type codes\n   * @param num_args number of arguments.\n   */\n  DGLArgs(const DGLValue* values, const int* type_codes, int num_args)\n      : values(values), type_codes(type_codes), num_args(num_args) {}\n  /** @return size of the arguments */\n  inline int size() const;\n  /**\n   * @brief Get i-th argument\n   * @param i the index.\n   * @return the ith argument.\n   */\n  inline DGLArgValue operator[](int i) const;\n};\n\n/**\n * @brief Type traits to mark if a class is dgl extension type.\n *\n * To enable extension type in C++ must be register () ed via marco.\n * DGL_REGISTER_EXT_TYPE(TypeName) after defining this with this traits.\n *\n * Extension class can be passed and returned via PackedFunc in all dgl runtime.\n * Internally extension class is stored as T*.\n *\n * @tparam T the typename\n */\ntemplate <typename T>\nstruct extension_class_info {\n  static const int code = 0;\n};\n\n/**\n * @brief Runtime function table about extension type.\n */\nclass ExtTypeVTable {\n public:\n  /** @brief function to be called to delete a handle */\n  void (*destroy)(void* handle);\n  /** @brief function to be called when clone a handle */\n  void* (*clone)(void* handle);\n  /**\n   * @brief Register type\n   * @tparam T The type to be register.\n   * @return The registered vtable.\n   */\n  template <typename T>\n  static inline ExtTypeVTable* Register_();\n  /**\n   * @brief Get a vtable based on type code.\n   * @param type_code The type code\n   * @return The registered vtable.\n   */\n  DGL_DLL static ExtTypeVTable* Get(int type_code);\n\n private:\n  // Internal registration function.\n  DGL_DLL static ExtTypeVTable* RegisterInternal(\n      int type_code, const ExtTypeVTable& vt);\n};\n\n/**\n * @brief Internal base class to\n *  handle conversion to POD values.\n */\nclass DGLPODValue_ {\n public:\n  operator double() const {\n    // Allow automatic conversion from int to float\n    // This avoids errors when user pass in int from\n    // the frontend while the API expects a float.\n    if (type_code_ == kDGLInt) {\n      return static_cast<double>(value_.v_int64);\n    }\n    DGL_CHECK_TYPE_CODE(type_code_, kDGLFloat);\n    return value_.v_float64;\n  }\n  operator int64_t() const {\n    DGL_CHECK_TYPE_CODE(type_code_, kDGLInt);\n    return value_.v_int64;\n  }\n  operator uint64_t() const {\n    DGL_CHECK_TYPE_CODE(type_code_, kDGLInt);\n    return value_.v_int64;\n  }\n  operator int() const {\n    DGL_CHECK_TYPE_CODE(type_code_, kDGLInt);\n    CHECK_LE(value_.v_int64, std::numeric_limits<int>::max());\n    return static_cast<int>(value_.v_int64);\n  }\n  operator bool() const {\n    DGL_CHECK_TYPE_CODE(type_code_, kDGLInt);\n    return value_.v_int64 != 0;\n  }\n  operator void*() const {\n    if (type_code_ == kNull) return nullptr;\n    if (type_code_ == kArrayHandle) return value_.v_handle;\n    DGL_CHECK_TYPE_CODE(type_code_, kHandle);\n    return value_.v_handle;\n  }\n  operator DGLArray*() const {\n    if (type_code_ == kArrayHandle || type_code_ == kNDArrayContainer) {\n      return static_cast<DGLArray*>(value_.v_handle);\n    } else {\n      if (type_code_ == kNull) return nullptr;\n      LOG(FATAL) << \"Expected \"\n                 << \"DGLArray* or NDArray but get \" << TypeCode2Str(type_code_);\n      return nullptr;\n    }\n  }\n  operator NDArray() const {\n    if (type_code_ == kNull) return NDArray();\n    DGL_CHECK_TYPE_CODE(type_code_, kNDArrayContainer);\n    return NDArray(static_cast<NDArray::Container*>(value_.v_handle));\n  }\n  operator DGLContext() const {\n    DGL_CHECK_TYPE_CODE(type_code_, kDGLContext);\n    return value_.v_ctx;\n  }\n  template <typename TExtension>\n  const TExtension& AsExtension() const {\n    CHECK_LT(type_code_, kExtEnd);\n    return static_cast<TExtension*>(value_.v_handle)[0];\n  }\n  int type_code() const { return type_code_; }\n  /**\n   * @brief return handle as specific pointer type.\n   * @tparam T the data type.\n   * @return The pointer type.\n   */\n  template <typename T>\n  T* ptr() const {\n    return static_cast<T*>(value_.v_handle);\n  }\n\n protected:\n  friend class DGLArgsSetter;\n  friend class DGLRetValue;\n  DGLPODValue_() : type_code_(kNull) {}\n  DGLPODValue_(DGLValue value, int type_code)\n      : value_(value), type_code_(type_code) {}\n\n  /** @brief The value */\n  DGLValue value_;\n  /** @brief the type code */\n  int type_code_;\n};\n\n/**\n * @brief A single argument value to PackedFunc.\n *  Containing both type_code and DGLValue\n *\n *  Provides utilities to do type cast into other types.\n */\nclass DGLArgValue : public DGLPODValue_ {\n public:\n  /** @brief default constructor */\n  DGLArgValue() {}\n  /**\n   * @brief constructor\n   * @param value of the function\n   * @param type_code The type code.\n   */\n  DGLArgValue(DGLValue value, int type_code) : DGLPODValue_(value, type_code) {}\n  // reuse converter from parent\n  using DGLPODValue_::operator double;\n  using DGLPODValue_::operator int64_t;\n  using DGLPODValue_::operator uint64_t;\n  using DGLPODValue_::operator int;\n  using DGLPODValue_::operator bool;\n  using DGLPODValue_::operator void*;\n  using DGLPODValue_::operator DGLArray*;\n  using DGLPODValue_::operator NDArray;\n  using DGLPODValue_::operator DGLContext;\n\n  // conversion operator.\n  operator std::string() const {\n    if (type_code_ == kDGLDataType) {\n      return DGLDataType2String(operator DGLDataType());\n    } else if (type_code_ == kBytes) {\n      DGLByteArray* arr = static_cast<DGLByteArray*>(value_.v_handle);\n      return std::string(arr->data, arr->size);\n    } else {\n      DGL_CHECK_TYPE_CODE(type_code_, kStr);\n      return std::string(value_.v_str);\n    }\n  }\n  operator DGLDataType() const {\n    if (type_code_ == kStr) {\n      return String2DGLDataType(operator std::string());\n    }\n    DGL_CHECK_TYPE_CODE(type_code_, kDGLDataType);\n    return value_.v_type;\n  }\n  operator PackedFunc() const {\n    if (type_code_ == kNull) return PackedFunc();\n    DGL_CHECK_TYPE_CODE(type_code_, kFuncHandle);\n    return *ptr<PackedFunc>();\n  }\n  template <typename FType>\n  operator TypedPackedFunc<FType>() const {\n    return TypedPackedFunc<FType>(operator PackedFunc());\n  }\n  operator Module() const {\n    DGL_CHECK_TYPE_CODE(type_code_, kModuleHandle);\n    return *ptr<Module>();\n  }\n  const DGLValue& value() const { return value_; }\n\n  // Deferred extension handler.\n  template <typename TObjectRef>\n  inline TObjectRef AsObjectRef() const;\n\n  // Convert this value to arbitrary class type\n  template <\n      typename T,\n      typename = typename std::enable_if<std::is_class<T>::value>::type>\n  inline operator T() const;\n\n  // Return true if the value is of TObjectRef type\n  template <\n      typename TObjectRef, typename = typename std::enable_if<\n                               std::is_class<TObjectRef>::value>::type>\n  inline bool IsObjectType() const;\n\n  // get internal node ptr, if it is node\n  inline std::shared_ptr<Object>& obj_sptr();\n};\n\n/**\n * @brief Return Value container,\n *  Unlike DGLArgValue, which only holds reference and do not delete\n *  the underlying container during destruction.\n *\n *  DGLRetValue holds value and will manage the underlying containers\n *  when it stores a complicated data type.\n */\nclass DGLRetValue : public DGLPODValue_ {\n public:\n  /** @brief default constructor */\n  DGLRetValue() {}\n  /**\n   * @brief move constructor from anoter return value.\n   * @param other The other return value.\n   */\n  DGLRetValue(DGLRetValue&& other)\n      : DGLPODValue_(other.value_, other.type_code_) {\n    other.value_.v_handle = nullptr;\n    other.type_code_ = kNull;\n  }\n  /** @brief destructor */\n  ~DGLRetValue() { this->Clear(); }\n  // reuse converter from parent\n  using DGLPODValue_::operator double;\n  using DGLPODValue_::operator int64_t;\n  using DGLPODValue_::operator uint64_t;\n  using DGLPODValue_::operator int;\n  using DGLPODValue_::operator bool;\n  using DGLPODValue_::operator void*;\n  using DGLPODValue_::operator DGLArray*;\n  using DGLPODValue_::operator DGLContext;\n  using DGLPODValue_::operator NDArray;\n  // Disable copy and assign from another value, but allow move.\n  DGLRetValue(const DGLRetValue& other) { this->Assign(other); }\n  // conversion operators\n  operator std::string() const {\n    if (type_code_ == kDGLDataType) {\n      return DGLDataType2String(operator DGLDataType());\n    } else if (type_code_ == kBytes) {\n      return *ptr<std::string>();\n    }\n    DGL_CHECK_TYPE_CODE(type_code_, kStr);\n    return *ptr<std::string>();\n  }\n  operator DGLDataType() const {\n    if (type_code_ == kStr) {\n      return String2DGLDataType(operator std::string());\n    }\n    DGL_CHECK_TYPE_CODE(type_code_, kDGLDataType);\n    return value_.v_type;\n  }\n  operator PackedFunc() const {\n    if (type_code_ == kNull) return PackedFunc();\n    DGL_CHECK_TYPE_CODE(type_code_, kFuncHandle);\n    return *ptr<PackedFunc>();\n  }\n  template <typename FType>\n  operator TypedPackedFunc<FType>() const {\n    return TypedPackedFunc<FType>(operator PackedFunc());\n  }\n  operator Module() const {\n    DGL_CHECK_TYPE_CODE(type_code_, kModuleHandle);\n    return *ptr<Module>();\n  }\n  // Assign operators\n  DGLRetValue& operator=(DGLRetValue&& other) {\n    this->Clear();\n    value_ = other.value_;\n    type_code_ = other.type_code_;\n    other.type_code_ = kNull;\n    return *this;\n  }\n  DGLRetValue& operator=(double value) {\n    this->SwitchToPOD(kDGLFloat);\n    value_.v_float64 = value;\n    return *this;\n  }\n  DGLRetValue& operator=(std::nullptr_t value) {\n    this->SwitchToPOD(kNull);\n    value_.v_handle = value;\n    return *this;\n  }\n  DGLRetValue& operator=(void* value) {\n    this->SwitchToPOD(kHandle);\n    value_.v_handle = value;\n    return *this;\n  }\n  DGLRetValue& operator=(int64_t value) {\n    this->SwitchToPOD(kDGLInt);\n    value_.v_int64 = value;\n    return *this;\n  }\n  DGLRetValue& operator=(int value) {\n    this->SwitchToPOD(kDGLInt);\n    value_.v_int64 = value;\n    return *this;\n  }\n  DGLRetValue& operator=(DGLDataType t) {\n    this->SwitchToPOD(kDGLDataType);\n    value_.v_type = t;\n    return *this;\n  }\n  DGLRetValue& operator=(DGLContext ctx) {\n    this->SwitchToPOD(kDGLContext);\n    value_.v_ctx = ctx;\n    return *this;\n  }\n  DGLRetValue& operator=(bool value) {\n    this->SwitchToPOD(kDGLInt);\n    value_.v_int64 = value;\n    return *this;\n  }\n  DGLRetValue& operator=(std::string value) {\n    this->SwitchToClass(kStr, value);\n    return *this;\n  }\n  DGLRetValue& operator=(DGLByteArray value) {\n    this->SwitchToClass(kBytes, std::string(value.data, value.size));\n    return *this;\n  }\n  DGLRetValue& operator=(NDArray other) {\n    this->Clear();\n    type_code_ = kNDArrayContainer;\n    value_.v_handle = other.data_;\n    other.data_ = nullptr;\n    return *this;\n  }\n  DGLRetValue& operator=(PackedFunc f) {\n    this->SwitchToClass(kFuncHandle, f);\n    return *this;\n  }\n  template <typename FType>\n  DGLRetValue& operator=(const TypedPackedFunc<FType>& f) {\n    return operator=(f.packed());\n  }\n  DGLRetValue& operator=(Module m) {\n    this->SwitchToClass(kModuleHandle, m);\n    return *this;\n  }\n  DGLRetValue& operator=(const DGLRetValue& other) {  // NOLINT(*0\n    this->Assign(other);\n    return *this;\n  }\n  DGLRetValue& operator=(const DGLArgValue& other) {\n    this->Assign(other);\n    return *this;\n  }\n  template <\n      typename T, typename = typename std::enable_if<\n                      extension_class_info<T>::code != 0>::type>\n  DGLRetValue& operator=(const T& other) {\n    this->SwitchToClass<T>(extension_class_info<T>::code, other);\n    return *this;\n  }\n  /**\n   * @brief Move the value back to front-end via C API.\n   *  This marks the current container as null.\n   *  The managed resources is moved to front-end and\n   *  the front end should take charge in managing them.\n   *\n   * @param ret_value The return value.\n   * @param ret_type_code The return type code.\n   */\n  void MoveToCHost(DGLValue* ret_value, int* ret_type_code) {\n    // cannot move str; need specially handle.\n    CHECK(type_code_ != kStr && type_code_ != kBytes);\n    *ret_value = value_;\n    *ret_type_code = type_code_;\n    type_code_ = kNull;\n  }\n  /** @return The value field, if the data is POD */\n  const DGLValue& value() const {\n    CHECK(\n        type_code_ != kObjectHandle && type_code_ != kFuncHandle &&\n        type_code_ != kModuleHandle && type_code_ != kStr)\n        << \"DGLRetValue.value can only be used for POD data\";\n    return value_;\n  }\n  // ObjectRef related extenstions: in dgl/packed_func_ext.h\n  template <\n      typename T,\n      typename = typename std::enable_if<std::is_class<T>::value>::type>\n  inline operator T() const;\n  template <typename TObjectRef>\n  inline TObjectRef AsObjectRef() const;\n  inline DGLRetValue& operator=(const ObjectRef& other);\n  inline DGLRetValue& operator=(const std::shared_ptr<Object>& other);\n\n private:\n  template <typename T>\n  void Assign(const T& other) {\n    switch (other.type_code()) {\n      case kStr: {\n        SwitchToClass<std::string>(kStr, other);\n        break;\n      }\n      case kBytes: {\n        SwitchToClass<std::string>(kBytes, other);\n        break;\n      }\n      case kFuncHandle: {\n        SwitchToClass<PackedFunc>(kFuncHandle, other);\n        break;\n      }\n      case kModuleHandle: {\n        SwitchToClass<Module>(kModuleHandle, other);\n        break;\n      }\n      case kNDArrayContainer: {\n        *this = other.operator NDArray();\n        break;\n      }\n      case kObjectHandle: {\n        SwitchToClass<std::shared_ptr<Object> >(\n            kObjectHandle, *other.template ptr<std::shared_ptr<Object> >());\n        break;\n      }\n      default: {\n        if (other.type_code() < kExtBegin) {\n          SwitchToPOD(other.type_code());\n          value_ = other.value_;\n        } else {\n#if DGL_RUNTIME_HEADER_ONLY\n          LOG(FATAL) << \"Header only mode do not support ext type\";\n#else\n          this->Clear();\n          type_code_ = other.type_code();\n          value_.v_handle = (*(ExtTypeVTable::Get(other.type_code())->clone))(\n              other.value().v_handle);\n#endif\n        }\n        break;\n      }\n    }\n  }\n  // get the internal container.\n  void SwitchToPOD(int type_code) {\n    if (type_code_ != type_code) {\n      this->Clear();\n      type_code_ = type_code;\n    }\n  }\n  template <typename T>\n  void SwitchToClass(int type_code, T v) {\n    if (type_code_ != type_code) {\n      this->Clear();\n      type_code_ = type_code;\n      value_.v_handle = new T(v);\n    } else {\n      *static_cast<T*>(value_.v_handle) = v;\n    }\n  }\n  void Clear() {\n    if (type_code_ == kNull) return;\n    switch (type_code_) {\n      case kStr:\n      case kBytes:\n        delete ptr<std::string>();\n        break;\n      case kFuncHandle:\n        delete ptr<PackedFunc>();\n        break;\n      case kModuleHandle:\n        delete ptr<Module>();\n        break;\n      case kObjectHandle:\n        delete ptr<std::shared_ptr<Object> >();\n        break;\n      case kNDArrayContainer: {\n        static_cast<NDArray::Container*>(value_.v_handle)->DecRef();\n        break;\n      }\n    }\n    if (type_code_ > kExtBegin) {\n#if DGL_RUNTIME_HEADER_ONLY\n      LOG(FATAL) << \"Header only mode do not support ext type\";\n#else\n      (*(ExtTypeVTable::Get(type_code_)->destroy))(value_.v_handle);\n#endif\n    }\n    type_code_ = kNull;\n  }\n};\n\n// implementation details\ninline DGLArgValue DGLArgs::operator[](int i) const {\n  CHECK_LT(i, num_args) << \"not enough argument passed, \" << num_args\n                        << \" passed\"\n                        << \" but request arg[\" << i << \"].\";\n  return DGLArgValue(values[i], type_codes[i]);\n}\n\ninline int DGLArgs::size() const { return num_args; }\n\ninline void PackedFunc::CallPacked(DGLArgs args, DGLRetValue* rv) const {\n  body_(args, rv);\n}\n\ninline PackedFunc::FType PackedFunc::body() const { return body_; }\n\n// internal namespace\nnamespace detail {\n\ntemplate <bool stop, std::size_t I, typename F>\nstruct for_each_dispatcher {\n  template <typename T, typename... Args>\n  static void run(const F& f, T&& value, Args&&... args) {  // NOLINT(*)\n    f(I, std::forward<T>(value));\n    for_each_dispatcher<sizeof...(Args) == 0, (I + 1), F>::run(\n        f, std::forward<Args>(args)...);\n  }\n};\n\ntemplate <std::size_t I, typename F>\nstruct for_each_dispatcher<true, I, F> {\n  static void run(const F& f) {}  // NOLINT(*)\n};\n\ntemplate <typename F, typename... Args>\ninline void for_each(const F& f, Args&&... args) {  // NOLINT(*)\n  for_each_dispatcher<sizeof...(Args) == 0, 0, F>::run(\n      f, std::forward<Args>(args)...);\n}\n}  // namespace detail\n\n/* @brief argument settter to PackedFunc */\nclass DGLArgsSetter {\n public:\n  DGLArgsSetter(DGLValue* values, int* type_codes)\n      : values_(values), type_codes_(type_codes) {}\n  // setters for POD types\n  template <\n      typename T,\n      typename = typename std::enable_if<std::is_integral<T>::value>::type>\n  void operator()(size_t i, T value) const {\n    values_[i].v_int64 = static_cast<int64_t>(value);\n    type_codes_[i] = kDGLInt;\n  }\n  void operator()(size_t i, uint64_t value) const {\n    values_[i].v_int64 = static_cast<int64_t>(value);\n    CHECK_LE(value, static_cast<uint64_t>(std::numeric_limits<int64_t>::max()));\n    type_codes_[i] = kDGLInt;\n  }\n  void operator()(size_t i, double value) const {\n    values_[i].v_float64 = value;\n    type_codes_[i] = kDGLFloat;\n  }\n  void operator()(size_t i, std::nullptr_t value) const {\n    values_[i].v_handle = value;\n    type_codes_[i] = kNull;\n  }\n  void operator()(size_t i, const DGLArgValue& value) const {\n    values_[i] = value.value_;\n    type_codes_[i] = value.type_code_;\n  }\n  void operator()(size_t i, void* value) const {\n    values_[i].v_handle = value;\n    type_codes_[i] = kHandle;\n  }\n  void operator()(size_t i, DGLArray* value) const {\n    values_[i].v_handle = value;\n    type_codes_[i] = kArrayHandle;\n  }\n  void operator()(size_t i, DGLContext value) const {\n    values_[i].v_ctx = value;\n    type_codes_[i] = kDGLContext;\n  }\n  void operator()(size_t i, DGLDataType value) const {\n    values_[i].v_type = value;\n    type_codes_[i] = kDGLDataType;\n  }\n  void operator()(size_t i, const char* value) const {\n    values_[i].v_str = value;\n    type_codes_[i] = kStr;\n  }\n  // setters for container type\n  // They must be reference(instead of const ref)\n  // to make sure they are alive in the tuple(instead of getting converted)\n  void operator()(size_t i, const std::string& value) const {  // NOLINT(*)\n    values_[i].v_str = value.c_str();\n    type_codes_[i] = kStr;\n  }\n  void operator()(size_t i, const DGLByteArray& value) const {  // NOLINT(*)\n    values_[i].v_handle = const_cast<DGLByteArray*>(&value);\n    type_codes_[i] = kBytes;\n  }\n  void operator()(size_t i, const PackedFunc& value) const {  // NOLINT(*)\n    values_[i].v_handle = const_cast<PackedFunc*>(&value);\n    type_codes_[i] = kFuncHandle;\n  }\n  template <typename FType>\n  void operator()(\n      size_t i, const TypedPackedFunc<FType>& value) const {  // NOLINT(*)\n    operator()(i, value.packed());\n  }\n  void operator()(size_t i, const Module& value) const {  // NOLINT(*)\n    values_[i].v_handle = const_cast<Module*>(&value);\n    type_codes_[i] = kModuleHandle;\n  }\n  void operator()(size_t i, const NDArray& value) const {  // NOLINT(*)\n    values_[i].v_handle = value.data_;\n    type_codes_[i] = kNDArrayContainer;\n  }\n  void operator()(size_t i, const DGLRetValue& value) const {  // NOLINT(*)\n    if (value.type_code() == kStr) {\n      values_[i].v_str = value.ptr<std::string>()->c_str();\n      type_codes_[i] = kStr;\n    } else {\n      CHECK_NE(value.type_code(), kBytes) << \"not handled.\";\n      values_[i] = value.value_;\n      type_codes_[i] = value.type_code();\n    }\n  }\n  // extension\n  template <\n      typename T, typename = typename std::enable_if<\n                      extension_class_info<T>::code != 0>::type>\n  inline void operator()(size_t i, const T& value) const;\n  // ObjectRef related extenstions: in dgl/packed_func_ext.h\n  inline void operator()(size_t i, const ObjectRef& other) const;  // NOLINT(*)\n\n private:\n  /** @brief The values fields */\n  DGLValue* values_;\n  /** @brief The type code fields */\n  int* type_codes_;\n};\n\ntemplate <typename... Args>\ninline DGLRetValue PackedFunc::operator()(Args&&... args) const {\n  const int kNumArgs = sizeof...(Args);\n  const int kArraySize = kNumArgs > 0 ? kNumArgs : 1;\n  DGLValue values[kArraySize];\n  int type_codes[kArraySize];\n  detail::for_each(\n      DGLArgsSetter(values, type_codes), std::forward<Args>(args)...);\n  DGLRetValue rv;\n  body_(DGLArgs(values, type_codes, kNumArgs), &rv);\n  return rv;\n}\n\nnamespace detail {\ntemplate <typename R, int nleft, int index, typename F>\nstruct unpack_call_dispatcher {\n  template <typename... Args>\n  static void run(\n      const F& f, const DGLArgs& args_pack, DGLRetValue* rv,\n      Args&&... unpacked_args) {\n    unpack_call_dispatcher<R, nleft - 1, index + 1, F>::run(\n        f, args_pack, rv, std::forward<Args>(unpacked_args)...,\n        args_pack[index]);\n  }\n};\n\ntemplate <typename R, int index, typename F>\nstruct unpack_call_dispatcher<R, 0, index, F> {\n  template <typename... Args>\n  static void run(\n      const F& f, const DGLArgs& args_pack, DGLRetValue* rv,\n      Args&&... unpacked_args) {\n    *rv = R(f(std::forward<Args>(unpacked_args)...));\n  }\n};\n\ntemplate <int index, typename F>\nstruct unpack_call_dispatcher<void, 0, index, F> {\n  template <typename... Args>\n  static void run(\n      const F& f, const DGLArgs& args_pack, DGLRetValue* rv,\n      Args&&... unpacked_args) {\n    f(std::forward<Args>(unpacked_args)...);\n  }\n};\n\ntemplate <typename R, int nargs, typename F>\ninline void unpack_call(const F& f, const DGLArgs& args, DGLRetValue* rv) {\n  unpack_call_dispatcher<R, nargs, 0, F>::run(f, args, rv);\n}\n\ntemplate <typename R, typename... Args>\ninline R call_packed(const PackedFunc& pf, Args&&... args) {\n  return R(pf(std::forward<Args>(args)...));\n}\n\ntemplate <typename R>\nstruct typed_packed_call_dispatcher {\n  template <typename... Args>\n  static inline R run(const PackedFunc& pf, Args&&... args) {\n    return pf(std::forward<Args>(args)...);\n  }\n};\n\ntemplate <>\nstruct typed_packed_call_dispatcher<void> {\n  template <typename... Args>\n  static inline void run(const PackedFunc& pf, Args&&... args) {\n    pf(std::forward<Args>(args)...);\n  }\n};\n}  // namespace detail\n\ntemplate <typename R, typename... Args>\nTypedPackedFunc<R(Args...)>::TypedPackedFunc(PackedFunc packed)\n    : packed_(packed) {}\n\ntemplate <typename R, typename... Args>\ntemplate <typename FType>\ninline void TypedPackedFunc<R(Args...)>::AssignTypedLambda(FType flambda) {\n  packed_ = PackedFunc([flambda](const DGLArgs& args, DGLRetValue* rv) {\n    detail::unpack_call<R, sizeof...(Args)>(flambda, args, rv);\n  });\n}\n\ntemplate <typename R, typename... Args>\ninline R TypedPackedFunc<R(Args...)>::operator()(Args... args) const {\n  return detail::typed_packed_call_dispatcher<R>::run(\n      packed_, std::forward<Args>(args)...);\n}\n\n// extension and node type handling\nnamespace detail {\ntemplate <typename T, typename TSrc, bool is_ext>\nstruct DGLValueCast {\n  static T Apply(const TSrc* self) { return self->template AsObjectRef<T>(); }\n};\n\ntemplate <typename T, typename TSrc>\nstruct DGLValueCast<T, TSrc, true> {\n  static T Apply(const TSrc* self) { return self->template AsExtension<T>(); }\n};\n}  // namespace detail\n\ntemplate <typename T, typename>\ninline DGLArgValue::operator T() const {\n  return detail::DGLValueCast<\n      T, DGLArgValue, extension_class_info<T>::code != 0>::Apply(this);\n}\n\ntemplate <typename T, typename>\ninline DGLRetValue::operator T() const {\n  return detail::DGLValueCast<\n      T, DGLRetValue, extension_class_info<T>::code != 0>::Apply(this);\n}\n\ntemplate <typename T, typename>\ninline void DGLArgsSetter::operator()(size_t i, const T& value) const {\n  static_assert(\n      extension_class_info<T>::code != 0, \"Need to have extesion code\");\n  type_codes_[i] = extension_class_info<T>::code;\n  values_[i].v_handle = const_cast<T*>(&value);\n}\n\n// extension type handling\ntemplate <typename T>\nstruct ExtTypeInfo {\n  static void destroy(void* handle) { delete static_cast<T*>(handle); }\n  static void* clone(void* handle) { return new T(*static_cast<T*>(handle)); }\n};\n\ntemplate <typename T>\ninline ExtTypeVTable* ExtTypeVTable::Register_() {\n  const int code = extension_class_info<T>::code;\n  static_assert(\n      code != 0,\n      \"require extension_class_info traits to be declared with non-zero code\");\n  ExtTypeVTable vt;\n  vt.clone = ExtTypeInfo<T>::clone;\n  vt.destroy = ExtTypeInfo<T>::destroy;\n  return ExtTypeVTable::RegisterInternal(code, vt);\n}\n\n// Implement Module::GetFunction\n// Put implementation in this file so we have seen the PackedFunc\ninline PackedFunc Module::GetFunction(\n    const std::string& name, bool query_imports) {\n  PackedFunc pf = node_->GetFunction(name, node_);\n  if (pf != nullptr) return pf;\n  if (query_imports) {\n    for (const Module& m : node_->imports_) {\n      pf = m.node_->GetFunction(name, m.node_);\n      if (pf != nullptr) return pf;\n    }\n  }\n  return pf;\n}\n}  // namespace runtime\n}  // namespace dgl\n#endif  // DGL_RUNTIME_PACKED_FUNC_H_\n"
  },
  {
    "path": "include/dgl/runtime/parallel_for.h",
    "content": "/**\n *  Copyright (c) 2021 by Contributors\n * @file runtime/container.h\n * @brief Defines the container object data structures.\n */\n#ifndef DGL_RUNTIME_PARALLEL_FOR_H_\n#define DGL_RUNTIME_PARALLEL_FOR_H_\n\n#include <dgl/env_variable.h>\n#include <dmlc/omp.h>\n\n#include <algorithm>\n#include <atomic>\n#include <cstdlib>\n#include <exception>\n#include <string>\n#include <utility>\n#include <vector>\n\nnamespace {\nint64_t divup(int64_t x, int64_t y) { return (x + y - 1) / y; }\n}  // namespace\n\nnamespace dgl {\nnamespace runtime {\nnamespace {\nstruct DefaultGrainSizeT {\n  size_t grain_size;\n\n  DefaultGrainSizeT() : DefaultGrainSizeT(1) {}\n\n  explicit DefaultGrainSizeT(size_t default_grain_size) {\n    auto var = dgl::kDGLParallelForGrainSize;\n\n    if (var) {\n      grain_size = std::stoul(var);\n    } else {\n      grain_size = default_grain_size;\n    }\n  }\n\n  size_t operator()() { return grain_size; }\n};\n}  // namespace\n\ninline size_t compute_num_threads(size_t begin, size_t end, size_t grain_size) {\n#ifdef _OPENMP\n  if (omp_in_parallel() || end - begin <= grain_size || end - begin == 1)\n    return 1;\n\n  return std::min(\n      static_cast<int64_t>(omp_get_max_threads()),\n      divup(end - begin, grain_size));\n#else\n  return 1;\n#endif\n}\n\nstatic DefaultGrainSizeT default_grain_size;\n\n/**\n * @brief OpenMP-based parallel for loop.\n *\n * It requires each thread's workload to have at least \\a grain_size elements.\n * The loop body will be a function that takes in two arguments \\a begin and \\a\n * end, which stands for the starting (inclusive) and ending index (exclusive)\n * of the workload.\n */\ntemplate <typename F>\nvoid parallel_for(\n    const size_t begin, const size_t end, const size_t grain_size, F&& f) {\n  if (begin >= end) {\n    return;\n  }\n\n#ifdef _OPENMP\n  auto num_threads = compute_num_threads(begin, end, grain_size);\n  // (BarclayII) the exception code is borrowed from PyTorch.\n  std::atomic_flag err_flag = ATOMIC_FLAG_INIT;\n  std::exception_ptr eptr;\n\n#pragma omp parallel num_threads(num_threads)\n  {\n    auto tid = omp_get_thread_num();\n    auto chunk_size = divup((end - begin), num_threads);\n    auto begin_tid = begin + tid * chunk_size;\n    if (begin_tid < end) {\n      auto end_tid = std::min(end, static_cast<size_t>(chunk_size + begin_tid));\n      try {\n        f(begin_tid, end_tid);\n      } catch (...) {\n        if (!err_flag.test_and_set()) eptr = std::current_exception();\n      }\n    }\n  }\n  if (eptr) std::rethrow_exception(eptr);\n#else\n  f(begin, end);\n#endif\n}\n\n/**\n * @brief OpenMP-based parallel for loop with default grain size.\n *\n * parallel_for with grain size to default value, either 1 or controlled through\n * environment variable DGL_PARALLEL_FOR_GRAIN_SIZE.\n * If grain size is set to 1, the function behaves the same way as OpenMP\n * parallel for pragma with static scheduling.\n */\ntemplate <typename F>\nvoid parallel_for(const size_t begin, const size_t end, F&& f) {\n  parallel_for(begin, end, default_grain_size(), std::forward<F>(f));\n}\n\n/**\n * @brief OpenMP-based two-stage parallel reduction.\n *\n * The first-stage reduction function \\a f works in parallel.  Each thread's\n * workload has at least \\a grain_size elements.  The loop body will be a\n * function that takes in the starting index (inclusive), the ending index\n * (exclusive), and the reduction identity.\n *\n * The second-stage reduction function \\a sf is a binary function working in the\n * main thread. It aggregates the partially reduced result computed from each\n * thread.\n *\n * Example to compute a parallelized max reduction of an array \\c a:\n *\n *     parallel_reduce(\n *       0,        // starting index\n *       100,      // ending index\n *       1,        // grain size\n *       -std::numeric_limits<float>::infinity,     // identity\n *       [&a] (int begin, int end, float ident) {   // first-stage partial\n * reducer float result = ident; for (int i = begin; i < end; ++i) result =\n * std::max(result, a[i]); return result;\n *       },\n *       [] (float result, float partial_result) {\n *         return std::max(result, partial_result);\n *       });\n */\ntemplate <typename DType, typename F, typename SF>\nDType parallel_reduce(\n    const size_t begin, const size_t end, const size_t grain_size,\n    const DType ident, const F& f, const SF& sf) {\n  if (begin >= end) {\n    return ident;\n  }\n\n  int num_threads = compute_num_threads(begin, end, grain_size);\n  if (num_threads == 1) {\n    return f(begin, end, ident);\n  }\n\n  std::vector<DType> results(num_threads, ident);\n  std::atomic_flag err_flag = ATOMIC_FLAG_INIT;\n  std::exception_ptr eptr;\n#pragma omp parallel num_threads(num_threads)\n  {\n    auto tid = omp_get_thread_num();\n    auto chunk_size = divup((end - begin), num_threads);\n    auto begin_tid = begin + tid * chunk_size;\n    if (begin_tid < end) {\n      auto end_tid = std::min(end, static_cast<size_t>(chunk_size + begin_tid));\n      try {\n        results[tid] = f(begin_tid, end_tid, ident);\n      } catch (...) {\n        if (!err_flag.test_and_set()) eptr = std::current_exception();\n      }\n    }\n  }\n  if (eptr) std::rethrow_exception(eptr);\n\n  DType out = ident;\n  for (int64_t i = 0; i < num_threads; ++i) out = sf(out, results[i]);\n  return out;\n}\n\n}  // namespace runtime\n}  // namespace dgl\n\n#endif  // DGL_RUNTIME_PARALLEL_FOR_H_\n"
  },
  {
    "path": "include/dgl/runtime/registry.h",
    "content": "/**\n *  Copyright (c) 2017 by Contributors\n * @file dgl/runtime/registry.h\n * @brief This file defines the DGL global function registry.\n *\n *  The registered functions will be made available to front-end\n *  as well as backend users.\n *\n *  The registry stores type-erased functions.\n *  Each registered function is automatically exposed\n *  to front-end language(e.g. python).\n *\n *  Front-end can also pass callbacks as PackedFunc, or register\n *  then into the same global registry in C++.\n *  The goal is to mix the front-end language and the DGL back-end.\n *\n * @code\n *   // register the function as MyAPIFuncName\n *   DGL_REGISTER_GLOBAL(MyAPIFuncName)\n *   .set_body([](DGLArgs args, DGLRetValue* rv) {\n *     // my code.\n *   });\n * @endcode\n */\n#ifndef DGL_RUNTIME_REGISTRY_H_\n#define DGL_RUNTIME_REGISTRY_H_\n\n#include <string>\n#include <vector>\n\n#include \"packed_func.h\"\n\nnamespace dgl {\nnamespace runtime {\n\n/** @brief Registry for global function */\nclass Registry {\n public:\n  /**\n   * @brief set the body of the function to be f\n   * @param f The body of the function.\n   */\n  DGL_DLL Registry& set_body(PackedFunc f);  // NOLINT(*)\n  /**\n   * @brief set the body of the function to be f\n   * @param f The body of the function.\n   */\n  Registry& set_body(PackedFunc::FType f) {  // NOLINT(*)\n    return set_body(PackedFunc(f));\n  }\n  /**\n   * @brief set the body of the function to be TypedPackedFunc.\n   *\n   * @code\n   *\n   * DGL_REGISTER_API(\"addone\")\n   * .set_body_typed<int(int)>([](int x) { return x + 1; });\n   *\n   * @endcode\n   *\n   * @param f The body of the function.\n   * @tparam FType the signature of the function.\n   * @tparam FLambda The type of f.\n   */\n  template <typename FType, typename FLambda>\n  Registry& set_body_typed(FLambda f) {\n    return set_body(TypedPackedFunc<FType>(f).packed());\n  }\n  /**\n   * @brief Register a function with given name\n   * @param name The name of the function.\n   * @param override Whether allow oveeride existing function.\n   * @return Reference to theregistry.\n   */\n  DGL_DLL static Registry& Register(\n      const std::string& name, bool override = false);  // NOLINT(*)\n  /**\n   * @brief Erase global function from registry, if exist.\n   * @param name The name of the function.\n   * @return Whether function exist.\n   */\n  DGL_DLL static bool Remove(const std::string& name);\n  /**\n   * @brief Get the global function by name.\n   * @param name The name of the function.\n   * @return pointer to the registered function,\n   *   nullptr if it does not exist.\n   */\n  DGL_DLL static const PackedFunc* Get(const std::string& name);  // NOLINT(*)\n  /**\n   * @brief Get the names of currently registered global function.\n   * @return The names\n   */\n  DGL_DLL static std::vector<std::string> ListNames();\n\n  // Internal class.\n  struct Manager;\n\n protected:\n  /** @brief name of the function */\n  std::string name_;\n  /** @brief internal packed function */\n  PackedFunc func_;\n  friend struct Manager;\n};\n\n/** @brief helper macro to supress unused warning */\n#if defined(__GNUC__)\n#define DGL_ATTRIBUTE_UNUSED __attribute__((unused))\n#else\n#define DGL_ATTRIBUTE_UNUSED\n#endif\n\n#define DGL_STR_CONCAT_(__x, __y) __x##__y\n#define DGL_STR_CONCAT(__x, __y) DGL_STR_CONCAT_(__x, __y)\n\n#define DGL_FUNC_REG_VAR_DEF \\\n  static DGL_ATTRIBUTE_UNUSED ::dgl::runtime::Registry& __mk_##DGL\n\n#define DGL_TYPE_REG_VAR_DEF \\\n  static DGL_ATTRIBUTE_UNUSED ::dgl::runtime::ExtTypeVTable* __mk_##DGLT\n\n/**\n * @brief Register a function globally.\n * @code\n *   DGL_REGISTER_GLOBAL(\"MyPrint\")\n *   .set_body([](DGLArgs args, DGLRetValue* rv) {\n *   });\n * @endcode\n */\n#define DGL_REGISTER_GLOBAL(OpName)                   \\\n  DGL_STR_CONCAT(DGL_FUNC_REG_VAR_DEF, __COUNTER__) = \\\n      ::dgl::runtime::Registry::Register(OpName)\n\n/**\n * @brief Macro to register extension type.\n *  This must be registered in a cc file\n *  after the trait extension_class_info is defined.\n */\n#define DGL_REGISTER_EXT_TYPE(T)                      \\\n  DGL_STR_CONCAT(DGL_TYPE_REG_VAR_DEF, __COUNTER__) = \\\n      ::dgl::runtime::ExtTypeVTable::Register_<T>()\n\n}  // namespace runtime\n}  // namespace dgl\n#endif  // DGL_RUNTIME_REGISTRY_H_\n"
  },
  {
    "path": "include/dgl/runtime/serializer.h",
    "content": "/**\n *  Copyright (c) 2017 by Contributors\n * @file dgl/runtime/serializer.h\n * @brief Serializer extension to support DGL data types\n *  Include this file to enable serialization of DGLDataType, DGLContext\n */\n#ifndef DGL_RUNTIME_SERIALIZER_H_\n#define DGL_RUNTIME_SERIALIZER_H_\n\n#include <dmlc/io.h>\n#include <dmlc/serializer.h>\n\n#include \"c_runtime_api.h\"\n#include \"smart_ptr_serializer.h\"\n\nnamespace dmlc {\nnamespace serializer {\n\ntemplate <>\nstruct Handler<DGLDataType> {\n  inline static void Write(Stream *strm, const DGLDataType &dtype) {\n    Handler<uint8_t>::Write(strm, dtype.code);\n    Handler<uint8_t>::Write(strm, dtype.bits);\n    Handler<uint16_t>::Write(strm, dtype.lanes);\n  }\n  inline static bool Read(Stream *strm, DGLDataType *dtype) {\n    if (!Handler<uint8_t>::Read(strm, &(dtype->code))) return false;\n    if (!Handler<uint8_t>::Read(strm, &(dtype->bits))) return false;\n    if (!Handler<uint16_t>::Read(strm, &(dtype->lanes))) return false;\n    return true;\n  }\n};\n\ntemplate <>\nstruct Handler<DGLContext> {\n  inline static void Write(Stream *strm, const DGLContext &ctx) {\n    int32_t device_type = static_cast<int32_t>(ctx.device_type);\n    Handler<int32_t>::Write(strm, device_type);\n    Handler<int32_t>::Write(strm, ctx.device_id);\n  }\n  inline static bool Read(Stream *strm, DGLContext *ctx) {\n    int32_t device_type = 0;\n    if (!Handler<int32_t>::Read(strm, &(device_type))) return false;\n    ctx->device_type = static_cast<DGLDeviceType>(device_type);\n    if (!Handler<int32_t>::Read(strm, &(ctx->device_id))) return false;\n    return true;\n  }\n};\n\n}  // namespace serializer\n}  // namespace dmlc\n#endif  // DGL_RUNTIME_SERIALIZER_H_\n"
  },
  {
    "path": "include/dgl/runtime/shared_mem.h",
    "content": "/**\n *  Copyright (c) 2017 by Contributors\n * @file dgl/runtime/ndarray.h\n * @brief shared memory management.\n */\n#ifndef DGL_RUNTIME_SHARED_MEM_H_\n#define DGL_RUNTIME_SHARED_MEM_H_\n\n#ifdef _WIN32\n#include <windows.h>\n#endif  // _WIN32\n#include <string>\n\nnamespace dgl {\nnamespace runtime {\n\n/**\n * @brief This class owns shared memory.\n *\n * When the object is gone, the shared memory will also be destroyed.\n * When the shared memory is destroyed, the file corresponding to\n * the shared memory is removed.\n */\nclass SharedMemory {\n  /**\n   * @brief whether the shared memory is owned by the object.\n   *\n   * If shared memory is created in the object, it'll be owned by the object\n   * and will be responsible for deleting it when the object is destroyed.\n   */\n  bool own_;\n\n  /* @brief the file descripter of the shared memory. */\n#ifndef _WIN32\n  int fd_;\n#else   // !_WIN32\n  HANDLE handle_;\n#endif  // _WIN32\n  /* @brief the address of the shared memory. */\n  void *ptr_;\n  /* @brief the size of the shared memory. */\n  size_t size_;\n\n  /**\n   * @brief the name of the object.\n   *\n   * In Unix, shared memory is identified by a file. Thus, `name` is actually\n   * the file name that identifies the shared memory.\n   */\n  std::string name;\n\n public:\n  /* @brief Get the filename of shared memory file\n   */\n  std::string GetName() const { return name; }\n\n  /**\n   * @brief constructor of the shared memory.\n   * @param name The file corresponding to the shared memory.\n   */\n  explicit SharedMemory(const std::string &name);\n  /**\n   * @brief destructor of the shared memory.\n   * It deallocates the shared memory and removes the corresponding file.\n   */\n  ~SharedMemory();\n  /**\n   * @brief create shared memory.\n   * It creates the file and shared memory.\n   * @param sz the size of the shared memory.\n   * @return the address of the shared memory\n   */\n  void *CreateNew(size_t sz);\n  /**\n   * @brief allocate shared memory that has been created.\n   * @param sz the size of the shared memory.\n   * @return the address of the shared memory\n   */\n  void *Open(size_t sz);\n\n  /**\n   * @brief check if the shared memory exist.\n   * @param name the name of the shared memory.\n   * @return a boolean value to indicate if the shared memory exists.\n   */\n  static bool Exist(const std::string &name);\n};\n\n}  // namespace runtime\n}  // namespace dgl\n#endif  // DGL_RUNTIME_SHARED_MEM_H_\n"
  },
  {
    "path": "include/dgl/runtime/smart_ptr_serializer.h",
    "content": "/**\n *  Copyright (c) 2017 by Contributors\n * @file dgl/runtime/serializer.h\n * @brief Serializer extension to support DGL data types\n *  Include this file to enable serialization of DGLDataType, DGLContext\n */\n#ifndef DGL_RUNTIME_SMART_PTR_SERIALIZER_H_\n#define DGL_RUNTIME_SMART_PTR_SERIALIZER_H_\n\n#include <dgl/graph_serializer.h>\n#include <dmlc/io.h>\n#include <dmlc/serializer.h>\n\n#include <memory>\n\nnamespace dmlc {\nnamespace serializer {\n\n//! \\cond Doxygen_Suppress\ntemplate <typename T>\nstruct Handler<std::shared_ptr<T>> {\n  inline static void Write(Stream *strm, const std::shared_ptr<T> &data) {\n    Handler<T>::Write(strm, *data.get());\n  }\n  inline static bool Read(Stream *strm, std::shared_ptr<T> *data) {\n    // When read, the default initialization behavior of shared_ptr is\n    // shared_ptr<T>(), which is holding a nullptr. Here we need to manually\n    // reset to a real object for further loading\n    if (!(*data)) {\n      data->reset(dgl::Serializer::new_object<T>());\n    }\n    return Handler<T>::Read(strm, data->get());\n  }\n};\n\ntemplate <typename T>\nstruct Handler<std::unique_ptr<T>> {\n  inline static void Write(Stream *strm, const std::unique_ptr<T> &data) {\n    Handler<T>::Write(strm, *data.get());\n  }\n  inline static bool Read(Stream *strm, std::unique_ptr<T> *data) {\n    // When read, the default initialization behavior of unique_ptr is\n    // unique_ptr<T>(), which is holding a nullptr. Here we need to manually\n    // reset to a real object for further loading\n    if (!(*data)) {\n      data->reset(dgl::Serializer::new_object<T>());\n    }\n    return Handler<T>::Read(strm, data->get());\n  }\n};\n\n}  // namespace serializer\n}  // namespace dmlc\n#endif  // DGL_RUNTIME_SMART_PTR_SERIALIZER_H_\n"
  },
  {
    "path": "include/dgl/runtime/tensordispatch.h",
    "content": "/**\n *  Copyright (c) 2020-2022 by Contributors\n * @file array/tensordispatch.h\n * @brief This file defines the dispatcher of tensor operators to\n * framework-specific implementations.\n *\n *  The dispatcher consists of a TensorDispatcher singleton in DGL C library and\n *  one separately-built shared library per supported backend.\n *\n *  Those shared libraries contain wrappers of the framework-specific operators.\n *  The wrappers are defined with extern \"C\", meaning that the C++ compiler will\n *  not do name mangling for those functions so that DGL can conveniently locate\n *  them using dlsym(3) (or GetProcAddress in Windows).\n *\n *  The TensorDispatcher singleton maintains a mapping from an array operator to\n *  the address of the corresponding symbol in the shared library.  During\n *  initialization, the TensorDispatcher checks which backend DGL is using.\n *  It then locates and opens the corresponding shared library using dlopen(3)\n * (or LoadLibrary in Windows), and populates the said mapping above with\n * dlsym(3) (or GetProcAddress in Windows).\n *\n *  A tensor operator in TensorDispatcher first checks whether the corresponding\n * symbol address is found in the mapping.  If so, it calls the function located\n * at the symbol address instead, allocate/free pieces of memory on CPU/GPU. If\n * not, it falls back to DeviceAPI::AllocWorkspace/FreeWorkspace.\n */\n\n#ifndef DGL_RUNTIME_TENSORDISPATCH_H_\n#define DGL_RUNTIME_TENSORDISPATCH_H_\n\n#include <stddef.h>\n#include <tensoradapter.h>\n#if defined(WIN32) || defined(_WIN32)\n#include <windows.h>\n#endif  // WIN32\n#ifdef DGL_USE_CUDA\n#include <cuda_runtime.h>\n#endif  // DGL_USE_CUDA\n#include \"ndarray.h\"\n\n/**\n * @brief Casts a pointer \\c entry to a function pointer with signature of \\c\n * func.\n */\n#define FUNCCAST(func, entry) (*reinterpret_cast<decltype(&(func))>(entry))\n\nnamespace dgl {\nnamespace runtime {\n\n/**\n * @brief Dispatcher that delegates the function calls to framework-specific C++\n * APIs.\n *\n * This class is not thread-safe.\n */\nclass TensorDispatcher {\n public:\n  /** @brief Get the singleton instance. */\n  static TensorDispatcher* Global() {\n    static TensorDispatcher inst;\n    return &inst;\n  }\n\n  /** @brief Whether an adapter library is available. */\n  inline bool IsAvailable() { return available_; }\n\n  /** @brief Load symbols from the given tensor adapter library path. */\n  bool Load(const char* path_cstr);\n\n  /**\n   * @brief Allocate a piece of CPU memory via PyTorch's CPUAllocator.\n   * Used in CPUDeviceAPI::AllocWorkspace().\n   *\n   * @param nbytes The size to be allocated.\n   * @return Pointer to the allocated memory.\n   */\n  inline void* CPUAllocWorkspace(size_t nbytes) {\n    auto entry = entrypoints_[Op::kCPURawAlloc];\n    return FUNCCAST(tensoradapter::CPURawAlloc, entry)(nbytes);\n  }\n\n  /**\n   * @brief Free the CPU memory.\n   * Used in CPUDeviceAPI::FreeWorkspace().\n   *\n   * @param ptr Pointer to the memory to be freed.\n   */\n  inline void CPUFreeWorkspace(void* ptr) {\n    auto entry = entrypoints_[Op::kCPURawDelete];\n    FUNCCAST(tensoradapter::CPURawDelete, entry)(ptr);\n  }\n\n#ifdef DGL_USE_CUDA\n  /**\n   * @brief Allocate a piece of GPU memory via\n   * PyTorch's THCCachingAllocator.\n   * Used in CUDADeviceAPI::AllocWorkspace().\n   *\n   * @note THCCachingAllocator specify the device to allocate on\n   * via cudaGetDevice(). Make sure to call cudaSetDevice()\n   * before invoking this function.\n   *\n   * @param nbytes The size to be allocated.\n   * @param stream The stream to be allocated on.\n   * @return Pointer to the allocated memory.\n   */\n  inline void* CUDAAllocWorkspace(size_t nbytes, cudaStream_t stream) {\n    auto entry = entrypoints_[Op::kCUDARawAlloc];\n    return FUNCCAST(tensoradapter::CUDARawAlloc, entry)(nbytes, stream);\n  }\n\n  /**\n   * @brief Free the GPU memory.\n   * Used in CUDADeviceAPI::FreeWorkspace().\n   *\n   * @param ptr Pointer to the memory to be freed.\n   */\n  inline void CUDAFreeWorkspace(void* ptr) {\n    auto entry = entrypoints_[Op::kCUDARawDelete];\n    FUNCCAST(tensoradapter::CUDARawDelete, entry)(ptr);\n  }\n\n  /**\n   * @brief Find the current PyTorch CUDA stream\n   * Used in runtime::getCurrentCUDAStream().\n   *\n   * @note PyTorch pre-allocates/sets the current CUDA stream\n   * on current device via cudaGetDevice(). Make sure to call cudaSetDevice()\n   * before invoking this function.\n   *\n   * @return cudaStream_t stream handle\n   */\n  inline cudaStream_t CUDAGetCurrentStream() {\n    auto entry = entrypoints_[Op::kCUDACurrentStream];\n    return FUNCCAST(tensoradapter::CUDACurrentStream, entry)();\n  }\n\n  /**\n   * @brief Allocate a piece of pinned CPU memory via PyTorch\n   *     CachingHostAllocator.\n   * @note Used in CUDADeviceAPI::AllocPinnedDataSpace().\n   * @param nbytes The size to be allocated.\n   * @param ctx Pointer to the PyTorch storage ctx ptr returned from the\n   *     allocator.\n   * @param deleter Pointer to the delete function ptr returned from the\n   *     allocator.\n   * @return Raw pointer to the allocated memory.\n   */\n  inline void* CUDAAllocHostWorkspace(\n      size_t nbytes, void** ctx, void** deleter) {\n    auto entry = entrypoints_[Op::kCUDARawHostAlloc];\n\n    auto alloc_func = FUNCCAST(tensoradapter::CUDARawHostAlloc, entry);\n    return alloc_func(nbytes, ctx, deleter);\n  }\n\n  /**\n   * @brief Insert the pinned memory block (allocated via PyTorch\n   *     CachingHostAllocator) back to the free list for future usage.(ref:\n   *     pytorch/pytorch/blob/master/aten/src/ATen/cuda/CachingHostAllocator.cpp).\n   * @note Used in CUDADeviceAPI::FreePinnedDataSpace().\n   * @param deleter Pointer to the delete function ptr returned from the\n   *     allocator.\n   */\n  inline void CUDAFreeHostWorkspace(void** deleter) {\n    auto entry = entrypoints_[Op::kCUDARawHostDelete];\n    FUNCCAST(tensoradapter::CUDARawHostDelete, entry)(deleter);\n  }\n\n  /**\n   * @brief Invoke the record_event function call from PyTorch\n   *     CachingHostAllocator.\n   * @note This function assoicates a CUDA stream (used by a copy kernel) to the\n   *     pinned data. In the free path of this data, which is achieved by\n   *     calling CUDAFreeHostWorkspace, the set of associated streams is then\n   *     consumed to ensure proper functionlity. (ref:\n   *     pytorch/pytorch/blob/master/aten/src/ATen/cuda/CachingHostAllocator.cpp).\n   *     Used in CUDADeviceAPI::RecordedCopyDataFromTo().\n   *\n   * @param data Pointer of the tensor to be recorded.\n   * @param ctx PyTorch storage ctx ptr returned from the allocator.\n   * @param stream The stream that currently consumes this tensor.\n   * @param device_id Device of the tensor.\n   */\n  inline void CUDARecordHostAlloc(\n      void* data, void* ctx, cudaStream_t stream, int device_id) {\n    auto entry = entrypoints_[Op::kCUDARecordHostAlloc];\n    auto recorded_alloc = FUNCCAST(tensoradapter::CUDARecordHostAlloc, entry);\n    recorded_alloc(data, ctx, stream, device_id);\n  }\n\n  /**\n   * @brief Release cached pinned memory allocations via cudaHostFree.\n   * @note Used in CUDADeviceAPI::PinData() before pinning any host memory by\n   *     DGL.\n   */\n  inline void CUDAHostAllocatorEmptyCache() {\n    auto entry = entrypoints_[Op::kCUDAHostAllocatorEmptyCache];\n    FUNCCAST(tensoradapter::CUDAHostAllocatorEmptyCache, entry)();\n  }\n#endif  // DGL_USE_CUDA\n\n  /**\n   * @brief Record streams that are using this tensor.\n   * Used in NDArray::RecordStream().\n   *\n   * @param ptr Pointer of the tensor to be recorded.\n   * @param stream The stream that is using this tensor.\n   * @param device_id Device of the tensor.\n   */\n  inline void RecordStream(void* ptr, DGLStreamHandle stream, int device_id) {\n#ifdef DGL_USE_CUDA\n    auto entry = entrypoints_[Op::kRecordStream];\n    FUNCCAST(tensoradapter::RecordStream, entry)\n    (ptr, static_cast<cudaStream_t>(stream), device_id);\n#endif\n  }\n\n private:\n  /** @brief ctor */\n  TensorDispatcher() = default;\n  /** @brief dtor */\n  ~TensorDispatcher();\n\n  /**\n   * @brief List of symbols in the adapter library.\n   *\n   * Must match the functions in tensoradapter/include/tensoradapter.h.\n   */\n  static constexpr const char* names_[] = {\n      \"CPURawAlloc\",         \"CPURawDelete\",\n#ifdef DGL_USE_CUDA\n      \"CUDARawAlloc\",        \"CUDARawDelete\",\n      \"CUDACurrentStream\",   \"RecordStream\",\n      \"CUDARawHostAlloc\",    \"CUDARawHostDelete\",\n      \"CUDARecordHostAlloc\", \"CUDAHostAllocatorEmptyCache\",\n#endif  // DGL_USE_CUDA\n  };\n\n  /** @brief Index of each function to the symbol list */\n  class Op {\n   public:\n    static constexpr int kCPURawAlloc = 0;\n    static constexpr int kCPURawDelete = 1;\n#ifdef DGL_USE_CUDA\n    static constexpr int kCUDARawAlloc = 2;\n    static constexpr int kCUDARawDelete = 3;\n    static constexpr int kCUDACurrentStream = 4;\n    static constexpr int kRecordStream = 5;\n    static constexpr int kCUDARawHostAlloc = 6;\n    static constexpr int kCUDARawHostDelete = 7;\n    static constexpr int kCUDARecordHostAlloc = 8;\n    static constexpr int kCUDAHostAllocatorEmptyCache = 9;\n#endif  // DGL_USE_CUDA\n  };\n\n  /** @brief Number of functions */\n  static constexpr int num_entries_ = sizeof(names_) / sizeof(names_[0]);\n\n  /** @brief Entrypoints of each function */\n  void* entrypoints_[num_entries_] = {\n      nullptr, nullptr,\n#ifdef DGL_USE_CUDA\n      nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,\n#endif  // DGL_USE_CUDA\n  };\n\n  bool available_ = false;\n#if defined(WIN32) || defined(_WIN32)\n  HINSTANCE handle_;\n#else   // !WIN32\n  void* handle_;\n#endif  // WIN32\n};\n\n};  // namespace runtime\n};  // namespace dgl\n\n#undef FUNCCAST\n\n#endif  // DGL_RUNTIME_TENSORDISPATCH_H_\n"
  },
  {
    "path": "include/dgl/runtime/threading_backend.h",
    "content": "/**\n *  Copyright (c) 2018 by Contributors\n * @file dgl/runtime/threading_backend.h\n * @brief Utilities for manipulating thread pool threads.\n */\n#ifndef DGL_RUNTIME_THREADING_BACKEND_H_\n#define DGL_RUNTIME_THREADING_BACKEND_H_\n\n#include <functional>\n#include <memory>\n#include <vector>\n\nnamespace dgl {\nnamespace runtime {\nnamespace threading {\n\n/**\n * @brief A platform-agnostic abstraction for managing a collection of\n *        thread pool threads.\n */\nclass ThreadGroup {\n public:\n  class Impl;\n\n  /**\n   * @brief Creates a collection of threads which run a provided function.\n   *\n   * @param num_workers The total number of worker threads in this group.\n            Includes main thread if `exclude_worker0 = true`\n   * @param worker_callback A callback which is run in its own thread.\n            Receives the worker_id as an argument.\n   * @param exclude_worker0 Whether to use the main thread as a worker.\n   *        If  `true`, worker0 will not be launched in a new thread and\n   *        `worker_callback` will only be called for values >= 1. This\n   *        allows use of the main thread as a worker.\n   */\n  ThreadGroup(\n      int num_workers, std::function<void(int)> worker_callback,\n      bool exclude_worker0 = false);\n  ~ThreadGroup();\n\n  /**\n   * @brief Blocks until all non-main threads in the pool finish.\n   */\n  void Join();\n\n  enum AffinityMode : int {\n    kBig = 1,\n    kLittle = -1,\n  };\n\n  /**\n   * @brief configure the CPU id affinity\n   *\n   * @param mode The preferred CPU type (1 = big, -1 = little).\n   * @param nthreads The number of threads to use (0 = use all).\n   * @param exclude_worker0 Whether to use the main thread as a worker.\n   *        If  `true`, worker0 will not be launched in a new thread and\n   *        `worker_callback` will only be called for values >= 1. This\n   *        allows use of the main thread as a worker.\n   *\n   * @return The number of workers to use.\n   */\n  int Configure(AffinityMode mode, int nthreads, bool exclude_worker0);\n\n private:\n  Impl* impl_;\n};\n\n/**\n * @brief Platform-agnostic no-op.\n */\n// This used to be Yield(), renaming to YieldThread() because windows.h defined\n// it as a macro in later SDKs.\nvoid YieldThread();\n\n/**\n * @return the maximum number of effective workers for this system.\n */\nint MaxConcurrency();\n\n}  // namespace threading\n}  // namespace runtime\n}  // namespace dgl\n\n#endif  // DGL_RUNTIME_THREADING_BACKEND_H_\n"
  },
  {
    "path": "include/dgl/runtime/util.h",
    "content": "/**\n *  Copyright (c) 2017 by Contributors\n * @file dgl/runtime/util.h\n * @brief Useful runtime util.\n */\n#ifndef DGL_RUNTIME_UTIL_H_\n#define DGL_RUNTIME_UTIL_H_\n\n#include \"c_runtime_api.h\"\n\nnamespace dgl {\nnamespace runtime {\n\n/**\n * @brief Check whether type matches the given spec.\n * @param t The type\n * @param code The type code.\n * @param bits The number of bits to be matched.\n * @param lanes The number of lanes sin the type.\n */\ninline bool TypeMatch(DGLDataType t, int code, int bits, int lanes = 1) {\n  return t.code == code && t.bits == bits && t.lanes == lanes;\n}\n}  // namespace runtime\n}  // namespace dgl\n// Forward declare the intrinsic id we need\n// in structure fetch to enable stackvm in runtime\nnamespace dgl {\nnamespace ir {\nnamespace intrinsic {\n/** @brief The kind of structure field info used in intrinsic */\nenum DGLStructFieldKind : int {\n  // array head address\n  kArrAddr,\n  kArrData,\n  kArrShape,\n  kArrStrides,\n  kArrNDim,\n  kArrTypeCode,\n  kArrTypeBits,\n  kArrTypeLanes,\n  kArrByteOffset,\n  kArrDeviceId,\n  kArrDeviceType,\n  kArrKindBound_,\n  // DGLValue field\n  kDGLValueContent,\n  kDGLValueKindBound_\n};\n}  // namespace intrinsic\n}  // namespace ir\n}  // namespace dgl\n#endif  // DGL_RUNTIME_UTIL_H_\n"
  },
  {
    "path": "include/dgl/sampler.h",
    "content": "/**\n *  Copyright (c) 2018 by Contributors\n * @file dgl/sampler.h\n * @brief DGL sampler header.\n */\n#ifndef DGL_SAMPLER_H_\n#define DGL_SAMPLER_H_\n\n#include <cstdlib>\n#include <ctime>\n#include <string>\n#include <vector>\n\n#include \"graph_interface.h\"\n#include \"nodeflow.h\"\n\nnamespace dgl {\n\nclass ImmutableGraph;\n\nclass SamplerOp {\n public:\n  /**\n   * @brief Sample a graph from the seed vertices with neighbor sampling.\n   * The neighbors are sampled with a uniform distribution.\n   *\n   * @param graph A graph for sampling.\n   * @param seeds the nodes where we should start to sample.\n   * @param edge_type the type of edges we should sample neighbors.\n   * @param num_hops the number of hops to sample neighbors.\n   * @param expand_factor the max number of neighbors to sample.\n   * @param add_self_loop whether to add self loop to the sampled subgraph\n   * @param probability the transition probability (float/double).\n   * @return a NodeFlow graph.\n   */\n  template <typename ValueType>\n  static NodeFlow NeighborSample(\n      const ImmutableGraph *graph, const std::vector<dgl_id_t> &seeds,\n      const std::string &edge_type, int num_hops, int expand_factor,\n      const bool add_self_loop, const ValueType *probability);\n\n  /**\n   * @brief Sample a graph from the seed vertices with layer sampling.\n   * The layers are sampled with a uniform distribution.\n   *\n   * @param graph A graph for sampling.\n   * @param seeds the nodes where we should start to sample.\n   * @param edge_type the type of edges we should sample neighbors.\n   * @param layer_sizes The size of layers.\n   * @return a NodeFlow graph.\n   */\n  static NodeFlow LayerUniformSample(\n      const ImmutableGraph *graph, const std::vector<dgl_id_t> &seeds,\n      const std::string &neigh_type, IdArray layer_sizes);\n};\n\n}  // namespace dgl\n\n#endif  // DGL_SAMPLER_H_\n"
  },
  {
    "path": "include/dgl/sampling/negative.h",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file dgl/sampling/negative.h\n * @brief Negative sampling.\n */\n#ifndef DGL_SAMPLING_NEGATIVE_H_\n#define DGL_SAMPLING_NEGATIVE_H_\n\n#include <dgl/array.h>\n#include <dgl/base_heterograph.h>\n\n#include <utility>\n\nnamespace dgl {\nnamespace sampling {\n\n/**\n * @brief Given an edge type, uniformly sample source-destination pairs that do\n * not have an edge in between using rejection sampling.\n *\n * @note This function may not return the same number of elements as the given\n * number of samples.\n * @note This function requires sorting the CSR or CSC matrix of the graph\n * in-place.  It prefers CSC over CSR.\n *\n * @param hg The graph.\n * @param etype The edge type.\n * @param num_samples The number of negative examples to sample.\n * @param num_trials The number of rejection sampling trials.\n * @param exclude_self_loops Do not include the examples where the source equals\n * the destination.\n * @param replace Whether to sample with replacement.\n * @param redundancy How much redundant negative examples to take in case of\n * duplicate examples.\n * @return The pair of source and destination tensors.\n */\nstd::pair<IdArray, IdArray> GlobalUniformNegativeSampling(\n    HeteroGraphPtr hg, dgl_type_t etype, int64_t num_samples, int num_trials,\n    bool exclude_self_loops, bool replace, double redundancy);\n\n};  // namespace sampling\n};  // namespace dgl\n\n#endif  // DGL_SAMPLING_NEGATIVE_H_\n"
  },
  {
    "path": "include/dgl/sampling/neighbor.h",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file dgl/sampling/neighbor.h\n * @brief Neighborhood-based sampling.\n */\n#ifndef DGL_SAMPLING_NEIGHBOR_H_\n#define DGL_SAMPLING_NEIGHBOR_H_\n\n#include <dgl/array.h>\n#include <dgl/base_heterograph.h>\n\n#include <tuple>\n#include <vector>\n\nnamespace dgl {\nnamespace sampling {\n\n/**\n * @brief Sample from the neighbors of the given nodes and return the sampled\n * edges as a graph.\n *\n * When sampling with replacement, the sampled subgraph could have parallel\n * edges.\n *\n * For sampling without replace, if fanout > the number of neighbors, all the\n * neighbors will be sampled.\n *\n * @param hg The input graph.\n * @param nodes Node IDs of each type. The vector length must be equal to the\n * number of node types. Empty array is allowed.\n * @param fanouts Number of sampled neighbors for each edge type. The vector\n * length should be equal to the number of edge types, or one if they all have\n * the same fanout.\n * @param dir Edge direction.\n * @param probability A vector of 1D float arrays, indicating the transition\n * probability of each edge by edge type.  An empty float array assumes uniform\n * transition.\n * @param exclude_edges Edges IDs of each type which will be excluded during\n * sampling. The vector length must be equal to the number of edges types. Empty\n * array is allowed.\n * @param replace If true, sample with replacement.\n * @return Sampled neighborhoods as a graph. The return graph has the same\n * schema as the original one.\n */\nHeteroSubgraph SampleNeighbors(\n    const HeteroGraphPtr hg, const std::vector<IdArray>& nodes,\n    const std::vector<int64_t>& fanouts, EdgeDir dir,\n    const std::vector<FloatArray>& probability,\n    const std::vector<IdArray>& exclude_edges, bool replace = true);\n\n/**\n * @brief Sample from the neighbors of the given nodes and convert a graph into\n * a bipartite-structured graph for message passing.\n *\n * Specifically, we create one node type \\c ntype_l on the \"left\" side and\n * another node type \\c ntype_r on the \"right\" side for each node type \\c ntype.\n * The nodes of type \\c ntype_r would contain the nodes designated by the\n * caller, and node type \\c ntype_l would contain the nodes that has an edge\n * connecting to one of the designated nodes.\n *\n * The nodes of \\c ntype_l would also contain the nodes in node type \\c ntype_r.\n * When sampling with replacement, the sampled subgraph could have parallel\n * edges.\n *\n * For sampling without replace, if fanout > the number of neighbors, all the\n * neighbors will be sampled.\n *\n * Non-deterministic algorithm, requires nodes parameter to store unique Node\n * IDs.\n *\n * @tparam IdType Graph's index data type, can be int32_t or int64_t\n * @param hg The input graph.\n * @param nodes Node IDs of each type. The vector length must be equal to the\n * number of node types. Empty array is allowed.\n * @param mapping External parameter that should be set to a vector of IdArrays\n *                filled with -1, required for mapping of nodes in returned\n *                graph\n * @param fanouts Number of sampled neighbors for each edge type. The vector\n * length should be equal to the number of edge types, or one if they all have\n * the same fanout.\n * @param dir Edge direction.\n * @param probability A vector of 1D float arrays, indicating the transition\n * probability of each edge by edge type.  An empty float array assumes uniform\n * transition.\n * @param exclude_edges Edges IDs of each type which will be excluded during\n * sampling. The vector length must be equal to the number of edges types. Empty\n * array is allowed.\n * @param replace If true, sample with replacement.\n * @return Sampled neighborhoods as a graph. The return graph has the same\n * schema as the original one.\n */\ntemplate <typename IdType>\nstd::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>>\nSampleNeighborsFused(\n    const HeteroGraphPtr hg, const std::vector<IdArray>& nodes,\n    const std::vector<IdArray>& mapping, const std::vector<int64_t>& fanouts,\n    EdgeDir dir, const std::vector<NDArray>& prob_or_mask,\n    const std::vector<IdArray>& exclude_edges, bool replace = true);\n\n/**\n * Select the neighbors with k-largest weights on the connecting edges for each\n * given node.\n *\n * If k > the number of neighbors, all the neighbors are sampled.\n *\n * @param hg The input graph.\n * @param nodes Node IDs of each type. The vector length must be equal to the\n * number of node types. Empty array is allowed.\n * @param k The k value for each edge type. The vector length should be equal to\n * the number of edge types, or one if they all have the same fanout.\n * @param dir Edge direction.\n * @param weight A vector of 1D float arrays, indicating the weights associated\n * witheach edge.\n * @param ascending If true, elements are sorted by ascending order, equivalent\n * to find the K smallest values. Otherwise, find K largest values.\n * @return Sampled neighborhoods as a graph. The return graph has the same\n * schema as the original one.\n */\nHeteroSubgraph SampleNeighborsTopk(\n    const HeteroGraphPtr hg, const std::vector<IdArray>& nodes,\n    const std::vector<int64_t>& k, EdgeDir dir,\n    const std::vector<FloatArray>& weight, bool ascending = false);\n\nHeteroSubgraph SampleNeighborsBiased(\n    const HeteroGraphPtr hg, const IdArray& nodes, const int64_t fanouts,\n    const NDArray& bias, const NDArray& tag_offset, const EdgeDir dir,\n    const bool replace);\n}  // namespace sampling\n}  // namespace dgl\n\n#endif  // DGL_SAMPLING_NEIGHBOR_H_\n"
  },
  {
    "path": "include/dgl/sampling/randomwalks.h",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file dgl/samplinig/randomwalks.h\n * @brief Random walk functions.\n */\n#ifndef DGL_SAMPLING_RANDOMWALKS_H_\n#define DGL_SAMPLING_RANDOMWALKS_H_\n\n#include <dgl/array.h>\n#include <dgl/base_heterograph.h>\n\n#include <tuple>\n#include <utility>\n#include <vector>\n\nnamespace dgl {\n\nnamespace sampling {\n\n/**\n * @brief Metapath-based random walk.\n * @param hg The heterograph.\n * @param seeds A 1D array of seed nodes, with the type the source type of the\n * first edge type in the metapath.\n * @param metapath A 1D array of edge types representing the metapath.\n * @param prob A vector of 1D float arrays, indicating the transition\n * probability of each edge by edge type. An empty float array assumes uniform\n * transition.\n * @return A pair of\n *         1. One 2D array of shape (len(seeds), len(metapath) + 1) with node\n *            IDs. The paths that terminated early are padded with -1.\n *         2. One 2D array of shape (len(seeds), len(metapath)) with edge IDs.\n *            The paths that terminated early are padded with -1.\n *         3. One 1D array of shape (len(metapath) + 1) with node type IDs.\n */\nstd::tuple<IdArray, IdArray, TypeArray> RandomWalk(\n    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,\n    const std::vector<FloatArray> &prob);\n\n/**\n * @brief Metapath-based random walk with restart probability.\n * @param hg The heterograph.\n * @param seeds A 1D array of seed nodes, with the type the source type of the\n * first edge type in the metapath.\n * @param metapath A 1D array of edge types representing the metapath.\n * @param prob A vector of 1D float arrays, indicating the transition\n * probability of each edge by edge type. An empty float array assumes uniform\n * transition.\n * @param restart_prob Restart probability.\n * @return A pair of\n *         1. One 2D array of shape (len(seeds), len(metapath) + 1) with node\n *            IDs. The paths that terminated early are padded with -1.\n *         2. One 2D array of shape (len(seeds), len(metapath)) with edge IDs.\n *            The paths that terminated early are padded with -1.\n *         3. One 1D array of shape (len(metapath) + 1) with node type IDs.\n */\nstd::tuple<IdArray, IdArray, TypeArray> RandomWalkWithRestart(\n    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,\n    const std::vector<FloatArray> &prob, double restart_prob);\n\n/**\n * @brief Metapath-based random walk with stepwise restart probability. Useful\n *        for PinSAGE-like models.\n * @param hg The heterograph.\n * @param seeds A 1D array of seed nodes, with the type the source type of the\n * first edge type in the metapath.\n * @param metapath A 1D array of edge types representing the metapath.\n * @param prob A vector of 1D float arrays, indicating the transition\n * probability of each edge by edge type. An empty float array assumes uniform\n * transition.\n * @param restart_prob Restart probability array which has the same number of\n * elements as \\c metapath, indicating the probability to terminate after\n * transition.\n * @return A pair of\n *         1. One 2D array of shape (len(seeds), len(metapath) + 1) with node\n *            IDs. The paths that terminated early are padded with -1.\n *         2. One 2D array of shape (len(seeds), len(metapath)) with edge IDs.\n *            The paths that terminated early are padded with -1.\n *         3. One 1D array of shape (len(metapath) + 1) with node type IDs.\n */\nstd::tuple<IdArray, IdArray, TypeArray> RandomWalkWithStepwiseRestart(\n    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,\n    const std::vector<FloatArray> &prob, FloatArray restart_prob);\n\n};  // namespace sampling\n\n};  // namespace dgl\n\n#endif  // DGL_SAMPLING_RANDOMWALKS_H_\n"
  },
  {
    "path": "include/dgl/scheduler.h",
    "content": "/**\n *  Copyright (c) 2018 by Contributors\n * @file dgl/scheduler.h\n * @brief Operations on graph index.\n */\n#ifndef DGL_SCHEDULER_H_\n#define DGL_SCHEDULER_H_\n\n#include <vector>\n\n#include \"runtime/ndarray.h\"\n\nnamespace dgl {\n\ntypedef dgl::runtime::NDArray IdArray;\n\nnamespace sched {\n\n/**\n * @brief Generate degree bucketing schedule\n * @tparam IdType Graph's index data type, can be int32_t or int64_t\n * @param msg_ids The edge id for each message\n * @param vids The destination vertex for each message\n * @param recv_ids The recv nodes (for checking zero degree nodes)\n * @note If there are multiple messages going into the same destination vertex,\n *       then there will be multiple copies of the destination vertex in vids.\n * @return a vector of 5 IdArrays for degree bucketing. The 5 arrays are:\n *         degrees: degrees for each bucket\n *         nids: destination node ids\n *         nid_section: number of nodes in each bucket (used to split nids)\n *         mids: message ids\n *         mid_section: number of messages in each bucket (used to split mids)\n */\ntemplate <class IdType>\nstd::vector<IdArray> DegreeBucketing(\n    const IdArray& msg_ids, const IdArray& vids, const IdArray& recv_ids);\n\n/**\n * @brief Generate degree bucketing schedule for group_apply edge\n * @tparam IdType Graph's index data type, can be int32_t or int64_t\n * @param uids One end vertex of edge by which edges are grouped\n * @param vids The other end vertex of edge\n * @param eids Edge ids\n * @note This function always generate group_apply schedule based on degrees of\n *       nodes in uids. Therefore, if group_apply by source nodes, then uids\n *       should be source. If group_apply by destination nodes, then uids\n *       should be destination.\n * @return a vector of 5 IdArrays for degree bucketing. The 5 arrays are:\n *         degrees: degrees for each bucket\n *         new_uids: uids reordered by degree bucket\n *         new_vids: vids reordered by degree bucket\n *         new_edis: eids reordered by degree bucket\n *         sections: number of edges in each degree bucket (used to partition\n *                   new_uids, new_vids, and new_eids)\n */\ntemplate <class IdType>\nstd::vector<IdArray> GroupEdgeByNodeDegree(\n    const IdArray& uids, const IdArray& vids, const IdArray& eids);\n\n}  // namespace sched\n\n}  // namespace dgl\n\n#endif  // DGL_SCHEDULER_H_\n"
  },
  {
    "path": "include/dgl/transform.h",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file dgl/transform.h\n * @brief DGL graph transformations\n */\n\n#ifndef DGL_TRANSFORM_H_\n#define DGL_TRANSFORM_H_\n\n#include <tuple>\n#include <utility>\n#include <vector>\n\n#include \"array.h\"\n#include \"base_heterograph.h\"\n\nnamespace dgl {\n\nnamespace transform {\n\n/**\n * @brief Given a list of graphs, remove the common nodes that do not have\n * inbound and outbound edges.\n *\n * The graphs should have identical node ID space (i.e. should have the same set\n * of nodes, including types and IDs).\n *\n * @param graphs The list of graphs.\n * @param always_preserve The list of nodes to preserve regardless of whether\n * the inbound or outbound edges exist.\n *\n * @return A pair.  The first element is the list of compacted graphs, and the\n * second element is the mapping from the compacted graphs and the original\n * graph.\n */\nstd::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>> CompactGraphs(\n    const std::vector<HeteroGraphPtr> &graphs,\n    const std::vector<IdArray> &always_preserve);\n\n/**\n * @brief Convert a graph into a bipartite-structured graph for message passing.\n *\n * Specifically, we create one node type \\c ntype_l on the \"left\" side and\n * another node type \\c ntype_r on the \"right\" side for each node type \\c ntype.\n * The nodes of type \\c ntype_r would contain the nodes designated by the\n * caller, and node type \\c ntype_l would contain the nodes that has an edge\n * connecting to one of the designated nodes.\n *\n * The nodes of \\c ntype_l would also contain the nodes in node type \\c ntype_r.\n *\n * This function is often used for constructing a series of dependency graphs\n * for multi-layer message passing, where we first construct a series of\n * frontier graphs on the original node space, and run the following to get the\n * bipartite graph needed for message passing with each GNN layer:\n *\n * <code>\n *     bipartites = [None] * len(num_layers)\n *     for l in reversed(range(len(layers))):\n *         bipartites[l], seeds = to_bipartite(frontier[l], seeds)\n *     x = graph.ndata[\"h\"][seeds]\n *     for g, layer in zip(bipartites, layers):\n *         x_src = x\n *         x_dst = x[:len(g.dsttype)]\n *         x = sageconv(g, (x_src, x_dst))\n *     output = x\n * </code>\n *\n * @param graph The graph.\n * @param rhs_nodes Designated nodes that would appear on the right side.\n * @param include_rhs_in_lhs If false, do not include the nodes of node type \\c\n * ntype_r in \\c ntype_l.\n *\n * @return A triplet containing\n *         * The bipartite-structured graph,\n *         * The induced node from the left side for each graph,\n *         * The induced edges.\n *\n * @note If include_rhs_in_lhs is true, then for each node type \\c ntype, the\n * nodes in rhs_nodes[ntype] would always appear first in the nodes of type \\c\n * ntype_l in the new graph.\n */\nstd::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>> ToBlock(\n    HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes,\n    bool include_rhs_in_lhs);\n\n/**\n * @brief Convert a multigraph to a simple graph.\n *\n * @return A triplet of\n * * @c hg : The said simple graph.\n * * @c count : The array of edge occurrences per edge type.\n * * @c edge_map : The mapping from original edge IDs to new edge IDs per edge\n * type.\n *\n * @note Example: consider a graph with the following edges\n *\n *     [(0, 1), (1, 3), (2, 2), (1, 3), (1, 4), (1, 4)]\n *\n * Then ToSimpleGraph(g) would yield the following elements:\n *\n * * The first element would be the simple graph itself with the following edges\n *\n *       [(0, 1), (1, 3), (1, 4), (2, 2)]\n *\n * * The second element is an array \\c count.  \\c count[i] stands for the number\n * of edges connecting simple_g.src[i] and simple_g.dst[i] in the original\n * graph.\n *\n *       count[0] = [1, 2, 2, 1]\n *\n * * One can find the mapping between edges from the original graph to the new\n * simple graph.\n *\n *       edge_map[0] = [0, 1, 3, 1, 2, 2]\n */\nstd::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>>\nToSimpleGraph(const HeteroGraphPtr graph);\n\n/**\n * @brief Remove edges from a graph.\n *\n * @param graph The graph.\n * @param eids The edge IDs to remove per edge type.\n *\n * @return A pair of the graph with edges removed, as well as the edge ID\n * mapping from the original graph to the new graph per edge type.\n */\nstd::pair<HeteroGraphPtr, std::vector<IdArray>> RemoveEdges(\n    const HeteroGraphPtr graph, const std::vector<IdArray> &eids);\n\n};  // namespace transform\n\n};  // namespace dgl\n\n#endif  // DGL_TRANSFORM_H_\n"
  },
  {
    "path": "include/dgl/zerocopy_serializer.h",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file rpc/shared_mem_serializer.h\n * @brief headers for serializer.\n */\n#ifndef DGL_ZEROCOPY_SERIALIZER_H_\n#define DGL_ZEROCOPY_SERIALIZER_H_\n\n#include <dgl/runtime/ndarray.h>\n#include <dmlc/io.h>\n#include <dmlc/memory_io.h>\n#include <dmlc/serializer.h>\n\n#include <deque>\n#include <memory>\n#include <queue>\n#include <string>\n#include <tuple>\n#include <utility>\n#include <vector>\n\n#include \"dmlc/logging.h\"\n\nnamespace dgl {\n\n/**\n *\n * StreamWithBuffer is backed up by dmlc::MemoryFixedSizeStream or\n * dmlc::MemoryStringStream. This class supports serializing and deserializing\n * NDArrays stored in shared memory. If the stream is created for\n * sending/recving data through network, the data pointer of the NDArray will be\n * transmitted directly without and copy. Otherwise, the stream is for\n * sending/recving data to another process on the same machine, so if an NDArray\n * is stored in shared memory, it will just record the shared memory name\n * instead of the actual data buffer.\n *\n * For example:\n *\n * std::string blob;\n * // Send to local\n * StreamWithBuffer strm(&blob, false);\n * // Send to remote\n * StreamWithBuffer strm(&blob, true);\n * // Receive from local\n * StreamWithBuffer strm(&blob, false);\n * // Receive from remote\n * std::vector<void*> ptr_list\n * StreamWithBuffer strm(&blob, ptr_list);\n */\nclass StreamWithBuffer : public dmlc::SeekStream {\n public:\n  // Buffer type. Storing NDArray to maintain the reference counting to ensure\n  // the liveness of data pointer\n  struct Buffer {\n    dgl::runtime::NDArray tensor = dgl::runtime::NDArray();\n    void* data = nullptr;\n    int64_t size = 0;\n\n    Buffer(const dgl::runtime::NDArray& tensor, void* data, int64_t data_size)\n        : tensor(tensor), data(data), size(data_size) {}\n\n    explicit Buffer(void* data) : data(data) {}\n  };\n\n  /**\n   * @brief This constructor is for writing scenario or reading from local\n   * machine\n   * @param strm The backup stream to write/load from\n   * @param send_to_remote Whether this stream will be deserialized at remote\n   * machine or the local machine. If true, will record the data pointer into\n   * buffer list.\n   */\n  StreamWithBuffer(std::unique_ptr<dmlc::SeekStream> strm, bool send_to_remote)\n      : strm_(std::move(strm)),\n        buffer_list_(),\n        send_to_remote_(send_to_remote) {}\n  /**\n   * @brief This constructor is for reading from remote\n   * @param strm The stream to write/load from zerocopy write/load\n   * @param data_ptr_list list of pointer to reconstruct NDArray\n   *\n   * For example:\n   * std::string blob;\n   * std::vector<void*> data_ptr_list;\n   * // Read from remote sended pointer list\n   * StreamWithBuffer buf_strm(&blob, data_ptr_list)\n   */\n  StreamWithBuffer(\n      std::unique_ptr<dmlc::SeekStream> strm,\n      const std::vector<void*>& data_ptr_list)\n      : strm_(std::move(strm)), send_to_remote_(true) {\n    for (void* data : data_ptr_list) {\n      buffer_list_.emplace_back(data);\n    }\n  }\n\n  /**\n   * @brief Construct stream backed up by string\n   * @param blob The string to write/load from zerocopy write/load\n   * @param send_to_remote Whether this stream will be deserialized at remote\n   * machine or the local machine. If true, will record the data pointer into\n   * buffer list.\n   */\n  StreamWithBuffer(std::string* blob, bool send_to_remote)\n      : strm_(new dmlc::MemoryStringStream(blob)),\n        send_to_remote_(send_to_remote) {}\n\n  /**\n   * @brief Construct stream backed up by string\n   * @param p_buffer buffer pointer\n   * @param size buffer size\n   * @param send_to_remote Whether this stream will be deserialized at remote\n   * machine or the local machine. If true, will record the data pointer into\n   * buffer list.\n   */\n  StreamWithBuffer(char* p_buffer, size_t size, bool send_to_remote)\n      : strm_(new dmlc::MemoryFixedSizeStream(p_buffer, size)),\n        send_to_remote_(send_to_remote) {}\n\n  /**\n   * @brief Construct stream backed up by string, and reconstruct NDArray\n   * from data_ptr_list\n   * @param blob The string to write/load from zerocopy write/load\n   * @param data_ptr_list pointer list for NDArrays to deconstruct from\n   */\n  StreamWithBuffer(std::string* blob, const std::vector<void*>& data_ptr_list)\n      : strm_(new dmlc::MemoryStringStream(blob)), send_to_remote_(true) {\n    for (void* data : data_ptr_list) {\n      buffer_list_.emplace_back(data);\n    }\n  }\n\n  /**\n   * @brief Construct stream backed up by string, and reconstruct NDArray\n   * from data_ptr_list\n   * @param p_buffer buffer pointer\n   * @param size buffer size\n   * @param data_ptr_list pointer list for NDArrays to deconstruct from\n   */\n  StreamWithBuffer(\n      char* p_buffer, size_t size, const std::vector<void*>& data_ptr_list)\n      : strm_(new dmlc::MemoryFixedSizeStream(p_buffer, size)),\n        send_to_remote_(true) {\n    for (void* data : data_ptr_list) {\n      buffer_list_.emplace_back(data);\n    }\n  }\n\n  // delegate methods to strm_\n  virtual size_t Read(void* ptr, size_t size) { return strm_->Read(ptr, size); }\n  virtual void Write(const void* ptr, size_t size) { strm_->Write(ptr, size); }\n  virtual void Seek(size_t pos) { strm_->Seek(pos); }\n  virtual size_t Tell(void) { return strm_->Tell(); }\n\n  using dmlc::Stream::Read;\n  using dmlc::Stream::Write;\n\n  /**\n   * @brief push NDArray into stream\n   * If send_to_remote=true, the NDArray will be saved to the buffer list\n   * If send_to_remote=false, the NDArray will be saved to the backedup string\n   */\n  void PushNDArray(const runtime::NDArray& tensor);\n\n  /**\n   * @brief pop NDArray from stream\n   * If send_to_remote=true, the NDArray will be reconstructed from buffer list\n   * If send_to_remote=false, the NDArray will be reconstructed from shared\n   * memory\n   */\n  dgl::runtime::NDArray PopNDArray();\n\n  /**\n   * @brief Get whether this stream is for remote usage\n   */\n  bool send_to_remote() { return send_to_remote_; }\n\n  /**\n   * @brief Get underlying buffer list\n   */\n  const std::deque<Buffer>& buffer_list() const { return buffer_list_; }\n\n private:\n  std::unique_ptr<dmlc::SeekStream> strm_;\n  std::deque<Buffer> buffer_list_;\n  bool send_to_remote_;\n};  // namespace dgl\n\n}  // namespace dgl\n\n#endif  // DGL_ZEROCOPY_SERIALIZER_H_\n"
  },
  {
    "path": "notebooks/graphbolt/walkthrough.ipynb",
    "content": "{\n  \"cells\": [\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"e1qfiZMOJYYv\"\n      },\n      \"source\": [\n        \"# Graphbolt Quick Walkthrough\\n\",\n        \"\\n\",\n        \"The tutorial provides a quick walkthrough of operators provided by the `dgl.graphbolt` package, and illustrates how to create a GNN datapipe with the package. To learn more details about Stochastic Training of GNNs, please read the [materials](https://docs.dgl.ai/tutorials/large/index.html) provided by DGL.\\n\",\n        \"\\n\",\n        \"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dmlc/dgl/blob/master/notebooks/graphbolt/walkthrough.ipynb) [![GitHub](https://img.shields.io/badge/-View%20on%20GitHub-181717?logo=github&logoColor=ffffff)](https://github.com/dmlc/dgl/blob/master/notebooks/graphbolt/walkthrough.ipynb)\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"fWiaC1WaDE-W\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"# Install required packages.\\n\",\n        \"import os\\n\",\n        \"import torch\\n\",\n        \"os.environ['TORCH'] = torch.__version__\\n\",\n        \"os.environ['DGLBACKEND'] = \\\"pytorch\\\"\\n\",\n        \"\\n\",\n        \"# Install the CPU version.\\n\",\n        \"device = torch.device(\\\"cpu\\\")\\n\",\n        \"!pip install --pre dgl -f https://data.dgl.ai/wheels-test/repo.html\\n\",\n        \"\\n\",\n        \"try:\\n\",\n        \"    import dgl.graphbolt as gb\\n\",\n        \"    installed = True\\n\",\n        \"except ImportError as error:\\n\",\n        \"    installed = False\\n\",\n        \"    print(error)\\n\",\n        \"print(\\\"DGL installed!\\\" if installed else \\\"DGL not found!\\\")\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"8O7PfsY4sPoN\"\n      },\n      \"source\": [\n        \"## Dataset\\n\",\n        \"\\n\",\n        \"The dataset has three primary components. *1*. An itemset, which can be iterated over as the training target. *2*. A sampling graph, which is used by the subgraph sampling algorithm to generate a subgraph. *3*. A feature store, which stores node, edge, and graph features.\\n\",\n        \"\\n\",\n        \"* The **Itemset** is created from iterable data or tuple of iterable data.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"g73ZAbMQsSgV\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"seeds = torch.tensor(\\n\",\n        \"    [[7, 0], [6, 0], [1, 3], [3, 3], [2, 4], [8, 4], [1, 4], [2, 4], [1, 5],\\n\",\n        \"     [9, 6], [0, 6], [8, 6], [7, 7], [7, 7], [4, 7], [6, 8], [5, 8], [9, 9],\\n\",\n        \"     [4, 9], [4, 9], [5, 9], [9, 9], [5, 9], [9, 9], [7, 9]]\\n\",\n        \")\\n\",\n        \"item_set = gb.ItemSet(seeds, names=\\\"seeds\\\")\\n\",\n        \"print(list(item_set))\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"Lqty9p4cs0OR\"\n      },\n      \"source\": [\n        \"* The **SamplingGraph** is used by the subgraph sampling algorithm to generate a subgraph. In graphbolt, we provide a canonical solution, the FusedCSCSamplingGraph, which achieves state-of-the-art time and space efficiency on CPU sampling. However, this requires enough CPU memory to host all FusedCSCSamplingGraph objects in memory.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"jDjY149xs3PI\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"indptr = torch.tensor([0, 2, 2, 2, 4, 8, 9, 12, 15, 17, 25])\\n\",\n        \"indices = torch.tensor(\\n\",\n        \"    [7, 6, 1, 3, 2, 8, 1, 2, 1, 9, 0, 8, 7, 7, 4, 6, 5, 9, 4, 4, 5, 9, 5, 9, 7]\\n\",\n        \")\\n\",\n        \"num_edges = 25\\n\",\n        \"eid = torch.arange(num_edges)\\n\",\n        \"edge_attributes = {gb.ORIGINAL_EDGE_ID: eid}\\n\",\n        \"graph = gb.fused_csc_sampling_graph(indptr, indices, edge_attributes=edge_attributes)\\n\",\n        \"print(graph)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"mNp2S2_Vs8af\"\n      },\n      \"source\": [\n        \"* The **FeatureStore** is used to store node, edge, and graph features. In graphbolt, we provide the TorchBasedFeature and related optimizations, such as the GPUCachedFeature, for different use cases.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"zIU6KWe1Sm2g\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"num_nodes = 10\\n\",\n        \"num_edges = 25\\n\",\n        \"node_feature_data = torch.rand((num_nodes, 2))\\n\",\n        \"edge_feature_data = torch.rand((num_edges, 3))\\n\",\n        \"node_feature = gb.TorchBasedFeature(node_feature_data)\\n\",\n        \"edge_feature = gb.TorchBasedFeature(edge_feature_data)\\n\",\n        \"features = {\\n\",\n        \"    (\\\"node\\\", None, \\\"feat\\\") : node_feature,\\n\",\n        \"    (\\\"edge\\\", None, \\\"feat\\\") : edge_feature,\\n\",\n        \"}\\n\",\n        \"feature_store = gb.BasicFeatureStore(features)\\n\",\n        \"print(feature_store)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"Oh2ockWWoXQ0\"\n      },\n      \"source\": [\n        \"## DataPipe\\n\",\n        \"\\n\",\n        \"The DataPipe in Graphbolt is an extension of the PyTorch DataPipe, but it is specifically designed to address the challenges of training graph neural networks (GNNs). Each stage of the data pipeline loads data from different sources and can be combined with other stages to create more complex data pipelines. The intermediate data will be stored in **MiniBatch** data packs.\\n\",\n        \"\\n\",\n        \"* **ItemSampler** iterates over input **Itemset** and create subsets.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"XtqPDprrogR7\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"datapipe = gb.ItemSampler(item_set, batch_size=3, shuffle=False)\\n\",\n        \"print(next(iter(datapipe)))\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"BjkAK37xopp1\"\n      },\n      \"source\": [\n        \"* **NegativeSampler** generate negative samples and return a mix of positive and negative samples.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"PrFpGoOGopJy\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"datapipe = datapipe.sample_uniform_negative(graph, 1)\\n\",\n        \"print(next(iter(datapipe)))\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"fYO_oIwkpmb3\"\n      },\n      \"source\": [\n        \"* **SubgraphSampler** samples a subgraph from a given set of nodes from a larger graph.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"4UsY3PL3ppYV\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"fanouts = torch.tensor([1])\\n\",\n        \"datapipe = datapipe.sample_neighbor(graph, [fanouts])\\n\",\n        \"print(next(iter(datapipe)))\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"0uIydsjUqMA0\"\n      },\n      \"source\": [\n        \"* **FeatureFetcher** fetchs features for node/edge in graphbolt.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"YAj8G7YBqO6G\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"datapipe = datapipe.fetch_feature(feature_store, node_feature_keys=[\\\"feat\\\"], edge_feature_keys=[\\\"feat\\\"])\\n\",\n        \"print(next(iter(datapipe)))\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"hjBSLPRPrsD2\"\n      },\n      \"source\": [\n        \"* Copy the data to the GPU for training on the GPU.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"RofiZOUMqt_u\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"datapipe = datapipe.copy_to(device=device)\\n\",\n        \"print(next(iter(datapipe)))\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"xm9HnyHRvxXj\"\n      },\n      \"source\": [\n        \"## Exercise: Node classification\\n\",\n        \"\\n\",\n        \"Similarly, the following Dataset is created for node classification, can you implement the data pipeline for the dataset?\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"YV-mk-xAv78v\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"# Dataset for node classification.\\n\",\n        \"num_nodes = 10\\n\",\n        \"nodes = torch.arange(num_nodes)\\n\",\n        \"labels = torch.tensor([1, 2, 0, 2, 2, 0, 2, 2, 2, 2])\\n\",\n        \"item_set = gb.ItemSet((nodes, labels), names=(\\\"seeds\\\", \\\"labels\\\"))\\n\",\n        \"\\n\",\n        \"indptr = torch.tensor([0, 2, 2, 2, 4, 8, 9, 12, 15, 17, 25])\\n\",\n        \"indices = torch.tensor(\\n\",\n        \"    [7, 6, 1, 3, 2, 8, 1, 2, 1, 9, 0, 8, 7, 7, 4, 6, 5, 9, 4, 4, 5, 9, 5, 9, 7]\\n\",\n        \")\\n\",\n        \"eid = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,\\n\",\n        \"                    14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24])\\n\",\n        \"edge_attributes = {gb.ORIGINAL_EDGE_ID: eid}\\n\",\n        \"graph = gb.from_fused_csc(indptr, indices, None, None, edge_attributes, None)\\n\",\n        \"\\n\",\n        \"num_nodes = 10\\n\",\n        \"num_edges = 25\\n\",\n        \"node_feature_data = torch.rand((num_nodes, 2))\\n\",\n        \"edge_feature_data = torch.rand((num_edges, 3))\\n\",\n        \"node_feature = gb.TorchBasedFeature(node_feature_data)\\n\",\n        \"edge_feature = gb.TorchBasedFeature(edge_feature_data)\\n\",\n        \"features = {\\n\",\n        \"    (\\\"node\\\", None, \\\"feat\\\") : node_feature,\\n\",\n        \"    (\\\"edge\\\", None, \\\"feat\\\") : edge_feature,\\n\",\n        \"}\\n\",\n        \"feature_store = gb.BasicFeatureStore(features)\\n\",\n        \"\\n\",\n        \"# Datapipe.\\n\",\n        \"...\\n\",\n        \"print(next(iter(datapipe)))\"\n      ]\n    }\n  ],\n  \"metadata\": {\n    \"accelerator\": \"GPU\",\n    \"colab\": {\n      \"collapsed_sections\": [\n        \"BjkAK37xopp1\"\n      ],\n      \"gpuType\": \"T4\",\n      \"private_outputs\": true,\n      \"provenance\": []\n    },\n    \"kernelspec\": {\n      \"display_name\": \"Python 3\",\n      \"name\": \"python3\"\n    },\n    \"language_info\": {\n      \"name\": \"python\"\n    }\n  },\n  \"nbformat\": 4,\n  \"nbformat_minor\": 0\n}\n"
  },
  {
    "path": "notebooks/sparse/gcn.ipynb",
    "content": "{\n  \"nbformat\": 4,\n  \"nbformat_minor\": 0,\n  \"metadata\": {\n    \"colab\": {\n      \"provenance\": []\n    },\n    \"kernelspec\": {\n      \"name\": \"python3\",\n      \"display_name\": \"Python 3\"\n    },\n    \"language_info\": {\n      \"name\": \"python\"\n    },\n    \"accelerator\": \"GPU\",\n    \"gpuClass\": \"standard\"\n  },\n  \"cells\": [\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"# Building a Graph Convolutional Network Using Sparse Matrices\\n\",\n        \"\\n\",\n        \"This tutorial illustrates step-by-step how to write and train a Graph Convolutional Network ([Kipf et al. (2017)](https://arxiv.org/abs/1609.02907)) using DGL's sparse matrix APIs.\\n\",\n        \"\\n\",\n        \"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dmlc/dgl/blob/master/notebooks/sparse/gcn.ipynb) [![GitHub](https://img.shields.io/badge/-View%20on%20GitHub-181717?logo=github&logoColor=ffffff)](https://github.com/dmlc/dgl/blob/master/notebooks/sparse/gcn.ipynb)\"\n      ],\n      \"metadata\": {\n        \"id\": \"_iqWrPwxtZr6\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"# Install required packages.\\n\",\n        \"import os\\n\",\n        \"import torch\\n\",\n        \"os.environ['TORCH'] = torch.__version__\\n\",\n        \"os.environ['DGLBACKEND'] = \\\"pytorch\\\"\\n\",\n        \"\\n\",\n        \"# Uncomment below to install required packages. If the CUDA version is not 11.8,\\n\",\n        \"# check the https://www.dgl.ai/pages/start.html to find the supported CUDA\\n\",\n        \"# version and corresponding command to install DGL.\\n\",\n        \"#!pip install dgl -f https://data.dgl.ai/wheels/cu118/repo.html > /dev/null\\n\",\n        \"\\n\",\n        \"try:\\n\",\n        \"    import dgl\\n\",\n        \"    installed = True\\n\",\n        \"except ImportError:\\n\",\n        \"    installed = False\\n\",\n        \"print(\\\"DGL installed!\\\" if installed else \\\"DGL not found!\\\")\"\n      ],\n      \"metadata\": {\n        \"id\": \"FTqB360eRvya\"\n      },\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"## Graph Convolutional Layer\\n\",\n        \"\\n\",\n        \"Mathematically, the graph convolutional layer is defined as:\\n\",\n        \"\\n\",\n        \"$$f(X^{(l)}, A) = \\\\sigma(\\\\bar{D}^{-\\\\frac{1}{2}}\\\\bar{A}\\\\bar{D}^{-\\\\frac{1}{2}}X^{(l)}W^{(l)})$$\\n\",\n        \"\\n\",\n        \"with $\\\\bar{A} = A + I$, where $A$ denotes the adjacency matrix and $I$ denotes the identity matrix, $\\\\bar{D}$ refers to the diagonal node degree matrix of $\\\\bar{A}$ and $W^{(l)}$ denotes a trainable weight matrix. $\\\\sigma$ refers to a non-linear activation (e.g. relu).\\n\",\n        \"\\n\",\n        \"The code below shows how to implement it using the `dgl.sparse` package. The core operations are:\\n\",\n        \"\\n\",\n        \"* `dgl.sparse.identity` creates the identity matrix $I$.\\n\",\n        \"* The augmented adjacency matrix $\\\\bar{A}$ is then computed by adding the identity matrix to the adjacency matrix $A$.\\n\",\n        \"* `A_hat.sum(0)` aggregates the augmented adjacency matrix $\\\\bar{A}$ along the first dimension which gives the degree vector of the augmented graph. The diagonal degree matrix $\\\\bar{D}$ is then created by `dgl.sparse.diag`.\\n\",\n        \"* Compute $\\\\bar{D}^{-\\\\frac{1}{2}}$.\\n\",\n        \"* `D_hat_invsqrt @ A_hat @ D_hat_invsqrt` computes the convolution matrix which is then multiplied by the linearly transformed node features.\"\n      ],\n      \"metadata\": {\n        \"id\": \"r3qB1atg_ld0\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"import torch\\n\",\n        \"import torch.nn as nn\\n\",\n        \"import torch.nn.functional as F\\n\",\n        \"\\n\",\n        \"import dgl.sparse as dglsp\\n\",\n        \"\\n\",\n        \"class GCNLayer(nn.Module):\\n\",\n        \"    def __init__(self, in_size, out_size):\\n\",\n        \"        super(GCNLayer, self).__init__()\\n\",\n        \"        self.W = nn.Linear(in_size, out_size)\\n\",\n        \"\\n\",\n        \"    def forward(self, A, X):\\n\",\n        \"        ########################################################################\\n\",\n        \"        # (HIGHLIGHT) Compute the symmetrically normalized adjacency matrix with\\n\",\n        \"        # Sparse Matrix API\\n\",\n        \"        ########################################################################\\n\",\n        \"        I = dglsp.identity(A.shape)\\n\",\n        \"        A_hat = A + I\\n\",\n        \"        D_hat = dglsp.diag(A_hat.sum(0))\\n\",\n        \"        D_hat_invsqrt = D_hat ** -0.5\\n\",\n        \"        return D_hat_invsqrt @ A_hat @ D_hat_invsqrt @ self.W(X)\"\n      ],\n      \"metadata\": {\n        \"id\": \"Y4I4EhHQ_kKb\"\n      },\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"A Graph Convolutional Network is then defined by stacking this layer.\"\n      ],\n      \"metadata\": {\n        \"id\": \"bvP7O2IwV_c7\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"# Create a GCN with the GCN layer.\\n\",\n        \"class GCN(nn.Module):\\n\",\n        \"    def __init__(self, in_size, out_size, hidden_size):\\n\",\n        \"        super(GCN, self).__init__()\\n\",\n        \"        self.conv1 = GCNLayer(in_size, hidden_size)\\n\",\n        \"        self.conv2 = GCNLayer(hidden_size, out_size)\\n\",\n        \"\\n\",\n        \"    def forward(self, A, X):\\n\",\n        \"        X = self.conv1(A, X)\\n\",\n        \"        X = F.relu(X)\\n\",\n        \"        return self.conv2(A, X)\"\n      ],\n      \"metadata\": {\n        \"id\": \"BHX3vRjDWJTO\"\n      },\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"## Training the GCN\\n\",\n        \"\\n\",\n        \"We then train the GCN model on the Cora dataset for node classification. Note that since the model expects an adjacency matrix as the first argument, we first construct the adjacency matrix from the graph using the `dgl.sparse.from_coo` API which returns a DGL `SparseMatrix` object.\"\n      ],\n      \"metadata\": {\n        \"id\": \"2Qw7fTdGNnEp\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"def evaluate(g, pred):\\n\",\n        \"    label = g.ndata[\\\"label\\\"]\\n\",\n        \"    val_mask = g.ndata[\\\"val_mask\\\"]\\n\",\n        \"    test_mask = g.ndata[\\\"test_mask\\\"]\\n\",\n        \"\\n\",\n        \"    # Compute accuracy on validation/test set.\\n\",\n        \"    val_acc = (pred[val_mask] == label[val_mask]).float().mean()\\n\",\n        \"    test_acc = (pred[test_mask] == label[test_mask]).float().mean()\\n\",\n        \"    return val_acc, test_acc\\n\",\n        \"\\n\",\n        \"def train(model, g):\\n\",\n        \"    features = g.ndata[\\\"feat\\\"]\\n\",\n        \"    label = g.ndata[\\\"label\\\"]\\n\",\n        \"    train_mask = g.ndata[\\\"train_mask\\\"]\\n\",\n        \"    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)\\n\",\n        \"    loss_fcn = nn.CrossEntropyLoss()\\n\",\n        \"\\n\",\n        \"    # Preprocess to get the adjacency matrix of the graph.\\n\",\n        \"    indices = torch.stack(g.edges())\\n\",\n        \"    N = g.num_nodes()\\n\",\n        \"    A = dglsp.spmatrix(indices, shape=(N, N))\\n\",\n        \"\\n\",\n        \"    for epoch in range(100):\\n\",\n        \"        model.train()\\n\",\n        \"\\n\",\n        \"        # Forward.\\n\",\n        \"        logits = model(A, features)\\n\",\n        \"\\n\",\n        \"        # Compute loss with nodes in the training set.\\n\",\n        \"        loss = loss_fcn(logits[train_mask], label[train_mask])\\n\",\n        \"\\n\",\n        \"        # Backward.\\n\",\n        \"        optimizer.zero_grad()\\n\",\n        \"        loss.backward()\\n\",\n        \"        optimizer.step()\\n\",\n        \"\\n\",\n        \"        # Compute prediction.\\n\",\n        \"        pred = logits.argmax(dim=1)\\n\",\n        \"\\n\",\n        \"        # Evaluate the prediction.\\n\",\n        \"        val_acc, test_acc = evaluate(g, pred)\\n\",\n        \"        if epoch % 5 == 0:\\n\",\n        \"            print(\\n\",\n        \"                f\\\"In epoch {epoch}, loss: {loss:.3f}, val acc: {val_acc:.3f}\\\"\\n\",\n        \"                f\\\", test acc: {test_acc:.3f}\\\"\\n\",\n        \"            )\\n\",\n        \"\\n\",\n        \"\\n\",\n        \"# Load graph from the existing dataset.\\n\",\n        \"dataset = dgl.data.CoraGraphDataset()\\n\",\n        \"g = dataset[0]\\n\",\n        \"\\n\",\n        \"# Create model.\\n\",\n        \"feature = g.ndata['feat']\\n\",\n        \"in_size = feature.shape[1]\\n\",\n        \"out_size = dataset.num_classes\\n\",\n        \"gcn_model = GCN(in_size, out_size, 16)\\n\",\n        \"\\n\",\n        \"# Kick off training.\\n\",\n        \"train(gcn_model, g)\"\n      ],\n      \"metadata\": {\n        \"id\": \"5Sp1B1_QHgC2\",\n        \"colab\": {\n          \"base_uri\": \"https://localhost:8080/\"\n        },\n        \"outputId\": \"552e2c22-44f4-4495-c7f9-a57f13484270\"\n      },\n      \"execution_count\": null,\n      \"outputs\": [\n        {\n          \"output_type\": \"stream\",\n          \"name\": \"stdout\",\n          \"text\": [\n            \"Downloading /root/.dgl/cora_v2.zip from https://data.dgl.ai/dataset/cora_v2.zip...\\n\",\n            \"Extracting file to /root/.dgl/cora_v2\\n\",\n            \"Finished data loading and preprocessing.\\n\",\n            \"  NumNodes: 2708\\n\",\n            \"  NumEdges: 10556\\n\",\n            \"  NumFeats: 1433\\n\",\n            \"  NumClasses: 7\\n\",\n            \"  NumTrainingSamples: 140\\n\",\n            \"  NumValidationSamples: 500\\n\",\n            \"  NumTestSamples: 1000\\n\",\n            \"Done saving data into cached files.\\n\",\n            \"In epoch 0, loss: 1.954, val acc: 0.114, test acc: 0.103\\n\",\n            \"In epoch 5, loss: 1.921, val acc: 0.158, test acc: 0.147\\n\",\n            \"In epoch 10, loss: 1.878, val acc: 0.288, test acc: 0.283\\n\",\n            \"In epoch 15, loss: 1.822, val acc: 0.344, test acc: 0.353\\n\",\n            \"In epoch 20, loss: 1.751, val acc: 0.388, test acc: 0.389\\n\",\n            \"In epoch 25, loss: 1.663, val acc: 0.406, test acc: 0.410\\n\",\n            \"In epoch 30, loss: 1.562, val acc: 0.472, test acc: 0.481\\n\",\n            \"In epoch 35, loss: 1.450, val acc: 0.558, test acc: 0.573\\n\",\n            \"In epoch 40, loss: 1.333, val acc: 0.636, test acc: 0.641\\n\",\n            \"In epoch 45, loss: 1.216, val acc: 0.684, test acc: 0.683\\n\",\n            \"In epoch 50, loss: 1.102, val acc: 0.726, test acc: 0.713\\n\",\n            \"In epoch 55, loss: 0.996, val acc: 0.740, test acc: 0.740\\n\",\n            \"In epoch 60, loss: 0.899, val acc: 0.754, test acc: 0.760\\n\",\n            \"In epoch 65, loss: 0.813, val acc: 0.762, test acc: 0.771\\n\",\n            \"In epoch 70, loss: 0.737, val acc: 0.768, test acc: 0.781\\n\",\n            \"In epoch 75, loss: 0.671, val acc: 0.776, test acc: 0.786\\n\",\n            \"In epoch 80, loss: 0.614, val acc: 0.784, test acc: 0.790\\n\",\n            \"In epoch 85, loss: 0.566, val acc: 0.780, test acc: 0.788\\n\",\n            \"In epoch 90, loss: 0.524, val acc: 0.780, test acc: 0.791\\n\",\n            \"In epoch 95, loss: 0.489, val acc: 0.772, test acc: 0.795\\n\"\n          ]\n        }\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"*Check out the full example script* [here](https://github.com/dmlc/dgl/blob/master/examples/sparse/gcn.py).\"\n      ],\n      \"metadata\": {\n        \"id\": \"yQnJZvE9ZduM\"\n      }\n    }\n  ]\n}\n"
  },
  {
    "path": "notebooks/sparse/graph_diffusion.ipynb",
    "content": "{\n  \"nbformat\": 4,\n  \"nbformat_minor\": 0,\n  \"metadata\": {\n    \"colab\": {\n      \"provenance\": [],\n      \"toc_visible\": true\n    },\n    \"kernelspec\": {\n      \"name\": \"python3\",\n      \"display_name\": \"Python 3\"\n    },\n    \"language_info\": {\n      \"name\": \"python\"\n    },\n    \"gpuClass\": \"standard\",\n    \"accelerator\": \"GPU\"\n  },\n  \"cells\": [\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"# Graph Diffusion in Graph Neural Networks\\n\",\n        \"\\n\",\n        \"This tutorial first briefly introduces the diffusion process on graphs. It then illustrates how Graph Neural Networks can utilize this concept to enhance prediction power.\\n\",\n        \"\\n\",\n        \"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dmlc/dgl/blob/master/notebooks/sparse/graph_diffusion.ipynb) [![GitHub](https://img.shields.io/badge/-View%20on%20GitHub-181717?logo=github&logoColor=ffffff)](https://github.com/dmlc/dgl/blob/master/notebooks/sparse/graph_diffusion.ipynb)\"\n      ],\n      \"metadata\": {\n        \"id\": \"SfdsDpOK7yOT\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"F6eQWmWn7lqh\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"# Install required packages.\\n\",\n        \"import os\\n\",\n        \"import torch\\n\",\n        \"os.environ['TORCH'] = torch.__version__\\n\",\n        \"os.environ['DGLBACKEND'] = \\\"pytorch\\\"\\n\",\n        \"\\n\",\n        \"# Uncomment below to install required packages. If the CUDA version is not 11.8,\\n\",\n        \"# check the https://www.dgl.ai/pages/start.html to find the supported CUDA\\n\",\n        \"# version and corresponding command to install DGL.\\n\",\n        \"#!pip install dgl -f https://data.dgl.ai/wheels/cu118/repo.html > /dev/null\\n\",\n        \"#!pip install --upgrade scipy networkx > /dev/null\\n\",\n        \"\\n\",\n        \"try:\\n\",\n        \"    import dgl\\n\",\n        \"    installed = True\\n\",\n        \"except ImportError:\\n\",\n        \"    installed = False\\n\",\n        \"print(\\\"DGL installed!\\\" if installed else \\\"Failed to install DGL!\\\")\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"## Graph Diffusion\\n\",\n        \"\\n\",\n        \"Diffusion describes the process of substances moving from one region to another. In the context of graph, the diffusing substances (e.g., real-value signals) travel along edges from nodes to nodes.\\n\",\n        \"\\n\",\n        \"Mathematically, let $\\\\vec x$ be the vector of node signals, then a graph diffusion operation can be defined as:\\n\",\n        \"\\n\",\n        \"$$\\n\",\n        \"\\\\vec{y} = \\\\tilde{A} \\\\vec{x}\\n\",\n        \"$$\\n\",\n        \"\\n\",\n        \", where $\\\\tilde{A}$ is the **diffusion matrix** that is typically derived from the adjacency matrix of the graph. Although the selection of diffusion matrices may vary, the diffusion matrix is typically sparse and $\\\\tilde{A} \\\\vec{x}$ is thus a sparse-dense matrix multiplication.\\n\",\n        \"\\n\",\n        \"Let us understand it more with a simple example. First, we obtain the adjacency matrix of the famous [Karate Club Network](https://en.wikipedia.org/wiki/Zachary%27s_karate_club).\"\n      ],\n      \"metadata\": {\n        \"id\": \"iH6os3oFcyze\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"import dgl\\n\",\n        \"import dgl.sparse as dglsp\\n\",\n        \"from dgl.data import KarateClubDataset\\n\",\n        \"\\n\",\n        \"# Get the graph from DGL's builtin dataset.\\n\",\n        \"dataset = KarateClubDataset()\\n\",\n        \"dgl_g = dataset[0]\\n\",\n        \"\\n\",\n        \"# Get its adjacency matrix.\\n\",\n        \"indices = torch.stack(dgl_g.edges())\\n\",\n        \"N = dgl_g.num_nodes()\\n\",\n        \"A = dglsp.spmatrix(indices, shape=(N, N))\\n\",\n        \"print(A.to_dense())\"\n      ],\n      \"metadata\": {\n        \"colab\": {\n          \"base_uri\": \"https://localhost:8080/\"\n        },\n        \"id\": \"_TnCECJmBKJE\",\n        \"outputId\": \"d8b78f0b-3a1c-4a9e-bcc9-ed4df7b7b5b7\"\n      },\n      \"execution_count\": null,\n      \"outputs\": [\n        {\n          \"output_type\": \"stream\",\n          \"name\": \"stdout\",\n          \"text\": [\n            \"tensor([[0., 1., 1.,  ..., 1., 0., 0.],\\n\",\n            \"        [1., 0., 1.,  ..., 0., 0., 0.],\\n\",\n            \"        [1., 1., 0.,  ..., 0., 1., 0.],\\n\",\n            \"        ...,\\n\",\n            \"        [1., 0., 0.,  ..., 0., 1., 1.],\\n\",\n            \"        [0., 0., 1.,  ..., 1., 0., 1.],\\n\",\n            \"        [0., 0., 0.,  ..., 1., 1., 0.]])\\n\"\n          ]\n        }\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"We use the graph convolution matrix from Graph Convolution Networks as the diffusion matrix in this example. The graph convolution matrix is defined as:\\n\",\n        \"\\n\",\n        \"$$\\\\tilde{A} = \\\\bar{D}^{-\\\\frac{1}{2}}\\\\bar{A}\\\\bar{D}^{-\\\\frac{1}{2}}$$\\n\",\n        \"\\n\",\n        \"with $\\\\bar{A} = A + I$, where $A$ denotes the adjacency matrix and $I$ denotes the identity matrix, $\\\\bar{D}$ refers to the diagonal node degree matrix of $\\\\bar{A}$.\"\n      ],\n      \"metadata\": {\n        \"id\": \"wJMT4oHOCCqJ\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"# Compute graph convolution matrix.\\n\",\n        \"I = dglsp.identity(A.shape)\\n\",\n        \"A_hat = A + I\\n\",\n        \"D_hat = dglsp.diag(A_hat.sum(dim=1))\\n\",\n        \"D_hat_invsqrt = D_hat ** -0.5\\n\",\n        \"A_tilde = D_hat_invsqrt @ A_hat @ D_hat_invsqrt\\n\",\n        \"print(A_tilde.to_dense())\"\n      ],\n      \"metadata\": {\n        \"colab\": {\n          \"base_uri\": \"https://localhost:8080/\"\n        },\n        \"id\": \"JyzctBGaC_O5\",\n        \"outputId\": \"b03ef3dc-dcf5-494e-9191-30591d09f138\"\n      },\n      \"execution_count\": null,\n      \"outputs\": [\n        {\n          \"output_type\": \"stream\",\n          \"name\": \"stdout\",\n          \"text\": [\n            \"tensor([[0.0588, 0.0767, 0.0731,  ..., 0.0917, 0.0000, 0.0000],\\n\",\n            \"        [0.0767, 0.1000, 0.0953,  ..., 0.0000, 0.0000, 0.0000],\\n\",\n            \"        [0.0731, 0.0953, 0.0909,  ..., 0.0000, 0.0836, 0.0000],\\n\",\n            \"        ...,\\n\",\n            \"        [0.0917, 0.0000, 0.0000,  ..., 0.1429, 0.1048, 0.0891],\\n\",\n            \"        [0.0000, 0.0000, 0.0836,  ..., 0.1048, 0.0769, 0.0654],\\n\",\n            \"        [0.0000, 0.0000, 0.0000,  ..., 0.0891, 0.0654, 0.0556]])\\n\"\n          ]\n        }\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"For node signals, we set all nodes but one to be zero.\"\n      ],\n      \"metadata\": {\n        \"id\": \"geYvWuUkDbiL\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"# Initial node signals. All nodes except one are set to zero.\\n\",\n        \"X = torch.zeros(N)\\n\",\n        \"X[0] = 5.\\n\",\n        \"\\n\",\n        \"# Number of diffusion steps.\\n\",\n        \"r = 8\\n\",\n        \"\\n\",\n        \"# Record the signals after each diffusion step.\\n\",\n        \"results = [X]\\n\",\n        \"for _ in range(r):\\n\",\n        \"    X = A_tilde @ X\\n\",\n        \"    results.append(X)\"\n      ],\n      \"metadata\": {\n        \"id\": \"DXb0uKqXDZKb\"\n      },\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"The program below visualizes the diffusion process with animation. To play the animation, click the \\\"play\\\" icon. You will see how node features converge over time.\"\n      ],\n      \"metadata\": {\n        \"id\": \"TpqMz4muF2aO\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"import matplotlib.pyplot as plt\\n\",\n        \"import networkx as nx\\n\",\n        \"from IPython.display import HTML\\n\",\n        \"from matplotlib import animation\\n\",\n        \"\\n\",\n        \"nx_g = dgl_g.to_networkx().to_undirected()\\n\",\n        \"pos = nx.spring_layout(nx_g)\\n\",\n        \"\\n\",\n        \"fig, ax = plt.subplots()\\n\",\n        \"plt.close()\\n\",\n        \"\\n\",\n        \"def animate(i):\\n\",\n        \"    ax.cla()\\n\",\n        \"    # Color nodes based on their features.\\n\",\n        \"    nodes = nx.draw_networkx_nodes(nx_g, pos, ax=ax, node_size=200, node_color=results[i].tolist(), cmap=plt.cm.Blues)\\n\",\n        \"    # Set boundary color of the nodes.\\n\",\n        \"    nodes.set_edgecolor(\\\"#000000\\\")\\n\",\n        \"    nx.draw_networkx_edges(nx_g, pos, ax=ax)\\n\",\n        \"\\n\",\n        \"ani = animation.FuncAnimation(fig, animate, frames=len(results), interval=1000)\\n\",\n        \"HTML(ani.to_jshtml())\"\n      ],\n      \"metadata\": {\n        \"id\": \"eN3kmJ8nl7_z\",\n        \"colab\": {\n          \"base_uri\": \"https://localhost:8080/\",\n          \"height\": 386\n        },\n        \"outputId\": \"be93263e-2283-4db7-caff-2e15e75ceb02\"\n      },\n      \"execution_count\": null,\n      \"outputs\": [\n        {\n          \"output_type\": \"execute_result\",\n          \"data\": {\n            \"text/plain\": [\n              \"<IPython.core.display.HTML object>\"\n            ],\n            \"text/html\": [\n              \"\\n\",\n              \"<link rel=\\\"stylesheet\\\"\\n\",\n              \"href=\\\"https://maxcdn.bootstrapcdn.com/font-awesome/4.4.0/\\n\",\n              \"css/font-awesome.min.css\\\">\\n\",\n              \"<script language=\\\"javascript\\\">\\n\",\n              \"  function isInternetExplorer() {\\n\",\n              \"    ua = navigator.userAgent;\\n\",\n              \"    /* MSIE used to detect old browsers and Trident used to newer ones*/\\n\",\n              \"    return ua.indexOf(\\\"MSIE \\\") > -1 || ua.indexOf(\\\"Trident/\\\") > -1;\\n\",\n              \"  }\\n\",\n              \"\\n\",\n              \"  /* Define the Animation class */\\n\",\n              \"  function Animation(frames, img_id, slider_id, interval, loop_select_id){\\n\",\n              \"    this.img_id = img_id;\\n\",\n              \"    this.slider_id = slider_id;\\n\",\n              \"    this.loop_select_id = loop_select_id;\\n\",\n              \"    this.interval = interval;\\n\",\n              \"    this.current_frame = 0;\\n\",\n              \"    this.direction = 0;\\n\",\n              \"    this.timer = null;\\n\",\n              \"    this.frames = new Array(frames.length);\\n\",\n              \"\\n\",\n              \"    for (var i=0; i<frames.length; i++)\\n\",\n              \"    {\\n\",\n              \"     this.frames[i] = new Image();\\n\",\n              \"     this.frames[i].src = frames[i];\\n\",\n              \"    }\\n\",\n              \"    var slider = document.getElementById(this.slider_id);\\n\",\n              \"    slider.max = this.frames.length - 1;\\n\",\n              \"    if (isInternetExplorer()) {\\n\",\n              \"        // switch from oninput to onchange because IE <= 11 does not conform\\n\",\n              \"        // with W3C specification. It ignores oninput and onchange behaves\\n\",\n              \"        // like oninput. In contrast, Mircosoft Edge behaves correctly.\\n\",\n              \"        slider.setAttribute('onchange', slider.getAttribute('oninput'));\\n\",\n              \"        slider.setAttribute('oninput', null);\\n\",\n              \"    }\\n\",\n              \"    this.set_frame(this.current_frame);\\n\",\n              \"  }\\n\",\n              \"\\n\",\n              \"  Animation.prototype.get_loop_state = function(){\\n\",\n              \"    var button_group = document[this.loop_select_id].state;\\n\",\n              \"    for (var i = 0; i < button_group.length; i++) {\\n\",\n              \"        var button = button_group[i];\\n\",\n              \"        if (button.checked) {\\n\",\n              \"            return button.value;\\n\",\n              \"        }\\n\",\n              \"    }\\n\",\n              \"    return undefined;\\n\",\n              \"  }\\n\",\n              \"\\n\",\n              \"  Animation.prototype.set_frame = function(frame){\\n\",\n              \"    this.current_frame = frame;\\n\",\n              \"    document.getElementById(this.img_id).src =\\n\",\n              \"            this.frames[this.current_frame].src;\\n\",\n              \"    document.getElementById(this.slider_id).value = this.current_frame;\\n\",\n              \"  }\\n\",\n              \"\\n\",\n              \"  Animation.prototype.next_frame = function()\\n\",\n              \"  {\\n\",\n              \"    this.set_frame(Math.min(this.frames.length - 1, this.current_frame + 1));\\n\",\n              \"  }\\n\",\n              \"\\n\",\n              \"  Animation.prototype.previous_frame = function()\\n\",\n              \"  {\\n\",\n              \"    this.set_frame(Math.max(0, this.current_frame - 1));\\n\",\n              \"  }\\n\",\n              \"\\n\",\n              \"  Animation.prototype.first_frame = function()\\n\",\n              \"  {\\n\",\n              \"    this.set_frame(0);\\n\",\n              \"  }\\n\",\n              \"\\n\",\n              \"  Animation.prototype.last_frame = function()\\n\",\n              \"  {\\n\",\n              \"    this.set_frame(this.frames.length - 1);\\n\",\n              \"  }\\n\",\n              \"\\n\",\n              \"  Animation.prototype.slower = function()\\n\",\n              \"  {\\n\",\n              \"    this.interval /= 0.7;\\n\",\n              \"    if(this.direction > 0){this.play_animation();}\\n\",\n              \"    else if(this.direction < 0){this.reverse_animation();}\\n\",\n              \"  }\\n\",\n              \"\\n\",\n              \"  Animation.prototype.faster = function()\\n\",\n              \"  {\\n\",\n              \"    this.interval *= 0.7;\\n\",\n              \"    if(this.direction > 0){this.play_animation();}\\n\",\n              \"    else if(this.direction < 0){this.reverse_animation();}\\n\",\n              \"  }\\n\",\n              \"\\n\",\n              \"  Animation.prototype.anim_step_forward = function()\\n\",\n              \"  {\\n\",\n              \"    this.current_frame += 1;\\n\",\n              \"    if(this.current_frame < this.frames.length){\\n\",\n              \"      this.set_frame(this.current_frame);\\n\",\n              \"    }else{\\n\",\n              \"      var loop_state = this.get_loop_state();\\n\",\n              \"      if(loop_state == \\\"loop\\\"){\\n\",\n              \"        this.first_frame();\\n\",\n              \"      }else if(loop_state == \\\"reflect\\\"){\\n\",\n              \"        this.last_frame();\\n\",\n              \"        this.reverse_animation();\\n\",\n              \"      }else{\\n\",\n              \"        this.pause_animation();\\n\",\n              \"        this.last_frame();\\n\",\n              \"      }\\n\",\n              \"    }\\n\",\n              \"  }\\n\",\n              \"\\n\",\n              \"  Animation.prototype.anim_step_reverse = function()\\n\",\n              \"  {\\n\",\n              \"    this.current_frame -= 1;\\n\",\n              \"    if(this.current_frame >= 0){\\n\",\n              \"      this.set_frame(this.current_frame);\\n\",\n              \"    }else{\\n\",\n              \"      var loop_state = this.get_loop_state();\\n\",\n              \"      if(loop_state == \\\"loop\\\"){\\n\",\n              \"        this.last_frame();\\n\",\n              \"      }else if(loop_state == \\\"reflect\\\"){\\n\",\n              \"        this.first_frame();\\n\",\n              \"        this.play_animation();\\n\",\n              \"      }else{\\n\",\n              \"        this.pause_animation();\\n\",\n              \"        this.first_frame();\\n\",\n              \"      }\\n\",\n              \"    }\\n\",\n              \"  }\\n\",\n              \"\\n\",\n              \"  Animation.prototype.pause_animation = function()\\n\",\n              \"  {\\n\",\n              \"    this.direction = 0;\\n\",\n              \"    if (this.timer){\\n\",\n              \"      clearInterval(this.timer);\\n\",\n              \"      this.timer = null;\\n\",\n              \"    }\\n\",\n              \"  }\\n\",\n              \"\\n\",\n              \"  Animation.prototype.play_animation = function()\\n\",\n              \"  {\\n\",\n              \"    this.pause_animation();\\n\",\n              \"    this.direction = 1;\\n\",\n              \"    var t = this;\\n\",\n              \"    if (!this.timer) this.timer = setInterval(function() {\\n\",\n              \"        t.anim_step_forward();\\n\",\n              \"    }, this.interval);\\n\",\n              \"  }\\n\",\n              \"\\n\",\n              \"  Animation.prototype.reverse_animation = function()\\n\",\n              \"  {\\n\",\n              \"    this.pause_animation();\\n\",\n              \"    this.direction = -1;\\n\",\n              \"    var t = this;\\n\",\n              \"    if (!this.timer) this.timer = setInterval(function() {\\n\",\n              \"        t.anim_step_reverse();\\n\",\n              \"    }, this.interval);\\n\",\n              \"  }\\n\",\n              \"</script>\\n\",\n              \"\\n\",\n              \"<style>\\n\",\n              \".animation {\\n\",\n              \"    display: inline-block;\\n\",\n              \"    text-align: center;\\n\",\n              \"}\\n\",\n              \"input[type=range].anim-slider {\\n\",\n              \"    width: 374px;\\n\",\n              \"    margin-left: auto;\\n\",\n              \"    margin-right: auto;\\n\",\n              \"}\\n\",\n              \".anim-buttons {\\n\",\n              \"    margin: 8px 0px;\\n\",\n              \"}\\n\",\n              \".anim-buttons button {\\n\",\n              \"    padding: 0;\\n\",\n              \"    width: 36px;\\n\",\n              \"}\\n\",\n              \".anim-state label {\\n\",\n              \"    margin-right: 8px;\\n\",\n              \"}\\n\",\n              \".anim-state input {\\n\",\n              \"    margin: 0;\\n\",\n              \"    vertical-align: middle;\\n\",\n              \"}\\n\",\n              \"</style>\\n\",\n              \"\\n\",\n              \"<div class=\\\"animation\\\">\\n\",\n              \"  <img id=\\\"_anim_imgb11f2637772a40ca98bbdbbd7669890b\\\">\\n\",\n              \"  <div class=\\\"anim-controls\\\">\\n\",\n              \"    <input id=\\\"_anim_sliderb11f2637772a40ca98bbdbbd7669890b\\\" type=\\\"range\\\" class=\\\"anim-slider\\\"\\n\",\n              \"           name=\\\"points\\\" min=\\\"0\\\" max=\\\"1\\\" step=\\\"1\\\" value=\\\"0\\\"\\n\",\n              \"           oninput=\\\"animb11f2637772a40ca98bbdbbd7669890b.set_frame(parseInt(this.value));\\\"></input>\\n\",\n              \"    <div class=\\\"anim-buttons\\\">\\n\",\n              \"      <button onclick=\\\"animb11f2637772a40ca98bbdbbd7669890b.slower()\\\"><i class=\\\"fa fa-minus\\\"></i></button>\\n\",\n              \"      <button onclick=\\\"animb11f2637772a40ca98bbdbbd7669890b.first_frame()\\\"><i class=\\\"fa fa-fast-backward\\\">\\n\",\n              \"          </i></button>\\n\",\n              \"      <button onclick=\\\"animb11f2637772a40ca98bbdbbd7669890b.previous_frame()\\\">\\n\",\n              \"          <i class=\\\"fa fa-step-backward\\\"></i></button>\\n\",\n              \"      <button onclick=\\\"animb11f2637772a40ca98bbdbbd7669890b.reverse_animation()\\\">\\n\",\n              \"          <i class=\\\"fa fa-play fa-flip-horizontal\\\"></i></button>\\n\",\n              \"      <button onclick=\\\"animb11f2637772a40ca98bbdbbd7669890b.pause_animation()\\\"><i class=\\\"fa fa-pause\\\">\\n\",\n              \"          </i></button>\\n\",\n              \"      <button onclick=\\\"animb11f2637772a40ca98bbdbbd7669890b.play_animation()\\\"><i class=\\\"fa fa-play\\\"></i>\\n\",\n              \"          </button>\\n\",\n              \"      <button onclick=\\\"animb11f2637772a40ca98bbdbbd7669890b.next_frame()\\\"><i class=\\\"fa fa-step-forward\\\">\\n\",\n              \"          </i></button>\\n\",\n              \"      <button onclick=\\\"animb11f2637772a40ca98bbdbbd7669890b.last_frame()\\\"><i class=\\\"fa fa-fast-forward\\\">\\n\",\n              \"          </i></button>\\n\",\n              \"      <button onclick=\\\"animb11f2637772a40ca98bbdbbd7669890b.faster()\\\"><i class=\\\"fa fa-plus\\\"></i></button>\\n\",\n              \"    </div>\\n\",\n              \"    <form action=\\\"#n\\\" name=\\\"_anim_loop_selectb11f2637772a40ca98bbdbbd7669890b\\\" class=\\\"anim-state\\\">\\n\",\n              \"      <input type=\\\"radio\\\" name=\\\"state\\\" value=\\\"once\\\" id=\\\"_anim_radio1_b11f2637772a40ca98bbdbbd7669890b\\\"\\n\",\n              \"             >\\n\",\n              \"      <label for=\\\"_anim_radio1_b11f2637772a40ca98bbdbbd7669890b\\\">Once</label>\\n\",\n              \"      <input type=\\\"radio\\\" name=\\\"state\\\" value=\\\"loop\\\" id=\\\"_anim_radio2_b11f2637772a40ca98bbdbbd7669890b\\\"\\n\",\n              \"             checked>\\n\",\n              \"      <label for=\\\"_anim_radio2_b11f2637772a40ca98bbdbbd7669890b\\\">Loop</label>\\n\",\n              \"      <input type=\\\"radio\\\" name=\\\"state\\\" value=\\\"reflect\\\" id=\\\"_anim_radio3_b11f2637772a40ca98bbdbbd7669890b\\\"\\n\",\n              \"             >\\n\",\n              \"      <label for=\\\"_anim_radio3_b11f2637772a40ca98bbdbbd7669890b\\\">Reflect</label>\\n\",\n              \"    </form>\\n\",\n              \"  </div>\\n\",\n              \"</div>\\n\",\n              \"\\n\",\n              \"\\n\",\n              \"<script language=\\\"javascript\\\">\\n\",\n              \"  /* Instantiate the Animation class. */\\n\",\n              \"  /* The IDs given should match those used in the template above. */\\n\",\n              \"  (function() {\\n\",\n              \"    var img_id = \\\"_anim_imgb11f2637772a40ca98bbdbbd7669890b\\\";\\n\",\n              \"    var slider_id = \\\"_anim_sliderb11f2637772a40ca98bbdbbd7669890b\\\";\\n\",\n              \"    var loop_select_id = \\\"_anim_loop_selectb11f2637772a40ca98bbdbbd7669890b\\\";\\n\",\n              \"    var frames = new Array(9);\\n\",\n              \"    \\n\",\n              \"  frames[0] = \\\"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAbAAAAEgCAYAAADVKCZpAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\\\\\\n\",\n              \"AAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0\\\\\\n\",\n              \"dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOzdeVhN2/8H8Hepm6J5RpQyRyWRUDJT\\\\\\n\",\n              \"moTMhDK7uOaZa55FmZIhQ4ZoRIVEpiIlY8Y0T5rH01m/P+7X+d1uqXOQOnxez7Ofup299v6ck7vf\\\\\\n\",\n              \"7b3XXkuEMcZACCGECBnRui6AEEII+RYUYIQQQoQSBRghhBChRAFGCCFEKFGAEUIIEUoUYIQQQoQS\\\\\\n\",\n              \"BRghhBChRAFGCCFEKFGAEUIIEUoUYIQQQoQSBRghhBChRAFGCCFEKFGAEUIIEUoUYIQQQoQSBRgh\\\\\\n\",\n              \"hBChRAFGCCFEKFGAEUIIEUoUYIQQQoQSBRghhBChRAFGCCFEKFGAEUIIEUoUYIQQQoQSBRghhBCh\\\\\\n\",\n              \"RAFGCCFEKFGAEUIIEUoUYIQQQoQSBRghhBChRAFGCCFEKFGAEUIIEUoUYIQQQoQSBRghhBChRAFG\\\\\\n\",\n              \"CCFEKFGAEUIIEUoUYIQQQoSSWF0XUFuUlJSgqalZ12UQQohQ+fDhAzIyMuq6DL78sgGmqamJyMjI\\\\\\n\",\n              \"ui6DEEKESpcuXeq6BL7RJURCCCFCiQKMEEKIUKIAI4QQIpQowAghhAglCjBCCCFCiQKMEEKIUKIA\\\\\\n\",\n              \"I4QQIpQowAghhAilX/ZBZkII+VmKiop4o1coKyujYcOGdVzR74HOwAgh5BswxhAeHo4xY8ZAWVkZ\\\\\\n\",\n              \"3bt3h7GxMZSVlTFhwgQ8fPgQjLG6LvOXRgFGCCECKiwsxLBhwzBx0iToGXRB3PtPePMhAW8/JuJF\\\\\\n\",\n              \"3Hu0aa+LkQ4OGD16NIqLi+u63F8WBRghhAigrKwMVtbWkGgohUdPYjHnz3mQl5fnva6kpIT5Cxbi\\\\\\n\",\n              \"cfQzFJeWYfiIESgvL6/Din9dFGCEECKALVu2QkxMHIePHsMff/zx1fUkJSVxwvMM8vLysXevy0+s\\\\\\n\",\n              \"8PdBAUYIIXwqKyvDgQNuWL9hM8TEau4DJy4ujrXrN2L//n3gcrk/ocLfCwUYIYTwyc/PDy00tdCx\\\\\\n\",\n              \"Uye+23Tt1g3SMjIIDg6uxcp+TxRghBDCpwcPHmLgoMECtREREcGAgYPx4MGDWqrq90UBRgghfMrP\\\\\\n\",\n              \"z0ejRo0FbtdYWhr5+QW1UNHvjQKMEEL4JCsri+zszwK3+5yVBVlZmVqo6PdGAUYIIXzq27cPLl+6\\\\\\n\",\n              \"KNADylwuFz6XvdGnT59arOz3RAFGCCF86tOnD8pKS3E3PJzvNtdDgiErIwNjY+NarOz3RAFGCCF8\\\\\\n\",\n              \"EhERwejRozF75jQUFNR8TysnJwfLFi/EvHnzICIi8hMq/L1QgBFCCJ98fX3h4uICZSVFWFsORmZm\\\\\\n\",\n              \"5lfXTUtLg9WQgTAzM8PYsWN/YpW/DwowQgipAZfLxerVqzFz5kz4+fnh5s2bMOneHR3btcLMaU6I\\\\\\n\",\n              \"fvIERUVFKCwsxONHjzBzmhP0OrTBwIED4eKyl86+aglNp0IIIdX4/Pkzxo4di/z8fERGRkJVVRUA\\\\\\n\",\n              \"sHXrFsydOwe9e/fG1auByPzfdCoaGhqYMGEiXrx4ATU1tbos/ZdHAUYIIV/x9OlT2NraYujQodi6\\\\\\n\",\n              \"dSvExcUrvN6oUSOkpqYiOTkZjRo1qqMqf190CZEQQqpw9uxZ9OnTB2vXrsWuXbsqhRcABAYGonfv\\\\\\n\",\n              \"3hRedYTOwAgh5F84HA4WL16MS5cuITg4GPr6+l9d18fHB9bW1j+xOvJvFGCEEPI/aWlpGDlyJCQk\\\\\\n\",\n              \"JBAZGQkFBYWvrltSUoJr165h7969P7FC8m90CZEQQgBERETAyMgIPXr0QEBAQLXhBQA3b96Erq4u\\\\\\n\",\n              \"r1MH+fnoDIwQ8ttzd3fH0qVLcfDgQdja2vLVhi4f1j0KMELIb6ukpARz587FrVu3EBYWhrZt2/LV\\\\\\n\",\n              \"jsvlwtfXFzdv3qzlCkl1KMAIIb+lxMRE2Nvbo0mTJnj48CGkpaX5bhsZGQkZGRm0bt26FiskNaF7\\\\\\n\",\n              \"YISQ305YWBiMjIxgZWWFCxcuCBReAF0+rC8owAghvw3GGPbu3YsRI0bg2LFjWLp06TcN8+Tj4wMb\\\\\\n\",\n              \"G5taqJAIgi4hEkKEFofDgb+/PzxPnUJKSgpERUWh2aIFHB0dYWZmViGcCgsL4eTkhGfPnuHevXvQ\\\\\\n\",\n              \"0tL6pn2+ffsWmZmZ6Nq16496G+QbUYARQoTS8eMnsGLFcjTTaI5JjlPQUlsbXC4XsU+fYsbMmQBj\\\\\\n\",\n              \"2LNnD/r37493797Bzs4OHTt2RHh4OKSkpL55vz4+Phg6dChERekCVl2jACOECJ1Nmzbj8JHDOHv+\\\\\\n\",\n              \"Egy7dKnwmqlZb0yfOQtB165i3LhxmDRpEo4ePYoVK1Zg1qxZAl0yLCoqwrlz53DnTjjyC/Ih3Vga\\\\\\n\",\n              \"YWG3sHnz5h/9lsg3oAAjhAiV8+cv4OChg7gZdhfq6upVriMiIoKBgwbD/0owzE1NsGPHDjg5OfG9\\\\\\n\",\n              \"j5KSEqxevQbu7kdg2MUIg4dYQkZWFjnZ2YiOjoazszOiop5gxYrlVY6RSH4OEcYYq+siakOXLl0Q\\\\\\n\",\n              \"GRlZ12UQQn4gxhj09fWxYfM29Os/gK82rvv24t7dOzh/7hxf6xcUFMDS0hLSsnLYvGU7WmprV1rn\\\\\\n\",\n              \"9atXWLjgT4iKiuDypUuQkJAQ6H3UZ8J07KSLuIQQoXH37l0UFhWhT99+fLcZM24CQoKDkZSUVOO6\\\\\\n\",\n              \"jDGMGTsWTZs1x9lzF6sMLwBo3aYNLl72g6RUI0yZMoXvWsiPRQFGCBEaFy96Y8zY8QJ1oJCVlcXg\\\\\\n\",\n              \"IZbw8/Orcd2IiAhER0fD7dCRGvchJiYGd48TuH79Op4/f853PeTHoQAjhAiNzMwMNGnSVOB2TZs2\\\\\\n\",\n              \"RUpKSo3rubq6YqrzdL7va0lKSmKi4xS4uroJXBP5ftSJgxAiNP74QwKlpaUCt8vPz8eunduxfft2\\\\\\n\",\n              \"qKqq8hY1NTXe90pKSjh37hxev4sXaNsTHaegu5EB9u1zEbgu8n0owAghQqNNmzZ4cP8epjg5C9Tu\\\\\\n\",\n              \"SdRjXLp0CaampkhNTeUtKSkpSE1NxePHj/Hp0yeIiIhCSUlJoG1raGggNzcXpaWl+OOPPwRqS74P\\\\\\n\",\n              \"BRghRGhMmDAerVu3RmZmJhQVFflq8yQqComJCRg8eDDExMQgKytb5SC8OTk50NDQELgmxhgYY2jQ\\\\\\n\",\n              \"oIHAbcn3oXtghNRD5eXlCAoKwsGDB+Hq6oqLFy+ioKCgrsuqc8rKyhg6dCj27t7J1/qMMfy9bjW4\\\\\\n\",\n              \"XG6NXcOlpaUhJiaGD+/fC1TT82fP0KRJEwqwOkABRkg9kpubi40bN6Fly5ZYvnwFHkY8wpPopzhw\\\\\\n\",\n              \"8BCaN2+OP/+ch/h4we7R/Go2bdqEs6c94eF+pNr1GGNYsWwJUpKTsHbtWtjZ2cHJyQmZmZkV1svO\\\\\\n\",\n              \"zoa7uzv69u0LDoeDA277BarnyKEDcHScLPD7IN+PLiESUk8kJiZi0KBBaNdBF2fOeaOzoWGF1z9+\\\\\\n\",\n              \"/IhDB1zRrVs3XL58Gd26daujSn++58+f4+hRD7x7/w7l5eXo3t0E69asRNC1q1iwcDG6GBnx1uVy\\\\\\n\",\n              \"ubh54zr27NqB3JxsXLlyBcrKyrC3t8eqVavQvn17rF27Fmpqajh9+jSCgoLQt29fzJ07F61atYK5\\\\\\n\",\n              \"uTkWLVkGBQWFGutKTU3FOa8ziI2Nrc23T76CAoyQeiAnJweDBg2C/QgHLFqyrMrx+lq0aIENm7ag\\\\\\n\",\n              \"R89esLKyQmhoKNq1a1cH1f48UVFRmD9/Pl6+fInxEx0xbLgD7zJfZGQkQm9eR/idMKipq6Nly38G\\\\\\n\",\n              \"8332LBaNGzXCjBkzMX78OEhKSgIAZGRk4ODggOTkZMyePRuSkpL4888/cfDgQcjLywP456ytXbt2\\\\\\n\",\n              \"sBzcH9dCQqudJ+zz588YZm2JOXPmokmTJj/l8yD/wX5RhoaGdV0CIXxbu3Ydcxg1hhWWcllRGatx\\\\\\n\",\n              \"2bZjN7McOrSuy65VN27cYEpKSszt4BGWU1BS6TMoLOWyq8E3mEbz5mz27Nns/PnzzNvbmz169Ihx\\\\\\n\",\n              \"uVzedl6/fs1WrVrFWrZsydq2bcs2bNjA3r59y44cOcJUVFTY7NmzWXZ2NisrK2NOTk5MT0+PjRs3\\\\\\n\",\n              \"junqdmTnLl5meUVlFfabW1jKTp09z1q3acPmz19QYV+/AmE6dtJYiITUMQ6HA01NTXj7BKCTnh5f\\\\\\n\",\n              \"bQoLC9FKSwOPHj2CpqZm7RZYB+Li4tCjRw94njkHU7Pe1a6bkJCAvmY9sHPnLgwbZgcASE9Ph5eX\\\\\\n\",\n              \"Fzw9PfHhwwc4ODhg3Lhx6Ny5c4Wz28zMTCxduhR+fn5QU1ODsrIyLl68iMaNG8PLywt79+5FQkIC\\\\\\n\",\n              \"Bg4aAmkZGXzOykKAvy+kZWTQ2aAzRo4cASsrq1+q+7wwHTupEwchdezatWvQ0GjOd3gBgJSUFEaN\\\\\\n\",\n              \"GYejRz1qsbK6s337DjhNm1FjeAFAs2bN4HrwCNasWY2zZ89i6NChaNWqFe7fv481a9YgISEBu3fv\\\\\\n\",\n              \"hqGhYaVLs4qKivj7778hLy+P+Ph4lJaW/u95MBE4ODjg7t278PHxgQi4uHDuLC5fuogBg4bAcYoT\\\\\\n\",\n              \"2rbvgL0u+9CiRQusWrUaRUVFtfRpkK+hACOkjr19+xZ6Bp0Fbqevb4B379/VQkV1KycnB+fOecFx\\\\\\n\",\n              \"Cv/Tn/Tp2w+fP2dj586dGDFiBD59+gRPT08MGjQIYmJfv9UfFxcHExMT2NvbIzU1Ffb29jAzM8Pi\\\\\\n\",\n              \"xYuRn58PAPDz80dwSAi279qLT8npOHL0GBb8tQjLV65G0PVQBFwNwdNnz9CnTx98/vz5u98/4R8F\\\\\\n\",\n              \"GCF1jMPhVHuQ/RoxcXFwOJxaqKhuXblyBT169hKoY4SIiAimzZgJI6OuGDduXLWdL764f/8+TE1N\\\\\\n\",\n              \"sXjxYqxbtw5iYmKYNWsWYmNjkZycjPbt28PJyQmnTp9C6O17sLK2qfL31L5DB5w+ex76nbvAzs7u\\\\\\n\",\n              \"l/yd1FcUYITUMVVVVcR//CBwu48f3kNFWeXHF1THUlNTodG8hcDtWmhqIS09ja91fXx8MHToUBw5\\\\\\n\",\n              \"cgRTp06t8JqqqipOnDiBw4cPw/PUKZy7cBmqqqrVbk9ERATbd+5GYWERfH19Ba6dfBvqRk9IHbOw\\\\\\n\",\n              \"sMCsWbOQkpICNTU1vtpwuVycOO4Bz5Mna7m6n09cXBzl33AWwykrQ3BQEHR0dKCmpvbV5dq1a9i3\\\\\\n\",\n              \"bx8CAwNh9K/nx/4rMTERvXubo03btnztv0GDBpgxey5cXV1hZ2cncP1EcBRghNQxOTk52NsPx9Ej\\\\\\n\",\n              \"h7BsxSq+2gQHXUOjRo1gbGxcy9X9fNra2jh27LjA7aKiHsHZ2RlTpkxBSkpKheXevXtISkpCVFQU\\\\\\n\",\n              \"MjIyICIigsGDB0NNTQ3q6upVBt2BAwexaOlygWqwsbXDX/PmICEhAc2aNRP4PRDBUIARUg8sWrQQ\\\\\\n\",\n              \"PXr0QI+evWDW27zadT9+/IipjhMBMNy5cwe9evX6KTX+LP369YOTkxMeP3pUaTSSryksLMRpz5OI\\\\\\n\",\n              \"iIiAlpYWWrVqVeH10tJSODo6QktLC7GxsZCXl0dmZmaloEtMTMSjR4+QkpKC169fo21bwR4Ul5CQ\\\\\\n\",\n              \"gKamFhITEynAfgIKMELqgVatWsHLywsjR47E2vUbMWbc+ErPFjHGcD0kGNOdJmPVqlXQ0tLEyJEj\\\\\\n\",\n              \"4ejoiNWrV/M9CWN916BBAzg5OWPnjq04eepslaOS/JeH+xF07dYNWlpalV7LycmBnZ0dZGRkcP36\\\\\\n\",\n              \"dUhJSQEAVFRUoKKigk6dOlW5TW1tbb72XZVf9PHaeoc6cRBST5ibm+PatWu4eP4sWrdsjpXLl+L8\\\\\\n\",\n              \"OS94X7yAHdu3Qr9jOyxdtAAuLvswa9ZMWFhYICoqCo8fP0aPHj0QFxdX12/hh5kzZzbevH6FDevX\\\\\\n\",\n              \"1hgGVwIDsHXzBuzYvr3SawkJCejVqxfatm2LCxcu8MKLH02bNsWrVy8Fqru0tBQfPrxH06aCzxpN\\\\\\n\",\n              \"BEcBRkg9YmBggJCQEISGhgLccvhevojzXqeR8PEDjhw+jJiYGNjYWPPWV1VVRUBAAMaPHw8TExMc\\\\\\n\",\n              \"PXr0l/jrX1paGleuXMG5s6dha22Jhw8eVHpfH96/x9LFCzHDeQp8fHwqjQsZGxsLExMTjB07Fvv2\\\\\\n\",\n              \"7RN4upPx48fD/fBBgdr4XL6Ejh07ftO8YkRwNJQUIb+IZ8+eYfTo0WjVqhUOHjzI94SP9RWHw0GL\\\\\\n\",\n              \"Fi1QWloKaWlpyMnLo2NHvX8G8/3wHtFPojB+/AT8+edcNG/evELbmzdvYuTIkdi9ezdGjx4t8L5z\\\\\\n\",\n              \"c3Oxbt067N+/HxGPY6Dzn3tqVeFyuehr1hMLFsyHvb29wPusL4Tp2ElnYIT8Ijp06ICHDx+iRYsW\\\\\\n\",\n              \"0NfXx/Xr1+u6pO+yf/9+ZGRk4PLly3jz5g22b9sG01490NXIELNnzUR8fDx27txRKbzOnDmDkSNH\\\\\\n\",\n              \"wsvLS+Dw4nA4OHToENq0aYOMjAwsW7YMw4dZIz09vdp2jDEsXrgAYmINYG1tXe265Aeqq1GEa5sw\\\\\\n\",\n              \"jahMyI8WFBTEmjZtyhYuXMiKi4vruhyBZWdnMykpKWZhYcF3Gy6Xy7Zu3co0NDRYTEyMwPsMCgpi\\\\\\n\",\n              \"urq6zMzMjEVGRvJ+vmLFSqato8O8ffxZfjGn0qj40bEvmf2IkczIyIhlZGQIvN/6RpiOnXQJkZBf\\\\\\n\",\n              \"VEZGBqZMmYKPHz/i9OnTQjV3mJOTE06cOIF3797xNaRUeXk5/vzzT4SGhuLKlSsCdWF//vw5Fi5c\\\\\\n\",\n              \"iNevX2Pbtm2wtrau1PvwwoUL2LJlCzIyM2E/fCSUlVVQXFyEW6E3Efs0Bo6Ok7FixXI0atRI4Pda\\\\\\n\",\n              \"3wjTsZMuIRLyi1JSUsKlS5cwY8YMmJqaws3NTSg6eLx79w7Hjh3DX3/9xVd4FRUVwd7eHs+ePcOd\\\\\\n\",\n              \"O3f4Dq/09HTMnDkTZmZm6NevH549ewYbG5squ87b29sjIiIC57y8INVQAgnxH1CYn4cpkx0RHx+P\\\\\\n\",\n              \"TZs2/hLhJWzoDIyQ38CrV68wZswYqKurw93dHSoq9XcMxZ49e+Lly5dITEyEhIREtetmZGTAysoK\\\\\\n\",\n              \"WlpaOHr0aI3rA0BJSQlcXFywZcsWjB49GqtWrRL6Di8/kjAdO+kMjJDfQJs2bXD37l3o6upCX18f\\\\\\n\",\n              \"V69ereuSqnTjxg08ePAAhw8frjGM3r17hx49esDMzAwnT56scX3GGC5cuIB27drh9u3buHPnDvbs\\\\\\n\",\n              \"2UPhJczq8gZcbRKmG5GE/Ew3b95kGhoabM6cOaywsLCuy+EpLy9nTZs2ZQYGBjWuGxERwdTV1dm+\\\\\\n\",\n              \"ffv42vaDBw9Yjx49mJ6eHgsJCfneUn9pwnTspDMwQn4zvXv3RnR0NFJSUtC1a1fExMTUi3tjLi4u\\\\\\n\",\n              \"SE1NxdmzZ6tdLzAwEIMHD4arqytmzpxZ7bqfPn3C2LFjYWNjA0dHRzx69Ah9+/b9kWWTOkQBRshv\\\\\\n\",\n              \"KDs7G5qamkhJTYWBgQHExcXRsmVLrF27DklJST+9noKCAixbtgwODg5o3br1V9c7cuQIHB0d4evr\\\\\\n\",\n              \"Cxsbm6+ul5+fj5UrV0JfXx8tW7bE69ev4ejoKPBoHKR+owAj5DfC4XAwc+YsGBkZoaikDEEhocjO\\\\\\n\",\n              \"L0ZWbiHOnPNGQmISOnTogFWrVv/Us7I5c+aAMQZXV9cqX2eMYfXq1di0aRPCwsLQvXv3KtcrLy+H\\\\\\n\",\n              \"u7s7WrdujQ8fPuDJkydYt24dGjduXJvlkzpCo9ET8pvgcrkYO3YsMrM+49mrt5CVla3wup6+Plxc\\\\\\n\",\n              \"D2DlmnWwtxmKz58/Y+/ePd88Iju/Pn78iOPHj2P79u2Qlpau9HpZWRmcnJwQGxuLu3fvfnV25OvX\\\\\\n\",\n              \"r2PBggWQlpaGj49PtZNVCgMOh4MGDRrU+ucvzOgMjJDfxO7dexD/KQHnvX0qhde/qaiowO9KEG7c\\\\\\n\",\n              \"uFHj/agfwcHBAerq6pgzZ06l1/Ly8mBpaYn09HSEhoZWGV6vXr2ClZUVnJycsHLlSoSFhQlleDHG\\\\\\n\",\n              \"8ODBA0yYMAGysrKQkJCAuLg4OnbsCBeXfcjNza3rEusdCjBCfgPl5eXYu3cPtu3YjYYNG9a4vqys\\\\\\n\",\n              \"LNZv3IxNmzbVal03btzAw4cPcerUKYiKVjwcJScnw8zMDC1atMDly5crPSicmZmJOXPmoGfPnjA1\\\\\\n\",\n              \"NcXz588xbNgwoTxjycrKQv/+/TFq9Gi07dARMc9fI7+Yg6zcQmzftRe3wsKgqamJ06fP1HWp9Uud\\\\\\n\",\n              \"9oGsRcLUFZSQ2ubv78+6dDGqNI5fdUt+MYcpKyuz2bNnMw6H88Nr4nK5TF1dnZmamlZ67fnz56xF\\\\\\n\",\n              \"ixZs/fr1jMvlVnitpKSE7dy5kykpKbGZM2eytLS0H17bz/T582fWsWNHNnfeAlZQUv7V30fE4xim\\\\\\n\",\n              \"oaHBjhxxr9V6hOnYSffACPkN3LhxE1Y2tgK1adCgAUY4jIafnw+ePn2KEydO/NB5rnbt2oX09HQ8\\\\\\n\",\n              \"fvy4ws/v3LmDYcOGYcuWLZg4cSLv54wxXL58GYsWLULr1q0RFhYmVOM7fs20adPQo5cZNm3ZVu3Z\\\\\\n\",\n              \"o27HjvC/Eox+5r1gbNwNHTp0+IlV1k8UYIT8BnJzc6HZUkfgdgoKChg1ahQaN26MLl26YN++fRg+\\\\\\n\",\n              \"fPh311NUVITly5dj+vTpUFNT4/38woULmDFjBjw9PTFgwADezx89eoT58+cjKysL+/fvr/CaMEtI\\\\\\n\",\n              \"SEBQUBBev4vn69Jn6zZt4DRtBvbt2w83t6p7bP5O6B4YIb+Bxo0bIy8/T+B2eXl5kJWVxbJly+Dv\\\\\\n\",\n              \"74/ly5dj0qRJyMsTfFv/5uzsDHFxcezYsYP3s927d+PPP/9EUFAQL6ASExMxceJEWFpaYuzYsYiK\\\\\\n\",\n              \"ivplwgsADh48hJGjxgjUzX/S5Kk4e/YMdeoABRghv4Xu3Y1x7UqgQG0YYwj094OxsTEAwMjICI8f\\\\\\n\",\n              \"P4aYmBgMDAxw//79b6rlw4cPOHXqFPbt2wdxcXFwuVzMnz8fhw4dQnh4OPT19VFQUIA1a9agU6dO\\\\\\n\",\n              \"aNKkCV69eoWpU6dCTOzXumgUeisU1gJe2m3atCnatWtf6dLr74gCjJDfgI2NDV6/eokXz5/z3Sbs\\\\\\n\",\n              \"VihSU1OwbNkyeHl5oaysDI0bN8bhw4exdetWWFtbY/369eBwOJXaFhcXw9PTE2Z9+qNth07oqGeI\\\\\\n\",\n              \"EQ6jefe3tLW1MX78eBQXF2PUqFGIjIzEnTt3oKGhgWPHjqFNmzZ4/fo1Hj9+jI0bN0JGRuZHfhx1\\\\\\n\",\n              \"qrS0FAkJCYiMjERqSgpkZeUE3oacvDydgYHugRHyW/jjjz/g5OSMFcsW49zFyzUOqVRcXIzVK5dj\\\\\\n\",\n              \"w4YNUFdXx969e7FgwQJMmzYNTk5OsLOzQ7du3TBhwgRcu3YNJ0+ehJaWFhhj2LZ9B/7+ewNEJJVR\\\\\\n\",\n              \"1FALIn+0Bkq5eB2eAP9AGxQV5MLD/RCysrJgY2MDNTU1BAUF4f79+5g/fz4aNmyICxcu8M78hEF5\\\\\\n\",\n              \"eTkyMjKQmpqKlJSUCst/f5abmwsVFRWoqqqisLAI+d9waTc3N7fKh75/NzQfGCG/idLSUgwZMgQq\\\\\\n\",\n              \"quo4cNgdf/zxR5Xr5efnY8QwG7yJi8Pz589492diYmLg4uKCCxcuwMrKCrNnz0bnzp2xa9cubN68\\\\\\n\",\n              \"GTt37kT43fs4de4ySlXMIdqw8pkFYwzcnPcQS78DJQU52Nvbw9nZGUuWLEFUVBS2bNmCESNG1Itn\\\\\\n\",\n              \"uRhjyM7OrjaMvvx3RkYG5OTkoKamxltUVVWr/G9FRUWIioqCMfbPQ8vyiti2YxffdaWnp6Nju1Z4\\\\\\n\",\n              \"9+4dFBQUfvj7FqZjJwUYIb+RwsJCjBk7FjExMZjqPB3jJ0ziHQRTUlJw7OgRuB8+iP79+6O0tBSJ\\\\\\n\",\n              \"iYnw8/Or0MkgKysL7u7u2L9/P5o0aYLZs2dDR0cHQ4daITOvFA207SDSoPq5ubgFKUD8FTiMtEdA\\\\\\n\",\n              \"QAAWLlyIuXPn8vWQ9ffKz8+vMoz++7PU1FRISkpWG0ZfFmVlZYiLi/O1/4yMDHh6esLd3R05OTnI\\\\\\n\",\n              \"y8vDu/gkSEpK8tV++7YtePv6FTw8jn7Px/BVwnTspAAj5DfD/jdkkaurK7y9vdGoUSMwxlBSUgIH\\\\\\n\",\n              \"h1GYMWM69PT0UF5eDmdnZ7x48QKBgYGVhp/icDjw8/ODi4sLXr58ieycfJQ3GwxRKWW+6uAkP4C2\\\\\\n\",\n              \"fCHCbt387hmiS0pKvnp29N//5nK5FcLnawGlqqrKd6jUpLy8HMHBwXB3d0dwcDCGDh2KyZMnw9TU\\\\\\n\",\n              \"FNY2NtDtqIfVa9fXuJ34+Hj07mkMPz8/GBoa/pDa/kuYjp0UYIT8xkpKSpCVlQVRUVEoKChUOovg\\\\\\n\",\n              \"crmYNWsWIiMjce3aNcjLy1e5nT179mDRys0Q1R7G975ZWQFE319ASnJilZ00OBwO0tPT+TpbKigo\\\\\\n\",\n              \"4IVQTWdLjRs3/mmXKN+9ewcPDw8cO3YMampqmDx5MkaNGlXhj4HU1FSYmJhg4uSp+Gvh4q/W9uH9\\\\\\n\",\n              \"e1hZDoKzkzMWLJhfazUL07GTAowQUi3GGObPn49bt24hKCgISkpKldaxHzEKPveSIaakK9C2RROu\\\\\\n\",\n              \"wH5IT6iqqlYKp6ysLCgoKPB1tiQvL19pLMW6UlRUBG9vb7i7u+Pp06cYM2YMHB0d0alTp6+2SUxM\\\\\\n\",\n              \"hJWVFUREROE0bQaGj3Tgnf3FPn2KQwdcceG8F9auXYfZs2fVav3CdOykXoiEkGqJiIhg586dWLZs\\\\\\n\",\n              \"GczNzRESElJpVPiU1FSIiDf6yha+rgx/ICEhAe3bt0eHDh0qhJOSkpLQPPfFGMPjx4/h7u4OLy8v\\\\\\n\",\n              \"GBkZYfr06bCysoKERPX3A4F/nu16+PAhAgICsGPHDsya4YzGjRujrKwMsrKymDr1n+lkmjRp8hPe\\\\\\n\",\n              \"jfAQjn8dhJA6JSIigo0bN0JCQgK9e/fG9evXKxxMG0o0BFi5wNvllpXi1q1biIuLg7q6OtTV1dGk\\\\\\n\",\n              \"SZMqv1dWVq53MypnZmbi1KlTOHr0KHJycjBp0iRERUWhefPmAm0nMTERhw4dxpEjhyEpJYWOnfSQ\\\\\\n\",\n              \"m5uDnOxsjB07DhMnTqDwqgIFGCGELyIiIlizZg0kJCRgZmaGGzduQFlZGTdv3kRWRhpYQSkgx/94\\\\\\n\",\n              \"i4wxSIrkI+DGDTRv3hzJyclISkpCcnIykpOTER4ezvs+OTkZ2dnZUFZWrjHoVFVVa/XMjcvl4vr1\\\\\\n\",\n              \"63B3d8fVq1cxZMgQ7NixA+bm5t90GdPHxweTJ0+G/QgH+AUGof2/Bul9/eoVDh86AENDQ+zZsxdj\\\\\\n\",\n              \"xoz+kW9F6FGAEUIEMmHCBPuDz38AACAASURBVERFRaFt27YQFRWFgYEB+vfvi+curmBcE4iI8ndY\\\\\\n\",\n              \"4eYnoCAvG+fPn8esWbPQvXv3atcvLS1FampqpaCLiIjgfZ+UlITMzEwoKirWGHRqampffRauKh8+\\\\\\n\",\n              \"fMCxY8fg4eEBJSUlODo6ws3N7asdW/gREBAAZ2dn+PhfhWGXLpVeb92mDbbt2IWJkybD2nIQxMTE\\\\\\n\",\n              \"MHLkiG/e36+GOnEQQqrF5XIRGRkJf39/BAQE4MOHDxg0aBDExcUREhKCmzdvolWrVjAz74u7r4sh\\\\\\n\",\n              \"pmJQ4zYZ40I0PgCyEqUQFRVFYWEhjIyMMGvWLFhYWHzXpUIOh4O0tLRKQffv75OTk5Gamgo5OblK\\\\\\n\",\n              \"wfbv/1ZQUMDDhw9x8uRJREVFYdSoUZg8eTL09fW/ub4vCgoK0KJFC1y87I9ufIw6EhMdjUH9zfHm\\\\\\n\",\n              \"zZtaeYD5C2E6dlKAEUIqycvLQ3BwMPz9/REYGAhFRUVYWlrC0tIS3bt3512iO3z4MNauXYvLly9j\\\\\\n\",\n              \"1qxZePDwEcQ1zNFAofVXt80YFw1S78CglQKuB1+Dt7c3Vq5cCQkJCYiIiCA/Px/Tp0/H5MmTq+zx\\\\\\n\",\n              \"WJWSkhJ4e3vjyJEjePPmDcrKyqCqqorhw0dgypTJVT5n9mX4p6qC7sWLF3jx4gXS09PBGIOUlBQ0\\\\\\n\",\n              \"NDTQrFmzKoPuy/LfWaOrc+jQYfgHBODcxct8t3GcMA5dDDtj/vx5fLcRlDAdOynACCEAgLdv38Lf\\\\\\n\",\n              \"3x/+/v548OABTExMYGlpCQsLC2hpaX213Y4dO7B48WKIiopCWloaJWXlQOMWKJNuW+GhZsa44Oa8\\\\\\n\",\n              \"R8P859Bvr40Afx/eeH5lZWXw8PDAunXr0KpVK8jKyiI0NBQ2NjaYNWsWulRxee2LCxcuYtasmWjf\\\\\\n\",\n              \"QRdTnaejc2dDNBATQ/zHjzhx7CguX7oIR8fJ2LZta7Vndp8/f8bp06fh7u6OzMxMTJo0CRMnTkTz\\\\\\n\",\n              \"5s2RlZVV4xldUlISJCQkarx02aRJE0hLS6Nz585Yt2Ez+vXnf3qY+/fuwWnyBLx69arWnmUTpmMn\\\\\\n\",\n              \"BRghv6mysjKEh4fzLg1mZ2fDwsIClpaW6NevH19zVIWFhcHKygrFxcXgcDiYOnUq1q1bhwMHDmKP\\\\\\n\",\n              \"yz6UccXAEWkIbnkZUJoDDqcMRw8fwKhRo6rsaFFUVARXV1ds2bIFvXv3hpaWFry8vKCqqoqZM2di\\\\\\n\",\n              \"xIgRFYabOnrUA6tWrcTZ85fQxcioyhozMzMxfowDFBXkcebMmQohxuVycfPmTRw9ehQBAQEYNGgQ\\\\\\n\",\n              \"HB0d0bdvX4EvY34ZO5GfoAP+GTA5O7+Y7yGovuxDUbYR0tPTBTrbE4RQHTvZL8rQ0LCuSyCk3klP\\\\\\n\",\n              \"T2cnT55kI0eOZPLy8qxLly5szZo1LDIykpWXlwu0rYMHD7LGjRszZWVlZmxszPr3789UVFTYgwcP\\\\\\n\",\n              \"GGOMlZWVsfDwcDZjxgzWv39/9vz5c9alSxd29erVGredk5PD1qxZwxQUFNi0adOYh4cHGzBgAFNR\\\\\\n\",\n              \"UWFLly5lHz9+ZBEREUxFRYVFx75kRWWs2iU7v5j1MjVj69atZ4wxFh8fz9atW8e0tLRYp06d2J49\\\\\\n\",\n              \"e1hGRobgH+g34HK5LDU1lYmLi9dYd1WLoqIiS0tLq7X6hOnYSQFGyC+My+WymJgYtnHjRmZiYsJk\\\\\\n\",\n              \"ZGSYra0tc3d3Z0lJSd+0zdLSUjZjxgymoKDAtLW1mY2NDbO2tmYcDof5+fkxZWVldufOHd76np6e\\\\\\n\",\n              \"bNSoUYwxxvbt28ccHBz43ld6ejpbsGABk5eXZ4sWLWL3799nc+fOZQoKCqyltjbbsm0H3wf+6NiX\\\\\\n\",\n              \"TF5BgfXv358pKCiw6dOns8jISMblcr/pcxAEl8tl8fHxzNfXl61fv57Z2dkxMTExlpKRLVB4ZecX\\\\\\n\",\n              \"M3FxcVZSUlJrtQrTsbN+jL1CCPlhioqKEBgYiJkzZ0JTUxPW1tZITk7G6tWrkZaWBm9vbzg6OkJd\\\\\\n\",\n              \"XV3gbWdmZqJ///7w8fGBjo4OBg8ejLS0NN6lOUtLS3h6esLW1hahoaEAAElJSRQVFQEAHBwcEBgY\\\\\\n\",\n              \"iOzsbL72p6SkhO3btyMmJgY5OTmwsLCAkpISQkNDkZ6WhnETJvFde+s2bdCmbVvo6OggISEBrq6u\\\\\\n\",\n              \"MDQ0/OH3kkpLSxETE4MTJ05g/vz56NOnD5SUlNClSxfs378feXl5sLe3h7m5Oc57nRFo25e8L8Ks\\\\\\n\",\n              \"d2+Buv//yug5MEJ+AYmJiQgICIC/vz9CQ0NhYGAAS0tLXL16FW3btv0hB+nY2FgMHToUoqKi0NPT\\\\\\n\",\n              \"g5mZGTw8PBAeHl5h1PYBAwbAy8sLI0aMwKlTp9CwYUNegCkqKvJed3Z25nvfzZo1w4EDB/DXX39h\\\\\\n\",\n              \"9erV2L59O3r2MhX4GawJExxxO+zmDxtl/vPnz4iOjkZ0dDSePHmCJ0+e4OXLl9DU1IS+vj709fWx\\\\\\n\",\n              \"aNEi6OvrQ01NrUJbZWVlzF+wAJOnOvP9+zl0wBWLFv71Q2r/FVCAESKEuFwuIiIieL0G4+PjMXjw\\\\\\n\",\n              \"YIwePRrHjx//rodrq+Lr6wtHR0fIysrC1NQU/fr1w5IlS3Dnzp0qn0kyNzeHt7c37OzsMG/ePBQX\\\\\\n\",\n              \"F/NemzhxIv7++2+BAuwLHR0dnDp1CitWrMDH+ASB2yurqPB99vdvjDF8+PCBF1JfAiszMxOdOnWC\\\\\\n\",\n              \"np4eunfvjunTp0NXVxdSUlI1brNPnz4QFxPDrp3bMX/BwhrXP3LoILIyM2BpaSlw/b8qCjBChERu\\\\\\n\",\n              \"bi6CgoIQEBCAwMBAKCsrw9LSEi4uLjA2Nq6V4ZMYY9i4cSNcXFzQuHFjjBw5Ev369YODgwNCQkLQ\\\\\\n\",\n              \"okWLr7bt2bMn/Pz8MHjwYCgqKvJ+PnDgQEyZMgV3795FVlYW8vLyIC0tja5du/I9L1jz5s0R9+at\\\\\\n\",\n              \"wO+nsLCwxnApLi7Gs2fPKpxVRUdHQ1pamndWNWbMGGzbtg3a2trfPAq+qKgoLl++jJ49e4JTVob5\\\\\\n\",\n              \"fy2q8ndYXl4Ot/37sGvHVoSFhQnNAMc/A30ShNRjb9684Z1lPXz4ED169IClpSVWr14NTU3NWt13\\\\\\n\",\n              \"YWEhHB0dERsbC1FRUcybNw+9e/dG//794eXlVe30IF9069YNrq6uGDduHLy8vDBy5Eg8evQIKioq\\\\\\n\",\n              \"GDhwILqb9ICcnBxycnLw8ME4DB48GLNnz640rFRWVhbCw8MRFhaGsLAwxMTEoFGjRuBwOAId0G9e\\\\\\n\",\n              \"D4G+3v+PopGens4Lqi9f37x5Ax0dHV5YWVtbQ09Pj++HqgWhoaGBu3fvYvTo0Thy6AAcpzjBytoW\\\\\\n\",\n              \"snJyyMvNRUCAH44cOgBVFRXcuXOn2ufxfkcUYITUI2VlZbhz5w4vtPLy8mBhYYHZs2ejb9++fD2b\\\\\\n\",\n              \"9SN8+vQJNjY2UFJSQnp6Onbv3g0TExP07NkT+/btg7m5Od/b6ty5M9TV1TFv3jz4+voiNDQUf87/\\\\\\n\",\n              \"C2PHT6xwqTM7OxueJ45j+PDhmDRpEnR1dREWFobbt2/jw4cPMDY2hqGhIW9U+tLSMgQG+MPK2oav\\\\\\n\",\n              \"OrKzs3HhwjlMmjgRFhYWiI6ORn5+PvT09KCvrw9zc3PMmzcP7du3r/CsWW1r2rQpbt26haioKOzf\\\\\\n\",\n              \"74pRI+yQm5sLaWlpmJiY4Nz/pmchVajrbpC1RZi6gpLfW1paGjt+/DgbMWIEk5OTY0ZGRmzt2rXs\\\\\\n\",\n              \"0aNHAj+b9SOEh4czdXV1Nn78eKakpMSCgoJYRkYGa9u2Ldu9e7fA2/v48SNr2rQpW7lyJWvatBl7\\\\\\n\",\n              \"/S6+2q7ice8/sWYaGqxTp05s+/bt7P79+8zb25sZGBgwERERpqioyJYtW8Y8PDyYgUFnlpVbyFcX\\\\\\n\",\n              \"9Nlz5zFVVVW2atUqdunSJfb+/XteF/qMjAy2bdt21rdvX9a5c2fWo2dPNnPmLBYbG/ujP956T5iO\\\\\\n\",\n              \"nRRghPxkXC6XRUdHsw0bNrDu3bszGRkZZmdnx44ePcqSk5PrtLajR48yZWVlNnv2bKampsYiIiJY\\\\\\n\",\n              \"QUEB6969O1u4cOE3bTMtLY3Jy8szOTk59vzVW77C5mXceyYnJ8emTJnCZGVlmaioKOvatSsLCQnh\\\\\\n\",\n              \"bbe8vJzZ2tqy/gMGsozs/K9uq7CUyzZs2sKaNm3KOnXqxCwtLXkPAhcXF7MZM2YyOTk5NnrMOHbJ\\\\\\n\",\n              \"N4DdvvuQBV0PZUuXr2RqamrM3NycvX379od8vsJAmI6ddAmRkJ+gqKgIN27c4HV1FxcXh6WlJdau\\\\\\n\",\n              \"XQtTU1O+Zu39XsXFxbh79y4yMzPRsGFDtGvXDjo6/8zfxeFwsGjRIvj6+sLBwQH+/v4ICwuDlpYW\\\\\\n\",\n              \"hg0bhpYtW2Lz5s3ftF8xMTHk5+djxEgHaLVsyVebFpqaGDzEEqdOnYKjoyNWrVqFjIwMPHnyBIsX\\\\\\n\",\n              \"L+Z1rigrK4O0tAwMOrbH3D/nY8z4CZCTk+O9pyuBATiw3wWvXr1CYWEh+vXrB0lJSejr6+PgwYPY\\\\\\n\",\n              \"tWsXpGXl8PRFXKV7XL1MzbB0+Uq47d+Hnj174vr162jXrt03fQakdtBYiHUgPz8fWVlZkJCQgKKi\\\\\\n\",\n              \"IvUq+kUlJCTwAuvWrVvo3Lkzb0T3Nm3a1NpgrP8VHx+Pffv249gxD2hr60C9SRMUFxfjUWQEOnbs\\\\\\n\",\n              \"iAkTJuDEiRMAAE1NTURERODq1atQVVXFtGnT8P79e/j7+/P98GxRUREePHjA63Dx4MEDMMZwLSS0\\\\\\n\",\n              \"yjmvvib6yRMMHtAHOjo6ePbsGZo0aQJ9fX3ePSt9fX00bdoUAHD79m24ubnhypUraKahAXFxcSQm\\\\\\n\",\n              \"JEBbWxszZszA8OHDUVxcjO3bt8PNzQ1mZma4desWevYyw2mv8zWOe+h54jg2rF+Dp0+f/rT7kHWl\\\\\\n\",\n              \"Ph87K6njM8BaU99Og8vKytjFixdZ3759maSkJGvWrBlTUlJiKioqbMmSpezDhw91XSL5ThwOh927\\\\\\n\",\n              \"d48tX76c6enpMUVFRTZ27Fh25swZlpWVVSc1Xb9+nSkrK7M5f85nsS/iKg1L5HHiFNPRacU6dOjA\\\\\\n\",\n              \"bG1tWe/evVl2djZjjLG1a9eyzp07s9zc3Gr3kZ2dzQICAtjixYuZiYkJk5KSYsbGxmzhwoXMz8+P\\\\\\n\",\n              \"paamMlFRUVZYyhVo2KTCUi4TFxdn169fr7GGLzIzM1l0dDSLjIxk8fHxVa6TkpLCJk2axKSkpFj6\\\\\\n\",\n              \"5zy+67GytmEHDhwU7BcghOrbsbM6dAb2E8THx8PCwgKNpWXgPH0mbO2G8S4ZfZky/LTnCSxdugwL\\\\\\n\",\n              \"Fsz/aX+Zk++Xk5NT4dksVVVV3lmWsbHxd03M+L0ePnwICwsLnDp7HqZmvb+6XlFREeztrPDx/Xs8\\\\\\n\",\n              \"ffoUkpKSOHLkCDZt2oTw8PBKI0ikpqbi9u3bvB6Cb968gZGREUxNTWFqaopu3brxRkovLCzE06dP\\\\\\n\",\n              \"0atXL+QWlgr8HuSlJWFsbAwVFRUoKCjUuPB7KXbFipXI+pyNnXtc+K7lekgwli3+C0+ePPml/x+t\\\\\\n\",\n              \"T8fOmlCA1bLk5GSYmJhg2ozZmDtv/lfX+/TpE6wtB2Hc2HFYunTJT6yQCCouLo7XzT0iIgI9e/bk\\\\\\n\",\n              \"zZtV3YO9PxNj7J9hjJauwDD74TWuX1RUhF7du2LLls3gcrlwcnJCWFgYdHR08PHjR15YhYWFIS0t\\\\\\n\",\n              \"DT169ICpqSl69eqFVq1aIT4+Hm/evMGbN2/w9u1b3vdZWVnQ1NRE3Js3iE9MFWgm4ZycHDRvogIf\\\\\\n\",\n              \"Hx9kZ2cjKyurxkVcXBwKCgpQVFSsNugWLVqEYyfPfHUKlqpwuVzoaDZDeHj4L/08Vn05dvKDAqyW\\\\\\n\",\n              \"DbGwgGGXrli+cnWN6yYlJaFXdyNcunQJXbt2/QnVEX6UlpZWeDaroKCAN29W3759a21epu8RHh4O\\\\\\n\",\n              \"x8mT8eTpC77PFk4eP4YTx9wRGxsLZ2dnfPr0CWFhYSgrK4OxsTG0tbWhqKiIkpISvH//nhdSBQUF\\\\\\n\",\n              \"0NHRgY6ODrS1tXnf6+joQFJSElu2bIGbmxtWrVmHufMW8P0eXPe54OH9cJw9e5av9RljKCgo4Cvo\\\\\\n\",\n              \"QkJCEPX0Be8eGr9Muhni0MGD1U6wKezqy7GTHxRgtejNmzcwMTHB63fxfD8YuWvndrx89hTHjx+v\\\\\\n\",\n              \"5epIddLT03HlyhX4+/sjODgYrVu35l0a1NfXr/eXkMaPH4+Oep0xe+6ffLcpKipCM3VliIqIoGXL\\\\\\n\",\n              \"lpCQkEBRURE+ffoExhhatWpVKaB0dHSgpqZW4fNgjOHatWtYvXo1Hj16BAkJCYiIiEBFVRWxL+L4\\\\\\n\",\n              \"GnqJy+Wis14HHDp4EKampt/0GfxXaWkpXr9+jdjYWMyeMwdhd+7z3Svyi856HeB19iw6duz4Q2qq\\\\\\n\",\n              \"j+rDsZNf1P2tFrm5HcC4CZMEeqp//IRJ6NBGGxkZGbUydI0wYozxxsxr1KgRlJSUfniAMMYQExPD\\\\\\n\",\n              \"O8t68eIF+vXrxxtrUFVV9Yfur7a9ePkSU5xnCtRGUlISbdq0hYK8HHr27FkhpBQVFWv8zDMzM7Fu\\\\\\n\",\n              \"3TocP34ceXl50NXVxZkzZ2BnZ4cuXbpAREQE8/+cg117XKrdFmMMK5YtgZqqKnr16iXQewD+Cb/3\\\\\\n\",\n              \"/7ufFxsby1vevn0LTU1N6OrqQkVZBffv3RUowNLS0pCclFRvLhMTCrBa9eDhA6xe+7dAbRQVFdFJ\\\\\\n\",\n              \"Tx/R0dHo27dvLVUmHPLy8uDpeQpubq74+PEj5OTkkJeXB0VFRUybNh2TJk0U6J7KfxUWFuLGjRvw\\\\\\n\",\n              \"9/dHQEAAJCQkYGlpifXr18PU1FSo51wqLSn5pvqVlZUxd85sDBkyhK/1GWPw9/fHmjVr8OTJEzRq\\\\\\n\",\n              \"1Ahjx47F6tWrK4Q+YwwJCQnIyvJHOYeD9Rs3857X+rfs7GysWbkcd8Nv4+bNmzUGXXJycqWgevHi\\\\\\n\",\n              \"BZSUlKCrqwtdXV1YWlpiyZIlaNu2Le+PSR8fH2zeshWjxozl+7M5fuwobG3tICMjw3cbUrsowGpR\\\\\\n\",\n              \"QX4+GjUS/JkRcXFxnDx5Eh8/foSSklKFRU5O7ptHvxYmd+/ehZ2dHbqb9MCW7bvQ27wPREREwBjD\\\\\\n\",\n              \"g/v3ceiAKzZu3IATJ07AwsKC7+1++vSJ92xWWFgYDA0NYWlpiZCQELRu3breXxrkl5KSEpKSEqFv\\\\\\n\",\n              \"YCBQu8TEhAojx39Neno6Vq5cidOnT6OgoACdO3eGr68vhgwZUukzDA8Px/Pnz7FgwQIsXboUc+bM\\\\\\n\",\n              \"QRvtFrAcaoVhw0dCRlYWuTk5CPDzhffF87C0tMTt27chKyvL20ZWVlaFkPqyiIuL84LKxMQETk5O\\\\\\n\",\n              \"6NChQ40h82V8yZs3rsO8T81/KGZmZuKQ235cvny5xnXJz0MBVotkZGSQnf1Z4HbZ2dlQVlLEnTt3\\\\\\n\",\n              \"kJGRUWHJy8uDvLx8pWD7sigqKlb6mYyMjFAdmO/fvw9ra2u4HzuJAQMHVXhNREQExt27w7h7dzy4\\\\\\n\",\n              \"fx8j7W3g7u7+1RArLy/Hw4cPeZcGk5KSMHjwYIwfPx6enp5VngX8CoYMGQL3w4cwxIL/uaOiHj9G\\\\\\n\",\n              \"Xm7uVzsoMMZw8eJFrF+/Hk+fPoWcnBymTJmClStXfnX+MT8/Pzg6OsLAwAAmJiaQkZHBsWPHUF5e\\\\\\n\",\n              \"joyMDBw6sJ83cG3PHj3x6NEjZGZmwtvbu0JQfbkk+WUZPnw4OnTowPf0K/8lJiYGDw8PjB49Ghcu\\\\\\n\",\n              \"+cGomk5Tnz9/hr3NUDg4jIKhoeE37Y/UDgqwWmRu3gc+l7zRp28/vtskJCTg+bNYtGvbBra2tujf\\\\\\n\",\n              \"v3+Fe2gcDgdZWVmVgi0jIwMJCQl48uRJpZ8XFxdXGWzVLVJSUnUSeqWlpbC3t8ch92OVwuu/uhkb\\\\\\n\",\n              \"4+z5SxhmY4m3b9/ywignJwfXrl2Dv78/rly5AnV1dVhaWsLNzQ3dunWr02ezfoa4uDi4ubkhKSkZ\\\\\\n\",\n              \"SUlJaNKkCV/tDh1whZOTc6XPJyUlBcuXL8e5c+dQWFgIY2NjhISEoE+fPtVuz8PDA0uXLkVAQAC2\\\\\\n\",\n              \"bdvGm5W5tLQUgYGBOHPmTIUzq2PHPLBp00a0adOGF1Rz586Frq4uNDQ0fvi/x759++LIkSOws7bA\\\\\\n\",\n              \"+ImOmOo0DZr/6h6fl5eHM6c8sWf3DlgNtcKWLd82lBapPdQLsRYlJSVBV1cXL9984Pu6+fq1q5EQ\\\\\\n\",\n              \"/xGdOxvg0qVLiIqKwsCBA2FrawsLC4tvuv5eUlKCzMzMKkOvqiU9PR0ABAo8JSWlHzKe37lz5+Dq\\\\\\n\",\n              \"dgBXg2/w3Wb82FForaMNBQUF+Pv7IzIyEr169eI9m9W8efPvrktYBAQEYMyYMWAAVFVUoaysgoBr\\\\\\n\",\n              \"wTV2JPL1uYx5c2biyZMnUFZWBmMMZ8+exYYNG/D8+XMoKSlhypQpWLZsWY1DKTHGsHXrVri5uSEw\\\\\\n\",\n              \"MBASEhKYMmUKFBQU8Mcff+Du3btISEhAq1atoKuri44dO/ICS1tb+6cPrfbu3Tvs3++KEyeOQ1tb\\\\\\n\",\n              \"B8oqKigsLETU40cw79MHM2fMqDGsfyX14djJLwqwWjZu3DhINJSCi+uBGv+CfPH8OQb2641bt27x\\\\\\n\",\n              \"Bg1NT0+Hr68vvL29cfv2bfTq1Qu2trawsrL65ssn/CgsLOQ78L4sEhISAgWegoICxMXFK+y3d+/e\\\\\\n\",\n              \"cJo+C3bD7Pmu9c6d2xhmbQkHBwdYWlqiT58+9fLZrNrEGMP27duxYcMG9Oxliu0790CjeXNMmjAW\\\\\\n\",\n              \"KcnJOOJxosog53A4OHbUHevXrkJAQADU1dWxZMkSeHt7o6SkBL169cLGjRsrTTBZ1f6Tk5MRExOD\\\\\\n\",\n              \"TZs2ISYmBi1atEBcXBykpaVRVlYGdXV1DBgwADExMbC2tsbs2bNr6+P4Jl/GcMzOzoaUlBR0dXX5\\\\\\n\",\n              \"Pnv9ldSXYyc/KMBqWV5eHnr06AGjrsbYs8/1q39dPomKgr3tUGzatBnjxlXdMyo3NxeBgYG4dOkS\\\\\\n\",\n              \"rl27Bj09PdjZ2cHW1rbOzzIYY8jLyxMo8LKysiAtLV0h1K5du4bk9M8CBRBjDCoKMkhISKhw4/93\\\\\\n\",\n              \"UVJSgunTpyMwMBBmvfvg6PGTvMuAXC4XG/9eB7f9LjDp0ROjxoyDmpo6SkqKced2GI57uEOrZUsM\\\\\\n\",\n              \"HDAAnp6eeP36NdTU1DB9+nQsXLiwyjO3z58/V7g/9aUXYIMGDdCgQQOIiopi/vz5SE1NRWBgIDjl\\\\\\n\",\n              \"5ejQQRcSEg2RlJyIh/cfYNKkSVi8eBF1Sa+H6suxkx8UYLXsy1+xBQUFyMvLw+Spzhg1Zhya/G9E\\\\\\n\",\n              \"8Pv37uLQAVeE37mNAwcOYvhw/s48iouLERwcjEuXLsHX1xeampq8MBOWKR+4XC7S0tLw+vVrxMXF\\\\\\n\",\n              \"4e3bt9i8eTMKSsoFvt+hpaGOR48e/XZ/MaekpMDOzg6SkpJ48eIFXsS9r/JSbkFBAc6dPQM/38u8\\\\\\n\",\n              \"mRBiY5+idatWePLkCTgcDvr06YPNmzfD4H89FwsKCvDixYsKIVVVhwpdXV1oaWlh+vTp+OOPP+Dm\\\\\\n\",\n              \"5oZRo0ZBoqEk/lq0BKZmvSv8Pj+8f49DB91w6uRxXLx4ET179vxpnxepWX05dvKDAuw7vX37FseO\\\\\\n\",\n              \"HcenT/Eo53LRRL0Jxo4dg44dO4IxBicnJ2RmZuLChQuIjo6Gq6sbfHwuIzMzExISEmjfvj2cnadh\\\\\\n\",\n              \"9OhR33zZi8Ph4Pbt2/D29salS5cgLS0NW1tb2NnZwdDQ8Kd3xmCMITc3FykpKUhOTq72a05ODpSV\\\\\\n\",\n              \"laGmpgZ1dXWEhITgQ0KKQL0Dy8vLoSwvjbS0tF9+qot/e/z4MWxsbDBp0iSEhoail5k5Vqxaw3f7\\\\\\n\",\n              \"zRs3wGXPLvz11wIMHjwYL1++rHBmlZSUVKFDxZelefPmFf5NZWRkwMLCAh06dMDevXsxePBg6HbS\\\\\\n\",\n              \"x45de6p95CMkOAiOE8YiODgYenp63/NRkB+IAqweqO1fwpMnT7Bk6VI8iozE6LHj0aGDLkRERPDm\\\\\\n\",\n              \"TRw8TxyDtrY2DA0NERwcjHv37kFaWrpCe8ZYrQQLl8tFZGQkvL294e3tjeLiYl6Y9ezZ87t64HE4\\\\\\n\",\n              \"HKSnp1cZRv/9WYMGDXih9N+v//5eUVGxQk1DraxgYWmNiY6T+a4rMMAfm/5ei4iIiG9+b/VBfHw8\\\\\\n\",\n              \"Dhw4iIsXLyAjIwPi4uJo3bo1nJycYG9vX+FynpeXF2bNmoW9e/ciJCQEnp6eePnmA9TV1fneX1pa\\\\\\n\",\n              \"GnQ0m0FUVJQ3QsW/Fx0dnRo7VHz8+JHXyWjjxo1wczuAyz4+uOwXyNfzikePHMa5s6cQGhrKd92k\\\\\\n\",\n              \"dglTgNF8YN8gJCSEKSkpMZf9B1hWbmGleYNyC0vZcc8zTFpamu3evbvW6qgJl8tlsbGxbP369czA\\\\\\n\",\n              \"wIApKyuzyZMnM39/f1ZcXMxbLz8/n8XFxbGwsDB27tw5tmfPHrZ06VI2ceJENmjQIKanp8dUVVWZ\\\\\\n\",\n              \"mJgYU1FRYXp6emzgwIFs4sSJbOnSpWzPnj3s3LlzLCwsjMXFxbG8vLxvrvnKlStM38BAoLmj+vbr\\\\\\n\",\n              \"zzw8PH7AJ1Y3iouLmaPjZKagoMBmzp7L7j18zD4lp7O3HxOZ14VLrF//AUxFRYV5e3uz8vJytnz5\\\\\\n\",\n              \"ctaiRQt27do1ZmxszKysrFjDhg0FmmvryyIlJcVSU1O/qe6nT5+yZs2asV27djHG/vn3pqury64G\\\\\\n\",\n              \"3+B7/zkFJUxNTY09e/bsR36k5DvQfGD1QG39FfHixQuYmZnh1Nnz6GVqVu26sU+fwmJQvzq5zs/l\\\\\\n\",\n              \"cpGRkVHhzOjFixe4d+8eXrx4gc+fP0NSUhIcDgcAKp0ZVfVVRUWlxr/Ic3Jy8ODBA+Tm5qJRo0Yw\\\\\\n\",\n              \"MDCoNJ9UTXV37NgRM2bNxeSpTjWu73P5EiZPHIcBAwbA1dVVoH3VB6WlpbAcOhRSjRrjyNHjX70E\\\\\\n\",\n              \"GhkRATtrSzRoIArGGFq0aIGoqChISEigpKQEXC4XeUVlAp/Vy0tLIjMzE1JSUgK1Cw8Ph52dHXbt\\\\\\n\",\n              \"2oXRo0cD+OcB9AkTJwo0Aj4ArF29EsWFBdi1a6dANZDaIUxnYPQgs4A2bdqEOfMW1BheAKDbsSM2\\\\\\n\",\n              \"bdmONWvWICQk5Ifsv6ioCCkpKV+9r/Tl+7S0NMjKylYKIVtbW0yfPv1/N/FjeVO+d+jQgdc9/1sG\\\\\\n\",\n              \"EY6NjYWLyz6cO+eFTnr6UFBQQH5+PiIjHqJf//6YPWsWX6OKi4qKwsfHB6ampuByuZji5PzVg+GF\\\\\\n\",\n              \"816Y4TwVV69eRWBgxiVVaQAAIABJREFUIDp16oStW7diwoQJldowxhAREQFXV1fcu3cP+fn5kJGR\\\\\\n\",\n              \"gZlZb8ycOaPWRxcvLy9HZmYm0tLSkJ6ezvt69qwXJBs1wqkz56q9vNvFyAjXQm7CtIcxGOPi4cOH\\\\\\n\",\n              \"kJOTQ/v27aGqqorr16/jxfPnaN+hA981vX71CjIyMpCUlBTovXwZXcPT0xMDBw7k/TwuLg6dO3cR\\\\\\n\",\n              \"OEQ7G3bBiWPuArUhBKAAE0hGRgZ8fX3xbNsuvtsMGz4Cy5YsxMuXL9G2bdsq12H/G239a/eT/v21\\\\\\n\",\n              \"qKgIampqle4nGRkZVQgqVVXVGgdztbOzA/DPWVNAQAC8vb0xb948GBoawtbWFjY2NtDQ0KjxPR4/\\\\\\n\",\n              \"fgILF/6F6TNn43HM8wr3YXJycnDa8yTGjRuHcePGY/36dTUe4HR0dBAWFgZbW1scPuQGJ+cZsLa1\\\\\\n\",\n              \"4w3mezUwAHv37EROdja0tLQQEBCATZs2Yfjw4Zg8eTJOnz6NQ4cOQVNTE8A/D6qOHj0aqWlpmOo8\\\\\\n\",\n              \"HXPnL0JjaWnkZGfD57I3Bg0ahHbt2sHT05PvMzgul4vPnz9XCqS0tLQqf5adnQ05OTkoKytDRUUF\\\\\\n\",\n              \"KioqkJWVRXT0EzyJfcnXvcl27dtj/l8L4bpvL3bu3Ilnz57B19cXycnJaNeuHdz2u8DF9QBf9QPA\\\\\\n\",\n              \"4UMH4Og4WaDA8fDwwLJlyxAQEFBpzrqysrJKz/XxQ1xcHKWlgs/WTAhdQhTAoUOHcP1mKI6fPC1Q\\\\\\n\",\n              \"uyWLFiAlKRHm5uZVhlJqaiqkpKRqvISnrq4OeXn5Wu1VWFRUhKCgIHh7e8Pf3x/a2tqws7ODnZ0d\\\\\\n\",\n              \"WrduXWn98+cvYN68P+F/JRhtq+m+n5aWBmuLQbC1tcWqVSv5qoXL5eLGjRvYt38/boWGIjc3F40b\\\\\\n\",\n              \"N0YnPT08f/YMiYmJyM/Ph4mJCebPn49p06ahrKwMO3fuxLZt27By5UoMGDAAffv2xV+LlmLajJlV\\\\\\n\",\n              \"diwoKyvDlk1/49TJkzhx4jhERUVrDKTMzExIS0tXCKQv3//3q4qKCho1aoT4+Hi8fv36/9q787ia\\\\\\n\",\n              \"8v8P4O/KrZTq1r23ui3aVFSkspSiGEkk61jLliX7rjHGMCg0ljF2RoiZMdZskWTJEsLI2kibJakk\\\\\\n\",\n              \"dNvvff3+mK9+07Tdqxoun+fjcR+j0/mcz+cc4/O655zP+Zzyz/nz58ncohkdOX5S6r+frKwssrYw\\\\\\n\",\n              \"IRsbG+rXrx/5+PjQn3/+SfPnz6c3b9/S/cTHUr36JTs7mxxatqCbN29K9SwW/jG7RlRUFFlbW1da\\\\\\n\",\n              \"5/jx4/TjqlUUFX1O6v0hItq2ZTPFX7tC4eHhMpVjGga7hPiZevEik8zNLWQuZ2FhSSeOHSUNDQ0S\\\\\\n\",\n              \"CoXUqlUr8vLyqhBOsrwzrCE1btyYevfuTb1796bS0lK6cOECHT58mDw8PEhbW7v8WTMHBwcqLCyk\\\\\\n\",\n              \"CRMC6Vjk6RrDi4hIV1eXDh09Qe0cW9GQIYPJ0tKy1rYoKipS165dqWvXynNJurq6UlRUFPXu3ZtO\\\\\\n\",\n              \"njxJbm5uZGRkRD4+PhQUFES9e/emUaNG0cKFCylkeSiNHlP9/TQOh0Pfff8DAQrk27s3Nbe2Jj09\\\\\\n\",\n              \"vfIAMjc3J2dn5wqBxOfzK51tSCQSevr0aXlAXblypfzPz58/p6ZNm5KVlRVZWVmRk5MTPXr0iPxH\\\\\\n\",\n              \"jKr1OPz7OLq5daTJkyeRkpIS+fn5EZfLpf3791N09Bnq59uTjp08XeNrZvLy8si3Z3cKDJwgVXhJ\\\\\\n\",\n              \"JBKaPXs2RUdH0+XLlyu9xbikpIROnTpF4eHhdP3qVUpLTa0wp2Btdu/aQQsWfCf1+gzzHgswGSgp\\\\\\n\",\n              \"KZJYLJa5nFgspq5dPWnTpo0N0KqGw+FwygNk3bp1dP36dTp06BANHDiQysrKyNLSktq2a08Ojo5S\\\\\\n\",\n              \"bU8oFNLwkaNp8+YttGrVSpnbU1BQUH4W1L59e/r+++/pr7/+oqysLHJwcKB+/fqRhYUFiUQiysrK\\\\\\n\",\n              \"IgBka9eyxvD6p+++X0gH9u+lFStWVHu/DgDl5OTQ9evXK5xNPXr0iJKTk0lHR6c8pKysrMjT05Os\\\\\\n\",\n              \"rKzIzMysUuDt3bv3g+436vB4NHfuXOJwOLRs2TLy8fEhBQUFcnFxoXfv3pFHRxdasHAx9e7Tt8Jl\\\\\\n\",\n              \"5JKSEjp6JIIWL1pA2VlZ1LZt1bPO/1NJSQmNHj2a0tLSKDY2tnzWeYlEQhcvXqTffvuNDh48SDY2\\\\\\n\",\n              \"NjR06FDS1dWjrVs2UcjyUKn25UZ8PL18mSn1+8cY5p9YgMnA1NSUfvt9r8zl7t29Q5bNZD9z+5Qo\\\\\\n\",\n              \"KiqSs7MzOTs704oVK+jevXvUu3cfWvPzepm2M3ZcILm5tKVly0JIIpFQdna2VPeQsrKySCKRlJ8F\\\\\\n\",\n              \"8Xg8SkxMpOTkZLKwsCA7OztycnKizZs30+HDh8nR0ZF69OhBY8ZPlLptCgoKNG78RNq4cSM5OTlR\\\\\\n\",\n              \"UlJShYD666+/6NGjRwSArK2ty0Nq4MCBZGVlRZaWljI9SK2mrk4ikUim40f09+U/T09PWrt2bYV7\\\\\\n\",\n              \"ZwoKCrRy5Y/UqVNHWrt2Lc2dNZ26eXmXv28r+vQpsra2phXLl5ORkRH17NmTeDxetSNk8/PzacCA\\\\\\n\",\n              \"AaSsrEynT5+mxo0b0+3bt+m3336j33//nXR0dGjo0KEVLkOmp6dTu3btqHOXr8izm1eV230vJyeH\\\\\\n\",\n              \"Akb504IF33/2bwhgGga7ByYDkUhETZs2pSvXbpLJ/wYI1Obt27dkaWZMDx48qHTpRd5pampSUupT\\\\\\n\",\n              \"mecfFAq0qbS0lEpKSmq8h/TvZU2aNKlw/2/q1KnE5XJp8eLF5cvWrVtHGzdupOjoaLK0tKSs3Lcy\\\\\\n\",\n              \"DSzIyckhCxNDUlRUpGbNmlU4m3r/4fP59XIf8rvvFtDrN29p1Zq1UpcpLCwkK/OmdO3aNTI3N69x\\\\\\n\",\n              \"3YcPH9KlS5f+/31bbm5kY2NT/vvo6Gjy8/OjM2fOVBqF+X52DTs7O/rmm29o37595S+vHDp0KA0d\\\\\\n\",\n              \"OpTs7OyqrPfSpUvUr18/WrQ4mPyGj6hyMNHNGzdo9Eg/GtB/AAUHy/bWcqZhydM9MPYgs4ymTp2G\\\\\\n\",\n              \"SVOmSf2g5tKQFdDT04OzszNiYmIapE0fi7KycpUPctf2adq0Kf78809IJJI61X/nzh0YGBigtLS0\\\\\\n\",\n              \"wvJZs2ahffv20NPTk7ltBSUSKCgooKSkpE5tk0Z6ejq0dXSQk5cvdfu2bd+J7t7e9daG33//HUZG\\\\\\n\",\n              \"RkhLSytflpaWBgsLC3h6esLFxQV8Ph8TJ07EpUuXIBaLpdpuQkICPDw8oK+vj2++/Q77Dkbg8NET\\\\\\n\",\n              \"+Hn9JrRt2w4mJib45Zft9bYfTP2RpweZWYDJ6MWLFzAxMcGmLb/U2tkcOHwUenp6ePDgAX799VdY\\\\\\n\",\n              \"WFigS5cuiIuLa5C2/dcMDAxw72GSTAGRl18EdXV1vH79ul7a4OLigiNHjlRYJhaL4evrCzU1NZkD\\\\\\n\",\n              \"LPdtAZSVlescrrWRSCTYv38/tLW1MWfuN1K1LTMnD0bGxli2bFm9tuXnn3+GlZUVUlJSEBISAhUV\\\\\\n\",\n              \"FaiqqmLo0KE4ceJEncL8/v37mDFjJnx69UI3Ly/4+/vj2LFjKCsrq8c9YOoTC7BPQEP+JSQmJsLU\\\\\\n\",\n              \"1BTjJ0zC/cTHlTqax2nPMPebb6Gvr4+rV6+WlyspKcHWrVthZGSEXr164fbt2w3Wxv/CpEmTMW/+\\\\\\n\",\n              \"ApkCYufu39CmTZt6C4gdO3agZ8+elZYXFBRAS0sLl6/ekKl9h44cR9t27eqlbdV58uQJfH190bx5\\\\\\n\",\n              \"c/j6+qKJhgaWha6ssV0ZWblwdeuIXr16QV9fH/PmzUNxcXGd21JcXIyjR4+iRYsWUFRUBIfDwaRJ\\\\\\n\",\n              \"k5Cfn18Pe8rII3kKsNpn22Qqsba2pqtXr5KGuhq5uzmTj3c3mjt7JgXNmUVf9+tNbR1a0rs3eRQX\\\\\\n\",\n              \"F0ft27cvL8fhcGjs2LGUlJREXbp0IS8vLxo8eDA9evToI+7Nh5s4cQLt2L5N6odQAdCGdWspIyOD\\\\\\n\",\n              \"HB0dKTw8vM4PsA4cOJDi4uLoyZMnFZY3btyYpk6dSj+vlW16oq2bN9LECRPq1KbqiMViWr9+PTk6\\\\\\n\",\n              \"OlKrVq2oRYsWlJWVRVaWlrQiJJjc3VzoSMTh8um9iIhevnxJoctDqJ1jK2rbpg0dPnyYEhIS6M6d\\\\\\n\",\n              \"O+Tq6vpB/+9IJBKKjY2l8ePHk4GBAYWGhpK7uzs1atSIWrZsSatXr/7iXgjKyKmPnaAN5b/6FlFY\\\\\\n\",\n              \"WIh9+/Zh5cqVCA0NxZ49e/D27Vupyr579w7BwcHg8/kYPXp0hfsQ8qK9szMGDx0m1eS7y0NXwc7O\\\\\\n\",\n              \"DkVFRYiMjISnpyeEQiGWLl2K7OzsD27D5MmT8f3331da/vLlS2hpaeHm7XtSnX3FnL8EPp+PgoKC\\\\\\n\",\n              \"uhySKt29exfOzs5wc3PDpUuX4OLiAi8vL+jq6sLIyAgtWrTA7Nmz0cHVFTo6OmhhYwOLZs3A5XIR\\\\\\n\",\n              \"EDAGN2/erLA9iUSCjRs3gs/nY9u2bVWe0cbHx2PkyFEwNzcHn8+HkZERWtnbQ1dXF7a2tli2bBlS\\\\\\n\",\n              \"U1MRFhYGfX19XL58GX369MGQIUOkvtfFfH7k6QyMBdgnIDc3F99++y10dHQwefJkvHjx4mM3qVZi\\\\\\n\",\n              \"sRjz5s2DiYkJWrdujaHD/PEsM6fKYHj1RoR58xfA1NQU6enpFbZz584djB49GlwuF+PHj8fDhw9l\\\\\\n\",\n              \"bsudO3dgaGhYaTAHAISH74ahkRH+vPOgxvCKvXwNenp6OHny5Acfk6oUFhZi/vz54PP52Lx5MxIT\\\\\\n\",\n              \"E9GsWTO4u7tDWVkZfD4fu3btqnBPKCMjA3fu3EFiYmKtM/vfv38f9vb26Nu3L3JycgAAT58+haur\\\\\\n\",\n              \"K0xMTbEkeBkS7iUi7Vkm7j54hBU/roK5hQWcnJzw6NEjLF++HCYmJkhMTATw96XXjh07YurUqQ1+\\\\\\n\",\n              \"H5D5NMlT38kC7BPy8uVLTJ8+HTo6OggKCsKrV68+dpOq9O7dO/Tp0wcdO3ZEVlYWRCIRAgLGQEtL\\\\\\n\",\n              \"C/7DR+LQkeM4F3sFx0+expRpM8Dj8eDTq1eNwZyZmYmFCxdCV1cXPXr0QHR0tEwdqIuLC44ePVrl\\\\\\n\",\n              \"73bs2Akul4sJk6ZUGnRy48+7GDFqNHR4vGrLf6jz58/DysoK/fv3x/Pnz3Hp0t9neEKhEEpKSpgz\\\\\\n\",\n              \"Zw4KCwvrXE9RURFmzZoFQ0ND7NmzB4aGRli8dBnyi8qqDGtRsRg//bweWlpcWFpa4tmzZxW29/r1\\\\\\n\",\n              \"a7Rq1QrBwcF1bhsjf+Sp72QB9gl68uQJxo4dCx6Phx9++EHqS5L/hfT0dNjb22PUqFGVBhFkZWVh\\\\\\n\",\n              \"2bLl6OrpiXbt26NLly6YOzcIKSkpUm+/oKAA27Ztg42NDVq2bImwsLAK7y6rTnWDOd5LTk6Gp2c3\\\\\\n\",\n              \"qKmpwbp5c7Rt2w7NLC0hFAphb98ac+bMkbqNtcnNzUVAQACMjIwQEREBAAgLC0Pjxo2hqKgIc3Nz\\\\\\n\",\n              \"pKam1lt9wN9fKubOnQt19SZYumyFVJdMN2zaAnNz8ypHGWZkZMDMzAzbtm2r13Yynz556jtZgH3C\\\\\\n\",\n              \"kpKSMGzYMOjq6mLlypUNcm9GFleuXIFQKMTKlSv/k2Hmp06dgpeXF/T19bF48WJkZWVVu75IJIKO\\\\\\n\",\n              \"jk6lS5T/NmPGDDg4OODcuXO4d+8eSkpK8OjRI/D5/DoP7ZdIJNi7dy+EQiEmTZqEN2/eQCQSwdvb\\\\\\n\",\n              \"GwoKClBTU8PkyZPrbQj5+xGEgwcPhqamJtq1a4cWLWxkehmoW8dO2L9/f5Xbf/ToEfT19ctDmPky\\\\\\n\",\n              \"yFPfyQJMDty9exd9+vSBoaEhNm3aVC/Dp2W1e/du8Pl8HDt27D+v+969exgzZgy4XC7Gjh1b7dt7\\\\\\n\",\n              \"J0+ejIULF9a4LbFYjIEDB2LgwIEVBiqMHDmy1rI1SU9PR8+ePWFra4srV66gtLQUW7Zsgbq6Oho3\\\\\\n\",\n              \"bgwtLa1qg0IWYrEYsbGxGD9+PHg8HlxdXbFhwwZkZWWhT9++WLdhs0yPDeza8zs6d+5cbX3x8fEQ\\\\\\n\",\n              \"CASIjY2tc9sZ+SBPfScLMDly/fp1dOvWDWZmZpVu/DeU94M1zMzMcPfu3QavryYvX77EDz/8AD09\\\\\\n\",\n              \"PXh5eSEqKqrCmeCdO3dgZGRU5WCOfyosLISbm1uFy4bJycng8Xgy33csKyvDTz/9BB6PhyVLlqCo\\\\\\n\",\n              \"qAiHDx+GtbU1tLW1oaenV+djJ5FIkJCQgLlz58LY2Bi2trYICQmpdBlSVVUVL7JfyxRgb0TFUFZW\\\\\\n\",\n              \"rvEybXR0NHR1dZGQkPDB+8DID3nqO1mAyaHz58/D1dUVLVq0wP79+xtsyPO/B2t8KgoLCxEWFgY7\\\\\\n\",\n              \"OzvY2tril19+KR8M4ezsLNVgjFevXsHa2hrr168vXzZmzBjMnz9f6nYkJCSgXbt26NSpExITExEb\\\\\\n\",\n              \"GwsXFxc0b94cZmZmEAqF8PT0/ODBOKmpqQgJCYGtrS2MjY0RFBRUbYgUFRWhUaNGMl0+fP/h8/l4\\\\\\n\",\n              \"+fJljW3Zu3cvDA0NZbqfycgneeo7WYDJKYlEgsjISDg4OMDR0RGRkZFS3Ze6fv06xo8PhFf37vDs\\\\\\n\",\n              \"1g0jR45EdHR0pRCsabDGp0IikSA6Ohre3t7Q09PDwoULsWbNGvj4+EhVPiUlBUKhsHwqqtTUVOjo\\\\\\n\",\n              \"6NT6TFpBQQHmzZsHgUCAbdu2ISEhAT4+PjAxMUFISAiEQiG4XC7mzp0r81lyVlYWNmzYgA4dOoDH\\\\\\n\",\n              \"4yEwMBCxsbG1fklJSkqCoqIi3oiKZQqvghIJGjduXOtwfQBYt24dLC0taw07Rr7JU9/JAkzOicVi\\\\\\n\",\n              \"7N+/Hy1atICbmxsuXLhQ5XqxsbFwcnKCqZkZFgcvw6EjxxFxLBKrf1qHlq1awdLSEvv2/X2P5r8c\\\\\\n\",\n              \"rFFfHjx4gHHjxkFLSwsqKiqIjo6Wqlx8fDz4fD6uXbsGAAgMDERQUFC168fExKBZs2YYOHDg/x4U\\\\\\n\",\n              \"HgldXV2sWbMGJ0+ehIaGBjQ0NLB3716p2/7u3Tvs2bMHPXr0gKamJoYMGYJjx45V+8WhqKgIx48f\\\\\\n\",\n              \"x4gRI2BlZQVlZWUQEbS0tBBxLFKmADt74TKaNWsm9d/zd999Bycnp09qZCxTv+Sp72QB9pkoKyvD\\\\\\n\",\n              \"rl27YGZmhm7duuH69evlvzt8OAICgQC//XEAomJxld/Co89egLGxMYYPHw6BQPBRBmvUh+zsbLRv\\\\\\n\",\n              \"3x5NmjSBp6cnTp48WWvnfPToUQiFQiQnJ+PJkyfQ0dGpdJaRk5ODUaNGwdjYGHv27MGsWbOgo6OD\\\\\\n\",\n              \"+fPnIy8vr3yYvFAolGqOy+LiYhw7dgxDhgyBpqYmvL29sWfPnirPhNLT07Fy5Up07doVfD4fCgoK\\\\\\n\",\n              \"UFJSQtOmTfH111/j119/hUgkwvbt29Gjp49MATZ4yDCsWrVa6uMrkUgwbtw4dO3aVarHGxj5I099\\\\\\n\",\n              \"Jwuwz0xxcTE2bdoEAwMD9OnTB/v37wefz5dqUtu/ktPB4/Gxbt26j70bdZKQkABDQ0Ns374d9vb2\\\\\\n\",\n              \"aNGiBbZu3VrjYwgbN26ElZUVcnJyMHnyZMycORPnzp3D1KlT4e7uAW0dHXTo0AFz5swpv7SXkZEB\\\\\\n\",\n              \"iUSCoKAgqKqqokOHDuWzYVTl/QjCwMBA8Pl8dOjQoXwE4XuFhYWIiYnBhAkTYGtrCxUVlfIh+I6O\\\\\\n\",\n              \"jpg5cybi4+OrvKQoEokgEAgQc/6SVOEVd/0WtLW1kZubK9PxLSsrQ9++fTFo0CA25dRnSJ76ThZg\\\\\\n\",\n              \"n6mCggKsXLkSXK42loeukvob+cGIY2jTps3Hbn6dOTs749ixY5BIJIiJiYGPjw8EAgEWLFhQ7Ywg\\\\\\n\",\n              \"QUFBcHV1xZo1a6CpqQlr6+b4YUkwtm3fifUbt6DfgK+hpq6Ovn374sWLFyguLoavry+UlZUxYcKE\\\\\\n\",\n              \"akc/JiQkICgoCE2bNoWNjQ2Cg4ORkpICiUSCtLQ0bNq0CT179oSenh4UFRWhqKgIPT09eHt7Y8OG\\\\\\n\",\n              \"DcjIyJB6v6OioqCrq4vzF+Nq/Hu+duM2DAwMyi8by6qwsBCdOnXC5MmTK53hSiQSvH37Fq9evWKv\\\\\\n\",\n              \"TZFD8tR3sgD7jGVmZoLL5co0tDq/qAwmpqYVLkHKo7CwMPTq1avCssTERAQGBoLL5WLkyJGVRvSJ\\\\\\n\",\n              \"xWLY29vDuGlTRJ05V+WIvmeZOZg99xuYmJiU338KCwurVP/7EYR2dnYwNjbG3Llzce3aNZw7dw4z\\\\\\n\",\n              \"ZsxA69atoaqqCiUlJXA4HDRv3hxjxozBqVOnIBKJ6rTvx48f/3uC6ICxuBr/Z4X230q4j8CJk8Hj\\\\\\n\",\n              \"8fD779Lfp6vK+ymnli5dCuDvS53z5n0LXV1dqKurQ0tLC6qqqhgyZAguXrwoN/dTv3Ty1HeyAPuM\\\\\\n\",\n              \"bd26FYOHDJN5WPW3332P2bPrb2qljyE/Px9aWlqYPXs2vLp3h6ubG7p7e2PZsuVITExEcHAwDAwM\\\\\\n\",\n              \"8NVXX+HEiRMQi8XYtGkzLC2t8CQjq9ZjtGrNz1BXV8f58+fL68zOzsaGDRvg+r8Z5YcMGYKgoCD0\\\\\\n\",\n              \"7dsXBgYGUFRUhJKSEjQ1NeHm5obg4GDcuXOnQS7DvXjxAkuWLIWxsTFMTE1h37o1zMzNIRQKsWDB\\\\\\n\",\n              \"93j69Gm91JORkQFTU1P06dMHOjo6mDRlGm7ffVh+nDKycvHjqp9gaWWFrp6eyMvLq5d6mYYjT32n\\\\\\n\",\n              \"AgB87Fe6NIQ2bdrQjRs3PnYzPqqQkBB6/eYdLQleJlO5nWHb6Y/f99DEiRNJXV2d1NXVSU1NrcJ/\\\\\\n\",\n              \"1dXVSVVVlRQVP71Xyr1+/ZqmTJlCEUeOUP8BX5Nv776kpaVFeXl5dDTiMB09cpj69etPK1f+SJGR\\\\\\n\",\n              \"kbR69WrKz8+nFy8y6fzFK2RrZydVPX5DBlK7tm3IyMiIdu3aRZcuXSJTU1MqKSmh58+fU1lZGUkk\\\\\\n\",\n              \"EjIwMCAXFxfq3bs3de7cmYRCYQMfgf9XVlZGqamp9PbtW9LQ0CAzMzPicDj1WseECRPo7NlzFBVz\\\\\\n\",\n              \"nvT19atcRywW06wZ0+hm/DU6f/48e9/YJ0ye+s5GH7sBTMPhcDhUWloqc7mSkhK6Hh9PNwMCSFlZ\\\\\\n\",\n              \"mZSUlEhRUZEAkFgsptLSUioqKqKSkhJq3LhxhVCrKuhkXfb+z2pqajIHZHZ2Nnl4eFAnjy6UnPaM\\\\\\n\",\n              \"tLS0Kvzep5cvLf9xFc2bO5u6detGMTExpKenRwMGDCDr5s2lDi8iogmTplJvH28qLS2hsrIy4nA4\\\\\\n\",\n              \"lJycTC1atKBJkyaRt7c3tW/f/qN21o0aNSJLS8sG235UVBRFnT5NsZevEZ/Pr3Y9JSUlWrN2HY0L\\\\\\n\",\n              \"GEWzZ8+hTZs2NlibmC8HC7DPmKWlJR09dlzmcrduxtOc2bOpf//+9OLFC8rIyCj//PPnFy9ekLKy\\\\\\n\",\n              \"MvF4PBIIBKSjo0NcLpc0NTVJQ0ODGjduTKqqqqSsrEzFxcUkEokoOzub0tLSqKCggEQiEYlEovI/\\\\\\n\",\n              \"/3tZYWEhqaioSB2EampqtHfvXurVuy+FLA+tdv90dHRo87btNGPqZLK3t6fXr1+TgoICTZ46Xabj\\\\\\n\",\n              \"1MHVlbS4WtTU2JiGDh1KHh4eZGNjQ0pKSjIfc3n109q1NO/bBTWG13sKCgq0dNkKam3XnJYtCyEu\\\\\\n\",\n              \"l/sftJD5nLFLiJ+x0tJSMjExoeMno8nG1laqMq9fvyYbK3P666+/SFdXt8Z1JRIJ5ebmVhlu//w5\\\\\\n\",\n              \"MzOTNDU1SSgUkoGBQfnnnz8LhUISCoWkrKxcYftFRUW1Bt37/969e5euxMVRwr1Eqc7cysrKyLyp\\\\\\n\",\n              \"AeXl5ZGGhgZFHDtJbdu1k+o4vdfDqys1szAnBwcH4nA4pKysTBwOp9JH1uUcDueTvDz7TykpKdSu\\\\\\n\",\n              \"XTtKSn1KjRs3lrrccL8h5NK+Pc2YIdsXBua/IU99JzsD+4xxOBwaM2YsrQxdTtt3hpOCgkKtZdat\\\\\\n\",\n              \"XUM8Hp8kEkmt6yoqKhKfzyc+n0+tWrWqdj2JREKvXr2qFG4PHjygM2fOlAfdy5cvSUtLq1K4/ftn\\\\\\n\",\n              \"fX39Ku/j9Ovfn6ZMnSF1x9+oUSOaMn0mHT6wj54/zyCxWCxVuX8qLS2lzMxMSkhIoNLSUiopKaHS\\\\\\n\",\n              \"0tJKH1mXl5aWkpKSUp2DsCGXnz59mtw9OssUXkREPX186fjRwyzAmDpjAfaZmzVrJrm5udGy4CU0\\\\\\n\",\n              \"b/6CGkPs91/30M6w7dS3bx9q1aoVhYSEUEBAgFTBVxNFRUUSCAQkEAjI3t6+2vUkEgllZ2dXOpO7\\\\\\n\",\n              \"e/cuRUVFlQffy5cvSVtbu0K46evr0/Fjx2jLLztlatvwEaNoyaLvSVVVle7euU3OLi5SlxWLxZSe\\\\\\n\",\n              \"lkpbTp0iGxsbmeqtDQAqKyurcxDWtlwkEn3wdl6+fEnunbvIvG+ampr07t27ej1ezJeJBdhnTktL\\\\\\n\",\n              \"i6Kiosjb25vu3EmgadNnkbOLS4VQun/vHm3asI5OR52k06ejyM7OjsaNG0djx46lX3/9lbZu3dqg\\\\\\n\",\n              \"AwHeU1RUJD09PdLT06PWrVtXu55YLKbs7OwKlykfP35MKqqqpKGhIVOdurq6JJFICAD9vHYNjRkX\\\\\\n\",\n              \"KHVgnzoZSQYGBvUeXkR/3y96f6bzqfrtt9/o0OEImcvl5eWRpqZmA7SI+dKwAPsCGBgY0OXLl2nr\\\\\\n\",\n              \"1m00LmAEqTZuTDa2dqSkqEQpyY/pyZN0Gjt2HN24caP8vpe9vT3FxcXRunXryMXFhWbNmkWzZ8/+\\\\\\n\",\n              \"JDrUt2/fUkpKCiUkJNDVq1fp7t27lJqaSiXFxTJv6314WVhYUFpaOl04f448pDirkEgktDxkKXXp\\\\\\n\",\n              \"7PEBe/B5cHV1pSlTppBIJJJppOXxoxHU2cOj4RrGfDHYII4vjEQiocuXL1N6ejpJJBISCoXk4eFR\\\\\\n\",\n              \"YzClpaXRhAkTKCMjg7Zt20btZBzo8CEAlN8nu3r1KsXHx9PDhw/p+fPnVFxcTEpKSlRWVkZ8Pp/M\\\\\\n\",\n              \"zMyoVatWdOjQIYo6c16mofDXrl4lX5/utHHDBlJXV6fRowMo5vxFalHDWZVEIqHZM6fTlUux9OrV\\\\\\n\",\n              \"K/L29qYff/yx0pD9L4Fv797k3aMXjQoYI9X6GRkZ1Ka1HaWlpbGzsE+UPPWdn/YwJ+aDAaArV66Q\\\\\\n\",\n              \"n58fGRoakoaGBgmFQurbrx8VFhbS0KFDafjw4eTp6VnrWZWpqSlFRkZSUFAQ+fr60owZMyg/P79e\\\\\\n\",\n              \"2ikWi+nx48dkP06oAAAZ1UlEQVR04MABmj59Onl4eJCRkRGpqKiQiYkJeXl5UUhICN2/f58sLS1p\\\\\\n\",\n              \"1qxZFBERQQ8ePKDi4mJ6+fIlnT17ltq1a0cqKiq0Yd1amepfu2YlNVZVpczMTNq0aRNZWJhTd8/O\\\\\\n\",\n              \"tHnjBnr79m2l9W/Ex9OgAX0p4c+bdO7cObp//z4pKiqSnZ0dnThxol6OiTyZPm0arVi2lDIzM2td\\\\\\n\",\n              \"VyKRUNCcmeTn58/Ci6kfH2cCkIYnT9Oh1LcnT56gffv2sGjWDMtDVyExKRWZOXl4nPYMGzZtRSt7\\\\\\n\",\n              \"e1hbW3/Qa+6zs7MxfPhwmJqa4uTJk1KXKywsxO3bt7F+/Xr4+/vD0dERPB6vfHolRUVF8Hg8ODo6\\\\\\n\",\n              \"wt/fHxs2bMDVq1fx+vXrarf5+PFjzJo1CzweD7169cKePXugpaWFRylPpJoy697DJGhpaeHAgQPQ\\\\\\n\",\n              \"09MDh8NBQEAAwsPD0X/AAHC5XAwZ6odpM2Zh/ISJaNHCBqampli+fEWlme1jYmJgZmYGf3//Gmek\\\\\\n\",\n              \"/xz98MNi2Nja4q/k9GqP9duCEowYORqurq41vhWA+fjkqe9kAfaZSU9Ph7GxMUJW/Fjlu7/ev/8r\\\\\\n\",\n              \"bOdu6OrqSvXuqqpERUXBzMwMw4YNq/A6kLy8PJw9exYLFy6Ej48PrKys0KRJEygoKEBBQQEqKipo\\\\\\n\",\n              \"2rQpunTpgpkzZ+LgwYNITk6WetZysViMyMhI9OjRA3w+H3PmzEFKSgoyMjLg5eUFZWVlmJqaIeVJ\\\\\\n\",\n              \"Ro3hlfg4DWZm5li7di3c3d3h7++PZ8+eITg4GMbGxnB2dsZPP/2E9evXw8zMDAEBAVBRUUFhYWG1\\\\\\n\",\n              \"bcvPz8e0adMgFApx4MCBDzqu8kgikWDlylXgcrkYNXoMrly7CVGxGAUlEqQ8ycAPS4LRtGlT+Pbu\\\\\\n\",\n              \"LdWbn5mPS576ThZgnxGJRAJHR0epX5+y+7c/YGxsLPM3YolEgmfPniEsLAytW7cGh8OBtrY2OBxO\\\\\\n\",\n              \"eVBpaWnB1tYW/fv3x/Lly3HlypU6TeSam5uL1atXw8LCAg4ODggLC0Nubi727dsHT09PNGrUCHp6\\\\\\n\",\n              \"eti7dy8WL14MXV09rFrzMzJz8irsc0ZWLlb8uBp6enrQ1NSEg4MDxowZU2FC3bKyMhw5cgTdu3eH\\\\\\n\",\n              \"QCAof4lly5YtpZql//Lly7C2tsaAAQOQmZn5wfssbzIzM7F0aTBMTU3LZ9nX0NBAQMAY3Lx582M3\\\\\\n\",\n              \"j5GSPPWdLMA+I2fPnoWNrW2VrwGp7tPNqzt27dpV5fbKysqQkJCANWvWYMiQIWjdunX5ZT8igrKy\\\\\\n\",\n              \"MgwMDNCqVSvweDzY2tri3Llz9foOqNu3b2Ps2LHgcrkYOnQoLl++jEuXLmH8+PHQ0dFBmzZtoKOj\\\\\\n\",\n              \"g5kzZ5a/j+v8+fNo0qQJdHR4UFdXh2c3L3w9aDC6eXUHl8uFn58fTp8+DRMTE2hpaVU4g/y3pKQk\\\\\\n\",\n              \"GBsbQ1NTE8bGxhg7dqxU+1dYWIhvvvkGurq62LNnzxf3KpHS0tIaz1aZT5c89Z1sFOJnZMDXX1PH\\\\\\n\",\n              \"Tp1p/ISJUpc5cfwYrQhZSkuWLKbY2Fi6desWPX78mDIzM0kkEhERkbq6OgmFQrK0tCRHR0dyd3en\\\\\\n\",\n              \"du3aVbgRX1paSqtXr6Yff/yR5s2bR9OmTaNGjT7sKY3S0lI6fPgwrV+/nlJSUigwMJC6detGp06d\\\\\\n\",\n              \"ovDwcOJwODR8+HAqLi6mTZs20S+//EK9evUiAHTo0CEaNmwYmZqa0vbt28nc3Jzi4+Pp3bt3pKmp\\\\\\n\",\n              \"Sc7OzkRE5OnpSZ6enqSoqEhXrlyh6OhoUlVVrbI9nTt3pqCgIDp06BAdPHiQNDQ0aPz48RQQEFDr\\\\\\n\",\n              \"dFs3btyg0aNHk4mJCW3evJkMDQ0/6JgwzH9FrvrOjxygDUaevkXUF01NTaneZfXvF1hyOBwoKSlB\\\\\\n\",\n              \"IBDA0dERfn5++Pnnn3Hv3j2Z31WVlJSELl26wMnJCbdu3ZKpbEZGBhYtWgShUAh3d3fs3LkTmzdv\\\\\\n\",\n              \"RseOHSEQCDBlyhTEx8cjLy8PAwYMgKOjI1JSUgAAFy5cgIuLC/T19dG6detq252RkQEbGxt89913\\\\\\n\",\n              \"kEgkEIvFGDhwIIYMGVJtGQ8PD5w9exZJSUkwMjLCjRs3EBAQUH5WWNvLGouLi7Fo0SLw+Xxs27bt\\\\\\n\",\n              \"izsbY+SLPPWdLMA+ExKJBERU7cCNmj4GBgZ48uRJvbYlLCwMAoEAc+fOrfENwxKJBDExMfDw8ICa\\\\\\n\",\n              \"mhq8vLywePFiDBo0CJqamujXrx8iIiJQXFwMAEhISIClpSUCAwNRWFiIhIQE9OjRA6amplizZg20\\\\\\n\",\n              \"tbXx6NGjKut6+vQpLC0tsWTJkgrLCwoK4OLigvnz51dZ7n2ASSQSCASC8mOVm5uLn376CVZWVmjZ\\\\\\n\",\n              \"siU2bdqEt2/fVruvCQkJcHJyQteuXZGamlrTIWSYj0ae+k4WYJ8RNTU1ZOW+lSm8Ckok4HK5ePXq\\\\\\n\",\n              \"Vb23JzMzE4MGDYKFhQXOnDlT4XcikQjLli2DgYEB1NTU4OjohO7ePdDB1Q2amppo1aoVwsPDK5yt\\\\\\n\",\n              \"7NixA3w+H+Hh4UhJSYGfnx/09PSwdu1aFBUVoU+fPli0aFGVbUlNTYWZmRlWrlxZ5e+zsrJgYWGB\\\\\\n\",\n              \"7du3V/rd+wADgN69e2Pv3r0Vfi+RSHDmzBn069cP2tramDhxYrWPKJSWlmLFihXg8XhYt25dg7yN\\\\\\n\",\n              \"mWHqQp76ThZgnxF3d3f89scBmQIs9vI1mJmZNWhHeuzYMRgbG2PUqFG4ceMGZs2aBQ0NDTRpooEp\\\\\\n\",\n              \"02bgwV/JFdr0+l0hwnbuhl3Llhg+fDjevHmDMWPGwNraGhcuXMCUKVOgo6ODRYsWlZ/xHDt2DM2a\\\\\\n\",\n              \"Naty4MCjR4/QtGlTrF+/vsZ2JiYmQldXF9HR0RWW/zPAVqxYgalTp1a7jWfPnmHhwoUQCoXo1KkT\\\\\\n\",\n              \"9u7dW372+O+6OnToADc3N/z111+1HkOG+a/IU9/JAuwzsnfvXnh07iJTgPn5j8CKFaEN2i6xWIwD\\\\\\n\",\n              \"Bw7A1NQUCgoKsLS0RBMNDZw5d7HGtuXk5cPT0wv6+vro27cvgoKCoKOjg2nTpuHly5fl2xeJRDA1\\\\\\n\",\n              \"NcXp06cr1f3gwQMYGhpi69atUrX1woULEAgEFc6g/hlgFy9eRJs2bWrdTklJCfbv34/OnTtDX18f\\\\\\n\",\n              \"8+fPR3p6eoV1ysrKsHbtWvB4PISGhpaPomSYj0me+k4WYJ+R4uJiGBgY4OiJU1KFV9z1W9DW1kZ2\\\\\\n\",\n              \"dnaF7RQUFCA5ORmJiYl1mlXin89uWVpaomPHjmjSpAmaNGmCfQcjpGpjTl4+TExNweVy4e/vX+W9\\\\\\n\",\n              \"o3nz5mHQoEGVlickJEAoFCI8PFymdu/ZswcmJibIyMgAUDHACgoKoKamhvz8fKm39+DBA0ydOhU6\\\\\\n\",\n              \"Ojrw9fXFqVOnKpzxJicno3Pnzmjbtu0HzY7CMPVJnvpOFmCfmYsXL0IgEODk6Zgag+Fq/J8wMDDA\\\\\\n\",\n              \"/v3/P2NEQkICxo0bDy6XCxMTE1g0awZNTU106dIFBw4cQElJiVRt+Hs746ChoQE7OzsIhULY2toi\\\\\\n\",\n              \"NDQU+/btg62tnUzPqq3bsAlffdW1yroePHgAHo+H58+fV1h+48YN6Onp4Y8//vig47h48WI4OTkh\\\\\\n\",\n              \"Pz+/QoABQPv27XH+/HmZt5mfn49t27ahdevWsLCwwMqVK8u/IEgkEmzZsgV8Ph+LFy+u9lhnZ2dj\\\\\\n\",\n              \"+fIV6Ny5MxwcHNDB1RVTpkzFgwcPPmg/Gebf5KnvZAH2GTp37hwEAgEGDh6CmPOXKoTFtRu3ETBm\\\\\\n\",\n              \"HHR0dPDHH/sA/H2Jb8aMmTAwMMD3ixZXmIbpjagYu/b8DpcOrmjdunWloHivpKQEf/zxB5ydncHl\\\\\\n\",\n              \"ctG0aVPw+XxMnz4dt27dKh+MMXDQIPz08waZLnNmv34HbW3tSnVLJBJ4eHhg7dq1FZbHxcVBV1cX\\\\\\n\",\n              \"hw8f/uBjKJFIMHLkSPj6+sLd3b1CgM2YMQMhISF12nZcXBz8/f3B5XIxYsQIXLt2DRKJBE+ePIG3\\\\\\n\",\n              \"tzfs7e0rzF5RUFCAcePGQ0tLC/7DR+LI8ZO4FBeP0zHn8c2330FPTw9dunRhoxuZOpOnvpMF2Gfq\\\\\\n\",\n              \"78t3a2BpaQk9PT1YWVvD0NAQRkZGWLx4CV68eAHg78504sRJ6ODqhoys3BpHK/6wJBjNmjWrcMkx\\\\\\n\",\n              \"IyMD33//PXR0dMDn86GmpoYBAwbg+PHjVZ5F2NvbI+76LZmH+nfo4IrY2NgK2woPD4eDg0OFe0ex\\\\\\n\",\n              \"sbEQCAQ4ceJEnY9hcXExunTpAiMjowoBtn//fvj4+NR5+8DfZ1ShoaEwMzODk5MTfvnlF+Tn5yM8\\\\\\n\",\n              \"PBwCgQDz5s3Dq1ev0LFjRwwYOAjPMnOqPD5vRMUIXh4KAwMDNiiEqRN56jtZgH3mxGIxnj59ivv3\\\\\\n\",\n              \"7yM9Pb3SQIGjR4/CunnzSnMGVveZPnM2Bg0ajIsXL6Jbt25QVlaGqqoqHBwcsHXr1hpnjweA5s2b\\\\\\n\",\n              \"4+btezIHmEfnLhVGB+bm5kJfXx/Xrl0rX3bmzBkIBIJKowjr4vXr11BTU8PkyZMB/B34+/fvB5fL\\\\\\n\",\n              \"RSd3d7h17IiBgwbh6NGjdZpC6/0kxT4+PtDR0cH06dNx8eJF9O3bF7q6uhg4aIhUz/ht3LwN5ubm\\\\\\n\",\n              \"NT57xzA1kae+kwXYF86zWzeE7dwtdZBk5uShcePG4HA44PF4CAoKwuPHj6Wuz83NDccio2QKr4IS\\\\\\n\",\n              \"CSytrCrM7BEYGIjAwMDyn0+cOAGBQIALFy7U6/EBAGdnZ/B4PCxatAg2NjZoYWOD0FVrcPJ0DKLO\\\\\\n\",\n              \"nMPGzdvQrl17mJiYYPfuPXWuLzU1FfPmzYOuri46dOgATS0tvHojkvp4dffugbCwsHrYc+ZLJE99\\\\\\n\",\n              \"JwuwL1hSUhIEAgFevyuUKVD8h49AQEDAB02JtGrVagweMkym+i5cugpzc/PykXvXrl2Dvr4+cnNz\\\\\\n\",\n              \"AQAREREQCASIi4ur1+PznoeHB0aPHg1NTU0cOX6y2gEoFy5dhamZGUJDf6yXeouKitDL1xcTJ0+V\\\\\\n\",\n              \"6XgdOnJcqqH+DFMVeeo72RuZv2A3b96kjp3cq53Etjo9fHwpOyeHFBQUZK5z1KiRdOrkCcrKypK6\\\\\\n\",\n              \"zE+rfyQej0cZGRlUVlZGgYGBFBoaStra2rRv3z4aP348nTx5snyi3vr29u1bOnLkKJ29cJm6eXWv\\\\\\n\",\n              \"dr/btW9PMecv0YYN6ykiIqLO9aqoqFDiw4c0YuRomcp18+pOaWlp9Pz58zq3gWE+ZSzAvmAikYjU\\\\\\n\",\n              \"1NRlLtekSZPymeplpa2tTWPGjKVRw4dRSUlJresf3L+PLsZeoJYtW5K9vT15e3uTuro6+fn50e7d\\\\\\n\",\n              \"u2n69Ol0+vRpcnJy+qD2SCM7O5u+W7iIbO3sal3XwMCAflq3kZYuXUqohxc95Obmkr5QKFMZJSUl\\\\\\n\",\n              \"0tPXp1evXtW5fob5lLEA+4JpaWlRbq7sndyrV69IS0vrg+tdtiyEtLla1KdXD3r27FmV65SWltLG\\\\\\n\",\n              \"DetoYuA4amXfmiIiIsje3p4uXLhADx8+pGHDhtE333xDMTEx1KpVqw9uS22eP39Oubm5NMxvuNRl\\\\\\n\",\n              \"unl1p1e5uRQfH1/n+lVUVKQK+n8rKioiFRWVOtfPMJ8yFmBfsE6dOtGVy5fo9evXMpU7fHA/eXb1\\\\\\n\",\n              \"/OB6GzVqRH/88Qe5ODtTeyd7Gvx1Pzp86CBduXyZzp2NoR8WLiArcxOKOHSQLl65RidORdNfyenU\\\\\\n\",\n              \"/+tBxOPxqWXLlnT48GFSVlamBw8e1MuZTnWOHj1K3j19SENDQ+oyioqKNMxvOB04cLDO9bdo0YLi\\\\\\n\",\n              \"rlyWqUxGRga9yskhIyOjOtfPMJ8yFmBfMIFAQD4+PrR7106pyzx79oxiL5wnP79hdapbSUmJgoOX\\\\\\n\",\n              \"Unp6OnX38qLgxYtoyKD+FLzkB8rLy6PjJ0/T6ZjzZGVtTUR/X7YMGDuerly/SampaTR37lzasmUL\\\\\\n\",\n              \"LV68mNzc3OjKlSt1ak91srOzyczMXOZyhoZGlJOTU+f6x48fT9u2bJKpzI7t22jQoMGkri775WGG\\\\\\n\",\n              \"kScswL5w06ZNozWrQin58eNa1xWLxTR9ykQaPTqAmjRpUi/1N2nShLy8ulFm5gu6fjOBzpyLpTVr\\\\\\n\",\n              \"11V7v0koFFLUmXO0du1acnBwoFu3btG4ceNo0KBB1L9/f0pKSqqXdr2noqJCRUVFMpcrLi6ul0t4\\\\\\n\",\n              \"ffr0ocdJj+hi7AWp1s/OzqZftm6miRMn1LluhvnUsQD7wrVp04YWLfqBenh9RXfv3Kl2PZFIRP7D\\\\\\n\",\n              \"BlNxcRGFhATXaxs2b95Cw/xHkJ6enlTrm5iakm/vvhQWtoOUlJRoxIgR9OjRI2rbti25uLjQlClT\\\\\\n\",\n              \"KDs7u17a1rx5c7p0MVbmcvHXr5GVlVWd6+dwOLRjxw7yHzqIbt28WeO6r169on6+PWnUqNENel+Q\\\\\\n\",\n              \"YT4ZH3scf0ORp2cZPgV79vwKHo+Hnj69EHEsEk8yspCZk4f4W3cwdfpM8Hg8jBw5EkVFRfVar1gs\\\\\\n\",\n              \"/vv1JQ8eyfSs08Ur19GsWbNK28vKysKUKVPA4/EQHBxc5xkpSkpKoKGhiRt/3pW6bU9fZIPL5dZp\\\\\\n\",\n              \"Jv9/O3ToEPh8PmbNCUJiUmqlh8tX/7QOJqammD17zgc9n8cw78lT36kANOAd8I+oTZs2dOPGjY/d\\\\\\n\",\n              \"DLkiEolo7969tHXbNnqclEQlJSUkEAhowICvacKEQDIzM6v3OvPy8sjExIRevnojU7mysjLiNlGl\\\\\\n\",\n              \"kpISUlSsfCEhKSmJvv32W7p69SotWbKE/P39SUlJ6YPaaGFhQY5ObSj8171SPfu2YP48ysnKpB07\\\\\\n\",\n              \"dnxQfdV5/PgxbdiwkXbvDidLSyvi8fkkEono9p+3qKunJ02eNInc3d3rtU7myyNPfScLMOajys7O\\\\\\n\",\n              \"pubNm9Pzl7IN5wdAGo05VFRURI0aNap2vbi4OJo9ezbl5+dTaGgoeXl5ydzGjh07UnZ2NvXu258W\\\\\\n\",\n              \"LV5aY4j9snULrQxdRnFxcSSU8fktaRUUFND169fpzZs3pKamRi1btiR9ff0GqYv58shT31n9v3yG\\\\\\n\",\n              \"+Q9wuVwqKCigN2/eyPRs2fPnz0lTU7PG8CIicnFxoUuXLlFERARNmTKFTE1NKTQ0lFq3bl1rHbm5\\\\\\n\",\n              \"uRQVFUWZmZn01Vdf0ZHDB+lOwm2aNmMWuXt0rhBkN+LjaeP6n+n6tTiKjo5usPAiIlJTUyMPD48G\\\\\\n\",\n              \"2z7DyAs2iIP5qDgcDvXy9aXf9uyWqVz4zjAaOHCQVOsqKChQ37596f79+9S7d2/q3r07jRgxgp4+\\\\\\n\",\n              \"fVrl+nfv3qVRo0aRubk5/b73D3Lt2IlEhUVEpED3792jgJH+1NLGigYN6EtDBvandk72NHzYILJv\\\\\\n\",\n              \"1ZLi4+PJ0tJSpn1hGObDsDMw5qObOGECTZg4kcYFTpDqPlVJSQmF/bKVIiMjZaqHw+HQpEmTyN/f\\\\\\n\",\n              \"v/wsbNy4cfTNN9+Un/1FRETQ2LFjaeqMWXT3YRIJBILy8gDoYuwFWrEsmPLfvaMB/fuRqqoqCYVC\\\\\\n\",\n              \"cnFx+eB7bAzDfBh2D4z56ABQ165dqYVtS/px1Zoa7zFJJBIaP2Y0FRTk08EDB+pU77Nnz2jhwoV0\\\\\\n\",\n              \"/Phxmj9/PjVv3pz8/f3p8NFIcqxhbkWxWEyTJ4ynZ8+eUOSJE8ThcOrUDob5lMhT38kuITIfnYKC\\\\\\n\",\n              \"Ah04cIDiLl+kwLEB1c5U//z5c/IfNpjS01IofNeuOtdrZGRE27dvpzNnzlBkZCQNHDiQNm3dXmN4\\\\\\n\",\n              \"Ef09i8i6jZspP19EB+oYogzDfDgWYMwnQVtbmy5cuEAqyhyyt7Wmkf7D6I+9v9PJyBP0+2+/0tBB\\\\\\n\",\n              \"A6itQ0syNjSk06dP1+s0SS1btqSgoCDS09Mn7x49pSrTqFEjmjZjFm3cuLHe2sEwjGzYJUTmk/P6\\\\\\n\",\n              \"9WvasWMnxV2No3fv3pGmpia5d3Kn4cP9ZZpUVxZ+fn7k2KY9TZw8ReoyZWVlZG1hQjExMdS8efMG\\\\\\n\",\n              \"aRfD/Nfkqe9kgziYT462tjbNnDmDiGb8Z3U+Tk6m0WNlmz+wUaNGZGvXklJSUliAMcxHwC4hMgwR\\\\\\n\",\n              \"icvKan2mrCocDofKysoaoEUMw9SGBRjDEJGunh6lp6fJVAYApaWlSj0JMcMw9YsFGMMQ0eBBgyh8\\\\\\n\",\n              \"Z5hMZW7Ex1NhQQG1bdu2gVrFMExNWIAxDBF9/fXXlHD7T3r0119Sl9m8cT0FBk6ocjJhhmEaHvuX\\\\\\n\",\n              \"xzBEpKqqSt9+O5/8hgykvLy8Wtf/dXc4Xb4US2PGBPwHrWMYpioswBjmf6ZNm0pdu3alrzzc6EZ8\\\\\\n\",\n              \"fJXrvHv3jlYsC6aFC76lyMhI0tHR+Y9byTDMe2wYPcP8j4KCAq1atZKst1qT/9CBxOcLaMgwfzIw\\\\\\n\",\n              \"NKSioiK6GneF9u39jdw9POjy5ctkYmLysZvMMF80FmAM8w8KCgo0fvw4GjMmgE6dOkWHDh2m2Atn\\\\\\n\",\n              \"SUVFhWxa2NCdO3fIyMjoYzeTYRhiAcYwVVJSUqKePXtSz57STS3FMMx/j90DYxiGYeQSCzCGYRhG\\\\\\n\",\n              \"LrEAYxiGYeQSCzCGYRhGLrEAYxiGYeQSCzCGYRhGLrEAYxiGYeQSCzCGYRhGLrEAYxiGYeQSCzCG\\\\\\n\",\n              \"YRhGLikAwMduREPg8/lkamr6sZvBMAwjV9LS0ignJ+djN0Mqn22AMQzDMJ83dgmRYRiGkUsswBiG\\\\\\n\",\n              \"YRi5xAKMYRiGkUsswBiGYRi5xAKMYRiGkUsswBiGYRi5xAKMYRiGkUsswBiGYRi5xAKMYRiGkUss\\\\\\n\",\n              \"wBiGYRi5xAKMYRiGkUsswBiGYRi5xAKMYRiGkUsswBiGYRi5xAKMYRiGkUsswBiGYRi5xAKMYRiG\\\\\\n\",\n              \"kUsswBiGYRi5xAKMYRiGkUsswBiGYRi5xAKMYRiGkUsswBiGYRi5xAKMYRiGkUsswBiGYRi5xAKM\\\\\\n\",\n              \"YRiGkUsswBiGYRi5xAKMYRiGkUsswBiGYRi5xAKMYRiGkUsswBiGYRi5xAKMYRiGkUsswBiGYRi5\\\\\\n\",\n              \"xAKMYRiGkUsswBiGYRi59H88rW+29ahrjQAAAABJRU5ErkJggg==\\\\\\n\",\n              \"\\\"\\n\",\n              \"  frames[1] = \\\"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAbAAAAEgCAYAAADVKCZpAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\\\\\\n\",\n              \"AAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0\\\\\\n\",\n              \"dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOzdd1RUx8PG8e9SRZrSwViwFxTsDbFX\\\\\\n\",\n              \"1gL23lvsmsSSGI0mthg1Yu9GLFggItgVxa7Ya8QOiEiRXpe97x955RcCyK4RcXU+5+yh7J17ZzG5\\\\\\n\",\n              \"z87sFJkkSRKCIAiCoGG0CrsCgiAIgvA+RIAJgiAIGkkEmCAIgqCRRIAJgiAIGkkEmCAIgqCRRIAJ\\\\\\n\",\n              \"giAIGkkEmCAIgqCRRIAJgiAIGkkEmCAIgqCRRIAJgiAIGkkEmCAIgqCRRIAJgiAIGkkEmCAIgqCR\\\\\\n\",\n              \"RIAJgiAIGkkEmCAIgqCRRIAJgiAIGkkEmCAIgqCRRIAJgiAIGkkEmCAIgqCRRIAJgiAIGkkEmCAI\\\\\\n\",\n              \"gqCRRIAJgiAIGkkEmCAIgqCRRIAJgiAIGkkEmCAIgqCRRIAJgiAIGkkEmCAIgqCRRIAJgiAIGkkE\\\\\\n\",\n              \"mCAIgqCRRIAJgiAIGkkEmCAIgqCRRIAJgiAIGkkEmCAIgqCRRIAJgiAIGkmnsCtQUCwsLChTpkxh\\\\\\n\",\n              \"V0MQBEGjPHv2jKioqMKuhko+2wArU6YMQUFBhV0NQRAEjVKnTp3CroLKRBeiIAiCoJFEgAmCIAga\\\\\\n\",\n              \"SQSYIAiCoJFEgAmCIAgaSQSYIAiCoJFEgAmCIAgaSQSYIAiCoJFEgAmCIAga6bOdyCwIgvCxpKSk\\\\\\n\",\n              \"ZK1eYWlpSZEiRQq5Rl8G0QITBEF4D5Ikce7cOfr27YulpSUNGzakQYMGWFpaMnDgQC5fvowkSYVd\\\\\\n\",\n              \"zc+aCDBBEAQ1JScn07VrVwYNHoxjzToEPw3h0bNQHj8P437wUypVdaBnr1706dOH1NTUwq7uZ0sE\\\\\\n\",\n              \"mCAIghoyMjLo1Lkz+kWKcvXGHcZPnETx4sWznrewsGDylG+5dvMuqekZdO/Rg8zMzEKs8edLBJgg\\\\\\n\",\n              \"CIIaFi5chI6OLus3bUFPTy/P4wwMDPjDcycJCYksX+7xEWv45RABJgiCoKKMjAzWrFnN3F8WoKOT\\\\\\n\",\n              \"/xg4XV1dfpo7j5UrV6BUKj9CDb8sIsAEQRBUdODAAUqXsad6jRoql6lXvz7GJiYcO3asAGv2ZRIB\\\\\\n\",\n              \"JgiCoKJLly7Ttl17tcrIZDLatG3PpUuXCqhWXy4RYIIgCCpKTEzE0NBI7XJGxsYkJiYVQI2+bCLA\\\\\\n\",\n              \"BEEQVGRqakps7Bu1y72JicHU1KQAavRlEwEmCIKgopYtW/Cnzz61JigrlUr2/+lNixYtCrBmXyYR\\\\\\n\",\n              \"YIIgCCpq0aIFGenpnD93TuUyJ44fw9TEhAYNGhRgzb5MIsAEQRBUJJPJ6NOnD+PGjCIpKf/PtOLi\\\\\\n\",\n              \"4pgx9VsmTZqETCb7CDX8sogAEwRBUJGvry8eHh5YWpjTWd6e6OjoPI99/fo1nTq0pWnTpvTr1+8j\\\\\\n\",\n              \"1vLLIQJMEAQhH0qlklmzZjFmzBgOHDhAQEAAjRo2pHqVCowZNYKbN26QkpJCcnIy165eZcyoEThW\\\\\\n\",\n              \"q0Tbtm3x8FguWl8FRGynIgiC8A5v3ryhX79+JCYmEhQUhLW1NQCLFi1kwoTxNGvWjMOHDxL9/9up\\\\\\n\",\n              \"lCxZkoEDB3H//n1sbGwKs+qfPRFggiAIebh9+zZubm507NiRRYsWoaurm+15Q0NDIiIiCA8Px9DQ\\\\\\n\",\n              \"sJBq+eUSXYiCIAi52LVrFy1atOCnn35i6dKlOcIL4ODBgzRr1kyEVyERLTBBEIR/UCgUTJ06FR8f\\\\\\n\",\n              \"H44dO4aTk1Oex+7fv5/OnTt/xNoJ/yQCTBAE4f+9fv2anj17oq+vT1BQEGZmZnkem5aWxpEjR1i+\\\\\\n\",\n              \"fPlHrKHwT6ILURAEAbhy5Qp169alcePG+Pv7vzO8AAICAnBwcMga1CF8fKIFJgjCF2/jxo1Mnz6d\\\\\\n\",\n              \"tWvX4ubmplIZ0X1Y+ESACYLwxUpLS2PChAmcPn2awMBAKleurFI5pVKJr68vAQEBBVxD4V1EgAmC\\\\\\n\",\n              \"8EUKCwujW7du2NnZcfnyZYyNjVUuGxQUhImJCRUrVizAGgr5EZ+BCYLwxQkMDKRu3bp06tSJvXv3\\\\\\n\",\n              \"qhVeILoPPxUiwARB+GJIksTy5cvp0aMHW7ZsYfr06e+1zNP+/fvp0qVLAdRQUIfoQhQEQWMpFAr8\\\\\\n\",\n              \"/PzYsPkPwsLD0dLSopx9Gb4eOZymTZtmC6fk5GRGjBjB3bt3uXDhAvb29u91zcePHxMdHU29evU+\\\\\\n\",\n              \"0KsQ3pdogQmCoJG2bNmKbcnSjJ76E48Mq0CD/mTW6c2NDCu6DRhG2YqVOXbsGABPnjyhUaNGyGQy\\\\\\n\",\n              \"zp07997hBX+3vjp27IiWlrh9FjbRAhMEQeP8PG8evy1fQ9m+czApVSXbc8Ur1MLOuRsx9y/StWcf\\\\\\n\",\n              \"vh4xjM2bN/HDDz8wduxYtboMU1JS2L17NydPBRKXkEgxU2MunjvLggULPvRLEt6DCDBBEDTKnj17\\\\\\n\",\n              \"WPz7KqqOWYW+qUWux8hkMsyrNkR/+BIWLxvFquVLGTFihMrXSEtL4/uZP7J+wwaMS1ahaMUG6Bh8\\\\\\n\",\n              \"hSIukYhUHQYNHc7Yq9eY9ePMXNdIFD4OmSRJUmFXoiDUqVOHoKCgwq6GIAgfkCRJVKxanaLNhmJW\\\\\\n\",\n              \"ub5KZcIC91BZ+QJfn30qHZ+UlESrth0ISdbiK9fRGFh8leOY5IjnhPitoLK1EYf8fNHX11frdXzK\\\\\\n\",\n              \"NOneKTpxBUHQGOfPnycmPpHiFeuqXMa6bntOHD/Oy5cv8z1WkiS69exNWKYR5fvPzTW8AIpal6bC\\\\\\n\",\n              \"oPkEv8mk/6AhKtdF+LBEgAmCoDF279mLqVMbZGoMoNAxMMLSoTEHDhzI99grV65wMeg69t2n5nsN\\\\\\n\",\n              \"LW0d7Ht+z+Gjx7l3757K9RE+HBFggiBojPDXkeiZWqpdTmZkzqtXr/I9bunyFVg06IyWtmrDA7T1\\\\\\n\",\n              \"9LGs58rvHivVrpPw34lBHIIgaIwi+vpISRlql8tITWHu3LksXrwYa2vrrIeNjU3W9xYWFuzbu4e6\\\\\\n\",\n              \"M1X7rOwtq/pydi4bxtrVIsQ+NhFggiBojOrVqhDofQpQbxWMzNeP8PHxwcXFhYiIiKzHq1eviIiI\\\\\\n\",\n              \"4Nq1a4SEhKCUQM+omFrn1i9mTXJiAunp6ejp6alVVvhvRIAJgqAxBg0cyOw5c/lKHoeuoalKZRJC\\\\\\n\",\n              \"H5IZH0X79u3R0dHB1NQ010V44+LisLHLfdDGO0kSEhLa2trqlxX+ExFggvAJyszM5MSJEzx9+pTM\\\\\\n\",\n              \"zEysra1p164dhoaGhV21QmVpaYmrXE7QaS9Kd8h/XpckSTw/vBFdLYmgoCAaNGiQ57HGxsZo62iT\\\\\\n\",\n              \"Ev0SA3M7leuU9Oop5lY2IsAKgQgwQfiExMfHs3y5B78vX0E6eij1zAAZ2pmJZCQOYeCAAUyb+i2l\\\\\\n\",\n              \"SpUq7KoWmqW/LqJ2vQa8MrPBpkGnPI+TJIkXB9dgrZXIpDlzcHd3Ry6XM3/+fMzNzbOOi42NZd++\\\\\\n\",\n              \"fXh6eqLMzCT8rDdlO49VuT5Rl/Yzaviw//SahPcjJjILwiciLCwMl6YtiEjUJsPUAa2iVtmel9Lj\\\\\\n\",\n              \"kcXeRz/5KUcO+1O/vmoTeT8H9+7dY92Gjdx/+IhMRSYmRkUJDAxEZluFUq36ZVtOSlIqeRMcRNSZ\\\\\\n\",\n              \"3VjqZXDs8EEsLS2Ji4vjxx9/ZNeuXfz000/Y2NiwY8cOjh49SsuWLenfvz8VKlSggbMLjt9uR9fQ\\\\\\n\",\n              \"JN96pSfEcGNRP4If3MPOTvVW26dMk+6dogUmCJ+AuLg4mjRtwcs0c7CpiVYu6/XJ9EzAqj7Jcda0\\\\\\n\",\n              \"btOeSxfPUaVKlVzO9vm4fv06I8dM4P79+xhVb4OuZQ1kBlpkxL0iVdIl49FVnoTeQdfE7O9Jx5KS\\\\\\n\",\n              \"hJdPMDc15seJ4xgwYAAGBgYAmJiY0KtXL8LDwxk3bhwGBgZMnDiRtWvXUrx4ceDvVlsNh2rcWjMR\\\\\\n\",\n              \"xzEe6BTJu8s2Izme4M3TmDxxwmcTXppGBJggfAKWLF1KRKIu2NTMd7FZbdMypGbEM2bcRE4eP/KR\\\\\\n\",\n              \"avjxBQQE0MmtGybOAyk1ehpa2tnXHDSr15Wk5zeJPrSEPvI2NG/aFG1tbUqXLk3Nmv/7OwYHB+Pp\\\\\\n\",\n              \"6Ymnpyd6enr079+fBQsWEBAQwIwZM4iNjWXu3LkYGhoyZswYkhLi6eBch+OrxmDbeihmVRtmmxem\\\\\\n\",\n              \"zFQQdfsMr45tpG/3Lsz5afbH/LMI/yACTBAKmUKhwGPFKhQWLXJteeVGy6wKFy7s4NmzZ5QpU6Zg\\\\\\n\",\n              \"K1gIgoOD6ezeHYuO0zEq45jrMTKZDKMyTuj3/Y1tO76hRbNmuLm5ARAZGYmXlxeenp48e/aMXr16\\\\\\n\",\n              \"sXv3bmrVqpUVbGXLlqVLly5Mnz6dypUrY2Njg6WlJWfOnMHIyAgvLy8WLF7KTd/fKV6lAegVRZmS\\\\\\n\",\n              \"QOTts5iamtC8YR1cGjciIyNDDJ8vJGIlDkEoZEeOHCFTyxAtg9xXVs+NTEsXrWIVWb9+QwHWrPDM\\\\\\n\",\n              \"W7AIQ8cOeYbXP+maWGLeZjzfTp/Jrl276NixIxUqVODixYvMnj2b0NBQli1bRu3atXO0bs3Nzfn5\\\\\\n\",\n              \"558pXrw4L168ID09nZCQEGQyGb169eJG0CUCjvjTupIFKXdPE3HjFEXt6yBVaMWZ1wZ8PWMe1nYl\\\\\\n\",\n              \"+f6HmaSkpBTUn0PIgwgwQShkjx8/JlO3uNrlMnSKc+/BwwKoUeGKi4tj9+7dmDp1ULmMoX0tQiMi\\\\\\n\",\n              \"WbJkCT169CAkJARPT0/atWuHjk7eHU3BwcE0atSIbt26ERERQbdu3WjatClTp04lMTERAJ/9vnj9\\\\\\n\",\n              \"eRB956GUn+CFbcdvsGzUAyuX/lj3XIhFt59Zv/8MDZ2b8ubNm//8+gXViQAThEKmUCiQUH2TxSwy\\\\\\n\",\n              \"LTIy1F9W6VN36NAhTMpUR9fYPP+D/59MJsO8TmecatWhf//+GBsb51vm4sWLuLi4MHXqVObMmYOO\\\\\\n\",\n              \"jg5jx47lzp07hIeHU7VqVYYNH8HytZux67sEk0qNkWnlnOtVxLIMVp1nEKFXgvbyzigUCrVer/D+\\\\\\n\",\n              \"RIAJQiGztrZGR0pWv2BGAl+V+PxGv0VERCAzssr/wH/RNbUmLDz/BXsB9u/fT8eOHdmwYQPDhw/P\\\\\\n\",\n              \"9py1tTV//PEH69evZ8vWP7DqMhMdo3e3kGUyGeYtRvIoPAZfX1+16y68HzGIQxAKmaurKxkjRiGZ\\\\\\n\",\n              \"JSHTVW2lDUmS0E96zMABSwq4dh+frq4uSJlql5OUCo4dO0b58uWxsbHJ83HkyBFWrFjBwYMHqVs3\\\\\\n\",\n              \"733FwsLCMK9QG30L1SaNy7S0MXDsxK9Ll+Pu7q52/QX1iQAThEJWrFgxunXrxu7jD8CqtkpllAkv\\\\\\n\",\n              \"sLQo/s6lkTRVuXLlyHj9WO1yisjHjB8zmpEjhvPq1atsjwsXLvDy5UuuX79OVFQUMpmM9u3bY2Nj\\\\\\n\",\n              \"g62tba5B99vylRRx6KxWHUwqN+bWqjWEhoby1Vfvsa6ioBYRYILwCZj5wwx8fOqRZmCNtvG7b3xS\\\\\\n\",\n              \"ejxSWACJJkU5e/YsTZo0+Ui1/DhatWqFLDWWlPCHGNjmXHQ3N8qMVOJvn2CM56/Y29tToUKFbM+n\\\\\\n\",\n              \"p6czZMgQ7O3tuXPnDsWLFyc6OjpH0IWFhXH16lVevXpF8MOHlGms3pJdWjp6GJrbERYWJgLsIxAB\\\\\\n\",\n              \"JgifgAoVKuC735uOnd1IT6+FVvFKOQYMSJKEMiEE3cizLFw0j/LlytGzZ0+GDBnCrFmz/u56+wxo\\\\\\n\",\n              \"a2sz7uvRrNizjyIdp+U7sRsg9sYh6tWvh729fY7n4uLicHd3x8TEhBMnTlC0aFEArKyssLKyokaN\\\\\\n\",\n              \"Grme07Zkmfd+DZ/pCn2fHDGIQxA+Ec2bN+fM6QBqlkhH6/FOeH2ZzDfBZMY+IvP1dXRf7KMED9j+\\\\\\n\",\n              \"xyYmjB+Pq6sr169f59q1azRu3Jjg4ODCfgkfzMQJ4ymmiCTmnGe+YZAQfInEy3tYtXxZjudCQ0Np\\\\\\n\",\n              \"0qQJlStXZu/evVnhpQpbOzvSokPUqrcyM4Ok6JeUKFFCrXLC+xEBJgifkJo1a3LpwlmuBV1kZI8m\\\\\\n\",\n              \"tKymj0sFLfq1q8rB/Xt48ugvunT532aO1tbW+Pv7M2DAABo1asSmTZs+i3f/xsbGnDp+FO2n53jh\\\\\\n\",\n              \"NZPksPs5Xlf6m3CiAjYQd9yDIwcP5FgX8s6dOzRq1Ih+/fqxYsUKtbc7+Xr4EFJvq7dUV8KDczg4\\\\\\n\",\n              \"OFCyZEm1ygnvR3QhCsInqHLlyixb+ptKx8pkMsaOHUvz5s3p06cPBw8eZO3atdm2DNFElpaWKDNS\\\\\\n\",\n              \"0U18QvLR34jXKYquVVmQaaGMjyAp/BGDBg7k2x1XcmwvExAQQM+ePVm2bBl9+vRR+9rx8fE8ePCA\\\\\\n\",\n              \"qIdXMI0JQ98s/xaVJClJvnmAbxf+qPb1hPcjWmCC8JmoVq0aly9fpnTp0jg5OXHixInCrtJ/snLl\\\\\\n\",\n              \"SqKiovDd/ydhL57itWklc0e6M2uInDXzZxDxMhSP35fmCK+dO3fSs2dPvLy81A4vhULBunXrqFSp\\\\\\n\",\n              \"ElFRUcz84QcifeagSIp9ZzlJkog+uZ4yFkZ07qzeyEXh/YkWmCB8RvT19fntt99o164dAwcOpE+f\\\\\\n\",\n              \"PsydOxd9ff3Crppa4uLimDFjBq1bt6Zx48YAtGjR4p1lJEli8eLFeHh4cOLECapXr67WNY8dO8bk\\\\\\n\",\n              \"yZMxNzfHz8+P2rX/ntKQlpbG6k1TKNZsBEbl6uQYXJMWHULsOU+sZXEcOXH0sxlMoxGkz1Tt2rUL\\\\\\n\",\n              \"uwqCUKgiIyOlzp07S05OTtK9e/cKuzpqGT58uKSvry+FhYWpdLxCoZDGjh0rOTg4SCEhIWpd6+7d\\\\\\n\",\n              \"u1KHDh2k8uXLSz4+PpJSqcxxzJ49e6Qq1Z0kU6sSklXjXpJN65GSVbPBklWlupKpmYX0zXdTpcTE\\\\\\n\",\n              \"RLWu+6nSpHunaIEJwmfKwsICHx8fNmzYgIuLC3PmzGHUqFEqDUsvTE+ePGHLli189913Km0UmZKS\\\\\\n\",\n              \"Qp8+fYiLi+Ps2bOYmpqqdJ3IyEhmz57N7t27mTFjBj4+Pnlui9KtWze6devGlStX8PPz43VUDMZG\\\\\\n\",\n              \"xajpNAl3d3eNa+F+LmSS9BkMWcqFJm2LLQgF7a+//qJv377Y2tqyceNGrKzUX2vwY3F2dubBgweE\\\\\\n\",\n              \"hYXlGwxRUVF06tQJe3t7Nm3apFKQpKWl4eHhwcKFC+nTpw8//vijxg94+ZA06d4pBnEIwhegUqVK\\\\\\n\",\n              \"nD9/HgcHB5ycnDh8+HBhVylXJ0+e5NKlS6xfvz7fMHry5AmNGzemadOmbNu2Ld/jJUli7969VKlS\\\\\\n\",\n              \"hTNnznD27Fl+//13EV6arJC7MAuMJvXjCsLHFBAQIJUsWVIaP368lJycXNjVyZKZmSmVKFFCqlmz\\\\\\n\",\n              \"Zr7HXrlyRbK1tZVWrFih0rkvXbokNW7cWHJ0dJSOHz/+X6v6WdOke6dogQnCF6ZZs2bcvHmTV69e\\\\\\n\",\n              \"Ua9ePW7duvVJTH728PAgIiKCXbt2vfO4gwcP0r59e1atWsWYMWPeeWxISAj9+vWjS5cuDBkyhKtX\\\\\\n\",\n              \"r9KyZcsPWW2hEIkAE4QvUGxsLCVKluLpizAcHZ3Q1tHFukQpfpw1m5cvX370+iQlJTFjxgx69epF\\\\\\n\",\n              \"xYp5L+C7YcMGhgwZgq+vb7YVSf4tMTGRmTNn4uTkRNmyZXn48CFDhgxRezUO4dMmAkwQviAKhYJh\\\\\\n\",\n              \"I0dTrUZNNh57gE7zGZj1/4PifbaQVvtrft93gXIVKzP9+5kftVU2fvx4JEli1apVuT4vSRKzZs1i\\\\\\n\",\n              \"/vz5BAYG0rBhw1yPy8zMZOPGjVSsWJFnz55x48YN5syZg5GRUUFWXygkYhi9IHwhlEolXXv05tT1\\\\\\n\",\n              \"xxTpuAQtvewL2+qYlQGzwWhXc2fV1qXExMSwZtWKAh92//z5c7Zu3crixYsxNjbO8XxGRgYjRozg\\\\\\n\",\n              \"zp07nD9/Hmtr61zPc+LECaZMmYKxsTH79+9/52aVmkChUKCtrf3JT3soTKIFJghfiCVLlnHqyj10\\\\\\n\",\n              \"nCfmCK9/0jIwRdvlO3Z4++f7edSH0KtXL2xtbRk/fnyO5xISEpDL5URGRnLq1Klcw+uvv/6iU6dO\\\\\\n\",\n              \"jBgxgpkzZxIYGKiR4SVJEpcuXaJH774YGBqjp6+Pjo4u9hWq4OHhQXx8fGFX8ZMjAkwQvgCZmZks\\\\\\n\",\n              \"WrIUmWMfZNq5T9b9Jy29osgcevD9rLkFWq+TJ09y+fJltm/fjpZW9ttReHg4TZs2pXTp0vz5558Y\\\\\\n\",\n              \"Ghpmez46Oprx48fj7OyMi4sL9+7do2vXrhrZYomJiaGRS3Naubpz6LEMA/mvFO/nSbE+W4gp240f\\\\\\n\",\n              \"V3hhW6IUO3bsKOyqflJEgAnCF+Dw4cOkaxuiY1FO5TK6JZx4HhLG+PHjyczM/OB1kiSJfv36ZQXQ\\\\\\n\",\n              \"P92/f5+GDRvi7u7O2rVr0dH536cd6enpLF26lMqVK6NUKrl37x7ffPONxq6GERsbS72GztxLMEHP\\\\\\n\",\n              \"9Vf0q8rRMiiGTCZDpq2Drm01tBuORbfFDEaMncyGDRsKu8qfDBFggvAFOHr8BBlWTmqVkWlpUbS8\\\\\\n\",\n              \"MwcOHKBVq1aEhKi3uWN+li5dSmRkJF5eXtl+f/bsWZo1a8bs2bP54YcfslpUkiTh4+NDtWrVOH78\\\\\\n\",\n              \"OIGBgaxYsQJLS8sPWq+PbcDgYUTqlUHHsTcyWd63ZJ3ipdBt+h0TpnzH3bt3P2INP10iwAThCxD9\\\\\\n\",\n              \"JhaZnmH+B/5Lpk5RevXqTevWralTpw579uz5IPVJSUnh+++/Z/To0djY2GT9fu/evbi7u7Nt2zYG\\\\\\n\",\n              \"DRqU9furV6/SrFkzfvzxR1auXIm/v3+ODSw1UWhoKMeOHUOnRg+Vuj61Te3QLteS35Yt/wi1+/SJ\\\\\\n\",\n              \"ABOEL0AxExOkjFS1y2kr0yhWzJQZM2bg5+fH999/z+DBg0lISPhP9Rk5ciS6urr89tv/Nu1ctmwZ\\\\\\n\",\n              \"EydO5OjRo7Rp0waAsLAwBg0ahFwup1+/fly/fj3ruc/BytVr0CvbGJluEZXLaJdrzq6du8SgDkSA\\\\\\n\",\n              \"CcIXwcW5EXrRd9QqI0kSvLxBgwYNAKhbty7Xrl1DR0eHmjVrcvHixfeqy7Nnz9i+fTsrVqxAV1cX\\\\\\n\",\n              \"pVLJ5MmTWbduHefOncPJyYmkpCRmz55NjRo1sLOz46+//mL48OHZPgv7HBw6cgJsa6lVRtvQDAOL\\\\\\n\",\n              \"kly7dq2AaqU5RIAJwhegS5cuSPHhKGJDVS6jeHWP1MQYZsyYgZeXFxkZGRgZGbF+/XoWLVpE586d\\\\\\n\",\n              \"mTt3LgqFIkfZ1NRUPD09adfBlVp16lK/YSP6DRjI2bNn6dq1K+XKlWPAgAGkpqbSu3dvgoKCOHv2\\\\\\n\",\n              \"LCVLlmTLli1UqlSJhw8fcu3aNebNm4eJicmH/HMUqvT0dEJDQwkKCiIi4tV7de3K9AxFCwwxkVkQ\\\\\\n\",\n              \"vgh6enqM+XoUK3bsRmo8EZnWu9+7SpnpcG8fS39diJ2dLcuXL2fKlCmMGjWKESNG4O7uTv369Rk4\\\\\\n\",\n              \"cCBHjhxh27Zt2NvbI0kSS5YuZd78+dhXro6zvDvONiXIzMzk6f1b9Ozbn5jI16xetZKYmBi6dOmC\\\\\\n\",\n              \"jY0NR48e5eLFi0yePJkiRYqwd+/erJafJsjMzCQqKoqIiAhevXqV7fHv38XHx2NlZYW1tTVJycnI\\\\\\n\",\n              \"MlLUvp6UkZLrpO8vjdgPTBC+EOnp6TRr2Ya7kaBTdxgy7dzfv0oZqSSfXoJxZjRPgv/KWobp1q1b\\\\\\n\",\n              \"eHh4sHfvXjp16sS4ceOoVasWS5cuZcGCBSxZsoTLV4I4dPwkY+evwq50ziH7kiRxJeAwG3+ZirFh\\\\\\n\",\n              \"Ubp168bIkSOZNm0a169fZ+HChfToodqAhoImSRKxsbHvDKO3P0dFRVGsWDFsbGyyHtbW1rn+bG5u\\\\\\n\",\n              \"jpaWFpIk0adff3xvvsGgTn+V66VMjSflwDeEhTzDzMzsg79uTbp3igAThC9IcnIy7t17cf7yNST7\\\\\\n\",\n              \"5uiWa4qW/t8BpUyJRfEoAOXTU3RybYeWlElYWBgHDhzItpZgTEwMGzduZOXKldjZ2TFu3DjKly9P\\\\\\n\",\n              \"585d0NI34Oc//Chq/O4uv4e3rrJwXH/cu3TG39+fb7/9lgkTJlCkiOqDGd5XYmJirmH0799FRERg\\\\\\n\",\n              \"YGDwzjB6+7C0tERXV1el60dFReHp6cnGjRuJi4vjVeQbjLuuRKaT/wRzgLS7B5BX0man5x//5c+Q\\\\\\n\",\n              \"J026d4oAE4QvzNsli35bthzf/X+io2eAJElImRn07NmLSRPG4ujoSGZmJiNHjuT+/fscPHgQU1PT\\\\\\n\",\n              \"bOdRKBQcOHAADw8PHjx4QHJqGjNW76RMJQeV6rF71a/8dfEkJ44f+887RKelpeXZOvr3z0qlMlv4\\\\\\n\",\n              \"5BVQ1tbWGBgY/Kd6vZWZmcmxY8fYuHEjx44do2PHjgwdOhQXFxdat5cTFFkU3Rrd8j9PYhQZJ37i\\\\\\n\",\n              \"9PHD1K5d+4PU7d806d4pAkwQvmBpaWnExMSgpaWFmZlZjlaEUqlk7NixBAUFceTIEYoXL57reX7/\\\\\\n\",\n              \"/Xc81m5k3g7Vd3qOef2KqT1aEvLiea6DNBQKBZGRkSq1lpKSkrJCKL/WkpGR0Ufronzy5AmbN29m\\\\\\n\",\n              \"y5Yt2NjYMHToUHr37p3tzUBERAQ169QnybYxOlXkedYtM+E1ijOLmfnteL779psCq7Mm3TvFIA5B\\\\\\n\",\n              \"+ILp6+tja2ub5/NaWlqsXLmSyZMn07JlS44ePYqFhUWO4y4HXaW5Wx+1rm1mZUPZqjUYN24c1tbW\\\\\\n\",\n              \"OcIpJiYGMzOzHK2lUqVKUa9evWwBVbx48RxrKRaWlJQUvL292bhxI7dv36Zv3774+/tTo0aNXI+3\\\\\\n\",\n              \"trbmysVztGrbgfATQShKN0ffvlFWl6LizQt4fIL05xeZ//PPTJgw7mO+nE+aCDBBEN5JJpOxZMkS\\\\\\n\",\n              \"ZsyYQfPmzTl+/HiOVeEjIyOpWccmjzPkzcTMktDQUKpWrUq1atWytZYsLCw0Zt6XJElcu3aNjRs3\\\\\\n\",\n              \"4uXlRd26dRk9ejSdOnVSaY3GEiVKcOfmNfz9/Zk7fxHXd21ER78oSqUCY2MTxoweyaiRm7Gzs/sI\\\\\\n\",\n              \"r0ZzaMZ/HYIgFCqZTMa8efPQ19enWbNmnDhxItvNVF9fn4z0dLXPm5KcxLXA0wQHB2Nra4utrS12\\\\\\n\",\n              \"dna5fm9pafnJ7agcHR3N9u3b2bRpE3FxcQwePJjr169TqlQptc4TFhbG6jVrWbVmHZkyXUxsy6FI\\\\\\n\",\n              \"S4KMZAYPHMDQIYNFeOVCBJggCCqRyWTMnj0bfX19mjZtysmTJ7G0tCQgIIBXL1+ScfUiDVrLVT6f\\\\\\n\",\n              \"Uqkk/OlDTp48SalSpQgPD04kHYgAACAASURBVOfly5eEh4cTHh7OuXPnsr4PDw8nNjYWS0vLfIPO\\\\\\n\",\n              \"2tq6QFtuSqWSEydOsHHjRg4fPkyHDh347bffaN68+Xt1Y+7fv5++AwahU6oBNJqCbvGSAOgBmXEv\\\\\\n\",\n              \"WX/kJKvXOrFu9Qr69u37gV+NZhMBJgiCWgYOHMj169epXLkyWlpa1KxZkxYtmrN2/Qb6TfwevSKq\\\\\\n\",\n              \"jdy7c/ksMZGR7Nmzh7Fjx9KwYcN3Hp+enk5ERESOoLty5UrW9y9fviQ6Ohpzc/N8g87GxgY9PdWG\\\\\\n\",\n              \"rsPfS2Bt2bKFzZs3Y2FhwZAhQ1i9enWeA1tU4e/vT58BQ9Bt8k2uW91om9qhXasfMvumjBw7CR0d\\\\\\n\",\n              \"HXr27Pne1/vciFGIgiC8k1KpJCgoCD8/P/z9/Xn27Bnt2rVDV1eX48ePExAQQIUKFWjTrj1WVerQ\\\\\\n\",\n              \"aeDofM+ZqVDwy6geRIc9R0tLi+TkZOrWrcvYsWNxdXX9T12FCoWC169f5wi6f34fHh5OREQExYoV\\\\\\n\",\n              \"yxFs//zZzMyMy5cvs23bNq5fv07v3r0ZOnQoTk7qbU2Tm6SkJGzsSqLlPAldywr5v66Y52QEzCfk\\\\\\n\",\n              \"+ZMCmcD8libdO0ULTBCEHBISEjh27Bh+fn4cPHgQc3Nz5HI5y5Yto2HDhllddOvXr6d58+b8+eef\\\\\\n\",\n              \"xMe+4fTqxZhb29K4XZc8z52pULDxl6lYFTPm6rlneHt7M3PmTMLCwpgxYwbjx49n9OjRDB06NNcR\\\\\\n\",\n              \"j7lJS0vD29sbj9VrefL4MQpFBhaWVvTv04vhw4blOs/s7fJP/w63+/fv4+3tzf3794mMjESSJIoW\\\\\\n\",\n              \"LUrJkiV58OABS5YsybVFZ2trm2PX6HfZvn07OtaV0FYhvAB0zEoj+8qJzVu2MGXyZJWv8zkTLTBB\\\\\\n\",\n              \"EAB4/Pgxfn5++Pn5cenSJRo1aoRcLsfV1RV7e/s8y/32229MnToVLS0tjI2NkWRa1HJpTavuAylb\\\\\\n\",\n              \"pXrWcZkKBUGnj3LYcy22lmb47NubtZ5fRkYGmzdvZs6cOVSoUAFTU1NOnTpFly5dGDt2LHXq1Mnz\\\\\\n\",\n              \"+nv37mXk6DGYflWO0s26YV62Klpa2iRGhfP8zH6eXT7O0CFDWPrb4ne27N68ecOOHTvYuHEj0dHR\\\\\\n\",\n              \"DB48mEGDBlGqVCliYmLybdG9fPkya1rCu7ou7ezsMDY2pmLV6rwq0RG9ErkPr89NxuuHGN3dQuiz\\\\\\n\",\n              \"xwU2l02T7p0iwAThC5WRkcG5c+eyugZjY2NxdXVFLpfTqlWrbMtH5SUwMJBOnTqRmpqKQqFg+PDh\\\\\\n\",\n              \"zJkzh7Xr1rF69RqKmhTDxNyS9LQ0wp89Jj09jZUey+ndu3euAy1SUlJYtWoVCxcupFmzZtjb2+Pl\\\\\\n\",\n              \"5YW1tTVjxoyhR48e2Zab2rhxE99M/x7nCUuwKJf7CiCpCbFcXDUNR3tb9u3xyhZiSqWSgIAANm3a\\\\\\n\",\n              \"hL+/P+3atWPIkCG0bNlS7W7Mt2snqhJ0AMnJKZj1/wOZluodYZIkEb9rCLExUWq19tShSfdOEWCC\\\\\\n\",\n              \"8AWJiori8OHD+Pn5cfToUcqVK4dcLkcul1OzZk21RtGtW7eOKVOmYGBgQLly5TA2NubmzZscOHCA\\\\\\n\",\n              \"evXqoVAouHz5Mtu3byc4OJjff/+dAQMG8PPPP9O2bdt3njs+Pp6lS5eyfPlyevToQf369dm5cyc3\\\\\\n\",\n              \"btxg6NChjBo1itevX9OybXtafr8RU7u8W4gAmRnpnP71awZ3kzP7x5mEhIRkDcgwNjZm6NCh9O3b\\\\\\n\",\n              \"F3Nzc5Vf//uSJInIyEhs7UpQvN82tcsn7xvF8yfBWFpaFkDtNOveKT4DE4TPmCRJ3LlzJ6tr8M6d\\\\\\n\",\n              \"O7Rs2RK5XM7SpUvfuQpHXjIyMpg4cSK7du3C2tqa6tWrI0kS+/bt49ChQ8jlcnx8fGjcuDGNGjXi\\\\\\n\",\n              \"6dOnvHnzhipVqjBo0CC2bNmSb4CZmJgwa9YsxowZw4IFC5g8eXJW627nzp3UrFmTosamVHYdlG94\\\\\\n\",\n              \"AWjr6lF70Pf8Nmcg58+e4erVq/Ts2ZM9e/ZQq1atAl9aSpIkQkNDuXHjBjdv3uT69esoJQllejJa\\\\\\n\",\n              \"ekVVP09mBumpSTnWpfxSiQAThM9MSkoKAQEB+Pv74+fnh7a2NnK5nFmzZtG0aVOVVobIS3R0NF27\\\\\\n\",\n              \"duXRo0eUL1+eevXqce3aNY4fP551HU9PT9zc3Ni9ezfNmjXDwMCAlJS/97zq1asXM2bMIDY2lmLF\\\\\\n\",\n              \"iuV7PQsLCxYvXszEiRP5+eefcXV1ZeLEiZw6dYp6DRrSxCXvwSL/Zmpnj5FNGcqXL8/+/fs/2EK9\\\\\\n\",\n              \"/5aens6DBw+4ceNG1uPmzZtZO1k7OjrSrVs3IqJiufHsPEUqtlL93M8vU6+hs1rD/z9nIsAE4TMQ\\\\\\n\",\n              \"FhaWFVinTp2iZs2ayOVyDh8+TOXKlT9IC+POnTt07NgRLS0tHB0dadq0KZs3b+bcuXPZwqBNmzZ4\\\\\\n\",\n              \"eXnRo0cPtm/fTpEiRbICzNzcPOv5kSNHqnztr776ijVr1vDNN98wa9YsFi9ejE2VOugbqbdTc9mm\\\\\\n\",\n              \"XXgd89cHC683b95w8+ZNbt68mRVWDx48oEyZMjg5OeHk5MR3332Hk5MTNjbZl9qytLSk24CRSBVa\\\\\\n\",\n              \"qvzvo/M8gO9+//mD1P1zIAJMEDSQUqnkypUrWV2DL168oH379vTp04etW7f+p8m1ufH19WXIkCGY\\\\\\n\",\n              \"mpri4uJCq1atmDZtGmfPns11TlLz5s3x9vbG3d2dSZMmkZqamvXcoEGD+Pnnn9UKsLfKly/P9u3b\\\\\\n\",\n              \"+eGHH9h99o7a5Q1MzIh59kbtcpIk8ezZs2wtqhs3bhAdHU2NGjVwdHSkYcOGjB49GgcHB4oWzb9b\\\\\\n\",\n              \"sEWLFtiaGfHyvj+6VfNfwSQ9+ATFtdOQy1Vf7eRzJwJMEDREfHw8R48exd/fn4MHD2JpaYlcLsfD\\\\\\n\",\n              \"w4MGDRoUyPJJkiQxb948PDw8MDIyomfPnrRq1YpevXpx/PhxSpcunWdZZ2dnDhw4QPv27bMNjmjb\\\\\\n\",\n              \"ti3Dhg3j/PnzxMTEkJCQgLGxMfXq1VN5X7BSpUqhpbyu9utRpKdinE+4pKamcvfu3Wytqps3b2Js\\\\\\n\",\n              \"bJzVqurbty+//vor5cqVe+9V8LW0tDh6yI/a9RqSJmX+vZWKVs6Rj5JSScbDo+g8PszJi+c0ZoHj\\\\\\n\",\n              \"j0H8JQThE/bo0aOsVtbly5dp3Lhx1udZZcqUKdBrJycnM2TIEO7cuYOWlhaTJk2iWbNmtG7dGi8v\\\\\\n\",\n              \"rzy3B/mn+vXrs2rVKvr374+Xlxc9e/bk6tWrFLewpGXrNpSrUQd9Q2PSkxJ5du867dq1Y/LECTmW\\\\\\n\",\n              \"lYqJieHcuXMEBgYSGBjIrVu3QLcI9TMVaGmrfhuLvHeJVk1r/e/nyMisoHr79e3ne2/DqnPnzjg6\\\\\\n\",\n              \"Oqo8qVodJUuW5HrQJTp37cEDvynI7Jui/VUdZHqGSBkpZIZdg6cBlClph+/lC++cj/clEgEmCJ+Q\\\\\\n\",\n              \"jIwMzp49mxVaCQkJuLq6Mm7cOFq2bKnS3KwPISQkhC5dumBhYUFkZCTLli2jUaNGODs7s2LFCpo3\\\\\\n\",\n              \"b67yuWrVqoWtrS2TJk1i/35fjpw4ST33IXT+yR0D4/+NpktJjOfmMW86unVl5LAh1KhencDAQM6c\\\\\\n\",\n              \"OcOzZ89o0KABtWvXzlqVPl2RQci1QErXbaFSPdKS4nly4QiRVUvg6urKzZs3SUxMxNHREScnJ5o3\\\\\\n\",\n              \"b86kSZOoWrVqtrlmBa1EiRIEXTzH9evXWbLMg5On1pCUmEBRQyOaODfmm1X7qVu37kerjyYRASYI\\\\\\n\",\n              \"hSwyMpJDhw7h7+/P0aNHqVChAnK5nJ07d+Lk5PTRN2o8f/483bp1o3Xr1hw8eJAdO3ZQq1YtnJ2d\\\\\\n\",\n              \"+eabb+jRo4da5ytSpAhKpZKhQ4exfM16hi3bjalVzuH7BkYmNHAbRBXntqyc2JOSNhYMGjiQAQMG\\\\\\n\",\n              \"8PLlS+bOncvChQsxMzNjwoQJVKhQge9/WUwJx0bo6OUfODe912JgUJRixYoxfPhwnJycKF26NDKZ\\\\\\n\",\n              \"jOjoaDZt3szXEyYTGxuLgUFR6tWuyfixX1OtWjW1Xu/7qlmzJtu2bvoo1/pciInMgvCRSZLE7du3\\\\\\n\",\n              \"s1pZd+/epVWrVsjlctq3b59jtNrHtHnzZqZOnUqvXr3Ys2cPBw4coGrVqrRq1QpnZ2cWLVqk9jkj\\\\\\n\",\n              \"IyOpVKkSGZlKhnp4Y2ab/15Zb16FsmFMF3r17M6ePXtISEigTp06zJs3j5YtWwJ/D2Tp7ObO7ZBo\\\\\\n\",\n              \"nMcvRrdI7p9tSZLEXb8tPD+5C1srS0qVKsWmTZuwtLQkLS2NsRMmsWPHdiwcnDGt3hxdo+Io01OJ\\\\\\n\",\n              \"f3SVyMt+VKtahe1bN1O2bFm1X7sm0qR7p2iBCcJHkJKSwsmTJ7OGuuvq6iKXy/npp59wcXH5T3Oz\\\\\\n\",\n              \"VJWamsr58+eJjo6mSJEiVKlShfLlywN/r+D+3Xff4evrS69evfDz8yMwMBB7e3u6du1K2bJlWbBg\\\\\\n\",\n              \"wXtdV0dHh4TERKo3k6sUXgDFbb6ibL1mbPP0ZNjQofz4449ERUVx48YNpk6dmjW4IiMjgyKGRhz4\\\\\\n\",\n              \"zo3K7ftTzqUT+oZ/D61XZioIvX6Gpye8SHz1DEVaKjVr1sTAwAAnJyfWrl3L/EWLeZ6khdPUnegZ\\\\\\n\",\n              \"ZZ+XVqx8Tb5qPYjws/uo26ARZ08HUKVKlff6GwgFQwRYIUhMTCQmJgZ9fX3Mzc3FqKLPVGhoaFZg\\\\\\n\",\n              \"nT59mlq1aiGXyzl69CiVKlUq8NUf3nrx4gW/L/dg4+YtGFt9RZFiligV6UQ+vkv16g6MGjaEP/74\\\\\\n\",\n              \"A/h7+HtgYCDnz5/H2tqaUaNGkZKSwp49e1TuykxJSeHSpUtZAy4uXbqETFuHup37q1XvRt2G8OTK\\\\\\n\",\n              \"aS5fvoy9vT12dnY4OTnh6OjIuHHjcHJyokSJEgCcOXOGZR4r2T+pA6aWtmhp6xAf9YpyZcvx0+Sx\\\\\\n\",\n              \"dO/endTUVBYvXszq1atp2rQpvfr2x8jekcoDf8p19B+AlrYOJZr2RKeoMS3btOPh/bsf7XNIIX+i\\\\\\n\",\n              \"C/EjUSgU+Pr6smrVKs6fP4+5uTmpqaloaWkxZMhQRo0a+c4hycKnLzMzM9vcrNDQUNq3b4+rqytt\\\\\\n\",\n              \"27b94HOzVHHy5Encu/fAtl57Srl0xci65P/qm5HOy6sneeC7DpvihjhUqcybN2/4888/MTU1Zc6c\\\\\\n\",\n              \"Oezfv59Tp05lrRqfm7i4uKwRgmfOnOHGjRvUqFGDJk2a4OLiQr169bCxsWXW4QdqhbYkScx1rcbR\\\\\\n\",\n              \"I4epW7fuO+vwVkxMDKGhoWRkZGBlZUXJkiVzHBMREcH06dP5Y/tOGs49gI6+aks5Pf7je74f0eu9\\\\\\n\",\n              \"5q9pkk/t3vku4q3/R/DixQtcXV0xMjZh5Ogx7Nvvn9Vl9PCvv1i/bg21atVi+vQZTJky+aO9Mxf+\\\\\\n\",\n              \"u7i4uGxzs6ytrZHL5axcuZIGDRr8p40Z/6vLly/TpWt3nIbPw7Jyzu1ItHX1KNmgHXa1mnFxxRSC\\\\\\n\",\n              \"rt/kr3t3MDAwYMOGDWzdupVz587lCI6IiAjOnDmTFViPHj2ibt26uLi4MHfuXOrXr5+1UnpycjK3\\\\\\n\",\n              \"b99Gpq2l9n/XMpkMLS0t5s6di5WVFWZmZio/3sXa2hprW1tKNpSrHF4Axet35rdlHowYMUL8P/qJ\\\\\\n\",\n              \"EC2wAhYeHk6jRo0Y9fU4JkzKexO6kJAQOsvb0b9ff6ZPn/YRayioKzg4OKuVdeXKFZydnbP2zfpU\\\\\\n\",\n              \"WtGSJFG5Wg2KNe3HV3XzX2svMz2VCwsGs2nlUpRKJSNGjCAwMJDy5cvz/PnzrLAKDAzk9evXNG7c\\\\\\n\",\n              \"GBcXF5o0aUKFChV48eIFjx494tGjRzx+/Djr+5iYGMqUKcPD4GC+9bpIUZP81z98KzUpgV971OeA\\\\\\n\",\n              \"ry+xsbHExMTk+9DV1cXMzAxzc/N3BtzYSVOw7fYDJqVU/0xLUiq59nNXbgZd/KznY30q905ViAAr\\\\\\n\",\n              \"YB1cXaldpx7fz5yV77EvX76kScO6+Pj4UK9evY9QO0EV6enp2eZmJSUlZe2b1bJlywLbl+m/OHfu\\\\\\n\",\n              \"HG69B+A8y0vl1sLzsweQ3TvKo7/uM3LkSEJCQggMDCQjI4MGDRpQrlw5zM3NSUtL4+nTp1khlZSU\\\\\\n\",\n              \"RPny5SlfvjzlypXL+r58+fIYGBiwcOFClq9YSYuBE2nUbajKr+HSn3+g//oB3nt2q3S8JEkkJSWp\\\\\\n\",\n              \"FHT+h49Se+p29IuptyXJA4/hHNi19Z0bbGq6T+XeqQrRhViAHj16RNCVK+zw2qfS8XZ2doydMImV\\\\\\n\",\n              \"K1eKACtkb+dm+fn5cezYMSpWrIhcLsfLywsnJ6dPvgtp+crV2DXuolY9v6rXGr8dizDQ1ebQoUPo\\\\\\n\",\n              \"6+tTrFgxQkJCCAgIIDQ0NCugmjdvzvDhwylfvjw2NjbZriNJEkeOHKF79+5cvXoVfX199HS0uea/\\\\\\n\",\n              \"gwbug1UaDKJUKrlxcCc7t25Uuf4ymQwjIyOMjIwoVSrnaMf09HQePnzInTt3OBZwGqUiXeVzv5WZ\\\\\\n\",\n              \"kf5RRowKqhEBVoBWr15D/4GD1ZrVP2DgYKpVKkdUVFSBLF2jiSRJylozz9DQEAsLiw8eIJIkcevW\\\\\\n\",\n              \"raxW1v3797PmZnl4eGBtbf1Br1fQ7t67j6W8tVpltPWKYGJbhlrl7HB2ds7WkjI3N8/3bx4dHc2c\\\\\\n\",\n              \"OXPYunUrCQkJODg4sHPnTtzd3alTpw4KJRxaNYcOY2a981ySJBGweTGlv7KlSZMmar0G+Dv8nj59\\\\\\n\",\n              \"yu3bt7lz507W4/Hjx5QpUwYHBwcsLa2Ie3YHA4sSKp83PeENSW8iP5luYkEEWIG6dPkSs35Sb+sD\\\\\\n\",\n              \"c3Nzajg6cfPmzawJm1+qhIQEPD09WeqxktAXLzAwNiEtOenvlRjGfs3gwYPz/cD+XZKTkzl58iR+\\\\\\n\",\n              \"fn74+/89sEYulzN37lxcXFw0es+ltLQ0tdYIfKuYmTmTJ0+mQ4cOKh0vSRJ+fn7Mnj2bGzduYGho\\\\\\n\",\n              \"SL9+/Zg1a1a20JckiVcvw9CNiebQilk0H/wNBrlshZKSGM/prUt5E3yDs4Gn8g268PDwHEF1//59\\\\\\n\",\n              \"LCwscHBwwMHBAblczrRp06hcuXLWm8n9+/cz6rvZ2NR598aa//T6sh9ubm6YmKi3hYtQcESAFaCk\\\\\\n\",\n              \"xEQMDdWfM6Krq8u2bdt4/vw5FhYW2R7FihX76EsLFYbz58/TsbMbFhUcsXefQKNq9ZHJZH9vxx58\\\\\\n\",\n              \"iw0HdvPTz7+w03Mbrq6uKp83JCQka25WYGAgtWvXRi6Xc/z4cSpWrPjJdw2qysLCgpTYSIqVrqxW\\\\\\n\",\n              \"ueSY19lWjs9LZGQkM2fOZMeOHSQlJVGrVi18fX3p0KFDjr/huXPnuHfvHlOmTGH69OmMHjOW5QOa\\\\\\n\",\n              \"UaFBC6q5dKCIoTGpSQk8vhzA3cBDuLq6cuj82Wy7DsfExGQLqbcPXV3drKBq1KgRI0aMoFq1avmG\\\\\\n\",\n              \"jKurK8rRY3jzMIjiFfP/PCsjKY7IC38y6bBfvscKH48IsAJkYmJCbKz6ew/FxsZiaWHO2bNniYqK\\\\\\n\",\n              \"yvZISEigePHiOYLt7cPc3DzH70xMTDTqxnzx4kXayztSb8TPfOXknO05mUyGVUVHrCo68jr4Jn0H\\\\\\n\",\n              \"Dmb71s15hlhmZiaXL1/O6hp8+fIl7du3Z8CAAXh6eqq0K7Am6tpZztLtPtg6qt4F9+b5A0hPznOA\\\\\\n\",\n              \"giRJ7Nu3j7lz53L79m2KFSvGsGHDmDlzZp5z3A4cOMCQIUOoWbMmjRo1wsTEhO3b/qB///68jozk\\\\\\n\",\n              \"ZeA+4uPjMTY2pkPTJuz4/Reio6Px9vbOFlRvuyTfPrp37061atVU3n7l33R0dNixbStu3XtSYdAC\\\\\\n\",\n              \"TEpXzfPYjOR4grdMY/CAvtSuXfu9ricUDBFgBah58xbs9/GmRUvVtwwPDQ3l3t07VKlcCTc3N1q3\\\\\\n\",\n              \"bp3tMzSFQkFMTEyOYIuKiiI0NJQbN27k+H1qamquwfauR9GiRQsl9NLT0+ns1pW6w+bkCK9/s6rg\\\\\\n\",\n              \"SOPxS+jdrz8vnj7JCqO4uDiOHDmCn58fhw4dwtbWFrlczurVq6lfv36hzs36GIKDg1mzZg2vQ8NI\\\\\\n\",\n              \"eROJQXHVRtqFBe5jzOhROf4+r1694vvvv2f37t0kJyfToEEDjh8/TosW714FfvPmzUyfPh1/f39+\\\\\\n\",\n              \"/fXXrF2Z09PTOXjwIDt37szWstq6ZQsL5s+nUqVKWUE1YcIEHBwcKFmy5Af/77Fly5Zs37qZvgMG\\\\\\n\",\n              \"YVG3A1YNOmNgbpf1vCI1iddBR4k8t5s+3d347Vf114EUCpYYRl+AXr58iYODAw8ePVO533zuT7MI\\\\\\n\",\n              \"ffGcWrVq4uPjw/Xr12nbti1ubm64urq+V/97Wloa0dHRuYZebo/IyEgAtQLPwsLig4zO2r17N1Pn\\\\\\n\",\n              \"LaXZtHUql7mwahpdmzhhbm6On58fQUFBNGnSJGtuVm4j0j5X/v7+9OrTB0kJBsXMwKAYjaesRFv3\\\\\\n\",\n              \"3f82L6+d4tHexdy7fQtLS0skSWLXrl388ssv3Lt3DwsLC4YNG8aMGTPyXUpJkiQWLVrE6tWrOXjw\\\\\\n\",\n              \"IPr6+gwbNgwzMzP09PQ4f/48oaGhVKhQAQcHB6pXr54VWOXKlfvoS6s9efKE3z1WsGXrVgwtS/69\\\\\\n\",\n              \"mG9GKjHP7tOseXOmTBiXb1h/Tj6Fe6eqRIAVsP79+6NfpCgeq9bk+w7y/r17tG3VjNOnT2ctGhoZ\\\\\\n\",\n              \"GYmvry/e3t6cOXOGJk2a4ObmRqdOnd67+0QVycnJKgfe24e+vr5agWdmZoaurm626zZo7IJBvc6U\\\\\\n\",\n              \"qd9G5bpGPLjKyV/HMbBfH+RyOS1atPgk52YVJEmSWLx4MbPn/kzZGnVx/foHilnZseOXyUS8ekWt\\\\\\n\",\n              \"IT9R1DznKvfKTAUvzu7nif96jh3+u7U6bdo0vL29SUtLo0mTJsybNy/HBpO5XT88PJxbt24xf/58\\\\\\n\",\n              \"bt26RenSpQkODsbY2JiMjAxsbW1p06YNt27donPnzowbN66g/hzv5e0ajrGxsRQtWhQHBwfs7Ozy\\\\\\n\",\n              \"L/iZ+VTunaoQAVbAEhISaNy4MXXrNeD3FavyfHd54/p1url1ZP78BfTv3y/XY+Lj4zl48CA+Pj4c\\\\\\n\",\n              \"OXIER0dH3N3dcXNzK/RWhiRJJCQkqBV4MTExGBsbZwu1Q4eP0Gv9mTy3xsjr2l7DGvHqZVi2D/6/\\\\\\n\",\n              \"FGlpaYwePZr9B/wp41Sf7lMXo/X/3YBKpZJjWz047/MH5hWc+KpBe4qYWpCZkU508HVenj9AhfJl\\\\\\n\",\n              \"6ezaHk9PTx4+fIiNjQ2jR4/m22+/zXUKyJs3b7J9PvV2FKC2tjba2tpoaWkxefJkIiIi2PvnAWKT\\\\\\n\",\n              \"0tAyK4WWjh5aKTFEPbnDkCFD+GHGNDEk/RP0qdw7VSECrIC9fReblJREQkICQ4ePpHff/tjZ2ZGa\\\\\\n\",\n              \"msrFC+dZt2YV586eYc2atXTv3k2l86ampnLs2DF8fHzw9fWlTJkyWWGmKVs+KJVKXr9+zcOHDwkO\\\\\\n\",\n              \"Dubx48fMX7CAgdtvqP15h8/YVty7df2Le8f86tUr3N3dMTAw4Prtu3yz7SQ6ejm7C9NTkrlx8gC3\\\\\\n\",\n              \"zhwlOT4ObV09op4/pEbVyty4cQOFQkGLFi1YsGABNWvWBCApKYn79+9nC6ncBlQ4ODhgb2/P6NGj\\\\\\n\",\n              \"0dPTY/Xq1XTp2oPnsQoM63bDsLRjtn/P9DfhxN/wJ+XeSfx8fXB2fvdnncLH9ancO1UhAuw/evz4\\\\\\n\",\n              \"MVu2bCUk5AWZSiV2tnb069eX6tWrI0kSI0aMIDo6mr1793Lz5k1WrVrN/v1/Eh0djb6+PlWrVmXk\\\\\\n\",\n              \"yFH06dP7vbu9FAoFZ86cwdvbGx8fH4yNjXFzc8Pd3Z3atWt/9MEYkiQRHx/Pq1evCA8Pf+fXuLg4\\\\\\n\",\n              \"LC0tsbGxwdbWliPHjtF91cmsPZ1UoVRmsmtoQ6IjI7+orS6uXbtGly5dGDx4MMdPBlC0XE1aDxyv\\\\\\n\",\n              \"cvkTniu56L2Fad99S/v27Xnw4EG2ltXLly+zDah4+yhVqlS2/6aioqJwdXWlWrVqLF++nGYt2xCm\\\\\\n\",\n              \"ZY15y1HIZHlP+Uh4HETMod84c+oEjo6O/+lvIXw4IsA+AQX9j3Djxg2mTZ/O1aAg+vQbQLVqDshk\\\\\\n\",\n              \"Mh49Csbzjy2UK1eO2rVrc+zYMS5cuJBjRW9JkgokWJRKJUFBQXh7e+Pt7U1qampWmDk7O/+nEXgK\\\\\\n\",\n              \"hYLIyMhcw+jfv9PW1s4KpX9//ef35ubm2erUpoOcpJJ1qdjcXeV6hVw7TcSxzdy+fvW9X9un4MWL\\\\\\n\",\n              \"F6xcvYZdXnt4ExONjq4O5cpXYMKY0XTr1i1bd56Xlxdjx45l+fLlHD9+nD+2bWPqjtOYmKv+uWji\\\\\\n\",\n              \"m2jm9XRGR1sra4WKfz7Kly+f74CK58+fZw0ymjdvHqtXr2a2x1asus55Z3i99eb6Qexigrh8/ozK\\\\\\n\",\n              \"9RYKlgiwT0BB/iOcOHGCXr16Meunn+nbfwAGBgbZns/IyMDHex9jR49g7ty5TJgwoUDqkR9Jkrh3\\\\\\n\",\n              \"7x4+Pj54e3sTGhpKp06dcHNzo1WrVlmjBpOSknINoX9/jY6OxszMLM8w+ufX920JHT58mKHjvqH1\\\\\\n\",\n              \"nB0qB3zAwlHMnjCcQYMGvdc1C1taWhrDRo7Gx8eHkg07YFe/AwZmNigVGbx5epfwc97EhQSzaf1a\\\\\\n\",\n              \"OnfuzI8//oinpyfr1q1j1qxZWFlZcfjIEeYeuqv2tX90rc6LZ8/ea0DQnTt3aN++PVOmTGHixIlI\\\\\\n\",\n              \"kkS5ytWQ6g3GqIyTSudQZmbwYvVALp07TdWqec/FEj4eTQowMQ9MTffv36d3797s8NpLE5emuR6j\\\\\\n\",\n              \"q6tLj569qFq1Gq7tWlG7du2P3s+vVCqJiopCoVBQt25dvvrqK+7fv8+FCxcYNGgQb968wcDAAIVC\\\\\\n\",\n              \"AZBrGDk7O2f72crKKt935HFxcZw/f574+HgMDQ2pWbMmNjY5R7/lpU2bNhhoZfLo5D4qtMz/88Dn\\\\\\n\",\n              \"V04Q/uAGvr6+tGvXTq1rfQrS09Np20HO80SJlvN90fnX4BWD4s2wq9WMmKd36Td4GMZFvgYkSpcu\\\\\\n\",\n              \"jVwuR19fn7S0NBSZme/VqpeU0nu92Th37hzu7u4sXbqUPn36AHDp0iXeJKRQorTq3YFa2roY12jL\\\\\\n\",\n              \"qjXrWLF8mdr1EL5sIsDUNH/+fMZPmpJneP2TQ/XqzF+4mNmzZ3P8+PEPcv2UlBRevXqVZ0vp7fev\\\\\\n\",\n              \"X7/G1NQ0Ryi5ubkxknXqzAAAIABJREFUevRo9PX1uXPnTtaW79WqVcsanv8+iwjfuXMHD48V7N7t\\\\\\n\",\n              \"RQ1HJ8zMzEhMTCToymVatW7NuLFjcXFxyfc8WlpaHPY/QIPGTZAkJRVads/zpvz0whEubviJE8eO\\\\\\n\",\n              \"cPDgQWrUqMGiRYsYOHBgjjKSJHHlyhVWrVrFhQsXSExMxMTEhKZNmzFmzNdUr15d7desjszMTKKj\\\\\\n\",\n              \"o3n9+jWRkZFZXz137ORFokSD8cvy3NYewMy+Go2+WcOpXwahry3j8uXLFCtWjKpVq2Jtbc3ho8eJ\\\\\\n\",\n              \"eBaMjX1FlesU+eIJRsbGOXoQ8vN2dQ1PT0/atv3fWoLBwcEUta2gdojqWlfg3oNLapURBBABppao\\\\\\n\",\n              \"qCh8fX25++tSlct07d6DGdO+5cGDB1SunPu6dG9XW8/r86R/fk1JScHGxiZHF17dunWzBZW1tXW+\\\\\\n\",\n              \"i9G6u//9OVNcXBz+/v54e3szadIkateujZubG126dMl1S/Z/27r1D7799htGjxnHtVv3sLW1zXou\\\\\\n\",\n              \"Li6OHZ7b6N+/P/37D2Du3Dn53uDKly/PxXNn6NCxM08D9lKmeXdK12uJXlFjMlKTCb0eyF+HPdHJ\\\\\\n\",\n              \"SKZS+bL4+/szf/58unfvztChQ9mxYwfr1q2jTJkywN8TVfv06UPE69cMHzmaCZO/w+j/2rvzuBqz\\\\\\n\",\n              \"xw/gn0pFab+3uok2lRQRgywjS7KNfSdblhj71hgzw9hlG2MdRoiZMRiyRbJmF0b2KVsoKSW0L/fz\\\\\\n\",\n              \"+2O++k3Tdq9quJz363VfdHvOc859zJzPfZ7nPOfo6eFVcjL2Bu1G27Zt4eTkhG3btil8BieXy/Hy\\\\\\n\",\n              \"5csCgRQfH1/oe8nJyTA0NIRUKoWpqSlMTU1hYGCAa9euoeWcXcWG11v6FrZwaDsQT0/twE/LluHW\\\\\\n\",\n              \"rVvYt28fnj17BhdnJ5wP2oquE+co1H4ACD/4G4b5+CgVOJs2bcLXX3+NgwcPFljyJzs7G1BXvktR\\\\\\n\",\n              \"06iAzCzllzYRBHEPTAnr16/HsRMnsWXrr0qV+2raZMTFxqBFixaFhtLz58+ho6NT5P2kf/5pZGRU\\\\\\n\",\n              \"rqMK09PTceTIEezevRsHDhyAnZ0dunXrhm7dusHBoeC3+507d2HixAk4cCgUNYoZvh8fH4/OHdqi\\\\\\n\",\n              \"a9eu+O67bxVqi1wux/Hjx7FsxUqcDjuFtJQ3qKijC2eXWrgfeRcxMTFISUlB48aNMWnSJPj6+iI7\\\\\\n\",\n              \"OxvLli3D4sWL8e2336JNmzZo1aoVpkybDt/RXxY6EXJ2djYWLZiLX7ZuRWDgFqirq5cYSImJidDT\\\\\\n\",\n              \"08sXSG///u8/TU1Noauri8ePHyMyMjLvdfLkSaRUMkXjCT8q/O+T8ToJIdO+gGstZ3Tr1g0dO3bE\\\\\\n\",\n              \"n3/+iRkzZiDxZTKmbj0OPeOSz6BTkhOxYmg7XL92VaFnsf45u0ZISAgcHR0LbHPgwAH4TJ4J014L\\\\\\n\",\n              \"Ff48AJB05QAa67/Azu2/KFVOKB/iHthH6tmzONja2ildzs7OHgf374Oenh5kMhlq164NLy+vfOGk\\\\\\n\",\n              \"zJph5alSpUro3LkzOnfujOzsbJw6dQp79uyBh4cHjIyM8p41q1u3LtLT0zFqlC/2Bx8pNrwAwNTU\\\\\\n\",\n              \"FLv3HUQDt9ro27cP7O3tS2yLuro6WrdujdatC84l2aRJE4SEhKBz5844dOgQmjZtCktLS3Ts2BF+\\\\\\n\",\n              \"fn7o3LkzhgwZgpkzZ2L+Qn8MHTaiyHo0NTXxzXffg1RDp86dUcPREWZmZnkBZGtri0aNGuULJIlE\\\\\\n\",\n              \"UmAWEblcjidPnuQF1Llz5/L+HhMTg2rVqsHBwQEODg6oV68ert26A2PXjiUeh3+qqG8My5r18N13\\\\\\n\",\n              \"06ChoYEBAwbA0NAQO3fuRMiRUAR+OxyDFmyCjn7RkxSnp7zGluk++HL0KIXCSy6XY8qUKQgNDcXZ\\\\\\n\",\n              \"s2dRpUr+NbSysrJw+PBhBAYG4sWDmzB8+QxaRrIi9lZQ1t1jGLhigcLbC8JbIsCUoKGhjtzcXKXL\\\\\\n\",\n              \"5ebmonVrT6xdu6YcWlV+NDU18wJk5cqVuHTpEnbv3o1evXohJycH9vb2+KxBQ9R1c1NofzKZDAMH\\\\\\n\",\n              \"D8W6dT9h6dIlSrcnLS0t7yyoYcOG+O677/DXX38hPj4edevWRbdu3WBnZ4fU1FTEx8eDJJxdahUb\\\\\\n\",\n              \"Xv/0zXczsWvndixatKjI+3Uk8eLFC1y6dCnf2VRkZCTu378PY2PjvJBycHCAp6cnHBwcYGNjUyDw\\\\\\n\",\n              \"Ngb+Aq3Kys+GX0FXH9OmTYOmpiYWLFiAjh07Qk1NDe7u7njz5jXWj++FFgPHw7mpJypo/v9l5Jzs\\\\\\n\",\n              \"LNw6G4oTW1Yg5WUiGjb4rMS6srKyMHToUDx69AhhYWF5s87L5XKcPn0av/76K/744w/UrFkT/fr1\\\\\\n\",\n              \"g5GJFAeuHYSkxTCFPkta7F9Qy0hWeP0xQfgnEWBKsLa2xq+/bVe63M0b12FfXfkztw+Juro6GjVq\\\\\\n\",\n              \"hEaNGmHRokW4efMmOnfuguU/rlJqP8NH+KKp+2dYsGA+5HI5EhISFLqHFB8fD7lcnncWZGJigrt3\\\\\\n\",\n              \"7+L+/fuws7ODi4sL6tWrh3Xr1mHPnj1wc3ND+/btMWzkaIXbpqamhhEjR2PNmjWoV68eoqKi8gXU\\\\\\n\",\n              \"X3/9hcjISJCEo6NjXkj16tULDg4OsLe3V2pEn66uLrIy05U6fgDw5mUienp6YsWKFfmeoVNTU8Oy\\\\\\n\",\n              \"pUvh0bw5Fi9bjkNr56FGw+bQ1KmMrLQURF4KQ40ajli9fAksLS3RoUMHmJiYFDlCNiUlBT169ICW\\\\\\n\",\n              \"lhaOHDmCSpUq4dq1a/j111/x22+/wdjYGP369cOVK1fyzuSio6Oxs259vKlWB3p2xa+zlZP2ConB\\\\\\n\",\n              \"S+A/e+ZHv0KAUD7EPTAlpKamolq1ajh38Qqs/jdAoCSvX7+GvU1V3L59u8ClF1Wnr6+PqIdPlJ5/\\\\\\n\",\n              \"UCY1QnZ2NrKysoq9h/Tv9ypXrpzv/t+4ceNgaGiI2bNn5723cuVKrFmzBqGhobC3t0d80usCZz7F\\\\\\n\",\n              \"efHiBeysqkBdXR3Vq1fPdzb19iWRSMrkPuT0Gd8g6PID1Ow9WeEyuVkZOPZVJ0RcDYetrW2x2965\\\\\\n\",\n              \"cwdnzpzJW2+radOm+Z61Cg0NxYABA3D06NECozDfzq7h4uKCr776Cjt27MhbvLJfv37o168fXFxc\\\\\\n\",\n              \"Cq33zJkzaP9FF+g38YZBbU+oaxQ8/umxkXgRvBgjBvWD/8L5Cn9+ofyp0j0wEWBKGj9+AnIJLFmm\\\\\\n\",\n              \"2DMrSxf7Y+WKZbCxscG8efM+qmUZtLW1EfciWelh2I52Vti7dy9cXV1LFQQ3btxA27ZtER0dne/5\\\\\\n\",\n              \"tClTpuDMmTN49OgRHj2NU2qfJKGrrYHMzEylgu9dPH78GDVruaLVwv2ooK3YMYw+ewB6j8/hRGhI\\\\\\n\",\n              \"mbRh+/btmDp1Ks6cOZPvLKpVq1awtbVFSkoKoqKi0KtXL/Tr1w/u7u4KrQh+/fp1DPMdgzt37kCv\\\\\\n\",\n              \"thc0zeyhplEB2a/ikXX3GNQzX2Hu9zMxfJhilxqF/44qBdjHvzZ9GZs+/Ssc2BeEzQEbS9z24IH9\\\\\\n\",\n              \"WLliGU6cOIGxY8dixIgRaNWqFS5cuPAftLT8SSQSxMbEKFXm7dpk1tbWpT6LqVWrFqysrBAcHJzv\\\\\\n\",\n              \"fX9/f5iZmeHNmzdK7zMjIwOamprlviYVSVy6dAkk8deBkv9bAoDstBTcCVoLr1Ytyqwdffr0wbRp\\\\\\n\",\n              \"09CmTRs8fPgQC/63oGRMTAykUim++eYbxMbGYvXq1WjSpIlC4QUAtWvXxqVzYbh49hS615HCLvki\\\\\\n\",\n              \"qjw7gcb6L7Dpx4WIffxIhJdQauIemJLMzc0REhKCtm3b4tq1PzFu/ETY2uW/vxUTE4P169YgcHMA\\\\\\n\",\n              \"9u7dCycnJzg5OaFnz57YvHkzevbsibp162LOnDkqPYlp167d8Mu2QHw3a3bJG/9P0J7dcHJyKrNl\\\\\\n\",\n              \"T0aMGIH169ejU6dOee+pq6tj+/btkMlkuHrlCtyUWAb+5InjcK1Tp1wfVXjy5AnGjBmDyMhItG7R\\\\\\n\",\n              \"HIdDd0JbzxDV2xS+jA4AZKW+xp/rpuLzRvWxYsUKvH79GrNmzSrxWb+SZGVlwdraGhoaGqhevTo0\\\\\\n\",\n              \"NDQwYsQILFq0qEzWVKtZsyZWrlD8uUlBUIY4A3sHjo6OuHDhAvR0ddC8aSN0bNcG06ZMgt/UyejZ\\\\\\n\",\n              \"rTM+q1sLb14l4/z582jYsGFeOU1NTQwfPhxRUVFo2bIlvLy80KdPH0RGRr7HT/PuRo8ehU0bNyBL\\\\\\n\",\n              \"wYdQSWL1yhWIjY2Fm5sbAgMDFS5blF69euH8+fN4/PhxvvcrVaqEcePG4ccVy5Ta3/p1azB61KhS\\\\\\n\",\n              \"takoubm5WLVqFdzc3FC7dm04OTkhPj4eLk6OiArehFPzhyDm6gnIc3PyymS8SkRk8CacmdMfXVs3\\\\\\n\",\n              \"xYF9exEREYHr16+jSZMm7/TfjlwuR1hYGEaOHAkLCwv4+/ujefPmqFChAmrVqoVly5Z9cguCCiqK\\\\\\n\",\n              \"H6l69er9J/Wkp6dzx44dXLJkCf39/blt2za+fv1aobJv3rzhvHnzKJFIOHToUD569KicW1v2GjZq\\\\\\n\",\n              \"xD79+jMtS870bBb7Wui/lC4uLszIyGBwcDA9PT0pk8k4d+5cJiQkvHMbxowZw++++67A+8+fP6eB\\\\\\n\",\n              \"gQGvXLtZYtvSs8ljJ89QIpEwLS2tNIekUDdu3GCjRo3YtGlTnjlzhu7u7vTy8qKpqSktLS3p5OTE\\\\\\n\",\n              \"KVOm0K1BI+rqG9LMqjolVayoq2fAgYOH8sqVK/n2J5fLuWbNGkokEm7YsIFyubxAneHh4Rw8eAht\\\\\\n\",\n              \"bW0pkUhoaWnJ2q6uNDU1pbOzMxcsWMCHDx8yICCA5ubmPHv2LLt06cK+ffsyNze3zI+BoBr+q76z\\\\\\n\",\n              \"LIgA+wAkJSXx66+/prGxMceMGcNnz5697yaVKDc3l9OnT6eVlRXr1KnDfv29+TTuRaHBkPgqldNn\\\\\\n\",\n              \"fEtra2tGR0fn28/169c5dOhQGhoacuTIkbxz547Sbbl+/TqrVKnC7OzsAr8LDNzKKpaW/PP67WLD\\\\\\n\",\n              \"K+zsRZqZmfHQoUPvfEwKk56ezhkzZlAikXDdunW8e/cuq1evzubNm1NLS4sSiYRbtmxhTk5OXpnY\\\\\\n\",\n              \"2Fhev36dd+/e5Zs3b4rd/61bt+jq6squXbvyxYsXJMknT56wSZMmtLK25px5Cxhx8y4fPY3jjduR\\\\\\n\",\n              \"XLR4KW3t7FivXj1GRkZy4cKFtLKy4t27d0mSaWlpbNasGceNG1doKAofP1XqO0WAfUCeP3/OCRMm\\\\\\n\",\n              \"0NjYmH5+fkxMTHzfTSrUmzdv2KVLFzZr1ozx8fFMTU2lj88wGhgY0HvgYO7ee4Anws7xwKEjHDt+\\\\\\n\",\n              \"Ik1MTNjxiy+KDea4uDjOnDmTpqambN++PUNDQ5XqQN3d3blv375Cf7dp02YaGhpy1JdjefNOVL7g\\\\\\n\",\n              \"uvznDQ4aMpTGJiZFln9XJ0+epIODA7t3786YmBieOfP3GZ5MJqOGhganTp3K9PT0UteTkZHByZMn\\\\\\n\",\n              \"s0qVKty2bRurVLHk7LkLmJKRU2hYp2bm8ocfV9HAwJD29vZ8+vRpvv29fPmStWvX5rx580rdNkH1\\\\\\n\",\n              \"qFLfKQLsA/T48WMOHz6cJiYm/P777xW+JPlfiI6OpqurK4cMGcLMzMx8v4uPj+eCBQvZ2tOTDRo2\\\\\\n\",\n              \"ZMuWLTltmh8fPHig8P7T0tK4YcMG1qxZk7Vq1WJAQAAzMjJKLLdp0yZ26NChyN/fv3+fnp5tqKOj\\\\\\n\",\n              \"Q8caNfjZZw1Y3d6eMpmMrq51OHXqVIXbWJKkpCT6+PjQ0tKSQUFBJMmAgABWqlSJ6urqtLW15cOH\\\\\\n\",\n              \"D8usPvLvLxXTpk2jrm5lzl2wSKFLpqvX/kRbW1tmZWUV2F9sbCxtbGy4YcOGMm2n8OFTpb5TBNgH\\\\\\n\",\n              \"LCoqiv3796epqSmXLFlSLvdmlHHu3DnKZDIuWbKk3C8vyeVyHj58mF5eXjQ3N+fs2bMZHx9f5Pap\\\\\\n\",\n              \"qak0NjYucIny3yZOnMi6devyxIkTvHnzJrOyshgZGUmJRMKXL1+Wus3bt2+nTCbjl19+yVevXjE1\\\\\\n\",\n              \"NZXt2rWjmpoadXR0OGbMmHyXC0sjMzOT+/btY58+faivr88GDRrQyammQvcj376aNvucO3fuLHT/\\\\\\n\",\n              \"kZGRNDc3zwth4dOgSn2nCDAVcOPGDXbp0oVVqlTh2rVrC5z5/Be2bt1KiUTC/fv3/+d137x5k8OG\\\\\\n\",\n              \"DaOhoSGHDx/OW7duFbrdmDFjOHPmzGL3lZuby169erFXr175BioMHjy4xLLFiY6OZocOHejs7Mxz\\\\\\n\",\n              \"584xOzubP/30E3V1dVmpUiUaGBgUGRTKyM3NZVhYGEeOHEkTExM2adKEq1evZnx8PLt07cqVq9cp\\\\\\n\",\n              \"HF7p2eSWbb+xRYsWRdYXHh5OqVTKsLCwUrddUA2q1HeKAFMhly5dYps2bWhjY1Pgxn95eTtYw8bG\\\\\\n\",\n              \"hjdu3Cj3+orz/Plzfv/99zQzM6OXlxdDQkLynQlev36dlpaWhQ7m+Kf09HQ2bdo032XD+/fv08TE\\\\\\n\",\n              \"ROn7jjk5Ofzhhx9oYmLCOXPmMCMjg3v27KGjoyONjIxoZmZW6mMnl8sZERHBadOmsWrVqnR2dub8\\\\\\n\",\n              \"+fMLXIasWLEinyW8VCrAXqVmUktLq9jLtKGhoTQ1NWVERMQ7fwZBdahS3ykCTAWdPHmSTZo0oZOT\\\\\\n\",\n              \"E3fu3FluQ57/PVjjQ5Gens6AgAC6uLjQ2dmZP//8c95giEaNGik0GCMxMZGOjo5ctWpV3nvDhg3j\\\\\\n\",\n              \"jBkzFG5HREQEGzRowM8//5x3795lWFgY3d3dWaNGDdrY2FAmk9HT0/OdB+M8fPiQ8+fPp7OzM6tW\\\\\\n\",\n              \"rUo/P78iQyQjI4MVKlRQ6vLh25dEIuHz58+Lbcv27dtZpUoVpe5nCqpJlfpOEWAqSi6XMzg4mHXr\\\\\\n\",\n              \"1qWbmxuDg4MVui916dIljhzpS6+2benZpg0HDx7M0NDQAiFY3GCND4VcLmdoaCjbtWtHMzMzzpw5\\\\\\n\",\n              \"k8uXL2fHjh0VKv/gwQPKZDLu3buX5N+BYWxsXOIzaWlpaZw+fTqlUik3bNjAiIgIduzYkVZWVpw/\\\\\\n\",\n              \"fz5lMhkNDQ05bdo0pc+S4+PjuXr1ajZu3JgmJib09fVlWFhYiV9SoqKiqK6uzlepmUqFV1qWnJUq\\\\\\n\",\n              \"VSpxuD5Jrly5kvb29iWGnaDaVKnvFAGm4nJzc7lz5046OTmxadOmPHXqVKHbhYWFsV69erS2seHs\\\\\\n\",\n              \"eQu4e+8BBu0P5rIfVrJW7dq0t7fnjh1/36P5LwdrlJXbt29zxIgRNDAwoLa2NkNDQxUqFx4eTolE\\\\\\n\",\n              \"wosXL5IkfX196efnV+T2x44dY/Xq1dmrV6//PSg8mKamply+fDkPHTpEPT096unpcfv27Qq3/c2b\\\\\\n\",\n              \"N9y2bRvbt29PfX199u3bl/v37y/yi0NGRgYPHDjAQYMG0cHBgVpaWgRAAwMDBu0PVirAjp86y+rV\\\\\\n\",\n              \"qyv87/zNN9+wXr16H9TIWKFsqVLfKQLsI5GTk8MtW7bQxsaGbdq04aVLl/J+t2dPEKVSKX/9fRdT\\\\\\n\",\n              \"M3ML/RYeevwUq1atyoEDB1Iqlb6XwRplISEhgQ0bNmTlypXp6enJQ4cOldg579u3jzKZjPfv3+fj\\\\\\n\",\n              \"x49pbGxc4CzjxYsXHDJkCKtWrcpt27Zx8uTJNDY25owZM5icnJw3TF4mk/HatWsltjMzM5P79+9n\\\\\\n\",\n              \"3759qa+vz3bt2nHbtm2FnglFR0dzyZIlbN26NSUSCdXU1KihocFq1aqxZ8+e/OWXX5iamsqNGzey\\\\\\n\",\n              \"fYeOSgVYn779uXTpMoWPr1wu54gRI9i6dWuFHm8QVI8q9Z0iwD4ymZmZXLt2LS0sLNilSxfu3LmT\\\\\\n\",\n              \"EomEZy9cLrEz++t+NE1MJFy5cuX7/hilEhERwSpVqnDjxo10dXWlk5MT169fX+xjCGvWrKGDgwNf\\\\\\n\",\n              \"vHjBMWPGcNKkSTxx4gTHjRvH5s09aGRszMaNG3Pq1Kl5l/ZiY2Mpl8vp5+fHihUrsnHjxnmzYRTm\\\\\\n\",\n              \"7QhCX19fSiQSNm7cOG8E4Vvp6ek8duwYR40aRWdnZ2pra+cNwXdzc+OkSZMYHh5e6CXF1NRUSqVS\\\\\\n\",\n              \"Hjt5RqHwOn/pKo2MjJiUlKTU8c3JyWHXrl3Zu3dvMeXUR0iV+k4RYB+ptLQ0LlmyhIaGRlzov1Th\\\\\\n\",\n              \"b+R/BO1n/fr133fzS61Ro0bcv38/5XI5jx07xo4dO1IqlfLbb78tckYQPz8/NmnShMuXL6e+vj4d\\\\\\n\",\n              \"HWvw+znzuGHjZq5a8xO79ehJHV1ddu3alc+ePWNmZiY7depELS0tjho1qsjRjxEREfTz82O1atVY\\\\\\n\",\n              \"s2ZNzps3jw8ePKBcLuejR4+4du1adujQgWZmZlRXV6e6ujrNzMzYrl07rl69mrGxsQp/7pCQEJqa\\\\\\n\",\n              \"mvLk6fPF/jtfvHyNFhYWeZeNlZWens7PP/+cY8aMKXCGK5fL+fr1ayYmJv4nI2WFsqVKfacIsI9Y\\\\\\n\",\n              \"XFwcDQ0NlRpanZKRQytr63yXIFVRQEAAv/jii3zv3b17l76+vjQ0NOTgwYMLjOjLzc2lq6srq1ar\\\\\\n\",\n              \"xpCjJwod0fc07gWnTPuKVlZWefefAgICCtT/dgShi4sLq1atymnTpvHixYs8ceIEJ06cyDp16rBi\\\\\\n\",\n              \"xYrU0NCgpqYma9SowWHDhvHw4cNMTU0t1Wc/cODA3xNE+wznhfA/87X/asQt+o4eQxMTE/72m+L3\\\\\\n\",\n              \"6QrzdsqpuXPnkvz7Uuf06V/T1NSUurq6NDAwYMWKFdm3b1+ePn1aZe6nfupUqe8UAfYRW79+Pfv0\\\\\\n\",\n              \"7a/0sOqvv/mOU6aU3dRK70NKSgoNDAw4ZcoUerVtyyZNm7Jtu3ZcsGAh7969y3nz5tHCwoKtWrXi\\\\\\n\",\n              \"wYMHmZuby7Vr19He3oGPY+NLPEZLl/9IXV1dnjx5Mq/OhIQErl69mk2aNKGxsTH79u1LPz8/du3a\\\\\\n\",\n              \"lRYWFlRXV6eGhgb19fXZtGlTzps3j9evXy+Xy3DPnj3jnDlzWbVqVVpZW9O1Th3a2NpSJpPx22+/\\\\\\n\",\n              \"45MnT8qkntjYWFpbW7NLly40Njbml2PH89qNO3nHKTY+iYuX/kB7Bwe29vRkcnJymdQrlB9V6jvV\\\\\\n\",\n              \"SPJ9L+lSHlRpWezyMn/+fLx89QZz5i1QqtzmgI34/bdtGD16NHR1daGrqwsdHZ18f+rq6qJixYoK\\\\\\n\",\n              \"r9D7X3r58iXGjh2LoL170b1HT3Tq3BUGBgZITk7GvqA92Ld3D7p1644lSxYjODgYy5YtQ0pKCp49\\\\\\n\",\n              \"i8PJ0+fg7OKiUD0D+vZCg8/qw9LSElu2bMGZM2dgbW2NrKwsxMTEICcnB3K5HBYWFnB3d0fnzp3R\\\\\\n\",\n              \"okULyGSycj4C/y8nJwcPHz7E69evoaenBxsbG2hqapZpHaNGjcLx4ycQcuwkzM3NC90mNzcXkyeO\\\\\\n\",\n              \"x5Xwizh58qRYb+wDpkp9p1iR+SOmqamJ7OxspctlZWXhUng4rvj4QEtLCxoaGlBXVwdJ5ObmIjs7\\\\\\n\",\n              \"GxkZGcjKykKlSpXyhVphQafse2//rqOjo3RAJiQkwMPDA597tMT9R08LrPzc8YtOWLh4KaZPm4I2\\\\\\n\",\n              \"bdrg2LFjMDMzQ48ePeBYo4bC4QUAo74ch84d2yE7Ows5OTnQ1NTE/fv34eTkhC+//BLt2rVDw4YN\\\\\\n\",\n              \"32tnXaFCBdjb25fb/kNCQhBy5AjCzl6ERCIpcjsNDQ0sX7ESI3yGYMqUqVi7dk25tUn4dIgA+4jZ\\\\\\n\",\n              \"29tj3/4DSpe7eiUcU6dMQffu3fHs2TPExsbmvf7587Nnz6ClpQUTExNIpVIYGxvD0NAQ+vr60NPT\\\\\\n\",\n              \"Q6VKlVCxYkVoaWkhMzMTqampSEhIwKNHj5CWlobU1FSkpqbm/f3f76Wnp0NbW1vhINTR0cH27dvx\\\\\\n\",\n              \"ReeumL/Qv8jPZ2xsjHUbNmLiuDFwdXXFy5cvoaamhjHjJih1nBo3aQIDQwNUq1oV/fr1g4eHB2rW\\\\\\n\",\n              \"rAkNDQ2lj7mq+mHFCkz/+ttiw+stNTU1zF2wCHVcamDBgvkwNDT8D1oofMzEJcSPWHZ2NqysrHDg\\\\\\n\",\n              \"UChqOjsrVObly5eo6WCLv/76C6ampsVuK5fLkZSUVGi4/fPnuLg46OvrQyaTwcLCIu/1z59lMhlk\\\\\\n\",\n              \"Mhm0tLTy7T8jI6PEoHv7540bN3Du/HlE3Lyr0JlbTk4ObKtZIDk5GXp6egjafwifNWig0HF6q71X\\\\\\n\",\n              \"a1S3s0XdunWhqakJLS0taGpqFngp+76mpuYHeXn2nx48eIAGDRog6uETVKpUSeFyAwf0hXvDhpg4\\\\\\n\",\n              \"UbkvDMJ/Q5X6TnEG9hHT1NTEsGHDscR/ITZuDoSamlqJZVauWA4TEwnkcnmJ26qrq0MikUAikaB2\\\\\\n\",\n              \"7dpFbieXy5GYmFgg3G7fvo2jR4/mBd3z589hYGBQINz+/bO5uXmh93G6de+OseMmKtzxV6hQAWMn\\\\\\n\",\n              \"TMKeXTsQExOL3Nxchcr9U3Z2NuLi4hAREYHs7GxkZWUhOzu7wEvZ97Ozs6GhoVHqICzP948cOYLm\\\\\\n\",\n              \"Hi2UCi8A6NCxEw7s2yMCTCg1EWAfucmTJ6Fp06ZYMG8Ops/4ttgQ++2XbdgcsBFdu3ZB7dq1MX/+\\\\\\n\",\n              \"fPj4+CgUfMVRV1eHVCqFVCqFq6trkdvJ5XIkJCQUOJO7ceMGQkJC8oLv+fPnMDIyyhdu5ubmOLB/\\\\\\n\",\n              \"P376ebNSbRs4aAjmzPoOFStWxI3r19DI3V3hsrm5uYh+9BA/HT6MmjVrKlVvSUgiJyen1EFY0vup\\\\\\n\",\n              \"qanvvJ/nz5+jeYuWSn82fX19vHnzpkyPl/BpEgH2kTMwMEBISAjatWuH69cjMH7CZDRyd88XSrdu\\\\\\n\",\n              \"3sTa1StxJOQQjhwJgYuLC0aMGIHhw4fjl19+wfr168t1IMBb6urqMDMzg5mZGerUqVPkdrm5uUhI\\\\\\n\",\n              \"SMh3mfLevXvQrlgRenp6StVpamoKuVwOkvhxxXIMG+GrcGAfPhQMCwuLMg8v4O/7RW/PdD5Uv/76\\\\\\n\",\n              \"K3bvCVK6XHJyMvT19cuhRcKnRgTYJ8DCwgJnz57F+vUbMMJnECpWqoSazi7QUNfAg/v38PhxNIYP\\\\\\n\",\n              \"H4HLly/n3fdydXXF+fPnsXLlSri7u2Py5MmYMmXKB9Ghvn79Gg8ePEBERAQuXLiAGzdu4OHDh8jK\\\\\\n\",\n              \"zFR6X2/Dy87ODo8eRePUyRPwUOCsQi6XY+H8uWjZwuMdPsHHoUmTJhg7dixSU1OVGml5YF8QWnh4\\\\\\n\",\n              \"lF/DhE+GGMQSilL4AAAa00lEQVTxiZHL5Th79iyio6Mhl8shk8ng4eFRbDA9evQIo0aNQmxsLDZs\\\\\\n\",\n              \"2IAGSg50eBck8+6TXbhwAeHh4bhz5w5iYmKQmZkJDQ0N5OTkQCKRwMbGBrVr18bu3bsRcvSkUkPh\\\\\\n\",\n              \"L164gE4d22LN6tXQ1dXF0KE+OHbyNJyKOauSy+WYMmkCzp0JQ2JiItq1a4fFixcXGLL/KejUuTPa\\\\\\n\",\n              \"tf8CQ3yGKbR9bGws6tdxwaNHj8RZ2AdKlfrOD3uYk/DOSOLcuXMYMGAAqlSpAj09PchkMnTt1g3p\\\\\\n\",\n              \"6eno168fBg4cCE9PzxLPqqytrREcHAw/Pz906tQJEydOREpKSpm0Mzc3F/fu3cOuXbswYcIEeHh4\\\\\\n\",\n              \"wNLSEtra2rCysoKXlxfmz5+PW7duwd7eHpMnT0ZQUBBu376NzMxMPH/+HMePH0eDBg2gra2N1StX\\\\\\n\",\n              \"KFX/iuVLUKliRcTFxWHt2rWws7NFW88WWLdmNV6/fl1g+8vh4ejdoysi/ryCEydO4NatW1BXV4eL\\\\\\n\",\n              \"iwsOHjxYJsdElUwYPx6LFsxFXFxcidvK5XL4TZ2EAQO8RXgJZeP9TABS/lRpOpSy9vjxYzZs2JB2\\\\\\n\",\n              \"1atzof9S3o16yLgXybz36ClXr13P2q6udHR0fKdl7hMSEjhw4EBaW1vz0KFDCpdLT0/ntWvXuGrV\\\\\\n\",\n              \"Knp7e9PNzY0mJiZ50yupq6vTxMSEbm5u9Pb25urVq3nhwgW+fPmyyH3eu3ePkydPpomJCb/44gtu\\\\\\n\",\n              \"27aNBgYGjHzwWKEps27eiaKBgQF37dpFMzMzampq0sfHh4GBgezeowcNDQ3Zt98Ajp84mSNHjaaT\\\\\\n\",\n              \"U01aW1tz4cJFBWa2P3bsGG1sbOjt7V3sjPQfo++/n82azs786350kcf6dVoWBw0eyiZNmhS7KoDw\\\\\\n\",\n              \"/qlS3ykC7CMTHR3NqlWrcv6ixYWu/fV2/a+AzVtpamqq0NpVhQkJCaGNjQ379++fbzmQ5ORkHj9+\\\\\\n\",\n              \"nDNnzmTHjh3p4ODAypUrU01NjWpqatTW1ma1atXYsmVLTpo0iX/88Qfv37+v8Kzlubm5DA4OZvv2\\\\\\n\",\n              \"7SmRSDh16lQ+ePCAsbGx9PLyopaWFq2tbfjgcWyx4XX33iPa2NhyxYoVbN68Ob29vfn06VPOmzeP\\\\\\n\",\n              \"VatWZaNGjfjDDz9w1apVtLGxoY+PD7W1tZmenl5k21JSUjh+/HjKZDLu2rXrnY6rKpLL5VyyZCkN\\\\\\n\",\n              \"DQ05ZOgwnrt4hamZuUzLkvPB41h+P2ceq1Wrxk6dOyu08rPwfqlS3ykC7CMil8vp5uam8PIpW3/9\\\\\\n\",\n              \"nVWrVlX6G7FcLufTp08ZEBDAOnXqUFNTk0ZGRtTU1MwLKgMDAzo7O7N79+5cuHAhz507V6qJXJOS\\\\\\n\",\n              \"krhs2TLa2dmxbt26DAgIYFJSEnfs2EFPT09WqFCBZmZm3L59O2fPnk1TUzMuXf4j414k5/vMsfFJ\\\\\\n\",\n              \"XLR4Gc3MzKivr8+6dety2LBh+SbUzcnJ4d69e9m2bVtKpdK8RSxr1aql0Cz9Z8+epaOjI3v06MG4\\\\\\n\",\n              \"uLh3/syqJi4ujnPnzqO1tXXeLPt6enr08RnGK1euvO/mCQpSpb5TBNhH5Pjx46zp7FzoMiBFvdp4\\\\\\n\",\n              \"teWWLVsK3V9OTg4jIiK4fPly9u3bl3Xq1Mm77AeAWlpatLCwYO3atWliYkJnZ2eeOHGiTNeAunbt\\\\\\n\",\n              \"GocPH05DQ0P269ePZ8+e5ZkzZzhy5EgaGxuzfv36NDY25qRJk/LW4zp58iQrV65MY2MT6urq0rON\\\\\\n\",\n              \"F3v27sM2Xm1paGjIAQMG8MiRI7SysqKBgUG+M8h/i4qKYtWqVamvr8+qVaty+PDhCn2+9PR0fvXV\\\\\\n\",\n              \"VzQ1NeW2bds+uaVEsrOziz1bFT5cqtR3ilGIH5EePXui2ectMHLUaIXLHDywH4vmz8WcObMRFhaG\\\\\\n\",\n              \"q1ev4t69e4iLi0NqaioAQFdXFzKZDPb29nBzc0Pz5s3RoEGDfDfis7OzsWzZMixevBjTp0/H+PHj\\\\\\n\",\n              \"UaHCuz2lkZ2djT179mDVqlV48OABfH190aZNGxw+fBiBgYHQ1NTEwIEDkZmZibVr1+Lnn3/GF198\\\\\\n\",\n              \"AZLYvXs3+vfvD2tra2zcuBG2trYIDw/HmzdvoK+vj0aNGgEAPD094enpCXV1dZw7dw6hoaGoWLFi\\\\\\n\",\n              \"oe1p0aIF/Pz8sHv3bvzxxx/Q09PDyJEj4ePjU+J0W5cvX8bQoUNhZWWFdevWoUqVKu90TAThv6JS\\\\\\n\",\n              \"fed7DtByo0rfIsqKvr6+QmtZ/XsBS01NTWpoaFAqldLNzY0DBgzgjz/+yJs3byq9VlVUVBRbtmzJ\\\\\\n\",\n              \"evXq8erVq0qVjY2N5axZsyiTydi8eXNu3ryZ69atY7NmzSiVSjl27FiGh4czOTmZPXr0oJubGx88\\\\\\n\",\n              \"eECSPHXqFN3d3Wlubs46deoU2e7Y2FjWrFmT33zzDeVyOXNzc9mrVy/27du3yDIeHh48fvw4o6Ki\\\\\\n\",\n              \"aGlpycuXL9PHxyfvrLCkxRozMzM5a9YsSiQSbtiw4ZM7GxNUiyr1nSLAPhJyuZwAihy4UdzLwsKC\\\\\\n\",\n              \"jx8/LtO2BAQEUCqVctq0acWuMCyXy3ns2DF6eHhQR0eHXl5enD17Nnv37k19fX1269aNQUFBzMzM\\\\\\n\",\n              \"JElGRETQ3t6evr6+TE9PZ0REBNu3b09ra2suX76cRkZGjIyMLLSuJ0+e0N7ennPmzMn3flpaGt3d\\\\\\n\",\n              \"3TljxoxCy70NMLlcTqlUmneskpKS+MMPP9DBwYG1atXi2rVr+fr16yI/a0REBOvVq8fWrVvz4cOH\\\\\\n\",\n              \"xR1CQXhvVKnvFAH2EdHR0WF80mulwistS05DQ0MmJiaWeXvi4uLYu3dv2tnZ8ejRo/l+l5qaygUL\\\\\\n\",\n              \"FtDCwoI6Ojp0c6vHtu3as3GTptTX12ft2rUZGBiY72xl06ZNlEgkDAwM5IMHDzhgwACamZlxxYoV\\\\\\n\",\n              \"zMjIYJcuXThr1qxC2/Lw4UPa2NhwyZIlhf4+Pj6ednZ23LhxY4HfvQ0wkuzcuTO3b9+e7/dyuZxH\\\\\\n\",\n              \"jx5lt27daGRkxNGjRxf5iEJ2djYXLVpEExMTrly5slxWYxaE0lClvlME2EekefPm/PX3XUoFWNjZ\\\\\\n\",\n              \"i7SxsSnXjnT//v2sWrUqhwwZwsuXL3Py5MnU09Nj5cp6HDt+Im//dT9fm16+SWfA5q10qVWLAwcO\\\\\\n\",\n              \"5KtXrzhs2DA6Ojry1KlTHDt2LI2NjTlr1qy8M579+/ezevXqhQ4ciIyMZLVq1bhq1api23n37l2a\\\\\\n\",\n              \"mpoyNDQ03/v/DLBFixZx3LhxRe7j6dOnnDlzJmUyGT///HNu37497+zx33U1btyYTZs25V9//VXi\\\\\\n\",\n              \"MRSE/4oq9Z0iwD4i27dvp0eLlkoF2ADvQVy0yL9c25Wbm8tdu3bR2tqaampqtLe3Z2U9PR49cbrY\\\\\\n\",\n              \"tr1ITqGnpxfNzc3ZtWtX+vn50djYmOPHj+fz58/z9p+amkpra2seOXKkQN23b99mlSpVuH79eoXa\\\\\\n\",\n              \"eurUKUql0nxnUP8MsNOnT7N+/fol7icrK4s7d+5kixYtaG5uzhkzZjA6OjrfNjk5OVyxYgVNTEzo\\\\\\n\",\n              \"7++fN4pSEN4nVeo7RYB9RDIzM2lhYcF9Bw8rFF7nL12lkZERExIS8u0nLS2N9+/f5927d0s1q8Q/\\\\\\n\",\n              \"n92yt7dns2bNWLlyZVauXJk7/ghSqI0vklNoZW1NQ0NDent7F3rvaPr06ezdu3eB9yMiIiiTyRgY\\\\\\n\",\n              \"GKhUu7dt20YrKyvGxsaSzB9gaWlp1NHRYUpKisL7u337NseNG0djY2N26tSJhw8fznfGe//+fbZo\\\\\\n\",\n              \"0YKfffbZO82OIghlSZX6ThFgH5nTp09TKpXy0JFjxQbDhfA/aWFhwZ07/3/GiIiICI4YMZKGhoa0\\\\\\n\",\n              \"srKiXfXq1NfXZ8uWLblr1y5mZWUp1Ia/9zOCenp6dHFxoUwmo7OzM/39/bljxw46O7so9azaytVr\\\\\\n\",\n              \"2apV60Lrun37Nk1MTBgTE5Pv/cuXL9PMzIy///77Ox3H2bNns169ekxJSckXYCTZsGFDnjx5Uul9\\\\\\n\",\n              \"pqSkcMOGDaxTpw7t7Oy4ZMmSvC8IcrmcP/30EyUSCWfPnl3ksU5ISODChYvYokUL1q1bl42bNOHY\\\\\\n\",\n              \"seN4+/btd/qcgvBvqtR3igD7CJ04cYJSqZS9+vTlsZNn8oXFxcvX6DNsBI2Njfn77ztI/n2Jb+LE\\\\\\n\",\n              \"SbSwsOB3s2bnm4bpVWomt2z7je6Nm7BOnToFguKtrKws/v7772zUqBENDQ1ZrVo1SiQSTpgwgVev\\\\\\n\",\n              \"Xs0bjNGrd2/+8ONqpS5zJrx8QyMjowJ1y+Vyenh4cMWKFfneP3/+PE1NTblnz553PoZyuZyDBw9m\\\\\\n\",\n              \"p06d2Lx583wBNnHiRM6fP79U+z5//jy9vb1paGjIQYMG8eLFi5TL5Xz8+DHbtWtHV1fXfLNXpKWl\\\\\\n\",\n              \"ccSIkTQwMKD3wMHce+AQz5wP55FjJ/nV19/QzMyMLVu2FKMbhVJTpb5TBNhH6u/Ld8tpb29PMzMz\\\\\\n\",\n              \"Ojg6skqVKrS0tOTs2XP47Nkzkn93pqNHf8nGTZoyNj6p2NGK38+Zx+rVq+e75BgbG8vvvvuOxsbG\\\\\\n\",\n              \"lEgk1NHRYY8ePXjgwIFCzyJcXV15/tJVpYf6N27chGFhYfn2FRgYyLp16+a7dxQWFkapVMqDBw+W\\\\\\n\",\n              \"+hhmZmayZcuWtLS0zBdgO3fuZMeOHUu9f/LvMyp/f3/a2NiwXr16/Pnnn5mSksLAwEBKpVJOnz6d\\\\\\n\",\n              \"iYmJbNasGXv06s2ncS8KPT6vUjM5b6E/LSwsxKAQoVRUqe8UAfaRy83N5ZMnT3jr1i1GR0cXGCiw\\\\\\n\",\n              \"b98+OtaoUWDOwKJeEyZNYe/efXj69Gm2adOGWlparFixIuvWrcv169cXO3s8SdaoUYNXrt1UOsA8\\\\\\n\",\n              \"WrTMNzowKSmJ5ubmvHjxYt57R48epVQqLTCKsDRevnxJHR0djhkzhuTfgb9z504aGhry8+bN2bRZ\\\\\\n\",\n              \"M/bq3Zv79u0r1RRabycp7tixI42NjTlhwgSePn2aXbt2pampKXv17qvQM35r1m2gra1tsc/eCUJx\\\\\\n\",\n              \"VKnvFAH2ifNs04YBm7cqHCRxL5JZqVIlampq0sTEhH5+frx3757C9TVt2pT7g0OUCq+0LDntHRzy\\\\\\n\",\n              \"zezh6+tLX1/fvJ8PHjxIqVTKU6dOlenxIclGjRrRxMSEs2bNYs2aNelUsyb9ly7noSPHGHL0BNes\\\\\\n\",\n              \"28AGDRrSysqKW7duK3V9Dx8+5PTp02lqasrGjRtT38CAia9SFT5ebdu1Z0BAQBl8cuFTpEp9pwiw\\\\\\n\",\n              \"T1hUVBSlUilfvklXKlC8Bw6ij4/PO02JtHTpMvbp21+p+k6duUBbW9u8kXsXL16kubk5k5KSSJJB\\\\\\n\",\n              \"QUGUSqU8f/58mR6ftzw8PDh06FDq6+tz74FDRQ5AOXXmAq1tbOjvv7hM6s3IyOAXnTpx9JhxSh2v\\\\\\n\",\n              \"3XsPKDTUXxAKo0p9p1iR+RN25coVNPu8eZGT2BalfcdOSHjxAmpqakrXOWTIYBw+dBDx8fEKl/lh\\\\\\n\",\n              \"2WKYmJggNjYWOTk58PX1hb+/P4yMjLBjxw6MHDkShw4dypuot6y9fv0ae/fuw/FTZ9HGq22Rn7tB\\\\\\n\",\n              \"w4Y4dvIMVq9ehaCgoFLXq62tjbt37mDQ4KFKlWvj1RaPHj1CTExMqdsgCB8yEWCfsNTUVOjo6Cpd\\\\\\n\",\n              \"rnLlynkz1SvLyMgIw4YNx5CB/ZGVlVXi9n/s3IHTYadQq1YtuLq6ol27dtDV1cWAAQOwdetWTJgw\\\\\\n\",\n              \"AUeOHEG9evXeqT2KSEhIwDczZ8HZxaXEbS0sLPDDyjWYO3cuWAYLPSQlJcFcJlOqjIaGBszMzZGY\\\\\\n\",\n              \"mFjq+gXhQyYC7BNmYGCApCTlO7nExEQYGBi8c70LFsyHkaEBunzRHk+fPi10m+zsbKxZvRKjfUeg\\\\\\n\",\n              \"tmsdBAUFwdXVFadOncKdO3fQv39/fPXVVzh27Bhq1679zm0pSUxMDJKSktB/wECFy7TxaovEpCSE\\\\\\n\",\n              \"h4eXun5tbW2Fgv7fMjIyoK2tXer6BeFDJgLsE/b555/j3NkzePnypVLl9vyxE56tPd+53goVKuD3\\\\\\n\",\n              \"33+He6NGaFjPFX16dsOe3X/g3NmzOHH8GL6f+S0cbK0QtPsPnD53EQcPh+Kv+9Ho3rM3TEwkqFWr\\\\\\n\",\n              \"Fvbs2QMtLS3cvn27TM50irJv3z6069ARenp6CpdRV1dH/wEDsWvXH6Wu38nJCefPnVWqTGxsLBJf\\\\\\n\",\n              \"vIClpWWp6xeED5kIsE+YVCpFx44dsXXLZoXLPH36FGGnTmLAgP6lqltDQwPz5s1FdHQ02np5Yd7s\\\\\\n\",\n              \"WejbuzvmzfkeycnJOHDoCI4cOwkHR0cAf1+29Bk+EucuXcHDh48wbdo0/PTTT5g9ezaaNm2Kc+fO\\\\\\n\",\n              \"lao9RUlISICNja3S5apUscSLFy9KXf/IkSOx4ae1SpXZtHEDevfuA11d5S8PC4IqEQH2iRs/fjyW\\\\\\n\",\n              \"L/XH/Xv3Stw2NzcXE8aOxtChPqhcuXKZ1F+5cmV4ebVBXNwzXLoSgaMnwrB8xcoi7zfJZDKEHD2B\\\\\\n\",\n              \"FStWoG7durh69SpGjBiB3r17o3v37oiKiiqTdr2lra2NjIwMpctlZmaWySW8Ll264F5UJE6HnVJo\\\\\\n\",\n              \"+4SEBPy8fh1Gjx5V6roF4UMnAuwTV79+fcya9T3ae7XCjevXi9wuNTUV3v37IDMzA/PnzyvTNqxb\\\\\\n\",\n              \"9xP6ew+CmZmZQttbWVujU+euCAjYBA0NDQwaNAiRkZH47LPP4O7ujrFjxyIhIaFM2lajRg2cOR2m\\\\\\n\",\n              \"dLnwSxfh4OBQ6vo1NTWxadMmePfrjatXrhS7bWJiIrp16oAhQ4aW631BQfhgvO9x/OVFlZ5l+BBs\\\\\\n\",\n              \"2/YLTUxM2KHjFwzaH8zHsfGMe5HM8KvXOW7CJJqYmHDw4MHMyMgo03pzc3P/Xr7kdqRSzzqdPneJ\\\\\\n\",\n              \"1atXL7C/+Ph4jh07liYmJpw3b16pZ6TIysqinp4+L/95Q+G2PXmWQENDw1LN5P9vu3fvpkQi4eSp\\\\\\n\",\n              \"frwb9bDAw+XLflhJK2trTpky9Z2ezxOEt1Sp71Qjy/EO+HtUv359XL58+X03Q6WkpqZi+/btWL9h\\\\\\n\",\n              \"A+5FRSErKwtSqRQ9evTEqFG+sLGxKfM6k5OTYWVlheeJr5Qql5OTA8PKFZGVlQV19YIXEqKiovD1\\\\\\n\",\n              \"11/jwoULmDNnDry9vaGhofFObbSzs4NbvfoI/GW7Qs++fTtjOl7Ex2HTpk3vVF9R7t27h9Wr12Dr\\\\\\n\",\n              \"1kDY2zvARCJBamoqrv15Fa09PTHmyy/RvHnzMq1T+PSoUt8pAkx4rxISElCjRg3EPFduOD9J6FXS\\\\\\n\",\n              \"REZGBipUqFDkdufPn8eUKVOQkpICf39/eHl5Kd3GZs2aISEhAZ27dses2XOLDbGf1/+EJf4LcP78\\\\\\n\",\n              \"eciUfH5LUWlpabh06RJevXoFHR0d1KpVC+bm5uVSl/DpUaW+s+j/8wXhP2BoaIi0tDS8evVKqWfL\\\\\\n\",\n              \"YmJioK+vX2x4AYC7uzvOnDmDoKAgjB07FtbW1vD390edOnVKrCMpKQkhISGIi4tDq1atsHfPH7ge\\\\\\n\",\n              \"cQ3jJ05Gc48W+YLscng41qz6EZcunkdoaGi5hRcA6OjowMPDo9z2LwiqQgziEN4rTU1NfNGpE37d\\\\\\n\",\n              \"tlWpcoGbA9CrV2+FtlVTU0PXrl1x69YtdO7cGW3btsWgQYPw5MmTQre/ceMGhgwZAltbW/y2/Xc0\\\\\\n\",\n              \"afY5UtMzAKjh1s2b8BnsjVo1HdC7R1f07dUdDeq5YmD/3nCtXQvh4eGwt7dX6rMIgvBuxBmY8N6N\\\\\\n\",\n              \"HjUKo0aPxgjfUQrdp8rKykLAz+sRHBysVD2ampr48ssv4e3tnXcWNmLECHz11Vd5Z39BQUEYPnw4\\\\\\n\",\n              \"xk2cjBt3oiCVSvPKk8TpsFNYtGAeUt68QY/u3VCxYkXIZDK4u7u/8z02QRDejbgHJrx3JNG6dWs4\\\\\\n\",\n              \"OdfC4qXLi73HJJfLMXLYUKSlpeCPXbtKVe/Tp08xc+ZMHDhwADNmzECNGjXg7e2NPfuC4VbM3Iq5\\\\\\n\",\n              \"ubkYM2oknj59jOCDB6GpqVmqdgjCh0SV+k5xCVF479TU1LBr1y6cP3savsN9ipypPiYmBt79+yD6\\\\\\n\",\n              \"0QMEbtlS6notLS2xceNGHD16FMHBwejVqxfWrt9YbHgBf88isnLNOqSkpGJXKUNUEIR3JwJM+CAY\\\\\\n\",\n              \"GRnh1KlT0NbShKuzIwZ798fv23/DoeCD+O3XX9Cvdw98VrcWqlapgiNHjpTpNEm1atWCn58fzMzM\\\\\\n\",\n              \"0a59B4XKVKhQAeMnTsaaNWvKrB2CIChHXEIUPjgvX77Epk2bcf7Cebx58wb6+vpo/nlzDBzordSk\\\\\\n\",\n              \"usoYMGAA3Oo3xOgxYxUuk5OTA0c7Kxw7dgw1atQol3YJwn9NlfpOMYhD+OAYGRlh0qSJACb+Z3Xe\\\\\\n\",\n              \"u38fQ4crN39ghQoV4OxSCw8ePBABJgjvgbiEKAgAcnNySnymrDCamprIyckphxYJglASEWCCAMDU\\\\\\n\",\n              \"zAzR0Y+UKkMSjx49VHgSYkEQypYIMEEA0Kd3bwRuDlCqzOXwcKSnpeGzzz4rp1YJglAcEWCCAKBn\\\\\\n\",\n              \"z56IuPYnIv/6S+Ey69asgq/vqEInExYEofyJ//MEAUDFihXx9dczMKBvLyQnJ5e4/S9bA3H2TBiG\\\\\\n\",\n              \"DfP5D1onCEJhRIAJwv+MHz8OrVu3RiuPprgcHl7oNm/evMGiBfMw89uvERwcDGNj4/+4lYIgvCWG\\\\\\n\",\n              \"0QvC/6ipqWHp0iVwXO8I7369IJFI0be/NyyqVEFGRgYunD+HHdt/RXMPD5w9exZWVlbvu8mC8EkT\\\\\\n\",\n              \"ASYI/6CmpoaRI0dg2DAfHD58GLt370HYqePQ1tZGTaeauH79OiwtLd93MwVBgAgwQSiUhoYGOnTo\\\\\\n\",\n              \"gA4dFJtaShCE/564ByYIgiCoJBFggiAIgkoSASYIgiCoJBFggiAIgkoSASYIgiCoJBFggiAIgkoS\\\\\\n\",\n              \"ASYIgiCoJBFggiAIgkoSASYIgiCoJBFggiAIgkpSI8n33YjyIJFIYG1t/b6bIQiCoFIePXqEFy9e\\\\\\n\",\n              \"vO9mKOSjDTBBEATh4yYuIQqCIAgqSQSYIAiCoJJEgAmCIAgqSQSYIAiCoJJEgAmCIAgqSQSYIAiC\\\\\\n\",\n              \"oJJEgAmCIAgqSQSYIAiCoJJEgAmCIAgqSQSYIAiCoJJEgAmCIAgqSQSYIAiCoJJEgAmCIAgqSQSY\\\\\\n\",\n              \"IAiCoJJEgAmCIAgqSQSYIAiCoJJEgAmCIAgqSQSYIAiCoJJEgAmCIAgqSQSYIAiCoJJEgAmCIAgq\\\\\\n\",\n              \"SQSYIAiCoJJEgAmCIAgqSQSYIAiCoJJEgAmCIAgqSQSYIAiCoJJEgAmCIAgqSQSYIAiCoJJEgAmC\\\\\\n\",\n              \"IAgqSQSYIAiCoJJEgAmCIAgqSQSYIAiCoJJEgAmCIAgqSQSYIAiCoJL+Dwmtm6inx7CKAAAAAElF\\\\\\n\",\n              \"TkSuQmCC\\\\\\n\",\n              \"\\\"\\n\",\n              \"  frames[2] = \\\"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAbAAAAEgCAYAAADVKCZpAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\\\\\\n\",\n              \"AAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0\\\\\\n\",\n              \"dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOzdd1gUx+PH8TdNREWUIsVKwI4KKvaO\\\\\\n\",\n              \"RhFUwIaKDRU19hZLbNHE3rEbrFiIioWiiBXFih1LbDEKIlWKdLj9/ZGffENAuTMins7ree4BuZ3d\\\\\\n\",\n              \"2QP3czM3M6siSZKEIAiCICgZ1aKugCAIgiB8DBFggiAIglISASYIgiAoJRFggiAIglISASYIgiAo\\\\\\n\",\n              \"JRFggiAIglISASYIgiAoJRFggiAIglISASYIgiAoJRFggiAIglISASYIgiAoJRFggiAIglISASYI\\\\\\n\",\n              \"giAoJRFggiAIglISASYIgiAoJRFggiAIglISASYIgiAoJRFggiAIglISASYIgiAoJRFggiAIglIS\\\\\\n\",\n              \"ASYIgiAoJRFggiAIglISASYIgiAoJRFggiAIglISASYIgiAoJRFggiAIglISASYIgiAoJRFggiAI\\\\\\n\",\n              \"glISASYIgiAoJRFggiAIglISASYIgiAoJRFggiAIglISASYIgiAoJRFggiAIglJSL+oKFBZ9fX2q\\\\\\n\",\n              \"VKlS1NUQBEFQKs+fPycmJqaoqyGXrzbAqlSpQkhISFFXQxAEQak0bNiwqKsgN9GFKAiCICglEWCC\\\\\\n\",\n              \"IAiCUhIBJgiCICglEWCCIAiCUhIBJgiCICglEWCCIAiCUhIBJgiCICglEWCCIAiCUvpqJzILgiB8\\\\\\n\",\n              \"LqmpqTmrVxgYGFC8ePEirtG3QbTABEEQPoIkSQQHB+Pcpy/6BgZYN25Cw0aN0dc3wKX/AK5evYok\\\\\\n\",\n              \"SUVdza+aCDBBEAQFpaSk4ODoRL/+AylnWpsDZ+9y8Fwo3kH38Dp1Ex0TM7r37E3vPn1IS0sr6up+\\\\\\n\",\n              \"tUSACYIgKCAzMxP7Ll3JQIPtPsH0HvwDpXXK5DxfRlePvkPHssM3mNiEVHr06El2dnYR1vjrJQJM\\\\\\n\",\n              \"EARBAYsXLyY9W5VpC9ehUazYe7fTLK7F7OVbiIxLYM2aNZ+xht8OEWCCIAhyyszMZP2GjbhNmo26\\\\\\n\",\n              \"esFj4NQ1NBg6YSbua9cik8k+Qw2/LSLABEEQ5OTj44NR+UqY16gtd5na9RpSvIQ2gYGBhVizb5MI\\\\\\n\",\n              \"MEEQBDlduXIF6xY2CpVRUVHBuoUNV65cKaRafbtEgAmCIMgp6e1btEqWVLicVslSvH37thBq9G0T\\\\\\n\",\n              \"ASYIgiCnMjo6JCXEK1wuKeENOjo6hVCjb5sIMEEQBDnZ2NhwPtBXoQnKMpmMCyf9aNeuXSHW7Nsk\\\\\\n\",\n              \"AkwQBEFO7dq1Q8rO5M51+T/PuhZ8hjI6pWnSpEkh1uzbJAJMEARBTioqKri49GPZnImkpiQXuP3b\\\\\\n\",\n              \"pEQ2Lp3L5EkTUVFR+Qw1/LaIABMEQZDT0aNHWevuTnlDA35060XCm7j3bvsmNpopQ3vQwaYNLi4u\\\\\\n\",\n              \"n7GW3w4RYIIgCAWQyWTMmTOHUaNG4ePjw7lzZ7Fp1YK+Ha1ZOmsCj+/fJT0tlbTUFB6G3mLprAn0\\\\\\n\",\n              \"69SYbva2rHV3F62vQiJupyIIgvABb968wcXFhbdv3xISEoKhoSEAS5cuYfz4cbRu3YYrQYG8iYsF\\\\\\n\",\n              \"wKR8BVwHD2LjqgcYGRkVZdW/eiLABEEQ3uPu3bs4OjrSpUsXlixZgoaGRq7nS5YsSVRUJBEREZT8\\\\\\n\",\n              \"iPlhwn8juhAFQRDysW/fPtq1a8fPP//MypUr84QXgL+/P23atBHhVUREC0wQBOEfsrKymDp1KocO\\\\\\n\",\n              \"HSIwMBBLS8v3bnvkyBG6dev2GWsn/JMIMEEQhP8XFRVF79690dTUJCQkBF1d3fdum56eTkBAgLhV\\\\\\n\",\n              \"ShESXYiCIAjAtWvXsLa2pnnz5vj5+X0wvADOnDmDhYVFzqAO4fMTLTBBEL55Hh4eTJ8+nU2bNuHo\\\\\\n\",\n              \"6ChXGdF9WPREgAmC8M1KT09n3LhxnDt3jqCgIGrUqCFXOZlMxtGjRzlz5kwh11D4EBFggiB8k8LD\\\\\\n\",\n              \"w+nRowcmJiZcvXoVbW1tucuGhIRQunRpqlWrVog1FAoiPgMTBOGbExQUhLW1NV27duXAgQMKhReI\\\\\\n\",\n              \"7sMvhQgwQRC+GZIksWbNGnr16sX27duZPn36Ry3zdOTIERwcHAqhhoIiRBeiIAhKKysrC19fX3bs\\\\\\n\",\n              \"9OTV6whUVVUxrVIFt6FDaN26da5wSklJwc3NjXv37nHp0iVMTU0/6phPnz4lNjaWRo0afarTED6S\\\\\\n\",\n              \"aIEJgqCUtu/YQaXKVfhp3gLK1W1GR9eJtB84lmLlqzFo2Aiq16xFYGAgAM+ePaNZs2aoqKgQHBz8\\\\\\n\",\n              \"0eEFf7e+unTpgqqquHwWNdECEwRB6SxYuJB1GzYzaslmvqtVL9dztRo24/veg7l98QzO/VwYNsSV\\\\\\n\",\n              \"bVu3MnPmTEaPHq1Ql2Fqaiq///47585fICnpLTqltblwPohFixZ96lMSPoJ4CyEIglLZv38/7us3\\\\\\n\",\n              \"8tNvB/OE1zsqKipYNm/Hj2t3s3qNO/Pnz2fMmDFyh1d6ejo/Tp1G+YoVWeuxEwxMKVe3Gdl6lcku\\\\\\n\",\n              \"VoKhw9yYPWcOmZmZn/LUBAWpSJIkFXUlCkPDhg0JCQkp6moIgvAJSZKERd16dB05jbpNW8tVJmDf\\\\\\n\",\n              \"VhKf3sH7wH65tk9OTqZTZzsy1bVwHjcTw4pV8mzz6vlT9qyYi15JTXyOHEZTU1OR0/iiKdO1U7TA\\\\\\n\",\n              \"BEFQGhcvXiTxbTIWjVvKXaalXQ9Ongzk1atXBW4rSRLOffuhrlOOMUs25xteACZVzJiwYhvJMnUG\\\\\\n\",\n              \"uQ6Ruy7CpyUCTBAEpXHgwEGadu6u0ACKEtqlqd+yPT4+PgVue+3aNa7fvIXrzCUFHkNNXR23n1dy\\\\\\n\",\n              \"8uQp7t+/L3d9hE9HBJggCEojKiaasgaK3+VYx8CI169fF7id+9p1tHVyQT2fe3/lp1hxLVp3c2bt\\\\\\n\",\n              \"uvUK10n478QoREEQlIamZnGyP2LgRFpKMvPnb2TZsmUYGhrmPIyMjHK+19fXZ//+/az2u6LQvls7\\\\\\n\",\n              \"9GG2iy3r161VuF7CfyMCTBAEpVGrRnX8z1/DpoeLQuXCHoVy6NAhWrVqRWRkZM7j9evXREZGcuPG\\\\\\n\",\n              \"DV6+fAkqKpQu++HbqPybnpEJb5MSycjIoFixYgqVFf4bEWCCICiNgQMHMu+XX3COf4N2mbJylXn+\\\\\\n\",\n              \"MJQ3Ua+xtbVFXV0dHR2dfBfhTUhIoHyFigrXSZIkJElCTU1N4bLCfyM+AxOEL1B2djYnTpxg06ZN\\\\\\n\",\n              \"rF+/noMHD5KcnFzU1SpyBgYG2Nt34fjuzXJtL0kSBzctR0WSChwarq2tjZq6GlHhLxSqU9jTPzA0\\\\\\n\",\n              \"MhYBVgREgAnCFyQxMZFfFyygYmVTRk78kR1+Z9kdEMycpWswqVCRMePG8eKFYhfYr83SxYu4GnCY\\\\\\n\",\n              \"M4f2fHA7SZLwcl9IVmIs8+b9jJOTE25ubsTGxubaLj4+Hg8PD2xsbJBlZRP4+w6F6nPG25OhQ8VQ\\\\\\n\",\n              \"+qIguhAF4QsRHh6OTYeOlDAxxeGnNZSvZpHr+TeR4YT47KGBdSN8jx6hcePGRVTTz+/+/fv85rGV\\\\\\n\",\n              \"R0+ekp2dTdOmzTiyeTm3L56ly6AfMKttmbOtTCbj3tULBOzegkpGMoEBxzEwMKBHjx7Mnj2bWrVq\\\\\\n\",\n              \"8fPPP2NkZMSePXs4ceIENjY2jBs3jqpVq9KydRu6uY6mlE7BXZQJsdFcOn6ErSvuFeLZC+8jAkwQ\\\\\\n\",\n              \"vgAJCQnYdOhI5aYdadlnRL5LHpU1LE+HoVOoaNGAzvZduBB0jpo1axZBbT+fmzdvMnrceB48fEjD\\\\\\n\",\n              \"Tj0wqm+DqpoabyLCQOMK964G8+T2VcroG2JYoTKSJOPlkz/Q0S7F2NGjGDBgAFpaWgCULl0aZ2dn\\\\\\n\",\n              \"IiIiGDNmDFpaWowfP55NmzZRtuzfYSVJErVr12LhD32ZuXk/WiVLvbdubxPjWTFhMOPHj8PExOSz\\\\\\n\",\n              \"vB5CbiLABOELsGr1akpVMH9veP1TjSbteNPLjQmTp3Dcz/cz1fDzO3PmDE49etJhyGS6zt6Eukbu\\\\\\n\",\n              \"EX4te7ry9NZlDi6ZSuf27WjduhVqampUrlwZKyurnNfx8ePHeHp64unpSbFixejfvz+LFi3izJkz\\\\\\n\",\n              \"zJgxg/j4eObPn0/JkiUZNWoUSQkJNG1Qj1+GOuI4fDJWLWxQU//fpTIrM5Pr5wLw3riM7t26MnfO\\\\\\n\",\n              \"nM/6ugj/I9ZCFIQilpWVRflKlek5dyPGZvK1qDLSUlnZrzV3bt2gSpUqhVvBIvD48WMaN21G75mr\\\\\\n\",\n              \"Mbdq8sFt46Mi2DSuNxvcV9O9e3cAoqOj8fLywtPTk+fPn+Ps7Ez//v2pX79+rjcIsbGxTJ8+HR8f\\\\\\n\",\n              \"H4yMjDAwMODgwYOUKlUKLy8vlq9czcuwl9Rr3g7NEqVISYzn+rlASpfWplHDBvTu3ZuuXbt+VcPn\\\\\\n\",\n              \"lenaKQZxCEIRCwgIoLS+sdzhBX+vAFGvfTc8PLYWYs2KzuKly2jUpW+B4QVQppwxjpMWMGPWbPbt\\\\\\n\",\n              \"20eXLl2oWrUqly9fZu7cuYSFhbFq1SoaNGiQp3Wrp6fHL7/8QtmyZXnx4gUZGRm8fPkSFRUVnJ2d\\\\\\n\",\n              \"uXblEsd8fahcpjjXT/lw5ZQ/Va1bUadjTxJLlWfWouWUr1iJWbNnk5qaWlgvh/AeogtREIrY06dP\\\\\\n\",\n              \"MTSvpXC5cmY1efT0diHUqGglJCTwu5cX47cek7tM1QbN8Vocx4oVKxgzZgx79uxBW1u7wHKPHz/G\\\\\\n\",\n              \"1taWvn37Mnv2bDZu3Ejr1q1xdXVl1qxZlCpViiNHfTjqH4DtiJ+o2bRdru5EgNd/PsJvxxqOHW9L\\\\\\n\",\n              \"YMCxnM/ThMInWmCCUMSysrJQUVV8DpGamjqZmVmFUKOidezYMczqWaOjbyh3GRUVFVo4DqB+Q2v6\\\\\\n\",\n              \"9+8vV3hdvnyZVq1aMXXqVObNm4e6ujqjR48mNDSUiIgIatWqxTA3N7Zs38kI9/1YtPw+T3gBGJlW\\\\\\n\",\n              \"o+8cd0pWrE6Xbo5kZX19v5MvlQgwQShihoaGJEWFK1zuTWQYxkbyX+SVRWRkJKXLKT6qr6xReV6/\\\\\\n\",\n              \"jpRr2yNHjtClSxd+++03hg0blus5Q0NDdu7cyZYtW9i5yxOXnzegrav/wf2pqKhgP2omr98kcvTo\\\\\\n\",\n              \"UYXrLnwc0YUoCEXMzs6OET+MIikuGm1dA7nKyGQy7gYeYp7XhyfzKiMNDQ1k2dkKl8vOyiIwMBBz\\\\\\n\",\n              \"c3OMjIze+wgICGDt2rX4+/tjbW393v2Fh4dTo0FTylU2k+v4qmpqNHEcwCr3tTg5OSlcf0FxIsAE\\\\\\n\",\n              \"oYiVKVOGHj16ct1vH236j5GrzJOQ85TVLkWTJgUPclA2ZmZmRGz4TeFyEU/u8cPIEbgNG8rr169z\\\\\\n\",\n              \"PS5dusSrV6+4efMmMTExqKioYGtri5GREcbGxvkG3Zp1G2jQfVjBB/6HOi074rv2F8LCwqhQoYLC\\\\\\n\",\n              \"5yAoRgSYIHwBZkz7kcZNm1GpjjXfWX44lN5EhnNo6VSKq6ty4cIFWraU/+7EyqB9+/akxA8j7I+7\\\\\\n\",\n              \"VKheR64yGWmp3DhxmK03QjA1NaVq1aq5n8/IwNXVFVNTU0JDQylbtiyxsbF5gi48PJzr16/z+vVr\\\\\\n\",\n              \"Hj96RCc5W1/vqBfTxKB8RcLDw0WAfQYiwAThC1C1alUO7v8dp569aDNwApYdHPJM3JUkiSfXL+C7\\\\\\n\",\n              \"cia/zpvLd6am9O7dG1dXV+bMmYOGnDdh/NKpqanxw4jhHP59C84zVxc4sRvgqp8XjRo3wtTUNM9z\\\\\\n\",\n              \"CQkJODk5Ubp0aU6dOkWJEiUAKFeuHOXKlaNu3br57rNSFVNUKPjYeUh//66EwicGcQjCF6Jt27ac\\\\\\n\",\n              \"DjxB7I2TrHJpw0mPZdw940do0DHOe21mwzBbLu1Yxm8b1zFm9Gjs7Oy4efMmN27coHnz5jx+/Lio\\\\\\n\",\n              \"T+GTGTd2LKmRLzm5Y02BYXD/0hmC9m5kzcoVeZ4LCwujZcuW1KhRgwMHDuSElzxMTMoT9eKZQvXO\\\\\\n\",\n              \"yswg+tVLypcvr1A54eOIABOEL4iVlRVBZ05z6UIQ9U20Sb1/nsSbp6ikmsS+ndt4eC8UBweHnO0N\\\\\\n\",\n              \"DQ3x8/NjwIABNGvWjK1bt34V7/61tbU5eeI4fwT54zFtKH/dv5XnvGIjXuK/cRFHV/6En8/RPOtC\\\\\\n\",\n              \"hoaG0qxZM1xcXFi7dq3CtzsZ6jqI6/5eCpUJPX8CizoWVKyo+H3FBMWJLkRB+ALVqFGD5cuWyrWt\\\\\\n\",\n              \"iooKo0ePpm3btvTt2xd/f382bdqEnp5eIdeycBkYGJCRlkLS01AOLZqIRsnSGH9XA1U1dd5EhhH2\\\\\\n\",\n              \"6B6DBg5k27WrVKpUKVfZM2fO0Lt3b1atWkXfvn0VPnZiYiIPHz7k/tXzRIf9iUGFvF2T/yaTybhy\\\\\\n\",\n              \"eCcLZk1T+HjCxxEtMEH4StSuXZurV69SuXJlLC0tOXXqVFFX6T9Zt24dMTExHDl8mL/+fMZva1fh\\\\\\n\",\n              \"6tgRl86t+WXaRF6FvWTVyhV5wmvv3r307t0bLy8vhcMrKyuLzZs3U716dWJiYvhp5k94zh7J2/jY\\\\\\n\",\n              \"D5aTJAn/DQvQLalJt27dFD5X4eOIFpggfEU0NTVZvnw5nTp1YuDAgfTt25f58+ejqalZ1FVTSEJC\\\\\\n\",\n              \"AjNmzKBDhw40b94cgHbt2n2wjCRJLFu2DHd3d06dOkWdOvKNYHwnMDCQiRMnoqenh6+vLw0aNAAg\\\\\\n\",\n              \"Iz2DjWN703nkT9Ro1ArVf3VFRr14xqkdq8l+E8HJEwFfzWAaZSBWoxeEr1RMTAxDhw7lr7/+Ys+e\\\\\\n\",\n              \"PUp17zA3Nzd27tzJs2fP5LrXVnZ2NuPHj+fs2bMcO3ZMoSHs9+/fZ8qUKTx69IilS5fSrVu3PCMf\\\\\\n\",\n              \"Dxw4wPwFC4mMjqF2686UKqNHZnoaL+5cJeLZQ4YOGcLsWTMpWbKkwuf6pVGma6foQhSEr5S+vj6H\\\\\\n\",\n              \"Dh3ihx9+oFWrVmzYsEEpBng8e/aM7du3M3nyZLnCKzU1lR49enDv3j0uXLggd3hFR0czatQoWrdu\\\\\\n\",\n              \"Tfv27bl37x4ODg75Dtvv0aMHt29cx8f7AC3MylFRJZE6eur8NH4kr8JesnjRwq8ivJSNaIEJwjfg\\\\\\n\",\n              \"jz/+oF+/fhgbG+Ph4UG5cuWKukrv1aJFCx4+fEh4eHiBXZ8xMTF07doVU1NTtm7dKldXaXp6Ou7u\\\\\\n\",\n              \"7ixevDhnFXplH/DyKSnTtVO0wAThG1C9enUuXryIhYUFlpaWHD9+vKirlK/Tp09z5coVtmzZUmAY\\\\\\n\",\n              \"PXv2jObNm9O6dWt27dpV4PaSJHHgwAFq1qzJ+fPnuXDhAqtXrxbhpcykr1SDBg2KugqC8EU6c+aM\\\\\\n\",\n              \"VLFiRWns2LFSSkpKUVcnR3Z2tlS+fHnJysqqwG2vXbsmGRsbS2vXrpVr31euXJGaN28u1atXTzp5\\\\\\n\",\n              \"8uR/repXTZmunaIFJgjfmDZt2nD79m1ev35No0aNuHPnzhfx2Zi7uzuRkZHs27fvg9v5+/tja2vL\\\\\\n\",\n              \"+vXrGTVq1Ae3ffnyJS4uLjg4OODq6sr169exsbH5lNUWipAIMEH4BsXHx1OpchVeRbzG0tISDQ0N\\\\\\n\",\n              \"KlcxZe7PP/Pq1avPXp/k5GRmzJiBs7Mz1apVe+92v/32G66urhw9ejTXiiT/9vbtW2bNmoWlpSXf\\\\\\n\",\n              \"ffcdjx49wtXVVeHVOIQvmwgwQfiGZGVlMfKHUdRv0JAnkQnM++0gB0Ne8PvVP5m49Deu3X9GzVq1\\\\\\n\",\n              \"mTV79mdtlY0dOxZJkli/fn2+z0uSxJw5c1i4cCFBQUE0bdo03+2ys7Px8PCgWrVqPH/+nFu3bjFv\\\\\\n\",\n              \"3jxKlSpVmNUXioiYyCwI3wiZTEafvv14/iqKDb6XKaldOtfz39WwYMTMJTiPnMLCcQOJe/OGtWvW\\\\\\n\",\n              \"yLUa/H/x119/sWPHDpYtW4a2tnae5zMzM3FzcyM0NJSLFy9iaJj/XahPnTrFpEmT0NbW5siRIx+8\\\\\\n\",\n              \"WaUyyMrKQk1NrdBff2UmWmCC8I1YtWoVD5/+xbRV2/OE1z+V0TNg9oZ9HAs4WeDnUZ+Cs7MzxsbG\\\\\\n\",\n              \"jB07Ns9zSUlJ2NvbEx0dzdmzZ/MNrz/++IOuXbvi5ubGrFmzCAoKUsrwkiSJK1eu0K//ALRL66Cp\\\\\\n\",\n              \"qYmGhgY1a1vg7u5OYmJiUVfxiyMCTBC+AdnZ2axctZrBU36mmGbxArcvqV0al7EzmP/rgkKt1+nT\\\\\\n\",\n              \"p7l69Sq7d+9GVTX35SgiIoLWrVtTuXJlDh8+nGeicGxsLGPHjqVFixa0atWK+/fv0717d6VsscTF\\\\\\n\",\n              \"xdHWpj1OvZxRNajC8kNB7Al5wc5LT+g5fg6/+wVSqXIV9uzZU9RV/aKIABOEb8Dx48fRLqtH1dqW\\\\\\n\",\n              \"cpep38KGVxERjB07luzs7E9eJ0mScHFxyQmgf3rw4AFNmzbFycmJTZs2oa7+v087MjIyWLlyJTVq\\\\\\n\",\n              \"1EAmk3H//n0mT56sdOs9vhMfH0+LVq3RqViV5YfO02XACMroGaCiooK6RjEsrJszZtFGZm7ez8Qp\\\\\\n\",\n              \"P/Kbh0dRV/mLIQJMEL4Bp06fpmEbW4XKqKmp0aazEz4+PrRv356XL19+0jqtXLmS6OhovLxy33Pr\\\\\\n\",\n              \"woULtGnThrlz5zJz5sycFpUkSRw6dIjatWtz8uRJgoKCWLt2LQYGBp+0Xp/bULfhVKlrTd/xs/K0\\\\\\n\",\n              \"Qv+pUtWaTFu3hx+nTuPevXufsYZfLhFggvANSEhIpFRpHYXLlSxdhj59+tChQwcaNmzI/v37P0l9\\\\\\n\",\n              \"UlNT+emnnxg5ciRGRkY5Pz9w4ABOTk7s2rWLQYMG5fz8+vXrtGnThtmzZ7Nu3Tr8/PyUanHi9wkL\\\\\\n\",\n              \"CyPwxAmcR0+Xq+vTpIo57XsMYI372s9Quy+fCDBB+AZolypFavJbhculpbxFR0eHGTNm4Ovry08/\\\\\\n\",\n              \"/cTgwYNJSkr6T/UZPnw4GhoaLF++POdnq1atYvz48Zw4cYLvv/8egPDwcAYNGoS9vT0uLi7cvHkz\\\\\\n\",\n              \"57mvwcZNm2jR2ZHiJeRfCLidU1/27dsnBnUgAkwQvgnNmjXl1sXTCpWRJImQoECaNGkCgLW1NTdu\\\\\\n\",\n              \"3EBdXR0rKysuX778UXV5/vw5u3fvZu3atWhoaCCTyZg4cSKbN28mODgYS0tLkpOTmTt3LnXr1sXE\\\\\\n\",\n              \"xIQ//viDYcOG5fos7Gtw6vQZGrTppFAZ3XLGVDSryo0bNwqpVspDBJggfAMcHBwIf/6EF0//kLvM\\\\\\n\",\n              \"3WvBxEVHMWPGDLy8vMjMzKRUqVJs2bKFJUuW0K1bN+bPn09WVlaesmlpaXh6etK6XQdq1K5LnXoN\\\\\\n\",\n              \"6OXclwsXLtC9e3fMzMwYMGAAaWlp9OnTh5CQEC5cuEDFihXZvn071atX59GjR9y4cYMFCxZQuvT7\\\\\\n\",\n              \"h/0rm4yMDMLCwggJCeH160hKaCvetVtKu4xogSEmMgvCN6FYsWKMGD4cz9W/MnXltgKXVMpIT2Pv\\\\\\n\",\n              \"ukUsWrgAY2Nj1qxZw6RJkxgxYgRubm44OTnRuHFjBg4cSEBAALt27cLU1BRJkli6bDm//PIrKloG\\\\\\n\",\n              \"pBY3RaVYNciQ8Sg4DF9/B1KTE9nmsZm4uDgcHBwwMjLixIkTXL58mYkTJ1K8eHEOHDiQ0/JTBtnZ\\\\\\n\",\n              \"2cTExBAZGcnr169zPf79s8TERMqVK4ehoSEpqSmkfUTXbkpyUr6Tvr814n5ggvCNyMjIoGMnW9RK\\\\\\n\",\n              \"6fLD3BVoaBTLd7vUlGQWjh9EdNhzHj64n7MM0507d3B3d+fAgQN07dqVMWPGUL9+fVauXMmiRYtY\\\\\\n\",\n              \"sWIFwRcvs/v3w2SUa4tq8TJ59i1JErKEP1GPvoC+bhl69OjB8OHDmTZtGjdv3mTx4sX06tXri5jL\\\\\\n\",\n              \"JUkS8fHxHwyjd/+OiYmhTJkyGBkZ5TwMDQ3z/beenh6qqqpIkkT/AQOIl7QYMHmu3PVKjItlomNL\\\\\\n\",\n              \"nv/5DF1d3U9+3sp07RQBJgjfkJSUFPr07ceNW7f5vscAbByc0dYpC8CbmCgCvXcTeNCTTh2/Jysz\\\\\\n\",\n              \"g/DwcHx8fHKtJRgXF4eHhwfr1q3DxMSEMWPGYG5uTpcuXYlNykDNzAkVtQ/PyZIlv4YXx3Du3QM/\\\\\\n\",\n              \"Pz+mTJnCuHHjKF684EnW/9Xbt2/zDaN//ywyMhItLa0PhtG7h4GBARoaGnIdPyYmBk9PTzw8PEhI\\\\\\n\",\n              \"SOBNQiIbAq5TrLiWXOWPbl+Havwrdm7f9l9ehvdSpmunCDBB+Ma8W7LIfe06Dh86RPESJZAkicyM\\\\\\n\",\n              \"DJydnRk96gfq1atHdnY2w4cP58GDB/j7+6Ojk/uzmqysLHx8fHB3d+fhw4fEJ7wlu4ItqiXkm5eV\\\\\\n\",\n              \"FXEFs7IpBJ0785/vEJ2env7e1tG//y2TyXKFz/sCytDQEC0t+UKlINnZ2QQGBuLh4UFgYCBdunRh\\\\\\n\",\n              \"yJAhtGrVCvuu3ShpYk7PH6YUuJ+YiHDmDOrKcX9fGjRo8Enq9m/KdO0UASYI37D09HTi4uJQVVVF\\\\\\n\",\n              \"V1c3TytCJpMxevRoQkJCCAgIoGzZsvnuZ/Xq1fw4axGqZt3lPraUmYzqnwd4HRGe7yCNrKwsoqOj\\\\\\n\",\n              \"5WotJScn54RQQa2lUqVKfbYuymfPnrFt2za2b9+OkZERQ4YMoU+fPrneDERGRtK4SVNadHWmy6BR\\\\\\n\",\n              \"761bVPgLloxxYeyokUyeNKnQ6qxM104xiEMQvmGampoYGxu/93lVVVXWrVvHxIkTsbGx4cSJE+jr\\\\\\n\",\n              \"6+fZ7nzwZWQ61RUa1qyiURKVEuUYM2YMhoaGecIpLi4OXV3dPK2lSpUq0ahRo1wBVbZs2Q+uYvE5\\\\\\n\",\n              \"paam4u3tjYeHB3fv3qVfv374+flRt27dfLc3NDQk+MJ5Ott3IeT0Mdr2GEDzjl1zuhRfPH7Ayf07\\\\\\n\",\n              \"uRR4lF/mzWPMmDGf83S+aCLABEH4IBUVFVasWMGMGTNo27YtJ0+ezLMq/OvISFQ05J+M+04mxQgL\\\\\\n\",\n              \"C6NWrVrUrl07V2tJX19faeZ9SZLEjRs38PDwwMvLC2tra0aOHEnXrl3lWqOxfPny3Ai5hp+fH0uW\\\\\\n\",\n              \"Lee3X6dSokRJsrKyKF26NMOHu7F11T1MTEw+w9koD+X46xAEoUipqKiwYMECNDU1adOmDadOncp1\\\\\\n\",\n              \"MS2uWRwkxRf8lWVmcO7cOR4/foyxsTHGxsaYmJjk+72BgcEXd0fl2NhYdu/ezdatW0lISGDw4MHc\\\\\\n\",\n              \"vHmTSpUqKbSf8PBwNm3ezOYtWyimqYV5jdokJyWR8jaRAQP6M3jQIBFe+RABJgiCXFRUVJg7dy6a\\\\\\n\",\n              \"mpq0bt2a06dPY2BgwJkzZ4iLiUJKzoAy5nLvT5IktFTe4nf6NJUqVSIiIoJXr14RERFBREQEwcHB\\\\\\n\",\n              \"Od9HREQQHx+PgYFBgUFnaAtS6+kAACAASURBVGhYqC03mUzGqVOn8PDw4Pjx43Tu3Jnly5fTtm3b\\\\\\n\",\n              \"j+rGPHLkCINdXWnR0YFZ6/ZSybxGznNhfz7hxIGdWNavz9o1a+jXr9+nPBWlJwJMEASFDBw4kJs3\\\\\\n\",\n              \"b1KjRg1UVVWxsrKiQwcb7ruvR5I1Q0VVvsuK7G0YyUnx7N+/n9GjR9O0adMPbp+RkUFkZGSeoLt2\\\\\\n\",\n              \"7VrO969evSI2NhY9Pb0Cg87IyIhixfKfC5ef58+fs337drZt24a+vj6urq5s2LDhvQNb5OHn58eQ\\\\\\n\",\n              \"YW7MXLcn31vdVDA1x3XKPGwc+jBxVD/U1dXp3bv3Rx/vayNGIQqC8EEymYyQkBB8fX3x8/Pj+fPn\\\\\\n\",\n              \"dOrUCQ0NDU6ePMmZM2eoWrUqrdvacPFRGurlrArcpyTJUH3hh45mBqqqqqSkpGBtbc3o0aOxs7P7\\\\\\n\",\n              \"T12FWVlZREVF5Qm6f34fERFBZGQkZcqUyRNs//y3rq4uV69eZdeuXdy8eZM+ffowZMgQLC3lv6/a\\\\\\n\",\n              \"+yQnJ1OxUmWmr95JjXoFD4n/8497zB7Wg2dPnxTKBOZ3lOnaKVpggiDkkZSURGBgIL6+vvj7+6On\\\\\\n\",\n              \"p4e9vT2rVq2iadOmOV10W7ZsoW3bthw+fJj01GSyIq6jol4SNd1q7923JMlQi7yAlYUppwID8Pb2\\\\\\n\",\n              \"ZtasWYSHhzNjxgzGjh3LyJEjGTJkSL4jHvOTnp6Ot7c36zZu5unTp2RlZmJQrhz9+vRm2NCh+c4z\\\\\\n\",\n              \"e7f807/D7cGDB3h7e/PgwQOio6ORJIkSJUpQsWJFHj58yIoVK/Jt0RkbG+e5a/SH7N69m5pWjeQK\\\\\\n\",\n              \"LwDT6rVp2NKG7du3M3HiRLmP8zUTLTBBEAB4+vQpvr6++Pr6cuXKFZo1a4a9vT12dnaYmpq+t9zy\\\\\\n\",\n              \"5cuZOnUqqqqqaGtrk56ZDaUqk6ldI9ekZkmSIUv4k+Jv72NZyww/3yM56/llZmaybds25s2bR9Wq\\\\\\n\",\n              \"VdHR0eHs2bM4ODgwevRoGjZs+N7jHzhwgBE/jEK/clUsOjpjXLU2qqrqJESFc++kNw+CAxgyxJUV\\\\\\n\",\n              \"y5Z9sGX35s0b9uzZg4eHB7GxsQwePJhBgwZRqVIl4uLiCmzRvXr1Kmdawoe6Lk1MTNDW1qaupRU9\\\\\\n\",\n              \"Rk7DqlkbuX9HD2+HsHHueJ48flRoc9mU6dopAkwQvlGZmZkEBwfndA3Gx8djZ2eHvb097du3z7V8\\\\\\n\",\n              \"1PsEBQXRtWtX0tLSyMrKYtiwYcybN4+NGzex2n0tmTJ1slSKI8vOhIwEsrIy2bplI3369Ml3oEVq\\\\\\n\",\n              \"airr169n8eLFtGnTBlNTU7y8vDA0NGTUqFH06tUr13JTHlu38uP0mTjMcMekev7zrFIS3+C7dCK1\\\\\\n\",\n              \"Khlx4HevXCEmk8k4c+YMW7duxc/Pj06dOuHq6oqNjY3C3Zjv1k6UJ+j+Ptc0Dob8hbqcS1C9O0bv\\\\\\n\",\n              \"xt8RExOtUGtPEcp07RQBJgjfkJiYGI4fP46vry8nTpzAzMwMe3t77O3tsbKyUmgU3ebNm5k0aRJa\\\\\\n\",\n              \"WlqYmZmhra3N7du38fHxoVGjRmRlZXH16lV2797N48ePWb16NQMGDOCXX36hY8eOH9x3YmIiK1eu\\\\\\n\",\n              \"ZM2aNfTq1YvGjRuzd+9ebt26xZAhQxgxYgRRUVF06GRLn0We6FX47oP7y8rI4MCcIbg42jFn9ixe\\\\\\n\",\n              \"vnyZMyBDW1ubIUOG0K9fP/T09OQ+/48lSRLR0dGUL18B7xsvFS4/oHUtHv3xEAMD+ZbsUpQyXTvF\\\\\\n\",\n              \"Z2CC8BWTJInQ0NCcrsHQ0FBsbGywt7dn5cqVH1yF430yMzMZP348+/btw9DQkDp16iBJEgcPHuTY\\\\\\n\",\n              \"sWPY29tz6NAhmjdvTrNmzfjzzz958+YNNWvWZNCgQWzfvr3AACtdujRz5sxh1KhRLFq0iIkTJ+a0\\\\\\n\",\n              \"7vbu3YuVlRWlSutg3X1YgeEFoF6sGB1G/czyyc4EXzjP9evX6d27N/v376d+/fqFvrSUJEmEhYVx\\\\\\n\",\n              \"69Ytbt++zc2bN5GQSE5KpKS2/Pc6y8xIJ/ltUp51Kb9VX8baK4IgfDKpqan4+/szatQoqlSpQrdu\\\\\\n\",\n              \"3YiIiGDOnDlERUXh7e2Nq6vrR4VXbGwsHTp04MiRI5ibm2Nra0tUVBR79+5FTU0Ne3t7PD09cXR0\\\\\\n\",\n              \"5OzZswBoaWmRmpoKgLOzM/7+/sTHx8t1PH19fZYtW8adO3dISEjAzs4OfX19zp49S3R0NHXbO8ld\\\\\\n\",\n              \"d70K31G2wneYm5sTFhbG+vXradCgwScPr4yMDO7cucPOnTuZOHEi7dq1Q19fn4YNG7Ju3TqSkpLo\\\\\\n\",\n              \"0aMHrdu05fzxwwrt+2KgHy1atVZo+P/XTLTABOErEB4ejp+fH76+vpw9exYrKyvs7e05fvw4NWrU\\\\\\n\",\n              \"+CQX6dDQULp06YKqqir16tWjdevWbNu2jeDg4Fyrtn///fd4eXnRq1cvdu/eTfHixXMCTE9PL+f5\\\\\\n\",\n              \"4cOHy33sChUqsHHjRiZPnsycOXNYtmwZletYo6Xg3YzrdOhOTEToJ1tl/s2bN9y+fZvbt29z69Yt\\\\\\n\",\n              \"bt26xcOHD6lSpQqWlpZYWlry448/YmlpiZGRUa6yBgYGjBwzno49+sv9+zlxYAfzZk77JHX/GogA\\\\\\n\",\n              \"EwQlJJPJuHbtWk7X4IsXL7C1taVv377s2LHjP02uzc/Ro0dxdXVFR0eHVq1a0b59e6ZNm8aFCxfy\\\\\\n\",\n              \"nZPUtm1bvL29cXJyYsKECaSlpeU8N2jQIH755ReFAuwdc3Nzdu/ezcyZM/G9el/h8iV19Ih9IF/r\\\\\\n\",\n              \"758kSeL58+c5IfUusGJjY6lbty716tWjadOmjBw5EgsLC0qUKFHgPtu1a0fJ4sU4vGM9joNGFbj9\\\\\\n\",\n              \"8f07SU18g729vcL1/1qJABMEJZGYmMiJEyfw8/PD398fAwMD7O3tcXd3p0mTJoWyfJIkSSxYsAB3\\\\\\n\",\n              \"d3dKlSpF7969ad++Pc7Ozpw8eZLKlSu/t2yLFi3w8fHB1tY21+CIjh07MnToUC5evEhcXBxJSUlo\\\\\\n\",\n              \"a2vTqFEjue8LVqlSJbh0R+HzyUxPpWQB4ZKWlsa9e/dytapu376NtrZ2TquqX79+LF26FDMzs49e\\\\\\n\",\n              \"BV9VVRWfo0do2rw52VlZOA4ahVo+v8Ps7Gz8923l6M71BJ8/rzQLHH8O4pUQhC/YkydPclpZV69e\\\\\\n\",\n              \"pXnz5tjb2zNnzhyqVKlSqMdOSUnB1dWV0NBQVFVVmTBhAm3atKFDhw54eXm99/Yg/9S4cWPWr19P\\\\\\n\",\n              \"//798fLyonfv3ly/fh0dPX3atu+ASQ0r1Itrk5X2lqgnd+nUqSOTJ4zPs6xUXFwcwcHBBAUFERQU\\\\\\n\",\n              \"xJ07d1AtpoUsOwtVNfkvY2F3LtG12f9WComOjs4Jqndfnzx5grm5eU5YdevWjXr16sk9qVoRFStW\\\\\\n\",\n              \"5MqlS/Ry7sOIA7vo4NSPJu1sKamtQ0pyEtfOnuDEwV0YGxlyKTj4g/PxvkUiwAThC5KZmcmFCxdy\\\\\\n\",\n              \"QispKQk7OzvGjBmDjY2NXHOzPoWXL1/i4OCAvr4+0dHRrFq1imbNmtGiRQvWrl1L27Zt5d5X/fr1\\\\\\n\",\n              \"MTY2ZsKECRw+cpRjgaeo2K4v7YbaU6zk/0bgZaYk8fiSL7ZdnfjBbQj16tYhKCiI8+fP8/z5c5o0\\\\\\n\",\n              \"aUKDBg1yVqXPzMrg8ZUzVG/WQa56pL1NJPScP03NjbCzs+P27du8ffuWevXqYWlpSdu2bZkwYQK1\\\\\\n\",\n              \"atXKNdessJUvX57g80HcvHkT97XrWPHjMJKSEilVSpvmzZpx+OB+rK2tP1t9lIkIMEEoYtHR0Rw7\\\\\\n\",\n              \"dgw/Pz9OnDhB1apVsbe3Z+/evVhaWn72GzVevHiRHj160KFDB/z9/dmzZw/169enRYsWTJ48mV69\\\\\\n\",\n              \"eim0v+LFiyOTyXAdMoQV6zbTYqoHWrpGebbTKKGNqU0fjOq3Y9XiIZiVL8eggQMZMGAAr169Yv78\\\\\\n\",\n              \"+SxevBhdXV3GjRtH1apVmbNoOd81aImGZsGBc37PWrRKlKBMmTIMGzYMS0tLKleujIqKCrGxsWzb\\\\\\n\",\n              \"to2Jk6bwJv4NJbRKYFXfklEjR1K7dm2FzvdjWVlZsdXjt89yrK+FmMgsCJ+ZJEncvXs3p5V17949\\\\\\n\",\n              \"2rdvj729Pba2tnlGq31O27ZtY+rUqTg7O7N//358fHyoVasW7du3p0WLFixZskThfUZHR1O9enUy\\\\\\n\",\n              \"smQ0mbadkgYVCiyTEvOKiwv609e5F/v37ycpKYmGDRuyYMECbGxsgL8Hsjg4deePV3E4zFhDseL5\\\\\\n\",\n              \"f7YlSRKXD/zGXf/dGJXTp1KlSmzduhUDAwPS09MZP2Eie/bsoXGb72nZyQGdsnqkp6Vy+0oQAQd3\\\\\\n\",\n              \"U7NmTbZv9eC77wqeb/Y1UKZrp2iBCcJnkJqayunTp3OGumtoaGBvb8/PP/9Mq1at5Lpr73+VlpbG\\\\\\n\",\n              \"xYsXiY2NpXjx4tSsWRNz87/v35WVlcWPP/7I0aNHcXZ2xtfXl6CgIExNTenevTvfffcdixYt+qjj\\\\\\n\",\n              \"qqurk5T0lgqNvpcrvABK6JugW7s5uzw9GTpkCLNnzyYmJoZbt24xderUnMEVmZmZaJUsxZYRnbF2\\\\\\n\",\n              \"GEzd9o4UL/V3t6QsO4vHV89y23838eF/kpmeipWVFVpaWlhaWrJp0yaWLV+BVKwkW/wuoVM29yoc\\\\\\n\",\n              \"da2b0Wf4JHz2eNCseXPOnD5NzZo1P+o1EAqHCLAi8PbtW+Li4tDU1ERPT0+MKvpKhYWF5QTWuXPn\\\\\\n\",\n              \"qF+/Pvb29pw4cYLq1asX+uoP77x48QL3tWvZtm07FSqbol/OiIz0dO7fuUGdOnVwHTyInTt3An8P\\\\\\n\",\n              \"fw8KCuLixYsYGhoyYsQIUlNT2b9/v9xdmampqVy5ciVnwMWVK1eQVNWp1Eaxrkez7/sRd/8iV69e\\\\\\n\",\n              \"xdTUFBMTEywtLalXrx5jxozB0tKS8uXLA3D+/HlWr13HBtd2lDU0QU1NnfjoCL77zoyZY0fTs2dP\\\\\\n\",\n              \"0tLSWLZsGRs2bKB169a49B+ARcOmzFi25b3rHqpraOA4cASldMrQ0daW+6Ghn+1zSKFgogvxM8nK\\\\\\n\",\n              \"yuLo0aOsW7eeS5cuoqurR1p6GqqqqgxxdWXEiBEfHJIsfPmys7Nzzc0KCwvD1tYWOzs7Onbs+Mnn\\\\\\n\",\n              \"Zsnj9OnT9OrtTMeuvXDo60qFyv/rBstIT+dswFE83BdTqoQmNapX582bNxw+fBgdHR3mzZvHkSNH\\\\\\n\",\n              \"OHv2bM6q8flJSEjIGSF4/vx5bt26Rd26dWnZsiWtWrWiUaNGGBoZ02XjZYVCW5Ik/Ec140RAANbW\\\\\\n\",\n              \"1h+swztxcXGEhYWRmZlJuXLlqFixYp5tIiMjmT59Onv3ebH3XChaJeULpF/HD6Jf964fNX9NmXxp\\\\\\n\",\n              \"184PEW/9P4MXL17QubMdJbW1GTh0BFv3Hc7pMnry+A92bd2CVf36zJg+nUmTJn22d+bCf5eQkJBr\\\\\\n\",\n              \"bpahoSH29vasW7eOJk2a/KcbM/5XV69epWev3sxbvY36jVvkeb6Ypibfd+1J6+/t+XFEX27fvkNo\\\\\\n\",\n              \"6F20tLT47bff2LFjB8HBwXmCIzIykvPnz+cE1pMnT7C2tqZVq1bMnz+fxo0b56yUnpKSwt27d1FR\\\\\\n\",\n              \"VVX471pFRQVU1Jg/fz7lypVDV1dX7seHGBoaYmxiQkenPnKHF4Btr0G4r56Pm5ub+D/6hRAtsEIW\\\\\\n\",\n              \"ERFB02bNGDTsB0aMHv/e7cLDXtKvexcGDnBh+vTpn7GGgqIeP36c08q6du0aLVq0yLlv1pfSipYk\\\\\\n\",\n              \"iTp169HHbSLtbB0K3D49LRW3Hu1ZuXwJMpkMNzc3goKCMDc356+//soJq6CgIKKiomjevDmtWrWi\\\\\\n\",\n              \"ZcuWVK1alRcvXvDkyROePHnC06dPc76Pi4ujSpUq/PHoMR2XB1CspPxLP2WmvOXE5O/x9TlKfHw8\\\\\\n\",\n              \"cXFxBT40NDTQ1dVFT0/vgwE3afIUJizcQPU6Bd89+h2ZTMbA9pZcuXTxq56P9aVcO+UhWmCFbMiQ\\\\\\n\",\n              \"ofTs0/+D4QVQvkJF9h32p3PbZtjY2NCoUaPPVEOhIBkZGbnmZiUnJ2NnZ8e4ceOwsbEptPsy/RcX\\\\\\n\",\n              \"L14kJTWNtp26ybW9ZnEteg0exa8LFvLg/j2GDx/O3LlzCQoKIjMzkyZNmmBmZsbAgQNJT0/nzz//\\\\\\n\",\n              \"5PDhwyxbtozk5GTMzc0xNzfHzMyMxo0b069fP8zNzdHS0mLx4sU8fraOF8E+mH/vIvc5hF32w76r\\\\\\n\",\n              \"A506dZJre0mSSE5Ofm+4xcTE8OjRI+Li4oiKikKvnGKjPVVVVdEzMCQ2NvarDjBlIgKsED158oRr\\\\\\n\",\n              \"IddYv32fXNsbGZsw7IexrF27jp07RYAVpXdzs3x9fQkMDKRatWrY29vj5eWFpaXlF9+FtH7DRrr2\\\\\\n\",\n              \"HqxQPdvbObL85ymoq6ly7NgxNDU1KVOmDC9fvuTMmTOEhYVhZmaGubk5bdu2ZdiwYZibm2NkZJTr\\\\\\n\",\n              \"OJIkERAQQM+ePbl+/TqampoU11Dj1QVvzNr3RUWOwSCSTEZEsDdrPbfJXX8VFRVKlSpFqVKl/l5q\\\\\\n\",\n              \"6l8yMjJ49OgRoaGhnA06T2ZGutz7ztlHevpnGTEqyEcEWCHasGEDvfsOUGhWf2+XgTSzrEFMTEyh\\\\\\n\",\n              \"LF2jjCRJylkzr2TJkujr63/yAJEkiTt37uS0sh48eJAzN8vd3R1DQ8NPerzC9vDhQ9o6DFSojGZx\\\\\\n\",\n              \"LaqYVaO8oR4tWrTIaVWZm5ujp6dX4GseGxvLvHnz2LFjB0lJSVhYWLB3716cnJxo2LAhGdkSofuW\\\\\\n\",\n              \"YdFnygf3JUkSj4+sw6xSeVq2bKnQOcDfXX1//vknd+/eJTQ0NOfx9OlTqlSpgoWFBeUMynH/1jWM\\\\\\n\",\n              \"K1aRe7/xsdHERL3+YrqJBRFgherylatMmjFXoTK6unrUrlOP27dv50zY/FYlJSXh6enJ+vUbePHi\\\\\\n\",\n              \"L0rrlOHt2yT09PQYOWIEgwcPLvAD+w9JSUnh9OnT+Pr64ufnh6amJvb29syfP59WrVop9T2X0tPT\\\\\\n\",\n              \"UddQvP56+vpMnDiBzp07y7W9JEn4+voyd+5cbt26RcmSJXFxcWHOnDm5Ql+SJKIiwlGLi+PBvsVU\\\\\\n\",\n              \"7TYKjRJ5RxVmpiTxxGcDvLqP3/lzBQZdREREnqB68OAB+vr6WFhYYGFhgb29PdOmTaNGjRo5byaP\\\\\\n\",\n              \"HDnCrPkLsOnSU+7X5oT3XhwdHSldWv4bUAqFSwRYIUp++5aSCoxyekddXZ1du3bx119/oa+vn+tR\\\\\\n\",\n              \"pkyZz760UFG4ePEijk5OWDduyuxfl9KiVVtUVFSQJInr166ww2MjCxYsYOfOndjZ2cm935cvX+bM\\\\\\n\",\n              \"zQoKCqJBgwbY29tz8uRJqlWr9sV3DcpLX1+fmMhXVK9d8IK7/xT1+lWulePfJzo6mlmzZrFnzx6S\\\\\\n\",\n              \"k5OpX78+R48epXPnznlew+DgYO7fv8+kSZOYPn06w38YzeGfumFQpwXGDTugoVWKzNS3xIVe4NX1\\\\\\n\",\n              \"U9jZ2bHlYHCuuw7HxcXlCql3Dw0NjZygatasGW5ubtSuXbvAkLGzs+OH0aO5eTkIqyatCjzfxPg4\\\\\\n\",\n              \"/Ly24edzpMBthc9HBFghKl26NAnxbxQul5iQQFqaARcuXCAmJibXIykpibJly+YJtncPPT29PD8r\\\\\\n\",\n              \"Xbq0Ul2YL1++TNdu3VizcRvt2ue+9byKigoNGzWhYaMmXL92Bdd+Pdi61eO9IZadnc3Vq1dzugZf\\\\\\n\",\n              \"vXqFra0tAwYMwNPTkzJlynyOU/rs7O06c8RrB83byTcAAuCPe7dJTX5Lw4YN831ekiQOHjzI/Pnz\\\\\\n\",\n              \"uXv3LmXKlGHo0KHMmjXrvXPcfHx8cHV1xcrKimbNmlG6dGn2eu6kf//+REVHk3r/GImJiWhra+PS\\\\\\n\",\n              \"thX9PZYSGxuLt7d3rqB61yX57tGzZ09q164t9+1X/k1dXZ2d27fT27kPs913Ub1u/fdum5QQz7zR\\\\\\n\",\n              \"/XHp24cGDRp81PGEwiECrBC1a9cWf5/DtGorf1fgq/AwHjy4R+1aNXB0dKRDhw65PkPLysrKGVH1\\\\\\n\",\n              \"70dYWBi3bt3K8/O0tLR8g+1DjxIlShRJ6GVkZNC9Rw9WrfstT3j9WwPrxnh47meAswNPnz7NCaOE\\\\\\n\",\n              \"hAQCAgLw9fXl2LFjGBsbY29vz4YNG2jcuHGRzs36HB4/fszGjRsJD39FdGQEBobGcpU7vMeDEcPd\\\\\\n\",\n              \"8rw+r1+/5qeffuL3338nJSWFJk2acPLkSdq1a/fB/W3bto3p06fj5+fH0qVLc+7KnJGRgb+/P3v3\\\\\\n\",\n              \"7s3Vstq5YzuLFy2kevXqOUE1btw4LCwsqFix4if/e7SxsWHbVg8GDe5He4e+dO41AKMK//t8KyX5\\\\\\n\",\n              \"Lad99nN450a6O3RjyZLFn/T4wn8n5oEVolevXlHbwoKrdx6jLWe/+dIFPxMT8ZL69etz6NAhbt68\\\\\\n\",\n              \"SceOHXF0dMTOzu6j+t/T09OJjY3NN/Tye0RHRwMoFHj6+vqfZHTW77//jvu6Dew/ekLuMiOHuFC7\\\\\\n\",\n              \"RlV0dXXx9fUlJCSEli1b5szNym9E2tfKz8+Pvn37IpOgjJ4Bunr6rN5xGM0CVms/F+jH6vk/cuf2\\\\\\n\",\n              \"LQwMDJAkiX379vHrr79y//599PX1GTp0KDNmzChwKSVJkliyZAkbNmzA398fTU1Nhg4diq6uLsWK\\\\\\n\",\n              \"FePixYuEhYVRtWpVLCwsqFOnTk5gmZmZffal1Z49e8batevYsXMnJpVMKaOnT3pqCo/u3aFt27aM\\\\\\n\",\n              \"GT2qwLD+mnwJ1055iQArZC4u/VEtVpzFK9cV+A7yj4f36WHfgaBz53IWDY2Ojubo0aN4e3tz/vx5\\\\\\n\",\n              \"WrZsiaOjI127dv3o7hN5pKSkyB147x6ampoKBZ6uri4aGhq5jtu6dRv6DxmBfbfuctf18sULDHB2\\\\\\n\",\n              \"oI+zM/b29n/fqv0LnJtVmCRJYtmyZcyb/wsWDZsy5Md5GBhXYMX0USTGRjN76QaMTPIuq5SVlYXv\\\\\\n\",\n              \"/l1sdV/EMX8/jI2NmTZtGt7e3qSnp9OyZUsWLFiQ5waT+R0/IiKCO3fusHDhQu7cuUPlypV5/Pgx\\\\\\n\",\n              \"2traZGZmYmxszPfff8+dO3fo1q0bY8aMKayX46O8W8MxPj6eEiVKYGFhgYmJSVFX67P7Uq6d8hAB\\\\\\n\",\n              \"VsiSkpJo1rw59Ro0YtFy9/e+u7x7+yaDnJ1YtGgh/fv3z3ebxMRE/P39OXToEAEBAdSrVw8nJycc\\\\\\n\",\n              \"HR2LvJUhSRJJSUkKBV5cXBza2tq5Qi0gIICHz6MooUAASZJEtYp6hIWF5frg/1uRnp7OyJEj8fH1\\\\\\n\",\n              \"w8K6OeMXrM3pBpTJZOzbuBzfPVup16AxnRyc0S9nSEZ6OreuXcTvoCdm332HbaeOeHp68ujRI4yM\\\\\\n\",\n              \"jBg5ciRTpkzJdwrImzdvcn0+9W4UoJqaGmpqaqiqqjJx4kQiIyPx9fMnPTOT76rWopimJjFREdy5\\\\\\n\",\n              \"fo3BroOZNnWqGJL+BfpSrp3yEAFWyN69i01OTiYhMQmXQUPp4dwPI2MT0tPSuHblEjs8NnL1UjAb\\\\\\n\",\n              \"N26kZ0/5hvWmpaURGBjIoUOHOHr0KFWqVMkJM2W55YNMJiMqKopHjx7x+PFjnj59yqJFiwiPS1P4\\\\\\n\",\n              \"8w7L6pW4ceP6N/eO+fXr1zg5OaGlpcWd0PtsPnYFjWJ5u3LTUpI5d+wwV04f421iPBoaxfjr8UNq\\\\\\n\",\n              \"1qjGrVu3yMrKol27dixatAgrq7+XV0pOTubBgwe5Qiq/ARUWFhaYmpoycuRIihUr9vf8R2dnJNVi\\\\\\n\",\n              \"uAwfT4MmLXP9PsNfPsd791aOHdrHIe+DtGiRd51Goeh8KddOeYgA+4+ePn3K9u3befHiJdkyGeVN\\\\\\n\",\n              \"jHFxcaFOnTpIkoSbmxuxsbEcOHCA27dvs379eg4fOUJcbCyamprUrFmLESOG07dv34/u9srKyuL8\\\\\\n\",\n              \"+fN4e3tz6NAhtLW1cXR0xMnJiQYNGnz2wRiSJJGYmMjr16+JiIj44NeEhAQMDAwwMjLC2NiYkydP\\\\\\n\",\n              \"cvuPl+goMDowOzubqhV0iYqK+qZudXHjxg0cHBwYPHgwp06foVKdxvT9YbLc5b02r8TXcws/TpmM\\\\\\n\",\n              \"ra0tDx8+zNWyevXqVa4BFe8elSpVyvU3FRMTg52dHbVr12bNmjV07NSJiua1mDBr8QenfFw+f5qf\\\\\\n\",\n              \"Jw3n1MlA6tWr959eC+HTEQH2BSjsX8KtW7eYNm06IddD6OHsQo2atVFRUeHZ08fs37MLMzMzGjZs\\\\\\n\",\n              \"QGBgIJcuXcqzorckSYUSLDKZjJCQELy9vfH29iYtLS0nzFq0aPGfRuBlZWURHR2dbxj9+2dqamo5\\\\\\n\",\n              \"ofTvr//8Xk9PL1edunTpStvv7ek7YLDc9Qo87sfqpb8SEnLto8/tS/DixQs2btzI/gMHiY2NQUNd\\\\\\n\",\n              \"g6rVqjFiuBs9evTI1Z3n5eXF6NGjWbNmDSdPnmTXLk9+C7iGroH8K4bEx0YzuEN91FRVc1ao+OfD\\\\\\n\",\n              \"3Ny8wAEVf/31V84gowULFrBhwwZ2/+7NCg/57h92eN8Ogo4d5HzQObnrLRQuEWBfgML8JZw6dYre\\\\\\n\",\n              \"zs78+NPP9OzjgpaWVq7nMzMz8TvizZTxP/DL/PmMGzeuUOpREEmSuH//PocOHcLb25uwsDC6du2K\\\\\\n\",\n              \"o6Mj7du3zxk1mJycnG8I/ftrbGwsurq67w2jf3792JbQ8ePHmTptOsfPXpE74J0dOzNoQD8GDRr0\\\\\\n\",\n              \"Uccsaunp6Yz84QcOHzpMt559cOjVD2OTCmRmZnDnRgj7dv7Gg3t32LRxI926dWP27Nl4enqyefNm\\\\\\n\",\n              \"5syZQ7ly5QgIOMGBkOcKH7tno+/46/mfHzUgKDQ0FFtbWyZNmsT48eORJInaFnUY89NCGjYteHIw\\\\\\n\",\n              \"QGZGBg6t6nD2zGlq1aqlcB2ET08E2BegsH4JDx48oFXr1mzavpdmLT78n/TBvbv0drDF++Dn7+eX\\\\\\n\",\n              \"yWTExMTkCqEHDx5w6dIlHjx4wJs3b9DS0iIrKwvgg2H07mu5cuUKfEeekJDAlStXSExMpGTJklhZ\\\\\\n\",\n              \"WWFkJP+q3zKZDIs6dRjsNpr+g4YWuL2/z2HGjhjM999/z/r16xU61pcgIyMDO3t71IuVYJH7lveu\\\\\\n\",\n              \"3HLnZghuLt3RUFdDkiQqV67MzZs30dTUJD09nezsbA7fCle4Vd/DugpxsbGUKFFCoXLBwcE4OTmx\\\\\\n\",\n              \"cuVK+vbtC/w9Ab1v/wHsC7iqUD02rfiFUupZrFq5UqE6CIVDmQJMTGRW0IIFCxk+anyB4QVQs3Yd\\\\\\n\",\n              \"Zs1fxJw5czl16uQnqhL93gAAIABJREFUOX5qaiqvX79+b0vp3fdRUVHo6OjkCSFHR0dGjhyJpqYm\\\\\\n\",\n              \"oaGhObd8r127ds7w/I9ZRDg0NBR3d3d+//136tStR5myuiQnv+VGyDXat+/AmDGjadWq4NdMVVWV\\\\\\n\",\n              \"o0eO0LJVK2QyGQMGD3vvxfCI936mjBvB8ePH8ff3p27duixZsoSBAwfmKSNJEteuXWPtuvVcunSJ\\\\\\n\",\n              \"5Ldv0S5dmrZtWjNq1Cjq1Kmj8DkrIjs7m9jYWKKiooiOjs75unffPtSKabF2h+cHu3frWjVkl/dx\\\\\\n\",\n              \"etq2BiSuXr1KmTJlqFWrFoaGhpwIPMmLp39Q2byG3HUK+/MJ2tql8/QgFOTd6hqenp507Pi/yeaP\\\\\\n\",\n              \"Hz+mZh0rhUO0Rh0rzhzdq1AZQQARYAqJiYnBx+coF28tkbtMV8ee/DJrOg8fPqRGjfwvLu9WW3/f\\\\\\n\",\n              \"50n//JqamoqRkVGeLjxra+tcQWVoaFjgYrROTk7A360mPz8/vL29mTBhAg0aNMDR0REHB4d8b8n+\\\\\\n\",\n              \"bzt27GDylCkMGzGKiyF3MTL+38oPiQkJeO31xMWlPwMG9Gf+/PkFXuDMzc05HxSEg4Mju7ZuYsCQ\\\\\\n\",\n              \"EXTu4oDO/y/mezLAn83rV5OUmICpqSl+fn4sXLiQnj17MmTI/7V353E1pv//wF+VU9qXc051WrSp\\\\\\n\",\n              \"VCrJoEJZkm3siuwiMfatMczYBiP7xxIhxJA9IZJQsoaRXSmyJO20b+f9+8NXv2lanKPC4Xo+HudB\\\\\\n\",\n              \"59zXfV3nnnG9uu/7uq/LC/v27UNAQAAMDQ0BfHhQdfAQT7x9+xZDRo7DxnFToaikhNx37xB+6ji6\\\\\\n\",\n              \"unWDhYUF9v29V+QzOKFQiOzs7CqBlJaWVu17OTk5UFNTA5/Ph6amJjQ1NaGqqoq4O3E4c/kfke5N\\\\\\n\",\n              \"mppbYOyk6di73R9r1qzBgwcPEBoaijdv3sDK0hKn9gdi4u+i/78ZfigIXl5jxAqcnTt34rfffsOp\\\\\\n\",\n              \"U6eqrFlXWlqKRo04NZSsWaNGHJSUlIhdjmHYJUQxBAQE4EzEeWzesUescovmzUF2+ht07Nix2lB6\\\\\\n\",\n              \"+/YtFBQUPnkJTyAQQF1dvUFHFRYWFuLs2bM4evQoTp48CRMTE/Tv3x/9+/eHmZlZle0PHTqEadOn\\\\\\n\",\n              \"40joGZg3q3n4fnpaGgb164kB/frijz/+EKktQqEQ58+fx8aNmxAVdRHv37+HkpISbGxs8fDhA7x+\\\\\\n\",\n              \"/Rp5eXlwdHTEjBkz4OPjg9LSUqxZswYrV67E77//jq5du6JTp87wnjILw8f4VDuwoLS0FJvXrsDx\\\\\\n\",\n              \"w39jT1AQpKWlPxlImZmZUFZWrhRIH//+3z81NTWhqKiIFy9eID4+vuJ18eJF6DQxxo79ISL/98lM\\\\\\n\",\n              \"T4OzfTNYWVmif//+6NWrF/755x/MmzcP2TnvsPXUNajz+J/cz7usDEzq2wF3/rkt0rNY/55dIzw8\\\\\\n\",\n              \"HObm5lW2OXnyJBYv98OmvSdE/j4AcOTvQLx8fBt/7xXv3xXTMNglxO/UmzdvYGBkLHY5Q+OmOHfm\\\\\\n\",\n              \"JJSVlSEQCGBjYwM3N7dK4STOmmENSV5eHn369EGfPn1QWlqKqKgoHDt2DC4uLlBXV6941szOzg6F\\\\\\n\",\n              \"hYWYMGECDh8/XWt4AQBfUxPBh0PRoa0dhgwZAlNT00+2RVpaGl26dEGXLl2qfObk5ITw8HD06dMH\\\\\\n\",\n              \"p0+fRrt27aCnp4devXrB19cXffr0wejRo/HHHwvgu2AZBg8fU2M9HA4HU+fMB4Hwc+/esGjWDFpa\\\\\\n\",\n              \"WhUBZGxsjLZt21YKJB6PV2UWEaFQiJcvX1YE1JUrVyr+/vr1azRp0gRmZmYwMzODvb09Hj+Jx4DB\\\\\\n\",\n              \"1T+0XhMuXxOtHZ0wc+pkyMjIYNiwYVBTU8OhQ4dwNiICSycPw4ItwVBWrX5yXQDIe/8OiyZ4YuLE\\\\\\n\",\n              \"CSKFl1AoxKxZsxAREYHLly9DV1e30uclJSU4c+YMgoKCcPdWLF6/fA5dMdbZCju6D0sXifZLDcP8\\\\\\n\",\n              \"GwswMcjIyKC8vFzsckJhObp06QJ/f/8GaFXD4XA4FQGyYcMG3LhxA0ePHoW7uzvKyspgamqKlq1a\\\\\\n\",\n              \"w9au5pm8/01bIIDn8FHYsmULVq9eLXZ7CgoKKs6C2rRpgz/++ANPnjxBWloa7Ozs0L9/f5iYmCA/\\\\\\n\",\n              \"Px9paWkfZuiwsKo1vP5t6uz5CDt+GCtWrKjxfh0RISMjAzdu3Kh0NhUfH4/ExERoaGhUhJSZmRlc\\\\\\n\",\n              \"XV1hZmYGIyOjKoG3b/9+qGt8eumS/1JT08CcOXPA4XCwfPly9OrVC1JSUnBwcEDu+1z8OvxnDJ44\\\\\\n\",\n              \"C2079wDnX2uClZaW4FrkaQRvXomcrAy0/umnT9ZVUlKCMWPG4Pnz54iOjq6YdV4oFOLSpUvYt28f\\\\\\n\",\n              \"jhw5AktLS3h6eoKvqYmjfwdi8q+LRfouD+NuIzsjTeT1xxjm31iAicHQ0BDRMeLfbH704B6szJs2\\\\\\n\",\n              \"QIu+HGlpabRt2xZt27bFihUrcP/+ffTp0xd/rV4v1n5Gjx2PLh3aYvny5RAKhUhPTxfpHlJaWhqE\\\\\\n\",\n              \"QmHFWRCXy8Xjx4+RmJgIExMTNG/eHPb29tiyZQuOHTuGli1bolv3Hhg0/NMjGT+SkpKC5yhvbNy0\\\\\\n\",\n              \"Gfb29khISKgUUE+ePEF8fDyICObm5hUh5e7uDjMzM5iamor1+ICioiIKCvLFOn4AkJmRAVdXV6xf\\\\\\n\",\n              \"v77SvTMpKSmsXr0Kzs4dsHrtOgSuXAD7dp0gr6iMwvxc3Iq5gGbNzLFutR/09PTQs2dPcLncGkfI\\\\\\n\",\n              \"5uXlYeDAgZCVlcXZs2chLy+PO3fuYN++fdi/fz80NDTg6emJW7duVZzJJScno9VPrdHK0RkOHWpf\\\\\\n\",\n              \"hSEnKxOLZo3Hgj9+/+5XCGAaBrsHJob8/Hw0adIEZy5eg76BoUhlct+/h72VMR49fFjl0oukU1FR\\\\\\n\",\n              \"wb0nz6Ei5vyDRro8lJWWoqSkpNZ7SP99T0lJqdL9vylTpkBNTQ2LF///3/Y3bNiAzZs3IyIiAqam\\\\\\n\",\n              \"priTmFblzKc2WZkZcLI1gbS0NJo2bVrpbOrji8fj1ct9yPnz5+NVWjbm/7lK5DJFhYVwtjdH7I3r\\\\\\n\",\n              \"MDau/XL2o0ePEBMTU7HeVrt27So9axUREYFhw4bh3LlzVUZhfpxdo3nz5vj1119x8ODBisUrPT09\\\\\\n\",\n              \"4enpiebNm1dbb0xMDPr264/xM+ajZ/8h4FQzmOjh3X+waKY3BnsMwrKlS0X+/kzDk6R7YCzAxDR1\\\\\\n\",\n              \"6lQUlhIW/yXaJbBN61YhYPN6GBsZYenSpd/VsgxycnJ4lpIp9jBsWwtjhB4/Dltb2zoFwb1799Ct\\\\\\n\",\n              \"WzckJydXej5t1qxZiImJQdKz57h2/7lY+yQimGorori4WKzg+xwvXryAbYsWiLoVL/LkxUcP7MX5\\\\\\n\",\n              \"sKM4c/p0vbQhODgYs2fPRkxMTKWzqM6dO8PY2Bh5eXlISEiAu7s7PD094eDgINIMG3fv3sWkyVPw\\\\\\n\",\n              \"+PFj9HYfjmbWdmjUiIPUlFc4fXQfsjPTseCP3+Hl5VUv34OpP5IUYN//2vT1bO7cuQgPC8W+oJ2f\\\\\\n\",\n              \"3Pbs6ZMI2LweFy9cwOTJk+Ht7Y3OnTvj2rVrX6ClDY/H4+FNymuxyhQXFyMrMxOGhoZ1PouxtraG\\\\\\n\",\n              \"gYEBwsLCKr3v5+cHLS0t5OXmir3P4qIicDicBl+TiujDs1wA4L9OtIUSc9+/w7oVi+Hi7Fxv7Rg8\\\\\\n\",\n              \"eDDmzJmDrl274tmzZ1i+/MOCkq9fvwafz8f8+fORkpKCTZs2wcnJSaTwAgAbGxtER13ExQvnocwp\\\\\\n\",\n              \"x4XQ/TgVvB0vH9/G0sUL8CwpkYUXU2fsHpiYtLW1cTY8HG7duuH+3TvwnjgFhsYmlbZ5k/Iau7Zv\\\\\\n\",\n              \"wYG/dyP0+HFYWFjAwsICgwYNwq5duzBo0CDY2dlhyZIlEj2Jab9+/RC8bw9++32RyGVOHj+GZhYW\\\\\\n\",\n              \"9bbsibe3NwICAtC7d++K96SlpREcHAxtbQHuxd2Gta1og0wA4GrMRdjU8czwU16+fIlJkyYhPj4e\\\\\\n\",\n              \"Hdq3x57ALVDT4MJrQs1Tjr3LycbEke6wt7PD+vXr8f79eyxcuPCTz/p9SklJCQwNDSEjI4OmTZtC\\\\\\n\",\n              \"RkYG3t7eWLFiRb2sqWZpaYm1a9bUeT8MUx12BvYZzM3Ncf3aNXBVFfGza3sM6dcDC3+bjYXz5mD0\\\\\\n\",\n              \"kP7o7NgSpQXvce3qVbRp06aiHIfDwbhx45CQkIBOnTrBzc0NgwcPRnx8/Ff8Np9v4sSJ2LNrh8gP\\\\\\n\",\n              \"oRIRtm7+H96kpKBly5YICgqq8wOs7u7uuHr1Kl68eFHpfXl5eUyZMgWB/v8Ta3/7dgXgl4kT69Sm\\\\\\n\",\n              \"mpSXl2Pjxo1o2bIlbGxsYGFhgbS0NJibmcF/3QoM6uGCs2HHK6b3AoCMtLfwX+eH3p3awKHtTwgJ\\\\\\n\",\n              \"OYa4uDjcvXsXTk5On/X/jlAoRHR0NMaPHw8dHR34+fnB2dkZjRo1grW1NdasWfPDLQjKSCYWYJ9J\\\\\\n\",\n              \"S0sLK1euxMuXLzHRxxumRvowaaKDYZ6DkZycjM2bN1fMBPFfjRs3xrRp0/D06VPY2NjAyckJXl5e\\\\\\n\",\n              \"SE5O/rJfoo4sLS1haGCIyT5jIcqt1M0b1qGwsABJSUlYtmwZ9u7dC0NDQyxduhQZGRmf1QYFBQV4\\\\\\n\",\n              \"enpix44dVT6bPHkSLkaeQfzjhyLt6+b1K4j75yY8PDw+qy21uX//Ptq1a4cDBw4gJCQEkZGRKCj4\\\\\\n\",\n              \"cCzS0tKgIxCgS8cO2LttAxybG6Knsz26OlijW3s75KS9xInQ41i3di1kZGSgqalZMZ2Tk5MTtm/f\\\\\\n\",\n              \"Xu3xv3nzJkaNHg0jYxPw+Hzo6enDxtYWAoEAEydOhJGREW7evIkxY8YgJCQEFy5cQJMmTTBq1CgI\\\\\\n\",\n              \"hcJ6PwYMU+/oO2Vvb/+1myCyrKws+u2330hDQ4MmTZpEb968+dpN+qTy8nKaO3cuGRgYkG2LFuQx\\\\\\n\",\n              \"ZBg9ffGWsvLLqrxepb+nWb/OIwNDQ0pOTq60n7t379KYMWNITU2Nxo8fT48ePRK7LXfv3iVdXV0q\\\\\\n\",\n              \"LS2t8llQUBAJdHQpPOYfSkwrrPF19Ew08TW16PTp0599TKpTWFhI8+bNIx6PR1u2bKHHjx9T06ZN\\\\\\n\",\n              \"ydnZmWRlZYnH49Hu3buprKysokxKSgrdvXuXHj9+TLm5ubXu/8GDB2Rra0v9+vWjjIwMIiJ6+fIl\\\\\\n\",\n              \"OTg6kX4TA5ozfwlFXImj6/ef07lr9+i3xSvI0MiYWra0p/j4ePrrr7/IwMCAHj9+TEREBQUF1L59\\\\\\n\",\n              \"e5oyZQoJhcJ6PRaMZJCkvpMF2Dfk7du3NG3aNNLQ0CBfX1/KzMz82k2qVm5uLvXt25fat29PaWlp\\\\\\n\",\n              \"lJ+fT15eXqSqqkqew0ZS8JHjdOb8JToSepomTJpKXC6XevX6udZgTk1NpQULFpCmpib16NGDIiIi\\\\\\n\",\n              \"xOpAHRwcKDQ0tNrPdu7cSapqajRi7ASKvHa/UnCFRd2kQZ4jSV1Do8byn+vixYtkZmZGAwYMoNev\\\\\\n\",\n              \"X1NMTAzxeDwSCAQkIyNDs2fPpsLCwjrXU1RURDNnziRdXV3au3cv6ejq0uz5iyn+TV61YZ2Qmk+L\\\\\\n\",\n              \"/lpHKqqqZGpqSq9evaq0v+zsbLKxsaGlS5fWuW2M5JGkvpMF2DfoxYsXNG7cOOJyubRo0SJ6//79\\\\\\n\",\n              \"125SheTkZLK1taXRo0dTcXFxpc/S0tJo+fLl1KWLK7Vu3YY6duxEc+bMoaSkJJH3X1BQQNu2bSNL\\\\\\n\",\n              \"S0uytramwMBAKioq+mS5nTt3Us+ePWv8PDExkVxdXUleXoGampmTnf1PZGzSlLS1BWRra0uzZ88W\\\\\\n\",\n              \"uY2fkpWVRV5eXqSnp0chISFERBQYGEjy8vIkLS1NxsbG9OzZs3qrj+jDLxVz5swhBUVF8v39z1rP\\\\\\n\",\n              \"Nj++lq7eREZGxlRSUlJlfykpKWRkZETbtm2r13Yy3z5J6jtZgH3DEhISaOjQoaSpqUmrVq2igoKC\\\\\\n\",\n              \"r9qeK1eukEAgoFWrVjX45SWhUEhnzpwhNzc30tbWpsWLF1NaWlqN2+fn55OGhkaVS5T/NX36dLKz\\\\\\n\",\n              \"s6MLFy7Q/fv3qaSkhOLj44nH41F2dnad2xwcHEwCgYB++eUXevfuHeXn51P37t1JSkqKFBQUaNKk\\\\\\n\",\n              \"SZUuF9ZFcXExhYaG0uDBg0lFRYVat25NpuYW9PRtgUgBlphWSG0c29GhQ4eq3X98fDxpa2tXhDDz\\\\\\n\",\n              \"Y5CkvpMFmAS4d+8e9e3bl3R1dcnf37/Kmc+XsGfPHuLxeHTixIkvXvf9+/dp7NixpKamRuPGjaMH\\\\\\n\",\n              \"Dx5Uu92kSZNowYIFte6rvLyc3N3dyd3dncrLyyveHzVq1CfL1iY5OZl69uxJVlZWdOXKFSotLaWt\\\\\\n\",\n              \"W7eSoqIiycvLk6qqao1BIY7y8nKKjo6m8ePHE5fLJScnJ9q0aROlpaVRn759acnKDSKHV2JaIa3b\\\\\\n\",\n              \"upucXTrWWF9sbCzx+XyKjo6uc9sZySBJfScLMAly48YN6tq1KxkZGVW58d9QPg7WMDIyonv37jV4\\\\\\n\",\n              \"fbV5+/YtLVq0iLS0tMjNzY3Cw8MrnQnevXuX9PT0qh3M8W+FhYXUrl27SpcNExMTicvlin3fsays\\\\\\n\",\n              \"jNatW0dcLpeWLFlCRUVFdOzYMTI3Nyd1dXXS0tKq87ETCoUUFxdHc+bMIX19fbKysqJly5ZVuQzZ\\\\\\n\",\n              \"uHFj+ifhjVgB9ujVO5KVla31Mm1ERARpampSXFzcZ38HRnJIUt/JAkwCXbx4kZycnMjCwoIOHTpU\\\\\\n\",\n              \"6UyiPv13sMa3orCwkAIDA6l58+ZkZWVF27dvrxgM0bZtW5EGY2RmZpK5uTlt3Lix4r2xY8fSvHnz\\\\\\n\",\n              \"RG5HXFwctW7dmjp06ECPHz+m6OhocnBwoGbNmpGRkREJBAJydXX97ME4z549o2XLlpGVlRXp6+uT\\\\\\n\",\n              \"r69vjSFSVFREjRo1Euvy4ccXl8ujt2/f1tqW4OBg0tXVFet+JiOZJKnvZAEmoYRCIYWFhZGdnR21\\\\\\n\",\n              \"bNmSwsLCRLovdePGDRo/3ofcunUj165dadSoURQREVElBGsbrPGtEAqFFBERQd27dyctLS1asGAB\\\\\\n\",\n              \"rV27lnr16iVS+aSkJBIIBHT8+HEi+hAYGhoalJ6eXmu5goICmjt3LvH5fNq2bRvFxcVRr169yMDA\\\\\\n\",\n              \"gJYtW0YCgYDU1NRozpw5Yp8lp6Wl0aZNm8jR0ZG4XC75+PhQdHT0J39JSUhIIGlpaXr06p1Y4fX0\\\\\\n\",\n              \"bQHJy8t/crg+EdGGDRvI1NT0k2HHSDZJ6jtZgEm48vJyOnToEFlYWFC7du0oKiqq2u2io6PJ3t6e\\\\\\n\",\n              \"DI2MaPHS5XT0+EkKORFGa9ZtIGsbGzI1NaWDBz/co/mSgzXqy8OHD8nb25tUVVVJTk6OIiIiRCoX\\\\\\n\",\n              \"GxtLPB6Prl+/TkREPj4+5OvrW+P2kZGR1LRpU3J3d6fY2FgaNWoUaWpq0tq1a+n06dOkrKxMysrK\\\\\\n\",\n              \"FBwcLHLbc3Nzae/evdSjRw9SUVGhIUOG0IkTJ2r8xaGoqIhOnjxJI0eOJDMzM5KVlSUApKyiSjv2\\\\\\n\",\n              \"h4gVYAdPnicTk6Yi/3eeP38+2dvbf1MjY5n6JUl9Jwuw70RZWRnt3r2bjIyMqGvXrnTjxo2Kz44d\\\\\\n\",\n              \"CyE+n0/7Dhym/OJyKiylSq+CEiFFnI8ifX19GjFiBPH5/K8yWKM+pKenU5s2bUhJSYlcXV3p9OnT\\\\\\n\",\n              \"n+ycQ0NDSSAQUGJiIr148YI0NDSqnGVkZGTQ6NGjSV9fn/bu3UszZ84kDQ0NmjdvHuXk5FQMkxcI\\\\\\n\",\n              \"BHTnzp1PtrO4uJhOnDhBQ4YMIRUVFerevTvt3bu32jOh5ORkWrVqFXXp0oV4PB5JSUmRjIwMNWnS\\\\\\n\",\n              \"hAYNGkR///035efn044dO6izWw+xAqzvwMG0evVqkY+vUCgkb29v6tKli0iPNzCSR5L6ThZg35ni\\\\\\n\",\n              \"4mLy9/cnHR0d6tu3Lx06dIh4PB5dvnazSnD99/UkMZm4XB5t2LDha3+NOomLiyNdXV3asWMH2dra\\\\\\n\",\n              \"koWFBQUEBNT6GMLmzZvJzMyMMjIyaNKkSTRjxgy6cOECTZ4yhdp3cCY1dXVydHSk2bNnV1zaS0lJ\\\\\\n\",\n              \"IaFQSL6+vtS4cWNydHSsmA2jOh9HEPr4+BCPxyNHR8eKEYQfFRYWUmRkJE2YMIGsrKxITk6uYgh+\\\\\\n\",\n              \"y5YtacaMGRQbG1vtJcX8/Hzi8fl04ESkSOEVGnmV1NTVKSsrS6zjW1ZWRv369SMPD48Gu//KfD2S\\\\\\n\",\n              \"1HeyAPtOFRQU0KpVq0hNTZ3+8lv9yfD6+DoScoJatWr1tZtfZ23btqUTJ06QUCikyMhI6tWrF/H5\\\\\\n\",\n              \"fPr9999rnBHE19eXnJycaM3ataSkrEzGTc1oiu8CWrJmC/3x13rq2rMfySsoUJ++/ejNmzdUXFxM\\\\\\n\",\n              \"vXv3JllZWZowYUKNox/j4uLI19eXmjRpQpaWlrR06VJKSkoioVBIz58/J39/f+rZsydpaWmRtLQ0\\\\\\n\",\n              \"SUtLk5aWFnXv3p02bdpEKSkpIn/v8PBw4vM16XDYxVrD6+T566Qt0KGDBw9+1vEtLCykDh060KRJ\\\\\\n\",\n              \"k6qc4QqFQnr//j1lZmZ+kZGyTP2SpL6TBdh3LDU1ldTU1OhNerbIAZZXVEYGhoaVLkFKosDAQPr5\\\\\\n\",\n              \"558rvff48WPy8fEhNTU1GjVqVJURfeXl5WRjY0sCXX3acTCM4l68p7svcyu9ou8+J69fZpB+kyYV\\\\\\n\",\n              \"958CAwOr1P9xBGHz5s1JX1+f5syZQ9evX6cLFy7Q9OnTqUWLFtS4cWOSkZEhDodDzZo1o7Fjx9KZ\\\\\\n\",\n              \"M2coPz+/Tt/95MmTxOXxaMjwMXQi8lql4Dpz6TaN8PIhDQ0u7d+/v071fJxy6s8//ySiD5c6586d\\\\\\n\",\n              \"S5qamqSoqEiqqqrUuHFjGjx4CF26dEli7qf+6CSp72QrMn/Htm3bhvMXorAzaK9Y5ZYsWoCSokKs\\\\\\n\",\n              \"XOnXQC1rePn5+dDV1cXYceNw79495OXlQVlZGS7OzujXrx+OHDmCTZs2wcLCAjNmzEC3bt0QEBAA\\\\\\n\",\n              \"v1WrsfPIWWhw+bXuf1/gFqxfsRCnw07B+f8WmMzIyMDBgwexb98+PHr0CG5ubmjSpAni4+Nx/fp1\\\\\\n\",\n              \"pKamQkpKCoqKirCxsUH37t3x888/w8rKSuSFIkWVmpqK7du3Y+vWAEhJS0NVVQ25ue9RXFSEcePG\\\\\\n\",\n              \"wtvbG3p6enWu582bN3B0dIRtixa4FB2NQYOHYpSXN0zNmwEAcrKzcWDfHuzcvhWGBgY4fPhQva0F\\\\\\n\",\n              \"xzQMSeo7WYB9x5YtW4bsd7lYsnS5WOV2Be7Agf17MXHiRCgqKkJRUREKCgqV/lRUVETjxo3rveOt\\\\\\n\",\n              \"D9nZ2fhl0mSEHg9Bjz4D0LVHbygrq+L9+xxEhIUiPCwU/fv3x+pVqxAWFoY1a9YgLy8PKW/eIOhY\\\\\\n\",\n              \"JEybWYpUz+wJI9CxXVvo6+th9+7diImJgaGhIUpKSvD69WuUlZVBKBRCR0cHDg4O6NOnDzp27AiB\\\\\\n\",\n              \"QNDAR+D/Kysrw7Nnz/D+/XsoKyvDyMgIHA6nXuvwmTABkZHncfxMJLS0tKvdpry8HL/Nno47t2MR\\\\\\n\",\n              \"dfEiW2/sGyZJfSdbkfk7xuFwUFpaKna5kpIS3IiNxS0vL8jKykJGRgbS0tIgIpSXl6O0tBRFRUUo\\\\\\n\",\n              \"KSmBvLx8pVCrLujEfe/j3xUUFMQOyPT0dDg7u6C1UwdcuZsIFZXKv+27duuF3xb9hWUL58LVtSvO\\\\\\n\",\n              \"n4+ElpYWBg4cCCMTc5HDCwCGjPbBxOH9UVZWirKyMnA4HCQmJsLCwgK//PILunfvjjZt2nzVzrpR\\\\\\n\",\n              \"o0YwNTVtsP2Hh4cjPPwswi9cBpfHq3E7GRkZ/LV6PSb7eGHWrFnw9/dvsDYxPw4WYN8xU1NThJ44\\\\\\n\",\n              \"KXa527diMXvWLAwYMABv3rxBSkpKxevfP7958waysrLgcrng8/nQ0NCAmpoaVFRUoKysDHl5eTRu\\\\\\n\",\n              \"3BiysrIoLi5Gfn4+0tPT8fz5cxQUFCA/Px/5+fkVf//ve4WFhZCTkxM5CBUUFLB/fzC6dP8Zvy5Y\\\\\\n\",\n              \"VuP3U1PXwIp1W7Dg1+mwsbFFTk42CFIY6iXeSsx2PzlAWUUVRoZN4OnpCRcXF1haWkJGRkbsYy6p\\\\\\n\",\n              \"1q1bj5m+v9UaXh9JSUnhj8XL4WjfHMuXL4eamtoXaCHzPWOXEL9jpaWlMDAwwMnTEbC0shKpTHZ2\\\\\\n\",\n              \"NizNjPHkyRNoamrWuq1QKERWVla14fbvn1NTU6GiogKBQAAdHZ2K179/FggEEAgEkJWVrbT/oqKi\\\\\\n\",\n              \"Twbdxz/v3buHS5ev4NyVOJHO3MrKytDW2hjv3+VAUUkZm3YfgbVdK5GO00fenj/Dupkp7OzswOFw\\\\\\n\",\n              \"ICsrCw6HU+Ul7vscDuebvDz7b0lJSfipdWvEPX4GeXl5kct5jxqGdo5tMH369AZsHfO5JKnvZGdg\\\\\\n\",\n              \"3zEOh4OxY8dhld9f2LErCFJSUp8ss2H9WnC5PJGWlJeWlgaPxwOPx4ONjU2N2wmFQmRmZlYJt4cP\\\\\\n\",\n              \"H+LcuXMVQff27VuoqqpWCbf//qytrV3tfZx+/ftjzPjJInf8jRo1gpfPZJw9dQyvXr2GUFguUrl/\\\\\\n\",\n              \"KystQ2pqKuLi4lBaWoqSkhKUlpZWeYn7fmlpKWRkZOochA35/tmzZ9Gug4tY4QUAbj16ISIslAUY\\\\\\n\",\n              \"U2cswL5zM2fOQLt27bB86RLMnfd7rSG2/++92BW4A/369YWNjQ2WLVsGLy8vkYKvNtLS0uDz+eDz\\\\\\n\",\n              \"+bC1ta1xO6FQiPT09Cpncvfu3UN4eHhF8L19+xbq6uqVwk1bWxunTp7EkjVbxWrbwCEjsGbFEsjJ\\\\\\n\",\n              \"NcaTh/dga99G5LLl5eVIefkcu3dshaWl6PfOREFEKCsrq3MQfur9/Pz8z97P27dv4dSho9jfTVlF\\\\\\n\",\n              \"Bbm5ufV6vJgfEwuw75yqqirCw8PRvXt33L0bh6nTZqKtg0OlUHpw/z78N23A2fDTOHs2HM2bN4e3\\\\\\n\",\n              \"tzfGjRuHv//+GwEBAQ06EOAjaWlpaGlpQUtLCy1atKhxu/LycqSnp1e6TPn06VPIyslBSUlZrDp5\\\\\\n\",\n              \"fE2QUAgpEPZs24RBw0QP7Evnw6Grq1Pv4QV8uF/08UznW7Vv3z4cPHJM7HLv3uVARUWlAVrE/GhY\\\\\\n\",\n              \"gP0AdHR0cPnyZQQEbIO310g0lpeHpVVzyEjLICnxKV68SMa4cd64efNmxX0vW1tbXL16FRs2bICD\\\\\\n\",\n              \"gwNmzpyJWbNTchnJAAAbcUlEQVRmfRMd6vv375GUlIS4uDhcu3YN9+7dw7Nnz1BSUiL2voRCIYRC\\\\\\n\",\n              \"gomJCZ49e44bV6LRxslZpHIB//NDj66dP+crfBecnJwwefJk5OfnizXS8vTJULh2cmmwdjE/DjaI\\\\\\n\",\n              \"4wcjFApx+fJlJCcnQygUQiAQwMXFpdZgev78OSZMmICUlBRs27YNrVu3bvB2ElHFfbJr164hNjYW\\\\\\n\",\n              \"jx49wuvXr1FcXAwZGRmUlZWBx+PByMgINjY2OHLkKPaFhMPcQrQBKwDwz83rGOnRG/6bN0NRURGj\\\\\\n\",\n              \"Ro/BriNnYWLWrMYyQqEQfgt9ce/WVWRnZaJ79+5YuXLlD/mAbu/efdDJrQeGj/ISafs3b1LQoXUL\\\\\\n\",\n              \"PH/+nJ2FfaMkqe/8toc5MZ+NiHDlyhUMGzYMurq6UFZWhkAgQL/+/VFYWAhPT0+MGDECrq6unzyr\\\\\\n\",\n              \"MjQ0RFhYGHx9fdG7d29Mnz4deXl59dLO8vJyPH36FIcPH8a0adPg4uICPT09yMnJwcDAAG5ubli2\\\\\\n\",\n              \"bBkePHgAU1NTzJw5EyEhIXj48CGKi4vx9u1bnD9/Hq1bt0bjxnLYGbBJrPq3bV4PeXl5pKamwt/f\\\\\\n\",\n              \"H6ZNTTBucA8E7wpAXu77Ktvfv3MLM8Z5IulxHKIuXsCDBw8gLS2N5s2b49SpU/VyTCTJtGlTsdZv\\\\\\n\",\n              \"Od6+Tf3ktkKhEH/MnY1hw4ax8GLqBTsD+w69fPkSgwYNQkZmJsZ5T0Dffv2hpq6OvLw8hJ8Ow9Yt\\\\\\n\",\n              \"m1BcVITDhw+jefPmYu07IyMDM2fORHR0NPz9/dGtWzeRyhUVFeHJkyeIiYnB9evX8eDBAyQnJyM7\\\\\\n\",\n              \"OxtSUlIgIqirq8PAwABWVlZo27Yt7O3tYW5uXuPzQomJifD398euXbvg6OgIDw8PTJz4C05Hx0JH\\\\\\n\",\n              \"V/+TbXqelIg+XZ0QuGMHfvnlF2RlZWHEiBFwdnZGyPHjiDwXiQ5dukGdy0dJcTFir0SjtKQIEyf4\\\\\\n\",\n              \"YMqUKZVG350/fx5jx45Fu3btsHbtWnC5XNEO6Hdg8eLFCD5wEPuPhEJPv0m125SWlmLW1Il49jQB\\\\\\n\",\n              \"585FiD1ykflyJKrv/ArzL34RkjQhZX1KTk4mfX19WrZiZbVrf31c/ytw1x7S1NQUae2q6oSHh5OR\\\\\\n\",\n              \"kRENHTq00nIgOTk5dP78eVqwYAH16tWLzMzMSElJiaSkpEhKSork5OSoSZMm1KlTJ5oxYwYdOXKE\\\\\\n\",\n              \"EhMTRZ61vLy8nMLCwqhHjx7E4/Fo9uzZlJSURCkpKeTm5kYcWVnSb2JI1+8l0bP0whpfl24/piaG\\\\\\n\",\n              \"RrRu3Xpydnam4cOH06tXr2jp0qWkr69Pbdu2pXXr1tHGjRvJyMiIvLy8SE5OjgoLC2tsW15eHk2d\\\\\\n\",\n              \"OpUEAgEdPnz4s46rJBIKhf+38oEaDR81hiJjrlPa+2JKzy2h+09f0LwFS0hfvwn17t1HpJWfma9L\\\\\\n\",\n              \"kvpOFmDfEaFQSC1bthR5+ZQ9+w6Qvr5+retk1VTPq1evKDAwkFq0aEEcDofU1dWJw+FUBJWqqipZ\\\\\\n\",\n              \"WVnRgAED6K+//qIrV65QTk7OZ3+3rKwsWrNmDZmYmJCdnR0FBgZSVlYWHTx4kFxdXalRo0akpaVF\\\\\\n\",\n              \"wcHBtGjxYuLxNWnBstUUl5haKbjuJKTQ/CUriK+pRSoqKmRnZ0djx46ttK5VWVkZHT9+nLp160Z8\\\\\\n\",\n              \"Pr9iEUtra2uRZum/fPkymZub08CBAyk1NfWzv7OkSU1NpT///JMMDA0rZtlXVlYmLy8vunXr1tdu\\\\\\n\",\n              \"HiMiSeo7WYB9R86fP0+WVlZUUCIUefmUrm7daPfu3dXur6ysjOLi4mjt2rU0ZMgQatGiBXG5XJKW\\\\\\n\",\n              \"liYAJCsrSzo6OmRjY0NcLpesrKzowoUL9boG1J07d2jcuHGkpqZGnp6edPnyZYqJiaHx48eThoYG\\\\\\n\",\n              \"tWrVijQ0NGjGjBkV63FdvHiRlJSUSF1DgxQUFMmlsyv17j+IOnbuSqpqauQ5dBidPXuWDAwMSFVV\\\\\\n\",\n              \"tdIZ5H8lJCSQvr4+qaiokL6+Po0bN06k71dYWEi//voraWpq0t69e3+4pURKS0trPVtlvl2S1Hey\\\\\\n\",\n              \"e2DfkYGDBqF9h44YP0H0Of1OnTyBFcv+xJIlixEdHY3bt2/j6dOnSE1NRX5+PgBAUVERAoEApqam\\\\\\n\",\n              \"aNmyJZydndG6detKN+JLS0uxZs0arFy5EnPnzsXUqVPRqNHnPaVRWlqKY8eOYePGjUhKSoKPjw+6\\\\\\n\",\n              \"du2KM2fOICgoCBwOByNGjEBxcTH8/f2xfft2/PzzzyAiHD16FEOHDoWhoSF27NgBY2NjxMbGIjc3\\\\\\n\",\n              \"FyoqKmjbti0AwNXVFa6urpCWlsaVK1cQERGBxo0bV9uejh07wtfXF0ePHsWRI0egrKyM8ePHw8vL\\\\\\n\",\n              \"65PTbd28eRNjxoyBgYEBtmzZAl1d3c86JgzzpUhU3/mVA7TBSNJvEfVFRUWFXqSkiXz29XEBSw6H\\\\\\n\",\n              \"QzIyMsTn86lly5Y0bNgw+t///kf3798Xe8n4hIQE6tSpE9nb29Pt27fFKpuSkkILFy4kgUBAzs7O\\\\\\n\",\n              \"tGvXLtqyZQu1b9+e+Hw+TZ48mWJjYyknJ4cGDhxILVu2pKSkJCIiioqKIgcHB9LW1qYWLVrU2O6U\\\\\\n\",\n              \"lBSytLSk+fPnk1AopPLycnJ3d6chQ4bUWMbFxYXOnz9PCQkJpKenRzdv3iQvL6+Ks8JPLdZYXFxM\\\\\\n\",\n              \"CxcuJB6PR9u2bfvhzsYYySJJfScLsO+EUCgkADUO3KjtpaOjQy9evKjXtgQGBhKfz6c5c+bUusKw\\\\\\n\",\n              \"UCikyMhIcnFxIQUFBXJzc6PFixeTh4cHqaioUP/+/SkkJISKi4uJiCguLo5MTU3Jx8eHCgsLKS4u\\\\\\n\",\n              \"jnr06EGGhoa0du1aUldXp/j4+GrrevnyJZmamtKSJUsqvV9QUEAODg40b968ast9DDChUEh8Pr/i\\\\\\n\",\n              \"WGVlZdG6devIzMyMrK2tyd/fn96/f1/jd42LiyN7e3vq0qULPXv2rLZDyDBfjST1nSzAviMKCgqU\\\\\\n\",\n              \"lvVerPAqKBGSmpoaZWZm1nt7UlNTycPDg0xMTOjcuXOVPsvPz6fly5eTjo4OKSgoUMuW9tStew9y\\\\\\n\",\n              \"dGpHKioqZGNjQ0FBQZXOVnbu3Ek8Ho+CgoIoKSmJhg0bRlpaWrR+/XoqKiqivn370sKFC6tty7Nn\\\\\\n\",\n              \"z8jIyIhWrVpV7edpaWlkYmJCO3bsqPLZxwAjIurTpw8FBwdX+lwoFNK5c+eof//+pK6uThMnTqR7\\\\\\n\",\n              \"9+5VW09paSmtWLGCuFwubdiwQewzXIZpaJLUd7IA+444OzvTvgOHxQqw6MvXycjIqEE70hMnTpC+\\\\\\n\",\n              \"vj6NHj2abt68STNnziRlZWVSUlKmyVOn08MniZXalJ1bSIG79lBza2saMWIEvXv3jsaOHUvm5uYU\\\\\\n\",\n              \"FRVFkydPJg0NDVq4cGHFGc+JEyeoadOm1Q4ciI+PpyZNmtDGjRtrbefjx49JU1OTIiIiKr3/7wBb\\\\\\n\",\n              \"sWIFTZkypcZ9vHr1ihYsWEACgYA6dOhAwcHBFWeP/63L0dGR2rVrR0+ePPnkMWSYL0WS+k4WYN+R\\\\\\n\",\n              \"4OBgcunYSawAGzZ8JK1Y4deg7SovL6fDhw+ToaEhSUlJkampKSkpK9O5C5dqbVtGTh65urqRtrY2\\\\\\n\",\n              \"9evXj3x9fUlDQ4OmTp1Kb9++rdh/fn4+GRoa0tmzZ6vU/fDhQ9LV1aWAgACR2hoVFUV8Pr/SGdS/\\\\\\n\",\n              \"A+zSpUvUqlWrT+6npKSEDh06RB07diRtbW2aN28eJScnV9qmrKyM1q9fT1wul/z8/CpGUTLM1yRJ\\\\\\n\",\n              \"fScLsO9IcXEx6ejoUOipMyKF19Ubt0ldXZ3S09Mr7aegoIASExPp8ePHlJGR8dnt+fezW6amptS+\\\\\\n\",\n              \"fXtSUlIiJSUlOngkRKQ2ZuTkkYGh4YeHZIcPr/be0dy5c8nDw6PK+3FxcSQQCCgoKEisdu/du5cM\\\\\\n\",\n              \"DAwoJSWFiCoHWEFBASkoKFBeXp7I+3v48CFNmTKFNDQ0qHfv3nTmzJlKZ7yJiYnUsWNH+umnn2q8\\\\\\n\",\n              \"9MgwX4ok9Z0swL4zly5dIj6fT6fPRtYaDNdi/yEdHR06dOj/zxgRFxdH3t7jSU1NjQwMDMikaVNS\\\\\\n\",\n              \"UVGhTp060eHDh6mkpESkNnzYjzcpKytT8+bNSSAQkJWVFfn5+dHBgwfJyqq5WM+qbdjkT507d6m2\\\\\\n\",\n              \"rocPHxKXy6XXr19Xev/mzZukpaVFBw4c+KzjuHjxYrK3t6e8vLxKAUZE1KZNG7p48aLY+8zLy6Nt\\\\\\n\",\n              \"27ZRixYtyMTEhFatWlXxC4JQKKStW7cSj8ejxYsX13is09PT6a+/VlDHjh3Jzs6OHJ2caPLkKfTw\\\\\\n\",\n              \"4cPP+p4M81+S1HeyAPsOXbhwgfh8PrkPHkKRF2MqhcX1m3fIa6w3aWho0IEDB4nowyW+6dNnkI6O\\\\\\n\",\n              \"Dv2xcDElvUip2P5dfjHt3rufHBydqEWLFlWC4qOSkhI6cOAAtW3bltTU1KhJkybE4/Fo2rRpdPv2\\\\\\n\",\n              \"7YrBGO4eHrTuf5vEusyZnp1L6urqVeoWCoXk4uJC69evr/T+1atXSVNTk44dO/bZx1AoFNKoUaOo\\\\\\n\",\n              \"d+/e5OzsXCnApk+fTsuWLavTvq9evUrDhw8nNTU1GjlyJF2/fp2EQiG9ePGCunfvTra2tpVmrygo\\\\\\n\",\n              \"KCBv7/GkqqpKw0eMouMnT1PM1Vg6G3mRfv1tPmlpaVGnTp3Y6EamziSp72QB9p36cPluLZmampKW\\\\\\n\",\n              \"lhaZmZuTrq4u6enp0eLFS+jNmzdE9KEznTjxF3J0akcpaVm1jlZctGQpNW3atNIlx5SUFPrjjz9I\\\\\\n\",\n              \"Q0ODeDweKSgo0MCBA+nkyZPVnkXY2trS1Ru3xR7q7+joRNHR0ZX2FRQURHZ2dpXuHUVHRxOfz6dT\\\\\\n\",\n              \"p07V+RgWFxdTp06dSE9Pr1KAHTp0iHr16lXn/RN9OKPy8/MjIyMjsre3p+3bt1NeXh4FBQURn8+n\\\\\\n\",\n              \"uXPnUmZmJrVv354GunvQq9SMao/Pu/xiWvqXH+no6LBBIUydSFLfyQLsO1deXk4vX76kBw8eUHJy\\\\\\n\",\n              \"cpWBAqGhoWTerBmlZuSIFCTTZswiD4/BdOnSJeratSvJyspS48aNyc7OjgICAig7O7vW9jRr1oxu\\\\\\n\",\n              \"3bkvdoC5dOxUaXRgVlYWaWtr0/Xr1yveO3fuHPH5/CqjCOsiOzubFBQUaNKkSUT0IfAPHTpEampq\\\\\\n\",\n              \"1MHZmdq1b0/uHh4UGhpapym0Pk5S3KtXL9LQ0KBp06bRpUuXqF+/fqSpqUnuHkNEesZv85ZtZGxs\\\\\\n\",\n              \"XOuzdwxTG0nqO1mA/eBcu3alwF17RA6S1IwckpeXJw6HQ1wul3x9fenp06ci19euXTs6ERYuVngV\\\\\\n\",\n              \"lAjJ1Mys0swePj4+5OPjU/HzqVOniM/nU1RUVL0eHyKitm3bEpfLpYULF5KlpSVZWFqS3+q1dPps\\\\\\n\",\n              \"JIWfu0Cbt2yj1q3bkIGBAe3Zs7fO9T179ozmzp1Lmpqa5OjoSCqqqpT5Ll/k49Wtew8KDAysh2/O\\\\\\n\",\n              \"/Igkqe9kAfYDS0hIID6fT9m5hWIFyvARI8nLy+uzpkRavXoNDR4yVKz6omKukbGxccXIvevXr5O2\\\\\\n\",\n              \"tjZlZWUREVFISAjx+Xy6evVqvR6fj1xcXGjMmDGkoqJCx0+ernEASlTMNTI0MiI/v5X1Um9RURH9\\\\\\n\",\n              \"3Ls3TZw0RazjdfT4SZGG+jNMdSSp72QrMv/Abt26hfYdnGucxLYmPXr1RnpGBqSkpMSuc/ToUThz\\\\\\n\",\n              \"+hTS0tJELrNuzUpwuVykpKSgrKwMPj4+8PPzg7q6Og4ePIjx48fj9OnTFRP11rf379/j+PFQnI+6\\\\\\n\",\n              \"jK5u3Wr83q3btEHkxRhs2rQRISEhda5XTk4Ojx89wshRY8Qq19WtG54/f47Xr1/XuQ0M8y1jAfYD\\\\\\n\",\n              \"y8/Ph4KCotjllJSUKmaqF5e6ujrGjh2H0SOGoqSk5JPbHzl0EJeio2BtbQ1bW1t0794dioqKGDZs\\\\\\n\",\n              \"GPbs2YNp06bh7NmzsLe3/6z2iCI9PR3zFyyElQirV+vo6GDdhs34888/QfWw0ENWVha0BQKxysjI\\\\\\n\",\n              \"yEBLWxuZmZl1rp9hvmUswH5gqqqqyMoSv5PLzMyEqqrqZ9e7fPkyqKupou/PPfDq1atqtyktLcXm\\\\\\n\",\n              \"TRsw0ccbNrYtEBISAltbW0RFReHRo0cYOnQofv31V0RGRsLGxuaz2/Ipr1+/RlZWFoYOGyFyma5u\\\\\\n\",\n              \"3ZCZlYXY2Ng61y8nJydS0P9XUVER5OTk6lw/w3zLWID9wDp06IArl2OQnZ0tVrljRw7BtYvrZ9fb\\\\\\n\",\n              \"qFEjHDhwAA5t26KNvS0GD+qPY0eP4Mrly7hwPhKLFvwOM2MDhBw9gktXruPUmQg8SUzGgEEe4HJ5\\\\\\n\",\n              \"sLa2xrFjxyArK4uHDx/Wy5lOTUJDQ9G9Zy8oKyuLXEZaWhpDh43A4cNH6ly/hYUFrl65LFaZlJQU\\\\\\n\",\n              \"ZGZkQE9Pr871M8y3jAXYD4zP56NXr17Ys3uXyGVevXqF6KiLGDZsaJ3qlpGRwdKlfyI5ORnd3Nyw\\\\\\n\",\n              \"dPFCDPEYgKVLFiEnJwcnT5/F2ciLMDM3B/DhsqXXuPG4cuMWnj17jjlz5mDr1q1YvHgx2rVrhytX\\\\\\n\",\n              \"rtSpPTVJT0+HkZGx2OV0dfWQkZFR5/rHjx+PbVv9xSqzc8c2eHgMhqKi+JeHGUaSsAD7wU2dOhVr\\\\\\n\",\n              \"V/sh8enTT25bXl6OaZMnYswYLygpKdVL/UpKSnBz64rU1De4cSsO5y5EY+36DTXebxIIBAg/dwHr\\\\\\n\",\n              \"16+HnZ0dbt++DW9vb3h4eGDAgAFISEiol3Z9JCcnh6KiIrHLFRcX18slvL59++JpQjwuRUeJtH16\\\\\\n\",\n              \"ejq2B2zBxIkT6lw3w3zrWID94Fq1aoWFCxehh1tn3Lt7t8bt8vPzMXzoYBQXF2HZsqX12oYtW7Zi\\\\\\n\",\n              \"6PCR0NLSEml7A0ND9O7TD4GBOyEjI4ORI0ciPj4eP/30ExwcHDB58mSkp6fXS9uaNWuGmEvRYpeL\\\\\\n\",\n              \"vXEdZmZmda6fw+Fg586dGO7pgdu3btW6bWZmJvr37onRo8c06H1BhvlmfO1x/A1Fkp5l+Bbs3fs3\\\\\\n\",\n              \"cblc6tnrZwo5EUYvUtIoNSOHYm/fpSnTZhCXy6VRo0ZRUVFRvdZbXl7+YfmSh/FiPet06coNatq0\\\\\\n\",\n              \"aZX9paWl0eTJk4nL5dLSpUvrPCNFSUkJKSur0M1/7onctpdv0klNTa1OM/n/19GjR4nH49HM2b70\\\\\\n\",\n              \"OOFZlYfL16zbQAaGhjRr1uzPej6PYT6SpL5TiqgB74B/Ra1atcLNmze/djMkSn5+PoKDgxGwbRue\\\\\\n\",\n              \"JiSgpKQEfD4fAwcOwoQJPjAyMqr3OnNycmBgYIC3me/EKldWVgY1pcYoKSmBtHTVCwkJCQn47bff\\\\\\n\",\n              \"cO3aNSxZsgTDhw+HjIzMZ7XRxMQELe1bIejvYJGefft93lxkpKVi586dn1VfTZ4+fYpNmzZjz54g\\\\\\n\",\n              \"mJqagcvjIT8/H3f+uY0urq6Y9MsvcHZ2rtc6mR+PJPWdLMCYryo9PR3NmjXD67fiDecnIijLc1BU\\\\\\n\",\n              \"VIRGjRrVuN3Vq1cxa9Ys5OXlwc/PD25ubmK3sX379khPT0effgOwcPGftYbY9oCtWOW3HFevXoVA\\\\\\n\",\n              \"zOe3RFVQUIAbN27g3bt3UFBQgLW1NbS1tRukLubHI0l9Z83/8hnmC1BTU0NBQQHevXsn1rNlr1+/\\\\\\n\",\n              \"hoqKSq3hBQAODg6IiYlBSEgIJk+eDENDQ/j5+aFFixafrCMrKwvh4eFITU1F586dcfzYEdyNu4Op\\\\\\n\",\n              \"02fC2aVjpSC7GRuLzRv/hxvXryIiIqLBwgsAFBQU4OLi0mD7ZxhJwQZxMF8Vh8PBz717Y9/ePWKV\\\\\\n\",\n              \"C9oVCHd3D5G2lZKSQr9+/fDgwQP06dMH3bp1w8iRI/Hy5ctqt7937x5Gjx4NY2Nj7A8+AKf2HZBf\\\\\\n\",\n              \"WARACg/u34fXqOGwtjSDx8B+GOI+AK3tbTFiqAdsbawRGxsLU1NTsb4LwzCfh52BMV/dxAkTMGHi\\\\\\n\",\n              \"RHj7TBDpPlVJSQkCtwcgLCxMrHo4HA5++eUXDB8+vOIszNvbG7/++mvF2V9ISAjGjRuHKdNn4t6j\\\\\\n\",\n              \"BPD5/IryRIRL0VFYsXwp8nJzMXBAfzRu3BgCgQAODg6ffY+NYZjPw+6BMV8dEaFLly6wsLLGytVr\\\\\\n\",\n              \"a73HJBQKMX7sGBQU5OHI4cN1qvfVq1dYsGABTp48iXnz5qFZs2YYPnw4joWGoWUtcyuWl5dj0oTx\\\\\\n\",\n              \"ePXqBcJOnQKHw6lTOxjmWyJJfSe7hMh8dVJSUjh8+DCuXr4En3FeNc5U//r1awwfOhjJz5MQtHt3\\\\\\n\",\n              \"nevV09PDjh07cO7cOYSFhcHd3R3+ATtqDS/gwywiGzZvQV5ePg7XMUQZhvl8LMCYb4K6ujqioqIg\\\\\\n\",\n              \"J8uBrZU5Rg0figPB+3E67BT27/sbnh4D8ZOdNfR1dXH27Nl6nSbJ2toavr6+0NLSRvcePUUq06hR\\\\\\n\",\n              \"I0ydPhObN2+ut3YwDCMedgmR+eZkZ2dj585duHrtKnJzc6GiogLnDs4YMWK4WJPqimPYsGFo2aoN\\\\\\n\",\n              \"Jk6aLHKZsrIymJsYIDIyEs2aNWuQdjHMlyZJfScbxMF8c9TV1TFjxnQA079YnU8TEzFmnHjzBzZq\\\\\\n\",\n              \"1AhWza2RlJTEAoxhvgJ2CZFhAJSXlX3ymbLqcDgclJWVNUCLGIb5FBZgDANAU0sLycnPxSpDRHj+\\\\\\n\",\n              \"/JnIkxAzDFO/WIAxDIDBHh4I2hUoVpmbsbEoLCjATz/91ECtYhimNizAGAbAoEGDEHfnH8Q/eSJy\\\\\\n\",\n              \"mS2bN8LHZ0K1kwkzDNPw2L88hgHQuHFj/PbbPAwb4o6cnJxPbv/3niBcjonG2LFeX6B1DMNUhwUY\\\\\\n\",\n              \"w/yfqVOnoEuXLujs0g43Y2Or3SY3Nxcrli/Fgt9/Q1hYGDQ0NL5wKxmG+YgNo2eY/yMlJYXVq1fB\\\\\\n\",\n              \"PMAcwz3dwePxMWTocOjo6qKoqAjXrl7BweB9cHZxweXLl2FgYPC1m8wwPzQWYAzzL1JSUhg/3htj\\\\\\n\",\n              \"x3rhzJkzOHr0GKKjzkNOTg6WFpa4e/cu9PT0vnYzGYYBCzCGqZaMjAx69uyJnj1Fm1qKYZgvj90D\\\\\\n\",\n              \"YxiGYSQSCzCGYRhGIrEAYxiGYSQSCzCGYRhGIrEAYxiGYSQSCzCGYRhGIrEAYxiGYSQSCzCGYRhG\\\\\\n\",\n              \"IrEAYxiGYSQSCzCGYRhGIkkREX3tRjQEHo8HQ0PDr90MhmEYifL8+XNkZGR87WaI5LsNMIZhGOb7\\\\\\n\",\n              \"xi4hMgzDMBKJBRjDMAwjkViAMQzDMBKJBRjDMAwjkViAMQzDMBKJBRjDMAwjkViAMQzDMBKJBRjD\\\\\\n\",\n              \"MAwjkViAMQzDMBKJBRjDMAwjkViAMQzDMBKJBRjDMAwjkViAMQzDMBKJBRjDMAwjkViAMQzDMBKJ\\\\\\n\",\n              \"BRjDMAwjkViAMQzDMBKJBRjDMAwjkViAMQzDMBKJBRjDMAwjkViAMQzDMBKJBRjDMAwjkViAMQzD\\\\\\n\",\n              \"MBKJBRjDMAwjkViAMQzDMBKJBRjDMAwjkViAMQzDMBKJBRjDMAwjkViAMQzDMBKJBRjDMAwjkViA\\\\\\n\",\n              \"MQzDMBKJBRjDMAwjkViAMQzDMBKJBRjDMAwjkf4fqOrD7CaPBGEAAAAASUVORK5CYII=\\\\\\n\",\n              \"\\\"\\n\",\n              \"  frames[3] = \\\"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAbAAAAEgCAYAAADVKCZpAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\\\\\\n\",\n              \"AAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0\\\\\\n\",\n              \"dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOzdd1RUx+P38fdSBKSIFClWBDsq2Cv2\\\\\\n\",\n              \"WEAFbNgVe+yaxGhiiRqjicaCvff+BQugCJag2LvYS4yCgBTpddn7/JFHfiGA7BoRV+d1zh6QvXPv\\\\\\n\",\n              \"3AXvZ2d2Zq5MkiQJQRAEQVAzGsVdAUEQBEF4HyLABEEQBLUkAkwQBEFQSyLABEEQBLUkAkwQBEFQ\\\\\\n\",\n              \"SyLABEEQBLUkAkwQBEFQSyLABEEQBLUkAkwQBEFQSyLABEEQBLUkAkwQBEFQSyLABEEQBLUkAkwQ\\\\\\n\",\n              \"BEFQSyLABEEQBLUkAkwQBEFQSyLABEEQBLUkAkwQBEFQSyLABEEQBLUkAkwQBEFQSyLABEEQBLUk\\\\\\n\",\n              \"AkwQBEFQSyLABEEQBLUkAkwQBEFQSyLABEEQBLUkAkwQBEFQSyLABEEQBLUkAkwQBEFQSyLABEEQ\\\\\\n\",\n              \"BLUkAkwQBEFQSyLABEEQBLUkAkwQBEFQSyLABEEQBLUkAkwQBEFQSyLABEEQBLWkVdwVKCpmZmZU\\\\\\n\",\n              \"qlSpuKshCIKgVp4/f05MTExxV0Mpn22AVapUiatXrxZ3NQRBENRKgwYNirsKShNdiIIgCIJaEgEm\\\\\\n\",\n              \"CIIgqCURYIIgCIJaEgEmCIIgqCURYIIgCIJaEgEmCIIgqCURYIIgCIJaEgEmCIIgqKXPdiKzIAjC\\\\\\n\",\n              \"x5KWlpazeoW5uTm6urrFXKMvg2iBCYIgvAdJkggJCaGPR19Mzcyp36gx9Ro2wtTMnP4DB3L58mUk\\\\\\n\",\n              \"SSruan7WRAtMEARBRampqXj068/1W7dp6dqfRUcuoG9UCoCk+DhCfPfj2qMXzZs1Zce2raJFVkRE\\\\\\n\",\n              \"C0wQBEEFWVlZOHftRlyWjFm7AujQb3hOeAEYGpvQacBoZu0O5EVcEu49e5GdnV2MNf58iQATBEFQ\\\\\\n\",\n              \"wcJFi4jPVDD4x8VoaZcocLsSuroMm7uCl6/jWLFixUes4ZdDBJggCIKSsrKyWLV6Da5jpqGpVfgn\\\\\\n\",\n              \"MFpa2nQb/S3LvVaiUCg+Qg2/LCLABEEQlHT06FFMrcpRvkoNpctUtndES0+fwMDAIqzZl0kEmCAI\\\\\\n\",\n              \"gpIuXrpE9catVCojk8mo3tiJS5cuFVGtvlwiwARBEJSUlJSMTkl9lcvpljQgKTm5CGr0ZRMBJgiC\\\\\\n\",\n              \"oKTSxqVISUxQuVxqYjzGpUoVvqGgEhFggiAISmrXrh23/ziu0gRlhULBrT8CaNu2bRHW7MskAkwQ\\\\\\n\",\n              \"BEFJbdu2RVPK5smtq0qXuXfpLCbGxjRp0qQIa/ZlEgEmCIKgJJlMxsD+/dmxcAYZaamFbp+anIjP\\\\\\n\",\n              \"ql/47pspyGSyj1DDL4sIMEEQBCUdOXKElSu9qGRtwcrJQ0iOf1PgtolxMXhNGkzn9m0ZMGDAR6zl\\\\\\n\",\n              \"l0MEmCAIQiEUCgWzZ89m7NixHD16lOA/ztC5rRMze7Vmxy/TefHoLpnp6WSkp/H8/h12/jKdWX3a\\\\\\n\",\n              \"0qubM6tWeonWVxGRSZ/pcskNGjTg6lXl+6kFQRDy8+bNGwYMGEBycjL79+/HwsIi57nw8HCcWrUi\\\\\\n\",\n              \"OSWN+DexAFiVLcewoUMYMXw4lpaWxVXt96ZO106xGr0gCEIB7ty5g5ubG127duXXX39FW1s71/P6\\\\\\n\",\n              \"+vpEv35NREQE+vqqzw8T/hvRhSgIgpCPvXv30rZtW3766SeWLl2aJ7wA/P39ad26tQivYiJaYIIg\\\\\\n\",\n              \"CP8gl8uZNm0aPj4+BAYG4uDgUOC2hw8fpnv37h+xdsI/iQATBEH4/16/fk2fPn3Q0dHh6tWrmJiY\\\\\\n\",\n              \"FLhtRkYGAQEB4lYpxUh0IQqCIABXrlyhYcOGNG/eHD8/v3eGF8Dp06ext7fPNahD+LhEC0wQhC/e\\\\\\n\",\n              \"pk2bmD59OuvWrcPNzU2pMqL7sPiJABME4YuVkZHBxIkT+eOPPwgODqZ69epKlVMoFBw5coTTp08X\\\\\\n\",\n              \"cQ2FdxEBJgjCFyk8PJyePXtibW3N5cuXMTQ0VLrs1atXMTIyomrVqkVYQ6Ew4jMwQRC+OMHBwTRs\\\\\\n\",\n              \"2JBu3bpx8OBBlcILRPfhp0IEmCAIXwxJklixYgW9e/dm69atTJ8+/b2WeTp8+DCurq5FUENBFaIL\\\\\\n\",\n              \"URAEtSWXy/H19WXj1u28ehWJhqYGtjaVGDNyOK1atcoVTqmpqYwcOZK7d+9y4cIFbGxs3uuYT58+\\\\\\n\",\n              \"JTY2lkaNGn2gsxDel2iBCYKglrZu3YZ1+YqMnf4TL0rVQsdpMFpN+3NXYUnvwSOwrVqDwMBAAJ49\\\\\\n\",\n              \"e0azZs2QyWSEhIS8d3jB362vrl27oqEhLp/FTbTABEFQO/MX/MLvXmuoMXQBxhVr5HrOtFp9KrTu\\\\\\n\",\n              \"RfTdC/Tw6MfXI4azZctmfvzxR8aNG6dSl2FaWhr79+/n1B/BJCYlU8rIiIshZ1m4cOGHPiXhPYgA\\\\\\n\",\n              \"EwRBrRw4cIAlK1ZRb8p6dEuZ5buNTCajjH0zdMcuZ8lvI1i1fCkjR45U+hgZGRn8MHMWGzZuxLhS\\\\\\n\",\n              \"TQyrNUXLsCJZqUlEZWozZNgIxl27zuxZM/NdI1H4OMTtVARBUBuSJFGtVm2MO4zEvGYTpco8P70f\\\\\\n\",\n              \"24znHPE5qNT2KSkptO/UhfA0TSq7jkPfvFyebZIj/+KZz3KqlDHg2NHD6OjoqHQenzJ1unaKTlxB\\\\\\n\",\n              \"ENTG+fPniUtMway68gMoyjbpwsmTgbx69arQbSVJoqdHP6IkQ+yHLcg3vAAMLCtiP/JXnsVnM3CI\\\\\\n\",\n              \"p9J1ET4sEWCCIKiN/QcPYla/IzIVBlBo6xlgWaclR48eLXTbK1eucOnqdar1m1HoMTQ0tag+cCYB\\\\\\n\",\n              \"gUHcu3dP6foIH44IMEEQ1EZkVDQljM1VLqdhaEpkZGSh2y1dsRLL5m5oaCo3PECzhC5WTbuyfOUq\\\\\\n\",\n              \"lesk/HdiEIcgCGpDV0cHKTNL5XJZ6anMmzePxYsXY2FhkfOwtLTM+d7MzAzvgwdw+vmwSvu2btaN\\\\\\n\",\n              \"PYuGsG61CLGPTQSYIAhqw75WDc4dDgbcVSqXFfEEHx8fnJyciIqKynlERkYSFRXF9evXefnyJQpk\\\\\\n\",\n              \"lDAwVmnfuqUtSE1OIjMzkxIlSqhUVvhvRIAJgqA2hgwezJy586nsmkAJg1JKlUl48ZCsxGg6d+6M\\\\\\n\",\n              \"lpYWpUqVyncR3oSEBCyt8x+08U6ShISEpqam6mWF/0QEmCB8grKzszl58iR//vkn2dnZWFhY0KlT\\\\\\n\",\n              \"J/T19Yu7asXK3NwcFxcXbp7ajV23MYVuL0kST/w2oCWTuHr1Kk2aFDz03tDQEC0tTVJjXlHSzFrp\\\\\\n\",\n              \"OiVFPMOsjKUIsGIgBnEIwickMTGRnxcsoGyFSgwd/w1L9wfi5X2Gb+YvxapsecZNmMiLFy+Ku5rF\\\\\\n\",\n              \"6vffFpFwM4iwkHd/ViVJEk8Pr8ZMSmbe3Lm4u7szcuRIYmNjc20XHx/Ppk2baNeuHdnZ2bz8Q7n5\\\\\\n\",\n              \"Ym9Fhhxi1IjhKp+H8N+JFpggfCLCw8Np0/4rsktXoNbwhZSulHuJpJSYCE7/cZA9DRrif/QIjRs3\\\\\\n\",\n              \"Lqaafnz37t1j/cZNPHz8FHl2Ns2aNSP4+CZeh16gcsdBGFeqmbOtpFAQ8+AKEWf2YqqdRWBQAObm\\\\\\n\",\n              \"5vTs2ZNZs2ZRs2ZNfvrpJywtLdm9ezcnTpygXbt2TJw4kSpVqtCkRStsOg2mhH7hXZQZibG8unKC\\\\\\n\",\n              \"0TuWFeXpCwUQASYIn4CEhATatP8K3VqtqdLFM9/1+vTNrKjRYzzGtnXp5OzC+bPB1KhRI5+9fT5u\\\\\\n\",\n              \"3LjBmPETuXf/AVZNXChp3RiZpibhsa/I0ixB8qNr3PvrDjqlTClpXg4kBQnhTzEtZciPE8czaNAg\\\\\\n\",\n              \"9PT0ADAyMsLDw4OIiAjGjx+Pnp4ekyZNYt26dZQuXRr4u9VWx74mV5ZPoPGU1WjpFtxlm5WSSOi6\\\\\\n\",\n              \"b5kyaSLW1sp3OQofjggwQfgELF22HIVppQLD65+sHJxIjQ5j4pRvOXHM9yPV8OM7ffo03d17UtFl\\\\\\n\",\n              \"DM08fkFDK/eagxXb9iXu0TUe7v6Zns5f0bqVE5qamlSsWBFHR8ec1/Hx48fs3LmTnTt3UqJECQYO\\\\\\n\",\n              \"HMjChQs5ffo0M2bMID4+nnnz5qGvr8/YsWNJSUqkU4v6nFo6mgqdh2Neu3mueWGKbDlRN//gxbEN\\\\\\n\",\n              \"9O/pytyf5nzMl0X4BxFgglDM5HI5q9aspe6Y35VeKb2ikxtB01x4/vw5lSpVKtoKFoPHjx/j2qMX\\\\\\n\",\n              \"NYbOx7Rq/Xy3kclkmFZrgOOktexaNpo2rVvh5uYGQHR0NPv27WPnzp08f/4cDw8P9u/fT7169XJe\\\\\\n\",\n              \"48qVK+Pq6sr06dOpXr06lpaWmJubc/bsWQwMDNi3bx+Llizl4v+WYmbfDFmJkmSnJRJ56yyljIxo\\\\\\n\",\n              \"3aQBLZs3IysrSwyfLyZiEIcgFLOAgAB0SltgXCHv0O6CaOnoUb5pFzZu2lyENSs+vyz6DcvmbgWG\\\\\\n\",\n              \"1z/plbagisf3TPthJnv37qVr165UqVKFixcvMmfOHMLCwli2bBn169fP8wbB1NSU+fPnU7p0aV68\\\\\\n\",\n              \"eEFmZiYvX75EJpPh4eHBjSuXOBXgRxs7E5JCT/Pq+mlMajbFqIELt9NKM2n2IqzKlufHmbNIS0sr\\\\\\n\",\n              \"qpdDKIBogQlCMXv69CmG5aupXM6gXFUePn5YBDUqXgkJCezfv59GM3YpXca0eiPubJ/P77//zvjx\\\\\\n\",\n              \"49m9ezeGhoaFlnv8+DGdO3emX79+zJo1i7Vr19KqVSs8PT2ZOXMmBgYGHDp8hINHj2HTfQJ1/9Wd\\\\\\n\",\n              \"CJD06hk7jm3E73hrTp04nvN5mlD0RAtMEIqZXC4HDdXnEMk0tciSy4ugRsXr2LFjmFV1QFeFNQ9l\\\\\\n\",\n              \"MhkVW/fCoV4DBg4cqFR4Xbx4EScnJ6ZNm8bcuXPR0tJi3LhxhIaGEhERQc2aNRkxYiQrN27DcfJ6\\\\\\n\",\n              \"LBxa5btGoqF1ZWp6/kySUSW6dHP9+/cpfBQiwAShmFlYWJAZV/hCs/+WFvMKa0vLIqhR8YqKikLb\\\\\\n\",\n              \"2ELlcnqmVryKjFJq28OHD9O1a1c2btzIiBEjcj1nYWHB9u3b2bBhA1u376DW8IXoGJm+c38ymYwq\\\\\\n\",\n              \"PSbxPCqeI0eOqFx34f2ILkRBKGbOzs6MGjOW9ISYAu8w/G+SQkHERV8GHtxbxLX7+LS1tUGRrXI5\\\\\\n\",\n              \"RbacwMBA7OzssLS0LPAREBDAypUr8ff3p2HDhgXuLzw8HKtaDTGwrKTU8WUamli07MmSZV64u6u2\\\\\\n\",\n              \"VqPwfkSACUIxMzY2pmevnlwL9qFa1xGFFwCiQi9gUsrwnUsjqStbW1tSwtaqXC41/BHjvx7NqJEj\\\\\\n\",\n              \"iIyMzPW4cOECr1694saNG8TExCCTyejcuTOWlpZYWVnlG3RLvVZj3sRDpTpYOLTh/P+WERYWRrly\\\\\\n\",\n              \"77GuoqASEWCC8An44ftpNGzSlNJ2DpSpUXCrAP5ekeP6ljkYlNDk3LlztGzZ8iPV8uNo3749ipQ3\\\\\\n\",\n              \"JPx1n1IVlZuonZ2ZTsTlY4xddw0bGxuqVKmS6/nMzEw8PT2xsbEhNDSU0qVLExsbmyfowsPDuXbt\\\\\\n\",\n              \"GpGRkTx+9IjGbpVUqrumdgmMypQlPDxcBNhHIAJMED4BVapUwefgAVx79KJK9zFUaOaSZ+KuJEm8\\\\\\n\",\n              \"vnuR0B0/s3DeT9hWtqFPnz54enoye/bsv7vePgOampqMHTOajUd3YTRknlJz48JCDtOoUSNsbGzy\\\\\\n\",\n              \"PJeQkIC7uztGRkacPHmSkiVLAlCmTBnKlClDnTp18t2ndQUbQLl5ef8mSdJ7lRNUIwZxCMInok2b\\\\\\n\",\n              \"Npw5GYjOs3Oc/N6F+94reXn5BGFXg3h0bBtn5/Qm0m81W9avZsL4cTg7O3Pjxg2uX79O8+bNefz4\\\\\\n\",\n              \"cXGfwgczaeIE9FIi+fPYpkLD4PWdc4QHbWfV8qV5ngsLC6Nly5ZUr16dgwcP5oSXMqytrUmJ+kul\\\\\\n\",\n              \"eivkWSS+Dqds2bIqlRPejwgwQfiEODo6EvLHaS6FnKVjVRPKRF7B+EUIjUpn8L9dW3l0PxRXV9ec\\\\\\n\",\n              \"7S0sLPDz82PQoEE0a9aMzZs3fxbv/g0NDTkdGEBq6CmurZ5C/J+hec4rNeYVj328eLZ/Ecf9juZZ\\\\\\n\",\n              \"FzI0NJRmzZoxYMAAVq5cqfLtTkYPH8rrC6rdnTnq5hnsa9tTvnx5lcoJ70d0IQrCJ6h69er8vmSx\\\\\\n\",\n              \"UtvKZDLGjRtHmzZt6NevH/7+/qxbtw5T03cP/f7UmZubk52RRnbiI57vnouko49+2SrIZBpkvIkg\\\\\\n\",\n              \"/sVDhgwezDcbrlChQoVcZU+fPk2fPn1YtmwZ/fr1U/nYiYmJPHjwgFehl7B5/QL9MhUKLSMpFESd\\\\\\n\",\n              \"PcCKn39U+XjC+xEtMEH4TNSqVYvLly9TsWJFHBwcOHnyZHFX6T9ZtWoVMTExHDl8iLC/nrFrvRfT\\\\\\n\",\n              \"B3Xlm74dWfHTNCLDw1ixbGme8NqzZw99+vRh3759KoeXXC5n/fr1VKtWjZiYGGb++AN3N0wjI+nN\\\\\\n\",\n              \"O8tJksSTQysoa6JP9+7dVT5X4f2IFpggfEZ0dHRYsmQJnTp1YvDgwfTr14958+aho6NT3FVTSUJC\\\\\\n\",\n              \"AjNmzKBDhw40b94cgLZt276zjCRJLF68GC8vL06ePEnt2rVVOmZgYCBTpkzB1NQUX19f6tf/ex3G\\\\\\n\",\n              \"9IwM1i8bhY3rBMxrNUX2r1VTkiP/4q9jGymVGcvxoIDPZjCNWpA+U/Xr1y/uKghCsYqOjpa6d+8u\\\\\\n\",\n              \"OTg4SPfu3Svu6qhkxIgRko6OjhQeHq7U9nK5XBo3bpxkb28vvXz5UqVj3b17V+rSpYtkZ2cn+fj4\\\\\\n\",\n              \"SAqFIs82Bw4ckGrVdZRKW5aT7DoNlqr3nCRV7TZGKle7iWRsaiZ9+900KTk5WaXjfqrU6dopuhAF\\\\\\n\",\n              \"4TNlZmaGj48PX3/9NU5OTqxZs0YtBng8e/aMrVu38s033yh1o8i0tDR69uzJ3bt3OXfunNLzr6Kj\\\\\\n\",\n              \"oxk7diytWrWiffv23L17F1dX13yH7ffs2ZPQm9cJOOJNv0YVaWMhp3t1I36dPoHI8DB+XbQQff2C\\\\\\n\",\n              \"b34pFA2ZpA5/0e+hQYMGXL16tbirIQifhIcPH9K/f3+srKzYtGkTZcqUKe4qFahFixY8ePCA8PDw\\\\\\n\",\n              \"Qrs+Y2Ji6NatGzY2NmzevFmprtKMjAy8vLxYtGhRzir06j7g5UNSp2unaIEJwhegWrVqnD9/Hnt7\\\\\\n\",\n              \"exwcHDh+/HhxVylfp06d4tKlS2zYsKHQMHr27BnNmzenVatW7Nixo9DtJUni4MGD1KhRg7Nnz3Lu\\\\\\n\",\n              \"3DmWL18uwkudFW8PZtFRp35cQfiYTp8+LZUvX16aMGGClJqaWtzVyZGdnS2VLVtWcnR0LHTbK1eu\\\\\\n\",\n              \"SFZWVtLKlSuV2velS5ek5s2bS3Xr1pWCgoL+a1U/a+p07RQtMEH4wrRu3Zpbt24RGRlJo0aNuH37\\\\\\n\",\n              \"9ifx2ZiXlxdRUVHs3fvuFfb9/f3p3Lkzq1evZuzYse/c9uXLlwwYMABXV1c8PT25du0a7dq1+5DV\\\\\\n\",\n              \"FoqRCDBB+ALFx8dTvkJFXoZHUNfBAS1tbcpWqMTsOT/x6tWrj16flJQUZsyYgYeHB1WrVi1wu40b\\\\\\n\",\n              \"N+Lp6cmRI0dyrUjyb8nJycycORMHBwcqV67Mo0eP8PT0VHk1DuHTJuaBCcIXRC6XM3b8BPbs3Yud\\\\\\n\",\n              \"Uzc6ztpMKauKIEm8CXvKoZMHWPJ7TSZOnMD8uT8ptZDuhzBhwgQkSWL16tX5Pi9JEnPmzGHnzp0E\\\\\\n\",\n              \"BwcXGHLZ2dls3bqVmTNn0q5dO27evCmWdfqMiQAThC+EQqGgd99+3HoWQc/lx9DRN8z1vJlNDcyG\\\\\\n\",\n              \"z8Kx11i2/zaeuLg3rF65oshD7K+//mLbtm0sXrwYQ0PDPM9nZWUxcuRIQkNDOX/+PBYW+d+t+eTJ\\\\\\n\",\n              \"k0ydOhVDQ0MOHz78zptVqgO5XI6mpuZHexOhjkQXoiB8IZYuW8a1+89o+82KPOH1T3qlTOkwfR0+\\\\\\n\",\n              \"/gGFfh71IXh4eGBlZcWECRPyPJeUlISLiwvR0dGcOXMm3/B6+PAh3bp1Y+TIkcycOZPg4GC1DC9J\\\\\\n\",\n              \"krh06RJ9BwzEwNAIHR0dtLS1qVKjFl5eXiQmJhZ3FT85IsAE4QuQnZ3NkqXLaDDwO7RKFD5XSkff\\\\\\n\",\n              \"kHoek5g97+cirdepU6e4fPkyu3btQkMj9+UoIiKCVq1aUbFiRQ4dOpRnonBsbCwTJkygRYsWODk5\\\\\\n\",\n              \"ce/ePXr06KGWLZa4uDic2rTDxb03zzUscFt6lGF7bzF0x1Wq9/mG1fv9KVehIrt37y7uqn5SRIAJ\\\\\\n\",\n              \"whfg+PHjaBmUpoydvdJlyju2JCw8ggkTJpCdnf3B6yRJEgMGDMgJoH+6f/8+TZs2xd3dnXXr1qGl\\\\\\n\",\n              \"9X+fdmRmZrJ06VKqV6+OQqHg3r17fPPNN2q33uNb8fHxNGvhRIpxRdyW+lKn21BKGpshk8nQ1NbG\\\\\\n\",\n              \"2r4RThMX03HmZsZN+ZaNmzYVd5U/GSLABOELEHTyFFb13r0Y7r9paGpSzcmZo0eP0r59e16+fPlB\\\\\\n\",\n              \"67R06VKio6PZt29frp+fO3eO1q1bM2fOHH788cecFpUkSfj4+FCrVi2CgoIIDg5m5cqVmJubf9B6\\\\\\n\",\n              \"fWxDh49E16YuDfpPRaZR8CXZpGJVOkxfx9Rvp3H37t2PWMNPlwgwQfgCvIlPQEffSOVyWvpGePTt\\\\\\n\",\n              \"S4cOHWjQoAEHDhz4IPVJS0vjhx9+YMyYMVhaWub8/ODBg7i7u7Njxw6GDBmS8/Nr167RunVrZs2a\\\\\\n\",\n              \"xapVq/Dz88tzA0t1FBYWRmBgIPU8JinV9Wlc1oZqHfqwbIXXR6jdp08EmCB8AYyMDMhKT1G5XHZ6\\\\\\n\",\n              \"KsalSjFjxgx8fX354YcfGDp0KElJSf+pPqNGjUJbW5slS5bk/GzZsmVMmjSJEydO8NVXXwEQHh7O\\\\\\n\",\n              \"kCFDcHFxYcCAAdy4cSPnuc/BmrXrsGvhjLZeSaXLVG3bg71794pBHYgAE4QvQotmzXh9+5xKZSRJ\\\\\\n\",\n              \"Iuz6HzRp0gSAhg0bcv36dbS0tHB0dOTixYvvVZfnz5+za9cuVq5ciba2NgqFgilTprB+/XpCQkJw\\\\\\n\",\n              \"cHAgJSWFOXPmUKdOHaytrXn48CEjRozI9VnY5+DEyVOUa6Ba166+qQVmFey4fv16EdVKfYgAE4Qv\\\\\\n\",\n              \"gKurKwmvnhP38onSZV7dvUxSXDQzZsxg3759ZGVlYWBgwIYNG/j111/p3r078+bNQy6X5ymbnp7O\\\\\\n\",\n              \"zp07adW2A9Vr1aF23fr09ujHuXPn6NGjB7a2tgwaNIj09HT69u3L1atXOXfuHOXLl2fr1q1Uq1aN\\\\\\n\",\n              \"R48ecf36dRYsWICRkerdn5+qzMxMwsLCuHr1KpGRUe+c0lAQHX0j0QJDTGQWhC9CiRIlGDN6FHv3\\\\\\n\",\n              \"LKPt1OVoFLKkkjwzg5v7vfht0S9YW1mxYsUKpk6dyujRoxk5ciTu7u40btyYwYMHExAQwI4dO7Cx\\\\\\n\",\n              \"sUGSJH5bvIT5839GpmdOmq4NshJVIVPBo5AwfP1dSUtJZMum9cTFxeHq6oqlpSUnTpzg4sWLTJky\\\\\\n\",\n              \"BV1dXQ4ePJjT8lMH2dnZxMTEEBUVRWRkZK7Hv3+WmJhImTJlsLCwIDU1lcw01bt2M9OS8530/aUR\\\\\\n\",\n              \"9wMThC9EZmYm7b/qRKyGIc1H/oSmdol8t8tKS+HEbxPIig3j8YP7GBgYAHD79m28vLw4ePAg3bp1\\\\\\n\",\n              \"Y/z48dSrV4+lS5eycOFCfv/9d0LOX2TX/kNklmmDhq5xnn1LkoQi4U+0os9hZmJMz549GTVqFN9/\\\\\\n\",\n              \"/z03btxg0aJF9O7d+5OYyyVJEvHx8e8Mo7f/jomJwdjYGEtLy5yHhYVFvv82NTVFQ0Pj72kEAwcR\\\\\\n\",\n              \"miCjyeBpStcrLSEO78kuvHj+JyYmJh/8vNXp2ikCTBC+IKmpqfT26Mfl6zep0r43VVu7omv4d9Ck\\\\\\n\",\n              \"vonm4SlvHp08gHPnjkjyLMLDwzl69GhOiMHfk243bdrEqlWrsLa2Zvz48djZ2dG1azdikzLRtHVH\\\\\\n\",\n              \"pvnuOVmKlEh4cQyPPj3x8/Pj22+/ZeLEiejq6hbp+cPfC/3mF0b//llUVBR6enrvDKO3D3Nzc7S1\\\\\\n\",\n              \"tZU6fkxMDDt37mTTpk0kJCQQ8yaBvmtPoaWj3LnfPrQJW804dm7b+h9ehYKp07VTBJggfGHeLlm0\\\\\\n\",\n              \"zGslhw/5UEK3JEgS8qxMPDw8mDBuLHXr1iU7O5tRo0Zx//59/P39KVWqVK79yOVyjh49ipeXFw8e\\\\\\n\",\n              \"PCA+IZnscp3RKKncvCx5xCVsS6cS/Mfp/3yH6IyMjAJbR//+t0KhyBU+BQWUhYUFenp6/6leb2Vn\\\\\\n\",\n              \"ZxMYGMimTZsIDAyka9euDBs2DCcnJzq7dOONfnkc+4wrdD/J0RH4zxpA0HE/6tev/0Hq9m/qdO0U\\\\\\n\",\n              \"ASYIX7CMjAzi4uLQ0NDAxMQkTytCoVAwbtw4rl69SkBAAKVLl853P8uXL+e7mQvRsO2h9LGlrBQ0\\\\\\n\",\n              \"/jxIZER4voM05HI50dHRSrWWUlJSckKosNaSgYHBR+uifPbsGVu2bGHr1q1YWloybNgw+vbtm+vN\\\\\\n\",\n              \"QFRUFA0aN6FcC1dqdx9WYN0So8I4uWgM304cy7dTpxZZndXp2ikGcQjCF0xHRwcrK6sCn9fQ0GDV\\\\\\n\",\n              \"qlVMmTKFdu3aceLECczMzPJsdzbkIopS1VQa1izT1kdWsgzjx4/HwsIiTzjFxcVhYmKSp7VUoUIF\\\\\\n\",\n              \"GjVqlCugSpcunWctxeKSlpaGt7c3mzZt4s6dO/Tv3x8/Pz/q1KmT7/YWFhZcDDlHxy4uHLt6Ett2\\\\\\n\",\n              \"vbFt1jmnSzHur0c8CtrHs/MB/Dx/LhPGj/+Yp/NJEwEmCMI7yWQyfv/9d2bMmEGbNm0ICgrKsyp8\\\\\\n\",\n              \"ZFQUMm39AvZQsCxKEBYWRs2aNalVq1au1pKZmZnazPuSJInr16+zadMm9u3bR8OGDRkzZgzdunVT\\\\\\n\",\n              \"ao3GsmXLcuv6Vfz8/Pjl18VsWz8X3ZIlyZbLMTQqxZhRIxm16Xesra0/wtmoD/X46xAEoVjJZDIW\\\\\\n\",\n              \"LFiAjo4OrVu35uTJk7kupro6uiCpvuCvIiuTP/74g8ePH2NlZYWVlRXW1tb5fm9ubv7J3VE5NjaW\\\\\\n\",\n              \"Xbt2sXnzZhISEhg6dCg3btygQoUKKu0nPDyctevWs3b9BjS0dShrV4P0lCQyU5IYMmggnkOHiPDK\\\\\\n\",\n              \"hwgwQRCUIpPJmDNnDjo6OrRq1YpTp05hbm7O6dOniYt5jZSSCcZ2Su9PkiT0ZMn4nTpFhQoViIiI\\\\\\n\",\n              \"4NWrV0RERBAREUFISEjO94FRC8YAACAASURBVBEREcTHx2Nubl5o0FlYWBRpy02hUHDy5Ek2bdrE\\\\\\n\",\n              \"8ePH6dKlC0uWLKFNmzbv1Y15+PBhBg3xpHKzTrT9fi0m5f/vNYwP/5OgoP2sdajHmpUr6N+//4c8\\\\\\n\",\n              \"FbUnAkwQBJUMHjyYGzduUL16dTQ0NHB0dKRDh3bc81qNpGiGTEO5y4oiOYyUpHgOHDjAuHHjaNq0\\\\\\n\",\n              \"6Tu3z8zMJCoqKk/QXblyJef7V69eERsbi6mpaaFBZ2lpSYkS+c+Fy8/z58/ZunUrW7ZswczMDE9P\\\\\\n\",\n              \"T9asWVPgwBZl+Pn5MXjYCNp/vybfW90Yl7Wh8eBpVGnjzvjJo9HS0qJPnz7vfbzPjRiFKAjCOykU\\\\\\n\",\n              \"Cq5evYqvry9+fn48f/6cTp06oa2tTVBQEKdPn6ZKlSq0atOO84/S0SrjWOg+JUmBxgs/SulkoqGh\\\\\\n\",\n              \"QWpqKg0bNmTcuHE4Ozv/p65CuVzO69ev8wTdP7+PiIggKioKY2PjPMH2z3+bmJhw+fJlduzYwY0b\\\\\\n\",\n              \"N+jbty/Dhg3DwcHhvev3VkpKCmXLV6DttyuxqFq30O1jnj8gYN4wnj97WiQTmN9Sp2unaIEJgpBH\\\\\\n\",\n              \"UlISgYGB+Pr64u/vj6mpKS4uLixbtoymTZvmdNFt2LCBNm3acOjQITLSUpBHXEOmpY+mSdUC9y1J\\\\\\n\",\n              \"CjSjzuFob8PJwAC8vb2ZOXMm4eHhzJgxgwkTJjBmzBiGDRuW74jH/GRkZODt7c2K1et49vQpcnkW\\\\\\n\",\n              \"puZlGNzPgxEjhuc7z+zt8k//Drf79+/j7e3N/fv3iY6ORpIkSpYsSfny5Xnw4AG///57vi06Kyur\\\\\\n\",\n              \"PHeNfpddu3ZhWb2eUuEFYFapOhXqObF161amTJmi9HE+Z6IFJggCAE+fPsXX1xdfX18uXbpEs2bN\\\\\\n\",\n              \"cHFxwdnZGRsbmwLLLVmyhGnTpqGhoYGhoSEZWdlgUJEsw+q5JjVLkgJFwp/oJt/DoaYtfr6Hc9bz\\\\\\n\",\n              \"y8rKYsuWLcydO5cqVapQqlQpzpw5g6urK+PGjaNBgwYFHv/gwYOMHP01epaVKd2oG4blqyPT1CQ9\\\\\\n\",\n              \"LpK4q/68vnWGYZ6eLPt98Ttbdm/evGH37t1s2rSJ2NhYhg4dypAhQ6hQoQJxcXGFtuhevXqVMy3h\\\\\\n\",\n              \"XV2X1tbWGBoaUrOOA5W7j6W8Q3Olf0eRD29yfeMsnj99XGRz2dTp2ikCTBC+UFlZWYSEhOR0DcbH\\\\\\n\",\n              \"x+Ps7IyLiwvt27fPtXxUQYKDg+nWrRvp6enI5XJGjBjB3LlzWbt2Hcu9VpKl0EIu00WRnQWZCcjl\\\\\\n\",\n              \"WWzesJa+ffvmO9AiLS2N1atXs2jRIlq3bo2NjQ379u3DwsKCsWPH0rt371zLTW3atJkp38+gyqCf\\\\\\n\",\n              \"MapYM//zTEng6a45NKxaDu8D+3KFmEKh4PTp02zevBk/Pz86deqEp6cn7dq1U7kb8+3aicoE3dtz\\\\\\n\",\n              \"Hb7nBppayi1B9fYYWwY2JC4mWqXWnirU6dopAkwQviAxMTEcP34cX19fTpw4ga2tLS4uLri4uODo\\\\\\n\",\n              \"6KjSKLr169czdepU9PT0sLW1xdDQkFu3bnH06FEaNWqEXC7n8uXL7Nq1i8ePH7N8+XIGDRrE/Pnz\\\\\\n\",\n              \"6dix4zv3nZiYyNKlS1mxYgW9e/emcePG7Nmzh5s3bzJs2DBGjx7N69evaftVJ2qOWUlJi4rv3J9C\\\\\\n\",\n              \"nsnDDVMZ5dGNObNn8fLly5wBGYaGhgwbNoz+/ftjamqq9Pm/L0mSiI6OxrpsWUbsvaVy+Z3DW/Ds\\\\\\n\",\n              \"0UPMzZVbsktV6nTtFJ+BCcJnTJIkQkNDc7oGQ0NDadeuHS4uLixduvSdq3AUJCsri0mTJrF3714s\\\\\\n\",\n              \"LCyoXbs2kiTxv//9j2PHjuHi4oKPjw/NmzenWbNm/Pnnn7x584YaNWowZMgQtm7dWmiAGRkZMXv2\\\\\\n\",\n              \"bMaOHcvChQuZMmVKTutuz549ODo6omdghEWrfoWGF4CGVgkqun/D4t/HcD7kHNeuXaNPnz4cOHCA\\\\\\n\",\n              \"evXqFfnSUpIkERYWxs2bN7l16xY3btxAkiAjJUml+4FlZ2WSnpycZ13KL5UIMEH4zKSlpXH69Gn8\\\\\\n\",\n              \"/Pzw9fVFU1MTFxcXZs+eTatWrZRaGaIgsbGx9OjRgydPnmBnZ0ejRo24fv06QUFBOcfZuXMnbm5u\\\\\\n\",\n              \"7N+/n9atW6Onp0daWhoAHh4ezJgxg/j4eIyN895u5d/MzMxYvHgxkyZNYv78+Tg7OzNp0iTOnDlD\\\\\\n\",\n              \"w8ZNadjYWem6l7SoSAmzCtjZ2XH48OEPtlDvv2VmZvLgwQNu3ryZ87h161bOnazr1q1Lz549iY1P\\\\\\n\",\n              \"5GnIMWp+1VvpfT+7eIKmLZ1UGv7/ORMBJgifgfDw8JzAOnPmDI6Ojri4uHD8+HGqV6/+QVoYoaGh\\\\\\n\",\n              \"dO3aFQ0NDerWrUurVq3YsmULISEhucLgq6++Yt++ffTu3Ztdu3ahq6ubE2CmpqY5z48aNUrpY5cr\\\\\\n\",\n              \"V461a9fyzTffMHv2bBYvXkzpKo5ol1TtTs1lGjkTFfvnBwuvN2/ecOvWLW7dupUTVg8ePKBSpUo4\\\\\\n\",\n              \"ODjg4ODAd999h4ODA5aWlrnKmpubM3j0eGp06KX07+fpyf0snffjB6n750AEmCCoIYVCwZUrV3K6\\\\\\n\",\n              \"Bl+8eEHnzp3p168f27Zt+0+Ta/Nz5MgRPD09KVWqFE5OTrRv357vv/+ec+fO5TsnqU2bNnh7e+Pu\\\\\\n\",\n              \"7s7kyZNJT0/PeW7IkCHMnz9fpQB7y87Ojl27dvHjjz+yJeiGyuVLGJYm9s/rKpeTJInnz5/nalHd\\\\\\n\",\n              \"vHmT2NhY6tSpQ926dWnatCljxozB3t6ekiVLFrrPtm3bUlpflztHt1Cnm2eh298P3I8sPREXFxeV\\\\\\n\",\n              \"6/+5EgEmCGoiMTGREydO4Ofnh7+/P+bm5ri4uODl5UWTJk2KZPkkSZJYsGABXl5eGBgY0KdPH9q3\\\\\\n\",\n              \"b4+HhwdBQUFUrFjw508tWrTg6NGjdO7cOdfgiI4dOzJ8+HDOnz9PXFwcSUlJGBoa0qhRI6XvC1ah\\\\\\n\",\n              \"QgU0JdUHGmRnphc6ei89PZ27d+/malXdunULQ0PDnFZV//79+e2337C1tX3vVfA1NDQ45nuExk2b\\\\\\n\",\n              \"o5DLqdPdEw3NvL9DRXY2947v5qH/Ni6EnFWbBY4/BvFKCMIn7MmTJzmtrMuXL9O8efOcz7MqVapU\\\\\\n\",\n              \"pMdOTU3F09OT0NBQNDQ0mDx5Mq1bt6ZDhw7s27evwNuD/FPjxo1ZvXo1AwcOZN++ffTp04dr165h\\\\\\n\",\n              \"VNqM1u06YG5bB5muAVJGCm+e36Vjp058N2VSnmWl4uLiCAkJITg4mODgYG7fvo1CS4fK2fJ8L/oF\\\\\\n\",\n              \"SXl6jSYdG+b8Ozo6Oieo3n59+/ne27Dq3r07devWVXpStSrKly/PlUsX6NHbg4MTDlClbU8qNmpH\\\\\\n\",\n              \"iZKGZKUl8+LaGR6fPEA5K0suXQh553y8L5EIMEH4hGRlZXHu3Lmc0EpKSsLZ2Znx48fTrl07peZm\\\\\\n\",\n              \"fQgvX77E1dUVMzMzoqOjWbZsGc2aNaNFixasXLmSNm3aKL2vevXqYWVlxeTJk/E5fBj/gJMYN+5J\\\\\\n\",\n              \"Ddef0dL7vxF4ZmnJXLsVQMeubowbNZy6dWoTHBzM2bNnef78OU2aNKF+/fo5q9JnyeXE3g3BvE4r\\\\\\n\",\n              \"peqRlZpE5PWTxNSriLOzM7du3SI5OZm6devi4OBAmzZtmDx5MjVr1sw116yolS1bloshZ7lx4wbL\\\\\\n\",\n              \"vFZyxmsqyUlJ6BsY0KJ5M5Yf+h8NGzYsfEdfIBFgglDMoqOjOXbsGH5+fpw4cYIqVarg4uLCnj17\\\\\\n\",\n              \"cHBw+Og3ajx//jw9e/akQ4cO+Pv7s3v3burVq0eLFi345ptv6N1b+VFzALq6uigUCjyHDeO3Feuo\\\\\\n\",\n              \"OsKLEqXydhVq6Rlg3qQHpWo4sWTVeKpWsGDI4MEMGjSIV69eMW/ePBYtWoSJiQkTJ06kSpUqfDf3\\\\\\n\",\n              \"V0yqN0GzROEjK/8K2Iyenh7GxsaMGDECBwcHKlasiEwmIzY2ls1btjBu0lTi49+gp1eShvUdGT/2\\\\\\n\",\n              \"a2rVqqXS+b4vR0dHtm3e9FGO9bkQE5kF4SOTJIk7d+7ktLLu3r1L+/btcXFxoXPnznlGq31MW7Zs\\\\\\n\",\n              \"Ydq0aXh4eHDgwAGOHj1KzZo1ad++PS1atODXX39VeZ/R0dFUq1aNjKxsbIavRsek8PtaZbyJ5NmG\\\\\\n\",\n              \"0fT36M2BAwdISkqiQYMGLFiwgHbt2gF/D2Tp5urO1WdRVBv8M5o6+Y8slCSJl6d2kXDJB2sLcypU\\\\\\n\",\n              \"qMDmzZsxNzcnIyOD8RMns3v3birUb4V1w6/QNSqNPDOd1/cu8+yMNzVr1mTn1s1UrlxZ5XNXR+p0\\\\\\n\",\n              \"7RQBJggfQVpaGqdOncoZ6q6trZ2zAoaTk9N/mpulrPT0dM6fP09sbCy6urrUqFEDO7u/7z0ll8v5\\\\\\n\",\n              \"7rvvOHLkCF26dMHX15eAgABsbGzo0aMHhoaGbN++/b1ag2/evMG8jAWmddpSznWa0uX+PLiAlIfn\\\\\\n\",\n              \"GDF8GLNmzSImJibXZ1U3b94kKysLnZL6JKXLsW7lgUXDzmiX/LtbUpEtJ/beeeIu+JAZ84LszHS6\\\\\\n\",\n              \"deuGnp4eR44cYd26dSz8bQlRmdrUH/IDukZ5R24q5Fk8DNzH04AdnD1ziho1aqh8/upGna6dogux\\\\\\n\",\n              \"GCQnJxMXF4eOjg6mpqZiVNFnKiwsLCew/vjjD+rVq4eLiwsnTpygWrVqRb76w1svXrxghddKNm3e\\\\\\n\",\n              \"grFVBfRNypCdlcGrR3eoU7s2I4cNZfv27cDfw9+Dg4M5f/48FhYWjB49mrS0NA4cOKB0eKWlpXHp\\\\\\n\",\n              \"0qWcAReXLl1CkmlSuqGrSvW2aN6bv55d4fLly9jY2GBtbY2DgwN169Zl/PjxODg4ULZsWQDOnj3L\\\\\\n\",\n              \"0hUrCfi5FwamFmhoapESF0XlyrYsnDaeXr16kZ6ezuLFi1mzZg2tWrWi74CBmFWrT8sJi9DQyH/d\\\\\\n\",\n              \"Qw0tbWp0HkAJfSPad+zEw3t3P9rnkELhRAvsI5HL5Rw5coRVq1Zz4cJ5TExMSc9IR1NDA09PT0aP\\\\\\n\",\n              \"Hv3OIcnCpy87OzvX3KywsDA6d+6Ms7MzHTt2/OBzs5Rx6tQpevTqg21LF2p81Qdjq//7G8vOyuTJ\\\\\\n\",\n              \"hRNc2beKMqX0qVWjGm/evOHQoUOUKlWKuXPncvjwYc6cOZOzanx+EhISckYInj17lps3b1KnTh1a\\\\\\n\",\n              \"tmyJk5MTjRo1wsLSEsc5QSqFtiRJ3JrbkcATATRs2PCddXgrLi6OsLAwsrKyKFOmDOXLl8+zTVRU\\\\\\n\",\n              \"FNOnT2fH7r30XnMKbd3C52wBnF8xhalDer3X/DV18qldO99FvPX/CF68eEEXZ2f0DQzxHDGG7fsP\\\\\\n\",\n              \"53QZPXn0kK2b11OvXj2mT5/O1KlTP9o7c+G/S0hIyDU3y8LCAhcXF1atWkWTJk3+040Z/6vLly/j\\\\\\n\",\n              \"1rM3bScvoZx9ozzPa2qXoJqTC7aN2+O/cBzXbtziwb1Q9PT02LhxI9u2bSMkJCRPcERFRXH27Nmc\\\\\\n\",\n              \"wHry5AkNGzbEycmJefPm0bhx45y5Vqmpqdy5cweZTFPlv2uZTAYaGsybN48yZcpgYmKi9ONdLCws\\\\\\n\",\n              \"sLSypnobN6XDC6BSm178vmIlI0eOFP9HPxGiBVbEIiIiaNqsGcNGfs2Y8ZML3C487CUe7i4MGjiA\\\\\\n\",\n              \"6dOnf8QaCqp6/PhxTivrypUrtGjRIue+WZ9KK1qSJGrY18Gmiyd2zToVur08I53DM/qywWsJCoWC\\\\\\n\",\n              \"kSNHEhwcjJ2dHX/99VdOWAUHB/P69WuaN2+Ok5MTLVu2pEqVKrx48YInT57w5MkTnj59mvN9XFwc\\\\\\n\",\n              \"lSpV4sGjx9T57n9oqbD0U3Z6MqG/9sD36BHi4+OJi4sr9KGtrY2JiQmmpqbvDLgJU77BYfh8zGzt\\\\\\n\",\n              \"lX9NFQoOT+rIjcsXP+v5WJ/KtVMZogVWxIYNG06ffoPeGV4AZcuV58DhY3zVqint2rWjUaO875iF\\\\\\n\",\n              \"4pGZmZlrblZKSgrOzs5MnDiRdu3aFdl9mf6L8+fPE5+chm3Td6/6/paWji41nQfx0/wFPH5wj1Gj\\\\\\n\",\n              \"RjFnzhyCg4PJysqiSZMm2NraMnjwYDIyMvjzzz85dOgQixcvJiUlBTs7O+zs7LC1taVx48b0798f\\\\\\n\",\n              \"Ozs79PT0WLRoEY+eriT2RgAWzXspfQ5xt07g3LU7nToVHsDwd2inpKQUGG4xMTE8evTo7++jX6NX\\\\\\n\",\n              \"WrXbkcg0NDA0KUNsbOxnHWDqRARYEXry5AlXrl5h/fZ9Sm1vaWXNqLETWblyFdu3iwArTm/nZvn6\\\\\\n\",\n              \"+hIYGEjVqlVxcXFh3759ODg4fPJdSF6r11C1vfKLxAJUad6ZjRvmo6utybFjx9DR0cHY2JiXL19y\\\\\\n\",\n              \"+vRpwsLCsLW1xc7OjjZt2jBixAjs7OywtLTMdRxJkggICKBXr15cu3YNHR0d9EpokXj9CGWa9kCm\\\\\\n\",\n              \"xGAQSaEg+YYfU/dsVbr+MpkMAwMDDAwMqFChQp7nMzMzefToEaGhoQSdCUaRlaX0vt+SZ2Z8lBGj\\\\\\n\",\n              \"gnJEgBWhNWvW0Lf/YJVm9fcdMJhGdasRExNTJEvXqCNJknLWzNPX18fMzOyDB4gkSdy+fTunlXX/\\\\\\n\",\n              \"/v2cuVleXl5YWFh80OMVtbv3HlDdQ7VFX7V0dDEpV5nalSxp0aJFTqvKzs4OU1PTQl/z2NhY5s6d\\\\\\n\",\n              \"y7Zt20hKSsLe3p49e/bg7u5OgwYNSJcrCDvmRbkuE965L0mSiD69kSoVy9KyZUuVzgH+nh/2559/\\\\\\n\",\n              \"cufOHUJDQ3MeT58+pVKlStjb21PG3JzXj29iaFFO6f2mJcSSFPf6k+kmFkSAFalLly7z7Y8/qVTG\\\\\\n\",\n              \"xNQU+zp1uXXrVs6EzS9VUlISO3fuZOWq1bx48QIjo1KkpCRjYmrC2DFjGDp0aKEf2L9Lamoqp06d\\\\\\n\",\n              \"wtfXFz8/P3R0dHBxcWHevHk4Oan3PZcyMzJUulX9W6VNzZgyZQpdunRRantJkvD19WXOnDncvHkT\\\\\\n\",\n              \"fX19BgwYwOzZs3OFviRJxES+QqNEHFHHlmPWZjhaenmHo8vTkok5s5mSbx7hfy640KCLiIjIE1T3\\\\\\n\",\n              \"79/HzMwMe3t77O3tcXFx4fvvv6d69eo5byYPHz7MhBlzsW2hfMg/DT6Mm6sbRkaq3cJFKDoiwIpQ\\\\\\n\",\n              \"cnIyBu/x+YiWljY7duzgr7/+wszMLNfD2Nj4oy8tVBzOnz+Pq5s79vUaMfLbedRv6oRMJkOSJO7e\\\\\\n\",\n              \"vMLhPZuZ//MCdu7YjrOz8jc1fPnyZc7crODgYOrXr4+LiwtBQUFUrVr1k+8aVJapmRnJca8xr1xT\\\\\\n\",\n              \"pXJJMZG5Vo4vSHR0NDNnzmT37t2kpKRQr169nEnQ/34NQ0JCuHfvHlOnTmX69OmMHDOWQ8v7YVC1\\\\\\n\",\n              \"Kca1WqOpo092RgrpTy8SF/oHzs7ObDx2Ptddh+Pi4nKF1NuHtrZ2TlA1a9aMkSNHUqtWrUJDxtnZ\\\\\\n\",\n              \"mVFfjyUi9BJW9o0LPd/0pHiendzPOv+jhW4rfDwiwIqQkZER8fHxKpdLSIgnPd2Mc+fOERMTk+uR\\\\\\n\",\n              \"lJRE6dKl8wTb24epqWmenxkZGanVhfnixYt07dadGYtW08Spfa7nZDIZ9o6NsHdsROjNKwwZOpCt\\\\\\n\",\n              \"WzYXGGLZ2dlcvnw5p2vw1atXdO7cmUGDBrFz506l7gqsjty7ObNm/wFsGrRWuszrp/dQZKTQoEGD\\\\\\n\",\n              \"fJ+XJIn//e9/zJs3jzt37mBsbMzw4cOZOXNmgXPcjh49iqenJ46OjjRr1gwjIyP27trBwIEDiXod\\\\\\n\",\n              \"Tdqr0yQmJmJoYEiHzq0YuGsZsbGxeHt75wqqt12Sbx+9evWiVq1aSt9+5d+0tLTYtX0b7r09aDFp\\\\\\n\",\n              \"GeZ2tQvcNiM5kZBlkxgyoB/169d/r+MJRUMEWBFq27YNfkd8aNVG+a7AV+FhPLh3l1o1quPm5kaH\\\\\\n\",\n              \"Dh1yfYYml8tzRlT9+xEWFsbNmzfz/Dw9PT3fYHvXo2TJksUSepmZmbj36Mn3C1bmCa9/s3doyPyV\\\\\\n\",\n              \"2xk4sB/Pnj3NCaOEhAQCAgLw9fXl2LFjWFlZ4eLiwpo1a2jcuHGxzs36GB4/fszatWt5GRZOctxr\\\\\\n\",\n              \"DEyUu8g/DNzL2DGj87w+kZGR/PDDD+zfv5/U1FSaNGlCUFAQbdu2fef+tmzZwvTp0/Hz8+O3337L\\\\\\n\",\n              \"uStzZmYm/v7+7NmzJ1fLavu2rSxa+AvVqlXLCaqJEydib29P+fLlP/jfY7t27di5dTMDBg/FpmU3\\\\\\n\",\n              \"bNv2xLDM/30mlpWWwrNzfjw5sROPHq4s/k31dSCFoiXmgRWhV69eYW9vz7XQJxgq2W++6OefeP3q\\\\\\n\",\n              \"BfXq1cPHx4cbN27QsWNH3NzccHZ2fq/+94yMDGJjY/MNvfwe0dHRACoFnpmZ2QcZnbV//34WL1vJ\\\\\\n\",\n              \"0m2HlS4zd8oIGtStjqmJCb6+vly9epWWLVvmzM3Kb0Ta58rPzw+Pvv1QSBI6pUzRNTKh++xNaBWy\\\\\\n\",\n              \"WvuzS0Fc2foLd+/cwtzcHEmS2Lt3Lz///DP37t3DzMyM4cOHM2PGjEKXUpIkiV9//ZU1a9bg7++P\\\\\\n\",\n              \"jo4Ow4cPx8TEhBIlSnD+/HnCwsKoUqUK9vb21K5dOyewbG1tP/rSas+ePWPFypVs3bqdUpYV0C1V\\\\\\n\",\n              \"muyMdCKf3vv7FisTxhUa1p+TT+HaqSwRYEVswICBaJbQZfHy1YW+g3z44B6uXdoT/McfOYuGRkdH\\\\\\n\",\n              \"c+TIEby9vTl79iwtW7bEzc2Nbt26vXf3iTJSU1OVDry3Dx0dHZUCz8TEBG3t3AMNWjq1omOvobTp\\\\\\n\",\n              \"1F3put66coFpo/vSr68HLi4utG3b9pOcm1WUJEli8eLFzJ47n7K1GtBsyPcYmFkRuHwaKW+i6TDh\\\\\\n\",\n              \"FwzN864Cr8iWcz/of9w8uJoTx/2xsrLi+++/x9vbm4yMDFq2bMmCBQvy3GAyv+NHRERw+/Ztfvnl\\\\\\n\",\n              \"F27fvk3FihV5/PgxhoaGZGVlYWVlxVdffcXt27fp3r0748ePL6qX4728XcMxPj6ekiVLYm9vj7V1\\\\\\n\",\n              \"4Svnf24+lWunMkSAFbGkpCSaNW+OY/1G/Lp0ZYHvLm/fusHA3m4sXPgLAwcOzHebxMRE/P398fHx\\\\\\n\",\n              \"ISAggLp16+Lu7o6bm1uxtzIkSSIpKUmlwIuLi8PQ0DBXqB0PCODYlWfolVQ+gCRJolO9ioSHh+X6\\\\\\n\",\n              \"4P9LkZGRwZgxYzh01BfLWo1oM24hGv+/G1BSKLi8fzW3/XdhXaMe1Vp1Q7+0OfKsDF7du8qTMz5U\\\\\\n\",\n              \"sbWla5dO7Ny5k0ePHmFpacmYMWP49ttv850C8ubNm1yfT70dBaipqYmmpiYaGhpMmTKFqKgoDh31\\\\\\n\",\n              \"Iy0jE/NKVdEqUYKU2Nc8uXMdT09Ppn8/TQxJ/wR9KtdOZYgAK2Jv38WmpKSQmJjEQM8R9Pboj6WV\\\\\\n\",\n              \"Nenp6Vy5dIGtG9dy8fw51q5dS69eyq1UkJ6eTmBgID4+Phw5coRKlSrlhJm63PJBoVDw+vVrHj16\\\\\\n\",\n              \"xOPHj3n69CkLFy4k+EGMyp93uLWowc0b17+4d8yRkZG4u7ujp6fHtVuh9Ft9Ak3tvMP/s9JTeXTW\\\\\\n\",\n              \"n2eXT5KRnICGljbxLx9Tp2Z1bt68iVwup23btixcuBBHR0cAUlJSuH//fq6Qym9Ahb29PTY2NowZ\\\\\\n\",\n              \"M4YSJUqwZs0aevbuQ1K2Ji09RmLr0CTX7zM24iWXjuzmZqAPh328adGixUd7vYTCfSrXTmWIAPuP\\\\\\n\",\n              \"nj59ytatW3nx4iUKhQJraysGDBhA7dq1kSSJkSNHEhsby8GDB7l16xarV6/m8OHDxMbGoqOjQ40a\\\\\\n\",\n              \"NRk9ehT9+vV7724vuVzO2bNn8fb2xsfHB0NDQ9zc3HB3d6d+/foffTCGJEkkJiYSGRlJRETEO78m\\\\\\n\",\n              \"JCRgbm6OpaUlVlZWBAYFcTjkAYZGyreksrOz6VivAtGvX39Rt7q4fv06rq6uDB06lMCTp8guW5dG\\\\\\n\",\n              \"HmOVLn/lwFruHdvB9O++pXPnzjx48CBXy+rVq1e5BlS8fVSoUCHX31RMTAzOzs7UqlWLFStW0P6r\\\\\\n\",\n              \"juhY2eEybuY7p3w8vHKWgwu/4fTJIOrWrfufXgvhwxEB9gko6l/CzZs3+f776Vy9dpU+fQdQo6Y9\\\\\\n\",\n              \"MpmMZ08fs2fXdmwr29KgQX0CAwO5cOFCnhW9JUkqkmBRKBRcvXoVb29vvL29SU9PzwmzFi1a/KcR\\\\\\n\",\n              \"eHK5nOjo6HzD6N8/09TUzAmlf3/95/empqa56uTs0pU6zTrg0iv/btT8hJwOYO+6xVy/ph7/6Qry\\\\\\n\",\n              \"4sUL1qxdy4EDB4mNjUVbWwu7KlX5evQoevbsmas7b9++fYwbN44VK1YQFBTEtu07GLQ2CH0T5df3\\\\\\n\",\n              \"S42PZdvItmhpauSsUPHPh52dXaEDKv7666+cQUYLFixgzZo1rNmxn0ELNio1X/Gi717CLx7n/Nlg\\\\\\n\",\n              \"pestFC0RYJ+AovwlnDx5Eg8PD2bMmkuffgPR08t9K/OsrCyOHPJm8vjRzJ83j4kTJxZJPQojSRL3\\\\\\n\",\n              \"7t3Dx8cHb29vwsLC6NatG25ubrRv3z5n1GBKSkq+IfTvr7GxsZiYmBQYRv/8+r4toePHjzPl2+9Z\\\\\\n\",\n              \"/79TSgf8VM8ejB42iCFDhrzXMYtbRkYGo8d8jc8hH9q69KJttz6YW1ojz8ri4Z3rBBzYxrOHd1m/\\\\\\n\",\n              \"bi3du3dn1qxZ7Ny5k/Xr1zN79mzKlCnDsYAARu25ofKxN/Srz4u/nr/XgKDQ0FA6d+7M1KlTmTRp\\\\\\n\",\n              \"EpIkUb2WPW1HzsDO8d2DPt6SZ2WyqF8rzp05Tc2aqk26FoqGOgWYmAemovv379O3b1+27NxH85at\\\\\\n\",\n              \"8t1GW1ubHr36UKNGTdy6dqR+/fofvZ9foVAQExODXC6nYcOGlCtXjvv373PhwgWGDBnCmzdv0NPT\\\\\\n\",\n              \"Qy6XA+QbRi1atMj17zJlyhT6jjwhIYHz58+TmJiIvr4+jo6OWFpaKl3vr776CmnKFI7u20Y3jyGF\\\\\\n\",\n              \"bv/HCV9uX7/MEXNjOnXqpNKxPgWZmZk4u3QlU0OHLQHX0CuZO/ibtetCs3ZdeHjnBsNG9Ofrr79G\\\\\\n\",\n              \"kiQqVqyIi4sLOjo6ZGRkIM/Ofq9WvSRJ7/VmIyQkBHd3d5YuXUq/fv0AuHTpEilpGdg6NFF6P1ra\\\\\\n\",\n              \"JWjYuTfr1m9g+bKlKtdD+LKJAFPRggW/8PX4yQWG1z/VtK/N3J9/ZfacOZwMCvogx09LSyMyMrLA\\\\\\n\",\n              \"ltLb71+/fk2pUqXyTcUnfQAAIABJREFUhJKbmxtjxoxBR0eH0NDQnFu+16pVK2d4/vssIhwaGoqX\\\\\\n\",\n              \"lxf79u+nVu26GBuXJjUlhRvXr9ChfQfGjx+Hk5NTofvR0NDg6JEjtGjphEKS6O4xpMCL8il/Hxb9\\\\\\n\",\n              \"OIkTAcfx9/enTp06/PrrrwwePDhPGUmSuHLlCitXrebChQukJCdjaGREm9atGDt2LLVrF7wSw4eQ\\\\\\n\",\n              \"nZ1NbGwsr1+/Jjo6Oufr7j17yZKVYM7q7e/s3q1W25FfNvswuW8nNGR/36zS2NiYmjVrYmFhwbET\\\\\\n\",\n              \"gcS9fIJphSpK1+lN+J/oGxjm6UEozNvVNXbu3EnHjv93u5bHjx9TrlptlUPUuqo9D0LEEk2C6kSA\\\\\\n\",\n              \"qSAmJoajR49w7c5ipcu49ujF7B+n8eDBA6pXr57vNm9XWy/o86R/fk1LS8PS0jJPF17Dhg1zBZWF\\\\\\n\",\n              \"hUWhi9G6u7sDf7ea/Pz88Pb2ZvLkydSvXx83NzdcXV3zvSX7v23bto1vvvmWwSO+5uT5m1hYWuU8\\\\\\n\",\n              \"l5iYgPe+XfQfMJDBgwYyb968Qi9wdnZ2nDsbTPfurhzZu4VufT1p9ZULhoalSE1J5vyZExzYuoa0\\\\\\n\",\n              \"lERsK9vg5+fHL7/8Qq9evRg2bBi7d+9m/fr1VKpUCfh7oqpH335ERUXhMXgEK4ZPwMDAkMSEeE74\\\\\\n\",\n              \"H+arjp2oUaMGu3ftVLoFp/h/7d15XI3p/z/wV6V975xTnRaVVFJKZUtRliSyU2SXkn0XYwbD2I1l\\\\\\n\",\n              \"LFmzkz1bJFuJkPGRLE1UQmnP0r6c9++P+eo3Tds5quFwPR+P86Bz7uu+rnObuV5d933d1y0QIDc3\\\\\\n\",\n              \"t0ogZWRkVPve+/fvoaamBh6PB01NTWhqakJVVRUxjx5hx/k7Ql2bNDA2w+BxU3Dh6G6sX78eT58+\\\\\\n\",\n              \"xblz5/Du3Tu0smiJJ5eOwGnCYqHaDwBxYccxfry3SIGzd+9e/PTTT7h48WKVZ9aVlpZC8gsWEJZq\\\\\\n\",\n              \"0gQlJSUil2MYdg1MBDt37kTYtRvYufeQSOV+WTAXWemp6NKlS7WhlJ6eDgUFhRqvJ/3zT3V19Uad\\\\\\n\",\n              \"VVhYWIgrV67g9OnTuHDhAoyNjTFw4EAMHDgQpqamVbY/ceIEps+YicOnLsLErObp+1mZGRg1pA8G\\\\\\n\",\n              \"DxqARYsWCdUWgUCA69ev44/NWxARHo5Pnz5CUVEJraysEPf8GVJSUpCXl4eOHTti1qxZ8PPzQ2lp\\\\\\n\",\n              \"KdavX4+1a9fil19+QY8ePdC1azf4TJ2DEeMmVDuxoLS0FAEbV+PcySM4eOAAJCUl6wyk7OxsKCsr\\\\\\n\",\n              \"Vwqkz3//95+amppQVFTE69evER8fX/G6efMm1LQNsGz7UaH/fd5nZ2J0D1tYWlhg4MCBcHd3x//+\\\\\\n\",\n              \"9z8sXLgQWbnvMWLrZSio1T2CLvyQg+Mz+yL20UOh7sX65+oaoaGhMDMzq7LNhQsXMO/XFfBeJ9r/\\\\\\n\",\n              \"H3fOHoFMehyOHhatHNM42DWw79S7d+9gYNRM5HLNjJsjNOQ8lJWVwefzYWVlBVdX10rhJMozwxqT\\\\\\n\",\n              \"vLw8+vXrh379+qG0tBTh4eE4c+YMnJ2doa6uXnGvmY2NDQoLC+E3cSIOnaw9vACAy9PE3qBguHZu\\\\\\n\",\n              \"g2HDhsHEpO5TXZKSkujevTu6d6+6JqKDgwNCQ0PRr18/XLp0CY6OjtDT04O7uzv8/f3Rr18/jB07\\\\\\n\",\n              \"Fr8sWgz/RcvhOXJcjfVIS0tj2tyfAQB9+vaFeYsW0NLSqgigZs2aoUOHDpUCicvlVllFRCAQ4M2b\\\\\\n\",\n              \"NxUBdefOnYq/p6SkoGnTpjA1NYWpqSns7OzwLO4vOPUfWudx+Cc1Dg/WbTti4ZzpkJKSwogRI6Cm\\\\\\n\",\n              \"poYTJ04g9EoY9qyaBLeFOyGnXPMixcX5HxGyfAImT5ooVHgJBALMmTMHYWFhuH37NnR1dSt9XlJS\\\\\\n\",\n              \"gsuXL+PAgQN4+fghst+9AYdf98j9s5iw09iwQrTHDjEMwAJMJFJSUhCUl4tcrry8HN27d0dAQEAj\\\\\\n\",\n              \"tKrxSEtLVwTI5s2bcf/+fZw+fRoeHh4oKyuDiYkJbGzbopW1jVD709Lmw8NrNLZv347ff/9d5PYU\\\\\\n\",\n              \"FBRUjILat2+PRYsW4a+//kJGRgZsbGwwcOBAGBsbIz8/HxkZGSAimLawqDW8/mnqnIUIOXsSq1ev\\\\\\n\",\n              \"rvF6HREhKysL9+/frzSaio+PR0JCAjQ0NCpCytTUFC4uLjA1NYWRkVGVwDt0+ChU1EV/npmyqhrm\\\\\\n\",\n              \"zZsHaWlprFy5Eu7u7pCQkIC9vT0+fvqIIz+PgM2QyWjWvlulm5rLS0uQeP8a/nd8Kwo+ZKNDu7Z1\\\\\\n\",\n              \"1lVSUoJx48bh1atXiIiIqFh1XiAQ4NatWzhy5AhOnTqFli1bwsvLC1yeJu6dO4JeE/yF+i6v4x6j\\\\\\n\",\n              \"4H2W0M8fY5h/YgEmAkNDQ0QeFv50z2dPn8TC3NS4EVr035GUlESHDh3QoUMHrF69Gk+ePEHffv2w\\\\\\n\",\n              \"dNVGkfYzYqwP+nR3wMqVKyEQCJCZmSnUNaSMjAwIBIKKURCHw0FcXBwSEhJgbGwMS0tL2NnZYfv2\\\\\\n\",\n              \"7Thz5gxsbW3R060XBo8YL3TbJCQkMGy0L7Zs3Qo7Ozu8ePGiUkD99ddfiI+PBxHBzMysIqQ8PDxg\\\\\\n\",\n              \"amoKExMTkWb0KSoqoqiwQKTjBwA52VlwcXHBpk2bKl07k5CQwIbff0cXJyes/n0DjuxfDUPbTmgi\\\\\\n\",\n              \"r4Sywjwk/y8SLVqYYfumddDT00Pv3r3B4XBqnCGbl5eHwYMHQ0ZGBleuXIG8vDwePXqEI0eO4OjR\\\\\\n\",\n              \"o9DQ0ICXlxf+/PPPipFccnIybNu0hbGNPcza1T5pJ/99Dk6umoNfFy/67p8QwDQOdg1MBPn5+Wja\\\\\\n\",\n              \"tCmuR95HUwNDocp8/PgRrcwM8fzZsyqnXsSdiooK7sYmQEWEVTMAwNJIC2VlpSgpKan1GtK/31NS\\\\\\n\",\n              \"Uqp0/W/atGlQU1PD0qVLK97bvHkztm3bhrCwMJiYmODhy/QqI5/a5GRnwbF1c0hJSqJ58+aVRlOf\\\\\\n\",\n              \"X1wut0GuQy78+Wc8S87AhPnLhS5TXFSIsT1s8SD6Ppo1q/109vPnzxEZGfn387aUleHo6FjpXquw\\\\\\n\",\n              \"sDCMGDECV69erTIL8/PqGpaWlpg/fz6OHz9e8fBKLy8veHl5wdLSstp6IyMj0bf/AHQbOxNtXAei\\\\\\n\",\n              \"STVLW72Je4wTq+Zg1DBPrFwh/PdnGp84XQNjASai6dOno7gMWLl2vVDb/7F+LQK2bISRkRGWL1/+\\\\\\n\",\n              \"XT2WQVZWFk+SMkSehm1vbYLz587C2tq6XkEQGxuLnj17Ijk5udL9aXPmzEFkZCQSk5JwJ/aVSPsk\\\\\\n\",\n              \"IpjxlVBcXCxS8H2J169fw8q6NfZdeQg5IRcvDjsbhNiIEFy5fKlB2hAUFIS5c+ciMjKy0iiqW7du\\\\\\n\",\n              \"aNasGfLy8vDixQt4eHjAy8sL9vb2Qq2w8fjxY0ycMhXP4+LQ1s0DOqaWkGrSBLnp7xATdhqFH7Lx\\\\\\n\",\n              \"6+JFGO/t3SDfg2k44hRg3/+z6RvYggULcPniWRzcH1jntpdDziNgy0bcuHEDU6dOha+vL7p164a7\\\\\\n\",\n              \"d+/+By1tfBwuF2nvUkQqU1xcjNycbBgaGtZ7FNOqVSsYGBggJCSk0vtr1qyBlpYW8j7libzP4qIi\\\\\\n\",\n              \"SEtLN/ozqYgI9+/fBwE4ulO407D5nz7i4ObV6Opc9z2Iwho6dCjmzZuHHj16ICkpCStX/v1AyZSU\\\\\\n\",\n              \"FPB4PPz8889ITU3F1q1b4eDgIFR4AYCVlRVuR4Qj8uYNmGtII/X2eSSEHYNMehw2rlyK10mJLLyY\\\\\\n\",\n              \"emPXwESkra2N0NBQuPbsidiYR5g4ZTqMmlW+vpWamoLAnQE4cnAfzp49C3Nzc5ibm2PIkCHYt28f\\\\\\n\",\n              \"hgwZAhsbGyxbtkysFzEdOGAATgUdxpyfhL/36NKFYLRoYd5gjz3x9fXFzp070bdv34r3JCUlERQU\\\\\\n\",\n              \"BG1tPp7EPISlta3Q+4uKvAmreo4M6/LmzRtMmTIF8fHxcO7cCReO7oGqmgYGjplYY5lPH95j+fTR\\\\\\n\",\n              \"aNfGBps2bcLHjx+xZMmSOu/1q0tJSQkMDQ0hJSWF5s2bQ0pKCr6+vli9enWDPFOtZcuW2LhBuLMV\\\\\\n\",\n              \"DCMqNgL7AmZmZrh39y7UlBXg2sUBg/r2xEL/2fh5/hwM9xiATu1aozDvA6KiotC+ffuKctLS0vDx\\\\\\n\",\n              \"8cGLFy/QtWtXuLq6YujQoYiPj/+K3+bLTZo0CUEHA4W+CZWIELhjC969S4WtrS0OHDhQ7xtYPTw8\\\\\\n\",\n              \"EBUVhdevX1d6X15eHtOmTUPg9s0i7e/ovl2YPGlSvdpUk/LycmzZsgW2trawsrKCubk5MjIy0MLM\\\\\\n\",\n              \"FMd2bcCs4W64ffUiyv9veS8AyM3KwLFdGzF1cBc4O7TH2eBgxMTE4PHjx3BwcPii/3YEAgEiIiIw\\\\\\n\",\n              \"YcIE6OjoYM2aNXByckKTJk3QqlUrrF+//od7ICgjnliAfSEtLS2sXbsWb968wcQJvjA21Iehvg6G\\\\\\n\",\n              \"D/NEcnIytm3bVrESxL/JyclhxowZePnyJaysrODg4ABvb28kJyf/t1+inlq2bAlDQ0PMmeoLYS6l\\\\\\n\",\n              \"7tq2CSWFBUhMTMSKFStw6NAhGBoaYvny5cjKyvqiNigoKMDLywt79uyp8tnUqVMQfvUyXsQ9E2pf\\\\\\n\",\n              \"f96PwuNHD+Dp6flFbanNkydP4OjoiGPHjiE4OBjXrl1DQcHfxyIjIwO6Onz0dumCsGO7MLyLBSYP\\\\\\n\",\n              \"6Axf9w7w6+cIiU/pCLlwDhs3boCUlBQ0NTUrlnNycHDA7t27qz3+Dx48wJgxY2HUrBk4XB509fRh\\\\\\n\",\n              \"ZWUNPp+PSZMmwcjICA8ePMC4ceMQHByMGzduoGnTphgzZgwEAkGDHwOGaXD0nbKzs/vaTRBaTk4O\\\\\\n\",\n              \"/fTTT6ShoUFTpkyhd+/efe0m1am8vJwWLFhABgYGZG3dmgZ5DqfHL1PpTU5xlVf821yaPucnMjAw\\\\\\n\",\n              \"pOTk5Er7efz4MY0bN47U1NRowoQJ9Pz5c5Hb8vjxY9LV1aXS0tIqnx04cID4Orp06dZDepFeUOPr\\\\\\n\",\n              \"5KVw4mlq0aVLl774mFSnsLCQFi5cSFwul7Zv305xcXHUvHlzcnJyIhkZGeJyubR//34qKyurKJOa\\\\\\n\",\n              \"mkqPHz+muLg4+vTpU637f/r0KVlbW9OAAQMoKyuLiIjevHlD9vYdSU/fgCbPW0wnrkbTpXvxdPLa\\\\\\n\",\n              \"nzRj4XJqatiMbGxtKT4+nlatWkUGBgYUFxdHREQFBQXUqVMnmjZtGgkEggY9Fox4EKe+kwXYNyQ9\\\\\\n\",\n              \"PZ1mzJhBGhoa5O/vT9nZ2V+7SdX69OkT9e/fnzp16kQZGRmUn59P3t7epKKqSh5eo2hf0BkKvhxO\\\\\\n\",\n              \"h09dJJ+J00iDw6He7n1qDea0tDRavHgxaWpqUq9evSgsLEykDtTe3p7OnTtX7Wd79+4lVTU1GuU9\\\\\\n\",\n              \"ka7eja0UXBdvRtNgr9GkrqFRY/kvdfPmTTI1NaVBgwZRSkoKRUZGEpfLJT6fT1JSUjR37lwqLCys\\\\\\n\",\n              \"dz1FRUU0e/Zs0tXVpUOHDpGOji5NnruYol5k0/3E91Ved1/m0Lxf15GKiiqZmJjQ27dvK+0vNzeX\\\\\\n\",\n              \"rKysaPny5fVuGyN+xKnvZAH2DXr9+jX5+PgQh8OhX3/9lT5+/Pi1m1QhOTmZrK2taezYsVRcXFzp\\\\\\n\",\n              \"s4yMDFq5ciV1696d2rZrT85dutK8efMoMTFR6P0XFBTQrl27qGXLltSqVSsKDAykoqKiOsvt3buX\\\\\\n\",\n              \"evfuXePnCQkJ5OLiQvLyCtTc1Ixs7NpSM+PmpK3NJ2tra5o7d67QbaxLTk4OeXt7k56eHgUHBxMR\\\\\\n\",\n              \"UWBgIMnLy5OkpCQ1a9aMkpKSGqw+or9/qZg3bx7JKyjSFP9fqw2uf78WrNhEhkbNqKSkpMr+UlNT\\\\\\n\",\n              \"ycjIiHbt2tWg7WS+feLUd7IA+4a9ePGChg8fTpqamrRu3ToqKCj4qu25c+cO8fl8WrduXaOfXhII\\\\\\n\",\n              \"BHT58mVydXUlbW1tWrp0KWVkZNS4fX5+PmloaFQ5RflvM2fOJBsbG7px4wY9efKESkpKKD4+nrhc\\\\\\n\",\n              \"LuXm5ta7zUFBQcTn82ny5Mn04cMHys/PJzc3N5KQkCAFBQWaMmVKpdOF9VFcXEznzp2joUOHkoqK\\\\\\n\",\n              \"CrVr146MTVrQvYRcoQLsfuJ7atPBgU6cOFHt/uPj40lbW7sihJkfgzj1nSzAxEBsbCz179+fdHV1\\\\\\n\",\n              \"KSAgoMrI579w8OBB4nK5dP78+f+87idPntD48eNJTU2NfHx86OnTp9VuN2XKFFq8eHGt+yovLycP\\\\\\n\",\n              \"Dw/y8PCg8vLyivfHjBlTZ9naJCcnU+/evcnCwoLu3LlDpaWltGPHDlJUVCR5eXlSVVWtMShEUV5e\\\\\\n\",\n              \"ThERETRhwgTicDjk4OBAW7dupYyMDOrbrz/N/22D0OF1P/E9/bZpD3V2cq6xvujoaOLxeBQREVHv\\\\\\n\",\n              \"tjPiQZz6ThZgYuT+/fvUo0cPMjIyqnLhv7F8nqxhZGREsbGxjV5fbdLT0+nXX38lLS0tcnV1pdDQ\\\\\\n\",\n              \"0EojwcePH5Oenl61kzn+qbCwkBwdHSudNkxISCAOhyPydceysjLauHEjcTgcWrZsGRUVFdGZM2fI\\\\\\n\",\n              \"zMyM1NXVSUtLq97HTiAQUExMDM2bN4/09fXJwsKCVqxYUeU0pJycHF179EqkALsdl0EyMjK1nqYN\\\\\\n\",\n              \"CwsjTU1NiomJ+eLvwIgPceo7WYCJoZs3b5KDgwOZm5vTiRMnKo0kGtK/J2t8KwoLCykwMJAsLS3J\\\\\\n\",\n              \"wsKCdu/eXTEZokOHDkJNxsjOziYzMzPasmVLxXvjx4+nhQsXCt2OmJgYateuHXXu3Jni4uIoIiKC\\\\\\n\",\n              \"7O3tqUWLFmRkZER8Pp9cXFy+eDJOUlISrVixgiwsLEhfX5/8/f1rDJGioiJq0qSJSKcPP780OFxK\\\\\\n\",\n              \"T0+vtS1BQUGkq6sr0vVMRjyJU9/JAkxMCQQCCgkJIRsbG7K1taWQkBChrkvdv3+fJkyYQK6uPcnF\\\\\\n\",\n              \"pQeNHj2GwsLCqoRgbZM1vhUCgYDCwsLIzc2NtLS0aPHixbRhwwZyd3cXqnxiYiLx+Xw6e/YsEf0d\\\\\\n\",\n              \"GBoaGpSZmVlruYKCAlqwYAHxeDzatWsXxcTEkLu7OxkYGNCKFSuIz+eTmpoazZs3T+RRckZGBm3d\\\\\\n\",\n              \"upU6duxIHA6H/Pz8KCIios5fUl68eEGSkpJ0Oy5DpPC6l5BLcvLydU7XJyLavHkzmZiY1Bl2jHgT\\\\\\n\",\n              \"p76TBZiYKy8vpxMnTpC5uTk5OjpSeHh4tdtFRESQrZ0dGRoa0aJfl1PQqbN07PR5Wv37JrJsZUXN\\\\\\n\",\n              \"TUzo+PHjRPTfTtZoKM+ePSNfX19SVVUlWVlZCgsLE6pcdHQ0cblcunfvHhER+fn5kb+/f43bX7t2\\\\\\n\",\n              \"jZo3b04eHh4UHR1NY8aMIU1NTdqwYQNdunSJlJWVSVlZmYKCgoRu+6dPn+jQoUPUq1cvUlFRoWHD\\\\\\n\",\n              \"htH58+dr/MWhqKiILly4QKNHjyZTU1OSkZEhAKSkrEIbA0+IFGC7T4RSM+PmQv87//zzz2RnZ/dN\\\\\\n\",\n              \"zYxlGpY49Z0swL4TZWVltH//fjIyMqIePXrQ/fv3Kz47c+YMcXk82nf4OGV9KqGc/LJKr+y8UroQ\\\\\\n\",\n              \"ep309PRp1KhRxOPxvspkjYaQmZlJ7du3JyUlJXJxcaFLly7V2TmfO3eO+Hw+JSQk0OvXr0lDQ6PK\\\\\\n\",\n              \"KCMrK4vGjh1L+vr6dOjQIZo9ezZpaGjQwoUL6f379xXT5Pl8Pj169KjOdhYXF9P58+dp2LBhpKKi\\\\\\n\",\n              \"Qm5ubnTo0KFqR0LJycm0bt066t69O3G5XJKQkCApKSlq2rQpDRkyhA4fPkz5+fm0Z88e6ty9p0gB\\\\\\n\",\n              \"1qu/B/3+++9CH1+BQEC+vr7UvXt3oW5vYMSPOPWdLMC+M8XFxRQQEEA6OjrUv39/OnHiBHG5XLoe\\\\\\n\",\n              \"ea9KcP379TgukTQ4HNq8efPX/hr1EhMTQ7q6urRnzx6ytrYmc3Nz2rlzZ623IWzbto1MTU0pKyuL\\\\\\n\",\n              \"pkyZQrNmzaIbN27Q1GnTqLOTE6mpq1PHjh1p7ty5Faf2UlNTSSAQkL+/P8nJyVHHjh0rVsOozucZ\\\\\\n\",\n              \"hH5+fsTlcqljx44VMwg/KywspGvXrtHEiRPJwsKCZGVlK6bg29ra0qxZsyg6OrraU4r5+fnE5fJo\\\\\\n\",\n              \"1/HLQoXXwfPhpKamTjk5OSId37KyMhowYAB5eno22vVX5usRp76TBdh3qqCggNatW0dqamq0bOXa\\\\\\n\",\n              \"OsPr8+voyWCya9Pmaze/3jp06EDnz58ngUBA165dI3d3d+LxePTLL7/UuCKIv78/OTg40PoNG0hJ\\\\\\n\",\n              \"WZmMTUxpxvwltGLjDlqy5g9ydR9A8goK1K//AHr37h0VFxdT3759SUZGhiZOnFjj7MeYmBjy9/en\\\\\\n\",\n              \"pk2bUsuWLWn58uWUmJhIAoGAXr16RQEBAdS7d2/S0tIiSUlJkpSUJC0tLXJzc6OtW7dSamqq0N87\\\\\\n\",\n              \"NDSUuDxN2nMqrNbwOnTxFmlp8ytOG4uqsLCQOnfuTFOmTKkywhUIBPTx40fKzs7+T2bKMg1LnPpO\\\\\\n\",\n              \"FmDfsbS0NFJVU6OklCyhAyzzYzEZGBhWOgUpjgIDA6lPnz6V3ouLiyM/Pz9SU1OjMWPGVJnRV15e\\\\\\n\",\n              \"TlZW1qSjq0/7TobQ05RP9Cw1r9LrzpNk8pkym/SbNq24/hQYGFil/s8zCC0tLUlfX5/mzZtH9+7d\\\\\\n\",\n              \"oxs3btDMmTOpdevWJCcnR1JSUiQtLU0tWrSg8ePH0+XLlyk/P79e3/3ChQukweHSgGGj6dCFiErB\\\\\\n\",\n              \"FRR6lzxH+5K6BoeOHj1ar3o+Lzn122+/EdHfpzoXLFhAPE1NUlRUJBVVVZKTk6OhQ4fRrVu3xOZ6\\\\\\n\",\n              \"6o9OnPpO9kTm79iuXbtw9fpNbN9zQKRyK39bAiotwtq1axunYf+B/Px86OrqwsfHB49jY5GXlw9l\\\\\\n\",\n              \"ZSU4OzlhwIABOHXqFLZu3Qpzc3PMmjULPXv2xM6dO7Fm3XocDL4CDQ6v1v0f2hOAjSuXICTkIpyc\\\\\\n\",\n              \"/n7AZFZWFo4fP44jR47g+fPncHV1RdOmTREfH4979+4hLS0NEhISUFRUhJWVFdzc3NCnTx9YWFgI\\\\\\n\",\n              \"/aBIYaWlpWH37t3YvmMnJCQkoayqirxPn1BSXARfn/Hw9fWFnp5evet59+4dOnbsCGvr1oi4FYEB\\\\\\n\",\n              \"Q7zgNcYHzU3MAAAf3ufi1LHDOLx3B4wMDXDy5MkGexYc0zjEqe9kAfYdW7FiBTJzPmDR0hUilTu4\\\\\\n\",\n              \"PxCnjx3B5MmToKioCEVFRSgoKFT6U1FREXJycg3e8TaE3NxcTJ4yFefOnoV7/0Fw7d0Xyiqq+Pjh\\\\\\n\",\n              \"PUIvnsPli2cxaOBArFu3DiEhIVi/fj3y8vKQ+u4djpy7BpMWFkLVM3vCKDg5tIe+vh7279+PyMhI\\\\\\n\",\n              \"GBoaoqSkBCkpKSgrK4NAIICOjg7s7e3Rr18/dOnSBXw+v5GPwP9XVlaGpKQkfPz4EcrKyjAyMoK0\\\\\\n\",\n              \"tHSD1uHnNxFh164j6OwVaGppV7tNeXk5lv40G08ePUB4+E32vLFvmDj1neyJzN8xaWlplJaWilyu\\\\\\n\",\n              \"tKQE0dHR8Pb2hoyMDKSkpCApKQkiQnl5OUpLS1FUVISSkhLIy8tXCrXqgk7U9z7/XUFBQeSAzMzM\\\\\\n\",\n              \"hJOTM9o7OuH+00SoqFT+bb+HWx/8vGw1li+aDxeXHrh+/Rq0tLQwePBgNGtuJnR4AYDXOD9MGDEA\\\\\\n\",\n              \"ZaWlKCsrg7S0NBISEmBubo7JkyfDzc0N7du3/6qddZMmTWBiYtJo+w8NDcXl0Cs4ExoBDQ63xu2k\\\\\\n\",\n              \"pKSwZNUGzJ3qgzlz5iAgIKDR2sT8OFiAfcdMTExw9tx5kcs9evgn5s6dg0GDBuHdu3dITU2teP3z\\\\\\n\",\n              \"53fv3kFGRgYcDgc8Hg8aGhpQU1ODiooKlJWVIS8vDzk5OcjIyKC4uBj5+fnIzMzEq1evUFBQgPz8\\\\\\n\",\n              \"fOTn51f8/d/vFRYWQlZWVuggVFBQwNGjQXDp1RcLf11Z4/dTV9fA2j924Jd502FlZY3373NBkMDI\\\\\\n\",\n              \"8ZNFOk627eyhrKIKI4Om8PLygrOzM1q2bAkpKSmRj7m42rBxI6bMnl9reH0mISEB/0XL0b2jNVau\\\\\\n\",\n              \"XAk1NbX/oIXM94ydQvyOlZaWwsDAACfPXYZ5S+FGFu9zc2FraYK//voLmpqatW4rEAiQk5NTbbj9\\\\\\n\",\n              \"8+e0tDSoqKiAz+dDR0en4vXPn/l8Pvh8PmRkZCrtv6ioqM6g+/xnbGwsIm/fwY17sUKN3MrKytCm\\\\\\n\",\n              \"pSE+fngPRSVlbD90GlY2bYQ6Tp95e7rDsoUJbGxsIC0tDRkZGUhLS1d5ifq+tLT0N3l69p8SExPR\\\\\\n\",\n              \"tm073Il5CTl5eaHLTfMdCWdHe8ycObMRW8d8KXHqO9kI7DsmLS2N8ePHY9PvaxCwex8kJCTqLLNt\\\\\\n\",\n              \"y0ZwOFyhHikvKSkJLpcLLpcLKyurGrcTCATIzs6uEm7Pnj3D1atXK4IuPT0dqqqqVcLt3z9ra2tX\\\\\\n\",\n              \"ex1n4MBB8J44TeiOv0mTJvCZNA2Xz5/G27cpEJSXC1Xun8rKypCWloaYmBiUlpaipKQEpaWlVV6i\\\\\\n\",\n              \"vl9aWgopKal6B2Fjvn/lyhXYOzqJFF4A0M3VHRFh51mAMfXGAuw7N3v2bDg4OmLtqt8wd/7PtYbY\\\\\\n\",\n              \"8aOHcXj/XgwY0B9WVlZYsWIFvL29hQq+2khKSoLH44HH48Ha2rrG7QQCATIzM6uM5GJjYxEaGloR\\\\\\n\",\n              \"fOnp6VBXV68Ubtra2rhw4TxWbNwpUts8vEZj3cqlkJOVQ9yzWLRu017osuXl5Uh5/Qr7du9Ay5Yt\\\\\\n\",\n              \"Raq3LkSEsrKyegdhXe/n5+d/8X7S09PR3sFJ5O+mrKyCT58+NejxYn5MLMC+c6qqqrgSGoqebm54\\\\\\n\",\n              \"GvsYk6bNRLv29pVC6dnTJ9i1fSuuh13GlSuhsLS0hK+vL3x8fHD48GHs3LmzUScCfCYpKQktLS1o\\\\\\n\",\n              \"aWmhdevWNW5XXl6OzMzMSqcpX758CVlZWSgpK4tUJ5enCRIIABAO7NwCz5HCB3bEtVDo6uo0eHgB\\\\\\n\",\n              \"f18v+jzS+VYdOXIER4+fFrncxw/voaKi0ggtYn40LMB+ADo6Orhz+zZ27tyJKRPGQU5OHuYtLSAp\\\\\\n\",\n              \"JYWkhJd4++Y1fHx88ODBg4rrXtbW1oiKisLmzZthb2+P2bNnY86cOd9Eh/rx40ckJiYiJiYGd+/e\\\\\\n\",\n              \"RWxsLJKSklB9OwtDAAAbhElEQVRcUiLyvgQCAQQCgrGxMZKSXuHe7XB0cHQWqtz2javh1qPbF3yD\\\\\\n\",\n              \"74ODgwOmTJ2Kgvx8KIgw0zLs0nn0dOnSiC1jfhRsEscPRiAQ4Pbt20hOToZAIACfz4ezs3OtwfTq\\\\\\n\",\n              \"1StMnDgRqamp2LVrF9q1a9fo7SSiiutkd+/eRXR0NJ4/f46UlBQUFxdDSkoKZWVl4HK5MDIygpWV\\\\\\n\",\n              \"FU6dOo1j58LQQsgJKwDwMPoeRgxxR8C2bVBUVMSYseNwMPgKmpua11hGIBBg1eJ5iImOQm5ONtzc\\\\\\n\",\n              \"3LB27dof8gbdPn37wrGrG4aOHCfU9unvUuHa2Q7Jr16xUdg3Spz6zm97mhPzxYgId+7cwfARI6Cr\\\\\\n\",\n              \"qwtlZWXw+XwMGDAQhYWF8PLywqhRo+Di4lLnqMrQ0BAhISHw9/dH3759MXPmTOTl5TVIO8vLy/Hy\\\\\\n\",\n              \"5UucPHkSM2bMgLOzM/T09CArKwsDAwO4urpixYoVePr0KUxMTDB79mwEBwfj2bNnKC4uRnp6Oq5f\\\\\\n\",\n              \"v4527dpBTk4WgTu2iFT/jq0bIS8vj7S0NAQEBMCkuTHGDemFI3t3IO/Txyrbxz76E9O9h+HlsxiE\\\\\\n\",\n              \"37yBp0+fQlJSEpaWlrh48WKDHBNxMnPGDGzdsBoZ6Wl1bisQCPDbIn+MHDGChRfTINgI7Dv05s0b\\\\\\n\",\n              \"DB4yBNlZ2fD28UOf/gOgpqaO/Lw8XAkNwa4dASgpLsLJkydhaWkp0r6zsrIwe/ZsREREICAgAD17\\\\\\n\",\n              \"9hSqXFFREf766y9ERkbi3r17ePr0KZKTk5GbmwsJCQkQEdTV1WFgYAALCwt06NABdnZ2MDMzq/F+\\\\\\n\",\n              \"oYSEBAQEBGDfvn3o2LEjPD09MWnSZIRF/gkdPf0625SU+BJ9unXEnj17MHnyZOTk5GDUqFFwcnJC\\\\\\n\",\n              \"8NmzuHb1GpxcekKDw0NxcTHu34lAaXERJk30w7Rp0yD/j9l3169fx/jx4+Ho6IgNGzaAw+EId0C/\\\\\\n\",\n              \"A0uXLsWRoGMIPBoMXb2m1W5TWlqKn+dMwZukl7h6NazSsWO+LWLVd36F9Rf/E+K0IGVDSk5OJn19\\\\\\n\",\n              \"ffptxRp6n19KHwvLq7w+FJTRzsD9pKmpKdSzq6oTGhpKRkZGNHz48EqPA3n//j1dv36dFi9eTO7u\\\\\\n\",\n              \"7mRqakpKSkokISFBEhISJCsrS02bNqWuXbvSrFmz6NSpU5SQkCD0quXl5eUUEhJCvXr1Ii6XS3Pn\\\\\\n\",\n              \"zqXExERKTU0lV1dXkpGRIX0DQ3rw7BW9ySmu8RUVE08Ghka0ceMmcnJyopEjR9Lbt29p+fLlpK+v\\\\\\n\",\n              \"Tx06dKCNGzfSli1byMjIiLy9vUlWVpYKCwtrbFteXh5Nnz6d+Hw+nTx58ouOqzgSCAS0bt06UlVT\\\\\\n\",\n              \"o6Ejx9KFa1GUmFFASZmFdP9JEs1d+Cvp6TelPn37CfXkZ+brEqe+kwXYd0QgEJCtrS2tWLWu2uD6\\\\\\n\",\n              \"92vfwaOkr69f63Oyaqrn7du3FBgYSK1btyZpaWlSV1cnaWnpiqBSVVUlCwsLGjRoEK1atYru3LlD\\\\\\n\",\n              \"79+//+LvlpOTQ+vXrydjY2OysbGhwMBAysnJoePHj5OLiws1adKEtLS0KCgoiJYuXUo8nib9umo9\\\\\\n\",\n              \"PX2VUSm4YhPTaNFva0hTU4tUVFTIxsaGxo8fX+m5VmVlZXT27Fnq2bMn8Xi8iodYtmrVSqhV+m/f\\\\\\n\",\n              \"vk1mZmY0ePBgSktL++LvLG7S0tLot99+IwMDw4pV9pWVlcnb25v+/PPPr908Rkji1HeyAPuOXL9+\\\\\\n\",\n              \"ncxbWtCHgjKhAuxjYTm59HCl/fv3V7u/srIyiomJoQ0bNtCwYcOodevWxOFwSFJSkgCQjIwM6ejo\\\\\\n\",\n              \"kJWVFXE4HLKwsKAbN2406DOgHj16RD4+PqSmpkZeXl50+/ZtioyMpAkTJpCGhga1adOGNDQ0aNas\\\\\\n\",\n              \"WRXP47p58yYpKSmRuoYGKSgoUpduPaj/IA/q0t2VVNXUyGv4CLpy5QoZGBiQqqpqpRHkv7148YL0\\\\\\n\",\n              \"9fVJRUWF9PX1ycfHR6jvV1hYSPPnzydNTU06dOjQD/cokdLS0lpHq8y3S5z6TnYN7DsyePAQdOzk\\\\\\n\",\n              \"DJ8JE4Uuc+nieaxbvRzLli1DREQEHj58iJcvXyItLQ35+fkAAEVFRfD5fJiYmMDW1hZOTk5o165d\\\\\\n\",\n              \"pQvxpaWlWL9+PdauXYsFCxZg+vTpaNLky+7SKC0txZkzZ7BlyxYkJibCz88PPXr0wOXLl3HgwAFI\\\\\\n\",\n              \"S0tj1KhRKC4uRkBAAHbv3o0+ffqAiHD69GkMHz4choaG2LNnD5o1a4bo6Gh8+vQJKioq6NChAwDA\\\\\\n\",\n              \"xcUFLi4ukJSUxJ07dxAWFgY5Oblq29OlSxf4+/vj9OnTOHXqFJSVlTFhwgR4e3vXudzWgwcPMG7c\\\\\\n\",\n              \"OBgYGGD79u3Q1dX9omPCMP8Vseo7v3KANhpx+i2ioaioqFDi6zShR18fC8spN6+EpKWlSUpKing8\\\\\\n\",\n              \"Htna2tKIESPojz/+oCdPnoj8yPgXL15Q165dyc7Ojh4+fChS2dTUVFqyZAnx+XxycnKiffv20fbt\\\\\\n\",\n              \"26lTp07E4/Fo6tSpFB0dTe/fv6fBgweTra0tJSYmEhFReHg42dvbk7a2NrVu3brGdqemplLLli3p\\\\\\n\",\n              \"559/JoFAQOXl5eTh4UHDhg2rsYyzszNdv36dXrx4QXp6evTgwQPy9vauGBXW9bDG4uJiWrJkCXG5\\\\\\n\",\n              \"XNq1a9cPNxpjxIs49Z0swL4TAoGAANQ4caO2l46ODr1+/bpB2xIYGEg8Ho/mzZtX6xOGBQIBXbt2\\\\\\n\",\n              \"jZydnUlBQYFcXV1p6dKl5OnpSSoqKjRw4EAKDg6m4uJiIiKKiYkhExMT8vPzo8LCQoqJiaFevXqR\\\\\\n\",\n              \"oaEhbdiwgdTV1Sk+Pr7aut68eUMmJia0bNmySu8XFBSQvb09LVy4sNpynwNMIBAQj8erOFY5OTm0\\\\\\n\",\n              \"ceNGMjU1pVatWlFAQAB9/Pixxu8aExNDdnZ21L17d0pKSqrtEDLMVyNOfScLsO+IgoICpWS8Fym8\\\\\\n\",\n              \"PhSUkZqaGmVnZzd4e9LS0sjT05OMjY3p6tWrlT7Lz8+nlStXko6ODikoKJCNrR31dOtFHR0cSUVF\\\\\\n\",\n              \"haysrOjAgQOVRit79+4lLpdLBw4coMTERBoxYgRpaWnRpk2bqKioiPr3709Lliypti1JSUlkZGRE\\\\\\n\",\n              \"69atq/bzjIwMMjY2pj179lT57HOAERH169ePgoKCKn0uEAjo6tWrNHDgQFJXV6dJkyZRbGxstfWU\\\\\\n\",\n              \"lpbS6tWricPh0ObNm0Ue4TJMYxOnvpMF2Heks5MTHTxyXKQAux4RRUZGRo3akZ4/f5709fVp7Nix\\\\\\n\",\n              \"9ODBA5o9ezYpKyuTkpIyTZ02k548f0n5xYKKV/aHAtq99wBZWraiUaNG0YcPH2j8+PFkZmZG4eHh\\\\\\n\",\n              \"NHXqVNLQ0KAlS5ZUjHjOnz9PzZs3r3biQHx8PDVt2pS2bNlSazvj4uJIU1OTwsLCKr3/zwBbvXo1\\\\\\n\",\n              \"TZs2rcZ9vH37lhYvXkx8Pp86d+5MQUFBFaPHf9fVsWNHcnR0pL/++qvOY8gw/xVx6jtZgH1HgoKC\\\\\\n\",\n              \"yMm5q0gB5jViFK1evbpR21VeXk4nT54kQ0NDkpCQIBMTE1JSUqaw6xGVguvfr4ycT9TdpQdpa2vT\\\\\\n\",\n              \"gAEDyN/fnzQ0NGj69OmUnp5esf/8/HwyNDSkK1euVKn72bNnpKurSzt37hSqreHh4cTj8SqNoP4Z\\\\\\n\",\n              \"YLdu3aI2bdrUuZ+SkhI6ceIEdenShbS1tWnhwoWUnJxcaZuysjLatGkTcTgcWrNmTcUsSob5msSp\\\\\\n\",\n              \"72QB9h0pLi4mHR0dOn0uRKjwuhX1gNTV1SkzM7PSfgoKCighIYHi4uIoKyvri9vzz3u3TExMqFOn\\\\\\n\",\n              \"TqSkpERKSkp07MSZWsPrnyFmYGBIampqNHLkyGqvHS1YsIA8PT2rvB8TE0N8Pp8OHDggUrsPHTpE\\\\\\n\",\n              \"BgYGlJqaSkSVA6ygoIAUFBQoLy9P6P09e/aMpk2bRhoaGtS3b1+6fPlypRFvQkICdenShdq2bVvj\\\\\\n\",\n              \"qUeG+a+IU9/JAuw7c+vWLeLxeHT+Ulit4RV590/i6+jQiRMnKsrGxMSQr68vqampUVMDAzI2bk4q\\\\\\n\",\n              \"KirUpWtXOnnyJJWUlAjVhs/7UVZWJktLS+Lz+WRhYUFr1qyh48ePk4WFJeUVlQsVYPnFAtq0JYC6\\\\\\n\",\n              \"detebV3Pnj0jDodDKSkpld5/8OABaWlp0bFjx77oOC5dupTs7OwoLy+vUoAREbVv355u3rwp8j7z\\\\\\n\",\n              \"8vJo165d1Lp1azI2NqZ169ZV/IIgEAhox44dxOVyaenSpTUe68zMTFq1ahU5d+lCrW1sqGNHB5o6\\\\\\n\",\n              \"dSo9e/bsi74nw/ybOPWdLMC+Qzdu3CAej0dDPIbSlWsRlW5svn3vIY319iENDY2Kzr28vJxmzpxJ\\\\\\n\",\n              \"fB0dWrjoV4pPfFuxfdaHQgrcf5js7R2odevWVYLis5KSEjp27Bh16NDh7wBs2pS4XC7NmDGDHj58\\\\\\n\",\n              \"WDEZw8PTkzZs2iJ0eOUXCyg9+yOpq6tXqVsgEJCzszNt2rSp0vtRUVGkqalJZ86c+eJjKBAIaMyY\\\\\\n\",\n              \"MdS3b19ycnKqFGAzZ86kFStW1GvfUVFRNHLkSFJTU6PRo0fTvXv3SCAQ0OvXr8nNzY2sra0rrV5R\\\\\\n\",\n              \"UFBAvr6+pKqqSsNHjqZTZy/Szch7FHLlOs2dv5C0tLSoS9eubHYjU2/i1HeyAPtOfT5919zEhLS0\\\\\\n\",\n              \"tMjU1Ix0dHVJT0+Pli5dSu/evSOivzvTSZMmkX1HR0pOzap1tuKiX3+j5s2bVzrlmJqaSosWLSIN\\\\\\n\",\n              \"DQ3icrmkoKBAgwcPpgsXLlQ7irCytqbb9/4UKcDyiwVk39GBIiIiKu3rwIEDZGNjU+naUUREBPF4\\\\\\n\",\n              \"PLp48WK9j2FxcTF17dqV9PT0KgXYiRMnyN3dvd77J/p7RLVmzRoyMjIiOzs72r17N+Xl5dGBAweI\\\\\\n\",\n              \"x+PRggULKDs7mxw7daJBgz0o6W1Gtf8+WR8Kadny1aSjo8MmhTD1Ik59Jwuw71x5eTm9efOGnj59\\\\\\n\",\n              \"SsnJyVUmCpw7d47MzFrQm7Qcoa6bTZsxmzw9PenWrVvUo0cPkpGRITk5ObKxsaGdO3dSbm5ure1p\\\\\\n\",\n              \"0aIFRf8vVuQAc+7StdLswJycHNLW1qZ79+5VvHf16lXi8XhVZhHWR25uLikoKNCUKVOI6O/AP3Hi\\\\\\n\",\n              \"BKmqqVHnzk7k6NiJPDw86dy5c/VaQuvzIsXu7u6koaFBM2bMoFu3btGAAQNIU1OTBnsMFeoev83b\\\\\\n\",\n              \"dlCzZs1qvfeOYWojTn0nC7AfnItLD9oZuF/oWYtv0nJIXl6epKWlicPhkL+/P718+VLo+hwcHens\\\\\\n\",\n              \"hcsihVdeUTmZmJhWWtnDz8+P/Pz8Kn6+ePEi8Xg8Cg8Pb9DjQ0TUoUMH4nA4tGTJEmrZsiWZm7ek\\\\\\n\",\n              \"VWvW0/lLYXQx9Bpt3raD2rZrTwYGBnTw4MF615eUlEQLFiwgTU1N6tixI6moqFJa9ieh/41ce7pR\\\\\\n\",\n              \"YGBgA3xz5kckTn0nC7Af2IsXL4jH41FGbr5IU++HjxxN3t7eX7Qk0u+//05Dhw0XKcBu3oqiZs2a\\\\\\n\",\n              \"Vczcu3fvHmlra1NOTg4REQUHBxOPx6OoqKgGPT6fOTs707hx40hFRYVOnb1Y42LJ18LvkKGhEa1Z\\\\\\n\",\n              \"s6ZB6i0qKqI+ffqS36SpIv37HD99juyEmOrPMNURp76TPZH5B/bnn3/CoVPnGhexrYlbL3dkZmZB\\\\\\n\",\n              \"QkJC5DrHjh2Ly5cuIiMjQ+gyGzesA4fDQWpqKsrKyuDn54c1a9ZAXV0dx48fx4QJE3Dp0qWKhXob\\\\\\n\",\n              \"2sePH3H27DlcuX4LLj161vi927Zrj9BrEdiydSuCg4PrXa+srCyexz3HyNFjRSrn0qMnkl+9QkpK\\\\\\n\",\n              \"Sr3bwDDfMhZgP7D8/HwoKCiKXE5RSalipXpRqaurY/z48fAePQIlJSV1bn/q5HFERoSjVatWsLa2\\\\\\n\",\n              \"hpubGxQVFTFixAgcPHgQM2bMwJUrV2BnZ/dF7RFGRmYmFvy8GC0t6n56NV9HB79v3IJlv/0GaoAH\\\\\\n\",\n              \"PeTm5EBbmy9SGSkpKWhpaSM7O7ve9TPMt4wF2A9MVVUVuTmid3I52dlQVVX94npXrlwJNXVVDOzb\\\\\\n\",\n              \"Gylv31a7TWlpKQK2bsaUiRNgaWWN4OBgWFlbIzw8HM+fP8fw4cMxf/58XLt2DVZWVl/clrqkpKQg\\\\\\n\",\n              \"NycHw4aPFLqMS4+eyMnOQXR0dL3rl5WVFSro/62oqAiysrL1rp9hvmUswH5gnTt3RtSd28jNzRWp\\\\\\n\",\n              \"XPCZk3Bx6f7F9TZp0gTHjx2DvX17dGjbGsM8BiH4zClE3bmNG9evYemSX9CiuSGCz5zGjVtROHfx\\\\\\n\",\n              \"Cp78lYQBgzzA4XDQqlUrnDlzBjIyMnj27FmDjHRqcu7cOfR0c4eysrLQZSQlJTF0+EicPHmy3vW3\\\\\\n\",\n              \"MDfH3ajbIpV5l5qK7Ows6Onp1bt+hvmWsQD7gfF4PPR2d8fhg/uFLpPy9i0iI8IxYsSIetUtJSWF\\\\\\n\",\n              \"5cuXIzk5GW49e2DFsiXwGjoYK35bitzc9wi+cBkhV67DxNQMAKCkpISx3r6IuPMAiUlJmDdvHnbs\\\\\\n\",\n              \"2IGlS5fC0dERd+7cqVd7apKZmQlDIyORy+nq6iIrK6ve9ftNmIA9u3aIVGbf3t3w9PSEoqLop4cZ\\\\\\n\",\n              \"RpywAPvBzZg+HX9sWIuEhJd1blteXo5ZMyZj3LhxUFJSapD6lZSU4OrqirS0NNy59z9cvnoT6zb8\\\\\\n\",\n              \"UeP1Jm0+HxdDr2PTpk2wsbHBw4cP4evrC09PTwwaNAgvXrxokHZ9Jisri+KiIpHLFReXNMgpvP79\\\\\\n\",\n              \"+yPhZTwib4ULtX1WZib27t6BSZMm1btuhvnWsQD7wbVp0wZLlixB314ueBL7uMbt8vPzMXbkMJQW\\\\\\n\",\n              \"F2PFihUN2obt27dj2PBR0NTSEmp7AwNDuPftj8DAQEhJSWH06NGIj49H27ZtYW9vj6lTpyIzM7NB\\\\\\n\",\n              \"2taiRQtERkaIXO7P6HswNTWtd/3S0tLYu3cvxo4chv89/LPWbbOzszFkYB+MHTu2Ua8LMsw342vP\\\\\\n\",\n              \"428s4nQvw7fg0KFDxOFwqFdvdzoZfIESX6fRm7Qciop+RFOmzSQOh0Ojx4yhoqKiBq23vLyceDwe\\\\\\n\",\n              \"PYyNE+lepxu37lLz5s2r7C8jI4OmTp1KHA6Hli9fXu8VKUpKSkhZWYXuPogRum1Jb9JJTU2tXiv5\\\\\\n\",\n              \"/9vp06eJy+XSzNnzKDYuocrN5WvX/0EGBoY0Z86cL7o/j2E+E6e+U4KoEa+Af0Vt2rTBgwcPvnYz\\\\\\n\",\n              \"xEp+fj6CgoKwc+cuvHz5AiUlJeDxeBg8eDAmTpwIoy+4FlSX9+/fw8DAAG/TRZtIUlZWBp6aAkpK\\\\\\n\",\n              \"SiApWfVEwosXL/DTTz/h7t27WLZsGUaOHAkpKakvaqOxsTFsbNsg8MARoe59W/LLT8jOSse+vXu/\\\\\\n\",\n              \"qL6avHz5Elu3bsXBgwfR3MQUHA4H+fkFiHn0EN27u2DKlMlwcnJq0DqZH4849Z0swJivKjMzEy1a\\\\\\n\",\n              \"tMCrFNFO+RERNJRlUVRUhCZNmtS4XVRUFObMmYO8vDysWbMGrq6uIrexU6dOyMzMRJ9+A/HLkmW1\\\\\\n\",\n              \"hljg7h3YsG41oqKiwOeLdv+WsAoKCnD//n18+PABCgoKaNWqFbS1tRulLubHI059Z83/5zPMf0BN\\\\\\n\",\n              \"TQ0FBQX48OGDSPeWpaakQEVFpdbwAgB7e3tERkYiODgYU6dOhaGhIdasWYPWrVvXWUdOTg5CQ0OR\\\\\\n\",\n              \"lpaGbt264fzZ03j8+BGmTp+Fzk5dKgXZnw+isWPbZkTfv4uwsLBGCy8AUFBQgLOzc6Ptn2HEBZvE\\\\\\n\",\n              \"wXxV0tLS6NOnL4KOHBKp3MEDe+Hh4SHUthISEhgwYACePn2Kfv36oWfPnhg9ejTevHlT7faxsbEY\\\\\\n\",\n              \"O3YsmjVrhqNBx+DQqTPyC4sASOD506fw9R4N21Yt4OU5ECOGDUbHdjYYN2oYWlu3QnR0NExMTET6\\\\\\n\",\n              \"LgzDfBk2AmO+ukmTJmLixEkY7+sn1HWqkpIS7AvchUshISLVIy0tjcmTJ2PkyJEVozBfX1/Mnz+/\\\\\\n\",\n              \"YvQXHBwMHx8fTJs5G7HPX4DH41WUJyLcigjH6pXLkffpEzwGD4KcnBz4fD7s7e2/+BobwzBfhl0D\\\\\\n\",\n              \"Y746IkK37t3RwtwSq9aur/Uak0AgwKQJ3igqyMepU/Vb6eLt27dYvHgxLly4gIULF6JFixYYOXIk\\\\\\n\",\n              \"zpwLgW0tayuWl5djysQJePv2NUIuXoS0tHS92sEw3xJx6jvZKUTmq5OQkMCpkydxLyoSk/3GI7OG\\\\\\n\",\n              \"lepTU1IwduQwvE5OwoEDwq8eUhM9PT3s2bMHV69eRUhICDw8PBCwc0+t4QX8vYrI5m3bkZeX3yDL\\\\\\n\",\n              \"RTEM82VYgDHfBHV1dYSHh0NeVhp21uYYP2YEThw7isuXLuLY0cMY6TUE9m2t0VRfF2FXrjToMkmt\\\\\\n\",\n              \"WrWCv78/tLS04dart1BlmjRpgukzZ2Pbtm0N1g6GYUTDTiEy35zc3Fzs3bsXUVF38enTJ6ioqMDJ\\\\\\n\",\n              \"qTNGjRol0qK6ohgxYgRs27THpClThS5TVlYGM2MDXLt2DS1atGiUdjHMf02c+k42iYP55qirq2PW\\\\\\n\",\n              \"rFn/aZ0vExIwzmeiSGWaNGkCC8tWSExMZAHGMF8BO4XIMADKy8rqvKesOtLS0igrK2uEFjEMUxcW\\\\\\n\",\n              \"YAwDQFNLC8nJr0QqQ0R49SoJWkIuQswwTMNiAcYwAIZ6euLAvkCRyjyIjkZhQQHatm3bSK1iGKY2\\\\\\n\",\n              \"LMAYBsCQIUMQ8+h/iP/rL6HLbN+2BX5+E6tdTJhhmMbH/s9jGABycnL46aeFGDHMA+/fv69z+8MH\\\\\\n\",\n              \"D+B2ZATGj/f+D1rHMEx1WIAxzP+ZPn0aunfvjm7OjngQHV3tNp8+fcLqlcux+JefEBISAg0Njf+4\\\\\\n\",\n              \"lQzDfMam0TPM/5GQkMDvv6+D2U4zjPTyAJfLw7DhI6Gjq4uioiLcjbqD40FH4OTsjNu3b8PAwOBr\\\\\\n\",\n              \"N5lhfmgswBjmHyQkJDBhgi/Gj/fG5cuXcfr0GUSEX4esrCxamrfE48ePoaen97WbyTAMWIAxTLWk\\\\\\n\",\n              \"pKTQu3dv9O4t3NJSDMP899g1MIZhGEYssQBjGIZhxBILMIZhGEYssQBjGIZhxBILMIZhGEYssQBj\\\\\\n\",\n              \"GIZhxBILMIZhGEYssQBjGIZhxBILMIZhGEYssQBjGIZhxJIEEdHXbkRj4HK5MDQ0/NrNYBiGESuv\\\\\\n\",\n              \"Xr1CVlbW126GUL7bAGMYhmG+b+wUIsMwDCOWWIAxDMMwYokFGMMwDCOWWIAxDMMwYokFGMMwDCOW\\\\\\n\",\n              \"WIAxDMMwYokFGMMwDCOWWIAxDMMwYokFGMMwDCOWWIAxDMMwYokFGMMwDCOWWIAxDMMwYokFGMMw\\\\\\n\",\n              \"DCOWWIAxDMMwYokFGMMwDCOWWIAxDMMwYokFGMMwDCOWWIAxDMMwYokFGMMwDCOWWIAxDMMwYokF\\\\\\n\",\n              \"GMMwDCOWWIAxDMMwYokFGMMwDCOWWIAxDMMwYokFGMMwDCOWWIAxDMMwYokFGMMwDCOWWIAxDMMw\\\\\\n\",\n              \"YokFGMMwDCOWWIAxDMMwYokFGMMwDCOWWIAxDMMwYokFGMMwDCOWWIAxDMMwYun/AWBd0twf7eOr\\\\\\n\",\n              \"AAAAAElFTkSuQmCC\\\\\\n\",\n              \"\\\"\\n\",\n              \"  frames[4] = \\\"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAbAAAAEgCAYAAADVKCZpAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\\\\\\n\",\n              \"AAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0\\\\\\n\",\n              \"dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOzdd1RUx+P38TdNRKVIEbAjGAuoYK9Y\\\\\\n\",\n              \"0FjAhg2xY481ahI1sUT9GjUaC3aDvWtQEVBELNgVRRRLxBYFASnS67L3+SOP/EJA2TUCrs7rnD2U\\\\\\n\",\n              \"vXPv7KL3szN3Zq6aJEkSgiAIgqBi1Eu6AoIgCILwIUSACYIgCCpJBJggCIKgkkSACYIgCCpJBJgg\\\\\\n\",\n              \"CIKgkkSACYIgCCpJBJggCIKgkkSACYIgCCpJBJggCIKgkkSACYIgCCpJBJggCIKgkkSACYIgCCpJ\\\\\\n\",\n              \"BJggCIKgkkSACYIgCCpJBJggCIKgkkSACYIgCCpJBJggCIKgkkSACYIgCCpJBJggCIKgkkSACYIg\\\\\\n\",\n              \"CCpJBJggCIKgkkSACYIgCCpJBJggCIKgkkSACYIgCCpJBJggCIKgkkSACYIgCCpJBJggCIKgkkSA\\\\\\n\",\n              \"CYIgCCpJBJggCIKgkkSACYIgCCpJBJggCIKgkkSACYIgCCpJBJggCIKgkkSACYIgCCpJs6QrUFSM\\\\\\n\",\n              \"jY2pXr16SVdDEARBpTx//pzY2NiSroZCPtsAq169OkFBQSVdDUEQBJXSuHHjkq6CwkQXoiAIgqCS\\\\\\n\",\n              \"RIAJgiAIKkkEmCAIgqCSRIAJgiAIKkkEmCAIgqCSRIAJgiAIKkkEmCAIgqCSRIAJgiAIKumzncgs\\\\\\n\",\n              \"CIJQXNLT03NXrzAxMaF06dIlXKMvg2iBCYIgfABJkrh06RL9XQZiZGxCwybNsGvcFCNjE1wHD+H6\\\\\\n\",\n              \"9etIklTS1fysiRaYIAiCktLS0hgw0JVbIXdp1mMg8w5fpIyePgApCfFc9z1Mzz79aN2iBbt2bhct\\\\\\n\",\n              \"siIiWmCCIAhKyM7OpptTd2Ky1Jix3Zf2A0bmhhdAOQNDOriOYcaOkzyLS6J3377k5OSUYI0/XyLA\\\\\\n\",\n              \"BEEQlLBk6VLeZEq4zFyGplapd25XSrs0g+eu4q/oN6xZs6YYa/jlEAEmCIKgoOzsbNat30DXMTPQ\\\\\\n\",\n              \"0Cz8CoyGphZdRk1nlfta5HJ5MdTwyyICTBAEQUHHjx+nvFllKlnVUbhMdWtbNEqXwd/fvwhr9mUS\\\\\\n\",\n              \"ASYIgqCgq9euYdWkjVJl1NTUqNnEnmvXrhVRrb5cIsAEQRAUlJScgrZOWaXLaeuUJTklpQhq9GUT\\\\\\n\",\n              \"ASYIgqCg8gb6pCUnKl0uPTkRA339wjcUlCICTBAEQUEdHRy4f+GUUhOU5XI59y6eokOHDkVYsy+T\\\\\\n\",\n              \"CDBBEAQFdejQAQ1JxtM7QQqX+fPGRYwM9GnevHkR1uzLJAJMEARBQWpqagwdPJiDy38iMz2t0O3T\\\\\\n\",\n              \"U5Lw2biU72dMR01NrRhq+GURASYIgqAgLy8v1q51x6KSGb9/70Zq4pt3bpv8JpYt342gW6cODB48\\\\\\n\",\n              \"uBhr+eUQASYIglAIuVzOvHnzmDBhAsePH+fC+XM4ObRlsasDB3/9kfCw+2RlZpCVkc6Lh3c59OuP\\\\\\n\",\n              \"LBnUiQE9nVi/1l20voqImvSZLpfcuHFjgoIU76cWBEEoyJs3bxg8eDApKSkcPHgQU1PT3OciIiKw\\\\\\n\",\n              \"b9uWlNQ0Et7EA2BesTIj3YYzetQozMzMSqraH0yVzp1iNXpBEIR3uHv3Lr1796Z79+4sW7YMLS2t\\\\\\n\",\n              \"PM+XLVuWmNeviYyMpGxZ5eeHCf+N6EIUBEEowP79++nQoQM///wzK1euzBdeAL6+vrRr106EVwkR\\\\\\n\",\n              \"LTBBEIR/kMlk/PDDDxw5cgR/f39sbW3fue2xY8fo2bNnMdZO+CcRYIIgCP/f69evGTBgANra2gQF\\\\\\n\",\n              \"BWFoaPjObTMzM/Hz8xO3SilBogtREAQBuHHjBk2aNKFVq1b4+Pi8N7wAzp49i42NTZ5BHULxEi0w\\\\\\n\",\n              \"QRC+eB4eHsyaNYtNmzbRu3dvhcqI7sOSJwJMEIQvVmZmJlOmTOH8+fMEBgZSu3ZthcrJ5XK8vLw4\\\\\\n\",\n              \"e/ZsEddQeB8RYIIgfJEiIiLo27cvFStW5Pr16+jq6ipcNigoCD09Pb766qsirKFQGHENTBCEL05g\\\\\\n\",\n              \"YCBNmjShR48eHD58WKnwAtF9+KkQASYIwhdDkiTWrFlD//792b59O7NmzfqgZZ6OHTtGr169iqCG\\\\\\n\",\n              \"gjJEF6IgCCpLJpPh7e2Nx/ZdvIqMRF1dnRoW1Rk/ZhRt27bNE05paWmMGTOGe/fuceXKFSwsLD7o\\\\\\n\",\n              \"mE+ePCEuLo6mTZt+pFchfCjRAhMEQSVt376DSlWqMfnHhcRWqI9hp1EYdBjOU63KuAwfg2WtOvj7\\\\\\n\",\n              \"+wPw9OlTWrZsiZqaGpcuXfrg8IK/W1/du3dHXV2cPkuaaIEJgqBy/rf4F35bu5Gm43/FqIZ1nudM\\\\\\n\",\n              \"6zThq44uvLpzib4ugxg/eiTbtm3lp59+YuLEiUp1Gaanp3Pw4EHOnr9AUnIy+vp6XLl4gSVLlnzs\\\\\\n\",\n              \"lyR8ABFggiColEOHDvGb+3ra/bgNHQOTArdRU1OjUoPWlJmxnt8WDGPt6pWMGTNG4WNkZmby05y5\\\\\\n\",\n              \"bPndA6Ma1hjZtELLuAYv05KJk5VixKjRBN0KZt6cnwpcI1EoHuJ2KoIgqAxJkqhtXZ/K3b/BvF5L\\\\\\n\",\n              \"hcr8eWofFVMec8zzsELbp6am0qlLN6KztLDpPxVd0yr5tkmKfM7dfcupYVwW3+PH0NbWVuZlfNJU\\\\\\n\",\n              \"6dwpOnEFQVAZly9f5k1yCmbWzRUuY9HaiYDTp3n16lWh20qSRD8XV+I0DGg+4dcCwwtAz7w6Laas\\\\\\n\",\n              \"4kWKxNDhbgrXRfi4RIAJgqAyDh4+TMXmjqgpMYCiVBldKtvZc/z48UK3vXHjBtdvBtNo+JxCj6Gu\\\\\\n\",\n              \"oUmjUQvwO32a+/fvK1wf4eMRASYIgsqIio5Fp3zB173eR0vfhKioqEK3W7VmLdXa9kFdU7HrWpql\\\\\\n\",\n              \"SlO9TS/WrF2ndJ2E/04M4hAEQWWULq2NXCZTulxWeioLFy5k+fLlmJqa5j7MzMxyvzc2NuaPw4fo\\\\\\n\",\n              \"vvKEUvu2sO/N3p9d2bhehFhxEwEmCILKqFe3Ntd9LgJ9lSqXHvGII0eOYG9vT3R0dO4jKiqK6Oho\\\\\\n\",\n              \"bt26xcuXL5GjRmnd8krtu4yRGWkpyWRlZVGqVCmlygr/jQgwQRBUxrBhw5i/YBE2/RPQ1jVQqEz8\\\\\\n\",\n              \"84dkJryma9euaGpqoq+vX+AivImJiZhVqqx8pSQJSZLQ0NBQvqzwn4gAE4RPUE5ODgEBATx79oyc\\\\\\n\",\n              \"nBxMTU3p0qULZcuWLemqlSgTExMcnZx46Leben0nFrq9JEmEHtmAhppEUFAQzZu/e/Sirq4umhqa\\\\\\n\",\n              \"pLyOoFyFSgrXKSHiCcYVzESAlQAxiEMQPiFJSUn8b/FiqlS34JtpP7D7xHn2n77Mz8vXUKlyFSZP\\\\\\n\",\n              \"mcqLFy9Kupol6rdfl/I66CRPznm+dztJkrh7aA36OUksXLAAZ2dnxowZQ1xcXJ7tEhIS8PDwwMHB\\\\\\n\",\n              \"gZwcGWEBB5Sqz/Nzhxk7epTSr0P470QLTBA+ERERETh06oxu5RoMmOtOlVr18jwfHxXB1WO7adSk\\\\\\n\",\n              \"Kd5ex2jWrFkJ1bT43b9/ny0eHvwZ9oScnBxatmxJ4PFNRIRcwrr7CIxq2ORuK8nlRN27xlP/PRio\\\\\\n\",\n              \"Z3LK3w8TExP69u3L3LlzqVu3Lj///DNmZmbs3buXU6dO4eDgwJQpU6hZsyYt2rSlbveRaJfTL7Re\\\\\\n\",\n              \"6Ylx/HXVj3HbfivKly+8gwgwQfgEJCYm4tCpM5atutBh8PgC1+szNKtEt7E/UL1eE7o5dedi4Hnq\\\\\\n\",\n              \"1KlTArUtPsHBwXwzeSoPHjzEsm1PDCzboK6hSURMODmapYl9eIOrT0PQMTBG17QKkiTnzcvHGOrp\\\\\\n\",\n              \"MnPKRIYOHYqOjg4Aenp6uLi4EBkZyaRJk9DR0WHq1Kls2rSJ8uX/HrghSRL1rK05u2wcDrN+R0vn\\\\\\n\",\n              \"3V22malJXFk9hW+nTqFixYrF8n4IeYkAE4RPwKrVq9GvUvOd4fVPdVt2IO7VWKbN+I4TPt7FVMPi\\\\\\n\",\n              \"d/bsWXr16UeD/pNxHr8KjX/NzbLpNpTIeze4snkuzt060a6tPRoaGlSrVg07O7vc9zEsLIzdu3ez\\\\\\n\",\n              \"e/duSpUqxZAhQ1iyZAlnz55l9uzZJCQksHDhQsqWLcuECRNITU6kU4tGnP3FjVo9x1HJ9u/QfEsu\\\\\\n\",\n              \"yyb81jkeHt3AwD49WTB/XrG+L8L/EWshCkIJk8lkVK5ajcGLNlPRSrEWVVZGOosHtOFO8C2qV69e\\\\\\n\",\n              \"tBUsAWFhYTRp3pJWE5dR0brJe7dNjYvC7+fhbF63mj59+gAQExPDgQMH2L17N8+fP8fFxYUhQ4bQ\\\\\\n\",\n              \"sGHDPB8Q4uLimDVrFsePH8fMzAwTExP++OMPypUrx4EDB1i2YhXPX7zEvEEr1LXLIktL4uXNc+jr\\\\\\n\",\n              \"69GyaWMGDBhAjx49Pqvh86p07hSDOAShhPn5+aFnYq5weAGUKq1Do0698Ni6tQhrVnKWLPsVK4d+\\\\\\n\",\n              \"hYYXQFkjM5qOmsvMn+ayf/9+unfvTs2aNbl69Srz588nPDycVatW0ahRo3ytWyMjIxYtWkT58uV5\\\\\\n\",\n              \"8eIFWVlZvHz5EjU1NVxcXLh14yoBJ71pVd2AuODT/HUjgEq2ranUujdPqcB3C3+lYuWq/DRnLunp\\\\\\n\",\n              \"6UX1dgjvILoQBaGEPXnyBHOrukqXM7Oqw6PHt4ugRiUrMTGRgwcP0n3p+0cZ/lOlei24sGEuv/32\\\\\\n\",\n              \"G5MmTWLv3r3o6uoWWi4sLIyuXbvi6urK3Llz2bhxI23btsXNzY05c+ZQrlw5jnodx9P7JHaDZlC1\\\\\\n\",\n              \"Yds83YkAb14+5qDnBnxPtiPg1Mnc62lC0RMtMEEoYTKZLN9JUREaGlrIspVfVulTd+LECSrWaURZ\\\\\\n\",\n              \"wwoKl1FTU6Nul4HYNWrMkCFDFAqvq1evYm9vzw8//MCCBQvQ1NRk4sSJhIaGEhkZSd26dRk9egwb\\\\\\n\",\n              \"PHbQef5OqjdxKPDvVL6KFfaTl5NjYoVjj17IPmCpK+HDiAAThBJmampK4usIpcvFR73E3My0CGpU\\\\\\n\",\n              \"sqKjo9ExMlO6XDmTSkRGRSu07bFjx+jevTu///47o0ePzvOcqakpO3fuZMuWLezYtYu2366kjIHx\\\\\\n\",\n              \"e/enpqZG46HfEx6biJeXl9J1Fz6M6EIUhBLm6OjI+G8mkBQXg56RYiuty+Vygv08mXNgbxHXrvhp\\\\\\n\",\n              \"aWkhyXOULifPycbf3x8rKyvMzMze+fDz82Pt2rX4+vrSpMm7r7FFRERQrX5zDCrVUOj46uoaWH3t\\\\\\n\",\n              \"ym9r1uLs7Kx0/QXliQAThBJmYGBA3779uOa9n07DJilU5s/rgRjolXvv0kiqytLSkjfPNytdLuH5\\\\\\n\",\n              \"QyZ8M46xo0cTFRWV53HlyhVevXpFcHAwsbGxqKmp0bVrV8zMzDA3Ny8w6FatXY9Fx2FK1aF6044c\\\\\\n\",\n              \"3rmU8PBwKlf+gHUVBaWIABOET8Csmd/TrEVLLOo1xqphi/duGx8VwcEl36Otqc7Fixdp06ZNMdWy\\\\\\n\",\n              \"eHTs2JHs5Dhin97DuIa1QmVkmek8uXCcI2tuYmFhQc2aNfM8n5WVhZubGxYWFoSGhlK+fHni4uLy\\\\\\n\",\n              \"BV1ERAQ3b94kKiqKsEePqD1MsdbXWxpapShvVpmIiAgRYMVABJggfAJq1qzJH4cO0qdffzq6TaNx\\\\\\n\",\n              \"595oauWdWyRJEo+CLuK5fDaLFsynhoUFAwYMwM3NjXnz5qGlpdhNGD91GhoaTBg3jl3e22kzaVmh\\\\\\n\",\n              \"E7sB/jzjSdNmzbCwsMj3XGJiIs7Ozujp6REQEECZMmUAqFChAhUqVKB+/foF7rNyNQtQ4NgF+Uyn\\\\\\n\",\n              \"135yxCAOQfhEtG/fngD/U0Td8GeJiz0nNv9K8BlvQs6d4OzeTawc3plzHsvYsmEdkyZOxNHRkeDg\\\\\\n\",\n              \"YG7dukWrVq0ICwsr6Zfw0UyZMhn1hAhC/thQaBi8uBXIg+MerF2Vfz3C8PBw2rRpQ+3atTl8+HBu\\\\\\n\",\n              \"eCnCvGJFEiOeKVXvHFk2b6LCqVRJ8dXshQ8nAkwQPiF2dnYEnj3D5QuB1DcrR/LdQOKC/DEniX07\\\\\\n\",\n              \"t/HwXii9evXK3d7U1BQfHx+GDh1Ky5Yt2bp162fx6V9XV5cz/n7E3DzFqaUTeB12J9/rSn4dTtCe\\\\\\n\",\n              \"3wjauoAT3l751oUMDQ2lZcuWDB48mLVr1yp9u5OxI0fw9NwfSpV5fj0AG5t6VKlSRalywocRXYiC\\\\\\n\",\n              \"8AmqXbs2K5b/qtC2ampqTJw4kfbt2+Pq6oqvry+bNm3CyMioiGtZtExMTMjOSCcz6QFBm2ajrqNL\\\\\\n\",\n              \"+apfgboGaTGviHn2gOHDhzLd/TpVq1bNU/bs2bMMGDCAVatW4erqqvSxk5KSePjwIc9uX8Yu8i/0\\\\\\n\",\n              \"zasVWkaSy3nsv5fffp6t9PGEDyNaYILwmbC2tub69etUq1YNW1tbAgICSrpK/8m6deuIjY3l2NGj\\\\\\n\",\n              \"vHz+lB0bVjPFpRsTnB1YPncGkREvWb1yZb7w2rdvHwMGDODAgQNKh5dMJmPz5s3UqlWL2NhY5vz4\\\\\\n\",\n              \"I+dXTiU9Kf695SRJ4uae5Zjp6dCzZ0+lX6vwYUQLTBA+I9ra2qxYsYIuXbowbNgwXF1dWbhwIdra\\\\\\n\",\n              \"2iVdNaUkJiYye/ZsOnXqRKtWrQDo0KHDe8tIksTy5ctxd3cnICCAevXqvXf7f/P392fatGkYGRnh\\\\\\n\",\n              \"7e1No0aNAMjMzGTL/GE0HPwdlW1boa6etysy4dUz7v6xAe3U15zw9/tsBtOoBOkz1ahRo5KugiCU\\\\\\n\",\n              \"qJiYGKlnz56Sra2tdP/+/ZKujlJGjx4taWtrSxEREQptL5PJpIkTJ0o2NjbSy5cvlTrWvXv3pG7d\\\\\\n\",\n              \"uklWVlbSkSNHJLlcnm+bQ4cOSTa2DSWjilUk214jpWZDv5Mau0yWati1lMobmUjf/TBTSklJUeq4\\\\\\n\",\n              \"nypVOneKLkRB+EwZGxtz5MgRvvnmG+zt7dmwofARfZ+Cp0+fsn37dmbMmKHQjSLT09Pp27cv9+7d\\\\\\n\",\n              \"4+LFiwrPv4qJiWHChAm0bduWjh07cu/ePXr16lXgsP2+fftyN/gmJ47+QY/6lWiom0GHqjos+m4i\\\\\\n\",\n              \"kREvWbbkF8qWfffNL4WiIe4HJghfgD///JNBgwZhbm6Oh4cHFSoovlBucWvdujUPHz4kIiKi0K7P\\\\\\n\",\n              \"2NhYevTogYWFBVu3blWoqzQzMxN3d3eWLl2auwq9qg94+ZhU6dwpWmCC8AWoVasWly9fxsbGBltb\\\\\\n\",\n              \"W06ePFnSVSrQmTNnuHbtGlu2bCk0jJ4+fUqrVq1o27Ytu3btKnR7SZI4fPgwderU4cKFC1y8eJHV\\\\\\n\",\n              \"q1eL8FJlJduDWXRUqR9XEIrT2bNnpSpVqkiTJ0+W0tLSSro6uXJycqRKlSpJdnZ2hW5748YNydzc\\\\\\n\",\n              \"XFq7dq1C+7527ZrUqlUrqUGDBtLp06f/a1U/a6p07hQtMEH4wrRr146QkBCioqJo2rQpd+7knyRc\\\\\\n\",\n              \"Etzd3YmOjmb//v3v3c7X15euXbuyfv16JkyY8N5tX758yeDBg+nVqxdubm7cvHkTBweHj1ltoQSJ\\\\\\n\",\n              \"ABOEL1BCQgJVq1Uj4lUktra2aGppUaWaBfN//plXr14Ve31SU1OZPXs2Li4ufPXVV+/c7vfff8fN\\\\\\n\",\n              \"zQ0vL688K5L8W0pKCnPmzMHW1pYaNWrw6NEj3NzclF6NQ/i0iQAThC+ITCZj3DcTsGvUmNsv45mw\\\\\\n\",\n              \"Zh+rz/3JyoD7DP55HedCwqhdpy4/zZlbrK2yyZMnI0kS69evL/B5SZKYN28ev/zyC4GBgbRoUfCK\\\\\\n\",\n              \"/Tk5OXh4ePDVV1/x/Plzbt++zYIFCyhXrlxRVl8oIWIisyB8IeRyOS6ug3j0Moo5+8+iU04vz/NV\\\\\\n\",\n              \"vqpLlRmL6DpyKr/PHEP8mzesc1+j0Grw/8Vff/3Fjh07WL58Obq6uvmez87OZsyYMYSGhnL58mVM\\\\\\n\",\n              \"TQu+C3VAQADTp09HV1eXY8eOvfdmlapAJpOhoaFR5O+/KhMtMEH4QqxctYrQsGeMXLw5X3j9k255\\\\\\n\",\n              \"Y8at2I73Sf9Cr0d9DC4uLpibmzN58uR8zyUnJ+Pk5ERMTAznzp0rMLz+/PNPevTowZgxY5gzZw6B\\\\\\n\",\n              \"gYEqGV6SJHHt2jVcBw9BV08fbW1ttLS0qF3XGnd3d5KSkkq6ip8cEWCC8AXIyclh5arV9Jz4E1oK\\\\\\n\",\n              \"zJXSKaeH49jvWLDof0VarzNnznD9+nX27NmDunre01FkZCRt27alWrVqHD16NN9E4bi4OCZPnkzr\\\\\\n\",\n              \"1q2xt7fn/v379OnTRyVbLPHx8bTr4ECvfgNI16/CD7v9+e1cGMv879Fp3I/sOuZHlarV2Lt3b0lX\\\\\\n\",\n              \"9ZMiAkwQvgAnT55ER9+QanUKvnljQaybtyMiMorJkyeTk5Pz0eskSRKDBw/ODaB/evDgAS1atMDZ\\\\\\n\",\n              \"2ZlNmzahqfl/VzuysrJYuXIltWvXRi6Xc//+fWbMmKFy6z2+lZCQQKs2bSllbskPuwNoP3A0eoYm\\\\\\n\",\n              \"qKmpoalVipoNWzB4vjvjV+9l6ozv+d3Do6Sr/MkQASYIX4CAgDPUbdVJqTLqGho0+bonx48fp2PH\\\\\\n\",\n              \"jrx8+fKj1mnlypXExMRw4MCBPL+/ePEi7dq1Y/78+fz000+5LSpJkjhy5AjW1tacPn2awMBA1q5d\\\\\\n\",\n              \"i4mJyUetV3EbOXoMZnUb4TR+Vr5W6D9VtKzN6OU7+O6Hmdy7d68Ya/jpEgEmCF+AN4mJlNHVV7qc\\\\\\n\",\n              \"Tjl9Bg4cSKdOnWjcuDGHDh36KPVJT0/nxx9/ZPz48ZiZmeX+/vDhwzg7O7Nr1y6GDx+e+/ubN2/S\\\\\\n\",\n              \"rl075s6dy7p16/Dx8cl3A0tVFB4ejr+/P91Gz1Co69O0ag1a9BzEave1xVC7T58IMEH4Aujp6pKR\\\\\\n\",\n              \"lqJ0uaz0VPT19Zk9ezbe3t78+OOPjBgxguTk5P9Un7Fjx6KlpcWKFStyf7dq1SqmTp3KqVOn+Prr\\\\\\n\",\n              \"rwGIiIhg+PDhODk5MXjwYIKDg3Of+xxs3LiJRp16ol1G8YWAmzsNYP/+fWJQByLABOGL0KplC8Ku\\\\\\n\",\n              \"n1eqjCRJ3LscQPPmzQFo0qQJt27dQlNTEzs7O65evfpBdXn+/Dl79uxh7dq1aGlpIZfLmTZtGps3\\\\\\n\",\n              \"b+bSpUvY2tqSmprK/PnzqV+/PhUrVuTPP/9k9OjRea6FfQ78z5zFurVyXbsGJmZUtKjJrVu3iqhW\\\\\\n\",\n              \"qkMEmCB8AXr16kX0i6dEPgtTuEzYraskxMYwe/ZsDhw4QHZ2NuXKlWPLli0sW7aMnj17snDhQmQy\\\\\\n\",\n              \"Wb6yGRkZ7N69m7YdOlHbuj71GjSiv4srFy9epE+fPlhaWjJ06FAyMjIYOHAgQUFBXLx4kSpVqrB9\\\\\\n\",\n              \"+3Zq1arFo0ePuHXrFosXL0ZP793D/lVNVlYW4eHhBAUFER0d9d4pDe9SRldftMAQE5kF4YtQqlQp\\\\\\n\",\n              \"xo8dy/FNyxj5v42oF7KkUnZmJr6/r2DZksWYm5uzZs0apk+fzrhx4xgzZgzOzs40a9aMYcOG4efn\\\\\\n\",\n              \"x65du7CwsECSJH5dvoJFi/6Hmo4J6aUtUCv1FWTJeXQpHG/fXqSnJrHNYzPx8fH06tULMzMzTp06\\\\\\n\",\n              \"xdWrV5k2bRqlS5fm8OHDuS0/VZCTk0NsbCzR0dFERUXlefz7d0lJSVSoUAFTU1PS0tLITEtV+ngZ\\\\\\n\",\n              \"qSkFTvr+0oj7gQnCFyIrK4uvO3chS8cAlx+WoKlVqsDtMtJS2TJrLElRf/Hngwe5yzDduXMHd3d3\\\\\\n\",\n              \"Dh8+TI8ePZg0aRINGzZk5cqVLFmyhN9++41Ll6+y5+BRsiq0R720Qb59S5KEPPEZmjEXMTY0oG/f\\\\\\n\",\n              \"vowdO5aZM2cSHBzM0qVL6d+//ycxl0uSJBISEt4bRm9/jo2NxcDAADMzs9yHqalpgT8bGRmhrq6O\\\\\\n\",\n              \"JEkMGTKUiOxS9J70k8L1SnkTxy+DHfjr2TMMDQ0/+utWpXOnCDBB+IKkpaUxYKArN2+H0KKHK80d\\\\\\n\",\n              \"+1FW7++gSYyL4erx/Vw5vp+unb8mR5ZNREQEx48fz7OWYHx8PB4eHqxbt46KFSsyadIkrKys6N69\\\\\\n\",\n              \"B3HJWWhYOqOm8f45WfLUKHhxApcBffHx8eG7775jypQplC5dukhfP/y90G9BYfTv30VHR6Ojo/Pe\\\\\\n\",\n              \"MHr7MDExQUtLS6Hjx8bGsnv3bjw8PEhMTCQ+IYn5R65QSlux1x6wZyP66dHs3L79P7wL76ZK504R\\\\\\n\",\n              \"YILwhXm7ZNEa97UcPXqU0jplkCQJWXYWA1xcmDThGxo0aEBOTg5jx47lwYMH+Pr6oq+fdxi+TCbj\\\\\\n\",\n              \"+PHjuLu78/DhQxISU8ip3BX1MorNy5JFXsOyfBqB58/+5ztEZ2ZmvrN19O+f5XJ5nvB5V0CZmpqi\\\\\\n\",\n              \"o6Pzn+r1Vk5ODv7+/nh4eODv70/37t0ZOXIk9vb2OHbvgZqJBV1GTSt0P/HREaz9pi9+vj40atTo\\\\\\n\",\n              \"o9Tt31Tp3CkCTBC+YJmZmcTHx6Ouro6hoWG+VoRcLmfixIkEBQXh5+dH+fLlC9zP6tWr+X7OEtQt\\\\\\n\",\n              \"+yh8bCk7FfVnh4mKjChwkIZMJiMmJkah1lJqampuCBXWWipXrlyxdVE+ffqUbdu2sX37dszMzBg5\\\\\\n\",\n              \"ciQDBw7M82EgOjqaps1bYNulHx1cx76zbnGvXvL7D258O3E8M6ZPL7I6q9K5UwSYIAjvJUkS06ZN\\\\\\n\",\n              \"4/z585w6dQpjY+N82/TtP5BjVyLRNLZRat/q4Sfo2601pqam+cIpPj4eQ0NDhVpL5cuXf+8qFsUp\\\\\\n\",\n              \"PT0dT09PPDw8uHv3LoMGDcLNzY369d+9jFdERARdHbuTJpPTtLsrDTt2z+1SfPXkIVeO7SH4jA+L\\\\\\n\",\n              \"Fi5g8qRJRVp/VTp3ilGIgiC8l5qaGr/99huzZ8+mffv2nD59Ot+q8FHR0ahpKT4Z961sShEeHk7d\\\\\\n\",\n              \"unWxtrbOE07GxsYqM+9LkiRu3bqFh4cHBw4coEmTJowfP54ePXootEZjpUqVCL55Ax8fH5b8upxD\\\\\\n\",\n              \"y39Cp0xZZDnZ6OnpM27sGPatvUfFihWL4dWoDtX41yEIQolSU1Nj8eLFaGtr065dOwICAvKcTEtr\\\\\\n\",\n              \"lwZJ+QV/5dlZnD9/nrCwMMzNzTE3N6dixYoFfm9iYvLJ3VE5Li6OPXv2sHXrVhITExkxYgTBwcFU\\\\\\n\",\n              \"rVpVqf1ERESwafNmNm3egqZ2aarXqktaSjIZKUkMHTKEEcOHi/AqgAgwQRAUoqamxvz589HW1qZt\\\\\\n\",\n              \"27acOXMGExMTzp49S3zsa6TULDCwUnh/kiSho5aCz5kzVK1alcjISF69ekVkZCSRkZFcunQp9/vI\\\\\\n\",\n              \"yEgSEhIwMTEpNOhMTU2LtIEN8uoAACAASURBVOUml8sJCAjAw8ODkydP0q1bN1asWEH79u0/qBvz\\\\\\n\",\n              \"2LFjDBvhRkMHJ0b/up2KNb7KfS76rydc9tpLA7uGrHNfw6BBgz7mS1F5IsAEQVDKsGHDCA4Opnbt\\\\\\n\",\n              \"2qirq2NnZ0enTg7cd1+PJG+JmrpipxV5SjipyQkcOnSIiRMn0qJFi/dun5WVRXR0dL6gu3HjRu73\\\\\\n\",\n              \"r169Ii4uDiMjo0KDzszMjFKlCp4LV5Dnz5+zfft2tm3bhrGxMW5ubmzYsOGdA1sU4ePjw4hRYxj7\\\\\\n\",\n              \"67YCb3VjWs2S3pPm0MyxP1Onj0BTU5MBAwZ88PE+N2IQhyAI7yWXywkKCsLb2xsfHx+eP39Oly5d\\\\\\n\",\n              \"0NLS4vTp05w9e5aaNWvStr0Dlx9loFnBrtB9SpIc9Rc+6Gtnoa6uTlpaGk2aNGHixIk4Ojr+p65C\\\\\\n\",\n              \"mUzG69ev8wXdP7+PjIwkOjoaAwODfMH2z58NDQ25fv06u3btIjg4mIEDBzJy5EhsbW0/uH5vpaam\\\\\\n\",\n              \"UrlKVUYt+R0Lm8Lfs/Cw+6ybMohnT58UyQTmt1Tp3ClaYIIg5JOcnIy/vz/e3t74+vpiZGSEk5MT\\\\\\n\",\n              \"q1atokWLFrlddFu2bKF9+/YcPXqUzPRUZJE3UdMsi4bhV+/ctyTJ0Yi+iJ2NBQH+fnh6ejJnzhwi\\\\\\n\",\n              \"IiKYPXs2kydPZvz48YwcObLAEY8FyczMxNPTE/cNm3jy5Ak52dkYV6jAEFcXRo8aVeA8s7fLP/07\\\\\\n\",\n              \"3B48eICnpycPHjwgJiYGSZIoU6YMVapU4eHDh/z2228FtujMzc3z3TX6ffbs2YNl/cYKhRdA5Zp1\\\\\\n\",\n              \"sW7Znu3btzNtWuFzxr4EogUmCAIAT548wdvbG29vb65du0bLli1xcnLC0dERCwuLd5ZbsWIFP/zw\\\\\\n\",\n              \"A+rq6ujq6pKZnQPlqpGtWzvPpGZJkiNPfEbplPvY1rXEx/tY7np+2dnZbNu2jQULFlCzZk309fU5\\\\\\n\",\n              \"d+4cvXr1YuLEiTRu3Pidxz98+DBjxn+DbkVLKrZxpnz1Oqipa5AWF8mry8cJDzqD20g3Vq1Y/t6W\\\\\\n\",\n              \"3Zs3b9i7dy8eHh7ExcUxYsQIhg8fTtWqVYmPjy+0Rffq1Su0tbUL7bqsWLEiurq61GtgS7vh06jT\\\\\\n\",\n              \"zP6ddfq3p6G3+GPpDzx9/KjI5rKp0rlTBJggfKGys7O5dOlSbtdgQkICjo6OODk50bFjxzzLR71L\\\\\\n\",\n              \"YGAgPXr0ICMjA5lMxujRo1mwYAEbN25itftasuWayNRKI8/JhqxEZLJstm7ZyMCBAwscaJGens76\\\\\\n\",\n              \"9etZunQp7dq1w8LCggMHDmBqasqECRPo379/nuWmPDy2Mn3WjzQatwzDGtYF1jEzJYHbW37EzrIS\\\\\\n\",\n              \"nocO5AkxuVzO2bNn2bp1Kz4+PnTp0gU3NzccHByU7sZ8u3aiIkH39rWuPvcnGpqKLUH19hjTO1oT\\\\\\n\",\n              \"FxujVGtPGap07hQBJghfkNjYWE6ePIm3tzenTp3C0tISJycnnJycsLOzU2oU3ebNm5k+fTo6OjpY\\\\\\n\",\n              \"Wlqiq6tLSEgIx48fp2nTpshkMq5fv86ePXsICwtj9erVDB06lEWLFtG5c+f37jspKYmVK1eyZs0a\\\\\\n\",\n              \"+vfvT7Nmzdi3bx+3b99m5MiRjBs3jtevX+PwdVdafLcJXfPq791fTnYWN1ZPZmT/7syfO4eXL1/m\\\\\\n\",\n              \"DsjQ1dVl5MiRDBo0CCMjI4Vf/4eSJImYmBgqVarM6vOPlC4/26kRYX8+xMREsSW7lKVK505xDUwQ\\\\\\n\",\n              \"PmOSJBEaGprbNRgaGoqDgwNOTk6sXLkSc3NzpfeZnZ3N1KlT2b9/P6amptSrVw9Jkvjjjz84ceIE\\\\\\n\",\n              \"Tk5OHDlyhFatWtGyZUuePXvGmzdvqFOnDsOHD2f79u2FBpienh7z5s1jwoQJLFmyhGnTpuW27vbt\\\\\\n\",\n              \"24ednR1ldPWx6Dyk0PAC0NAqhfWgmaxYMpLLFy9w8+ZNBgwYwKFDh2jYsGGRLy0lSRLh4eHcvn2b\\\\\\n\",\n              \"kJAQgoODkSSJ9JQkpe4Hlp2VSWpKcr51Kb9UIsAE4TOTnp7O2bNn8fHxwdvbGw0NDZycnJg3bx5t\\\\\\n\",\n              \"27ZVaGWId4mLi6NPnz48fvwYKysrmjZtyq1btzh9+nTucXbv3k3v3r05ePAg7dq1Q0dHh/T0dABc\\\\\\n\",\n              \"XFyYPXs2CQkJGBjkv93KvxkbG7N8+XKmTp3KokWLcHR0ZOrUqZw7d44mzVvQsHUPheuua14dHdNq\\\\\\n\",\n              \"WFlZcezYsY+2UO+/ZWVl8fDhQ27fvp37CAkJyb2TdYMGDejbty9vEpMIOu1Nm16uCu/79rmTtG5j\\\\\\n\",\n              \"r9Tw/8+ZCDBB+AxERETkBta5c+ews7PDycmJkydPUrt27Y/SwggNDaV79+6oq6vToEED2rZty7Zt\\\\\\n\",\n              \"27h06VKeMPj66685cOAA/fv3Z8+ePZQuXTo3wIyMjHKfHzt2rMLHrly5Mhs3bmTGjBnMmzeP5cuX\\\\\\n\",\n              \"U6FWQ0qVVe5uxlVb9+B1XNhHC683b94QEhJCSEhIblg9fPiQ6tWrY2tri62tLd9//z22traYmZnl\\\\\\n\",\n              \"KWtiYsLoCVNo3XOgwn+fq157+GXu7I9S98+BCDBBUEFyuZwbN27kdg2+ePGCrl274urqyo4dO/7T\\\\\\n\",\n              \"5NqCeHl54ebmhr6+Pvb29nTs2JGZM2dy8eLFAucktW/fHk9PT5ydnfn222/JyMjIfW748OEsWrRI\\\\\\n\",\n              \"qQB7y8rKij179vDTTz+x9/wdpctr6xkSH56gdDlJknj+/HmeFtXt27eJi4ujfv36NGjQgBYtWjB+\\\\\\n\",\n              \"/HhsbGwoU6ZMofvs0KEDumW0ObNvMw6uhb8Xl47tJTslEScnJ6Xr/7kSASYIKiIpKYlTp07h4+OD\\\\\\n\",\n              \"r68vJiYmODk54e7uTvPmzYtk+SRJkli8eDHu7u6UK1eOAQMG0LFjR1xcXDh9+jTVqlV7Z9nWrVtz\\\\\\n\",\n              \"/PhxunbtmmdwROfOnRk1ahSXL18mPj6e5ORkdHV1adq0qcL3BatatSrq8ltKv56czAzKlH1/uGRk\\\\\\n\",\n              \"ZHDv3r08raqQkBB0dXVzW1WDBg3i119/xdLS8oNXwVdXV8fH6xjNW7YiR5aDg+sYNAr4G8pzcgj8\\\\\\n\",\n              \"YyfnD2zh8sULKrPAcXEQ74QgfMIeP36c28q6fv06rVq1yr2eVb169SI9dlpaGm5uboSGhqKurs63\\\\\\n\",\n              \"335Lu3bt6NSpEwcOHHjv7UHeatasGevXr2fIkCEcOHCAAQMGcPPmTfTKG9POoRMmVg1Q1y6LlJlG\\\\\\n\",\n              \"/PNQunTpwnfTpuZbVio+Pp5Lly4RGBhIYGAgd+7cQdIsTYMcGeoaip/G3vx5A8cODXN/jomJyQ2q\\\\\\n\",\n              \"t1/fXt97G1Y9e/akQYMGCk+qVkaVKlW4fvUK/QYMZKHLXpo7DaC+fWfK6OqRkZpC6KXTXPHaR0Uz\\\\\\n\",\n              \"M65evvTe+XhfIhFggvAJyc7O5uLFi7mhlZycjKOjI5MmTcLBwUGhuVkfw8uXL+nVqxfGxsbExMSw\\\\\\n\",\n              \"atUqWrZsSevWrVm7di3t27dXeF8NGzbE3Nycb7/9lqPHjuHjF0D5Fv2w6bMYTR3d3O0qpKdwK/gk\\\\\\n\",\n              \"Xbr3ZsLYUTSoX4/AwEAuXLjA8+fPad68OY0aNcpdlT5Tlk1kyAUqNVSsLllpyby4dorYepVxdHQk\\\\\\n\",\n              \"JCSElJQUGjRogK2tLe3bt+fbb7+lbt26eeaaFbVKlSpx+WIgwcHBrFm7jr3zJ5CcnEy5cuVo1aol\\\\\\n\",\n              \"Xp6HadKkSbHVR5WIABOEEhYTE8OJEyfw8fHh1KlT1KxZEycnJ/bt24etrW2x36jx8uXL9O3bl06d\\\\\\n\",\n              \"OuHr68vevXtp2LAhrVu3ZsaMGfTv31+p/ZUuXRq5XI7byJH86r6J2mPWom2Qv6tQU6ccFVr2xcC6\\\\\\n\",\n              \"LSvWTeSrqqYMHzaMoUOH8urVKxYuXMjSpUsxNDRkypQp1KxZk1mLfsXMpgUapQoPnAfHtlC6jA4G\\\\\\n\",\n              \"BgaMHj0aW1tbqlWrhpqaGnFxcWzbto3J02aQ8OYNOmXK0LihHRO/GY+1dcETpD82Ozs7tnn8XizH\\\\\\n\",\n              \"+lyIicyCUMwkSeLu3bu5rax79+7RsWNHnJyc6Nq1a77RasVp27Zt/PDDD7i4uHDo0CGOHz9O3bp1\\\\\\n\",\n              \"6dixI61bt2bZsmVK7zMmJoZatWqRmZ2D5dgNlDYs/L5WmW+iCNs4lkEu/Tl06BDJyck0btyYxYsX\\\\\\n\",\n              \"4+DgAPw9kKVHb2dC/oql8filaGoXPLJQkiQendhJZOAhKpqaULVqVbZu3YqJiQmZmZlMnjqNvXv3\\\\\\n\",\n              \"ULt5B2rZd6OsviHZmek8v32V2ycPUrduHXZu20qNGjWUfu2qSJXOnSLABKEYpKenc+bMmdyh7lpa\\\\\\n\",\n              \"WrkrYNjb2/+nuVmKysjI4PLly8TFxVG6dGnq1KmDldXf9++SyWR8//33eHl50a1bN7y9vfHz88PC\\\\\\n\",\n              \"woI+ffqgq6vLzp07P6g1+ObNG0wqmGLcoANVnWcqXO7pocUkP7jA6FEjmTt3LrGxsXmuVd2+fZvs\\\\\\n\",\n              \"7GxKlylHSqaMGl8PolorJ0qV+btbUp4jIyrkIuHnD5EW/YKcrHR69OiBjo4OXl5ebNq0iWXLV5Ao\\\\\\n\",\n              \"labLpJ8pq59/NGWOLJsbXrsJOrKV82fPUKdOHaVfv6pRpXOn6EIsASkpKcTHx6OtrY2RkZEYVfSZ\\\\\\n\",\n              \"Cg8Pzw2s8+fP07BhQ5ycnDh16hS1atUq8tUf3nrx4gVr3Neydds2TCpXQ9/YFFlmJs8e3KFePRtG\\\\\\n\",\n              \"u41g586dwN/D3wMDA7l8+TKmpqaMGzeO9PR0Dh06pHB4paenc+3atdwBF9euXUNS08CoaW+l6m3W\\\\\\n\",\n              \"qj9pj69z/fp1LCwsqFixIra2tjRo0IBJkyZha2tLpUqVALhw4QKr3Nfi90MP9EzMUNfQJDk2mhqW\\\\\\n\",\n              \"NVg0YxL9+vUjIyOD5cuXs2HDBtq2bYvr4CFUrdeUvj+uQv0d6x5qaGrR3HkEpcvp06lLVx7eCy22\\\\\\n\",\n              \"65BC4UQLrJjIZDK8vLxYt249V65cxtDQiIzMDNTV1Rnp5sa4cePeOyRZ+PTl5OTkmZsVHh5O165d\\\\\\n\",\n              \"cXR0pHPnzh99bpYizpw5Q9/+A2jYqScteg7CpHL13OdkWZncPn8Sv62rKa9bhrq1a/HmzRuOHj2K\\\\\\n\",\n              \"vr4+CxYs4NixY5w7dy531fiCJCYm5o4QvHDhArdv36Z+/fq0adMGe3t7mjZtiqmZGY0XBigV2pIk\\\\\\n\",\n              \"cWve1/if8qNJkybvrcNb8fHxhIeHk52dTYUKFahSpUq+baKjo5k1axa79+3nuwNX0NZRbFHcPxZO\\\\\\n\",\n              \"4JtBzh80f02VfGrnzvcRH/2LwYsXL+jWzZGyuroMHzWObfuP5nYZPQ77k51bt2DXsCGzZ81i+vTp\\\\\\n\",\n              \"xfbJXPjvEhMT88zNMjU1xcnJiXXr1tG8efP/dGPG/+r69ev06defwfPdqWnXPN/zmqW0adypJ/Xt\\\\\\n\",\n              \"O+Mxawy3bofw4F4oOjo6/P777+zYsYNLly7lC47o6GguXLiQG1iPHz+mSZMm2Nvbs3DhQpo1a5a7\\\\\\n\",\n              \"UnpaWhp3795FTU1D6X/XampqoK7OwoULqVChAoaGhgo/3sfU1BQz84o07tJX4fACsO02kFXuKxgz\\\\\\n\",\n              \"Zoz4P/qJEC2wIhYZGUmLli0ZMfobxk2c+s7tIsJf4tqnO8OGDmbWrFnFWENBWWFhYbmtrBs3btC6\\\\\\n\",\n              \"devc+2Z9Kq1oSZKwrteAZgPGYte+W6HbZ2Vm4D7OmfWrliOXyxkzZgyBgYFYWVnx119/5YZVYGAg\\\\\\n\",\n              \"r1+/plWrVtjb29OmTRtq1qzJixcvePz4MY8fP+bJkye538fHx1O9enUePgrDbpYnmmUUX/pJlpHC\\\\\\n\",\n              \"nV+c8T7uRUJCAvHx8YU+tLS0MDQ0xMjI6L0BN3X6d3Sd/iuVahU+l+0tuVzOmiH23Lx25bOej/Wp\\\\\\n\",\n              \"nDsVIVpgRWzkyFH0GzjkveEFUKlyFQ4c9aVr+5Y4ODjQtGnTYqqhUJisrKw8c7NSU1NxdHRkypQp\\\\\\n\",\n              \"ODg4FNl9mf6Ly5cvk5SWjm27rgptX0q7NK37ubHwf4v588F9xo4dy/z58wkMDCQ7O5vmzZtjaWnJ\\\\\\n\",\n              \"sGHDyMzM5NmzZxw9epTly5eTmpqKlZUVVlZWWFpa0qxZMwYNGoSVlRU6OjosXbqUR0/WEnvzJGZt\\\\\\n\",\n              \"FB+CHx98CscePenSpYtC20uSRGpq6jvDLTY2lkePHhEfH0/M69foGim26sdb6urq6BtVIC4u7rMO\\\\\\n\",\n              \"MFUiAqwIPX78mBtBN9iwfb9C25uZV2T0N5NZu3YdO3eKACtJb+dmeXt74+/vz1dffYWTkxMHDhzA\\\\\\n\",\n              \"1tb2k+9CWrt+A02dXJSqp52DE4dXzqOUpgYnTpxAW1sbAwMDXr58ydmzZwkPD8fS0hIrKyvat2/P\\\\\\n\",\n              \"6NGjsbKywszMLM9xJEnCz8+Pfv36cfPmTbS1tdEppUnCLS9MW/VFTYHBIJJcTvItb6bv265w/dXU\\\\\\n\",\n              \"1ChXrhzlypWjatWq+Z7Pysri0aNHhIaGcuZ8ILLsLIX3/ZYsK7NYRowKihEBVoQ2bNiAi+tQpWb1\\\\\\n\",\n              \"uwweRgvb2sTGxhbJ0jWqSJKk3DXzypYti7Gx8UcPEEmSuHPnTm4r68GDB7lzs9zd3TE1Nf2oxytq\\\\\\n\",\n              \"9x88xGFsH6XKlNIujVk1S2pWNqV169a5rSorKyuMjIwKfc/j4uJYsGABO3bsIDk5GRsbG/bt24ez\\\\\\n\",\n              \"szONGzcmUybnpY87VZwmv3dfkiQRHfA7VtUr0aZNG6VeA/zd1ffs2TPu3r1LaGho7uPJkydUr14d\\\\\\n\",\n              \"GxsbTEwq8PJeMIbm+YPuXVLexJEQG/3JdBMLIsCK1LVr15k+e75SZQwNjbCp14CQkJDcCZtfquTk\\\\\\n\",\n              \"ZHbv3s269et58dcL9PT1SU1JwdDIiG/Gj2PEiBGFXrB/n7S0NM6cOYO3tzc+Pj5oa2vj5OTEwoUL\\\\\\n\",\n              \"sbdX7XsuZWVloqml+K3q3ypvZMK0adPo1q3w62bwd9h4e3szf/58bt++TdmyZRk8eDDz5s3LE/qS\\\\\\n\",\n              \"JBET9Qr1UvFE+qymgsMoNHXyD0eXpafw+sxWdOL/xPdiYKFBFxkZmS+oHjx4gLGxMTY2NtjY2ODk\\\\\\n\",\n              \"5MTMmTOpXbt27ofJY8eOMWPuIhp07Knwe3P71GF69+6Nnp5yt3ARio4IsCKUkpJC2bLKzxnR0NRk\\\\\\n\",\n              \"165d/PXXXxgbG+d5GBgYFPvSQiXh8uXL9O7tjG3jZkz96ReatWqLmpoakiQRcus6B3f+zv/+t5hd\\\\\\n\",\n              \"u3bi6Oio8H5fvnyZOzcrMDCQRo0a4eTkxOnTp/nqq68++a5BRRkbGZMQE0Xlr5RbBikhJjLPyvHv\\\\\\n\",\n              \"EhMTw5w5c9i7dy+pqak0bNgwdxL0v9/DS5cucf/+faZPn86sWbMYM34CR38biG6tlpS3aYdG6bLk\\\\\\n\",\n              \"ZKSSHnaFuNDzODo68rvv5Tx3HY6Pj88TUm8fWlpauUHVsmVLxowZg7W1daEh4+joyLgJE3kafJka\\\\\\n\",\n              \"di0Lfb1pSW+45b0XPx+vQrcVio8IsCKkp6dHYsIbpcslJSaSkWHCxYsXiY2NzfNITk6mfPny+YLt\\\\\\n\",\n              \"7cPIyCjf7/T09FTqxHz16lV69OjJolWbad2+U57n1NTUsG3UDNtGzQi5eZ0Rbq5s2+rxzhDLycnh\\\\\\n\",\n              \"+vXruV2Dr169omvXrgwdOpTdu3crdFdgVdSzuyM7j+7HppXirfiXf4aSnZ5K48aNC3xekiT++OMP\\\\\\n\",\n              \"Fi5cyN27dzEwMGDUqFHMmTPnnXPcjh8/jpubG3Z2drRs2RI9PT3279nFkCFDeP06hrTwMyQlJaGr\\\\\\n\",\n              \"q0unrm0ZsmcVcXFxeHp65gmqt12Sbx/9+vXD2tpa4duv/Jumpia7d2yn7wAX+s3bSOXaDd65bXpy\\\\\\n\",\n              \"Iofmj2PYYFcaNWr0QccTioYIsCLUoUN7fI8fxb694ieRVxHhPHhwD+u6tenduzedOnXKcw1NJpPl\\\\\\n\",\n              \"jqj69yM8PJzbt2/n+31GRkaBwfa+R5kyZUok9LKysujTpy8LftuQL7z+rUGjpvy2ZQ9DhvTn6dMn\\\\\\n\",\n              \"uWGUmJiIn58f3t7enDhxAnNzc5ycnNiwYQPNmjUr0blZxSEsLIyNGzfyMuIVibHR6Bsrdv3u6rE9\\\\\\n\",\n              \"jB83Nt/7ExUVxY8//sjBgwdJS0ujefPmnD59mg4dOrx3f9u2bWPWrFn4+Pjw66+/5t6VOSsrC19f\\\\\\n\",\n              \"X/bt25enZbVzx3aWLvmFWrVq5QbVlClTsLGxoUqVKh/936ODgwM7t21lyPARNOjUh4bdXChv/n8T\\\\\\n\",\n              \"nzPTUggJOEbQ0e30d+7Jr8uWftTjC/+dmAdWhF69eoW1jQ037oShq2C/+bLFPxMb+ZKGDRty5MgR\\\\\\n\",\n              \"goOD6dy5M71798bR0fGD+t8zMzOJi4srMPQKesTExAAoFXjGxsYfZXTWwYMHWblmHVsO+Chc5oeJ\\\\\\n\",\n              \"I7C1roWRoSHe3t4EBQXRpk2b3LlZBY1I+1z5+Pjg4uqKXC6hY2CEXnljJqzchVYhf5s7gac47v4z\\\\\\n\",\n              \"oXdCMDExQZIk9u/fz//+9z/u37+PsbExo0aNYvbs2YUupSRJEsuWLWPDhg34+vqira3NqFGjMDQ0\\\\\\n\",\n              \"pFSpUly+fJnw8HBq1qyJjY0N9erVyw0sS0vLYl9a7enTp7ivXcf2HTswqlSNcgZGZGdm8OLPu7Rv\\\\\\n\",\n              \"356pkyYWGtafk0/h3KkoEWBFbPDgIaiXKs2ylesK/QT558P79HHqROD587mLhsbExODl5YWnpycX\\\\\\n\",\n              \"LlygTZs29O7dmx49enxw94ki0tLSFA68tw9tbW2lAs/Q0BCtfw00sG/bll6uo/naqZfCdb157RIT\\\\\\n\",\n              \"h/fHdaALTk5OdOjQ4ZOcm1WUJEli+fLlzF+wiCo2jbEfNRs9E3N8V3xHZmIsQ+aswNC0Ur5yOTIZ\\\\\\n\",\n              \"13wO4r99NX4nfDE3N2fmzJl4enqSmZlJmzZtWLx4cb4bTBZ0/MjISO7cucMvv/zCnTt3qFatGmFh\\\\\\n\",\n              \"Yejq6pKdnY25uTlff/01d+7coWfPnkyaNKmo3o4P8nYNx4SEBMqUKYONjQ0VKxa+cv7n5lM5dypC\\\\\\n\",\n              \"BFgRS05OpmWrVtg2asqSFe7v/HR5NySYYS7OLFnyC0OGDClwm6SkJHx9fTly5Ah+fn40aNAAZ2dn\\\\\\n\",\n              \"evfuXeKtDEmSSE5OVirw4uPj0dXVzRNqfn5+XLz3kjJlFA8gSZJoUaciEeHheS78fykyMzMZP348\\\\\\n\",\n              \"R497U8mmGZ2/XZq7OK0kl3N53zqCvfdQo35jmnbuhZ5RBbKzMnkScp2bJw5jaVkDp65d2L17N48e\\\\\\n\",\n              \"PcLMzIzx48fz3XffFTgF5M2bN3muT70dBaihoYGGhgbq6upMmzaN6OhovLx9yMySUcWyFlqltEmI\\\\\\n\",\n              \"jebe7SBGjBjBrJk/iCHpn6BP5dypCBFgReztp9jU1FSSkpIZPHwUfV0GYWZekcyMDG5cu8IOj41c\\\\\\n\",\n              \"u3KJjRs30q9fP4X2m5GRgb+/P0eOHMHLy4vq1avnhpmq3PJBLpfz+vVrHj16RFhYGE+ePGHJkiWE\\\\\\n\",\n              \"vEhU+nqHQ6OaBN+6+cV9Yo6KisLZ2RkdHR1u3rmH2xZ/NLXyD//Pykjj4XkfnlwNICMlEQ0tLWL/\\\\\\n\",\n              \"CqNenVrcvn0bmUxGhw4dWLJkCXZ2dgCkpqby4MGDPCFV0IAKGxsbLCwsGD9+PKVKlWLDhg30H+BC\\\\\\n\",\n              \"Fpr0GjGRek1a5fl7RoX/hd+hnZzzPshRT09at25dbO+XULhP5dypCBFg/9GTJ0/Yvn07L168RC6X\\\\\\n\",\n              \"U7GiOYMHD6ZevXpIksSYMWOIi4vj8OHDhISEsH79eo4eO0Z8XBza2trUqVOXcePG4urq+sHdXjKZ\\\\\\n\",\n              \"jAsXLuDp6cmRI0fQ1dWld+/eODs706hRo2IfjCFJEklJSURFRREZGfner4mJiZiYmGBmZoa5uTn+\\\\\\n\",\n              \"p09z9tZj9PQVHx2Yk5ND89rmxLx+/UXd6uLWrVv06tWLESNG4B9wBrVqtrRynahw+Sv7N3DHeycz\\\\\\n\",\n              \"v/+Orl278vDhwzwtq1evXuUZUPH2UbVq1Tz/pmJjY3F0dMTa2po1a9bwdecuGFevxcjvF713ykfw\\\\\\n\",\n              \"5XOs/mkiZ06fpkGDd48CFIqXCLBPQFH/EW7fvs3MmbMIuhlEf5fB1K5rjZqaGk+fPObAnp1YWlrS\\\\\\n\",\n              \"uHEj/P39uXLlSr4VvSVJKpJgkcvlBAUF4enpiaenJxkZGblh1rp16/80Ak8mkxETE1NgGP37dxoa\\\\\\n\",\n              \"Grmh9O+v//zeyMgoT52cnLrTtF0XnAcOU7he50+fYKv7Mm4G3fjg1/YpePHiBRs2buTQocPEx8Wh\\\\\\n\",\n              \"qaVFzZo1GT9uLH379s3TnXfgwAEmTpzImjVrOH36NDt27WKMRwDlDBW/LpqaEMem4e3Q1FDPXaHi\\\\\\n\",\n              \"nw8rK6tCB1T89ddfuYOMFi9ezIYNG9i+7zA/rt2j0HxFv8O7uX3Wi4uB5xWut1C0RIB9AoryjxAQ\\\\\\n\",\n              \"EMAAFxdmzvmZ/gOHoKOT91bm2dnZeB/zZPrk8SxauJApU6YUST0KI0kS9+/f58iRI3h6ehIeHk6P\\\\\\n\",\n              \"Hj3o3bs3HTt2zB01mJqaWmAI/ftrXFwchoaG7wyjf3790JbQyZMnmfH9TPb5XlA44Me69mT0iCEM\\\\\\n\",\n              \"Hz78g45Z0v6+hvUNR44epUvP/nTpNRBT80pkZ2dx/84tju3fStiDUDZv2kjPnj2ZO3cuu3fvZvPm\\\\\\n\",\n              \"zcybN48KFSpwws+PqX+EKH3sNX3tePHX8w8aEBQaGkrXrl2ZPn06U6dORZIk6trUY/C0n6nfVLFu\\\\\\n\",\n              \"wezsLMZ2bcL5s2eoW7eu0nUQPj5VCjAxD0xJDx48wGXgQDx2HaBla/sCt9HS0qJ33wHUqlOXfj26\\\\\\n\",\n              \"0KhRo2Lv55fL5cTGxiKTyWjSpAmVK1fmwYMHXLlyheHDh/PmzRt0dHSQyWQABYZR69at8/xcoUKF\\\\\\n\",\n              \"Qj+RJyYm/r0SelISZcuWxc7ODjMzM4Xr/fXXX5MzbTqH92yj32C3QrcPOOHF7ZvX8DLSp0uXLkod\\\\\\n\",\n              \"61OQlZWFk1N35Jql+eNsCGX+tXJL206OtO3kyP07txg12oVvvvkGSZKoVq0aTk5OaGtrk5mZiSwn\\\\\\n\",\n              \"54Na9ZIkfdCHjUuXLuHs7MzKlStxdXUF4Nq1a6RnZFGvSSuF96OlVQqHXq5s3ryFVatWKl0P4csm\\\\\\n\",\n              \"AkxJixf/wriJ374zvP6prnU95i1ayrx58wkIOP1Rjp+enk5UVNQ7W0pvv3/9+jX6+vr5Qql3796M\\\\\\n\",\n              \"Hz8ebW1tQkNDc2/5vUQe8wAAIABJREFUbm1tnTs8/0MWEQ4NDcXd3Z0DBw9iXa8BBgblSUtNIfhW\\\\\\n\",\n              \"EJ06dmLSpInY2xf+nqmrq3Pc6xht2tgjSRL9Bru986Ts5/UH87+fhN/Jk/j6+lK/fn2WLVvGsGHD\\\\\\n\",\n              \"8pWRJIkbN26wdt16rly5QmpKCrp6erRv15YJEyZQr149pV+zMnJycoiLi+P169fExMTkft23bz9y\\\\\\n\",\n              \"DW2Wr9323u7duvUbsnaXF6P6dUINievXr2NgYEDdunUxNTXl5KnTxL14jHG1mgrXKT78KeV0dfP1\\\\\\n\",\n              \"IBTm7eoau3fvpnPnzrm/DwsLw8q6gdIhamndgFt+h5QqIwggAkwpsbGxHD/uxbWQXxUu09O5Hwvm\\\\\\n\",\n              \"zOThw4fUrl27wG3errb+rutJ//yanp6OmZlZvi68Jk2a5AkqU1PTQhejdXZ2Bv5uNfn4+ODp6cm3\\\\\\n\",\n              \"335Lo0aN6N27N7169Srwluz/tmPHDmbM+I5ho8cTcDkYUzPz3OeSkhLxPLCXQYOHMGzoEBYuXFjo\\\\\\n\",\n              \"Cc7KyooLFwLp2as3h3Z70G/IKDp27Y6ungFpKckEBvix6/e1pCUnUaOGBT4+Pvzyyy/069ePkSNH\\\\\\n\",\n              \"snfvXjZv3kz16tWBvyequgx0JTo6Gpeho1kzcjJly+mSlJSAv88xvu7chTp16rB3z26FW3Byufz/\\\\\\n\",\n              \"tXfncTWm///AX5XSvp1zqtOiQoVSKluKYkqW7BTJ2mLflxgGYx07H0uGEbJlzxbJViKEkaxFCSWt\\\\\\n\",\n              \"tK/n/ftjvvpp2s5RDYfr+XicB51zX/d1ndvM9eq67+u+bmRlZVUKpNTU1Crf+/jxI1RVVcHj8aCh\\\\\\n\",\n              \"oQENDQ2oqKjgYfRDHL54V6hrk4ZGLTDcaypOHPgTGzZswJMnT3DmzBm8f/8eZqYt8fDcQThOWiJU\\\\\\n\",\n              \"+wHgcchReHl6ihQ4e/bswa+//orz589XemZdSUkJpL7iJuRGjRqhuLhE5HIMw66BiWDnzp0IuXwN\\\\\\n\",\n              \"O/z3i1Ru8a9zkZmajK5du1YZSh8+fIC8vHy115O+/FNNTa1BZxUWFBTg0qVLOHnyJM6dO4dmzZph\\\\\\n\",\n              \"4MCBGDhwIIyNjSttf+zYMUybPgMHT5yHkUnVAQ0A6WmpGDmkLwYPGoBFixYJ1RaBQICrV69iy9at\\\\\\n\",\n              \"CAsLQ052NhQUFWFubo5nT58iKSkJubm56NSpE2bOnInx48ejpKQEGzZswNq1a/Hbb7+he/fu6Nbt\\\\\\n\",\n              \"F3hNng2PseOqnFhQUlKCHZtW48yJQ9gfEABJSclaAykjIwNKSkoVAunz3//9p4aGBhQUFPDmzRvE\\\\\\n\",\n              \"xsaWv65fvw6etgE27BZ+9JGZkYYB9uYwM22FgQMHwsXFBX///TcWLFiA9KyP8Np5CQpqtY+g8z9l\\\\\\n\",\n              \"ImCSC2IePhDqXqwvV9cICQmBiYlJpW3OnTuHRctX4/ddx4X+PgBw4eg+fEyIwaEDov1/xTQMdg3s\\\\\\n\",\n              \"B/X+/XsYGDYVuZxhs2YIvXAWSkpK4PP5MDc3h7Ozc4VwEuWZYQ1JTk4O/fr1Q79+/VBSUoKwsDCc\\\\\\n\",\n              \"OnUKDg4OUFNTK7/XzNLSEgUFBRg/YQIOHD9XY3gBAJengT2Bp+DcpR2GDRsGI6PaT3VJSkrC0dER\\\\\\n\",\n              \"jo6OlT6ztbVFSEgI+vXrhwsXLsDOzg66urpwcXGBr68v+vXrhzFjxuC3RYvh+9sKuI6o/nqatLQ0\\\\\\n\",\n              \"psxZCALQp29ftGzRApqamuUB1LRpU3Ts2LFCIHG53EqriAgEArx9+7Y8oG7dulX+96SkJDRp0gTG\\\\\\n\",\n              \"xsYwNjaGtbU1nj1/gR6D3Gs9Dl9S5/Bg3b4TfGdNhZSUFDw8PKCqqopjx44h5FIo9iyfgH5LdkFO\\\\\\n\",\n              \"qfrbEApzs3FqiTcmT5wgVHgJBALMnj0boaGhuHnzJnR0Kq7oUVxcjIsXLyIgIABPHt5DyrtEaOkK\\\\\\n\",\n              \"f4Py9bNHsHrZEqG3Z5jPWICJQEpKCmVlZSKXKysrg6OjI/z8/BqgVQ1HWlq6PEC2bNmCu3fv4uTJ\\\\\\n\",\n              \"k3B1dUVpaSmMjIxgadUOrS0shdqfphYfru6jsGPHDqxfv17k9uTn55ePgjp06IBFixbhxYsXSE1N\\\\\\n\",\n              \"haWlJQYOHIhmzZohLy8PqampICIYtTCtMby+NGX2Alw4cxyrV6+u9nodESE9PR13796tMJqKjY3F\\\\\\n\",\n              \"q1evoK6uXh5SxsbGcHJygrGxMQwNDSsF3sFDh6GqVvujS/5NSVUNc+fOhbS0NFatWgUXFxdISEjA\\\\\\n\",\n              \"xsYG2TnZOOzrjvZDJ8PIxhFSX9zUXFZSjLjbV3Dn8Bbkf8xAh/btaq2ruLgYY8eOxevXrxEeHl6+\\\\\\n\",\n              \"6rxAIMCNGzdw6NAhnDhxAq1atYK7uzu4PA2EHAvAqBm/CfVd4h7/jeyMNKGfP8YwX2IBJgIDAwPc\\\\\\n\",\n              \"OHhY5HLPHseglUnzBmjRf0dSUhIdO3ZEx44dsXr1ajx+/Bh9+/XD0j82ibQfjzFe6ONoh1WrVkEg\\\\\\n\",\n              \"ECAtLU2oa0ipqakQCATloyAOh4Pnz5/j1atXaNasGczMzGBtbY0dO3bg1KlTsLKyQo+evTBouJfQ\\\\\\n\",\n              \"bZOQkMCwUT7Yum0brK2tERcXVyGgXrx4gdjYWBARTExMykPK1dUVxsbGMDIyEmlGn7yCAgry80Q6\\\\\\n\",\n              \"fgCQmZEOJycnbN68ucK1MwkJCWxcvx5d7e2xZv1G+Pv/gabWXSAtr4iS/FwkPLiBFiYm2LFpHXR1\\\\\\n\",\n              \"ddG7d29wOJxqZ8jm5uZi8ODBkJGRwaVLlyAnJ4eHDx/i0KFDOHz4MNTV1eHu7o779++Xj+QSExNh\\\\\\n\",\n              \"3a4dWre3g5Vt1xq/R3ZWBjYvnIrFixf98E8IYBoGuwYmgry8PDRp0gSXwu+gib6BUGVysrPRpqUh\\\\\\n\",\n              \"nj19WunUi7hTVlbG7ZiXUFYWbf1BM0MtlJaWoLi4uMZrSP9+T1FRscL1v6lTp0JVVRVLly4tf2/L\\\\\\n\",\n              \"li3Yvn07QkNDYWRkhPtxHyqNfGqSmZGOzpbNISUpiebNm1cYTX1+cbncerkOuWDhQrx6l47pC/8Q\\\\\\n\",\n              \"ukxhYQEG2rfGvai7aNq05tPZz549Q0RERPnztuzs7CrcaxUaGgoPDw9cvny50izMz6trmJmZYd68\\\\\\n\",\n              \"eTh69Gj5wyvd3d3h7u4OMzOzKuuNiIhAvwEDMGziPHTr5wrpKpa2invyEJsXTMFwtyFYuXKF0N+f\\\\\\n\",\n              \"aXjidA2MBZiIpk2bhsJSwvLVG4TafuvGtdixbTOaGhpixYoVP9RjGRo3bozHCR9EnoZtY2GMs2dO\\\\\\n\",\n              \"w8JC9CnXX4qJiUGPHj2QmJhY4f602bNnIyIiAvEJCbj56LVI+yQitNBWRFFRkUjB9zXevHkDc4s2\\\\\\n\",\n              \"OBUWAzkhFy8+f/Iw7l49g5CLF+qlDYGBgZgzZw4iIiIqjKJ++eUXNG3aFLm5uYiLi4Orqyvc3d1h\\\\\\n\",\n              \"Y2Mj1Aobjx49wqQpU/H8+XP80t8dzUwt0KhRI6S9T8L1s0eQk5mOxYsXwcvTs16+B1N/xCnAfvxn\\\\\\n\",\n              \"09ez+fPn4+L5Mzi4z7/WbUOCz2HHts24fu0apkyZAh8fH/zyyy+4ffv2f9DShsfhcpHyPlmkMkVF\\\\\\n\",\n              \"RcjKzICBgUGdRzGtW7eGvr4+goODK7y/Zs0aaGpqIjcnV+R9FhUWQlpausGfSUX0z71cALDXT7jr\\\\\\n\",\n              \"gbk5n7Br0wp0dbCvt3YMHToUc+fORffu3ZGQkIBVq/55oGRSUhJ4PB4WLlyI5ORkbNu2Dba2tkKF\\\\\\n\",\n              \"FwCYm5vjRth1hF27Cr488CDkGCJOBeBTQgzWLP8drxPiWXgxdcaugYlIS0sLl0JC4NyjB2IePcS4\\\\\\n\",\n              \"SdNg2LRZhW3eJydhz64dOHxgL86cPo2WLVuiZcuWGDJkCPbu3YshQ4bA0tISy5YtE+tFTAcOGIAT\\\\\\n\",\n              \"gQcw+9fFQpe5cC4ILVq0rLfHnvj4+GDnzp3o27dv+XuSkpIIDAyElhYfj6MfwMzCSuj9RUZch3kd\\\\\\n\",\n              \"R4a1efv2LSZPnozY2FjYd+mME/t3QVVVHcM8q1+IN/vTR8yfOBxtra2wefNmZGdnY8mSJbXe61eb\\\\\\n\",\n              \"4uJiGBgYQEpKCs2bN4eUlBR8fHywevXqenmmWqtWrbBxo3BnKxhGVGwE9hVMTExw5/ZtqCsroPcv\\\\\\n\",\n              \"dnDt1xOL5s3G4vlzMHLoQDh0tERR3ifcjoxEhw4dystJS0vD29sbcXFx6NatG5ydnTF06FDExsZ+\\\\\\n\",\n              \"w2/z9SZOnIjA/XtQXFws1PZEBP8/t+L9+2RYWVkhICBA6LLVcXV1RWRkJN68eVPhfTk5OUydOhV7\\\\\\n\",\n              \"/twi0v4C9+3CpIkT69Sm6pSVlWHr1q2wsrKCubk5WrZsidTUVJiYGGOv33p4D3HC9Uvnypf3AoDM\\\\\\n\",\n              \"9FTs81uPUX3t0NmmPU4HnUJ0dDQePXoEW1vbr/pvRyAQIDw8HOPGjYO2tjbWrFkDe3t7NGrUCK1b\\\\\\n\",\n              \"t8aGDRt+ugeCMuKJBdhX0tTUxNq1a/H27VtMHO+D5oZ6MGyiDY9hbkhMTMT27dvLV4L4N1lZWUyf\\\\\\n\",\n              \"Ph0vX76Eubk5bG1t4enpicTExP/2S9RRq1atYGCgj9lTfCDMpdRd2/+H4oICxMfHY+XKlThw4AAM\\\\\\n\",\n              \"DAywYsUKpKenf1Ub5OXl4e7ujt27d1f6bMqUyQi7fBFxz58Kta/7dyPx6OE9uLm5fVVbavL48WPY\\\\\\n\",\n              \"2dnhyJEjCAoKwpUrV5Cfn4/4+HikpqZCR5uPHo4OOH3AD306mWBEbxsM7d4W7j06oPjje5w/ewab\\\\\\n\",\n              \"Nm2ElJQUNDQ0ypdzsrW1xV9//VXl8b937x5GjxkDw6bNwOXyoKOrB3MLC/D5fEycOBGGhoa4d+8e\\\\\\n\",\n              \"xo4di6CgIFy7dg1NmjTB6NGjIRAI6v0YMEy9ox+UtbX1t26C0DIzM+nXX38ldXV1mjx5Mr1///5b\\\\\\n\",\n              \"N6lWZWVlNH/+fNLX1ycLizY0yM2dHr1MoreZhZVese8yadrs+aSvb0CJiYkV9vPo0SMaO3Ysqaqq\\\\\\n\",\n              \"0rhx4+jZs2cit+XRo0eko6NDJSUllT4LCAggvrYOBYc/oNiU/Gpfx4LDiKehSRcuXPjqY1KVgoIC\\\\\\n\",\n              \"WrBgAXG5XNqxYwc9f/6cmjdvTvb29iQjI0NcLpf27dtHpaWl5WWSk5Pp0aNH9Pz5c8rJyalx/0+e\\\\\\n\",\n              \"PCELCwsaMGAApaenExHR27dvyaZTJ9Jtok/T5i2hoGv36cr9l3Q67AHN+m0lNTFoSpZW1hQbG0t/\\\\\\n\",\n              \"/PEH6evr0/Pnz4mIKD8/nzp37kxTp04lgUBQr8eCEQ/i1HeyAPuOfPjwgaZPn07q6urk6+tLGRkZ\\\\\\n\",\n              \"37pJVcrJyaH+/ftT586dKTU1lfLy8sjT05OUVVTI1X0E7Q08RUEXr9PBE+fIe8JUUudwqLeLS43B\\\\\\n\",\n              \"nJKSQosXLyYNDQ3q1asXhYaGitSB2tjY0JkzZ6r8bM+ePaSiokojPCdQaGRMheA6dy2KBg8bRWrq\\\\\\n\",\n              \"6tWW/1rXr18nY2NjGjRoECUlJVFERARxuVzi8/kkJSVFc+bMoYKCgjrXU1hYSLNmzSIdHR06cOAA\\\\\\n\",\n              \"aWvr0FTfJXQ/IYsevsmu9Hrw+iPNX76elJVVyMjIiN69e1dhf1lZWWRubk4rVqyoc9sY8SNOfScL\\\\\\n\",\n              \"sO/QmzdvyNvbmzgcDv3++++UnZ39rZtULjExkSwsLGjMmDFUVFRU4bPU1FRatWoV/eLoSO3adyCH\\\\\\n\",\n              \"rt1o7ty5FB8fL/T+8/PzadeuXdSqVStq3bo1+fv7U2FhYa3l9uzZQ717967281evXpGTkxPJyclT\\\\\\n\",\n              \"c2MTamPVjpo2a05aWnyysLCgOXPmCN3G2mRmZpKnpyfp6upSUFAQERH5+/uTnJwcSUpKUtOmTSkh\\\\\\n\",\n              \"IaHe6iP655eKuXPnkpy8Ak2bv7TK4Pr367c//kcGhk2puLi40v6Sk5PJ0NCQdu3aVa/tZL5/4tR3\\\\\\n\",\n              \"sgD7jsXFxdHw4cNJQ0OD1q1bR/n5+d+0Pbdu3SI+n0/r1q1r8NNLAoGALl68SM7OzqSlpUVLly6l\\\\\\n\",\n              \"1NTUarfPy8sjdXX1Sqco/23GjBlkaWlJ165do8ePH1NxcTHFxsYSl8ulrKysOrc5MDCQ+Hw+TZo0\\\\\\n\",\n              \"iT59+kR5eXnUs2dPkpCQIHl5eZo8eXKF04V1UVRURGfOnKGhQ4eSsrIytW/fnpoZt6C/Ez8JFWAP\\\\\\n\",\n              \"32RTu462dOzYsSr3HxsbS1paWuUhzPwcxKnvZAEmBmJiYqh///6ko6NDfn5+lUY+/4X9+/cTl8ul\\\\\\n\",\n              \"s2fP/ud1P378mLy8vEhVVZW8vb3pyZMnVW43efJkWrx4cY37KisrI1dXV3J1daWysrLy90ePHl1r\\\\\\n\",\n              \"2ZokJiZS7969ydTUlG7dukUlJSX0559/koKCAsnJyZGKikq1QSGKsrIyCg8Pp3HjxhGHwyFbW1va\\\\\\n\",\n              \"tm0bpaamUr9+/Wnhyk1Ch9fDN9n0x1Z/6uLgUG19UVFRxOPxKDw8vM5tZ8SDOPWdLMDEyN27d6l7\\\\\\n\",\n              \"9+5kaGhY6cJ/Q/k8WcPQ0JBiYmIavL6afPjwgX7//XfS1NQkZ2dnCgkJqTASfPToEenq6lY5meNL\\\\\\n\",\n              \"BQUFZGdnV+G04atXr4jD4Yh83bG0tJQ2bdpEHA6Hli1bRoWFhXTq1CkyMTEhNTU10tTUrPOxEwgE\\\\\\n\",\n              \"FB0dTXPnziU9PT0yNTWllStXVjoNKSsrS+Exb0QKsKiX6SQjI1PjadrQ0FDS0NCg6Ojor/4OjPgQ\\\\\\n\",\n              \"p76TBZgYun79Otna2lLLli3p2LFjFUYS9enfkzW+FwUFBeTv709mZmZkampKf/31V/lkiI4dOwo1\\\\\\n\",\n              \"GSMjI4NMTExo69at5e95eXnRggULhG5HdHQ0tW/fnrp06ULPnz+n8PBwsrGxoRYtWpChoSHx+Xxy\\\\\\n\",\n              \"cnL66sk4CQkJtHLlSjI1NSU9PT3y9fWtNkQKCwupUaNGIp0+/PxS53Dpw4cPNbYlMDCQdHR0RLqe\\\\\\n\",\n              \"yYgnceo7WYCJKYFAQMHBwWRpaUlWVlYUHBws1HWpu3fv0rhx48jZuQc5OXWnUaNGU2hoaKUQrGmy\\\\\\n\",\n              \"xvdCIBBQaGgo9ezZkzQ1NWnx4sW0ceNGcnFxEap8fHw88fl8On36NBH9Exjq6uqUlpZWY7n8/Hya\\\\\\n\",\n              \"P38+8Xg82rVrF0VHR5OLiwvp6+vTypUric/nk6qqKs2dO1fkUXJqaipt27aNOnXqRBwOh8aPH0/h\\\\\\n\",\n              \"4eG1/pISFxdHkpKSFPUyXaTw+jvxE8nKydU6XZ+IaMuWLWRkZFRr2DHiTZz6ThZgYq6srIyOHTtG\\\\\\n\",\n              \"LVu2JDs7OwoLC6tyu/DwcLKytiYDA0Na9PsKCjxxmo6cPEur128ms9bm1NzIiI4ePUpE/+1kjfry\\\\\\n\",\n              \"9OlT8vHxIRUVFWrcuDGFhoYKVS4qKoq4XC7duXOHiIjGjx9Pvr6+1W5/5coVat68Obm6ulJUVBSN\\\\\\n\",\n              \"Hj2aNDQ0aOPGjXThwgVSUlIiJSUlCgwMFLrtOTk5dODAAerVqxcpKyvTsGHD6OzZs9X+4lBYWEjn\\\\\\n\",\n              \"zp2jUaNGkbGxMcnIyBAAUlJWpq37josUYHtPhlLTZs2F/ndeuHAhWVtbf1czY5n6JU59JwuwH0Rp\\\\\\n\",\n              \"aSnt27ePDA0NqXv37nT37t3yz06dOkVcHo/2HjxK6TnFlJlXWuGVkVtC50Kukq6uHo0cOZJ4PN43\\\\\\n\",\n              \"maxRH9LS0qhDhw6kqKhITk5OdOHChVo75zNnzhCfz6dXr17RmzdvSF1dvdIoIz09ncaMGUN6enp0\\\\\\n\",\n              \"4MABmjVrFqmrq9OCBQvo48eP5dPk+Xw+PXz4sNZ2FhUV0dmzZ2nYsGGkrKxMPXv2pAMHDlQ5EkpM\\\\\\n\",\n              \"TKR169aRo6MjcblckpCQICkpKWrSpAkNGTKEDh48SHl5ebR7925ycOwpUoC5DHSj9evXC318BQIB\\\\\\n\",\n              \"+fj4kKOjo1C3NzDiR5z6ThZgP5iioiLy8/MjbW1t6t+/Px07doy4XC5djbhTKbj+/Xr0PJ7UORza\\\\\\n\",\n              \"smXLt/4adRIdHU06Ojq0e/dusrCwoJYtW9LOnTtrvA1h+/btZGxsTOnp6TR58mSaOXMmXbt2jaZM\\\\\\n\",\n              \"mUqdu9iTiqoaderUiebMmVN+ai85OZkEAgH5+vqSrKwsderUqXw1jKp8nkE4fvx44nK51KlTp/IZ\\\\\\n\",\n              \"hJ8VFBTQlStXaMKECWRqakqNGzcun4JvZWVFM2fOpKioqCpPKebl5RGXy6M9Jy4JFV6BwTdIVVWN\\\\\\n\",\n              \"MjMzRTq+paWlNGDAAHJzc2uw66/MtyNOfScLsB9Ufn4+rVu3jlRVVWnZqrW1htfn1+HjQWTdtu23\\\\\\n\",\n              \"bn6ddezYkc6ePUsCgYCuXLlCLi4uxOPx6Lfffqt2RRBfX1+ytbWlDRs2koKiEhk0MyLPGQvJd9VW\\\\\\n\",\n              \"mvn7erLv0Y/k5OSpX7/+9P79eyoqKqK+ffuSjIwMTZgwodrZj9HR0eTr60tNmjShVq1a0YoVKyg+\\\\\\n\",\n              \"Pp4EAgG9fv2a/Pz8qHfv3qSpqUmSkpIkKSlJmpqa1LNnT9q2bRslJycL/b1DQkKIy9OggKDLNYbX\\\\\\n\",\n              \"kYs3SVOLX37aWFQFBQXUpUsXmjx5cqURrkAgoOzsbMrIyPhPZsoy9Uuc+k4WYD+wlJQUUlFVpYSk\\\\\\n\",\n              \"dKEDLC27iPT1DSqcghRH/v7+1KdPnwrvPX/+nMaPH0+qqqo0evToSjP6ysrKqLW5OWlq69KGfUF0\\\\\\n\",\n              \"5VkaXX2eXuF1KjKW3H2mka5ek/LrT/7+/pXq/zyD0MzMjPT09Gju3Ll0584dunbtGs2YMYPatGlD\\\\\\n\",\n              \"srKyJCUlRdLS0tSiRQvy8vKiixcvUl5eXp2++7lz54jD4dJg99EUeCGiQnCduHyXho32IXV1Dh0+\\\\\\n\",\n              \"fLhO9Xxecmr58uVE9M+pzvnz5xNPQ4MUFBRIWUWFZGVlaejQYXTjxg2xuZ76sxOnvpM9kfkHtmvX\\\\\\n\",\n              \"Lly+eh07dgeIVG7V8iWgkkKsXbu2YRr2H8jLy4OOjg68vL3x6FEMcnNzoaSkhK4O9hgwYABOnDiB\\\\\\n\",\n              \"bdu2oWXLlpg5cyZ69OiBP3fuxB9r1mPTwXNQVefWuP+TB3Zh94bluBB8Hvb2/zxgMj09HUePHsWh\\\\\\n\",\n              \"Q4fw7NkzODs7o0mTJoiNjcWdO3eQkpICCQkJKCgowNzcHD179kSfPn1gamoq9IMihZWSkoK//voL\\\\\\n\",\n              \"O/7cCQlJSSgrqyA3NwdFhYXw8faCj48PdHV161zP+/fv0alTJ1hYtEH4jXAMGOIO99HeaG5kAgD4\\\\\\n\",\n              \"9DELJ44cxME9f8LQQB/Hjx+vt2fBMQ1DnPpOFmA/sJUrVyIt8xMWLV0pUrn9+/xx8sghTJo0EQoK\\\\\\n\",\n              \"ClBQUIC8vHyFPxUUFCArK1vvHW99yMrKwqRJk3H6zGk4uwxEN2cXKCopIyf7E66FnMOVi2cxcNBA\\\\\\n\",\n              \"rF+3DsHBwdiwYQNyc3ORlPweWw5fgKFxS6HqWT7DE93tbaCnp4d9+/YhIiICBgYGKC4uRlJSEkpL\\\\\\n\",\n              \"SyEQCKCtrQ0bGxv069cPXbt2BZ/Pb+Aj8P+VlpYiISEB2dnZUFJSgqGhIaSlpeu1jvHjJyD0ylUE\\\\\\n\",\n              \"nr4EDU2tKrcpKyvD0l9n4fHDewgLu86eN/YdE6e+kz2R+QcmLS2NkpISkcuVFBcjKioKnp6ekJGR\\\\\\n\",\n              \"gZSUFCQlJUFEKCsrQ0lJCQoLC1FcXAw5ObkKoVZV0In63ue/y8vLixyQaWlp6GLvAKsOnXH57gso\\\\\\n\",\n              \"KVf8bb9r996YtXAF1q9YCEen7rh29Qo0NTUxePBgNGlqJHR4AUA/D2/M83ZDWWkJSktLIS0tjVev\\\\\\n\",\n              \"XqFly5aYNGkSevbsiQ4dOnzTzrpRo0YwMjJqsP2HhITgYsglnAoJhzqn+lGrlJQUlvyxEXOmeGP2\\\\\\n\",\n              \"7Nnw8/NrsDYxPw8WYD8wIyMjnD5zVuRyDx/cx5w5szFo0CC8f/8eycnJ5a8vf37//j1kZGTA4XDA\\\\\\n\",\n              \"4/Ggrq4OVVVVKCsrQ0lJCXJycpCVlYWMjAyKioqQl5eHtLQ0vH79Gvn5+cjLy0NeXl753//9XkFB\\\\\\n\",\n              \"ARo3bix0EMrLy+PQ4UDYO7lg5oJl1X4/FTV1/L52G1b9Ngutzc3x6eNHCAAMGjlOpONkZtUBCkrK\\\\\\n\",\n              \"aG6oD3d3dzg4OKBVq1aQkpIS+ZiLq42bNmHyrHk1htdnEhIS8F20Ao6dLLBq1Sqoqqr+By1kfmTs\\\\\\n\",\n              \"FOIPrKSkBPr6+jh+5iJatjIVqszHrCxYmRnhxYsX0NDQqHFbgUCAzMzMKsPty59TUlKgrKwMPp8P\\\\\\n\",\n              \"bW3t8teXP/P5fPD5fMjIyFTYf2FhYa1B9/nPmJgYhEfcwulr94UauZWWlqJbWyPkZn+CvIIiVu48\\\\\\n\",\n              \"gpbmVkIdp8/mjh0ES1NjWFpaQlpaGjIyMpCWlq70EvV9aWnp7/L07Jfi4+PRrl173Ip+CVk5OaHL\\\\\\n\",\n              \"TfUZAQc7G8yYMaMBW8d8LXHqO9kI7AcmLS0NLy8vbF6/Bn5/7YWEhEStZbZv3QQOhyvUI+UlJSXB\\\\\\n\",\n              \"5XLB5XJhbm5e7XYCgQAZGRmVwu3p06e4fPlyedB9+PABKioqlcLt3z9raWlVeR1nwMCB8PCcJHTH\\\\\\n\",\n              \"36hRI4zynozrF0/j7bskCMrKhCr3pdLSEqSkpCA6OholJSUoLi5GSUlJpZeo75eUlEBKSqrOQdiQ\\\\\\n\",\n              \"71+6dAk2dvYihRcA/OLsgvDQsyzAmDpjAfaDmzVrFmzt7LD2j+WYM29hjSF29PBBHNy3BwMG9Ie5\\\\\\n\",\n              \"uTlWrlwJT09PoYKvJpKSkuDxeODxeLCwsKh2O4FAgLS0tEojuZiYGISEhJQH34cPH6CmplYh3LS0\\\\\\n\",\n              \"tHDu3DnMW7lNpLb1cx2BretWoLGsLF49fwxTy3ZCly0rK0PKuzc4uGcXWrVqJVK9tSEilJaW1jkI\\\\\\n\",\n              \"a3s/Ly/vq/fz4cMHdLC1F/m7KSkpIycnp16PF/NzYgH2g1NRUcGlkBD06NkTT2IeYeLUGWjfwaZC\\\\\\n\",\n              \"KD198hi7dmzD1dCLuHQpBGZmZvDx8YG3tzcOHjyInTt3NuhEgM8kJSWhqakJTU1NtGnTptrtysrK\\\\\\n\",\n              \"kJaWVuE05cuXL/+5XqaoJFKdHC4PRAJIgHB83w70GTpa6MC+ExYKXR3teg8v4J/rRZ9HOt+rQ4cO\\\\\\n\",\n              \"4fDRkyKXy/70EcrKyg3QIuZnwwLsJ6CtrY1bN29i586dmDxuLGRl5dCylSkkpaSQ8Ool3r19A29v\\\\\\n\",\n              \"b9y7d6/8upeFhQUiIyOxZcsW2NjYYNasWZg9e/Z30aFmZ2cjPj4e0dHRuH37NmJiYpCQkIDiomKR\\\\\\n\",\n              \"9yUQCCAQCNBaJ9AuAAAbb0lEQVSsWTPEJ7zG33duwKpjF6HKHdixAX17OH7NV/gh2NraYvKUKcjP\\\\\\n\",\n              \"y4O8CDMtQy+cRQ+nrg3YMuZnwSZx/GQEAgFu3ryJxMRECAQC8Pl8ODg41BhMr1+/xoQJE5CcnIxd\\\\\\n\",\n              \"u3ahffv2Dd5OIiq/Tnb79m1ERUXh2bNnSEpKQlFREaSkpFBaWgoulwtDQ0OYm5vj+ImT2H0kGEYt\\\\\\n\",\n              \"hB8RRT+4iwkeA+Dntx0KCgoYNXoMNh08D4PmJtWWEQgE8Fu1AC8e3sHHrEz07NkTa9eu/Slv0O3T\\\\\\n\",\n              \"ty/suvXE0BFjhdr+w/tkOHexRuLr12wU9p0Sp77z+57mxHw1IsKtW7cw3MMDOjo6UFJSAp/Px4CB\\\\\\n\",\n              \"A1FQUAB3d3eMHDkSTk5OtY6qDAwMEBwcDF9fX/Tt2xczZsxAbm5uvbSzrKwML1++xPHjxzF9+nQ4\\\\\\n\",\n              \"ODhAV1cXjRs3hr6+PpydnbFy5Uo8efIERkZGmDVrFoKCgvD06VMUFRXhw4cPuHr1Ktq3bw9Z2cY4\\\\\\n\",\n              \"6L9dpPr3/bkFcnJySElJgZ+fH4yMmmPO6P4IOrgbebmVr9M8j3mAJVNG4V3cY4SHXceTJ08gKSkJ\\\\\\n\",\n              \"MzMznD9/vl6OiTiZMX06tm1cjdQPKbVuKxAIsHyRL0Z4eLDwYuoFG4H9gN6+fYvBQ4YgIyMDXt7j\\\\\\n\",\n              \"0bf/QKiqqSEvNxchF4Oxc8d2FBcV4vjx4zAzMxNp3+np6Zg1axbCw8Ph5+eHHj16CFWusLAQL168\\\\\\n\",\n              \"QEREBO7cuYMnT54gMTERWVlZkJCQABFBTU0N+vr6MDU1RceOHWFtbQ0TE5Nq7xd69eoV/Pz8sHfv\\\\\\n\",\n              \"XnTq1Alubm6YMHEiTobegZZ27cskvUl4hWF97OG/ezcmTZqEzMxMjBw5Evb29jgVdBpXrlyBjUN3\\\\\\n\",\n              \"KKtzUVJUhId3I1BWUohJEyZg6tSpkPti9t3Vq1fh5eUFOzs7bNy4ERwOR7gD+gNYunQpDgUegf/h\\\\\\n\",\n              \"IOjoNqlym5KSEiycPRlvE17i8uXQCseO+b6IVd/5DdZf/E+I04KU9SkxMZH09PRoxao1lF1QSrlF\\\\\\n\",\n              \"gkqvnMIy+mtPAGloaAj17KqqhISEkKGhIQ0fPrzC40A+fvxIV69epcWLF5OLiwsZGxuToqIiSUhI\\\\\\n\",\n              \"kISEBDVu3JiaNGlC3bp1o5kzZ9KJEyfo1atXQq9aXlZWRsHBwdSrVy/icrk0Z84cio+Pp+TkZHJ2\\\\\\n\",\n              \"diZpGRnS0dOnq/fiKOZtTrWvkMgnpKdvQJs2bSZ7e3saMWIEvXv3jlasWEF6enrUsWNH2rRpE23d\\\\\\n\",\n              \"upUMDQ3J09OTGjduTAUFBdW2LTc3l6ZNm0Z8Pp+OHz/+VcdVHAkEAlq3bh2pqKrS0BFj6NyVSIpP\\\\\\n\",\n              \"zaeEtAK6+ziB5iz4nXT1mlCfvv2EevIz822JU9/JAuwHIhAIyMrKilatXldlcP37te9gIOnp6dX4\\\\\\n\",\n              \"nKzq6nn37h35+/tTmzZtSFpamtTU1EhaWro8qFRUVMjU1JQGDRpEf/zxB926dYs+fvz41d8tMzOT\\\\\\n\",\n              \"NmzYQM2aNSNLS0vy9/enzMxMOnr0KDk5OVGjRo1IU1OTAgMD6felS4nD06B5v6+lW0/eVQiuiJg3\\\\\\n\",\n              \"NGfRKuJpaJKysjJZWlqSl5dXhedalZaW0unTp6lHjx7E4/HKH2LZunVroVbpv3nzJpmYmNDgwYMp\\\\\\n\",\n              \"JSXlq7+zuElJSaHly5eTvr5B+Sr7SkpK5OnpSffv3//WzWOEJE59JwuwH8jVq1epVStTyiksEyrA\\\\\\n\",\n              \"cosE5OTcg/bt21fl/kpLSyk6Opo2btxIw4YNozZt2hCHwyFJSUkCQDIyMqStrU3m5ubE4XDI1NSU\\\\\\n\",\n              \"rl27Vq/PgHr48CF5e3uTqqoqubu7082bNykiIoLGjRtH6urq1LZtW1JXV6eZM2eWP4/r+vXrpKio\\\\\\n\",\n              \"SGpq6iQnL0+duzpS7/6DqUtXJ1JRUSV39+F06dIl0tfXJxUVlQojyH+Li4sjPT09UlZWJj09PfL2\\\\\\n\",\n              \"9hbq+xUUFNC8efNIQ0ODDhw48NM9SqSkpKTG0Srz/RKnvpNdA/uBDB4yBLadHeAzfqLQZYLPncXa\\\\\\n\",\n              \"P5Zj2bJlCA8Px4MHD/Dy5UukpKQgLy8PAKCgoAA+nw8jIyNYWVnB3t4e7du3r3AhvqSkBBs2bMDa\\\\\\n\",\n              \"tWsxf/58TJs2DY0afd1dGiUlJTh16hS2bt2K+Ph4jB8/Ht27d8fFixcREBAAaWlpjBw5EkVFRfDz\\\\\\n\",\n              \"88Nff/2FPn36gIhw8uRJDB8+HAYGBti9ezeaNm2KqKgo5OTkQFlZGR07dgQAODk5wcnJCZKSkrh1\\\\\\n\",\n              \"6xZCQ0MhKytbZXu6du0KX19fnDx5EidOnICSkhLGjRsHT0/PWpfbunfvHsaOHQt9fX3s2LEDOjo6\\\\\\n\",\n              \"X3VMGOa/IlZ95zcO0AYjTr9F1BdlZWVKePdB6NFXbpGAPuWXkLS0NElJSRGPxyMrKyvy8PCg//3v\\\\\\n\",\n              \"f/T48WORHxkfFxdH3bp1I2tra3rw4IFIZZOTk2nJkiXE5/PJ3t6e9u7dSzt27KDOnTsTj8ejKVOm\\\\\\n\",\n              \"UFRUFH38+JEGDx5MVlZWFB8fT0REYWFhZGNjQ1paWtSmTZtq252cnEytWrWihQsXkkAgoLKyMnJ1\\\\\\n\",\n              \"daVhw4ZVW8bBwYGuXr1KcXFxpKurS/fu3SNPT8/yUWFtD2ssKiqiJUuWEJfLpV27dv10ozFGvIhT\\\\\\n\",\n              \"38kC7AchEAgIQLUTN2p6aWtr05s3b+q1Lf7+/sTj8Wju3Lk1PmFYIBDQlStXyMHBgeTl5cnZ2ZmW\\\\\\n\",\n              \"Ll1Kbm5upKysTAMHDqSgoCAqKioiIqLo6GgyMjKi8ePHU0FBAUVHR1OvXr3IwMCANm7cSGpqahQb\\\\\\n\",\n              \"G1tlXW/fviUjIyNatmxZhffz8/PJxsaGFixYUGW5zwEmEAiIx+OVH6vMzEzatGkTGRsbU+vWrcnP\\\\\\n\",\n              \"z4+ys7Or/a7R0dFkbW1Njo6OlJCQUNMhZJhvRpz6ThZgPxB5eXl6n/5JpPDKKSwjVVVVysjIqPf2\\\\\\n\",\n              \"pKSkkJubGzVr1owuX75c4bO8vDxatWoV8bW1SV5eniytrMm5Ry+ysbUjZWVlam1uTgEBARVGK3v2\\\\\\n\",\n              \"7CEul0sBAQEUHx9PHh4epKmpSZs3b6bCwkLq378/LVmypMq2JCQkkKGhIa1bt67Kz1NTU6lZs2a0\\\\\\n\",\n              \"e/fuSp99DjAion79+lFgYGCFzwUCAV2+fJkGDhxIampqNHHiRIqJiamynpKSElq9ejVxOBzasmWL\\\\\\n\",\n              \"yCNchmlo4tR3sgD7gXSxt6cDgcdECrDrEbfJ0NCwQTvSs2fPkp6eHo0ZM4bu3btHs2bNIiUlJVJU\\\\\\n\",\n              \"VKLJU2fQo6cvKadQUP5K+5hPu/wDyMysNY0YOZI+ffpEXl5eZGJiQmFhYTRlyhRSV1enJUuWlI94\\\\\\n\",\n              \"zp49S82bN69y4kBsbCw1adKEtm7dWmM7nz9/ThoaGhQaGlrh/S8DbPXq1TR16tRq9/Hu3TtavHgx\\\\\\n\",\n              \"8fl86tKlCwUGBpaPHv9dV6dOncjOzo5evHhR6zFkmP+KOPWdLMB+IIGBgeTQtZtIATZ8xChavXp1\\\\\\n\",\n              \"g7arrKyMjh8/TgYGBiQhIUFGRkakqKhEIVfCKwTXv18pGTn0i1N30tLSogEDBpCvry+pq6vTtGnT\\\\\\n\",\n              \"6MOHD+X7z8vLIwMDA7p06VKlup8+fUo6Ojq0c+dOodoaFhZGPB6vwgjqywC7ceMGtW3bttb9FBcX\\\\\\n\",\n              \"07Fjx6hr166kpaVFCxYsoMTExArblJaW0ubNm4nD4dCaNWvKZ1EyzLckTn0nC7AfSFFREWlra1PQ\\\\\\n\",\n              \"2QtChdfNO/dJTU2N0tLSKuwnPz+fXr16Rc+fP6f09PSvbs+X924ZGRlR586dSVFRkRQVFenwsVM1\\\\\\n\",\n              \"hteXIaavb0Cqqqo0YsSIKq8dzZ8/n9zc3Cq9Hx0dTXw+nwICAkRq94EDB0hfX5+Sk5OJqGKA5efn\\\\\\n\",\n              \"k7y8POXm5gq9v6dPn9LUqVNJXV2d+vbtSxcvXqww4n316hV17dqV2rVrV+2pR4b5r4hT38kC7Adz\\\\\\n\",\n              \"48YN4vF4dO7i5RrD69bdB6StrU3Hjh0rLxsdHU0+Pj6kqqpKTfT1qVnz5qSsrExdu3Wj48ePU3Fx\\\\\\n\",\n              \"sVBt+LwfJSUlMjMzIz6fT6amprRmzRo6evQotTI1o+yCMqECLKdQQJu2+NEvvzhWWdfTp0+Jw+FQ\\\\\\n\",\n              \"UlJShffv3btHmpqadOTIka86jkuXLiVra2vKzc2tEGBERB06dKDr16+LvM/c3FzatWsXtWnThpo1\\\\\\n\",\n              \"a0br1q0r/wVBIBDQn3/+SVwul5YuXVrtsU5LS6M//viDHLp2pTaWltTJ1pamTJlCT58+/arvyTD/\\\\\\n\",\n              \"Jk59JwuwH9C1a9eIx+PRELdhFHrtRoUbmyOj/qYxnt6krq5e3rmXlZXRjBkzSFtbmxYu/p1evk4q\\\\\\n\",\n              \"3z4zp5D27D9ENp1sqU2bNpWC4rPi4mI6cuQIdezY8Z8AbNKEuFwuTZ8+nR48eFA+GcPVzY02bN4q\\\\\\n\",\n              \"dHjlFArofXo2qampVapbIBCQg4MDbd68ucL7kZGRpKGhQadOnfrqYygQCGj06NHUt29fsre3rxBg\\\\\\n\",\n              \"M2bMoJUrV9Zp35GRkTRixAhSVVWlUaNG0Z07d0ggENCbN2+oZ8+eZGFhUWH1ivz8fPLx8SEVFRXy\\\\\\n\",\n              \"GDmaTp0JpvBbd+lC6DWaO38BaWpqUrdu3djsRqbOxKnvZAH2g/p8+s7IyIg0NTXJ2MSEdHR0SFdX\\\\\\n\",\n              \"l5YuXUrv378non8604kTJ1InWzt6m5JR42zFxUuXU/PmzSucckxOTqZFixaRuro6cblckpeXp8GD\\\\\\n\",\n              \"B9O5c+eqHEWYW1hQxO37IgVYTqGAbDrZUnh4eIV9BQQEkKWlZYVrR+Hh4cTj8ej8+fN1PoZFRUXU\\\\\\n\",\n              \"rVs30tXVrRBgx44dIxcXlzrvn+ifEdWaNWvI0NCQrK2t6a+//qLc3FwKCAggHo9H8+fPp4yMDOrc\\\\\\n\",\n              \"uTMNHuJGiclpVf77ZOYU0vKVq0lbW5tNCmHqRJz6ThZgP7iysjJ6+/YtPXnyhBITEytNFDhz5gyZ\\\\\\n\",\n              \"mLSgpNQsoa6bTZsxi9zchtKNGzeoe/fuJCMjQ7KysmRpaUk7d+6krKysGtvTokULuvsgRuQAs3fo\\\\\\n\",\n              \"VmF2YGZmJmlpadGdO3fK37t8+TLxeLxKswjrIisri+Tl5Wny5MlE9E/gHzt2jFRUValLF3uy69yZ\\\\\\n\",\n              \"XN3c6MyZM3VaQuvzIsUuLi6krq5O06dPpxs3btCAAQNIQ0ODhrgNFeoev61+O6lp06Y13nvHMDUR\\\\\\n\",\n              \"p76TBdhPzsmpO/21J0DoWYtJqVkkJydH0tLSxOFwyNfXl16+fCl0fbZ2dhR09qJI4ZVdUEZGRsYV\\\\\\n\",\n              \"VvYYP348jR8/vvzn8+fPE4/Ho7CwsHo9PkREHTt2JA6HQ0uWLKFWrVpRy5ataPW6DXTu4mUKvnSV\\\\\\n\",\n              \"tvrtpHbtO5C+vj7t37+/zvUlJCTQ/PnzSUNDgzp16kTKKiqUmpUr9L+Rc89e5O/vXw/fnPkZiVPf\\\\\\n\",\n              \"yQLsJxYXF0c8Ho/SP+WLPPXe09Pzq5ZEWr9+PbkNHS5SgF0Nj6SmTZuWz9y7c+cOaWlpUWZmJhER\\\\\\n\",\n              \"BQUFEY/Ho8jIyHo9Pp85ODjQ2LFjSVlZmU6dCa52seRrNyLJwNCQ1qxZUy/1FhYWUp8+fWnCpCki\\\\\\n\",\n              \"/fscP3VWqKn+DFMVceo72ROZf2L379+HXWf7ahexrU4vlz5IS0uHhISEyHWOGTMGIRfPIy01Vegy\\\\\\n\",\n              \"mzeuA4fDQXJyMkpLSzF+/HisWbMGampqOHr0KMaNG4cLFy6UL9Rb37Kzs3H69BmEXo+Ak3OPar93\\\\\\n\",\n              \"u/YdEHr1BrZu24agoKA619u4cWM8f/4MI0ePFamck3MPvH79GklJSXVuA8N8z1iA/cTy8vIgr6Ag\\\\\\n\",\n              \"cjlFBcXylepFpaamBi8vL3iO9kBxcXGt2588fhQ3b4ShdevWsLCwQM+ePaGgoAAPDw/s378f06dP\\\\\\n\",\n              \"x6VLl2Btbf1V7RFGaloaFixaAlPT2p9ezdfWxsbN27B8+XJQPTzoITMzE5pafJHKSElJQVNTCxkZ\\\\\\n\",\n              \"GXWun2G+ZyzAfmIqKirI/IpOLiMzAyqqKl9d76pVq6CmpoJB/Xoj6d27KrcpKSnBju1bMGXiOLQ2\\\\\\n\",\n              \"b4OgoCCYW1ggLCwMz549w/DhwzFv3jxcuXIF5ubmX92W2iQlJSErMxPDho8QuoyTcw9kZGYiKiqq\\\\\\n\",\n              \"zvU3btwYJUIE/b8VFhWicePGda6fYb5nLMB+Yl26dEHkrQhkZWWJVC7oxHE4OTp+db2NGjXC0SNH\\\\\\n\",\n              \"0MmmAzq1bwN3t0E4feoEIm/dxPWrV7BsyW9oZWSA06dO4nrEbZy9cAlP415j4GA3cDgctG7dGqdO\\\\\\n\",\n              \"nYKMjAyePn1aLyOd6pw5cwY9e7lASUlJ6DKSkpIYNnwEjh8/Xuf6W7RsicjImyKVeZ+cjIz0dOjq\\\\\\n\",\n              \"6ta5fob5nrEA+4nxeDz0dnHBwYC9QpdJevcON8Kvw8PDo051S0lJYcWKFUhMTESvHt2xavkSeAwb\\\\\\n\",\n              \"jFUrl+LTp484HRyCi5evwdjEBACgqKiIsV4+iLh9H/EJrzF37lz8+eefWLp0Kezs7HDr1q06tac6\\\\\\n\",\n              \"aWlpMDBsKnI5HR1dpKen17n+8ePGYffOHSKV2eO/C25ublD4itPDDCNOWID95KZPm4ZNG9bi1cuX\\\\\\n\",\n              \"tW5bVlaGGVMnYezYsVBUVKyX+hUVFeHs7IyUlBRERj3EpSthWL9pS7XXm7T4fFy4dBWbN2+GpaUl\\\\\\n\",\n              \"Hjx4AB8fH7i5uWHQoEGIi4url3Z91rhxYxQWFopcrqioqF5O4fXv3x8v42JxIzxMqO3T0tLgv+tP\\\\\\n\",\n              \"TJwo/FO5GUZcsQD7ybVt2xZLliyBS09HPI55VO12eXl5GO0xDMXFhVi5cmW9tmHHjh1w9xgJTU1N\\\\\\n\",\n              \"obbXNzBAn34D4O/vDykpKYwaNQqxsbFo164dbGxsMGXKFKSlpdVL21q0aIGbN4QLjy/di7oLY2Pj\\\\\\n\",\n              \"OtcvLS2NPXv2YLTHUPz94H6N22ZkZGBwfxeMGTOmQa8LMsx341vP428o4nQvw/fgwIEDxOFwqFfv\\\\\\n\",\n              \"PnTy9HlKePeBklKz6Pb9aJoybQZxOBwaNXo0FRYW1mu9ZWVlxOPx6OHjFyLd6xR28w41b9680v5S\\\\\\n\",\n              \"U1NpypQpxOFwaMWKFXVekaK4uJiUlJTpzoNHQrftdVIqqaqq1mkl/387efIkcblcmjl7Lj15EV/p\\\\\\n\",\n              \"5vJ1G/9H+gYGNHv27K+6P49hPhOnvlOCqAGvgH9Dbdu2xb179751M8RKXl4eAgMDsXPXLryMi0Nx\\\\\\n\",\n              \"cTF4PB4GDx6MCRMmwNDQsN7r/PjxI/T19ZGc9lGkcqWlpeAoy6G4uBiSkpVPJMTFxeHXX3/F7du3\\\\\\n\",\n              \"sWzZMowYMQJSUlJf1cZmzZrB0rot9u4/LNS9b4sXzkd62gfs3bPnq+qrzsuXL7Ft2zbs378fzY2M\\\\\\n\",\n              \"weFwkZ+fh4d/P4CjkxMmT5oEe3v7eq2T+fmIU9/JAoz5ptLS0tCiRQu8eS/ahAcigqqCDAoLC9Go\\\\\\n\",\n              \"UaNqt4uMjMTs2bORm5uLNWvWwNnZWeQ2du7cGWlpaejTfyAW/768xhDbvetPbFj7ByIjI8Hni3b/\\\\\\n\",\n              \"lrDy8/Nx9+5dfPr0CfLy8mjdujW0tLQapC7m5yNOfWf1/+czzH9AVVUV+fn5+PTpE1RUhL+3LDkp\\\\\\n\",\n              \"CcrKyjWGFwDY2NggIiICQUFBmDJlCgwMDLBmzRq0adOm1joyMzMREhKClJQU/PLLLzgbdBIx0dGY\\\\\\n\",\n              \"Mn0m7B26Vgiy+/ei4LdtC6LuRCI0NLTBwgsA5OXl4eDg0GD7ZxhxwSZxMN+UtLQ0+vTti8MH94tU\\\\\\n\",\n              \"LmCfP1xdXYXaVkJCAgMGDMCTJ0/Qr18/9OjRA6NGjcLbt2+r3D4mJgZjxoxB06ZNcTjwCGw7d0Fe\\\\\\n\",\n              \"QSEACTx98hjeY0eijakJhg0ZiOFug2HTtg1GewxFG3MzREVFwcjISKTvwjDM12EjMOabmzhhAiZM\\\\\\n\",\n              \"mAjvcROEuk5VXFyMPbt34UJwsEj1SEtLY9KkSRgxYkT5KMzHxwfz5s0rH/0FBQXB29sbU2fMQsyz\\\\\\n\",\n              \"OPB4vPLyRIQb4WFYvWoFcnNyMGTwQMjKyoLP58PGxuarr7ExDPN12DUw5psjIjg6OqJFKzOsXrex\\\\\\n\",\n              \"xmtMAoEAE3w8UZCfixN1XOni3bt3WLx4Mc6dO4cFCxagRYsWGDFiBE6dCYZVDWsrlpWVYfKEcXj3\\\\\\n\",\n              \"7g2Cz5+HtLR0ndrBMN8Tceo72SlE5puTkJDA8ePHcftWBCaO80RqNSvVJyclYbTHMLx5HY+Affvq\\\\\\n\",\n              \"XK+uri52796Ny5cvIzg4GK6urvDbubvG8AL+WUVky/YdyM3Nq5flohiG+ToswJjvgpqaGsLCwiAr\\\\\\n\",\n              \"Iw2r1i0wdpQHjh45jIvB53Hk8EF4DB2CDtbm0NPVxqVLl+p1maTWrVvD19cXmppa6Nmrt1BlGjVq\\\\\\n\",\n              \"hGkzZmH79u311g6GYUTDTiEy352srCzs2bMHkbdvIyc7B8oqyrDv0gUjR44UaVFdUXh4eMCqbQdM\\\\\\n\",\n              \"nDxF6DKlpaUwaaaPK1euoEWLFg3SLob5r4lT38kmcTDfHTU1NcycOfM/rfPlq1cY6z1BpDKNGjWC\\\\\\n\",\n              \"qVlrxMfHswBjmG+AnUJkGABlpaW13lNWFWlpaZSWljZAixiGqQ0LMIYBoKGpicTE1yKVISK8fp0g\\\\\\n\",\n              \"9CLEDMPULxZgDANgqJsbAvb6i1TmXlQUCvLz0a5duwZqFcMwNWEBxjAAhgwZguiHfyP2xQuhy+zY\\\\\\n\",\n              \"vhXjx0+ocjFhhmEaHvs/j2EAyMrK4tdfF8BjmCs+fqx9ZfyD+wNwMyIcXl6e/0HrGIapCgswhvk/\\\\\\n\",\n              \"06ZNhaOjI35xsMO9qKgqt8nJycHqVSuw+LdfERwcDHV19f+4lQzDfMam0TPM/5GQkMD69etgstME\\\\\\n\",\n              \"I9xdweXyMGz4CGjr6KCwsBC3I2/haOAh2Ds44ObNm9DX1//WTWaYnxoLMIb5goSEBMaN84GXlycu\\\\\\n\",\n              \"XryIkydPITzsKho3boxWLVvh0aNH0NXV/dbNZBgGLMAYpkpSUlLo3bs3evcWbmkphmH+e+waGMMw\\\\\\n\",\n              \"DCOWWIAxDMMwYokFGMMwDCOWWIAxDMMwYokFGMMwDCOWWIAxDMMwYokFGMMwDCOWWIAxDMMwYokF\\\\\\n\",\n              \"GMMwDCOWWIAxDMMwYkmCiOhbN6IhcLlcGBgYfOtmMAzDiJXXr18jPT39WzdDKD9sgDEMwzA/NnYK\\\\\\n\",\n              \"kWEYhhFLLMAYhmEYscQCjGEYhhFLLMAYhmEYscQCjGEYhhFLLMAYhmEYscQCjGEYhhFLLMAYhmEY\\\\\\n\",\n              \"scQCjGEYhhFLLMAYhmEYscQCjGEYhhFLLMAYhmEYscQCjGEYhhFLLMAYhmEYscQCjGEYhhFLLMAY\\\\\\n\",\n              \"hmEYscQCjGEYhhFLLMAYhmEYscQCjGEYhhFLLMAYhmEYscQCjGEYhhFLLMAYhmEYscQCjGEYhhFL\\\\\\n\",\n              \"LMAYhmEYscQCjGEYhhFLLMAYhmEYscQCjGEYhhFLLMAYhmEYscQCjGEYhhFLLMAYhmEYscQCjGEY\\\\\\n\",\n              \"hhFLLMAYhmEYscQCjGEYhhFLLMAYhmEYsfT/ABs9BDmEtprEAAAAAElFTkSuQmCC\\\\\\n\",\n              \"\\\"\\n\",\n              \"  frames[5] = \\\"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAbAAAAEgCAYAAADVKCZpAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\\\\\\n\",\n              \"AAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0\\\\\\n\",\n              \"dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOzdd1RUx+P38TdNRKVIkWIlYENUsIu9\\\\\\n\",\n              \"F1ABG3ZFRY295GtJLFFj1NixG+xdg4qgIlbsFVQsETUWEJAivS57nz/yyC8ElF0j4uq8zrkHZO/c\\\\\\n\",\n              \"O7vg/ezMzsxVkyRJQhAEQRBUjHpRV0AQBEEQPoYIMEEQBEEliQATBEEQVJIIMEEQBEEliQATBEEQ\\\\\\n\",\n              \"VJIIMEEQBEEliQATBEEQVJIIMEEQBEEliQATBEEQVJIIMEEQBEEliQATBEEQVJIIMEEQBEEliQAT\\\\\\n\",\n              \"BEEQVJIIMEEQBEEliQATBEEQVJIIMEEQBEEliQATBEEQVJIIMEEQBEEliQATBEEQVJIIMEEQBEEl\\\\\\n\",\n              \"iQATBEEQVJIIMEEQBEEliQATBEEQVJIIMEEQBEEliQATBEEQVJIIMEEQBEEliQATBEEQVJIIMEEQ\\\\\\n\",\n              \"BEEliQATBEEQVJIIMEEQBEEliQATBEEQVJIIMEEQBEEliQATBEEQVJIIMEEQBEElaRZ1BQqLsbEx\\\\\\n\",\n              \"lSpVKupqCIIgqJTnz58TExNT1NVQyFcbYJUqVeLmzZtFXQ1BEASVUq9evaKugsJEF6IgCIKgkkSA\\\\\\n\",\n              \"CYIgCCpJBJggCIKgkkSACYIgCCpJBJggCIKgkkSACYIgCCpJBJggCIKgkkSACYIgCCrpq53ILAiC\\\\\\n\",\n              \"8LmkpaXlrF5hYmJC8eLFi7hG3wbRAhMEQfgIkiRx6dIlern1wcjYBPv6DbGr1wBDY2P69h/A9evX\\\\\\n\",\n              \"kSSpqKv5VRMtMEEQBCWlpqbSu09fbgXfw97RjUm7zqOjqw9ASkIcQf5/0MW1J80cGrNz+1bRIisk\\\\\\n\",\n              \"ogUmCIKghKysLDo7dSEiDUZsOIpD9yE54QVQUt+Qpr2GM2qjH0/eJODSvQfZ2dlFWOOvlwgwQRAE\\\\\\n\",\n              \"JSxctIjoNDldpyxEU6vYe/fT0i6O6/RlPIuMY9WqVZ+xht8OEWCCIAgKysrKYvXadbR2n4yGRsGf\\\\\\n\",\n              \"wGhoatFq8ERWrFqNXC7/DDX8togAEwRBUNDRo0fRL1MWs++qKVymXHU71LRLEBAQUIg1+zaJABME\\\\\\n\",\n              \"QVDQ1WvXqFi3mVJl1NTUsKzbjGvXrhVSrb5dIsAEQRAUlJiUjLZOCaXLaemUJCkpuRBq9G0TASYI\\\\\\n\",\n              \"gqCg0gb6pCUlKl0uIzkBAwP9gncUlCICTBAEQUFt27Qh9HKAUhOU5XI5f14OoHXr1oVYs2+TCDBB\\\\\\n\",\n              \"EAQFtW7dGk1kvAi5pXCZp7cuYVxan0aNGhVizb5NIsAEQRAUpKamxqD+/Tm6chaZaakF7p+enMQZ\\\\\\n\",\n              \"r8VMnTIZNTW1z1DDb4sIMEEQBAX5+PiwerUnVuVM2T1zOKkJb9+7b/LbWHb/OBTH9m3o37//Z6zl\\\\\\n\",\n              \"t0MEmCAIQgHkcjmzZ89m9OjRHD16lAvnz9O1bXM83dvjs+InIp48ICsjncz0NMIfh+C7Yiarh3ag\\\\\\n\",\n              \"j4sTa1d7itZXIVGTvtLlkuvVq8fNmzeLuhqCIKi4t2/f0r9/f5KTk9m/fz+mpqY5j4WHh9O8RQuS\\\\\\n\",\n              \"UlJJiIsDwKxsOYa5D2b4sGGYmZkVVbU/mipdO8Vq9IIgCO9x7949XFxc6NKlC4sXL0ZLSyvX4yVL\\\\\\n\",\n              \"liT6zRsiIiIoWbJkEdXy2yW6EAVBEPKxd+9eWrduzc8//8zy5cvzhBfAsWPHaNmypQivIiJaYIIg\\\\\\n\",\n              \"CP8gk8mYOnUqhw4dIiAgADs7u/fue+TIEbp16/YZayf8kwgwQRCE/+/Nmzf07t0bbW1tbt68iaGh\\\\\\n\",\n              \"4Xv3zcjIwN/fX9wqpQiJLkRBEATgxo0b1K9fnyZNmuDn5/fB8AI4e/Ystra2uQZ1CJ+XaIEJgvDN\\\\\\n\",\n              \"8/LyYvr06WzYsAEXFxeFyojuw6InAkwQhG9WRkYG48eP5/z58wQGBlKtmmL3+ZLL5fj4+HD27NlC\\\\\\n\",\n              \"rqHwISLABEH4JoWHh9OjRw8sLCy4fv06urq6Cpe9efMmenp6VKlSpRBrKBREfAYmCMI3JzAwkPr1\\\\\\n\",\n              \"69O1a1cOHjyoVHiB6D78UogAEwThmyFJEqtWraJXr15s3bqV6dOnf9QyT0eOHMHZ2bkQaigoQ3Qh\\\\\\n\",\n              \"CoKgsmQyGb6+vvy+dTuvX0eirqGOlWUlRnkMo0WLFrnCKTU1FQ8PD+7fv8+VK1ewtLT8qHM+ffqU\\\\\\n\",\n              \"2NhYGjRo8ImehfCxRAtMEASVtHXrNizKV2T09J95qV8D7eaD0Gzcj/tyM3oNGo5VleoEBAQA8OzZ\\\\\\n\",\n              \"MxwcHFBTU+PSpUsfHV7wd+urS5cuqKuLy2dREy0wQRBUzvwFv7LMcx3VhyzAoGL1XI8ZVa1LhZY9\\\\\\n\",\n              \"ib5/he5uffl++DC2bNnMTz/9xJgxY5TqMkxLS2P//v2cOR9IYlIy+np6XL10gYULF37qpyR8BBFg\\\\\\n\",\n              \"giColAMHDrB01RrqTNpIcX3jfPdRU1OjjK0DxUevZOlvw1mzcjkeHh4KnyMjI4MfZ85i0++/Y1DJ\\\\\\n\",\n              \"Bt2qjdHUrUhWahJRmVoMHjqcMbduM3vWzHzXSBQ+D3E7FUEQVIYkSVStURODdh6Y2DRSqMzzs/ux\\\\\\n\",\n              \"yniOz6GDCu2fkpJC246dCU/T4DvnMZQ0KZdnn+TIFzw7tJLKZUpx/OgRtLW1lXoeXzJVunaKTlxB\\\\\\n\",\n              \"EFTG5cuXiUtMwbia4gMoyjbqzOnTAbx+/brAfSVJoodbX6IkXWyHLsg3vABKmVXE1mMxz+KzGTDY\\\\\\n\",\n              \"XeG6CJ+WCDBBEFTG/oMHMa7bATUlBlBo6ZTCrFYzjh49WuC+N27c4NrN21TtO6PAc6hraFJtwEz8\\\\\\n\",\n              \"A07x4MEDhesjfDoiwARBUBmRUdEUMzBRupy6rhGRkZEF7rd81WrMmrigrqHY8ACNYsUxb9yFlavX\\\\\\n\",\n              \"KF0n4b8TgzgEQVAZxbW1kTKzlC6XlZ7KvHnzWLJkCaampjmbmZlZzvfGxsZ4HzxA81+OKHVsC4eu\\\\\\n\",\n              \"7Fk0mA1rRYh9biLABEFQGbY1qnPxSCDgqlS5rIgnHDp0iObNmxMVFZWzRUZGEhUVxe3bt3n16hVy\\\\\\n\",\n              \"1ChWykCpYxcvbUpqchKZmZkUK1ZMqbLCfyMCTBAElTF40CDmzJ3Pd84JFCulr1CZhJd/kpUYTadO\\\\\\n\",\n              \"ndDU1ERfXz/fRXgTEhIws8h/0MYHSRISEhoaGsqXFf4TEWCC8AXKzs7m9OnT/PXXX2RnZ2NqakrH\\\\\\n\",\n              \"jh0pWbJkUVetSJmYmODk5ETwmd1Ydx1V4P6SJPHEbxOaahI3b96kUaP3D73X1dVFU1OD1JjXlDC2\\\\\\n\",\n              \"ULhOSRHPMC5jJgKsCIhBHILwBUlMTGTBggVUqGTJmClT2ecfyB9nrjB/mSdly5dn3PgJvHz5sqir\\\\\\n\",\n              \"WaSW/baIhOBThF368GdVkiTx9MhajKVk5s2di6urKx4eHsTGxubaLz4+Hi8vL9q0aUN2djavzis2\\\\\\n\",\n              \"X+ydyEuHGTF8mNLPQ/jvRAtMEL4Q4eHhtG3fAYPyVgyZv5aK1Wrmejw2IoxA753Urd8AX58jNGzY\\\\\\n\",\n              \"sIhq+vk9ePCATV5e/Bn6FFl2Nk2aOHD+hBdvQq7wXYeBGFSyydlXksuJeXSDiHN7MdLKIuCUPyYm\\\\\\n\",\n              \"JvTo0YNZs2ZhY2PDzz//jJmZGbt37+bkyZO0adOG8ePHU7lyZRo1bYFlx0EUK1lwF2VGYiyvb5xk\\\\\\n\",\n              \"5I4Vhfn0hfcQASYIX4CEhATatu9A9ead6DhodL7r9RmZl8Nl9DSsatens1MXLgaep3r16vkc7esR\\\\\\n\",\n              \"FBTE9+Mm8PDhQyybd0PfsilqGhq8ig5HrqlNQugtHry4h7a+ESVMyoEkJyH8KUb6uvw0fiwDBw5E\\\\\\n\",\n              \"R0cHAD09Pdzc3IiIiGDs2LHo6OgwYcIENmzYQOnSpYG/W221bG24sXIcDSetRbP4+7tss1ISCdnw\\\\\\n\",\n              \"A5MmjMfCQvEuR+HTEQEmCF+AFStXYlixynvD659qNW1D9IBRTJryA8f9fD9TDT+/s2fP4ty9J7Y9\\\\\\n\",\n              \"xtDFYxkamrnXHKzeaQCRD25w4/fZ9HBsT8sWzdHQ0KBixYrY29vnvI6hoaHs3LmTnTt3UqxYMQYM\\\\\\n\",\n              \"GMDChQs5e/YsM2bMID4+nnnz5lGyZElGjx5NSlIiHZvW5czykVToNAyTmk1yzQuTZ8uICj7Py+Ob\\\\\\n\",\n              \"6NfDmbk/z/mcL4vwDyLABKGIyWQy1q3fgMfi3xVeKb1ptz786OzJ8+fPqVSpUuFWsAiEhobi0qMX\\\\\\n\",\n              \"jUcvwsymfr77qKmpYV6jAW1+2sru+UNo1bIFLi4uAERHR7Nv3z527tzJ8+fPcXNzY//+/dSpUyfn\\\\\\n\",\n              \"Nf7uu+9wdnZm+vTpVKtWDTMzM0xMTLhw4QKlSpVi3759LFq6nKt/LMfY1gG1YiXITksk8s4F9PX0\\\\\\n\",\n              \"aNWoHs2aOJCVlSWGzxcRMYhDEIqYv78/BqYWlK9sU/DO/592cR0adnTBa/PmQqxZ0Vm4+De+a9Xj\\\\\\n\",\n              \"veH1TyWNTKnnPpNpP85i7969dOnShcqVK3P16lXmzJlDWFgYK1asoG7dunneIBgZGTF//nxKly7N\\\\\\n\",\n              \"y5cvyczM5NWrV6ipqeHm5kbQjWuc8fejlbUhySFniQg6R9naTbBo6swTyYQpc3/DvFwFfpo5i7S0\\\\\\n\",\n              \"tMJ6OYT3EC0wQShiT58+pVzlGkqXK1vZhtAntwuhRkUrISGBffv30/lXxUcDmts24tKG2Sxbtoyx\\\\\\n\",\n              \"Y8eye/dudHV1CywXGhpKp06d6Nu3L7NmzWL9+vW0aNECd3d3Zs6cSalSpTjscxRv3xPUcptMuTrN\\\\\\n\",\n              \"8ywz9TbsCfsOrefYiZacPnki5/M0ofCJFpggFDGZTIb6R8wh0tDUIksmK4QaFa3jx49jUa0OJUqX\\\\\\n\",\n              \"UbiMmpoa1du7YV+3HgMGDFAovK5evUrz5s2ZOnUqc+fORVNTkzFjxhASEkJERAQ2NjYMH+7BOq9t\\\\\\n\",\n              \"tJm1jQr1W+e7RmLpctY4jPmNLGMrHLs6I/sKfydfKhFgglDETE1NeRsZrnS52NevMDc1LYQaFa2o\\\\\\n\",\n              \"qCiKG5krXa6kiQWvI6MU2vfIkSN06dKF33//neHDh+d6zNTUlO3bt7Np0ya27thBk3HL0NE3+uDx\\\\\\n\",\n              \"1NTUsO//A69iEvDx8VG67sLHEV2IglDEHB0dGTV6NAmxb9A3UqzVIZfLuX78IAf37i7k2n1+Wlpa\\\\\\n\",\n              \"SNnKt2Lk2TICAgKwtrbGzMzsvZu/vz+rV6/m2LFj1K///s/YwsPDqVizIfplLRU6v7q6BpZt+7Bs\\\\\\n\",\n              \"pSeursqt1Sh8HBFgglDEDAwM6NGjJxcP78Fx6HiFyty/eh593VIfXBpJVVlZWRH/cqPS5RJfPGLM\\\\\\n\",\n              \"9yMZMXw4kZGRubYrV67w+vVrgoKCiImJQU1NjU6dOmFmZoa5uXm+Qbdi9VoqtB6oVB0q1m/DkZ2L\\\\\\n\",\n              \"CQsLo1y5j1hXUVCKCDBB+AJMn/o/GjZ2wMquPtXqOnxw39iIMLbNn4K2hjoXL16kWbNmn6mWn0fb\\\\\\n\",\n              \"tm2RJcYS++wBRt8pNjJTlpHGs4u+jPa8haWlJZUrV871eGZmJu7u7lhaWhISEkLp0qWJjY3NE3Th\\\\\\n\",\n              \"4eHcunWLyMhIQh8/xrq/Yq2vdzS0imFgVo7w8HARYJ+BCDBB+AJUrlyZPw7sp3vPXjh5TKFxZ1c0\\\\\\n\",\n              \"tXLPLZIkiQfXL7D712nM/3kO31la0rt3b9zd3Zk9ezZaWlrvObpq0dDQYPSokWw/vo3G3y9UaG5c\\\\\\n\",\n              \"6DlvGjRsiKVl3sBJSEjA1dUVPT09Tp8+TYkSJQAoU6YMZcqUoVatWvkes2xFS1BwXt6/SZL0UeUE\\\\\\n\",\n              \"5YgAE4QvRKtWrTgdcJIJkybjt2kZjRx7UNa6OuoaGsSEv+TasYOULK7NxnVrcHZ2Bv5eamnIkCE0\\\\\\n\",\n              \"adKEXbt25Wl5qKrx48exa+8+Qg6tx9Zl5AdDLCwokMe+W7h84Xzex8LC6Ny5M82aNWPVqlVKrRhv\\\\\\n\",\n              \"YWFBwuu/0DOroHCZbFkW8ZFhlC1bVuEywscToxAF4Qtib2/P+bNnuHThPNWMS/Am6BxhV/0xksWz\\\\\\n\",\n              \"e9tmHt6/lxNe8PeIOT8/PwYOHIiDgwObN2/+Kt796+rqcibAn5hbAZz+bSzRT+7meV5Jb8II2rOc\\\\\\n\",\n              \"oK3zOO7rk2ddyJCQEBwcHOjfvz+rV69W+nYnI4YO4cV5b6XKvLxxGltbW8qXL69UOeHjiBaYIHyB\\\\\\n\",\n              \"qlWrxtIlvym0r5qaGmPGjKFVq1b07duXY8eOsWHDBoyMPjz0+0tnYmKCLCON9LCHBG/6CbXipTCo\\\\\\n\",\n              \"UAU19b/v2RXz1yMGDx7I5NU3qFAhdyvp7Nmz9O7dmxUrVtC3b1+lz52YmMijR494HnyFWhEv0DOv\\\\\\n\",\n              \"WGAZSS7n2am9LJ87Q+nzCR9HtMAE4StRo0YNrl+/TsWKFbGzs+P06dNFXaX/ZM2aNcTExOBz+DCv\\\\\\n\",\n              \"nj9j+/pVTOzjyJjubVk66wciwl+xcvnyPOG1Z88eevfuzb59+5QOL5lMxsaNG6latSoxMTHM/OlH\\\\\\n\",\n              \"Lq2cSHpi3AfLSZJE0J6lmOvr0K1bN6Wfq/BxRAtMEL4i2traLF26lI4dOzJo0CD69u3LvHnz0NbW\\\\\\n\",\n              \"LuqqKSUhIYEZM2bQrl07mjRpAkDr1q0/WEaSJJYsWYKnpyenT5+mZs2aH9z/3wICApg0aRJGRkb4\\\\\\n\",\n              \"+vpSt25dADIyMtg0bzC1+07BonYT1NVzd0UmvH7O/UPr0Ul9w/EA/69mMI1KkL5SdevWLeoqCEKR\\\\\\n\",\n              \"io6Olrp16ybZ2dlJDx48KOrqKGX48OGStra2FB4ertD+MplMGjNmjGRrayu9evVKqXPdv39f6ty5\\\\\\n\",\n              \"s2RtbS0dOnRIksvlefY5cOCAZFu7jmRkXl6q1W2oVK//FMm+11jJ0s5BKm1kIv0wdZqUnJys1Hm/\\\\\\n\",\n              \"VKp07RRdiILwlTI2NubQoUN8//33NG/enHXr1qnEAI9nz56xdetWpkyZotCNItPS0ujRowf379/n\\\\\\n\",\n              \"4sWLCs+/io6OZvTo0bRo0YK2bdty//59nJ2d8x3x2KNHD+4F3+L4kT9wrl2W+voZtKtUgl/+N4aI\\\\\\n\",\n              \"8FcsXvgrJUu+/+aXQuFQk1ThL/oj1KtXj5s3bxZ1NQThi/Dnn3/Sr18/zM3N8fLyokwZxRfK/dya\\\\\\n\",\n              \"Nm3Ko0ePCA8PL7DrMyYmhq5du2JpacnmzZsV6irNyMjA09OTRYsW5axCr+oDXj4lVbp2ihaYIHwD\\\\\\n\",\n              \"qlatyuXLl7G1tcXOzo4TJ04UdZXydebMGa5du8amTZsKDKNnz57RpEkTWrRowY4dOwrcX5IkDh48\\\\\\n\",\n              \"SPXq1blw4QIXL15k5cqVIrxUWdH2YBYeVerHFYTP6ezZs1L58uWlcePGSampqUVdnRzZ2dlS2bJl\\\\\\n\",\n              \"JXt7+wL3vXHjhmRubi6tXr1aoWNfu3ZNatKkiVS7dm3p1KlT/7WqXzVVunaKFpggfGNatmzJnTt3\\\\\\n\",\n              \"iIyMpEGDBty9m3eScFHw9PQkKiqKvXv3fnC/Y8eO0alTJ9auXcvo0aM/uO+rV6/o378/zs7OuLu7\\\\\\n\",\n              \"c+vWLdq0afMpqy0UIRFggvANio+Pp0LFioS/jsDOzg5NLS3KV7Rkzs8/8/r1689en5SUFGbMmIGb\\\\\\n\",\n              \"mxtVqlR5736///477u7u+Pj45FqR5N+Sk5OZOXMmdnZ2fPfddzx+/Bh3d3elV+MQvmwiwAThGyKT\\\\\\n\",\n              \"yRj5/Wjs69Yj+FUc36/cw4qzf7Ls1AP6/7yGc3dCqVbdhp9mzvqsrbJx48YhSRJr167N93FJkpg9\\\\\\n\",\n              \"eza//vorgYGBNG7cON/9srOz8fLyokqVKjx//pzg4GDmzp1LqVKlCrP6QhERE5kF4Rshl8tx69uP\\\\\\n\",\n              \"x68i+WnPWXRK6eV6vFxlG3pOnk9H9wl4Tfcg7u1b1niuUmg1+P/ixYsXbNu2jSVLlqCrq5vn8ays\\\\\\n\",\n              \"LDw8PAgJCeHy5cuYvucu1KdPn2by5Mno6upy5MiRD96sUhXIZDI0NDQK/fVXZaIFJgjfiOUrVhAS\\\\\\n\",\n              \"+hfuv2zME17/pFvamBFLtuJ7IqDAz6M+BTc3N8zNzRk3blyex5KSknByciI6Oppz587lG15//vkn\\\\\\n\",\n              \"Xbt2xcPDg5kzZxIYGKiS4SVJEteuXaNv/wHo6umjra2NlpYW1Wxq4OnpSWJiYlFX8YsjAkwQvgHZ\\\\\\n\",\n              \"2dksX7GSrmN+QkuBuVI6pfTo7PEDP8//pVDrdebMGa5fv86uXbtQV899OYqIiKBFixZUrFiRw4cP\\\\\\n\",\n              \"55koHBsby7hx42jatCnNmzfnwYMHdO/eXSVbLHFxcbRs3Qbnnr1J0y/P1J0BLDsXyuKA+7Qb+SM7\\\\\\n\",\n              \"jvhTvkJFdu/eXdRV/aKIABOEb8CJEyfQ0TekYrX8b96YnxqNWhL+OpJx48aRnZ39yeskSRL9+/fP\\\\\\n\",\n              \"CaB/evjwIY0bN8bV1ZUNGzagqfl/n3ZkZmayfPlyqlWrhlwu58GDB0yZMkXl1nt8Jz4+nibNWlDM\\\\\\n\",\n              \"3IqpO0/Tqs9w9AxNUFNTQ1OrGJXrNKb/HE9GrdzNhCn/43cvr6Ku8hdDBJggfANOnz5D9SbtlCqj\\\\\\n\",\n              \"rqFBg/bdOHr0KG3btuXVq1eftE7Lly8nOjqaffv25fr5xYsXadmyJXPmzOGnn37KaVFJksShQ4eo\\\\\\n\",\n              \"UaMGp06dIjAwkNWrV2NiYvJJ6/W5DR3ugZlNXZxGTc/TCv0nC6tqDF+yjR+mTuP+/fufsYZfLhFg\\\\\\n\",\n              \"gvANeJuQQAldfaXL6ejq06dPH9q1a0e9evU4cODAJ6lPWloaP/74I6NGjcLMzCzn5wcPHsTV1ZUd\\\\\\n\",\n              \"O3YwePDgnJ/funWLli1bMmvWLNasWYOfn1+eG1iqorCwMAICAug8fIpCXZ+mFb6jcbd+rPRc/Rlq\\\\\\n\",\n              \"9+UTASYI3wA9XV0yUpOVLpeZloK+vj4zZszA19eXH3/8kSFDhpCUlPSf6jNixAi0tLRYunRpzs9W\\\\\\n\",\n              \"rFjBhAkTOHnyJO3btwcgPDycwYMH4+TkRP/+/QkKCsp57Guwfv0G6rbrhnYJxRcCbuTUm71794hB\\\\\\n\",\n              \"HYgAE4RvQhOHxoReP69UGUmSuH/5NI0aNQKgfv363L59G01NTezt7bl69epH1eX58+fs2rWL1atX\\\\\\n\",\n              \"o6WlhVwuZ9KkSWzcuJFLly5hZ2dHSkoKc+bMoVatWlhYWPDnn38yfPjwXJ+FfQ0CzpylRlPlunYN\\\\\\n\",\n              \"TMywsKzM7du3C6lWqkMEmCB8A5ydnYl6+YyIv0IVLhMadJX4mGhmzJjBvn37yMrKolSpUmzatInF\\\\\\n\",\n              \"ixfTrVs35s2bh0wmy1M2PT2dnTt30qJ1O6rVqEXN2nXp5daXixcv0r17d6ysrBg4cCDp6en06dOH\\\\\\n\",\n              \"mzdvcvHiRcqXL8/WrVupWrUqjx8/5vbt2yxYsAA9vfcP+1c1mZmZhIWFcfPmTaKiIj84peF9Sujq\\\\\\n\",\n              \"ixYYYiKzIHwTihUrxqgRIzi6YTHuv6xHvYAllbIyMjj++1IWL1yAubk5q1atYvLkyYwcORIPDw9c\\\\\\n\",\n              \"XV1p2LAhgwYNwt/fnx07dmBpaYkkSfy2ZCnz5/+Cmo4JacUtUStWBTLlPL4Uhu8xZ9JSEtnitZG4\\\\\\n\",\n              \"uDicnZ0xMzPj5MmTXL16lUmTJlG8eHEOHjyY0/JTBdnZ2cTExBAVFUVkZGSu7d8/S0xMpEyZMpia\\\\\\n\",\n              \"mpKamkpGaorS50tPSc530ve3RtwPTBC+EZmZmbTv0JHM4gb0nroQTa1i+e6XkZrCphkjSIx4wZ+P\\\\\\n\",\n              \"HuYsw3T37l08PT05ePAgXbt2ZezYsdSpU4fly5ezcOFCli1bxqXLV9m1/zCZZVqhXtwgz7ElSUKe\\\\\\n\",\n              \"8Bea0RcxNjSgR48ejBgxgmnTphEUFMSiRYvo1avXFzGXS5Ik4uPjPxhG7/4dExODgYEBZmZmOZup\\\\\\n\",\n              \"qWm+/zYyMkJdXR1JkhgwYCDhWcVwGfuTwvVKfhvLr/3b8OKvvzA0NPzkz1uVrp0iwAThG5Kamkrv\\\\\\n\",\n              \"Pn25FXSHRl370sixJyX1/g6axNhorvju5crRvXTu0J5sWRbh4eEcPXo011qCcXFxeHl5sWbNGiws\\\\\\n\",\n              \"LBg7dizW1tZ06dKV2KRMNKxcUdP48JwseUokvDyOW+8e+Pn58cMPPzB+/HiKFy9eqM8f/l7oN78w\\\\\\n\",\n              \"+vfPoqKi0NHR+WAYvdtMTEzQ0tJS6PwxMTHs3LkTLy8vEhISiItPZM6hKxTTVuy5n961Hv20KLZv\\\\\\n\",\n              \"3fofXoX3U6VrpwgwQfjGvFuyaJXnag4fPkxxnRJIkoQsK5Pebm6MHf09tWvXJjs7mxEjRvDw4UOO\\\\\\n\",\n              \"HTuGvn7uYfgymYyjR4/i6enJo0ePiE9IJrtcJ9RLKDYvSxZxDavSqQSeP/uf7xCdkZHx3tbRv/8t\\\\\\n\",\n              \"l8tzhc/7AsrU1BQdHZ3/VK93srOzCQgIwMvLi4CAALp06cLQoUNp3rw5jl26omZiScdhkwo8TlxU\\\\\\n\",\n              \"OKu/74H/MT/q1q37Ser2b6p07RQBJgjfsIyMDOLi4lBXV8fQ0DBPK0IulzNmzBhu3ryJv78/pUuX\\\\\\n\",\n              \"zvc4K1eu5H8zF6Ju1V3hc0tZKaj/dZDIiPB8B2nIZDKio6MVai2lpKTkhFBBraVSpUp9ti7KZ8+e\\\\\\n\",\n              \"sWXLFrZu3YqZmRlDhw6lT58+ud4MREVF0aBRY+w69qR13xHvrVvs61f8PtWdiWNGMWXy5EKrsypd\\\\\\n\",\n              \"O0WACYLwQZIkMWnSJM6fP8/JkycxNjbOs0+PXn04ciUCTWNbpY6tHnacHp2bYmpqmiec4uLiMDQ0\\\\\\n\",\n              \"VKi1VLp06Q+uYvE5paWl4e3tjZeXF/fu3aNfv364u7tTq9b7l/EKDw+nk2MXUmVyGnTpS522XXK6\\\\\\n\",\n              \"FF8/fcSVI7sIOuPH/HlzGTd2bKHWX5WunWIUoiAIH6SmpsayZcuYMWMGrVq14tSpU3lWhY+MikJN\\\\\\n\",\n              \"S/HJuO9kUYywsDBsbGyoUaNGrnAyNjZWmXlfkiRx+/ZtvLy82LdvH/Xr12fUqFF07dpVoTUay5Yt\\\\\\n\",\n              \"S9CtG/j5+bHwtyUcWPITOiVKIsvOQk9Pn5EjPNiz+j4WFhaf4dmoDtX46xAEoUipqamxYMECtLW1\\\\\\n\",\n              \"admyJadPn851MS2uXRwk5Rf8lWdlcv78eUJDQzE3N8fc3BwLC4t8vzcxMfni7qgcGxvLrl272Lx5\\\\\\n\",\n              \"MwkJCQwZMoSgoCAqVKig1HHCw8PZsHEjGzZuQlO7OJWq2pCanER6ciIDBwxgyODBIrzyIQJMEASF\\\\\\n\",\n              \"qKmpMWfOHLS1tWnRogVnzpzBxMSEs2fPEhfzBiklEwysFT6eJEnoqCXjd+YMFSpUICIigtevXxMR\\\\\\n\",\n              \"EUFERASXLl3K+T4iIoL4+FHfm/wAACAASURBVHhMTEwKDDpTU9NCbbnJ5XJOnz6Nl5cXJ06coHPn\\\\\\n\",\n              \"zixdupRWrVp9VDfmkSNHGDTEnTptnBi+eCvm31XJeSzq5VMuH9lNbfs6rPFcRb9+/T7lU1F5IsAE\\\\\\n\",\n              \"QVDKoEGDCAoKolq1aqirq2Nvb0+7dm144LkWSe6AmrpilxV5chgpSfEcOHCAMWPG0Lhx4w/un5mZ\\\\\\n\",\n              \"SVRUVJ6gu3HjRs73r1+/JjY2FiMjowKDzszMjGLF8p8Ll5/nz5+zdetWtmzZgrGxMe7u7qxbt+69\\\\\\n\",\n              \"A1sU4efnx5BhHnj8tiXfW92YVrDCZexMGjr2YsLkIWhqatK7d++PPt/XRgziEAThg+RyOTdv3sTX\\\\\\n\",\n              \"1xc/Pz+eP39Ox44d0dLS4tSpU5w9e5bKlSvTolUbLj9OR7OMfYHHlCQ56i/90NfORF1dndTUVOrX\\\\\\n\",\n              \"r8+YMWNwdHT8T12FMpmMN2/e5Am6f34fERFBVFQUBgYGeYLtn/82NDTk+vXr7Nixg6CgIPr06cPQ\\\\\\n\",\n              \"oUOxs7P76Pq9k5KSQrnyFRj66+9Y2hb8moWFPmDthH789expoUxgfkeVrp2iBSYIQh5JSUkEBATg\\\\\\n\",\n              \"6+vLsWPHMDIywsnJiRUrVtC4ceOcLrpNmzbRqlUrDh8+TEZaCrKIW6hplkTDsMp7jy1JcjSiLmJv\\\\\\n\",\n              \"a8npAH+8vb2ZOXMm4eHhzJgxg3HjxjFq1CiGDh2a74jH/GRkZODt7c2qdRt49vQpsqwsjMuUYWBf\\\\\\n\",\n              \"N4YPG5bvPLN3yz/9O9wePnyIt7c3Dx8+JDo6GkmSKFGiBOXLl+fRo0csW7Ys3xadubl5nrtGf8iu\\\\\\n\",\n              \"XbuwqlVPofACKFfZhhoOrdi6dSuTJhU8Z+xbIFpggiAA8PTpU3x9ffH19eXatWs4ODjg5OSEo6Mj\\\\\\n\",\n              \"lpaW7y23dOlSpk6dirq6Orq6umRkZUOpimTpVss1qVmS5MgT/qJ48gPsbKzw8z2Ss55fVlYWW7Zs\\\\\\n\",\n              \"Ye7cuVSuXBl9fX3OnTuHs7MzY8aMoV69eu89/8GDB/EY9T2lLKwwc3DBoGI11NQ1SI2LJPLqUV7f\\\\\\n\",\n              \"Oov7UHdWLF3ywZbd27dv2b17N15eXsTGxjJkyBAGDx5MhQoViIuLK7BF9/r1a7S1tQvsurSwsEBX\\\\\\n\",\n              \"V5eate1oMXgS1Rs0f2+d/u1ZyG28F0/l2ZPHhTaXTZWunSLABOEblZWVxaVLl3K6BuPj43F0dMTJ\\\\\\n\",\n              \"yYm2bdvmWj7qfQIDA+natSvp6enIZDKGDx/O3LlzWb9+Ays9V5Ml10SmVhx5dhZkJiCTZbF503r6\\\\\\n\",\n              \"9OmT70CLtLQ01q5dy6JFi2jZsiWWlpbs27cPU1NTRo8eTa9evXItN+XltZnJ036ktsdCSlvWyLeO\\\\\\n\",\n              \"mcnx3Ns8kzrWZfE+sC9XiMnlcs6ePcvmzZvx8/OjY8eOuLu706ZNG6W7Md+tnahI0L17rivO/omG\\\\\\n\",\n              \"pmJLUL07x5R2NYiNiVaqtacMVbp2igAThG9ITEwMJ06cwNfXl5MnT2JlZYWTkxNOTk7Y29srNYpu\\\\\\n\",\n              \"48aNTJ48GR0dHaysrNDV1eXOnTscPXqUBg0aIJPJuH79Ort27SI0NJSVK1cycOBA5s+fT4cOHT54\\\\\\n\",\n              \"7MTERJYvX86qVavo1asXDRs2ZM+ePQQHBzN06FBGjhzJmzdvaN2+E/UnrUPXrNIHj5edlUnQ6vEM\\\\\\n\",\n              \"692FObNm8erVq5wBGbq6ugwdOpR+/fphZGSk8PP/WJIkER0dTdmy5Vhx7rHS5X/sUpfQPx9hYqLY\\\\\\n\",\n              \"kl3KUqVrp/gMTBC+YpIkERISktM1GBISQps2bXBycmL58uWYm5srfcysrCwmTJjA3r17MTU1pWbN\\\\\\n\",\n              \"mkiSxB9//MHx48dxcnLi0KFDNGnSBAcHB/766y/evn1L9erVGTx4MFu3bi0wwPT09Jg9ezajR49m\\\\\\n\",\n              \"4cKFTJo0Kad1t2fPHuzt7Smhq0fFdv0LDC8ADa1iVO0zlSW/DefyxYvcunWL3r17c+DAAerUqVPo\\\\\\n\",\n              \"S0tJkkRYWBjBwcHcuXOHoKAgJEkiLTlRqfuBZWVmkJKclGddym+VCDBB+MqkpaVx9uxZ/Pz88PX1\\\\\\n\",\n              \"RUNDAycnJ2bPnk2LFi0UWhnifWJjY+nevTtPnjzB2tqaBg0acPv2bU6dOpVznp07d+Li4sL+/ftp\\\\\\n\",\n              \"2bIlOjo6pKWlAeDm5saMGTOIj4/HwCDv7Vb+zdjYmCVLljBhwgTmz5+Po6MjEyZM4Ny5c9Rv1Jia\\\\\\n\",\n              \"Dl0UrruuWSV0TCtibW3NkSNHPtlCvf+WmZnJo0ePCA4Oztnu3LmTcyfr2rVr06NHD94mJHLrlC9N\\\\\\n\",\n              \"nfsqfOzgcydo2qy5UsP/v2YiwAThKxAeHp4TWOfOncPe3h4nJydOnDhBtWrVPkkLIyQkhC5duqCu\\\\\\n\",\n              \"rk7t2rVp0aIFW7Zs4dKlS7nCoH379uzbt49evXqxa9cuihcvnhNgRkZGOY+PGDFC4XOXK1eO9evX\\\\\\n\",\n              \"M2XKFGbPns2SJUswqVKHYiWVu5txucZdeBP75JOF19u3b7lz5w537tzJCatHjx5RqVIl7OzssLOz\\\\\\n\",\n              \"43//+x92dnaYmZnlKmtiYsLw0eNp0q2Pwr+faz67+HX2jE9S96+BCDBBUEFyuZwbN27kdA2+fPmS\\\\\\n\",\n              \"Tp060bdvX7Zt2/afJtfmx8fHB3d3d/T19WnevDlt27Zl2rRpXLx4Md85Sa1atcLb2xtXV1cmTpxI\\\\\\n\",\n              \"enp6zmODBw9m/vz5SgXYO9bW1uzatYuffvqJnefuKF1eW680sRHxSpeTJInnz5/nalEFBwcTGxtL\\\\\\n\",\n              \"rVq1qF27No0bN2bUqFHY2tpSokSJAo/ZunVrdEtoc2bPRtr0Lfi1uHRkN1kpCTg5OSld/6+VCDBB\\\\\\n\",\n              \"UBGJiYmcPHkSPz8/jh07homJCU5OTnh6etKoUaNCWT5JkiQWLFiAp6cnpUqVonfv3rRt2xY3NzdO\\\\\\n\",\n              \"nTpFxYoV31u2adOmHD16lE6dOuUaHNGhQweGDRvG5cuXiYuLIykpCV1dXRo0aKDwfcEqVKiAuvyW\\\\\\n\",\n              \"0s8nOzOdkgWES3p6Ovfv38/Vqrpz5w66uro5rap+/frx22+/YWVl9dGr4Kurq+Pnc4RGDk3Izs6m\\\\\\n\",\n              \"TR8PNPL5Hcqzswn03k7gvk1cvnhBZRY4/hzEKyEIX7AnT57ktLKuX79OkyZNcj7PqlSpUqGeOzU1\\\\\\n\",\n              \"FXd3d0JCQlBXV2fixIm0bNmSdu3asW/fvg/eHuSdhg0bsnbtWgYMGMC+ffvo3bs3t27dQq+0MS3b\\\\\\n\",\n              \"tMPEujbq2iWRMlKJex5Cx44d+WHShDzLSsXFxXHp0iUCAwMJDAzk7t27SJrFsc2Woa6h+GUs/vFN\\\\\\n\",\n              \"urb5vxtBRkdH5wTVu6/vPt97F1bdunWjdu3aCk+qVkb58uW5fvUKPXv3YX6f3TRy6k2tZh3Q0dUj\\\\\\n\",\n              \"PSWZkMunuOqzBwszM65evvTB+XjfIhFggvAFycrK4uLFizmhlZSUhKOjI2PHjqVNmzYKzc36FF69\\\\\\n\",\n              \"eoWzszPGxsZER0ezYsUKHBwcaNq0KatXr6ZVq1YKH6tOnTqYm5szceJEDh85gp//aUo37olt9wVo\\\\\\n\",\n              \"6ujm7FcmLZnbQSfo2MWF0SOGUbtWTQIDA7lw4QLPnz+nUaNG1K1bN2dV+gxZFlF3L2Ju31KhemSl\\\\\\n\",\n              \"JhF2PYCYWuVxdHTkzp07JCcnU7t2bezs7GjVqhUTJ07ExsYm11yzwla2bFkuXwwkKCiIVavXsPvn\\\\\\n\",\n              \"0SQlJVGqVCmaNHHAx/sg9evX/2z1USUiwAShiEVHR3P8+HH8/Pw4efIklStXxsnJiT179mBnZ/fZ\\\\\\n\",\n              \"b9R4+fJlevToQbt27Th27Bi7d++mTp06NG3alClTptCrVy+ljle8eHHkcjnuQ4fym+cGqnmsRtsg\\\\\\n\",\n              \"b1ehpk4pyjj0wKBGC5auGUOVCqYMHjSIgQMH8vr1a+bNm8eiRYswNDRk/PjxVK5cmWnzf6NMjUZo\\\\\\n\",\n              \"FCs4cP48+jvFdXQwMDBg+PDh2NnZUbFiRdTU1IiNjWXzli2MmziFt/FvKaGjQ706dRgzehQ1auQ/\\\\\\n\",\n              \"QfpTs7e3Z4vX75/lXF8LMZFZED4zSZK4d+9eTivr/v37tG3bFicnJzp16pRntNrntGXLFqZOnYqb\\\\\\n\",\n              \"mxsHDhzg6NGj2NjY0LZtW5o2bcrixYuVPmZ0dDRVq1YlIysbqxHrKG5Y8H2tMt5GErp+BP3cenHg\\\\\\n\",\n              \"wAGSkpKoV68eCxYsoE2bNsDfA1m6urgS/DwaO4+FaGrnP7JQkiSe+u8g6uJBLExNqFChAps3b8bE\\\\\\n\",\n              \"xISMjAzGTZjI7t27sW7QCqsmndDRN0SWkcaru1e5H3AQm+rV2bF1M999953Sz10VqdK1UwSYIHwG\\\\\\n\",\n              \"aWlpnDlzJmeou5aWVs4KGM2bN/9Pc7MUlZ6ezuXLl4mNjaV48eJUr14da+u/798lk8n43//+h4+P\\\\\\n\",\n              \"D507d8bX1xd/f38sLS3p3r07urq6bN++/aNag2/fvsWkjCnGtVtTwXWawuWeHVhA0sMLDB82lFmz\\\\\\n\",\n              \"ZhETE5Prs6rg4GCysrIoXqIUSRkyKrXtS4XGjmiV+LtbUp4tI+reJSICD5L25gXZmel07doVHR0d\\\\\\n\",\n              \"fHx82LBhA4uWLOVttjYtR82hhH7ekZvZsiyCfXdx12crgefOUL16daWfv6pRpWun6EIsAsnJycTF\\\\\\n\",\n              \"xaGtrY2RkZEYVfSVCgsLywms8+fPU6dOHZycnDh58iRVq1Yt9NUf3nn58iWrPFezecsWjMpWRNew\\\\\\n\",\n              \"DLKsDF49ukvNmjUZ7j6E7du3A38Pfw8MDOTy5cuYmpoycuRI0tLSOHDggMLhlZaWxrVr13IGXFy7\\\\\\n\",\n              \"dg1JTQOjBi5K1dusSS9Sn1zn+vXrWFpaYmFhgZ2dHbVr12bs2LHY2dlRtmxZAC5cuMDyVas5+aMz\\\\\\n\",\n              \"usZmqGtokhQbhZXVd/zyw1h69uxJeno6S5YsYd26dbRo0YK+/QdgUaM+TlOXof6edQ81NLWo6zyY\\\\\\n\",\n              \"4rr6tOvQiUcPQj7b55BCwUQL7DORyWT4+PiwZs1arly5jKGhEekZ6Wioq+Pu7s7IkSM/OCRZ+PJl\\\\\\n\",\n              \"Z2fnmpsVFhZGp06dcHR0pEOHDp98bpYizpw5Q49evanRuit1HftiVPb//sZkmRk8uOjP2e2eGOuV\\\\\\n\",\n              \"wKZaVd6+fcvhw4fR19dn7ty5HDlyhHPnzuWsGp+fhISEnBGCFy5cIDg4mFq1atGsWTOaN29OgwYN\\\\\\n\",\n              \"MDUzo96800qFtiRJ3J7dnoCT/tSvX/+DdXgnLi6OsLAwsrKyKFOmDOXLl8+zT1RUFNOnT2fn7r18\\\\\\n\",\n              \"v/MixXQUWxT32MKxjO3f/aPmr6mSL+3a+SHirf9n8PLlSzp3dqSkri5Dho1i274jOV1GT0L/ZJvX\\\\\\n\",\n              \"Ruzr1GHG9OlMnjz5s70zF/67hISEXHOzTE1NcXJyYs2aNTRq1Og/3Zjxv7p+/TquPXvhOmMllrUb\\\\\\n\",\n              \"5nlcs5g2tVp3pXqT9uyZPZLbwXd4eD8EHR0dfv/9d7Zt28alS5fyBEdUVBQXLlzICawnT55Qv359\\\\\\n\",\n              \"mjdvzrx582jYsGHOSumpqancu3cPNTUNpf+u1dTUQF2defPmUaZMGQwNDRXePsTU1BQzcwvsOnRX\\\\\\n\",\n              \"OLwAbDq4sXzVcjw8PMT/0S+EaIEVsoiICBo7OODu8T2jxkx8737hYa/o092JgQP6M3369M9YQ0FZ\\\\\\n\",\n              \"oaGhOa2sGzdu0LRp05z7Zn0prWhJkrCpWYtaLh7YtuhU4P5ZGelsGd+D9SuXIpfL8fDwIDAwEGtr\\\\\\n\",\n              \"a168eJETVoGBgbx584YmTZrQvHlzmjVrRuXKlXn58iVPnjzhyZMnPH36NOf7uLg4KlWqxKPHodhP\\\\\\n\",\n              \"90azhOJLP8nSk7n7qyu+R32Ij48nLi6uwE1LSwtDQ0OMjIw+GHDjJ0+h5bjFmFepqfhrKpfjNbQV\\\\\\n\",\n              \"t65d+arnY30p105FiBZYIRs6dBi9+gz8YHgBlC1Xnv2Hj9OhZWPatGlDgwYNPlMNhYJkZmbmmpuV\\\\\\n\",\n              \"kpKCo6Mj48ePp02bNoV2X6b/4vLlyySmpFGjeUeF9tfSLk59lyHM/WUBjx8+YMSIEcyZM4fAwECy\\\\\\n\",\n              \"srJo1KgRVlZWDBo0iIyMDP766y8OHz7MkiVLSElJwdraGmtra6ysrGjYsCH9+vXD2toaHR0dFi1a\\\\\\n\",\n              \"xOOnq4m5dQKzZooPwY8LOolj12507KjYc5AkiZSUlPeGW0xMDI8fP/77+zdvKGWo2Kof76ipq6Nn\\\\\\n\",\n              \"VIbY2NivOsBUiQiwQvTkyRNu3LzBhm37FNrfzNwCj9HjWb16Ddu3iwArSu/mZvn6+hIQEECVKlVw\\\\\\n\",\n              \"cnJi37592NnZffFdSKvXrqN2p95K1dO2pSO+q39GW1OD48ePo62tjYGBAa9eveLs2bOEhYVhZWWF\\\\\\n\",\n              \"tbU1rVq1Yvjw4VhbW2NmZpbrPJIk4e/vT8+ePbl16xba2troFNMk/rYPpk16oKbAYBBJLifpti+T\\\\\\n\",\n              \"92xVuP5qamqUKlWKUqVKUaFChTyPZ2Zm8vjxY0JCQjh9LpBsWabCx35HlpnxWUaMCooRAVaI1q1b\\\\\\n\",\n              \"R59+g5Sa1d+n3yAa2lUlJiamUJauUUWSJOWsmVeyZEmMjY0/eYBIksTdu3dzWlkPHz7MmZvl6emJ\\\\\\n\",\n              \"qanpJz1fYXvw8BENhzgrVUZLuzgmFayoXt6Upk2b5rSqrK2tMTIyKvA1j42NZe7cuWzbto2kpCRs\\\\\\n\",\n              \"bW3Zs2cPrq6u1KtXjwyZnFd+npR3GvfBY0mSRNTp37GuVJZmzZop9Rzg7/lhf/31F/fu3SMkJCRn\\\\\\n\",\n              \"e/r0KZUqVcLW1pYyZcoQ/jAIA7O8gzzeJyU+loSYqC+mm1gQAVaorl27zg8//qxUGUMjI2xr1ebO\\\\\\n\",\n              \"nTs5Eza/VUlJSezcuZM1a9fy8sVLdPX1SUlOxsjIiO9HjWTIkCEFfmD/IampqZw5cwZfX1/8/PzQ\\\\\\n\",\n              \"1tbGycmJefPm0by5at9zKSMjA00txW9V/46hkTGTJk2ic+fOCu0vSRK+vr7MmTOH4OBgSpYsSf/+\\\\\\n\",\n              \"/Zk9e3au0JckiejI16gXiyPCbyVl2gxDUyfvcHRZWjJvzmxGJ+5Pjl0MLDDoIiIi8gTVw4cPMTY2\\\\\\n\",\n              \"xtbWFltbW5ycnJg2bRrVqlXLeTN55MgRJs2cR41WXRV+be4H/IGLiwt6esrdwkUoPCLAClFycvJH\\\\\\n\",\n              \"fT6ioanFjh07ePHiBcbGxrk2AwODz760UFG4fPkyLi6u1KrbkDHTfqG+QwvU1NT+XsUi6AZ/7Pyd\\\\\\n\",\n              \"X35ZwI4d23F0dFT4uK9evcqZmxUYGEjdunVxcnLi1KlTVKlS5YvvGlSUsbExiTFRmFsrtwxSQnRk\\\\\\n\",\n              \"rpXj3yc6OpqZM2eye/duUlJSqFOnTs4k6H+/hpcuXeLBgwdMnjyZ6dOn4zFqNIeX9UG3qgOlbVui\\\\\\n\",\n              \"Ubwk2ekppIVeITbkPI6Ojvx+7HKuuw7HxcXlCql3m5aWVk5QOTg44OHhQY0aNQoMGUdHR0Z+P4YX\\\\\\n\",\n              \"wVeoaNf4g/sCpCW+5d6JPSz1O1rgvsLnIwKsEOnp6REfr/y9hxIT4klPN+bixYvExMTk2pKSkihd\\\\\\n\",\n              \"unSeYHu3GRkZ5fmZnp6eSl2Yr169Speu3fh56QYcWrTN9Ziamhq16jSgVp0G3Au6wRD3vmzZvPm9\\\\\\n\",\n              \"IZadnc3169dzugZfv35Np06dGDhwIDt37lTorsCqyLmLI17e+6jaqLXCZV6H3keWlkK9evXyfVyS\\\\\\n\",\n              \"JP744w/mzZvHvXv3MDAwYNiwYcycOfO9c9yOHj2Ku7s79vb2ODg4oKenx95dOxgwYABv3kSTGnaG\\\\\\n\",\n              \"xMREdHV1adepBQN2rSA2NhZvb+9cQfWuS/Ld1rNnT2rUqKHw7Vf+TVNTk53bt9K9lxtdflyLedX3\\\\\\n\",\n              \"r6yfnpyA7y+jGdSvL3Xr1n3vfsLnJwKsELVu3YpjRw/RopXiXYGvw8N49OA+NapXw8XFhXbt2uX6\\\\\\n\",\n              \"DE0mk+WMqPr3FhYWRnBwcJ6fp6en5xtsH9pKlChRJKGXmZmJa/cezFq8Nk94/VtN+/osXreLAQN6\\\\\\n\",\n              \"8+zZ05wwSkhIwN/fH19fX44fP465uTlOTk6sW7eOhg0bFuncrM8hNDSU9evX8yo8nMSYKPSMFfv8\\\\\\n\",\n              \"7rbfbr4fNSLP6xMZGcmPP/7I/v37SU1NpVGjRpw6dYrWrT8cjlu2bGH69On4+fnx22+/5dyVOTMz\\\\\\n\",\n              \"k2PHjrFnz55cLavt27ayaOGvVK1aNSeoxo8fj62tLeXLl//kf49t2rRhx9bNDBg0BJs2Lth2dMPA\\\\\\n\",\n              \"rFzO45mpyTw4d5Q7Plvp3d2ZJb8pvw6kULjEPLBC9Pr1a2xtbbl57wm6CvabL/rlZ6IjXlKnTh0O\\\\\\n\",\n              \"HTpEUFAQHTp0wMXFBUdHx4/qf8/IyCA2Njbf0Mtvi46OBlAq8IyNjT/J6Kz9+/ezbOUa1u5SvKvm\\\\\\n\",\n              \"p/Hu2NeshpGhIb6+vty8eZNmzZrlzM3Kb0Ta18rPzw+3Pn2RSxLa+kaUMjBi8OJtaBX78O/m4aUA\\\\\\n\",\n              \"Tq2fx/27dzAxMUGSJPbu3csvv/zCgwcPMDY2ZtiwYcyYMaPApZQkSWLx4sWsW7eOY8eOoa2tzbBh\\\\\\n\",\n              \"wzA0NKRYsWJcvnyZsLAwKleujK2tLTVr1swJLCsrq8++tNqzZ89YtXo127Ztp7RFRUrqG5KVkU74\\\\\\n\",\n              \"4xBatWrFhHFjCgzrr8mXcO1UlAiwQta//wA0ihXntxVrC3wH+eejB7g4tiXw/PmcRUOjo6Px8fHB\\\\\\n\",\n              \"29ubCxcu0KxZM1xcXOjatetHd58oIjU1VeHAe7dpa2srFXiGhoZo/WugQbPmLejiNpS2nRUfQRd0\\\\\\n\",\n              \"/TIThvWmbx83nJycaN269Rc5N6swSZLEkiVLmD13HhY29ag/cCqljM055zmNrMRYuk/7DYMyeVeB\\\\\\n\",\n              \"z86WEXT8ABd3eeJ//Bjm5uZMmzYNb29vMjIyaNasGQsWLMhzg8n8zh8REcHdu3f59ddfuXv3LhUr\\\\\\n\",\n              \"ViQ0NBRdXV2ysrIwNzenffv23L17l27dujF27NjCejk+yrs1HOPj4ylRogS2trZYWBS8cv7X5ku5\\\\\\n\",\n              \"dipCBFghS0pKwqFJE+zrNmDRstXvfXd5704QA3q7sHDhrwwYMCDffRITEzl27BiHDh3C39+f2rVr\\\\\\n\",\n              \"4+rqiouLS5G3MiRJIikpSanAi4uLQ1dXN1eo+fv7czb4BTolFA8gSZJoUbMc4eFhuT74/1ZkZGQw\\\\\\n\",\n              \"atQoDh/1pYxNA5p9vyBncVpJLuf2wXXcP7GbCjXqYt/OmVKGZZBlZvDi3g3unfwDa6vvcOrckZ07\\\\\\n\",\n              \"d/L48WPMzMwYNWoUP/zwQ75TQN6+fZvr86l3owA1NDTQ0NBAXV2dSZMmERUVhY+vH+mZWZSzqoZW\\\\\\n\",\n              \"sWLEx0TxMPgWQ4YMYfq0qWJI+hfoS7l2KkIEWCF79y42JSWFxMQkBgwZTk+3fpiZW5Cens6Na1fY\\\\\\n\",\n              \"+vt6rl25yPr16+nZs6dCx01PTycgIIBDhw7h4+NDpUqVcsJMVW75IJfLefPmDY8fPyY0NJSnT5+y\\\\\\n\",\n              \"cOFCbjx9q/TnHR0bViE46PY39445MjISV1dXdHR0uHknhJ6eJ9DQyjv8Pys9laeXjvHixlkykhPQ\\\\\\n\",\n              \"0NTi7atQatlUIzg4GJlMRuvWrVm4cCH29vYApKSk8PDhw1whld+ACltbWywtLRk1ahTFihVj3bp1\\\\\\n\",\n              \"9OztRgaadBk8mhr1HHL9PqPCX3LqwHYu+B3gsLc3TZs2/Wyvl1CwL+XaqQgRYP/R06dP2bp1Ky9f\\\\\\n\",\n              \"vkIul2NhYU7//v2pWbMmkiTh4eFBbGwsBw8e5M6dO6xdu5YjR44QGxuLtrY21avbMHLkCPr27fvR\\\\\\n\",\n              \"3V4ymYwLFy7g7e3NoUOH0NXVxcXFBVdXV+rWrfvZB2NIkkRiYiKRkZFERER88GtCQgImJiaYmZlh\\\\\\n\",\n              \"bm5OwKlTnLz+GF09xUcHZmdn07xmWaLfvPmmbnVx+/ZtnJ2dGTJkCAGnz5BhUYt6vb5XvPwfG/jz\\\\\\n\",\n              \"xE6m/+8HOnXqxKNHj3K1rF6/fp1rQMW7rUKFCrn+pmJiYnB0dKRGjRqsWrWKdh06YlixKoN+mPvB\\\\\\n\",\n              \"KR93Lp9j7cxxnDl9itq1a/+Xl0L4hESAfQEK+5cQHBzMtGnTuXnrJr379KeajS1qamo8exLK3t3b\\\\\\n\",\n              \"sfrOinr16hIQEMCVK1fyrOgtSVKhBItcLufmzZt4e3vj7e1Nenp6Tpg1bdr0P43Ak8lkREdH5xtG\\\\\\n\",\n              \"//6ZhoZGTij9++s/vzcyMspVJyenLtRp1gHn3gMVrlfg6RNsX7uY27dU4z/d+7x8+ZJ169dz4MBB\\\\\\n\",\n              \"YmNj0dLSwrpyZb4fOYIePXrk6s7bt28fY8aMYdWqVZw6dYpt23fQZ+1JSpQ2Ufh8aQmx7BrZFi0N\\\\\\n\",\n              \"9ZwVKv65WVtbFzig4sWLFzmDjBYsWMC6devYvOcg/1u1Q6H5iqf+2EXIuaNcDDyvcL2FwiUC7AtQ\\\\\\n\",\n              \"mL+E06dP4+bmxvSZc+nddwA6OrlvZZ6VlcXRw95MHDeS+fPmMX78+EKpR0EkSeLBgwccOnQIb29v\\\\\\n\",\n              \"wsLC6Nq1Ky4uLrRt2zZn1GBKSkq+IfTvr7GxsRgaGr43jP759WNbQidOnGDyD9PY7nNe4YAfM9AF\\\\\\n\",\n              \"D/cBDB48+KPOWdQyMjIYOep7Dh8+TNsuPWnbrTdlzMuSlZXJo7tBHNu/laePQti4YT3dunVj1qxZ\\\\\\n\",\n              \"7Ny5k40bNzJ79mzKlCnDcX9/huy8pfS5tw6oz6sXzz9qQFBISAidOnVi8uTJTJgw4e8V8G1r0nvi\\\\\\n\",\n              \"HGzrN1HoGFlZmYzt3JDAc2ewsbFRug7Cp6dKASbmgSnp4cOH9OnTh8079tGkWYt899HS0sK1Z2+q\\\\\\n\",\n              \"VbfBtWsH6tat+9n7+eVyOTExMchkMurXr0+5cuV4+PAhV65cYfDgwbx9+xYdHR1kMhlAvmHUtGnT\\\\\\n\",\n              \"XP8uU6ZMge/IExIS/l4JPTGRkiVLYm9vj5mZmcL1bt++PfJJkzm0ZyuufYcUuP8Z/6PcuXUNH2N9\\\\\\n\",\n              \"OnbsqNS5vgSZmZk4OnVBpq7NzoDb6JTMHfxN25rTtG1nHt27zdDhffn++++RJImKFSvi5OSEtrY2\\\\\\n\",\n              \"GRkZyLKzP65VL0kf9Wbj0qVLuLq6snz5cvr27QvAtWvXSE3PoEY9B4WPo6VVjFbOfdi4cRMrVixX\\\\\\n\",\n              \"uh7Ct00EmJIWLPiVUWMnvje8/snGtiY//7KY2bPncPr0qU9y/rS0NCIjI9/bUnr3/Zs3b9DX188T\\\\\\n\",\n              \"Si4uLowaNQptbW1CCfrISgAAIABJREFUQkJybvleo0aNnOH5H7OIcEhICJ6enuzbv58atrXQL21I\\\\\\n\",\n              \"akoywbdv0rZtW8aNHUvz5s0LPI66ujpHfY7QtFlz5JKc7n3d33tRDvD1Zt70cfj7n+DYsWPUqlWL\\\\\\n\",\n              \"xYsXM2jQoDxlJEnixo0brF6zlitXrpCSnIyunh6tWrZg9OjR1Kyp+H2hPkZ2djaxsbG8efOG6Ojo\\\\\\n\",\n              \"nK979uwlS12bX9Z7fbB7t1rNOizZeoQxbh1QR+L69esYGBhgY2ODqakpx08G8DbsKYblrRWuU3z4\\\\\\n\",\n              \"X5QspZunB6Eg71bX2LlzJx06dMj5eWhoKFY2tZUO0e9sanE34A+lyggCiABTSkxMDEeP+nDz7hKF\\\\\\n\",\n              \"yzi79mTOT1N59OgR1apVy3efd6utv+/zpH9+TUtLw8zMLE8XXv369XMFlampaYGL0bq6ugJ/t5r8\\\\\\n\",\n              \"/Pzw9vZm4sSJ1K1bFxcXF5ydnfO9Jfu/bdu2jSlTfmDAsFGcvHibMmbmOY8lJiZwaP9u+vUfwKCB\\\\\\n\",\n              \"A5g3b16BFzhra2suXgikm7Mz3rs3073fUFp36EopPX1SU5K4eOYkuzevITU5EavvLPHz8+PXX3+l\\\\\\n\",\n              \"Z8+eDB06lN27d7Nx40YqVaoE/D1R1a1PH6Ki3tB74DBWDh1LyVK6JCbGE+DnQ/sOHalevTq7d+1U\\\\\\n\",\n              \"uAUnl8t5+/ZtnkB68+ZNvj+Lj4/HwMDg/7V353E1pv//wF+V9r1zTnVaVKgoSjWWFGXJmp0iskWy\\\\\\n\",\n              \"ZKexDIaJsYxl7IyQXZZsUdlK1vCRXamE0h5pr3Pevz/mq980beeohsP1fDzOg865r/u6zm3menXd\\\\\\n\",\n              \"93VfN3g8HrS1taGtrQ11dXU8jHmIgPO3Rbo2adzMHO7jpuLMoV1Yt24dnj59ijNnzuD9+/doaWmB\\\\\\n\",\n              \"5xcPwWHCYpHaDwCxl4MwfryXWIGzZ88eLFiwAOfPn6/0zLrS0lLINBJ/AWGZRrIoKRH/0SYMw66B\\\\\\n\",\n              \"iWHnzp0Iu3QVO/ccEKvc4gVzkZmWgs6dO1cZSmlpaVBSUqr2etI//9TU1GzQWYWFhYUICwvDyZMn\\\\\\n\",\n              \"ce7cOTRt2hSDBg3CoEGDYGZmVmn7oKAgTJ8xE4HHz6GZWdUBDQCZGekY694fQwcPxOLFonWyQqEQ\\\\\\n\",\n              \"V65cwZ+bNiMyIgKfPuVCWUUFrVpZ4cXzZ0hOTkZeXh46dOiAWbNmwcfHB6WlpVi3bh3WrFmDX375\\\\\\n\",\n              \"Bd27d0eXLl0xfupsjBg3scqJBaWlpdi+YTXOnjiE/YGBkJaWrjWQsrKyoKqqWiGQPv/9339qa2tD\\\\\\n\",\n              \"WVkZb968QWxsbPnr2rVr0OIbYeVO0Z4XBwA5WRnw6GqDlpYWGDRoEFxdXfG///0PCxcuRGbOB7j9\\\\\\n\",\n              \"GQIljdpH0IUfsxE8pz8eP/yfSPdi/XN1jdDQUJibm1fa5ty5c1j02+9YtCNI5O8DAGFBgch7/QSH\\\\\\n\",\n              \"DuwXqxzTMNg1sO/U+/fvYWzSROxyJk2aITTkLFRVVcHn82FlZYUePXpUCCdxnhnWkBQVFdG/f3/0\\\\\\n\",\n              \"798fpaWliIiIwKlTp+Ds7AxNTc3ye81sbGxQWFgIn0mTsO/Y2RrDCwC4PG3sPnQSvZ3aYvjw4TA1\\\\\\n\",\n              \"Na21LdLS0ujWrRu6dau8JqKDgwNCQ0PRv39/XLhwAY6OjjAwMICrqyv8/PzQv39/jB07Fr8sXoJ5\\\\\\n\",\n              \"v/wGN89x1dYjKysL37kLAQB9+/VDi+bNoaOjUx5ATZo0Qfv27SsEEpfLrbSKiFAoxNu3b8sD6ubN\\\\\\n\",\n              \"m+V/T05ORuPGjWFmZgYzMzPY2dnh2YuX6DpweK3H4Z80OTzYtLHH/DnTISMjg5EjR0JDQwNBQUEI\\\\\\n\",\n              \"DQvH7jVT0fXnHVBQrf6G7uL8XISv9MGUyZNFCi+hUIg5c+YgPDwcN27cgL6+foXPS0pKcPHiRQQG\\\\\\n\",\n              \"BuL5w/tIS34DHX3Rb6y/fvYYVv+2VOTtGeYzFmBikJGRgUAoELucQCBAt27dsG3btgZoVcORlZUt\\\\\\n\",\n              \"D5BNmzbh7t27OHnyJNzc3FBWVgZTU1O0tm2DltY2Iu1PW5ePIR6jsH37dvzxxx9it6egoKB8FNSu\\\\\\n\",\n              \"XTssXrwYL1++RHp6OmxsbDBo0CA0bdoU+fn5SE9PBxHBtLlljeH1T1PnLMCFM8exatWqaq/XEREy\\\\\\n\",\n              \"MzNx9+7dCqOp2NhYxMfHQ0tLqzykzMzM4OLiAjMzM5iYmFQKvAOHDkNdU/znmamqa2LevHmQlZXF\\\\\\n\",\n              \"ypUr4erqCikpKdjb2yP3Uy4OLfFEq8GTYNy2G2T+UaegtBSvoy/j0fGtKPyYhfZt29RaV0lJCcaN\\\\\\n\",\n              \"G4fXr18jMjKyfNV5oVCI69ev49ChQzhx4gQsLCzg4eEBLk8bl4ICMWLGIpG+y6snD/EpO0Pk548x\\\\\\n\",\n              \"zD+xABODsbExog4eFrvcsyeP0cK8aQO06L8jLS2N9u3bo3379li1ahWePHmCfv37Y8lK8WaOjRgz\\\\\\n\",\n              \"AQO6O2LlypUQCoXIyMgQ6RpSeno6hEJh+SiIw+HgxYsXiI+PR9OmTdGyZUvY2dlh+/btOHXqFGxt\\\\\\n\",\n              \"bdGzVy8MGjFe5LZJSUlh2OgJ2LxlC+zs7BAXF1choF6+fInY2FgQEczNzctDys3NDWZmZjA1NRVr\\\\\\n\",\n              \"Rp+ykjIKCwvEOn4AkJ2VCRcXF2zcuLHCtTMpKSms/+MPdHZywqo/1uP4/jVobOMIGUUVCArz8Pbh\\\\\\n\",\n              \"DTRvbo4dG9fCwMAAffr0AYfDqXaGbF5eHoYMGQI5OTmEhYVBUVERDx8+xKFDh3D48GFoaWnBw8MD\\\\\\n\",\n              \"9+/fLx/JJSUlwa5NG1i2dUTrDs41fo/cnCxsWzwNS5Ys/u6fEMA0DHYNTAz5+flo3LgxLl+/i8ZG\\\\\\n\",\n              \"xiKVyc3NhVVzYzx/9qzSqRdJp6amhqiYOKipibf+oHVTPgRlpSgpKanxGtK/31NRUalw/W/atGnQ\\\\\\n\",\n              \"0NDAsmXLyt/btGkTtm7divDwcJiamuJeXGqlkU9NsrMy0cnGFDLS0mjWrFmF0dTnF5fLrZfrkAsX\\\\\\n\",\n              \"LcLLNxmYvGCFyGWKiwoxomtr3Iu+iyZNaj6d/fz5c0RFRZU/b8vR0bHCvVbh4eEYOXIkLl26VGkW\\\\\\n\",\n              \"5ufVNVq2bImff/4Zx44dK394pYeHBzw8PNCyZcsq642KikL/AQMxdLIfnPoNhWwVS1vFP43B1l98\\\\\\n\",\n              \"4eHuhpUr/EX+/kzDk6RrYCzAxDR9+nQUlwErVq8Tafs/16/Bts0b0MTEBP7+/t/VYxnk5eURE58K\\\\\\n\",\n              \"BTGnYXe0McfZM6dhbS3+lOt/evz4MXr27ImkpKQK96fNmTMHUVFRSEhMRNSjRLH2SURooaeK4uJi\\\\\\n\",\n              \"sYLvS7x58wZW1q1x8PJDkRcvDg0+gofXziH04oV6acORI0cwd+5cREVFVRhFde3aFU2aNEFeXh7i\\\\\\n\",\n              \"4uLg5uYGDw8P2Nvbi7TCxqNHjzDFdxqeP3+BzgOGo4mFFWQaySIzNRnXzx5DXk4mlixZjPFeXvXy\\\\\\n\",\n              \"PZj6I0kB9v0/m76ezZ8/HxfPn8aBfQG1bnsx5Cy2b96Aa1evwtfXF97e3ujatStu3779H7S04XE4\\\\\\n\",\n              \"XKS+TxGrTHFxMXKys2BsbFznUUyrVq1gZGSEkJCQCu+vXr0aOjo6yPuUJ/Y+i4uKICsr2+DPpCL6\\\\\\n\",\n              \"+14uADi4Q7RfhvI+5WLvnyvR2bn2exBFNWzYMMybNw/du3dHYmIiVq78+4GSycnJ4PF4WLRoEVJS\\\\\\n\",\n              \"UrBlyxY4ODiIFF4AYGVlhesR1xB57QoMlKXwKPwE7pz+e7bhGv9f8ToxgYUXU2fsGpiYdHV1ERoa\\\\\\n\",\n              \"ih49e+JRzENMmjodJk0qXt9KSUlGwM5tOHxgL06fPo0WLVqgRYsWGDp0KPbu3YuhQ4fCxsYGy5cv\\\\\\n\",\n              \"l+hFTAcNGohTRw9i5nzR7z0KPReM5s1b1NtjT7y9vbFz507069ev/D1paWkcOXIEOrp8PIn5n8iT\\\\\\n\",\n              \"TADgdlQErOo4MqzN27dvMXXqVMTGxsKpU0ecPrgb6hocDB1b/UK8nz5+wFJfT7Sxs8HGjRuRm5uL\\\\\\n\",\n              \"pUuX1nqvX21KSkpgbGwMGRkZNGvWDDIyMvD29saqVavq5ZlqFhYWWL9etIBmGHGxEdgXMDc3x53b\\\\\\n\",\n              \"t6GppoSeXRwwpH9PLPp5Nn6ZPwcj3QeiU7vWKMr/iFu3bqFdu3bl5WRlZTFhwgTExcWhS5cu6NGj\\\\\\n\",\n              \"B4YNG4bY2Niv+G2+3OTJk3HkwB6Rb0IlIuzduQXv36fA1tYWgYGBdb6B1c3NDbdu3cKbN28qvK+o\\\\\\n\",\n              \"qIjp06Zh744/xdrfkX07MWWy6Cu6i0MgEGDz5s2wtbWFlZUVWrRogfT0dDQ3N8OhHevgO6wnosLP\\\\\\n\",\n              \"Q/B/y3sBQHZmOg7tWA+fgU5w6tAOp4ODERMTg0ePHsHBweGL/tsRCoWIjIzExIkToaenh9WrV8PJ\\\\\\n\",\n              \"yQmNGjVCq1atsG7duh/ugaCMZGIB9oV0dHSwZs0avH37FpMmeqOpsSGMDfUwYrg7kpKSsHXr1vKV\\\\\\n\",\n              \"IP5NQUEBM2bMwKtXr2BlZQUHBwd4eXkhKSnpv/0SdWRhYQFjYyPMmzYRolxK3b3tT5QUFSIhIQEr\\\\\\n\",\n              \"VqzAgQMHYGxsDH9/f2RmZn5RG5SUlODh4YHdu3dX+szXdyoiLoUi7sUzkfZ1/+4tPHp4H+7u7l/U\\\\\\n\",\n              \"lpo8efIEjo6OOHr0KIKDg3H58mUUFBQgISEB6enp0Nfjo7eLM0IO74BbJwt493fE2F5t4eXaAZSb\\\\\\n\",\n              \"ivPnzmDDhvWQkZGBtrZ2+XJODg4O+Ouvv6o8/vfu3cOYMWNh0qQJuFwe9A0MYWVlDT6fj8mTJ8PE\\\\\\n\",\n              \"xAT37t3DuHHjEBwcjKtXr6Jx48YYM2YMhEJhvR8Dhql39J2ys7P72k0QWXZ2Ni1YsIC0tLRo6tSp\\\\\\n\",\n              \"9P79+6/dpFoJBAKaP38+GRkZkbV1axrk5kEPYt9RYmZhpdezN1k0bfZ8MjIypqSkpAr7efToEY0b\\\\\\n\",\n              \"N440NDRo4sSJ9Pz5c7Hb8ujRI9LX16fS0tJKnwUGBhJfT59CIu/Ty9T8al9BIRHE09ahCxcufPEx\\\\\\n\",\n              \"qUphYSEtXLiQuFwubd++nV68eEHNmjUjJycnkpOTIy6XS/v27aOysrLyMikpKfTo0SN68eIFffr0\\\\\\n\",\n              \"qcb9P336lKytrWngwIGUmZlJRERv374le/sOZGBoRL7zltKJS/co7G4cnbrygGYuWkGNjZuQja0t\\\\\\n\",\n              \"xcbG0u+//05GRkb04sULIiIqKCigjh070rRp00goFNbrsWAkgyT1nSzAviFpaWk0Y8YM0tLSIj8/\\\\\\n\",\n              \"P8rKyvraTarSp0+faMCAAdSxY0dKT0+n/Px88vLyIjV1dRo63JN2HzpJJy5cpcDj58hrki9paXGo\\\\\\n\",\n              \"j6trjcGcmppKS5YsIW1tberduzeFh4eL1YHa29vTmTNnqvxsz549pK6uQZ5ekyjs1qMKwXX26l0a\\\\\\n\",\n              \"PHwUaWppVVv+S127do3MzMxo8ODBlJycTFFRUcTlconP55OMjAzNnTuXCgsL61xPUVERzZ49m/T1\\\\\\n\",\n              \"9enAgQOkp6dPvvOW0t1X2XQ/8WOlV3R8DvktW0tqaupkampK7969q7C/nJwcsrKyIn9//zq3jZE8\\\\\\n\",\n              \"ktR3sgD7Br1584YmTJhAHA6Hfv31V8rNzf3aTSqXlJRE1tbWNHbsWCouLq7wWXp6Oq1cuZK6dutG\\\\\\n\",\n              \"bdq2JefOXWjevHmUkJAg8v4LCgpo165dZGFhQa1ataKAgAAqKiqqtdyePXuoT58+1X4eHx9PLi4u\\\\\\n\",\n              \"pKioSM3MzKm17U/UpGkz0tXlk7W1Nc2dO1fkNtYmOzubvLy8yMDAgIKDg4mIKCAggBQVFUlaWpqa\\\\\\n\",\n              \"NGlCiYmJ9VYf0d+/VMybN48UlZRp2s+/Vhlc/34tXLGRjE2aUElJSaX9paSkkImJCe3atate28l8\\\\\\n\",\n              \"+ySp72QB9g2Li4ujESNGkLa2Nq1du5YKCgq+antu3rxJfD6f1q5d2+Cnl4RCIV28eJF69OhBurq6\\\\\\n\",\n              \"tGzZMkpPT692+/z8fNLS0qp0ivLfZs6cSTY2NnT16lV68uQJlZSUUGxsLHG5XMrJyalzm48cOUJ8\\\\\\n\",\n              \"Pp+mTJlCHz9+pPz8fOrVqxdJSUmRkpISTZ06tcLpwrooLi6mM2fO0LBhw0hNTY3atm1LTU2b072E\\\\\\n\",\n              \"DyIF2P3Ej/RTewcKCgqqcv+xsbGkq6tbHsLMj0GS+k4WYBLg8ePHNGDAANLX16dt27ZVGvn8F/bv\\\\\\n\",\n              \"309cLpfOnj37n9f95MkTGj9+PGloaNCECRPo6dOnVW43depUWrJkSY37EggE5ObmRm5ubiQQCMrf\\\\\\n\",\n              \"HzNmTK1la5KUlER9+vQhS0tLunnzJpWWltKOHTtIWVmZFBUVSV1dvdqgEIdAIKDIyEiaOHEicTgc\\\\\\n\",\n              \"cnBwoC1btlB6ejr16z+AFvivFzm87id+pBV/7qZOTs7V1hcdHU08Ho8iIyPr3HZGMkhS38kCTILc\\\\\\n\",\n              \"vXuXunfvTiYmJpUu/DeUz5M1TExM6PHjxw1eX03S0tLo119/JR0dHerRoweFhoZWGAk+evSIDAwM\\\\\\n\",\n              \"qpzM8U+FhYXk6OhY4bRhfHw8cTgcsa87lpWV0YYNG4jD4dDy5cupqKiITp06Rebm5qSpqUk6Ojp1\\\\\\n\",\n              \"PnZCoZBiYmJo3rx5ZGhoSJaWlrRixYpKpyEVFBToWkySWAF2+2UGycnJ1XiaNjw8nLS1tSkmJuaL\\\\\\n\",\n              \"vwMjOSSp72QBJoGuXbtGDg4O1KJFCwoKCqowkqhP/56s8a0oLCykgIAAatmyJVlaWtJff/1VPhmi\\\\\\n\",\n              \"ffv2Ik3GyMrKInNzc9q8eXP5e+PHj6eFCxeK3I6YmBhq27YtderUiV68eEGRkZFkb29PzZs3JxMT\\\\\\n\",\n              \"E+Lz+eTi4vLFk3ESExNpxYoVZGlpSYaGhuTn51dtiBQVFVGjRo3EOn34+aXF4VJaWlqNbTly5Ajp\\\\\\n\",\n              \"6+uLdT2TkUyS1HeyAJNQQqGQQkJCyMbGhmxtbSkkJESk61J3796liRMnUo8ePcnFpTuNHj2GwsPD\\\\\\n\",\n              \"K4VgTZM1vhVCoZDCw8OpV69epKOjQ0uWLKH169eTq6urSOUTEhKIz+fT6dOniejvwNDS0qKMjIwa\\\\\\n\",\n              \"yxUUFND8+fOJx+PRrl27KCYmhlxdXcnIyIhWrFhBfD6fNDQ0aN68eWKPktPT02nLli3UoUMH4nA4\\\\\\n\",\n              \"5OPjQ5GRkbX+khIXF0fS0tJ0+2WGWOF1L+EDKSgq1jpdn4ho06ZNZGpqWmvYMZJNkvpOFmASTiAQ\\\\\\n\",\n              \"UFBQELVo0YIcHR0pIiKiyu0iIyPJ1taOjIxNaNFSfzp47DQdPn6WVq7ZQJYtrahZM1M6duwYEf23\\\\\\n\",\n              \"kzXqy7Nnz8jb25vU1dVJXl6ewsPDRSoXHR1NXC6X7ty5Q0REPj4+5OfnV+32ly9fpmbNmpGbmxtF\\\\\\n\",\n              \"R0fTmDFjSFtbm9avX08XLlwgVVVVUlVVpSNHjojc9k+fPtGBAweod+/epKamRsOHD6ezZ89W+4tD\\\\\\n\",\n              \"UVERnTt3jkaPHk1mZmYkJydHAEhVVY3+3HNcrAALOB5GTZo2E/nfedGiRWRnZ/dNzYxl6pck9Z0s\\\\\\n\",\n              \"wL4TZWVltG/fPjIxMaHu3bvT3bt3yz87deoUcXk82r3/KKV+KKb03NIKr7SPJXT6whUyMDAkz1Gj\\\\\\n\",\n              \"iMfjfZXJGvUhIyOD2rVrRyoqKuTi4kIXLlyotXM+c+YM8fl8io+Ppzdv3pCWllalUUZmZiaNHTuW\\\\\\n\",\n              \"DA0N6cCBAzR79mzS0tKihQsX0ocPH8qnyfP5fHr48GGt7SwuLqazZ8/S8OHDSU1NjXr16kUHDhyo\\\\\\n\",\n              \"ciSUlJREa9eupW7duhGXyyUpKSmSkZGhxo0b09ChQ+ngwYOUn59Pu3fvJqduPcUKsD4D3eiPP/4Q\\\\\\n\",\n              \"+fgKhULy9vambt26iXR7AyN5JKnvZAH2nSkuLqZt27aRnp4eDRgwgIKCgojD5VJ4xO1KwfXv14On\\\\\\n\",\n              \"8aTF4dCmTZu+9teok5iYGNLX16fdu3eTtbU1tWjRgnbu3FnjbQhbt24lMzMzyszMpKlTp9KsWbPo\\\\\\n\",\n              \"6tWr5Os7jRw7OZG6hiZ16NCB5s6dW35qLyUlhYRCIfn5+ZGCggJ16NChfDWMqnyeQejj40NcLpc6\\\\\\n\",\n              \"dOhQPoPws8LCQrp8+TJNmjSJLC0tSV5evnwKvq2tLc2aNYuio6OrPKWYn59PXC6PdgeFihReB89F\\\\\\n\",\n              \"koaGJmVnZ4t1fMvKymjgwIHk7u7eYNdfma9HkvpOFmDfqYKCAlq7di1paGjQr/6raw2vz68DR4PJ\\\\\\n\",\n              \"zu6nr938Omvfvj2dPXuWhEIhXb58mVxdXYnH49Evv/xS7Yogfn5+5ODgQOvWrSdlFVUyampKY2cs\\\\\\n\",\n              \"pDkr/qTpS9dSpx79SEFRifr1H0Dv37+n4uJi6tevH8nJydGkSZOqnf0YExNDfn5+1LhxY7KwsCB/\\\\\\n\",\n              \"f39KSEggoVBIr1+/pm3btlGfPn1IR0eHpKWlSVpamnR0dKhXr160ZcsWSklJEfl7h4aGEpenTXtP\\\\\\n\",\n              \"XqoxvA6fjyIdXX75aWNxFRYWUqdOnWjq1KmVRrhCoZByc3MpKyvrP5kpy9QvSeo7WYB9x1JTU0ld\\\\\\n\",\n              \"Q4Pi3mSIHGDvc4qosZFxhVOQkiggIID69u1b4b0XL16Qj48PaWho0JgxYyrN6BMIBNTKyop09Axo\\\\\\n\",\n              \"9d5TdPFpGoU+S6/wCrr5goZNmEYGho3Lrz8FBARUqv/zDMKWLVuSoaEhzZs3j+7cuUNXr16lmTNn\\\\\\n\",\n              \"UuvWrUlBQYFkZGRIVlaWmjdvTuPHj6eLFy9Sfn5+nb77uXPnSIvDpcHDx9Ch89crBFdQ2B0aNtqb\\\\\\n\",\n              \"NLU4dPjw4TrV83nJqd9++42I/j7VOX/+fOJpa5OysjKpqauTgoICDRs2nK5fvy4x11N/dJLUd7In\\\\\\n\",\n              \"Mn/Hdu3ahbDLV7F1V6BY5Vb5/wopQRHWrFnTQC1rePn5+dDX18f4CRPw6NFj5OXlQVVVFZ2dnTBw\\\\\\n\",\n              \"4ECcOHECW7ZsQYsWLTBr1iz07NkTO3buxO+r12LN/rPQ0OLWuP/gA7uwZ4M/LoaEwMnp7wdMZmZm\\\\\\n\",\n              \"4tixYzh06BCeP3+OHj16oHHjxoiNjcWdO3eQmpoKKSkpKCsrw8rKCr169ULfvn1haWkp8oMiRZWa\\\\\\n\",\n              \"moq//voL23fshJS0NFTV1JH/6ROKi4vgPWE8vL29YWBgUOd63r9/jw4dOsDaujUiIyMxwG04PEZP\\\\\\n\",\n              \"QFNTcwDAxw85OHn0IA7t3QljYyOcOH683p4FxzQMSeo7WYB9x1asWIG0rI9YtNRfrHIHAwMQHHQY\\\\\\n\",\n              \"U6ZMhrKyMpSVlaGkpFThT2VlZSgoKNR7x1sfcnJyMGXKVJw+cxrdXQehc3dXqKip4dPHj7gWdhZX\\\\\\n\",\n              \"Lp7DoMGD8MfatQgJCcG6deuQl5eH5JT3WH/oPIxNW4hUz8rZE9DTyR6GhobYt28foqKiYGxsjJKS\\\\\\n\",\n              \"EiQnJ6OsrAxCoRB6enqwt7dH//790blzZ/D5/AY+Av9fWVkZEhMTkZubC1VVVZiYmEBWVrZe6/Dx\\\\\\n\",\n              \"mYTwy1dwODgUPB3dKrcRCARYvnA2nj68j4iIa+x5Y98wSeo72ROZv2OysrIoLS0Vu1xJSQmio6Ph\\\\\\n\",\n              \"5eUFOTk5yMjIQFpaGkQEgUCA0tJSFBUVoaSkBIqKihVCraqgE/e9z39XUlISOyAzMjLQyckZNu0c\\\\\\n\",\n              \"EXrnBVTVKv6279y9N2Yu8sd6/0Xo5tIdV69cho6ODoYMGQLDJqYihxcA9PXwwsKJwyEsK0VZWRlk\\\\\\n\",\n              \"ZWURHx+PFi1aYMqUKejVqxfatWv3VTvrRo0awdTUtMH2HxoaiouhoThxMRJanOpHrTIyMliycj3m\\\\\\n\",\n              \"TfPGnDlzsG3btgZrE/PjYAH2HTM1NUXw6bNil4v5333MnTsHgwcPxvv375GSklL++ufP79+/h5yc\\\\\\n\",\n              \"HDgcDng8HrS0tKChoQE1NTWoqqpCUVERCgoKkJOTQ3FxMfLz85GRkYHXr1+joKAA+fn5yM/PL//7\\\\\\n\",\n              \"v98rLCyEvLy8yEGopKSEQ4ePoJNLH8xYsLza76euoYUlq7fg98Vz0MrKCh8/fICQgAGe3mIdJ0vb\\\\\\n\",\n              \"dlBRVUMzEyN4eHjA2dkZFhYWkJGREfuYS6r1GzZgyqz5NYbXZ1JSUpj3y2/o7tAaK1euhIaGxn/Q\\\\\\n\",\n              \"QuZ7xk4hfsdKS0vR2MgIx4IvoHkLS5HKfMjJQVtrM7x8+RLa2to1bisUCpGdnV1luP3z59TUVKip\\\\\\n\",\n              \"qYHP50NPT6/89c+f+Xw++Hw+5OTkKuy/qKio1qD7/Ofjx48REXUTp67cE2nkVlZWBpc2ZsjL/QhF\\\\\\n\",\n              \"ZRUs334Yza1sRTpOny3wGgLbluawsbGBrKws5OTkICsrW+kl7vuysrLf5OnZf0pISECbNm0R9TAO\\\\\\n\",\n              \"CoqKIpebMXEUnB3tMXPmzAZsHfOlJKnvZCOw75isrCwmjB+PTetXY/OOvZCSkqq1zPYtG8HhcER6\\\\\\n\",\n              \"pLy0tDS4XC64XC6srKyq3U4oFCIrK6tSuD179gyXLl0qD7q0tDSoq6tXCrd//6yrq1vldZyBgwZh\\\\\\n\",\n              \"hNdkkTv+Ro0awXP8VESGnsabd+8gFApEKvdPZWVlSE1NRUxMDEpLS1FSUoLS0tJKL3HfLy0thYyM\\\\\\n\",\n              \"TJ2DsCHfDwsLQ3tHJ7HCCwC69OiD6+HnWIAxdcYC7Ds3e/ZsODg44o9V/pjtt7DGEAs6chCH9wdg\\\\\\n\",\n              \"4MCBsLKywooVK+Dl5SVS8NVEWloaPB4PPB4P1tbW1W4nFAqRkZFRaST3+PFjhIaGlgdfWloaNDU1\\\\\\n\",\n              \"K4Sbrq4uzp07h7n+m8VqW3+3kdj6hz/kFOSR8OIpLFq3EbmsQCBAWvIbHNq7CxYWFmLVWxsiQllZ\\\\\\n\",\n              \"WZ2DsLb38/Pzv3g/aWlpaOvgJPZ3U1FRw6dPn+r1eDE/JhZg3zl1dXWEhYWiZ89eePbkEXx8Z6BN\\\\\\n\",\n              \"W/sKofT82RPs3rEFVy6FIiwsDC1btoS3tzcmTJiAgwcPYufOnQ06EeAzaWlp6OjoQEdHB61bt652\\\\\\n\",\n              \"O4FAgIyMjAqnKV+9egV5OXkoq6iKVacWlwciIaQBnArcjj7uo0UO7OjIS9DX16v38AL+vl70eaTz\\\\\\n\",\n              \"rTp06BAOHzspdrnc3A9QU1NrgBYxPxoWYD8APT093Lx5Azt37sT0SV6Ql1dECwtLSMtI43VCPN6+\\\\\\n\",\n              \"fQPvCRNw/9698ute1tbWuHXrFjZt2gR7e3vMnj0bc+bM+SY61NzcXCQkJCAmJga3b9/G48ePkZiY\\\\\\n\",\n              \"iJKSErH3JRQYdUeoAAAbe0lEQVQKIRQK0bRpU8QnvkbMnSi0bt9RpHKHtq9D/14uX/IVvgsODg6Y\\\\\\n\",\n              \"OtUXBfn5UBJjpuWlC+fQy6VzA7aM+VGwSRw/GKFQiBs3biApKQlCoRB8Ph/Ozs41BtPr168xadIk\\\\\\n\",\n              \"pKSkYNeuXWjbtm2Dt5OIyq+T3b59G9HR0Xj+/DmSk5NRXFwMGRkZlJWVgcvlwsTEBFZWVjh+4gR2\\\\\\n\",\n              \"HQ1BM3PRR0SPHtzFFM9B2LZtK5SVlTF6zFis3X8WRs3Mqy0jFAqx8/dFiIu5iw852ejVqxfWrFnz\\\\\\n\",\n              \"Q96g27dfPzh07gV3z7EibZ+WmoJenX5C0uvXbBT2jZKkvvPbnubEfDEiws2bNzFi5Ejo6+tDVVUV\\\\\\n\",\n              \"fD4fAwcOQmFhITw8PDBq1Ci4uLjUOqoyNjZGSEgI/Pz80K9fP8ycORN5eXn10k6BQIBXr17h+PHj\\\\\\n\",\n              \"mDFjBpydnWFgYAB5eXkYGRmhR48eWLFiBZ4+fQpTU1PMnj0bwcHBePbsGYqLi5GWloYrV66gbdu2\\\\\\n\",\n              \"UFBQwKEA8e4vCty5CYqKikhNTcW2bdtgatoMP48bhDOHdiM/r/J1mpeP/4ffpo9ByquniIy4hqdP\\\\\\n\",\n              \"n0JaWhotW7bE+fPn6+WYSJKZM2Zg64ZVyEhLrXVboVCIFYv94DlyJAsvpl6wEdh36O3btxgydCiy\\\\\\n\",\n              \"MrPgNcEHfQcMhIaGJvLz8hAWGoJdO7ahpLgIx48fR8uWLcXad2ZmJmbPno3IyEhs27YNPXv2FKlc\\\\\\n\",\n              \"UVERXr58iaioKNy5cwdPnz5FUlIScnJyICUlBSKCpqYmjIyMYGlpifbt28POzg7m5ubV3i8UHx+P\\\\\\n\",\n              \"bdu2Ye/evejQoQPc3d0xafJkBIXdhq5e7cskvXkdj5GuzggI2I0pU6YgOzsbo0aNgpOTE04Fn8bl\\\\\\n\",\n              \"y5fRzskFalpclBYXI+buDQhLizBl8iRMmzYNiv+YfXflyhWMHz8ejo6OWL9+PTgcjmgH9DuwbNky\\\\\\n\",\n              \"HDp8FH8dPgV9g8ZVblNaWopf5vribeIrXL4UXuHYMd8Wieo7v8L6i/8JSVqQsj4lJSWRoaEh/bZi\\\\\\n\",\n              \"NX3IL6XcQkGl18eCMtoZsI+0tbVFenZVVUJDQ8nExIRGjBhR4XEgHz58oCtXrtCSJUvI1dWVzMzM\\\\\\n\",\n              \"SEVFhaSkpEhKSork5eWpcePG1KVLF5o1axadOHGC4uPjRV61XCAQUEhICPXu3Zu4XC7NnTuXEhIS\\\\\\n\",\n              \"KCUlhXr06EGycnKkb2hE4dGx9PBNbrWvkJtPyNDImDZs2EhOTk7k6elJ7969I39/fzI0NKT27dvT\\\\\\n\",\n              \"hg0baPPmzWRiYkJeXl4kLy9PhYWF1bYtLy+Ppk+fTnw+n44fP/5Fx1USCYVCWrt2LalraNCwkWPp\\\\\\n\",\n              \"zKWb9Cotn+LTC+j24wSaveBXMjBsTH379RPpyc/M1yVJfScLsO+IUCgkW1tbWvH72iqD69+vvfsP\\\\\\n\",\n              \"k6GhYY3Pyaqunnfv3lFAQAC1bt2aZGVlSVNTk2RlZcuDSl1dnSwtLWnw4MH0+++/082bN+nDhw9f\\\\\\n\",\n              \"/N2ys7Np3bp11LRpU7KxsaGAgADKzs6mY8eOkYuLCzVq1Ih0dHToyJEj9OuyZcThatO8X1fT9Sdv\\\\\\n\",\n              \"KwRX5KMkmr14BXF5OqSmpkY2NjY0fvz4Cs+1Kisro9OnT1PPnj2Jx+OVP8SyVatWIq3Sf+PGDTI3\\\\\\n\",\n              \"N6chQ4ZQamrqF39nSZOamkq//fYbGRkZl6+yr6qqSl5eXnT//v2v3TxGRJLUd7IA+45cuXKFWlhY\\\\\\n\",\n              \"0seCMpECLLdQQC7de9C+ffuq3F9ZWRnFxMTQ+vXrafjw4dS6dWvicDgkLS1NAEhOTo709PTIysqK\\\\\\n\",\n              \"OBwOWVpa0tWrV+v1GVAPHz6kCRMmkIaGBnl4eNCNGzcoKiqKJk6cSFpaWvTTTz+RlpYWzZo1q/x5\\\\\\n\",\n              \"XNeuXSMVFRXS1NQiRSUlcnTuRr37D6GOnV1IXV2DPDxGUFhYGBkZGZG6unqFEeS/xcXFkaGhIamp\\\\\\n\",\n              \"qZGhoSFNmDBBpO9XWFhIP//8M2lra9OBAwd+uEeJlJaW1jhaZb5dktR3smtg35EhQ4aiQ0dnTJg4\\\\\\n\",\n              \"SeQyF86fxdpV/li+fDkiIyPx4MEDvHr1CqmpqcjPzwcAKCsrg8/nw9TUFLa2tnByckLbtm0rXIgv\\\\\\n\",\n              \"LS3FunXrsGbNGsyfPx/Tp09Ho0ZfdpdGaWkpTp06hc2bNyMhIQE+Pj7o3r07Ll68iMDAQMjKymLU\\\\\\n\",\n              \"qFEoLi7Gtm3b8Ndff6Fv374gIpw8eRIjRoyAsbExdu/ejSZNmiA6OhqfPn2Cmpoa2rdvDwBwcXGB\\\\\\n\",\n              \"i4sLpKWlcfPmTYSHh0NBQaHK9nTu3Bl+fn44efIkTpw4AVVVVUycOBFeXl61Lrd17949jBs3DkZG\\\\\\n\",\n              \"Rti+fTv09fW/6JgwzH9FovrOrxygDUaSfouoL2pqapTwJlXk0VduoYBy8kpIVlaWZGRkiMfjka2t\\\\\\n\",\n              \"LY0cOZL+/PNPevLkidiPjI+Li6MuXbqQnZ0dPXjwQKyyKSkptHTpUuLz+eTk5ER79+6l7du3U8eO\\\\\\n\",\n              \"HYnH45Gvry9FR0fThw8faMiQIWRra0sJCQlERBQREUH29vakq6tLrVu3rrbdKSkpZGFhQYsWLSKh\\\\\\n\",\n              \"UEgCgYDc3Nxo+PDh1ZZxdnamK1euUFxcHBkYGNC9e/fIy8urfFRY28Mai4uLaenSpcTlcmnXrl0/\\\\\\n\",\n              \"3GiMkSyS1HeyAPtOCIVCAlDtxI2aXnp6evTmzZt6bUtAQADxeDyaN29ejU8YFgqFdPnyZXJ2diYl\\\\\\n\",\n              \"JSXq0aMHLVu2jNzd3UlNTY0GDRpEwcHBVFxcTEREMTExZGpqSj4+PlRYWEgxMTHUu3dvMjY2pvXr\\\\\\n\",\n              \"15OmpibFxsZWWdfbt2/J1NSUli9fXuH9goICsre3p4ULF1ZZ7nOACYVC4vF45ccqOzubNmzYQGZm\\\\\\n\",\n              \"ZtSqVSvatm0b5ebmVvtdY2JiyM7Ojrp160aJiYk1HUKG+Wokqe9kAfYdUVJSouT0D2KF18eCMtLQ\\\\\\n\",\n              \"0KCsrKx6b09qaiq5u7tT06ZN6dKlSxU+y8/Pp5UrVxJfT48UlZSota0t9ejZm+w7OJKamhq1srKi\\\\\\n\",\n              \"wMDACqOVPXv2EJfLpcDAQEpISKCRI0eSjo4Obdy4kYqKimjAgAG0dOnSKtuSmJhIJiYmtHbt2io/\\\\\\n\",\n              \"T09Pp6ZNm9Lu3bsrffY5wIiI+vfvT0eOHKnwuVAopEuXLtGgQYNIU1OTJk+eTI8fP66yntLSUlq1\\\\\\n\",\n              \"ahVxOBzatGmT2CNchmloktR3sgD7jnRycqL9h46JFWBXIm+RiYlJg3akZ8+eJUNDQxo7dizdu3eP\\\\\\n\",\n              \"Zs+eTaqqqqSiokJTfGfQw2dx9LFQUP5Ky8mnnQH7yLJlK/IcNYo+fvxI48ePJ3Nzc4qIiCBfX1/S\\\\\\n\",\n              \"0tKipUuXlo94zp49S82aNaty4kBsbCw1btyYNm/eXGM7X7x4Qdra2hQeHl7h/X8G2KpVq2jatGnV\\\\\\n\",\n              \"7uPdu3e0ZMkS4vP51KlTJzpy5Ej56PHfdXXo0IEcHR3p5cuXtR5DhvmvSFLfyQLsO3LkyBFycu4i\\\\\\n\",\n              \"VoB5jBxFq1atatB2CQQCOn78OBkbG5OUlBSZmpqSiooqXbwUUSG4/v1Kycylri7dSVdXlwYOHEh+\\\\\\n\",\n              \"fn6kpaVF06dPp7S0tPL95+fnk7GxMYWFhVWq+9mzZ6Svr087d+4Uqa0RERHE4/EqjKD+GWDXr1+n\\\\\\n\",\n              \"n376qdb9lJSUUFBQEHXu3Jl0dXVp4cKFlJSUVGGbsrIy2rhxI3E4HFq9enX5LEqG+Zokqe9kAfYd\\\\\\n\",\n              \"KS4uJj09PTp5JkSk8Lp+6x5pampSRkZGhf0UFBRQfHw8vXjxgjIzM7+4Pf+8d8vU1JQ6duxIKioq\\\\\\n\",\n              \"pKKiQoeOnawxvP4ZYo2NjElDQ4M8PT2rvHY0f/58cnd3r/R+TEwM8fl8CgwMFKvdBw4cICMjI0pJ\\\\\\n\",\n              \"SSGiigFWUFBASkpKlJeXJ/L+nj17RtOmTSMtLS3q168fXbx4scKINz4+njp37kxt2rSp9tQjw/xX\\\\\\n\",\n              \"JKnvZAH2nbl+/TrxeDw6eyG8xvCKun2f+Hp6FBQUVF42JiaGvL29SUNDgxobGVHTps1ITU2NOnfp\\\\\\n\",\n              \"QsePH6eSkhKR2vB5P6qqqtSyZUvi8/lkaWlJq1evpmPHjpGFpSV9KCgTKcA+Fgpo/aat1LVrtyrr\\\\\\n\",\n              \"evbsGXE4HEpOTq7w/r1790hHR4eOHj36Rcdx2bJlZGdnR3l5eRUCjIioXbt2dO3aNbH3mZeXR7t2\\\\\\n\",\n              \"7aLWrVtT06ZNae3ateW/IAiFQtqxYwdxuVxatmxZtcc6IyODfv/9d3Lu3Jla29hQhw4O5OvrS8+e\\\\\\n\",\n              \"Pfui78kw/yZJfScLsO/Q1atXicfj0VC3YRR2ObLCjc037jygsV4TSEtLq7xzFwgENHPmTOLr6dHC\\\\\\n\",\n              \"xb9SbMK78u0zPxZSwL6DZG/vQK1bt64UFJ+VlJTQ0aNHqX379n8HYOPGxOVyacaMGfTgwYPyyRhu\\\\\\n\",\n              \"bu60dsNmkcPrY6GAkjM+kqamZqW6hUIhOTs708aNGyu8f+vWLdLW1qZTp0598TEUCoU0ZswY6tev\\\\\\n\",\n              \"Hzk5OVUIsJkzZ9KKFSvqtO9bt26Rp6cnaWho0OjRo+nOnTskFArpzZs31KtXL7K2tq6wekVBQQF5\\\\\\n\",\n              \"e3uTuro6jfAcTSdOn6drUXcoJOwKzf15Ieno6FDnLl3Y7EamziSp72QB9p36fPqumakp6ejokJmZ\\\\\\n\",\n              \"Oenp65OBgQEtW7aM3r9/T0R/d6aTJ08m+w6OlJSSWeNsxcW//kbNmjWrcMoxJSWFFi9eTFpaWsTl\\\\\\n\",\n              \"cklJSYmGDBlC586dq3IUYWVtTZG37okVYB8LBWTfwYEiIyMr7CswMJBsbGwqXDuKjIwkHo9H58+f\\\\\\n\",\n              \"r/MxLC4upi5dupCBgUGFAAsKCiJXV9c675/o7xHV6tWrycTEhOzs7Oivv/6ivLw8CgwMJB6PR/Pn\\\\\\n\",\n              \"z6esrCxy7NiRBg9xo8R36VX++2R+LKTl/qtIT0+PTQph6kSS+k4WYN85gUBAb9++padPn1JSUlKl\\\\\\n\",\n              \"iQJnzpwhc/Pm9DY1W6TrZtNmzCZ3d3e6fv06de/eneTk5EhBQYFsbGxo586dlJOTU2N7zJs3p9v3\\\\\\n\",\n              \"H4kdYE7OXSrMDszOziZdXV26c+dO+XuXLl0iHo9XaRZhXeTk5JCSkhJNnTqViP4O/KCgIFLX0KBO\\\\\\n\",\n              \"nZzI0bEjubm505kzZ+q0hNbnRYpdXV1JS0uLZsyYQdevX6eBAweStrY2DXEbJtI9fpu27qAmTZrU\\\\\\n\",\n              \"eO8dw9REkvpOFmA/OBeX7rQzYJ/IsxbfpmaToqIiycrKEofDIT8/P3r16pXI9Tk4OtLJsxfECq8P\\\\\\n\",\n              \"BWVkampWYWUPHx8f8vHxKf/5/PnzxOPxKCIiol6PDxFR+/bticPh0NKlS8nCwoJatLCg31evo7MX\\\\\\n\",\n              \"wul86GXatHUHtWnbjoyMjGj//v11ri8xMZHmz59P2tra1KFDB1JTU6fUrE8i/xv16NmLAgIC6uGb\\\\\\n\",\n              \"Mz8iSeo7WYD9wOLi4ojH41F6Tr5YU+9HeI4mLy+vL1oS6Y8//iC3YR5iBdjliJvUpEmT8pl7d+7c\\\\\\n\",\n              \"IV1dXcrOziYiouDgYOLxeHTr1q16PT6fOTs707hx40hNTY1OnD5f7WLJlyNukrGxCa1evbpe6i0q\\\\\\n\",\n              \"KqK+ffuRz2Rfsf59jp08Q3YiTPVnmKpIUt/Jnsj8A7t//z4cOnaqdhHb6vTq7YqMjExISUmJXefY\\\\\\n\",\n              \"sWMRdjEEGenpIpf5c8Mf4HA4SElJQVlZGXx8fLB69Wpoamri2LFjmDhxIi5cuFC+UG99y83NxenT\\\\\\n\",\n              \"ZxB25Tpcuves9nu3adsOoZcjsXnLFgQHB9e5Xnl5eTx/8Ryeo8eKVc6le08kvX6N5OTkOreBYb5l\\\\\\n\",\n              \"LMB+YPn5+VBSUha7nLKKSvlK9eLS1NTE+PHjMWHsSJSUlNS6/akTQbhxPQKtWrWCtbU1evXqBWVl\\\\\\n\",\n              \"ZYwcORL79+/HjBkzEBYWBjs7uy9qjyjSMzIwf9ESWFjW/vRqvp4e/tiwGct/+w1UDw96yMnOhq4u\\\\\\n\",\n              \"X6wyMjIy0NHRRVZWVp3rZ5hvGQuwH5i6ujpyssXv5LKzsqCurv7F9a5cuRKamhoYOsAVye/eVblN\\\\\\n\",\n              \"aWkpdmzdjGmTJ6KllTWCg4NhZW2NiIgIPH/+HCNGjMDPP/+My5cvw8rK6ovbUpvk5GTkZGdj+AhP\\\\\\n\",\n              \"kcu4dO+J7KxsREdH17l+eXl5kYL+34qKiiAvL1/n+hnmW8YC7AfWqVMn3Lp5Azk5OWKVCz51HC4u\\\\\\n\",\n              \"3b643kaNGuHY0aPoYN8Oju1sMHLYYJw+dQK3b97AtauX8duvi2FpZoIzwSdx9fotnDkfhicvEzFw\\\\\\n\",\n              \"sBs4HA5atWqFU6dOQU5ODs+ePauXkU51zpw5g569XKGqqipyGWlpaQwb4Ynjx4/Xuf7mLVrg9q0b\\\\\\n\",\n              \"YpV5n5KCrKxMGBgY1Ll+hvmWsQD7gfF4PPRxdcXB/ftELpP87h2iIiMwcuTIOtUtIyMDf39/JCUl\\\\\\n\",\n              \"oXfPHljl/ys8hw/FKv/l+PjhA06fu4iQsCswNTMHAKioqGCslzcib95DQmIi5s2bhx07dmDZsmVw\\\\\\n\",\n              \"dHTEzZs369Se6mRkZMDYxETscvr6+sjMzKxz/T4TJ2L3rh1ildm75y+4u7tDWVn808MMI0lYgP3g\\\\\\n\",\n              \"Zkyfjj/Xr0F8/KtatxUIBJg1YwrGjRsHFRWVeqlfRUUFPXr0QGpqKm7e/R8uXrqGtev/rPZ6ky6f\\\\\\n\",\n              \"j/OhV7Bx40bY2NjgwYMH8Pb2hru7OwYPHoy4uLh6addn8vLyKC4qErtccXFJvZzCGzBgAOJfxSLq\\\\\\n\",\n              \"eoRI22dmZGDPXzswefLkOtfNMN86FmA/uJ9++glLly5Fv94uePL4UbXb5efnY6zncJQWF2PFihX1\\\\\\n\",\n              \"2obt27dj+IhR0NbREWl7IyNjuPYbgICAAMjIyGD06NGIjY1FmzZtYG9vD19fX2RkZNRL25o3b46o\\\\\\n\",\n              \"qEixy92PvgMzM7M61y8rK4s9e/ZgrOdw/O/B/Rq3zcrKwtBBfTF27NgGvS7IMN+Mrz2Pv6FI0r0M\\\\\\n\",\n              \"34IDBw4Qh8Oh3n1c6XjwOUp4k0pvU7PpVvRDmjptJnE4HBo9ZgwVFRXVa70CgYB4PB49ePxCrHud\\\\\\n\",\n              \"rl6/Tc2aNau0v/T0dPL19SUOh0P+/v51XpGipKSEVFXV6Pa9GJHblvg2jTQ0NOq0kv+/nTx5krhc\\\\\\n\",\n              \"Ls2cPY8ev4ivdHP5mnV/kpGRMc2ZM+eL7s9jmM8kqe+UImrAK+Bf0U8//YR79+597WZIlPz8fBw5\\\\\\n\",\n              \"cgQ7d+7Cq1dxKCkpAY/Hw5AhQzBp0iSYfMG1oNp8+PABRkZGeJcm3kSSsrIy8DSUUFJSAmnpyicS\\\\\\n\",\n              \"4uLisGDBAty+fRvLly+Hp6cnZGRkvqiNTZs2hY3tTwgIPCTSvW9Lf1mArMw07N2z54vqq86rV6+w\\\\\\n\",\n              \"ZcsW7N+/H81MzcDhcJCfX4CYhw/QrZsLpk6dAicnp3qtk/nxSFLfyQKM+aoyMjLQvHlzvE4W75Qf\\\\\\n\",\n              \"EUFLVR5FRUVo1KhRtdvdunULc+bMQV5eHlavXo0ePXqI3caOHTsiIyMDffsPwi9Ll9cYYgF/7cD6\\\\\\n\",\n              \"tatw69Yt8Pni3b8lqoKCAty9excfP36EkpISWrVqBV1d3Qapi/nxSFLfWf3/+QzzH9DQ0EBBQQE+\\\\\\n\",\n              \"fvwo1r1lKcnJUFNTqzG8AMDe3h5RUVEIDg6Gr68vjI2NsXr1arRu3brWOrKzsxEaGorU1FR07doV\\\\\\n\",\n              \"Z0+fxKNHD+E7fRY6OXWuEGT370Vjx9ZNiL57G+Hh4Q0WXgCgpKQEZ2fnBts/w0gKNomD+apkZWXR\\\\\\n\",\n              \"t28/HDl0QKxy+wP3wM3NTaRtpaSkMHDgQDx9+hT9+/dHz549MXr0aLx9+7bK7R8/foyxY8eiSZMm\\\\\\n\",\n              \"OHzkKBw6dkJ+YREAKTx/+hTeXqNh26o5PNwHYeTwIejQ1gbjRg1Ha+tWiI6OhqmpqVjfhWGYL8NG\\\\\\n\",\n              \"YMxXN3nyJEyaNBnjvX1Euk5VUlKCvQG7cCEkRKx6ZGVlMWXKFHh6epaPwry9vfHzzz+Xj/6Cg4Mx\\\\\\n\",\n              \"YcIETJs5G4+fx4HH45WXJyJcj4zAqpX+yPv0CW5DBkNBQQF8Ph/29vZffI2NYZgvw66BMV8dEaFr\\\\\\n\",\n              \"t25o3qIlfl+zrsZrTEKhEJMneqGoIB8nTtRtpYt3795hyZIlOHfuHBYuXIjmzZvD09MTp86EwLaG\\\\\\n\",\n              \"tRUFAgGmTpqId+/eIOT8ecjKytapHQzzLZGkvpOdQmS+OikpKZw4fhx3bkVhis/4aleqT0lOxljP\\\\\\n\",\n              \"4XiTlIjAQNFXD6mOgYEBdu/ejUuXLiEkJARubm7YtnN3jeEF/L2KyKat25GXl18vy0UxDPNlWIAx\\\\\\n\",\n              \"3wRNTU1ERERAUV4WdtYtMH7MSAQdPYyLF87j6OGD8PQYCvs21mhsqI/wsLB6XSapVatW8PPzg46O\\\\\\n\",\n              \"Lnr17iNSmUaNGmH6zNnYunVrvbWDYRjxsFOIzDcnJycHe/bswa1bt/Hp0yeoqanByakTRo0aJdai\\\\\\n\",\n              \"uuIYOXIkbH9qh8lTfUUuU1ZWBvOmRrh8+TKaN2/eIO1imP+aJPWdbBIH883R1NTErFmz/tM6X8XH\\\\\\n\",\n              \"Y9yESWKVadSoESxbtkJCQgILMIb5CtgpRIYBICgrq/WesqrIysqirKysAVrEMExtWIAxDABtHR0k\\\\\\n\",\n              \"Jb0WqwwR4fXrROiIuAgxwzD1iwUYwwAY5u6OwL0BYpW5Fx2NwoICtGnTpoFaxTBMTViAMQyAoUOH\\\\\\n\",\n              \"Iubh/xD78qXIZbZv3Qwfn0lVLibMMEzDY//nMQwABQUFLFiwECOHu+HDhw+1bn9wfyBuREVi/Hiv\\\\\\n\",\n              \"/6B1DMNUhQUYw/yf6dOnoVu3bujq7Ih70dFVbvPp0yesWumPJb8sQEhICLS0tP7jVjIM8xmbRs8w\\\\\\n\",\n              \"/0dKSgp//LEW5jvN4enhBi6Xh+EjPKGnr4+ioiLcvnUTx44cgpOzM27cuAEjI6Ov3WSG+aGxAGOY\\\\\\n\",\n              \"f5CSksLEid4YP94LFy9exMmTpxAZcQXy8vKwaGGBR48ewcDA4Gs3k2EYsABjmCrJyMigT58+6NNH\\\\\\n\",\n              \"tKWlGIb577FrYAzDMIxEYgHGMAzDSCQWYAzDMIxEYgHGMAzDSCQWYAzDMIxEYgHGMAzDSCQWYAzD\\\\\\n\",\n              \"MIxEYgHGMAzDSCQWYAzDMIxEYgHGMAzDSCQpIqKv3YiGwOVyYWxs/LWbwTAMI1Fev36NzMzMr90M\\\\\\n\",\n              \"kXy3AcYwDMN839gpRIZhGEYisQBjGIZhJBILMIZhGEYisQBjGIZhJBILMIZhGEYisQBjGIZhJBIL\\\\\\n\",\n              \"MIZhGEYisQBjGIZhJBILMIZhGEYisQBjGIZhJBILMIZhGEYisQBjGIZhJBILMIZhGEYisQBjGIZh\\\\\\n\",\n              \"JBILMIZhGEYisQBjGIZhJBILMIZhGEYisQBjGIZhJBILMIZhGEYisQBjGIZhJBILMIZhGEYisQBj\\\\\\n\",\n              \"GIZhJBILMIZhGEYisQBjGIZhJBILMIZhGEYisQBjGIZhJBILMIZhGEYisQBjGIZhJBILMIZhGEYi\\\\\\n\",\n              \"sQBjGIZhJBILMIZhGEYisQBjGIZhJBILMIZhGEYisQBjGIZhJNL/A3js3VliGNxAAAAAAElFTkSu\\\\\\n\",\n              \"QmCC\\\\\\n\",\n              \"\\\"\\n\",\n              \"  frames[6] = \\\"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAbAAAAEgCAYAAADVKCZpAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\\\\\\n\",\n              \"AAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0\\\\\\n\",\n              \"dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOzdd1RUx+P38TdNQClKEbAbsaOAWFGx\\\\\\n\",\n              \"VxAFG3bFir0lscQS9WussdcEK2osQaUpYonYFRUVS2yxgIAU6XXZ+/yRR34hoOwaEVfndc4eyr1z\\\\\\n\",\n              \"7+yi97MzOzNXTZIkCUEQBEFQMerFXQFBEARB+BAiwARBEASVJAJMEARBUEkiwARBEASVJAJMEARB\\\\\\n\",\n              \"UEkiwARBEASVJAJMEARBUEkiwARBEASVJAJMEARBUEkiwARBEASVJAJMEARBUEkiwARBEASVJAJM\\\\\\n\",\n              \"EARBUEkiwARBEASVJAJMEARBUEkiwARBEASVJAJMEARBUEkiwARBEASVJAJMEARBUEkiwARBEASV\\\\\\n\",\n              \"JAJMEARBUEkiwARBEASVJAJMEARBUEkiwARBEASVJAJMEARBUEkiwARBEASVJAJMEARBUEkiwARB\\\\\\n\",\n              \"EASVJAJMEARBUEkiwARBEASVJAJMEARBUEkiwARBEASVJAJMEARBUEkiwARBEASVpFncFSgqJiYm\\\\\\n\",\n              \"VKlSpbirIQiCoFKePXtGbGxscVdDIV9sgFWpUoWQkJDiroYgCIJKadiwYXFXQWGiC1EQBEFQSSLA\\\\\\n\",\n              \"BEEQBJUkAkwQBEFQSSLABEEQBJUkAkwQBEFQSSLABEEQBJUkAkwQBEFQSSLABEEQBJX0xU5kFgRB\\\\\\n\",\n              \"+FTS09NzV68wNTVFR0enmGv0dRAtMEEQhA8gSRIXLlygj1s/jE1MsW3UBJuGjTEyMaH/wEFcvXoV\\\\\\n\",\n              \"SZKKu5pfNNECEwRBUFJaWhp9+/XneugdbLq6MdnrLLr6hgCkJsYTGvg73Vx709K+GV67dogWWRER\\\\\\n\",\n              \"LTBBEAQlZGdn09WpG6/SYeRmX5r1HJYbXgClDI1o3mcko7f48/h1Ii49e5GTk1OMNf5yiQATBEFQ\\\\\\n\",\n              \"wpKlS4lJl9Nt2hI0tUq8cz8tbR16zPiZp1HxrF279hPW8OshAkwQBEFB2dnZrN+4idbDpqGhUfgn\\\\\\n\",\n              \"MBqaWrQaMoVVa9cjl8s/QQ2/LiLABEEQFOTr64th2fKYf1NL4TIVatugrl2SoKCgIqzZ10kEmCAI\\\\\\n\",\n              \"goIuX7lCpQYtlSqjpqZGFbuWXLlypYhq9fUSASYIgqCgpOQUtHVLKl1OS7cUyckpRVCjr5sIMEEQ\\\\\\n\",\n              \"BAWVKW1IekqS0uUykxMpXdqw8B0FpYgAEwRBUFD7du14fDFIqQnKcrmch5eCaNu2bRHW7OskAkwQ\\\\\\n\",\n              \"BEFBbdu2RRMZL+5eV7jMk+sXMCljSNOmTYuwZl8nEWCCIAgKUlNTY8jAgfiumUtWelqh+2ekJnNm\\\\\\n\",\n              \"2zK+nz4NNTW1T1DDr4sIMEEQBAX5+Piwfv06qlUw47c5I0lLevPOfVPexLFv9nCcOrZj4MCBn7CW\\\\\\n\",\n              \"Xw8RYIIgCIWQy+XMmzePcePG4evry/mzZ3Hu4MAG9474rf6ByMf3yM7MICsjnVcPw/BfM4eNIzrR\\\\\\n\",\n              \"38WJjevXidZXEVGTvtDlkhs2bEhISEhxV0MQBBX35s0bBg4cSEpKCgcOHMDMzCx3W0REBA6tWpGc\\\\\\n\",\n              \"mkZifDwA5uUrMMJ9KCNHjMDc3Ly4qv3BVOnaKVajFwRBeIc7d+7g4uJCt27dWLZsGVpaWnm2lypV\\\\\\n\",\n              \"ipjXr4mMjKRUqVLFVMuvl+hCFARBKMBvv/1G27Zt+fHHH1m1alW+8AIICAigdevWIryKiWiBCYIg\\\\\\n\",\n              \"/INMJuP777/n8OHDBAUFYWNj8859jx49Svfu3T9h7YR/EgEmCILw/71+/Zq+ffuira1NSEgIRkZG\\\\\\n\",\n              \"79w3MzOTwMBAcauUYiS6EAVBEIBr167RqFEjmjdvjr+//3vDC+DMmTNYWVnlGdQhfFqiBSYIwlfP\\\\\\n\",\n              \"09OTmTNnsmXLFlxcXBQqI7oPi58IMEEQvlqZmZlMmjSJs2fPEhwcTK1ait3nSy6X4+Pjw5kzZ4q4\\\\\\n\",\n              \"hsL7iAATBOGrFBERQa9evShXrhxXr15FX19f4bIhISEYGBhQo0aNIqyhUBjxGZggCF+d4OBgGjVq\\\\\\n\",\n              \"hLOzM4cOHVIqvEB0H34uRIAJgvDVkCSJtWvX0qdPH3bs2MHMmTM/aJmno0eP0qNHjyKooaAM0YUo\\\\\\n\",\n              \"CILKkslk+Pn58euOXbyKjEJdXZ1qVavgMWoErVq1yhNOaWlpjBo1irt373Lp0iWqVq36Qed88uQJ\\\\\\n\",\n              \"cXFxNG7c+CM9C+FDiRaYIAgqaceOnZSrWJlxM38k3NAK3ZZDKGE/kPuSOX2GjKRajdoEBQUB8PTp\\\\\\n\",\n              \"U+zt7VFTU+PChQsfHF7wd+urW7duqKuLy2dxEy0wQRBUzqLFP/Hzuk1YDVtM6Sp18mwzqWlH5dZ9\\\\\\n\",\n              \"eH33Er3c+uMxcgTbt2/jhx9+YPz48Up1Gaanp3PgwAFOnz1HUnIyhgYGXL5wjiVLlnzspyR8ABFg\\\\\\n\",\n              \"giColIMHD7Jy7UYaTduKjqFJgfuoqalhZmWP7vi1rFw2gg1rVjFq1CiFz5GZmcnsOXP55VdPylSp\\\\\\n\",\n              \"jUEtezQNKpOdlszrLC2GDR9JyI0bzJszp8A1EoVPQ9xORRAElSFJEjXr1sOo42jK1mmqUJm/Tu+n\\\\\\n\",\n              \"auYzfA4fUmj/1NRU2nfuyqt0DSxdJlDKtEK+fVKinvPIezU1yuoR4HsUbW1tpZ7H50yVrp2iE1cQ\\\\\\n\",\n              \"BJVx8eJF4pNSMa2l+ACKCs0cOXUqiFevXhW6ryRJ9HLrz2vJgPojfiowvAD0zCtjPXo5TxNzGDTU\\\\\\n\",\n              \"XeG6CB+XCDBBEFTGgUOHMG3YCTUlBlBo6ephUb8lvr6+he577do1roTcoPaAWYWeQ11DkzqD5hIY\\\\\\n\",\n              \"dJJ79+4pXB/h4xEBJgiCyoiKjkHbsKzS5dT1TYiKiip0v1Vr11OuuSvqGooND9AooUO5Zs6sXb9B\\\\\\n\",\n              \"6ToJ/50YxCEIgsrQ0dFGnpWtdLnsjDQWLlzIihUrMDMzy32Ym5vnfm9iYoL3oYO0Xuyj1LHLN3dm\\\\\\n\",\n              \"75IhbN4oQuxTEwEmCILKsKpTmwtHgwFXpcplRT7i8OHDODg4EB0dnfuIiooiOjqaGzdu8PLlS+So\\\\\\n\",\n              \"oa1XWqlj65YxIy0lmaysLEqUKKFUWeG/EQEmCILKGDpkCPMXLMLSJYESCgZN4os/kSXF0KVLFzQ1\\\\\\n\",\n              \"NTE0NCxwEd7ExETMyxU8aOO9JAkJCQ0NDeXLCv+JCDBB+Azl5ORw6tQp/vrrL3JycjAzM6Nz586U\\\\\\n\",\n              \"KlWquKtWrExNTXFycuLWqX3U6O5R6P6SJPHQbysaahIhISE0bfruoff6+vpoamqQFhtBSZPyCtcp\\\\\\n\",\n              \"OfIpJmXNRYAVAzGIQxA+I0lJSSxevJjKVaoy6dsZHD55Hr+zV1iyej0VKlZi0qTJvHjxorirWax+\\\\\\n\",\n              \"Xr6UhNAgXp4/8t79JEni0ZENGEspLFywAFdXV0aNGkVcXFye/RISEvD09KRdu3bk5OTw/Kxi88Xe\\\\\\n\",\n              \"enX+MKNHjlD6eQj/nWiBCcJnIiIigg4dO2FauTqTlv1CtTr182x//SqcEwd30rBRY3x9jtKkSZNi\\\\\\n\",\n              \"qumnd+/ePX7x9OTPR0/IycmheXN7zh73JDrsEpadh+RZTkqSy4l9cI2IP/ZhpJlN0MlATE1N6dWr\\\\\\n\",\n              \"F3PnzqVOnTr8+OOPmJubs3fvXk6cOEG7du2YNGkS1atXp2mLVlTrPJQSpQwLrVdmUhyvrp1gzO7V\\\\\\n\",\n              \"Rfn0hXcQASYIn4HExEQ6dOyEbVsnXIdPLHC9vrLlKjBw0mxq2TbGsVs3zp09S+3atYuhtp/OzZs3\\\\\\n\",\n              \"GTtxMvfvP6Baqx6UsXRATUODV68jkLR0SHh0nTvP76BtaPz3pGNJTkLEE4wN9Zk9aQKDBw9GV1cX\\\\\\n\",\n              \"AAMDA9zc3IiMjGTChAno6uoyefJktmzZQpkyZYC/W231repwZc0Emk3dhKbOu7tss1KTuLV5OlOm\\\\\\n\",\n              \"TKJcuXKf5PUQ8hIBJgifgdVr1mBWtdY7w+ufGjp0IGroBKZ9+x0BfoVPzlVVZ86coUfP3lj3nUjP\\\\\\n\",\n              \"sWvQ0My75qCV42Ai717j0pY59HLsSOtWDmhoaFC5cmVsbW1zX8dHjx7h5eWFl5cXJUqUYNCgQSxZ\\\\\\n\",\n              \"soQzZ84wa9YsEhISWLhwIaVKlWLcuHGkJifRqbkdp38eTZWuIylbr3meeWHyHBlRoWd5FrCVAb16\\\\\\n\",\n              \"sGD+/E/5sgj/INZCFIRiJpPJqFipMt+u3kmVmnUVKpOZnoZHl0aE3rxBlSpViraCxeDRo0c0ampP\\\\\\n\",\n              \"iwnLKFf3/ctGpcRFETh/CFs3rKFnz54AxMTEsH//fry8vHj27Blubm4MGjSIBg0a5HmDEBcXx8yZ\\\\\\n\",\n              \"M/H19cXc3BxTU1N+//139PT02L9/P0tXruLZi5eY1rVHrURJctKTibwVjKGBAc2bNKRv3744Ozt/\\\\\\n\",\n              \"UcPnVenaKQZxCEIxCwwMxNi8vMLhBaCtWxIHx554bttWhDUrPkuWLad6uz6FhheAnrE5TUbOY8YP\\\\\\n\",\n              \"c/ntt9/o1q0b1atX5/Lly8yfP5/w8HBWr16NnZ1dvtatsbExixYtokyZMrx48YKsrCxevnyJmpoa\\\\\\n\",\n              \"bm5u3Lx2hdPH/WltaURy2Bkib56hgk1LKjq48pdaWb5buIJyFSrxw5y5pKenF9XLIbyD6EIUhGL2\\\\\\n\",\n              \"5MkTqtS0Urpc5ZpWPL5/tQhqVLwSExM5cOAAzssOK1ymfL1mBG+aw88//8yECRPYu3cv+vr6hZZ7\\\\\\n\",\n              \"9OgRXbp0oX///sydO5fNmzfTqlUr3N3dmTNnDnp6ehzx8cXb7zgNBnxLJbtW+ZaZin/5mIO/byTg\\\\\\n\",\n              \"eGtOnTie+3maUPREC0wQiplMJkP9A+YQaWpqki3LKYIaFa9jx45RrrYdpYwUX/NQTU2Nup36Y2vX\\\\\\n\",\n              \"kEGDBikUXpcvX8bBwYHvv/+eBQsWoKmpyfjx4wkLCyMyMpI6deowcuQoNnvupPOPu6nSuF2BayQa\\\\\\n\",\n              \"VbTEYdJKcspa4ujcA5lMptTzFT6cCDBBKGZmZmbERoUrXe51xEsszJVf2PZzFx0dja6xhdLl9MqW\\\\\\n\",\n              \"IzLqtUL7Hj16lG7duvHrr78ycuTIPNvMzMzYtWsXv/zyCzt376bV1NWULF3wjTPfUlNTo9Hg7wmP\\\\\\n\",\n              \"TcLHR7m1FIUPJ7oQBaGYOTo6MnbceN7EvqaMiWKBJJfLOet3gIP79hRx7T49LS0tJLnyrRi5TEZQ\\\\\\n\",\n              \"0AksLS0xNzd/5yMwMJD169cTEBBAo0aN3nm8iIgIqlg3pXT5bxQ6v7q6BpYd+/Pz2vW4uiq3VqPw\\\\\\n\",\n              \"YUSACUIxK126NL169eLk7170Hj1VoTKhF//AUE/vvUsjqapq1arx5tlWpcslPH/A+LFjGDVyJFFR\\\\\\n\",\n              \"UXkely5d4tWrV9y8eZPY2FjU1NTo0qUL5ubmWFhYFBh0q9dvomqHIUrVoWqT9hzctYTw8HAqVPiA\\\\\\n\",\n              \"dRUFpYgAE4TPwIzvv6OpvT21GjShXqPm79339atwNsydjJaGGufPn6dly5afqJafRvv27clOjiPm\\\\\\n\",\n              \"6V1Mv1FsZKYsM53HwT54r7lO1apVqV69ep7tWVlZuLu7U7VqVcLCwihTpgxxcXH5gi4iIoLr168T\\\\\\n\",\n              \"FRXFo4d/UnuoYq2vtzS0SlDGvCIREREiwD4BEWCC8BmoXr06hw4coFfvPvQd9z2tuvVGSyvv3CJJ\\\\\\n\",\n              \"krh16SxbFk5n4Y/zqFq1Kn379sXd3Z158+ahpaX1jqOrFg0NDcZ5jMHLbwcmE5YVOrEb4M9Tv9O4\\\\\\n\",\n              \"SROqVq2ab1tiYiKurq4YGBhw6tQpSpYsCUDZsmUpW7Ys9evXz1cGoELlqqDAuQvyhU6v/eyIABOE\\\\\\n\",\n              \"z0SbNm04GXSCKVOnc2DTclo796Fyjbqoq2sQHf6cs777KamjzZaNG+jRowfw91JLw4YNo3nz5uzZ\\\\\\n\",\n              \"sydfy0NVTZo4kb379hP6+yZsenq8N8Re3DjLPV9PLp47m29beHg4Xbt2pWXLlqxdu1apFeMtypUn\\\\\\n\",\n              \"IeIphuaVFC6TI8vmTdRLypdXfDV74cOJUYiC8BmxtbXljzOnOB98liqldXh65SQPzvmjmxGH145t\\\\\\n\",\n              \"3Au7kxte8PeIOX9/fwYPHoy9vT3btm37It796+vrcyookJiQQAKXjOX1o9v5nlfS63BCvFZyzXMB\\\\\\n\",\n              \"x/x88q0LGRYWhr29PQMHDmT9+vVK3+5k9PChPD3zu1Jlnl09iVW9elSsWFGpcsKHES0wQfgM1apV\\\\\\n\",\n              \"ixUrliu0r5qaGuPHj6dNmzb079+fgIAAtmzZgrGxcRHXsmiZmpqSnZFOZtJ9rm2eibquPmUq1UBN\\\\\\n\",\n              \"XYPU2FfEPL3P0KGDObLuKpUq5W0lnTlzhr59+7J69Wr69++v9LmTkpJ48OABf4VepEHkcwwtKhda\\\\\\n\",\n              \"RpLLeXxiLyt/nKX0+YQPI1pggvCFqFu3LlevXqVy5crY2Nhw6tSp4q7Sf7JhwwZiY2M5euQIL589\\\\\\n\",\n              \"ZeemNUxy68pY13asmDOdyIiXrFm1Kl947du3j759+7J//36lw0smk7F161Zq1qxJbGwsc2bP5o+f\\\\\\n\",\n              \"J5GeFP/ecpIkEeK1AjMDXbp37670cxU+jGiBCcIXRFtbm5UrV9K5c2eGDBlC//79WbhwIdra2sVd\\\\\\n\",\n              \"NaUkJiYya9YsOnToQPPmf4/KbNu27XvLSJLEihUrWLduHadOnaJevXpKnTMoKIipU6dibGyMn58f\\\\\\n\",\n              \"dnZ2AGRmZfLLvMHYDfqWCjYtUFfP2xWZ8Oovbh/ahHZqNMeCAr+YwTQqQfpC2dnZFXcVBKFYxcTE\\\\\\n\",\n              \"SN27d5dsbGyke/fuFXd1lDJy5EhJW1tbioiIUGh/mUwmjR8/XrKyspJevnyp1Lnu3r0rde3aVbK0\\\\\\n\",\n              \"tJQOHz4syeXyfPscPHhQsrJpIJmUqyjZ9hghNR38ndTQbaL0jW1zqYyxqfTt9zOklJQUpc77uVKl\\\\\\n\",\n              \"a6foQhSEL5SJiQmHDx9m7NixODg4sGnTJpUY4PH06VN27NjB9OnTFbpRZHp6Or169eLu3bucP39e\\\\\\n\",\n              \"4flXMTExjBs3jlatWtG+fXvu3r1Ljx49Chzx2KtXL+7cvE7Akd/pVr8ctgbptK2sy6JvxxEZ8ZJl\\\\\\n\",\n              \"S36iVKl33/xSKBrifmCC8BX4888/GTBgABYWFnh6elK27Oe7hmKLFi148OABERERhXZ9xsbG4uzs\\\\\\n\",\n              \"TNWqVdm2bZtCXaWZmZmsW7eOpUuX5q5Cr+oDXj4mVbp2ihaYIHwFatasycWLF7GyssLGxobjx48X\\\\\\n\",\n              \"d5UKdPr0aa5cucIvv/xSaBg9ffqU5s2b06pVK3bv3l3o/pIkcejQIWrXrs25c+c4f/48a9asEeGl\\\\\\n\",\n              \"yoq3B7PoqFI/riB8SmfOnJEqVqwoTZw4UUpLSyvu6uTKycmRypcvL9na2ha677Vr1yQLCwtp/fr1\\\\\\n\",\n              \"Ch37ypUrUvPmzSVra2vp5MmT/7WqXzRVunaKFpggfGVat27NrVu3iIqKonHjxty+nX+ScHFYt24d\\\\\\n\",\n              \"0dHR/Pbbb+/dLyAggC5durBx40bGjRv33n1fvnzJwIED6dGjB+7u7ly/fp127dp9zGoLxUgEmCB8\\\\\\n\",\n              \"hRISEqhUuTIRkVHY2NigqaVFpSpVmf/jj7x69eqT1yc1NZVZs2bh5uZGjRo13rnfr7/+iru7Oz4+\\\\\\n\",\n              \"PnlWJPm3lJQU5syZg42NDd988w0PHz7E3d1d6dU4hM+bCDBB+IrIZDI8xo7D1q4h914l8P2m/Wy7\\\\\\n\",\n              \"9ATP8w8ZvXgzF8IeU6tOHebMmftJW2UTJ05EkiQ2btxY4HZJkpg3bx4//fQTwcHBNGvWrMD9cnJy\\\\\\n\",\n              \"8PT0pEaNGjx79ozQ0FAWLFiAnp5eUVZfKCZiIrMgfCXkcjn9+g/gcUQ0yw6fo6SeQZ7tlWvWZciM\\\\\\n\",\n              \"n+gxahprp7kT/+YN69etVWg1+P/i+fPn7Ny5kxUrVqCvr59ve3Z2NqNGjSIsLIyLFy9iZmZW4HFO\\\\\\n\",\n              \"nTrFtGnT0NfX5+jRo++9WaUqkMlkaGhoFPnrr8pEC0wQvhKrV6/m7uNnTFj+a77w+idDIxOmrfPC\\\\\\n\",\n              \"PzCo0M+jPgY3NzcsLCyYOHFivm3Jyck4OTkRExPDH3/8UWB4/fnnnzg7OzNq1CjmzJlDcHCwSoaX\\\\\\n\",\n              \"JElcuXKFAYMGoW9giLa2NlpaWtSqY8W6detISkoq7ip+dkSACcJXICcnh59Xr8FtylxKaOsUun9J\\\\\\n\",\n              \"PQN6jpvBgkWLi7Rep0+f5urVq+zZswd19byXo8jISFq1akXlypU5cuRIvonCcXFxTJw4kRYtWuDg\\\\\\n\",\n              \"4MC9e/fo2bOnSrZY4uPjadOuHS69+5JjVIlFB07z66WnbA5+QPcJP7DP7wQVK1dm7969xV3Vz4oI\\\\\\n\",\n              \"MEH4Chw/fpxSpY34po61wmWs7dvwKjKSiRMnkpOT89HrJEkSAwcOzA2gf7p//z7NmjXD1dWVLVu2\\\\\\n\",\n              \"oKn5f592ZGVlsWrVKmrVqoVcLufevXtMnz5d5dZ7fCshIYEWDq0oVb46iw7+QeeBozE0NkVNTQ1N\\\\\\n\",\n              \"rRLUbmjPqEUbmL5xP1Omf8evnp7FXeXPhggwQfgKnDp1GmuHTkqVUdfQoHmXHvj6+tK+fXtevnz5\\\\\\n\",\n              \"Ueu0atUqYmJi2L9/f57fnz9/ntatWzN//nx++OGH3BaVJEkcPnyYunXrcvLkSYKDg1m/fj2mpqYf\\\\\\n\",\n              \"tV6f2ohRo6lo1ZBeE2bla4X+U0XLWkxeu5vvvp/B3bt3P2ENP18iwAThK5CQlEQpfUOly+nql6Zf\\\\\\n\",\n              \"v3506NCBhg0bcvDgwY9Sn/T0dGbPno2Hhwfm5ua5vz906BCurq7s3r2boUOH5v7++vXrtG7dmrlz\\\\\\n\",\n              \"57Jhwwb8/f3z3cBSFYWHhxN04gQuHt8p1PVpUbkarVwHsnbd+k9Qu8+fCDBB+AoY6OmRkZaqdLnM\\\\\\n\",\n              \"tBQMDQ2ZNWsWfn5+zJ49m2HDhpGcnPyf6jN69Gi0tLRYuXJl7u9Wr17N5MmTOXHiBB07dgQgIiKC\\\\\\n\",\n              \"oUOH4uTkxMCBA7l582buti/B5i1baNq5BzolFV8I2KF7P/b9tk8M6kAEmCB8Feztm3H38hmlykiS\\\\\\n\",\n              \"xK3zJ2natCkAjRo14saNG2hqamJra8vly5c/qC7Pnj1jz549rF+/Hi0tLeRyOVOnTmXr1q1cuHAB\\\\\\n\",\n              \"GxsbUlNTmT9/PvXr16dcuXL8+eefjBw5Ms9nYV+Ck6fPYO2gXCCXKWtOxW9qcOPGjSKqleoQASYI\\\\\\n\",\n              \"X4EePXoQ9ewJEU8fKlzm/vVLvImNYdasWezfv5/s7Gz09PT45ZdfWLZsGd27d2fhwoXIZLJ8ZTMy\\\\\\n\",\n              \"MvDy8qJV2w7UqlufetZ29HHrz/nz5+nZsyfVqlVj8ODBZGRk0K9fP0JCQjh//jwVK1Zkx44d1KxZ\\\\\\n\",\n              \"k4cPH3Ljxg0WL16MgcG7h/2rmqysLMLDwwkJCSE6KpqS+so/t5L6hqIFhpjILAhfhRIlSjBmzGgO\\\\\\n\",\n              \"rf+JCct/Rb2QJZWyMjM4vGkZS39ajIWFBWvXrmXatGmMGTOGUaNG4erqSpMmTRgyZAiBgYHs3r2b\\\\\\n\",\n              \"qlWrIkkSy1esZNGi/6Gma0q6TlXUStSALDkPL4TjF9CD9NQktntuJT4+nh49emBubs6JEye4fPky\\\\\\n\",\n              \"U6dORUdHh0OHDuW2/FRBTk4OsbGxREdHExUVlefx798lJSVRtmxZzMzMSEtL+6Cu3fTUlAInfX9t\\\\\\n\",\n              \"xP3ABOErkZWVRcfOXZCXLI37D8vR1CpR4H4ZaamsmT6C+FfP+fP+vdxlmG7fvs26des4dOgQzs7O\\\\\\n\",\n              \"TJgwgQYNGrBq1SqWLFnCzz//zIWLl9lz4AhZZdugrlM637ElSUKe+BeaMecxMSpNr169GD16NDNm\\\\\\n\",\n              \"zODmzZssXbqUPn36fBZzuSRJIiEh4b1h9Pbn2NhYSpcujbm5ee7DzMyswJ+NjY1RV1dHkiQGDR5M\\\\\\n\",\n              \"TI42/abMVbheSW/imN27Nc//+gsjI6OP/rxV6dopAkwQviJpaWm49R/A9ZuhtHYdSMtufdAzLANA\\\\\\n\",\n              \"Quxrzh7dR/CRvXTu2JEcWTYRERH4+vrmWUswPj4eT09PNmzYQLly5ZgwYQKWlpZ06+ZMXHIWGtVc\\\\\\n\",\n              \"UdN4/5wseWoUvDiGW99e+Pv78+233zJp0iR0dAqfZP1fpaSkFBhG//5ddHQ0urq67w2jtw9TU1O0\\\\\\n\",\n              \"tLQUOn9sbCxeXl54enqSmJjIm8QkVvpdpYSCz/3Yrk1oJ0exa8f2//IyvJMqXTtFgAnCV+btkkVr\\\\\\n\",\n              \"12/gyOHD6JQsCZJEdlYWbm5ujB83Fmtra3Jychg9ejT3798nICAAQ8O8w/BlMhm+vr6sW7eOBw8e\\\\\\n\",\n              \"kJCYQk6FLqiXVGxelizyCtXKpBF89sx/vkN0ZmbmO1tH//5ZLpfnCZ93BZSZmRm6urr/qV5v5eTk\\\\\\n\",\n              \"EBQUhKenJ0FBQXTr1o3hw4fj4OCAo7Mz2ubV6DF6WqHHiYuK4KcRLgQG+GNnZ/dR6vZvqnTtFAEm\\\\\\n\",\n              \"CF+xzMxM4uPjUVdXx8jIKF8rQi6XM378eEJCQggMDKRMmTIFHmfNmjV8N2cJ6tV6KnxuKTsV9b8O\\\\\\n\",\n              \"ERUZUeAgDZlMRkxMjEKtpdTU1NwQKqy1pKen98m6KJ8+fcr27dvZsWMH5ubmDB8+nH79+uV5MxAd\\\\\\n\",\n              \"HU3jps1o4tiHLoM93lm3mIgXrJ0ylEnjPZg+rfCw+1CqdO0UgzgE4Sumra2NhYXFO7erq6uzYcMG\\\\\\n\",\n              \"pk6dSrt27Thx4gQmJib59jt34TJyw5pKDWtW0yqFWsmyTJgwATMzs3zhFB8fj5GRUb7WUqVKlWjc\\\\\\n\",\n              \"uHGegCpTpsx7V7H4lNLT0/H29sbT05M7d+4wYMAA/P39qV+/foH7m5mZcfH8Obo6dSP07HFaugyk\\\\\\n\",\n              \"SQfn3C7Fl48fcNZ7N9eC/Fi0cAETJkz4lE/nsyYCTBCE91JTU+Pnn39m1qxZtGnThpMnT+ZbFT4q\\\\\\n\",\n              \"Oho1LcUn476VTQnCw8OpU6cOdevWzdNaMjExUZl5X5IkcePGDTw9Pdm/fz+NGjXCw8MDZ2dnhdZo\\\\\\n\",\n              \"LF++PDdCruHv78/S5SvZtWQWuiVLIZNlY2BgyJjRo9i15i7lypX7BM9GdajGvw5BEIqVmpoaixcv\\\\\\n\",\n              \"Rltbm9atW3Pq1Kk8F1MdbR2QlF/wV56dxdmzZ3n06BEWFhZYWFhQrly5Ar83NTX97O6oHBcXx549\\\\\\n\",\n              \"e9i2bRuJiYkMGzaMmzdvUqlSJaWOExERwZatW9m69Rc0tXX4pmZd0lKSSEtJYsjgQQwbOlSEVwFE\\\\\\n\",\n              \"gAmCoBA1NTXmz5+PtrY2rVq14vTp05iamnLmzBniY18jpWZBaUuFjydJErpqKfifPk2lSpWIjIzk\\\\\\n\",\n              \"1atXREZGEhkZyYULF3K/j0C+LkEAACAASURBVIyMJCEhAVNT00KDzszMrEhbbnK5nFOnTuHp6cnx\\\\\\n\",\n              \"48fp2rUrK1eupE2bNh/UjXn06FGGDnOnSUdnJq/ZTYVqNXO3RT57wpnDXtjYNmD9urUMGDDgYz4V\\\\\\n\",\n              \"lScCTBAEpQwZMoSbN29Sq1Yt1NXVsbW1pUOHdtxbtxFJbo+aumKXFXlKOKnJCRw8eJDx48fTrFmz\\\\\\n\",\n              \"9+6flZVFdHR0vqC7du1a7vevXr0iLi4OY2PjQoPO3NycEiUKngtXkGfPnrFjxw62b9+OiYkJ7u7u\\\\\\n\",\n              \"bNq06Z0DWxTh7++P+4hRTFmzq8Bb3VhUqUb/KfNo2a0vUyYPRlNTk759+37w+b40YhSiIAjvJZfL\\\\\\n\",\n              \"CQkJwc/PD39/f549e0bnzp3R0tLi5MmTnDlzhurVq9OqTTsuPsxAs6xtoceUJDnqL/wx1M5CXV2d\\\\\\n\",\n              \"tLQ0GjVqxPjx43F0dPxPXYUymYzXr1/nC7p/fh8ZGUl0dDSlS5fOF2z//NnIyIirV6+ye/dubt68\\\\\\n\",\n              \"Sb9+/Rg+fDg2NjYfXL+3UlNTqVCpEpNWbseyXoNC93/+8B7LxvblrydPimQC81uqdO0ULTBBEPJJ\\\\\\n\",\n              \"Tk4mKCgIPz8/AgICMDY2xsnJidWrV9OsWbPcLrpffvmFNm3acOTIETLTU5FFXkdNsxQaRjXeeWxJ\\\\\\n\",\n              \"kqMRfR5bq6qcCgrE29ubOXPmEBERwaxZs5g4cSIeHh4MHz68wBGPBcnMzMTb25t1m7bw9MkTZNnZ\\\\\\n\",\n              \"mJQty6D+bowcMaLAeWZvl3/6d7jdv38fb29v7t+/T0xMDJIkUbJkSSpWrMiDBw/4+eefC2zRWVhY\\\\\\n\",\n              \"5Ltr9Pvs2bOHGtaNFAovgMo16mDTvB07duxg6tSpCp/nSyZaYIIgAPDkyRP8/Pzw8/PjypUr2Nvb\\\\\\n\",\n              \"4+TkhKOjI1WrVn1nuZUrV/L999+jrq6Ovr4+mdk5oFeZbP1aeSY1S5IceeJf6KTcw6ZONfz9juau\\\\\\n\",\n              \"55ednc327dtZsGAB1atXx9DQkD/++IMePXowfvx4GjZs+M7zHzp0iFEe4zAo/w3lW/SkTJXaqGlo\\\\\\n\",\n              \"kBYXSfgFX8JDTuE+3J3VK1e8t2X35s0b9u7di6enJ3FxcQwbNoyhQ4dSqVIl4uPjC23RvXr1Knda\\\\\\n\",\n              \"wvu6LsuVK4e+vj71rG3pOmo69Zq2Uvhv9Oj2dXb/bzpPHj0ssrlsqnTtFAEmCF+p7OxsLly4kNs1\\\\\\n\",\n              \"mJCQgKOjI05OTrRv3z7P8lHvEhwcjLOzMxkZGchkMkaOHMmCBQvYvHkLa9atJ1uuiUxNB3lONmQl\\\\\\n\",\n              \"IpNls+2XzfTr16/AgRbp6els3LiRpUuX0rp1a6pWrcr+/fsxMzNj3Lhx9OnTJ89yU56e25g2czaN\\\\\\n\",\n              \"PJZj9E3dAuuYmZLAza2zsalWDu+D+/OEmFwu58yZM2zbtg1/f386d+6Mu7s77dq1U7ob8+3aiYoE\\\\\\n\",\n              \"3d/PNYNtlx6jqanYElRvzzGyZU3iYmOUau0pQ5WunSLABOErEhsby/Hjx/Hz8+PEiRNUq1YNJycn\\\\\\n\",\n              \"nJycsLW1VWoU3datW5k2bRq6urpUq1YNfX19bt26ha+vL40bN0Ymk3H16lX27NnDo0ePWLNmDYMH\\\\\\n\",\n              \"D2bRokV06tTpvcdOSkpi1apVrF27lj59+tCkSRP27dtHaGgow4cPZ8yYMbx+/Zp2Hbtg/91WDCyq\\\\\\n\",\n              \"vPd4OdlZXFk9geF9ujF/7hxevnyZOyBDX1+f4cOHM2DAAIyNjRV+/h9KkiRiYmIoX74C2y8/Vbr8\\\\\\n\",\n              \"+A7WPPrzAaamii3ZpSxVunaKz8AE4QsmSRJhYWG5XYNhYWG0a9cOJycnVq1a9d5VON4lOzubyZMn\\\\\\n\",\n              \"89tvv2FmZka9evWQJInff/+dY8eO4eTkxOHDh2nevDn29vb89ddfvHnzhtq1azN06FB27NhRaIAZ\\\\\\n\",\n              \"GBgwb948xo0bx5IlS5g6dWpu627fvn3Y2tpSUt+QbzoPLjS8ADS0SmA1cCYrf3Ln4vlzXL9+nb59\\\\\\n\",\n              \"+3Lw4EEaNGhQ5EtLSZJEeHg4oaGh3Lp1i5s3byJJEmkpSZTUU/x+YNlZmaSmJOdbl/JrJQJMEL4w\\\\\\n\",\n              \"6enpnDlzBn9/f/z8/NDQ0MDJyYl58+bRqlUrhVaGeJe4uDh69uzJ48ePsbS0pHHjxty4cYOTJ0/m\\\\\\n\",\n              \"nsfLywsXFxcOHDhA69at0dXVJT09HQA3NzdmzZpFQkICpUvnv93Kv5mYmLBixQomT57MokWLcHR0\\\\\\n\",\n              \"ZPLkyfzxxx80atqMhi2cFa67gUUVSppXwdLSkqNHj360hXr/LSsriwcPHhAaGpr7uHXrVu6drK2t\\\\\\n\",\n              \"renVqxdvEpO4HOhD254DFT72tdMBtGjpoNTw/y+ZCDBB+AJERETkBtYff/yBra0tTk5OHD9+nFq1\\\\\\n\",\n              \"an2UFkZYWBjdunVDXV0da2trWrVqxfbt27lw4UKeMOjYsSP79++nT58+7NmzBx0dndwAMzY2zt0+\\\\\\n\",\n              \"evRohc9doUIFNm/ezPTp05k3bx4rVqzArJYdJUopdzfjSs2deR338KOF15s3b7h16xa3bt3KDasH\\\\\\n\",\n              \"Dx5QpUoVbGxssLGx4bvvvsPGxgZzc/M8ZU1NTRk9fhJtXAco/Pc56+3F/+bM/Ch1/xKIABMEFSSX\\\\\\n\",\n              \"y7l27Vpu1+CLFy/o0qUL/fv3Z+fOnf9pcm1BfHx8cHd3x9DQEAcHB9q3b8+MGTM4f/58gXOS2rRp\\\\\\n\",\n              \"g7e3N66urkyZMoWMjIzcbUOHDmXRokVKBdhblpaW7Nmzhx9++IF9Z28rXV7HoAzx4QlKl5MkiWfP\\\\\\n\",\n              \"nuVpUYWGhhIXF0f9+vWxtramWbNmeHh4YGVlRcmSJQs9Ztu2bdHT1eaY1xa6DhpT6P6nf/ciMzkB\\\\\\n\",\n              \"Jycnpev/pRIBJggqIikpiRMnTuDv709AQACmpqY4OTmxbt06mjZtWiTLJ0mSxOLFi1m3bh16enr0\\\\\\n\",\n              \"7duX9u3b4+bmxsmTJ6lcufI7y7Zo0QJfX1+6dOmSZ3BEp06dGDFiBBcvXiQ+Pp7k5GT09fVp3Lix\\\\\\n\",\n              \"wvcFq1SpEmrSDaWfjywrg9Kl3h8uGRkZ3L17N0+r6tatW+jr6+e2qgYMGMDy5cupVq3aB6+Cr66u\\\\\\n\",\n              \"jp/PUZrZNydHlk3XQR5oFPA3lOfkcPLADk7s2cKF8+dUZoHjT0G8EoLwGXv8+HFuK+vq1as0b948\\\\\\n\",\n              \"9/OsKlWqFOm509LScHd3JywsDHV1daZMmULr1q3p0KED+/fvf+ftQf6pSZMmbNy4kUGDBrF//376\\\\\\n\",\n              \"9u3L9evXMTQyoXW7DpS1tEZdpxTyzDTi/wqjc+fOfDt1cr5lpeLj47lw4QLBwcEEBwdz+/ZtJC0d\\\\\\n\",\n              \"bHNkqGsofhl78+c1HNv838ThmJiY3KB6+/Xt53tvw6p79+5YW1srPKlaGRUrVuTK5Uv0cevHd957\\\\\\n\",\n              \"cOjRD7vWnSmpb0B6agqhwSc5e2QP5czNuHTxwnvn432NRIAJwmckOzub8+fP54ZWcnIyjo6OTJgw\\\\\\n\",\n              \"gXbt2ik0N+tjePnyJT169MDExISYmBhWr16Nvb09LVq0YP369bRp00bhYzVo0AALCwumTJnCkaNH\\\\\\n\",\n              \"8T9xCuNmvanf6yc0dfVz9zNPT+bmjUA6d3Nh3OgRWNevR3BwMOfOnePZs2c0bdoUOzu73FXps7Kz\\\\\\n\",\n              \"iQw9R3k7xeqSlZrM88sniLGqgKOjI7du3SIlJQVra2tsbGxo06YNU6ZMoU6dOnnmmhW18uXLc+Fc\\\\\\n\",\n              \"MDdv3mTd+g1sneVBSnISpfT1aW5vz9HfD9KoUaNPVh9VIgJMEIpZTEwMx44dw9/fnxMnTlC9enWc\\\\\\n\",\n              \"nJzYt28fNjY2n/xGjRcvXqRXr1506NCBgIAA9u7dS4MGDWjRogXTp0+nT58+Sh1PR0cHuVyO+/Dh\\\\\\n\",\n              \"LF+3hTpjNqBdOn9XoaauPmbNe1HayoGfN46nekUzhg4ZwuDBg3n16hULFy5k6dKlGBkZMWnSJKpX\\\\\\n\",\n              \"r87MRcsxq9cMzRKFB849n63olNSldOnSjBw5EhsbGypXroyamhpxcXFs376diVOnk/DmDbolS9Kw\\\\\\n\",\n              \"gS3jx3pQt27BE6Q/NltbW7Z5/vpJzvWlEBOZBeETkySJO3fu5Lay7t69S/v27XFycqJLly75Rqt9\\\\\\n\",\n              \"Stu3b+f777/Hzc2NgwcP4uvrS506dWjfvj0tWrRg2bJlSh8zJiaGmjVrkpmdQ3WPzegYFX5fq8w3\\\\\\n\",\n              \"kfy5cTQD3Ppw8OBBkpOTadiwIYsXL6Zdu3bA3wNZnF1cufU8lsZjl6GpXfDIQkmS+PPYLl6dPUA5\\\\\\n\",\n              \"M1MqVarEtm3bMDU1JTMzk4mTp7J37x5qNW1LTYeulDI0IjsznWehlwk9foA6dWqza/s2vvnmG6Wf\\\\\\n\",\n              \"uypSpWunCDBB+ATS09M5ffp07lB3LS2t3BUwHBwc/tPcLEVlZGRw8eJF4uLi0NHRoXbt2lha/n3/\\\\\\n\",\n              \"LplMxnfffYePjw9du3bFz8+PwMBAqlatSs+ePdHX12fXrl0f1Bp88+YNpmXNMLVpS+Weig8Bf7z/\\\\\\n\",\n              \"fyTfP8fIEcOZO3cusbGxeT6rCg0NJTs7G52SeqRkyqjWaQBV7LtRotTf3ZLyHBmRt87z8uwB0qNf\\\\\\n\",\n              \"IMtMx9nZGV1dXXx8fNiyZQvLVqwkUdKh84QfKWWYfzRljiybaz5ehBzextkzp6ldu7bSz1/VqNK1\\\\\\n\",\n              \"U3QhFoOUlBTi4+PR1tbG2NhYjCr6QoWHh+cG1tmzZ2nQoAFOTk6cOHGCmjVrFvnqD2+9ePGCtevW\\\\\\n\",\n              \"s237dkzLV8bApCzZWVk8v3+LevXqMdJ9GLt27QL+Hv4eHBzMxYsXMTMzY8yYMaSnp3Pw4EGFwys9\\\\\\n\",\n              \"PZ0rV67kDri4cuUKkpoGxk1clKq3Rcs+pD2+ytWrV6latSrlypXDxsYGa2trJkyYgI2NDeXLlwfg\\\\\\n\",\n              \"3LlzrF63gcDvu2Fgao66hibJsdF8U+0bFk2bQO/evcnIyGDFihVs2rSJVq1a0X/gICrVa0yv2atR\\\\\\n\",\n              \"f8e6hxqaWjR1HYaOniEdOnfhwd2wT/Y5pFA40QL7RGQyGT4+PmzYsJFLly5iZGRMRmYGGurquLu7\\\\\\n\",\n              \"M2bMmPcOSRY+fzk5OXnmZoWHh9OlSxccHR3p1KnTR5+bpYjTp0/Tq09fbNp3p0m3/phUqJK7TZaV\\\\\\n\",\n              \"ye3gQIJ2rMFIvyR1atXkzZs3HDlyBENDQxYsWMDRo0f5448/cleNL0hiYmLuCMFz584RGhpK/fr1\\\\\\n\",\n              \"admyJQ4ODjRu3Bgzc3Ma/++0UqEtSRIhczoQdCKQRo0avbcOb8XHxxMeHk52djZly5alYsWK+faJ\\\\\\n\",\n              \"jo5m5syZeO37jW/3X0JbV7FFcX9fOI6xA1w/aP6aKvncrp3vI976fwIvXryga1dHSunrM2yEBzv3\\\\\\n\",\n              \"H83tMnr86E92em7FtkEDZs2cybRp0z7ZO3Phv0tMTMwzN8vMzAwnJyc2bNhA06ZN/9ONGf+rq1ev\\\\\\n\",\n              \"0rN3H9zmrMXStmm+7ZoltGnQ3pl6LTuyffYoboTe4v7dMHR1dfn111/ZuXMnFy5cyBcc0dHRnDt3\\\\\\n\",\n              \"LjewHj9+TKNGjXBwcGDhwoU0adIkd6X0tLQ07ty5g5qahtL/rtXU1EBdnYULF1K2bFmMjIwUfryP\\\\\\n\",\n              \"mZkZ5hblaNi5l8LhBWDTtR+r161k1KhR4v/oZ0K0wIpYZGQkzeztcR81Fo/xU965X0T4S/r1dGLw\\\\\\n\",\n              \"oIHMnCmWivmcPXr0KLeVde3aNVq0aJF736zPpRUtSRJ161nTsPcorFt3LXT/7MwMNo7ryabVK5DL\\\\\\n\",\n              \"5YwaNYrg4GAsLS15/vx5blgFBwfz+vVrmjdvjoODAy1btqR69eq8ePGCx48f8/jxY548eZL7fXx8\\\\\\n\",\n              \"PFWqVOHBw0fYzT6MZknFl36SZaQQ+j8X/Hx9SEhIID4+vtCHlpYWRkZGGBsbvzfgJk/7li7TllO+\\\\\\n\",\n              \"ZuFz2d6Sy+WsHeTA9SuXvuj5WJ/LtVMRogVWxIYPH0GffoPfG14A5StU5MCRY3Rq3Yx27drRuHHj\\\\\\n\",\n              \"T1RDoTBZWVl55malpqbi6OjIpEmTaNeuXZHdl+m/uHjxIkmpadRv1UWh/bW0dbDvOYwF/1vMw/v3\\\\\\n\",\n              \"GD16NPPnzyc4OJjs7GyaNm1KtWrVGDJkCJmZmfz1118cOXKEFStWkJqaiqWlJZaWllSrVo0mTZow\\\\\\n\",\n              \"YMAALC0t0dXVZenSpTx8sp6Y68exaKn4EPy4GydwdO5O586dFdpfkiRSU1PfGW6xsbE8fPiQ+Ph4\\\\\\n\",\n              \"Yl6/Rt9YsVU/3lJXV8fQuCxxcXFfdICpEhFgRejx48dcC7nGlp37Fdrf3KIco8ZNYv36DezaJQKs\\\\\\n\",\n              \"OL2dm+Xn50dQUBA1atTAycmJ/fv3Y2Nj89l3Ia3fuImGjv2UqqdNWycOr5lPCU0Njh07hra2NqVL\\\\\\n\",\n              \"l+bly5ecOXOG8PBwqlWrhqWlJW3atGHkyJFYWlpibm6e5zySJBEYGEjv3r25fv062tra6JbQJCHE\\\\\\n\",\n              \"B/PmvVBTYDCIJJeTdMOXaXt3KFx/NTU19PT00NPTo1KlSvm2Z2Vl8fDhQ8LCwjh9NhhZdpbCx35L\\\\\\n\",\n              \"lpX5SUaMCooRAVaENm3aRL8BQ5Sa1d9vwBCa2NQkNja2SJauUUWSJOWumVeqVClMTEw+eoBIksTt\\\\\\n\",\n              \"27dzW1n379/PnZu1bt06zMzMPur5itq9+w9oNdJVqTJa2jqYVa5GjQpmtGjRIrdVZWlpibGxcaGv\\\\\\n\",\n              \"eVxcHAsWLGDnzp0kJydjZWXFvn37cHV1pWHDhmTKJJ77raVyt0nvPZYkSUSd/AXLyuVp2bKlUs8B\\\\\\n\",\n              \"/u7q++uvv7hz5w5hYWG5jydPnlClShWsrKwwNS3Ly7s3MbLIH3TvkvImjoTY6M+mm1gQAVakrly5\\\\\\n\",\n              \"yrezf1SqjJGxMVb1rbl161buhM2vVXJyMl5eXmzYuJEXz19gYGhISkoKxsbGjPUYw7Bhwwr9wP59\\\\\\n\",\n              \"0tLSOH36NH5+fvj7+6OtrY2TkxMLFy7EwUG177mUlZWJppbit6p/q4yxCVOnTqVr18I/N4O/w8bP\\\\\\n\",\n              \"z4/58+cTGhpKqVKlGDhwIPPmzcsT+pIkERMVgXp8HBG+qzHvMCLPMlJvydKTiT61Dd24Pwk4H1xo\\\\\\n\",\n              \"0EVGRuYLqvv372NiYoKVlRVWVlY4OTkxY8YMatWqlftm8ujRo0yfuwjr9t0Vfm1CTxzCxcUFAwPl\\\\\\n\",\n              \"buEiFB0RYEUoJSXlgz4f0dDUYvfu3Tx//hwTE5M8j9KlS3/ypYWKw8WLF3FxcaW+XRPGz1xMY/tW\\\\\\n\",\n              \"qKmp/d1SunGV3/d48r//LWb37l04OjoqfNyXL1/mzs0KDg7Gzs4OJycnTp48SY0aNT77rkFFGRub\\\\\\n\",\n              \"kBgTTfnqyi2DlPA6Ks/K8e8SExPDnDlz2Lt3L6mpqTRo0CB3EvS/X8MLFy5w7949pk2bxsyZMxnl\\\\\\n\",\n              \"MY4jK/phUNOeMvVao6GjR05GCmmPLhN35w8cHR351f9inrsOx8fH5wmptw8tLa3coLK3t2fUqFHU\\\\\\n\",\n              \"rVu30JBxdHRkzLjxPL15kW9s7Qt9vmlJb7jht5dAf59C9xU+HRFgRcjAwICEBOXvPZSUmEBGhgnn\\\\\\n\",\n              \"z58nNjY2zyM5OZkyZcrkC7a3D2Nj43y/MzAwUKkL8+XLl+nm3J0fV26heesOebapqalhbdcEa7sm\\\\\\n\",\n              \"3L5xlWHuA9i+zfOdIZaTk8PVq1dzuwZfvXpFly5dGDx4MF5eXgrdFVgV9ejmyI7Dv1HHvq3CZcIf\\\\\\n\",\n              \"hpGdkUrDhg0L3C5JEr///jsLFy7kzp07lC5dmhEjRjBnzpx3znHz9fXF3d0dW1tb7O3tMTAw4Lc9\\\\\\n\",\n              \"uxk0aBCvX8eQ9vI0SUlJGOnr06FzKwZ5rSIuLg5vb+88QfW2S/Lto3fv3tStW1fh26/8m6amJl47\\\\\\n\",\n              \"d9Crrxu9522mQi3rd+6bnpzIwfljGDKwP3Z2dh90PqFoiAArQm3btiHA9zCt2ijeFfgqIpwH9+5S\\\\\\n\",\n              \"t3YtXFxc6NChQ57P0GQyWe6Iqn8/wsPDCQ0Nzff7jIyMAoPtfY+SJUsWS+hlZWXh2rMX85ZvzBde\\\\\\n\",\n              \"/1a/QWOWbfJi0KC+PH36JDeMEhMTCQwMxM/Pj2PHjmFhYYGTkxObNm2iSZMmxTo361N49OgRmzdv\\\\\\n\",\n              \"5mXEKxJjozE0Uezzu6s+exk7ZnS+1ycqKorZs2dz4MAB0tLSaNq0KSdPnqRt2/eH4/bt25k5cyb+\\\\\\n\",\n              \"/v4sX748967MWVlZBAQEsG/fvjwtq107d7B0yU/UrFkzN6gmTZqElZUVFStW/Oj/Htu1a8eu7dsY\\\\\\n\",\n              \"NHQY1h160qCrG2Us/m/ic2ZaCrdOHSXkyA76uHZn+bKlH/X8wn8n5oEVoVevXmFlZUXIncfoK9hv\\\\\\n\",\n              \"vvR/PxIT+YIGDRpw+PBhbt68SadOnXBxccHR0fGD+t8zMzOJi4srMPQKesTExAAoFXgmJiYfZXTW\\\\\\n\",\n              \"gQMH+HntBjbt8VO4zOyJ7tjWq4mxkRF+fn6EhITQsmXL3LlZBY1I+1L5+/vj1q8/cklC29AYgzLG\\\\\\n\",\n              \"jFm5G61C/jZh504QsGEBYbdvYWpqiiRJ/Pbbb/zvf//j3r17mJiYMGLECGbNmlXoUkqSJLFs2TI2\\\\\\n\",\n              \"bdpEQEAA2trajBgxAiMjI0qUKMHFixcJDw+nevXqWFlZUa9evdzAqlat2idfWu3p06esW7+BHTt3\\\\\\n\",\n              \"Yly+MnqljcnOzODFn3do06YNkyeMLzSsvySfw7VTUSLAitjAgYPQKKHD8tUbC30H+eeDe7g4tif4\\\\\\n\",\n              \"7NncRUNjYmLw8fHB29ubc+fO0bJlS1xcXHB2dv7g7hNFpKWlKRx4bx/a2tpKBZ6RkRFa/xpo0NKh\\\\\\n\",\n              \"Fc79RtC+aw+F63rj6kUmD+9D/35uODk50bZt289yblZRkiSJFStWMG/BQsrVbkjDId+jZ2LB2XUz\\\\\\n\",\n              \"yEmOo//slZQxy78KfI5MxrWAA5zetZbAYwFYWFgwY8YMvL29yczMpGXLlixevDjfDSYLOn9kZCS3\\\\\\n\",\n              \"b9/mp59+4vbt21SuXJlHjx6hr69PdnY2FhYWdOzYkdu3b9O9e3cmTJhQVC/HB3m7hmNCQgIlS5bE\\\\\\n\",\n              \"ysqKcuUKXzn/S/O5XDsVIQKsiCUnJ2PfvDm2do1Z+vP6d767vHPrJoP6urBkyU8MGjSowH2SkpII\\\\\\n\",\n              \"CAjg8OHDBAYGYm1tjaurKy4uLsXeypAkieTkZKUCLz4+Hn19/TyhFhgYyB+3XqBbUvEAkiQJB6vy\\\\\\n\",\n              \"RESE5/ng/2uRmZmJh4cHR3z9KFu7MS3GLUZd/e9uQEku58bvm7h3fC9VrBrSqJMLBsamyLIyeXr7\\\\\\n\",\n              \"GjeOH8Ky2jc4demMl5cXDx8+xNzcHA8PD7799tsCp4C8efMmz+dTb0cBamhooKGhgbq6OlOnTiU6\\\\\\n\",\n              \"OhofP38ys2RUtKxFiRIleBMTzd3QEIYNG8bMGd+LIemfoc/l2qkIEWBF7O272NTUVJKSkhk0bCS9\\\\\\n\",\n              \"3QZgblGOjIwMrl25xI5fN3Pl0nk2b95M7969FTpuRkYGQUFBHD58GB8fH6pUqZIbZqpyywe5XM7r\\\\\\n\",\n              \"1695+PAhjx494smTJyxZsoSQpwlKf97RqXENQm9e/+reMUdFReHq6oquri4ht8LotfY4Glr5h/9n\\\\\\n\",\n              \"Z6Tx5EIAL0LOkJWSiLqWFm9ePKJ+nVqEhoYik8lo27YtS5YswdbWFoDU1FTu37+fJ6QKGlBhZWVF\\\\\\n\",\n              \"1apV8fDwoESJEmzatIk+fd3IUtPCddh46jVqnufvGRX+nOMHdvKH3wEOe3vTokWLT/Z6CYX7XK6d\\\\\\n\",\n              \"ihAB9h89efKEHTt28OLFS+RyOeXKWTBw4EDq1auHJEmMGjWKuLg4Dh06xK1bt9i4cSNHjx4lLi4O\\\\\\n\",\n              \"bW1tateuw5gxo+nfv/8Hd3vJZDLOnTuHt7c3hw8fRl9fHxcXF1xdXbGzs/vkgzEkSSIpKYmoqCgi\\\\\\n\",\n              \"IyPf+zUxMRFTU1PMzc2xsLAg6ORJgq4+Qt9Q8dGBOTk5tLQqR8zr11/VrS5u3LhBjx49GDZsGCdO\\\\\\n\",\n              \"nSbLoj52fcYqXt57Cw+PezHzu2/p0qULDx48yNOyevXqVZ4BFW8flSpVyvNvKjY2FkdHR+rWrcva\\\\\\n\",\n              \"tWvp2KkzplVqMeL7Re+d8nHj4h+snj2O0ydPYm397lGAwqclAuwzUNR/hNDQUGbMmEnI9RD69BtI\\\\\\n\",\n              \"7Tp1UVNT4+njx/y2dxfVvqlGw4Z2BAUFcenSpXwrekuSVCTBIpfLCQkJwdvbG29vbzIyMnLDrEWL\\\\\\n\",\n              \"Fv9pBJ5MJiMmJqbAMPr37zQ0NHJD6d9f//m9sbFxnjo5OXXDzqEzPfoOVrhewaeOsXPDMm5cV43/\\\\\\n\",\n              \"dO/y4sULNm3ezMGDh4iLi0NLSwvL6tUZO2Y0vXr1ytOdt3//fsaPH8/atWs5efIkO3ftxm3DCUqW\\\\\\n\",\n              \"MVX4fOmJcez1aI+WhnruChX/fFhaWhY6oOL58+e5g4wWL17Mpk2b2LHvd+Zs2KPQfMXjh3Zz87QP\\\\\\n\",\n              \"54PPKlxvoWiJAPsMFOUf4dSpU7i5uTFjzo/07T8IXd28tzLPzs7G94g3Uyd6sGjhQiZNmlQk9SiM\\\\\\n\",\n              \"JEncu3ePw4cP4+3tTXh4OM7Ozri4uNC+ffvcUYOpqakFhtC/v8bFxWFkZPTOMPrn1w9tCR0/fpxp\\\\\\n\",\n              \"381gt8/7V2H4p3GDejDKfRBDhw79oHMWt8zMTMZ4jOXIkSN0cO5Nh+5ulLUoT3ZWFg/u3MBv/w6e\\\\\\n\",\n              \"PAhj65bNdO/enblz5+Ll5cXWrVuZN28eZcuW5VhgIEN3X1f63DsHN+Ll82cfNCAoLCyMLl26MG3a\\\\\\n\",\n              \"NCZPnowkSdSxqsfgaQuo31ixbsHs7CxGdm7I2TOnqVOnjtJ1ED4+VQowMQ9MSffv36dfv3547t5P\\\\\\n\",\n              \"85YOBe6jpaWFa+++1Kpdh57OnbGzs/vk/fxyuZzY2FhkMhmNGjWiQoUK3L9/n0uXLjF06FDevHmD\\\\\\n\",\n              \"rq4uMpkMoMAwatGiRZ6fy5YtW+g78sTExL9XQk9KolSpUtja2mJubq5wvTt27Ih86jS89+2gZ/9h\\\\\\n\",\n              \"he5/+rgvt65fwcfEkM6dOyt1rs9BVlYWjk7dyNHQYe/Jm+iWyhv8LcwcadHekQd3bjB8ZH/Gjh2L\\\\\\n\",\n              \"JElUrlwZJycntLW1yczMRJaT80GtekmSPujNxoULF3B1dWXVqlX0798fgCtXrpCekUW9Rs0VPo6W\\\\\\n\",\n              \"Vgnau/Rn69ZfWL16ldL1EL5uIsCUtHjxT3hMmPLO8PqnOlb1+PF/S5k3bz6nTp38KOdPT08nKirq\\\\\\n\",\n              \"nS2lt9+/fv0aQ0PD1ErKJAAAIABJREFUfKHk4uKCh4cH2trahIWF5d7yvW7durnD8z9kEeGwsDDW\\\\\\n\",\n              \"rVvH/gMHqGtVH8PSRqSlpRB6I4T27dszccIEHBwKf83U1dXx9TlKi5YOSHI5PQe4v/OifMLPm0Uz\\\\\\n\",\n              \"JhAYeJyAgADq16/PsmXLGDJkSL4ykiRx7do11m/YyKVLl0hNSUHfwIA2rVsxbtw46tWrp/RzVkZO\\\\\\n\",\n              \"Tg5xcXG8fv2amJiY3K/79v2GTEObxZs939u9W6teA37eeYSxfTuhjsTVq1cpXbo0derUwczMjGMn\\\\\\n\",\n              \"gngT/gSjipYK1ynh1V/o6enn60EozNvVNby8vOjUqVPu7x89ekT1utZKh6hlHWtCAg8qVUYQQASY\\\\\\n\",\n              \"UmJjY/H19eHa7eUKl+nu2pv5P8zgwYMH1KpVq8B93q62/q7Pk/75NT09HXNz83xdeI0aNcoTVGZm\\\\\\n\",\n              \"ZoUuRuvq+vdq5YmJifj7++Pt7c2UKVOws7PDxcWFHj16FHhL9n/buXMn06d/y6ARHpw4d4Oy5ha5\\\\\\n\",\n              \"25KSEjl8YC8DBg5iyOBBLFy4sNALnKWlJefPBdO9Rw+893riOmAE7To7o2dgSFpqMudOB7LXcwNp\\\\\\n\",\n              \"KUl8801V/P39+emnn+jduzfDhw9n7969bN26lSpVqgB/T1R169eP6OjX9B08gjXDJ1BKT5+kpASC\\\\\\n\",\n              \"/H3o2KkztWvXZu8eL4VbcHK5/P+1d99hUR3v28BvQHpnd4GloxQVAQELiAoqiC12QVFsIKJiL0RN\\\\\\n\",\n              \"rFFjiSX2hooNOzYUsIFYwUTsgoBYEKmK9LLP+8f3J28IbVcgujqf69pL2T1zZvaYzM2cM2cOcnJy\\\\\\n\",\n              \"qgRSenp6te99+PABampq4PF40NTUhKamJlRVVXE/7j72ht4W6tqkkUlzDPWejNMHd2Dt2rV4/Pgx\\\\\\n\",\n              \"zpw5g3fv3qGVRUs8DTsER58FQrUfABIuH4OPj7dIgbNnzx7MmzcP58+fr/LMutLSUkh9wU3IUk2k\\\\\\n\",\n              \"UVpSKnI5hmHXwESwY8cOhF+6iu179otUbsG8Och6n4ouXbpUG0rv37+HgoJCjdeT/vmnurp6o84q\\\\\\n\",\n              \"LCwsRHh4OE6ePIlz586hWbNmGDhwIAYOHAgzM7Mq2x87dgxTp01H0LFzMDGrPqABIDMjHWOG9sOQ\\\\\\n\",\n              \"QQOwYIFwnaxAIMCVK1ewceMmREZG4tOnXCgqKcHS0grPnj7B27dvkZeXhw4dOmDGjBnw8/NDaWkp\\\\\\n\",\n              \"1q5di9WrV+PXX39F9+7d0bVrN/j4z8TwseOrnVhQWlqKbetX4eyJQ9gfFARJSck6AykrKwvKysqV\\\\\\n\",\n              \"Aunz3//9p6amJhQVFfHq1SvEx8dXvK5duwYNviFW7jwq9L9PTlYGhnZtjVYWLTFw4ED06dMHf//9\\\\\\n\",\n              \"N+bPn4/M7A8Y8mcoFNTqHkEX5mbj9Kx+eHj/b6Huxfrn6hphYWEwNzevss25c+ew4LeVWLrrhNDf\\\\\\n\",\n              \"BwBCj+7Dh6QHOHhAtP+vmMbBroF9p969ewcj46YilzNu2gzhoWehrKwMPp8PKysruLm5VQonUZ4Z\\\\\\n\",\n              \"1pjk5eXRr18/9OvXD6WlpYiMjMSpU6fg7OwMdXX1invNbGxsUFhYCL8JE7DvyNlawwsAuDxN7D54\\\\\\n\",\n              \"Er2c22HYsGEwNTWtsy2SkpJwcXGBi4tLlc8cHR0RFhaGfv364cKFC+jYsSP09PTQp08fBAQEoF+/\\\\\\n\",\n              \"fhgzZgx+XbAQc379De5eY2usR1paGpNnzwcA/NS3L1o0bw4tLa2KAGratCns7e0rBRKXy62yiohA\\\\\\n\",\n              \"IMDr168rAurmzZsVf3/79i0MDAxgZmYGMzMz2NnZ4cmz53AZ6FnncfgndQ4PNu0cMHfmVEhJSWHE\\\\\\n\",\n              \"iBFQU1PDsWPHcDE8ArtX+6Pb3O2QU6r5hu7i/FxErPDDpIkThQovgUCAWbNmISIiAjdu3ICurm6l\\\\\\n\",\n              \"z0tKSnDx4kUEBQXh8f1YpL1Jgbae8DcoXzsTjN+XLhJ6e4b5jAWYCKSkpFAuKBe5nKC8HC4uLti6\\\\\\n\",\n              \"dWsjtKrxSEtLVwTIxo0bcffuXZw8eRLu7u4oKyuDqakpWtu2RStrG6H2p6nNx2DPkdi2bRv++OMP\\\\\\n\",\n              \"kdtTUFBQMQpq3749FixYgOfPnyM9PR02NjYYOHAgmjVrhvz8fKSnp4OIYNrcotbw+if/WfNw4cxx\\\\\\n\",\n              \"rFy5ssbrdUSEzMxM3L17t9JoKj4+HomJidDQ0KgIKTMzM7i6usLMzAzGxsZVAu/AocNQVRf9eWbK\\\\\\n\",\n              \"quqYM2cOpKWlsWLFCvTp0wcSEhJwcHBA7qdcHF7ohVaDJsConQukmvz/OsvLSvHy7mU8PLEFhR+z\\\\\\n\",\n              \"YN+ubZ11lZSUYOzYsXj58iWioqIqVp0XCAS4fv06Dh06hBMnTqBly5bw9PQEl6eJi0f3YfQM4UbZ\\\\\\n\",\n              \"8Y/+xsesDKGfP8Yw/8QCTARGRka4fvCwyOWePHqIFubCX1z/FklKSsLe3h729vZYuXIlHj16hL59\\\\\\n\",\n              \"+2HhCtFmjg0fNQ793TpixYoVEAgEyMjIEOoaUnp6OgQCQcUoiMPh4NmzZ0hMTESzZs3QqlUr2NnZ\\\\\\n\",\n              \"Ydu2bTh16hRsbW3Ro2dPDBzuI3TbJCQkMHTUOGzavBl2dnZISEioFFDPnz9HfHw8iAjm5uYVIeXu\\\\\\n\",\n              \"7g4zMzOYmpqKNKNPUUERRQUFIh0/AMjOyoSrqys2bNhQ6dqZhIQE1v/xB7o6OWHlH+twYv9qGNh0\\\\\\n\",\n              \"hJScEsqK8vDm/g00b26O7RvWQE9PD7179waHw6lxhmxeXh4GDx4MGRkZhIeHQ15eHvfv38ehQ4dw\\\\\\n\",\n              \"+PBhaGhowNPTE/fu3asYyaWkpKBN27awbN8Jdo5dav0eH3OysGH+ZCxcuOC7f0IA0zjYNTAR5Ofn\\\\\\n\",\n              \"w8DAAJeu34GBoZFQZT7l5sKquTGePnlS5dSLuFNRUUH0/QSoqIi2/qC1CR/lZaUoKSmp9RrSv99T\\\\\\n\",\n              \"UlKqdP1vypQpUFNTw5IlSyre27hxI7Zs2YKIiAiYmpoiNiGtysinNtlZmehsYwopSUmYmJhUGk19\\\\\\n\",\n              \"fnG53Aa5Djn/l18Q/zoDk+atELpMcVEhhnW1RmzMXTRtWvvp7KdPnyI6Ohq5ublQVlZGx44dK91r\\\\\\n\",\n              \"FRERgREjRuDSpUtVZmF+Xl2jVatW+Pnnn3H06NGKh1d6enrC09MTrVq1qrbe6Oho9BswAJ6Tfka3\\\\\\n\",\n              \"fh6QrmZpq4TH97F+nj88PYZg+fJlQn9/pvGJ0zUwFmAimjp1KorLCMtWrRVq+43rVmPrpg1oamyM\\\\\\n\",\n              \"ZcuWfVePZZCVlUXcizTIiTgNu5OtOc6eOQ1ra9GnXP/Tw4cP0aNHD6SkpFS6P23WrFmIjo5GUnIy\\\\\\n\",\n              \"oh8ki7RPIkILHWUUFxeLFHxf4tWrV7Cybo3DV+KEXrz44qnD+PvaOYRdvNAgbQgODsbs2bMRHR1d\\\\\\n\",\n              \"aRTVrVs3NG3aFHl5eUhISIC7uzs8PT3h4OAg1AobDx48wKTJU/Ds2TO4DPCESUtrSDWRRkbaW1w7\\\\\\n\",\n              \"E4zc7EwsXLgAPt7eDfI9mIYjTgH2/T+bvoHNnTsXF86fwYF9gXVuezH0HLZt2oBrV69i8uTJ8PX1\\\\\\n\",\n              \"Rbdu3XD79u3/oKWNj8PhIu1dqkhliouLkZOdBSMjo3qPYiwtLWFoaIjQ0NBK769atQpaWlrI+5Qn\\\\\\n\",\n              \"8j6Li4ogLS3d6M+kIvrfvVwAcGCbcL8M5X3KxZ4/V6CLs1ODtWPo0KGYM2cOunfvjuTkZKxY8b8H\\\\\\n\",\n              \"Sr59+xY8Hg+//PILUlNTsXnzZjg6OgoVXgBgZWWF65HXEHn1CrTlgdiwY4g+9b/Zhit/W4yXyUks\\\\\\n\",\n              \"vJh6Y9fARKStrY3wsDC49eiBh3H34ec/FcZNm1Xa5l3qWwTu2IbDB/bi9OnTaNGiBVq0aIEhQ4Zg\\\\\\n\",\n              \"7969GDJkCGxsbLB06VKxXsR04MABOHX0IKb/LPy9R2HnQ9C8eYsGe+yJr68vduzYgb59+1a8Jykp\\\\\\n\",\n              \"ieDgYGhp8/Eo7m+hJ5kAwO3oSFjVc2RYl9evX8Pf3x/x8fFw6twJIQd3QVVdA+5jJtVY5tPHD1jg\\\\\\n\",\n              \"PwJt7WywYcMG5ObmYtGiRXXe61eXkpISGBkZQUpKCiYmJpCSkoKvry9WrlzZIM9Ua9myJdatEy6g\\\\\\n\",\n              \"GUZUbAT2BczNzXHn9m2oqyiiZ9eOGNKvJ379eRZ+nTsbXh4D0bm9DYryP+LWrVto3759RTlpaWmM\\\\\\n\",\n              \"GzcOCQkJ6Nq1K9zc3DB06FDEx8d/xW/z5SZOnIjgA3tQUlIi1PZEhL07NuPdu1TY2toiKChI6LI1\\\\\\n\",\n              \"cXd3x61bt/Dq1atK78vLy2PqlCnYu/1PkfYXvG8HJk0UfkV3UZSXl2PTpk2wtbWFlZUVWrRogfT0\\\\\\n\",\n              \"dDQ3N8PBbWsxaagbrkecQ/n/Le8FANmZ6Ti4fS18+3eGU4f2OB0Sgri4ODx48ACOjo5f9N+OQCBA\\\\\\n\",\n              \"VFQUxo8fDx0dHaxatQpOTk5o0qQJLC0tsXbt2h/ugaCMeGIB9oW0tLSwevVqvH79GhPG+6KZkT6M\\\\\\n\",\n              \"9XUwfJgHUlJSsGXLloqVIP5NTk4O06ZNw4sXL2BlZQVHR0d4e3sjJSXlv/0S9dSyZUsYGRliztTx\\\\\\n\",\n              \"EOZS6u5tf6KkqBBJSUlYvnw5Dhw4ACMjIyxbtgyZmZlf1AYFBQV4enpi9+7dVT6bPNkfkZfCkPDs\\\\\\n\",\n              \"iVD7unf3Fh7cvwcPD48vakttHj16hI4dO+LIkSMICQnB5cuXUVBQgKSkJKSnp0NXh49eLs44f2g7\\\\\\n\",\n              \"BndqAZ++jhjVox3G9HaA4GMazp87g/Xr10FKSgqampoVyzk5Ojpi165d1R7/2NhYjB4zBsZNm4LL\\\\\\n\",\n              \"5UFXTx9W1tbg8/mYOHEijI2NERsbi7FjxyIkJARXr16FgYEBRo8eDYFA0ODHgGEaHH2n7OzsvnYT\\\\\\n\",\n              \"hJadnU3z5s0jDQ0N8vf3p3fv3n3tJtWpvLyc5s6dS4aGhmRt3ZoGunvSX8/fUHJGYZXXk5QsmjJz\\\\\\n\",\n              \"LhkaGlFKSkql/Tx48IDGjh1LampqNH78eHr69KnIbXnw4AHp6upSaWlplc+CgoKIr6NLoVH36Hla\\\\\\n\",\n              \"fo2vY6GRxNPUogsXLnzxMalOYWEhzZ8/n7hcLm3bto2ePXtGJiYm5OTkRDIyMsTlcmnfvn1UVlZW\\\\\\n\",\n              \"USY1NZUePHhAz549o0+fPtW6/8ePH5O1tTUNGDCAMjMziYjo9evX5ODQgfQMDGlywCI6cTmWwmMS\\\\\\n\",\n              \"6NTVv2j6L8vJwKgp2djaUnx8PP3+++9kaGhIz549IyKigoIC6tSpE02ZMoUEAkGDHgtGPIhT38kC\\\\\\n\",\n              \"7Bvy/v17mjZtGmloaFBAQABlZWV97SZV69OnT9S/f3/q1KkTpaenU35+Pnl7e5OKqioNGeZFuw+e\\\\\\n\",\n              \"pBOhVyno2Dny9ptMGhoc6t2nT63BnJaWRgsXLiRNTU3q1asXRUREiNSBOjg40JkzZ6r9bM+ePaSq\\\\\\n\",\n              \"qkZe3hMo/NaDSsF19updGjRsJKlraNRY/ktdu3aNzMzMaNCgQfT27VuKjo4mLpdLfD6fpKSkaPbs\\\\\\n\",\n              \"2VRYWFjveoqKimjmzJmkq6tLBw4cIB0dXZo8ZxHdTcymey8/VnnFJOVQwNI1pKKiSqampvTmzZtK\\\\\\n\",\n              \"+8vJySErKytatmxZvdvGiB9x6jtZgH2DXr16RePGjSMOh0OLFy+m3Nzcr92kCikpKWRtbU1jxoyh\\\\\\n\",\n              \"4uLiSp+lp6fTihUrqJuLC7Vt146cu3SlOXPmUFJSktD7LygooJ07d1LLli3J0tKSAgMDqaioqM5y\\\\\\n\",\n              \"e/bsod69e9f4eWJiIrm6upK8vDyZmJlTa9s21LSZCWlr88na2ppmz54tdBvrkp2dTd7e3qSnp0ch\\\\\\n\",\n              \"ISFERBQYGEjy8vIkKSlJTZs2peTk5Aarj+h/v1TMmTOH5BUUacrPi6sNrn+/5q/YQEbGTamkpKTK\\\\\\n\",\n              \"/lJTU8nY2Jh27tzZoO1kvn3i1HeyAPuGJSQk0PDhw0lTU5PWrFlDBQUFX7U9N2/eJD6fT2vWrGn0\\\\\\n\",\n              \"00sCgYAuXrxIbm5upK2tTUuWLKH09PQat8/PzycNDY0qpyj/bfr06WRjY0NXr16lR48eUUlJCcXH\\\\\\n\",\n              \"xxOXy6WcnJx6tzk4OJj4fD5NmjSJPn78SPn5+dSzZ0+SkJAgBQUF8vf3r3S6sD6Ki4vpzJkzNHTo\\\\\\n\",\n              \"UFJRUaF27dpRM7PmFJv8QagAu/fyI7Wxd6Rjx45Vu//4+HjS1tauCGHmxyBOfScLMDHw8OFD6t+/\\\\\\n\",\n              \"P+nq6tLWrVurjHz+C/v37ycul0tnz579z+t+9OgR+fj4kJqaGo0bN44eP35c7Xb+/v60cOHCWvdV\\\\\\n\",\n              \"Xl5O7u7u5O7uTuXl5RXvjx49us6ytUlJSaHevXuThYUF3bx5k0pLS2n79u2kqKhI8vLypKqqWmNQ\\\\\\n\",\n              \"iKK8vJyioqJo/PjxxOFwyNHRkTZv3kzp6enUt19/mrdsndDhde/lR1r+527q7ORcY30xMTHE4/Eo\\\\\\n\",\n              \"Kiqq3m1nxIM49Z0swMTI3bt3qXv37mRsbFzlwn9j+TxZw9jYmB4+fNjo9dXm/fv3tHjxYtLS0iI3\\\\\\n\",\n              \"NzcKCwurNBJ88OAB6enpVTuZ458KCwupY8eOlU4bJiYmEofDEfm6Y1lZGa1fv544HA4tXbqUioqK\\\\\\n\",\n              \"6NSpU2Rubk7q6uqkpaVV72MnEAgoLi6O5syZQ/r6+mRhYUHLly+vchpSTk6OrsWliBRgt+MzSEZG\\\\\\n\",\n              \"ptbTtBEREaSpqUlxcXFf/B0Y8SFOfScLMDF07do1cnR0pBYtWtCxY8cqjSQa0r8na3wrCgsLKTAw\\\\\\n\",\n              \"kFq1akUWFha0a9euiskQ9vb2Qk3GyMrKInNzc9q0aVPFez4+PjR//nyh2xEXF0ft2rWjzp0707Nn\\\\\\n\",\n              \"zygqKoocHByoefPmZGxsTHw+n1xdXb94Mk5ycjItX76cLCwsSF9fnwICAmoMkaKiImrSpIlIpw8/\\\\\\n\",\n              \"vzQ4XHr//n2tbQkODiZdXV2Rrmcy4kmc+k4WYGJKIBBQaGgo2djYkK2tLYWGhgp1Xeru3bs0fvx4\\\\\\n\",\n              \"cnNzI1dXVxo1ajRFRERUCcHaJmt8KwQCAUVERFDPnj1JS0uLFi5cSOvWraM+ffoIVT4pKYn4fD6d\\\\\\n\",\n              \"Pn2aiP4XGBoaGpSRkVFruYKCApo7dy7xeDzauXMnxcXFUZ8+fcjQ0JCWL19OfD6f1NTUaM6cOSKP\\\\\\n\",\n              \"ktPT02nz5s3UoUMH4nA45OfnR1FRUXX+kpKQkECSkpJ0Oz5DpPCKTf5AcvLydU7XJyLauHEjmZqa\\\\\\n\",\n              \"1hl2jHgTp76TBZiYKy8vp2PHjlGLFi2oY8eOFBkZWe12UVFRZGtrR4ZGxjR/4W+0/2gIHTx+hpat\\\\\\n\",\n              \"Xk8WrSzJxMSUjh49SkT/7WSNhvLkyRPy9fUlVVVVkpWVpYiICKHKxcTEEJfLpTt37hARkZ+fHwUE\\\\\\n\",\n              \"BNS4/eXLl8nExITc3d0pJiaGRo8eTZqamrRu3Tq6cOECKSsrk7KyMgUHBwvd9k+fPtGBAweoV69e\\\\\\n\",\n              \"pKKiQsOGDaOzZ8/W+ItDUVERnTt3jkaNGkVmZmYkIyNDAEhZRYX+3HNcpAALPBFOTZuZCP3v/Msv\\\\\\n\",\n              \"v5Cdnd03NTOWaVji1HeyAPtOlJWV0b59+8jY2Ji6d+9Od+/erfjs1KlTxOXxaFdQMKXmFFHax5JK\\\\\\n\",\n              \"r3cfiulU6GXS1dMnr5EjicfjfZXJGg0hIyOD2rdvT0pKSuTq6koXLlyos3M+c+YM8fl8SkxMpFev\\\\\\n\",\n              \"XpGGhkaVUUZmZiaNGTOG9PX16cCBAzRz5kzS0NCg+fPn04cPHyqmyfP5fLp//36d7SwuLqazZ8/S\\\\\\n\",\n              \"sGHDSEVFhXr27EkHDhyodiSUkpJCa9asIRcXF+JyuSQhIUFSUlJkYGBAQ4YMoYMHD1J+fj7t3r2b\\\\\\n\",\n              \"nFx6iBRgvQe40x9//CH08RUIBOTr60suLi5C3d7AiB9x6jtZgH1niouLaevWraSjo0P9+/enY8eO\\\\\\n\",\n              \"EYfLpbBrt6sE179fsY9ekIYGhzZu3Pi1v0a9xMXFka6uLu3evZusra2pRYsWtGPHjlpvQ9iyZQuZ\\\\\\n\",\n              \"mZlRZmYm+fv704wZM+jq1avkP3kKdezkRCpq6tShQweaPXt2xam91NRUEggEFBAQQHJyctShQ4eK\\\\\\n\",\n              \"1TCq83kGoZ+fH3G5XOrQoUPFDMLPCgsL6fLlyzRhwgSysLAgWVnZiin4tra2NGPGDIqJian2lGJ+\\\\\\n\",\n              \"fj5xuTzafTxMqPA6eC6K1NTUKTs7W6TjW1ZWRgMGDCAPD49Gu/7KfD3i1HeyAPtOFRQU0Jo1a0hN\\\\\\n\",\n              \"TY0WLVtVZ3h9fgUdOUV2dm2+dvPrzd7ens6ePUsCgYAuX75Mffr0IR6PR7/++muNK4IEBASQo6Mj\\\\\\n\",\n              \"rV27jhSUlEm/qQm5+weQ3+J15D3/d2rv2ofk5BWob7/+9O7dOyouLqa+ffuSjIwMTZgwocbZj3Fx\\\\\\n\",\n              \"cRQQEEAGBgbUsmVLWrZsGSUlJZFAIKCXL1/S1q1bqXfv3qSlpUWSkpIkKSlJWlpa1LNnT9q8eTOl\\\\\\n\",\n              \"pqYK/b3DwsKIy9OkvScv1Rpeh0OjSUubX3HaWFSFhYXUuXNn8vf3rzLCFQgElJubS1lZWf/JTFmm\\\\\\n\",\n              \"YYlT38kC7DuWlpZGqqpq9DwlXegAe5tdSAaGRpVOQYqjwMBA+umnnyq99+zZM/Lz8yM1NTUaPXp0\\\\\\n\",\n              \"lRl95eXlZGllRTy+Lv2y4ygdvPeaDv31ptJr+5WH1G+MP+nqG1RcfwoMDKxS/+cZhK1atSJ9fX2a\\\\\\n\",\n              \"M2cO3blzh65evUrTp0+n1q1bk5ycHElJSZG0tDQ1b96cfHx86OLFi5Sfn1+v737u3DnS4HBp0LDR\\\\\\n\",\n              \"dOj89UrBdSziDg0d7UvqGhw6fPhwver5vOTUb7/9RkT/O9U5d+5c4mlqkqKiIqmoqpKcnBwNHTqM\\\\\\n\",\n              \"rl+/LjbXU3904tR3sicyf8d27tyJsEtXsXnnPpHKrVq+GJLlxVi9enUjtazx5efnQ1dXFz4+4/Dg\\\\\\n\",\n              \"4UPk5eVBWVkZXZydMGDAAJw4cQKbN29GixYtMGPGDPTo0QPbd+zAspVr8Ovuk1BR59S6/7DgQBzZ\\\\\\n\",\n              \"+DsuXgiFk9P/HjCZmZmJo0eP4tChQ3j69Cnc3NxgYGCA+Ph43LlzB2lpaZCQkICioiKsrKzQs2dP\\\\\\n\",\n              \"/PTTT7CwsBD6QZHCSktLw65du7Bt+w5ISEpCWUUV+XmfUFxUBN9xPvD19YWenl6963n37h06dOgA\\\\\\n\",\n              \"a+vWiIqKwgD3YfAcNQ7NTM0BAB8/5ODkkYM4uHcHjIwMceL48QZ7FhzTOMSp72QB9h1bvnw50jI/\\\\\\n\",\n              \"YP6iZSKVOxS0B6ePH8akSROhqKgIRUVFKCgoVPpTUVERcnJyDd7xNoScnBxMnOSPM6dPo1vv/ujk\\\\\\n\",\n              \"0htKyirIy/2I65fOIzL8PAYOHIg//liD0NBQrF27Fnl5eXiT+g6L9oRA36S5UPVs+nkC+nZ1hL6+\\\\\\n\",\n              \"Pvbt24fo6GgYGRmhpKQEb9++RVlZGQQCAXR0dODg4IB+/fqhS5cu4PP5jXwE/r+ysjIkJycjNzcX\\\\\\n\",\n              \"ysrKMDY2hrS0dIPW4ec3ARGXryA4JAw8Le1qtykvL8eS+TPx+P49REZeY88b+4aJU9/Jnsj8HZOW\\\\\\n\",\n              \"lkZZaVndG/5LSUkJYmJi4O3tDRkZGUhJSUFSUhJEhPLycpSWlqKoqAglJSWQl5evFGrVBZ2o733+\\\\\\n\",\n              \"u4KCgsgBmZGRgc5OzrBs64jTNx5DSbnyb/udXHph0s+/YfPvv8LF1RVXr1yBlpYWBg8eDL6RidDh\\\\\\n\",\n              \"BQAuHmOwyH8EBOVlKCsrg7S0NBITE9GiRQtMmjQJPXv2RPv27b9qZ92kSROYmpo22v7DwsJwMSwM\\\\\\n\",\n              \"Jy9GQYPDrXE7KSkpLFqxDnOm+GLWrFnYunVro7WJ+XGwAPuOmZqaIuT0WZHLxd2/h9mzZ2HQoEF4\\\\\\n\",\n              \"9+4dUlNTK17//Pndu3eQkZEBh8MBj8eDhoYG1NTUoKKiAmVlZcjLy0NOTg4yMjIoLi5Gfn4+MjIy\\\\\\n\",\n              \"8PLlSxQUFCA/Px/5+fkVf//3e4WFhZCVlRU6CBUUFHDo8GF06NobkwIW1/j9VNXUMXfFRqxdPAeW\\\\\\n\",\n              \"Vlb4+OEDygno6ekt0nEyb90WCsoqMGtqBE9PTzg7O6Nly5aQkpIS+ZiLq3Xr18N/xtxaw+szCQkJ\\\\\\n\",\n              \"zPn1N7g6tsaKFSugpqb2H7SQ+Z6xU4jfsdLSUhgYGiL4VCiat7AQqsyHnBzYtzbH8+fPoampWeu2\\\\\\n\",\n              \"AoEA2dnZ1YbbP39OS0uDiooK+Hw+dHR0Kl7//JnP54PP50NGRqbS/ouKiuoMus9/Pnz4ENeu38Dh\\\\\\n\",\n              \"sLtCjdzKysrwU4cWyP/0EXKKSpj9ZxBMWtkIdZw++33iMLS3bA4bGxtIS0tDRkYG0tLSVV6ivi8t\\\\\\n\",\n              \"Lf1Nnp79p6SkJLRt2w437idATl5e6HJTx4+Ec0cHTJ8+vRFbx3wpceo72QjsOyYtLY1xPj7YtG41\\\\\\n\",\n              \"Nm7fAwkJiTrLbN+yARwOR6hHyktKSoLL5YLL5cLKyqrG7QQCAbKysqqE25MnT3Dp0qWKoHv//j1U\\\\\\n\",\n              \"VVWrhNu/f9bW1q72Ok7/AQPhMXqC0B1/kyZNMGzsRNy8dBav3ryBoLzu7/xv5aWlSEtLQ1xcHEpL\\\\\\n\",\n              \"S1FSUoLS0tIqL1HfLy0thZSUVL2DsDHfDw8Ph31HJ5HCCwC6ufVGVMQ5FmBMvbEA+87NnDkTjo4d\\\\\\n\",\n              \"sXbVMsyYM7/WEDt+5CCC9+/BgAEDYGVlheXLl8Pb21uo4KuNpKQkeDweeDwerK2ta9xOIBAgIyOj\\\\\\n\",\n              \"ykju4cOHCAsLqwi+9+/fQ11dvVK4aWtr4/y5c5i8aINIbeszeDh2rl8BaVlZpMQ/hpm1ndBlBeXl\\\\\\n\",\n              \"yHz3GkeDdqNly5Yi1VsXIkJZWVm9g7Cu9/Pz8794P+/fv0d7RyeRv5uSkgo+ffrUoMeL+TGxAPvO\\\\\\n\",\n              \"qaqqIjw8DD169MSTRw/g5z8dbdrZVwqlp08eIXDHFly9FIbw8HC0atUKvr6+GDduHA4ePIgdO3Y0\\\\\\n\",\n              \"6kSAzyQlJaGlpQUtLS20bt26xu3Ky8uRkZFR6TTlixcvICMrC0UlZZHqVOfwICABpCSAiwd3wWWw\\\\\\n\",\n              \"l9CB/Xf0Fejq6DZ4eAH/u170eaTzrTp06BAOHz0pcrnc3A9QUVFphBYxPxoWYD8AHR0d3Lx5Azt2\\\\\\n\",\n              \"7MC0id6QlZVD85YWkJKUQnJyIt68fgXfceOwJja24rqXtbU1bt26hY0bN8LBwQEzZ87ErFmzvokO\\\\\\n\",\n              \"NTc3F0lJSYhMeNmbAAAbjElEQVSLi8Pt27fx8OFDJCcno7SkROR9CQQCkECAZs2aITEpGU9ibsKi\\\\\\n\",\n              \"naNQ5UJ2b8CgXt2/5Ct8FxwdHeHvPxkF+flQEGGm5aUL59DDtUsjtoz5UbBJHD8YgUCAGzduICUl\\\\\\n\",\n              \"BQKBAHw+H87OzrUG08uXLzFhwgSkpqZi586daNeuXaO3k4gqrpPdvn0bMTExePr0Kd6+fYvi4mJI\\\\\\n\",\n              \"SUmhrKwMXC4XxsbGsLKywvETJ7DxwFk0MxN+RPTo7xhMHzsY27ZugaKiIrxGj8GC3Seh19SsxjIC\\\\\\n\",\n              \"gQAH/1iEV49i8TEnGz179sTq1at/yBt0f+rbF45demKo1xihtn+flooendsg5eVLNgr7RolT3/lt\\\\\\n\",\n              \"T3NivhgR4ebNmxg+YgR0dXWhrKwMPp+PAQMGorCwEJ6enhg5ciRcXV3rHFUZGRkhNDQUAQEB6Nu3\\\\\\n\",\n              \"L6ZPn468vLwGaWd5eTlevHiB48ePY9q0aXB2doaenh5kZWVhaGgINzc3LF++HI8fP4apqSlmzpyJ\\\\\\n\",\n              \"kJAQPHnyBMXFxXj//j2uXLmCdu3aQU5ODkf3bhep/kO7N0FBXh5paWnYunUrzE1M8LufByKO7EVB\\\\\\n\",\n              \"XtXrNImP7+PPWeOQlfwU1yOv4fHjx5CUlESrVq1w/vz5Bjkm4mT6tGnYsn4lMt6n1bmtQCDAsgUB\\\\\\n\",\n              \"8BoxgoUX0yDYCOw79Pr1awweMgRZWVnwHueHn/oNgJq6OvLz8hB+MRQ7t29FSXERjh8/jlatWom0\\\\\\n\",\n              \"78zMTMycORNRUVHYunUrevToIVS5oqIiPH/+HNHR0bhz5w4eP36MlJQU5OTkQEJCAkQEdXV1GBoa\\\\\\n\",\n              \"wsLCAvb29rCzs4O5uXmN9wslJiZi69at2Lt3Lzp06AAPDw9MmDgR+8/dgJZO3cskvX6ZBJ+BXREY\\\\\\n\",\n              \"uBuTJk1CdnY2Ro4cCScnJ5wMOY3Lly/BtpMLlNQ5KC0uwZPYm0BZMfwnTsCUKVMg/4/Zd1euXIGP\\\\\\n\",\n              \"jw86duyIdevWgcOpfSmq78mSJUtw6PAR7D58Crp6BtVuU1pail9mT8ab5Be4dCmi0rFjvi1i1Xd+\\\\\\n\",\n              \"hfUX/xPitCBlQ0pJSSF9fX36bcUq+lBQSrlF5VVeHwvLaEfgPtLU1BTq2VXVCQsLI2NjYxo+fHil\\\\\\n\",\n              \"x4F8+PCBrly5QgsXLqQ+ffqQmZkZKSkpkYSEBElISJCsrCwZGBhQ165dacaMGXTixAlKTEwUetXy\\\\\\n\",\n              \"8vJyCg0NpV69ehGXy6XZs2dTUlISpaamkpubG0lLy5COniGdvfmUbr3IqfF1MjKO9AyMaP369eTk\\\\\\n\",\n              \"5EReXl705s0bWrZsGenr65O9vT2tX7+eNm3aRMbGxuTt7U2ysrJUWFhYY9vy8vJo6tSpxOfz6fjx\\\\\\n\",\n              \"4190XMWRQCCgNWvWkKqaGg0dMYbOXrpJie/zKSm9gO48TKJZ8xaTnr4B/dS3r1BPfma+LnHqO1mA\\\\\\n\",\n              \"fUcEAgHZ2trS8pVrqg2uf7/2HjhM+vr6tT4nq6Z63rx5Q4GBgdS6dWuSlpYmdXV1kpaWrggqVVVV\\\\\\n\",\n              \"srCwoEGDBtHvv/9ON2/epA8fPnzxd8vOzqa1a9dSs2bNyMbGhgIDAyk7O5uOHj1Krq6u1KRJE9LS\\\\\\n\",\n              \"0qLg4GBavHgJaXB5NP3X3yni75eVgivsXjJNmbeMuDxNUlFRIRsbG/Lx8an0XKuysjI6ffo09ejR\\\\\\n\",\n              \"g3g8XsVDLC0tLYVapf/GjRtkbm5OgwcPprS0tC/+zuImLS2NfvvtNzI0NKpYZV9ZWZm8vb3p3r17\\\\\\n\",\n              \"X7t5jJDEqe9kAfYduXLlCrVoaUEfC8uECrDconJy7e5G+/btq3Z/ZWVlFBcXR+vWraNhw4ZR69at\\\\\\n\",\n              \"icPhkKSkJAEgGRkZ0tHRISsrK+JwOGRhYUFXr15t0GdA3b9/n8aNG0dqamrk6elJN27coOjoaBo/\\\\\\n\",\n              \"fjxpaGhQmzZtSENDg2bMmFHxPK5r166RkpISqalrkLyCAnVwciG3voPI0dmFVFTVaJjncAoPDydD\\\\\\n\",\n              \"Q0NSVVWtNIL8t4SEBNLX1ycVFRXS19encePGCfX9CgsL6eeffyZNTU06cODAD/cokdLS0lpHq8y3\\\\\\n\",\n              \"S5z6TnYN7DsyeMgQdOjkjHHjJwhd5sL5s1jz+zIsXboUUVFR+Ouvv/DixQukpaUhPz8fAKCoqAg+\\\\\\n\",\n              \"nw9TU1PY2trCyckJ7dq1q3QhvrS0FGvXrsXq1asxd+5cTJ06FU2afNldGqWlpTh16hQ2bdqEpKQk\\\\\\n\",\n              \"+Pn5oXv37rh48SKCgoIgLS2NkSNHori4GFu3bsWuXbvw008/gYhw8uRJDB8+HEZGRti9ezeaNm2K\\\\\\n\",\n              \"mJgYfPr0CSoqKrC3twcAuLq6wtXVFZKSkrh58yYiIiIgJydXbXu6dOmCgIAAnDx5EidOnICysjLG\\\\\\n\",\n              \"jx8Pb2/vOpfbio2NxdixY2FoaIht27ZBV1f3i44Jw/xXxKrv/MoB2mjE6beIhqKiokJJr9OEHn3l\\\\\\n\",\n              \"FpVTTn4JSUtLk5SUFPF4PLK1taURI0bQn3/+SY8ePRL5kfEJCQnUtWtXsrOzo7/++kuksqmpqbRo\\\\\\n\",\n              \"0SLi8/nk5OREe/fupW3btlGnTp2Ix+PR5MmTKSYmhj58+ECDBw8mW1tbSkpKIiKiyMhIcnBwIG1t\\\\\\n\",\n              \"bWrdunWN7U5NTaWWLVvSL7/8QgKBgMrLy8nd3Z2GDRtWYxlnZ2e6cuUKJSQkkJ6eHsXGxpK3t3fF\\\\\\n\",\n              \"qLCuhzUWFxfTokWLiMvl0s6dO3+40RgjXsSp72QB9p0QCAQEoMaJG7W9dHR06NWrVw3alsDAQOLx\\\\\\n\",\n              \"eDRnzpxanzAsEAjo8uXL5OzsTAoKCuTm5kZLliwhDw8PUlFRoYEDB1JISAgVFxcTEVFcXByZmpqS\\\\\\n\",\n              \"n58fFRYWUlxcHPXq1YuMjIxo3bp1pK6uTvHx8dXW9fr1azI1NaWlS5dWer+goIAcHBxo/vz51Zb7\\\\\\n\",\n              \"HGACgYB4PF7FscrOzqb169eTmZkZWVpa0tatWyk3N7fG7xoXF0d2dnbk4uJCycnJtR1ChvlqxKnv\\\\\\n\",\n              \"ZAH2HVFQUKC3GR9ECq+PhWWkpqZGWVlZDd6etLQ08vDwoGbNmtGlS5cqfZafn08rVqwgvo4OySso\\\\\\n\",\n              \"UGsbW+reoxfZd3AkFRUVsrSyoqCgoEqjlT179hCXy6WgoCBKSkqiESNGkJaWFm3YsIGKioqof//+\\\\\\n\",\n              \"tGjRomrbkpycTMbGxrRmzZpqP09PT6dmzZrR7t27q3z2OcCIiPr160fBwcGVPhcIBHTp0iUaOHAg\\\\\\n\",\n              \"qaur08SJE+nhw4fV1lNaWkorV64kDodDGzduFHmEyzCNTZz6ThZg35HOTk60//BRkQLsyvVbZGxs\\\\\\n\",\n              \"3Kgd6dmzZ0lfX5/GjBlDsbGxNHPmTFJWViYlJSWaOHka/f04nnIKyipe77LzaNvufWTRypK8Ro6k\\\\\\n\",\n              \"jx8/ko+PD5mbm1NkZCRNnjyZNDQ0aNGiRRUjnrNnz5KJiUm1Ewfi4+PJwMCANm3aVGs7nz17Rpqa\\\\\\n\",\n              \"mhQREVHp/X8G2MqVK2nKlCk17uPNmze0cOFC4vP51LlzZwoODq4YPf67rg4dOlDHjh3p+fPndR5D\\\\\\n\",\n              \"hvmviFPfyQLsOxIcHExOzl1FCjDPESNp5cqVjdqu8vJyOn78OBkZGZGEhASZmpqSkpIyhV66Vim4\\\\\\n\",\n              \"/v16k/GRurl0J21tbRowYAAFBASQhoYGTZ06ld6/f1+x//z8fDIyMqLw8PAqdT958oR0dXVpx44d\\\\\\n\",\n              \"QrU1MjKSeDxepRHUPwPs+vXr1KZNmzr3U1JSQseOHaMuXbqQtrY2zZ8/n1JSUiptU1ZWRhs2bCAO\\\\\\n\",\n              \"h0OrVq2qmEXJMF+TOPWdLMC+I8XFxaSjo0Mnz4YKFV7Xb8eSuro6ZWRkVNpPQUEBJSYm0rNnzygz\\\\\\n\",\n              \"M/OL2/PPe7dMTU2pU6dOpKSkREpKSnTwyMlaw+ufIWZgaERqamrk5eVV7bWjuXPnkoeHR5X34+Li\\\\\\n\",\n              \"iM/nU1BQkEjtPnDgABkaGlJqaioRVQ6wgoICUlBQoLy8PKH39+TJE5oyZQppaGhQ37596eLFi5VG\\\\\\n\",\n              \"vImJidSlSxdq27ZtjaceGea/Ik59Jwuw78z169eJx+PR2QsRtYZX9J17xNfRoWPHjlWUjYuLI19f\\\\\\n\",\n              \"X1JTUyMDQ0NqZmJCKioq1KVrVzp+/DiVlJQI1YbP+1FWVqZWrVoRn88nCwsLWrVqFR09epRaWlhQ\\\\\\n\",\n              \"dn6pUAGWU1BGa//cQt26uVRb15MnT4jD4dDbt28rvR8bG0taWlp05MiRLzqOS5YsITs7O8rLy6sU\\\\\\n\",\n              \"YERE7du3p2vXrom8z7y8PNq5cye1bt2amjVrRmvWrKn4BUEgEND27duJy+XSkiVLajzWGRkZ9Pvv\\\\\\n\",\n              \"v5Nzly7U2saGOjg60uTJk+nJkydf9D0Z5t/Eqe9kAfYdunr1KvF4PBriPpTCr0RVurH5xt2/aIz3\\\\\\n\",\n              \"ONLQ0Kjo3MvLy2n69OnE19Gh+QsWU3zym4rtM3MLKTDoIDl0cKTWrVtXCYrPSkpK6MiRI2Rvb/+/\\\\\\n\",\n              \"ADQwIC6XS9OmTaO//vqrYjKGu7sHrV63Uejwyikoo9fpH0hdXb1K3QKBgJydnWnDhg2V3r916xZp\\\\\\n\",\n              \"amrSqVOnvvgYCgQCGj16NPXt25ecnJwqBdj06dNp+fLl9dr3rVu3yMvLi9TU1GjUqFF0584dEggE\\\\\\n\",\n              \"9OrVK+rZsydZW1tXWr2ioKCAfH19SVVVlYZ7jaITZ87TtRt3KDT8Cs3+eT5paWlRl65d2exGpt7E\\\\\\n\",\n              \"qe9kAfad+nz6zsTUlLS0tMjMzJx0dHVJT0+PlixZQu/evSOi/3WmEydOJAfHjpTyLrPW2YoLFv9G\\\\\\n\",\n              \"JiYmlU45pqam0oIFC0hDQ4O4XC4pKCjQ4MGD6dy5c9WOIqysrSnyZoxIAZZTUEb2Dh0oKiqq0r6C\\\\\\n\",\n              \"goLIxsam0rWjqKgo4vF4dP78+Xofw+LiYuratSvp6elVCrBjx45Rnz596r1/ov+NqFatWkXGxsZk\\\\\\n\",\n              \"Z2dHu3btory8PAoKCiIej0dz586lrKws6tipEw0a4k7Jb9Or/ffJzC2kpctXko6ODpsUwtSLOPWd\\\\\\n\",\n              \"LMC+c+Xl5fT69Wt6/PgxpaSkVJkocObMGTI3b06v32cLdd1syvSZ5OHhQdevX6fu3buTjIwMycnJ\\\\\\n\",\n              \"kY2NDe3YsYNycnJqbY958+Z0MzZO5ABzcu5SaXZgdnY2aWtr0507dyreu3TpEvF4vCqzCOsjJyeH\\\\\\n\",\n              \"FBQUyN/fn4j+F/jHjh0jVTU16tzZiTp26kTuHh505syZei2h9XmR4j59+pCGhgZNmzaNrl+/TgMG\\\\\\n\",\n              \"DCBNTU0a7DFUqHv8Nm7ZTk2bNq313juGqY049Z0swH5wrq7daUfgPqFnLb5+n03y8vIkLS1NHA6H\\\\\\n\",\n              \"AgIC6MWLF0LX5+jYkU6cCRUpvLLzS8nU1KzSyh5+fn7k5+dX8fP58+eJx+NRZGRkgx4fIiJ7e3vi\\\\\\n\",\n              \"cDi0aNEiatmyJbVo0ZJ+X72Wzl6IoPNhl2njlu3Utl17MjQ0pP3799e7vuTkZJo7dy5pampShw4d\\\\\\n\",\n              \"SEVFldKyPwn9b+TWoycFBgY2wDdnfkTi1HeyAPuBJSQkEI/Ho/QP+SJNvR/uNYq8vb2/aEmkP/74\\\\\\n\",\n              \"g9yHeooUYBHXbpBx06YVM/fu3LlD2tralJ2dTUREISEhxOPx6NatWw16fD5zdnamsWPHkoqKCp04\\\\\\n\",\n              \"c77GxZIvR90kI2NjWrVqVYPUW1RURD/91Jf8Jk0W6d/n6KkzZCfEVH+GqY449Z3sicw/sHv37sGx\\\\\\n\",\n              \"U+caF7GtSc/efZCRkQkJCQmR6xwzZgzCL4YiIz1d6DIb1/8BLoeD1NRUlJWVwc/PD6tWrYK6ujqO\\\\\\n\",\n              \"Hj2K8ePH48KFCxUL9Ta03NxcnD59BuFXr8O1e48av3fbdu0RdjkKmzZvRkhISL3rlZWVxdNnT+E1\\\\\\n\",\n              \"aoxI5Vy790DKy5d4+/ZtvdvAMN8yFmA/sPz8fCgoKIpcTlFRqWKlelGpq6vDx8cHvmO9UFJSUuf2\\\\\\n\",\n              \"p04cw83oKFhaWsLa2ho9e/aEoqIiRowYgf3792PatGkIDw+HnZ3dF7VHGOkZGZj760K0tKj76dV8\\\\\\n\",\n              \"HR38sWETlv72G6gBHvSQk50NbW2+SGWkpKSgpaWNrKysetfPMN8yFmA/MFVVVeRki97JZWdnQVVV\\\\\\n\",\n              \"9YvrXbFiBTTU1OA+oA/evnlT7TalpaXYsXUTpk0aj1ZW1ggJCYGVtTUiIyPx9OlTDB8+HD///DMu\\\\\\n\",\n              \"X74MKyurL25LXd6+fYuc7GwMG+4ldBnX7j2QnZ2NmJiYetcvKysrVND/W1FxEWRlZetdP8N8y1iA\\\\\\n\",\n              \"/cA6d+6MWzdvICcnR6RyISePw9XV5YvrbdKkCY4ePQJHB3t0treF19DBOBNyErdv3UDk1ctYtngB\\\\\\n\",\n              \"LM2b4mzISVyNvo0zoeF4FJ+MAYPcweFwYGlpiVOnTkFGRgZPnjxpkJFOTc6cOYMevfpAWVlZ6DKS\\\\\\n\",\n              \"kpIY6umF48eP17v+5i1a4PatGyKVeZeaiqzMTOjp6dW7fob5lrEA+4HxeDz07tMHB/fvE7rM2zdv\\\\\\n\",\n              \"EB0ViREjRtSrbikpKSxbtgwpKSno3dMNK5ctxqhh7li1/DfkfvyI0+cvIjTiKkzNzAEASkpKGOPj\\\\\\n\",\n              \"i6hbsUhKTsacOXOwfft2LFmyBB07dsTNmzfr1Z6aZGRkwMjYWORyurq6yMzMrHf9fuPHY/fO7SKV\\\\\\n\",\n              \"2Ru4Cx4eHlBUFP30MMOIExZgP7hpU6fiz7WrkZj4os5ty8vLMWPqJIwdOxZKSkoNUr+SkhLc3Nzw\\\\\\n\",\n              \"Pi0NN2P+xsXL17Bm/Z81Xm/S5vNxPuwKNmzYABsbG/z111/w9fWFh4cHBg0ahISEhAZp12eysrIo\\\\\\n\",\n              \"LioSuVxxSUmDnMLr378/EhPiER0VKdT2mRkZ2LNrOyZOnFjvuhnmW8cC7AfXpk0bLFq0CH17uuLR\\\\\\n\",\n              \"wwc1bpefn48xXsNQWlyM5cuXN2gbtm3bhmEjRkJTS0uo7Q2NjNCnb38EBgZCSkoKo0aNQnx8PNq2\\\\\\n\",\n              \"bQsHBwdMnjwZGRkZDdK25s2bI/p6lMjl7t29AzMzs3rXLy0tjT179mCM1zD8/de9WrfNysrCkAE/\\\\\\n\",\n              \"YcyYMY16XZBhvhlfex5/YxGnexm+BQcOHCAOh0O9eveh46fPUdLrNHr9Pptuxd4n/6nTicPh0KjR\\\\\\n\",\n              \"o6moqKhB6y0vLycej0d/PXom0r1OV6Nvk4mJSZX9paen0+TJk4nD4dCyZcvqvSJFSUkJKSur0O17\\\\\\n\",\n              \"cUK3LfnNe1JTU6vXSv7/dvLkSeJyuTR91hx6+Cyxys3lq9f9SYZGRjRr1qwvuj+PYT4Tp75TgqgR\\\\\\n\",\n              \"r4B/RW3atEFsbOzXboZYyc/PR3BwMHbs3IkXCQkoKSkBj8fD4MGDMWHCBBh/wbWgunz48AGGhoZ4\\\\\\n\",\n              \"ky7aRJKysjLwVBVQUlICScmqJxISEhIwb9483L59G0uXLoWXlxekpKS+qI3NmjWDjW0bBO4/JNS9\\\\\\n\",\n              \"b4t+mYeszPfYu2fPF9VXkxcvXmDz5s3Yv38/TEzNwOFwkJ9fgLj7f8HF1RX+kybBycmpQetkfjzi\\\\\\n\",\n              \"1HeyAGO+qoyMDDRv3hwvU0U75UdE0FCSRVFREZo0aVLjdrdu3cKsWbOQl5eHVatWwc3NTeQ2durU\\\\\\n\",\n              \"CRkZGfip30D8unhprSEWuHM71q1ZiVu3boHPF+3+LWEVFBTg7t27+PjxIxQUFGBpaQltbe1GqYv5\\\\\\n\",\n              \"8YhT31nz//kM8x9QU1NDQUEBPn78KNK9Zalv30JFRaXW8AIABwcHREdHIyQkBJMnT4aRkRFWrVqF\\\\\\n\",\n              \"1q1b11lHdnY2wsLCkJaWhm7duuHs6ZN48OA+Jk+dgc7OXSoF2b3YGGzfvBExd28jIiKi0cILABQU\\\\\\n\",\n              \"FODs7Nxo+2cYccEmcTBflbS0NH7q2xfBhw6IVG7/vj1wd3cXalsJCQkMGDAAjx8/Rr9+/dCjRw+M\\\\\\n\",\n              \"GjUKr1+/rnb7hw8fYsyYMWjatCkOBx+BY6fOyC8sAiCBp48fw9d7FGxbNYen+0CMGDoYHdraYKzX\\\\\\n\",\n              \"MLS2tkRMTAxMTU1F+i4Mw3wZNgJjvrqJEyZgwoSJ8PH1E+o6VUlJCfYG7sSF0FCR6pGWlsakSZPg\\\\\\n\",\n              \"5eVVMQrz9fXFzz//XDH6CwkJwbhx4zBl+kw8fJoAHo9XUZ6IcD0qEitXLEPep09wHzwIcnJy4PP5\\\\\\n\",\n              \"cHBw+OJrbAzDfBl2DYz56ogI3Vxc0LxlK/y+em2t15gEAgEm+nqjqDAfJ+q50sWbN2+wcOFCnDt3\\\\\\n\",\n              \"DvPnz0fz5s3h5eWFU2dCYVvL2orl5eXwnzAeb968Quj585CWlq5XOxjmWyJOfSc7hch8dRISEjhx\\\\\\n\",\n              \"/Dju3IzGpPE+Na5Un/r2LcZ4DcOrlGQE7RN+9ZCa6OnpYffu3bh06RJCQ0Ph7u6OrTt21xpewP9W\\\\\\n\",\n              \"Edm4ZRvy8vIbZLkohmG+DAsw5pugrq6OyMhIyMtKw86qBXxGjcCxI4dx8cJ5HDl8EF7DhsChjTUM\\\\\\n\",\n              \"9HQRER7eoMskWVpaIiAgAFpa2ujZq7dQZZo0aYKp02diy5YtDdYOhmFEw04hMt+cnJwc7NmzB7du\\\\\\n\",\n              \"3canT5+goqICJ6fOGDlypEiL6opixIgRsG3THhP9JwtdpqysDObNDHH58mU0b968UdrFMP81ceo7\\\\\\n\",\n              \"2SQO5pujrq6OGTNm/Kd1vkhMxNhxE0Qq06RJE1i0skRSUhILMIb5CtgpRIYBUF5WVuc9ZdWRlpZG\\\\\\n\",\n              \"WVlZI7SIYZi6sABjGACaWlpISXkpUhkiwsuXydASchFihmEaFgswhgEw1MMDQXsDRSoTGxODwoIC\\\\\\n\",\n              \"tG3btpFaxTBMbViAMQyAIUOGIO7+34h//lzoMtu2bIKf34RqFxNmGKbxsf/zGAaAnJwc5s2bjxHD\\\\\\n\",\n              \"3PHhw4c6tz+4Pwg3oqPg4+P9H7SOYZjqsABjmP8zdeoUuLi4oJtzR8TGxFS7zadPn7ByxTIs/HUe\\\\\\n\",\n              \"QkNDoaGh8R+3kmGYz9g0eob5PxISEvjjjzUw32EOL093cLk8DBvuBR1dXRQVFeH2rZs4GnwITs7O\\\\\\n\",\n              \"uHHjBgwNDb92kxnmh8YCjGH+QUJCAuPH+8LHxxsXL17EyZOnEBV5BbKysmjZoiUePHgAPT29r91M\\\\\\n\",\n              \"hmHAAoxhqiUlJYXevXujd2/hlpZiGOa/x66BMQzDMGKJBRjDMAwjlliAMQzDMGKJBRjDMAwjlliA\\\\\\n\",\n              \"MQzDMGKJBRjDMAwjlliAMQzDMGKJBRjDMAwjlliAMQzDMGKJBRjDMAwjliSIiL52IxoDl8uFkZHR\\\\\\n\",\n              \"124GwzCMWHn58iUyMzO/djOE8t0GGMMwDPN9Y6cQGYZhGLHEAoxhGIYRSyzAGIZhGLHEAoxhGIYR\\\\\\n\",\n              \"SyzAGIZhGLHEAoxhGIYRSyzAGIZhGLHEAoxhGIYRSyzAGIZhGLHEAoxhGIYRSyzAGIZhGLHEAoxh\\\\\\n\",\n              \"GIYRSyzAGIZhGLHEAoxhGIYRSyzAGIZhGLHEAoxhGIYRSyzAGIZhGLHEAoxhGIYRSyzAGIZhGLHE\\\\\\n\",\n              \"AoxhGIYRSyzAGIZhGLHEAoxhGIYRSyzAGIZhGLHEAoxhGIYRSyzAGIZhGLHEAoxhGIYRSyzAGIZh\\\\\\n\",\n              \"GLHEAoxhGIYRSyzAGIZhGLHEAoxhGIYRSyzAGIZhGLHEAoxhGIYRSyzAGIZhGLHEAoxhGIYRS/8P\\\\\\n\",\n              \"fAbeq2LxvXIAAAAASUVORK5CYII=\\\\\\n\",\n              \"\\\"\\n\",\n              \"  frames[7] = \\\"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAbAAAAEgCAYAAADVKCZpAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\\\\\\n\",\n              \"AAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0\\\\\\n\",\n              \"dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOzdeVyN6eP/8VebChUtKntT9lB2WbJv\\\\\\n\",\n              \"JcqWnRDGvswMZiwfZgyGsWWdyZrdIC0kRMQghCxjX0qlRft6Ovfvj/npO03ROUZyuJ6Px3m03Pd1\\\\\\n\",\n              \"39c5cb/PdZ3rum41SZIkBEEQBEHFqJd0BQRBEAThfYgAEwRBEFSSCDBBEARBJYkAEwRBEFSSCDBB\\\\\\n\",\n              \"EARBJYkAEwRBEFSSCDBBEARBJYkAEwRBEFSSCDBBEARBJYkAEwRBEFSSCDBBEARBJYkAEwRBEFSS\\\\\\n\",\n              \"CDBBEARBJYkAEwRBEFSSCDBBEARBJYkAEwRBEFSSCDBBEARBJYkAEwRBEFSSCDBBEARBJYkAEwRB\\\\\\n\",\n              \"EFSSCDBBEARBJYkAEwRBEFSSCDBBEARBJYkAEwRBEFSSCDBBEARBJYkAEwRBEFSSCDBBEARBJYkA\\\\\\n\",\n              \"EwRBEFSSCDBBEARBJYkAEwRBEFSSCDBBEARBJYkAEwRBEFSSCDBBEARBJYkAEwRBEFSSZklXoLgY\\\\\\n\",\n              \"GxtTvXr1kq6GIAiCSnn69ClxcXElXQ2FfLYBVr16dUJDQ0u6GoIgCCqlSZMmJV0FhYkuREEQBEEl\\\\\\n\",\n              \"iQATBEEQVJIIMEEQBEEliQATBEEQVJIIMEEQBEEliQATBEEQVJIIMEEQBEEliQATBEEQVNJnO5FZ\\\\\\n\",\n              \"EAThY8nIyMhbvcLExAQdHZ0SrtGXQbTABEEQ3oMkSYSEhNDfdSBGxibYNm2OTZNmGBobM2jIUC5f\\\\\\n\",\n              \"vowkSSVdzc+aaIEJgiAoKT09nQEDBxEadgubHq5M2XkWXT0DANKSErhx4g96uvSjjV1LvHZsEy2y\\\\\\n\",\n              \"YiJaYIIgCErIycmhh2NPXqbDmA0+tHQZmRdeAGUMDLHrNwb3jX48eJWEc5++5ObmlmCNP18iwARB\\\\\\n\",\n              \"EJSwZOlSXmXIcZyxBE2tUm/dT0tbh97f/crj6ATWrFnzEWv45RABJgiCoKCcnBw81m2g3YgZaGgU\\\\\\n\",\n              \"/QmMhqYWbYdPY+UaD+Ry+Ueo4ZdFBJggCIKCfHx8MDCthNlXtRUuU7m2DerapQkMDCzGmn2ZRIAJ\\\\\\n\",\n              \"giAo6M9Ll6jaqI1SZdTU1KjWqA2XLl0qplp9uUSACYIgKCg5JZVSOqWVLldKtwwpKanFUKMvmwgw\\\\\\n\",\n              \"QRAEBZUvZ0BGarLS5TJTkyhXzqDoHQWliAATBEFQUKeOHXl0MVCpCcpyuZwHFwPp0KFDMdbsyyQC\\\\\\n\",\n              \"TBAEQUEdOnRAExnPb19VuMzjayEYlzOgRYsWxVizL5MIMEEQBAWpqakxfMgQfFfPIzsjvcj9M9NS\\\\\\n\",\n              \"CNqyjO++mYGamtpHqOGXRQSYIAiCgo4ePYqHx1osK5uyb94Y0pNfv3Xf1Nfx7P1hFI5dOjJkyJCP\\\\\\n\",\n              \"WMsvhwgwQRCEIsjlcubPn8+ECRPw8fHhfPBZnDq1Zf2oLviu/oGoR3fIycokOzODl/fD8V89lw1j\\\\\\n\",\n              \"ujLI2ZH1HmtF66uYqEmf6XLJTZo0ITQ0tKSrIQiCinv9+jVDhgwhNTWV/fv3Y2pqmrctMjKStvb2\\\\\\n\",\n              \"pKSlk5SQAIBZpcqMdhvBmNGjMTMzK6lqvzdVunaK1egFQRDe4tatWzg7O9OzZ0+WLVuGlpZWvu1l\\\\\\n\",\n              \"ypQh9tUroqKiKFOmTAnV8ssluhAFQRAKsXfvXjp06MD//vc/Vq5cWSC8APz9/WnXrp0IrxIiWmCC\\\\\\n\",\n              \"IAj/IJPJ+O677zh8+DCBgYHY2Ni8dV9vb2969er1EWsn/JMIMEEQhP/v1atXDBgwAG1tbUJDQzE0\\\\\\n\",\n              \"NHzrvllZWQQEBIhbpZQg0YUoCIIAXLlyhaZNm9KqVSv8/PzeGV4AQUFBWFtb5xvUIXxcogUmCMIX\\\\\\n\",\n              \"z9PTk9mzZ7Np0yacnZ0VKiO6D0ueCDBBEL5YWVlZTJkyhbNnzxIcHEzt2ord50sul3P06FGCgoKK\\\\\\n\",\n              \"uYbCu4gAEwThixQZGUnfvn2pWLEily9fRk9PT+GyoaGh6OvrU7NmzWKsoVAU8RmYIAhfnODgYJo2\\\\\\n\",\n              \"bYqTkxMHDx5UKrxAdB9+KkSACYLwxZAkiTVr1tC/f3+2bdvG7Nmz32uZJ29vb3r37l0MNRSUIboQ\\\\\\n\",\n              \"BUFQWTKZDF9fX37ftoOXL6NR11DH0qI6491HY29vny+c0tPTcXd35/bt21y8eBELC4v3OuejR4+I\\\\\\n\",\n              \"j4+nWbNmH+hZCO9LtMAEQVBJ27Ztp2KVakyY/T9eGFij3XY4Wi2HcEduRv/hY7CsWYfAwEAAHj9+\\\\\\n\",\n              \"jJ2dHWpqaoSEhLx3eMHfra+ePXuiri4unyVNtMAEQVA5Py7+mV/XbqDuyMWUq1Y33zajWo2p2q4/\\\\\\n\",\n              \"sbcv0sd1EF+PGc3WrVv44YcfmDhxolJdhhkZGezfv5/TZ8+RnJKCgb4+f4acY8mSJR/6KQnvQQSY\\\\\\n\",\n              \"IAgq5cCBA6xYs57G0zejY2Bc6D5qampUsLZDZ8IaVvwymnWrV+Lu7q7wObKysvh+7jx++92TctXr\\\\\\n\",\n              \"oFfLDi29auSkpxCTrcWIUWOYePUa8+fNLXSNROHjELdTEQRBZUiSRK169SnfeSwmdVsoVOZp0D6+\\\\\\n\",\n              \"ynrK0cMHFdo/LS2NTt168DJDg696T6KMSeUC+6RGP+PR4VXUqFCWYz7eaGtrK/U8PmWqdO0UnbiC\\\\\\n\",\n              \"IKiMCxcukJCchnFtxQdQVGrhwKlTgbx8+bLIfSVJoq/rIGIkfaxH/VxoeAGUNatGffdfeJyYy9AR\\\\\\n\",\n              \"bgrXRfiwRIAJgqAy9h88iHHjrqgpMYBCS7csZg3a4OPjU+S+V65c4VLoNWoPmlPkOdQ1NKkzdB4B\\\\\\n\",\n              \"gSe5c+eOwvURPhwRYIIgqIzomFi0y1VQupyGnjHR0dFF7rdyjQfmrVxQ11BseIBGKR0qtnRijcc6\\\\\\n\",\n              \"pesk/HdiEIcgCCpDR1sbeXaO0uWyM9NZtGgRy5cvx9TUNO9hZmaW972xsTGHDh7A/qejSh27op0T\\\\\\n\",\n              \"u5cOZ+N6EWIfmwgwQRBUhnW9Opz3DgZclCqXE/WAw4cP07ZtW2JiYvIe0dHRxMTEcO3aNV68eIEc\\\\\\n\",\n              \"NUqVLafUsXXKm5KemkJ2djalSpVSqqzw34gAEwRBZYwYPpwFC3/EsneiwkGT9PwvcpJj6d69O5qa\\\\\\n\",\n              \"mhgYGBS6CG9SUhJmFQsftPFOkoSEhIaGhvJlhf9EBJggfIJyc3M5deoUT548ITc3F1NTU7p160aZ\\\\\\n\",\n              \"MmVKumolysTEBEdHR8JO76GG0/gi95ckiYd+m9FUkwgNDaVFi7cPvdfT00NTU4P0uEhKG1dSuE4p\\\\\\n\",\n              \"UY8xrmAmAqwEiEEcgvAJSU5OZvHixVSrbsG0b2bjczqE4+eusGz1OipXqcqUKVN5/vx5SVezRP36\\\\\\n\",\n              \"y1KSwgJ5EXLknftJksRD73UYSaksWrgQFxcX3N3diY+Pz7dfYmIinp6edOzYkdzcXJ6fVWy+2BvR\\\\\\n\",\n              \"IYcZO2a00s9D+O9EC0wQPhGRkZF07tIVc4uazPp1CzXqNcy3PeblC/z2bqVJs2b4eHvTvHnzEqrp\\\\\\n\",\n              \"x3fnzh1+8/TkrwePyM3NpXUrO84c9+RV+EUsuw6nXPX/W05KksuJu3eFl2f2YKSVQ+DJAExMTOjb\\\\\\n\",\n              \"ty/z5s2jbt26/O9//8PMzIzdu3dz4sQJOnbsyJQpU6hRowYtWtvzVbcRlCpjUGS9spLjibxygnE7\\\\\\n\",\n              \"VxXn0xfeQgSYIHwCkpKS6NylK807OTHAfWqh6/WZVqyC2/R51GvUAseeTgSfPUOdOnVKoLYfz/Xr\\\\\\n\",\n              \"1/l68lTu3r1HjXa9KV/DHnUNDaJeRSBp6ZD04Cq3n91C28CI0iaVQZKTFPkIIwM9fpgyiWHDhqGr\\\\\\n\",\n              \"qwuAvr4+rq6uREVFMWnSJHR1dZk6dSqbNm2ifPnywN+ttgbWdbm8ehItpm9AU+ftXbY5acnc2jST\\\\\\n\",\n              \"6VOnULFixY/yegj5iQAThE/A6tWrqWRZ563h9U/N23UhatRkZnzzLf6+RU/OVVVBQUH07tMPW9cp\\\\\\n\",\n              \"9J+wGg2t/CP86jsO5+Xty4RsmEtfhy60s2+LhoYG1apVw9bWNu91fPDgAV5eXnh5eVGqVCmGDh3K\\\\\\n\",\n              \"kiVLCAoKYs6cOSQmJrJo0SLKlCnDhAkTSEtJplvrxpxeOZaq3cdQoX6rfPPC5LkyYsLO8uzYZgb3\\\\\\n\",\n              \"7c3C/y34mC+L8A9iLURBKGEymYwq1aoxb60XX9W2VqhMZkY6wzvbEnbtGtWrVy/eCpaABw8e0KyF\\\\\\n\",\n              \"HW0m/0Il63cvG5UaH43//GFs9lhNnz59AIiNjWXfvn14eXnx9OlTXF1dGTp0KI0aNcr3BiE+Pp7Z\\\\\\n\",\n              \"s2fj4+ODmZkZJiYm/PHHH5QtW5Z9+/axdMVKnj57gbG1HWqlSpObkUL0jWAM9PVp1bwJAwYMwMnJ\\\\\\n\",\n              \"6bMaPq9K104xiEMQSlhAQAAmZpUUDi8AHd3SdHTsh+eWLcVYs5KzZNkv1OjUv8jwAihrZEbLMQuY\\\\\\n\",\n              \"9cM89u7dS8+ePalRowZ//vknCxYsICIiglWrVtG4ceMCrVsjIyN+/PFHypcvz/Pnz8nOzubFixeo\\\\\\n\",\n              \"qanh6urK9SuXOB3gR3srQ1LDg4i+HkRV2zZUtXfhqbop3/64nIqVq/LD3HlkZGQU18shvIXoQhSE\\\\\\n\",\n              \"Evbo0SO+ql1f6XIWta15ePtSMdSoZCUlJbF//36cl797lOE/VW7QkjPrf+DXX39l0qRJ7N69Gz09\\\\\\n\",\n              \"vSLLPXjwgO7duzNo0CDmzZvHxo0bsbe3x83Njblz51K2bFm8j/pwyO84TYZ8S7XG9gWWmUp48ZAD\\\\\\n\",\n              \"B9ZzLKAdJwOO532eJhQ/0QIThBImk8nQUHDtvX/S1NRCliMrhhqVrGPHjlGpbmPKGCq+5qGamhr1\\\\\\n\",\n              \"uw3CtnFThg4dqlB4/fnnn7Rt25bvvvuOhQsXoqmpycSJEwkPDycqKoq6desyxt2dDVu247DQC4tm\\\\\\n\",\n              \"HQtdI9GwihXtp61AXqEGDr16I5N9fn+TT5UIMEEoYaampryKeqF0uejI55iZKb+w7acuJiYGXSNz\\\\\\n\",\n              \"pcuVrVCJqOgYhfb19vamZ8+e/P7774wZMybfNlNTU3bs2MFvv/3G9h07aT9jNaXLFX7jzDfU1NRo\\\\\\n\",\n              \"Nuw7IuKSOXpUubUUhfcnuhAFoYQ5ODjw9YSJJMTGYGhiqlAZuVzOKe+97N+zq5hr9/FpaWkhyXOV\\\\\\n\",\n              \"LifPlREYeAIrKyvMzMze+ggICMDDwwN/f3+aNm361uNFRkZSvWELylf6SqHzq2toULPrIH5d44GL\\\\\\n\",\n              \"i3JrNQrvRwSYIJSwcuXK0bdvX44d3Mng8TMVKnM1JAh9vbLvXBpJVVlaWvL66Waly71+epeJX4/D\\\\\\n\",\n              \"fcwYoqOj8z0uXrzIy5cvuX79OnFxcaipqdG9e3fMzMwwNzcvNOhWeWzAsstwperwVfPO7N22hIiI\\\\\\n\",\n              \"CCpXfo91FQWliAAThE/ArO++paWdHdaNWtCweet37hvz8gW/fj8JTXU1zp8/T5s2bT5SLT+OTp06\\\\\\n\",\n              \"kZ0cT+yj25hY1lOoTE5WBg+Cj/LHqqtYWFhQo0aNfNuzs7Nxc3PDwsKC8PBwypcvT3x8fIGgi4yM\\\\\\n\",\n              \"5OrVq0RHR/Pg/l/Uc7NUqu4aWqUob16FyMhIEWAfgQgwQfgE1KhRgwP799O3f3+GTppDJ6f+aP1r\\\\\\n\",\n              \"bpEkSVy7cIY186ezcMF8LCwsGDBgAG5ubsyfPx8tLa2SqfwHpqGhwYTx49jluxX7yb8UObEb4N7J\\\\\\n\",\n              \"P2jerDkWFhYFtiUlJeHi4oK+vj6nTp2idOnSAFSoUIEKFSrQoEGDQo9ZuZoFRZ+5cJ/p9NpPjggw\\\\\\n\",\n              \"QfhEtG/fnpMnTjBtxky8PJbQuZcrFrWtUVfXIDriKSe991JauxQb13vQu3dv4O+llkaOHEmrVq3Y\\\\\\n\",\n              \"tWtXgZaHqpoyeTK79+7j+sH12Pb9+p0h9uzqWe4c/Z2Qc2cLbIuIiKBHjx60adOGNWvWKLVifMWK\\\\\\n\",\n              \"lUiMfIKBeTWFy+TmZPM66gWVKim+mr3w/sQoREH4hNja2nLm9CnOB5+lUrlS3Lt4gptnfNBIi2Xn\\\\\\n\",\n              \"Vk9uh9/KCy/4e8Scn58fw4YNw87Oji1btnwW7/719PQ4dSKAmCsBHPv5a2Lu3yzwvJJjIri8czmX\\\\\\n\",\n              \"fv8f/r5HC6wLGR4ejp2dHUOGDMHDw0Pp2524jxrBwyDlVqZ/cvkU1vXrU6VKFaXKCe9HtMAE4RNU\\\\\\n\",\n              \"u3Ztlv/yi0L7qqmpMXHiRNq3b8+gQYPw9/dn06ZNGBkZFXMti5eJiQk5mRlkPb/DpQ2zUNfVw7Ba\\\\\\n\",\n              \"TdTUNUmLjeTVk7uMGD6Mw2suU7Vq1Xxlg4KCGDBgAKtWrWLQoEFKnzs5OZl79+7x5PoFEqOeUU6B\\\\\\n\",\n              \"Vpgkl3M/YBcr/jdH6fMJ70e0wAThM1GvXj0uX75MtWrVsLGx4dSpUyVdpf9k3bp1xMXF4X3kCC+e\\\\\\n\",\n              \"Pmb7htVMHtCD8c4d+GXuTKIiXrB65coC4bVnzx4GDBjAvn37lA4vmUzG5s2bqVWrFnFxcfzw/fcE\\\\\\n\",\n              \"LZ9MRlLCO8tJksTlnb9gaqBLr169lH6uwvsRLTBB+Ixoa2uzYsUKunXrxvDhwxk0aBCLFi1CW1u7\\\\\\n\",\n              \"pKumlKSkJObMmUPnzp1p1aoVAB06dHhnGUmSWL58OWvXruXUqVPUr6/c8lyBgYFMnz4dIyMjfH19\\\\\\n\",\n              \"ady4MQBZ2dn8Pn8oTYZ+SxWb1qj/qysyMfIJYQfXUyo1hsDAgM9mMI1KkD5TjRs3LukqCEKJio2N\\\\\\n\",\n              \"lXr16iXZ2NhId+7cKenqKGXMmDGStra2FBkZqdD+MplMmjhxomRtbS29ePFCqXPdvn1b6tGjh2Rl\\\\\\n\",\n              \"ZSUdPnxYksvlBfY5cOCAZG3TSDKuWFVq5Dxaajn8W6nZwCmSZaNWUnkjE+mb72ZJqampSp33U6VK\\\\\\n\",\n              \"107RhSgInyljY2MOHz7M119/Tdu2bdmwYYNKDPB4/Pgx27ZtY+bMmQrdKDIjI4O+ffty+/Ztzp8/\\\\\\n\",\n              \"r/D8q9jYWCZMmIC9vT2dOnXi9u3b9O7du9ARj3379uXW9av4HzmIY/2K2Opl0K6qDotmTiAq8gXL\\\\\\n\",\n              \"lvxMmTJvv/mlUDzE/cAE4Qvw119/MXjwYMzNzfH09KRChU93DcXWrVtz7949IiMji+z6jIuLw8nJ\\\\\\n\",\n              \"CQsLC7Zs2aJQV2lWVhZr165l6dKleavQq/qAlw9Jla6dogUmCF+AWrVqceHCBaytrbGxseH48eMl\\\\\\n\",\n              \"XaVCnT59mkuXLvHbb78VGUaPHz+mVatW2Nvbs3PnziL3lySJgwcPUqdOHc6dO8f58+dZvXq1CC9V\\\\\\n\",\n              \"VrI9mMVHlfpxBeFjCgoKkqpUqSJNnjxZSk9PL+nq5MnNzZUqVaok2draFrnvlStXJHNzc8nDw0Oh\\\\\\n\",\n              \"Y1+6dElq1aqV1LBhQ+nkyZP/taqfNVW6dooWmCB8Ydq1a8eNGzeIjo6mWbNm3LxZcJJwSVi7di0x\\\\\\n\",\n              \"MTHs3bv3nfv5+/vTvXt31q9fz4QJE96574sXLxgyZAi9e/fGzc2Nq1ev0rFjxw9ZbaEEiQAThC9Q\\\\\\n\",\n              \"YmIiVatVIzIqGhsbG7S0tKha3YIF//sfL1++/Oj1SUtLY86cObi6ulKzZs237vf777/j5ubG0aNH\\\\\\n\",\n              \"861I8m+pqanMnTsXGxsbvvrqK+7fv4+bm5vSq3EInzYRYILwBZHJZIz/egK2jZvwV1QSP2w6gNel\\\\\\n\",\n              \"J+y4+JCJSzdx8fYjatety9y58z5qq2zy5MlIksT69esL3S5JEvPnz+fnn38mODiYli1bFrpfbm4u\\\\\\n\",\n              \"np6e1KxZk6dPnxIWFsbChQspW7ZscVZfKCFiIrMgfCHkcjkDBw3m8ctXrPIOobSefr7t1WtZM2rO\\\\\\n\",\n              \"EvqOm8mKaSNJeP0aj7VrFFoN/r949uwZ27dvZ/ny5ejp6RXYnpOTg7u7O+Hh4Vy4cAFT08Jv+nnq\\\\\\n\",\n              \"1ClmzJiBnp4e3t7e77xZpSqQyWRoaGgU++uvykQLTBC+EKtWreLOo6dMW+FZILz+ycDQmFkeu/A/\\\\\\n\",\n              \"EVjk51EfgqurK+bm5kyePLnAtpSUFBwdHYmNjeXMmTOFhtdff/2Fk5MT7u7uzJ07l+DgYJUML0mS\\\\\\n\",\n              \"uHTpEoOHDkNP3wBtbW20tLSoXc+atWvXkpycXNJV/OSIABOEL0Bubi4rV61myPQFlNLWKXL/0nr6\\\\\\n\",\n              \"9J84m0U/LS7Wep0+fZrLly+za9cu1NXzX46ioqKwt7enWrVqHDlypMBE4fj4eCZPnkzr1q1p27Yt\\\\\\n\",\n              \"d+7coU+fPirZYklISKB9x0649HcFo6os++MMOy4/ZUvIffpMnsde30CqVqvG7t27S7qqnxQRYILw\\\\\\n\",\n              \"BTh+/DhlyhlhWa+hwmVsW3Ug8mUUkydPJjc394PXSZIkhgwZkhdA/3T37l1atmyJi4sLmzZtQlPz\\\\\\n\",\n              \"/z7tyM7OZuXKldSuXRu5XM6dO3eYOXOmyq33+EZiYiKt29qjV7kGS/84i8OwcRgYmaCmpoamVinq\\\\\\n\",\n              \"NbVjws/rmbVxP9Nmfsvvnp4lXeVPhggwQfgCnDp1Glv7rkqVUdfQoE0PZ3x8fOjUqRMvXrz4oHVa\\\\\\n\",\n              \"uXIlsbGx7Nu3L9/vz58/T7t27ViwYAE//PBDXotKkiQOHz5MvXr1OHnyJMHBwXh4eGBiYvJB6/Wx\\\\\\n\",\n              \"jXYfS7X6TXCd8n2BVug/Va1Rh288dvHtd7O4ffv2R6zhp0sEmCB8ARKTkymrb6B0udJ65Rg4cCCd\\\\\\n\",\n              \"O3emSZMmHDhw4IPUJyMjg++//57x48djZmaW9/uDBw/i4uLCzp07GTFiRN7vr169Srt27Zg3bx7r\\\\\\n\",\n              \"1q3Dz8+vwA0sVVFERASBJ07Qb8Ishbo+K1a3pEOfoaxZ6/ERavfpEwEmCF8A/bJlyUhPU7pcVnoq\\\\\\n\",\n              \"BgYGzJkzB19fX77//ntGjhxJSkrKf6rP2LFj0dLSYsWKFXm/W7VqFVOnTuXEiRN06dIFgMjISEaM\\\\\\n\",\n              \"GIGjoyNDhgzh+vXreds+Bxs3bcKue290Siu+EHA754Hs3btXDOpABJggfBHs7FoSfiFIqTKSJHH9\\\\\\n\",\n              \"3ElatGgBQNOmTbl27RqamprY2try559/vlddnj59yq5du/Dw8EBLSwu5XM706dPZvHkzISEh2NjY\\\\\\n\",\n              \"kJaWxoIFC2jQoAEVK1bkr7/+YsyYMfk+C/scnDodRCP7bkqVMaxgTpWvanDt2rViqpXqEAEmCF+A\\\\\\n\",\n              \"3r178/LpQyIe3Ve4zO3QCyTEvWLOnDns27ePnJwcypYty2+//cayZcvo1asXixYtQiaTFSibmZmJ\\\\\\n\",\n              \"l5cX9h06U7teA+o3bEx/10GcP3+ePn36YGlpybBhw8jMzGTgwIGEhoZy/vx5qlSpwrZt26hVqxb3\\\\\\n\",\n              \"79/n2rVrLF68GH39tw/7VzXZ2dlEREQQGhpKdHTMO6c0vE0ZfQPRAkNMZBaEL0KpUqUYN24se9cu\\\\\\n\",\n              \"ZvoKzwJ3Ff637KxMDq5fxtKfF2Nubs6aNWuYMWMG48aNw93dHRcXF5o3b87w4cMJCAhg586dWFhY\\\\\\n\",\n              \"IEkSvyxfwY8//oSargkZOhaolaoJ2XLuh0Tg69+bjLRktnpuJiEhgd69e2NmZsaJEyf4888/mT59\\\\\\n\",\n              \"Ojo6Ohw8eDCv5acKcnNziYuLIyYmhujo6HyPf/8uOTmZChUqYGpqSnpGOpnpqUqfLyMttdBJ318a\\\\\\n\",\n              \"cT8wQfhCZGdn06Vbd9TKlGfMvOVoapUqdL/M9DSWTx9FfORT/rp7J28Zpps3b7J27VoOHjyIk5MT\\\\\\n\",\n              \"kyZNolGjRqxcuZIlS5bw66+/EnLhT3btP0J2hfao65QrcGxJkpAnPUEz9jzGhuXo27cvY8eOZdas\\\\\\n\",\n              \"WVy/fp2lS5fSv3//T2IulyRJJCYmvjOM3vwcFxdHuXLlMDMzy3uYmpoW+rORkRHq6upIksTQYcNI\\\\\\n\",\n              \"kOswZMZ8heuV/Dqeb13sefrkMYaGhh/8eavStVMEmCB8QdLT03EdNJhrYTfo4DKEdr0GUNagPACJ\\\\\\n\",\n              \"ca8IOrKb04d20a1rV3JzsomMjMTHxyffWoIJCQl4enqybt06KlasyKRJk7CysqJnTyfiU7LRsHRB\\\\\\n\",\n              \"TePdc7LkadHw/BiuA/ri5+fHN998w5QpU9DRKXqS9X+VmppaaBj9+3cxMTHo6uq+M4zePExMTNDS\\\\\\n\",\n              \"0lLo/HFxcXh5eeHp6UlSUhKvk5JZe+wKpXR0FSrvu309mklR7Ni29b+8DG+lStdOEWCC8IV5s2TR\\\\\\n\",\n              \"Go91HDl8GN3SpZEkiZzsbFxdXZk44WsaNmxIbm4uY8eO5e7du/j7+2NgkH8Yvkwmw8fHh7Vr13Lv\\\\\\n\",\n              \"3j0Sk1LJrdwd9dKKzcuSRV3Csnw6wWeD/vMdorOyst7aOvr3z3K5PF/4vC2gTE1N0dVVLFSKkpub\\\\\\n\",\n              \"S2BgIJ6engQGBtKzZ09GjRpF27ZtcXRyQtfcij7jZxZ5nLioSBa59ea4vy+NGzf+IHX7N1W6dooA\\\\\\n\",\n              \"E4QvWFZWFgkJCairq2NoaFigFSGXy5k4cSKhoaEEBARQvnz5Qo+zevVqvp27BHXLPgqfW8pJQ/3J\\\\\\n\",\n              \"QaKjIgsdpCGTyYiNjVWotZSWlpYXQkW1lsqWLfvRuigfP37M1q1b2bZtG2ZmZowaNYqBAwfmezMQ\\\\\\n\",\n              \"ExND8xYtses5AIfhX7+1bq8in7NiyjCmTBjPzBkziq3OqnTtFIM4BOELpq2tjbm5+Vu3q6urs27d\\\\\\n\",\n              \"OqZPn07Hjh05ceIExsbGBfY7F/IncoNaSg1rVtMqg1rpCkyaNAlTU9MC4ZSQkIChoWGB1lLVqlVp\\\\\\n\",\n              \"1qxZvoAqX778O1ex+JgyMjI4dOgQnp6e3Lp1i8GDB+Pn50eDBg0K3d/U1JSQ8+fo4diTq0HHaddn\\\\\\n\",\n              \"CC27OOV1KT5/cJegP3byZ6APPy5cyKRJkz7m0/mkiQATBOGd1NTU+PXXX5kzZw7t27fn5MmTBVaF\\\\\\n\",\n              \"j46JQU1L8cm4b+RQioiICOrWrUu9evXytZaMjY1VZt6XJElcu3YNT09P9u3bR9OmTRk/fjxOTk4K\\\\\\n\",\n              \"rdFYqVIlroVewc/Pj2W/rGDL4tmULl0GmUyGvr4+Y8e6s3XVbSpWrPgRno3qUI1/HYIglCg1NTUW\\\\\\n\",\n              \"L16MtrY27dq149SpU/kupjraOiApv+CvPCebs2fP8uDBA8zNzTE3N6dixYqFfm9iYvLJ3VE5Pj6e\\\\\\n\",\n              \"Xbt2sWXLFpKSkhg5ciTXr1+natWqSh0nMjKSTZs3s3nzb2jp6GJZqx5pqcmkp6YwbNhQRo4YIcKr\\\\\\n\",\n              \"ECLABEFQiJqaGgsWLEBbWxt7e3tOnz6NiYkJQUFBJMS9QkrLhnJWCh9PkiR01VLxO32aqlWrEhUV\\\\\\n\",\n              \"xcuXL4mKiiIqKoqQkJC87whcyIIAACAASURBVKOiokhMTMTExKTIoDM1NS3WlptcLufUqVN4enpy\\\\\\n\",\n              \"/PhxevTowYoVK2jfvv17dWN6e3szws2Nll2c+MZjF1Usa+Vte/n0Iaf+8MLGthEea9cwePDgD/lU\\\\\\n\",\n              \"VJ4IMEEQlDJ8+HCuX79O7dq1UVdXx9bWls6dO3Jn7XokuR1q6opdVuSpEaSlJHLgwAEmTpxIy5Yt\\\\\\n\",\n              \"37l/dnY2MTExBYLuypUred+/fPmS+Ph4jIyMigw6MzMzSpUqfC5cYZ4+fcq2bdvYunUrxsbGuLm5\\\\\\n\",\n              \"sWHDhrcObFGEn58fo8a48+0ar0JvdVOxuhVDZyzAvpcr0yYNQVNTkwEDBrz3+T43YhSiIAjvJJfL\\\\\\n\",\n              \"CQ0NxdfXFz8/P54+fUq3bt3Q0tLi5MmTBAUFUaNGDezbd+TC/Uw0K9gWeUxJkqP+3A8D7WzU1dVJ\\\\\\n\",\n              \"T0+nadOmTJw4EQcHh//UVSiTyXj16lWBoPvn91FRUcTExFCuXLkCwfbPnw0NDbl8+TI7d+7k+vXr\\\\\\n\",\n              \"DBw4kFGjRmFjY/Pe9XsjLS2NKlWrMn3lNmo2KHpI/NO/brN43AAeP3pYLBOY31Cla6dogQmCUEBK\\\\\\n\",\n              \"SgqBgYH4+vri7++PkZERjo6OrFq1ipYtW+Z10f3222+0b9+eI0eOkJWRhizqKmqaZdAwrPnWY0uS\\\\\\n\",\n              \"HI2Y89haW3AqMIBDhw4xd+5cIiMjmTNnDpMnT2b8+PGMGjWq0BGPhcnKyuLQoUOs3bCJx48eIZPl\\\\\\n\",\n              \"YGxSgaGDXBkzenSh88zeLP/073C7e/cuhw4d4u7du8TGxiJJEqVLl6ZKlSrcu3ePX3/9tdAWnbm5\\\\\\n\",\n              \"eYG7Rr/Lrl27qGXTTKHwAqheqx62rTuwbds2pk+frvB5PmeiBSYIAgCPHj3C19cXX19fLl26hJ2d\\\\\\n\",\n              \"HY6Ojjg4OGBhYfHWcitWrOC7775DXV0dPT09snJyoWw1cvRq55vULEly5ElP0Em9g01dS/x8vfPW\\\\\\n\",\n              \"88vJyWHr1q0sXLiQGjVqYGBgwJkzZ+jduzcTJ06kSZMmbz3/wYMHcR8/AYNKllRu64Jh9bqoa2iQ\\\\\\n\",\n              \"FhfFi5CjPL9yilGj3Fi5Yvk7W3avX79m9+7deHp6Eh8fz8iRIxkxYgRVq1YlISGhyBbdy5cv86Yl\\\\\\n\",\n              \"vKvrsmLFiujp6dHAxpaeY7+lYUt7hf9G92+EsmXRDB49uF9sc9lU6dopAkwQvlA5OTmEhITkdQ0m\\\\\\n\",\n              \"Jibi4OCAo6MjnTp1yrd81NsEBwfj5OREZmYmMpmMMWPGsHDhQjZu3MTqtR7kyDWRqekgz82B7CRk\\\\\\n\",\n              \"shy2/LaRgQMHFjrQIiMjg/Xr17N06VLatWuHhYUF+/btw9TUlAkTJtC/f/98y015em5h5uzvaT5h\\\\\\n\",\n              \"OUZf1Su0jlkpiVzZPAfbryryx4F9+UJMLpcTFBTEli1b8PPzo1u3bri5udGxY0eluzHfrJ2oSND9\\\\\\n\",\n              \"/Vwz8br0GE0Fl6B6c44RdjWIi4tVqrWnDFW6dooAE4QvSFxcHMePH8fX15cTJ05gaWmJo6Mjjo6O\\\\\\n\",\n              \"2NraKjWKbvPmzcyYMQNdXV0sLS3R09Pjxo0b+Pj40KxZM2QyGZcvX2bXrl08ePCA1atXM2zYMH78\\\\\\n\",\n              \"8Ue6du36zmMnJyezcuVK1qxZQ//+/WnevDl79uwhLCyMUaNGMW7cOF69ekXHLt1pO+s39M2rv/N4\\\\\\n\",\n              \"uTnZXFg5Ebd+PVkwby4vXrzIG5Chp6fHqFGjGDx4MEZGRgo///clSRKxsbFUqlSZXVeeKl3evUN9\\\\\\n\",\n              \"Hvx1DxMTxZbsUpYqXTvFZ2CC8BmTJInw8PC8rsHw8HA6duyIo6MjK1eufOcqHG+Tk5PD1KlT2bt3\\\\\\n\",\n              \"L6amptSvXx9Jkvjjjz84duwYjo6OHD58mFatWmFnZ8eTJ094/fo1derUYcSIEWzbtq3IANPX12f+\\\\\\n\",\n              \"/PlMmDCBJUuWMH369LzW3Z49e7C1taW0ngFW3YcVGV4AGlqlaDh0Dit+GsmF8+e4evUqAwYM4MCB\\\\\\n\",\n              \"AzRq1KjYl5aSJImIiAjCwsK4ceMG169fR5Ik0lOSlbofWE52FmmpKQXWpfxSiQAThM9MRkYGQUFB\\\\\\n\",\n              \"+Pn54evri4aGBo6OjsyfPx97e3uFVoZ4m/j4ePr06cPDhw+xsrKiWbNmXLt2jZMnT+adx8vLC2dn\\\\\\n\",\n              \"Z/bv30+7du3Q1dUlIyMDAFdXV+bMmUNiYiLlyhW83cq/GRsbs3z5cqZOncqPP/6Ig4MDU6dO5cyZ\\\\\\n\",\n              \"MzRt0ZIWbXopXHd98+qUNquOlZUV3t7eH2yh3n/Lzs7m3r17hIWF5T1u3LiRdyfrhg0b0rdvXxKT\\\\\\n\",\n              \"UggJ8KZz36EKH/vSKX/atLVXavj/50wEmCB8BiIjI/MC68yZM9ja2uLo6Mjx48epXbv2B2lhhIeH\\\\\\n\",\n              \"07NnT9TV1WnYsCH29vZs3bqVkJCQfGHQpUsX9u3bR//+/dm1axc6Ojp5AWZkZJS3fezYsQqfu3Ll\\\\\\n\",\n              \"ymzcuJGZM2cyf/58li9fjlntxpQqo9zdjC3aOPEq/v4HC6/Xr19z48YNbty4kRdW9+7do3r16tjY\\\\\\n\",\n              \"2GBjY8O3336LjY0NZmZm+cqamJgwbtJUOvUZovDf5/TBHfw4d/YHqfvnQASYIKgguVzOlStX8roG\\\\\\n\",\n              \"nz9/Tvfu3Rk0aBDbt2//T5NrC3P06FHc3NwwMDCgbdu2dOrUiVmzZnH+/PlC5yS1b9+eQ4cO4eLi\\\\\\n\",\n              \"wrRp08jMzMzbNmLECH788UelAuwNKysrdu3axQ8//MDe4JtKl9fWNyTheaLS5SRJ4unTp/laVGFh\\\\\\n\",\n              \"YcTHx9OgQQMaNmxIy5YtGT9+PNbW1pQuXbrIY3bo0IEyOqXw3bGRnsPHF7n/yYM7yUpJxNHRUen6\\\\\\n\",\n              \"f65EgAmCikhOTubEiRP4+fnh7++PiYkJjo6OrF27lhYtWhTL8kmSJLF48WLWrl1L2bJlGTBgAJ06\\\\\\n\",\n              \"dcLV1ZWTJ09SrVq1t5Zt3bo1Pj4+dO/ePd/giK5duzJ69GguXLhAQkICKSkp6Onp0axZM4XvC1a1\\\\\\n\",\n              \"alXU5NeUfj65WZnol3l3uGRmZnL79u18raobN26gp6eX16oaPHgwv/zyC5aWlu+9Cr66ujq+R71p\\\\\\n\",\n              \"adeKXFkOPYd/jUYhf0N5bi4B+7ZyzGsjIefOqcwCxx+DeCUE4RP28OHDvFbW5cuXadWqVd7nWdWr\\\\\\n\",\n              \"Vy/Wc6enp+Pm5kZ4eDjq6upMmzaNdu3a0blzZ/bt2/fW24P8U/PmzVm/fj1Dhw5l3759DBgwgKtX\\\\\\n\",\n              \"r2JgaEy7jp2pUKMhGjplkWelkfA4nK7duvHN9KkFlpVKSEggJCSE4OBggoODuXnzJmjp0CRXhrqG\\\\\\n\",\n              \"4pex+HtX6Na+Ud7PsbGxeUH15uubz/fehFWvXr1o2LChwpOqlVGlShUu/XmR/q4DmXbIi/bOg2jS\\\\\\n\",\n              \"vjtl9PRJT03lWnAgQYe8MDcz5WJIyDvn432JRIAJwickJyeH8+fP54VWSkoKDg4OTJo0iY4dOyo0\\\\\\n\",\n              \"N+tDePHiBb1798bY2JjY2FhWrVqFnZ0drVu3xsPDg/bt2yt8rEaNGmFubs60adM44u2N/4lTGLXq\\\\\\n\",\n              \"j23/pWiW1svbr2JGCjeuHqdbT2cmjB1Nwwb1CQ4O5ty5czx9+pQWLVrQuHHjvFXps3NyiAw7R5XG\\\\\\n\",\n              \"itUlOy2Fp38GEGtdCQcHB27cuEFqaioNGzbExsaG9u3bM23aNOrWrZtvrllxq1SpEiHngrl+/Tpr\\\\\\n\",\n              \"PdaxftZYUpKTKaunRys7O478cYCmTZt+tPqoEhFgglDCYmNjOXbsGH5+fpw4cYIaNWrg6OjInj17\\\\\\n\",\n              \"sLGx+eg3arxw4QJ9+/alc+fO+Pv7s3v3bho1akTr1q2ZOXMm/fv3V+p4Ojo6yOVy3EaNYvnaTVhP\\\\\\n\",\n              \"WI92OdMC+2nq6mHWuh/lre35df3X1Khiyojhwxk2bBgvX75k0aJFLF26FENDQ6ZMmUKNGjWY89Ny\\\\\\n\",\n              \"zOu3RLNU0YFzy3szuqVLU65cOcaMGYONjQ3VqlVDTU2N+Ph4tm7dypTpM3n9+jW6pUvTpJEtE78e\\\\\\n\",\n              \"T716hU+Q/tBsbW3Z4vn7RznX50JMZBaEj0ySJG7dupXXyrp9+zadOnXC0dGR7t27Fxit9jFt3bqV\\\\\\n\",\n              \"7777DldXVw4cOICPjw9169alU6dOtG7dmmXLlil9zNjYWGrVqkWWLJfaX29Gx6jo+1plJkRx12MM\\\\\\n\",\n              \"g137c+DAAVJSUmjSpAmLFy+mY8eOwN8DWZycXbj5LI6WE39BU7vwkYWSJHHXfzuRZ/ZjbmpC1apV\\\\\\n\",\n              \"2bJlCyYmJmRlZTF56nR2795FnZYdqNPWgTIGhuRkZfAk7CLXju2nbt067Ni6ha+++krp566KVOna\\\\\\n\",\n              \"KQJMED6CjIwMTp8+nTfUXUtLK28FjLZt2/6nuVmKyszM5MKFC8THx6Ojo0OdOnWwsvr7/l0ymYxv\\\\\\n\",\n              \"v/2Wo0eP0qNHD3x9fQkICMDCwoI+ffqgp6fHjh073qs1+Pr1a0wqmFLBtiMW/eYoXO7+nkWk3DnH\\\\\\n\",\n              \"6NGjmDdvHnFxcfk+qwoLCyMnJwed0mVJzZZRs+sQLFr1pFSZv7sl5bkyXoad49mZA6THPEOWlYGT\\\\\\n\",\n              \"kxO6urocPXqUTZs2sWz5CpLRwWHyQsoYFBxNmSvL4ZL3Ti4f2sLZoNPUqVNH6eevalTp2im6EEtA\\\\\\n\",\n              \"amoqCQkJaGtrY2RkJEYVfaYiIiLyAuvs2bM0atQIR0dHTpw4Qa1atYp99Yc3nj9/zpq1HmzZuhWT\\\\\\n\",\n              \"StXQN65ATnY2z+7eoH79+oxxG8mOHTuAv4e/BwcHc+HCBUxNTRk3bhwZGRkcOHBA4fDKyMjg0qVL\\\\\\n\",\n              \"eQMuLl26hKSugUkLF6XqXcnelfsPL3P58mUsLCyoWLEiNjY2NGzYkEmTJmFjY0OlSpUAOHfuHKvW\\\\\\n\",\n              \"rsP/G0f0TczQ0NAiOT6ar776ikUzJtGvXz8yMzNZvnw5GzZswN7enkFDhlGtflMGzF2N+lvWPdTQ\\\\\\n\",\n              \"1MKujxu6euXo0q07d2+Hf7TPIYWiiRbYRyKTyTh69Cjr1q3n4sULGBoakZmViYa6Om5ubowbN+6d\\\\\\n\",\n              \"Q5KFT19ubm6+uVkRERF0794dBwcHunbt+sHnZini9OnT9O0/AJtOvWjecxDGlavnbZNlZ3EzOIDA\\\\\\n\",\n              \"basx1CtN3dq1eP36NUeOHMHAwICFCxfi7e3NmTNn8laNL0xSUlLeCMFz584RFhZGgwYNaNOmDW3b\\\\\\n\",\n              \"tqVZs2aYmpnRcskZpUJbkiQuzelI4IkAmjZt+s46vJGQkEBERAQ5OTlUqFCBKlWqFNgnJiaG2bNn\\\\\\n\",\n              \"47VnL7MP/Im2rmKL4u5fOIHxg5zfa/6aKvnUrp3vIt76fwTPnz+nRw8HyujpMXL0eLbv887rMnr4\\\\\\n\",\n              \"4C+2e27GtlEj5syezYwZMz7aO3Phv0tKSso3N8vU1BRHR0fWrVtHixYt/tONGf+ry5cv06dff1zn\\\\\\n\",\n              \"rsHKtkWB7ZqltGnUyYn6bbqw9Xt3roXd4O7tcHR1dfn999/Zvn07ISEhBYIjJiaGc+fO5QXWw4cP\\\\\\n\",\n              \"adq0KW3btmXRokU0b948b6X09PR0bt26hZq6htL/rtXU1EBdg0WLFlGhQgUMDQ0VfryLqakpZuYV\\\\\\n\",\n              \"ada9n8LhBdDIYSCr1i7H3d1d/B/9RIgWWDGLioqipZ0dbu5fM37itLfuFxnxgoF9HBk2dAizZ4ul\\\\\\n\",\n              \"Yj5lDx48yGtlXblyhdatW+fdN+tTaUVLkkS9+g1p0s+dhu16FLl/TlYm6yf0YcOq5cjlctzd3QkO\\\\\\n\",\n              \"DsbKyopnz57lhVVwcDCvXr2iVatWtG3bljZt2lCjRg2eP3/Ow4cPefjwIY8ePcr7PiEhgerVq3Pv\\\\\\n\",\n              \"/gOazvNGq7TiSz/JMlK5uqgXvj5HSUxMJCEhociHlpYWhoaGGBkZvTPgps74BseZy6lcu+i5bG/I\\\\\\n\",\n              \"5XJWDm5D6KWLn/V8rE/l2qkI0QIrZqNGjab/wGHvDC+ASpWrsP/IMbq2a0nHjh1p1qzZR6qhUJTs\\\\\\n\",\n              \"7Ox8c7PS0tJwcHBgypQpdOzYsdjuy/RfXLhwgeS0dBrYd1dofy1tHez6jGThT4u5f/cOY8eOZcGC\\\\\\n\",\n              \"BQQHB5OTk0OLFi2wtLRk+PDhZGVl8eTJE44cOcLy5ctJS0vDysoKKysrLC0tad68OYMHD8bKygpd\\\\\\n\",\n              \"XV2WLl3K/UcexF45RkX7AQo/h7hrATg69aJbt24K7S9JEmlpaW8Nt7i4OO7fv09CQgKxr16hZ6zY\\\\\\n\",\n              \"qh9vqKurY2Bcgfj4+M86wFSJCLBi9PDhQ66EXmHT9n0K7W9mXhH3CVPw8FjHjh0iwErSm7lZvr6+\\\\\\n\",\n              \"BAYGUrNmTRwdHdm3bx82NjaffBeSx/oNNHEYqFQ9bTo4cnj1AkppanDs2DG0tbUpV64cL168ICgo\\\\\\n\",\n              \"iIiICCwtLbGysqJ9+/aMGTMGKysrzMzM8p1HkiQCAgLo168fV69eRVtbG91SmiRc8ca8TT/UFBgM\\\\\\n\",\n              \"IsnlJIYeZfrubQrXX01NjbJly1K2bFmqVq1aYHt2djb3798nPDyc02eDyc3OVvjYb+RkZ32UEaOC\\\\\\n\",\n              \"YkSAFaMNGzYwcPBwpWb1Dxw8nOY2tYiLiyuWpWtUkSRJeWvmlSlTBmNj4w8eIJIkcfPmzbxW1t27\\\\\\n\",\n              \"d/PmZq1duxZT04ITbz9ld+7ew36McqP+tLR1MK1mSc3KprRu3TqvVWVlZYWRkVGRr3l8fDwLFy5k\\\\\\n\",\n              \"+/btpKSkYG1tzZ49e3BxcaFJkyZkySSeHF2NRa+p7zyWJEm8PLEZq2qVaNOmjVLPAf7u6nvy5Am3\\\\\\n\",\n              \"bt0iPDw87/Ho0SOqV6+OtbU1JiYVeH7nGoYVCwbd26S+jicxNuaT6SYWRIAVq0uXLvPN9/9Tqoyh\\\\\\n\",\n              \"kRHWDRpy48aNvAmbX6qUlBS8vLzwWL+eF8+eo2dgQFpqKkZGRnw9fhwjR44s8gP7d0lPT+f06dP4\\\\\\n\",\n              \"+vri5+eHtrY2jo6OLFq0iLZt26r0PZeys7OUulX9G+WNjJk+fTo9ehT9uRn8HTa+vr4sWLCAsLAw\\\\\\n\",\n              \"ypQpw5AhQ5g/f36+0JckidjoSNQT4ok4uhLzLmPQ1C04qlCWkUJUoCc6cffwPx9cZNBFRUUVCKq7\\\\\\n\",\n              \"d+9ibGyMtbU11tbWODo6MmvWLGrXrp33ZtLb25uZ837EplNvhV+bawEHcHZ2Rl9fuVu4CMVHBFgx\\\\\\n\",\n              \"Sk1Nfa/PRzQ0tdi5cyfPnj3D2Ng436NcuXIffWmhknDhwgWcnV1o0Lg5E2f9RFM7e9TU1P5exeL6\\\\\\n\",\n              \"Ff7w+p2fflrMzp07cHBwUPi4L168yJubFRwcTOPGjXF0dOTkyZPUrFnzk+8aVJSRkTFJsTFUqqHc\\\\\\n\",\n              \"MkiJr6LzrRz/NrGxscydO5fdu3eTlpZGo0aN8iZB//s1DAkJ4c6dO8yYMYPZs2fjPn4C3ksHoF/b\\\\\\n\",\n              \"DsMGHdDULYMsI420+xeJuxmEg4MDv/teyHfX4YSEhHwh9eahpaWVF1R2dna4u7tTr169IkPGwcGB\\\\\\n\",\n              \"cRMm8ujaBSwb2RX5fNOTXhPqs5sAv6NF7it8PCLAipG+vj6Jicrfeyg5KZHMTGPOnz9PXFxcvkdK\\\\\\n\",\n              \"Sgrly5cvEGxvHkZGRgV+p6+vr1IX5j///JOeTr1YsGITdvad8m1TU1OjQaNmNGjUjFvXrzDSbRBb\\\\\\n\",\n              \"t2x5a4jl5uZy+fLlvK7Bly9f0r17d4YNG4aXl5dCdwVWRb17OrDt8F7q2nVQuEzE/XByMtNo0qRJ\\\\\\n\",\n              \"odslSeKPP/5g0aJF3Lp1i3LlyjF69Gjmzp371jluPj4+uLm5YWtri52dHfr6+uzdtZOhQ4fyKjaW\\\\\\n\",\n              \"9GcnSU5OxkhPD9eu9gzd+Svx8fEcOnQoX1C96ZJ88+jXrx/16tVT+PYr/6apqYnX9m30HTAQ1/9t\\\\\\n\",\n              \"pErthm/dNyMlib0LxjF8yCAaN278XucTiocIsGLUoUN7/H0OY99e8a7Al5ER3Ltzm3p1auPs7Ezn\\\\\\n\",\n              \"zp3zfYYmk8nyRlT9+xEREUFYWFiB32dmZhYabO96lC5dukRCLzs7G5c+fZm3bH2B8Pq3+rZNWbph\\\\\\n\",\n              \"F0OHDuDx40d5YZSUlERAQAC+vr4cO3YMc3NzHB0d2bBhA82bNy/RuVkfw4MHD9i4cSMvIl+SFBeD\\\\\\n\",\n              \"gbFin99dPrqbr8eNLfD6REdH8/3337N//37S09Np0aIFJ0+epEOHd4fj1q1bmT17Nn5+fvzyyy95\\\\\\n\",\n              \"d2XOzs7G39+fPXv25GtZ7di+jaVLfqZWrVp5QTVlyhSsra2pUqXKB//32LFjR3Zs9WTYiJHYdOlD\\\\\\n\",\n              \"Y4eBGJr/38TnrPRUwk56c+nwVvq79OKXZUs/6PmF/07MAytGL1++xNramtBbD9FTsN986U//Izbq\\\\\\n\",\n              \"OY0aNeLw4cNcv36drl274uzsjIODw3v1v2dlZREfH19o6BX2iI2NBVAq8IyNjT/I6Kz9+/fz6+p1\\\\\\n\",\n              \"rNvlo3CZH6a40ah+bYwMDfH19SU0NJQ2bdrkzc0qbETa58rPz48BAwchlyR09A3RNzRm3IqdaBXx\\\\\\n\",\n              \"twk/dwL/dQsJv3kDExMTJEli7969/PTTT9y5cwdjY2NGjx7NnDlzilxKSZIkli1bxoYNG/D390db\\\\\\n\",\n              \"W5vRo0djaGhIqVKluHDhAhEREdSoUQNra2vq16+fF1iWlpYffWm1x48fs9ZjHdu2b8e4UjXKljMi\\\\\\n\",\n              \"JyuT53/dol379kydNLHIsP6cfArXTkWJACtmQ4YMRaOUDr+sWl/kO8i/7t3B2aETwWfP5i0aGhsb\\\\\\n\",\n              \"y9GjRzl06BDnzp2jTZs2ODs74+Tk9N7dJ4pIT09XOPDePLS1tZUKPENDQ7T+NdCgTVt7HF1H0amH\\\\\\n\",\n              \"4h+uX798gWmjBzBooCuOjo5/36r9E5ybVZwkSWL58uXMX7gIszpNaDzkG8oYm3N+3RxyU+MZ/P0K\\\\\\n\",\n              \"ypsWXAU+Vybjiv9+Tu9YQ8Axf8zNzZk1axaHDh0iKyuLNm3asHjx4gI3mCzs/FFRUdy8eZOff/6Z\\\\\\n\",\n              \"mzdvUq1aNR48eICenh45OTmYm5vTpUsXbt68Sa9evZg0aVJxvRzv5c0ajomJiZQuXRpra2sqVix6\\\\\\n\",\n              \"5fzPzady7VSECLBilpKSgl2rVtg2bsbSXz3e+u7y1o3rDB3gzJIlPzN06NBC90lOTsbf35/Dhw8T\\\\\\n\",\n              \"EBBAw4YNcXFxwdnZucRbGZIkkZKSolTgJSQkoKenly/UAgICOB32DN3SigeQJEm0q1+ZyMiIfB/8\\\\\\n\",\n              \"fymysrIYP348R476YlynKXbjf0Jd/e9uQEku58ahjdw9sYfq1k1o2tUZfSMTZNlZPL55hWvHD2Jl\\\\\\n\",\n              \"+RWO3bvh5eXF/fv3MTMzY/z48XzzzTeFTgF5/fp1vs+n3owC1NDQQENDA3V1daZPn05MTAw+vn5k\\\\\\n\",\n              \"5sioalmbUqW0eR0XTfj1UNxGjmTWrO/EkPRP0Kdy7VSECLBi9uZdbFpaGsnJKQwdOYZ+roMxM69I\\\\\\n\",\n              \"ZmYmVy5dZNvvG7l08TwbN26kX79+Ch03MzOTwMBADh8+zNGjR6levXpemKnKLR/kcjmvXr3i/v37\\\\\\n\",\n              \"PHjwgEePHrFkyRIuP3qt9Ocd3ZvXJOz6tS/uHXN0dDQuLi7o6uoSeiOc3iv90dAqOPw/JzOdJxeO\\\\\\n\",\n              \"EXE1iKy0JNQ1S5H44gEN6tYmLCwMmUxGhw4dWLJkCba2tgCkpaVx9+7dfCFV2IAKa2trLCwsGD9+\\\\\\n\",\n              \"PKVKlWLDhg30H+BKjpoWfdwm0aBpq3x/z+iIZ/jv306Qz34OH/qD1q1bf7TXSyjap3LtVIQIsP/o\\\\\\n\",\n              \"0aNHbNu2jefPXyCXy6lY0ZwhQ4ZQv359JEnC3d2d+Ph4Dh48yI0bN1i/fj3e3t7Ex8ejra1NnTp1\\\\\\n\",\n              \"GTduLIMGDXrvbi+ZTMa5c+c4dOgQhw8fRk9PD2dnZ1xcXGjcuPFHH4whSRLJyclER0cTFRX1zq9J\\\\\\n\",\n              \"SUmYmJhgZmaGubk5gSdPEnD5Pnr6io8OzM3Nxb5+JWJfvfqibnVx7do1evfuzciRIzlx6jSZZtbY\\\\\\n\",\n              \"9v1a4fJhhzfzMGAXs7/7hu7du3Pv3r18LauXL1/mG1Dx5lG1atV8/6bi4uJwcHCgXr16rFmzhi5d\\\\\\n\",\n              \"u1Ghem3cZ/30zikfV0OCWPn9BE6dPEnDhm8fBSh8XCLAPgHF/UcICwtj1qzZhF4NZcDAIdSua42a\\\\\\n\",\n              \"mhqPHz5g7+4dWH5lSZMmjQkMDOTixYsFVvSWJKlYgkUulxMaGsqhQ4c4dOgQmZmZeWHWunXr/zQC\\\\\\n\",\n              \"TyaTERsbW2gY/ft3GhoaeaH076///N7IyChfnRwde2Lbpiu9BwxTuF7nTh1nx/plXLuqGv/p3ub5\\\\\\n\",\n              \"8+ds2LiRAwcOEh8fj5aWFlY1avD1uLH07ds3X3fevn37mDhxImvWrOHkyZNs37GTvmuOU7q8icLn\\\\\\n\",\n              \"y0iKZ//ELmhpqOetUPHPh5WVVZEDKp49e5Y3yGjx4sVs2LCB7Xv/YP663QrNVzx2cCdXT3lzPvis\\\\\\n\",\n              \"wvUWipcIsE9Acf4RTp06haurK7PnLmTAoKHo6ua/lXlOTg4+Rw4xbfI4fly0iClTphRLPYoiSRJ3\\\\\\n\",\n              \"7tzh8OHDHDp0iIiICJycnHB2dqZTp055owbT0tIKDaF/f42Pj8fQ0PCtYfTPr+/bEjp+/DgzvpnF\\\\\\n\",\n              \"9qNnFQ74ScOccXcbyogRI97rnCUtKyuLceO/5siRI3Tq2Y9OvQZQwbwSOTnZ3Lt5Hf/923h0L5zN\\\\\\n\",\n              \"mzbSq1cv5s2bh5eXF5s3b2b+/PlUqFCBY8cDGLL9itLn3jWyOS+ePX2vAUHh4eF0796dGTNmMHXq\\\\\\n\",\n              \"1L9XwLeuz/CZi2jYTLFuwZycbNy6NuZs0Gnq1q2rdB2ED0+VAkzMA1PS3bt3GThwIFt27qNVG/tC\\\\\\n\",\n              \"99HS0sKl3wBq16mLi1NXGjdu/NH7+eVyOXFxcchkMpo2bUrlypW5e/cuFy9eZMSIEbx+/RpdXV1k\\\\\\n\",\n              \"MhlAoWHUunXrfD9XqFChyHfkSUlJf6+EnpxMmTJlsLW1xczMTOF6d+nSBfn0GRzesw2XQSOL3D8o\\\\\\n\",\n              \"wIcbVy9x1NiAbt26KXWuT0F2djYOjj2RqWvjFXgN3TL5g791J3Nad+rBvVvXGDVmEF9//TWSJFGt\\\\\\n\",\n              \"WjUcHR3R1tYmKysLWW7ue7XqJUl6rzcbISEhuLi4sHLlSgYNGgTApUuXSM/KpkHTVgofR0urFF2c\\\\\\n\",\n              \"B7N582+sWrVS6XoIXzYRYEpavPhnxk+a9tbw+qe61vX530/LmD9/AadOnfwg58/IyCA6OvqtLaU3\\\\\\n\",\n              \"37969QoDA4MCoeTsUy+HmgAAIABJREFU7Mz48ePR1tYmPDw875bv9erVyxue/z6LCIeHh7N27Vr2\\\\\\n\",\n              \"7d9PXesGlCtnSHp6KmHXQunUqROTJ02ibdu2RR5HXV0dn6PetG7TFrkkp88gt7delAN9D/Hj7MkE\\\\\\n\",\n              \"BBzH39+fBg0asGzZMoYPH16gjCRJXLlyBY9167l48SJpqano6evTvp09EyZMoH79+ko/Z2Xk5uYS\\\\\\n\",\n              \"Hx/Pq1eviI2Nzfu6Z89ectS1+Wmj5zu7d2vXb8Tybd5MdO2KOhKXL1+mXLly1K1bF1NTU46dCCQx\\\\\\n\",\n              \"4hHlq1gpXKekl08oW1avQA9CUd6sruHl5UXXrl3zfv/gwQNq1muodIha1WvIleP7lSojCCACTClx\\\\\\n\",\n              \"cXH4+Bwl9OZyhcv0dunHgh++4969e9SuXbvQfd6stv62z5P++TUjIwMzM7MCXXhNmzbNF1SmpqZF\\\\\\n\",\n              \"Lkbr4vL3auVJSUn4+flx6NAhpk2bRuPGjXF2dqZ3796F3pL937Zv387Mmd8wdPR4As5do4KZed62\\\\\\n\",\n              \"5OQkjuzfzeAhQxk+bCiLFi0q8gJnZWXF+XPB9Ordm8O7t+AyeBQdujqhp29AWloK50+fYM+WdaSn\\\\\\n\",\n              \"JmP5lQV+fn78/PPP9OvXj1GjRrF79242b95M9erVgb8nqroOHERMTAyuw8awZtRkypTVIzk5kUA/\\\\\\n\",\n              \"b7p07UadOnXYvctL4RacXC7n9evXBQLp/7V353E1Z/8fwF+V9r3urW5FpQ2pVLYUZUmWBlmKyFYS\\\\\\n\",\n              \"yk62wVjHMpaxRMgylsi+RLImQvjKrlRCaUfal/v+/TE/PaZpu1c1XM7z8bgPufdzPufcj5nz6nw+\\\\\\n\",\n              \"53M+GRkZ1b738eNHqKiogMvlQkNDAxoaGlBWVsbD2IcIPndboGuT+kamcB/rh9MHd2DdunV4+vQp\\\\\\n\",\n              \"Tp8+jffv36O1WSu8vHgIHb1+Faj9APDqylF4e3sJFTi7d+/GvHnzcO7cuSrPrCstLYVEE+EXEJZo\\\\\\n\",\n              \"IomSklKhyzEMuwYmhKCgIFy8dBVBu/cLVW7hvFnISk9F165dqw2l9PR0yMnJ1Xg96Z9/qqqqNuqs\\\\\\n\",\n              \"wsLCQly8eBHHjx/H2bNnYWhoiIEDB2LgwIEwMTGpsn1oaCimTJ2GvaFnYWRSfUADQFZmBsYO7Y8h\\\\\\n\",\n              \"g1yxcOFCgdrC5/Nx5coV/LlpMyKvX8fnz7mQV1CAubkFXjx/hpSUFOTl5aFTp06YPn06fH19UVpa\\\\\\n\",\n              \"inXr1mHNmjX49ddf0bNnT3Tr1h3efjMxYuz4aicWlJaWYtuGVTh97CD+2rcP4uLidQZSdnY2FBUV\\\\\\n\",\n              \"KwXSl5///aeGhgbk5eXx5s0bxMXFVbyuXbsGNZ4eVgYJ9rw4APiQnQmP7lZobdYKAwcOhIuLC/73\\\\\\n\",\n              \"v/9h/vz5yMr5iIHrz0FWpe7FeItyc3BmtiuexP5PoHux/rm6Rnh4OExNTatsc/bsWSxavgrLdx4X\\\\\\n\",\n              \"+PsAwLnDe5CT+AgH9v8lVDmmcbBrYD+o9+/fQ9+gudDlDJobITzsDBQVFcHj8WBhYQFnZ+dK4STM\\\\\\n\",\n              \"M8Mak6ysLPr374/+/fujtLQU169fx4kTJ+Do6AhVVdWKe82srKxQWFgI3wkTsOfwmVrDCwA4XA3s\\\\\\n\",\n              \"PHAcfR3bY9iwYTA2Nq6zLeLi4ujRowd69Ki6JqKdnR3Cw8PRv39/nD9/Hvb29tDV1YWLiwsCAgLQ\\\\\\n\",\n              \"v39/jBkzBr8uXISAX5fDzXNsjfVISkrCf9YCEIBf+vVDyxYtoKmpWRFAzZs3R8eOHSsFEofDqbKK\\\\\\n\",\n              \"CJ/Px9u3bysC6tatWxU/p6SkoFmzZjAxMYGJiQlsbGzw7MVLdHcdVudx+CdVdS6s2tli7swpkJCQ\\\\\\n\",\n              \"wIgRI6CiooLQ0FBcuBiBXev84Tg7ENIKNd/QXZyfiyurJ8Jv0kSBwovP52PmzJmIiIjAzZs3oaOj\\\\\\n\",\n              \"U+nzkpISXLhwAfv27cOT/91D2rtkaOkKfoPyldOH8fvSRQJvzzBfsAATgoSEBMr55UKXKy8vR48e\\\\\\n\",\n              \"PRAYGNgIrWo8kpKSFQGyadMm3L17F8ePH4ebmxvKyspgbGwMS+t2aG1pJdD+NLR4GOwxEtu2bcMf\\\\\\n\",\n              \"f/whdHsKCgoqRkEdOnTAwoUL8fLlS2RkZMDKygoDBw6EoaEh8vPzkZGRASKCcQuzWsPrn/xnzsf5\\\\\\n\",\n              \"00exatWqGq/XERGysrJw9+7dSqOpuLg4JCQkQE1NrSKkTExM4OTkBBMTExgYGFQJvP0HD0FZVfjn\\\\\\n\",\n              \"mSkqq2L27NmQlJTEypUr4eLiAjExMdja2iI3NxcHl4xEK9cJ0GvXvdIpvfKyUryJuYInx7eiKDcH\\\\\\n\",\n              \"Hdu3q7OukpISjB07Fq9fv0ZkZGTFqvN8Ph83btzAwYMHcezYMbRq1QoeHh7gcjUQdmQvxk4XbJT9\\\\\\n\",\n              \"8vH/8Ck7Q+DnjzHMP7EAE4K+vj6iDhwSutyzJ4/R0tSwEVr03xEXF0fHjh3RsWNHrFq1Ck+ePEG/\\\\\\n\",\n              \"fv2xcKVwM8c8Ro2Dq7M9Vq5cCT6fj8zMTIGuIWVkZIDP51eMgtTV1fHixQskJCTA0NAQrVu3ho2N\\\\\\n\",\n              \"DbZt24YTJ07A2toavXr3waDh3gK3TUxMDMNG+WDzli2wsbFBfHx8pYB6+fIl4uLiQEQwNTWtCCk3\\\\\\n\",\n              \"NzeYmJjA2NhYqBl98nLyKCwsEOr4AUBOdhacnJywcePGStfOxMTEsGHdH+jm6IBVa9fj1IG10G1j\\\\\\n\",\n              \"DwkZeZQX5eFd7E20NG2BoD//gK6uLvr27Qt1dfUaZ8jm5eVh8ODBkJKSwsWLFyErK4uHDx/i4MGD\\\\\\n\",\n              \"OHToENTU1ODh4YH79+9XjOSSk5PRtl17WHawh41d7QvgfvqQjfXz/bBo0cIf/gkBTONg18CEkJ+f\\\\\\n\",\n              \"j2bNmuHyjbtopqcvUJnc3FxYtNDH82fPqpx6EXVKSkq48TAeSkrCrT/YxoiH8rJSlJSU1HoN6d/v\\\\\\n\",\n              \"KSgoVLr+N3nyZKioqGDJkiUV723atAlbt25FREQEjI2NcT8+vcrIpzY52VnobGUECXFxGBkZVRpN\\\\\\n\",\n              \"fXlxOJwGuQ45f8ECvHyTiYnzVghcprioEMO7t8G9mLto3rz209nPnz9HVFQUcnNzoaioCHt7+0r3\\\\\\n\",\n              \"WkVERGDEiBG4dOlSlVmYX1bXaN26NebMmYMjR45UPLzSw8MDHh4eaN26dbX1RkVFYYCrK4b7zUWP\\\\\\n\",\n              \"/u6QrGZpq7gnD7Fu3iR4uA/BihXLBf7+TOMTpWtgLMCENGXKFBSXAStWrxNo+z/Xr0Hg5g1obmCA\\\\\\n\",\n              \"5cuX/1CPZZCWlsbDV2mQEXIadhdrU5w5fQqWlsJPuf6nx48fo1evXkhOTq50f9rMmTMRFRWFxKQk\\\\\\n\",\n              \"3Hz0Wqh9EhFaaCuguLhYqOD7Gm/evIGFZRscuPxQ4MWLw0+G4OG1swi/cL5B2hASEoJZs2YhKiqq\\\\\\n\",\n              \"0iiqe/fuaN68OfLy8hAfHw83Nzd4eHjA1tZWoBU2Hj16hEn+k/HixQv0dB0OIzNLSDSRROb7d7hy\\\\\\n\",\n              \"+jByczKxaNFCeHt5Ncj3YBqOKAXYj/9s+gY2d+5cXDh3Cvv3Bte57YWwM9i2eQOuXb0Kf39/+Pj4\\\\\\n\",\n              \"oHv37rh9+/Z/0NLGp67OQdr7VKHKFBcX40NONvT19es9ijE3N4eenh7CwsIqvb969Wpoamoi73Oe\\\\\\n\",\n              \"0PssLiqCpKRkoz+Tiujve7kA4MB2wX4Zyvuciz1/rkRXx7rvQRTU0KFDMXv2bPTs2RNJSUlYufLv\\\\\\n\",\n              \"B0qmpKSAy+ViwYIFSE1NxZYtW2BnZydQeAGAhYUFbly/hutXr0BDlhBz4Qgij/8923DVssV4nZTI\\\\\\n\",\n              \"woupN3YNTEhaWloIDw+Hc69eeBT7EBP8psCgeeXrW6mpKQgOCsSh/Xtw6tQptGzZEi1btsSQIUOw\\\\\\n\",\n              \"Z88eDBkyBFZWVli6dKlIL2I6cKArThw5gGlzBLtgDwDh506iRYuWDfbYEx8fHwQFBaFfv34V74mL\\\\\\n\",\n              \"iyMkJARaWjw8iX2A1pbWAu8vOuoaLOo5MqzL27dv4efnh7i4ODh06YxTB3ZBWUUdQ8bUvBDv508f\\\\\\n\",\n              \"sdjfE+1srLBx40bk5uZi8eLFdd7rV5eSkhLo6+tDQkICRkZGkJCQgI+PD1atWtUgz1Rr1aoV1q8X\\\\\\n\",\n              \"LKAZRlhsBPYVTE1Ncef2bagqyaFXNzsM7t8LC+bMwK9zZ2KEuyu6dGiDovxPiI6ORocOHSrKSUpK\\\\\\n\",\n              \"Yty4cYiPj0e3bt3g7OyMoUOHIi4u7ht+m683ceJEHN6/GyUlJQJtT0TYE7QF79+nwtraGvv27RO4\\\\\\n\",\n              \"bE3c3NwQHR2NN2/eVHpfVlYWkydPxu7tm4TaX8jeHZg0UfAV3YVRXl6OzZs3w9raGhYWFmjZsiUy\\\\\\n\",\n              \"MjLQwtQEB7evg//QXoiKOIfy/1/eCwBysjJwcPt6+Lo6wKFTB5w6eRKxsbF49OgR7Ozsvuq/HT6f\\\\\\n\",\n              \"j8jISIwfPx7a2tpYvXo1HBwc0KRJE5ibm2PdunU/3QNBGdHEAuwraWpqYs2aNXj79i0mjPeBoX5T\\\\\\n\",\n              \"6DfVxvBh7khOTsbWrVsrVoL4NxkZGUydOhWvXr2ChYUF7Ozs4OXlheTk5P/2S9RTq1atoK+vh4Ap\\\\\\n\",\n              \"4yHIpdTgbX+itKgQiYmJWLFiBfbv3w99fX0sX74cWVlZX9UGOTk5eHh4YNeuXVU+8/f3w/VLFxD/\\\\\\n\",\n              \"4plA+7p/NxqPHt6Du7v7V7WlNk+ePIG9vT0OHz6MkydP4vLlyygoKEBiYiIyMjKgo81DHydHhB3a\\\\\\n\",\n              \"DrcureDT3x5jereHl0snUG4azp09jQ0b1kNCQgIaGhoVyznZ2dlh586d1R7/e/fuYfSYMTBo3hwc\\\\\\n\",\n              \"Dhc6uk1hYWkJHo+HiRMnwsDAAPfu3cPYsWNx8uRJXL16Fc2aNcPo0aPB5/Mb/BgwTIOjH5SNjc23\\\\\\n\",\n              \"boLAcnJyaN68eaSmpkZ+fn70/v37b92kOpWXl9PcuXNJT0+PLC3b0EA3D7r/8h0lZhZWeT1Nzib/\\\\\\n\",\n              \"GXNJT0+fkpOTK+3n0aNHNHbsWFJRUaHx48fT8+fPhW7Lo0ePSEdHh0pLS6t8tm/fPuJp61BY5AOK\\\\\\n\",\n              \"Syuo8RUadp24Gpp0/vz5rz4m1SksLKT58+cTh8Ohbdu20YsXL8jIyIgcHBxISkqKOBwO7d27l8rK\\\\\\n\",\n              \"yirKpKam0qNHj+jFixf0+fPnWvf/9OlTsrS0JFdXV8rKyiIiordv35KtbSfSbaZH/gGL6djle3Qx\\\\\\n\",\n              \"Jp5OXH1A0xasoGb6zcnK2pri4uLo999/Jz09PXrx4gURERUUFFDnzp1p8uTJxOfzG/RYMKJBlPpO\\\\\\n\",\n              \"FmDfkfT0dJo6dSqpqalRQEAAZWdnf+smVevz5880YMAA6ty5M2VkZFB+fj55eXmRkrIyDRnmSTsP\\\\\\n\",\n              \"HKejYVdpb+hZ8vL1JzU1derr4lJrMKelpdGiRYtIQ0OD+vTpQxEREUJ1oLa2tnT69OlqP9u9ezcp\\\\\\n\",\n              \"K6uQp9cEioh+XCm4zl6NocHDRpGqmlqN5b/WtWvXyMTEhAYNGkQpKSkUFRVFHA6HeDweSUhI0KxZ\\\\\\n\",\n              \"s6iwsLDe9RQVFdGMGTNIR0eH9u/fT9raOuQ/ezHdTcih+68/VXnFJH6ggKVrSUlJmYyNjendu3eV\\\\\\n\",\n              \"9vfhwweysLCg5cuX17ttjOgRpb6TBdh36M2bNzRu3DhSV1en3377jXJzc791kyokJyeTpaUljRkz\\\\\\n\",\n              \"hoqLiyt9lpGRQStXrqTuPXpQu/btybFrN5o9ezYlJiYKvP+CggLasWMHtWrViszNzSk4OJiKiorq\\\\\\n\",\n              \"LLd7927q27dvjZ8nJCSQk5MTycrKkZGJKbWxbkfNDY1IS4tHlpaWNGvWLIHbWJecnBzy8vIiXV1d\\\\\\n\",\n              \"OnnyJBERBQcHk6ysLImLi1Pz5s0pKSmpweoj+vuXitmzZ5OsnDxNnvNbtcH179f8lRtJ36A5lZSU\\\\\\n\",\n              \"VNlfamoqGRgY0I4dOxq0ncz3T5T6ThZg37H4+HgaPnw4aWho0Nq1a6mgoOCbtufWrVvE4/Fo7dq1\\\\\\n\",\n              \"jX56ic/n04ULF8jZ2Zm0tLRoyZIllJGRUeP2+fn5pKamVuUU5b9NmzaNrKys6OrVq/TkyRMqKSmh\\\\\\n\",\n              \"uLg44nA49OHDh3q3OSQkhHg8Hk2aNIk+ffpE+fn51Lt3bxITEyM5OTny8/OrdLqwPoqLi+n06dM0\\\\\\n\",\n              \"dOhQUlJSovbt25OhSQu6l/RRoAC7//oTte1oR6GhodXuPy4ujrS0tCpCmPk5iFLfyQJMBDx+/JgG\\\\\\n\",\n              \"DBhAOjo6FBgYWGXk81/466+/iMPh0JkzZ/7zup88eULe3t6koqJC48aNo6dPn1a7nZ+fHy1atKjW\\\\\\n\",\n              \"fZWXl5Obmxu5ublReXl5xfujR4+us2xtkpOTqW/fvmRmZka3bt2i0tJS2r59O8nLy5OsrCwpKyvX\\\\\\n\",\n              \"GBTCKC8vp8jISBo/fjypq6uTnZ0dbdmyhTIyMqhf/wE0b/l6gcPr/utPtOLPXdTFwbHG+mJiYojL\\\\\\n\",\n              \"5VJkZGS9286IBlHqO1mAiZC7d+9Sz549ycDAoMqF/8byZbKGgYEBPX78uNHrq016ejr99ttvpKmp\\\\\\n\",\n              \"Sc7OzhQeHl5pJPjo0SPS1dWtdjLHPxUWFpK9vX2l04YJCQmkrq4u9HXHsrIy2rBhA6mrq9PSpUup\\\\\\n\",\n              \"qKiITpw4QaampqSqqkqampr1PnZ8Pp9iY2Np9uzZ1LRpUzIzM6MVK1ZUOQ0pIyND12KThQqw23GZ\\\\\\n\",\n              \"JCUlVetp2oiICNLQ0KDY2Niv/g6M6BClvpMFmAi6du0a2dnZUcuWLSk0NLTSSKIh/XuyxveisLCQ\\\\\\n\",\n              \"goODqXXr1mRmZkY7d+6smAzRsWNHgSZjZGdnk6mpKW3evLniPW9vb5o/f77A7YiNjaX27dtTly5d\\\\\\n\",\n              \"6MWLFxQZGUm2trbUokULMjAwIB6PR05OTl89GScpKYlWrFhBZmZm1LRpUwoICKgxRIqKiqhJkyZC\\\\\\n\",\n              \"nT788lJT51B6enqtbQkJCSEdHR2hrmcyokmU+k4WYCKKz+dTWFgYWVlZkbW1NYWFhQl0Xeru3bs0\\\\\\n\",\n              \"fvx46unsTD2cnGjUqNEUERFRJQRrm6zxveDz+RQREUG9e/cmTU1NWrRoEa1fv55cXFwEKp+YmEg8\\\\\\n\",\n              \"Ho9OnTpFRH8HhpqaGmVmZtZarqCggObOnUtcLpd27NhBsbGx5OLiQnp6erRixQri8XikoqJCs2fP\\\\\\n\",\n              \"FnqUnJGRQVu2bKFOnTqRuro6+fr6UmRkZJ2/pMTHx5O4uDjdjssUKrzuJX0kGVnZOqfrExFt2rSJ\\\\\\n\",\n              \"jI2N6ww7RrSJUt/JAkzElZeXU2hoKLVs2ZLs7e3p+vXr1W4XGRlJ1tY2pKdvQHMXLqO9ISforyOn\\\\\\n\",\n              \"aNmq9WTW2pyMjIzpyJEjRPTfTtZoKM+ePSMfHx9SVlYmaWlpioiIEKhcTEwMcTgcunPnDhER+fr6\\\\\\n\",\n              \"UkBAQI3bX758mYyMjMjNzY1iYmJo9OjRpKGhQevXr6fz58+ToqIiKSoqUkhIiMBt//z5M+3fv5/6\\\\\\n\",\n              \"9OlDSkpKNGzYMDpz5kyNvzgUFRXR2bNnadSoUWRiYkJSUlIEgBSVlOjP3UeFCrDgYxepuaGRwP/O\\\\\\n\",\n              \"CxYsIBsbm+9qZizTsESp72QB9oMoKyujvXv3koGBAfXs2ZPu3r1b8dmJEyeIw+VS0N4QeptdSCkf\\\\\\n\",\n              \"iiu93uUU0bGzl0hHtyl5jhxJXC73m0zWaAiZmZnUoUMHUlBQICcnJzp//nydnfPp06eJx+NRQkIC\\\\\\n\",\n              \"vXnzhtTU1KqMMrKysmjMmDHUtGlT2r9/P82YMYPU1NRo/vz59PHjx4pp8jwejx4+fFhnO4uLi+nM\\\\\\n\",\n              \"mTM0bNgwUlJSot69e9P+/furHQklJyfT2rVrqUePHsThcEhMTIwkJCSoWbNmNGTIEDpw4ADl5+fT\\\\\\n\",\n              \"rl27yKFHL6ECrK+rG/3xxx8CH18+n08+Pj7Uo0cPgW5vYESPKPWdLMB+MMXFxRQYGEja2to0YMAA\\\\\\n\",\n              \"Cg0NJXUOh85fja4SXP9+3X0UT2pq6rRp06Zv/TXqJTY2lnR0dGjXrl1kaWlJLVu2pKCgoFpvQ9i6\\\\\\n\",\n              \"dSuZmJhQVlYW+fn50fTp0+nq1avk5z+Z7Dp3ISVlVerUqRPNmjWr4tReamoq8fl8CggIIBkZGerU\\\\\\n\",\n              \"qVPFahjV+TKD0NfXlzgcDnXq1KliBuEXhYWFdPnyZZowYQKZmZmRtLR0xRR8a2trmj59OsXExFR7\\\\\\n\",\n              \"SjE/P584HC7tOhouUHgdOBtJKiqqlJOTI9TxLSsrI1dXV3J3d2+066/MtyNKfScLsB9UQUEBrV27\\\\\\n\",\n              \"llRUVGjhslV1hteX155Dx8nGpu23bn69dezYkc6cOUN8Pp8uX75MLi4uxOVy6ddff61xRZCAgACy\\\\\\n\",\n              \"s7OjdevWk6y8AukYGJGr7ywa/etaGjFnBdl070vSsnL0S/8B9P79eyouLqZ+/fqRlJQUTZgwocbZ\\\\\\n\",\n              \"j7GxsRQQEEDNmjWjVq1a0fLlyykxMZH4fD69fv2aAgMDqW/fvqSpqUni4uIkLi5Ompqa1Lt3b9qy\\\\\\n\",\n              \"ZQulpqYK/L3Dw8OJw9WgPccv1Rpeh8KiSFOLV3HaWFiFhYXUpUsX8vPzqzLC5fP5lJubS9nZ2f/J\\\\\\n\",\n              \"TFmmYYlS38kC7AeWlpZGysoq9Ox1usAB9iargJrp6Vc6BSmKgoOD6Zdffqn03osXL8jX15dUVFRo\\\\\\n\",\n              \"9OjRVWb0lZeXk7mFBalr6dCMLYdoe3QSBd1+Xem1Lvx/1GfURNJp2qzi+lNwcHCV+r/MIGzdujU1\\\\\\n\",\n              \"bdqUZs+eTXfu3KGrV6/StGnTqE2bNiQjI0MSEhIkKSlJLVq0IG9vb7pw4QLl5+fX67ufPXuW1NQ5\\\\\\n\",\n              \"NGjYaDp47kal4AqNuENDR/uQqpo6HTp0qF71fFlyatmyZUT096nOuXPnEldDg+Tl5UlJWZlkZGRo\\\\\\n\",\n              \"6NBhdOPGDZG5nvqzE6W+kz2R+Qe2Y8cOhF+6ij+37xGq3NqVSyDBL8aaNWsap2H/gfz8fOjo6MDb\\\\\\n\",\n              \"exxiHz1GXn4eFBUV0c3RAa6urjh27Bi2bNmCli1bYvr06ejVqxe2bw/CslVrMGNbKBRV1Wvd/5Uj\\\\\\n\",\n              \"e3AicDXCz4fBweHvB0xmZWXhyJEjOHjwIJ4/fw5nZ2c0a9YMcXFxuHPnDtLS0iAmJgZ5eXlYWFig\\\\\\n\",\n              \"d+/e+OWXX2BmZibwgyIFlZaWhp07d2Lb9iCIiYtDUUkZ+XmfUVxUBJ9x3vDx8YGurm6963n//j06\\\\\\n\",\n              \"deoES8s2iLwRCdchw+AxehwMjU0BAJ8+fsDxwwdwYHcQ9PX1cOzo0QZ7FhzTOESp72QPtPyBZWZm\\\\\\n\",\n              \"gqctfCelo9sUZ46FIDQ0FPLy8pCXl4ecnFylP+Xl5SEjI9PgHW9D+PDhAyZO8kNpaRlevH6PbgNH\\\\\\n\",\n              \"QkFRCXm5nxB1+RxWrPwdAwcOxOPHjxEWFoYFCxZg6tSpeJf6HrODjtUZXgDQzW00kh7fw+3bt5GS\\\\\\n\",\n              \"koK9e/ciKioK+vr6KCkpQWFhIY4ePQo+nw9tbW3Y29ujf//+6Nq1K3g8XqMfAy0tLSxYsABz5sxB\\\\\\n\",\n              \"UlIScnNzoaioCAMDA0hKSjZYPTweD87OvRBx+QrCb9wHV1Or0ufKKqoYM94PI70nYMm8GejevQeu\\\\\\n\",\n              \"X7/GnjfGNAgWYD8wSUlJlJaVCl2upKQEMTF34eV1H1JSUpCQkIC4uDiICOXl5SgtLUVRURFKSkog\\\\\\n\",\n              \"KytbKdSqCzph3/vys5ycnNABmZmZiS4ODjBra4/QyCdQUFSq9Lld997wnf0B29YsRPceTrh29Qo0\\\\\\n\",\n              \"NTUxePBgaOoZQsfQVOC6ugwahUXTRoHKy1BWVgZJSUkkJCSgZcuWmDRpEnr37o0OHTp80866SZMm\\\\\\n\",\n              \"MDY2brT9h4eH40J4OI6HR0JNnVPjdhISElj8+3rM9vfBzJkzERgY2GhtYn4eLMB+YMbGxjhx6ozQ\\\\\\n\",\n              \"5R4/fIBZs2Zh0KBBeP/+PVJTUyte//z7+/fvISUlBXV1dXC5XKipqUFFRQVKSkpQVFSErKwsZGRk\\\\\\n\",\n              \"ICUlheLiYuTn5yMzMxOvX79GQUEB8vPzkZ+fX/Hzv98rLCyEtLS0wEEoJyeHg4cOoUPXPvCdtbjG\\\\\\n\",\n              \"76ekoopZy/7En0sDYG5hgU8fP6KMgO7uY4U6TkaWbSGroIQWhvrw8PCAo6MjWrVqBQkJCaGPuaha\\\\\\n\",\n              \"v2ED/GbMrTW8vhATE8Pshcvg1KkNVq5cCRUVlf+ghcyPjF0D+4GVlpaimZ4eDh4Lg2nLVgKV+fjx\\\\\\n\",\n              \"AzpZtUDcy5fQ0NCodVs+n4+cnJxqw+2ff09LS4OSkhJ4PB60tbUrXv/8O4/HA4/Hg5SUVKX9FxUV\\\\\\n\",\n              \"1Rl0X/58/PgxrkbexL7ztwUauZWVlWFQZzPkf/4EGTkF+K3bDQOzNgIdpy/+nDwCtpYtYWVlBUlJ\\\\\\n\",\n              \"SUhJSUFSUrLKS9j3JSUlv8vTs/+UmJiIdu3a42ZsPGRkZQUuN8VnJBztbTFt2rRGbB3ztUSp72Qj\\\\\\n\",\n              \"sB+YpKQkxnl7Y8uGNdi4LRhiYmJ1ltmx9U9w1NUFeqS8uLg4OBwOOBwOLCwsatyOz+cjOzu7Srg9\\\\\\n\",\n              \"e/YMly5dqgi69PR0KCsrVwm3f/9dS0ur2us4A1wHYvAoX4E7/iZNmsBtzATcuXwWb96+A59fLlC5\\\\\\n\",\n              \"fyorK0VaWhpiY2NRWlqKkpISlJaWVnkJ+35paSkkJCTqHYSN+f7FixfR0d5BqPACgO7OfREZcZYF\\\\\\n\",\n              \"GFNvLMB+cDNmzICdnT02rFmBqbPm1Rpixw4fxOH9u+Hq6goLCwusWLECXl5eAgVfbcTFxcHlcsHl\\\\\\n\",\n              \"cmFpaVnjdnw+H5mZmVVGco8fP0Z4eHhF8KWnp0NVVbVSuGlpaeHc2bM48+t6odrWe6AHgjf+Dklp\\\\\\n\",\n              \"abyLfw5DcxuBy/LLy5Hz/h1+/ysYrVoJNsIVFBGhrKys3kFY1/v5+flfvZ/09HR0sHMQ+rspKCrh\\\\\\n\",\n              \"8+fPDXq8mJ8TC7AfnLKyMi5eDEevXr3x/MkjjJs0FW3bd6wUSi+ePcXuHVtw7fJFXLx4Ea1bt4aP\\\\\\n\",\n              \"jw/GjRuHAwcOICgoqFEnAnwhLi4OTU1NaGpqok2bmk/llZeXIzMzs9JpylevXkFKWhpyCopC1amq\\\\\\n\",\n              \"zgWf+JAQAy6F7EIX1+ECB/bjW1ehq6Pd4OEF/H296MtI53t18OBBHDpyXOhyuZ8+QklJqe4NGaYO\\\\\\n\",\n              \"LMB+Atra2rh16yaCgoIwY5I3pGVkYNrSDBISEnidmIB3797AZ9w4/HHvXsV1L0tLS0RHR2PTpk2w\\\\\\n\",\n              \"tbXFjBkzMHN/zqwpAAAbqUlEQVTmzO+iQ83NzUViYiJiY2Nx+/ZtPH78GElJSSgpKRF6X3w+H8Tn\\\\\\n\",\n              \"w9DQEK8Sk/DyfjRatO0kULmw3ZswxMX5a77CD8HOzg5+/v4oyM+HnBAzLS+dP4teTl0bsWXMz4JN\\\\\\n\",\n              \"4vjJ8Pl83Lx5E8nJyeDz+eDxeHB0dKw1mF6/fo0JEyYgNTUVO3bsQPv27Ru9nURUcZ3s9u3biImJ\\\\\\n\",\n              \"wfPnz5GSkoLi4mJISEigrKwMHA4HBgYGsLCwwNFjx7Bu72k0N2kpcD1PH8ZgtrcbtgVuhby8PDxH\\\\\\n\",\n              \"jcHM7aHQNqh5xMnn83F04xK8f/YAnz7koHfv3lizZs1PeYPuL/36wa5bbwz1HCPQ9unvU9GrS1sk\\\\\\n\",\n              \"v37NRmHfKVHqO7/vaU7MVyMi3Lp1C8NHjICOjg4UFRXB4/Hg6joQhYWF8PDwwMiRI+Hk5FTnqEpf\\\\\\n\",\n              \"Xx9hYWEICAhAv379MG3aNOTl5TVIO8vLy/Hq1SscPXoUU6dOhaOjI3R1dSEtLQ09PT04OztjxYoV\\\\\\n\",\n              \"ePr0KYyNjTFjxgycPHkSz549Q3FxMdLT03HlyhW0b98eMtIyOLZvu1D1Hw7eCjlZWaSlpSEwMBCm\\\\\\n\",\n              \"xkbY6DcM147uQ2F+1es0r5/FYsfc8chNfokb16/h6dOnEBcXR+vWrXHu3LkGOSaiZNrUqdi6fhUy\\\\\\n\",\n              \"09Pq3JbP52P5wgB4jhjBwotpEGwE9gN6+/YtBg8ZguzsbHiN88Uv/V2hoqqK/Lw8XLwQhh3bA1FS\\\\\\n\",\n              \"XISjR4+idevWQu07KysLM2bMQGRkJAIDA9GrVy+ByhUVFeHly5eIiorCnTt38PTpUyQnJ+PDhw8Q\\\\\\n\",\n              \"ExMDEUFVVRV6enowMzNDx44dYWNjA1NT0xrvF0pISEBgYCD27NmDTp06wd3dHRMmTkTw6Sho8HTq\\\\\\n\",\n              \"bNO75ET4Du6B3cG7MGnSJOTk5GDkyJFwcHDAsRMnceXyZVjad4esijrKSorx8n40xMtL4DdxAiZP\\\\\\n\",\n              \"ngzZf8y+u3LlCry9vWFvb4/169dDXb3u1Tx+FEuWLMHBkMPYdegEdHSbVbtNaWkpFsz0x7ukV7h0\\\\\\n\",\n              \"KaLSsWO+LyLVd36TFRj/A6K0IGVDSk5OpqZNm9KylavpY0Ep5RaVV3l9KiyjoOC9pKGhIdCzq6oT\\\\\\n\",\n              \"Hh5OBgYGNHz48EqPA/n48SNduXKFFi1aRC4uLmRiYkIKCgokJiZGYmJiJC0tTc2aNaNu3brR9OnT\\\\\\n\",\n              \"6dixY5SQkCDwquXl5eUUFhZGffr0IQ6HQ7NmzaLExERKTU0lZ2dnkpSUIp6uHh2/8ZSuv8yu8XX4\\\\\\n\",\n              \"ykPSaaZPGzZsIAcHB/L09KR3797R8uXLqWnTptSxY0fasGEDbd68mQwMDMjLy4ukpaWpsLCwxrbl\\\\\\n\",\n              \"5eXRlClTiMfj0dGjR7/quIoiPp9Pa9euJWUVFRrqOYbOXL5FCRn5lJhZQHeeJNLM+b+RbtNm9Eu/\\\\\\n\",\n              \"fgI9+Zn5tkSp72QB9gPh8/lkbW1NK1atrTa4/v3as/8QNW3atNbnZNVUz7t37yg4OJjatGlDkpKS\\\\\\n\",\n              \"pKqqSpKSkhVBpaysTGZmZjRo0CD6/fff6datW/Tx48ev/m45OTm0bt06MjQ0JCsrKwoODqacnBw6\\\\\\n\",\n              \"cuQIOTk5UZMmTUhTU5NCQkLot9+WkBqHS5MXrKRz95IqBdfZuwk0ae4yUudqkJKSEllZWZG3t3el\\\\\\n\",\n              \"51qVlZXRqVOnqFevXsTlciseYmlubi7QKv03b94kU1NTGjx4MKWlpX31dxY1aWlptGzZMtLT069Y\\\\\\n\",\n              \"ZV9RUZG8vLzo/v3737p5jIBEqe9kAfYDuXLlCrVsZUafCssECrDconJy6ulMe/furXZ/ZWVlFBsb\\\\\\n\",\n              \"S+vXr6dhw4ZRmzZtSF1dncTFxQkASUlJkba2NllYWJC6ujqZmZnR1atXG/QZUA8fPqRx48aRiooK\\\\\\n\",\n              \"eXh40M2bNykqKorGjx9Pampq1LZtW1JTU6Pp06dXPI/r2rVrpKCgQCqqaiQrK0e2XbpTT5eB1Mmh\\\\\\n\",\n              \"Bykpq9Awj+F08eJF0tPTI2Vl5UojyH+Lj4+npk2bkpKSEjVt2pTGjRsn0PcrLCykOXPmkIaGBu3f\\\\\\n\",\n              \"v/+ne5RIaWlpraNV5vslSn0nuwb2Axk8ZAg6dXbEuPETBC5z/twZrP19OZYuXYrIyEg8ePAAr169\\\\\\n\",\n              \"QlpaGvLz8wEA8vLy4PF4MDY2hrW1NRwcHNC+fftKF+JLS0uxbt06rFmzBnPnzsWUKVPQpMnX3aVR\\\\\\n\",\n              \"WlqKEydOYPPmzUhMTISvry969uyJCxcuYN++fZCUlMTIkSNRXFyMwMBA7Ny5E7/88guICMePH8fw\\\\\\n\",\n              \"4cOhr6+PXbt2oXnz5oiJicHnz5+hpKSEjh07AgCcnJzg5OQEcXFx3Lp1CxEREZCRkam2PV27dkVA\\\\\\n\",\n              \"QACOHz+OY8eOQVFREePHj4eXl1edy23du3cPY8eOhZ6eHrZt2wYdnbqvzTHMtyRSfec3DtBGI0q/\\\\\\n\",\n              \"RTQUJSUlSnybJvDoK7eonD7kl5CkpCRJSEgQl8sla2trGjFiBP3555/05MkToR8ZHx8fT926dSMb\\\\\\n\",\n              \"Gxt68OCBUGVTU1Np8eLFxOPxyMHBgfbs2UPbtm2jzp07E5fLJX9/f4qJiaGPHz/S4MGDydramhIT\\\\\\n\",\n              \"E4mI6Pr162Rra0taWlrUpk2bGtudmppKrVq1ogULFhCfz6fy8nJyc3OjYcOG1VjG0dGRrly5QvHx\\\\\\n\",\n              \"8aSrq0v37t0jLy+vilFhXQ9rLC4upsWLFxOHw6EdO3b8dKMxRrSIUt/JAuwHwefzCUCNEzdqe2lr\\\\\\n\",\n              \"a9ObN28atC3BwcHE5XJp9uzZtT5hmM/n0+XLl8nR0ZHk5OTI2dmZlixZQu7u7qSkpEQDBw6kkydP\\\\\\n\",\n              \"UnFxMRERxcbGkrGxMfn6+lJhYSHFxsZSnz59SF9fn9avX0+qqqoUFxdXbV1v374lY2NjWrp0aaX3\\\\\\n\",\n              \"CwoKyNbWlubPn19tuS8BxufzicvlVhyrnJwc2rBhA5mYmJC5uTkFBgZSbm5ujd81NjaWbGxsqEeP\\\\\\n\",\n              \"HpSUlFTbIWSYb0aU+k4WYD8QOTk5Ssn8KFR4fSosIxUVFcrOzm7w9qSlpZG7uzsZGhrSpUuXKn2W\\\\\\n\",\n              \"n59PK1euJB5Pm2Tl5KiNlTX1dO5DHTvZkaKSEpmbW9C+ffsqjVZ2795NHA6H9u3bR4mJiTRixAjS\\\\\\n\",\n              \"1NSkjRs3UlFREQ0YMIAWL15cbVuSkpLIwMCA1q5dW+3nGRkZZGhoSLt27ary2ZcAIyLq378/hYSE\\\\\\n\",\n              \"VPqcz+fTpUuXaODAgaSqqkoTJ06kx48fV1tPaWkprVq1itTV1WnTpk1Cj3AZprGJUt/JAuwH0sXB\\\\\\n\",\n              \"gf46dESoALtyI5oMDAwatSM9c+YMNW3alMaMGUP37t2jGTNmkKKiIikoKNAEv6l0/3EcZeeVVbxS\\\\\\n\",\n              \"svIocOdeMmttTp4jR9KnT5/I29ubTE1N6fr16+Tv709qamq0ePHiihHPmTNnyMjIqNqJA3FxcdSs\\\\\\n\",\n              \"WTPavHlzre188eIFaWhoUERERKX3/xlgq1atosmTJ9e4j3fv3tGiRYuIx+NRly5dKCQkpGL0+O+6\\\\\\n\",\n              \"OnXqRPb29vTy5cs6jyHD/FdEqe9kAfYDCQkJIQfHbkIFmMeIkbRq1apGbVd5eTkdPXqU9PX1SUxM\\\\\\n\",\n              \"jIyNjUlBQZHOXbxWKbj+/XqT/om69ehJWlpa5OrqSgEBAaSmpkZTpkyh9PT0iv3n5+eTvr4+Xbx4\\\\\\n\",\n              \"sUrdz549Ix0dHQoKChKordevXycul1tpBPXPALtx4wa1bdu2zv2UlJRQaGgode3albS0tGj+/PmU\\\\\\n\",\n              \"nJxcaZuysjLauHEjqaur0+rVqytmUTLMtyRKfScLsB9IcXExaWtr0/EzYQKF143b90hVVZUyMzMr\\\\\\n\",\n              \"7aegoIASEhLoxYsXlJWV9dXt+ee9W8bGxtS5c2dSUFAgeQUF+ivkeK3h9c8Qa6anTyoqKuTp6Vnt\\\\\\n\",\n              \"taO5c+eSu7t7lfdjY2OJx+PRvn37hGr3/v37SU9Pj1JTU4mocoAVFBSQnJwc5eXlCby/Z8+e0eTJ\\\\\\n\",\n              \"k0lNTY369etHFy5cqDTiTUhIoK5du1K7du1qPPXIMP8VUeo7WYD9YG7cuEFcLpfOnI+oNbyi7twn\\\\\\n\",\n              \"nrY2hYaGVpSNjY0lHx8fUlFRoWZ6emRoZERKSkrUtVs3Onr0KJWUlAjUhi/7UVRUpNatWxOPxyMz\\\\\\n\",\n              \"MzNavXo1HTlyhFq1MqOsz6UCBVh2Xhn9sXErde/eo9q6nj17Rurq6pSSklLp/Xv37pGmpiYdPnz4\\\\\\n\",\n              \"q47jkiVLyMbGhvLy8ioFGBFRhw4d6Nq1a0LvMy8vj3bs2EFt2rQhQ0NDWrt2bcUvCHw+n7Zv304c\\\\\\n\",\n              \"DoeWLFlS47HOzMyk33//nRy7dqU2VlbUyc6O/P396dmzZ1/1PRnm30Sp72QB9gO6evUqcblcGuI2\\\\\\n\",\n              \"lC5eiax0Y/PNuw9ojNc4UlNTq+jcy8vLadq0acTT1qb5C3+juKR3Fdtn5RZS8L4DZNvJjtq0aVMl\\\\\\n\",\n              \"KL4oKSmhw4cPU8eOHf8OwGbNiMPh0NSpU+nBgwcVkzHc3Nxp9bpNAodXdl4ZJad9JFVV1Sp18/l8\\\\\\n\",\n              \"cnR0pI0bN1Z6Pzo6mjQ0NOjEiRNffQz5fD6NHj2a+vXrRw4ODpUCbNq0abRixYp67Ts6Opo8PT1J\\\\\\n\",\n              \"RUWFRo0aRXfu3CE+n09v3ryh3r17k6WlZaXVKwoKCsjHx4eUlZVpuOcoOnb6HF27eYfCLl6hWXPm\\\\\\n\",\n              \"k6amJnXt1o3NbmTqTZT6ThZgP6gvp++MjI1JU1OTTExMSVtHh3R1dWnJkiX0/v17Ivq7M504cSLZ\\\\\\n\",\n              \"2tlT8vusWmcrLvxtGRkZGVU65ZiamkoLFy4kNTU14nA4JCcnR4MHD6azZ89WO4qwsLCkqzdjhAqw\\\\\\n\",\n              \"7Lwy6mjbiSIjIyvta9++fWRlZVXp2lFkZCRxuVw6d+5cvY9hcXExdevWjXR1dSsFWGhoKLm4uNR7\\\\\\n\",\n              \"/0R/j6hWr15NBgYGZGNjQzt37qS8vDzat28fcblcmjt3LmVnZ5N95840aIgbJaVkVPvvk5VbSEtX\\\\\\n\",\n              \"rCJtbW02KYSpF1HqO1mA/eDKy8vp7du39PTpU0pOTq4yUeD06dNkatqC3qbnCHTdbPK0GeTu7k43\\\\\\n\",\n              \"btygnj17kpSUFMnIyJCVlRUFBQXRhw8fam2PaYsWdPNurNAB1sWxa6XZgTk5OaSlpUV37typeO/S\\\\\\n\",\n              \"pUvE5XKrzCKsjw8fPpCcnBz5+fkR0d+BHxoaSsoqKtSliwPZd+5Mbu7udPr06XotofVlkWIXFxdS\\\\\\n\",\n              \"U1OjqVOn0o0bN8jV1ZU0NDRosPtQge7x27R1OzVv3rzWe+8Ypjai1HeypaR+cj17OmPIsOEY6jFC\\\\\\n\",\n              \"oO0/ffoEE30dlJWVQUlJCd7e3hg3bhwMDQ0FKm9v3xnTZs9D1+49BW4jEaGjtRmOHA6BlZUVAGDC\\\\\\n\",\n              \"hL+XywoMDAQAhIWFYfTo0Th69Ci6dOki8L4FYWtri/j4ePj7++PIkSMgAkaN9YZZa3OIi4sjMeEV\\\\\\n\",\n              \"9u0JRkZ6GpYtW4YRIwQ7ljV5/fo1goKCsGvXLhgZGeHJk6eIe/0OcnJyApUfMsAF7m5DMGaMYA+Z\\\\\\n\",\n              \"ZJh/Eqm+8xsHaKMRpd8ivpX4+HjicrmU8TFfqKn3wz1HkZeX11ctifTHH3/QEHcPoUZf4VdvkkHz\\\\\\n\",\n              \"5hUz9+7cuUNaWlqUk5NDREQnT54kLpdL0dHRDXp8vnB0dKSxY8eSkpISHTt9rsbFki9H3iJ9AwNa\\\\\\n\",\n              \"vXp1g9RbVFREv/zSj3wn+Qv173PkxGmyEWCqP8NUR5T6TvZE5p/Y/fv3Yde5S42L2Nakd18XZGZm\\\\\\n\",\n              \"QUxMTOg6x4wZg4jwMGRmZAhcZsvGP8BRV0dqairKysrg6+uL1atXQ1VVFUeOHMH48eNx/vz5ioV6\\\\\\n\",\n              \"G1pubi5OnTqNi1dvwKlnrxq/d7v2HRB+ORKbt2zByZMn612vtLQ0nr94Ds9Rwo2knHr2QvLr10hJ\\\\\\n\",\n              \"Sal3Gxjme8YC7CeWn58POTl5ocvJyytUrFQvLFVVVXh7e8PXyxMlJSV1bn/yeChuRkXC3NwclpaW\\\\\\n\",\n              \"6N27N+Tl5TFixAj89ddfmDp1Ki5evAgbG5uvao8gMjIzMffXRWhlVvfTq3na2vhj42YsXbYM1ABn\\\\\\n\",\n              \"5z/k5EBLiydUGQkJCWhqaiE7O7ve9TPM94wF2E9MWVkZH3KE7+RycrKhrKz81fWuXLkSaqoqGDrQ\\\\\\n\",\n              \"BSkp76rdprS0FEHbNmOa33iYW1ji5MmTsLC0xPXr1/H8+XMMHz4cc+bMweXLl2FhYfHVbalLSkoK\\\\\\n\",\n              \"PuTkYNhwT4HLOPXshZycHMTExNS7fmlpaYGC/t+KiosgLS1d7/oZ5nvGAuwn1qVLF0TfuokPHz4I\\\\\\n\",\n              \"Ve7k8aNwcurx1fU2adIER44chl2njnC0tcYoj8E4ffI47kTfxPWrl7FiyUJYtmyOc6eO42rUbZwO\\\\\\n\",\n              \"u4gncUlwHeQGdXV1mJub48SJE5CSksKzZ88aZKRTk9OnT6NXHxcoKioKXEZcXBxDPTxx9OjRetff\\\\\\n\",\n              \"omVL3I6+KVSZ96mpyM7Kgq6ubr3rZ5jvGQuwnxiXy0VfFxcc+GuvwGVS3r1DVOT1es+0k5CQwPLl\\\\\\n\",\n              \"y5GcnIy+vZ2xZuVvGD3cDWt/X4bPuZ9w6twFhEVchbGJKQBAQUEBY7x9EBl9D4lJSZg9eza2b9+O\\\\\\n\",\n              \"JUuWwN7eHrdu3apXe2qSmZkJfQMDocvp6OggKyur3vX7jh+PXTu2C1VmT/BOuLu7Q15e+NPDDCNK\\\\\\n\",\n              \"WID95KZOmYI/161BQsKrOrctLy/H9CmTMHbsWCgoKDRI/QoKCnB2dkZ6WhpuxfwPFy5fw9oNf9Z4\\\\\\n\",\n              \"vUmLx8O58CvYuHEjrKys8ODBA/j4+MDd3R2DBg1CfHx8g7TrC2lpaRQXFQldrrikpEFO4Q0YMAAJ\\\\\\n\",\n              \"8XGIirwu0PZZmZnYvXM7Jk6cWO+6GeZ7xwLsJ9e2bVssXrwY/Xo74cnjRzVul5+fjzGew1BaXIwV\\\\\\n\",\n              \"K1Y0aBu2bduGYSNGQkNTU6Dt9fT14dJvAIKDgyEhIYFRo0YhLi4O7dq1g62tLfz9/ZGZmdkgbWvR\\\\\\n\",\n              \"ogWibkQKXe7+3TswMTGpd/2SkpLYvXs3xngOw/8e3K912+zsbAxx/QVjxoxp1OuCDPPd+Nbz+BuL\\\\\\n\",\n              \"KN3L8D3Yv38/qaurU5++LnT01FlKfJtGb9NzKPreQ/KbMo3U1dVp1OjRVFRU1KD1lpeXE5fLpQdP\\\\\\n\",\n              \"Xgh1r9PVqNtkZGRUZX8ZGRnk7+9P6urqtHz58nqvSFFSUkKKikp0+36swG1LepdOKioq9VrJ/9+O\\\\\\n\",\n              \"Hz9OHA6Hps2cTY9fJFSq7216Dq1Z/yfp6evTzJkzv+r+PIb5QpT6TrYSB1MhPz8fISEhCNqxA6/i\\\\\\n\",\n              \"41FSUgIul4vBgwdjwoQJMPiKa0F1+fjxI/T09PAuQ7iJJGVlZeAqy6GkpATi4lVPJMTHx2PevHm4\\\\\\n\",\n              \"ffs2li5dCk9PT0hISHxVGw0NDWFl3RbBfx0U6N63xQvmITsrHXt27/6q+mry6tUrbNmyBX/99ReM\\\\\\n\",\n              \"jE2grq6O/PwCxD58gB5OTvCbNAkODg4NWifz8xGlvpMFGPNNZWZmokWLFnidKtwpPyKCmoI0ioqK\\\\\\n\",\n              \"0KRJkxq3i46OxsyZM5GXl4fVq1fD2dlZ6DZ27twZmZmZ+KX/QPz629JaQyx4x3asX7sK0dHR4PGE\\\\\\n\",\n              \"u39LUAUFBbh79y4+ffoEOTk5mJubQ0tLq1HqYn4+otR31vx/PsP8B1RUVFBQUIBPnz4JdW9ZakoK\\\\\\n\",\n              \"lJSUag0v4O91DKOionDy5En4+/tDX18fq1evRps2beqsIycnB+Hh4UhLS0P37t1x5tRxPHr0EP5T\\\\\\n\",\n              \"pqOLY9dKQXb/Xgy2b9mEmLu3ERER0WjhBQBycnJwdHRstP0zjKhgkziYb0pSUhK/9OuHkIP7hSr3\\\\\\n\",\n              \"197dcHNzE2hbMTExuLq64unTp+jfvz969eqFUaNG4e3bt9Vu//jxY4wZMwbNmzfHoZDDsOvcBfmF\\\\\\n\",\n              \"RQDE8PzpU/h4jYJ16xbwcBuIEUMHo1M7K4z1HIY2luaIiYmBsbGxUN+FYZivw0ZgzDc3ccIETJgw\\\\\\n\",\n              \"Ed4+vgJdpyopKcGe4B04HxYmVD2SkpKYNGkSPD09K0ZhPj4+mDNnTsXo7+TJkxg3bhwmT5uBx8/j\\\\\\n\",\n              \"weVyK8oTEW5EXseqlcuR9/kz3AYPgoyMDHg8Hmxtbb/6GhvDMF+HXQNjvjkiQvcePdCiVWv8vmZd\\\\\\n\",\n              \"rdeY+Hw+Jvp4oagwH8fqudLFu3fvsGjRIpw9exbz589HixYt4OnpiROnw2Bdy9qK5eXl8JswHu/e\\\\\\n\",\n              \"vUHYuXOQlJSsVzsY5nsiSn0nO4XIfHNiYmI4dvQo7tyKwqTx3jWuVJ+akoIxnsPwJjkJ+/YKvnpI\\\\\\n\",\n              \"TXR1dbFr1y5cunQJYWFhcHNzQ2DQrlrDC/h7FZFNW7chLy+/QZaLYhjm67AAY74LqqqquH79OmSl\\\\\\n\",\n              \"JWFj0RLeo0Yg9PAhXDh/DocPHYDnsCGwbWuJZro6iLh4sUGXSTI3N0dAQAA0NbXQu09fgco0adIE\\\\\\n\",\n              \"U6bNwNatWxusHQzDCIedQmS+Ox8+fMDu3bsRHX0bnz9/hpKSEhwcumDkyJFCLaorjBEjRsC6bQdM\\\\\\n\",\n              \"9PMXuExZWRlMDfVw+fJltGjRolHaxTD/NVHqO9kkDua7o6qqiunTp/+ndb5KSMDYcROEKtOkSROY\\\\\\n\",\n              \"tTZHYmIiCzCG+QbYKUSGAVBeVlbnPWXVkZSURFlZWSO0iGGYurAAYxgAGpqaSE5+LVQZIsLr10nQ\\\\\\n\",\n              \"FHARYoZhGhYLMIYBMNTdHfv2BAtV5l5MDAoLCtCuXbtGahXDMLVhAcYwAIYMGYLYh/9D3MuXApfZ\\\\\\n\",\n              \"tnUzfH0nVLuYMMMwjY/9n8cwAGRkZDBv3nyMGOaGjx8/1rn9gb/24WZUJLy9vf6D1jEMUx0WYAzz\\\\\\n\",\n              \"/6ZMmYwePXqgu6M97sXEVLvN58+fsWrlciz6dR7CwsKgpqb2H7eSYZgv2DR6hvl/YmJi+OOPtTAN\\\\\\n\",\n              \"MoWnhxs4HC6GDfeEto4OioqKcDv6Fo6EHISDoyNu3rwJPT29b91khvmpsQBjmH8QExPD+PE+8Pb2\\\\\\n\",\n              \"woULF3D8+AlEXr8CaWlptGrZCo8ePYKuru63bibDMGABxjDVkpCQQN++fdG3r2BLSzEM899j18AY\\\\\\n\",\n              \"hmEYkcQCjGEYhhFJLMAYhmEYkcQCjGEYhhFJLMAYhmEYkcQCjGEYhhFJLMAYhmEYkcQCjGEYhhFJ\\\\\\n\",\n              \"LMAYhmEYkcQCjGEYhhFJYkRE37oRjYHD4UBfX/9bN4NhGEakvH79GllZWd+6GQL5YQOMYRiG+bGx\\\\\\n\",\n              \"U4gMwzCMSGIBxjAMw4gkFmAMwzCMSGIBxjAMw4gkFmAMwzCMSGIBxjAMw4gkFmAMwzCMSGIBxjAM\\\\\\n\",\n              \"w4gkFmAMwzCMSGIBxjAMw4gkFmAMwzCMSGIBxjAMw4gkFmAMwzCMSGIBxjAMw4gkFmAMwzCMSGIB\\\\\\n\",\n              \"xjAMw4gkFmAMwzCMSGIBxjAMw4gkFmAMwzCMSGIBxjAMw4gkFmAMwzCMSGIBxjAMw4gkFmAMwzCM\\\\\\n\",\n              \"SGIBxjAMw4gkFmAMwzCMSGIBxjAMw4gkFmAMwzCMSGIBxjAMw4gkFmAMwzCMSGIBxjAMw4gkFmAM\\\\\\n\",\n              \"wzCMSGIBxjAMw4gkFmAMwzCMSGIBxjAMw4ik/wN6AKtyQAMhHgAAAABJRU5ErkJggg==\\\\\\n\",\n              \"\\\"\\n\",\n              \"  frames[8] = \\\"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAbAAAAEgCAYAAADVKCZpAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\\\\\\n\",\n              \"AAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0\\\\\\n\",\n              \"dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOzdd1RUx+P38TdNRASUjhUCdlSwF8SC\\\\\\n\",\n              \"sYEFbNgVK7FGk68lsUQTo0Zjw5pgb6hRqYpYIooVxYIl9gICUpTelr3PH3nkFwLKrhFxdV7n7KHs\\\\\\n\",\n              \"zN7ZRe9nZ3ZmrpokSRKCIAiCoGLUS7sBgiAIgvAuRIAJgiAIKkkEmCAIgqCSRIAJgiAIKkkEmCAI\\\\\\n\",\n              \"gqCSRIAJgiAIKkkEmCAIgqCSRIAJgiAIKkkEmCAIgqCSRIAJgiAIKkkEmCAIgqCSRIAJgiAIKkkE\\\\\\n\",\n              \"mCAIgqCSRIAJgiAIKkkEmCAIgqCSRIAJgiAIKkkEmCAIgqCSRIAJgiAIKkkEmCAIgqCSRIAJgiAI\\\\\\n\",\n              \"KkkEmCAIgqCSRIAJgiAIKkkEmCAIgqCSRIAJgiAIKkkEmCAIgqCSRIAJgiAIKkkEmCAIgqCSRIAJ\\\\\\n\",\n              \"giAIKkkEmCAIgqCSRIAJgiAIKkkEmCAIgqCSRIAJgiAIKkkEmCAIgqCSRIAJgiAIKkmztBtQUoyN\\\\\\n\",\n              \"jbG0tCztZgiCIKiUx48fk5CQUNrNUMgnG2CWlpaEh4eXdjMEQRBUSpMmTUq7CQoTQ4iCIAiCShIB\\\\\\n\",\n              \"JgiCIKgkEWCCIAiCShIBJgiCIKgkEWCCIAiCShIBJgiCIKgkEWCCIAiCShIBJgiCIKikT3YhsyAI\\\\\\n\",\n              \"woeSmZmZv3uFiYkJZcuWLeUWfR5ED0wQBOEdSJJEWFgY/dwHYGRsgn3T5tg1aYahsTEDBw/h4sWL\\\\\\n\",\n              \"SJJU2s38pIkemCAIgpIyMjLoP2Agl6/ewN7Znak7T6GjZwBAenISEcF/0N2tL21atWTHti2iR1ZC\\\\\\n\",\n              \"RA9MEARBCbm5uXRz6U5MJozd4E+r3iPywwtA18AQh36j8dwYyP0Xybj27kNeXl4ptvjTJQJMEARB\\\\\\n\",\n              \"CYsWLyY+U06PbxahqVXmjeW0tMviNvNXHsYmsWrVqg/Yws+HCDBBEAQF5ebm4rV2HR08pqGhUfwn\\\\\\n\",\n              \"MBqaWrQf/jUrVnkhl8s/QAs/LyLABEEQFOTv74+BaWXMv6itcJ0qdexQ0y5HSEhICbbs8yQCTBAE\\\\\\n\",\n              \"QUHnL1ygeuM2StVRU1PDqnEbLly4UEKt+nyJABMEQVBQSmoa2jrllK6npaNLampaCbTo8yYCTBAE\\\\\\n\",\n              \"QUEVKxiQmZqidL3stGQqVDAovqCgFBFggiAICuro5MS9syFKLVCWy+X8dTaEDh06lGDLPk8iwARB\\\\\\n\",\n              \"EBTUoUMHNJHxJPKywnUeXA7DuKIBLVq0KMGWfZ5EgAmCIChITU2NYYMH479yDjmZGcWWz0pL5YT3\\\\\\n\",\n              \"EqZ/Mw01NbUP0MLPiwgwQRAEBfn5+eHltRrrKmbsmj2ajOSXbyyb9jKRXd+NxLmTE4MHD/6Arfx8\\\\\\n\",\n              \"iAATBEEohlwuZ+7cuYwfPx5/f39OnzpFj46OrPbohN+K74m5f4vc7CxysjKJvhtJwIrZeI3szABX\\\\\\n\",\n              \"F9Z6rRa9rxKiJn2i2yU3adKE8PDw0m6GIAgq7uXLlwwePJi0tDT27t2LmZlZ/n3R0dE4tm1LanoG\\\\\\n\",\n              \"yUlJAJhXrsIoj+GMHjUKc3Pz0mr2O1Olc6fYjV4QBOENbty4gaurK927d2fJkiVoaWkVuF9XV5f4\\\\\\n\",\n              \"Fy+IiYlBV1e3lFr5+RJDiIIgCEXYs2cPHTp04IcffmD58uWFwgsgKCiIdu3aifAqJaIHJgiC8A8y\\\\\\n\",\n              \"mYzp06dz8OBBQkJCsLOze2NZX19fevbs+QFbJ/yTCDBBEIT/78WLF/Tv3x9tbW3Cw8MxNDR8Y9ns\\\\\\n\",\n              \"7GyCg4PFpVJKkRhCFARBAC5dukTTpk1p3bo1gYGBbw0vgJMnT2Jra1tgUofwYYkemCAInz1vb29m\\\\\\n\",\n              \"zpzJhg0bcHV1VaiOGD4sfSLABEH4bGVnZzN58mROnTpFaGgotWsrdp0vuVyOn58fJ0+eLOEWCm8j\\\\\\n\",\n              \"AkwQhM9SdHQ0ffr0oVKlSly8eBE9PT2F64aHh6Ovr0/NmjVLsIVCccRnYIIgfHZCQ0Np2rQpPXr0\\\\\\n\",\n              \"YP/+/UqFF4jhw4+FCDBBED4bkiSxatUq+vXrx5YtW5g5c+Y7bfPk6+tLr169SqCFgjLEEKIgCCpL\\\\\\n\",\n              \"JpMREBDA71u28TwmFnV1daytLPEcM4q2bdsWCKeMjAzGjBnDzZs3OXfuHFZWVu90zAcPHpCYmEiz\\\\\\n\",\n              \"Zs3e07MQ3pXogQmCoJK2bNlKparVGT9zPtEVbCnnOBzt1kO4LVnQb9horGvWISQkBICHDx/SqlUr\\\\\\n\",\n              \"1NTUCAsLe+fwgr97X927d0ddXZw+S5vogQmCoHJ+XPgzv65eR32Pn6lgWbfAfca1GmPZvh8vIs/S\\\\\\n\",\n              \"x30gnqNHsXnzJr7//nsmTJig1JBhZmYme/fu5cSp06SkpmKgr8/5sNMsWrTofT8l4R2IABMEQaXs\\\\\\n\",\n              \"27ePZavW0mzab5StYFxkGTU1Nczqt0ZnwmqWLRnJmpXLGTNmjMLHyM7O5rvZc/jtd28qWtZFv05L\\\\\\n\",\n              \"tAyqk5uRyoscLUaMHE34lSvMnT27yD0ShQ9DXE5FEASVIUkSterVx6jTOEzrtVCozqMTPlhmPcLv\\\\\\n\",\n              \"4H6Fyqenp9OxSzeeZ2lQo9ckdE2rFCqTFvuYe3+soIZpeYL8fdHW1lbqeXzMVOncKQZxBUFQGWfP\\\\\\n\",\n              \"niUpJR2TOopPoKjSwpnjx4/x/PnzYstKkkRf94G8QJ+GoxYVGV4A5c0taThuKQ+T8xgy3EPhtgjv\\\\\\n\",\n              \"lwgwQRBUxt79+zFt0gU1JSZQaJUrj0UDB/z9/Yste+nSJS6ER1B34HfFHkNdQ5N6Q+cSHHKMW7du\\\\\\n\",\n              \"Kdwe4f0RASYIgsqIjYtHu4KJ0vU09I2JjY0tttzyVV5YOLiirqnY9ACNMmWp3KoHq7zWKN0m4b8T\\\\\\n\",\n              \"kzgEQVAZZctqI8+RKV0vNyuDBQsWsHTpUszMzPJv5ubm+d8bGxtzYP8+2i8svqf2T5Vb9WTXoqGs\\\\\\n\",\n              \"XytC7EMTASYIgsqwrVuHML9QwE2pejnP73Hw4EEcHR2Ji4vLv8XGxhIXF8eVK1d49uwZctTQ1qug\\\\\\n\",\n              \"1GPrGJqRkZZKTk4OZcqUUaqu8N+IABMEQWUMHzaMefN/pIbrK8qUVyxokp/+RW5KPF27dkVTUxMD\\\\\\n\",\n              \"A4MiN+FNTk7GvFLRkzbeSpKQkNDQ0FC+rvCfiAAThI9QXl4ex48f59GjR+Tl5WFmZkaXLl3Q1dUt\\\\\\n\",\n              \"7aaVKhMTE1xcXLh+fBc1e35VbHlJkrjrvxENNYnw8HBatHjz1Hs9PT00NTVIT4hG17iywm1Kff4Q\\\\\\n\",\n              \"Y1NzEWClQEziEISPSEpKCgsXLsTS6gum/W8mh0+e41hYOEtXraVK1WpMnjyFp0+flnYzS9Wvvyzm\\\\\\n\",\n              \"ZUQIT88cems5SZK4e2gNhlIqC+bPx83NjTFjxpCYmFig3KtXr/D29sbJyYm8vDye/KnYerHXosMO\\\\\\n\",\n              \"Mnb0KKWfh/DfiR6YIHwkoqOj6dS5M5WtajFv5RZq2doVuD82+hmHdnnTtFlz/HwP0bx581Jq6Yd3\\\\\\n\",\n              \"69YtfvP25u69h8jyZDi0bsWfh38nLvIcNl2GUtGyXn5ZSS4n/s4lok/uwlAzl5BjwZiYmNCnTx/m\\\\\\n\",\n              \"zJlD3bp1+eGHHzA3N2fXrl0cPXoUJycnJk+eTI0aNWjh0BabrsMpo2tQbLuyUhJ5fimYcduWl+TT\\\\\\n\",\n              \"F95ABJggfASSk5Pp1KkzrTv3Ysi4qUXu12deuSrjvp1H/SYtceneg9BTf1KnTp1SaO2HExERwfhJ\\\\\\n\",\n              \"U7h1+w61OrhiWKstmhoaxMVFg9YFXt29zI3H19E2MELXpApIEsnR9zE00OO7yRMZOnQoOjo6AOjr\\\\\\n\",\n              \"6+Pu7k5MTAwTJ05ER0eHKVOmsGHDBipWrAj83WtrYFuX8ysm0GraejTLvnnINic9hWvrp/H1lMlU\\\\\\n\",\n              \"qlTpg7weQkEiwAThI7By5Uqq1qj7xvD6p9btO/N8zBS++fZ/BAYoN+VblZw8eRLX3n1pPHAKgyat\\\\\\n\",\n              \"RkOr4Aw/ux7DiI68yKm139HHuRPt2jqioaFB9erVsbe3z38d7927x44dO9ixYwdlypRhyJAhLFq0\\\\\\n\",\n              \"iJMnTzJr1ixevXrFggUL0NXVZfz48aSnptC5dRNO/DoGy25jMKvfGnWN/ztVymUyYq+d4lHgBgb1\\\\\\n\",\n              \"6cX8efM+5Msi/IPYC1EQSplMJqNadUsWrtuNTR1bhepkZWbQt10DIq5cxtLSsmQbWAru3btHsxat\\\\\\n\",\n              \"aP/1MqrYvn3bqLSEWHxnD2aj10p69+4NQHx8PD4+PuzYsYPHjx/j7u7OkCFDaNSoUYE3CImJicyc\\\\\\n\",\n              \"ORN/f3/Mzc0xMTHhjz/+oHz58vj4+LB42XIeP32GiW0r1MrokpeRQsy1UAz09WndvAn9+/enR48e\\\\\\n\",\n              \"n9T0eVU6d4pJHIJQyoKDgzG1qKxweAGU1SlHp5792LRpUwm2rPQsWvILtTv1Lza8AMobm9Nm7A/M\\\\\\n\",\n              \"/H4Oe/bsoXv37tSoUYPz588zb948oqKiWLFiBY0bNy7UuzUyMuLHH3+kYsWKPH36lJycHJ49e4aa\\\\\\n\",\n              \"mhru7u5EXLrAiSOBtLM2JPXGCWIiTmLZ2BGr9r15pmnO9B+XUqlKNb6fPYfMzMySejmENxBDiIJQ\\\\\\n\",\n              \"yh48eIBNnQZK17OpXZ/718+WQItKV3JyMnv37qXvcj+F61Rt2IrjXt/x66+/MnHiRHbt2oWenl6x\\\\\\n\",\n              \"9e7du0fXrl0ZOHAgc+bMYf369bRt2xYPDw9mz55N+fLl8fXz52BgMM2GTseqSbsCw4kAiU/v8cfe\\\\\\n\",\n              \"NRwObsex4CP5n6cJJU/0wAShlMlkMjQU3HvvnzS1tJDJlN9W6WN3+PBhqtRrQnlDU4XrqKmp0aDb\\\\\\n\",\n              \"IOwbN2XIkCEKhdf58+dxdHRk+vTpzJ8/H01NTSZMmEBkZCQxMTHUrVuX0WPGsH7TNnr+tBPr5h0L\\\\\\n\",\n              \"hReAUbUadJy2HMxq4tKz1yf5N/lYiQAThFJmZmZGXLTya7tiop5gZqb4SV5VxMXFUc7IQul6eqaV\\\\\\n\",\n              \"iYmLU6isr68v3bt35/fff2f06NEF7jMzM2Pbtm389ttvbN22nY7frqLcGy6c+Zqamhoth88gKiEF\\\\\\n\",\n              \"Pz/Fe47CfyOGEAWhlDk7O/PV+AkkxsdhZGKmUB25XM6Rg7vx2bWjhFv34WlpaSHJ85SuJ5fJCAk5\\\\\\n\",\n              \"io2NDebm5m+8BQcH4+XlRVBQEE2bNn3j40VHR/OFXUsMq3yh0PHVNTSo03Uwv67yws1Nub0ahXcj\\\\\\n\",\n              \"AkwQSlmFChXo26cP/j5bGT7hfwrVuXj6OPrldd+6NZKqsra2JunRRqXrJT2+zQRPT8aMHkVsbGyB\\\\\\n\",\n              \"27lz53j+/DkREREkJCSgpqZG165dMTc3x8LCosigW+m1DpsuI5Rre4sv2b55IVFRUVSp8g77KgpK\\\\\\n\",\n              \"EQEmCB+B6dP/R8tWrWnQpCWNWrR5a9nY6GcsnDEeTXU1zpw5Q5s2by+vajp27Eh2ShIvHtzE1Lpe\\\\\\n\",\n              \"8RWA3OxM7p7yZd/yy1hZWVGjRo0C9+fk5ODh4YGVlRWRkZFUrFiRxMTEQkEXHR3N5cuXiY2N5e69\\\\\\n\",\n              \"uzQcpVjv6zUNrTIYmlclOjpaBNgHIAJMED4CNWrUYN9eH/r268/IKd/RpZc7Wv9aWyRJEpfCTvLL\\\\\\n\",\n              \"d5OZP28uVlZW9O/fHw8PD+bOnYuWllYptf790tDQYLznWHb5bcJpytJiF3YD3ArZT7NmzbGysip0\\\\\\n\",\n              \"X3JyMm5ubujr63P8+HHKlSsHgKmpKaampjRoUPQM0CrVrUCBYxflE11e+9ERASYIH4n27dsTcjSY\\\\\\n\",\n              \"qdO+YdPKhXRxG4hNbVvUNTR4/uwJwQd2Ula7DOvWetGrVy/g762WRowYQevWrdm5c2ehnoeqmjxp\\\\\\n\",\n              \"Erv2+BC+dw1N+o1/a4g9Cv+TG4d+I+z0qUL3RUVF0a1bN9q0acOqVauU2jG+UqXKvIx6SAWL6grX\\\\\\n\",\n              \"ycvNISn2GZUrK76bvfDuxCxEQfiI2Nvbc/LEcU6HnsJMT5PrZ45w6bgvUmosWzd7czPyRn54wd8z\\\\\\n\",\n              \"5gIDAxk6dCitWrVi06ZNn8S7fz09PY4fDSbm4hH8fxpH7N1rhZ5XclwU57b9wrmN8wgK8Cu0L2Rk\\\\\\n\",\n              \"ZCStWrVi8ODBeHl5KX25kzEjh3PvhHI70z+4cAxb2/pUrVpVqXrCuxE9MEH4CNWuXZulv/yiUFk1\\\\\\n\",\n              \"NTUmTJhA+/btGThwIEFBQWzYsAEjI6MSbmXJMjExITcrk+wntzjrNR11HT2MLGuBuibp8VHEPbzN\\\\\\n\",\n              \"8GFD+WPFRapVq1ag7smTJ+nfvz8rVqxg4MCBSh87JSWFO3fu8OBKGK+eP6ZCJcti60hyOXeO7GTp\\\\\\n\",\n              \"vJlKH094N6IHJgifiHr16nHx4kWqV6+OnZ0dx48fL+0m/Sdr1qwhISEB30OHePr4IVvWrWRCv66M\\\\\\n\",\n              \"69WeJd9/Q0zUM1YuX14ovHbv3k3//v3x8fFROrxkMhkbN26kVq1aJCQk8P1333F0yUQyk5PeWk+S\\\\\\n\",\n              \"JM5tXYKZfll69uyp9HMV3o3ogQnCJ0RbW5tly5bRpUsXhg0bxsCBA1mwYAHa2tql3TSlJCcnM2vW\\\\\\n\",\n              \"LL788ktat24NQIcOHd5aR5Ikli5dyurVqzl+/Dj169dX6pghISFMnToVIyMjAgICaNy4MQDZOTn8\\\\\\n\",\n              \"/v0gmg+bTnX7Nqj/ayjyZfQjLu/1okxaHMFHgz+ZyTQqQfpENW7cuLSbIAilKj4+XurZs6dkZ2cn\\\\\\n\",\n              \"3bp1q7Sbo5TRo0dL2traUnR0tELlZTKZNGHCBMnW1lZ69uyZUse6efOm1K1bN8nGxkY6ePCgJJfL\\\\\\n\",\n              \"C5XZt2+fZGvXSDKuVE1q0nu05DBiutRi0BTJpnFrqaKRifTt9BlSWlqaUsf9WKnSuVMMIQrCJ8rY\\\\\\n\",\n              \"2JiDBw/y1Vdf4ejoyLp161RigsfDhw/ZsmUL33zzjUIXiszMzKRPnz7cvHmTM2fOKLz+Kj4+nvHj\\\\\\n\",\n              \"x9O2bVs6duzIzZs36dWrV5EzHvv06cONiMsEHdpPt3oWNNDNwLGKNvOnjScm+hlLFv2Mru6bL34p\\\\\\n\",\n              \"lAxxPTBB+Az89ddfDBo0CAsLC7y9vTE1/Xj3UHRwcODOnTtER0cXO/SZkJBAjx49sLKyYtOmTQoN\\\\\\n\",\n              \"lWZnZ7N69WoWL16cvwu9qk94eZ9U6dwpemCC8BmoVasWZ8+exdbWFjs7O44cOVLaTSrSiRMnuHDh\\\\\\n\",\n              \"Ar/99luxYfTw4UNat25N27Zt2b59e7HlJUli//791KlTh9OnT3PmzBlWrlwpwkuVle4IZslRpXFc\\\\\\n\",\n              \"QfiQTp48KVWtWlWaNGmSlJGRUdrNyZeXlydVrlxZsre3L7bspUuXJAsLC8nLy0uhx75w4YLUunVr\\\\\\n\",\n              \"qWHDhtKxY8f+a1M/aap07hQ9MEH4zLRr145r164RGxtLs2bNuH79+kfx2djq1auJi4tjz549by0X\\\\\\n\",\n              \"FBRE165dWbt2LePHj39r2WfPnjF48GB69eqFh4cHly9fxsnJ6X02WyhFIsAE4TP06tUrqlW35HlM\\\\\\n\",\n              \"LHZ2dmhpaVHd0op5P/zA8+fPP3h70tPTmTVrFu7u7tSsWfON5X7//Xc8PDzw8/MrsCPJv6WlpTF7\\\\\\n\",\n              \"9mzs7Oz44osvuHv3Lh4eHkrvxiF83ESACcJnRCaT4fnVeBo1bsL9uGTm//4Hf4Q/Ze/FR0z95Xcu\\\\\\n\",\n              \"3XpInbr1mD1nzgftlU2aNAlJkli7dm2R90uSxNy5c/n5558JDQ2lZcuWRZbLy8vD29ubmjVr8vjx\\\\\\n\",\n              \"Y65evcr8+fMpX758STZfKCViIbMgfCbkcjkDBg7i8fMXrAs4j66efoH7v6hty7jvl+Du+S0/Tx5G\\\\\\n\",\n              \"0suXeK1apdBu8P/FkydP2Lp1K0uXLkVPT6/Q/bm5uYwZM4bIyEjOnj2LmVnRF/08fvw406ZNQ09P\\\\\\n\",\n              \"D19f37derFIVyGQyNDQ0Svz1V2WiByYIn4kVK1Zw58ETZqzYUii8/qmCkQlz1u3hcPCxYj+Peh/c\\\\\\n\",\n              \"3d2xsLBg0qRJhe5LTU3FxcWF+Ph4/vzzzyLD66+//qJHjx6MGTOG2bNnExoaqpLhJUkSFy5cYPCQ\\\\\\n\",\n              \"oejpG6CtrY2WlhZ16tmyevVqUlJSSruJHx0RYILwGcjLy2P5ipWM+PYHymiXLba8rp4+gyfNYsFP\\\\\\n\",\n              \"C0u0XSdOnODixYvs3LkTdfWCp6OYmBjatm1L9erVOXToUKGFwomJiUyaNAkHBwccHR25desWvXv3\\\\\\n\",\n              \"VskeS1JSEu2dOtK7nzuaplas8j3DvitR7LrwkAFf/8C+oGNUq27Jrl27SrupHxURYILwGThy5Ah6\\\\\\n\",\n              \"FY2oUc9O4TqNHJx4HhPDpEmTyMvLe+9tkiSJwYMH5wfQP92+fZuWLVvi5ubGhg0b0NT8v087cnJy\\\\\\n\",\n              \"WL58ObVr10Yul3Pr1i2++eYbldvv8bVXr17RxrEtFavVZKVvGD2HeVLByAQ1NTW0tMpQv1lrvl68\\\\\\n\",\n              \"gXm//8HUb/+Ht7d3aTf5oyECTBA+A8dPnKBJu65K1dHQ0KBdNzf8/f3p2LEjz549e69tWr58OfHx\\\\\\n\",\n              \"8fj4+BT4/ZkzZ2jXrh3z5s3j+++/z+9RSZLEwYMHqVevHseOHSM0NBQvLy9MTEzea7s+tNFjxvJF\\\\\\n\",\n              \"w2YM+XpOoV7oP1WvUYfv1+3h2+kzuHnz5gds4cdLBJggfAaSk1Mor2+gdD1d/QoMGDCAL7/8kiZN\\\\\\n\",\n              \"mrBv37730p7MzEy+++47PD09MTc3z//9/v37cXNzY/v27QwfPjz/95cvX6Zdu3bMmTOHNWvWEBgY\\\\\\n\",\n              \"WOgClqooKiqKoyFHGTBxlkJDn5UtbejUdxirVnt9gNZ9/ESACcJnQK98eTLT05Sul5WRhoGBAbNm\\\\\\n\",\n              \"zSIgIIDvvvuOESNGkJqa+p/aM3bsWLS0tFi2bFn+71asWMGUKVM4evQonTp1AiA6Oprhw4fj4uLC\\\\\\n\",\n              \"4MGDiYiIyL/vU7B+wwYcu7mhU07xjYA7ug1iz549YlIHIsAE4bPQqlVLrp49oVQdSZIIDw2hRYsW\\\\\\n\",\n              \"ADRt2pQrV66gqamJvb0958+ff6e2PH78mJ07d+Ll5YWWlhZyuZypU6eyceNGwsLCsLOzIz09nXnz\\\\\\n\",\n              \"5tGgQQMqVarEX3/9xejRowt8FvYpOH7iJE2VHNo1MrOgmnVNrly5UkKtUh0iwAThM9CrVy+iH9/n\\\\\\n\",\n              \"6YO/FK5z41IYSfEvmDVrFj4+PuTm5lK+fHl+++03lixZQs+ePVmwYAEymaxQ3aysLHbs2EHbDl9S\\\\\\n\",\n              \"u14D6jdsTD/3gZw5c4bevXtjbW3N0KFDycrKYsCAAYSHh3PmzBmqVq3Kli1bqFWrFnfv3uXKlSss\\\\\\n\",\n              \"XLgQff03T/tXNTk5OURFRREeHk5cXBzl3uG5ldc3ED0wxEJmQfgslClThnFjx7Jj5U9MX7652C2V\\\\\\n\",\n              \"crKz2L1mEYt+XoiFhQWrVq1i2rRpjBs3jjFjxuDm5kbz5s0ZNmwYwcHBbN++HSsrKyRJ4pely/jx\\\\\\n\",\n              \"x59Q0zEhs6wVamVqQo6cu2FRBAT1IjM9hc3eG0lKSqJXr16Ym5tz9OhRzp8/z9SpUylbtiz79+/P\\\\\\n\",\n              \"7/mpgry8PBISEoiLiyM2NrbA7d+/S0lJwdTUFDMzMzIyMsh6h6HdjPTUIhd9f27E9cAE4TORk5ND\\\\\\n\",\n              \"5y5d0ShvyFfzfkVLq0yR5TIz0vl5ynDiox5z5/at/G2Yrl+/zurVq9m/fz89evRg4sSJNGrUiOXL\\\\\\n\",\n              \"l7No0SJ+/fVXws6eZ+feQ+SYtke9bIVCjy1JEvLkR2jGn8HYsAJ9+vRh7NixzJgxg4iICBYvXky/\\\\\\n\",\n              \"fv0+irVckiTx6tWrt4bR658TEhKoUKEC5ubm+TczM7MifzYyMkJdXR1JkhgydCgplGPEtz8o3K7k\\\\\\n\",\n              \"pAQm9XTg8aOHGBoavvfnrUrnThFggvAZycjIYMDAQVy5eo1OfYbi1MsdPYOKALxMeEHIgZ2E/LGD\\\\\\n\",\n              \"Lp07IcvNITo6Gn9//wJ7CSYlJeHt7c2aNWuoVKkSEydOxMbGhu7de5CYmoOGtRtqGm9fkyVPj4Wn\\\\\\n\",\n              \"h3Hv34fAwEC+/fZbJk+eTNmyxS+y/q/S0tKKDKN//y4uLg4dHZ23htHrm4mJCVpaWgodPyEhgR07\\\\\\n\",\n              \"duDt7U1ycjIvk1P4LSQC7bI6CtU/uNkLtZfRbN2y+b+8DG+kSudOEWCC8Jl5vWXRaq81HDp4kLLl\\\\\\n\",\n              \"yiFJErk5Obi7uzNh/Fc0bNiQvLw8xo4dy+3btwkKCsLAoOA0fJlMhr+/P6tXr+bOnTu8Sk4jr0pX\\\\\\n\",\n              \"1Mspti5LFnMB64oZhJ46+Z+vEJ2dnf3G3tG/f5bL5QXC500BZWZmho6OYqFSnLy8PEJCQvD29iYk\\\\\\n\",\n              \"JITu3bszcuRIHB0dcenRE73KNXAf/79iHyc+JorvhnbnSFAAjRs3fi9t+zdVOneKABOEz1h2djZJ\\\\\\n\",\n              \"SUmoq6tjaGhYqBchl8uZMGEC4eHhBAcHU7FixSIfZ+XKlfxv9iLUrXsrfGwpNx31R/uJjYkucpKG\\\\\\n\",\n              \"TCYjPj5eod5Senp6fggV11sqX778BxuifPjwIZs3b2bLli2Ym5szcuRIBgwYUODNQFxcHM1btMSx\\\\\\n\",\n              \"5wB6jZjwxrbFRT1l4YSBTB7vybRp00qszap07hSTOAThM6atrY2FhcUb71dXV2fNmjVMnToVJycn\\\\\\n\",\n              \"jh49irGxcaFyp8POIzeopdS0ZjUtXdTKmTJx4kTMzMwKhVNSUhKGhoaFekvVqlWjWbNmBQKqYsWK\\\\\\n\",\n              \"b93F4kPKzMzkwIEDeHt7c+PGDQYNGkRgYCANGjQosryZmRlhZ07j7NKdSyeCcOozDIcuPfOHFJ/c\\\\\\n\",\n              \"u83RvVsIO+rHj/PnM3HixA/5dD5qIsAEQXgrNTU1fv31V2bNmkX79u05duxYoV3hY+PiUNNSfDHu\\\\\\n\",\n              \"a7mUISoqirp161KvXr0CvSVjY2OVWfclSRJXrlzB29sbHx8fmjZtiqenJz169FBoj8bKlStzOfwS\\\\\\n\",\n              \"gYGBLFm6jA0//o9y5XSRyWTo6+szduwYfl9+k0qVKn2AZ6M6VONfhyAIpUpNTY2FCxeira1Nu3bt\\\\\\n\",\n              \"OH78eIGTaVntsiApv+GvPDeHU6dOce/ePSwsLLCwsKBSpUpFfm9iYvLRXVE5MTGRnTt3smnTJpKT\\\\\\n\",\n              \"kxkxYgQRERFUq1ZNqceJjo5mw8aNbPztN8po62BTux7pqalkpKUwdOgQRgwfLsKrCCLABEFQiJqa\\\\\\n\",\n              \"GvPmzUNbW5u2bdty4sQJTExMOHnyJEkJL5DSc6CCjcKPJ0kSOmppBJ44QbVq1YiJieH58+fExMQQ\\\\\\n\",\n              \"ExNDWFhY/vcxMTG8evUKE7yc0cMAACAASURBVBOTYoPOzMysRHtucrmc48eP4+3tzZEjR+jWrRvL\\\\\\n\",\n              \"li2jffv27zSM6evrywgPDxw692L2mt1Us6mdf1/Uo/sc3b8Nu0aN8Fq1ikGDBr3Pp6LyRIAJgqCU\\\\\\n\",\n              \"YcOGERERQe3atVFXV8fe3p4vv3Ti1uq1SPJWqKkrdlqRp0WRnvqKffv2MWHCBFq2bPnW8jk5OcTF\\\\\\n\",\n              \"xRUKukuXLuV///z5cxITEzEyMio26MzNzSlTpui1cEV5/PgxW7ZsYfPmzRgbG+Ph4cG6deveOLFF\\\\\\n\",\n              \"EYGBgYwcPYbv1+wq8lI3Vaxs8Ph2Pk69BjB1/CA0NTXp37//Ox/vUyNmIQqC8FZyuZzw8HACAgII\\\\\\n\",\n              \"DAzk8ePHdOnSBS0tLY4dO8bJkyepUaMGbds7cfZuFpqm9sU+piTJUX8aiIF2Durq6mRkZNC0aVMm\\\\\\n\",\n              \"TJiAs7PzfxoqlMlkvHjxolDQ/fP7mJgY4uLiqFChQqFg++fPhoaGXLx4ke3btxMREcGAAQMYOXIk\\\\\\n\",\n              \"dnaKX1ftTdLT06larTozV26jdsPip8Q/+usmc0b34eGD+yWygPk1VTp3ih6YIAiFpKamEhISQkBA\\\\\\n\",\n              \"AEFBQRgZGeHi4sKKFSto2bJl/hDdb7/9Rvv27Tl06BDZmenIYi6jpqmLhmHNNz62JMnRiDuDva0V\\\\\\n\",\n              \"x0OCOXDgALNnzyY6OppZs2YxadIkPD09GTlyZJEzHouSnZ3NgQMHWL1uAw8fPEAmy8XYxJQhA90Z\\\\\\n\",\n              \"PWpUkevMXm//9O9wu337NgcOHOD27dvEx8cjSRLlypWjatWq3Llzh19//bXIHp2FhUWhq0a/zc6d\\\\\\n\",\n              \"O6lj30yh8AKwqlWPJm2c2LJlC1OnTlX4OJ8y0QMTBAGABw8eEBAQQEBAABcuXKBVq1a4uLjg7OyM\\\\\\n\",\n              \"lZXVG+stW7aM6dOno66ujp6eHtm5eVC+Orl6tQssapYkOfLkR5RNu4VdXWsCA3zz9/PLzc1l8+bN\\\\\\n\",\n              \"zJ8/nxo1amBgYMCff/5Jr169mDBhAk2aNHnj8ffv389Yz/EYVLGmers+GFnVQU1dk/SE5zw548fj\\\\\\n\",\n              \"i8cY6eHB8mVL39qze/nyJbt27cLb25vExERGjBjB8OHDqVatGklJScX26J4/f56/LOFtQ5eVKlVC\\\\\\n\",\n              \"T0+PBnb29PGcgX2rdgr/je5cC2f9vCncv3e3xNayqdK5UwSYIHymcnNzCQsLyx8afPXqFc7Ozri4\\\\\\n\",\n              \"uNCxY8cC20e9SWhoKD169CArKwuZTMbo0aOZP38+69dvYOVqL3LlmsjUyiLPy4WcZGSyXDb9tp4B\\\\\\n\",\n              \"AwYUOdEiMzOTtWvXsnjxYtq1a4eVlRU+Pj6YmZkxfvx4+vXrV2C7KW/vTXwz8ztaT1qGsbVtkW3M\\\\\\n\",\n              \"Sn3F+XUzsbOy4I99PgVCTC6Xc/LkSTZt2kRgYCBdunTBw8MDJycnpYcxX++dqEjQ/f1cs/gj/Ama\\\\\\n\",\n              \"Cm5B9foY/Zt/QUJCvFK9PWWo0rlTBJggfEYSEhI4cuQIAQEBHD16FGtra1xcXHBxccHe3l6pWXQb\\\\\\n\",\n              \"N25k2rRp6OjoYG1tjZ6eHteuXcPf359mzZohk8m4ePEiO3fu5N69e6xcuZKhQ4fy448/0rlz57c+\\\\\\n\",\n              \"dkpKCsuXL2fVqlX069eP5s2bs3v3bq5evcrIkSMZN24cL168wKlzVzrM8sagkuVbHy8vN4fQpeMZ\\\\\\n\",\n              \"0ceFeXNm8+zZs/wJGXp6eowcOZJBgwZhZGSk8PN/V5IkER8fT+XKVThw5ZnS9Ye2rcvdv+5gYqLY\\\\\\n\",\n              \"ll3KUqVzp/gMTBA+YZIkERkZmT80GBkZiZOTEy4uLixfvvytu3C8SW5uLlOmTGHPnj2YmZlRv359\\\\\\n\",\n              \"JEnijz/+4PDhw7i4uHDw4EFat25Nq1atePToES9fvqROnToMHz6cLVu2FBtg+vr6zJ07l/Hjx7No\\\\\\n\",\n              \"0SKmTp2a37vbvXs39vb2lNMzoFa3YcWGF4CGVhkaDfuOZQuGcfbMaS5fvkz//v3Zt28fjRo1KvGt\\\\\\n\",\n              \"pSRJIioqiqtXr3Lt2jUiIiKQkEhPTUFXT/HrgeXmZJOellpoX8rPlQgwQfjEZGZmcvLkSQIDAwkI\\\\\\n\",\n              \"CEBDQwMXFxfmzp1L27ZtFdoZ4k0SExPp3bs39+/fx8bGhmbNmnHlyhWOHTuWf5wdO3bg6urK3r17\\\\\\n\",\n              \"adeuHTo6OmRmZgLg7u7OrFmzePXqFRUqFL7cyr8ZGxuzdOlSpkyZwo8//oizszNTpkzhzz//pFmL\\\\\\n\",\n              \"ljg49lK47QaVLNE1t8TGxgZfX9/3tlHvv+Xk5HDnzh2uXr2af7t27Vr+lawbNmxInz59eJWSyukj\\\\\\n\",\n              \"h+jSd6jCj302JBAHx7ZKTf//lIkAE4RPQHR0dH5g/fnnn9jb2+Pi4sKRI0eoXbv2e+lhREZG0r17\\\\\\n\",\n              \"d9TV1WnYsCFt27Zl8+bNhIWFFQiDTp064ePjQ79+/di5cydly5bNDzAjI6P8+8eOHavwsatUqcL6\\\\\\n\",\n              \"9ev55ptvmDt3LkuXLsW8TmO0yyt3NWNrx568SLr73sLr5cuXXLt2jWvXruWH1Z07d7C0tMTOzg47\\\\\\n\",\n              \"Ozv+97//YWdnh7m5eYG6JiYmeE6cQuc+QxT++xzdv5X53894L23/FIgAEwQVJJfLuXTpUv7Q4NOn\\\\\\n\",\n              \"T+natSsDBw5k69at/2lxbVH8/Pzw8PDAwMAAR0dHOnbsyIwZMzhz5kyRa5Lat2/PgQMHcHNz4+uv\\\\\\n\",\n              \"vyYrKyv/vuHDh/Pjjz8qFWCv2djYsHPnTr7//nv2nrmhdP2y+oYkPX6pdD1Jknj8+HGBHtXVq1dJ\\\\\\n\",\n              \"TEykQYMGNGzYkJYtW+Lp6YmtrS3lypUr9jE7dOiAbtkyHNq6Ftfh44stf2TfNjJTXuLi4qJ0+z9V\\\\\\n\",\n              \"IsAEQUWkpKRw9OhRAgMDCQoKwsTEBBcXF1avXk2LFi1KZPskSZJYuHAhq1evpnz58vTv35+OHTvi\\\\\\n\",\n              \"7u7OsWPHqF69+hvrOjg44O/vT9euXQtMjujcuTOjRo3i7NmzJCUlkZqaip6eHs2aNVP4umDVqlVD\\\\\\n\",\n              \"LS9C6ecjy8lCX/ft4ZKVlcXNmzcL9KquXbuGnp5efq9q0KBB/PLLL1hbW7/zLvjq6ur4+/nSsnVr\\\\\\n\",\n              \"8mQyXIePR6OIv2FeXh5Bezbht20tYadPq8wGxx+CeCUE4SN2//79/F7WxYsXad26df7nWZaWliV6\\\\\\n\",\n              \"7IyMDDw8PIiMjERdXZ2vv/6adu3a8eWXX+Lj4/PGy4P8U/PmzVm7di1DhgzBx8eH/v37c/nyZQwM\\\\\\n\",\n              \"jWnn9CVmNexQL6uLPDudpIeRdO7ShW+nTim0rVRSUhJhYWGEhoYSGhrK9evXQasszfNkqGsofhpL\\\\\\n\",\n              \"uH2RL9s2yv85Pj4+P6hef339+d7rsOrZsycNGzZUeFG1MqpWrcqFc+fo5z6Acfu386XbIFp06Iqu\\\\\\n\",\n              \"ngEZ6alc+vMoR//YjoW5GefCwt66Hu9zJAJMED4iubm5nDlzJj+0UlNTcXZ2ZuLEiTg5OSm0Nut9\\\\\\n\",\n              \"ePbsGb169cLY2Jj4+HhWrFhBq1atcHBwwMvLi/bt2yv8WI0aNcLCwoKvv/6aQ76+BB09jrFDfxq7\\\\\\n\",\n              \"L0GznF5+uSqZqVwPP0KX7q6MHzuKhg3qExoayunTp3n8+DEtWrSgcePG+bvS58hyiboSSrWmHRRq\\\\\\n\",\n              \"R056Cg/PBRNftzLOzs5cu3aNtLQ0GjZsiJ2dHe3bt+frr7+mbt26BdaalbTKlSsTdjqUiIgIVnut\\\\\\n\",\n              \"4df/jSY1NYXy5fVo3aoVh/7YR9OmTT9Ye1SJCDBBKGXx8fEcPnyYwMBAjh49So0aNXBxcWH37t3Y\\\\\\n\",\n              \"2dl98As1nj17lj59+vDll18SFBTErl27aNSoEQ4ODnzzzTf069dPqccrW7Yscrkcj5EjWbp6Aw0m\\\\\\n\",\n              \"rkO7glmhcpo6eli06Yth/bb86uVJjapmDB82jKFDh/L8+XMWLFjA4sWLMTQ0ZPLkydSoUYNZPy2l\\\\\\n\",\n              \"UsNWaJYpPnCuHdiIjk45KlSowOjRo7Gzs6N69eqoqamRmJjI5s2bmTz1G16+fEm5cuVo3MieCV95\\\\\\n\",\n              \"Uq9ePaWe77uyt7dnk/fvH+RYnwqxkFkQPjBJkrhx40Z+L+vmzZt07NgRFxcXunbtWmi22oe0efNm\\\\\\n\",\n              \"pk+fjru7O/v27cPf35+6devSsWNHHBwcWLJkidKPGR8fT61atciW5VF3wm+UNSr+ulZZSTHcXDWK\\\\\\n\",\n              \"Qe792LdvH6mpqTRp0oSFCxfi5OQE/D2RpaerG9efJeAwcSlaZYv+bEuSJG4GbOHpSR8sTE2oVq0a\\\\\\n\",\n              \"mzZtwsTEhOzsbCZPmcquXTup18oJ2/bd0DUwIjc7kwdXznExyIe6deqwbfMmvvjiC6WfuypSpXOn\\\\\\n\",\n              \"CDBB+AAyMzM5ceJE/lR3LS2t/B0wHB0d/9PaLEVlZWVx9uxZEhMTKVu2LHXq1MHG5u/rd8lkMv73\\\\\\n\",\n              \"v//h5+dHt27dCAgIIDg4GCsrK3r37o2enh7btm17p97gy5cvMTU1w8TeCev+3ylc769dC0i5Gcqo\\\\\\n\",\n              \"USOZM2cOCQkJBT6runr1Krm5uZTVLU96tozaXYZg7didMrp/T62X58mIijjNoxM+pMc+QZadSY8e\\\\\\n\",\n              \"PdDR0cHPz48NGzawZOky0tTK0uvrBZSvUHg2ZZ4sl7CD2zm7z5tTJ09Qp04dpZ+/qlGlc6cYQiwF\\\\\\n\",\n              \"aWlpJCUloa2tjZGRkZhV9ImKiorKD6xTp07RqFEjXFxcOHr0KLVq1Srx3R9ee/r0KatWe7Fp82ZM\\\\\\n\",\n              \"q1RH39gMWU42j25dp359W0Z7jGDbtm3A39PfQ0NDOXv2LGZmZowbN47MzEz27duncHhlZmZy4cKF\\\\\\n\",\n              \"/AkXFy5cQK6ugWmr3kq1u3Jbd9LuXeDixYtYWVlRqVIl7OzsaNiwIRMnTsTOzo7KlSsDcPr0aVas\\\\\\n\",\n              \"XoPvVGcMjC1Q19QkJSEW6y+smT91An379iUrK4ulS5eybt062rZty6AhQ7Fs0JQhc1eh/oZ9DzU0\\\\\\n\",\n              \"tXDs60E5PQM6denK7ZuRH+xzSKF4ogf2gchkMvz8/FizZi3nzp3F0NCIrOwsNNTV8fDwYNy4cW+d\\\\\\n\",\n              \"kix8/PLy8gqszYqKiqJr1644OzvTuXPn9742SxEnTpygT7/+NOrUk1Y9B2FSxTL/vtycbK79eYQj\\\\\\n\",\n              \"m1dSsXw56tauxcuXLzl06BAGBgbMnz8fX19f/vzzz/xd44uSnJycP0Pw9OnTXL16lQYNGtCmTRsc\\\\\\n\",\n              \"HR1p1qwZZuYWtF7yp1KhLUkS52Z0IORoME2bNn1rG15LSkoiKiqK3NxcTE1NqVq1aqEycXFxzJw5\\\\\\n\",\n              \"k5279zDv4AW0yym2Ke6OOV8xdoDrO61fUyUf27nzbcRb/w/g6dOndOvmjK6eHiNGebLVxzd/yOj+\\\\\\n\",\n              \"vb/Y6r0R+0aNmDVzJtOmTftg78yF/y45ObnA2iwzMzNcXFxYs2YNLVq0+E8XZvyvLl68SO++/Rgy\\\\\\n\",\n              \"bzU1GrUodL9WGW2adOpJg7ad+X3GGCKuXuPWzUh0dHT4/fff2bp1K2FhYYWCIy4ujtOnT+cH1v37\\\\\\n\",\n              \"92natCmOjo4sWLCA5s2b5++UnpGRwY0bN1BTV1f637Wamhqoa7BgwQJMTU0xNDRU+PY2ZmZmmFtU\\\\\\n\",\n              \"ooVzX4XDC6Bp94GsWP0LY8aMEf9HPxKiB1bCYmJiaNmqFR5jvsJzwtdvLBcd9YwBvV0YOmQwM2fO\\\\\\n\",\n              \"/IAtFJR17969/F7WpUuXcHBwyL9u1sfSi5YkiXr1G9Ki/1jsO3QrtnxOdharxrqxdsVS5HI5Y8aM\\\\\\n\",\n              \"ITQ0FBsbG548eZIfVqGhobx48YLWrVvj6OhImzZtqFGjBk+fPuX+/fvcv3+fBw8e5H+flJSEpaUl\\\\\\n\",\n              \"f929R7N5fmiVU3zrJ1lmGpd+6EGAvx+vXr0iKSmp2JuWlhaGhoYYGRm9NeCmTPsWtxnLqFan+LVs\\\\\\n\",\n              \"r8nlcn7u50D4hXOf9Hqsj+XcqQjRAythI0eOot+AoW8NL4DKVaqy99BhOrdriZOTE82aNftALRSK\\\\\\n\",\n              \"k5OTU2BtVnp6Os7OzkyePBknJ6cSuy7Tf3H27FlSMjKxa99VofJltMvSpp8HC35ayF+3bzF27Fjm\\\\\\n\",\n              \"zZtHaGgoubm5tGjRAmtra4YNG0Z2djaPHj3i0KFDLF26lPT0dGxsbLCxscHa2prmzZszaNAgbGxs\\\\\\n\",\n              \"0NHRYfHixdx9sIYXF4Oo3M5d4ecQf/kILj160qVLF4XKS5JEenr6G8MtISGBu3fv/v19/AsMjBXb\\\\\\n\",\n              \"9eM1dXV1KhibkpiY+EkHmCoRAVaC7t+/z6XwS2zY6qNQeXOLSowZPxkvrzVs2yYCrDS9XpsVEBBA\\\\\\n\",\n              \"SEgINWvWxMXFBR8fH+zs7D76IaQ1a9fRvLu7Uu20d3Jh369zKaOpweHDh9HW1qZChQo8e/aMkydP\\\\\\n\",\n              \"EhUVhbW1NTY2NrRv357Ro0djY2ODubl5geNIkkRwcDB9+/bl8uXLaGtro1NGg8SLvlRy7IeaApNB\\\\\\n\",\n              \"JLmcl5f8mLpri8LtV1NTo3z58pQvX55q1aoVuj8nJ4e7d+8SGRnJiVOhyHJzFH7s13Jzsj/IjFFB\\\\\\n\",\n              \"MSLAStC6desYMGiYUqv6BwwaRnO7WiQkJJTI1jWqSJKk/D3zdHV1MTY2fu8BIkkS169fz+9l3b59\\\\\\n\",\n              \"O39t1urVqzEzK7zw9mN2884dOo5VbtZfGe2ymFe3pkYVMxwcHPJ7VTY2NhgZGRX7micmJjJ//ny2\\\\\\n\",\n              \"bt1Kamoqtra27N69Gzc3N5o0aUK2TOKh70q+6DXlrY8lSRLRwRuxqV6ZNm3aKPUc4O+hvkePHnHj\\\\\\n\",\n              \"xg0iIyPzbw8ePMDS0hJbW1tMTU15HHkFo0qFg+5NUl8m8Co+7qMZJhZEgJWoCxcu8u13PyhVx9DI\\\\\\n\",\n              \"CNsGDbl27Vr+gs3PVWpqKjt27MBrzVqePX2KnoEB6WlpGBkZ8ZXnOEaMGFHsB/Zvk5GRwYkTJwgI\\\\\\n\",\n              \"CCAwMBBtbW1cXFxYsGABjo6OKn3NpZzsbDSUuFT9a4bGJkydOpVu3Yr/3Az+DpuAgADmzZvH1atX\\\\\\n\",\n              \"0dXVZfDgwcydO7dA6EuSRHxsNBpJiTw99CuVu4xBU6fwrEJZZirPj/6Odvwdgs6EFht0MTExhYLq\\\\\\n\",\n              \"9u3bGBsbY2tri62tLS4uLsyYMYPatWvnv5n09fXlf3N/pHEnxa8ndiloP65urujrK3cJF6HkiAAr\\\\\\n\",\n              \"QWlpae/0+YiGphbbt2/nyZMnGBsbF7hVqFDhg28tVBrOnj1LL1c3GjRuzlczfqJpK0fU1NT+3sUi\\\\\\n\",\n              \"4hIHd3rz008L2b59G87Ozgo/7rNnz/LXZoWGhtK4cWNcXFw4duwYNWvW/OiHBhVlbGxMcnwsVWsq\\\\\\n\",\n              \"tw3SyxcxBXaOf5P4+Hhmz57Nrl27SE9Pp1GjRvmLoP/9GoaFhXHr1i2mTZvGzJkzGeM5Ht+f+6Ff\\\\\\n\",\n              \"pzXGDdujUbY8eVlppP11joRrJ+nm7Mzv/mcLXHU4KSmpQEi9vmlpaeUHVatWrRgzZgz16tUrNmSc\\\\\\n\",\n              \"nZ3xHD+Bu5fDqNm4dbHPNz35JRf8dnIkwK/YssKHIwKsBOnr6/Pq1Sul66UkvyIry5gzZ86QkJBQ\\\\\\n\",\n              \"4JaamkrFihULBdvrm5GRUaHf6evrq9SJ+fz583Tv0ZO5S9fTsm3HAvepqanRoFEzGjRqxo2ISwwf\\\\\\n\",\n              \"MYgtmze9McTy8vK4ePFi/tDg8+fP6dq1K0OHDmXHjh0KXRVYFfV0cWbboT3Ytla8F//sr0hyM9Np\\\\\\n\",\n              \"0qRJkfdLksQff/zBggULuHHjBhUqVGDUqFHMnj37jWvc/P398fDwwN7enlatWqGvr8+endsZMmQI\\\\\\n\",\n              \"L+LjyXh0jJSUFCrq6dGvU1uGbPuVxMREDhw4UCCoXg9Jvr717duXevXqKXz5lX/T1NRk+9Yt9O0/\\\\\\n\",\n              \"gKE/baBanYZvLJuRmsy278cydNBAGjdu/E7HE0qGCLAS1KFDe4L8D9K2veInkefRUdy5dZN6dWrj\\\\\\n\",\n              \"6urKl19+WeAzNJlMlj+j6t+3qKgorl69Wuj3WVlZRQbb227lypUrldDLycnBrXcfvl+8plB4/Vt9\\\\\\n\",\n              \"+6YsWreDIUPcefjwQX4YJScnExwcTEBAAIcPH8bCwgIXFxfWrVtH8+bNS3Vt1odw79491q9fz7Po\\\\\\n\",\n              \"5yQnxGFgrNjnd+d8d+I5bmyh1yc2NpbvvvuOvXv3kpGRQYsWLTh27BgdOrx9F/jNmzczc+ZMAgMD\\\\\\n\",\n              \"+eWXX/KvypyTk0NQUBC7d+8u0LPatnULixf9TK1atfKDavLkydja2lK1atX3/u/RycmJrZu9GTp8\\\\\\n\",\n              \"BI279KZ5j4EYWfzfwuesjDSuHPUlbP8m+rn14pcli9/r8YX/TqwDK0HPnz/H1taW8Bv30VNw3Hzx\\\\\\n\",\n              \"Tz8QH/OURo0acfDgQSIiIujcuTOurq44Ozu/0/h7dnY2iYmJRYZeUbf4+HgApQLP2Nj4vczO2rt3\\\\\\n\",\n              \"L8tWrsFrh+JDNXOmjKRR/doYGRoSEBBAeHg4bdq0yV+bVdSMtE9VYGAg/QcMRC5JlDUwwsDQiIkr\\\\\\n\",\n              \"dqBVzN/meuhR/Fb9QOT1a5iYmCBJEnv27OGnn37i1q1bGBsbM2rUKGbNmlXsVkqSJLFkyRLWrVtH\\\\\\n\",\n              \"UFAQ2trajBo1CkNDQ8qUKcPZs2eJioqiRo0a2NraUr9+/fzAsra2/uBbqz18+JDVXmvYsnUrJlUs\\\\\\n\",\n              \"KV/BkNzsLJ7cuUH79u2ZPHFCsWH9KfkYzp2KEgFWwgYPHoJGmbL8smJtse8g/7pzC1fnjoSeOpW/\\\\\\n\",\n              \"aWh8fDx+fn4cOHCA06dP06ZNG1xdXenRo8c7D58oIiMjQ+HAe33T1tZWKvAMDQ3R+tdEgzaObXHu\\\\\\n\",\n              \"PxKnbj0VbmvExbNMG+3OwAHuuLi4/H2p9o9wbVZJkiSJpUuXMmf+AizqNKHR4G/RNbYgbM0spPRE\\\\\\n\",\n              \"hs5ehqF55UL18mQyzgfsJWTLSoIPB2FhYcGMGTM4cOAA2dnZtGnThoULFxa6wGRRx4+JieH69ev8\\\\\\n\",\n              \"/PPPXL9+nerVq3Pv3j309PTIzc3FwsKCTp06cf36dXr27MnEiRNL6uV4J6/3cHz16hXlypXD1taW\\\\\\n\",\n              \"SpWK3zn/U/OxnDsVIQKshKWmptKqdWvsGzdj8a9eb3x3eeNaBEP6u7Jo0c8MGTKkyDIpKSkEBQVx\\\\\\n\",\n              \"8OBBgoODadiwIW5ubri6upZ6L0OSJFJTU5UKvKSkJPT09AqEWnBwMMciHqOjxBY/kiTRoUFVoqOj\\\\\\n\",\n              \"Cnzw/7nIzs7G09OTQ34BGNdpSkvPn1BX/3sYUJLLuXZgPXeO7uaL+k1o3tUVAyNTcnOyeXDtIuFB\\\\\\n\",\n              \"+7G2/gKXrl3YsWMHd+/exdzcHE9PT7799tsil4C8fPmywOdTr2cBamhooKGhgbq6OlOnTiUuLg7/\\\\\\n\",\n              \"gECyc2VUt6mNVhltXsbHciMiHI8RI5gxY7qYkv4R+ljOnYoQAVbCXr+LTU9PJyUllSEjRtPXfRDm\\\\\\n\",\n              \"FpXIysri0oVzbPl9PRfOnWH9+vX07dtXocfNysoiJCSEgwcP4ufnh6WlZX6YqcolH+RyOS9evODu\\\\\\n\",\n              \"3bvcu3ePBw8esGjRIs7fT1L68w7nFrW4GnHls3vHHBsbi5ubGzo6OoRfi6Tn8iA0tApP/8/NyuDR\\\\\\n\",\n              \"2cNEXz5JdnoK6ppavHp2jwZ1a3P16lVkMhkdOnRg0aJF2NvbA5Cens7t27cLhFRREypsbW2xsrLC\\\\\\n\",\n              \"09OTMmXKsG7dOvr1d0emrkW/kZNo2Kx1gb9nTNQTAvZs4bjfXg4e+AMHB4cP9noJxftYzp2KEAH2\\\\\\n\",\n              \"Hz148IAtW7bw9Okz5HI5lSpZMHjwYOrXr48kSYwZM4bExET279/PtWvXWLt2Lb6+viQmJqKtrU2d\\\\\\n\",\n              \"OnUZN24sAwcOfOdhL5lMxunTpzlw4AAHDx5ET08PV1dX3NzcaNy48QefjCFJEikpKcTGxhITE/PW\\\\\\n\",\n              \"r8nJyZiYmGBubo6FhQUhx45x+MJf6OkrPjswLy+P9g2qEP/ixWd1qYsrV67Qq1cvRowYwdHjJ8g0\\\\\\n\",\n              \"t8W+z1cK1792cCP3g3cyc/q3dO3alTt37hToWT1//rzAhIrXt2rVqhX4N5WQkICzszP16tVj1apV\\\\\\n\",\n              \"dOrcBfMv6uA586e3LvkIDzvJLzO+4vixYzRs+OZZgMKHJQLsI1DSf4SrV68yY8ZMwi+H02/AYOrU\\\\\\n\",\n              \"rYeamhoP799nz65tWH9hTZMmjQkJCeHcuXOFdvSWJKlEgkUulxMeHs6BAwc4cOAAWVlZ+WHm4ODw\\\\\\n\",\n              \"n2bgyWQy4uPjiwyjf/9OQ0MjP5T+/fWf3xsZGRVok7NLd+wdOtGz/1CF23X6+BF2rPuFK5dV4z/d\\\\\\n\",\n              \"mzx9+pR169ezb99+EhMT0dLSwqZGDb4aN5Y+ffoUGM7z8fFhwoQJrFq1imPHjrF123Z6rzpCuYom\\\\\\n\",\n              \"Ch8vMzmRfRM6oaWhnr9DxT9vNjY2xU6oePLkSf4ko4ULF7Ju3Tq27fmDBet3K7ReMXDvdi4dO8Tp\\\\\\n\",\n              \"0FMKt1soWSLAPgIl+Uc4fvw47u7uzJj9A/0HDkFHR6fA/bm5ufgfOsDUSZ78uGABkydPLpF2FEeS\\\\\\n\",\n              \"JG7dusXBgwc5cOAAUVFR9OjRA1dXVzp27Jg/azA9Pb3IEPr318TERAwNDd8YRv/8+q49oSNHjjD1\\\\\\n\",\n              \"2xls8VX82lGThrky1mMow4cPf6djlrbs7GzGeX7FoUOH6NijL1/27I+pRWVyc3K4cyOCQJ8tPLgT\\\\\\n\",\n              \"ycYN6+nZsydz5sxhx44dbNy4kblz52JqasrhI8EM2npJ6WPvGtGcZ08ev9OEoMjISLp27cq0adOY\\\\\\n\",\n              \"MmXK3zvg29Zn5P9+xK65YsOCuTk5DPmyEadOnqBu3bpKt0F4/1QpwMQ6MCXdvn2bAQMG4L3dh9Zt\\\\\\n\",\n              \"HIsso6WlhVvf/tSuU5fePbrQuHHjDz7OL5fLSUhIQCaT0bRpU6pUqcLt27c5d+4cw4cP5+XLl+jo\\\\\\n\",\n              \"6CCTyQCKDCMHB4cCP5uamhb7jjw5OfnvndBTUtDV1cXe3h5zc3OF292pUyekqdM4tGcrrgOGF1v+\\\\\\n\",\n              \"ZLA/1y9fxM+4Al26dFHqWB+DnJwcnF26I9PQZuexK+joFgx+BzMLHDp2486NK4wcPZCvvvoKSZKo\\\\\\n\",\n              \"Xr06Li4uaGtrk52djSwv75169ZIkvdObjbCwMNzc3Fi+fDkDBw4E4MKFC2Rm59CwWfE7W7ymVaYM\\\\\\n\",\n              \"XXoPYuPG31ixYrnS7RA+byLAlLRw4c94Tvz6jeH1T3Vt6/PDT4uZO3cex48fey/Hz8zMJDY29o09\\\\\\n\",\n              \"pdffv3jxAgMDg0KhYtHgIwAAIABJREFU5OrqiqenJ9ra2kRGRuZf8r1evXr50/PfZRPhyMhIVq9e\\\\\\n\",\n              \"jc/evdS1bUCFCoZkpKdxNSKcjh07MmniRBwdi3/N1NXV8ffzxaGNI5JcjuvAEW88KR8LOMBPsyYT\\\\\\n\",\n              \"HHyEoKAgGjRowJIlSxg2bFihOpIkcenSJbzWrOXcuXOkp6Whp69P+3ZtGT9+PPXr11f6OSsjLy+P\\\\\\n\",\n              \"xMREXrx4QXx8fP7X3bv3INPQ5qf13m8d3q1dvxHLtvoyvn9n1JG4ePEiFSpUoG7dupiZmXH4aAiv\\\\\\n\",\n              \"oh5QsaqNwm1Kfv6I8uX1Co0gFOf17ho7duygc+fO+b+/d+8eNes1VDpEa9Sz4/zhvUrVEQQQAaaU\\\\\\n\",\n              \"hIQE/P39uHT9F4Xr9HTry7zvZ3Dnzh1q165dZJnXu62/6fOkf37NzMzE3Ny80BBe06ZNCwSVmZlZ\\\\\\n\",\n              \"sZvRurm5AX/3mgIDAzlw4ABff/01jRs3xtXVlV69ehV5SfZ/27p1K9988y1DRnpy5PQVTM0s8u9L\\\\\\n\",\n              \"TUnm4L5dDBw8hOFDh7BgwYJiT3A2NjacOR1Kz169OLh7E64DR9K+c3f09A1IT08j7GQwuzetJTMt\\\\\\n\",\n              \"BesvrAgMDOTnn3+mb9++jBw5kl27drFx40YsLS2Bvxequg8YSFxcHO7DRrNq1CTKl9cjJfkVR4N8\\\\\\n\",\n              \"6dS5C3Xq1GHX/2vvzuNqzv4/gL/atW/3Vrei0oYo1VhSlCXLaOxKEdJizb6bwTAYy1jGkrXssieE\\\\\\n\",\n              \"siZCGFlT2lDaI+3Lff/+mK9+07Tdqxou5/l43Afu/ZzPOfeD8+p8PudzPocOCjyC4/P5yM3NrRZI\\\\\\n\",\n              \"GRkZNb73/v17qKiogMvlQkNDAxoaGlBWVsaj6EcICLkj0LVJfSNTjPCcgjOHdmH9+vV49uwZgoOD\\\\\\n\",\n              \"8e7dO7Q1a4OXoUfQ2fMXgdoPAPFXT8DLy1OowAkICMDChQtx/vz5as+sKysrg6Sk8AsIS0pKorRE\\\\\\n\",\n              \"+EebMAy7BiaEnTt3IvTyNewIOCBUucUL5yI7PRXdu3evMZTS09MhJydX6/Wkf/6qqqrapLMKi4qK\\\\\\n\",\n              \"EBoailOnTuHcuXMwNDTEkCFDMGTIEJiYmFTb/vjx45g2fQb2HjsHI5OaAxoAsjIz4Ok6EMOHDsbi\\\\\\n\",\n              \"xYsFagufz8fVq1fx5+YtCL9xAx8/5kFeQQHt2pkj5sVzpKSkID8/H126dMHMmTMxYcIElJWVYf36\\\\\\n\",\n              \"9Vi7di1++eUX9O7dGz169IS372yMGje+xokFZWVl8Nu4GsEnDuPA/v0QFxevN5Cys7OhqKhYJZA+\\\\\\n\",\n              \"/f7fv2poaEBeXh6vX79GbGxs5ev69etQ4+nh912CPS8OAHKzM+HawxJtzdpgyJAhcHJywl9//YVF\\\\\\n\",\n              \"ixYhK+c9Bm84D1mV+hfjLc7Lwbm5g/E0+i+B7sX65+oaly5dgqmpabVtzp07h6UrVmN1wGmBvw8A\\\\\\n\",\n              \"nA3ci8xX0Th0ULj/V0zTYNfAvlHv3r2DvkFLocsZtDREaMhZKCoqgsfjwdzcHH369KkSTsI8M6wp\\\\\\n\",\n              \"ycrKYuDAgRg4cCDKyspw48YNnD59Gg4ODlBVVa2818zS0hJFRUWYMHEiAo6erTO8AIDD1cCuQ6fg\\\\\\n\",\n              \"5NARrq6uMDY2rrct4uLi6NWrF3r1qr4moq2tLS5duoSBAwfiwoULsLOzg66uLpycnDBv3jwMHDgQ\\\\\\n\",\n              \"Hh4e+GXxEsxbvAIu7uNqrUdKSgpT5/wMAPhpwAC0btUKmpqalQHUsmVLdO7cuUogcTicaquI8Pl8\\\\\\n\",\n              \"vHnzpjKgbt++Xfn7lJQUtGjRAiYmJjAxMYG1tTWex7xEryGu9R6Hf1JV58Kyow0WzJoGCQkJjBo1\\\\\\n\",\n              \"CioqKjh+/DguhoZhz3pf2M/1g4xC7Td0lxTk4dqaSZgyeZJA4cXn8zF79myEhYXh1q1b0NGpuqJH\\\\\\n\",\n              \"aWkpLl68iP379+PJX/fx7m0yeLqC36B8OSgQq5YvEXh7hvmEBZgQJCQkUMGvELocv6ICvXr1gp+f\\\\\\n\",\n              \"XxO0qulISUlVBsjmzZtx7949nDp1Cs7OzigvL4exsTEsLDugrbmlQPvT0ORhqNtobN++HX/88YfQ\\\\\\n\",\n              \"7SksLKwcBXXq1AmLFy/Gy5cvkZGRAUtLSwwZMgSGhoYoKChARkYGiAgmrczqDK9/8p29CCFnTmD1\\\\\\n\",\n              \"6tW1Xq8jImRlZeHevXtVRlOxsbGIj4+HmppaZUiZmJjA0dERJiYmMDAwqBZ4Bw8fgZKq8M8zU1RW\\\\\\n\",\n              \"xdy5cyElJYVVq1bByckJYmJisLGxQV5eHg4vG402gyeiRYeekPjHKb2K8jK8jrqKZ6e2oTgvB507\\\\\\n\",\n              \"dqi3rtLSUowbNw5JSUkIDw+vXHWez+fj5s2bOHz4ME6ePIk2bdrAzc0NXK4GzgXuhfdswQIp5slf\\\\\\n\",\n              \"eJ+TIfDzxxjmn1iACUFfXx83Dx0Rutzzp0/Q2lTwi+tfI3FxcXTu3BmdO3fG6tWr8fTpUwwYMBC/\\\\\\n\",\n              \"rBJu5pjbaG8M6WuHVatWgc/nIzMzU6BrSBkZGeDz+ZWjIHV1dcTExCA+Ph6GhoZo27YtrK2tsX37\\\\\\n\",\n              \"dpw+fRpWVlbo2+9HDBvlJXDbxMTE4DrGB1u2boW1tTXi4uKqBNTLly8RGxsLIoKpqWllSDk7O8PE\\\\\\n\",\n              \"xATGxsZCzeiTl5NHcWGhUMcPAHKys+Do6IhNmzZVuXYmJiaGjev/QA8He6xetwHBh9ZBt70dJJrJ\\\\\\n\",\n              \"o6I4H2+jb6GVaSvs/PMP6Orqon///lBXV691hmx+fj6GDRsGaWlphIaGQlZWFo8ePcLhw4dx5MgR\\\\\\n\",\n              \"qKmpwc3NDQ8ePKgcySUnJ+OHDh3RvnNXdLCrewHcD7nZWLdgMpYsXvzNPyGAaRrsGpgQCgoK0KJF\\\\\\n\",\n              \"C1y+eRct9PQFKvMxLw/mrQzw4vnzaqdeRJ2SkhJu/hUHRSXh1h+0NOahorwMpaWldV5D+vd7CgoK\\\\\\n\",\n              \"Va7/TZ06FSoqKli2bFnle5s3b8a2bdsQFhYGY2NjPHyVXm3kU5ec7CzYtTeChLg4jIyMqoymPr04\\\\\\n\",\n              \"HE6jXIdc9PPPePkmE5MXrhS4TElxEdx6tMf9qHto2bLu09kvXrxAREQE8vLyoKioCDs7uyr3WoWF\\\\\\n\",\n              \"hWHUqFG4fPlytVmYn1bXaNu2LebPn49jx45VPrzSzc0Nbm5uaNu2bY31RkREYNDgwRjtuxC9B7lA\\\\\\n\",\n              \"qobJRC+fPsLa+ZPg6jwcK1euEPj7M01PlK6BsQAT0rRp01BSTlixZr1A22/esBZ+WzahpYEBVqxY\\\\\\n\",\n              \"8U09lkFGRgZ/xaWhmZDTsO2tTXE2+AwsLISfcv1PT548Qd++fZGcnFzl/rTZs2cjIiICCYmJuP0k\\\\\\n\",\n              \"Sah9EhFMeQooKSkRKvg+x+vXr2Fu0R6Hrz4SePHiS6cD8df1c7h08UKjtCEwMBBz5sxBRERElVFU\\\\\\n\",\n              \"z5490bJlS+Tn5yMuLg7Ozs5wc3ODjY2NQCtsPH78GFN8p+JFTAz6Dh0JY7P2kJSURMa7FFwOCsSH\\\\\\n\",\n              \"3EwsWbwYXp6ejfI9mMYjSgH27T+bvpEtWLAAF84H4+A+/3q3vRhyDtu3bML1a9fg6+sLHx8f9OzZ\\\\\\n\",\n              \"E3fu3PkPWtr01NU5SHuXKlSZkpIS5OZkQ19fv8GjmHbt2kFPTw8hISFV3l+zZg00NTWR/zFf6H2W\\\\\\n\",\n              \"FBdDSkqqyZ9JRfT3vVwAcGi7YD8M5X/MQ8Cfq9Ddwb7R2jFixAjMnTsXvXv3RmJiIlat+vuBkikp\\\\\\n\",\n              \"KeByufj555+RmpqKrVu3wtbWVqDwAgBzc3OE37iOG9eugiNDuHPhGK6e+Hu24e+/LUVSQgILL6bB\\\\\\n\",\n              \"2DUwIWlpaSH00iX06dsXT6IfYcKUaTBoaVhlm3epKfDfuR1HDu7FmTNn0Lp1a7Ru3RrDhw/H3r17\\\\\\n\",\n              \"MXz4cFhaWmL58uUivYjpkCGDEXTsEKbPF2xaPABcOh+EVq1aN9pjT3x8fLBz504MGDCg8j1xcXEE\\\\\\n\",\n              \"BgZCS4uHp9EP0dbCSuD9RUZch3kDR4b1efPmDaZMmYLY2FjYd+uKoEN7oKyqjuEetS/E+/HDeyyZ\\\\\\n\",\n              \"4o4O1pbYtGkT8vLysHTp0nrv9atPaWkp9PX1ISEhASMjI0hISMDHxwerV69ulGeqtWnTBhs2CBbQ\\\\\\n\",\n              \"DCMsNgL7DKamprh75w5UleTRr4cdhg/sh1/mz8YvC+bA3WUIunWyRHHBB0RGRqJTp06V5aSkpODt\\\\\\n\",\n              \"7Y24uDj06NEDffr0wYgRIxAbG/sFv83nmzRpEo4eCkBpqWA3oRIR9u3ainfvUmFlZYX9+/cLXLY2\\\\\\n\",\n              \"zs7OiIyMxOvXr6u8Lysri6lTp8J/+2ah9ndk7y5MniT4iu7CqKiowJYtW2BlZQVzc3O0bt0aGRkZ\\\\\\n\",\n              \"aGVqgkPb12PKiL64GXYeFf9b3gsAcrIycGjHBowfZA/7Lp1wJigI0dHRePz4MWxtbT/r3w6fz0d4\\\\\\n\",\n              \"eDjGjx8PbW1trFmzBvb29pCUlES7du2wfv367+6BoIxoYgH2mTQ1NbF27Vq8efMGE8f7wFC/OQya\\\\\\n\",\n              \"a2OkqwuSk5Oxbdu2ypUg/q1Zs2aYPn06Xr16BXNzc9ja2sLT0xPJycn/7ZdooDZt2kBfXw/zp42H\\\\\\n\",\n              \"IJdS/bf/ibLiIiQkJGDlypU4ePAg9PX1sWLFCmRlZX1WG+Tk5ODm5oY9e/ZU+8zXdwpuXL6IuJjn\\\\\\n\",\n              \"Au3rwb1IPH50Hy4uLp/Vlro8ffoUdnZ2OHr0KIKCgnDlyhUUFhYiISEBGRkZ0NHm4cdeDgg5vAPD\\\\\\n\",\n              \"u7aB9wA7jO3bEeP6dwF9SMP5c8HYuHEDJCQkoKGhUbmck62tLXbv3l3j8b9//z7GenjAoGVLcDhc\\\\\\n\",\n              \"6Og2h7mFBXg8HiZNmgQDAwPcv38f48aNQ1BQEK5du4YWLVpg7Nix4PP5jX4MGKbR0TfK2tr6SzdB\\\\\\n\",\n              \"YDk5ObRw4UJSU1OjKVOm0Lt37750k+pVUVFBCxYsID09PbKwaE9DnN3ofsxbis8oqvZ6mpRNvjMX\\\\\\n\",\n              \"kJ6ePiUnJ1fZz+PHj2ncuHGkoqJC48ePpxcvXgjdlsePH5OOjg6VlZVV+2z//v3E09ahCzcfUlx6\\\\\\n\",\n              \"Ya2vExduEFdDky5cuPDZx6QmRUVFtGjRIuJwOLR9+3aKiYkhIyMjsre3J2lpaeJwOLRv3z4qLy+v\\\\\\n\",\n              \"LJOamkqPHz+mmJgY+vjxY537f/bsGVlYWNDgwYMpKyuLiIjevHlDNjZdSLeFHvnOW0onr9yn0Kg4\\\\\\n\",\n              \"On3tIc34eSW10G9JllZWFBsbS7///jvp6elRTEwMEREVFhZS165daerUqcTn8xv1WDCiQZT6ThZg\\\\\\n\",\n              \"X5H09HSaPn06qamp0bx58yg7O/tLN6lGHz9+pEGDBlHXrl0pIyODCgoKaJynJykpK9MwV3fafegU\\\\\\n\",\n              \"HT9/jfYdO0fjJviSmpo69XdyqjOY09LSaMmSJaShoUE//vgjhYWFCdWB2tjYUHBwcI2fBQQEkLKK\\\\\\n\",\n              \"Co32nEiX7zypElznr0fRMLcxpKqmVmv5z3X9+nUyMTGhoUOHUkpKCkVERBCHwyEej0cSEhI0Z84c\\\\\\n\",\n              \"KioqanA9xcXFNGvWLNLR0aGDBw+StrYO+c5dSvfic+hB0odqr6iEXJq3fB0pKSmTsbExvX37tsr+\\\\\\n\",\n              \"cnNzydzcnFasWNHgtjGiR5T6ThZgX6HXr1+Tt7c3qaur06+//kp5eXlfukmVkpOTycLCgjw8PKik\\\\\\n\",\n              \"pKTKZxkZGbRq1Srq2asXdejYkRy696C5c+dSQkKCwPsvLCykXbt2UZs2bahdu3bk7+9PxcXF9ZYL\\\\\\n\",\n              \"CAig/v371/p5fHw8OTo6kqysHBmZmJKldQdqaWhEWlo8srCwoDlz5gjcxvrk5OSQp6cn6erqUlBQ\\\\\\n\",\n              \"EBER+fv7k6ysLImLi1PLli0pMTGx0eoj+vuHirlz55KsnDxNnf9rjcH179eiVZtI36AllZaWVttf\\\\\\n\",\n              \"amoqGRgY0K5duxq1nczXT5T6ThZgX7G4uDgaOXIkaWho0Lp166iwsPCLtuf27dvE4/Fo3bp1TX56\\\\\\n\",\n              \"ic/n08WLF6lPnz6kpaVFy5Yto4yMjFq3LygoIDU1tWqnKP9txowZZGlpSdeuXaOnT59SaWkpxcbG\\\\\\n\",\n              \"EofDodzc3Aa3OTAwkHg8Hk2ePJk+fPhABQUF1K9fPxITEyM5OTmaMmVKldOFDVFSUkLBwcE0YsQI\\\\\\n\",\n              \"UlJSoo4dO5KhSSu6n/heoAB7kPSBfuhsS8ePH69x/7GxsaSlpVUZwsz3QZT6ThZgIuDJkyc0aNAg\\\\\\n\",\n              \"0tHRIT8/v2ojn//CgQMHiMPh0NmzZ//zup8+fUpeXl6koqJC3t7e9OzZsxq3mzJlCi1ZsqTOfVVU\\\\\\n\",\n              \"VJCzszM5OztTRUVF5ftjx46tt2xdkpOTqX///mRmZka3b9+msrIy2rFjB8nLy5OsrCwpKyvXGhTC\\\\\\n\",\n              \"qKiooPDwcBo/fjypq6uTra0tbd26lTIyMmjAwEG0cMUGgcPrQdIHWvnnHupm71BrfVFRUcTlcik8\\\\\\n\",\n              \"PLzBbWdEgyj1nSzARMi9e/eod+/eZGBgUO3Cf1P5NFnDwMCAnjx50uT11SU9PZ1+/fVX0tTUpD59\\\\\\n\",\n              \"+tClS5eqjAQfP35Murq6NU7m+KeioiKys7OrctowPj6e1NXVhb7uWF5eThs3biR1dXVavnw5FRcX\\\\\\n\",\n              \"0+nTp8nU1JRUVVVJU1OzwceOz+dTdHQ0zZ07l5o3b05mZma0cuXKaqchmzVrRtejk4UKsDuxmSQt\\\\\\n\",\n              \"LV3nadqwsDDS0NCg6Ojoz/4OjOgQpb6TBZgIun79Otna2lLr1q3p+PHjVUYSjenfkzW+FkVFReTv\\\\\\n\",\n              \"709t27YlMzMz2r17d+VkiM6dOws0GSM7O5tMTU1py5Ytle95eXnRokWLBG5HdHQ0dezYkbp160Yx\\\\\\n\",\n              \"MTEUHh5ONjY21KpVKzIwMCAej0eOjo6fPRknMTGRVq5cSWZmZtS8eXOaN29erSFSXFxMkpKSQp0+\\\\\\n\",\n              \"/PRSU+dQenp6nW0JDAwkHR0doa5nMqJJlPpOFmAiis/nU0hICFlaWpKVlRWFhIQIdF3q3r17NH78\\\\\\n\",\n              \"eOrdpw/1cnSkMWPGUlhYWLUQrGuyxteCz+dTWFgY9evXjzQ1NWnJkiW0YcMGcnJyEqh8QkIC8Xg8\\\\\\n\",\n              \"OnPmDBH9HRhqamqUmZlZZ7nCwkJasGABcblc2rVrF0VHR5OTkxPp6enRypUricfjkYqKCs2dO1fo\\\\\\n\",\n              \"UXJGRgZt3bqVunTpQurq6jRhwgQKDw+v94eUuLg4EhcXpzuxmUKF1/3E99RMVrbe6fpERJs3byZj\\\\\\n\",\n              \"Y+N6w44RbaLUd7IAE3EVFRV0/Phxat26NdnZ2dGNGzdq3C48PJysrKxJT9+A5i9eTnsDT9P+o2do\\\\\\n\",\n              \"+eoN1KZtOzIyMqZjx44R0X87WaOxPH/+nHx8fEhZWZlkZGQoLCxMoHJRUVHE4XDo7t27REQ0YcIE\\\\\\n\",\n              \"mjdvXq3bX7lyhYyMjMjZ2ZmioqJo7NixpKGhQRs2bKALFy6QoqIiKSoqUmBgoMBt//jxIx08eJB+\\\\\\n\",\n              \"/PFHUlJSIldXVzp79mytPzgUFxfTuXPnaMyYMWRiYkLS0tIEgBSVlOjPgBNCBZj/yVBqaWgk8N/z\\\\\\n\",\n              \"zz//TNbW1l/VzFimcYlS38kC7BtRXl5O+/btIwMDA+rduzfdu3ev8rPTp08Th8OlHXuPUHJWIb3J\\\\\\n\",\n              \"Ka7yep1dRMfPhZGObnNyHz2auFzuF5ms0RgyMzOpU6dOpKCgQI6OjnThwoV6O+fg4GDi8XgUHx9P\\\\\\n\",\n              \"r1+/JjU1tWqjjKysLPLw8KDmzZvTwYMHadasWaSmpkaLFi2i9+/fV06T5/F49OjRo3rbWVJSQmfP\\\\\\n\",\n              \"niVXV1dSUlKifv360cGDB2scCSUnJ9O6deuoV69exOFwSExMjCQkJKhFixY0fPhwOnToEBUUFNCe\\\\\\n\",\n              \"PXvIvldfoQKs/2Bn+uOPPwQ+vnw+n3x8fKhXr14C3d7AiB5R6jtZgH1jSkpKyM/Pj7S1tWnQoEF0\\\\\\n\",\n              \"/PhxUudw6PzV29WC69+vO49jSVVNnTZv3vylv0aDREdHk46ODu3Zs4csLCyodevWtHPnzjpvQ9i2\\\\\\n\",\n              \"bRuZmJhQVlYWTZkyhWbOnEnXrl2jyb5TqUvXbqSorEJdunShOXPmVJ7aS01NJT6fT/PmzaNmzZpR\\\\\\n\",\n              \"ly5dKlfDqMmnGYQTJkwgDodDXbp0qZxB+ElRURFduXKFJk6cSGZmZiQjI1M5Bd/KyopmzpxJUVFR\\\\\\n\",\n              \"NZ5SLCgoIA6HS3tOXBIovA6dCycVFVXKyckR6viWl5fT4MGDycXFpcmuvzJfjij1nSzAvlGFhYW0\\\\\\n\",\n              \"bt06UlZRoV+Wr643vD69Ao6cImvrH7508xusc+fOdPbsWeLz+XTlyhVycnIiLpdLv/zyS60rgsyb\\\\\\n\",\n              \"N49sbW1p/fr11ExegXj6htTXcxY5z1tDQ2b+Rub2/UhGVo6cBg6id+/eUUlJCQ0YMICkpaVp4sSJ\\\\\\n\",\n              \"tc5+jI6Opnnz5lGLFi2oTZs2tGLFCkpISCA+n09JSUnk5+dH/fv3J01NTRIXFydxcXHS1NSkfv36\\\\\\n\",\n              \"0datWyk1NVXg733p0iXicDVo76nLdYbXkZAI0tTiVZ42FlZRURF169aNpkyZUm2Ey+fzKS8vj7Kz\\\\\\n\",\n              \"s/+TmbJM4xKlvpM90PIblp6eDlPTVrj1KAbKyioClamoqEC3H8xw4vgxdOjQoYlb2HQCAgJw+vRp\\\\\\n\",\n              \"BAcHV7738uVLbNy4EYGBgRg0aBBmzJgBc3Pzys/5fD4s2lsiNSMLw+atQUuLTtUeq1LwIRcRx/cg\\\\\\n\",\n              \"Jvw8FORkkZSUhO3bt8PDw6PKdklJSThy5AgOHz6MDx8+wNXVFUOHDkVhYSGCg4Nx7do1xMTEoKys\\\\\\n\",\n              \"DOLi4jA0NISdnR2GDRuGrl27Qk5O7rO/+/nz5zF6zFh07+2EoaM8YWr2/98xIS4GJw/twaXgk9i2\\\\\\n\",\n              \"dQtGjBjx2fW8f/8e9vb2cHZ2xqJFi/D69Wts374du/fsQWFBASQkJVFaUoJBgwZj8uRJsLW1bdLH\\\\\\n\",\n              \"1DCNQ5T6ThZg37Bdu3bh4uWr2LR9r1Dl/li1DJJUirVr1zZNw/4DBQUF0NHRgZeXNx49foL8/Hwo\\\\\\n\",\n              \"KiqiZ3d7DB48GCdPnsTWrVvRunVrzJw5E3379sWOHTux7Pe18N54BAoq6nXu/9ap/bi4Zx1CL4TA\\\\\\n\",\n              \"3v7vB0xmZWXh2LFjOHz4MF68eIE+ffqgRYsWiI2Nxd27d5GWlgYxMTHIy8vD3Nwc/fr1w08//QQz\\\\\\n\",\n              \"MzOBHxQpqLS0NOzevRvbd+yEmLg4FJWUUZD/ESXFxfDx9oKPjw90dXUbXM+7d+/QpUsXWFi0R/jN\\\\\\n\",\n              \"cAwe7oaRY71hZGIKAPjwPhcnjx7CQf8dMNDXw4kTJxrtWXBM0xClvpMF2Dds5cqVSM3MxfzFvwlV\\\\\\n\",\n              \"LvBAAM6dOorJkydBXl4e8vLykJOTq/KrvLw8mjVr1ugdb2PIzc3FpMlTcObMGXTtMwCde/aDvIIS\\\\\\n\",\n              \"8j9+wL2rF3D7ygUMGTIE6/9Yh5CQEKxfvx75+fl4k5KKiZuPQ8vARKB6jv42Fc69u6F58+bYt28f\\\\\\n\",\n              \"IiIioK+vj9LSUqSkpKC8vBx8Ph/a2tqwsbHBwIED0b17d/B4vCY+Av+vvLwciYmJyMvLg6KiIgwM\\\\\\n\",\n              \"DCAlJdWodUyYMBGXr1xFYHAoNDS1atymoqICvy6chad/3ceNG9fZ88a+YqLUd7InMn/DpKSkUF5W\\\\\\n\",\n              \"Xv+G/1JaVoaoqHvw9HwAaWlpSEhIQFxcHESEiooKlJWVobi4GKWlpZCVla0SajUFnbDvffq9nJyc\\\\\\n\",\n              \"0AGZmZmJrt3sYWrVBfuuPIK8olKVzzt37wuP2UsR8Mev6NnLEdevXYWmpiaGDRsGTgtDgcMLADoO\\\\\\n\",\n              \"HI3F88cBFeUoLy+HlJQU4uPj0bp1a0yePBn9+vVDp06dvmhnLSkpCWNj4ybb/6VLl3DxUiiCQsOh\\\\\\n\",\n              \"ps6pdTsJCQn8+vsGzPb1xuzZs+Hn59dkbWK+HyzAvmHGxsY4feas0OWePHqAOXPmYOjQoXj37h1S\\\\\\n\",\n              \"U1MrX//887t37yAtLQ11dXVwuVyoqalBRUUFSkpKUFRUhKysLJo1awZpaWmUlJSgoKAAmZmZSEpK\\\\\\n\",\n              \"QmFhIQoKClBQUFD5+3+/V1RUBBkZGYGDUE5ODocOH4G1Q194zFxc6/dTVFaF768bsGPlArRtZ468\\\\\\n\",\n              \"D+9Rxgfsho4V6jjpt7WGrLwiWhsZwM3NDQ4ODmjTpg0kJCSEPuaiasPGjfCdNb/O8PpETEwM8xev\\\\\\n\",\n              \"QE8bC6xatQoqKoJdl2WY2rBTiN+wsrIytNDTw8ET52Hauo1AZd6/z4WdVWvEvnwJDQ2NOrfl8/nI\\\\\\n\",\n              \"ycmpMdz++ee0tDQoKSmBx+NBW1u78vXPP/N4PPB4PEhLS1fZf3Fxcb1B9+nXJ0+e4Gr4LfgFRwg0\\\\\\n\",\n              \"cqsoL4d7d3MU5udBWk4eHiv3oHlrC4GO0ycBc8egq2UbWFpaQkpKCtLS0pCSkqr2EvZ9KSmpr/L0\\\\\\n\",\n              \"7D8lJCSgQ4eOiHz8Cs1kZQUuN9XbHfZ2NpgxY0YTto75XKLUd7IR2DdMSkoK3l5e2LZpLTb6+Qs0\\\\\\n\",\n              \"A2z3tj/BUVcX6JHy4uLi4HA44HA4VWbz/Rufz0d2dna1cHv+/DkuX75cGXTp6elQVlauFm7//rOW\\\\\\n\",\n              \"llaN13EGDh6CAaN8BO74JSQlMXjMBPx1PQRJb96Az68QqNw/lZeXIS0tDdHR0SgrK0NpaSnKysqq\\\\\\n\",\n              \"vYR9v6ysDBISEg0OwqZ8PzQ0FDZ29kKFFwD07OuEG6FnWYAxDcYC7Bs3a9Ys2NraYdPalZg2Z2Gd\\\\\\n\",\n              \"IXbq2GEcO7QXgwcPhrm5OVauXAlPT88GT30WFxcHl8sFl8uFhUXtIxw+n4/MzMxqI7knT57g0qVL\\\\\\n\",\n              \"lcGXnp4OVVXVKuGmpaWFkHPncGTBOqHa1mvQCBzYsgZSMjJ4Fx8DPTMrgcvyKyrwIS0Fvx8KQJs2\\\\\\n\",\n              \"go1wBUVEKC8vb3AQ1vd+QUHBZ+8nPT0dnWzthf5uCopK+PjxY6MeL+b7xALsG6esrIzQ0Evo27cf\\\\\\n\",\n              \"Xjx7Au9J02DdsXOVUIp5/gx7d23DjauhCA0NRdu2beHj4wNvb28cOnQIO3fubNKJAJ+Ii4tDU1MT\\\\\\n\",\n              \"mpqaaN++fa3bVVRUIDMzs8ppylevXkFaRgZy8gpC1amizgURH5JiQMSJAHT6yVXgwI65ex26utqN\\\\\\n\",\n              \"Hl7A39eLPo10vlaHDx/GkeOnhC6X9+E9lJSU6t+QYerBAuw7oK2tjdu3b2Hnzp2YPcUbMs2awbT1\\\\\\n\",\n              \"3/ceJSUmIOXA/Q6fAAAbzElEQVTta/h4e2P9/fuV170sLCwQGRmJzZs3w8bGBrNmzcLs2bO/ig41\\\\\\n\",\n              \"Ly8PCQkJiI6Oxp07d/DkyRMkJiaitLRU6H3x+XwQnw9DQ0PEJSQi/q87MLKyEajc1YNb4Tqg7+d8\\\\\\n\",\n              \"hW+Cra0tpvj6orCgAHJCzLQMu3AWfXt1b8KWMd8LNonjO8Pn83Hr1i0kJyeDz+eDx+PBwcGhzmBK\\\\\\n\",\n              \"SkrCxIkTkZqail27dqFjx45N3k4iqrxOdufOHURFReHFixdISUlBSUkJJCQkUF5eDg6HAwMDA5ib\\\\\\n\",\n              \"m+P4iZNY4X8K+satBa4nJvo+Fk9wxQ6/bZCXl8eoMR4Yv/EINPVrH3Hy+XyE+K1A9su/kPc+B/36\\\\\\n\",\n              \"9cPatWu/yxt0fxowAHY9+8HVfZxA26e/S0XvrtZITkpio7CvlCj1nV/3NCfmsxERbt++jZGjRkFH\\\\\\n\",\n              \"RweKiorg8XgYPHgIioqK4ObmhtGjR8PR0bHeUZW+vj5CQkIwb948DBgwADNmzEB+fn6jtLOiogKv\\\\\\n\",\n              \"Xr3CiRMnMH36dDg4OEBXVxcyMjLQ09NDnz59sHLlSjx79gzGxsaYNWsWgoKC8Pz5c5SUlCA9PR1X\\\\\\n\",\n              \"r15Fx44dIdusGYIP7hKq/lN7/SAvK4u0tDT4+fmhlbER/Oe4IzLoAIoLql+neRPzGEeWTkLx25eI\\\\\\n\",\n              \"CL+OZ8+eQVxcHG3btsX58+cb5ZiIkhnTp2Pr+tXISE+rd1s+n4/ffpkH91GjWHgxjYKNwL5Bb968\\\\\\n\",\n              \"wbDhw5GdnQ1P7wn4aeBgqKiqoiA/H6EXQ7Brhx9KS4px4sQJtG3bVqh9Z2VlYdasWQgPD4efnx/6\\\\\\n\",\n              \"9hXsFFpxcTFevnyJiIgI3L17F8+ePUNycjJyc3MhJiYGIoKqqir09PRgZmaGzp07w9raGqamprXe\\\\\\n\",\n              \"LxQfHw8/Pz/s3bsXXbp0gYuLCyZMnIQtp66Dy9Opt02prxMxc0QfBPjvweTJk5GTk4PRo0fD3t4e\\\\\\n\",\n              \"J08H4cqVK2hj0wPNlNRQUVaC+Ed3IFFRBt/JEzF16lTI/mP23dWrV+Hl5QU7Ozts2LAB6up1L0X1\\\\\\n\",\n              \"LVm2bBkOBx5FQGAQdHRb1LhNWVkZFs2agjeJr3D5cliVY8d8XUSq7/zPlw/+j4jSisqNKTk5mZo3\\\\\\n\",\n              \"b06/rVpD7wvLKK+4otrrQ1E57fTfRxoaGgI9u6omly5dIgMDAxo5cmSVx4G8f/+erl69SkuWLCEn\\\\\\n\",\n              \"JycyMTEhBQUFEhMTIzExMZKRkaEWLVpQjx49aObMmXTy5EmKj48XeNXyiooKCgkJoR9//JE4HA7N\\\\\\n\",\n              \"mTOHEhISKDU1lfr06UOSUtKkqdOCDlx7TOefptf6Cgi9T9rN9Wjjxo1kb29P7u7u9PbtW1qxYgU1\\\\\\n\",\n              \"b96cOnfuTBs3bqQtW7aQgYEBeXp6koyMDBUVFdXatvz8fJo2bRrxeDw6ceLEZx1XUcTn8yuffODq\\\\\\n\",\n              \"7kHnrkZSYmYhJWUV0b1niTRn0a+k27wFDRgwUKAnPzNflij1nSzAviF8Pp+srKxo5ep1NQbXv197\\\\\\n\",\n              \"Dx6h5s2b1/mcrNrqefv2Lfn7+1P79u1JSkqKVFVVSUpKqjKolJWVyczMjIYOHUq///473b59m96/\\\\\\n\",\n              \"f//Z3y0nJ4fWr19PhoaGZGlpSf7+/pSTk0PHjh0jR0dHkpSUJE1NTQoMDKRff11GqupcGr9gBR2L\\\\\\n\",\n              \"jKsSXIG3XpL33F9JnatBikpKZGlpSV5eXlWea1VeXk5nzpyhvn37EpfLrXyIZbt27ao8KLQ2t27d\\\\\\n\",\n              \"IlNTUxo2bBilpaV99ncWNWlpafTbb7+Rnp4+SUhIkJSUFCkqKpKnpyc9ePDgSzePEZAo9Z0swL4h\\\\\\n\",\n              \"V69epdZtzOhDUblAAZZXXEGOvfvQvn37atxfeXk5RUdH04YNG8jV1ZXat29P6urqJC4uTgBIWlqa\\\\\\n\",\n              \"tLW1ydzcnNTV1cnMzIyuXbvWqM+AevToEXl7e5OKigq5ubnRrVu3KCIigsaPH09qamr0ww8/kJqa\\\\\\n\",\n              \"Gs2cObPyeVzXr18nBQUFUlFVo2ayctSxaw/q0X8wderWk5SUVcjVbSSFhoaSnp4eKSsrVxlB/ltc\\\\\\n\",\n              \"XBw1b96clJSUqHnz5uTt7S3Q9ysqKqL58+eThoYGHTx4sN6nQn9rysrK6hytMl8vUeo72TWwb8iw\\\\\\n\",\n              \"4cPRpasDvMdPFLjMhfNnse73FVi+fDnCw8Px8OFDvHr1CmlpaSgoKAAAyMvLg8fjwdjYGFZWVrC3\\\\\\n\",\n              \"t0fHjh2rXIgvKyvD+vXrsXbtWixYsADTpk2DpOTn3aVRVlaG06dPY8uWLUhISMCECRPQu3dvXLx4\\\\\\n\",\n              \"Efv374eUlBRGjx6NkpIS+Pn5Yffu3fjpp59ARDh16hRGjhwJfX197NmzBy1btkRUVBQ+fvwIJSUl\\\\\\n\",\n              \"dO7cGQDg6OgIR0dHiIuL4/bt2wgLC0OzZs1qbE/37t0xb948nDp1CidPnoSioiLGjx8PT0/Pepfb\\\\\\n\",\n              \"un//PsaNGwc9PT1s374dOjr1X5tjmC9JpPrOLxygTUaUfopoLEpKSpTwJk3g0VdecQXlFpSSlJQU\\\\\\n\",\n              \"SUhIEJfLJSsrKxo1ahT9+eef9PTpU6EfGR8XF0c9evQga2trevjwoVBlU1NTaenSpcTj8cje3p72\\\\\\n\",\n              \"7t1L27dvp65duxKXyyVfX1+Kioqi9+/f07Bhw8jKyooSEhKIiOjGjRtkY2NDWlpa1L59+1rbnZqa\\\\\\n\",\n              \"Sm3atKGff/6Z+Hw+VVRUkLOzM7m6utZaxsHBga5evUpxcXGkq6tL9+/fJ09Pz8pR4c2bN+scYZWU\\\\\\n\",\n              \"lNDSpUuJw+HQrl27vrvRGCNaRKnvZAH2jeDz+QSg1okbdb20tbXp9evXjdoWf39/4nK5NHfuXCoo\\\\\\n\",\n              \"KKhz2ytXrpCDgwPJyclRnz59aNmyZeTi4kJKSko0ZMgQCgoKopKSEiIiio6OJmNjY5owYQIVFRVR\\\\\\n\",\n              \"dHQ0/fjjj6Svr08bNmwgVVVVio2NrbGuN2/ekLGxMS1fvrzK+4WFhWRjY0OLFi2qsdynAOPz+cTl\\\\\\n\",\n              \"ciuPVU5ODm3cuJFMTEyoXbt25OfnR3l5ebV+1+joaLK2tqZevXpRYmJiXYeQYb4YUeo7WYB9Q+Tk\\\\\\n\",\n              \"5Cgl871Q4fWhqJxUVFQoOzu70duTlpZGLi4uZGhoSJcvX67yWUFBAa1atYp4PG2SlZMjC0srcuzT\\\\\\n\",\n              \"jzrb2JKikhK1a2dO+/fvrzJaCQgIIA6HQ/v376eEhAQaNWoUaWpq0qZNm6i4uJgGDRpES5curbEt\\\\\\n\",\n              \"iYmJZGBgQOvWravx84yMDDI0NKQ9e/ZU++xTgBERDRw4kAIDA6t8zufz6fLlyzRkyBBSVVWlSZMm\\\\\\n\",\n              \"0ZMnT2qsp6ysjFavXk3q6uq0efNmoUe4DNPURKnvZAH2Delmb08HjhwTKsCu3owkAwODJu1Iz549\\\\\\n\",\n              \"S82bNycPDw+6f/8+zZo1ixQVFUlBQYEmTJlGUY9fUubHssrXm8yPtG3XXmrTth25u4+mDx8+kJeX\\\\\\n\",\n              \"F5mamtKNGzfI19eX1NTUaOnSpZUjnrNnz5KRkVGNEwdiY2OpRYsWtGXLljrbGRMTQxoaGhQWFlbl\\\\\\n\",\n              \"/X8G2OrVq2nq1Km17uPt27e0ZMkS4vF41K1bNwoMDKwcPf67ri5dupCdnR29fPmy3mPIMP8VUeo7\\\\\\n\",\n              \"WYB9QwIDA8neoYdQAeY2ajStXr26SdtVUVFBJ06cIH19fRITEyNjY2NSUFCks5euVQmuf7+S0t5T\\\\\\n\",\n              \"956OpKWlRYMHD6Z58+aRmpoaTZs2jdLT0yv3X1BQQPr6+hQaGlqt7ufPn5OOjg7t3LlToLbeuHGD\\\\\\n\",\n              \"uFxulRHUPwPs5s2b9MMPP9S7n9LSUjp+/Dh1796dtLS0aNGiRZScnFxlm/Lyctq0aROpq6vTmjVr\\\\\\n\",\n              \"KmdRMsyXJEp9Jwuwb0hJSQlpa2vTqbMhAoXXzTv3SVVVlTIzM6vsp7CwkOLj4ykmJoaysrI+uz3/\\\\\\n\",\n              \"vHfL2NiYunbtSgoKCiSvoED7j5ysM7z+GWIt9PRJRUWF3N3da7x2tGDBAnJxcan2fnR0NPF4PNq/\\\\\\n\",\n              \"f79Q7T548CDp6elRamoqEVUNsMLCQpKTk6P8/HyB9/f8+XOaOnUqqamp0YABA+jixYtVRrzx8fHU\\\\\\n\",\n              \"vXt36tChQ62nHhnmvyJKfScLsG/MzZs3icvl0tkLYXWGV8TdB8TT1qbjx49Xlo2OjiYfHx9SUVGh\\\\\\n\",\n              \"Fnp6ZGhkREpKStS9Rw86ceIElZaWCtSGT/tRVFSktm3bEo/HIzMzM1qzZg0dO3aMWrcxo4y8UoEC\\\\\\n\",\n              \"LPNjGa3duJV69uxZY13Pnz8ndXV1SklJqfL+/fv3SVNTk44ePfpZx3HZsmVkbW1N+fn5VQKMiKhT\\\\\\n\",\n              \"p050/fp1ofeZn59Pu3btovbt25OhoSGtW7eu8gcEPp9PO3bsIA6HQ8uWLav1WGdmZtLvv/9ODt27\\\\\\n\",\n              \"U3tLS+pia0u+vr70/Pnzz/qeDPNvotR3sgD7Bl27do24XC4Ndx5BoVfDq9zYfOveQ/Lw9CY1NbXK\\\\\\n\",\n              \"zr2iooJmzJhBPG1tWrT4V4pNfFu5fVZeEfnvP0Q2XWypffv21YLik9LSUjp69Ch17tz57wBs0YI4\\\\\\n\",\n              \"HA5Nnz6dHj58WDkZw9nZhVav/1Pg8Mr8WEaJ73JJVVW1Wt18Pp8cHBxo06ZNVd6PjIwkDQ0NOn36\\\\\\n\",\n              \"9GcfQz6fT2PHjqUBAwaQvb19lQCbMWMGrVy5skH7joyMJHd3d1JRUaExY8bQ3bt3ic/n0+vXr6lf\\\\\\n\",\n              \"v35kYWFRZfWKwsJC8vHxIWVlZRrpPoZOBp+n67fuUkjoVZozfxFpampS9x492OxGpsFEqe9kAfaN\\\\\\n\",\n              \"+nT6zsjYmDQ1NcnExJS0dXRIV1eXli1bRu/evSOivzvTSZMmkY2tHSW/y6pztuLiX38jIyOjKqcc\\\\\\n\",\n              \"U1NTafHixaSmpkYcDofk5ORo2LBhdO7cuRpHEebmFnQl4p5QAZb5sYw623Sh8PDwKvvav38/WVpa\\\\\\n\",\n              \"Vrl2FB4eTlwul86fP9/gY1hSUkI9evQgXV3dKgF2/PhxcnJyavD+if4eUa1Zs4YMDAzI2tqadu/e\\\\\\n\",\n              \"Tfn5+bR//37icrm0YMECys7OJruuXWnocGdKTMmo8e8nK6+Ilq9cTdra2mxSCNMgotR3sgD7xlVU\\\\\\n\",\n              \"VNCbN2/o2bNnlJycXG2iQHBwMJmatqI36TkCXTebOmMWubi40M2bN6l3794kLS1NzZo1I0tLS9q5\\\\\\n\",\n              \"cyfl5ubW2R7TVq3o5r1HQgdYN/vuVWYH5uTkkJaWFt29e7fyvcuXLxOXy602i7AhcnNzSU5OjqZM\\\\\\n\",\n              \"mUJEfwf+8ePHSVlFhbp1sye7rl3J2cWFgoODG7SE1qdFip2cnEhNTY2mT59ON2/epMGDB5OGhgYN\\\\\\n\",\n              \"cxkh0D1+m7ftoJYtW9Z57x3D1EWU+k62lNR3rnfvPhjuOhIj3EYJtP2HDx9goq+D8vJyKCkpwcvL\\\\\\n\",\n              \"C97e3jA0NBSovJ1dV0ybsxDdezoK3EYiQhfrtjh2NBCWlpYAgIkT/14uy8/PDwAQEhKCsWPH4sSJ\\\\\\n\",\n              \"E+jWrZvA+xaEjY0N4uLi4Ovri2PHjoEIGDPOC2Zt20FcXBwJ8a+wf68/MtLT8Ntvv2HUKMGOZW2S\\\\\\n\",\n              \"kpKwc+dO7NmzB0ZGRnj69Blik95CTk5OoPLDBznBxXk4PDw8GtQO5vskUn3nFw7QJiNKP0V8KXFx\\\\\\n\",\n              \"ccTlcinjfYFQU+9Huo8hT0/Pz1oS6Y8//qBhLq5Cjb4uXo0gg5YtK2fu3b17l7S0tCgnJ4eIiIKC\\\\\\n\",\n              \"gojL5VJkZGSjHp9PHBwcaNy4caSkpEQng8/XuljylfDbpG9gQGvWrGmUeouLi+mnnwbQhMm+Qv39\\\\\\n\",\n              \"HDsdTNYCTPVnmJqIUt/Jnsj8HXvw4AFsu3ardRHb2vTr74TMzCyIiYkJXaeHhwcuX7qAzMwMgcts\\\\\\n\",\n              \"3bQeHHV1pKamory8HBMmTMCaNWugqqqKY8eOYfz48bhw4ULlQr2NLS8vD2fOBCP02k049u5b6/fu\\\\\\n\",\n              \"0LETLl0Jx5atWxEUFNTgemVkZPAi5gXcxwg3knLs3RfJSUlISUlpcBsY5mvGAuw7VlBQADk5eaHL\\\\\\n\",\n              \"ycsrVK5ULyxVVVV4eXlhoudolJaW1rv9mVPHcftWONq1awcLCwv069cP8vLyGDVqFA4cOIDp06cj\\\\\\n\",\n              \"NDQU1tbWn9UeQWRkZmLBL0vQxqz+p1fztLXxx6YtWP7bb6BGODufm5MDLS2eUGUkJCSgqamF7Ozs\\\\\\n\",\n              \"BtfPMF8zFmDfMWVlZeTmCN/J5eRkQ1lZ+bPrXbVqFdRVVeA29CekprytcZuysjLs2r4VM30noJ25\\\\\\n\",\n              \"BYKCgmBuYYEbN27gxYsXGDlyJObPn48rV67A3Nz8s9tSn5SUFOTm5MB1pLvAZRx790VOTg6ioqIa\\\\\\n\",\n              \"XL+MjIxAQf9vxSXFkJGRaXD9DPM1YwH2HevWrRsib99Cbm6uUOWCTp2Ao2Ovz65XUlISx44dhW2X\\\\\\n\",\n              \"zujexRpj3Ybj7JlTuBt5C+HXr2LV8iWwbGOI88GncC3iDoJDQvE0NhGDhzpDXV0d7dq1w+nTpyEt\\\\\\n\",\n              \"LY3nz583ykinNsHBwej7oxMUFRUFLiMuLo4Rbu44ceJEg+tv1bo17kTeEqrMu9RUZGdlQVdXt8H1\\\\\\n\",\n              \"M8zXjAXYd4zL5aK/kxMOHdgncJmUt28REX6jwTPtJCQksGLFCiQnJ8Ppxz5Yt2oZxo1ywbrff0N+\\\\\\n\",\n              \"3gecOX8RF8KuwdjEFACgoKAADy8fhEfeR0JiIubOnYsdO3Zg2bJlsLOzw+3btxvUntpkZmZC38BA\\\\\\n\",\n              \"6HI6OjrIyspqcP0Txo/Hnl07hCqz1383XFxcIC8v/OlhhhElLMC+c9OnTcOf69ciPv5VvdtWVFRg\\\\\\n\",\n              \"5rTJGDduHBQUFBqlfgUFBfTp0wfpaWm4HfUXLl25jnUb/6z1epMWj4fzl65i06ZNsLS0xMOHD+Hj\\\\\\n\",\n              \"4wMXFxcMHToUcXFxjdKuT2RkZFBSXCx0uZLS0kY5hTdo0CDEx8UiIvyGQNtnZWYiYPcOTJo0qcF1\\\\\\n\",\n              \"M8zXjgXYd+6HH37A0qVLMaCfI54+eVzrdgUFBfBwd0VZSQlWrlzZqG3Yvn07XEeNhoampkDb6+nr\\\\\\n\",\n              \"w2nAIPj7+0NCQgJjxoxBbGwsOnToABsbG/j6+iIzM7NR2taqVStE3AwXutyDe3dhYmLS4PqlpKQQ\\\\\\n\",\n              \"EBAAD3dX/PXwQZ3bZmdnY/jgn+Dh4dGk1wUZ5qvxpefxNxVRupfha3Dw4EFSV1enH/s70Ykz5yjh\\\\\\n\",\n              \"TRq9Sc+hyPuPaMq0GaSurk5jxo6l4uLiRq23oqKCuFwuPXwaI9S9Ttci7pCRkVG1/WVkZJCvry+p\\\\\\n\",\n              \"q6vTihUrGrwiRWlpKSkqKtGdB9ECty3xbTqpqKg0aCX/fzt16hRxOByaMXsuPYmJr1Lfm/QcWrvh\\\\\\n\",\n              \"T9LT16fZs2d/1v15DPOJKPWdbCUOplJBQQECAwOxc9cuvIqLQ2lpKbhcLoYNG4aJEyfC4DOuBdXn\\\\\\n\",\n              \"/fv30NPTw9sM4SaSlJeXg6ssh9LSUoiLVz+REBcXh4ULF+LOnTtYvnw53N3dISEh8VltNDQ0hKXV\\\\\\n\",\n              \"D/A/cFige9+W/rwQ2Vnp2BsQ8Fn11ebVq1fYunUrDhw4ACNjE6irq6OgoBDRjx6il6MjpkyeDHt7\\\\\\n\",\n              \"+0atk/n+iFLfyQKM+aIyMzPRqlUrJKUKd8qPiKCmIIPi4mJISkrWul1kZCRmz56N/Px8rFmzBn36\\\\\\n\",\n              \"9BG6jV27dkVmZiZ+GjgEv/y6vM4Q89+1AxvWrUZkZCR4POHu3xJUYWEh7t27hw8fPkBOTg7t2rWD\\\\\\n\",\n              \"lpZWk9TFfH9Eqe+s/X8+w/wHVFRUUFhYiA8fPgh1b1lqSgqUlJTqDC/g73UMIyIiEBQUBF9fX+jr\\\\\\n\",\n              \"62PNmjVo3759vXXk5OTg0qVLSEtLQ8+ePXH2zCk8fvwIvtNmoptD9ypB9uB+FHZs3Yyoe3cQFhbW\\\\\\n\",\n              \"ZOEFAHJycnBwcGiy/TOMqGCTOJgvSkpKCj8NGIDAwweFKndgXwCcnZ0F2lZMTAyDBw/Gs2fPMHDg\\\\\\n\",\n              \"QPTt2xdjxozBmzdvatz+yZMn8PDwQMuWLXEk8Chsu3ZDQVExADG8ePYMPp5jYNW2Fdych2DUiGHo\\\\\\n\",\n              \"0sES49xd0d6iHaKiomBsbCzUd2EY5vOwERjzxU2aOBETJ06Cl88Ega5TlZaWYq//LlwICRGqHikp\\\\\\n\",\n              \"KUyePBnu7u6VozAfHx/Mnz+/cvQXFBQEb29vTJ0xC09exIHL5VaWJyLcDL+B1atWIP/jRzgPG4pm\\\\\\n\",\n              \"zZqBx+PBxsbms6+xMQzzedg1MOaLIyL07NULrdq0xe9r19d5jYnP52OSjyeKiwpwsoErXbx9+xZL\\\\\\n\",\n              \"lizBuXPnsGjRIrRq1Qru7u44HRwCqzrWVqyoqMCUiePx9u1rhJw/DykpqQa1g2G+JqLUd7JTiMwX\\\\\\n\",\n              \"JyYmhpMnTuDu7QhMHu+FzIyaV6pPTUmBh7srXicnYv8+wVcPqY2uri727NmDy5cvIyQkBM7OzvDb\\\\\\n\",\n              \"uafO8AL+XkVk87btyM8vaJTlohiG+TwswJivgqqqKm7cuAFZGSlYm7eG15hROH70CC5eOI+jRw7B\\\\\\n\",\n              \"3XU4bH6wQAtdHYSFhjbqMknt2rXDvHnzoKmphX4/9heojKSkJKbNmIVt27Y1WjsYhhEOO4XIfHVy\\\\\\n\",\n              \"c3MREBCAyMg7+PjxI5SUlGBv3w2jR48WalFdYYwaNQpWP3TCpCm+ApcpLy+HqaEerly5glatWjVJ\\\\\\n\",\n              \"uxjmvyZKfSebxMF8dVRVVTFz5sz/tM5X8fEY5z1RqDKSkpIwa9sOCQkJLMAY5gtgpxAZBkBFeXm9\\\\\\n\",\n              \"95TVREpKCuXl5U3QIoZh6sMCjGEAaGhqIjk5SagyRISkpERoCrgIMcMwjYsFGMMAGOHigv17/YUq\\\\\\n\",\n              \"cz8qCkWFhejQoUMTtYphmLqwAGMYAMOHD0f0o78Q+/KlwGW2b9uCCRMm1riYMMMwTY/9z2MYAM2a\\\\\\n\",\n              \"NcPChYswytUZ79+/r3f7Qwf241ZEOLy8PP+D1jEMUxMWYAzzP9OmTUWvXr3Q08EO96Oiatzm48eP\\\\\\n\",\n              \"WL1qBZb8shAhISFQU1P7j1vJMMwnbBo9w/yPmJgY/vhjHUx3msLdzRkcDheuI92hraOD4uJi3Im8\\\\\\n\",\n              \"jWOBh2Hv4IBbt25BT0/vSzeZYb5rLMAY5h/ExMQwfrwPvLw8cfHiRZw6dRrhN65CRkYGbVq3wePH\\\\\\n\",\n              \"j6Grq/ulm8kwDFiAMUyNJCQk0L9/f/TvL9jSUgzD/PfYNTCGYRhGJLEAYxiGYUQSCzCGYRhGJLEA\\\\\\n\",\n              \"YxiGYUQSCzCGYRhGJLEAYxiGYUQSCzCGYRhGJLEAYxiGYUQSCzCGYRhGJLEAYxiGYUSSGBHRl25E\\\\\\n\",\n              \"U+BwONDX1//SzWAYhhEpSUlJyMrK+tLNEMg3G2AMwzDMt42dQmQYhmFEEgswhmEYRiSxAGMYhmFE\\\\\\n\",\n              \"EgswhmEYRiSxAGMYhmFEEgswhmEYRiSxAGMYhmFEEgswhmEYRiSxAGMYhmFEEgswhmEYRiSxAGMY\\\\\\n\",\n              \"hmFEEgswhmEYRiSxAGMYhmFEEgswhmEYRiSxAGMYhmFEEgswhmEYRiSxAGMYhmFEEgswhmEYRiSx\\\\\\n\",\n              \"AGMYhmFEEgswhmEYRiSxAGMYhmFEEgswhmEYRiSxAGMYhmFEEgswhmEYRiSxAGMYhmFEEgswhmEY\\\\\\n\",\n              \"RiSxAGMYhmFEEgswhmEYRiSxAGMYhmFEEgswhmEYRiSxAGMYhmFEEgswhmEYRiSxAGMYhmFEEgsw\\\\\\n\",\n              \"hmEYRiT9HztB2uXuIZ0oAAAAAElFTkSuQmCC\\\\\\n\",\n              \"\\\"\\n\",\n              \"\\n\",\n              \"\\n\",\n              \"    /* set a timeout to make sure all the above elements are created before\\n\",\n              \"       the object is initialized. */\\n\",\n              \"    setTimeout(function() {\\n\",\n              \"        animb11f2637772a40ca98bbdbbd7669890b = new Animation(frames, img_id, slider_id, 1000.0,\\n\",\n              \"                                 loop_select_id);\\n\",\n              \"    }, 0);\\n\",\n              \"  })()\\n\",\n              \"</script>\\n\"\n            ]\n          },\n          \"metadata\": {},\n          \"execution_count\": 5\n        }\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"## Graph Diffusion in GNNs\\n\",\n        \"\\n\",\n        \"[Scalable Inception Graph Neural Networks (SIGN)](https://arxiv.org/abs/2004.11198) leverages multiple diffusion operators simultaneously. Formally, it is defined as:\\n\",\n        \"\\n\",\n        \"$$\\n\",\n        \"Z=\\\\sigma([X\\\\Theta_{0},A_1X\\\\Theta_{1},\\\\cdots,A_rX\\\\Theta_{r}])\\\\\\\\\\n\",\n        \"Y=\\\\xi(Z\\\\Omega)\\n\",\n        \"$$\\n\",\n        \"\\n\",\n        \"where:\\n\",\n        \"* $\\\\sigma$ and $\\\\xi$ are nonlinear activation functions.\\n\",\n        \"* $[\\\\cdot,\\\\cdots,\\\\cdot]$ is the concatenation operation.\\n\",\n        \"* $X\\\\in\\\\mathbb{R}^{n\\\\times d}$ is the input node feature matrix with $n$ nodes and $d$-dimensional feature vector per node.\\n\",\n        \"* $\\\\Theta_0,\\\\cdots,\\\\Theta_r\\\\in\\\\mathbb{R}^{d\\\\times d'}$ are learnable weight matrices.\\n\",\n        \"* $A_1,\\\\cdots, A_r\\\\in\\\\mathbb{R}^{n\\\\times n}$ are linear diffusion operators. In the example below, we consider $A^i$ for $A_i$, where $A$ is the convolution matrix of the graph.\\n\",\n        \"- $\\\\Omega\\\\in\\\\mathbb{R}^{d'(r+1)\\\\times c}$ is a learnable weight matrix and $c$ is the number of classes.\\n\",\n        \"\\n\",\n        \"The code below implements the diffusion function to compute $A_1X, A_2X, \\\\cdots, A_rX$ and the module that combines all the diffused node features.\"\n      ],\n      \"metadata\": {\n        \"id\": \"unL_mAj-TqC6\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"import torch\\n\",\n        \"import torch.nn as nn\\n\",\n        \"import torch.nn.functional as F\\n\",\n        \"\\n\",\n        \"\\n\",\n        \"################################################################################\\n\",\n        \"# (HIGHLIGHT) Take the advantage of DGL sparse APIs to implement the feature\\n\",\n        \"# diffusion in SIGN laconically.\\n\",\n        \"################################################################################\\n\",\n        \"def sign_diffusion(A, X, r):\\n\",\n        \"    # Perform the r-hop diffusion operation.\\n\",\n        \"    X_sign = [X]\\n\",\n        \"    for i in range(r):\\n\",\n        \"        # A^i X\\n\",\n        \"        X = A @ X\\n\",\n        \"        X_sign.append(X)\\n\",\n        \"    return X_sign\\n\",\n        \"\\n\",\n        \"class SIGN(nn.Module):\\n\",\n        \"    def __init__(self, in_size, out_size, r, hidden_size=256):\\n\",\n        \"        super().__init__()\\n\",\n        \"        self.theta = nn.ModuleList(\\n\",\n        \"            [nn.Linear(in_size, hidden_size) for _ in range(r + 1)]\\n\",\n        \"        )\\n\",\n        \"        self.omega = nn.Linear(hidden_size * (r + 1), out_size)\\n\",\n        \"\\n\",\n        \"    def forward(self, X_sign):\\n\",\n        \"        results = []\\n\",\n        \"        for i in range(len(X_sign)):\\n\",\n        \"            results.append(self.theta[i](X_sign[i]))\\n\",\n        \"        Z = F.relu(torch.cat(results, dim=1))\\n\",\n        \"        return self.omega(Z)\"\n      ],\n      \"metadata\": {\n        \"id\": \"__U3Hsp_S0SR\"\n      },\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"## Training\\n\",\n        \"\\n\",\n        \"We train the SIGN model on [Cora dataset](https://docs.dgl.ai/en/latest/generated/dgl.data.CoraGraphDataset.html). The node features are diffused in the pre-processing stage.\"\n      ],\n      \"metadata\": {\n        \"id\": \"ngyh4-YZTkNY\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"from dgl.data import CoraGraphDataset\\n\",\n        \"from torch.optim import Adam\\n\",\n        \"\\n\",\n        \"\\n\",\n        \"def evaluate(g, pred):\\n\",\n        \"    label = g.ndata[\\\"label\\\"]\\n\",\n        \"    val_mask = g.ndata[\\\"val_mask\\\"]\\n\",\n        \"    test_mask = g.ndata[\\\"test_mask\\\"]\\n\",\n        \"\\n\",\n        \"    # Compute accuracy on validation/test set.\\n\",\n        \"    val_acc = (pred[val_mask] == label[val_mask]).float().mean()\\n\",\n        \"    test_acc = (pred[test_mask] == label[test_mask]).float().mean()\\n\",\n        \"    return val_acc, test_acc\\n\",\n        \"\\n\",\n        \"\\n\",\n        \"def train(model, g, X_sign):\\n\",\n        \"    label = g.ndata[\\\"label\\\"]\\n\",\n        \"    train_mask = g.ndata[\\\"train_mask\\\"]\\n\",\n        \"    optimizer = Adam(model.parameters(), lr=3e-3)\\n\",\n        \"\\n\",\n        \"    for epoch in range(10):\\n\",\n        \"        # Switch the model to training mode.\\n\",\n        \"        model.train()\\n\",\n        \"\\n\",\n        \"        # Forward.\\n\",\n        \"        logits = model(X_sign)\\n\",\n        \"\\n\",\n        \"        # Compute loss with nodes in training set.\\n\",\n        \"        loss = F.cross_entropy(logits[train_mask], label[train_mask])\\n\",\n        \"\\n\",\n        \"        # Backward.\\n\",\n        \"        optimizer.zero_grad()\\n\",\n        \"        loss.backward()\\n\",\n        \"        optimizer.step()\\n\",\n        \"\\n\",\n        \"        # Switch the model to evaluating mode.\\n\",\n        \"        model.eval()\\n\",\n        \"\\n\",\n        \"        # Compute prediction.\\n\",\n        \"        logits = model(X_sign)\\n\",\n        \"        pred = logits.argmax(1)\\n\",\n        \"\\n\",\n        \"        # Evaluate the prediction.\\n\",\n        \"        val_acc, test_acc = evaluate(g, pred)\\n\",\n        \"        print(\\n\",\n        \"            f\\\"In epoch {epoch}, loss: {loss:.3f}, val acc: {val_acc:.3f}, test\\\"\\n\",\n        \"            f\\\" acc: {test_acc:.3f}\\\"\\n\",\n        \"        )\\n\",\n        \"\\n\",\n        \"\\n\",\n        \"# If CUDA is available, use GPU to accelerate the training, use CPU\\n\",\n        \"# otherwise.\\n\",\n        \"dev = torch.device(\\\"cuda:0\\\" if torch.cuda.is_available() else \\\"cpu\\\")\\n\",\n        \"\\n\",\n        \"# Load graph from the existing dataset.\\n\",\n        \"dataset = CoraGraphDataset()\\n\",\n        \"g = dataset[0].to(dev)\\n\",\n        \"\\n\",\n        \"# Create the sparse adjacency matrix A (note that W was used as the notation\\n\",\n        \"# for adjacency matrix in the original paper).\\n\",\n        \"indices = torch.stack(g.edges())\\n\",\n        \"N = g.num_nodes()\\n\",\n        \"A = dglsp.spmatrix(indices, shape=(N, N))\\n\",\n        \"\\n\",\n        \"# Calculate the graph convolution matrix.\\n\",\n        \"I = dglsp.identity(A.shape, device=dev)\\n\",\n        \"A_hat = A + I\\n\",\n        \"D_hat_invsqrt = dglsp.diag(A_hat.sum(dim=1)) ** -0.5\\n\",\n        \"A_hat = D_hat_invsqrt @ A_hat @ D_hat_invsqrt\\n\",\n        \"\\n\",\n        \"# 2-hop diffusion.\\n\",\n        \"r = 2\\n\",\n        \"X = g.ndata[\\\"feat\\\"]\\n\",\n        \"X_sign = sign_diffusion(A_hat, X, r)\\n\",\n        \"\\n\",\n        \"# Create SIGN model.\\n\",\n        \"in_size = X.shape[1]\\n\",\n        \"out_size = dataset.num_classes\\n\",\n        \"model = SIGN(in_size, out_size, r).to(dev)\\n\",\n        \"\\n\",\n        \"# Kick off training.\\n\",\n        \"train(model, g, X_sign)\"\n      ],\n      \"metadata\": {\n        \"id\": \"58WnPtPvT2mx\",\n        \"colab\": {\n          \"base_uri\": \"https://localhost:8080/\"\n        },\n        \"outputId\": \"19e86f6a-c7f1-4b40-8cfc-58a181fc30d7\"\n      },\n      \"execution_count\": null,\n      \"outputs\": [\n        {\n          \"output_type\": \"stream\",\n          \"name\": \"stdout\",\n          \"text\": [\n            \"Downloading /root/.dgl/cora_v2.zip from https://data.dgl.ai/dataset/cora_v2.zip...\\n\",\n            \"Extracting file to /root/.dgl/cora_v2\\n\",\n            \"Finished data loading and preprocessing.\\n\",\n            \"  NumNodes: 2708\\n\",\n            \"  NumEdges: 10556\\n\",\n            \"  NumFeats: 1433\\n\",\n            \"  NumClasses: 7\\n\",\n            \"  NumTrainingSamples: 140\\n\",\n            \"  NumValidationSamples: 500\\n\",\n            \"  NumTestSamples: 1000\\n\",\n            \"Done saving data into cached files.\\n\",\n            \"In epoch 0, loss: 1.946, val acc: 0.164, test acc: 0.200\\n\",\n            \"In epoch 1, loss: 1.937, val acc: 0.712, test acc: 0.690\\n\",\n            \"In epoch 2, loss: 1.926, val acc: 0.610, test acc: 0.595\\n\",\n            \"In epoch 3, loss: 1.914, val acc: 0.656, test acc: 0.640\\n\",\n            \"In epoch 4, loss: 1.898, val acc: 0.724, test acc: 0.726\\n\",\n            \"In epoch 5, loss: 1.880, val acc: 0.734, test acc: 0.753\\n\",\n            \"In epoch 6, loss: 1.859, val acc: 0.730, test acc: 0.746\\n\",\n            \"In epoch 7, loss: 1.834, val acc: 0.732, test acc: 0.743\\n\",\n            \"In epoch 8, loss: 1.807, val acc: 0.734, test acc: 0.746\\n\",\n            \"In epoch 9, loss: 1.776, val acc: 0.734, test acc: 0.745\\n\"\n          ]\n        }\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"Check out the full example script [here](https://github.com/dmlc/dgl/blob/master/examples/sparse/sign.py). Learn more about how graph diffusion is used in other GNN models:\\n\",\n        \"\\n\",\n        \"* *Predict then Propagate: Graph Neural Networks meet Personalized PageRank* [paper](https://arxiv.org/abs/1810.05997) [code](https://github.com/dmlc/dgl/blob/master/examples/sparse/appnp.py)\\n\",\n        \"* *Combining Label Propagation and Simple Models Out-performs Graph Neural Networks* [paper](https://arxiv.org/abs/2010.13993) [code](https://github.com/dmlc/dgl/blob/master/examples/sparse/c_and_s.py)\\n\",\n        \"* *Simplifying Graph Convolutional Networks* [paper](https://arxiv.org/abs/1902.07153) [code](https://github.com/dmlc/dgl/blob/master/examples/sparse/sgc.py)\\n\",\n        \"* *Graph Neural Networks Inspired by Classical Iterative Algorithms* [paper](https://arxiv.org/pdf/2103.06064.pdf) [code](https://github.com/dmlc/dgl/blob/master/examples/sparse/twirls.py)\"\n      ],\n      \"metadata\": {\n        \"id\": \"lI2Nms8PXq-y\"\n      }\n    }\n  ]\n}\n"
  },
  {
    "path": "notebooks/sparse/graph_transformer.ipynb",
    "content": "{\n  \"cells\": [\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"Jv-tHPvR-JKa\"\n      },\n      \"source\": [\n        \"# Graph Transformer in a Nutshell\\n\",\n        \"\\n\",\n        \"The **Transformer** [(Vaswani et al. 2017)](https://proceedings.neurips.cc/paper/2017/hash/3f5ee243547dee91fbd053c1c4a845aa-Abstract.html) has been proven an effective learning architecture in natural language processing and computer vision.\\n\",\n        \"Recently, researchers turns to explore the application of transformer in graph learning. They have achieved inital success on many practical tasks, e.g., graph property prediction.\\n\",\n        \"[Dwivedi et al. (2020)](https://arxiv.org/abs/2012.09699) firstly generalize the transformer neural architecture to graph-structured data. Here, we present how to build such a graph transformer with DGL's sparse matrix APIs.\\n\",\n        \"\\n\",\n        \"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dmlc/dgl/blob/master/notebooks/sparse/graph_transformer.ipynb) [![GitHub](https://img.shields.io/badge/-View%20on%20GitHub-181717?logo=github&logoColor=ffffff)](https://github.com/dmlc/dgl/blob/master/notebooks/sparse/graph_transformer.ipynb)\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"# Install required packages.\\n\",\n        \"import os\\n\",\n        \"import torch\\n\",\n        \"os.environ['TORCH'] = torch.__version__\\n\",\n        \"os.environ['DGLBACKEND'] = \\\"pytorch\\\"\\n\",\n        \"\\n\",\n        \"# Uncomment below to install required packages. If the CUDA version is not 11.8,\\n\",\n        \"# check the https://www.dgl.ai/pages/start.html to find the supported CUDA\\n\",\n        \"# version and corresponding command to install DGL.\\n\",\n        \"#!pip install dgl -f https://data.dgl.ai/wheels/cu118/repo.html > /dev/null\\n\",\n        \"#!pip install ogb >/dev/null\\n\",\n        \"\\n\",\n        \"try:\\n\",\n        \"    import dgl\\n\",\n        \"    installed = True\\n\",\n        \"except ImportError:\\n\",\n        \"    installed = False\\n\",\n        \"print(\\\"DGL installed!\\\" if installed else \\\"Failed to install DGL!\\\")\"\n      ],\n      \"metadata\": {\n        \"id\": \"8wIJZQqODy-7\"\n      },\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"nOpFdtLI-JKb\"\n      },\n      \"source\": [\n        \"## Sparse Multi-head Attention\\n\",\n        \"\\n\",\n        \"Recall the all-pairs scaled-dot-product attention mechanism in vanillar Transformer:\\n\",\n        \"\\n\",\n        \"$$\\\\text{Attn}=\\\\text{softmax}(\\\\dfrac{QK^T} {\\\\sqrt{d}})V,$$\\n\",\n        \"\\n\",\n        \"The graph transformer (GT) model employs a Sparse Multi-head Attention block:\\n\",\n        \"\\n\",\n        \"$$\\\\text{SparseAttn}(Q, K, V, A) = \\\\text{softmax}(\\\\frac{(QK^T) \\\\circ A}{\\\\sqrt{d}})V,$$\\n\",\n        \"\\n\",\n        \"where $Q, K, V ∈\\\\mathbb{R}^{N\\\\times d}$ are query feature, key feature, and value feature, respectively. $A\\\\in[0,1]^{N\\\\times N}$ is the adjacency matrix of the input graph. $(QK^T)\\\\circ A$ means that the multiplication of query matrix and key matrix is followed by a Hadamard product (or element-wise multiplication) with the sparse adjacency matrix as illustrated in the figure below:\\n\",\n        \"\\n\",\n        \"<img src=\\\"https://drive.google.com/uc?id=1OgMAewLR3Z1vz5y4J8aPRSeaU3g8iQfX\\\" width=\\\"500\\\">\\n\",\n        \"\\n\",\n        \"Essentially, only the attention scores between connected nodes are computed according to the sparsity of $A$. This operation is also called *Sampled Dense Dense Matrix Multiplication (SDDMM)*.\\n\",\n        \"\\n\",\n        \"Enjoying the [batched SDDMM API](https://docs.dgl.ai/en/latest/generated/dgl.sparse.bsddmm.html) in DGL, we can parallel the computation on multiple attention heads (different representation subspaces).\\n\",\n        \"\\n\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"dh7zc5v0-JKb\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"import dgl\\n\",\n        \"import dgl.nn as dglnn\\n\",\n        \"import dgl.sparse as dglsp\\n\",\n        \"import torch\\n\",\n        \"import torch.nn as nn\\n\",\n        \"import torch.nn.functional as F\\n\",\n        \"import torch.optim as optim\\n\",\n        \"\\n\",\n        \"from dgl.data import AsGraphPredDataset\\n\",\n        \"from dgl.dataloading import GraphDataLoader\\n\",\n        \"from ogb.graphproppred import collate_dgl, DglGraphPropPredDataset, Evaluator\\n\",\n        \"from ogb.graphproppred.mol_encoder import AtomEncoder\\n\",\n        \"from tqdm import tqdm\\n\",\n        \"\\n\",\n        \"\\n\",\n        \"class SparseMHA(nn.Module):\\n\",\n        \"    \\\"\\\"\\\"Sparse Multi-head Attention Module\\\"\\\"\\\"\\n\",\n        \"\\n\",\n        \"    def __init__(self, hidden_size=80, num_heads=8):\\n\",\n        \"        super().__init__()\\n\",\n        \"        self.hidden_size = hidden_size\\n\",\n        \"        self.num_heads = num_heads\\n\",\n        \"        self.head_dim = hidden_size // num_heads\\n\",\n        \"        self.scaling = self.head_dim**-0.5\\n\",\n        \"\\n\",\n        \"        self.q_proj = nn.Linear(hidden_size, hidden_size)\\n\",\n        \"        self.k_proj = nn.Linear(hidden_size, hidden_size)\\n\",\n        \"        self.v_proj = nn.Linear(hidden_size, hidden_size)\\n\",\n        \"        self.out_proj = nn.Linear(hidden_size, hidden_size)\\n\",\n        \"\\n\",\n        \"    def forward(self, A, h):\\n\",\n        \"        N = len(h)\\n\",\n        \"        # [N, dh, nh]\\n\",\n        \"        q = self.q_proj(h).reshape(N, self.head_dim, self.num_heads)\\n\",\n        \"        q *= self.scaling\\n\",\n        \"        # [N, dh, nh]\\n\",\n        \"        k = self.k_proj(h).reshape(N, self.head_dim, self.num_heads)\\n\",\n        \"        # [N, dh, nh]\\n\",\n        \"        v = self.v_proj(h).reshape(N, self.head_dim, self.num_heads)\\n\",\n        \"\\n\",\n        \"        ######################################################################\\n\",\n        \"        # (HIGHLIGHT) Compute the multi-head attention with Sparse Matrix API\\n\",\n        \"        ######################################################################\\n\",\n        \"        attn = dglsp.bsddmm(A, q, k.transpose(1, 0))  # (sparse) [N, N, nh]\\n\",\n        \"        # Sparse softmax by default applies on the last sparse dimension.\\n\",\n        \"        attn = attn.softmax()  # (sparse) [N, N, nh]\\n\",\n        \"        out = dglsp.bspmm(attn, v)  # [N, dh, nh]\\n\",\n        \"\\n\",\n        \"        return self.out_proj(out.reshape(N, -1))\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"3_Fm6Lrx-JKc\"\n      },\n      \"source\": [\n        \"## Graph Transformer Layer\\n\",\n        \"\\n\",\n        \"The GT layer is composed of Multi-head Attention, Batch Norm, and Feed-forward Network, connected by residual links as in vanilla transformer.\\n\",\n        \"\\n\",\n        \"<img src=\\\"https://drive.google.com/uc?id=1cm-Ijw7bUQIOkoTKn5MQ3m4-66JqCsMz\\\" width=\\\"300\\\">\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"M6h7JVWT-JKd\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"class GTLayer(nn.Module):\\n\",\n        \"    \\\"\\\"\\\"Graph Transformer Layer\\\"\\\"\\\"\\n\",\n        \"\\n\",\n        \"    def __init__(self, hidden_size=80, num_heads=8):\\n\",\n        \"        super().__init__()\\n\",\n        \"        self.MHA = SparseMHA(hidden_size=hidden_size, num_heads=num_heads)\\n\",\n        \"        self.batchnorm1 = nn.BatchNorm1d(hidden_size)\\n\",\n        \"        self.batchnorm2 = nn.BatchNorm1d(hidden_size)\\n\",\n        \"        self.FFN1 = nn.Linear(hidden_size, hidden_size * 2)\\n\",\n        \"        self.FFN2 = nn.Linear(hidden_size * 2, hidden_size)\\n\",\n        \"\\n\",\n        \"    def forward(self, A, h):\\n\",\n        \"        h1 = h\\n\",\n        \"        h = self.MHA(A, h)\\n\",\n        \"        h = self.batchnorm1(h + h1)\\n\",\n        \"\\n\",\n        \"        h2 = h\\n\",\n        \"        h = self.FFN2(F.relu(self.FFN1(h)))\\n\",\n        \"        h = h2 + h\\n\",\n        \"\\n\",\n        \"        return self.batchnorm2(h)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"t40DhVjI-JKd\"\n      },\n      \"source\": [\n        \"## Graph Transformer Model\\n\",\n        \"\\n\",\n        \"The GT model is constructed by stacking GT layers. The input positional encoding of vanilla transformer is replaced with Laplacian positional encoding [(Dwivedi et al. 2020)](https://arxiv.org/abs/2003.00982). For the graph-level prediction task, an extra pooler is stacked on top of GT layers to aggregate node feature of the same graph.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"UrjvEBrF-JKe\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"class GTModel(nn.Module):\\n\",\n        \"    def __init__(\\n\",\n        \"        self,\\n\",\n        \"        out_size,\\n\",\n        \"        hidden_size=80,\\n\",\n        \"        pos_enc_size=2,\\n\",\n        \"        num_layers=8,\\n\",\n        \"        num_heads=8,\\n\",\n        \"    ):\\n\",\n        \"        super().__init__()\\n\",\n        \"        self.atom_encoder = AtomEncoder(hidden_size)\\n\",\n        \"        self.pos_linear = nn.Linear(pos_enc_size, hidden_size)\\n\",\n        \"        self.layers = nn.ModuleList(\\n\",\n        \"            [GTLayer(hidden_size, num_heads) for _ in range(num_layers)]\\n\",\n        \"        )\\n\",\n        \"        self.pooler = dglnn.SumPooling()\\n\",\n        \"        self.predictor = nn.Sequential(\\n\",\n        \"            nn.Linear(hidden_size, hidden_size // 2),\\n\",\n        \"            nn.ReLU(),\\n\",\n        \"            nn.Linear(hidden_size // 2, hidden_size // 4),\\n\",\n        \"            nn.ReLU(),\\n\",\n        \"            nn.Linear(hidden_size // 4, out_size),\\n\",\n        \"        )\\n\",\n        \"\\n\",\n        \"    def forward(self, g, X, pos_enc):\\n\",\n        \"        indices = torch.stack(g.edges())\\n\",\n        \"        N = g.num_nodes()\\n\",\n        \"        A = dglsp.spmatrix(indices, shape=(N, N))\\n\",\n        \"        h = self.atom_encoder(X) + self.pos_linear(pos_enc)\\n\",\n        \"        for layer in self.layers:\\n\",\n        \"            h = layer(A, h)\\n\",\n        \"        h = self.pooler(g, h)\\n\",\n        \"\\n\",\n        \"        return self.predictor(h)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"RdrPU18I-JKe\"\n      },\n      \"source\": [\n        \"## Training\\n\",\n        \"\\n\",\n        \"We train the GT model on [ogbg-molhiv](https://ogb.stanford.edu/docs/graphprop/#ogbg-mol) benchmark. The Laplacian positional encoding of each graph is pre-computed (with the API [here](https://docs.dgl.ai/en/latest/generated/dgl.laplacian_pe.html)) as part of the input to the model.\\n\",\n        \"\\n\",\n        \"*Note that we down-sample the dataset to make this demo runs faster. See the* [*example script*](https://github.com/dmlc/dgl/blob/master/examples/sparse/graph_transformer.py) *for the performance on the full dataset.*\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"V41i0w-9-JKe\",\n        \"colab\": {\n          \"base_uri\": \"https://localhost:8080/\"\n        },\n        \"outputId\": \"15343d1a-a32d-4677-d053-d9da96910f43\"\n      },\n      \"outputs\": [\n        {\n          \"output_type\": \"stream\",\n          \"name\": \"stderr\",\n          \"text\": [\n            \"Computing Laplacian PE:   1%|          | 25/4000 [00:00<00:16, 244.77it/s]/usr/local/lib/python3.8/dist-packages/dgl/backend/pytorch/tensor.py:52: UserWarning: Casting complex values to real discards the imaginary part (Triggered internally at ../aten/src/ATen/native/Copy.cpp:250.)\\n\",\n            \"  return th.as_tensor(data, dtype=dtype)\\n\",\n            \"Computing Laplacian PE: 100%|██████████| 4000/4000 [00:13<00:00, 296.04it/s]\\n\"\n          ]\n        },\n        {\n          \"output_type\": \"stream\",\n          \"name\": \"stdout\",\n          \"text\": [\n            \"Epoch: 000, Loss: 0.2486, Val: 0.3082, Test: 0.3068\\n\",\n            \"Epoch: 001, Loss: 0.1695, Val: 0.4684, Test: 0.4572\\n\",\n            \"Epoch: 002, Loss: 0.1428, Val: 0.5887, Test: 0.4721\\n\",\n            \"Epoch: 003, Loss: 0.1237, Val: 0.6375, Test: 0.5010\\n\",\n            \"Epoch: 004, Loss: 0.1127, Val: 0.6628, Test: 0.4854\\n\",\n            \"Epoch: 005, Loss: 0.1047, Val: 0.6811, Test: 0.4983\\n\",\n            \"Epoch: 006, Loss: 0.0949, Val: 0.6751, Test: 0.5409\\n\",\n            \"Epoch: 007, Loss: 0.0901, Val: 0.6340, Test: 0.5357\\n\",\n            \"Epoch: 008, Loss: 0.0811, Val: 0.6717, Test: 0.5543\\n\",\n            \"Epoch: 009, Loss: 0.0643, Val: 0.7861, Test: 0.5628\\n\",\n            \"Epoch: 010, Loss: 0.0489, Val: 0.7319, Test: 0.5341\\n\",\n            \"Epoch: 011, Loss: 0.0340, Val: 0.7884, Test: 0.5299\\n\",\n            \"Epoch: 012, Loss: 0.0285, Val: 0.5887, Test: 0.4293\\n\",\n            \"Epoch: 013, Loss: 0.0361, Val: 0.5514, Test: 0.3419\\n\",\n            \"Epoch: 014, Loss: 0.0451, Val: 0.6795, Test: 0.4964\\n\",\n            \"Epoch: 015, Loss: 0.0429, Val: 0.7405, Test: 0.5527\\n\",\n            \"Epoch: 016, Loss: 0.0331, Val: 0.7859, Test: 0.4994\\n\",\n            \"Epoch: 017, Loss: 0.0177, Val: 0.6544, Test: 0.4457\\n\",\n            \"Epoch: 018, Loss: 0.0201, Val: 0.8250, Test: 0.6073\\n\",\n            \"Epoch: 019, Loss: 0.0093, Val: 0.7356, Test: 0.5561\\n\"\n          ]\n        }\n      ],\n      \"source\": [\n        \"@torch.no_grad()\\n\",\n        \"def evaluate(model, dataloader, evaluator, device):\\n\",\n        \"    model.eval()\\n\",\n        \"    y_true = []\\n\",\n        \"    y_pred = []\\n\",\n        \"    for batched_g, labels in dataloader:\\n\",\n        \"        batched_g, labels = batched_g.to(device), labels.to(device)\\n\",\n        \"        y_hat = model(batched_g, batched_g.ndata[\\\"feat\\\"], batched_g.ndata[\\\"PE\\\"])\\n\",\n        \"        y_true.append(labels.view(y_hat.shape).detach().cpu())\\n\",\n        \"        y_pred.append(y_hat.detach().cpu())\\n\",\n        \"    y_true = torch.cat(y_true, dim=0).numpy()\\n\",\n        \"    y_pred = torch.cat(y_pred, dim=0).numpy()\\n\",\n        \"    input_dict = {\\\"y_true\\\": y_true, \\\"y_pred\\\": y_pred}\\n\",\n        \"    return evaluator.eval(input_dict)[\\\"rocauc\\\"]\\n\",\n        \"\\n\",\n        \"\\n\",\n        \"def train(model, dataset, evaluator, device):\\n\",\n        \"    train_dataloader = GraphDataLoader(\\n\",\n        \"        dataset[dataset.train_idx],\\n\",\n        \"        batch_size=256,\\n\",\n        \"        shuffle=True,\\n\",\n        \"        collate_fn=collate_dgl,\\n\",\n        \"    )\\n\",\n        \"    valid_dataloader = GraphDataLoader(\\n\",\n        \"        dataset[dataset.val_idx], batch_size=256, collate_fn=collate_dgl\\n\",\n        \"    )\\n\",\n        \"    test_dataloader = GraphDataLoader(\\n\",\n        \"        dataset[dataset.test_idx], batch_size=256, collate_fn=collate_dgl\\n\",\n        \"    )\\n\",\n        \"    optimizer = optim.Adam(model.parameters(), lr=0.001)\\n\",\n        \"    num_epochs = 20\\n\",\n        \"    scheduler = optim.lr_scheduler.StepLR(\\n\",\n        \"        optimizer, step_size=num_epochs, gamma=0.5\\n\",\n        \"    )\\n\",\n        \"    loss_fcn = nn.BCEWithLogitsLoss()\\n\",\n        \"\\n\",\n        \"    for epoch in range(num_epochs):\\n\",\n        \"        model.train()\\n\",\n        \"        total_loss = 0.0\\n\",\n        \"        for batched_g, labels in train_dataloader:\\n\",\n        \"            batched_g, labels = batched_g.to(device), labels.to(device)\\n\",\n        \"            logits = model(\\n\",\n        \"                batched_g, batched_g.ndata[\\\"feat\\\"], batched_g.ndata[\\\"PE\\\"]\\n\",\n        \"            )\\n\",\n        \"            loss = loss_fcn(logits, labels.float())\\n\",\n        \"            total_loss += loss.item()\\n\",\n        \"            optimizer.zero_grad()\\n\",\n        \"            loss.backward()\\n\",\n        \"            optimizer.step()\\n\",\n        \"        scheduler.step()\\n\",\n        \"        avg_loss = total_loss / len(train_dataloader)\\n\",\n        \"        val_metric = evaluate(model, valid_dataloader, evaluator, device)\\n\",\n        \"        test_metric = evaluate(model, test_dataloader, evaluator, device)\\n\",\n        \"        print(\\n\",\n        \"            f\\\"Epoch: {epoch:03d}, Loss: {avg_loss:.4f}, \\\"\\n\",\n        \"            f\\\"Val: {val_metric:.4f}, Test: {test_metric:.4f}\\\"\\n\",\n        \"        )\\n\",\n        \"\\n\",\n        \"\\n\",\n        \"# Training device.\\n\",\n        \"dev = torch.device(\\\"cpu\\\")\\n\",\n        \"# Uncomment the code below to train on GPU. Be sure to install DGL with CUDA support.\\n\",\n        \"#dev = torch.device(\\\"cuda:0\\\")\\n\",\n        \"\\n\",\n        \"# Load dataset.\\n\",\n        \"pos_enc_size = 8\\n\",\n        \"dataset = AsGraphPredDataset(\\n\",\n        \"    DglGraphPropPredDataset(\\\"ogbg-molhiv\\\", \\\"./data/OGB\\\")\\n\",\n        \")\\n\",\n        \"evaluator = Evaluator(\\\"ogbg-molhiv\\\")\\n\",\n        \"\\n\",\n        \"# Down sample the dataset to make the tutorial run faster.\\n\",\n        \"import random\\n\",\n        \"random.seed(42)\\n\",\n        \"train_size = len(dataset.train_idx)\\n\",\n        \"val_size = len(dataset.val_idx)\\n\",\n        \"test_size = len(dataset.test_idx)\\n\",\n        \"dataset.train_idx = dataset.train_idx[\\n\",\n        \"    torch.LongTensor(random.sample(range(train_size), 2000))\\n\",\n        \"]\\n\",\n        \"dataset.val_idx = dataset.val_idx[\\n\",\n        \"    torch.LongTensor(random.sample(range(val_size), 1000))\\n\",\n        \"]\\n\",\n        \"dataset.test_idx = dataset.test_idx[\\n\",\n        \"    torch.LongTensor(random.sample(range(test_size), 1000))\\n\",\n        \"]\\n\",\n        \"\\n\",\n        \"# Laplacian positional encoding.\\n\",\n        \"indices = torch.cat([dataset.train_idx, dataset.val_idx, dataset.test_idx])\\n\",\n        \"for idx in tqdm(indices, desc=\\\"Computing Laplacian PE\\\"):\\n\",\n        \"    g, _ = dataset[idx]\\n\",\n        \"    g.ndata[\\\"PE\\\"] = dgl.laplacian_pe(g, k=pos_enc_size, padding=True)\\n\",\n        \"\\n\",\n        \"# Create model.\\n\",\n        \"out_size = dataset.num_tasks\\n\",\n        \"model = GTModel(out_size=out_size, pos_enc_size=pos_enc_size).to(dev)\\n\",\n        \"\\n\",\n        \"# Kick off training.\\n\",\n        \"train(model, dataset, evaluator, dev)\"\n      ]\n    }\n  ],\n  \"metadata\": {\n    \"language_info\": {\n      \"name\": \"python\"\n    },\n    \"orig_nbformat\": 4,\n    \"colab\": {\n      \"provenance\": []\n    },\n    \"gpuClass\": \"standard\",\n    \"kernelspec\": {\n      \"name\": \"python3\",\n      \"display_name\": \"Python 3\"\n    },\n    \"accelerator\": \"GPU\"\n  },\n  \"nbformat\": 4,\n  \"nbformat_minor\": 0\n}\n"
  },
  {
    "path": "notebooks/sparse/hgnn.ipynb",
    "content": "{\n  \"nbformat\": 4,\n  \"nbformat_minor\": 0,\n  \"metadata\": {\n    \"colab\": {\n      \"provenance\": [],\n      \"toc_visible\": true\n    },\n    \"kernelspec\": {\n      \"name\": \"python3\",\n      \"display_name\": \"Python 3\"\n    },\n    \"language_info\": {\n      \"name\": \"python\"\n    },\n    \"gpuClass\": \"standard\",\n    \"accelerator\": \"GPU\"\n  },\n  \"cells\": [\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"# Hypergraph Neural Networks\\n\",\n        \"\\n\",\n        \"This tutorial illustrates what is hypergraph and how to build a Hypergraph Neural Network using DGL's sparse matrix APIs.\\n\",\n        \"\\n\",\n        \"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dmlc/dgl/blob/master/notebooks/sparse/hgnn.ipynb) [![GitHub](https://img.shields.io/badge/-View%20on%20GitHub-181717?logo=github&logoColor=ffffff)](https://github.com/dmlc/dgl/blob/master/notebooks/sparse/hgnn.ipynb)\"\n      ],\n      \"metadata\": {\n        \"id\": \"eiDu3XgReCt4\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"# Install required packages.\\n\",\n        \"import os\\n\",\n        \"import torch\\n\",\n        \"os.environ['TORCH'] = torch.__version__\\n\",\n        \"os.environ['DGLBACKEND'] = \\\"pytorch\\\"\\n\",\n        \"\\n\",\n        \"# Uncomment below to install required packages. If the CUDA version is not 11.8,\\n\",\n        \"# check the https://www.dgl.ai/pages/start.html to find the supported CUDA\\n\",\n        \"# version and corresponding command to install DGL.\\n\",\n        \"#!pip install dgl -f https://data.dgl.ai/wheels/cu118/repo.html > /dev/null\\n\",\n        \"#!pip install torchmetrics > /dev/null\\n\",\n        \"\\n\",\n        \"try:\\n\",\n        \"    import dgl\\n\",\n        \"    installed = True\\n\",\n        \"except ImportError:\\n\",\n        \"    installed = False\\n\",\n        \"print(\\\"DGL installed!\\\" if installed else \\\"Failed to install DGL!\\\")\"\n      ],\n      \"metadata\": {\n        \"id\": \"__2tKqL0eaB0\"\n      },\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"## Hypergraphs\\n\",\n        \"\\n\",\n        \"A [hypergraph](https://en.wikipedia.org/wiki/Hypergraph) consists of *nodes* and *hyperedges*.  Contrary to edges in graphs, a *hyperedge* can connect arbitrary number of nodes.  For instance, the following figure shows a hypergraph with 11 nodes and 5 hyperedges drawn in different colors.\\n\",\n        \"![](https://data.dgl.ai/tutorial/img/hgnn/hypergraph4.PNG)\\n\",\n        \"\\n\",\n        \"Hypergraphs are particularly useful when the relationships between data points within the dataset is not binary.  For instance, more than two products can be co-purchased together in an e-commerce system, so the relationship of co-purchase is $n$-ary rather than binary, and therefore it is better described as a hypergraph rather than a normal graph.\\n\",\n        \"\\n\",\n        \"A hypergraph is usually characterized by its *incidence matrix* $H$, whose rows represent nodes and columns represent hyperedges.  An entry $H_{ij}$ is 1 if hyperedge $j$ includes node $i$, or 0 otherwise.  For example, the hypergraph in the figure above can be characterized by a $11 \\\\times 5$ matrix as follows:\\n\",\n        \"\\n\",\n        \"$$\\n\",\n        \"H = \\\\begin{bmatrix}\\n\",\n        \"1 & 0 & 0 & 0 & 0 \\\\\\\\\\n\",\n        \"1 & 0 & 0 & 0 & 0 \\\\\\\\\\n\",\n        \"1 & 1 & 0 & 1 & 1 \\\\\\\\\\n\",\n        \"0 & 0 & 1 & 0 & 0 \\\\\\\\\\n\",\n        \"0 & 1 & 0 & 0 & 0 \\\\\\\\\\n\",\n        \"1 & 0 & 1 & 1 & 1 \\\\\\\\\\n\",\n        \"0 & 0 & 1 & 0 & 0 \\\\\\\\\\n\",\n        \"0 & 1 & 0 & 1 & 0 \\\\\\\\\\n\",\n        \"0 & 1 & 0 & 1 & 0 \\\\\\\\\\n\",\n        \"0 & 0 & 1 & 0 & 1 \\\\\\\\\\n\",\n        \"0 & 0 & 0 & 0 & 1 \\\\\\\\\\n\",\n        \"\\\\end{bmatrix}\\n\",\n        \"$$\\n\",\n        \"\\n\",\n        \"One can construct the hypergraph incidence matrix by specifying two tensors `nodes` and `hyperedges`, where the node ID `nodes[i]` belongs to the hyperedge ID `hyperedges[i]` for all `i`.  In the case above, the incidence matrix can be constructed below.\\n\"\n      ],\n      \"metadata\": {\n        \"id\": \"unL_mAj-TqC6\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"import dgl.sparse as dglsp\\n\",\n        \"import torch\\n\",\n        \"\\n\",\n        \"H = dglsp.spmatrix(\\n\",\n        \"    torch.LongTensor([[0, 1, 2, 2, 2, 2, 3, 4, 5, 5, 5, 5, 6, 7, 7, 8, 8, 9, 9, 10],\\n\",\n        \"                      [0, 0, 0, 1, 3, 4, 2, 1, 0, 2, 3, 4, 2, 1, 3, 1, 3, 2, 4, 4]])\\n\",\n        \")\\n\",\n        \"\\n\",\n        \"print(H.to_dense())\"\n      ],\n      \"metadata\": {\n        \"id\": \"I_cExvtIJD1F\",\n        \"colab\": {\n          \"base_uri\": \"https://localhost:8080/\"\n        },\n        \"outputId\": \"a1a576f6-1559-479c-9f3e-93e41a56833d\"\n      },\n      \"execution_count\": null,\n      \"outputs\": [\n        {\n          \"output_type\": \"stream\",\n          \"name\": \"stdout\",\n          \"text\": [\n            \"tensor([[1., 0., 0., 0., 0.],\\n\",\n            \"        [1., 0., 0., 0., 0.],\\n\",\n            \"        [1., 1., 0., 1., 1.],\\n\",\n            \"        [0., 0., 1., 0., 0.],\\n\",\n            \"        [0., 1., 0., 0., 0.],\\n\",\n            \"        [1., 0., 1., 1., 1.],\\n\",\n            \"        [0., 0., 1., 0., 0.],\\n\",\n            \"        [0., 1., 0., 1., 0.],\\n\",\n            \"        [0., 1., 0., 1., 0.],\\n\",\n            \"        [0., 0., 1., 0., 1.],\\n\",\n            \"        [0., 0., 0., 0., 1.]])\\n\"\n          ]\n        }\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"The degree of a node in a hypergraph is defined as the number of hyperedges including the node.  Similarly, the degree of a hyperedge in a hypergraph is defined as the number of nodes included by the hyperedge.  In the example above, the hyperedge degrees can be computed by the sum of row vectors (i.e. all 4), while the node degree can be computed by the sum of column vectors.\"\n      ],\n      \"metadata\": {\n        \"id\": \"p-shCPQPHvBB\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"node_degrees = H.sum(1)\\n\",\n        \"print(\\\"Node degrees\\\", node_degrees)\\n\",\n        \"\\n\",\n        \"hyperedge_degrees = H.sum(0)\\n\",\n        \"print(\\\"Hyperedge degrees\\\", hyperedge_degrees)\"\n      ],\n      \"metadata\": {\n        \"id\": \"wjKm9gkTOnU9\",\n        \"colab\": {\n          \"base_uri\": \"https://localhost:8080/\"\n        },\n        \"outputId\": \"ffe2c441-8c2c-48a7-cef2-4ef6e96548ec\"\n      },\n      \"execution_count\": null,\n      \"outputs\": [\n        {\n          \"output_type\": \"stream\",\n          \"name\": \"stdout\",\n          \"text\": [\n            \"Node degrees tensor([1., 1., 4., 1., 1., 4., 1., 2., 2., 2., 1.])\\n\",\n            \"Hyperedge degrees tensor([4., 4., 4., 4., 4.])\\n\"\n          ]\n        }\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"\\n\",\n        \"## Hypergraph Neural Network (HGNN) Layer\\n\",\n        \"\\n\",\n        \"The [HGNN layer](https://arxiv.org/pdf/1809.09401.pdf) is defined as:\\n\",\n        \"\\n\",\n        \"$$f(X^{(l)}, H; W^{(l)}) = \\\\sigma(L X^{(l)} W^{(l)})$$$$L = D_v^{-1/2} H B D_e^{-1} H^\\\\top D_v^{-1/2}$$\\n\",\n        \"\\n\",\n        \"where\\n\",\n        \"\\n\",\n        \"* $H \\\\in \\\\mathbb{R}^{N \\\\times M}$ is the incidence matrix of hypergraph with $N$ nodes and $M$ hyperedges.\\n\",\n        \"* $D_v \\\\in \\\\mathbb{R}^{N \\\\times N}$ is a diagonal matrix representing node degrees, whose $i$-th diagonal element is $\\\\sum_{j=1}^M H_{ij}$.\\n\",\n        \"* $D_e \\\\in \\\\mathbb{R}^{M \\\\times M}$ is a diagonal matrix representing hyperedge degrees, whose $j$-th diagonal element is $\\\\sum_{i=1}^N H_{ij}$.\\n\",\n        \"* $B \\\\in \\\\mathbb{R}^{M \\\\times M}$ is a diagonal matrix representing the hyperedge weights, whose $j$-th diagonal element is the weight of $j$-th hyperedge.  In our example, $B$ is an identity matrix.\\n\",\n        \"\\n\",\n        \"The following code builds a two-layer HGNN.\"\n      ],\n      \"metadata\": {\n        \"id\": \"7kxrINkVHrAi\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"import dgl.sparse as dglsp\\n\",\n        \"import torch\\n\",\n        \"import torch.nn as nn\\n\",\n        \"import torch.nn.functional as F\\n\",\n        \"import tqdm\\n\",\n        \"from dgl.data import CoraGraphDataset\\n\",\n        \"from torchmetrics.functional import accuracy\\n\",\n        \"\\n\",\n        \"\\n\",\n        \"class HGNN(nn.Module):\\n\",\n        \"    def __init__(self, H, in_size, out_size, hidden_dims=16):\\n\",\n        \"        super().__init__()\\n\",\n        \"\\n\",\n        \"        self.W1 = nn.Linear(in_size, hidden_dims)\\n\",\n        \"        self.W2 = nn.Linear(hidden_dims, out_size)\\n\",\n        \"        self.dropout = nn.Dropout(0.5)\\n\",\n        \"\\n\",\n        \"        ###########################################################\\n\",\n        \"        # (HIGHLIGHT) Compute the Laplacian with Sparse Matrix API\\n\",\n        \"        ###########################################################\\n\",\n        \"        # Compute node degree.\\n\",\n        \"        d_V = H.sum(1)\\n\",\n        \"        # Compute edge degree.\\n\",\n        \"        d_E = H.sum(0)\\n\",\n        \"        # Compute the inverse of the square root of the diagonal D_v.\\n\",\n        \"        D_v_invsqrt = dglsp.diag(d_V**-0.5)\\n\",\n        \"        # Compute the inverse of the diagonal D_e.\\n\",\n        \"        D_e_inv = dglsp.diag(d_E**-1)\\n\",\n        \"        # In our example, B is an identity matrix.\\n\",\n        \"        n_edges = d_E.shape[0]\\n\",\n        \"        B = dglsp.identity((n_edges, n_edges))\\n\",\n        \"        # Compute Laplacian from the equation above.\\n\",\n        \"        self.L = D_v_invsqrt @ H @ B @ D_e_inv @ H.T @ D_v_invsqrt\\n\",\n        \"\\n\",\n        \"    def forward(self, X):\\n\",\n        \"        X = self.L @ self.W1(self.dropout(X))\\n\",\n        \"        X = F.relu(X)\\n\",\n        \"        X = self.L @ self.W2(self.dropout(X))\\n\",\n        \"        return X\"\n      ],\n      \"metadata\": {\n        \"id\": \"58WnPtPvT2mx\"\n      },\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"## Loading Data\\n\",\n        \"\\n\",\n        \"We use Cora citation network in our example.  But instead of using the original \\\"cite\\\" relationship between papers, we consider the \\\"co-cite\\\" relationship between papers.  We build a hypergraph from the original citation network where for each paper we construct a hyperedge that includes all the other papers it cited, as well as the paper itself.\\n\",\n        \"\\n\",\n        \"![](https://data.dgl.ai/tutorial/img/hgnn/equiv.PNG)\\n\",\n        \"\\n\",\n        \"Note that a hypergraph constructed this way has an incidence matrix exactly identical to the adjacency matrix of the original graph (plus an identity matrix for self-loops).  This is because each hyperedge has a one-to-one correspondence to each paper.  So we can directly take the graph's adjacency matrix and add an identity matrix to it, and we use it as the hypergraph's incidence matrix.\"\n      ],\n      \"metadata\": {\n        \"id\": \"bPrOHVaGwUD0\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"def load_data():\\n\",\n        \"    dataset = CoraGraphDataset()\\n\",\n        \"\\n\",\n        \"    graph = dataset[0]\\n\",\n        \"    indices = torch.stack(graph.edges())\\n\",\n        \"    H = dglsp.spmatrix(indices)\\n\",\n        \"    H = H + dglsp.identity(H.shape)\\n\",\n        \"\\n\",\n        \"    X = graph.ndata[\\\"feat\\\"]\\n\",\n        \"    Y = graph.ndata[\\\"label\\\"]\\n\",\n        \"    train_mask = graph.ndata[\\\"train_mask\\\"]\\n\",\n        \"    val_mask = graph.ndata[\\\"val_mask\\\"]\\n\",\n        \"    test_mask = graph.ndata[\\\"test_mask\\\"]\\n\",\n        \"    return H, X, Y, dataset.num_classes, train_mask, val_mask, test_mask\"\n      ],\n      \"metadata\": {\n        \"id\": \"qI0j1J9pwTFg\"\n      },\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"## Training and Evaluation\\n\",\n        \"\\n\",\n        \"Now we can write the training and evaluation functions as follows.\"\n      ],\n      \"metadata\": {\n        \"id\": \"--rq1-r7wMST\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"def train(model, optimizer, X, Y, train_mask):\\n\",\n        \"    model.train()\\n\",\n        \"    Y_hat = model(X)\\n\",\n        \"    loss = F.cross_entropy(Y_hat[train_mask], Y[train_mask])\\n\",\n        \"    optimizer.zero_grad()\\n\",\n        \"    loss.backward()\\n\",\n        \"    optimizer.step()\\n\",\n        \"\\n\",\n        \"\\n\",\n        \"def evaluate(model, X, Y, val_mask, test_mask, num_classes):\\n\",\n        \"    model.eval()\\n\",\n        \"    Y_hat = model(X)\\n\",\n        \"    val_acc = accuracy(\\n\",\n        \"        Y_hat[val_mask], Y[val_mask], task=\\\"multiclass\\\", num_classes=num_classes\\n\",\n        \"    )\\n\",\n        \"    test_acc = accuracy(\\n\",\n        \"        Y_hat[test_mask],\\n\",\n        \"        Y[test_mask],\\n\",\n        \"        task=\\\"multiclass\\\",\\n\",\n        \"        num_classes=num_classes,\\n\",\n        \"    )\\n\",\n        \"    return val_acc, test_acc\\n\",\n        \"\\n\",\n        \"\\n\",\n        \"H, X, Y, num_classes, train_mask, val_mask, test_mask = load_data()\\n\",\n        \"model = HGNN(H, X.shape[1], num_classes)\\n\",\n        \"optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\\n\",\n        \"\\n\",\n        \"with tqdm.trange(500) as tq:\\n\",\n        \"    for epoch in tq:\\n\",\n        \"        train(model, optimizer, X, Y, train_mask)\\n\",\n        \"        val_acc, test_acc = evaluate(\\n\",\n        \"            model, X, Y, val_mask, test_mask, num_classes\\n\",\n        \"        )\\n\",\n        \"        tq.set_postfix(\\n\",\n        \"            {\\n\",\n        \"                \\\"Val acc\\\": f\\\"{val_acc:.5f}\\\",\\n\",\n        \"                \\\"Test acc\\\": f\\\"{test_acc:.5f}\\\",\\n\",\n        \"            },\\n\",\n        \"            refresh=False,\\n\",\n        \"        )\\n\",\n        \"\\n\",\n        \"print(f\\\"Test acc: {test_acc:.3f}\\\")\"\n      ],\n      \"metadata\": {\n        \"colab\": {\n          \"base_uri\": \"https://localhost:8080/\"\n        },\n        \"id\": \"IfEc6JRXwHPt\",\n        \"outputId\": \"0172578a-6a1b-49eb-adcb-77ee1a949186\"\n      },\n      \"execution_count\": null,\n      \"outputs\": [\n        {\n          \"output_type\": \"stream\",\n          \"name\": \"stdout\",\n          \"text\": [\n            \"Downloading /root/.dgl/cora_v2.zip from https://data.dgl.ai/dataset/cora_v2.zip...\\n\",\n            \"Extracting file to /root/.dgl/cora_v2\\n\",\n            \"Finished data loading and preprocessing.\\n\",\n            \"  NumNodes: 2708\\n\",\n            \"  NumEdges: 10556\\n\",\n            \"  NumFeats: 1433\\n\",\n            \"  NumClasses: 7\\n\",\n            \"  NumTrainingSamples: 140\\n\",\n            \"  NumValidationSamples: 500\\n\",\n            \"  NumTestSamples: 1000\\n\",\n            \"Done saving data into cached files.\\n\"\n          ]\n        },\n        {\n          \"output_type\": \"stream\",\n          \"name\": \"stderr\",\n          \"text\": [\n            \"100%|██████████| 500/500 [00:57<00:00,  8.70it/s, Val acc=0.77800, Test acc=0.78100]\"\n          ]\n        },\n        {\n          \"output_type\": \"stream\",\n          \"name\": \"stdout\",\n          \"text\": [\n            \"Test acc: 0.781\\n\"\n          ]\n        },\n        {\n          \"output_type\": \"stream\",\n          \"name\": \"stderr\",\n          \"text\": [\n            \"\\n\"\n          ]\n        }\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"For the complete example of HGNN, please refer to [here](https://github.com/dmlc/dgl/blob/master/examples/sparse/hgnn.py).\"\n      ],\n      \"metadata\": {\n        \"id\": \"59pCzjpBOyEW\"\n      }\n    }\n  ]\n}\n"
  },
  {
    "path": "notebooks/sparse/quickstart.ipynb",
    "content": "{\n  \"nbformat\": 4,\n  \"nbformat_minor\": 0,\n  \"metadata\": {\n    \"colab\": {\n      \"provenance\": [],\n      \"private_outputs\": true,\n      \"toc_visible\": true,\n      \"gpuType\": \"T4\"\n    },\n    \"kernelspec\": {\n      \"name\": \"python3\",\n      \"display_name\": \"Python 3\"\n    },\n    \"language_info\": {\n      \"name\": \"python\"\n    }\n  },\n  \"cells\": [\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"# Quickstart\\n\",\n        \"\\n\",\n        \"The tutorial provides a quick walkthrough of the classes and operators provided by the `dgl.sparse` package.\\n\",\n        \"\\n\",\n        \"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dmlc/dgl/blob/master/notebooks/sparse/quickstart.ipynb) [![GitHub](https://img.shields.io/badge/-View%20on%20GitHub-181717?logo=github&logoColor=ffffff)](https://github.com/dmlc/dgl/blob/master/notebooks/sparse/quickstart.ipynb)\"\n      ],\n      \"metadata\": {\n        \"id\": \"E0DAKDMuWz7I\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"# Install the required packages.\\n\",\n        \"\\n\",\n        \"import os\\n\",\n        \"# Uncomment following commands to download Pytorch and DGL\\n\",\n        \"# !pip install torch==2.0.0+cpu torchvision==0.15.1+cpu torchaudio==2.0.1 --index-url https://download.pytorch.org/whl/cpu > /dev/null\\n\",\n        \"# !pip install  dgl==1.1.0 -f https://data.dgl.ai/wheels/repo.html > /dev/null\\n\",\n        \"import torch\\n\",\n        \"os.environ['TORCH'] = torch.__version__\\n\",\n        \"os.environ['DGLBACKEND'] = \\\"pytorch\\\"\\n\",\n        \"\\n\",\n        \"\\n\",\n        \"try:\\n\",\n        \"    import dgl.sparse as dglsp\\n\",\n        \"    installed = True\\n\",\n        \"except ImportError:\\n\",\n        \"    installed = False\\n\",\n        \"print(\\\"DGL installed!\\\" if installed else \\\"DGL not found!\\\")\"\n      ],\n      \"metadata\": {\n        \"id\": \"19UZd7wyWzpT\"\n      },\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"## Sparse Matrix\\n\",\n        \"\\n\",\n        \"The core abstraction of DGL's sparse package is the `SparseMatrix` class. Compared with other sparse matrix libraries (such as `scipy.sparse` and `torch.sparse`), DGL's `SparseMatrix` is specialized for the deep learning workloads on structure data (e.g., Graph Neural Networks), with the following features:\\n\",\n        \"\\n\",\n        \"* **Auto sparse format.** Don't bother choosing between different sparse formats. There is only one `SparseMatrix` and it will select the best format for the operation to be performed.\\n\",\n        \"* **Non-zero elements can be scalar or vector.** Easy for modeling relations (e.g., edges) by vector representation.\\n\",\n        \"* **Fully PyTorch compatible.** The package is built upon PyTorch and is natively compatible with other tools in the PyTorch ecosystem.\\n\"\n      ],\n      \"metadata\": {\n        \"id\": \"GsWoAGC4RpHw\"\n      }\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"### Creating a DGL Sparse Matrix\\n\",\n        \"\\n\",\n        \"The simplest way to create a sparse matrix is using the `spmatrix` API by providing the indices of the non-zero elements. The indices are stored in a tensor of shape `(2, nnz)`, where the `i`-th non-zero element is stored at position `(indices[0][i], indices[1][i])`. The code below creates a 3x3 sparse matrix.\\n\"\n      ],\n      \"metadata\": {\n        \"id\": \"_q4HYodcWenB\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"h-ryVEs1PuIP\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"import torch\\n\",\n        \"import dgl.sparse as dglsp\\n\",\n        \"\\n\",\n        \"i = torch.tensor([[1, 1, 2],\\n\",\n        \"                  [0, 2, 0]])\\n\",\n        \"A = dglsp.spmatrix(i)  # 1.0 is default value for nnz elements.\\n\",\n        \"\\n\",\n        \"print(A)\\n\",\n        \"print(\\\"\\\")\\n\",\n        \"print(\\\"In dense format:\\\")\\n\",\n        \"print(A.to_dense())\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"If not specified, the shape is inferred automatically from the indices but you can specify it explicitly too.\"\n      ],\n      \"metadata\": {\n        \"id\": \"W1JJg-eZ7K3t\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"i = torch.tensor([[0, 0, 1],\\n\",\n        \"                  [0, 2, 0]])\\n\",\n        \"\\n\",\n        \"A1 = dglsp.spmatrix(i)\\n\",\n        \"print(f\\\"Implicit Shape: {A1.shape}\\\")\\n\",\n        \"print(A1.to_dense())\\n\",\n        \"print(\\\"\\\")\\n\",\n        \"\\n\",\n        \"A2 = dglsp.spmatrix(i, shape=(3, 3))\\n\",\n        \"print(f\\\"Explicit Shape: {A2.shape}\\\")\\n\",\n        \"print(A2.to_dense())\"\n      ],\n      \"metadata\": {\n        \"id\": \"80NNSQfd7L5V\"\n      },\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"Both scalar values and vector values can be set for nnz elements in Sparse Matrix.\"\n      ],\n      \"metadata\": {\n        \"id\": \"zdNgUf0ShfCe\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"i = torch.tensor([[1, 1, 2],\\n\",\n        \"                  [0, 2, 0]])\\n\",\n        \"# The length of the value should match the nnz elements represented by the\\n\",\n        \"# sparse matrix format.\\n\",\n        \"scalar_val = torch.tensor([1., 2., 3.])\\n\",\n        \"vector_val = torch.tensor([[1., 1.], [2., 2.], [3., 3.]])\\n\",\n        \"\\n\",\n        \"print(\\\"-----Scalar Values-----\\\")\\n\",\n        \"A = dglsp.spmatrix(i, scalar_val)\\n\",\n        \"print(A)\\n\",\n        \"print(\\\"\\\")\\n\",\n        \"print(\\\"In dense format:\\\")\\n\",\n        \"print(A.to_dense())\\n\",\n        \"print(\\\"\\\")\\n\",\n        \"\\n\",\n        \"print(\\\"-----Vector Values-----\\\")\\n\",\n        \"A = dglsp.spmatrix(i, vector_val)\\n\",\n        \"print(A)\\n\",\n        \"print(\\\"\\\")\\n\",\n        \"print(\\\"In dense format:\\\")\\n\",\n        \"print(A.to_dense())\"\n      ],\n      \"metadata\": {\n        \"id\": \"buE9ZkKvhp1f\"\n      },\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"*Duplicated indices*\"\n      ],\n      \"metadata\": {\n        \"id\": \"7ufTCDAVsrmP\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"i = torch.tensor([[0, 0, 0, 1],\\n\",\n        \"                  [0, 2, 2, 0]])\\n\",\n        \"val = torch.tensor([1., 2., 3., 4])\\n\",\n        \"A = dglsp.spmatrix(i, val)\\n\",\n        \"print(A)\\n\",\n        \"print(f\\\"Whether A contains duplicate indices: {A.has_duplicate()}\\\")\\n\",\n        \"print(\\\"\\\")\\n\",\n        \"\\n\",\n        \"B = A.coalesce()\\n\",\n        \"print(B)\\n\",\n        \"print(f\\\"Whether B contains duplicate indices: {B.has_duplicate()}\\\")\"\n      ],\n      \"metadata\": {\n        \"id\": \"ilSAlFLOs0o8\"\n      },\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"**val_like**\\n\",\n        \"\\n\",\n        \"You can create a new sparse matrix by retaining the non-zero indices of a given sparse matrix but with different non-zero values.\"\n      ],\n      \"metadata\": {\n        \"id\": \"ZJ09qM5NaxuI\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"i = torch.tensor([[1, 1, 2],\\n\",\n        \"                  [0, 2, 0]])\\n\",\n        \"val = torch.tensor([1., 2., 3.])\\n\",\n        \"A = dglsp.spmatrix(i, val)\\n\",\n        \"\\n\",\n        \"new_val = torch.tensor([4., 5., 6.])\\n\",\n        \"B = dglsp.val_like(A, new_val)\\n\",\n        \"print(B)\"\n      ],\n      \"metadata\": {\n        \"id\": \"UB3lKJVBbsUD\"\n      },\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"**Create a sparse matrix from various sparse formats**\\n\",\n        \"\\n\",\n        \"*   `from_coo()`: Create a sparse matrix from [COO](https://en.wikipedia.org/wiki/Sparse_matrix#Coordinate_list_(COO)) format.\\n\",\n        \"*   `from_csr()`: Create a sparse matrix from [CSR](https://en.wikipedia.org/wiki/Sparse_matrix#Compressed_sparse_row_(CSR,_CRS_or_Yale_format)) format.\\n\",\n        \"*   `from_csc()`: Create a sparse matrix from [CSC](https://en.wikipedia.org/wiki/Sparse_matrix#Compressed_sparse_column_(CSC_or_CCS)) format.\"\n      ],\n      \"metadata\": {\n        \"id\": \"nWjBSFDBXDPJ\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"row = torch.tensor([0, 1, 2, 2, 2])\\n\",\n        \"col = torch.tensor([1, 2, 0, 1, 2])\\n\",\n        \"\\n\",\n        \"print(\\\"-----Create from COO format-----\\\")\\n\",\n        \"A = dglsp.from_coo(row, col)\\n\",\n        \"print(A)\\n\",\n        \"print(\\\"\\\")\\n\",\n        \"print(\\\"In dense format:\\\")\\n\",\n        \"print(A.to_dense())\\n\",\n        \"print(\\\"\\\")\\n\",\n        \"\\n\",\n        \"indptr = torch.tensor([0, 1, 2, 5])\\n\",\n        \"indices = torch.tensor([1, 2, 0, 1, 2])\\n\",\n        \"\\n\",\n        \"print(\\\"-----Create from CSR format-----\\\")\\n\",\n        \"A = dglsp.from_csr(indptr, indices)\\n\",\n        \"print(A)\\n\",\n        \"print(\\\"\\\")\\n\",\n        \"print(\\\"In dense format:\\\")\\n\",\n        \"print(A.to_dense())\\n\",\n        \"print(\\\"\\\")\\n\",\n        \"\\n\",\n        \"print(\\\"-----Create from CSC format-----\\\")\\n\",\n        \"B = dglsp.from_csc(indptr, indices)\\n\",\n        \"print(B)\\n\",\n        \"print(\\\"\\\")\\n\",\n        \"print(\\\"In dense format:\\\")\\n\",\n        \"print(B.to_dense())\"\n      ],\n      \"metadata\": {\n        \"id\": \"3puXyMFsvdlj\"\n      },\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"### Attributes and methods of a DGL Sparse Matrix\"\n      ],\n      \"metadata\": {\n        \"id\": \"nd4hJ9ysd4St\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"i = torch.tensor([[0, 1, 1, 2],\\n\",\n        \"                  [1, 0, 2, 0]])\\n\",\n        \"val = torch.tensor([1., 2., 3., 4.])\\n\",\n        \"A = dglsp.spmatrix(i, val)\\n\",\n        \"\\n\",\n        \"print(f\\\"Shape of sparse matrix: {A.shape}\\\")\\n\",\n        \"print(f\\\"The number of nonzero elements of sparse matrix: {A.nnz}\\\")\\n\",\n        \"print(f\\\"Datatype of sparse matrix: {A.dtype}\\\")\\n\",\n        \"print(f\\\"Device sparse matrix is stored on: {A.device}\\\")\\n\",\n        \"print(f\\\"Get the values of the nonzero elements: {A.val}\\\")\\n\",\n        \"print(f\\\"Get the row indices of the nonzero elements: {A.row}\\\")\\n\",\n        \"print(f\\\"Get the column indices of the nonzero elements: {A.col}\\\")\\n\",\n        \"print(f\\\"Get the coordinate (COO) representation: {A.coo()}\\\")\\n\",\n        \"print(f\\\"Get the compressed sparse row (CSR) representation: {A.csr()}\\\")\\n\",\n        \"print(f\\\"Get the compressed sparse column (CSC) representation: {A.csc()}\\\")\"\n      ],\n      \"metadata\": {\n        \"id\": \"OKbFiWKIzZVe\"\n      },\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"**dtype and/or device conversion**\"\n      ],\n      \"metadata\": {\n        \"id\": \"VzosM7i3yQPK\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"i = torch.tensor([[0, 1, 1, 2],\\n\",\n        \"                  [1, 0, 2, 0]])\\n\",\n        \"val = torch.tensor([1., 2., 3., 4.])\\n\",\n        \"A = dglsp.spmatrix(i, val)\\n\",\n        \"\\n\",\n        \"B = A.to(device='cpu', dtype=torch.int32)\\n\",\n        \"print(f\\\"Device sparse matrix is stored on: {B.device}\\\")\\n\",\n        \"print(f\\\"Datatype of sparse matrix: {B.dtype}\\\")\"\n      ],\n      \"metadata\": {\n        \"id\": \"y_RJihw-ypXp\"\n      },\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"Similar to pytorch, we also provide various fine-grained APIs ([Doc](https://docs.dgl.ai/en/latest/api/python/dgl.sparse_v0.html)) for dtype and/or device conversion.\"\n      ],\n      \"metadata\": {\n        \"id\": \"U26arLlJzfkN\"\n      }\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"## Diagonal Matrix\\n\",\n        \"\\n\",\n        \"Diagonal Matrix is a special type of Sparse Matrix, in which the entries outside the main diagonal are all zero.\\n\",\n        \"\\n\",\n        \"\\n\"\n      ],\n      \"metadata\": {\n        \"id\": \"EFe9ABRuWHqf\"\n      }\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"### Initializing a DGL Diagonal Sparse Matrix\\n\",\n        \"A DGL Diagonal Sparse Matrix can be initiate by `dglsp.diag()`.\\n\",\n        \"\\n\",\n        \"Identity Matrix is a special type of Diagonal Sparse Matrix, in which all the value on the diagonal are 1.0. Use `dglsp.identity()` to initiate a Diagonal Sparse Matrix.\"\n      ],\n      \"metadata\": {\n        \"id\": \"1CeCoE2Fgl_x\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"val = torch.tensor([1., 2., 3., 4.])\\n\",\n        \"D = dglsp.diag(val)\\n\",\n        \"print(D)\\n\",\n        \"\\n\",\n        \"I = dglsp.identity(shape=(3, 3))\\n\",\n        \"print(I)\"\n      ],\n      \"metadata\": {\n        \"id\": \"9wzJNApahXAR\"\n      },\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"## Operations on Sparse Matrix\\n\",\n        \"*   Elementwise operations\\n\",\n        \"    *   `A + B`\\n\",\n        \"    *   `A - B`\\n\",\n        \"    *   `A * B`\\n\",\n        \"    *   `A / B`\\n\",\n        \"    *   `A ** scalar`\\n\",\n        \"*   Broadcast operations\\n\",\n        \"    *   `sp_<op>_v()`\\n\",\n        \"*   Reduce operations\\n\",\n        \"    *   `reduce()`\\n\",\n        \"    *   `sum()`\\n\",\n        \"    *   `smax()`\\n\",\n        \"    *   `smin()`\\n\",\n        \"    *   `smean()`\\n\",\n        \"*   Matrix transformations\\n\",\n        \"    *   `SparseMatrix.transpose()` or `SparseMatrix.T`\\n\",\n        \"    *   `SparseMatrix.neg()`\\n\",\n        \"    *   `SparseMatrix.inv()`\\n\",\n        \"*   Matrix multiplication\\n\",\n        \"    *   `matmul()`\\n\",\n        \"    *   `sddmm()`\\n\",\n        \"\\n\",\n        \"\\n\",\n        \"*We are using dense format to print sparse matrix in this tutorial since it is more intuitive to read.*\"\n      ],\n      \"metadata\": {\n        \"id\": \"Tjsapqp6zSFR\"\n      }\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"### *Elementwise operations*\"\n      ],\n      \"metadata\": {\n        \"id\": \"psvGwcIqYvC2\"\n      }\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"**add(A, B), equivalent to A + B**\\n\",\n        \"\\n\",\n        \"Element-wise addition on two sparse matrices, returning a sparse matrix.\"\n      ],\n      \"metadata\": {\n        \"id\": \"39YJitpW-K9v\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"i = torch.tensor([[1, 1, 2],\\n\",\n        \"                  [0, 2, 0]])\\n\",\n        \"val = torch.tensor([1., 2., 3.])\\n\",\n        \"A1 = dglsp.spmatrix(i, val, shape=(3, 3))\\n\",\n        \"print(\\\"A1:\\\")\\n\",\n        \"print(A1.to_dense())\\n\",\n        \"\\n\",\n        \"i = torch.tensor([[0, 1, 2],\\n\",\n        \"                  [0, 2, 1]])\\n\",\n        \"val = torch.tensor([4., 5., 6.])\\n\",\n        \"A2 = dglsp.spmatrix(i, val, shape=(3, 3))\\n\",\n        \"print(\\\"A2:\\\")\\n\",\n        \"print(A2.to_dense())\\n\",\n        \"\\n\",\n        \"val = torch.tensor([-1., -2., -3.])\\n\",\n        \"D1 = dglsp.diag(val)\\n\",\n        \"print(\\\"D1:\\\")\\n\",\n        \"print(D1.to_dense())\\n\",\n        \"\\n\",\n        \"val = torch.tensor([-4., -5., -6.])\\n\",\n        \"D2 = dglsp.diag(val)\\n\",\n        \"print(\\\"D2:\\\")\\n\",\n        \"print(D2.to_dense())\\n\",\n        \"\\n\",\n        \"print(\\\"A1 + A2:\\\")\\n\",\n        \"print((A1 + A2).to_dense())\\n\",\n        \"\\n\",\n        \"print(\\\"A1 + D1:\\\")\\n\",\n        \"print((A1 + D1).to_dense())\\n\",\n        \"\\n\",\n        \"print(\\\"D1 + D2:\\\")\\n\",\n        \"print((D1 + D2).to_dense())\"\n      ],\n      \"metadata\": {\n        \"id\": \"pj3Ckx41-BSu\"\n      },\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"**sub(A, B), equivalent to A - B**\\n\",\n        \"\\n\",\n        \"Element-wise substraction on two sparse matrices, returning a sparse matrix.\"\n      ],\n      \"metadata\": {\n        \"id\": \"i25N0JHUTUX9\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"i = torch.tensor([[1, 1, 2],\\n\",\n        \"                  [0, 2, 0]])\\n\",\n        \"val = torch.tensor([1., 2., 3.])\\n\",\n        \"A1 = dglsp.spmatrix(i, val, shape=(3, 3))\\n\",\n        \"print(\\\"A1:\\\")\\n\",\n        \"print(A1.to_dense())\\n\",\n        \"\\n\",\n        \"i = torch.tensor([[0, 1, 2],\\n\",\n        \"                  [0, 2, 1]])\\n\",\n        \"val = torch.tensor([4., 5., 6.])\\n\",\n        \"A2 = dglsp.spmatrix(i, val, shape=(3, 3))\\n\",\n        \"print(\\\"A2:\\\")\\n\",\n        \"print(A2.to_dense())\\n\",\n        \"\\n\",\n        \"val = torch.tensor([-1., -2., -3.])\\n\",\n        \"D1 = dglsp.diag(val)\\n\",\n        \"print(\\\"D1:\\\")\\n\",\n        \"print(D1.to_dense())\\n\",\n        \"\\n\",\n        \"val = torch.tensor([-4., -5., -6.])\\n\",\n        \"D2 = dglsp.diag(val)\\n\",\n        \"print(\\\"D2:\\\")\\n\",\n        \"print(D2.to_dense())\\n\",\n        \"\\n\",\n        \"print(\\\"A1 - A2:\\\")\\n\",\n        \"print((A1 - A2).to_dense())\\n\",\n        \"\\n\",\n        \"print(\\\"A1 - D1:\\\")\\n\",\n        \"print((A1 - D1).to_dense())\\n\",\n        \"\\n\",\n        \"print(\\\"D1 - A1:\\\")\\n\",\n        \"print((D1 - A1).to_dense())\\n\",\n        \"\\n\",\n        \"print(\\\"D1 - D2:\\\")\\n\",\n        \"print((D1 - D2).to_dense())\"\n      ],\n      \"metadata\": {\n        \"id\": \"GMxfz-cyT129\"\n      },\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"**mul(A, B), equivalent to A * B**\\n\",\n        \"\\n\",\n        \"Element-wise multiplication on two sparse matrices or on a sparse matrix and a scalar, returning a sparse matrix.\"\n      ],\n      \"metadata\": {\n        \"id\": \"bg45jnq8T9EJ\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"i = torch.tensor([[1, 1, 2],\\n\",\n        \"                  [0, 2, 0]])\\n\",\n        \"val = torch.tensor([1., 2., 3.])\\n\",\n        \"A1 = dglsp.spmatrix(i, val, shape=(3, 3))\\n\",\n        \"print(\\\"A1:\\\")\\n\",\n        \"print(A1.to_dense())\\n\",\n        \"\\n\",\n        \"i = torch.tensor([[0, 1, 2, 2],\\n\",\n        \"                  [0, 2, 0, 1]])\\n\",\n        \"val = torch.tensor([1., 2., 3., 4.])\\n\",\n        \"A2 = dglsp.spmatrix(i, val, shape=(3, 3))\\n\",\n        \"\\n\",\n        \"print(\\\"A2:\\\")\\n\",\n        \"print(A2.to_dense())\\n\",\n        \"\\n\",\n        \"print(\\\"A1 * 3:\\\")\\n\",\n        \"print((A1 * 3).to_dense())\\n\",\n        \"print(\\\"3 * A1:\\\")\\n\",\n        \"print((3 * A1).to_dense())\\n\",\n        \"\\n\",\n        \"print(\\\"A1 * A2\\\")\\n\",\n        \"print((A1 * A2).to_dense())\\n\",\n        \"\\n\",\n        \"val = torch.tensor([-1., -2., -3.])\\n\",\n        \"D1 = dglsp.diag(val)\\n\",\n        \"print(\\\"D1:\\\")\\n\",\n        \"print(D1.to_dense())\\n\",\n        \"\\n\",\n        \"print(\\\"D1 * A2\\\")\\n\",\n        \"print((D1 * A2).to_dense())\\n\",\n        \"\\n\",\n        \"val = torch.tensor([-4., -5., -6.])\\n\",\n        \"D2 = dglsp.diag(val)\\n\",\n        \"print(\\\"D2:\\\")\\n\",\n        \"print(D2.to_dense())\\n\",\n        \"\\n\",\n        \"print(\\\"D1 * -2:\\\")\\n\",\n        \"print((D1 * -2).to_dense())\\n\",\n        \"print(\\\"-2 * D1:\\\")\\n\",\n        \"print((-2 * D1).to_dense())\\n\",\n        \"\\n\",\n        \"print(\\\"D1 * D2:\\\")\\n\",\n        \"print((D1 * D2).to_dense())\"\n      ],\n      \"metadata\": {\n        \"id\": \"4PAITJqHUB8J\"\n      },\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"**div(A, B), equivalent to A / B**\\n\",\n        \"\\n\",\n        \"Element-wise multiplication on two sparse matrices or on a sparse matrix and a scalar, returning a sparse matrix. If both `A` and `B` are sparse matrices, both of them must have the same sparsity. And the returned matrix has the same order of non-zero entries as `A`.\"\n      ],\n      \"metadata\": {\n        \"id\": \"Xb2RU6H4UBCs\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"i = torch.tensor([[1, 1, 2],\\n\",\n        \"                  [0, 2, 0]])\\n\",\n        \"val = torch.tensor([1., 2., 3.])\\n\",\n        \"A1 = dglsp.spmatrix(i, val, shape=(3, 3))\\n\",\n        \"print(\\\"A1:\\\")\\n\",\n        \"print(A1.to_dense())\\n\",\n        \"\\n\",\n        \"i = torch.tensor([[1, 2, 1],\\n\",\n        \"                  [0, 0, 2]])\\n\",\n        \"val = torch.tensor([1., 3., 2.])\\n\",\n        \"A2 = dglsp.spmatrix(i, val, shape=(3, 3))\\n\",\n        \"\\n\",\n        \"print(\\\"A1 / 2:\\\")\\n\",\n        \"print((A1 / 2).to_dense())\\n\",\n        \"\\n\",\n        \"print(\\\"A1 / A2\\\")\\n\",\n        \"print((A1 / A2).to_dense())\\n\",\n        \"\\n\",\n        \"val = torch.tensor([-1., -2., -3.])\\n\",\n        \"D1 = dglsp.diag(val)\\n\",\n        \"print(\\\"D1:\\\")\\n\",\n        \"print(D1.to_dense())\\n\",\n        \"\\n\",\n        \"val = torch.tensor([-4., -5., -6.])\\n\",\n        \"D2 = dglsp.diag(val)\\n\",\n        \"print(\\\"D2:\\\")\\n\",\n        \"print(D2.to_dense())\\n\",\n        \"\\n\",\n        \"print(\\\"D1 / D2:\\\")\\n\",\n        \"print((D1 / D2).to_dense())\\n\",\n        \"\\n\",\n        \"print(\\\"D1 / 2:\\\")\\n\",\n        \"print((D1 / 2).to_dense())\"\n      ],\n      \"metadata\": {\n        \"id\": \"TFB_UcmEUdr3\"\n      },\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"**power(A, B), equivalent to A \\\\*\\\\* B**\\n\",\n        \"\\n\",\n        \"Element-wise power of a sparse matrix and a scalar, returning a sparse matrix.\"\n      ],\n      \"metadata\": {\n        \"id\": \"2lZbyTYUUgSi\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"i = torch.tensor([[1, 1, 2],\\n\",\n        \"                  [0, 2, 0]])\\n\",\n        \"val = torch.tensor([1., 2., 3.])\\n\",\n        \"A = dglsp.spmatrix(i, val, shape=(3, 3))\\n\",\n        \"print(\\\"A:\\\")\\n\",\n        \"print(A.to_dense())\\n\",\n        \"\\n\",\n        \"print(\\\"A ** 3:\\\")\\n\",\n        \"print((A ** 3).to_dense())\\n\",\n        \"\\n\",\n        \"val = torch.tensor([-1., -2., -3.])\\n\",\n        \"D = dglsp.diag(val)\\n\",\n        \"print(\\\"D:\\\")\\n\",\n        \"print(D.to_dense())\\n\",\n        \"\\n\",\n        \"print(\\\"D1 ** 2:\\\")\\n\",\n        \"print((D1 ** 2).to_dense())\"\n      ],\n      \"metadata\": {\n        \"id\": \"ox-XxCnuUqAy\"\n      },\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"### *Broadcast operations*\"\n      ],\n      \"metadata\": {\n        \"id\": \"VXBz4j5x_wQ4\"\n      }\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"**sp_\\\\<op\\\\>_v(A, v)**\\n\",\n        \"\\n\",\n        \"Broadcast operations on a sparse matrix and a vector, returning a sparse matrix. `v` is broadcasted to the shape of `A` and then the operator is applied on the non-zero values of `A`. `<op>` can be add, sub, mul, and div. \\n\",\n        \"\\n\",\n        \"There are two cases regarding the shape of `v`:\\n\",\n        \"\\n\",\n        \"1. `v` is a vector of shape `(1, A.shape[1])` or `(A.shape[1])`. In this case, `v` is broadcasted on the row dimension of `A`.\\n\",\n        \"\\n\",\n        \"2. `v` is a vector of shape `(A.shape[0], 1)`. In this case, `v` is broadcasted on the column dimension of `A`.\"\n      ],\n      \"metadata\": {\n        \"id\": \"PtnyZdXHAZ6Z\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"i = torch.tensor([[1, 0, 2], [0, 3, 2]])\\n\",\n        \"val = torch.tensor([10, 20, 30])\\n\",\n        \"A = dglsp.spmatrix(i, val, shape=(3, 4))\\n\",\n        \"\\n\",\n        \"v1 = torch.tensor([1, 2, 3, 4])\\n\",\n        \"print(\\\"A:\\\")\\n\",\n        \"print(A.to_dense())\\n\",\n        \"\\n\",\n        \"print(\\\"v1:\\\")\\n\",\n        \"print(v1)\\n\",\n        \"\\n\",\n        \"print(\\\"sp_add_v(A, v1)\\\")\\n\",\n        \"print(dglsp.sp_add_v(A, v1).to_dense())\\n\",\n        \"\\n\",\n        \"v2 = v1.reshape(1, -1)\\n\",\n        \"print(\\\"v2:\\\")\\n\",\n        \"print(v2)\\n\",\n        \"\\n\",\n        \"print(\\\"sp_add_v(A, v2)\\\")\\n\",\n        \"print(dglsp.sp_add_v(A, v2).to_dense())\\n\",\n        \"\\n\",\n        \"v3 = torch.tensor([1, 2, 3]).reshape(-1, 1)\\n\",\n        \"print(\\\"v3:\\\")\\n\",\n        \"print(v3)\\n\",\n        \"\\n\",\n        \"print(\\\"sp_add_v(A, v3)\\\")\\n\",\n        \"print(dglsp.sp_add_v(A, v3).to_dense())\"\n      ],\n      \"metadata\": {\n        \"id\": \"xxf3s-uWBRR7\"\n      },\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"### *Reduce operations*\\n\",\n        \"\\n\",\n        \"All DGL sparse reduce operations only consider non-zero elements. To distinguish them from dense PyTorch reduce operations that consider zero elements, we use name `smax`, `smin` and `smean` (`s` stands for sparse).\"\n      ],\n      \"metadata\": {\n        \"id\": \"TQJJlctZjYPv\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"i = torch.tensor([[0, 1, 1, 2],\\n\",\n        \"                  [1, 0, 2, 0]])\\n\",\n        \"val = torch.tensor([1., 2., 3., 4.])\\n\",\n        \"A = dglsp.spmatrix(i, val)\\n\",\n        \"print(A.T.to_dense())\\n\",\n        \"print(\\\"\\\")\\n\",\n        \"\\n\",\n        \"# O1, O2 will have the same value.\\n\",\n        \"O1 = A.reduce(0, 'sum')\\n\",\n        \"O2 = A.sum(0)\\n\",\n        \"print(\\\"Reduce with reducer:sum along dim = 0:\\\")\\n\",\n        \"print(O1)\\n\",\n        \"print(\\\"\\\")\\n\",\n        \"\\n\",\n        \"# O3, O4 will have the same value.\\n\",\n        \"O3 = A.reduce(0, 'smax')\\n\",\n        \"O4 = A.smax(0)\\n\",\n        \"print(\\\"Reduce with reducer:max along dim = 0:\\\")\\n\",\n        \"print(O3)\\n\",\n        \"print(\\\"\\\")\\n\",\n        \"\\n\",\n        \"# O5, O6 will have the same value.\\n\",\n        \"O5 = A.reduce(0, 'smin')\\n\",\n        \"O6 = A.smin(0)\\n\",\n        \"print(\\\"Reduce with reducer:min along dim = 0:\\\")\\n\",\n        \"print(O5)\\n\",\n        \"print(\\\"\\\")\\n\",\n        \"\\n\",\n        \"# O7, O8 will have the same value.\\n\",\n        \"O7 = A.reduce(0, 'smean')\\n\",\n        \"O8 = A.smean(0)\\n\",\n        \"print(\\\"Reduce with reducer:smean along dim = 0:\\\")\\n\",\n        \"print(O7)\\n\",\n        \"print(\\\"\\\")\"\n      ],\n      \"metadata\": {\n        \"id\": \"GhS49Js1jW4b\"\n      },\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"### *Matrix transformations*\"\n      ],\n      \"metadata\": {\n        \"id\": \"kanwnB7LOQui\"\n      }\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"*Sparse Matrix*\"\n      ],\n      \"metadata\": {\n        \"id\": \"NiiXso9elM2p\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"i = torch.tensor([[0, 1, 1, 2],\\n\",\n        \"                  [1, 0, 2, 0]])\\n\",\n        \"val = torch.tensor([1., 2., 3., 4.])\\n\",\n        \"A = dglsp.spmatrix(i, val)\\n\",\n        \"print(A.to_dense())\\n\",\n        \"print(\\\"\\\")\\n\",\n        \"\\n\",\n        \"print(\\\"Get transpose of sparse matrix.\\\")\\n\",\n        \"print(A.T.to_dense())\\n\",\n        \"# Alias\\n\",\n        \"# A.transpose()\\n\",\n        \"# A.t()\\n\",\n        \"print(\\\"\\\")\\n\",\n        \"\\n\",\n        \"print(\\\"Get a sparse matrix with the negation of the original nonzero values.\\\")\\n\",\n        \"print(A.neg().to_dense())\\n\",\n        \"print(\\\"\\\")\"\n      ],\n      \"metadata\": {\n        \"id\": \"qJcmZHmf-oTY\"\n      },\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"### *Matrix multiplication*\"\n      ],\n      \"metadata\": {\n        \"id\": \"4uQlDFb0Uzto\"\n      }\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"**matmul(A, B), equivalent to A @ B**\\n\",\n        \"\\n\",\n        \"Matrix multiplication on sparse matrices and/or dense matrix. There are two cases as follows.\"\n      ],\n      \"metadata\": {\n        \"id\": \"THWE30v6WpAk\"\n      }\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"**SparseMatrix @ SparseMatrix -> SparseMatrix:**\\n\",\n        \"\\n\",\n        \"For a $L \\\\times M$ sparse matrix A and a $M \\\\times N$ sparse matrix B, the shape of `A @ B` will be $L \\\\times N$ sparse matrix.\"\n      ],\n      \"metadata\": {\n        \"id\": \"VxyykR-vX7lF\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"i = torch.tensor([[1, 1, 2],\\n\",\n        \"                  [0, 2, 0]])\\n\",\n        \"val = torch.tensor([1., 2., 3.])\\n\",\n        \"A1 = dglsp.spmatrix(i, val, shape=(3, 3))\\n\",\n        \"print(\\\"A1:\\\")\\n\",\n        \"print(A1.to_dense())\\n\",\n        \"\\n\",\n        \"i = torch.tensor([[0, 1, 2],\\n\",\n        \"                  [0, 2, 1]])\\n\",\n        \"val = torch.tensor([4., 5., 6.])\\n\",\n        \"A2 = dglsp.spmatrix(i, val, shape=(3, 3))\\n\",\n        \"print(\\\"A2:\\\")\\n\",\n        \"print(A2.to_dense())\\n\",\n        \"\\n\",\n        \"val = torch.tensor([-1., -2., -3.])\\n\",\n        \"D1 = dglsp.diag(val)\\n\",\n        \"print(\\\"D1:\\\")\\n\",\n        \"print(D1.to_dense())\\n\",\n        \"\\n\",\n        \"val = torch.tensor([-4., -5., -6.])\\n\",\n        \"D2 = dglsp.diag(val)\\n\",\n        \"print(\\\"D2:\\\")\\n\",\n        \"print(D2.to_dense())\\n\",\n        \"\\n\",\n        \"print(\\\"A1 @ A2:\\\")\\n\",\n        \"print((A1 @ A2).to_dense())\\n\",\n        \"\\n\",\n        \"print(\\\"A1 @ D1:\\\")\\n\",\n        \"print((A1 @ D1).to_dense())\\n\",\n        \"\\n\",\n        \"print(\\\"D1 @ A1:\\\")\\n\",\n        \"print((D1 @ A1).to_dense())\\n\",\n        \"\\n\",\n        \"print(\\\"D1 @ D2:\\\")\\n\",\n        \"print((D1 @ D2).to_dense())\"\n      ],\n      \"metadata\": {\n        \"id\": \"XRDFC2rOYQM4\"\n      },\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"**SparseMatrix @ Tensor -> Tensor:**\\n\",\n        \"\\n\",\n        \"For a $L \\\\times M$ sparse matrix A and a $M \\\\times N$ dense matrix B, the shape of `A @ B` will be $L \\\\times N$ dense matrix.\"\n      ],\n      \"metadata\": {\n        \"id\": \"g13fG8nvaVOt\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"i = torch.tensor([[1, 1, 2],\\n\",\n        \"                  [0, 2, 0]])\\n\",\n        \"val = torch.tensor([1., 2., 3.])\\n\",\n        \"A = dglsp.spmatrix(i, val, shape=(3, 3))\\n\",\n        \"print(\\\"A:\\\")\\n\",\n        \"print(A.to_dense())\\n\",\n        \"\\n\",\n        \"val = torch.tensor([-1., -2., -3.])\\n\",\n        \"D = dglsp.diag(val)\\n\",\n        \"print(\\\"D:\\\")\\n\",\n        \"print(D.to_dense())\\n\",\n        \"\\n\",\n        \"X = torch.tensor([[11., 22.], [33., 44.], [55., 66.]])\\n\",\n        \"print(\\\"X:\\\")\\n\",\n        \"print(X)\\n\",\n        \"\\n\",\n        \"print(\\\"A @ X:\\\")\\n\",\n        \"print(A @ X)\\n\",\n        \"\\n\",\n        \"print(\\\"D @ X:\\\")\\n\",\n        \"print(D @ X)\"\n      ],\n      \"metadata\": {\n        \"id\": \"FcQ-CnqdlgWF\"\n      },\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"This operator also supports batched sparse-dense matrix multiplication. The sparse matrix A should have shape $L \\\\times M$, where the non-zero values are vectors of length $K$. The dense matrix B should have shape $M \\\\times N \\\\times K$. The output is a dense matrix of shape $L \\\\times N \\\\times K$.\"\n      ],\n      \"metadata\": {\n        \"id\": \"_KZiULLbmEZE\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"i = torch.tensor([[1, 1, 2],\\n\",\n        \"                  [0, 2, 0]])\\n\",\n        \"val = torch.tensor([[1., 1.], [2., 2.], [3., 3.]])\\n\",\n        \"A = dglsp.spmatrix(i, val, shape=(3, 3))\\n\",\n        \"print(\\\"A:\\\")\\n\",\n        \"print(A.to_dense())\\n\",\n        \"\\n\",\n        \"X = torch.tensor([[[1., 1.], [1., 2.]],\\n\",\n        \"                  [[1., 3.], [1., 4.]],\\n\",\n        \"                  [[1., 5.], [1., 6.]]])\\n\",\n        \"print(\\\"X:\\\")\\n\",\n        \"print(X)\\n\",\n        \"\\n\",\n        \"print(\\\"A @ X:\\\")\\n\",\n        \"print(A @ X)\"\n      ],\n      \"metadata\": {\n        \"id\": \"ZUzXQk7Ab2wG\"\n      },\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"**Sampled-Dense-Dense Matrix Multiplication (SDDMM)**\\n\",\n        \"\\n\",\n        \"``sddmm`` matrix-multiplies two dense matrices X1 and X2, then elementwise-multiplies the result with sparse matrix A at the nonzero locations. This is designed for sparse matrix with scalar values.\\n\",\n        \"\\n\",\n        \"$$out = (X_1 @ X_2) * A$$\\n\",\n        \"\\n\",\n        \"For a $L \\\\times N$ sparse matrix A, a $L \\\\times M$ dense matrix X1 and a $M \\\\times N$ dense matrix X2, `sddmm(A, X1, X2)` will be a $L \\\\times N$ sparse matrix.\"\n      ],\n      \"metadata\": {\n        \"id\": \"qO_8f_vhPKtf\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"i = torch.tensor([[1, 1, 2],\\n\",\n        \"                  [2, 3, 3]])\\n\",\n        \"val = torch.tensor([1., 2., 3.])\\n\",\n        \"A = dglsp.spmatrix(i, val, (3, 4))\\n\",\n        \"print(\\\"A:\\\")\\n\",\n        \"print(A.to_dense())\\n\",\n        \"\\n\",\n        \"X1 = torch.randn(3, 5)\\n\",\n        \"X2 = torch.randn(5, 4)\\n\",\n        \"print(\\\"X1:\\\")\\n\",\n        \"print(X1)\\n\",\n        \"print(\\\"X2:\\\")\\n\",\n        \"print(X2)\\n\",\n        \"\\n\",\n        \"O = dglsp.sddmm(A, X1, X2)\\n\",\n        \"print(\\\"dglsp.sddmm(A, X1, X2):\\\")\\n\",\n        \"print(O.to_dense())\"\n      ],\n      \"metadata\": {\n        \"id\": \"3ZIFV0TgPhwH\"\n      },\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"This operator also supports batched sampled-dense-dense matrix multiplication. For a $L \\\\times N$ sparse matrix A with non-zero vector values of length $𝐾$, a $L \\\\times M \\\\times K$ dense matrix X1 and a $M \\\\times N \\\\times K$ dense matrix X2, `sddmm(A, X1, X2)` will be a $L \\\\times N \\\\times K$ sparse matrix.\"\n      ],\n      \"metadata\": {\n        \"id\": \"RmNmXU_ZqyF7\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"i = torch.tensor([[1, 1, 2],\\n\",\n        \"                  [2, 3, 3]])\\n\",\n        \"val = torch.tensor([[1., 1.], [2., 2.], [3., 3.]])\\n\",\n        \"A = dglsp.spmatrix(i, val, (3, 4))\\n\",\n        \"print(\\\"A:\\\")\\n\",\n        \"print(A.to_dense())\\n\",\n        \"\\n\",\n        \"X1 = torch.randn(3, 5, 2)\\n\",\n        \"X2 = torch.randn(5, 4, 2)\\n\",\n        \"print(\\\"X1:\\\")\\n\",\n        \"print(X1)\\n\",\n        \"print(\\\"X2:\\\")\\n\",\n        \"print(X2)\\n\",\n        \"\\n\",\n        \"O = dglsp.sddmm(A, X1, X2)\\n\",\n        \"print(\\\"dglsp.sddmm(A, X1, X2):\\\")\\n\",\n        \"print(O.to_dense())\"\n      ],\n      \"metadata\": {\n        \"id\": \"DuSAjamyrIO_\"\n      },\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"## Non-linear activation functions\"\n      ],\n      \"metadata\": {\n        \"id\": \"fVkbTT28ZzPr\"\n      }\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"### Element-wise functions\\n\",\n        \"\\n\",\n        \"Most activation functions are element-wise and can be further grouped into two categories:\\n\",\n        \"\\n\",\n        \"**Sparse-preserving functions** such as `sin()`, `tanh()`, `sigmoid()`, `relu()`, etc. You can directly apply them on the `val` tensor of the sparse matrix and then recreate a new matrix of the same sparsity using `val_like`.\"\n      ],\n      \"metadata\": {\n        \"id\": \"XuaNdFO7XG2r\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"i = torch.tensor([[0, 1, 1, 2],\\n\",\n        \"                  [1, 0, 2, 0]])\\n\",\n        \"val = torch.randn(4)\\n\",\n        \"A = dglsp.spmatrix(i, val)\\n\",\n        \"print(A.to_dense())\\n\",\n        \"\\n\",\n        \"print(\\\"Apply tanh.\\\")\\n\",\n        \"A_new = dglsp.val_like(A, torch.tanh(A.val))\\n\",\n        \"print(A_new.to_dense())\"\n      ],\n      \"metadata\": {\n        \"id\": \"GZkCJJ0TX0cI\"\n      },\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"**Non-sparse-preserving functions** such as `exp()`, `cos()`, etc. You can first convert the sparse matrix to dense before applying the functions.\"\n      ],\n      \"metadata\": {\n        \"id\": \"i92lhMEnYas3\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"i = torch.tensor([[0, 1, 1, 2],\\n\",\n        \"                  [1, 0, 2, 0]])\\n\",\n        \"val = torch.randn(4)\\n\",\n        \"A = dglsp.spmatrix(i, val)\\n\",\n        \"print(A.to_dense())\\n\",\n        \"\\n\",\n        \"print(\\\"Apply exp.\\\")\\n\",\n        \"A_new = A.to_dense().exp()\\n\",\n        \"print(A_new)\"\n      ],\n      \"metadata\": {\n        \"id\": \"sroJpzRNYZq5\"\n      },\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"### Softmax\\n\",\n        \"\\n\",\n        \"Apply row-wise softmax to the nonzero entries of the sparse matrix.\"\n      ],\n      \"metadata\": {\n        \"id\": \"y8OQZReVXpo3\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"i = torch.tensor([[0, 1, 1, 2],\\n\",\n        \"                  [1, 0, 2, 0]])\\n\",\n        \"val = torch.tensor([1., 2., 3., 4.])\\n\",\n        \"A = dglsp.spmatrix(i, val)\\n\",\n        \"\\n\",\n        \"print(A.softmax())\\n\",\n        \"print(\\\"In dense format:\\\")\\n\",\n        \"print(A.softmax().to_dense())\\n\",\n        \"print(\\\"\\\\n\\\")\"\n      ],\n      \"metadata\": {\n        \"id\": \"CQaKgzCJULjt\"\n      },\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"## Exercise \\\\#1\\n\",\n        \"\\n\",\n        \"*Let's test what you've learned. Feel free to [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dmlc/dgl/blob/master/notebooks/sparse/quickstart.ipynb).*\\n\",\n        \"\\n\",\n        \"Given a sparse symmetrical adjacency matrix $A$, calculate its symmetrically normalized adjacency matrix: $$norm = \\\\bar{D}^{-\\\\frac{1}{2}}\\\\bar{A}\\\\bar{D}^{-\\\\frac{1}{2}}$$\\n\",\n        \"\\n\",\n        \"Where $\\\\bar{A} = A + I$, $I$ is the identity matrix, and $\\\\bar{D}$ is the diagonal node degree matrix of $\\\\bar{A}$.\"\n      ],\n      \"metadata\": {\n        \"id\": \"1iBNlJVYz3zi\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"i = torch.tensor([[0, 0, 1, 1, 2, 2, 3],\\n\",\n        \"                  [1, 3, 2, 5, 3, 5, 4]])\\n\",\n        \"asym_A = dglsp.spmatrix(i, shape=(6, 6))\\n\",\n        \"# Step 1: create symmetrical adjacency matrix A from asym_A.\\n\",\n        \"# A =\\n\",\n        \"\\n\",\n        \"# Step 2: calculate A_hat from A.\\n\",\n        \"# A_hat =\\n\",\n        \"\\n\",\n        \"# Step 3: diagonal node degree matrix of A_hat\\n\",\n        \"# D_hat =\\n\",\n        \"\\n\",\n        \"# Step 4: calculate the norm from D_hat and A_hat.\\n\",\n        \"# norm = \"\n      ],\n      \"metadata\": {\n        \"id\": \"0dDhfbJo0ByV\"\n      },\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"## Exercise \\\\#2\\n\",\n        \"\\n\",\n        \"Let's implement a simplified version of the Graph Attention Network (GAT) layer.\\n\",\n        \"\\n\",\n        \"A GAT layer has two inputs: the adjacency matrix $A$ and the node input features $X$.  The idea of GAT layer is to update each node's representation with a weighted average of the node's own representation and its neighbors' representations.  In particular, when computing the output for node $i$, the GAT layer does the following:\\n\",\n        \"1. Compute the scores $S_{ij}$ representing the attention logit from neighbor $j$ to node $i$.  $S_{ij}$ is a function of $i$ and $j$'s input features $X_i$ and $X_j$: $$S_{ij} = LeakyReLU(X_i^\\\\top v_1 + X_j^\\\\top v_2)$$, where $v_1$ and $v_2$ are trainable vectors.\\n\",\n        \"2. Compute a softmax attention $R_{ij} = \\\\exp S_{ij} / \\\\left( \\\\sum_{j' \\\\in \\\\mathcal{N}_i} s_{ij'} \\\\right)$, where $\\\\mathcal{N}_j$ means the neighbors of $j$.  This means that $R$ is a row-wise softmax attention of $S$.\\n\",\n        \"3. Compute the weighted average $H_i = \\\\sum_{j' : j' \\\\in \\\\mathcal{N}_i} R_{j'} X_{j'} W$, where $W$ is a trainable matrix.\\n\",\n        \"\\n\",\n        \"The following code defined all the parameters you need but only completes step 1.  Could you implement step 2 and step 3?\"\n      ],\n      \"metadata\": {\n        \"id\": \"yfEVQBUuI-cE\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"import torch.nn as nn\\n\",\n        \"import torch.nn.functional as F\\n\",\n        \"\\n\",\n        \"class SimplifiedGAT(nn.Module):\\n\",\n        \"    def __init__(self, in_size, out_size):\\n\",\n        \"        super().__init__()\\n\",\n        \"\\n\",\n        \"        self.W = nn.Parameter(torch.randn(in_size, out_size))\\n\",\n        \"        self.v1 = nn.Parameter(torch.randn(in_size))\\n\",\n        \"        self.v2 = nn.Parameter(torch.randn(in_size))\\n\",\n        \"\\n\",\n        \"    def forward(self, A, X):\\n\",\n        \"        # A: A sparse matrix with size (N, N).  A[i, j] represent the edge from j to i.\\n\",\n        \"        # X: A dense matrix with size (N, D)\\n\",\n        \"        # Step 1: compute S[i, j]\\n\",\n        \"        Xv1 = X @ self.v1\\n\",\n        \"        Xv2 = X @ self.v2\\n\",\n        \"        s = F.leaky_relu(Xv1[A.col] + Xv2[A.row])\\n\",\n        \"        S = dglsp.val_like(A, s)\\n\",\n        \"\\n\",\n        \"        # Step 2: compute R[i, j] which is the row-wise attention of $S$.\\n\",\n        \"        # EXERCISE: replace the statement below.\\n\",\n        \"        R = S\\n\",\n        \"\\n\",\n        \"        # Step 3: compute H.\\n\",\n        \"        # EXERCISE: replace the statement below.\\n\",\n        \"        H = X\\n\",\n        \"\\n\",\n        \"        return H\"\n      ],\n      \"metadata\": {\n        \"id\": \"pYrgSxq6La5c\"\n      },\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"# Test:\\n\",\n        \"# Let's use the symmetric A created above.\\n\",\n        \"X = torch.randn(6, 20)\\n\",\n        \"module = SimplifiedGAT(20, 10)\\n\",\n        \"Y = module(A, X)\"\n      ],\n      \"metadata\": {\n        \"id\": \"qjcXiidYCqGK\"\n      },\n      \"execution_count\": null,\n      \"outputs\": []\n    }\n  ]\n}\n"
  },
  {
    "path": "notebooks/stochastic_training/link_prediction.ipynb",
    "content": "{\n  \"cells\": [\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"Ow8CQmZIV8Yn\"\n      },\n      \"source\": [\n        \"# Link Prediction\\n\",\n        \"\\n\",\n        \"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dmlc/dgl/blob/master/notebooks/stochastic_training/link_prediction.ipynb) [![GitHub](https://img.shields.io/badge/-View%20on%20GitHub-181717?logo=github&logoColor=ffffff)](https://github.com/dmlc/dgl/blob/master/notebooks/stochastic_training/link_prediction.ipynb)\\n\",\n        \"\\n\",\n        \"This tutorial will show how to train a multi-layer GraphSAGE for link\\n\",\n        \"prediction on [CoraGraphDataset](https://data.dgl.ai/dataset/cora_v2.zip).\\n\",\n        \"The dataset contains 2708 nodes and 10556 edges.\\n\",\n        \"\\n\",\n        \"By the end of this tutorial, you will be able to\\n\",\n        \"\\n\",\n        \"-  Train a GNN model for link prediction on target device with DGL's\\n\",\n        \"   neighbor sampling components.\\n\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"onVijYWpWlMj\"\n      },\n      \"source\": [\n        \"## Install DGL package\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"QcpjTazg6hEo\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"# Install required packages.\\n\",\n        \"import os\\n\",\n        \"import torch\\n\",\n        \"os.environ['TORCH'] = torch.__version__\\n\",\n        \"os.environ['DGLBACKEND'] = \\\"pytorch\\\"\\n\",\n        \"\\n\",\n        \"# Install the CPU version in default. If you want to install CUDA version,\\n\",\n        \"# please refer to https://www.dgl.ai/pages/start.html and change runtime type\\n\",\n        \"# accordingly.\\n\",\n        \"device = torch.device(\\\"cpu\\\")\\n\",\n        \"!pip install --pre dgl -f https://data.dgl.ai/wheels-test/repo.html\\n\",\n        \"\\n\",\n        \"try:\\n\",\n        \"    import dgl\\n\",\n        \"    import dgl.graphbolt as gb\\n\",\n        \"    installed = True\\n\",\n        \"except ImportError as error:\\n\",\n        \"    installed = False\\n\",\n        \"    print(error)\\n\",\n        \"print(\\\"DGL installed!\\\" if installed else \\\"DGL not found!\\\")\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"OOKZxxT7W1Rz\"\n      },\n      \"source\": [\n        \"## Loading Dataset\\n\",\n        \"`cora` is already prepared as `BuiltinDataset` in **GraphBolt**.\\n\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"RnJkkSKhWiUG\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"dataset = gb.BuiltinDataset(\\\"cora-seeds\\\").load()\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"WxnTMEQXXKsM\"\n      },\n      \"source\": [\n        \"Dataset consists of graph, feature and tasks. You can get the training-validation-test set from the tasks. Seed nodes and corresponding labels are already stored in each training-validation-test set. This dataset contains 2 tasks, one for node classification and the other for link prediction. We will use the link prediction task.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"YCm8CGkOX9lK\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"graph = dataset.graph.to(device)\\n\",\n        \"feature = dataset.feature.to(device)\\n\",\n        \"train_set = dataset.tasks[1].train_set\\n\",\n        \"test_set = dataset.tasks[1].test_set\\n\",\n        \"task_name = dataset.tasks[1].metadata[\\\"name\\\"]\\n\",\n        \"print(f\\\"Task: {task_name}.\\\")\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"2y-P5omQYP00\"\n      },\n      \"source\": [\n        \"## Defining Neighbor Sampler and Data Loader in DGL\\n\",\n        \"Different from the link prediction tutorial for full graph, a common practice to train GNN on large graphs is to iterate over the edges in minibatches, since computing the probability of all edges is usually impossible. For each minibatch of edges, you compute the output representation of their incident nodes using neighbor sampling and GNN, in a similar fashion introduced in the node classification tutorial.\\n\",\n        \"\\n\",\n        \"To perform link prediction, you need to specify a negative sampler. DGL provides builtin negative samplers such as `dgl.graphbolt.UniformNegativeSampler`. Here this tutorial uniformly draws 5 negative examples per positive example.\\n\",\n        \"\\n\",\n        \"Except for the negative sampler, the rest of the code is identical to the node classification tutorial.\\n\",\n        \"\\n\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"LZgXGfBvYijJ\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"from functools import partial\\n\",\n        \"datapipe = gb.ItemSampler(train_set, batch_size=256, shuffle=True)\\n\",\n        \"datapipe = datapipe.copy_to(device)\\n\",\n        \"datapipe = datapipe.sample_uniform_negative(graph, 5)\\n\",\n        \"datapipe = datapipe.sample_neighbor(graph, [5, 5])\\n\",\n        \"datapipe = datapipe.transform(partial(gb.exclude_seed_edges, include_reverse_edges=True))\\n\",\n        \"datapipe = datapipe.fetch_feature(feature, node_feature_keys=[\\\"feat\\\"])\\n\",\n        \"train_dataloader = gb.DataLoader(datapipe)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"5sU_aulqYkwK\"\n      },\n      \"source\": [\n        \"You can peek one minibatch from train_dataloader and see what it will give you.\\n\",\n        \"\\n\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"euEdzmerYmZi\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"data = next(iter(train_dataloader))\\n\",\n        \"print(f\\\"MiniBatch: {data}\\\")\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"WYQqfrDWYtU0\"\n      },\n      \"source\": [\n        \"## Defining Model for Node Representation\\n\",\n        \"Let’s consider training a 2-layer GraphSAGE with neighbor sampling. The model can be written as follows:\\n\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"0qQbBwO7Y3-Q\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"import dgl.nn as dglnn\\n\",\n        \"import torch.nn as nn\\n\",\n        \"import torch.nn.functional as F\\n\",\n        \"\\n\",\n        \"\\n\",\n        \"class SAGE(nn.Module):\\n\",\n        \"    def __init__(self, in_size, hidden_size):\\n\",\n        \"        super().__init__()\\n\",\n        \"        self.layers = nn.ModuleList()\\n\",\n        \"        self.layers.append(dglnn.SAGEConv(in_size, hidden_size, \\\"mean\\\"))\\n\",\n        \"        self.layers.append(dglnn.SAGEConv(hidden_size, hidden_size, \\\"mean\\\"))\\n\",\n        \"        self.hidden_size = hidden_size\\n\",\n        \"        self.predictor = nn.Sequential(\\n\",\n        \"            nn.Linear(hidden_size, hidden_size),\\n\",\n        \"            nn.ReLU(),\\n\",\n        \"            nn.Linear(hidden_size, 1),\\n\",\n        \"        )\\n\",\n        \"\\n\",\n        \"    def forward(self, blocks, x):\\n\",\n        \"        hidden_x = x\\n\",\n        \"        for layer_idx, (layer, block) in enumerate(zip(self.layers, blocks)):\\n\",\n        \"            hidden_x = layer(block, hidden_x)\\n\",\n        \"            is_last_layer = layer_idx == len(self.layers) - 1\\n\",\n        \"            if not is_last_layer:\\n\",\n        \"                hidden_x = F.relu(hidden_x)\\n\",\n        \"        return hidden_x\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"y23JppwHY5MC\"\n      },\n      \"source\": [\n        \"## Defining Traing Loop\\n\",\n        \"The following initializes the model and defines the optimizer.\\n\",\n        \"\\n\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"omSIB_ePZACg\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"in_size = feature.size(\\\"node\\\", None, \\\"feat\\\")[0]\\n\",\n        \"model = SAGE(in_size, 128).to(device)\\n\",\n        \"optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"QyWtzNZcZRgp\"\n      },\n      \"source\": [\n        \"The following is the training loop for link prediction and evaluation.\\n\",\n        \"\\n\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"SccLVrjSZSkd\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"from tqdm.auto import tqdm\\n\",\n        \"for epoch in range(3):\\n\",\n        \"    model.train()\\n\",\n        \"    total_loss = 0\\n\",\n        \"    for step, data in tqdm(enumerate(train_dataloader)):\\n\",\n        \"        # Get node pairs with labels for loss calculation.\\n\",\n        \"        compacted_seeds = data.compacted_seeds.T\\n\",\n        \"        labels = data.labels\\n\",\n        \"        node_feature = data.node_features[\\\"feat\\\"]\\n\",\n        \"        # Convert sampled subgraphs to DGL blocks.\\n\",\n        \"        blocks = data.blocks\\n\",\n        \"\\n\",\n        \"        # Get the embeddings of the input nodes.\\n\",\n        \"        y = model(blocks, node_feature)\\n\",\n        \"        logits = model.predictor(\\n\",\n        \"            y[compacted_seeds[0]] * y[compacted_seeds[1]]\\n\",\n        \"        ).squeeze()\\n\",\n        \"\\n\",\n        \"        # Compute loss.\\n\",\n        \"        loss = F.binary_cross_entropy_with_logits(logits, labels)\\n\",\n        \"        optimizer.zero_grad()\\n\",\n        \"        loss.backward()\\n\",\n        \"        optimizer.step()\\n\",\n        \"\\n\",\n        \"        total_loss += loss.item()\\n\",\n        \"\\n\",\n        \"    print(f\\\"Epoch {epoch:03d} | Loss {total_loss / (step + 1):.3f}\\\")\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"pxow2XSkZXoO\"\n      },\n      \"source\": [\n        \"## Evaluating Performance with Link Prediction\\n\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"IMulfsnIZZVh\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"model.eval()\\n\",\n        \"\\n\",\n        \"datapipe = gb.ItemSampler(test_set, batch_size=256, shuffle=False)\\n\",\n        \"datapipe = datapipe.copy_to(device)\\n\",\n        \"# Since we need to use all neghborhoods for evaluation, we set the fanout\\n\",\n        \"# to -1.\\n\",\n        \"datapipe = datapipe.sample_neighbor(graph, [-1, -1])\\n\",\n        \"datapipe = datapipe.fetch_feature(feature, node_feature_keys=[\\\"feat\\\"])\\n\",\n        \"eval_dataloader = gb.DataLoader(datapipe, num_workers=0)\\n\",\n        \"\\n\",\n        \"logits = []\\n\",\n        \"labels = []\\n\",\n        \"for step, data in tqdm(enumerate(eval_dataloader)):\\n\",\n        \"    # Get node pairs with labels for loss calculation.\\n\",\n        \"    compacted_seeds = data.compacted_seeds.T\\n\",\n        \"    label = data.labels\\n\",\n        \"\\n\",\n        \"    # The features of sampled nodes.\\n\",\n        \"    x = data.node_features[\\\"feat\\\"]\\n\",\n        \"\\n\",\n        \"    # Forward.\\n\",\n        \"    y = model(data.blocks, x)\\n\",\n        \"    logit = (\\n\",\n        \"        model.predictor(y[compacted_seeds[0]] * y[compacted_seeds[1]])\\n\",\n        \"        .squeeze()\\n\",\n        \"        .detach()\\n\",\n        \"    )\\n\",\n        \"\\n\",\n        \"    logits.append(logit)\\n\",\n        \"    labels.append(label)\\n\",\n        \"\\n\",\n        \"logits = torch.cat(logits, dim=0)\\n\",\n        \"labels = torch.cat(labels, dim=0)\\n\",\n        \"\\n\",\n        \"\\n\",\n        \"# Compute the AUROC score.\\n\",\n        \"from sklearn.metrics import roc_auc_score\\n\",\n        \"\\n\",\n        \"auc = roc_auc_score(labels.cpu(), logits.cpu())\\n\",\n        \"print(\\\"Link Prediction AUC:\\\", auc)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"KoCoIvqAZeCS\"\n      },\n      \"source\": [\n        \"## Conclusion\\n\",\n        \"In this tutorial, you have learned how to train a multi-layer GraphSAGE for link prediction with neighbor sampling.\"\n      ]\n    }\n  ],\n  \"metadata\": {\n    \"colab\": {\n      \"private_outputs\": true,\n      \"provenance\": []\n    },\n    \"kernelspec\": {\n      \"display_name\": \"Python 3\",\n      \"name\": \"python3\"\n    },\n    \"language_info\": {\n      \"codemirror_mode\": {\n        \"name\": \"ipython\",\n        \"version\": 3\n      },\n      \"file_extension\": \".py\",\n      \"mimetype\": \"text/x-python\",\n      \"name\": \"python\",\n      \"nbconvert_exporter\": \"python\",\n      \"pygments_lexer\": \"ipython3\",\n      \"version\": \"3.10.12\"\n    }\n  },\n  \"nbformat\": 4,\n  \"nbformat_minor\": 0\n}\n"
  },
  {
    "path": "notebooks/stochastic_training/multigpu_node_classification.ipynb",
    "content": "{\n  \"cells\": [\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"2ppSJal9At7-\"\n      },\n      \"source\": [\n        \"# Multi-GPU Node Classification\\n\",\n        \"\\n\",\n        \"This tutorial shows how to train a multi-layer GraphSAGE for node classification on the `ogbn-products` dataset provided by [Open Graph\\n\",\n        \"Benchmark (OGB)](https://ogb.stanford.edu/). The dataset contains around 2.4 million nodes and 62 million edges.\\n\",\n        \"\\n\",\n        \"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dmlc/dgl/blob/master/notebooks/stochastic_training/multigpu_node_classification.ipynb)\\n\",\n        \"[![GitHub](https://img.shields.io/badge/-View%20on%20GitHub-181717?logo=github&logoColor=ffffff)](https://github.com/dmlc/dgl/blob/master/notebooks/stochastic_training/multigpu_node_classification.ipynb)\\n\",\n        \"\\n\",\n        \"By the end of this tutorial, you will be able to\\n\",\n        \"\\n\",\n        \"- Train a GNN model for node classification on multiple GPUs with DGL's neighbor sampling components. After learning how to use multiple GPUs, you will\\n\",\n        \"be able to extend it to other scenarios such as link prediction.\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"mzZKrVVk6Y_8\"\n      },\n      \"source\": [\n        \"## Install DGL package and other dependencies\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": 1,\n      \"metadata\": {\n        \"id\": \"QTCc1RrD_5Id\"\n      },\n      \"outputs\": [\n        {\n          \"name\": \"stdout\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com\\n\",\n            \"Looking in links: https://data.dgl.ai/wheels-test/cu121/repo.html\\n\",\n            \"Requirement already satisfied: dgl in /localscratch/dgl-3/python (2.1)\\n\",\n            \"Requirement already satisfied: numpy>=1.14.0 in /usr/local/lib/python3.10/dist-packages (from dgl) (1.24.4)\\n\",\n            \"Requirement already satisfied: scipy>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from dgl) (1.11.4)\\n\",\n            \"Requirement already satisfied: networkx>=2.1 in /usr/local/lib/python3.10/dist-packages (from dgl) (2.6.3)\\n\",\n            \"Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.10/dist-packages (from dgl) (2.31.0)\\n\",\n            \"Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from dgl) (4.66.1)\\n\",\n            \"Requirement already satisfied: psutil>=5.8.0 in /usr/local/lib/python3.10/dist-packages (from dgl) (5.9.4)\\n\",\n            \"Requirement already satisfied: torchdata>=0.5.0 in /usr/local/lib/python3.10/dist-packages (from dgl) (0.7.0a0)\\n\",\n            \"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->dgl) (3.3.2)\\n\",\n            \"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->dgl) (3.6)\\n\",\n            \"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->dgl) (1.26.18)\\n\",\n            \"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->dgl) (2023.11.17)\\n\",\n            \"Requirement already satisfied: torch>=2 in /usr/local/lib/python3.10/dist-packages (from torchdata>=0.5.0->dgl) (2.2.0a0+81ea7a4)\\n\",\n            \"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=2->torchdata>=0.5.0->dgl) (3.13.1)\\n\",\n            \"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch>=2->torchdata>=0.5.0->dgl) (4.8.0)\\n\",\n            \"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=2->torchdata>=0.5.0->dgl) (1.12)\\n\",\n            \"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=2->torchdata>=0.5.0->dgl) (3.1.2)\\n\",\n            \"Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=2->torchdata>=0.5.0->dgl) (2023.12.0)\\n\",\n            \"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=2->torchdata>=0.5.0->dgl) (2.1.3)\\n\",\n            \"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=2->torchdata>=0.5.0->dgl) (1.3.0)\\n\",\n            \"\\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\\u001b[0m\\u001b[33m\\n\",\n            \"\\u001b[0m\\n\",\n            \"\\u001b[1m[\\u001b[0m\\u001b[34;49mnotice\\u001b[0m\\u001b[1;39;49m]\\u001b[0m\\u001b[39;49m A new release of pip is available: \\u001b[0m\\u001b[31;49m23.3.1\\u001b[0m\\u001b[39;49m -> \\u001b[0m\\u001b[32;49m24.0\\u001b[0m\\n\",\n            \"\\u001b[1m[\\u001b[0m\\u001b[34;49mnotice\\u001b[0m\\u001b[1;39;49m]\\u001b[0m\\u001b[39;49m To update, run: \\u001b[0m\\u001b[32;49mpython -m pip install --upgrade pip\\u001b[0m\\n\",\n            \"Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com\\n\",\n            \"Requirement already satisfied: torchmetrics in /usr/local/lib/python3.10/dist-packages (1.3.0.post0)\\n\",\n            \"Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (0.70.16)\\n\",\n            \"Requirement already satisfied: numpy>1.20.0 in /usr/local/lib/python3.10/dist-packages (from torchmetrics) (1.24.4)\\n\",\n            \"Requirement already satisfied: packaging>17.1 in /usr/local/lib/python3.10/dist-packages (from torchmetrics) (23.2)\\n\",\n            \"Requirement already satisfied: torch>=1.10.0 in /usr/local/lib/python3.10/dist-packages (from torchmetrics) (2.2.0a0+81ea7a4)\\n\",\n            \"Requirement already satisfied: lightning-utilities>=0.8.0 in /usr/local/lib/python3.10/dist-packages (from torchmetrics) (0.10.1)\\n\",\n            \"Requirement already satisfied: dill>=0.3.8 in /usr/local/lib/python3.10/dist-packages (from multiprocess) (0.3.8)\\n\",\n            \"Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from lightning-utilities>=0.8.0->torchmetrics) (68.2.2)\\n\",\n            \"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from lightning-utilities>=0.8.0->torchmetrics) (4.8.0)\\n\",\n            \"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->torchmetrics) (3.13.1)\\n\",\n            \"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->torchmetrics) (1.12)\\n\",\n            \"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->torchmetrics) (2.6.3)\\n\",\n            \"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->torchmetrics) (3.1.2)\\n\",\n            \"Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->torchmetrics) (2023.12.0)\\n\",\n            \"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.10.0->torchmetrics) (2.1.3)\\n\",\n            \"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.10.0->torchmetrics) (1.3.0)\\n\",\n            \"\\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\\u001b[0m\\u001b[33m\\n\",\n            \"\\u001b[0m\\n\",\n            \"\\u001b[1m[\\u001b[0m\\u001b[34;49mnotice\\u001b[0m\\u001b[1;39;49m]\\u001b[0m\\u001b[39;49m A new release of pip is available: \\u001b[0m\\u001b[31;49m23.3.1\\u001b[0m\\u001b[39;49m -> \\u001b[0m\\u001b[32;49m24.0\\u001b[0m\\n\",\n            \"\\u001b[1m[\\u001b[0m\\u001b[34;49mnotice\\u001b[0m\\u001b[1;39;49m]\\u001b[0m\\u001b[39;49m To update, run: \\u001b[0m\\u001b[32;49mpython -m pip install --upgrade pip\\u001b[0m\\n\",\n            \"DGL installed!\\n\"\n          ]\n        }\n      ],\n      \"source\": [\n        \"# Install required packages.\\n\",\n        \"import os\\n\",\n        \"import torch\\n\",\n        \"os.environ['TORCH'] = torch.__version__\\n\",\n        \"os.environ['DGLBACKEND'] = \\\"pytorch\\\"\\n\",\n        \"\\n\",\n        \"# Install the CUDA version. If you want to install CPU version, please\\n\",\n        \"# refer to https://www.dgl.ai/pages/start.html.\\n\",\n        \"!pip install --pre dgl -f https://data.dgl.ai/wheels-test/cu121/repo.html\\n\",\n        \"!pip install torchmetrics multiprocess\\n\",\n        \"\\n\",\n        \"try:\\n\",\n        \"    import dgl\\n\",\n        \"    import dgl.graphbolt as gb\\n\",\n        \"    installed = True\\n\",\n        \"except ImportError as error:\\n\",\n        \"    installed = False\\n\",\n        \"    print(error)\\n\",\n        \"print(\\\"DGL installed!\\\" if installed else \\\"DGL not found!\\\")\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"q7GrcJTnZQjt\"\n      },\n      \"source\": [\n        \"## Defining Neighbor Sampler and Data Loader in DGL\\n\",\n        \"\\n\",\n        \"The major difference from the previous tutorial is that we will use `DistributedItemSampler` instead of `ItemSampler` to sample mini-batches of nodes. `DistributedItemSampler` is a distributed version of `ItemSampler` that works with `DistributedDataParallel`. It is implemented as a wrapper around `ItemSampler` and will sample the same minibatch on all replicas. It also supports dropping the last non-full minibatch to avoid the need for padding.\\n\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": 2,\n      \"metadata\": {\n        \"id\": \"eel0Wn_aEYAd\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"def create_dataloader(graph, features, itemset, device, is_train):\\n\",\n        \"    datapipe = gb.DistributedItemSampler(\\n\",\n        \"        item_set=itemset,\\n\",\n        \"        batch_size=1024,\\n\",\n        \"        drop_last=is_train,\\n\",\n        \"        shuffle=is_train,\\n\",\n        \"        drop_uneven_inputs=is_train,\\n\",\n        \"    )\\n\",\n        \"    datapipe = datapipe.copy_to(device)\\n\",\n        \"    # Now that we have moved to device, sample_neighbor and fetch_feature steps\\n\",\n        \"    # will be executed on GPUs.\\n\",\n        \"    datapipe = datapipe.sample_neighbor(graph, [10, 10, 10])\\n\",\n        \"    datapipe = datapipe.fetch_feature(features, node_feature_keys=[\\\"feat\\\"])\\n\",\n        \"    return gb.DataLoader(datapipe)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"uswPlvOLF1IX\"\n      },\n      \"source\": [\n        \"## Weighted reduction across GPUs\\n\",\n        \"\\n\",\n        \"As the different GPUs might process differing numbers of data points, we define a function to compute the exact average of values such as loss or accuracy in a\\n\",\n        \"weighted manner.\\n\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": 3,\n      \"metadata\": {\n        \"id\": \"VXP0hmzVGKnp\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"import torch.distributed as dist\\n\",\n        \"\\n\",\n        \"def weighted_reduce(tensor, weight, dst=0):\\n\",\n        \"    ########################################################################\\n\",\n        \"    # (HIGHLIGHT) Collect accuracy and loss values from sub-processes and\\n\",\n        \"    # obtain overall average values.\\n\",\n        \"    #\\n\",\n        \"    # `torch.distributed.reduce` is used to reduce tensors from all the\\n\",\n        \"    # sub-processes to a specified process, ReduceOp.SUM is used by default.\\n\",\n        \"    #\\n\",\n        \"    # Because the GPUs may have differing numbers of processed items, we\\n\",\n        \"    # perform a weighted mean to calculate the exact loss and accuracy.\\n\",\n        \"    ########################################################################\\n\",\n        \"    dist.reduce(tensor=tensor, dst=dst)\\n\",\n        \"    weight = torch.tensor(weight, device=tensor.device)\\n\",\n        \"    dist.reduce(tensor=weight, dst=dst)\\n\",\n        \"    return tensor / weight\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"fV6epnRxbZl4\"\n      },\n      \"source\": [\n        \"## Defining Model\\n\",\n        \"Let’s consider training a 3-layer GraphSAGE with neighbor sampling. The model can be written as follows:\\n\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": 4,\n      \"metadata\": {\n        \"id\": \"ft9Ldg-yEsa5\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"from torch import nn\\n\",\n        \"import torch.nn.functional as F\\n\",\n        \"from dgl.nn import SAGEConv\\n\",\n        \"\\n\",\n        \"class SAGE(nn.Module):\\n\",\n        \"    def __init__(self, in_size, hidden_size, out_size):\\n\",\n        \"        super().__init__()\\n\",\n        \"        self.layers = nn.ModuleList()\\n\",\n        \"        # Three-layer GraphSAGE-mean.\\n\",\n        \"        self.layers.append(SAGEConv(in_size, hidden_size, \\\"mean\\\"))\\n\",\n        \"        self.layers.append(SAGEConv(hidden_size, hidden_size, \\\"mean\\\"))\\n\",\n        \"        self.layers.append(SAGEConv(hidden_size, out_size, \\\"mean\\\"))\\n\",\n        \"        self.dropout = nn.Dropout(0.5)\\n\",\n        \"        self.hidden_size = hidden_size\\n\",\n        \"        self.out_size = out_size\\n\",\n        \"        # Set the dtype for the layers manually.\\n\",\n        \"        self.float()\\n\",\n        \"\\n\",\n        \"    def forward(self, blocks, x):\\n\",\n        \"        hidden_x = x\\n\",\n        \"        for layer_idx, (layer, block) in enumerate(zip(self.layers, blocks)):\\n\",\n        \"            hidden_x = layer(block, hidden_x)\\n\",\n        \"            is_last_layer = layer_idx == len(self.layers) - 1\\n\",\n        \"            if not is_last_layer:\\n\",\n        \"                hidden_x = F.relu(hidden_x)\\n\",\n        \"                hidden_x = self.dropout(hidden_x)\\n\",\n        \"        return hidden_x\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"CjuvDKDVGbPW\"\n      },\n      \"source\": [\n        \"## Evaluation function\\n\",\n        \"\\n\",\n        \"The evaluation function can be used to calculate the validation accuracy during training or the testing accuracy at the end of the training. The difference from\\n\",\n        \"the previous tutorial is that we need to return the number of items processed\\n\",\n        \"by each GPU to take a weighted average.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": 5,\n      \"metadata\": {\n        \"id\": \"j4djoX9tG7Ib\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"import torchmetrics.functional as MF\\n\",\n        \"import tqdm\\n\",\n        \"\\n\",\n        \"@torch.no_grad()\\n\",\n        \"def evaluate(rank, model, graph, features, itemset, num_classes, device):\\n\",\n        \"    model.eval()\\n\",\n        \"    y = []\\n\",\n        \"    y_hats = []\\n\",\n        \"    dataloader = create_dataloader(\\n\",\n        \"        graph,\\n\",\n        \"        features,\\n\",\n        \"        itemset,\\n\",\n        \"        device,\\n\",\n        \"        is_train=False,\\n\",\n        \"    )\\n\",\n        \"\\n\",\n        \"    for data in tqdm.tqdm(dataloader) if rank == 0 else dataloader:\\n\",\n        \"        blocks = data.blocks\\n\",\n        \"        x = data.node_features[\\\"feat\\\"]\\n\",\n        \"        y.append(data.labels)\\n\",\n        \"        y_hats.append(model.module(blocks, x))\\n\",\n        \"\\n\",\n        \"    res = MF.accuracy(\\n\",\n        \"        torch.cat(y_hats),\\n\",\n        \"        torch.cat(y),\\n\",\n        \"        task=\\\"multiclass\\\",\\n\",\n        \"        num_classes=num_classes,\\n\",\n        \"    )\\n\",\n        \"\\n\",\n        \"    return res.to(device), sum(y_i.size(0) for y_i in y)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"kN5BbnR4HSU2\"\n      },\n      \"source\": [\n        \"## Training Loop\\n\",\n        \"\\n\",\n        \"The training loop is almost identical to the previous tutorial. In this tutorial, we explicitly disable uneven inputs coming from the dataloader, however, the Join Context Manager could be used to train possibly with incomplete batches at the end of epochs. Please refer to [this tutorial](https://pytorch.org/tutorials/advanced/generic_join.html) for more information.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": 6,\n      \"metadata\": {\n        \"id\": \"bdOceP3yH-eI\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"import time\\n\",\n        \"\\n\",\n        \"def train(\\n\",\n        \"    rank,\\n\",\n        \"    graph,\\n\",\n        \"    features,\\n\",\n        \"    train_set,\\n\",\n        \"    valid_set,\\n\",\n        \"    num_classes,\\n\",\n        \"    model,\\n\",\n        \"    device,\\n\",\n        \"):\\n\",\n        \"    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\\n\",\n        \"    # Create training data loader.\\n\",\n        \"    dataloader = create_dataloader(\\n\",\n        \"        graph,\\n\",\n        \"        features,\\n\",\n        \"        train_set,\\n\",\n        \"        device,\\n\",\n        \"        is_train=True,\\n\",\n        \"    )\\n\",\n        \"\\n\",\n        \"    for epoch in range(5):\\n\",\n        \"        epoch_start = time.time()\\n\",\n        \"\\n\",\n        \"        model.train()\\n\",\n        \"        total_loss = torch.tensor(0, dtype=torch.float, device=device)\\n\",\n        \"        num_train_items = 0\\n\",\n        \"        for data in tqdm.tqdm(dataloader) if rank == 0 else dataloader:\\n\",\n        \"            # The input features are from the source nodes in the first\\n\",\n        \"            # layer's computation graph.\\n\",\n        \"            x = data.node_features[\\\"feat\\\"]\\n\",\n        \"\\n\",\n        \"            # The ground truth labels are from the destination nodes\\n\",\n        \"            # in the last layer's computation graph.\\n\",\n        \"            y = data.labels\\n\",\n        \"\\n\",\n        \"            blocks = data.blocks\\n\",\n        \"\\n\",\n        \"            y_hat = model(blocks, x)\\n\",\n        \"\\n\",\n        \"            # Compute loss.\\n\",\n        \"            loss = F.cross_entropy(y_hat, y)\\n\",\n        \"\\n\",\n        \"            optimizer.zero_grad()\\n\",\n        \"            loss.backward()\\n\",\n        \"            optimizer.step()\\n\",\n        \"\\n\",\n        \"            total_loss += loss.detach() * y.size(0)\\n\",\n        \"            num_train_items += y.size(0)\\n\",\n        \"\\n\",\n        \"        # Evaluate the model.\\n\",\n        \"        if rank == 0:\\n\",\n        \"            print(\\\"Validating...\\\")\\n\",\n        \"        acc, num_val_items = evaluate(\\n\",\n        \"            rank,\\n\",\n        \"            model,\\n\",\n        \"            graph,\\n\",\n        \"            features,\\n\",\n        \"            valid_set,\\n\",\n        \"            num_classes,\\n\",\n        \"            device,\\n\",\n        \"        )\\n\",\n        \"        total_loss = weighted_reduce(total_loss, num_train_items)\\n\",\n        \"        acc = weighted_reduce(acc * num_val_items, num_val_items)\\n\",\n        \"\\n\",\n        \"        # We synchronize before measuring the epoch time.\\n\",\n        \"        torch.cuda.synchronize()\\n\",\n        \"        epoch_end = time.time()\\n\",\n        \"        if rank == 0:\\n\",\n        \"            print(\\n\",\n        \"                f\\\"Epoch {epoch:05d} | \\\"\\n\",\n        \"                f\\\"Average Loss {total_loss.item():.4f} | \\\"\\n\",\n        \"                f\\\"Accuracy {acc.item():.4f} | \\\"\\n\",\n        \"                f\\\"Time {epoch_end - epoch_start:.4f}\\\"\\n\",\n        \"            )\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"mA-Xu37uIHc4\"\n      },\n      \"source\": [\n        \"## Defining Training and Evaluation Procedures\\n\",\n        \"\\n\",\n        \"The following code defines the main function for each process. It is similar to the previous tutorial except that we need to initialize a distributed training context with `torch.distributed` and wrap the model with `torch.nn.parallel.DistributedDataParallel`.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": 7,\n      \"metadata\": {\n        \"id\": \"sW__HeslIMTT\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"def run(rank, world_size, devices, dataset):\\n\",\n        \"    # Set up multiprocessing environment.\\n\",\n        \"    device = devices[rank]\\n\",\n        \"    torch.cuda.set_device(device)\\n\",\n        \"    dist.init_process_group(\\n\",\n        \"        backend=\\\"nccl\\\",  # Use NCCL backend for distributed GPU training\\n\",\n        \"        init_method=\\\"tcp://127.0.0.1:12345\\\",\\n\",\n        \"        world_size=world_size,\\n\",\n        \"        rank=rank,\\n\",\n        \"    )\\n\",\n        \"\\n\",\n        \"    # Pin the graph and features in-place to enable GPU access.\\n\",\n        \"    graph = dataset.graph.pin_memory_()\\n\",\n        \"    features = dataset.feature.pin_memory_()\\n\",\n        \"    train_set = dataset.tasks[0].train_set\\n\",\n        \"    valid_set = dataset.tasks[0].validation_set\\n\",\n        \"    num_classes = dataset.tasks[0].metadata[\\\"num_classes\\\"]\\n\",\n        \"\\n\",\n        \"    in_size = features.size(\\\"node\\\", None, \\\"feat\\\")[0]\\n\",\n        \"    hidden_size = 256\\n\",\n        \"    out_size = num_classes\\n\",\n        \"\\n\",\n        \"    # Create GraphSAGE model. It should be copied onto a GPU as a replica.\\n\",\n        \"    model = SAGE(in_size, hidden_size, out_size).to(device)\\n\",\n        \"    model = nn.parallel.DistributedDataParallel(model)\\n\",\n        \"\\n\",\n        \"    # Model training.\\n\",\n        \"    if rank == 0:\\n\",\n        \"        print(\\\"Training...\\\")\\n\",\n        \"    train(\\n\",\n        \"        rank,\\n\",\n        \"        graph,\\n\",\n        \"        features,\\n\",\n        \"        train_set,\\n\",\n        \"        valid_set,\\n\",\n        \"        num_classes,\\n\",\n        \"        model,\\n\",\n        \"        device,\\n\",\n        \"    )\\n\",\n        \"\\n\",\n        \"    # Test the model.\\n\",\n        \"    if rank == 0:\\n\",\n        \"        print(\\\"Testing...\\\")\\n\",\n        \"    test_set = dataset.tasks[0].test_set\\n\",\n        \"    test_acc, num_test_items = evaluate(\\n\",\n        \"        rank,\\n\",\n        \"        model,\\n\",\n        \"        graph,\\n\",\n        \"        features,\\n\",\n        \"        itemset=test_set,\\n\",\n        \"        num_classes=num_classes,\\n\",\n        \"        device=device,\\n\",\n        \"    )\\n\",\n        \"    test_acc = weighted_reduce(test_acc * num_test_items, num_test_items)\\n\",\n        \"\\n\",\n        \"    if rank == 0:\\n\",\n        \"        print(f\\\"Test Accuracy {test_acc.item():.4f}\\\")\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"qMzt0aBFIfbS\"\n      },\n      \"source\": [\n        \"## Spawning Trainer Processes\\n\",\n        \"\\n\",\n        \"The following code spawns a process for each GPU and calls the run function defined above.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": 8,\n      \"metadata\": {\n        \"id\": \"5Dt95eSVIiyM\"\n      },\n      \"outputs\": [\n        {\n          \"name\": \"stdout\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"Training with 1 gpus.\\n\",\n            \"The dataset is already preprocessed.\\n\",\n            \"Training...\\n\"\n          ]\n        },\n        {\n          \"name\": \"stderr\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"192it [00:09, 21.32it/s]\\n\"\n          ]\n        },\n        {\n          \"name\": \"stdout\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"Validating...\\n\"\n          ]\n        },\n        {\n          \"name\": \"stderr\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"39it [00:00, 78.32it/s]\\n\"\n          ]\n        },\n        {\n          \"name\": \"stdout\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"Epoch 00000 | Average Loss 1.2953 | Accuracy 0.8556 | Time 9.5520\\n\"\n          ]\n        },\n        {\n          \"name\": \"stderr\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"192it [00:03, 61.08it/s]\\n\"\n          ]\n        },\n        {\n          \"name\": \"stdout\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"Validating...\\n\"\n          ]\n        },\n        {\n          \"name\": \"stderr\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"39it [00:00, 79.10it/s]\\n\"\n          ]\n        },\n        {\n          \"name\": \"stdout\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"Epoch 00001 | Average Loss 0.5859 | Accuracy 0.8788 | Time 3.6609\\n\"\n          ]\n        },\n        {\n          \"name\": \"stderr\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"192it [00:03, 62.82it/s]\\n\"\n          ]\n        },\n        {\n          \"name\": \"stdout\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"Validating...\\n\"\n          ]\n        },\n        {\n          \"name\": \"stderr\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"39it [00:00, 80.55it/s]\\n\"\n          ]\n        },\n        {\n          \"name\": \"stdout\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"Epoch 00002 | Average Loss 0.4858 | Accuracy 0.8852 | Time 3.5646\\n\"\n          ]\n        },\n        {\n          \"name\": \"stderr\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"192it [00:03, 60.34it/s]\\n\"\n          ]\n        },\n        {\n          \"name\": \"stdout\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"Validating...\\n\"\n          ]\n        },\n        {\n          \"name\": \"stderr\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"39it [00:00, 44.41it/s]\\n\"\n          ]\n        },\n        {\n          \"name\": \"stdout\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"Epoch 00003 | Average Loss 0.4407 | Accuracy 0.8920 | Time 4.0852\\n\"\n          ]\n        },\n        {\n          \"name\": \"stderr\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"192it [00:03, 58.87it/s]\\n\"\n          ]\n        },\n        {\n          \"name\": \"stdout\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"Validating...\\n\"\n          ]\n        },\n        {\n          \"name\": \"stderr\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"39it [00:00, 78.52it/s]\\n\"\n          ]\n        },\n        {\n          \"name\": \"stdout\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"Epoch 00004 | Average Loss 0.4122 | Accuracy 0.8943 | Time 3.7938\\n\",\n            \"Testing...\\n\"\n          ]\n        },\n        {\n          \"name\": \"stderr\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"2162it [00:24, 89.75it/s]\"\n          ]\n        },\n        {\n          \"name\": \"stdout\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"Test Accuracy 0.7514\\n\"\n          ]\n        },\n        {\n          \"name\": \"stderr\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"\\n\"\n          ]\n        }\n      ],\n      \"source\": [\n        \"import torch.multiprocessing as mp\\n\",\n        \"\\n\",\n        \"def main():\\n\",\n        \"    if not torch.cuda.is_available():\\n\",\n        \"        print(\\\"No GPU found!\\\")\\n\",\n        \"        return\\n\",\n        \"\\n\",\n        \"    devices = [\\n\",\n        \"        torch.device(f\\\"cuda:{i}\\\") for i in range(torch.cuda.device_count())\\n\",\n        \"    ][:1]\\n\",\n        \"    world_size = len(devices)\\n\",\n        \"\\n\",\n        \"    print(f\\\"Training with {world_size} gpus.\\\")\\n\",\n        \"\\n\",\n        \"    # Load and preprocess dataset.\\n\",\n        \"    dataset = gb.BuiltinDataset(\\\"ogbn-products\\\").load()\\n\",\n        \"\\n\",\n        \"    # Thread limiting to avoid resource competition.\\n\",\n        \"    os.environ[\\\"OMP_NUM_THREADS\\\"] = str(mp.cpu_count() // 2 // world_size)\\n\",\n        \"\\n\",\n        \"    if world_size > 1:\\n\",\n        \"        # The following launch method is not supported in a notebook.\\n\",\n        \"        mp.set_sharing_strategy(\\\"file_system\\\")\\n\",\n        \"        mp.spawn(\\n\",\n        \"            run,\\n\",\n        \"            args=(world_size, devices, dataset),\\n\",\n        \"            nprocs=world_size,\\n\",\n        \"            join=True,\\n\",\n        \"        )\\n\",\n        \"    else:\\n\",\n        \"        run(0, 1, devices, dataset)\\n\",\n        \"\\n\",\n        \"\\n\",\n        \"if __name__ == \\\"__main__\\\":\\n\",\n        \"    main()\"\n      ]\n    }\n  ],\n  \"metadata\": {\n    \"accelerator\": \"GPU\",\n    \"colab\": {\n      \"gpuType\": \"T4\",\n      \"provenance\": []\n    },\n    \"kernelspec\": {\n      \"display_name\": \"Python 3\",\n      \"name\": \"python3\"\n    },\n    \"language_info\": {\n      \"name\": \"python\"\n    }\n  },\n  \"nbformat\": 4,\n  \"nbformat_minor\": 0\n}\n"
  },
  {
    "path": "notebooks/stochastic_training/neighbor_sampling_overview.ipynb",
    "content": "{\n  \"nbformat\": 4,\n  \"nbformat_minor\": 0,\n  \"metadata\": {\n    \"colab\": {\n      \"private_outputs\": true,\n      \"provenance\": [],\n      \"authorship_tag\": \"ABX9TyMxpiQDo/pG6bIgkfWOPqXY\"\n    },\n    \"kernelspec\": {\n      \"name\": \"python3\",\n      \"display_name\": \"Python 3\"\n    },\n    \"language_info\": {\n      \"name\": \"python\"\n    }\n  },\n  \"cells\": [\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"# Neighbor Sampling Overview\\n\",\n        \"\\n\",\n        \"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dmlc/dgl/blob/master/notebooks/stochastic_training/neighbor_sampling_overview.ipynb) [![GitHub](https://img.shields.io/badge/-View%20on%20GitHub-181717?logo=github&logoColor=ffffff)](https://github.com/dmlc/dgl/blob/master/notebooks/stochastic_training/neighbor_sampling_overview.ipynb)\\n\",\n        \"\\n\",\n        \"In previous tutorials you have learned how to train GNNs by computing the representations of all nodes on a graph. However, sometimes your graph is too large to fit the computation of all nodes in a single GPU.\\n\",\n        \"\\n\",\n        \"By the end of this tutorial, you will be able to\\n\",\n        \"\\n\",\n        \"- Understand the pipeline of stochastic GNN training.\\n\",\n        \"\\n\",\n        \"- Understand what is neighbor sampling and why it yields a bipartite graph for each GNN layer.\"\n      ],\n      \"metadata\": {\n        \"id\": \"p7tTmsjh3dEy\"\n      }\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"## Message Passing Review\\n\",\n        \"Recall that in [Gilmer et al.](https://arxiv.org/abs/1704.01212), the message passing formulation is as follows:\\n\",\n        \"\\n\",\n        \"$$m_{u \\\\to v}^{(l)} = M^{(l)}\\\\left(h_v^{(l-1)}, h_u^{(l-1)}, e_{u \\\\to v}^{(l-1)}\\\\right)$$\\n\",\n        \"\\n\",\n        \"$$m_{v}^{(l)} = \\\\sum_{u \\\\in \\\\mathcal{N}(v)} m_{u \\\\to v}^{(l)}$$\\n\",\n        \"\\n\",\n        \"$$h_v^{(l)} = U^{(l)}\\\\left(h_v^{(l-1)}, m_v^{(l)}\\\\right)$$\\n\",\n        \"\\n\",\n        \"\\n\",\n        \"where DGL calls\\n\",\n        \"- message function: $M^{(l)}$\\n\",\n        \"- reduce function: $\\\\sum$\\n\",\n        \"- update function: $U^{(l)}$\\n\",\n        \"\\n\",\n        \"Note that $\\\\sum$ here can represent any function and is not necessarily a summation.\\n\",\n        \"\\n\",\n        \"Essentially, the $l$-th layer representation of a single node depends on the $(l-1)$-th layer representation of the same node, as well as the $(l-1)$-th layer representation of the neighboring nodes. Those $(l-1)$-th layer representations then depend on the $(l-2)$-th layer representation of those nodes, as well as their neighbors.\\n\",\n        \"\\n\",\n        \"The following animation shows how a 2-layer GNN is supposed to compute the output of node 5:\\n\",\n        \"\\n\",\n        \"![image1](https://data.dgl.ai/tutorial/img/sampling.gif)\\n\",\n        \"\\n\",\n        \"You can see that to compute node 5 from the second layer, you will need its direct neighbors’ first layer representations (colored in yellow), which in turn needs their direct neighbors’ (i.e. node 5’s second-hop neighbors’) representations (colored in green).\"\n      ],\n      \"metadata\": {\n        \"id\": \"eJs-O2Vz88Kd\"\n      }\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"## Neighbor Sampling Overview\\n\",\n        \"You can also see from the previous example that computing representation for a small number of nodes often requires input features of a significantly larger number of nodes. Taking all neighbors for message aggregation is often too costly since the nodes needed for input features would easily cover a large portion of the graph, especially for real-world graphs which are often [scale-free](https://en.wikipedia.org/wiki/Scale-free_network).\\n\",\n        \"\\n\",\n        \"Neighbor sampling addresses this issue by selecting a subset of the neighbors to perform aggregation. For instance, to compute ${h}_5^{(2)}$, you can choose two of the neighbors instead of all of them to aggregate, as in the following animation:\\n\",\n        \"\\n\",\n        \"![image2](https://data.dgl.ai/tutorial/img/bipartite.gif)\\n\",\n        \"\\n\",\n        \"You can see that this method uses much fewer nodes needed in message passing for a single minibatch.\\n\",\n        \"\\n\",\n        \"You can also notice in the animation above that the computation dependencies in the animation above can be described as a series of bipartite graphs. The output nodes (called destination nodes) are on one side and all the nodes necessary for inputs (called source nodes) are on the other side. The arrows indicate how the sampled neighbors propagates messages to the nodes. DGL calls such graphs **message flow graphs (MFG)**.\\n\",\n        \"\\n\",\n        \"Note that some GNN modules, such as `SAGEConv`, need to use the destination nodes’ features on the previous layer to compute the outputs. Without loss of generality, DGL always includes the destination nodes themselves in the source nodes.\"\n      ],\n      \"metadata\": {\n        \"id\": \"0yYSBM8s9M_P\"\n      }\n    }\n  ]\n}\n"
  },
  {
    "path": "notebooks/stochastic_training/node_classification.ipynb",
    "content": "{\n  \"cells\": [\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"OxbY2KlG4ZfJ\"\n      },\n      \"source\": [\n        \"# Node Classification\\n\",\n        \"This tutorial shows how to train a multi-layer GraphSAGE for node\\n\",\n        \"classification on ``ogbn-arxiv`` provided by [Open Graph\\n\",\n        \"Benchmark (OGB)](https://ogb.stanford.edu/). The dataset contains around\\n\",\n        \"170 thousand nodes and 1 million edges.\\n\",\n        \"\\n\",\n        \"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dmlc/dgl/blob/master/notebooks/stochastic_training/node_classification.ipynb) [![GitHub](https://img.shields.io/badge/-View%20on%20GitHub-181717?logo=github&logoColor=ffffff)](https://github.com/dmlc/dgl/blob/master/notebooks/stochastic_training/node_classification.ipynb)\\n\",\n        \"\\n\",\n        \"By the end of this tutorial, you will be able to\\n\",\n        \"\\n\",\n        \"-  Train a GNN model for node classification on a single GPU with DGL's\\n\",\n        \"   neighbor sampling components.\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"mzZKrVVk6Y_8\"\n      },\n      \"source\": [\n        \"## Install DGL package\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"QcpjTazg6hEo\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"# Install required packages.\\n\",\n        \"import os\\n\",\n        \"import torch\\n\",\n        \"import numpy as np\\n\",\n        \"os.environ['TORCH'] = torch.__version__\\n\",\n        \"os.environ['DGLBACKEND'] = \\\"pytorch\\\"\\n\",\n        \"\\n\",\n        \"# Install the CPU version in default. If you want to install CUDA version,\\n\",\n        \"# please refer to https://www.dgl.ai/pages/start.html and change runtime type\\n\",\n        \"# accordingly.\\n\",\n        \"device = torch.device(\\\"cpu\\\")\\n\",\n        \"!pip install --pre dgl -f https://data.dgl.ai/wheels-test/repo.html\\n\",\n        \"\\n\",\n        \"try:\\n\",\n        \"    import dgl\\n\",\n        \"    import dgl.graphbolt as gb\\n\",\n        \"    installed = True\\n\",\n        \"except ImportError as error:\\n\",\n        \"    installed = False\\n\",\n        \"    print(error)\\n\",\n        \"print(\\\"DGL installed!\\\" if installed else \\\"DGL not found!\\\")\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"XWdRZAM-51Cb\"\n      },\n      \"source\": [\n        \"## Loading Dataset\\n\",\n        \"`ogbn-arxiv` is already prepared as ``BuiltinDataset`` in **GraphBolt**.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"RnJkkSKhWiUG\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"dataset = gb.BuiltinDataset(\\\"ogbn-arxiv-seeds\\\").load()\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"S8avoKBiXA9j\"\n      },\n      \"source\": [\n        \"Dataset consists of graph, feature and tasks. You can get the training-validation-test set from the tasks. Seed nodes and corresponding labels are already stored in each training-validation-test set. Other metadata such as number of classes are also stored in the tasks. In this dataset, there is only one task: `node classification`.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"IXGZmgIaXJWQ\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"graph = dataset.graph.to(device)\\n\",\n        \"feature = dataset.feature.to(device)\\n\",\n        \"train_set = dataset.tasks[0].train_set\\n\",\n        \"valid_set = dataset.tasks[0].validation_set\\n\",\n        \"test_set = dataset.tasks[0].test_set\\n\",\n        \"task_name = dataset.tasks[0].metadata[\\\"name\\\"]\\n\",\n        \"num_classes = dataset.tasks[0].metadata[\\\"num_classes\\\"]\\n\",\n        \"print(f\\\"Task: {task_name}. Number of classes: {num_classes}\\\")\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"y8yn77Kg6HkW\"\n      },\n      \"source\": [\n        \"## How DGL Handles Computation Dependency¶\\n\",\n        \"The computation dependency for message passing of a single node can be described as a series of message flow graphs (MFG).\\n\",\n        \"\\n\",\n        \"![DGL Computation](https://data.dgl.ai/tutorial/img/bipartite.gif)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"q7GrcJTnZQjt\"\n      },\n      \"source\": [\n        \"## Defining Neighbor Sampler and Data Loader in DGL\\n\",\n        \"\\n\",\n        \"DGL provides tools to iterate over the dataset in minibatches while generating the computation dependencies to compute their outputs with the MFGs above. For node classification, you can use `dgl.graphbolt.DataLoader` for iterating over the dataset. It accepts a data pipe that generates minibatches of nodes and their labels, sample neighbors for each node, and generate the computation dependencies in the form of MFGs. Feature fetching, block creation and copying to target device are also supported. All these operations are split into separate stages in the data pipe, so that you can customize the data pipeline by inserting your own operations.\\n\",\n        \"\\n\",\n        \"Let’s say that each node will gather messages from 4 neighbors on each layer. The code defining the data loader and neighbor sampler will look like the following.\\n\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"yQVYDO0ZbBvi\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"def create_dataloader(itemset, shuffle):\\n\",\n        \"    datapipe = gb.ItemSampler(itemset, batch_size=1024, shuffle=shuffle)\\n\",\n        \"    datapipe = datapipe.copy_to(device)\\n\",\n        \"    datapipe = datapipe.sample_neighbor(graph, [4, 4])\\n\",\n        \"    datapipe = datapipe.fetch_feature(feature, node_feature_keys=[\\\"feat\\\"])\\n\",\n        \"    return gb.DataLoader(datapipe)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"7Rp12SUhbEV1\"\n      },\n      \"source\": [\n        \"You can iterate over the data loader and a `MiniBatch` object is yielded.\\n\",\n        \"\\n\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"V7vQiKj2bL_o\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"data = next(iter(create_dataloader(train_set, shuffle=True)))\\n\",\n        \"print(data)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"-eBuPnT-bS-o\"\n      },\n      \"source\": [\n        \"You can get the input node IDs from MFGs.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"bN4sgZqFbUvd\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"mfgs = data.blocks\\n\",\n        \"input_nodes = mfgs[0].srcdata[dgl.NID]\\n\",\n        \"print(f\\\"Input nodes: {input_nodes}.\\\")\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"fV6epnRxbZl4\"\n      },\n      \"source\": [\n        \"## Defining Model\\n\",\n        \"Let’s consider training a 2-layer GraphSAGE with neighbor sampling. The model can be written as follows:\\n\",\n        \"\\n\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"iKhEIL0Ccmwx\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"import torch.nn as nn\\n\",\n        \"import torch.nn.functional as F\\n\",\n        \"from dgl.nn import SAGEConv\\n\",\n        \"\\n\",\n        \"\\n\",\n        \"class Model(nn.Module):\\n\",\n        \"    def __init__(self, in_feats, h_feats, num_classes):\\n\",\n        \"        super(Model, self).__init__()\\n\",\n        \"        self.conv1 = SAGEConv(in_feats, h_feats, aggregator_type=\\\"mean\\\")\\n\",\n        \"        self.conv2 = SAGEConv(h_feats, num_classes, aggregator_type=\\\"mean\\\")\\n\",\n        \"        self.h_feats = h_feats\\n\",\n        \"\\n\",\n        \"    def forward(self, mfgs, x):\\n\",\n        \"        h = self.conv1(mfgs[0], x)\\n\",\n        \"        h = F.relu(h)\\n\",\n        \"        h = self.conv2(mfgs[1], h)\\n\",\n        \"        return h\\n\",\n        \"\\n\",\n        \"\\n\",\n        \"in_size = feature.size(\\\"node\\\", None, \\\"feat\\\")[0]\\n\",\n        \"model = Model(in_size, 64, num_classes).to(device)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"OGLN3kCcwCA8\"\n      },\n      \"source\": [\n        \"## Defining Training Loop\\n\",\n        \"\\n\",\n        \"The following initializes the model and defines the optimizer.\\n\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"dET8i_hewLUi\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"opt = torch.optim.Adam(model.parameters())\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"leZvFP4GwMcq\"\n      },\n      \"source\": [\n        \"When computing the validation score for model selection, usually you can also do neighbor sampling. We can just reuse our create_dataloader function to create two separate dataloaders for training and validation.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"Gvd7vFWZwQI5\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"train_dataloader = create_dataloader(train_set, shuffle=True)\\n\",\n        \"valid_dataloader = create_dataloader(valid_set, shuffle=False)\\n\",\n        \"\\n\",\n        \"import sklearn.metrics\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"nTIIfVMDwXqX\"\n      },\n      \"source\": [\n        \"The following is a training loop that performs validation every epoch. It also saves the model with the best validation accuracy into a file.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"wsfqhKUvwZEj\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"from tqdm.auto import tqdm\\n\",\n        \"\\n\",\n        \"for epoch in range(10):\\n\",\n        \"    model.train()\\n\",\n        \"\\n\",\n        \"    with tqdm(train_dataloader) as tq:\\n\",\n        \"        for step, data in enumerate(tq):\\n\",\n        \"            x = data.node_features[\\\"feat\\\"]\\n\",\n        \"            labels = data.labels\\n\",\n        \"\\n\",\n        \"            predictions = model(data.blocks, x)\\n\",\n        \"\\n\",\n        \"            loss = F.cross_entropy(predictions, labels)\\n\",\n        \"            opt.zero_grad()\\n\",\n        \"            loss.backward()\\n\",\n        \"            opt.step()\\n\",\n        \"\\n\",\n        \"            accuracy = sklearn.metrics.accuracy_score(\\n\",\n        \"                labels.cpu().numpy(),\\n\",\n        \"                predictions.argmax(1).detach().cpu().numpy(),\\n\",\n        \"            )\\n\",\n        \"\\n\",\n        \"            tq.set_postfix(\\n\",\n        \"                {\\\"loss\\\": \\\"%.03f\\\" % loss.item(), \\\"acc\\\": \\\"%.03f\\\" % accuracy},\\n\",\n        \"                refresh=False,\\n\",\n        \"            )\\n\",\n        \"\\n\",\n        \"    model.eval()\\n\",\n        \"\\n\",\n        \"    predictions = []\\n\",\n        \"    labels = []\\n\",\n        \"    with tqdm(valid_dataloader) as tq, torch.no_grad():\\n\",\n        \"        for data in tq:\\n\",\n        \"            x = data.node_features[\\\"feat\\\"]\\n\",\n        \"            labels.append(data.labels.cpu().numpy())\\n\",\n        \"            predictions.append(model(data.blocks, x).argmax(1).cpu().numpy())\\n\",\n        \"        predictions = np.concatenate(predictions)\\n\",\n        \"        labels = np.concatenate(labels)\\n\",\n        \"        accuracy = sklearn.metrics.accuracy_score(labels, predictions)\\n\",\n        \"        print(\\\"Epoch {} Validation Accuracy {}\\\".format(epoch, accuracy))\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"kmHnUI0QwfJ4\"\n      },\n      \"source\": [\n        \"## Conclusion\\n\",\n        \"\\n\",\n        \"In this tutorial, you have learned how to train a multi-layer GraphSAGE with neighbor sampling.\\n\"\n      ]\n    }\n  ],\n  \"metadata\": {\n    \"colab\": {\n      \"private_outputs\": true,\n      \"provenance\": []\n    },\n    \"kernelspec\": {\n      \"display_name\": \"Python 3\",\n      \"name\": \"python3\"\n    },\n    \"language_info\": {\n      \"codemirror_mode\": {\n        \"name\": \"ipython\",\n        \"version\": 3\n      },\n      \"file_extension\": \".py\",\n      \"mimetype\": \"text/x-python\",\n      \"name\": \"python\",\n      \"nbconvert_exporter\": \"python\",\n      \"pygments_lexer\": \"ipython3\",\n      \"version\": \"3.10.12\"\n    }\n  },\n  \"nbformat\": 4,\n  \"nbformat_minor\": 0\n}\n"
  },
  {
    "path": "notebooks/stochastic_training/ondisk_dataset_heterograph.ipynb",
    "content": "{\n  \"cells\": [\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"FnFhPMaAfLtJ\"\n      },\n      \"source\": [\n        \"# OnDiskDataset for Heterogeneous Graph\\n\",\n        \"\\n\",\n        \"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dmlc/dgl/blob/master/notebooks/stochastic_training/ondisk_dataset_heterograph.ipynb) [![GitHub](https://img.shields.io/badge/-View%20on%20GitHub-181717?logo=github&logoColor=ffffff)](https://github.com/dmlc/dgl/blob/master/notebooks/stochastic_training/ondisk_dataset_heterograph.ipynb)\\n\",\n        \"\\n\",\n        \"This tutorial shows how to create `OnDiskDataset` for heterogeneous graph that could be used in **GraphBolt** framework. The major difference from creating dataset for homogeneous graph is that we need to specify node/edge types for edges, feature data, training/validation/test sets.\\n\",\n        \"\\n\",\n        \"By the end of this tutorial, you will be able to\\n\",\n        \"\\n\",\n        \"- organize graph structure data.\\n\",\n        \"- organize feature data.\\n\",\n        \"- organize training/validation/test set for specific tasks.\\n\",\n        \"\\n\",\n        \"To create an ``OnDiskDataset`` object, you need to organize all the data including graph structure, feature data and tasks into a directory. The directory should contain a ``metadata.yaml`` file that describes the metadata of the dataset.\\n\",\n        \"\\n\",\n        \"Now let's generate various data step by step and organize them together to instantiate `OnDiskDataset` finally.\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"Wlb19DtWgtzq\"\n      },\n      \"source\": [\n        \"## Install DGL package\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"UojlT9ZGgyr9\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"# Install required packages.\\n\",\n        \"import os\\n\",\n        \"import torch\\n\",\n        \"import numpy as np\\n\",\n        \"os.environ['TORCH'] = torch.__version__\\n\",\n        \"os.environ['DGLBACKEND'] = \\\"pytorch\\\"\\n\",\n        \"\\n\",\n        \"# Install the CPU version.\\n\",\n        \"device = torch.device(\\\"cpu\\\")\\n\",\n        \"!pip install --pre dgl -f https://data.dgl.ai/wheels-test/repo.html\\n\",\n        \"\\n\",\n        \"try:\\n\",\n        \"    import dgl\\n\",\n        \"    import dgl.graphbolt as gb\\n\",\n        \"    installed = True\\n\",\n        \"except ImportError as error:\\n\",\n        \"    installed = False\\n\",\n        \"    print(error)\\n\",\n        \"print(\\\"DGL installed!\\\" if installed else \\\"DGL not found!\\\")\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"2R7WnSbjsfbr\"\n      },\n      \"source\": [\n        \"## Data preparation\\n\",\n        \"In order to demonstrate how to organize various data, let's create a base directory first.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"SZipbzyltLfO\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"base_dir = './ondisk_dataset_heterograph'\\n\",\n        \"os.makedirs(base_dir, exist_ok=True)\\n\",\n        \"print(f\\\"Created base directory: {base_dir}\\\")\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"qhNtIn_xhlnl\"\n      },\n      \"source\": [\n        \"### Generate graph structure data\\n\",\n        \"For heterogeneous graph, we need to save different edge edges(namely seeds) into separate **Numpy** or **CSV** files.\\n\",\n        \"\\n\",\n        \"Note:\\n\",\n        \"- when saving to **Numpy**, the array requires to be in shape of `(2, N)`. This format is recommended as constructing graph from it is much faster than **CSV** file.\\n\",\n        \"- when saving to **CSV** file, do not save index and header.\\n\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"HcBt4G5BmSjr\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"import numpy as np\\n\",\n        \"import pandas as pd\\n\",\n        \"\\n\",\n        \"# For simplicity, we create a heterogeneous graph with\\n\",\n        \"# 2 node types: `user`, `item`\\n\",\n        \"# 2 edge types: `user:like:item`, `user:follow:user`\\n\",\n        \"# And each node/edge type has the same number of nodes/edges.\\n\",\n        \"num_nodes = 1000\\n\",\n        \"num_edges = 10 * num_nodes\\n\",\n        \"\\n\",\n        \"# Edge type: \\\"user:like:item\\\"\\n\",\n        \"like_edges_path = os.path.join(base_dir, \\\"like-edges.csv\\\")\\n\",\n        \"like_edges = np.random.randint(0, num_nodes, size=(num_edges, 2))\\n\",\n        \"print(f\\\"Part of [user:like:item] edges: {like_edges[:5, :]}\\\\n\\\")\\n\",\n        \"\\n\",\n        \"df = pd.DataFrame(like_edges)\\n\",\n        \"df.to_csv(like_edges_path, index=False, header=False)\\n\",\n        \"print(f\\\"[user:like:item] edges are saved into {like_edges_path}\\\\n\\\")\\n\",\n        \"\\n\",\n        \"# Edge type: \\\"user:follow:user\\\"\\n\",\n        \"follow_edges_path = os.path.join(base_dir, \\\"follow-edges.csv\\\")\\n\",\n        \"follow_edges = np.random.randint(0, num_nodes, size=(num_edges, 2))\\n\",\n        \"print(f\\\"Part of [user:follow:user] edges: {follow_edges[:5, :]}\\\\n\\\")\\n\",\n        \"\\n\",\n        \"df = pd.DataFrame(follow_edges)\\n\",\n        \"df.to_csv(follow_edges_path, index=False, header=False)\\n\",\n        \"print(f\\\"[user:follow:user] edges are saved into {follow_edges_path}\\\\n\\\")\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"kh-4cPtzpcaH\"\n      },\n      \"source\": [\n        \"### Generate feature data for graph\\n\",\n        \"For feature data, numpy arrays and torch tensors are supported for now. Let's generate feature data for each node/edge type.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"_PVu1u5brBhF\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"# Generate node[user] feature in numpy array.\\n\",\n        \"node_user_feat_0_path = os.path.join(base_dir, \\\"node-user-feat-0.npy\\\")\\n\",\n        \"node_user_feat_0 = np.random.rand(num_nodes, 5)\\n\",\n        \"print(f\\\"Part of node[user] feature [feat_0]: {node_user_feat_0[:3, :]}\\\")\\n\",\n        \"np.save(node_user_feat_0_path, node_user_feat_0)\\n\",\n        \"print(f\\\"Node[user] feature [feat_0] is saved to {node_user_feat_0_path}\\\\n\\\")\\n\",\n        \"\\n\",\n        \"# Generate another node[user] feature in torch tensor\\n\",\n        \"node_user_feat_1_path = os.path.join(base_dir, \\\"node-user-feat-1.pt\\\")\\n\",\n        \"node_user_feat_1 = torch.rand(num_nodes, 5)\\n\",\n        \"print(f\\\"Part of node[user] feature [feat_1]: {node_user_feat_1[:3, :]}\\\")\\n\",\n        \"torch.save(node_user_feat_1, node_user_feat_1_path)\\n\",\n        \"print(f\\\"Node[user] feature [feat_1] is saved to {node_user_feat_1_path}\\\\n\\\")\\n\",\n        \"\\n\",\n        \"# Generate node[item] feature in numpy array.\\n\",\n        \"node_item_feat_0_path = os.path.join(base_dir, \\\"node-item-feat-0.npy\\\")\\n\",\n        \"node_item_feat_0 = np.random.rand(num_nodes, 5)\\n\",\n        \"print(f\\\"Part of node[item] feature [feat_0]: {node_item_feat_0[:3, :]}\\\")\\n\",\n        \"np.save(node_item_feat_0_path, node_item_feat_0)\\n\",\n        \"print(f\\\"Node[item] feature [feat_0] is saved to {node_item_feat_0_path}\\\\n\\\")\\n\",\n        \"\\n\",\n        \"# Generate another node[item] feature in torch tensor\\n\",\n        \"node_item_feat_1_path = os.path.join(base_dir, \\\"node-item-feat-1.pt\\\")\\n\",\n        \"node_item_feat_1 = torch.rand(num_nodes, 5)\\n\",\n        \"print(f\\\"Part of node[item] feature [feat_1]: {node_item_feat_1[:3, :]}\\\")\\n\",\n        \"torch.save(node_item_feat_1, node_item_feat_1_path)\\n\",\n        \"print(f\\\"Node[item] feature [feat_1] is saved to {node_item_feat_1_path}\\\\n\\\")\\n\",\n        \"\\n\",\n        \"# Generate edge[user:like:item] feature in numpy array.\\n\",\n        \"edge_like_feat_0_path = os.path.join(base_dir, \\\"edge-like-feat-0.npy\\\")\\n\",\n        \"edge_like_feat_0 = np.random.rand(num_edges, 5)\\n\",\n        \"print(f\\\"Part of edge[user:like:item] feature [feat_0]: {edge_like_feat_0[:3, :]}\\\")\\n\",\n        \"np.save(edge_like_feat_0_path, edge_like_feat_0)\\n\",\n        \"print(f\\\"Edge[user:like:item] feature [feat_0] is saved to {edge_like_feat_0_path}\\\\n\\\")\\n\",\n        \"\\n\",\n        \"# Generate another edge[user:like:item] feature in torch tensor\\n\",\n        \"edge_like_feat_1_path = os.path.join(base_dir, \\\"edge-like-feat-1.pt\\\")\\n\",\n        \"edge_like_feat_1 = torch.rand(num_edges, 5)\\n\",\n        \"print(f\\\"Part of edge[user:like:item] feature [feat_1]: {edge_like_feat_1[:3, :]}\\\")\\n\",\n        \"torch.save(edge_like_feat_1, edge_like_feat_1_path)\\n\",\n        \"print(f\\\"Edge[user:like:item] feature [feat_1] is saved to {edge_like_feat_1_path}\\\\n\\\")\\n\",\n        \"\\n\",\n        \"# Generate edge[user:follow:user] feature in numpy array.\\n\",\n        \"edge_follow_feat_0_path = os.path.join(base_dir, \\\"edge-follow-feat-0.npy\\\")\\n\",\n        \"edge_follow_feat_0 = np.random.rand(num_edges, 5)\\n\",\n        \"print(f\\\"Part of edge[user:follow:user] feature [feat_0]: {edge_follow_feat_0[:3, :]}\\\")\\n\",\n        \"np.save(edge_follow_feat_0_path, edge_follow_feat_0)\\n\",\n        \"print(f\\\"Edge[user:follow:user] feature [feat_0] is saved to {edge_follow_feat_0_path}\\\\n\\\")\\n\",\n        \"\\n\",\n        \"# Generate another edge[user:follow:user] feature in torch tensor\\n\",\n        \"edge_follow_feat_1_path = os.path.join(base_dir, \\\"edge-follow-feat-1.pt\\\")\\n\",\n        \"edge_follow_feat_1 = torch.rand(num_edges, 5)\\n\",\n        \"print(f\\\"Part of edge[user:follow:user] feature [feat_1]: {edge_follow_feat_1[:3, :]}\\\")\\n\",\n        \"torch.save(edge_follow_feat_1, edge_follow_feat_1_path)\\n\",\n        \"print(f\\\"Edge[user:follow:user] feature [feat_1] is saved to {edge_follow_feat_1_path}\\\\n\\\")\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"ZyqgOtsIwzh_\"\n      },\n      \"source\": [\n        \"### Generate tasks\\n\",\n        \"`OnDiskDataset` supports multiple tasks. For each task, we need to prepare training/validation/test sets respectively. Such sets usually vary among different tasks. In this tutorial, let's create a **Node Classification** task and **Link Prediction** task.\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"hVxHaDIfzCkr\"\n      },\n      \"source\": [\n        \"#### Node Classification Task\\n\",\n        \"For node classification task, we need **node IDs** and corresponding **labels** for each training/validation/test set. Like feature data, numpy arrays and torch tensors are supported for these sets.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"S5-fyBbHzTCO\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"# For illustration, let's generate item sets for each node type.\\n\",\n        \"num_trains = int(num_nodes * 0.6)\\n\",\n        \"num_vals = int(num_nodes * 0.2)\\n\",\n        \"num_tests = num_nodes - num_trains - num_vals\\n\",\n        \"\\n\",\n        \"user_ids = np.arange(num_nodes)\\n\",\n        \"np.random.shuffle(user_ids)\\n\",\n        \"\\n\",\n        \"item_ids = np.arange(num_nodes)\\n\",\n        \"np.random.shuffle(item_ids)\\n\",\n        \"\\n\",\n        \"# Train IDs for user.\\n\",\n        \"nc_train_user_ids_path = os.path.join(base_dir, \\\"nc-train-user-ids.npy\\\")\\n\",\n        \"nc_train_user_ids = user_ids[:num_trains]\\n\",\n        \"print(f\\\"Part of train ids[user] for node classification: {nc_train_user_ids[:3]}\\\")\\n\",\n        \"np.save(nc_train_user_ids_path, nc_train_user_ids)\\n\",\n        \"print(f\\\"NC train ids[user] are saved to {nc_train_user_ids_path}\\\\n\\\")\\n\",\n        \"\\n\",\n        \"# Train labels for user.\\n\",\n        \"nc_train_user_labels_path = os.path.join(base_dir, \\\"nc-train-user-labels.pt\\\")\\n\",\n        \"nc_train_user_labels = torch.randint(0, 10, (num_trains,))\\n\",\n        \"print(f\\\"Part of train labels[user] for node classification: {nc_train_user_labels[:3]}\\\")\\n\",\n        \"torch.save(nc_train_user_labels, nc_train_user_labels_path)\\n\",\n        \"print(f\\\"NC train labels[user] are saved to {nc_train_user_labels_path}\\\\n\\\")\\n\",\n        \"\\n\",\n        \"# Train IDs for item.\\n\",\n        \"nc_train_item_ids_path = os.path.join(base_dir, \\\"nc-train-item-ids.npy\\\")\\n\",\n        \"nc_train_item_ids = item_ids[:num_trains]\\n\",\n        \"print(f\\\"Part of train ids[item] for node classification: {nc_train_item_ids[:3]}\\\")\\n\",\n        \"np.save(nc_train_item_ids_path, nc_train_item_ids)\\n\",\n        \"print(f\\\"NC train ids[item] are saved to {nc_train_item_ids_path}\\\\n\\\")\\n\",\n        \"\\n\",\n        \"# Train labels for item.\\n\",\n        \"nc_train_item_labels_path = os.path.join(base_dir, \\\"nc-train-item-labels.pt\\\")\\n\",\n        \"nc_train_item_labels = torch.randint(0, 10, (num_trains,))\\n\",\n        \"print(f\\\"Part of train labels[item] for node classification: {nc_train_item_labels[:3]}\\\")\\n\",\n        \"torch.save(nc_train_item_labels, nc_train_item_labels_path)\\n\",\n        \"print(f\\\"NC train labels[item] are saved to {nc_train_item_labels_path}\\\\n\\\")\\n\",\n        \"\\n\",\n        \"# Val IDs for user.\\n\",\n        \"nc_val_user_ids_path = os.path.join(base_dir, \\\"nc-val-user-ids.npy\\\")\\n\",\n        \"nc_val_user_ids = user_ids[num_trains:num_trains+num_vals]\\n\",\n        \"print(f\\\"Part of val ids[user] for node classification: {nc_val_user_ids[:3]}\\\")\\n\",\n        \"np.save(nc_val_user_ids_path, nc_val_user_ids)\\n\",\n        \"print(f\\\"NC val ids[user] are saved to {nc_val_user_ids_path}\\\\n\\\")\\n\",\n        \"\\n\",\n        \"# Val labels for user.\\n\",\n        \"nc_val_user_labels_path = os.path.join(base_dir, \\\"nc-val-user-labels.pt\\\")\\n\",\n        \"nc_val_user_labels = torch.randint(0, 10, (num_vals,))\\n\",\n        \"print(f\\\"Part of val labels[user] for node classification: {nc_val_user_labels[:3]}\\\")\\n\",\n        \"torch.save(nc_val_user_labels, nc_val_user_labels_path)\\n\",\n        \"print(f\\\"NC val labels[user] are saved to {nc_val_user_labels_path}\\\\n\\\")\\n\",\n        \"\\n\",\n        \"# Val IDs for item.\\n\",\n        \"nc_val_item_ids_path = os.path.join(base_dir, \\\"nc-val-item-ids.npy\\\")\\n\",\n        \"nc_val_item_ids = item_ids[num_trains:num_trains+num_vals]\\n\",\n        \"print(f\\\"Part of val ids[item] for node classification: {nc_val_item_ids[:3]}\\\")\\n\",\n        \"np.save(nc_val_item_ids_path, nc_val_item_ids)\\n\",\n        \"print(f\\\"NC val ids[item] are saved to {nc_val_item_ids_path}\\\\n\\\")\\n\",\n        \"\\n\",\n        \"# Val labels for item.\\n\",\n        \"nc_val_item_labels_path = os.path.join(base_dir, \\\"nc-val-item-labels.pt\\\")\\n\",\n        \"nc_val_item_labels = torch.randint(0, 10, (num_vals,))\\n\",\n        \"print(f\\\"Part of val labels[item] for node classification: {nc_val_item_labels[:3]}\\\")\\n\",\n        \"torch.save(nc_val_item_labels, nc_val_item_labels_path)\\n\",\n        \"print(f\\\"NC val labels[item] are saved to {nc_val_item_labels_path}\\\\n\\\")\\n\",\n        \"\\n\",\n        \"# Test IDs for user.\\n\",\n        \"nc_test_user_ids_path = os.path.join(base_dir, \\\"nc-test-user-ids.npy\\\")\\n\",\n        \"nc_test_user_ids = user_ids[-num_tests:]\\n\",\n        \"print(f\\\"Part of test ids[user] for node classification: {nc_test_user_ids[:3]}\\\")\\n\",\n        \"np.save(nc_test_user_ids_path, nc_test_user_ids)\\n\",\n        \"print(f\\\"NC test ids[user] are saved to {nc_test_user_ids_path}\\\\n\\\")\\n\",\n        \"\\n\",\n        \"# Test labels for user.\\n\",\n        \"nc_test_user_labels_path = os.path.join(base_dir, \\\"nc-test-user-labels.pt\\\")\\n\",\n        \"nc_test_user_labels = torch.randint(0, 10, (num_tests,))\\n\",\n        \"print(f\\\"Part of test labels[user] for node classification: {nc_test_user_labels[:3]}\\\")\\n\",\n        \"torch.save(nc_test_user_labels, nc_test_user_labels_path)\\n\",\n        \"print(f\\\"NC test labels[user] are saved to {nc_test_user_labels_path}\\\\n\\\")\\n\",\n        \"\\n\",\n        \"# Test IDs for item.\\n\",\n        \"nc_test_item_ids_path = os.path.join(base_dir, \\\"nc-test-item-ids.npy\\\")\\n\",\n        \"nc_test_item_ids = item_ids[-num_tests:]\\n\",\n        \"print(f\\\"Part of test ids[item] for node classification: {nc_test_item_ids[:3]}\\\")\\n\",\n        \"np.save(nc_test_item_ids_path, nc_test_item_ids)\\n\",\n        \"print(f\\\"NC test ids[item] are saved to {nc_test_item_ids_path}\\\\n\\\")\\n\",\n        \"\\n\",\n        \"# Test labels for item.\\n\",\n        \"nc_test_item_labels_path = os.path.join(base_dir, \\\"nc-test-item-labels.pt\\\")\\n\",\n        \"nc_test_item_labels = torch.randint(0, 10, (num_tests,))\\n\",\n        \"print(f\\\"Part of test labels[item] for node classification: {nc_test_item_labels[:3]}\\\")\\n\",\n        \"torch.save(nc_test_item_labels, nc_test_item_labels_path)\\n\",\n        \"print(f\\\"NC test labels[item] are saved to {nc_test_item_labels_path}\\\\n\\\")\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"LhAcDCHQ_KJ0\"\n      },\n      \"source\": [\n        \"#### Link Prediction Task\\n\",\n        \"For link prediction task, we need **seeds** or **corresponding labels and indexes** which representing the pos/neg property and group of the seeds for each training/validation/test set. Like feature data, numpy arrays and torch tensors are supported for these sets.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"u0jCnXIcAQy4\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"# For illustration, let's generate item sets for each edge type.\\n\",\n        \"num_trains = int(num_edges * 0.6)\\n\",\n        \"num_vals = int(num_edges * 0.2)\\n\",\n        \"num_tests = num_edges - num_trains - num_vals\\n\",\n        \"\\n\",\n        \"# Train seeds for user:like:item.\\n\",\n        \"lp_train_like_seeds_path = os.path.join(base_dir, \\\"lp-train-like-seeds.npy\\\")\\n\",\n        \"lp_train_like_seeds = like_edges[:num_trains, :]\\n\",\n        \"print(f\\\"Part of train seeds[user:like:item] for link prediction: {lp_train_like_seeds[:3]}\\\")\\n\",\n        \"np.save(lp_train_like_seeds_path, lp_train_like_seeds)\\n\",\n        \"print(f\\\"LP train seeds[user:like:item] are saved to {lp_train_like_seeds_path}\\\\n\\\")\\n\",\n        \"\\n\",\n        \"# Train seeds for user:follow:user.\\n\",\n        \"lp_train_follow_seeds_path = os.path.join(base_dir, \\\"lp-train-follow-seeds.npy\\\")\\n\",\n        \"lp_train_follow_seeds = follow_edges[:num_trains, :]\\n\",\n        \"print(f\\\"Part of train seeds[user:follow:user] for link prediction: {lp_train_follow_seeds[:3]}\\\")\\n\",\n        \"np.save(lp_train_follow_seeds_path, lp_train_follow_seeds)\\n\",\n        \"print(f\\\"LP train seeds[user:follow:user] are saved to {lp_train_follow_seeds_path}\\\\n\\\")\\n\",\n        \"\\n\",\n        \"# Val seeds for user:like:item.\\n\",\n        \"lp_val_like_seeds_path = os.path.join(base_dir, \\\"lp-val-like-seeds.npy\\\")\\n\",\n        \"lp_val_like_seeds = like_edges[num_trains:num_trains+num_vals, :]\\n\",\n        \"lp_val_like_neg_dsts = np.random.randint(0, num_nodes, (num_vals, 10)).reshape(-1)\\n\",\n        \"lp_val_like_neg_srcs = np.repeat(lp_val_like_seeds[:,0], 10)\\n\",\n        \"lp_val_like_neg_seeds = np.concatenate((lp_val_like_neg_srcs, lp_val_like_neg_dsts)).reshape(2,-1).T\\n\",\n        \"lp_val_like_seeds = np.concatenate((lp_val_like_seeds, lp_val_like_neg_seeds))\\n\",\n        \"print(f\\\"Part of val seeds[user:like:item] for link prediction: {lp_val_like_seeds[:3]}\\\")\\n\",\n        \"np.save(lp_val_like_seeds_path, lp_val_like_seeds)\\n\",\n        \"print(f\\\"LP val seeds[user:like:item] are saved to {lp_val_like_seeds_path}\\\\n\\\")\\n\",\n        \"\\n\",\n        \"# Val labels for user:like:item.\\n\",\n        \"lp_val_like_labels_path = os.path.join(base_dir, \\\"lp-val-like-labels.npy\\\")\\n\",\n        \"lp_val_like_labels = np.empty(num_vals * (10 + 1))\\n\",\n        \"lp_val_like_labels[:num_vals] = 1\\n\",\n        \"lp_val_like_labels[num_vals:] = 0\\n\",\n        \"print(f\\\"Part of val labels[user:like:item] for link prediction: {lp_val_like_labels[:3]}\\\")\\n\",\n        \"np.save(lp_val_like_labels_path, lp_val_like_labels)\\n\",\n        \"print(f\\\"LP val labels[user:like:item] are saved to {lp_val_like_labels_path}\\\\n\\\")\\n\",\n        \"\\n\",\n        \"# Val indexes for user:like:item.\\n\",\n        \"lp_val_like_indexes_path = os.path.join(base_dir, \\\"lp-val-like-indexes.npy\\\")\\n\",\n        \"lp_val_like_indexes = np.arange(0, num_vals)\\n\",\n        \"lp_val_like_neg_indexes = np.repeat(lp_val_like_indexes, 10)\\n\",\n        \"lp_val_like_indexes = np.concatenate([lp_val_like_indexes, lp_val_like_neg_indexes])\\n\",\n        \"print(f\\\"Part of val indexes[user:like:item] for link prediction: {lp_val_like_indexes[:3]}\\\")\\n\",\n        \"np.save(lp_val_like_indexes_path, lp_val_like_indexes)\\n\",\n        \"print(f\\\"LP val indexes[user:like:item] are saved to {lp_val_like_indexes_path}\\\\n\\\")\\n\",\n        \"\\n\",\n        \"# Val seeds for user:follow:item.\\n\",\n        \"lp_val_follow_seeds_path = os.path.join(base_dir, \\\"lp-val-follow-seeds.npy\\\")\\n\",\n        \"lp_val_follow_seeds = follow_edges[num_trains:num_trains+num_vals, :]\\n\",\n        \"lp_val_follow_neg_dsts = np.random.randint(0, num_nodes, (num_vals, 10)).reshape(-1)\\n\",\n        \"lp_val_follow_neg_srcs = np.repeat(lp_val_follow_seeds[:,0], 10)\\n\",\n        \"lp_val_follow_neg_seeds = np.concatenate((lp_val_follow_neg_srcs, lp_val_follow_neg_dsts)).reshape(2,-1).T\\n\",\n        \"lp_val_follow_seeds = np.concatenate((lp_val_follow_seeds, lp_val_follow_neg_seeds))\\n\",\n        \"print(f\\\"Part of val seeds[user:follow:item] for link prediction: {lp_val_follow_seeds[:3]}\\\")\\n\",\n        \"np.save(lp_val_follow_seeds_path, lp_val_follow_seeds)\\n\",\n        \"print(f\\\"LP val seeds[user:follow:item] are saved to {lp_val_follow_seeds_path}\\\\n\\\")\\n\",\n        \"\\n\",\n        \"# Val labels for user:follow:item.\\n\",\n        \"lp_val_follow_labels_path = os.path.join(base_dir, \\\"lp-val-follow-labels.npy\\\")\\n\",\n        \"lp_val_follow_labels = np.empty(num_vals * (10 + 1))\\n\",\n        \"lp_val_follow_labels[:num_vals] = 1\\n\",\n        \"lp_val_follow_labels[num_vals:] = 0\\n\",\n        \"print(f\\\"Part of val labels[user:follow:item] for link prediction: {lp_val_follow_labels[:3]}\\\")\\n\",\n        \"np.save(lp_val_follow_labels_path, lp_val_follow_labels)\\n\",\n        \"print(f\\\"LP val labels[user:follow:item] are saved to {lp_val_follow_labels_path}\\\\n\\\")\\n\",\n        \"\\n\",\n        \"# Val indexes for user:follow:item.\\n\",\n        \"lp_val_follow_indexes_path = os.path.join(base_dir, \\\"lp-val-follow-indexes.npy\\\")\\n\",\n        \"lp_val_follow_indexes = np.arange(0, num_vals)\\n\",\n        \"lp_val_follow_neg_indexes = np.repeat(lp_val_follow_indexes, 10)\\n\",\n        \"lp_val_follow_indexes = np.concatenate([lp_val_follow_indexes, lp_val_follow_neg_indexes])\\n\",\n        \"print(f\\\"Part of val indexes[user:follow:item] for link prediction: {lp_val_follow_indexes[:3]}\\\")\\n\",\n        \"np.save(lp_val_follow_indexes_path, lp_val_follow_indexes)\\n\",\n        \"print(f\\\"LP val indexes[user:follow:item] are saved to {lp_val_follow_indexes_path}\\\\n\\\")\\n\",\n        \"\\n\",\n        \"# Test seeds for user:like:item.\\n\",\n        \"lp_test_like_seeds_path = os.path.join(base_dir, \\\"lp-test-like-seeds.npy\\\")\\n\",\n        \"lp_test_like_seeds = like_edges[-num_tests:, :]\\n\",\n        \"lp_test_like_neg_dsts = np.random.randint(0, num_nodes, (num_tests, 10)).reshape(-1)\\n\",\n        \"lp_test_like_neg_srcs = np.repeat(lp_test_like_seeds[:,0], 10)\\n\",\n        \"lp_test_like_neg_seeds = np.concatenate((lp_test_like_neg_srcs, lp_test_like_neg_dsts)).reshape(2,-1).T\\n\",\n        \"lp_test_like_seeds = np.concatenate((lp_test_like_seeds, lp_test_like_neg_seeds))\\n\",\n        \"print(f\\\"Part of test seeds[user:like:item] for link prediction: {lp_test_like_seeds[:3]}\\\")\\n\",\n        \"np.save(lp_test_like_seeds_path, lp_test_like_seeds)\\n\",\n        \"print(f\\\"LP test seeds[user:like:item] are saved to {lp_test_like_seeds_path}\\\\n\\\")\\n\",\n        \"\\n\",\n        \"# Test labels for user:like:item.\\n\",\n        \"lp_test_like_labels_path = os.path.join(base_dir, \\\"lp-test-like-labels.npy\\\")\\n\",\n        \"lp_test_like_labels = np.empty(num_tests * (10 + 1))\\n\",\n        \"lp_test_like_labels[:num_tests] = 1\\n\",\n        \"lp_test_like_labels[num_tests:] = 0\\n\",\n        \"print(f\\\"Part of test labels[user:like:item] for link prediction: {lp_test_like_labels[:3]}\\\")\\n\",\n        \"np.save(lp_test_like_labels_path, lp_test_like_labels)\\n\",\n        \"print(f\\\"LP test labels[user:like:item] are saved to {lp_test_like_labels_path}\\\\n\\\")\\n\",\n        \"\\n\",\n        \"# Test indexes for user:like:item.\\n\",\n        \"lp_test_like_indexes_path = os.path.join(base_dir, \\\"lp-test-like-indexes.npy\\\")\\n\",\n        \"lp_test_like_indexes = np.arange(0, num_tests)\\n\",\n        \"lp_test_like_neg_indexes = np.repeat(lp_test_like_indexes, 10)\\n\",\n        \"lp_test_like_indexes = np.concatenate([lp_test_like_indexes, lp_test_like_neg_indexes])\\n\",\n        \"print(f\\\"Part of test indexes[user:like:item] for link prediction: {lp_test_like_indexes[:3]}\\\")\\n\",\n        \"np.save(lp_test_like_indexes_path, lp_test_like_indexes)\\n\",\n        \"print(f\\\"LP test indexes[user:like:item] are saved to {lp_test_like_indexes_path}\\\\n\\\")\\n\",\n        \"\\n\",\n        \"# Test seeds for user:follow:item.\\n\",\n        \"lp_test_follow_seeds_path = os.path.join(base_dir, \\\"lp-test-follow-seeds.npy\\\")\\n\",\n        \"lp_test_follow_seeds = follow_edges[-num_tests:, :]\\n\",\n        \"lp_test_follow_neg_dsts = np.random.randint(0, num_nodes, (num_tests, 10)).reshape(-1)\\n\",\n        \"lp_test_follow_neg_srcs = np.repeat(lp_test_follow_seeds[:,0], 10)\\n\",\n        \"lp_test_follow_neg_seeds = np.concatenate((lp_test_follow_neg_srcs, lp_test_follow_neg_dsts)).reshape(2,-1).T\\n\",\n        \"lp_test_follow_seeds = np.concatenate((lp_test_follow_seeds, lp_test_follow_neg_seeds))\\n\",\n        \"print(f\\\"Part of test seeds[user:follow:item] for link prediction: {lp_test_follow_seeds[:3]}\\\")\\n\",\n        \"np.save(lp_test_follow_seeds_path, lp_test_follow_seeds)\\n\",\n        \"print(f\\\"LP test seeds[user:follow:item] are saved to {lp_test_follow_seeds_path}\\\\n\\\")\\n\",\n        \"\\n\",\n        \"# Test labels for user:follow:item.\\n\",\n        \"lp_test_follow_labels_path = os.path.join(base_dir, \\\"lp-test-follow-labels.npy\\\")\\n\",\n        \"lp_test_follow_labels = np.empty(num_tests * (10 + 1))\\n\",\n        \"lp_test_follow_labels[:num_tests] = 1\\n\",\n        \"lp_test_follow_labels[num_tests:] = 0\\n\",\n        \"print(f\\\"Part of test labels[user:follow:item] for link prediction: {lp_test_follow_labels[:3]}\\\")\\n\",\n        \"np.save(lp_test_follow_labels_path, lp_test_follow_labels)\\n\",\n        \"print(f\\\"LP test labels[user:follow:item] are saved to {lp_test_follow_labels_path}\\\\n\\\")\\n\",\n        \"\\n\",\n        \"# Test indexes for user:follow:item.\\n\",\n        \"lp_test_follow_indexes_path = os.path.join(base_dir, \\\"lp-test-follow-indexes.npy\\\")\\n\",\n        \"lp_test_follow_indexes = np.arange(0, num_tests)\\n\",\n        \"lp_test_follow_neg_indexes = np.repeat(lp_test_follow_indexes, 10)\\n\",\n        \"lp_test_follow_indexes = np.concatenate([lp_test_follow_indexes, lp_test_follow_neg_indexes])\\n\",\n        \"print(f\\\"Part of test indexes[user:follow:item] for link prediction: {lp_test_follow_indexes[:3]}\\\")\\n\",\n        \"np.save(lp_test_follow_indexes_path, lp_test_follow_indexes)\\n\",\n        \"print(f\\\"LP test indexes[user:follow:item] are saved to {lp_test_follow_indexes_path}\\\\n\\\")\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"wbk6-wxRK-6S\"\n      },\n      \"source\": [\n        \"## Organize Data into YAML File\\n\",\n        \"Now we need to create a `metadata.yaml` file which contains the paths, dadta types of graph structure, feature data, training/validation/test sets. Please note that all path should be relative to `metadata.yaml`.\\n\",\n        \"\\n\",\n        \"For heterogeneous graph, we need to specify the node/edge type in **type** fields. For edge type, canonical etype is required which is a string that's concatenated by source node type, etype, and destination node type together with `:`.\\n\",\n        \"\\n\",\n        \"Notes:\\n\",\n        \"- all path should be relative to `metadata.yaml`.\\n\",\n        \"- Below fields are optional and not specified in below example.\\n\",\n        \"  - `in_memory`: indicates whether to load dada into memory or `mmap`. Default is `True`.\\n\",\n        \"\\n\",\n        \"Please refer to [YAML specification](https://github.com/dmlc/dgl/blob/master/docs/source/stochastic_training/ondisk-dataset-specification.rst) for more details.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"ddGTWW61Lpwp\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"yaml_content = f\\\"\\\"\\\"\\n\",\n        \"    dataset_name: heterogeneous_graph_nc_lp\\n\",\n        \"    graph:\\n\",\n        \"      nodes:\\n\",\n        \"        - type: user\\n\",\n        \"          num: {num_nodes}\\n\",\n        \"        - type: item\\n\",\n        \"          num: {num_nodes}\\n\",\n        \"      edges:\\n\",\n        \"        - type: \\\"user:like:item\\\"\\n\",\n        \"          format: csv\\n\",\n        \"          path: {os.path.basename(like_edges_path)}\\n\",\n        \"        - type: \\\"user:follow:user\\\"\\n\",\n        \"          format: csv\\n\",\n        \"          path: {os.path.basename(follow_edges_path)}\\n\",\n        \"    feature_data:\\n\",\n        \"      - domain: node\\n\",\n        \"        type: user\\n\",\n        \"        name: feat_0\\n\",\n        \"        format: numpy\\n\",\n        \"        path: {os.path.basename(node_user_feat_0_path)}\\n\",\n        \"      - domain: node\\n\",\n        \"        type: user\\n\",\n        \"        name: feat_1\\n\",\n        \"        format: torch\\n\",\n        \"        path: {os.path.basename(node_user_feat_1_path)}\\n\",\n        \"      - domain: node\\n\",\n        \"        type: item\\n\",\n        \"        name: feat_0\\n\",\n        \"        format: numpy\\n\",\n        \"        path: {os.path.basename(node_item_feat_0_path)}\\n\",\n        \"      - domain: node\\n\",\n        \"        type: item\\n\",\n        \"        name: feat_1\\n\",\n        \"        format: torch\\n\",\n        \"        path: {os.path.basename(node_item_feat_1_path)}\\n\",\n        \"      - domain: edge\\n\",\n        \"        type: \\\"user:like:item\\\"\\n\",\n        \"        name: feat_0\\n\",\n        \"        format: numpy\\n\",\n        \"        path: {os.path.basename(edge_like_feat_0_path)}\\n\",\n        \"      - domain: edge\\n\",\n        \"        type: \\\"user:like:item\\\"\\n\",\n        \"        name: feat_1\\n\",\n        \"        format: torch\\n\",\n        \"        path: {os.path.basename(edge_like_feat_1_path)}\\n\",\n        \"      - domain: edge\\n\",\n        \"        type: \\\"user:follow:user\\\"\\n\",\n        \"        name: feat_0\\n\",\n        \"        format: numpy\\n\",\n        \"        path: {os.path.basename(edge_follow_feat_0_path)}\\n\",\n        \"      - domain: edge\\n\",\n        \"        type: \\\"user:follow:user\\\"\\n\",\n        \"        name: feat_1\\n\",\n        \"        format: torch\\n\",\n        \"        path: {os.path.basename(edge_follow_feat_1_path)}\\n\",\n        \"    tasks:\\n\",\n        \"      - name: node_classification\\n\",\n        \"        num_classes: 10\\n\",\n        \"        train_set:\\n\",\n        \"          - type: user\\n\",\n        \"            data:\\n\",\n        \"              - name: seeds\\n\",\n        \"                format: numpy\\n\",\n        \"                path: {os.path.basename(nc_train_user_ids_path)}\\n\",\n        \"              - name: labels\\n\",\n        \"                format: torch\\n\",\n        \"                path: {os.path.basename(nc_train_user_labels_path)}\\n\",\n        \"          - type: item\\n\",\n        \"            data:\\n\",\n        \"              - name: seeds\\n\",\n        \"                format: numpy\\n\",\n        \"                path: {os.path.basename(nc_train_item_ids_path)}\\n\",\n        \"              - name: labels\\n\",\n        \"                format: torch\\n\",\n        \"                path: {os.path.basename(nc_train_item_labels_path)}\\n\",\n        \"        validation_set:\\n\",\n        \"          - type: user\\n\",\n        \"            data:\\n\",\n        \"              - name: seeds\\n\",\n        \"                format: numpy\\n\",\n        \"                path: {os.path.basename(nc_val_user_ids_path)}\\n\",\n        \"              - name: labels\\n\",\n        \"                format: torch\\n\",\n        \"                path: {os.path.basename(nc_val_user_labels_path)}\\n\",\n        \"          - type: item\\n\",\n        \"            data:\\n\",\n        \"              - name: seeds\\n\",\n        \"                format: numpy\\n\",\n        \"                path: {os.path.basename(nc_val_item_ids_path)}\\n\",\n        \"              - name: labels\\n\",\n        \"                format: torch\\n\",\n        \"                path: {os.path.basename(nc_val_item_labels_path)}\\n\",\n        \"        test_set:\\n\",\n        \"          - type: user\\n\",\n        \"            data:\\n\",\n        \"              - name: seeds\\n\",\n        \"                format: numpy\\n\",\n        \"                path: {os.path.basename(nc_test_user_ids_path)}\\n\",\n        \"              - name: labels\\n\",\n        \"                format: torch\\n\",\n        \"                path: {os.path.basename(nc_test_user_labels_path)}\\n\",\n        \"          - type: item\\n\",\n        \"            data:\\n\",\n        \"              - name: seeds\\n\",\n        \"                format: numpy\\n\",\n        \"                path: {os.path.basename(nc_test_item_ids_path)}\\n\",\n        \"              - name: labels\\n\",\n        \"                format: torch\\n\",\n        \"                path: {os.path.basename(nc_test_item_labels_path)}\\n\",\n        \"      - name: link_prediction\\n\",\n        \"        num_classes: 10\\n\",\n        \"        train_set:\\n\",\n        \"          - type: \\\"user:like:item\\\"\\n\",\n        \"            data:\\n\",\n        \"              - name: seeds\\n\",\n        \"                format: numpy\\n\",\n        \"                path: {os.path.basename(lp_train_like_seeds_path)}\\n\",\n        \"          - type: \\\"user:follow:user\\\"\\n\",\n        \"            data:\\n\",\n        \"              - name: seeds\\n\",\n        \"                format: numpy\\n\",\n        \"                path: {os.path.basename(lp_train_follow_seeds_path)}\\n\",\n        \"        validation_set:\\n\",\n        \"          - type: \\\"user:like:item\\\"\\n\",\n        \"            data:\\n\",\n        \"              - name: seeds\\n\",\n        \"                format: numpy\\n\",\n        \"                path: {os.path.basename(lp_val_like_seeds_path)}\\n\",\n        \"              - name: labels\\n\",\n        \"                format: numpy\\n\",\n        \"                path: {os.path.basename(lp_val_like_labels_path)}\\n\",\n        \"              - name: indexes\\n\",\n        \"                format: numpy\\n\",\n        \"                path: {os.path.basename(lp_val_like_indexes_path)}\\n\",\n        \"          - type: \\\"user:follow:user\\\"\\n\",\n        \"            data:\\n\",\n        \"              - name: seeds\\n\",\n        \"                format: numpy\\n\",\n        \"                path: {os.path.basename(lp_val_follow_seeds_path)}\\n\",\n        \"              - name: labels\\n\",\n        \"                format: numpy\\n\",\n        \"                path: {os.path.basename(lp_val_follow_labels_path)}\\n\",\n        \"              - name: indexes\\n\",\n        \"                format: numpy\\n\",\n        \"                path: {os.path.basename(lp_val_follow_indexes_path)}\\n\",\n        \"        test_set:\\n\",\n        \"          - type: \\\"user:like:item\\\"\\n\",\n        \"            data:\\n\",\n        \"              - name: seeds\\n\",\n        \"                format: numpy\\n\",\n        \"                path: {os.path.basename(lp_test_like_seeds_path)}\\n\",\n        \"              - name: labels\\n\",\n        \"                format: numpy\\n\",\n        \"                path: {os.path.basename(lp_test_like_labels_path)}\\n\",\n        \"              - name: indexes\\n\",\n        \"                format: numpy\\n\",\n        \"                path: {os.path.basename(lp_test_like_indexes_path)}\\n\",\n        \"          - type: \\\"user:follow:user\\\"\\n\",\n        \"            data:\\n\",\n        \"              - name: seeds\\n\",\n        \"                format: numpy\\n\",\n        \"                path: {os.path.basename(lp_test_follow_seeds_path)}\\n\",\n        \"              - name: labels\\n\",\n        \"                format: numpy\\n\",\n        \"                path: {os.path.basename(lp_test_follow_labels_path)}\\n\",\n        \"              - name: indexes\\n\",\n        \"                format: numpy\\n\",\n        \"                path: {os.path.basename(lp_test_follow_indexes_path)}\\n\",\n        \"\\\"\\\"\\\"\\n\",\n        \"metadata_path = os.path.join(base_dir, \\\"metadata.yaml\\\")\\n\",\n        \"with open(metadata_path, \\\"w\\\") as f:\\n\",\n        \"  f.write(yaml_content)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"kEfybHGhOW7O\"\n      },\n      \"source\": [\n        \"## Instantiate `OnDiskDataset`\\n\",\n        \"Now we're ready to load dataset via `dgl.graphbolt.OnDiskDataset`. When instantiating, we just pass in the base directory where `metadata.yaml` file lies.\\n\",\n        \"\\n\",\n        \"During first instantiation, GraphBolt preprocesses the raw data such as constructing `FusedCSCSamplingGraph` from edges. All data including graph, feature data, training/validation/test sets are put into `preprocessed` directory after preprocessing. Any following dataset loading will skip the preprocess stage.\\n\",\n        \"\\n\",\n        \"After preprocessing, `load()` is required to be called explicitly in order to load graph, feature data and tasks.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"W58CZoSzOiyo\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"dataset = gb.OnDiskDataset(base_dir).load()\\n\",\n        \"graph = dataset.graph\\n\",\n        \"print(f\\\"Loaded graph: {graph}\\\\n\\\")\\n\",\n        \"\\n\",\n        \"feature = dataset.feature\\n\",\n        \"print(f\\\"Loaded feature store: {feature}\\\\n\\\")\\n\",\n        \"\\n\",\n        \"tasks = dataset.tasks\\n\",\n        \"nc_task = tasks[0]\\n\",\n        \"print(f\\\"Loaded node classification task: {nc_task}\\\\n\\\")\\n\",\n        \"lp_task = tasks[1]\\n\",\n        \"print(f\\\"Loaded link prediction task: {lp_task}\\\\n\\\")\"\n      ]\n    }\n  ],\n  \"metadata\": {\n    \"colab\": {\n      \"private_outputs\": true,\n      \"provenance\": []\n    },\n    \"kernelspec\": {\n      \"display_name\": \"Python 3\",\n      \"name\": \"python3\"\n    },\n    \"language_info\": {\n      \"codemirror_mode\": {\n        \"name\": \"ipython\",\n        \"version\": 3\n      },\n      \"file_extension\": \".py\",\n      \"mimetype\": \"text/x-python\",\n      \"name\": \"python\",\n      \"nbconvert_exporter\": \"python\",\n      \"pygments_lexer\": \"ipython3\",\n      \"version\": \"3.10.12\"\n    }\n  },\n  \"nbformat\": 4,\n  \"nbformat_minor\": 0\n}\n"
  },
  {
    "path": "notebooks/stochastic_training/ondisk_dataset_homograph.ipynb",
    "content": "{\n  \"cells\": [\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"FnFhPMaAfLtJ\"\n      },\n      \"source\": [\n        \"# OnDiskDataset for Homogeneous Graph\\n\",\n        \"\\n\",\n        \"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dmlc/dgl/blob/master/notebooks/stochastic_training/ondisk_dataset_homograph.ipynb) [![GitHub](https://img.shields.io/badge/-View%20on%20GitHub-181717?logo=github&logoColor=ffffff)](https://github.com/dmlc/dgl/blob/master/notebooks/stochastic_training/ondisk_dataset_homograph.ipynb)\\n\",\n        \"\\n\",\n        \"This tutorial shows how to create `OnDiskDataset` for homogeneous graph that could be used in **GraphBolt** framework.\\n\",\n        \"\\n\",\n        \"By the end of this tutorial, you will be able to\\n\",\n        \"\\n\",\n        \"- organize graph structure data.\\n\",\n        \"- organize feature data.\\n\",\n        \"- organize training/validation/test set for specific tasks.\\n\",\n        \"\\n\",\n        \"To create an ``OnDiskDataset`` object, you need to organize all the data including graph structure, feature data and tasks into a directory. The directory should contain a ``metadata.yaml`` file that describes the metadata of the dataset.\\n\",\n        \"\\n\",\n        \"Now let's generate various data step by step and organize them together to instantiate `OnDiskDataset` finally.\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"Wlb19DtWgtzq\"\n      },\n      \"source\": [\n        \"## Install DGL package\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"UojlT9ZGgyr9\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"# Install required packages.\\n\",\n        \"import os\\n\",\n        \"import torch\\n\",\n        \"import numpy as np\\n\",\n        \"os.environ['TORCH'] = torch.__version__\\n\",\n        \"os.environ['DGLBACKEND'] = \\\"pytorch\\\"\\n\",\n        \"\\n\",\n        \"# Install the CPU version.\\n\",\n        \"device = torch.device(\\\"cpu\\\")\\n\",\n        \"!pip install --pre dgl -f https://data.dgl.ai/wheels-test/repo.html\\n\",\n        \"\\n\",\n        \"try:\\n\",\n        \"    import dgl\\n\",\n        \"    import dgl.graphbolt as gb\\n\",\n        \"    installed = True\\n\",\n        \"except ImportError as error:\\n\",\n        \"    installed = False\\n\",\n        \"    print(error)\\n\",\n        \"print(\\\"DGL installed!\\\" if installed else \\\"DGL not found!\\\")\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"2R7WnSbjsfbr\"\n      },\n      \"source\": [\n        \"## Data preparation\\n\",\n        \"In order to demonstrate how to organize various data, let's create a base directory first.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"SZipbzyltLfO\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"base_dir = './ondisk_dataset_homograph'\\n\",\n        \"os.makedirs(base_dir, exist_ok=True)\\n\",\n        \"print(f\\\"Created base directory: {base_dir}\\\")\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"qhNtIn_xhlnl\"\n      },\n      \"source\": [\n        \"### Generate graph structure data\\n\",\n        \"For homogeneous graph, we just need to save edges(namely seeds) into  **Numpy** or **CSV** file.\\n\",\n        \"\\n\",\n        \"Note:\\n\",\n        \"- when saving to **Numpy**, the array requires to be in shape of `(2, N)`. This format is recommended as constructing graph from it is much faster than **CSV** file.\\n\",\n        \"- when saving to **CSV** file, do not save index and header.\\n\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"HcBt4G5BmSjr\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"import numpy as np\\n\",\n        \"import pandas as pd\\n\",\n        \"num_nodes = 1000\\n\",\n        \"num_edges = 10 * num_nodes\\n\",\n        \"edges_path = os.path.join(base_dir, \\\"edges.csv\\\")\\n\",\n        \"edges = np.random.randint(0, num_nodes, size=(num_edges, 2))\\n\",\n        \"\\n\",\n        \"print(f\\\"Part of edges: {edges[:5, :]}\\\")\\n\",\n        \"\\n\",\n        \"df = pd.DataFrame(edges)\\n\",\n        \"df.to_csv(edges_path, index=False, header=False)\\n\",\n        \"\\n\",\n        \"print(f\\\"Edges are saved into {edges_path}\\\")\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"kh-4cPtzpcaH\"\n      },\n      \"source\": [\n        \"### Generate feature data for graph\\n\",\n        \"For feature data, numpy arrays and torch tensors are supported for now.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"_PVu1u5brBhF\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"# Generate node feature in numpy array.\\n\",\n        \"node_feat_0_path = os.path.join(base_dir, \\\"node-feat-0.npy\\\")\\n\",\n        \"node_feat_0 = np.random.rand(num_nodes, 5)\\n\",\n        \"print(f\\\"Part of node feature [feat_0]: {node_feat_0[:3, :]}\\\")\\n\",\n        \"np.save(node_feat_0_path, node_feat_0)\\n\",\n        \"print(f\\\"Node feature [feat_0] is saved to {node_feat_0_path}\\\\n\\\")\\n\",\n        \"\\n\",\n        \"# Generate another node feature in torch tensor\\n\",\n        \"node_feat_1_path = os.path.join(base_dir, \\\"node-feat-1.pt\\\")\\n\",\n        \"node_feat_1 = torch.rand(num_nodes, 5)\\n\",\n        \"print(f\\\"Part of node feature [feat_1]: {node_feat_1[:3, :]}\\\")\\n\",\n        \"torch.save(node_feat_1, node_feat_1_path)\\n\",\n        \"print(f\\\"Node feature [feat_1] is saved to {node_feat_1_path}\\\\n\\\")\\n\",\n        \"\\n\",\n        \"# Generate edge feature in numpy array.\\n\",\n        \"edge_feat_0_path = os.path.join(base_dir, \\\"edge-feat-0.npy\\\")\\n\",\n        \"edge_feat_0 = np.random.rand(num_edges, 5)\\n\",\n        \"print(f\\\"Part of edge feature [feat_0]: {edge_feat_0[:3, :]}\\\")\\n\",\n        \"np.save(edge_feat_0_path, edge_feat_0)\\n\",\n        \"print(f\\\"Edge feature [feat_0] is saved to {edge_feat_0_path}\\\\n\\\")\\n\",\n        \"\\n\",\n        \"# Generate another edge feature in torch tensor\\n\",\n        \"edge_feat_1_path = os.path.join(base_dir, \\\"edge-feat-1.pt\\\")\\n\",\n        \"edge_feat_1 = torch.rand(num_edges, 5)\\n\",\n        \"print(f\\\"Part of edge feature [feat_1]: {edge_feat_1[:3, :]}\\\")\\n\",\n        \"torch.save(edge_feat_1, edge_feat_1_path)\\n\",\n        \"print(f\\\"Edge feature [feat_1] is saved to {edge_feat_1_path}\\\\n\\\")\\n\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"ZyqgOtsIwzh_\"\n      },\n      \"source\": [\n        \"### Generate tasks\\n\",\n        \"`OnDiskDataset` supports multiple tasks. For each task, we need to prepare training/validation/test sets respectively. Such sets usually vary among different tasks. In this tutorial, let's create a **Node Classification** task and **Link Prediction** task.\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"hVxHaDIfzCkr\"\n      },\n      \"source\": [\n        \"#### Node Classification Task\\n\",\n        \"For node classification task, we need **node IDs** and corresponding **labels** for each training/validation/test set. Like feature data, numpy arrays and torch tensors are supported for these sets.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"S5-fyBbHzTCO\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"num_trains = int(num_nodes * 0.6)\\n\",\n        \"num_vals = int(num_nodes * 0.2)\\n\",\n        \"num_tests = num_nodes - num_trains - num_vals\\n\",\n        \"\\n\",\n        \"ids = np.arange(num_nodes)\\n\",\n        \"np.random.shuffle(ids)\\n\",\n        \"\\n\",\n        \"nc_train_ids_path = os.path.join(base_dir, \\\"nc-train-ids.npy\\\")\\n\",\n        \"nc_train_ids = ids[:num_trains]\\n\",\n        \"print(f\\\"Part of train ids for node classification: {nc_train_ids[:3]}\\\")\\n\",\n        \"np.save(nc_train_ids_path, nc_train_ids)\\n\",\n        \"print(f\\\"NC train ids are saved to {nc_train_ids_path}\\\\n\\\")\\n\",\n        \"\\n\",\n        \"nc_train_labels_path = os.path.join(base_dir, \\\"nc-train-labels.pt\\\")\\n\",\n        \"nc_train_labels = torch.randint(0, 10, (num_trains,))\\n\",\n        \"print(f\\\"Part of train labels for node classification: {nc_train_labels[:3]}\\\")\\n\",\n        \"torch.save(nc_train_labels, nc_train_labels_path)\\n\",\n        \"print(f\\\"NC train labels are saved to {nc_train_labels_path}\\\\n\\\")\\n\",\n        \"\\n\",\n        \"nc_val_ids_path = os.path.join(base_dir, \\\"nc-val-ids.npy\\\")\\n\",\n        \"nc_val_ids = ids[num_trains:num_trains+num_vals]\\n\",\n        \"print(f\\\"Part of val ids for node classification: {nc_val_ids[:3]}\\\")\\n\",\n        \"np.save(nc_val_ids_path, nc_val_ids)\\n\",\n        \"print(f\\\"NC val ids are saved to {nc_val_ids_path}\\\\n\\\")\\n\",\n        \"\\n\",\n        \"nc_val_labels_path = os.path.join(base_dir, \\\"nc-val-labels.pt\\\")\\n\",\n        \"nc_val_labels = torch.randint(0, 10, (num_vals,))\\n\",\n        \"print(f\\\"Part of val labels for node classification: {nc_val_labels[:3]}\\\")\\n\",\n        \"torch.save(nc_val_labels, nc_val_labels_path)\\n\",\n        \"print(f\\\"NC val labels are saved to {nc_val_labels_path}\\\\n\\\")\\n\",\n        \"\\n\",\n        \"nc_test_ids_path = os.path.join(base_dir, \\\"nc-test-ids.npy\\\")\\n\",\n        \"nc_test_ids = ids[-num_tests:]\\n\",\n        \"print(f\\\"Part of test ids for node classification: {nc_test_ids[:3]}\\\")\\n\",\n        \"np.save(nc_test_ids_path, nc_test_ids)\\n\",\n        \"print(f\\\"NC test ids are saved to {nc_test_ids_path}\\\\n\\\")\\n\",\n        \"\\n\",\n        \"nc_test_labels_path = os.path.join(base_dir, \\\"nc-test-labels.pt\\\")\\n\",\n        \"nc_test_labels = torch.randint(0, 10, (num_tests,))\\n\",\n        \"print(f\\\"Part of test labels for node classification: {nc_test_labels[:3]}\\\")\\n\",\n        \"torch.save(nc_test_labels, nc_test_labels_path)\\n\",\n        \"print(f\\\"NC test labels are saved to {nc_test_labels_path}\\\\n\\\")\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"LhAcDCHQ_KJ0\"\n      },\n      \"source\": [\n        \"#### Link Prediction Task\\n\",\n        \"For link prediction task, we need **seeds** or **corresponding labels and indexes** which representing the pos/neg property and group of the seeds for each training/validation/test set. Like feature data, numpy arrays and torch tensors are supported for these sets.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"u0jCnXIcAQy4\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"num_trains = int(num_edges * 0.6)\\n\",\n        \"num_vals = int(num_edges * 0.2)\\n\",\n        \"num_tests = num_edges - num_trains - num_vals\\n\",\n        \"\\n\",\n        \"lp_train_seeds_path = os.path.join(base_dir, \\\"lp-train-seeds.npy\\\")\\n\",\n        \"lp_train_seeds = edges[:num_trains, :]\\n\",\n        \"print(f\\\"Part of train seeds for link prediction: {lp_train_seeds[:3]}\\\")\\n\",\n        \"np.save(lp_train_seeds_path, lp_train_seeds)\\n\",\n        \"print(f\\\"LP train seeds are saved to {lp_train_seeds_path}\\\\n\\\")\\n\",\n        \"\\n\",\n        \"lp_val_seeds_path = os.path.join(base_dir, \\\"lp-val-seeds.npy\\\")\\n\",\n        \"lp_val_seeds = edges[num_trains:num_trains+num_vals, :]\\n\",\n        \"lp_val_neg_dsts = np.random.randint(0, num_nodes, (num_vals, 10)).reshape(-1)\\n\",\n        \"lp_val_neg_srcs = np.repeat(lp_val_seeds[:,0], 10)\\n\",\n        \"lp_val_neg_seeds = np.concatenate((lp_val_neg_srcs, lp_val_neg_dsts)).reshape(2,-1).T\\n\",\n        \"lp_val_seeds = np.concatenate((lp_val_seeds, lp_val_neg_seeds))\\n\",\n        \"print(f\\\"Part of val seeds for link prediction: {lp_val_seeds[:3]}\\\")\\n\",\n        \"np.save(lp_val_seeds_path, lp_val_seeds)\\n\",\n        \"print(f\\\"LP val seeds are saved to {lp_val_seeds_path}\\\\n\\\")\\n\",\n        \"\\n\",\n        \"lp_val_labels_path = os.path.join(base_dir, \\\"lp-val-labels.npy\\\")\\n\",\n        \"lp_val_labels = np.empty(num_vals * (10 + 1))\\n\",\n        \"lp_val_labels[:num_vals] = 1\\n\",\n        \"lp_val_labels[num_vals:] = 0\\n\",\n        \"print(f\\\"Part of val labels for link prediction: {lp_val_labels[:3]}\\\")\\n\",\n        \"np.save(lp_val_labels_path, lp_val_labels)\\n\",\n        \"print(f\\\"LP val labels are saved to {lp_val_labels_path}\\\\n\\\")\\n\",\n        \"\\n\",\n        \"lp_val_indexes_path = os.path.join(base_dir, \\\"lp-val-indexes.npy\\\")\\n\",\n        \"lp_val_indexes = np.arange(0, num_vals)\\n\",\n        \"lp_val_neg_indexes = np.repeat(lp_val_indexes, 10)\\n\",\n        \"lp_val_indexes = np.concatenate([lp_val_indexes, lp_val_neg_indexes])\\n\",\n        \"print(f\\\"Part of val indexes for link prediction: {lp_val_indexes[:3]}\\\")\\n\",\n        \"np.save(lp_val_indexes_path, lp_val_indexes)\\n\",\n        \"print(f\\\"LP val indexes are saved to {lp_val_indexes_path}\\\\n\\\")\\n\",\n        \"\\n\",\n        \"lp_test_seeds_path = os.path.join(base_dir, \\\"lp-test-seeds.npy\\\")\\n\",\n        \"lp_test_seeds = edges[-num_tests:, :]\\n\",\n        \"lp_test_neg_dsts = np.random.randint(0, num_nodes, (num_tests, 10)).reshape(-1)\\n\",\n        \"lp_test_neg_srcs = np.repeat(lp_test_seeds[:,0], 10)\\n\",\n        \"lp_test_neg_seeds = np.concatenate((lp_test_neg_srcs, lp_test_neg_dsts)).reshape(2,-1).T\\n\",\n        \"lp_test_seeds = np.concatenate((lp_test_seeds, lp_test_neg_seeds))\\n\",\n        \"print(f\\\"Part of test seeds for link prediction: {lp_test_seeds[:3]}\\\")\\n\",\n        \"np.save(lp_test_seeds_path, lp_test_seeds)\\n\",\n        \"print(f\\\"LP test seeds are saved to {lp_test_seeds_path}\\\\n\\\")\\n\",\n        \"\\n\",\n        \"lp_test_labels_path = os.path.join(base_dir, \\\"lp-test-labels.npy\\\")\\n\",\n        \"lp_test_labels = np.empty(num_tests * (10 + 1))\\n\",\n        \"lp_test_labels[:num_tests] = 1\\n\",\n        \"lp_test_labels[num_tests:] = 0\\n\",\n        \"print(f\\\"Part of val labels for link prediction: {lp_test_labels[:3]}\\\")\\n\",\n        \"np.save(lp_test_labels_path, lp_test_labels)\\n\",\n        \"print(f\\\"LP test labels are saved to {lp_test_labels_path}\\\\n\\\")\\n\",\n        \"\\n\",\n        \"lp_test_indexes_path = os.path.join(base_dir, \\\"lp-test-indexes.npy\\\")\\n\",\n        \"lp_test_indexes = np.arange(0, num_tests)\\n\",\n        \"lp_test_neg_indexes = np.repeat(lp_test_indexes, 10)\\n\",\n        \"lp_test_indexes = np.concatenate([lp_test_indexes, lp_test_neg_indexes])\\n\",\n        \"print(f\\\"Part of test indexes for link prediction: {lp_test_indexes[:3]}\\\")\\n\",\n        \"np.save(lp_test_indexes_path, lp_test_indexes)\\n\",\n        \"print(f\\\"LP test indexes are saved to {lp_test_indexes_path}\\\\n\\\")\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"wbk6-wxRK-6S\"\n      },\n      \"source\": [\n        \"## Organize Data into YAML File\\n\",\n        \"Now we need to create a `metadata.yaml` file which contains the paths, dadta types of graph structure, feature data, training/validation/test sets.\\n\",\n        \"\\n\",\n        \"Notes:\\n\",\n        \"- all path should be relative to `metadata.yaml`.\\n\",\n        \"- Below fields are optional and not specified in below example.\\n\",\n        \"  - `in_memory`: indicates whether to load dada into memory or `mmap`. Default is `True`.\\n\",\n        \"\\n\",\n        \"Please refer to [YAML specification](https://github.com/dmlc/dgl/blob/master/docs/source/stochastic_training/ondisk-dataset-specification.rst) for more details.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"ddGTWW61Lpwp\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"yaml_content = f\\\"\\\"\\\"\\n\",\n        \"    dataset_name: homogeneous_graph_nc_lp\\n\",\n        \"    graph:\\n\",\n        \"      nodes:\\n\",\n        \"        - num: {num_nodes}\\n\",\n        \"      edges:\\n\",\n        \"        - format: csv\\n\",\n        \"          path: {os.path.basename(edges_path)}\\n\",\n        \"    feature_data:\\n\",\n        \"      - domain: node\\n\",\n        \"        name: feat_0\\n\",\n        \"        format: numpy\\n\",\n        \"        path: {os.path.basename(node_feat_0_path)}\\n\",\n        \"      - domain: node\\n\",\n        \"        name: feat_1\\n\",\n        \"        format: torch\\n\",\n        \"        path: {os.path.basename(node_feat_1_path)}\\n\",\n        \"      - domain: edge\\n\",\n        \"        name: feat_0\\n\",\n        \"        format: numpy\\n\",\n        \"        path: {os.path.basename(edge_feat_0_path)}\\n\",\n        \"      - domain: edge\\n\",\n        \"        name: feat_1\\n\",\n        \"        format: torch\\n\",\n        \"        path: {os.path.basename(edge_feat_1_path)}\\n\",\n        \"    tasks:\\n\",\n        \"      - name: node_classification\\n\",\n        \"        num_classes: 10\\n\",\n        \"        train_set:\\n\",\n        \"          - data:\\n\",\n        \"              - name: seeds\\n\",\n        \"                format: numpy\\n\",\n        \"                path: {os.path.basename(nc_train_ids_path)}\\n\",\n        \"              - name: labels\\n\",\n        \"                format: torch\\n\",\n        \"                path: {os.path.basename(nc_train_labels_path)}\\n\",\n        \"        validation_set:\\n\",\n        \"          - data:\\n\",\n        \"              - name: seeds\\n\",\n        \"                format: numpy\\n\",\n        \"                path: {os.path.basename(nc_val_ids_path)}\\n\",\n        \"              - name: labels\\n\",\n        \"                format: torch\\n\",\n        \"                path: {os.path.basename(nc_val_labels_path)}\\n\",\n        \"        test_set:\\n\",\n        \"          - data:\\n\",\n        \"              - name: seeds\\n\",\n        \"                format: numpy\\n\",\n        \"                path: {os.path.basename(nc_test_ids_path)}\\n\",\n        \"              - name: labels\\n\",\n        \"                format: torch\\n\",\n        \"                path: {os.path.basename(nc_test_labels_path)}\\n\",\n        \"      - name: link_prediction\\n\",\n        \"        num_classes: 10\\n\",\n        \"        train_set:\\n\",\n        \"          - data:\\n\",\n        \"              - name: seeds\\n\",\n        \"                format: numpy\\n\",\n        \"                path: {os.path.basename(lp_train_seeds_path)}\\n\",\n        \"        validation_set:\\n\",\n        \"          - data:\\n\",\n        \"              - name: seeds\\n\",\n        \"                format: numpy\\n\",\n        \"                path: {os.path.basename(lp_val_seeds_path)}\\n\",\n        \"              - name: labels\\n\",\n        \"                format: numpy\\n\",\n        \"                path: {os.path.basename(lp_val_labels_path)}\\n\",\n        \"              - name: indexes\\n\",\n        \"                format: numpy\\n\",\n        \"                path: {os.path.basename(lp_val_indexes_path)}\\n\",\n        \"        test_set:\\n\",\n        \"          - data:\\n\",\n        \"              - name: seeds\\n\",\n        \"                format: numpy\\n\",\n        \"                path: {os.path.basename(lp_test_seeds_path)}\\n\",\n        \"              - name: labels\\n\",\n        \"                format: numpy\\n\",\n        \"                path: {os.path.basename(lp_test_labels_path)}\\n\",\n        \"              - name: indexes\\n\",\n        \"                format: numpy\\n\",\n        \"                path: {os.path.basename(lp_test_indexes_path)}\\n\",\n        \"\\\"\\\"\\\"\\n\",\n        \"metadata_path = os.path.join(base_dir, \\\"metadata.yaml\\\")\\n\",\n        \"with open(metadata_path, \\\"w\\\") as f:\\n\",\n        \"  f.write(yaml_content)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"kEfybHGhOW7O\"\n      },\n      \"source\": [\n        \"## Instantiate `OnDiskDataset`\\n\",\n        \"Now we're ready to load dataset via `dgl.graphbolt.OnDiskDataset`. When instantiating, we just pass in the base directory where `metadata.yaml` file lies.\\n\",\n        \"\\n\",\n        \"During first instantiation, GraphBolt preprocesses the raw data such as constructing `FusedCSCSamplingGraph` from edges. All data including graph, feature data, training/validation/test sets are put into `preprocessed` directory after preprocessing. Any following dataset loading will skip the preprocess stage.\\n\",\n        \"\\n\",\n        \"After preprocessing, `load()` is required to be called explicitly in order to load graph, feature data and tasks.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"W58CZoSzOiyo\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"dataset = gb.OnDiskDataset(base_dir).load()\\n\",\n        \"graph = dataset.graph\\n\",\n        \"print(f\\\"Loaded graph: {graph}\\\\n\\\")\\n\",\n        \"\\n\",\n        \"feature = dataset.feature\\n\",\n        \"print(f\\\"Loaded feature store: {feature}\\\\n\\\")\\n\",\n        \"\\n\",\n        \"tasks = dataset.tasks\\n\",\n        \"nc_task = tasks[0]\\n\",\n        \"print(f\\\"Loaded node classification task: {nc_task}\\\\n\\\")\\n\",\n        \"lp_task = tasks[1]\\n\",\n        \"print(f\\\"Loaded link prediction task: {lp_task}\\\\n\\\")\"\n      ]\n    }\n  ],\n  \"metadata\": {\n    \"colab\": {\n      \"private_outputs\": true,\n      \"provenance\": []\n    },\n    \"kernelspec\": {\n      \"display_name\": \"Python 3\",\n      \"name\": \"python3\"\n    },\n    \"language_info\": {\n      \"codemirror_mode\": {\n        \"name\": \"ipython\",\n        \"version\": 3\n      },\n      \"file_extension\": \".py\",\n      \"mimetype\": \"text/x-python\",\n      \"name\": \"python\",\n      \"nbconvert_exporter\": \"python\",\n      \"pygments_lexer\": \"ipython3\",\n      \"version\": \"3.10.12\"\n    }\n  },\n  \"nbformat\": 4,\n  \"nbformat_minor\": 0\n}\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[tool.black]\n\nline-length = 80\n"
  },
  {
    "path": "python/dgl/__init__.py",
    "content": "\"\"\"\nThe ``dgl`` package contains data structure for storing structural and feature data\n(i.e., the :class:`DGLGraph` class) and also utilities for generating, manipulating\nand transforming graphs.\n\"\"\"\n\n\n# Windows compatibility\n# This initializes Winsock and performs cleanup at termination as required\nimport socket\n\n# Backend and logging should be imported before other modules.\nfrom .logging import enable_verbose_logging  # usort: skip\nfrom .backend import backend_name, load_backend  # usort: skip\n\nfrom . import (\n    container,\n    cuda,\n    dataloading,\n    function,\n    ops,\n    random,\n    sampling,\n    storages,\n)\nfrom ._ffi.base import __version__, DGLError\nfrom ._ffi.function import (\n    extract_ext_funcs,\n    get_global_func,\n    list_global_func_names,\n    register_func,\n)\n\nfrom ._ffi.runtime_ctypes import TypeCode\n\nfrom .base import ALL, EID, ETYPE, NID, NTYPE\nfrom .readout import *\nfrom .batch import *\nfrom .convert import *\nfrom .generators import *\nfrom .dataloading import (\n    set_dst_lazy_features,\n    set_edge_lazy_features,\n    set_node_lazy_features,\n    set_src_lazy_features,\n)\nfrom .heterograph import (  # pylint: disable=reimported\n    DGLGraph,\n    DGLGraph as DGLHeteroGraph,\n)\nfrom .merge import *\nfrom .subgraph import *\nfrom .traversal import *\nfrom .transforms import *\nfrom .propagate import *\nfrom .random import *\nfrom . import optim\nfrom .data.utils import load_graphs, save_graphs\nfrom .frame import LazyFeature\nfrom .global_config import is_libxsmm_enabled, use_libxsmm\nfrom .utils import apply_each\nfrom .mpops import *\nfrom .homophily import *\nfrom .label_informativeness import *\n"
  },
  {
    "path": "python/dgl/_api_internal.py",
    "content": "\"\"\"Namespace for internal apis.\"\"\"\n"
  },
  {
    "path": "python/dgl/_ffi/README.md",
    "content": "# C API and runtime\n\nBorrowed and adapted from TVM project. (commit: 2ce5277)\n"
  },
  {
    "path": "python/dgl/_ffi/__init__.py",
    "content": "\"\"\"C interfacing code.\n\nThis namespace contains everything that interacts with C code.\nMost C related object are ctypes compatible, which means\nthey contains a handle field that is ctypes.c_void_p and can\nbe used via ctypes function calls.\n\nSome performance critical functions are implemented by cython\nand have a ctypes fallback implementation.\n\"\"\"\n"
  },
  {
    "path": "python/dgl/_ffi/_ctypes/__init__.py",
    "content": "\"\"\"ctypes specific implementation of FFI\"\"\"\n"
  },
  {
    "path": "python/dgl/_ffi/_ctypes/function.py",
    "content": "# coding: utf-8\n# pylint: disable=invalid-name, protected-access, too-many-branches, global-statement\n\"\"\"Function configuration API.\"\"\"\nfrom __future__ import absolute_import\n\nimport ctypes\nimport traceback\nfrom numbers import Integral, Number\n\nfrom ..base import _LIB, c_str, check_call, string_types\nfrom ..object_generic import convert_to_object, ObjectGeneric\nfrom ..runtime_ctypes import DGLByteArray, DGLContext, DGLDataType\nfrom . import ndarray as _nd, object as _object\nfrom .ndarray import _make_array, NDArrayBase\nfrom .object import ObjectBase\nfrom .types import (\n    _wrap_arg_func,\n    C_TO_PY_ARG_SWITCH,\n    DGLCFuncFinalizer,\n    DGLPackedCFunc,\n    DGLValue,\n    RETURN_SWITCH,\n    TypeCode,\n)\n\nFunctionHandle = ctypes.c_void_p\nModuleHandle = ctypes.c_void_p\nDGLRetValueHandle = ctypes.c_void_p\n\n\ndef _ctypes_free_resource(rhandle):\n    \"\"\"callback to free resources when it it not needed.\"\"\"\n    pyobj = ctypes.cast(rhandle, ctypes.py_object)\n    ctypes.pythonapi.Py_DecRef(pyobj)\n\n\n# Global callback that is always alive\nDGL_FREE_PYOBJ = DGLCFuncFinalizer(_ctypes_free_resource)\nctypes.pythonapi.Py_IncRef(ctypes.py_object(DGL_FREE_PYOBJ))\n\n\ndef convert_to_dgl_func(pyfunc):\n    \"\"\"Convert a python function to DGL function\n\n    Parameters\n    ----------\n    pyfunc : python function\n        The python function to be converted.\n\n    Returns\n    -------\n    dglfunc: dgl.nd.Function\n        The converted dgl function.\n    \"\"\"\n    local_pyfunc = pyfunc\n\n    def cfun(args, type_codes, num_args, ret, _):\n        \"\"\"ctypes function\"\"\"\n        num_args = (\n            num_args.value if isinstance(num_args, ctypes.c_int) else num_args\n        )\n        pyargs = (\n            C_TO_PY_ARG_SWITCH[type_codes[i]](args[i]) for i in range(num_args)\n        )\n        # pylint: disable=broad-except\n        try:\n            rv = local_pyfunc(*pyargs)\n        except Exception:\n            msg = traceback.format_exc()\n            _LIB.DGLAPISetLastError(c_str(msg))\n            return -1\n\n        if rv is not None:\n            if isinstance(rv, tuple):\n                raise ValueError(\n                    \"PackedFunction can only support one return value\"\n                )\n            temp_args = []\n            values, tcodes, _ = _make_dgl_args((rv,), temp_args)\n            if not isinstance(ret, DGLRetValueHandle):\n                ret = DGLRetValueHandle(ret)\n            check_call(\n                _LIB.DGLCFuncSetReturn(ret, values, tcodes, ctypes.c_int(1))\n            )\n            _ = temp_args\n            _ = rv\n        return 0\n\n    handle = FunctionHandle()\n    f = DGLPackedCFunc(cfun)\n    # NOTE: We will need to use python-api to increase ref count of the f\n    # DGL_FREE_PYOBJ will be called after it is no longer needed.\n    pyobj = ctypes.py_object(f)\n    ctypes.pythonapi.Py_IncRef(pyobj)\n    check_call(\n        _LIB.DGLFuncCreateFromCFunc(\n            f, pyobj, DGL_FREE_PYOBJ, ctypes.byref(handle)\n        )\n    )\n    return _CLASS_FUNCTION(handle, False)\n\n\ndef _make_dgl_args(args, temp_args):\n    \"\"\"Pack arguments into c args dgl call accept.\n\n    temp_args is used to temporarily save the arguments so they will not be\n    freed during C API function call.\n    \"\"\"\n    num_args = len(args)\n    values = (DGLValue * num_args)()\n    type_codes = (ctypes.c_int * num_args)()\n    for i, arg in enumerate(args):\n        if arg is None:\n            values[i].v_handle = None\n            type_codes[i] = TypeCode.NULL\n        elif isinstance(arg, ObjectBase):\n            values[i].v_handle = arg.handle\n            type_codes[i] = TypeCode.OBJECT_HANDLE\n        elif isinstance(arg, (list, tuple, dict, ObjectGeneric)):\n            arg = convert_to_object(arg)\n            values[i].v_handle = arg.handle\n            type_codes[i] = TypeCode.OBJECT_HANDLE\n            temp_args.append(arg)\n        elif isinstance(arg, NDArrayBase):\n            values[i].v_handle = ctypes.cast(arg.handle, ctypes.c_void_p)\n            type_codes[i] = (\n                TypeCode.NDARRAY_CONTAINER\n                if not arg.is_view\n                else TypeCode.ARRAY_HANDLE\n            )\n        elif isinstance(arg, _nd._DGL_COMPATS):\n            values[i].v_handle = ctypes.c_void_p(arg._dgl_handle)\n            type_codes[i] = arg.__class__._dgl_tcode\n        elif isinstance(arg, Integral):\n            values[i].v_int64 = arg\n            type_codes[i] = TypeCode.INT\n        elif isinstance(arg, Number):\n            values[i].v_float64 = arg\n            type_codes[i] = TypeCode.FLOAT\n        elif isinstance(arg, DGLDataType):\n            values[i].v_str = c_str(str(arg))\n            type_codes[i] = TypeCode.STR\n        elif isinstance(arg, DGLContext):\n            values[i].v_ctx = arg\n            type_codes[i] = TypeCode.DGL_CONTEXT\n        elif isinstance(arg, bytearray):\n            arr = DGLByteArray()\n            arr.data = ctypes.cast(\n                (ctypes.c_byte * len(arg)).from_buffer(arg),\n                ctypes.POINTER(ctypes.c_byte),\n            )\n            arr.size = len(arg)\n            values[i].v_handle = ctypes.c_void_p(ctypes.addressof(arr))\n            temp_args.append(arr)\n            type_codes[i] = TypeCode.BYTES\n        elif isinstance(arg, string_types):\n            values[i].v_str = c_str(arg)\n            type_codes[i] = TypeCode.STR\n        # NOTE(minjie): module is not used in DGL\n        # elif isinstance(arg, _CLASS_MODULE):\n        #    values[i].v_handle = arg.handle\n        #    type_codes[i] = TypeCode.MODULE_HANDLE\n        elif isinstance(arg, FunctionBase):\n            values[i].v_handle = arg.handle\n            type_codes[i] = TypeCode.FUNC_HANDLE\n        elif isinstance(arg, ctypes.c_void_p):\n            values[i].v_handle = arg\n            type_codes[i] = TypeCode.HANDLE\n        elif callable(arg):\n            arg = convert_to_dgl_func(arg)\n            values[i].v_handle = arg.handle\n            type_codes[i] = TypeCode.FUNC_HANDLE\n            temp_args.append(arg)\n        else:\n            raise TypeError(\"Don't know how to handle type %s\" % type(arg))\n    return values, type_codes, num_args\n\n\nclass FunctionBase(object):\n    \"\"\"Function base.\"\"\"\n\n    __slots__ = [\"handle\", \"is_global\"]\n    # pylint: disable=no-member\n    def __init__(self, handle, is_global):\n        \"\"\"Initialize the function with handle\n\n        Parameters\n        ----------\n        handle : FunctionHandle\n            the handle to the underlying function.\n\n        is_global : bool\n            Whether this is a global function in python\n        \"\"\"\n        self.handle = handle\n        self.is_global = is_global\n\n    def __del__(self):\n        if not self.is_global and _LIB is not None:\n            check_call(_LIB.DGLFuncFree(self.handle))\n\n    def __call__(self, *args):\n        \"\"\"Call the function with positional arguments\n\n        args : list\n           The positional arguments to the function call.\n        \"\"\"\n        temp_args = []\n        values, tcodes, num_args = _make_dgl_args(args, temp_args)\n        ret_val = DGLValue()\n        ret_tcode = ctypes.c_int()\n        check_call(\n            _LIB.DGLFuncCall(\n                self.handle,\n                values,\n                tcodes,\n                ctypes.c_int(num_args),\n                ctypes.byref(ret_val),\n                ctypes.byref(ret_tcode),\n            )\n        )\n        _ = temp_args\n        _ = args\n        return RETURN_SWITCH[ret_tcode.value](ret_val)\n\n\ndef __init_handle_by_constructor__(fconstructor, args):\n    \"\"\"Initialize handle by constructor\"\"\"\n    temp_args = []\n    values, tcodes, num_args = _make_dgl_args(args, temp_args)\n    ret_val = DGLValue()\n    ret_tcode = ctypes.c_int()\n    check_call(\n        _LIB.DGLFuncCall(\n            fconstructor.handle,\n            values,\n            tcodes,\n            ctypes.c_int(num_args),\n            ctypes.byref(ret_val),\n            ctypes.byref(ret_tcode),\n        )\n    )\n    _ = temp_args\n    _ = args\n    assert ret_tcode.value == TypeCode.OBJECT_HANDLE\n    handle = ret_val.v_handle\n    return handle\n\n\ndef _return_module(x):\n    \"\"\"Return function\"\"\"\n    handle = x.v_handle\n    if not isinstance(handle, ModuleHandle):\n        handle = ModuleHandle(handle)\n    return _CLASS_MODULE(handle)\n\n\ndef _handle_return_func(x):\n    \"\"\"Return function\"\"\"\n    handle = x.v_handle\n    if not isinstance(handle, FunctionHandle):\n        handle = FunctionHandle(handle)\n    return _CLASS_FUNCTION(handle, False)\n\n\n# setup return handle for function type\n_object.__init_by_constructor__ = __init_handle_by_constructor__\nRETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func\nRETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module\nRETURN_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(\n    x.v_handle, False\n)\nC_TO_PY_ARG_SWITCH[TypeCode.FUNC_HANDLE] = _wrap_arg_func(\n    _handle_return_func, TypeCode.FUNC_HANDLE\n)\nC_TO_PY_ARG_SWITCH[TypeCode.MODULE_HANDLE] = _wrap_arg_func(\n    _return_module, TypeCode.MODULE_HANDLE\n)\nC_TO_PY_ARG_SWITCH[TypeCode.ARRAY_HANDLE] = lambda x: _make_array(\n    x.v_handle, True\n)\nC_TO_PY_ARG_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(\n    x.v_handle, False\n)\n\n_CLASS_MODULE = None\n_CLASS_FUNCTION = None\n\n\ndef _set_class_module(module_class):\n    \"\"\"Initialize the module.\"\"\"\n    global _CLASS_MODULE\n    _CLASS_MODULE = module_class\n\n\ndef _set_class_function(func_class):\n    global _CLASS_FUNCTION\n    _CLASS_FUNCTION = func_class\n"
  },
  {
    "path": "python/dgl/_ffi/_ctypes/ndarray.py",
    "content": "# pylint: disable=invalid-name\n\"\"\"Runtime NDArray api\"\"\"\nfrom __future__ import absolute_import\n\nimport ctypes\n\nfrom ..base import _LIB, c_str, check_call\nfrom ..runtime_ctypes import DGLArrayHandle\nfrom .types import (\n    _return_handle,\n    _wrap_arg_func,\n    C_TO_PY_ARG_SWITCH,\n    RETURN_SWITCH,\n)\n\nDGLPyCapsuleDestructor = ctypes.CFUNCTYPE(None, ctypes.c_void_p)\n_c_str_dltensor = c_str(\"dltensor\")\n_c_str_used_dltensor = c_str(\"used_dltensor\")\n\n\n# used for PyCapsule manipulation\nif hasattr(ctypes, \"pythonapi\"):\n    ctypes.pythonapi.PyCapsule_GetName.restype = ctypes.c_char_p\n    ctypes.pythonapi.PyCapsule_GetPointer.restype = ctypes.c_void_p\n    ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object\n\n\ndef _from_dlpack(dltensor):\n    dltensor = ctypes.py_object(dltensor)\n    if ctypes.pythonapi.PyCapsule_IsValid(dltensor, _c_str_dltensor):\n        ptr = ctypes.pythonapi.PyCapsule_GetPointer(dltensor, _c_str_dltensor)\n        # XXX(minjie): The below cast should be unnecessary given the code to\n        #   set restype of PyCapsule calls. But weirdly, this does not\n        #   work out always.\n        ptr = ctypes.cast(ptr, ctypes.c_void_p)\n        handle = DGLArrayHandle()\n        check_call(_LIB.DGLArrayFromDLPack(ptr, ctypes.byref(handle)))\n        ctypes.pythonapi.PyCapsule_SetName(dltensor, _c_str_used_dltensor)\n        ctypes.pythonapi.PyCapsule_SetDestructor(\n            dltensor, DGLPyCapsuleDestructor(0)\n        )\n        return _make_array(handle, False)\n    raise ValueError(\n        \"Expect a dltensor field, PyCapsule can only be consumed once\"\n    )\n\n\ndef _dlpack_deleter(pycapsule):\n    pycapsule = ctypes.cast(pycapsule, ctypes.py_object)\n    if ctypes.pythonapi.PyCapsule_IsValid(pycapsule, _c_str_dltensor):\n        ptr = ctypes.pythonapi.PyCapsule_GetPointer(pycapsule, _c_str_dltensor)\n        # XXX(minjie): The below cast should be unnecessary given the code to\n        #   set restype of PyCapsule calls. But weirdly, this does not\n        #   work out always.\n        ptr = ctypes.cast(ptr, ctypes.c_void_p)\n        _LIB.DGLDLManagedTensorCallDeleter(ptr)\n        ctypes.pythonapi.PyCapsule_SetDestructor(\n            pycapsule, DGLPyCapsuleDestructor(0)\n        )\n\n\n_c_dlpack_deleter = DGLPyCapsuleDestructor(_dlpack_deleter)\n\n\nclass NDArrayBase(object):\n    \"\"\"A simple Device/CPU Array object in runtime.\"\"\"\n\n    __slots__ = [\"handle\", \"is_view\"]\n    # pylint: disable=no-member\n    def __init__(self, handle, is_view=False):\n        \"\"\"Initialize the function with handle\n\n        Parameters\n        ----------\n        handle : DGLArrayHandle\n            the handle to the underlying C++ DGLArray\n        \"\"\"\n        self.handle = handle\n        self.is_view = is_view\n\n    def __del__(self):\n        if not self.is_view and _LIB:\n            check_call(_LIB.DGLArrayFree(self.handle))\n\n    @property\n    def _dgl_handle(self):\n        return ctypes.cast(self.handle, ctypes.c_void_p).value\n\n    def to_dlpack(self, alignment=0):\n        \"\"\"Produce an array from a DLPack Tensor without copying memory\n\n        Args\n        -------\n        alignment: int, default to be 0\n        Indicates the alignment requirement when converting to dlpack. Will copy to a\n        new tensor if the alignment requirement is not satisfied.\n        0 means no alignment requirement.\n\n\n        Returns\n        -------\n        dlpack : DLPack tensor view of the array data\n        \"\"\"\n        ptr = ctypes.c_void_p()\n        check_call(\n            _LIB.DGLArrayToDLPack(self.handle, ctypes.byref(ptr), alignment)\n        )\n        return ctypes.pythonapi.PyCapsule_New(\n            ptr, _c_str_dltensor, _c_dlpack_deleter\n        )\n\n\ndef _make_array(handle, is_view):\n    handle = ctypes.cast(handle, DGLArrayHandle)\n    return _CLASS_NDARRAY(handle, is_view)\n\n\n_DGL_COMPATS = ()\n\n\ndef _reg_extension(cls, fcreate):\n    global _DGL_COMPATS\n    _DGL_COMPATS += (cls,)\n    if fcreate:\n        fret = lambda x: fcreate(_return_handle(x))\n        RETURN_SWITCH[cls._dgl_tcode] = fret\n        C_TO_PY_ARG_SWITCH[cls._dgl_tcode] = _wrap_arg_func(\n            fret, cls._dgl_tcode\n        )\n\n\n_CLASS_NDARRAY = None\n\n\ndef _set_class_ndarray(cls):\n    global _CLASS_NDARRAY\n    _CLASS_NDARRAY = cls\n"
  },
  {
    "path": "python/dgl/_ffi/_ctypes/object.py",
    "content": "\"\"\"ctypes object API.\"\"\"\nfrom __future__ import absolute_import\n\nimport ctypes\n\nfrom ..base import _LIB, c_str, check_call\nfrom ..object_generic import _set_class_object_base\nfrom .types import (\n    _wrap_arg_func,\n    C_TO_PY_ARG_SWITCH,\n    DGLValue,\n    RETURN_SWITCH,\n    TypeCode,\n)\n\nObjectHandle = ctypes.c_void_p\n__init_by_constructor__ = None\n\n\"\"\"Maps object type to its constructor\"\"\"\nOBJECT_TYPE = {}\n\n\ndef _register_object(index, cls):\n    \"\"\"register object class in python\"\"\"\n    OBJECT_TYPE[index] = cls\n\n\ndef _return_object(x):\n    \"\"\"Construct a object object from the given DGLValue object\"\"\"\n    handle = x.v_handle\n    if not isinstance(handle, ObjectHandle):\n        handle = ObjectHandle(handle)\n    tindex = ctypes.c_int()\n    check_call(_LIB.DGLObjectGetTypeIndex(handle, ctypes.byref(tindex)))\n    cls = OBJECT_TYPE.get(tindex.value, ObjectBase)\n    # Avoid calling __init__ of cls, instead directly call __new__\n    # This allows child class to implement their own __init__\n    obj = cls.__new__(cls)\n    obj.handle = handle\n    return obj\n\n\nRETURN_SWITCH[TypeCode.OBJECT_HANDLE] = _return_object\nC_TO_PY_ARG_SWITCH[TypeCode.OBJECT_HANDLE] = _wrap_arg_func(\n    _return_object, TypeCode.OBJECT_HANDLE\n)\n\n\nclass ObjectBase(object):\n    \"\"\"Object base class\"\"\"\n\n    __slots__ = [\"handle\"]\n\n    # pylint: disable=no-member\n    def __del__(self):\n        if _LIB is not None and hasattr(self, \"handle\"):\n            check_call(_LIB.DGLObjectFree(self.handle))\n\n    def __getattr__(self, name):\n        if name == \"handle\":\n            raise AttributeError(\n                \"'handle' is a reserved attribute name that should not be used\"\n            )\n        ret_val = DGLValue()\n        ret_type_code = ctypes.c_int()\n        ret_success = ctypes.c_int()\n        check_call(\n            _LIB.DGLObjectGetAttr(\n                self.handle,\n                c_str(name),\n                ctypes.byref(ret_val),\n                ctypes.byref(ret_type_code),\n                ctypes.byref(ret_success),\n            )\n        )\n        if not ret_success.value:\n            raise AttributeError(\n                \"'%s' object has no attribute '%s'\" % (str(type(self)), name)\n            )\n        return RETURN_SWITCH[ret_type_code.value](ret_val)\n\n    def __init_handle_by_constructor__(self, fconstructor, *args):\n        \"\"\"Initialize the handle by calling constructor function.\n\n        Parameters\n        ----------\n        fconstructor : Function\n            Constructor function.\n\n        args: list of objects\n            The arguments to the constructor\n\n        Note\n        ----\n        We have a special calling convention to call constructor functions.\n        So the return handle is directly set into the Object object\n        instead of creating a new Object.\n        \"\"\"\n        # assign handle first to avoid error raising\n        self.handle = None\n        handle = __init_by_constructor__(\n            fconstructor, args\n        )  # pylint: disable=not-callable\n        if not isinstance(handle, ObjectHandle):\n            handle = ObjectHandle(handle)\n        self.handle = handle\n\n\n_set_class_object_base(ObjectBase)\n"
  },
  {
    "path": "python/dgl/_ffi/_ctypes/types.py",
    "content": "\"\"\"The C Types used in API.\"\"\"\n# pylint: disable=invalid-name\nfrom __future__ import absolute_import as _abs\n\nimport ctypes\n\nfrom ..base import _LIB, check_call, py_str\nfrom ..runtime_ctypes import DGLByteArray, DGLContext, DGLDataType, TypeCode\n\n\nclass DGLValue(ctypes.Union):\n    \"\"\"DGLValue in C API\"\"\"\n\n    _fields_ = [\n        (\"v_int64\", ctypes.c_int64),\n        (\"v_float64\", ctypes.c_double),\n        (\"v_handle\", ctypes.c_void_p),\n        (\"v_str\", ctypes.c_char_p),\n        (\"v_type\", DGLDataType),\n        (\"v_ctx\", DGLContext),\n    ]\n\n\nDGLPackedCFunc = ctypes.CFUNCTYPE(\n    ctypes.c_int,\n    ctypes.POINTER(DGLValue),\n    ctypes.POINTER(ctypes.c_int),\n    ctypes.c_int,\n    ctypes.c_void_p,\n    ctypes.c_void_p,\n)\n\n\nDGLCFuncFinalizer = ctypes.CFUNCTYPE(None, ctypes.c_void_p)\n\n\ndef _return_handle(x):\n    \"\"\"return handle\"\"\"\n    handle = x.v_handle\n    if not isinstance(handle, ctypes.c_void_p):\n        handle = ctypes.c_void_p(handle)\n    return handle\n\n\ndef _return_bytes(x):\n    \"\"\"return handle\"\"\"\n    handle = x.v_handle\n    if not isinstance(handle, ctypes.c_void_p):\n        handle = ctypes.c_void_p(handle)\n    arr = ctypes.cast(handle, ctypes.POINTER(DGLByteArray))[0]\n    size = arr.size\n    res = bytearray(size)\n    rptr = (ctypes.c_byte * size).from_buffer(res)\n    if not ctypes.memmove(rptr, arr.data, size):\n        raise RuntimeError(\"memmove failed\")\n    return res\n\n\ndef _wrap_arg_func(return_f, type_code):\n    tcode = ctypes.c_int(type_code)\n\n    def _wrap_func(x):\n        check_call(_LIB.DGLCbArgToReturn(ctypes.byref(x), tcode))\n        return return_f(x)\n\n    return _wrap_func\n\n\nRETURN_SWITCH = {\n    TypeCode.INT: lambda x: x.v_int64,\n    TypeCode.FLOAT: lambda x: x.v_float64,\n    TypeCode.HANDLE: _return_handle,\n    TypeCode.NULL: lambda x: None,\n    TypeCode.STR: lambda x: py_str(x.v_str),\n    TypeCode.BYTES: _return_bytes,\n    TypeCode.DGL_CONTEXT: lambda x: DGLContext(\n        x.v_ctx.device_type, x.v_ctx.device_id\n    ),\n}\n\nC_TO_PY_ARG_SWITCH = {\n    TypeCode.INT: lambda x: x.v_int64,\n    TypeCode.FLOAT: lambda x: x.v_float64,\n    TypeCode.HANDLE: _return_handle,\n    TypeCode.NULL: lambda x: None,\n    TypeCode.STR: lambda x: py_str(x.v_str),\n    TypeCode.BYTES: _return_bytes,\n    TypeCode.DGL_CONTEXT: lambda x: DGLContext(\n        x.v_ctx.device_type, x.v_ctx.device_id\n    ),\n}\n"
  },
  {
    "path": "python/dgl/_ffi/_cy2/__init__.py",
    "content": "\"\"\"cython2 namespace\"\"\"\n"
  },
  {
    "path": "python/dgl/_ffi/_cy3/__init__.py",
    "content": "\"\"\"cython3 namespace\"\"\"\n"
  },
  {
    "path": "python/dgl/_ffi/_cython/.gitignore",
    "content": "*.cpp\n"
  },
  {
    "path": "python/dgl/_ffi/_cython/base.pxi",
    "content": "from ..base import DGLError\nfrom libcpp.vector cimport vector\nfrom libcpp cimport bool\nfrom cpython.version cimport PY_MAJOR_VERSION\nfrom cpython cimport pycapsule\nfrom libc.stdint cimport int32_t, int64_t, uint64_t, uint8_t, uint16_t\nimport ctypes\n\ncdef enum DGLObjectTypeCode:\n    kObjectInt = 0\n    kObjectUInt = 1\n    kObjectFloat = 2\n    kHandle = 3\n    kNull = 4\n    kDGLDataType = 5\n    kDGLContext = 6\n    kArrayHandle = 7\n    kObjectHandle = 8\n    kModuleHandle = 9\n    kFuncHandle = 10\n    kStr = 11\n    kBytes = 12\n    kNDArrayContainer = 13\n    kExtBegin = 15\n\ncdef extern from \"dgl/runtime/c_runtime_api.h\":\n    ctypedef struct DGLDataType:\n        uint8_t code\n        uint8_t bits\n        uint16_t lanes\n\n    ctypedef struct DGLContext:\n        int32_t device_type\n        int32_t device_id\n\n    ctypedef struct DGLArray:\n        void* data\n        DGLContext ctx\n        int32_t ndim\n        DGLDataType dtype\n        int64_t* shape\n        int64_t* strides\n        uint64_t byte_offset\n\n    ctypedef struct DLManagedTensor:\n        DGLArray dl_tensor\n        void* manager_ctx\n        void (*deleter)(DLManagedTensor* self)\n\n    ctypedef struct DGLValue:\n        int64_t v_int64\n        double v_float64\n        void* v_handle\n        const char* v_str\n        DGLDataType v_type\n        DGLContext v_ctx\n\nctypedef int64_t dgl_index_t\nctypedef DGLArray* DGLArrayHandle\nctypedef void* DGLStreamHandle\nctypedef void* DGLRetValueHandle\nctypedef void* DGLFunctionHandle\nctypedef void* ObjectHandle\n\nctypedef int (*DGLPackedCFunc)(\n    DGLValue* args,\n    int* type_codes,\n    int num_args,\n    DGLRetValueHandle ret,\n    void* resource_handle)\n\nctypedef void (*DGLPackedCFuncFinalizer)(void* resource_handle)\n\ncdef extern from \"dgl/runtime/c_runtime_api.h\":\n    void DGLAPISetLastError(const char* msg)\n    const char *DGLGetLastError()\n    int DGLFuncCall(DGLFunctionHandle func,\n                    DGLValue* arg_values,\n                    int* type_codes,\n                    int num_args,\n                    DGLValue* ret_val,\n                    int* ret_type_code) nogil\n    int DGLFuncFree(DGLFunctionHandle func)\n    int DGLCFuncSetReturn(DGLRetValueHandle ret,\n                          DGLValue* value,\n                          int* type_code,\n                          int num_ret)\n    int DGLFuncCreateFromCFunc(DGLPackedCFunc func,\n                               void* resource_handle,\n                               DGLPackedCFuncFinalizer fin,\n                               DGLFunctionHandle *out)\n    int DGLCbArgToReturn(DGLValue* value, int code)\n    int DGLArrayAlloc(dgl_index_t* shape,\n                      dgl_index_t ndim,\n                      DGLDataType dtype,\n                      DGLContext ctx,\n                      DGLArrayHandle* out)\n    int DGLArrayAllocSharedMem(const char *mem_name,\n                               const dgl_index_t *shape,\n                               int ndim,\n                               int dtype_code,\n                               int dtype_bits,\n                               int dtype_lanes,\n                               bool is_create,\n                               DGLArrayHandle* out)\n    int DGLArrayFree(DGLArrayHandle handle)\n    int DGLArrayCopyFromTo(DGLArrayHandle src,\n                           DGLArrayHandle to)\n\ncdef extern from \"dgl/runtime/c_object_api.h\":\n    int DGLObjectFree(ObjectHandle handle)\n    int DGLObjectTypeKey2Index(const char* type_key,\n                               int* out_index)\n    int DGLObjectGetTypeIndex(ObjectHandle handle,\n                              int* out_index)\n    int DGLObjectGetAttr(ObjectHandle handle,\n                         const char* key,\n                         DGLValue* out_value,\n                         int* out_type_code,\n                         int* out_success)\n\ncdef extern from \"dgl/runtime/dlpack_convert.h\":\n    int DGLArrayFromDLPack(DLManagedTensor* arr_from,\n                           DGLArrayHandle* out)\n    int DGLArrayToDLPack(DGLArrayHandle arr_from,\n                         DLManagedTensor** out,\n                         int alignment)\n    void DGLDLManagedTensorCallDeleter(DLManagedTensor* dltensor)\n\ncdef inline py_str(const char* x):\n    if PY_MAJOR_VERSION < 3:\n        return x\n    else:\n        return x.decode(\"utf-8\")\n\n\ncdef inline c_str(pystr):\n    \"\"\"Create ctypes char * from a python string\n    Parameters\n    ----------\n    string : string type\n        python string\n\n    Returns\n    -------\n    str : c_char_p\n        A char pointer that can be passed to C API\n    \"\"\"\n    return pystr.encode(\"utf-8\")\n\n\ncdef inline CALL(int ret):\n    if ret != 0:\n        raise DGLError(py_str(DGLGetLastError()))\n\n\ncdef inline object ctypes_handle(void* chandle):\n    \"\"\"Cast C handle to ctypes handle.\"\"\"\n    return ctypes.cast(<unsigned long long>chandle, ctypes.c_void_p)\n\n\ncdef inline void* c_handle(object handle):\n    \"\"\"Cast C types handle to c handle.\"\"\"\n    cdef unsigned long long v_ptr\n    if handle.value is None:\n        return NULL\n    else:\n        v_ptr = handle.value\n        return <void*>(v_ptr)\n"
  },
  {
    "path": "python/dgl/_ffi/_cython/core.pyx",
    "content": "include \"./base.pxi\"\ninclude \"./object.pxi\"\ninclude \"./function.pxi\"\ninclude \"./ndarray.pxi\"\n"
  },
  {
    "path": "python/dgl/_ffi/_cython/function.pxi",
    "content": "import ctypes\nimport traceback\nfrom cpython cimport Py_INCREF, Py_DECREF\nfrom numbers import Number, Integral\nfrom ..base import string_types\nfrom ..object_generic import convert_to_object, ObjectGeneric\nfrom ..runtime_ctypes import DGLDataType as CTypesDGLDataType, \\\n                             DGLContext as CTypesDGLContext, \\\n                             DGLByteArray\n\n\ncdef void dgl_callback_finalize(void* fhandle):\n    local_pyfunc = <object>(fhandle)\n    Py_DECREF(local_pyfunc)\n\ncdef int dgl_callback(DGLValue* args,\n                      int* type_codes,\n                      int num_args,\n                      DGLRetValueHandle ret,\n                      void* fhandle) with gil:\n    cdef list pyargs\n    cdef DGLValue value\n    cdef int tcode\n    local_pyfunc = <object>(fhandle)\n    pyargs = []\n    for i in range(num_args):\n        value = args[i]\n        tcode = type_codes[i]\n        if (tcode == kObjectHandle or\n            tcode == kFuncHandle or\n            tcode == kModuleHandle or\n            tcode > kExtBegin):\n            CALL(DGLCbArgToReturn(&value, tcode))\n\n        if tcode != kArrayHandle:\n            pyargs.append(make_ret(value, tcode))\n        else:\n            pyargs.append(c_make_array(value.v_handle, True))\n    try:\n        rv = local_pyfunc(*pyargs)\n    except Exception:\n        msg = traceback.format_exc()\n        DGLAPISetLastError(c_str(msg))\n        return -1\n    if rv is not None:\n        if isinstance(rv, tuple):\n            raise ValueError(\"PackedFunction can only support one return value\")\n        temp_args = []\n        make_arg(rv, &value, &tcode, temp_args)\n        CALL(DGLCFuncSetReturn(ret, &value, &tcode, 1))\n    return 0\n\n\ndef convert_to_dgl_func(object pyfunc):\n    \"\"\"Convert a python function to DGL function\n\n    Parameters\n    ----------\n    pyfunc : python function\n        The python function to be converted.\n\n    Returns\n    -------\n    dglfunc: dgl.Function\n        The converted dgl function.\n    \"\"\"\n    cdef DGLFunctionHandle chandle\n    Py_INCREF(pyfunc)\n    CALL(DGLFuncCreateFromCFunc(dgl_callback,\n                                <void*>(pyfunc),\n                                dgl_callback_finalize,\n                                &chandle))\n    ret = _CLASS_FUNCTION(None, False)\n    (<FunctionBase>ret).chandle = chandle\n    return ret\n\n\ncdef inline int make_arg(object arg,\n                         DGLValue* value,\n                         int* tcode,\n                         list temp_args) except -1:\n    \"\"\"Pack arguments into c args dgl call accept\"\"\"\n    cdef unsigned long long ptr\n    if isinstance(arg, ObjectBase):\n        value[0].v_handle = (<ObjectBase>arg).chandle\n        tcode[0] = kObjectHandle\n    elif isinstance(arg, NDArrayBase):\n        value[0].v_handle = (<NDArrayBase>arg).chandle\n        tcode[0] = (kNDArrayContainer if\n                    not (<NDArrayBase>arg).c_is_view else kArrayHandle)\n    elif isinstance(arg, _DGL_COMPATS):\n        ptr = arg._dgl_handle\n        value[0].v_handle = (<void*>ptr)\n        tcode[0] = arg.__class__._dgl_tcode\n    elif isinstance(arg, (int, long)):\n        value[0].v_int64 = arg\n        tcode[0] = kObjectInt\n    elif isinstance(arg, float):\n        value[0].v_float64 = arg\n        tcode[0] = kObjectFloat\n    elif isinstance(arg, str):\n        tstr = c_str(arg)\n        value[0].v_str = tstr\n        tcode[0] = kStr\n        temp_args.append(tstr)\n    elif arg is None:\n        value[0].v_handle = NULL\n        tcode[0] = kNull\n    elif isinstance(arg, Number):\n        value[0].v_float64 = arg\n        tcode[0] = kObjectFloat\n    elif isinstance(arg, CTypesDGLDataType):\n        tstr = c_str(str(arg))\n        value[0].v_str = tstr\n        tcode[0] = kStr\n        temp_args.append(tstr)\n    elif isinstance(arg, CTypesDGLContext):\n        value[0].v_ctx = (<DGLContext*>(\n            <unsigned long long>ctypes.addressof(arg)))[0]\n        tcode[0] = kDGLContext\n    elif isinstance(arg, bytearray):\n        arr = DGLByteArray()\n        arr.data = ctypes.cast(\n            (ctypes.c_byte * len(arg)).from_buffer(arg),\n            ctypes.POINTER(ctypes.c_byte))\n        arr.size = len(arg)\n        value[0].v_handle = <void*>(\n            <unsigned long long>ctypes.addressof(arr))\n        tcode[0] = kBytes\n        temp_args.append(arr)\n    elif isinstance(arg, string_types):\n        tstr = c_str(arg)\n        value[0].v_str = tstr\n        tcode[0] = kStr\n        temp_args.append(tstr)\n    elif isinstance(arg, (list, tuple, dict, ObjectGeneric)):\n        arg = convert_to_object(arg)\n        value[0].v_handle = (<ObjectBase>arg).chandle\n        tcode[0] = kObjectHandle\n        temp_args.append(arg)\n    #elif isinstance(arg, _CLASS_MODULE):\n    #    value[0].v_handle = c_handle(arg.handle)\n    #    tcode[0] = kModuleHandle\n    elif isinstance(arg, FunctionBase):\n        value[0].v_handle = (<FunctionBase>arg).chandle\n        tcode[0] = kFuncHandle\n    elif isinstance(arg, ctypes.c_void_p):\n        value[0].v_handle = c_handle(arg)\n        tcode[0] = kHandle\n    elif callable(arg):\n        arg = convert_to_dgl_func(arg)\n        value[0].v_handle = (<FunctionBase>arg).chandle\n        tcode[0] = kFuncHandle\n        temp_args.append(arg)\n    else:\n        raise TypeError(\"Don't know how to handle type %s\" % type(arg))\n    return 0\n\ncdef inline bytearray make_ret_bytes(void* chandle):\n    handle = ctypes_handle(chandle)\n    arr = ctypes.cast(handle, ctypes.POINTER(DGLByteArray))[0]\n    size = arr.size\n    res = bytearray(size)\n    rptr = (ctypes.c_byte * size).from_buffer(res)\n    if not ctypes.memmove(rptr, arr.data, size):\n        raise RuntimeError('memmove failed')\n    return res\n\ncdef inline object make_ret(DGLValue value, int tcode):\n    \"\"\"convert result to return value.\"\"\"\n    if tcode == kObjectHandle:\n        return make_ret_object(value.v_handle)\n    elif tcode == kNull:\n        return None\n    elif tcode == kObjectInt:\n        return value.v_int64\n    elif tcode == kObjectFloat:\n        return value.v_float64\n    elif tcode == kNDArrayContainer:\n        return c_make_array(value.v_handle, False)\n    elif tcode == kStr:\n        return py_str(value.v_str)\n    elif tcode == kBytes:\n        return make_ret_bytes(value.v_handle)\n    elif tcode == kHandle:\n        return ctypes_handle(value.v_handle)\n    elif tcode == kDGLContext:\n        return CTypesDGLContext(value.v_ctx.device_type, value.v_ctx.device_id)\n    # (minjie): class module are not used in DGL.\n    #elif tcode == kModuleHandle:\n    #    return _CLASS_MODULE(ctypes_handle(value.v_handle))\n    elif tcode == kFuncHandle:\n        fobj = _CLASS_FUNCTION(None, False)\n        (<FunctionBase>fobj).chandle = value.v_handle\n        return fobj\n    elif tcode in _DGL_EXT_RET:\n        return _DGL_EXT_RET[tcode](ctypes_handle(value.v_handle))\n\n    raise ValueError(\"Unhandled type code %d\" % tcode)\n\n\ncdef inline int FuncCall3(void* chandle,\n                          tuple args,\n                          int nargs,\n                          DGLValue* ret_val,\n                          int* ret_tcode) except -1:\n    cdef DGLValue[3] values\n    cdef int[3] tcodes\n    nargs = len(args)\n    temp_args = []\n    for i in range(nargs):\n        make_arg(args[i], &values[i], &tcodes[i], temp_args)\n    with nogil:\n        ret = DGLFuncCall(chandle, &values[0], &tcodes[0],\n                          nargs, ret_val, ret_tcode)\n    if ret != 0:\n        raise DGLError(py_str(DGLGetLastError()))\n    return 0\n\ncdef inline int FuncCall(void* chandle,\n                         tuple args,\n                         DGLValue* ret_val,\n                         int* ret_tcode) except -1:\n    cdef int nargs\n    nargs = len(args)\n    if nargs <= 3:\n        FuncCall3(chandle, args, nargs, ret_val, ret_tcode)\n        return 0\n\n    cdef vector[DGLValue] values\n    cdef vector[int] tcodes\n    values.resize(max(nargs, 1))\n    tcodes.resize(max(nargs, 1))\n    temp_args = []\n    for i in range(nargs):\n        make_arg(args[i], &values[i], &tcodes[i], temp_args)\n    with nogil:\n        ret = DGLFuncCall(chandle, &values[0], &tcodes[0],\n                          nargs, ret_val, ret_tcode)\n    if ret != 0:\n        raise DGLError(py_str(DGLGetLastError()))\n    return 0\n\n\ncdef inline int ConstructorCall(void* constructor_handle,\n                                int type_code,\n                                tuple args,\n                                void** handle) except -1:\n    \"\"\"Call contructor of a handle function\"\"\"\n    cdef DGLValue ret_val\n    cdef int ret_tcode\n    FuncCall(constructor_handle, args, &ret_val, &ret_tcode)\n    assert ret_tcode == type_code\n    handle[0] = ret_val.v_handle\n    return 0\n\n\ncdef class FunctionBase:\n    cdef DGLFunctionHandle chandle\n    cdef int is_global\n\n    cdef inline _set_handle(self, handle):\n        if handle is None:\n            self.chandle = NULL\n        else:\n            self.chandle = c_handle(handle)\n\n    property is_global:\n        def __get__(self):\n            return self.c_is_global != 0\n\n        def __set__(self, value):\n            self.c_is_global = value\n\n    property handle:\n        def __get__(self):\n            if self.chandle == NULL:\n                return None\n            else:\n                return ctypes.cast(<unsigned long long>self.chandle, ctypes.c_void_p)\n        def __set__(self, value):\n            self._set_handle(value)\n\n    def __init__(self, handle, is_global):\n        self._set_handle(handle)\n        self.c_is_global = is_global\n\n    def __dealloc__(self):\n        if self.is_global == 0:\n            CALL(DGLFuncFree(self.chandle))\n\n    def __call__(self, *args):\n        cdef DGLValue ret_val\n        cdef int ret_tcode\n        FuncCall(self.chandle, args, &ret_val, &ret_tcode)\n        return make_ret(ret_val, ret_tcode)\n\n_CLASS_FUNCTION = None\n_CLASS_MODULE = None\n\ndef _set_class_module(module_class):\n    \"\"\"Initialize the module.\"\"\"\n    global _CLASS_MODULE\n    _CLASS_MODULE = module_class\n\ndef _set_class_function(func_class):\n    global _CLASS_FUNCTION\n    _CLASS_FUNCTION = func_class\n"
  },
  {
    "path": "python/dgl/_ffi/_cython/ndarray.pxi",
    "content": "from ..runtime_ctypes import DGLArrayHandle as PyDGLArrayHandle\nfrom cpython cimport PyCapsule_Destructor\n\ncdef const char* _c_str_dltensor = \"dltensor\"\ncdef const char* _c_str_used_dltensor = \"used_dltensor\"\n\n\ncdef _c_dlpack_deleter(object pycaps):\n    cdef DLManagedTensor* dltensor\n    if pycapsule.PyCapsule_IsValid(pycaps, _c_str_dltensor):\n        dltensor = <DLManagedTensor*>pycapsule.PyCapsule_GetPointer(pycaps, _c_str_dltensor)\n        DGLDLManagedTensorCallDeleter(dltensor)\n\n\ndef _from_dlpack(object dltensor):\n    cdef DLManagedTensor* ptr\n    cdef DGLArrayHandle chandle\n    if pycapsule.PyCapsule_IsValid(dltensor, _c_str_dltensor):\n        ptr = <DLManagedTensor*>pycapsule.PyCapsule_GetPointer(dltensor, _c_str_dltensor)\n        CALL(DGLArrayFromDLPack(ptr, &chandle))\n        # set name and destructor to be empty\n        pycapsule.PyCapsule_SetDestructor(dltensor, NULL)\n        pycapsule.PyCapsule_SetName(dltensor, _c_str_used_dltensor)\n        return c_make_array(chandle, 0)\n    raise ValueError(\"Expect a dltensor field, pycapsule.PyCapsule can only be consumed once\")\n\n\ncdef class NDArrayBase:\n    cdef DGLArray* chandle\n    cdef int c_is_view\n\n    cdef inline _set_handle(self, handle):\n        cdef unsigned long long ptr\n        if handle is None:\n            self.chandle = NULL\n        else:\n            ptr = ctypes.cast(handle, ctypes.c_void_p).value\n            self.chandle = <DGLArray*>(ptr)\n\n    property _dgl_handle:\n        def __get__(self):\n            return <unsigned long long>self.chandle\n\n    property handle:\n        def __get__(self):\n            if self.chandle == NULL:\n                return None\n            else:\n                return ctypes.cast(\n                    <unsigned long long>self.chandle, PyDGLArrayHandle)\n\n        def __set__(self, value):\n            self._set_handle(value)\n\n    def __init__(self, handle, is_view):\n        self._set_handle(handle)\n        self.c_is_view = is_view\n\n    def __dealloc__(self):\n        if self.c_is_view == 0:\n            CALL(DGLArrayFree(self.chandle))\n\n    def to_dlpack(self, alignment=0):\n        \"\"\"Produce an array from a DLPack Tensor without copying memory\n\n        Args\n        -------\n        alignment: int, default to be 0\n        Indicates the alignment requirement when converting to dlpack. Will copy to a \n        new tensor if the alignment requirement is not satisfied. \n        0 means no alignment requirement.\n        \n        Returns\n        -------\n        dlpack : DLPack tensor view of the array data\n        \"\"\"\n        cdef DLManagedTensor* dltensor\n        if self.c_is_view != 0:\n            raise ValueError(\"to_dlpack do not work with memory views\")\n        CALL(DGLArrayToDLPack(self.chandle, &dltensor, alignment))\n        return pycapsule.PyCapsule_New(dltensor, _c_str_dltensor, <PyCapsule_Destructor>_c_dlpack_deleter)\n\n\ncdef c_make_array(void* chandle, is_view):\n    ret = _CLASS_NDARRAY(None, is_view)\n    (<NDArrayBase>ret).chandle = <DGLArray*>chandle\n    return ret\n\n\ncdef _DGL_COMPATS = ()\n\ncdef _DGL_EXT_RET = {}\n\ndef _reg_extension(cls, fcreate):\n    global _DGL_COMPATS\n    _DGL_COMPATS += (cls,)\n    if fcreate:\n        _DGL_EXT_RET[cls._dgl_tcode] = fcreate\n\n\ndef _make_array(handle, is_view):\n    cdef unsigned long long ptr\n    ptr = ctypes.cast(handle, ctypes.c_void_p).value\n    return c_make_array(<void*>ptr, is_view)\n\ncdef object _CLASS_NDARRAY = None\n\ndef _set_class_ndarray(cls):\n    global _CLASS_NDARRAY\n    _CLASS_NDARRAY = cls\n"
  },
  {
    "path": "python/dgl/_ffi/_cython/object.pxi",
    "content": "from ... import _api_internal\nfrom ..base import string_types\nfrom ..object_generic import _set_class_object_base\n\n\"\"\"Maps object type to its constructor\"\"\"\nOBJECT_TYPE = []\n\ndef _register_object(int index, object cls):\n    \"\"\"register object class\"\"\"\n    while len(OBJECT_TYPE) <= index:\n        OBJECT_TYPE.append(None)\n    OBJECT_TYPE[index] = cls\n\n\ncdef inline object make_ret_object(void* chandle):\n    global OBJECT_TYPE\n    cdef int tindex\n    cdef list object_type\n    cdef object cls\n    object_type = OBJECT_TYPE\n    CALL(DGLObjectGetTypeIndex(chandle, &tindex))\n    if tindex < len(object_type):\n        cls = object_type[tindex]\n        if cls is not None:\n            obj = cls.__new__(cls)\n        else:\n            obj = ObjectBase.__new__(ObjectBase)\n    else:\n        obj = ObjectBase.__new__(ObjectBase)\n    (<ObjectBase>obj).chandle = chandle\n    return obj\n\n\ncdef class ObjectBase:\n    cdef void* chandle\n\n    cdef _set_handle(self, handle):\n        cdef unsigned long long ptr\n        if handle is None:\n            self.chandle = NULL\n        else:\n            ptr = handle.value\n            self.chandle = <void*>(ptr)\n\n    property handle:\n        def __get__(self):\n            if self.chandle == NULL:\n                return None\n            else:\n                return ctypes_handle(self.chandle)\n\n        def __set__(self, value):\n            self._set_handle(value)\n\n    def __dealloc__(self):\n        CALL(DGLObjectFree(self.chandle))\n\n    def __getattr__(self, name):\n        cdef DGLValue ret_val\n        cdef int ret_type_code, ret_succ\n        CALL(DGLObjectGetAttr(self.chandle, c_str(name),\n                            &ret_val, &ret_type_code, &ret_succ))\n        if ret_succ == 0:\n            raise AttributeError(\n                \"'%s' object has no attribute '%s'\" % (type(self), name))\n        return make_ret(ret_val, ret_type_code)\n\n    def __init_handle_by_constructor__(self, fconstructor, *args):\n        \"\"\"Initialize the handle by calling constructor function.\n\n        Parameters\n        ----------\n        fconstructor : Function\n            Constructor function.\n\n        args: list of objects\n            The arguments to the constructor\n\n        Note\n        ----\n        We have a special calling convention to call constructor functions.\n        So the return handle is directly set into the Object object\n        instead of creating a new Object.\n        \"\"\"\n        cdef void* chandle\n        ConstructorCall(\n            (<FunctionBase>fconstructor).chandle,\n            kObjectHandle, args, &chandle)\n        self.chandle = chandle\n\n_set_class_object_base(ObjectBase)\n"
  },
  {
    "path": "python/dgl/_ffi/base.py",
    "content": "# coding: utf-8\n# pylint: disable=invalid-name\n\"\"\"ctypes library and helper functions \"\"\"\nfrom __future__ import absolute_import\n\nimport ctypes\nimport logging\nimport os\nimport sys\n\nimport numpy as np\n\nfrom . import libinfo\n\n# ----------------------------\n# library loading\n# ----------------------------\nif sys.version_info[0] == 3:\n    string_types = (str,)\n    numeric_types = (float, int, np.float32, np.int32)\n    # this function is needed for python3\n    # to convert ctypes.char_p .value back to python str\n    py_str = lambda x: x.decode(\"utf-8\")\nelse:\n    string_types = (basestring,)\n    numeric_types = (float, int, long, np.float32, np.int32)\n    py_str = lambda x: x\n\n\nclass DGLError(Exception):\n    \"\"\"Error thrown by DGL function\"\"\"\n\n    pass  # pylint: disable=unnecessary-pass\n\n\ndef _load_lib():\n    \"\"\"Load libary by searching possible path.\"\"\"\n    lib_path = libinfo.find_lib_path()\n    lib = ctypes.CDLL(lib_path[0])\n    dirname = os.path.dirname(lib_path[0])\n    basename = os.path.basename(lib_path[0])\n    # DMatrix functions\n    lib.DGLGetLastError.restype = ctypes.c_char_p\n    return lib, basename, dirname\n\n\n# version number\n__version__ = libinfo.__version__\n# library instance of nnvm\n_LIB, _LIB_NAME, _DIR_NAME = _load_lib()\n\n# The FFI mode of DGL\n_FFI_MODE = os.environ.get(\"DGL_FFI\", \"auto\")\n\n# ----------------------------\n# helper function in ctypes.\n# ----------------------------\ndef check_call(ret):\n    \"\"\"Check the return value of C API call\n\n    This function will raise exception when error occurs.\n    Wrap every API call with this function\n\n    Parameters\n    ----------\n    ret : int\n        return value from API calls\n    \"\"\"\n    if ret != 0:\n        raise DGLError(py_str(_LIB.DGLGetLastError()))\n\n\ndef c_str(string):\n    \"\"\"Create ctypes char * from a python string\n    Parameters\n    ----------\n    string : string type\n        python string\n\n    Returns\n    -------\n    str : c_char_p\n        A char pointer that can be passed to C API\n    \"\"\"\n    return ctypes.c_char_p(string.encode(\"utf-8\"))\n\n\ndef c_array(ctype, values):\n    \"\"\"Create ctypes array from a python array\n\n    Parameters\n    ----------\n    ctype : ctypes data type\n        data type of the array we want to convert to\n\n    values : tuple or list\n        data content\n\n    Returns\n    -------\n    out : ctypes array\n        Created ctypes array\n    \"\"\"\n    return (ctype * len(values))(*values)\n\n\ndef decorate(func, fwrapped):\n    \"\"\"A wrapper call of decorator package, differs to call time\n\n    Parameters\n    ----------\n    func : function\n        The original function\n\n    fwrapped : function\n        The wrapped function\n    \"\"\"\n    import decorator\n\n    return decorator.decorate(func, fwrapped)\n\n\ntensor_adapter_loaded = False\n\n\ndef load_tensor_adapter(backend, version):\n    \"\"\"Tell DGL to load a tensoradapter library for given backend and version.\n\n    Parameters\n    ----------\n    backend : str\n        The backend (currently ``pytorch``, ``mxnet`` or ``tensorflow``).\n    version : str\n        The version number of the backend.\n    \"\"\"\n    global tensor_adapter_loaded\n    version = version.split(\"+\")[0]\n    if sys.platform.startswith(\"linux\"):\n        basename = \"libtensoradapter_%s_%s.so\" % (backend, version)\n    elif sys.platform.startswith(\"darwin\"):\n        basename = \"libtensoradapter_%s_%s.dylib\" % (backend, version)\n    elif sys.platform.startswith(\"win\"):\n        basename = \"tensoradapter_%s_%s.dll\" % (backend, version)\n    else:\n        raise NotImplementedError(\"Unsupported system: %s\" % sys.platform)\n    path = os.path.join(_DIR_NAME, \"tensoradapter\", backend, basename)\n    tensor_adapter_loaded = _LIB.DGLLoadTensorAdapter(path.encode(\"utf-8\")) == 0\n    if not tensor_adapter_loaded:\n        logger = logging.getLogger(\"dgl-core\")\n        logger.debug(\"Memory optimization with PyTorch is not enabled.\")\n\n\ndef is_tensor_adaptor_enabled() -> bool:\n    \"\"\"Check whether TensorAdaptor is enabled.\"\"\"\n    return tensor_adapter_loaded\n"
  },
  {
    "path": "python/dgl/_ffi/capi.py",
    "content": "\"\"\"Init all C APIs in the default namespace.\"\"\"\nfrom .function import _init_api\n\n__all__ = _init_api(\"dgl.capi\", __name__)\n"
  },
  {
    "path": "python/dgl/_ffi/function.py",
    "content": "# pylint: disable=invalid-name, unused-import\n\"\"\"Function namespace.\"\"\"\nfrom __future__ import absolute_import\n\nimport ctypes\nimport sys\n\nfrom .base import _FFI_MODE, _LIB, c_str, check_call, py_str, string_types\n\nIMPORT_EXCEPT = RuntimeError if _FFI_MODE == \"cython\" else ImportError\n\ntry:\n    # pylint: disable=wrong-import-position\n    if _FFI_MODE == \"ctypes\":\n        raise ImportError()\n    if sys.version_info >= (3, 0):\n        from ._cy3.core import (\n            _set_class_function,\n            _set_class_module,\n            convert_to_dgl_func,\n            FunctionBase as _FunctionBase,\n        )\n    else:\n        from ._cy2.core import (\n            _set_class_function,\n            _set_class_module,\n            convert_to_dgl_func,\n            FunctionBase as _FunctionBase,\n        )\nexcept IMPORT_EXCEPT:\n    # pylint: disable=wrong-import-position\n    from ._ctypes.function import (\n        _set_class_function,\n        _set_class_module,\n        convert_to_dgl_func,\n        FunctionBase as _FunctionBase,\n    )\n\nFunctionHandle = ctypes.c_void_p\n\n\nclass Function(_FunctionBase):\n    \"\"\"The PackedFunc object.\n\n    Function plays an key role to bridge front and backend in DGL.\n    Function provide a type-erased interface, you can call function with positional arguments.\n\n    The compiled module returns Function.\n    DGL backend also registers and exposes its API as Functions.\n    For example, the developer function exposed in dgl.ir_pass are actually\n    C++ functions that are registered as PackedFunc\n\n    The following are list of common usage scenario of dgl.Function.\n\n    - Automatic exposure of C++ API into python\n    - To call PackedFunc from python side\n    - To call python callbacks to inspect results in generated code\n    - Bring python hook into C++ backend\n\n    See Also\n    --------\n    dgl.register_func: How to register global function.\n    dgl.get_global_func: How to get global function.\n    \"\"\"\n\n    pass  # pylint: disable=unnecessary-pass\n\n\nclass ModuleBase(object):\n    \"\"\"Base class for module\"\"\"\n\n    __slots__ = [\"handle\", \"_entry\", \"entry_name\"]\n\n    def __init__(self, handle):\n        self.handle = handle\n        self._entry = None\n        self.entry_name = \"__dgl_main__\"\n\n    def __del__(self):\n        check_call(_LIB.DGLModFree(self.handle))\n\n    @property\n    def entry_func(self):\n        \"\"\"Get the entry function\n\n        Returns\n        -------\n        f : Function\n            The entry function if exist\n        \"\"\"\n        if self._entry:\n            return self._entry\n        self._entry = self.get_function(self.entry_name)\n        return self._entry\n\n    def get_function(self, name, query_imports=False):\n        \"\"\"Get function from the module.\n\n        Parameters\n        ----------\n        name : str\n            The name of the function\n\n        query_imports : bool\n            Whether also query modules imported by this module.\n\n        Returns\n        -------\n        f : Function\n            The result function.\n        \"\"\"\n        ret_handle = FunctionHandle()\n        check_call(\n            _LIB.DGLModGetFunction(\n                self.handle,\n                c_str(name),\n                ctypes.c_int(query_imports),\n                ctypes.byref(ret_handle),\n            )\n        )\n        if not ret_handle.value:\n            raise AttributeError(\"Module has no function '%s'\" % name)\n        return Function(ret_handle, False)\n\n    def import_module(self, module):\n        \"\"\"Add module to the import list of current one.\n\n        Parameters\n        ----------\n        module : Module\n            The other module.\n        \"\"\"\n        check_call(_LIB.DGLModImport(self.handle, module.handle))\n\n    def __getitem__(self, name):\n        if not isinstance(name, string_types):\n            raise ValueError(\"Can only take string as function name\")\n        return self.get_function(name)\n\n    def __call__(self, *args):\n        if self._entry:\n            return self._entry(*args)\n        f = self.entry_func\n        return f(*args)\n\n\ndef register_func(func_name, f=None, override=False):\n    \"\"\"Register global function\n\n    Parameters\n    ----------\n    func_name : str or function\n        The function name\n\n    f : function, optional\n        The function to be registered.\n\n    override: boolean optional\n        Whether override existing entry.\n\n    Returns\n    -------\n    fregister : function\n        Register function if f is not specified.\n\n    Examples\n    --------\n    The following code registers my_packed_func as global function.\n    Note that we simply get it back from global function table to invoke\n    it from python side. However, we can also invoke the same function\n    from C++ backend, or in the compiled DGL code.\n\n    .. code-block:: python\n\n      targs = (10, 10.0, \"hello\")\n      @dgl.register_func\n      def my_packed_func(*args):\n          assert(tuple(args) == targs)\n          return 10\n      # Get it out from global function table\n      f = dgl.get_global_func(\"my_packed_func\")\n      assert isinstance(f, dgl.nd.Function)\n      y = f(*targs)\n      assert y == 10\n    \"\"\"\n    if callable(func_name):\n        f = func_name\n        func_name = f.__name__\n\n    if not isinstance(func_name, str):\n        raise ValueError(\"expect string function name\")\n\n    ioverride = ctypes.c_int(override)\n\n    def register(myf):\n        \"\"\"internal register function\"\"\"\n        if not isinstance(myf, Function):\n            myf = convert_to_dgl_func(myf)\n        check_call(\n            _LIB.DGLFuncRegisterGlobal(c_str(func_name), myf.handle, ioverride)\n        )\n        return myf\n\n    if f:\n        return register(f)\n    return register\n\n\ndef get_global_func(name, allow_missing=False):\n    \"\"\"Get a global function by name\n\n    Parameters\n    ----------\n    name : str\n        The name of the global function\n\n    allow_missing : bool\n        Whether allow missing function or raise an error.\n\n    Returns\n    -------\n    func : dgl.Function\n        The function to be returned, None if function is missing.\n    \"\"\"\n    handle = FunctionHandle()\n    check_call(_LIB.DGLFuncGetGlobal(c_str(name), ctypes.byref(handle)))\n    if handle.value:\n        return Function(handle, False)\n    else:\n        if allow_missing:\n            return None\n        else:\n            raise ValueError(\"Cannot find global function %s\" % name)\n\n\ndef list_global_func_names():\n    \"\"\"Get list of global functions registered.\n\n    Returns\n    -------\n    names : list\n       List of global functions names.\n    \"\"\"\n    plist = ctypes.POINTER(ctypes.c_char_p)()\n    size = ctypes.c_uint()\n\n    check_call(\n        _LIB.DGLFuncListGlobalNames(ctypes.byref(size), ctypes.byref(plist))\n    )\n    fnames = []\n    for i in range(size.value):\n        fnames.append(py_str(plist[i]))\n    return fnames\n\n\ndef extract_ext_funcs(finit):\n    \"\"\"\n    Extract the extension PackedFuncs from a C module.\n\n    Parameters\n    ----------\n    finit : ctypes function\n        a ctypes that takes signature of DGLExtensionDeclarer\n\n    Returns\n    -------\n    fdict : dict of str to Function\n        The extracted functions\n    \"\"\"\n    fdict = {}\n\n    def _list(name, func):\n        fdict[name] = func\n\n    myf = convert_to_dgl_func(_list)\n    ret = finit(myf.handle)\n    _ = myf\n    if ret != 0:\n        raise RuntimeError(\"cannot initialize with %s\" % finit)\n    return fdict\n\n\ndef _get_api(f):\n    flocal = f\n    flocal.is_global = True\n    return flocal\n\n\ndef _init_api(namespace, target_module_name=None):\n    \"\"\"Initialize api for a given module name\n\n    namespace : str\n       The namespace of the source registry\n\n    target_module_name : str\n       The target module name if different from namespace\n    \"\"\"\n    target_module_name = target_module_name if target_module_name else namespace\n    if namespace.startswith(\"dgl.\"):\n        return _init_api_prefix(target_module_name, namespace[4:])\n    else:\n        return _init_api_prefix(target_module_name, namespace)\n\n\ndef _init_api_prefix(module_name, prefix):\n    module = sys.modules[module_name]\n    name_list = []\n\n    for name in list_global_func_names():\n        if name.startswith(\"_\") and not name.startswith(\"_deprecate\"):\n            # internal APIs are ignored\n            continue\n        name_split = name.rsplit(\".\", 1)\n        if name_split[0] != prefix:\n            continue\n\n        if len(name_split) == 1:\n            print('Warning: invalid API name \"%s\".' % name)\n            continue\n        fname = name_split[1]\n        target_module = module\n\n        f = get_global_func(name)\n        ff = _get_api(f)\n        ff.__name__ = fname\n        ff.__doc__ = \"DGL PackedFunc %s. \" % fname\n        setattr(target_module, ff.__name__, ff)\n        name_list.append(fname)\n\n    return name_list\n\n\ndef _init_internal_api():\n    for name in list_global_func_names():\n        if not name.startswith(\"_\") or name.startswith(\"_deprecate\"):\n            # normal APIs are ignored\n            continue\n        target_module = sys.modules[\"dgl._api_internal\"]\n        fname = name\n        if fname.find(\".\") != -1:\n            print('Warning: invalid API name \"%s\".' % fname)\n            continue\n        f = get_global_func(name)\n        ff = _get_api(f)\n        ff.__name__ = fname\n        ff.__doc__ = \"DGL PackedFunc %s. \" % fname\n        setattr(target_module, ff.__name__, ff)\n\n\n_set_class_function(Function)\n"
  },
  {
    "path": "python/dgl/_ffi/libinfo.py",
    "content": "\"\"\"Library information.\"\"\"\nfrom __future__ import absolute_import\n\nimport os\nimport pathlib\nimport sys\n\n\ndef find_lib_path(name=None, search_path=None, optional=False):\n    \"\"\"Find dynamic library files.\n\n    Parameters\n    ----------\n    name : list of str\n        List of names to be found.\n\n    Returns\n    -------\n    lib_path : list(string)\n        List of all found path to the libraries\n    \"\"\"\n    # See https://github.com/dmlc/tvm/issues/281 for some background.\n\n    # NB: This will either be the source directory (if DGL is run\n    # inplace) or the install directory (if DGL is installed).\n    # An installed DGL's curr_path will look something like:\n    #   $PREFIX/lib/python3.6/site-packages/dgl/_ffi\n    ffi_dir = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))\n    source_dir = os.path.join(ffi_dir, \"..\", \"..\", \"..\")\n    install_lib_dir = os.path.join(ffi_dir, \"..\", \"..\", \"..\", \"..\")\n\n    dll_path = []\n\n    if os.environ.get(\"DGL_LIBRARY_PATH\", None):\n        dll_path.append(os.environ[\"DGL_LIBRARY_PATH\"])\n\n    if sys.platform.startswith(\"linux\") and os.environ.get(\n        \"LD_LIBRARY_PATH\", None\n    ):\n        dll_path.extend(\n            [p.strip() for p in os.environ[\"LD_LIBRARY_PATH\"].split(\":\")]\n        )\n    elif sys.platform.startswith(\"darwin\") and os.environ.get(\n        \"DYLD_LIBRARY_PATH\", None\n    ):\n        dll_path.extend(\n            [p.strip() for p in os.environ[\"DYLD_LIBRARY_PATH\"].split(\":\")]\n        )\n\n    # Pip lib directory\n    dll_path.append(os.path.join(ffi_dir, \"..\"))\n    # Default cmake build directory\n    dll_path.append(os.path.join(source_dir, \"build\"))\n    dll_path.append(os.path.join(source_dir, \"build\", \"Release\"))\n    # Default make build directory\n    dll_path.append(os.path.join(source_dir, \"lib\"))\n\n    dll_path.append(install_lib_dir)\n\n    if search_path is not None:\n        if isinstance(search_path, (list, tuple, set)):\n            dll_path = dll_path + list(search_path)\n        elif isinstance(search_path, str):\n            dll_path.append(search_path)\n        else:\n            raise ValueError(\n                \"type(search_path)={} is invalid\".format(type(search_path))\n            )\n    dll_path = [\n        str(x.absolute()) if isinstance(x, pathlib.Path) else os.path.abspath(x)\n        for x in dll_path\n    ]\n\n    if name is None:\n        if sys.platform.startswith(\"win32\"):\n            name = [\"libdgl.dll\", \"dgl.dll\"]\n        elif sys.platform.startswith(\"darwin\"):\n            name = \"libdgl.dylib\"\n        else:\n            name = \"libdgl.so\"\n\n    if isinstance(name, str):\n        name = [name]\n\n    lib_dll_path = []\n    for n in name:\n        lib_dll_path += [os.path.join(p, n) for p in dll_path]\n\n    lib_found = [p for p in lib_dll_path if os.path.isfile(p)]\n\n    if not lib_found:\n        message = (\n            \"Cannot find the files.\\n\"\n            + \"List of candidates:\\n\"\n            + str(\"\\n\".join(lib_dll_path))\n        )\n        if not optional:\n            raise RuntimeError(message)\n        return None\n\n    return lib_found\n\n\n# current version\n# We use the version of the incoming release for code\n# that is under development.\n# The following line is set by dgl/python/update_version.py\n__version__ = \"2.5\"\n"
  },
  {
    "path": "python/dgl/_ffi/ndarray.py",
    "content": "# pylint: disable=invalid-name, unused-import\n\"\"\"Runtime NDArray api\"\"\"\nfrom __future__ import absolute_import\n\nimport ctypes\nimport sys\n\nimport numpy as np\n\nfrom .base import _FFI_MODE, _LIB, c_array, c_str, check_call, string_types\nfrom .runtime_ctypes import (\n    dgl_shape_index_t,\n    DGLArray,\n    DGLArrayHandle,\n    DGLContext,\n    DGLDataType,\n    TypeCode,\n)\n\nIMPORT_EXCEPT = RuntimeError if _FFI_MODE == \"cython\" else ImportError\n\ntry:\n    # pylint: disable=wrong-import-position\n    if _FFI_MODE == \"ctypes\":\n        raise ImportError()\n    if sys.version_info >= (3, 0):\n        from ._cy3.core import (\n            _from_dlpack,\n            _make_array,\n            _reg_extension,\n            _set_class_ndarray,\n            NDArrayBase as _NDArrayBase,\n        )\n    else:\n        from ._cy2.core import (\n            _from_dlpack,\n            _make_array,\n            _reg_extension,\n            _set_class_ndarray,\n            NDArrayBase as _NDArrayBase,\n        )\nexcept IMPORT_EXCEPT:\n    # pylint: disable=wrong-import-position\n    from ._ctypes.ndarray import (\n        _from_dlpack,\n        _make_array,\n        _reg_extension,\n        _set_class_ndarray,\n        NDArrayBase as _NDArrayBase,\n    )\n\n\ndef context(dev_type, dev_id=0):\n    \"\"\"Construct a DGL context with given device type and id.\n\n    Parameters\n    ----------\n    dev_type: int or str\n        The device type mask or name of the device.\n\n    dev_id : int, optional\n        The integer device id\n\n    Returns\n    -------\n    ctx: DGLContext\n        The corresponding context.\n\n    Examples\n    --------\n    Context can be used to create reflection of context by\n    string representation of the device type.\n\n    .. code-block:: python\n\n      assert dgl.context(\"cpu\", 1) == dgl.cpu(1)\n      assert dgl.context(\"gpu\", 0) == dgl.gpu(0)\n      assert dgl.context(\"cuda\", 0) == dgl.gpu(0)\n    \"\"\"\n    if isinstance(dev_type, string_types):\n        dev_type = dev_type.split()[0]\n        if dev_type not in DGLContext.STR2MASK:\n            raise ValueError(\"Unknown device type %s\" % dev_type)\n        dev_type = DGLContext.STR2MASK[dev_type]\n    return DGLContext(dev_type, dev_id)\n\n\ndef numpyasarray(np_data):\n    \"\"\"Return a DGLArray representation of a numpy array.\"\"\"\n    data = np_data\n    assert data.flags[\"C_CONTIGUOUS\"]\n    arr = DGLArray()\n    shape = c_array(dgl_shape_index_t, data.shape)\n    arr.data = data.ctypes.data_as(ctypes.c_void_p)\n    arr.shape = shape\n    arr.strides = None\n    arr.dtype = DGLDataType(np.dtype(data.dtype).name)\n    arr.ndim = data.ndim\n    # CPU device\n    arr.ctx = context(1, 0)\n    return arr, shape\n\n\ndef empty(shape, dtype=\"float32\", ctx=context(1, 0)):\n    \"\"\"Create an empty array given shape and device\n\n    Parameters\n    ----------\n    shape : tuple of int\n        The shape of the array\n\n    dtype : type or str\n        The data type of the array.\n\n    ctx : DGLContext\n        The context of the array\n\n    Returns\n    -------\n    arr : dgl.nd.NDArray\n        The array dgl supported.\n    \"\"\"\n    shape = c_array(dgl_shape_index_t, shape)\n    ndim = ctypes.c_int(len(shape))\n    handle = DGLArrayHandle()\n    dtype = DGLDataType(dtype)\n    check_call(\n        _LIB.DGLArrayAlloc(\n            shape,\n            ndim,\n            ctypes.c_int(dtype.type_code),\n            ctypes.c_int(dtype.bits),\n            ctypes.c_int(dtype.lanes),\n            ctx.device_type,\n            ctx.device_id,\n            ctypes.byref(handle),\n        )\n    )\n    return _make_array(handle, False)\n\n\ndef empty_shared_mem(name, is_create, shape, dtype=\"float32\"):\n    \"\"\"Create an empty array with shared memory given shape and dtype\n\n    Parameters\n    ----------\n    name : string\n        The name of the shared memory. It's a file name in Unix.\n\n    is_create : bool\n        Whether to create the shared memory or use the one created by somewhere else.\n\n    shape : tuple of int\n        The shape of the array\n\n    dtype : type or str\n        The data type of the array.\n\n    Returns\n    -------\n    arr : dgl.nd.NDArray\n        The array dgl supported.\n    \"\"\"\n    name = ctypes.c_char_p(name.encode(\"utf-8\"))\n    shape = c_array(dgl_shape_index_t, shape)\n    ndim = ctypes.c_int(len(shape))\n    handle = DGLArrayHandle()\n    dtype = DGLDataType(dtype)\n    check_call(\n        _LIB.DGLArrayAllocSharedMem(\n            name,\n            shape,\n            ndim,\n            ctypes.c_int(dtype.type_code),\n            ctypes.c_int(dtype.bits),\n            ctypes.c_int(dtype.lanes),\n            is_create,\n            ctypes.byref(handle),\n        )\n    )\n    return _make_array(handle, False)\n\n\ndef from_dlpack(dltensor):\n    \"\"\"Produce an array from a DLPack tensor without memory copy.\n    Retrieves the underlying DLPack tensor's pointer to create an array from the\n    data. Removes the original DLPack tensor's destructor as now the array is\n    responsible for destruction.\n\n    Parameters\n    ----------\n    dltensor : DLPack tensor\n        Input DLManagedTensor, can only be consumed once.\n\n    Returns\n    -------\n    arr: dgl.nd.NDArray\n        The array view of the tensor data.\n    \"\"\"\n    return _from_dlpack(dltensor)\n\n\nclass NDArrayBase(_NDArrayBase):\n    \"\"\"A simple Device/CPU Array object in runtime.\"\"\"\n\n    @property\n    def shape(self):\n        \"\"\"Shape of this array\"\"\"\n        return tuple(\n            self.handle.contents.shape[i]\n            for i in range(self.handle.contents.ndim)\n        )\n\n    @property\n    def dtype(self):\n        \"\"\"Type of this array\"\"\"\n        return str(self.handle.contents.dtype)\n\n    @property\n    def ctx(self):\n        \"\"\"context of this array\"\"\"\n        return self.handle.contents.ctx\n\n    @property\n    def context(self):\n        \"\"\"context of this array\"\"\"\n        return self.ctx\n\n    def __hash__(self):\n        return ctypes.cast(self.handle, ctypes.c_void_p).value\n\n    def __eq__(self, other):\n        return self.same_as(other)\n\n    def __ne__(self, other):\n        return not self.__eq__(other)\n\n    def same_as(self, other):\n        \"\"\"Check object identity equality\n\n        Parameters\n        ----------\n        other : object\n            The other object to compare to\n\n        Returns\n        -------\n        same : bool\n            Whether other is same as self.\n        \"\"\"\n        if not isinstance(other, NDArrayBase):\n            return False\n        return self.__hash__() == other.__hash__()\n\n    def __setitem__(self, in_slice, value):\n        \"\"\"Set ndarray value\"\"\"\n        if (\n            not isinstance(in_slice, slice)\n            or in_slice.start is not None\n            or in_slice.stop is not None\n        ):\n            raise ValueError(\"Array only support set from numpy array\")\n        if isinstance(value, NDArrayBase):\n            if value.handle is not self.handle:\n                value.copyto(self)\n        elif isinstance(value, (np.ndarray, np.generic)):\n            self.copyfrom(value)\n        else:\n            raise TypeError(\"type %s not supported\" % str(type(value)))\n\n    def copyfrom(self, source_array):\n        \"\"\"Perform a synchronized copy from the array.\n\n        Parameters\n        ----------\n        source_array : array_like\n            The data source we should like to copy from.\n\n        Returns\n        -------\n        arr : NDArray\n            Reference to self.\n        \"\"\"\n        if isinstance(source_array, NDArrayBase):\n            source_array.copyto(self)\n            return self\n\n        if not isinstance(source_array, np.ndarray):\n            try:\n                source_array = np.asarray(source_array, dtype=self.dtype)\n            except:\n                raise TypeError(\n                    \"array must be an array_like data,\"\n                    + \"type %s is not supported\" % str(type(source_array))\n                )\n        t = DGLDataType(self.dtype)\n        shape, dtype = self.shape, self.dtype\n        if t.lanes > 1:\n            shape = shape + (t.lanes,)\n            t.lanes = 1\n            dtype = str(t)\n\n        if source_array.shape != shape:\n            raise ValueError(\n                \"array shape do not match the shape of NDArray {0} vs {1}\".format(\n                    source_array.shape, shape\n                )\n            )\n        source_array = np.ascontiguousarray(source_array, dtype=dtype)\n        assert source_array.flags[\"C_CONTIGUOUS\"]\n        data = source_array.ctypes.data_as(ctypes.c_void_p)\n        nbytes = ctypes.c_size_t(\n            source_array.size * source_array.dtype.itemsize\n        )\n        check_call(_LIB.DGLArrayCopyFromBytes(self.handle, data, nbytes))\n        return self\n\n    def __repr__(self):\n        res = \"dgl.{0}@{1}\".format(self.asnumpy().__repr__(), self.context)\n        return res\n\n    def __str__(self):\n        return str(self.asnumpy())\n\n    def asnumpy(self):\n        \"\"\"Convert this array to numpy array\n\n        Returns\n        -------\n        np_arr : numpy.ndarray\n            The corresponding numpy array.\n        \"\"\"\n        t = DGLDataType(self.dtype)\n        shape, dtype = self.shape, self.dtype\n        if t.lanes > 1:\n            shape = shape + (t.lanes,)\n            t.lanes = 1\n            dtype = str(t)\n        np_arr = np.empty(shape, dtype=dtype)\n        assert np_arr.flags[\"C_CONTIGUOUS\"]\n        data = np_arr.ctypes.data_as(ctypes.c_void_p)\n        nbytes = ctypes.c_size_t(np_arr.size * np_arr.dtype.itemsize)\n        check_call(_LIB.DGLArrayCopyToBytes(self.handle, data, nbytes))\n        return np_arr\n\n    def copyto(self, target):\n        \"\"\"Copy array to target\n\n        Parameters\n        ----------\n        target : NDArray\n            The target array to be copied, must have same shape as this array.\n        \"\"\"\n        if isinstance(target, DGLContext):\n            target = empty(self.shape, self.dtype, target)\n        if isinstance(target, NDArrayBase):\n            check_call(_LIB.DGLArrayCopyFromTo(self.handle, target.handle))\n        else:\n            raise ValueError(\"Unsupported target type %s\" % str(type(target)))\n        return target\n\n    def pin_memory_(self):\n        \"\"\"Pin host memory and map into GPU address space (in-place)\"\"\"\n        check_call(_LIB.DGLArrayPinData(self.handle))\n\n    def unpin_memory_(self):\n        \"\"\"Unpin host memory pinned by pin_memory_()\"\"\"\n        check_call(_LIB.DGLArrayUnpinData(self.handle))\n\n    def record_stream(self, stream):\n        \"\"\"Record the stream that is using this tensor.\n\n        Note\n        ----\n        This API is more for testing. Users should call ``record_stream``\n        on torch.Tensor or dgl.graph directly.\n\n        Parameters\n        ----------\n        stream : DGLStreamHandle\n        \"\"\"\n        check_call(_LIB.DGLArrayRecordStream(self.handle, stream))\n\n\ndef free_extension_handle(handle, type_code):\n    \"\"\"Free c++ extension type handle\n\n    Parameters\n    ----------\n    handle : ctypes.c_void_p\n        The handle to the extension type.\n\n    type_code : int\n         The tyoe code\n    \"\"\"\n    check_call(_LIB.DGLExtTypeFree(handle, ctypes.c_int(type_code)))\n\n\ndef register_extension(cls, fcreate=None):\n    \"\"\"Register a extension class to DGL.\n\n    After the class is registered, the class will be able\n    to directly pass as Function argument generated by DGL.\n\n    Parameters\n    ----------\n    cls : class\n        The class object to be registered as extension.\n\n    Note\n    ----\n    The registered class is requires one property: _dgl_handle and a class attribute _dgl_tcode.\n\n    - ```_dgl_handle``` returns integer represents the address of the handle.\n    - ```_dgl_tcode``` gives integer represents type code of the class.\n\n    Returns\n    -------\n    cls : class\n        The class being registered.\n\n    fcreate : function, optional\n        The creation function to create a class object given handle value.\n\n    Example\n    -------\n    The following code registers user defined class\n    MyTensor to be DLTensor compatible.\n\n    .. code-block:: python\n\n       @dgl.register_extension\n       class MyTensor(object):\n           _dgl_tcode = dgl.TypeCode.ARRAY_HANDLE\n\n           def __init__(self):\n               self.handle = _LIB.NewDLTensor()\n\n           @property\n           def _dgl_handle(self):\n               return self.handle.value\n    \"\"\"\n    if fcreate and cls._dgl_tcode < TypeCode.EXT_BEGIN:\n        raise ValueError(\n            \"Cannot register create when extension tcode is same as buildin\"\n        )\n    _reg_extension(cls, fcreate)\n    return cls\n"
  },
  {
    "path": "python/dgl/_ffi/object.py",
    "content": "\"\"\"Object namespace\"\"\"\n# pylint: disable=unused-import\nfrom __future__ import absolute_import\n\nimport ctypes\nimport sys\n\nfrom .. import _api_internal\nfrom .base import _FFI_MODE, _LIB, c_str, check_call, py_str\nfrom .object_generic import convert_to_object, ObjectGeneric\n\n# pylint: disable=invalid-name\nIMPORT_EXCEPT = RuntimeError if _FFI_MODE == \"cython\" else ImportError\ntry:\n    # pylint: disable=wrong-import-position\n    if _FFI_MODE == \"ctypes\":\n        raise ImportError()\n    if sys.version_info >= (3, 0):\n        from ._cy3.core import _register_object, ObjectBase as _ObjectBase\n    else:\n        from ._cy2.core import _register_object, ObjectBase as _ObjectBase\nexcept IMPORT_EXCEPT:\n    # pylint: disable=wrong-import-position\n    from ._ctypes.object import _register_object, ObjectBase as _ObjectBase\n\n\ndef _new_object(cls):\n    \"\"\"Helper function for pickle\"\"\"\n    return cls.__new__(cls)\n\n\nclass ObjectBase(_ObjectBase):\n    \"\"\"ObjectBase is the base class of all DGL CAPI object.\n\n    The core attribute is ``handle``, which is a C raw pointer.  It must be initialized\n    via ``__init_handle_by_constructor__``.\n\n    Note that the same handle **CANNOT** be shared across multiple ObjectBase instances.\n    \"\"\"\n\n    def __dir__(self):\n        plist = ctypes.POINTER(ctypes.c_char_p)()\n        size = ctypes.c_uint()\n        check_call(\n            _LIB.DGLObjectListAttrNames(\n                self.handle, ctypes.byref(size), ctypes.byref(plist)\n            )\n        )\n        names = []\n        for i in range(size.value):\n            names.append(py_str(plist[i]))\n        return names\n\n    def __hash__(self):\n        return _api_internal._raw_ptr(self)\n\n    def __eq__(self, other):\n        return self.same_as(other)\n\n    def __ne__(self, other):\n        return not self.__eq__(other)\n\n    def __reduce__(self):\n        cls = type(self)\n        return (_new_object, (cls,), self.__getstate__())\n\n    def __getstate__(self):\n        # TODO(minjie): TVM assumes that a Node (Object in DGL) can be serialized\n        #   to json. However, this is not true in DGL because DGL Object is meant\n        #   for runtime API, so it could contain binary data such as NDArray.\n        #   If this feature is required, please raise a RFC to DGL issue.\n        raise RuntimeError(\"__getstate__ is not supported for object type\")\n\n    def __setstate__(self, state):\n        # pylint: disable=assigning-non-slot\n        # TODO(minjie): TVM assumes that a Node (Object in DGL) can be serialized\n        #   to json. However, this is not true in DGL because DGL Object is meant\n        #   for runtime API, so it could contain binary data such as NDArray.\n        #   If this feature is required, please raise a RFC to DGL issue.\n        raise RuntimeError(\"__setstate__ is not supported for object type\")\n\n    def same_as(self, other):\n        \"\"\"check object identity equality\"\"\"\n        if not isinstance(other, ObjectBase):\n            return False\n        return self.__hash__() == other.__hash__()\n\n\ndef register_object(type_key=None):\n    \"\"\"Decorator used to register object type\n\n    Examples\n    --------\n    >>> @register_object\n    >>> class MyObject:\n    >>> ... pass\n\n    Parameters\n    ----------\n    type_key : str or cls\n        The type key of the object\n    \"\"\"\n    object_name = type_key if isinstance(type_key, str) else type_key.__name__\n\n    def register(cls):\n        \"\"\"internal register function\"\"\"\n        tindex = ctypes.c_int()\n        ret = _LIB.DGLObjectTypeKey2Index(\n            c_str(object_name), ctypes.byref(tindex)\n        )\n        if ret == 0:\n            _register_object(tindex.value, cls)\n        return cls\n\n    if isinstance(type_key, str):\n        return register\n    return register(type_key)\n"
  },
  {
    "path": "python/dgl/_ffi/object_generic.py",
    "content": "\"\"\"Common implementation of Object generic related logic\"\"\"\n# pylint: disable=unused-import\nfrom __future__ import absolute_import\n\nfrom numbers import Integral, Number\n\nfrom .. import _api_internal\nfrom .base import string_types\n\n# Object base class\n_CLASS_OBJECT_BASE = None\n\n\ndef _set_class_object_base(cls):\n    global _CLASS_OBJECT_BASE\n    _CLASS_OBJECT_BASE = cls\n\n\nclass ObjectGeneric(object):\n    \"\"\"Base class for all classes that can be converted to object.\"\"\"\n\n    def asobject(self):\n        \"\"\"Convert value to object\"\"\"\n        raise NotImplementedError()\n\n\ndef convert_to_object(value):\n    \"\"\"Convert a python value to corresponding object type.\n\n    Parameters\n    ----------\n    value : str\n        The value to be inspected.\n\n    Returns\n    -------\n    object : Object\n        The corresponding object value.\n    \"\"\"\n    if isinstance(value, _CLASS_OBJECT_BASE):\n        return value\n    if isinstance(value, (list, tuple)):\n        value = [convert_to_object(x) for x in value]\n        return _api_internal._List(*value)\n    if isinstance(value, dict):\n        vlist = []\n        for item in value.items():\n            if not isinstance(item[0], _CLASS_OBJECT_BASE) and not isinstance(\n                item[0], string_types\n            ):\n                raise ValueError(\n                    \"key of map must already been a container type\"\n                )\n            vlist.append(item[0])\n            vlist.append(convert_to_object(item[1]))\n        return _api_internal._Map(*vlist)\n    if isinstance(value, ObjectGeneric):\n        return value.asobject()\n    return _api_internal._Value(value)\n"
  },
  {
    "path": "python/dgl/_ffi/runtime_ctypes.py",
    "content": "\"\"\"Common runtime ctypes.\"\"\"\n# pylint: disable=invalid-name, super-init-not-called\nfrom __future__ import absolute_import\n\nimport ctypes\nimport json\n\nimport numpy as np\n\nfrom .. import _api_internal\nfrom .base import _LIB, check_call\n\ndgl_shape_index_t = ctypes.c_int64\n\n\nclass TypeCode(object):\n    \"\"\"Type code used in API calls\"\"\"\n\n    INT = 0\n    UINT = 1\n    FLOAT = 2\n    HANDLE = 3\n    NULL = 4\n    DGL_DATA_TYPE = 5\n    DGL_CONTEXT = 6\n    ARRAY_HANDLE = 7\n    OBJECT_HANDLE = 8\n    MODULE_HANDLE = 9\n    FUNC_HANDLE = 10\n    STR = 11\n    BYTES = 12\n    NDARRAY_CONTAINER = 13\n    EXT_BEGIN = 15\n\n\nclass DGLByteArray(ctypes.Structure):\n    \"\"\"Temp data structure for byte array.\"\"\"\n\n    _fields_ = [\n        (\"data\", ctypes.POINTER(ctypes.c_byte)),\n        (\"size\", ctypes.c_size_t),\n    ]\n\n\nclass DGLDataType(ctypes.Structure):\n    \"\"\"DGL datatype structure\"\"\"\n\n    _fields_ = [\n        (\"type_code\", ctypes.c_uint8),\n        (\"bits\", ctypes.c_uint8),\n        (\"lanes\", ctypes.c_uint16),\n    ]\n    CODE2STR = {0: \"int\", 1: \"uint\", 2: \"float\", 4: \"handle\"}\n    _cache = {}\n\n    def __new__(cls, type_str):\n        if type_str in cls._cache:\n            return cls._cache[type_str]\n\n        inst = super(DGLDataType, cls).__new__(DGLDataType)\n\n        if isinstance(type_str, np.dtype):\n            type_str = str(type_str)\n        arr = type_str.split(\"x\")\n        head = arr[0]\n        inst.lanes = int(arr[1]) if len(arr) > 1 else 1\n        bits = 32\n\n        if head.startswith(\"int\"):\n            inst.type_code = 0\n            head = head[3:]\n        elif head.startswith(\"uint\"):\n            inst.type_code = 1\n            head = head[4:]\n        elif head.startswith(\"float\"):\n            inst.type_code = 2\n            head = head[5:]\n        elif head.startswith(\"handle\"):\n            inst.type_code = 4\n            bits = 64\n            head = \"\"\n        else:\n            raise ValueError(\"Do not know how to handle type %s\" % type_str)\n        bits = int(head) if head else bits\n        inst.bits = bits\n\n        cls._cache[type_str] = inst\n        return inst\n\n    def __init__(self, type_str):\n        pass\n\n    def __repr__(self):\n        x = \"%s%d\" % (DGLDataType.CODE2STR[self.type_code], self.bits)\n        if self.lanes != 1:\n            x += \"x%d\" % self.lanes\n        return x\n\n    def __eq__(self, other):\n        return (\n            self.bits == other.bits\n            and self.type_code == other.type_code\n            and self.lanes == other.lanes\n        )\n\n    def __ne__(self, other):\n        return not self.__eq__(other)\n\n\nRPC_SESS_MASK = 128\n\n\nclass DGLContext(ctypes.Structure):\n    \"\"\"DGL context strucure.\"\"\"\n\n    _fields_ = [(\"device_type\", ctypes.c_int), (\"device_id\", ctypes.c_int)]\n    MASK2STR = {\n        1: \"cpu\",\n        2: \"gpu\",\n        4: \"opencl\",\n        5: \"aocl\",\n        6: \"sdaccel\",\n        7: \"vulkan\",\n        8: \"metal\",\n        9: \"vpi\",\n        10: \"rocm\",\n        11: \"opengl\",\n        12: \"ext_dev\",\n    }\n    STR2MASK = {\n        \"llvm\": 1,\n        \"stackvm\": 1,\n        \"cpu\": 1,\n        \"gpu\": 2,\n        \"cuda\": 2,\n        \"nvptx\": 2,\n        \"cl\": 4,\n        \"opencl\": 4,\n        \"aocl\": 5,\n        \"aocl_sw_emu\": 5,\n        \"sdaccel\": 6,\n        \"vulkan\": 7,\n        \"metal\": 8,\n        \"vpi\": 9,\n        \"rocm\": 10,\n        \"opengl\": 11,\n        \"ext_dev\": 12,\n    }\n    _cache = {}\n\n    def __new__(cls, device_type, device_id):\n        if (device_type, device_id) in cls._cache:\n            return cls._cache[(device_type, device_id)]\n\n        inst = super(DGLContext, cls).__new__(DGLContext)\n\n        inst.device_type = device_type\n        inst.device_id = device_id\n\n        cls._cache[(device_type, device_id)] = inst\n        return inst\n\n    def __init__(self, device_type, device_id):\n        pass\n\n    @property\n    def exist(self):\n        \"\"\"Whether this device exist.\"\"\"\n        return (\n            _api_internal._GetDeviceAttr(self.device_type, self.device_id, 0)\n            != 0\n        )\n\n    @property\n    def max_threads_per_block(self):\n        \"\"\"Maximum number of threads on each block.\"\"\"\n        return _api_internal._GetDeviceAttr(self.device_type, self.device_id, 1)\n\n    @property\n    def warp_size(self):\n        \"\"\"Number of threads that executes in concurrent.\"\"\"\n        return _api_internal._GetDeviceAttr(self.device_type, self.device_id, 2)\n\n    @property\n    def max_shared_memory_per_block(self):\n        \"\"\"Total amount of shared memory per block in bytes.\"\"\"\n        return _api_internal._GetDeviceAttr(self.device_type, self.device_id, 3)\n\n    @property\n    def compute_version(self):\n        \"\"\"Get compute verison number in string.\n\n        Currently used to get compute capability of CUDA device.\n\n        Returns\n        -------\n        version : str\n            The version string in `major.minor` format.\n        \"\"\"\n        return _api_internal._GetDeviceAttr(self.device_type, self.device_id, 4)\n\n    @property\n    def device_name(self):\n        \"\"\"Return the string name of device.\"\"\"\n        return _api_internal._GetDeviceAttr(self.device_type, self.device_id, 5)\n\n    @property\n    def max_clock_rate(self):\n        \"\"\"Return the max clock frequency of device.\"\"\"\n        return _api_internal._GetDeviceAttr(self.device_type, self.device_id, 6)\n\n    @property\n    def multi_processor_count(self):\n        \"\"\"Return the number of compute units of device.\"\"\"\n        return _api_internal._GetDeviceAttr(self.device_type, self.device_id, 7)\n\n    @property\n    def max_thread_dimensions(self):\n        \"\"\"Return the maximum size of each thread axis\n\n        Returns\n        -------\n        dims: List of int\n            The maximum length of threadIdx.x, threadIdx.y, threadIdx.z\n        \"\"\"\n        return json.loads(\n            _api_internal._GetDeviceAttr(self.device_type, self.device_id, 8)\n        )\n\n    def sync(self):\n        \"\"\"Synchronize until jobs finished at the context.\"\"\"\n        check_call(_LIB.DGLSynchronize(self.device_type, self.device_id, None))\n\n    def __eq__(self, other):\n        return (\n            isinstance(other, DGLContext)\n            and self.device_id == other.device_id\n            and self.device_type == other.device_type\n        )\n\n    def __ne__(self, other):\n        return not self.__eq__(other)\n\n    def __repr__(self):\n        if self.device_type >= RPC_SESS_MASK:\n            tbl_id = self.device_type / RPC_SESS_MASK - 1\n            dev_type = self.device_type % RPC_SESS_MASK\n            return \"remote[%d]:%s(%d)\" % (\n                tbl_id,\n                DGLContext.MASK2STR[dev_type],\n                self.device_id,\n            )\n        return \"%s(%d)\" % (\n            DGLContext.MASK2STR[self.device_type],\n            self.device_id,\n        )\n\n    def __hash__(self):\n        return hash((self.device_type, self.device_id))\n\n\nclass DGLArray(ctypes.Structure):\n    \"\"\"DGLValue in C API\"\"\"\n\n    _fields_ = [\n        (\"data\", ctypes.c_void_p),\n        (\"ctx\", DGLContext),\n        (\"ndim\", ctypes.c_int),\n        (\"dtype\", DGLDataType),\n        (\"shape\", ctypes.POINTER(dgl_shape_index_t)),\n        (\"strides\", ctypes.POINTER(dgl_shape_index_t)),\n        (\"byte_offset\", ctypes.c_uint64),\n    ]\n\n\nDGLArrayHandle = ctypes.POINTER(DGLArray)\n\nDGLStreamHandle = ctypes.c_void_p\n"
  },
  {
    "path": "python/dgl/_ffi/streams.py",
    "content": "# pylint: disable=invalid-name, unused-import\n\"\"\"Runtime stream APIs which are mainly for internal test use only.\nFor applications, please use PyTorch's stream management, of which DGL is aware.\n\"\"\"\nfrom __future__ import absolute_import\n\nimport ctypes\n\nfrom .base import _FFI_MODE, _LIB, check_call\nfrom .runtime_ctypes import DGLStreamHandle\n\n\ndef to_dgl_stream_handle(cuda_stream):\n    \"\"\"Convert torch.cuda.Stream to DGL stream handle\n\n    Parameters\n    ----------\n    cuda_stream : torch.cuda.Stream.\n\n    Returns\n    -------\n    DGLStreamHandle\n        DGLStreamHandle of the input ``cuda_stream``.\n    \"\"\"\n    return ctypes.c_void_p(cuda_stream.cuda_stream)\n\n\ndef _dgl_get_stream(ctx):\n    \"\"\"Get the current CUDA stream of the given DGL context.\n\n    Parameters\n    ----------\n    ctx : DGL context.\n\n    Returns\n    -------\n    DGLStreamHandle\n        DGLStreamHandle of the current CUDA stream.\n    \"\"\"\n    current_cuda_stream = DGLStreamHandle()\n    check_call(\n        _LIB.DGLGetStream(\n            ctx.device_type, ctx.device_id, ctypes.byref(current_cuda_stream)\n        )\n    )\n    return current_cuda_stream\n"
  },
  {
    "path": "python/dgl/_sparse_ops.py",
    "content": "\"\"\"Module for sparse matrix operators.\"\"\"\n# pylint: disable= invalid-name\nfrom __future__ import absolute_import\n\nfrom . import backend as F, ndarray as nd\nfrom ._ffi.function import _init_api\nfrom .base import DGLError\n\n\ndef infer_broadcast_shape(op, shp1, shp2):\n    r\"\"\"Check the shape validity, and infer the output shape given input shape and operator.\n    Note the both :attr:`shp1`, :attr:`shp2` and the returned shape are feature\n    shapes (i.e. we remove the first dimension, which correspond to graph statistics\n    such as number of nodes, number of edges, etc.).\n\n    We allow applying op on operands with different shapes, according to the\n    broadcasting semantics of Numpy/Scipy:\n    https://numpy.org/doc/stable/user/basics.broadcasting.html\n\n    Parameters\n    ----------\n    op : str\n        The binary op's name, could be `add`, `sub`, `mul`, `div`, `dot`, `copy_lhs`, `copy_rhs`.\n    shp1 : tuple[int]\n        The shape of lhs operand.\n    shp2 : tuple[int]\n        The shape of rhs operand.\n\n    Returns\n    -------\n    tuple[int]\n        shape after broadcasting\n    \"\"\"\n    pad_shp1, pad_shp2 = shp1, shp2\n    if op == \"dot\":\n        if shp1[-1] != shp2[-1]:\n            raise DGLError(\n                \"Dot operator is only available for arrays with the \"\n                \"same size on last dimension, but got {} and {}.\".format(\n                    shp1, shp2\n                )\n            )\n    if op == \"copy_lhs\":\n        return shp1\n    if op == \"copy_rhs\":\n        return shp2\n    # operands are padded to have the same dimensionality with leading 1's.\n    if len(shp1) > len(shp2):\n        pad_shp2 = (1,) * (len(shp1) - len(shp2)) + shp2\n    elif len(shp1) < len(shp2):\n        pad_shp1 = (1,) * (len(shp2) - len(shp1)) + shp1\n    for d1, d2 in zip(pad_shp1, pad_shp2):\n        if d1 != d2 and d1 != 1 and d2 != 1:\n            raise DGLError(\n                \"Feature shapes {} and {} are not valid for broadcasting.\".format(\n                    shp1, shp2\n                )\n            )\n    rst = tuple(max(d1, d2) for d1, d2 in zip(pad_shp1, pad_shp2))\n    return rst[:-1] + (1,) if op == \"dot\" else rst\n\n\ndef to_dgl_nd(x):\n    \"\"\"Convert framework-specific tensor/None to dgl ndarray.\"\"\"\n    return nd.NULL[\"int64\"] if x is None else F.zerocopy_to_dgl_ndarray(x)\n\n\ndef to_dgl_nd_for_write(x):\n    \"\"\"Convert framework-specific tensor/None to dgl ndarray for write.\"\"\"\n    return (\n        nd.NULL[\"int64\"]\n        if x is None\n        else F.zerocopy_to_dgl_ndarray_for_write(x)\n    )\n\n\ndef get_typeid_by_target(gidx, etid, target):\n    \"\"\"Find the src/dst/etype id based on the target 'u', 'v' or 'e'.\"\"\"\n    src_id, dst_id = gidx.metagraph.find_edge(etid)\n    if target in [0, \"u\"]:\n        return src_id\n    if target in [2, \"v\"]:\n        return dst_id\n    return etid\n\n\ntarget_mapping = {\"u\": 0, \"e\": 1, \"v\": 2, \"src\": 0, \"edge\": 1, \"dst\": 2}\n\n\ndef _edge_softmax_backward(gidx, out, sds):\n    r\"\"\"Edge_softmax backward interface.\n\n    Parameters\n    ----------\n    gidx : HeteroGraphIndex\n        The input graph index.\n    out : tensor\n        The result of Edge_softmax during forward.\n    sds : tensor\n        The result of out * gradient.\n\n    Returns\n    -------\n    The result of Edge_softmax during backward\n\n    Notes\n    -----\n    This function does not support gpu op.\n    \"\"\"\n    op = \"copy_rhs\"\n    back_out = F.zeros_like(out)\n    _CAPI_DGLKernelEdge_softmax_backward(\n        gidx,\n        op,\n        to_dgl_nd(out),\n        to_dgl_nd(sds),\n        to_dgl_nd_for_write(back_out),\n        to_dgl_nd(None),\n    )\n    return back_out\n\n\ndef _edge_softmax_forward(gidx, e, op):\n    r\"\"\"Edge_softmax forward interface.\n\n    Parameters\n    ----------\n    gidx : HeteroGraphIndex\n        The input graph index.\n    op : str\n        The binary op's name, default as ``copy_rhs``.\n    e : tensor or None\n        The feature on edges.\n\n    Returns\n    -------\n    The result of Edge_softmax during forward\n\n    Notes\n    -----\n    This function does not support gpu op.\n    \"\"\"\n    if F.ndim(e) == 1:\n        e = F.unsqueeze(e, -1)\n        expand = True\n    else:\n        expand = False\n    myout = F.zeros_like(e)\n    _CAPI_DGLKernelEdge_softmax_forward(\n        gidx, op, to_dgl_nd(None), to_dgl_nd(e), to_dgl_nd_for_write(myout)\n    )\n    myout = F.squeeze(myout, -1) if expand else myout\n    return myout\n\n\ndef _gspmm(gidx, op, reduce_op, u, e):\n    r\"\"\"Generalized Sparse Matrix Multiplication interface. It takes the result of\n    :attr:`op` on source node feature and edge feature, leads to a message on edge.\n    Then aggregates the message by :attr:`reduce_op` on destination nodes.\n\n    .. math::\n        x_v = \\psi_{(u, v, e)\\in \\mathcal{G}}(\\rho(x_u, x_e))\n\n    where :math:`x_v` is the returned feature on destination nodes, and :math`x_u`,\n    :math:`x_e` refers to :attr:`u`, :attr:`e` respectively. :math:`\\rho` means binary\n    operator :attr:`op` and :math:`\\psi` means reduce operator :attr:`reduce_op`,\n    :math:`\\mathcal{G}` is the graph we apply gspmm on: :attr:`g`.\n\n    Note that this function does not handle gradients.\n\n    Parameters\n    ----------\n    gidx : HeteroGraphIndex\n        The input graph index.\n    op : str\n        The binary op's name, could be ``add``, ``sub``, ``mul``, ``div``, ``copy_lhs``,\n        ``copy_rhs``.\n    reduce_op : str\n        Reduce operator, could be ``sum``, ``max``, ``min``.\n    u : tensor or None\n        The feature on source nodes, could be None if op is ``copy_rhs``.\n    e : tensor or None\n        The feature on edges, could be None if op is ``copy_lhs``.\n\n    Returns\n    -------\n    tuple\n        The returned tuple is composed of two elements:\n        - The first element refers to the result tensor.\n        - The second element refers to a tuple composed of arg_u and arg_e\n          (which is useful when reducer is `min`/`max`).\n\n    Notes\n    -----\n    This function does not handle gradients.\n    \"\"\"\n    if gidx.number_of_etypes() != 1:\n        raise DGLError(\"We only support gspmm on graph with one edge type\")\n    use_u = op != \"copy_rhs\"\n    use_e = op != \"copy_lhs\"\n    if use_u and use_e:\n        if F.dtype(u) != F.dtype(e):\n            raise DGLError(\n                \"The node features' data type {} doesn't match edge\"\n                \" features' data type {}, please convert them to the\"\n                \" same type.\".format(F.dtype(u), F.dtype(e))\n            )\n    # deal with scalar features.\n    expand_u, expand_e = False, False\n    if use_u:\n        if F.ndim(u) == 1:\n            u = F.unsqueeze(u, -1)\n            expand_u = True\n    if use_e:\n        if F.ndim(e) == 1:\n            e = F.unsqueeze(e, -1)\n            expand_e = True\n\n    ctx = F.context(u) if use_u else F.context(e)\n    dtype = F.dtype(u) if use_u else F.dtype(e)\n    u_shp = F.shape(u) if use_u else (0,)\n    e_shp = F.shape(e) if use_e else (0,)\n    _, dsttype = gidx.metagraph.find_edge(0)\n    v_shp = (gidx.num_nodes(dsttype),) + infer_broadcast_shape(\n        op, u_shp[1:], e_shp[1:]\n    )\n    v = F.zeros(v_shp, dtype, ctx)\n    use_cmp = reduce_op in [\"max\", \"min\"]\n    arg_u, arg_e = None, None\n    idtype = getattr(F, gidx.dtype)\n    if use_cmp:\n        if use_u:\n            arg_u = F.zeros(v_shp, idtype, ctx)\n        if use_e:\n            arg_e = F.zeros(v_shp, idtype, ctx)\n    arg_u_nd = to_dgl_nd_for_write(arg_u)\n    arg_e_nd = to_dgl_nd_for_write(arg_e)\n    if gidx.num_edges(0) > 0:\n        _CAPI_DGLKernelSpMM(\n            gidx,\n            op,\n            reduce_op,\n            to_dgl_nd(u if use_u else None),\n            to_dgl_nd(e if use_e else None),\n            to_dgl_nd_for_write(v),\n            arg_u_nd,\n            arg_e_nd,\n        )\n    # NOTE(zihao): actually we can avoid the following step, because arg_*_nd\n    # refers to the data that stores arg_*. After we call _CAPI_DGLKernelSpMM,\n    # arg_* should have already been changed. But we found this doesn't work\n    # under Tensorflow when index type is int32. (arg_u and arg_e would be\n    # all zero).\n    # The workaround is proposed by Jinjing, and we still need to investigate\n    # where the problem is.\n    arg_u = None if arg_u is None else F.zerocopy_from_dgl_ndarray(arg_u_nd)\n    arg_e = None if arg_e is None else F.zerocopy_from_dgl_ndarray(arg_e_nd)\n    # To deal with scalar node/edge features.\n    if (expand_u or not use_u) and (expand_e or not use_e):\n        v = F.squeeze(v, -1)\n    if expand_u and use_cmp:\n        arg_u = F.squeeze(arg_u, -1)\n    if expand_e and use_cmp:\n        arg_e = F.squeeze(arg_e, -1)\n    return v, (arg_u, arg_e)\n\n\ndef _gspmm_hetero(gidx, op, reduce_op, u_len, u_and_e_tuple):\n    r\"\"\"Generalized Sparse Matrix Multiplication interface on heterogeneous graphs.\n    It handles multiple node and edge types of the graph. For each edge type, it takes\n    the result of :attr:`op` on source node feature and edge feature, and leads to a\n    message on edge. Then it aggregates the message by :attr:`reduce_op` on the destination\n    nodes of the etype.\n\n    .. math::\n        x_v = \\psi_{(u, v, e)\\in \\mathcal{G}}(\\rho(x_u, x_e))\n\n    where :math:`x_v` is the returned feature on destination nodes, and :math`x_u`,\n    :math:`x_e` refers to :attr:`u`, :attr:`e` respectively. :math:`\\rho` means binary\n    operator :attr:`op` and :math:`\\psi` means reduce operator :attr:`reduce_op`,\n    :math:`\\mathcal{G}` is the graph we apply gspmm on: :attr:`g`.\n\n    Note that this function does not handle gradients.\n\n    Parameters\n    ----------\n    gidx : HeteroGraphIndex\n        The input graph index.\n    op : str\n        The binary op's name, could be ``add``, ``sub``, ``mul``, ``div``, ``copy_lhs``,\n        ``copy_rhs``.\n    reduce_op : str\n        Reduce operator, could be ``sum``, ``max``, ``min``.\n    u_len : int\n        The number of tensors in ``u`` (source node features)\n    u_and_e_tuple : Tuple of tensors\n        Tuple of source nodes' features and edges' features. ``u_and_e_tuple[:u_len]``\n        stores the source nodes's features of all source node types. ``u_and_e_tuple[u_len:]``\n        stores the edges's features of all the edge types.\n        The source nodes' features of the soruce node types could be None if op is ``copy_rhs``.\n        The edges' features of the edge types could be None if op is ``copy_lhs``.\n\n    Returns\n    -------\n    tuple\n        The returned tuple is composed of two elements:\n        - The first element refers to the tuple of result tensors.\n        - The second element refers to a tuple composed of arg_u and arg_e\n          (which is useful when reducer is `min`/`max`).\n\n    Notes\n    -----\n    This function does not handle gradients.\n    \"\"\"\n    u_tuple, e_tuple = u_and_e_tuple[:u_len], u_and_e_tuple[u_len:]\n    use_u = op != \"copy_rhs\"\n    use_e = op != \"copy_lhs\"\n    # TODO (Israt): Add check - F.dtype(u) != F.dtype(e):\n\n    # deal with scalar features.\n    expand_u, expand_e = False, False\n    num_ntypes = gidx.number_of_ntypes()\n    num_etypes = gidx.number_of_etypes()\n    list_u = [None] * num_ntypes\n    list_v = [None] * num_ntypes\n    list_e = [None] * num_etypes\n    list_arg_u_nd = [None] * num_ntypes\n    list_arg_u = [None] * num_ntypes\n    list_arg_u_ntype_nd = [None] * num_ntypes\n    list_arg_u_ntype = [None] * num_ntypes\n    # TODO(Israt): double check ntype or etype\n    list_arg_e_nd = [None] * num_ntypes\n    list_arg_e = [None] * num_ntypes\n    list_arg_e_etype_nd = [None] * num_ntypes\n    list_arg_e_etype = [None] * num_ntypes\n\n    use_cmp = reduce_op in [\"max\", \"min\"]\n    idtype = getattr(F, gidx.dtype)\n\n    for etid in range(num_etypes):\n        src_id, dst_id = gidx.metagraph.find_edge(etid)\n        u = u_tuple[src_id] if use_u else None\n        e = e_tuple[etid] if use_e else None\n        if use_u:\n            if u is not None and F.ndim(u) == 1:\n                u = F.unsqueeze(u, -1)\n                expand_u = True\n            list_u[src_id] = u if use_u else None\n        if use_e:\n            if e is not None and F.ndim(e) == 1:\n                e = F.unsqueeze(e, -1)\n                expand_e = True\n            list_e[etid] = e if use_e else None\n        ctx = (\n            F.context(u) if use_u else F.context(e)\n        )  # TODO(Israt): Put outside of loop\n        dtype = (\n            F.dtype(u) if use_u else F.dtype(e)\n        )  # TODO(Israt): Put outside of loop\n        u_shp = F.shape(u) if use_u else (0,)\n        e_shp = F.shape(e) if use_e else (0,)\n        v_shp = (gidx.num_nodes(dst_id),) + infer_broadcast_shape(\n            op, u_shp[1:], e_shp[1:]\n        )\n        list_v[dst_id] = F.zeros(v_shp, dtype, ctx)\n        if use_cmp:\n            if use_u:\n                list_arg_u[dst_id] = F.zeros(v_shp, idtype, ctx)\n                list_arg_u_ntype[dst_id] = F.zeros(v_shp, idtype, ctx)\n            if use_e:\n                list_arg_e[dst_id] = F.zeros(v_shp, idtype, ctx)\n                list_arg_e_etype[dst_id] = F.zeros(v_shp, idtype, ctx)\n        list_arg_u_nd[dst_id] = to_dgl_nd_for_write(list_arg_u[dst_id])\n        list_arg_u_ntype_nd[dst_id] = to_dgl_nd_for_write(\n            list_arg_u_ntype[dst_id]\n        )\n        list_arg_e_nd[dst_id] = to_dgl_nd_for_write(list_arg_e[dst_id])\n        list_arg_e_etype_nd[dst_id] = to_dgl_nd_for_write(\n            list_arg_e_etype[dst_id]\n        )\n\n    if gidx.num_edges(0) > 0:\n        _CAPI_DGLKernelSpMMHetero(\n            gidx,\n            op,\n            reduce_op,\n            [to_dgl_nd(u_i) for u_i in list_u],\n            [to_dgl_nd(e_i) for e_i in list_e],\n            [to_dgl_nd_for_write(v_i) for v_i in list_v],\n            list_arg_u_nd,\n            list_arg_e_nd,\n            list_arg_u_ntype_nd,\n            list_arg_e_etype_nd,\n        )\n    for l, arg_u_nd in enumerate(list_arg_u_nd):\n        # TODO(Israt): l or src_id as index of lhs\n        list_arg_u[l] = (\n            None\n            if list_arg_u[l] is None\n            else F.zerocopy_from_dgl_ndarray(arg_u_nd)\n        )\n        if list_arg_u[l] is not None and expand_u and use_cmp:\n            list_arg_u[l] = F.squeeze(list_arg_u[l], -1)\n    for l, arg_e_nd in enumerate(list_arg_e_nd):\n        list_arg_e[l] = (\n            None\n            if list_arg_e[l] is None\n            else F.zerocopy_from_dgl_ndarray(arg_e_nd)\n        )\n        if list_arg_e[l] is not None and expand_e and use_cmp:\n            list_arg_e[l] = F.squeeze(list_arg_e[l], -1)\n    for l, arg_u_ntype_nd in enumerate(list_arg_u_ntype_nd):\n        list_arg_u_ntype[l] = (\n            None\n            if arg_u_ntype_nd is None\n            else F.zerocopy_from_dgl_ndarray(arg_u_ntype_nd)\n        )\n    for l, arg_e_etype_nd in enumerate(list_arg_e_etype_nd):\n        list_arg_e_etype[l] = (\n            None\n            if arg_e_etype_nd is None\n            else F.zerocopy_from_dgl_ndarray(arg_e_etype_nd)\n        )\n    # To deal with scalar node/edge features.\n    for l in range(num_ntypes):\n        # replace None by empty tensor. Forward func doesn't accept None in tuple.\n        v = list_v[l]\n        v = F.tensor([]) if v is None else v\n        if (expand_u or not use_u) and (expand_e or not use_e):\n            v = F.squeeze(v, -1)  # To deal with scalar node/edge features.\n        list_v[l] = v\n    out = tuple(list_v)\n    return out, (list_arg_u, list_arg_e, list_arg_u_ntype, list_arg_e_etype)\n\n\ndef _segment_mm(A, B, out, seglen_A, b_trans=False):\n    \"\"\"Invoke the C API of segment_mm.\"\"\"\n    _CAPI_DGLKernelSEGMENTMM(\n        to_dgl_nd(A),\n        to_dgl_nd(B),\n        to_dgl_nd_for_write(out),\n        to_dgl_nd(seglen_A),\n        False,\n        b_trans,\n    )\n    return out\n\n\ndef _segment_mm_backward_B(A, dC, dB, seglen):\n    \"\"\"Invoke the C API of the backward of segment_mm on B.\"\"\"\n    _CAPI_DGLKernelSEGMENTMMBackwardB(\n        to_dgl_nd(A), to_dgl_nd(dC), to_dgl_nd_for_write(dB), to_dgl_nd(seglen)\n    )\n    return dB\n\n\ndef _gather_mm(A, B, out, idx_a=None, idx_b=None):\n    r\"\"\"Invoke the C API of the gather_mm operator.\"\"\"\n    _CAPI_DGLKernelGATHERMM(\n        to_dgl_nd(A),\n        to_dgl_nd(B),\n        to_dgl_nd_for_write(out),\n        to_dgl_nd(idx_a),\n        to_dgl_nd(idx_b),\n    )\n    return out\n\n\ndef _gather_mm_scatter(A, B, out, idx_a=None, idx_b=None, idx_c=None):\n    r\"\"\"Invoke the C API of the gather_mm_scatter operator.\"\"\"\n    _CAPI_DGLKernelGATHERMMSCATTER(\n        to_dgl_nd(A),\n        to_dgl_nd(B),\n        to_dgl_nd_for_write(out),\n        to_dgl_nd(idx_a),\n        to_dgl_nd(idx_b),\n        to_dgl_nd(idx_c),\n    )\n    return out\n\n\ndef _gsddmm(gidx, op, lhs, rhs, lhs_target=\"u\", rhs_target=\"v\"):\n    r\"\"\"Generalized Sampled-Dense-Dense Matrix Multiplication interface. It\n    takes the result of :attr:`op` on source node feature and destination node\n    feature, leads to a feature on edge.\n\n    .. math::\n        x_{e} = \\phi(x_u, x_e, x_v), \\forall (u,e,v)\\in \\mathcal{G}\n\n    where :math:`x_{e}` is the returned feature on edges and :math:`x_u`,\n    :math:`x_v` refers to :attr:`u`, :attr:`v` respectively. :math:`\\phi`\n    is the binary operator :attr:`op`, and :math:`\\mathcal{G}` is the graph\n    we apply gsddmm on: :attr:`g`.\n\n    Parameters\n    ----------\n    gidx : HeteroGraphIndex\n        The input graph index.\n    op : str\n        Binary operator, could be ``add``, ``sub``, ``mul``, ``div``, ``dot``,\n        ``copy_lhs``, ``copy_rhs``.\n    lhs : tensor or None\n        Left hand operand.\n    rhs : tensor or None\n        Right hand operand.\n    lhs_target : str\n        The target of left hand operand, could be ``src``, ``edge``, ``dst``\n        or their alias ``u``, ``e``, ``v``.\n    rhs_target : str\n        The target of right hand operand, could be ``src``, ``edge``, ``dst``\n        or their alias ``u``, ``e``, ``v``.\n\n    Returns\n    -------\n    tensor\n        The result tensor.\n\n    Notes\n    -----\n    This function does not handle gradients.\n    \"\"\"\n    if gidx.number_of_etypes() != 1:\n        raise DGLError(\"We only support gsddmm on graph with one edge type\")\n    use_lhs = op != \"copy_rhs\"\n    use_rhs = op != \"copy_lhs\"\n    if use_lhs and use_rhs:\n        if F.dtype(lhs) != F.dtype(rhs):\n            raise DGLError(\n                \"The operands data type don't match: {} and {}, please convert them\"\n                \" to the same type.\".format(F.dtype(lhs), F.dtype(rhs))\n            )\n    # deal with scalar features.\n    expand_lhs, expand_rhs = False, False\n    if use_lhs:\n        if F.ndim(lhs) == 1:\n            lhs = F.unsqueeze(lhs, -1)\n            expand_lhs = True\n    if use_rhs:\n        if F.ndim(rhs) == 1:\n            rhs = F.unsqueeze(rhs, -1)\n            expand_rhs = True\n    lhs_target = target_mapping[lhs_target]\n    rhs_target = target_mapping[rhs_target]\n\n    ctx = F.context(lhs) if use_lhs else F.context(rhs)\n    dtype = F.dtype(lhs) if use_lhs else F.dtype(rhs)\n    lhs_shp = F.shape(lhs) if use_lhs else (0,)\n    rhs_shp = F.shape(rhs) if use_rhs else (0,)\n    out_shp = (gidx.num_edges(0),) + infer_broadcast_shape(\n        op, lhs_shp[1:], rhs_shp[1:]\n    )\n    out = F.empty(out_shp, dtype, ctx)\n    if gidx.num_edges(0) > 0:\n        _CAPI_DGLKernelSDDMM(\n            gidx,\n            op,\n            to_dgl_nd(lhs if use_lhs else None),\n            to_dgl_nd(rhs if use_rhs else None),\n            to_dgl_nd_for_write(out),\n            lhs_target,\n            rhs_target,\n        )\n    if (expand_lhs or not use_lhs) and (expand_rhs or not use_rhs):\n        out = F.squeeze(out, -1)\n    return out\n\n\ndef _gsddmm_hetero(\n    gidx, op, lhs_len, lhs_target=\"u\", rhs_target=\"v\", lhs_and_rhs_tuple=None\n):\n    r\"\"\"Generalized Sampled-Dense-Dense Matrix Multiplication interface.\"\"\"\n    lhs_tuple, rhs_tuple = (\n        lhs_and_rhs_tuple[:lhs_len],\n        lhs_and_rhs_tuple[lhs_len:],\n    )\n\n    use_lhs = op != \"copy_rhs\"\n    use_rhs = op != \"copy_lhs\"\n\n    # TODO (Israt): Add check - F.dtype(u) != F.dtype(e):\n    # deal with scalar features.\n    expand_lhs, expand_rhs = False, False\n    num_ntype = gidx.number_of_ntypes()\n    num_etype = gidx.number_of_etypes()\n    lhs_list = (\n        [None] * num_ntype if lhs_target in [\"u\", \"v\"] else [None] * num_etype\n    )\n    rhs_list = (\n        [None] * num_ntype if rhs_target in [\"u\", \"v\"] else [None] * num_etype\n    )\n    out_list = [None] * gidx.number_of_etypes()\n\n    lhs_target = target_mapping[lhs_target]\n    rhs_target = target_mapping[rhs_target]\n\n    for etid in range(gidx.number_of_etypes()):\n        lhs_id = get_typeid_by_target(gidx, etid, lhs_target)\n        rhs_id = get_typeid_by_target(gidx, etid, rhs_target)\n        lhs = lhs_tuple[lhs_id]\n        rhs = rhs_tuple[rhs_id]\n        if use_lhs:\n            if lhs is not None and F.ndim(lhs) == 1:\n                lhs = F.unsqueeze(lhs, -1)\n                expand_lhs = True\n        if use_rhs:\n            if rhs is not None and F.ndim(rhs) == 1:\n                rhs = F.unsqueeze(rhs, -1)\n                expand_rhs = True\n        ctx = F.context(lhs) if use_lhs else F.context(rhs)\n        dtype = F.dtype(lhs) if use_lhs else F.dtype(rhs)\n        lhs_shp = F.shape(lhs) if use_lhs else (0,)\n        rhs_shp = F.shape(rhs) if use_rhs else (0,)\n        lhs_list[lhs_id] = lhs if use_lhs else None\n        rhs_list[rhs_id] = rhs if use_rhs else None\n        out_shp = (gidx.num_edges(etid),) + infer_broadcast_shape(\n            op, lhs_shp[1:], rhs_shp[1:]\n        )\n        out_list[etid] = F.empty(out_shp, dtype, ctx)\n    if gidx.num_edges(0) > 0:\n        _CAPI_DGLKernelSDDMMHetero(\n            gidx,\n            op,\n            [to_dgl_nd(lhs) for lhs in lhs_list],\n            [to_dgl_nd(rhs) for rhs in rhs_list],\n            [to_dgl_nd_for_write(out) for out in out_list],\n            lhs_target,\n            rhs_target,\n        )\n\n    for l in range(gidx.number_of_etypes()):\n        # Replace None by empty tensor. Forward func doesn't accept None in tuple.\n        e = out_list[l]\n        e = F.tensor([]) if e is None else e\n        if (expand_lhs or not use_lhs) and (expand_rhs or not use_rhs):\n            e = F.squeeze(e, -1)\n        out_list[l] = e\n    out = tuple(out_list)\n    return out\n\n\ndef _segment_reduce(op, feat, offsets):\n    r\"\"\"Segment reduction operator.\n\n    It aggregates the value tensor along the first dimension by segments.\n    The argument ``offsets`` specifies the start offset of each segment (and\n    the upper bound of the last segment). Zero-length segments are allowed.\n\n    .. math::\n      y_i = \\Phi_{j=\\mathrm{offsets}_i}^{\\mathrm{offsets}_{i+1}-1} x_j\n\n    where :math:`\\Phi` is the reduce operator.\n\n    Parameters\n    ----------\n    op : str\n        Aggregation method. Can be ``sum``, ``max``, ``min``.\n    x : Tensor\n        Value to aggregate.\n    offsets : Tensor\n        The start offsets of segments.\n\n    Returns\n    -------\n    tuple(Tensor)\n        The first tensor correspond to aggregated tensor of shape\n        ``(len(seglen), value.shape[1:])``, and the second tensor records\n        the argmin/max at each position for computing gradients.\n\n    Notes\n    -----\n    This function does not handle gradients.\n    \"\"\"\n    n = F.shape(offsets)[0] - 1\n    out_shp = (n,) + F.shape(feat)[1:]\n    ctx = F.context(feat)\n    dtype = F.dtype(feat)\n    idtype = F.dtype(offsets)\n    out = F.zeros(out_shp, dtype, ctx)\n    arg = None\n    if op in [\"min\", \"max\"]:\n        arg = F.zeros(out_shp, idtype, ctx)\n    arg_nd = to_dgl_nd_for_write(arg)\n    _CAPI_DGLKernelSegmentReduce(\n        op,\n        to_dgl_nd(feat),\n        to_dgl_nd(offsets),\n        to_dgl_nd_for_write(out),\n        arg_nd,\n    )\n    arg = None if arg is None else F.zerocopy_from_dgl_ndarray(arg_nd)\n    return out, arg\n\n\ndef _scatter_add(x, idx, m):\n    r\"\"\"Scatter add operator (on first dimension) implementation.\n\n    Math: y[idx[i], *] += x[i, *]\n\n    Parameters\n    ----------\n    x : Tensor\n        The input feature.\n    idx : Tensor\n        The indices array.\n    m : int\n        The length of output.\n\n    Returns\n    -------\n    Tensor\n        The output tensor.\n    \"\"\"\n    out_shp = (m,) + F.shape(x)[1:]\n    ctx = F.context(x)\n    dtype = F.dtype(x)\n    out = F.zeros(out_shp, dtype, ctx)\n    _CAPI_DGLKernelScatterAdd(\n        to_dgl_nd(x), to_dgl_nd(idx), to_dgl_nd_for_write(out)\n    )\n    return out\n\n\ndef _update_grad_minmax_hetero(\n    gidx, op, list_x, list_idx, list_idx_etype, list_dX\n):\n    r\"\"\"Update gradients for reduce operator max and min (on first dimension) implementation.\n\n    Parameters\n    ----------\n    gidx : HeteroGraphIndex\n        The input graph index.\n    list_x : List of tensors\n        List of the input features.\n    list_idx : List of tensors\n        List of the indices array.\n    list_idx_etype : List of tensors\n        List of the node- or edge-type array.\n    list_dX : List of tensors\n        List of gradients.\n\n    Returns\n    -------\n    Tensor\n        The output tensor.\n    \"\"\"\n    use_u = op != \"copy_rhs\"\n    use_e = op != \"copy_lhs\"\n    list_out = [None] * len(list_dX)\n    for etid in range(gidx.number_of_etypes()):\n        src_id, dst_id = gidx.metagraph.find_edge(etid)  # gidx is reveresed\n        x = list_x[src_id]\n        ctx = F.context(x)\n        dtype = F.dtype(x)\n        if use_u:\n            out_shp = (len(list_dX[dst_id]),) + F.shape(x)[1:]\n            list_out[dst_id] = F.zeros(out_shp, dtype, ctx)\n        if use_e:\n            out_shp = (len(list_dX[etid]),) + F.shape(x)[1:]\n            list_out[etid] = F.zeros(out_shp, dtype, ctx)\n\n    _CAPI_DGLKernelUpdateGradMinMaxHetero(\n        gidx,\n        op,\n        [to_dgl_nd(x) for x in list_x],\n        [to_dgl_nd(idx) for idx in list_idx],\n        [to_dgl_nd(idx_etype) for idx_etype in list_idx_etype],\n        [to_dgl_nd_for_write(out) for out in list_out],\n    )\n    return tuple(list_out)\n\n\ndef _bwd_segment_cmp(feat, arg, m):\n    r\"\"\"Backward phase of segment reduction (for 'min'/'max' reduction).\n\n    It computes the gradient of input feature given output gradient of\n    the segment reduction result.\n\n    Parameters\n    ----------\n    feat : Tensor\n        The output gradient\n    arg : Tensor\n        The ArgMin/Max tensor produced by segment_reduce op.\n    m : int\n        The length of input gradients' first dimension.\n\n    Returns\n    -------\n    Tensor\n        The input gradient.\n    \"\"\"\n    out_shp = (m,) + F.shape(feat)[1:]\n    ctx = F.context(feat)\n    dtype = F.dtype(feat)\n    out = F.zeros(out_shp, dtype, ctx)\n    _CAPI_DGLKernelBwdSegmentCmp(\n        to_dgl_nd(feat), to_dgl_nd(arg), to_dgl_nd_for_write(out)\n    )\n    return out\n\n\ndef _csrmm(A, A_weights, B, B_weights, num_vtypes):\n    \"\"\"Return a graph whose adjacency matrix is the sparse matrix multiplication\n    of those of two given graphs.\n\n    Note that the edge weights of both graphs must be scalar, i.e. :attr:`A_weights`\n    and :attr:`B_weights` must be 1D vectors.\n\n    Parameters\n    ----------\n    A : HeteroGraphIndex\n        The input graph index as left operand.\n    A_weights : Tensor\n        The edge weights of graph A as 1D tensor.\n    B : HeteroGraphIndex\n        The input graph index as right operand.\n    B_weights : Tensor\n        The edge weights of graph B as 1D tensor.\n    num_vtypes : int\n        The number of node types for the returned graph (must be either 1 or 2).\n\n    Returns\n    -------\n    C : HeteroGraphIndex\n        The output graph index.\n    C_weights : Tensor\n        The edge weights of the output graph.\n    \"\"\"\n    C, C_weights = _CAPI_DGLCSRMM(\n        A, F.to_dgl_nd(A_weights), B, F.to_dgl_nd(B_weights), num_vtypes\n    )\n    return C, F.from_dgl_nd(C_weights)\n\n\ndef _csrsum(As, A_weights):\n    \"\"\"Return a graph whose adjacency matrix is the sparse matrix summation\n    of the given list of graphs.\n\n    Note that the edge weights of all graphs must be scalar, i.e. the arrays in\n    :attr:`A_weights` must be 1D vectors.\n\n    Parameters\n    ----------\n    As : list[HeteroGraphIndex]\n        The input graph indices.\n    A_weights : list[Tensor]\n        The edge weights of graph A as 1D tensor.\n\n    Returns\n    -------\n    C : HeteroGraphIndex\n        The output graph index.\n    C_weights : Tensor\n        The edge weights of the output graph.\n    \"\"\"\n    C, C_weights = _CAPI_DGLCSRSum(As, [F.to_dgl_nd(w) for w in A_weights])\n    return C, F.from_dgl_nd(C_weights)\n\n\ndef _csrmask(A, A_weights, B):\n    \"\"\"Return the weights of A at the locations identical to the sparsity pattern\n    of B.\n\n    If a non-zero entry in B does not exist in A, DGL returns 0 for that location\n    instead.\n\n    Note that the edge weights of the graph must be scalar, i.e. :attr:`A_weights`\n    must be a 1D vector.\n\n    In scipy notation this is identical to ``A[B != 0]``.\n\n    Parameters\n    ----------\n    A : HeteroGraphIndex\n        The input graph index as left operand.\n    A_weights : Tensor\n        The edge weights of graph A as 1D tensor.\n    B : HeteroGraphIndex\n        The input graph index as right operand.\n\n    Returns\n    -------\n    B_weights : Tensor\n        The output weights.\n    \"\"\"\n    return F.from_dgl_nd(_CAPI_DGLCSRMask(A, F.to_dgl_nd(A_weights), B))\n\n\n###################################################################################################\n## Libra Graph Partition\ndef libra_vertex_cut(\n    nc,\n    node_degree,\n    edgenum_unassigned,\n    community_weights,\n    u,\n    v,\n    w,\n    out,\n    N,\n    N_e,\n    dataset,\n):\n    \"\"\"\n    This function invokes C/C++ code for Libra based graph partitioning.\n    Parameter details are present in dgl/src/array/libra_partition.cc\n    \"\"\"\n    _CAPI_DGLLibraVertexCut(\n        nc,\n        to_dgl_nd_for_write(node_degree),\n        to_dgl_nd_for_write(edgenum_unassigned),\n        to_dgl_nd_for_write(community_weights),\n        to_dgl_nd(u),\n        to_dgl_nd(v),\n        to_dgl_nd(w),\n        to_dgl_nd_for_write(out),\n        N,\n        N_e,\n        dataset,\n    )\n\n\ndef libra2dgl_build_dict(\n    a,\n    b,\n    indices,\n    ldt_key,\n    gdt_key,\n    gdt_value,\n    node_map,\n    offset,\n    nc,\n    c,\n    fsize,\n    dataset,\n):\n    \"\"\"\n    This function invokes C/C++ code for pre-processing Libra output.\n    After graph partitioning using Libra, during conversion from Libra output to DGL/DistGNN input,\n    this function creates dictionaries to assign local node ids to the partitioned nodes\n    and also to create a database of the split nodes.\n    Parameter details are present in dgl/src/array/libra_partition.cc\n    \"\"\"\n    ret = _CAPI_DGLLibra2dglBuildDict(\n        to_dgl_nd_for_write(a),\n        to_dgl_nd_for_write(b),\n        to_dgl_nd_for_write(indices),\n        to_dgl_nd_for_write(ldt_key),\n        to_dgl_nd_for_write(gdt_key),\n        to_dgl_nd_for_write(gdt_value),\n        to_dgl_nd_for_write(node_map),\n        to_dgl_nd_for_write(offset),\n        nc,\n        c,\n        fsize,\n        dataset,\n    )\n    return ret\n\n\ndef libra2dgl_build_adjlist(\n    feat,\n    gfeat,\n    adj,\n    inner_node,\n    ldt,\n    gdt_key,\n    gdt_value,\n    node_map,\n    lr,\n    lrtensor,\n    num_nodes,\n    nc,\n    c,\n    feat_size,\n    labels,\n    trainm,\n    testm,\n    valm,\n    glabels,\n    gtrainm,\n    gtestm,\n    gvalm,\n    feat_shape,\n):\n    \"\"\"\n    This function invokes C/C++ code for pre-processing Libra output.\n    After graph partitioning using Libra, once the local and global dictionaries are built,\n    for each node in each partition, this function copies the split node details from the\n    global dictionary. It also copies features, label, train, test, and validation information\n    for each node from the input graph to the corresponding partitions.\n    Parameter details are present in dgl/src/array/libra_partition.cc\n    \"\"\"\n    _CAPI_DGLLibra2dglBuildAdjlist(\n        to_dgl_nd(feat),\n        to_dgl_nd_for_write(gfeat),\n        to_dgl_nd_for_write(adj),\n        to_dgl_nd_for_write(inner_node),\n        to_dgl_nd(ldt),\n        to_dgl_nd(gdt_key),\n        to_dgl_nd(gdt_value),\n        to_dgl_nd(node_map),\n        to_dgl_nd_for_write(lr),\n        to_dgl_nd(lrtensor),\n        num_nodes,\n        nc,\n        c,\n        feat_size,\n        to_dgl_nd(labels),\n        to_dgl_nd(trainm),\n        to_dgl_nd(testm),\n        to_dgl_nd(valm),\n        to_dgl_nd_for_write(glabels),\n        to_dgl_nd_for_write(gtrainm),\n        to_dgl_nd_for_write(gtestm),\n        to_dgl_nd_for_write(gvalm),\n        feat_shape,\n    )\n\n\ndef libra2dgl_set_lr(gdt_key, gdt_value, lrtensor, nc, Nn):\n    \"\"\"\n    This function invokes C/C++ code for pre-processing Libra output.\n    To prepare the graph partitions for DistGNN input, this function sets the leaf\n    and root (1-level tree) among the split copies (across different partitions)\n    of a node from input graph.\n    Parameter details are present in dgl/src/array/libra_partition.cc\n    \"\"\"\n    _CAPI_DGLLibra2dglSetLR(\n        to_dgl_nd(gdt_key),\n        to_dgl_nd(gdt_value),\n        to_dgl_nd_for_write(lrtensor),\n        nc,\n        Nn,\n    )\n\n\n_init_api(\"dgl.sparse\", __name__)\n"
  },
  {
    "path": "python/dgl/backend/__init__.py",
    "content": "from __future__ import absolute_import\n\nimport importlib\nimport json\nimport logging\nimport os\nimport sys\n\nfrom . import backend\nfrom .set_default_backend import set_default_backend\n\n_enabled_apis = set()\n\nlogger = logging.getLogger(\"dgl-core\")\n\n\ndef _gen_missing_api(api, mod_name):\n    def _missing_api(*args, **kwargs):\n        raise ImportError(\n            'API \"%s\" is not supported by backend \"%s\".'\n            \" You can switch to other backends by setting\"\n            \" the DGLBACKEND environment.\" % (api, mod_name)\n        )\n\n    return _missing_api\n\n\ndef load_backend(mod_name):\n    # Load backend does four things:\n    # (1) Import backend framework (PyTorch, MXNet, Tensorflow, etc.)\n    # (2) Import DGL C library.  DGL imports it *after* PyTorch/MXNet/Tensorflow.  Otherwise\n    #     DGL will crash with errors like `munmap_chunk(): invalid pointer`.\n    # (3) Sets up the tensoradapter library path.\n    # (4) Import the Python wrappers of the backend framework.  DGL does this last because\n    #     it already depends on both the backend framework and the DGL C library.\n    if mod_name == \"pytorch\":\n        import torch\n\n        mod = torch\n    elif mod_name == \"mxnet\":\n        import mxnet\n\n        mod = mxnet\n    elif mod_name == \"tensorflow\":\n        import tensorflow\n\n        mod = tensorflow\n    else:\n        raise NotImplementedError(\"Unsupported backend: %s\" % mod_name)\n\n    from .._ffi.base import load_tensor_adapter  # imports DGL C library\n\n    version = mod.__version__\n    load_tensor_adapter(mod_name, version)\n\n    logger.debug(\"Using backend: %s\" % mod_name)\n    mod = importlib.import_module(\".%s\" % mod_name, __name__)\n    thismod = sys.modules[__name__]\n    for api in backend.__dict__.keys():\n        if api.startswith(\"__\"):\n            # ignore python builtin attributes\n            continue\n        if api == \"data_type_dict\":\n            # load data type\n            if api not in mod.__dict__:\n                raise ImportError(\n                    'API \"data_type_dict\" is required but missing for'\n                    ' backend \"%s\".' % (mod_name)\n                )\n            data_type_dict = mod.__dict__[api]()\n            for name, dtype in data_type_dict.items():\n                setattr(thismod, name, dtype)\n\n            # override data type dict function\n            setattr(thismod, \"data_type_dict\", data_type_dict)\n\n            # for data types with aliases, treat the first listed type as\n            # the true one\n            rev_data_type_dict = {}\n            for k, v in data_type_dict.items():\n                if not v in rev_data_type_dict.keys():\n                    rev_data_type_dict[v] = k\n            setattr(thismod, \"reverse_data_type_dict\", rev_data_type_dict)\n            # log backend name\n            setattr(thismod, \"backend_name\", mod_name)\n        else:\n            # load functions\n            if api in mod.__dict__:\n                _enabled_apis.add(api)\n                setattr(thismod, api, mod.__dict__[api])\n            else:\n                setattr(thismod, api, _gen_missing_api(api, mod_name))\n\n\ndef get_preferred_backend():\n    default_dir = None\n    if \"DGLDEFAULTDIR\" in os.environ:\n        default_dir = os.getenv(\"DGLDEFAULTDIR\")\n    else:\n        default_dir = os.path.join(os.path.expanduser(\"~\"), \".dgl\")\n    config_path = os.path.join(default_dir, \"config.json\")\n    backend_name = None\n    if \"DGLBACKEND\" in os.environ:\n        backend_name = os.getenv(\"DGLBACKEND\")\n    elif os.path.exists(config_path):\n        with open(config_path, \"r\") as config_file:\n            config_dict = json.load(config_file)\n            backend_name = config_dict.get(\"backend\", \"\").lower()\n\n    if backend_name in [\"tensorflow\", \"mxnet\", \"pytorch\"]:\n        return backend_name\n    else:\n        print(\n            \"DGL backend not selected or invalid.  \"\n            \"Assuming PyTorch for now.\",\n            file=sys.stderr,\n        )\n        set_default_backend(default_dir, \"pytorch\")\n        return \"pytorch\"\n\n\nload_backend(get_preferred_backend())\n\n\ndef is_enabled(api):\n    \"\"\"Return true if the api is enabled by the current backend.\n\n    Parameters\n    ----------\n    api : str\n        The api name.\n\n    Returns\n    -------\n    bool\n        True if the API is enabled by the current backend.\n    \"\"\"\n    return api in _enabled_apis\n\n\ndef to_dgl_nd(data):\n    return zerocopy_to_dgl_ndarray(data)\n\n\ndef from_dgl_nd(data):\n    return zerocopy_from_dgl_ndarray(data)\n"
  },
  {
    "path": "python/dgl/backend/backend.py",
    "content": "\"\"\"This file defines the unified tensor framework interface required by DGL.\n\nThe principles of this interface:\n* There should be as few interfaces as possible.\n* The interface is used by DGL system so it is more important to have\n  clean definition rather than convenient usage.\n* Default arguments should be avoided.\n* Keyword or positional arguments should be avoided.\n* Argument type should be easier to understand.\n\nIt is recommended the frameworks implement all the interfaces. However, it is\nalso OK to skip some. The generated backend module has an ``is_enabled`` function\nthat returns whether the interface is supported by the framework or not.\n\"\"\"\n\n###############################################################################\n# Tensor, data type and context interfaces\n\n\ndef data_type_dict():\n    \"\"\"Returns a dictionary from data type string to the data type.\n\n    The dictionary should include at least:\n    bfloat16\n    float16\n    float32\n    float64\n    uint8\n    int8\n    int16\n    int32\n    int64\n    bool\n\n    This function will be called only *once* during the initialization fo the\n    backend module. The returned dictionary will become the attributes of the\n    backend module.\n\n    Examples\n    --------\n    >>> import torch as th\n    >>> def data_type_dict():\n    >>>   return { 'float16' : th.float16, 'float32' : th.float32, ... }\n\n    After the module is initialized.\n\n    >>> import backend as F\n    >>> F.float16  # this will point to torch.float16\n\n    Returns\n    -------\n    dict of str to data type\n        The data type dict.\n    \"\"\"\n    pass\n\n\ndef cpu():\n    \"\"\"Return a context object for CPU device.\"\"\"\n    pass\n\n\ndef tensor(data, dtype=None):\n    \"\"\"Create a tensor given the data and data type.\n\n    If the input is already a tensor and has the same dtype,\n    directly return.\n\n    Scalar input is converted to a array of one element instead of\n    a 0-dim tensor to avoid certain issues with some backends.\n\n    Parameters\n    ----------\n    data : int, iterable, Tensor\n        The interface should at least support list and numpy array.\n        The data is copied to a newly-allocated tensor.\n    dtype : data type, optional\n        It should be one of the values in the data type dict.\n        If is none, the type should be inferred from data.\n\n    Returns\n    -------\n    Tensor\n        A framework-specific tensor.\n    \"\"\"\n    pass\n\n\ndef as_scalar(data):\n    \"\"\"Returns a scalar whose value is copied from this array.\n\n    Parameters\n    ----------\n    data : Tensor\n        The input data\n\n    Returns\n    -------\n    scalar\n        The scalar value in the tensor.\n    \"\"\"\n    pass\n\n\ndef get_preferred_sparse_format():\n    \"\"\"Get the preferred sparse matrix format supported by the backend.\n\n    Different backends have their preferred backend. This info is useful when\n    constructing a sparse matrix.\n\n    Returns\n    -------\n    string\n        the name of the preferred sparse matrix format.\n    \"\"\"\n    pass\n\n\ndef sparse_matrix(data, index, shape, force_format=False):\n    \"\"\"Create a sparse matrix.\n\n    NOTE: Please make sure that the data and index tensors are not\n    copied. This is critical to the performance.\n\n    Parameters\n    ----------\n    data : Tensor\n        Data tensor. It should be of shape (nnz,).\n    index : tuple\n        This is used to support different sparse formats.\n        For COO format:\n          index=('coo', coord), where coord is of shape (2, nnz).\n          coord[0,:] should be the row index and coord[1,:] should be\n          the column index.\n        For CSR format:\n          index=('csr', indices, indptr), where indices is of shape (nnz,)\n          and indptr is of shape (nrows+1,). See ``scipy.sparse.csr_matrix``\n          for more documents on what each array means.\n    shape : tuple of int\n        The shape.\n    force_format : bool\n        If true, the returned sparse matrix must be stored in the same\n        format as the given index.\n\n    Returns\n    -------\n    SparseMatrix\n        The framework-specific sparse matrix. It can be stored in any format\n        unless force_format is True.\n    Tensor\n        The data convert index due to sparse format change.\n        None if no conversion is needed.\n    \"\"\"\n    pass\n\n\ndef sparse_matrix_indices(spmat):\n    \"\"\"Return the indices of the given sparse matrix.\n\n    Parameters\n    ----------\n    spmat : SparseMatrix\n        The framework-specific sparse matrix.\n\n    Returns\n    -------\n    index : tuple\n        This is used to support different sparse formats.\n        For COO format:\n          index=('coo', coord), where coord is of shape (2, nnz).\n          coord[0,:] should be the row index and coord[1,:] should be\n          the column index.\n        For CSR format:\n          index=('csr', indices, indptr), where indices is of shape (nnz,)\n          and indptr is of shape (nrows+1,). See ``scipy.sparse.csr_matrix``\n          for more documents on what each array means.\n    \"\"\"\n    pass\n\n\ndef is_tensor(obj):\n    \"\"\"Returns true if the given object is a framework-specific tensor.\"\"\"\n    pass\n\n\ndef shape(input):\n    \"\"\"Return the shape of the tensor.\n\n    Parameters\n    ----------\n    input : Tensor\n        The input tensor.\n\n    Returns\n    -------\n    tuple of int\n        The tensor shape.\n    \"\"\"\n    pass\n\n\ndef dtype(input):\n    \"\"\"Return the data type of the tensor.\n\n    Parameters\n    ----------\n    input : Tensor\n        The input tensor.\n\n    Returns\n    -------\n    data type\n        It should be one of the values in the data type dict.\n    \"\"\"\n    pass\n\n\ndef ndim(input):\n    \"\"\"Return the number of dimensions of the tensor.\n\n    Parameters\n    ----------\n    input : Tensor\n        The input tensor.\n\n    Returns\n    -------\n    int\n        The number of dimensions\n    \"\"\"\n    pass\n\n\ndef context(input):\n    \"\"\"Return the context/device of the input tensor.\n\n    Parameters\n    ----------\n    input : Tensor\n        The input tensor.\n\n    Returns\n    -------\n    Context object\n        A framework-specific context object.\n    \"\"\"\n    pass\n\n\ndef device_type(ctx):\n    \"\"\"Return a str representing device type.\n\n    Parameters\n    ----------\n    ctx : Device context object.\n        Device context.\n\n    Returns\n    -------\n    str\n    \"\"\"\n    pass\n\n\ndef device_id(ctx):\n    \"\"\"Return device index.\n\n    For CPU, the index does not matter. For GPU, the index means which GPU\n    device on the machine.\n\n    Parameters\n    ----------\n    ctx : Device context object.\n        Device context.\n\n    Returns\n    -------\n    int\n        The device index.\n    \"\"\"\n    pass\n\n\ndef to_backend_ctx(dglctx):\n    \"\"\"Convert a DGL context object to a backend context.\n\n    Parameters\n    ----------\n    dglctx : dgl.ndarray.DGLContext\n        DGL context object. See _ffi.runtime_types for definition.\n\n    Returns\n    -------\n    ctx : framework-specific context object.\n    \"\"\"\n    pass\n\n\ndef astype(input, ty):\n    \"\"\"Convert the input tensor to the given data type.\n\n    Parameters\n    ----------\n    input : Tensor\n        The input tensor.\n    ty : data type\n        It should be one of the values in the data type dict.\n\n    Returns\n    -------\n    Tensor\n        A framework-specific tensor.\n    \"\"\"\n    pass\n\n\ndef asnumpy(input):\n    \"\"\"Convert the input tensor to numpy array.\n\n    The data is copied.\n\n    Parameters\n    ----------\n    input : Tensor\n        The input tensor.\n\n    Returns\n    -------\n    numpy.ndarray\n        Numpy array.\n    \"\"\"\n    pass\n\n\ndef copy_to(input, ctx, **kwargs):\n    \"\"\"Copy the given tensor to the context.\n\n    Parameters\n    ----------\n    input : Tensor\n        The input tensor\n    ctx :\n        A framework-specific context object.\n\n    Returns\n    -------\n    Tensor\n        The tensor on the given context.\n    \"\"\"\n    pass\n\n\ndef is_pinned(input):\n    \"\"\"Check whether the tensor is in pinned memory.\n\n    Parameters\n    ----------\n    input : Tensor\n        The tensor.\n\n    Returns\n    -------\n    bool\n        Whether the tensor is in pinned memory.\n    \"\"\"\n    pass\n\n\n###############################################################################\n# Tensor functions on feature data\n# --------------------------------\n# These functions are performance critical, so it's better to have efficient\n# implementation in each framework.\n\n\ndef sum(input, dim, keepdims=False):\n    \"\"\"Reduce sum the input tensor along the given dim.\n\n    Parameters\n    ----------\n    input : Tensor\n        The input tensor.\n    dim : int\n        The reduce dim.\n    keepdims : bool\n        Whether to keep the summed dimension.\n\n    Returns\n    -------\n    Tensor\n        A framework-specific tensor.\n    \"\"\"\n    pass\n\n\ndef floor_div(in1, in2):\n    \"\"\"Element-wise integer division and rounds each quotient towards zero.\n\n    Parameters\n    ----------\n    in1 : Tensor\n        The input tensor\n    in2 : Tensor or integer\n        The input\n\n    Returns\n    -------\n    Tensor\n        A framework-specific tensor.\n    \"\"\"\n\n\ndef reduce_sum(input):\n    \"\"\"Returns the sum of all elements in the input tensor.\n\n    Parameters\n    ----------\n    input : Tensor\n        The input tensor.\n\n    Returns\n    -------\n    Tensor\n        A framework-specific tensor with shape (1,)\n    \"\"\"\n    pass\n\n\ndef cumsum(input, dim):\n    \"\"\"Return the cumulative sum of the elements along a given axis.\n\n    Parameters\n    ----------\n    input : Tensor\n        The input tensor.\n    dim : int\n        The cumulative dimension.\n\n    Returns\n    -------\n    Tensor\n        A framework-specific tensor.\n    \"\"\"\n    pass\n\n\ndef mean(input, dim):\n    \"\"\"Reduce average the input tensor along the given dim.\n\n    Parameters\n    ----------\n    input : Tensor\n        The input tensor.\n    dim : int\n        The reduce dim.\n\n    Returns\n    -------\n    Tensor\n        A framework-specific tensor.\n    \"\"\"\n    pass\n\n\ndef reduce_mean(input):\n    \"\"\"Returns the average of all elements in the input tensor.\n\n    Parameters\n    ----------\n    input : Tensor\n        The input tensor.\n\n    Returns\n    -------\n    Tensor\n        A framework-specific tensor with shape (1,)\n    \"\"\"\n    pass\n\n\ndef max(input, dim):\n    \"\"\"Reduce max the input tensor along the given dim.\n\n    Parameters\n    ----------\n    input : Tensor\n        The input tensor.\n    dim : int\n        The reduce dim.\n\n    Returns\n    -------\n    Tensor\n        A framework-specific tensor.\n    \"\"\"\n    pass\n\n\ndef reduce_max(input):\n    \"\"\"Returns the max of all elements in the input tensor.\n\n    Parameters\n    ----------\n    input : Tensor\n        The input tensor.\n\n    Returns\n    -------\n    Tensor\n        A framework-specific tensor with shape (1,)\n    \"\"\"\n    pass\n\n\ndef min(input, dim):\n    \"\"\"Reduce min the input tensor along the given dim.\n\n    Parameters\n    ----------\n    input : Tensor\n        The input tensor.\n    dim : int\n        The reduce dim.\n\n    Returns\n    -------\n    Tensor\n        A framework-specific tensor.\n    \"\"\"\n    pass\n\n\ndef reduce_min(input):\n    \"\"\"Returns the min of all elements in the input tensor.\n\n    Parameters\n    ----------\n    input : Tensor\n        The input tensor.\n\n    Returns\n    -------\n    Tensor\n        A framework-specific tensor with shape (1,)\n    \"\"\"\n    pass\n\n\ndef argsort(input, dim, descending):\n    \"\"\"Return the indices that would sort the input along the given dim.\n\n    Parameters\n    ----------\n    input : Tensor\n        The input tensor.\n    dim : int\n        The dim to sort along.\n    descending : bool\n        Controls the sorting order (False: ascending, True: descending)\n\n    Returns\n    -------\n    Tensor\n        A framework-specific tensor.\n    \"\"\"\n\n\ndef topk(input, k, dim, descending=True):\n    \"\"\"Return the k largest elements of the given input tensor along the given dimension.\n\n    If descending is False then the k smallest elements are returned.\n\n    Parameters\n    ----------\n    input : Tensor\n        The input tensor.\n    k : int\n        The number of elements.\n    dim : int\n        The dim to sort along.\n    descending : bool\n        Controls whether to return largest/smallest elements.\n    \"\"\"\n    pass\n\n\ndef argtopk(input, k, dim, descending=True):\n    \"\"\"Return the indices of the k largest elements of the given input tensor\n    along the given dimension.\n\n    If descending is False then the k smallest elements are returned.\n\n    Parameters\n    ----------\n    input : Tensor\n        The input tensor.\n    k : int\n        The number of elements.\n    dim : int\n        The dimension to sort along.\n    descending : bool\n        Controls whether to return largest/smallest elements.\n    \"\"\"\n    pass\n\n\ndef exp(input):\n    \"\"\"Returns a new tensor with the exponential of the elements of the input tensor `input`.\n\n    Parameters\n    ----------\n    input : Tensor\n        The input tensor.\n\n    Returns\n    -------\n    Tensor\n        The output tensor.\n    \"\"\"\n    pass\n\n\ndef inverse(input):\n    \"\"\"Returns the inverse matrix of a square matrix if it exists.\n\n    Parameters\n    ----------\n    input : Tensor\n        The input square matrix.\n\n    Returns\n    -------\n    Tensor\n        The output tensor.\n    \"\"\"\n    pass\n\n\ndef sqrt(input):\n    \"\"\"Returns a new tensor with the square root of the elements of the input tensor `input`.\n\n    Parameters\n    ----------\n    input : Tensor\n        The input tensor.\n\n    Returns\n    -------\n    Tensor\n        The output tensor.\n    \"\"\"\n    pass\n\n\ndef softmax(input, dim=-1):\n    \"\"\"Apply the softmax function on given dimension.\n\n    Parameters\n    ----------\n    input : Tensor\n        The input tensor.\n    dim : int\n        The dimension along which to compute softmax.\n\n    Returns\n    -------\n    Tensor\n        The output tensor.\n    \"\"\"\n    pass\n\n\ndef cat(seq, dim):\n    \"\"\"Concat the sequence of tensors in the given dimension.\n\n    Parameters\n    ----------\n    seq : list of Tensor\n        The tensor sequence.\n    dim : int\n        The concat dim.\n\n    Returns\n    -------\n    Tensor\n        A framework-specific tensor.\n    \"\"\"\n    pass\n\n\ndef stack(seq, dim):\n    \"\"\"Stack the sequence of tensors along the given dimension.\n\n    Parameters\n    ----------\n    seq : list of Tensor\n        The tensor sequence.\n    dim : int\n        The concat dim.\n\n    Returns\n    -------\n    Tensor\n        A framework-specific tensor.\n    \"\"\"\n    pass\n\n\ndef split(input, sizes_or_sections, dim):\n    \"\"\"Split the input tensor into chunks.\n\n    If ``sizes_or_sections`` is an integer, then the tensor will\n    be splitted into equal pieces.\n\n    If ``sizes_or_sections`` is a list, then the tensor will be\n    splitted into segments.\n\n    Parameters\n    ----------\n    input : Tensor\n        Tensor to split.\n    sizes_or_sections : int, list[int]\n        Split sizes or sections.\n    dim : int\n        The dimension to split on.\n\n    Returns\n    -------\n    list of Tensor\n        The splitted tensors.\n    \"\"\"\n    pass\n\n\ndef repeat(input, repeats, dim):\n    \"\"\"Repeats elements of an array.\n\n    Parameters\n    ----------\n    input : Tensor\n        Input data array\n    repeats : int, Tensor\n        The number of repetitions for each element\n    dim : int\n        The dim along which to repeat values.\n\n    Returns\n    -------\n    Tensor\n        The obtained tensor.\n    \"\"\"\n    pass\n\n\ndef gather_row(data, row_index):\n    \"\"\"Slice out the data given the row index.\n\n    Parameters\n    ----------\n    data : Tensor\n        The data tensor\n    row_index : Tensor\n        A 1-D integer tensor containing which rows to be sliced out.\n\n    Returns\n    -------\n    Tensor\n        The sliced data. The first dimension should equal to ``len(row_index)``.\n    \"\"\"\n    pass\n\n\ndef slice_axis(data, axis, begin, end):\n    \"\"\"Slice along a given axis.\n    Returns an array slice along a given axis starting from :attr:`begin` index to :attr:`end` index.\n\n    Parameters\n    ----------\n    data : Tensor\n        The data tensor.\n    axis : int\n        The axis along to slice the tensor.\n    begin : int\n        Indicates the begin index.\n    end : int\n        Indicates the end index.\n    Returns:\n    --------\n    Tensor\n        The sliced tensor.\n    \"\"\"\n    pass\n\n\ndef take(data, indices, dim):\n    \"\"\"Takes elements from an input array along the given dim.\n\n    Parameters\n    ----------\n    data : Tensor\n        The data tensor.\n    indices : Tensor\n        The indices tensor.\n    dim : Tensor\n        The dimension to gather along.\n    \"\"\"\n    pass\n\n\ndef narrow_row(x, start, stop):\n    \"\"\"Narrow down the tensor along the first dimension.\n\n    Parameters\n    ----------\n    x : Tensor\n        The input tensor.\n    start : int\n        The start index (inclusive).\n    stop : int\n        The stop index (exclusive).\n\n    Returns\n    -------\n    Tensor\n        The narrowed tensor\n\n    Notes\n    -----\n    The returned tensor could be a view of the original tensor.\n    \"\"\"\n    pass\n\n\ndef scatter_row(data, row_index, value):\n    \"\"\"Write the value into the data tensor using the row index.\n\n    This is an out-place write so it can work with autograd.\n\n    Parameters\n    ----------\n    data : Tensor\n        The data tensor to be updated.\n    row_index : Tensor\n        A 1-D integer tensor containing which rows to be updated.\n    value : Tensor\n        The new value.\n\n    Returns\n    -------\n    Tensor\n        The new data.\n    \"\"\"\n    pass\n\n\ndef index_add_inplace(data, row_idx, value):\n    \"\"\"Add the values into the data tensor using the row index inplace.\n\n    If two row indices are the same, the corresponding values are sum up before\n    adding to the data tensor.\n\n    Examples\n    --------\n    >>> import torch as th\n    >>> arr = th.zeros((10))\n    >>> F. index_add_inplace(arr, th.tensor([0, 1, 1]), th.tensor([1.0, 1.0, 1.0]))\n    >>> arr\n    tensor([1., 2., 0., 0., 0., 0., 0., 0., 0., 0.])\n\n    Parameters\n    ----------\n    data : Tensor\n        The data tensor to be updated.\n    row_index : Tensor\n        A 1-D integer tensor containing which rows to be updated.\n    value : Tensor\n        The new value.\n    \"\"\"\n    pass\n\n\ndef scatter_row_inplace(data, row_index, value):\n    \"\"\"Write the value into the data tensor using the row index inplace.\n\n    This is an inplace write so it will break the autograd.\n\n    Parameters\n    ----------\n    data : Tensor\n        The data tensor to be updated.\n    row_index : Tensor\n        A 1-D integer tensor containing which rows to be updated.\n    value : Tensor\n        The new value.\n    \"\"\"\n    pass\n\n\ndef squeeze(input, dim):\n    \"\"\"Remove the given dimension of size 1.\n\n    Parameters\n    ----------\n    input : Tensor\n        The input tensor.\n    dim : int\n        The dimension to be squeezed.\n\n    Returns\n    -------\n    Tensor\n        The result tensor.\n    \"\"\"\n    pass\n\n\ndef unsqueeze(input, dim):\n    \"\"\"Add the given dimension of size 1.\n\n    Parameters\n    ----------\n    input : Tensor\n        The input tensor.\n    dim : int\n        The dimension to be unsqueezed.\n\n    Returns\n    -------\n    Tensor\n        The result tensor.\n    \"\"\"\n    pass\n\n\ndef reshape(input, shape):\n    \"\"\"Reshape the tensor.\n\n    Parameters\n    ----------\n    input : Tensor\n        The input tensor.\n    shape : tuple of int\n        The new shape.\n\n    Returns\n    -------\n    Tensor\n        The reshaped tensor.\n    \"\"\"\n    pass\n\n\ndef swapaxes(input, axis1, axis2):\n    \"\"\"Interchange the two given axes of a tensor.\n\n    Parameters\n    ----------\n    input : Tensor\n        The input tensor.\n    axis1, axis2 : int\n        The two axes.\n\n    Returns\n    -------\n    Tensor\n        The transposed tensor.\n    \"\"\"\n    pass\n\n\ndef empty(shape, dtype, ctx):\n    \"\"\"Create a tensor filled with uninitialized data.\n\n    Parameters\n    ----------\n    shape : tuple of int\n        The tensor shape.\n    dtype : data type\n        It should be one of the values in the data type dict.\n    ctx : context\n        The device of the result tensor.\n\n    Returns\n    -------\n    Tensor\n        The emtpy tensor.\n    \"\"\"\n    pass\n\n\ndef zeros(shape, dtype, ctx):\n    \"\"\"Create a zero tensor.\n\n    Parameters\n    ----------\n    shape : tuple of int\n        The tensor shape.\n    dtype : data type\n        It should be one of the values in the data type dict.\n    ctx : context\n        The device of the result tensor.\n\n    Returns\n    -------\n    Tensor\n        The zero tensor.\n    \"\"\"\n    pass\n\n\ndef zeros_like(input):\n    \"\"\"Create a zero tensor with the same shape, dtype and context of the\n    given tensor.\n\n    Parameters\n    ----------\n    input : Tensor\n        The input\n\n    Returns\n    -------\n    Tensor\n        The result\n    \"\"\"\n    pass\n\n\ndef ones(shape, dtype, ctx):\n    \"\"\"Create a one tensor.\n\n    Parameters\n    ----------\n    shape : tuple of int\n        The tensor shape.\n    dtype : data type\n        It should be one of the values in the data type dict.\n    ctx : context\n        The device of the result tensor.\n\n    Returns\n    -------\n    Tensor\n        The one tensor.\n    \"\"\"\n    pass\n\n\ndef uniform(shape, dtype, ctx, low, high):\n    \"\"\"Create a tensor with random value in a uniform\n    distribution between low (inclusive) and high (exclusive).\n\n    Parameters\n    ----------\n    shape : tuple of int\n        The tensor shape.\n    dtype : data type\n        It should be one of the values in the data type dict.\n    ctx : context\n        The device of the result tensor.\n\n    Returns\n    -------\n    Tensor\n        The random tensor.\n    \"\"\"\n    pass\n\n\ndef randint(shape, dtype, ctx, low, high):\n    \"\"\"Create a tensor with random value in a uniform integer\n    distribution between low (inclusive) and high (exclusive)\n\n    Parameters\n    ----------\n    shape : tuple of int\n        The tensor shape.\n    dtype : data type\n        It should be one of the values in the data type dict.\n    ctx : context\n        The device of the result tensor.\n\n    Returns\n    -------\n    Tensor\n        The random tensor.\n    \"\"\"\n    pass\n\n\ndef pad_packed_tensor(input, lengths, value, l_min=None):\n    r\"\"\"Pads a packed batch of variable length tensors with given value.\n\n    Parameters\n    ----------\n    input : Tensor\n        The input tensor with shape :math:`(N, *)`\n    lengths : list or tensor\n        The array of tensor lengths (of the first dimension) :math:`L`.\n        It should satisfy :math:`\\sum_{i=1}^{B}L_i = N`,\n        where :math:`B` is the length of :math:`L`.\n    value : float\n        The value to fill in the tensor.\n    l_min : int or None, defaults to None.\n        The minimum length each tensor need to be padded to, if set to None,\n        then there is no minimum length requirement.\n\n    Returns\n    -------\n    Tensor\n        The obtained tensor with shape :math:`(B, \\max(\\max_i(L_i), l_{min}), *)`\n    \"\"\"\n    pass\n\n\ndef pack_padded_tensor(input, lengths):\n    r\"\"\"Packs a tensor containing padded sequence of variable length.\n\n    Parameters\n    ----------\n    input : Tensor\n        The input tensor with shape :math:`(B, L, *)`, where :math:`B` is\n        the batch size and :math:`L` is the maximum length of the batch.\n    lengths : list or tensor\n        The array of tensor lengths (of the first dimension) :math:`L`.\n        :math:`\\max_i(L_i)` should equal :math:`L`.\n\n    Returns\n    -------\n    Tensor\n        The obtained tensor with shape :math:`(N, *)` where\n        :math:`N = \\sum_{i=1}^{B}L_i`\n    \"\"\"\n    pass\n\n\ndef boolean_mask(input, mask):\n    \"\"\"Selects elements in x according to the given mask from the first\n    dimension.\n\n    Parameters\n    ----------\n    input : Tensor\n        The input tensor\n    mask : Boolean Tensor\n        The mask\n\n    Returns\n    -------\n    Tensor\n        The result\n    \"\"\"\n    pass\n\n\ndef equal(x, y):\n    \"\"\"Compares whether the elements are equal.\n\n    Parameters\n    ----------\n    x, y : Tensor\n        The two tensors\n\n    Returns\n    -------\n    Boolean or integer tensor\n        The result, with the same shape as input.\n    \"\"\"\n    pass\n\n\ndef allclose(x, y, rtol=1e-4, atol=1e-4):\n    \"\"\"Compares whether all elements are close.\n\n    Parameters\n    ----------\n    x : Tensor\n        First tensor\n    y : Tensor\n        Second tensor\n    rtol : float, optional\n        Relative tolerance\n    atol : float, optional\n        Absolute tolerance\n    \"\"\"\n\n\ndef logical_not(input):\n    \"\"\"Perform a logical not operation.  Equivalent to np.logical_not\n\n    Parameters\n    ----------\n    input : Tensor\n        The input\n\n    Returns\n    -------\n    Tensor\n        The result\n    \"\"\"\n    pass\n\n\ndef logical_and(input1, input2):\n    pass\n\n\ndef clone(input):\n    \"\"\"Return a clone of the input tensor.\n\n    Parameters\n    ----------\n    input : Tensor\n        Input tensor.\n\n    Returns\n    -------\n    Tensor\n        A clone tensor.\n    \"\"\"\n    pass\n\n\ndef clamp(data, min_val, max_val):\n    \"\"\"Clamp all elements in :attr:`input` into the range [min_val, max_val]\n    and return a resulting tensor.\n\n    Parameters\n    ----------\n    data : Tensor\n        Input tensor\n    min_val : Scalar\n        Min value.\n    max_val : Scalar\n        Max value.\n\n    Returns\n    -------\n    Tensor\n        The result.\n    \"\"\"\n    pass\n\n\ndef replace_inf_with_zero(x):\n    \"\"\"Returns a new tensor replacing infinity and negative infinity with zeros.\n\n    Parameters\n    ----------\n    x : Tensor\n        The input\n\n    Returns\n    -------\n    Tensor\n        The result\n    \"\"\"\n    pass\n\n\ndef count_nonzero(input):\n    \"\"\"Return the count of non-zero values in the tensor input.\n\n    Parameters\n    ----------\n    input : Tensor\n        The tensor to be counted\n\n    Returns\n    -------\n    Integer\n        The result\n    \"\"\"\n    pass\n\n\n###############################################################################\n# Tensor functions used *only* on index tensor\n# ----------------\n# These operators are light-weighted, so it is acceptable to fallback to\n# numpy operators if currently missing in the framework. Ideally in the future,\n# DGL should contain all the operations on index, so this set of operators\n# should be gradually removed.\n\n\ndef unique(input, return_inverse=False, return_counts=False):\n    \"\"\"Returns the unique scalar elements in a tensor.\n\n    Parameters\n    ----------\n    input : Tensor\n        Must be a 1-D tensor.\n    return_inverse : bool, optional\n        Whether to also return the indices for where elements in the original\n        input ended up in the returned unique list.\n    return_counts : bool, optional\n        Whether to also return the counts for each unique element.\n\n    Returns\n    -------\n    Tensor\n        A 1-D tensor containing unique elements.\n    Tensor, optional\n        A 1-D tensor containing the new positions of the elements in the input.\n        It is returned if return_inverse is True.\n    Tensor, optional\n        A 1-D tensor containing the number of occurrences for each unique value or tensor.\n        It is returned if return_counts is True.\n    \"\"\"\n    pass\n\n\ndef full_1d(length, fill_value, dtype, ctx):\n    \"\"\"Create a 1D tensor full of the fill_value.\n\n    Parameters\n    ----------\n    shape : int\n        The length of the vector.\n    fill_value : int\n        The filled value.\n    dtype : data type\n        It should be one of the values in the data type dict.\n    ctx : context\n        The device of the result tensor.\n\n    Returns\n    -------\n    Tensor\n        A result 1D tensor\n    \"\"\"\n    pass\n\n\ndef nonzero_1d(input):\n    \"\"\"Return the nonzero index of the given 1D input.\n\n    Parameters\n    ----------\n    input : Tensor\n        Must be a 1D tensor.\n\n    Returns\n    -------\n    Tensor\n        A 1D integer tensor containing the nonzero indices.\n    \"\"\"\n    pass\n\n\ndef sort_1d(input):\n    \"\"\"Sort a 1D tensor (in ascending order) and also return the original index.\n\n    Parameters\n    ----------\n    input : Tensor\n        The tensor to be sorted.\n\n    Returns\n    -------\n    Tensor\n        Sorted tensor.\n    Tensor\n        Index tensor of the elements in the original input.\n    \"\"\"\n    pass\n\n\ndef arange(start, stop, dtype, ctx):\n    \"\"\"Create a 1D range int64 tensor.\n\n    Parameters\n    ----------\n    start : int\n        The range start.\n    stop : int\n        The range stop.\n    dtype: str\n        The dtype of result tensor.\n    ctx : Device context object.\n        Device context.\n\n    Returns\n    -------\n    Tensor\n        The result tensor.\n    \"\"\"\n    pass\n\n\ndef rand_shuffle(arr):\n    \"\"\"Random shuffle the data in the first dimension of the array.\n\n    The shuffled data is stored in a new array.\n\n    Parameters\n    ----------\n    arr : Tensor\n        The data tensor\n\n    Returns\n    -------\n    Tensor\n        The result tensor\n    \"\"\"\n    pass\n\n\ndef zerocopy_to_dlpack(input):\n    \"\"\"Create a dlpack tensor that shares the input memory.\n\n    Parameters\n    ----------\n    input : Tensor\n        The input tensor\n\n    Returns\n    -------\n    dlpack capsule\n        A dlpack capsule that can be used by other framework.\n    \"\"\"\n    pass\n\n\ndef zerocopy_from_dlpack(dlpack_tensor):\n    \"\"\"Create a tensor that shares the dlpack_tensor.\n\n    Parameters\n    ----------\n    dlpack_tensor : dlpack capsule\n        The dlpack tensor.\n\n    Returns\n    -------\n    Tensor\n        A framework-specific tensor.\n    \"\"\"\n    pass\n\n\ndef zerocopy_to_numpy(input):\n    \"\"\"Create a numpy ndarray that shares the input memory.\n\n    Parameters\n    ----------\n    input : Tensor\n        The input tensor\n\n    Returns\n    -------\n    numpy.ndarray\n        A numpy ndarray.\n    \"\"\"\n    pass\n\n\ndef zerocopy_from_numpy(np_array):\n    \"\"\"Create a tensor that shares the numpy array.\n\n    Parameters\n    ----------\n    np_array : numpy.ndarray\n        The numpy ndarray.\n\n    Returns\n    -------\n    Tensor\n        A framework-specific tensor.\n    \"\"\"\n    pass\n\n\ndef zerocopy_to_dgl_ndarray(input):\n    \"\"\"Zerocopy a framework-specific Tensor to dgl.ndarray.NDArray\n\n    Parameters\n    ----------\n    input : Tensor\n\n    Returns\n    -------\n    dgl.ndarray.NDArray\n    \"\"\"\n    pass\n\n\ndef zerocopy_to_dgl_ndarray_for_write(input):\n    \"\"\"Zerocopy a framework-specific Tensor to dgl.ndarray.NDArray\n    that is ready for write (required in MXNet).\n\n    Parameters\n    ----------\n    input : Tensor\n\n    Returns\n    -------\n    dgl.ndarray.NDArray\n    \"\"\"\n    pass\n\n\ndef zerocopy_from_dgl_ndarray(input):\n    \"\"\"Zerocopy a dgl.ndarray.NDArray to framework-specific Tensor\n\n    Parameters\n    ----------\n    input : dgl.ndarray.NDArray\n\n    Returns\n    -------\n    Tensor\n    \"\"\"\n    pass\n\n\n###############################################################################\n# Custom Operators for graph level computations.\n\n# Note: These operators are supposed to be implemented using DGL-provided\n# kernels (see kernel.py), and plug into tensor framework using custom op\n# extensions.\n\n\ndef binary_reduce(\n    reducer,\n    binary_op,\n    graph,\n    lhs,\n    rhs,\n    lhs_data,\n    rhs_data,\n    out_size,\n    lhs_map,\n    rhs_map,\n    out_map,\n):\n    \"\"\"Perform binary operation between given data and reduce based on graph\n    structure.\n\n    Parameters\n    ----------\n    reducer : str\n        Type of reduction: 'sum', 'max', 'min', 'mean', 'prod', 'none' (no\n        reduction)\n    binary_op : str\n        Binary operation to perform, can be 'add', 'mul', 'sub', 'div'\n    graph : GraphIndex\n        The graph\n    lhs : int\n        The lhs target (src, dst, edge)\n    rhs : int\n        The rhs target (src, dst, edge)\n    lhs_data : Tensor\n        The lhs data\n    rhs_data : Tensor\n        The rhs data\n    out_size : int\n        Size of first dimension of output data\n    lhs_map : tuple\n        Two lhs id mapping arrays, one for forward pass, the other for backward\n    rhs_map : tuple\n        Two rhs id mapping arrays, one for forward pass, the other for backward\n    out_map : tuple\n        Two out id mapping arrays, one for forward pass, the other for backward\n\n    Returns\n    -------\n    Tensor\n        The result.\n    \"\"\"\n    pass\n\n\ndef copy_reduce(reducer, graph, target, in_data, out_size, in_map, out_map):\n    \"\"\"Copy target data and perform reduce based on graph structure.\n\n    Parameters\n    ----------\n    reducer : str\n        Type of reduction: be 'sum', 'max', 'min', 'mean', 'prod', 'none' (no\n        reduction)\n    graph : GraphIndex\n        The graph\n    target : int\n        The input target (src, dst, edge)\n    in_data : Tensor\n        The input data\n    out_size : int\n        Size of first dimension of output data\n    in_map : tuple\n        Two input id mapping arrays, one for forward, the other for backward\n    out_map : tuple\n        Two output id mapping arrays, one for forward, the other for backward\n\n    Returns\n    -------\n    Tensor\n        The result.\n    \"\"\"\n    pass\n\n\ndef gspmm(gidx, op, reduce_op, lhs_data, rhs_data):\n    r\"\"\"Generalized Sparse Matrix Multiplication interface.\n    It fuses two steps into one kernel.\n    (1) Computes messages by :attr:`op` source node and edge features.\n    (2) Aggregate the messages by :attr:`reduce_op` as the features on destination nodes.\n\n    .. math::\n        x_v = \\psi_{(u, v, e)\\in \\mathcal{G}}(\\rho(x_u, x_e))\n\n    where :math:`x_v` is the returned feature on destination nodes, and :math`x_u`,\n    :math:`x_e` refers to :attr:`u`, :attr:`e` respectively. :math:`\\rho` means binary\n    operator :attr:`op` and :math:`\\psi` means reduce operator :attr:`reduce_op`,\n    :math:`\\mathcal{G}` is the graph we apply gspmm on: :attr:`g`.\n\n    Note that this function does not handle gradients.\n\n    Parameters\n    ----------\n    gidx : HeteroGraphIndex\n        The input graph.\n    op : str\n        The binary op's name, could be ``add``, ``sub``, ``mul``, ``div``,\n        ``copy_lhs``, ``copy_rhs``.\n    reduce_op : str\n        Reduce operator, could be ``sum``, ``max``, ``min``.\n    lhs_data : tensor or None\n        The left operand, could be None if it's not required by the op.\n    rhs_data : tensor or None\n        The right operand, could be None if it's not required by the op.\n\n    Returns\n    -------\n    tensor\n        The result tensor.\n    \"\"\"\n    pass\n\n\ndef gspmm_hetero(g, op, reduce_op, lhs_len, *lhs_and_rhs_tuple):\n    r\"\"\"Generalized Sparse Matrix Multiplication interface on heterogenenous graph.\n    All the relation types of the heterogeneous graph will be processed together.\n    It fuses two steps into one kernel.\n    (1) Computes messages by :attr:`op` source node and edge features.\n    (2) Aggregate the messages by :attr:`reduce_op` as the features on destination nodes.\n\n    .. math::\n        x_v = \\psi_{(u, v, e)\\in \\mathcal{G}}(\\rho(x_u, x_e))\n\n    where :math:`x_v` is the returned feature on destination nodes, and :math`x_u`,\n    :math:`x_e` refers to :attr:`u`, :attr:`e` respectively. :math:`\\rho` means binary\n    operator :attr:`op` and :math:`\\psi` means reduce operator :attr:`reduce_op`,\n    :math:`\\mathcal{G}` is the graph we apply gspmm on: :attr:`g`.\n\n    Note that this function does not handle gradients.\n\n    Parameters\n    ----------\n    g : HeteroGraph\n        The input graph.\n    op : str\n        The binary op's name, could be ``add``, ``sub``, ``mul``, ``div``,\n        ``copy_lhs``, ``copy_rhs``.\n    reduce_op : str\n        Reduce operator, could be ``sum``, ``max``, ``min``.\n    lhs_len : int\n        Length of the lhs data\n    lhs_and_rhs_tuple : tuple of tensors\n        lhs_data and rhs_data are concatenated to one tuple. lhs_data is\n        also a tuple of tensors of size number of ntypes. Same is true for\n        rhs_data.\n        The tensor(s) in the tuple could be None\n\n    Returns\n    -------\n    tuple of tensor\n        The resulting tuple of tensor.\n    \"\"\"\n    pass\n\n\ndef gsddmm(gidx, op, lhs_data, rhs_data, lhs_target=\"u\", rhs_target=\"v\"):\n    r\"\"\"Generalized Sampled-Dense-Dense Matrix Multiplication interface.\n    It computes edge features by :attr:`op` lhs features and rhs features.\n\n    .. math::\n        x_{e} = \\phi(x_{lhs}, x_{rhs}), \\forall (u,e,v)\\in \\mathcal{G}\n\n    where :math:`x_{e}` is the returned feature on edges and :math:`x_u`,\n    :math:`x_v` refers to :attr:`u`, :attr:`v` respectively. :math:`\\phi`\n    is the binary operator :attr:`op`, and :math:`\\mathcal{G}` is the graph\n    we apply gsddmm on: :attr:`g`. $lhs$ and $rhs$ are one of $u,v,e$'s.\n\n    Parameters\n    ----------\n    gidx : HeteroGraphIndex\n        The input graph.\n    op : str\n        Binary operator, could be ``add``, ``sub``, ``mul``, ``div``, ``dot``,\n        ``copy_lhs``, ``copy_rhs``.\n    lhs_data : tensor or None\n        The left operand, could be None if it's not required by op.\n    rhs_data : tensor or None\n        The right operand, could be None if it's not required by op.\n    lhs_target: str\n        Choice of `u`(source), `e`(edge) or `v`(destination) for left operand.\n    rhs_target: str\n        Choice of `u`(source), `e`(edge) or `v`(destination) for right operand.\n\n    Returns\n    -------\n    tensor\n        The result tensor.\n    \"\"\"\n    pass\n\n\ndef gsddmm_hetero(\n    g, op, lhs_len, lhs_target=\"u\", rhs_target=\"v\", *lhs_and_rhs_tuple\n):\n    r\"\"\"Generalized Sampled-Dense-Dense Matrix Multiplication interface on\n    heterogenenous graph. All the relation types of the heterogeneous graph\n    will be processed together.\n    It computes edge features by :attr:`op` lhs features and rhs features.\n\n    .. math::\n        x_{e} = \\phi(x_{lhs}, x_{rhs}), \\forall (u,e,v)\\in \\mathcal{G}\n\n    where :math:`x_{e}` is the returned feature on edges and :math:`x_u`,\n    :math:`x_v` refers to :attr:`u`, :attr:`v` respectively. :math:`\\phi`\n    is the binary operator :attr:`op`, and :math:`\\mathcal{G}` is the graph\n    we apply gsddmm on: :attr:`g`. $lhs$ and $rhs$ are one of $u,v,e$'s.\n\n    Parameters\n    ----------\n    gidx : HeteroGraphIndex\n        The input graph.\n    op : str\n        Binary operator, could be ``add``, ``sub``, ``mul``, ``div``, ``dot``,\n        ``copy_lhs``, ``copy_rhs``.\n    lhs_len : int\n        Length of the lhs data\n    lhs_target: str\n        Choice of `u`(source), `e`(edge) or `v`(destination) for left operand.\n    rhs_target: str\n        Choice of `u`(source), `e`(edge) or `v`(destination) for right operand.\n    lhs_and_rhs_tuple : tuple of tensors\n        lhs_data and rhs_data are concatenated to one tuple. lhs_data is\n        also a tuple of tensors of size number of ntypes. Same is true for\n        rhs_data.\n        The tensor(s) in the tuple could be None\n\n    Returns\n    -------\n    tuple of tensor\n        The resulting tuple of tensor.\n    \"\"\"\n    pass\n\n\ndef edge_softmax(gidx, logits, eids, norm_by):\n    r\"\"\"Compute edge softmax.\n\n    For a node :math:`i`, edge softmax is an operation of computing\n\n    .. math::\n      a_{ij} = \\frac{\\exp(z_{ij})}{\\sum_{j\\in\\mathcal{N}(i)}\\exp(z_{ij})}\n\n    where :math:`z_{ij}` is a signal of edge :math:`j\\rightarrow i`, also\n    called logits in the context of softmax. :math:`\\mathcal{N}(i)` is\n    the set of nodes that have an edge to :math:`i`.\n\n    By default edge softmax is normalized by destination nodes(i.e. :math:`ij`\n    are incoming edges of `i` in the formula above). We also support edge\n    softmax normalized by source nodes(i.e. :math:`ij` are outgoing edges of\n    `i` in the formula). The previous case correspond to softmax in GAT and\n    Transformer, and the later case correspond to softmax in Capsule network.\n\n    Parameters\n    ----------\n    gidx : HeteroGraphIndex\n        The graph to perfor edge softmax on.\n    logits : torch.Tensor\n        The input edge feature\n    eids : torch.Tensor or ALL, optional\n        Edges on which to apply edge softmax. If ALL, apply edge\n        softmax on all edges in the graph. Default: ALL.\n    norm_by : str, could be `src` or `dst`\n        Normalized by source nodes or destination nodes. Default: `dst`.\n\n    Returns\n    -------\n    Tensor\n        Softmax value\n    \"\"\"\n    pass\n\n\ndef edge_softmax_hetero(gidx, eids, norm_by, *logits):\n    r\"\"\"Compute edge softmax.\n\n    For a node :math:`i`, edge softmax is an operation of computing\n\n    .. math::\n      a_{ij} = \\frac{\\exp(z_{ij})}{\\sum_{j\\in\\mathcal{N}(i)}\\exp(z_{ij})}\n\n    where :math:`z_{ij}` is a signal of edge :math:`j\\rightarrow i`, also\n    called logits in the context of softmax. :math:`\\mathcal{N}(i)` is\n    the set of nodes that have an edge to :math:`i`.\n\n    By default edge softmax is normalized by destination nodes(i.e. :math:`ij`\n    are incoming edges of `i` in the formula above). We also support edge\n    softmax normalized by source nodes(i.e. :math:`ij` are outgoing edges of\n    `i` in the formula). The previous case correspond to softmax in GAT and\n    Transformer, and the later case correspond to softmax in Capsule network.\n\n    Parameters\n    ----------\n    gidx : HeteroGraphIndex\n        The graph to perfor edge softmax on.\n    eids : dict of tensors\n        Each tensor has the edges on which to apply edge softmax for a\n        corresponsing relation type.\n    logits : tuple of tensors\n        The input edge features of different relation types.\n    norm_by : str, could be `src` or `dst`\n        Normalized by source nodes or destination nodes. Default: `dst`.\n\n    Returns\n    -------\n    Tensor\n        Softmax value\n    \"\"\"\n    pass\n\n\ndef segment_reduce(op, x, offsets):\n    \"\"\"Segment reduction operator.\n\n    It aggregates the value tensor along the first dimension by segments.\n    The argument ``offsets`` specifies the start offset of each segment (and\n    the upper bound of the last segment). Zero-length segments are allowed.\n\n    .. math::\n      y_i = \\Phi_{j=\\mathrm{offsets}_i}^{\\mathrm{offsets}_{i+1}-1} x_j\n\n    where :math:`\\Phi` is the reduce operator.\n\n    Parameters\n    ----------\n    op : str\n        Aggregation method. Can be ``sum``, ``max``, ``min``.\n    x : Tensor\n        Value to aggregate.\n    offsets : Tensor\n        The start offsets of segments.\n\n    Returns\n    -------\n    Tensor\n        Aggregated tensor of shape ``(len(offsets) - 1, value.shape[1:])``.\n    \"\"\"\n    pass\n\n\ndef scatter_add(x, idx, m):\n\n    \"\"\"Scatter add (on first dimension) operator.\n\n    Math: y[idx[i], *] += x[i, *]\n\n    Parameters\n    ----------\n    x : Tensor\n        The input feature.\n    idx : Tensor\n        The indices array.\n    m : int\n        The length of output.\n\n    Returns\n    -------\n    Tensor\n        The output tensor.\n    \"\"\"\n    pass\n\n\ndef csrmm(A, A_weights, B, B_weights, num_vtypes):\n    \"\"\"Compute weighted adjacency matrix multiplication.\n\n    Notes\n    -----\n    Both A and B must allow creation of CSR representations, and must be simple graphs\n    (i.e. having at most one edge between two nodes).\n\n    The output unit graph has no format restriction.\n\n    Parameters\n    ----------\n    A : HeteroGraphIndex\n        The unit graph as left operand.\n    A_weights : Tensor\n        The edge weights of A.  Must be a 1D vector.\n    B : HeteroGraphIndex\n        The unit graph as right operand.\n    B_weights : Tensor\n        The edge weights of B.  Must be a 1D vector.\n    num_vtypes : int\n        The number of node types of the output graph.  Must be either 1 or 2.\n\n    Returns\n    -------\n    HeteroGraphIndex\n        The output unit graph.\n    Tensor\n        The output edge weights.\n    \"\"\"\n    pass\n\n\ndef csrsum(gidxs, weights):\n    \"\"\"Compute weighted adjacency matrix summation.\n\n    Notes\n    -----\n    All unit graphs must allow creation of CSR representations, and must be simple graphs\n    (i.e. having at most one edge between two nodes).\n\n    The output unit graph has no format restriction.\n\n    Parameters\n    ----------\n    gidxs : list[HeteroGraphIndex]\n        The unit graphs.\n    weights : list[Tensor]\n        The edge weights of each graph.  Must be 1D vectors.\n\n    Returns\n    -------\n    HeteroGraphIndex\n        The output unit graph.\n    Tensor\n        The output edge weights.\n    \"\"\"\n    pass\n\n\ndef csrmask(A, A_weights, B):\n    \"\"\"Retrieve the values in the weighted adjacency matrix of graph :attr:`A` at the\n    non-zero positions of graph :attr:`B`'s adjacency matrix.\n\n    In scipy, this is equivalent to ``A[B != 0]``.\n\n    Notes\n    -----\n    Both A and B must allow creation of CSR representations, and must be simple graphs\n    (i.e. having at most one edge between two nodes).\n\n    Parameters\n    ----------\n    A : HeteroGraphIndex\n        The unit graph as left operand.\n    A_weights : Tensor\n        The edge weights of A.  Must be a 1D vector.\n    B : HeteroGraphIndex\n        The unit graph as right operand.\n\n    Returns\n    -------\n    Tensor\n        The output tensor.\n    \"\"\"\n    pass\n\n\ndef gather_mm(A, B, idx_a, idx_b):\n    r\"\"\"Dense Matrix Multiplication interface. It multiplies 2D dense tensor A\n    and 3D dense tensor B according to their relation types. A is unsorted and\n    the relation type is fetched from idx_b.\n\n    Parameters\n    ----------\n    A : tensor\n        2-D tensor of shape (N, D1)\n    B : tensor\n        3-D tensor of shape (R, D1, D2)\n    idx_a : Tensor, optional\n        If specified, must be a 1-D integer tensor of shape (K,).\n    idx_b : Tensor, optional\n        If specified, must be a 1-D integer tensor of shape (K,).\n\n    Returns\n    -------\n    Tensor\n        The output dense matrix of shape (N, D2)\n    \"\"\"\n    pass\n\n\ndef segment_mm(A, B, seglen_A):\n    r\"\"\"Dense Matrix Multiplication interface. It multiplies dense tensor A\n    and dense tensor B according to relation types. A is sorted and concatenated\n    according to relation types.\n\n    Parameters\n    ----------\n    A : tensor\n        2-D tensor of shape (N, D1)\n    B : tensor\n        3-D tensor of shape (R, D1, D2)\n    seglen_A : Tensor\n        An integer tensor of shape (R,). Each element is the length of segments\n        of input ``A``. The summation of all elements must be equal to N.\n\n    Returns\n    -------\n    Tensor\n        The output dense matrix of shape (N, D2)\n    \"\"\"\n    pass\n\n\n###############################################################################\n# Other interfaces\n# ----------------\n# These are not related to tensors. Some of them are temporary workarounds that\n# should be included in DGL in the future.\n\n\ndef sync():\n    \"\"\"Synchronize computation.\n\n    In DL frameworks such as MXNet and TensorFlow, the computation in operators\n    are done asynchronously. This is to synchronize computation and makes sure\n    that all computation is complete after this function call.\n    \"\"\"\n    pass\n\n\ndef attach_grad(tensor):\n    \"\"\"Attach gradients to the input tensor\"\"\"\n    pass\n\n\ndef backward(x, head_gradient=None):\n    \"\"\"Invoke backward computation with an optional head gradient.\"\"\"\n    pass\n\n\ndef grad(x):\n    \"\"\"Fetches the gradient from the tensor after backward computation.\"\"\"\n    pass\n\n\ndef is_no_grad(x):\n    \"\"\"Test if the input tensor has gradient\"\"\"\n    pass\n\n\ndef is_recording():\n    \"\"\"Test if the execution is recording gradients.\"\"\"\n    pass\n\n\nclass record_grad(object):\n    \"\"\"Context manager that records the gradients\"\"\"\n\n    def __init__(self):\n        pass\n\n    def __enter__(self):\n        pass\n\n    def __exit__(self, exc_type, exc_value, exc_traceback):\n        pass\n\n\nclass no_grad(object):\n    \"\"\"Context manager that explicitly disables gradient computation\"\"\"\n\n    def __init__(self):\n        pass\n\n    def __enter__(self):\n        pass\n\n    def __exit__(self, exc_type, exc_value, exc_traceback):\n        pass\n\n\nclass NodeEmbedding(object):\n    \"\"\"Sparse node embeddings\"\"\"\n\n    def __init__(self):\n        pass\n\n    def __enter__(self):\n        pass\n\n    def __exit__(self, exc_type, exc_value, exc_traceback):\n        pass\n"
  },
  {
    "path": "python/dgl/backend/mxnet/__init__.py",
    "content": "from .sparse import *\nfrom .tensor import *\n"
  },
  {
    "path": "python/dgl/backend/mxnet/sparse.py",
    "content": "import mxnet as mx\nimport numpy as np\nfrom mxnet import nd\n\nfrom ..._sparse_ops import (\n    _bwd_segment_cmp,\n    _csrmask,\n    _csrmm,\n    _csrsum,\n    _gsddmm,\n    _gspmm,\n    _scatter_add,\n    _segment_reduce,\n)\n\nfrom ...base import ALL, dgl_warning, is_all\nfrom ...heterograph_index import create_unitgraph_from_csr\nfrom .tensor import (\n    asnumpy,\n    context,\n    copy_to,\n    to_backend_ctx,\n    zerocopy_from_numpy,\n)\n\n__all__ = [\n    \"gspmm\",\n    \"gsddmm\",\n    \"edge_softmax\",\n    \"segment_reduce\",\n    \"scatter_add\",\n    \"csrmm\",\n    \"csrsum\",\n    \"csrmask\",\n]\n\n\ndef _scatter_nd(index, src, n_rows):\n    \"\"\"Similar to PyTorch's scatter nd on first dimension.\"\"\"\n    assert index.shape == src.shape\n    dgl_warning(\"MXNet do not support scatter_add, fallback to numpy.\")\n    ctx = context(src)\n    index = asnumpy(index)\n    src = asnumpy(src)\n    shp = index.shape\n    ndim = src.ndim\n    offsets = []\n    stride = 1\n    for i in reversed(range(1, ndim)):\n        di = shp[i]\n        offset_i = np.arange(di, dtype=index.dtype)\n        offsets.append(\n            (stride * offset_i).reshape(\n                (1,) * i + (di,) + (1,) * (ndim - 1 - i)\n            )\n        )\n        stride *= di\n    if ndim > 1:\n        new_idx = index * stride + sum(offsets)\n    else:\n        new_idx = index\n    src = src.reshape(-1)\n    new_idx = new_idx.reshape(-1)\n    rst = np.zeros((stride * n_rows,), dtype=src.dtype)\n    np.add.at(rst, new_idx, src)\n    rst = rst.reshape(n_rows, *shp[1:])\n    rst = copy_to(zerocopy_from_numpy(rst), ctx)\n    return rst\n\n\ndef _gather_nd(index, src):\n    \"\"\"Similar to PyTorch's gather nd on first dimension.\"\"\"\n    ctx = context(src)\n    shp = index.shape\n    ndim = src.ndim\n    offsets = []\n    stride = 1\n    for i in reversed(range(1, ndim)):\n        di = shp[i]\n        offset_i = nd.arange(di, dtype=index.dtype)\n        offsets.append(\n            (stride * offset_i).reshape(\n                (1,) * i + (di,) + (1,) * (ndim - 1 - i)\n            )\n        )\n        stride *= di\n    if ndim > 1:\n        new_idx = index * stride + copy_to(sum(offsets), ctx)\n    else:\n        new_idx = index\n    src = src.reshape(-1)\n    new_idx = new_idx.reshape(-1)\n    rst = nd.take(src, new_idx).reshape(shp)\n    return rst\n\n\ndef _reduce_grad(grad, shape):\n    \"\"\"Reduce gradient on the broadcast dimension\n    If there is broadcast in forward pass, gradients need to be reduced on\n    broadcast dimension. This function checks the input tensor shape and\n    gradient shape and perform the reduction.\n\n    Parameters\n    ----------\n    grad: Tensor\n        Gradient tensor\n    shape: tuple\n        Shape of input tensor\n\n    Returns\n    -------\n    Tensor\n    \"\"\"\n    grad_shape = grad.shape[1:]\n    in_shape = shape[1:]\n    if in_shape == grad_shape:\n        # no need to reduce\n        return grad\n    num_to_squeeze = len(grad_shape) - len(in_shape)\n    # pad inshape\n    in_shape = (1,) * num_to_squeeze + in_shape\n    # pad in_shape\n    in_shape = (1,) * num_to_squeeze + in_shape\n    reduce_idx = np.nonzero(np.asarray(grad_shape) - np.asarray(in_shape))[0]\n    reduce_idx += 1  # skip batch dim\n    grad = grad.sum(axis=tuple(reduce_idx), keepdims=True)\n    return grad.reshape(shape)\n\n\ndef _need_reduce_last_dim(ufeat, efeat):\n    \"\"\"Indicates whether to reduce the last dimension on edges\n    in the backward pass of spmm,\n    if so, use dot instead of mul.\"\"\"\n    ushp = ufeat.shape\n    eshp = efeat.shape\n    return ushp[1:-1] == eshp[1:-1] and eshp[-1] == 1 and ushp[-1] > 1\n\n\ndef _muldiv(op, x):\n    return 1.0 / x if op == \"div\" else x\n\n\ndef _addsub(op, x):\n    return -x if op == \"sub\" else x\n\n\ndef _expand(x, shape):\n    return x.broadcast_to((x.shape[0], *shape))\n\n\nclass GSpMM(mx.autograd.Function):\n    def __init__(self, gidx, op, reduce_op):\n        super(GSpMM, self).__init__()\n        self.gidx = gidx\n        self.op = op\n        self.reduce_op = reduce_op\n\n    def forward(self, X, Y):\n        out, (argX, argY) = _gspmm(self.gidx, self.op, self.reduce_op, X, Y)\n        self.save_for_backward(X, Y, argX, argY)\n        return out\n\n    def backward(self, dZ):\n        ctx = context(dZ)\n        X, Y, argX, argY = self.saved_tensors\n        gidx, op, reduce_op = self.gidx, self.op, self.reduce_op\n        if op != \"copy_rhs\":\n            g_rev = gidx.reverse()\n            if reduce_op == \"sum\":\n                if op in [\"mul\", \"div\"]:\n                    dX = _gspmm(g_rev, \"mul\", \"sum\", dZ, _muldiv(op, Y))[0]\n                elif op in [\"add\", \"sub\"]:\n                    dX = _gspmm(g_rev, \"copy_lhs\", \"sum\", dZ, Y)[0]\n                elif op == \"copy_lhs\":\n                    dX = _gspmm(g_rev, \"copy_lhs\", \"sum\", dZ, None)[0]\n            else:\n                if op in [\"mul\", \"div\"]:\n                    dX = _scatter_nd(\n                        argX,\n                        _muldiv(op, _gather_nd(argY, _expand(Y, dZ.shape[1:])))\n                        * dZ,\n                        X.shape[0],\n                    )\n                elif op in [\"add\", \"sub\", \"copy_lhs\"]:\n                    dX = _scatter_nd(argX, dZ, X.shape[0])\n            dX = _reduce_grad(dX, X.shape)\n        else:\n            dX = nd.zeros_like(X)\n        if op != \"copy_lhs\":\n            if reduce_op == \"sum\":\n                if op == \"mul\" and _need_reduce_last_dim(X, Y):\n                    dY = _gsddmm(gidx, \"dot\", X, dZ)\n                elif op in [\"mul\", \"div\"]:\n                    dY = _gsddmm(gidx, \"mul\", X, dZ)\n                    if op == \"div\":\n                        dY = -dY / (Y**2)\n                elif op in [\"add\", \"sub\", \"copy_rhs\"]:\n                    dY = _gsddmm(gidx, \"copy_rhs\", X, _addsub(op, dZ))\n            else:\n                if op in [\"mul\", \"div\"]:\n                    dY = _scatter_nd(\n                        argY,\n                        _gather_nd(argX, _expand(X, dZ.shape[1:])) * dZ,\n                        Y.shape[0],\n                    )\n                    if op == \"div\":\n                        dY = -dY / (Y**2)\n                elif op in [\"add\", \"sub\", \"copy_rhs\"]:\n                    dY = _scatter_nd(argY, _addsub(op, dZ), Y.shape[0])\n            dY = _reduce_grad(dY, Y.shape)\n        else:\n            dY = nd.zeros_like(Y)\n        self.saved_tensors = None\n        return dX, dY\n\n\ndef gspmm(gidx, op, reduce_op, lhs_data, rhs_data):\n    func = GSpMM(gidx, op, reduce_op)\n    ctx = to_backend_ctx(gidx.ctx)\n    # XXX(minjie): There is a bug in MXNet's autograd system when one of the inputs\n    #   does not require gradient. Although it still invokes the backward function,\n    #   it does not set the gradient value to the correct buffer, resulting all the\n    #   input gradients to be zero. Fix this by enforcing all the inputs to require\n    #   gradients.\n    if lhs_data is None:\n        lhs_data = nd.zeros((1,), ctx=ctx)\n        lhs_data.attach_grad()\n    if rhs_data is None:\n        rhs_data = nd.zeros((1,), ctx=ctx)\n        rhs_data.attach_grad()\n    return func(lhs_data, rhs_data)\n\n\nclass GSDDMM(mx.autograd.Function):\n    def __init__(self, gidx, op, lhs_target, rhs_target):\n        super(GSDDMM, self).__init__()\n        self.gidx = gidx\n        self.op = op\n        self.lhs_target = lhs_target\n        self.rhs_target = rhs_target\n\n    def forward(self, X, Y):\n        out = _gsddmm(\n            self.gidx, self.op, X, Y, self.lhs_target, self.rhs_target\n        )\n        self.save_for_backward(X, Y)\n        return out\n\n    def backward(self, dZ):\n        ctx = context(dZ)\n        X, Y = self.saved_tensors\n        gidx, op = self.gidx, self.op\n        lhs_target, rhs_target = self.lhs_target, self.rhs_target\n        if op != \"copy_rhs\":\n            if lhs_target in [\"u\", \"v\"]:\n                _gidx = gidx if self.lhs_target == \"v\" else gidx.reverse()\n                if op in [\"add\", \"sub\", \"copy_lhs\"]:\n                    dX = _gspmm(_gidx, \"copy_rhs\", \"sum\", None, dZ)[0]\n                else:  # mul, div, dot\n                    if rhs_target == lhs_target:\n                        dX = _gspmm(_gidx, \"copy_rhs\", \"sum\", None, dZ)[\n                            0\n                        ] * _muldiv(op, Y)\n                    elif self.rhs_target == \"e\":\n                        dX = _gspmm(\n                            _gidx, \"copy_rhs\", \"sum\", None, dZ * _muldiv(op, Y)\n                        )[0]\n                    else:  # rhs_target = !lhs_target\n                        dX = _gspmm(_gidx, \"mul\", \"sum\", _muldiv(op, Y), dZ)[0]\n            else:  # lhs_target == 'e'\n                if op in [\"add\", \"sub\", \"copy_lhs\"]:\n                    dX = dZ\n                else:  # mul, div, dot\n                    dX = _gsddmm(\n                        gidx, \"mul\", dZ, _muldiv(op, Y), \"e\", rhs_target\n                    )\n            dX = _reduce_grad(dX, X.shape)\n        else:\n            dX = nd.zeros_like(X)\n        if op != \"copy_lhs\":\n            if self.rhs_target in [\"u\", \"v\"]:\n                _gidx = gidx if rhs_target == \"v\" else gidx.reverse()\n                if op in [\"add\", \"sub\", \"copy_rhs\"]:\n                    dY = _gspmm(\n                        _gidx, \"copy_rhs\", \"sum\", None, _addsub(op, dZ)\n                    )[0]\n                else:  # mul, div, dot\n                    if lhs_target == rhs_target:\n                        dY = _gspmm(_gidx, \"copy_rhs\", \"sum\", None, dZ)[0] * X\n                    elif self.lhs_target == \"e\":\n                        dY = _gspmm(_gidx, \"copy_rhs\", \"sum\", None, dZ * X)[0]\n                    else:  # rhs_target = !lhs_target\n                        dY = _gspmm(_gidx, \"mul\", \"sum\", X, dZ)[0]\n                    if op == \"div\":\n                        dY = -dY / (Y**2)\n            else:\n                if op in [\"add\", \"sub\", \"copy_rhs\"]:\n                    dY = _addsub(op, dZ)\n                else:  # mul, div, dot\n                    dY = _gsddmm(gidx, \"mul\", dZ, X, \"e\", lhs_target)\n                    if op == \"div\":\n                        dY = -dY / (Y**2)\n            dY = _reduce_grad(dY, Y.shape)\n        else:\n            dY = nd.zeros_like(Y)\n        self.saved_tensors = None\n        return dX, dY\n\n\ndef gsddmm(gidx, op, lhs_data, rhs_data, lhs_target=\"u\", rhs_target=\"v\"):\n    func = GSDDMM(gidx, op, lhs_target, rhs_target)\n    ctx = to_backend_ctx(gidx.ctx)\n    if lhs_data is None:\n        lhs_data = nd.zeros((1,), ctx=ctx)\n    if rhs_data is None:\n        rhs_data = nd.zeros((1,), ctx=ctx)\n    return func(lhs_data, rhs_data)\n\n\nclass EdgeSoftmax(mx.autograd.Function):\n    def __init__(self, gidx, eids, norm_by):\n        super(EdgeSoftmax, self).__init__()\n        if not is_all(eids):\n            gidx = gidx.edge_subgraph([eids], True).graph\n        if norm_by == \"src\":\n            gidx = gidx.reverse()\n        self.gidx = gidx\n\n    def forward(self, score):\n        \"\"\"Forward function.\n\n        Pseudo-code:\n\n        .. code:: python\n\n            score = dgl.EData(g, score)\n            score_max = score.dst_max()  # of type dgl.NData\n            score = score - score_max  # edge_sub_dst, ret dgl.EData\n            score_sum = score.dst_sum()  # of type dgl.NData\n            out = score / score_sum    # edge_div_dst, ret dgl.EData\n            return out.data\n        \"\"\"\n        gidx = self.gidx\n        score_max = _gspmm(gidx, \"copy_rhs\", \"max\", None, score)[0]\n        score = mx.nd.exp(_gsddmm(gidx, \"sub\", score, score_max, \"e\", \"v\"))\n        score_sum = _gspmm(gidx, \"copy_rhs\", \"sum\", None, score)[0]\n        out = _gsddmm(gidx, \"div\", score, score_sum, \"e\", \"v\")\n        self.save_for_backward(out)\n        return out\n\n    def backward(self, grad_out):\n        \"\"\"Backward function.\n\n        Pseudo-code:\n\n        .. code:: python\n\n            g, out = ctx.backward_cache\n            grad_out = dgl.EData(g, grad_out)\n            out = dgl.EData(g, out)\n            sds = out * grad_out  # type dgl.EData\n            sds_sum = sds.dst_sum()  # type dgl.NData\n            grad_score = sds - sds * sds_sum  # multiple expressions\n        \"\"\"\n        (out,) = self.saved_tensors\n        gidx = self.gidx\n        sds = out * grad_out\n        accum = gspmm(gidx, \"copy_rhs\", \"sum\", None, sds)\n        grad_score = sds - gsddmm(gidx, \"mul\", out, accum, \"e\", \"v\")\n        self.save_tensors = None\n        return grad_score\n\n\ndef edge_softmax(gidx, logits, eids=ALL, norm_by=\"dst\"):\n    softmax_op = EdgeSoftmax(gidx, eids, norm_by)\n    return softmax_op(logits)\n\n\nclass SegmentReduce(mx.autograd.Function):\n    def __init__(self, op, offsets):\n        super(SegmentReduce, self).__init__()\n        self.op = op\n        self.offsets = offsets\n\n    def forward(self, x):\n        y, arg = _segment_reduce(self.op, x, self.offsets)\n        self.save_for_backward(arg)\n        return y\n\n    def backward(self, dy):\n        (arg,) = self.saved_tensors\n        offsets = self.offsets\n        m = offsets[-1].asscalar()\n        if self.op == \"sum\":\n            offsets_np = asnumpy(offsets[1:])\n            indices_np = np.zeros((m + 1,), dtype=offsets_np.dtype)\n            np.add.at(indices_np, offsets_np, np.ones_like(offsets_np))\n            indices_np = np.cumsum(indices_np, -1)[:-1]\n            indices = zerocopy_from_numpy(indices_np)\n            dx = dy[indices]\n        else:\n            dx = _bwd_segment_cmp(dy, arg, m)\n        return dx\n\n\ndef segment_reduce(op, x, offsets):\n    segment_reduce_op = SegmentReduce(op, offsets)\n    return segment_reduce_op(x)\n\n\nclass ScatterAdd(mx.autograd.Function):\n    def __init__(self, idx, m):\n        super(ScatterAdd, self).__init__()\n        self.idx = idx\n        self.m = m\n\n    def forward(self, x):\n        y = _scatter_add(x, self.idx, self.m)\n        return y\n\n    def backward(self, dy):\n        return dy[self.idx]\n\n\ndef scatter_add(x, idx, m):\n    scatter_add_op = ScatterAdd(idx, m)\n    return scatter_add_op(x)\n\n\nclass CSRMM(mx.autograd.Function):\n    def __init__(self, gidxA, gidxB, num_vtypes):\n        super().__init__()\n        self.gidxA = gidxA\n        self.gidxB = gidxB\n        self.num_vtypes = num_vtypes\n\n    def forward(self, A_weights, B_weights):\n        gidxC, C_weights = _csrmm(\n            self.gidxA, A_weights, self.gidxB, B_weights, self.num_vtypes\n        )\n        (\n            nrows,\n            ncols,\n            C_indptr,\n            C_indices,\n            C_eids,\n        ) = gidxC.adjacency_matrix_tensors(0, False, \"csr\")\n        # Note: the returned C_indptr, C_indices and C_eids tensors MUST be the same\n        # as the underlying tensors of the created graph gidxC.\n        self.backward_cache = gidxC\n        self.save_for_backward(A_weights, B_weights)\n        nrows = nd.array([nrows], dtype=\"int64\")\n        ncols = nd.array([ncols], dtype=\"int64\")\n        return nrows, ncols, C_indptr, C_indices, C_eids, C_weights\n\n    def backward(\n        self, dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights\n    ):\n        # Only the last argument is meaningful.\n        gidxC = self.backward_cache\n        A_weights, B_weights = self.saved_tensors\n        dgidxA, dA_weights = _csrmm(\n            gidxC,\n            dC_weights,\n            self.gidxB.reverse(),\n            B_weights,\n            self.gidxA.number_of_ntypes(),\n        )\n        dgidxB, dB_weights = _csrmm(\n            self.gidxA.reverse(),\n            A_weights,\n            gidxC,\n            dC_weights,\n            self.gidxB.number_of_ntypes(),\n        )\n        dA_weights = _csrmask(dgidxA, dA_weights, self.gidxA)\n        dB_weights = _csrmask(dgidxB, dB_weights, self.gidxB)\n        return dA_weights, dB_weights\n\n\ndef csrmm(gidxA, A_weights, gidxB, B_weights, num_vtypes):\n    op = CSRMM(gidxA, gidxB, num_vtypes)\n    nrows, ncols, C_indptr, C_indices, C_eids, C_weights = op(\n        A_weights, B_weights\n    )\n    gidxC = create_unitgraph_from_csr(\n        num_vtypes,\n        nrows.asscalar(),\n        ncols.asscalar(),\n        C_indptr,\n        C_indices,\n        C_eids,\n        [\"coo\", \"csr\", \"csc\"],\n    )\n    return gidxC, C_weights\n\n\nclass CSRSum(mx.autograd.Function):\n    def __init__(self, gidxs):\n        super().__init__()\n        self.gidxs = gidxs\n\n    def forward(self, *weights):\n        gidxC, C_weights = _csrsum(self.gidxs, weights)\n        (\n            nrows,\n            ncols,\n            C_indptr,\n            C_indices,\n            C_eids,\n        ) = gidxC.adjacency_matrix_tensors(0, False, \"csr\")\n        # Note: the returned C_indptr, C_indices and C_eids tensors MUST be the same\n        # as the underlying tensors of the created graph gidxC.\n        self.backward_cache = gidxC\n        nrows = nd.array([nrows], dtype=\"int64\")\n        ncols = nd.array([ncols], dtype=\"int64\")\n        return nrows, ncols, C_indptr, C_indices, C_eids, C_weights\n\n    def backward(\n        self, dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights\n    ):\n        # Only the last argument is meaningful.\n        gidxC = self.backward_cache\n        return tuple(csrmask(gidxC, dC_weights, gidx) for gidx in self.gidxs)\n\n\ndef csrsum(gidxs, weights):\n    op = CSRSum(gidxs)\n    nrows, ncols, C_indptr, C_indices, C_eids, C_weights = op(*weights)\n    num_vtypes = gidxs[0].number_of_ntypes()\n    gidxC = create_unitgraph_from_csr(\n        num_vtypes,\n        nrows.asscalar(),\n        ncols.asscalar(),\n        C_indptr,\n        C_indices,\n        C_eids,\n        [\"coo\", \"csr\", \"csc\"],\n    )\n    return gidxC, C_weights\n\n\nclass CSRMask(mx.autograd.Function):\n    def __init__(self, gidxA, gidxB):\n        super().__init__()\n        self.gidxA = gidxA\n        self.gidxB = gidxB\n\n    def forward(self, A_weights):\n        return _csrmask(self.gidxA, A_weights, self.gidxB)\n\n    def backward(self, dB_weights):\n        return _csrmask(self.gidxB, dB_weights, self.gidxA)\n\n\ndef csrmask(gidxA, A_weights, gidxB):\n    op = CSRMask(gidxA, gidxB)\n    return op(A_weights)\n"
  },
  {
    "path": "python/dgl/backend/mxnet/sparse_optim.py",
    "content": "\"\"\"Sparse optimizer is not supported for mxnet\"\"\"\n"
  },
  {
    "path": "python/dgl/backend/mxnet/tensor.py",
    "content": "from __future__ import absolute_import\n\nimport builtins\nimport numbers\nimport os\n\nimport mxnet as mx\nimport mxnet.ndarray as nd\nimport numpy as np\n\nfrom ... import ndarray as dglnd\nfrom ...function.base import TargetCode\nfrom ...utils import version\n\nif version.parse(mx.__version__) < version.parse(\"1.6.0\"):\n    raise RuntimeError(\"DGL requires MXNet >= 1.6\")\n\n# After MXNet 1.5, empty tensors aren't supprted by default.\n# After we turn on the numpy compatible flag, MXNet supports empty NDArray.\nmx.set_np_shape(bool(os.environ.get(\"DGL_MXNET_SET_NP_SHAPE\", True)))\n\n\ndef data_type_dict():\n    return {\n        \"float16\": np.float16,\n        \"float32\": np.float32,\n        \"float64\": np.float64,\n        \"uint8\": np.uint8,\n        \"int8\": np.int8,\n        \"int16\": np.int16,\n        \"int32\": np.int32,\n        \"int64\": np.int64,\n        \"bool\": np.bool_,\n    }  # mxnet does not support bool\n\n\ndef cpu():\n    return mx.cpu()\n\n\ndef tensor(data, dtype=None):\n    if dtype == np.bool_:\n        # mxnet doesn't support bool\n        dtype = np.int32\n    if isinstance(data, nd.NDArray):\n        if dtype is None or data.dtype == dtype:\n            return data\n        else:\n            return data.astype(dtype)\n    else:\n        if isinstance(data, numbers.Number):\n            data = [data]\n        if dtype is None:\n            if isinstance(data, np.ndarray):\n                dtype = np.int32 if data.dtype == np.bool_ else data.dtype\n            elif len(data) == 0:\n                dtype = np.int64\n            else:\n                dtype = (\n                    np.int64\n                    if isinstance(data[0], numbers.Integral)\n                    else np.float32\n                )\n        return nd.array(data, dtype=dtype)\n\n\ndef as_scalar(data):\n    if data.size != 1:\n        raise ValueError(\"The current array is not a scalar\")\n    if data.shape != (1,):\n        data = data.expand_dims(axis=0)\n    return data.asscalar()\n\n\ndef get_preferred_sparse_format():\n    \"\"\"Get the preferred sparse matrix format supported by the backend.\n\n    Different backends have their preferred backend. This info is useful when\n    constructing a sparse matrix.\n    \"\"\"\n    return \"csr\"\n\n\ndef sparse_matrix(data, index, shape, force_format=False):\n    fmt = index[0]\n    if fmt == \"coo\":\n        if force_format:\n            raise TypeError(\n                \"MXNet backend only supports CSR format,\"\n                \" but COO format is forced.\"\n            )\n        coord = index[1]\n        # generate convert idx\n        # FIXME: cannot use int64\n        tmp_data = nd.arange(\n            len(coord[0]), dtype=data.dtype, ctx=coord[0].context\n        )\n        tmp_spmat = nd.sparse.csr_matrix(\n            (tmp_data, (coord[0], coord[1])), tuple(shape), ctx=data.context\n        )\n        convert_idx = nd.cast(tmp_spmat.data, dtype=\"int64\")\n        # shuffle the data\n        data = data[convert_idx]\n        spmat = nd.sparse.csr_matrix(\n            (data, tmp_spmat.indices, tmp_spmat.indptr),\n            tuple(shape),\n            ctx=data.context,\n        )\n        return spmat, convert_idx\n    elif fmt == \"csr\":\n        indices = index[1]\n        indptr = index[2]\n        spmat = nd.sparse.csr_matrix(\n            (data, indices, indptr), tuple(shape), ctx=data.context\n        )\n        # No conversion is required.\n        return spmat, None\n    else:\n        raise TypeError(\"Invalid format: %s.\" % fmt)\n\n\ndef sparse_matrix_indices(spmat):\n    return (\"csr\", spmat.indices, spmat.indptr)\n\n\ndef is_tensor(obj):\n    return isinstance(obj, nd.NDArray)\n\n\ndef shape(input):\n    # NOTE: the input cannot be a symbol\n    return input.shape\n\n\ndef dtype(input):\n    # NOTE: the input cannot be a symbol\n    return input.dtype\n\n\ndef ndim(input):\n    return input.ndim\n\n\ndef context(input):\n    return input.context\n\n\ndef device_type(ctx):\n    return ctx.device_type\n\n\ndef device_id(ctx):\n    return ctx.device_id\n\n\ndef to_backend_ctx(dglctx):\n    dev_type = dglctx.device_type\n    if dev_type == 1:\n        return mx.cpu()\n    elif dev_type == 2:\n        return mx.gpu(dglctx.device_id)\n    else:\n        raise ValueError(\"Unsupported DGL device context:\", dglctx)\n\n\ndef astype(input, ty):\n    if ty == np.bool_:\n        ty = np.int32\n    return input.astype(ty)\n\n\ndef asnumpy(input):\n    return input.asnumpy()\n\n\ndef copy_to(input, ctx, **kwargs):\n    return input.as_in_context(ctx)\n\n\ndef is_pinned(input):\n    return input.context == mx.cpu_pinned()\n\n\ndef sum(input, dim, keepdims=False):\n    if len(input) == 0:\n        return nd.array([0.0], dtype=input.dtype, ctx=input.context)\n    return nd.sum(input, axis=dim, keepdims=keepdims)\n\n\ndef floor_div(in1, in2):\n    return in1 / in2\n\n\ndef reduce_sum(input):\n    return input.sum()\n\n\ndef cumsum(input, dim):\n    return nd.cumsum(input, axis=dim)\n\n\ndef mean(input, dim):\n    return nd.mean(input, axis=dim)\n\n\ndef reduce_mean(input):\n    return input.mean()\n\n\ndef max(input, dim):\n    return nd.max(input, axis=dim)\n\n\ndef reduce_max(input):\n    return input.max()\n\n\ndef min(input, dim):\n    return nd.min(input, axis=dim)\n\n\ndef reduce_min(input):\n    return input.min()\n\n\ndef topk(input, k, dim, descending=True):\n    return nd.topk(\n        input, axis=dim, k=k, ret_typ=\"value\", is_ascend=not descending\n    )\n\n\ndef argtopk(input, k, dim, descending=True):\n    idx = nd.argsort(input, dim, is_ascend=not descending)\n    return nd.slice_axis(input, dim, 0, k)\n\n\ndef argsort(input, dim, descending):\n    idx = nd.argsort(input, dim, is_ascend=not descending)\n    idx = nd.cast(idx, dtype=\"int64\")\n    return idx\n\n\ndef exp(input):\n    return nd.exp(input)\n\n\ndef inverse(input):\n    return nd.linalg_inverse(input)\n\n\ndef sqrt(input):\n    return nd.sqrt(input)\n\n\ndef softmax(input, dim=-1):\n    return nd.softmax(input, axis=dim)\n\n\ndef cat(seq, dim):\n    return nd.concat(*seq, dim=dim)\n\n\ndef stack(seq, dim):\n    return nd.stack(*seq, axis=dim)\n\n\ndef split(x, sizes_or_sections, dim):\n    if isinstance(sizes_or_sections, list) and len(sizes_or_sections) == 1:\n        assert len(x) == sizes_or_sections[0]\n        return [x]\n\n    if isinstance(sizes_or_sections, (np.ndarray, list)):\n        sizes_or_sections1 = tuple(np.cumsum(sizes_or_sections)[:-1])\n    return nd.split_v2(x, sizes_or_sections1, axis=dim)\n\n\ndef repeat(input, repeats, dim):\n    if isinstance(repeats, nd.NDArray):\n        return nd.array(\n            np.repeat(input.asnumpy(), repeats.asnumpy(), axis=dim),\n            ctx=input.context,\n            dtype=input.dtype,\n        )\n    else:\n        return nd.repeat(input, repeats, axis=dim)\n\n\ndef gather_row(data, row_index):\n    # MXNet workaround for empty row index\n    if len(row_index) == 0:\n        if data.shape[0] == 0:\n            return data\n        else:\n            return data[0:0]\n\n    if isinstance(row_index, nd.NDArray):\n        return nd.take(data, row_index)\n    else:\n        return data[\n            row_index,\n        ]\n\n\ndef slice_axis(data, axis, begin, end):\n    dim = data.shape[axis]\n    if begin < 0:\n        begin += dim\n    if end <= 0:\n        end += dim\n    return nd.slice_axis(data, axis, begin, end)\n\n\ndef take(data, indices, dim):\n    return nd.take(data, indices, dim)\n\n\ndef narrow_row(data, start, stop):\n    return data[start:stop]\n\n\ndef index_add_inplace(data, row_idx, value):\n    raise NotImplementedError(\"MXNet doesn't support inplace index_add\")\n\n\ndef scatter_row(data, row_index, value):\n    return mx.nd.contrib.index_copy(data, row_index, value)\n\n\ndef scatter_row_inplace(data, row_index, value):\n    data[row_index] = value\n\n\ndef squeeze(input, dim):\n    return nd.squeeze(input, axis=dim)\n\n\ndef unsqueeze(input, dim):\n    return nd.expand_dims(input, axis=dim)\n\n\ndef reshape(input, shape):\n    # NOTE: the input cannot be a symbol\n    return nd.reshape(input, shape)\n\n\ndef swapaxes(input, axis1, axis2):\n    return nd.swapaxes(input, axis1, axis2)\n\n\ndef empty(shape, dtype, ctx):\n    return nd.empty(shape, dtype=dtype, ctx=ctx)\n\n\ndef zeros(shape, dtype, ctx):\n    return nd.zeros(shape, dtype=dtype, ctx=ctx)\n\n\ndef zeros_like(input):\n    return nd.zeros_like(input)\n\n\ndef ones(shape, dtype, ctx):\n    return nd.ones(shape, dtype=dtype, ctx=ctx)\n\n\ndef uniform(shape, dtype, ctx, low, high):\n    return nd.random.uniform(low, high, ctx=ctx, dtype=dtype, shape=shape)\n\n\ndef randint(shape, dtype, ctx, low, high):\n    return nd.random.randint(low, high, ctx=ctx, dtype=dtype, shape=shape)\n\n\ndef pad_packed_tensor(input, lengths, value, l_min=None):\n    old_shape = input.shape\n    if isinstance(lengths, nd.NDArray):\n        lengths = list(lengths.asnumpy())\n    max_len = builtins.max(lengths)\n\n    if l_min is not None:\n        max_len = builtins.max(max_len, l_min)\n\n    batch_size = len(lengths)\n    ctx = input.context\n    dtype = input.dtype\n    x = nd.full(\n        (batch_size * max_len, *old_shape[1:]), value, ctx=ctx, dtype=dtype\n    )\n    index = []\n    for i, l in enumerate(lengths):\n        index.extend(range(i * max_len, i * max_len + l))\n    index = nd.array(index, ctx=ctx)\n    return scatter_row(x, index, input).reshape(\n        batch_size, max_len, *old_shape[1:]\n    )\n\n\ndef pack_padded_tensor(input, lengths):\n    batch_size, max_len = input.shape[:2]\n    ctx = input.context\n    index = []\n    for i, l in enumerate(lengths):\n        index.extend(range(i * max_len, i * max_len + l))\n    index = nd.array(index, ctx=ctx)\n    return gather_row(input.reshape(batch_size * max_len, -1), index)\n\n\ndef boolean_mask(input, mask):\n    return mx.contrib.nd.boolean_mask(input, mask)\n\n\ndef equal(x, y):\n    return x == y\n\n\ndef allclose(x, y, rtol=1e-4, atol=1e-4):\n    return np.allclose(x.asnumpy(), y.asnumpy(), rtol=rtol, atol=atol)\n\n\ndef logical_not(input):\n    return nd.logical_not(input)\n\n\ndef logical_and(input1, input2):\n    return nd.logical_and(input1, input2)\n\n\ndef clone(input):\n    return input.copy()\n\n\ndef clamp(data, min_val, max_val):\n    return nd.clip(data, min_val, max_val)\n\n\ndef replace_inf_with_zero(x):\n    return nd.where(nd.abs(x) == np.inf, nd.zeros_like(x), x)\n\n\ndef count_nonzero(input):\n    # TODO: fallback to numpy is unfortunate\n    tmp = input.asnumpy()\n    return np.count_nonzero(tmp)\n\n\ndef unique(input, return_inverse=False, return_counts=False):\n    # TODO: fallback to numpy is unfortunate\n    tmp = input.asnumpy()\n    if return_inverse and return_counts:\n        tmp, inv, count = np.unique(\n            tmp, return_inverse=True, return_counts=True\n        )\n        tmp = nd.array(tmp, ctx=input.context, dtype=input.dtype)\n        inv = nd.array(inv, ctx=input.context)\n        count = nd.array(count, ctx=input.context)\n        return tmp, inv, count\n    elif return_inverse or return_counts:\n        tmp, tmp2 = np.unique(\n            tmp, return_inverse=return_inverse, return_counts=return_counts\n        )\n        tmp = nd.array(tmp, ctx=input.context, dtype=input.dtype)\n        tmp2 = nd.array(tmp2, ctx=input.context)\n        return tmp, tmp2\n    else:\n        tmp = np.unique(tmp)\n        return nd.array(tmp, ctx=input.context, dtype=input.dtype)\n\n\ndef full_1d(length, fill_value, dtype, ctx):\n    return nd.full((length,), fill_value, dtype=dtype, ctx=ctx)\n\n\ndef nonzero_1d(input):\n    # TODO: fallback to numpy is unfortunate\n    tmp = input.asnumpy()\n    tmp = np.nonzero(tmp)[0]\n    r = nd.array(tmp, ctx=input.context, dtype=tmp.dtype)\n    return r\n\n\ndef sort_1d(input):\n    # TODO: this isn't an ideal implementation.\n    val = nd.sort(input, axis=None, is_ascend=True)\n    idx = nd.argsort(input, is_ascend=True)\n    idx = nd.cast(idx, dtype=\"int64\")\n    return val, idx\n\n\ndef arange(start, stop, dtype=np.int64, ctx=None):\n    if start >= stop:\n        return nd.array([], dtype=dtype, ctx=ctx)\n    else:\n        return nd.arange(start, stop, dtype=dtype, ctx=ctx)\n\n\ndef rand_shuffle(arr):\n    return mx.nd.random.shuffle(arr)\n\n\ndef zerocopy_to_dlpack(arr):\n    return arr.to_dlpack_for_read()\n\n\ndef zerocopy_from_dlpack(dlpack_arr):\n    return nd.from_dlpack(dlpack_arr)\n\n\ndef zerocopy_to_numpy(arr):\n    # NOTE: not zerocopy\n    return arr.asnumpy()\n\n\ndef zerocopy_from_numpy(np_data):\n    np_data = np.asarray(np_data, order=\"C\")\n    return mx.nd.from_numpy(np_data, zero_copy=True)\n\n\ndef zerocopy_to_dgl_ndarray(arr):\n    arr.to_dlpack_for_read()\n    return dglnd.from_dlpack(arr.to_dlpack_for_read())\n\n\ndef zerocopy_to_dgl_ndarray_for_write(arr):\n    return dglnd.from_dlpack(arr.to_dlpack_for_write())\n\n\ndef zerocopy_from_dgl_ndarray(arr):\n    return nd.from_dlpack(arr.to_dlpack())\n\n\ndef sync():\n    \"\"\"Synchronize computation.\n\n    In DL frameworks such as MXNet and TensorFlow, the computation in operators\n    are done asynchronously. This is to synchronize computation and makes sure\n    that all computation is complete after this function call.\n    \"\"\"\n    mx.nd.waitall()\n\n\ndef attach_grad(tensor):\n    tensor.attach_grad()\n    return tensor\n\n\ndef backward(x, head_gradient=None):\n    x.backward(head_gradient)\n\n\ndef grad(x):\n    return x.grad\n\n\ndef is_no_grad(x):\n    return (x != 0).sum() == 0\n\n\ndef is_recording():\n    return mx.autograd.is_recording()\n\n\nrecord_grad = mx.autograd.record\n\n\nclass no_grad(object):\n    def __init__(self):\n        pass\n\n    def __enter__(self):\n        pass\n\n    def __exit__(self, exc_type, exc_value, exc_traceback):\n        pass\n"
  },
  {
    "path": "python/dgl/backend/pytorch/__init__.py",
    "content": "from .sparse import *\nfrom .tensor import *\n"
  },
  {
    "path": "python/dgl/backend/pytorch/sparse.py",
    "content": "import torch as th\n\nfrom ..._sparse_ops import (\n    _bwd_segment_cmp,\n    _csrmask,\n    _csrmm,\n    _csrsum,\n    _edge_softmax_backward,\n    _edge_softmax_forward,\n    _gather_mm,\n    _gather_mm_scatter,\n    _gsddmm,\n    _gsddmm_hetero,\n    _gspmm,\n    _gspmm_hetero,\n    _scatter_add,\n    _segment_mm,\n    _segment_mm_backward_B,\n    _segment_reduce,\n    _update_grad_minmax_hetero,\n)\n\nfrom ...base import ALL, is_all\nfrom ...heterograph_index import create_unitgraph_from_csr\n\n__all__ = [\n    \"gspmm\",\n    \"gsddmm\",\n    \"gspmm_hetero\",\n    \"gsddmm_hetero\",\n    \"edge_softmax\",\n    \"edge_softmax_hetero\",\n    \"segment_reduce\",\n    \"scatter_add\",\n    \"csrmm\",\n    \"csrsum\",\n    \"csrmask\",\n    \"gather_mm\",\n    \"segment_mm\",\n]\n\n\ndef _reduce_grad(grad, shape):\n    \"\"\"Reduce gradient on the broadcast dimension\n    If there is broadcast in forward pass, gradients need to be reduced on\n    broadcast dimension. This function checks the input tensor shape and\n    gradient shape and perform the reduction.\n\n    Parameters\n    ----------\n    grad: Tensor\n        Gradient tensor\n    shape: tuple\n        Shape of input tensor\n\n    Returns\n    -------\n    Tensor\n    \"\"\"\n    grad_shape = grad.shape[1:]\n    in_shape = shape[1:]\n    if in_shape == grad_shape:\n        # no need to reduce\n        return grad\n    num_to_squeeze = len(grad_shape) - len(in_shape)\n    # pad inshape\n    in_shape = (1,) * num_to_squeeze + in_shape\n    reduce_idx = th.nonzero(\n        th.tensor(grad_shape) - th.tensor(in_shape), as_tuple=False\n    )\n    reduce_idx += 1  # skip batch dim\n    if len(reduce_idx) > 0:\n        grad = grad.sum(dim=tuple(reduce_idx), keepdim=True)\n    return grad.view(-1, *shape[1:])\n\n\ndef _need_reduce_last_dim(ufeat, efeat):\n    \"\"\"Indicates whether to reduce the last dimension on edges\n    in the backward pass of spmm,\n    if so, use dot instead of mul.\"\"\"\n    if ufeat is None or efeat is None:\n        return False\n    ushp = ufeat.shape\n    eshp = efeat.shape\n    return ushp[1:-1] == eshp[1:-1] and eshp[-1] == 1 and ushp[-1] > 1\n\n\ndef _expand(x, shape):\n    return x.expand(-1, *shape)\n\n\ndef spmm_cache_X(binary_op, reduce_op, req_grad_X, req_grad_Y):\n    \"\"\"Rules to identify whether to cache X in SpMM forward stage.\"\"\"\n    if binary_op != \"copy_lhs\" and req_grad_Y:\n        if reduce_op == \"sum\":\n            return True\n        else:\n            if binary_op == \"mul\":\n                return True\n    return False\n\n\ndef spmm_cache_Y(binary_op, reduce_op, req_grad_X, req_grad_Y):\n    \"\"\"Rules to identify whether to cache Y in SpMM forward stage.\"\"\"\n    if binary_op != \"copy_rhs\" and req_grad_X:\n        if reduce_op == \"sum\":\n            if binary_op in [\"mul\", \"add\"]:\n                return True\n        else:\n            if binary_op == \"mul\":\n                return True\n    return False\n\n\ndef spmm_cache_argX(binary_op, reduce_op, req_grad_X, req_grad_Y):\n    \"\"\"Rules to identify whether to cache argX in SpMM forward stage.\"\"\"\n    if req_grad_X or req_grad_Y:\n        if reduce_op in [\"min\", \"max\"]:\n            return True\n    return False\n\n\ndef spmm_cache_argY(binary_op, reduce_op, req_grad_X, req_grad_Y):\n    \"\"\"Rules to identify whether to cache argY in SpMM forward stage.\"\"\"\n    if req_grad_X or req_grad_Y:\n        if reduce_op in [\"min\", \"max\"]:\n            return True\n    return False\n\n\nclass empty_context:\n    \"\"\"Empty context that does nothing\"\"\"\n\n    def __init__(self, *args, **kargs):\n        return\n\n    def __enter__(self, *args, **kargs):\n        return self\n\n    def __exit__(self, *args, **kargs):\n        return\n\n\n# Disable CUDA autocast since we have casted args manually,\n# and do it only in a nested autocast context.\ndef _disable_autocast_if_enabled():\n    if th.is_autocast_enabled():\n        return th.cuda.amp.autocast(enabled=False)\n    else:\n        return empty_context()\n\n\ndef _cast_if_autocast_enabled(*args):\n    if not th.is_autocast_enabled():\n        return args\n    else:\n        return th.cuda.amp.autocast_mode._cast(\n            args, th.get_autocast_gpu_dtype()\n        )\n\n\nclass GSpMM(th.autograd.Function):\n    @staticmethod\n    def forward(ctx, gidx, op, reduce_op, X, Y):\n        out, (argX, argY) = _gspmm(gidx, op, reduce_op, X, Y)\n        reduce_last = _need_reduce_last_dim(X, Y)\n        X_shape = X.shape if X is not None else None\n        Y_shape = Y.shape if Y is not None else None\n        dtype = X.dtype if X is not None else Y.dtype\n        device = X.device if X is not None else Y.device\n        ctx.backward_cache = (\n            gidx,\n            op,\n            reduce_op,\n            X_shape,\n            Y_shape,\n            dtype,\n            device,\n            reduce_last,\n        )\n        req_grad_X = X.requires_grad if X is not None else False\n        req_grad_Y = Y.requires_grad if Y is not None else False\n        if not spmm_cache_X(op, reduce_op, req_grad_X, req_grad_Y):\n            X = None\n        if not spmm_cache_Y(op, reduce_op, req_grad_X, req_grad_Y):\n            Y = None\n        if not spmm_cache_argX(op, reduce_op, req_grad_X, req_grad_Y):\n            argX = None\n        if not spmm_cache_argY(op, reduce_op, req_grad_X, req_grad_Y):\n            argY = None\n        ctx.save_for_backward(X, Y, argX, argY)\n        return out\n\n    @staticmethod\n    def backward(ctx, dZ):\n        (\n            gidx,\n            op,\n            reduce_op,\n            X_shape,\n            Y_shape,\n            dtype,\n            device,\n            reduce_last,\n        ) = ctx.backward_cache\n        X, Y, argX, argY = ctx.saved_tensors\n        if op != \"copy_rhs\" and ctx.needs_input_grad[3]:\n            g_rev = gidx.reverse()\n            if reduce_op == \"sum\":\n                if op == \"mul\":\n                    dX = gspmm(g_rev, \"mul\", \"sum\", dZ, Y)\n                elif op == \"add\":\n                    dX = gspmm(g_rev, \"copy_lhs\", \"sum\", dZ, Y)\n                elif op == \"copy_lhs\":\n                    dX = gspmm(g_rev, \"copy_lhs\", \"sum\", dZ, None)\n            else:  # max/min\n                dX = th.zeros(\n                    (X_shape[0],) + dZ.shape[1:], dtype=dtype, device=device\n                )\n                if op == \"mul\":\n                    grad = _expand(Y, dZ.shape[1:]).gather(0, argY.long()) * dZ\n                    dX.scatter_add_(0, argX.long(), grad)\n                elif op in [\"add\", \"copy_lhs\"]:\n                    dX.scatter_add_(0, argX.long(), dZ)\n            dX = _reduce_grad(dX, X_shape)\n        else:  # X has not gradient\n            dX = None\n        if op != \"copy_lhs\" and ctx.needs_input_grad[4]:\n            if reduce_op == \"sum\":\n                if op == \"mul\" and reduce_last:\n                    dY = gsddmm(gidx, \"dot\", X, dZ)\n                elif op == \"mul\":\n                    dY = gsddmm(gidx, \"mul\", X, dZ)\n                elif op in [\"add\", \"copy_rhs\"]:\n                    dY = gsddmm(gidx, \"copy_rhs\", X, dZ)\n            else:  # max/min\n                dY = th.zeros(\n                    (Y_shape[0],) + dZ.shape[1:], dtype=dtype, device=device\n                )\n                if op == \"mul\":\n                    grad = _expand(X, dZ.shape[1:]).gather(0, argX.long()) * dZ\n                    dY.scatter_add_(0, argY.long(), grad)\n                elif op in [\"add\", \"copy_rhs\"]:\n                    dY.scatter_add_(0, argY.long(), dZ)\n            dY = _reduce_grad(dY, Y_shape)\n        else:  # Y has no gradient\n            dY = None\n        return None, None, None, dX, dY\n\n\nclass GSpMM_hetero(th.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx, gidx, op, reduce_op, X_len, *feats\n    ):  # feats = lhs_data + rhs_data\n        out, (argX, argY, argX_ntype, argY_etype) = _gspmm_hetero(\n            gidx, op, reduce_op, X_len, feats\n        )\n        X, Y = feats[:X_len], feats[X_len:]\n        # TODO (Israt): check target to decide src_id/dst_id?\n        src_id, dst_id = gidx.metagraph.find_edge(0)\n        reduce_last = _need_reduce_last_dim(X[src_id], Y[dst_id])\n        X_shape = tuple(\n            [X[i].shape if X[i] is not None else None for i in range(X_len)]\n        )\n        Y_shape = tuple(\n            [Y[i].shape if Y[i] is not None else None for i in range(len(Y))]\n        )\n        dtype = X[src_id].dtype if X[src_id] is not None else Y[dst_id].dtype\n        device = X[src_id].device if X[src_id] is not None else Y[dst_id].device\n        ctx.backward_cache = (\n            gidx,\n            op,\n            reduce_op,\n            X_shape,\n            Y_shape,\n            dtype,\n            device,\n            reduce_last,\n            X_len,\n        )\n        req_grad_X = tuple(\n            [\n                X[i].requires_grad if X[i] is not None else False\n                for i in range(X_len)\n            ]\n        )\n        req_grad_Y = tuple(\n            [\n                Y[i].requires_grad if Y[i] is not None else False\n                for i in range(len(Y))\n            ]\n        )\n\n        # checking the first relation to decide for all the relations\n        if not spmm_cache_argX(\n            op, reduce_op, req_grad_X[src_id], req_grad_Y[dst_id]\n        ):\n            argX = tuple([None] * len(X))\n        if not spmm_cache_argY(\n            op, reduce_op, req_grad_X[src_id], req_grad_Y[dst_id]\n        ):\n            argY = tuple([None] * len(X))\n\n        ctx.save_for_backward(*feats, *argX, *argX_ntype, *argY, *argY_etype)\n        return out\n\n    @staticmethod\n    def backward(ctx, *dZ):\n        (\n            gidx,\n            op,\n            reduce_op,\n            X_shape,\n            Y_shape,\n            dtype,\n            device,\n            reduce_last,\n            X_len,\n        ) = ctx.backward_cache\n        num_ntypes = gidx.number_of_ntypes()\n        feats = ctx.saved_tensors[: -(4 * num_ntypes)]\n        argX = ctx.saved_tensors[-(4 * num_ntypes) : -(3 * num_ntypes)]\n        argX_ntype = ctx.saved_tensors[-(3 * num_ntypes) : -(2 * num_ntypes)]\n        argY = ctx.saved_tensors[-(2 * num_ntypes) : -num_ntypes]\n        argY_etype = ctx.saved_tensors[-num_ntypes:]\n        X, Y = feats[:X_len], feats[X_len:]\n\n        if op != \"copy_rhs\" and any([x is not None for x in X]):\n            g_rev = gidx.reverse()\n            if reduce_op == \"sum\":\n                if op == \"mul\":\n                    dX = gspmm_hetero(\n                        g_rev, \"mul\", \"sum\", len(X), *tuple(dZ + Y)\n                    )\n                elif op == \"add\":\n                    dX = gspmm_hetero(\n                        g_rev, \"copy_lhs\", \"sum\", len(X), *tuple(dZ + Y)\n                    )\n                elif op == \"copy_lhs\":\n                    tpl_None = tuple([None] * len(Y))\n                    dX = gspmm_hetero(\n                        g_rev, \"copy_lhs\", \"sum\", len(X), *tuple(dZ + tpl_None)\n                    )\n            else:  # max/min\n                # Assuming that the features are of the same dimension (enforced by the forward function)\n                src_id, dst_id = gidx.metagraph.find_edge(0)\n                dX = tuple(\n                    [\n                        th.zeros(\n                            (X_shape[i][0],) + dZ[dst_id].shape[1:],\n                            dtype=dtype,\n                            device=device,\n                        )\n                        if X[i] is not None\n                        else None\n                        for i in range(len(X))\n                    ]\n                )\n                if op == \"mul\":\n                    grad = _expand(Y, dZ.shape[1:]).gather(0, argY.long()) * dZ\n                    dX.scatter_add_(0, argX.long(), grad)\n                elif op in [\"add\", \"copy_lhs\"]:\n                    dX = _update_grad_minmax_hetero(\n                        g_rev, op, dZ, argX, argX_ntype, dX\n                    )\n            dX = tuple(\n                [\n                    _reduce_grad(dX[i], X_shape[i])\n                    if X[i] is not None\n                    else None\n                    for i in range(len(X))\n                ]\n            )\n        else:  # X has not gradient\n            dX = tuple([None] * len(X))\n        if op != \"copy_lhs\" and any([y is not None for y in Y]):\n            # TODO(Israt): implement other combinations of reduce functions\n            if reduce_op == \"sum\":\n                tpl_dZ = tuple(\n                    [\n                        dZ[i] if dZ[i] is not None else None\n                        for i in range(len(dZ))\n                    ]\n                )\n                tpl_X_dZ = tuple(X + tpl_dZ)\n                if op == \"mul\" and reduce_last:\n                    dY = gsddmm_hetero(gidx, \"dot\", X_len, \"u\", \"v\", *tpl_X_dZ)\n                elif op == \"mul\":\n                    dY = gsddmm_hetero(gidx, \"mul\", X_len, \"u\", \"v\", *tpl_X_dZ)\n                elif op in [\"add\", \"copy_rhs\"]:\n                    dY = gsddmm_hetero(\n                        gidx, \"copy_rhs\", X_len, \"u\", \"v\", *tpl_X_dZ\n                    )\n            else:  # max/min\n                src_id, dst_id = gidx.metagraph.find_edge(0)\n                dY = tuple(\n                    [\n                        th.zeros(\n                            (Y_shape[i][0],) + dZ[dst_id].shape[1:],\n                            dtype=dtype,\n                            device=device,\n                        )\n                        if Y[i] is not None\n                        else None\n                        for i in range(len(Y))\n                    ]\n                )\n                if op == \"mul\":\n                    grad = _expand(X, dZ.shape[1:]).gather(0, argX.long()) * dZ\n                    dY.scatter_add_(0, argY.long(), grad)\n                elif op in [\"add\", \"copy_rhs\"]:\n                    dY = _update_grad_minmax_hetero(\n                        gidx.reverse(), op, dZ, argY, argY_etype, dY\n                    )\n            dY = tuple(\n                [\n                    _reduce_grad(dY[i], Y_shape[i])\n                    if dY[i] is not None\n                    else None\n                    for i in range(len(dY))\n                ]\n            )\n        else:  # Y has no gradient\n            dY = tuple([None] * len(Y))\n        return (None, None, None, None) + dX + dY\n\n\ndef sddmm_cache_X(op, req_grad_X, req_grad_Y):\n    \"\"\"Rules to identify whether to cache X in SDDMM forward stage.\"\"\"\n    if op in [\"mul\", \"dot\"] and req_grad_Y:\n        return True\n    return False\n\n\ndef sddmm_cache_Y(op, req_grad_X, req_grad_Y):\n    \"\"\"Rules to identify whether to cache Y in SDDMM forward stage.\"\"\"\n    if op in [\"mul\", \"dot\"] and req_grad_X:\n        return True\n    return False\n\n\nclass GSDDMM(th.autograd.Function):\n    @staticmethod\n    def forward(ctx, gidx, op, X, Y, lhs_target, rhs_target):\n        out = _gsddmm(gidx, op, X, Y, lhs_target, rhs_target)\n        X_shape = X.shape if X is not None else None\n        Y_shape = Y.shape if Y is not None else None\n        ctx.backward_cache = gidx, op, lhs_target, rhs_target, X_shape, Y_shape\n        req_grad_X = X.requires_grad if X is not None else False\n        req_grad_Y = Y.requires_grad if Y is not None else False\n        if not sddmm_cache_X(op, req_grad_X, req_grad_Y):\n            X = None\n        if not sddmm_cache_Y(op, req_grad_X, req_grad_Y):\n            Y = None\n        ctx.save_for_backward(X, Y)\n        return out\n\n    @staticmethod\n    def backward(ctx, dZ):\n        gidx, op, lhs_target, rhs_target, X_shape, Y_shape = ctx.backward_cache\n        X, Y = ctx.saved_tensors\n        if op != \"copy_rhs\" and ctx.needs_input_grad[2]:\n            if lhs_target in [\"u\", \"v\"]:\n                _gidx = gidx if lhs_target == \"v\" else gidx.reverse()\n                if op in [\"add\", \"copy_lhs\"]:\n                    dX = gspmm(_gidx, \"copy_rhs\", \"sum\", None, dZ)\n                else:  # mul, dot\n                    if rhs_target == lhs_target:\n                        dX = gspmm(_gidx, \"copy_rhs\", \"sum\", None, dZ) * Y\n                    elif rhs_target == \"e\":\n                        dX = gspmm(_gidx, \"copy_rhs\", \"sum\", None, dZ * Y)\n                    else:  # rhs_target = !lhs_target\n                        dX = gspmm(_gidx, \"mul\", \"sum\", Y, dZ)\n            else:  # lhs_target == 'e'\n                if op in [\"add\", \"copy_lhs\"]:\n                    dX = dZ\n                else:  # mul, dot\n                    dX = gsddmm(gidx, \"mul\", dZ, Y, \"e\", rhs_target)\n            dX = _reduce_grad(dX, X_shape)\n        else:\n            dX = None\n        if op != \"copy_lhs\" and ctx.needs_input_grad[3]:\n            if rhs_target in [\"u\", \"v\"]:\n                _gidx = gidx if rhs_target == \"v\" else gidx.reverse()\n                if op in [\"add\", \"copy_rhs\"]:\n                    dY = gspmm(_gidx, \"copy_rhs\", \"sum\", None, dZ)\n                else:  # mul, dot\n                    if lhs_target == rhs_target:\n                        dY = gspmm(_gidx, \"copy_rhs\", \"sum\", None, dZ) * X\n                    elif lhs_target == \"e\":\n                        dY = gspmm(_gidx, \"copy_rhs\", \"sum\", None, dZ * X)\n                    else:  # rhs_target = !lhs_target\n                        dY = gspmm(_gidx, \"mul\", \"sum\", X, dZ)\n            else:\n                if op in [\"add\", \"copy_rhs\"]:\n                    dY = dZ\n                else:  # mul, dot\n                    dY = gsddmm(gidx, \"mul\", dZ, X, \"e\", lhs_target)\n            dY = _reduce_grad(dY, Y_shape)\n        else:\n            dY = None\n        return None, None, dX, dY, None, None\n\n\nclass GSDDMM_hetero(th.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx, gidx, op, X_len, lhs_target, rhs_target, *feats\n    ):  # feats = X+Y\n        out = _gsddmm_hetero(gidx, op, X_len, lhs_target, rhs_target, feats)\n        X, Y = feats[:X_len], feats[X_len:]\n        X_shape = tuple(\n            [X[i].shape if X[i] is not None else None for i in range(len(X))]\n        )\n        Y_shape = tuple(\n            [Y[i].shape if Y[i] is not None else None for i in range(len(Y))]\n        )\n        ctx.backward_cache = (\n            gidx,\n            op,\n            lhs_target,\n            rhs_target,\n            X_shape,\n            Y_shape,\n            X_len,\n        )\n        req_grad_X = tuple(\n            [\n                X[i].requires_grad if X[i] is not None else False\n                for i in range(len(X))\n            ]\n        )\n        req_grad_Y = tuple(\n            [\n                Y[i].requires_grad if Y[i] is not None else False\n                for i in range(len(Y))\n            ]\n        )\n        ctx.save_for_backward(*feats)\n        return out\n\n    @staticmethod\n    # TODO(Israt): Implement the complete backward operator\n    def backward(ctx, *dZ):\n        (\n            gidx,\n            op,\n            lhs_target,\n            rhs_target,\n            X_shape,\n            Y_shape,\n            X_len,\n        ) = ctx.backward_cache\n        feats = ctx.saved_tensors\n        X, Y = feats[:X_len], feats[X_len:]\n        if op != \"copy_rhs\" and any([x is not None for x in X]):\n            if lhs_target in [\"u\", \"v\"]:\n                _gidx = gidx if lhs_target == \"v\" else gidx.reverse()\n                tpl_of_None = tuple([None] * len(X))\n                if op in [\"add\", \"copy_lhs\"]:\n                    dX = gspmm_hetero(\n                        _gidx,\n                        \"copy_rhs\",\n                        \"sum\",\n                        len(X),\n                        *(tuple(tpl_of_None + dZ))\n                    )\n                else:  # mul, dot\n                    if rhs_target == lhs_target:\n                        dX = (\n                            gspmm_hetero(\n                                _gidx,\n                                \"copy_rhs\",\n                                \"sum\",\n                                len(X),\n                                *(tuple(tpl_of_None + dZ))\n                            )\n                            * Y\n                        )\n                    elif rhs_target == \"e\":\n                        dZ_mul_Y = tuple(\n                            [\n                                dZ[i] * Y[i] if dZ[i] is not None else None\n                                for i in range(len(Y))\n                            ]\n                        )\n                        dX = gspmm_hetero(\n                            _gidx,\n                            \"copy_rhs\",\n                            \"sum\",\n                            len(X),\n                            *(tuple(tpl_of_None + dZ_mul_Y))\n                        )\n                    else:  # rhs_target = !lhs_target\n                        dX = gspmm_hetero(\n                            _gidx, \"mul\", \"sum\", len(X), *tuple(Y + dZ)\n                        )\n            else:  # lhs_target == 'e'\n                if op in [\"add\", \"copy_lhs\"]:\n                    dX = dZ\n                else:  # mul, dot\n                    num_etype = gidx.number_of_etypes()\n                    dX = gsddmm_hetero(\n                        gidx, \"mul\", num_etype, \"e\", rhs_target, *tuple(dZ + Y)\n                    )\n            dX = tuple(\n                [\n                    _reduce_grad(dX[i], X_shape[i])\n                    if X[i] is not None\n                    else None\n                    for i in range(len(X))\n                ]\n            )\n        else:\n            dX = tuple([None] * len(X))\n        if op != \"copy_lhs\" and any([y is not None for y in Y]):\n            if rhs_target in [\"u\", \"v\"]:\n                _gidx = gidx if rhs_target == \"v\" else gidx.reverse()\n                tpl_of_None = tuple([None] * len(X))\n                if op in [\"add\", \"copy_rhs\"]:\n                    dY = gspmm_hetero(\n                        _gidx,\n                        \"copy_rhs\",\n                        \"sum\",\n                        len(X),\n                        *(tuple(tpl_of_None + dZ))\n                    )\n                else:  # mul, dot\n                    if lhs_target == rhs_target:\n                        dY = (\n                            gspmm_hetero(\n                                _gidx,\n                                \"copy_rhs\",\n                                \"sum\",\n                                len(X),\n                                *(tuple(tpl_of_None + dZ))\n                            )\n                            * X\n                        )\n                    elif lhs_target == \"e\":\n                        dZ_mul_X = tuple(\n                            [\n                                dZ[i] * X[i] if dZ[i] is not None else None\n                                for i in range(len(X))\n                            ]\n                        )\n                        dY = gspmm_hetero(\n                            _gidx,\n                            \"copy_rhs\",\n                            \"sum\",\n                            len(X),\n                            *(tuple(tpl_of_None + dZ_mul_X))\n                        )\n                    else:  # rhs_target = !lhs_target\n                        dY = gspmm_hetero(\n                            _gidx, \"mul\", \"sum\", len(X), *tuple(X + dZ)\n                        )\n            else:\n                if op in [\"add\", \"copy_rhs\"]:\n                    dY = tuple(\n                        [\n                            dZ[i] if dZ[i] is not None else None\n                            for i in range(len(dZ))\n                        ]\n                    )\n                else:  # mul, dot\n                    num_etype = gidx.number_of_etypes()\n                    dY = gsddmm_hetero(\n                        gidx, \"mul\", num_etype, \"e\", lhs_target, *tuple(dZ + X)\n                    )\n            dY = tuple(\n                [\n                    _reduce_grad(dY[i], Y_shape[i])\n                    if Y[i] is not None\n                    else None\n                    for i in range(len(Y))\n                ]\n            )\n        else:\n            dY = tuple([None] * len(Y))\n        return (None, None, None, None, None) + dX + dY\n\n\nclass EdgeSoftmax(th.autograd.Function):\n    @staticmethod\n    def forward(ctx, gidx, score, eids, norm_by):\n        \"\"\"Forward function.\n\n        Pseudo-code:\n\n        .. code:: python\n\n            score = dgl.EData(g, score)\n            score_max = score.dst_max()  # of type dgl.NData\n            score = score - score_max  # edge_sub_dst, ret dgl.EData\n            score_sum = score.dst_sum()  # of type dgl.NData\n            out = score / score_sum    # edge_div_dst, ret dgl.EData\n            return out.data\n        \"\"\"\n        # remember to save the graph to backward cache before making it\n        # a local variable\n        if not is_all(eids):\n            gidx = gidx.edge_subgraph([eids], True).graph\n        if norm_by == \"src\":\n            gidx = gidx.reverse()\n        # Note: Now _edge_softmax_forward op only supports CPU\n        # TODO(Zhejiang): We will support GPU in the future\n        if score.is_cuda:\n            score_max = _gspmm(gidx, \"copy_rhs\", \"max\", None, score)[0]\n            score = th.exp(_gsddmm(gidx, \"sub\", score, score_max, \"e\", \"v\"))\n            score_sum = _gspmm(gidx, \"copy_rhs\", \"sum\", None, score)[0]\n            out = _gsddmm(gidx, \"div\", score, score_sum, \"e\", \"v\")\n        else:\n            out = _edge_softmax_forward(gidx, score, \"copy_rhs\")\n        ctx.backward_cache = gidx\n        ctx.save_for_backward(out)\n        return out\n\n    @staticmethod\n    def backward(ctx, grad_out):\n        \"\"\"Backward function.\n\n        Pseudo-code:\n\n        .. code:: python\n\n            g, out = ctx.backward_cache\n            grad_out = dgl.EData(g, grad_out)\n            out = dgl.EData(g, out)\n            sds = out * grad_out  # type dgl.EData\n            sds_sum = sds.dst_sum()  # type dgl.NData\n            grad_score = sds - out * sds_sum  # multiple expressions\n            return grad_score.data\n        \"\"\"\n        gidx = ctx.backward_cache\n        (out,) = ctx.saved_tensors\n        sds = out * grad_out\n        # Note: Now _edge_softmax_backward op only supports CPU\n        # TODO(Zhejiang): We will support GPU in the future\n        if out.is_cuda:\n            accum = gspmm(gidx, \"copy_rhs\", \"sum\", None, sds)\n\n            grad_score = sds - gsddmm(gidx, \"mul\", out, accum, \"e\", \"v\")\n        else:\n            grad_score = _edge_softmax_backward(gidx, out, sds)\n        return None, grad_score, None, None\n\n\nclass EdgeSoftmax_hetero(th.autograd.Function):\n    @staticmethod\n    def forward(ctx, gidx, eids, norm_by, *score):\n        \"\"\"Forward function.\n\n        Pseudo-code:\n\n        .. code:: python\n\n            score = dgl.EData(g, score)\n            score_max = score.dst_max()  # of type dgl.NData\n            score = score - score_max  # edge_sub_dst, ret dgl.EData\n            score_sum = score.dst_sum()  # of type dgl.NData\n            out = score / score_sum    # edge_div_dst, ret dgl.EData\n            return out.data\n        \"\"\"\n        # remember to save the graph to backward cache before making it\n        # a local variable\n        if not is_all(eids):\n            gidx = gidx.edge_subgraph([eids], True).graph\n        if norm_by == \"src\":\n            gidx = gidx.reverse()\n        u_len = gidx.number_of_ntypes()\n        e_len = gidx.number_of_etypes()\n        lhs = [None] * u_len\n        feats = tuple(lhs + list(score))\n        score_max = _gspmm_hetero(gidx, \"copy_rhs\", \"max\", u_len, feats)[0]\n        out_tmp = _gsddmm_hetero(\n            gidx, \"sub\", e_len, \"e\", \"v\", tuple(list(score) + list(score_max))\n        )\n        score = tuple(\n            [\n                th.exp(out_tmp[i]) if out_tmp[i] is not None else None\n                for i in range(len(out_tmp))\n            ]\n        )\n        score_sum = _gspmm_hetero(\n            gidx, \"copy_rhs\", \"sum\", u_len, tuple(lhs + list(score))\n        )[0]\n        out = _gsddmm_hetero(\n            gidx, \"div\", e_len, \"e\", \"v\", tuple(list(score) + list(score_sum))\n        )\n        ctx.backward_cache = gidx\n        ctx.save_for_backward(*out)\n        return out\n\n    @staticmethod\n    def backward(ctx, *grad_out):\n        \"\"\"Backward function.\n\n        Pseudo-code:\n\n        .. code:: python\n\n            g, out = ctx.backward_cache\n            grad_out = dgl.EData(g, grad_out)\n            out = dgl.EData(g, out)\n            sds = out * grad_out  # type dgl.EData\n            sds_sum = sds.dst_sum()  # type dgl.NData\n            grad_score = sds - out * sds_sum  # multiple expressions\n            return grad_score.data\n        \"\"\"\n        gidx = ctx.backward_cache\n        u_len = gidx.number_of_ntypes()\n        e_len = gidx.number_of_etypes()\n        lhs = [None] * u_len\n        out = ctx.saved_tensors\n        sds = tuple([out[i] * grad_out[i] for i in range(len(out))])\n        accum = _gspmm_hetero(\n            gidx, \"copy_rhs\", \"sum\", u_len, tuple(lhs + list(sds))\n        )[0]\n        out_sddmm = _gsddmm_hetero(\n            gidx, \"mul\", e_len, \"e\", \"v\", tuple(list(out) + list(accum))\n        )\n        grad_score = tuple([sds[i] - out_sddmm[i] for i in range(len(sds))])\n        return (None, None, None) + grad_score\n\n\nclass SegmentReduce(th.autograd.Function):\n    @staticmethod\n    def forward(ctx, op, x, offsets):\n        y, arg = _segment_reduce(op, x, offsets)\n        ctx.save_for_backward(arg, offsets)\n        ctx.backward_cache = op\n        return y\n\n    @staticmethod\n    def backward(ctx, dy):\n        op = ctx.backward_cache\n        arg, offsets = ctx.saved_tensors\n        m = offsets[-1].item()\n        if op == \"sum\":\n            offsets = offsets[1:]\n            # To address the issue of trailing zeros, related issue:\n            # https://github.com/dmlc/dgl/pull/2610\n            indices = th.zeros(\n                (m + 1,), device=offsets.device, dtype=offsets.dtype\n            )\n            indices.scatter_add_(0, offsets, th.ones_like(offsets))\n            indices = th.cumsum(indices, -1)[:-1]\n            dx = dy[indices]\n        else:\n            dx = _bwd_segment_cmp(dy, arg, m)\n        return None, dx, None\n\n\nclass ScatterAdd(th.autograd.Function):\n    @staticmethod\n    def forward(ctx, x, idx, m):\n        y = _scatter_add(x, idx, m)\n        ctx.save_for_backward(idx)\n        return y\n\n    @staticmethod\n    def backward(ctx, dy):\n        idx = ctx.saved_tensors\n        return dy[idx], None, None\n\n\nclass CSRMM(th.autograd.Function):\n    @staticmethod\n    def forward(ctx, gidxA, A_weights, gidxB, B_weights, num_vtypes):\n        gidxC, C_weights = _csrmm(\n            gidxA, A_weights, gidxB, B_weights, num_vtypes\n        )\n        (\n            nrows,\n            ncols,\n            C_indptr,\n            C_indices,\n            C_eids,\n        ) = gidxC.adjacency_matrix_tensors(0, False, \"csr\")\n        # Note: the returned C_indptr, C_indices and C_eids tensors MUST be the same\n        # as the underlying tensors of the created graph gidxC.\n        ctx.backward_cache = gidxA, gidxB, gidxC\n        ctx.save_for_backward(A_weights, B_weights)\n        return (\n            th.tensor(nrows),\n            th.tensor(ncols),\n            C_indptr,\n            C_indices,\n            C_eids,\n            C_weights,\n        )\n\n    @staticmethod\n    def backward(\n        ctx, dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights\n    ):\n        # Only the last argument is meaningful.\n        gidxA, gidxB, gidxC = ctx.backward_cache\n        A_weights, B_weights = ctx.saved_tensors\n        dgidxA, dA_weights = csrmm(\n            gidxC,\n            dC_weights,\n            gidxB.reverse(),\n            B_weights,\n            gidxA.number_of_ntypes(),\n        )\n        dgidxB, dB_weights = csrmm(\n            gidxA.reverse(),\n            A_weights,\n            gidxC,\n            dC_weights,\n            gidxB.number_of_ntypes(),\n        )\n        dA_weights = csrmask(dgidxA, dA_weights, gidxA)\n        dB_weights = csrmask(dgidxB, dB_weights, gidxB)\n        return None, dA_weights, None, dB_weights, None\n\n\nclass CSRSum(th.autograd.Function):\n    @staticmethod\n    def forward(ctx, gidxs, *weights):\n        # PyTorch tensors must be explicit arguments of the forward function\n        gidxC, C_weights = _csrsum(gidxs, weights)\n        (\n            nrows,\n            ncols,\n            C_indptr,\n            C_indices,\n            C_eids,\n        ) = gidxC.adjacency_matrix_tensors(0, False, \"csr\")\n        # Note: the returned C_indptr, C_indices and C_eids tensors MUST be the same\n        # as the underlying tensors of the created graph gidxC.\n        ctx.backward_cache = gidxs, gidxC\n        return (\n            th.tensor(nrows),\n            th.tensor(ncols),\n            C_indptr,\n            C_indices,\n            C_eids,\n            C_weights,\n        )\n\n    @staticmethod\n    def backward(\n        ctx, dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights\n    ):\n        # Only the last argument is meaningful.\n        gidxs, gidxC = ctx.backward_cache\n        return (None,) + tuple(\n            csrmask(gidxC, dC_weights, gidx) for gidx in gidxs\n        )\n\n\nclass CSRMask(th.autograd.Function):\n    @staticmethod\n    def forward(ctx, gidxA, A_weights, gidxB):\n        ctx.backward_cache = gidxA, gidxB\n        return _csrmask(gidxA, A_weights, gidxB)\n\n    @staticmethod\n    def backward(ctx, dB_weights):\n        gidxA, gidxB = ctx.backward_cache\n        return None, csrmask(gidxB, dB_weights, gidxA), None\n\n\nclass SEGMENTMM(th.autograd.Function):\n    @staticmethod\n    def forward(ctx, A, B, seglen_A):\n        if B.dim() != 3:\n            raise ValueError(\"segment_mm expects B to be a 3D tensor.\")\n        C = th.empty((A.shape[0], B.shape[2]), device=A.device, dtype=A.dtype)\n        C = _segment_mm(A, B, C, seglen_A)\n        ctx.backward_cache = A, B, seglen_A\n        return C\n\n    @staticmethod\n    def backward(ctx, dZ):\n        A, B, seglen_A = ctx.backward_cache\n        A_grad = B_grad = None\n        if ctx.needs_input_grad[0]:\n            #  Compute A_grad = Out_grad * B^T\n            A_grad = th.empty(A.shape, device=A.device, dtype=A.dtype)\n            A_grad = _segment_mm(dZ, B, A_grad, seglen_A, b_trans=True)\n        if ctx.needs_input_grad[1]:\n            #  Compute B_grad = A^T * Out_grad\n            B_grad = th.empty(B.shape, device=B.device, dtype=B.dtype)\n            B_grad = _segment_mm_backward_B(A, dZ, B_grad, seglen_A)\n        return A_grad, B_grad, None\n\n\nclass GATHERMM(th.autograd.Function):\n    @staticmethod\n    def forward(ctx, A, B, idx_a, idx_b):\n        if B.dim() != 3:\n            raise ValueError(\n                \"Expected dimension of B is 3. Got \" + str(B.dim())\n            )\n        N = len(idx_b) if idx_a is None else len(idx_a)\n        C = th.zeros((N, B.shape[2]), device=A.device, dtype=A.dtype)\n        C = _gather_mm(A, B, C, idx_a, idx_b)\n        ctx.backward_cache = A, B, idx_a, idx_b\n        return C\n\n    @staticmethod\n    def backward(ctx, dZ):\n        A, B, idx_a, idx_b = ctx.backward_cache\n        A_grad = B_grad = None\n        if ctx.needs_input_grad[0]:\n            #  Compute A_grad = Out_grad * B^T\n            A_grad = th.zeros(A.shape, device=A.device, dtype=A.dtype)\n            A_grad = _gather_mm_scatter(\n                dZ, B.transpose(1, 2), A_grad, idx_b=idx_b, idx_c=idx_a\n            )\n        if ctx.needs_input_grad[1]:\n            #  Compute B_grad = A^T * Out_grad\n            B_grad = th.zeros(B.shape, device=B.device, dtype=B.dtype)\n            B_grad = _gather_mm_scatter(A, dZ, B_grad, idx_a=idx_a, idx_c=idx_b)\n        return A_grad, B_grad, None, None\n\n\ndef gspmm(gidx, op, reduce_op, lhs_data, rhs_data):\n    if op == \"sub\":\n        op = \"add\"\n        rhs_data = -rhs_data\n    if op == \"div\":\n        op = \"mul\"\n        rhs_data = 1.0 / rhs_data\n    args = _cast_if_autocast_enabled(gidx, op, reduce_op, lhs_data, rhs_data)\n    with _disable_autocast_if_enabled():\n        return GSpMM.apply(*args)\n\n\ndef gsddmm(gidx, op, lhs_data, rhs_data, lhs_target=\"u\", rhs_target=\"v\"):\n    if op == \"sub\":\n        op = \"add\"\n        rhs_data = -rhs_data\n    if op == \"div\":\n        op = \"mul\"\n        rhs_data = 1.0 / rhs_data\n    args = _cast_if_autocast_enabled(\n        gidx, op, lhs_data, rhs_data, lhs_target, rhs_target\n    )\n    with _disable_autocast_if_enabled():\n        return GSDDMM.apply(*args)\n\n\ndef gspmm_hetero(g, op, reduce_op, lhs_len, *lhs_and_rhs_tuple):\n    lhs_tuple, rhs_tuple = (\n        lhs_and_rhs_tuple[:lhs_len],\n        lhs_and_rhs_tuple[lhs_len:],\n    )\n    if op == \"sub\":\n        op = \"add\"\n        rhs_tuple = tuple(\n            [\n                -rhs_tuple[i] if rhs_tuple[i] is not None else None\n                for i in range(len(rhs_tuple))\n            ]\n        )\n    if op == \"div\":\n        op = \"mul\"\n        rhs_tuple = tuple(\n            [\n                (1.0 / rhs_tuple[i]) if rhs_tuple[i] is not None else None\n                for i in range(len(rhs_tuple))\n            ]\n        )\n    if op in [\"add\", \"mul\"]:\n        lhs_and_rhs_tuple = tuple(list(lhs_tuple) + list(rhs_tuple))\n\n    args = _cast_if_autocast_enabled(\n        g, op, reduce_op, lhs_len, *lhs_and_rhs_tuple\n    )\n    with _disable_autocast_if_enabled():\n        return GSpMM_hetero.apply(*args)\n\n\ndef gsddmm_hetero(\n    g, op, lhs_len, lhs_target=\"u\", rhs_target=\"v\", *lhs_and_rhs_tuple\n):\n    lhs_tuple, rhs_tuple = (\n        lhs_and_rhs_tuple[:lhs_len],\n        lhs_and_rhs_tuple[lhs_len:],\n    )\n    if op == \"sub\":\n        op = \"add\"\n        rhs_tuple = tuple(\n            [\n                -rhs_tuple[i] if rhs_tuple[i] is not None else None\n                for i in range(len(rhs_tuple))\n            ]\n        )\n    if op == \"div\":\n        op = \"mul\"\n        rhs_tuple = tuple(\n            [\n                (1.0 / rhs_tuple[i]) if rhs_tuple[i] is not None else None\n                for i in range(len(rhs_tuple))\n            ]\n        )\n    if op in [\"add\", \"mul\"]:\n        lhs_and_rhs_tuple = tuple(list(lhs_tuple) + list(rhs_tuple))\n\n    args = _cast_if_autocast_enabled(\n        g, op, lhs_len, lhs_target, rhs_target, *lhs_and_rhs_tuple\n    )\n    with _disable_autocast_if_enabled():\n        return GSDDMM_hetero.apply(*args)\n\n\ndef edge_softmax(gidx, logits, eids=ALL, norm_by=\"dst\"):\n    args = _cast_if_autocast_enabled(gidx, logits, eids, norm_by)\n    with _disable_autocast_if_enabled():\n        return EdgeSoftmax.apply(*args)\n\n\ndef edge_softmax_hetero(gidx, eids=ALL, norm_by=\"dst\", *logits):\n    args = _cast_if_autocast_enabled(gidx, eids, norm_by, *logits)\n    with _disable_autocast_if_enabled():\n        return EdgeSoftmax_hetero.apply(*args)\n\n\ndef segment_reduce(op, x, offsets):\n    args = _cast_if_autocast_enabled(op, x, offsets)\n    with _disable_autocast_if_enabled():\n        return SegmentReduce.apply(*args)\n\n\ndef scatter_add(x, idx, m):\n    args = _cast_if_autocast_enabled(x, idx, m)\n    with _disable_autocast_if_enabled():\n        return ScatterAdd.apply(*args)\n\n\ndef csrmm(gidxA, A_weights, gidxB, B_weights, num_vtypes):\n    nrows, ncols, C_indptr, C_indices, C_eids, C_weights = CSRMM.apply(\n        gidxA, A_weights, gidxB, B_weights, num_vtypes\n    )\n    gidxC = create_unitgraph_from_csr(\n        num_vtypes,\n        nrows.item(),\n        ncols.item(),\n        C_indptr,\n        C_indices,\n        C_eids,\n        [\"coo\", \"csr\", \"csc\"],\n    )\n    return gidxC, C_weights\n\n\ndef csrsum(gidxs, weights):\n    nrows, ncols, C_indptr, C_indices, C_eids, C_weights = CSRSum.apply(\n        gidxs, *weights\n    )\n    gidxC = create_unitgraph_from_csr(\n        gidxs[0].number_of_ntypes(),\n        nrows.item(),\n        ncols.item(),\n        C_indptr,\n        C_indices,\n        C_eids,\n        [\"coo\", \"csr\", \"csc\"],\n    )\n    return gidxC, C_weights\n\n\ndef csrmask(gidxA, A_weights, gidxB):\n    return CSRMask.apply(gidxA, A_weights, gidxB)\n\n\ndef segment_mm(A, B, seglen_A):\n    if A.device.type == \"cpu\":\n        C = []\n        off = 0\n        for i in range(B.shape[0]):\n            C.append(A[off : off + seglen_A[i]] @ B[i])\n            off += seglen_A[i]\n        return th.cat(C)\n    else:\n        args = _cast_if_autocast_enabled(A, B, seglen_A)\n        with _disable_autocast_if_enabled():\n            return SEGMENTMM.apply(*args)\n\n\ndef gather_mm(A, B, idx_A=None, idx_B=None):\n    if A.device.type == \"cpu\":\n        A = A[idx_A] if idx_A is not None else A\n        B = B[idx_B] if idx_B is not None else B\n        return th.bmm(A.unsqueeze(1), B).squeeze(1)\n    else:\n        args = _cast_if_autocast_enabled(A, B, idx_A, idx_B)\n        with _disable_autocast_if_enabled():\n            return GATHERMM.apply(*args)\n"
  },
  {
    "path": "python/dgl/backend/pytorch/tensor.py",
    "content": "from __future__ import absolute_import\n\nimport builtins\nimport numbers\n\nimport numpy as np\nimport scipy  # Weird bug in new pytorch when import scipy after import torch\nimport torch as th\nfrom torch.utils import dlpack\n\nfrom ... import ndarray as nd\nfrom ...function.base import TargetCode\nfrom ...utils import version\n\nif version.parse(th.__version__) < version.parse(\"2.1.0\"):\n    raise RuntimeError(\"DGL requires PyTorch >= 2.1.0\")\n\n\ndef data_type_dict():\n    return {\n        \"bfloat16\": th.bfloat16,\n        \"float16\": th.float16,\n        \"float32\": th.float32,\n        \"float64\": th.float64,\n        \"uint8\": th.uint8,\n        \"int8\": th.int8,\n        \"int16\": th.int16,\n        \"int32\": th.int32,\n        \"int64\": th.int64,\n        \"bool\": th.bool,\n    }\n\n\ndef cpu():\n    return th.device(\"cpu\")\n\n\ndef tensor(data, dtype=None):\n    if isinstance(data, numbers.Number):\n        data = [data]\n    if (\n        isinstance(data, list)\n        and len(data) > 0\n        and isinstance(data[0], th.Tensor)\n    ):\n        # prevent GPU->CPU->GPU copies\n        if data[0].ndim == 0:\n            # zero dimenion scalar tensors\n            return th.stack(data)\n    if isinstance(data, th.Tensor):\n        return th.as_tensor(data, dtype=dtype, device=data.device)\n    else:\n        return th.as_tensor(data, dtype=dtype)\n\n\ndef as_scalar(data):\n    return data.item()\n\n\ndef get_preferred_sparse_format():\n    \"\"\"Get the preferred sparse matrix format supported by the backend.\n\n    Different backends have their preferred backend. This info is useful when\n    constructing a sparse matrix.\n    \"\"\"\n    return \"coo\"\n\n\ndef sparse_matrix(data, index, shape, force_format=False):\n    fmt = index[0]\n    if fmt != \"coo\":\n        raise TypeError(\n            \"Pytorch backend only supports COO format. But got %s.\" % fmt\n        )\n    spmat = th.sparse_coo_tensor(index[1], data, shape)\n    return spmat, None\n\n\ndef sparse_matrix_indices(spmat):\n    return (\"coo\", spmat._indices())\n\n\ndef is_tensor(obj):\n    return isinstance(obj, th.Tensor)\n\n\ndef shape(input):\n    return input.shape\n\n\ndef dtype(input):\n    return input.dtype\n\n\ndef ndim(input):\n    return input.dim()\n\n\ndef context(input):\n    return input.device\n\n\ndef device_type(ctx):\n    return th.device(ctx).type\n\n\ndef device_id(ctx):\n    ctx = th.device(ctx)\n    if ctx.index is None:\n        return 0 if ctx.type == \"cpu\" else th.cuda.current_device()\n    else:\n        return ctx.index\n\n\ndef to_backend_ctx(dglctx):\n    dev_type = dglctx.device_type\n    if dev_type == 1:\n        return th.device(\"cpu\")\n    elif dev_type == 2:\n        return th.device(\"cuda\", dglctx.device_id)\n    else:\n        raise ValueError(\"Unsupported DGL device context:\", dglctx)\n\n\ndef astype(input, ty):\n    return input.type(ty)\n\n\ndef asnumpy(input):\n    if isinstance(input, th.sparse.FloatTensor):\n        return input.to_dense().cpu().detach().numpy()\n    else:\n        return input.cpu().detach().numpy()\n\n\ndef copy_to(input, ctx, **kwargs):\n    ctx = th.device(ctx)\n    if ctx.type == \"cpu\":\n        return input.cpu()\n    elif ctx.type == \"cuda\":\n        if ctx.index is not None:\n            th.cuda.set_device(ctx.index)\n        return input.cuda(**kwargs)\n    else:\n        raise RuntimeError(\"Invalid context\", ctx)\n\n\ndef is_pinned(input):\n    return input.is_pinned()\n\n\ndef sum(input, dim, keepdims=False):\n    return th.sum(input, dim=dim, keepdim=keepdims)\n\n\ndef floor_div(in1, in2):\n    return in1 // in2\n\n\ndef reduce_sum(input):\n    return input.sum()\n\n\ndef cumsum(input, dim):\n    return th.cumsum(input, dim=dim)\n\n\ndef mean(input, dim):\n    return th.mean(input, dim=dim)\n\n\ndef reduce_mean(input):\n    return input.mean()\n\n\ndef max(input, dim):\n    # NOTE: the second argmax array is not returned\n    return th.max(input, dim=dim)[0]\n\n\ndef reduce_max(input):\n    return input.max()\n\n\ndef min(input, dim):\n    # NOTE: the second argmin array is not returned\n    return th.min(input, dim=dim)[0]\n\n\ndef reduce_min(input):\n    return input.min()\n\n\ndef argsort(input, dim, descending):\n    return th.argsort(input, dim=dim, descending=descending)\n\n\ndef topk(input, k, dim, descending=True):\n    return th.topk(input, k, dim, largest=descending)[0]\n\n\ndef argtopk(input, k, dim, descending=True):\n    return th.topk(input, k, dim, largest=descending)[1]\n\n\ndef exp(input):\n    return th.exp(input)\n\n\ndef inverse(input):\n    return th.inverse(input)\n\n\ndef sqrt(input):\n    return th.sqrt(input)\n\n\ndef softmax(input, dim=-1):\n    return th.softmax(input, dim=dim)\n\n\ndef cat(seq, dim):\n    return th.cat(seq, dim=dim)\n\n\ndef stack(seq, dim):\n    return th.stack(seq, dim=dim)\n\n\ndef split(input, sizes_or_sections, dim):\n    return th.split(input, sizes_or_sections, dim)\n\n\ndef repeat(input, repeats, dim):\n    return th.repeat_interleave(input, repeats, dim)  # PyTorch 1.1\n\n\ndef gather_row(data, row_index):\n    return th.index_select(data, 0, row_index.long())\n\n\ndef slice_axis(data, axis, begin, end):\n    return th.narrow(data, axis, begin, end - begin)\n\n\ndef take(data, indices, dim):\n    new_shape = data.shape[:dim] + indices.shape + data.shape[dim + 1 :]\n    return th.index_select(data, dim, indices.view(-1)).view(new_shape)\n\n\ndef narrow_row(x, start, stop):\n    return x[start:stop]\n\n\ndef index_add_inplace(data, row_idx, value):\n    data.index_add_(0, row_idx, value)\n\n\ndef scatter_row(data, row_index, value):\n    return data.index_copy(0, row_index.long(), value)\n\n\ndef scatter_row_inplace(data, row_index, value):\n    data[row_index.long()] = value\n\n\ndef squeeze(input, dim):\n    return th.squeeze(input, dim)\n\n\ndef unsqueeze(input, dim):\n    return th.unsqueeze(input, dim)\n\n\ndef reshape(input, shape):\n    return th.reshape(input, shape)\n\n\ndef swapaxes(input, axis1, axis2):\n    return th.transpose(input, axis1, axis2)\n\n\ndef empty(shape, dtype, ctx):\n    return th.empty(shape, dtype=dtype, device=ctx)\n\n\ndef zeros(shape, dtype, ctx):\n    return th.zeros(shape, dtype=dtype, device=ctx)\n\n\ndef zeros_like(input):\n    return th.zeros_like(input)\n\n\ndef ones(shape, dtype, ctx):\n    return th.ones(shape, dtype=dtype, device=ctx)\n\n\ndef uniform(shape, dtype, ctx, low, high):\n    return th.empty(shape, dtype=dtype, device=ctx).uniform_(low, high)\n\n\ndef randint(shape, dtype, ctx, low, high):\n    return th.randint(low, high, shape, dtype=dtype, device=ctx)\n\n\ndef pad_packed_tensor(input, lengths, value, l_min=None):\n    old_shape = input.shape\n    device = input.device\n    if not is_tensor(lengths):\n        lengths = th.tensor(lengths, dtype=th.int64, device=device)\n    else:\n        lengths = lengths.to(device)\n    max_len = as_scalar(lengths.max())\n\n    if l_min is not None:\n        max_len = builtins.max(max_len, l_min)\n\n    batch_size = len(lengths)\n    x = input.new(batch_size * max_len, *old_shape[1:])\n    x.fill_(value)\n    index = th.ones(len(input), dtype=th.int64, device=device)\n    cum_lengths = th.cumsum(lengths, 0)\n    index[cum_lengths[:-1]] += max_len - lengths[:-1]\n    index = th.cumsum(index, 0) - 1\n    x[index] = input\n    return x.view(batch_size, max_len, *old_shape[1:])\n\n\ndef pack_padded_tensor(input, lengths):\n    max_len = input.shape[1]\n    device = input.device\n    if not is_tensor(lengths):\n        lengths = th.tensor(lengths, dtype=th.int64, device=device)\n    else:\n        lengths = lengths.to(device)\n    input = input.view(-1, *input.shape[2:])\n    out_len = lengths.sum().item()\n    index = th.ones(out_len, dtype=th.int64, device=device)\n    cum_lengths = th.cumsum(lengths, 0)\n    index[cum_lengths[:-1]] += max_len - lengths[:-1]\n    index = th.cumsum(index, 0) - 1\n    return input[index]\n\n\ndef boolean_mask(input, mask):\n    if \"bool\" not in str(mask.dtype):\n        mask = th.as_tensor(mask, dtype=th.bool)\n    return input[mask]\n\n\ndef equal(x, y):\n    return x == y\n\n\ndef allclose(x, y, rtol=1e-4, atol=1e-4):\n    return th.allclose(x, y, rtol=rtol, atol=atol)\n\n\ndef logical_not(input):\n    return ~input\n\n\ndef logical_and(input1, input2):\n    return input1 & input2\n\n\ndef clone(input):\n    return input.clone()\n\n\ndef clamp(data, min_val, max_val):\n    return th.clamp(data, min_val, max_val)\n\n\ndef replace_inf_with_zero(x):\n    return th.masked_fill(x, th.isinf(x), 0)\n\n\ndef count_nonzero(input):\n    # TODO: fallback to numpy for backward compatibility\n    return np.count_nonzero(input)\n\n\ndef unique(input, return_inverse=False, return_counts=False):\n    if input.dtype == th.bool:\n        input = input.type(th.int8)\n    return th.unique(\n        input, return_inverse=return_inverse, return_counts=return_counts\n    )\n\n\ndef full_1d(length, fill_value, dtype, ctx):\n    return th.full((length,), fill_value, dtype=dtype, device=ctx)\n\n\ndef nonzero_1d(input):\n    x = th.nonzero(input, as_tuple=False).squeeze()\n    return x if x.dim() == 1 else x.view(-1)\n\n\ndef sort_1d(input):\n    return th.sort(input)\n\n\ndef arange(start, stop, dtype=th.int64, ctx=None):\n    return th.arange(start, stop, dtype=dtype, device=ctx)\n\n\ndef rand_shuffle(arr):\n    idx = th.randperm(len(arr))\n    return arr[idx]\n\n\ndef zerocopy_to_dlpack(input):\n    return dlpack.to_dlpack(input.contiguous())\n\n\ndef zerocopy_from_dlpack(dlpack_tensor):\n    return dlpack.from_dlpack(dlpack_tensor)\n\n\ndef zerocopy_to_numpy(input):\n    # NOTE: not zerocopy\n    return asnumpy(input)\n\n\ndef zerocopy_from_numpy(np_array):\n    return th.as_tensor(np_array)\n\n\ndef zerocopy_to_dgl_ndarray(data):\n    if data.dtype == th.bool:\n        data = data.byte()\n    return nd.from_dlpack(dlpack.to_dlpack(data.contiguous()))\n\n\n# NGC PyTorch containers are shipping alpha version PyTorch.\nif version.parse(th.__version__) >= version.parse(\"2.0.0a0\"):\n\n    def check_is_view(input):\n        assert (\n            input.data_ptr() == input.untyped_storage().data_ptr()\n        ), \"Cannot convert view tensors to dgl ndarray for write.\"\n\nelse:\n\n    def check_is_view(input):\n        assert (\n            input.data_ptr() == input._storage().data_ptr()\n        ), \"Cannot convert view tensors to dgl ndarray for write.\"\n\n\ndef zerocopy_to_dgl_ndarray_for_write(input):\n    if input.numel() > 0:\n        # only check non-empty tensors\n        assert input.is_contiguous(), (\n            \"Cannot convert non-contiguous tensors \"\n            \"to dgl ndarray for write. Call .to_contiguous() first.\"\n        )\n        check_is_view(input)\n    return zerocopy_to_dgl_ndarray(input)\n\n\ndef zerocopy_from_dgl_ndarray(data):\n    if data.shape == (0,):\n        # NOTE: PyTorch v1.5 does not accept DLPack object representing empty CUDA tensor.\n        #  Related issue: https://github.com/pytorch/pytorch/issues/41182\n        #  The issue will be fixed in v1.6 and later.\n        return th.tensor(\n            [], dtype=getattr(th, data.dtype), device=to_backend_ctx(data.ctx)\n        )\n    elif len(data.shape) == 0 or builtins.min(data.shape) == 0:\n        # Workaround the same issue as above, but preserve the shape of the\n        # empty tensor. This is needed by the sparse optimizer when one of\n        # processors may receive no gradients to update, but we want to keep\n        # the dimension of the embedding.\n        return th.empty(\n            data.shape,\n            dtype=getattr(th, data.dtype),\n            device=to_backend_ctx(data.ctx),\n        )\n    else:\n        return dlpack.from_dlpack(data.to_dlpack())\n\n\ndef sync():\n    # Pytorch performs computation synchronously, so no need for synchronization.\n    pass\n\n\ndef attach_grad(x):\n    if x.grad is not None:\n        x.grad.zero_()\n        return x\n    else:\n        return x.requires_grad_()\n\n\ndef backward(x, head_gradient=None):\n    if (\n        head_gradient is not None\n        and head_gradient.shape[0] == 1\n        and len(head_gradient.shape) == 1\n    ):\n        # Fix for torch 1.3.1\n        head_gradient = th.tensor(head_gradient.item()).to(head_gradient.device)\n    x.backward(head_gradient)\n\n\ndef grad(x):\n    x.retain_grad()\n    return x.grad\n\n\ndef is_no_grad(x):\n    return x.grad is None or (x.grad == 0).all()\n\n\ndef is_recording():\n    return th.is_grad_enabled()\n\n\nclass record_grad(object):\n    def __init__(self):\n        pass\n\n    def __enter__(self):\n        pass\n\n    def __exit__(self, exc_type, exc_value, exc_traceback):\n        pass\n\n\nno_grad = th.no_grad\n"
  },
  {
    "path": "python/dgl/backend/set_default_backend.py",
    "content": "import argparse\nimport json\nimport os\n\n\ndef set_default_backend(default_dir, backend_name):\n    os.makedirs(default_dir, exist_ok=True)\n    config_path = os.path.join(default_dir, \"config.json\")\n    with open(config_path, \"w\") as config_file:\n        json.dump({\"backend\": backend_name.lower()}, config_file)\n    print(\n        'Setting the default backend to \"{}\". You can change it in the '\n        \"~/.dgl/config.json file or export the DGLBACKEND environment variable.  \"\n        \"Valid options are: pytorch, mxnet, tensorflow (all lowercase)\".format(\n            backend_name\n        )\n    )\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"default_dir\",\n        type=str,\n        default=os.path.join(os.path.expanduser(\"~\"), \".dgl\"),\n    )\n    parser.add_argument(\n        \"backend\",\n        nargs=1,\n        type=str,\n        choices=[\"pytorch\", \"tensorflow\", \"mxnet\"],\n        help=\"Set default backend\",\n    )\n    args = parser.parse_args()\n    set_default_backend(args.default_dir, args.backend[0])\n"
  },
  {
    "path": "python/dgl/backend/tensorflow/__init__.py",
    "content": "import os\n\nos.environ[\"TF_FORCE_GPU_ALLOW_GROWTH\"] = \"true\"\n\nfrom .sparse import *\nfrom .tensor import *\n"
  },
  {
    "path": "python/dgl/backend/tensorflow/sparse.py",
    "content": "import numpy as np\nimport tensorflow as tf\n\nfrom ..._sparse_ops import (\n    _bwd_segment_cmp,\n    _csrmask,\n    _csrmm,\n    _csrsum,\n    _gsddmm,\n    _gspmm,\n    _scatter_add,\n    _segment_reduce,\n)\n\nfrom ...base import ALL, is_all\nfrom ...heterograph_index import create_unitgraph_from_csr\nfrom .tensor import asnumpy, context, copy_to, tensor, zerocopy_from_numpy\n\n__all__ = [\n    \"gspmm\",\n    \"gsddmm\",\n    \"edge_softmax\",\n    \"segment_reduce\",\n    \"scatter_add\",\n    \"csrmm\",\n    \"csrsum\",\n    \"csrmask\",\n]\n\n\ndef _scatter_nd(index, src, n_rows):\n    assert index.shape == src.shape\n    shp = index.shape\n    ctx = context(src)\n    ndim = index.ndim\n    offsets = []\n    stride = 1\n    for i in reversed(range(1, ndim)):\n        di = shp[i]\n        offset_i = tf.range(di, dtype=index.dtype)\n        offsets.append(\n            tf.reshape(\n                (stride * offset_i), (1,) * i + (di,) + (1,) * (ndim - 1 - i)\n            )\n        )\n        stride *= di\n    if ndim > 1:\n        new_idx = index * stride + copy_to(sum(offsets), ctx)\n    else:\n        new_idx = index\n    src = tf.reshape(src, (-1,))\n    new_idx = tf.reshape(new_idx, (-1, 1))\n    rst = tf.reshape(\n        tf.scatter_nd(new_idx, src, (stride * n_rows,)), (n_rows, *shp[1:])\n    )\n    return rst\n\n\ndef _gather_nd(index, src):\n    shp = index.shape\n    ctx = context(src)\n    ndim = index.ndim\n    offsets = []\n    stride = 1\n    for i in reversed(range(1, ndim)):\n        di = shp[i]\n        offset_i = tf.range(di, dtype=index.dtype)\n        offsets.append(\n            tf.reshape(\n                (stride * offset_i), (1,) * i + (di,) + (1,) * (ndim - 1 - i)\n            )\n        )\n        stride *= di\n    if ndim > 1:\n        new_idx = index * stride + copy_to(sum(offsets), ctx)\n    else:\n        new_idx = index\n    src = tf.reshape(src, (-1,))\n    new_idx = tf.reshape(new_idx, (-1))\n    rst = tf.reshape(tf.gather(src, new_idx), shp)\n    return rst\n\n\ndef _reduce_grad(grad, shape):\n    \"\"\"Reduce gradient on the broadcast dimension\n    If there is broadcast in forward pass, gradients need to be reduced on\n    broadcast dimension. This function checks the input tensor shape and\n    gradient shape and perform the reduction.\n    Parameters\n    ----------\n    grad: Tensor\n        Gradient tensor\n    shape: tuple\n        Shape of input tensor\n    Returns\n    -------\n    Tensor\n    \"\"\"\n    grad_shape = grad.shape[1:]\n    in_shape = shape[1:]\n    if in_shape == grad_shape:\n        # no need to reduce\n        return grad\n    num_to_squeeze = len(grad_shape) - len(in_shape)\n    # pad inshape\n    in_shape = (1,) * num_to_squeeze + in_shape\n    reduce_idx = np.asarray(\n        np.nonzero(np.asarray(grad_shape) - np.asarray(in_shape))\n    )\n    reduce_idx += 1  # skip batch dim\n    reduce_idx_tensor = tf.constant(\n        tuple(reduce_idx.flatten().tolist()), dtype=tf.int32\n    )\n    grad = tf.reduce_sum(grad, axis=reduce_idx_tensor, keepdims=True)\n    return tf.reshape(grad, shape)\n\n\ndef _need_reduce_last_dim(ufeat, efeat):\n    \"\"\"Indicates whether to reduce the last dimension on edges\n    in the backward pass of spmm,\n    if so, use dot instead of mul.\"\"\"\n    ushp = ufeat.shape\n    eshp = efeat.shape\n    return ushp[1:-1] == eshp[1:-1] and eshp[-1] == 1 and ushp[-1] > 1\n\n\ndef _muldiv(op, x):\n    return 1.0 / x if op == \"div\" else x\n\n\ndef _addsub(op, x):\n    return -x if op == \"sub\" else x\n\n\ndef _expand(x, shape):\n    return tf.broadcast_to(x, (x.shape[0], *shape))\n\n\ndef gspmm_real(gidx, op, reduce_op, X, Y):\n    out, (argX, argY) = _gspmm(gidx, op, reduce_op, X, Y)\n\n    def grad(dZ):\n        dZ = tensor(dZ)\n        if op != \"copy_rhs\":\n            g_rev = gidx.reverse()\n            if reduce_op == \"sum\":\n                if op in [\"mul\", \"div\"]:\n                    dX = _gspmm(g_rev, \"mul\", \"sum\", dZ, _muldiv(op, Y))[0]\n                elif op in [\"add\", \"sub\"]:\n                    dX = _gspmm(g_rev, \"copy_lhs\", \"sum\", dZ, Y)[0]\n                elif op == \"copy_lhs\":\n                    dX = _gspmm(g_rev, \"copy_lhs\", \"sum\", dZ, None)[0]\n            else:\n                if op in [\"mul\", \"div\"]:\n                    dX = _scatter_nd(\n                        argX,\n                        _muldiv(op, _gather_nd(argY, _expand(Y, dZ.shape[1:])))\n                        * dZ,\n                        X.shape[0],\n                    )\n                elif op in [\"add\", \"sub\", \"copy_lhs\"]:\n                    dX = _scatter_nd(argX, dZ, X.shape[0])\n            dX = _reduce_grad(dX, X.shape)\n        else:\n            dX = tf.zeros_like(X)\n        if op != \"copy_lhs\":\n            if reduce_op == \"sum\":\n                if op == \"mul\" and _need_reduce_last_dim(X, Y):\n                    dY = _gsddmm(gidx, \"dot\", X, dZ)\n                elif op in [\"mul\", \"div\"]:\n                    dY = _gsddmm(gidx, \"mul\", X, dZ)\n                    if op == \"div\":\n                        dY = -dY / (Y**2)\n                elif op in [\"add\", \"sub\", \"copy_rhs\"]:\n                    dY = _gsddmm(gidx, \"copy_rhs\", X, _addsub(op, dZ))\n            else:\n                out_shp = (Y.shape[0],) + dZ.shape[1:]\n                if op in [\"mul\", \"div\"]:\n                    dY = _scatter_nd(\n                        argY,\n                        _gather_nd(argX, _expand(X, dZ.shape[1:])) * dZ,\n                        Y.shape[0],\n                    )\n                    if op == \"div\":\n                        dY = -dY / (Y**2)\n                elif op in [\"add\", \"sub\", \"copy_rhs\"]:\n                    dY = _scatter_nd(argY, _addsub(op, dZ), Y.shape[0])\n            dY = _reduce_grad(dY, Y.shape)\n        else:\n            dY = tf.zeros_like(Y)\n        return dX, dY\n\n    return out, grad\n\n\ndef gspmm(gidx, op, reduce_op, X, Y):\n    @tf.custom_gradient\n    def _lambda(X, Y):\n        return gspmm_real(gidx, op, reduce_op, X, Y)\n\n    if X is None:\n        X = tf.zeros(())\n    if Y is None:\n        Y = tf.zeros(())\n    return _lambda(X, Y)\n\n\ndef gsddmm_real(gidx, op, X, Y, lhs_target, rhs_target):\n    out = _gsddmm(gidx, op, X, Y, lhs_target, rhs_target)\n\n    def grad(dZ):\n        if op != \"copy_rhs\":\n            if lhs_target in [\"u\", \"v\"]:\n                _gidx = gidx if lhs_target == \"v\" else gidx.reverse()\n                if op in [\"add\", \"sub\", \"copy_lhs\"]:\n                    dX = _gspmm(_gidx, \"copy_rhs\", \"sum\", None, dZ)[0]\n                else:  # mul, div, dot\n                    if rhs_target == lhs_target:\n                        dX = _gspmm(_gidx, \"copy_rhs\", \"sum\", None, dZ)[\n                            0\n                        ] * _muldiv(op, Y)\n                    elif rhs_target == \"e\":\n                        dX = _gspmm(\n                            _gidx, \"copy_rhs\", \"sum\", None, dZ * _muldiv(op, Y)\n                        )[0]\n                    else:  # rhs_target = !lhs_target\n                        dX = _gspmm(_gidx, \"mul\", \"sum\", _muldiv(op, Y), dZ)[0]\n            else:  # lhs_target == 'e'\n                if op in [\"add\", \"sub\", \"copy_lhs\"]:\n                    dX = dZ\n                else:  # mul, div, dot\n                    dX = _gsddmm(\n                        gidx, \"mul\", dZ, _muldiv(op, Y), \"e\", rhs_target\n                    )\n            dX = _reduce_grad(dX, X.shape)\n        else:\n            dX = tf.zeros_like(X)\n        if op != \"copy_lhs\":\n            if rhs_target in [\"u\", \"v\"]:\n                _gidx = gidx if rhs_target == \"v\" else gidx.reverse()\n                if op in [\"add\", \"sub\", \"copy_rhs\"]:\n                    dY = _gspmm(\n                        _gidx, \"copy_rhs\", \"sum\", None, _addsub(op, dZ)\n                    )[0]\n                else:  # mul, div, dot\n                    if lhs_target == rhs_target:\n                        dY = _gspmm(_gidx, \"copy_rhs\", \"sum\", None, dZ)[0] * X\n                    elif lhs_target == \"e\":\n                        dY = _gspmm(_gidx, \"copy_rhs\", \"sum\", None, dZ * X)[0]\n                    else:  # rhs_target = !lhs_target\n                        dY = _gspmm(_gidx, \"mul\", \"sum\", X, dZ)[0]\n                    if op == \"div\":\n                        dY = -dY / (Y**2)\n            else:\n                if op in [\"add\", \"sub\", \"copy_rhs\"]:\n                    dY = _addsub(op, dZ)\n                else:  # mul, div, dot\n                    dY = _gsddmm(gidx, \"mul\", dZ, X, \"e\", lhs_target)\n                    if op == \"div\":\n                        dY = -dY / (Y**2)\n            dY = _reduce_grad(dY, Y.shape)\n        else:\n            dY = tf.zeros_like(Y)\n        return dX, dY\n\n    return out, grad\n\n\ndef gsddmm(gidx, op, X, Y, lhs_target=\"u\", rhs_target=\"v\"):\n    @tf.custom_gradient\n    def _lambda(X, Y):\n        return gsddmm_real(gidx, op, X, Y, lhs_target, rhs_target)\n\n    if X is None:\n        X = tf.zeros(())\n    if Y is None:\n        Y = tf.zeros(())\n    return _lambda(X, Y)\n\n\ndef edge_softmax_real(gidx, score, eids=ALL, norm_by=\"dst\"):\n    if not is_all(eids):\n        gidx = gidx.edge_subgraph([eids], True).graph\n    if norm_by == \"src\":\n        gidx = gidx.reverse()\n    score_max = _gspmm(gidx, \"copy_rhs\", \"max\", None, score)[0]\n    score = tf.math.exp(_gsddmm(gidx, \"sub\", score, score_max, \"e\", \"v\"))\n    score_sum = _gspmm(gidx, \"copy_rhs\", \"sum\", None, score)[0]\n    out = _gsddmm(gidx, \"div\", score, score_sum, \"e\", \"v\")\n\n    def edge_softmax_backward(grad_out):\n        sds = out * grad_out\n        accum = gspmm(gidx, \"copy_rhs\", \"sum\", None, sds)\n        grad_score = sds - gsddmm(gidx, \"mul\", out, accum, \"e\", \"v\")\n        return grad_score\n\n    return out, edge_softmax_backward\n\n\ndef edge_softmax(gidx, logits, eids=ALL, norm_by=\"dst\"):\n    @tf.custom_gradient\n    def _lambda(logits):\n        return edge_softmax_real(gidx, logits, eids, norm_by)\n\n    return _lambda(logits)\n\n\ndef segment_reduce_real(op, x, offsets):\n    y, arg = _segment_reduce(op, x, offsets)\n\n    def segment_reduce_backward(dy):\n        m = x.shape[0]\n        if op == \"sum\":\n            offsets_np = asnumpy(offsets[1:])\n            indices_np = np.zeros((m + 1,), dtype=offsets_np.dtype)\n            np.add.at(indices_np, offsets_np, np.ones_like(offsets_np))\n            indices_np = np.cumsum(indices_np, -1)[:-1]\n            indices = zerocopy_from_numpy(indices_np)\n            dx = tf.gather(dy, indices)\n        else:\n            dx = _bwd_segment_cmp(dy, arg, m)\n        return dx\n\n    return y, segment_reduce_backward\n\n\ndef segment_reduce(op, x, offsets):\n    @tf.custom_gradient\n    def _lambda(x):\n        return segment_reduce_real(op, x, offsets)\n\n    return _lambda(x)\n\n\ndef scatter_add_real(x, idx, m):\n    y = _scatter_add(x, idx, m)\n\n    def scatter_add_backward(dy):\n        return tf.gather(dy, idx)\n\n    return y, scatter_add_backward\n\n\ndef scatter_add(x, idx, m):\n    @tf.custom_gradient\n    def _lambda(x):\n        return scatter_add_real(x, idx, m)\n\n    return _lambda(x)\n\n\ndef csrmm_real(gidxA, A_weights, gidxB, B_weights, num_vtypes):\n    gidxC, C_weights = _csrmm(gidxA, A_weights, gidxB, B_weights, num_vtypes)\n    nrows, ncols, C_indptr, C_indices, C_eids = gidxC.adjacency_matrix_tensors(\n        0, False, \"csr\"\n    )\n\n    def grad(dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights):\n        # Only the last argument is meaningful.\n        dgidxA, dA_weights = _csrmm(\n            gidxC,\n            dC_weights,\n            gidxB.reverse(),\n            B_weights,\n            gidxA.number_of_ntypes(),\n        )\n        dgidxB, dB_weights = _csrmm(\n            gidxA.reverse(),\n            A_weights,\n            gidxC,\n            dC_weights,\n            gidxB.number_of_ntypes(),\n        )\n        dA_weights = _csrmask(dgidxA, dA_weights, gidxA)\n        dB_weights = _csrmask(dgidxB, dB_weights, gidxB)\n        return dA_weights, dB_weights\n\n    return (\n        tf.constant(nrows),\n        tf.constant(ncols),\n        C_indptr,\n        C_indices,\n        C_eids,\n        C_weights,\n    ), grad\n\n\ndef csrmm(gidxA, A_weights, gidxB, B_weights, num_vtypes):\n    @tf.custom_gradient\n    def _lambda(A_weights, B_weights):\n        return csrmm_real(gidxA, A_weights, gidxB, B_weights, num_vtypes)\n\n    nrows, ncols, C_indptr, C_indices, C_eids, C_weights = _lambda(\n        A_weights, B_weights\n    )\n    gidxC = create_unitgraph_from_csr(\n        num_vtypes,\n        nrows.numpy(),\n        ncols.numpy(),\n        C_indptr,\n        C_indices,\n        C_eids,\n        [\"coo\", \"csr\", \"csc\"],\n    )\n    return gidxC, C_weights\n\n\ndef csrsum_real(gidxs, weights):\n    gidxC, C_weights = _csrsum(gidxs, weights)\n    nrows, ncols, C_indptr, C_indices, C_eids = gidxC.adjacency_matrix_tensors(\n        0, False, \"csr\"\n    )\n\n    def grad(dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights):\n        # Only the last argument is meaningful.\n        return tuple(_csrmask(gidxC, dC_weights, gidx) for gidx in gidxs)\n\n    return (\n        tf.constant(nrows),\n        tf.constant(ncols),\n        C_indptr,\n        C_indices,\n        C_eids,\n        C_weights,\n    ), grad\n\n\ndef csrsum(gidxs, weights):\n    @tf.custom_gradient\n    def _lambda(*weights):\n        return csrsum_real(gidxs, weights)\n\n    nrows, ncols, C_indptr, C_indices, C_eids, C_weights = _lambda(*weights)\n    num_vtypes = gidxs[0].number_of_ntypes()\n    gidxC = create_unitgraph_from_csr(\n        num_vtypes,\n        nrows.numpy(),\n        ncols.numpy(),\n        C_indptr,\n        C_indices,\n        C_eids,\n        [\"coo\", \"csr\", \"csc\"],\n    )\n    return gidxC, C_weights\n\n\ndef csrmask_real(gidxA, A_weights, gidxB):\n    B_weights = _csrmask(gidxA, A_weights, gidxB)\n\n    def grad(dB_weights):\n        return _csrmask(gidxB, dB_weights, gidxA)\n\n    return B_weights, grad\n\n\ndef csrmask(gidxA, A_weights, gidxB):\n    @tf.custom_gradient\n    def _lambda(A_weights):\n        return csrmask_real(gidxA, A_weights, gidxB)\n\n    return _lambda(A_weights)\n"
  },
  {
    "path": "python/dgl/backend/tensorflow/sparse_optim.py",
    "content": "\"\"\"Sparse optimizer is not supported for tensorflow\"\"\"\n"
  },
  {
    "path": "python/dgl/backend/tensorflow/tensor.py",
    "content": "\"\"\"Tensorflow backend implementation\"\"\"\nfrom __future__ import absolute_import\n\nimport builtins\nimport numbers\n\nimport numpy as np\nimport tensorflow as tf\n\nfrom ... import ndarray as nd\nfrom ...function.base import TargetCode\nfrom ...utils import version\n\nif version.parse(tf.__version__) < version.parse(\"2.3.0\"):\n    raise RuntimeError(\n        \"DGL requires TensorFlow>=2.3.0 for the official DLPack support.\"\n    )\n\n\ndef zerocopy_to_dlpack(data):\n    return tf.experimental.dlpack.to_dlpack(data)\n\n\ndef zerocopy_from_dlpack(dlpack_tensor):\n    # TODO(Jinjing): Tensorflow requires memory to be 64-bytes aligned. We check the\n    #   alignment and make a copy if needed. The functionality is better in TF's main repo.\n    aligned = nd.from_dlpack(dlpack_tensor).to_dlpack(64)\n    return tf.experimental.dlpack.from_dlpack(aligned)\n\n\ndef data_type_dict():\n    return {\n        \"bfloat16\": tf.bfloat16,\n        \"float16\": tf.float16,\n        \"float32\": tf.float32,\n        \"float64\": tf.float64,\n        \"uint8\": tf.uint8,\n        \"int8\": tf.int8,\n        \"int16\": tf.int16,\n        \"int32\": tf.int32,\n        \"int64\": tf.int64,\n        \"bool\": tf.bool,\n    }\n\n\ndef cpu():\n    return \"/cpu:0\"\n\n\ndef tensor(data, dtype=None):\n    if isinstance(data, tf.Tensor):\n        if dtype is None or data.dtype == dtype:\n            return data\n        else:\n            return tf.cast(data, dtype=dtype)\n    else:\n        if isinstance(data, numbers.Number):\n            data = [data]\n        return tf.convert_to_tensor(data, dtype=dtype)\n\n\ndef initialize_context():\n    tf.zeros(1)\n\n\ndef as_scalar(data):\n    data = data.numpy()\n    return data if np.isscalar(data) else data.item()\n\n\ndef get_preferred_sparse_format():\n    \"\"\"Get the preferred sparse matrix format supported by the backend.\n\n    Different backends have their preferred backend. This info is useful when\n    constructing a sparse matrix.\n    \"\"\"\n    return \"coo\"\n\n\ndef sparse_matrix(data, index, shape, force_format=False):\n    fmt = index[0]\n    if fmt != \"coo\":\n        raise TypeError(\n            \"Tensorflow backend only supports COO format. But got %s.\" % fmt\n        )\n    # tf.SparseTensor only supports int64 indexing,\n    # therefore manually casting to int64 when input in int32\n    spmat = tf.SparseTensor(\n        indices=tf.cast(tf.transpose(index[1], (1, 0)), tf.int64),\n        values=data,\n        dense_shape=shape,\n    )\n    return spmat, None\n\n\ndef sparse_matrix_indices(spmat):\n    return (\"coo\", spmat.indices)\n\n\ndef is_tensor(obj):\n    return isinstance(obj, tf.Tensor)\n\n\ndef shape(input):\n    return input.shape\n\n\ndef dtype(input):\n    return input.dtype\n\n\ndef ndim(input):\n    return input.ndim\n\n\ndef context(input):\n    spec = tf.DeviceSpec.from_string(input.device)\n    return \"/{}:{}\".format(spec.device_type.lower(), spec.device_index)\n\n\ndef device_type(ctx):\n    return tf.DeviceSpec.from_string(ctx).device_type.lower()\n\n\ndef device_id(ctx):\n    return tf.DeviceSpec.from_string(ctx).device_index\n\n\ndef to_backend_ctx(dglctx):\n    dev_type = dglctx.device_type\n    if dev_type == 1:\n        return \"/cpu:0\"\n    elif dev_type == 2:\n        return \"/gpu:%d\" % (dglctx.device_id)\n    else:\n        raise ValueError(\"Unsupported DGL device context:\", dglctx)\n\n\ndef astype(input, ty):\n    with tf.device(input.device):\n        return tf.cast(input, dtype=ty)\n\n\ndef asnumpy(input):\n    if isinstance(input, tf.SparseTensor):\n        # tf.sparse.to_dense assume sorted indices, need to turn off validate_indices in our cases\n        return tf.sparse.to_dense(input, validate_indices=False).numpy()\n    else:\n        return input.numpy()\n\n\ndef copy_to(input, ctx, **kwargs):\n    with tf.device(ctx):\n        new_tensor = tf.identity(input)\n    return new_tensor\n\n\ndef is_pinned(input):\n    return False  # not sure how to do this\n\n\ndef sum(input, dim, keepdims=False):\n    if input.dtype == tf.bool:\n        input = tf.cast(input, tf.int32)\n    return tf.reduce_sum(input, axis=dim, keepdims=keepdims)\n\n\ndef floor_div(in1, in2):\n    return astype(in1 / in2, dtype(in1))\n\n\ndef reduce_sum(input):\n    if input.dtype == tf.bool:\n        input = tf.cast(input, tf.int32)\n    return tf.reduce_sum(input)\n\n\ndef cumsum(input, dim):\n    if input.dtype == tf.bool:\n        input = tf.cast(input, tf.int32)\n    return tf.cumsum(input, axis=dim)\n\n\ndef mean(input, dim):\n    return tf.reduce_mean(input, axis=dim)\n\n\ndef reduce_mean(input):\n    return tf.reduce_mean(input)\n\n\ndef max(input, dim):\n    return tf.reduce_max(input, axis=dim)\n\n\ndef reduce_max(input):\n    return tf.reduce_max(input)\n\n\ndef min(input, dim):\n    return tf.reduce_min(input, axis=dim)\n\n\ndef reduce_min(input):\n    return tf.reduce_min(input)\n\n\ndef argsort(input, dim, descending):\n    if descending:\n        return tf.cast(\n            tf.argsort(input, axis=dim, direction=\"DESCENDING\"), dtype=tf.int64\n        )\n    else:\n        return tf.cast(\n            tf.argsort(input, axis=dim, direction=\"ASCENDING\"), dtype=tf.int64\n        )\n\n\ndef topk(input, k, dim, descending=True):\n    if not descending:\n        input = -input\n    shape = np.arange(input.ndim)\n    shape[dim], shape[-1] = shape[-1], shape[dim]\n    out1 = tf.transpose(input, perm=shape)\n    out2 = tf.math.top_k(out1, k=k, sorted=True)\n    out = tf.transpose(out2[0], shape)\n    if not descending:\n        out = -out\n    return out\n\n\ndef argtopk(input, k, dim, descending=True):\n    if not descending:\n        input = -input\n    shape = np.arange(input.ndim)\n    shape[dim], shape[-1] = shape[-1], shape[dim]\n    out1 = tf.transpose(input, perm=shape)\n    out2 = tf.math.top_k(out1, k=k, sorted=True)\n    out = tf.transpose(out2[1], shape)\n    if not descending:\n        out = -out\n    return out\n\n\ndef exp(input):\n    return tf.exp(input)\n\n\ndef inverse(input):\n    return tf.linalg.inv(input)\n\n\ndef sqrt(input):\n    return tf.sqrt(input)\n\n\ndef softmax(input, dim=-1):\n    return tf.math.softmax(input, axis=dim)\n\n\ndef cat(seq, dim):\n    return tf.concat(seq, axis=dim)\n\n\ndef stack(seq, dim):\n    return tf.stack(seq, axis=dim)\n\n\ndef split(input, sizes_or_sections, dim):\n    return [\n        copy_to(_, input.device)\n        for _ in tf.split(input, sizes_or_sections, axis=dim)\n    ]\n\n\ndef repeat(input, repeats, dim):\n    return tf.repeat(input, repeats, dim)\n\n\ndef gather_row(data, row_index):\n    return tf.gather(data, row_index)\n\n\ndef slice_axis(data, axis, begin, end):\n    # assert axis == 0\n    # tf doesn't behave well with negative\n    s = [slice(None) for i in range(data.ndim)]\n    if end == 0:\n        end = data.shape[axis]\n    s[axis] = slice(begin, end, None)\n    return data[tuple(s)]\n\n\ndef take(data, indices, dim):\n    return tf.gather_nd(data, indices, dim)\n\n\ndef narrow_row(x, start, stop):\n    return x[start:stop]\n\n\ndef scatter_row(data, row_index, value):\n    row_index = tf.expand_dims(row_index, 1)\n    # XXX(minjie): Normally, the copy_to here is unnecessary. However, TF has this\n    #   notorious legacy issue that int32 type data is always on CPU, which will\n    #   crash the program since DGL requires feature data to be on the same device\n    #   as graph structure.\n    return copy_to(\n        tf.tensor_scatter_nd_update(data, row_index, value), data.device\n    )\n\n\ndef index_add_inplace(data, row_idx, value):\n    raise NotImplementedError(\"Tensorflow doesn't support inplace index_add\")\n\n\ndef scatter_row_inplace(data, row_index, value):\n    raise NotImplementedError(\"Tensorflow doesn't support inplace update\")\n\n\ndef squeeze(input, dim):\n    return tf.squeeze(input, axis=dim)\n\n\ndef unsqueeze(input, dim):\n    return tf.expand_dims(input, axis=dim)\n\n\ndef reshape(input, shape):\n    return tf.reshape(input, shape)\n\n\ndef swapaxes(input, axis1, axis2):\n    ndim = input.ndim\n    t = list(range(ndim))\n    t[axis1], t[axis2] = axis2 % ndim, axis1 % ndim\n    return tf.transpose(input, perm=t)\n\n\ndef empty(shape, dtype, ctx):\n    # tf doesn't have tf.empty(), use zeros() as a workaround\n    return zeros(shape, dtype, ctx)\n\n\ndef zeros(shape, dtype, ctx):\n    with tf.device(ctx):\n        t = tf.zeros(shape, dtype=dtype)\n    return t\n\n\ndef zeros_like(input):\n    return tf.zeros_like(input)\n\n\ndef ones(shape, dtype, ctx):\n    with tf.device(ctx):\n        t = tf.ones(shape, dtype=dtype)\n    return t\n\n\ndef uniform(shape, dtype, ctx, low, high):\n    with tf.device(ctx):\n        t = tf.random.uniform(shape, dtype=dtype, minval=low, maxval=high)\n    return t\n\n\ndef randint(shape, dtype, ctx, low, high):\n    with tf.device(ctx):\n        t = tf.random.uniform(shape, dtype=dtype, minval=low, maxval=high)\n    return t\n\n\ndef pad_packed_tensor(input, lengths, value, l_min=None):\n    old_shape = input.shape\n    if isinstance(lengths, tf.Tensor):\n        max_len = as_scalar(tf.reduce_max(lengths))\n    else:\n        max_len = builtins.max(lengths)\n\n    if l_min is not None:\n        max_len = builtins.max(max_len, l_min)\n\n    batch_size = len(lengths)\n    ndim = input.ndim\n    tensor_list = []\n    cum_row = 0\n    pad_nparray = np.zeros((ndim, 2), dtype=np.int32)\n    for l in lengths:\n        t = input[cum_row : cum_row + l]\n        pad_nparray[0, 1] = max_len - l\n        t = tf.pad(\n            t, tf.constant(pad_nparray), mode=\"CONSTANT\", constant_values=value\n        )\n        tensor_list.append(t)\n        cum_row += l\n    return tf.stack(tensor_list, axis=0)\n\n\ndef pack_padded_tensor(input, lengths):\n    out_list = []\n    for i, l in enumerate(lengths):\n        t = input[i]\n        out = t[:l]\n        out_list.append(out)\n    return tf.concat(out_list, axis=0)\n\n\ndef boolean_mask(input, mask):\n    return tf.boolean_mask(input, mask)\n\n\ndef equal(x, y):\n    return x == y\n\n\ndef allclose(x, y, rtol=1e-4, atol=1e-4):\n    return np.allclose(\n        tf.convert_to_tensor(x).numpy(),\n        tf.convert_to_tensor(y).numpy(),\n        rtol=rtol,\n        atol=atol,\n    )\n\n\ndef logical_not(input):\n    return ~input\n\n\ndef logical_and(input1, input2):\n    return tf.math.logical_and(input1, input2)\n\n\ndef clone(input):\n    # TF tensor is always immutable so returning the input is safe.\n    return input\n\n\ndef clamp(data, min_val, max_val):\n    return tf.clip_by_value(data, min_val, max_val)\n\n\ndef replace_inf_with_zero(x):\n    return tf.where(tf.abs(x) == np.inf, 0, x)\n\n\ndef count_nonzero(input):\n    return int(tf.math.count_nonzero(input))\n\n\ndef unique(input, return_inverse=False, return_counts=False):\n    if return_inverse and return_counts:\n        return tf.unique_with_counts(input)\n    elif return_counts:\n        result = tf.unique_with_counts(input)\n        return result.y, result.count\n    elif return_inverse:\n        return tf.unique(input)\n    else:\n        return tf.unique(input).y\n\n\ndef full_1d(length, fill_value, dtype, ctx):\n    with tf.device(ctx):\n        t = tf.fill([length], value=fill_value)\n        t = tf.cast(t, dtype=dtype)\n    return t\n\n\ndef nonzero_1d(input):\n    nonzero_bool = tf.cast(input, tf.bool)\n    return tf.reshape(tf.where(nonzero_bool), (-1,))\n\n\ndef sort_1d(input):\n    return tf.sort(input), tf.cast(tf.argsort(input), dtype=tf.int64)\n\n\ndef arange(start, stop, dtype=tf.int64, ctx=None):\n    if not ctx:\n        ctx = \"/cpu:0\"\n    with tf.device(ctx):\n        t = tf.range(start, stop, dtype=dtype)\n    return t\n\n\ndef rand_shuffle(arr):\n    return tf.random.shuffle(arr)\n\n\ndef zerocopy_to_numpy(input):\n    return np.asarray(memoryview(input))\n\n\ndef zerocopy_from_numpy(np_array):\n    # NOTE: not zerocopy\n    # This assumes tensor should be on cpu\n    with tf.device(\"/cpu:0\"):\n        t = tf.convert_to_tensor(np_array)\n    return t\n\n\ndef zerocopy_to_dgl_ndarray(data):\n    if device_type(data.device) == \"gpu\" and data.dtype in (tf.int32, tf.int64):\n        # NOTE: TF doesn't keep signed tensors on GPU due to legacy issues with\n        #   shape inference. Convert it to unsigned and cast it back afterwards.\n        if data.dtype == tf.int32:\n            data = tf.cast(data, tf.uint32)\n        elif data.dtype == tf.int64:\n            data = tf.cast(data, tf.uint64)\n        return nd.cast_to_signed(nd.from_dlpack(zerocopy_to_dlpack(data)))\n    else:\n        return nd.from_dlpack(zerocopy_to_dlpack(data))\n\n\ndef zerocopy_to_dgl_ndarray_for_write(input):\n    return zerocopy_to_dgl_ndarray(input)\n\n\ndef zerocopy_from_dgl_ndarray(input):\n    return zerocopy_from_dlpack(input.to_dlpack())\n\n\ndef sync():\n    context = context().context()\n    context.async_wait()\n\n\nclass GradContext:\n    def __init__(self):\n        self.tensor_for_grad = []\n        self.grad_list = []\n        self.tape = None\n\n    def set_tape(self, tape):\n        self.tape = tape\n\n    def add_tensor(self, x):\n        idx_pop = []\n        for idx, ele in enumerate(self.tensor_for_grad):\n            if ele._id == x._id:\n                idx_pop.append(idx)\n        if len(idx_pop) > 0:\n            self.tensor_for_grad.pop(idx_pop[0])\n        if self.tape is not None:\n            self.tape.watch(x)\n        self.tensor_for_grad.append(x)\n\n    def backward(self, x, head_gradient=None):\n        if head_gradient is not None:\n            x = x * head_gradient\n        self.grad_list = self.tape.gradient(x, self.tensor_for_grad)\n\n    def is_no_grad(self, x):\n        idx_pop = []\n        for idx, ele in enumerate(self.tensor_for_grad):\n            if ele._id == x._id:\n                idx_pop.append(idx)\n        if len(idx_pop) == 0:\n            return True\n        else:\n            return self.grad_list[idx_pop[0]] is None\n\n    def grad(self, x):\n        idx_pop = []\n        for idx, ele in enumerate(self.tensor_for_grad):\n            if ele._id == x._id:\n                idx_pop.append(idx)\n        assert len(idx_pop) == 1\n        t = self.grad_list[idx_pop[0]]\n        return tf.convert_to_tensor(t)\n\n\ncgrad = GradContext()\n\n\ndef get_cgrad():\n    return cgrad\n\n\nclass record_grad:\n    def __init__(self):\n        self.tape = tf.GradientTape()\n\n    def __enter__(self):\n        cgrad.set_tape(self.tape)\n        self.tape.__enter__()\n        for x in cgrad.tensor_for_grad:\n            self.tape.watch(x)\n\n    def __exit__(self, exc_type, exc_value, exc_traceback):\n        # pass\n        self.tape.__exit__(exc_type, exc_value, exc_traceback)\n        cgrad.tape = None\n\n\ndef attach_grad(x):\n    cgrad.add_tensor(x)\n    return x\n\n\ndef backward(x, head_gradient=None):\n    cgrad.backward(x, head_gradient)\n\n\ndef grad(x):\n    return cgrad.grad(x)\n\n\ndef is_no_grad(x):\n    return cgrad.is_no_grad(x)\n\n\ndef is_recording():\n    raise NotImplementedError(\"Tensorflow doesn't support is_recording\")\n\n\nno_grad = None\n\ninitialize_context()\n"
  },
  {
    "path": "python/dgl/base.py",
    "content": "\"\"\"Module for base types and utilities.\"\"\"\nfrom __future__ import absolute_import\n\nimport warnings\n\nfrom ._ffi.base import DGLError  # pylint: disable=unused-import\nfrom ._ffi.function import _init_internal_api\n\n# A special symbol for selecting all nodes or edges.\nALL = \"__ALL__\"\n# An alias for [:]\nSLICE_FULL = slice(None, None, None)\n# Reserved column names for storing parent node/edge types and IDs in flattened heterographs\nNTYPE = \"_TYPE\"\nNID = \"_ID\"\nETYPE = \"_TYPE\"\nEID = \"_ID\"\n\n_INTERNAL_COLUMNS = {NTYPE, NID, ETYPE, EID}\n\n\ndef is_internal_column(name):\n    \"\"\"Return true if the column name is reversed by DGL.\"\"\"\n    return name in _INTERNAL_COLUMNS\n\n\ndef is_all(arg):\n    \"\"\"Return true if the argument is a special symbol for all nodes or edges.\"\"\"\n    return isinstance(arg, str) and arg == ALL\n\n\n# pylint: disable=invalid-name\n_default_formatwarning = warnings.formatwarning\n\n\nclass DGLWarning(UserWarning):\n    \"\"\"DGL Warning class.\"\"\"\n\n\n# pylint: disable=unused-argument\ndef dgl_warning_format(message, category, filename, lineno, line=None):\n    \"\"\"Format DGL warnings.\"\"\"\n    if isinstance(category, DGLWarning):\n        return \"DGL Warning: {}\\n\".format(message)\n    else:\n        return _default_formatwarning(\n            message, category, filename, lineno, line=None\n        )\n\n\ndef dgl_warning(message, category=DGLWarning, stacklevel=2):\n    \"\"\"DGL warning wrapper that defaults to ``DGLWarning`` instead of ``UserWarning`` category.\"\"\"\n    return warnings.warn(message, category=category, stacklevel=stacklevel)\n\n\nwarnings.formatwarning = dgl_warning_format\n\n_init_internal_api()\n"
  },
  {
    "path": "python/dgl/batch.py",
    "content": "\"\"\"Utilities for batching/unbatching graphs.\"\"\"\nfrom collections.abc import Mapping\n\nfrom . import backend as F, convert, utils\nfrom .base import ALL, DGLError, EID, is_all, NID\nfrom .heterograph import DGLGraph\nfrom .heterograph_index import disjoint_union, slice_gidx\n\n\n__all__ = [\"batch\", \"unbatch\", \"slice_batch\"]\n\n\ndef batch(graphs, ndata=ALL, edata=ALL):\n    r\"\"\"Batch a collection of :class:`DGLGraph` s into one graph for more efficient\n    graph computation.\n\n    Each input graph becomes one disjoint component of the batched graph. The nodes\n    and edges are relabeled to be disjoint segments:\n\n    =================  =========  =================  ===  =========\n                       graphs[0]  graphs[1]          ...  graphs[k]\n    =================  =========  =================  ===  =========\n    Original node ID   0 ~ N_0    0 ~ N_1            ...  0 ~ N_k\n    New node ID        0 ~ N_0    N_0 ~ N_0+N_1      ...  \\sum_{i=0}^{k-1} N_i ~\n                                                          \\sum_{i=0}^k N_i\n    =================  =========  =================  ===  =========\n\n    Because of this, many of the computations on a batched graph are the same as if\n    performed on each graph individually, but become much more efficient\n    since they can be parallelized easily. This makes ``dgl.batch`` very useful\n    for tasks dealing with many graph samples such as graph classification tasks.\n\n    For heterograph inputs, they must share the same set of relations (i.e., node types\n    and edge types) and the function will perform batching on each relation one by one.\n    Thus, the result is also a heterograph and has the same set of relations as the inputs.\n\n    The numbers of nodes and edges of the input graphs are accessible via the\n    :func:`DGLGraph.batch_num_nodes` and :func:`DGLGraph.batch_num_edges` attributes\n    of the resulting graph. For homogeneous graphs, they are 1D integer tensors,\n    with each element being the number of nodes/edges of the corresponding input graph. For\n    heterographs, they are dictionaries of 1D integer tensors, with node\n    type or edge type as the keys.\n\n    The function supports batching batched graphs. The batch size of the result\n    graph is the sum of the batch sizes of all the input graphs.\n\n    By default, node/edge features are batched by concatenating the feature tensors\n    of all input graphs. This thus requires features of the same name to have\n    the same data type and feature size. One can pass ``None`` to the ``ndata``\n    or ``edata`` argument to prevent feature batching, or pass a list of strings\n    to specify which features to batch.\n\n    To unbatch the graph back to a list, use the :func:`dgl.unbatch` function.\n\n    Parameters\n    ----------\n    graphs : list[DGLGraph]\n        Input graphs.\n    ndata : list[str], None, optional\n        Node features to batch.\n    edata : list[str], None, optional\n        Edge features to batch.\n\n    Returns\n    -------\n    DGLGraph\n        Batched graph.\n\n    Examples\n    --------\n\n    Batch homogeneous graphs\n\n    >>> import dgl\n    >>> import torch as th\n    >>> # 4 nodes, 3 edges\n    >>> g1 = dgl.graph((th.tensor([0, 1, 2]), th.tensor([1, 2, 3])))\n    >>> # 3 nodes, 4 edges\n    >>> g2 = dgl.graph((th.tensor([0, 0, 0, 1]), th.tensor([0, 1, 2, 0])))\n    >>> bg = dgl.batch([g1, g2])\n    >>> bg\n    Graph(num_nodes=7, num_edges=7,\n          ndata_schemes={}\n          edata_schemes={})\n    >>> bg.batch_size\n    2\n    >>> bg.batch_num_nodes()\n    tensor([4, 3])\n    >>> bg.batch_num_edges()\n    tensor([3, 4])\n    >>> bg.edges()\n    (tensor([0, 1, 2, 4, 4, 4, 5], tensor([1, 2, 3, 4, 5, 6, 4]))\n\n    Batch batched graphs\n\n    >>> bbg = dgl.batch([bg, bg])\n    >>> bbg.batch_size\n    4\n    >>> bbg.batch_num_nodes()\n    tensor([4, 3, 4, 3])\n    >>> bbg.batch_num_edges()\n    tensor([3, 4, 3, 4])\n\n    Batch graphs with feature data\n\n    >>> g1.ndata['x'] = th.zeros(g1.num_nodes(), 3)\n    >>> g1.edata['w'] = th.ones(g1.num_edges(), 2)\n    >>> g2.ndata['x'] = th.ones(g2.num_nodes(), 3)\n    >>> g2.edata['w'] = th.zeros(g2.num_edges(), 2)\n    >>> bg = dgl.batch([g1, g2])\n    >>> bg.ndata['x']\n    tensor([[0, 0, 0],\n            [0, 0, 0],\n            [0, 0, 0],\n            [0, 0, 0],\n            [1, 1, 1],\n            [1, 1, 1],\n            [1, 1, 1]])\n    >>> bg.edata['w']\n    tensor([[1, 1],\n            [1, 1],\n            [1, 1],\n            [0, 0],\n            [0, 0],\n            [0, 0],\n            [0, 0]])\n\n    Batch heterographs\n\n    >>> hg1 = dgl.heterograph({\n    ...     ('user', 'plays', 'game') : (th.tensor([0, 1]), th.tensor([0, 0]))})\n    >>> hg2 = dgl.heterograph({\n    ...     ('user', 'plays', 'game') : (th.tensor([0, 0, 0]), th.tensor([1, 0, 2]))})\n    >>> bhg = dgl.batch([hg1, hg2])\n    >>> bhg\n    Graph(num_nodes={'user': 3, 'game': 4},\n          num_edges={('user', 'plays', 'game'): 5},\n          metagraph=[('drug', 'game')])\n    >>> bhg.batch_size\n    2\n    >>> bhg.batch_num_nodes()\n    {'user' : tensor([2, 1]), 'game' : tensor([1, 3])}\n    >>> bhg.batch_num_edges()\n    {('user', 'plays', 'game') : tensor([2, 3])}\n\n    See Also\n    --------\n    unbatch\n    \"\"\"\n    if len(graphs) == 0:\n        raise DGLError(\"The input list of graphs cannot be empty.\")\n    if not (is_all(ndata) or isinstance(ndata, list) or ndata is None):\n        raise DGLError(\n            \"Invalid argument ndata: must be a string list but got {}.\".format(\n                type(ndata)\n            )\n        )\n    if not (is_all(edata) or isinstance(edata, list) or edata is None):\n        raise DGLError(\n            \"Invalid argument edata: must be a string list but got {}.\".format(\n                type(edata)\n            )\n        )\n    if any(g.is_block for g in graphs):\n        raise DGLError(\"Batching a MFG is not supported.\")\n\n    relations = list(graphs[0].canonical_etypes)\n    relation_ids = [graphs[0].get_etype_id(r) for r in relations]\n    ntypes = list(graphs[0].ntypes)\n    ntype_ids = [graphs[0].get_ntype_id(n) for n in ntypes]\n    etypes = [etype for _, etype, _ in relations]\n\n    gidx = disjoint_union(\n        graphs[0]._graph.metagraph, [g._graph for g in graphs]\n    )\n    retg = DGLGraph(gidx, ntypes, etypes)\n\n    # Compute batch num nodes\n    bnn = {}\n    for ntype in ntypes:\n        bnn[ntype] = F.cat([g.batch_num_nodes(ntype) for g in graphs], 0)\n    retg.set_batch_num_nodes(bnn)\n\n    # Compute batch num edges\n    bne = {}\n    for etype in relations:\n        bne[etype] = F.cat([g.batch_num_edges(etype) for g in graphs], 0)\n    retg.set_batch_num_edges(bne)\n\n    # Batch node feature\n    if ndata is not None:\n        for ntype_id, ntype in zip(ntype_ids, ntypes):\n            all_empty = all(g._graph.num_nodes(ntype_id) == 0 for g in graphs)\n            frames = [\n                g._node_frames[ntype_id]\n                for g in graphs\n                if g._graph.num_nodes(ntype_id) > 0 or all_empty\n            ]\n            # TODO: do we require graphs with no nodes/edges to have the same schema?  Currently\n            # we allow empty graphs to have no features during batching.\n            ret_feat = _batch_feat_dicts(\n                frames, ndata, 'nodes[\"{}\"].data'.format(ntype)\n            )\n            retg.nodes[ntype].data.update(ret_feat)\n\n    # Batch edge feature\n    if edata is not None:\n        for etype_id, etype in zip(relation_ids, relations):\n            all_empty = all(g._graph.num_edges(etype_id) == 0 for g in graphs)\n            frames = [\n                g._edge_frames[etype_id]\n                for g in graphs\n                if g._graph.num_edges(etype_id) > 0 or all_empty\n            ]\n            # TODO: do we require graphs with no nodes/edges to have the same schema?  Currently\n            # we allow empty graphs to have no features during batching.\n            ret_feat = _batch_feat_dicts(\n                frames, edata, \"edges[{}].data\".format(etype)\n            )\n            retg.edges[etype].data.update(ret_feat)\n\n    return retg\n\n\ndef _batch_feat_dicts(frames, keys, feat_dict_name):\n    \"\"\"Internal function to batch feature dictionaries.\n\n    Parameters\n    ----------\n    frames : list[Frame]\n        List of frames\n    keys : list[str]\n        Feature keys. Can be '__ALL__', meaning batching all features.\n    feat_dict_name : str\n        Name of the feature dictionary for reporting errors.\n\n    Returns\n    -------\n    dict[str, Tensor]\n        New feature dict.\n    \"\"\"\n    if len(frames) == 0:\n        return {}\n    schemas = [frame.schemes for frame in frames]\n    # sanity checks\n    if is_all(keys):\n        utils.check_all_same_schema(schemas, feat_dict_name)\n        keys = schemas[0].keys()\n    else:\n        utils.check_all_same_schema_for_keys(schemas, keys, feat_dict_name)\n    # concat features\n    ret_feat = {k: F.cat([fd[k] for fd in frames], 0) for k in keys}\n    return ret_feat\n\n\ndef unbatch(g, node_split=None, edge_split=None):\n    \"\"\"Revert the batch operation by split the given graph into a list of small ones.\n\n    This is the reverse operation of :func:``dgl.batch``. If the ``node_split``\n    or the ``edge_split`` is not given, it calls :func:`DGLGraph.batch_num_nodes`\n    and :func:`DGLGraph.batch_num_edges` of the input graph to get the information.\n\n    If the ``node_split`` or the ``edge_split`` arguments are given,\n    it will partition the graph according to the given segments. One must assure\n    that the partition is valid -- edges of the i^th graph only connect nodes\n    belong to the i^th graph. Otherwise, DGL will throw an error.\n\n    The function supports heterograph input, in which case the two split\n    section arguments shall be of dictionary type -- similar to the\n    :func:`DGLGraph.batch_num_nodes`\n    and :func:`DGLGraph.batch_num_edges` attributes of a heterograph.\n\n    Parameters\n    ----------\n    g : DGLGraph\n        Input graph to unbatch.\n    node_split : Tensor, dict[str, Tensor], optional\n        Number of nodes of each result graph.\n    edge_split : Tensor, dict[str, Tensor], optional\n        Number of edges of each result graph.\n\n    Returns\n    -------\n    list[DGLGraph]\n        Unbatched list of graphs.\n\n    Examples\n    --------\n\n    Unbatch a batched graph\n\n    >>> import dgl\n    >>> import torch as th\n    >>> # 4 nodes, 3 edges\n    >>> g1 = dgl.graph((th.tensor([0, 1, 2]), th.tensor([1, 2, 3])))\n    >>> # 3 nodes, 4 edges\n    >>> g2 = dgl.graph((th.tensor([0, 0, 0, 1]), th.tensor([0, 1, 2, 0])))\n    >>> # add features\n    >>> g1.ndata['x'] = th.zeros(g1.num_nodes(), 3)\n    >>> g1.edata['w'] = th.ones(g1.num_edges(), 2)\n    >>> g2.ndata['x'] = th.ones(g2.num_nodes(), 3)\n    >>> g2.edata['w'] = th.zeros(g2.num_edges(), 2)\n    >>> bg = dgl.batch([g1, g2])\n    >>> f1, f2 = dgl.unbatch(bg)\n    >>> f1\n    Graph(num_nodes=4, num_edges=3,\n          ndata_schemes={‘x’ : Scheme(shape=(3,), dtype=torch.float32)}\n          edata_schemes={‘w’ : Scheme(shape=(2,), dtype=torch.float32)})\n    >>> f2\n    Graph(num_nodes=3, num_edges=4,\n          ndata_schemes={‘x’ : Scheme(shape=(3,), dtype=torch.float32)}\n          edata_schemes={‘w’ : Scheme(shape=(2,), dtype=torch.float32)})\n\n    With provided split arguments:\n\n    >>> g1 = dgl.graph((th.tensor([0, 1, 2]), th.tensor([1, 2, 3])))\n    >>> g2 = dgl.graph((th.tensor([0, 0, 0, 1]), th.tensor([0, 1, 2, 0])))\n    >>> g3 = dgl.graph((th.tensor([0]), th.tensor([1])))\n    >>> bg = dgl.batch([g1, g2, g3])\n    >>> bg.batch_num_nodes()\n    tensor([4, 3, 2])\n    >>> bg.batch_num_edges()\n    tensor([3, 4, 1])\n    >>> # unbatch but merge g2 and g3\n    >>> f1, f2 = dgl.unbatch(bg, th.tensor([4, 5]), th.tensor([3, 5]))\n    >>> f1\n    Graph(num_nodes=4, num_edges=3,\n          ndata_schemes={}\n          edata_schemes={})\n    >>> f2\n    Graph(num_nodes=5, num_edges=5,\n          ndata_schemes={}\n          edata_schemes={})\n\n    Heterograph input\n\n    >>> hg1 = dgl.heterograph({\n    ...     ('user', 'plays', 'game') : (th.tensor([0, 1]), th.tensor([0, 0]))})\n    >>> hg2 = dgl.heterograph({\n    ...     ('user', 'plays', 'game') : (th.tensor([0, 0, 0]), th.tensor([1, 0, 2]))})\n    >>> bhg = dgl.batch([hg1, hg2])\n    >>> f1, f2 = dgl.unbatch(bhg)\n    >>> f1\n    Graph(num_nodes={'user': 2, 'game': 1},\n          num_edges={('user', 'plays', 'game'): 2},\n          metagraph=[('drug', 'game')])\n    >>> f2\n    Graph(num_nodes={'user': 1, 'game': 3},\n          num_edges={('user', 'plays', 'game'): 3},\n          metagraph=[('drug', 'game')])\n\n    See Also\n    --------\n    batch\n    \"\"\"\n    num_split = None\n    # Parse node_split\n    if node_split is None:\n        node_split = {ntype: g.batch_num_nodes(ntype) for ntype in g.ntypes}\n    elif not isinstance(node_split, Mapping):\n        if len(g.ntypes) != 1:\n            raise DGLError(\n                \"Must provide a dictionary for argument node_split when\"\n                \" there are multiple node types.\"\n            )\n        node_split = {g.ntypes[0]: node_split}\n    if node_split.keys() != set(g.ntypes):\n        raise DGLError(\"Must specify node_split for each node type.\")\n    for split in node_split.values():\n        if num_split is not None and num_split != len(split):\n            raise DGLError(\n                \"All node_split and edge_split must specify the same number\"\n                \" of split sizes.\"\n            )\n        num_split = len(split)\n\n    # Parse edge_split\n    if edge_split is None:\n        edge_split = {\n            etype: g.batch_num_edges(etype) for etype in g.canonical_etypes\n        }\n    elif not isinstance(edge_split, Mapping):\n        if len(g.etypes) != 1:\n            raise DGLError(\n                \"Must provide a dictionary for argument edge_split when\"\n                \" there are multiple edge types.\"\n            )\n        edge_split = {g.canonical_etypes[0]: edge_split}\n    if edge_split.keys() != set(g.canonical_etypes):\n        raise DGLError(\"Must specify edge_split for each canonical edge type.\")\n    for split in edge_split.values():\n        if num_split is not None and num_split != len(split):\n            raise DGLError(\n                \"All edge_split and edge_split must specify the same number\"\n                \" of split sizes.\"\n            )\n        num_split = len(split)\n\n    node_split = {\n        k: F.asnumpy(split).tolist() for k, split in node_split.items()\n    }\n    edge_split = {\n        k: F.asnumpy(split).tolist() for k, split in edge_split.items()\n    }\n\n    # Split edges for each relation\n    edge_dict_per = [{} for i in range(num_split)]\n    for rel in g.canonical_etypes:\n        srctype, etype, dsttype = rel\n        srcnid_off = dstnid_off = 0\n        u, v = g.edges(order=\"eid\", etype=rel)\n        us = F.split(u, edge_split[rel], 0)\n        vs = F.split(v, edge_split[rel], 0)\n        for i, (subu, subv) in enumerate(zip(us, vs)):\n            edge_dict_per[i][rel] = (subu - srcnid_off, subv - dstnid_off)\n            srcnid_off += node_split[srctype][i]\n            dstnid_off += node_split[dsttype][i]\n    num_nodes_dict_per = [\n        {k: split[i] for k, split in node_split.items()}\n        for i in range(num_split)\n    ]\n\n    # Create graphs\n    gs = [\n        convert.heterograph(edge_dict, num_nodes_dict, idtype=g.idtype)\n        for edge_dict, num_nodes_dict in zip(edge_dict_per, num_nodes_dict_per)\n    ]\n\n    # Unbatch node features\n    for ntype in g.ntypes:\n        for key, feat in g.nodes[ntype].data.items():\n            subfeats = F.split(feat, node_split[ntype], 0)\n            for subg, subf in zip(gs, subfeats):\n                subg.nodes[ntype].data[key] = subf\n\n    # Unbatch edge features\n    for etype in g.canonical_etypes:\n        for key, feat in g.edges[etype].data.items():\n            subfeats = F.split(feat, edge_split[etype], 0)\n            for subg, subf in zip(gs, subfeats):\n                subg.edges[etype].data[key] = subf\n\n    return gs\n\n\ndef slice_batch(g, gid, store_ids=False):\n    \"\"\"Get a particular graph from a batch of graphs.\n\n    Parameters\n    ----------\n    g : DGLGraph\n        Input batched graph.\n    gid : int\n        The ID of the graph to retrieve.\n    store_ids : bool\n        If True, it will store the raw IDs of the extracted nodes and edges in the ``ndata`` and\n        ``edata`` of the resulting graph under name ``dgl.NID`` and ``dgl.EID``, respectively.\n\n    Returns\n    -------\n    DGLGraph\n        Retrieved graph.\n\n    Examples\n    --------\n\n    The following example uses PyTorch backend.\n\n    >>> import dgl\n    >>> import torch\n\n    Create a batched graph.\n\n    >>> g1 = dgl.graph(([0, 1], [2, 3]))\n    >>> g2 = dgl.graph(([1], [2]))\n    >>> bg = dgl.batch([g1, g2])\n\n    Get the second component graph.\n\n    >>> g = dgl.slice_batch(bg, 1)\n    >>> print(g)\n    Graph(num_nodes=3, num_edges=1,\n          ndata_schemes={}\n          edata_schemes={})\n    \"\"\"\n    start_nid = []\n    num_nodes = []\n    for ntype in g.ntypes:\n        batch_num_nodes = g.batch_num_nodes(ntype)\n        num_nodes.append(F.as_scalar(batch_num_nodes[gid]))\n        if gid == 0:\n            start_nid.append(0)\n        else:\n            start_nid.append(\n                F.as_scalar(F.sum(F.slice_axis(batch_num_nodes, 0, 0, gid), 0))\n            )\n\n    start_eid = []\n    num_edges = []\n    for etype in g.canonical_etypes:\n        batch_num_edges = g.batch_num_edges(etype)\n        num_edges.append(F.as_scalar(batch_num_edges[gid]))\n        if gid == 0:\n            start_eid.append(0)\n        else:\n            start_eid.append(\n                F.as_scalar(F.sum(F.slice_axis(batch_num_edges, 0, 0, gid), 0))\n            )\n\n    # Slice graph structure\n    gidx = slice_gidx(\n        g._graph,\n        utils.toindex(num_nodes),\n        utils.toindex(start_nid),\n        utils.toindex(num_edges),\n        utils.toindex(start_eid),\n    )\n    retg = DGLGraph(gidx, g.ntypes, g.etypes)\n\n    # Slice node features\n    for ntid, ntype in enumerate(g.ntypes):\n        stnid = start_nid[ntid]\n        for key, feat in g.nodes[ntype].data.items():\n            subfeats = F.slice_axis(feat, 0, stnid, stnid + num_nodes[ntid])\n            retg.nodes[ntype].data[key] = subfeats\n\n        if store_ids:\n            retg.nodes[ntype].data[NID] = F.arange(\n                stnid, stnid + num_nodes[ntid], retg.idtype, retg.device\n            )\n\n    # Slice edge features\n    for etid, etype in enumerate(g.canonical_etypes):\n        steid = start_eid[etid]\n        for key, feat in g.edges[etype].data.items():\n            subfeats = F.slice_axis(feat, 0, steid, steid + num_edges[etid])\n            retg.edges[etype].data[key] = subfeats\n\n        if store_ids:\n            retg.edges[etype].data[EID] = F.arange(\n                steid, steid + num_edges[etid], retg.idtype, retg.device\n            )\n\n    return retg\n"
  },
  {
    "path": "python/dgl/container.py",
    "content": "\"\"\"Container data structures used in DGL runtime.\nreference: tvm/python/tvm/collections.py\n\"\"\"\nfrom __future__ import absolute_import as _abs\n\nfrom . import _api_internal\nfrom ._ffi.object import ObjectBase, register_object\nfrom ._ffi.object_generic import convert_to_object\n\n\n@register_object\nclass List(ObjectBase):\n    \"\"\"List container of DGL.\n\n    You do not need to create List explicitly.\n    Normally python list and tuple will be converted automatically\n    to List during dgl function call.\n    You may get List in return values of DGL function call.\n    \"\"\"\n\n    def __getitem__(self, i):\n        if isinstance(i, slice):\n            start = i.start if i.start is not None else 0\n            stop = i.stop if i.stop is not None else len(self)\n            step = i.step if i.step is not None else 1\n            if start < 0:\n                start += len(self)\n            if stop < 0:\n                stop += len(self)\n            return [self[idx] for idx in range(start, stop, step)]\n\n        if i < -len(self) or i >= len(self):\n            raise IndexError(\n                \"List index out of range. List size: {}, got index {}\".format(\n                    len(self), i\n                )\n            )\n        if i < 0:\n            i += len(self)\n        ret = _api_internal._ListGetItem(self, i)\n        if isinstance(ret, Value):\n            ret = ret.data\n        return ret\n\n    def __len__(self):\n        return _api_internal._ListSize(self)\n\n\n@register_object\nclass Map(ObjectBase):\n    \"\"\"Map container of DGL.\n\n    You do not need to create Map explicitly.\n    Normally python dict will be converted automaticall to Map during dgl function call.\n    You can use convert to create a dict[ObjectBase-> ObjectBase] into a Map\n    \"\"\"\n\n    def __getitem__(self, k):\n        return _api_internal._MapGetItem(self, k)\n\n    def __contains__(self, k):\n        return _api_internal._MapCount(self, k) != 0\n\n    def items(self):\n        \"\"\"Get the items from the map\"\"\"\n        akvs = _api_internal._MapItems(self)\n        return [(akvs[i], akvs[i + 1]) for i in range(0, len(akvs), 2)]\n\n    def __len__(self):\n        return _api_internal._MapSize(self)\n\n\n@register_object\nclass StrMap(Map):\n    \"\"\"A special map container that has str as key.\n\n    You can use convert to create a dict[str->ObjectBase] into a Map.\n    \"\"\"\n\n    def items(self):\n        \"\"\"Get the items from the map\"\"\"\n        akvs = _api_internal._MapItems(self)\n        return [(akvs[i], akvs[i + 1]) for i in range(0, len(akvs), 2)]\n\n\n@register_object\nclass Value(ObjectBase):\n    \"\"\"Object wrapper for various values.\"\"\"\n\n    @property\n    def data(self):\n        \"\"\"Return the value data.\"\"\"\n        return _api_internal._ValueGet(self)\n\n\ndef convert_to_strmap(value):\n    \"\"\"Convert a python dictionary to a dgl.contrainer.StrMap\"\"\"\n    assert isinstance(value, dict), \"Only support dict\"\n    if len(value) == 0:\n        return _api_internal._EmptyStrMap()\n    else:\n        return convert_to_object(value)\n"
  },
  {
    "path": "python/dgl/convert.py",
    "content": "\"\"\"Module for converting graph from/to other object.\"\"\"\n\nfrom collections import defaultdict\nfrom collections.abc import Mapping\n\nimport networkx as nx\nimport numpy as np\nfrom scipy.sparse import spmatrix\n\nfrom . import backend as F, graph_index, heterograph_index, utils\nfrom .base import DGLError, EID, ETYPE, NID, NTYPE\nfrom .heterograph import combine_frames, DGLBlock, DGLGraph\n\n__all__ = [\n    \"graph\",\n    \"hetero_from_shared_memory\",\n    \"heterograph\",\n    \"create_block\",\n    \"block_to_graph\",\n    \"to_heterogeneous\",\n    \"to_homogeneous\",\n    \"from_scipy\",\n    \"bipartite_from_scipy\",\n    \"from_networkx\",\n    \"bipartite_from_networkx\",\n    \"to_networkx\",\n    \"from_cugraph\",\n    \"to_cugraph\",\n]\n\n\ndef graph(\n    data,\n    *,\n    num_nodes=None,\n    idtype=None,\n    device=None,\n    row_sorted=False,\n    col_sorted=False,\n):\n    \"\"\"Create a graph and return.\n\n    Parameters\n    ----------\n    data : graph data\n        The data for constructing a graph, which takes the form of :math:`(U, V)`.\n        :math:`(U[i], V[i])` forms the edge with ID :math:`i` in the graph.\n        The allowed data formats are:\n\n        - ``(Tensor, Tensor)``: Each tensor must be a 1D tensor containing node IDs.\n          DGL calls this format \"tuple of node-tensors\". The tensors should have the same\n          data type of int32/int64 and device context (see below the descriptions of\n          :attr:`idtype` and :attr:`device`).\n        - ``('coo', (Tensor, Tensor))``: Same as ``(Tensor, Tensor)``.\n        - ``('csr', (Tensor, Tensor, Tensor))``: The three tensors form the CSR representation\n          of the graph's adjacency matrix.  The first one is the row index pointer.  The\n          second one is the column indices.  The third one is the edge IDs, which can be empty\n          to represent consecutive integer IDs starting from 0.\n        - ``('csc', (Tensor, Tensor, Tensor))``: The three tensors form the CSC representation\n          of the graph's adjacency matrix.  The first one is the column index pointer.  The\n          second one is the row indices.  The third one is the edge IDs, which can be empty\n          to represent consecutive integer IDs starting from 0.\n\n        The tensors can be replaced with any iterable of integers (e.g. list, tuple,\n        numpy.ndarray).\n    num_nodes : int, optional\n        The number of nodes in the graph. If not given, this will be the largest node ID\n        plus 1 from the :attr:`data` argument. If given and the value is no greater than\n        the largest node ID from the :attr:`data` argument, DGL will raise an error.\n    idtype : int32 or int64, optional\n        The data type for storing the structure-related graph information such as node and\n        edge IDs. It should be a framework-specific data type object (e.g., ``torch.int32``).\n        If ``None`` (default), DGL infers the ID type from the :attr:`data` argument.\n        See \"Notes\" for more details.\n    device : device context, optional\n        The device of the returned graph, which should be a framework-specific device object\n        (e.g., ``torch.device``). If ``None`` (default), DGL uses the device of the tensors of\n        the :attr:`data` argument. If :attr:`data` is not a tuple of node-tensors, the\n        returned graph is on CPU.  If the specified :attr:`device` differs from that of the\n        provided tensors, it casts the given tensors to the specified device first.\n    row_sorted : bool, optional\n        Whether or not the rows of the COO are in ascending order.\n    col_sorted : bool, optional\n        Whether or not the columns of the COO are in ascending order within\n        each row. This only has an effect when ``row_sorted`` is True.\n\n    Returns\n    -------\n    DGLGraph\n        The created graph.\n\n    Notes\n    -----\n    1. If the :attr:`idtype` argument is not given then:\n\n       - in the case of the tuple of node-tensor format, DGL uses the\n         data type of the given ID tensors.\n       - in the case of the tuple of sequence format, DGL uses int64.\n\n       Once the graph has been created, you can change the data type by using\n       :func:`dgl.DGLGraph.long` or :func:`dgl.DGLGraph.int`.\n\n       If the specified :attr:`idtype` argument differs from the data type of the provided\n       tensors, it casts the given tensors to the specified data type first.\n    2. The most efficient construction approach is to provide a tuple of node tensors without\n       specifying :attr:`idtype` and :attr:`device`. This is because the returned graph shares\n       the storage with the input node-tensors in this case.\n    3. DGL internally maintains multiple copies of the graph structure in different\n       `sparse formats <https://en.wikipedia.org/wiki/Sparse_matrix>`_ and chooses the most\n       efficient one depending on the computation invoked. If memory usage becomes an issue\n       in the case of large graphs, use :func:`dgl.DGLGraph.formats` to restrict the allowed\n       formats.\n\n    Examples\n    --------\n\n    The following example uses PyTorch backend.\n\n    >>> import dgl\n    >>> import torch\n\n    Create a small three-edge graph.\n\n    >>> # Source nodes for edges (2, 1), (3, 2), (4, 3)\n    >>> src_ids = torch.tensor([2, 3, 4])\n    >>> # Destination nodes for edges (2, 1), (3, 2), (4, 3)\n    >>> dst_ids = torch.tensor([1, 2, 3])\n    >>> g = dgl.graph((src_ids, dst_ids))\n\n    Explicitly specify the number of nodes in the graph.\n\n    >>> g = dgl.graph((src_ids, dst_ids), num_nodes=100)\n\n    Create a graph on the first GPU with data type int32.\n\n    >>> g = dgl.graph((src_ids, dst_ids), idtype=torch.int32, device='cuda:0')\n\n    Creating a graph with CSR representation:\n\n    >>> g = dgl.graph(('csr', ([0, 0, 0, 1, 2, 3], [1, 2, 3], [])))\n\n    Create the same graph with CSR representation and edge IDs.\n\n    >>> g = dgl.graph(('csr', ([0, 0, 0, 1, 2, 3], [1, 2, 3], [0, 1, 2])))\n\n    See Also\n    --------\n    from_scipy\n    from_networkx\n    \"\"\"\n    if isinstance(data, spmatrix):\n        raise DGLError(\n            \"dgl.graph no longer supports graph construction from a SciPy \"\n            \"sparse matrix, use dgl.from_scipy instead.\"\n        )\n\n    if isinstance(data, nx.Graph):\n        raise DGLError(\n            \"dgl.graph no longer supports graph construction from a NetworkX \"\n            \"graph, use dgl.from_networkx instead.\"\n        )\n\n    (sparse_fmt, arrays), urange, vrange = utils.graphdata2tensors(data, idtype)\n    if num_nodes is not None:  # override the number of nodes\n        if num_nodes < max(urange, vrange):\n            raise DGLError(\n                \"The num_nodes argument must be larger than the max ID in the data,\"\n                \" but got {} and {}.\".format(num_nodes, max(urange, vrange) - 1)\n            )\n        urange, vrange = num_nodes, num_nodes\n\n    g = create_from_edges(\n        sparse_fmt,\n        arrays,\n        \"_N\",\n        \"_E\",\n        \"_N\",\n        urange,\n        vrange,\n        row_sorted=row_sorted,\n        col_sorted=col_sorted,\n    )\n\n    return g.to(device)\n\n\ndef hetero_from_shared_memory(name):\n    \"\"\"Create a heterograph from shared memory with the given name.\n\n    The newly created graph will have the same node types and edge types as the original graph.\n    But it does not have node features or edges features.\n\n    Paramaters\n    ----------\n    name : str\n        The name of the share memory\n\n    Returns\n    -------\n    HeteroGraph (in shared memory)\n    \"\"\"\n    g, ntypes, etypes = heterograph_index.create_heterograph_from_shared_memory(\n        name\n    )\n    return DGLGraph(g, ntypes, etypes)\n\n\ndef heterograph(data_dict, num_nodes_dict=None, idtype=None, device=None):\n    \"\"\"Create a heterogeneous graph and return.\n\n    Parameters\n    ----------\n    data_dict : graph data\n        The dictionary data for constructing a heterogeneous graph. The keys are in the form of\n        string triplets (src_type, edge_type, dst_type), specifying the source node,\n        edge, and destination node types. The values are graph data in the form of\n        :math:`(U, V)`, where :math:`(U[i], V[i])` forms the edge with ID :math:`i`.\n        The allowed graph data formats are:\n\n        - ``(Tensor, Tensor)``: Each tensor must be a 1D tensor containing node IDs. DGL calls\n          this format \"tuple of node-tensors\". The tensors should have the same data type,\n          which must be either int32 or int64. They should also have the same device context\n          (see below the descriptions of :attr:`idtype` and :attr:`device`).\n        - ``('coo', (Tensor, Tensor))``: Same as ``(Tensor, Tensor)``.\n        - ``('csr', (Tensor, Tensor, Tensor))``: The three tensors form the CSR representation\n          of the graph's adjacency matrix.  The first one is the row index pointer.  The\n          second one is the column indices.  The third one is the edge IDs, which can be empty\n          (i.e. with 0 elements) to represent consecutive integer IDs starting from 0.\n        - ``('csc', (Tensor, Tensor, Tensor))``: The three tensors form the CSC representation\n          of the graph's adjacency matrix.  The first one is the column index pointer.  The\n          second one is the row indices.  The third one is the edge IDs, which can be empty\n          to represent consecutive integer IDs starting from 0.\n\n        The tensors can be replaced with any iterable of integers (e.g. list, tuple,\n        numpy.ndarray).\n    num_nodes_dict : dict[str, int], optional\n        The number of nodes for some node types, which is a dictionary mapping a node type\n        :math:`T` to the number of :math:`T`-typed nodes. If not given for a node type\n        :math:`T`, DGL finds the largest ID appearing in *every* graph data whose source\n        or destination node type is :math:`T`, and sets the number of nodes to be that ID\n        plus one. If given and the value is no greater than the largest ID for some node type,\n        DGL will raise an error. By default, DGL infers the number of nodes for all node types.\n    idtype : int32 or int64, optional\n        The data type for storing the structure-related graph information such as node and\n        edge IDs. It should be a framework-specific data type object (e.g., ``torch.int32``).\n        If ``None`` (default), DGL infers the ID type from the :attr:`data_dict` argument.\n    device : device context, optional\n        The device of the returned graph, which should be a framework-specific device object\n        (e.g., ``torch.device``). If ``None`` (default), DGL uses the device of the tensors of\n        the :attr:`data` argument. If :attr:`data` is not a tuple of node-tensors, the\n        returned graph is on CPU.  If the specified :attr:`device` differs from that of the\n        provided tensors, it casts the given tensors to the specified device first.\n\n    Returns\n    -------\n    DGLGraph\n        The created graph.\n\n    Notes\n    -----\n    1. If the :attr:`idtype` argument is not given then:\n\n       - in the case of the tuple of node-tensor format, DGL uses\n         the data type of the given ID tensors.\n       - in the case of the tuple of sequence format, DGL uses int64.\n\n       Once the graph has been created, you can change the data type by using\n       :func:`dgl.DGLGraph.long` or :func:`dgl.DGLGraph.int`.\n\n       If the specified :attr:`idtype` argument differs from the data type of the provided\n       tensors, it casts the given tensors to the specified data type first.\n    2. The most efficient construction approach is to provide a tuple of node tensors without\n       specifying :attr:`idtype` and :attr:`device`. This is because the returned graph shares\n       the storage with the input node-tensors in this case.\n    3. DGL internally maintains multiple copies of the graph structure in different sparse\n       formats and chooses the most efficient one depending on the computation invoked.\n       If memory usage becomes an issue in the case of large graphs, use\n       :func:`dgl.DGLGraph.formats` to restrict the allowed formats.\n    4. DGL internally decides a deterministic order for the same set of node types and canonical\n       edge types, which does not necessarily follow the order in :attr:`data_dict`.\n\n    Examples\n    --------\n\n    The following example uses PyTorch backend.\n\n    >>> import dgl\n    >>> import torch\n\n    Create a heterograph with three canonical edge types.\n\n    >>> data_dict = {\n    ...     ('user', 'follows', 'user'): (torch.tensor([0, 1]), torch.tensor([1, 2])),\n    ...     ('user', 'follows', 'topic'): (torch.tensor([1, 1]), torch.tensor([1, 2])),\n    ...     ('user', 'plays', 'game'): (torch.tensor([0, 3]), torch.tensor([3, 4]))\n    ... }\n    >>> g = dgl.heterograph(data_dict)\n    >>> g\n    Graph(num_nodes={'game': 5, 'topic': 3, 'user': 4},\n          num_edges={('user', 'follows', 'topic'): 2, ('user', 'follows', 'user'): 2,\n                     ('user', 'plays', 'game'): 2},\n          metagraph=[('user', 'topic', 'follows'), ('user', 'user', 'follows'),\n                     ('user', 'game', 'plays')])\n\n    Explicitly specify the number of nodes for each node type in the graph.\n\n    >>> num_nodes_dict = {'user': 4, 'topic': 4, 'game': 6}\n    >>> g = dgl.heterograph(data_dict, num_nodes_dict=num_nodes_dict)\n\n    Create a graph on the first GPU with data type int32.\n\n    >>> g = dgl.heterograph(data_dict, idtype=torch.int32, device='cuda:0')\n    \"\"\"\n    # Convert all data to node tensors first\n    node_tensor_dict = {}\n    need_infer = num_nodes_dict is None\n    if num_nodes_dict is None:\n        num_nodes_dict = defaultdict(int)\n    for (sty, ety, dty), data in data_dict.items():\n        if isinstance(data, spmatrix):\n            raise DGLError(\n                \"dgl.heterograph no longer supports graph construction from a SciPy \"\n                \"sparse matrix, use dgl.from_scipy instead.\"\n            )\n\n        if isinstance(data, nx.Graph):\n            raise DGLError(\n                \"dgl.heterograph no longer supports graph construction from a NetworkX \"\n                \"graph, use dgl.from_networkx instead.\"\n            )\n        is_bipartite = sty != dty\n        (sparse_fmt, arrays), urange, vrange = utils.graphdata2tensors(\n            data, idtype, bipartite=is_bipartite\n        )\n        node_tensor_dict[(sty, ety, dty)] = (sparse_fmt, arrays)\n        if need_infer:\n            num_nodes_dict[sty] = max(num_nodes_dict[sty], urange)\n            num_nodes_dict[dty] = max(num_nodes_dict[dty], vrange)\n        else:  # sanity check\n            if num_nodes_dict[sty] < urange:\n                raise DGLError(\n                    \"The given number of nodes of node type {} must be larger than\"\n                    \" the max ID in the data, but got {} and {}.\".format(\n                        sty, num_nodes_dict[sty], urange - 1\n                    )\n                )\n            if num_nodes_dict[dty] < vrange:\n                raise DGLError(\n                    \"The given number of nodes of node type {} must be larger than\"\n                    \" the max ID in the data, but got {} and {}.\".format(\n                        dty, num_nodes_dict[dty], vrange - 1\n                    )\n                )\n    # Create the graph\n    (\n        metagraph,\n        ntypes,\n        etypes,\n        relations,\n    ) = heterograph_index.create_metagraph_index(\n        num_nodes_dict.keys(), node_tensor_dict.keys()\n    )\n    num_nodes_per_type = utils.toindex(\n        [num_nodes_dict[ntype] for ntype in ntypes], \"int64\"\n    )\n    rel_graphs = []\n    for srctype, etype, dsttype in relations:\n        sparse_fmt, arrays = node_tensor_dict[(srctype, etype, dsttype)]\n        g = create_from_edges(\n            sparse_fmt,\n            arrays,\n            srctype,\n            etype,\n            dsttype,\n            num_nodes_dict[srctype],\n            num_nodes_dict[dsttype],\n        )\n        rel_graphs.append(g)\n\n    # create graph index\n    hgidx = heterograph_index.create_heterograph_from_relations(\n        metagraph, [rgrh._graph for rgrh in rel_graphs], num_nodes_per_type\n    )\n    retg = DGLGraph(hgidx, ntypes, etypes)\n\n    return retg.to(device)\n\n\ndef create_block(\n    data_dict,\n    num_src_nodes=None,\n    num_dst_nodes=None,\n    idtype=None,\n    device=None,\n    node_count_check=True,\n):\n    \"\"\"Create a message flow graph (MFG) as a :class:`DGLBlock` object.\n\n    Parameters\n    ----------\n    data_dict : graph data\n        The dictionary data for constructing a MFG. The keys are in the form of\n        string triplets (src_type, edge_type, dst_type), specifying the source node type,\n        edge type, and destination node type. The values are graph data in the form of\n        :math:`(U, V)`, where :math:`(U[i], V[i])` forms the edge with ID :math:`i`.\n        The allowed graph data formats are:\n\n        - ``(Tensor, Tensor)``: Each tensor must be a 1D tensor containing node IDs. DGL calls\n          this format \"tuple of node-tensors\". The tensors should have the same data type,\n          which must be either int32 or int64. They should also have the same device context\n          (see below the descriptions of :attr:`idtype` and :attr:`device`).\n        - ``('coo', (Tensor, Tensor))``: Same as ``(Tensor, Tensor)``.\n        - ``('csr', (Tensor, Tensor, Tensor))``: The three tensors form the CSR representation\n          of the graph's adjacency matrix.  The first one is the row index pointer.  The\n          second one is the column indices.  The third one is the edge IDs, which can be empty\n          to represent consecutive integer IDs starting from 0.\n        - ``('csc', (Tensor, Tensor, Tensor))``: The three tensors form the CSC representation\n          of the graph's adjacency matrix.  The first one is the column index pointer.  The\n          second one is the row indices.  The third one is the edge IDs, which can be empty\n          to represent consecutive integer IDs starting from 0.\n\n        The tensors can be replaced with any iterable of integers (e.g. list, tuple,\n        numpy.ndarray).\n\n        If you would like to create a MFG with a single source node type, a single destination\n        node type, and a single edge type, then you can pass in the graph data directly\n        without wrapping it as a dictionary.\n    num_src_nodes : dict[str, int] or int, optional\n        The number of nodes for each source node type, which is a dictionary mapping a node type\n        :math:`T` to the number of :math:`T`-typed source nodes.\n\n        If not given for a node type :math:`T`, DGL finds the largest ID appearing in *every*\n        graph data whose source node type is :math:`T`, and sets the number of nodes to\n        be that ID plus one. If given and the value is no greater than the largest ID for some\n        source node type, DGL will raise an error. By default, DGL infers the number of nodes for\n        all source node types.\n\n        If you would like to create a MFG with a single source node type, a single destination\n        node type, and a single edge type, then you can pass in an integer to directly\n        represent the number of source nodes.\n    num_dst_nodes : dict[str, int] or int, optional\n        The number of nodes for each destination node type, which is a dictionary mapping a node\n        type :math:`T` to the number of :math:`T`-typed destination nodes.\n\n        If not given for a node type :math:`T`, DGL finds the largest ID appearing in *every*\n        graph data whose destination node type is :math:`T`, and sets the number of nodes to\n        be that ID plus one. If given and the value is no greater than the largest ID for some\n        destination node type, DGL will raise an error. By default, DGL infers the number of nodes\n        for all destination node types.\n\n        If you would like to create a MFG with a single destination node type, a single\n        destination node type, and a single edge type, then you can pass in an integer to directly\n        represent the number of destination nodes.\n    idtype : int32 or int64, optional\n        The data type for storing the structure-related graph information such as node and\n        edge IDs. It should be a framework-specific data type object (e.g., ``torch.int32``).\n        If ``None`` (default), DGL infers the ID type from the :attr:`data_dict` argument.\n    device : device context, optional\n        The device of the returned graph, which should be a framework-specific device object\n        (e.g., ``torch.device``). If ``None`` (default), DGL uses the device of the tensors of\n        the :attr:`data` argument. If :attr:`data` is not a tuple of node-tensors, the\n        returned graph is on CPU.  If the specified :attr:`device` differs from that of the\n        provided tensors, it casts the given tensors to the specified device first.\n    node_count_check : bool, optional\n        When num_src_nodes and num_dst_nodes are passed, whether we should perform\n        sanity checks to ensure they are valid.\n\n    Returns\n    -------\n    DGLBlock\n        The created MFG.\n\n    Notes\n    -----\n    1. If the :attr:`idtype` argument is not given then:\n\n       - in the case of the tuple of node-tensor format, DGL uses\n         the data type of the given ID tensors.\n       - in the case of the tuple of sequence format, DGL uses int64.\n\n       Once the graph has been created, you can change the data type by using\n       :func:`dgl.DGLGraph.long` or :func:`dgl.DGLGraph.int`.\n\n       If the specified :attr:`idtype` argument differs from the data type of the provided\n       tensors, it casts the given tensors to the specified data type first.\n    2. The most efficient construction approach is to provide a tuple of node tensors without\n       specifying :attr:`idtype` and :attr:`device`. This is because the returned graph shares\n       the storage with the input node-tensors in this case.\n    3. DGL internally maintains multiple copies of the graph structure in different sparse\n       formats and chooses the most efficient one depending on the computation invoked.\n       If memory usage becomes an issue in the case of large graphs, use\n       :func:`dgl.DGLGraph.formats` to restrict the allowed formats.\n    4. DGL internally decides a deterministic order for the same set of node types and canonical\n       edge types, which does not necessarily follow the order in :attr:`data_dict`.\n\n    Examples\n    --------\n\n    The following example uses PyTorch backend.\n\n    >>> import dgl\n    >>> block = dgl.create_block(([0, 1, 2], [1, 2, 3]), num_src_nodes=3, num_dst_nodes=4)\n    >>> block\n    Block(num_src_nodes=3, num_dst_nodes=4, num_edges=3)\n\n    >>> block = dgl.create_block({\n    ...     ('A', 'AB', 'B'): ([1, 2, 3], [2, 1, 0]),\n    ...     ('B', 'BA', 'A'): ([2, 1], [2, 3])},\n    ...     num_src_nodes={'A': 6, 'B': 5},\n    ...     num_dst_nodes={'A': 4, 'B': 3})\n    >>> block\n    Block(num_src_nodes={'A': 6, 'B': 5},\n          num_dst_nodes={'A': 4, 'B': 3},\n          num_edges={('A', 'AB', 'B'): 3, ('B', 'BA', 'A'): 2},\n          metagraph=[('A', 'B', 'AB'), ('B', 'A', 'BA')])\n\n    See also\n    --------\n    to_block\n    \"\"\"\n    need_infer = num_src_nodes is None and num_dst_nodes is None\n    if not isinstance(data_dict, Mapping):\n        data_dict = {(\"_N\", \"_E\", \"_N\"): data_dict}\n\n        if not need_infer:\n            assert isinstance(\n                num_src_nodes, int\n            ), \"num_src_nodes must be a pair of integers if data_dict is not a dict\"\n            assert isinstance(\n                num_dst_nodes, int\n            ), \"num_dst_nodes must be a pair of integers if data_dict is not a dict\"\n            num_src_nodes = {\"_N\": num_src_nodes}\n            num_dst_nodes = {\"_N\": num_dst_nodes}\n    else:\n        if not need_infer:\n            assert isinstance(\n                num_src_nodes, Mapping\n            ), \"num_src_nodes must be a dict if data_dict is a dict\"\n            assert isinstance(\n                num_dst_nodes, Mapping\n            ), \"num_dst_nodes must be a dict if data_dict is a dict\"\n\n    if need_infer:\n        num_src_nodes = defaultdict(int)\n        num_dst_nodes = defaultdict(int)\n\n    # Convert all data to node tensors first\n    node_tensor_dict = {}\n    for (sty, ety, dty), data in data_dict.items():\n        (sparse_fmt, arrays), urange, vrange = utils.graphdata2tensors(\n            data,\n            idtype,\n            bipartite=True,\n            infer_node_count=need_infer or node_count_check,\n        )\n        node_tensor_dict[(sty, ety, dty)] = (sparse_fmt, arrays)\n        if need_infer:\n            num_src_nodes[sty] = max(num_src_nodes[sty], urange)\n            num_dst_nodes[dty] = max(num_dst_nodes[dty], vrange)\n        elif node_count_check:  # sanity check\n            if num_src_nodes[sty] < urange:\n                raise DGLError(\n                    \"The given number of nodes of source node type {} must be larger\"\n                    \" than the max ID in the data, but got {} and {}.\".format(\n                        sty, num_src_nodes[sty], urange - 1\n                    )\n                )\n            if num_dst_nodes[dty] < vrange:\n                raise DGLError(\n                    \"The given number of nodes of destination node type {} must be\"\n                    \" larger than the max ID in the data, but got {} and {}.\".format(\n                        dty, num_dst_nodes[dty], vrange - 1\n                    )\n                )\n    # Create the graph\n\n    # Sort the ntypes and relation tuples to have a deterministic order for the same set\n    # of type names.\n    srctypes = list(sorted(num_src_nodes.keys()))\n    dsttypes = list(sorted(num_dst_nodes.keys()))\n    relations = list(sorted(node_tensor_dict.keys()))\n\n    num_nodes_per_type = utils.toindex(\n        [num_src_nodes[ntype] for ntype in srctypes]\n        + [num_dst_nodes[ntype] for ntype in dsttypes],\n        \"int64\",\n    )\n    srctype_dict = {ntype: i for i, ntype in enumerate(srctypes)}\n    dsttype_dict = {\n        ntype: i + len(srctypes) for i, ntype in enumerate(dsttypes)\n    }\n\n    meta_edges_src = []\n    meta_edges_dst = []\n    etypes = []\n    rel_graphs = []\n    for srctype, etype, dsttype in relations:\n        meta_edges_src.append(srctype_dict[srctype])\n        meta_edges_dst.append(dsttype_dict[dsttype])\n        etypes.append(etype)\n        sparse_fmt, arrays = node_tensor_dict[(srctype, etype, dsttype)]\n        g = create_from_edges(\n            sparse_fmt,\n            arrays,\n            \"SRC/\" + srctype,\n            etype,\n            \"DST/\" + dsttype,\n            num_src_nodes[srctype],\n            num_dst_nodes[dsttype],\n        )\n        rel_graphs.append(g)\n\n    # metagraph is DGLGraph, currently still using int64 as index dtype\n    metagraph = graph_index.from_coo(\n        len(srctypes) + len(dsttypes), meta_edges_src, meta_edges_dst, True\n    )\n    # create graph index\n    hgidx = heterograph_index.create_heterograph_from_relations(\n        metagraph, [rgrh._graph for rgrh in rel_graphs], num_nodes_per_type\n    )\n    retg = DGLBlock(hgidx, (srctypes, dsttypes), etypes)\n\n    return retg.to(device)\n\n\ndef block_to_graph(block):\n    \"\"\"Convert a message flow graph (MFG) as a :class:`DGLBlock` object to a :class:`DGLGraph`.\n\n    DGL will rename all the source node types by suffixing with ``_src``, and\n    all the destination node types by suffixing with ``_dst``.\n\n    Features on the returned graph will be preserved.\n\n    Parameters\n    ----------\n    block : DGLBlock\n        The MFG.\n\n    Returns\n    -------\n    DGLGraph\n        The graph.\n\n    Examples\n    --------\n    >>> block = dgl.create_block({\n    ...     ('A', 'AB', 'B'): ([1, 2, 3], [2, 1, 0]),\n    ...     ('B', 'BA', 'A'): ([2, 1], [2, 3])})\n    >>> g = dgl.block_to_graph(block)\n    >>> g\n    Graph(num_nodes={'A_src': 4, 'B_src': 3, 'A_dst': 4, 'B_dst': 3},\n          num_edges={('A_src', 'AB', 'B_dst'): 3, ('B_src', 'BA', 'A_dst'): 2},\n          metagraph=[('A_src', 'B_dst', 'AB'), ('B_src', 'A_dst', 'BA')])\n    \"\"\"\n    new_types = [ntype + \"_src\" for ntype in block.srctypes] + [\n        ntype + \"_dst\" for ntype in block.dsttypes\n    ]\n    retg = DGLGraph(block._graph, new_types, block.etypes)\n\n    for srctype in block.srctypes:\n        retg.nodes[srctype + \"_src\"].data.update(block.srcnodes[srctype].data)\n    for dsttype in block.dsttypes:\n        retg.nodes[dsttype + \"_dst\"].data.update(block.dstnodes[dsttype].data)\n    for srctype, etype, dsttype in block.canonical_etypes:\n        retg.edges[srctype + \"_src\", etype, dsttype + \"_dst\"].data.update(\n            block.edges[srctype, etype, dsttype].data\n        )\n\n    return retg\n\n\ndef to_heterogeneous(\n    G, ntypes, etypes, ntype_field=NTYPE, etype_field=ETYPE, metagraph=None\n):\n    \"\"\"Convert a homogeneous graph to a heterogeneous graph and return.\n\n    The input graph should have only one type of nodes and edges. Each node and edge\n    stores an integer feature as its type ID\n    (specified by :attr:`ntype_field` and :attr:`etype_field`).\n    DGL uses it to retrieve the type names stored in the given\n    :attr:`ntypes` and :attr:`etypes` arguments.\n\n    The function will automatically distinguish edge types that have the same given\n    type IDs but different src and dst type IDs. For example, it allows both edges A and B\n    to have the same type ID 0, but one has (0, 1) and the other as (2, 3) as the\n    (src, dst) type IDs. In this case, the function will \"split\" edge type 0 into two types:\n    (0, ty_A, 1) and (2, ty_B, 3). In another word, these two edges share the same edge\n    type name, but can be distinguished by an edge type triplet.\n\n    The function stores the node and edge IDs in the input graph using the ``dgl.NID``\n    and ``dgl.EID`` names in the ``ndata`` and ``edata`` of the resulting graph.\n    It also copies any node/edge features from :attr:`G` to the returned heterogeneous\n    graph, except for reserved fields for storing type IDs (``dgl.NTYPE`` and ``dgl.ETYPE``)\n    and node/edge IDs (``dgl.NID`` and ``dgl.EID``).\n\n    Parameters\n    ----------\n    G : DGLGraph\n        The homogeneous graph.\n    ntypes : list[str]\n        The node type names.\n    etypes : list[str]\n        The edge type names.\n    ntype_field : str, optional\n        The feature field used to store node type. (Default: ``dgl.NTYPE``)\n    etype_field : str, optional\n        The feature field used to store edge type. (Default: ``dgl.ETYPE``)\n    metagraph : networkx MultiDiGraph, optional\n        Metagraph of the returned heterograph.\n        If provided, DGL assumes that G can indeed be described with the given metagraph.\n        If None, DGL will infer the metagraph from the given inputs, which could be\n        costly for large graphs.\n\n    Returns\n    -------\n    DGLGraph\n        A heterogeneous graph.\n\n    Notes\n    -----\n    * The returned node and edge types may not necessarily be in the same order as\n      ``ntypes`` and ``etypes``.\n    * Calling :func:`~dgl.to_homogeneous` then calling :func:`~dgl.to_heterogeneous` again\n      yields the same result.\n\n    Examples\n    --------\n\n    The following example uses PyTorch backend.\n\n    >>> import dgl\n    >>> import torch\n\n    >>> hg = dgl.heterograph({\n    ...     ('user', 'develops', 'activity'): (torch.tensor([0, 1]), torch.tensor([1, 2])),\n    ...     ('developer', 'develops', 'game'): (torch.tensor([0, 1]), torch.tensor([0, 1]))\n    ... })\n    >>> print(hg)\n    Graph(num_nodes={'activity': 3, 'developer': 2, 'game': 2, 'user': 2},\n          num_edges={('developer', 'develops', 'game'): 2, ('user', 'develops', 'activity'): 2},\n          metagraph=[('developer', 'game', 'develops'), ('user', 'activity', 'develops')])\n\n    We first convert the heterogeneous graph to a homogeneous graph.\n\n    >>> g = dgl.to_homogeneous(hg)\n    >>> print(g)\n    Graph(num_nodes=9, num_edges=4,\n          ndata_schemes={'_TYPE': Scheme(shape=(), dtype=torch.int64),\n                         '_ID': Scheme(shape=(), dtype=torch.int64)}\n          edata_schemes={'_TYPE': Scheme(shape=(), dtype=torch.int64),\n                         '_ID': Scheme(shape=(), dtype=torch.int64)})\n    >>> g.ndata\n    {'_TYPE': tensor([0, 0, 0, 1, 1, 2, 2, 3, 3]), '_ID': tensor([0, 1, 2, 0, 1, 0, 1, 0, 1])}\n    Nodes 0, 1, 2 for 'activity', 3, 4 for 'developer', 5, 6 for 'game', 7, 8 for 'user'\n    >>> g.edata\n    {'_TYPE': tensor([0, 0, 1, 1]), '_ID': tensor([0, 1, 0, 1])}\n    Edges 0, 1 for ('developer', 'develops', 'game'), 2, 3 for ('user', 'develops', 'activity')\n\n    Now convert the homogeneous graph back to a heterogeneous graph.\n\n    >>> hg_2 = dgl.to_heterogeneous(g, hg.ntypes, hg.etypes)\n    >>> print(hg_2)\n    Graph(num_nodes={'activity': 3, 'developer': 2, 'game': 2, 'user': 2},\n          num_edges={('developer', 'develops', 'game'): 2, ('user', 'develops', 'activity'): 2},\n          metagraph=[('developer', 'game', 'develops'), ('user', 'activity', 'develops')])\n\n    Retrieve the original node/edge IDs.\n\n    >>> hg_2.ndata[dgl.NID]\n    {'activity': tensor([0, 1, 2]),\n     'developer': tensor([3, 4]),\n     'game': tensor([5, 6]),\n     'user': tensor([7, 8])}\n    >>> hg_2.edata[dgl.EID]\n    {('developer', 'develops', 'game'): tensor([0, 1]),\n     ('user', 'develops', 'activity'): tensor([2, 3])}\n\n    See Also\n    --------\n    to_homogeneous\n    \"\"\"\n    if (\n        hasattr(G, \"ntypes\")\n        and len(G.ntypes) > 1\n        or hasattr(G, \"etypes\")\n        and len(G.etypes) > 1\n    ):\n        raise DGLError(\n            \"The input graph should be homogeneous and have only one \"\n            \" type of nodes and edges.\"\n        )\n\n    num_ntypes = len(ntypes)\n    idtype = G.idtype\n    device = G.device\n\n    ntype_ids = F.asnumpy(G.ndata[ntype_field])\n    etype_ids = F.asnumpy(G.edata[etype_field])\n\n    # relabel nodes to per-type local IDs\n    ntype_count = np.bincount(ntype_ids, minlength=num_ntypes)\n    ntype_offset = np.insert(np.cumsum(ntype_count), 0, 0)\n    ntype_ids_sortidx = np.argsort(ntype_ids, kind=\"stable\")\n    ntype_local_ids = np.zeros_like(ntype_ids)\n    node_groups = []\n    for i in range(num_ntypes):\n        node_group = ntype_ids_sortidx[ntype_offset[i] : ntype_offset[i + 1]]\n        node_groups.append(node_group)\n        ntype_local_ids[node_group] = np.arange(ntype_count[i])\n\n    src, dst = G.all_edges(order=\"eid\")\n    src = F.asnumpy(src)\n    dst = F.asnumpy(dst)\n    src_local = ntype_local_ids[src]\n    dst_local = ntype_local_ids[dst]\n    # a 2D tensor of shape (E, 3). Each row represents the (stid, etid, dtid) tuple.\n    edge_ctids = np.stack([ntype_ids[src], etype_ids, ntype_ids[dst]], 1)\n\n    # infer metagraph and canonical edge types\n    # No matter which branch it takes, the code will generate a 2D tensor of shape (E_m, 3),\n    # E_m is the set of all possible canonical edge tuples. Each row represents the\n    # (stid, dtid, dtid) tuple. We then compute a 2D tensor of shape (E, E_m) using the\n    # above ``edge_ctids`` matrix. Each element i,j indicates whether the edge i is of the\n    # canonical edge type j. We can then group the edges of the same type together.\n    if metagraph is None:\n        canonical_etids, _, etype_remapped = utils.make_invmap(\n            list(tuple(_) for _ in edge_ctids), False\n        )\n        etype_mask = (\n            etype_remapped[None, :] == np.arange(len(canonical_etids))[:, None]\n        )\n    else:\n        ntypes_invmap = {nt: i for i, nt in enumerate(ntypes)}\n        etypes_invmap = {et: i for i, et in enumerate(etypes)}\n        canonical_etids = []\n        for i, (srctype, dsttype, etype) in enumerate(\n            metagraph.edges(keys=True)\n        ):\n            srctype_id = ntypes_invmap[srctype]\n            etype_id = etypes_invmap[etype]\n            dsttype_id = ntypes_invmap[dsttype]\n            canonical_etids.append((srctype_id, etype_id, dsttype_id))\n        canonical_etids = np.asarray(canonical_etids)\n        etype_mask = (edge_ctids[None, :] == canonical_etids[:, None]).all(2)\n    edge_groups = [\n        etype_mask[i].nonzero()[0] for i in range(len(canonical_etids))\n    ]\n\n    data_dict = dict()\n    canonical_etypes = []\n    for i, (stid, etid, dtid) in enumerate(canonical_etids):\n        src_of_etype = src_local[edge_groups[i]]\n        dst_of_etype = dst_local[edge_groups[i]]\n        canonical_etypes.append((ntypes[stid], etypes[etid], ntypes[dtid]))\n        data_dict[canonical_etypes[-1]] = (src_of_etype, dst_of_etype)\n    hg = heterograph(\n        data_dict, dict(zip(ntypes, ntype_count)), idtype=idtype, device=device\n    )\n\n    ntype2ngrp = {ntype: node_groups[ntid] for ntid, ntype in enumerate(ntypes)}\n\n    # features\n    for key, data in G.ndata.items():\n        if key in [ntype_field, NID]:\n            continue\n        for ntid, ntype in enumerate(hg.ntypes):\n            rows = F.copy_to(F.tensor(ntype2ngrp[ntype]), F.context(data))\n            hg._node_frames[ntid][key] = F.gather_row(data, rows)\n\n    for key, data in G.edata.items():\n        if key in [etype_field, EID]:\n            continue\n        for etid in range(len(hg.canonical_etypes)):\n            rows = F.copy_to(F.tensor(edge_groups[etid]), F.context(data))\n            hg._edge_frames[hg.get_etype_id(canonical_etypes[etid])][\n                key\n            ] = F.gather_row(data, rows)\n\n    # Record the original IDs of the nodes/edges\n    for ntid, ntype in enumerate(hg.ntypes):\n        hg._node_frames[ntid][NID] = F.copy_to(\n            F.tensor(ntype2ngrp[ntype]), device\n        )\n    for etid in range(len(hg.canonical_etypes)):\n        hg._edge_frames[hg.get_etype_id(canonical_etypes[etid])][\n            EID\n        ] = F.copy_to(F.tensor(edge_groups[etid]), device)\n\n    return hg\n\n\ndef to_homogeneous(\n    G, ndata=None, edata=None, store_type=True, return_count=False\n):\n    \"\"\"Convert a heterogeneous graph to a homogeneous graph and return.\n\n    By default, the function stores the node and edge types of the input graph as\n    the ``dgl.NTYPE`` and ``dgl.ETYPE`` features in the returned graph.\n    Each feature is an integer representing the type id, determined by the\n    :meth:`DGLGraph.get_ntype_id` and :meth:`DGLGraph.get_etype_id` methods.\n    One can omit it by specifying ``store_type=False``.\n\n    The result graph assigns nodes and edges of the same type with IDs in continuous range\n    (i.e., nodes of the first type have IDs 0 ~ ``G.num_nodes(G.ntypes[0])``; nodes\n    of the second type come after; so on and so forth). Therefore, a more memory-efficient\n    format for type information is an integer list; the i^th corresponds to\n    the number of nodes/edges of the i^th type. One can choose this format by\n    specifying ``return_count=True``.\n\n    Parameters\n    ----------\n    G : DGLGraph\n        The heterogeneous graph.\n    ndata : list[str], optional\n        The node features to combine across all node types. For each feature ``feat`` in\n        :attr:`ndata`, it concatenates ``G.nodes[T].data[feat]`` across all node types ``T``.\n        As a result, the feature ``feat`` of all node types should have the same shape and\n        data type. By default, the returned graph will not have any node features.\n    edata : list[str], optional\n        The edge features to combine across all edge types. For each feature ``feat`` in\n        :attr:`edata`, it concatenates ``G.edges[T].data[feat]`` across all edge types ``T``.\n        As a result, the feature ``feat`` of all edge types should have the same shape and\n        data type. By default, the returned graph will not have any edge features.\n    store_type : bool, optional\n        If True, store type information as the ``dgl.NTYPE`` and ``dgl.ETYPE`` features\n        in the returned graph.\n    return_count : bool, optional\n        If True, return type information as an integer list; the i^th element corresponds to\n        the number of nodes/edges of the i^th type.\n\n    Returns\n    -------\n    DGLGraph\n        A homogeneous graph.\n    ntype_count : list[int], optional\n        Number of nodes of each type. Return when ``return_count`` is True.\n    etype_count : list[int], optional\n        Number of edges of each type. Return when ``return_count`` is True.\n\n    Notes\n    -----\n\n    * Calculating type information may introduce noticeable cost. Setting both ``store_type``\n      and ``return_count`` to False can avoid such cost if type information is not needed.\n      Otherwise, DGL recommends to use ``store_type=False`` and ``return_count=True`` due\n      to its memory efficiency.\n    * The ``ntype_count`` and ``etype_count`` lists can help speed up some operations.\n      See :class:`~dgl.nn.pytorch.conv.RelGraphConv` for such an example.\n    * Calling :func:`~dgl.to_homogeneous` then calling :func:`~dgl.to_heterogeneous` again\n      yields the same result.\n\n    Examples\n    --------\n\n    The following example uses PyTorch backend.\n\n    >>> import dgl\n    >>> import torch\n\n    >>> hg = dgl.heterograph({\n    ...     ('user', 'follows', 'user'): ([0, 1], [1, 2]),\n    ...     ('developer', 'develops', 'game'): ([0, 1], [0, 1])\n    ...     })\n    >>> hg.nodes['user'].data['h'] = torch.ones(3, 1)\n    >>> hg.nodes['developer'].data['h'] = torch.zeros(2, 1)\n    >>> hg.nodes['game'].data['h'] = torch.ones(2, 1)\n    >>> g = dgl.to_homogeneous(hg)\n    >>> # The first three nodes are for 'user', the next two are for 'developer',\n    >>> # and the last two are for 'game'\n    >>> g.ndata\n    {'_TYPE': tensor([0, 0, 0, 1, 1, 2, 2]), '_ID': tensor([0, 1, 2, 0, 1, 0, 1])}\n    >>> # The first two edges are for 'follows', and the next two are for 'develops' edges.\n    >>> g.edata\n    {'_TYPE': tensor([0, 0, 1, 1]), '_ID': tensor([0, 1, 0, 1])}\n\n    Combine feature 'h' across all node types in the conversion.\n\n    >>> g = dgl.to_homogeneous(hg, ndata=['h'])\n    >>> g.ndata['h']\n    tensor([[1.], [1.], [1.], [0.], [0.], [1.], [1.]])\n\n    See Also\n    --------\n    to_heterogeneous\n    \"\"\"\n    num_nodes_per_ntype = [G.num_nodes(ntype) for ntype in G.ntypes]\n    offset_per_ntype = np.insert(np.cumsum(num_nodes_per_ntype), 0, 0)\n    srcs = []\n    dsts = []\n    nids = []\n    eids = []\n    if store_type:\n        ntype_ids = []\n        etype_ids = []\n    if return_count:\n        ntype_count = []\n        etype_count = []\n    total_num_nodes = 0\n\n    for ntype_id, ntype in enumerate(G.ntypes):\n        num_nodes = G.num_nodes(ntype)\n        total_num_nodes += num_nodes\n        if store_type:\n            # Type ID is always in int64\n            ntype_ids.append(F.full_1d(num_nodes, ntype_id, F.int64, G.device))\n        if return_count:\n            ntype_count.append(num_nodes)\n        nids.append(F.arange(0, num_nodes, G.idtype, G.device))\n\n    for etype_id, etype in enumerate(G.canonical_etypes):\n        srctype, _, dsttype = etype\n        src, dst = G.all_edges(etype=etype, order=\"eid\")\n        num_edges = len(src)\n        srcs.append(src + int(offset_per_ntype[G.get_ntype_id(srctype)]))\n        dsts.append(dst + int(offset_per_ntype[G.get_ntype_id(dsttype)]))\n        if store_type:\n            # Type ID is always in int64\n            etype_ids.append(F.full_1d(num_edges, etype_id, F.int64, G.device))\n        if return_count:\n            etype_count.append(num_edges)\n        eids.append(F.arange(0, num_edges, G.idtype, G.device))\n\n    retg = graph(\n        (F.cat(srcs, 0), F.cat(dsts, 0)),\n        num_nodes=total_num_nodes,\n        idtype=G.idtype,\n        device=G.device,\n    )\n\n    # copy features\n    if ndata is None:\n        ndata = []\n    if edata is None:\n        edata = []\n    comb_nf = combine_frames(\n        G._node_frames, range(len(G.ntypes)), col_names=ndata\n    )\n    comb_ef = combine_frames(\n        G._edge_frames, range(len(G.etypes)), col_names=edata\n    )\n    if comb_nf is not None:\n        retg.ndata.update(comb_nf)\n    if comb_ef is not None:\n        retg.edata.update(comb_ef)\n\n    retg.ndata[NID] = F.cat(nids, 0)\n    retg.edata[EID] = F.cat(eids, 0)\n    if store_type:\n        retg.ndata[NTYPE] = F.cat(ntype_ids, 0)\n        retg.edata[ETYPE] = F.cat(etype_ids, 0)\n\n    if return_count:\n        return retg, ntype_count, etype_count\n    else:\n        return retg\n\n\ndef from_scipy(sp_mat, eweight_name=None, idtype=None, device=None):\n    \"\"\"Create a graph from a SciPy sparse matrix and return.\n\n    Parameters\n    ----------\n    sp_mat : scipy.sparse.spmatrix\n        The graph adjacency matrix. Each nonzero entry ``sp_mat[i, j]`` represents an edge from\n        node ``i`` to ``j``. The matrix must have square shape ``(N, N)``, where ``N`` is the\n        number of nodes in the graph.\n    eweight_name : str, optional\n        The edata name for storing the nonzero values of :attr:`sp_mat`. If given, DGL will\n        store the nonzero values of :attr:`sp_mat` in ``edata[eweight_name]`` of the returned\n        graph.\n    idtype : int32 or int64, optional\n        The data type for storing the structure-related graph information such as node and\n        edge IDs. It should be a framework-specific data type object (e.g., ``torch.int32``).\n        By default, DGL uses int64.\n    device : device context, optional\n        The device of the resulting graph. It should be a framework-specific device object\n        (e.g., ``torch.device``). By default, DGL stores the graph on CPU.\n\n    Returns\n    -------\n    DGLGraph\n        The created graph.\n\n    Notes\n    -----\n    1. The function supports all kinds of SciPy sparse matrix classes (e.g.,\n       :class:`scipy.sparse.csr.csr_matrix`). It converts the input matrix to the COOrdinate\n       format using :func:`scipy.sparse.spmatrix.tocoo` before creates a :class:`DGLGraph`.\n       Creating from a :class:`scipy.sparse.coo.coo_matrix` is hence the most efficient way.\n    2. DGL internally maintains multiple copies of the graph structure in different sparse\n       formats and chooses the most efficient one depending on the computation invoked.\n       If memory usage becomes an issue in the case of large graphs, use\n       :func:`dgl.DGLGraph.formats` to restrict the allowed formats.\n\n    Examples\n    --------\n\n    The following example uses PyTorch backend.\n\n    >>> import dgl\n    >>> import numpy as np\n    >>> import torch\n    >>> from scipy.sparse import coo_matrix\n\n    Create a small three-edge graph.\n\n    >>> # Source nodes for edges (2, 1), (3, 2), (4, 3)\n    >>> src_ids = np.array([2, 3, 4])\n    >>> # Destination nodes for edges (2, 1), (3, 2), (4, 3)\n    >>> dst_ids = np.array([1, 2, 3])\n    >>> # Weight for edges (2, 1), (3, 2), (4, 3)\n    >>> eweight = np.array([0.2, 0.3, 0.5])\n    >>> sp_mat = coo_matrix((eweight, (src_ids, dst_ids)), shape=(5, 5))\n    >>> g = dgl.from_scipy(sp_mat)\n\n    Retrieve the edge weights.\n\n    >>> g = dgl.from_scipy(sp_mat, eweight_name='w')\n    >>> g.edata['w']\n    tensor([0.2000, 0.3000, 0.5000], dtype=torch.float64)\n\n    Create a graph on the first GPU with data type int32.\n\n    >>> g = dgl.from_scipy(sp_mat, idtype=torch.int32, device='cuda:0')\n\n    See Also\n    --------\n    graph\n    from_networkx\n    \"\"\"\n    # Sanity check\n    num_rows = sp_mat.shape[0]\n    num_cols = sp_mat.shape[1]\n    if num_rows != num_cols:\n        raise DGLError(\n            \"Expect the number of rows to be the same as the number of columns for \"\n            \"sp_mat, got {:d} and {:d}.\".format(num_rows, num_cols)\n        )\n\n    (sparse_fmt, arrays), urange, vrange = utils.graphdata2tensors(\n        sp_mat, idtype\n    )\n    g = create_from_edges(sparse_fmt, arrays, \"_N\", \"_E\", \"_N\", urange, vrange)\n    if eweight_name is not None:\n        g.edata[eweight_name] = F.tensor(sp_mat.data)\n    return g.to(device)\n\n\ndef bipartite_from_scipy(\n    sp_mat, utype, etype, vtype, eweight_name=None, idtype=None, device=None\n):\n    \"\"\"Create a uni-directional bipartite graph from a SciPy sparse matrix and return.\n\n    The created graph will have two types of nodes ``utype`` and ``vtype`` as well as one\n    edge type ``etype`` whose edges are from ``utype`` to ``vtype``.\n\n    Parameters\n    ----------\n    sp_mat : scipy.sparse.spmatrix\n        The graph adjacency matrix. Each nonzero entry ``sp_mat[i, j]``\n        represents an edge from node ``i`` of type :attr:`utype` to ``j`` of type :attr:`vtype`.\n        Let the matrix shape be ``(N, M)``. There will be ``N`` nodes of type :attr:`utype`\n        and ``M`` nodes of type ``vtype`` in the resulting graph.\n    utype : str, optional\n        The name of the source node type.\n    etype : str, optional\n        The name of the edge type.\n    vtype : str, optional\n        The name of the destination node type.\n    eweight_name : str, optional\n        The edata name for storing the nonzero values of :attr:`sp_mat`.\n        If given, DGL will store the nonzero values of :attr:`sp_mat` in ``edata[eweight_name]``\n        of the returned graph.\n    idtype : int32 or int64, optional\n        The data type for storing the structure-related graph information such as node and\n        edge IDs. It should be a framework-specific data type object (e.g., ``torch.int32``).\n        By default, DGL uses int64.\n    device : device context, optional\n        The device of the resulting graph. It should be a framework-specific device object\n        (e.g., ``torch.device``). By default, DGL stores the graph on CPU.\n\n    Returns\n    -------\n    DGLGraph\n        The created graph.\n\n    Notes\n    -----\n    1. The function supports all kinds of SciPy sparse matrix classes (e.g.,\n       :class:`scipy.sparse.csr.csr_matrix`). It converts the input matrix to the COOrdinate\n       format using :func:`scipy.sparse.spmatrix.tocoo` before creates a :class:`DGLGraph`.\n       Creating from a :class:`scipy.sparse.coo.coo_matrix` is hence the most efficient way.\n    2. DGL internally maintains multiple copies of the graph structure in different sparse\n       formats and chooses the most efficient one depending on the computation invoked.\n       If memory usage becomes an issue in the case of large graphs, use\n       :func:`dgl.DGLGraph.formats` to restrict the allowed formats.\n\n    Examples\n    --------\n\n    The following example uses PyTorch backend.\n\n    >>> import dgl\n    >>> import numpy as np\n    >>> import torch\n    >>> from scipy.sparse import coo_matrix\n\n    Create a small three-edge graph.\n\n    >>> # Source nodes for edges (2, 1), (3, 2), (4, 3)\n    >>> src_ids = np.array([2, 3, 4])\n    >>> # Destination nodes for edges (2, 1), (3, 2), (4, 3)\n    >>> dst_ids = np.array([1, 2, 3])\n    >>> # Weight for edges (2, 1), (3, 2), (4, 3)\n    >>> eweight = np.array([0.2, 0.3, 0.5])\n    >>> sp_mat = coo_matrix((eweight, (src_ids, dst_ids)))\n    >>> g = dgl.bipartite_from_scipy(sp_mat, utype='_U', etype='_E', vtype='_V')\n\n    Retrieve the edge weights.\n\n    >>> g = dgl.bipartite_from_scipy(sp_mat, utype='_U', etype='_E', vtype='_V', eweight_name='w')\n    >>> g.edata['w']\n    tensor([0.2000, 0.3000, 0.5000], dtype=torch.float64)\n\n    Create a graph on the first GPU with data type int32.\n\n    >>> g = dgl.bipartite_from_scipy(sp_mat, utype='_U', etype='_E', vtype='_V',\n    ...                              idtype=torch.int32, device='cuda:0')\n\n    See Also\n    --------\n    heterograph\n    bipartite_from_networkx\n    \"\"\"\n    (sparse_fmt, arrays), urange, vrange = utils.graphdata2tensors(\n        sp_mat, idtype, bipartite=True\n    )\n    g = create_from_edges(\n        sparse_fmt, arrays, utype, etype, vtype, urange, vrange\n    )\n    if eweight_name is not None:\n        g.edata[eweight_name] = F.tensor(sp_mat.data)\n    return g.to(device)\n\n\ndef _batcher(lst):\n    if F.is_tensor(lst[0]):\n        return F.cat([F.unsqueeze(x, 0) for x in lst], dim=0)\n\n    if isinstance(lst[0], np.ndarray):\n        return F.tensor(np.array(lst))\n\n    return F.tensor(lst)\n\n\ndef from_networkx(\n    nx_graph,\n    node_attrs=None,\n    edge_attrs=None,\n    edge_id_attr_name=None,\n    idtype=None,\n    device=None,\n):\n    \"\"\"Create a graph from a NetworkX graph and return.\n\n    .. note::\n        Creating a DGLGraph from a NetworkX graph is not fast especially for large scales.\n        It is recommended to first convert a NetworkX graph into a tuple of node-tensors\n        and then construct a DGLGraph with :func:`dgl.graph`.\n\n    Parameters\n    ----------\n    nx_graph : networkx.Graph\n        The NetworkX graph holding the graph structure and the node/edge attributes.\n        DGL will relabel the nodes using consecutive integers starting from zero if it is\n        not the case. If the input graph is undirected, DGL converts it to a directed graph\n        by :func:`networkx.Graph.to_directed`.\n    node_attrs : list[str], optional\n        The names of the node attributes to retrieve from the NetworkX graph. If given, DGL\n        stores the retrieved node attributes in ``ndata`` of the returned graph using their\n        original names. The attribute data must be convertible to Tensor type (e.g., scalar,\n        numpy.ndarray, list, etc.).\n    edge_attrs : list[str], optional\n        The names of the edge attributes to retrieve from the NetworkX graph. If given, DGL\n        stores the retrieved edge attributes in ``edata`` of the returned graph using their\n        original names. The attribute data must be convertible to Tensor type (e.g., scalar,\n        ``numpy.ndarray``, list, etc.). It must be None if :attr:`nx_graph` is undirected.\n    edge_id_attr_name : str, optional\n        The name of the edge attribute that stores the edge IDs. If given, DGL will assign edge\n        IDs accordingly when creating the graph, so the attribute must be valid IDs, i.e.\n        consecutive integers starting from zero. By default, the edge IDs of the returned graph\n        can be arbitrary. It must be None if :attr:`nx_graph` is undirected.\n    idtype : int32 or int64, optional\n        The data type for storing the structure-related graph information such as node and\n        edge IDs. It should be a framework-specific data type object (e.g., ``torch.int32``).\n        By default, DGL uses int64.\n    device : device context, optional\n        The device of the resulting graph. It should be a framework-specific device object\n        (e.g., ``torch.device``). By default, DGL stores the graph on CPU.\n\n    Returns\n    -------\n    DGLGraph\n        The created graph.\n\n    Notes\n    -----\n    DGL internally maintains multiple copies of the graph structure in different sparse\n    formats and chooses the most efficient one depending on the computation invoked.\n    If memory usage becomes an issue in the case of large graphs, use\n    :func:`dgl.DGLGraph.formats` to restrict the allowed formats.\n\n    Examples\n    --------\n\n    The following example uses PyTorch backend.\n\n    >>> import dgl\n    >>> import networkx as nx\n    >>> import numpy as np\n    >>> import torch\n\n    Create a 2-edge NetworkX graph.\n\n    >>> nx_g = nx.DiGraph()\n    >>> # Add 3 nodes and two features for them\n    >>> nx_g.add_nodes_from([0, 1, 2], feat1=np.zeros((3, 1)), feat2=np.ones((3, 1)))\n    >>> # Add 2 edges (1, 2) and (2, 1) with two features, one being edge IDs\n    >>> nx_g.add_edge(1, 2, weight=np.ones((1, 1)), eid=np.array([1]))\n    >>> nx_g.add_edge(2, 1, weight=np.ones((1, 1)), eid=np.array([0]))\n\n    Convert it into a DGLGraph with structure only.\n\n    >>> g = dgl.from_networkx(nx_g)\n\n    Retrieve the node/edge features of the graph.\n\n    >>> g = dgl.from_networkx(nx_g, node_attrs=['feat1', 'feat2'], edge_attrs=['weight'])\n\n    Use a pre-specified ordering of the edges.\n\n    >>> g.edges()\n    (tensor([1, 2]), tensor([2, 1]))\n    >>> g = dgl.from_networkx(nx_g, edge_id_attr_name='eid')\n    (tensor([2, 1]), tensor([1, 2]))\n\n    Create a graph on the first GPU with data type int32.\n\n    >>> g = dgl.from_networkx(nx_g, idtype=torch.int32, device='cuda:0')\n\n    See Also\n    --------\n    graph\n    from_scipy\n    \"\"\"\n    # Sanity check\n    if (\n        edge_id_attr_name is not None\n        and edge_id_attr_name not in next(iter(nx_graph.edges(data=True)))[-1]\n    ):\n        raise DGLError(\n            \"Failed to find the pre-specified edge IDs in the edge features of \"\n            \"the NetworkX graph with name {}\".format(edge_id_attr_name)\n        )\n\n    if not nx_graph.is_directed() and not (\n        edge_id_attr_name is None and edge_attrs is None\n    ):\n        raise DGLError(\n            \"Expect edge_id_attr_name and edge_attrs to be None when nx_graph is \"\n            \"undirected, got {} and {}\".format(edge_id_attr_name, edge_attrs)\n        )\n\n    # Relabel nodes using consecutive integers starting from 0\n    nx_graph = nx.convert_node_labels_to_integers(nx_graph, ordering=\"sorted\")\n    if not nx_graph.is_directed():\n        nx_graph = nx_graph.to_directed()\n\n    (sparse_fmt, arrays), urange, vrange = utils.graphdata2tensors(\n        nx_graph, idtype, edge_id_attr_name=edge_id_attr_name\n    )\n\n    g = create_from_edges(sparse_fmt, arrays, \"_N\", \"_E\", \"_N\", urange, vrange)\n\n    # nx_graph.edges(data=True) returns src, dst, attr_dict\n    has_edge_id = (\n        nx_graph.number_of_edges() > 0 and edge_id_attr_name is not None\n    )\n\n    # handle features\n    # copy attributes\n    if node_attrs is not None:\n        # mapping from feature name to a list of tensors to be concatenated\n        attr_dict = defaultdict(list)\n        for nid in range(g.num_nodes()):\n            for attr in node_attrs:\n                attr_dict[attr].append(nx_graph.nodes[nid][attr])\n        for attr in node_attrs:\n            g.ndata[attr] = F.copy_to(_batcher(attr_dict[attr]), g.device)\n\n    if edge_attrs is not None:\n        # mapping from feature name to a list of tensors to be concatenated\n        attr_dict = defaultdict(lambda: [None] * g.num_edges())\n        # each defaultdict value is initialized to be a list of None\n        # None here serves as placeholder to be replaced by feature with\n        # corresponding edge id\n        if has_edge_id:\n            num_edges = g.num_edges()\n            for _, _, attrs in nx_graph.edges(data=True):\n                if attrs[edge_id_attr_name] >= num_edges:\n                    raise DGLError(\n                        \"Expect the pre-specified edge ids to be\"\n                        \" smaller than the number of edges --\"\n                        \" {}, got {}.\".format(num_edges, attrs[\"id\"])\n                    )\n                for key in edge_attrs:\n                    attr_dict[key][attrs[edge_id_attr_name]] = attrs[key]\n        else:\n            # XXX: assuming networkx iteration order is deterministic\n            #      so the order is the same as graph_index.from_networkx\n            for eid, (_, _, attrs) in enumerate(nx_graph.edges(data=True)):\n                for key in edge_attrs:\n                    attr_dict[key][eid] = attrs[key]\n        for attr in edge_attrs:\n            for val in attr_dict[attr]:\n                if val is None:\n                    raise DGLError(\n                        \"Not all edges have attribute {}.\".format(attr)\n                    )\n            g.edata[attr] = F.copy_to(_batcher(attr_dict[attr]), g.device)\n\n    return g.to(device)\n\n\ndef bipartite_from_networkx(\n    nx_graph,\n    utype,\n    etype,\n    vtype,\n    u_attrs=None,\n    e_attrs=None,\n    v_attrs=None,\n    edge_id_attr_name=None,\n    idtype=None,\n    device=None,\n):\n    \"\"\"Create a unidirectional bipartite graph from a NetworkX graph and return.\n\n    The created graph will have two types of nodes ``utype`` and ``vtype`` as well as one\n    edge type ``etype`` whose edges are from ``utype`` to ``vtype``.\n\n    .. note::\n        Creating a DGLGraph from a NetworkX graph is not fast especially for large scales.\n        It is recommended to first convert a NetworkX graph into a tuple of node-tensors\n        and then construct a DGLGraph with :func:`dgl.heterograph`.\n\n    Parameters\n    ----------\n    nx_graph : networkx.DiGraph\n        The NetworkX graph holding the graph structure and the node/edge attributes.\n        DGL will relabel the nodes using consecutive integers starting from zero if it is\n        not the case. The graph must follow `NetworkX's bipartite graph convention\n        <https://networkx.github.io/documentation/stable/reference/algorithms/bipartite.html>`_,\n        and furthermore the edges must be from nodes with attribute ``bipartite=0`` to nodes\n        with attribute ``bipartite=1``.\n    utype : str, optional\n        The name of the source node type.\n    etype : str, optional\n        The name of the edge type.\n    vtype : str, optional\n        The name of the destination node type.\n    u_attrs : list[str], optional\n        The names of the node attributes for node type :attr:`utype` to retrieve from the\n        NetworkX graph. If given, DGL stores the retrieved node attributes in\n        ``nodes[utype].data`` of the returned graph using their original names. The attribute\n        data must be convertible to Tensor type (e.g., scalar, ``numpy.ndarray``, list, etc.).\n    e_attrs : list[str], optional\n        The names of the edge attributes to retrieve from the NetworkX graph. If given, DGL\n        stores the retrieved edge attributes in ``edata`` of the returned graph using their\n        original names. The attribute data must be convertible to Tensor type (e.g., scalar,\n        numpy.ndarray, list, etc.).\n    v_attrs : list[str], optional\n        The names of the node attributes for node type :attr:`vtype` to retrieve from the\n        NetworkX graph.  If given, DGL stores the retrieved node attributes in\n        ``nodes[vtype].data`` of the returned graph using their original names. The attribute\n        data must be convertible to Tensor type (e.g., scalar, numpy.array, list, etc.).\n    edge_id_attr_name : str, optional\n        The name of the edge attribute that stores the edge IDs. If given, DGL will assign edge\n        IDs accordingly when creating the graph, so the attribute must be valid IDs, i.e.\n        consecutive integers starting from zero. By default, the edge IDs of the returned graph\n        can be arbitrary.\n    idtype : int32 or int64, optional\n        The data type for storing the structure-related graph information such as node and\n        edge IDs. It should be a framework-specific data type object (e.g., torch.int32).\n        By default, DGL uses int64.\n    device : device context, optional\n        The device of the resulting graph. It should be a framework-specific device object\n        (e.g., torch.device). By default, DGL stores the graph on CPU.\n\n    Returns\n    -------\n    DGLGraph\n        The created graph.\n\n    Examples\n    --------\n\n    The following example uses PyTorch backend.\n\n    >>> import dgl\n    >>> import networkx as nx\n    >>> import numpy as np\n    >>> import torch\n\n    Create a 2-edge unidirectional bipartite graph.\n\n    >>> nx_g = nx.DiGraph()\n    >>> # Add nodes for the source type\n    >>> nx_g.add_nodes_from([1, 3], bipartite=0, feat1=np.zeros((2, 1)), feat2=np.ones((2, 1)))\n    >>> # Add nodes for the destination type\n    >>> nx_g.add_nodes_from([2, 4, 5], bipartite=1, feat3=np.zeros((3, 1)))\n    >>> nx_g.add_edge(1, 4, weight=np.ones((1, 1)), eid=np.array([1]))\n    >>> nx_g.add_edge(3, 5, weight=np.ones((1, 1)), eid=np.array([0]))\n\n    Convert it into a DGLGraph with structure only.\n\n    >>> g = dgl.bipartite_from_networkx(nx_g, utype='_U', etype='_E', vtype='_V')\n\n    Retrieve the node/edge features of the graph.\n\n    >>> g = dgl.bipartite_from_networkx(nx_g, utype='_U', etype='_E', vtype='_V',\n    ...                                 u_attrs=['feat1', 'feat2'],\n    ...                                 e_attrs=['weight'],\n    ...                                 v_attrs=['feat3'])\n\n    Use a pre-specified ordering of the edges.\n\n    >>> g.edges()\n    (tensor([0, 1]), tensor([1, 2]))\n    >>> g = dgl.bipartite_from_networkx(nx_g,\n    ...                                 utype='_U', etype='_E', vtype='_V',\n    ...                                 edge_id_attr_name='eid')\n    (tensor([1, 0]), tensor([2, 1]))\n\n    Create a graph on the first GPU with data type int32.\n\n    >>> g = dgl.bipartite_from_networkx(nx_g, utype='_U', etype='_E', vtype='_V',\n    ...                                 idtype=torch.int32, device='cuda:0')\n\n    See Also\n    --------\n    heterograph\n    bipartite_from_scipy\n    \"\"\"\n    if not nx_graph.is_directed():\n        raise DGLError(\"Expect nx_graph to be a directed NetworkX graph.\")\n    if (\n        edge_id_attr_name is not None\n        and not edge_id_attr_name in next(iter(nx_graph.edges(data=True)))[-1]\n    ):\n        raise DGLError(\n            \"Failed to find the pre-specified edge IDs in the edge features \"\n            \"of the NetworkX graph with name {}\".format(edge_id_attr_name)\n        )\n\n    # Get the source and destination node sets\n    top_nodes = set()\n    bottom_nodes = set()\n    for n, ndata in nx_graph.nodes(data=True):\n        if \"bipartite\" not in ndata:\n            raise DGLError(\n                \"Expect the node {} to have attribute bipartite\".format(n)\n            )\n        if ndata[\"bipartite\"] == 0:\n            top_nodes.add(n)\n        elif ndata[\"bipartite\"] == 1:\n            bottom_nodes.add(n)\n        else:\n            raise ValueError(\n                \"Expect the bipartite attribute of the node {} to be 0 or 1, \"\n                \"got {}\".format(n, ndata[\"bipartite\"])\n            )\n\n    # Separately relabel the source and destination nodes.\n    top_nodes = sorted(top_nodes)\n    bottom_nodes = sorted(bottom_nodes)\n    top_map = {n: i for i, n in enumerate(top_nodes)}\n    bottom_map = {n: i for i, n in enumerate(bottom_nodes)}\n\n    # Get the node tensors and the number of nodes\n    (sparse_fmt, arrays), urange, vrange = utils.graphdata2tensors(\n        nx_graph,\n        idtype,\n        bipartite=True,\n        edge_id_attr_name=edge_id_attr_name,\n        top_map=top_map,\n        bottom_map=bottom_map,\n    )\n\n    g = create_from_edges(\n        sparse_fmt, arrays, utype, etype, vtype, urange, vrange\n    )\n\n    # nx_graph.edges(data=True) returns src, dst, attr_dict\n    has_edge_id = (\n        nx_graph.number_of_edges() > 0 and edge_id_attr_name is not None\n    )\n\n    # handle features\n    # copy attributes\n    if u_attrs is not None:\n        # mapping from feature name to a list of tensors to be concatenated\n        src_attr_dict = defaultdict(list)\n        for nid in top_map.keys():\n            for attr in u_attrs:\n                src_attr_dict[attr].append(nx_graph.nodes[nid][attr])\n        for attr in u_attrs:\n            g.srcdata[attr] = F.copy_to(_batcher(src_attr_dict[attr]), g.device)\n\n    if v_attrs is not None:\n        # mapping from feature name to a list of tensors to be concatenated\n        dst_attr_dict = defaultdict(list)\n        for nid in bottom_map.keys():\n            for attr in v_attrs:\n                dst_attr_dict[attr].append(nx_graph.nodes[nid][attr])\n        for attr in v_attrs:\n            g.dstdata[attr] = F.copy_to(_batcher(dst_attr_dict[attr]), g.device)\n\n    if e_attrs is not None:\n        # mapping from feature name to a list of tensors to be concatenated\n        attr_dict = defaultdict(lambda: [None] * g.num_edges())\n        # each defaultdict value is initialized to be a list of None\n        # None here serves as placeholder to be replaced by feature with\n        # corresponding edge id\n        if has_edge_id:\n            for _, _, attrs in nx_graph.edges(data=True):\n                for key in e_attrs:\n                    attr_dict[key][attrs[edge_id_attr_name]] = attrs[key]\n        else:\n            # XXX: assuming networkx iteration order is deterministic\n            #      so the order is the same as graph_index.from_networkx\n            for eid, (_, _, attrs) in enumerate(nx_graph.edges(data=True)):\n                for key in e_attrs:\n                    attr_dict[key][eid] = attrs[key]\n        for attr in e_attrs:\n            for val in attr_dict[attr]:\n                if val is None:\n                    raise DGLError(\n                        \"Not all edges have attribute {}.\".format(attr)\n                    )\n            g.edata[attr] = F.copy_to(_batcher(attr_dict[attr]), g.device)\n\n    return g.to(device)\n\n\ndef _to_networkx_homogeneous(g, node_attrs, edge_attrs):\n    # TODO: consider adding an eid_attr parameter as in\n    #  `_to_networkx_heterogeneous` when this function is properly tested\n    # (see GitHub issue #5735)\n    src, dst = g.edges()\n    src = F.asnumpy(src)\n    dst = F.asnumpy(dst)\n    # xiangsx: Always treat graph as multigraph\n    nx_graph = nx.MultiDiGraph()\n    nx_graph.add_nodes_from(range(g.num_nodes()))\n    for eid, (u, v) in enumerate(zip(src, dst)):\n        nx_graph.add_edge(u, v, id=eid)\n\n    if node_attrs is not None:\n        for nid, attr in nx_graph.nodes(data=True):\n            feat_dict = g._get_n_repr(0, nid)\n            attr.update(\n                {key: F.squeeze(feat_dict[key], 0) for key in node_attrs}\n            )\n    if edge_attrs is not None:\n        for _, _, attr in nx_graph.edges(data=True):\n            eid = attr[\"id\"]\n            feat_dict = g._get_e_repr(0, eid)\n            attr.update(\n                {key: F.squeeze(feat_dict[key], 0) for key in edge_attrs}\n            )\n    return nx_graph\n\n\ndef _to_networkx_heterogeneous(\n    g, node_attrs, edge_attrs, ntype_attr, etype_attr, eid_attr\n):\n    nx_graph = nx.MultiDiGraph()\n\n    # This implementation does not use `ndata` and `edata` in the call to\n    # `to_homogeneous` because the function expects node and edge attributes\n    # both to be defined for every type and to have the same shape.\n    # If the `to_homogeneous` function is updated to support non-uniform node\n    # and edge attributes, the implementation can be simplified.\n    hom_g = to_homogeneous(g, store_type=True, return_count=False)\n    ntypes = g.ntypes\n    etypes = g.canonical_etypes\n\n    for hom_nid, ndata in enumerate(zip(hom_g.ndata[NID], hom_g.ndata[NTYPE])):\n        orig_nid, ntype = ndata\n        attrs = {ntype_attr: ntypes[ntype]}\n\n        if node_attrs is not None:\n            assert ntype_attr not in node_attrs, (\n                f\"'{ntype_attr}' already used as node type attribute, \"\n                f\"please provide a different value for ntype_attr\"\n            )\n\n            feat_dict = g._get_n_repr(ntype, orig_nid)\n            attrs.update(\n                {\n                    key: F.squeeze(feat_dict[key], 0)\n                    for key in node_attrs\n                    if key in feat_dict\n                }\n            )\n\n        nx_graph.add_node(hom_nid, **attrs)\n\n    for hom_eid, edata in enumerate(zip(hom_g.edata[EID], hom_g.edata[ETYPE])):\n        orig_eid, etype = edata\n        attrs = {eid_attr: hom_eid, etype_attr: etypes[etype]}\n\n        if edge_attrs is not None:\n            assert etype_attr not in edge_attrs, (\n                f\"'{etype_attr}' already used as edge type attribute, \"\n                f\"please provide a different value for etype_attr\"\n            )\n            assert eid_attr not in edge_attrs, (\n                f\"'{eid_attr}' already used as edge ID attribute, \"\n                f\"please provide a different value for eid_attr\"\n            )\n\n            feat_dict = g._get_e_repr(etype, orig_eid)\n            attrs.update(\n                {\n                    key: F.squeeze(feat_dict[key], 0)\n                    for key in edge_attrs\n                    if key in feat_dict\n                }\n            )\n\n        src, dst = hom_g.find_edges(hom_eid)\n        nx_graph.add_edge(int(src), int(dst), **attrs)\n\n    return nx_graph\n\n\ndef to_networkx(\n    g,\n    node_attrs=None,\n    edge_attrs=None,\n    ntype_attr=\"ntype\",\n    etype_attr=\"etype\",\n    eid_attr=\"id\",\n):\n    \"\"\"Convert a graph to a NetworkX graph and return.\n\n    The resulting NetworkX graph also contains the node/edge features of the input graph.\n    Additionally, DGL saves the edge IDs as the ``'id'`` edge attribute in the\n    returned NetworkX graph.\n\n    Parameters\n    ----------\n    g : DGLGraph\n        A homogeneous or heterogeneous graph.\n    node_attrs : iterable of str, optional\n        The node attributes to copy from ``g.ndata``. (Default: None)\n    edge_attrs : iterable of str, optional\n        The edge attributes to copy from ``g.edata``.\n        (Default: None)\n    ntype_attr : str, optional\n        The name of the node attribute to store the node types in the NetworkX object.\n        (Default: \"ntype\")\n    etype_attr : str, optional\n        The name of the edge attribute to store the edge canonical types in the NetworkX object.\n        (Default: \"etype\")\n    eid_attr : str, optional\n        The name of the edge attribute to store the original edge ID in the NetworkX object.\n        (Default: \"id\")\n\n    Returns\n    -------\n    networkx.DiGraph\n        The converted NetworkX graph.\n\n    Notes\n    -----\n    The function only supports CPU graph input.\n\n    Examples\n    --------\n    The following examples use the PyTorch backend.\n\n    >>> import dgl\n    >>> import torch\n\n    With a homogeneous graph:\n\n    >>> g = dgl.graph((torch.tensor([1, 2]), torch.tensor([1, 3])))\n    >>> g.ndata['h'] = torch.zeros(4, 1)\n    >>> g.edata['h1'] = torch.ones(2, 1)\n    >>> g.edata['h2'] = torch.zeros(2, 2)\n    >>> nx_g = dgl.to_networkx(g, node_attrs=['h'], edge_attrs=['h1', 'h2'])\n    >>> nx_g.nodes(data=True)\n    NodeDataView({\n        0: {'h': tensor([0.])},\n        1: {'h': tensor([0.])},\n        2: {'h': tensor([0.])},\n        3: {'h': tensor([0.])}\n    })\n    >>> nx_g.edges(data=True)\n    OutMultiEdgeDataView([\n        (1, 1, {'id': 0, 'h1': tensor([1.]), 'h2': tensor([0., 0.])}),\n        (2, 3, {'id': 1, 'h1': tensor([1.]), 'h2': tensor([0., 0.])})\n    ])\n\n    With a heterogeneous graph:\n\n    >>> g = dgl.heterograph({\n    ...     ('user', 'follows', 'user'): (torch.tensor([0, 1]), torch.tensor([1, 2])),\n    ...     ('user', 'follows', 'topic'): (torch.tensor([1, 1]), torch.tensor([1, 2])),\n    ...     ('user', 'plays', 'game'): (torch.tensor([0, 3]), torch.tensor([3, 4]))\n    ... })\n    >>> g.ndata['n'] = {\n    ...     'game': torch.zeros(5, 1),\n    ...     'user': torch.ones(4, 1)\n    ... }\n    >>> g.edata['e'] = {\n    ...     ('user', 'follows', 'user'): torch.zeros(2, 1),\n    ...     'plays': torch.ones(2, 1)\n    ... }\n    >>> nx_g = dgl.to_networkx(g, node_attrs=['n'], edge_attrs=['e'])\n    >>> nx_g.nodes(data=True)\n    NodeDataView({\n        0: {'ntype': 'game', 'n': tensor([0.])},\n        1: {'ntype': 'game', 'n': tensor([0.])},\n        2: {'ntype': 'game', 'n': tensor([0.])},\n        3: {'ntype': 'game', 'n': tensor([0.])},\n        4: {'ntype': 'game', 'n': tensor([0.])},\n        5: {'ntype': 'topic'},\n        6: {'ntype': 'topic'},\n        7: {'ntype': 'topic'},\n        8: {'ntype': 'user', 'n': tensor([1.])},\n        9: {'ntype': 'user', 'n': tensor([1.])},\n        10: {'ntype': 'user', 'n': tensor([1.])},\n        11: {'ntype': 'user', 'n': tensor([1.])}\n    })\n    >>> nx_g.edges(data=True)\n    OutMultiEdgeDataView([\n        (8, 9, {'id': 2, 'etype': ('user', 'follows', 'user'), 'e': tensor([0.])}),\n        (8, 3, {'id': 4, 'etype': ('user', 'plays', 'game'), 'e': tensor([1.])}),\n        (9, 6, {'id': 0, 'etype': ('user', 'follows', 'topic')}),\n        (9, 7, {'id': 1, 'etype': ('user', 'follows', 'topic')}),\n        (9, 10, {'id': 3, 'etype': ('user', 'follows', 'user'), 'e': tensor([0.])}),\n        (11, 4, {'id': 5, 'etype': ('user', 'plays', 'game'), 'e': tensor([1.])})\n    ])\n    \"\"\"\n    if g.device != F.cpu():\n        raise DGLError(\n            \"Cannot convert a CUDA graph to networkx. Call g.cpu() first.\"\n        )\n    if g.is_homogeneous:\n        return _to_networkx_homogeneous(g, node_attrs, edge_attrs)\n    else:\n        return _to_networkx_heterogeneous(\n            g, node_attrs, edge_attrs, ntype_attr, etype_attr, eid_attr\n        )\n\n\nDGLGraph.to_networkx = to_networkx\n\n\ndef to_cugraph(g):\n    \"\"\"Convert a DGL graph to a :class:`cugraph.Graph` and return.\n\n    Parameters\n    ----------\n    g : DGLGraph\n        A homogeneous graph.\n\n    Returns\n    -------\n    cugraph.Graph\n        The converted cugraph graph.\n\n    Notes\n    -----\n    The function only supports GPU graph input.\n\n    Examples\n    --------\n    The following example uses PyTorch backend.\n\n    >>> import dgl\n    >>> import cugraph\n    >>> import torch\n\n    >>> g = dgl.graph((torch.tensor([1, 2]), torch.tensor([1, 3]))).to('cuda')\n    >>> cugraph_g = g.to_cugraph()\n    >>> cugraph_g.edges()\n        src  dst\n    0    2    3\n    1    1    1\n    \"\"\"\n\n    if g.device.type != \"cuda\":\n        raise DGLError(\n            f\"Cannot convert a {g.device.type} graph to cugraph.\"\n            + \"Call g.to('cuda') first.\"\n        )\n    if not g.is_homogeneous:\n        raise DGLError(\"dgl.to_cugraph only supports homogeneous graphs.\")\n\n    try:\n        import cudf\n        import cugraph\n    except ModuleNotFoundError:\n        raise ModuleNotFoundError(\n            \"to_cugraph requires cugraph which could not be imported\"\n        )\n\n    edgelist = g.edges()\n    src_ser = cudf.from_dlpack(F.zerocopy_to_dlpack(edgelist[0]))\n    dst_ser = cudf.from_dlpack(F.zerocopy_to_dlpack(edgelist[1]))\n    cudf_data = cudf.DataFrame({\"source\": src_ser, \"destination\": dst_ser})\n    g_cugraph = cugraph.Graph(directed=True)\n    g_cugraph.from_cudf_edgelist(\n        cudf_data, source=\"source\", destination=\"destination\"\n    )\n    return g_cugraph\n\n\nDGLGraph.to_cugraph = to_cugraph\n\n\ndef from_cugraph(cugraph_graph):\n    \"\"\"Create a graph from a :class:`cugraph.Graph` object.\n\n    Parameters\n    ----------\n    cugraph_graph : cugraph.Graph\n        The cugraph graph object holding the graph structure. Node and edge attributes are\n        dropped.\n\n        If the input graph is undirected, DGL converts it to a directed graph\n        by :func:`cugraph.Graph.to_directed`.\n\n    Returns\n    -------\n    DGLGraph\n        The created graph.\n\n    Examples\n    --------\n\n    The following example uses PyTorch backend.\n\n    >>> import dgl\n    >>> import cugraph\n    >>> import cudf\n\n    Create a cugraph graph.\n    >>> cugraph_g = cugraph.Graph(directed=True)\n    >>> df = cudf.DataFrame({\"source\":[0, 1, 2, 3],\n                     \"destination\":[1, 2, 3, 0]})\n    >>> cugraph_g.from_cudf_edgelist(df)\n\n    Convert it into a DGLGraph\n    >>> g = dgl.from_cugraph(cugraph_g)\n    >>> g.edges()\n    (tensor([1, 2, 3, 0], device='cuda:0'), tensor([2, 3, 0, 1], device='cuda:0'))\n    \"\"\"\n    if not cugraph_graph.is_directed():\n        cugraph_graph = cugraph_graph.to_directed()\n\n    edges = cugraph_graph.edges()\n    src_t = F.zerocopy_from_dlpack(edges[\"src\"].to_dlpack())\n    dst_t = F.zerocopy_from_dlpack(edges[\"dst\"].to_dlpack())\n    g = graph((src_t, dst_t))\n\n    return g\n\n\n############################################################\n# Internal APIs\n############################################################\n\n\ndef create_from_edges(\n    sparse_fmt,\n    arrays,\n    utype,\n    etype,\n    vtype,\n    urange,\n    vrange,\n    row_sorted=False,\n    col_sorted=False,\n):\n    \"\"\"Internal function to create a graph from incident nodes with types.\n\n    utype could be equal to vtype\n\n    Parameters\n    ----------\n    sparse_fmt : str\n        The sparse adjacency matrix format.\n    arrays : tuple[Tensor]\n        The sparse adjacency matrix arrays.\n    utype : str\n        Source node type name.\n    etype : str\n        Edge type name.\n    vtype : str\n        Destination node type name.\n    urange : int, optional\n        The source node ID range. If None, the value is the maximum\n        of the source node IDs in the edge list plus 1. (Default: None)\n    vrange : int, optional\n        The destination node ID range. If None, the value is the\n        maximum of the destination node IDs in the edge list plus 1. (Default: None)\n    row_sorted : bool, optional\n        Whether or not the rows of the COO are in ascending order.\n    col_sorted : bool, optional\n        Whether or not the columns of the COO are in ascending order within\n        each row. This only has an effect when ``row_sorted`` is True.\n\n\n    Returns\n    -------\n    DGLGraph\n    \"\"\"\n    if utype == vtype:\n        num_ntypes = 1\n    else:\n        num_ntypes = 2\n\n    if sparse_fmt == \"coo\":\n        u, v = arrays\n        hgidx = heterograph_index.create_unitgraph_from_coo(\n            num_ntypes,\n            urange,\n            vrange,\n            u,\n            v,\n            [\"coo\", \"csr\", \"csc\"],\n            row_sorted,\n            col_sorted,\n        )\n    else:  # 'csr' or 'csc'\n        indptr, indices, eids = arrays\n        hgidx = heterograph_index.create_unitgraph_from_csr(\n            num_ntypes,\n            urange,\n            vrange,\n            indptr,\n            indices,\n            eids,\n            [\"coo\", \"csr\", \"csc\"],\n            sparse_fmt == \"csc\",\n        )\n    if utype == vtype:\n        return DGLGraph(hgidx, [utype], [etype])\n    else:\n        return DGLGraph(hgidx, [utype, vtype], [etype])\n"
  },
  {
    "path": "python/dgl/core.py",
    "content": "\"\"\"Implementation for core graph computation.\"\"\"\n# pylint: disable=not-callable\nimport numpy as np\n\nfrom . import backend as F, function as fn, ops\nfrom .base import ALL, dgl_warning, DGLError, EID, is_all, NID\nfrom .frame import Frame\nfrom .udf import EdgeBatch, NodeBatch\n\n\ndef is_builtin(func):\n    \"\"\"Return true if the function is a DGL builtin function.\"\"\"\n    return isinstance(func, fn.BuiltinFunction)\n\n\ndef invoke_node_udf(graph, nid, ntype, func, *, ndata=None, orig_nid=None):\n    \"\"\"Invoke user-defined node function on the given nodes.\n\n    Parameters\n    ----------\n    graph : DGLGraph\n        The input graph.\n    nid : Tensor\n        The IDs of the nodes to invoke UDF on.\n    ntype : str\n        Node type.\n    func : callable\n        The user-defined function.\n    ndata : dict[str, Tensor], optional\n        If provided, apply the UDF on this ndata instead of the ndata of the graph.\n    orig_nid : Tensor, optional\n        Original node IDs. Useful if the input graph is an extracted subgraph.\n\n    Returns\n    -------\n    dict[str, Tensor]\n        Results from running the UDF.\n    \"\"\"\n    ntid = graph.get_ntype_id(ntype)\n    if ndata is None:\n        if is_all(nid):\n            ndata = graph._node_frames[ntid]\n            nid = graph.nodes(ntype=ntype)\n        else:\n            ndata = graph._node_frames[ntid].subframe(nid)\n    nbatch = NodeBatch(\n        graph, nid if orig_nid is None else orig_nid, ntype, ndata\n    )\n    return func(nbatch)\n\n\ndef invoke_edge_udf(graph, eid, etype, func, *, orig_eid=None):\n    \"\"\"Invoke user-defined edge function on the given edges.\n\n    Parameters\n    ----------\n    graph : DGLGraph\n        The input graph.\n    eid : Tensor\n        The IDs of the edges to invoke UDF on.\n    etype : (str, str, str)\n        Edge type.\n    func : callable\n        The user-defined function.\n    orig_eid : Tensor, optional\n        Original edge IDs. Useful if the input graph is an extracted subgraph.\n\n    Returns\n    -------\n    dict[str, Tensor]\n        Results from running the UDF.\n    \"\"\"\n    etid = graph.get_etype_id(etype)\n    stid, dtid = graph._graph.metagraph.find_edge(etid)\n    if is_all(eid):\n        u, v, eid = graph.edges(form=\"all\")\n        edata = graph._edge_frames[etid]\n    else:\n        u, v = graph.find_edges(eid)\n        edata = graph._edge_frames[etid].subframe(eid)\n    if len(u) == 0:\n        dgl_warning(\n            \"The input graph for the user-defined edge function \"\n            \"does not contain valid edges\"\n        )\n    srcdata = graph._node_frames[stid].subframe(u)\n    dstdata = graph._node_frames[dtid].subframe(v)\n    ebatch = EdgeBatch(\n        graph,\n        eid if orig_eid is None else orig_eid,\n        etype,\n        srcdata,\n        edata,\n        dstdata,\n    )\n    return func(ebatch)\n\n\ndef invoke_udf_reduce(graph, func, msgdata, *, orig_nid=None):\n    \"\"\"Invoke user-defined reduce function on all the nodes in the graph.\n\n    It analyzes the graph, groups nodes by their degrees and applies the UDF on each\n    group -- a strategy called *degree-bucketing*.\n\n    Parameters\n    ----------\n    graph : DGLGraph\n        The input graph.\n    func : callable\n        The user-defined function.\n    msgdata : dict[str, Tensor]\n        Message data.\n    orig_nid : Tensor, optional\n        Original node IDs. Useful if the input graph is an extracted subgraph.\n\n    Returns\n    -------\n    dict[str, Tensor]\n        Results from running the UDF.\n    \"\"\"\n    degs = graph.in_degrees()\n    nodes = graph.dstnodes()\n    if orig_nid is None:\n        orig_nid = nodes\n    ntype = graph.dsttypes[0]\n    ntid = graph.get_ntype_id_from_dst(ntype)\n    dstdata = graph._node_frames[ntid]\n    msgdata = Frame(msgdata)\n\n    # degree bucketing\n    unique_degs, bucketor = _bucketing(degs)\n    bkt_rsts = []\n    bkt_nodes = []\n    for deg, node_bkt, orig_nid_bkt in zip(\n        unique_degs, bucketor(nodes), bucketor(orig_nid)\n    ):\n        if deg == 0:\n            # skip reduce function for zero-degree nodes\n            continue\n        bkt_nodes.append(node_bkt)\n        ndata_bkt = dstdata.subframe(node_bkt)\n\n        # order the incoming edges per node by edge ID\n        eid_bkt = F.zerocopy_to_numpy(graph.in_edges(node_bkt, form=\"eid\"))\n        assert len(eid_bkt) == deg * len(node_bkt)\n        eid_bkt = np.sort(eid_bkt.reshape((len(node_bkt), deg)), 1)\n        eid_bkt = F.zerocopy_from_numpy(eid_bkt.flatten())\n\n        msgdata_bkt = msgdata.subframe(eid_bkt)\n        # reshape all msg tensors to (num_nodes_bkt, degree, feat_size)\n        maildata = {}\n        for k, msg in msgdata_bkt.items():\n            newshape = (len(node_bkt), deg) + F.shape(msg)[1:]\n            maildata[k] = F.reshape(msg, newshape)\n        # invoke udf\n        nbatch = NodeBatch(graph, orig_nid_bkt, ntype, ndata_bkt, msgs=maildata)\n        bkt_rsts.append(func(nbatch))\n\n    # prepare a result frame\n    retf = Frame(num_rows=len(nodes))\n    retf._initializers = dstdata._initializers\n    retf._default_initializer = dstdata._default_initializer\n\n    # merge bucket results and write to the result frame\n    if (\n        len(bkt_rsts) != 0\n    ):  # if all the nodes have zero degree, no need to merge results.\n        merged_rst = {}\n        for k in bkt_rsts[0].keys():\n            merged_rst[k] = F.cat([rst[k] for rst in bkt_rsts], dim=0)\n        merged_nodes = F.cat(bkt_nodes, dim=0)\n        retf.update_row(merged_nodes, merged_rst)\n\n    return retf\n\n\ndef _bucketing(val):\n    \"\"\"Internal function to create groups on the values.\n\n    Parameters\n    ----------\n    val : Tensor\n        Value tensor.\n\n    Returns\n    -------\n    unique_val : Tensor\n        Unique values.\n    bucketor : callable[Tensor -> list[Tensor]]\n        A bucketing function that splits the given tensor data as the same\n        way of how the :attr:`val` tensor is grouped.\n    \"\"\"\n    sorted_val, idx = F.sort_1d(val)\n    unique_val = F.asnumpy(F.unique(sorted_val))\n    bkt_idx = []\n    for v in unique_val:\n        eqidx = F.nonzero_1d(F.equal(sorted_val, v))\n        bkt_idx.append(F.gather_row(idx, eqidx))\n\n    def bucketor(data):\n        bkts = [F.gather_row(data, idx) for idx in bkt_idx]\n        return bkts\n\n    return unique_val, bucketor\n\n\ndef data_dict_to_list(graph, data_dict, func, target):\n    \"\"\"Get node or edge feature data of the given name for all the types.\n\n    Parameters\n    -------------\n    graph :  DGLGraph\n        The input graph.\n    data_dict : dict[str, Tensor] or dict[(str, str, str), Tensor]] or Tensor\n        Node or edge data stored in DGLGraph. The key of the dictionary\n        is the node type name or edge type name. If there is only single source\n        node type, data_dict is the value of feature(a Tensor) not a dict.\n    func : dgl.function.BaseMessageFunction\n        Built-in message function.\n    target : 'u', 'v' or 'e'\n        The target of the lhs or rhs data\n\n    Returns\n    --------\n    data_list : list(Tensor)\n        Feature data stored in a list of tensors. The i^th tensor stores the feature\n        data of type ``types[i]``.\n    \"\"\"\n    if isinstance(func, fn.BinaryMessageFunction):\n        if target in [\"u\", \"v\"]:\n            output_list = [None] * graph._graph.number_of_ntypes()\n            # If there is only single source node type, data_dict should be the value of\n            # feature, namely, a tensor.\n            if not isinstance(data_dict, dict):\n                src_id, dst_id = graph._graph.metagraph.find_edge(0)\n                if target == \"u\":\n                    output_list[src_id] = data_dict\n                else:\n                    output_list[dst_id] = data_dict\n            else:\n                for srctype, _, dsttype in graph.canonical_etypes:\n                    if target == \"u\":\n                        src_id = graph.get_ntype_id(srctype)\n                        output_list[src_id] = data_dict[srctype]\n                    else:\n                        dst_id = graph.get_ntype_id(dsttype)\n                        output_list[dst_id] = data_dict[dsttype]\n        else:  # target == 'e'\n            output_list = [None] * graph._graph.number_of_etypes()\n            for rel in graph.canonical_etypes:\n                etid = graph.get_etype_id(rel)\n                output_list[etid] = data_dict[rel]\n        return output_list\n    else:\n        if target == \"u\":\n            lhs_list = [None] * graph._graph.number_of_ntypes()\n            if not isinstance(data_dict, dict):\n                src_id, _ = graph._graph.metagraph.find_edge(0)\n                lhs_list[src_id] = data_dict\n            else:\n                for srctype, _, _ in graph.canonical_etypes:\n                    src_id = graph.get_ntype_id(srctype)\n                    lhs_list[src_id] = data_dict[srctype]\n            return lhs_list\n        else:  # target == 'e':\n            rhs_list = [None] * graph._graph.number_of_etypes()\n            for rel in graph.canonical_etypes:\n                etid = graph.get_etype_id(rel)\n                rhs_list[etid] = data_dict[rel]\n            return rhs_list\n\n\ndef invoke_gsddmm(graph, func):\n    \"\"\"Invoke g-SDDMM computation on the graph.\n\n    Parameters\n    ----------\n    graph :  DGLGraph\n        The input graph.\n    func : dgl.function.BaseMessageFunction\n        Built-in message function.\n\n    Returns\n    -------\n    dict[str, Tensor]\n        Results from the g-SDDMM computation.\n    \"\"\"\n    alldata = [graph.srcdata, graph.dstdata, graph.edata]\n    if isinstance(func, fn.BinaryMessageFunction):\n        x = alldata[func.lhs][func.lhs_field]\n        y = alldata[func.rhs][func.rhs_field]\n        op = getattr(ops, func.name)\n        if graph._graph.number_of_etypes() > 1:\n            lhs_target, _, rhs_target = func.name.split(\"_\", 2)\n            x = data_dict_to_list(graph, x, func, lhs_target)\n            y = data_dict_to_list(graph, y, func, rhs_target)\n        z = op(graph, x, y)\n    else:\n        x = alldata[func.target][func.in_field]\n        op = getattr(ops, func.name)\n        if graph._graph.number_of_etypes() > 1:\n            # Convert to list as dict is unordered.\n            if func.name == \"copy_u\":\n                x = data_dict_to_list(graph, x, func, \"u\")\n            else:  # \"copy_e\"\n                x = data_dict_to_list(graph, x, func, \"e\")\n        z = op(graph, x)\n    return {func.out_field: z}\n\n\ndef invoke_gspmm(\n    graph, mfunc, rfunc, *, srcdata=None, dstdata=None, edata=None\n):\n    \"\"\"Invoke g-SPMM computation on the graph.\n\n    Parameters\n    ----------\n    graph :  DGLGraph\n        The input graph.\n    mfunc : dgl.function.BaseMessageFunction\n        Built-in message function.\n    rfunc : dgl.function.BaseReduceFunction\n        Built-in reduce function.\n    srcdata : dict[str, Tensor], optional\n        Source node feature data. If not provided, it use ``graph.srcdata``.\n    dstdata : dict[str, Tensor], optional\n        Destination node feature data. If not provided, it use ``graph.dstdata``.\n    edata : dict[str, Tensor], optional\n        Edge feature data. If not provided, it use ``graph.edata``.\n\n    Returns\n    -------\n    dict[str, Tensor]\n        Results from the g-SPMM computation.\n    \"\"\"\n    # sanity check\n    if mfunc.out_field != rfunc.msg_field:\n        raise DGLError(\n            \"Invalid message ({}) and reduce ({}) function pairs.\"\n            \" The output field of the message function must be equal to the\"\n            \" message field of the reduce function.\".format(mfunc, rfunc)\n        )\n    if edata is None:\n        edata = graph.edata\n    if srcdata is None:\n        srcdata = graph.srcdata\n    if dstdata is None:\n        dstdata = graph.dstdata\n    alldata = [srcdata, dstdata, edata]\n\n    if isinstance(mfunc, fn.BinaryMessageFunction):\n        x = alldata[mfunc.lhs][mfunc.lhs_field]\n        y = alldata[mfunc.rhs][mfunc.rhs_field]\n        op = getattr(ops, \"{}_{}\".format(mfunc.name, rfunc.name))\n        if graph._graph.number_of_etypes() > 1:\n            lhs_target, _, rhs_target = mfunc.name.split(\"_\", 2)\n            x = data_dict_to_list(graph, x, mfunc, lhs_target)\n            y = data_dict_to_list(graph, y, mfunc, rhs_target)\n        z = op(graph, x, y)\n    else:\n        x = alldata[mfunc.target][mfunc.in_field]\n        op = getattr(ops, \"{}_{}\".format(mfunc.name, rfunc.name))\n        if graph._graph.number_of_etypes() > 1 and not isinstance(x, tuple):\n            if mfunc.name == \"copy_u\":\n                x = data_dict_to_list(graph, x, mfunc, \"u\")\n            else:  # \"copy_e\"\n                x = data_dict_to_list(graph, x, mfunc, \"e\")\n        z = op(graph, x)\n    return {rfunc.out_field: z}\n\n\ndef message_passing(g, mfunc, rfunc, afunc):\n    \"\"\"Invoke message passing computation on the whole graph.\n\n    Parameters\n    ----------\n    g : DGLGraph\n        The input graph.\n    mfunc : callable or dgl.function.BuiltinFunction\n        Message function.\n    rfunc : callable or dgl.function.BuiltinFunction\n        Reduce function.\n    afunc : callable or dgl.function.BuiltinFunction\n        Apply function.\n\n    Returns\n    -------\n    dict[str, Tensor]\n        Results from the message passing computation.\n    \"\"\"\n    if (\n        is_builtin(mfunc)\n        and is_builtin(rfunc)\n        and getattr(ops, \"{}_{}\".format(mfunc.name, rfunc.name), None)\n        is not None\n    ):\n        # invoke fused message passing\n        ndata = invoke_gspmm(g, mfunc, rfunc)\n    else:\n        # invoke message passing in two separate steps\n        # message phase\n        if is_builtin(mfunc):\n            msgdata = invoke_gsddmm(g, mfunc)\n        else:\n            orig_eid = g.edata.get(EID, None)\n            msgdata = invoke_edge_udf(\n                g, ALL, g.canonical_etypes[0], mfunc, orig_eid=orig_eid\n            )\n        # reduce phase\n        if is_builtin(rfunc):\n            msg = rfunc.msg_field\n            ndata = invoke_gspmm(g, fn.copy_e(msg, msg), rfunc, edata=msgdata)\n        else:\n            orig_nid = g.dstdata.get(NID, None)\n            ndata = invoke_udf_reduce(g, rfunc, msgdata, orig_nid=orig_nid)\n    # apply phase\n    if afunc is not None:\n        for k, v in g.dstdata.items():  # include original node features\n            if k not in ndata:\n                ndata[k] = v\n        orig_nid = g.dstdata.get(NID, None)\n        ndata = invoke_node_udf(\n            g, ALL, g.dsttypes[0], afunc, ndata=ndata, orig_nid=orig_nid\n        )\n    return ndata\n"
  },
  {
    "path": "python/dgl/cuda/__init__.py",
    "content": "\"\"\" CUDA wrappers \"\"\"\nfrom .. import backend as F\n\nfrom .gpu_cache import GPUCache\n\nif F.get_preferred_backend() == \"pytorch\":\n    from . import nccl\n"
  },
  {
    "path": "python/dgl/cuda/gpu_cache.py",
    "content": "\"\"\"API wrapping HugeCTR gpu_cache.\"\"\"\n#    Copyright (c) 2022, NVIDIA Corporation\n#    All rights reserved.\n#\n#    Licensed under the Apache License, Version 2.0 (the \"License\");\n#    you may not use this file except in compliance with the License.\n#    You may obtain a copy of the License at\n#\n#        http://www.apache.org/licenses/LICENSE-2.0\n#\n#    Unless required by applicable law or agreed to in writing, software\n#    distributed under the License is distributed on an \"AS IS\" BASIS,\n#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#    See the License for the specific language governing permissions and\n#    limitations under the License.\n#\n#  @file gpu_cache.py\n#  @brief API for managing a GPU Cache\n\nfrom .. import backend as F\nfrom .._ffi.function import _init_api\n\n\nclass GPUCache(object):\n    \"\"\"High-level wrapper for GPU embedding cache\"\"\"\n\n    def __init__(self, num_items, num_feats, idtype=F.int64):\n        assert idtype in [F.int32, F.int64]\n        self._cache = _CAPI_DGLGpuCacheCreate(\n            num_items, num_feats, 32 if idtype == F.int32 else 64\n        )\n        self.idtype = idtype\n        self.total_miss = 0\n        self.total_queries = 0\n\n    def query(self, keys):\n        \"\"\"Queries the GPU cache.\n\n        Parameters\n        ----------\n        keys : Tensor\n            The keys to query the GPU cache with.\n\n        Returns\n        -------\n        tuple(Tensor, Tensor, Tensor)\n            A tuple containing (values, missing_indices, missing_keys) where\n            values[missing_indices] corresponds to cache misses that should be\n            filled by quering another source with missing_keys.\n        \"\"\"\n        self.total_queries += keys.shape[0]\n        keys = F.astype(keys, self.idtype)\n        values, missing_index, missing_keys = _CAPI_DGLGpuCacheQuery(\n            self._cache, F.to_dgl_nd(keys)\n        )\n        self.total_miss += missing_keys.shape[0]\n        return (\n            F.from_dgl_nd(values),\n            F.from_dgl_nd(missing_index),\n            F.from_dgl_nd(missing_keys),\n        )\n\n    def replace(self, keys, values):\n        \"\"\"Inserts key-value pairs into the GPU cache using the Least-Recently\n        Used (LRU) algorithm to remove old key-value pairs if it is full.\n\n        Parameters\n        ----------\n        keys: Tensor\n            The keys to insert to the GPU cache.\n        values: Tensor\n            The values to insert to the GPU cache.\n        \"\"\"\n        keys = F.astype(keys, self.idtype)\n        values = F.astype(values, F.float32)\n        _CAPI_DGLGpuCacheReplace(\n            self._cache, F.to_dgl_nd(keys), F.to_dgl_nd(values)\n        )\n\n    @property\n    def miss_rate(self):\n        \"\"\"Returns the cache miss rate since creation.\"\"\"\n        return self.total_miss / self.total_queries\n\n\n_init_api(\"dgl.cuda\", __name__)\n"
  },
  {
    "path": "python/dgl/cuda/nccl.py",
    "content": "\"\"\"API wrapping NCCL primitives.\"\"\"\n\nimport torch\nimport torch.distributed as dist\n\n\ndef sparse_all_to_all_push(idx, value, partition):\n    \"\"\"Perform an all-to-all-v operation, where by all processors send out\n    a set of indices and corresponding values. Indices and values,\n    corresponding to the current process, will copied into the output\n    arrays.\n\n    Note: This method requires 'torch.distributed.get_backend() == \"nccl\"'.\n\n    Parameters\n    ----------\n    idx : torch.Tensor\n        The 1D set of indices to send to other processors.\n    value : torch.Tensor\n        The multi-dimension set of values to send to other processors.\n        The first dimension must match that of `idx`.\n    partition : NDArrayPartition\n        The object containing information for assigning indices to\n        processors.\n\n    Returns\n    -------\n    torch.Tensor\n        The 1D tensor of the recieved indices.\n    torch.Tensor\n        The set of recieved values.\n\n    Examples\n    --------\n\n    To perform a sparse_all_to_all_push(), a partition object must be\n    provided. A partition of a homgeonous graph, where the vertices are\n    striped across processes can be generated via:\n\n    >>> from dgl.partition import NDArrayPartition\n    >>> part = NDArrayPartition(g.num_nodes(), world_size, mode='remainder')\n\n    With this partition, each processor can send values to be associatd\n    with vertices in the graph. So if we have an array `global_idxs` of all of\n    the neighbors updated during mini-batch processing, and an array\n    `global_values` containing the new values associated with the neighbors,\n    we communicate them to the own processes via:\n\n    >>> my_idxs, my_values = nccl.sparse_all_to_all_push(global_idxs, global_values, part)\n\n    This communication pattern is common when communicating gradient\n    updates for node embeddings.\n\n    Indices the current process owns, do not need to treated specially,\n    as internally they will be copied to the output array. If we have a\n    set of indices in process 0 '[0, 3, 8, 9, 10]` and for process 1\n    '[0, 2, 4, 5, 8, 8, 9]'. Using a remainder partition will result\n    indices for processe 0 of '[0, 8, 10, 0, 2, 4, 8, 8]', and for\n    process 1 of '[3, 9, 5, 9]'.\n    \"\"\"\n    if not dist.is_initialized() or dist.get_world_size() == 1:\n        return idx, value\n    assert (\n        dist.get_backend() == \"nccl\"\n    ), \"requires NCCL backend to communicate CUDA tensors.\"\n\n    perm, send_splits = partition.generate_permutation(idx)\n    perm = perm.long()\n\n    # Get receive splits.\n    recv_splits = torch.empty_like(send_splits)\n    dist.all_to_all_single(recv_splits, send_splits)\n\n    # Use pinned memory to speedup D2H copy.\n    recv_splits = recv_splits.to(\"cpu\", non_blocking=True)\n    send_splits = send_splits.to(\"cpu\", non_blocking=True)\n    send_idx = idx[perm]\n    send_value = value[perm]\n    # Wait D2H copy finish.\n    torch.cuda.current_stream().synchronize()\n    recv_sum = recv_splits.sum()\n    recv_splits = recv_splits.tolist()\n    send_splits = send_splits.tolist()\n\n    # Send idx.\n    recv_idx = torch.empty((recv_sum,), dtype=idx.dtype, device=idx.device)\n    dist.all_to_all_single(recv_idx, send_idx, recv_splits, send_splits)\n\n    # Send value.\n    recv_value = torch.empty(\n        (recv_sum, *value.shape[1:]), dtype=value.dtype, device=value.device\n    )\n    dist.all_to_all_single(recv_value, send_value, recv_splits, send_splits)\n\n    return recv_idx, recv_value\n\n\ndef sparse_all_to_all_pull(req_idx, value, partition):\n    \"\"\"Perform an all-to-all-v operation, where by all processors request\n    the values corresponding to their set of indices.\n\n    Note: This method requires 'torch.distributed.get_backend() == \"nccl\"'.\n\n    Parameters\n    ----------\n    req_idx : torch.Tensor\n        The set of indices this processor is requesting.\n    value : torch.Tensor\n        The multi-dimension set of values that can be requested from\n        this processor.\n    partition : NDArrayPartition\n        The object containing information for assigning indices to\n        processors.\n\n    Returns\n    -------\n    torch.Tensor\n        The set of recieved values, corresponding to `req_idx`.\n\n    Examples\n    --------\n\n    To perform a sparse_all_to_all_pull(), a partition object must be\n    provided. A partition of a homgeonous graph, where the vertices are\n    striped across processes can be generated via:\n\n    >>> from dgl.partition import NDArrayPartition\n    >>> part = NDArrayPartition(g.num_nodes(), world_size, mode='remainder')\n\n    With this partition, each processor can request values/features\n    associated with vertices in the graph. So in the case where we have\n    a set of neighbors 'nbr_idxs' we need features for, and each process\n    has a tensor 'node_feat' storing the features of nodes it owns in\n    the partition, the features can be requested via:\n\n    >>> nbr_values = nccl.sparse_all_to_all_pull(nbr_idxs, node_feat, part)\n\n    Then two the arrays 'nbr_idxs' and 'nbr_values' forms the sparse\n    set of features, where 'nbr_idxs[i]' is the global node id, and\n    'nbr_values[i]' is the feature vector for that node. This\n    communication pattern is useful for node features or node\n    embeddings.\n    \"\"\"\n    if not dist.is_initialized() or dist.get_world_size() == 1:\n        return value[req_idx.long()]\n    assert (\n        dist.get_backend() == \"nccl\"\n    ), \"requires NCCL backend to communicate CUDA tensors.\"\n\n    perm, req_splits = partition.generate_permutation(req_idx)\n    perm = perm.long()\n\n    # Get response splits.\n    resp_splits = torch.empty_like(req_splits)\n    dist.all_to_all_single(resp_splits, req_splits)\n\n    # Use pinned memory to speedup D2H copy.\n    resp_splits = resp_splits.to(\"cpu\", non_blocking=True)\n    req_splits = req_splits.to(\"cpu\", non_blocking=True)\n    req_idx = req_idx[perm]\n    # Wait D2H copy finish.\n    torch.cuda.current_stream().synchronize()\n    resp_sum = resp_splits.sum()\n    resp_splits = resp_splits.tolist()\n    req_splits = req_splits.tolist()\n\n    # Gather requested indices.\n    resp_idx = torch.empty(\n        (resp_sum,), dtype=req_idx.dtype, device=req_idx.device\n    )\n    dist.all_to_all_single(resp_idx, req_idx, resp_splits, req_splits)\n\n    # Convert requested indices to local indices depending on partition.\n    if resp_sum > 0:\n        resp_idx = partition.map_to_local(resp_idx)\n\n    # Collect the request value.\n    req_value = torch.empty(\n        (req_idx.size(0), *value.shape[1:]),\n        dtype=value.dtype,\n        device=value.device,\n    )\n    dist.all_to_all_single(req_value, value[resp_idx], req_splits, resp_splits)\n\n    # Permute the value back into the requested order.\n    return_value = torch.empty_like(req_value)\n    return_value[perm] = req_value\n\n    return return_value\n"
  },
  {
    "path": "python/dgl/data/__init__.py",
    "content": "\"\"\"The ``dgl.data`` package contains datasets hosted by DGL and also utilities\nfor downloading, processing, saving and loading data from external resources.\n\"\"\"\n\nfrom __future__ import absolute_import\n\nfrom . import citation_graph as citegrh\nfrom .actor import ActorDataset\nfrom .movielens import MovieLensDataset\nfrom .adapter import *\nfrom .bitcoinotc import BitcoinOTC, BitcoinOTCDataset\nfrom .citation_graph import (\n    CitationGraphDataset,\n    CiteseerGraphDataset,\n    CoraBinary,\n    CoraGraphDataset,\n    PubmedGraphDataset,\n)\nfrom .csv_dataset import CSVDataset\nfrom .dgl_dataset import DGLBuiltinDataset, DGLDataset\nfrom .fakenews import FakeNewsDataset\nfrom .flickr import FlickrDataset\nfrom .fraud import FraudAmazonDataset, FraudDataset, FraudYelpDataset\nfrom .gdelt import GDELT, GDELTDataset\nfrom .gindt import GINDataset\nfrom .gnn_benchmark import (\n    AmazonCoBuy,\n    AmazonCoBuyComputerDataset,\n    AmazonCoBuyPhotoDataset,\n    Coauthor,\n    CoauthorCSDataset,\n    CoauthorPhysicsDataset,\n    CoraFull,\n    CoraFullDataset,\n)\nfrom .icews18 import ICEWS18, ICEWS18Dataset\nfrom .karate import KarateClub, KarateClubDataset\nfrom .knowledge_graph import FB15k237Dataset, FB15kDataset, WN18Dataset\nfrom .minigc import *\nfrom .ppi import LegacyPPIDataset, PPIDataset\nfrom .qm7b import QM7b, QM7bDataset\nfrom .qm9 import QM9, QM9Dataset\nfrom .qm9_edge import QM9Edge, QM9EdgeDataset\nfrom .rdf import AIFBDataset, AMDataset, BGSDataset, MUTAGDataset\nfrom .reddit import RedditDataset\nfrom .sbm import SBMMixture, SBMMixtureDataset\nfrom .synthetic import (\n    BA2MotifDataset,\n    BACommunityDataset,\n    BAShapeDataset,\n    TreeCycleDataset,\n    TreeGridDataset,\n)\nfrom .tree import SST, SSTDataset\nfrom .tu import LegacyTUDataset, TUDataset\nfrom .utils import *\nfrom .cluster import CLUSTERDataset\nfrom .geom_gcn import (\n    ChameleonDataset,\n    CornellDataset,\n    SquirrelDataset,\n    TexasDataset,\n    WisconsinDataset,\n)\n\nfrom .heterophilous_graphs import (\n    AmazonRatingsDataset,\n    MinesweeperDataset,\n    QuestionsDataset,\n    RomanEmpireDataset,\n    TolokersDataset,\n)\n\n# RDKit is required for Peptides-Structural, Peptides-Functional dataset.\n# Exception handling was added to prevent crashes for users who are using other\n# datasets.\ntry:\n    from .lrgb import (\n        COCOSuperpixelsDataset,\n        PeptidesFunctionalDataset,\n        PeptidesStructuralDataset,\n        VOCSuperpixelsDataset,\n    )\nexcept ImportError:\n    pass\nfrom .pattern import PATTERNDataset\nfrom .superpixel import CIFAR10SuperPixelDataset, MNISTSuperPixelDataset\nfrom .wikics import WikiCSDataset\nfrom .yelp import YelpDataset\nfrom .zinc import ZINCDataset\n\n\ndef register_data_args(parser):\n    parser.add_argument(\n        \"--dataset\",\n        type=str,\n        required=False,\n        help=\"The input dataset. Can be cora, citeseer, pubmed, syn(synthetic dataset) or reddit\",\n    )\n\n\ndef load_data(args):\n    if args.dataset == \"cora\":\n        return citegrh.load_cora()\n    elif args.dataset == \"citeseer\":\n        return citegrh.load_citeseer()\n    elif args.dataset == \"pubmed\":\n        return citegrh.load_pubmed()\n    elif args.dataset is not None and args.dataset.startswith(\"reddit\"):\n        return RedditDataset(self_loop=(\"self-loop\" in args.dataset))\n    else:\n        raise ValueError(\"Unknown dataset: {}\".format(args.dataset))\n"
  },
  {
    "path": "python/dgl/data/actor.py",
    "content": "\"\"\"\nActor-only induced subgraph of the film-directoractor-writer network.\n\"\"\"\nimport os\n\nimport numpy as np\n\nfrom ..convert import graph\nfrom .dgl_dataset import DGLBuiltinDataset\nfrom .utils import _get_dgl_url\n\n\nclass ActorDataset(DGLBuiltinDataset):\n    r\"\"\"Actor-only induced subgraph of the film-directoractor-writer network\n    from `Social Influence Analysis in Large-scale Networks\n    <https://dl.acm.org/doi/10.1145/1557019.1557108>`, introduced by\n    `Geom-GCN: Geometric Graph Convolutional Networks\n    <https://arxiv.org/abs/2002.05287>`\n\n    Nodes represent actors, and edges represent co-occurrence on the same\n    Wikipedia page. Node features correspond to some keywords in the Wikipedia\n    pages.\n\n    Statistics:\n\n    - Nodes: 7600\n    - Edges: 33391\n    - Number of Classes: 5\n    - 10 train/val/test splits\n\n        - Train: 3648\n        - Val: 2432\n        - Test: 1520\n\n    Parameters\n    ----------\n    raw_dir : str, optional\n        Raw file directory to store the processed data. Default: ~/.dgl/\n    force_reload : bool, optional\n        Whether to re-download the data source. Default: False\n    verbose : bool, optional\n        Whether to print progress information. Default: True\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access. Default: None\n\n    Attributes\n    ----------\n    num_classes : int\n        Number of node classes\n\n    Notes\n    -----\n    The graph does not come with edges for both directions.\n    \"\"\"\n\n    def __init__(\n        self, raw_dir=None, force_reload=False, verbose=True, transform=None\n    ):\n        super(ActorDataset, self).__init__(\n            name=\"actor\",\n            url=_get_dgl_url(\"dataset/actor.zip\"),\n            raw_dir=raw_dir,\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n    def process(self):\n        \"\"\"Load and process the data.\"\"\"\n        try:\n            import torch\n        except ImportError:\n            raise ModuleNotFoundError(\n                \"This dataset requires PyTorch to be the backend.\"\n            )\n\n        # Process node features and labels.\n        with open(f\"{self.raw_path}/out1_node_feature_label.txt\", \"r\") as f:\n            data = [x.split(\"\\t\") for x in f.read().split(\"\\n\")[1:-1]]\n\n            rows, cols = [], []\n            labels = torch.empty(len(data), dtype=torch.long)\n            for n_id, col, label in data:\n                col = [int(x) for x in col.split(\",\")]\n                rows += [int(n_id)] * len(col)\n                cols += col\n\n                labels[int(n_id)] = int(label)\n\n            row, col = torch.tensor(rows), torch.tensor(cols)\n            features = torch.zeros(len(data), int(col.max()) + 1)\n            features[row, col] = 1.0\n\n            self._num_classes = int(labels.max().item()) + 1\n\n        # Process graph structure.\n        with open(f\"{self.raw_path}/out1_graph_edges.txt\", \"r\") as f:\n            data = f.read().split(\"\\n\")[1:-1]\n            data = [[int(v) for v in r.split(\"\\t\")] for r in data]\n        dst, src = torch.tensor(data, dtype=torch.long).t().contiguous()\n\n        self._g = graph((src, dst), num_nodes=features.size(0))\n        self._g.ndata[\"feat\"] = features\n        self._g.ndata[\"label\"] = labels\n\n        # Process 10 train/val/test node splits.\n        train_masks, val_masks, test_masks = [], [], []\n        for i in range(10):\n            filepath = f\"{self.raw_path}/{self.name}_split_0.6_0.2_{i}.npz\"\n            f = np.load(filepath)\n            train_masks += [torch.from_numpy(f[\"train_mask\"])]\n            val_masks += [torch.from_numpy(f[\"val_mask\"])]\n            test_masks += [torch.from_numpy(f[\"test_mask\"])]\n        self._g.ndata[\"train_mask\"] = torch.stack(train_masks, dim=1).bool()\n        self._g.ndata[\"val_mask\"] = torch.stack(val_masks, dim=1).bool()\n        self._g.ndata[\"test_mask\"] = torch.stack(test_masks, dim=1).bool()\n\n    def has_cache(self):\n        return os.path.exists(self.raw_path)\n\n    def load(self):\n        self.process()\n\n    def __getitem__(self, idx):\n        assert idx == 0, \"This dataset has only one graph.\"\n        if self._transform is None:\n            return self._g\n        else:\n            return self._transform(self._g)\n\n    def __len__(self):\n        return 1\n\n    @property\n    def num_classes(self):\n        return self._num_classes\n"
  },
  {
    "path": "python/dgl/data/adapter.py",
    "content": "\"\"\"Dataset adapters for re-purposing a dataset for a different kind of training task.\"\"\"\n\nimport json\nimport os\n\nimport numpy as np\n\nfrom .. import backend as F\nfrom ..base import DGLError\nfrom ..convert import graph as create_dgl_graph\nfrom ..sampling.negative import _calc_redundancy\nfrom . import utils\nfrom .dgl_dataset import DGLDataset\n\n__all__ = [\"AsNodePredDataset\", \"AsLinkPredDataset\", \"AsGraphPredDataset\"]\n\n\nclass AsNodePredDataset(DGLDataset):\n    \"\"\"Repurpose a dataset for a standard semi-supervised transductive\n    node prediction task.\n\n    The class converts a given dataset into a new dataset object such that:\n\n      - Contains only one graph, accessible from ``dataset[0]``.\n      - The graph stores:\n\n        - Node labels in ``g.ndata['label']``.\n        - Train/val/test masks in ``g.ndata['train_mask']``, ``g.ndata['val_mask']``,\n          and ``g.ndata['test_mask']`` respectively.\n      - In addition, the dataset contains the following attributes:\n\n        - ``num_classes``, the number of classes to predict.\n        - ``train_idx``, ``val_idx``, ``test_idx``, train/val/test indexes.\n\n    If the input dataset contains heterogeneous graphs, users need to specify the\n    ``target_ntype`` argument to indicate which node type to make predictions for.\n    In this case:\n\n      - Node labels are stored in ``g.nodes[target_ntype].data['label']``.\n      - Training masks are stored in ``g.nodes[target_ntype].data['train_mask']``.\n        So do validation and test masks.\n\n    The class will keep only the first graph in the provided dataset and\n    generate train/val/test masks according to the given split ratio. The generated\n    masks will be cached to disk for fast re-loading. If the provided split ratio\n    differs from the cached one, it will re-process the dataset properly.\n\n    Parameters\n    ----------\n    dataset : DGLDataset\n        The dataset to be converted.\n    split_ratio : (float, float, float), optional\n        Split ratios for training, validation and test sets. They must sum to one.\n    target_ntype : str, optional\n        The node type to add split mask for.\n\n    Attributes\n    ----------\n    num_classes : int\n        Number of classes to predict.\n    train_idx : Tensor\n        An 1-D integer tensor of training node IDs.\n    val_idx : Tensor\n        An 1-D integer tensor of validation node IDs.\n    test_idx : Tensor\n        An 1-D integer tensor of test node IDs.\n\n    Examples\n    --------\n    >>> ds = dgl.data.AmazonCoBuyComputerDataset()\n    >>> print(ds)\n    Dataset(\"amazon_co_buy_computer\", num_graphs=1, save_path=...)\n    >>> new_ds = dgl.data.AsNodePredDataset(ds, [0.8, 0.1, 0.1])\n    >>> print(new_ds)\n    Dataset(\"amazon_co_buy_computer-as-nodepred\", num_graphs=1, save_path=...)\n    >>> print('train_mask' in new_ds[0].ndata)\n    True\n    \"\"\"\n\n    def __init__(self, dataset, split_ratio=None, target_ntype=None, **kwargs):\n        self.dataset = dataset\n        self.split_ratio = split_ratio\n        self.target_ntype = target_ntype\n        super().__init__(\n            self.dataset.name + \"-as-nodepred\",\n            hash_key=(split_ratio, target_ntype, dataset.name, \"nodepred\"),\n            **kwargs\n        )\n\n    def process(self):\n        is_ogb = hasattr(self.dataset, \"get_idx_split\")\n        if is_ogb:\n            g, label = self.dataset[0]\n            self.g = g.clone()\n            self.g.ndata[\"label\"] = F.reshape(label, (g.num_nodes(),))\n        else:\n            self.g = self.dataset[0].clone()\n\n        if \"label\" not in self.g.nodes[self.target_ntype].data:\n            raise ValueError(\n                \"Missing node labels. Make sure labels are stored \"\n                \"under name 'label'.\"\n            )\n\n        if self.split_ratio is None:\n            if is_ogb:\n                split = self.dataset.get_idx_split()\n                train_idx, val_idx, test_idx = (\n                    split[\"train\"],\n                    split[\"valid\"],\n                    split[\"test\"],\n                )\n                n = self.g.num_nodes()\n                train_mask = utils.generate_mask_tensor(\n                    utils.idx2mask(train_idx, n)\n                )\n                val_mask = utils.generate_mask_tensor(\n                    utils.idx2mask(val_idx, n)\n                )\n                test_mask = utils.generate_mask_tensor(\n                    utils.idx2mask(test_idx, n)\n                )\n                self.g.ndata[\"train_mask\"] = train_mask\n                self.g.ndata[\"val_mask\"] = val_mask\n                self.g.ndata[\"test_mask\"] = test_mask\n            else:\n                assert (\n                    \"train_mask\" in self.g.nodes[self.target_ntype].data\n                ), \"train_mask is not provided, please specify split_ratio to generate the masks\"\n                assert (\n                    \"val_mask\" in self.g.nodes[self.target_ntype].data\n                ), \"val_mask is not provided, please specify split_ratio to generate the masks\"\n                assert (\n                    \"test_mask\" in self.g.nodes[self.target_ntype].data\n                ), \"test_mask is not provided, please specify split_ratio to generate the masks\"\n        else:\n            if self.verbose:\n                print(\"Generating train/val/test masks...\")\n            utils.add_nodepred_split(self, self.split_ratio, self.target_ntype)\n\n        self._set_split_index()\n\n        self.num_classes = getattr(self.dataset, \"num_classes\", None)\n        if self.num_classes is None:\n            self.num_classes = len(\n                F.unique(self.g.nodes[self.target_ntype].data[\"label\"])\n            )\n\n    def has_cache(self):\n        return os.path.isfile(\n            os.path.join(self.save_path, \"graph_{}.bin\".format(self.hash))\n        )\n\n    def load(self):\n        with open(\n            os.path.join(self.save_path, \"info_{}.json\".format(self.hash)), \"r\"\n        ) as f:\n            info = json.load(f)\n            if (\n                info[\"split_ratio\"] != self.split_ratio\n                or info[\"target_ntype\"] != self.target_ntype\n            ):\n                raise ValueError(\n                    \"Provided split ratio is different from the cached file. \"\n                    \"Re-process the dataset.\"\n                )\n            self.split_ratio = info[\"split_ratio\"]\n            self.target_ntype = info[\"target_ntype\"]\n            self.num_classes = info[\"num_classes\"]\n        gs, _ = utils.load_graphs(\n            os.path.join(self.save_path, \"graph_{}.bin\".format(self.hash))\n        )\n        self.g = gs[0]\n        self._set_split_index()\n\n    def save(self):\n        utils.save_graphs(\n            os.path.join(self.save_path, \"graph_{}.bin\".format(self.hash)),\n            [self.g],\n        )\n        with open(\n            os.path.join(self.save_path, \"info_{}.json\".format(self.hash)), \"w\"\n        ) as f:\n            json.dump(\n                {\n                    \"split_ratio\": self.split_ratio,\n                    \"target_ntype\": self.target_ntype,\n                    \"num_classes\": self.num_classes,\n                },\n                f,\n            )\n\n    def __getitem__(self, idx):\n        return self.g\n\n    def __len__(self):\n        return 1\n\n    def _set_split_index(self):\n        \"\"\"Add train_idx/val_idx/test_idx as dataset attributes according to corresponding mask.\"\"\"\n        ndata = self.g.nodes[self.target_ntype].data\n        self.train_idx = F.nonzero_1d(ndata[\"train_mask\"])\n        self.val_idx = F.nonzero_1d(ndata[\"val_mask\"])\n        self.test_idx = F.nonzero_1d(ndata[\"test_mask\"])\n\n\ndef negative_sample(g, num_samples):\n    \"\"\"Random sample negative edges from graph, excluding self-loops,\n    the result samples might be less than num_samples\n    \"\"\"\n    num_nodes = g.num_nodes()\n    redundancy = _calc_redundancy(num_samples, g.num_edges(), num_nodes**2)\n    sample_size = int(num_samples * (1 + redundancy))\n    edges = np.random.randint(0, num_nodes, size=(2, sample_size))\n    edges = np.unique(edges, axis=1)\n    # remove self loop\n    mask_self_loop = edges[0] == edges[1]\n    # remove existing edges\n    has_edges = F.asnumpy(g.has_edges_between(edges[0], edges[1]))\n    mask = ~(np.logical_or(mask_self_loop, has_edges))\n    edges = edges[:, mask]\n    if edges.shape[1] >= num_samples:\n        edges = edges[:, :num_samples]\n    return edges\n\n\nclass AsLinkPredDataset(DGLDataset):\n    \"\"\"Repurpose a dataset for link prediction task.\n\n    The created dataset will include data needed for link prediction.\n    Currently it only supports homogeneous graphs.\n    It will keep only the first graph in the provided dataset and\n    generate train/val/test edges according to the given split ratio,\n    and the correspondent negative edges based on the neg_ratio. The generated\n    edges will be cached to disk for fast re-loading. If the provided split ratio\n    differs from the cached one, it will re-process the dataset properly.\n\n    Parameters\n    ----------\n    dataset : DGLDataset\n        The dataset to be converted.\n    split_ratio : (float, float, float), optional\n        Split ratios for training, validation and test sets. Must sum to one.\n    neg_ratio : int, optional\n        Indicate how much negative samples to be sampled\n        The number of the negative samples will be equal or less than neg_ratio * num_positive_edges.\n\n    Attributes\n    -------\n    feat_size: int\n        The size of the feature dimension in the graph\n    train_graph: DGLGraph\n        The DGLGraph for training\n    val_edges: Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor]]\n        The validation set edges, encoded as\n        ((positive_edge_src, positive_edge_dst), (negative_edge_src, negative_edge_dst))\n    test_edges: Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor]]\n        The test set edges, encoded as\n        ((positive_edge_src, positive_edge_dst), (negative_edge_src, negative_edge_dst))\n\n    Examples\n    --------\n    >>> ds = dgl.data.CoraGraphDataset()\n    >>> print(ds)\n    Dataset(\"cora_v2\", num_graphs=1, save_path=...)\n    >>> new_ds = dgl.data.AsLinkPredDataset(ds, [0.8, 0.1, 0.1])\n    >>> print(new_ds)\n    Dataset(\"cora_v2-as-linkpred\", num_graphs=1, save_path=/home/ubuntu/.dgl/cora_v2-as-linkpred)\n    >>> print(hasattr(new_ds, \"test_edges\"))\n    True\n    \"\"\"\n\n    def __init__(self, dataset, split_ratio=None, neg_ratio=3, **kwargs):\n        self.g = dataset[0]\n        self.num_nodes = self.g.num_nodes()\n        self.dataset = dataset\n        self.split_ratio = split_ratio\n        self.neg_ratio = neg_ratio\n        super().__init__(\n            dataset.name + \"-as-linkpred\",\n            hash_key=(neg_ratio, split_ratio, dataset.name, \"linkpred\"),\n            **kwargs\n        )\n\n    def process(self):\n        if self.split_ratio is None:\n            # Handle logics for OGB link prediction dataset\n            assert hasattr(\n                self.dataset, \"get_edge_split\"\n            ), \"dataset doesn't have get_edge_split method, please specify split_ratio and neg_ratio to generate the split\"\n            # This is likely to be an ogb dataset\n            self.edge_split = self.dataset.get_edge_split()\n            self._train_graph = self.g\n            if \"source_node\" in self.edge_split[\"test\"]:\n                # Probably ogbl-citation2\n                pos_e = (\n                    self.edge_split[\"valid\"][\"source_node\"],\n                    self.edge_split[\"valid\"][\"target_node\"],\n                )\n                neg_e_size = self.edge_split[\"valid\"][\"target_node_neg\"].shape[\n                    -1\n                ]\n                neg_e_src = np.repeat(\n                    self.edge_split[\"valid\"][\"source_node\"], neg_e_size\n                )\n                neg_e_dst = np.reshape(\n                    self.edge_split[\"valid\"][\"target_node_neg\"], -1\n                )\n                self._val_edges = pos_e, (neg_e_src, neg_e_dst)\n                pos_e = (\n                    self.edge_split[\"test\"][\"source_node\"],\n                    self.edge_split[\"test\"][\"target_node\"],\n                )\n                neg_e_size = self.edge_split[\"test\"][\"target_node_neg\"].shape[\n                    -1\n                ]\n                neg_e_src = np.repeat(\n                    self.edge_split[\"test\"][\"source_node\"], neg_e_size\n                )\n                neg_e_dst = np.reshape(\n                    self.edge_split[\"test\"][\"target_node_neg\"], -1\n                )\n                self._test_edges = pos_e, (neg_e_src, neg_e_dst)\n            elif \"edge\" in self.edge_split[\"test\"]:\n                # Probably ogbl-collab\n                pos_e_tensor, neg_e_tensor = (\n                    self.edge_split[\"valid\"][\"edge\"],\n                    self.edge_split[\"valid\"][\"edge_neg\"],\n                )\n                pos_e = (pos_e_tensor[:, 0], pos_e_tensor[:, 1])\n                neg_e = (neg_e_tensor[:, 0], neg_e_tensor[:, 1])\n                self._val_edges = pos_e, neg_e\n\n                pos_e_tensor, neg_e_tensor = (\n                    self.edge_split[\"test\"][\"edge\"],\n                    self.edge_split[\"test\"][\"edge_neg\"],\n                )\n                pos_e = (pos_e_tensor[:, 0], pos_e_tensor[:, 1])\n                neg_e = (neg_e_tensor[:, 0], neg_e_tensor[:, 1])\n                self._test_edges = pos_e, neg_e\n            # delete edge split to save memory\n            self.edge_split = None\n        else:\n            assert self.split_ratio is not None, \"Need to specify split_ratio\"\n            assert self.neg_ratio is not None, \"Need to specify neg_ratio\"\n            ratio = self.split_ratio\n            graph = self.dataset[0]\n            n = graph.num_edges()\n            src, dst = graph.edges()\n            src, dst = F.asnumpy(src), F.asnumpy(dst)\n            n_train, n_val, n_test = (\n                int(n * ratio[0]),\n                int(n * ratio[1]),\n                int(n * ratio[2]),\n            )\n\n            idx = np.random.permutation(n)\n            train_pos_idx = idx[:n_train]\n            val_pos_idx = idx[n_train : n_train + n_val]\n            test_pos_idx = idx[n_train + n_val :]\n            neg_src, neg_dst = negative_sample(\n                graph, self.neg_ratio * (n_val + n_test)\n            )\n            neg_n_val, neg_n_test = (\n                self.neg_ratio * n_val,\n                self.neg_ratio * n_test,\n            )\n            neg_val_src, neg_val_dst = neg_src[:neg_n_val], neg_dst[:neg_n_val]\n            neg_test_src, neg_test_dst = (\n                neg_src[neg_n_val:],\n                neg_dst[neg_n_val:],\n            )\n            self._val_edges = (\n                F.tensor(src[val_pos_idx]),\n                F.tensor(dst[val_pos_idx]),\n            ), (F.tensor(neg_val_src), F.tensor(neg_val_dst))\n            self._test_edges = (\n                F.tensor(src[test_pos_idx]),\n                F.tensor(dst[test_pos_idx]),\n            ), (F.tensor(neg_test_src), F.tensor(neg_test_dst))\n            self._train_graph = create_dgl_graph(\n                (src[train_pos_idx], dst[train_pos_idx]),\n                num_nodes=self.num_nodes,\n            )\n            self._train_graph.ndata[\"feat\"] = graph.ndata[\"feat\"]\n\n    def has_cache(self):\n        return os.path.isfile(\n            os.path.join(self.save_path, \"graph_{}.bin\".format(self.hash))\n        )\n\n    def load(self):\n        gs, tensor_dict = utils.load_graphs(\n            os.path.join(self.save_path, \"graph_{}.bin\".format(self.hash))\n        )\n        self.g = gs[0]\n        self._train_graph = self.g\n        self._val_edges = (\n            tensor_dict[\"val_pos_src\"],\n            tensor_dict[\"val_pos_dst\"],\n        ), (tensor_dict[\"val_neg_src\"], tensor_dict[\"val_neg_dst\"])\n        self._test_edges = (\n            tensor_dict[\"test_pos_src\"],\n            tensor_dict[\"test_pos_dst\"],\n        ), (tensor_dict[\"test_neg_src\"], tensor_dict[\"test_neg_dst\"])\n\n        with open(\n            os.path.join(self.save_path, \"info_{}.json\".format(self.hash)), \"r\"\n        ) as f:\n            info = json.load(f)\n            self.split_ratio = info[\"split_ratio\"]\n            self.neg_ratio = info[\"neg_ratio\"]\n\n    def save(self):\n        tensor_dict = {\n            \"val_pos_src\": self._val_edges[0][0],\n            \"val_pos_dst\": self._val_edges[0][1],\n            \"val_neg_src\": self._val_edges[1][0],\n            \"val_neg_dst\": self._val_edges[1][1],\n            \"test_pos_src\": self._test_edges[0][0],\n            \"test_pos_dst\": self._test_edges[0][1],\n            \"test_neg_src\": self._test_edges[1][0],\n            \"test_neg_dst\": self._test_edges[1][1],\n        }\n        utils.save_graphs(\n            os.path.join(self.save_path, \"graph_{}.bin\".format(self.hash)),\n            [self._train_graph],\n            tensor_dict,\n        )\n        with open(\n            os.path.join(self.save_path, \"info_{}.json\".format(self.hash)), \"w\"\n        ) as f:\n            json.dump(\n                {\"split_ratio\": self.split_ratio, \"neg_ratio\": self.neg_ratio},\n                f,\n            )\n\n    @property\n    def feat_size(self):\n        return self._train_graph.ndata[\"feat\"].shape[-1]\n\n    @property\n    def train_graph(self):\n        return self._train_graph\n\n    @property\n    def val_edges(self):\n        return self._val_edges\n\n    @property\n    def test_edges(self):\n        return self._test_edges\n\n    def __getitem__(self, idx):\n        return self.g\n\n    def __len__(self):\n        return 1\n\n\nclass AsGraphPredDataset(DGLDataset):\n    \"\"\"Repurpose a dataset for standard graph property prediction task.\n\n    The created dataset will include data needed for graph property prediction.\n    Currently it only supports homogeneous graphs.\n\n    The class converts a given dataset into a new dataset object such that:\n\n      - It stores ``len(dataset)`` graphs.\n      - The i-th graph and its label is accessible from ``dataset[i]``.\n\n    The class will generate a train/val/test split if :attr:`split_ratio` is provided.\n    The generated split will be cached to disk for fast re-loading. If the provided split\n    ratio differs from the cached one, it will re-process the dataset properly.\n\n    Parameters\n    ----------\n    dataset : DGLDataset\n        The dataset to be converted.\n    split_ratio : (float, float, float), optional\n        Split ratios for training, validation and test sets. They must sum to one.\n\n    Attributes\n    ----------\n    num_tasks : int\n        Number of tasks to predict.\n    num_classes : int\n        Number of classes to predict per task, None for regression datasets.\n    train_idx : Tensor\n        An 1-D integer tensor of training node IDs.\n    val_idx : Tensor\n        An 1-D integer tensor of validation node IDs.\n    test_idx : Tensor\n        An 1-D integer tensor of test node IDs.\n    node_feat_size : int\n        Input node feature size, None if not applicable.\n    edge_feat_size : int\n        Input edge feature size, None if not applicable.\n\n    Examples\n    --------\n\n    >>> from dgl.data import AsGraphPredDataset\n    >>> from ogb.graphproppred import DglGraphPropPredDataset\n    >>> dataset = DglGraphPropPredDataset(name='ogbg-molhiv')\n    >>> new_dataset = AsGraphPredDataset(dataset)\n    >>> print(new_dataset)\n    Dataset(\"ogbg-molhiv-as-graphpred\", num_graphs=41127, save_path=...)\n    >>> print(len(new_dataset))\n    41127\n    >>> print(new_dataset[0])\n    (Graph(num_nodes=19, num_edges=40,\n           ndata_schemes={'feat': Scheme(shape=(9,), dtype=torch.int64)}\n           edata_schemes={'feat': Scheme(shape=(3,), dtype=torch.int64)}), tensor([0]))\n    \"\"\"\n\n    def __init__(self, dataset, split_ratio=None, **kwargs):\n        self.dataset = dataset\n        self.split_ratio = split_ratio\n        super().__init__(\n            dataset.name + \"-as-graphpred\",\n            hash_key=(split_ratio, dataset.name, \"graphpred\"),\n            **kwargs\n        )\n\n    def process(self):\n        is_ogb = hasattr(self.dataset, \"get_idx_split\")\n        if self.split_ratio is None:\n            if is_ogb:\n                split = self.dataset.get_idx_split()\n                self.train_idx = split[\"train\"]\n                self.val_idx = split[\"valid\"]\n                self.test_idx = split[\"test\"]\n            else:\n                # Handle FakeNewsDataset\n                try:\n                    self.train_idx = F.nonzero_1d(self.dataset.train_mask)\n                    self.val_idx = F.nonzero_1d(self.dataset.val_mask)\n                    self.test_idx = F.nonzero_1d(self.dataset.test_mask)\n                except:\n                    raise DGLError(\n                        \"The input dataset does not have default train/val/test\\\n                        split. Please specify split_ratio to generate the split.\"\n                    )\n        else:\n            if self.verbose:\n                print(\"Generating train/val/test split...\")\n            train_ratio, val_ratio, _ = self.split_ratio\n            num_graphs = len(self.dataset)\n            num_train = int(num_graphs * train_ratio)\n            num_val = int(num_graphs * val_ratio)\n\n            idx = np.random.permutation(num_graphs)\n            self.train_idx = F.tensor(idx[:num_train])\n            self.val_idx = F.tensor(idx[num_train : num_train + num_val])\n            self.test_idx = F.tensor(idx[num_train + num_val :])\n\n        if hasattr(self.dataset, \"num_classes\"):\n            # GINDataset, MiniGCDataset, FakeNewsDataset, TUDataset,\n            # LegacyTUDataset, BA2MotifDataset\n            self.num_classes = self.dataset.num_classes\n        else:\n            # None for multi-label classification and regression\n            self.num_classes = None\n\n        if hasattr(self.dataset, \"num_tasks\"):\n            # OGB datasets\n            self.num_tasks = self.dataset.num_tasks\n        else:\n            self.num_tasks = 1\n\n    def has_cache(self):\n        return os.path.isfile(\n            os.path.join(self.save_path, \"info_{}.json\".format(self.hash))\n        )\n\n    def load(self):\n        with open(\n            os.path.join(self.save_path, \"info_{}.json\".format(self.hash)), \"r\"\n        ) as f:\n            info = json.load(f)\n            if info[\"split_ratio\"] != self.split_ratio:\n                raise ValueError(\n                    \"Provided split ratio is different from the cached file. \"\n                    \"Re-process the dataset.\"\n                )\n            self.split_ratio = info[\"split_ratio\"]\n            self.num_tasks = info[\"num_tasks\"]\n            self.num_classes = info[\"num_classes\"]\n\n        split = np.load(\n            os.path.join(self.save_path, \"split_{}.npz\".format(self.hash))\n        )\n        self.train_idx = F.zerocopy_from_numpy(split[\"train_idx\"])\n        self.val_idx = F.zerocopy_from_numpy(split[\"val_idx\"])\n        self.test_idx = F.zerocopy_from_numpy(split[\"test_idx\"])\n\n    def save(self):\n        if not os.path.exists(self.save_path):\n            os.makedirs(self.save_path)\n        with open(\n            os.path.join(self.save_path, \"info_{}.json\".format(self.hash)), \"w\"\n        ) as f:\n            json.dump(\n                {\n                    \"split_ratio\": self.split_ratio,\n                    \"num_tasks\": self.num_tasks,\n                    \"num_classes\": self.num_classes,\n                },\n                f,\n            )\n        np.savez(\n            os.path.join(self.save_path, \"split_{}.npz\".format(self.hash)),\n            train_idx=F.zerocopy_to_numpy(self.train_idx),\n            val_idx=F.zerocopy_to_numpy(self.val_idx),\n            test_idx=F.zerocopy_to_numpy(self.test_idx),\n        )\n\n    def __getitem__(self, idx):\n        return self.dataset[idx]\n\n    def __len__(self):\n        return len(self.dataset)\n\n    @property\n    def node_feat_size(self):\n        g = self[0][0]\n        return g.ndata[\"feat\"].shape[-1] if \"feat\" in g.ndata else None\n\n    @property\n    def edge_feat_size(self):\n        g = self[0][0]\n        return g.edata[\"feat\"].shape[-1] if \"feat\" in g.edata else None\n"
  },
  {
    "path": "python/dgl/data/bitcoinotc.py",
    "content": "\"\"\" BitcoinOTC dataset for fraud detection \"\"\"\nimport datetime\nimport gzip\nimport os\nimport shutil\n\nimport numpy as np\n\nfrom .. import backend as F\nfrom ..convert import graph as dgl_graph\nfrom .dgl_dataset import DGLBuiltinDataset\nfrom .utils import check_sha1, download, load_graphs, makedirs, save_graphs\n\n\nclass BitcoinOTCDataset(DGLBuiltinDataset):\n    r\"\"\"BitcoinOTC dataset for fraud detection\n\n    This is who-trusts-whom network of people who trade using Bitcoin on\n    a platform called Bitcoin OTC. Since Bitcoin users are anonymous,\n    there is a need to maintain a record of users' reputation to prevent\n    transactions with fraudulent and risky users.\n\n    Offical website: `<https://snap.stanford.edu/data/soc-sign-bitcoin-otc.html>`_\n\n    Bitcoin OTC dataset statistics:\n\n    - Nodes: 5,881\n    - Edges: 35,592\n    - Range of edge weight: -10 to +10\n    - Percentage of positive edges: 89%\n\n    Parameters\n    ----------\n    raw_dir : str\n        Raw file directory to download/contains the input data directory.\n        Default: ~/.dgl/\n    force_reload : bool\n        Whether to reload the dataset.\n        Default: False\n    verbose: bool\n        Whether to print out progress information.\n        Default: True.\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access.\n\n    Attributes\n    ----------\n    graphs : list\n        A list of DGLGraph objects\n    is_temporal : bool\n        Indicate whether the graphs are temporal graphs\n\n    Raises\n    ------\n    UserWarning\n        If the raw data is changed in the remote server by the author.\n\n    Examples\n    --------\n    >>> dataset = BitcoinOTCDataset()\n    >>> len(dataset)\n    136\n    >>> for g in dataset:\n    ....    # get edge feature\n    ....    edge_weights = g.edata['h']\n    ....    # your code here\n    >>>\n    \"\"\"\n\n    _url = \"https://snap.stanford.edu/data/soc-sign-bitcoinotc.csv.gz\"\n    _sha1_str = \"c14281f9e252de0bd0b5f1c6e2bae03123938641\"\n\n    def __init__(\n        self, raw_dir=None, force_reload=False, verbose=False, transform=None\n    ):\n        super(BitcoinOTCDataset, self).__init__(\n            name=\"bitcoinotc\",\n            url=self._url,\n            raw_dir=raw_dir,\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n    def download(self):\n        gz_file_path = os.path.join(self.raw_dir, self.name + \".csv.gz\")\n        download(self.url, path=gz_file_path)\n        if not check_sha1(gz_file_path, self._sha1_str):\n            raise UserWarning(\n                \"File {} is downloaded but the content hash does not match.\"\n                \"The repo may be outdated or download may be incomplete. \"\n                \"Otherwise you can create an issue for it.\".format(\n                    self.name + \".csv.gz\"\n                )\n            )\n        self._extract_gz(gz_file_path, self.raw_path)\n\n    def process(self):\n        filename = os.path.join(self.save_path, self.name + \".csv\")\n        data = np.loadtxt(filename, delimiter=\",\").astype(np.int64)\n        data[:, 0:2] = data[:, 0:2] - data[:, 0:2].min()\n        delta = datetime.timedelta(days=14).total_seconds()\n        # The source code is not released, but the paper indicates there're\n        # totally 137 samples. The cutoff below has exactly 137 samples.\n        time_index = np.around((data[:, 3] - data[:, 3].min()) / delta).astype(\n            np.int64\n        )\n\n        self._graphs = []\n        for i in range(time_index.max()):\n            row_mask = time_index <= i\n            edges = data[row_mask][:, 0:2]\n            rate = data[row_mask][:, 2]\n            g = dgl_graph((edges[:, 0], edges[:, 1]))\n            g.edata[\"h\"] = F.tensor(\n                rate.reshape(-1, 1), dtype=F.data_type_dict[\"int64\"]\n            )\n            self._graphs.append(g)\n\n    @property\n    def graph_path(self):\n        return os.path.join(self.save_path, \"dgl_graph.bin\")\n\n    def has_cache(self):\n        return os.path.exists(self.graph_path)\n\n    def save(self):\n        save_graphs(self.graph_path, self.graphs)\n\n    def load(self):\n        self._graphs = load_graphs(self.graph_path)[0]\n\n    @property\n    def graphs(self):\n        return self._graphs\n\n    def __len__(self):\n        r\"\"\"Number of graphs in the dataset.\n\n        Return\n        -------\n        int\n        \"\"\"\n        return len(self.graphs)\n\n    def __getitem__(self, item):\n        r\"\"\"Get graph by index\n\n        Parameters\n        ----------\n        item : int\n            Item index\n\n        Returns\n        -------\n        :class:`dgl.DGLGraph`\n\n            The graph contains:\n\n            - ``edata['h']`` : edge weights\n        \"\"\"\n        if self._transform is None:\n            return self.graphs[item]\n        else:\n            return self._transform(self.graphs[item])\n\n    @property\n    def is_temporal(self):\n        r\"\"\"Are the graphs temporal graphs\n\n        Returns\n        -------\n        bool\n        \"\"\"\n        return True\n\n    def _extract_gz(self, file, target_dir, overwrite=False):\n        if os.path.exists(target_dir) and not overwrite:\n            return\n        print(\"Extracting file to {}\".format(target_dir))\n        fname = os.path.basename(file)\n        makedirs(target_dir)\n        out_file_path = os.path.join(target_dir, fname[:-3])\n        with gzip.open(file, \"rb\") as f_in:\n            with open(out_file_path, \"wb\") as f_out:\n                shutil.copyfileobj(f_in, f_out)\n\n\nBitcoinOTC = BitcoinOTCDataset\n"
  },
  {
    "path": "python/dgl/data/citation_graph.py",
    "content": "\"\"\"Cora, citeseer, pubmed dataset.\n\n(lingfan): following dataset loading and preprocessing code from tkipf/gcn\nhttps://github.com/tkipf/gcn/blob/master/gcn/utils.py\n\"\"\"\n\nfrom __future__ import absolute_import\n\nimport os, sys\nimport pickle as pkl\nimport warnings\n\nimport networkx as nx\n\nimport numpy as np\nimport scipy.sparse as sp\n\nfrom .. import backend as F, convert\nfrom ..batch import batch as batch_graphs\nfrom ..convert import from_networkx, graph as dgl_graph, to_networkx\nfrom ..transforms import reorder_graph\nfrom .dgl_dataset import DGLBuiltinDataset\n\nfrom .utils import (\n    _get_dgl_url,\n    deprecate_function,\n    deprecate_property,\n    generate_mask_tensor,\n    load_graphs,\n    load_info,\n    makedirs,\n    save_graphs,\n    save_info,\n)\n\nbackend = os.environ.get(\"DGLBACKEND\", \"pytorch\")\n\n\ndef _pickle_load(pkl_file):\n    with warnings.catch_warnings():\n        warnings.simplefilter(\"ignore\", category=DeprecationWarning)\n        if sys.version_info > (3, 0):\n            return pkl.load(pkl_file, encoding=\"latin1\")\n        else:\n            return pkl.load(pkl_file)\n\n\nclass CitationGraphDataset(DGLBuiltinDataset):\n    r\"\"\"The citation graph dataset, including cora, citeseer and pubmeb.\n    Nodes mean authors and edges mean citation relationships.\n\n    Parameters\n    -----------\n    name: str\n      name can be 'cora', 'citeseer' or 'pubmed'.\n    raw_dir : str\n        Raw file directory to download/contains the input data directory.\n        Default: ~/.dgl/\n    force_reload : bool\n        Whether to reload the dataset. Default: False\n    verbose : bool\n        Whether to print out progress information. Default: True.\n    reverse_edge : bool\n        Whether to add reverse edges in graph. Default: True.\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access.\n    reorder : bool\n        Whether to reorder the graph using :func:`~dgl.reorder_graph`. Default: False.\n    \"\"\"\n\n    _urls = {\n        \"cora_v2\": \"dataset/cora_v2.zip\",\n        \"citeseer\": \"dataset/citeseer.zip\",\n        \"pubmed\": \"dataset/pubmed.zip\",\n    }\n\n    def __init__(\n        self,\n        name,\n        raw_dir=None,\n        force_reload=False,\n        verbose=True,\n        reverse_edge=True,\n        transform=None,\n        reorder=False,\n    ):\n        assert name.lower() in [\"cora\", \"citeseer\", \"pubmed\"]\n\n        # Previously we use the pre-processing in pygcn (https://github.com/tkipf/pygcn)\n        # for Cora, which is slightly different from the one used in the GCN paper\n        if name.lower() == \"cora\":\n            name = \"cora_v2\"\n\n        url = _get_dgl_url(self._urls[name])\n        self._reverse_edge = reverse_edge\n        self._reorder = reorder\n\n        super(CitationGraphDataset, self).__init__(\n            name,\n            url=url,\n            raw_dir=raw_dir,\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n    def process(self):\n        \"\"\"Loads input data from data directory and reorder graph for better locality\n\n        ind.name.x => the feature vectors of the training instances as scipy.sparse.csr.csr_matrix object;\n        ind.name.tx => the feature vectors of the test instances as scipy.sparse.csr.csr_matrix object;\n        ind.name.allx => the feature vectors of both labeled and unlabeled training instances\n            (a superset of ind.name.x) as scipy.sparse.csr.csr_matrix object;\n        ind.name.y => the one-hot labels of the labeled training instances as numpy.ndarray object;\n        ind.name.ty => the one-hot labels of the test instances as numpy.ndarray object;\n        ind.name.ally => the labels for instances in ind.name.allx as numpy.ndarray object;\n        ind.name.graph => a dict in the format {index: [index_of_neighbor_nodes]} as collections.defaultdict\n            object;\n        ind.name.test.index => the indices of test instances in graph, for the inductive setting as list object.\n        \"\"\"\n        root = self.raw_path\n        objnames = [\"x\", \"y\", \"tx\", \"ty\", \"allx\", \"ally\", \"graph\"]\n        objects = []\n        for i in range(len(objnames)):\n            with open(\n                \"{}/ind.{}.{}\".format(root, self.name, objnames[i]), \"rb\"\n            ) as f:\n                objects.append(_pickle_load(f))\n\n        x, y, tx, ty, allx, ally, graph = tuple(objects)\n        test_idx_reorder = _parse_index_file(\n            \"{}/ind.{}.test.index\".format(root, self.name)\n        )\n        test_idx_range = np.sort(test_idx_reorder)\n\n        if self.name == \"citeseer\":\n            # Fix citeseer dataset (there are some isolated nodes in the graph)\n            # Find isolated nodes, add them as zero-vecs into the right position\n            test_idx_range_full = range(\n                min(test_idx_reorder), max(test_idx_reorder) + 1\n            )\n            tx_extended = sp.lil_matrix((len(test_idx_range_full), x.shape[1]))\n            tx_extended[test_idx_range - min(test_idx_range), :] = tx\n            tx = tx_extended\n            ty_extended = np.zeros((len(test_idx_range_full), y.shape[1]))\n            ty_extended[test_idx_range - min(test_idx_range), :] = ty\n            ty = ty_extended\n\n        features = sp.vstack((allx, tx)).tolil()\n        features[test_idx_reorder, :] = features[test_idx_range, :]\n\n        if self.reverse_edge:\n            graph = nx.DiGraph(nx.from_dict_of_lists(graph))\n            g = from_networkx(graph)\n        else:\n            graph = nx.Graph(nx.from_dict_of_lists(graph))\n            edges = list(graph.edges())\n            u, v = map(list, zip(*edges))\n            g = dgl_graph((u, v))\n\n        onehot_labels = np.vstack((ally, ty))\n        onehot_labels[test_idx_reorder, :] = onehot_labels[test_idx_range, :]\n        labels = np.argmax(onehot_labels, 1)\n\n        idx_test = test_idx_range.tolist()\n        idx_train = range(len(y))\n        idx_val = range(len(y), len(y) + 500)\n\n        train_mask = generate_mask_tensor(\n            _sample_mask(idx_train, labels.shape[0])\n        )\n        val_mask = generate_mask_tensor(_sample_mask(idx_val, labels.shape[0]))\n        test_mask = generate_mask_tensor(\n            _sample_mask(idx_test, labels.shape[0])\n        )\n\n        g.ndata[\"train_mask\"] = train_mask\n        g.ndata[\"val_mask\"] = val_mask\n        g.ndata[\"test_mask\"] = test_mask\n        g.ndata[\"label\"] = F.tensor(labels)\n        g.ndata[\"feat\"] = F.tensor(\n            _preprocess_features(features), dtype=F.data_type_dict[\"float32\"]\n        )\n        self._num_classes = onehot_labels.shape[1]\n        self._labels = labels\n        if self._reorder:\n            self._g = reorder_graph(\n                g,\n                node_permute_algo=\"rcmk\",\n                edge_permute_algo=\"dst\",\n                store_ids=False,\n            )\n        else:\n            self._g = g\n\n        if self.verbose:\n            print(\"Finished data loading and preprocessing.\")\n            print(\"  NumNodes: {}\".format(self._g.num_nodes()))\n            print(\"  NumEdges: {}\".format(self._g.num_edges()))\n            print(\"  NumFeats: {}\".format(self._g.ndata[\"feat\"].shape[1]))\n            print(\"  NumClasses: {}\".format(self.num_classes))\n            print(\n                \"  NumTrainingSamples: {}\".format(\n                    F.nonzero_1d(self._g.ndata[\"train_mask\"]).shape[0]\n                )\n            )\n            print(\n                \"  NumValidationSamples: {}\".format(\n                    F.nonzero_1d(self._g.ndata[\"val_mask\"]).shape[0]\n                )\n            )\n            print(\n                \"  NumTestSamples: {}\".format(\n                    F.nonzero_1d(self._g.ndata[\"test_mask\"]).shape[0]\n                )\n            )\n\n    @property\n    def graph_path(self):\n        return os.path.join(self.save_path, self.save_name + \".bin\")\n\n    @property\n    def info_path(self):\n        return os.path.join(self.save_path, self.save_name + \".pkl\")\n\n    def has_cache(self):\n        if os.path.exists(self.graph_path) and os.path.exists(self.info_path):\n            return True\n\n        return False\n\n    def save(self):\n        \"\"\"save the graph list and the labels\"\"\"\n        save_graphs(str(self.graph_path), self._g)\n        save_info(str(self.info_path), {\"num_classes\": self.num_classes})\n\n    def load(self):\n        graphs, _ = load_graphs(str(self.graph_path))\n\n        info = load_info(str(self.info_path))\n        graph = graphs[0]\n        self._g = graph\n        # for compatability\n        graph = graph.clone()\n        graph.ndata.pop(\"train_mask\")\n        graph.ndata.pop(\"val_mask\")\n        graph.ndata.pop(\"test_mask\")\n        graph.ndata.pop(\"feat\")\n        graph.ndata.pop(\"label\")\n        graph = to_networkx(graph)\n\n        self._num_classes = info[\"num_classes\"]\n        self._g.ndata[\"train_mask\"] = generate_mask_tensor(\n            F.asnumpy(self._g.ndata[\"train_mask\"])\n        )\n        self._g.ndata[\"val_mask\"] = generate_mask_tensor(\n            F.asnumpy(self._g.ndata[\"val_mask\"])\n        )\n        self._g.ndata[\"test_mask\"] = generate_mask_tensor(\n            F.asnumpy(self._g.ndata[\"test_mask\"])\n        )\n        # hack for mxnet compatability\n\n        if self.verbose:\n            print(\"  NumNodes: {}\".format(self._g.num_nodes()))\n            print(\"  NumEdges: {}\".format(self._g.num_edges()))\n            print(\"  NumFeats: {}\".format(self._g.ndata[\"feat\"].shape[1]))\n            print(\"  NumClasses: {}\".format(self.num_classes))\n            print(\n                \"  NumTrainingSamples: {}\".format(\n                    F.nonzero_1d(self._g.ndata[\"train_mask\"]).shape[0]\n                )\n            )\n            print(\n                \"  NumValidationSamples: {}\".format(\n                    F.nonzero_1d(self._g.ndata[\"val_mask\"]).shape[0]\n                )\n            )\n            print(\n                \"  NumTestSamples: {}\".format(\n                    F.nonzero_1d(self._g.ndata[\"test_mask\"]).shape[0]\n                )\n            )\n\n    def __getitem__(self, idx):\n        assert idx == 0, \"This dataset has only one graph\"\n        if self._transform is None:\n            return self._g\n        else:\n            return self._transform(self._g)\n\n    def __len__(self):\n        return 1\n\n    @property\n    def save_name(self):\n        return self.name + \"_dgl_graph\"\n\n    @property\n    def num_labels(self):\n        deprecate_property(\"dataset.num_labels\", \"dataset.num_classes\")\n        return self.num_classes\n\n    @property\n    def num_classes(self):\n        return self._num_classes\n\n    \"\"\" Citation graph is used in many examples\n        We preserve these properties for compatability.\n    \"\"\"\n\n    @property\n    def reverse_edge(self):\n        return self._reverse_edge\n\n\ndef _preprocess_features(features):\n    \"\"\"Row-normalize feature matrix and convert to tuple representation\"\"\"\n    features = _normalize(features)\n    return np.asarray(features.todense())\n\n\ndef _parse_index_file(filename):\n    \"\"\"Parse index file.\"\"\"\n    index = []\n    for line in open(filename):\n        index.append(int(line.strip()))\n    return index\n\n\ndef _sample_mask(idx, l):\n    \"\"\"Create mask.\"\"\"\n    mask = np.zeros(l)\n    mask[idx] = 1\n    return mask\n\n\nclass CoraGraphDataset(CitationGraphDataset):\n    r\"\"\"Cora citation network dataset.\n\n    Nodes mean paper and edges mean citation\n    relationships. Each node has a predefined\n    feature with 1433 dimensions. The dataset is\n    designed for the node classification task.\n    The task is to predict the category of\n    certain paper.\n\n    Statistics:\n\n    - Nodes: 2708\n    - Edges: 10556\n    - Number of Classes: 7\n    - Label split:\n\n        - Train: 140\n        - Valid: 500\n        - Test: 1000\n\n    Parameters\n    ----------\n    raw_dir : str\n        Raw file directory to download/contains the input data directory.\n        Default: ~/.dgl/\n    force_reload : bool\n        Whether to reload the dataset. Default: False\n    verbose : bool\n        Whether to print out progress information. Default: True.\n    reverse_edge : bool\n        Whether to add reverse edges in graph. Default: True.\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access.\n    reorder : bool\n        Whether to reorder the graph using :func:`~dgl.reorder_graph`. Default: False.\n\n    Attributes\n    ----------\n    num_classes: int\n        Number of label classes\n\n    Notes\n    -----\n    The node feature is row-normalized.\n\n    Examples\n    --------\n    >>> dataset = CoraGraphDataset()\n    >>> g = dataset[0]\n    >>> num_class = dataset.num_classes\n    >>>\n    >>> # get node feature\n    >>> feat = g.ndata['feat']\n    >>>\n    >>> # get data split\n    >>> train_mask = g.ndata['train_mask']\n    >>> val_mask = g.ndata['val_mask']\n    >>> test_mask = g.ndata['test_mask']\n    >>>\n    >>> # get labels\n    >>> label = g.ndata['label']\n\n    \"\"\"\n\n    def __init__(\n        self,\n        raw_dir=None,\n        force_reload=False,\n        verbose=True,\n        reverse_edge=True,\n        transform=None,\n        reorder=False,\n    ):\n        name = \"cora\"\n\n        super(CoraGraphDataset, self).__init__(\n            name,\n            raw_dir,\n            force_reload,\n            verbose,\n            reverse_edge,\n            transform,\n            reorder,\n        )\n\n    def __getitem__(self, idx):\n        r\"\"\"Gets the graph object\n\n        Parameters\n        -----------\n        idx: int\n            Item index, CoraGraphDataset has only one graph object\n\n        Return\n        ------\n        :class:`dgl.DGLGraph`\n\n            graph structure, node features and labels.\n\n            - ``ndata['train_mask']``: mask for training node set\n            - ``ndata['val_mask']``: mask for validation node set\n            - ``ndata['test_mask']``: mask for test node set\n            - ``ndata['feat']``: node feature\n            - ``ndata['label']``: ground truth labels\n        \"\"\"\n        return super(CoraGraphDataset, self).__getitem__(idx)\n\n    def __len__(self):\n        r\"\"\"The number of graphs in the dataset.\"\"\"\n        return super(CoraGraphDataset, self).__len__()\n\n\nclass CiteseerGraphDataset(CitationGraphDataset):\n    r\"\"\"Citeseer citation network dataset.\n\n    Nodes mean scientific publications and edges\n    mean citation relationships. Each node has a\n    predefined feature with 3703 dimensions. The\n    dataset is designed for the node classification\n    task. The task is to predict the category of\n    certain publication.\n\n    Statistics:\n\n    - Nodes: 3327\n    - Edges: 9228\n    - Number of Classes: 6\n    - Label Split:\n\n        - Train: 120\n        - Valid: 500\n        - Test: 1000\n\n    Parameters\n    -----------\n    raw_dir : str\n        Raw file directory to download/contains the input data directory.\n        Default: ~/.dgl/\n    force_reload : bool\n        Whether to reload the dataset. Default: False\n    verbose : bool\n        Whether to print out progress information. Default: True.\n    reverse_edge : bool\n        Whether to add reverse edges in graph. Default: True.\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access.\n    reorder : bool\n        Whether to reorder the graph using :func:`~dgl.reorder_graph`. Default: False.\n\n    Attributes\n    ----------\n    num_classes: int\n        Number of label classes\n\n    Notes\n    -----\n    The node feature is row-normalized.\n\n    In citeseer dataset, there are some isolated nodes in the graph.\n    These isolated nodes are added as zero-vecs into the right position.\n\n    Examples\n    --------\n    >>> dataset = CiteseerGraphDataset()\n    >>> g = dataset[0]\n    >>> num_class = dataset.num_classes\n    >>>\n    >>> # get node feature\n    >>> feat = g.ndata['feat']\n    >>>\n    >>> # get data split\n    >>> train_mask = g.ndata['train_mask']\n    >>> val_mask = g.ndata['val_mask']\n    >>> test_mask = g.ndata['test_mask']\n    >>>\n    >>> # get labels\n    >>> label = g.ndata['label']\n\n    \"\"\"\n\n    def __init__(\n        self,\n        raw_dir=None,\n        force_reload=False,\n        verbose=True,\n        reverse_edge=True,\n        transform=None,\n        reorder=False,\n    ):\n        name = \"citeseer\"\n\n        super(CiteseerGraphDataset, self).__init__(\n            name,\n            raw_dir,\n            force_reload,\n            verbose,\n            reverse_edge,\n            transform,\n            reorder,\n        )\n\n    def __getitem__(self, idx):\n        r\"\"\"Gets the graph object\n\n        Parameters\n        -----------\n        idx: int\n            Item index, CiteseerGraphDataset has only one graph object\n\n        Return\n        ------\n        :class:`dgl.DGLGraph`\n\n            graph structure, node features and labels.\n\n            - ``ndata['train_mask']``: mask for training node set\n            - ``ndata['val_mask']``: mask for validation node set\n            - ``ndata['test_mask']``: mask for test node set\n            - ``ndata['feat']``: node feature\n            - ``ndata['label']``: ground truth labels\n        \"\"\"\n        return super(CiteseerGraphDataset, self).__getitem__(idx)\n\n    def __len__(self):\n        r\"\"\"The number of graphs in the dataset.\"\"\"\n        return super(CiteseerGraphDataset, self).__len__()\n\n\nclass PubmedGraphDataset(CitationGraphDataset):\n    r\"\"\"Pubmed citation network dataset.\n\n    Nodes mean scientific publications and edges\n    mean citation relationships. Each node has a\n    predefined feature with 500 dimensions. The\n    dataset is designed for the node classification\n    task. The task is to predict the category of\n    certain publication.\n\n    Statistics:\n\n    - Nodes: 19717\n    - Edges: 88651\n    - Number of Classes: 3\n    - Label Split:\n\n        - Train: 60\n        - Valid: 500\n        - Test: 1000\n\n    Parameters\n    -----------\n    raw_dir : str\n        Raw file directory to download/contains the input data directory.\n        Default: ~/.dgl/\n    force_reload : bool\n        Whether to reload the dataset. Default: False\n    verbose : bool\n        Whether to print out progress information. Default: True.\n    reverse_edge : bool\n        Whether to add reverse edges in graph. Default: True.\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access.\n    reorder : bool\n        Whether to reorder the graph using :func:`~dgl.reorder_graph`. Default: False.\n\n    Attributes\n    ----------\n    num_classes: int\n        Number of label classes\n\n    Notes\n    -----\n    The node feature is row-normalized.\n\n    Examples\n    --------\n    >>> dataset = PubmedGraphDataset()\n    >>> g = dataset[0]\n    >>> num_class = dataset.num_of_class\n    >>>\n    >>> # get node feature\n    >>> feat = g.ndata['feat']\n    >>>\n    >>> # get data split\n    >>> train_mask = g.ndata['train_mask']\n    >>> val_mask = g.ndata['val_mask']\n    >>> test_mask = g.ndata['test_mask']\n    >>>\n    >>> # get labels\n    >>> label = g.ndata['label']\n\n    \"\"\"\n\n    def __init__(\n        self,\n        raw_dir=None,\n        force_reload=False,\n        verbose=True,\n        reverse_edge=True,\n        transform=None,\n        reorder=False,\n    ):\n        name = \"pubmed\"\n\n        super(PubmedGraphDataset, self).__init__(\n            name,\n            raw_dir,\n            force_reload,\n            verbose,\n            reverse_edge,\n            transform,\n            reorder,\n        )\n\n    def __getitem__(self, idx):\n        r\"\"\"Gets the graph object\n\n        Parameters\n        -----------\n        idx: int\n            Item index, PubmedGraphDataset has only one graph object\n\n        Return\n        ------\n        :class:`dgl.DGLGraph`\n\n            graph structure, node features and labels.\n\n            - ``ndata['train_mask']``: mask for training node set\n            - ``ndata['val_mask']``: mask for validation node set\n            - ``ndata['test_mask']``: mask for test node set\n            - ``ndata['feat']``: node feature\n            - ``ndata['label']``: ground truth labels\n        \"\"\"\n        return super(PubmedGraphDataset, self).__getitem__(idx)\n\n    def __len__(self):\n        r\"\"\"The number of graphs in the dataset.\"\"\"\n        return super(PubmedGraphDataset, self).__len__()\n\n\ndef load_cora(\n    raw_dir=None,\n    force_reload=False,\n    verbose=True,\n    reverse_edge=True,\n    transform=None,\n):\n    \"\"\"Get CoraGraphDataset\n\n    Parameters\n    -----------\n    raw_dir : str\n        Raw file directory to download/contains the input data directory.\n        Default: ~/.dgl/\n    force_reload : bool\n        Whether to reload the dataset. Default: False\n    verbose : bool\n        Whether to print out progress information. Default: True.\n    reverse_edge : bool\n        Whether to add reverse edges in graph. Default: True.\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access.\n\n    Return\n    -------\n    CoraGraphDataset\n    \"\"\"\n    data = CoraGraphDataset(\n        raw_dir, force_reload, verbose, reverse_edge, transform\n    )\n    return data\n\n\ndef load_citeseer(\n    raw_dir=None,\n    force_reload=False,\n    verbose=True,\n    reverse_edge=True,\n    transform=None,\n):\n    \"\"\"Get CiteseerGraphDataset\n\n    Parameters\n    -----------\n    raw_dir : str\n        Raw file directory to download/contains the input data directory.\n        Default: ~/.dgl/\n    force_reload : bool\n        Whether to reload the dataset. Default: False\n    verbose : bool\n        Whether to print out progress information. Default: True.\n    reverse_edge : bool\n        Whether to add reverse edges in graph. Default: True.\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access.\n\n    Return\n    -------\n    CiteseerGraphDataset\n    \"\"\"\n    data = CiteseerGraphDataset(\n        raw_dir, force_reload, verbose, reverse_edge, transform\n    )\n    return data\n\n\ndef load_pubmed(\n    raw_dir=None,\n    force_reload=False,\n    verbose=True,\n    reverse_edge=True,\n    transform=None,\n):\n    \"\"\"Get PubmedGraphDataset\n\n    Parameters\n    -----------\n    raw_dir : str\n        Raw file directory to download/contains the input data directory.\n        Default: ~/.dgl/\n    force_reload : bool\n        Whether to reload the dataset. Default: False\n    verbose : bool\n        Whether to print out progress information. Default: True.\n    reverse_edge : bool\n        Whether to add reverse edges in graph. Default: True.\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access.\n\n    Return\n    -------\n    PubmedGraphDataset\n    \"\"\"\n    data = PubmedGraphDataset(\n        raw_dir, force_reload, verbose, reverse_edge, transform\n    )\n    return data\n\n\nclass CoraBinary(DGLBuiltinDataset):\n    \"\"\"A mini-dataset for binary classification task using Cora.\n\n    After loaded, it has following members:\n\n    graphs : list of :class:`~dgl.DGLGraph`\n    pmpds : list of :class:`scipy.sparse.coo_matrix`\n    labels : list of :class:`numpy.ndarray`\n\n    Parameters\n    -----------\n    raw_dir : str\n        Raw file directory to download/contains the input data directory.\n        Default: ~/.dgl/\n    force_reload : bool\n        Whether to reload the dataset. Default: False\n    verbose: bool\n        Whether to print out progress information. Default: True.\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access.\n    \"\"\"\n\n    def __init__(\n        self, raw_dir=None, force_reload=False, verbose=True, transform=None\n    ):\n        name = \"cora_binary\"\n        url = _get_dgl_url(\"dataset/cora_binary.zip\")\n        super(CoraBinary, self).__init__(\n            name,\n            url=url,\n            raw_dir=raw_dir,\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n    def process(self):\n        root = self.raw_path\n        # load graphs\n        self.graphs = []\n        with open(\"{}/graphs.txt\".format(root), \"r\") as f:\n            elist = []\n            for line in f.readlines():\n                if line.startswith(\"graph\"):\n                    if len(elist) != 0:\n                        self.graphs.append(dgl_graph(tuple(zip(*elist))))\n                    elist = []\n                else:\n                    u, v = line.strip().split(\" \")\n                    elist.append((int(u), int(v)))\n            if len(elist) != 0:\n                self.graphs.append(dgl_graph(tuple(zip(*elist))))\n        with open(\"{}/pmpds.pkl\".format(root), \"rb\") as f:\n            self.pmpds = _pickle_load(f)\n        self.labels = []\n        with open(\"{}/labels.txt\".format(root), \"r\") as f:\n            cur = []\n            for line in f.readlines():\n                if line.startswith(\"graph\"):\n                    if len(cur) != 0:\n                        self.labels.append(np.asarray(cur))\n                    cur = []\n                else:\n                    cur.append(int(line.strip()))\n            if len(cur) != 0:\n                self.labels.append(np.asarray(cur))\n        # sanity check\n        assert len(self.graphs) == len(self.pmpds)\n        assert len(self.graphs) == len(self.labels)\n\n    @property\n    def graph_path(self):\n        return os.path.join(self.save_path, self.save_name + \".bin\")\n\n    def has_cache(self):\n        if os.path.exists(self.graph_path):\n            return True\n\n        return False\n\n    def save(self):\n        \"\"\"save the graph list and the labels\"\"\"\n        labels = {}\n        for i, label in enumerate(self.labels):\n            labels[\"{}\".format(i)] = F.tensor(label)\n        save_graphs(str(self.graph_path), self.graphs, labels)\n        if self.verbose:\n            print(\"Done saving data into cached files.\")\n\n    def load(self):\n        self.graphs, labels = load_graphs(str(self.graph_path))\n\n        self.labels = []\n        for i in range(len(labels)):\n            self.labels.append(F.asnumpy(labels[\"{}\".format(i)]))\n        # load pmpds under self.raw_path\n        with open(\"{}/pmpds.pkl\".format(self.raw_path), \"rb\") as f:\n            self.pmpds = _pickle_load(f)\n        if self.verbose:\n            print(\"Done loading data into cached files.\")\n        # sanity check\n        assert len(self.graphs) == len(self.pmpds)\n        assert len(self.graphs) == len(self.labels)\n\n    def __len__(self):\n        return len(self.graphs)\n\n    def __getitem__(self, i):\n        r\"\"\"Gets the idx-th sample.\n\n        Parameters\n        -----------\n        idx : int\n            The sample index.\n\n        Returns\n        -------\n        (dgl.DGLGraph, scipy.sparse.coo_matrix, int)\n            The graph, scipy sparse coo_matrix and its label.\n        \"\"\"\n        if self._transform is None:\n            g = self.graphs[i]\n        else:\n            g = self._transform(self.graphs[i])\n        return (g, self.pmpds[i], self.labels[i])\n\n    @property\n    def save_name(self):\n        return self.name + \"_dgl_graph\"\n\n    @staticmethod\n    def collate_fn(cur):\n        graphs, pmpds, labels = zip(*cur)\n        batched_graphs = batch_graphs(graphs)\n        batched_pmpds = sp.block_diag(pmpds)\n        batched_labels = np.concatenate(labels, axis=0)\n        return batched_graphs, batched_pmpds, batched_labels\n\n\ndef _normalize(mx):\n    \"\"\"Row-normalize sparse matrix\"\"\"\n    rowsum = np.asarray(mx.sum(1))\n    mask = np.equal(rowsum, 0.0).flatten()\n    rowsum[mask] = np.nan\n    r_inv = np.power(rowsum, -1).flatten()\n    r_inv[mask] = 0.0\n    r_mat_inv = sp.diags(r_inv)\n    return r_mat_inv.dot(mx)\n\n\ndef _encode_onehot(labels):\n    classes = list(sorted(set(labels)))\n    classes_dict = {\n        c: np.identity(len(classes))[i, :] for i, c in enumerate(classes)\n    }\n    labels_onehot = np.asarray(\n        list(map(classes_dict.get, labels)), dtype=np.int32\n    )\n    return labels_onehot\n"
  },
  {
    "path": "python/dgl/data/cluster.py",
    "content": "\"\"\" CLUSTERDataset for inductive learning. \"\"\"\nimport os\n\nfrom .dgl_dataset import DGLBuiltinDataset\nfrom .utils import _get_dgl_url, load_graphs\n\n\nclass CLUSTERDataset(DGLBuiltinDataset):\n    r\"\"\"CLUSTER dataset for semi-supervised clustering task.\n\n    Each graph contains 6 SBM clusters with sizes randomly selected between\n    [5, 35] and probabilities p = 0.55, q = 0.25. The graphs are of sizes 40\n    -190 nodes. Each node can take an input feature value in {0, 1, 2, ..., 6}\n    and values 1~6 correspond to classes 0~5 respectively, while value 0 means\n    that the class of the node is unknown. There is only one labeled node that\n    is randomly assigned to each community and most node features are set to 0.\n\n    Reference `<https://arxiv.org/pdf/2003.00982.pdf>`_\n\n    Statistics:\n\n    - Train examples: 10,000\n    - Valid examples: 1,000\n    - Test examples: 1,000\n    - Number of classes for each node: 6\n\n    Parameters\n    ----------\n    mode : str\n        Must be one of ('train', 'valid', 'test').\n        Default: 'train'\n    raw_dir : str\n        Raw file directory to download/contains the input data directory.\n        Default: ~/.dgl/\n    force_reload : bool\n        Whether to reload the dataset.\n        Default: False\n    verbose : bool\n        Whether to print out progress information.\n        Default: False\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access.\n\n    Attributes\n    ----------\n    num_classes : int\n        Number of classes for each node.\n\n    Examples\n    --------\n    >>> from dgl.data import CLUSTERDataset\n    >>>\n    >>> trainset = CLUSTERDataset(mode='train')\n    >>>\n    >>> trainset.num_classes\n    6\n    >>> len(trainset)\n    10000\n    >>> trainset[0]\n    Graph(num_nodes=117, num_edges=4104,\n          ndata_schemes={'label': Scheme(shape=(), dtype=torch.int16),\n                         'feat': Scheme(shape=(), dtype=torch.int64)}\n          edata_schemes={'feat': Scheme(shape=(1,), dtype=torch.float32)})\n    \"\"\"\n\n    def __init__(\n        self,\n        mode=\"train\",\n        raw_dir=None,\n        force_reload=False,\n        verbose=False,\n        transform=None,\n    ):\n        self._url = _get_dgl_url(\"dataset/SBM_CLUSTER.zip\")\n        self.mode = mode\n\n        super(CLUSTERDataset, self).__init__(\n            name=\"cluster\",\n            url=self._url,\n            raw_dir=raw_dir,\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n    def process(self):\n        self.load()\n\n    def has_cache(self):\n        graph_path = os.path.join(\n            self.save_path, \"CLUSTER_{}.bin\".format(self.mode)\n        )\n        return os.path.exists(graph_path)\n\n    def load(self):\n        graph_path = os.path.join(\n            self.save_path, \"CLUSTER_{}.bin\".format(self.mode)\n        )\n        self._graphs, _ = load_graphs(graph_path)\n\n    @property\n    def num_classes(self):\n        r\"\"\"Number of classes for each node.\"\"\"\n        return 6\n\n    def __len__(self):\n        r\"\"\"The number of examples in the dataset.\"\"\"\n        return len(self._graphs)\n\n    def __getitem__(self, idx):\n        r\"\"\"Get the idx^th sample.\n\n        Parameters\n        ---------\n        idx : int\n            The sample index.\n\n        Returns\n        -------\n        :class:`dgl.DGLGraph`\n            graph structure, node features, node labels and edge features.\n\n            - ``ndata['feat']``: node features\n            - ``ndata['label']``: node labels\n            - ``edata['feat']``: edge features\n        \"\"\"\n        if self._transform is None:\n            return self._graphs[idx]\n        else:\n            return self._transform(self._graphs[idx])\n"
  },
  {
    "path": "python/dgl/data/csv_dataset.py",
    "content": "import os\n\nimport numpy as np\n\nfrom .. import backend as F\nfrom ..base import DGLError\nfrom .dgl_dataset import DGLDataset\nfrom .utils import load_graphs, save_graphs, Subset\n\n\nclass CSVDataset(DGLDataset):\n    \"\"\"Dataset class that loads and parses graph data from CSV files.\n\n    This class requires the following additional packages:\n\n        - pyyaml >= 5.4.1\n        - pandas >= 1.1.5\n        - pydantic >= 1.9.0\n\n    The parsed graph and feature data will be cached for faster reloading. If\n    the source CSV files are modified, please specify ``force_reload=True``\n    to re-parse from them.\n\n    Parameters\n    ----------\n    data_path : str\n        Directory which contains 'meta.yaml' and CSV files\n    force_reload : bool, optional\n        Whether to reload the dataset. Default: False\n    verbose: bool, optional\n        Whether to print out progress information. Default: True.\n    ndata_parser : dict[str, callable] or callable, optional\n        Callable object which takes in the ``pandas.DataFrame`` object created from\n        CSV file, parses node data and returns a dictionary of parsed data. If given a\n        dictionary, the key is node type and the value is a callable object which is\n        used to parse data of corresponding node type. If given a single callable\n        object, such object is used to parse data of all node type data. Default: None.\n        If None, a default data parser is applied which load data directly and tries to\n        convert list into array.\n    edata_parser : dict[(str, str, str), callable], or callable, optional\n        Callable object which takes in the ``pandas.DataFrame`` object created from\n        CSV file, parses edge data and returns a dictionary of parsed data. If given a\n        dictionary, the key is edge type and the value is a callable object which is\n        used to parse data of corresponding edge type. If given a single callable\n        object, such object is used to parse data of all edge type data. Default: None.\n        If None, a default data parser is applied which load data directly and tries to\n        convert list into array.\n    gdata_parser : callable, optional\n        Callable object which takes in the ``pandas.DataFrame`` object created from\n        CSV file, parses graph data and returns a dictionary of parsed data. Default:\n        None. If None, a default data parser is applied which load data directly and\n        tries to convert list into array.\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access.\n\n    Attributes\n    ----------\n    graphs : :class:`dgl.DGLGraph`\n        Graphs of the dataset\n    data : dict\n        any available graph-level data such as graph-level feature, labels.\n\n    Examples\n    --------\n    Please refer to :ref:`guide-data-pipeline-loadcsv`.\n\n    \"\"\"\n\n    META_YAML_NAME = \"meta.yaml\"\n\n    def __init__(\n        self,\n        data_path,\n        force_reload=False,\n        verbose=True,\n        ndata_parser=None,\n        edata_parser=None,\n        gdata_parser=None,\n        transform=None,\n    ):\n        from .csv_dataset_base import (\n            DefaultDataParser,\n            load_yaml_with_sanity_check,\n        )\n\n        self.graphs = None\n        self.data = None\n        self.ndata_parser = {} if ndata_parser is None else ndata_parser\n        self.edata_parser = {} if edata_parser is None else edata_parser\n        self.gdata_parser = gdata_parser\n        self.default_data_parser = DefaultDataParser()\n        meta_yaml_path = os.path.join(data_path, CSVDataset.META_YAML_NAME)\n        if not os.path.exists(meta_yaml_path):\n            raise DGLError(\n                \"'{}' cannot be found under {}.\".format(\n                    CSVDataset.META_YAML_NAME, data_path\n                )\n            )\n        self.meta_yaml = load_yaml_with_sanity_check(meta_yaml_path)\n        ds_name = self.meta_yaml.dataset_name\n        super().__init__(\n            ds_name,\n            raw_dir=os.path.dirname(meta_yaml_path),\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n    def process(self):\n        \"\"\"Parse node/edge data from CSV files and construct DGL.Graphs\"\"\"\n        from .csv_dataset_base import (\n            DGLGraphConstructor,\n            EdgeData,\n            GraphData,\n            NodeData,\n        )\n\n        meta_yaml = self.meta_yaml\n        base_dir = self.raw_dir\n        node_data = []\n        for meta_node in meta_yaml.node_data:\n            if meta_node is None:\n                continue\n            ntype = meta_node.ntype\n            data_parser = (\n                self.ndata_parser\n                if callable(self.ndata_parser)\n                else self.ndata_parser.get(ntype, self.default_data_parser)\n            )\n            ndata = NodeData.load_from_csv(\n                meta_node,\n                base_dir=base_dir,\n                separator=meta_yaml.separator,\n                data_parser=data_parser,\n            )\n            node_data.append(ndata)\n        edge_data = []\n        for meta_edge in meta_yaml.edge_data:\n            if meta_edge is None:\n                continue\n            etype = tuple(meta_edge.etype)\n            data_parser = (\n                self.edata_parser\n                if callable(self.edata_parser)\n                else self.edata_parser.get(etype, self.default_data_parser)\n            )\n            edata = EdgeData.load_from_csv(\n                meta_edge,\n                base_dir=base_dir,\n                separator=meta_yaml.separator,\n                data_parser=data_parser,\n            )\n            edge_data.append(edata)\n        graph_data = None\n        if meta_yaml.graph_data is not None:\n            meta_graph = meta_yaml.graph_data\n            data_parser = (\n                self.default_data_parser\n                if self.gdata_parser is None\n                else self.gdata_parser\n            )\n            graph_data = GraphData.load_from_csv(\n                meta_graph,\n                base_dir=base_dir,\n                separator=meta_yaml.separator,\n                data_parser=data_parser,\n            )\n        # construct graphs\n        self.graphs, self.data = DGLGraphConstructor.construct_graphs(\n            node_data, edge_data, graph_data\n        )\n        if len(self.data) == 1:\n            self.labels = list(self.data.values())[0]\n\n    def has_cache(self):\n        graph_path = os.path.join(self.save_path, self.name + \".bin\")\n        if os.path.exists(graph_path):\n            return True\n\n        return False\n\n    def save(self):\n        if self.graphs is None:\n            raise DGLError(\"No graphs available in dataset\")\n        graph_path = os.path.join(self.save_path, self.name + \".bin\")\n        save_graphs(graph_path, self.graphs, labels=self.data)\n\n    def load(self):\n        graph_path = os.path.join(self.save_path, self.name + \".bin\")\n        self.graphs, self.data = load_graphs(graph_path)\n        if len(self.data) == 1:\n            self.labels = list(self.data.values())[0]\n\n    def __getitem__(self, i):\n        if F.is_tensor(i) and F.ndim(i) == 1:\n            return Subset(self, F.copy_to(i, F.cpu()))\n\n        if self._transform is None:\n            g = self.graphs[i]\n        else:\n            g = self._transform(self.graphs[i])\n\n        if len(self.data) == 1:\n            return g, self.labels[i]\n        elif len(self.data) > 0:\n            data = {k: v[i] for (k, v) in self.data.items()}\n            return g, data\n        else:\n            return g\n\n    def __len__(self):\n        return len(self.graphs)\n"
  },
  {
    "path": "python/dgl/data/csv_dataset_base.py",
    "content": "import ast\nimport os\nfrom typing import Callable, List, Optional\n\nimport numpy as np\nimport pandas as pd\nimport pydantic as dt\nimport yaml\n\nfrom .. import backend as F\nfrom ..base import dgl_warning, DGLError\nfrom ..convert import heterograph as dgl_heterograph\n\n\nclass MetaNode(dt.BaseModel):\n    \"\"\"Class of node_data in YAML. Internal use only.\"\"\"\n\n    file_name: str\n    ntype: Optional[str] = \"_V\"\n    graph_id_field: Optional[str] = \"graph_id\"\n    node_id_field: Optional[str] = \"node_id\"\n\n\nclass MetaEdge(dt.BaseModel):\n    \"\"\"Class of edge_data in YAML. Internal use only.\"\"\"\n\n    file_name: str\n    etype: Optional[List[str]] = [\"_V\", \"_E\", \"_V\"]\n    graph_id_field: Optional[str] = \"graph_id\"\n    src_id_field: Optional[str] = \"src_id\"\n    dst_id_field: Optional[str] = \"dst_id\"\n\n\nclass MetaGraph(dt.BaseModel):\n    \"\"\"Class of graph_data in YAML. Internal use only.\"\"\"\n\n    file_name: str\n    graph_id_field: Optional[str] = \"graph_id\"\n\n\nclass MetaYaml(dt.BaseModel):\n    \"\"\"Class of YAML. Internal use only.\"\"\"\n\n    version: Optional[str] = \"1.0.0\"\n    dataset_name: str\n    separator: Optional[str] = \",\"\n    node_data: List[MetaNode]\n    edge_data: List[MetaEdge]\n    graph_data: Optional[MetaGraph] = None\n\n\ndef load_yaml_with_sanity_check(yaml_file):\n    \"\"\"Load yaml and do sanity check. Internal use only.\"\"\"\n    with open(yaml_file) as f:\n        yaml_data = yaml.load(f, Loader=yaml.loader.SafeLoader)\n        try:\n            meta_yaml = MetaYaml(**yaml_data)\n        except dt.ValidationError as e:\n            print(\"Details of pydantic.ValidationError:\\n{}\".format(e.json()))\n            raise DGLError(\n                \"Validation Error for YAML fields. Details are shown above.\"\n            )\n        if meta_yaml.version != \"1.0.0\":\n            raise DGLError(\n                \"Invalid CSVDataset version {}. Supported versions: '1.0.0'\".format(\n                    meta_yaml.version\n                )\n            )\n        ntypes = [meta.ntype for meta in meta_yaml.node_data]\n        if len(ntypes) > len(set(ntypes)):\n            raise DGLError(\n                \"Each node CSV file must have a unique node type name, but found duplicate node type: {}.\".format(\n                    ntypes\n                )\n            )\n        etypes = [tuple(meta.etype) for meta in meta_yaml.edge_data]\n        if len(etypes) > len(set(etypes)):\n            raise DGLError(\n                \"Each edge CSV file must have a unique edge type name, but found duplicate edge type: {}.\".format(\n                    etypes\n                )\n            )\n        return meta_yaml\n\n\ndef _validate_data_length(data_dict):\n    len_dict = {k: len(v) for k, v in data_dict.items()}\n    lst = list(len_dict.values())\n    res = lst.count(lst[0]) == len(lst)\n    if not res:\n        raise DGLError(\n            \"All data are required to have same length while some of them does not. Length of data={}\".format(\n                str(len_dict)\n            )\n        )\n\n\ndef _tensor(data, dtype=None):\n    \"\"\"Float32 is the default dtype for float tensor in DGL\n    so let's cast float64 into float32 to avoid dtype mismatch.\n    \"\"\"\n    ret = F.tensor(data, dtype)\n    if F.dtype(ret) == F.float64:\n        ret = F.tensor(ret, dtype=F.float32)\n    return ret\n\n\nclass BaseData:\n    \"\"\"Class of base data which is inherited by Node/Edge/GraphData. Internal use only.\"\"\"\n\n    @staticmethod\n    def read_csv(file_name, base_dir, separator):\n        csv_path = file_name\n        if base_dir is not None:\n            csv_path = os.path.join(base_dir, csv_path)\n        return pd.read_csv(csv_path, sep=separator)\n\n    @staticmethod\n    def pop_from_dataframe(df: pd.DataFrame, item: str):\n        ret = None\n        try:\n            ret = df.pop(item).to_numpy().squeeze()\n        except KeyError:\n            pass\n        return ret\n\n\nclass NodeData(BaseData):\n    \"\"\"Class of node data which is used for DGLGraph construction. Internal use only.\"\"\"\n\n    def __init__(self, node_id, data, type=None, graph_id=None):\n        self.id = np.array(node_id)\n        self.data = data\n        self.type = type if type is not None else \"_V\"\n        self.graph_id = (\n            np.array(graph_id)\n            if graph_id is not None\n            else np.full(len(node_id), 0)\n        )\n        _validate_data_length(\n            {**{\"id\": self.id, \"graph_id\": self.graph_id}, **self.data}\n        )\n\n    @staticmethod\n    def load_from_csv(\n        meta: MetaNode, data_parser: Callable, base_dir=None, separator=\",\"\n    ):\n        df = BaseData.read_csv(meta.file_name, base_dir, separator)\n        node_ids = BaseData.pop_from_dataframe(df, meta.node_id_field)\n        graph_ids = BaseData.pop_from_dataframe(df, meta.graph_id_field)\n        if node_ids is None:\n            raise DGLError(\n                \"Missing node id field [{}] in file [{}].\".format(\n                    meta.node_id_field, meta.file_name\n                )\n            )\n        ntype = meta.ntype\n        ndata = data_parser(df)\n        return NodeData(node_ids, ndata, type=ntype, graph_id=graph_ids)\n\n    @staticmethod\n    def to_dict(node_data: List[\"NodeData\"]) -> dict:\n        # node_ids could be numeric or non-numeric values, but duplication is not allowed.\n        node_dict = {}\n        for n_data in node_data:\n            graph_ids = np.unique(n_data.graph_id)\n            for graph_id in graph_ids:\n                idx = n_data.graph_id == graph_id\n                ids = n_data.id[idx]\n                u_ids, u_indices, u_counts = np.unique(\n                    ids, return_index=True, return_counts=True\n                )\n                if len(ids) > len(u_ids):\n                    raise DGLError(\n                        \"Node IDs are required to be unique but the following ids are duplicate: {}\".format(\n                            u_ids[u_counts > 1]\n                        )\n                    )\n                if graph_id not in node_dict:\n                    node_dict[graph_id] = {}\n                node_dict[graph_id][n_data.type] = {\n                    \"mapping\": {\n                        index: i for i, index in enumerate(ids[u_indices])\n                    },\n                    \"data\": {\n                        k: _tensor(v[idx][u_indices])\n                        for k, v in n_data.data.items()\n                    },\n                    \"dtype\": ids.dtype,\n                }\n        return node_dict\n\n\nclass EdgeData(BaseData):\n    \"\"\"Class of edge data which is used for DGLGraph construction. Internal use only.\"\"\"\n\n    def __init__(self, src_id, dst_id, data, type=None, graph_id=None):\n        self.src = np.array(src_id)\n        self.dst = np.array(dst_id)\n        self.data = data\n        self.type = type if type is not None else (\"_V\", \"_E\", \"_V\")\n        self.graph_id = (\n            np.array(graph_id)\n            if graph_id is not None\n            else np.full(len(src_id), 0)\n        )\n        _validate_data_length(\n            {\n                **{\"src\": self.src, \"dst\": self.dst, \"graph_id\": self.graph_id},\n                **self.data,\n            }\n        )\n\n    @staticmethod\n    def load_from_csv(\n        meta: MetaEdge, data_parser: Callable, base_dir=None, separator=\",\"\n    ):\n        df = BaseData.read_csv(meta.file_name, base_dir, separator)\n        src_ids = BaseData.pop_from_dataframe(df, meta.src_id_field)\n        if src_ids is None:\n            raise DGLError(\n                \"Missing src id field [{}] in file [{}].\".format(\n                    meta.src_id_field, meta.file_name\n                )\n            )\n        dst_ids = BaseData.pop_from_dataframe(df, meta.dst_id_field)\n        if dst_ids is None:\n            raise DGLError(\n                \"Missing dst id field [{}] in file [{}].\".format(\n                    meta.dst_id_field, meta.file_name\n                )\n            )\n        graph_ids = BaseData.pop_from_dataframe(df, meta.graph_id_field)\n        etype = tuple(meta.etype)\n        edata = data_parser(df)\n        return EdgeData(src_ids, dst_ids, edata, type=etype, graph_id=graph_ids)\n\n    @staticmethod\n    def to_dict(edge_data: List[\"EdgeData\"], node_dict: dict) -> dict:\n        edge_dict = {}\n        for e_data in edge_data:\n            (src_type, e_type, dst_type) = e_data.type\n            graph_ids = np.unique(e_data.graph_id)\n            for graph_id in graph_ids:\n                if graph_id in edge_dict and e_data.type in edge_dict[graph_id]:\n                    raise DGLError(\n                        f\"Duplicate edge type[{e_data.type}] for same graph[{graph_id}], please place the same edge_type for same graph into single EdgeData.\"\n                    )\n                idx = e_data.graph_id == graph_id\n                src_mapping = node_dict[graph_id][src_type][\"mapping\"]\n                dst_mapping = node_dict[graph_id][dst_type][\"mapping\"]\n                orig_src_ids = e_data.src[idx].astype(\n                    node_dict[graph_id][src_type][\"dtype\"]\n                )\n                orig_dst_ids = e_data.dst[idx].astype(\n                    node_dict[graph_id][dst_type][\"dtype\"]\n                )\n                src_ids = [src_mapping[index] for index in orig_src_ids]\n                dst_ids = [dst_mapping[index] for index in orig_dst_ids]\n                if graph_id not in edge_dict:\n                    edge_dict[graph_id] = {}\n                edge_dict[graph_id][e_data.type] = {\n                    \"edges\": (_tensor(src_ids), _tensor(dst_ids)),\n                    \"data\": {\n                        k: _tensor(v[idx]) for k, v in e_data.data.items()\n                    },\n                }\n        return edge_dict\n\n\nclass GraphData(BaseData):\n    \"\"\"Class of graph data which is used for DGLGraph construction. Internal use only.\"\"\"\n\n    def __init__(self, graph_id, data):\n        self.graph_id = np.array(graph_id)\n        self.data = data\n        _validate_data_length({**{\"graph_id\": self.graph_id}, **self.data})\n\n    @staticmethod\n    def load_from_csv(\n        meta: MetaGraph, data_parser: Callable, base_dir=None, separator=\",\"\n    ):\n        df = BaseData.read_csv(meta.file_name, base_dir, separator)\n        graph_ids = BaseData.pop_from_dataframe(df, meta.graph_id_field)\n        if graph_ids is None:\n            raise DGLError(\n                \"Missing graph id field [{}] in file [{}].\".format(\n                    meta.graph_id_field, meta.file_name\n                )\n            )\n        gdata = data_parser(df)\n        return GraphData(graph_ids, gdata)\n\n    @staticmethod\n    def to_dict(graph_data: \"GraphData\", graphs_dict: dict) -> dict:\n        missing_ids = np.setdiff1d(\n            np.array(list(graphs_dict.keys())), graph_data.graph_id\n        )\n        if len(missing_ids) > 0:\n            raise DGLError(\n                \"Found following graph ids in node/edge CSVs but not in graph CSV: {}.\".format(\n                    missing_ids\n                )\n            )\n        graph_ids = graph_data.graph_id\n        graphs = []\n        for graph_id in graph_ids:\n            if graph_id not in graphs_dict:\n                graphs_dict[graph_id] = dgl_heterograph(\n                    {(\"_V\", \"_E\", \"_V\"): ([], [])}\n                )\n        for graph_id in graph_ids:\n            graphs.append(graphs_dict[graph_id])\n        data = {\n            k: F.reshape(_tensor(v), (len(graphs), -1))\n            for k, v in graph_data.data.items()\n        }\n        return graphs, data\n\n\nclass DGLGraphConstructor:\n    \"\"\"Class for constructing DGLGraph from Node/Edge/Graph data. Internal use only.\"\"\"\n\n    @staticmethod\n    def construct_graphs(node_data, edge_data, graph_data=None):\n        if not isinstance(node_data, list):\n            node_data = [node_data]\n        if not isinstance(edge_data, list):\n            edge_data = [edge_data]\n        node_dict = NodeData.to_dict(node_data)\n        edge_dict = EdgeData.to_dict(edge_data, node_dict)\n        graph_dict = DGLGraphConstructor._construct_graphs(node_dict, edge_dict)\n        if graph_data is None:\n            graph_data = GraphData(np.full(1, 0), {})\n        graphs, data = GraphData.to_dict(graph_data, graph_dict)\n        return graphs, data\n\n    @staticmethod\n    def _construct_graphs(node_dict, edge_dict):\n        graph_dict = {}\n        for graph_id in node_dict:\n            if graph_id not in edge_dict:\n                edge_dict[graph_id][(\"_V\", \"_E\", \"_V\")] = {\"edges\": ([], [])}\n            graph = dgl_heterograph(\n                {\n                    etype: edata[\"edges\"]\n                    for etype, edata in edge_dict[graph_id].items()\n                },\n                num_nodes_dict={\n                    ntype: len(ndata[\"mapping\"])\n                    for ntype, ndata in node_dict[graph_id].items()\n                },\n            )\n\n            def assign_data(type, src_data, dst_data):\n                for key, value in src_data.items():\n                    dst_data[type].data[key] = value\n\n            for type, data in node_dict[graph_id].items():\n                assign_data(type, data[\"data\"], graph.nodes)\n            for (type), data in edge_dict[graph_id].items():\n                assign_data(type, data[\"data\"], graph.edges)\n            graph_dict[graph_id] = graph\n        return graph_dict\n\n\nclass DefaultDataParser:\n    \"\"\"Default data parser for CSVDataset. It\n    1. ignores any columns which does not have a header.\n    2. tries to convert to list of numeric values(generated by\n        np.array().tolist()) if cell data is a str separated by ','.\n    3. read data and infer data type directly, otherwise.\n    \"\"\"\n\n    def __call__(self, df: pd.DataFrame):\n        data = {}\n        for header in df:\n            if \"Unnamed\" in header:\n                dgl_warning(\"Unnamed column is found. Ignored...\")\n                continue\n            dt = df[header].to_numpy().squeeze()\n            if len(dt) > 0 and isinstance(dt[0], str):\n                # probably consists of list of numeric values\n                dt = np.array([ast.literal_eval(row) for row in dt])\n            data[header] = dt\n        return data\n"
  },
  {
    "path": "python/dgl/data/dgl_dataset.py",
    "content": "\"\"\"Basic DGL Dataset\n\"\"\"\n\nfrom __future__ import absolute_import\n\nimport abc\nimport hashlib\nimport os\nimport traceback\n\nfrom ..utils import retry_method_with_fix\nfrom .utils import download, extract_archive, get_download_dir, makedirs\n\n\nclass DGLDataset(object):\n    r\"\"\"The basic DGL dataset for creating graph datasets.\n    This class defines a basic template class for DGL Dataset.\n    The following steps will be executed automatically:\n\n      1. Check whether there is a dataset cache on disk\n         (already processed and stored on the disk) by\n         invoking ``has_cache()``. If true, goto 5.\n      2. Call ``download()`` to download the data if ``url`` is not None.\n      3. Call ``process()`` to process the data.\n      4. Call ``save()`` to save the processed dataset on disk and goto 6.\n      5. Call ``load()`` to load the processed dataset from disk.\n      6. Done.\n\n    Users can overwite these functions with their\n    own data processing logic.\n\n    Parameters\n    ----------\n    name : str\n        Name of the dataset\n    url : str\n        Url to download the raw dataset. Default: None\n    raw_dir : str\n        Specifying the directory that will store the\n        downloaded data or the directory that\n        already stores the input data.\n        Default: ~/.dgl/\n    save_dir : str\n        Directory to save the processed dataset.\n        Default: same as raw_dir\n    hash_key : tuple\n        A tuple of values as the input for the hash function.\n        Users can distinguish instances (and their caches on the disk)\n        from the same dataset class by comparing the hash values.\n        Default: (), the corresponding hash value is ``'f9065fa7'``.\n    force_reload : bool\n        Whether to reload the dataset. Default: False\n    verbose : bool\n        Whether to print out progress information\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access.\n\n    Attributes\n    ----------\n    url : str\n        The URL to download the dataset\n    name : str\n        The dataset name\n    raw_dir : str\n        Directory to store all the downloaded raw datasets.\n    raw_path : str\n        Path to the downloaded raw dataset folder. An alias for\n        ``os.path.join(self.raw_dir, self.name)``.\n    save_dir : str\n        Directory to save all the processed datasets.\n    save_path : str\n        Path to the processed dataset folder. An alias for\n        ``os.path.join(self.save_dir, self.name)``.\n    verbose : bool\n        Whether to print more runtime information.\n    hash : str\n        Hash value for the dataset and the setting.\n    \"\"\"\n\n    def __init__(\n        self,\n        name,\n        url=None,\n        raw_dir=None,\n        save_dir=None,\n        hash_key=(),\n        force_reload=False,\n        verbose=False,\n        transform=None,\n    ):\n        self._name = name\n        self._url = url\n        self._force_reload = force_reload\n        self._verbose = verbose\n        self._hash_key = hash_key\n        self._hash = self._get_hash()\n        self._transform = transform\n\n        # if no dir is provided, the default dgl download dir is used.\n        if raw_dir is None:\n            self._raw_dir = get_download_dir()\n        else:\n            self._raw_dir = raw_dir\n\n        if save_dir is None:\n            self._save_dir = self._raw_dir\n        else:\n            self._save_dir = save_dir\n\n        self._load()\n\n    def download(self):\n        r\"\"\"Overwite to realize your own logic of downloading data.\n\n        It is recommended to download the to the :obj:`self.raw_dir`\n        folder. Can be ignored if the dataset is\n        already in :obj:`self.raw_dir`.\n        \"\"\"\n        pass\n\n    def save(self):\n        r\"\"\"Overwite to realize your own logic of\n        saving the processed dataset into files.\n\n        It is recommended to use ``dgl.data.utils.save_graphs``\n        to save dgl graph into files and use\n        ``dgl.data.utils.save_info`` to save extra\n        information into files.\n        \"\"\"\n        pass\n\n    def load(self):\n        r\"\"\"Overwite to realize your own logic of\n        loading the saved dataset from files.\n\n        It is recommended to use ``dgl.data.utils.load_graphs``\n        to load dgl graph from files and use\n        ``dgl.data.utils.load_info`` to load extra information\n        into python dict object.\n        \"\"\"\n        pass\n\n    @abc.abstractmethod\n    def process(self):\n        r\"\"\"Overwrite to realize your own logic of processing the input data.\"\"\"\n        pass\n\n    def has_cache(self):\n        r\"\"\"Overwrite to realize your own logic of\n        deciding whether there exists a cached dataset.\n\n        By default False.\n        \"\"\"\n        return False\n\n    @retry_method_with_fix(download)\n    def _download(self):\n        \"\"\"Download dataset by calling ``self.download()``\n        if the dataset does not exists under ``self.raw_path``.\n\n        By default ``self.raw_path = os.path.join(self.raw_dir, self.name)``\n        One can overwrite ``raw_path()`` function to change the path.\n        \"\"\"\n        if os.path.exists(self.raw_path):  # pragma: no cover\n            return\n\n        makedirs(self.raw_dir)\n        self.download()\n\n    def _load(self):\n        \"\"\"Entry point from __init__ to load the dataset.\n\n        If cache exists:\n\n          - Load the dataset from saved dgl graph and information files.\n          - If loadin process fails, re-download and process the dataset.\n\n        else:\n\n          - Download the dataset if needed.\n          - Process the dataset and build the dgl graph.\n          - Save the processed dataset into files.\n        \"\"\"\n        load_flag = not self._force_reload and self.has_cache()\n\n        if load_flag:\n            try:\n                self.load()\n                if self.verbose:\n                    print(\"Done loading data from cached files.\")\n            except KeyboardInterrupt:\n                raise\n            except:\n                load_flag = False\n                if self.verbose:\n                    print(traceback.format_exc())\n                    print(\"Loading from cache failed, re-processing.\")\n\n        if not load_flag:\n            self._download()\n            self.process()\n            self.save()\n            if self.verbose:\n                print(\"Done saving data into cached files.\")\n\n    def _get_hash(self):\n        \"\"\"Compute the hash of the input tuple\n\n        Example\n        -------\n        Assume `self._hash_key = (10, False, True)`\n\n        >>> hash_value = self._get_hash()\n        >>> hash_value\n        'a770b222'\n        \"\"\"\n        hash_func = hashlib.sha1()\n        hash_func.update(str(self._hash_key).encode(\"utf-8\"))\n        return hash_func.hexdigest()[:8]\n\n    def _get_hash_url_suffix(self):\n        \"\"\"Get the suffix based on the hash value of the url.\"\"\"\n        if self._url is None:\n            return \"\"\n        else:\n            hash_func = hashlib.sha1()\n            hash_func.update(str(self._url).encode(\"utf-8\"))\n            return \"_\" + hash_func.hexdigest()[:8]\n\n    @property\n    def url(self):\n        r\"\"\"Get url to download the raw dataset.\"\"\"\n        return self._url\n\n    @property\n    def name(self):\n        r\"\"\"Name of the dataset.\"\"\"\n        return self._name\n\n    @property\n    def raw_dir(self):\n        r\"\"\"Raw file directory contains the input data folder.\"\"\"\n        return self._raw_dir\n\n    @property\n    def raw_path(self):\n        r\"\"\"Directory contains the input data files.\n        By default raw_path = os.path.join(self.raw_dir, self.name)\n        \"\"\"\n        return os.path.join(\n            self.raw_dir, self.name + self._get_hash_url_suffix()\n        )\n\n    @property\n    def save_dir(self):\n        r\"\"\"Directory to save the processed dataset.\"\"\"\n        return self._save_dir\n\n    @property\n    def save_path(self):\n        r\"\"\"Path to save the processed dataset.\"\"\"\n        return os.path.join(\n            self.save_dir, self.name + self._get_hash_url_suffix()\n        )\n\n    @property\n    def verbose(self):\n        r\"\"\"Whether to print information.\"\"\"\n        return self._verbose\n\n    @property\n    def hash(self):\n        r\"\"\"Hash value for the dataset and the setting.\"\"\"\n        return self._hash\n\n    @abc.abstractmethod\n    def __getitem__(self, idx):\n        r\"\"\"Gets the data object at index.\"\"\"\n        pass\n\n    @abc.abstractmethod\n    def __len__(self):\n        r\"\"\"The number of examples in the dataset.\"\"\"\n        pass\n\n    def __repr__(self):\n        return (\n            f'Dataset(\"{self.name}\", num_graphs={len(self)},'\n            + f\" save_path={self.save_path})\"\n        )\n\n\nclass DGLBuiltinDataset(DGLDataset):\n    r\"\"\"The Basic DGL Builtin Dataset.\n\n    Parameters\n    ----------\n    name : str\n        Name of the dataset.\n    url : str\n        Url to download the raw dataset.\n    raw_dir : str\n        Specifying the directory that will store the\n        downloaded data or the directory that\n        already stores the input data.\n        Default: ~/.dgl/\n    hash_key : tuple\n        A tuple of values as the input for the hash function.\n        Users can distinguish instances (and their caches on the disk)\n        from the same dataset class by comparing the hash values.\n    force_reload : bool\n        Whether to reload the dataset. Default: False\n    verbose : bool\n        Whether to print out progress information. Default: False\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access.\n    \"\"\"\n\n    def __init__(\n        self,\n        name,\n        url,\n        raw_dir=None,\n        hash_key=(),\n        force_reload=False,\n        verbose=False,\n        transform=None,\n    ):\n        super(DGLBuiltinDataset, self).__init__(\n            name,\n            url=url,\n            raw_dir=raw_dir,\n            save_dir=None,\n            hash_key=hash_key,\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n    def download(self):\n        r\"\"\"Automatically download data and extract it.\"\"\"\n        if self.url is not None:\n            zip_file_path = os.path.join(self.raw_dir, self.name + \".zip\")\n            download(self.url, path=zip_file_path)\n            extract_archive(zip_file_path, self.raw_path)\n"
  },
  {
    "path": "python/dgl/data/fakenews.py",
    "content": "import os\n\nimport numpy as np\nimport scipy.sparse as sp\n\nfrom .. import backend as F\nfrom ..convert import graph\nfrom .dgl_dataset import DGLBuiltinDataset\nfrom .utils import _get_dgl_url, load_graphs, load_info, save_graphs, save_info\n\n\nclass FakeNewsDataset(DGLBuiltinDataset):\n    r\"\"\"Fake News Graph Classification dataset.\n\n    The dataset is composed of two sets of tree-structured fake/real\n    news propagation graphs extracted from Twitter. Different from\n    most of the benchmark datasets for the graph classification task,\n    the graphs in this dataset are directed tree-structured graphs where\n    the root node represents the news, the leaf nodes are Twitter users\n    who retweeted the root news. Besides, the node features are encoded\n    user historical tweets using different pretrained language models:\n\n    - bert: the 768-dimensional node feature composed of Twitter user historical tweets encoded by the bert-as-service\n    - content: the 310-dimensional node feature composed of a 300-dimensional “spacy” vector plus a 10-dimensional “profile” vector\n    - profile: the 10-dimensional node feature composed of ten Twitter user profile attributes.\n    - spacy: the 300-dimensional node feature composed of Twitter user historical tweets encoded by the spaCy word2vec encoder.\n\n    Reference: <https://github.com/safe-graph/GNN-FakeNews>\n\n    Note: this dataset is for academic use only, and commercial use is prohibited.\n\n    Statistics:\n\n        Politifact:\n\n        - Graphs: 314\n        - Nodes: 41,054\n        - Edges: 40,740\n        - Classes:\n\n            - Fake: 157\n            - Real: 157\n\n        - Node feature size:\n\n            - bert: 768\n            - content: 310\n            - profile: 10\n            - spacy: 300\n\n        Gossipcop:\n\n        - Graphs: 5,464\n        - Nodes: 314,262\n        - Edges: 308,798\n        - Classes:\n\n            - Fake: 2,732\n            - Real: 2,732\n\n        - Node feature size:\n\n            - bert: 768\n            - content: 310\n            - profile: 10\n            - spacy: 300\n\n    Parameters\n    ----------\n    name : str\n        Name of the dataset (gossipcop, or politifact)\n    feature_name : str\n        Name of the feature (bert, content, profile, or spacy)\n    raw_dir : str\n        Specifying the directory that will store the\n        downloaded data or the directory that\n        already stores the input data.\n        Default: ~/.dgl/\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access.\n\n    Attributes\n    ----------\n    name : str\n        Name of the dataset (gossipcop, or politifact)\n    num_classes : int\n        Number of label classes\n    num_graphs : int\n        Number of graphs\n    graphs : list\n        A list of DGLGraph objects\n    labels : Tensor\n        Graph labels\n    feature_name : str\n        Name of the feature (bert, content, profile, or spacy)\n    feature : Tensor\n        Node features\n    train_mask : Tensor\n        Mask of training set\n    val_mask : Tensor\n        Mask of validation set\n    test_mask : Tensor\n        Mask of testing set\n\n    Examples\n    --------\n    >>> dataset = FakeNewsDataset('gossipcop', 'bert')\n    >>> graph, label = dataset[0]\n    >>> num_classes = dataset.num_classes\n    >>> feat = dataset.feature\n    >>> labels = dataset.labels\n    \"\"\"\n    file_urls = {\n        \"gossipcop\": \"dataset/FakeNewsGOS.zip\",\n        \"politifact\": \"dataset/FakeNewsPOL.zip\",\n    }\n\n    def __init__(self, name, feature_name, raw_dir=None, transform=None):\n        assert name in [\n            \"gossipcop\",\n            \"politifact\",\n        ], \"Only supports 'gossipcop' or 'politifact'.\"\n        url = _get_dgl_url(self.file_urls[name])\n\n        assert feature_name in [\n            \"bert\",\n            \"content\",\n            \"profile\",\n            \"spacy\",\n        ], \"Only supports 'bert', 'content', 'profile', or 'spacy'\"\n        self.feature_name = feature_name\n        super(FakeNewsDataset, self).__init__(\n            name=name, url=url, raw_dir=raw_dir, transform=transform\n        )\n\n    def process(self):\n        \"\"\"process raw data to graph, labels and masks\"\"\"\n        self.labels = F.tensor(\n            np.load(os.path.join(self.raw_path, \"graph_labels.npy\"))\n        )\n        num_graphs = self.labels.shape[0]\n\n        node_graph_id = np.load(\n            os.path.join(self.raw_path, \"node_graph_id.npy\")\n        )\n        edges = np.genfromtxt(\n            os.path.join(self.raw_path, \"A.txt\"), delimiter=\",\", dtype=int\n        )\n        src = edges[:, 0]\n        dst = edges[:, 1]\n        g = graph((src, dst))\n\n        node_idx_list = []\n        for idx in range(np.max(node_graph_id) + 1):\n            node_idx = np.where(node_graph_id == idx)\n            node_idx_list.append(node_idx[0])\n\n        self.graphs = [g.subgraph(node_idx) for node_idx in node_idx_list]\n\n        train_idx = np.load(os.path.join(self.raw_path, \"train_idx.npy\"))\n        val_idx = np.load(os.path.join(self.raw_path, \"val_idx.npy\"))\n        test_idx = np.load(os.path.join(self.raw_path, \"test_idx.npy\"))\n        train_mask = np.zeros(num_graphs, dtype=np.bool_)\n        val_mask = np.zeros(num_graphs, dtype=np.bool_)\n        test_mask = np.zeros(num_graphs, dtype=np.bool_)\n        train_mask[train_idx] = True\n        val_mask[val_idx] = True\n        test_mask[test_idx] = True\n        self.train_mask = F.tensor(train_mask)\n        self.val_mask = F.tensor(val_mask)\n        self.test_mask = F.tensor(test_mask)\n\n        feature_file = \"new_\" + self.feature_name + \"_feature.npz\"\n        self.feature = F.tensor(\n            sp.load_npz(os.path.join(self.raw_path, feature_file)).todense()\n        )\n\n    def save(self):\n        \"\"\"save the graph list and the labels\"\"\"\n        save_graphs(str(self.graph_path), self.graphs)\n        save_info(\n            self.info_path,\n            {\n                \"label\": self.labels,\n                \"feature\": self.feature,\n                \"train_mask\": self.train_mask,\n                \"val_mask\": self.val_mask,\n                \"test_mask\": self.test_mask,\n            },\n        )\n\n    @property\n    def graph_path(self):\n        return os.path.join(self.save_path, self.name + \"_dgl_graph.bin\")\n\n    @property\n    def info_path(self):\n        return os.path.join(self.save_path, self.name + \"_dgl_graph.pkl\")\n\n    def has_cache(self):\n        \"\"\"check whether there are processed data in `self.save_path`\"\"\"\n        return os.path.exists(self.graph_path) and os.path.exists(\n            self.info_path\n        )\n\n    def load(self):\n        \"\"\"load processed data from directory `self.save_path`\"\"\"\n        graphs, _ = load_graphs(str(self.graph_path))\n        info = load_info(str(self.info_path))\n        self.graphs = graphs\n        self.labels = info[\"label\"]\n        self.feature = info[\"feature\"]\n\n        self.train_mask = info[\"train_mask\"]\n        self.val_mask = info[\"val_mask\"]\n        self.test_mask = info[\"test_mask\"]\n\n    @property\n    def num_classes(self):\n        \"\"\"Number of classes for each graph, i.e. number of prediction tasks.\"\"\"\n        return 2\n\n    @property\n    def num_graphs(self):\n        \"\"\"Number of graphs.\"\"\"\n        return self.labels.shape[0]\n\n    def __getitem__(self, i):\n        r\"\"\"Get graph and label by index\n\n        Parameters\n        ----------\n        i : int\n            Item index\n\n        Returns\n        -------\n        (:class:`dgl.DGLGraph`, Tensor)\n        \"\"\"\n        if self._transform is None:\n            g = self.graphs[i]\n        else:\n            g = self._transform(self.graphs[i])\n        return g, self.labels[i]\n\n    def __len__(self):\n        r\"\"\"Number of graphs in the dataset.\n\n        Return\n        -------\n        int\n        \"\"\"\n        return len(self.graphs)\n"
  },
  {
    "path": "python/dgl/data/flickr.py",
    "content": "\"\"\"Flickr Dataset\"\"\"\nimport json\nimport os\n\nimport numpy as np\nimport scipy.sparse as sp\n\nfrom .. import backend as F\nfrom ..convert import from_scipy\nfrom ..transforms import reorder_graph\nfrom .dgl_dataset import DGLBuiltinDataset\nfrom .utils import _get_dgl_url, generate_mask_tensor, load_graphs, save_graphs\n\n\nclass FlickrDataset(DGLBuiltinDataset):\n    r\"\"\"Flickr dataset for node classification from `GraphSAINT: Graph Sampling Based Inductive\n    Learning Method <https://arxiv.org/abs/1907.04931>`_\n\n    The task of this dataset is categorizing types of images based on the descriptions and common\n    properties of online images.\n\n    Flickr dataset statistics:\n\n    - Nodes: 89,250\n    - Edges: 899,756\n    - Number of classes: 7\n    - Node feature size: 500\n\n    Parameters\n    ----------\n    raw_dir : str\n        Raw file directory to download/contains the input data directory.\n        Default: ~/.dgl/\n    force_reload : bool\n        Whether to reload the dataset.\n        Default: False\n    verbose : bool\n        Whether to print out progress information.\n        Default: False\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access.\n    reorder : bool\n        Whether to reorder the graph using :func:`~dgl.reorder_graph`.\n        Default: False.\n\n    Attributes\n    ----------\n    num_classes : int\n        Number of node classes\n\n    Examples\n    --------\n    >>> from dgl.data import FlickrDataset\n    >>> dataset = FlickrDataset()\n    >>> dataset.num_classes\n    7\n    >>> g = dataset[0]\n    >>> # get node feature\n    >>> feat = g.ndata['feat']\n    >>> # get node labels\n    >>> labels = g.ndata['label']\n    >>> # get data split\n    >>> train_mask = g.ndata['train_mask']\n    >>> val_mask = g.ndata['val_mask']\n    >>> test_mask = g.ndata['test_mask']\n    \"\"\"\n\n    def __init__(\n        self,\n        raw_dir=None,\n        force_reload=False,\n        verbose=False,\n        transform=None,\n        reorder=False,\n    ):\n        _url = _get_dgl_url(\"dataset/flickr.zip\")\n        self._reorder = reorder\n        super(FlickrDataset, self).__init__(\n            name=\"flickr\",\n            raw_dir=raw_dir,\n            url=_url,\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n    def process(self):\n        \"\"\"process raw data to graph, labels and masks\"\"\"\n        coo_adj = sp.load_npz(os.path.join(self.raw_path, \"adj_full.npz\"))\n        g = from_scipy(coo_adj)\n\n        features = np.load(os.path.join(self.raw_path, \"feats.npy\"))\n        features = F.tensor(features, dtype=F.float32)\n\n        y = [-1] * features.shape[0]\n        with open(os.path.join(self.raw_path, \"class_map.json\")) as f:\n            class_map = json.load(f)\n            for key, item in class_map.items():\n                y[int(key)] = item\n        labels = F.tensor(np.array(y), dtype=F.int64)\n\n        with open(os.path.join(self.raw_path, \"role.json\")) as f:\n            role = json.load(f)\n\n        train_mask = np.zeros(features.shape[0], dtype=bool)\n        train_mask[role[\"tr\"]] = True\n\n        val_mask = np.zeros(features.shape[0], dtype=bool)\n        val_mask[role[\"va\"]] = True\n\n        test_mask = np.zeros(features.shape[0], dtype=bool)\n        test_mask[role[\"te\"]] = True\n\n        g.ndata[\"feat\"] = features\n        g.ndata[\"label\"] = labels\n        g.ndata[\"train_mask\"] = generate_mask_tensor(train_mask)\n        g.ndata[\"val_mask\"] = generate_mask_tensor(val_mask)\n        g.ndata[\"test_mask\"] = generate_mask_tensor(test_mask)\n\n        if self._reorder:\n            self._graph = reorder_graph(\n                g,\n                node_permute_algo=\"rcmk\",\n                edge_permute_algo=\"dst\",\n                store_ids=False,\n            )\n        else:\n            self._graph = g\n\n    def has_cache(self):\n        graph_path = os.path.join(self.save_path, \"dgl_graph.bin\")\n        return os.path.exists(graph_path)\n\n    def save(self):\n        graph_path = os.path.join(self.save_path, \"dgl_graph.bin\")\n        save_graphs(graph_path, self._graph)\n\n    def load(self):\n        graph_path = os.path.join(self.save_path, \"dgl_graph.bin\")\n        g, _ = load_graphs(graph_path)\n        self._graph = g[0]\n\n    @property\n    def num_classes(self):\n        return 7\n\n    def __len__(self):\n        r\"\"\"The number of graphs in the dataset.\"\"\"\n        return 1\n\n    def __getitem__(self, idx):\n        r\"\"\"Get graph object\n\n        Parameters\n        ----------\n        idx : int\n            Item index, FlickrDataset has only one graph object\n\n        Returns\n        -------\n        :class:`dgl.DGLGraph`\n\n            The graph contains:\n\n            - ``ndata['label']``: node label\n            - ``ndata['feat']``: node feature\n            - ``ndata['train_mask']``: mask for training node set\n            - ``ndata['val_mask']``: mask for validation node set\n            - ``ndata['test_mask']``: mask for test node set\n\n        \"\"\"\n        assert idx == 0, \"This dataset has only one graph\"\n        if self._transform is None:\n            return self._graph\n        else:\n            return self._transform(self._graph)\n"
  },
  {
    "path": "python/dgl/data/fraud.py",
    "content": "\"\"\"Fraud Dataset\n\"\"\"\nimport os\n\nimport numpy as np\nfrom scipy import io\n\nfrom .. import backend as F\nfrom ..convert import heterograph\nfrom .dgl_dataset import DGLBuiltinDataset\nfrom .utils import _get_dgl_url, load_graphs, save_graphs\n\n\nclass FraudDataset(DGLBuiltinDataset):\n    r\"\"\"Fraud node prediction dataset.\n\n    The dataset includes two multi-relational graphs extracted from Yelp and Amazon\n    where nodes represent fraudulent reviews or fraudulent reviewers.\n\n    It was first proposed in a CIKM'20 paper <https://arxiv.org/pdf/2008.08692.pdf> and\n    has been used by a recent WWW'21 paper <https://ponderly.github.io/pub/PCGNN_WWW2021.pdf>\n    as a benchmark. Another paper <https://arxiv.org/pdf/2104.01404.pdf> also takes\n    the dataset as an example to study the non-homophilous graphs. This dataset is built\n    upon industrial data and has rich relational information and unique properties like\n    class-imbalance and feature inconsistency, which makes the dataset be a good instance\n    to investigate how GNNs perform on real-world noisy graphs. These graphs are bidirected\n    and not self connected.\n\n    Reference: <https://github.com/YingtongDou/CARE-GNN>\n\n    Parameters\n    ----------\n    name : str\n        Name of the dataset\n    raw_dir : str\n        Specifying the directory that will store the\n        downloaded data or the directory that\n        already stores the input data.\n        Default: ~/.dgl/\n    random_seed : int\n        Specifying the random seed in splitting the dataset.\n        Default: 717\n    train_size : float\n        training set size of the dataset.\n        Default: 0.7\n    val_size : float\n        validation set size of the dataset, and the\n        size of testing set is (1 - train_size - val_size)\n        Default: 0.1\n    force_reload : bool\n        Whether to reload the dataset. Default: False\n    verbose : bool\n        Whether to print out progress information. Default: True.\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access.\n\n    Attributes\n    ----------\n    num_classes : int\n        Number of label classes\n    graph : dgl.DGLGraph\n        Graph structure, etc.\n    seed : int\n        Random seed in splitting the dataset.\n    train_size : float\n        Training set size of the dataset.\n    val_size : float\n        Validation set size of the dataset\n\n    Examples\n    --------\n    >>> dataset = FraudDataset('yelp')\n    >>> graph = dataset[0]\n    >>> num_classes = dataset.num_classes\n    >>> feat = graph.ndata['feature']\n    >>> label = graph.ndata['label']\n    \"\"\"\n    file_urls = {\n        \"yelp\": \"dataset/FraudYelp.zip\",\n        \"amazon\": \"dataset/FraudAmazon.zip\",\n    }\n    relations = {\n        \"yelp\": [\"net_rsr\", \"net_rtr\", \"net_rur\"],\n        \"amazon\": [\"net_upu\", \"net_usu\", \"net_uvu\"],\n    }\n    file_names = {\"yelp\": \"YelpChi.mat\", \"amazon\": \"Amazon.mat\"}\n    node_name = {\"yelp\": \"review\", \"amazon\": \"user\"}\n\n    def __init__(\n        self,\n        name,\n        raw_dir=None,\n        random_seed=717,\n        train_size=0.7,\n        val_size=0.1,\n        force_reload=False,\n        verbose=True,\n        transform=None,\n    ):\n        assert name in [\"yelp\", \"amazon\"], \"only supports 'yelp', or 'amazon'\"\n        url = _get_dgl_url(self.file_urls[name])\n        self.seed = random_seed\n        self.train_size = train_size\n        self.val_size = val_size\n        super(FraudDataset, self).__init__(\n            name=name,\n            url=url,\n            raw_dir=raw_dir,\n            hash_key=(random_seed, train_size, val_size),\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n    def process(self):\n        \"\"\"process raw data to graph, labels, splitting masks\"\"\"\n        file_path = os.path.join(self.raw_path, self.file_names[self.name])\n\n        data = io.loadmat(file_path)\n        node_features = data[\"features\"].todense()\n        # remove additional dimension of length 1 in raw .mat file\n        node_labels = data[\"label\"].squeeze()\n\n        graph_data = {}\n        for relation in self.relations[self.name]:\n            adj = data[relation].tocoo()\n            row, col = adj.row, adj.col\n            graph_data[\n                (self.node_name[self.name], relation, self.node_name[self.name])\n            ] = (row, col)\n        g = heterograph(graph_data)\n\n        g.ndata[\"feature\"] = F.tensor(\n            node_features, dtype=F.data_type_dict[\"float32\"]\n        )\n        g.ndata[\"label\"] = F.tensor(\n            node_labels, dtype=F.data_type_dict[\"int64\"]\n        )\n        self.graph = g\n\n        self._random_split(\n            g.ndata[\"feature\"], self.seed, self.train_size, self.val_size\n        )\n\n    def __getitem__(self, idx):\n        r\"\"\"Get graph object\n\n        Parameters\n        ----------\n        idx : int\n            Item index\n\n        Returns\n        -------\n        :class:`dgl.DGLGraph`\n            graph structure, node features, node labels and masks\n\n            - ``ndata['feature']``: node features\n            - ``ndata['label']``: node labels\n            - ``ndata['train_mask']``: mask of training set\n            - ``ndata['val_mask']``: mask of validation set\n            - ``ndata['test_mask']``: mask of testing set\n        \"\"\"\n        assert idx == 0, \"This dataset has only one graph\"\n        if self._transform is None:\n            return self.graph\n        else:\n            return self._transform(self.graph)\n\n    def __len__(self):\n        \"\"\"number of data examples\"\"\"\n        return len(self.graph)\n\n    @property\n    def num_classes(self):\n        \"\"\"Number of classes.\n\n        Return\n        -------\n        int\n        \"\"\"\n        return 2\n\n    def save(self):\n        \"\"\"save processed data to directory `self.save_path`\"\"\"\n        graph_path = os.path.join(\n            self.save_path, self.name + \"_dgl_graph_{}.bin\".format(self.hash)\n        )\n        save_graphs(str(graph_path), self.graph)\n\n    def load(self):\n        \"\"\"load processed data from directory `self.save_path`\"\"\"\n        graph_path = os.path.join(\n            self.save_path, self.name + \"_dgl_graph_{}.bin\".format(self.hash)\n        )\n        graph_list, _ = load_graphs(str(graph_path))\n        g = graph_list[0]\n        self.graph = g\n\n    def has_cache(self):\n        \"\"\"check whether there are processed data in `self.save_path`\"\"\"\n        graph_path = os.path.join(\n            self.save_path, self.name + \"_dgl_graph_{}.bin\".format(self.hash)\n        )\n        return os.path.exists(graph_path)\n\n    def _random_split(self, x, seed=717, train_size=0.7, val_size=0.1):\n        \"\"\"split the dataset into training set, validation set and testing set\"\"\"\n\n        assert 0 <= train_size + val_size <= 1, (\n            \"The sum of valid training set size and validation set size \"\n            \"must between 0 and 1 (inclusive).\"\n        )\n\n        N = x.shape[0]\n        index = np.arange(N)\n        if self.name == \"amazon\":\n            # 0-3304 are unlabeled nodes\n            index = np.arange(3305, N)\n\n        index = np.random.RandomState(seed).permutation(index)\n        train_idx = index[: int(train_size * len(index))]\n        val_idx = index[len(index) - int(val_size * len(index)) :]\n        test_idx = index[\n            int(train_size * len(index)) : len(index)\n            - int(val_size * len(index))\n        ]\n        train_mask = np.zeros(N, dtype=np.bool_)\n        val_mask = np.zeros(N, dtype=np.bool_)\n        test_mask = np.zeros(N, dtype=np.bool_)\n        train_mask[train_idx] = True\n        val_mask[val_idx] = True\n        test_mask[test_idx] = True\n        self.graph.ndata[\"train_mask\"] = F.tensor(train_mask)\n        self.graph.ndata[\"val_mask\"] = F.tensor(val_mask)\n        self.graph.ndata[\"test_mask\"] = F.tensor(test_mask)\n\n\nclass FraudYelpDataset(FraudDataset):\n    r\"\"\"Fraud Yelp Dataset\n\n    The Yelp dataset includes hotel and restaurant reviews filtered (spam) and recommended\n    (legitimate) by Yelp. A spam review detection task can be conducted, which is a binary\n    classification task. 32 handcrafted features from <http://dx.doi.org/10.1145/2783258.2783370>\n    are taken as the raw node features. Reviews are nodes in the graph, and three relations are:\n\n        1. R-U-R: it connects reviews posted by the same user\n        2. R-S-R: it connects reviews under the same product with the same star rating (1-5 stars)\n        3. R-T-R: it connects two reviews under the same product posted in the same month.\n\n    Statistics:\n\n    - Nodes: 45,954\n    - Edges:\n\n        - R-U-R: 98,630\n        - R-T-R: 1,147,232\n        - R-S-R: 6,805,486\n\n    - Classes:\n\n        - Positive (spam): 6,677\n        - Negative (legitimate): 39,277\n\n    - Positive-Negative ratio: 1 : 5.9\n    - Node feature size: 32\n\n    Parameters\n    ----------\n    raw_dir : str\n        Specifying the directory that will store the\n        downloaded data or the directory that\n        already stores the input data.\n        Default: ~/.dgl/\n    random_seed : int\n        Specifying the random seed in splitting the dataset.\n        Default: 717\n    train_size : float\n        training set size of the dataset.\n        Default: 0.7\n    val_size : float\n        validation set size of the dataset, and the\n        size of testing set is (1 - train_size - val_size)\n        Default: 0.1\n    force_reload : bool\n        Whether to reload the dataset. Default: False\n    verbose : bool\n        Whether to print out progress information. Default: True.\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access.\n\n    Examples\n    --------\n    >>> dataset = FraudYelpDataset()\n    >>> graph = dataset[0]\n    >>> num_classes = dataset.num_classes\n    >>> feat = graph.ndata['feature']\n    >>> label = graph.ndata['label']\n    \"\"\"\n\n    def __init__(\n        self,\n        raw_dir=None,\n        random_seed=717,\n        train_size=0.7,\n        val_size=0.1,\n        force_reload=False,\n        verbose=True,\n        transform=None,\n    ):\n        super(FraudYelpDataset, self).__init__(\n            name=\"yelp\",\n            raw_dir=raw_dir,\n            random_seed=random_seed,\n            train_size=train_size,\n            val_size=val_size,\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n\nclass FraudAmazonDataset(FraudDataset):\n    r\"\"\"Fraud Amazon Dataset\n\n    The Amazon dataset includes product reviews under the Musical Instruments category.\n    Users with more than 80% helpful votes are labelled as benign entities and users with\n    less than 20% helpful votes are labelled as fraudulent entities. A fraudulent user\n    detection task can be conducted on the Amazon dataset, which is a binary classification\n    task. 25 handcrafted features from <https://arxiv.org/pdf/2005.10150.pdf> are taken as\n    the raw node features .\n\n    Users are nodes in the graph, and three relations are:\n    1. U-P-U : it connects users reviewing at least one same product\n    2. U-S-U : it connects users having at least one same star rating within one week\n    3. U-V-U : it connects users with top 5% mutual review text similarities (measured by\n    TF-IDF) among all users.\n\n    Statistics:\n\n    - Nodes: 11,944\n    - Edges:\n\n        - U-P-U: 351,216\n        - U-S-U: 7,132,958\n        - U-V-U: 2,073,474\n\n    - Classes:\n\n        - Positive (fraudulent): 821\n        - Negative (benign): 7,818\n        - Unlabeled: 3,305\n\n    - Positive-Negative ratio: 1 : 10.5\n    - Node feature size: 25\n\n    Parameters\n    ----------\n    raw_dir : str\n        Specifying the directory that will store the\n        downloaded data or the directory that\n        already stores the input data.\n        Default: ~/.dgl/\n    random_seed : int\n        Specifying the random seed in splitting the dataset.\n        Default: 717\n    train_size : float\n        training set size of the dataset.\n        Default: 0.7\n    val_size : float\n        validation set size of the dataset, and the\n        size of testing set is (1 - train_size - val_size)\n        Default: 0.1\n    force_reload : bool\n        Whether to reload the dataset. Default: False\n    verbose : bool\n        Whether to print out progress information. Default: True.\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access.\n\n    Examples\n    --------\n    >>> dataset = FraudAmazonDataset()\n    >>> graph = dataset[0]\n    >>> num_classes = dataset.num_classes\n    >>> feat = graph.ndata['feature']\n    >>> label = graph.ndata['label']\n    \"\"\"\n\n    def __init__(\n        self,\n        raw_dir=None,\n        random_seed=717,\n        train_size=0.7,\n        val_size=0.1,\n        force_reload=False,\n        verbose=True,\n        transform=None,\n    ):\n        super(FraudAmazonDataset, self).__init__(\n            name=\"amazon\",\n            raw_dir=raw_dir,\n            random_seed=random_seed,\n            train_size=train_size,\n            val_size=val_size,\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n"
  },
  {
    "path": "python/dgl/data/gdelt.py",
    "content": "\"\"\" GDELT dataset for temporal graph \"\"\"\nimport os\n\nimport numpy as np\n\nfrom .. import backend as F\nfrom ..convert import graph as dgl_graph\nfrom .dgl_dataset import DGLBuiltinDataset\nfrom .utils import _get_dgl_url, load_info, loadtxt, save_info\n\n\nclass GDELTDataset(DGLBuiltinDataset):\n    r\"\"\"GDELT dataset for event-based temporal graph\n\n    The Global Database of Events, Language, and Tone (GDELT) dataset.\n    This contains events happend all over the world (ie every protest held\n    anywhere in Russia on a given day is collapsed to a single entry).\n    This Dataset consists ofevents collected from 1/1/2018 to 1/31/2018\n    (15 minutes time granularity).\n\n    Reference:\n\n        - `Recurrent Event Network for Reasoning over Temporal Knowledge Graphs <https://arxiv.org/abs/1904.05530>`_\n        - `The Global Database of Events, Language, and Tone (GDELT) <https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/28075>`_\n\n    Statistics:\n\n    - Train examples: 2,304\n    - Valid examples: 288\n    - Test examples: 384\n\n    Parameters\n    ----------\n    mode : str\n        Must be one of ('train', 'valid', 'test'). Default: 'train'\n    raw_dir : str\n        Raw file directory to download/contains the input data directory.\n        Default: ~/.dgl/\n    force_reload : bool\n        Whether to reload the dataset. Default: False\n    verbose : bool\n        Whether to print out progress information. Default: True.\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access.\n\n    Attributes\n    ----------\n    start_time : int\n        Start time of the temporal graph\n    end_time : int\n        End time of the temporal graph\n    is_temporal : bool\n        Does the dataset contain temporal graphs\n\n    Examples\n    ----------\n    >>> # get train, valid, test dataset\n    >>> train_data = GDELTDataset()\n    >>> valid_data = GDELTDataset(mode='valid')\n    >>> test_data = GDELTDataset(mode='test')\n    >>>\n    >>> # length of train set\n    >>> train_size = len(train_data)\n    >>>\n    >>> for g in train_data:\n    ....    e_feat = g.edata['rel_type']\n    ....    # your code here\n    ....\n    >>>\n    \"\"\"\n\n    def __init__(\n        self,\n        mode=\"train\",\n        raw_dir=None,\n        force_reload=False,\n        verbose=False,\n        transform=None,\n    ):\n        mode = mode.lower()\n        assert mode in [\"train\", \"valid\", \"test\"], \"Mode not valid.\"\n        self.mode = mode\n        self.num_nodes = 23033\n        _url = _get_dgl_url(\"dataset/gdelt.zip\")\n        super(GDELTDataset, self).__init__(\n            name=\"GDELT\",\n            url=_url,\n            raw_dir=raw_dir,\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n    def process(self):\n        file_path = os.path.join(self.raw_path, self.mode + \".txt\")\n        self.data = loadtxt(file_path, delimiter=\"\\t\").astype(np.int64)\n\n        # The source code is not released, but the paper indicates there're\n        # totally 137 samples. The cutoff below has exactly 137 samples.\n        self.time_index = np.floor(self.data[:, 3] / 15).astype(np.int64)\n        self._start_time = self.time_index.min()\n        self._end_time = self.time_index.max()\n\n    @property\n    def info_path(self):\n        return os.path.join(self.save_path, self.mode + \"_info.pkl\")\n\n    def has_cache(self):\n        return os.path.exists(self.info_path)\n\n    def save(self):\n        save_info(\n            self.info_path,\n            {\n                \"data\": self.data,\n                \"time_index\": self.time_index,\n                \"start_time\": self.start_time,\n                \"end_time\": self.end_time,\n            },\n        )\n\n    def load(self):\n        info = load_info(self.info_path)\n        self.data, self.time_index, self._start_time, self._end_time = (\n            info[\"data\"],\n            info[\"time_index\"],\n            info[\"start_time\"],\n            info[\"end_time\"],\n        )\n\n    @property\n    def start_time(self):\n        r\"\"\"Start time of events in the temporal graph\n\n        Returns\n        -------\n        int\n        \"\"\"\n        return self._start_time\n\n    @property\n    def end_time(self):\n        r\"\"\"End time of events in the temporal graph\n\n        Returns\n        -------\n        int\n        \"\"\"\n        return self._end_time\n\n    def __getitem__(self, t):\n        r\"\"\"Get graph by with events before time `t + self.start_time`\n\n        Parameters\n        ----------\n        t : int\n            Time, its value must be in range [0, `self.end_time` - `self.start_time`]\n\n        Returns\n        -------\n        :class:`dgl.DGLGraph`\n\n            The graph contains:\n\n            - ``edata['rel_type']``: edge type\n        \"\"\"\n        if t >= len(self) or t < 0:\n            raise IndexError(\"Index out of range\")\n        i = t + self.start_time\n        row_mask = self.time_index <= i\n        edges = self.data[row_mask][:, [0, 2]]\n        rate = self.data[row_mask][:, 1]\n        g = dgl_graph((edges[:, 0], edges[:, 1]))\n        g.edata[\"rel_type\"] = F.tensor(\n            rate.reshape(-1, 1), dtype=F.data_type_dict[\"int64\"]\n        )\n        if self._transform is not None:\n            g = self._transform(g)\n        return g\n\n    def __len__(self):\n        r\"\"\"Number of graphs in the dataset.\n\n        Return\n        -------\n        int\n        \"\"\"\n        return self._end_time - self._start_time + 1\n\n    @property\n    def is_temporal(self):\n        r\"\"\"Does the dataset contain temporal graphs\n\n        Returns\n        -------\n        bool\n        \"\"\"\n        return True\n\n\nGDELT = GDELTDataset\n"
  },
  {
    "path": "python/dgl/data/geom_gcn.py",
    "content": "\"\"\"Datasets introduced in the Geom-GCN paper.\"\"\"\nimport os\n\nimport numpy as np\n\nfrom ..convert import graph\nfrom .dgl_dataset import DGLBuiltinDataset\nfrom .utils import _get_dgl_url\n\n\nclass GeomGCNDataset(DGLBuiltinDataset):\n    r\"\"\"Datasets introduced in\n    `Geom-GCN: Geometric Graph Convolutional Networks\n    <https://arxiv.org/abs/2002.05287>`__\n\n    Parameters\n    ----------\n    name : str\n        Name of the dataset.\n    raw_dir : str\n        Raw file directory to store the processed data.\n    force_reload : bool\n        Whether to re-download the data source.\n    verbose : bool\n        Whether to print progress information.\n    transform : callable\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access.\n    \"\"\"\n\n    def __init__(self, name, raw_dir, force_reload, verbose, transform):\n        url = _get_dgl_url(f\"dataset/{name}.zip\")\n        super(GeomGCNDataset, self).__init__(\n            name=name,\n            url=url,\n            raw_dir=raw_dir,\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n    def process(self):\n        \"\"\"Load and process the data.\"\"\"\n        try:\n            import torch\n        except ImportError:\n            raise ModuleNotFoundError(\n                \"This dataset requires PyTorch to be the backend.\"\n            )\n\n        # Process node features and labels.\n        with open(f\"{self.raw_path}/out1_node_feature_label.txt\", \"r\") as f:\n            data = f.read().split(\"\\n\")[1:-1]\n        features = [\n            [float(v) for v in r.split(\"\\t\")[1].split(\",\")] for r in data\n        ]\n        features = torch.tensor(features, dtype=torch.float)\n        labels = [int(r.split(\"\\t\")[2]) for r in data]\n        self._num_classes = max(labels) + 1\n        labels = torch.tensor(labels, dtype=torch.long)\n\n        # Process graph structure.\n        with open(f\"{self.raw_path}/out1_graph_edges.txt\", \"r\") as f:\n            data = f.read().split(\"\\n\")[1:-1]\n            data = [[int(v) for v in r.split(\"\\t\")] for r in data]\n        dst, src = torch.tensor(data, dtype=torch.long).t().contiguous()\n\n        self._g = graph((src, dst), num_nodes=features.size(0))\n        self._g.ndata[\"feat\"] = features\n        self._g.ndata[\"label\"] = labels\n\n        # Process 10 train/val/test node splits.\n        train_masks, val_masks, test_masks = [], [], []\n        for i in range(10):\n            filepath = f\"{self.raw_path}/{self.name}_split_0.6_0.2_{i}.npz\"\n            f = np.load(filepath)\n            train_masks += [torch.from_numpy(f[\"train_mask\"])]\n            val_masks += [torch.from_numpy(f[\"val_mask\"])]\n            test_masks += [torch.from_numpy(f[\"test_mask\"])]\n        self._g.ndata[\"train_mask\"] = torch.stack(train_masks, dim=1).bool()\n        self._g.ndata[\"val_mask\"] = torch.stack(val_masks, dim=1).bool()\n        self._g.ndata[\"test_mask\"] = torch.stack(test_masks, dim=1).bool()\n\n    def has_cache(self):\n        return os.path.exists(self.raw_path)\n\n    def load(self):\n        self.process()\n\n    def __getitem__(self, idx):\n        assert idx == 0, \"This dataset has only one graph.\"\n        if self._transform is None:\n            return self._g\n        else:\n            return self._transform(self._g)\n\n    def __len__(self):\n        return 1\n\n    @property\n    def num_classes(self):\n        return self._num_classes\n\n\nclass ChameleonDataset(GeomGCNDataset):\n    r\"\"\"Wikipedia page-page network on chameleons from `Multi-scale Attributed\n    Node Embedding <https://arxiv.org/abs/1909.13021>`__ and later modified by\n    `Geom-GCN: Geometric Graph Convolutional Networks\n    <https://arxiv.org/abs/2002.05287>`__\n\n    Nodes represent articles from the English Wikipedia, edges reflect mutual\n    links between them. Node features indicate the presence of particular nouns\n    in the articles. The nodes were classified into 5 classes in terms of their\n    average monthly traffic.\n\n    Statistics:\n\n    - Nodes: 2277\n    - Edges: 36101\n    - Number of Classes: 5\n    - 10 train/val/test splits\n\n        - Train: 1092\n        - Val: 729\n        - Test: 456\n\n    Parameters\n    ----------\n    raw_dir : str, optional\n        Raw file directory to store the processed data. Default: ~/.dgl/\n    force_reload : bool, optional\n        Whether to re-download the data source. Default: False\n    verbose : bool, optional\n        Whether to print progress information. Default: True\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access. Default: None\n\n    Attributes\n    ----------\n    num_classes : int\n        Number of node classes\n\n    Notes\n    -----\n    The graph does not come with edges for both directions.\n\n    Examples\n    --------\n\n    >>> from dgl.data import ChameleonDataset\n    >>> dataset = ChameleonDataset()\n    >>> g = dataset[0]\n    >>> num_classes = dataset.num_classes\n\n    >>> # get node features\n    >>> feat = g.ndata[\"feat\"]\n\n    >>> # get data split\n    >>> train_mask = g.ndata[\"train_mask\"]\n    >>> val_mask = g.ndata[\"val_mask\"]\n    >>> test_mask = g.ndata[\"test_mask\"]\n\n    >>> # get labels\n    >>> label = g.ndata['label']\n    \"\"\"\n\n    def __init__(\n        self, raw_dir=None, force_reload=False, verbose=True, transform=None\n    ):\n        super(ChameleonDataset, self).__init__(\n            name=\"chameleon\",\n            raw_dir=raw_dir,\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n\nclass SquirrelDataset(GeomGCNDataset):\n    r\"\"\"Wikipedia page-page network on squirrels from `Multi-scale Attributed\n    Node Embedding <https://arxiv.org/abs/1909.13021>`__ and later modified by\n    `Geom-GCN: Geometric Graph Convolutional Networks\n    <https://arxiv.org/abs/2002.05287>`__\n\n    Nodes represent articles from the English Wikipedia, edges reflect mutual\n    links between them. Node features indicate the presence of particular nouns\n    in the articles. The nodes were classified into 5 classes in terms of their\n    average monthly traffic.\n\n    Statistics:\n\n    - Nodes: 5201\n    - Edges: 217073\n    - Number of Classes: 5\n    - 10 train/val/test splits\n\n        - Train: 2496\n        - Val: 1664\n        - Test: 1041\n\n    Parameters\n    ----------\n    raw_dir : str, optional\n        Raw file directory to store the processed data. Default: ~/.dgl/\n    force_reload : bool, optional\n        Whether to re-download the data source. Default: False\n    verbose : bool, optional\n        Whether to print progress information. Default: True\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access. Default: None\n\n    Attributes\n    ----------\n    num_classes : int\n        Number of node classes\n\n    Notes\n    -----\n    The graph does not come with edges for both directions.\n\n    Examples\n    --------\n\n    >>> from dgl.data import SquirrelDataset\n    >>> dataset = SquirrelDataset()\n    >>> g = dataset[0]\n    >>> num_classes = dataset.num_classes\n\n    >>> # get node features\n    >>> feat = g.ndata[\"feat\"]\n\n    >>> # get data split\n    >>> train_mask = g.ndata[\"train_mask\"]\n    >>> val_mask = g.ndata[\"val_mask\"]\n    >>> test_mask = g.ndata[\"test_mask\"]\n\n    >>> # get labels\n    >>> label = g.ndata['label']\n    \"\"\"\n\n    def __init__(\n        self, raw_dir=None, force_reload=False, verbose=True, transform=None\n    ):\n        super(SquirrelDataset, self).__init__(\n            name=\"squirrel\",\n            raw_dir=raw_dir,\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n\nclass CornellDataset(GeomGCNDataset):\n    r\"\"\"Cornell subset of\n    `WebKB <http://www.cs.cmu.edu/afs/cs.cmu.edu/project/theo-11/www/wwkb/>`__,\n    later modified by `Geom-GCN: Geometric Graph Convolutional Networks\n    <https://arxiv.org/abs/2002.05287>`__\n\n    Nodes represent web pages. Edges represent hyperlinks between them. Node\n    features are the bag-of-words representation of web pages. The web pages\n    are manually classified into the five categories, student, project, course,\n    staff, and faculty.\n\n    Statistics:\n\n    - Nodes: 183\n    - Edges: 298\n    - Number of Classes: 5\n    - 10 train/val/test splits\n\n        - Train: 87\n        - Val: 59\n        - Test: 37\n\n    Parameters\n    ----------\n    raw_dir : str, optional\n        Raw file directory to store the processed data. Default: ~/.dgl/\n    force_reload : bool, optional\n        Whether to re-download the data source. Default: False\n    verbose : bool, optional\n        Whether to print progress information. Default: True\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access. Default: None\n\n    Attributes\n    ----------\n    num_classes : int\n        Number of node classes\n\n    Notes\n    -----\n    The graph does not come with edges for both directions.\n\n    Examples\n    --------\n\n    >>> from dgl.data import CornellDataset\n    >>> dataset = CornellDataset()\n    >>> g = dataset[0]\n    >>> num_classes = dataset.num_classes\n\n    >>> # get node features\n    >>> feat = g.ndata[\"feat\"]\n\n    >>> # get data split\n    >>> train_mask = g.ndata[\"train_mask\"]\n    >>> val_mask = g.ndata[\"val_mask\"]\n    >>> test_mask = g.ndata[\"test_mask\"]\n\n    >>> # get labels\n    >>> label = g.ndata['label']\n    \"\"\"\n\n    def __init__(\n        self, raw_dir=None, force_reload=False, verbose=True, transform=None\n    ):\n        super(CornellDataset, self).__init__(\n            name=\"cornell\",\n            raw_dir=raw_dir,\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n\nclass TexasDataset(GeomGCNDataset):\n    r\"\"\"Texas subset of\n    `WebKB <http://www.cs.cmu.edu/afs/cs.cmu.edu/project/theo-11/www/wwkb/>`__,\n    later modified by `Geom-GCN: Geometric Graph Convolutional Networks\n    <https://arxiv.org/abs/2002.05287>`__\n\n    Nodes represent web pages. Edges represent hyperlinks between them. Node\n    features are the bag-of-words representation of web pages. The web pages\n    are manually classified into the five categories, student, project, course,\n    staff, and faculty.\n\n    Statistics:\n\n    - Nodes: 183\n    - Edges: 325\n    - Number of Classes: 5\n    - 10 train/val/test splits\n\n        - Train: 87\n        - Val: 59\n        - Test: 37\n\n    Parameters\n    ----------\n    raw_dir : str, optional\n        Raw file directory to store the processed data. Default: ~/.dgl/\n    force_reload : bool, optional\n        Whether to re-download the data source. Default: False\n    verbose : bool, optional\n        Whether to print progress information. Default: True\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access. Default: None\n\n    Attributes\n    ----------\n    num_classes : int\n        Number of node classes\n\n    Notes\n    -----\n    The graph does not come with edges for both directions.\n\n    Examples\n    --------\n\n    >>> from dgl.data import TexasDataset\n    >>> dataset = TexasDataset()\n    >>> g = dataset[0]\n    >>> num_classes = dataset.num_classes\n\n    >>> # get node features\n    >>> feat = g.ndata[\"feat\"]\n\n    >>> # get data split\n    >>> train_mask = g.ndata[\"train_mask\"]\n    >>> val_mask = g.ndata[\"val_mask\"]\n    >>> test_mask = g.ndata[\"test_mask\"]\n\n    >>> # get labels\n    >>> label = g.ndata['label']\n    \"\"\"\n\n    def __init__(\n        self, raw_dir=None, force_reload=False, verbose=True, transform=None\n    ):\n        super(TexasDataset, self).__init__(\n            name=\"texas\",\n            raw_dir=raw_dir,\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n\nclass WisconsinDataset(GeomGCNDataset):\n    r\"\"\"Wisconsin subset of\n    `WebKB <http://www.cs.cmu.edu/afs/cs.cmu.edu/project/theo-11/www/wwkb/>`__,\n    later modified by `Geom-GCN: Geometric Graph Convolutional Networks\n    <https://arxiv.org/abs/2002.05287>`__\n\n    Nodes represent web pages. Edges represent hyperlinks between them. Node\n    features are the bag-of-words representation of web pages. The web pages\n    are manually classified into the five categories, student, project, course,\n    staff, and faculty.\n\n    Statistics:\n\n    - Nodes: 251\n    - Edges: 515\n    - Number of Classes: 5\n    - 10 train/val/test splits\n\n        - Train: 120\n        - Val: 80\n        - Test: 51\n\n    Parameters\n    ----------\n    raw_dir : str, optional\n        Raw file directory to store the processed data. Default: ~/.dgl/\n    force_reload : bool, optional\n        Whether to re-download the data source. Default: False\n    verbose : bool, optional\n        Whether to print progress information. Default: True\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access. Default: None\n\n    Attributes\n    ----------\n    num_classes : int\n        Number of node classes\n\n    Notes\n    -----\n    The graph does not come with edges for both directions.\n\n    Examples\n    --------\n\n    >>> from dgl.data import WisconsinDataset\n    >>> dataset = WisconsinDataset()\n    >>> g = dataset[0]\n    >>> num_classes = dataset.num_classes\n\n    >>> # get node features\n    >>> feat = g.ndata[\"feat\"]\n\n    >>> # get data split\n    >>> train_mask = g.ndata[\"train_mask\"]\n    >>> val_mask = g.ndata[\"val_mask\"]\n    >>> test_mask = g.ndata[\"test_mask\"]\n\n    >>> # get labels\n    >>> label = g.ndata['label']\n    \"\"\"\n\n    def __init__(\n        self, raw_dir=None, force_reload=False, verbose=True, transform=None\n    ):\n        super(WisconsinDataset, self).__init__(\n            name=\"wisconsin\",\n            raw_dir=raw_dir,\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n"
  },
  {
    "path": "python/dgl/data/gindt.py",
    "content": "\"\"\"Datasets used in How Powerful Are Graph Neural Networks?\n(chen jun)\nDatasets include:\nMUTAG, COLLAB, IMDBBINARY, IMDBMULTI, NCI1, PROTEINS, PTC, REDDITBINARY, REDDITMULTI5K\nhttps://github.com/weihua916/powerful-gnns/blob/master/dataset.zip\n\"\"\"\n\nimport os\n\nimport numpy as np\n\nfrom .. import backend as F\nfrom ..convert import graph as dgl_graph\nfrom ..utils import retry_method_with_fix\nfrom .dgl_dataset import DGLBuiltinDataset\nfrom .utils import (\n    download,\n    extract_archive,\n    load_graphs,\n    load_info,\n    loadtxt,\n    save_graphs,\n    save_info,\n)\n\n\nclass GINDataset(DGLBuiltinDataset):\n    \"\"\"Dataset Class for `How Powerful Are Graph Neural Networks? <https://arxiv.org/abs/1810.00826>`_.\n\n    This is adapted from `<https://github.com/weihua916/powerful-gnns/blob/master/dataset.zip>`_.\n\n    The class provides an interface for nine datasets used in the paper along with the paper-specific\n    settings. The datasets are ``'MUTAG'``, ``'COLLAB'``, ``'IMDBBINARY'``, ``'IMDBMULTI'``,\n    ``'NCI1'``, ``'PROTEINS'``, ``'PTC'``, ``'REDDITBINARY'``, ``'REDDITMULTI5K'``.\n\n    If ``degree_as_nlabel`` is set to ``False``, then ``ndata['label']`` stores the provided node label,\n    otherwise ``ndata['label']`` stores the node in-degrees.\n\n    For graphs that have node attributes, ``ndata['attr']`` stores the node attributes.\n    For graphs that have no attribute, ``ndata['attr']`` stores the corresponding one-hot encoding\n    of ``ndata['label']``.\n\n    Parameters\n    ---------\n    name: str\n        dataset name, one of\n        (``'MUTAG'``, ``'COLLAB'``, \\\n        ``'IMDBBINARY'``, ``'IMDBMULTI'``, \\\n        ``'NCI1'``, ``'PROTEINS'``, ``'PTC'``, \\\n        ``'REDDITBINARY'``, ``'REDDITMULTI5K'``)\n    self_loop: bool\n        add self to self edge if true\n    degree_as_nlabel: bool\n        take node degree as label and feature if true\n    transform: callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access.\n\n    Attributes\n    ----------\n    num_classes : int\n        Number of classes for multiclass classification\n\n    Examples\n    --------\n    >>> data = GINDataset(name='MUTAG', self_loop=False)\n\n    The dataset instance is an iterable\n\n    >>> len(data)\n    188\n    >>> g, label = data[128]\n    >>> g\n    Graph(num_nodes=13, num_edges=26,\n          ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(7,), dtype=torch.float32)}\n          edata_schemes={})\n    >>> label\n    tensor(1)\n\n    Batch the graphs and labels for mini-batch training\n\n    >>> graphs, labels = zip(*[data[i] for i in range(16)])\n    >>> batched_graphs = dgl.batch(graphs)\n    >>> batched_labels = torch.tensor(labels)\n    >>> batched_graphs\n    Graph(num_nodes=330, num_edges=748,\n          ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(7,), dtype=torch.float32)}\n          edata_schemes={})\n    \"\"\"\n\n    def __init__(\n        self,\n        name,\n        self_loop,\n        degree_as_nlabel=False,\n        raw_dir=None,\n        force_reload=False,\n        verbose=False,\n        transform=None,\n    ):\n        self._name = name  # MUTAG\n        gin_url = \"https://raw.githubusercontent.com/weihua916/powerful-gnns/master/dataset.zip\"\n        self.ds_name = \"nig\"\n\n        self.self_loop = self_loop\n        self.graphs = []\n        self.labels = []\n\n        # relabel\n        self.glabel_dict = {}\n        self.nlabel_dict = {}\n        self.elabel_dict = {}\n        self.ndegree_dict = {}\n\n        # global num\n        self.N = 0  # total graphs number\n        self.n = 0  # total nodes number\n        self.m = 0  # total edges number\n\n        # global num of classes\n        self.gclasses = 0\n        self.nclasses = 0\n        self.eclasses = 0\n        self.dim_nfeats = 0\n\n        # flags\n        self.degree_as_nlabel = degree_as_nlabel\n        self.nattrs_flag = False\n        self.nlabels_flag = False\n\n        super(GINDataset, self).__init__(\n            name=name,\n            url=gin_url,\n            hash_key=(name, self_loop, degree_as_nlabel),\n            raw_dir=raw_dir,\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n    @property\n    def raw_path(self):\n        return os.path.join(self.raw_dir, \"GINDataset\")\n\n    def download(self):\n        r\"\"\"Automatically download data and extract it.\"\"\"\n        zip_file_path = os.path.join(self.raw_dir, \"GINDataset.zip\")\n        download(self.url, path=zip_file_path)\n        extract_archive(zip_file_path, self.raw_path)\n\n    def __len__(self):\n        \"\"\"Return the number of graphs in the dataset.\"\"\"\n        return len(self.graphs)\n\n    def __getitem__(self, idx):\n        \"\"\"Get the idx-th sample.\n\n        Parameters\n        ---------\n        idx : int\n            The sample index.\n\n        Returns\n        -------\n        (:class:`dgl.Graph`, Tensor)\n            The graph and its label.\n        \"\"\"\n        if self._transform is None:\n            g = self.graphs[idx]\n        else:\n            g = self._transform(self.graphs[idx])\n        return g, self.labels[idx]\n\n    def _file_path(self):\n        return os.path.join(\n            self.raw_dir,\n            \"GINDataset\",\n            \"dataset\",\n            self.name,\n            \"{}.txt\".format(self.name),\n        )\n\n    def process(self):\n        \"\"\"Loads input dataset from dataset/NAME/NAME.txt file\"\"\"\n        if self.verbose:\n            print(\"loading data...\")\n        self.file = self._file_path()\n        with open(self.file, \"r\") as f:\n            # line_1 == N, total number of graphs\n            self.N = int(f.readline().strip())\n\n            for i in range(self.N):\n                if (i + 1) % 10 == 0 and self.verbose is True:\n                    print(\"processing graph {}...\".format(i + 1))\n\n                grow = f.readline().strip().split()\n                # line_2 == [n_nodes, l] is equal to\n                # [node number of a graph, class label of a graph]\n                n_nodes, glabel = [int(w) for w in grow]\n\n                # relabel graphs\n                if glabel not in self.glabel_dict:\n                    mapped = len(self.glabel_dict)\n                    self.glabel_dict[glabel] = mapped\n\n                self.labels.append(self.glabel_dict[glabel])\n\n                g = dgl_graph(([], []))\n                g.add_nodes(n_nodes)\n\n                nlabels = []  # node labels\n                nattrs = []  # node attributes if it has\n                m_edges = 0\n\n                for j in range(n_nodes):\n                    nrow = f.readline().strip().split()\n\n                    # handle edges and attributes(if has)\n                    tmp = int(nrow[1]) + 2  # tmp == 2 + #edges\n                    if tmp == len(nrow):\n                        # no node attributes\n                        nrow = [int(w) for w in nrow]\n                    elif tmp > len(nrow):\n                        nrow = [int(w) for w in nrow[:tmp]]\n                        nattr = [float(w) for w in nrow[tmp:]]\n                        nattrs.append(nattr)\n                    else:\n                        raise Exception(\"edge number is incorrect!\")\n\n                    # relabel nodes if it has labels\n                    # if it doesn't have node labels, then every nrow[0]==0\n                    if not nrow[0] in self.nlabel_dict:\n                        mapped = len(self.nlabel_dict)\n                        self.nlabel_dict[nrow[0]] = mapped\n\n                    nlabels.append(self.nlabel_dict[nrow[0]])\n\n                    m_edges += nrow[1]\n                    g.add_edges(j, nrow[2:])\n\n                    # add self loop\n                    if self.self_loop:\n                        m_edges += 1\n                        g.add_edges(j, j)\n\n                    if (j + 1) % 10 == 0 and self.verbose is True:\n                        print(\n                            \"processing node {} of graph {}...\".format(\n                                j + 1, i + 1\n                            )\n                        )\n                        print(\"this node has {} edgs.\".format(nrow[1]))\n\n                if nattrs != []:\n                    nattrs = np.stack(nattrs)\n                    g.ndata[\"attr\"] = F.tensor(nattrs, F.float32)\n                    self.nattrs_flag = True\n\n                g.ndata[\"label\"] = F.tensor(nlabels)\n                if len(self.nlabel_dict) > 1:\n                    self.nlabels_flag = True\n\n                assert g.num_nodes() == n_nodes\n\n                # update statistics of graphs\n                self.n += n_nodes\n                self.m += m_edges\n\n                self.graphs.append(g)\n\n        self.labels = F.tensor(self.labels)\n        # if no attr\n        if not self.nattrs_flag:\n            if self.verbose:\n                print(\"there are no node features in this dataset!\")\n            # generate node attr by node degree\n            if self.degree_as_nlabel:\n                if self.verbose:\n                    print(\"generate node features by node degree...\")\n                for g in self.graphs:\n                    # actually this label shouldn't be updated\n                    # in case users want to keep it\n                    # but usually no features means no labels, fine.\n                    g.ndata[\"label\"] = g.in_degrees()\n                    # extracting unique node labels\n\n            # in case the labels/degrees are not continuous number\n            nlabel_set = set([])\n            for g in self.graphs:\n                nlabel_set = nlabel_set.union(\n                    set([F.as_scalar(nl) for nl in g.ndata[\"label\"]])\n                )\n            nlabel_set = list(nlabel_set)\n            is_label_valid = all(\n                [label in self.nlabel_dict for label in nlabel_set]\n            )\n            if (\n                is_label_valid\n                and len(nlabel_set) == np.max(nlabel_set) + 1\n                and np.min(nlabel_set) == 0\n            ):\n                # Note this is different from the author's implementation. In weihua916's implementation,\n                # the labels are relabeled anyway. But here we didn't relabel it if the labels are contiguous\n                # to make it consistent with the original dataset\n                label2idx = self.nlabel_dict\n            else:\n                label2idx = {nlabel_set[i]: i for i in range(len(nlabel_set))}\n            # generate node attr by node label\n            for g in self.graphs:\n                attr = np.zeros((g.num_nodes(), len(label2idx)))\n                attr[\n                    range(g.num_nodes()),\n                    [\n                        label2idx[nl]\n                        for nl in F.asnumpy(g.ndata[\"label\"]).tolist()\n                    ],\n                ] = 1\n                g.ndata[\"attr\"] = F.tensor(attr, F.float32)\n\n        # after load, get the #classes and #dim\n        self.gclasses = len(self.glabel_dict)\n        self.nclasses = len(self.nlabel_dict)\n        self.eclasses = len(self.elabel_dict)\n        self.dim_nfeats = len(self.graphs[0].ndata[\"attr\"][0])\n\n        if self.verbose:\n            print(\"Done.\")\n            print(\n                \"\"\"\n                -------- Data Statistics --------'\n                #Graphs: %d\n                #Graph Classes: %d\n                #Nodes: %d\n                #Node Classes: %d\n                #Node Features Dim: %d\n                #Edges: %d\n                #Edge Classes: %d\n                Avg. of #Nodes: %.2f\n                Avg. of #Edges: %.2f\n                Graph Relabeled: %s\n                Node Relabeled: %s\n                Degree Relabeled(If degree_as_nlabel=True): %s \\n \"\"\"\n                % (\n                    self.N,\n                    self.gclasses,\n                    self.n,\n                    self.nclasses,\n                    self.dim_nfeats,\n                    self.m,\n                    self.eclasses,\n                    self.n / self.N,\n                    self.m / self.N,\n                    self.glabel_dict,\n                    self.nlabel_dict,\n                    self.ndegree_dict,\n                )\n            )\n\n    def save(self):\n        label_dict = {\"labels\": self.labels}\n        info_dict = {\n            \"N\": self.N,\n            \"n\": self.n,\n            \"m\": self.m,\n            \"self_loop\": self.self_loop,\n            \"gclasses\": self.gclasses,\n            \"nclasses\": self.nclasses,\n            \"eclasses\": self.eclasses,\n            \"dim_nfeats\": self.dim_nfeats,\n            \"degree_as_nlabel\": self.degree_as_nlabel,\n            \"glabel_dict\": self.glabel_dict,\n            \"nlabel_dict\": self.nlabel_dict,\n            \"elabel_dict\": self.elabel_dict,\n            \"ndegree_dict\": self.ndegree_dict,\n        }\n        save_graphs(str(self.graph_path), self.graphs, label_dict)\n        save_info(str(self.info_path), info_dict)\n\n    def load(self):\n        graphs, label_dict = load_graphs(str(self.graph_path))\n        info_dict = load_info(str(self.info_path))\n\n        self.graphs = graphs\n        self.labels = label_dict[\"labels\"]\n\n        self.N = info_dict[\"N\"]\n        self.n = info_dict[\"n\"]\n        self.m = info_dict[\"m\"]\n        self.self_loop = info_dict[\"self_loop\"]\n        self.gclasses = info_dict[\"gclasses\"]\n        self.nclasses = info_dict[\"nclasses\"]\n        self.eclasses = info_dict[\"eclasses\"]\n        self.dim_nfeats = info_dict[\"dim_nfeats\"]\n        self.glabel_dict = info_dict[\"glabel_dict\"]\n        self.nlabel_dict = info_dict[\"nlabel_dict\"]\n        self.elabel_dict = info_dict[\"elabel_dict\"]\n        self.ndegree_dict = info_dict[\"ndegree_dict\"]\n        self.degree_as_nlabel = info_dict[\"degree_as_nlabel\"]\n\n    @property\n    def graph_path(self):\n        return os.path.join(\n            self.save_path, \"gin_{}_{}.bin\".format(self.name, self.hash)\n        )\n\n    @property\n    def info_path(self):\n        return os.path.join(\n            self.save_path, \"gin_{}_{}.pkl\".format(self.name, self.hash)\n        )\n\n    def has_cache(self):\n        if os.path.exists(self.graph_path) and os.path.exists(self.info_path):\n            return True\n        return False\n\n    @property\n    def num_classes(self):\n        return self.gclasses\n"
  },
  {
    "path": "python/dgl/data/gnn_benchmark.py",
    "content": "\"\"\"GNN Benchmark datasets for node classification.\"\"\"\nimport os\n\nimport numpy as np\nimport scipy.sparse as sp\n\nfrom .. import backend as F, transforms\nfrom ..convert import graph as dgl_graph\n\nfrom .dgl_dataset import DGLBuiltinDataset\nfrom .utils import (\n    _get_dgl_url,\n    deprecate_class,\n    deprecate_property,\n    load_graphs,\n    save_graphs,\n)\n\n__all__ = [\n    \"AmazonCoBuyComputerDataset\",\n    \"AmazonCoBuyPhotoDataset\",\n    \"CoauthorPhysicsDataset\",\n    \"CoauthorCSDataset\",\n    \"CoraFullDataset\",\n    \"AmazonCoBuy\",\n    \"Coauthor\",\n    \"CoraFull\",\n]\n\n\ndef eliminate_self_loops(A):\n    \"\"\"Remove self-loops from the adjacency matrix.\"\"\"\n    A = A.tolil()\n    A.setdiag(0)\n    A = A.tocsr()\n    A.eliminate_zeros()\n    return A\n\n\nclass GNNBenchmarkDataset(DGLBuiltinDataset):\n    r\"\"\"Base Class for GNN Benchmark dataset\n\n    Reference: https://github.com/shchur/gnn-benchmark#datasets\n    \"\"\"\n\n    def __init__(\n        self,\n        name,\n        raw_dir=None,\n        force_reload=False,\n        verbose=False,\n        transform=None,\n    ):\n        _url = _get_dgl_url(\"dataset/\" + name + \".zip\")\n        super(GNNBenchmarkDataset, self).__init__(\n            name=name,\n            url=_url,\n            raw_dir=raw_dir,\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n    def process(self):\n        npz_path = os.path.join(self.raw_path, self.name + \".npz\")\n        g = self._load_npz(npz_path)\n        g = transforms.reorder_graph(\n            g,\n            node_permute_algo=\"rcmk\",\n            edge_permute_algo=\"dst\",\n            store_ids=False,\n        )\n        self._graph = g\n        self._data = [g]\n        self._print_info()\n\n    def has_cache(self):\n        graph_path = os.path.join(self.save_path, \"dgl_graph_v1.bin\")\n        if os.path.exists(graph_path):\n            return True\n        return False\n\n    def save(self):\n        graph_path = os.path.join(self.save_path, \"dgl_graph_v1.bin\")\n        save_graphs(graph_path, self._graph)\n\n    def load(self):\n        graph_path = os.path.join(self.save_path, \"dgl_graph_v1.bin\")\n        graphs, _ = load_graphs(graph_path)\n        self._graph = graphs[0]\n        self._data = [graphs[0]]\n        self._print_info()\n\n    def _print_info(self):\n        if self.verbose:\n            print(\"  NumNodes: {}\".format(self._graph.num_nodes()))\n            print(\"  NumEdges: {}\".format(self._graph.num_edges()))\n            print(\"  NumFeats: {}\".format(self._graph.ndata[\"feat\"].shape[-1]))\n            print(\"  NumbClasses: {}\".format(self.num_classes))\n\n    def _load_npz(self, file_name):\n        with np.load(file_name, allow_pickle=True) as loader:\n            loader = dict(loader)\n            num_nodes = loader[\"adj_shape\"][0]\n            adj_matrix = sp.csr_matrix(\n                (\n                    loader[\"adj_data\"],\n                    loader[\"adj_indices\"],\n                    loader[\"adj_indptr\"],\n                ),\n                shape=loader[\"adj_shape\"],\n            ).tocoo()\n\n            if \"attr_data\" in loader:\n                # Attributes are stored as a sparse CSR matrix\n                attr_matrix = sp.csr_matrix(\n                    (\n                        loader[\"attr_data\"],\n                        loader[\"attr_indices\"],\n                        loader[\"attr_indptr\"],\n                    ),\n                    shape=loader[\"attr_shape\"],\n                ).todense()\n            elif \"attr_matrix\" in loader:\n                # Attributes are stored as a (dense) np.ndarray\n                attr_matrix = loader[\"attr_matrix\"]\n            else:\n                attr_matrix = None\n\n            if \"labels_data\" in loader:\n                # Labels are stored as a CSR matrix\n                labels = sp.csr_matrix(\n                    (\n                        loader[\"labels_data\"],\n                        loader[\"labels_indices\"],\n                        loader[\"labels_indptr\"],\n                    ),\n                    shape=loader[\"labels_shape\"],\n                ).todense()\n            elif \"labels\" in loader:\n                # Labels are stored as a numpy array\n                labels = loader[\"labels\"]\n            else:\n                labels = None\n        g = dgl_graph((adj_matrix.row, adj_matrix.col))\n        g = transforms.to_bidirected(g)\n        g.ndata[\"feat\"] = F.tensor(attr_matrix, F.data_type_dict[\"float32\"])\n        g.ndata[\"label\"] = F.tensor(labels, F.data_type_dict[\"int64\"])\n        return g\n\n    @property\n    def num_classes(self):\n        \"\"\"Number of classes.\"\"\"\n        raise NotImplementedError\n\n    def __getitem__(self, idx):\n        r\"\"\"Get graph by index\n\n        Parameters\n        ----------\n        idx : int\n            Item index\n\n        Returns\n        -------\n        :class:`dgl.DGLGraph`\n\n            The graph contains:\n\n            - ``ndata['feat']``: node features\n            - ``ndata['label']``: node labels\n        \"\"\"\n        assert idx == 0, \"This dataset has only one graph\"\n        if self._transform is None:\n            return self._graph\n        else:\n            return self._transform(self._graph)\n\n    def __len__(self):\n        r\"\"\"Number of graphs in the dataset\"\"\"\n        return 1\n\n\nclass CoraFullDataset(GNNBenchmarkDataset):\n    r\"\"\"CORA-Full dataset for node classification task.\n\n    Extended Cora dataset. Nodes represent paper and edges represent citations.\n\n    Reference: `<https://github.com/shchur/gnn-benchmark#datasets>`_\n\n    Statistics:\n\n    - Nodes: 19,793\n    - Edges: 126,842 (note that the original dataset has 65,311 edges but DGL adds\n      the reverse edges and remove the duplicates, hence with a different number)\n    - Number of Classes: 70\n    - Node feature size: 8,710\n\n    Parameters\n    ----------\n    raw_dir : str\n        Raw file directory to download/contains the input data directory.\n        Default: ~/.dgl/\n    force_reload : bool\n        Whether to reload the dataset. Default: False\n    verbose : bool\n        Whether to print out progress information. Default: True.\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access.\n\n    Attributes\n    ----------\n    num_classes : int\n        Number of classes for each node.\n\n    Examples\n    --------\n    >>> data = CoraFullDataset()\n    >>> g = data[0]\n    >>> num_class = data.num_classes\n    >>> feat = g.ndata['feat']  # get node feature\n    >>> label = g.ndata['label']  # get node labels\n    \"\"\"\n\n    def __init__(\n        self, raw_dir=None, force_reload=False, verbose=False, transform=None\n    ):\n        super(CoraFullDataset, self).__init__(\n            name=\"cora_full\",\n            raw_dir=raw_dir,\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n    @property\n    def num_classes(self):\n        \"\"\"Number of classes.\n\n        Return\n        -------\n        int\n        \"\"\"\n        return 70\n\n\nclass CoauthorCSDataset(GNNBenchmarkDataset):\n    r\"\"\"'Computer Science (CS)' part of the Coauthor dataset for node classification task.\n\n    Coauthor CS and Coauthor Physics are co-authorship graphs based on the Microsoft Academic Graph\n    from the KDD Cup 2016 challenge. Here, nodes are authors, that are connected by an edge if they\n    co-authored a paper; node features represent paper keywords for each author’s papers, and class\n    labels indicate most active fields of study for each author.\n\n    Reference: `<https://github.com/shchur/gnn-benchmark#datasets>`_\n\n    Statistics:\n\n    - Nodes: 18,333\n    - Edges: 163,788 (note that the original dataset has 81,894 edges but DGL adds\n      the reverse edges and remove the duplicates, hence with a different number)\n    - Number of classes: 15\n    - Node feature size: 6,805\n\n    Parameters\n    ----------\n    raw_dir : str\n        Raw file directory to download/contains the input data directory.\n        Default: ~/.dgl/\n    force_reload : bool\n        Whether to reload the dataset. Default: False\n    verbose : bool\n        Whether to print out progress information. Default: True.\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access.\n\n    Attributes\n    ----------\n    num_classes : int\n        Number of classes for each node.\n\n    Examples\n    --------\n    >>> data = CoauthorCSDataset()\n    >>> g = data[0]\n    >>> num_class = data.num_classes\n    >>> feat = g.ndata['feat']  # get node feature\n    >>> label = g.ndata['label']  # get node labels\n    \"\"\"\n\n    def __init__(\n        self, raw_dir=None, force_reload=False, verbose=False, transform=None\n    ):\n        super(CoauthorCSDataset, self).__init__(\n            name=\"coauthor_cs\",\n            raw_dir=raw_dir,\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n    @property\n    def num_classes(self):\n        \"\"\"Number of classes.\n\n        Return\n        -------\n        int\n        \"\"\"\n        return 15\n\n\nclass CoauthorPhysicsDataset(GNNBenchmarkDataset):\n    r\"\"\"'Physics' part of the Coauthor dataset for node classification task.\n\n    Coauthor CS and Coauthor Physics are co-authorship graphs based on the Microsoft Academic Graph\n    from the KDD Cup 2016 challenge. Here, nodes are authors, that are connected by an edge if they\n    co-authored a paper; node features represent paper keywords for each author’s papers, and class\n    labels indicate most active fields of study for each author.\n\n    Reference: `<https://github.com/shchur/gnn-benchmark#datasets>`_\n\n    Statistics\n\n    - Nodes: 34,493\n    - Edges: 495,924 (note that the original dataset has 247,962 edges but DGL adds\n      the reverse edges and remove the duplicates, hence with a different number)\n    - Number of classes: 5\n    - Node feature size: 8,415\n\n    Parameters\n    ----------\n    raw_dir : str\n        Raw file directory to download/contains the input data directory.\n        Default: ~/.dgl/\n    force_reload : bool\n        Whether to reload the dataset. Default: False\n    verbose : bool\n        Whether to print out progress information. Default: True.\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access.\n\n    Attributes\n    ----------\n    num_classes : int\n        Number of classes for each node.\n\n    Examples\n    --------\n    >>> data = CoauthorPhysicsDataset()\n    >>> g = data[0]\n    >>> num_class = data.num_classes\n    >>> feat = g.ndata['feat']  # get node feature\n    >>> label = g.ndata['label']  # get node labels\n    \"\"\"\n\n    def __init__(\n        self, raw_dir=None, force_reload=False, verbose=False, transform=None\n    ):\n        super(CoauthorPhysicsDataset, self).__init__(\n            name=\"coauthor_physics\",\n            raw_dir=raw_dir,\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n    @property\n    def num_classes(self):\n        \"\"\"Number of classes.\n\n        Return\n        -------\n        int\n        \"\"\"\n        return 5\n\n\nclass AmazonCoBuyComputerDataset(GNNBenchmarkDataset):\n    r\"\"\"'Computer' part of the AmazonCoBuy dataset for node classification task.\n\n    Amazon Computers and Amazon Photo are segments of the Amazon co-purchase graph [McAuley et al., 2015],\n    where nodes represent goods, edges indicate that two goods are frequently bought together, node\n    features are bag-of-words encoded product reviews, and class labels are given by the product category.\n\n    Reference: `<https://github.com/shchur/gnn-benchmark#datasets>`_\n\n    Statistics:\n\n    - Nodes: 13,752\n    - Edges: 491,722 (note that the original dataset has 245,778 edges but DGL adds\n      the reverse edges and remove the duplicates, hence with a different number)\n    - Number of classes: 10\n    - Node feature size: 767\n\n    Parameters\n    ----------\n    raw_dir : str\n        Raw file directory to download/contains the input data directory.\n        Default: ~/.dgl/\n    force_reload : bool\n        Whether to reload the dataset. Default: False\n    verbose : bool\n        Whether to print out progress information. Default: True.\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access.\n\n    Attributes\n    ----------\n    num_classes : int\n        Number of classes for each node.\n\n    Examples\n    --------\n    >>> data = AmazonCoBuyComputerDataset()\n    >>> g = data[0]\n    >>> num_class = data.num_classes\n    >>> feat = g.ndata['feat']  # get node feature\n    >>> label = g.ndata['label']  # get node labels\n    \"\"\"\n\n    def __init__(\n        self, raw_dir=None, force_reload=False, verbose=False, transform=None\n    ):\n        super(AmazonCoBuyComputerDataset, self).__init__(\n            name=\"amazon_co_buy_computer\",\n            raw_dir=raw_dir,\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n    @property\n    def num_classes(self):\n        \"\"\"Number of classes.\n\n        Return\n        -------\n        int\n        \"\"\"\n        return 10\n\n\nclass AmazonCoBuyPhotoDataset(GNNBenchmarkDataset):\n    r\"\"\"AmazonCoBuy dataset for node classification task.\n\n    Amazon Computers and Amazon Photo are segments of the Amazon co-purchase graph [McAuley et al., 2015],\n    where nodes represent goods, edges indicate that two goods are frequently bought together, node\n    features are bag-of-words encoded product reviews, and class labels are given by the product category.\n\n    Reference: `<https://github.com/shchur/gnn-benchmark#datasets>`_\n\n    Statistics\n\n    - Nodes: 7,650\n    - Edges: 238,163 (note that the original dataset has 119,043 edges but DGL adds\n      the reverse edges and remove the duplicates, hence with a different number)\n    - Number of classes: 8\n    - Node feature size: 745\n\n    Parameters\n    ----------\n    raw_dir : str\n        Raw file directory to download/contains the input data directory.\n        Default: ~/.dgl/\n    force_reload : bool\n        Whether to reload the dataset. Default: False\n    verbose : bool\n        Whether to print out progress information. Default: True.\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access.\n\n    Attributes\n    ----------\n    num_classes : int\n        Number of classes for each node.\n\n    Examples\n    --------\n    >>> data = AmazonCoBuyPhotoDataset()\n    >>> g = data[0]\n    >>> num_class = data.num_classes\n    >>> feat = g.ndata['feat']  # get node feature\n    >>> label = g.ndata['label']  # get node labels\n    \"\"\"\n\n    def __init__(\n        self, raw_dir=None, force_reload=False, verbose=False, transform=None\n    ):\n        super(AmazonCoBuyPhotoDataset, self).__init__(\n            name=\"amazon_co_buy_photo\",\n            raw_dir=raw_dir,\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n    @property\n    def num_classes(self):\n        \"\"\"Number of classes.\n\n        Return\n        -------\n        int\n        \"\"\"\n        return 8\n\n\nclass CoraFull(CoraFullDataset):\n    def __init__(self, **kwargs):\n        deprecate_class(\"CoraFull\", \"CoraFullDataset\")\n        super(CoraFull, self).__init__(**kwargs)\n\n\ndef AmazonCoBuy(name):\n    if name == \"computers\":\n        deprecate_class(\"AmazonCoBuy\", \"AmazonCoBuyComputerDataset\")\n        return AmazonCoBuyComputerDataset()\n    elif name == \"photo\":\n        deprecate_class(\"AmazonCoBuy\", \"AmazonCoBuyPhotoDataset\")\n        return AmazonCoBuyPhotoDataset()\n    else:\n        raise ValueError('Dataset name should be \"computers\" or \"photo\".')\n\n\ndef Coauthor(name):\n    if name == \"cs\":\n        deprecate_class(\"Coauthor\", \"CoauthorCSDataset\")\n        return CoauthorCSDataset()\n    elif name == \"physics\":\n        deprecate_class(\"Coauthor\", \"CoauthorPhysicsDataset\")\n        return CoauthorPhysicsDataset()\n    else:\n        raise ValueError('Dataset name should be \"cs\" or \"physics\".')\n"
  },
  {
    "path": "python/dgl/data/graph_serialize.py",
    "content": "\"\"\"For Graph Serialization\"\"\"\nfrom __future__ import absolute_import\n\nimport os\n\nfrom .. import backend as F\nfrom .._ffi.function import _init_api\nfrom .._ffi.object import ObjectBase, register_object\nfrom ..base import dgl_warning, DGLError\nfrom ..heterograph import DGLGraph\nfrom .heterograph_serialize import save_heterographs\n\n_init_api(\"dgl.data.graph_serialize\")\n\n__all__ = [\"save_graphs\", \"load_graphs\", \"load_labels\"]\n\n\n@register_object(\"graph_serialize.StorageMetaData\")\nclass StorageMetaData(ObjectBase):\n    \"\"\"StorageMetaData Object\n    attributes available:\n      num_graph [int]: return numbers of graphs\n      nodes_num_list Value of NDArray: return number of nodes for each graph\n      edges_num_list Value of NDArray: return number of edges for each graph\n      labels [dict of backend tensors]: return dict of labels\n      graph_data [list of GraphData]: return list of GraphData Object\n    \"\"\"\n\n\ndef is_local_path(filepath):\n    return not (\n        filepath.startswith(\"hdfs://\")\n        or filepath.startswith(\"viewfs://\")\n        or filepath.startswith(\"s3://\")\n    )\n\n\ndef check_local_file_exists(filename):\n    if is_local_path(filename) and not os.path.exists(filename):\n        raise DGLError(\"File {} does not exist.\".format(filename))\n\n\n@register_object(\"graph_serialize.GraphData\")\nclass GraphData(ObjectBase):\n    \"\"\"GraphData Object\"\"\"\n\n    @staticmethod\n    def create(g):\n        \"\"\"Create GraphData\"\"\"\n        # TODO(zihao): support serialize batched graph in the future.\n        assert (\n            g.batch_size == 1\n        ), \"Batched DGLGraph is not supported for serialization\"\n        ghandle = g._graph\n        if len(g.ndata) != 0:\n            node_tensors = dict()\n            for key, value in g.ndata.items():\n                node_tensors[key] = F.zerocopy_to_dgl_ndarray(value)\n        else:\n            node_tensors = None\n        if len(g.edata) != 0:\n            edge_tensors = dict()\n            for key, value in g.edata.items():\n                edge_tensors[key] = F.zerocopy_to_dgl_ndarray(value)\n        else:\n            edge_tensors = None\n        return _CAPI_MakeGraphData(ghandle, node_tensors, edge_tensors)\n\n    def get_graph(self):\n        \"\"\"Get DGLGraph from GraphData\"\"\"\n        ghandle = _CAPI_GDataGraphHandle(self)\n        hgi = _CAPI_DGLAsHeteroGraph(ghandle)\n        g = DGLGraph(hgi, [\"_U\"], [\"_E\"])\n        node_tensors_items = _CAPI_GDataNodeTensors(self).items()\n        edge_tensors_items = _CAPI_GDataEdgeTensors(self).items()\n        for k, v in node_tensors_items:\n            g.ndata[k] = F.zerocopy_from_dgl_ndarray(v)\n        for k, v in edge_tensors_items:\n            g.edata[k] = F.zerocopy_from_dgl_ndarray(v)\n        return g\n\n\ndef save_graphs(filename, g_list, labels=None, formats=None):\n    r\"\"\"Save graphs and optionally their labels to file.\n\n    Besides saving to local files, DGL supports writing the graphs directly\n    to S3 (by providing a ``\"s3://...\"`` path) or to HDFS (by providing\n    ``\"hdfs://...\"`` a path).\n\n    The function saves both the graph structure and node/edge features to file\n    in DGL's own binary format. For graph-level features, pass them via\n    the :attr:`labels` argument.\n\n    Parameters\n    ----------\n    filename : str\n        The file name to store the graphs and labels.\n    g_list: list\n        The graphs to be saved.\n    labels: dict[str, Tensor]\n        labels should be dict of tensors, with str as keys\n    formats: str or list[str]\n        Save graph in specified formats. It could be any combination of\n        ``coo``, ``csc`` and ``csr``. If not specified, save one format\n        only according to what format is available. If multiple formats\n        are available, selection priority from high to low is ``coo``,\n        ``csc``, ``csr``.\n\n    Examples\n    ----------\n    >>> import dgl\n    >>> import torch as th\n\n    Create :class:`DGLGraph` objects and initialize node\n    and edge features.\n\n    >>> g1 = dgl.graph(([0, 1, 2], [1, 2, 3]))\n    >>> g2 = dgl.graph(([0, 2], [2, 3]))\n    >>> g2.edata[\"e\"] = th.ones(2, 4)\n\n    Save Graphs into file\n\n    >>> from dgl.data.utils import save_graphs\n    >>> graph_labels = {\"glabel\": th.tensor([0, 1])}\n    >>> save_graphs(\"./data.bin\", [g1, g2], graph_labels)\n\n    See Also\n    --------\n    load_graphs\n    \"\"\"\n    # if it is local file, do some sanity check\n    if is_local_path(filename):\n        if os.path.isdir(filename):\n            raise DGLError(\n                \"Filename {} is an existing directory.\".format(filename)\n            )\n        f_path = os.path.dirname(filename)\n        if f_path and not os.path.exists(f_path):\n            os.makedirs(f_path)\n    g_sample = g_list[0] if isinstance(g_list, list) else g_list\n    if type(g_sample) == DGLGraph:  # Doesn't support DGLGraph's derived class\n        save_heterographs(filename, g_list, labels, formats)\n    else:\n        raise DGLError(\n            \"Invalid argument g_list. Must be a DGLGraph or a list of DGLGraphs.\"\n        )\n\n\ndef load_graphs(filename, idx_list=None):\n    \"\"\"Load graphs and optionally their labels from file saved by :func:`save_graphs`.\n\n    Besides loading from local files, DGL supports loading the graphs directly\n    from S3 (by providing a ``\"s3://...\"`` path) or from HDFS (by providing\n    ``\"hdfs://...\"`` a path).\n\n    Parameters\n    ----------\n    filename: str\n        The file name to load graphs from.\n    idx_list: list[int], optional\n        The indices of the graphs to be loaded if the file contains multiple graphs.\n        Default is loading all the graphs stored in the file.\n\n    Returns\n    --------\n    graph_list: list[DGLGraph]\n        The loaded graphs.\n    labels: dict[str, Tensor]\n        The graph labels stored in file. If no label is stored, the dictionary is empty.\n        Regardless of whether the ``idx_list`` argument is given or not,\n        the returned dictionary always contains the labels of all the graphs.\n\n    Examples\n    ----------\n    Following the example in :func:`save_graphs`.\n\n    >>> from dgl.data.utils import load_graphs\n    >>> glist, label_dict = load_graphs(\"./data.bin\") # glist will be [g1, g2]\n    >>> glist, label_dict = load_graphs(\"./data.bin\", [0]) # glist will be [g1]\n\n    See Also\n    --------\n    save_graphs\n    \"\"\"\n    # if it is local file, do some sanity check\n    check_local_file_exists(filename)\n    version = _CAPI_GetFileVersion(filename)\n    if version == 1:\n        dgl_warning(\n            \"You are loading a graph file saved by old version of dgl.  \\\r\n            Please consider saving it again with the current format.\"\n        )\n        return load_graph_v1(filename, idx_list)\n    elif version == 2:\n        return load_graph_v2(filename, idx_list)\n    else:\n        raise DGLError(\"Invalid DGL Version Number.\")\n\n\ndef load_graph_v2(filename, idx_list=None):\n    \"\"\"Internal functions for loading DGLGraphs.\"\"\"\n    if idx_list is None:\n        idx_list = []\n    assert isinstance(idx_list, list)\n    heterograph_list = _CAPI_LoadGraphFiles_V2(filename, idx_list)\n    label_dict = load_labels_v2(filename)\n    return [gdata.get_graph() for gdata in heterograph_list], label_dict\n\n\ndef load_graph_v1(filename, idx_list=None):\n    \"\"\" \"Internal functions for loading DGLGraphs (V0).\"\"\"\n    if idx_list is None:\n        idx_list = []\n    assert isinstance(idx_list, list)\n    metadata = _CAPI_LoadGraphFiles_V1(filename, idx_list, False)\n    label_dict = {}\n    for k, v in metadata.labels.items():\n        label_dict[k] = F.zerocopy_from_dgl_ndarray(v)\n    return [gdata.get_graph() for gdata in metadata.graph_data], label_dict\n\n\ndef load_labels(filename):\n    \"\"\"\n    Load label dict from file\n\n    Parameters\n    ----------\n    filename: str\n        filename to load DGLGraphs\n\n    Returns\n    ----------\n    labels: dict\n        dict of labels stored in file (empty dict returned if no\n        label stored)\n\n    Examples\n    ----------\n    Following the example in save_graphs.\n\n    >>> from dgl.data.utils import load_labels\n    >>> label_dict = load_graphs(\"./data.bin\")\n\n    \"\"\"\n    # if it is local file, do some sanity check\n    check_local_file_exists(filename)\n\n    version = _CAPI_GetFileVersion(filename)\n    if version == 1:\n        return load_labels_v1(filename)\n    elif version == 2:\n        return load_labels_v2(filename)\n    else:\n        raise Exception(\"Invalid DGL Version Number\")\n\n\ndef load_labels_v2(filename):\n    \"\"\"Internal functions for loading labels from V2 format\"\"\"\n    label_dict = {}\n    nd_dict = _CAPI_LoadLabels_V2(filename)\n    for k, v in nd_dict.items():\n        label_dict[k] = F.zerocopy_from_dgl_ndarray(v)\n    return label_dict\n\n\ndef load_labels_v1(filename):\n    \"\"\"Internal functions for loading labels from V1 format\"\"\"\n    metadata = _CAPI_LoadGraphFiles_V1(filename, [], True)\n    label_dict = {}\n    for k, v in metadata.labels.items():\n        label_dict[k] = F.zerocopy_from_dgl_ndarray(v)\n    return label_dict\n"
  },
  {
    "path": "python/dgl/data/heterograph_serialize.py",
    "content": "\"\"\"For HeteroGraph Serialization\"\"\"\nfrom __future__ import absolute_import\n\nfrom .. import backend as F\nfrom .._ffi.function import _init_api\nfrom .._ffi.object import ObjectBase, register_object\nfrom ..container import convert_to_strmap\nfrom ..frame import Frame\nfrom ..heterograph import DGLGraph\n\n_init_api(\"dgl.data.heterograph_serialize\")\n\n\ndef tensor_dict_to_ndarray_dict(tensor_dict):\n    \"\"\"Convert dict[str, tensor] to StrMap[NDArray]\"\"\"\n    ndarray_dict = {}\n    for key, value in tensor_dict.items():\n        ndarray_dict[key] = F.zerocopy_to_dgl_ndarray(value)\n    return convert_to_strmap(ndarray_dict)\n\n\ndef save_heterographs(filename, g_list, labels, formats):\n    \"\"\"Save heterographs into file\"\"\"\n    if labels is None:\n        labels = {}\n    if isinstance(g_list, DGLGraph):\n        g_list = [g_list]\n    assert all(\n        [type(g) == DGLGraph for g in g_list]\n    ), \"Invalid DGLGraph in g_list argument\"\n    gdata_list = [HeteroGraphData.create(g) for g in g_list]\n    if formats is None:\n        formats = []\n    elif isinstance(formats, str):\n        formats = [formats]\n    _CAPI_SaveHeteroGraphData(\n        filename, gdata_list, tensor_dict_to_ndarray_dict(labels), formats\n    )\n\n\n@register_object(\"heterograph_serialize.HeteroGraphData\")\nclass HeteroGraphData(ObjectBase):\n    \"\"\"Object to hold the data to be stored for DGLGraph\"\"\"\n\n    @staticmethod\n    def create(g):\n        edata_list = []\n        ndata_list = []\n        for etype in g.canonical_etypes:\n            edata_list.append(tensor_dict_to_ndarray_dict(g.edges[etype].data))\n        for ntype in g.ntypes:\n            ndata_list.append(tensor_dict_to_ndarray_dict(g.nodes[ntype].data))\n        return _CAPI_MakeHeteroGraphData(\n            g._graph, ndata_list, edata_list, g.ntypes, g.etypes\n        )\n\n    def get_graph(self):\n        ntensor_list = list(_CAPI_GetNDataFromHeteroGraphData(self))\n        etensor_list = list(_CAPI_GetEDataFromHeteroGraphData(self))\n        ntype_names = list(_CAPI_GetNtypesFromHeteroGraphData(self))\n        etype_names = list(_CAPI_GetEtypesFromHeteroGraphData(self))\n        gidx = _CAPI_GetGindexFromHeteroGraphData(self)\n        nframes = []\n        eframes = []\n        for ntid, ntensor in enumerate(ntensor_list):\n            ndict = {\n                ntensor[i]: F.zerocopy_from_dgl_ndarray(ntensor[i + 1])\n                for i in range(0, len(ntensor), 2)\n            }\n            nframes.append(Frame(ndict, num_rows=gidx.num_nodes(ntid)))\n\n        for etid, etensor in enumerate(etensor_list):\n            edict = {\n                etensor[i]: F.zerocopy_from_dgl_ndarray(etensor[i + 1])\n                for i in range(0, len(etensor), 2)\n            }\n            eframes.append(Frame(edict, num_rows=gidx.num_edges(etid)))\n\n        return DGLGraph(gidx, ntype_names, etype_names, nframes, eframes)\n"
  },
  {
    "path": "python/dgl/data/heterophilous_graphs.py",
    "content": "\"\"\"\nDatasets introduced in the 'A Critical Look at the Evaluation of GNNs under Heterophily: Are We\nReally Making Progress? <https://arxiv.org/abs/2302.11640>'__ paper.\n\"\"\"\nimport os\n\nimport numpy as np\n\nfrom ..convert import graph\nfrom ..transforms.functional import to_bidirected\nfrom .dgl_dataset import DGLBuiltinDataset\nfrom .utils import download\n\n\nclass HeterophilousGraphDataset(DGLBuiltinDataset):\n    r\"\"\"Datasets introduced in the 'A Critical Look at the Evaluation of GNNs under Heterophily:\n    Are We Really Making Progress? <https://arxiv.org/abs/2302.11640>'__ paper.\n\n    Parameters\n    ----------\n    name : str\n        Name of the dataset. One of 'roman-empire', 'amazon-ratings', 'minesweeper', 'tolokers',\n        'questions'.\n    raw_dir : str\n        Raw file directory to store the processed data.\n    force_reload : bool\n        Whether to re-download the data source.\n    verbose : bool\n        Whether to print progress information.\n    transform : callable\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access.\n    \"\"\"\n\n    def __init__(\n        self,\n        name,\n        raw_dir=None,\n        force_reload=False,\n        verbose=True,\n        transform=None,\n    ):\n        name = name.lower().replace(\"-\", \"_\")\n        url = f\"https://github.com/yandex-research/heterophilous-graphs/raw/main/data/{name}.npz\"\n        super(HeterophilousGraphDataset, self).__init__(\n            name=name,\n            url=url,\n            raw_dir=raw_dir,\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n    def download(self):\n        download(\n            url=self.url, path=os.path.join(self.raw_path, f\"{self.name}.npz\")\n        )\n\n    def process(self):\n        \"\"\"Load and process the data.\"\"\"\n        try:\n            import torch\n        except ImportError:\n            raise ModuleNotFoundError(\n                \"This dataset requires PyTorch to be the backend.\"\n            )\n\n        data = np.load(os.path.join(self.raw_path, f\"{self.name}.npz\"))\n        src = torch.from_numpy(data[\"edges\"][:, 0])\n        dst = torch.from_numpy(data[\"edges\"][:, 1])\n        features = torch.from_numpy(data[\"node_features\"])\n        labels = torch.from_numpy(data[\"node_labels\"])\n        train_masks = torch.from_numpy(data[\"train_masks\"].T)\n        val_masks = torch.from_numpy(data[\"val_masks\"].T)\n        test_masks = torch.from_numpy(data[\"test_masks\"].T)\n        num_nodes = len(labels)\n        num_classes = len(labels.unique())\n\n        self._num_classes = num_classes\n\n        self._g = to_bidirected(graph((src, dst), num_nodes=num_nodes))\n        self._g.ndata[\"feat\"] = features\n        self._g.ndata[\"label\"] = labels\n        self._g.ndata[\"train_mask\"] = train_masks\n        self._g.ndata[\"val_mask\"] = val_masks\n        self._g.ndata[\"test_mask\"] = test_masks\n\n    def has_cache(self):\n        return os.path.exists(self.raw_path)\n\n    def load(self):\n        self.process()\n\n    def __getitem__(self, idx):\n        assert idx == 0, \"This dataset has only one graph.\"\n        if self._transform is None:\n            return self._g\n        else:\n            return self._transform(self._g)\n\n    def __len__(self):\n        return 1\n\n    @property\n    def num_classes(self):\n        return self._num_classes\n\n\nclass RomanEmpireDataset(HeterophilousGraphDataset):\n    r\"\"\"Roman-empire dataset from the 'A Critical Look at the Evaluation of GNNs under Heterophily:\n    Are We Really Making Progress? <https://arxiv.org/abs/2302.11640>'__ paper.\n\n    This dataset is based on the Roman Empire article from English Wikipedia, which was selected\n    since it is one of the longest articles on Wikipedia. Each node in the graph corresponds to one\n    (non-unique) word in the text. Thus, the number of nodes in the graph is equal to the article’s\n    length. Two words are connected with an edge if at least one of the following two conditions\n    holds: either these words follow each other in the text, or these words are connected in the\n    dependency tree of the sentence (one word depends on the other). Thus, the graph is a chain\n    graph with additional shortcut edges corresponding to syntactic dependencies between words. The\n    class of a node is its syntactic role (17 most frequent roles were selected as unique classes\n    and all the other roles were grouped into the 18th class). Node features are word embeddings.\n\n    Statistics:\n\n    - Nodes: 22662\n    - Edges: 65854\n    - Classes: 18\n    - Node features: 300\n    - 10 train/val/test splits\n\n    Parameters\n    ----------\n    raw_dir : str, optional\n        Raw file directory to store the processed data. Default: ~/.dgl/\n    force_reload : bool, optional\n        Whether to re-download the data source. Default: False\n    verbose : bool, optional\n        Whether to print progress information. Default: True\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access. Default: None\n\n    Attributes\n    ----------\n    num_classes : int\n        Number of node classes\n\n\n    Examples\n    --------\n\n    >>> from dgl.data import RomanEmpireDataset\n    >>> dataset = RomanEmpireDataset()\n    >>> g = dataset[0]\n    >>> num_classes = dataset.num_classes\n\n    >>> # get node features\n    >>> feat = g.ndata[\"feat\"]\n\n    >>> # get the first data split\n    >>> train_mask = g.ndata[\"train_mask\"][:, 0]\n    >>> val_mask = g.ndata[\"val_mask\"][:, 0]\n    >>> test_mask = g.ndata[\"test_mask\"][:, 0]\n\n    >>> # get labels\n    >>> label = g.ndata['label']\n    \"\"\"\n\n    def __init__(\n        self, raw_dir=None, force_reload=False, verbose=True, transform=None\n    ):\n        super(RomanEmpireDataset, self).__init__(\n            name=\"roman-empire\",\n            raw_dir=raw_dir,\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n\nclass AmazonRatingsDataset(HeterophilousGraphDataset):\n    r\"\"\"Amazon-ratings dataset from the 'A Critical Look at the Evaluation of GNNs under\n    Heterophily: Are We Really Making Progress? <https://arxiv.org/abs/2302.11640>'__ paper.\n\n    This dataset is based on the Amazon product co-purchasing data. Nodes are products (books, music\n    CDs, DVDs, VHS video tapes), and edges connect products that are frequently bought together. The\n    task is to predict the average rating given to a product by reviewers. All possible rating\n    values were grouped into five classes. Node features are the mean of word embeddings for words\n    in the product description.\n\n    Statistics:\n\n    - Nodes: 24492\n    - Edges: 186100\n    - Classes: 5\n    - Node features: 300\n    - 10 train/val/test splits\n\n    Parameters\n    ----------\n    raw_dir : str, optional\n        Raw file directory to store the processed data. Default: ~/.dgl/\n    force_reload : bool, optional\n        Whether to re-download the data source. Default: False\n    verbose : bool, optional\n        Whether to print progress information. Default: True\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access. Default: None\n\n    Attributes\n    ----------\n    num_classes : int\n        Number of node classes\n\n\n    Examples\n    --------\n\n    >>> from dgl.data import AmazonRatingsDataset\n    >>> dataset = AmazonRatingsDataset()\n    >>> g = dataset[0]\n    >>> num_classes = dataset.num_classes\n\n    >>> # get node features\n    >>> feat = g.ndata[\"feat\"]\n\n    >>> # get the first data split\n    >>> train_mask = g.ndata[\"train_mask\"][:, 0]\n    >>> val_mask = g.ndata[\"val_mask\"][:, 0]\n    >>> test_mask = g.ndata[\"test_mask\"][:, 0]\n\n    >>> # get labels\n    >>> label = g.ndata['label']\n    \"\"\"\n\n    def __init__(\n        self, raw_dir=None, force_reload=False, verbose=True, transform=None\n    ):\n        super(AmazonRatingsDataset, self).__init__(\n            name=\"amazon-ratings\",\n            raw_dir=raw_dir,\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n\nclass MinesweeperDataset(HeterophilousGraphDataset):\n    r\"\"\"Minesweeper dataset from the 'A Critical Look at the Evaluation of GNNs under Heterophily:\n    Are We Really Making Progress? <https://arxiv.org/abs/2302.11640>'__ paper.\n\n    This dataset is inspired by the Minesweeper game. The graph is a regular 100x100 grid where each\n    node (cell) is connected to eight neighboring nodes (with the exception of nodes at the edge of\n    the grid, which have fewer neighbors). 20% of the nodes are randomly selected as mines. The task\n    is to predict which nodes are mines. The node features are one-hot-encoded numbers of\n    neighboring mines. However, for randomly selected 50% of the nodes, the features are unknown,\n    which is indicated by a separate binary feature.\n\n    Statistics:\n\n    - Nodes: 10000\n    - Edges: 78804\n    - Classes: 2\n    - Node features: 7\n    - 10 train/val/test splits\n\n    Parameters\n    ----------\n    raw_dir : str, optional\n        Raw file directory to store the processed data. Default: ~/.dgl/\n    force_reload : bool, optional\n        Whether to re-download the data source. Default: False\n    verbose : bool, optional\n        Whether to print progress information. Default: True\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access. Default: None\n\n    Attributes\n    ----------\n    num_classes : int\n        Number of node classes\n\n\n    Examples\n    --------\n\n    >>> from dgl.data import MinesweeperDataset\n    >>> dataset = MinesweeperDataset()\n    >>> g = dataset[0]\n    >>> num_classes = dataset.num_classes\n\n    >>> # get node features\n    >>> feat = g.ndata[\"feat\"]\n\n    >>> # get the first data split\n    >>> train_mask = g.ndata[\"train_mask\"][:, 0]\n    >>> val_mask = g.ndata[\"val_mask\"][:, 0]\n    >>> test_mask = g.ndata[\"test_mask\"][:, 0]\n\n    >>> # get labels\n    >>> label = g.ndata['label']\n    \"\"\"\n\n    def __init__(\n        self, raw_dir=None, force_reload=False, verbose=True, transform=None\n    ):\n        super(MinesweeperDataset, self).__init__(\n            name=\"minesweeper\",\n            raw_dir=raw_dir,\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n\nclass TolokersDataset(HeterophilousGraphDataset):\n    r\"\"\"Tolokers dataset from the 'A Critical Look at the Evaluation of GNNs under Heterophily:\n    Are We Really Making Progress? <https://arxiv.org/abs/2302.11640>'__ paper.\n\n    This dataset is based on data from the Toloka crowdsourcing platform. The nodes represent\n    tolokers (workers). An edge connects two tolokers if they have worked on the same task. The goal\n    is to predict which tolokers have been banned in one of the projects. Node features are based on\n    the worker’s profile information and task performance statistics.\n\n    Statistics:\n\n    - Nodes: 11758\n    - Edges: 1038000\n    - Classes: 2\n    - Node features: 10\n    - 10 train/val/test splits\n\n    Parameters\n    ----------\n    raw_dir : str, optional\n        Raw file directory to store the processed data. Default: ~/.dgl/\n    force_reload : bool, optional\n        Whether to re-download the data source. Default: False\n    verbose : bool, optional\n        Whether to print progress information. Default: True\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access. Default: None\n\n    Attributes\n    ----------\n    num_classes : int\n        Number of node classes\n\n\n    Examples\n    --------\n\n    >>> from dgl.data import TolokersDataset\n    >>> dataset = TolokersDataset()\n    >>> g = dataset[0]\n    >>> num_classes = dataset.num_classes\n\n    >>> # get node features\n    >>> feat = g.ndata[\"feat\"]\n\n    >>> # get the first data split\n    >>> train_mask = g.ndata[\"train_mask\"][:, 0]\n    >>> val_mask = g.ndata[\"val_mask\"][:, 0]\n    >>> test_mask = g.ndata[\"test_mask\"][:, 0]\n\n    >>> # get labels\n    >>> label = g.ndata['label']\n    \"\"\"\n\n    def __init__(\n        self, raw_dir=None, force_reload=False, verbose=True, transform=None\n    ):\n        super(TolokersDataset, self).__init__(\n            name=\"tolokers\",\n            raw_dir=raw_dir,\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n\nclass QuestionsDataset(HeterophilousGraphDataset):\n    r\"\"\"Questions dataset from the 'A Critical Look at the Evaluation of GNNs under Heterophily:\n    Are We Really Making Progress? <https://arxiv.org/abs/2302.11640>'__ paper.\n\n    This dataset is based on data from the question-answering website Yandex Q. Nodes are users, and\n    an edge connects two nodes if one user answered the other user’s question. The task is to\n    predict which users remained active on the website (were not deleted or blocked). Node features\n    are the mean of word embeddings for words in the user description. Users that do not have\n    description are indicated by a separate binary feature.\n\n    Statistics:\n\n    - Nodes: 48921\n    - Edges: 307080\n    - Classes: 2\n    - Node features: 301\n    - 10 train/val/test splits\n\n    Parameters\n    ----------\n    raw_dir : str, optional\n        Raw file directory to store the processed data. Default: ~/.dgl/\n    force_reload : bool, optional\n        Whether to re-download the data source. Default: False\n    verbose : bool, optional\n        Whether to print progress information. Default: True\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access. Default: None\n\n    Attributes\n    ----------\n    num_classes : int\n        Number of node classes\n\n\n    Examples\n    --------\n\n    >>> from dgl.data import QuestionsDataset\n    >>> dataset = QuestionsDataset()\n    >>> g = dataset[0]\n    >>> num_classes = dataset.num_classes\n\n    >>> # get node features\n    >>> feat = g.ndata[\"feat\"]\n\n    >>> # get the first data split\n    >>> train_mask = g.ndata[\"train_mask\"][:, 0]\n    >>> val_mask = g.ndata[\"val_mask\"][:, 0]\n    >>> test_mask = g.ndata[\"test_mask\"][:, 0]\n\n    >>> # get labels\n    >>> label = g.ndata['label']\n    \"\"\"\n\n    def __init__(\n        self, raw_dir=None, force_reload=False, verbose=True, transform=None\n    ):\n        super(QuestionsDataset, self).__init__(\n            name=\"questions\",\n            raw_dir=raw_dir,\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n"
  },
  {
    "path": "python/dgl/data/icews18.py",
    "content": "\"\"\"ICEWS18 dataset for temporal graph\"\"\"\nimport os\n\nimport numpy as np\n\nfrom .. import backend as F\nfrom ..convert import graph as dgl_graph\nfrom .dgl_dataset import DGLBuiltinDataset\nfrom .utils import _get_dgl_url, load_graphs, loadtxt, save_graphs\n\n\nclass ICEWS18Dataset(DGLBuiltinDataset):\n    r\"\"\"ICEWS18 dataset for temporal graph\n\n    Integrated Crisis Early Warning System (ICEWS18)\n\n    Event data consists of coded interactions between socio-political\n    actors (i.e., cooperative or hostile actions between individuals,\n    groups, sectors and nation states). This Dataset consists of events\n    from 1/1/2018 to 10/31/2018 (24 hours time granularity).\n\n    Reference:\n\n        - `Recurrent Event Network for Reasoning over Temporal Knowledge Graphs <https://arxiv.org/abs/1904.05530>`_\n        - `ICEWS Coded Event Data <https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/28075>`_\n\n    Statistics：\n\n    - Train examples: 240\n    - Valid examples: 30\n    - Test examples: 34\n    - Nodes per graph: 23033\n\n    Parameters\n    ----------\n    mode: str\n        Load train/valid/test data. Has to be one of ['train', 'valid', 'test']\n    raw_dir : str\n        Raw file directory to download/contains the input data directory.\n        Default: ~/.dgl/\n    force_reload : bool\n        Whether to reload the dataset. Default: False\n    verbose : bool\n        Whether to print out progress information. Default: True.\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access.\n\n    Attributes\n    -------\n    is_temporal : bool\n        Is the dataset contains temporal graphs\n\n    Examples\n    --------\n    >>> # get train, valid, test set\n    >>> train_data = ICEWS18Dataset()\n    >>> valid_data = ICEWS18Dataset(mode='valid')\n    >>> test_data = ICEWS18Dataset(mode='test')\n    >>>\n    >>> train_size = len(train_data)\n    >>> for g in train_data:\n    ....    e_feat = g.edata['rel_type']\n    ....    # your code here\n    ....\n    >>>\n    \"\"\"\n\n    def __init__(\n        self,\n        mode=\"train\",\n        raw_dir=None,\n        force_reload=False,\n        verbose=False,\n        transform=None,\n    ):\n        mode = mode.lower()\n        assert mode in [\"train\", \"valid\", \"test\"], \"Mode not valid\"\n        self.mode = mode\n        _url = _get_dgl_url(\"dataset/icews18.zip\")\n        super(ICEWS18Dataset, self).__init__(\n            name=\"ICEWS18\",\n            url=_url,\n            raw_dir=raw_dir,\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n    def process(self):\n        data = loadtxt(\n            os.path.join(self.save_path, \"{}.txt\".format(self.mode)),\n            delimiter=\"\\t\",\n        ).astype(np.int64)\n        num_nodes = 23033\n        # The source code is not released, but the paper indicates there're\n        # totally 137 samples. The cutoff below has exactly 137 samples.\n        time_index = np.floor(data[:, 3] / 24).astype(np.int64)\n        start_time = time_index[time_index != -1].min()\n        end_time = time_index.max()\n        self._graphs = []\n        for i in range(start_time, end_time + 1):\n            row_mask = time_index <= i\n            edges = data[row_mask][:, [0, 2]]\n            rate = data[row_mask][:, 1]\n            g = dgl_graph((edges[:, 0], edges[:, 1]))\n            g.edata[\"rel_type\"] = F.tensor(\n                rate.reshape(-1, 1), dtype=F.data_type_dict[\"int64\"]\n            )\n            self._graphs.append(g)\n\n    def has_cache(self):\n        graph_path = os.path.join(\n            self.save_path, \"{}_dgl_graph.bin\".format(self.mode)\n        )\n        return os.path.exists(graph_path)\n\n    def save(self):\n        graph_path = os.path.join(\n            self.save_path, \"{}_dgl_graph.bin\".format(self.mode)\n        )\n        save_graphs(graph_path, self._graphs)\n\n    def load(self):\n        graph_path = os.path.join(\n            self.save_path, \"{}_dgl_graph.bin\".format(self.mode)\n        )\n        self._graphs = load_graphs(graph_path)[0]\n\n    def __getitem__(self, idx):\n        r\"\"\"Get graph by index\n\n        Parameters\n        ----------\n        idx : int\n            Item index\n\n        Returns\n        -------\n        :class:`dgl.DGLGraph`\n\n            The graph contains:\n\n            - ``edata['rel_type']``: edge type\n        \"\"\"\n        if self._transform is None:\n            return self._graphs[idx]\n        else:\n            return self._transform(self._graphs[idx])\n\n    def __len__(self):\n        r\"\"\"Number of graphs in the dataset.\n\n        Return\n        -------\n        int\n        \"\"\"\n        return len(self._graphs)\n\n    @property\n    def is_temporal(self):\n        r\"\"\"Is the dataset contains temporal graphs\n\n        Returns\n        -------\n        bool\n        \"\"\"\n        return True\n\n\nICEWS18 = ICEWS18Dataset\n"
  },
  {
    "path": "python/dgl/data/karate.py",
    "content": "\"\"\"KarateClub Dataset\n\"\"\"\nimport networkx as nx\nimport numpy as np\n\nfrom .. import backend as F\nfrom ..convert import from_networkx\nfrom .dgl_dataset import DGLDataset\nfrom .utils import deprecate_property\n\n__all__ = [\"KarateClubDataset\", \"KarateClub\"]\n\n\nclass KarateClubDataset(DGLDataset):\n    r\"\"\"Karate Club dataset for Node Classification\n\n    Zachary's karate club is a social network of a university\n    karate club, described in the paper \"An Information Flow\n    Model for Conflict and Fission in Small Groups\" by Wayne W. Zachary.\n    The network became a popular example of community structure in\n    networks after its use by Michelle Girvan and Mark Newman in 2002.\n    Official website: `<http://konect.cc/networks/ucidata-zachary/>`_\n\n    Karate Club dataset statistics:\n\n    - Nodes: 34\n    - Edges: 156\n    - Number of Classes: 2\n\n    Parameters\n    ----------\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access.\n\n    Attributes\n    ----------\n    num_classes : int\n        Number of node classes\n\n    Examples\n    --------\n    >>> dataset = KarateClubDataset()\n    >>> num_classes = dataset.num_classes\n    >>> g = dataset[0]\n    >>> labels = g.ndata['label']\n    \"\"\"\n\n    def __init__(self, transform=None):\n        super(KarateClubDataset, self).__init__(\n            name=\"karate_club\", transform=transform\n        )\n\n    def process(self):\n        kc_graph = nx.karate_club_graph()\n        label = np.asarray(\n            [kc_graph.nodes[i][\"club\"] != \"Mr. Hi\" for i in kc_graph.nodes]\n        ).astype(np.int64)\n        label = F.tensor(label)\n        g = from_networkx(kc_graph)\n        g.ndata[\"label\"] = label\n        self._graph = g\n        self._data = [g]\n\n    @property\n    def num_classes(self):\n        \"\"\"Number of classes.\"\"\"\n        return 2\n\n    def __getitem__(self, idx):\n        r\"\"\"Get graph object\n\n        Parameters\n        ----------\n        idx : int\n            Item index, KarateClubDataset has only one graph object\n\n        Returns\n        -------\n        :class:`dgl.DGLGraph`\n\n            graph structure and labels.\n\n            - ``ndata['label']``: ground truth labels\n        \"\"\"\n        assert idx == 0, \"This dataset has only one graph\"\n        if self._transform is None:\n            return self._graph\n        else:\n            return self._transform(self._graph)\n\n    def __len__(self):\n        r\"\"\"The number of graphs in the dataset.\"\"\"\n        return 1\n\n\nKarateClub = KarateClubDataset\n"
  },
  {
    "path": "python/dgl/data/knowledge_graph.py",
    "content": "from __future__ import absolute_import\n\nimport os, sys\nimport pickle as pkl\n\nimport networkx as nx\n\nimport numpy as np\nimport scipy.sparse as sp\n\nfrom .. import backend as F\nfrom ..convert import graph as dgl_graph\nfrom ..utils import retry_method_with_fix\n\nfrom .dgl_dataset import DGLBuiltinDataset\nfrom .utils import (\n    _get_dgl_url,\n    deprecate_function,\n    deprecate_property,\n    download,\n    extract_archive,\n    generate_mask_tensor,\n    get_download_dir,\n    load_graphs,\n    load_info,\n    makedirs,\n    save_graphs,\n    save_info,\n)\n\n\nclass KnowledgeGraphDataset(DGLBuiltinDataset):\n    \"\"\"KnowledgeGraph link prediction dataset\n\n    The dataset contains a graph depicting the connectivity of a knowledge\n    base. Currently, the knowledge bases from the\n    `RGCN paper <https://arxiv.org/pdf/1703.06103.pdf>`_ supported are\n    FB15k-237, FB15k, wn18\n\n    Parameters\n    -----------\n    name : str\n        Name can be 'FB15k-237', 'FB15k' or 'wn18'.\n    reverse : bool\n        Whether add reverse edges. Default: True.\n    raw_dir : str\n        Raw file directory to download/contains the input data directory.\n        Default: ~/.dgl/\n    force_reload : bool\n        Whether to reload the dataset. Default: False\n    verbose : bool\n        Whether to print out progress information. Default: True.\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access.\n    \"\"\"\n\n    def __init__(\n        self,\n        name,\n        reverse=True,\n        raw_dir=None,\n        force_reload=False,\n        verbose=True,\n        transform=None,\n    ):\n        self._name = name\n        self.reverse = reverse\n        url = _get_dgl_url(\"dataset/\") + \"{}.tgz\".format(name)\n        super(KnowledgeGraphDataset, self).__init__(\n            name,\n            url=url,\n            raw_dir=raw_dir,\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n    def download(self):\n        r\"\"\"Automatically download data and extract it.\"\"\"\n        tgz_path = os.path.join(self.raw_dir, self.name + \".tgz\")\n        download(self.url, path=tgz_path)\n        extract_archive(tgz_path, self.raw_path)\n\n    def process(self):\n        \"\"\"\n        The original knowledge base is stored in triplets.\n        This function will parse these triplets and build the DGLGraph.\n        \"\"\"\n        root_path = self.raw_path\n        entity_path = os.path.join(root_path, \"entities.dict\")\n        relation_path = os.path.join(root_path, \"relations.dict\")\n        train_path = os.path.join(root_path, \"train.txt\")\n        valid_path = os.path.join(root_path, \"valid.txt\")\n        test_path = os.path.join(root_path, \"test.txt\")\n        entity_dict = _read_dictionary(entity_path)\n        relation_dict = _read_dictionary(relation_path)\n        train = np.asarray(\n            _read_triplets_as_list(train_path, entity_dict, relation_dict)\n        )\n        valid = np.asarray(\n            _read_triplets_as_list(valid_path, entity_dict, relation_dict)\n        )\n        test = np.asarray(\n            _read_triplets_as_list(test_path, entity_dict, relation_dict)\n        )\n        num_nodes = len(entity_dict)\n        num_rels = len(relation_dict)\n        if self.verbose:\n            print(\"# entities: {}\".format(num_nodes))\n            print(\"# relations: {}\".format(num_rels))\n            print(\"# training edges: {}\".format(train.shape[0]))\n            print(\"# validation edges: {}\".format(valid.shape[0]))\n            print(\"# testing edges: {}\".format(test.shape[0]))\n\n        # for compatability\n        self._train = train\n        self._valid = valid\n        self._test = test\n\n        self._num_nodes = num_nodes\n        self._num_rels = num_rels\n        # build graph\n        g, data = build_knowledge_graph(\n            num_nodes, num_rels, train, valid, test, reverse=self.reverse\n        )\n        (\n            etype,\n            ntype,\n            train_edge_mask,\n            valid_edge_mask,\n            test_edge_mask,\n            train_mask,\n            val_mask,\n            test_mask,\n        ) = data\n        g.edata[\"train_edge_mask\"] = train_edge_mask\n        g.edata[\"valid_edge_mask\"] = valid_edge_mask\n        g.edata[\"test_edge_mask\"] = test_edge_mask\n        g.edata[\"train_mask\"] = train_mask\n        g.edata[\"val_mask\"] = val_mask\n        g.edata[\"test_mask\"] = test_mask\n        g.edata[\"etype\"] = etype\n        g.ndata[\"ntype\"] = ntype\n        self._g = g\n\n    @property\n    def graph_path(self):\n        return os.path.join(self.save_path, self.save_name + \".bin\")\n\n    @property\n    def info_path(self):\n        return os.path.join(self.save_path, self.save_name + \".pkl\")\n\n    def has_cache(self):\n        if os.path.exists(self.graph_path) and os.path.exists(self.info_path):\n            return True\n\n        return False\n\n    def __getitem__(self, idx):\n        assert idx == 0, \"This dataset has only one graph\"\n        if self._transform is None:\n            return self._g\n        else:\n            return self._transform(self._g)\n\n    def __len__(self):\n        return 1\n\n    def save(self):\n        \"\"\"save the graph list and the labels\"\"\"\n        save_graphs(str(self.graph_path), self._g)\n        save_info(\n            str(self.info_path),\n            {\"num_nodes\": self.num_nodes, \"num_rels\": self.num_rels},\n        )\n\n    def load(self):\n        graphs, _ = load_graphs(str(self.graph_path))\n\n        info = load_info(str(self.info_path))\n        self._num_nodes = info[\"num_nodes\"]\n        self._num_rels = info[\"num_rels\"]\n        self._g = graphs[0]\n        train_mask = self._g.edata[\"train_edge_mask\"].numpy()\n        val_mask = self._g.edata[\"valid_edge_mask\"].numpy()\n        test_mask = self._g.edata[\"test_edge_mask\"].numpy()\n\n        # convert mask tensor into bool tensor if possible\n        self._g.edata[\"train_edge_mask\"] = generate_mask_tensor(\n            self._g.edata[\"train_edge_mask\"].numpy()\n        )\n        self._g.edata[\"valid_edge_mask\"] = generate_mask_tensor(\n            self._g.edata[\"valid_edge_mask\"].numpy()\n        )\n        self._g.edata[\"test_edge_mask\"] = generate_mask_tensor(\n            self._g.edata[\"test_edge_mask\"].numpy()\n        )\n        self._g.edata[\"train_mask\"] = generate_mask_tensor(\n            self._g.edata[\"train_mask\"].numpy()\n        )\n        self._g.edata[\"val_mask\"] = generate_mask_tensor(\n            self._g.edata[\"val_mask\"].numpy()\n        )\n        self._g.edata[\"test_mask\"] = generate_mask_tensor(\n            self._g.edata[\"test_mask\"].numpy()\n        )\n\n        # for compatability (with 0.4.x) generate train_idx, valid_idx and test_idx\n        etype = self._g.edata[\"etype\"].numpy()\n        self._etype = etype\n        u, v = self._g.all_edges(form=\"uv\")\n        u = u.numpy()\n        v = v.numpy()\n        train_idx = np.nonzero(train_mask == 1)\n        self._train = np.column_stack(\n            (u[train_idx], etype[train_idx], v[train_idx])\n        )\n        valid_idx = np.nonzero(val_mask == 1)\n        self._valid = np.column_stack(\n            (u[valid_idx], etype[valid_idx], v[valid_idx])\n        )\n        test_idx = np.nonzero(test_mask == 1)\n        self._test = np.column_stack(\n            (u[test_idx], etype[test_idx], v[test_idx])\n        )\n\n        if self.verbose:\n            print(\"# entities: {}\".format(self.num_nodes))\n            print(\"# relations: {}\".format(self.num_rels))\n            print(\"# training edges: {}\".format(self._train.shape[0]))\n            print(\"# validation edges: {}\".format(self._valid.shape[0]))\n            print(\"# testing edges: {}\".format(self._test.shape[0]))\n\n    @property\n    def num_nodes(self):\n        return self._num_nodes\n\n    @property\n    def num_rels(self):\n        return self._num_rels\n\n    @property\n    def save_name(self):\n        return self.name + \"_dgl_graph\"\n\n\ndef _read_dictionary(filename):\n    d = {}\n    with open(filename, \"r+\") as f:\n        for line in f:\n            line = line.strip().split(\"\\t\")\n            d[line[1]] = int(line[0])\n    return d\n\n\ndef _read_triplets(filename):\n    with open(filename, \"r+\") as f:\n        for line in f:\n            processed_line = line.strip().split(\"\\t\")\n            yield processed_line\n\n\ndef _read_triplets_as_list(filename, entity_dict, relation_dict):\n    l = []\n    for triplet in _read_triplets(filename):\n        s = entity_dict[triplet[0]]\n        r = relation_dict[triplet[1]]\n        o = entity_dict[triplet[2]]\n        l.append([s, r, o])\n    return l\n\n\ndef build_knowledge_graph(\n    num_nodes, num_rels, train, valid, test, reverse=True\n):\n    \"\"\"Create a DGL Homogeneous graph with heterograph info stored as node or edge features.\"\"\"\n    src = []\n    rel = []\n    dst = []\n    raw_subg = {}\n    raw_subg_eset = {}\n    raw_subg_etype = {}\n    raw_reverse_sugb = {}\n    raw_reverse_subg_eset = {}\n    raw_reverse_subg_etype = {}\n\n    # here there is noly one node type\n    s_type = \"node\"\n    d_type = \"node\"\n\n    def add_edge(s, r, d, reverse, edge_set):\n        r_type = str(r)\n        e_type = (s_type, r_type, d_type)\n        if raw_subg.get(e_type, None) is None:\n            raw_subg[e_type] = ([], [])\n            raw_subg_eset[e_type] = []\n            raw_subg_etype[e_type] = []\n        raw_subg[e_type][0].append(s)\n        raw_subg[e_type][1].append(d)\n        raw_subg_eset[e_type].append(edge_set)\n        raw_subg_etype[e_type].append(r)\n\n        if reverse is True:\n            r_type = str(r + num_rels)\n            re_type = (d_type, r_type, s_type)\n            if raw_reverse_sugb.get(re_type, None) is None:\n                raw_reverse_sugb[re_type] = ([], [])\n                raw_reverse_subg_etype[re_type] = []\n                raw_reverse_subg_eset[re_type] = []\n            raw_reverse_sugb[re_type][0].append(d)\n            raw_reverse_sugb[re_type][1].append(s)\n            raw_reverse_subg_eset[re_type].append(edge_set)\n            raw_reverse_subg_etype[re_type].append(r + num_rels)\n\n    for edge in train:\n        s, r, d = edge\n        assert r < num_rels\n        add_edge(s, r, d, reverse, 1)  # train set\n\n    for edge in valid:\n        s, r, d = edge\n        assert r < num_rels\n        add_edge(s, r, d, reverse, 2)  # valid set\n\n    for edge in test:\n        s, r, d = edge\n        assert r < num_rels\n        add_edge(s, r, d, reverse, 3)  # test set\n\n    subg = []\n    fg_s = []\n    fg_d = []\n    fg_etype = []\n    fg_settype = []\n    for e_type, val in raw_subg.items():\n        s, d = val\n        s = np.asarray(s)\n        d = np.asarray(d)\n        etype = raw_subg_etype[e_type]\n        etype = np.asarray(etype)\n        settype = raw_subg_eset[e_type]\n        settype = np.asarray(settype)\n\n        fg_s.append(s)\n        fg_d.append(d)\n        fg_etype.append(etype)\n        fg_settype.append(settype)\n\n    settype = np.concatenate(fg_settype)\n    if reverse is True:\n        settype = np.concatenate([settype, np.full((settype.shape[0]), 0)])\n    train_edge_mask = generate_mask_tensor(settype == 1)\n    valid_edge_mask = generate_mask_tensor(settype == 2)\n    test_edge_mask = generate_mask_tensor(settype == 3)\n\n    for e_type, val in raw_reverse_sugb.items():\n        s, d = val\n        s = np.asarray(s)\n        d = np.asarray(d)\n        etype = raw_reverse_subg_etype[e_type]\n        etype = np.asarray(etype)\n        settype = raw_reverse_subg_eset[e_type]\n        settype = np.asarray(settype)\n\n        fg_s.append(s)\n        fg_d.append(d)\n        fg_etype.append(etype)\n        fg_settype.append(settype)\n\n    s = np.concatenate(fg_s)\n    d = np.concatenate(fg_d)\n    g = dgl_graph((s, d), num_nodes=num_nodes)\n    etype = np.concatenate(fg_etype)\n    settype = np.concatenate(fg_settype)\n    etype = F.tensor(etype, dtype=F.data_type_dict[\"int64\"])\n    train_edge_mask = train_edge_mask\n    valid_edge_mask = valid_edge_mask\n    test_edge_mask = test_edge_mask\n    train_mask = (\n        generate_mask_tensor(settype == 1)\n        if reverse is True\n        else train_edge_mask\n    )\n    valid_mask = (\n        generate_mask_tensor(settype == 2)\n        if reverse is True\n        else valid_edge_mask\n    )\n    test_mask = (\n        generate_mask_tensor(settype == 3)\n        if reverse is True\n        else test_edge_mask\n    )\n    ntype = F.full_1d(\n        num_nodes, 0, dtype=F.data_type_dict[\"int64\"], ctx=F.cpu()\n    )\n\n    return g, (\n        etype,\n        ntype,\n        train_edge_mask,\n        valid_edge_mask,\n        test_edge_mask,\n        train_mask,\n        valid_mask,\n        test_mask,\n    )\n\n\nclass FB15k237Dataset(KnowledgeGraphDataset):\n    r\"\"\"FB15k237 link prediction dataset.\n\n    FB15k-237 is a subset of FB15k where inverse\n    relations are removed. When creating the dataset,\n    a reverse edge with reversed relation types are\n    created for each edge by default.\n\n    FB15k237 dataset statistics:\n\n    - Nodes: 14541\n    - Number of relation types: 237\n    - Number of reversed relation types: 237\n    - Label Split:\n\n        - Train: 272115\n        - Valid: 17535\n        - Test: 20466\n\n    Parameters\n    ----------\n    reverse : bool\n        Whether to add reverse edge. Default True.\n    raw_dir : str\n        Raw file directory to download/contains the input data directory.\n        Default: ~/.dgl/\n    force_reload : bool\n        Whether to reload the dataset. Default: False\n    verbose : bool\n        Whether to print out progress information. Default: True.\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access.\n\n    Attributes\n    ----------\n    num_nodes: int\n        Number of nodes\n    num_rels: int\n        Number of relation types\n\n    Examples\n    ----------\n    >>> dataset = FB15k237Dataset()\n    >>> g = dataset.graph\n    >>> e_type = g.edata['e_type']\n    >>>\n    >>> # get data split\n    >>> train_mask = g.edata['train_mask']\n    >>> val_mask = g.edata['val_mask']\n    >>> test_mask = g.edata['test_mask']\n    >>>\n    >>> train_set = th.arange(g.num_edges())[train_mask]\n    >>> val_set = th.arange(g.num_edges())[val_mask]\n    >>>\n    >>> # build train_g\n    >>> train_edges = train_set\n    >>> train_g = g.edge_subgraph(train_edges,\n                                  relabel_nodes=False)\n    >>> train_g.edata['e_type'] = e_type[train_edges];\n    >>>\n    >>> # build val_g\n    >>> val_edges = th.cat([train_edges, val_edges])\n    >>> val_g = g.edge_subgraph(val_edges,\n                                relabel_nodes=False)\n    >>> val_g.edata['e_type'] = e_type[val_edges];\n    >>>\n    >>> # Train, Validation and Test\n    \"\"\"\n\n    def __init__(\n        self,\n        reverse=True,\n        raw_dir=None,\n        force_reload=False,\n        verbose=True,\n        transform=None,\n    ):\n        name = \"FB15k-237\"\n        super(FB15k237Dataset, self).__init__(\n            name, reverse, raw_dir, force_reload, verbose, transform\n        )\n\n    def __getitem__(self, idx):\n        r\"\"\"Gets the graph object\n\n        Parameters\n        -----------\n        idx: int\n            Item index, FB15k237Dataset has only one graph object\n\n        Return\n        -------\n        :class:`dgl.DGLGraph`\n\n            The graph contains\n\n            - ``edata['e_type']``: edge relation type\n            - ``edata['train_edge_mask']``: positive training edge mask\n            - ``edata['val_edge_mask']``: positive validation edge mask\n            - ``edata['test_edge_mask']``: positive testing edge mask\n            - ``edata['train_mask']``: training edge set mask (include reversed training edges)\n            - ``edata['val_mask']``: validation edge set mask (include reversed validation edges)\n            - ``edata['test_mask']``: testing edge set mask (include reversed testing edges)\n            - ``ndata['ntype']``: node type. All 0 in this dataset\n        \"\"\"\n        return super(FB15k237Dataset, self).__getitem__(idx)\n\n    def __len__(self):\n        r\"\"\"The number of graphs in the dataset.\"\"\"\n        return super(FB15k237Dataset, self).__len__()\n\n\nclass FB15kDataset(KnowledgeGraphDataset):\n    r\"\"\"FB15k link prediction dataset.\n\n    The FB15K dataset was introduced in `Translating Embeddings for Modeling\n    Multi-relational Data <http://papers.nips.cc/paper/5071-translating-embeddings-for-modeling-multi-relational-data.pdf>`_.\n    It is a subset of Freebase which contains about\n    14,951 entities with 1,345 different relations.\n    When creating the dataset, a reverse edge with\n    reversed relation types are created for each edge\n    by default.\n\n    FB15k dataset statistics:\n\n    - Nodes: 14,951\n    - Number of relation types: 1,345\n    - Number of reversed relation types: 1,345\n    - Label Split:\n\n        - Train: 483142\n        - Valid: 50000\n        - Test: 59071\n\n    Parameters\n    ----------\n    reverse : bool\n        Whether to add reverse edge. Default True.\n    raw_dir : str\n        Raw file directory to download/contains the input data directory.\n        Default: ~/.dgl/\n    force_reload : bool\n        Whether to reload the dataset. Default: False\n    verbose : bool\n        Whether to print out progress information. Default: True.\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access.\n\n    Attributes\n    ----------\n    num_nodes: int\n        Number of nodes\n    num_rels: int\n        Number of relation types\n\n    Examples\n    ----------\n    >>> dataset = FB15kDataset()\n    >>> g = dataset.graph\n    >>> e_type = g.edata['e_type']\n    >>>\n    >>> # get data split\n    >>> train_mask = g.edata['train_mask']\n    >>> val_mask = g.edata['val_mask']\n    >>>\n    >>> train_set = th.arange(g.num_edges())[train_mask]\n    >>> val_set = th.arange(g.num_edges())[val_mask]\n    >>>\n    >>> # build train_g\n    >>> train_edges = train_set\n    >>> train_g = g.edge_subgraph(train_edges,\n                                  relabel_nodes=False)\n    >>> train_g.edata['e_type'] = e_type[train_edges];\n    >>>\n    >>> # build val_g\n    >>> val_edges = th.cat([train_edges, val_edges])\n    >>> val_g = g.edge_subgraph(val_edges,\n                                relabel_nodes=False)\n    >>> val_g.edata['e_type'] = e_type[val_edges];\n    >>>\n    >>> # Train, Validation and Test\n    >>>\n    \"\"\"\n\n    def __init__(\n        self,\n        reverse=True,\n        raw_dir=None,\n        force_reload=False,\n        verbose=True,\n        transform=None,\n    ):\n        name = \"FB15k\"\n        super(FB15kDataset, self).__init__(\n            name, reverse, raw_dir, force_reload, verbose, transform\n        )\n\n    def __getitem__(self, idx):\n        r\"\"\"Gets the graph object\n\n        Parameters\n        -----------\n        idx: int\n            Item index, FB15kDataset has only one graph object\n\n        Return\n        -------\n        :class:`dgl.DGLGraph`\n\n            The graph contains\n\n            - ``edata['e_type']``: edge relation type\n            - ``edata['train_edge_mask']``: positive training edge mask\n            - ``edata['val_edge_mask']``: positive validation edge mask\n            - ``edata['test_edge_mask']``: positive testing edge mask\n            - ``edata['train_mask']``: training edge set mask (include reversed training edges)\n            - ``edata['val_mask']``: validation edge set mask (include reversed validation edges)\n            - ``edata['test_mask']``: testing edge set mask (include reversed testing edges)\n            - ``ndata['ntype']``: node type. All 0 in this dataset\n        \"\"\"\n        return super(FB15kDataset, self).__getitem__(idx)\n\n    def __len__(self):\n        r\"\"\"The number of graphs in the dataset.\"\"\"\n        return super(FB15kDataset, self).__len__()\n\n\nclass WN18Dataset(KnowledgeGraphDataset):\n    r\"\"\"WN18 link prediction dataset.\n\n    The WN18 dataset was introduced in `Translating Embeddings for Modeling\n    Multi-relational Data <http://papers.nips.cc/paper/5071-translating-embeddings-for-modeling-multi-relational-data.pdf>`_.\n    It included the full 18 relations scraped from\n    WordNet for roughly 41,000 synsets. When creating\n    the dataset, a reverse edge with reversed relation\n    types are created for each edge by default.\n\n    WN18 dataset statistics:\n\n    - Nodes: 40943\n    - Number of relation types: 18\n    - Number of reversed relation types: 18\n    - Label Split:\n\n        - Train: 141442\n        - Valid: 5000\n        - Test: 5000\n\n    Parameters\n    ----------\n    reverse : bool\n        Whether to add reverse edge. Default True.\n    raw_dir : str\n        Raw file directory to download/contains the input data directory.\n        Default: ~/.dgl/\n    force_reload : bool\n        Whether to reload the dataset. Default: False\n    verbose : bool\n        Whether to print out progress information. Default: True.\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access.\n\n    Attributes\n    ----------\n    num_nodes: int\n        Number of nodes\n    num_rels: int\n        Number of relation types\n\n    Examples\n    ----------\n    >>> dataset = WN18Dataset()\n    >>> g = dataset.graph\n    >>> e_type = g.edata['e_type']\n    >>>\n    >>> # get data split\n    >>> train_mask = g.edata['train_mask']\n    >>> val_mask = g.edata['val_mask']\n    >>>\n    >>> train_set = th.arange(g.num_edges())[train_mask]\n    >>> val_set = th.arange(g.num_edges())[val_mask]\n    >>>\n    >>> # build train_g\n    >>> train_edges = train_set\n    >>> train_g = g.edge_subgraph(train_edges,\n                                  relabel_nodes=False)\n    >>> train_g.edata['e_type'] = e_type[train_edges];\n    >>>\n    >>> # build val_g\n    >>> val_edges = th.cat([train_edges, val_edges])\n    >>> val_g = g.edge_subgraph(val_edges,\n                                relabel_nodes=False)\n    >>> val_g.edata['e_type'] = e_type[val_edges];\n    >>>\n    >>> # Train, Validation and Test\n    >>>\n    \"\"\"\n\n    def __init__(\n        self,\n        reverse=True,\n        raw_dir=None,\n        force_reload=False,\n        verbose=True,\n        transform=None,\n    ):\n        name = \"wn18\"\n        super(WN18Dataset, self).__init__(\n            name, reverse, raw_dir, force_reload, verbose, transform\n        )\n\n    def __getitem__(self, idx):\n        r\"\"\"Gets the graph object\n\n        Parameters\n        -----------\n        idx: int\n            Item index, WN18Dataset has only one graph object\n\n        Return\n        -------\n        :class:`dgl.DGLGraph`\n\n            The graph contains\n\n            - ``edata['e_type']``: edge relation type\n            - ``edata['train_edge_mask']``: positive training edge mask\n            - ``edata['val_edge_mask']``: positive validation edge mask\n            - ``edata['test_edge_mask']``: positive testing edge mask\n            - ``edata['train_mask']``: training edge set mask (include reversed training edges)\n            - ``edata['val_mask']``: validation edge set mask (include reversed validation edges)\n            - ``edata['test_mask']``: testing edge set mask (include reversed testing edges)\n            - ``ndata['ntype']``: node type. All 0 in this dataset\n        \"\"\"\n        return super(WN18Dataset, self).__getitem__(idx)\n\n    def __len__(self):\n        r\"\"\"The number of graphs in the dataset.\"\"\"\n        return super(WN18Dataset, self).__len__()\n\n\ndef load_data(dataset):\n    r\"\"\"Load knowledge graph dataset for RGCN link prediction tasks\n\n    It supports three datasets: wn18, FB15k and FB15k-237\n\n    Parameters\n    ----------\n    dataset: str\n        The name of the dataset to load.\n\n    Return\n    ------\n    The dataset object.\n    \"\"\"\n    if dataset == \"wn18\":\n        return WN18Dataset()\n    elif dataset == \"FB15k\":\n        return FB15kDataset()\n    elif dataset == \"FB15k-237\":\n        return FB15k237Dataset()\n"
  },
  {
    "path": "python/dgl/data/lrgb.py",
    "content": "import hashlib\nimport os\nimport pickle\n\nimport pandas as pd\nfrom ogb.utils import smiles2graph as smiles2graph_OGB\nfrom tqdm.auto import tqdm\n\nfrom .. import backend as F\n\nfrom ..convert import graph as dgl_graph\nfrom .dgl_dataset import DGLDataset\nfrom .utils import (\n    download,\n    extract_archive,\n    load_graphs,\n    makedirs,\n    save_graphs,\n    Subset,\n)\n\n\nclass PeptidesStructuralDataset(DGLDataset):\n    r\"\"\"Peptides structure dataset for the graph regression task.\n\n    DGL dataset of Peptides-struct in the LRGB benchmark which contains\n    15,535 small peptides represented as their molecular graph (SMILES)\n    with 11 regression targets derived from the peptide's 3D structure.\n\n    The 11 regression targets were precomputed from molecules' 3D structure:\n\n    - Inertia_mass_[a-c]: The principal component of the inertia of the\n      mass, with some normalizations. (Sorted)\n    - Inertia_valence_[a-c]: The principal component of the inertia of the\n      Hydrogen atoms. This is basically a measure of the 3D\n      distribution of hydrogens. (Sorted)\n    - length_[a-c]: The length around the 3 main geometric axis of\n      the 3D objects (without considering atom types). (Sorted)\n    - Spherocity: SpherocityIndex descriptor computed by\n      rdkit.Chem.rdMolDescriptors.CalcSpherocityIndex\n    - Plane_best_fit: Plane of best fit (PBF) descriptor computed by\n      rdkit.Chem.rdMolDescriptors.CalcPBF\n\n    Reference `<https://arxiv.org/abs/2206.08164.pdf>`_\n\n    Statistics:\n\n    - Train examples: 10,873\n    - Valid examples: 2,331\n    - Test examples: 2,331\n    - Average number of nodes: 150.94\n    - Average number of edges: 307.30\n    - Number of atom types: 9\n    - Number of bond types: 3\n\n    Parameters\n    ----------\n    raw_dir : str\n        Directory to store all the downloaded raw datasets.\n        Default: \"~/.dgl/\".\n    force_reload : bool\n        Whether to reload the dataset.\n        Default: False.\n    verbose : bool\n        Whether to print out progress information.\n        Default: False.\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access.\n    smiles2graph : callable\n        A callable function that converts a SMILES string into a graph object.\n        * The default smiles2graph requires rdkit to be installed *\n\n    Examples\n    ---------\n    >>> from dgl.data import PeptidesStructuralDataset\n\n    >>> dataset = PeptidesStructuralDataset()\n    >>> len(dataset)\n    15535\n    >>> dataset.num_atom_types\n    9\n    >>> graph, label = dataset[0]\n    >>> graph\n    Graph(num_nodes=119, num_edges=244,\n        ndata_schemes={'feat': Scheme(shape=(9,), dtype=torch.int64)}\n        edata_schemes={'feat': Scheme(shape=(3,), dtype=torch.int64)})\n\n\n    >>> # support tensor to be index when transform is None\n    >>> # see details in __getitem__ function\n    >>> # get train dataset\n    >>> split_dict = dataset.get_idx_split()\n    >>> trainset = dataset[split_dict[\"train\"]]\n    >>> graph, label = trainset[0]\n    >>> graph\n    Graph(num_nodes=338, num_edges=682,\n        ndata_schemes={'feat': Scheme(shape=(9,), dtype=torch.int64)}\n        edata_schemes={'feat': Scheme(shape=(3,), dtype=torch.int64)})\n\n    >>> # get subset of dataset\n    >>> import torch\n    >>> idx = torch.tensor([0, 1, 2])\n    >>> dataset_subset = dataset[idx]\n    >>> graph, label = dataset_subset[0]\n    >>> graph\n    Graph(num_nodes=119, num_edges=244,\n        ndata_schemes={'feat': Scheme(shape=(9,), dtype=torch.int64)}\n        edata_schemes={'feat': Scheme(shape=(3,), dtype=torch.int64)})\n    \"\"\"\n\n    def __init__(\n        self,\n        raw_dir=None,\n        force_reload=None,\n        verbose=None,\n        transform=None,\n        smiles2graph=smiles2graph_OGB,\n    ):\n        self.smiles2graph = smiles2graph\n        # MD5 hash of the dataset file.\n        self.md5sum_data = \"9786061a34298a0684150f2e4ff13f47\"\n        self.url_stratified_split = \"\"\"\n        https://www.dropbox.com/s/9dfifzft1hqgow6/splits_random_stratified_peptide_structure.pickle?dl=1\n        \"\"\"\n        self.md5sum_stratified_split = \"5a0114bdadc80b94fc7ae974f13ef061\"\n        self.graphs = []\n        self.labels = []\n\n        super().__init__(\n            name=\"Peptides-struc\",\n            raw_dir=raw_dir,\n            url=\"\"\"\n            https://www.dropbox.com/s/464u3303eu2u4zp/peptide_structure_dataset.csv.gz?dl=1\n            \"\"\",\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n    @property\n    def raw_data_path(self):\n        r\"\"\"Path to save the raw dataset file.\"\"\"\n        return os.path.join(self.raw_path, \"peptide_structure_dataset.csv.gz\")\n\n    @property\n    def split_data_path(self):\n        r\"\"\"Path to save the dataset split file.\"\"\"\n        return os.path.join(\n            self.raw_path, \"splits_random_stratified_peptide_structure.pickle\"\n        )\n\n    @property\n    def graph_path(self):\n        r\"\"\"Path to save the processed dataset file.\"\"\"\n        return os.path.join(self.save_path, \"Peptides-struc.bin\")\n\n    @property\n    def num_atom_types(self):\n        r\"\"\"Number of atom types.\"\"\"\n        return 9\n\n    @property\n    def num_bond_types(self):\n        r\"\"\"Number of bond types.\"\"\"\n        return 3\n\n    def _md5sum(self, path):\n        hash_md5 = hashlib.md5()\n        with open(path, \"rb\") as file:\n            buffer = file.read()\n            hash_md5.update(buffer)\n        return hash_md5.hexdigest()\n\n    def download(self):\n        path = download(self.url, path=self.raw_data_path)\n        # Save to disk the MD5 hash of the downloaded file.\n        hash_data = self._md5sum(path)\n        if hash_data != self.md5sum_data:\n            raise ValueError(\"Unexpected MD5 hash of the downloaded file\")\n        open(os.path.join(self.raw_path, hash_data), \"w\").close()\n        # Download train/val/test splits.\n        path_split = download(\n            self.url_stratified_split, path=self.split_data_path\n        )\n        hash_split = self._md5sum(path_split)\n        if hash_split != self.md5sum_stratified_split:\n            raise ValueError(\"Unexpected MD5 hash of the split file\")\n\n    def process(self):\n        data_df = pd.read_csv(self.raw_data_path)\n        smiles_list = data_df[\"smiles\"]\n        target_names = [\n            \"Inertia_mass_a\",\n            \"Inertia_mass_b\",\n            \"Inertia_mass_c\",\n            \"Inertia_valence_a\",\n            \"Inertia_valence_b\",\n            \"Inertia_valence_c\",\n            \"length_a\",\n            \"length_b\",\n            \"length_c\",\n            \"Spherocity\",\n            \"Plane_best_fit\",\n        ]\n        # Normalize to zero mean and unit standard deviation.\n        data_df.loc[:, target_names] = data_df.loc[:, target_names].apply(\n            lambda x: (x - x.mean()) / x.std(), axis=0\n        )\n        if self.verbose:\n            print(\"Converting SMILES strings into graphs...\")\n\n        for i in tqdm(range(len(smiles_list))):\n            smiles = smiles_list[i]\n            y = data_df.iloc[i][target_names]\n            graph = self.smiles2graph(smiles)\n\n            assert len(graph[\"edge_feat\"]) == graph[\"edge_index\"].shape[1]\n            assert len(graph[\"node_feat\"]) == graph[\"num_nodes\"]\n            DGLgraph = dgl_graph(\n                (graph[\"edge_index\"][0], graph[\"edge_index\"][1]),\n                num_nodes=graph[\"num_nodes\"],\n            )\n            DGLgraph.edata[\"feat\"] = F.zerocopy_from_numpy(\n                graph[\"edge_feat\"]\n            ).to(F.int64)\n            DGLgraph.ndata[\"feat\"] = F.zerocopy_from_numpy(\n                graph[\"node_feat\"]\n            ).to(F.int64)\n\n            self.graphs.append(DGLgraph)\n            self.labels.append(y)\n\n        self.labels = F.tensor(self.labels, dtype=F.float32)\n\n    def load(self):\n        self.graphs, label_dict = load_graphs(self.graph_path)\n        self.labels = label_dict[\"labels\"]\n\n    def save(self):\n        save_graphs(\n            self.graph_path, self.graphs, labels={\"labels\": self.labels}\n        )\n\n    def has_cache(self):\n        return os.path.exists(self.graph_path)\n\n    def get_idx_split(self):\n        \"\"\"Get dataset splits.\n\n        Returns:\n            Dict with 'train', 'val', 'test', splits indices.\n        \"\"\"\n        with open(self.split_data_path, \"rb\") as file:\n            split_dict = pickle.load(file)\n        for key in split_dict.keys():\n            split_dict[key] = F.zerocopy_from_numpy(split_dict[key])\n        return split_dict\n\n    def __len__(self):\n        return len(self.graphs)\n\n    def __getitem__(self, idx):\n        \"\"\"Get the idx-th sample.\n\n        Parameters\n        ---------\n        idx : int or tensor\n            The sample index.\n            1-D tensor as `idx` is allowed when transform is None.\n\n        Returns\n        -------\n        (:class:`dgl.DGLGraph`, Tensor)\n            Graph with node feature stored in ``feat`` field and its label.\n        or\n        :class:`dgl.data.utils.Subset`\n            Subset of the dataset at specified indices\n        \"\"\"\n        if F.is_tensor(idx) and idx.dim() == 1:\n            if self._transform is None:\n                return Subset(self, idx.cpu())\n\n            raise ValueError(\n                \"Tensor idx not supported when transform is not None.\"\n            )\n\n        if self._transform is None:\n            return self.graphs[idx], self.labels[idx]\n\n        return self._transform(self.graphs[idx]), self.labels[idx]\n\n\nclass PeptidesFunctionalDataset(DGLDataset):\n    r\"\"\"Peptides functional dataset for the graph classification task.\n\n    DGL dataset of Peptides-func in the LRGB benchmark which contains\n    15,535 peptides represented as their molecular graph(SMILES) with\n    10-way multi-task binary classification of their functional classes.\n\n    The 10 classes represent the following functional classes (in order):\n        ['antifungal', 'cell_cell_communication', 'anticancer',\n        'drug_delivery_vehicle', 'antimicrobial', 'antiviral',\n        'antihypertensive', 'antibacterial', 'antiparasitic', 'toxic']\n\n    Reference `<https://arxiv.org/abs/2206.08164.pdf>`_\n\n    Statistics:\n\n    - Train examples: 10,873\n    - Valid examples: 2,331\n    - Test examples: 2,331\n    - Average number of nodes: 150.94\n    - Average number of edges: 307.30\n    - Number of atom types: 9\n    - Number of bond types: 3\n\n    Parameters\n    ----------\n    raw_dir : str\n        Directory to store all the downloaded raw datasets.\n        Default: \"~/.dgl/\".\n    force_reload : bool\n        Whether to reload the dataset.\n        Default: False.\n    verbose : bool\n        Whether to print out progress information.\n        Default: False.\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access.\n    smiles2graph (callable):\n        A callable function that converts a SMILES string into a graph object.\n        * The default smiles2graph requires rdkit to be installed *\n\n    Examples\n    ---------\n    >>> from dgl.data import PeptidesFunctionalDataset\n\n    >>> dataset = PeptidesFunctionalDataset()\n    >>> len(dataset)\n    15535\n    >>> dataset.num_classes\n    10\n    >>> graph, label = dataset[0]\n    >>> graph\n    Graph(num_nodes=119, num_edges=244,\n        ndata_schemes={'feat': Scheme(shape=(9,), dtype=torch.int64)}\n        edata_schemes={'feat': Scheme(shape=(3,), dtype=torch.int64)})\n\n\n    >>> # support tensor to be index when transform is None\n    >>> # see details in __getitem__ function\n    >>> # get train dataset\n    >>> split_dict = dataset.get_idx_split()\n    >>> trainset = dataset[split_dict[\"train\"]]\n    >>> graph, label = trainset[0]\n    >>> graph\n    Graph(num_nodes=338, num_edges=682,\n        ndata_schemes={'feat': Scheme(shape=(9,), dtype=torch.int64)}\n        edata_schemes={'feat': Scheme(shape=(3,), dtype=torch.int64)})\n\n    >>> # get subset of dataset\n    >>> import torch\n    >>> idx = torch.tensor([0, 1, 2])\n    >>> dataset_subset = dataset[idx]\n    >>> graph, label = dataset_subset[0]\n    >>> graph\n    Graph(num_nodes=119, num_edges=244,\n        ndata_schemes={'feat': Scheme(shape=(9,), dtype=torch.int64)}\n        edata_schemes={'feat': Scheme(shape=(3,), dtype=torch.int64)})\n    \"\"\"\n\n    def __init__(\n        self,\n        raw_dir=None,\n        force_reload=None,\n        verbose=None,\n        transform=None,\n        smiles2graph=smiles2graph_OGB,\n    ):\n        self.smiles2graph = smiles2graph\n        # MD5 hash of the dataset file.\n        self.md5sum_data = \"701eb743e899f4d793f0e13c8fa5a1b4\"\n        self.url_stratified_split = \"\"\"\n        https://www.dropbox.com/s/j4zcnx2eipuo0xz/splits_random_stratified_peptide.pickle?dl=1\n        \"\"\"\n        self.md5sum_stratified_split = \"5a0114bdadc80b94fc7ae974f13ef061\"\n        self.graphs = []\n        self.labels = []\n\n        super().__init__(\n            name=\"Peptides-func\",\n            raw_dir=raw_dir,\n            url=\"\"\"\n            https://www.dropbox.com/s/ol2v01usvaxbsr8/peptide_multi_class_dataset.csv.gz?dl=1\n            \"\"\",\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n    @property\n    def raw_data_path(self):\n        r\"\"\"Path to save the raw dataset file.\"\"\"\n        return os.path.join(self.raw_path, \"peptide_multi_class_dataset.csv.gz\")\n\n    @property\n    def split_data_path(self):\n        r\"\"\"Path to save the dataset split file.\"\"\"\n        return os.path.join(\n            self.raw_path, \"splits_random_stratified_peptide.pickle\"\n        )\n\n    @property\n    def graph_path(self):\n        r\"\"\"Path to save the processed dataset file.\"\"\"\n        return os.path.join(self.save_path, \"Peptides-func.bin\")\n\n    @property\n    def num_atom_types(self):\n        r\"\"\"Number of atom types.\"\"\"\n        return 9\n\n    @property\n    def num_bond_types(self):\n        r\"\"\"Number of bond types.\"\"\"\n        return 3\n\n    @property\n    def num_classes(self):\n        r\"\"\"Number of graph classes.\"\"\"\n        return 10\n\n    def _md5sum(self, path):\n        hash_md5 = hashlib.md5()\n        with open(path, \"rb\") as file:\n            buffer = file.read()\n            hash_md5.update(buffer)\n        return hash_md5.hexdigest()\n\n    def download(self):\n        path = download(self.url, path=self.raw_data_path)\n        # Save to disk the MD5 hash of the downloaded file.\n        hash_data = self._md5sum(path)\n        if hash_data != self.md5sum_data:\n            raise ValueError(\"Unexpected MD5 hash of the downloaded file\")\n        open(os.path.join(self.raw_path, hash_data), \"w\").close()\n        # Download train/val/test splits.\n        path_split = download(\n            self.url_stratified_split, path=self.split_data_path\n        )\n        hash_split = self._md5sum(path_split)\n        if hash_split != self.md5sum_stratified_split:\n            raise ValueError(\"Unexpected MD5 hash of the split file\")\n\n    def process(self):\n        data_df = pd.read_csv(self.raw_data_path)\n        smiles_list = data_df[\"smiles\"]\n        if self.verbose:\n            print(\"Converting SMILES strings into graphs...\")\n\n        for i in tqdm(range(len(smiles_list))):\n            smiles = smiles_list[i]\n            graph = self.smiles2graph(smiles)\n\n            assert len(graph[\"edge_feat\"]) == graph[\"edge_index\"].shape[1]\n            assert len(graph[\"node_feat\"]) == graph[\"num_nodes\"]\n            DGLgraph = dgl_graph(\n                (graph[\"edge_index\"][0], graph[\"edge_index\"][1]),\n                num_nodes=graph[\"num_nodes\"],\n            )\n            DGLgraph.edata[\"feat\"] = F.zerocopy_from_numpy(\n                graph[\"edge_feat\"]\n            ).to(F.int64)\n            DGLgraph.ndata[\"feat\"] = F.zerocopy_from_numpy(\n                graph[\"node_feat\"]\n            ).to(F.int64)\n            self.graphs.append(DGLgraph)\n            self.labels.append(eval(data_df[\"labels\"].iloc[i]))\n        self.labels = F.tensor(self.labels, dtype=F.float32)\n\n    def load(self):\n        self.graphs, label_dict = load_graphs(self.graph_path)\n        self.labels = label_dict[\"labels\"]\n\n    def save(self):\n        save_graphs(\n            self.graph_path, self.graphs, labels={\"labels\": self.labels}\n        )\n\n    def has_cache(self):\n        return os.path.exists(self.graph_path)\n\n    def get_idx_split(self):\n        \"\"\"Get dataset splits.\n\n        Returns:\n            Dict with 'train', 'val', 'test', splits indices.\n        \"\"\"\n        with open(self.split_data_path, \"rb\") as file:\n            split_dict = pickle.load(file)\n        for key in split_dict.keys():\n            split_dict[key] = F.zerocopy_from_numpy(split_dict[key])\n        return split_dict\n\n    def __len__(self):\n        return len(self.graphs)\n\n    def __getitem__(self, idx):\n        \"\"\"Get the idx-th sample.\n\n        Parameters\n        ---------\n        idx : int or tensor\n            The sample index.\n            1-D tensor as `idx` is allowed when transform is None.\n\n        Returns\n        -------\n        (:class:`dgl.DGLGraph`, Tensor)\n            Graph with node feature stored in ``feat`` field and its label.\n        or\n        :class:`dgl.data.utils.Subset`\n            Subset of the dataset at specified indices\n        \"\"\"\n        if F.is_tensor(idx) and idx.dim() == 1:\n            if self._transform is None:\n                return Subset(self, idx.cpu())\n\n            raise ValueError(\n                \"Tensor idx not supported when transform is not None.\"\n            )\n\n        if self._transform is None:\n            return self.graphs[idx], self.labels[idx]\n\n        return self._transform(self.graphs[idx]), self.labels[idx]\n\n\nclass VOCSuperpixelsDataset(DGLDataset):\n    r\"\"\"VOCSuperpixels dataset for the node classification task.\n\n    DGL dataset of PascalVOC-SP in the LRGB benchmark which contains image\n    superpixels and a semantic segmentation label for each node superpixel.\n\n    color map\n    0=background, 1=aeroplane, 2=bicycle, 3=bird, 4=boat, 5=bottle,\n    6=bus, 7=car, 8=cat, 9=chair, 10=cow,\n    11=diningtable, 12=dog, 13=horse, 14=motorbike, 15=person,\n    16=potted plant, 17=sheep, 18=sofa, 19=train, 20=tv/monitor\n\n    Reference `<https://arxiv.org/abs/2206.08164.pdf>`_\n\n    Statistics:\n\n    - Train examples: 8,498\n    - Valid examples: 1,428\n    - Test examples: 1,429\n    - Average number of nodes: 479.40\n    - Average number of edges: 2,710.48\n\n    Parameters\n    ----------\n    raw_dir : str\n        Directory to store all the downloaded raw datasets.\n        Default: \"~/.dgl/\".\n    split : str\n        Should be chosen from [\"train\", \"val\", \"test\"]\n        Default: \"train\".\n    construct_format : str, optional\n        Option to select the graph construction format.\n        Should be chosen from the following formats:\n\n        - \"edge_wt_only_coord\": the graphs are 8-nn graphs with the edge weights\n          computed based on only spatial coordinates of superpixel nodes.\n        - \"edge_wt_coord_feat\": the graphs are 8-nn graphs with the edge weights\n          computed based on combination of spatial coordinates and feature\n          values of superpixel nodes.\n        - \"edge_wt_region_boundary\": the graphs region boundary graphs where two\n          regions (i.e. superpixel nodes) have an edge between them if they\n          share a boundary in the original image.\n\n        Default: \"edge_wt_region_boundary\".\n    slic_compactness : int, optional\n        Option to select compactness of slic that was used for superpixels\n        Should be chosen from [10, 30]\n        Default: 30.\n    force_reload : bool\n        Whether to reload the dataset.\n        Default: False.\n    verbose : bool\n        Whether to print out progress information.\n        Default: False.\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access.\n\n    Examples\n    ---------\n    >>> from dgl.data import VOCSuperpixelsDataset\n\n    >>> train_dataset = VOCSuperpixelsDataset(split=\"train\")\n    >>> len(train_dataset)\n    8498\n    >>> train_dataset.num_classes\n    21\n    >>> graph = train_dataset[0]\n    >>> graph\n    Graph(num_nodes=460, num_edges=2632,\n        ndata_schemes={'feat': Scheme(shape=(14,), dtype=torch.float32),\n                        'label': Scheme(shape=(), dtype=torch.int32)}\n        edata_schemes={'feat': Scheme(shape=(2,), dtype=torch.float32)})\n\n    >>> # support tensor to be index when transform is None\n    >>> # see details in __getitem__ function\n    >>> import torch\n    >>> idx = torch.tensor([0, 1, 2])\n    >>> train_dataset_subset = train_dataset[idx]\n    >>> train_dataset_subset[0]\n    Graph(num_nodes=460, num_edges=2632,\n        ndata_schemes={'feat': Scheme(shape=(14,), dtype=torch.float32),\n                        'label': Scheme(shape=(), dtype=torch.int32)}\n        edata_schemes={'feat': Scheme(shape=(2,), dtype=torch.float32)})\n    \"\"\"\n\n    urls = {\n        10: {\n            \"edge_wt_only_coord\": \"\"\"\n            https://www.dropbox.com/s/rk6pfnuh7tq3t37/voc_superpixels_edge_wt_only_coord.zip?dl=1\n            \"\"\",\n            \"edge_wt_coord_feat\": \"\"\"\n            https://www.dropbox.com/s/2a53nmfp6llqg8y/voc_superpixels_edge_wt_coord_feat.zip?dl=1\n            \"\"\",\n            \"edge_wt_region_boundary\": \"\"\"\n            https://www.dropbox.com/s/6pfz2mccfbkj7r3/voc_superpixels_edge_wt_region_boundary.zip?dl=1\n            \"\"\",\n        },\n        30: {\n            \"edge_wt_only_coord\": \"\"\"\n            https://www.dropbox.com/s/toqulkdpb1jrswk/voc_superpixels_edge_wt_only_coord.zip?dl=1\n            \"\"\",\n            \"edge_wt_coord_feat\": \"\"\"\n            https://www.dropbox.com/s/xywki8ysj63584d/voc_superpixels_edge_wt_coord_feat.zip?dl=1\n            \"\"\",\n            \"edge_wt_region_boundary\": \"\"\"\n            https://www.dropbox.com/s/8x722ai272wqwl4/voc_superpixels_edge_wt_region_boundary.zip?dl=1\n            \"\"\",\n        },\n    }\n\n    def __init__(\n        self,\n        raw_dir=None,\n        split=\"train\",\n        construct_format=\"edge_wt_region_boundary\",\n        slic_compactness=30,\n        force_reload=None,\n        verbose=None,\n        transform=None,\n    ):\n        assert split in [\"train\", \"val\", \"test\"], \"split not valid.\"\n        assert construct_format in [\n            \"edge_wt_only_coord\",\n            \"edge_wt_coord_feat\",\n            \"edge_wt_region_boundary\",\n        ], \"construct_format not valid.\"\n        assert slic_compactness in [10, 30], \"slic_compactness not valid.\"\n\n        self.construct_format = construct_format\n        self.slic_compactness = slic_compactness\n        self.split = split\n        self.graphs = []\n\n        super().__init__(\n            name=\"PascalVOC-SP\",\n            raw_dir=raw_dir,\n            url=self.urls[self.slic_compactness][self.construct_format],\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n    @property\n    def save_path(self):\n        r\"\"\"Directory to save the processed dataset.\"\"\"\n        return os.path.join(\n            self.raw_path,\n            \"slic_compactness_\" + str(self.slic_compactness),\n            self.construct_format,\n        )\n\n    @property\n    def raw_data_path(self):\n        r\"\"\"Path to save the raw dataset file.\"\"\"\n        return os.path.join(self.save_path, f\"{self.split}.pickle\")\n\n    @property\n    def graph_path(self):\n        r\"\"\"Path to save the processed dataset file.\"\"\"\n        return os.path.join(self.save_path, f\"processed_{self.split}.pkl\")\n\n    @property\n    def num_classes(self):\n        r\"\"\"Number of classes for each node.\"\"\"\n        return 21\n\n    def __len__(self):\n        r\"\"\"The number of examples in the dataset.\"\"\"\n        return len(self.graphs)\n\n    def download(self):\n        zip_file_path = os.path.join(\n            self.raw_path, \"voc_superpixels_\" + self.construct_format + \".zip\"\n        )\n        path = download(self.url, path=zip_file_path)\n        extract_archive(path, self.raw_path, overwrite=True)\n        makedirs(self.save_path)\n        os.rename(\n            os.path.join(\n                self.raw_path, \"voc_superpixels_\" + self.construct_format\n            ),\n            self.save_path,\n        )\n        os.unlink(path)\n\n    def process(self):\n        with open(self.raw_data_path, \"rb\") as file:\n            graphs = pickle.load(file)\n\n        for idx in tqdm(\n            range(len(graphs)), desc=f\"Processing {self.split} dataset\"\n        ):\n            graph = graphs[idx]\n\n            \"\"\"\n            Each `graph` is a tuple (x, edge_attr, edge_index, y)\n                Shape of x : [num_nodes, 14]\n                Shape of edge_attr : [num_edges, 1] or [num_edges, 2]\n                Shape of edge_index : [2, num_edges]\n                Shape of y : [num_nodes]\n            \"\"\"\n            DGLgraph = dgl_graph(\n                (graph[2][0], graph[2][1]),\n                num_nodes=len(graph[3]),\n            )\n            DGLgraph.ndata[\"feat\"] = graph[0].to(F.float32)\n            DGLgraph.edata[\"feat\"] = graph[1].to(F.float32)\n            DGLgraph.ndata[\"label\"] = F.tensor(graph[3])\n            self.graphs.append(DGLgraph)\n\n    def load(self):\n        with open(self.graph_path, \"rb\") as file:\n            graphs = pickle.load(file)\n            self.graphs = graphs\n\n    def save(self):\n        with open(os.path.join(self.graph_path), \"wb\") as file:\n            pickle.dump(self.graphs, file)\n\n    def has_cache(self):\n        return os.path.exists(self.graph_path)\n\n    def __getitem__(self, idx):\n        r\"\"\"Get the idx-th sample.\n\n        Parameters\n        ---------\n        idx : int or tensor\n            The sample index.\n            1-D tensor as `idx` is allowed when transform is None.\n\n        Returns\n        -------\n        :class:`dgl.DGLGraph`\n            graph structure, node features, node labels and edge features.\n\n            - ``ndata['feat']``: node features\n            - ``ndata['label']``: node labels\n            - ``edata['feat']``: edge features\n        or\n        :class:`dgl.data.utils.Subset`\n            Subset of the dataset at specified indices\n        \"\"\"\n        if F.is_tensor(idx) and idx.dim() == 1:\n            if self._transform is None:\n                return Subset(self, idx.cpu())\n\n            raise ValueError(\n                \"Tensor idx not supported when transform is not None.\"\n            )\n\n        if self._transform is None:\n            return self.graphs[idx]\n\n        return self._transform(self.graphs[idx])\n\n\nclass COCOSuperpixelsDataset(DGLDataset):\n    r\"\"\"COCO superpixel dataset for the node classification task.\n\n    DGL dataset of COCO-SP in the LRGB benckmark which contains image\n    superpixels and a semantic segmentation label for each node superpixel.\n\n    Based on the COCO 2017 dataset. Original source `<https://cocodataset.org>`_\n\n    Reference `<https://arxiv.org/abs/2206.08164.pdf>`_\n\n    Statistics:\n\n    - Train examples: 113,286\n    - Valid examples: 5,000\n    - Test examples: 5,000\n    - Average number of nodes: 476.88\n    - Average number of edges: 2,710.48\n    - Number of node classes: 81\n\n    Parameters\n    ----------\n    raw_dir : str\n        Directory to store all the downloaded raw datasets.\n        Default: \"~/.dgl/\".\n    split : str\n        Should be chosen from [\"train\", \"val\", \"test\"]\n        Default: \"train\".\n    construct_format : str, optional\n        Option to select the graph construction format.\n        Should be chosen from the following formats:\n\n        - \"edge_wt_only_coord\": the graphs are 8-nn graphs with the edge weights\n          computed based on only spatial coordinates of superpixel nodes.\n        - \"edge_wt_coord_feat\": the graphs are 8-nn graphs with the edge weights\n          computed based on combination of spatial coordinates and feature\n          values of superpixel nodes.\n        - \"edge_wt_region_boundary\": the graphs region boundary graphs where two\n          regions (i.e. superpixel nodes) have an edge between them if they\n          share a boundary in the original image.\n\n        Default: \"edge_wt_region_boundary\".\n    slic_compactness : int, optional\n        Option to select compactness of slic that was used for superpixels\n        Should be chosen from [10, 30]\n        Default: 30.\n    force_reload : bool\n        Whether to reload the dataset.\n        Default: False.\n    verbose : bool\n        Whether to print out progress information.\n        Default: False.\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access.\n\n    Examples\n    ---------\n    >>> from dgl.data import COCOSuperpixelsDataset\n\n    >>> train_dataset = COCOSuperpixelsDataset(split=\"train\")\n    >>> len(train_dataset)\n    113286\n    >>> train_dataset.num_classes\n    81\n    >>> graph = train_dataset[0]\n    >>> graph\n    Graph(num_nodes=488, num_edges=2766,\n        ndata_schemes={'feat': Scheme(shape=(14,), dtype=torch.float32),\n                        'label': Scheme(shape=(), dtype=torch.uint8)}\n        edata_schemes={'feat': Scheme(shape=(2,), dtype=torch.float32)})\n\n    >>> # support tensor to be index when transform is None\n    >>> # see details in __getitem__ function\n    >>> import torch\n    >>> idx = torch.tensor([0, 1, 2])\n    >>> train_dataset_subset = train_dataset[idx]\n    >>> train_dataset_subset[0]\n    Graph(num_nodes=488, num_edges=2766,\n        ndata_schemes={'feat': Scheme(shape=(14,), dtype=torch.float32),\n                        'label': Scheme(shape=(), dtype=torch.uint8)}\n        edata_schemes={'feat': Scheme(shape=(2,), dtype=torch.float32)})\n    \"\"\"\n\n    urls = {\n        10: {\n            \"edge_wt_only_coord\": \"\"\"\n            https://www.dropbox.com/s/prqizdep8gk0ndk/coco_superpixels_edge_wt_only_coord.zip?dl=1\n            \"\"\",\n            \"edge_wt_coord_feat\": \"\"\"\n            https://www.dropbox.com/s/zftoyln1pkcshcg/coco_superpixels_edge_wt_coord_feat.zip?dl=1\n            \"\"\",\n            \"edge_wt_region_boundary\": \"\"\"\n            https://www.dropbox.com/s/fhihfcyx2y978u8/coco_superpixels_edge_wt_region_boundary.zip?dl=1\n            \"\"\",\n        },\n        30: {\n            \"edge_wt_only_coord\": \"\"\"\n            https://www.dropbox.com/s/hrbfkxmc5z9lsaz/coco_superpixels_edge_wt_only_coord.zip?dl=1\n            \"\"\",\n            \"edge_wt_coord_feat\": \"\"\"\n            https://www.dropbox.com/s/4rfa2d5ij1gfu9b/coco_superpixels_edge_wt_coord_feat.zip?dl=1\n            \"\"\",\n            \"edge_wt_region_boundary\": \"\"\"\n            https://www.dropbox.com/s/r6ihg1f4pmyjjy0/coco_superpixels_edge_wt_region_boundary.zip?dl=1\n            \"\"\",\n        },\n    }\n\n    def __init__(\n        self,\n        raw_dir=None,\n        split=\"train\",\n        construct_format=\"edge_wt_region_boundary\",\n        slic_compactness=30,\n        force_reload=None,\n        verbose=None,\n        transform=None,\n    ):\n        assert split in [\"train\", \"val\", \"test\"], \"split not valid.\"\n        assert construct_format in [\n            \"edge_wt_only_coord\",\n            \"edge_wt_coord_feat\",\n            \"edge_wt_region_boundary\",\n        ], \"construct_format not valid.\"\n        assert slic_compactness in [10, 30], \"slic_compactness not valid.\"\n\n        self.construct_format = construct_format\n        self.slic_compactness = slic_compactness\n        self.split = split\n        self.graphs = []\n\n        super().__init__(\n            name=\"COCO-SP\",\n            raw_dir=raw_dir,\n            url=self.urls[self.slic_compactness][self.construct_format],\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n    @property\n    def save_path(self):\n        r\"\"\"Directory to save the processed dataset.\"\"\"\n        return os.path.join(\n            self.raw_path,\n            \"slic_compactness_\" + str(self.slic_compactness),\n            self.construct_format,\n        )\n\n    @property\n    def raw_data_path(self):\n        r\"\"\"Path to save the raw dataset file.\"\"\"\n        return os.path.join(self.save_path, f\"{self.split}.pickle\")\n\n    @property\n    def graph_path(self):\n        r\"\"\"Path to save the processed dataset file.\"\"\"\n        return os.path.join(self.save_path, f\"processed_{self.split}.pkl\")\n\n    @property\n    def num_classes(self):\n        r\"\"\"Number of classes for each node.\"\"\"\n        return 81\n\n    def __len__(self):\n        r\"\"\"The number of examples in the dataset.\"\"\"\n        return len(self.graphs)\n\n    def download(self):\n        zip_file_path = os.path.join(\n            self.raw_path, \"coco_superpixels_\" + self.construct_format + \".zip\"\n        )\n        path = download(self.url, path=zip_file_path, overwrite=True)\n        extract_archive(path, self.raw_path, overwrite=True)\n        makedirs(self.save_path)\n        os.rename(\n            os.path.join(\n                self.raw_path, \"coco_superpixels_\" + self.construct_format\n            ),\n            self.save_path,\n        )\n        os.unlink(path)\n\n    def label_remap(self):\n        # Util function to remap the labels as the original label\n        # idxs are not contiguous\n        # fmt: off\n        original_label_idx = [\n            0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19,\n            20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39,\n            40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57,\n            58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78,\n            79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90\n        ]\n        # fmt: on\n        label_map = {}\n        for i, key in enumerate(original_label_idx):\n            label_map[key] = i\n\n        return label_map\n\n    def process(self):\n        with open(self.raw_data_path, \"rb\") as file:\n            graphs = pickle.load(file)\n\n        label_map = self.label_remap()\n\n        for idx in tqdm(\n            range(len(graphs)), desc=f\"Processing {self.split} dataset\"\n        ):\n            graph = graphs[idx]\n\n            \"\"\"\n            Each `graph` is a tuple (x, edge_attr, edge_index, y)\n                Shape of x : [num_nodes, 14]\n                Shape of edge_attr : [num_edges, 1] or [num_edges, 2]\n                Shape of edge_index : [2, num_edges]\n                Shape of y : [num_nodes]\n            \"\"\"\n\n            DGLgraph = dgl_graph(\n                (graph[2][0], graph[2][1]),\n                num_nodes=len(graph[3]),\n            )\n            DGLgraph.ndata[\"feat\"] = graph[0].to(F.float32)\n            DGLgraph.edata[\"feat\"] = graph[1].to(F.float32)\n\n            y = F.tensor(graph[3])\n\n            # Label remapping. See self.label_remap() func\n            for i, label in enumerate(y):\n                y[i] = label_map[label.item()]\n\n            DGLgraph.ndata[\"label\"] = y\n            self.graphs.append(DGLgraph)\n\n    def load(self):\n        with open(self.graph_path, \"rb\") as file:\n            graphs = pickle.load(file)\n            self.graphs = graphs\n\n    def save(self):\n        with open(os.path.join(self.graph_path), \"wb\") as file:\n            pickle.dump(self.graphs, file)\n\n    def has_cache(self):\n        return os.path.exists(self.graph_path)\n\n    def __getitem__(self, idx):\n        r\"\"\"Get the idx-th sample.\n\n        Parameters\n        ---------\n        idx : int or tensor\n            The sample index.\n            1-D tensor as `idx` is allowed when transform is None.\n\n        Returns\n        -------\n        :class:`dgl.DGLGraph`\n            graph structure, node features, node labels and edge features.\n\n            - ``ndata['feat']``: node features\n            - ``ndata['label']``: node labels\n            - ``edata['feat']``: edge features\n        or\n        :class:`dgl.data.utils.Subset`\n            Subset of the dataset at specified indices\n        \"\"\"\n        if F.is_tensor(idx) and idx.dim() == 1:\n            if self._transform is None:\n                return Subset(self, idx.cpu())\n            raise ValueError(\n                \"Tensor idx not supported when transform is not None.\"\n            )\n\n        if self._transform is None:\n            return self.graphs[idx]\n\n        return self._transform(self.graphs[idx])\n"
  },
  {
    "path": "python/dgl/data/minigc.py",
    "content": "\"\"\"A mini synthetic dataset for graph classification benchmark.\"\"\"\nimport math\nimport os\n\nimport networkx as nx\nimport numpy as np\n\nfrom .. import backend as F\nfrom ..convert import from_networkx\nfrom ..transforms import add_self_loop\nfrom .dgl_dataset import DGLDataset\nfrom .utils import load_graphs, makedirs, save_graphs\n\n__all__ = [\"MiniGCDataset\"]\n\n\nclass MiniGCDataset(DGLDataset):\n    \"\"\"The synthetic graph classification dataset class.\n\n    The datset contains 8 different types of graphs.\n\n    - class 0 : cycle graph\n    - class 1 : star graph\n    - class 2 : wheel graph\n    - class 3 : lollipop graph\n    - class 4 : hypercube graph\n    - class 5 : grid graph\n    - class 6 : clique graph\n    - class 7 : circular ladder graph\n\n    Parameters\n    ----------\n    num_graphs: int\n        Number of graphs in this dataset.\n    min_num_v: int\n        Minimum number of nodes for graphs\n    max_num_v: int\n        Maximum number of nodes for graphs\n    seed: int, default is 0\n        Random seed for data generation\n    transform: callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access.\n\n    Attributes\n    ----------\n    num_graphs : int\n        Number of graphs\n    min_num_v : int\n        The minimum number of nodes\n    max_num_v : int\n        The maximum number of nodes\n    num_classes : int\n        The number of classes\n\n    Examples\n    --------\n    >>> data = MiniGCDataset(100, 16, 32, seed=0)\n\n    The dataset instance is an iterable\n\n    >>> len(data)\n    100\n    >>> g, label = data[64]\n    >>> g\n    Graph(num_nodes=20, num_edges=82,\n          ndata_schemes={}\n          edata_schemes={})\n    >>> label\n    tensor(5)\n\n    Batch the graphs and labels for mini-batch training\n\n    >>> graphs, labels = zip(*[data[i] for i in range(16)])\n    >>> batched_graphs = dgl.batch(graphs)\n    >>> batched_labels = torch.tensor(labels)\n    >>> batched_graphs\n    Graph(num_nodes=356, num_edges=1060,\n          ndata_schemes={}\n          edata_schemes={})\n    \"\"\"\n\n    def __init__(\n        self,\n        num_graphs,\n        min_num_v,\n        max_num_v,\n        seed=0,\n        save_graph=True,\n        force_reload=False,\n        verbose=False,\n        transform=None,\n    ):\n        self.num_graphs = num_graphs\n        self.min_num_v = min_num_v\n        self.max_num_v = max_num_v\n        self.seed = seed\n        self.save_graph = save_graph\n\n        super(MiniGCDataset, self).__init__(\n            name=\"minigc\",\n            hash_key=(num_graphs, min_num_v, max_num_v, seed),\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n    def process(self):\n        self.graphs = []\n        self.labels = []\n        self._generate(self.seed)\n\n    def __len__(self):\n        \"\"\"Return the number of graphs in the dataset.\"\"\"\n        return len(self.graphs)\n\n    def __getitem__(self, idx):\n        \"\"\"Get the idx-th sample.\n\n        Parameters\n        ---------\n        idx : int\n            The sample index.\n\n        Returns\n        -------\n        (:class:`dgl.Graph`, Tensor)\n            The graph and its label.\n        \"\"\"\n        if self._transform is None:\n            g = self.graphs[idx]\n        else:\n            g = self._transform(self.graphs[idx])\n        return g, self.labels[idx]\n\n    def has_cache(self):\n        graph_path = os.path.join(\n            self.save_path, \"dgl_graph_{}.bin\".format(self.hash)\n        )\n        if os.path.exists(graph_path):\n            return True\n\n        return False\n\n    def save(self):\n        \"\"\"save the graph list and the labels\"\"\"\n        if self.save_graph:\n            graph_path = os.path.join(\n                self.save_path, \"dgl_graph_{}.bin\".format(self.hash)\n            )\n            save_graphs(str(graph_path), self.graphs, {\"labels\": self.labels})\n\n    def load(self):\n        graphs, label_dict = load_graphs(\n            os.path.join(self.save_path, \"dgl_graph_{}.bin\".format(self.hash))\n        )\n        self.graphs = graphs\n        self.labels = label_dict[\"labels\"]\n\n    @property\n    def num_classes(self):\n        \"\"\"Number of classes.\"\"\"\n        return 8\n\n    def _generate(self, seed):\n        if seed is not None:\n            np.random.seed(seed)\n        self._gen_cycle(self.num_graphs // 8)\n        self._gen_star(self.num_graphs // 8)\n        self._gen_wheel(self.num_graphs // 8)\n        self._gen_lollipop(self.num_graphs // 8)\n        self._gen_hypercube(self.num_graphs // 8)\n        self._gen_grid(self.num_graphs // 8)\n        self._gen_clique(self.num_graphs // 8)\n        self._gen_circular_ladder(self.num_graphs - len(self.graphs))\n        # preprocess\n        for i in range(self.num_graphs):\n            # convert to DGLGraph, and add self loops\n            self.graphs[i] = add_self_loop(from_networkx(self.graphs[i]))\n        self.labels = F.tensor(np.array(self.labels).astype(np.int64))\n\n    def _gen_cycle(self, n):\n        for _ in range(n):\n            num_v = np.random.randint(self.min_num_v, self.max_num_v)\n            g = nx.cycle_graph(num_v)\n            self.graphs.append(g)\n            self.labels.append(0)\n\n    def _gen_star(self, n):\n        for _ in range(n):\n            num_v = np.random.randint(self.min_num_v, self.max_num_v)\n            # nx.star_graph(N) gives a star graph with N+1 nodes\n            g = nx.star_graph(num_v - 1)\n            self.graphs.append(g)\n            self.labels.append(1)\n\n    def _gen_wheel(self, n):\n        for _ in range(n):\n            num_v = np.random.randint(self.min_num_v, self.max_num_v)\n            g = nx.wheel_graph(num_v)\n            self.graphs.append(g)\n            self.labels.append(2)\n\n    def _gen_lollipop(self, n):\n        for _ in range(n):\n            num_v = np.random.randint(self.min_num_v, self.max_num_v)\n            path_len = np.random.randint(2, num_v // 2)\n            g = nx.lollipop_graph(m=num_v - path_len, n=path_len)\n            self.graphs.append(g)\n            self.labels.append(3)\n\n    def _gen_hypercube(self, n):\n        for _ in range(n):\n            num_v = np.random.randint(self.min_num_v, self.max_num_v)\n            g = nx.hypercube_graph(int(math.log(num_v, 2)))\n            g = nx.convert_node_labels_to_integers(g)\n            self.graphs.append(g)\n            self.labels.append(4)\n\n    def _gen_grid(self, n):\n        for _ in range(n):\n            num_v = np.random.randint(self.min_num_v, self.max_num_v)\n            assert num_v >= 4, (\n                \"We require a grid graph to contain at least two \"\n                \"rows and two columns, thus 4 nodes, got {:d} \"\n                \"nodes\".format(num_v)\n            )\n            n_rows = np.random.randint(2, num_v // 2)\n            n_cols = num_v // n_rows\n            g = nx.grid_graph([n_rows, n_cols])\n            g = nx.convert_node_labels_to_integers(g)\n            self.graphs.append(g)\n            self.labels.append(5)\n\n    def _gen_clique(self, n):\n        for _ in range(n):\n            num_v = np.random.randint(self.min_num_v, self.max_num_v)\n            g = nx.complete_graph(num_v)\n            self.graphs.append(g)\n            self.labels.append(6)\n\n    def _gen_circular_ladder(self, n):\n        for _ in range(n):\n            num_v = np.random.randint(self.min_num_v, self.max_num_v)\n            g = nx.circular_ladder_graph(num_v // 2)\n            self.graphs.append(g)\n            self.labels.append(7)\n"
  },
  {
    "path": "python/dgl/data/movielens.py",
    "content": "\"\"\"MovieLens dataset\"\"\"\nimport os\n\nimport numpy as np\nimport pandas as pd\n\nfrom torch import LongTensor, Tensor\n\nfrom ..base import dgl_warning\nfrom ..convert import heterograph\nfrom .dgl_dataset import DGLDataset\n\nfrom .utils import (\n    _get_dgl_url,\n    download,\n    extract_archive,\n    load_graphs,\n    load_info,\n    save_graphs,\n    save_info,\n    split_dataset,\n)\n\nGENRES_ML_100K = [\n    \"unknown\",\n    \"Action\",\n    \"Adventure\",\n    \"Animation\",\n    \"Children\",\n    \"Comedy\",\n    \"Crime\",\n    \"Documentary\",\n    \"Drama\",\n    \"Fantasy\",\n    \"Film-Noir\",\n    \"Horror\",\n    \"Musical\",\n    \"Mystery\",\n    \"Romance\",\n    \"Sci-Fi\",\n    \"Thriller\",\n    \"War\",\n    \"Western\",\n]\nGENRES_ML_1M = GENRES_ML_100K[1:]\nGENRES_ML_10M = GENRES_ML_100K + [\"IMAX\"]\n\ntry:\n    import torch\nexcept ImportError:\n    HAS_TORCH = False\nelse:\n    HAS_TORCH = True\n\n\ndef check_pytorch():\n    \"\"\"Check if PyTorch is the backend.\"\"\"\n    if not HAS_TORCH:\n        raise ModuleNotFoundError(\n            \"MovieLensDataset requires PyTorch to be the backend.\"\n        )\n\n\nclass MovieLensDataset(DGLDataset):\n    r\"\"\"MovieLens dataset for edge prediction tasks. The raw datasets are extracted from\n    `MovieLens <https://grouplens.org/datasets/movielens/>`, introduced by\n    `Movielens unplugged: experiences with an occasionally connected recommender system <https://dl.acm.org/doi/10.1145/604045.604094>`.\n\n    The datasets consist of user ratings for movies and incorporate additional user/movie information in the form of features.\n    The nodes represent users and movies, and the edges store ratings that users assign to movies.\n\n    Statistics:\n\n    MovieLens-100K (ml-100k)\n\n    - Users: 943\n    - Movies: 1,682\n    - Ratings: 100,000 (1, 2, 3, 4, 5)\n\n    MovieLens-1M (ml-1m)\n\n    - Users: 6,040\n    - Movies: 3,706\n    - Ratings: 1,000,209 (1, 2, 3, 4, 5)\n\n    MovieLens-10M (ml-10m)\n\n    - Users: 69,878\n    - Movies: 10,677\n    - Ratings: 10,000,054 (0.5, 1, 1.5, ..., 4.5, 5.0)\n\n    Parameters\n    ----------\n    name: str\n        Dataset name. (:obj:`\"ml-100k\"`, :obj:`\"ml-1m\"`, :obj:`\"ml-10m\"`).\n    valid_ratio: int\n        Ratio of validation samples out of the whole dataset. Should be in (0.0, 1.0).\n    test_ratio: int, optional\n        Ratio of testing samples out of the whole dataset. Should be in (0.0, 1.0). And its sum with\n        :obj:`valid_ratio` should be in (0.0, 1.0) as well. This parameter is invalid\n        when :obj:`name` is :obj:`\"ml-100k\"`, since its testing samples are pre-specified.\n        Default: None\n    raw_dir : str, optional\n        Raw file directory to download/store the data.\n        Default: ~/.dgl/\n    force_reload : bool, optional\n        Whether to re-download(if the dataset has not been downloaded) and re-process the dataset.\n        Default: False\n    verbose : bool, optional\n        Whether to print progress information. Default: True.\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access.\n    random_state : int, optional\n        Random seed used for random dataset split. Default: 0\n\n    Notes\n    -----\n    - When :obj:`name` is :obj:`\"ml-100k\"`, the :obj:`test_ratio` is invalid, and the training ratio is equal to 1-:obj:`valid_ratio`.\n    When :obj:`name` is :obj:`\"ml-1m\"` or :obj:`\"ml-10m\"`, the :obj:`test_ratio` is valid,\n    and the training ratio is equal to 1-:obj:`valid_ratio`-:obj:`test_ratio`.\n    - The number of edges is doubled to form an undirected(bidirected) graph structure.\n\n    Examples\n    --------\n    >>> from dgl.data import MovieLensDataset\n    >>> dataset = MovieLensDataset(name='ml-100k', valid_ratio=0.2)\n    >>> g = dataset[0]\n    >>> g\n    Graph(num_nodes={'movie': 1682, 'user': 943},\n          num_edges={('movie', 'movie-user', 'user'): 100000, ('user', 'user-movie', 'movie'): 100000},\n          metagraph=[('movie', 'user', 'movie-user'), ('user', 'movie', 'user-movie')])\n\n    >>> # get ratings of edges in the training graph.\n    >>> rate = g.edges['user-movie'].data['rate'] # or rate = g.edges['movie-user'].data['rate']\n    >>> rate\n    tensor([5., 5., 3.,  ..., 3., 3., 5.])\n\n    >>> # get train, valid and test mask of edges\n    >>> train_mask = g.edges['user-movie'].data['train_mask']\n    >>> valid_mask = g.edges['user-movie'].data['valid_mask']\n    >>> test_mask = g.edges['user-movie'].data['test_mask']\n\n    >>> # get train, valid and test ratings\n    >>> train_ratings = rate[train_mask]\n    >>> valid_ratings = rate[valid_mask]\n    >>> test_ratings = rate[test_mask]\n\n    >>> # get input features of users\n    >>> g.nodes[\"user\"].data[\"feat\"] # or g.nodes[\"movie\"].data[\"feat\"] for movie nodes\n    tensor([[0.4800, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],\n            [1.0600, 1.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],\n            [0.4600, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],\n            ...,\n            [0.4000, 0.0000, 1.0000,  ..., 0.0000, 0.0000, 0.0000],\n            [0.9600, 1.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],\n            [0.4400, 0.0000, 1.0000,  ..., 0.0000, 0.0000, 0.0000]])\n\n    \"\"\"\n\n    _url = {\n        \"ml-100k\": \"dataset/ml-100k.zip\",\n        \"ml-1m\": \"dataset/ml-1m.zip\",\n        \"ml-10m\": \"dataset/ml-10m.zip\",\n    }\n\n    def __init__(\n        self,\n        name,\n        valid_ratio,\n        test_ratio=None,\n        raw_dir=None,\n        force_reload=None,\n        verbose=None,\n        transform=None,\n        random_state=0,\n    ):\n        check_pytorch()\n        assert name in [\n            \"ml-100k\",\n            \"ml-1m\",\n            \"ml-10m\",\n        ], f\"currently movielens does not support {name}\"\n\n        # test regarding valid and test split ratio\n        assert (\n            valid_ratio > 0.0 and valid_ratio < 1.0\n        ), f\"valid_ratio {valid_ratio} must be in (0.0, 1.0)\"\n\n        if name in [\"ml-1m\", \"ml-10m\"]:\n            assert (\n                test_ratio is not None and test_ratio > 0.0 and test_ratio < 1.0\n            ), f\"test_ratio({test_ratio}) must be set to a value in (0.0, 1.0) when using ml-1m and ml-10m\"\n            assert (\n                test_ratio + valid_ratio > 0.0\n                and test_ratio + valid_ratio < 1.0\n            ), f\"test_ratio({test_ratio}) + valid_ratio({valid_ratio}) must be set to (0.0, 1.0) when using ml-1m and ml-10m\"\n\n        if name == \"ml-100k\" and test_ratio is not None:\n            dgl_warning(\n                f\"test_ratio ({test_ratio}) is not set to None for ml-100k. \"\n                \"Note that dataset split would not be affected by the test_ratio since \"\n                \"testing samples of ml-100k have been pre-specified.\"\n            )\n\n        self.valid_ratio = valid_ratio\n        self.test_ratio = test_ratio\n        self.random_state = random_state\n\n        if name == \"ml-100k\":\n            self.genres = GENRES_ML_100K\n        elif name == \"ml-1m\":\n            self.genres = GENRES_ML_1M\n        elif name == \"ml-10m\":\n            self.genres = GENRES_ML_10M\n        else:\n            raise NotImplementedError\n\n        super(MovieLensDataset, self).__init__(\n            name=name,\n            url=_get_dgl_url(self._url[name]),\n            raw_dir=raw_dir,\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n    def check_version(self):\n        valid_ratio, test_ratio = load_info(self.version_path)\n        if self.valid_ratio == valid_ratio and (\n            self.test_ratio == test_ratio if self.name != \"ml-100k\" else True\n        ):\n            return True\n        else:\n            if self.name == \"ml-100k\":\n                print(\n                    f\"The current valid ratio ({self.valid_ratio}) \"\n                    \"is not the same as the last setting \"\n                    f\"(valid: {valid_ratio}). \"\n                    f\"MovieLens {self.name} will be re-processed with the new dataset split setting.\"\n                )\n            else:\n                print(\n                    f\"At least one of current valid ({self.valid_ratio}) and test ({self.test_ratio}) ratio \"\n                    \"are not the same as the last setting \"\n                    f\"(valid: {valid_ratio}, test: {test_ratio}). \"\n                    f\"MovieLens {self.name} will be re-processed with the new dataset split setting.\"\n                )\n            return False\n\n    def download(self):\n        zip_file_path = os.path.join(self.raw_dir, self.name + \".zip\")\n        download(self.url, path=zip_file_path)\n        extract_archive(zip_file_path, self.raw_dir, overwrite=True)\n\n    def process(self):\n        print(f\"Starting processing {self.name} ...\")\n\n        # 0. loading movie features\n        movie_feat = load_info(\n            os.path.join(self.raw_path, \"movie_feat.pkl\")\n        ).to(torch.float)\n        # 1. dataset split: train + (valid + ) test\n        if self.name == \"ml-100k\":\n            train_rating_data = self._load_raw_rates(\n                os.path.join(self.raw_path, \"u1.base\"), \"\\t\"\n            )\n            test_rating_data = self._load_raw_rates(\n                os.path.join(self.raw_path, \"u1.test\"), \"\\t\"\n            )\n            indices = np.arange(len(train_rating_data))\n            train, valid, _ = split_dataset(\n                indices,\n                [1 - self.valid_ratio, self.valid_ratio, 0.0],\n                shuffle=True,\n                random_state=self.random_state,\n            )\n            train_rating_data, valid_rating_data = (\n                train_rating_data.iloc[train.indices],\n                train_rating_data.iloc[valid.indices],\n            )\n            all_rating_data = pd.concat(\n                [train_rating_data, valid_rating_data, test_rating_data]\n            )\n\n        elif self.name == \"ml-1m\" or self.name == \"ml-10m\":\n            all_rating_data = self._load_raw_rates(\n                os.path.join(self.raw_path, \"ratings.dat\"), \"::\"\n            )\n            indices = np.arange(len(all_rating_data))\n            train, valid, test = split_dataset(\n                indices,\n                [\n                    1 - self.valid_ratio - self.test_ratio,\n                    self.valid_ratio,\n                    self.test_ratio,\n                ],\n                shuffle=True,\n                random_state=self.random_state,\n            )\n            train_rating_data, valid_rating_data, test_rating_data = (\n                all_rating_data.iloc[train.indices],\n                all_rating_data.iloc[valid.indices],\n                all_rating_data.iloc[test.indices],\n            )\n\n        # 2. load user and movie data, and drop those unseen in rating_data\n        user_data = self._load_raw_user_data()\n        movie_data = self._load_raw_movie_data()\n        user_data = self._drop_unseen_nodes(\n            data_df=user_data,\n            col_name=\"id\",\n            reserved_ids_set=set(all_rating_data[\"user_id\"].values),\n        )\n        movie_data = self._drop_unseen_nodes(\n            data_df=movie_data,\n            col_name=\"id\",\n            reserved_ids_set=set(all_rating_data[\"movie_id\"].values),\n        )\n\n        user_feat = Tensor(self._process_user_feat(user_data))\n\n        # 3. generate rating pairs\n        # Map user/movie to the global id\n        self._global_user_id_map = {\n            ele: i for i, ele in enumerate(user_data[\"id\"])\n        }\n        self._global_movie_id_map = {\n            ele: i for i, ele in enumerate(movie_data[\"id\"])\n        }\n\n        # pair value is idx rather than id\n        u_indices, v_indices, labels = self._generate_pair_value(\n            all_rating_data\n        )\n        all_rating_pairs = (\n            LongTensor(u_indices),\n            LongTensor(v_indices),\n        )\n        all_rating_values = Tensor(labels)\n\n        graph = self.construct_g(\n            all_rating_pairs, all_rating_values, user_feat, movie_feat\n        )\n        self.graph = self.add_masks(\n            graph, train_rating_data, valid_rating_data, test_rating_data\n        )\n\n        print(f\"End processing {self.name} ...\")\n\n    def construct_g(self, rate_pairs, rate_values, user_feat, movie_feat):\n        g = heterograph(\n            {\n                (\"user\", \"user-movie\", \"movie\"): (rate_pairs[0], rate_pairs[1]),\n                (\"movie\", \"movie-user\", \"user\"): (rate_pairs[1], rate_pairs[0]),\n            }\n        )\n        ndata = {\"user\": user_feat, \"movie\": movie_feat}\n        edata = {\"user-movie\": rate_values, \"movie-user\": rate_values}\n        g.ndata[\"feat\"] = ndata\n        g.edata[\"rate\"] = edata\n        return g\n\n    def add_masks(\n        self, g, train_rating_data, valid_rating_data, test_rating_data\n    ):\n        train_u_indices, train_v_indices, _ = self._generate_pair_value(\n            train_rating_data\n        )\n        valid_u_indices, valid_v_indices, _ = self._generate_pair_value(\n            valid_rating_data\n        )\n        test_u_indices, test_v_indices, _ = self._generate_pair_value(\n            test_rating_data\n        )\n\n        # user-movie\n        train_mask = torch.zeros((g.num_edges(\"user-movie\"),), dtype=torch.bool)\n        train_mask[\n            g.edge_ids(train_u_indices, train_v_indices, etype=\"user-movie\")\n        ] = True\n        valid_mask = torch.zeros((g.num_edges(\"user-movie\"),), dtype=torch.bool)\n        valid_mask[\n            g.edge_ids(valid_u_indices, valid_v_indices, etype=\"user-movie\")\n        ] = True\n        test_mask = torch.zeros((g.num_edges(\"user-movie\"),), dtype=torch.bool)\n        test_mask[\n            g.edge_ids(test_u_indices, test_v_indices, etype=\"user-movie\")\n        ] = True\n\n        g.edges[\"user-movie\"].data[\"train_mask\"] = train_mask\n        g.edges[\"user-movie\"].data[\"valid_mask\"] = valid_mask\n        g.edges[\"user-movie\"].data[\"test_mask\"] = test_mask\n\n        # movie-user\n        train_mask_rev = torch.zeros(\n            (g.num_edges(\"movie-user\"),), dtype=torch.bool\n        )\n        train_mask_rev[\n            g.edge_ids(train_v_indices, train_u_indices, etype=\"movie-user\")\n        ] = True\n        valid_mask_rev = torch.zeros(\n            (g.num_edges(\"movie-user\"),), dtype=torch.bool\n        )\n        valid_mask_rev[\n            g.edge_ids(valid_v_indices, valid_u_indices, etype=\"movie-user\")\n        ] = True\n        test_mask_rev = torch.zeros(\n            (g.num_edges(\"movie-user\"),), dtype=torch.bool\n        )\n        test_mask_rev[\n            g.edge_ids(test_v_indices, test_u_indices, etype=\"movie-user\")\n        ] = True\n\n        g.edges[\"movie-user\"].data[\"train_mask\"] = train_mask_rev\n        g.edges[\"movie-user\"].data[\"valid_mask\"] = valid_mask_rev\n        g.edges[\"movie-user\"].data[\"test_mask\"] = test_mask_rev\n\n        return g\n\n    def has_cache(self):\n        if (\n            os.path.exists(self.graph_path)\n            and os.path.exists(self.version_path)\n            and self.check_version()\n        ):\n            return True\n        return False\n\n    def save(self):\n        save_graphs(self.graph_path, [self.graph])\n        save_info(self.version_path, [self.valid_ratio, self.test_ratio])\n        if self.verbose:\n            print(f\"Done saving data into {self.raw_path}.\")\n\n    def load(self):\n        g_list, _ = load_graphs(self.graph_path)\n        self.graph = g_list[0]\n\n        \"\"\"\n        To avoid the problem each time loading boolean tensor from the disk, boolean values\n        would be automatically converted into torch.uint8 types, and a deprecation warning would\n        be raised for using torch.uint8\n        \"\"\"\n        for e in self.graph.etypes:\n            self.graph.edges[e].data[\"train_mask\"] = (\n                self.graph.edges[e].data[\"train_mask\"].to(torch.bool)\n            )\n            self.graph.edges[e].data[\"valid_mask\"] = (\n                self.graph.edges[e].data[\"valid_mask\"].to(torch.bool)\n            )\n            self.graph.edges[e].data[\"test_mask\"] = (\n                self.graph.edges[e].data[\"test_mask\"].to(torch.bool)\n            )\n\n    def __getitem__(self, idx):\n        assert (\n            idx == 0\n        ), \"This dataset has only one set of training, validation and testing graph\"\n        if self._transform is None:\n            return self.graph\n        else:\n            return self._transform(self.graph)\n\n    def __len__(self):\n        return 1\n\n    @property\n    def raw_path(self):\n        return os.path.join(self.raw_dir, self.name)\n\n    @property\n    def graph_path(self):\n        return os.path.join(self.raw_path, self.name + \".bin\")\n\n    @property\n    def version_path(self):\n        return os.path.join(self.raw_path, self.name + \"_version.pkl\")\n\n    def _process_user_feat(self, user_data):\n        if self.name == \"ml-100k\" or self.name == \"ml-1m\":\n            ages = user_data[\"age\"].values.astype(np.float32)\n            gender = (user_data[\"gender\"] == \"F\").values.astype(np.float32)\n            all_occupations = set(user_data[\"occupation\"])\n            occupation_map = {ele: i for i, ele in enumerate(all_occupations)}\n            occupation_one_hot = np.zeros(\n                shape=(user_data.shape[0], len(all_occupations)),\n                dtype=np.float32,\n            )\n            occupation_one_hot[\n                np.arange(user_data.shape[0]),\n                np.array(\n                    [occupation_map[ele] for ele in user_data[\"occupation\"]]\n                ),\n            ] = 1\n            user_features = np.concatenate(\n                [\n                    ages.reshape((user_data.shape[0], 1)) / 50.0,\n                    gender.reshape((user_data.shape[0], 1)),\n                    occupation_one_hot,\n                ],\n                axis=1,\n            )\n        elif self.name == \"ml-10m\":\n            user_features = np.zeros(\n                shape=(user_data.shape[0], 1), dtype=np.float32\n            )\n        else:\n            raise NotImplementedError\n        return user_features\n\n    def _load_raw_user_data(self):\n        if self.name == \"ml-100k\":\n            user_data = pd.read_csv(\n                os.path.join(self.raw_path, \"u.user\"),\n                sep=\"|\",\n                header=None,\n                names=[\"id\", \"age\", \"gender\", \"occupation\", \"zip_code\"],\n                engine=\"python\",\n            )\n        elif self.name == \"ml-1m\":\n            user_data = pd.read_csv(\n                os.path.join(self.raw_path, \"users.dat\"),\n                sep=\"::\",\n                header=None,\n                names=[\"id\", \"gender\", \"age\", \"occupation\", \"zip_code\"],\n                engine=\"python\",\n            )\n        elif self.name == \"ml-10m\":\n            rating_info = pd.read_csv(\n                os.path.join(self.raw_path, \"ratings.dat\"),\n                sep=\"::\",\n                header=None,\n                names=[\"user_id\", \"movie_id\", \"rating\", \"timestamp\"],\n                dtype={\n                    \"user_id\": np.int32,\n                    \"movie_id\": np.int32,\n                    \"ratings\": np.float32,\n                    \"timestamp\": np.int64,\n                },\n                engine=\"python\",\n            )\n            user_data = pd.DataFrame(\n                np.unique(rating_info[\"user_id\"].values.astype(np.int32)),\n                columns=[\"id\"],\n            )\n        else:\n            raise NotImplementedError\n        return user_data\n\n    def _load_raw_movie_data(self):\n        file_path = os.path.join(self.raw_path, \"u.item\")\n        if self.name == \"ml-100k\":\n            movie_data = pd.read_csv(\n                file_path,\n                sep=\"|\",\n                header=None,\n                names=[\n                    \"id\",\n                    \"title\",\n                    \"release_date\",\n                    \"video_release_date\",\n                    \"url\",\n                ]\n                + GENRES_ML_100K,\n                engine=\"python\",\n                encoding=\"ISO-8859-1\",\n            )\n        elif self.name == \"ml-1m\" or self.name == \"ml-10m\":\n            file_path = os.path.join(self.raw_path, \"movies.dat\")\n            movie_data = pd.read_csv(\n                file_path,\n                sep=\"::\",\n                header=None,\n                names=[\"id\", \"title\", \"genres\"],\n                encoding=\"iso-8859-1\",\n                engine=\"python\",\n            )\n            genre_map = {ele: i for i, ele in enumerate(self.genres)}\n            genre_map[\"Children's\"] = genre_map[\"Children\"]\n            genre_map[\"Childrens\"] = genre_map[\"Children\"]\n            movie_genres = np.zeros(\n                shape=(movie_data.shape[0], len(self.genres)), dtype=np.float32\n            )\n            for i, genres in enumerate(movie_data[\"genres\"]):\n                for ele in genres.split(\"|\"):\n                    if ele in genre_map:\n                        movie_genres[i, genre_map[ele]] = 1.0\n                    else:\n                        movie_genres[i, genre_map[\"unknown\"]] = 1.0\n            for idx, genre_name in enumerate(self.genres):\n                movie_data[genre_name] = movie_genres[:, idx]\n            movie_data = movie_data.drop(columns=[\"genres\"])\n        else:\n            raise NotImplementedError\n\n        return movie_data\n\n    def _load_raw_rates(self, file_path, sep):\n        rating_data = pd.read_csv(\n            file_path,\n            sep=sep,\n            header=None,\n            names=[\"user_id\", \"movie_id\", \"rating\", \"timestamp\"],\n            dtype={\n                \"user_id\": np.int32,\n                \"movie_id\": np.int32,\n                \"ratings\": np.float32,\n                \"timestamp\": np.int64,\n            },\n            engine=\"python\",\n        )\n        rating_data = rating_data.reset_index(drop=True)\n        return rating_data\n\n    def _drop_unseen_nodes(self, data_df, col_name, reserved_ids_set):\n        data_df = data_df[data_df[col_name].isin(reserved_ids_set)]\n        data_df.reset_index(drop=True, inplace=True)\n        return data_df\n\n    def _generate_pair_value(self, rating_data):\n        rating_pairs = (\n            np.array(\n                [\n                    self._global_user_id_map[ele]\n                    for ele in rating_data[\"user_id\"]\n                ],\n                dtype=np.int32,\n            ),\n            np.array(\n                [\n                    self._global_movie_id_map[ele]\n                    for ele in rating_data[\"movie_id\"]\n                ],\n                dtype=np.int32,\n            ),\n        )\n        rating_values = rating_data[\"rating\"].values.astype(np.float32)\n        return rating_pairs[0], rating_pairs[1], rating_values\n\n    def __repr__(self):\n        return (\n            f'Dataset(\"{self.name}\", num_graphs={len(self)},'\n            + f\" save_path={self.raw_path}), valid_ratio={self.valid_ratio}, test_ratio={self.test_ratio}\"\n        )\n"
  },
  {
    "path": "python/dgl/data/pattern.py",
    "content": "\"\"\" PATTERNDataset for inductive learning. \"\"\"\nimport os\n\nfrom .dgl_dataset import DGLBuiltinDataset\nfrom .utils import _get_dgl_url, load_graphs\n\n\nclass PATTERNDataset(DGLBuiltinDataset):\n    r\"\"\"PATTERN dataset for graph pattern recognition task.\n\n    Each graph G contains 5 communities with sizes randomly selected between [5, 35].\n    The SBM of each community is p = 0.5, q = 0.35, and the node features on G are\n    generated with a uniform random distribution with a vocabulary of size 3, i.e. {0, 1, 2}.\n    Then randomly generate 100 patterns P composed of 20 nodes with intra-probability :math:`p_P` = 0.5\n    and extra-probability :math:`q_P` = 0.5 (i.e. 50% of nodes in P are connected to G). The node features\n    for P are also generated as a random signal with values {0, 1, 2}. The graphs are of sizes\n    44-188 nodes. The output node labels have value 1 if the node belongs to P and value 0 if it is in G.\n\n    Reference `<https://arxiv.org/pdf/2003.00982.pdf>`_\n\n    Statistics:\n\n    - Train examples: 10,000\n    - Valid examples: 2,000\n    - Test examples: 2,000\n    - Number of classes for each node: 2\n\n    Parameters\n    ----------\n    mode : str\n        Must be one of ('train', 'valid', 'test').\n        Default: 'train'\n    raw_dir : str\n        Raw file directory to download/contains the input data directory.\n        Default: ~/.dgl/\n    force_reload : bool\n        Whether to reload the dataset.\n        Default: False\n    verbose : bool\n        Whether to print out progress information.\n        Default: False\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access.\n\n    Attributes\n    ----------\n    num_classes : int\n        Number of classes for each node.\n\n    Examples\n    --------\n    >>> from dgl.data import PATTERNDataset\n    >>> data = PATTERNDataset(mode='train')\n    >>> data.num_classes\n    2\n    >>> len(trainset)\n    10000\n    >>> data[0]\n    Graph(num_nodes=108, num_edges=4884, ndata_schemes={'feat': Scheme(shape=(), dtype=torch.int64), 'label': Scheme(shape=(), dtype=torch.int16)}\n    edata_schemes={'feat': Scheme(shape=(1,), dtype=torch.float32)})\n    \"\"\"\n\n    def __init__(\n        self,\n        mode=\"train\",\n        raw_dir=None,\n        force_reload=False,\n        verbose=False,\n        transform=None,\n    ):\n        assert mode in [\"train\", \"valid\", \"test\"]\n        self.mode = mode\n        _url = _get_dgl_url(\"dataset/SBM_PATTERN.zip\")\n\n        super(PATTERNDataset, self).__init__(\n            name=\"pattern\",\n            url=_url,\n            raw_dir=raw_dir,\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n    def process(self):\n        self.load()\n\n    @property\n    def graph_path(self):\n        return os.path.join(\n            self.save_path, \"SBM_PATTERN_{}.bin\".format(self.mode)\n        )\n\n    def has_cache(self):\n        return os.path.exists(self.graph_path)\n\n    def load(self):\n        self._graphs, _ = load_graphs(self.graph_path)\n\n    @property\n    def num_classes(self):\n        r\"\"\"Number of classes for each node.\"\"\"\n        return 2\n\n    def __len__(self):\n        r\"\"\"The number of examples in the dataset.\"\"\"\n        return len(self._graphs)\n\n    def __getitem__(self, idx):\n        r\"\"\"Get the idx^th sample.\n\n        Parameters\n        ---------\n        idx : int\n            The sample index.\n\n        Returns\n        -------\n        :class:`dgl.DGLGraph`\n            graph structure, node features, node labels and edge features.\n\n            - ``ndata['feat']``: node features\n            - ``ndata['label']``: node labels\n            - ``edata['feat']``: edge features\n        \"\"\"\n        if self._transform is None:\n            return self._graphs[idx]\n        else:\n            return self._transform(self._graphs[idx])\n"
  },
  {
    "path": "python/dgl/data/ppi.py",
    "content": "\"\"\" PPIDataset for inductive learning. \"\"\"\nimport json\nimport os\n\nimport networkx as nx\nimport numpy as np\nfrom networkx.readwrite import json_graph\n\nfrom .. import backend as F\nfrom ..convert import from_networkx\nfrom .dgl_dataset import DGLBuiltinDataset\nfrom .utils import _get_dgl_url, load_graphs, load_info, save_graphs, save_info\n\n\nclass PPIDataset(DGLBuiltinDataset):\n    r\"\"\"Protein-Protein Interaction dataset for inductive node classification\n\n    A toy Protein-Protein Interaction network dataset. The dataset contains\n    24 graphs. The average number of nodes per graph is 2372. Each node has\n    50 features and 121 labels. 20 graphs for training, 2 for validation\n    and 2 for testing.\n\n    Reference: `<http://snap.stanford.edu/graphsage/>`_\n\n    Statistics:\n\n    - Train examples: 20\n    - Valid examples: 2\n    - Test examples: 2\n\n    Parameters\n    ----------\n    mode : str\n        Must be one of ('train', 'valid', 'test').\n        Default: 'train'\n    raw_dir : str\n        Raw file directory to download/contains the input data directory.\n        Default: ~/.dgl/\n    force_reload : bool\n        Whether to reload the dataset.\n        Default: False\n    verbose : bool\n        Whether to print out progress information.\n        Default: True.\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access.\n\n    Attributes\n    ----------\n    num_labels : int\n        Number of labels for each node\n    labels : Tensor\n        Node labels\n    features : Tensor\n        Node features\n\n    Examples\n    --------\n    >>> dataset = PPIDataset(mode='valid')\n    >>> num_classes = dataset.num_classes\n    >>> for g in dataset:\n    ....    feat = g.ndata['feat']\n    ....    label = g.ndata['label']\n    ....    # your code here\n    >>>\n    \"\"\"\n\n    def __init__(\n        self,\n        mode=\"train\",\n        raw_dir=None,\n        force_reload=False,\n        verbose=False,\n        transform=None,\n    ):\n        assert mode in [\"train\", \"valid\", \"test\"]\n        self.mode = mode\n        _url = _get_dgl_url(\"dataset/ppi.zip\")\n        super(PPIDataset, self).__init__(\n            name=\"ppi\",\n            url=_url,\n            raw_dir=raw_dir,\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n    def process(self):\n        graph_file = os.path.join(\n            self.save_path, \"{}_graph.json\".format(self.mode)\n        )\n        label_file = os.path.join(\n            self.save_path, \"{}_labels.npy\".format(self.mode)\n        )\n        feat_file = os.path.join(\n            self.save_path, \"{}_feats.npy\".format(self.mode)\n        )\n        graph_id_file = os.path.join(\n            self.save_path, \"{}_graph_id.npy\".format(self.mode)\n        )\n\n        g_data = json.load(open(graph_file))\n        self._labels = np.load(label_file)\n        self._feats = np.load(feat_file)\n        self.graph = from_networkx(\n            nx.DiGraph(json_graph.node_link_graph(g_data))\n        )\n        graph_id = np.load(graph_id_file)\n\n        # lo, hi means the range of graph ids for different portion of the dataset,\n        # 20 graphs for training, 2 for validation and 2 for testing.\n        lo, hi = 1, 21\n        if self.mode == \"valid\":\n            lo, hi = 21, 23\n        elif self.mode == \"test\":\n            lo, hi = 23, 25\n\n        graph_masks = []\n        self.graphs = []\n        for g_id in range(lo, hi):\n            g_mask = np.where(graph_id == g_id)[0]\n            graph_masks.append(g_mask)\n            g = self.graph.subgraph(g_mask)\n            g.ndata[\"feat\"] = F.tensor(\n                self._feats[g_mask], dtype=F.data_type_dict[\"float32\"]\n            )\n            g.ndata[\"label\"] = F.tensor(\n                self._labels[g_mask], dtype=F.data_type_dict[\"float32\"]\n            )\n            self.graphs.append(g)\n\n    @property\n    def graph_list_path(self):\n        return os.path.join(\n            self.save_path, \"{}_dgl_graph_list.bin\".format(self.mode)\n        )\n\n    @property\n    def g_path(self):\n        return os.path.join(\n            self.save_path, \"{}_dgl_graph.bin\".format(self.mode)\n        )\n\n    @property\n    def info_path(self):\n        return os.path.join(self.save_path, \"{}_info.pkl\".format(self.mode))\n\n    def has_cache(self):\n        return (\n            os.path.exists(self.graph_list_path)\n            and os.path.exists(self.g_path)\n            and os.path.exists(self.info_path)\n        )\n\n    def save(self):\n        save_graphs(self.graph_list_path, self.graphs)\n        save_graphs(self.g_path, self.graph)\n        save_info(\n            self.info_path, {\"labels\": self._labels, \"feats\": self._feats}\n        )\n\n    def load(self):\n        self.graphs = load_graphs(self.graph_list_path)[0]\n        g, _ = load_graphs(self.g_path)\n        self.graph = g[0]\n        info = load_info(self.info_path)\n        self._labels = info[\"labels\"]\n        self._feats = info[\"feats\"]\n\n    @property\n    def num_labels(self):\n        return 121\n\n    @property\n    def num_classes(self):\n        return 121\n\n    def __len__(self):\n        \"\"\"Return number of samples in this dataset.\"\"\"\n        return len(self.graphs)\n\n    def __getitem__(self, item):\n        \"\"\"Get the item^th sample.\n\n        Parameters\n        ---------\n        item : int\n            The sample index.\n\n        Returns\n        -------\n        :class:`dgl.DGLGraph`\n            graph structure, node features and node labels.\n\n            - ``ndata['feat']``: node features\n            - ``ndata['label']``: node labels\n        \"\"\"\n        if self._transform is None:\n            return self.graphs[item]\n        else:\n            return self._transform(self.graphs[item])\n\n\nclass LegacyPPIDataset(PPIDataset):\n    \"\"\"Legacy version of PPI Dataset\"\"\"\n\n    def __getitem__(self, item):\n        \"\"\"Get the item^th sample.\n\n        Paramters\n        ---------\n        idx : int\n            The sample index.\n\n        Returns\n        -------\n        (dgl.DGLGraph, Tensor, Tensor)\n            The graph, features and its label.\n        \"\"\"\n        if self._transform is None:\n            g = self.graphs[item]\n        else:\n            g = self._transform(self.graphs[item])\n        return g, g.ndata[\"feat\"], g.ndata[\"label\"]\n"
  },
  {
    "path": "python/dgl/data/qm7b.py",
    "content": "\"\"\"QM7b dataset for graph property prediction (regression).\"\"\"\nimport os\n\nfrom scipy import io\n\nfrom .. import backend as F\nfrom ..convert import graph as dgl_graph\n\nfrom .dgl_dataset import DGLDataset\nfrom .utils import check_sha1, download, load_graphs, save_graphs\n\n\nclass QM7bDataset(DGLDataset):\n    r\"\"\"QM7b dataset for graph property prediction (regression)\n\n    This dataset consists of 7,211 molecules with 14 regression targets.\n    Nodes means atoms and edges means bonds. Edge data 'h' means\n    the entry of Coulomb matrix.\n\n    Reference: `<http://quantum-machine.org/datasets/>`_\n\n    Statistics:\n\n    - Number of graphs: 7,211\n    - Number of regression targets: 14\n    - Average number of nodes: 15\n    - Average number of edges: 245\n    - Edge feature size: 1\n\n    Parameters\n    ----------\n    raw_dir : str\n        Raw file directory to download/contains the input data directory.\n        Default: ~/.dgl/\n    force_reload : bool\n        Whether to reload the dataset. Default: False\n    verbose : bool\n        Whether to print out progress information. Default: True.\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access.\n\n    Attributes\n    ----------\n    num_tasks : int\n        Number of prediction tasks\n    num_labels : int\n        (DEPRECATED, use num_tasks instead) Number of prediction tasks\n\n    Raises\n    ------\n    UserWarning\n        If the raw data is changed in the remote server by the author.\n\n    Examples\n    --------\n    >>> data = QM7bDataset()\n    >>> data.num_tasks\n    14\n    >>>\n    >>> # iterate over the dataset\n    >>> for g, label in data:\n    ...     edge_feat = g.edata['h']  # get edge feature\n    ...     # your code here...\n    ...\n    >>>\n    \"\"\"\n\n    _url = (\n        \"http://deepchem.io.s3-website-us-west-1.amazonaws.com/\"\n        \"datasets/qm7b.mat\"\n    )\n    _sha1_str = \"4102c744bb9d6fd7b40ac67a300e49cd87e28392\"\n\n    def __init__(\n        self, raw_dir=None, force_reload=False, verbose=False, transform=None\n    ):\n        super(QM7bDataset, self).__init__(\n            name=\"qm7b\",\n            url=self._url,\n            raw_dir=raw_dir,\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n    def process(self):\n        mat_path = os.path.join(self.raw_dir, self.name + \".mat\")\n        self.graphs, self.label = self._load_graph(mat_path)\n\n    def _load_graph(self, filename):\n        data = io.loadmat(filename)\n        labels = F.tensor(data[\"T\"], dtype=F.data_type_dict[\"float32\"])\n        feats = data[\"X\"]\n        num_graphs = labels.shape[0]\n        graphs = []\n        for i in range(num_graphs):\n            edge_list = feats[i].nonzero()\n            g = dgl_graph(edge_list)\n            g.edata[\"h\"] = F.tensor(\n                feats[i][edge_list[0], edge_list[1]].reshape(-1, 1),\n                dtype=F.data_type_dict[\"float32\"],\n            )\n            graphs.append(g)\n        return graphs, labels\n\n    def save(self):\n        \"\"\"save the graph list and the labels\"\"\"\n        graph_path = os.path.join(self.save_path, \"dgl_graph.bin\")\n        save_graphs(str(graph_path), self.graphs, {\"labels\": self.label})\n\n    def has_cache(self):\n        graph_path = os.path.join(self.save_path, \"dgl_graph.bin\")\n        return os.path.exists(graph_path)\n\n    def load(self):\n        graphs, label_dict = load_graphs(\n            os.path.join(self.save_path, \"dgl_graph.bin\")\n        )\n        self.graphs = graphs\n        self.label = label_dict[\"labels\"]\n\n    def download(self):\n        file_path = os.path.join(self.raw_dir, self.name + \".mat\")\n        download(self.url, path=file_path)\n        if not check_sha1(file_path, self._sha1_str):\n            raise UserWarning(\n                \"File {} is downloaded but the content hash does not match.\"\n                \"The repo may be outdated or download may be incomplete. \"\n                \"Otherwise you can create an issue for it.\".format(self.name)\n            )\n\n    @property\n    def num_tasks(self):\n        \"\"\"Number of prediction tasks.\"\"\"\n        return self.num_labels\n\n    @property\n    def num_labels(self):\n        \"\"\"Number of prediction tasks.\"\"\"\n        return 14\n\n    @property\n    def num_classes(self):\n        \"\"\"Number of prediction tasks.\"\"\"\n        return 14\n\n    def __getitem__(self, idx):\n        r\"\"\"Get graph and label by index\n\n        Parameters\n        ----------\n        idx : int\n            Item index\n\n        Returns\n        -------\n        (:class:`dgl.DGLGraph`, Tensor)\n        \"\"\"\n        if self._transform is None:\n            g = self.graphs[idx]\n        else:\n            g = self._transform(self.graphs[idx])\n        return g, self.label[idx]\n\n    def __len__(self):\n        r\"\"\"Number of graphs in the dataset.\n\n        Return\n        -------\n        int\n        \"\"\"\n        return len(self.graphs)\n\n\nQM7b = QM7bDataset\n"
  },
  {
    "path": "python/dgl/data/qm9.py",
    "content": "\"\"\"QM9 dataset for graph property prediction (regression).\"\"\"\nimport os\n\nimport numpy as np\nimport scipy.sparse as sp\n\nfrom .. import backend as F\nfrom ..convert import graph as dgl_graph\nfrom ..transforms import to_bidirected\n\nfrom .dgl_dataset import DGLDataset\nfrom .utils import _get_dgl_url, download\n\n\nclass QM9Dataset(DGLDataset):\n    r\"\"\"QM9 dataset for graph property prediction (regression)\n\n    This dataset consists of 130,831 molecules with 12 regression targets.\n    Nodes correspond to atoms and edges correspond to close atom pairs.\n\n    This dataset differs from :class:`~dgl.data.QM9EdgeDataset` in the following aspects:\n        1. Edges in this dataset are purely distance-based.\n        2. It only provides atoms' coordinates and atomic numbers as node features\n        3. It only provides 12 regression targets.\n\n    Reference:\n\n    - `\"Quantum-Machine.org\" <http://quantum-machine.org/datasets/>`_,\n    - `\"Directional Message Passing for Molecular Graphs\" <https://arxiv.org/abs/2003.03123>`_\n\n    Statistics:\n\n    - Number of graphs: 130,831\n    - Number of regression targets: 12\n\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | Keys   | Property                         | Description                                                                       | Unit                                        |\n    +========+==================================+===================================================================================+=============================================+\n    | mu     | :math:`\\mu`                      | Dipole moment                                                                     | :math:`\\textrm{D}`                          |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | alpha  | :math:`\\alpha`                   | Isotropic polarizability                                                          | :math:`{a_0}^3`                             |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | homo   | :math:`\\epsilon_{\\textrm{HOMO}}` | Highest occupied molecular orbital energy                                         | :math:`\\textrm{eV}`                         |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | lumo   | :math:`\\epsilon_{\\textrm{LUMO}}` | Lowest unoccupied molecular orbital energy                                        | :math:`\\textrm{eV}`                         |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | gap    | :math:`\\Delta \\epsilon`          | Gap between :math:`\\epsilon_{\\textrm{HOMO}}` and :math:`\\epsilon_{\\textrm{LUMO}}` | :math:`\\textrm{eV}`                         |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | r2     | :math:`\\langle R^2 \\rangle`      | Electronic spatial extent                                                         | :math:`{a_0}^2`                             |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | zpve   | :math:`\\textrm{ZPVE}`            | Zero point vibrational energy                                                     | :math:`\\textrm{eV}`                         |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | U0     | :math:`U_0`                      | Internal energy at 0K                                                             | :math:`\\textrm{eV}`                         |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | U      | :math:`U`                        | Internal energy at 298.15K                                                        | :math:`\\textrm{eV}`                         |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | H      | :math:`H`                        | Enthalpy at 298.15K                                                               | :math:`\\textrm{eV}`                         |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | G      | :math:`G`                        | Free energy at 298.15K                                                            | :math:`\\textrm{eV}`                         |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | Cv     | :math:`c_{\\textrm{v}}`           | Heat capavity at 298.15K                                                          | :math:`\\frac{\\textrm{cal}}{\\textrm{mol K}}` |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n\n    Parameters\n    ----------\n    label_keys : list\n        Names of the regression property, which should be a subset of the keys in the table above.\n    cutoff : float\n        Cutoff distance for interatomic interactions, i.e. two atoms are connected in the corresponding graph if the distance between them is no larger than this.\n        Default: 5.0 Angstrom\n    raw_dir : str\n        Raw file directory to download/contains the input data directory.\n        Default: ~/.dgl/\n    force_reload : bool\n        Whether to reload the dataset. Default: False\n    verbose : bool\n        Whether to print out progress information. Default: True.\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access.\n\n    Attributes\n    ----------\n    num_tasks : int\n        Number of prediction tasks\n    num_labels : int\n        (DEPRECATED, use num_tasks instead) Number of prediction tasks\n\n    Raises\n    ------\n    UserWarning\n        If the raw data is changed in the remote server by the author.\n\n    Examples\n    --------\n    >>> data = QM9Dataset(label_keys=['mu', 'gap'], cutoff=5.0)\n    >>> data.num_tasks\n    2\n    >>>\n    >>> # iterate over the dataset\n    >>> for g, label in data:\n    ...     R = g.ndata['R'] # get coordinates of each atom\n    ...     Z = g.ndata['Z'] # get atomic numbers of each atom\n    ...     # your code here...\n    >>>\n    \"\"\"\n\n    def __init__(\n        self,\n        label_keys,\n        cutoff=5.0,\n        raw_dir=None,\n        force_reload=False,\n        verbose=False,\n        transform=None,\n    ):\n        self.cutoff = cutoff\n        self.label_keys = label_keys\n        self._url = _get_dgl_url(\"dataset/qm9_eV.npz\")\n\n        super(QM9Dataset, self).__init__(\n            name=\"qm9\",\n            url=self._url,\n            raw_dir=raw_dir,\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n    def process(self):\n        npz_path = f\"{self.raw_dir}/qm9_eV.npz\"\n        data_dict = np.load(npz_path, allow_pickle=True)\n        # data_dict['N'] contains the number of atoms in each molecule.\n        # Atomic properties (Z and R) of all molecules are concatenated as single tensors,\n        # so you need this value to select the correct atoms for each molecule.\n        self.N = data_dict[\"N\"]\n        self.R = data_dict[\"R\"]\n        self.Z = data_dict[\"Z\"]\n        self.label = np.stack(\n            [data_dict[key] for key in self.label_keys], axis=1\n        )\n        self.N_cumsum = np.concatenate([[0], np.cumsum(self.N)])\n\n    def download(self):\n        file_path = f\"{self.raw_dir}/qm9_eV.npz\"\n        if not os.path.exists(file_path):\n            download(self._url, path=file_path)\n\n    @property\n    def num_labels(self):\n        r\"\"\"\n        Returns\n        --------\n        int\n            Number of prediction tasks.\n        \"\"\"\n        return self.label.shape[1]\n\n    @property\n    def num_classes(self):\n        r\"\"\"\n        Returns\n        --------\n        int\n            Number of prediction tasks.\n        \"\"\"\n        return self.label.shape[1]\n\n    @property\n    def num_tasks(self):\n        r\"\"\"\n        Returns\n        --------\n        int\n            Number of prediction tasks.\n        \"\"\"\n        return self.label.shape[1]\n\n    def __getitem__(self, idx):\n        r\"\"\"Get graph and label by index\n\n        Parameters\n        ----------\n        idx : int\n            Item index\n\n        Returns\n        -------\n        dgl.DGLGraph\n            The graph contains:\n\n            - ``ndata['R']``: the coordinates of each atom\n            - ``ndata['Z']``: the atomic number\n\n        Tensor\n            Property values of molecular graphs\n        \"\"\"\n        label = F.tensor(self.label[idx], dtype=F.data_type_dict[\"float32\"])\n        n_atoms = self.N[idx]\n        R = self.R[self.N_cumsum[idx] : self.N_cumsum[idx + 1]]\n        dist = np.linalg.norm(R[:, None, :] - R[None, :, :], axis=-1)\n        adj = sp.csr_matrix(dist <= self.cutoff) - sp.eye(\n            n_atoms, dtype=np.bool_\n        )\n        adj = adj.tocoo()\n        u, v = F.tensor(adj.row), F.tensor(adj.col)\n        g = dgl_graph((u, v))\n        g = to_bidirected(g)\n        g.ndata[\"R\"] = F.tensor(R, dtype=F.data_type_dict[\"float32\"])\n        g.ndata[\"Z\"] = F.tensor(\n            self.Z[self.N_cumsum[idx] : self.N_cumsum[idx + 1]],\n            dtype=F.data_type_dict[\"int64\"],\n        )\n\n        if self._transform is not None:\n            g = self._transform(g)\n\n        return g, label\n\n    def __len__(self):\n        r\"\"\"Number of graphs in the dataset.\n\n        Return\n        -------\n        int\n        \"\"\"\n        return self.label.shape[0]\n\n\nQM9 = QM9Dataset\n"
  },
  {
    "path": "python/dgl/data/qm9_edge.py",
    "content": "\"\"\" QM9 dataset for graph property prediction (regression) \"\"\"\n\nimport os\n\nimport numpy as np\n\nfrom .. import backend as F\nfrom ..convert import graph as dgl_graph\n\nfrom .dgl_dataset import DGLDataset\nfrom .utils import _get_dgl_url, download, extract_archive\n\n\nclass QM9EdgeDataset(DGLDataset):\n    r\"\"\"QM9Edge dataset for graph property prediction (regression)\n\n    This dataset consists of 130,831 molecules with 19 regression targets.\n    Nodes correspond to atoms and edges correspond to bonds.\n\n    This dataset differs from :class:`~dgl.data.QM9Dataset` in the following aspects:\n        1. It includes the bonds in a molecule in the edges of the corresponding graph while the edges in :class:`~dgl.data.QM9Dataset` are purely distance-based.\n        2. It provides edge features, and node features in addition to the atoms' coordinates and atomic numbers.\n        3. It provides another 7 regression tasks(from 12 to 19).\n\n    This class is built based on a preprocessed version of the dataset, and we provide the preprocessing datails `here <https://gist.github.com/hengruizhang98/a2da30213b2356fff18b25385c9d3cd2>`_.\n\n    Reference:\n\n    - `\"MoleculeNet: A Benchmark for Molecular Machine Learning\" <https://arxiv.org/abs/1703.00564>`_\n    - `\"Neural Message Passing for Quantum Chemistry\" <https://arxiv.org/abs/1704.01212>`_\n\n    For\n    Statistics:\n\n    - Number of graphs: 130,831.\n    - Number of regression targets: 19.\n\n    Node attributes:\n\n    - pos: the 3D coordinates of each atom.\n    - attr: the 11D atom features.\n\n    Edge attributes:\n\n    - edge_attr: the 4D bond features.\n\n    Regression targets:\n\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | Keys   | Property                         | Description                                                                       | Unit                                        |\n    +========+==================================+===================================================================================+=============================================+\n    | mu     | :math:`\\mu`                      | Dipole moment                                                                     | :math:`\\textrm{D}`                          |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | alpha  | :math:`\\alpha`                   | Isotropic polarizability                                                          | :math:`{a_0}^3`                             |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | homo   | :math:`\\epsilon_{\\textrm{HOMO}}` | Highest occupied molecular orbital energy                                         | :math:`\\textrm{eV}`                         |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | lumo   | :math:`\\epsilon_{\\textrm{LUMO}}` | Lowest unoccupied molecular orbital energy                                        | :math:`\\textrm{eV}`                         |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | gap    | :math:`\\Delta \\epsilon`          | Gap between :math:`\\epsilon_{\\textrm{HOMO}}` and :math:`\\epsilon_{\\textrm{LUMO}}` | :math:`\\textrm{eV}`                         |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | r2     | :math:`\\langle R^2 \\rangle`      | Electronic spatial extent                                                         | :math:`{a_0}^2`                             |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | zpve   | :math:`\\textrm{ZPVE}`            | Zero point vibrational energy                                                     | :math:`\\textrm{eV}`                         |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | U0     | :math:`U_0`                      | Internal energy at 0K                                                             | :math:`\\textrm{eV}`                         |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | U      | :math:`U`                        | Internal energy at 298.15K                                                        | :math:`\\textrm{eV}`                         |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | H      | :math:`H`                        | Enthalpy at 298.15K                                                               | :math:`\\textrm{eV}`                         |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | G      | :math:`G`                        | Free energy at 298.15K                                                            | :math:`\\textrm{eV}`                         |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | Cv     | :math:`c_{\\textrm{v}}`           | Heat capavity at 298.15K                                                          | :math:`\\frac{\\textrm{cal}}{\\textrm{mol K}}` |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | U0_atom| :math:`U_0^{\\textrm{ATOM}}`      | Atomization energy at 0K                                                          | :math:`\\textrm{eV}`                         |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | U_atom | :math:`U^{\\textrm{ATOM}}`        | Atomization energy at 298.15K                                                     | :math:`\\textrm{eV}`                         |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | H_atom | :math:`H^{\\textrm{ATOM}}`        | Atomization enthalpy at 298.15K                                                   | :math:`\\textrm{eV}`                         |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | G_atom | :math:`G^{\\textrm{ATOM}}`        | Atomization free energy at 298.15K                                                | :math:`\\textrm{eV}`                         |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | A      | :math:`A`                        | Rotational constant                                                               | :math:`\\textrm{GHz}`                        |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | B      | :math:`B`                        | Rotational constant                                                               | :math:`\\textrm{GHz}`                        |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | C      | :math:`C`                        | Rotational constant                                                               | :math:`\\textrm{GHz}`                        |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n\n    Parameters\n    ----------\n    label_keys : list\n        Names of the regression property, which should be a subset of the keys in the table above.\n        If not provided, it will load all the labels.\n    raw_dir : str\n        Raw file directory to download/contains the input data directory.\n        Default: ~/.dgl/\n    force_reload : bool\n        Whether to reload the dataset. Default: False.\n    verbose : bool\n        Whether to print out progress information. Default: True.\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access.\n\n    Attributes\n    ----------\n    num_tasks : int\n        Number of prediction tasks\n    num_labels : int\n        (DEPRECATED, use num_tasks instead) Number of prediction tasks\n\n    Raises\n    ------\n    UserWarning\n        If the raw data is changed in the remote server by the author.\n\n    Examples\n    --------\n    >>> data = QM9EdgeDataset(label_keys=['mu', 'alpha'])\n    >>> data.num_tasks\n    2\n\n    >>> # iterate over the dataset\n    >>> for graph, labels in data:\n    ...     print(graph) # get information of each graph\n    ...     print(labels) # get labels of the corresponding graph\n    ...     # your code here...\n    >>>\n    \"\"\"\n\n    keys = [\n        \"mu\",\n        \"alpha\",\n        \"homo\",\n        \"lumo\",\n        \"gap\",\n        \"r2\",\n        \"zpve\",\n        \"U0\",\n        \"U\",\n        \"H\",\n        \"G\",\n        \"Cv\",\n        \"U0_atom\",\n        \"U_atom\",\n        \"H_atom\",\n        \"G_atom\",\n        \"A\",\n        \"B\",\n        \"C\",\n    ]\n    map_dict = {}\n\n    for i, key in enumerate(keys):\n        map_dict[key] = i\n\n    def __init__(\n        self,\n        label_keys=None,\n        raw_dir=None,\n        force_reload=False,\n        verbose=True,\n        transform=None,\n    ):\n        if label_keys is None:\n            self.label_keys = None\n            self.num_labels = 19\n        else:\n            self.label_keys = [self.map_dict[i] for i in label_keys]\n            self.num_labels = len(label_keys)\n\n        self._url = _get_dgl_url(\"dataset/qm9_edge.npz\")\n\n        super(QM9EdgeDataset, self).__init__(\n            name=\"qm9Edge\",\n            raw_dir=raw_dir,\n            url=self._url,\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n    def download(self):\n        if not os.path.exists(self.npz_path):\n            download(self._url, path=self.npz_path)\n\n    def process(self):\n        self.load()\n\n    @property\n    def npz_path(self):\n        return f\"{self.raw_dir}/qm9_edge.npz\"\n\n    def has_cache(self):\n        return os.path.exists(self.npz_path)\n\n    def save(self):\n        np.savez_compressed(\n            self.npz_path,\n            n_node=self.n_node,\n            n_edge=self.n_edge,\n            node_attr=self.node_attr,\n            node_pos=self.node_pos,\n            edge_attr=self.edge_attr,\n            src=self.src,\n            dst=self.dst,\n            targets=self.targets,\n        )\n\n    def load(self):\n        data_dict = np.load(self.npz_path, allow_pickle=True)\n\n        self.n_node = data_dict[\"n_node\"]\n        self.n_edge = data_dict[\"n_edge\"]\n        self.node_attr = data_dict[\"node_attr\"]\n        self.node_pos = data_dict[\"node_pos\"]\n        self.edge_attr = data_dict[\"edge_attr\"]\n        self.targets = data_dict[\"targets\"]\n\n        self.src = data_dict[\"src\"]\n        self.dst = data_dict[\"dst\"]\n\n        self.n_cumsum = np.concatenate([[0], np.cumsum(self.n_node)])\n        self.ne_cumsum = np.concatenate([[0], np.cumsum(self.n_edge)])\n\n    def __getitem__(self, idx):\n        r\"\"\"Get graph and label by index\n\n        Parameters\n        ----------\n        idx : int\n            Item index\n\n        Returns\n        -------\n        dgl.DGLGraph\n           The graph contains:\n\n           - ``ndata['pos']``: the coordinates of each atom\n           - ``ndata['attr']``: the features of each atom\n           - ``edata['edge_attr']``: the features of each bond\n\n        Tensor\n            Property values of molecular graphs\n        \"\"\"\n\n        pos = self.node_pos[self.n_cumsum[idx] : self.n_cumsum[idx + 1]]\n        src = self.src[self.ne_cumsum[idx] : self.ne_cumsum[idx + 1]]\n        dst = self.dst[self.ne_cumsum[idx] : self.ne_cumsum[idx + 1]]\n\n        g = dgl_graph((src, dst))\n\n        g.ndata[\"pos\"] = F.tensor(pos, dtype=F.data_type_dict[\"float32\"])\n        g.ndata[\"attr\"] = F.tensor(\n            self.node_attr[self.n_cumsum[idx] : self.n_cumsum[idx + 1]],\n            dtype=F.data_type_dict[\"float32\"],\n        )\n        g.edata[\"edge_attr\"] = F.tensor(\n            self.edge_attr[self.ne_cumsum[idx] : self.ne_cumsum[idx + 1]],\n            dtype=F.data_type_dict[\"float32\"],\n        )\n\n        label = F.tensor(\n            self.targets[idx][self.label_keys],\n            dtype=F.data_type_dict[\"float32\"],\n        )\n\n        if self._transform is not None:\n            g = self._transform(g)\n\n        return g, label\n\n    def __len__(self):\n        r\"\"\"Number of graphs in the dataset.\n\n        Returns\n        -------\n        int\n        \"\"\"\n        return self.n_node.shape[0]\n\n    @property\n    def num_tasks(self):\n        r\"\"\"\n        Returns\n        -------\n        int\n            Number of prediction tasks\n        \"\"\"\n        return self.num_labels\n\n\nQM9Edge = QM9EdgeDataset\n"
  },
  {
    "path": "python/dgl/data/rdf.py",
    "content": "\"\"\"RDF datasets\nDatasets from \"A Collection of Benchmark Datasets for\nSystematic Evaluations of Machine Learning on\nthe Semantic Web\"\n\"\"\"\nimport abc\nimport itertools\nimport os\nimport re\nfrom collections import OrderedDict\n\nimport networkx as nx\nimport numpy as np\n\nimport dgl\nimport dgl.backend as F\n\nfrom .dgl_dataset import DGLBuiltinDataset\nfrom .utils import (\n    _get_dgl_url,\n    generate_mask_tensor,\n    idx2mask,\n    load_graphs,\n    load_info,\n    save_graphs,\n    save_info,\n)\n\n__all__ = [\"AIFBDataset\", \"MUTAGDataset\", \"BGSDataset\", \"AMDataset\"]\n\n# Dictionary for renaming reserved node/edge type names to the ones\n# that are allowed by nn.Module.\nRENAME_DICT = {\n    \"type\": \"rdftype\",\n    \"rev-type\": \"rev-rdftype\",\n}\n\n\nclass Entity:\n    \"\"\"Class for entities\n    Parameters\n    ----------\n    id : str\n        ID of this entity\n    cls : str\n        Type of this entity\n    \"\"\"\n\n    def __init__(self, e_id, cls):\n        self.id = e_id\n        self.cls = cls\n\n    def __str__(self):\n        return \"{}/{}\".format(self.cls, self.id)\n\n\nclass Relation:\n    \"\"\"Class for relations\n    Parameters\n    ----------\n    cls : str\n        Type of this relation\n    \"\"\"\n\n    def __init__(self, cls):\n        self.cls = cls\n\n    def __str__(self):\n        return str(self.cls)\n\n\nclass RDFGraphDataset(DGLBuiltinDataset):\n    \"\"\"Base graph dataset class from RDF tuples.\n\n    To derive from this, implement the following abstract methods:\n    * ``parse_entity``\n    * ``parse_relation``\n    * ``process_tuple``\n    * ``process_idx_file_line``\n    * ``predict_category``\n    Preprocessed graph and other data will be cached in the download folder\n    to speedup data loading.\n    The dataset should contain a \"trainingSet.tsv\" and a \"testSet.tsv\" file\n    for training and testing samples.\n\n    Attributes\n    ----------\n    num_classes : int\n        Number of classes to predict\n    predict_category : str\n        The entity category (node type) that has labels for prediction\n\n    Parameters\n    ----------\n    name : str\n        Name of the dataset\n    url : str or path\n        URL to download the raw dataset.\n    predict_category : str\n        Predict category.\n    print_every : int, optional\n        Preprocessing log for every X tuples.\n    insert_reverse : bool, optional\n        If true, add reverse edge and reverse relations to the final graph.\n    raw_dir : str\n        Raw file directory to download/contains the input data directory.\n        Default: ~/.dgl/\n    force_reload : bool, optional\n        If true, force load and process from raw data. Ignore cached pre-processed data.\n    verbose : bool\n        Whether to print out progress information. Default: True.\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access.\n    \"\"\"\n\n    def __init__(\n        self,\n        name,\n        url,\n        predict_category,\n        print_every=10000,\n        insert_reverse=True,\n        raw_dir=None,\n        force_reload=False,\n        verbose=True,\n        transform=None,\n    ):\n        self._insert_reverse = insert_reverse\n        self._print_every = print_every\n        self._predict_category = predict_category\n\n        super(RDFGraphDataset, self).__init__(\n            name,\n            url,\n            raw_dir=raw_dir,\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n    def process(self):\n        raw_tuples = self.load_raw_tuples(self.raw_path)\n        self.process_raw_tuples(raw_tuples, self.raw_path)\n\n    def load_raw_tuples(self, root_path):\n        \"\"\"Loading raw RDF dataset\n\n        Parameters\n        ----------\n        root_path : str\n            Root path containing the data\n\n        Returns\n        -------\n            Loaded rdf data\n        \"\"\"\n        import rdflib as rdf\n\n        raw_rdf_graphs = []\n        for _, filename in enumerate(os.listdir(root_path)):\n            fmt = None\n            if filename.endswith(\"nt\"):\n                fmt = \"nt\"\n            elif filename.endswith(\"n3\"):\n                fmt = \"n3\"\n            if fmt is None:\n                continue\n            g = rdf.Graph()\n            print(\"Parsing file %s ...\" % filename)\n            g.parse(os.path.join(root_path, filename), format=fmt)\n            raw_rdf_graphs.append(g)\n        return itertools.chain(*raw_rdf_graphs)\n\n    def process_raw_tuples(self, raw_tuples, root_path):\n        \"\"\"Processing raw RDF dataset\n\n        Parameters\n        ----------\n        raw_tuples:\n            Raw rdf tuples\n        root_path: str\n            Root path containing the data\n        \"\"\"\n        mg = nx.MultiDiGraph()\n        ent_classes = OrderedDict()\n        rel_classes = OrderedDict()\n        entities = OrderedDict()\n        src = []\n        dst = []\n        ntid = []\n        etid = []\n        sorted_tuples = []\n        for t in raw_tuples:\n            sorted_tuples.append(t)\n        sorted_tuples.sort()\n\n        for i, (sbj, pred, obj) in enumerate(sorted_tuples):\n            if self.verbose and i % self._print_every == 0:\n                print(\n                    \"Processed %d tuples, found %d valid tuples.\"\n                    % (i, len(src))\n                )\n            sbjent = self.parse_entity(sbj)\n            rel = self.parse_relation(pred)\n            objent = self.parse_entity(obj)\n            processed = self.process_tuple(\n                (sbj, pred, obj), sbjent, rel, objent\n            )\n            if processed is None:\n                # ignored\n                continue\n            # meta graph\n            sbjclsid = _get_id(ent_classes, sbjent.cls)\n            objclsid = _get_id(ent_classes, objent.cls)\n            relclsid = _get_id(rel_classes, rel.cls)\n            mg.add_edge(sbjent.cls, objent.cls, key=rel.cls)\n            if self._insert_reverse:\n                mg.add_edge(objent.cls, sbjent.cls, key=\"rev-%s\" % rel.cls)\n            # instance graph\n            src_id = _get_id(entities, str(sbjent))\n            if len(entities) > len(ntid):  # found new entity\n                ntid.append(sbjclsid)\n            dst_id = _get_id(entities, str(objent))\n            if len(entities) > len(ntid):  # found new entity\n                ntid.append(objclsid)\n            src.append(src_id)\n            dst.append(dst_id)\n            etid.append(relclsid)\n\n        src = np.asarray(src)\n        dst = np.asarray(dst)\n        ntid = np.asarray(ntid)\n        etid = np.asarray(etid)\n        ntypes = list(ent_classes.keys())\n        etypes = list(rel_classes.keys())\n\n        # add reverse edge with reverse relation\n        if self._insert_reverse:\n            if self.verbose:\n                print(\"Adding reverse edges ...\")\n            newsrc = np.hstack([src, dst])\n            newdst = np.hstack([dst, src])\n            src = newsrc\n            dst = newdst\n            etid = np.hstack([etid, etid + len(etypes)])\n            etypes.extend([\"rev-%s\" % t for t in etypes])\n\n        hg = self.build_graph(mg, src, dst, ntid, etid, ntypes, etypes)\n\n        if self.verbose:\n            print(\"Load training/validation/testing split ...\")\n        idmap = F.asnumpy(hg.nodes[self.predict_category].data[dgl.NID])\n        glb2lcl = {glbid: lclid for lclid, glbid in enumerate(idmap)}\n\n        def findidfn(ent):\n            if ent not in entities:\n                return None\n            else:\n                return glb2lcl[entities[ent]]\n\n        self._hg = hg\n        train_idx, test_idx, labels, num_classes = self.load_data_split(\n            findidfn, root_path\n        )\n\n        train_mask = idx2mask(\n            train_idx, self._hg.num_nodes(self.predict_category)\n        )\n        test_mask = idx2mask(\n            test_idx, self._hg.num_nodes(self.predict_category)\n        )\n        labels = F.tensor(labels, F.data_type_dict[\"int64\"])\n\n        train_mask = generate_mask_tensor(train_mask)\n        test_mask = generate_mask_tensor(test_mask)\n        self._hg.nodes[self.predict_category].data[\"train_mask\"] = train_mask\n        self._hg.nodes[self.predict_category].data[\"test_mask\"] = test_mask\n        # TODO(minjie): Deprecate 'labels', use 'label' for consistency.\n        self._hg.nodes[self.predict_category].data[\"labels\"] = labels\n        self._hg.nodes[self.predict_category].data[\"label\"] = labels\n        self._num_classes = num_classes\n\n    def build_graph(self, mg, src, dst, ntid, etid, ntypes, etypes):\n        \"\"\"Build the graphs\n\n        Parameters\n        ----------\n        mg: MultiDiGraph\n            Input graph\n        src: Numpy array\n            Source nodes\n        dst: Numpy array\n            Destination nodes\n        ntid: Numpy array\n            Node types for each node\n        etid: Numpy array\n            Edge types for each edge\n        ntypes: list\n            Node types\n        etypes: list\n            Edge types\n\n        Returns\n        -------\n        g: DGLGraph\n        \"\"\"\n        # create homo graph\n        if self.verbose:\n            print(\"Creating one whole graph ...\")\n        g = dgl.graph((src, dst))\n        g.ndata[dgl.NTYPE] = F.tensor(ntid)\n        g.edata[dgl.ETYPE] = F.tensor(etid)\n        if self.verbose:\n            print(\"Total #nodes:\", g.num_nodes())\n            print(\"Total #edges:\", g.num_edges())\n\n        # rename names such as 'type' so that they an be used as keys\n        # to nn.ModuleDict\n        etypes = [RENAME_DICT.get(ty, ty) for ty in etypes]\n        mg_edges = mg.edges(keys=True)\n        mg = nx.MultiDiGraph()\n        for sty, dty, ety in mg_edges:\n            mg.add_edge(sty, dty, key=RENAME_DICT.get(ety, ety))\n\n        # convert to heterograph\n        if self.verbose:\n            print(\"Convert to heterograph ...\")\n        hg = dgl.to_heterogeneous(g, ntypes, etypes, metagraph=mg)\n        if self.verbose:\n            print(\"#Node types:\", len(hg.ntypes))\n            print(\"#Canonical edge types:\", len(hg.etypes))\n            print(\"#Unique edge type names:\", len(set(hg.etypes)))\n        return hg\n\n    def load_data_split(self, ent2id, root_path):\n        \"\"\"Load data split\n\n        Parameters\n        ----------\n        ent2id: func\n            A function mapping entity to id\n        root_path: str\n            Root path containing the data\n\n        Return\n        ------\n        train_idx: Numpy array\n            Training set\n        test_idx: Numpy array\n            Testing set\n        labels: Numpy array\n            Labels\n        num_classes: int\n            Number of classes\n        \"\"\"\n        label_dict = {}\n        labels = np.zeros((self._hg.num_nodes(self.predict_category),)) - 1\n        train_idx = self.parse_idx_file(\n            os.path.join(root_path, \"trainingSet.tsv\"),\n            ent2id,\n            label_dict,\n            labels,\n        )\n        test_idx = self.parse_idx_file(\n            os.path.join(root_path, \"testSet.tsv\"), ent2id, label_dict, labels\n        )\n        train_idx = np.array(train_idx)\n        test_idx = np.array(test_idx)\n        labels = np.array(labels)\n        num_classes = len(label_dict)\n        return train_idx, test_idx, labels, num_classes\n\n    def parse_idx_file(self, filename, ent2id, label_dict, labels):\n        \"\"\"Parse idx files\n\n        Parameters\n        ----------\n        filename: str\n            File to parse\n        ent2id: func\n            A function mapping entity to id\n        label_dict: dict\n            Map label to label id\n        labels: dict\n            Map entity id to label id\n\n        Return\n        ------\n        idx: list\n            Entity idss\n        \"\"\"\n        idx = []\n        with open(filename, \"r\") as f:\n            for i, line in enumerate(f):\n                if i == 0:\n                    continue  # first line is the header\n                sample, label = self.process_idx_file_line(line)\n                # person, _, label = line.strip().split('\\t')\n                ent = self.parse_entity(sample)\n                entid = ent2id(str(ent))\n                if entid is None:\n                    print(\n                        'Warning: entity \"%s\" does not have any valid links associated. Ignored.'\n                        % str(ent)\n                    )\n                else:\n                    idx.append(entid)\n                    lblid = _get_id(label_dict, label)\n                    labels[entid] = lblid\n        return idx\n\n    def has_cache(self):\n        \"\"\"check if there is a processed data\"\"\"\n        graph_path = os.path.join(self.save_path, self.save_name + \".bin\")\n        info_path = os.path.join(self.save_path, self.save_name + \".pkl\")\n        if os.path.exists(graph_path) and os.path.exists(info_path):\n            return True\n\n        return False\n\n    def save(self):\n        \"\"\"save the graph list and the labels\"\"\"\n        graph_path = os.path.join(self.save_path, self.save_name + \".bin\")\n        info_path = os.path.join(self.save_path, self.save_name + \".pkl\")\n        save_graphs(str(graph_path), self._hg)\n        save_info(\n            str(info_path),\n            {\n                \"num_classes\": self.num_classes,\n                \"predict_category\": self.predict_category,\n            },\n        )\n\n    def load(self):\n        \"\"\"load the graph list and the labels from disk\"\"\"\n        graph_path = os.path.join(self.save_path, self.save_name + \".bin\")\n        info_path = os.path.join(self.save_path, self.save_name + \".pkl\")\n        graphs, _ = load_graphs(str(graph_path))\n\n        info = load_info(str(info_path))\n        self._num_classes = info[\"num_classes\"]\n        self._predict_category = info[\"predict_category\"]\n        self._hg = graphs[0]\n        # For backward compatibility\n        if \"label\" not in self._hg.nodes[self.predict_category].data:\n            self._hg.nodes[self.predict_category].data[\n                \"label\"\n            ] = self._hg.nodes[self.predict_category].data[\"labels\"]\n\n    def __getitem__(self, idx):\n        r\"\"\"Gets the graph object\"\"\"\n        g = self._hg\n        if self._transform is not None:\n            g = self._transform(g)\n        return g\n\n    def __len__(self):\n        r\"\"\"The number of graphs in the dataset.\"\"\"\n        return 1\n\n    @property\n    def save_name(self):\n        return self.name + \"_dgl_graph\"\n\n    @property\n    def predict_category(self):\n        return self._predict_category\n\n    @property\n    def num_classes(self):\n        return self._num_classes\n\n    @abc.abstractmethod\n    def parse_entity(self, term):\n        \"\"\"Parse one entity from an RDF term.\n        Return None if the term does not represent a valid entity and the\n        whole tuple should be ignored.\n        Parameters\n        ----------\n        term : rdflib.term.Identifier\n            RDF term\n        Returns\n        -------\n        Entity or None\n            An entity.\n        \"\"\"\n        pass\n\n    @abc.abstractmethod\n    def parse_relation(self, term):\n        \"\"\"Parse one relation from an RDF term.\n        Return None if the term does not represent a valid relation and the\n        whole tuple should be ignored.\n        Parameters\n        ----------\n        term : rdflib.term.Identifier\n            RDF term\n        Returns\n        -------\n        Relation or None\n            A relation\n        \"\"\"\n        pass\n\n    @abc.abstractmethod\n    def process_tuple(self, raw_tuple, sbj, rel, obj):\n        \"\"\"Process the tuple.\n        Return (Entity, Relation, Entity) tuple for as the final tuple.\n        Return None if the tuple should be ignored.\n\n        Parameters\n        ----------\n        raw_tuple : tuple of rdflib.term.Identifier\n            (subject, predicate, object) tuple\n        sbj : Entity\n            Subject entity\n        rel : Relation\n            Relation\n        obj : Entity\n            Object entity\n        Returns\n        -------\n        (Entity, Relation, Entity)\n            The final tuple or None if should be ignored\n        \"\"\"\n        pass\n\n    @abc.abstractmethod\n    def process_idx_file_line(self, line):\n        \"\"\"Process one line of ``trainingSet.tsv`` or ``testSet.tsv``.\n        Parameters\n        ----------\n        line : str\n            One line of the file\n        Returns\n        -------\n        (str, str)\n            One sample and its label\n        \"\"\"\n        pass\n\n\ndef _get_id(dict, key):\n    id = dict.get(key, None)\n    if id is None:\n        id = len(dict)\n        dict[key] = id\n    return id\n\n\nclass AIFBDataset(RDFGraphDataset):\n    r\"\"\"AIFB dataset for node classification task\n\n    AIFB DataSet is a Semantic Web (RDF) dataset used as a benchmark in\n    data mining.  It records the organizational structure of AIFB at the\n    University of Karlsruhe.\n\n    AIFB dataset statistics:\n\n    - Nodes: 7262\n    - Edges: 48810 (including reverse edges)\n    - Target Category: Personen\n    - Number of Classes: 4\n    - Label Split:\n\n        - Train: 140\n        - Test: 36\n\n    Parameters\n    -----------\n    print_every : int\n        Preprocessing log for every X tuples. Default: 10000.\n    insert_reverse : bool\n        If true, add reverse edge and reverse relations to the final graph. Default: True.\n    raw_dir : str\n        Raw file directory to download/contains the input data directory.\n        Default: ~/.dgl/\n    force_reload : bool\n        Whether to reload the dataset. Default: False\n    verbose : bool\n        Whether to print out progress information. Default: True.\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access.\n\n    Attributes\n    ----------\n    num_classes : int\n        Number of classes to predict\n    predict_category : str\n        The entity category (node type) that has labels for prediction\n\n    Examples\n    --------\n    >>> dataset = dgl.data.rdf.AIFBDataset()\n    >>> graph = dataset[0]\n    >>> category = dataset.predict_category\n    >>> num_classes = dataset.num_classes\n    >>>\n    >>> train_mask = g.nodes[category].data['train_mask']\n    >>> test_mask = g.nodes[category].data['test_mask']\n    >>> label = g.nodes[category].data['label']\n    \"\"\"\n\n    entity_prefix = \"http://www.aifb.uni-karlsruhe.de/\"\n    relation_prefix = \"http://swrc.ontoware.org/\"\n\n    def __init__(\n        self,\n        print_every=10000,\n        insert_reverse=True,\n        raw_dir=None,\n        force_reload=False,\n        verbose=True,\n        transform=None,\n    ):\n        import rdflib as rdf\n\n        self.employs = rdf.term.URIRef(\n            \"http://swrc.ontoware.org/ontology#employs\"\n        )\n        self.affiliation = rdf.term.URIRef(\n            \"http://swrc.ontoware.org/ontology#affiliation\"\n        )\n        url = _get_dgl_url(\"dataset/rdf/aifb-hetero.zip\")\n        name = \"aifb-hetero\"\n        predict_category = \"Personen\"\n        super(AIFBDataset, self).__init__(\n            name,\n            url,\n            predict_category,\n            print_every=print_every,\n            insert_reverse=insert_reverse,\n            raw_dir=raw_dir,\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n    def __getitem__(self, idx):\n        r\"\"\"Gets the graph object\n\n        Parameters\n        -----------\n        idx: int\n            Item index, AIFBDataset has only one graph object\n\n        Return\n        -------\n        :class:`dgl.DGLGraph`\n\n            The graph contains:\n\n            - ``ndata['train_mask']``: mask for training node set\n            - ``ndata['test_mask']``: mask for testing node set\n            - ``ndata['label']``: node labels\n        \"\"\"\n        return super(AIFBDataset, self).__getitem__(idx)\n\n    def __len__(self):\n        r\"\"\"The number of graphs in the dataset.\n\n        Return\n        -------\n        int\n        \"\"\"\n        return super(AIFBDataset, self).__len__()\n\n    def parse_entity(self, term):\n        import rdflib as rdf\n\n        if isinstance(term, rdf.Literal):\n            return Entity(e_id=str(term), cls=\"_Literal\")\n        if isinstance(term, rdf.BNode):\n            return None\n        entstr = str(term)\n        if entstr.startswith(self.entity_prefix):\n            sp = entstr.split(\"/\")\n            return Entity(e_id=sp[5], cls=sp[3])\n        else:\n            return None\n\n    def parse_relation(self, term):\n        if term == self.employs or term == self.affiliation:\n            return None\n        relstr = str(term)\n        if relstr.startswith(self.relation_prefix):\n            return Relation(cls=relstr.split(\"/\")[3])\n        else:\n            relstr = relstr.split(\"/\")[-1]\n            return Relation(cls=relstr)\n\n    def process_tuple(self, raw_tuple, sbj, rel, obj):\n        if sbj is None or rel is None or obj is None:\n            return None\n        return (sbj, rel, obj)\n\n    def process_idx_file_line(self, line):\n        person, _, label = line.strip().split(\"\\t\")\n        return person, label\n\n\nclass MUTAGDataset(RDFGraphDataset):\n    r\"\"\"MUTAG dataset for node classification task\n\n    Mutag dataset statistics:\n\n    - Nodes: 27163\n    - Edges: 148100 (including reverse edges)\n    - Target Category: d\n    - Number of Classes: 2\n    - Label Split:\n\n        - Train: 272\n        - Test: 68\n\n    Parameters\n    -----------\n    print_every : int\n        Preprocessing log for every X tuples. Default: 10000.\n    insert_reverse : bool\n        If true, add reverse edge and reverse relations to the final graph. Default: True.\n    raw_dir : str\n        Raw file directory to download/contains the input data directory.\n        Default: ~/.dgl/\n    force_reload : bool\n        Whether to reload the dataset. Default: False\n    verbose : bool\n        Whether to print out progress information. Default: True.\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access.\n\n    Attributes\n    ----------\n    num_classes : int\n        Number of classes to predict\n    predict_category : str\n        The entity category (node type) that has labels for prediction\n    graph : :class:`dgl.DGLGraph`\n        Graph structure\n\n    Examples\n    --------\n    >>> dataset = dgl.data.rdf.MUTAGDataset()\n    >>> graph = dataset[0]\n    >>> category = dataset.predict_category\n    >>> num_classes = dataset.num_classes\n    >>>\n    >>> train_mask = g.nodes[category].data['train_mask']\n    >>> test_mask = g.nodes[category].data['test_mask']\n    >>> label = g.nodes[category].data['label']\n    \"\"\"\n\n    d_entity = re.compile(\"d[0-9]\")\n    bond_entity = re.compile(\"bond[0-9]\")\n\n    entity_prefix = \"http://dl-learner.org/carcinogenesis#\"\n    relation_prefix = entity_prefix\n\n    def __init__(\n        self,\n        print_every=10000,\n        insert_reverse=True,\n        raw_dir=None,\n        force_reload=False,\n        verbose=True,\n        transform=None,\n    ):\n        import rdflib as rdf\n\n        self.is_mutagenic = rdf.term.URIRef(\n            \"http://dl-learner.org/carcinogenesis#isMutagenic\"\n        )\n        self.rdf_type = rdf.term.URIRef(\n            \"http://www.w3.org/1999/02/22-rdf-syntax-ns#type\"\n        )\n        self.rdf_subclassof = rdf.term.URIRef(\n            \"http://www.w3.org/2000/01/rdf-schema#subClassOf\"\n        )\n        self.rdf_domain = rdf.term.URIRef(\n            \"http://www.w3.org/2000/01/rdf-schema#domain\"\n        )\n\n        url = _get_dgl_url(\"dataset/rdf/mutag-hetero.zip\")\n        name = \"mutag-hetero\"\n        predict_category = \"d\"\n        super(MUTAGDataset, self).__init__(\n            name,\n            url,\n            predict_category,\n            print_every=print_every,\n            insert_reverse=insert_reverse,\n            raw_dir=raw_dir,\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n    def __getitem__(self, idx):\n        r\"\"\"Gets the graph object\n\n        Parameters\n        -----------\n        idx: int\n            Item index, MUTAGDataset has only one graph object\n\n        Return\n        -------\n        :class:`dgl.DGLGraph`\n\n            The graph contains:\n\n            - ``ndata['train_mask']``: mask for training node set\n            - ``ndata['test_mask']``: mask for testing node set\n            - ``ndata['label']``: node labels\n        \"\"\"\n        return super(MUTAGDataset, self).__getitem__(idx)\n\n    def __len__(self):\n        r\"\"\"The number of graphs in the dataset.\n\n        Return\n        -------\n        int\n        \"\"\"\n        return super(MUTAGDataset, self).__len__()\n\n    def parse_entity(self, term):\n        import rdflib as rdf\n\n        if isinstance(term, rdf.Literal):\n            return Entity(e_id=str(term), cls=\"_Literal\")\n        elif isinstance(term, rdf.BNode):\n            return None\n        entstr = str(term)\n        if entstr.startswith(self.entity_prefix):\n            inst = entstr[len(self.entity_prefix) :]\n            if self.d_entity.match(inst):\n                cls = \"d\"\n            elif self.bond_entity.match(inst):\n                cls = \"bond\"\n            else:\n                cls = None\n            return Entity(e_id=inst, cls=cls)\n        else:\n            return None\n\n    def parse_relation(self, term):\n        if term == self.is_mutagenic:\n            return None\n        relstr = str(term)\n        if relstr.startswith(self.relation_prefix):\n            cls = relstr[len(self.relation_prefix) :]\n            return Relation(cls=cls)\n        else:\n            relstr = relstr.split(\"/\")[-1]\n            return Relation(cls=relstr)\n\n    def process_tuple(self, raw_tuple, sbj, rel, obj):\n        if sbj is None or rel is None or obj is None:\n            return None\n\n        if not raw_tuple[1].startswith(\"http://dl-learner.org/carcinogenesis#\"):\n            obj.cls = \"SCHEMA\"\n            if sbj.cls is None:\n                sbj.cls = \"SCHEMA\"\n        if obj.cls is None:\n            obj.cls = rel.cls\n\n        assert sbj.cls is not None and obj.cls is not None\n\n        return (sbj, rel, obj)\n\n    def process_idx_file_line(self, line):\n        bond, _, label = line.strip().split(\"\\t\")\n        return bond, label\n\n\nclass BGSDataset(RDFGraphDataset):\n    r\"\"\"BGS dataset for node classification task\n\n    BGS namespace convention:\n    ``http://data.bgs.ac.uk/(ref|id)/<Major Concept>/<Sub Concept>/INSTANCE``.\n    We ignored all literal nodes and the relations connecting them in the\n    output graph. We also ignored the relation used to mark whether a\n    term is CURRENT or DEPRECATED.\n\n    BGS dataset statistics:\n\n    - Nodes: 94806\n    - Edges: 672884 (including reverse edges)\n    - Target Category: Lexicon/NamedRockUnit\n    - Number of Classes: 2\n    - Label Split:\n\n        - Train: 117\n        - Test: 29\n\n    Parameters\n    -----------\n    print_every : int\n        Preprocessing log for every X tuples. Default: 10000.\n    insert_reverse : bool\n        If true, add reverse edge and reverse relations to the final graph. Default: True.\n    raw_dir : str\n        Raw file directory to download/contains the input data directory.\n        Default: ~/.dgl/\n    force_reload : bool\n        Whether to reload the dataset. Default: False\n    verbose : bool\n        Whether to print out progress information. Default: True.\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access.\n\n    Attributes\n    ----------\n    num_classes : int\n        Number of classes to predict\n    predict_category : str\n        All the labels of the entities in ``predict_category``\n\n    Examples\n    --------\n    >>> dataset = dgl.data.rdf.BGSDataset()\n    >>> graph = dataset[0]\n    >>> category = dataset.predict_category\n    >>> num_classes = dataset.num_classes\n    >>>\n    >>> train_mask = g.nodes[category].data['train_mask']\n    >>> test_mask = g.nodes[category].data['test_mask']\n    >>> label = g.nodes[category].data['label']\n    \"\"\"\n\n    entity_prefix = \"http://data.bgs.ac.uk/\"\n    status_prefix = \"http://data.bgs.ac.uk/ref/CurrentStatus\"\n    relation_prefix = \"http://data.bgs.ac.uk/ref\"\n\n    def __init__(\n        self,\n        print_every=10000,\n        insert_reverse=True,\n        raw_dir=None,\n        force_reload=False,\n        verbose=True,\n        transform=None,\n    ):\n        import rdflib as rdf\n\n        url = _get_dgl_url(\"dataset/rdf/bgs-hetero.zip\")\n        name = \"bgs-hetero\"\n        predict_category = \"Lexicon/NamedRockUnit\"\n        self.lith = rdf.term.URIRef(\n            \"http://data.bgs.ac.uk/ref/Lexicon/hasLithogenesis\"\n        )\n        super(BGSDataset, self).__init__(\n            name,\n            url,\n            predict_category,\n            print_every=print_every,\n            insert_reverse=insert_reverse,\n            raw_dir=raw_dir,\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n    def __getitem__(self, idx):\n        r\"\"\"Gets the graph object\n\n        Parameters\n        -----------\n        idx: int\n            Item index, BGSDataset has only one graph object\n\n        Return\n        -------\n        :class:`dgl.DGLGraph`\n\n            The graph contains:\n\n            - ``ndata['train_mask']``: mask for training node set\n            - ``ndata['test_mask']``: mask for testing node set\n            - ``ndata['label']``: node labels\n        \"\"\"\n        return super(BGSDataset, self).__getitem__(idx)\n\n    def __len__(self):\n        r\"\"\"The number of graphs in the dataset.\n\n        Return\n        -------\n        int\n        \"\"\"\n        return super(BGSDataset, self).__len__()\n\n    def parse_entity(self, term):\n        import rdflib as rdf\n\n        if isinstance(term, rdf.Literal):\n            return None\n        elif isinstance(term, rdf.BNode):\n            return None\n        entstr = str(term)\n        if entstr.startswith(self.status_prefix):\n            return None\n        if entstr.startswith(self.entity_prefix):\n            sp = entstr.split(\"/\")\n            if len(sp) != 7:\n                return None\n            # instance\n            cls = \"%s/%s\" % (sp[4], sp[5])\n            inst = sp[6]\n            return Entity(e_id=inst, cls=cls)\n        else:\n            return None\n\n    def parse_relation(self, term):\n        if term == self.lith:\n            return None\n        relstr = str(term)\n        if relstr.startswith(self.relation_prefix):\n            sp = relstr.split(\"/\")\n            if len(sp) < 6:\n                return None\n            assert len(sp) == 6, relstr\n            cls = \"%s/%s\" % (sp[4], sp[5])\n            return Relation(cls=cls)\n        else:\n            relstr = relstr.replace(\".\", \"_\")\n            return Relation(cls=relstr)\n\n    def process_tuple(self, raw_tuple, sbj, rel, obj):\n        if sbj is None or rel is None or obj is None:\n            return None\n        return (sbj, rel, obj)\n\n    def process_idx_file_line(self, line):\n        _, rock, label = line.strip().split(\"\\t\")\n        return rock, label\n\n\nclass AMDataset(RDFGraphDataset):\n    \"\"\"AM dataset. for node classification task\n\n    Namespace convention:\n\n    - Instance: ``http://purl.org/collections/nl/am/<type>-<id>``\n    - Relation: ``http://purl.org/collections/nl/am/<name>``\n\n    We ignored all literal nodes and the relations connecting them in the\n    output graph.\n\n    AM dataset statistics:\n\n    - Nodes: 881680\n    - Edges: 5668682 (including reverse edges)\n    - Target Category: proxy\n    - Number of Classes: 11\n    - Label Split:\n\n        - Train: 802\n        - Test: 198\n\n    Parameters\n    -----------\n    print_every : int\n        Preprocessing log for every X tuples. Default: 10000.\n    insert_reverse : bool\n        If true, add reverse edge and reverse relations to the final graph. Default: True.\n    raw_dir : str\n        Raw file directory to download/contains the input data directory.\n        Default: ~/.dgl/\n    force_reload : bool\n        Whether to reload the dataset. Default: False\n    verbose : bool\n        Whether to print out progress information. Default: True.\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access.\n\n    Attributes\n    ----------\n    num_classes : int\n        Number of classes to predict\n    predict_category : str\n        The entity category (node type) that has labels for prediction\n\n    Examples\n    --------\n    >>> dataset = dgl.data.rdf.AMDataset()\n    >>> graph = dataset[0]\n    >>> category = dataset.predict_category\n    >>> num_classes = dataset.num_classes\n    >>>\n    >>> train_mask = g.nodes[category].data['train_mask']\n    >>> test_mask = g.nodes[category].data['test_mask']\n    >>> label = g.nodes[category].data['label']\n    \"\"\"\n\n    entity_prefix = \"http://purl.org/collections/nl/am/\"\n    relation_prefix = entity_prefix\n\n    def __init__(\n        self,\n        print_every=10000,\n        insert_reverse=True,\n        raw_dir=None,\n        force_reload=False,\n        verbose=True,\n        transform=None,\n    ):\n        import rdflib as rdf\n\n        self.objectCategory = rdf.term.URIRef(\n            \"http://purl.org/collections/nl/am/objectCategory\"\n        )\n        self.material = rdf.term.URIRef(\n            \"http://purl.org/collections/nl/am/material\"\n        )\n        url = _get_dgl_url(\"dataset/rdf/am-hetero.zip\")\n        name = \"am-hetero\"\n        predict_category = \"proxy\"\n        super(AMDataset, self).__init__(\n            name,\n            url,\n            predict_category,\n            print_every=print_every,\n            insert_reverse=insert_reverse,\n            raw_dir=raw_dir,\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n    def __getitem__(self, idx):\n        r\"\"\"Gets the graph object\n\n        Parameters\n        -----------\n        idx: int\n            Item index, AMDataset has only one graph object\n\n        Return\n        -------\n        :class:`dgl.DGLGraph`\n\n            The graph contains:\n\n            - ``ndata['train_mask']``: mask for training node set\n            - ``ndata['test_mask']``: mask for testing node set\n            - ``ndata['label']``: node labels\n        \"\"\"\n        return super(AMDataset, self).__getitem__(idx)\n\n    def __len__(self):\n        r\"\"\"The number of graphs in the dataset.\n\n        Return\n        -------\n        int\n        \"\"\"\n        return super(AMDataset, self).__len__()\n\n    def parse_entity(self, term):\n        import rdflib as rdf\n\n        if isinstance(term, rdf.Literal):\n            return None\n        elif isinstance(term, rdf.BNode):\n            return Entity(e_id=str(term), cls=\"_BNode\")\n        entstr = str(term)\n        if entstr.startswith(self.entity_prefix):\n            sp = entstr.split(\"/\")\n            assert len(sp) == 7, entstr\n            spp = sp[6].split(\"-\")\n            if len(spp) == 2:\n                # instance\n                cls, inst = spp\n            else:\n                cls = \"TYPE\"\n                inst = spp\n            return Entity(e_id=inst, cls=cls)\n        else:\n            return None\n\n    def parse_relation(self, term):\n        if term == self.objectCategory or term == self.material:\n            return None\n        relstr = str(term)\n        if relstr.startswith(self.relation_prefix):\n            sp = relstr.split(\"/\")\n            assert len(sp) == 7, relstr\n            cls = sp[6]\n            return Relation(cls=cls)\n        else:\n            relstr = relstr.replace(\".\", \"_\")\n            return Relation(cls=relstr)\n\n    def process_tuple(self, raw_tuple, sbj, rel, obj):\n        if sbj is None or rel is None or obj is None:\n            return None\n        return (sbj, rel, obj)\n\n    def process_idx_file_line(self, line):\n        proxy, _, label = line.strip().split(\"\\t\")\n        return proxy, label\n"
  },
  {
    "path": "python/dgl/data/reddit.py",
    "content": "\"\"\" Reddit dataset for community detection \"\"\"\nfrom __future__ import absolute_import\n\nimport os\n\nimport numpy as np\n\nimport scipy.sparse as sp\n\nfrom .. import backend as F\nfrom ..convert import from_scipy\nfrom ..transforms import reorder_graph\n\nfrom .dgl_dataset import DGLBuiltinDataset\nfrom .utils import (\n    _get_dgl_url,\n    deprecate_property,\n    generate_mask_tensor,\n    load_graphs,\n    save_graphs,\n)\n\n\nclass RedditDataset(DGLBuiltinDataset):\n    r\"\"\"Reddit dataset for community detection (node classification)\n\n    This is a graph dataset from Reddit posts made in the month of September, 2014.\n    The node label in this case is the community, or “subreddit”, that a post belongs to.\n    The authors sampled 50 large communities and built a post-to-post graph, connecting\n    posts if the same user comments on both. In total this dataset contains 232,965\n    posts with an average degree of 492. We use the first 20 days for training and the\n    remaining days for testing (with 30% used for validation).\n\n    Reference: `<http://snap.stanford.edu/graphsage/>`_\n\n    Statistics\n\n    - Nodes: 232,965\n    - Edges: 114,615,892\n    - Node feature size: 602\n    - Number of training samples: 153,431\n    - Number of validation samples: 23,831\n    - Number of test samples: 55,703\n\n    Parameters\n    ----------\n    self_loop : bool\n        Whether load dataset with self loop connections. Default: False\n    raw_dir : str\n        Raw file directory to download/contains the input data directory.\n        Default: ~/.dgl/\n    force_reload : bool\n        Whether to reload the dataset. Default: False\n    verbose : bool\n        Whether to print out progress information. Default: True.\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access.\n\n    Attributes\n    ----------\n    num_classes : int\n        Number of classes for each node\n\n    Examples\n    --------\n    >>> data = RedditDataset()\n    >>> g = data[0]\n    >>> num_classes = data.num_classes\n    >>>\n    >>> # get node feature\n    >>> feat = g.ndata['feat']\n    >>>\n    >>> # get data split\n    >>> train_mask = g.ndata['train_mask']\n    >>> val_mask = g.ndata['val_mask']\n    >>> test_mask = g.ndata['test_mask']\n    >>>\n    >>> # get labels\n    >>> label = g.ndata['label']\n    >>>\n    >>> # Train, Validation and Test\n    \"\"\"\n\n    def __init__(\n        self,\n        self_loop=False,\n        raw_dir=None,\n        force_reload=False,\n        verbose=False,\n        transform=None,\n    ):\n        self_loop_str = \"\"\n        if self_loop:\n            self_loop_str = \"_self_loop\"\n        _url = _get_dgl_url(\"dataset/reddit{}.zip\".format(self_loop_str))\n        self._self_loop_str = self_loop_str\n        super(RedditDataset, self).__init__(\n            name=\"reddit{}\".format(self_loop_str),\n            url=_url,\n            raw_dir=raw_dir,\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n    def process(self):\n        # graph\n        coo_adj = sp.load_npz(\n            os.path.join(\n                self.raw_path, \"reddit{}_graph.npz\".format(self._self_loop_str)\n            )\n        )\n        self._graph = from_scipy(coo_adj)\n        # features and labels\n        reddit_data = np.load(os.path.join(self.raw_path, \"reddit_data.npz\"))\n        features = reddit_data[\"feature\"]\n        labels = reddit_data[\"label\"]\n        # tarin/val/test indices\n        node_types = reddit_data[\"node_types\"]\n        train_mask = node_types == 1\n        val_mask = node_types == 2\n        test_mask = node_types == 3\n        self._graph.ndata[\"train_mask\"] = generate_mask_tensor(train_mask)\n        self._graph.ndata[\"val_mask\"] = generate_mask_tensor(val_mask)\n        self._graph.ndata[\"test_mask\"] = generate_mask_tensor(test_mask)\n        self._graph.ndata[\"feat\"] = F.tensor(\n            features, dtype=F.data_type_dict[\"float32\"]\n        )\n        self._graph.ndata[\"label\"] = F.tensor(\n            labels, dtype=F.data_type_dict[\"int64\"]\n        )\n        self._graph = reorder_graph(\n            self._graph,\n            node_permute_algo=\"rcmk\",\n            edge_permute_algo=\"dst\",\n            store_ids=False,\n        )\n\n        self._print_info()\n\n    def has_cache(self):\n        graph_path = os.path.join(self.save_path, \"dgl_graph.bin\")\n        if os.path.exists(graph_path):\n            return True\n        return False\n\n    def save(self):\n        graph_path = os.path.join(self.save_path, \"dgl_graph.bin\")\n        save_graphs(graph_path, self._graph)\n\n    def load(self):\n        graph_path = os.path.join(self.save_path, \"dgl_graph.bin\")\n        graphs, _ = load_graphs(graph_path)\n        self._graph = graphs[0]\n        self._graph.ndata[\"train_mask\"] = generate_mask_tensor(\n            self._graph.ndata[\"train_mask\"].numpy()\n        )\n        self._graph.ndata[\"val_mask\"] = generate_mask_tensor(\n            self._graph.ndata[\"val_mask\"].numpy()\n        )\n        self._graph.ndata[\"test_mask\"] = generate_mask_tensor(\n            self._graph.ndata[\"test_mask\"].numpy()\n        )\n        self._print_info()\n\n    def _print_info(self):\n        if self.verbose:\n            print(\"Finished data loading.\")\n            print(\"  NumNodes: {}\".format(self._graph.num_nodes()))\n            print(\"  NumEdges: {}\".format(self._graph.num_edges()))\n            print(\"  NumFeats: {}\".format(self._graph.ndata[\"feat\"].shape[1]))\n            print(\"  NumClasses: {}\".format(self.num_classes))\n            print(\n                \"  NumTrainingSamples: {}\".format(\n                    F.nonzero_1d(self._graph.ndata[\"train_mask\"]).shape[0]\n                )\n            )\n            print(\n                \"  NumValidationSamples: {}\".format(\n                    F.nonzero_1d(self._graph.ndata[\"val_mask\"]).shape[0]\n                )\n            )\n            print(\n                \"  NumTestSamples: {}\".format(\n                    F.nonzero_1d(self._graph.ndata[\"test_mask\"]).shape[0]\n                )\n            )\n\n    @property\n    def num_classes(self):\n        r\"\"\"Number of classes for each node.\"\"\"\n        return 41\n\n    def __getitem__(self, idx):\n        r\"\"\"Get graph by index\n\n        Parameters\n        ----------\n        idx : int\n            Item index\n\n        Returns\n        -------\n        :class:`dgl.DGLGraph`\n            graph structure, node labels, node features and splitting masks:\n\n            - ``ndata['label']``: node label\n            - ``ndata['feat']``: node feature\n            - ``ndata['train_mask']``： mask for training node set\n            - ``ndata['val_mask']``: mask for validation node set\n            - ``ndata['test_mask']:`` mask for test node set\n        \"\"\"\n        assert idx == 0, \"Reddit Dataset only has one graph\"\n        if self._transform is None:\n            return self._graph\n        else:\n            return self._transform(self._graph)\n\n    def __len__(self):\n        r\"\"\"Number of graphs in the dataset\"\"\"\n        return 1\n"
  },
  {
    "path": "python/dgl/data/sbm.py",
    "content": "\"\"\"Dataset for stochastic block model.\"\"\"\nimport math\nimport os\nimport random\n\nimport numpy as np\nimport numpy.random as npr\nimport scipy as sp\n\nfrom .. import batch\nfrom ..convert import from_scipy\nfrom .dgl_dataset import DGLDataset\nfrom .utils import load_graphs, load_info, save_graphs, save_info\n\n\ndef sbm(n_blocks, block_size, p, q, rng=None):\n    \"\"\"(Symmetric) Stochastic Block Model\n\n    Parameters\n    ----------\n    n_blocks : int\n        Number of blocks.\n    block_size : int\n        Block size.\n    p : float\n        Probability for intra-community edge.\n    q : float\n        Probability for inter-community edge.\n    rng : numpy.random.RandomState, optional\n        Random number generator.\n\n    Returns\n    -------\n    scipy sparse matrix\n        The adjacency matrix of generated graph.\n    \"\"\"\n    n = n_blocks * block_size\n    p /= n\n    q /= n\n    rng = np.random.RandomState() if rng is None else rng\n\n    rows = []\n    cols = []\n    for i in range(n_blocks):\n        for j in range(i, n_blocks):\n            density = p if i == j else q\n            block = sp.sparse.random(\n                block_size,\n                block_size,\n                density,\n                random_state=rng,\n                data_rvs=lambda n: np.ones(n),\n            )\n            rows.append(block.row + i * block_size)\n            cols.append(block.col + j * block_size)\n\n    rows = np.hstack(rows)\n    cols = np.hstack(cols)\n    a = sp.sparse.coo_matrix(\n        (np.ones(rows.shape[0]), (rows, cols)), shape=(n, n)\n    )\n    adj = sp.sparse.triu(a) + sp.sparse.triu(a, 1).transpose()\n    return adj\n\n\nclass SBMMixtureDataset(DGLDataset):\n    r\"\"\"Symmetric Stochastic Block Model Mixture\n\n    Reference: Appendix C of `Supervised Community Detection with Hierarchical Graph Neural Networks <https://arxiv.org/abs/1705.08415>`_\n\n    Parameters\n    ----------\n    n_graphs : int\n        Number of graphs.\n    n_nodes : int\n        Number of nodes.\n    n_communities : int\n        Number of communities.\n    k : int, optional\n        Multiplier. Default: 2\n    avg_deg : int, optional\n        Average degree. Default: 3\n    pq : list of pair of nonnegative float or str, optional\n        Random densities. This parameter is for future extension,\n        for now it's always using the default value.\n        Default: Appendix_C\n    rng : numpy.random.RandomState, optional\n        Random number generator. If not given, it's numpy.random.RandomState() with `seed=None`,\n        which read data from /dev/urandom (or the Windows analogue) if available or seed from\n        the clock otherwise.\n        Default: None\n\n    Raises\n    ------\n    RuntimeError is raised if pq is not a list or string.\n\n    Examples\n    --------\n    >>> data = SBMMixtureDataset(n_graphs=16, n_nodes=10000, n_communities=2)\n    >>> from torch.utils.data import DataLoader\n    >>> dataloader = DataLoader(data, batch_size=1, collate_fn=data.collate_fn)\n    >>> for graph, line_graph, graph_degrees, line_graph_degrees, pm_pd in dataloader:\n    ...     # your code here\n    \"\"\"\n\n    def __init__(\n        self,\n        n_graphs,\n        n_nodes,\n        n_communities,\n        k=2,\n        avg_deg=3,\n        pq=\"Appendix_C\",\n        rng=None,\n    ):\n        self._n_graphs = n_graphs\n        self._n_nodes = n_nodes\n        self._n_communities = n_communities\n        assert n_nodes % n_communities == 0\n        self._block_size = n_nodes // n_communities\n        self._k = k\n        self._avg_deg = avg_deg\n        self._pq = pq\n        self._rng = rng\n        super(SBMMixtureDataset, self).__init__(\n            name=\"sbmmixture\",\n            hash_key=(n_graphs, n_nodes, n_communities, k, avg_deg, pq, rng),\n        )\n\n    def process(self):\n        pq = self._pq\n        if type(pq) is list:\n            assert len(pq) == self._n_graphs\n        elif type(pq) is str:\n            generator = {\"Appendix_C\": self._appendix_c}[pq]\n            pq = [generator() for _ in range(self._n_graphs)]\n        else:\n            raise RuntimeError()\n        self._graphs = [\n            from_scipy(sbm(self._n_communities, self._block_size, *x))\n            for x in pq\n        ]\n        self._line_graphs = [\n            g.line_graph(backtracking=False) for g in self._graphs\n        ]\n        in_degrees = lambda g: g.in_degrees().float()\n        self._graph_degrees = [in_degrees(g) for g in self._graphs]\n        self._line_graph_degrees = [in_degrees(lg) for lg in self._line_graphs]\n        self._pm_pds = list(zip(*[g.edges() for g in self._graphs]))[0]\n\n    @property\n    def graph_path(self):\n        return os.path.join(self.save_path, \"graphs_{}.bin\".format(self.hash))\n\n    @property\n    def line_graph_path(self):\n        return os.path.join(\n            self.save_path, \"line_graphs_{}.bin\".format(self.hash)\n        )\n\n    @property\n    def info_path(self):\n        return os.path.join(self.save_path, \"info_{}.pkl\".format(self.hash))\n\n    def has_cache(self):\n        return (\n            os.path.exists(self.graph_path)\n            and os.path.exists(self.line_graph_path)\n            and os.path.exists(self.info_path)\n        )\n\n    def save(self):\n        save_graphs(self.graph_path, self._graphs)\n        save_graphs(self.line_graph_path, self._line_graphs)\n        save_info(\n            self.info_path,\n            {\n                \"graph_degree\": self._graph_degrees,\n                \"line_graph_degree\": self._line_graph_degrees,\n                \"pm_pds\": self._pm_pds,\n            },\n        )\n\n    def load(self):\n        self._graphs, _ = load_graphs(self.graph_path)\n        self._line_graphs, _ = load_graphs(self.line_graph_path)\n        info = load_info(self.info_path)\n        self._graph_degrees = info[\"graph_degree\"]\n        self._line_graph_degrees = info[\"line_graph_degree\"]\n        self._pm_pds = info[\"pm_pds\"]\n\n    def __len__(self):\n        r\"\"\"Number of graphs in the dataset.\"\"\"\n        return len(self._graphs)\n\n    def __getitem__(self, idx):\n        r\"\"\"Get one example by index\n\n        Parameters\n        ----------\n        idx : int\n            Item index\n\n        Returns\n        -------\n        graph: :class:`dgl.DGLGraph`\n            The original graph\n        line_graph: :class:`dgl.DGLGraph`\n            The line graph of `graph`\n        graph_degree: numpy.ndarray\n            In degrees for each node in `graph`\n        line_graph_degree: numpy.ndarray\n            In degrees for each node in `line_graph`\n        pm_pd: numpy.ndarray\n            Edge indicator matrices Pm and Pd\n        \"\"\"\n        return (\n            self._graphs[idx],\n            self._line_graphs[idx],\n            self._graph_degrees[idx],\n            self._line_graph_degrees[idx],\n            self._pm_pds[idx],\n        )\n\n    def _appendix_c(self):\n        q = npr.uniform(0, self._avg_deg - math.sqrt(self._avg_deg))\n        p = self._k * self._avg_deg - q\n        if random.random() < 0.5:\n            return p, q\n        else:\n            return q, p\n\n    def collate_fn(self, x):\n        r\"\"\"The `collate` function for dataloader\n\n        Parameters\n        ----------\n        x : tuple\n            a batch of data that contains:\n\n            - graph: :class:`dgl.DGLGraph`\n                The original graph\n            - line_graph: :class:`dgl.DGLGraph`\n                The line graph of `graph`\n            - graph_degree: numpy.ndarray\n                In degrees for each node in `graph`\n            - line_graph_degree: numpy.ndarray\n                In degrees for each node in `line_graph`\n            - pm_pd: numpy.ndarray\n                Edge indicator matrices Pm and Pd\n\n        Returns\n        -------\n        g_batch: :class:`dgl.DGLGraph`\n            Batched graphs\n        lg_batch: :class:`dgl.DGLGraph`\n            Batched line graphs\n        degg_batch: numpy.ndarray\n            A batch of in degrees for each node in `g_batch`\n        deglg_batch: numpy.ndarray\n            A batch of in degrees for each node in `lg_batch`\n        pm_pd_batch: numpy.ndarray\n            A batch of edge indicator matrices Pm and Pd\n        \"\"\"\n        g, lg, deg_g, deg_lg, pm_pd = zip(*x)\n        g_batch = batch.batch(g)\n        lg_batch = batch.batch(lg)\n        degg_batch = np.concatenate(deg_g, axis=0)\n        deglg_batch = np.concatenate(deg_lg, axis=0)\n        pm_pd_batch = np.concatenate(\n            [x + i * self._n_nodes for i, x in enumerate(pm_pd)], axis=0\n        )\n        return g_batch, lg_batch, degg_batch, deglg_batch, pm_pd_batch\n\n\nSBMMixture = SBMMixtureDataset\n"
  },
  {
    "path": "python/dgl/data/superpixel.py",
    "content": "import os\nimport pickle\n\nimport numpy as np\nfrom scipy.spatial.distance import cdist\nfrom tqdm.auto import tqdm\n\nfrom .. import backend as F\nfrom ..convert import graph as dgl_graph\n\nfrom .dgl_dataset import DGLDataset\nfrom .utils import download, extract_archive, load_graphs, save_graphs, Subset\n\n\ndef sigma(dists, kth=8):\n    num_nodes = dists.shape[0]\n\n    # Compute sigma and reshape.\n    if kth > num_nodes:\n        # Handling for graphs with num_nodes less than kth.\n        sigma = np.array([1] * num_nodes).reshape(num_nodes, 1)\n    else:\n        # Get k-nearest neighbors for each node.\n        knns = np.partition(dists, kth, axis=-1)[:, : kth + 1]\n        sigma = knns.sum(axis=1).reshape((knns.shape[0], 1)) / kth\n\n    return sigma + 1e-8\n\n\ndef compute_adjacency_matrix_images(coord, feat, use_feat=True):\n    coord = coord.reshape(-1, 2)\n    # Compute coordinate distance.\n    c_dist = cdist(coord, coord)\n\n    if use_feat:\n        # Compute feature distance.\n        f_dist = cdist(feat, feat)\n        # Compute adjacency.\n        A = np.exp(\n            -((c_dist / sigma(c_dist)) ** 2) - (f_dist / sigma(f_dist)) ** 2\n        )\n    else:\n        A = np.exp(-((c_dist / sigma(c_dist)) ** 2))\n\n    # Convert to symmetric matrix.\n    A = 0.5 * (A + A.T)\n    A[np.diag_indices_from(A)] = 0\n    return A\n\n\ndef compute_edges_list(A, kth=9):\n    # Get k-similar neighbor indices for each node.\n    num_nodes = A.shape[0]\n    new_kth = num_nodes - kth\n\n    if num_nodes > kth:\n        knns = np.argpartition(A, new_kth - 1, axis=-1)[:, new_kth:-1]\n        knn_values = np.partition(A, new_kth - 1, axis=-1)[:, new_kth:-1]\n    else:\n        # Handling for graphs with less than kth nodes.\n        # In such cases, the resulting graph will be fully connected.\n        knns = np.tile(np.arange(num_nodes), num_nodes).reshape(\n            num_nodes, num_nodes\n        )\n        knn_values = A\n\n        # Removing self loop.\n        if num_nodes != 1:\n            knn_values = A[knns != np.arange(num_nodes)[:, None]].reshape(\n                num_nodes, -1\n            )\n            knns = knns[knns != np.arange(num_nodes)[:, None]].reshape(\n                num_nodes, -1\n            )\n    return knns, knn_values\n\n\nclass SuperPixelDataset(DGLDataset):\n    def __init__(\n        self,\n        raw_dir=None,\n        name=\"MNIST\",\n        split=\"train\",\n        use_feature=False,\n        force_reload=False,\n        verbose=False,\n        transform=None,\n    ):\n        assert split in [\"train\", \"test\"], \"split not valid.\"\n        assert name in [\"MNIST\", \"CIFAR10\"], \"name not valid.\"\n\n        self.use_feature = use_feature\n        self.split = split\n        self._dataset_name = name\n        self.graphs = []\n        self.labels = []\n\n        super().__init__(\n            name=\"Superpixel\",\n            raw_dir=raw_dir,\n            url=\"\"\"\n            https://www.dropbox.com/s/y2qwa77a0fxem47/superpixels.zip?dl=1\n            \"\"\",\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n    @property\n    def img_size(self):\n        r\"\"\"Size of dataset image.\"\"\"\n        if self._dataset_name == \"MNIST\":\n            return 28\n        return 32\n\n    @property\n    def save_path(self):\n        r\"\"\"Directory to save the processed dataset.\"\"\"\n        return os.path.join(self.raw_path, \"processed\")\n\n    @property\n    def raw_data_path(self):\n        r\"\"\"Path to save the raw dataset file.\"\"\"\n        return os.path.join(self.raw_path, \"superpixels.zip\")\n\n    @property\n    def graph_path(self):\n        r\"\"\"Path to save the processed dataset file.\"\"\"\n        if self.use_feature:\n            return os.path.join(\n                self.save_path,\n                f\"use_feat_{self._dataset_name}_{self.split}.pkl\",\n            )\n        return os.path.join(\n            self.save_path, f\"{self._dataset_name}_{self.split}.pkl\"\n        )\n\n    def download(self):\n        path = download(self.url, path=self.raw_data_path)\n        extract_archive(path, target_dir=self.raw_path, overwrite=True)\n\n    def process(self):\n        if self._dataset_name == \"MNIST\":\n            plk_file = \"mnist_75sp\"\n        elif self._dataset_name == \"CIFAR10\":\n            plk_file = \"cifar10_150sp\"\n\n        with open(\n            os.path.join(\n                self.raw_path, \"superpixels\", f\"{plk_file}_{self.split}.pkl\"\n            ),\n            \"rb\",\n        ) as f:\n            self.labels, self.sp_data = pickle.load(f)\n            self.labels = F.tensor(self.labels)\n\n        self.Adj_matrices = []\n        self.node_features = []\n        self.edges_lists = []\n        self.edge_features = []\n\n        for index, sample in enumerate(\n            tqdm(self.sp_data, desc=f\"Processing {self.split} dataset\")\n        ):\n            mean_px, coord = sample[:2]\n            coord = coord / self.img_size\n\n            if self.use_feature:\n                A = compute_adjacency_matrix_images(\n                    coord, mean_px\n                )  # using super-pixel locations + features\n            else:\n                A = compute_adjacency_matrix_images(\n                    coord, mean_px, False\n                )  # using only super-pixel locations\n            edges_list, edge_values_list = compute_edges_list(A)\n\n            N_nodes = A.shape[0]\n\n            mean_px = mean_px.reshape(N_nodes, -1)\n            coord = coord.reshape(N_nodes, 2)\n            x = np.concatenate((mean_px, coord), axis=1)\n\n            edge_values_list = edge_values_list.reshape(-1)\n\n            self.node_features.append(x)\n            self.edge_features.append(edge_values_list)\n            self.Adj_matrices.append(A)\n            self.edges_lists.append(edges_list)\n\n        for index in tqdm(\n            range(len(self.sp_data)), desc=f\"Dump {self.split} dataset\"\n        ):\n            N = self.node_features[index].shape[0]\n\n            src_nodes = []\n            dst_nodes = []\n            for src, dsts in enumerate(self.edges_lists[index]):\n                # handling for 1 node where the self loop would be the only edge\n                if N == 1:\n                    src_nodes.append(src)\n                    dst_nodes.append(dsts)\n                else:\n                    dsts = dsts[dsts != src]\n                    srcs = [src] * len(dsts)\n                    src_nodes.extend(srcs)\n                    dst_nodes.extend(dsts)\n\n            src_nodes = F.tensor(src_nodes)\n            dst_nodes = F.tensor(dst_nodes)\n\n            g = dgl_graph((src_nodes, dst_nodes), num_nodes=N)\n            g.ndata[\"feat\"] = F.zerocopy_from_numpy(\n                self.node_features[index]\n            ).to(F.float32)\n            g.edata[\"feat\"] = (\n                F.zerocopy_from_numpy(self.edge_features[index])\n                .to(F.float32)\n                .unsqueeze(1)\n            )\n\n            self.graphs.append(g)\n\n    def load(self):\n        self.graphs, label_dict = load_graphs(self.graph_path)\n        self.labels = label_dict[\"labels\"]\n\n    def save(self):\n        save_graphs(\n            self.graph_path, self.graphs, labels={\"labels\": self.labels}\n        )\n\n    def has_cache(self):\n        return os.path.exists(self.graph_path)\n\n    def __len__(self):\n        return len(self.graphs)\n\n    def __getitem__(self, idx):\n        \"\"\"Get the idx-th sample.\n\n        Parameters\n        ---------\n        idx : int or tensor\n            The sample index.\n            1-D tensor as `idx` is allowed when transform is None.\n\n        Returns\n        -------\n        (:class:`dgl.DGLGraph`, Tensor)\n            Graph with node feature stored in ``feat`` field and its label.\n        or\n        :class:`dgl.data.utils.Subset`\n            Subset of the dataset at specified indices\n        \"\"\"\n        if F.is_tensor(idx) and idx.dim() == 1:\n            if self._transform is None:\n                return Subset(self, idx.cpu())\n\n            raise ValueError(\n                \"Tensor idx not supported when transform is not None.\"\n            )\n\n        if self._transform is None:\n            return self.graphs[idx], self.labels[idx]\n\n        return self._transform(self.graphs[idx]), self.labels[idx]\n\n\nclass MNISTSuperPixelDataset(SuperPixelDataset):\n    r\"\"\"MNIST superpixel dataset for the graph classification task.\n\n    DGL dataset of MNIST and CIFAR10 in the benchmark-gnn which contains graphs\n    converted fromt the original MINST and CIFAR10 images.\n\n    Reference `<http://arxiv.org/abs/2003.00982>`_\n\n    Statistics:\n\n        - Train examples: 60,000\n        - Test examples: 10,000\n        - Size of dataset images: 28\n\n    Parameters\n    ----------\n    raw_dir : str\n        Directory to store all the downloaded raw datasets.\n        Default: \"~/.dgl/\".\n    split : str\n        Should be chosen from [\"train\", \"test\"]\n        Default: \"train\".\n    use_feature: bool\n\n        - True: Adj matrix defined from super-pixel locations + features\n        - False: Adj matrix defined from super-pixel locations (only)\n\n        Default: False.\n    force_reload : bool\n        Whether to reload the dataset.\n        Default: False.\n    verbose : bool\n        Whether to print out progress information.\n        Default: False.\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access.\n\n    Examples\n    ---------\n    >>> from dgl.data import MNISTSuperPixelDataset\n\n    >>> # MNIST dataset\n    >>> train_dataset = MNISTSuperPixelDataset(split=\"train\")\n    >>> len(train_dataset)\n    60000\n    >>> graph, label = train_dataset[0]\n    >>> graph\n    Graph(num_nodes=71, num_edges=568,\n        ndata_schemes={'feat': Scheme(shape=(3,), dtype=torch.float32)}\n        edata_schemes={'feat': Scheme(shape=(1,), dtype=torch.float32)})\n\n    >>> # support tensor to be index when transform is None\n    >>> # see details in __getitem__ function\n    >>> import torch\n    >>> idx = torch.tensor([0, 1, 2])\n    >>> train_dataset_subset = train_dataset[idx]\n    >>> train_dataset_subset[0]\n    Graph(num_nodes=71, num_edges=568,\n        ndata_schemes={'feat': Scheme(shape=(3,), dtype=torch.float32)}\n        edata_schemes={'feat': Scheme(shape=(1,), dtype=torch.float32)})\n    \"\"\"\n\n    def __init__(\n        self,\n        raw_dir=None,\n        split=\"train\",\n        use_feature=False,\n        force_reload=False,\n        verbose=False,\n        transform=None,\n    ):\n        super().__init__(\n            raw_dir=raw_dir,\n            name=\"MNIST\",\n            split=split,\n            use_feature=use_feature,\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n\nclass CIFAR10SuperPixelDataset(SuperPixelDataset):\n    r\"\"\"CIFAR10 superpixel dataset for the graph classification task.\n\n    DGL dataset of CIFAR10 in the benchmark-gnn which contains graphs\n    converted fromt the original CIFAR10 images.\n\n    Reference `<http://arxiv.org/abs/2003.00982>`_\n\n    Statistics:\n\n        - Train examples: 50,000\n        - Test examples: 10,000\n        - Size of dataset images: 32\n\n    Parameters\n    ----------\n    raw_dir : str\n        Directory to store all the downloaded raw datasets.\n        Default: \"~/.dgl/\".\n    split : str\n        Should be chosen from [\"train\", \"test\"]\n        Default: \"train\".\n    use_feature: bool\n\n        - True: Adj matrix defined from super-pixel locations + features\n        - False: Adj matrix defined from super-pixel locations (only)\n\n        Default: False.\n    force_reload : bool\n        Whether to reload the dataset.\n        Default: False.\n    verbose : bool\n        Whether to print out progress information.\n        Default: False.\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access.\n\n    Examples\n    ---------\n    >>> from dgl.data import CIFAR10SuperPixelDataset\n\n    >>> # CIFAR10 dataset\n    >>> train_dataset = CIFAR10SuperPixelDataset(split=\"train\")\n    >>> len(train_dataset)\n    50000\n    >>> graph, label = train_dataset[0]\n    >>> graph\n    Graph(num_nodes=123, num_edges=984,\n        ndata_schemes={'feat': Scheme(shape=(5,), dtype=torch.float32)}\n        edata_schemes={'feat': Scheme(shape=(1,), dtype=torch.float32)}),\n\n    >>> # support tensor to be index when transform is None\n    >>> # see details in __getitem__ function\n    >>> import torch\n    >>> idx = torch.tensor([0, 1, 2])\n    >>> train_dataset_subset = train_dataset[idx]\n    >>> train_dataset_subset[0]\n    Graph(num_nodes=123, num_edges=984,\n        ndata_schemes={'feat': Scheme(shape=(5,), dtype=torch.float32)}\n        edata_schemes={'feat': Scheme(shape=(1,), dtype=torch.float32)}),\n    \"\"\"\n\n    def __init__(\n        self,\n        raw_dir=None,\n        split=\"train\",\n        use_feature=False,\n        force_reload=False,\n        verbose=False,\n        transform=None,\n    ):\n        super().__init__(\n            raw_dir=raw_dir,\n            name=\"CIFAR10\",\n            split=split,\n            use_feature=use_feature,\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n"
  },
  {
    "path": "python/dgl/data/synthetic.py",
    "content": "\"\"\"Synthetic graph datasets.\"\"\"\nimport math\nimport os\nimport pickle\nimport random\n\nimport networkx as nx\nimport numpy as np\n\nfrom .. import backend as F\nfrom ..batch import batch\nfrom ..convert import graph\nfrom ..transforms import reorder_graph\nfrom .dgl_dataset import DGLBuiltinDataset\nfrom .utils import _get_dgl_url, download, load_graphs, save_graphs\n\n\nclass BAShapeDataset(DGLBuiltinDataset):\n    r\"\"\"BA-SHAPES dataset from `GNNExplainer: Generating Explanations for Graph Neural Networks\n    <https://arxiv.org/abs/1903.03894>`__\n\n    This is a synthetic dataset for node classification. It is generated by performing the\n    following steps in order.\n\n    - Construct a base Barabási–Albert (BA) graph.\n    - Construct a set of five-node house-structured network motifs.\n    - Attach the motifs to randomly selected nodes of the base graph.\n    - Perturb the graph by adding random edges.\n    - Nodes are assigned to 4 classes. Nodes of label 0 belong to the base BA graph. Nodes of\n      label 1, 2, 3 are separately at the middle, bottom, or top of houses.\n    - Generate constant feature for all nodes, which is 1.\n\n    Parameters\n    ----------\n    num_base_nodes : int, optional\n        Number of nodes in the base BA graph. Default: 300\n    num_base_edges_per_node : int, optional\n        Number of edges to attach from a new node to existing nodes in constructing the base BA\n        graph. Default: 5\n    num_motifs : int, optional\n        Number of house-structured network motifs to use. Default: 80\n    perturb_ratio : float, optional\n        Number of random edges to add in perturbation divided by the number of edges in the\n        original graph. Default: 0.01\n    seed : integer, random_state, or None, optional\n        Indicator of random number generation state. Default: None\n    raw_dir : str, optional\n        Raw file directory to store the processed data. Default: ~/.dgl/\n    force_reload : bool, optional\n        Whether to always generate the data from scratch rather than load a cached version.\n        Default: False\n    verbose : bool, optional\n        Whether to print progress information. Default: True\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access. Default: None\n\n    Attributes\n    ----------\n    num_classes : int\n        Number of node classes\n\n    Examples\n    --------\n\n    >>> from dgl.data import BAShapeDataset\n    >>> dataset = BAShapeDataset()\n    >>> dataset.num_classes\n    4\n    >>> g = dataset[0]\n    >>> label = g.ndata['label']\n    >>> feat = g.ndata['feat']\n    \"\"\"\n\n    def __init__(\n        self,\n        num_base_nodes=300,\n        num_base_edges_per_node=5,\n        num_motifs=80,\n        perturb_ratio=0.01,\n        seed=None,\n        raw_dir=None,\n        force_reload=False,\n        verbose=True,\n        transform=None,\n    ):\n        self.num_base_nodes = num_base_nodes\n        self.num_base_edges_per_node = num_base_edges_per_node\n        self.num_motifs = num_motifs\n        self.perturb_ratio = perturb_ratio\n        self.seed = seed\n        super(BAShapeDataset, self).__init__(\n            name=\"BA-SHAPES\",\n            url=None,\n            raw_dir=raw_dir,\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n    def process(self):\n        g = nx.barabasi_albert_graph(\n            self.num_base_nodes, self.num_base_edges_per_node, self.seed\n        )\n        edges = list(g.edges())\n        src, dst = map(list, zip(*edges))\n        n = self.num_base_nodes\n\n        # Nodes in the base BA graph belong to class 0\n        node_labels = [0] * n\n        # The motifs will be evenly attached to the nodes in the base graph.\n        spacing = math.floor(n / self.num_motifs)\n\n        for motif_id in range(self.num_motifs):\n            # Construct a five-node house-structured network motif\n            motif_edges = [\n                (n, n + 1),\n                (n + 1, n + 2),\n                (n + 2, n + 3),\n                (n + 3, n),\n                (n + 4, n),\n                (n + 4, n + 1),\n            ]\n            motif_src, motif_dst = map(list, zip(*motif_edges))\n            src.extend(motif_src)\n            dst.extend(motif_dst)\n\n            # Nodes at the middle of a house belong to class 1\n            # Nodes at the bottom of a house belong to class 2\n            # Nodes at the top of a house belong to class 3\n            node_labels.extend([1, 1, 2, 2, 3])\n\n            # Attach the motif to the base BA graph\n            src.append(n)\n            dst.append(int(motif_id * spacing))\n            n += 5\n\n        g = graph((src, dst), num_nodes=n)\n\n        # Perturb the graph by adding non-self-loop random edges\n        num_real_edges = g.num_edges()\n        max_ratio = (n * (n - 1) - num_real_edges) / num_real_edges\n        assert (\n            self.perturb_ratio <= max_ratio\n        ), \"perturb_ratio cannot exceed {:.4f}\".format(max_ratio)\n        num_random_edges = int(num_real_edges * self.perturb_ratio)\n\n        if self.seed is not None:\n            np.random.seed(self.seed)\n        for _ in range(num_random_edges):\n            while True:\n                u = np.random.randint(0, n)\n                v = np.random.randint(0, n)\n                if (not g.has_edges_between(u, v)) and (u != v):\n                    break\n            g.add_edges(u, v)\n\n        g.ndata[\"label\"] = F.tensor(node_labels, F.int64)\n        g.ndata[\"feat\"] = F.ones((n, 1), F.float32, F.cpu())\n        self._graph = reorder_graph(\n            g,\n            node_permute_algo=\"rcmk\",\n            edge_permute_algo=\"dst\",\n            store_ids=False,\n        )\n\n    @property\n    def graph_path(self):\n        return os.path.join(\n            self.save_path, \"{}_dgl_graph.bin\".format(self.name)\n        )\n\n    def save(self):\n        save_graphs(str(self.graph_path), self._graph)\n\n    def has_cache(self):\n        return os.path.exists(self.graph_path)\n\n    def load(self):\n        graphs, _ = load_graphs(str(self.graph_path))\n        self._graph = graphs[0]\n\n    def __getitem__(self, idx):\n        assert idx == 0, \"This dataset has only one graph.\"\n        if self._transform is None:\n            return self._graph\n        else:\n            return self._transform(self._graph)\n\n    def __len__(self):\n        return 1\n\n    @property\n    def num_classes(self):\n        return 4\n\n\nclass BACommunityDataset(DGLBuiltinDataset):\n    r\"\"\"BA-COMMUNITY dataset from `GNNExplainer: Generating Explanations for Graph Neural Networks\n    <https://arxiv.org/abs/1903.03894>`__\n\n    This is a synthetic dataset for node classification. It is generated by performing the\n    following steps in order.\n\n    - Construct a base Barabási–Albert (BA) graph.\n    - Construct a set of five-node house-structured network motifs.\n    - Attach the motifs to randomly selected nodes of the base graph.\n    - Perturb the graph by adding random edges.\n    - Nodes are assigned to 4 classes. Nodes of label 0 belong to the base BA graph. Nodes of\n      label 1, 2, 3 are separately at the middle, bottom, or top of houses.\n    - Generate normally distributed features of length 10\n    - Repeat the above steps to generate another graph. Its nodes are assigned to class\n      4, 5, 6, 7. Its node features are generated with a distinct normal distribution.\n    - Join the two graphs by randomly adding edges between them.\n\n    Parameters\n    ----------\n    num_base_nodes : int, optional\n        Number of nodes in each base BA graph. Default: 300\n    num_base_edges_per_node : int, optional\n        Number of edges to attach from a new node to existing nodes in constructing a base BA\n        graph. Default: 4\n    num_motifs : int, optional\n        Number of house-structured network motifs to use in constructing each graph. Default: 80\n    perturb_ratio : float, optional\n        Number of random edges to add to a graph in perturbation divided by the number of original\n        edges in it. Default: 0.01\n    num_inter_edges : int, optional\n        Number of random edges to add between the two graphs. Default: 350\n    seed : integer, random_state, or None, optional\n        Indicator of random number generation state. Default: None\n    raw_dir : str, optional\n        Raw file directory to store the processed data. Default: ~/.dgl/\n    force_reload : bool, optional\n        Whether to always generate the data from scratch rather than load a cached version.\n        Default: False\n    verbose : bool, optional\n        Whether to print progress information. Default: True\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access. Default: None\n\n    Attributes\n    ----------\n    num_classes : int\n        Number of node classes\n\n    Examples\n    --------\n\n    >>> from dgl.data import BACommunityDataset\n    >>> dataset = BACommunityDataset()\n    >>> dataset.num_classes\n    8\n    >>> g = dataset[0]\n    >>> label = g.ndata['label']\n    >>> feat = g.ndata['feat']\n    \"\"\"\n\n    def __init__(\n        self,\n        num_base_nodes=300,\n        num_base_edges_per_node=4,\n        num_motifs=80,\n        perturb_ratio=0.01,\n        num_inter_edges=350,\n        seed=None,\n        raw_dir=None,\n        force_reload=False,\n        verbose=True,\n        transform=None,\n    ):\n        self.num_base_nodes = num_base_nodes\n        self.num_base_edges_per_node = num_base_edges_per_node\n        self.num_motifs = num_motifs\n        self.perturb_ratio = perturb_ratio\n        self.num_inter_edges = num_inter_edges\n        self.seed = seed\n        super(BACommunityDataset, self).__init__(\n            name=\"BA-COMMUNITY\",\n            url=None,\n            raw_dir=raw_dir,\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n    def process(self):\n        if self.seed is not None:\n            random.seed(self.seed)\n            np.random.seed(self.seed)\n\n        # Construct two BA-SHAPES graphs\n        g1 = BAShapeDataset(\n            self.num_base_nodes,\n            self.num_base_edges_per_node,\n            self.num_motifs,\n            self.perturb_ratio,\n            force_reload=True,\n            verbose=False,\n        )[0]\n        g2 = BAShapeDataset(\n            self.num_base_nodes,\n            self.num_base_edges_per_node,\n            self.num_motifs,\n            self.perturb_ratio,\n            force_reload=True,\n            verbose=False,\n        )[0]\n\n        # Join them and randomly add edges between them\n        g = batch([g1, g2])\n        num_nodes = g.num_nodes() // 2\n        src = np.random.randint(0, num_nodes, (self.num_inter_edges,))\n        dst = np.random.randint(\n            num_nodes, 2 * num_nodes, (self.num_inter_edges,)\n        )\n        src = F.astype(F.zerocopy_from_numpy(src), g.idtype)\n        dst = F.astype(F.zerocopy_from_numpy(dst), g.idtype)\n        g.add_edges(src, dst)\n        g.ndata[\"label\"] = F.cat(\n            [g1.ndata[\"label\"], g2.ndata[\"label\"] + 4], dim=0\n        )\n\n        # feature generation\n        random_mu = [0.0] * 8\n        random_sigma = [1.0] * 8\n\n        mu_1, sigma_1 = np.array([-1.0] * 2 + random_mu), np.array(\n            [0.5] * 2 + random_sigma\n        )\n        feat1 = np.random.multivariate_normal(mu_1, np.diag(sigma_1), num_nodes)\n\n        mu_2, sigma_2 = np.array([1.0] * 2 + random_mu), np.array(\n            [0.5] * 2 + random_sigma\n        )\n        feat2 = np.random.multivariate_normal(mu_2, np.diag(sigma_2), num_nodes)\n\n        feat = np.concatenate([feat1, feat2])\n        g.ndata[\"feat\"] = F.zerocopy_from_numpy(feat)\n        self._graph = reorder_graph(\n            g,\n            node_permute_algo=\"rcmk\",\n            edge_permute_algo=\"dst\",\n            store_ids=False,\n        )\n\n    @property\n    def graph_path(self):\n        return os.path.join(\n            self.save_path, \"{}_dgl_graph.bin\".format(self.name)\n        )\n\n    def save(self):\n        save_graphs(str(self.graph_path), self._graph)\n\n    def has_cache(self):\n        return os.path.exists(self.graph_path)\n\n    def load(self):\n        graphs, _ = load_graphs(str(self.graph_path))\n        self._graph = graphs[0]\n\n    def __getitem__(self, idx):\n        assert idx == 0, \"This dataset has only one graph.\"\n        if self._transform is None:\n            return self._graph\n        else:\n            return self._transform(self._graph)\n\n    def __len__(self):\n        return 1\n\n    @property\n    def num_classes(self):\n        return 8\n\n\nclass TreeCycleDataset(DGLBuiltinDataset):\n    r\"\"\"TREE-CYCLES dataset from `GNNExplainer: Generating Explanations for Graph Neural Networks\n    <https://arxiv.org/abs/1903.03894>`__\n\n    This is a synthetic dataset for node classification. It is generated by performing the\n    following steps in order.\n\n    - Construct a balanced binary tree as the base graph.\n    - Construct a set of cycle motifs.\n    - Attach the motifs to randomly selected nodes of the base graph.\n    - Perturb the graph by adding random edges.\n    - Generate constant feature for all nodes, which is 1.\n    - Nodes in the tree belong to class 0 and nodes in cycles belong to class 1.\n\n    Parameters\n    ----------\n    tree_height : int, optional\n        Height of the balanced binary tree. Default: 8\n    num_motifs : int, optional\n        Number of cycle motifs to use. Default: 60\n    cycle_size : int, optional\n        Number of nodes in a cycle motif. Default: 6\n    perturb_ratio : float, optional\n        Number of random edges to add in perturbation divided by the\n        number of original edges in the graph. Default: 0.01\n    seed : integer, random_state, or None, optional\n        Indicator of random number generation state. Default: None\n    raw_dir : str, optional\n        Raw file directory to store the processed data. Default: ~/.dgl/\n    force_reload : bool, optional\n        Whether to always generate the data from scratch rather than load a cached version.\n        Default: False\n    verbose : bool, optional\n        Whether to print progress information. Default: True\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access. Default: None\n\n    Attributes\n    ----------\n    num_classes : int\n        Number of node classes\n\n    Examples\n    --------\n\n    >>> from dgl.data import TreeCycleDataset\n    >>> dataset = TreeCycleDataset()\n    >>> dataset.num_classes\n    2\n    >>> g = dataset[0]\n    >>> label = g.ndata['label']\n    >>> feat = g.ndata['feat']\n    \"\"\"\n\n    def __init__(\n        self,\n        tree_height=8,\n        num_motifs=60,\n        cycle_size=6,\n        perturb_ratio=0.01,\n        seed=None,\n        raw_dir=None,\n        force_reload=False,\n        verbose=True,\n        transform=None,\n    ):\n        self.tree_height = tree_height\n        self.num_motifs = num_motifs\n        self.cycle_size = cycle_size\n        self.perturb_ratio = perturb_ratio\n        self.seed = seed\n        super(TreeCycleDataset, self).__init__(\n            name=\"TREE-CYCLES\",\n            url=None,\n            raw_dir=raw_dir,\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n    def process(self):\n        if self.seed is not None:\n            np.random.seed(self.seed)\n\n        g = nx.balanced_tree(r=2, h=self.tree_height)\n        edges = list(g.edges())\n        src, dst = map(list, zip(*edges))\n        n = nx.number_of_nodes(g)\n\n        # Nodes in the base tree graph belong to class 0\n        node_labels = [0] * n\n        # The motifs will be evenly attached to the nodes in the base graph.\n        spacing = math.floor(n / self.num_motifs)\n\n        for motif_id in range(self.num_motifs):\n            # Construct a six-node cycle\n            motif_edges = [(n + i, n + i + 1) for i in range(5)]\n            motif_edges.append((n + 5, n))\n            motif_src, motif_dst = map(list, zip(*motif_edges))\n            src.extend(motif_src)\n            dst.extend(motif_dst)\n\n            # Nodes in cycles belong to class 1\n            node_labels.extend([1] * self.cycle_size)\n\n            # Attach the motif to the base tree graph\n            anchor = int(motif_id * spacing)\n            src.append(n)\n            dst.append(anchor)\n\n            if np.random.random() > 0.5:\n                a = np.random.randint(1, 4)\n                b = np.random.randint(1, 4)\n                src.append(n + a)\n                dst.append(anchor + b)\n\n            n += self.cycle_size\n\n        g = graph((src, dst), num_nodes=n)\n\n        # Perturb the graph by adding non-self-loop random edges\n        num_real_edges = g.num_edges()\n        max_ratio = (n * (n - 1) - num_real_edges) / num_real_edges\n        assert (\n            self.perturb_ratio <= max_ratio\n        ), \"perturb_ratio cannot exceed {:.4f}\".format(max_ratio)\n        num_random_edges = int(num_real_edges * self.perturb_ratio)\n\n        for _ in range(num_random_edges):\n            while True:\n                u = np.random.randint(0, n)\n                v = np.random.randint(0, n)\n                if (not g.has_edges_between(u, v)) and (u != v):\n                    break\n            g.add_edges(u, v)\n\n        g.ndata[\"label\"] = F.tensor(node_labels, F.int64)\n        g.ndata[\"feat\"] = F.ones((n, 1), F.float32, F.cpu())\n        self._graph = reorder_graph(\n            g,\n            node_permute_algo=\"rcmk\",\n            edge_permute_algo=\"dst\",\n            store_ids=False,\n        )\n\n    @property\n    def graph_path(self):\n        return os.path.join(\n            self.save_path, \"{}_dgl_graph.bin\".format(self.name)\n        )\n\n    def save(self):\n        save_graphs(str(self.graph_path), self._graph)\n\n    def has_cache(self):\n        return os.path.exists(self.graph_path)\n\n    def load(self):\n        graphs, _ = load_graphs(str(self.graph_path))\n        self._graph = graphs[0]\n\n    def __getitem__(self, idx):\n        assert idx == 0, \"This dataset has only one graph.\"\n        if self._transform is None:\n            return self._graph\n        else:\n            return self._transform(self._graph)\n\n    def __len__(self):\n        return 1\n\n    @property\n    def num_classes(self):\n        return 2\n\n\nclass TreeGridDataset(DGLBuiltinDataset):\n    r\"\"\"TREE-GRIDS dataset from `GNNExplainer: Generating Explanations for Graph Neural Networks\n    <https://arxiv.org/abs/1903.03894>`__\n\n    This is a synthetic dataset for node classification. It is generated by performing the\n    following steps in order.\n\n    - Construct a balanced binary tree as the base graph.\n    - Construct a set of n-by-n grid motifs.\n    - Attach the motifs to randomly selected nodes of the base graph.\n    - Perturb the graph by adding random edges.\n    - Generate constant feature for all nodes, which is 1.\n    - Nodes in the tree belong to class 0 and nodes in grids belong to class 1.\n\n    Parameters\n    ----------\n    tree_height : int, optional\n        Height of the balanced binary tree. Default: 8\n    num_motifs : int, optional\n        Number of grid motifs to use. Default: 80\n    grid_size : int, optional\n        The number of nodes in a grid motif will be grid_size ^ 2. Default: 3\n    perturb_ratio : float, optional\n        Number of random edges to add in perturbation divided by the\n        number of original edges in the graph. Default: 0.1\n    seed : integer, random_state, or None, optional\n        Indicator of random number generation state. Default: None\n    raw_dir : str, optional\n        Raw file directory to store the processed data. Default: ~/.dgl/\n    force_reload : bool, optional\n        Whether to always generate the data from scratch rather than load a cached version.\n        Default: False\n    verbose : bool, optional\n        Whether to print progress information. Default: True\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access. Default: None\n\n    Attributes\n    ----------\n    num_classes : int\n        Number of node classes\n\n    Examples\n    --------\n\n    >>> from dgl.data import TreeGridDataset\n    >>> dataset = TreeGridDataset()\n    >>> dataset.num_classes\n    2\n    >>> g = dataset[0]\n    >>> label = g.ndata['label']\n    >>> feat = g.ndata['feat']\n    \"\"\"\n\n    def __init__(\n        self,\n        tree_height=8,\n        num_motifs=80,\n        grid_size=3,\n        perturb_ratio=0.1,\n        seed=None,\n        raw_dir=None,\n        force_reload=False,\n        verbose=True,\n        transform=None,\n    ):\n        self.tree_height = tree_height\n        self.num_motifs = num_motifs\n        self.grid_size = grid_size\n        self.perturb_ratio = perturb_ratio\n        self.seed = seed\n        super(TreeGridDataset, self).__init__(\n            name=\"TREE-GRIDS\",\n            url=None,\n            raw_dir=raw_dir,\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n    def process(self):\n        if self.seed is not None:\n            np.random.seed(self.seed)\n\n        g = nx.balanced_tree(r=2, h=self.tree_height)\n        edges = list(g.edges())\n        src, dst = map(list, zip(*edges))\n        n = nx.number_of_nodes(g)\n\n        # Nodes in the base tree graph belong to class 0\n        node_labels = [0] * n\n        # The motifs will be evenly attached to the nodes in the base graph.\n        spacing = math.floor(n / self.num_motifs)\n\n        # Construct an n-by-n grid\n        motif_g = nx.grid_graph([self.grid_size, self.grid_size])\n        grid_size = nx.number_of_nodes(motif_g)\n        motif_g = nx.convert_node_labels_to_integers(motif_g, first_label=0)\n        motif_edges = list(motif_g.edges())\n        motif_src, motif_dst = map(list, zip(*motif_edges))\n        motif_src, motif_dst = np.array(motif_src), np.array(motif_dst)\n\n        for motif_id in range(self.num_motifs):\n            src.extend((motif_src + n).tolist())\n            dst.extend((motif_dst + n).tolist())\n\n            # Nodes in grids belong to class 1\n            node_labels.extend([1] * grid_size)\n\n            # Attach the motif to the base tree graph\n            src.append(n)\n            dst.append(int(motif_id * spacing))\n\n            n += grid_size\n\n        g = graph((src, dst), num_nodes=n)\n\n        # Perturb the graph by adding non-self-loop random edges\n        num_real_edges = g.num_edges()\n        max_ratio = (n * (n - 1) - num_real_edges) / num_real_edges\n        assert (\n            self.perturb_ratio <= max_ratio\n        ), \"perturb_ratio cannot exceed {:.4f}\".format(max_ratio)\n        num_random_edges = int(num_real_edges * self.perturb_ratio)\n\n        for _ in range(num_random_edges):\n            while True:\n                u = np.random.randint(0, n)\n                v = np.random.randint(0, n)\n                if (not g.has_edges_between(u, v)) and (u != v):\n                    break\n            g.add_edges(u, v)\n\n        g.ndata[\"label\"] = F.tensor(node_labels, F.int64)\n        g.ndata[\"feat\"] = F.ones((n, 1), F.float32, F.cpu())\n        self._graph = reorder_graph(\n            g,\n            node_permute_algo=\"rcmk\",\n            edge_permute_algo=\"dst\",\n            store_ids=False,\n        )\n\n    @property\n    def graph_path(self):\n        return os.path.join(\n            self.save_path, \"{}_dgl_graph.bin\".format(self.name)\n        )\n\n    def save(self):\n        save_graphs(str(self.graph_path), self._graph)\n\n    def has_cache(self):\n        return os.path.exists(self.graph_path)\n\n    def load(self):\n        graphs, _ = load_graphs(str(self.graph_path))\n        self._graph = graphs[0]\n\n    def __getitem__(self, idx):\n        assert idx == 0, \"This dataset has only one graph.\"\n        if self._transform is None:\n            return self._graph\n        else:\n            return self._transform(self._graph)\n\n    def __len__(self):\n        return 1\n\n    @property\n    def num_classes(self):\n        return 2\n\n\nclass BA2MotifDataset(DGLBuiltinDataset):\n    r\"\"\"BA-2motifs dataset from `Parameterized Explainer for Graph Neural Network\n    <https://arxiv.org/abs/2011.04573>`__\n\n    This is a synthetic dataset for graph classification. It was generated by\n    performing the following steps in order.\n\n    - Construct 1000 base Barabási–Albert (BA) graphs.\n    - Attach house-structured network motifs to half of the base BA graphs.\n    - Attach five-node cycle motifs to the rest base BA graphs.\n    - Assign each graph to one of two classes according to the type of the attached motif.\n\n    Parameters\n    ----------\n    raw_dir : str, optional\n        Raw file directory to download and store the data. Default: ~/.dgl/\n    force_reload : bool, optional\n        Whether to reload the dataset. Default: False\n    verbose : bool, optional\n        Whether to print progress information. Default: True\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access. Default: None\n\n    Attributes\n    ----------\n    num_classes : int\n        Number of graph classes\n\n    Examples\n    --------\n\n    >>> from dgl.data import BA2MotifDataset\n    >>> dataset = BA2MotifDataset()\n    >>> dataset.num_classes\n    2\n    >>> # Get the first graph and its label\n    >>> g, label = dataset[0]\n    >>> feat = g.ndata['feat']\n    \"\"\"\n\n    def __init__(\n        self, raw_dir=None, force_reload=False, verbose=True, transform=None\n    ):\n        super(BA2MotifDataset, self).__init__(\n            name=\"BA-2motifs\",\n            url=_get_dgl_url(\"dataset/BA-2motif.pkl\"),\n            raw_dir=raw_dir,\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n    def download(self):\n        r\"\"\"Automatically download data.\"\"\"\n        file_path = os.path.join(self.raw_dir, self.name + \".pkl\")\n        download(self.url, path=file_path)\n\n    def process(self):\n        file_path = os.path.join(self.raw_dir, self.name + \".pkl\")\n        with open(file_path, \"rb\") as f:\n            adjs, features, labels = pickle.load(f)\n\n        self.graphs = []\n        self.labels = F.tensor(labels, F.int64)\n\n        for i in range(len(adjs)):\n            g = graph(adjs[i].nonzero())\n            g.ndata[\"feat\"] = F.zerocopy_from_numpy(features[i])\n            self.graphs.append(g)\n\n    @property\n    def graph_path(self):\n        return os.path.join(\n            self.save_path, \"{}_dgl_graph.bin\".format(self.name)\n        )\n\n    def save(self):\n        label_dict = {\"labels\": self.labels}\n        save_graphs(str(self.graph_path), self.graphs, label_dict)\n\n    def has_cache(self):\n        return os.path.exists(self.graph_path)\n\n    def load(self):\n        self.graphs, label_dict = load_graphs(str(self.graph_path))\n        self.labels = label_dict[\"labels\"]\n\n    def __getitem__(self, idx):\n        g = self.graphs[idx]\n        if self._transform is not None:\n            g = self._transform(g)\n        return g, self.labels[idx]\n\n    def __len__(self):\n        return len(self.graphs)\n\n    @property\n    def num_classes(self):\n        return 2\n"
  },
  {
    "path": "python/dgl/data/tensor_serialize.py",
    "content": "\"\"\"For Tensor Serialization\"\"\"\nfrom __future__ import absolute_import\n\nfrom .. import backend as F\nfrom .._ffi.function import _init_api\nfrom ..ndarray import NDArray\n\n__all__ = [\"save_tensors\", \"load_tensors\"]\n\n_init_api(\"dgl.data.tensor_serialize\")\n\n\ndef save_tensors(filename, tensor_dict):\n    \"\"\"\n    Save dict of tensors to file\n\n    Parameters\n    ----------\n    filename : str\n        File name to store dict of tensors.\n    tensor_dict: dict of dgl NDArray or backend tensor\n        Python dict using string as key and tensor as value\n\n    Returns\n    ----------\n    status : bool\n        Return whether save operation succeeds\n    \"\"\"\n    nd_dict = {}\n    is_empty_dict = len(tensor_dict) == 0\n    for key, value in tensor_dict.items():\n        if not isinstance(key, str):\n            raise Exception(\"Dict key has to be str\")\n        if F.is_tensor(value):\n            nd_dict[key] = F.zerocopy_to_dgl_ndarray(value)\n        elif isinstance(value, NDArray):\n            nd_dict[key] = value\n        else:\n            raise Exception(\n                \"Dict value has to be backend tensor or dgl ndarray\"\n            )\n\n    return _CAPI_SaveNDArrayDict(filename, nd_dict, is_empty_dict)\n\n\ndef load_tensors(filename, return_dgl_ndarray=False):\n    \"\"\"\n    load dict of tensors from file\n\n    Parameters\n    ----------\n    filename : str\n        File name to load dict of tensors.\n    return_dgl_ndarray: bool\n        Whether return dict of dgl NDArrays or backend tensors\n\n    Returns\n    ---------\n    tensor_dict : dict\n        dict of tensor or ndarray based on return_dgl_ndarray flag\n    \"\"\"\n    nd_dict = _CAPI_LoadNDArrayDict(filename)\n    tensor_dict = {}\n    for key, value in nd_dict.items():\n        if return_dgl_ndarray:\n            tensor_dict[key] = value\n        else:\n            tensor_dict[key] = F.zerocopy_from_dgl_ndarray(value)\n    return tensor_dict\n"
  },
  {
    "path": "python/dgl/data/tree.py",
    "content": "\"\"\"Tree-structured data.\nIncluding:\n    - Stanford Sentiment Treebank\n\"\"\"\nfrom __future__ import absolute_import\n\nimport os\n\nfrom collections import OrderedDict\n\nimport networkx as nx\n\nimport numpy as np\n\nfrom .. import backend as F\nfrom ..convert import from_networkx\n\nfrom .dgl_dataset import DGLBuiltinDataset\nfrom .utils import (\n    _get_dgl_url,\n    deprecate_property,\n    load_graphs,\n    load_info,\n    save_graphs,\n    save_info,\n)\n\n__all__ = [\"SST\", \"SSTDataset\"]\n\n\nclass SSTDataset(DGLBuiltinDataset):\n    r\"\"\"Stanford Sentiment Treebank dataset.\n\n    Each sample is the constituency tree of a sentence. The leaf nodes\n    represent words. The word is a int value stored in the ``x`` feature field.\n    The non-leaf node has a special value ``PAD_WORD`` in the ``x`` field.\n    Each node also has a sentiment annotation: 5 classes (very negative,\n    negative, neutral, positive and very positive). The sentiment label is a\n    int value stored in the ``y`` feature field.\n    Official site: `<http://nlp.stanford.edu/sentiment/index.html>`_\n\n    Statistics:\n\n    - Train examples: 8,544\n    - Dev examples: 1,101\n    - Test examples: 2,210\n    - Number of classes for each node: 5\n\n    Parameters\n    ----------\n    mode : str, optional\n        Should be one of ['train', 'dev', 'test', 'tiny']\n        Default: train\n    glove_embed_file : str, optional\n        The path to pretrained glove embedding file.\n        Default: None\n    vocab_file : str, optional\n        Optional vocabulary file. If not given, the default vacabulary file is used.\n        Default: None\n    raw_dir : str\n        Raw file directory to download/contains the input data directory.\n        Default: ~/.dgl/\n    force_reload : bool\n        Whether to reload the dataset. Default: False\n    verbose : bool\n        Whether to print out progress information. Default: True.\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access.\n\n    Attributes\n    ----------\n    vocab : OrderedDict\n        Vocabulary of the dataset\n    num_classes : int\n        Number of classes for each node\n    pretrained_emb: Tensor\n        Pretrained glove embedding with respect the vocabulary.\n    vocab_size : int\n        The size of the vocabulary\n\n    Notes\n    -----\n    All the samples will be loaded and preprocessed in the memory first.\n\n    Examples\n    --------\n    >>> # get dataset\n    >>> train_data = SSTDataset()\n    >>> dev_data = SSTDataset(mode='dev')\n    >>> test_data = SSTDataset(mode='test')\n    >>> tiny_data = SSTDataset(mode='tiny')\n    >>>\n    >>> len(train_data)\n    8544\n    >>> train_data.num_classes\n    5\n    >>> glove_embed = train_data.pretrained_emb\n    >>> train_data.vocab_size\n    19536\n    >>> train_data[0]\n    Graph(num_nodes=71, num_edges=70,\n      ndata_schemes={'x': Scheme(shape=(), dtype=torch.int64), 'y': Scheme(shape=(), dtype=torch.int64), 'mask': Scheme(shape=(), dtype=torch.int64)}\n      edata_schemes={})\n    >>> for tree in train_data:\n    ...     input_ids = tree.ndata['x']\n    ...     labels = tree.ndata['y']\n    ...     mask = tree.ndata['mask']\n    ...     # your code here\n    \"\"\"\n\n    PAD_WORD = -1  # special pad word id\n    UNK_WORD = -1  # out-of-vocabulary word id\n\n    def __init__(\n        self,\n        mode=\"train\",\n        glove_embed_file=None,\n        vocab_file=None,\n        raw_dir=None,\n        force_reload=False,\n        verbose=False,\n        transform=None,\n    ):\n        assert mode in [\"train\", \"dev\", \"test\", \"tiny\"]\n        _url = _get_dgl_url(\"dataset/sst.zip\")\n        self._glove_embed_file = glove_embed_file if mode == \"train\" else None\n        self.mode = mode\n        self._vocab_file = vocab_file\n        super(SSTDataset, self).__init__(\n            name=\"sst\",\n            url=_url,\n            raw_dir=raw_dir,\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n    def process(self):\n        from nltk.corpus.reader import BracketParseCorpusReader\n\n        # load vocab file\n        self._vocab = OrderedDict()\n        vocab_file = (\n            self._vocab_file\n            if self._vocab_file is not None\n            else os.path.join(self.raw_path, \"vocab.txt\")\n        )\n        with open(vocab_file, encoding=\"utf-8\") as vf:\n            for line in vf.readlines():\n                line = line.strip()\n                self._vocab[line] = len(self._vocab)\n\n        # filter glove\n        if self._glove_embed_file is not None and os.path.exists(\n            self._glove_embed_file\n        ):\n            glove_emb = {}\n            with open(self._glove_embed_file, \"r\", encoding=\"utf-8\") as pf:\n                for line in pf.readlines():\n                    sp = line.split(\" \")\n                    if sp[0].lower() in self._vocab:\n                        glove_emb[sp[0].lower()] = np.asarray(\n                            [float(x) for x in sp[1:]]\n                        )\n        files = [\"{}.txt\".format(self.mode)]\n        corpus = BracketParseCorpusReader(self.raw_path, files)\n        sents = corpus.parsed_sents(files[0])\n\n        # initialize with glove\n        pretrained_emb = []\n        fail_cnt = 0\n        for line in self._vocab.keys():\n            if self._glove_embed_file is not None and os.path.exists(\n                self._glove_embed_file\n            ):\n                if not line.lower() in glove_emb:\n                    fail_cnt += 1\n                pretrained_emb.append(\n                    glove_emb.get(\n                        line.lower(), np.random.uniform(-0.05, 0.05, 300)\n                    )\n                )\n\n        self._pretrained_emb = None\n        if self._glove_embed_file is not None and os.path.exists(\n            self._glove_embed_file\n        ):\n            self._pretrained_emb = F.tensor(np.stack(pretrained_emb, 0))\n            print(\n                \"Miss word in GloVe {0:.4f}\".format(\n                    1.0 * fail_cnt / len(self._pretrained_emb)\n                )\n            )\n        # build trees\n        self._trees = []\n        for sent in sents:\n            self._trees.append(self._build_tree(sent))\n\n    def _build_tree(self, root):\n        g = nx.DiGraph()\n\n        def _rec_build(nid, node):\n            for child in node:\n                cid = g.number_of_nodes()\n                if isinstance(child[0], str) or isinstance(child[0], bytes):\n                    # leaf node\n                    word = self.vocab.get(child[0].lower(), self.UNK_WORD)\n                    g.add_node(cid, x=word, y=int(child.label()), mask=1)\n                else:\n                    g.add_node(\n                        cid, x=SSTDataset.PAD_WORD, y=int(child.label()), mask=0\n                    )\n                    _rec_build(cid, child)\n                g.add_edge(cid, nid)\n\n        # add root\n        g.add_node(0, x=SSTDataset.PAD_WORD, y=int(root.label()), mask=0)\n        _rec_build(0, root)\n        ret = from_networkx(g, node_attrs=[\"x\", \"y\", \"mask\"])\n        return ret\n\n    @property\n    def graph_path(self):\n        return os.path.join(self.save_path, self.mode + \"_dgl_graph.bin\")\n\n    @property\n    def vocab_path(self):\n        return os.path.join(self.save_path, \"vocab.pkl\")\n\n    def has_cache(self):\n        return os.path.exists(self.graph_path) and os.path.exists(\n            self.vocab_path\n        )\n\n    def save(self):\n        save_graphs(self.graph_path, self._trees)\n        save_info(self.vocab_path, {\"vocab\": self.vocab})\n        if self.pretrained_emb:\n            emb_path = os.path.join(self.save_path, \"emb.pkl\")\n            save_info(emb_path, {\"embed\": self.pretrained_emb})\n\n    def load(self):\n        emb_path = os.path.join(self.save_path, \"emb.pkl\")\n\n        self._trees = load_graphs(self.graph_path)[0]\n        self._vocab = load_info(self.vocab_path)[\"vocab\"]\n        self._pretrained_emb = None\n        if os.path.exists(emb_path):\n            self._pretrained_emb = load_info(emb_path)[\"embed\"]\n\n    @property\n    def vocab(self):\n        r\"\"\"Vocabulary\n\n        Returns\n        -------\n        OrderedDict\n        \"\"\"\n        return self._vocab\n\n    @property\n    def pretrained_emb(self):\n        r\"\"\"Pre-trained word embedding, if given.\"\"\"\n        return self._pretrained_emb\n\n    def __getitem__(self, idx):\n        r\"\"\"Get graph by index\n\n        Parameters\n        ----------\n        idx : int\n\n        Returns\n        -------\n        :class:`dgl.DGLGraph`\n\n            graph structure, word id for each node, node labels and masks.\n\n            - ``ndata['x']``: word id of the node\n            - ``ndata['y']:`` label of the node\n            - ``ndata['mask']``: 1 if the node is a leaf, otherwise 0\n        \"\"\"\n        if self._transform is None:\n            return self._trees[idx]\n        else:\n            return self._transform(self._trees[idx])\n\n    def __len__(self):\n        r\"\"\"Number of graphs in the dataset.\"\"\"\n        return len(self._trees)\n\n    @property\n    def vocab_size(self):\n        r\"\"\"Vocabulary size.\"\"\"\n        return len(self._vocab)\n\n    @property\n    def num_classes(self):\n        r\"\"\"Number of classes for each node.\"\"\"\n        return 5\n\n\nSST = SSTDataset\n"
  },
  {
    "path": "python/dgl/data/tu.py",
    "content": "from __future__ import absolute_import\n\nimport os\n\nimport numpy as np\n\nfrom .. import backend as F\nfrom ..convert import graph as dgl_graph\n\nfrom .dgl_dataset import DGLBuiltinDataset\nfrom .utils import load_graphs, load_info, loadtxt, save_graphs, save_info\n\n\nclass LegacyTUDataset(DGLBuiltinDataset):\n    r\"\"\"LegacyTUDataset contains lots of graph kernel datasets for graph classification.\n\n    Parameters\n    ----------\n    name : str\n        Dataset Name, such as ``ENZYMES``, ``DD``, ``COLLAB``, ``MUTAG``, can be the\n        datasets name on `<https://chrsmrrs.github.io/datasets/docs/datasets/>`_.\n    use_pandas : bool\n        Numpy's file read function has performance issue when file is large,\n        using pandas can be faster.\n        Default: False\n    hidden_size : int\n        Some dataset doesn't contain features.\n        Use constant node features initialization instead, with hidden size as ``hidden_size``.\n        Default : 10\n    max_allow_node : int\n        Remove graphs that contains more nodes than ``max_allow_node``.\n        Default : None\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access.\n\n    Attributes\n    ----------\n    max_num_node : int\n        Maximum number of nodes\n    num_classes : int\n        Number of classes\n    num_labels : numpy.int64\n        (DEPRECATED, use num_classes instead) Number of classes\n\n    Notes\n    -----\n    LegacyTUDataset uses provided node feature by default. If no feature provided, it uses one-hot node label instead.\n    If neither labels provided, it uses constant for node feature.\n\n    The dataset sorts graphs by their labels.\n    Shuffle is preferred before manual train/val split.\n\n    Examples\n    --------\n    >>> data = LegacyTUDataset('DD')\n\n    The dataset instance is an iterable\n\n    >>> len(data)\n    1178\n    >>> g, label = data[1024]\n    >>> g\n    Graph(num_nodes=88, num_edges=410,\n          ndata_schemes={'feat': Scheme(shape=(89,), dtype=torch.float32), '_ID': Scheme(shape=(), dtype=torch.int64)}\n          edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)})\n    >>> label\n    tensor(1)\n\n    Batch the graphs and labels for mini-batch training\n\n    >>> graphs, labels = zip(*[data[i] for i in range(16)])\n    >>> batched_graphs = dgl.batch(graphs)\n    >>> batched_labels = torch.tensor(labels)\n    >>> batched_graphs\n    Graph(num_nodes=9539, num_edges=47382,\n          ndata_schemes={'feat': Scheme(shape=(89,), dtype=torch.float32), '_ID': Scheme(shape=(), dtype=torch.int64)}\n          edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)})\n    \"\"\"\n\n    _url = r\"https://www.chrsmrrs.com/graphkerneldatasets/{}.zip\"\n\n    def __init__(\n        self,\n        name,\n        use_pandas=False,\n        hidden_size=10,\n        max_allow_node=None,\n        raw_dir=None,\n        force_reload=False,\n        verbose=False,\n        transform=None,\n    ):\n        url = self._url.format(name)\n        self.hidden_size = hidden_size\n        self.max_allow_node = max_allow_node\n        self.use_pandas = use_pandas\n        super(LegacyTUDataset, self).__init__(\n            name=name,\n            url=url,\n            raw_dir=raw_dir,\n            hash_key=(name, use_pandas, hidden_size, max_allow_node),\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n    def process(self):\n        self.data_mode = None\n\n        if self.use_pandas:\n            import pandas as pd\n\n            DS_edge_list = self._idx_from_zero(\n                pd.read_csv(\n                    self._file_path(\"A\"), delimiter=\",\", dtype=int, header=None\n                ).values\n            )\n        else:\n            DS_edge_list = self._idx_from_zero(\n                np.genfromtxt(self._file_path(\"A\"), delimiter=\",\", dtype=int)\n            )\n\n        DS_indicator = self._idx_from_zero(\n            np.genfromtxt(self._file_path(\"graph_indicator\"), dtype=int)\n        )\n        if os.path.exists(self._file_path(\"graph_labels\")):\n            DS_graph_labels = self._idx_from_zero(\n                np.genfromtxt(self._file_path(\"graph_labels\"), dtype=int)\n            )\n            self.num_labels = max(DS_graph_labels) + 1\n            self.graph_labels = DS_graph_labels\n        elif os.path.exists(self._file_path(\"graph_attributes\")):\n            DS_graph_labels = np.genfromtxt(\n                self._file_path(\"graph_attributes\"), dtype=float\n            )\n            self.num_labels = None\n            self.graph_labels = DS_graph_labels\n        else:\n            raise Exception(\"Unknown graph label or graph attributes\")\n\n        g = dgl_graph(([], []))\n        g.add_nodes(int(DS_edge_list.max()) + 1)\n        g.add_edges(DS_edge_list[:, 0], DS_edge_list[:, 1])\n\n        node_idx_list = []\n        self.max_num_node = 0\n        for idx in range(np.max(DS_indicator) + 1):\n            node_idx = np.where(DS_indicator == idx)\n            node_idx_list.append(node_idx[0])\n            if len(node_idx[0]) > self.max_num_node:\n                self.max_num_node = len(node_idx[0])\n\n        self.graph_lists = [g.subgraph(node_idx) for node_idx in node_idx_list]\n\n        try:\n            DS_node_labels = self._idx_from_zero(\n                np.loadtxt(self._file_path(\"node_labels\"), dtype=int)\n            )\n            g.ndata[\"node_label\"] = F.tensor(DS_node_labels)\n            one_hot_node_labels = self._to_onehot(DS_node_labels)\n            for idxs, g in zip(node_idx_list, self.graph_lists):\n                g.ndata[\"feat\"] = F.tensor(\n                    one_hot_node_labels[idxs, :], F.float32\n                )\n            self.data_mode = \"node_label\"\n        except IOError:\n            print(\"No Node Label Data\")\n\n        try:\n            DS_node_attr = np.loadtxt(\n                self._file_path(\"node_attributes\"), delimiter=\",\"\n            )\n            if DS_node_attr.ndim == 1:\n                DS_node_attr = np.expand_dims(DS_node_attr, -1)\n            for idxs, g in zip(node_idx_list, self.graph_lists):\n                g.ndata[\"feat\"] = F.tensor(DS_node_attr[idxs, :], F.float32)\n            self.data_mode = \"node_attr\"\n        except IOError:\n            print(\"No Node Attribute Data\")\n\n        if \"feat\" not in g.ndata.keys():\n            for idxs, g in zip(node_idx_list, self.graph_lists):\n                g.ndata[\"feat\"] = F.ones(\n                    (g.num_nodes(), self.hidden_size), F.float32, F.cpu()\n                )\n            self.data_mode = \"constant\"\n            if self.verbose:\n                print(\n                    \"Use Constant one as Feature with hidden size {}\".format(\n                        self.hidden_size\n                    )\n                )\n\n        # remove graphs that are too large by user given standard\n        # optional pre-processing steop in conformity with Rex Ying's original\n        # DiffPool implementation\n        if self.max_allow_node:\n            preserve_idx = []\n            if self.verbose:\n                print(\"original dataset length : \", len(self.graph_lists))\n            for i, g in enumerate(self.graph_lists):\n                if g.num_nodes() <= self.max_allow_node:\n                    preserve_idx.append(i)\n            self.graph_lists = [self.graph_lists[i] for i in preserve_idx]\n            if self.verbose:\n                print(\n                    \"after pruning graphs that are too big : \",\n                    len(self.graph_lists),\n                )\n            self.graph_labels = [self.graph_labels[i] for i in preserve_idx]\n            self.max_num_node = self.max_allow_node\n        self.graph_labels = F.tensor(self.graph_labels)\n\n    def save(self):\n        label_dict = {\"labels\": self.graph_labels}\n        info_dict = {\n            \"max_num_node\": self.max_num_node,\n            \"num_labels\": self.num_labels,\n        }\n        save_graphs(str(self.graph_path), self.graph_lists, label_dict)\n        save_info(str(self.info_path), info_dict)\n\n    def load(self):\n        graphs, label_dict = load_graphs(str(self.graph_path))\n        info_dict = load_info(str(self.info_path))\n\n        self.graph_lists = graphs\n        self.graph_labels = label_dict[\"labels\"]\n        self.max_num_node = info_dict[\"max_num_node\"]\n        self.num_labels = info_dict[\"num_labels\"]\n\n    @property\n    def graph_path(self):\n        return os.path.join(\n            self.save_path, \"legacy_tu_{}_{}.bin\".format(self.name, self.hash)\n        )\n\n    @property\n    def info_path(self):\n        return os.path.join(\n            self.save_path, \"legacy_tu_{}_{}.pkl\".format(self.name, self.hash)\n        )\n\n    def has_cache(self):\n        if os.path.exists(self.graph_path) and os.path.exists(self.info_path):\n            return True\n        return False\n\n    def __getitem__(self, idx):\n        \"\"\"Get the idx-th sample.\n\n        Parameters\n        ---------\n        idx : int\n            The sample index.\n\n        Returns\n        -------\n        (:class:`dgl.DGLGraph`, Tensor)\n            Graph with node feature stored in ``feat`` field and node label in ``node_label`` if available.\n            And its label.\n        \"\"\"\n        g = self.graph_lists[idx]\n        if self._transform is not None:\n            g = self._transform(g)\n        return g, self.graph_labels[idx]\n\n    def __len__(self):\n        \"\"\"Return the number of graphs in the dataset.\"\"\"\n        return len(self.graph_lists)\n\n    def _file_path(self, category):\n        return os.path.join(\n            self.raw_path, self.name, \"{}_{}.txt\".format(self.name, category)\n        )\n\n    @staticmethod\n    def _idx_from_zero(idx_tensor):\n        return idx_tensor - np.min(idx_tensor)\n\n    @staticmethod\n    def _to_onehot(label_tensor):\n        label_num = label_tensor.shape[0]\n        assert np.min(label_tensor) == 0\n        one_hot_tensor = np.zeros((label_num, np.max(label_tensor) + 1))\n        one_hot_tensor[np.arange(label_num), label_tensor] = 1\n        return one_hot_tensor\n\n    def statistics(self):\n        return (\n            self.graph_lists[0].ndata[\"feat\"].shape[1],\n            self.num_labels,\n            self.max_num_node,\n        )\n\n    @property\n    def num_classes(self):\n        return int(self.num_labels)\n\n\nclass TUDataset(DGLBuiltinDataset):\n    r\"\"\"\n    TUDataset contains lots of graph kernel datasets for graph classification.\n\n    Parameters\n    ----------\n    name : str\n        Dataset Name, such as ``ENZYMES``, ``DD``, ``COLLAB``, ``MUTAG``, can be the\n        datasets name on `<https://chrsmrrs.github.io/datasets/docs/datasets/>`_.\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access.\n\n    Attributes\n    ----------\n    max_num_node : int\n        Maximum number of nodes\n    num_classes : int\n        Number of classes\n    num_labels : int\n        (DEPRECATED, use num_classes instead) Number of classes\n\n    Notes\n    -----\n    **IMPORTANT:** Some of the datasets have duplicate edges exist in the graphs, e.g.\n    the edges in ``IMDB-BINARY`` are all duplicated.  DGL faithfully keeps the duplicates\n    as per the original data.  Other frameworks such as PyTorch Geometric removes the\n    duplicates by default.  You can remove the duplicate edges with :func:`dgl.to_simple`.\n\n    Graphs may have node labels, node attributes, edge labels, and edge attributes,\n    varing from different dataset.\n\n    Labels are mapped to :math:`\\lbrace 0,\\cdots,n-1 \\rbrace` where :math:`n` is the\n    number of labels (some datasets have raw labels :math:`\\lbrace -1, 1 \\rbrace` which\n    will be mapped to :math:`\\lbrace 0, 1 \\rbrace`). In previous versions, the minimum\n    label was added so that :math:`\\lbrace -1, 1 \\rbrace` was mapped to\n    :math:`\\lbrace 0, 2 \\rbrace`.\n\n    The dataset sorts graphs by their labels.\n    Shuffle is preferred before manual train/val split.\n\n    Examples\n    --------\n    >>> data = TUDataset('DD')\n\n    The dataset instance is an iterable\n\n    >>> len(data)\n    1178\n    >>> g, label = data[1024]\n    >>> g\n    Graph(num_nodes=88, num_edges=410,\n          ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64), 'node_labels': Scheme(shape=(1,), dtype=torch.int64)}\n          edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)})\n    >>> label\n    tensor([1])\n\n    Batch the graphs and labels for mini-batch training\n\n    >>> graphs, labels = zip(*[data[i] for i in range(16)])\n    >>> batched_graphs = dgl.batch(graphs)\n    >>> batched_labels = torch.tensor(labels)\n    >>> batched_graphs\n    Graph(num_nodes=9539, num_edges=47382,\n          ndata_schemes={'node_labels': Scheme(shape=(1,), dtype=torch.int64), '_ID': Scheme(shape=(), dtype=torch.int64)}\n          edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)})\n\n    \"\"\"\n\n    _url = r\"https://www.chrsmrrs.com/graphkerneldatasets/{}.zip\"\n\n    def __init__(\n        self,\n        name,\n        raw_dir=None,\n        force_reload=False,\n        verbose=False,\n        transform=None,\n    ):\n        url = self._url.format(name)\n        super(TUDataset, self).__init__(\n            name=name,\n            url=url,\n            raw_dir=raw_dir,\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n    def process(self):\n        DS_edge_list = self._idx_from_zero(\n            loadtxt(self._file_path(\"A\"), delimiter=\",\").astype(int)\n        )\n        DS_indicator = self._idx_from_zero(\n            loadtxt(self._file_path(\"graph_indicator\"), delimiter=\",\").astype(\n                int\n            )\n        )\n\n        if os.path.exists(self._file_path(\"graph_labels\")):\n            DS_graph_labels = self._idx_reset(\n                loadtxt(self._file_path(\"graph_labels\"), delimiter=\",\").astype(\n                    int\n                )\n            )\n            self.num_labels = int(max(DS_graph_labels) + 1)\n            self.graph_labels = F.tensor(DS_graph_labels)\n        elif os.path.exists(self._file_path(\"graph_attributes\")):\n            DS_graph_labels = loadtxt(\n                self._file_path(\"graph_attributes\"), delimiter=\",\"\n            ).astype(float)\n            self.num_labels = None\n            self.graph_labels = F.tensor(DS_graph_labels)\n        else:\n            raise Exception(\"Unknown graph label or graph attributes\")\n\n        g = dgl_graph(([], []))\n        g.add_nodes(int(DS_edge_list.max()) + 1)\n        g.add_edges(DS_edge_list[:, 0], DS_edge_list[:, 1])\n\n        node_idx_list = []\n        self.max_num_node = 0\n        for idx in range(np.max(DS_indicator) + 1):\n            node_idx = np.where(DS_indicator == idx)\n            node_idx_list.append(node_idx[0])\n            if len(node_idx[0]) > self.max_num_node:\n                self.max_num_node = len(node_idx[0])\n\n        self.attr_dict = {\n            \"node_labels\": (\"ndata\", \"node_labels\"),\n            \"node_attributes\": (\"ndata\", \"node_attr\"),\n            \"edge_labels\": (\"edata\", \"edge_labels\"),\n            \"edge_attributes\": (\"edata\", \"node_labels\"),\n        }\n\n        for filename, field_name in self.attr_dict.items():\n            try:\n                data = loadtxt(self._file_path(filename), delimiter=\",\")\n                if \"label\" in filename:\n                    data = F.tensor(self._idx_from_zero(data))\n                else:\n                    data = F.tensor(data)\n                getattr(g, field_name[0])[field_name[1]] = data\n            except IOError:\n                pass\n\n        self.graph_lists = [g.subgraph(node_idx) for node_idx in node_idx_list]\n\n    @property\n    def graph_path(self):\n        return os.path.join(self.save_path, \"tu_{}.bin\".format(self.name))\n\n    @property\n    def info_path(self):\n        return os.path.join(self.save_path, \"tu_{}.pkl\".format(self.name))\n\n    def save(self):\n        label_dict = {\"labels\": self.graph_labels}\n        info_dict = {\n            \"max_num_node\": self.max_num_node,\n            \"num_labels\": self.num_labels,\n        }\n        save_graphs(str(self.graph_path), self.graph_lists, label_dict)\n        save_info(str(self.info_path), info_dict)\n\n    def load(self):\n        graphs, label_dict = load_graphs(str(self.graph_path))\n        info_dict = load_info(str(self.info_path))\n\n        self.graph_lists = graphs\n        self.graph_labels = label_dict[\"labels\"]\n        self.max_num_node = info_dict[\"max_num_node\"]\n        self.num_labels = info_dict[\"num_labels\"]\n\n    def has_cache(self):\n        if os.path.exists(self.graph_path) and os.path.exists(self.info_path):\n            return True\n        return False\n\n    def __getitem__(self, idx):\n        \"\"\"Get the idx-th sample.\n\n        Parameters\n        ---------\n        idx : int\n            The sample index.\n\n        Returns\n        -------\n        (:class:`dgl.DGLGraph`, Tensor)\n            Graph with node feature stored in ``feat`` field and node label in ``node_labels`` if available.\n            And its label.\n        \"\"\"\n        g = self.graph_lists[idx]\n        if self._transform is not None:\n            g = self._transform(g)\n        return g, self.graph_labels[idx]\n\n    def __len__(self):\n        \"\"\"Return the number of graphs in the dataset.\"\"\"\n        return len(self.graph_lists)\n\n    def _file_path(self, category):\n        return os.path.join(\n            self.raw_path, self.name, \"{}_{}.txt\".format(self.name, category)\n        )\n\n    @staticmethod\n    def _idx_from_zero(idx_tensor):\n        return idx_tensor - np.min(idx_tensor)\n\n    @staticmethod\n    def _idx_reset(idx_tensor):\n        \"\"\"Maps n unique labels to {0, ..., n-1} in an ordered fashion.\"\"\"\n        labels = np.unique(idx_tensor)\n        relabel_map = {x: i for i, x in enumerate(labels)}\n        new_idx_tensor = np.vectorize(relabel_map.get)(idx_tensor)\n        return new_idx_tensor\n\n    def statistics(self):\n        return (\n            self.graph_lists[0].ndata[\"feat\"].shape[1],\n            self.num_labels,\n            self.max_num_node,\n        )\n\n    @property\n    def num_classes(self):\n        return self.num_labels\n"
  },
  {
    "path": "python/dgl/data/utils.py",
    "content": "\"\"\"Dataset utilities.\"\"\"\nfrom __future__ import absolute_import\n\nimport errno\nimport hashlib\nimport os\nimport pickle\nimport sys\nimport warnings\n\nimport networkx.algorithms as A\n\nimport numpy as np\nimport requests\nfrom tqdm.auto import tqdm\n\nfrom .. import backend as F\nfrom .graph_serialize import load_graphs, load_labels, save_graphs\nfrom .tensor_serialize import load_tensors, save_tensors\n\n__all__ = [\n    \"loadtxt\",\n    \"download\",\n    \"check_sha1\",\n    \"extract_archive\",\n    \"get_download_dir\",\n    \"Subset\",\n    \"split_dataset\",\n    \"save_graphs\",\n    \"load_graphs\",\n    \"load_labels\",\n    \"save_tensors\",\n    \"load_tensors\",\n    \"add_nodepred_split\",\n    \"add_node_property_split\",\n    \"mask_nodes_by_property\",\n]\n\n\ndef loadtxt(path, delimiter, dtype=None):\n    try:\n        import pandas as pd\n\n        df = pd.read_csv(path, delimiter=delimiter, header=None)\n        return df.values\n    except ImportError:\n        warnings.warn(\n            \"Pandas is not installed, now using numpy.loadtxt to load data, \"\n            \"which could be extremely slow. Accelerate by installing pandas\"\n        )\n        return np.loadtxt(path, delimiter=delimiter)\n\n\ndef _get_dgl_url(file_url):\n    \"\"\"Get DGL online url for download.\"\"\"\n    dgl_repo_url = \"https://data.dgl.ai/\"\n    repo_url = os.environ.get(\"DGL_REPO\", dgl_repo_url)\n    if repo_url[-1] != \"/\":\n        repo_url = repo_url + \"/\"\n    return repo_url + file_url\n\n\ndef split_dataset(dataset, frac_list=None, shuffle=False, random_state=None):\n    \"\"\"Split dataset into training, validation and test set.\n\n    Parameters\n    ----------\n    dataset\n        We assume ``len(dataset)`` gives the number of datapoints and ``dataset[i]``\n        gives the ith datapoint.\n    frac_list : list or None, optional\n        A list of length 3 containing the fraction to use for training,\n        validation and test. If None, we will use [0.8, 0.1, 0.1].\n    shuffle : bool, optional\n        By default we perform a consecutive split of the dataset. If True,\n        we will first randomly shuffle the dataset.\n    random_state : None, int or array_like, optional\n        Random seed used to initialize the pseudo-random number generator.\n        Can be any integer between 0 and 2**32 - 1 inclusive, an array\n        (or other sequence) of such integers, or None (the default).\n        If seed is None, then RandomState will try to read data from /dev/urandom\n        (or the Windows analogue) if available or seed from the clock otherwise.\n\n    Returns\n    -------\n    list of length 3\n        Subsets for training, validation and test.\n    \"\"\"\n    from itertools import accumulate\n\n    if frac_list is None:\n        frac_list = [0.8, 0.1, 0.1]\n    frac_list = np.asarray(frac_list)\n    assert np.allclose(\n        np.sum(frac_list), 1.0\n    ), \"Expect frac_list sum to 1, got {:.4f}\".format(np.sum(frac_list))\n    num_data = len(dataset)\n    lengths = (num_data * frac_list).astype(int)\n    lengths[-1] = num_data - np.sum(lengths[:-1])\n    if shuffle:\n        indices = np.random.RandomState(seed=random_state).permutation(num_data)\n    else:\n        indices = np.arange(num_data)\n    return [\n        Subset(dataset, indices[offset - length : offset])\n        for offset, length in zip(accumulate(lengths), lengths)\n    ]\n\n\ndef download(\n    url,\n    path=None,\n    overwrite=True,\n    sha1_hash=None,\n    retries=5,\n    verify_ssl=True,\n    log=True,\n):\n    \"\"\"Download a given URL.\n\n    Codes borrowed from mxnet/gluon/utils.py\n\n    Parameters\n    ----------\n    url : str\n        URL to download.\n    path : str, optional\n        Destination path to store downloaded file. By default stores to the\n        current directory with the same name as in url.\n    overwrite : bool, optional\n        Whether to overwrite the destination file if it already exists.\n        By default always overwrites the downloaded file.\n    sha1_hash : str, optional\n        Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified\n        but doesn't match.\n    retries : integer, default 5\n        The number of times to attempt downloading in case of failure or non 200 return codes.\n    verify_ssl : bool, default True\n        Verify SSL certificates.\n    log : bool, default True\n        Whether to print the progress for download\n\n    Returns\n    -------\n    str\n        The file path of the downloaded file.\n    \"\"\"\n    if path is None:\n        fname = url.split(\"/\")[-1]\n        # Empty filenames are invalid\n        assert fname, (\n            \"Can't construct file-name from this URL. \"\n            \"Please set the `path` option manually.\"\n        )\n    else:\n        path = os.path.expanduser(path)\n        if os.path.isdir(path):\n            fname = os.path.join(path, url.split(\"/\")[-1])\n        else:\n            fname = path\n    assert retries >= 0, \"Number of retries should be at least 0\"\n\n    if not verify_ssl:\n        warnings.warn(\n            \"Unverified HTTPS request is being made (verify_ssl=False). \"\n            \"Adding certificate verification is strongly advised.\"\n        )\n\n    if (\n        overwrite\n        or not os.path.exists(fname)\n        or (sha1_hash and not check_sha1(fname, sha1_hash))\n    ):\n        dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))\n        if not os.path.exists(dirname):\n            os.makedirs(dirname)\n        while retries + 1 > 0:\n            # Disable pyling too broad Exception\n            # pylint: disable=W0703\n            try:\n                if log:\n                    print(\"Downloading %s from %s...\" % (fname, url))\n                r = requests.get(url, stream=True, verify=verify_ssl)\n                if r.status_code != 200:\n                    raise RuntimeError(\"Failed downloading url %s\" % url)\n                # Get the total file size.\n                total_size = int(r.headers.get(\"content-length\", 0))\n                with tqdm(\n                    total=total_size, unit=\"B\", unit_scale=True, desc=fname\n                ) as bar:\n                    with open(fname, \"wb\") as f:\n                        for chunk in r.iter_content(chunk_size=1024):\n                            if chunk:  # filter out keep-alive new chunks\n                                f.write(chunk)\n                                bar.update(len(chunk))\n                if sha1_hash and not check_sha1(fname, sha1_hash):\n                    raise UserWarning(\n                        \"File {} is downloaded but the content hash does not match.\"\n                        \" The repo may be outdated or download may be incomplete. \"\n                        'If the \"repo_url\" is overridden, consider switching to '\n                        \"the default repo.\".format(fname)\n                    )\n                break\n            except Exception as e:\n                retries -= 1\n                if retries <= 0:\n                    raise e\n                else:\n                    if log:\n                        print(\n                            \"download failed, retrying, {} attempt{} left\".format(\n                                retries, \"s\" if retries > 1 else \"\"\n                            )\n                        )\n\n    return fname\n\n\ndef check_sha1(filename, sha1_hash):\n    \"\"\"Check whether the sha1 hash of the file content matches the expected hash.\n\n    Codes borrowed from mxnet/gluon/utils.py\n\n    Parameters\n    ----------\n    filename : str\n        Path to the file.\n    sha1_hash : str\n        Expected sha1 hash in hexadecimal digits.\n\n    Returns\n    -------\n    bool\n        Whether the file content matches the expected hash.\n    \"\"\"\n    sha1 = hashlib.sha1()\n    with open(filename, \"rb\") as f:\n        while True:\n            data = f.read(1048576)\n            if not data:\n                break\n            sha1.update(data)\n\n    return sha1.hexdigest() == sha1_hash\n\n\ndef extract_archive(file, target_dir, overwrite=True):\n    \"\"\"Extract archive file.\n\n    Parameters\n    ----------\n    file : str\n        Absolute path of the archive file.\n    target_dir : str\n        Target directory of the archive to be uncompressed.\n    overwrite : bool, default True\n        Whether to overwrite the contents inside the directory.\n        By default always overwrites.\n    \"\"\"\n    if os.path.exists(target_dir) and not overwrite:\n        return\n    print(\"Extracting file to {}\".format(target_dir))\n    if (\n        file.endswith(\".tar.gz\")\n        or file.endswith(\".tar\")\n        or file.endswith(\".tgz\")\n    ):\n        import tarfile\n\n        with tarfile.open(file, \"r\") as archive:\n\n            def is_within_directory(directory, target):\n                abs_directory = os.path.abspath(directory)\n                abs_target = os.path.abspath(target)\n                prefix = os.path.commonprefix([abs_directory, abs_target])\n                return prefix == abs_directory\n\n            def safe_extract(\n                tar, path=\".\", members=None, *, numeric_owner=False\n            ):\n                for member in tar.getmembers():\n                    member_path = os.path.join(path, member.name)\n                    if not is_within_directory(path, member_path):\n                        raise Exception(\"Attempted Path Traversal in Tar File\")\n                tar.extractall(path, members, numeric_owner=numeric_owner)\n\n            safe_extract(archive, path=target_dir)\n    elif file.endswith(\".gz\"):\n        import gzip\n        import shutil\n\n        with gzip.open(file, \"rb\") as f_in:\n            target_file = os.path.join(target_dir, os.path.basename(file)[:-3])\n            with open(target_file, \"wb\") as f_out:\n                shutil.copyfileobj(f_in, f_out)\n    elif file.endswith(\".zip\"):\n        import zipfile\n\n        with zipfile.ZipFile(file, \"r\") as archive:\n            archive.extractall(path=target_dir)\n    else:\n        raise Exception(\"Unrecognized file type: \" + file)\n\n\ndef get_download_dir():\n    \"\"\"Get the absolute path to the download directory.\n\n    Returns\n    -------\n    dirname : str\n        Path to the download directory\n    \"\"\"\n    default_dir = os.path.join(os.path.expanduser(\"~\"), \".dgl\")\n    dirname = os.environ.get(\"DGL_DOWNLOAD_DIR\", default_dir)\n    if not os.path.exists(dirname):\n        os.makedirs(dirname)\n    return dirname\n\n\ndef makedirs(path):\n    try:\n        os.makedirs(os.path.expanduser(os.path.normpath(path)))\n    except OSError as e:\n        if e.errno != errno.EEXIST and os.path.isdir(path):\n            raise e\n\n\ndef save_info(path, info):\n    \"\"\"Save dataset related information into disk.\n\n    Parameters\n    ----------\n    path : str\n        File to save information.\n    info : dict\n        A python dict storing information to save on disk.\n    \"\"\"\n    with open(path, \"wb\") as pf:\n        pickle.dump(info, pf)\n\n\ndef load_info(path):\n    \"\"\"Load dataset related information from disk.\n\n    Parameters\n    ----------\n    path : str\n        File to load information from.\n\n    Returns\n    -------\n    info : dict\n        A python dict storing information loaded from disk.\n    \"\"\"\n    with open(path, \"rb\") as pf:\n        info = pickle.load(pf)\n    return info\n\n\ndef deprecate_property(old, new):\n    warnings.warn(\n        \"Property {} will be deprecated, please use {} instead.\".format(\n            old, new\n        )\n    )\n\n\ndef deprecate_function(old, new):\n    warnings.warn(\n        \"Function {} will be deprecated, please use {} instead.\".format(\n            old, new\n        )\n    )\n\n\ndef deprecate_class(old, new):\n    warnings.warn(\n        \"Class {} will be deprecated, please use {} instead.\".format(old, new)\n    )\n\n\ndef idx2mask(idx, len):\n    \"\"\"Create mask.\"\"\"\n    mask = np.zeros(len)\n    mask[idx] = 1\n    return mask\n\n\ndef generate_mask_tensor(mask):\n    \"\"\"Generate mask tensor according to different backend\n    For torch and tensorflow, it will create a bool tensor\n    For mxnet, it will create a float tensor\n    Parameters\n    ----------\n    mask: numpy ndarray\n        input mask tensor\n    \"\"\"\n    assert isinstance(mask, np.ndarray), (\n        \"input for generate_mask_tensor\" \"should be an numpy ndarray\"\n    )\n    if F.backend_name == \"mxnet\":\n        return F.tensor(mask, dtype=F.data_type_dict[\"float32\"])\n    else:\n        return F.tensor(mask, dtype=F.data_type_dict[\"bool\"])\n\n\nclass Subset(object):\n    \"\"\"Subset of a dataset at specified indices\n\n    Code adapted from PyTorch.\n\n    Parameters\n    ----------\n    dataset\n        dataset[i] should return the ith datapoint\n    indices : list\n        List of datapoint indices to construct the subset\n    \"\"\"\n\n    def __init__(self, dataset, indices):\n        self.dataset = dataset\n        self.indices = indices\n\n    def __getitem__(self, item):\n        \"\"\"Get the datapoint indexed by item\n\n        Returns\n        -------\n        tuple\n            datapoint\n        \"\"\"\n        return self.dataset[self.indices[item]]\n\n    def __len__(self):\n        \"\"\"Get subset size\n\n        Returns\n        -------\n        int\n            Number of datapoints in the subset\n        \"\"\"\n        return len(self.indices)\n\n\ndef add_nodepred_split(dataset, ratio, ntype=None):\n    \"\"\"Split the given dataset into training, validation and test sets for\n    transductive node predction task.\n\n    It adds three node mask arrays ``'train_mask'``, ``'val_mask'`` and ``'test_mask'``,\n    to each graph in the dataset. Each sample in the dataset thus must be a :class:`DGLGraph`.\n\n    Fix the random seed of NumPy to make the result deterministic::\n\n        numpy.random.seed(42)\n\n    Parameters\n    ----------\n    dataset : DGLDataset\n        The dataset to modify.\n    ratio : (float, float, float)\n        Split ratios for training, validation and test sets. Must sum to one.\n    ntype : str, optional\n        The node type to add mask for.\n\n    Examples\n    --------\n    >>> dataset = dgl.data.AmazonCoBuyComputerDataset()\n    >>> print('train_mask' in dataset[0].ndata)\n    False\n    >>> dgl.data.utils.add_nodepred_split(dataset, [0.8, 0.1, 0.1])\n    >>> print('train_mask' in dataset[0].ndata)\n    True\n    \"\"\"\n    if len(ratio) != 3:\n        raise ValueError(\n            f\"Split ratio must be a float triplet but got {ratio}.\"\n        )\n    for i in range(len(dataset)):\n        g = dataset[i]\n        n = g.num_nodes(ntype)\n        idx = np.arange(0, n)\n        np.random.shuffle(idx)\n        n_train, n_val, n_test = (\n            int(n * ratio[0]),\n            int(n * ratio[1]),\n            int(n * ratio[2]),\n        )\n        train_mask = generate_mask_tensor(idx2mask(idx[:n_train], n))\n        val_mask = generate_mask_tensor(\n            idx2mask(idx[n_train : n_train + n_val], n)\n        )\n        test_mask = generate_mask_tensor(idx2mask(idx[n_train + n_val :], n))\n        g.nodes[ntype].data[\"train_mask\"] = train_mask\n        g.nodes[ntype].data[\"val_mask\"] = val_mask\n        g.nodes[ntype].data[\"test_mask\"] = test_mask\n\n\ndef mask_nodes_by_property(property_values, part_ratios, random_seed=None):\n    \"\"\"Provide the split masks for a node split with distributional shift based on a given\n    node property, as proposed in `Evaluating Robustness and Uncertainty of Graph Models\n    Under Structural Distributional Shifts <https://arxiv.org/abs/2302.13875>`__\n\n    It considers the in-distribution (ID) and out-of-distribution (OOD) subsets of nodes.\n    The ID subset includes training, validation and testing parts, while the OOD subset\n    includes validation and testing parts. It sorts the nodes in the ascending order of\n    their property values, splits them into 5 non-intersecting parts, and creates 5\n    associated node mask arrays:\n        - 3 for the ID nodes: ``'in_train_mask'``, ``'in_valid_mask'``, ``'in_test_mask'``,\n        - and 2 for the OOD nodes: ``'out_valid_mask'``, ``'out_test_mask'``.\n\n    Parameters\n    ----------\n    property_values : numpy ndarray\n        The node property (float) values by which the dataset will be split.\n        The length of the array must be equal to the number of nodes in graph.\n    part_ratios : list\n        A list of 5 ratios for training, ID validation, ID test,\n        OOD validation, OOD testing parts. The values in the list must sum to one.\n    random_seed : int, optional\n        Random seed to fix for the initial permutation of nodes. It is\n        used to create a random order for the nodes that have the same\n        property values or belong to the ID subset. (default: None)\n\n    Returns\n    ----------\n    split_masks : dict\n        A python dict storing the mask names as keys and the corresponding\n        node mask arrays as values.\n\n    Examples\n    --------\n    >>> num_nodes = 1000\n    >>> property_values = np.random.uniform(size=num_nodes)\n    >>> part_ratios = [0.3, 0.1, 0.1, 0.3, 0.2]\n    >>> split_masks = dgl.data.utils.mask_nodes_by_property(property_values, part_ratios)\n    >>> print('in_valid_mask' in split_masks)\n    True\n    \"\"\"\n\n    num_nodes = len(property_values)\n    part_sizes = np.round(num_nodes * np.array(part_ratios)).astype(int)\n    part_sizes[-1] -= np.sum(part_sizes) - num_nodes\n\n    generator = np.random.RandomState(random_seed)\n    permutation = generator.permutation(num_nodes)\n\n    node_indices = np.arange(num_nodes)[permutation]\n    property_values = property_values[permutation]\n    in_distribution_size = np.sum(part_sizes[:3])\n\n    node_indices_ordered = node_indices[np.argsort(property_values)]\n    node_indices_ordered[:in_distribution_size] = generator.permutation(\n        node_indices_ordered[:in_distribution_size]\n    )\n\n    sections = np.cumsum(part_sizes)\n    node_split = np.split(node_indices_ordered, sections)[:-1]\n    mask_names = [\n        \"in_train_mask\",\n        \"in_valid_mask\",\n        \"in_test_mask\",\n        \"out_valid_mask\",\n        \"out_test_mask\",\n    ]\n    split_masks = {}\n\n    for mask_name, node_indices in zip(mask_names, node_split):\n        split_mask = idx2mask(node_indices, num_nodes)\n        split_masks[mask_name] = generate_mask_tensor(split_mask)\n\n    return split_masks\n\n\ndef add_node_property_split(\n    dataset, part_ratios, property_name, ascending=True, random_seed=None\n):\n    \"\"\"Create a node split with distributional shift based on a given node property,\n    as proposed in `Evaluating Robustness and Uncertainty of Graph Models Under\n    Structural Distributional Shifts <https://arxiv.org/abs/2302.13875>`__\n\n    It splits the nodes of each graph in the given dataset into 5 non-intersecting\n    parts based on their structural properties. This can be used for transductive node\n    prediction task with distributional shifts.\n\n    It considers the in-distribution (ID) and out-of-distribution (OOD) subsets of nodes.\n    The ID subset includes training, validation and testing parts, while the OOD subset\n    includes validation and testing parts. As a result, it creates 5 associated node mask\n    arrays for each graph:\n        - 3 for the ID nodes: ``'in_train_mask'``, ``'in_valid_mask'``, ``'in_test_mask'``,\n        - and 2 for the OOD nodes: ``'out_valid_mask'``, ``'out_test_mask'``.\n\n    This function implements 3 particular strategies for inducing distributional shifts\n    in graph — based on **popularity**, **locality** or **density**.\n\n    Parameters\n    ----------\n    dataset : :class:`~DGLDataset` or list of :class:`~dgl.DGLGraph`\n        The dataset to induce structural distributional shift.\n    part_ratios : list\n        A list of 5 ratio values for training, ID validation, ID test,\n        OOD validation and OOD test parts. The values must sum to 1.0.\n    property_name : str\n        The name of the node property to be used, which must be\n        ``'popularity'``, ``'locality'`` or ``'density'``.\n    ascending : bool, optional\n        Whether to sort nodes in the ascending order of the node property,\n        so that nodes with greater values of the property are considered\n        to be OOD (default: True)\n    random_seed : int, optional\n        Random seed to fix for the initial permutation of nodes. It is\n        used to create a random order for the nodes that have the same\n        property values or belong to the ID subset. (default: None)\n\n    Examples\n    --------\n    >>> dataset = dgl.data.AmazonCoBuyComputerDataset()\n    >>> print('in_valid_mask' in dataset[0].ndata)\n    False\n    >>> part_ratios = [0.3, 0.1, 0.1, 0.3, 0.2]\n    >>> property_name = 'popularity'\n    >>> dgl.data.utils.add_node_property_split(dataset, part_ratios, property_name)\n    >>> print('in_valid_mask' in dataset[0].ndata)\n    True\n    \"\"\"\n\n    assert property_name in [\n        \"popularity\",\n        \"locality\",\n        \"density\",\n    ], \"The name of property has to be 'popularity', 'locality', or 'density'\"\n\n    assert len(part_ratios) == 5, \"part_ratios must contain 5 values\"\n\n    import networkx as nx\n\n    for idx in range(len(dataset)):\n        graph_dgl = dataset[idx]\n        graph_nx = nx.Graph(graph_dgl.to_networkx())\n\n        compute_property_fn = _property_name_to_compute_fn[property_name]\n        property_values = compute_property_fn(graph_nx, ascending)\n\n        node_masks = mask_nodes_by_property(\n            property_values, part_ratios, random_seed\n        )\n\n        for mask_name, node_mask in node_masks.items():\n            graph_dgl.ndata[mask_name] = node_mask\n\n\ndef _compute_popularity_property(graph_nx, ascending=True):\n    direction = -1 if ascending else 1\n    property_values = direction * np.array(list(A.pagerank(graph_nx).values()))\n    return property_values\n\n\ndef _compute_locality_property(graph_nx, ascending=True):\n    num_nodes = graph_nx.number_of_nodes()\n    pagerank_values = np.array(list(A.pagerank(graph_nx).values()))\n\n    personalization = dict(zip(range(num_nodes), [0.0] * num_nodes))\n    personalization[np.argmax(pagerank_values)] = 1.0\n\n    direction = -1 if ascending else 1\n    property_values = direction * np.array(\n        list(A.pagerank(graph_nx, personalization=personalization).values())\n    )\n    return property_values\n\n\ndef _compute_density_property(graph_nx, ascending=True):\n    direction = -1 if ascending else 1\n    property_values = direction * np.array(\n        list(A.clustering(graph_nx).values())\n    )\n    return property_values\n\n\n_property_name_to_compute_fn = {\n    \"popularity\": _compute_popularity_property,\n    \"locality\": _compute_locality_property,\n    \"density\": _compute_density_property,\n}\n"
  },
  {
    "path": "python/dgl/data/wikics.py",
    "content": "\"\"\"Wiki-CS Dataset\"\"\"\nimport itertools\nimport json\nimport os\n\nimport numpy as np\n\nfrom .. import backend as F\nfrom ..convert import graph\nfrom ..transforms import reorder_graph, to_bidirected\nfrom .dgl_dataset import DGLBuiltinDataset\nfrom .utils import _get_dgl_url, generate_mask_tensor, load_graphs, save_graphs\n\n\nclass WikiCSDataset(DGLBuiltinDataset):\n    r\"\"\"Wiki-CS is a Wikipedia-based dataset for node classification from `Wiki-CS: A Wikipedia-Based\n    Benchmark for Graph Neural Networks <https://arxiv.org/abs/2007.02901v2>`_\n\n    The dataset consists of nodes corresponding to Computer Science articles, with edges based on\n    hyperlinks and 10 classes representing different branches of the field.\n\n    WikiCS dataset statistics:\n\n    - Nodes: 11,701\n    - Edges: 431,726 (note that the original dataset has 216,123 edges but DGL adds\n      the reverse edges and removes the duplicate edges, hence with a different number)\n    - Number of classes: 10\n    - Node feature size: 300\n    - Number of different train, validation, stopping splits: 20\n    - Number of test split: 1\n\n    Parameters\n    ----------\n    raw_dir : str\n        Raw file directory to download/contains the input data directory.\n        Default: ~/.dgl/\n    force_reload : bool\n        Whether to reload the dataset.\n        Default: False\n    verbose : bool\n        Whether to print out progress information.\n        Default: False\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access.\n\n    Attributes\n    ----------\n    num_classes : int\n        Number of node classes\n\n    Examples\n    --------\n    >>> from dgl.data import WikiCSDataset\n    >>> dataset = WikiCSDataset()\n    >>> dataset.num_classes\n    10\n    >>> g = dataset[0]\n    >>> # get node feature\n    >>> feat = g.ndata['feat']\n    >>> # get node labels\n    >>> labels = g.ndata['label']\n    >>> # get data split\n    >>> train_mask = g.ndata['train_mask']\n    >>> val_mask = g.ndata['val_mask']\n    >>> stopping_mask = g.ndata['stopping_mask']\n    >>> test_mask = g.ndata['test_mask']\n    >>> # The shape of train, val and stopping masks are (num_nodes, num_splits).\n    >>> # The num_splits is the number of different train, validation, stopping splits.\n    >>> # Due to the number of test spilt is 1, the shape of test mask is (num_nodes,).\n    >>> print(train_mask.shape, val_mask.shape, stopping_mask.shape)\n    (11701, 20) (11701, 20) (11701, 20)\n    >>> print(test_mask.shape)\n    (11701,)\n    \"\"\"\n\n    def __init__(\n        self, raw_dir=None, force_reload=False, verbose=False, transform=None\n    ):\n        _url = _get_dgl_url(\"dataset/wiki_cs.zip\")\n        super(WikiCSDataset, self).__init__(\n            name=\"wiki_cs\",\n            raw_dir=raw_dir,\n            url=_url,\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n    def process(self):\n        \"\"\"process raw data to graph, labels and masks\"\"\"\n        with open(os.path.join(self.raw_path, \"data.json\")) as f:\n            data = json.load(f)\n        features = F.tensor(np.array(data[\"features\"]), dtype=F.float32)\n        labels = F.tensor(np.array(data[\"labels\"]), dtype=F.int64)\n\n        train_masks = np.array(data[\"train_masks\"], dtype=bool).T\n        val_masks = np.array(data[\"val_masks\"], dtype=bool).T\n        stopping_masks = np.array(data[\"stopping_masks\"], dtype=bool).T\n        test_mask = np.array(data[\"test_mask\"], dtype=bool)\n\n        edges = [[(i, j) for j in js] for i, js in enumerate(data[\"links\"])]\n        edges = np.array(list(itertools.chain(*edges)))\n        src, dst = edges[:, 0], edges[:, 1]\n\n        g = graph((src, dst))\n        g = to_bidirected(g)\n\n        g.ndata[\"feat\"] = features\n        g.ndata[\"label\"] = labels\n        g.ndata[\"train_mask\"] = generate_mask_tensor(train_masks)\n        g.ndata[\"val_mask\"] = generate_mask_tensor(val_masks)\n        g.ndata[\"stopping_mask\"] = generate_mask_tensor(stopping_masks)\n        g.ndata[\"test_mask\"] = generate_mask_tensor(test_mask)\n\n        g = reorder_graph(\n            g,\n            node_permute_algo=\"rcmk\",\n            edge_permute_algo=\"dst\",\n            store_ids=False,\n        )\n\n        self._graph = g\n\n    def has_cache(self):\n        graph_path = os.path.join(self.save_path, \"dgl_graph.bin\")\n        return os.path.exists(graph_path)\n\n    def save(self):\n        graph_path = os.path.join(self.save_path, \"dgl_graph.bin\")\n        save_graphs(graph_path, self._graph)\n\n    def load(self):\n        graph_path = os.path.join(self.save_path, \"dgl_graph.bin\")\n        g, _ = load_graphs(graph_path)\n        self._graph = g[0]\n\n    @property\n    def num_classes(self):\n        return 10\n\n    def __len__(self):\n        r\"\"\"The number of graphs in the dataset.\"\"\"\n        return 1\n\n    def __getitem__(self, idx):\n        r\"\"\"Get graph object\n\n        Parameters\n        ----------\n        idx : int\n            Item index, WikiCSDataset has only one graph object\n\n        Returns\n        -------\n        :class:`dgl.DGLGraph`\n\n            The graph contains:\n\n            - ``ndata['feat']``: node features\n            - ``ndata['label']``: node labels\n            - ``ndata['train_mask']``: train mask is for retrieving the nodes for training.\n            - ``ndata['val_mask']``: val mask is for retrieving the nodes for hyperparameter tuning.\n            - ``ndata['stopping_mask']``: stopping mask is for retrieving the nodes for early stopping criterion.\n            - ``ndata['test_mask']``: test mask is for retrieving the nodes for testing.\n\n        \"\"\"\n        assert idx == 0, \"This dataset has only one graph\"\n        if self._transform is None:\n            return self._graph\n        else:\n            return self._transform(self._graph)\n"
  },
  {
    "path": "python/dgl/data/yelp.py",
    "content": "\"\"\"Yelp Dataset\"\"\"\nimport json\nimport os\n\nimport numpy as np\nimport scipy.sparse as sp\n\nfrom .. import backend as F\nfrom ..convert import from_scipy\nfrom ..transforms import reorder_graph\nfrom .dgl_dataset import DGLBuiltinDataset\nfrom .utils import _get_dgl_url, generate_mask_tensor, load_graphs, save_graphs\n\n\nclass YelpDataset(DGLBuiltinDataset):\n    r\"\"\"Yelp dataset for node classification from `GraphSAINT: Graph Sampling Based Inductive\n    Learning Method <https://arxiv.org/abs/1907.04931>`_\n\n    The task of this dataset is categorizing types of businesses based on customer reviewers and\n    friendship.\n\n    Yelp dataset statistics:\n\n    - Nodes: 716,847\n    - Edges: 13,954,819\n    - Number of classes: 100 (Multi-class)\n    - Node feature size: 300\n\n    Parameters\n    ----------\n    raw_dir : str\n        Raw file directory to download/contains the input data directory.\n        Default: ~/.dgl/\n    force_reload : bool\n        Whether to reload the dataset.\n        Default: False\n    verbose : bool\n        Whether to print out progress information.\n        Default: False\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access.\n    reorder : bool\n        Whether to reorder the graph using :func:`~dgl.reorder_graph`.\n        Default: False.\n\n    Attributes\n    ----------\n    num_classes : int\n        Number of node classes\n\n    Examples\n    --------\n    >>> dataset = YelpDataset()\n    >>> dataset.num_classes\n    100\n    >>> g = dataset[0]\n    >>> # get node feature\n    >>> feat = g.ndata['feat']\n    >>> # get node labels\n    >>> labels = g.ndata['label']\n    >>> # get data split\n    >>> train_mask = g.ndata['train_mask']\n    >>> val_mask = g.ndata['val_mask']\n    >>> test_mask = g.ndata['test_mask']\n    \"\"\"\n\n    def __init__(\n        self,\n        raw_dir=None,\n        force_reload=False,\n        verbose=False,\n        transform=None,\n        reorder=False,\n    ):\n        _url = _get_dgl_url(\"dataset/yelp.zip\")\n        self._reorder = reorder\n        super(YelpDataset, self).__init__(\n            name=\"yelp\",\n            raw_dir=raw_dir,\n            url=_url,\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n    def process(self):\n        \"\"\"process raw data to graph, labels and masks\"\"\"\n        coo_adj = sp.load_npz(os.path.join(self.raw_path, \"adj_full.npz\"))\n        g = from_scipy(coo_adj)\n\n        features = np.load(os.path.join(self.raw_path, \"feats.npy\"))\n        features = F.tensor(features, dtype=F.float32)\n\n        y = [-1] * features.shape[0]\n        with open(os.path.join(self.raw_path, \"class_map.json\")) as f:\n            class_map = json.load(f)\n            for key, item in class_map.items():\n                y[int(key)] = item\n        labels = F.tensor(np.array(y), dtype=F.int64)\n\n        with open(os.path.join(self.raw_path, \"role.json\")) as f:\n            role = json.load(f)\n\n        train_mask = np.zeros(features.shape[0], dtype=bool)\n        train_mask[role[\"tr\"]] = True\n\n        val_mask = np.zeros(features.shape[0], dtype=bool)\n        val_mask[role[\"va\"]] = True\n\n        test_mask = np.zeros(features.shape[0], dtype=bool)\n        test_mask[role[\"te\"]] = True\n\n        g.ndata[\"feat\"] = features\n        g.ndata[\"label\"] = labels\n        g.ndata[\"train_mask\"] = generate_mask_tensor(train_mask)\n        g.ndata[\"val_mask\"] = generate_mask_tensor(val_mask)\n        g.ndata[\"test_mask\"] = generate_mask_tensor(test_mask)\n\n        if self._reorder:\n            self._graph = reorder_graph(\n                g,\n                node_permute_algo=\"rcmk\",\n                edge_permute_algo=\"dst\",\n                store_ids=False,\n            )\n        else:\n            self._graph = g\n\n    def has_cache(self):\n        graph_path = os.path.join(self.save_path, \"dgl_graph.bin\")\n        return os.path.exists(graph_path)\n\n    def save(self):\n        graph_path = os.path.join(self.save_path, \"dgl_graph.bin\")\n        save_graphs(graph_path, self._graph)\n\n    def load(self):\n        graph_path = os.path.join(self.save_path, \"dgl_graph.bin\")\n        g, _ = load_graphs(graph_path)\n        self._graph = g[0]\n\n    @property\n    def num_classes(self):\n        return 100\n\n    def __len__(self):\n        r\"\"\"The number of graphs in the dataset.\"\"\"\n        return 1\n\n    def __getitem__(self, idx):\n        r\"\"\"Get graph object\n\n        Parameters\n        ----------\n        idx : int\n            Item index, FlickrDataset has only one graph object\n\n        Returns\n        -------\n        :class:`dgl.DGLGraph`\n\n            The graph contains:\n\n            - ``ndata['label']``: node label\n            - ``ndata['feat']``: node feature\n            - ``ndata['train_mask']``: mask for training node set\n            - ``ndata['val_mask']``: mask for validation node set\n            - ``ndata['test_mask']``: mask for test node set\n\n        \"\"\"\n        assert idx == 0, \"This dataset has only one graph\"\n        if self._transform is None:\n            return self._graph\n        else:\n            return self._transform(self._graph)\n"
  },
  {
    "path": "python/dgl/data/zinc.py",
    "content": "import os\n\nfrom .dgl_dataset import DGLBuiltinDataset\nfrom .utils import _get_dgl_url, load_graphs\n\n\nclass ZINCDataset(DGLBuiltinDataset):\n    r\"\"\"ZINC dataset for the graph regression task.\n\n    A subset (12K) of ZINC molecular graphs (250K) dataset is used to\n    regress a molecular property known as the constrained solubility.\n    For each molecular graph, the node features are the types of heavy\n    atoms, between which the edge features are the types of bonds.\n    Each graph contains 9-37 nodes and 16-84 edges.\n\n    Reference `<https://arxiv.org/pdf/2003.00982.pdf>`_\n\n    Statistics:\n\n    Train examples: 10,000\n    Valid examples: 1,000\n    Test examples: 1,000\n    Average number of nodes: 23.16\n    Average number of edges: 39.83\n    Number of atom types: 28\n    Number of bond types: 4\n\n    Parameters\n    ----------\n    mode : str, optional\n        Should be chosen from [\"train\", \"valid\", \"test\"]\n        Default: \"train\".\n    raw_dir : str\n        Raw file directory to download/contains the input data directory.\n        Default: \"~/.dgl/\".\n    force_reload : bool\n        Whether to reload the dataset.\n        Default: False.\n    verbose : bool\n        Whether to print out progress information.\n        Default: False.\n    transform : callable, optional\n        A transform that takes in a :class:`~dgl.DGLGraph` object and returns\n        a transformed version. The :class:`~dgl.DGLGraph` object will be\n        transformed before every access.\n\n    Attributes\n    ----------\n    num_atom_types : int\n        Number of atom types.\n    num_bond_types : int\n        Number of bond types.\n\n    Examples\n    ---------\n    >>> from dgl.data import ZINCDataset\n\n    >>> training_set = ZINCDataset(mode=\"train\")\n    >>> training_set.num_atom_types\n    28\n    >>> len(training_set)\n    10000\n    >>> graph, label = training_set[0]\n    >>> graph\n    Graph(num_nodes=29, num_edges=64,\n        ndata_schemes={'feat': Scheme(shape=(), dtype=torch.int64)}\n        edata_schemes={'feat': Scheme(shape=(), dtype=torch.int64)})\n    \"\"\"\n\n    def __init__(\n        self,\n        mode=\"train\",\n        raw_dir=None,\n        force_reload=False,\n        verbose=False,\n        transform=None,\n    ):\n        self._url = _get_dgl_url(\"dataset/ZINC12k.zip\")\n        self.mode = mode\n\n        super(ZINCDataset, self).__init__(\n            name=\"zinc\",\n            url=self._url,\n            raw_dir=raw_dir,\n            force_reload=force_reload,\n            verbose=verbose,\n            transform=transform,\n        )\n\n    def process(self):\n        self.load()\n\n    @property\n    def graph_path(self):\n        return os.path.join(self.save_path, \"ZincDGL_{}.bin\".format(self.mode))\n\n    def has_cache(self):\n        return os.path.exists(self.graph_path)\n\n    def load(self):\n        self._graphs, self._labels = load_graphs(self.graph_path)\n\n    @property\n    def num_atom_types(self):\n        return 28\n\n    @property\n    def num_bond_types(self):\n        return 4\n\n    def __len__(self):\n        return len(self._graphs)\n\n    def __getitem__(self, idx):\n        r\"\"\"Get one example by index.\n\n        Parameters\n        ----------\n        idx : int\n            The sample index.\n\n        Returns\n        -------\n        dgl.DGLGraph\n            Each graph contains:\n\n            - ``ndata['feat']``: Types of heavy atoms as node features\n            - ``edata['feat']``: Types of bonds as edge features\n\n        Tensor\n            Constrained solubility as graph label\n        \"\"\"\n        labels = self._labels[\"g_label\"]\n        if self._transform is None:\n            return self._graphs[idx], labels[idx]\n        else:\n            return self._transform(self._graphs[idx]), labels[idx]\n"
  },
  {
    "path": "python/dgl/dataloading/__init__.py",
    "content": "\"\"\"Package for dataloaders and samplers.\"\"\"\n\nfrom .. import backend as F\nfrom . import negative_sampler\nfrom .base import *\nfrom .cluster_gcn import *\nfrom .graphsaint import *\nfrom .labor_sampler import *\nfrom .neighbor_sampler import *\nfrom .shadow import *\n\nif F.get_preferred_backend() == \"pytorch\":\n    from .spot_target import *\n    from .dataloader import *\n"
  },
  {
    "path": "python/dgl/dataloading/base.py",
    "content": "\"\"\"Base classes and functionalities for dataloaders\"\"\"\nimport inspect\nfrom collections.abc import Mapping\n\nfrom .. import backend as F\nfrom ..base import EID, NID\nfrom ..convert import heterograph\nfrom ..frame import LazyFeature\nfrom ..transforms import compact_graphs\nfrom ..utils import context_of, recursive_apply\n\n\ndef _set_lazy_features(x, xdata, feature_names):\n    if feature_names is None:\n        return\n    if not isinstance(feature_names, Mapping):\n        xdata.update({k: LazyFeature(k) for k in feature_names})\n    else:\n        for type_, names in feature_names.items():\n            x[type_].data.update({k: LazyFeature(k) for k in names})\n\n\ndef set_node_lazy_features(g, feature_names):\n    \"\"\"Assign lazy features to the ``ndata`` of the input graph for prefetching optimization.\n\n    When used in a :class:`~dgl.dataloading.Sampler`, lazy features mark which data\n    should be fetched before computation in model. See :ref:`guide-minibatch-prefetching`\n    for a detailed explanation.\n\n    If the graph is homogeneous, this is equivalent to:\n\n    .. code:: python\n\n       g.ndata.update({k: LazyFeature(k, g.ndata[dgl.NID]) for k in feature_names})\n\n    If the graph is heterogeneous, this is equivalent to:\n\n    .. code:: python\n\n        for type_, names in feature_names.items():\n            g.nodes[type_].data.update(\n                {k: LazyFeature(k, g.nodes[type_].data[dgl.NID]) for k in names})\n\n    Parameters\n    ----------\n    g : DGLGraph\n        The graph.\n    feature_names : list[str] or dict[str, list[str]]\n        The feature names to prefetch.\n\n    See also\n    --------\n    dgl.LazyFeature\n    \"\"\"\n    return _set_lazy_features(g.nodes, g.ndata, feature_names)\n\n\ndef set_edge_lazy_features(g, feature_names):\n    \"\"\"Assign lazy features to the ``edata`` of the input graph for prefetching optimization.\n\n    When used in a :class:`~dgl.dataloading.Sampler`, lazy features mark which data\n    should be fetched before computation in model. See :ref:`guide-minibatch-prefetching`\n    for a detailed explanation.\n\n    If the graph is homogeneous, this is equivalent to:\n\n    .. code:: python\n\n       g.edata.update({k: LazyFeature(k, g.edata[dgl.EID]) for k in feature_names})\n\n    If the graph is heterogeneous, this is equivalent to:\n\n    .. code:: python\n\n        for type_, names in feature_names.items():\n            g.edges[type_].data.update(\n                {k: LazyFeature(k, g.edges[type_].data[dgl.EID]) for k in names})\n\n    Parameters\n    ----------\n    g : DGLGraph\n        The graph.\n    feature_names : list[str] or dict[etype, list[str]]\n        The feature names to prefetch. The ``etype`` key is either a string\n        or a triplet.\n\n    See also\n    --------\n    dgl.LazyFeature\n    \"\"\"\n    return _set_lazy_features(g.edges, g.edata, feature_names)\n\n\ndef set_src_lazy_features(g, feature_names):\n    \"\"\"Assign lazy features to the ``srcdata`` of the input graph for prefetching optimization.\n\n    When used in a :class:`~dgl.dataloading.Sampler`, lazy features mark which data\n    should be fetched before computation in model. See :ref:`guide-minibatch-prefetching`\n    for a detailed explanation.\n\n    If the graph is homogeneous, this is equivalent to:\n\n    .. code:: python\n\n       g.srcdata.update({k: LazyFeature(k, g.srcdata[dgl.NID]) for k in feature_names})\n\n    If the graph is heterogeneous, this is equivalent to:\n\n    .. code:: python\n\n        for type_, names in feature_names.items():\n            g.srcnodes[type_].data.update(\n                {k: LazyFeature(k, g.srcnodes[type_].data[dgl.NID]) for k in names})\n\n    Parameters\n    ----------\n    g : DGLGraph\n        The graph.\n    feature_names : list[str] or dict[str, list[str]]\n        The feature names to prefetch.\n\n    See also\n    --------\n    dgl.LazyFeature\n    \"\"\"\n    return _set_lazy_features(g.srcnodes, g.srcdata, feature_names)\n\n\ndef set_dst_lazy_features(g, feature_names):\n    \"\"\"Assign lazy features to the ``dstdata`` of the input graph for prefetching optimization.\n\n    When used in a :class:`~dgl.dataloading.Sampler`, lazy features mark which data\n    should be fetched before computation in model. See :ref:`guide-minibatch-prefetching`\n    for a detailed explanation.\n\n    If the graph is homogeneous, this is equivalent to:\n\n    .. code:: python\n\n       g.dstdata.update({k: LazyFeature(k, g.dstdata[dgl.NID]) for k in feature_names})\n\n    If the graph is heterogeneous, this is equivalent to:\n\n    .. code:: python\n\n        for type_, names in feature_names.items():\n            g.dstnodes[type_].data.update(\n                {k: LazyFeature(k, g.dstnodes[type_].data[dgl.NID]) for k in names})\n\n    Parameters\n    ----------\n    g : DGLGraph\n        The graph.\n    feature_names : list[str] or dict[str, list[str]]\n        The feature names to prefetch.\n\n    See also\n    --------\n    dgl.LazyFeature\n    \"\"\"\n    return _set_lazy_features(g.dstnodes, g.dstdata, feature_names)\n\n\nclass Sampler(object):\n    \"\"\"Base class for graph samplers.\n\n    All graph samplers must subclass this class and override the ``sample``\n    method.\n\n    .. code:: python\n\n        from dgl.dataloading import Sampler\n\n        class SubgraphSampler(Sampler):\n            def __init__(self):\n                super().__init__()\n\n            def sample(self, g, indices):\n                return g.subgraph(indices)\n    \"\"\"\n\n    def sample(self, g, indices):\n        \"\"\"Abstract sample method.\n\n        Parameters\n        ----------\n        g : DGLGraph\n            The graph.\n        indices : object\n            Any object representing the indices selected in the current minibatch.\n        \"\"\"\n        raise NotImplementedError\n\n\nclass BlockSampler(Sampler):\n    \"\"\"Base class for sampling mini-batches in the form of Message-passing\n    Flow Graphs (MFGs).\n\n    It provides prefetching options to fetch the node features for the first MFG's ``srcdata``,\n    the node labels for the last MFG's ``dstdata`` and the edge features of all MFG's ``edata``.\n\n    Parameters\n    ----------\n    prefetch_node_feats : list[str] or dict[str, list[str]], optional\n        The node data to prefetch for the first MFG.\n\n        DGL will populate the first layer's MFG's ``srcnodes`` and ``srcdata`` with\n        the node data of the given names from the original graph.\n    prefetch_labels : list[str] or dict[str, list[str]], optional\n        The node data to prefetch for the last MFG.\n\n        DGL will populate the last layer's MFG's ``dstnodes`` and ``dstdata`` with\n        the node data of the given names from the original graph.\n    prefetch_edge_feats : list[str] or dict[etype, list[str]], optional\n        The edge data names to prefetch for all the MFGs.\n\n        DGL will populate every MFG's ``edges`` and ``edata`` with the edge data\n        of the given names from the original graph.\n    output_device : device, optional\n        The device of the output subgraphs or MFGs.  Default is the same as the\n        minibatch of seed nodes.\n    \"\"\"\n\n    def __init__(\n        self,\n        prefetch_node_feats=None,\n        prefetch_labels=None,\n        prefetch_edge_feats=None,\n        output_device=None,\n    ):\n        super().__init__()\n        self.prefetch_node_feats = prefetch_node_feats or []\n        self.prefetch_labels = prefetch_labels or []\n        self.prefetch_edge_feats = prefetch_edge_feats or []\n        self.output_device = output_device\n\n    def sample_blocks(self, g, seed_nodes, exclude_eids=None):\n        \"\"\"Generates a list of blocks from the given seed nodes.\n\n        This function must return a triplet where the first element is the input node IDs\n        for the first GNN layer (a tensor or a dict of tensors for heterogeneous graphs),\n        the second element is the output node IDs for the last GNN layer, and the third\n        element is the said list of blocks.\n        \"\"\"\n        raise NotImplementedError\n\n    def assign_lazy_features(self, result):\n        \"\"\"Assign lazy features for prefetching.\"\"\"\n        input_nodes, output_nodes, blocks = result\n        set_src_lazy_features(blocks[0], self.prefetch_node_feats)\n        set_dst_lazy_features(blocks[-1], self.prefetch_labels)\n        for block in blocks:\n            set_edge_lazy_features(block, self.prefetch_edge_feats)\n        return input_nodes, output_nodes, blocks\n\n    def sample(\n        self, g, seed_nodes, exclude_eids=None\n    ):  # pylint: disable=arguments-differ\n        \"\"\"Sample a list of blocks from the given seed nodes.\"\"\"\n        result = self.sample_blocks(g, seed_nodes, exclude_eids=exclude_eids)\n        return self.assign_lazy_features(result)\n\n\ndef _find_exclude_eids_with_reverse_id(g, eids, reverse_eid_map):\n    if isinstance(eids, Mapping):\n        eids = {g.to_canonical_etype(k): v for k, v in eids.items()}\n        exclude_eids = {\n            k: F.cat([v, F.gather_row(reverse_eid_map[k], v)], 0)\n            for k, v in eids.items()\n        }\n    else:\n        exclude_eids = F.cat([eids, F.gather_row(reverse_eid_map, eids)], 0)\n    return exclude_eids\n\n\ndef _find_exclude_eids_with_reverse_types(g, eids, reverse_etype_map):\n    exclude_eids = {g.to_canonical_etype(k): v for k, v in eids.items()}\n    reverse_etype_map = {\n        g.to_canonical_etype(k): g.to_canonical_etype(v)\n        for k, v in reverse_etype_map.items()\n    }\n    for k, v in reverse_etype_map.items():\n        if k in exclude_eids:\n            if v in exclude_eids:\n                exclude_eids[v] = F.unique(\n                    F.cat((exclude_eids[k], exclude_eids[v]), dim=0)\n                )\n            else:\n                exclude_eids[v] = exclude_eids[k]\n    return exclude_eids\n\n\ndef _find_exclude_eids(g, exclude_mode, eids, **kwargs):\n    if exclude_mode is None:\n        return None\n    elif callable(exclude_mode):\n        return exclude_mode(eids)\n    elif F.is_tensor(exclude_mode) or (\n        isinstance(exclude_mode, Mapping)\n        and all(F.is_tensor(v) for v in exclude_mode.values())\n    ):\n        return exclude_mode\n    elif exclude_mode == \"self\":\n        return eids\n    elif exclude_mode == \"reverse_id\":\n        return _find_exclude_eids_with_reverse_id(\n            g, eids, kwargs[\"reverse_eid_map\"]\n        )\n    elif exclude_mode == \"reverse_types\":\n        return _find_exclude_eids_with_reverse_types(\n            g, eids, kwargs[\"reverse_etype_map\"]\n        )\n    else:\n        raise ValueError(\"unsupported mode {}\".format(exclude_mode))\n\n\ndef find_exclude_eids(\n    g,\n    seed_edges,\n    exclude,\n    reverse_eids=None,\n    reverse_etypes=None,\n    output_device=None,\n):\n    \"\"\"Find all edge IDs to exclude according to :attr:`exclude_mode`.\n\n    Parameters\n    ----------\n    g : DGLGraph\n        The graph.\n    exclude :\n        Can be either of the following,\n\n        None (default)\n            Does not exclude any edge.\n\n        'self'\n            Exclude the given edges themselves but nothing else.\n\n        'reverse_id'\n            Exclude all edges specified in ``eids``, as well as their reverse edges\n            of the same edge type.\n\n            The mapping from each edge ID to its reverse edge ID is specified in\n            the keyword argument ``reverse_eid_map``.\n\n            This mode assumes that the reverse of an edge with ID ``e`` and type\n            ``etype`` will have ID ``reverse_eid_map[e]`` and type ``etype``.\n\n        'reverse_types'\n            Exclude all edges specified in ``eids``, as well as their reverse\n            edges of the corresponding edge types.\n\n            The mapping from each edge type to its reverse edge type is specified\n            in the keyword argument ``reverse_etype_map``.\n\n            This mode assumes that the reverse of an edge with ID ``e`` and type ``etype``\n            will have ID ``e`` and type ``reverse_etype_map[etype]``.\n\n        callable\n            Any function that takes in a single argument :attr:`seed_edges` and returns\n            a tensor or dict of tensors.\n    eids : Tensor or dict[etype, Tensor]\n        The edge IDs.\n    reverse_eids : Tensor or dict[etype, Tensor]\n        The mapping from edge ID to its reverse edge ID.\n    reverse_etypes : dict[etype, etype]\n        The mapping from edge etype to its reverse edge type.\n    output_device : device\n        The device of the output edge IDs.\n    \"\"\"\n    exclude_eids = _find_exclude_eids(\n        g,\n        exclude,\n        seed_edges,\n        reverse_eid_map=reverse_eids,\n        reverse_etype_map=reverse_etypes,\n    )\n    if exclude_eids is not None and output_device is not None:\n        exclude_eids = recursive_apply(\n            exclude_eids, lambda x: F.copy_to(x, output_device)\n        )\n    return exclude_eids\n\n\nclass EdgePredictionSampler(Sampler):\n    \"\"\"Sampler class that wraps an existing sampler for node classification into another\n    one for edge classification or link prediction.\n\n    See also\n    --------\n    as_edge_prediction_sampler\n    \"\"\"\n\n    def __init__(\n        self,\n        sampler,\n        exclude=None,\n        reverse_eids=None,\n        reverse_etypes=None,\n        negative_sampler=None,\n        prefetch_labels=None,\n    ):\n        super().__init__()\n        # Check if the sampler's sample method has an optional third argument.\n        argspec = inspect.getfullargspec(sampler.sample)\n        if len(argspec.args) < 4:  # ['self', 'g', 'indices', 'exclude_eids']\n            raise TypeError(\n                \"This sampler does not support edge or link prediction; please add an\"\n                \"optional third argument for edge IDs to exclude in its sample() method.\"\n            )\n        self.reverse_eids = reverse_eids\n        self.reverse_etypes = reverse_etypes\n        self.exclude = exclude\n        self.sampler = sampler\n        self.negative_sampler = negative_sampler\n        self.prefetch_labels = prefetch_labels or []\n        self.output_device = sampler.output_device\n\n    def _build_neg_graph(self, g, seed_edges):\n        neg_srcdst = self.negative_sampler(g, seed_edges)\n        if not isinstance(neg_srcdst, Mapping):\n            assert len(g.canonical_etypes) == 1, (\n                \"graph has multiple or no edge types; \"\n                \"please return a dict in negative sampler.\"\n            )\n            neg_srcdst = {g.canonical_etypes[0]: neg_srcdst}\n\n        dtype = F.dtype(list(neg_srcdst.values())[0][0])\n        ctx = context_of(seed_edges) if seed_edges is not None else g.device\n        neg_edges = {\n            etype: neg_srcdst.get(\n                etype,\n                (\n                    F.copy_to(F.tensor([], dtype), ctx=ctx),\n                    F.copy_to(F.tensor([], dtype), ctx=ctx),\n                ),\n            )\n            for etype in g.canonical_etypes\n        }\n        neg_pair_graph = heterograph(\n            neg_edges, {ntype: g.num_nodes(ntype) for ntype in g.ntypes}\n        )\n        return neg_pair_graph\n\n    def assign_lazy_features(self, result):\n        \"\"\"Assign lazy features for prefetching.\"\"\"\n        pair_graph = result[1]\n        set_edge_lazy_features(pair_graph, self.prefetch_labels)\n        # In-place updates\n        return result\n\n    def sample(self, g, seed_edges):  # pylint: disable=arguments-differ\n        \"\"\"Samples a list of blocks, as well as a subgraph containing the sampled\n        edges from the original graph.\n\n        If :attr:`negative_sampler` is given, also returns another graph containing the\n        negative pairs as edges.\n        \"\"\"\n        if isinstance(seed_edges, Mapping):\n            seed_edges = {\n                g.to_canonical_etype(k): v for k, v in seed_edges.items()\n            }\n        exclude = self.exclude\n        pair_graph = g.edge_subgraph(\n            seed_edges, relabel_nodes=False, output_device=self.output_device\n        )\n        eids = pair_graph.edata[EID]\n\n        if self.negative_sampler is not None:\n            neg_graph = self._build_neg_graph(g, seed_edges)\n            pair_graph, neg_graph = compact_graphs([pair_graph, neg_graph])\n        else:\n            pair_graph = compact_graphs(pair_graph)\n\n        pair_graph.edata[EID] = eids\n        seed_nodes = pair_graph.ndata[NID]\n\n        exclude_eids = find_exclude_eids(\n            g,\n            seed_edges,\n            exclude,\n            self.reverse_eids,\n            self.reverse_etypes,\n            self.output_device,\n        )\n\n        input_nodes, _, blocks = self.sampler.sample(\n            g, seed_nodes, exclude_eids\n        )\n\n        if self.negative_sampler is None:\n            return self.assign_lazy_features((input_nodes, pair_graph, blocks))\n        else:\n            return self.assign_lazy_features(\n                (input_nodes, pair_graph, neg_graph, blocks)\n            )\n\n\ndef as_edge_prediction_sampler(\n    sampler,\n    exclude=None,\n    reverse_eids=None,\n    reverse_etypes=None,\n    negative_sampler=None,\n    prefetch_labels=None,\n):\n    \"\"\"Create an edge-wise sampler from a node-wise sampler.\n\n    For each batch of edges, the sampler applies the provided node-wise sampler to\n    their source and destination nodes to extract subgraphs. It also generates negative\n    edges if a negative sampler is provided, and extract subgraphs for their incident\n    nodes as well.\n\n    For each iteration, the sampler will yield\n\n    * A tensor of input nodes necessary for computing the representation on edges, or\n      a dictionary of node type names and such tensors.\n\n    * A subgraph that contains only the edges in the minibatch and their incident nodes.\n      Note that the graph has an identical metagraph with the original graph.\n\n    * If a negative sampler is given, another graph that contains the \"negative edges\",\n      connecting the source and destination nodes yielded from the given negative sampler.\n\n    * The subgraphs or MFGs returned by the provided node-wise sampler, generated\n      from the incident nodes of the edges in the minibatch (as well as those of the\n      negative edges if applicable).\n\n    Parameters\n    ----------\n    sampler : Sampler\n        The node-wise sampler object.  It additionally requires that the :attr:`sample`\n        method must have an optional third argument :attr:`exclude_eids` representing the\n        edge IDs to exclude from neighborhood.  The argument will be either a tensor\n        for homogeneous graphs or a dict of edge types and tensors for heterogeneous\n        graphs.\n    exclude : Union[str, callable], optional\n        Whether and how to exclude dependencies related to the sampled edges in the\n        minibatch.  Possible values are\n\n        * None, for not excluding any edges.\n\n        * ``self``, for excluding the edges in the current minibatch.\n\n        * ``reverse_id``, for excluding not only the edges in the current minibatch but\n          also their reverse edges according to the ID mapping in the argument\n          :attr:`reverse_eids`.\n\n        * ``reverse_types``, for excluding not only the edges in the current minibatch\n          but also their reverse edges stored in another type according to\n          the argument :attr:`reverse_etypes`.\n\n        * User-defined exclusion rule. It is a callable with edges in the current\n          minibatch as a single argument and should return the edges to be excluded.\n    reverse_eids : Tensor or dict[etype, Tensor], optional\n        A tensor of reverse edge ID mapping.  The i-th element indicates the ID of\n        the i-th edge's reverse edge.\n\n        If the graph is heterogeneous, this argument requires a dictionary of edge\n        types and the reverse edge ID mapping tensors.\n    reverse_etypes : dict[etype, etype], optional\n        The mapping from the original edge types to their reverse edge types.\n    negative_sampler : callable, optional\n        The negative sampler.\n    prefetch_labels : list[str] or dict[etype, list[str]], optional\n        The edge labels to prefetch for the returned positive pair graph.\n\n        See :ref:`guide-minibatch-prefetching` for a detailed explanation of prefetching.\n\n    Examples\n    --------\n    The following example shows how to train a 3-layer GNN for edge classification on a\n    set of edges ``train_eid`` on a homogeneous undirected graph. Each node takes\n    messages from all neighbors.\n\n    Given an array of source node IDs ``src`` and another array of destination\n    node IDs ``dst``, the following code creates a bidirectional graph:\n\n    >>> g = dgl.graph((torch.cat([src, dst]), torch.cat([dst, src])))\n\n    Edge :math:`i`'s reverse edge in the graph above is edge :math:`i + |E|`. Therefore, we can\n    create a reverse edge mapping ``reverse_eids`` by:\n\n    >>> E = len(src)\n    >>> reverse_eids = torch.cat([torch.arange(E, 2 * E), torch.arange(0, E)])\n\n    By passing ``reverse_eids`` to the edge sampler, the edges in the current mini-batch and their\n    reversed edges will be excluded from the extracted subgraphs to avoid information leakage.\n\n    >>> sampler = dgl.dataloading.as_edge_prediction_sampler(\n    ...     dgl.dataloading.NeighborSampler([15, 10, 5]),\n    ...     exclude='reverse_id', reverse_eids=reverse_eids)\n    >>> dataloader = dgl.dataloading.DataLoader(\n    ...     g, train_eid, sampler,\n    ...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)\n    >>> for input_nodes, pair_graph, blocks in dataloader:\n    ...     train_on(input_nodes, pair_graph, blocks)\n\n    For link prediction, one can provide a negative sampler to sample negative edges.\n    The code below uses DGL's :class:`~dgl.dataloading.negative_sampler.Uniform`\n    to generate 5 negative samples per edge:\n\n    >>> neg_sampler = dgl.dataloading.negative_sampler.Uniform(5)\n    >>> sampler = dgl.dataloading.as_edge_prediction_sampler(\n    ...     dgl.dataloading.NeighborSampler([15, 10, 5]),\n    ...     sampler, exclude='reverse_id', reverse_eids=reverse_eids,\n    ...     negative_sampler=neg_sampler)\n    >>> dataloader = dgl.dataloading.DataLoader(\n    ...     g, train_eid, sampler,\n    ...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)\n    >>> for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader:\n    ...     train_on(input_nodes, pair_graph, neg_pair_graph, blocks)\n\n    For heterogeneous graphs, reverse edges may belong to a different relation. For example,\n    the relations \"user-click-item\" and \"item-click-by-user\" in the graph below are\n    mutual reverse.\n\n    >>> g = dgl.heterograph({\n    ...     ('user', 'click', 'item'): (user, item),\n    ...     ('item', 'clicked-by', 'user'): (item, user)})\n\n    To correctly exclude edges from each mini-batch, set ``exclude='reverse_types'`` and\n    pass a dictionary ``{'click': 'clicked-by', 'clicked-by': 'click'}`` to the\n    ``reverse_etypes`` argument.\n\n    >>> sampler = dgl.dataloading.as_edge_prediction_sampler(\n    ...     dgl.dataloading.NeighborSampler([15, 10, 5]),\n    ...     exclude='reverse_types',\n    ...     reverse_etypes={'click': 'clicked-by', 'clicked-by': 'click'})\n    >>> dataloader = dgl.dataloading.DataLoader(\n    ...     g, {'click': train_eid}, sampler,\n    ...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)\n    >>> for input_nodes, pair_graph, blocks in dataloader:\n    ...     train_on(input_nodes, pair_graph, blocks)\n\n    For link prediction, provide a negative sampler to generate negative samples:\n\n    >>> neg_sampler = dgl.dataloading.negative_sampler.Uniform(5)\n    >>> sampler = dgl.dataloading.as_edge_prediction_sampler(\n    ...     dgl.dataloading.NeighborSampler([15, 10, 5]),\n    ...     exclude='reverse_types',\n    ...     reverse_etypes={'click': 'clicked-by', 'clicked-by': 'click'},\n    ...     negative_sampler=neg_sampler)\n    >>> dataloader = dgl.dataloading.DataLoader(\n    ...     g, train_eid, sampler,\n    ...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)\n    >>> for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader:\n    ...     train_on(input_nodes, pair_graph, neg_pair_graph, blocks)\n    \"\"\"\n    return EdgePredictionSampler(\n        sampler,\n        exclude=exclude,\n        reverse_eids=reverse_eids,\n        reverse_etypes=reverse_etypes,\n        negative_sampler=negative_sampler,\n        prefetch_labels=prefetch_labels,\n    )\n"
  },
  {
    "path": "python/dgl/dataloading/capped_neighbor_sampler.py",
    "content": "\"\"\"Capped neighbor sampler.\"\"\"\nfrom collections import defaultdict\n\nimport numpy as np\nimport torch\n\nfrom ..sampling.utils import EidExcluder\nfrom .base import Sampler, set_edge_lazy_features, set_node_lazy_features\n\n\nclass CappedNeighborSampler(Sampler):\n    \"\"\"Subgraph sampler that sets an upper bound on the number of nodes included in\n    each layer of the sampled subgraph. At each layer, the frontier is randomly\n    subsampled. Rare node types can also be upsampled by taking the scaled square\n    root of the sampling probabilities. The sampler returns the subgraph induced by\n    all the sampled nodes.\n\n    This code was contributed by a community member\n    ([@ayushnoori](https://github.com/ayushnoori)). There aren't currently any unit\n    tests in place to verify its functionality, so please be cautious if you need\n    to make any changes to the code's logic.\n\n    Parameters\n    ----------\n    fanouts : list[int] or dict[etype, int]\n        List of neighbors to sample per edge type for each GNN layer, with the i-th\n        element being the fanout for the i-th GNN layer.\n        - If only a single integer is provided, DGL assumes that every edge type\n            will have the same fanout.\n        - If -1 is provided for one edge type on one layer, then all inbound edges\n            of that edge type will be included.\n    fixed_k : int\n            The number of nodes to sample for each GNN layer.\n    upsample_rare_types : bool\n        Whether or not to upsample rare node types.\n    replace : bool, default True\n        Whether to sample with replacement.\n    prob : str, optional\n        If given, the probability of each neighbor being sampled is proportional\n        to the edge feature value with the given name in ``g.edata``. The feature must be\n        a scalar on each edge.\n    \"\"\"\n\n    def __init__(\n        self,\n        fanouts,\n        fixed_k,\n        upsample_rare_types,\n        replace=False,\n        prob=None,\n        prefetch_node_feats=None,\n        prefetch_edge_feats=None,\n        output_device=None,\n    ):\n        super().__init__()\n        self.fanouts = fanouts\n        self.replace = replace\n        self.fixed_k = fixed_k\n        self.upsample_rare_types = upsample_rare_types\n        self.prob = prob\n        self.prefetch_node_feats = prefetch_node_feats\n        self.prefetch_edge_feats = prefetch_edge_feats\n        self.output_device = output_device\n\n    def sample(\n        self, g, indices, exclude_eids=None\n    ):  # pylint: disable=arguments-differ\n        \"\"\"Sampling function.\n\n        Parameters\n        ----------\n        g : DGLGraph\n            The graph to sample from.\n        indices : Tensor or dict[str, Tensor]\n            Nodes which induce the subgraph.\n        exclude_eids : Tensor or dict[etype, Tensor], optional\n            The edges to exclude from the sampled subgraph.\n\n        Returns\n        -------\n        input_nodes : Tensor or dict[str, Tensor]\n            The node IDs inducing the subgraph.\n        output_nodes : Tensor or dict[str, Tensor]\n            The node IDs that are sampled in this minibatch.\n        subg : DGLGraph\n            The subgraph itself.\n        \"\"\"\n\n        # Define empty dictionary to store reached nodes.\n        output_nodes = indices\n        all_reached_nodes = [indices]\n\n        # Iterate over fanout.\n        for fanout in reversed(self.fanouts):\n\n            # Sample frontier.\n            frontier = g.sample_neighbors(\n                indices,\n                fanout,\n                output_device=self.output_device,\n                replace=self.replace,\n                prob=self.prob,\n                exclude_edges=exclude_eids,\n            )\n\n            # Get reached nodes.\n            curr_reached = defaultdict(list)\n            for c_etype in frontier.canonical_etypes:\n                (src_type, _, _) = c_etype\n                src, _ = frontier.edges(etype=c_etype)\n                curr_reached[src_type].append(src)\n\n            # De-duplication.\n            curr_reached = {\n                ntype: torch.unique(torch.cat(srcs))\n                for ntype, srcs in curr_reached.items()\n            }\n\n            # Generate type sampling probabilties.\n            type_count = {\n                node_type: indices.shape[0]\n                for node_type, indices in curr_reached.items()\n            }\n            total_count = sum(type_count.values())\n            probs = {\n                node_type: count / total_count\n                for node_type, count in type_count.items()\n            }\n\n            # Upsample rare node types.\n            if self.upsample_rare_types:\n\n                # Take scaled square root of probabilities.\n                prob_dist = list(probs.values())\n                prob_dist = np.sqrt(prob_dist)\n                prob_dist = prob_dist / prob_dist.sum()\n\n                # Update probabilities.\n                probs = {\n                    node_type: prob_dist[i]\n                    for i, node_type in enumerate(probs.keys())\n                }\n\n            # Generate node counts per type.\n            n_per_type = {\n                node_type: int(self.fixed_k * prob)\n                for node_type, prob in probs.items()\n            }\n            remainder = self.fixed_k - sum(n_per_type.values())\n            for _ in range(remainder):\n                node_type = np.random.choice(\n                    list(probs.keys()), p=list(probs.values())\n                )\n                n_per_type[node_type] += 1\n\n            # Downsample nodes.\n            curr_reached_k = {}\n            for node_type, node_ids in curr_reached.items():\n\n                # Get number of total nodes and number to sample.\n                num_nodes = node_ids.shape[0]\n                n_to_sample = min(num_nodes, n_per_type[node_type])\n\n                # Downsample nodes of current type.\n                random_indices = torch.randperm(num_nodes)[:n_to_sample]\n                curr_reached_k[node_type] = node_ids[random_indices]\n\n            # Update seed nodes.\n            indices = curr_reached_k\n            all_reached_nodes.append(curr_reached_k)\n\n        # Merge all reached nodes before sending to `DGLGraph.subgraph`.\n        merged_nodes = {}\n        for ntype in g.ntypes:\n            merged_nodes[ntype] = torch.unique(\n                torch.cat(\n                    [reached.get(ntype, []) for reached in all_reached_nodes]\n                )\n            )\n        subg = g.subgraph(\n            merged_nodes, relabel_nodes=True, output_device=self.output_device\n        )\n\n        if exclude_eids is not None:\n            subg = EidExcluder(exclude_eids)(subg)\n\n        set_node_lazy_features(subg, self.prefetch_node_feats)\n        set_edge_lazy_features(subg, self.prefetch_edge_feats)\n\n        return indices, output_nodes, subg\n"
  },
  {
    "path": "python/dgl/dataloading/cluster_gcn.py",
    "content": "\"\"\"Cluster-GCN samplers.\"\"\"\nimport os\nimport pickle\n\nimport numpy as np\n\nfrom .. import backend as F\nfrom ..base import DGLError\nfrom ..partition import metis_partition_assignment\nfrom .base import Sampler, set_edge_lazy_features, set_node_lazy_features\n\n\nclass ClusterGCNSampler(Sampler):\n    \"\"\"Cluster sampler from `Cluster-GCN: An Efficient Algorithm for Training\n    Deep and Large Graph Convolutional Networks\n    <https://arxiv.org/abs/1905.07953>`__\n\n    This sampler first partitions the graph with METIS partitioning, then it caches the nodes of\n    each partition to a file within the given cache directory.\n\n    The sampler then selects the graph partitions according to the provided\n    partition IDs, take the union of all nodes in those partitions, and return an\n    induced subgraph in its :attr:`sample` method.\n\n    Parameters\n    ----------\n    g : DGLGraph\n        The original graph.  Must be homogeneous and on CPU.\n    k : int\n        The number of partitions.\n    cache_path : str\n        The path to the cache directory for storing the partition result.\n    balance_ntypes, balkance_edges, mode :\n        Passed to :func:`dgl.metis_partition_assignment`.\n    prefetch_ndata : list[str], optional\n        The node data to prefetch for the subgraph.\n\n        See :ref:`guide-minibatch-prefetching` for a detailed explanation of prefetching.\n    prefetch_edata : list[str], optional\n        The edge data to prefetch for the subgraph.\n\n        See :ref:`guide-minibatch-prefetching` for a detailed explanation of prefetching.\n    output_device : device, optional\n        The device of the output subgraphs or MFGs.  Default is the same as the\n        minibatch of partition indices.\n\n    Examples\n    --------\n    **Node classification**\n\n    With this sampler, the data loader will accept the list of partition IDs as\n    indices to iterate over.  For instance, the following code first splits the\n    graph into 1000 partitions using METIS, and at each iteration it gets a subgraph\n    induced by the nodes covered by 20 randomly selected partitions.\n\n    >>> num_parts = 1000\n    >>> sampler = dgl.dataloading.ClusterGCNSampler(g, num_parts)\n    >>> dataloader = dgl.dataloading.DataLoader(\n    ...     g, torch.arange(num_parts), sampler,\n    ...     batch_size=20, shuffle=True, drop_last=False, num_workers=4)\n    >>> for subg in dataloader:\n    ...     train_on(subg)\n    \"\"\"\n\n    def __init__(\n        self,\n        g,\n        k,\n        cache_path=\"cluster_gcn.pkl\",\n        balance_ntypes=None,\n        balance_edges=False,\n        mode=\"k-way\",\n        prefetch_ndata=None,\n        prefetch_edata=None,\n        output_device=None,\n    ):\n        super().__init__()\n        if os.path.exists(cache_path):\n            try:\n                with open(cache_path, \"rb\") as f:\n                    (\n                        self.partition_offset,\n                        self.partition_node_ids,\n                    ) = pickle.load(f)\n            except (EOFError, TypeError, ValueError):\n                raise DGLError(\n                    f\"The contents in the cache file {cache_path} is invalid. \"\n                    f\"Please remove the cache file {cache_path} or specify another path.\"\n                )\n            if len(self.partition_offset) != k + 1:\n                raise DGLError(\n                    f\"Number of partitions in the cache does not match the value of k. \"\n                    f\"Please remove the cache file {cache_path} or specify another path.\"\n                )\n            if len(self.partition_node_ids) != g.num_nodes():\n                raise DGLError(\n                    f\"Number of nodes in the cache does not match the given graph. \"\n                    f\"Please remove the cache file {cache_path} or specify another path.\"\n                )\n        else:\n            partition_ids = metis_partition_assignment(\n                g,\n                k,\n                balance_ntypes=balance_ntypes,\n                balance_edges=balance_edges,\n                mode=mode,\n            )\n            partition_ids = F.asnumpy(partition_ids)\n            partition_node_ids = np.argsort(partition_ids)\n            partition_size = F.zerocopy_from_numpy(\n                np.bincount(partition_ids, minlength=k)\n            )\n            partition_offset = F.zerocopy_from_numpy(\n                np.insert(np.cumsum(partition_size), 0, 0)\n            )\n            partition_node_ids = F.zerocopy_from_numpy(partition_node_ids)\n            with open(cache_path, \"wb\") as f:\n                pickle.dump((partition_offset, partition_node_ids), f)\n            self.partition_offset = partition_offset\n            self.partition_node_ids = partition_node_ids\n\n        self.prefetch_ndata = prefetch_ndata or []\n        self.prefetch_edata = prefetch_edata or []\n        self.output_device = output_device\n\n    def sample(self, g, partition_ids):  # pylint: disable=arguments-differ\n        \"\"\"Sampling function.\n\n        Parameters\n        ----------\n        g : DGLGraph\n            The graph to sample from.\n        partition_ids : Tensor\n            A 1-D integer tensor of partition IDs.\n\n        Returns\n        -------\n        DGLGraph\n            The sampled subgraph.\n        \"\"\"\n        node_ids = F.cat(\n            [\n                self.partition_node_ids[\n                    self.partition_offset[i] : self.partition_offset[i + 1]\n                ]\n                for i in F.asnumpy(partition_ids)\n            ],\n            0,\n        )\n        sg = g.subgraph(\n            node_ids, relabel_nodes=True, output_device=self.output_device\n        )\n        set_node_lazy_features(sg, self.prefetch_ndata)\n        set_edge_lazy_features(sg, self.prefetch_edata)\n        return sg\n"
  },
  {
    "path": "python/dgl/dataloading/dataloader.py",
    "content": "\"\"\"DGL PyTorch DataLoaders\"\"\"\n\nimport atexit\nimport inspect\nimport itertools\nimport math\nimport operator\nimport os\nimport re\nimport threading\nfrom collections.abc import Mapping, Sequence\nfrom contextlib import contextmanager\nfrom functools import reduce\nfrom queue import Empty, Full, Queue\n\nimport numpy as np\nimport psutil\nimport torch\nimport torch.distributed as dist\nfrom torch.utils.data.distributed import DistributedSampler\n\nfrom .. import backend as F\nfrom .._ffi.base import is_tensor_adaptor_enabled\n\nfrom ..base import dgl_warning, DGLError, EID, NID\nfrom ..batch import batch as batch_graphs\nfrom ..cuda import GPUCache\nfrom ..frame import LazyFeature\nfrom ..heterograph import DGLGraph\nfrom ..storages import wrap_storage\nfrom ..utils import (\n    dtype_of,\n    ExceptionWrapper,\n    get_num_threads,\n    get_numa_nodes_cores,\n    recursive_apply,\n    recursive_apply_pair,\n    set_num_threads,\n)\n\nPYTHON_EXIT_STATUS = False\n\n\ndef _set_python_exit_flag():\n    global PYTHON_EXIT_STATUS\n    PYTHON_EXIT_STATUS = True\n\n\natexit.register(_set_python_exit_flag)\n\nprefetcher_timeout = int(os.environ.get(\"DGL_PREFETCHER_TIMEOUT\", \"30\"))\n\n\nclass _TensorizedDatasetIter(object):\n    def __init__(self, dataset, batch_size, drop_last, mapping_keys, shuffle):\n        self.dataset = dataset\n        self.batch_size = batch_size\n        self.drop_last = drop_last\n        self.mapping_keys = mapping_keys\n        self.index = 0\n        self.shuffle = shuffle\n\n    # For PyTorch Lightning compatibility\n    def __iter__(self):\n        return self\n\n    def _next_indices(self):\n        num_items = self.dataset.shape[0]\n        if self.index >= num_items:\n            raise StopIteration\n        end_idx = self.index + self.batch_size\n        if end_idx > num_items:\n            if self.drop_last:\n                raise StopIteration\n            end_idx = num_items\n        batch = self.dataset[self.index : end_idx]\n        self.index += self.batch_size\n\n        return batch\n\n    def __next__(self):\n        batch = self._next_indices()\n        if self.mapping_keys is None:\n            # clone() fixes #3755, probably.  Not sure why.  Need to take a look afterwards.\n            return batch.clone()\n\n        # convert the type-ID pairs to dictionary\n        type_ids = batch[:, 0]\n        indices = batch[:, 1]\n        _, type_ids_sortidx = torch.sort(type_ids, stable=True)\n        type_ids = type_ids[type_ids_sortidx]\n        indices = indices[type_ids_sortidx]\n        type_id_uniq, type_id_count = torch.unique_consecutive(\n            type_ids, return_counts=True\n        )\n        type_id_uniq = type_id_uniq.tolist()\n        type_id_offset = type_id_count.cumsum(0).tolist()\n        type_id_offset.insert(0, 0)\n        id_dict = {\n            self.mapping_keys[type_id_uniq[i]]: indices[\n                type_id_offset[i] : type_id_offset[i + 1]\n            ].clone()\n            for i in range(len(type_id_uniq))\n        }\n        return id_dict\n\n\ndef _get_id_tensor_from_mapping(indices, device, keys):\n    dtype = dtype_of(indices)\n    id_tensor = torch.empty(\n        sum(v.shape[0] for v in indices.values()), 2, dtype=dtype, device=device\n    )\n\n    offset = 0\n    for i, k in enumerate(keys):\n        if k not in indices:\n            continue\n        index = indices[k]\n        length = index.shape[0]\n        id_tensor[offset : offset + length, 0] = i\n        id_tensor[offset : offset + length, 1] = index\n        offset += length\n    return id_tensor\n\n\ndef _split_to_local_id_tensor_from_mapping(\n    indices, keys, local_lower_bound, local_upper_bound\n):\n    dtype = dtype_of(indices)\n    device = next(iter(indices.values())).device\n    num_samples = local_upper_bound - local_lower_bound\n    id_tensor = torch.empty(num_samples, 2, dtype=dtype, device=device)\n\n    index_offset = 0\n    split_id_offset = 0\n    for i, k in enumerate(keys):\n        if k not in indices:\n            continue\n        index = indices[k]\n        length = index.shape[0]\n        index_offset2 = index_offset + length\n        lower = max(local_lower_bound, index_offset)\n        upper = min(local_upper_bound, index_offset2)\n        if upper > lower:\n            split_id_offset2 = split_id_offset + (upper - lower)\n            assert split_id_offset2 <= num_samples\n            id_tensor[split_id_offset:split_id_offset2, 0] = i\n            id_tensor[split_id_offset:split_id_offset2, 1] = index[\n                lower - index_offset : upper - index_offset\n            ]\n            split_id_offset += upper - lower\n            if split_id_offset2 == num_samples:\n                break\n        index_offset = index_offset2\n    return id_tensor\n\n\ndef _split_to_local_id_tensor(indices, local_lower_bound, local_upper_bound):\n    dtype = dtype_of(indices)\n    device = indices.device\n    num_samples = local_upper_bound - local_lower_bound\n    id_tensor = torch.empty(num_samples, dtype=dtype, device=device)\n\n    if local_upper_bound > len(indices):\n        remainder = len(indices) - local_lower_bound\n        id_tensor[0:remainder] = indices[local_lower_bound:]\n    else:\n        id_tensor = indices[local_lower_bound:local_upper_bound]\n    return id_tensor\n\n\ndef _divide_by_worker(dataset, batch_size, drop_last):\n    num_samples = dataset.shape[0]\n    worker_info = torch.utils.data.get_worker_info()\n    if worker_info:\n        num_batches = (\n            num_samples + (0 if drop_last else batch_size - 1)\n        ) // batch_size\n        num_batches_per_worker = num_batches // worker_info.num_workers\n        left_over = num_batches % worker_info.num_workers\n        start = (num_batches_per_worker * worker_info.id) + min(\n            left_over, worker_info.id\n        )\n        end = start + num_batches_per_worker + (worker_info.id < left_over)\n        start *= batch_size\n        end = min(end * batch_size, num_samples)\n        dataset = dataset[start:end]\n    return dataset\n\n\nclass TensorizedDataset(torch.utils.data.IterableDataset):\n    \"\"\"Custom Dataset wrapper that returns a minibatch as tensors or dicts of tensors.\n    When the dataset is on the GPU, this significantly reduces the overhead.\n    \"\"\"\n\n    def __init__(\n        self, indices, batch_size, drop_last, shuffle, use_shared_memory\n    ):\n        if isinstance(indices, Mapping):\n            self._mapping_keys = list(indices.keys())\n            self._device = next(iter(indices.values())).device\n            self._id_tensor = _get_id_tensor_from_mapping(\n                indices, self._device, self._mapping_keys\n            )\n        else:\n            self._id_tensor = indices\n            self._device = indices.device\n            self._mapping_keys = None\n        # Use a shared memory array to permute indices for shuffling.  This is to make sure that\n        # the worker processes can see it when persistent_workers=True, where self._indices\n        # would not be duplicated every epoch.\n        self._indices = torch.arange(\n            self._id_tensor.shape[0], dtype=torch.int64\n        )\n        if use_shared_memory:\n            self._indices.share_memory_()\n        self.batch_size = batch_size\n        self.drop_last = drop_last\n        self._shuffle = shuffle\n\n    def shuffle(self):\n        \"\"\"Shuffle the dataset.\"\"\"\n        np.random.shuffle(self._indices.numpy())\n\n    def __iter__(self):\n        indices = _divide_by_worker(\n            self._indices, self.batch_size, self.drop_last\n        )\n        id_tensor = self._id_tensor[indices]\n        return _TensorizedDatasetIter(\n            id_tensor,\n            self.batch_size,\n            self.drop_last,\n            self._mapping_keys,\n            self._shuffle,\n        )\n\n    def __len__(self):\n        num_samples = self._id_tensor.shape[0]\n        return (\n            num_samples + (0 if self.drop_last else (self.batch_size - 1))\n        ) // self.batch_size\n\n\ndef _decompose_one_dimension(length, world_size, rank, drop_last):\n    if drop_last:\n        num_samples = math.floor(length / world_size)\n    else:\n        num_samples = math.ceil(length / world_size)\n    sta = rank * num_samples\n    end = (rank + 1) * num_samples\n    return sta, end\n\n\nclass DDPTensorizedDataset(torch.utils.data.IterableDataset):\n    \"\"\"Custom Dataset wrapper that returns a minibatch as tensors or dicts of tensors.\n    When the dataset is on the GPU, this significantly reduces the overhead.\n\n    This class additionally saves the index tensor in shared memory and therefore\n    avoids duplicating the same index tensor during shuffling.\n    \"\"\"\n\n    def __init__(self, indices, batch_size, drop_last, ddp_seed, shuffle):\n        if isinstance(indices, Mapping):\n            self._mapping_keys = list(indices.keys())\n            len_indices = sum(len(v) for v in indices.values())\n        else:\n            self._mapping_keys = None\n            len_indices = len(indices)\n\n        self.rank = dist.get_rank()\n        self.num_replicas = dist.get_world_size()\n        self.seed = ddp_seed\n        self.epoch = 0\n        self.batch_size = batch_size\n        self.drop_last = drop_last\n        self._shuffle = shuffle\n        (\n            self.local_lower_bound,\n            self.local_upper_bound,\n        ) = _decompose_one_dimension(\n            len_indices, self.num_replicas, self.rank, drop_last\n        )\n        self.num_samples = self.local_upper_bound - self.local_lower_bound\n        self.local_num_indices = self.num_samples\n        if self.local_upper_bound > len_indices:\n            assert not drop_last\n            self.local_num_indices = len_indices - self.local_lower_bound\n\n        if isinstance(indices, Mapping):\n            self._id_tensor = _split_to_local_id_tensor_from_mapping(\n                indices,\n                self._mapping_keys,\n                self.local_lower_bound,\n                self.local_upper_bound,\n            )\n        else:\n            self._id_tensor = _split_to_local_id_tensor(\n                indices, self.local_lower_bound, self.local_upper_bound\n            )\n        self._device = self._id_tensor.device\n        # padding self._indices when drop_last = False (self._indices always on cpu)\n        self._indices = torch.empty(self.num_samples, dtype=torch.int64)\n        torch.arange(\n            self.local_num_indices, out=self._indices[: self.local_num_indices]\n        )\n        if not drop_last:\n            torch.arange(\n                self.num_samples - self.local_num_indices,\n                out=self._indices[self.local_num_indices :],\n            )\n        assert len(self._id_tensor) == self.num_samples\n\n    def shuffle(self):\n        \"\"\"Shuffles the dataset.\"\"\"\n        np.random.shuffle(self._indices[: self.local_num_indices].numpy())\n        if not self.drop_last:\n            # pad extra from local indices\n            self._indices[self.local_num_indices :] = self._indices[\n                : self.num_samples - self.local_num_indices\n            ]\n\n    def __iter__(self):\n        indices = _divide_by_worker(\n            self._indices, self.batch_size, self.drop_last\n        )\n        id_tensor = self._id_tensor[indices]\n        return _TensorizedDatasetIter(\n            id_tensor,\n            self.batch_size,\n            self.drop_last,\n            self._mapping_keys,\n            self._shuffle,\n        )\n\n    def __len__(self):\n        return (\n            self.num_samples + (0 if self.drop_last else (self.batch_size - 1))\n        ) // self.batch_size\n\n\ndef _numel_of_shape(shape):\n    return reduce(operator.mul, shape, 1)\n\n\ndef _init_gpu_caches(graph, gpu_caches):\n    if not hasattr(graph, \"_gpu_caches\"):\n        graph._gpu_caches = {\"node\": {}, \"edge\": {}}\n    if gpu_caches is None:\n        return\n    assert isinstance(gpu_caches, dict), \"GPU cache argument should be a dict\"\n    for i, frames in enumerate([graph._node_frames, graph._edge_frames]):\n        node_or_edge = [\"node\", \"edge\"][i]\n        cache_inf = gpu_caches.get(node_or_edge, {})\n        for tid, frame in enumerate(frames):\n            type_ = [graph.ntypes, graph.canonical_etypes][i][tid]\n            for key in frame.keys():\n                if key in cache_inf and cache_inf[key] > 0:\n                    column = frame._columns[key]\n                    if (key, type_) not in graph._gpu_caches[node_or_edge]:\n                        cache = GPUCache(\n                            cache_inf[key],\n                            _numel_of_shape(column.shape),\n                            graph.idtype,\n                        )\n                        graph._gpu_caches[node_or_edge][key, type_] = (\n                            cache,\n                            column.shape,\n                        )\n\n\ndef _prefetch_update_feats(\n    feats,\n    frames,\n    types,\n    get_storage_func,\n    id_name,\n    device,\n    pin_prefetcher,\n    gpu_caches,\n):\n    for tid, frame in enumerate(frames):\n        type_ = types[tid]\n        default_id = frame.get(id_name, None)\n        for key in frame.keys():\n            column = frame._columns[key]\n            if isinstance(column, LazyFeature):\n                parent_key = column.name or key\n                if column.id_ is None and default_id is None:\n                    raise DGLError(\n                        \"Found a LazyFeature with no ID specified, \"\n                        \"and the graph does not have dgl.NID or dgl.EID columns\"\n                    )\n                ids = column.id_ or default_id\n                if (parent_key, type_) in gpu_caches:\n                    cache, item_shape = gpu_caches[parent_key, type_]\n                    values, missing_index, missing_keys = cache.query(ids)\n                    missing_values = get_storage_func(parent_key, type_).fetch(\n                        missing_keys, device, pin_prefetcher\n                    )\n                    cache.replace(\n                        missing_keys, F.astype(missing_values, F.float32)\n                    )\n                    values = F.astype(values, F.dtype(missing_values))\n                    F.scatter_row_inplace(values, missing_index, missing_values)\n                    # Reshape the flattened result to match the original shape.\n                    F.reshape(values, (values.shape[0],) + item_shape)\n                    values.__cache_miss__ = missing_keys.shape[0] / ids.shape[0]\n                    feats[tid, key] = values\n                else:\n                    feats[tid, key] = get_storage_func(parent_key, type_).fetch(\n                        ids, device, pin_prefetcher\n                    )\n\n\n# This class exists to avoid recursion into the feature dictionary returned by the\n# prefetcher when calling recursive_apply().\nclass _PrefetchedGraphFeatures(object):\n    __slots__ = [\"node_feats\", \"edge_feats\"]\n\n    def __init__(self, node_feats, edge_feats):\n        self.node_feats = node_feats\n        self.edge_feats = edge_feats\n\n\ndef _prefetch_for_subgraph(subg, dataloader):\n    node_feats, edge_feats = {}, {}\n    _prefetch_update_feats(\n        node_feats,\n        subg._node_frames,\n        subg.ntypes,\n        dataloader.graph.get_node_storage,\n        NID,\n        dataloader.device,\n        dataloader.pin_prefetcher,\n        dataloader.graph._gpu_caches[\"node\"],\n    )\n    _prefetch_update_feats(\n        edge_feats,\n        subg._edge_frames,\n        subg.canonical_etypes,\n        dataloader.graph.get_edge_storage,\n        EID,\n        dataloader.device,\n        dataloader.pin_prefetcher,\n        dataloader.graph._gpu_caches[\"edge\"],\n    )\n    return _PrefetchedGraphFeatures(node_feats, edge_feats)\n\n\ndef _prefetch_for(item, dataloader):\n    if isinstance(item, DGLGraph):\n        return _prefetch_for_subgraph(item, dataloader)\n    elif isinstance(item, LazyFeature):\n        return dataloader.other_storages[item.name].fetch(\n            item.id_, dataloader.device, dataloader.pin_prefetcher\n        )\n    else:\n        return None\n\n\ndef _await_or_return(x):\n    if hasattr(x, \"wait\"):\n        return x.wait()\n    elif isinstance(x, _PrefetchedGraphFeatures):\n        node_feats = recursive_apply(x.node_feats, _await_or_return)\n        edge_feats = recursive_apply(x.edge_feats, _await_or_return)\n        return _PrefetchedGraphFeatures(node_feats, edge_feats)\n    else:\n        return x\n\n\ndef _record_stream(x, stream):\n    if stream is None:\n        return x\n    if hasattr(x, \"record_stream\"):\n        x.record_stream(stream)\n        return x\n    elif isinstance(x, _PrefetchedGraphFeatures):\n        node_feats = recursive_apply(x.node_feats, _record_stream, stream)\n        edge_feats = recursive_apply(x.edge_feats, _record_stream, stream)\n        return _PrefetchedGraphFeatures(node_feats, edge_feats)\n    else:\n        return x\n\n\ndef _prefetch(batch, dataloader, stream):\n    # feats has the same nested structure of batch, except that\n    # (1) each subgraph is replaced with a pair of node features and edge features, both\n    #     being dictionaries whose keys are (type_id, column_name) and values are either\n    #     tensors or futures.\n    # (2) each LazyFeature object is replaced with a tensor or future.\n    # (3) everything else are replaced with None.\n    #\n    # Once the futures are fetched, this function waits for them to complete by\n    # calling its wait() method.\n    if stream is not None:\n        current_stream = torch.cuda.current_stream()\n        current_stream.wait_stream(stream)\n    else:\n        current_stream = None\n    with torch.cuda.stream(stream):\n        # fetch node/edge features\n        feats = recursive_apply(batch, _prefetch_for, dataloader)\n        feats = recursive_apply(feats, _await_or_return)\n        feats = recursive_apply(feats, _record_stream, current_stream)\n        # transfer input nodes/seed nodes/subgraphs\n        batch = recursive_apply(\n            batch, lambda x: x.to(dataloader.device, non_blocking=True)\n        )\n        batch = recursive_apply(batch, _record_stream, current_stream)\n    stream_event = stream.record_event() if stream is not None else None\n    return batch, feats, stream_event\n\n\ndef _assign_for(item, feat):\n    if isinstance(item, DGLGraph):\n        subg = item\n        for (tid, key), value in feat.node_feats.items():\n            assert isinstance(subg._node_frames[tid][key], LazyFeature)\n            subg._node_frames[tid][key] = value\n        for (tid, key), value in feat.edge_feats.items():\n            assert isinstance(subg._edge_frames[tid][key], LazyFeature)\n            subg._edge_frames[tid][key] = value\n        return subg\n    elif isinstance(item, LazyFeature):\n        return feat\n    else:\n        return item\n\n\ndef _put_if_event_not_set(queue, result, event):\n    while not event.is_set():\n        try:\n            queue.put(result, timeout=1.0)\n            break\n        except Full:\n            continue\n\n\ndef _prefetcher_entry(\n    dataloader_it, dataloader, queue, num_threads, stream, done_event\n):\n    # PyTorch will set the number of threads to 1 which slows down pin_memory() calls\n    # in main process if a prefetching thread is created.\n    if num_threads is not None:\n        torch.set_num_threads(num_threads)\n\n    try:\n        while not done_event.is_set():\n            try:\n                batch = next(dataloader_it)\n            except StopIteration:\n                break\n            batch = recursive_apply(\n                batch, restore_parent_storage_columns, dataloader.graph\n            )\n            batch, feats, stream_event = _prefetch(batch, dataloader, stream)\n            _put_if_event_not_set(\n                queue, (batch, feats, stream_event, None), done_event\n            )\n        _put_if_event_not_set(queue, (None, None, None, None), done_event)\n    except:  # pylint: disable=bare-except\n        _put_if_event_not_set(\n            queue,\n            (None, None, None, ExceptionWrapper(where=\"in prefetcher\")),\n            done_event,\n        )\n\n\n# DGLGraphs have the semantics of lazy feature slicing with subgraphs.  Such behavior depends\n# on that DGLGraph's ndata and edata are maintained by Frames.  So to maintain compatibility\n# with older code, DGLGraphs and other graph storages are handled separately: (1)\n# DGLGraphs will preserve the lazy feature slicing for subgraphs.  (2) Other graph storages\n# will not have lazy feature slicing; all feature slicing will be eager.\ndef remove_parent_storage_columns(item, g):\n    \"\"\"Removes the storage objects in the given graphs' Frames if it is a sub-frame of the\n    given parent graph, so that the storages are not serialized during IPC from PyTorch\n    DataLoader workers.\n    \"\"\"\n    if not isinstance(item, DGLGraph) or not isinstance(g, DGLGraph):\n        return item\n\n    for subframe, frame in zip(\n        itertools.chain(item._node_frames, item._edge_frames),\n        itertools.chain(g._node_frames, g._edge_frames),\n    ):\n        for key in list(subframe.keys()):\n            subcol = subframe._columns[key]  # directly get the column object\n            if isinstance(subcol, LazyFeature):\n                continue\n            col = frame._columns.get(key, None)\n            if col is None:\n                continue\n            if col.storage is subcol.storage:\n                subcol.storage = None\n    return item\n\n\ndef restore_parent_storage_columns(item, g):\n    \"\"\"Restores the storage objects in the given graphs' Frames if it is a sub-frame of the\n    given parent graph (i.e. when the storage object is None).\n    \"\"\"\n    if not isinstance(item, DGLGraph) or not isinstance(g, DGLGraph):\n        return item\n\n    for subframe, frame in zip(\n        itertools.chain(item._node_frames, item._edge_frames),\n        itertools.chain(g._node_frames, g._edge_frames),\n    ):\n        for key in subframe.keys():\n            subcol = subframe._columns[key]\n            if isinstance(subcol, LazyFeature):\n                continue\n            col = frame._columns.get(key, None)\n            if col is None:\n                continue\n            if subcol.storage is None:\n                subcol.storage = col.storage\n    return item\n\n\nclass _PrefetchingIter(object):\n    def __init__(self, dataloader, dataloader_it, num_threads=None):\n        self.queue = Queue(1)\n        self.dataloader_it = dataloader_it\n        self.dataloader = dataloader\n        self.num_threads = num_threads\n\n        self.use_thread = dataloader.use_prefetch_thread\n        self.use_alternate_streams = dataloader.use_alternate_streams\n        self.device = self.dataloader.device\n        if self.use_alternate_streams and self.device.type == \"cuda\":\n            self.stream = torch.cuda.Stream(device=self.device)\n        else:\n            self.stream = None\n        self._shutting_down = False\n        if self.use_thread:\n            self._done_event = threading.Event()\n            thread = threading.Thread(\n                target=_prefetcher_entry,\n                args=(\n                    dataloader_it,\n                    dataloader,\n                    self.queue,\n                    num_threads,\n                    self.stream,\n                    self._done_event,\n                ),\n                daemon=True,\n            )\n            thread.start()\n            self.thread = thread\n\n    def __iter__(self):\n        return self\n\n    def _shutdown(self):\n        # Sometimes when Python is exiting complicated operations like\n        # self.queue.get_nowait() will hang.  So we set it to no-op and let Python handle\n        # the rest since the thread is daemonic.\n        # PyTorch takes the same solution.\n        if PYTHON_EXIT_STATUS is True or PYTHON_EXIT_STATUS is None:\n            return\n        if not self._shutting_down:\n            try:\n                self._shutting_down = True\n                self._done_event.set()\n\n                try:\n                    self.queue.get_nowait()  # In case the thread is blocking on put().\n                except:  # pylint: disable=bare-except\n                    pass\n\n                self.thread.join()\n            except:  # pylint: disable=bare-except\n                pass\n\n    def __del__(self):\n        if self.use_thread:\n            self._shutdown()\n\n    def _next_non_threaded(self):\n        batch = next(self.dataloader_it)\n        batch = recursive_apply(\n            batch, restore_parent_storage_columns, self.dataloader.graph\n        )\n        batch, feats, stream_event = _prefetch(\n            batch, self.dataloader, self.stream\n        )\n        return batch, feats, stream_event\n\n    def _next_threaded(self):\n        try:\n            batch, feats, stream_event, exception = self.queue.get(\n                timeout=prefetcher_timeout\n            )\n        except Empty:\n            raise RuntimeError(\n                f\"Prefetcher thread timed out at {prefetcher_timeout} seconds.\"\n            )\n        if batch is None:\n            self.thread.join()\n            if exception is None:\n                raise StopIteration\n            exception.reraise()\n        return batch, feats, stream_event\n\n    def __next__(self):\n        batch, feats, stream_event = (\n            self._next_non_threaded()\n            if not self.use_thread\n            else self._next_threaded()\n        )\n        batch = recursive_apply_pair(batch, feats, _assign_for)\n        if stream_event is not None:\n            stream_event.wait()\n        return batch\n\n\n# Make them classes to work with pickling in mp.spawn\nclass CollateWrapper(object):\n    \"\"\"Wraps a collate function with :func:`remove_parent_storage_columns` for serializing\n    from PyTorch DataLoader workers.\n    \"\"\"\n\n    def __init__(self, sample_func, g, use_uva, device):\n        self.sample_func = sample_func\n        self.g = g\n        self.use_uva = use_uva\n        self.device = device\n\n    def __call__(self, items):\n        graph_device = getattr(self.g, \"device\", None)\n        if self.use_uva or (graph_device != torch.device(\"cpu\")):\n            # Only copy the indices to the given device if in UVA mode or the graph\n            # is not on CPU.\n            items = recursive_apply(items, lambda x: x.to(self.device))\n        batch = self.sample_func(self.g, items)\n        return recursive_apply(batch, remove_parent_storage_columns, self.g)\n\n\nclass WorkerInitWrapper(object):\n    \"\"\"Wraps the :attr:`worker_init_fn` argument of the DataLoader to set the number of DGL\n    OMP threads to 1 for PyTorch DataLoader workers.\n    \"\"\"\n\n    def __init__(self, func):\n        self.func = func\n\n    def __call__(self, worker_id):\n        set_num_threads(1)\n        if self.func is not None:\n            self.func(worker_id)\n\n\ndef create_tensorized_dataset(\n    indices,\n    batch_size,\n    drop_last,\n    use_ddp,\n    ddp_seed,\n    shuffle,\n    use_shared_memory,\n):\n    \"\"\"Converts a given indices tensor to a TensorizedDataset, an IterableDataset\n    that returns views of the original tensor, to reduce overhead from having\n    a list of scalar tensors in default PyTorch DataLoader implementation.\n    \"\"\"\n    if use_ddp:\n        # DDP always uses shared memory\n        return DDPTensorizedDataset(\n            indices, batch_size, drop_last, ddp_seed, shuffle\n        )\n    else:\n        return TensorizedDataset(\n            indices, batch_size, drop_last, shuffle, use_shared_memory\n        )\n\n\ndef _get_device(device):\n    device = torch.device(device)\n    if device.type == \"cuda\" and device.index is None:\n        device = torch.device(\"cuda\", torch.cuda.current_device())\n    return device\n\n\nclass DataLoader(torch.utils.data.DataLoader):\n    \"\"\"Sampled graph data loader. Wrap a :class:`~dgl.DGLGraph` and a\n    :class:`~dgl.dataloading.Sampler` into an iterable over mini-batches of samples.\n\n    DGL's ``DataLoader`` extends PyTorch's ``DataLoader`` by handling creation\n    and transmission of graph samples. It supports iterating over a set of nodes,\n    edges or any kinds of indices to get samples in the form of ``DGLGraph``, message\n    flow graphs (MFGS), or any other structures necessary to train a graph neural network.\n\n    Parameters\n    ----------\n    graph : DGLGraph\n        The graph.\n    indices : Tensor or dict[ntype, Tensor]\n        The set of indices.  It can either be a tensor of integer indices or a dictionary\n        of types and indices.\n\n        The actual meaning of the indices is defined by the :meth:`sample` method of\n        :attr:`graph_sampler`.\n    graph_sampler : dgl.dataloading.Sampler\n        The subgraph sampler.\n    device : device context, optional\n        The device of the generated MFGs in each iteration, which should be a\n        PyTorch device object (e.g., ``torch.device``).\n\n        By default this value is None. If :attr:`use_uva` is True, MFGs and graphs will\n        generated in torch.cuda.current_device(), otherwise generated in the same device\n        of :attr:`g`.\n    use_ddp : boolean, optional\n        If True, tells the DataLoader to split the training set for each\n        participating process appropriately using\n        :class:`torch.utils.data.distributed.DistributedSampler`.\n\n        Overrides the :attr:`sampler` argument of :class:`torch.utils.data.DataLoader`.\n    ddp_seed : int, optional\n        The seed for shuffling the dataset in\n        :class:`torch.utils.data.distributed.DistributedSampler`.\n\n        Only effective when :attr:`use_ddp` is True.\n    use_uva : bool, optional\n        Whether to use Unified Virtual Addressing (UVA) to directly sample the graph\n        and slice the features from CPU into GPU.  Setting it to True will pin the\n        graph and feature tensors into pinned memory.\n\n        If True, requires that :attr:`indices` must have the same device as the\n        :attr:`device` argument.\n\n        Default: False.\n    use_prefetch_thread : bool, optional\n        (Advanced option)\n        Spawns a new Python thread to perform feature slicing\n        asynchronously.  Can make things faster at the cost of GPU memory.\n\n        Default: True if the graph is on CPU and :attr:`device` is CUDA.  False otherwise.\n    use_alternate_streams : bool, optional\n        (Advanced option)\n        Whether to slice and transfers the features to GPU on a non-default stream.\n\n        Default: True if the graph is on CPU, :attr:`device` is CUDA, and :attr:`use_uva`\n        is False.  False otherwise.\n    pin_prefetcher : bool, optional\n        (Advanced option)\n        Whether to pin the feature tensors into pinned memory.\n\n        Default: True if the graph is on CPU and :attr:`device` is CUDA.  False otherwise.\n    gpu_cache : dict[dict], optional\n        Which node and edge features to cache using HugeCTR gpu_cache. Example:\n        {\"node\": {\"features\": 500000}, \"edge\": {\"types\": 4000000}} would\n        indicate that we want to cache 500k of the node \"features\" and 4M of the\n        edge \"types\" in GPU caches.\n\n        Is supported only on NVIDIA GPUs with compute capability 70 or above.\n        The dictionary holds the keys of features along with the corresponding\n        cache sizes. Please see\n        https://github.com/NVIDIA-Merlin/HugeCTR/blob/main/gpu_cache/ReadMe.md\n        for further reference.\n    kwargs : dict\n        Key-word arguments to be passed to the parent PyTorch\n        :py:class:`torch.utils.data.DataLoader` class. Common arguments are:\n\n          - ``batch_size`` (int): The number of indices in each batch.\n          - ``drop_last`` (bool): Whether to drop the last incomplete batch.\n          - ``shuffle`` (bool): Whether to randomly shuffle the indices at each epoch.\n\n\n    Examples\n    --------\n    To train a 3-layer GNN for node classification on a set of nodes ``train_nid`` on\n    a homogeneous graph where each node takes messages from 15 neighbors on the\n    first layer, 10 neighbors on the second, and 5 neighbors on the third (assume\n    the backend is PyTorch):\n\n    >>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])\n    >>> dataloader = dgl.dataloading.DataLoader(\n    ...     g, train_nid, sampler,\n    ...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)\n    >>> for input_nodes, output_nodes, blocks in dataloader:\n    ...     train_on(input_nodes, output_nodes, blocks)\n\n    **Using with Distributed Data Parallel**\n\n    If you are using PyTorch's distributed training (e.g. when using\n    :mod:`torch.nn.parallel.DistributedDataParallel`), you can train the model by turning\n    on the `use_ddp` option:\n\n    >>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])\n    >>> dataloader = dgl.dataloading.DataLoader(\n    ...     g, train_nid, sampler, use_ddp=True,\n    ...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)\n    >>> for epoch in range(start_epoch, n_epochs):\n    ...     for input_nodes, output_nodes, blocks in dataloader:\n    ...         train_on(input_nodes, output_nodes, blocks)\n\n    Notes\n    -----\n    Please refer to\n    :doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`\n    and :ref:`User Guide Section 6 <guide-minibatch>` for usage.\n\n    **Tips for selecting the proper device**\n\n    * If the input graph :attr:`g` is on GPU, the output device :attr:`device` must be the same GPU\n      and :attr:`num_workers` must be zero. In this case, the sampling and subgraph construction\n      will take place on the GPU. This is the recommended setting when using a single-GPU and\n      the whole graph fits in GPU memory.\n\n    * If the input graph :attr:`g` is on CPU while the output device :attr:`device` is GPU, then\n      depending on the value of :attr:`use_uva`:\n\n      - If :attr:`use_uva` is set to True, the sampling and subgraph construction will happen\n        on GPU even if the GPU itself cannot hold the entire graph. This is the recommended\n        setting unless there are operations not supporting UVA. :attr:`num_workers` must be 0\n        in this case.\n\n      - Otherwise, both the sampling and subgraph construction will take place on the CPU.\n    \"\"\"\n\n    def __init__(\n        self,\n        graph,\n        indices,\n        graph_sampler,\n        device=None,\n        use_ddp=False,\n        ddp_seed=0,\n        batch_size=1,\n        drop_last=False,\n        shuffle=False,\n        use_prefetch_thread=None,\n        use_alternate_streams=None,\n        pin_prefetcher=None,\n        use_uva=False,\n        gpu_cache=None,\n        **kwargs,\n    ):\n        # (BarclayII) PyTorch Lightning sometimes will recreate a DataLoader from an existing\n        # DataLoader with modifications to the original arguments.  The arguments are retrieved\n        # from the attributes with the same name, and because we change certain arguments\n        # when calling super().__init__() (e.g. batch_size attribute is None even if the\n        # batch_size argument is not, so the next DataLoader's batch_size argument will be\n        # None), we cannot reinitialize the DataLoader with attributes from the previous\n        # DataLoader directly.\n        # A workaround is to check whether \"collate_fn\" appears in kwargs.  If \"collate_fn\"\n        # is indeed in kwargs and it's already a CollateWrapper object, we can assume that\n        # the arguments come from a previously created DGL DataLoader, and directly initialize\n        # the new DataLoader from kwargs without any changes.\n        if isinstance(kwargs.get(\"collate_fn\", None), CollateWrapper):\n            assert batch_size is None  # must be None\n            # restore attributes\n            self.graph = graph\n            self.indices = indices\n            self.graph_sampler = graph_sampler\n            self.device = device\n            self.use_ddp = use_ddp\n            self.ddp_seed = ddp_seed\n            self.shuffle = shuffle\n            self.drop_last = drop_last\n            self.use_prefetch_thread = use_prefetch_thread\n            self.use_alternate_streams = use_alternate_streams\n            self.pin_prefetcher = pin_prefetcher\n            self.use_uva = use_uva\n            kwargs[\"batch_size\"] = None\n            super().__init__(**kwargs)\n            return\n\n        # (BarclayII) I hoped that pin_prefetcher can be merged into PyTorch's native\n        # pin_memory argument.  But our neighbor samplers and subgraph samplers\n        # return indices, which could be CUDA tensors (e.g. during UVA sampling)\n        # hence cannot be pinned.  PyTorch's native pin memory thread does not ignore\n        # CUDA tensors when pinning and will crash.  To enable pin memory for prefetching\n        # features and disable pin memory for sampler's return value, I had to use\n        # a different argument.  Of course I could change the meaning of pin_memory\n        # to pinning prefetched features and disable pin memory for sampler's returns\n        # no matter what, but I doubt if it's reasonable.\n        self.graph = graph\n        self.indices = indices  # For PyTorch-Lightning\n        num_workers = kwargs.get(\"num_workers\", 0)\n\n        indices_device = None\n        try:\n            if isinstance(indices, Mapping):\n                indices = {\n                    k: (torch.tensor(v) if not torch.is_tensor(v) else v)\n                    for k, v in indices.items()\n                }\n                indices_device = next(iter(indices.values())).device\n            else:\n                indices = (\n                    torch.tensor(indices)\n                    if not torch.is_tensor(indices)\n                    else indices\n                )\n                indices_device = indices.device\n        except:  # pylint: disable=bare-except\n            # ignore when it fails to convert to torch Tensors.\n            pass\n\n        if indices_device is None:\n            if not hasattr(indices, \"device\"):\n                raise AttributeError(\n                    'Custom indices dataset requires a \"device\" \\\n                attribute indicating where the indices is.'\n                )\n            indices_device = indices.device\n\n        if device is None:\n            if use_uva:\n                device = torch.cuda.current_device()\n            else:\n                device = self.graph.device\n        self.device = _get_device(device)\n\n        # Sanity check - we only check for DGLGraphs.\n        if isinstance(self.graph, DGLGraph):\n            # Check graph and indices device as well as num_workers\n            if use_uva:\n                if self.graph.device.type != \"cpu\":\n                    raise ValueError(\n                        \"Graph must be on CPU if UVA sampling is enabled.\"\n                    )\n                if num_workers > 0:\n                    raise ValueError(\n                        \"num_workers must be 0 if UVA sampling is enabled.\"\n                    )\n\n                # Create all the formats and pin the features - custom GraphStorages\n                # will need to do that themselves.\n                self.graph.create_formats_()\n                self.graph.pin_memory_()\n            else:\n                if self.graph.device != indices_device:\n                    raise ValueError(\n                        \"Expect graph and indices to be on the same device when use_uva=False. \"\n                    )\n                if self.graph.device.type == \"cuda\" and num_workers > 0:\n                    raise ValueError(\n                        \"num_workers must be 0 if graph and indices are on CUDA.\"\n                    )\n                if self.graph.device.type == \"cpu\" and num_workers > 0:\n                    # Instantiate all the formats if the number of workers is greater than 0.\n                    self.graph.create_formats_()\n\n            # Check pin_prefetcher and use_prefetch_thread - should be only effective\n            # if performing CPU sampling but output device is CUDA\n            if (\n                self.device.type == \"cuda\"\n                and self.graph.device.type == \"cpu\"\n                and not use_uva\n            ):\n                if pin_prefetcher is None:\n                    pin_prefetcher = True\n                if use_prefetch_thread is None:\n                    use_prefetch_thread = True\n            else:\n                if pin_prefetcher is True:\n                    raise ValueError(\n                        \"pin_prefetcher=True is only effective when device=cuda and \"\n                        \"sampling is performed on CPU.\"\n                    )\n                if pin_prefetcher is None:\n                    pin_prefetcher = False\n\n                if use_prefetch_thread is True:\n                    raise ValueError(\n                        \"use_prefetch_thread=True is only effective when device=cuda and \"\n                        \"sampling is performed on CPU.\"\n                    )\n                if use_prefetch_thread is None:\n                    use_prefetch_thread = False\n\n            # Check use_alternate_streams\n            if use_alternate_streams is None:\n                use_alternate_streams = (\n                    self.device.type == \"cuda\"\n                    and self.graph.device.type == \"cpu\"\n                    and not use_uva\n                    and is_tensor_adaptor_enabled()\n                )\n            elif use_alternate_streams and not is_tensor_adaptor_enabled():\n                dgl_warning(\n                    \"use_alternate_streams is turned off because \"\n                    \"TensorAdaptor is not available.\"\n                )\n                use_alternate_streams = False\n\n        if torch.is_tensor(indices) or (\n            isinstance(indices, Mapping)\n            and all(torch.is_tensor(v) for v in indices.values())\n        ):\n            self.dataset = create_tensorized_dataset(\n                indices,\n                batch_size,\n                drop_last,\n                use_ddp,\n                ddp_seed,\n                shuffle,\n                kwargs.get(\"persistent_workers\", False),\n            )\n        else:\n            self.dataset = indices\n\n        self.ddp_seed = ddp_seed\n        self.use_ddp = use_ddp\n        self.use_uva = use_uva\n        self.shuffle = shuffle\n        self.drop_last = drop_last\n        self.graph_sampler = graph_sampler\n        self.use_alternate_streams = use_alternate_streams\n        self.pin_prefetcher = pin_prefetcher\n        self.use_prefetch_thread = use_prefetch_thread\n        self.cpu_affinity_enabled = False\n\n        worker_init_fn = WorkerInitWrapper(kwargs.pop(\"worker_init_fn\", None))\n\n        self.other_storages = {}\n\n        _init_gpu_caches(self.graph, gpu_cache)\n\n        super().__init__(\n            self.dataset,\n            collate_fn=CollateWrapper(\n                self.graph_sampler.sample, graph, self.use_uva, self.device\n            ),\n            batch_size=None,\n            pin_memory=self.pin_prefetcher,\n            worker_init_fn=worker_init_fn,\n            **kwargs,\n        )\n\n    def __iter__(self):\n        if (\n            self.device.type == \"cpu\"\n            and hasattr(psutil.Process, \"cpu_affinity\")\n            and not self.cpu_affinity_enabled\n        ):\n            link = \"https://docs.dgl.ai/tutorials/cpu/cpu_best_practises.html\"\n            dgl_warning(\n                f\"Dataloader CPU affinity opt is not enabled, consider switching it on \"\n                f\"(see enable_cpu_affinity() or CPU best practices for DGL [{link}])\"\n            )\n\n        if self.shuffle:\n            self.dataset.shuffle()\n        # When using multiprocessing PyTorch sometimes set the number of PyTorch threads to 1\n        # when spawning new Python threads.  This drastically slows down pinning features.\n        num_threads = torch.get_num_threads() if self.num_workers > 0 else None\n        return _PrefetchingIter(\n            self, super().__iter__(), num_threads=num_threads\n        )\n\n    @contextmanager\n    def enable_cpu_affinity(\n        self, loader_cores=None, compute_cores=None, verbose=True\n    ):\n        \"\"\"Helper method for enabling cpu affinity for compute threads and dataloader workers\n        Only for CPU devices\n        Uses only NUMA node 0 by default for multi-node systems\n\n        Parameters\n        ----------\n        loader_cores : [int] (optional)\n            List of cpu cores to which dataloader workers should affinitize to.\n            default: node0_cores[0:num_workers]\n\n        compute_cores : [int] (optional)\n            List of cpu cores to which compute threads should affinitize to\n            default: node0_cores[num_workers:]\n\n        verbose : bool (optional)\n            If True, affinity information will be printed to the console\n\n        Usage\n        -----\n        with dataloader.enable_cpu_affinity():\n            <training loop>\n        \"\"\"\n        if self.device.type == \"cpu\":\n            if not self.num_workers > 0:\n                raise Exception(\n                    \"ERROR: affinity should be used with at least one DL worker\"\n                )\n            if loader_cores and len(loader_cores) != self.num_workers:\n                raise Exception(\n                    \"ERROR: cpu_affinity incorrect \"\n                    \"number of loader_cores={} for num_workers={}\".format(\n                        loader_cores, self.num_workers\n                    )\n                )\n\n            # False positive E0203 (access-member-before-definition) linter warning\n            worker_init_fn_old = self.worker_init_fn  # pylint: disable=E0203\n            affinity_old = psutil.Process().cpu_affinity()\n            nthreads_old = get_num_threads()\n\n            compute_cores = compute_cores[:] if compute_cores else []\n            loader_cores = loader_cores[:] if loader_cores else []\n\n            def init_fn(worker_id):\n                try:\n                    psutil.Process().cpu_affinity([loader_cores[worker_id]])\n                except:\n                    raise Exception(\n                        \"ERROR: cannot use affinity id={} cpu={}\".format(\n                            worker_id, loader_cores\n                        )\n                    )\n\n                worker_init_fn_old(worker_id)\n\n            if not loader_cores or not compute_cores:\n                numa_info = get_numa_nodes_cores()\n                if numa_info and len(numa_info[0]) > self.num_workers:\n                    # take one thread per each node 0 core\n                    node0_cores = [cpus[0] for core_id, cpus in numa_info[0]]\n                else:\n                    node0_cores = list(range(psutil.cpu_count(logical=False)))\n\n                if len(node0_cores) < self.num_workers:\n                    raise Exception(\"ERROR: more workers than available cores\")\n\n                loader_cores = loader_cores or node0_cores[0 : self.num_workers]\n                compute_cores = [\n                    cpu for cpu in node0_cores if cpu not in loader_cores\n                ]\n\n            try:\n                psutil.Process().cpu_affinity(compute_cores)\n                set_num_threads(len(compute_cores))\n                self.worker_init_fn = init_fn\n\n                self.cpu_affinity_enabled = True\n                if verbose:\n                    print(\n                        f\"{self.num_workers} DL workers are assigned to cpus \"\n                        f\"{loader_cores}, main process will use cpus \"\n                        f\"{compute_cores}\"\n                    )\n\n                yield\n            finally:\n                # restore omp_num_threads and cpu affinity\n                psutil.Process().cpu_affinity(affinity_old)\n                set_num_threads(nthreads_old)\n                self.worker_init_fn = worker_init_fn_old\n\n                self.cpu_affinity_enabled = False\n        else:\n            yield\n\n    # To allow data other than node/edge data to be prefetched.\n    def attach_data(self, name, data):\n        \"\"\"Add a data other than node and edge features for prefetching.\"\"\"\n        self.other_storages[name] = wrap_storage(data)\n\n\n######## Graph DataLoaders ########\n# GraphDataLoader loads a set of graphs so it's not relevant to the above.  They are currently\n# copied from the old DataLoader implementation.\n\n\ndef _create_dist_sampler(dataset, dataloader_kwargs, ddp_seed):\n    # Note: will change the content of dataloader_kwargs\n    dist_sampler_kwargs = {\"shuffle\": dataloader_kwargs.get(\"shuffle\", False)}\n    dataloader_kwargs[\"shuffle\"] = False\n    dist_sampler_kwargs[\"seed\"] = ddp_seed\n    dist_sampler_kwargs[\"drop_last\"] = dataloader_kwargs.get(\"drop_last\", False)\n    dataloader_kwargs[\"drop_last\"] = False\n\n    return DistributedSampler(dataset, **dist_sampler_kwargs)\n\n\nclass GraphCollator(object):\n    \"\"\"Given a set of graphs as well as their graph-level data, the collate function will batch the\n    graphs into a batched graph, and stack the tensors into a single bigger tensor.  If the\n    example is a container (such as sequences or mapping), the collate function preserves\n    the structure and collates each of the elements recursively.\n\n    If the set of graphs has no graph-level data, the collate function will yield a batched graph.\n\n    Examples\n    --------\n    To train a GNN for graph classification on a set of graphs in ``dataset`` (assume\n    the backend is PyTorch):\n\n    >>> dataloader = dgl.dataloading.GraphDataLoader(\n    ...     dataset, batch_size=1024, shuffle=True, drop_last=False, num_workers=4)\n    >>> for batched_graph, labels in dataloader:\n    ...     train_on(batched_graph, labels)\n    \"\"\"\n\n    def __init__(self):\n        self.graph_collate_err_msg_format = (\n            \"graph_collate: batch must contain DGLGraph, tensors, numpy arrays, \"\n            \"numbers, dicts or lists; found {}\"\n        )\n        self.np_str_obj_array_pattern = re.compile(r\"[SaUO]\")\n\n    # This implementation is based on torch.utils.data._utils.collate.default_collate\n    def collate(self, items):\n        \"\"\"This function is similar to ``torch.utils.data._utils.collate.default_collate``.\n        It combines the sampled graphs and corresponding graph-level data\n        into a batched graph and tensors.\n\n        Parameters\n        ----------\n        items : list of data points or tuples\n            Elements in the list are expected to have the same length.\n            Each sub-element will be batched as a batched graph, or a\n            batched tensor correspondingly.\n\n        Returns\n        -------\n        A tuple of the batching results.\n        \"\"\"\n        elem = items[0]\n        elem_type = type(elem)\n        if isinstance(elem, DGLGraph):\n            batched_graphs = batch_graphs(items)\n            return batched_graphs\n        elif F.is_tensor(elem):\n            return F.stack(items, 0)\n        elif (\n            elem_type.__module__ == \"numpy\"\n            and elem_type.__name__ != \"str_\"\n            and elem_type.__name__ != \"string_\"\n        ):\n            if (\n                elem_type.__name__ == \"ndarray\"\n                or elem_type.__name__ == \"memmap\"\n            ):\n                # array of string classes and object\n                if (\n                    self.np_str_obj_array_pattern.search(elem.dtype.str)\n                    is not None\n                ):\n                    raise TypeError(\n                        self.graph_collate_err_msg_format.format(elem.dtype)\n                    )\n\n                return self.collate([F.tensor(b) for b in items])\n            elif elem.shape == ():  # scalars\n                return F.tensor(items)\n        elif isinstance(elem, float):\n            return F.tensor(items, dtype=F.float64)\n        elif isinstance(elem, int):\n            return F.tensor(items)\n        elif isinstance(elem, (str, bytes)):\n            return items\n        elif isinstance(elem, Mapping):\n            return {key: self.collate([d[key] for d in items]) for key in elem}\n        elif isinstance(elem, tuple) and hasattr(elem, \"_fields\"):  # namedtuple\n            return elem_type(\n                *(self.collate(samples) for samples in zip(*items))\n            )\n        elif isinstance(elem, Sequence):\n            # check to make sure that the elements in batch have consistent size\n            item_iter = iter(items)\n            elem_size = len(next(item_iter))\n            if not all(len(elem) == elem_size for elem in item_iter):\n                raise RuntimeError(\n                    \"each element in list of batch should be of equal size\"\n                )\n            transposed = zip(*items)\n            return [self.collate(samples) for samples in transposed]\n\n        raise TypeError(self.graph_collate_err_msg_format.format(elem_type))\n\n\nclass GraphDataLoader(torch.utils.data.DataLoader):\n    \"\"\"Batched graph data loader.\n\n    PyTorch dataloader for batch-iterating over a set of graphs, generating the batched\n    graph and corresponding label tensor (if provided) of the said minibatch.\n\n    Parameters\n    ----------\n    dataset : torch.utils.data.Dataset\n        The dataset to load graphs from.\n    collate_fn : Function, default is None\n        The customized collate function. Will use the default collate\n        function if not given.\n    use_ddp : boolean, optional\n        If True, tells the DataLoader to split the training set for each\n        participating process appropriately using\n        :class:`torch.utils.data.distributed.DistributedSampler`.\n\n        Overrides the :attr:`sampler` argument of :class:`torch.utils.data.DataLoader`.\n    ddp_seed : int, optional\n        The seed for shuffling the dataset in\n        :class:`torch.utils.data.distributed.DistributedSampler`.\n\n        Only effective when :attr:`use_ddp` is True.\n    kwargs : dict\n        Key-word arguments to be passed to the parent PyTorch\n        :py:class:`torch.utils.data.DataLoader` class. Common arguments are:\n\n          - ``batch_size`` (int): The number of indices in each batch.\n          - ``drop_last`` (bool): Whether to drop the last incomplete batch.\n          - ``shuffle`` (bool): Whether to randomly shuffle the indices at each epoch.\n\n    Examples\n    --------\n    To train a GNN for graph classification on a set of graphs in ``dataset``:\n\n    >>> dataloader = dgl.dataloading.GraphDataLoader(\n    ...     dataset, batch_size=1024, shuffle=True, drop_last=False, num_workers=4)\n    >>> for batched_graph, labels in dataloader:\n    ...     train_on(batched_graph, labels)\n\n    **With Distributed Data Parallel**\n\n    If you are using PyTorch's distributed training (e.g. when using\n    :mod:`torch.nn.parallel.DistributedDataParallel`), you can train the model by\n    turning on the :attr:`use_ddp` option:\n\n    >>> dataloader = dgl.dataloading.GraphDataLoader(\n    ...     dataset, use_ddp=True, batch_size=1024, shuffle=True, drop_last=False, num_workers=4)\n    >>> for epoch in range(start_epoch, n_epochs):\n    ...     dataloader.set_epoch(epoch)\n    ...     for batched_graph, labels in dataloader:\n    ...         train_on(batched_graph, labels)\n    \"\"\"\n\n    collator_arglist = inspect.getfullargspec(GraphCollator).args\n\n    def __init__(\n        self, dataset, collate_fn=None, use_ddp=False, ddp_seed=0, **kwargs\n    ):\n        collator_kwargs = {}\n        dataloader_kwargs = {}\n        for k, v in kwargs.items():\n            if k in self.collator_arglist:\n                collator_kwargs[k] = v\n            else:\n                dataloader_kwargs[k] = v\n\n        self.use_ddp = use_ddp\n        if use_ddp:\n            self.dist_sampler = _create_dist_sampler(\n                dataset, dataloader_kwargs, ddp_seed\n            )\n            dataloader_kwargs[\"sampler\"] = self.dist_sampler\n\n        if collate_fn is None and kwargs.get(\"batch_size\", 1) is not None:\n            collate_fn = GraphCollator(**collator_kwargs).collate\n\n        super().__init__(\n            dataset=dataset, collate_fn=collate_fn, **dataloader_kwargs\n        )\n\n    def set_epoch(self, epoch):\n        \"\"\"Sets the epoch number for the underlying sampler which ensures all replicas\n        to use a different ordering for each epoch.\n\n        Only available when :attr:`use_ddp` is True.\n\n        Calls :meth:`torch.utils.data.distributed.DistributedSampler.set_epoch`.\n\n        Parameters\n        ----------\n        epoch : int\n            The epoch number.\n        \"\"\"\n        if self.use_ddp:\n            self.dist_sampler.set_epoch(epoch)\n        else:\n            raise DGLError(\"set_epoch is only available when use_ddp is True.\")\n\n\nclass NodeCollator:\n    \"\"\"Deprecated. Please use :class:`~dgl.distributed.NodeCollator` instead.\"\"\"\n\n    def __new__(cls, *args, **kwargs):\n        dgl_warning(\n            \"NodeCollator is defined in dgl.distributed This class is for \"\n            \"backward compatibility and will be removed soon. Please update \"\n            \"your code to use `dgl.distributed.NodeCollator`.\"\n        )\n        from ..distributed import NodeCollator as NewNodeCollator\n\n        return NewNodeCollator(*args, **kwargs)\n\n\nclass EdgeCollator:\n    \"\"\"Deprecated. Please use :class:`~dgl.distributed.EdgeCollator` instead.\"\"\"\n\n    def __new__(cls, *args, **kwargs):\n        dgl_warning(\n            \"EdgeCollator is defined in dgl.distributed This class is for \"\n            \"backward compatibility and will be removed soon. Please update \"\n            \"your code to use `dgl.distributed.EdgeCollator`.\"\n        )\n        from ..distributed import EdgeCollator as NewEdgeCollator\n\n        return NewEdgeCollator(*args, **kwargs)\n\n\ndef _remove_kwargs_dist(kwargs):\n    \"\"\"Deprecated.\"\"\"\n    if \"num_workers\" in kwargs:\n        del kwargs[\"num_workers\"]\n    if \"pin_memory\" in kwargs:\n        del kwargs[\"pin_memory\"]\n        print(\"Distributed DataLoaders do not support pin_memory.\")\n    return kwargs\n\n\nclass DistDataLoader:\n    \"\"\"Deprecated. Please use :class:`~dgl.distributed.DistDataLoader` instead.\"\"\"\n\n    def __new__(cls, *args, **kwargs):\n        dgl_warning(\n            \"DistDataLoader is defined in dgl.distributed This class is for \"\n            \"backward compatibility and will be removed soon. Please update \"\n            \"your code to use `dgl.distributed.DistDataLoader`.\"\n        )\n        from ..distributed import DistDataLoader as NewDistDataLoader\n\n        return NewDistDataLoader(*args, **kwargs)\n\n\nclass DistNodeDataLoader:\n    \"\"\"Deprecated. Please use :class:`~dgl.distributed.DistNodeDataLoader`\n    instead.\n    \"\"\"\n\n    def __new__(cls, *args, **kwargs):\n        dgl_warning(\n            \"dgl.dataloading.DistNodeDataLoader has been moved to \"\n            \"dgl.distributed.DistNodeDataLoader. This old class is deprecated \"\n            \"and will be removed soon. Please update your code to use the new \"\n            \"class.\"\n        )\n        from ..distributed import DistNodeDataLoader as NewDistNodeDataLoader\n\n        return NewDistNodeDataLoader(*args, **kwargs)\n\n\nclass DistEdgeDataLoader:\n    \"\"\"Deprecated. Please use :class:`~dgl.distributed.DistEdgeDataLoader`\n    instead.\n    \"\"\"\n\n    def __new__(cls, *args, **kwargs):\n        dgl_warning(\n            \"dgl.dataloading.DistEdgeDataLoader has been moved to \"\n            \"dgl.distributed.DistEdgeDataLoader. This old class is deprecated \"\n            \"and will be removed soon. Please update your code to use the new \"\n            \"class.\"\n        )\n        from ..distributed import DistEdgeDataLoader as NewDistEdgeDataLoader\n\n        return NewDistEdgeDataLoader(*args, **kwargs)\n"
  },
  {
    "path": "python/dgl/dataloading/graphsaint.py",
    "content": "\"\"\"GraphSAINT samplers.\"\"\"\nfrom ..base import DGLError\nfrom ..random import choice\nfrom ..sampling import pack_traces, random_walk\nfrom .base import Sampler, set_edge_lazy_features, set_node_lazy_features\n\ntry:\n    import torch\nexcept ImportError:\n    pass\n\n\nclass SAINTSampler(Sampler):\n    \"\"\"Random node/edge/walk sampler from\n    `GraphSAINT: Graph Sampling Based Inductive Learning Method\n    <https://arxiv.org/abs/1907.04931>`__\n\n    For each call, the sampler samples a node subset and then returns a node induced subgraph.\n    There are three options for sampling node subsets:\n\n    - For :attr:`'node'` sampler, the probability to sample a node is in proportion\n      to its out-degree.\n    - The :attr:`'edge'` sampler first samples an edge subset and then use the\n      end nodes of the edges.\n    - The :attr:`'walk'` sampler uses the nodes visited by random walks. It uniformly selects\n      a number of root nodes and then performs a fixed-length random walk from each root node.\n\n    Parameters\n    ----------\n    mode : str\n        The sampler to use, which can be :attr:`'node'`, :attr:`'edge'`, or :attr:`'walk'`.\n    budget : int or tuple[int]\n        Sampler configuration.\n\n        - For :attr:`'node'` sampler, budget specifies the number of nodes\n          in each sampled subgraph.\n        - For :attr:`'edge'` sampler, budget specifies the number of edges\n          to sample for inducing a subgraph.\n        - For :attr:`'walk'` sampler, budget is a tuple. budget[0] specifies\n          the number of root nodes to generate random walks. budget[1] specifies\n          the length of a random walk.\n\n    cache : bool, optional\n        If False, it will not cache the probability arrays for sampling. Setting\n        it to False is required if you want to use the sampler across different graphs.\n    prefetch_ndata : list[str], optional\n        The node data to prefetch for the subgraph.\n\n        See :ref:`guide-minibatch-prefetching` for a detailed explanation of prefetching.\n    prefetch_edata : list[str], optional\n        The edge data to prefetch for the subgraph.\n\n        See :ref:`guide-minibatch-prefetching` for a detailed explanation of prefetching.\n    output_device : device, optional\n        The device of the output subgraphs.\n\n    Examples\n    --------\n\n    >>> import torch\n    >>> from dgl.dataloading import SAINTSampler, DataLoader\n    >>> num_iters = 1000\n    >>> sampler = SAINTSampler(mode='node', budget=6000)\n    >>> # Assume g.ndata['feat'] and g.ndata['label'] hold node features and labels\n    >>> dataloader = DataLoader(g, torch.arange(num_iters), sampler, num_workers=4)\n    >>> for subg in dataloader:\n    ...     train_on(subg)\n    \"\"\"\n\n    def __init__(\n        self,\n        mode,\n        budget,\n        cache=True,\n        prefetch_ndata=None,\n        prefetch_edata=None,\n        output_device=\"cpu\",\n    ):\n        super().__init__()\n        self.budget = budget\n        if mode == \"node\":\n            self.sampler = self.node_sampler\n        elif mode == \"edge\":\n            self.sampler = self.edge_sampler\n        elif mode == \"walk\":\n            self.sampler = self.walk_sampler\n        else:\n            raise DGLError(\n                f\"Expect mode to be 'node', 'edge' or 'walk', got {mode}.\"\n            )\n\n        self.cache = cache\n        self.prob = None\n        self.prefetch_ndata = prefetch_ndata or []\n        self.prefetch_edata = prefetch_edata or []\n        self.output_device = output_device\n\n    def node_sampler(self, g):\n        \"\"\"Node ID sampler for random node sampler\"\"\"\n        # Alternatively, this can be realized by uniformly sampling an edge subset,\n        # and then take the src node of the sampled edges. However, the number of edges\n        # is typically much larger than the number of nodes.\n        if self.cache and self.prob is not None:\n            prob = self.prob\n        else:\n            prob = g.out_degrees().float().clamp(min=1)\n            if self.cache:\n                self.prob = prob\n        return (\n            torch.multinomial(prob, num_samples=self.budget, replacement=True)\n            .unique()\n            .type(g.idtype)\n        )\n\n    def edge_sampler(self, g):\n        \"\"\"Node ID sampler for random edge sampler\"\"\"\n        src, dst = g.edges()\n        if self.cache and self.prob is not None:\n            prob = self.prob\n        else:\n            in_deg = g.in_degrees().float().clamp(min=1)\n            out_deg = g.out_degrees().float().clamp(min=1)\n            # We can reduce the sample space by half if graphs are always symmetric.\n            prob = 1.0 / in_deg[dst.long()] + 1.0 / out_deg[src.long()]\n            prob /= prob.sum()\n            if self.cache:\n                self.prob = prob\n        sampled_edges = torch.unique(\n            choice(len(prob), size=self.budget, prob=prob)\n        )\n        sampled_nodes = torch.cat([src[sampled_edges], dst[sampled_edges]])\n        return sampled_nodes.unique().type(g.idtype)\n\n    def walk_sampler(self, g):\n        \"\"\"Node ID sampler for random walk sampler\"\"\"\n        num_roots, walk_length = self.budget\n        sampled_roots = torch.randint(0, g.num_nodes(), (num_roots,))\n        traces, types = random_walk(g, nodes=sampled_roots, length=walk_length)\n        sampled_nodes, _, _, _ = pack_traces(traces, types)\n        return sampled_nodes.unique().type(g.idtype)\n\n    def sample(self, g, indices):\n        \"\"\"Sampling function\n\n        Parameters\n        ----------\n        g : DGLGraph\n            The graph to sample from.\n        indices : Tensor\n            Placeholder not used.\n\n        Returns\n        -------\n        DGLGraph\n            The sampled subgraph.\n        \"\"\"\n        node_ids = self.sampler(g)\n        sg = g.subgraph(\n            node_ids, relabel_nodes=True, output_device=self.output_device\n        )\n        set_node_lazy_features(sg, self.prefetch_ndata)\n        set_edge_lazy_features(sg, self.prefetch_edata)\n        return sg\n"
  },
  {
    "path": "python/dgl/dataloading/labor_sampler.py",
    "content": "#\n#   Copyright (c) 2022 by Contributors\n#\n#   Licensed under the Apache License, Version 2.0 (the \"License\");\n#   you may not use this file except in compliance with the License.\n#   You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n#   Unless required by applicable law or agreed to in writing, software\n#   distributed under the License is distributed on an \"AS IS\" BASIS,\n#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#   See the License for the specific language governing permissions and\n#   limitations under the License.\n#\n#   Based off of neighbor_sampler.py\n#\n\n\"\"\"Data loading components for labor sampling\"\"\"\nfrom numpy.random import default_rng\n\nfrom .. import backend as F\nfrom ..base import EID, NID\nfrom ..random import choice\nfrom ..transforms import to_block\nfrom .base import BlockSampler\n\n\nclass LaborSampler(BlockSampler):\n    \"\"\"Sampler that builds computational dependency of node representations via\n    labor sampling for multilayer GNN from the NeurIPS 2023 paper\n    `Layer-Neighbor Sampling -- Defusing Neighborhood Explosion in GNNs\n    <https://arxiv.org/abs/2210.13339>`__\n\n    This sampler will make every node gather messages from a fixed number of\n    neighbors per edge type. The neighbors are picked uniformly with default\n    parameters. For every vertex t that will be considered to be sampled, there\n    will be a single random variate r_t.\n\n    Parameters\n    ----------\n    fanouts : list[int] or list[dict[etype, int]]\n        List of neighbors to sample per edge type for each GNN layer, with the\n        i-th element being the fanout for the i-th GNN layer.\n\n        If only a single integer is provided, DGL assumes that every edge type\n        will have the same fanout.\n\n        If -1 is provided for one edge type on one layer, then all inbound edges\n        of that edge type will be included.\n    edge_dir : str, default ``'in'``\n        Can be either ``'in'`` where the neighbors will be sampled according to\n        incoming edges, or ``'out'`` otherwise, same as\n        :func:`dgl.sampling.sample_neighbors`.\n    prob : str, optional\n        If given, the probability of each neighbor being sampled is proportional\n        to the edge feature value with the given name in ``g.edata``.\n        The feature must be a scalar on each edge. In this case, the returned\n        blocks edata include ``'edge_weights'`` that needs to be used in the\n        message passing operation.\n    importance_sampling : int, default ``0``\n        Whether to use importance sampling or uniform sampling, use of negative\n        values optimizes importance sampling probabilities until convergence\n        while use of positive values runs optimization steps that many times.\n        If the value is i, then LABOR-i variant is used. When used with a\n        nonzero parameter, the returned blocks edata include ``'edge_weights'``\n        that needs to be used in the message passing operation.\n    layer_dependency : bool, default ``False``\n        Specifies whether different layers should use same random variates.\n        Results into a reduction in the number of vertices sampled, but may\n        degrade the quality slightly.\n    batch_dependency : int, default ``1``\n        Specifies whether different minibatches should use similar random\n        variates. Results in a higher temporal access locality of sampled\n        vertices, but may degrade the quality slightly.\n    prefetch_node_feats : list[str] or dict[ntype, list[str]], optional\n        The source node data to prefetch for the first MFG, corresponding to the\n        input node features necessary for the first GNN layer.\n    prefetch_labels : list[str] or dict[ntype, list[str]], optional\n        The destination node data to prefetch for the last MFG, corresponding to\n        the node labels of the minibatch.\n    prefetch_edge_feats : list[str] or dict[etype, list[str]], optional\n        The edge data names to prefetch for all the MFGs, corresponding to the\n        edge features necessary for all GNN layers.\n    output_device : device, optional\n        The device of the output subgraphs or MFGs.  Default is the same as the\n        minibatch of seed nodes.\n\n    Examples\n    --------\n    **Node classification**\n\n    To train a 3-layer GNN for node classification on a set of nodes\n    ``train_nid`` on a homogeneous graph where each node takes messages from\n    5, 10, 15 neighbors for the first, second, and third layer respectively\n    (assuming the backend is PyTorch):\n\n    >>> sampler = dgl.dataloading.LaborSampler([5, 10, 15])\n    >>> dataloader = dgl.dataloading.DataLoader(\n    ...     g, train_nid, sampler,\n    ...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)\n    >>> for input_nodes, output_nodes, blocks in dataloader:\n    ...     train_on(blocks)\n\n    If training on a heterogeneous graph and you want different number of\n    neighbors for each edge type, one should instead provide a list of dicts.\n    Each dict would specify the number of neighbors to pick per edge type.\n\n    >>> sampler = dgl.dataloading.LaborSampler([\n    ...     {('user', 'follows', 'user'): 5,\n    ...      ('user', 'plays', 'game'): 4,\n    ...      ('game', 'played-by', 'user'): 3}] * 3)\n\n    If you would like non-uniform labor sampling:\n\n    >>> # any non-negative 1D vector works\n    >>> g.edata['p'] = torch.rand(g.num_edges())\n    >>> sampler = dgl.dataloading.LaborSampler([5, 10, 15], prob='p')\n\n    **Edge classification and link prediction**\n\n    This class can also work for edge classification and link prediction\n    together with :func:`as_edge_prediction_sampler`.\n\n    >>> sampler = dgl.dataloading.LaborSampler([5, 10, 15])\n    >>> sampler = dgl.dataloading.as_edge_prediction_sampler(sampler)\n    >>> dataloader = dgl.dataloading.DataLoader(\n    ...     g, train_eid, sampler,\n    ...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)\n\n    See the documentation :func:`as_edge_prediction_sampler` for more details.\n\n    Notes\n    -----\n    For the concept of MFGs, please refer to\n    :ref:`User Guide Section 6 <guide-minibatch>` and\n    :doc:`Minibatch Training Tutorials\n    <tutorials/large/L0_neighbor_sampling_overview>`.\n    \"\"\"\n\n    def __init__(\n        self,\n        fanouts,\n        edge_dir=\"in\",\n        prob=None,\n        importance_sampling=0,\n        layer_dependency=False,\n        batch_dependency=1,\n        prefetch_node_feats=None,\n        prefetch_labels=None,\n        prefetch_edge_feats=None,\n        output_device=None,\n    ):\n        super().__init__(\n            prefetch_node_feats=prefetch_node_feats,\n            prefetch_labels=prefetch_labels,\n            prefetch_edge_feats=prefetch_edge_feats,\n            output_device=output_device,\n        )\n        self.fanouts = fanouts\n        self.edge_dir = edge_dir\n        self.prob = prob\n        self.importance_sampling = importance_sampling\n        self.layer_dependency = layer_dependency\n        self.cnt = F.zeros(2, F.int64, F.cpu())\n        self.cnt[0] = -1\n        self.cnt[1] = batch_dependency\n        self.random_seed = F.zeros(\n            2 if self.cnt[1] > 1 else 1, F.int64, F.cpu()\n        )\n        self.set_seed(None if batch_dependency > 0 else choice(1e18, 1).item())\n\n    def set_seed(self, random_seed=None):\n        \"\"\"Updates the underlying seed for the sampler\n\n        Calling this function enforces the sampling algorithm to use the same\n        seed on every edge type. This can reduce the number of nodes being\n        sampled because the passed random_seed makes it so that for any seed\n        vertex ``s`` and its neighbor ``t``, the rolled random variate ``r_t``\n        is the same for any instance of this class with the same random seed.\n        When sampling as part of the same batch, one would want identical seeds\n        so that LABOR can globally sample. One example is that for heterogenous\n        graphs, there is a single random seed passed for each edge type. This\n        will sample much fewer vertices compared to having unique random seeds\n        for each edge type. If one called this function individually for each\n        edge type for a heterogenous graph with different random seeds, then it\n        would run LABOR locally for each edge type, resulting into a larger\n        number of vertices being sampled.\n\n        If this function is called without any parameters, we get the random\n        seed by getting a random number from DGL. Call this function if multiple\n        instances of LaborSampler are used to sample as part of a single batch.\n\n        Parameters\n        ----------\n        random_seed : int, default ``None``\n            The random seed to be used for next sampling call.\n        \"\"\"\n        if random_seed is None:\n            self.cnt[0] += 1\n            if self.cnt[1] > 0 and self.cnt[0] % self.cnt[1] == 0:\n                if self.cnt[0] <= 0 or self.cnt[1] <= 1:\n                    if not hasattr(self, \"rng\"):\n                        self.rng = default_rng(choice(1e18, 1).item())\n                    self.random_seed[0] = self.rng.integers(1e18)\n                    if self.cnt[1] > 1:\n                        self.random_seed[1] = self.rng.integers(1e18)\n                else:\n                    self.random_seed[0] = self.random_seed[1]\n                    self.random_seed[1] = self.rng.integers(1e18)\n        else:\n            self.rng = default_rng(random_seed)\n            self.random_seed[0] = self.rng.integers(1e18)\n            if self.cnt[1] > 1:\n                self.random_seed[1] = self.rng.integers(1e18)\n            self.cnt[0] = 0\n\n    def sample_blocks(self, g, seed_nodes, exclude_eids=None):\n        output_nodes = seed_nodes\n        blocks = []\n        for i, fanout in enumerate(reversed(self.fanouts)):\n            random_seed_i = F.zerocopy_to_dgl_ndarray(\n                self.random_seed + (i if not self.layer_dependency else 0)\n            )\n            if self.cnt[1] <= 1:\n                seed2_contr = 0\n            else:\n                seed2_contr = ((self.cnt[0] % self.cnt[1]) / self.cnt[1]).item()\n            frontier, importances = g.sample_labors(\n                seed_nodes,\n                fanout,\n                edge_dir=self.edge_dir,\n                prob=self.prob,\n                importance_sampling=self.importance_sampling,\n                random_seed=random_seed_i,\n                seed2_contribution=seed2_contr,\n                output_device=self.output_device,\n                exclude_edges=exclude_eids,\n            )\n            eid = frontier.edata[EID]\n            block = to_block(\n                frontier, seed_nodes, include_dst_in_src=True, src_nodes=None\n            )\n            block.edata[EID] = eid\n            if len(g.canonical_etypes) > 1:\n                for etype, importance in zip(g.canonical_etypes, importances):\n                    if importance.shape[0] == block.num_edges(etype):\n                        block.edata[\"edge_weights\"][etype] = importance\n            elif importances[0].shape[0] == block.num_edges():\n                block.edata[\"edge_weights\"] = importances[0]\n            seed_nodes = block.srcdata[NID]\n            blocks.insert(0, block)\n\n        self.set_seed()\n        return seed_nodes, output_nodes, blocks\n"
  },
  {
    "path": "python/dgl/dataloading/negative_sampler.py",
    "content": "\"\"\"Negative samplers\"\"\"\nfrom collections.abc import Mapping\n\nfrom .. import backend as F\n\n\nclass _BaseNegativeSampler(object):\n    def _generate(self, g, eids, canonical_etype):\n        raise NotImplementedError\n\n    def __call__(self, g, eids):\n        \"\"\"Returns negative samples.\n\n        Parameters\n        ----------\n        g : DGLGraph\n            The graph.\n        eids : Tensor or dict[etype, Tensor]\n            The sampled edges in the minibatch.\n\n        Returns\n        -------\n        tuple[Tensor, Tensor] or dict[etype, tuple[Tensor, Tensor]]\n            The returned source-destination pairs as negative samples.\n        \"\"\"\n        if isinstance(eids, Mapping):\n            eids = {g.to_canonical_etype(k): v for k, v in eids.items()}\n            neg_pair = {k: self._generate(g, v, k) for k, v in eids.items()}\n        else:\n            assert (\n                len(g.canonical_etypes) == 1\n            ), \"please specify a dict of etypes and ids for graphs with multiple edge types\"\n            neg_pair = self._generate(g, eids, g.canonical_etypes[0])\n\n        return neg_pair\n\n\nclass PerSourceUniform(_BaseNegativeSampler):\n    \"\"\"Negative sampler that randomly chooses negative destination nodes\n    for each source node according to a uniform distribution.\n\n    For each edge ``(u, v)`` of type ``(srctype, etype, dsttype)``, DGL generates\n    :attr:`k` pairs of negative edges ``(u, v')``, where ``v'`` is chosen\n    uniformly from all the nodes of type ``dsttype``.  The resulting edges will\n    also have type ``(srctype, etype, dsttype)``.\n\n    Parameters\n    ----------\n    k : int\n        The number of negative samples per edge.\n\n    Examples\n    --------\n    >>> g = dgl.graph(([0, 1, 2], [1, 2, 3]))\n    >>> neg_sampler = dgl.dataloading.negative_sampler.PerSourceUniform(2)\n    >>> neg_sampler(g, torch.tensor([0, 1]))\n    (tensor([0, 0, 1, 1]), tensor([1, 0, 2, 3]))\n    \"\"\"\n\n    def __init__(self, k):\n        self.k = k\n\n    def _generate(self, g, eids, canonical_etype):\n        _, _, vtype = canonical_etype\n        shape = F.shape(eids)\n        dtype = F.dtype(eids)\n        ctx = F.context(eids)\n        shape = (shape[0] * self.k,)\n        src, _ = g.find_edges(eids, etype=canonical_etype)\n        src = F.repeat(src, self.k, 0)\n        dst = F.randint(shape, dtype, ctx, 0, g.num_nodes(vtype))\n        return src, dst\n\n\n# Alias\nUniform = PerSourceUniform\n\n\nclass GlobalUniform(_BaseNegativeSampler):\n    \"\"\"Negative sampler that randomly chooses negative source-destination pairs according\n    to a uniform distribution.\n\n    For each edge ``(u, v)`` of type ``(srctype, etype, dsttype)``, DGL generates at most\n    :attr:`k` pairs of negative edges ``(u', v')``, where ``u'`` is chosen uniformly from\n    all the nodes of type ``srctype`` and ``v'`` is chosen uniformly from all the nodes\n    of type ``dsttype``.  The resulting edges will also have type\n    ``(srctype, etype, dsttype)``.  DGL guarantees that the sampled pairs will not have\n    edges in between.\n\n    Parameters\n    ----------\n    k : int\n        The desired number of negative samples to generate per edge.\n    exclude_self_loops : bool, optional\n        Whether to exclude self-loops from negative samples.  (Default: True)\n    replace : bool, optional\n        Whether to sample with replacement.  Setting it to True will make things\n        faster.  (Default: False)\n\n    Notes\n    -----\n    This negative sampler will try to generate as many negative samples as possible, but\n    it may rarely return less than :attr:`k` negative samples per edge.\n    This is more likely to happen if a graph is so small or dense that not many unique\n    negative samples exist.\n\n    Examples\n    --------\n    >>> g = dgl.graph(([0, 1, 2], [1, 2, 3]))\n    >>> neg_sampler = dgl.dataloading.negative_sampler.GlobalUniform(2, True)\n    >>> neg_sampler(g, torch.LongTensor([0, 1]))\n    (tensor([0, 1, 3, 2]), tensor([2, 0, 2, 1]))\n    \"\"\"\n\n    def __init__(self, k, exclude_self_loops=True, replace=False):\n        self.k = k\n        self.exclude_self_loops = exclude_self_loops\n        self.replace = replace\n\n    def _generate(self, g, eids, canonical_etype):\n        return g.global_uniform_negative_sampling(\n            len(eids) * self.k,\n            self.exclude_self_loops,\n            self.replace,\n            canonical_etype,\n        )\n"
  },
  {
    "path": "python/dgl/dataloading/neighbor_sampler.py",
    "content": "\"\"\"Data loading components for neighbor sampling\"\"\"\n\nfrom .. import backend as F\nfrom ..base import EID, NID\nfrom ..heterograph import DGLGraph\nfrom ..transforms import to_block\nfrom ..utils import get_num_threads\nfrom .base import BlockSampler\n\n\nclass NeighborSampler(BlockSampler):\n    \"\"\"Sampler that builds computational dependency of node representations via\n    neighbor sampling for multilayer GNN.\n\n    This sampler will make every node gather messages from a fixed number of neighbors\n    per edge type.  The neighbors are picked uniformly.\n\n    Parameters\n    ----------\n    fanouts : list[int] or list[dict[etype, int]]\n        List of neighbors to sample per edge type for each GNN layer, with the i-th\n        element being the fanout for the i-th GNN layer.\n\n        If only a single integer is provided, DGL assumes that every edge type\n        will have the same fanout.\n\n        If -1 is provided for one edge type on one layer, then all inbound edges\n        of that edge type will be included.\n    edge_dir : str, default ``'in'``\n        Can be either ``'in' `` where the neighbors will be sampled according to\n        incoming edges, or ``'out'`` otherwise, same as :func:`dgl.sampling.sample_neighbors`.\n    prob : str, optional\n        If given, the probability of each neighbor being sampled is proportional\n        to the edge feature value with the given name in ``g.edata``.  The feature must be\n        a scalar on each edge.\n\n        This argument is mutually exclusive with :attr:`mask`.  If you want to\n        specify both a mask and a probability, consider multiplying the probability\n        with the mask instead.\n    mask : str, optional\n        If given, a neighbor could be picked only if the edge mask with the given\n        name in ``g.edata`` is True.  The data must be boolean on each edge.\n\n        This argument is mutually exclusive with :attr:`prob`.  If you want to\n        specify both a mask and a probability, consider multiplying the probability\n        with the mask instead.\n    replace : bool, default False\n        Whether to sample with replacement\n    prefetch_node_feats : list[str] or dict[ntype, list[str]], optional\n        The source node data to prefetch for the first MFG, corresponding to the\n        input node features necessary for the first GNN layer.\n    prefetch_labels : list[str] or dict[ntype, list[str]], optional\n        The destination node data to prefetch for the last MFG, corresponding to\n        the node labels of the minibatch.\n    prefetch_edge_feats : list[str] or dict[etype, list[str]], optional\n        The edge data names to prefetch for all the MFGs, corresponding to the\n        edge features necessary for all GNN layers.\n    output_device : device, optional\n        The device of the output subgraphs or MFGs.  Default is the same as the\n        minibatch of seed nodes.\n    fused : bool, default True\n        If True and device is CPU fused sample neighbors is invoked. This version\n        requires seed_nodes to be unique\n\n    Examples\n    --------\n    **Node classification**\n\n    To train a 3-layer GNN for node classification on a set of nodes ``train_nid`` on\n    a homogeneous graph where each node takes messages from 5, 10, 15 neighbors for\n    the first, second, and third layer respectively (assuming the backend is PyTorch):\n\n    >>> sampler = dgl.dataloading.NeighborSampler([5, 10, 15])\n    >>> dataloader = dgl.dataloading.DataLoader(\n    ...     g, train_nid, sampler,\n    ...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)\n    >>> for input_nodes, output_nodes, blocks in dataloader:\n    ...     train_on(blocks)\n\n    If training on a heterogeneous graph and you want different number of neighbors for each\n    edge type, one should instead provide a list of dicts.  Each dict would specify the\n    number of neighbors to pick per edge type.\n\n    >>> sampler = dgl.dataloading.NeighborSampler([\n    ...     {('user', 'follows', 'user'): 5,\n    ...      ('user', 'plays', 'game'): 4,\n    ...      ('game', 'played-by', 'user'): 3}] * 3)\n\n    If you would like non-uniform neighbor sampling:\n\n    >>> g.edata['p'] = torch.rand(g.num_edges())   # any non-negative 1D vector works\n    >>> sampler = dgl.dataloading.NeighborSampler([5, 10, 15], prob='p')\n\n    Or sampling on edge masks:\n\n    >>> g.edata['mask'] = torch.rand(g.num_edges()) < 0.2   # any 1D boolean mask works\n    >>> sampler = dgl.dataloading.NeighborSampler([5, 10, 15], prob='mask')\n\n    **Edge classification and link prediction**\n\n    This class can also work for edge classification and link prediction together\n    with :func:`as_edge_prediction_sampler`.\n\n    >>> sampler = dgl.dataloading.NeighborSampler([5, 10, 15])\n    >>> sampler = dgl.dataloading.as_edge_prediction_sampler(sampler)\n    >>> dataloader = dgl.dataloading.DataLoader(\n    ...     g, train_eid, sampler,\n    ...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)\n\n    See the documentation :func:`as_edge_prediction_sampler` for more details.\n\n    Notes\n    -----\n    For the concept of MFGs, please refer to\n    :ref:`User Guide Section 6 <guide-minibatch>` and\n    :doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.\n    \"\"\"\n\n    def __init__(\n        self,\n        fanouts,\n        edge_dir=\"in\",\n        prob=None,\n        mask=None,\n        replace=False,\n        prefetch_node_feats=None,\n        prefetch_labels=None,\n        prefetch_edge_feats=None,\n        output_device=None,\n        fused=True,\n    ):\n        super().__init__(\n            prefetch_node_feats=prefetch_node_feats,\n            prefetch_labels=prefetch_labels,\n            prefetch_edge_feats=prefetch_edge_feats,\n            output_device=output_device,\n        )\n        self.fanouts = fanouts\n        self.edge_dir = edge_dir\n        if mask is not None and prob is not None:\n            raise ValueError(\n                \"Mask and probability arguments are mutually exclusive. \"\n                \"Consider multiplying the probability with the mask \"\n                \"to achieve the same goal.\"\n            )\n        self.prob = prob or mask\n        self.replace = replace\n        self.fused = fused\n        self.mapping = {}\n        self.g = None\n\n    def sample_blocks(self, g, seed_nodes, exclude_eids=None):\n        output_nodes = seed_nodes\n        blocks = []\n        # sample_neighbors_fused function requires multithreading to be more efficient\n        # than sample_neighbors\n        if self.fused and get_num_threads() > 1:\n            cpu = F.device_type(g.device) == \"cpu\"\n            if isinstance(seed_nodes, dict):\n                for ntype in list(seed_nodes.keys()):\n                    if not cpu:\n                        break\n                    cpu = (\n                        cpu and F.device_type(seed_nodes[ntype].device) == \"cpu\"\n                    )\n            else:\n                cpu = cpu and F.device_type(seed_nodes.device) == \"cpu\"\n            if cpu and isinstance(g, DGLGraph) and F.backend_name == \"pytorch\":\n                if self.g != g:\n                    self.mapping = {}\n                    self.g = g\n                for fanout in reversed(self.fanouts):\n                    block = g.sample_neighbors_fused(\n                        seed_nodes,\n                        fanout,\n                        edge_dir=self.edge_dir,\n                        prob=self.prob,\n                        replace=self.replace,\n                        exclude_edges=exclude_eids,\n                        mapping=self.mapping,\n                    )\n                    seed_nodes = block.srcdata[NID]\n                    blocks.insert(0, block)\n                return seed_nodes, output_nodes, blocks\n\n        for fanout in reversed(self.fanouts):\n            frontier = g.sample_neighbors(\n                seed_nodes,\n                fanout,\n                edge_dir=self.edge_dir,\n                prob=self.prob,\n                replace=self.replace,\n                output_device=self.output_device,\n                exclude_edges=exclude_eids,\n            )\n            block = to_block(frontier, seed_nodes)\n            # If sampled from graphbolt-backed DistGraph, `EID` may not be in\n            # the block. If not exists, we should remove it from the block.\n            if EID in frontier.edata.keys():\n                block.edata[EID] = frontier.edata[EID]\n            else:\n                del block.edata[EID]\n            seed_nodes = block.srcdata[NID]\n            blocks.insert(0, block)\n\n        return seed_nodes, output_nodes, blocks\n\n\nMultiLayerNeighborSampler = NeighborSampler\n\n\nclass MultiLayerFullNeighborSampler(NeighborSampler):\n    \"\"\"Sampler that builds computational dependency of node representations by taking messages\n    from all neighbors for multilayer GNN.\n\n    This sampler will make every node gather messages from every single neighbor per edge type.\n\n    Parameters\n    ----------\n    num_layers : int\n        The number of GNN layers to sample.\n    kwargs :\n        Passed to :class:`dgl.dataloading.NeighborSampler`.\n\n    Examples\n    --------\n    To train a 3-layer GNN for node classification on a set of nodes ``train_nid`` on\n    a homogeneous graph where each node takes messages from all neighbors for the first,\n    second, and third layer respectively (assuming the backend is PyTorch):\n\n    >>> sampler = dgl.dataloading.MultiLayerFullNeighborSampler(3)\n    >>> dataloader = dgl.dataloading.DataLoader(\n    ...     g, train_nid, sampler,\n    ...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)\n    >>> for input_nodes, output_nodes, blocks in dataloader:\n    ...     train_on(blocks)\n\n    Notes\n    -----\n    For the concept of MFGs, please refer to\n    :ref:`User Guide Section 6 <guide-minibatch>` and\n    :doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.\n    \"\"\"\n\n    def __init__(self, num_layers, **kwargs):\n        super().__init__([-1] * num_layers, **kwargs)\n"
  },
  {
    "path": "python/dgl/dataloading/shadow.py",
    "content": "\"\"\"ShaDow-GNN subgraph samplers.\"\"\"\nfrom .. import transforms\nfrom ..base import NID\nfrom ..sampling.utils import EidExcluder\nfrom .base import Sampler, set_edge_lazy_features, set_node_lazy_features\n\n\nclass ShaDowKHopSampler(Sampler):\n    \"\"\"K-hop subgraph sampler from `Deep Graph Neural Networks with Shallow\n    Subgraph Samplers <https://arxiv.org/abs/2012.01380>`__.\n\n    It performs node-wise neighbor sampling and returns the subgraph induced by\n    all the sampled nodes. The seed nodes from which the neighbors are sampled\n    will appear the first in the induced nodes of the subgraph.\n\n    Parameters\n    ----------\n    fanouts : list[int] or list[dict[etype, int]]\n        List of neighbors to sample per edge type for each GNN layer, with the i-th\n        element being the fanout for the i-th GNN layer.\n\n        If only a single integer is provided, DGL assumes that every edge type\n        will have the same fanout.\n\n        If -1 is provided for one edge type on one layer, then all inbound edges\n        of that edge type will be included.\n    replace : bool, default True\n        Whether to sample with replacement\n    prob : str, optional\n        If given, the probability of each neighbor being sampled is proportional\n        to the edge feature value with the given name in ``g.edata``. The feature must be\n        a scalar on each edge.\n\n    Examples\n    --------\n    **Node classification**\n\n    To train a 3-layer GNN for node classification on a set of nodes ``train_nid`` on\n    a homogeneous graph where each node takes messages from 5, 10, 15 neighbors for\n    the first, second, and third layer respectively (assuming the backend is PyTorch):\n\n    >>> g = dgl.data.CoraFullDataset()[0]\n    >>> sampler = dgl.dataloading.ShaDowKHopSampler([5, 10, 15])\n    >>> dataloader = dgl.dataloading.DataLoader(\n    ...     g, torch.arange(g.num_nodes()), sampler,\n    ...     batch_size=5, shuffle=True, drop_last=False, num_workers=4)\n    >>> for input_nodes, output_nodes, subgraph in dataloader:\n    ...     print(subgraph)\n    ...     assert torch.equal(input_nodes, subgraph.ndata[dgl.NID])\n    ...     assert torch.equal(input_nodes[:output_nodes.shape[0]], output_nodes)\n    ...     break\n    Graph(num_nodes=529, num_edges=3796,\n          ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64),\n                         'feat': Scheme(shape=(8710,), dtype=torch.float32),\n                         '_ID': Scheme(shape=(), dtype=torch.int64)}\n          edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)})\n\n    If training on a heterogeneous graph and you want different number of neighbors for each\n    edge type, one should instead provide a list of dicts. Each dict would specify the\n    number of neighbors to pick per edge type.\n\n    >>> sampler = dgl.dataloading.ShaDowKHopSampler([\n    ...     {('user', 'follows', 'user'): 5,\n    ...      ('user', 'plays', 'game'): 4,\n    ...      ('game', 'played-by', 'user'): 3}] * 3)\n\n    If you would like non-uniform neighbor sampling:\n\n    >>> g.edata['p'] = torch.rand(g.num_edges())   # any non-negative 1D vector works\n    >>> sampler = dgl.dataloading.ShaDowKHopSampler([5, 10, 15], prob='p')\n    \"\"\"\n\n    def __init__(\n        self,\n        fanouts,\n        replace=False,\n        prob=None,\n        prefetch_node_feats=None,\n        prefetch_edge_feats=None,\n        output_device=None,\n    ):\n        super().__init__()\n        self.fanouts = fanouts\n        self.replace = replace\n        self.prob = prob\n        self.prefetch_node_feats = prefetch_node_feats\n        self.prefetch_edge_feats = prefetch_edge_feats\n        self.output_device = output_device\n\n    def sample(\n        self, g, seed_nodes, exclude_eids=None\n    ):  # pylint: disable=arguments-differ\n        \"\"\"Sampling function.\n\n        Parameters\n        ----------\n        g : DGLGraph\n            The graph to sample nodes from.\n        seed_nodes : Tensor or dict[str, Tensor]\n            The nodes sampled in the current minibatch.\n        exclude_eids : Tensor or dict[etype, Tensor], optional\n            The edges to exclude from neighborhood expansion.\n\n        Returns\n        -------\n        input_nodes, output_nodes, subg\n            A triplet containing (1) the node IDs inducing the subgraph, (2) the node\n            IDs that are sampled in this minibatch, and (3) the subgraph itself.\n        \"\"\"\n        output_nodes = seed_nodes\n        for fanout in reversed(self.fanouts):\n            frontier = g.sample_neighbors(\n                seed_nodes,\n                fanout,\n                output_device=self.output_device,\n                replace=self.replace,\n                prob=self.prob,\n                exclude_edges=exclude_eids,\n            )\n            block = transforms.to_block(frontier, seed_nodes)\n            seed_nodes = block.srcdata[NID]\n\n        subg = g.subgraph(\n            seed_nodes, relabel_nodes=True, output_device=self.output_device\n        )\n        if exclude_eids is not None:\n            subg = EidExcluder(exclude_eids)(subg)\n\n        set_node_lazy_features(subg, self.prefetch_node_feats)\n        set_edge_lazy_features(subg, self.prefetch_edge_feats)\n\n        return seed_nodes, output_nodes, subg\n"
  },
  {
    "path": "python/dgl/dataloading/spot_target.py",
    "content": "\"\"\"SpotTarget: Target edge excluder for link prediction\"\"\"\nimport torch\n\nfrom .base import find_exclude_eids\n\n\nclass SpotTarget(object):\n    \"\"\"Callable excluder object to exclude the edges by the degree threshold.\n\n    Besides excluding all the edges or given edges in the edge sampler\n    ``dgl.dataloading.as_edge_prediction_sampler`` in link prediction training,\n    this excluder can extend the exclusion function by only excluding the edges incident\n    to low-degree nodes in the graph to bring the performance increase in training\n    link prediction model. This function will exclude the edge if incident to at least\n    one node with degree larger or equal to ``degree_threshold``. The performance\n    boost by excluding the target edges incident to low-degree nodes can be found\n    in this paper: https://arxiv.org/abs/2306.00899\n\n    Parameters\n    ----------\n    g : DGLGraph\n        The graph.\n    exclude : Union[str, callable]\n        Whether and how to exclude dependencies related to the sampled edges in the\n        minibatch.  Possible values are\n\n        * ``self``, for excluding the edges in the current minibatch.\n\n        * ``reverse_id``, for excluding not only the edges in the current minibatch but\n          also their reverse edges according to the ID mapping in the argument\n          :attr:`reverse_eids`.\n\n        * ``reverse_types``, for excluding not only the edges in the current minibatch\n          but also their reverse edges stored in another type according to\n          the argument :attr:`reverse_etypes`.\n\n        * User-defined exclusion rule. It is a callable with edges in the current\n          minibatch as a single argument and should return the edges to be excluded.\n    degree_threshold : int\n        The threshold of node degrees, if the source or target node of an edge incident to\n        has larger or equal degrees than ``degree_threshold``, this edge will be excluded from\n        the graph\n    reverse_eids : Tensor or dict[etype, Tensor], optional\n        A tensor of reverse edge ID mapping.  The i-th element indicates the ID of\n        the i-th edge's reverse edge.\n\n        If the graph is heterogeneous, this argument requires a dictionary of edge\n        types and the reverse edge ID mapping tensors.\n    reverse_etypes : dict[etype, etype], optional\n        The mapping from the original edge types to their reverse edge types.\n\n    Examples\n    --------\n    .. code:: python\n       low_degree_excluder = SpotTarget(g, degree_threshold=10)\n       sampler = as_edge_prediction_sampler(sampler, exclude=low_degree_excluder,\n       reverse_eids=reverse_eids, negative_sampler=negative_sampler.Uniform(1))\n    \"\"\"\n\n    def __init__(\n        self,\n        g,\n        exclude,\n        degree_threshold=10,\n        reverse_eids=None,\n        reverse_etypes=None,\n    ):\n        self.g = g\n        self.exclude = exclude\n        self.degree_threshold = degree_threshold\n        self.reverse_eids = reverse_eids\n        self.reverse_etypes = reverse_etypes\n\n    def __call__(self, seed_edges):\n        g = self.g\n        src, dst = g.find_edges(seed_edges)\n        head_degree = g.in_degrees(src)\n        tail_degree = g.in_degrees(dst)\n\n        degree = torch.min(head_degree, tail_degree)\n        degree_mask = degree < self.degree_threshold\n        edges_need_to_exclude = seed_edges[degree_mask]\n        return find_exclude_eids(\n            g,\n            edges_need_to_exclude,\n            self.exclude,\n            self.reverse_eids,\n            self.reverse_etypes,\n        )\n"
  },
  {
    "path": "python/dgl/distgnn/__init__.py",
    "content": "\"\"\"\nThis package contains DistGNN and Libra based graph partitioning tools.\n\"\"\"\nfrom . import partition, tools\n"
  },
  {
    "path": "python/dgl/distgnn/partition/__init__.py",
    "content": "\"\"\"\nThis package contains Libra graph partitioner.\n\"\"\"\nfrom .libra_partition import partition_graph\n"
  },
  {
    "path": "python/dgl/distgnn/partition/libra_partition.py",
    "content": "r\"\"\"Libra partition functions.\n\nLibra partition is a vertex-cut based partitioning algorithm from\n`Distributed Power-law Graph Computing:\nTheoretical and Empirical Analysis\n<https://proceedings.neurips.cc/paper/2014/file/67d16d00201083a2b118dd5128dd6f59-Paper.pdf>`__\nfrom Xie et al.\n\"\"\"\n\n# Copyright (c) 2021 Intel Corporation\n#  \\file distgnn/partition/libra_partition.py\n#  \\brief Libra - Vertex-cut based graph partitioner for distributed training\n#  \\author Vasimuddin Md <vasimuddin.md@intel.com>,\n#          Guixiang Ma <guixiang.ma@intel.com>\n#          Sanchit Misra <sanchit.misra@intel.com>,\n#          Ramanarayan Mohanty <ramanarayan.mohanty@intel.com>,\n#          Sasikanth Avancha <sasikanth.avancha@intel.com>\n#          Nesreen K. Ahmed <nesreen.k.ahmed@intel.com>\n#  \\cite Distributed Power-law Graph Computing: Theoretical and Empirical Analysis\n\nimport json\nimport os\nimport time\n\nimport torch as th\n\nfrom dgl import DGLGraph\nfrom dgl._sparse_ops import (\n    libra2dgl_build_adjlist,\n    libra2dgl_build_dict,\n    libra2dgl_set_lr,\n    libra_vertex_cut,\n)\nfrom dgl.base import DGLError\nfrom dgl.data.utils import save_graphs, save_tensors\n\n\ndef libra_partition(num_community, G, resultdir):\n    \"\"\"\n    Performs vertex-cut based graph partitioning and converts the partitioning\n    output to DGL input format.\n\n    Parameters\n    ----------\n    num_community : Number of partitions to create\n    G : Input graph to be partitioned\n    resultdir : Output location for storing the partitioned graphs\n\n    Output\n    ------\n    1. Creates X partition folder as XCommunities (say, X=2, so, 2Communities)\n       XCommunities contains file name communityZ.txt per partition Z (Z <- 0 .. X-1);\n       each such file contains a list of edges assigned to that partition.\n       These files constitute the output of Libra graph partitioner\n       (An intermediate result of this function).\n    2. The folder also contains partZ folders, each of these folders stores\n       DGL/DistGNN graphs for the Z partitions;\n       these graph files are used as input to DistGNN.\n    3. The folder also contains a json file which contains partitions' information.\n    \"\"\"\n\n    num_nodes = G.num_nodes()  # number of nodes\n    num_edges = G.num_edges()  # number of edges\n    print(\"Number of nodes in the graph: \", num_nodes)\n    print(\"Number of edges in the graph: \", num_edges)\n\n    in_d = G.in_degrees()\n    out_d = G.out_degrees()\n    node_degree = in_d + out_d\n    edgenum_unassigned = node_degree.clone()\n\n    u_t, v_t = G.edges()\n    weight_ = th.ones(u_t.shape[0], dtype=th.int64)\n    community_weights = th.zeros(num_community, dtype=th.int64)\n\n    # self_loop = 0\n    # for p, q in zip(u_t, v_t):\n    #     if p == q:\n    #         self_loop += 1\n    # print(\"#self loops in the dataset: \", self_loop)\n\n    # del G\n\n    ## call to C/C++ code\n    out = th.zeros(u_t.shape[0], dtype=th.int32)\n    libra_vertex_cut(\n        num_community,\n        node_degree,\n        edgenum_unassigned,\n        community_weights,\n        u_t,\n        v_t,\n        weight_,\n        out,\n        num_nodes,\n        num_edges,\n        resultdir,\n    )\n\n    print(\"Max partition size: \", int(community_weights.max()))\n    print(\" ** Converting libra partitions to dgl graphs **\")\n    fsize = int(community_weights.max()) + 1024  ## max edges in partition\n    # print(\"fsize: \", fsize, flush=True)\n\n    node_map = th.zeros(num_community, dtype=th.int64)\n    indices = th.zeros(num_nodes, dtype=th.int64)\n    lrtensor = th.zeros(num_nodes, dtype=th.int64)\n    gdt_key = th.zeros(num_nodes, dtype=th.int64)\n    gdt_value = th.zeros([num_nodes, num_community], dtype=th.int64)\n    offset = th.zeros(1, dtype=th.int64)\n    ldt_ar = []\n\n    gg_ar = [DGLGraph() for i in range(num_community)]\n    part_nodes = []\n\n    print(\">>> \", \"num_nodes   \", \" \", \"num_edges\")\n    ## Iterator over number of partitions\n    for i in range(num_community):\n        g = gg_ar[i]\n\n        a_t = th.zeros(fsize, dtype=th.int64)\n        b_t = th.zeros(fsize, dtype=th.int64)\n        ldt_key = th.zeros(fsize, dtype=th.int64)\n        ldt_ar.append(ldt_key)\n\n        ## building node, parition dictionary\n        ## Assign local node ids and mapping to global node ids\n        ret = libra2dgl_build_dict(\n            a_t,\n            b_t,\n            indices,\n            ldt_key,\n            gdt_key,\n            gdt_value,\n            node_map,\n            offset,\n            num_community,\n            i,\n            fsize,\n            resultdir,\n        )\n\n        num_nodes_partition = int(ret[0])\n        num_edges_partition = int(ret[1])\n        part_nodes.append(num_nodes_partition)\n        print(\">>> \", num_nodes_partition, \" \", num_edges_partition)\n        g.add_edges(a_t[0:num_edges_partition], b_t[0:num_edges_partition])\n\n    ########################################################\n    ## fixing lr - 1-level tree for the split-nodes\n    libra2dgl_set_lr(gdt_key, gdt_value, lrtensor, num_community, num_nodes)\n    ########################################################\n    # graph_name = dataset\n    graph_name = resultdir.split(\"_\")[-1].split(\"/\")[0]\n    part_method = \"Libra\"\n    num_parts = num_community  ## number of paritions/communities\n    num_hops = 0\n    node_map_val = node_map.tolist()\n    edge_map_val = 0\n    out_path = resultdir\n\n    part_metadata = {\n        \"graph_name\": graph_name,\n        \"num_nodes\": G.num_nodes(),\n        \"num_edges\": G.num_edges(),\n        \"part_method\": part_method,\n        \"num_parts\": num_parts,\n        \"halo_hops\": num_hops,\n        \"node_map\": node_map_val,\n        \"edge_map\": edge_map_val,\n    }\n    ############################################################\n\n    for i in range(num_community):\n        g = gg_ar[0]\n        num_nodes_partition = part_nodes[i]\n        adj = th.zeros([num_nodes_partition, num_community - 1], dtype=th.int64)\n        inner_node = th.zeros(num_nodes_partition, dtype=th.int32)\n        lr_t = th.zeros(num_nodes_partition, dtype=th.int64)\n        ldt = ldt_ar[0]\n\n        try:\n            feat = G.ndata[\"feat\"]\n        except KeyError:\n            feat = G.ndata[\"features\"]\n\n        try:\n            labels = G.ndata[\"label\"]\n        except KeyError:\n            labels = G.ndata[\"labels\"]\n\n        trainm = G.ndata[\"train_mask\"].int()\n        testm = G.ndata[\"test_mask\"].int()\n        valm = G.ndata[\"val_mask\"].int()\n\n        feat_size = feat.shape[1]\n        gfeat = th.zeros([num_nodes_partition, feat_size], dtype=feat.dtype)\n\n        glabels = th.zeros(num_nodes_partition, dtype=labels.dtype)\n        gtrainm = th.zeros(num_nodes_partition, dtype=trainm.dtype)\n        gtestm = th.zeros(num_nodes_partition, dtype=testm.dtype)\n        gvalm = th.zeros(num_nodes_partition, dtype=valm.dtype)\n\n        ## build remote node databse per local node\n        ## gather feats, train, test, val, and labels for each partition\n        libra2dgl_build_adjlist(\n            feat,\n            gfeat,\n            adj,\n            inner_node,\n            ldt,\n            gdt_key,\n            gdt_value,\n            node_map,\n            lr_t,\n            lrtensor,\n            num_nodes_partition,\n            num_community,\n            i,\n            feat_size,\n            labels,\n            trainm,\n            testm,\n            valm,\n            glabels,\n            gtrainm,\n            gtestm,\n            gvalm,\n            feat.shape[0],\n        )\n\n        g.ndata[\"adj\"] = adj  ## database of remote clones\n        g.ndata[\"inner_node\"] = inner_node  ## split node '0' else '1'\n        g.ndata[\"feat\"] = gfeat  ## gathered features\n        g.ndata[\"lf\"] = lr_t  ## 1-level tree among split nodes\n\n        g.ndata[\"label\"] = glabels\n        g.ndata[\"train_mask\"] = gtrainm\n        g.ndata[\"test_mask\"] = gtestm\n        g.ndata[\"val_mask\"] = gvalm\n\n        # Validation code, run only small graphs\n        # for l in range(num_nodes_partition):\n        #     index = int(ldt[l])\n        #     assert glabels[l] == labels[index]\n        #     assert gtrainm[l] == trainm[index]\n        #     assert gtestm[l] == testm[index]\n        #     for j in range(feat_size):\n        #         assert gfeat[l][j] == feat[index][j]\n\n        print(\"Writing partition {} to file\".format(i), flush=True)\n\n        part = g\n        part_id = i\n        part_dir = os.path.join(out_path, \"part\" + str(part_id))\n        node_feat_file = os.path.join(part_dir, \"node_feat.dgl\")\n        edge_feat_file = os.path.join(part_dir, \"edge_feat.dgl\")\n        part_graph_file = os.path.join(part_dir, \"graph.dgl\")\n        part_metadata[\"part-{}\".format(part_id)] = {\n            \"node_feats\": node_feat_file,\n            \"edge_feats\": edge_feat_file,\n            \"part_graph\": part_graph_file,\n        }\n        os.makedirs(part_dir, mode=0o775, exist_ok=True)\n        save_tensors(node_feat_file, part.ndata)\n        save_graphs(part_graph_file, [part])\n\n        del g\n        del gg_ar[0]\n        del ldt\n        del ldt_ar[0]\n\n    with open(\"{}/{}.json\".format(out_path, graph_name), \"w\") as outfile:\n        json.dump(part_metadata, outfile, sort_keys=True, indent=4)\n\n    print(\"Conversion libra2dgl completed !!!\")\n\n\ndef partition_graph(num_community, G, resultdir):\n    \"\"\"\n    Performs vertex-cut based graph partitioning and converts the partitioning\n    output to DGL input format.\n\n    Given a graph, this function will create a folder named ``XCommunities`` where ``X``\n    stands for the number of communities.  It will contain ``X`` files named\n    ``communityZ.txt`` for each partition Z (from 0 to X-1);\n    each such file contains a list of edges assigned to that partition.\n    These files constitute the output of Libra graph partitioner.\n\n    The folder also contains X subfolders named ``partZ``, each of these folders stores\n    DGL/DistGNN graphs for partition Z; these graph files are used as input to\n    DistGNN.\n\n    The folder also contains a json file which contains partitions' information.\n\n    Currently we require the graph's node data to contain the following columns:\n\n    * ``features`` for node features.\n    * ``label`` for node labels.\n    * ``train_mask`` as a boolean mask of training node set.\n    * ``val_mask`` as a boolean mask of validation node set.\n    * ``test_mask`` as a boolean mask of test node set.\n\n    Parameters\n    ----------\n    num_community : int\n        Number of partitions to create.\n    G : DGLGraph\n        Input graph to be partitioned.\n    resultdir : str\n        Output location for storing the partitioned graphs.\n    \"\"\"\n\n    print(\"num partitions: \", num_community)\n    print(\"output location: \", resultdir)\n\n    ## create ouptut directory\n    try:\n        os.makedirs(resultdir, mode=0o775, exist_ok=True)\n    except:\n        raise DGLError(\"Error: Could not create directory: \", resultdir)\n\n    tic = time.time()\n    print(\n        \"####################################################################\"\n    )\n    print(\"Executing parititons: \", num_community)\n    ltic = time.time()\n    try:\n        resultdir = os.path.join(resultdir, str(num_community) + \"Communities\")\n        os.makedirs(resultdir, mode=0o775, exist_ok=True)\n    except:\n        raise DGLError(\"Error: Could not create sub-directory: \", resultdir)\n\n    ## Libra partitioning\n    libra_partition(num_community, G, resultdir)\n\n    ltoc = time.time()\n    print(\n        \"Time taken by {} partitions {:0.4f} sec\".format(\n            num_community, ltoc - ltic\n        )\n    )\n    print()\n\n    toc = time.time()\n    print(\n        \"Generated \",\n        num_community,\n        \" partitions in {:0.4f} sec\".format(toc - tic),\n        flush=True,\n    )\n    print(\"Partitioning completed successfully !!!\")\n"
  },
  {
    "path": "python/dgl/distgnn/tools/__init__.py",
    "content": "\"\"\"\nThis package contains extra routines related to Libra graph partitioner.\n\"\"\"\nfrom .tools import load_proteins\n"
  },
  {
    "path": "python/dgl/distgnn/tools/tools.py",
    "content": "r\"\"\"\nCopyright (c) 2021 Intel Corporation\n \\file distgnn/tools/tools.py\n \\brief Tools for use in Libra graph partitioner.\n \\author Vasimuddin Md <vasimuddin.md@intel.com>\n\"\"\"\n\nimport os\nimport random\n\nimport requests\nimport torch as th\nfrom scipy.io import mmread\n\nimport dgl\nfrom dgl.base import DGLError\nfrom dgl.data.utils import load_graphs, save_graphs, save_tensors\n\n\ndef rep_per_node(prefix, num_community):\n    \"\"\"\n    Used on Libra partitioned data.\n    This function reports number of split-copes per node (replication) of\n    a partitioned graph\n    Parameters\n    ----------\n    prefix: Partition folder location (contains replicationlist.csv)\n    num_community: number of partitions or communities\n    \"\"\"\n    ifile = os.path.join(prefix, \"replicationlist.csv\")\n    fhandle = open(ifile, \"r\")\n    r_dt = {}\n\n    fline = fhandle.readline()  ## reading first line, contains the comment.\n    print(fline)\n    for line in fhandle:\n        if line[0] == \"#\":\n            raise DGLError(\"[Bug] Read Hash char in rep_per_node func.\")\n\n        node = line.strip(\"\\n\")\n        if r_dt.get(node, -100) == -100:\n            r_dt[node] = 1\n        else:\n            r_dt[node] += 1\n\n    fhandle.close()\n    ## sanity checks\n    for v in r_dt.values():\n        if v >= num_community:\n            raise DGLError(\n                \"[Bug] Unexpected event in rep_per_node() in tools.py.\"\n            )\n\n    return r_dt\n\n\ndef download_proteins():\n    \"\"\"\n    Downloads the proteins dataset\n    \"\"\"\n    print(\"Downloading dataset...\")\n    print(\"This might a take while..\")\n    url = \"https://portal.nersc.gov/project/m1982/GNN/\"\n    file_name = \"subgraph3_iso_vs_iso_30_70length_ALL.m100.propermm.mtx\"\n    url = url + file_name\n    try:\n        req = requests.get(url)\n    except:\n        raise DGLError(\n            \"Error: Failed to download Proteins dataset!! Aborting..\"\n        )\n\n    with open(\"proteins.mtx\", \"wb\") as handle:\n        handle.write(req.content)\n\n\ndef proteins_mtx2dgl():\n    \"\"\"\n    This function converts Proteins dataset from mtx to dgl format.\n    \"\"\"\n    print(\"Converting mtx2dgl..\")\n    print(\"This might a take while..\")\n    a_mtx = mmread(\"proteins.mtx\")\n    coo = a_mtx.tocoo()\n    u = th.tensor(coo.row, dtype=th.int64)\n    v = th.tensor(coo.col, dtype=th.int64)\n    g = dgl.DGLGraph()\n\n    g.add_edges(u, v)\n\n    n = g.num_nodes()\n    feat_size = 128  ## arbitrary number\n    feats = th.empty([n, feat_size], dtype=th.float32)\n\n    ## arbitrary numbers\n    train_size = 1000000\n    test_size = 500000\n    val_size = 5000\n    nlabels = 256\n\n    train_mask = th.zeros(n, dtype=th.bool)\n    test_mask = th.zeros(n, dtype=th.bool)\n    val_mask = th.zeros(n, dtype=th.bool)\n    label = th.zeros(n, dtype=th.int64)\n\n    for i in range(train_size):\n        train_mask[i] = True\n\n    for i in range(test_size):\n        test_mask[train_size + i] = True\n\n    for i in range(val_size):\n        val_mask[train_size + test_size + i] = True\n\n    for i in range(n):\n        label[i] = random.choice(range(nlabels))\n\n    g.ndata[\"feat\"] = feats\n    g.ndata[\"train_mask\"] = train_mask\n    g.ndata[\"test_mask\"] = test_mask\n    g.ndata[\"val_mask\"] = val_mask\n    g.ndata[\"label\"] = label\n\n    return g\n\n\ndef save(g, dataset):\n    \"\"\"\n    This function saves input dataset to dgl format\n    Parameters\n    ----------\n    g : graph to be saved\n    dataset : output folder name\n    \"\"\"\n    print(\"Saving dataset..\")\n    part_dir = os.path.join(\"./\" + dataset)\n    node_feat_file = os.path.join(part_dir, \"node_feat.dgl\")\n    part_graph_file = os.path.join(part_dir, \"graph.dgl\")\n    os.makedirs(part_dir, mode=0o775, exist_ok=True)\n    save_tensors(node_feat_file, g.ndata)\n    save_graphs(part_graph_file, [g])\n    print(\"Graph saved successfully !!\")\n\n\ndef load_proteins(dataset):\n    \"\"\"\n    This function downloads, converts, and load Proteins graph dataset\n    Parameter\n    ---------\n    dataset: output folder name\n    \"\"\"\n    part_dir = dataset\n    graph_file = os.path.join(part_dir + \"/graph.dgl\")\n\n    if not os.path.exists(\"proteins.mtx\"):\n        download_proteins()\n    if not os.path.exists(graph_file):\n        g = proteins_mtx2dgl()\n        save(g, dataset)\n    ## load\n    graph = load_graphs(graph_file)[0][0]\n    return graph\n"
  },
  {
    "path": "python/dgl/distributed/__init__.py",
    "content": "\"\"\"DGL distributed module\"\"\"\n\nfrom . import optim\nfrom .dist_context import exit_client, initialize\nfrom .dist_dataloader import (\n    DistDataLoader,\n    DistEdgeDataLoader,\n    DistNodeDataLoader,\n    EdgeCollator,\n    NodeCollator,\n)\nfrom .dist_graph import DistGraph, DistGraphServer, edge_split, node_split\nfrom .dist_tensor import DistTensor\nfrom .graph_partition_book import GraphPartitionBook, PartitionPolicy\nfrom .graph_services import *\nfrom .kvstore import KVClient, KVServer\nfrom .nn import *\nfrom .partition import (\n    dgl_partition_to_graphbolt,\n    gb_convert_single_dgl_partition,\n    load_partition,\n    load_partition_book,\n    load_partition_feats,\n    partition_graph,\n)\nfrom .rpc import *\nfrom .rpc_client import connect_to_server\nfrom .rpc_server import start_server\nfrom .server_state import ServerState\nfrom .constants import *\n"
  },
  {
    "path": "python/dgl/distributed/constants.py",
    "content": "\"\"\"Define all the constants used by DGL rpc\"\"\"\n\n# Maximum size of message queue in bytes\nMAX_QUEUE_SIZE = 20 * 1024 * 1024 * 1024\n\nSERVER_EXIT = \"server_exit\"\n\nDEFAULT_NTYPE = \"_N\"\nDEFAULT_ETYPE = (DEFAULT_NTYPE, \"_E\", DEFAULT_NTYPE)\n\nDGL2GB_EID = \"_dgl2gb_eid\"\nGB_DST_ID = \"_gb_dst_id\"\n"
  },
  {
    "path": "python/dgl/distributed/dist_context.py",
    "content": "\"\"\"Initialize the distributed services\"\"\"\n# pylint: disable=line-too-long\n\nimport atexit\nimport gc\nimport multiprocessing as mp\nimport os\nimport queue\nimport sys\nimport time\nimport traceback\nfrom enum import Enum\n\nfrom .. import utils\nfrom ..base import dgl_warning, DGLError\nfrom . import rpc\nfrom .constants import MAX_QUEUE_SIZE\nfrom .kvstore import close_kvstore, init_kvstore\nfrom .role import init_role\nfrom .rpc_client import connect_to_server\n\nSAMPLER_POOL = None\nNUM_SAMPLER_WORKERS = 0\nINITIALIZED = False\n\n\ndef set_initialized(value=True):\n    \"\"\"Set the initialized state of rpc\"\"\"\n    global INITIALIZED\n    INITIALIZED = value\n\n\ndef get_sampler_pool():\n    \"\"\"Return the sampler pool and num_workers\"\"\"\n    return SAMPLER_POOL, NUM_SAMPLER_WORKERS\n\n\ndef _init_rpc(\n    ip_config,\n    num_servers,\n    max_queue_size,\n    role,\n    num_threads,\n    group_id,\n):\n    \"\"\"This init function is called in the worker processes.\"\"\"\n    try:\n        utils.set_num_threads(num_threads)\n        if os.environ.get(\"DGL_DIST_MODE\", \"standalone\") != \"standalone\":\n            connect_to_server(ip_config, num_servers, max_queue_size, group_id)\n        init_role(role)\n        init_kvstore(ip_config, num_servers, role)\n    except Exception as e:\n        print(e, flush=True)\n        traceback.print_exc()\n        raise e\n\n\nclass MpCommand(Enum):\n    \"\"\"Enum class for multiprocessing command\"\"\"\n\n    INIT_RPC = 0  # Not used in the task queue\n    SET_COLLATE_FN = 1\n    CALL_BARRIER = 2\n    DELETE_COLLATE_FN = 3\n    CALL_COLLATE_FN = 4\n    CALL_FN_ALL_WORKERS = 5\n    FINALIZE_POOL = 6\n\n\ndef init_process(rpc_config, mp_contexts):\n    \"\"\"Work loop in the worker\"\"\"\n    try:\n        _init_rpc(*rpc_config)\n        keep_polling = True\n        data_queue, task_queue, barrier = mp_contexts\n        collate_fn_dict = {}\n\n        while keep_polling:\n            try:\n                # Follow https://github.com/pytorch/pytorch/blob/d57ce8cf8989c0b737e636d8d7abe16c1f08f70b/torch/utils/data/_utils/worker.py#L260\n                command, args = task_queue.get(timeout=5)\n            except queue.Empty:\n                continue\n            if command == MpCommand.SET_COLLATE_FN:\n                dataloader_name, func = args\n                collate_fn_dict[dataloader_name] = func\n            elif command == MpCommand.CALL_BARRIER:\n                barrier.wait()\n            elif command == MpCommand.DELETE_COLLATE_FN:\n                (dataloader_name,) = args\n                del collate_fn_dict[dataloader_name]\n            elif command == MpCommand.CALL_COLLATE_FN:\n                dataloader_name, collate_args = args\n                data_queue.put(\n                    (\n                        dataloader_name,\n                        collate_fn_dict[dataloader_name](collate_args),\n                    )\n                )\n            elif command == MpCommand.CALL_FN_ALL_WORKERS:\n                func, func_args = args\n                func(func_args)\n            elif command == MpCommand.FINALIZE_POOL:\n                _exit()\n                keep_polling = False\n            else:\n                raise Exception(\"Unknown command\")\n    except Exception as e:\n        traceback.print_exc()\n        raise e\n\n\nclass CustomPool:\n    \"\"\"Customized worker pool\"\"\"\n\n    def __init__(self, num_workers, rpc_config):\n        \"\"\"\n        Customized worker pool init function\n        \"\"\"\n        ctx = mp.get_context(\"spawn\")\n        self.num_workers = num_workers\n        # As pool could be used by any number of dataloaders, queues\n        # should be able to take infinite elements to avoid dead lock.\n        self.queue_size = 0\n        self.result_queue = ctx.Queue(self.queue_size)\n        self.results = {}  # key is dataloader name, value is fetched batch.\n        self.task_queues = []\n        self.process_list = []\n        self.current_proc_id = 0\n        self.cache_result_dict = {}\n        self.barrier = ctx.Barrier(num_workers)\n        for _ in range(num_workers):\n            task_queue = ctx.Queue(self.queue_size)\n            self.task_queues.append(task_queue)\n            proc = ctx.Process(\n                target=init_process,\n                args=(\n                    rpc_config,\n                    (self.result_queue, task_queue, self.barrier),\n                ),\n            )\n            proc.daemon = True\n            proc.start()\n            self.process_list.append(proc)\n\n    def set_collate_fn(self, func, dataloader_name):\n        \"\"\"Set collate function in subprocess\"\"\"\n        for i in range(self.num_workers):\n            self.task_queues[i].put(\n                (MpCommand.SET_COLLATE_FN, (dataloader_name, func))\n            )\n        self.results[dataloader_name] = []\n\n    def submit_task(self, dataloader_name, args):\n        \"\"\"Submit task to workers\"\"\"\n        # Round robin\n        self.task_queues[self.current_proc_id].put(\n            (MpCommand.CALL_COLLATE_FN, (dataloader_name, args))\n        )\n        self.current_proc_id = (self.current_proc_id + 1) % self.num_workers\n\n    def submit_task_to_all_workers(self, func, args):\n        \"\"\"Submit task to all workers\"\"\"\n        for i in range(self.num_workers):\n            self.task_queues[i].put(\n                (MpCommand.CALL_FN_ALL_WORKERS, (func, args))\n            )\n\n    def get_result(self, dataloader_name, timeout=1800):\n        \"\"\"Get result from result queue\"\"\"\n        if dataloader_name not in self.results:\n            raise DGLError(\n                f\"Got result from an unknown dataloader {dataloader_name}.\"\n            )\n        while len(self.results[dataloader_name]) == 0:\n            dl_name, data = self.result_queue.get(timeout=timeout)\n            self.results[dl_name].append(data)\n        return self.results[dataloader_name].pop(0)\n\n    def delete_collate_fn(self, dataloader_name):\n        \"\"\"Delete collate function\"\"\"\n        for i in range(self.num_workers):\n            self.task_queues[i].put(\n                (MpCommand.DELETE_COLLATE_FN, (dataloader_name,))\n            )\n        del self.results[dataloader_name]\n\n    def call_barrier(self):\n        \"\"\"Call barrier at all workers\"\"\"\n        for i in range(self.num_workers):\n            self.task_queues[i].put((MpCommand.CALL_BARRIER, tuple()))\n\n    def close(self):\n        \"\"\"Close worker pool\"\"\"\n        for i in range(self.num_workers):\n            self.task_queues[i].put(\n                (MpCommand.FINALIZE_POOL, tuple()), block=False\n            )\n            time.sleep(0.5)  # Fix for early python version\n\n    def join(self):\n        \"\"\"Join the close process of worker pool\"\"\"\n        for i in range(self.num_workers):\n            self.process_list[i].join()\n\n\ndef initialize(\n    ip_config,\n    max_queue_size=MAX_QUEUE_SIZE,\n    net_type=None,\n    num_worker_threads=1,\n    use_graphbolt=False,\n):\n    \"\"\"Initialize DGL's distributed module\n\n    This function initializes DGL's distributed module. It acts differently in server\n    or client modes. In the server mode, it runs the server code and never returns.\n    In the client mode, it builds connections with servers for communication and\n    creates worker processes for distributed sampling.\n\n    Parameters\n    ----------\n    ip_config: str\n        File path of ip_config file\n    max_queue_size : int\n        Maximal size (bytes) of client queue buffer (~20 GB on default).\n\n        Note that the 20 GB is just an upper-bound and DGL uses zero-copy and\n        it will not allocate 20GB memory at once.\n    net_type : str, optional\n        [Deprecated] Networking type, can be 'socket' only.\n    num_worker_threads: int\n        The number of OMP threads in each sampler process.\n    use_graphbolt: bool, optional\n        Whether to use GraphBolt for distributed train.\n\n    Note\n    ----\n    Users have to invoke this API before any DGL's distributed API and framework-specific\n    distributed API. For example, when used with Pytorch, users have to invoke this function\n    before Pytorch's `pytorch.distributed.init_process_group`.\n    \"\"\"\n    print(\n        f\"Initialize the distributed services with graphbolt: {use_graphbolt}\"\n    )\n    if net_type is not None:\n        dgl_warning(\n            \"net_type is deprecated and will be removed in future release.\"\n        )\n    if os.environ.get(\"DGL_ROLE\", \"client\") == \"server\":\n        from .dist_graph import DistGraphServer\n\n        assert (\n            os.environ.get(\"DGL_SERVER_ID\") is not None\n        ), \"Please define DGL_SERVER_ID to run DistGraph server\"\n        assert (\n            os.environ.get(\"DGL_IP_CONFIG\") is not None\n        ), \"Please define DGL_IP_CONFIG to run DistGraph server\"\n        assert (\n            os.environ.get(\"DGL_NUM_SERVER\") is not None\n        ), \"Please define DGL_NUM_SERVER to run DistGraph server\"\n        assert (\n            os.environ.get(\"DGL_NUM_CLIENT\") is not None\n        ), \"Please define DGL_NUM_CLIENT to run DistGraph server\"\n        assert (\n            os.environ.get(\"DGL_CONF_PATH\") is not None\n        ), \"Please define DGL_CONF_PATH to run DistGraph server\"\n        formats = os.environ.get(\"DGL_GRAPH_FORMAT\", \"csc\").split(\",\")\n        formats = [f.strip() for f in formats]\n        rpc.reset()\n        serv = DistGraphServer(\n            int(os.environ.get(\"DGL_SERVER_ID\")),\n            os.environ.get(\"DGL_IP_CONFIG\"),\n            int(os.environ.get(\"DGL_NUM_SERVER\")),\n            int(os.environ.get(\"DGL_NUM_CLIENT\")),\n            os.environ.get(\"DGL_CONF_PATH\"),\n            graph_format=formats,\n            use_graphbolt=use_graphbolt,\n        )\n        serv.start()\n        sys.exit()\n    else:\n        num_workers = int(os.environ.get(\"DGL_NUM_SAMPLER\", 0))\n        num_servers = int(os.environ.get(\"DGL_NUM_SERVER\", 1))\n        group_id = int(os.environ.get(\"DGL_GROUP_ID\", 0))\n        rpc.reset()\n        global SAMPLER_POOL\n        global NUM_SAMPLER_WORKERS\n        is_standalone = (\n            os.environ.get(\"DGL_DIST_MODE\", \"standalone\") == \"standalone\"\n        )\n        if num_workers > 0 and not is_standalone:\n            SAMPLER_POOL = CustomPool(\n                num_workers,\n                (\n                    ip_config,\n                    num_servers,\n                    max_queue_size,\n                    \"sampler\",\n                    num_worker_threads,\n                    group_id,\n                ),\n            )\n        else:\n            SAMPLER_POOL = None\n        NUM_SAMPLER_WORKERS = num_workers\n        if not is_standalone:\n            assert (\n                num_servers is not None and num_servers > 0\n            ), \"The number of servers per machine must be specified with a positive number.\"\n            connect_to_server(\n                ip_config,\n                num_servers,\n                max_queue_size,\n                group_id=group_id,\n            )\n        init_role(\"default\")\n        init_kvstore(ip_config, num_servers, \"default\")\n\n\ndef finalize_client():\n    \"\"\"Release resources of this client.\"\"\"\n    if os.environ.get(\"DGL_DIST_MODE\", \"standalone\") != \"standalone\":\n        rpc.finalize_sender()\n        rpc.finalize_receiver()\n\n\ndef _exit():\n    exit_client()\n    time.sleep(1)\n\n\ndef finalize_worker():\n    \"\"\"Finalize workers\n    Python's multiprocessing pool will not call atexit function when close\n    \"\"\"\n    global SAMPLER_POOL\n    if SAMPLER_POOL is not None:\n        SAMPLER_POOL.close()\n\n\ndef join_finalize_worker():\n    \"\"\"join the worker close process\"\"\"\n    global SAMPLER_POOL\n    if SAMPLER_POOL is not None:\n        SAMPLER_POOL.join()\n    SAMPLER_POOL = None\n\n\ndef is_initialized():\n    \"\"\"Is RPC initialized?\"\"\"\n    return INITIALIZED\n\n\ndef _shutdown_servers():\n    set_initialized(False)\n    # send ShutDownRequest to servers\n    if rpc.get_rank() == 0:  # Only client_0 issue this command\n        req = rpc.ShutDownRequest(rpc.get_rank())\n        for server_id in range(rpc.get_num_server()):\n            rpc.send_request(server_id, req)\n\n\ndef exit_client():\n    \"\"\"Trainer exits\n\n    This function is called automatically when a Python process exits. Normally,\n    the training script does not need to invoke this function at the end.\n\n    In the case that the training script needs to initialize the distributed module\n    multiple times (so far, this is needed in the unit tests), the training script\n    needs to call `exit_client` before calling `initialize` again.\n    \"\"\"\n    # Only client with rank_0 will send shutdown request to servers.\n    print(\n        \"Client[{}] in group[{}] is exiting...\".format(\n            rpc.get_rank(), rpc.get_group_id()\n        )\n    )\n    finalize_worker()  # finalize workers should be earilier than barrier, and non-blocking\n    # collect data such as DistTensor before exit\n    gc.collect()\n    if os.environ.get(\"DGL_DIST_MODE\", \"standalone\") != \"standalone\":\n        rpc.client_barrier()\n        _shutdown_servers()\n    finalize_client()\n    join_finalize_worker()\n    close_kvstore()\n    atexit.unregister(exit_client)\n"
  },
  {
    "path": "python/dgl/distributed/dist_dataloader.py",
    "content": "# pylint: disable=global-variable-undefined, invalid-name\n\"\"\"Multiprocess dataloader for distributed training\"\"\"\nimport inspect\nfrom abc import ABC, abstractmethod\nfrom collections.abc import Mapping\n\nfrom .. import backend as F, transforms, utils\nfrom ..base import EID, NID\nfrom ..convert import heterograph\nfrom .dist_context import get_sampler_pool\n\n__all__ = [\n    \"NodeCollator\",\n    \"EdgeCollator\",\n    \"DistDataLoader\",\n    \"DistNodeDataLoader\",\n    \"DistEdgeDataLoader\",\n]\n\nDATALOADER_ID = 0\n\n\nclass DistDataLoader:\n    \"\"\"DGL customized multiprocessing dataloader.\n\n    DistDataLoader provides a similar interface to Pytorch's DataLoader to generate mini-batches\n    with multiprocessing. It utilizes the worker processes created by\n    :func:`dgl.distributed.initialize` to parallelize sampling.\n\n    Parameters\n    ----------\n    dataset: a tensor\n        Tensors of node IDs or edge IDs.\n    batch_size: int\n        The number of samples per batch to load.\n    shuffle: bool, optional\n        Set to ``True`` to have the data reshuffled at every epoch (default: ``False``).\n    collate_fn: callable, optional\n        The function is typically used to sample neighbors of the nodes in a batch\n        or the endpoint nodes of the edges in a batch.\n    drop_last: bool, optional\n        Set to ``True`` to drop the last incomplete batch, if the dataset size is not\n        divisible by the batch size. If ``False`` and the size of dataset is not divisible\n        by the batch size, then the last batch will be smaller. (default: ``False``)\n    queue_size: int, optional\n        Size of multiprocessing queue\n\n    Examples\n    --------\n    >>> g = dgl.distributed.DistGraph('graph-name')\n    >>> def sample(seeds):\n    ...     seeds = th.LongTensor(np.asarray(seeds))\n    ...     frontier = dgl.distributed.sample_neighbors(g, seeds, 10)\n    ...     return dgl.to_block(frontier, seeds)\n    >>> dataloader = dgl.distributed.DistDataLoader(dataset=nodes, batch_size=1000,\n                                                    collate_fn=sample, shuffle=True)\n    >>> for block in dataloader:\n    ...     feat = g.ndata['features'][block.srcdata[dgl.NID]]\n    ...     labels = g.ndata['labels'][block.dstdata[dgl.NID]]\n    ...     pred = model(block, feat)\n\n    Note\n    ----\n    When performing DGL's distributed sampling with multiprocessing, users have to use this class\n    instead of Pytorch's DataLoader because DGL's RPC requires that all processes establish\n    connections with servers before invoking any DGL's distributed API. Therefore, this dataloader\n    uses the worker processes created in :func:`dgl.distributed.initialize`.\n\n    Note\n    ----\n    This dataloader does not guarantee the iteration order. For example,\n    if dataset = [1, 2, 3, 4], batch_size = 2 and shuffle = False, the order of [1, 2]\n    and [3, 4] is not guaranteed.\n    \"\"\"\n\n    def __init__(\n        self,\n        dataset,\n        batch_size,\n        shuffle=False,\n        collate_fn=None,\n        drop_last=False,\n        queue_size=None,\n    ):\n        self.pool, self.num_workers = get_sampler_pool()\n        if queue_size is None:\n            queue_size = self.num_workers * 4 if self.num_workers > 0 else 4\n        self.queue_size = queue_size  # prefetch size\n        self.batch_size = batch_size\n        self.num_pending = 0\n        self.collate_fn = collate_fn\n        self.current_pos = 0\n        self.queue = []  # Only used when pool is None\n        self.drop_last = drop_last\n        self.recv_idxs = 0\n        self.shuffle = shuffle\n        self.is_closed = False\n\n        self.dataset = dataset\n        self.data_idx = F.arange(0, len(dataset))\n        self.expected_idxs = len(dataset) // self.batch_size\n        if not self.drop_last and len(dataset) % self.batch_size != 0:\n            self.expected_idxs += 1\n\n        # We need to have a unique ID for each data loader to identify itself\n        # in the sampler processes.\n        global DATALOADER_ID\n        self.name = \"dataloader-\" + str(DATALOADER_ID)\n        DATALOADER_ID += 1\n\n        if self.pool is not None:\n            self.pool.set_collate_fn(self.collate_fn, self.name)\n\n    def __del__(self):\n        # When the process exits, the process pool may have been closed. We should try\n        # and get the process pool again and see if we need to clean up the process pool.\n        self.pool, self.num_workers = get_sampler_pool()\n        if self.pool is not None:\n            self.pool.delete_collate_fn(self.name)\n\n    def __next__(self):\n        if self.pool is None:\n            num_reqs = 1\n        else:\n            num_reqs = self.queue_size - self.num_pending\n        for _ in range(num_reqs):\n            self._request_next_batch()\n        if self.recv_idxs < self.expected_idxs:\n            result = self._get_data_from_result_queue()\n            self.recv_idxs += 1\n            self.num_pending -= 1\n            return result\n        else:\n            assert self.num_pending == 0\n            raise StopIteration\n\n    def _get_data_from_result_queue(self, timeout=1800):\n        if self.pool is None:\n            ret = self.queue.pop(0)\n        else:\n            ret = self.pool.get_result(self.name, timeout=timeout)\n        return ret\n\n    def __iter__(self):\n        if self.shuffle:\n            self.data_idx = F.rand_shuffle(self.data_idx)\n        self.recv_idxs = 0\n        self.current_pos = 0\n        self.num_pending = 0\n        return self\n\n    def _request_next_batch(self):\n        next_data = self._next_data()\n        if next_data is None:\n            return\n        elif self.pool is not None:\n            self.pool.submit_task(self.name, next_data)\n        else:\n            result = self.collate_fn(next_data)\n            self.queue.append(result)\n        self.num_pending += 1\n\n    def _next_data(self):\n        if self.current_pos == len(self.dataset):\n            return None\n\n        end_pos = 0\n        if self.current_pos + self.batch_size > len(self.dataset):\n            if self.drop_last:\n                return None\n            else:\n                end_pos = len(self.dataset)\n        else:\n            end_pos = self.current_pos + self.batch_size\n        idx = self.data_idx[self.current_pos : end_pos].tolist()\n        ret = [self.dataset[i] for i in idx]\n        # Sharing large number of tensors between processes will consume too many\n        # file descriptors, so let's convert each tensor to scalar value beforehand.\n        if isinstance(ret[0], tuple):\n            ret = [(type, F.as_scalar(id)) for (type, id) in ret]\n        else:\n            ret = [F.as_scalar(id) for id in ret]\n        self.current_pos = end_pos\n        return ret\n\n\n# [Note] As implementation of ``dgl.distributed.DistDataLoader`` is independent\n# of ``dgl.dataloading.DataLoader`` currently, dedicated collators are defined\n# here instead of using ``dgl.dataloading.CollateWrapper``.\n\n\ndef _find_exclude_eids_with_reverse_id(g, eids, reverse_eid_map):\n    if isinstance(eids, Mapping):\n        eids = {g.to_canonical_etype(k): v for k, v in eids.items()}\n        exclude_eids = {\n            k: F.cat([v, F.gather_row(reverse_eid_map[k], v)], 0)\n            for k, v in eids.items()\n        }\n    else:\n        exclude_eids = F.cat([eids, F.gather_row(reverse_eid_map, eids)], 0)\n    return exclude_eids\n\n\ndef _find_exclude_eids_with_reverse_types(g, eids, reverse_etype_map):\n    exclude_eids = {g.to_canonical_etype(k): v for k, v in eids.items()}\n    reverse_etype_map = {\n        g.to_canonical_etype(k): g.to_canonical_etype(v)\n        for k, v in reverse_etype_map.items()\n    }\n    exclude_eids.update(\n        {reverse_etype_map[k]: v for k, v in exclude_eids.items()}\n    )\n    return exclude_eids\n\n\ndef _find_exclude_eids(g, exclude_mode, eids, **kwargs):\n    \"\"\"Find all edge IDs to exclude according to :attr:`exclude_mode`.\n\n    Parameters\n    ----------\n    g : DGLGraph\n        The graph.\n    exclude_mode : str, optional\n        Can be either of the following,\n\n        None (default)\n            Does not exclude any edge.\n\n        'self'\n            Exclude the given edges themselves but nothing else.\n\n        'reverse_id'\n            Exclude all edges specified in ``eids``, as well as their reverse edges\n            of the same edge type.\n\n            The mapping from each edge ID to its reverse edge ID is specified in\n            the keyword argument ``reverse_eid_map``.\n\n            This mode assumes that the reverse of an edge with ID ``e`` and type\n            ``etype`` will have ID ``reverse_eid_map[e]`` and type ``etype``.\n\n        'reverse_types'\n            Exclude all edges specified in ``eids``, as well as their reverse\n            edges of the corresponding edge types.\n\n            The mapping from each edge type to its reverse edge type is specified\n            in the keyword argument ``reverse_etype_map``.\n\n            This mode assumes that the reverse of an edge with ID ``e`` and type ``etype``\n            will have ID ``e`` and type ``reverse_etype_map[etype]``.\n    eids : Tensor or dict[etype, Tensor]\n        The edge IDs.\n    reverse_eid_map : Tensor or dict[etype, Tensor]\n        The mapping from edge ID to its reverse edge ID.\n    reverse_etype_map : dict[etype, etype]\n        The mapping from edge etype to its reverse edge type.\n    \"\"\"\n    if exclude_mode is None:\n        return None\n    elif exclude_mode == \"self\":\n        if isinstance(eids, Mapping):\n            eids = {g.to_canonical_etype(k): v for k, v in eids.items()}\n        return eids\n    elif exclude_mode == \"reverse_id\":\n        return _find_exclude_eids_with_reverse_id(\n            g, eids, kwargs[\"reverse_eid_map\"]\n        )\n    elif exclude_mode == \"reverse_types\":\n        return _find_exclude_eids_with_reverse_types(\n            g, eids, kwargs[\"reverse_etype_map\"]\n        )\n    else:\n        raise ValueError(\"unsupported mode {}\".format(exclude_mode))\n\n\nclass Collator(ABC):\n    \"\"\"Abstract DGL collator for training GNNs on downstream tasks stochastically.\n\n    Provides a :attr:`dataset` object containing the collection of all nodes or edges,\n    as well as a :attr:`collate` method that combines a set of items from\n    :attr:`dataset` and obtains the message flow graphs (MFGs).\n\n    Notes\n    -----\n    For the concept of MFGs, please refer to\n    :ref:`User Guide Section 6 <guide-minibatch>` and\n    :doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.\n    \"\"\"\n\n    @property\n    @abstractmethod\n    def dataset(self):\n        \"\"\"Returns the dataset object of the collator.\"\"\"\n        raise NotImplementedError\n\n    @abstractmethod\n    def collate(self, items):\n        \"\"\"Combines the items from the dataset object and obtains the list of MFGs.\n\n        Parameters\n        ----------\n        items : list[str, int]\n            The list of node or edge IDs or type-ID pairs.\n\n        Notes\n        -----\n        For the concept of MFGs, please refer to\n        :ref:`User Guide Section 6 <guide-minibatch>` and\n        :doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.\n        \"\"\"\n        raise NotImplementedError\n\n    @staticmethod\n    def add_edge_attribute_to_graph(g, data_name, gb_padding):\n        \"\"\"Add data into the graph as an edge attribute.\n\n        For some cases such as prob/mask-based sampling on GraphBolt partitions,\n        we need to prepare such data beforehand. This is because data are\n        usually saved in DistGraph.ndata/edata, but such data is not in the\n        format that GraphBolt partitions require. And in GraphBolt, such data\n        are saved as edge attributes. So we need to add such data into the graph\n        before any sampling is kicked off.\n\n        Parameters\n        ----------\n        g : DistGraph\n            The graph.\n        data_name : str\n            The name of data that's stored in DistGraph.ndata/edata.\n        gb_padding : int, optional\n            The padding value for GraphBolt partitions' new edge_attributes.\n        \"\"\"\n        if g._use_graphbolt and data_name:\n            g.add_edge_attribute(data_name, gb_padding)\n\n\nclass NodeCollator(Collator):\n    \"\"\"DGL collator to combine nodes and their computation dependencies within a minibatch for\n    training node classification or regression on a single graph with neighborhood sampling.\n\n    Parameters\n    ----------\n    g : DGLGraph\n        The graph.\n    nids : Tensor or dict[ntype, Tensor]\n        The node set to compute outputs.\n    graph_sampler : dgl.dataloading.BlockSampler\n        The neighborhood sampler.\n    gb_padding : int, optional\n        The padding value for GraphBolt partitions' new edge_attributes if the attributes in DistGraph are None.\n        e.g. prob/mask-based sampling.\n        Only when the mask of one edge is set as 1, an edge will be sampled in dgl.graphbolt.FusedCSCSamplingGraph.sample_neighbors.\n        The argument will be used in add_edge_attribute_to_graph to add new edge_attributes in graphbolt.\n\n    Examples\n    --------\n    To train a 3-layer GNN for node classification on a set of nodes ``train_nid`` on\n    a homogeneous graph where each node takes messages from all neighbors (assume\n    the backend is PyTorch):\n\n    >>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])\n    >>> collator = dgl.dataloading.NodeCollator(g, train_nid, sampler)\n    >>> dataloader = torch.utils.data.DataLoader(\n    ...     collator.dataset, collate_fn=collator.collate,\n    ...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)\n    >>> for input_nodes, output_nodes, blocks in dataloader:\n    ...     train_on(input_nodes, output_nodes, blocks)\n\n    Notes\n    -----\n    For the concept of MFGs, please refer to\n    :ref:`User Guide Section 6 <guide-minibatch>` and\n    :doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.\n    \"\"\"\n\n    def __init__(self, g, nids, graph_sampler, gb_padding=1):\n        self.g = g\n        if not isinstance(nids, Mapping):\n            assert (\n                len(g.ntypes) == 1\n            ), \"nids should be a dict of node type and ids for graph with multiple node types\"\n        self.graph_sampler = graph_sampler\n\n        self.nids = utils.prepare_tensor_or_dict(g, nids, \"nids\")\n        self._dataset = utils.maybe_flatten_dict(self.nids)\n\n        # Add prob/mask into graphbolt partition's edge attributes if needed.\n        if hasattr(self.graph_sampler, \"prob\"):\n            Collator.add_edge_attribute_to_graph(\n                self.g, self.graph_sampler.prob, gb_padding\n            )\n\n    @property\n    def dataset(self):\n        return self._dataset\n\n    def collate(self, items):\n        \"\"\"Find the list of MFGs necessary for computing the representation of given\n        nodes for a node classification/regression task.\n\n        Parameters\n        ----------\n        items : list[int] or list[tuple[str, int]]\n            Either a list of node IDs (for homogeneous graphs), or a list of node type-ID\n            pairs (for heterogeneous graphs).\n\n        Returns\n        -------\n        input_nodes : Tensor or dict[ntype, Tensor]\n            The input nodes necessary for computation in this minibatch.\n\n            If the original graph has multiple node types, return a dictionary of\n            node type names and node ID tensors.  Otherwise, return a single tensor.\n        output_nodes : Tensor or dict[ntype, Tensor]\n            The nodes whose representations are to be computed in this minibatch.\n\n            If the original graph has multiple node types, return a dictionary of\n            node type names and node ID tensors.  Otherwise, return a single tensor.\n        MFGs : list[DGLGraph]\n            The list of MFGs necessary for computing the representation.\n        \"\"\"\n        if isinstance(items[0], tuple):\n            # returns a list of pairs: group them by node types into a dict\n            items = utils.group_as_dict(items)\n        items = utils.prepare_tensor_or_dict(self.g, items, \"items\")\n\n        input_nodes, output_nodes, blocks = self.graph_sampler.sample_blocks(\n            self.g, items\n        )\n\n        return input_nodes, output_nodes, blocks\n\n\nclass EdgeCollator(Collator):\n    \"\"\"DGL collator to combine edges and their computation dependencies within a minibatch for\n    training edge classification, edge regression, or link prediction on a single graph\n    with neighborhood sampling.\n\n    Given a set of edges, the collate function will yield\n\n    * A tensor of input nodes necessary for computing the representation on edges, or\n      a dictionary of node type names and such tensors.\n\n    * A subgraph that contains only the edges in the minibatch and their incident nodes.\n      Note that the graph has an identical metagraph with the original graph.\n\n    * If a negative sampler is given, another graph that contains the \"negative edges\",\n      connecting the source and destination nodes yielded from the given negative sampler.\n\n    * A list of MFGs necessary for computing the representation of the incident nodes\n      of the edges in the minibatch.\n\n    Parameters\n    ----------\n    g : DGLGraph\n        The graph from which the edges are iterated in minibatches and the subgraphs\n        are generated.\n    eids : Tensor or dict[etype, Tensor]\n        The edge set in graph :attr:`g` to compute outputs.\n    graph_sampler : dgl.dataloading.BlockSampler\n        The neighborhood sampler.\n    g_sampling : DGLGraph, optional\n        The graph where neighborhood sampling and message passing is performed.\n\n        Note that this is not necessarily the same as :attr:`g`.\n\n        If None, assume to be the same as :attr:`g`.\n    exclude : str, optional\n        Whether and how to exclude dependencies related to the sampled edges in the\n        minibatch.  Possible values are\n\n        * None, which excludes nothing.\n\n        * ``'self'``, which excludes the sampled edges themselves but nothing else.\n\n        * ``'reverse_id'``, which excludes the reverse edges of the sampled edges.  The said\n          reverse edges have the same edge type as the sampled edges.  Only works\n          on edge types whose source node type is the same as its destination node type.\n\n        * ``'reverse_types'``, which excludes the reverse edges of the sampled edges.  The\n          said reverse edges have different edge types from the sampled edges.\n\n        If ``g_sampling`` is given, ``exclude`` is ignored and will be always ``None``.\n    reverse_eids : Tensor or dict[etype, Tensor], optional\n        A tensor of reverse edge ID mapping.  The i-th element indicates the ID of\n        the i-th edge's reverse edge.\n\n        If the graph is heterogeneous, this argument requires a dictionary of edge\n        types and the reverse edge ID mapping tensors.\n\n        Required and only used when ``exclude`` is set to ``reverse_id``.\n\n        For heterogeneous graph this will be a dict of edge type and edge IDs.  Note that\n        only the edge types whose source node type is the same as destination node type\n        are needed.\n    reverse_etypes : dict[etype, etype], optional\n        The mapping from the edge type to its reverse edge type.\n\n        Required and only used when ``exclude`` is set to ``reverse_types``.\n    negative_sampler : callable, optional\n        The negative sampler.  Can be omitted if no negative sampling is needed.\n\n        The negative sampler must be a callable that takes in the following arguments:\n\n        * The original (heterogeneous) graph.\n\n        * The ID array of sampled edges in the minibatch, or the dictionary of edge\n          types and ID array of sampled edges in the minibatch if the graph is\n          heterogeneous.\n\n        It should return\n\n        * A pair of source and destination node ID arrays as negative samples,\n          or a dictionary of edge types and such pairs if the graph is heterogenenous.\n\n        A set of builtin negative samplers are provided in\n        :ref:`the negative sampling module <api-dataloading-negative-sampling>`.\n    gb_padding : int, optional\n        The padding value for GraphBolt partitions' new edge_attributes if the attributes in DistGraph are None.\n        e.g. prob/mask-based sampling.\n        Only when the mask of one edge is set as 1, an edge will be sampled in dgl.graphbolt.FusedCSCSamplingGraph.sample_neighbors.\n        The argument will be used in add_edge_attribute_to_graph to add new edge_attributes in graphbolt.\n    --------\n    The following example shows how to train a 3-layer GNN for edge classification on a\n    set of edges ``train_eid`` on a homogeneous undirected graph. Each node takes\n    messages from all neighbors.\n\n    Say that you have an array of source node IDs ``src`` and another array of destination\n    node IDs ``dst``.  One can make it bidirectional by adding another set of edges\n    that connects from ``dst`` to ``src``:\n\n    >>> g = dgl.graph((torch.cat([src, dst]), torch.cat([dst, src])))\n\n    One can then know that the ID difference of an edge and its reverse edge is ``|E|``,\n    where ``|E|`` is the length of your source/destination array.  The reverse edge\n    mapping can be obtained by\n\n    >>> E = len(src)\n    >>> reverse_eids = torch.cat([torch.arange(E, 2 * E), torch.arange(0, E)])\n\n    Note that the sampled edges as well as their reverse edges are removed from\n    computation dependencies of the incident nodes.  This is a common trick to avoid\n    information leakage.\n\n    >>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])\n    >>> collator = dgl.dataloading.EdgeCollator(\n    ...     g, train_eid, sampler, exclude='reverse_id',\n    ...     reverse_eids=reverse_eids)\n    >>> dataloader = torch.utils.data.DataLoader(\n    ...     collator.dataset, collate_fn=collator.collate,\n    ...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)\n    >>> for input_nodes, pair_graph, blocks in dataloader:\n    ...     train_on(input_nodes, pair_graph, blocks)\n\n    To train a 3-layer GNN for link prediction on a set of edges ``train_eid`` on a\n    homogeneous graph where each node takes messages from all neighbors (assume the\n    backend is PyTorch), with 5 uniformly chosen negative samples per edge:\n\n    >>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])\n    >>> neg_sampler = dgl.dataloading.negative_sampler.Uniform(5)\n    >>> collator = dgl.dataloading.EdgeCollator(\n    ...     g, train_eid, sampler, exclude='reverse_id',\n    ...     reverse_eids=reverse_eids, negative_sampler=neg_sampler)\n    >>> dataloader = torch.utils.data.DataLoader(\n    ...     collator.dataset, collate_fn=collator.collate,\n    ...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)\n    >>> for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader:\n    ...     train_on(input_nodse, pair_graph, neg_pair_graph, blocks)\n\n    For heterogeneous graphs, the reverse of an edge may have a different edge type\n    from the original edge.  For instance, consider that you have an array of\n    user-item clicks, representated by a user array ``user`` and an item array ``item``.\n    You may want to build a heterogeneous graph with a user-click-item relation and an\n    item-clicked-by-user relation.\n\n    >>> g = dgl.heterograph({\n    ...     ('user', 'click', 'item'): (user, item),\n    ...     ('item', 'clicked-by', 'user'): (item, user)})\n\n    To train a 3-layer GNN for edge classification on a set of edges ``train_eid`` with\n    type ``click``, you can write\n\n    >>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])\n    >>> collator = dgl.dataloading.EdgeCollator(\n    ...     g, {'click': train_eid}, sampler, exclude='reverse_types',\n    ...     reverse_etypes={'click': 'clicked-by', 'clicked-by': 'click'})\n    >>> dataloader = torch.utils.data.DataLoader(\n    ...     collator.dataset, collate_fn=collator.collate,\n    ...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)\n    >>> for input_nodes, pair_graph, blocks in dataloader:\n    ...     train_on(input_nodes, pair_graph, blocks)\n\n    To train a 3-layer GNN for link prediction on a set of edges ``train_eid`` with type\n    ``click``, you can write\n\n    >>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])\n    >>> neg_sampler = dgl.dataloading.negative_sampler.Uniform(5)\n    >>> collator = dgl.dataloading.EdgeCollator(\n    ...     g, train_eid, sampler, exclude='reverse_types',\n    ...     reverse_etypes={'click': 'clicked-by', 'clicked-by': 'click'},\n    ...     negative_sampler=neg_sampler)\n    >>> dataloader = torch.utils.data.DataLoader(\n    ...     collator.dataset, collate_fn=collator.collate,\n    ...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)\n    >>> for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader:\n    ...     train_on(input_nodes, pair_graph, neg_pair_graph, blocks)\n\n    Notes\n    -----\n    For the concept of MFGs, please refer to\n    :ref:`User Guide Section 6 <guide-minibatch>` and\n    :doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.\n    \"\"\"\n\n    def __init__(\n        self,\n        g,\n        eids,\n        graph_sampler,\n        g_sampling=None,\n        exclude=None,\n        reverse_eids=None,\n        reverse_etypes=None,\n        negative_sampler=None,\n        gb_padding=1,\n    ):\n        self.g = g\n        if not isinstance(eids, Mapping):\n            assert (\n                len(g.etypes) == 1\n            ), \"eids should be a dict of etype and ids for graph with multiple etypes\"\n        self.graph_sampler = graph_sampler\n\n        # One may wish to iterate over the edges in one graph while perform sampling in\n        # another graph.  This may be the case for iterating over validation and test\n        # edge set while perform neighborhood sampling on the graph formed by only\n        # the training edge set.\n        # See GCMC for an example usage.\n        if g_sampling is not None:\n            self.g_sampling = g_sampling\n            self.exclude = None\n        else:\n            self.g_sampling = self.g\n            self.exclude = exclude\n\n        self.reverse_eids = reverse_eids\n        self.reverse_etypes = reverse_etypes\n        self.negative_sampler = negative_sampler\n\n        self.eids = utils.prepare_tensor_or_dict(g, eids, \"eids\")\n        self._dataset = utils.maybe_flatten_dict(self.eids)\n\n        # Add prob/mask into graphbolt partition's edge attributes if needed.\n        if hasattr(self.graph_sampler, \"prob\"):\n            Collator.add_edge_attribute_to_graph(\n                self.g, self.graph_sampler.prob, gb_padding\n            )\n\n    @property\n    def dataset(self):\n        return self._dataset\n\n    def _collate(self, items):\n        if isinstance(items[0], tuple):\n            # returns a list of pairs: group them by node types into a dict\n            items = utils.group_as_dict(items)\n        items = utils.prepare_tensor_or_dict(self.g_sampling, items, \"items\")\n\n        pair_graph = self.g.edge_subgraph(items)\n        seed_nodes = pair_graph.ndata[NID]\n\n        exclude_eids = _find_exclude_eids(\n            self.g_sampling,\n            self.exclude,\n            items,\n            reverse_eid_map=self.reverse_eids,\n            reverse_etype_map=self.reverse_etypes,\n        )\n\n        input_nodes, _, blocks = self.graph_sampler.sample_blocks(\n            self.g_sampling, seed_nodes, exclude_eids=exclude_eids\n        )\n\n        return input_nodes, pair_graph, blocks\n\n    def _collate_with_negative_sampling(self, items):\n        if isinstance(items[0], tuple):\n            # returns a list of pairs: group them by node types into a dict\n            items = utils.group_as_dict(items)\n        items = utils.prepare_tensor_or_dict(self.g_sampling, items, \"items\")\n\n        pair_graph = self.g.edge_subgraph(items, relabel_nodes=False)\n        induced_edges = pair_graph.edata[EID]\n\n        neg_srcdst = self.negative_sampler(self.g, items)\n        if not isinstance(neg_srcdst, Mapping):\n            assert len(self.g.etypes) == 1, (\n                \"graph has multiple or no edge types; \"\n                \"please return a dict in negative sampler.\"\n            )\n            neg_srcdst = {self.g.canonical_etypes[0]: neg_srcdst}\n        # Get dtype from a tuple of tensors\n        dtype = F.dtype(list(neg_srcdst.values())[0][0])\n        ctx = F.context(pair_graph)\n        neg_edges = {\n            etype: neg_srcdst.get(\n                etype,\n                (\n                    F.copy_to(F.tensor([], dtype), ctx),\n                    F.copy_to(F.tensor([], dtype), ctx),\n                ),\n            )\n            for etype in self.g.canonical_etypes\n        }\n        neg_pair_graph = heterograph(\n            neg_edges,\n            {ntype: self.g.num_nodes(ntype) for ntype in self.g.ntypes},\n        )\n\n        pair_graph, neg_pair_graph = transforms.compact_graphs(\n            [pair_graph, neg_pair_graph]\n        )\n        pair_graph.edata[EID] = induced_edges\n\n        seed_nodes = pair_graph.ndata[NID]\n\n        exclude_eids = _find_exclude_eids(\n            self.g_sampling,\n            self.exclude,\n            items,\n            reverse_eid_map=self.reverse_eids,\n            reverse_etype_map=self.reverse_etypes,\n        )\n\n        input_nodes, _, blocks = self.graph_sampler.sample_blocks(\n            self.g_sampling, seed_nodes, exclude_eids=exclude_eids\n        )\n\n        return input_nodes, pair_graph, neg_pair_graph, blocks\n\n    def collate(self, items):\n        \"\"\"Combines the sampled edges into a minibatch for edge classification, edge\n        regression, and link prediction tasks.\n\n        Parameters\n        ----------\n        items : list[int] or list[tuple[str, int]]\n            Either a list of edge IDs (for homogeneous graphs), or a list of edge type-ID\n            pairs (for heterogeneous graphs).\n\n        Returns\n        -------\n        Either ``(input_nodes, pair_graph, blocks)``, or\n        ``(input_nodes, pair_graph, negative_pair_graph, blocks)`` if negative sampling is\n        enabled.\n\n        input_nodes : Tensor or dict[ntype, Tensor]\n            The input nodes necessary for computation in this minibatch.\n\n            If the original graph has multiple node types, return a dictionary of\n            node type names and node ID tensors.  Otherwise, return a single tensor.\n        pair_graph : DGLGraph\n            The graph that contains only the edges in the minibatch as well as their incident\n            nodes.\n\n            Note that the metagraph of this graph will be identical to that of the original\n            graph.\n        negative_pair_graph : DGLGraph\n            The graph that contains only the edges connecting the source and destination nodes\n            yielded from the given negative sampler, if negative sampling is enabled.\n\n            Note that the metagraph of this graph will be identical to that of the original\n            graph.\n        blocks : list[DGLGraph]\n            The list of MFGs necessary for computing the representation of the edges.\n        \"\"\"\n        if self.negative_sampler is None:\n            return self._collate(items)\n        else:\n            return self._collate_with_negative_sampling(items)\n\n\ndef _remove_kwargs_dist(kwargs):\n    if \"num_workers\" in kwargs:\n        del kwargs[\"num_workers\"]\n    if \"pin_memory\" in kwargs:\n        del kwargs[\"pin_memory\"]\n        print(\"Distributed DataLoaders do not support pin_memory.\")\n    return kwargs\n\n\nclass DistNodeDataLoader(DistDataLoader):\n    \"\"\"Sampled graph data loader over nodes for distributed graph storage.\n\n    It wraps an iterable over a set of nodes, generating the list\n    of message flow graphs (MFGs) as computation dependency of the said minibatch, on\n    a distributed graph.\n\n    All the arguments have the same meaning as the single-machine counterpart\n    :class:`dgl.dataloading.DataLoader` except the first argument\n    :attr:`g` which must be a :class:`dgl.distributed.DistGraph`.\n\n    Parameters\n    ----------\n    g : DistGraph\n        The distributed graph.\n\n    nids, graph_sampler, device, kwargs :\n        See :class:`dgl.dataloading.DataLoader`.\n\n    See also\n    --------\n    dgl.dataloading.DataLoader\n    \"\"\"\n\n    def __init__(self, g, nids, graph_sampler, device=None, **kwargs):\n        collator_kwargs = {}\n        dataloader_kwargs = {}\n        _collator_arglist = inspect.getfullargspec(NodeCollator).args\n        for k, v in kwargs.items():\n            if k in _collator_arglist:\n                collator_kwargs[k] = v\n            else:\n                dataloader_kwargs[k] = v\n        if device is None:\n            # for the distributed case default to the CPU\n            device = \"cpu\"\n        assert (\n            device == \"cpu\"\n        ), \"Only cpu is supported in the case of a DistGraph.\"\n        # Distributed DataLoader currently does not support heterogeneous graphs\n        # and does not copy features.  Fallback to normal solution\n        self.collator = NodeCollator(g, nids, graph_sampler, **collator_kwargs)\n        _remove_kwargs_dist(dataloader_kwargs)\n        super().__init__(\n            self.collator.dataset,\n            collate_fn=self.collator.collate,\n            **dataloader_kwargs\n        )\n        self.device = device\n\n\nclass DistEdgeDataLoader(DistDataLoader):\n    \"\"\"Sampled graph data loader over edges for distributed graph storage.\n\n    It wraps an iterable over a set of edges, generating the list\n    of message flow graphs (MFGs) as computation dependency of the said minibatch for\n    edge classification, edge regression, and link prediction, on a distributed\n    graph.\n\n    All the arguments have the same meaning as the single-machine counterpart\n    :class:`dgl.dataloading.DataLoader` except the first argument\n    :attr:`g` which must be a :class:`dgl.distributed.DistGraph`.\n\n    Parameters\n    ----------\n    g : DistGraph\n        The distributed graph.\n\n    eids, graph_sampler, device, kwargs :\n        See :class:`dgl.dataloading.DataLoader`.\n\n    See also\n    --------\n    dgl.dataloading.DataLoader\n    \"\"\"\n\n    def __init__(self, g, eids, graph_sampler, device=None, **kwargs):\n        collator_kwargs = {}\n        dataloader_kwargs = {}\n        _collator_arglist = inspect.getfullargspec(EdgeCollator).args\n        for k, v in kwargs.items():\n            if k in _collator_arglist:\n                collator_kwargs[k] = v\n            else:\n                dataloader_kwargs[k] = v\n\n        if device is None:\n            # for the distributed case default to the CPU\n            device = \"cpu\"\n        assert (\n            device == \"cpu\"\n        ), \"Only cpu is supported in the case of a DistGraph.\"\n        # Distributed DataLoader currently does not support heterogeneous graphs\n        # and does not copy features.  Fallback to normal solution\n        self.collator = EdgeCollator(g, eids, graph_sampler, **collator_kwargs)\n        _remove_kwargs_dist(dataloader_kwargs)\n        super().__init__(\n            self.collator.dataset,\n            collate_fn=self.collator.collate,\n            **dataloader_kwargs\n        )\n\n        self.device = device\n"
  },
  {
    "path": "python/dgl/distributed/dist_graph.py",
    "content": "\"\"\"Define distributed graph.\"\"\"\n\nimport gc\n\nimport os\nfrom collections import namedtuple\nfrom collections.abc import Mapping, MutableMapping\n\nimport numpy as np\nimport torch\n\nfrom .. import backend as F, graphbolt as gb, heterograph_index\nfrom .._ffi.ndarray import empty_shared_mem\nfrom ..base import ALL, DGLError, EID, ETYPE, is_all, NID\nfrom ..convert import graph as dgl_graph, heterograph as dgl_heterograph\nfrom ..frame import infer_scheme\n\nfrom ..heterograph import DGLGraph\nfrom ..ndarray import exist_shared_mem_array\nfrom ..transforms import compact_graphs\nfrom . import graph_services, role, rpc\nfrom .dist_tensor import DistTensor\nfrom .graph_partition_book import (\n    _etype_str_to_tuple,\n    EdgePartitionPolicy,\n    get_shared_mem_partition_book,\n    HeteroDataName,\n    NodePartitionPolicy,\n    parse_hetero_data_name,\n    PartitionPolicy,\n)\nfrom .graph_services import (\n    find_edges as dist_find_edges,\n    in_degrees as dist_in_degrees,\n    out_degrees as dist_out_degrees,\n)\nfrom .kvstore import get_kvstore, KVServer\nfrom .partition import (\n    load_partition,\n    load_partition_book,\n    load_partition_feats,\n    RESERVED_FIELD_DTYPE,\n)\nfrom .rpc_server import start_server\nfrom .server_state import ServerState\nfrom .shared_mem_utils import (\n    _get_edata_path,\n    _get_ndata_path,\n    _to_shared_mem,\n    DTYPE_DICT,\n)\n\nINIT_GRAPH = 800001\nQUERY_IF_USE_GRAPHBOLT = 800002\nADD_EDGE_ATTRIBUTE_FROM_KV = 800003\nADD_EDGE_ATTRIBUTE_FROM_SHARED_MEM = 800004\n\n\nclass InitGraphRequest(rpc.Request):\n    \"\"\"Init graph on the backup servers.\n\n    When the backup server starts, they don't load the graph structure.\n    This request tells the backup servers that they can map to the graph structure\n    with shared memory.\n    \"\"\"\n\n    def __init__(self, graph_name):\n        self._graph_name = graph_name\n\n    def __getstate__(self):\n        return self._graph_name\n\n    def __setstate__(self, state):\n        self._graph_name = state\n\n    def process_request(self, server_state):\n        if server_state.graph is None:\n            server_state.graph = _get_graph_from_shared_mem(\n                self._graph_name, server_state.use_graphbolt\n            )\n        return InitGraphResponse(self._graph_name)\n\n\nclass InitGraphResponse(rpc.Response):\n    \"\"\"Ack the init graph request\"\"\"\n\n    def __init__(self, graph_name):\n        self._graph_name = graph_name\n\n    def __getstate__(self):\n        return self._graph_name\n\n    def __setstate__(self, state):\n        self._graph_name = state\n\n\nclass QueryIfUseGraphBoltRequest(rpc.Request):\n    \"\"\"Query if use GraphBolt.\"\"\"\n\n    def __getstate__(self):\n        return None\n\n    def __setstate__(self, state):\n        pass\n\n    def process_request(self, server_state):\n        return QueryIfUseGraphBoltResponse(server_state.use_graphbolt)\n\n\nclass QueryIfUseGraphBoltResponse(rpc.Response):\n    \"\"\"Ack the query request about if use GraphBolt.\"\"\"\n\n    def __init__(self, use_graphbolt):\n        self._use_graphbolt = use_graphbolt\n\n    def __getstate__(self):\n        return self._use_graphbolt\n\n    def __setstate__(self, state):\n        self._use_graphbolt = state\n\n\ndef _copy_data_to_shared_mem(data, name):\n    \"\"\"Copy data to shared memory.\"\"\"\n    # [TODO] Copy data to shared memory.\n    assert data.dtype == torch.float32, \"Only float32 is supported.\"\n    data_type = F.reverse_data_type_dict[F.dtype(data)]\n    shared_data = empty_shared_mem(name, True, data.shape, data_type)\n    dlpack = shared_data.to_dlpack()\n    ret = F.zerocopy_from_dlpack(dlpack)\n    rpc.copy_data_to_shared_memory(ret, data)\n    return ret\n\n\ndef _copy_data_from_shared_mem(name, shape):\n    \"\"\"Copy data from shared memory.\"\"\"\n    data_type = F.reverse_data_type_dict[F.float32]\n    data = empty_shared_mem(name, False, shape, data_type)\n    dlpack = data.to_dlpack()\n    return F.zerocopy_from_dlpack(dlpack)\n\n\nclass AddEdgeAttributeFromKVRequest(rpc.Request):\n    \"\"\"Add edge attribute from kvstore to local GraphBolt partition.\"\"\"\n\n    def __init__(self, name, kv_names, padding):\n        self._name = name\n        self._kv_names = kv_names\n        self._padding = padding\n\n    def __getstate__(self):\n        return self._name, self._kv_names, self._padding\n\n    def __setstate__(self, state):\n        self._name, self._kv_names, self._padding = state\n\n    def process_request(self, server_state):\n        # For now, this is only used to add prob/mask data to the graph.\n        name = self._name\n        g = server_state.graph\n        if name not in g.edge_attributes:\n            # Fetch target data from kvstore.\n            kv_store = server_state.kv_store\n            data = [\n                kv_store.data_store[kv_name] if kv_name else None\n                for kv_name in self._kv_names\n            ]\n            # Due to data type limitation in GraphBolt's sampling, we only support float32.\n            data_type = torch.float32\n            gpb = server_state.partition_book\n            # Initialize the edge attribute.\n            num_edges = g.total_num_edges\n\n            # Padding is used to fill missing edge attributes (e.g., 'prob' or 'mask') for certain edge types.\n            # In DGLGraph, some edges may lack these attributes or have them set to None, but DGL will still sample these edges.\n            # In contrast, GraphBolt samples edges based on specific attributes (e.g., 'mask' == 1) and will skip edges with missing attributes.\n            # To ensure consistent sampling behavior in GraphBolt, we pad missing attributes with default values (e.g., 'mask' = 1),\n            # allowing all edges to be sampled, even if their attributes were missing or None in DGLGraph.\n            attr_data = torch.full((num_edges,), self._padding, dtype=data_type)\n            # Map data from kvstore to the local partition for inner edges only.\n            num_inner_edges = gpb.metadata()[gpb.partid][\"num_edges\"]\n            homo_eids = g.edge_attributes[EID][:num_inner_edges]\n            etype_ids, typed_eids = gpb.map_to_per_etype(homo_eids)\n            for etype_id, c_etype in enumerate(gpb.canonical_etypes):\n                curr_indices = torch.nonzero(etype_ids == etype_id).squeeze()\n                curr_typed_eids = typed_eids[curr_indices]\n                curr_local_eids = gpb.eid2localeid(\n                    curr_typed_eids, gpb.partid, etype=c_etype\n                )\n                if data[etype_id] is None:\n                    continue\n                attr_data[curr_indices] = data[etype_id][curr_local_eids].to(\n                    data_type\n                )\n            # Copy data to shared memory.\n            attr_data = _copy_data_to_shared_mem(attr_data, \"__edge__\" + name)\n            g.add_edge_attribute(name, attr_data)\n        return AddEdgeAttributeFromKVResponse(name)\n\n\nclass AddEdgeAttributeFromKVResponse(rpc.Response):\n    \"\"\"Ack the request of adding edge attribute.\"\"\"\n\n    def __init__(self, name):\n        self._name = name\n\n    def __getstate__(self):\n        return self._name\n\n    def __setstate__(self, state):\n        self._name = state\n\n\nclass AddEdgeAttributeFromSharedMemRequest(rpc.Request):\n    \"\"\"Add edge attribute from shared memory to local GraphBolt partition.\"\"\"\n\n    def __init__(self, name):\n        self._name = name\n\n    def __getstate__(self):\n        return self._name\n\n    def __setstate__(self, state):\n        self._name = state\n\n    def process_request(self, server_state):\n        name = self._name\n        g = server_state.graph\n        if name not in g.edge_attributes:\n            data = _copy_data_from_shared_mem(\n                \"__edge__\" + name, (g.total_num_edges,)\n            )\n            g.add_edge_attribute(name, data)\n        return AddEdgeAttributeFromSharedMemResponse(name)\n\n\nclass AddEdgeAttributeFromSharedMemResponse(rpc.Response):\n    \"\"\"Ack the request of adding edge attribute from shared memory.\"\"\"\n\n    def __init__(self, name):\n        self._name = name\n\n    def __getstate__(self):\n        return self._name\n\n    def __setstate__(self, state):\n        self._name = state\n\n\ndef _copy_graph_to_shared_mem(g, graph_name, graph_format, use_graphbolt):\n    if use_graphbolt:\n        return g.copy_to_shared_memory(graph_name)\n    new_g = g.shared_memory(graph_name, formats=graph_format)\n    # We should share the node/edge data to the client explicitly instead of putting them\n    # in the KVStore because some of the node/edge data may be duplicated.\n    new_g.ndata[\"inner_node\"] = _to_shared_mem(\n        g.ndata[\"inner_node\"], _get_ndata_path(graph_name, \"inner_node\")\n    )\n    new_g.ndata[NID] = _to_shared_mem(\n        g.ndata[NID], _get_ndata_path(graph_name, NID)\n    )\n\n    new_g.edata[\"inner_edge\"] = _to_shared_mem(\n        g.edata[\"inner_edge\"], _get_edata_path(graph_name, \"inner_edge\")\n    )\n    new_g.edata[EID] = _to_shared_mem(\n        g.edata[EID], _get_edata_path(graph_name, EID)\n    )\n    # for heterogeneous graph, we need to put ETYPE into KVStore\n    # for homogeneous graph, ETYPE does not exist\n    if ETYPE in g.edata:\n        new_g.edata[ETYPE] = _to_shared_mem(\n            g.edata[ETYPE],\n            _get_edata_path(graph_name, ETYPE),\n        )\n    return new_g\n\n\ndef _get_shared_mem_ndata(g, graph_name, name):\n    \"\"\"Get shared-memory node data from DistGraph server.\n\n    This is called by the DistGraph client to access the node data in the DistGraph server\n    with shared memory.\n    \"\"\"\n    shape = (g.num_nodes(),)\n    dtype = RESERVED_FIELD_DTYPE[name]\n    dtype = DTYPE_DICT[dtype]\n    data = empty_shared_mem(\n        _get_ndata_path(graph_name, name), False, shape, dtype\n    )\n    dlpack = data.to_dlpack()\n    return F.zerocopy_from_dlpack(dlpack)\n\n\ndef _get_shared_mem_edata(g, graph_name, name):\n    \"\"\"Get shared-memory edge data from DistGraph server.\n\n    This is called by the DistGraph client to access the edge data in the DistGraph server\n    with shared memory.\n    \"\"\"\n    shape = (g.num_edges(),)\n    dtype = RESERVED_FIELD_DTYPE[name]\n    dtype = DTYPE_DICT[dtype]\n    data = empty_shared_mem(\n        _get_edata_path(graph_name, name), False, shape, dtype\n    )\n    dlpack = data.to_dlpack()\n    return F.zerocopy_from_dlpack(dlpack)\n\n\ndef _exist_shared_mem_array(graph_name, name):\n    return exist_shared_mem_array(_get_edata_path(graph_name, name))\n\n\ndef _get_graph_from_shared_mem(graph_name, use_graphbolt):\n    \"\"\"Get the graph from the DistGraph server.\n\n    The DistGraph server puts the graph structure of the local partition in the shared memory.\n    The client can access the graph structure and some metadata on nodes and edges directly\n    through shared memory to reduce the overhead of data access.\n    \"\"\"\n    if use_graphbolt:\n        return gb.load_from_shared_memory(graph_name)\n    g, ntypes, etypes = heterograph_index.create_heterograph_from_shared_memory(\n        graph_name\n    )\n    if g is None:\n        return None\n    g = DGLGraph(g, ntypes, etypes)\n\n    g.ndata[\"inner_node\"] = _get_shared_mem_ndata(g, graph_name, \"inner_node\")\n    g.ndata[NID] = _get_shared_mem_ndata(g, graph_name, NID)\n\n    g.edata[\"inner_edge\"] = _get_shared_mem_edata(g, graph_name, \"inner_edge\")\n    g.edata[EID] = _get_shared_mem_edata(g, graph_name, EID)\n\n    # heterogeneous graph has ETYPE\n    if _exist_shared_mem_array(graph_name, ETYPE):\n        g.edata[ETYPE] = _get_shared_mem_edata(g, graph_name, ETYPE)\n    return g\n\n\nNodeSpace = namedtuple(\"NodeSpace\", [\"data\"])\nEdgeSpace = namedtuple(\"EdgeSpace\", [\"data\"])\n\n\nclass HeteroNodeView(object):\n    \"\"\"A NodeView class to act as G.nodes for a DistGraph.\"\"\"\n\n    __slots__ = [\"_graph\"]\n\n    def __init__(self, graph):\n        self._graph = graph\n\n    def __getitem__(self, key):\n        assert isinstance(key, str)\n        return NodeSpace(data=NodeDataView(self._graph, key))\n\n\nclass HeteroEdgeView(object):\n    \"\"\"An EdgeView class to act as G.edges for a DistGraph.\"\"\"\n\n    __slots__ = [\"_graph\"]\n\n    def __init__(self, graph):\n        self._graph = graph\n\n    def __getitem__(self, key):\n        assert isinstance(key, str) or (\n            isinstance(key, tuple) and len(key) == 3\n        ), f\"Expect edge type in string or triplet of string, but got {key}.\"\n        return EdgeSpace(data=EdgeDataView(self._graph, key))\n\n\nclass NodeDataView(MutableMapping):\n    \"\"\"The data view class when dist_graph.ndata[...].data is called.\"\"\"\n\n    __slots__ = [\"_graph\", \"_data\"]\n\n    def __init__(self, g, ntype=None):\n        self._graph = g\n        if ntype is None or len(g.ntypes) == 1:\n            self._data = g._ndata_store\n        else:\n            if ntype not in g.ntypes:\n                raise DGLError(f\"Node type {ntype} does not exist.\")\n            self._data = g._ndata_store[ntype]\n\n    def _get_names(self):\n        return list(self._data.keys())\n\n    def __getitem__(self, key):\n        return self._data[key]\n\n    def __setitem__(self, key, val):\n        self._data[key] = val\n\n    def __delitem__(self, key):\n        del self._data[key]\n\n    def __len__(self):\n        # The number of node data may change. Let's count it every time we need them.\n        # It's not called frequently. It should be fine.\n        return len(self._data)\n\n    def __iter__(self):\n        return iter(self._data)\n\n    def __repr__(self):\n        reprs = {}\n        for name in self._data:\n            dtype = F.dtype(self._data[name])\n            shape = F.shape(self._data[name])\n            reprs[name] = \"DistTensor(shape={}, dtype={})\".format(\n                str(shape), str(dtype)\n            )\n        return repr(reprs)\n\n\nclass EdgeDataView(MutableMapping):\n    \"\"\"The data view class when G.edges[...].data is called.\"\"\"\n\n    __slots__ = [\"_graph\", \"_data\"]\n\n    def __init__(self, g, etype=None):\n        self._graph = g\n        if etype is None or len(g.canonical_etypes) == 1:\n            self._data = g._edata_store\n        else:\n            c_etype = g.to_canonical_etype(etype)\n            self._data = g._edata_store[c_etype]\n\n    def _get_names(self):\n        return list(self._data.keys())\n\n    def __getitem__(self, key):\n        return self._data[key]\n\n    def __setitem__(self, key, val):\n        self._data[key] = val\n\n    def __delitem__(self, key):\n        del self._data[key]\n\n    def __len__(self):\n        # The number of edge data may change. Let's count it every time we need them.\n        # It's not called frequently. It should be fine.\n        return len(self._data)\n\n    def __iter__(self):\n        return iter(self._data)\n\n    def __repr__(self):\n        reprs = {}\n        for name in self._data:\n            dtype = F.dtype(self._data[name])\n            shape = F.shape(self._data[name])\n            reprs[name] = \"DistTensor(shape={}, dtype={})\".format(\n                str(shape), str(dtype)\n            )\n        return repr(reprs)\n\n\ndef _format_partition(graph, graph_format):\n    \"\"\"Format the partition to the specified format.\"\"\"\n    if isinstance(graph, gb.FusedCSCSamplingGraph):\n        return graph\n    # formatting dtype\n    # TODO(Rui) Formatting forcely is not a perfect solution.\n    #   We'd better store all dtypes when mapping to shared memory\n    #   and map back with original dtypes.\n    for k, dtype in RESERVED_FIELD_DTYPE.items():\n        if k in graph.ndata:\n            graph.ndata[k] = F.astype(graph.ndata[k], dtype)\n        if k in graph.edata:\n            graph.edata[k] = F.astype(graph.edata[k], dtype)\n    # Create the graph formats specified the users.\n    print(\n        \"Start to create specified graph formats which may take \"\n        \"non-trivial time.\"\n    )\n    graph = graph.formats(graph_format)\n    graph.create_formats_()\n    print(f\"Finished creating specified graph formats: {graph_format}\")\n    return graph\n\n\nclass DistGraphServer(KVServer):\n    \"\"\"The DistGraph server.\n\n    This DistGraph server loads the graph data and sets up a service so that trainers and\n    samplers can read data of a graph partition (graph structure, node data and edge data)\n    from remote machines. A server is responsible for one graph partition.\n\n    Currently, each machine runs only one main server with a set of backup servers to handle\n    clients' requests. The main server and the backup servers all handle the requests for the same\n    graph partition. They all share the partition data (graph structure and node/edge data) with\n    shared memory.\n\n    By default, the partition data is shared with the DistGraph clients that run on\n    the same machine. However, a user can disable shared memory option. This is useful for the case\n    that a user wants to run the server and the client on different machines.\n\n    Parameters\n    ----------\n    server_id : int\n        The server ID (start from 0).\n    ip_config : str\n        Path of IP configuration file.\n    num_servers : int\n        Server count on each machine.\n    num_clients : int\n        Total number of client nodes.\n    part_config : string\n        The path of the config file generated by the partition tool.\n    disable_shared_mem : bool\n        Disable shared memory.\n    graph_format : str or list of str\n        The graph formats.\n    use_graphbolt : bool\n        Whether to load GraphBolt partition. Default: False.\n    \"\"\"\n\n    def __init__(\n        self,\n        server_id,\n        ip_config,\n        num_servers,\n        num_clients,\n        part_config,\n        disable_shared_mem=False,\n        graph_format=(\"csc\", \"coo\"),\n        use_graphbolt=False,\n    ):\n        super(DistGraphServer, self).__init__(\n            server_id=server_id,\n            ip_config=ip_config,\n            num_servers=num_servers,\n            num_clients=num_clients,\n        )\n        self.ip_config = ip_config\n        self.num_servers = num_servers\n        self.use_graphbolt = use_graphbolt\n        # Load graph partition data.\n        if self.is_backup_server():\n            # The backup server doesn't load the graph partition. It'll initialized afterwards.\n            self.gpb, graph_name, ntypes, etypes = load_partition_book(\n                part_config, self.part_id\n            )\n            self.client_g = None\n        else:\n            # Loading of node/edge_feats are deferred to lower the peak memory consumption.\n            (\n                self.client_g,\n                _,\n                _,\n                self.gpb,\n                graph_name,\n                ntypes,\n                etypes,\n            ) = load_partition(\n                part_config,\n                self.part_id,\n                load_feats=False,\n                use_graphbolt=use_graphbolt,\n            )\n            print(\"load \" + graph_name)\n            self.client_g = _format_partition(self.client_g, graph_format)\n            if not disable_shared_mem:\n                self.client_g = _copy_graph_to_shared_mem(\n                    self.client_g, graph_name, graph_format, use_graphbolt\n                )\n\n        if not disable_shared_mem:\n            self.gpb.shared_memory(graph_name)\n        assert self.gpb.partid == self.part_id\n        for ntype in ntypes:\n            node_name = HeteroDataName(True, ntype, \"\")\n            self.add_part_policy(\n                PartitionPolicy(node_name.policy_str, self.gpb)\n            )\n        for etype in etypes:\n            edge_name = HeteroDataName(False, etype, \"\")\n            self.add_part_policy(\n                PartitionPolicy(edge_name.policy_str, self.gpb)\n            )\n\n        if not self.is_backup_server():\n            node_feats, _ = load_partition_feats(\n                part_config, self.part_id, load_nodes=True, load_edges=False\n            )\n            for name in node_feats:\n                # The feature name has the following format: node_type + \"/\" + feature_name to avoid\n                # feature name collision for different node types.\n                ntype, feat_name = name.split(\"/\")\n                data_name = HeteroDataName(True, ntype, feat_name)\n                self.init_data(\n                    name=str(data_name),\n                    policy_str=data_name.policy_str,\n                    data_tensor=node_feats[name],\n                )\n                self.orig_data.add(str(data_name))\n            # Let's free once node features are copied to shared memory\n            del node_feats\n            gc.collect()\n            _, edge_feats = load_partition_feats(\n                part_config, self.part_id, load_nodes=False, load_edges=True\n            )\n            for name in edge_feats:\n                # The feature name has the following format: edge_type + \"/\" + feature_name to avoid\n                # feature name collision for different edge types.\n                etype, feat_name = name.split(\"/\")\n                etype = _etype_str_to_tuple(etype)\n                data_name = HeteroDataName(False, etype, feat_name)\n                self.init_data(\n                    name=str(data_name),\n                    policy_str=data_name.policy_str,\n                    data_tensor=edge_feats[name],\n                )\n                self.orig_data.add(str(data_name))\n            # Let's free once edge features are copied to shared memory\n            del edge_feats\n            gc.collect()\n\n    def start(self):\n        \"\"\"Start graph store server.\"\"\"\n        # start server\n        server_state = ServerState(\n            kv_store=self,\n            local_g=self.client_g,\n            partition_book=self.gpb,\n            use_graphbolt=self.use_graphbolt,\n        )\n        print(\n            \"start graph service on server {} for part {}\".format(\n                self.server_id, self.part_id\n            )\n        )\n        start_server(\n            server_id=self.server_id,\n            ip_config=self.ip_config,\n            num_servers=self.num_servers,\n            num_clients=self.num_clients,\n            server_state=server_state,\n        )\n\n\nclass DistGraph:\n    \"\"\"The class for accessing a distributed graph.\n\n    This class provides a subset of DGLGraph APIs for accessing partitioned graph data in\n    distributed GNN training and inference. Thus, its main use case is to work with\n    distributed sampling APIs to generate mini-batches and perform forward and\n    backward computation on the mini-batches.\n\n    The class can run in two modes: the standalone mode and the distributed mode.\n\n    * When a user runs the training script normally, ``DistGraph`` will be in the standalone mode.\n      In this mode, the input data must be constructed by\n      :py:meth:`~dgl.distributed.partition.partition_graph` with only one partition. This mode is\n      used for testing and debugging purpose. In this mode, users have to provide ``part_config``\n      so that ``DistGraph`` can load the input graph.\n    * When a user runs the training script with the distributed launch script, ``DistGraph`` will\n      be set into the distributed mode. This is used for actual distributed training. All data of\n      partitions are loaded by the ``DistGraph`` servers, which are created by DGL's launch script.\n      ``DistGraph`` connects with the servers to access the partitioned graph data.\n\n    Currently, the ``DistGraph`` servers and clients run on the same set of machines\n    in the distributed mode. ``DistGraph`` uses shared-memory to access the partition data\n    in the local machine. This gives the best performance for distributed training\n\n    Users may want to run ``DistGraph`` servers and clients on separate sets of machines.\n    In this case, a user may want to disable shared memory by passing\n    ``disable_shared_mem=False`` when creating ``DistGraphServer``. When shared memory is disabled,\n    a user has to pass a partition book.\n\n    Parameters\n    ----------\n    graph_name : str\n        The name of the graph. This name has to be the same as the one used for\n        partitioning a graph in :py:meth:`dgl.distributed.partition.partition_graph`.\n    gpb : GraphPartitionBook, optional\n        The partition book object. Normally, users do not need to provide the partition book.\n        This argument is necessary only when users want to run server process and trainer\n        processes on different machines.\n    part_config : str, optional\n        The path of partition configuration file generated by\n        :py:meth:`dgl.distributed.partition.partition_graph`. It's used in the standalone mode.\n\n    Examples\n    --------\n    The example shows the creation of ``DistGraph`` in the standalone mode.\n\n    >>> dgl.distributed.partition_graph(g, 'graph_name', 1, num_hops=1, part_method='metis',\n    ...                                 out_path='output/')\n    >>> g = dgl.distributed.DistGraph('graph_name', part_config='output/graph_name.json')\n\n    The example shows the creation of ``DistGraph`` in the distributed mode.\n\n    >>> g = dgl.distributed.DistGraph('graph-name')\n\n    The code below shows the mini-batch training using ``DistGraph``.\n\n    >>> def sample(seeds):\n    ...     seeds = th.LongTensor(np.asarray(seeds))\n    ...     frontier = dgl.distributed.sample_neighbors(g, seeds, 10)\n    ...     return dgl.to_block(frontier, seeds)\n    >>> dataloader = dgl.distributed.DistDataLoader(dataset=nodes, batch_size=1000,\n    ...                                             collate_fn=sample, shuffle=True)\n    >>> for block in dataloader:\n    ...     feat = g.ndata['features'][block.srcdata[dgl.NID]]\n    ...     labels = g.ndata['labels'][block.dstdata[dgl.NID]]\n    ...     pred = model(block, feat)\n\n    Note\n    ----\n    DGL's distributed training by default runs server processes and trainer processes on the same\n    set of machines. If users need to run them on different sets of machines, it requires\n    manually setting up servers and trainers. The setup is not fully tested yet.\n    \"\"\"\n\n    def __init__(self, graph_name, gpb=None, part_config=None):\n        self.graph_name = graph_name\n        self._added_edge_attributes = []  # For prob/mask sampling on GB.\n        if os.environ.get(\"DGL_DIST_MODE\", \"standalone\") == \"standalone\":\n            # \"GraphBolt is not supported in standalone mode.\"\n            self._use_graphbolt = False\n            assert (\n                part_config is not None\n            ), \"When running in the standalone model, the partition config file is required\"\n            self._client = get_kvstore()\n            assert (\n                self._client is not None\n            ), \"Distributed module is not initialized. Please call dgl.distributed.initialize.\"\n            # Load graph partition data.\n            g, node_feats, edge_feats, self._gpb, _, _, _ = load_partition(\n                part_config, 0\n            )\n            assert (\n                self._gpb.num_partitions() == 1\n            ), \"The standalone mode can only work with the graph data with one partition\"\n            if self._gpb is None:\n                self._gpb = gpb\n            self._g = g\n            for name in node_feats:\n                # The feature name has the following format: node_type + \"/\" + feature_name.\n                ntype, feat_name = name.split(\"/\")\n                self._client.add_data(\n                    str(HeteroDataName(True, ntype, feat_name)),\n                    node_feats[name],\n                    NodePartitionPolicy(self._gpb, ntype=ntype),\n                )\n            for name in edge_feats:\n                # The feature name has the following format: edge_type + \"/\" + feature_name.\n                etype, feat_name = name.split(\"/\")\n                etype = _etype_str_to_tuple(etype)\n                self._client.add_data(\n                    str(HeteroDataName(False, etype, feat_name)),\n                    edge_feats[name],\n                    EdgePartitionPolicy(self._gpb, etype=etype),\n                )\n            self._client.map_shared_data(self._gpb)\n            rpc.set_num_client(1)\n        else:\n            # Query the main server about whether GraphBolt is used.\n            rpc.send_request(0, QueryIfUseGraphBoltRequest())\n            self._use_graphbolt = rpc.recv_response()._use_graphbolt\n\n            self._init(gpb)\n            # Tell the backup servers to load the graph structure from shared memory.\n            for server_id in range(self._client.num_servers):\n                rpc.send_request(server_id, InitGraphRequest(graph_name))\n            for server_id in range(self._client.num_servers):\n                rpc.recv_response()\n            self._client.barrier()\n\n        self._init_ndata_store()\n        self._init_edata_store()\n        self._init_metadata()\n\n    def _init(self, gpb):\n        self._client = get_kvstore()\n        assert (\n            self._client is not None\n        ), \"Distributed module is not initialized. Please call dgl.distributed.initialize.\"\n        self._g = _get_graph_from_shared_mem(\n            self.graph_name, self._use_graphbolt\n        )\n        self._gpb = get_shared_mem_partition_book(self.graph_name)\n        if self._gpb is None:\n            self._gpb = gpb\n        self._client.map_shared_data(self._gpb)\n\n    def _init_ndata_store(self):\n        \"\"\"Initialize node data store.\"\"\"\n        self._ndata_store = {}\n        for ntype in self.ntypes:\n            names = self._get_ndata_names(ntype)\n            data = {}\n            for name in names:\n                assert name.is_node()\n                policy = PartitionPolicy(\n                    name.policy_str, self.get_partition_book()\n                )\n                dtype, shape, _ = self._client.get_data_meta(str(name))\n                # We create a wrapper on the existing tensor in the kvstore.\n                data[name.get_name()] = DistTensor(\n                    shape,\n                    dtype,\n                    name.get_name(),\n                    part_policy=policy,\n                    attach=False,\n                )\n            if len(self.ntypes) == 1:\n                self._ndata_store = data\n            else:\n                self._ndata_store[ntype] = data\n\n    def _init_edata_store(self):\n        \"\"\"Initialize edge data store.\"\"\"\n        self._edata_store = {}\n        for etype in self.canonical_etypes:\n            names = self._get_edata_names(etype)\n            data = {}\n            for name in names:\n                assert name.is_edge()\n                policy = PartitionPolicy(\n                    name.policy_str, self.get_partition_book()\n                )\n                dtype, shape, _ = self._client.get_data_meta(str(name))\n                # We create a wrapper on the existing tensor in the kvstore.\n                data[name.get_name()] = DistTensor(\n                    shape,\n                    dtype,\n                    name.get_name(),\n                    part_policy=policy,\n                    attach=False,\n                )\n            if len(self.canonical_etypes) == 1:\n                self._edata_store = data\n            else:\n                self._edata_store[etype] = data\n\n    def _init_metadata(self):\n        self._num_nodes = 0\n        self._num_edges = 0\n        for part_md in self._gpb.metadata():\n            self._num_nodes += int(part_md[\"num_nodes\"])\n            self._num_edges += int(part_md[\"num_edges\"])\n\n        # When we store node/edge types in a list, they are stored in the order of type IDs.\n        self._ntype_map = {ntype: i for i, ntype in enumerate(self.ntypes)}\n        self._etype_map = {\n            etype: i for i, etype in enumerate(self.canonical_etypes)\n        }\n\n    def __getstate__(self):\n        return (\n            self.graph_name,\n            self._gpb,\n            self._use_graphbolt,\n            self._added_edge_attributes,\n        )\n\n    def __setstate__(self, state):\n        (\n            self.graph_name,\n            gpb,\n            self._use_graphbolt,\n            self._added_edge_attributes,\n        ) = state\n        self._init(gpb)\n\n        self._init_ndata_store()\n        self._init_edata_store()\n        self._init_metadata()\n\n        # For prob/mask sampling on GB only.\n        if self._use_graphbolt and len(self._added_edge_attributes) > 0:\n            # Add edge attribute from main server's shared memory.\n            for name in self._added_edge_attributes:\n                data = _copy_data_from_shared_mem(\n                    \"__edge__\" + name, (self.local_partition.total_num_edges,)\n                )\n                self.local_partition.add_edge_attribute(name, data)\n\n    @property\n    def local_partition(self):\n        \"\"\"Return the local partition on the client\n\n        DistGraph provides a global view of the distributed graph. Internally,\n        it may contains a partition of the graph if it is co-located with\n        the server. When servers and clients run on separate sets of machines,\n        this returns None.\n\n        Returns\n        -------\n        DGLGraph\n            The local partition\n        \"\"\"\n        return self._g\n\n    @property\n    def nodes(self):\n        \"\"\"Return a node view\"\"\"\n        return HeteroNodeView(self)\n\n    @property\n    def edges(self):\n        \"\"\"Return an edge view\"\"\"\n        return HeteroEdgeView(self)\n\n    @property\n    def ndata(self):\n        \"\"\"Return the data view of all the nodes.\n\n        Returns\n        -------\n        NodeDataView\n            The data view in the distributed graph storage.\n        \"\"\"\n        assert (\n            len(self.ntypes) == 1\n        ), \"ndata only works for a graph with one node type.\"\n        return NodeDataView(self)\n\n    @property\n    def edata(self):\n        \"\"\"Return the data view of all the edges.\n\n        Returns\n        -------\n        EdgeDataView\n            The data view in the distributed graph storage.\n        \"\"\"\n        assert (\n            len(self.etypes) == 1\n        ), \"edata only works for a graph with one edge type.\"\n        return EdgeDataView(self)\n\n    @property\n    def idtype(self):\n        \"\"\"The dtype of graph index\n\n        Returns\n        -------\n        backend dtype object\n            th.int32/th.int64 or tf.int32/tf.int64 etc.\n\n        See Also\n        --------\n        long\n        int\n        \"\"\"\n        # TODO(da?): describe when self._g is None and idtype shouldn't be called.\n        # For GraphBolt partition, we use the global node ID's dtype.\n        return (\n            self.get_partition_book().global_nid_dtype\n            if self._use_graphbolt\n            else F.int64\n        )\n\n    @property\n    def device(self):\n        \"\"\"Get the device context of this graph.\n\n        Examples\n        --------\n        The following example uses PyTorch backend.\n\n        >>> g = dgl.heterograph({\n        ...     ('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 2, 1])\n        ... })\n        >>> print(g.device)\n        device(type='cpu')\n        >>> g = g.to('cuda:0')\n        >>> print(g.device)\n        device(type='cuda', index=0)\n\n        Returns\n        -------\n        Device context object\n        \"\"\"\n        # TODO(da?): describe when self._g is None and device shouldn't be called.\n        return F.cpu()\n\n    def is_pinned(self):\n        \"\"\"Check if the graph structure is pinned to the page-locked memory.\n\n        Returns\n        -------\n        bool\n            True if the graph structure is pinned.\n        \"\"\"\n        # (Xin Yao): Currently we don't support pinning a DistGraph.\n        return False\n\n    @property\n    def ntypes(self):\n        \"\"\"Return the list of node types of this graph.\n\n        Returns\n        -------\n        list of str\n\n        Examples\n        --------\n\n        >>> g = DistGraph(\"test\")\n        >>> g.ntypes\n        ['_U']\n        \"\"\"\n        return self._gpb.ntypes\n\n    @property\n    def etypes(self):\n        \"\"\"Return the list of edge types of this graph.\n\n        Returns\n        -------\n        list of str\n\n        Examples\n        --------\n\n        >>> g = DistGraph(\"test\")\n        >>> g.etypes\n        ['_E']\n        \"\"\"\n        return self._gpb.etypes\n\n    @property\n    def canonical_etypes(self):\n        \"\"\"Return all the canonical edge types in the graph.\n\n        A canonical edge type is a string triplet ``(str, str, str)``\n        for source node type, edge type and destination node type.\n\n        Returns\n        -------\n        list[(str, str, str)]\n            All the canonical edge type triplets in a list.\n\n        Notes\n        -----\n        DGL internally assigns an integer ID for each edge type. The returned\n        edge type names are sorted according to their IDs.\n\n        See Also\n        --------\n        etypes\n\n        Examples\n        --------\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        >>> g = DistGraph(\"test\")\n        >>> g.canonical_etypes\n        [('user', 'follows', 'user'),\n         ('user', 'follows', 'game'),\n         ('user', 'plays', 'game')]\n        \"\"\"\n        return self._gpb.canonical_etypes\n\n    def to_canonical_etype(self, etype):\n        \"\"\"Convert an edge type to the corresponding canonical edge type in the graph.\n\n        A canonical edge type is a string triplet ``(str, str, str)``\n        for source node type, edge type and destination node type.\n\n        The function expects the given edge type name can uniquely identify a canonical edge\n        type. DGL will raise error if this is not the case.\n\n        Parameters\n        ----------\n        etype : str or (str, str, str)\n            If :attr:`etype` is an edge type (str), it returns the corresponding canonical edge\n            type in the graph. If :attr:`etype` is already a canonical edge type,\n            it directly returns the input unchanged.\n\n        Returns\n        -------\n        (str, str, str)\n            The canonical edge type corresponding to the edge type.\n\n        Examples\n        --------\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        >>> g = DistGraph(\"test\")\n        >>> g.canonical_etypes\n        [('user', 'follows', 'user'),\n         ('user', 'follows', 'game'),\n         ('user', 'plays', 'game')]\n\n        >>> g.to_canonical_etype('plays')\n        ('user', 'plays', 'game')\n        >>> g.to_canonical_etype(('user', 'plays', 'game'))\n        ('user', 'plays', 'game')\n\n        See Also\n        --------\n        canonical_etypes\n        \"\"\"\n        return self._gpb.to_canonical_etype(etype)\n\n    def get_ntype_id(self, ntype):\n        \"\"\"Return the ID of the given node type.\n\n        ntype can also be None. If so, there should be only one node type in the\n        graph.\n\n        Parameters\n        ----------\n        ntype : str\n            Node type\n\n        Returns\n        -------\n        int\n        \"\"\"\n        if ntype is None:\n            if len(self._ntype_map) != 1:\n                raise DGLError(\n                    \"Node type name must be specified if there are more than one \"\n                    \"node types.\"\n                )\n            return 0\n        return self._ntype_map[ntype]\n\n    def get_etype_id(self, etype):\n        \"\"\"Return the id of the given edge type.\n\n        etype can also be None. If so, there should be only one edge type in the\n        graph.\n\n        Parameters\n        ----------\n        etype : str or tuple of str\n            Edge type\n\n        Returns\n        -------\n        int\n        \"\"\"\n        if etype is None:\n            if len(self._etype_map) != 1:\n                raise DGLError(\n                    \"Edge type name must be specified if there are more than one \"\n                    \"edge types.\"\n                )\n            return 0\n        etype = self.to_canonical_etype(etype)\n        return self._etype_map[etype]\n\n    def number_of_nodes(self, ntype=None):\n        \"\"\"Alias of :func:`num_nodes`\"\"\"\n        return self.num_nodes(ntype)\n\n    def number_of_edges(self, etype=None):\n        \"\"\"Alias of :func:`num_edges`\"\"\"\n        return self.num_edges(etype)\n\n    def num_nodes(self, ntype=None):\n        \"\"\"Return the total number of nodes in the distributed graph.\n\n        Parameters\n        ----------\n        ntype : str, optional\n            The node type name. If given, it returns the number of nodes of the\n            type. If not given (default), it returns the total number of nodes of all types.\n\n        Returns\n        -------\n        int\n            The number of nodes\n\n        Examples\n        --------\n        >>> g = dgl.distributed.DistGraph('ogb-product')\n        >>> print(g.num_nodes())\n        2449029\n        \"\"\"\n        if ntype is None:\n            if len(self.ntypes) == 1:\n                return self._gpb._num_nodes(self.ntypes[0])\n            else:\n                return sum(\n                    [self._gpb._num_nodes(ntype) for ntype in self.ntypes]\n                )\n        return self._gpb._num_nodes(ntype)\n\n    def num_edges(self, etype=None):\n        \"\"\"Return the total number of edges in the distributed graph.\n\n        Parameters\n        ----------\n        etype : str or (str, str, str), optional\n            The type name of the edges. The allowed type name formats are:\n\n            * ``(str, str, str)`` for source node type, edge type and destination node type.\n            * or one ``str`` edge type name if the name can uniquely identify a\n              triplet format in the graph.\n\n            If not provided, return the total number of edges regardless of the types\n            in the graph.\n\n        Returns\n        -------\n        int\n            The number of edges\n\n        Examples\n        --------\n        >>> g = dgl.distributed.DistGraph('ogb-product')\n        >>> print(g.num_edges())\n        123718280\n        \"\"\"\n        if etype is None:\n            return sum(\n                [\n                    self._gpb._num_edges(c_etype)\n                    for c_etype in self.canonical_etypes\n                ]\n            )\n        return self._gpb._num_edges(etype)\n\n    def out_degrees(self, u=ALL):\n        \"\"\"Return the out-degree(s) of the given nodes.\n\n        It computes the out-degree(s).\n        It does not support heterogeneous graphs yet.\n\n        Parameters\n        ----------\n        u : node IDs\n            The node IDs. The allowed formats are:\n\n            * ``int``: A single node.\n            * Int Tensor: Each element is a node ID. The tensor must have the same device type\n              and ID data type as the graph's.\n            * iterable[int]: Each element is a node ID.\n\n            If not given, return the in-degrees of all the nodes.\n\n        Returns\n        -------\n        int or Tensor\n            The out-degree(s) of the node(s) in a Tensor. The i-th element is the out-degree\n            of the i-th input node. If :attr:`v` is an ``int``, return an ``int`` too.\n\n        Examples\n        --------\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        Query for all nodes.\n\n        >>> g.out_degrees()\n        tensor([2, 2, 0, 0])\n\n        Query for nodes 1 and 2.\n\n        >>> g.out_degrees(torch.tensor([1, 2]))\n        tensor([2, 0])\n\n        See Also\n        --------\n        in_degrees\n        \"\"\"\n        if is_all(u):\n            u = F.arange(0, self.num_nodes())\n        return dist_out_degrees(self, u)\n\n    def in_degrees(self, v=ALL):\n        \"\"\"Return the in-degree(s) of the given nodes.\n\n        It computes the in-degree(s).\n        It does not support heterogeneous graphs yet.\n\n        Parameters\n        ----------\n        v : node IDs\n            The node IDs. The allowed formats are:\n\n            * ``int``: A single node.\n            * Int Tensor: Each element is a node ID. The tensor must have the same device type\n              and ID data type as the graph's.\n            * iterable[int]: Each element is a node ID.\n\n            If not given, return the in-degrees of all the nodes.\n\n        Returns\n        -------\n        int or Tensor\n            The in-degree(s) of the node(s) in a Tensor. The i-th element is the in-degree\n            of the i-th input node. If :attr:`v` is an ``int``, return an ``int`` too.\n\n        Examples\n        --------\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        Query for all nodes.\n\n        >>> g.in_degrees()\n        tensor([0, 2, 1, 1])\n\n        Query for nodes 1 and 2.\n\n        >>> g.in_degrees(torch.tensor([1, 2]))\n        tensor([2, 1])\n\n        See Also\n        --------\n        out_degrees\n        \"\"\"\n        if is_all(v):\n            v = F.arange(0, self.num_nodes())\n        return dist_in_degrees(self, v)\n\n    def node_attr_schemes(self):\n        \"\"\"Return the node feature schemes.\n\n        Each feature scheme is a named tuple that stores the shape and data type\n        of the node feature.\n\n        Returns\n        -------\n        dict of str to schemes\n            The schemes of node feature columns.\n\n        Examples\n        --------\n        The following uses PyTorch backend.\n\n        >>> g.node_attr_schemes()\n        {'h': Scheme(shape=(4,), dtype=torch.float32)}\n\n        See Also\n        --------\n        edge_attr_schemes\n        \"\"\"\n        schemes = {}\n        for key in self.ndata:\n            schemes[key] = infer_scheme(self.ndata[key])\n        return schemes\n\n    def edge_attr_schemes(self):\n        \"\"\"Return the edge feature schemes.\n\n        Each feature scheme is a named tuple that stores the shape and data type\n        of the edge feature.\n\n        Returns\n        -------\n        dict of str to schemes\n            The schemes of edge feature columns.\n\n        Examples\n        --------\n        The following uses PyTorch backend.\n\n        >>> g.edge_attr_schemes()\n        {'h': Scheme(shape=(4,), dtype=torch.float32)}\n\n        See Also\n        --------\n        node_attr_schemes\n        \"\"\"\n        schemes = {}\n        for key in self.edata:\n            schemes[key] = infer_scheme(self.edata[key])\n        return schemes\n\n    def rank(self):\n        \"\"\"The rank of the current DistGraph.\n\n        This returns a unique number to identify the DistGraph object among all of\n        the client processes.\n\n        Returns\n        -------\n        int\n            The rank of the current DistGraph.\n        \"\"\"\n        return role.get_global_rank()\n\n    def find_edges(self, edges, etype=None):\n        \"\"\"Given an edge ID array, return the source\n        and destination node ID array ``s`` and ``d``.  ``s[i]`` and ``d[i]``\n        are source and destination node ID for edge ``eid[i]``.\n\n        Parameters\n        ----------\n        edges : Int Tensor\n            Each element is an ID. The tensor must have the same device type\n              and ID data type as the graph's.\n\n        etype : str or (str, str, str), optional\n            The type names of the edges. The allowed type name formats are:\n\n            * ``(str, str, str)`` for source node type, edge type and destination node type.\n            * or one ``str`` edge type name if the name can uniquely identify a\n              triplet format in the graph.\n\n            Can be omitted if the graph has only one type of edges.\n\n        Returns\n        -------\n        tensor\n            The source node ID array.\n        tensor\n            The destination node ID array.\n        \"\"\"\n        if etype is None:\n            assert (\n                len(self.etypes) == 1\n            ), \"find_edges requires etype for heterogeneous graphs.\"\n\n        gpb = self.get_partition_book()\n        if len(gpb.etypes) > 1:\n            edges = gpb.map_to_homo_eid(edges, etype)\n        src, dst = dist_find_edges(self, edges)\n        if len(gpb.ntypes) > 1:\n            _, src = gpb.map_to_per_ntype(src)\n            _, dst = gpb.map_to_per_ntype(dst)\n        return src, dst\n\n    def edge_subgraph(self, edges, relabel_nodes=True, store_ids=True):\n        \"\"\"Return a subgraph induced on the given edges.\n\n        An edge-induced subgraph is equivalent to creating a new graph using the given\n        edges. In addition to extracting the subgraph, DGL also copies the features\n        of the extracted nodes and edges to the resulting graph. The copy is *lazy*\n        and incurs data movement only when needed.\n\n        If the graph is heterogeneous, DGL extracts a subgraph per relation and composes\n        them as the resulting graph. Thus, the resulting graph has the same set of relations\n        as the input one.\n\n        Parameters\n        ----------\n        edges : Int Tensor or dict[(str, str, str), Int Tensor]\n            The edges to form the subgraph. Each element is an edge ID. The tensor must have\n            the same device type and ID data type as the graph's.\n\n            If the graph is homogeneous, one can directly pass an Int Tensor.\n            Otherwise, the argument must be a dictionary with keys being edge types\n            and values being the edge IDs in the above formats.\n        relabel_nodes : bool, optional\n            If True, it will remove the isolated nodes and relabel the incident nodes in the\n            extracted subgraph.\n        store_ids : bool, optional\n            If True, it will store the raw IDs of the extracted edges in the ``edata`` of the\n            resulting graph under name ``dgl.EID``; if ``relabel_nodes`` is ``True``, it will\n            also store the raw IDs of the incident nodes in the ``ndata`` of the resulting\n            graph under name ``dgl.NID``.\n\n        Returns\n        -------\n        G : DGLGraph\n            The subgraph.\n        \"\"\"\n        if isinstance(edges, dict):\n            # TODO(zhengda) we need to directly generate subgraph of all relations with\n            # one invocation.\n            subg = {}\n            for etype, edge in edges.items():\n                etype = self.to_canonical_etype(etype)\n                subg[etype] = self.find_edges(edge, etype)\n            num_nodes = {ntype: self.num_nodes(ntype) for ntype in self.ntypes}\n            subg = dgl_heterograph(subg, num_nodes_dict=num_nodes)\n            for etype in edges:\n                subg.edges[etype].data[EID] = edges[etype]\n        else:\n            assert len(self.etypes) == 1\n            subg = self.find_edges(edges)\n            subg = dgl_graph(subg, num_nodes=self.num_nodes())\n            subg.edata[EID] = edges\n\n        if relabel_nodes:\n            subg = compact_graphs(subg)\n        assert store_ids, \"edge_subgraph always stores original node/edge IDs.\"\n        return subg\n\n    def get_partition_book(self):\n        \"\"\"Get the partition information.\n\n        Returns\n        -------\n        GraphPartitionBook\n            Object that stores all graph partition information.\n        \"\"\"\n        return self._gpb\n\n    def get_node_partition_policy(self, ntype):\n        \"\"\"Get the partition policy for a node type.\n\n        When creating a new distributed tensor, we need to provide a partition policy\n        that indicates how to distribute data of the distributed tensor in a cluster\n        of machines. When we load a distributed graph in the cluster, we have pre-defined\n        partition policies for each node type and each edge type. By providing\n        the node type, we can reference to the pre-defined partition policy for the node type.\n\n        Parameters\n        ----------\n        ntype : str\n            The node type\n\n        Returns\n        -------\n        PartitionPolicy\n            The partition policy for the node type.\n        \"\"\"\n        return NodePartitionPolicy(self.get_partition_book(), ntype)\n\n    def get_edge_partition_policy(self, etype):\n        \"\"\"Get the partition policy for an edge type.\n\n        When creating a new distributed tensor, we need to provide a partition policy\n        that indicates how to distribute data of the distributed tensor in a cluster\n        of machines. When we load a distributed graph in the cluster, we have pre-defined\n        partition policies for each node type and each edge type. By providing\n        the edge type, we can reference to the pre-defined partition policy for the edge type.\n\n        Parameters\n        ----------\n        etype : str or (str, str, str)\n            The edge type\n\n        Returns\n        -------\n        PartitionPolicy\n            The partition policy for the edge type.\n        \"\"\"\n        etype = self.to_canonical_etype(etype)\n        return EdgePartitionPolicy(self.get_partition_book(), etype)\n\n    def barrier(self):\n        \"\"\"Barrier for all client nodes.\n\n        This API blocks the current process untill all the clients invoke this API.\n        Please use this API with caution.\n        \"\"\"\n        self._client.barrier()\n\n    def sample_neighbors(\n        self,\n        seed_nodes,\n        fanout,\n        edge_dir=\"in\",\n        prob=None,\n        exclude_edges=None,\n        replace=False,\n        etype_sorted=True,\n        output_device=None,\n    ):\n        # pylint: disable=unused-argument\n        \"\"\"Sample neighbors from a distributed graph.\"\"\"\n        if exclude_edges is not None:\n            # Convert exclude edge IDs to homogeneous edge IDs.\n            gpb = self.get_partition_book()\n            if isinstance(exclude_edges, Mapping):\n                exclude_eids = []\n                for c_etype, eids in exclude_edges.items():\n                    exclude_eids.append(gpb.map_to_homo_eid(eids, c_etype))\n                exclude_edges = torch.cat(exclude_eids)\n        if len(self.etypes) > 1:\n            frontier = graph_services.sample_etype_neighbors(\n                self,\n                seed_nodes,\n                fanout,\n                replace=replace,\n                etype_sorted=etype_sorted,\n                prob=prob,\n                exclude_edges=exclude_edges,\n                use_graphbolt=self._use_graphbolt,\n            )\n        else:\n            frontier = graph_services.sample_neighbors(\n                self,\n                seed_nodes,\n                fanout,\n                replace=replace,\n                prob=prob,\n                exclude_edges=exclude_edges,\n                use_graphbolt=self._use_graphbolt,\n            )\n        return frontier\n\n    def _get_ndata_names(self, ntype=None):\n        \"\"\"Get the names of all node data.\"\"\"\n        names = self._client.gdata_name_list()\n        ndata_names = []\n        for name in names:\n            name = parse_hetero_data_name(name)\n            right_type = (\n                (name.get_type() == ntype) if ntype is not None else True\n            )\n            if name.is_node() and right_type:\n                ndata_names.append(name)\n        return ndata_names\n\n    def _get_edata_names(self, etype=None):\n        \"\"\"Get the names of all edge data.\"\"\"\n        if etype is not None:\n            etype = self.to_canonical_etype(etype)\n        names = self._client.gdata_name_list()\n        edata_names = []\n        for name in names:\n            name = parse_hetero_data_name(name)\n            right_type = (\n                (name.get_type() == etype) if etype is not None else True\n            )\n            if name.is_edge() and right_type:\n                edata_names.append(name)\n        return edata_names\n\n    def add_edge_attribute(self, name, padding):\n        \"\"\"Add an edge attribute into GraphBolt partition from edge data.\n\n        Parameters\n        ----------\n        name : str\n            The name of the edge attribute.\n        padding : int, optional\n            The padding value for the new edge attribute.\n        \"\"\"\n        # Sanity checks.\n        if not self._use_graphbolt:\n            raise DGLError(\"GraphBolt is not used.\")\n\n        # Send add request to main server on the same machine.\n        kv_names = [\n            (\n                self.edges[etype].data[name].kvstore_key\n                if name in self.edges[etype].data\n                else None\n            )\n            for etype in self.canonical_etypes\n        ]\n        rpc.send_request(\n            self._client._main_server_id,\n            AddEdgeAttributeFromKVRequest(name, kv_names, padding),\n        )\n        # Wait for the response.\n        assert rpc.recv_response()._name == name\n        # Send add request to local backup servers.\n        for i in range(self._client.group_count - 1):\n            server_id = (\n                self._client.machine_id * self._client.group_count + i + 1\n            )\n            rpc.send_request(\n                server_id, AddEdgeAttributeFromSharedMemRequest(name)\n            )\n        # Receive response from local backup servers.\n        for _ in range(self._client.group_count - 1):\n            response = rpc.recv_response()\n            assert response._name == name\n        # Add edge attribute from main server's shared memory.\n        data = _copy_data_from_shared_mem(\n            \"__edge__\" + name, (self.local_partition.total_num_edges,)\n        )\n        self.local_partition.add_edge_attribute(name, data)\n        # Sync local clients.\n        self._client.barrier()\n\n        # Save the edge attribute into state. This is required by separate samplers.\n        self._added_edge_attributes.append(name)\n\n\ndef _get_overlap(mask_arr, ids):\n    \"\"\"Select the IDs given a boolean mask array.\n\n    The boolean mask array indicates all of the IDs to be selected. We want to\n    find the overlap between the IDs selected by the boolean mask array and\n    the ID array.\n\n    Parameters\n    ----------\n    mask_arr : 1D tensor\n        A boolean mask array.\n    ids : 1D tensor\n        A vector with IDs.\n\n    Returns\n    -------\n    1D tensor\n        The selected IDs.\n    \"\"\"\n    if isinstance(mask_arr, DistTensor):\n        masks = mask_arr[ids]\n        return F.boolean_mask(ids, masks)\n    else:\n        masks = F.gather_row(F.tensor(mask_arr), ids)\n        return F.boolean_mask(ids, masks)\n\n\ndef _split_local(partition_book, rank, elements, local_eles):\n    \"\"\"Split the input element list with respect to data locality.\"\"\"\n    num_clients = role.get_num_trainers()\n    num_client_per_part = num_clients // partition_book.num_partitions()\n    if rank is None:\n        rank = role.get_trainer_rank()\n    assert (\n        rank < num_clients\n    ), \"The input rank ({}) is incorrect. #Trainers: {}\".format(\n        rank, num_clients\n    )\n    # all ranks of the clients in the same machine are in a contiguous range.\n    client_id_in_part = rank % num_client_per_part\n    local_eles = _get_overlap(elements, local_eles)\n\n    # get a subset for the local client.\n    size = len(local_eles) // num_client_per_part\n    # if this isn't the last client in the partition.\n    if client_id_in_part + 1 < num_client_per_part:\n        return local_eles[\n            (size * client_id_in_part) : (size * (client_id_in_part + 1))\n        ]\n    else:\n        return local_eles[(size * client_id_in_part) :]\n\n\ndef _even_offset(n, k):\n    \"\"\"Split an array of length n into k segments and the difference of thier length is\n    at most 1. Return the offset of each segment.\n    \"\"\"\n    eles_per_part = n // k\n    offset = np.array([0] + [eles_per_part] * k, dtype=int)\n    offset[1 : n - eles_per_part * k + 1] += 1\n    return np.cumsum(offset)\n\n\ndef _split_even_to_part(partition_book, elements):\n    \"\"\"Split the input element list evenly.\"\"\"\n    # here we divide the element list as evenly as possible. If we use range partitioning,\n    # the split results also respect the data locality. Range partitioning is the default\n    # strategy.\n    # TODO(zhengda) we need another way to divide the list for other partitioning strategy.\n    if isinstance(elements, DistTensor):\n        nonzero_count = elements.count_nonzero()\n    else:\n        elements = F.tensor(elements)\n        nonzero_count = F.count_nonzero(elements)\n    # compute the offset of each split and ensure that the difference of each partition size\n    # is 1.\n    offsets = _even_offset(nonzero_count, partition_book.num_partitions())\n    assert offsets[-1] == nonzero_count\n\n    # Get the elements that belong to the partition.\n    partid = partition_book.partid\n    left, right = offsets[partid], offsets[partid + 1]\n\n    x = y = 0\n    num_elements = len(elements)\n    block_size = num_elements // partition_book.num_partitions()\n    part_eles = F.tensor([], dtype=elements.dtype)\n    # compute the nonzero tensor of each partition instead of whole tensor to save memory\n    for idx in range(0, num_elements, block_size):\n        nonzero_block = F.nonzero_1d(\n            elements[idx : min(idx + block_size, num_elements)]\n        )\n        x = y\n        y += len(nonzero_block)\n        if y > left and x < right:\n            start = max(x, left) - x\n            end = min(y, right) - x\n            tmp = nonzero_block[start:end] + idx\n            part_eles = F.cat((part_eles, tmp), 0)\n        elif x >= right:\n            break\n\n    return part_eles\n\n\ndef _split_random_within_part(partition_book, rank, part_eles):\n    # If there are more than one client in a partition, we need to randomly select a subset of\n    # elements in the partition for a client. We have to make sure that the set of elements\n    # for different clients are disjoint.\n\n    num_clients = role.get_num_trainers()\n    num_client_per_part = num_clients // partition_book.num_partitions()\n    if num_client_per_part == 1:\n        return part_eles\n    if rank is None:\n        rank = role.get_trainer_rank()\n    assert (\n        rank < num_clients\n    ), \"The input rank ({}) is incorrect. #Trainers: {}\".format(\n        rank, num_clients\n    )\n    client_id_in_part = rank % num_client_per_part\n    offset = _even_offset(len(part_eles), num_client_per_part)\n\n    # We set the random seed for each partition, so that each process (client) in a partition\n    # permute the elements in a partition in the same way, so each process gets a disjoint subset\n    # of elements.\n    np.random.seed(partition_book.partid)\n    rand_idx = np.random.permutation(len(part_eles))\n    rand_idx = rand_idx[\n        offset[client_id_in_part] : offset[client_id_in_part + 1]\n    ]\n    idx, _ = F.sort_1d(F.tensor(rand_idx))\n    return F.gather_row(part_eles, idx)\n\n\ndef _split_by_trainer_id(\n    partition_book,\n    part_eles,\n    trainer_id,\n    num_client_per_part,\n    client_id_in_part,\n):\n    # TODO(zhengda): MXNet cannot deal with empty tensors, which makes the implementation\n    # much more difficult. Let's just use numpy for the computation for now. We just\n    # perform operations on vectors. It shouldn't be too difficult.\n    trainer_id = F.asnumpy(trainer_id)\n    part_eles = F.asnumpy(part_eles)\n    part_id = trainer_id // num_client_per_part\n    trainer_id = trainer_id % num_client_per_part\n    local_eles = part_eles[\n        np.nonzero(part_id[part_eles] == partition_book.partid)[0]\n    ]\n    # these are the Ids of the local elements in the partition. The Ids are global Ids.\n    remote_eles = part_eles[\n        np.nonzero(part_id[part_eles] != partition_book.partid)[0]\n    ]\n    # these are the Ids of the remote nodes in the partition. The Ids are global Ids.\n    local_eles_idx = np.concatenate(\n        [\n            np.nonzero(trainer_id[local_eles] == i)[0]\n            for i in range(num_client_per_part)\n        ],\n        # trainer_id[local_eles] is the trainer ids of local nodes in the partition and we\n        # pick out the indices where the node belongs to each trainer i respectively, and\n        # concatenate them.\n        axis=0,\n    )\n    # `local_eles_idx` is used to sort `local_eles` according to `trainer_id`. It is a\n    # permutation of 0...(len(local_eles)-1)\n    local_eles = local_eles[local_eles_idx]\n\n    # evenly split local nodes to trainers\n    local_offsets = _even_offset(len(local_eles), num_client_per_part)\n    # evenly split remote nodes to trainers\n    remote_offsets = _even_offset(len(remote_eles), num_client_per_part)\n\n    client_local_eles = local_eles[\n        local_offsets[client_id_in_part] : local_offsets[client_id_in_part + 1]\n    ]\n    client_remote_eles = remote_eles[\n        remote_offsets[client_id_in_part] : remote_offsets[\n            client_id_in_part + 1\n        ]\n    ]\n    client_eles = np.concatenate(\n        [client_local_eles, client_remote_eles], axis=0\n    )\n    return F.tensor(client_eles)\n\n\ndef node_split(\n    nodes,\n    partition_book=None,\n    ntype=\"_N\",\n    rank=None,\n    force_even=True,\n    node_trainer_ids=None,\n):\n    \"\"\"Split nodes and return a subset for the local rank.\n\n    This function splits the input nodes based on the partition book and\n    returns a subset of nodes for the local rank. This method is used for\n    dividing workloads for distributed training.\n\n    The input nodes are stored as a vector of masks. The length of the vector is\n    the same as the number of nodes in a graph; 1 indicates that the vertex in\n    the corresponding location exists.\n\n    There are two strategies to split the nodes. By default, it splits the nodes\n    in a way to maximize data locality. That is, all nodes that belong to a process\n    are returned. If ``force_even`` is set to true, the nodes are split evenly so\n    that each process gets almost the same number of nodes.\n\n    When ``force_even`` is True, the data locality is still preserved if a graph is partitioned\n    with Metis and the node/edge IDs are shuffled.\n    In this case, majority of the nodes returned for a process are the ones that\n    belong to the process. If node/edge IDs are not shuffled, data locality is not guaranteed.\n\n    Parameters\n    ----------\n    nodes : 1D tensor or DistTensor\n        A boolean mask vector that indicates input nodes.\n    partition_book : GraphPartitionBook, optional\n        The graph partition book\n    ntype : str, optional\n        The node type of the input nodes.\n    rank : int, optional\n        The rank of a process. If not given, the rank of the current process is used.\n    force_even : bool, optional\n        Force the nodes are split evenly.\n    node_trainer_ids : 1D tensor or DistTensor, optional\n        If not None, split the nodes to the trainers on the same machine according to\n        trainer IDs assigned to each node. Otherwise, split randomly.\n\n    Returns\n    -------\n    1D-tensor\n        The vector of node IDs that belong to the rank.\n    \"\"\"\n    if not isinstance(nodes, DistTensor):\n        assert (\n            partition_book is not None\n        ), \"Regular tensor requires a partition book.\"\n    elif partition_book is None:\n        partition_book = nodes.part_policy.partition_book\n\n    assert len(nodes) == partition_book._num_nodes(\n        ntype\n    ), \"The length of boolean mask vector should be the number of nodes in the graph.\"\n    if rank is None:\n        rank = role.get_trainer_rank()\n    if force_even:\n        num_clients = role.get_num_trainers()\n        num_client_per_part = num_clients // partition_book.num_partitions()\n        assert (\n            num_clients % partition_book.num_partitions() == 0\n        ), \"The total number of clients should be multiple of the number of partitions.\"\n        part_nid = _split_even_to_part(partition_book, nodes)\n        if num_client_per_part == 1:\n            return part_nid\n        elif node_trainer_ids is None:\n            return _split_random_within_part(partition_book, rank, part_nid)\n        else:\n            trainer_id = node_trainer_ids[0 : len(node_trainer_ids)]\n            max_trainer_id = F.as_scalar(F.reduce_max(trainer_id)) + 1\n\n            if max_trainer_id > num_clients:\n                # We hope the partition scheme with trainer_id could be used when the number of\n                # trainers is less than the `num_trainers_per_machine` previously assigned during\n                # partitioning.\n                assert max_trainer_id % num_clients == 0\n                trainer_id //= max_trainer_id // num_clients\n\n            client_id_in_part = rank % num_client_per_part\n            return _split_by_trainer_id(\n                partition_book,\n                part_nid,\n                trainer_id,\n                num_client_per_part,\n                client_id_in_part,\n            )\n    else:\n        # Get all nodes that belong to the rank.\n        local_nids = partition_book.partid2nids(\n            partition_book.partid, ntype=ntype\n        )\n        return _split_local(partition_book, rank, nodes, local_nids)\n\n\ndef edge_split(\n    edges,\n    partition_book=None,\n    etype=\"_E\",\n    rank=None,\n    force_even=True,\n    edge_trainer_ids=None,\n):\n    \"\"\"Split edges and return a subset for the local rank.\n\n    This function splits the input edges based on the partition book and\n    returns a subset of edges for the local rank. This method is used for\n    dividing workloads for distributed training.\n\n    The input edges can be stored as a vector of masks. The length of the vector is\n    the same as the number of edges in a graph; 1 indicates that the edge in\n    the corresponding location exists.\n\n    There are two strategies to split the edges. By default, it splits the edges\n    in a way to maximize data locality. That is, all edges that belong to a process\n    are returned. If ``force_even`` is set to true, the edges are split evenly so\n    that each process gets almost the same number of edges.\n\n    When ``force_even`` is True, the data locality is still preserved if a graph is partitioned\n    with Metis and the node/edge IDs are shuffled.\n    In this case, majority of the nodes returned for a process are the ones that\n    belong to the process. If node/edge IDs are not shuffled, data locality is not guaranteed.\n\n    Parameters\n    ----------\n    edges : 1D tensor or DistTensor\n        A boolean mask vector that indicates input edges.\n    partition_book : GraphPartitionBook, optional\n        The graph partition book\n    etype : str or (str, str, str), optional\n        The edge type of the input edges.\n    rank : int, optional\n        The rank of a process. If not given, the rank of the current process is used.\n    force_even : bool, optional\n        Force the edges are split evenly.\n    edge_trainer_ids : 1D tensor or DistTensor, optional\n        If not None, split the edges to the trainers on the same machine according to\n        trainer IDs assigned to each edge. Otherwise, split randomly.\n\n    Returns\n    -------\n    1D-tensor\n        The vector of edge IDs that belong to the rank.\n    \"\"\"\n    if not isinstance(edges, DistTensor):\n        assert (\n            partition_book is not None\n        ), \"Regular tensor requires a partition book.\"\n    elif partition_book is None:\n        partition_book = edges.part_policy.partition_book\n    assert len(edges) == partition_book._num_edges(\n        etype\n    ), \"The length of boolean mask vector should be the number of edges in the graph.\"\n    if rank is None:\n        rank = role.get_trainer_rank()\n    if force_even:\n        num_clients = role.get_num_trainers()\n        num_client_per_part = num_clients // partition_book.num_partitions()\n        assert (\n            num_clients % partition_book.num_partitions() == 0\n        ), \"The total number of clients should be multiple of the number of partitions.\"\n        part_eid = _split_even_to_part(partition_book, edges)\n        if num_client_per_part == 1:\n            return part_eid\n        elif edge_trainer_ids is None:\n            return _split_random_within_part(partition_book, rank, part_eid)\n        else:\n            trainer_id = edge_trainer_ids[0 : len(edge_trainer_ids)]\n            max_trainer_id = F.as_scalar(F.reduce_max(trainer_id)) + 1\n\n            if max_trainer_id > num_clients:\n                # We hope the partition scheme with trainer_id could be used when the number of\n                # trainers is less than the `num_trainers_per_machine` previously assigned during\n                # partitioning.\n                assert max_trainer_id % num_clients == 0\n                trainer_id //= max_trainer_id // num_clients\n\n            client_id_in_part = rank % num_client_per_part\n            return _split_by_trainer_id(\n                partition_book,\n                part_eid,\n                trainer_id,\n                num_client_per_part,\n                client_id_in_part,\n            )\n    else:\n        # Get all edges that belong to the rank.\n        local_eids = partition_book.partid2eids(\n            partition_book.partid, etype=etype\n        )\n        return _split_local(partition_book, rank, edges, local_eids)\n\n\nrpc.register_service(INIT_GRAPH, InitGraphRequest, InitGraphResponse)\nrpc.register_service(\n    QUERY_IF_USE_GRAPHBOLT,\n    QueryIfUseGraphBoltRequest,\n    QueryIfUseGraphBoltResponse,\n)\nrpc.register_service(\n    ADD_EDGE_ATTRIBUTE_FROM_KV,\n    AddEdgeAttributeFromKVRequest,\n    AddEdgeAttributeFromKVResponse,\n)\nrpc.register_service(\n    ADD_EDGE_ATTRIBUTE_FROM_SHARED_MEM,\n    AddEdgeAttributeFromSharedMemRequest,\n    AddEdgeAttributeFromSharedMemResponse,\n)\n"
  },
  {
    "path": "python/dgl/distributed/dist_tensor.py",
    "content": "\"\"\"Define distributed tensor.\"\"\"\n\nimport os\n\nfrom .. import backend as F, utils\n\nfrom .dist_context import is_initialized\nfrom .kvstore import get_kvstore\nfrom .role import get_role\nfrom .rpc import get_group_id\n\n\ndef _default_init_data(shape, dtype):\n    return F.zeros(shape, dtype, F.cpu())\n\n\n# These IDs can identify the anonymous distributed tensors.\nDIST_TENSOR_ID = 0\n\n\nclass DistTensor:\n    \"\"\"Distributed tensor.\n\n    ``DistTensor`` references to a distributed tensor sharded and stored in a cluster of machines.\n    It has the same interface as Pytorch Tensor to access its metadata (e.g., shape and data type).\n    To access data in a distributed tensor, it supports slicing rows and writing data to rows.\n    It does not support any operators of a deep learning framework, such as addition and\n    multiplication.\n\n    Currently, distributed tensors are designed to store node data and edge data of a distributed\n    graph. Therefore, their first dimensions have to be the number of nodes or edges in the graph.\n    The tensors are sharded in the first dimension based on the partition policy of nodes\n    or edges. When a distributed tensor is created, the partition policy is automatically\n    determined based on the first dimension if the partition policy is not provided. If the first\n    dimension matches the number of nodes of a node type, ``DistTensor`` will use the partition\n    policy for this particular node type; if the first dimension matches the number of edges of\n    an edge type, ``DistTensor`` will use the partition policy for this particular edge type.\n    If DGL cannot determine the partition policy automatically (e.g., multiple node types or\n    edge types have the same number of nodes or edges), users have to explicity provide\n    the partition policy.\n\n    A distributed tensor can be ether named or anonymous.\n    When a distributed tensor has a name, the tensor can be persistent if ``persistent=True``.\n    Normally, DGL destroys the distributed tensor in the system when the ``DistTensor`` object\n    goes away. However, a persistent tensor lives in the system even if\n    the ``DistTenor`` object disappears in the trainer process. The persistent tensor has\n    the same life span as the DGL servers. DGL does not allow an anonymous tensor to be persistent.\n\n    When a ``DistTensor`` object is created, it may reference to an existing distributed tensor or\n    create a new one. A distributed tensor is identified by the name passed to the constructor.\n    If the name exists, ``DistTensor`` will reference the existing one.\n    In this case, the shape and the data type must match the existing tensor.\n    If the name doesn't exist, a new tensor will be created in the kvstore.\n\n    When a distributed tensor is created, its values are initialized to zero. Users\n    can define an initialization function to control how the values are initialized.\n    The init function has two input arguments: shape and data type and returns a tensor.\n    Below shows an example of an init function:\n\n    .. highlight:: python\n    .. code-block:: python\n\n        def init_func(shape, dtype):\n            return torch.ones(shape=shape, dtype=dtype)\n\n    Parameters\n    ----------\n    shape : tuple\n        The shape of the tensor. The first dimension has to be the number of nodes or\n        the number of edges of a distributed graph.\n    dtype : dtype\n        The dtype of the tensor. The data type has to be the one in the deep learning framework.\n    name : string, optional\n        The name of the embeddings. The name can uniquely identify embeddings in a system\n        so that another ``DistTensor`` object can referent to the distributed tensor.\n    init_func : callable, optional\n        The function to initialize data in the tensor. If the init function is not provided,\n        the values of the embeddings are initialized to zero.\n    part_policy : PartitionPolicy, optional\n        The partition policy of the rows of the tensor to different machines in the cluster.\n        Currently, it only supports node partition policy or edge partition policy.\n        The system determines the right partition policy automatically.\n    persistent : bool\n        Whether the created tensor lives after the ``DistTensor`` object is destroyed.\n    is_gdata : bool\n        Whether the created tensor is a ndata/edata or not.\n    attach : bool\n        Whether to attach group ID into name to be globally unique.\n\n    Examples\n    --------\n    >>> init = lambda shape, dtype: th.ones(shape, dtype=dtype)\n    >>> arr = dgl.distributed.DistTensor((g.num_nodes(), 2), th.int32, init_func=init)\n    >>> print(arr[0:3])\n    tensor([[1, 1],\n            [1, 1],\n            [1, 1]], dtype=torch.int32)\n    >>> arr[0:3] = th.ones((3, 2), dtype=th.int32) * 2\n    >>> print(arr[0:3])\n    tensor([[2, 2],\n            [2, 2],\n            [2, 2]], dtype=torch.int32)\n\n    Note\n    ----\n    The creation of ``DistTensor`` is a synchronized operation. When a trainer process tries to\n    create a ``DistTensor`` object, the creation succeeds only when all trainer processes\n    do the same.\n    \"\"\"\n\n    def __init__(\n        self,\n        shape,\n        dtype,\n        name=None,\n        init_func=None,\n        part_policy=None,\n        persistent=False,\n        is_gdata=True,\n        attach=True,\n    ):\n        self.kvstore = get_kvstore()\n        assert (\n            self.kvstore is not None\n        ), \"Distributed module is not initialized. Please call dgl.distributed.initialize.\"\n        self._shape = shape\n        self._dtype = dtype\n        self._attach = attach\n        self._is_gdata = is_gdata\n\n        part_policies = self.kvstore.all_possible_part_policy\n        # If a user doesn't provide a partition policy, we should find one based on\n        # the input shape.\n        if part_policy is None:\n            for policy_name in part_policies:\n                policy = part_policies[policy_name]\n                if policy.get_size() == shape[0]:\n                    # If multiple partition policies match the input shape, we cannot\n                    # decide which is the right one automatically. We should ask users\n                    # to provide one.\n                    assert part_policy is None, (\n                        \"Multiple partition policies match the input shape. \"\n                        + \"Please provide a partition policy explicitly.\"\n                    )\n                    part_policy = policy\n            assert part_policy is not None, (\n                \"Cannot find a right partition policy. It is either because \"\n                + \"its first dimension does not match the number of nodes or edges \"\n                + \"of a distributed graph or there does not exist a distributed graph.\"\n            )\n\n        self._part_policy = part_policy\n        assert (\n            part_policy.get_size() == shape[0]\n        ), \"The partition policy does not match the input shape.\"\n\n        if init_func is None:\n            init_func = _default_init_data\n        exist_names = self.kvstore.data_name_list()\n        # If a user doesn't provide a name, we generate a name ourselves.\n        # We need to generate the name in a deterministic way.\n        if name is None:\n            assert (\n                not persistent\n            ), \"We cannot generate anonymous persistent distributed tensors\"\n            global DIST_TENSOR_ID\n            # All processes of the same role should create DistTensor synchronously.\n            # Thus, all of them should have the same IDs.\n            name = \"anonymous-\" + get_role() + \"-\" + str(DIST_TENSOR_ID)\n            DIST_TENSOR_ID += 1\n        assert isinstance(name, str), \"name {} is type {}\".format(\n            name, type(name)\n        )\n        name = self._attach_group_id(name)\n        self._tensor_name = name\n        data_name = part_policy.get_data_name(name)\n        self._name = str(data_name)\n        self._persistent = persistent\n        if self._name not in exist_names:\n            self._owner = True\n            self.kvstore.init_data(\n                self._name, shape, dtype, part_policy, init_func, is_gdata\n            )\n        else:\n            self._owner = False\n            dtype1, shape1, _ = self.kvstore.get_data_meta(self._name)\n            assert (\n                dtype == dtype1\n            ), \"The dtype does not match with the existing tensor\"\n            assert (\n                shape == shape1\n            ), \"The shape does not match with the existing tensor\"\n\n    def __del__(self):\n        initialized = (\n            os.environ.get(\"DGL_DIST_MODE\", \"standalone\") == \"standalone\"\n            or is_initialized()\n        )\n        if not self._persistent and self._owner and initialized:\n            self.kvstore.delete_data(self._name)\n\n    def __getitem__(self, idx):\n        idx = utils.toindex(idx)\n        idx = idx.tousertensor()\n        return self.kvstore.pull(name=self._name, id_tensor=idx)\n\n    def __setitem__(self, idx, val):\n        idx = utils.toindex(idx)\n        idx = idx.tousertensor()\n        # TODO(zhengda) how do we want to support broadcast (e.g., G.ndata['h'][idx] = 1).\n        self.kvstore.push(name=self._name, id_tensor=idx, data_tensor=val)\n\n    @property\n    def kvstore_key(self):\n        \"\"\"Return the key string of this DistTensor in the associated KVStore.\"\"\"\n        return self._name\n\n    @property\n    def local_partition(self):\n        \"\"\"Return the local partition of this DistTensor.\"\"\"\n        return self.kvstore.data_store[self._name]\n\n    def __or__(self, other):\n        new_dist_tensor = DistTensor(\n            self._shape,\n            self._dtype,\n            part_policy=self._part_policy,\n            persistent=self._persistent,\n            is_gdata=self._is_gdata,\n            attach=self._attach,\n        )\n        kvstore = self.kvstore\n        kvstore.union(self._name, other._name, new_dist_tensor._name)\n        return new_dist_tensor\n\n    def __len__(self):\n        return self._shape[0]\n\n    @property\n    def part_policy(self):\n        \"\"\"Return the partition policy\n\n        Returns\n        -------\n        PartitionPolicy\n            The partition policy of the distributed tensor.\n        \"\"\"\n        return self._part_policy\n\n    @property\n    def shape(self):\n        \"\"\"Return the shape of the distributed tensor.\n\n        Returns\n        -------\n        tuple\n            The shape of the distributed tensor.\n        \"\"\"\n        return self._shape\n\n    @property\n    def dtype(self):\n        \"\"\"Return the data type of the distributed tensor.\n\n        Returns\n        ------\n        dtype\n            The data type of the tensor.\n        \"\"\"\n        return self._dtype\n\n    @property\n    def name(self):\n        \"\"\"Return the name of the distributed tensor\n\n        Returns\n        -------\n        str\n            The name of the tensor.\n        \"\"\"\n        return self._detach_group_id(self._name)\n\n    @property\n    def tensor_name(self):\n        \"\"\"Return the tensor name\n\n        Returns\n        -------\n        str\n            The name of the tensor.\n        \"\"\"\n        return self._detach_group_id(self._tensor_name)\n\n    def count_nonzero(self):\n        \"\"\"Count and return the number of nonzero value\n\n        Returns\n        -------\n        int\n            the number of nonzero value\n        \"\"\"\n        return self.kvstore.count_nonzero(name=self._name)\n\n    def _attach_group_id(self, name):\n        \"\"\"Attach group ID if needed\n\n        Returns\n        -------\n        str\n            new name with group ID attached\n        \"\"\"\n        if not self._attach:\n            return name\n        return \"{}_{}\".format(name, get_group_id())\n\n    def _detach_group_id(self, name):\n        \"\"\"Detach group ID if needed\n\n        Returns\n        -------\n        str\n            original name without group ID\n        \"\"\"\n        if not self._attach:\n            return name\n        suffix = \"_{}\".format(get_group_id())\n        return name[: -len(suffix)]\n"
  },
  {
    "path": "python/dgl/distributed/graph_partition_book.py",
    "content": "\"\"\"Define graph partition book.\"\"\"\n\nimport pickle\nfrom abc import ABC\n\nimport numpy as np\n\nfrom .. import backend as F, utils\nfrom .._ffi.ndarray import empty_shared_mem\nfrom ..base import DGLError\nfrom ..ndarray import exist_shared_mem_array\nfrom ..partition import NDArrayPartition\nfrom .constants import DEFAULT_ETYPE, DEFAULT_NTYPE\nfrom .id_map import IdMap\nfrom .shared_mem_utils import (\n    _get_edata_path,\n    _get_ndata_path,\n    _to_shared_mem,\n    DTYPE_DICT,\n)\n\nCANONICAL_ETYPE_DELIMITER = \":\"\n\n\ndef _etype_tuple_to_str(c_etype):\n    \"\"\"Convert canonical etype from tuple to string.\n\n    Examples\n    --------\n    >>> c_etype = ('user', 'like', 'item')\n    >>> c_etype_str = _etype_tuple_to_str(c_etype)\n    >>> print(c_etype_str)\n    'user:like:item'\n\n    \"\"\"\n    assert isinstance(c_etype, tuple) and len(c_etype) == 3, (\n        \"Passed-in canonical etype should be in format of (str, str, str). \"\n        f\"But got {c_etype}.\"\n    )\n    return CANONICAL_ETYPE_DELIMITER.join(c_etype)\n\n\ndef _etype_str_to_tuple(c_etype):\n    \"\"\"Convert canonical etype from tuple to string.\n\n    Examples\n    --------\n    >>> c_etype_str = 'user:like:item'\n    >>> c_etype = _etype_str_to_tuple(c_etype_str)\n    >>> print(c_etype)\n    ('user', 'like', 'item')\n\n    \"\"\"\n    ret = tuple(c_etype.split(CANONICAL_ETYPE_DELIMITER))\n    assert len(ret) == 3, (\n        \"Passed-in canonical etype should be in format of 'str:str:str'. \"\n        f\"But got {c_etype}.\"\n    )\n    return ret\n\n\ndef _move_metadata_to_shared_mem(\n    graph_name,\n    num_nodes,\n    num_edges,\n    part_id,\n    num_partitions,\n    node_map,\n    edge_map,\n    is_range_part,\n):\n    \"\"\"Move all metadata of the partition book to the shared memory.\n\n    These metadata will be used to construct graph partition book.\n\n    Parameters\n    ----------\n    graph_name : str\n        The name of the graph\n    num_nodes : int\n        The total number of nodes\n    num_edges : int\n        The total number of edges\n    part_id : int\n        The partition ID.\n    num_partitions : int\n        The number of physical partitions generated for the graph.\n    node_map : Tensor\n        It stores the mapping information from node IDs to partitions. With range partitioning,\n        the tensor stores the serialized result of partition ranges.\n    edge_map : Tensor\n        It stores the mapping information from edge IDs to partitions. With range partitioning,\n        the tensor stores the serialized result of partition ranges.\n    is_range_part : bool\n        Indicate that we use a range partition. This is important for us to deserialize data\n        in node_map and edge_map.\n\n    Returns\n    -------\n    (Tensor, Tensor, Tensor)\n        The first tensor stores the serialized metadata, the second tensor stores the serialized\n        node map and the third tensor stores the serialized edge map. All tensors are stored in\n        shared memory.\n    \"\"\"\n    meta = _to_shared_mem(\n        F.tensor(\n            [\n                int(is_range_part),\n                num_nodes,\n                num_edges,\n                num_partitions,\n                part_id,\n                len(node_map),\n                len(edge_map),\n            ]\n        ),\n        _get_ndata_path(graph_name, \"meta\"),\n    )\n    node_map = _to_shared_mem(node_map, _get_ndata_path(graph_name, \"node_map\"))\n    edge_map = _to_shared_mem(edge_map, _get_edata_path(graph_name, \"edge_map\"))\n    return meta, node_map, edge_map\n\n\ndef _get_shared_mem_metadata(graph_name):\n    \"\"\"Get the metadata of the graph from shared memory.\n\n    The server serializes the metadata of a graph and store them in shared memory.\n    The client needs to deserialize the data in shared memory and get the metadata\n    of the graph.\n\n    Parameters\n    ----------\n    graph_name : str\n        The name of the graph. We can use the graph name to find the shared memory name.\n\n    Returns\n    -------\n    (bool, int, int, Tensor, Tensor)\n        The first element indicates whether it is range partitioning;\n        the second element is the partition ID;\n        the third element is the number of partitions;\n        the fourth element is the tensor that stores the serialized result of node maps;\n        the fifth element is the tensor that stores the serialized result of edge maps.\n    \"\"\"\n    # The metadata has 7 elements: is_range_part, num_nodes, num_edges, num_partitions, part_id,\n    # the length of node map and the length of the edge map.\n    shape = (7,)\n    dtype = F.int64\n    dtype = DTYPE_DICT[dtype]\n    data = empty_shared_mem(\n        _get_ndata_path(graph_name, \"meta\"), False, shape, dtype\n    )\n    dlpack = data.to_dlpack()\n    meta = F.asnumpy(F.zerocopy_from_dlpack(dlpack))\n    (\n        is_range_part,\n        _,\n        _,\n        num_partitions,\n        part_id,\n        node_map_len,\n        edge_map_len,\n    ) = meta\n\n    # Load node map\n    data = empty_shared_mem(\n        _get_ndata_path(graph_name, \"node_map\"), False, (node_map_len,), dtype\n    )\n    dlpack = data.to_dlpack()\n    node_map = F.zerocopy_from_dlpack(dlpack)\n\n    # Load edge_map\n    data = empty_shared_mem(\n        _get_edata_path(graph_name, \"edge_map\"), False, (edge_map_len,), dtype\n    )\n    dlpack = data.to_dlpack()\n    edge_map = F.zerocopy_from_dlpack(dlpack)\n\n    return is_range_part, part_id, num_partitions, node_map, edge_map\n\n\ndef get_shared_mem_partition_book(graph_name):\n    \"\"\"Get a graph partition book from shared memory.\n\n    A graph partition book of a specific graph can be serialized to shared memory.\n    We can reconstruct a graph partition book from shared memory.\n\n    Parameters\n    ----------\n    graph_name : str\n        The name of the graph.\n\n    Returns\n    -------\n    GraphPartitionBook\n        A graph partition book for a particular partition.\n    \"\"\"\n    if not exist_shared_mem_array(_get_ndata_path(graph_name, \"meta\")):\n        return None\n    (\n        is_range_part,\n        part_id,\n        num_parts,\n        node_map_data,\n        edge_map_data,\n    ) = _get_shared_mem_metadata(graph_name)\n    if is_range_part == 1:\n        # node ID ranges and edge ID ranges are stored in the order of node type IDs\n        # and edge type IDs.\n        node_map = {}\n        ntypes = {}\n        # node_map_data and edge_map_data were serialized with pickle and converted into\n        # a list of bytes and then stored in a numpy array before being placed in shared\n        # memory. To deserialize, we need to reverse the process.\n        node_map_data = pickle.loads(bytes(F.asnumpy(node_map_data).tolist()))\n        for i, (ntype, nid_range) in enumerate(node_map_data):\n            ntypes[ntype] = i\n            node_map[ntype] = nid_range\n\n        edge_map = {}\n        etypes = {}\n        edge_map_data = pickle.loads(bytes(F.asnumpy(edge_map_data).tolist()))\n        for i, (etype, eid_range) in enumerate(edge_map_data):\n            etypes[etype] = i\n            edge_map[etype] = eid_range\n        return RangePartitionBook(\n            part_id, num_parts, node_map, edge_map, ntypes, etypes\n        )\n    else:\n        raise TypeError(\"Only RangePartitionBook is supported currently.\")\n\n\ndef get_node_partition_from_book(book, device):\n    \"\"\"Get an NDArrayPartition of the nodes from a RangePartitionBook.\n\n    Parameters\n    ----------\n    book : RangePartitionBook\n        The partition book to extract the node partition from.\n    device : Device context object.\n        The location to node partition is to be used.\n\n    Returns\n    -------\n    NDarrayPartition\n        The NDArrayPartition object for the nodes in the graph.\n    \"\"\"\n    assert isinstance(book, RangePartitionBook), (\n        \"Can only convert \" \"RangePartitionBook to NDArrayPartition.\"\n    )\n    # create prefix-sum array on host\n    max_node_ids = F.zerocopy_from_numpy(book._max_node_ids)\n    cpu_range = F.cat(\n        [F.tensor([0], dtype=F.dtype(max_node_ids)), max_node_ids + 1], dim=0\n    )\n    gpu_range = F.copy_to(cpu_range, ctx=device)\n\n    # convert from numpy\n    array_size = int(F.as_scalar(cpu_range[-1]))\n    num_parts = book.num_partitions()\n\n    return NDArrayPartition(\n        array_size, num_parts, mode=\"range\", part_ranges=gpu_range\n    )\n\n\nclass GraphPartitionBook(ABC):\n    \"\"\"The base class of the graph partition book.\n\n    For distributed training, a graph is partitioned into multiple parts and is loaded\n    in multiple machines. The partition book contains all necessary information to locate\n    nodes and edges in the cluster.\n\n    The partition book contains various partition information, including\n\n    * the number of partitions,\n    * the partition ID that a node or edge belongs to,\n    * the node IDs and the edge IDs that a partition has.\n    * the local IDs of nodes and edges in a partition.\n\n    Currently, only one class that implement ``GraphPartitionBook``\n    :``RangePartitionBook``. It calculates the mapping between node/edge IDs\n    and partition IDs based on some small metadata because nodes/edges have been\n    relabeled to have IDs in the same partition fall in a contiguous ID range.\n\n    A graph partition book is constructed automatically when a graph is partitioned.\n    When a graph partition is loaded, a graph partition book is loaded as well.\n    Please see :py:meth:`~dgl.distributed.partition.partition_graph`,\n    :py:meth:`~dgl.distributed.partition.load_partition` and\n    :py:meth:`~dgl.distributed.partition.load_partition_book` for more details.\n    \"\"\"\n\n    def shared_memory(self, graph_name):\n        \"\"\"Move the partition book to shared memory.\n\n        Parameters\n        ----------\n        graph_name : str\n            The graph name. This name will be used to read the partition book from shared\n            memory in another process.\n        \"\"\"\n\n    def num_partitions(self):\n        \"\"\"Return the number of partitions.\n\n        Returns\n        -------\n        int\n            number of partitions\n        \"\"\"\n\n    def metadata(self):\n        \"\"\"Return the partition meta data.\n\n        The meta data includes:\n\n        * The machine ID.\n        * Number of nodes and edges of each partition.\n\n        Examples\n        --------\n        >>> print(g.get_partition_book().metadata())\n        >>> [{'machine_id' : 0, 'num_nodes' : 3000, 'num_edges' : 5000},\n        ...  {'machine_id' : 1, 'num_nodes' : 2000, 'num_edges' : 4888},\n        ...  ...]\n\n        Returns\n        -------\n        list[dict[str, any]]\n            Meta data of each partition.\n        \"\"\"\n\n    def nid2partid(self, nids, ntype):\n        \"\"\"From global node IDs to partition IDs\n\n        Parameters\n        ----------\n        nids : tensor\n            global node IDs\n        ntype : str\n            The node type\n\n        Returns\n        -------\n        tensor\n            partition IDs\n        \"\"\"\n\n    def eid2partid(self, eids, etype):\n        \"\"\"From global edge IDs to partition IDs\n\n        Parameters\n        ----------\n        eids : tensor\n            global edge IDs\n        etype : str or (str, str, str)\n            The edge type\n\n        Returns\n        -------\n        tensor\n            partition IDs\n        \"\"\"\n\n    def partid2nids(self, partid, ntype):\n        \"\"\"From partition id to global node IDs\n\n        Parameters\n        ----------\n        partid : int\n            partition id\n        ntype : str\n            The node type\n\n        Returns\n        -------\n        tensor\n            node IDs\n        \"\"\"\n\n    def partid2eids(self, partid, etype):\n        \"\"\"From partition id to global edge IDs\n\n        Parameters\n        ----------\n        partid : int\n            partition id\n        etype : str or (str, str, str)\n            The edge type\n\n        Returns\n        -------\n        tensor\n            edge IDs\n        \"\"\"\n\n    def nid2localnid(self, nids, partid, ntype):\n        \"\"\"Get local node IDs within the given partition.\n\n        Parameters\n        ----------\n        nids : tensor\n            global node IDs\n        partid : int\n            partition ID\n        ntype : str\n            The node type\n\n        Returns\n        -------\n        tensor\n             local node IDs\n        \"\"\"\n\n    def eid2localeid(self, eids, partid, etype):\n        \"\"\"Get the local edge ids within the given partition.\n\n        Parameters\n        ----------\n        eids : tensor\n            global edge IDs\n        partid : int\n            partition ID\n        etype : str or (str, str, str)\n            The edge type\n\n        Returns\n        -------\n        tensor\n             local edge IDs\n        \"\"\"\n\n    @property\n    def partid(self):\n        \"\"\"Get the current partition ID\n\n        Return\n        ------\n        int\n            The partition ID of current machine\n        \"\"\"\n\n    @property\n    def ntypes(self):\n        \"\"\"Get the list of node types\"\"\"\n\n    @property\n    def etypes(self):\n        \"\"\"Get the list of edge types\"\"\"\n\n    @property\n    def canonical_etypes(self):\n        \"\"\"Get the list of canonical edge types\n\n        Returns\n        -------\n        list[(str, str, str)]\n            A list of canonical etypes\n        \"\"\"\n\n    def to_canonical_etype(self, etype):\n        \"\"\"Convert an edge type to the corresponding canonical edge type.\n\n        Parameters\n        ----------\n        etype : str or (str, str, str)\n            The edge type\n\n        Returns\n        -------\n        (str, str, str)\n            The corresponding canonical edge type\n        \"\"\"\n\n    @property\n    def is_homogeneous(self):\n        \"\"\"check if homogeneous\"\"\"\n        return not (len(self.etypes) > 1 or len(self.ntypes) > 1)\n\n    def map_to_per_ntype(self, ids):\n        \"\"\"Map homogeneous node IDs to type-wise IDs and node types.\n\n        Parameters\n        ----------\n        ids : tensor\n            Homogeneous node IDs.\n\n        Returns\n        -------\n        (tensor, tensor)\n            node type IDs and type-wise node IDs.\n        \"\"\"\n\n    def map_to_per_etype(self, ids):\n        \"\"\"Map homogeneous edge IDs to type-wise IDs and edge types.\n\n        Parameters\n        ----------\n        ids : tensor\n            Homogeneous edge IDs.\n\n        Returns\n        -------\n        (tensor, tensor)\n            edge type IDs and type-wise edge IDs.\n        \"\"\"\n\n    def map_to_homo_nid(self, ids, ntype):\n        \"\"\"Map type-wise node IDs and type IDs to homogeneous node IDs.\n\n        Parameters\n        ----------\n        ids : tensor\n            Type-wise node Ids\n        ntype : str\n            node type\n\n        Returns\n        -------\n        Tensor\n            Homogeneous node IDs.\n        \"\"\"\n\n    def map_to_homo_eid(self, ids, etype):\n        \"\"\"Map type-wise edge IDs and type IDs to homogeneous edge IDs.\n\n        Parameters\n        ----------\n        ids : tensor\n            Type-wise edge Ids\n        etype : str or (str, str, str)\n            The edge type\n\n        Returns\n        -------\n        Tensor\n            Homogeneous edge IDs.\n        \"\"\"\n\n\nclass RangePartitionBook(GraphPartitionBook):\n    \"\"\"This partition book supports more efficient storage of partition information.\n\n    This partition book is used if the nodes and edges of a graph partition are assigned\n    with contiguous IDs. It uses very small amount of memory to store the partition\n    information.\n\n    Parameters\n    ----------\n    part_id : int\n        partition ID of current partition book\n    num_parts : int\n        number of total partitions\n    node_map : dict[str, Tensor]\n        Global node ID ranges within partitions for each node type. The key is the node type\n        name in string. The value is a tensor of shape :math:`(K, 2)`, where :math:`K` is\n        the number of partitions. Each row has two integers: the starting and the ending IDs\n        for a particular node type in a partition. For example, all nodes of type ``\"T\"`` in\n        partition ``i`` has ID range ``node_map[\"T\"][i][0]`` to ``node_map[\"T\"][i][1]``.\n    edge_map : dict[(str, str, str), Tensor]\n        Global edge ID ranges within partitions for each edge type. The key is the edge type\n        name in string. The value is a tensor of shape :math:`(K, 2)`, where :math:`K` is\n        the number of partitions. Each row has two integers: the starting and the ending IDs\n        for a particular edge type in a partition. For example, all edges of type ``\"T\"`` in\n        partition ``i`` has ID range ``edge_map[\"T\"][i][0]`` to ``edge_map[\"T\"][i][1]``.\n    ntypes : dict[str, int]\n        map ntype strings to ntype IDs.\n    etypes : dict[(str, str, str), int]\n        map canonical etypes to etype IDs.\n\n    \"\"\"\n\n    def __init__(self, part_id, num_parts, node_map, edge_map, ntypes, etypes):\n        assert part_id >= 0, \"part_id cannot be a negative number.\"\n        assert num_parts > 0, \"num_parts must be greater than zero.\"\n        self._partid = part_id\n        self._num_partitions = num_parts\n        self._ntypes = [None] * len(ntypes)\n        self._etypes = [None] * len(etypes)\n        self._canonical_etypes = [None] * len(etypes)\n        # map etypes to canonical ones\n        self._etype2canonical = {}\n        for ntype in ntypes:\n            ntype_id = ntypes[ntype]\n            self._ntypes[ntype_id] = ntype\n        assert all(\n            ntype is not None for ntype in self._ntypes\n        ), \"The node types have invalid IDs.\"\n        for c_etype, etype_id in etypes.items():\n            assert isinstance(c_etype, tuple) and len(c_etype) == 3, (\n                \"Expect canonical edge type in a triplet of string, but got \"\n                f\"{c_etype}.\"\n            )\n            etype = c_etype[1]\n            self._etypes[etype_id] = etype\n            self._canonical_etypes[etype_id] = c_etype\n            if etype in self._etype2canonical:\n                # If one etype maps to multiple canonical etypes, empty tuple\n                # is used to indicate such ambiguity casued by etype. See more\n                # details in self.to_canonical_etype().\n                self._etype2canonical[etype] = tuple()\n            else:\n                self._etype2canonical[etype] = c_etype\n        assert all(\n            etype is not None for etype in self._etypes\n        ), \"The edge types have invalid IDs.\"\n\n        # This stores the node ID ranges for each node type in each partition.\n        # The key is the node type, the value is a NumPy matrix with two\n        # columns, in which each row indicates the start and the end of the\n        # node ID range in a partition. The node IDs are global node IDs in the\n        # homogeneous representation.\n        self._typed_nid_range = {}\n        # This stores the node ID map for per-node-type IDs in each partition.\n        # The key is the node type, the value is a NumPy vector which indicates\n        # the last node ID in a partition.\n        self._typed_max_node_ids = {}\n        max_node_map = np.zeros((num_parts,), dtype=np.int64)\n        for key in node_map:\n            assert key in ntypes, \"Unexpected ntype: {}.\".format(key)\n            if not isinstance(node_map[key], np.ndarray):\n                node_map[key] = F.asnumpy(node_map[key])\n            assert node_map[key].shape == (num_parts, 2)\n            self._typed_nid_range[key] = node_map[key]\n            # This is used for per-node-type lookup.\n            self._typed_max_node_ids[key] = np.cumsum(\n                self._typed_nid_range[key][:, 1]\n                - self._typed_nid_range[key][:, 0]\n            )\n            # This is used for homogeneous node ID lookup.\n            max_node_map = np.maximum(\n                self._typed_nid_range[key][:, 1], max_node_map\n            )\n        # This is a vector that indicates the last node ID in each partition.\n        # The ID is the global ID in the homogeneous representation.\n        self._max_node_ids = max_node_map\n\n        # Similar to _typed_nid_range.\n        self._typed_eid_range = {}\n        # similar to _typed_max_node_ids.\n        self._typed_max_edge_ids = {}\n        max_edge_map = np.zeros((num_parts,), dtype=np.int64)\n        for key in edge_map:\n            assert key in etypes, \"Unexpected etype: {}.\".format(key)\n            if not isinstance(edge_map[key], np.ndarray):\n                edge_map[key] = F.asnumpy(edge_map[key])\n            assert edge_map[key].shape == (num_parts, 2)\n            self._typed_eid_range[key] = edge_map[key]\n            # This is used for per-edge-type lookup.\n            self._typed_max_edge_ids[key] = np.cumsum(\n                self._typed_eid_range[key][:, 1]\n                - self._typed_eid_range[key][:, 0]\n            )\n            # This is used for homogeneous edge ID lookup.\n            max_edge_map = np.maximum(\n                self._typed_eid_range[key][:, 1], max_edge_map\n            )\n        # Similar to _max_node_ids\n        self._max_edge_ids = max_edge_map\n\n        # These two are map functions that map node/edge IDs to node/edge type IDs.\n        self._nid_map = IdMap(self._typed_nid_range)\n        self._eid_map = IdMap(self._typed_eid_range)\n\n        # Local node/edge type offset that maps the local homogenized node/edge IDs\n        # to local heterogenized node/edge IDs.  One can do the mapping by binary search\n        # on these arrays.\n        self._local_ntype_offset = np.cumsum(\n            [0]\n            + [\n                v[self._partid, 1] - v[self._partid, 0]\n                for v in self._typed_nid_range.values()\n            ]\n        ).tolist()\n        self._local_etype_offset = np.cumsum(\n            [0]\n            + [\n                v[self._partid, 1] - v[self._partid, 0]\n                for v in self._typed_eid_range.values()\n            ]\n        ).tolist()\n\n        # Get meta data of the partition book\n        self._partition_meta_data = []\n        for partid in range(self._num_partitions):\n            nrange_start = max_node_map[partid - 1] if partid > 0 else 0\n            nrange_end = max_node_map[partid]\n            num_nodes = nrange_end - nrange_start\n\n            erange_start = max_edge_map[partid - 1] if partid > 0 else 0\n            erange_end = max_edge_map[partid]\n            num_edges = erange_end - erange_start\n\n            part_info = {}\n            part_info[\"machine_id\"] = partid\n            part_info[\"num_nodes\"] = int(num_nodes)\n            part_info[\"num_edges\"] = int(num_edges)\n            self._partition_meta_data.append(part_info)\n\n    def shared_memory(self, graph_name):\n        \"\"\"Move data to shared memory.\"\"\"\n        # we need to store the nid ranges and eid ranges of different types in the order defined\n        # by type IDs.\n        nid_range = [None] * len(self.ntypes)\n        for i, ntype in enumerate(self.ntypes):\n            nid_range[i] = (ntype, self._typed_nid_range[ntype])\n        nid_range_pickle = list(pickle.dumps(nid_range))\n\n        eid_range = [None] * len(self.canonical_etypes)\n        for i, etype in enumerate(self.canonical_etypes):\n            eid_range[i] = (etype, self._typed_eid_range[etype])\n        eid_range_pickle = list(pickle.dumps(eid_range))\n\n        self._meta = _move_metadata_to_shared_mem(\n            graph_name,\n            0,  # We don't need to provide the number of nodes\n            0,  # We don't need to provide the number of edges\n            self._partid,\n            self._num_partitions,\n            F.tensor(nid_range_pickle),\n            F.tensor(eid_range_pickle),\n            True,\n        )\n\n    def num_partitions(self):\n        \"\"\"Return the number of partitions.\"\"\"\n        return self._num_partitions\n\n    def _num_nodes(self, ntype=DEFAULT_NTYPE):\n        \"\"\"The total number of nodes\"\"\"\n        if ntype == DEFAULT_NTYPE:\n            return int(self._max_node_ids[-1])\n        else:\n            return int(self._typed_max_node_ids[ntype][-1])\n\n    def _num_edges(self, etype=DEFAULT_ETYPE):\n        \"\"\"The total number of edges\"\"\"\n        if etype in (DEFAULT_ETYPE, DEFAULT_ETYPE[1]):\n            return int(self._max_edge_ids[-1])\n        else:\n            c_etype = self.to_canonical_etype(etype)\n            return int(self._typed_max_edge_ids[c_etype][-1])\n\n    def metadata(self):\n        \"\"\"Return the partition meta data.\"\"\"\n        return self._partition_meta_data\n\n    def map_to_per_ntype(self, ids):\n        \"\"\"Map global homogeneous node IDs to node type IDs.\n        Returns\n            type_ids, per_type_ids\n        \"\"\"\n        return self._nid_map(ids)\n\n    def map_to_per_etype(self, ids):\n        \"\"\"Map global homogeneous edge IDs to edge type IDs.\n        Returns\n            type_ids, per_type_ids\n        \"\"\"\n        return self._eid_map(ids)\n\n    def map_to_homo_nid(self, ids, ntype):\n        \"\"\"Map per-node-type IDs to global node IDs in the homogeneous format.\"\"\"\n        ids = utils.toindex(ids).tousertensor()\n        partids = self.nid2partid(ids, ntype)\n        typed_max_nids = F.zerocopy_from_numpy(self._typed_max_node_ids[ntype])\n        end_diff = F.gather_row(typed_max_nids, partids) - ids\n        typed_nid_range = F.zerocopy_from_numpy(\n            self._typed_nid_range[ntype][:, 1]\n        )\n        return F.gather_row(typed_nid_range, partids) - end_diff\n\n    def map_to_homo_eid(self, ids, etype):\n        \"\"\"Map per-edge-type IDs to global edge IDs in the homoenegeous format.\"\"\"\n        ids = utils.toindex(ids).tousertensor()\n        c_etype = self.to_canonical_etype(etype)\n        partids = self.eid2partid(ids, c_etype)\n        typed_max_eids = F.zerocopy_from_numpy(\n            self._typed_max_edge_ids[c_etype]\n        )\n        end_diff = F.gather_row(typed_max_eids, partids) - ids\n        typed_eid_range = F.zerocopy_from_numpy(\n            self._typed_eid_range[c_etype][:, 1]\n        )\n        return F.gather_row(typed_eid_range, partids) - end_diff\n\n    def nid2partid(self, nids, ntype=DEFAULT_NTYPE):\n        \"\"\"From global node IDs to partition IDs\"\"\"\n        nids = utils.toindex(nids)\n        if ntype == DEFAULT_NTYPE:\n            ret = np.searchsorted(\n                self._max_node_ids, nids.tonumpy(), side=\"right\"\n            )\n        else:\n            ret = np.searchsorted(\n                self._typed_max_node_ids[ntype], nids.tonumpy(), side=\"right\"\n            )\n        ret = utils.toindex(ret)\n        return ret.tousertensor()\n\n    def eid2partid(self, eids, etype=DEFAULT_ETYPE):\n        \"\"\"From global edge IDs to partition IDs\"\"\"\n        eids = utils.toindex(eids)\n        if etype in (DEFAULT_ETYPE, DEFAULT_ETYPE[1]):\n            ret = np.searchsorted(\n                self._max_edge_ids, eids.tonumpy(), side=\"right\"\n            )\n        else:\n            c_etype = self.to_canonical_etype(etype)\n            ret = np.searchsorted(\n                self._typed_max_edge_ids[c_etype], eids.tonumpy(), side=\"right\"\n            )\n        ret = utils.toindex(ret)\n        return ret.tousertensor()\n\n    def partid2nids(self, partid, ntype=DEFAULT_NTYPE):\n        \"\"\"From partition ID to global node IDs\"\"\"\n        # TODO do we need to cache it?\n        if ntype == DEFAULT_NTYPE:\n            start = self._max_node_ids[partid - 1] if partid > 0 else 0\n            end = self._max_node_ids[partid]\n            return F.arange(start, end)\n        else:\n            start = (\n                self._typed_max_node_ids[ntype][partid - 1] if partid > 0 else 0\n            )\n            end = self._typed_max_node_ids[ntype][partid]\n            return F.arange(start, end)\n\n    def partid2eids(self, partid, etype=DEFAULT_ETYPE):\n        \"\"\"From partition ID to global edge IDs\"\"\"\n        # TODO do we need to cache it?\n        if etype in (DEFAULT_ETYPE, DEFAULT_ETYPE[1]):\n            start = self._max_edge_ids[partid - 1] if partid > 0 else 0\n            end = self._max_edge_ids[partid]\n            return F.arange(start, end)\n        else:\n            c_etype = self.to_canonical_etype(etype)\n            start = (\n                self._typed_max_edge_ids[c_etype][partid - 1]\n                if partid > 0\n                else 0\n            )\n            end = self._typed_max_edge_ids[c_etype][partid]\n            return F.arange(start, end)\n\n    def nid2localnid(self, nids, partid, ntype=DEFAULT_NTYPE):\n        \"\"\"Get local node IDs within the given partition.\"\"\"\n        if partid != self._partid:\n            raise RuntimeError(\n                \"Now RangePartitionBook does not support \\\n                getting remote tensor of nid2localnid.\"\n            )\n\n        nids = utils.toindex(nids)\n        nids = nids.tousertensor()\n        if ntype == DEFAULT_NTYPE:\n            start = self._max_node_ids[partid - 1] if partid > 0 else 0\n        else:\n            start = (\n                self._typed_max_node_ids[ntype][partid - 1] if partid > 0 else 0\n            )\n        return nids - int(start)\n\n    def eid2localeid(self, eids, partid, etype=DEFAULT_ETYPE):\n        \"\"\"Get the local edge IDs within the given partition.\"\"\"\n        if partid != self._partid:\n            raise RuntimeError(\n                \"Now RangePartitionBook does not support \\\n                getting remote tensor of eid2localeid.\"\n            )\n\n        eids = utils.toindex(eids)\n        eids = eids.tousertensor()\n        if etype in (DEFAULT_ETYPE, DEFAULT_ETYPE[1]):\n            start = self._max_edge_ids[partid - 1] if partid > 0 else 0\n        else:\n            c_etype = self.to_canonical_etype(etype)\n            start = (\n                self._typed_max_edge_ids[c_etype][partid - 1]\n                if partid > 0\n                else 0\n            )\n        return eids - int(start)\n\n    @property\n    def partid(self):\n        \"\"\"Get the current partition ID.\"\"\"\n        return self._partid\n\n    @property\n    def ntypes(self):\n        \"\"\"Get the list of node types\"\"\"\n        return self._ntypes\n\n    @property\n    def etypes(self):\n        \"\"\"Get the list of edge types\"\"\"\n        return self._etypes\n\n    @property\n    def canonical_etypes(self):\n        \"\"\"Get the list of canonical edge types\n\n        Returns\n        -------\n        list[(str, str, str)] or list[None]\n            A list of canonical etypes. If keys of ``edge_map`` and ``etypes``\n            are strings, a list of ``None`` is returned as canonical etypes\n            are not available.\n        \"\"\"\n        return self._canonical_etypes\n\n    @property\n    def local_ntype_offset(self):\n        \"\"\"Get the node type offset array of the local partition.\n\n        The i-th element indicates the starting position of the i-th node type.\n        \"\"\"\n        return self._local_ntype_offset\n\n    @property\n    def local_etype_offset(self):\n        \"\"\"Get the edge type offset array of the local partition.\n\n        The i-th element indicates the starting position of the i-th edge type.\n        \"\"\"\n        return self._local_etype_offset\n\n    def to_canonical_etype(self, etype):\n        \"\"\"Convert an edge type to the corresponding canonical edge type.\n\n        Parameters\n        ----------\n        etype : str or (str, str, str)\n            The edge type\n\n        Returns\n        -------\n        (str, str, str)\n            The corresponding canonical edge type\n        \"\"\"\n        if isinstance(etype, tuple):\n            if etype not in self.canonical_etypes:\n                raise DGLError('Edge type \"{}\" does not exist.'.format(etype))\n            return etype\n        ret = self._etype2canonical.get(etype, None)\n        if ret is None:\n            raise DGLError('Edge type \"{}\" does not exist.'.format(etype))\n        if len(ret) == 0:\n            raise DGLError(\n                'Edge type \"%s\" is ambiguous. Please use canonical edge type '\n                \"in the form of (srctype, etype, dsttype)\" % etype\n            )\n        return ret\n\n    @property\n    def global_nid_dtype(self):\n        \"\"\"Get the node ID's dtype\"\"\"\n        return self._nid_map.torch_dtype\n\n    @property\n    def global_eid_dtype(self):\n        \"\"\"Get the edge ID's dtype\"\"\"\n        return self._eid_map.torch_dtype\n\n\nNODE_PART_POLICY = \"node\"\nEDGE_PART_POLICY = \"edge\"\nPOLICY_DELIMITER = \"~\"\n\n\nclass PartitionPolicy(object):\n    \"\"\"This defines a partition policy for a distributed tensor or distributed embedding.\n\n    When DGL shards tensors and stores them in a cluster of machines, it requires\n    partition policies that map rows of the tensors to machines in the cluster.\n\n    Although an arbitrary partition policy can be defined, DGL currently supports\n    two partition policies for mapping nodes and edges to machines. To define a partition\n    policy from a graph partition book, users need to specify the policy name ('node' or 'edge').\n\n    Parameters\n    ----------\n    policy_str : str\n        Partition policy name, e.g., 'edge~_N:_E:_N' or 'node~_N'.\n    partition_book : GraphPartitionBook\n        A graph partition book\n    \"\"\"\n\n    def __init__(self, policy_str, partition_book):\n        assert policy_str.startswith(NODE_PART_POLICY) or policy_str.startswith(\n            EDGE_PART_POLICY\n        ), (\n            f\"policy_str must start with {NODE_PART_POLICY} or \"\n            f\"{EDGE_PART_POLICY}, but got {policy_str}.\"\n        )\n        if NODE_PART_POLICY == policy_str:\n            policy_str = NODE_PART_POLICY + POLICY_DELIMITER + DEFAULT_NTYPE\n        if EDGE_PART_POLICY == policy_str:\n            policy_str = EDGE_PART_POLICY + POLICY_DELIMITER + DEFAULT_ETYPE[1]\n        self._policy_str = policy_str\n        self._part_id = partition_book.partid\n        self._partition_book = partition_book\n        part_policy, self._type_name = policy_str.split(POLICY_DELIMITER, 1)\n        if part_policy == EDGE_PART_POLICY:\n            self._type_name = _etype_str_to_tuple(self._type_name)\n        self._is_node = self.policy_str.startswith(NODE_PART_POLICY)\n\n    @property\n    def policy_str(self):\n        \"\"\"Get the policy name\n\n        Returns\n        -------\n        str\n            The name of the partition policy.\n        \"\"\"\n        return self._policy_str\n\n    @property\n    def type_name(self):\n        \"\"\"Get the type name: ntype or etype\n\n        Returns\n        -------\n        str or (str, str, str)\n            The ntype or etype.\n        \"\"\"\n        return self._type_name\n\n    @property\n    def part_id(self):\n        \"\"\"Get partition ID\n\n        Returns\n        -------\n        int\n            The partition ID\n        \"\"\"\n        return self._part_id\n\n    @property\n    def partition_book(self):\n        \"\"\"Get partition book\n\n        Returns\n        -------\n        GraphPartitionBook\n            The graph partition book\n        \"\"\"\n        return self._partition_book\n\n    @property\n    def is_node(self):\n        \"\"\"Indicate whether the policy is for node or edge\n\n        Returns\n        -------\n        bool\n            node or edge\n        \"\"\"\n        return self._is_node\n\n    def get_data_name(self, name):\n        \"\"\"Get HeteroDataName\"\"\"\n        return HeteroDataName(self.is_node, self.type_name, name)\n\n    def to_local(self, id_tensor):\n        \"\"\"Mapping global ID to local ID.\n\n        Parameters\n        ----------\n        id_tensor : tensor\n            Gloabl ID tensor\n\n        Return\n        ------\n        tensor\n            local ID tensor\n        \"\"\"\n        if self.is_node:\n            return self._partition_book.nid2localnid(\n                id_tensor, self._part_id, self.type_name\n            )\n        else:\n            return self._partition_book.eid2localeid(\n                id_tensor, self._part_id, self.type_name\n            )\n\n    def to_partid(self, id_tensor):\n        \"\"\"Mapping global ID to partition ID.\n\n        Parameters\n        ----------\n        id_tensor : tensor\n            Global ID tensor\n\n        Return\n        ------\n        tensor\n            partition ID\n        \"\"\"\n        if self.is_node:\n            return self._partition_book.nid2partid(id_tensor, self.type_name)\n        else:\n            return self._partition_book.eid2partid(id_tensor, self.type_name)\n\n    def get_part_size(self):\n        \"\"\"Get data size of current partition.\n\n        Returns\n        -------\n        int\n            data size\n        \"\"\"\n        if self.is_node:\n            return len(\n                self._partition_book.partid2nids(self._part_id, self.type_name)\n            )\n        else:\n            return len(\n                self._partition_book.partid2eids(self._part_id, self.type_name)\n            )\n\n    def get_size(self):\n        \"\"\"Get the full size of the data.\n\n        Returns\n        -------\n        int\n            data size\n        \"\"\"\n        if self.is_node:\n            return self._partition_book._num_nodes(self.type_name)\n        else:\n            return self._partition_book._num_edges(self.type_name)\n\n\nclass NodePartitionPolicy(PartitionPolicy):\n    \"\"\"Partition policy for nodes.\"\"\"\n\n    def __init__(self, partition_book, ntype=DEFAULT_NTYPE):\n        super(NodePartitionPolicy, self).__init__(\n            NODE_PART_POLICY + POLICY_DELIMITER + ntype, partition_book\n        )\n\n\nclass EdgePartitionPolicy(PartitionPolicy):\n    \"\"\"Partition policy for edges.\"\"\"\n\n    def __init__(self, partition_book, etype=DEFAULT_ETYPE):\n        assert (\n            isinstance(etype, tuple) and len(etype) == 3\n        ), f\"Expect canonical edge type in a triplet of string, but got {etype}.\"\n        super(EdgePartitionPolicy, self).__init__(\n            EDGE_PART_POLICY + POLICY_DELIMITER + _etype_tuple_to_str(etype),\n            partition_book,\n        )\n\n\nclass HeteroDataName(object):\n    \"\"\"The data name in a heterogeneous graph.\n\n    A unique data name has three components:\n    * indicate it's node data or edge data.\n    * indicate the node/edge type.\n    * the name of the data.\n\n    Parameters\n    ----------\n    is_node : bool\n        Indicate whether it's node data or edge data.\n    entity_type : str or (str, str, str)\n        The type of the node/edge.\n    data_name : str\n        The name of the data.\n    \"\"\"\n\n    def __init__(self, is_node, entity_type, data_name):\n        self._policy = NODE_PART_POLICY if is_node else EDGE_PART_POLICY\n        if not is_node:\n            assert isinstance(entity_type, tuple) and len(entity_type) == 3, (\n                \"Expect canonical edge type in a triplet of string, but got \"\n                f\"{entity_type}.\"\n            )\n        self._entity_type = entity_type\n        self.data_name = data_name\n\n    @property\n    def policy_str(self):\n        \"\"\"concatenate policy and entity type into string\"\"\"\n        entity_type = self.get_type()\n        if self.is_edge():\n            entity_type = _etype_tuple_to_str(entity_type)\n        return self._policy + POLICY_DELIMITER + entity_type\n\n    def is_node(self):\n        \"\"\"Is this the name of node data\"\"\"\n        return self._policy == NODE_PART_POLICY\n\n    def is_edge(self):\n        \"\"\"Is this the name of edge data\"\"\"\n        return self._policy == EDGE_PART_POLICY\n\n    def get_type(self):\n        \"\"\"The type of the node/edge.\n        This is only meaningful in a heterogeneous graph.\n        In homogeneous graph, type is '_N' for a node and '_N:_E:_N' for an\n        edge.\n        \"\"\"\n        return self._entity_type\n\n    def get_name(self):\n        \"\"\"The name of the data.\"\"\"\n        return self.data_name\n\n    def __str__(self):\n        \"\"\"The full name of the data.\n\n        The full name is used as the key in the KVStore.\n        \"\"\"\n        return self.policy_str + POLICY_DELIMITER + self.data_name\n\n\ndef parse_hetero_data_name(name):\n    \"\"\"Parse data name and create HeteroDataName.\n\n    The data name has a specialized format. We can parse the name to determine if\n    it's node data or edge data, node/edge type and its actual name. The data name\n    has three fields and they are separated by \":\".\n\n    Parameters\n    ----------\n    name : str\n        The data name\n\n    Returns\n    -------\n    HeteroDataName\n    \"\"\"\n    names = name.split(POLICY_DELIMITER)\n    assert len(names) == 3, \"{} is not a valid heterograph data name\".format(\n        name\n    )\n    assert names[0] in (\n        NODE_PART_POLICY,\n        EDGE_PART_POLICY,\n    ), \"{} is not a valid heterograph data name\".format(name)\n    is_node = names[0] == NODE_PART_POLICY\n    entity_type = names[1]\n    if not is_node:\n        entity_type = _etype_str_to_tuple(entity_type)\n    return HeteroDataName(is_node, entity_type, names[2])\n"
  },
  {
    "path": "python/dgl/distributed/graph_services.py",
    "content": "\"\"\"A set of graph services of getting subgraphs from DistGraph\"\"\"\n\nimport os\nfrom collections import namedtuple\n\nimport numpy as np\n\nimport torch\n\nfrom .. import backend as F, graphbolt as gb\nfrom ..base import EID, ETYPE, NID\nfrom ..convert import graph, heterograph\nfrom ..sampling import (\n    sample_etype_neighbors as local_sample_etype_neighbors,\n    sample_neighbors as local_sample_neighbors,\n)\nfrom ..subgraph import in_subgraph as local_in_subgraph\nfrom ..utils import toindex\nfrom .constants import DGL2GB_EID, GB_DST_ID\nfrom .rpc import (\n    recv_responses,\n    register_service,\n    Request,\n    Response,\n    send_requests_to_machine,\n)\n\n__all__ = [\n    \"sample_neighbors\",\n    \"sample_etype_neighbors\",\n    \"in_subgraph\",\n    \"find_edges\",\n]\n\nSAMPLING_SERVICE_ID = 6657\nINSUBGRAPH_SERVICE_ID = 6658\nEDGES_SERVICE_ID = 6659\nOUTDEGREE_SERVICE_ID = 6660\nINDEGREE_SERVICE_ID = 6661\nETYPE_SAMPLING_SERVICE_ID = 6662\n\n\nclass SubgraphResponse(Response):\n    \"\"\"The response for sampling and in_subgraph\"\"\"\n\n    def __init__(\n        self, global_src, global_dst, *, global_eids=None, etype_ids=None\n    ):\n        self.global_src = global_src\n        self.global_dst = global_dst\n        self.global_eids = global_eids\n        self.etype_ids = etype_ids\n\n    def __setstate__(self, state):\n        (\n            self.global_src,\n            self.global_dst,\n            self.global_eids,\n            self.etype_ids,\n        ) = state\n\n    def __getstate__(self):\n        return (\n            self.global_src,\n            self.global_dst,\n            self.global_eids,\n            self.etype_ids,\n        )\n\n\nclass FindEdgeResponse(Response):\n    \"\"\"The response for sampling and in_subgraph\"\"\"\n\n    def __init__(self, global_src, global_dst, order_id):\n        self.global_src = global_src\n        self.global_dst = global_dst\n        self.order_id = order_id\n\n    def __setstate__(self, state):\n        self.global_src, self.global_dst, self.order_id = state\n\n    def __getstate__(self):\n        return self.global_src, self.global_dst, self.order_id\n\n\ndef _sample_neighbors_graphbolt(\n    g,\n    gpb,\n    nodes,\n    fanout,\n    edge_dir=\"in\",\n    prob=None,\n    exclude_edges=None,\n    replace=False,\n):\n    \"\"\"Sample from local partition via graphbolt.\n\n    The input nodes use global IDs. We need to map the global node IDs to local\n    node IDs, perform sampling and map the sampled results to the global IDs\n    space again. The sampled results are stored in three vectors that store\n    source nodes, destination nodes, etype IDs and edge IDs.\n\n    Parameters\n    ----------\n    g : FusedCSCSamplingGraph\n        The local partition.\n    gpb : GraphPartitionBook\n        The graph partition book.\n    nodes : tensor\n        The nodes to sample neighbors from.\n    fanout : tensor or int\n        The number of edges to be sampled for each node.\n    edge_dir : str, optional\n        Determines whether to sample inbound or outbound edges.\n    prob : tensor, optional\n        The probability associated with each neighboring edge of a node.\n    exclude_edges : tensor, optional\n        The edges to exclude when sampling.\n    replace : bool, optional\n        If True, sample with replacement.\n\n    Returns\n    -------\n    tensor\n        The source node ID array.\n    tensor\n        The destination node ID array.\n    tensor\n        The edge ID array.\n    tensor\n        The edge type ID array.\n    \"\"\"\n    assert (\n        edge_dir == \"in\"\n    ), f\"GraphBolt only supports inbound edge sampling but got {edge_dir}.\"\n    assert exclude_edges is None, \"GraphBolt does not support excluding edges.\"\n\n    # 1. Map global node IDs to local node IDs.\n    nodes = gpb.nid2localnid(nodes, gpb.partid)\n    # Local partition may be saved in torch.int32 even though the global graph\n    # is in torch.int64.\n    nodes = nodes.to(dtype=g.indices.dtype)\n\n    # 2. Perform sampling.\n    probs_or_mask = None\n    if prob is not None:\n        probs_or_mask = g.edge_attributes[prob]\n    # Sanity checks.\n    assert isinstance(\n        g, gb.FusedCSCSamplingGraph\n    ), \"Expect a FusedCSCSamplingGraph.\"\n    assert isinstance(nodes, torch.Tensor), \"Expect a tensor of nodes.\"\n    if isinstance(fanout, int):\n        fanout = torch.LongTensor([fanout])\n    assert isinstance(fanout, torch.Tensor), \"Expect a tensor of fanout.\"\n\n    subgraph = g._sample_neighbors(\n        nodes,\n        None,\n        fanout,\n        replace=replace,\n        probs_or_mask=probs_or_mask,\n    )\n\n    # 3. Map local node IDs to global node IDs.\n    local_src = subgraph.indices\n    local_dst = gb.expand_indptr(\n        subgraph.indptr,\n        dtype=local_src.dtype,\n        node_ids=subgraph.original_column_node_ids,\n        output_size=local_src.shape[0],\n    )\n    global_nid_mapping = g.node_attributes[NID]\n    global_src = global_nid_mapping[local_src]\n    global_dst = global_nid_mapping[local_dst]\n\n    global_eids = None\n    if g.edge_attributes is not None and EID in g.edge_attributes:\n        global_eids = g.edge_attributes[EID][subgraph.original_edge_ids]\n    return LocalSampledGraph(\n        global_src, global_dst, global_eids, subgraph.type_per_edge\n    )\n\n\ndef _sample_neighbors_dgl(\n    local_g,\n    partition_book,\n    seed_nodes,\n    fan_out,\n    edge_dir=\"in\",\n    prob=None,\n    exclude_edges=None,\n    replace=False,\n):\n    \"\"\"Sample from local partition.\n\n    The input nodes use global IDs. We need to map the global node IDs to local node IDs,\n    perform sampling and map the sampled results to the global IDs space again.\n    The sampled results are stored in three vectors that store source nodes, destination nodes\n    and edge IDs.\n    \"\"\"\n    local_ids = partition_book.nid2localnid(seed_nodes, partition_book.partid)\n    local_ids = F.astype(local_ids, local_g.idtype)\n    # local_ids = self.seed_nodes\n    sampled_graph = local_sample_neighbors(\n        local_g,\n        local_ids,\n        fan_out,\n        edge_dir=edge_dir,\n        prob=prob,\n        exclude_edges=exclude_edges,\n        replace=replace,\n        _dist_training=True,\n    )\n    global_nid_mapping = local_g.ndata[NID]\n    src, dst = sampled_graph.edges()\n    global_src, global_dst = F.gather_row(\n        global_nid_mapping, src\n    ), F.gather_row(global_nid_mapping, dst)\n    global_eids = F.gather_row(local_g.edata[EID], sampled_graph.edata[EID])\n    return LocalSampledGraph(global_src, global_dst, global_eids)\n\n\ndef _sample_neighbors(use_graphbolt, *args, **kwargs):\n    \"\"\"Wrapper for sampling neighbors.\n\n    The actual sampling function depends on whether to use GraphBolt.\n\n    Parameters\n    ----------\n    use_graphbolt : bool\n        Whether to use GraphBolt for sampling.\n    args : list\n        The arguments for the sampling function.\n    kwargs : dict\n        The keyword arguments for the sampling function.\n\n    Returns\n    -------\n    tensor\n        The source node ID array.\n    tensor\n        The destination node ID array.\n    tensor\n        The edge ID array.\n    tensor\n        The edge type ID array.\n    \"\"\"\n    func = (\n        _sample_neighbors_graphbolt if use_graphbolt else _sample_neighbors_dgl\n    )\n    return func(*args, **kwargs)\n\n\ndef _sample_etype_neighbors_dgl(\n    local_g,\n    partition_book,\n    seed_nodes,\n    fan_out,\n    edge_dir=\"in\",\n    prob=None,\n    exclude_edges=None,\n    replace=False,\n    etype_offset=None,\n    etype_sorted=False,\n):\n    \"\"\"Sample from local partition.\n\n    The input nodes use global IDs. We need to map the global node IDs to local node IDs,\n    perform sampling and map the sampled results to the global IDs space again.\n    The sampled results are stored in three vectors that store source nodes, destination nodes\n    and edge IDs.\n    \"\"\"\n    assert etype_offset is not None, \"The etype offset is not provided.\"\n\n    local_ids = partition_book.nid2localnid(seed_nodes, partition_book.partid)\n    local_ids = F.astype(local_ids, local_g.idtype)\n\n    sampled_graph = local_sample_etype_neighbors(\n        local_g,\n        local_ids,\n        etype_offset,\n        fan_out,\n        edge_dir=edge_dir,\n        prob=prob,\n        exclude_edges=exclude_edges,\n        replace=replace,\n        etype_sorted=etype_sorted,\n        _dist_training=True,\n    )\n    global_nid_mapping = local_g.ndata[NID]\n    src, dst = sampled_graph.edges()\n    global_src, global_dst = F.gather_row(\n        global_nid_mapping, src\n    ), F.gather_row(global_nid_mapping, dst)\n    global_eids = F.gather_row(local_g.edata[EID], sampled_graph.edata[EID])\n    return LocalSampledGraph(global_src, global_dst, global_eids)\n\n\ndef _sample_etype_neighbors(use_graphbolt, *args, **kwargs):\n    \"\"\"Wrapper for sampling etype neighbors.\n\n    The actual sampling function depends on whether to use GraphBolt.\n\n    Parameters\n    ----------\n    use_graphbolt : bool\n        Whether to use GraphBolt for sampling.\n    args : list\n        The arguments for the sampling function.\n    kwargs : dict\n        The keyword arguments for the sampling function.\n\n    Returns\n    -------\n    tensor\n        The source node ID array.\n    tensor\n        The destination node ID array.\n    tensor\n        The edge ID array.\n    tensor\n        The edge type ID array.\n    \"\"\"\n    func = (\n        _sample_neighbors_graphbolt\n        if use_graphbolt\n        else _sample_etype_neighbors_dgl\n    )\n    if use_graphbolt:\n        # GraphBolt does not require `etype_offset` and `etype_sorted`.\n        kwargs.pop(\"etype_offset\", None)\n        kwargs.pop(\"etype_sorted\", None)\n    return func(*args, **kwargs)\n\n\ndef _find_edges(local_g, partition_book, seed_edges):\n    \"\"\"Given an edge ID array, return the source\n    and destination node ID array ``s`` and ``d`` in the local partition.\n    \"\"\"\n    local_eids = partition_book.eid2localeid(seed_edges, partition_book.partid)\n    if isinstance(local_g, gb.FusedCSCSamplingGraph):\n        # When converting from DGLGraph to FusedCSCSamplingGraph, the edge IDs\n        # are re-ordered. In order to find the correct node pairs, we need to\n        # map the DGL edge IDs back to GraphBolt edge IDs.\n        if (\n            DGL2GB_EID not in local_g.edge_attributes\n            or GB_DST_ID not in local_g.edge_attributes\n        ):\n            raise ValueError(\n                \"The edge attributes DGL2GB_EID and GB_DST_ID are not found. \"\n                \"Please make sure `coo` format is available when generating \"\n                \"partitions in GraphBolt format.\"\n            )\n        local_eids = local_g.edge_attributes[DGL2GB_EID][local_eids]\n        local_src = local_g.indices[local_eids]\n        local_dst = local_g.edge_attributes[GB_DST_ID][local_eids]\n        global_nid_mapping = local_g.node_attributes[NID]\n    else:\n        local_eids = F.astype(local_eids, local_g.idtype)\n        local_src, local_dst = local_g.find_edges(local_eids)\n        global_nid_mapping = local_g.ndata[NID]\n    global_src = global_nid_mapping[local_src]\n    global_dst = global_nid_mapping[local_dst]\n    return global_src, global_dst\n\n\ndef _in_degrees(local_g, partition_book, n):\n    \"\"\"Get in-degree of the nodes in the local partition.\"\"\"\n    local_nids = partition_book.nid2localnid(n, partition_book.partid)\n    local_nids = F.astype(local_nids, local_g.idtype)\n    return local_g.in_degrees(local_nids)\n\n\ndef _out_degrees(local_g, partition_book, n):\n    \"\"\"Get out-degree of the nodes in the local partition.\"\"\"\n    local_nids = partition_book.nid2localnid(n, partition_book.partid)\n    local_nids = F.astype(local_nids, local_g.idtype)\n    return local_g.out_degrees(local_nids)\n\n\ndef _in_subgraph(local_g, partition_book, seed_nodes):\n    \"\"\"Get in subgraph from local partition.\n\n    The input nodes use global IDs. We need to map the global node IDs to local node IDs,\n    get in-subgraph and map the sampled results to the global IDs space again.\n    The results are stored in three vectors that store source nodes, destination nodes\n    and edge IDs.\n    \"\"\"\n    local_ids = partition_book.nid2localnid(seed_nodes, partition_book.partid)\n    local_ids = F.astype(local_ids, local_g.idtype)\n    # local_ids = self.seed_nodes\n    sampled_graph = local_in_subgraph(local_g, local_ids)\n    global_nid_mapping = local_g.ndata[NID]\n    src, dst = sampled_graph.edges()\n    global_src, global_dst = global_nid_mapping[src], global_nid_mapping[dst]\n    global_eids = F.gather_row(local_g.edata[EID], sampled_graph.edata[EID])\n    return LocalSampledGraph(global_src, global_dst, global_eids)\n\n\n# --- NOTE 1 ---\n# (BarclayII)\n# If the sampling algorithm needs node and edge data, ideally the\n# algorithm should query the underlying feature storage to get what it\n# just needs to complete the job.  For instance, with\n# sample_etype_neighbors, we only need the probability of the seed nodes'\n# neighbors.\n#\n# However, right now we are reusing the existing subgraph sampling\n# interfaces of DGLGraph (i.e. single machine solution), which needs\n# the data of *all* the nodes/edges.  Going distributed, we now need\n# the node/edge data of the *entire* local graph partition.\n#\n# If the sampling algorithm only use edge data, the current design works\n# because the local graph partition contains all the in-edges of the\n# assigned nodes as well as the data.  This is the case for\n# sample_etype_neighbors.\n#\n# However, if the sampling algorithm requires data of the neighbor nodes\n# (e.g. sample_neighbors_biased which performs biased sampling based on the\n# type of the neighbor nodes), the current design will fail because the\n# neighbor nodes (hence the data) may not belong to the current partition.\n# This is a limitation of the current DistDGL design.  We should improve it\n# later.\n\n\nclass SamplingRequest(Request):\n    \"\"\"Sampling Request\"\"\"\n\n    def __init__(\n        self,\n        nodes,\n        fan_out,\n        edge_dir=\"in\",\n        prob=None,\n        exclude_edges=None,\n        replace=False,\n        use_graphbolt=False,\n    ):\n        self.seed_nodes = nodes\n        self.edge_dir = edge_dir\n        self.prob = prob\n        self.exclude_edges = exclude_edges\n        self.replace = replace\n        self.fan_out = fan_out\n        self.use_graphbolt = use_graphbolt\n\n    def __setstate__(self, state):\n        (\n            self.seed_nodes,\n            self.edge_dir,\n            self.prob,\n            self.exclude_edges,\n            self.replace,\n            self.fan_out,\n            self.use_graphbolt,\n        ) = state\n\n    def __getstate__(self):\n        return (\n            self.seed_nodes,\n            self.edge_dir,\n            self.prob,\n            self.exclude_edges,\n            self.replace,\n            self.fan_out,\n            self.use_graphbolt,\n        )\n\n    def process_request(self, server_state):\n        local_g = server_state.graph\n        partition_book = server_state.partition_book\n        kv_store = server_state.kv_store\n        if self.prob is not None and (not self.use_graphbolt):\n            prob = [kv_store.data_store[self.prob]]\n        else:\n            prob = self.prob\n        res = _sample_neighbors(\n            self.use_graphbolt,\n            local_g,\n            partition_book,\n            self.seed_nodes,\n            self.fan_out,\n            edge_dir=self.edge_dir,\n            prob=prob,\n            exclude_edges=self.exclude_edges,\n            replace=self.replace,\n        )\n        return SubgraphResponse(\n            res.global_src,\n            res.global_dst,\n            global_eids=res.global_eids,\n            etype_ids=res.etype_ids,\n        )\n\n\nclass SamplingRequestEtype(Request):\n    \"\"\"Sampling Request\"\"\"\n\n    def __init__(\n        self,\n        nodes,\n        fan_out,\n        edge_dir=\"in\",\n        prob=None,\n        exclude_edges=None,\n        replace=False,\n        etype_sorted=True,\n        use_graphbolt=False,\n    ):\n        self.seed_nodes = nodes\n        self.edge_dir = edge_dir\n        self.prob = prob\n        self.exclude_edges = exclude_edges\n        self.replace = replace\n        self.fan_out = fan_out\n        self.etype_sorted = etype_sorted\n        self.use_graphbolt = use_graphbolt\n\n    def __setstate__(self, state):\n        (\n            self.seed_nodes,\n            self.edge_dir,\n            self.prob,\n            self.exclude_edges,\n            self.replace,\n            self.fan_out,\n            self.etype_sorted,\n            self.use_graphbolt,\n        ) = state\n\n    def __getstate__(self):\n        return (\n            self.seed_nodes,\n            self.edge_dir,\n            self.prob,\n            self.exclude_edges,\n            self.replace,\n            self.fan_out,\n            self.etype_sorted,\n            self.use_graphbolt,\n        )\n\n    def process_request(self, server_state):\n        local_g = server_state.graph\n        partition_book = server_state.partition_book\n        kv_store = server_state.kv_store\n        etype_offset = partition_book.local_etype_offset\n        # See NOTE 1\n        if self.prob is not None and (not self.use_graphbolt):\n            probs = [\n                kv_store.data_store[key] if key != \"\" else None\n                for key in self.prob\n            ]\n        else:\n            probs = self.prob\n        res = _sample_etype_neighbors(\n            self.use_graphbolt,\n            local_g,\n            partition_book,\n            self.seed_nodes,\n            self.fan_out,\n            edge_dir=self.edge_dir,\n            prob=probs,\n            exclude_edges=self.exclude_edges,\n            replace=self.replace,\n            etype_offset=etype_offset,\n            etype_sorted=self.etype_sorted,\n        )\n        return SubgraphResponse(\n            res.global_src,\n            res.global_dst,\n            global_eids=res.global_eids,\n            etype_ids=res.etype_ids,\n        )\n\n\nclass EdgesRequest(Request):\n    \"\"\"Edges Request\"\"\"\n\n    def __init__(self, edge_ids, order_id):\n        self.edge_ids = edge_ids\n        self.order_id = order_id\n\n    def __setstate__(self, state):\n        self.edge_ids, self.order_id = state\n\n    def __getstate__(self):\n        return self.edge_ids, self.order_id\n\n    def process_request(self, server_state):\n        local_g = server_state.graph\n        partition_book = server_state.partition_book\n        global_src, global_dst = _find_edges(\n            local_g, partition_book, self.edge_ids\n        )\n\n        return FindEdgeResponse(global_src, global_dst, self.order_id)\n\n\nclass InDegreeRequest(Request):\n    \"\"\"In-degree Request\"\"\"\n\n    def __init__(self, n, order_id):\n        self.n = n\n        self.order_id = order_id\n\n    def __setstate__(self, state):\n        self.n, self.order_id = state\n\n    def __getstate__(self):\n        return self.n, self.order_id\n\n    def process_request(self, server_state):\n        local_g = server_state.graph\n        partition_book = server_state.partition_book\n        deg = _in_degrees(local_g, partition_book, self.n)\n\n        return InDegreeResponse(deg, self.order_id)\n\n\nclass InDegreeResponse(Response):\n    \"\"\"The response for in-degree\"\"\"\n\n    def __init__(self, deg, order_id):\n        self.val = deg\n        self.order_id = order_id\n\n    def __setstate__(self, state):\n        self.val, self.order_id = state\n\n    def __getstate__(self):\n        return self.val, self.order_id\n\n\nclass OutDegreeRequest(Request):\n    \"\"\"Out-degree Request\"\"\"\n\n    def __init__(self, n, order_id):\n        self.n = n\n        self.order_id = order_id\n\n    def __setstate__(self, state):\n        self.n, self.order_id = state\n\n    def __getstate__(self):\n        return self.n, self.order_id\n\n    def process_request(self, server_state):\n        local_g = server_state.graph\n        partition_book = server_state.partition_book\n        deg = _out_degrees(local_g, partition_book, self.n)\n\n        return OutDegreeResponse(deg, self.order_id)\n\n\nclass OutDegreeResponse(Response):\n    \"\"\"The response for out-degree\"\"\"\n\n    def __init__(self, deg, order_id):\n        self.val = deg\n        self.order_id = order_id\n\n    def __setstate__(self, state):\n        self.val, self.order_id = state\n\n    def __getstate__(self):\n        return self.val, self.order_id\n\n\nclass InSubgraphRequest(Request):\n    \"\"\"InSubgraph Request\"\"\"\n\n    def __init__(self, nodes):\n        self.seed_nodes = nodes\n\n    def __setstate__(self, state):\n        self.seed_nodes = state\n\n    def __getstate__(self):\n        return self.seed_nodes\n\n    def process_request(self, server_state):\n        local_g = server_state.graph\n        partition_book = server_state.partition_book\n        global_src, global_dst, global_eids = _in_subgraph(\n            local_g, partition_book, self.seed_nodes\n        )\n        return SubgraphResponse(global_src, global_dst, global_eids=global_eids)\n\n\ndef merge_graphs(res_list, num_nodes, exclude_edges=None):\n    \"\"\"Merge request from multiple servers\"\"\"\n    if len(res_list) > 1:\n        srcs = []\n        dsts = []\n        eids = []\n        etype_ids = []\n        for res in res_list:\n            srcs.append(res.global_src)\n            dsts.append(res.global_dst)\n            eids.append(res.global_eids)\n            etype_ids.append(res.etype_ids)\n        src_tensor = F.cat(srcs, 0)\n        dst_tensor = F.cat(dsts, 0)\n        eid_tensor = None if eids[0] is None else F.cat(eids, 0)\n        etype_id_tensor = None if etype_ids[0] is None else F.cat(etype_ids, 0)\n    else:\n        src_tensor = res_list[0].global_src\n        dst_tensor = res_list[0].global_dst\n        eid_tensor = res_list[0].global_eids\n        etype_id_tensor = res_list[0].etype_ids\n    if exclude_edges is not None:\n        mask = torch.isin(\n            eid_tensor, exclude_edges, assume_unique=True, invert=True\n        )\n        src_tensor = src_tensor[mask]\n        dst_tensor = dst_tensor[mask]\n        eid_tensor = eid_tensor[mask]\n        if etype_id_tensor is not None:\n            etype_id_tensor = etype_id_tensor[mask]\n    g = graph((src_tensor, dst_tensor), num_nodes=num_nodes)\n    if eid_tensor is not None:\n        g.edata[EID] = eid_tensor\n    if etype_id_tensor is not None:\n        g.edata[ETYPE] = etype_id_tensor\n    return g\n\n\nLocalSampledGraph = namedtuple(  # pylint: disable=unexpected-keyword-arg\n    \"LocalSampledGraph\",\n    \"global_src global_dst global_eids etype_ids\",\n    defaults=(None, None, None, None),\n)\n\n\ndef _distributed_access(\n    g, nodes, issue_remote_req, local_access, exclude_edges=None\n):\n    \"\"\"A routine that fetches local neighborhood of nodes from the distributed graph.\n\n    The local neighborhood of some nodes are stored in the local machine and the other\n    nodes have their neighborhood on remote machines. This code will issue remote\n    access requests first before fetching data from the local machine. In the end,\n    we combine the data from the local machine and remote machines.\n    In this way, we can hide the latency of accessing data on remote machines.\n\n    Parameters\n    ----------\n    g : DistGraph\n        The distributed graph\n    nodes : tensor\n        The nodes whose neighborhood are to be fetched.\n    issue_remote_req : callable\n        The function that issues requests to access remote data.\n    local_access : callable\n        The function that reads data on the local machine.\n    exclude_edges : tensor\n        The edges to exclude after sampling.\n\n    Returns\n    -------\n    DGLGraph\n        The subgraph that contains the neighborhoods of all input nodes.\n    \"\"\"\n    req_list = []\n    partition_book = g.get_partition_book()\n    if not isinstance(nodes, torch.Tensor):\n        nodes = toindex(nodes).tousertensor()\n    partition_id = partition_book.nid2partid(nodes)\n    local_nids = None\n    for pid in range(partition_book.num_partitions()):\n        node_id = F.boolean_mask(nodes, partition_id == pid)\n        # We optimize the sampling on a local partition if the server and the client\n        # run on the same machine. With a good partitioning, most of the seed nodes\n        # should reside in the local partition. If the server and the client\n        # are not co-located, the client doesn't have a local partition.\n        if pid == partition_book.partid and g.local_partition is not None:\n            assert local_nids is None\n            local_nids = node_id\n        elif len(node_id) != 0:\n            req = issue_remote_req(node_id)\n            req_list.append((pid, req))\n\n    # send requests to the remote machine.\n    msgseq2pos = None\n    if len(req_list) > 0:\n        msgseq2pos = send_requests_to_machine(req_list)\n\n    # sample neighbors for the nodes in the local partition.\n    res_list = []\n    if local_nids is not None:\n        res = local_access(g.local_partition, partition_book, local_nids)\n        res_list.append(res)\n\n    # receive responses from remote machines.\n    if msgseq2pos is not None:\n        results = recv_responses(msgseq2pos)\n        res_list.extend(results)\n\n    sampled_graph = merge_graphs(\n        res_list, g.num_nodes(), exclude_edges=exclude_edges\n    )\n    return sampled_graph\n\n\ndef _frontier_to_heterogeneous_graph(g, frontier, gpb):\n    # We need to handle empty frontiers correctly.\n    if frontier.num_edges() == 0:\n        data_dict = {\n            etype: (np.zeros(0), np.zeros(0)) for etype in g.canonical_etypes\n        }\n        return heterograph(\n            data_dict,\n            {ntype: g.num_nodes(ntype) for ntype in g.ntypes},\n            idtype=g.idtype,\n        )\n\n    # For DGL partitions, the global edge IDs are always stored in the edata.\n    # For GraphBolt partitions, the edge type IDs are always stored in the\n    # edata. As for the edge IDs, they are stored in the edata if the graph is\n    # partitioned with `store_eids=True`. Otherwise, the edge IDs are not\n    # stored.\n    etype_ids, type_wise_eids = (\n        gpb.map_to_per_etype(frontier.edata[EID])\n        if EID in frontier.edata\n        else (frontier.edata[ETYPE], None)\n    )\n    etype_ids, idx = F.sort_1d(etype_ids)\n    if type_wise_eids is not None:\n        type_wise_eids = F.gather_row(type_wise_eids, idx)\n\n    # Sort the edges by their edge types.\n    src, dst = frontier.edges()\n    src, dst = F.gather_row(src, idx), F.gather_row(dst, idx)\n    src_ntype_ids, src = gpb.map_to_per_ntype(src)\n    dst_ntype_ids, dst = gpb.map_to_per_ntype(dst)\n\n    data_dict = dict()\n    edge_ids = {}\n    for etid, etype in enumerate(g.canonical_etypes):\n        src_ntype, _, dst_ntype = etype\n        src_ntype_id = g.get_ntype_id(src_ntype)\n        dst_ntype_id = g.get_ntype_id(dst_ntype)\n        type_idx = etype_ids == etid\n        data_dict[etype] = (\n            F.boolean_mask(src, type_idx),\n            F.boolean_mask(dst, type_idx),\n        )\n        if \"DGL_DIST_DEBUG\" in os.environ:\n            assert torch.all(\n                src_ntype_id == src_ntype_ids[type_idx]\n            ), \"source ntype is is not expected.\"\n            assert torch.all(\n                dst_ntype_id == dst_ntype_ids[type_idx]\n            ), \"destination ntype is is not expected.\"\n        if type_wise_eids is not None:\n            edge_ids[etype] = F.boolean_mask(type_wise_eids, type_idx)\n    hg = heterograph(\n        data_dict,\n        {ntype: g.num_nodes(ntype) for ntype in g.ntypes},\n        idtype=g.idtype,\n    )\n\n    for etype in edge_ids:\n        hg.edges[etype].data[EID] = edge_ids[etype]\n    return hg\n\n\ndef sample_etype_neighbors(\n    g,\n    nodes,\n    fanout,\n    edge_dir=\"in\",\n    prob=None,\n    exclude_edges=None,\n    replace=False,\n    etype_sorted=True,\n    use_graphbolt=False,\n):\n    \"\"\"Sample from the neighbors of the given nodes from a distributed graph.\n\n    For each node, a number of inbound (or outbound when ``edge_dir == 'out'``) edges\n    will be randomly chosen.  The returned graph will contain all the nodes in the\n    original graph, but only the sampled edges.\n\n    Node/edge features are not preserved. The original IDs of\n    the sampled edges are stored as the `dgl.EID` feature in the returned graph.\n\n    This function assumes the input is a homogeneous ``DGLGraph`` with the edges\n    ordered by their edge types. The sampled subgraph is also\n    stored in the homogeneous graph format. That is, all nodes and edges are assigned\n    with unique IDs (in contrast, we typically use a type name and a node/edge ID to\n    identify a node or an edge in ``DGLGraph``). We refer to this type of IDs\n    as *homogeneous ID*.\n    Users can use :func:`dgl.distributed.GraphPartitionBook.map_to_per_ntype`\n    and :func:`dgl.distributed.GraphPartitionBook.map_to_per_etype`\n    to identify their node/edge types and node/edge IDs of that type.\n\n    Parameters\n    ----------\n    g : DistGraph\n        The distributed graph..\n    nodes : tensor or dict\n        Node IDs to sample neighbors from. If it's a dict, it should contain only\n        one key-value pair to make this API consistent with dgl.sampling.sample_neighbors.\n    fanout : int or dict[etype, int]\n        The number of edges to be sampled for each node per edge type.  If an integer\n        is given, DGL assumes that the same fanout is applied to every edge type.\n\n        If -1 is given, all of the neighbors will be selected.\n    edge_dir : str, optional\n        Determines whether to sample inbound or outbound edges.\n\n        Can take either ``in`` for inbound edges or ``out`` for outbound edges.\n    prob : str, optional\n        Feature name used as the (unnormalized) probabilities associated with each\n        neighboring edge of a node.  The feature must have only one element for each\n        edge.\n\n        The features must be non-negative floats, and the sum of the features of\n        inbound/outbound edges for every node must be positive (though they don't have\n        to sum up to one).  Otherwise, the result will be undefined.\n    exclude_edges : tensor, optional\n        The edges to exclude when sampling. Homogeneous edge IDs are used.\n    replace : bool, optional\n        If True, sample with replacement.\n\n        When sampling with replacement, the sampled subgraph could have parallel edges.\n\n        For sampling without replacement, if fanout > the number of neighbors, all the\n        neighbors are sampled. If fanout == -1, all neighbors are collected.\n    etype_sorted : bool, optional\n        Indicates whether etypes are sorted.\n    use_graphbolt : bool, optional\n        Whether to use GraphBolt for sampling.\n\n    Returns\n    -------\n    DGLGraph\n        A sampled subgraph containing only the sampled neighboring edges.  It is on CPU.\n    \"\"\"\n    if isinstance(fanout, int):\n        fanout = F.full_1d(len(g.canonical_etypes), fanout, F.int64, F.cpu())\n    else:\n        etype_ids = {etype: i for i, etype in enumerate(g.canonical_etypes)}\n        fanout_array = [None] * len(g.canonical_etypes)\n        for etype, v in fanout.items():\n            c_etype = g.to_canonical_etype(etype)\n            fanout_array[etype_ids[c_etype]] = v\n        assert all(v is not None for v in fanout_array), (\n            \"Not all etypes have valid fanout. Please make sure passed-in \"\n            \"fanout in dict includes all the etypes in graph. Passed-in \"\n            f\"fanout: {fanout}, graph etypes: {g.canonical_etypes}.\"\n        )\n        fanout = F.tensor(fanout_array, dtype=F.int64)\n\n    gpb = g.get_partition_book()\n    if isinstance(nodes, dict):\n        homo_nids = []\n        for ntype in nodes.keys():\n            assert (\n                ntype in g.ntypes\n            ), \"The sampled node type {} does not exist in the input graph\".format(\n                ntype\n            )\n            if F.is_tensor(nodes[ntype]):\n                typed_nodes = nodes[ntype]\n            else:\n                typed_nodes = toindex(nodes[ntype]).tousertensor()\n            homo_nids.append(gpb.map_to_homo_nid(typed_nodes, ntype))\n        nodes = F.cat(homo_nids, 0)\n\n    def issue_remote_req(node_ids):\n        if prob is not None and (not use_graphbolt):\n            # See NOTE 1\n            _prob = [\n                (\n                    # NOTE (BarclayII)\n                    # Currently DistGraph.edges[] does not accept canonical etype.\n                    g.edges[etype].data[prob].kvstore_key\n                    if prob in g.edges[etype].data\n                    else \"\"\n                )\n                for etype in g.canonical_etypes\n            ]\n        else:\n            _prob = prob\n        return SamplingRequestEtype(\n            node_ids,\n            fanout,\n            edge_dir=edge_dir,\n            prob=_prob,\n            exclude_edges=None,\n            replace=replace,\n            etype_sorted=etype_sorted,\n            use_graphbolt=use_graphbolt,\n        )\n\n    def local_access(local_g, partition_book, local_nids):\n        etype_offset = gpb.local_etype_offset\n        # See NOTE 1\n        if prob is not None and (not use_graphbolt):\n            _prob = [\n                (\n                    g.edges[etype].data[prob].local_partition\n                    if prob in g.edges[etype].data\n                    else None\n                )\n                for etype in g.canonical_etypes\n            ]\n        else:\n            _prob = prob\n        return _sample_etype_neighbors(\n            use_graphbolt,\n            local_g,\n            partition_book,\n            local_nids,\n            fanout,\n            edge_dir=edge_dir,\n            prob=_prob,\n            exclude_edges=None,\n            replace=replace,\n            etype_offset=etype_offset,\n            etype_sorted=etype_sorted,\n        )\n\n    frontier = _distributed_access(\n        g, nodes, issue_remote_req, local_access, exclude_edges=exclude_edges\n    )\n    if not gpb.is_homogeneous:\n        return _frontier_to_heterogeneous_graph(g, frontier, gpb)\n    else:\n        return frontier\n\n\ndef sample_neighbors(\n    g,\n    nodes,\n    fanout,\n    edge_dir=\"in\",\n    prob=None,\n    exclude_edges=None,\n    replace=False,\n    use_graphbolt=False,\n):\n    \"\"\"Sample from the neighbors of the given nodes from a distributed graph.\n\n    For each node, a number of inbound (or outbound when ``edge_dir == 'out'``) edges\n    will be randomly chosen.  The returned graph will contain all the nodes in the\n    original graph, but only the sampled edges.\n\n    Node/edge features are not preserved. The original IDs of\n    the sampled edges are stored as the `dgl.EID` feature in the returned graph.\n\n    For heterogeneous graphs, ``nodes`` is a dictionary whose key is node type\n    and the value is type-specific node IDs.\n\n    Parameters\n    ----------\n    g : DistGraph\n        The distributed graph..\n    nodes : tensor or dict\n        Node IDs to sample neighbors from. If it's a dict, it should contain only\n        one key-value pair to make this API consistent with dgl.sampling.sample_neighbors.\n    fanout : int\n        The number of edges to be sampled for each node.\n\n        If -1 is given, all of the neighbors will be selected.\n    edge_dir : str, optional\n        Determines whether to sample inbound or outbound edges.\n\n        Can take either ``in`` for inbound edges or ``out`` for outbound edges.\n    prob : str, optional\n        Feature name used as the (unnormalized) probabilities associated with each\n        neighboring edge of a node.  The feature must have only one element for each\n        edge.\n\n        The features must be non-negative floats, and the sum of the features of\n        inbound/outbound edges for every node must be positive (though they don't have\n        to sum up to one).  Otherwise, the result will be undefined.\n    exclude_edges: tensor or dict, optional\n        Edge IDs to exclude during sampling neighbors for the seed nodes.\n\n        This argument can take a single ID tensor or a dictionary of edge types\n        and ID tensors. If a single tensor is given, the graph must only have\n        one type of nodes.\n    replace : bool, optional\n        If True, sample with replacement.\n\n        When sampling with replacement, the sampled subgraph could have parallel edges.\n\n        For sampling without replacement, if fanout > the number of neighbors, all the\n        neighbors are sampled. If fanout == -1, all neighbors are collected.\n    use_graphbolt : bool, optional\n        Whether to use GraphBolt for sampling.\n\n    Returns\n    -------\n    DGLGraph\n        A sampled subgraph containing only the sampled neighboring edges.  It is on CPU.\n    \"\"\"\n    gpb = g.get_partition_book()\n    if not gpb.is_homogeneous:\n        assert isinstance(nodes, dict)\n        homo_nids = []\n        for ntype in nodes:\n            assert (\n                ntype in g.ntypes\n            ), \"The sampled node type does not exist in the input graph\"\n            if F.is_tensor(nodes[ntype]):\n                typed_nodes = nodes[ntype]\n            else:\n                typed_nodes = toindex(nodes[ntype]).tousertensor()\n            homo_nids.append(gpb.map_to_homo_nid(typed_nodes, ntype))\n        nodes = F.cat(homo_nids, 0)\n    elif isinstance(nodes, dict):\n        assert len(nodes) == 1\n        nodes = list(nodes.values())[0]\n\n    def issue_remote_req(node_ids):\n        if prob is not None and (not use_graphbolt):\n            # See NOTE 1\n            _prob = g.edata[prob].kvstore_key\n        else:\n            _prob = prob\n        return SamplingRequest(\n            node_ids,\n            fanout,\n            edge_dir=edge_dir,\n            prob=_prob,\n            exclude_edges=None,\n            replace=replace,\n            use_graphbolt=use_graphbolt,\n        )\n\n    def local_access(local_g, partition_book, local_nids):\n        # See NOTE 1\n        _prob = (\n            [g.edata[prob].local_partition]\n            if prob is not None and (not use_graphbolt)\n            else prob\n        )\n        return _sample_neighbors(\n            use_graphbolt,\n            local_g,\n            partition_book,\n            local_nids,\n            fanout,\n            edge_dir=edge_dir,\n            prob=_prob,\n            exclude_edges=None,\n            replace=replace,\n        )\n\n    frontier = _distributed_access(\n        g, nodes, issue_remote_req, local_access, exclude_edges=exclude_edges\n    )\n    if not gpb.is_homogeneous:\n        return _frontier_to_heterogeneous_graph(g, frontier, gpb)\n    else:\n        return frontier\n\n\ndef _distributed_edge_access(g, edges, issue_remote_req, local_access):\n    \"\"\"A routine that fetches local edges from distributed graph.\n\n    The source and destination nodes of local edges are stored in the local\n    machine and others are stored on remote machines. This code will issue\n    remote access requests first before fetching data from the local machine.\n    In the end, we combine the data from the local machine and remote machines.\n\n    Parameters\n    ----------\n    g : DistGraph\n        The distributed graph\n    edges : tensor\n        The edges to find their source and destination nodes.\n    issue_remote_req : callable\n        The function that issues requests to access remote data.\n    local_access : callable\n        The function that reads data on the local machine.\n\n    Returns\n    -------\n    tensor\n        The source node ID array.\n    tensor\n        The destination node ID array.\n    \"\"\"\n    req_list = []\n    partition_book = g.get_partition_book()\n    edges = toindex(edges).tousertensor()\n    partition_id = partition_book.eid2partid(edges)\n    local_eids = None\n    reorder_idx = []\n    for pid in range(partition_book.num_partitions()):\n        mask = partition_id == pid\n        edge_id = F.boolean_mask(edges, mask)\n        reorder_idx.append(F.nonzero_1d(mask))\n        if pid == partition_book.partid and g.local_partition is not None:\n            assert local_eids is None\n            local_eids = edge_id\n        elif len(edge_id) != 0:\n            req = issue_remote_req(edge_id, pid)\n            req_list.append((pid, req))\n\n    # send requests to the remote machine.\n    msgseq2pos = None\n    if len(req_list) > 0:\n        msgseq2pos = send_requests_to_machine(req_list)\n\n    # handle edges in local partition.\n    src_ids = F.zeros_like(edges)\n    dst_ids = F.zeros_like(edges)\n    if local_eids is not None:\n        src, dst = local_access(g.local_partition, partition_book, local_eids)\n        src_ids = F.scatter_row(\n            src_ids, reorder_idx[partition_book.partid], src\n        )\n        dst_ids = F.scatter_row(\n            dst_ids, reorder_idx[partition_book.partid], dst\n        )\n\n    # receive responses from remote machines.\n    if msgseq2pos is not None:\n        results = recv_responses(msgseq2pos)\n        for result in results:\n            src = result.global_src\n            dst = result.global_dst\n            src_ids = F.scatter_row(src_ids, reorder_idx[result.order_id], src)\n            dst_ids = F.scatter_row(dst_ids, reorder_idx[result.order_id], dst)\n    return src_ids, dst_ids\n\n\ndef find_edges(g, edge_ids):\n    \"\"\"Given an edge ID array, return the source and destination\n    node ID array ``s`` and ``d`` from a distributed graph.\n    ``s[i]`` and ``d[i]`` are source and destination node ID for\n    edge ``eid[i]``.\n\n    Parameters\n    ----------\n    g : DistGraph\n        The distributed graph.\n    edges : tensor\n        The edge ID array.\n\n    Returns\n    -------\n    tensor\n        The source node ID array.\n    tensor\n        The destination node ID array.\n    \"\"\"\n\n    def issue_remote_req(edge_ids, order_id):\n        return EdgesRequest(edge_ids, order_id)\n\n    def local_access(local_g, partition_book, edge_ids):\n        return _find_edges(local_g, partition_book, edge_ids)\n\n    return _distributed_edge_access(g, edge_ids, issue_remote_req, local_access)\n\n\ndef in_subgraph(g, nodes):\n    \"\"\"Return the subgraph induced on the inbound edges of the given nodes.\n\n    The subgraph keeps the same type schema and all the nodes are preserved regardless\n    of whether they have an edge or not.\n\n    Node/edge features are not preserved. The original IDs of\n    the extracted edges are stored as the `dgl.EID` feature in the returned graph.\n\n    For now, we only support the input graph with one node type and one edge type.\n\n\n    Parameters\n    ----------\n    g : DistGraph\n        The distributed graph structure.\n    nodes : tensor or dict\n        Node ids to sample neighbors from.\n\n    Returns\n    -------\n    DGLGraph\n        The subgraph.\n\n        One can retrieve the mapping from subgraph edge ID to parent\n        edge ID via ``dgl.EID`` edge features of the subgraph.\n    \"\"\"\n    if isinstance(nodes, dict):\n        assert (\n            len(nodes) == 1\n        ), \"The distributed in_subgraph only supports one node type for now.\"\n        nodes = list(nodes.values())[0]\n\n    def issue_remote_req(node_ids):\n        return InSubgraphRequest(node_ids)\n\n    def local_access(local_g, partition_book, local_nids):\n        return _in_subgraph(local_g, partition_book, local_nids)\n\n    return _distributed_access(g, nodes, issue_remote_req, local_access)\n\n\ndef _distributed_get_node_property(g, n, issue_remote_req, local_access):\n    req_list = []\n    partition_book = g.get_partition_book()\n    n = toindex(n).tousertensor()\n    partition_id = partition_book.nid2partid(n)\n    local_nids = None\n    reorder_idx = []\n    for pid in range(partition_book.num_partitions()):\n        mask = partition_id == pid\n        nid = F.boolean_mask(n, mask)\n        reorder_idx.append(F.nonzero_1d(mask))\n        if pid == partition_book.partid and g.local_partition is not None:\n            assert local_nids is None\n            local_nids = nid\n        elif len(nid) != 0:\n            req = issue_remote_req(nid, pid)\n            req_list.append((pid, req))\n\n    # send requests to the remote machine.\n    msgseq2pos = None\n    if len(req_list) > 0:\n        msgseq2pos = send_requests_to_machine(req_list)\n\n    # handle edges in local partition.\n    vals = None\n    if local_nids is not None:\n        local_vals = local_access(g.local_partition, partition_book, local_nids)\n        shape = list(F.shape(local_vals))\n        shape[0] = len(n)\n        vals = F.zeros(shape, F.dtype(local_vals), F.cpu())\n        vals = F.scatter_row(\n            vals, reorder_idx[partition_book.partid], local_vals\n        )\n\n    # receive responses from remote machines.\n    if msgseq2pos is not None:\n        results = recv_responses(msgseq2pos)\n        if len(results) > 0 and vals is None:\n            shape = list(F.shape(results[0].val))\n            shape[0] = len(n)\n            vals = F.zeros(shape, F.dtype(results[0].val), F.cpu())\n        for result in results:\n            val = result.val\n            vals = F.scatter_row(vals, reorder_idx[result.order_id], val)\n    return vals\n\n\ndef in_degrees(g, v):\n    \"\"\"Get in-degrees\"\"\"\n\n    def issue_remote_req(v, order_id):\n        return InDegreeRequest(v, order_id)\n\n    def local_access(local_g, partition_book, v):\n        return _in_degrees(local_g, partition_book, v)\n\n    return _distributed_get_node_property(g, v, issue_remote_req, local_access)\n\n\ndef out_degrees(g, u):\n    \"\"\"Get out-degrees\"\"\"\n\n    def issue_remote_req(u, order_id):\n        return OutDegreeRequest(u, order_id)\n\n    def local_access(local_g, partition_book, u):\n        return _out_degrees(local_g, partition_book, u)\n\n    return _distributed_get_node_property(g, u, issue_remote_req, local_access)\n\n\nregister_service(SAMPLING_SERVICE_ID, SamplingRequest, SubgraphResponse)\nregister_service(EDGES_SERVICE_ID, EdgesRequest, FindEdgeResponse)\nregister_service(INSUBGRAPH_SERVICE_ID, InSubgraphRequest, SubgraphResponse)\nregister_service(OUTDEGREE_SERVICE_ID, OutDegreeRequest, OutDegreeResponse)\nregister_service(INDEGREE_SERVICE_ID, InDegreeRequest, InDegreeResponse)\nregister_service(\n    ETYPE_SAMPLING_SERVICE_ID, SamplingRequestEtype, SubgraphResponse\n)\n"
  },
  {
    "path": "python/dgl/distributed/id_map.py",
    "content": "\"\"\"Module for mapping between node/edge IDs and node/edge types.\"\"\"\n\nimport numpy as np\nimport torch\n\nfrom .. import backend as F, utils\n\nfrom .._ffi.function import _init_api\n\n\n__all__ = [\"IdMap\"]\n\n\nclass IdMap:\n    \"\"\"A map for converting node/edge IDs to their type IDs and type-wise IDs.\n\n    For a heterogeneous graph, DGL assigns an integer ID to each node/edge type;\n    node and edge of different types have independent IDs starting from zero.\n    Therefore, a node/edge can be uniquely identified by an ID pair,\n    ``(type_id, type_wise_id)``. To make it convenient for distributed processing,\n    DGL further encodes the ID pair into one integer ID, which we refer to\n    as *homogeneous ID*.\n\n    DGL arranges nodes and edges so that all nodes of the same type have contiguous\n    homogeneous IDs. If the graph is partitioned, the nodes/edges of the same type\n    within a partition have contiguous homogeneous IDs.\n\n    Below is an example adjancency matrix of an unpartitioned heterogeneous graph\n    stored using the above ID assignment. Here, the graph has two types of nodes\n    (``T0`` and ``T1``), and four types of edges (``R0``, ``R1``, ``R2``, ``R3``).\n    There are a total of 400 nodes in the graph and each type has 200 nodes. Nodes\n    of type 0 have IDs in [0,200), while nodes of type 1 have IDs in [200, 400).\n\n    ```\n        0 <- T0 -> 200 <- T1 -> 400\n     0  +-----------+------------+\n        |           |            |\n     ^  |    R0     |     R1     |\n     T0 |           |            |\n     v  |           |            |\n    200 +-----------+------------+\n        |           |            |\n     ^  |    R2     |     R3     |\n     T1 |           |            |\n     v  |           |            |\n    400 +-----------+------------+\n    ```\n\n    Below shows the adjacency matrix after the graph is partitioned into two.\n    Note that each partition still has two node types and four edge types,\n    and nodes/edges of the same type have contiguous IDs.\n\n    ```\n                partition 0              partition 1\n\n        0 <- T0 -> 100 <- T1 -> 200 <- T0 -> 300 <- T1 -> 400\n     0  +-----------+------------+-----------+------------+\n        |           |            |                        |\n     ^  |    R0     |     R1     |                        |\n     T0 |           |            |                        |\n     v  |           |            |                        |\n    100 +-----------+------------+                        |\n        |           |            |                        |\n     ^  |    R2     |     R3     |                        |\n     T1 |           |            |                        |\n     v  |           |            |                        |\n    200 +-----------+------------+-----------+------------+\n        |                        |           |            |\n     ^  |                        |    R0     |     R1     |\n     T0 |                        |           |            |\n     v  |                        |           |            |\n    100 |                        +-----------+------------+\n        |                        |           |            |\n     ^  |                        |    R2     |     R3     |\n     T1 |                        |           |            |\n     v  |                        |           |            |\n    200 +-----------+------------+-----------+------------+\n    ```\n\n    The following table is an alternative way to represent the above ID assignments.\n    It is easy to see that the homogeneous ID range [0, 100) is used for nodes of type 0\n    in partition 0, [100, 200) is used for nodes of type 1 in partition 0, and so on.\n    ```\n    +---------+------+----------\n      range   | type | partition\n    [0, 100)  |   0  |    0\n    [100,200) |   1  |    0\n    [200,300) |   0  |    1\n    [300,400) |   1  |    1\n    ```\n\n    The goal of this class is to, given a node's homogenous ID, convert it into the\n    ID pair ``(type_id, type_wise_id)``. For example, homogeneous node ID 90 is mapped\n    to (0, 90); homogeneous node ID 201 is mapped to (0, 101).\n\n    Parameters\n    ----------\n    id_ranges : dict[str, Tensor].\n        Node ID ranges within partitions for each node type. The key is the node type\n        name in string. The value is a tensor of shape :math:`(K, 2)`, where :math:`K` is\n        the number of partitions. Each row has two integers: the starting and the ending IDs\n        for a particular node type in a partition. For example, all nodes of type ``\"T\"`` in\n        partition ``i`` has ID range ``id_ranges[\"T\"][i][0]`` to ``id_ranges[\"T\"][i][1]``.\n        It is the same as the `node_map` argument in `RangePartitionBook`.\n    \"\"\"\n\n    def __init__(self, id_ranges):\n        id_ranges_values = list(id_ranges.values())\n        assert isinstance(\n            id_ranges_values[0], np.ndarray\n        ), \"id_ranges should be a dict of numpy arrays.\"\n        self.num_parts = id_ranges_values[0].shape[0]\n        self.dtype = id_ranges_values[0].dtype\n        self.dtype_str = \"int32\" if self.dtype == np.int32 else \"int64\"\n        self.num_types = len(id_ranges)\n        ranges = np.zeros(\n            (self.num_parts * self.num_types, 2), dtype=self.dtype\n        )\n        typed_map = []\n        id_ranges = id_ranges_values\n        id_ranges.sort(key=lambda a: a[0, 0])\n        for i, id_range in enumerate(id_ranges):\n            ranges[i :: self.num_types] = id_range\n            map1 = np.cumsum(id_range[:, 1] - id_range[:, 0], dtype=self.dtype)\n            typed_map.append(map1)\n\n        assert np.all(np.diff(ranges[:, 0]) >= 0)\n        assert np.all(np.diff(ranges[:, 1]) >= 0)\n        self.range_start = utils.toindex(\n            np.ascontiguousarray(ranges[:, 0]), dtype=self.dtype_str\n        )\n        self.range_end = utils.toindex(\n            np.ascontiguousarray(ranges[:, 1]) - 1, dtype=self.dtype_str\n        )\n        self.typed_map = utils.toindex(\n            np.concatenate(typed_map), dtype=self.dtype_str\n        )\n\n    def __call__(self, ids):\n        \"\"\"Convert the homogeneous IDs to (type_id, type_wise_id).\n\n        Parameters\n        ----------\n        ids : 1D tensor\n            The homogeneous ID.\n\n        Returns\n        -------\n        type_ids : Tensor\n            Type IDs\n        per_type_ids : Tensor\n            Type-wise IDs\n        \"\"\"\n        if self.num_types == 0:\n            return F.zeros((len(ids),), F.dtype(ids), F.cpu()), ids\n        if len(ids) == 0:\n            return ids, ids\n\n        ids = utils.toindex(ids, dtype=self.dtype_str)\n        ret = _CAPI_DGLHeteroMapIds(\n            ids.todgltensor(),\n            self.range_start.todgltensor(),\n            self.range_end.todgltensor(),\n            self.typed_map.todgltensor(),\n            self.num_parts,\n            self.num_types,\n        )\n        ret = utils.toindex(ret, dtype=self.dtype_str).tousertensor()\n        return ret[: len(ids)], ret[len(ids) :]\n\n    @property\n    def torch_dtype(self):\n        \"\"\"Return the data type of the ID map.\"\"\"\n        # [TODO][Rui] Use torch instead of numpy.\n        return torch.int32 if self.dtype == np.int32 else torch.int64\n\n\n_init_api(\"dgl.distributed.id_map\")\n"
  },
  {
    "path": "python/dgl/distributed/kvstore.py",
    "content": "\"\"\"Define distributed kvstore\"\"\"\n\nimport os\n\nimport numpy as np\n\nfrom .. import backend as F, utils\nfrom .._ffi.ndarray import empty_shared_mem\n\nfrom . import rpc\nfrom .graph_partition_book import EdgePartitionPolicy, NodePartitionPolicy\nfrom .standalone_kvstore import KVClient as SA_KVClient\n\n############################ Register KVStore Requsts and Responses ###############################\n\nKVSTORE_PULL = 901231\n\n\nclass PullResponse(rpc.Response):\n    \"\"\"Send the sliced data tensor back to the client.\n\n    Parameters\n    ----------\n    server_id : int\n        ID of current server\n    data_tensor : tensor\n        sliced data tensor\n    \"\"\"\n\n    def __init__(self, server_id, data_tensor):\n        self.server_id = server_id\n        self.data_tensor = data_tensor\n\n    def __getstate__(self):\n        return self.server_id, self.data_tensor\n\n    def __setstate__(self, state):\n        self.server_id, self.data_tensor = state\n\n\nclass PullRequest(rpc.Request):\n    \"\"\"Send ID tensor to server and get target data tensor as response.\n\n    Parameters\n    ----------\n    name : str\n        data name\n    id_tensor : tensor\n        a vector storing the data ID\n    \"\"\"\n\n    def __init__(self, name, id_tensor):\n        self.name = name\n        self.id_tensor = id_tensor\n\n    def __getstate__(self):\n        return self.name, self.id_tensor\n\n    def __setstate__(self, state):\n        self.name, self.id_tensor = state\n\n    def process_request(self, server_state):\n        kv_store = server_state.kv_store\n        if self.name not in kv_store.part_policy:\n            raise RuntimeError(\n                \"KVServer cannot find partition policy with name: %s\"\n                % self.name\n            )\n        if self.name not in kv_store.data_store:\n            raise RuntimeError(\n                \"KVServer Cannot find data tensor with name: %s\" % self.name\n            )\n        local_id = kv_store.part_policy[self.name].to_local(self.id_tensor)\n        data = kv_store.pull_handlers[self.name](\n            kv_store.data_store, self.name, local_id\n        )\n        res = PullResponse(kv_store.server_id, data)\n        return res\n\n\nKVSTORE_PUSH = 901232\n\n\nclass PushRequest(rpc.Request):\n    \"\"\"Send ID tensor and data tensor to server and update kvstore's data.\n\n    This request has no response.\n\n    Parameters\n    ----------\n    name : str\n        data name\n    id_tensor : tensor\n        a vector storing the data ID\n    data_tensor : tensor\n        a tensor with the same row size of data ID\n    \"\"\"\n\n    def __init__(self, name, id_tensor, data_tensor):\n        self.name = name\n        self.id_tensor = id_tensor\n        self.data_tensor = data_tensor\n\n    def __getstate__(self):\n        return self.name, self.id_tensor, self.data_tensor\n\n    def __setstate__(self, state):\n        self.name, self.id_tensor, self.data_tensor = state\n\n    def process_request(self, server_state):\n        kv_store = server_state.kv_store\n        if self.name not in kv_store.part_policy:\n            raise RuntimeError(\n                \"KVServer cannot find partition policy with name: %s\"\n                % self.name\n            )\n        if self.name not in kv_store.data_store:\n            raise RuntimeError(\n                \"KVServer Cannot find data tensor with name: %s\" % self.name\n            )\n        local_id = kv_store.part_policy[self.name].to_local(self.id_tensor)\n        kv_store.push_handlers[self.name](\n            kv_store.data_store, self.name, local_id, self.data_tensor\n        )\n\n\nINIT_DATA = 901233\nINIT_MSG = \"Init\"\n\n\nclass InitDataResponse(rpc.Response):\n    \"\"\"Send a confirmation response (just a short string message) of\n    InitDataRequest to client.\n\n    Parameters\n    ----------\n    msg : string\n        string message\n    \"\"\"\n\n    def __init__(self, msg):\n        self.msg = msg\n\n    def __getstate__(self):\n        return self.msg\n\n    def __setstate__(self, state):\n        self.msg = state\n\n\nclass InitDataRequest(rpc.Request):\n    \"\"\"Send meta data to server and init data tensor\n    on server using UDF init function.\n\n    Parameters\n    ----------\n    name : str\n        data name\n    shape : tuple\n        data shape\n    dtype : str\n        data type string, e.g., 'int64', 'float32', etc.\n    policy_str : str\n        partition-policy string, e.g., 'edge' or 'node'.\n    init_func : function\n        UDF init function.\n    \"\"\"\n\n    def __init__(self, name, shape, dtype, policy_str, init_func):\n        self.name = name\n        self.shape = shape\n        self.dtype = dtype\n        self.policy_str = policy_str\n        self.init_func = init_func\n\n    def __getstate__(self):\n        return (\n            self.name,\n            self.shape,\n            self.dtype,\n            self.policy_str,\n            self.init_func,\n        )\n\n    def __setstate__(self, state):\n        (\n            self.name,\n            self.shape,\n            self.dtype,\n            self.policy_str,\n            self.init_func,\n        ) = state\n\n    def process_request(self, server_state):\n        kv_store = server_state.kv_store\n        dtype = F.data_type_dict[self.dtype]\n\n        # We should see requests from multiple clients. We need to ignore the duplicated\n        # reqeusts.\n        if self.name in kv_store.data_store:\n            assert tuple(F.shape(kv_store.data_store[self.name])) == tuple(\n                self.shape\n            )\n            assert (\n                F.reverse_data_type_dict[\n                    F.dtype(kv_store.data_store[self.name])\n                ]\n                == self.dtype\n            )\n            assert kv_store.part_policy[self.name].policy_str == self.policy_str\n        else:\n            if not kv_store.is_backup_server():\n                data_tensor = self.init_func(self.shape, dtype)\n                kv_store.init_data(\n                    name=self.name,\n                    policy_str=self.policy_str,\n                    data_tensor=data_tensor,\n                )\n            else:\n                kv_store.init_data(name=self.name, policy_str=self.policy_str)\n        res = InitDataResponse(INIT_MSG)\n        return res\n\n\nBARRIER = 901234\nBARRIER_MSG = \"Barrier\"\n\n\nclass BarrierResponse(rpc.Response):\n    \"\"\"Send an confimation signal (just a short string message) of\n    BarrierRequest to client.\n\n    Parameters\n    ----------\n    msg : string\n        string msg\n    \"\"\"\n\n    def __init__(self, msg):\n        self.msg = msg\n\n    def __getstate__(self):\n        return self.msg\n\n    def __setstate__(self, state):\n        self.msg = state\n\n\nclass BarrierRequest(rpc.Request):\n    \"\"\"Send a barrier signal (just a short string message) to server.\n\n    Parameters\n    ----------\n    role : string\n        client role\n    \"\"\"\n\n    def __init__(self, role):\n        self.role = role\n        self.group_id = rpc.get_group_id()\n\n    def __getstate__(self):\n        return self.role, self.group_id\n\n    def __setstate__(self, state):\n        self.role, self.group_id = state\n\n    def process_request(self, server_state):\n        kv_store = server_state.kv_store\n        roles = server_state.roles\n        role = roles[self.group_id]\n        barrier_count = kv_store.barrier_count[self.group_id]\n        count = barrier_count[self.role]\n        barrier_count[self.role] = count + 1\n        if barrier_count[self.role] == len(role[self.role]):\n            barrier_count[self.role] = 0\n            res_list = []\n            for client_id, _ in role[self.role]:\n                res_list.append((client_id, BarrierResponse(BARRIER_MSG)))\n            return res_list\n        return None\n\n\nREGISTER_PULL = 901235\nREGISTER_PULL_MSG = \"Register_Pull\"\n\n\nclass RegisterPullHandlerResponse(rpc.Response):\n    \"\"\"Send a confirmation signal (just a short string message) of\n    RegisterPullHandler to client.\n\n    Parameters\n    ----------\n    msg : string\n        string message\n    \"\"\"\n\n    def __init__(self, msg):\n        self.msg = msg\n\n    def __getstate__(self):\n        return self.msg\n\n    def __setstate__(self, state):\n        self.msg = state\n\n\nclass RegisterPullHandlerRequest(rpc.Request):\n    \"\"\"Send an UDF and register Pull handler on server.\n\n    Parameters\n    ----------\n    pull_func : func\n        UDF pull handler\n    \"\"\"\n\n    def __init__(self, name, pull_func):\n        self.name = name\n        self.pull_func = pull_func\n\n    def __getstate__(self):\n        return self.name, self.pull_func\n\n    def __setstate__(self, state):\n        self.name, self.pull_func = state\n\n    def process_request(self, server_state):\n        kv_store = server_state.kv_store\n        kv_store.pull_handlers[self.name] = self.pull_func\n        res = RegisterPullHandlerResponse(REGISTER_PULL_MSG)\n        return res\n\n\nREGISTER_PUSH = 901236\nREGISTER_PUSH_MSG = \"Register_Push\"\n\n\nclass RegisterPushHandlerResponse(rpc.Response):\n    \"\"\"Send a confirmation signal (just a short string message) of\n    RegisterPushHandler to client.\n\n    Parameters\n    ----------\n    msg : string\n        string message\n    \"\"\"\n\n    def __init__(self, msg):\n        self.msg = msg\n\n    def __getstate__(self):\n        return self.msg\n\n    def __setstate__(self, state):\n        self.msg = state\n\n\nclass RegisterPushHandlerRequest(rpc.Request):\n    \"\"\"Send an UDF to register Push handler on server.\n\n    Parameters\n    ----------\n    push_func : func\n        UDF push handler\n    \"\"\"\n\n    def __init__(self, name, push_func):\n        self.name = name\n        self.push_func = push_func\n\n    def __getstate__(self):\n        return self.name, self.push_func\n\n    def __setstate__(self, state):\n        self.name, self.push_func = state\n\n    def process_request(self, server_state):\n        kv_store = server_state.kv_store\n        kv_store.push_handlers[self.name] = self.push_func\n        res = RegisterPushHandlerResponse(REGISTER_PUSH_MSG)\n        return res\n\n\nGET_SHARED = 901237\nGET_SHARED_MSG = \"Get_Shared\"\n\n\nclass GetSharedDataResponse(rpc.Response):\n    \"\"\"Send meta data of shared-memory tensor to client.\n\n    Parameters\n    ----------\n    meta : dict\n        a dict of meta, e.g.,\n\n        {'data_0' : (shape, dtype, policy_str),\n         'data_1' : (shape, dtype, policy_str)}\n    \"\"\"\n\n    def __init__(self, meta):\n        self.meta = meta\n\n    def __getstate__(self):\n        return self.meta\n\n    def __setstate__(self, state):\n        self.meta = state\n\n\nclass GetSharedDataRequest(rpc.Request):\n    \"\"\"Send a signal (just a short string message) to get the\n    meta data of shared-tensor from server.\n\n    Parameters\n    ----------\n    msg : string\n        string message\n    \"\"\"\n\n    def __init__(self, msg):\n        self.msg = msg\n\n    def __getstate__(self):\n        return self.msg\n\n    def __setstate__(self, state):\n        self.msg = state\n\n    def process_request(self, server_state):\n        assert self.msg == GET_SHARED_MSG\n        meta = {}\n        kv_store = server_state.kv_store\n        for name, data in kv_store.data_store.items():\n            meta[name] = (\n                F.shape(data),\n                F.reverse_data_type_dict[F.dtype(data)],\n                kv_store.part_policy[name].policy_str,\n            )\n        res = GetSharedDataResponse(meta)\n        return res\n\n\nGET_PART_SHAPE = 901238\n\n\nclass GetPartShapeResponse(rpc.Response):\n    \"\"\"Send the partitioned data shape back to client.\n\n    Parameters\n    ----------\n    shape : tuple\n        shape of tensor\n    \"\"\"\n\n    def __init__(self, shape):\n        self.shape = shape\n\n    def __getstate__(self):\n        return self.shape\n\n    def __setstate__(self, state):\n        # When the shape has only one dimension, state is an integer.\n        if isinstance(state, int):\n            self.shape = (state,)\n        else:\n            self.shape = state\n\n\nclass GetPartShapeRequest(rpc.Request):\n    \"\"\"Send data name to get the partitioned data shape from server.\n\n    Parameters\n    ----------\n    name : str\n        data name\n    \"\"\"\n\n    def __init__(self, name):\n        self.name = name\n\n    def __getstate__(self):\n        return self.name\n\n    def __setstate__(self, state):\n        self.name = state\n\n    def process_request(self, server_state):\n        kv_store = server_state.kv_store\n        if self.name not in kv_store.data_store:\n            raise RuntimeError(\n                \"KVServer Cannot find data tensor with name: %s\" % self.name\n            )\n        data_shape = F.shape(kv_store.data_store[self.name])\n        res = GetPartShapeResponse(data_shape)\n        return res\n\n\nSEND_META_TO_BACKUP = 901239\nSEND_META_TO_BACKUP_MSG = \"Send_Meta_TO_Backup\"\n\n\nclass SendMetaToBackupResponse(rpc.Response):\n    \"\"\"Send a confirmation signal (just a short string message)\n    of SendMetaToBackupRequest to client.\n    \"\"\"\n\n    def __init__(self, msg):\n        self.msg = msg\n\n    def __getstate__(self):\n        return self.msg\n\n    def __setstate__(self, state):\n        self.msg = state\n\n\nclass SendMetaToBackupRequest(rpc.Request):\n    \"\"\"Send meta data to backup server and backup server\n    will use this meta data to read shared-memory tensor.\n\n    Parameters\n    ----------\n    name : str\n        data name\n    dtype : str\n        data type string\n    shape : tuple of int\n        data shape\n    policy_str : str\n        partition-policy string, e.g., 'edge' or 'node'.\n    pull_handler : callable\n        The callback function when data is pulled from kvstore.\n    push_handler : callable\n        The callback function when data is pushed to kvstore.\n    \"\"\"\n\n    def __init__(\n        self, name, dtype, shape, policy_str, pull_handler, push_handler\n    ):\n        self.name = name\n        self.dtype = dtype\n        self.shape = shape\n        self.policy_str = policy_str\n        self.pull_handler = pull_handler\n        self.push_handler = push_handler\n\n    def __getstate__(self):\n        return (\n            self.name,\n            self.dtype,\n            self.shape,\n            self.policy_str,\n            self.pull_handler,\n            self.push_handler,\n        )\n\n    def __setstate__(self, state):\n        (\n            self.name,\n            self.dtype,\n            self.shape,\n            self.policy_str,\n            self.pull_handler,\n            self.push_handler,\n        ) = state\n\n    def process_request(self, server_state):\n        kv_store = server_state.kv_store\n        assert kv_store.is_backup_server()\n        if self.name not in kv_store.data_store:\n            shared_data = empty_shared_mem(\n                self.name + \"-kvdata-\", False, self.shape, self.dtype\n            )\n            dlpack = shared_data.to_dlpack()\n            kv_store.data_store[self.name] = F.zerocopy_from_dlpack(dlpack)\n            kv_store.part_policy[self.name] = kv_store.find_policy(\n                self.policy_str\n            )\n            kv_store.pull_handlers[self.name] = self.pull_handler\n            kv_store.push_handlers[self.name] = self.push_handler\n        else:\n            assert tuple(F.shape(kv_store.data_store[self.name])) == tuple(\n                self.shape\n            )\n            assert (\n                F.reverse_data_type_dict[\n                    F.dtype(kv_store.data_store[self.name])\n                ]\n                == self.dtype\n            )\n            assert kv_store.part_policy[self.name].policy_str == self.policy_str\n            assert kv_store.pull_handlers[self.name] == self.pull_handler\n            assert kv_store.push_handlers[self.name] == self.push_handler\n        res = SendMetaToBackupResponse(SEND_META_TO_BACKUP_MSG)\n        return res\n\n\nDELETE_DATA = 901240\nDELETE_MSG = \"Delete_Data\"\n\n\nclass DeleteDataResponse(rpc.Response):\n    \"\"\"Send a confirmation signal (just a short string message)\n    of DeleteDataRequest to client.\n    \"\"\"\n\n    def __init__(self, msg):\n        self.msg = msg\n\n    def __getstate__(self):\n        return self.msg\n\n    def __setstate__(self, state):\n        self.msg = state\n\n\nclass DeleteDataRequest(rpc.Request):\n    \"\"\"Send message to server to delete data tensor\n\n    Parameters\n    ----------\n    name : str\n        data name\n    \"\"\"\n\n    def __init__(self, name):\n        self.name = name\n\n    def __getstate__(self):\n        return self.name\n\n    def __setstate__(self, state):\n        self.name = state\n\n    def process_request(self, server_state):\n        kv_store = server_state.kv_store\n        if self.name in kv_store.data_store:\n            del kv_store.data_store[self.name]\n            del kv_store.part_policy[self.name]\n            del kv_store.push_handlers[self.name]\n            del kv_store.pull_handlers[self.name]\n        res = DeleteDataResponse(DELETE_MSG)\n        return res\n\n\nCOUNT_LOCAL_NONZERO = 901241\n\n\nclass CountLocalNonzeroResponse(rpc.Response):\n    \"\"\"Send the number of nonzero value in local data\"\"\"\n\n    def __init__(self, num_local_nonzero):\n        self.num_local_nonzero = num_local_nonzero\n\n    def __getstate__(self):\n        return self.num_local_nonzero\n\n    def __setstate__(self, state):\n        self.num_local_nonzero = state\n\n\nclass CountLocalNonzeroRequest(rpc.Request):\n    \"\"\"Send data name to server to count local nonzero value\n    Parameters\n    ----------\n    name : str\n        data name\n    \"\"\"\n\n    def __init__(self, name):\n        self.name = name\n\n    def __getstate__(self):\n        return self.name\n\n    def __setstate__(self, state):\n        self.name = state\n\n    def process_request(self, server_state):\n        kv_store = server_state.kv_store\n        num_local_nonzero = kv_store.count_local_nonzero(self.name)\n        res = CountLocalNonzeroResponse(num_local_nonzero)\n        return res\n\n\n############################ KVServer ###############################\n\n\ndef default_push_handler(target, name, id_tensor, data_tensor):\n    \"\"\"Default handler for PUSH message.\n\n    On default, _push_handler perform scatter_row() operation for the tensor.\n\n    Parameters\n    ----------\n    target : tensor\n        target tensor\n    name : str\n        data name\n    id_tensor : tensor\n        a vector storing the ID list.\n    data_tensor : tensor\n        a tensor with the same row size of id\n    \"\"\"\n    # TODO(chao): support Tensorflow backend\n    target[name][id_tensor] = data_tensor\n\n\ndef default_pull_handler(target, name, id_tensor):\n    \"\"\"Default handler for PULL operation.\n\n    On default, _pull_handler perform gather_row() operation for the tensor.\n\n    Parameters\n    ----------\n    target : tensor\n        target tensor\n    name : str\n        data name\n    id_tensor : tensor\n        a vector storing the ID list.\n\n    Return\n    ------\n    tensor\n        a tensor with the same row size of ID.\n    \"\"\"\n    # TODO(chao): support Tensorflow backend\n    return target[name][id_tensor]\n\n\nclass KVServer(object):\n    \"\"\"KVServer is a lightweight key-value store service for DGL distributed training.\n\n    In practice, developers can use KVServer to hold large-scale graph features or\n    graph embeddings across machines in a distributed setting. KVServer depends on DGL rpc\n    infrastructure thats support backup servers, which means we can lunach many KVServers\n    on the same machine for load-balancing.\n\n    DO NOT use KVServer in mult-threads because this behavior is not defined. For now, KVServer\n    can only support CPU-to-CPU communication. We may support GPU-communication in the future.\n\n    Parameters\n    ----------\n    server_id : int\n        ID of current server (starts from 0).\n    ip_config : str\n        Path of IP configuration file.\n    num_servers : int\n        Server count on each machine.\n    num_clients : int\n        Total number of KVClients that will be connected to the KVServer.\n    \"\"\"\n\n    def __init__(self, server_id, ip_config, num_servers, num_clients):\n        assert server_id >= 0, (\n            \"server_id (%d) cannot be a negative number.\" % server_id\n        )\n        assert num_servers > 0, (\n            \"num_servers (%d) must be a positive number.\" % num_servers\n        )\n        assert os.path.exists(ip_config), \"Cannot open file: %s\" % ip_config\n        assert num_clients >= 0, (\n            \"num_clients (%d) cannot be a negative number.\" % num_clients\n        )\n        # Register services on server\n        rpc.register_service(KVSTORE_PULL, PullRequest, PullResponse)\n        rpc.register_service(KVSTORE_PUSH, PushRequest, None)\n        rpc.register_service(INIT_DATA, InitDataRequest, InitDataResponse)\n        rpc.register_service(BARRIER, BarrierRequest, BarrierResponse)\n        rpc.register_service(\n            REGISTER_PUSH,\n            RegisterPushHandlerRequest,\n            RegisterPushHandlerResponse,\n        )\n        rpc.register_service(\n            REGISTER_PULL,\n            RegisterPullHandlerRequest,\n            RegisterPullHandlerResponse,\n        )\n        rpc.register_service(\n            GET_SHARED, GetSharedDataRequest, GetSharedDataResponse\n        )\n        rpc.register_service(\n            GET_PART_SHAPE, GetPartShapeRequest, GetPartShapeResponse\n        )\n        rpc.register_service(\n            SEND_META_TO_BACKUP,\n            SendMetaToBackupRequest,\n            SendMetaToBackupResponse,\n        )\n        rpc.register_service(DELETE_DATA, DeleteDataRequest, DeleteDataResponse)\n        rpc.register_service(\n            COUNT_LOCAL_NONZERO,\n            CountLocalNonzeroRequest,\n            CountLocalNonzeroResponse,\n        )\n        # Store the tensor data with specified data name\n        self._data_store = {}\n        # Store original tensor data names when instantiating DistGraphServer\n        self._orig_data = set()\n        # Store the partition information with specified data name\n        self._policy_set = set()\n        self._part_policy = {}\n        # Basic information\n        self._server_id = server_id\n        self._server_namebook = rpc.read_ip_config(ip_config, num_servers)\n        assert (\n            server_id in self._server_namebook\n        ), \"Trying to start server {}, but there are {} servers in the config file\".format(\n            server_id, len(self._server_namebook)\n        )\n        self._machine_id = self._server_namebook[server_id][0]\n        self._group_count = self._server_namebook[server_id][3]\n        # We assume partition_id is equal to machine_id\n        self._part_id = self._machine_id\n        self._num_clients = num_clients\n        self._barrier_count = {}\n        # push and pull handler\n        self._push_handlers = {}\n        self._pull_handlers = {}\n\n    @property\n    def server_id(self):\n        \"\"\"Get server ID\"\"\"\n        return self._server_id\n\n    @property\n    def barrier_count(self):\n        \"\"\"Get barrier count\"\"\"\n        return self._barrier_count\n\n    @barrier_count.setter\n    def barrier_count(self, count):\n        \"\"\"Set barrier count\"\"\"\n        self._barrier_count = count\n\n    @property\n    def num_clients(self):\n        \"\"\"Get number of clients\"\"\"\n        return self._num_clients\n\n    @property\n    def data_store(self):\n        \"\"\"Get data store\"\"\"\n        return self._data_store\n\n    @property\n    def orig_data(self):\n        \"\"\"Get original data\"\"\"\n        return self._orig_data\n\n    @property\n    def part_policy(self):\n        \"\"\"Get part policy\"\"\"\n        return self._part_policy\n\n    @property\n    def part_id(self):\n        \"\"\"Get part ID\"\"\"\n        return self._part_id\n\n    @property\n    def push_handlers(self):\n        \"\"\"Get push handler\"\"\"\n        return self._push_handlers\n\n    @property\n    def pull_handlers(self):\n        \"\"\"Get pull handler\"\"\"\n        return self._pull_handlers\n\n    def is_backup_server(self):\n        \"\"\"Return True if current server is a backup server.\"\"\"\n        if self._server_id % self._group_count == 0:\n            return False\n        return True\n\n    def add_part_policy(self, policy):\n        \"\"\"Add partition policy to kvserver.\n\n        Parameters\n        ----------\n        policy : PartitionPolicy\n            Store the partition information\n        \"\"\"\n        self._policy_set.add(policy)\n\n    def init_data(self, name, policy_str, data_tensor=None):\n        \"\"\"Init data tensor on kvserver.\n\n        Parameters\n        ----------\n        name : str\n            data name\n        policy_str : str\n            partition-policy string, e.g., 'edge' or 'node'.\n        data_tensor : tensor\n            If the data_tensor is None, KVServer will\n            read shared-memory when client invoking get_shared_data().\n        \"\"\"\n        assert len(name) > 0, \"name cannot be empty.\"\n        if name in self._data_store:\n            raise RuntimeError(\"Data %s has already exists!\" % name)\n        self._part_policy[name] = self.find_policy(policy_str)\n        if data_tensor is not None:  # Create shared-tensor\n            data_type = F.reverse_data_type_dict[F.dtype(data_tensor)]\n            shared_data = empty_shared_mem(\n                name + \"-kvdata-\", True, data_tensor.shape, data_type\n            )\n            dlpack = shared_data.to_dlpack()\n            self._data_store[name] = F.zerocopy_from_dlpack(dlpack)\n            rpc.copy_data_to_shared_memory(self._data_store[name], data_tensor)\n            assert (\n                self._part_policy[name].get_part_size() == data_tensor.shape[0]\n            ), \"kvserver expect partition {} for {} has {} rows, but gets {} rows\".format(\n                self._part_policy[name].part_id,\n                policy_str,\n                self._part_policy[name].get_part_size(),\n                data_tensor.shape[0],\n            )\n        self._pull_handlers[name] = default_pull_handler\n        self._push_handlers[name] = default_push_handler\n\n    def find_policy(self, policy_str):\n        \"\"\"Find a partition policy from existing policy set\n\n        Parameters\n        ----------\n        policy_str : str\n            partition-policy string, e.g., 'edge' or 'node'.\n        \"\"\"\n        for policy in self._policy_set:\n            if policy_str == policy.policy_str:\n                return policy\n        raise RuntimeError(\n            \"Cannot find policy_str: %s from kvserver.\" % policy_str\n        )\n\n    def count_local_nonzero(self, name):\n        \"\"\"Count nonzero in local data\n\n        Parameters\n        ----------\n        name : str\n            data name.\n\n        Returns\n        -------\n        int\n            the number of nonzero in local data.\n        \"\"\"\n        assert len(name) > 0, \"name cannot be empty.\"\n        if name not in self._data_store:\n            raise RuntimeError(\"Data %s has not be created!\" % name)\n        return F.count_nonzero(self._data_store[name])\n\n\n############################ KVClient ###############################\n\n\nclass KVClient(object):\n    \"\"\"KVClient is used to push/pull data to/from KVServer. If the\n    target kvclient and kvserver are in the same machine, they can\n    communicate with each other using local shared-memory\n    automatically, instead of going through the tcp/ip RPC.\n\n    DO NOT use KVClient in multi-threads because this behavior is\n    not defined. For now, KVClient can only support CPU-to-CPU communication.\n    We may support GPU-communication in the future.\n\n    Parameters\n    ----------\n    ip_config : str\n        Path of IP configuration file.\n    num_servers : int\n        Server count on each machine.\n    role : str\n        We can set different role for kvstore.\n    \"\"\"\n\n    def __init__(self, ip_config, num_servers, role=\"default\"):\n        assert (\n            rpc.get_rank() != -1\n        ), \"Please invoke rpc.connect_to_server() before creating KVClient.\"\n        assert os.path.exists(ip_config), \"Cannot open file: %s\" % ip_config\n        assert num_servers > 0, (\n            \"num_servers (%d) must be a positive number.\" % num_servers\n        )\n        # Register services on client\n        rpc.register_service(KVSTORE_PULL, PullRequest, PullResponse)\n        rpc.register_service(KVSTORE_PUSH, PushRequest, None)\n        rpc.register_service(INIT_DATA, InitDataRequest, InitDataResponse)\n        rpc.register_service(BARRIER, BarrierRequest, BarrierResponse)\n        rpc.register_service(\n            REGISTER_PUSH,\n            RegisterPushHandlerRequest,\n            RegisterPushHandlerResponse,\n        )\n        rpc.register_service(\n            REGISTER_PULL,\n            RegisterPullHandlerRequest,\n            RegisterPullHandlerResponse,\n        )\n        rpc.register_service(\n            GET_SHARED, GetSharedDataRequest, GetSharedDataResponse\n        )\n        rpc.register_service(\n            GET_PART_SHAPE, GetPartShapeRequest, GetPartShapeResponse\n        )\n        rpc.register_service(\n            SEND_META_TO_BACKUP,\n            SendMetaToBackupRequest,\n            SendMetaToBackupResponse,\n        )\n        rpc.register_service(DELETE_DATA, DeleteDataRequest, DeleteDataResponse)\n        rpc.register_service(\n            COUNT_LOCAL_NONZERO,\n            CountLocalNonzeroRequest,\n            CountLocalNonzeroResponse,\n        )\n        # Store the tensor data with specified data name\n        self._data_store = {}\n        # Store the partition information with specified data name\n        self._part_policy = {}\n        # This stores all unique partition policies in the kvstore. The key is the policy name.\n        self._all_possible_part_policy = {}\n        # Store the full data shape across kvserver\n        self._full_data_shape = {}\n        # Store all the data name\n        self._data_name_list = set()\n        # Store all graph data name\n        self._gdata_name_list = set()\n        # Basic information\n        self._server_namebook = rpc.read_ip_config(ip_config, num_servers)\n        self._server_count = len(self._server_namebook)\n        self._group_count = self._server_namebook[0][3]\n        self._machine_count = int(self._server_count / self._group_count)\n        self._client_id = rpc.get_rank()\n        self._machine_id = rpc.get_machine_id()\n        self._part_id = self._machine_id\n        self._main_server_id = self._machine_id * self._group_count\n\n        # push and pull handler\n        self._pull_handlers = {}\n        self._push_handlers = {}\n        # register role on server-0\n        self._role = role\n\n    @property\n    def all_possible_part_policy(self):\n        \"\"\"Get all possible partition policies\"\"\"\n        return self._all_possible_part_policy\n\n    @property\n    def client_id(self):\n        \"\"\"Get client ID\"\"\"\n        return self._client_id\n\n    @property\n    def role(self):\n        \"\"\"Get client role\"\"\"\n        return self._role\n\n    @property\n    def machine_id(self):\n        \"\"\"Get machine ID\"\"\"\n        return self._machine_id\n\n    @property\n    def num_servers(self):\n        \"\"\"Get the number of servers\"\"\"\n        return self._server_count\n\n    @property\n    def group_count(self):\n        \"\"\"Get the number of groups --num_servers\"\"\"\n        return self._group_count\n\n    def barrier(self):\n        \"\"\"Barrier for all client nodes.\n\n        This API will be blocked untill all the clients invoke this API.\n        \"\"\"\n        request = BarrierRequest(self._role)\n        rpc.send_request(0, request)\n        response = rpc.recv_response()\n        assert response.msg == BARRIER_MSG\n\n    def register_push_handler(self, name, func):\n        \"\"\"Register UDF push function.\n\n        This UDF is triggered for every push. The signature of the UDF is\n\n        ```\n        def push_handler(data_store, name, local_offset, data)\n        ```\n\n        ``data_store`` is a dict that contains all tensors in the kvstore. ``name`` is the name\n        of the tensor where new data is pushed to. ``local_offset`` is the offset where new\n        data should be written in the tensor in the local partition. ``data`` is the new data\n        to be written.\n\n        Parameters\n        ----------\n        name : str\n            The name of the tensor\n        func : callable\n            The function to be called.\n        \"\"\"\n        self.barrier()\n        request = RegisterPushHandlerRequest(name, func)\n        # send request to all the server nodes\n        for server_id in range(self._server_count):\n            rpc.send_request(server_id, request)\n        # recv response from all the server nodes\n        for _ in range(self._server_count):\n            response = rpc.recv_response()\n            assert response.msg == REGISTER_PUSH_MSG\n        self._push_handlers[name] = func\n        self.barrier()\n\n    def register_pull_handler(self, name, func):\n        \"\"\"Register UDF pull function.\n\n        This UDF is triggered for every pull. The signature of the UDF is\n\n        ```\n        def pull_handler(data_store, name, local_offset)\n        ```\n\n        ``data_store`` is a dict that contains all tensors in the kvstore. ``name`` is the name\n        of the tensor where new data is pushed to. ``local_offset`` is the offset where new\n        data should be written in the tensor in the local partition.\n\n        Parameters\n        ----------\n        name : str\n            The name of the tensor\n        func : callable\n            The function to be called.\n        \"\"\"\n        self.barrier()\n        request = RegisterPullHandlerRequest(name, func)\n        # send request to all the server nodes\n        for server_id in range(self._server_count):\n            rpc.send_request(server_id, request)\n        # recv response from all the server nodes\n        for _ in range(self._server_count):\n            response = rpc.recv_response()\n            assert response.msg == REGISTER_PULL_MSG\n        self._pull_handlers[name] = func\n        self.barrier()\n\n    def init_data(\n        self, name, shape, dtype, part_policy, init_func, is_gdata=True\n    ):\n        \"\"\"Send message to kvserver to initialize new data tensor and mapping this\n        data from server side to client side.\n\n        Parameters\n        ----------\n        name : str\n            data name\n        shape : list or tuple of int\n            data shape\n        dtype : dtype\n            data type\n        part_policy : PartitionPolicy\n            partition policy.\n        init_func : func\n            UDF init function\n        is_gdata : bool\n            Whether the created tensor is a ndata/edata or not.\n        \"\"\"\n        assert len(name) > 0, \"name cannot be empty.\"\n        assert len(shape) > 0, \"shape cannot be empty\"\n        assert name not in self._data_name_list, (\n            \"data name: %s already exists.\" % name\n        )\n        self.barrier()\n        shape = list(shape)\n\n        # Send request to the servers to initialize data.\n        # The servers may handle the duplicated initializations.\n        part_shape = shape.copy()\n        part_shape[0] = part_policy.get_part_size()\n        request = InitDataRequest(\n            name,\n            tuple(part_shape),\n            F.reverse_data_type_dict[dtype],\n            part_policy.policy_str,\n            init_func,\n        )\n        # The request is sent to the servers in one group, which are on the same machine.\n        for n in range(self._group_count):\n            server_id = part_policy.part_id * self._group_count + n\n            rpc.send_request(server_id, request)\n        for _ in range(self._group_count):\n            response = rpc.recv_response()\n            assert response.msg == INIT_MSG\n\n        self.barrier()\n        # Create local shared-data\n        local_shape = shape.copy()\n        local_shape[0] = part_policy.get_part_size()\n        if name in self._part_policy:\n            raise RuntimeError(\"Policy %s has already exists!\" % name)\n        if name in self._data_store:\n            raise RuntimeError(\"Data %s has already exists!\" % name)\n        if name in self._full_data_shape:\n            raise RuntimeError(\"Data shape %s has already exists!\" % name)\n        self._part_policy[name] = part_policy\n        self._all_possible_part_policy[part_policy.policy_str] = part_policy\n        shared_data = empty_shared_mem(\n            name + \"-kvdata-\",\n            False,\n            local_shape,\n            F.reverse_data_type_dict[dtype],\n        )\n        dlpack = shared_data.to_dlpack()\n        self._data_store[name] = F.zerocopy_from_dlpack(dlpack)\n        self._data_name_list.add(name)\n        if is_gdata:\n            self._gdata_name_list.add(name)\n        self._full_data_shape[name] = tuple(shape)\n        self._pull_handlers[name] = default_pull_handler\n        self._push_handlers[name] = default_push_handler\n\n        # Now we need to tell the backup server the new tensor.\n        request = SendMetaToBackupRequest(\n            name,\n            F.reverse_data_type_dict[dtype],\n            part_shape,\n            part_policy.policy_str,\n            self._pull_handlers[name],\n            self._push_handlers[name],\n        )\n        # send request to all the backup server nodes\n        for i in range(self._group_count - 1):\n            server_id = self._machine_id * self._group_count + i + 1\n            rpc.send_request(server_id, request)\n        # recv response from all the backup server nodes\n        for _ in range(self._group_count - 1):\n            response = rpc.recv_response()\n            assert response.msg == SEND_META_TO_BACKUP_MSG\n        self.barrier()\n\n    def delete_data(self, name):\n        \"\"\"Send message to kvserver to delete tensor and clear the meta data\n\n        Parameters\n        ----------\n        name : str\n            data name\n        \"\"\"\n        assert len(name) > 0, \"name cannot be empty.\"\n        assert name in self._data_name_list, \"data name: %s not exists.\" % name\n        self.barrier()\n        part_policy = self._part_policy[name]\n\n        # send request to every server nodes\n        request = DeleteDataRequest(name)\n        for n in range(self._group_count):\n            server_id = part_policy.part_id * self._group_count + n\n            rpc.send_request(server_id, request)\n        for _ in range(self._group_count):\n            response = rpc.recv_response()\n            assert response.msg == DELETE_MSG\n\n        self.barrier()\n        self._data_name_list.remove(name)\n        if name in self._gdata_name_list:\n            self._gdata_name_list.remove(name)\n        # TODO(chao) : remove the delete log print\n        del self._data_store[name]\n        del self._full_data_shape[name]\n        del self._part_policy[name]\n        del self._pull_handlers[name]\n        del self._push_handlers[name]\n        self.barrier()\n\n    def map_shared_data(self, partition_book):\n        \"\"\"Mapping shared-memory tensor from server to client.\n\n        Parameters\n        ----------\n        partition_book : GraphPartitionBook\n            Store the partition information\n        \"\"\"\n        # Get all partition policies\n        for ntype in partition_book.ntypes:\n            policy = NodePartitionPolicy(partition_book, ntype)\n            self._all_possible_part_policy[policy.policy_str] = policy\n        for etype in partition_book.canonical_etypes:\n            policy = EdgePartitionPolicy(partition_book, etype)\n            self._all_possible_part_policy[policy.policy_str] = policy\n\n        # Get shared data from server side\n        self.barrier()\n        request = GetSharedDataRequest(GET_SHARED_MSG)\n        rpc.send_request(self._main_server_id, request)\n        response = rpc.recv_response()\n        for name, meta in response.meta.items():\n            if name not in self._data_name_list:\n                shape, dtype, policy_str = meta\n                assert policy_str in self._all_possible_part_policy\n                shared_data = empty_shared_mem(\n                    name + \"-kvdata-\", False, shape, dtype\n                )\n                dlpack = shared_data.to_dlpack()\n                self._data_store[name] = F.zerocopy_from_dlpack(dlpack)\n                self._part_policy[name] = self._all_possible_part_policy[\n                    policy_str\n                ]\n                self._pull_handlers[name] = default_pull_handler\n                self._push_handlers[name] = default_push_handler\n        # Get full data shape across servers\n        for name, meta in response.meta.items():\n            if name not in self._data_name_list:\n                shape, _, _ = meta\n                data_shape = list(shape)\n                data_shape[0] = 0\n                request = GetPartShapeRequest(name)\n                # send request to all main server nodes\n                for machine_id in range(self._machine_count):\n                    server_id = machine_id * self._group_count\n                    rpc.send_request(server_id, request)\n                # recv response from all the main server nodes\n                for _ in range(self._machine_count):\n                    res = rpc.recv_response()\n                    data_shape[0] += res.shape[0]\n                self._full_data_shape[name] = tuple(data_shape)\n        # Send meta data to backup servers\n        for name, meta in response.meta.items():\n            shape, dtype, policy_str = meta\n            request = SendMetaToBackupRequest(\n                name,\n                dtype,\n                shape,\n                policy_str,\n                self._pull_handlers[name],\n                self._push_handlers[name],\n            )\n            # send request to all the backup server nodes\n            for i in range(self._group_count - 1):\n                server_id = self._machine_id * self._group_count + i + 1\n                rpc.send_request(server_id, request)\n            # recv response from all the backup server nodes\n            for _ in range(self._group_count - 1):\n                response = rpc.recv_response()\n                assert response.msg == SEND_META_TO_BACKUP_MSG\n            self._data_name_list.add(name)\n            # map_shared_data happens only at DistGraph initialization\n            # TODO(xiangsx): We assume there is no non-graph data initialized at this time\n            self._gdata_name_list.add(name)\n        self.barrier()\n\n    def gdata_name_list(self):\n        \"\"\"Get all the graph data name\"\"\"\n        return list(self._gdata_name_list)\n\n    def data_name_list(self):\n        \"\"\"Get all the data name\"\"\"\n        return list(self._data_name_list)\n\n    def get_data_meta(self, name):\n        \"\"\"Get meta data (data_type, data_shape, partition_policy)\"\"\"\n        assert len(name) > 0, \"name cannot be empty.\"\n        data_type = F.dtype(self._data_store[name])\n        data_shape = self._full_data_shape[name]\n        part_policy = self._part_policy[name]\n        return (data_type, data_shape, part_policy)\n\n    def get_partid(self, name, id_tensor):\n        \"\"\"\n        Parameters\n        ----------\n        name : str\n            data name\n        id_tensor : tensor\n            a vector storing the global data ID\n        \"\"\"\n        assert len(name) > 0, \"name cannot be empty.\"\n        id_tensor = utils.toindex(id_tensor)\n        id_tensor = id_tensor.tousertensor()\n        assert F.ndim(id_tensor) == 1, \"ID must be a vector.\"\n        # partition data\n        machine_id = self._part_policy[name].to_partid(id_tensor)\n\n        return machine_id\n\n    def push(self, name, id_tensor, data_tensor):\n        \"\"\"Push data to KVServer.\n\n        Note that, the push() is an non-blocking operation that will return immediately.\n\n        Parameters\n        ----------\n        name : str\n            data name\n        id_tensor : tensor\n            a vector storing the global data ID\n        data_tensor : tensor\n            a tensor with the same row size of data ID\n        \"\"\"\n        assert len(name) > 0, \"name cannot be empty.\"\n        id_tensor = utils.toindex(id_tensor)\n        id_tensor = id_tensor.tousertensor()\n        assert F.ndim(id_tensor) == 1, \"ID must be a vector.\"\n        assert (\n            F.shape(id_tensor)[0] == F.shape(data_tensor)[0]\n        ), \"The data must has the same row size with ID.\"\n        # partition data\n        machine_id = self._part_policy[name].to_partid(id_tensor)\n        # sort index by machine id\n        sorted_id = F.tensor(np.argsort(F.asnumpy(machine_id)))\n        id_tensor = id_tensor[sorted_id]\n        data_tensor = data_tensor[sorted_id]\n        machine, count = np.unique(F.asnumpy(machine_id), return_counts=True)\n        # push data to server by order\n        start = 0\n        local_id = None\n        local_data = None\n        for idx, machine_idx in enumerate(machine):\n            end = start + count[idx]\n            if start == end:  # No data for target machine\n                continue\n            partial_id = id_tensor[start:end]\n            partial_data = data_tensor[start:end]\n            if machine_idx == self._machine_id:  # local push\n                # Note that DO NOT push local data right now because we can overlap\n                # communication-local_push here\n                local_id = self._part_policy[name].to_local(partial_id)\n                local_data = partial_data\n            else:  # push data to remote server\n                request = PushRequest(name, partial_id, partial_data)\n                rpc.send_request_to_machine(machine_idx, request)\n            start += count[idx]\n        if local_id is not None:  # local push\n            self._push_handlers[name](\n                self._data_store, name, local_id, local_data\n            )\n\n    def pull(self, name, id_tensor):\n        \"\"\"Pull message from KVServer.\n\n        Parameters\n        ----------\n        name : str\n            data name\n        id_tensor : tensor\n            a vector storing the ID list\n\n        Returns\n        -------\n        tensor\n            a data tensor with the same row size of id_tensor.\n        \"\"\"\n        assert len(name) > 0, \"name cannot be empty.\"\n        id_tensor = utils.toindex(id_tensor)\n        id_tensor = id_tensor.tousertensor()\n        assert F.ndim(id_tensor) == 1, \"ID must be a vector.\"\n        if self._pull_handlers[name] is default_pull_handler:  # Use fast-pull\n            part_id = self._part_policy[name].to_partid(id_tensor)\n            return rpc.fast_pull(\n                name,\n                id_tensor,\n                part_id,\n                KVSTORE_PULL,\n                self._machine_count,\n                self._group_count,\n                self._machine_id,\n                self._client_id,\n                self._data_store[name],\n                self._part_policy[name],\n            )\n        else:\n            # partition data\n            machine_id = self._part_policy[name].to_partid(id_tensor)\n            # sort index by machine id\n            sorted_id = F.tensor(np.argsort(F.asnumpy(machine_id)))\n            back_sorted_id = F.tensor(np.argsort(F.asnumpy(sorted_id)))\n            id_tensor = id_tensor[sorted_id]\n            machine, count = np.unique(\n                F.asnumpy(machine_id), return_counts=True\n            )\n            # pull data from server by order\n            start = 0\n            pull_count = 0\n            local_id = None\n            for idx, machine_idx in enumerate(machine):\n                end = start + count[idx]\n                if start == end:  # No data for target machine\n                    continue\n                partial_id = id_tensor[start:end]\n                if machine_idx == self._machine_id:  # local pull\n                    # Note that DO NOT pull local data right now because we can overlap\n                    # communication-local_pull here\n                    local_id = self._part_policy[name].to_local(partial_id)\n                else:  # pull data from remote server\n                    request = PullRequest(name, partial_id)\n                    rpc.send_request_to_machine(machine_idx, request)\n                    pull_count += 1\n                start += count[idx]\n            # recv response\n            response_list = []\n            if local_id is not None:  # local pull\n                local_data = self._pull_handlers[name](\n                    self._data_store, name, local_id\n                )\n                server_id = self._main_server_id\n                local_response = PullResponse(server_id, local_data)\n                response_list.append(local_response)\n            # wait response from remote server nodes\n            for _ in range(pull_count):\n                remote_response = rpc.recv_response()\n                response_list.append(remote_response)\n            # sort response by server_id and concat tensor\n            response_list.sort(key=self._take_id)\n            data_tensor = F.cat(\n                seq=[response.data_tensor for response in response_list], dim=0\n            )\n            return data_tensor[\n                back_sorted_id\n            ]  # return data with original index order\n\n    def union(self, operand1_name, operand2_name, output_name):\n        \"\"\"Compute the union of two mask arrays in the KVStore.\"\"\"\n        # Each trainer computes its own result from its local storage.\n        self._data_store[output_name][:] = (\n            self._data_store[operand1_name] | self._data_store[operand2_name]\n        )\n\n    def _take_id(self, elem):\n        \"\"\"Used by sort response list\"\"\"\n        return elem.server_id\n\n    def count_nonzero(self, name):\n        \"\"\"Count nonzero value by pull request from KVServers.\n\n        Parameters\n        ----------\n        name : str\n            data name\n\n        Returns\n        -------\n        int\n            the number of nonzero in this data.\n        \"\"\"\n        total = 0\n        pull_count = 0\n        for machine_id in range(self._machine_count):\n            if machine_id == self._machine_id:\n                local_id = F.tensor(\n                    np.arange(\n                        self._part_policy[name].get_part_size(), dtype=np.int64\n                    )\n                )\n                total += F.count_nonzero(self._data_store[name][local_id])\n            else:\n                request = CountLocalNonzeroRequest(name)\n                rpc.send_request_to_machine(machine_id, request)\n                pull_count += 1\n        for _ in range(pull_count):\n            res = rpc.recv_response()\n            total += res.num_local_nonzero\n        return total\n\n    @property\n    def data_store(self):\n        \"\"\"Return the local partition of the data storage.\n\n        Returns\n        -------\n        dict[str, Tensor]\n            The tensor storages of the local partition.\n        \"\"\"\n        return self._data_store\n\n\nKVCLIENT = None\n\n\ndef init_kvstore(ip_config, num_servers, role):\n    \"\"\"initialize KVStore\"\"\"\n    global KVCLIENT\n    if KVCLIENT is None:\n        if os.environ.get(\"DGL_DIST_MODE\", \"standalone\") == \"standalone\":\n            KVCLIENT = SA_KVClient()\n        else:\n            KVCLIENT = KVClient(ip_config, num_servers, role)\n\n\ndef close_kvstore():\n    \"\"\"Close the current KVClient\"\"\"\n    global KVCLIENT\n    KVCLIENT = None\n\n\ndef get_kvstore():\n    \"\"\"get the KVClient\"\"\"\n    return KVCLIENT\n"
  },
  {
    "path": "python/dgl/distributed/nn/__init__.py",
    "content": "\"\"\"dgl distributed.optims.\"\"\"\nimport importlib\nimport os\nimport sys\n\nfrom ...backend import backend_name\nfrom ...utils import expand_as_pair\n\n\ndef _load_backend(mod_name):\n    mod = importlib.import_module(\".%s\" % mod_name, __name__)\n    thismod = sys.modules[__name__]\n    for api, obj in mod.__dict__.items():\n        setattr(thismod, api, obj)\n\n\n_load_backend(backend_name)\n"
  },
  {
    "path": "python/dgl/distributed/nn/mxnet/__init__.py",
    "content": ""
  },
  {
    "path": "python/dgl/distributed/nn/pytorch/__init__.py",
    "content": "\"\"\"dgl distributed sparse optimizer for pytorch.\"\"\"\nfrom .sparse_emb import DistEmbedding\n"
  },
  {
    "path": "python/dgl/distributed/nn/pytorch/sparse_emb.py",
    "content": "\"\"\"Define sparse embedding and optimizer.\"\"\"\n\nimport torch as th\n\nfrom .... import backend as F, utils\nfrom ...dist_tensor import DistTensor\n\n\nclass DistEmbedding:\n    \"\"\"Distributed node embeddings.\n\n    DGL provides a distributed embedding to support models that require learnable embeddings.\n    DGL's distributed embeddings are mainly used for learning node embeddings of graph models.\n    Because distributed embeddings are part of a model, they are updated by mini-batches.\n    The distributed embeddings have to be updated by DGL's optimizers instead of\n    the optimizers provided by the deep learning frameworks (e.g., Pytorch and MXNet).\n\n    To support efficient training on a graph with many nodes, the embeddings support sparse\n    updates. That is, only the embeddings involved in a mini-batch computation are updated.\n    Please refer to `Distributed Optimizers <https://docs.dgl.ai/api/python/dgl.distributed.html#\n    distributed-embedding-optimizer>`__ for available optimizers in DGL.\n\n    Distributed embeddings are sharded and stored in a cluster of machines in the same way as\n    :class:`dgl.distributed.DistTensor`, except that distributed embeddings are trainable.\n    Because distributed embeddings are sharded\n    in the same way as nodes and edges of a distributed graph, it is usually much more\n    efficient to access than the sparse embeddings provided by the deep learning frameworks.\n\n    Parameters\n    ----------\n    num_embeddings : int\n        The number of embeddings. Currently, the number of embeddings has to be the same as\n        the number of nodes or the number of edges.\n    embedding_dim : int\n        The dimension size of embeddings.\n    name : str, optional\n        The name of the embeddings. The name can uniquely identify embeddings in a system\n        so that another DistEmbedding object can referent to the same embeddings.\n    init_func : callable, optional\n        The function to create the initial data. If the init function is not provided,\n        the values of the embeddings are initialized to zero.\n    part_policy : PartitionPolicy, optional\n        The partition policy that assigns embeddings to different machines in the cluster.\n        Currently, it only supports node partition policy or edge partition policy.\n        The system determines the right partition policy automatically.\n\n    Examples\n    --------\n    >>> def initializer(shape, dtype):\n            arr = th.zeros(shape, dtype=dtype)\n            arr.uniform_(-1, 1)\n            return arr\n    >>> emb = dgl.distributed.DistEmbedding(g.num_nodes(), 10, init_func=initializer)\n    >>> optimizer = dgl.distributed.optim.SparseAdagrad([emb], lr=0.001)\n    >>> for blocks in dataloader:\n    ...     feats = emb(nids)\n    ...     loss = F.sum(feats + 1, 0)\n    ...     loss.backward()\n    ...     optimizer.step()\n\n    Note\n    ----\n    When a ``DistEmbedding``  object is used in the forward computation, users\n    have to invoke\n    :py:meth:`~dgl.distributed.optim.SparseAdagrad.step` afterwards. Otherwise,\n    there will be some memory leak.\n    \"\"\"\n\n    def __init__(\n        self,\n        num_embeddings,\n        embedding_dim,\n        name=None,\n        init_func=None,\n        part_policy=None,\n    ):\n        self._tensor = DistTensor(\n            (num_embeddings, embedding_dim),\n            F.float32,\n            name,\n            init_func=init_func,\n            part_policy=part_policy,\n        )\n        self._trace = []\n        self._name = name\n        self._num_embeddings = num_embeddings\n        self._embedding_dim = embedding_dim\n\n        # Check whether it is multi-gpu/distributed training or not\n        if th.distributed.is_initialized():\n            self._rank = th.distributed.get_rank()\n            self._world_size = th.distributed.get_world_size()\n        # [TODO] The following code is clearly wrong but changing it to \"raise DGLError\"\n        # actually fails unit test.  ???\n        # else:\n        #     assert 'th.distributed should be initialized'\n        self._optm_state = None  # track optimizer state\n        self._part_policy = part_policy\n\n    def __call__(self, idx, device=th.device(\"cpu\")):\n        \"\"\"\n        node_ids : th.tensor\n            Index of the embeddings to collect.\n        device : th.device\n            Target device to put the collected embeddings.\n\n        Returns\n        -------\n        Tensor\n            The requested node embeddings\n        \"\"\"\n        idx = utils.toindex(idx).tousertensor()\n        emb = self._tensor[idx].to(device, non_blocking=True)\n        if F.is_recording():\n            emb = F.attach_grad(emb)\n            self._trace.append((idx.to(device, non_blocking=True), emb))\n        return emb\n\n    def reset_trace(self):\n        \"\"\"Reset the traced data.\"\"\"\n        self._trace = []\n\n    @property\n    def part_policy(self):\n        \"\"\"Return the partition policy\n\n        Returns\n        -------\n        PartitionPolicy\n            partition policy\n        \"\"\"\n        return self._part_policy\n\n    @property\n    def name(self):\n        \"\"\"Return the name of the embeddings\n\n        Returns\n        -------\n        str\n            The name of the embeddings\n        \"\"\"\n        return self._tensor.tensor_name\n\n    @property\n    def data_name(self):\n        \"\"\"Return the data name of the embeddings\n\n        Returns\n        -------\n        str\n            The data name of the embeddings\n        \"\"\"\n        return self._tensor._name\n\n    @property\n    def kvstore(self):\n        \"\"\"Return the kvstore client\n\n        Returns\n        -------\n        KVClient\n            The kvstore client\n        \"\"\"\n        return self._tensor.kvstore\n\n    @property\n    def num_embeddings(self):\n        \"\"\"Return the number of embeddings\n\n        Returns\n        -------\n        int\n            The number of embeddings\n        \"\"\"\n        return self._num_embeddings\n\n    @property\n    def embedding_dim(self):\n        \"\"\"Return the dimension of embeddings\n\n        Returns\n        -------\n        int\n            The dimension of embeddings\n        \"\"\"\n        return self._embedding_dim\n\n    @property\n    def optm_state(self):\n        \"\"\"Return the optimizer related state tensor.\n\n        Returns\n        -------\n        tuple of torch.Tensor\n            The optimizer related state.\n        \"\"\"\n        return self._optm_state\n\n    @property\n    def weight(self):\n        \"\"\"Return the tensor storing the node embeddings\n\n        Returns\n        -------\n        torch.Tensor\n            The tensor storing the node embeddings\n        \"\"\"\n        return self._tensor\n"
  },
  {
    "path": "python/dgl/distributed/nn/tensorflow/__init__.py",
    "content": ""
  },
  {
    "path": "python/dgl/distributed/optim/__init__.py",
    "content": "\"\"\"dgl distributed.optims.\"\"\"\nimport importlib\nimport os\nimport sys\n\nfrom ...backend import backend_name\nfrom ...utils import expand_as_pair\n\n\ndef _load_backend(mod_name):\n    mod = importlib.import_module(\".%s\" % mod_name, __name__)\n    thismod = sys.modules[__name__]\n    for api, obj in mod.__dict__.items():\n        setattr(thismod, api, obj)\n\n\n_load_backend(backend_name)\n"
  },
  {
    "path": "python/dgl/distributed/optim/mxnet/__init__.py",
    "content": ""
  },
  {
    "path": "python/dgl/distributed/optim/pytorch/__init__.py",
    "content": "\"\"\"dgl distributed sparse optimizer for pytorch.\"\"\"\nfrom .sparse_optim import SparseAdagrad, SparseAdam\n"
  },
  {
    "path": "python/dgl/distributed/optim/pytorch/sparse_optim.py",
    "content": "\"\"\"Node embedding optimizers for distributed training\"\"\"\nimport abc\nimport warnings\nfrom abc import abstractmethod\nfrom os.path import exists\n\nimport torch as th\n\nimport dgl\n\nfrom .... import backend as F\nfrom ...dist_tensor import DistTensor\nfrom ...graph_partition_book import EDGE_PART_POLICY, NODE_PART_POLICY\nfrom ...nn.pytorch import DistEmbedding\nfrom .utils import alltoall, alltoallv\n\nEMB_STATES = \"emb_states\"\nWORLD_SIZE = \"world_size\"\nIDS = \"ids\"\nPARAMS = \"params\"\nSTATES = \"states\"\n\n\nclass DistSparseGradOptimizer(abc.ABC):\n    r\"\"\"The abstract dist sparse optimizer.\n\n    Note: dgl dist sparse optimizer only work with dgl.distributed.DistEmbedding\n\n    Parameters\n    ----------\n    params : list of DistEmbedding\n        The list of DistEmbedding.\n    lr : float\n        The learning rate.\n    \"\"\"\n\n    def __init__(self, params, lr):\n        self._params = params\n        self._lr = lr\n        self._rank = None\n        self._world_size = None\n        self._shared_cache = {}\n        self._clean_grad = False\n        self._opt_meta = {}\n        self._state = {}\n        ## collect all hyper parameters for save\n        self._defaults = {}\n\n        if th.distributed.is_initialized():\n            self._rank = th.distributed.get_rank()\n            self._world_size = th.distributed.get_world_size()\n        else:\n            self._rank = 0\n            self._world_size = 1\n\n    def local_state_dict(self):\n        \"\"\"Return the state pertaining to current rank of the optimizer.\n\n        Returns\n        -------\n        dict\n            Local state dict\n            Example Dict of Adagrad Optimizer:\n            .. code-block:: json\n\n            {\n                \"params\": {\n                    \"_lr\": 0.01,\n                    \"_eps\": \"1e-8\",\n                    \"world_size\": 2\n                },\n                \"emb_states\": {\n                    \"emb_name1\": {\n                        \"ids\": [0, 2, 4, 6 ,8 ,10], ## tensor,\n                        \"emb_name1_sum\": [0.1 , 0.2, 0.5, 0.1, 0.2] ## tensor,\n                    },\n                    \"emb_name2\": {\n                        \"ids\": [0, 2, 4, 6 ,8 ,10], ## tensor,\n                        \"emb_name2_sum\": [0.3 , 0.2, 0.4, 0.5, 0.2] ## tensor,\n                    }\n                }\n            }\n\n            :param json: json object\n\n        See Also\n        --------\n        load_local_state_dict\n        \"\"\"\n        local_state_dict = {}\n        local_state_dict[EMB_STATES] = {}\n        local_state_dict[PARAMS] = {WORLD_SIZE: self._world_size}\n        for emb in self._params:\n            trainers_per_machine = self._world_size // max(\n                1, dgl.distributed.get_num_machines()\n            )\n            emb_state_dict = {}\n            part_policy = (\n                emb.part_policy if emb.part_policy else emb.weight.part_policy\n            )\n            idx = self._get_local_ids(part_policy)\n            if trainers_per_machine > 1:\n                kv_idx_split = (idx % trainers_per_machine).long()\n                local_rank = self._rank % trainers_per_machine\n                mask = kv_idx_split == local_rank\n                idx = F.boolean_mask(idx, mask)\n            emb_state_dict.update({IDS: idx})\n            emb_state = {}\n            states = (\n                list(self._state[emb.name])\n                if isinstance(self._state[emb.name], tuple)\n                else [self._state[emb.name]]\n            )\n            emb_state = {state.name: state[idx] for state in states}\n            emb_state_dict.update({STATES: emb_state})\n            local_state_dict[EMB_STATES].update({emb.name: emb_state_dict})\n        local_state_dict[PARAMS].update(self._defaults)\n        return local_state_dict\n\n    def load_local_state_dict(self, local_state_dict):\n        \"\"\"Load the local state from the input state_dict,\n        updating the optimizer as needed.\n\n        Parameters\n        ----------\n        local_state_dict : dict\n            Optimizer state; should be an object returned\n            from a call to local_state_dict().\n\n        See Also\n        --------\n        local_state_dict\n        \"\"\"\n        for emb_name, emb_state in local_state_dict[EMB_STATES].items():\n            idx = emb_state[IDS]\n            # As state of an embedding of different optimizers can be a single\n            # DistTensor(Adagrad) or a tuple(Adam) of that, converting it to list for\n            # consistency. The list contains reference(s) to original DistTensor(s).\n            states = (\n                list(self._state[emb_name])\n                if isinstance(self._state[emb_name], tuple)\n                else [self._state[emb_name]]\n            )\n            if len(emb_state[STATES]) != len(states):\n                raise ValueError(\n                    f\"loaded state dict has a different number of states\"\n                    f\" of embedding {emb_name}\"\n                )\n            name_to_index = {\n                state.name: index for index, state in enumerate(states)\n            }\n            for name, state in emb_state[STATES].items():\n                if name not in name_to_index:\n                    raise ValueError(\n                        \"loaded state dict contains a state {name}\"\n                        \"that can't be found in the optimizer states\"\n                    )\n                state_idx = name_to_index[name]\n                state = state.to(\n                    th.device(\"cpu\"), states[name_to_index[name]].dtype\n                )\n                states[state_idx][idx] = state\n        self._defaults.update(local_state_dict[PARAMS])\n        self.__dict__.update(local_state_dict[PARAMS])\n\n    def save(self, f):\n        \"\"\"Save the local state_dict to disk on per rank.\n\n        Saved dict contains 2 parts:\n\n        * 'params': hyper parameters of the optimizer.\n        * 'emb_states': partial optimizer states, each embedding contains 2 items:\n            1. ```ids```: global id of the nodes/edges stored in this rank.\n            2. ```states```: state data corrseponding to ```ids```.\n\n        NOTE: This needs to be called on all ranks.\n\n        Parameters\n        ----------\n        f : Union[str, os.PathLike]\n            The path of the file to save to.\n\n        See Also\n        --------\n        load\n        \"\"\"\n        if self._world_size > 1:\n            th.distributed.barrier()\n        f = f if isinstance(f, str) else str(f, \"UTF-8\")\n        f = f\"{f}_{self._rank}\"\n        th.save(self.local_state_dict(), f)\n        if self._world_size > 1:\n            th.distributed.barrier()\n\n    def load(self, f):\n        \"\"\"Load the local state of the optimizer from the file on per rank.\n\n        NOTE: This needs to be called on all ranks.\n\n        Parameters\n        ----------\n        f : Union[str, os.PathLike]\n            The path of the file to load from.\n\n        See Also\n        --------\n        save\n        \"\"\"\n        if self._world_size > 1:\n            th.distributed.barrier()\n        f = f if isinstance(f, str) else str(f, \"UTF-8\")\n        f_attach_rank = f\"{f}_{self._rank}\"\n        # Don't throw error here to support device number scale-out\n        # after reloading, but make sure your hyper parameter is same\n        # as before because new added local optimizers will be filled\n        # in nothing\n        if not exists(f_attach_rank):\n            warnings.warn(f\"File {f_attach_rank} can't be found, load nothing.\")\n        else:\n            old_world_size = self._load_state_from(f_attach_rank)\n            # Device number scale-in\n            if self._world_size < old_world_size:\n                for rank in range(\n                    self._rank + self._world_size,\n                    old_world_size,\n                    self._world_size,\n                ):\n                    self._load_state_from(f\"{f}_{rank}\")\n        if self._world_size > 1:\n            th.distributed.barrier()\n\n    def _load_state_from(self, f):\n        local_state_dict = th.load(f)\n        world_size = local_state_dict[PARAMS].pop(WORLD_SIZE)\n        self.load_local_state_dict(local_state_dict)\n        return world_size\n\n    def _get_local_ids(self, part_policy):\n        if EDGE_PART_POLICY in part_policy.policy_str:\n            return part_policy.partition_book.partid2eids(\n                part_policy.part_id, part_policy.type_name\n            )\n        elif NODE_PART_POLICY in part_policy.policy_str:\n            return part_policy._partition_book.partid2nids(\n                part_policy.part_id, part_policy.type_name\n            )\n        else:\n            raise RuntimeError(\n                \"Cannot support policy: %s \" % part_policy.policy_str\n            )\n\n    def step(self):\n        \"\"\"The step function.\n\n        The step function is invoked at the end of every batch to push the gradients\n        of the embeddings involved in a mini-batch to DGL's servers and update the embeddings.\n        \"\"\"\n        with th.no_grad():\n            # [Rui]\n            # As `gloo` supports CPU tensors only while `nccl` supports GPU\n            # tensors only, we firstly create tensors on the corresponding\n            # devices and then copy the data to target device if needed.\n            # Please note that the target device can be different from the\n            # preferred device.\n            target_device = None\n            preferred_device = (\n                th.device(f\"cuda:{self._rank}\")\n                if th.distributed.get_backend() == \"nccl\"\n                else th.device(\"cpu\")\n            )\n            local_indics = {emb.name: [] for emb in self._params}\n            local_grads = {emb.name: [] for emb in self._params}\n            for emb in self._params:\n                name = emb.weight.name\n                kvstore = emb.weight.kvstore\n                trainers_per_server = self._world_size // kvstore.num_servers\n\n                idics = []\n                grads = []\n                for trace in emb._trace:\n                    if trace[1].grad is not None:\n                        idics.append(trace[0])\n                        grads.append(trace[1].grad.data)\n                    else:\n                        assert len(trace[0]) == 0\n                # If the sparse embedding is not used in the previous forward step\n                # The idx and grad will be empty, initialize them as empty tensors to\n                # avoid crashing the optimizer step logic.\n                #\n                # Note: we cannot skip the gradient exchange and update steps as other\n                # working processes may send gradient update requests corresponding\n                # to certain embedding to this process.\n                #\n                # [WARNING][TODO][Rui]\n                # For empty idx and grad, we blindly create data on the\n                # preferred device, which may not be the device where the\n                # embedding is stored.\n                idics = (\n                    th.cat(idics, dim=0)\n                    if len(idics) != 0\n                    else th.zeros((0,), dtype=th.int64, device=preferred_device)\n                )\n                grads = (\n                    th.cat(grads, dim=0)\n                    if len(grads) != 0\n                    else th.zeros(\n                        (0, emb.embedding_dim),\n                        dtype=th.float32,\n                        device=preferred_device,\n                    )\n                )\n                target_device = grads.device\n\n                # will send grad to each corresponding trainer\n                if self._world_size > 1:\n                    # get idx split from kvstore\n                    idx_split = kvstore.get_partid(emb.data_name, idics)\n                    idx_split_size = []\n                    idics_list = []\n                    grad_list = []\n                    # split idx and grad first\n                    for i in range(kvstore.num_servers):\n                        mask = idx_split == i\n                        idx_i = idics[mask]\n                        grad_i = grads[mask]\n\n                        if trainers_per_server <= 1:\n                            idx_split_size.append(\n                                th.tensor(\n                                    [idx_i.shape[0]],\n                                    dtype=th.int64,\n                                    device=preferred_device,\n                                )\n                            )\n                            idics_list.append(idx_i)\n                            grad_list.append(grad_i)\n                        else:\n                            kv_idx_split = th.remainder(\n                                idx_i, trainers_per_server\n                            ).long()\n                            for j in range(trainers_per_server):\n                                mask = kv_idx_split == j\n                                idx_j = idx_i[mask]\n                                grad_j = grad_i[mask]\n                                idx_split_size.append(\n                                    th.tensor(\n                                        [idx_j.shape[0]],\n                                        dtype=th.int64,\n                                        device=preferred_device,\n                                    )\n                                )\n                                idics_list.append(idx_j)\n                                grad_list.append(grad_j)\n\n                    # if one machine launch multiple KVServer, they share the same storage.\n                    # For each machine, the pytorch rank is num_trainers *\n                    # machine_id + i\n\n                    # use scatter to sync across trainers about the p2p tensor size\n                    # Note: If we have GPU nccl support, we can use all_to_all to\n                    # sync information here\n                    gather_list = list(\n                        th.empty(\n                            [self._world_size],\n                            dtype=th.int64,\n                            device=preferred_device,\n                        ).chunk(self._world_size)\n                    )\n                    alltoall(\n                        self._rank,\n                        self._world_size,\n                        gather_list,\n                        idx_split_size,\n                    )\n                    idx_gather_list = [\n                        th.empty(\n                            (int(num_emb),),\n                            dtype=idics.dtype,\n                            device=preferred_device,\n                        )\n                        for num_emb in gather_list\n                    ]\n                    alltoallv(\n                        self._rank,\n                        self._world_size,\n                        idx_gather_list,\n                        idics_list,\n                    )\n                    local_indics[name] = idx_gather_list\n                    grad_gather_list = [\n                        th.empty(\n                            (int(num_emb), grads.shape[1]),\n                            dtype=grads.dtype,\n                            device=preferred_device,\n                        )\n                        for num_emb in gather_list\n                    ]\n                    alltoallv(\n                        self._rank,\n                        self._world_size,\n                        grad_gather_list,\n                        grad_list,\n                    )\n                    local_grads[name] = grad_gather_list\n                else:\n                    local_indics[name] = [idics]\n                    local_grads[name] = [grads]\n\n            if self._clean_grad:\n                # clean gradient track\n                for emb in self._params:\n                    emb.reset_trace()\n                self._clean_grad = False\n\n            # do local update\n            for emb in self._params:\n                name = emb.weight.name\n                idx = th.cat(local_indics[name], dim=0)\n                grad = th.cat(local_grads[name], dim=0)\n                self.update(\n                    idx.to(target_device, non_blocking=True),\n                    grad.to(target_device, non_blocking=True),\n                    emb,\n                )\n\n        # synchronized gradient update\n        if self._world_size > 1:\n            th.distributed.barrier()\n\n    @abstractmethod\n    def update(self, idx, grad, emb):\n        \"\"\"Update embeddings in a sparse manner\n        Sparse embeddings are updated in mini batches. We maintain gradient states for\n        each embedding so they can be updated separately.\n\n        Parameters\n        ----------\n        idx : tensor\n            Index of the embeddings to be updated.\n        grad : tensor\n            Gradient of each embedding.\n        emb : dgl.distributed.DistEmbedding\n            Sparse node embedding to update.\n        \"\"\"\n\n    def zero_grad(self):\n        \"\"\"clean grad cache\"\"\"\n        self._clean_grad = True\n\n\ndef initializer(shape, dtype):\n    \"\"\"Sparse optimizer state initializer\n\n    Parameters\n    ----------\n    shape : tuple of ints\n        The shape of the state tensor\n    dtype : torch dtype\n        The data type of the state tensor\n    \"\"\"\n    arr = th.zeros(shape, dtype=dtype)\n    return arr\n\n\nclass SparseAdagrad(DistSparseGradOptimizer):\n    r\"\"\"Distributed Node embedding optimizer using the Adagrad algorithm.\n\n    This optimizer implements a distributed sparse version of Adagrad algorithm for\n    optimizing :class:`dgl.distributed.DistEmbedding`. Being sparse means it only updates\n    the embeddings whose gradients have updates, which are usually a very\n    small portion of the total embeddings.\n\n    Adagrad maintains a :math:`G_{t,i,j}` for every parameter in the embeddings, where\n    :math:`G_{t,i,j}=G_{t-1,i,j} + g_{t,i,j}^2` and :math:`g_{t,i,j}` is the gradient of\n    the dimension :math:`j` of embedding :math:`i` at step :math:`t`.\n\n    NOTE: The support of sparse Adagrad optimizer is experimental.\n\n    Parameters\n    ----------\n    params : list[dgl.distributed.DistEmbedding]\n        The list of dgl.distributed.DistEmbedding.\n    lr : float\n        The learning rate.\n    eps : float, Optional\n        The term added to the denominator to improve numerical stability\n        Default: 1e-10\n    \"\"\"\n\n    def __init__(self, params, lr, eps=1e-10):\n        super(SparseAdagrad, self).__init__(params, lr)\n        self._eps = eps\n        self._defaults = {\"_lr\": lr, \"_eps\": eps}\n        # We need to register a state sum for each embedding in the kvstore.\n        for emb in params:\n            assert isinstance(\n                emb, DistEmbedding\n            ), \"SparseAdagrad only supports dgl.distributed.DistEmbedding\"\n\n            name = emb.name + \"_sum\"\n            state = DistTensor(\n                (emb.num_embeddings, emb.embedding_dim),\n                th.float32,\n                name,\n                init_func=initializer,\n                part_policy=emb.part_policy,\n                is_gdata=False,\n            )\n            assert (\n                emb.name not in self._state\n            ), \"{} already registered in the optimizer\".format(emb.name)\n            self._state[emb.name] = state\n\n    def update(self, idx, grad, emb):\n        \"\"\"Update embeddings in a sparse manner\n        Sparse embeddings are updated in mini batches. We maintain gradient states for\n        each embedding so they can be updated separately.\n\n        Parameters\n        ----------\n        idx : tensor\n            Index of the embeddings to be updated.\n        grad : tensor\n            Gradient of each embedding.\n        emb : dgl.distributed.DistEmbedding\n            Sparse embedding to update.\n        \"\"\"\n        eps = self._eps\n        clr = self._lr\n\n        state_dev = th.device(\"cpu\")\n        exec_dev = grad.device\n\n        # only perform async copies cpu -> gpu, or gpu-> gpu, but block\n        # when copying to the cpu, so as to ensure the copy is finished\n        # before operating on the data on the cpu\n        state_block = state_dev == th.device(\"cpu\") and exec_dev != state_dev\n\n        # the update is non-linear so indices must be unique\n        grad_indices, inverse, cnt = th.unique(\n            idx, return_inverse=True, return_counts=True\n        )\n        grad_values = th.zeros(\n            (grad_indices.shape[0], grad.shape[1]), device=exec_dev\n        )\n        grad_values.index_add_(0, inverse, grad)\n        grad_values = grad_values / cnt.unsqueeze(1)\n        grad_sum = grad_values * grad_values\n\n        # update grad state\n        grad_state = self._state[emb.name][grad_indices].to(exec_dev)\n        grad_state += grad_sum\n        grad_state_dst = grad_state.to(state_dev, non_blocking=True)\n        if state_block:\n            # use events to try and overlap CPU and GPU as much as possible\n            update_event = th.cuda.Event()\n            update_event.record()\n\n        # update emb\n        std_values = grad_state.sqrt_().add_(eps)\n        tmp = clr * grad_values / std_values\n        tmp_dst = tmp.to(state_dev, non_blocking=True)\n\n        if state_block:\n            std_event = th.cuda.Event()\n            std_event.record()\n            # wait for our transfers from exec_dev to state_dev to finish\n            # before we can use them\n            update_event.wait()\n        self._state[emb.name][grad_indices] = grad_state_dst\n\n        if state_block:\n            # wait for the transfer of std_values to finish before we\n            # can use it\n            std_event.wait()\n        emb._tensor[grad_indices] -= tmp_dst\n\n\nclass SparseAdam(DistSparseGradOptimizer):\n    r\"\"\"Distributed Node embedding optimizer using the Adam algorithm.\n\n    This optimizer implements a distributed sparse version of Adam algorithm for\n    optimizing :class:`dgl.distributed.DistEmbedding`. Being sparse means it only updates\n    the embeddings whose gradients have updates, which are usually a very\n    small portion of the total embeddings.\n\n    Adam maintains a :math:`Gm_{t,i,j}` and `Gp_{t,i,j}` for every parameter\n    in the embeddings, where\n    :math:`Gm_{t,i,j}=beta1 * Gm_{t-1,i,j} + (1-beta1) * g_{t,i,j}`,\n    :math:`Gp_{t,i,j}=beta2 * Gp_{t-1,i,j} + (1-beta2) * g_{t,i,j}^2`,\n    :math:`g_{t,i,j} = lr * Gm_{t,i,j} / (1 - beta1^t) / \\sqrt{Gp_{t,i,j} / (1 - beta2^t)}` and\n    :math:`g_{t,i,j}` is the gradient of the dimension :math:`j` of embedding :math:`i`\n    at step :math:`t`.\n\n    NOTE: The support of sparse Adam optimizer is experimental.\n\n    Parameters\n    ----------\n    params : list[dgl.distributed.DistEmbedding]\n        The list of dgl.distributed.DistEmbedding.\n    lr : float\n        The learning rate.\n    betas : tuple[float, float], Optional\n        Coefficients used for computing running averages of gradient and its square.\n        Default: (0.9, 0.999)\n    eps : float, Optional\n        The term added to the denominator to improve numerical stability\n        Default: 1e-8\n    \"\"\"\n\n    def __init__(self, params, lr, betas=(0.9, 0.999), eps=1e-08):\n        super(SparseAdam, self).__init__(params, lr)\n        self._eps = eps\n        # We need to register a state sum for each embedding in the kvstore.\n        self._beta1 = betas[0]\n        self._beta2 = betas[1]\n        self._defaults = {\n            \"_lr\": lr,\n            \"_eps\": eps,\n            \"_beta1\": betas[0],\n            \"_beta2\": betas[1],\n        }\n        for emb in params:\n            assert isinstance(\n                emb, DistEmbedding\n            ), \"SparseAdam only supports dgl.distributed.DistEmbedding\"\n\n            state_step = DistTensor(\n                (emb.num_embeddings,),\n                th.float32,\n                emb.name + \"_step\",\n                init_func=initializer,\n                part_policy=emb.part_policy,\n                is_gdata=False,\n            )\n            state_mem = DistTensor(\n                (emb.num_embeddings, emb.embedding_dim),\n                th.float32,\n                emb.name + \"_mem\",\n                init_func=initializer,\n                part_policy=emb.part_policy,\n                is_gdata=False,\n            )\n            state_power = DistTensor(\n                (emb.num_embeddings, emb.embedding_dim),\n                th.float32,\n                emb.name + \"_power\",\n                init_func=initializer,\n                part_policy=emb.part_policy,\n                is_gdata=False,\n            )\n            state = (state_step, state_mem, state_power)\n            assert (\n                emb.name not in self._state\n            ), \"{} already registered in the optimizer\".format(emb.name)\n            self._state[emb.name] = state\n\n    def update(self, idx, grad, emb):\n        \"\"\"Update embeddings in a sparse manner\n        Sparse embeddings are updated in mini batches. We maintain gradient states for\n        each embedding so they can be updated separately.\n\n        Parameters\n        ----------\n        idx : tensor\n            Index of the embeddings to be updated.\n        grad : tensor\n            Gradient of each embedding.\n        emb : dgl.distributed.DistEmbedding\n            Sparse embedding to update.\n        \"\"\"\n        beta1 = self._beta1\n        beta2 = self._beta2\n        eps = self._eps\n        clr = self._lr\n        state_step, state_mem, state_power = self._state[emb.name]\n\n        state_dev = th.device(\"cpu\")\n        exec_dev = grad.device\n\n        # only perform async copies cpu -> gpu, or gpu-> gpu, but block\n        # when copying to the cpu, so as to ensure the copy is finished\n        # before operating on the data on the cpu\n        state_block = state_dev == th.device(\"cpu\") and exec_dev != state_dev\n\n        # the update is non-linear so indices must be unique\n        grad_indices, inverse, cnt = th.unique(\n            idx, return_inverse=True, return_counts=True\n        )\n        # update grad state\n        state_idx = grad_indices.to(state_dev)\n        # The original implementation will cause read/write contension.\n        #    state_step[state_idx] += 1\n        #    state_step = state_step[state_idx].to(exec_dev, non_blocking=True)\n        # In a distributed environment, the first line of code will send write requests to\n        # kvstore servers to update the state_step which is asynchronous and the second line\n        # of code will also send read requests to kvstore servers. The write and read requests\n        # may be handled by different kvstore servers managing the same portion of the\n        # state_step dist tensor in the same node. So that, the read request may read an old\n        # value (i.e., 0 in the first iteration) which will cause\n        # update_power_corr to be NaN\n        state_val = state_step[state_idx] + 1\n        state_step[state_idx] = state_val\n        state_step = state_val.to(exec_dev)\n        orig_mem = state_mem[state_idx].to(exec_dev)\n        orig_power = state_power[state_idx].to(exec_dev)\n\n        grad_values = th.zeros(\n            (grad_indices.shape[0], grad.shape[1]), device=exec_dev\n        )\n        grad_values.index_add_(0, inverse, grad)\n        grad_values = grad_values / cnt.unsqueeze(1)\n        grad_mem = grad_values\n        grad_power = grad_values * grad_values\n        update_mem = beta1 * orig_mem + (1.0 - beta1) * grad_mem\n        update_power = beta2 * orig_power + (1.0 - beta2) * grad_power\n        update_mem_dst = update_mem.to(state_dev, non_blocking=True)\n        update_power_dst = update_power.to(state_dev, non_blocking=True)\n        if state_block:\n            # use events to try and overlap CPU and GPU as much as possible\n            update_event = th.cuda.Event()\n            update_event.record()\n\n        update_mem_corr = update_mem / (\n            1.0 - th.pow(th.tensor(beta1, device=exec_dev), state_step)\n        ).unsqueeze(1)\n        update_power_corr = update_power / (\n            1.0 - th.pow(th.tensor(beta2, device=exec_dev), state_step)\n        ).unsqueeze(1)\n        std_values = clr * update_mem_corr / (th.sqrt(update_power_corr) + eps)\n\n        std_values_dst = std_values.to(state_dev, non_blocking=True)\n\n        if state_block:\n            std_event = th.cuda.Event()\n            std_event.record()\n            # wait for our transfers from exec_dev to state_dev to finish\n            # before we can use them\n            update_event.wait()\n        state_mem[state_idx] = update_mem_dst\n        state_power[state_idx] = update_power_dst\n\n        if state_block:\n            # wait for the transfer of std_values to finish before we\n            # can use it\n            std_event.wait()\n        emb._tensor[state_idx] -= std_values_dst\n"
  },
  {
    "path": "python/dgl/distributed/optim/pytorch/utils.py",
    "content": "\"\"\"Provide utils for distributed sparse optimizers\n\"\"\"\nimport torch as th\nimport torch.distributed as dist\n\n\ndef alltoall_cpu(rank, world_size, output_tensor_list, input_tensor_list):\n    \"\"\"Each process scatters list of input tensors to all processes in a cluster\n    and return gathered list of tensors in output list. The tensors should have the same shape.\n\n    Parameters\n    ----------\n    rank : int\n        The rank of current worker\n    world_size : int\n        The size of the entire communicator\n    output_tensor_list : List of tensor\n        The received tensors\n    input_tensor_list : List of tensor\n        The tensors to exchange\n    \"\"\"\n    input_tensor_list = [\n        tensor.to(th.device(\"cpu\")) for tensor in input_tensor_list\n    ]\n    for i in range(world_size):\n        dist.scatter(\n            output_tensor_list[i], input_tensor_list if i == rank else [], src=i\n        )\n\n\ndef alltoallv_cpu(rank, world_size, output_tensor_list, input_tensor_list):\n    \"\"\"Each process scatters list of input tensors to all processes in a cluster\n    and return gathered list of tensors in output list.\n\n    Parameters\n    ----------\n    rank : int\n        The rank of current worker\n    world_size : int\n        The size of the entire communicator\n    output_tensor_list : List of tensor\n        The received tensors\n    input_tensor_list : List of tensor\n        The tensors to exchange\n    \"\"\"\n    # send tensor to each target trainer using torch.distributed.isend\n    # isend is async\n    senders = []\n    for i in range(world_size):\n        if i == rank:\n            output_tensor_list[i] = input_tensor_list[i].to(th.device(\"cpu\"))\n        else:\n            sender = dist.isend(\n                input_tensor_list[i].to(th.device(\"cpu\")), dst=i\n            )\n            senders.append(sender)\n\n    for i in range(world_size):\n        if i != rank:\n            dist.recv(output_tensor_list[i], src=i)\n\n    th.distributed.barrier()\n\n\ndef alltoall(rank, world_size, output_tensor_list, input_tensor_list):\n    \"\"\"Each process scatters list of input tensors to all processes in a cluster\n    and return gathered list of tensors in output list. The tensors should have the same shape.\n\n    Parameters\n    ----------\n    rank : int\n        The rank of current worker\n    world_size : int\n        The size of the entire communicator\n    output_tensor_list : List of tensor\n        The received tensors\n    input_tensor_list : List of tensor\n        The tensors to exchange\n    \"\"\"\n    if th.distributed.get_backend() == \"nccl\":\n        th.distributed.all_to_all(output_tensor_list, input_tensor_list)\n    else:\n        alltoall_cpu(\n            rank,\n            world_size,\n            output_tensor_list,\n            input_tensor_list,\n        )\n\n\ndef alltoallv(rank, world_size, output_tensor_list, input_tensor_list):\n    \"\"\"Each process scatters list of input tensors to all processes in a cluster\n    and return gathered list of tensors in output list.\n\n    Parameters\n    ----------\n    rank : int\n        The rank of current worker\n    world_size : int\n        The size of the entire communicator\n    output_tensor_list : List of tensor\n        The received tensors\n    input_tensor_list : List of tensor\n        The tensors to exchange\n    \"\"\"\n    if th.distributed.get_backend() == \"nccl\":\n        th.distributed.all_to_all(output_tensor_list, input_tensor_list)\n    else:\n        alltoallv_cpu(\n            rank,\n            world_size,\n            output_tensor_list,\n            input_tensor_list,\n        )\n"
  },
  {
    "path": "python/dgl/distributed/optim/tensorflow/__init__.py",
    "content": ""
  },
  {
    "path": "python/dgl/distributed/partition.py",
    "content": "\"\"\"Functions for partitions. \"\"\"\n\nimport concurrent\nimport concurrent.futures\nimport copy\nimport json\nimport logging\nimport multiprocessing as mp\nimport os\nimport time\nfrom functools import partial\n\nimport numpy as np\n\nimport torch\n\nfrom .. import backend as F, graphbolt as gb\nfrom ..base import dgl_warning, DGLError, EID, ETYPE, NID, NTYPE\nfrom ..convert import heterograph, to_homogeneous\nfrom ..data.utils import load_graphs, load_tensors, save_graphs, save_tensors\nfrom ..partition import (\n    get_peak_mem,\n    metis_partition_assignment,\n    partition_graph_with_halo,\n)\nfrom ..random import choice as random_choice\nfrom ..transforms import sort_csc_by_tag, sort_csr_by_tag\nfrom .constants import DEFAULT_ETYPE, DEFAULT_NTYPE, DGL2GB_EID, GB_DST_ID\nfrom .graph_partition_book import (\n    _etype_str_to_tuple,\n    _etype_tuple_to_str,\n    RangePartitionBook,\n)\n\n\nRESERVED_FIELD_DTYPE = {\n    \"inner_node\": (\n        F.uint8\n    ),  # A flag indicates whether the node is inside a partition.\n    \"inner_edge\": (\n        F.uint8\n    ),  # A flag indicates whether the edge is inside a partition.\n    NID: F.int64,\n    EID: F.int64,\n    NTYPE: F.int16,\n    # `sort_csr_by_tag` and `sort_csc_by_tag` works on int32/64 only.\n    ETYPE: F.int32,\n}\n\n\ndef _format_part_metadata(part_metadata, formatter):\n    \"\"\"Format etypes with specified formatter.\"\"\"\n    for key in [\"edge_map\", \"etypes\"]:\n        if key not in part_metadata:\n            continue\n        orig_data = part_metadata[key]\n        if not isinstance(orig_data, dict):\n            continue\n        new_data = {}\n        for etype, data in orig_data.items():\n            etype = formatter(etype)\n            new_data[etype] = data\n        part_metadata[key] = new_data\n    return part_metadata\n\n\ndef _load_part_config(part_config):\n    \"\"\"Load part config and format.\"\"\"\n    try:\n        with open(part_config) as f:\n            part_metadata = _format_part_metadata(\n                json.load(f), _etype_str_to_tuple\n            )\n    except AssertionError as e:\n        raise DGLError(\n            f\"Failed to load partition config due to {e}. \"\n            \"Probably caused by outdated config. If so, please refer to \"\n            \"https://github.com/dmlc/dgl/tree/master/tools#change-edge-\"\n            \"type-to-canonical-edge-type-for-partition-configuration-json\"\n        )\n    return part_metadata\n\n\ndef _dump_part_config(part_config, part_metadata):\n    \"\"\"Format and dump part config.\"\"\"\n    part_metadata = _format_part_metadata(part_metadata, _etype_tuple_to_str)\n    with open(part_config, \"w\") as outfile:\n        json.dump(part_metadata, outfile, sort_keys=False, indent=4)\n\n\ndef process_partitions(g, formats=None, sort_etypes=False):\n    \"\"\"Preprocess partitions before saving:\n    1. format data types.\n    2. sort csc/csr by tag.\n    \"\"\"\n    for k, dtype in RESERVED_FIELD_DTYPE.items():\n        if k in g.ndata:\n            g.ndata[k] = F.astype(g.ndata[k], dtype)\n        if k in g.edata:\n            g.edata[k] = F.astype(g.edata[k], dtype)\n\n    if (sort_etypes) and (formats is not None):\n        if \"csr\" in formats:\n            g = sort_csr_by_tag(g, tag=g.edata[ETYPE], tag_type=\"edge\")\n        if \"csc\" in formats:\n            g = sort_csc_by_tag(g, tag=g.edata[ETYPE], tag_type=\"edge\")\n    return g\n\n\ndef _save_dgl_graphs(filename, g_list, formats=None):\n    save_graphs(filename, g_list, formats=formats)\n\n\ndef _get_inner_node_mask(graph, ntype_id, gpb=None):\n    ndata = (\n        graph.node_attributes\n        if isinstance(graph, gb.FusedCSCSamplingGraph)\n        else graph.ndata\n    )\n    assert \"inner_node\" in ndata, \"'inner_node' is not in nodes' data\"\n    if NTYPE in ndata or gpb is not None:\n        ntype = (\n            gpb.map_to_per_ntype(ndata[NID])[0]\n            if gpb is not None\n            else ndata[NTYPE]\n        )\n        dtype = F.dtype(ndata[\"inner_node\"])\n        return ndata[\"inner_node\"] * F.astype(ntype == ntype_id, dtype) == 1\n    else:\n        return ndata[\"inner_node\"] == 1\n\n\ndef _get_inner_edge_mask(\n    graph,\n    etype_id,\n):\n    edata = (\n        graph.edge_attributes\n        if isinstance(graph, gb.FusedCSCSamplingGraph)\n        else graph.edata\n    )\n    assert \"inner_edge\" in edata, \"'inner_edge' is not in edges' data\"\n    etype = (\n        graph.type_per_edge\n        if isinstance(graph, gb.FusedCSCSamplingGraph)\n        else (graph.edata[ETYPE] if ETYPE in graph.edata else None)\n    )\n    if etype is not None:\n        dtype = F.dtype(edata[\"inner_edge\"])\n        return edata[\"inner_edge\"] * F.astype(etype == etype_id, dtype) == 1\n    else:\n        return edata[\"inner_edge\"] == 1\n\n\ndef _get_part_ranges(id_ranges):\n    res = {}\n    for key in id_ranges:\n        # Normally, each element has two values that represent the starting ID and the ending ID\n        # of the ID range in a partition.\n        # If not, the data is probably still in the old format, in which only the ending ID is\n        # stored. We need to convert it to the format we expect.\n        if not isinstance(id_ranges[key][0], list):\n            start = 0\n            for i, end in enumerate(id_ranges[key]):\n                id_ranges[key][i] = [start, end]\n                start = end\n        res[key] = np.concatenate(\n            [np.array(l) for l in id_ranges[key]]\n        ).reshape(-1, 2)\n    return res\n\n\ndef _verify_dgl_partition(graph, part_id, gpb, ntypes, etypes):\n    \"\"\"Verify the partition of a DGL graph.\"\"\"\n    assert (\n        NID in graph.ndata\n    ), \"the partition graph should contain node mapping to global node ID\"\n    assert (\n        EID in graph.edata\n    ), \"the partition graph should contain edge mapping to global edge ID\"\n\n    for ntype in ntypes:\n        ntype_id = ntypes[ntype]\n        # graph.ndata[NID] are global homogeneous node IDs.\n        nids = F.boolean_mask(\n            graph.ndata[NID], _get_inner_node_mask(graph, ntype_id)\n        )\n        partids1 = gpb.nid2partid(nids)\n        _, per_type_nids = gpb.map_to_per_ntype(nids)\n        partids2 = gpb.nid2partid(per_type_nids, ntype)\n        assert np.all(F.asnumpy(partids1 == part_id)), (\n            \"Unexpected partition IDs are found in the loaded partition \"\n            \"while querying via global homogeneous node IDs.\"\n        )\n        assert np.all(F.asnumpy(partids2 == part_id)), (\n            \"Unexpected partition IDs are found in the loaded partition \"\n            \"while querying via type-wise node IDs.\"\n        )\n    for etype in etypes:\n        etype_id = etypes[etype]\n        # graph.edata[EID] are global homogeneous edge IDs.\n        eids = F.boolean_mask(\n            graph.edata[EID], _get_inner_edge_mask(graph, etype_id)\n        )\n        partids1 = gpb.eid2partid(eids)\n        _, per_type_eids = gpb.map_to_per_etype(eids)\n        partids2 = gpb.eid2partid(per_type_eids, etype)\n        assert np.all(F.asnumpy(partids1 == part_id)), (\n            \"Unexpected partition IDs are found in the loaded partition \"\n            \"while querying via global homogeneous edge IDs.\"\n        )\n        assert np.all(F.asnumpy(partids2 == part_id)), (\n            \"Unexpected partition IDs are found in the loaded partition \"\n            \"while querying via type-wise edge IDs.\"\n        )\n\n\ndef _verify_graphbolt_partition(graph, part_id, gpb, ntypes, etypes):\n    \"\"\"Verify the partition of a GraphBolt graph.\"\"\"\n    required_ndata_fields = [NID]\n    required_edata_fields = [EID]\n    assert all(\n        field in graph.node_attributes for field in required_ndata_fields\n    ), \"the partition graph should contain node mapping to global node ID.\"\n    assert all(\n        field in graph.edge_attributes for field in required_edata_fields\n    ), \"the partition graph should contain edge mapping to global edge ID.\"\n\n    num_edges = graph.total_num_edges\n    local_src_ids = graph.indices\n    local_dst_ids = gb.expand_indptr(\n        graph.csc_indptr, dtype=local_src_ids.dtype, output_size=num_edges\n    )\n    global_src_ids = graph.node_attributes[NID][local_src_ids]\n    global_dst_ids = graph.node_attributes[NID][local_dst_ids]\n\n    etype_ids, type_wise_eids = gpb.map_to_per_etype(graph.edge_attributes[EID])\n    if graph.type_per_edge is not None:\n        assert torch.equal(etype_ids, graph.type_per_edge)\n    etype_ids, etype_ids_indices = torch.sort(etype_ids)\n    global_src_ids = global_src_ids[etype_ids_indices]\n    global_dst_ids = global_dst_ids[etype_ids_indices]\n    type_wise_eids = type_wise_eids[etype_ids_indices]\n\n    src_ntype_ids, src_type_wise_nids = gpb.map_to_per_ntype(global_src_ids)\n    dst_ntype_ids, dst_type_wise_nids = gpb.map_to_per_ntype(global_dst_ids)\n\n    data_dict = dict()\n    edge_ids = dict()\n    for c_etype, etype_id in etypes.items():\n        idx = etype_ids == etype_id\n        src_ntype, etype, dst_ntype = c_etype\n        if idx.sum() == 0:\n            continue\n        actual_src_ntype_ids = src_ntype_ids[idx]\n        actual_dst_ntype_ids = dst_ntype_ids[idx]\n        expected_src_ntype_ids = ntypes[src_ntype]\n        expected_dst_ntype_ids = ntypes[dst_ntype]\n        assert all(actual_src_ntype_ids == expected_src_ntype_ids), (\n            f\"Unexpected types of source nodes for {c_etype}. Expected: \"\n            f\"{expected_src_ntype_ids}, but got: {actual_src_ntype_ids}.\"\n        )\n        assert all(actual_dst_ntype_ids == expected_dst_ntype_ids), (\n            f\"Unexpected types of destination nodes for {c_etype}. Expected: \"\n            f\"{expected_dst_ntype_ids}, but got: {actual_dst_ntype_ids}.\"\n        )\n        data_dict[c_etype] = (src_type_wise_nids[idx], dst_type_wise_nids[idx])\n        edge_ids[c_etype] = type_wise_eids[idx]\n\n    # Make sure node/edge IDs are not out of range.\n    hg = heterograph(\n        data_dict, {ntype: gpb._num_nodes(ntype) for ntype in ntypes}\n    )\n    for etype in edge_ids:\n        hg.edges[etype].data[EID] = edge_ids[etype]\n    assert all(\n        hg.num_edges(etype) == len(eids) for etype, eids in edge_ids.items()\n    ), \"The number of edges per etype in the partition graph is not correct.\"\n    assert num_edges == hg.num_edges(), (\n        f\"The total number of edges in the partition graph is not correct. \"\n        f\"Expected: {num_edges}, but got: {hg.num_edges()}.\"\n    )\n    print(f\"Partition {part_id} looks good!\")\n\n\ndef load_partition(part_config, part_id, load_feats=True, use_graphbolt=False):\n    \"\"\"Load data of a partition from the data path.\n\n    A partition data includes a graph structure of the partition, a dict of node tensors,\n    a dict of edge tensors and some metadata. The partition may contain the HALO nodes,\n    which are the nodes replicated from other partitions. However, the dict of node tensors\n    only contains the node data that belongs to the local partition. Similarly, edge tensors\n    only contains the edge data that belongs to the local partition. The metadata include\n    the information of the global graph (not the local partition), which includes the number\n    of nodes, the number of edges as well as the node assignment of the global graph.\n\n    The function currently loads data through the local filesystem interface.\n\n    Parameters\n    ----------\n    part_config : str\n        The path of the partition config file.\n    part_id : int\n        The partition ID.\n    load_feats : bool, optional\n        Whether to load node/edge feats. If False, the returned node/edge feature\n        dictionaries will be empty. Default: True.\n    use_graphbolt : bool, optional\n        Whether to load GraphBolt partition. Default: False.\n\n    Returns\n    -------\n    DGLGraph\n        The graph partition structure.\n    Dict[str, Tensor]\n        Node features.\n    Dict[(str, str, str), Tensor]\n        Edge features.\n    GraphPartitionBook\n        The graph partition information.\n    str\n        The graph name\n    List[str]\n        The node types\n    List[(str, str, str)]\n        The edge types\n    \"\"\"\n    config_path = os.path.dirname(part_config)\n    relative_to_config = lambda path: os.path.join(config_path, path)\n\n    with open(part_config) as conf_f:\n        part_metadata = json.load(conf_f)\n    assert (\n        \"part-{}\".format(part_id) in part_metadata\n    ), \"part-{} does not exist\".format(part_id)\n    part_files = part_metadata[\"part-{}\".format(part_id)]\n\n    exist_dgl_graph = exist_graphbolt_graph = False\n    if os.path.exists(os.path.join(config_path, f\"part{part_id}\", \"graph.dgl\")):\n        use_graphbolt = False\n        exist_dgl_graph = True\n    if os.path.exists(\n        os.path.join(\n            config_path, f\"part{part_id}\", \"fused_csc_sampling_graph.pt\"\n        )\n    ):\n        use_graphbolt = True\n        exist_graphbolt_graph = True\n\n    # Check if both DGL graph and GraphBolt graph exist or not exist. Make sure only one exists.\n    if not exist_dgl_graph and not exist_graphbolt_graph:\n        raise ValueError(\"The graph object doesn't exist.\")\n    if exist_dgl_graph and exist_graphbolt_graph:\n        raise ValueError(\n            \"Both DGL graph and GraphBolt graph exist. Please remove one.\"\n        )\n\n    if use_graphbolt:\n        part_graph_field = \"part_graph_graphbolt\"\n    else:\n        part_graph_field = \"part_graph\"\n    assert (\n        part_graph_field in part_files\n    ), f\"the partition does not contain graph structure: {part_graph_field}\"\n    partition_path = relative_to_config(part_files[part_graph_field])\n    logging.info(\n        \"Start to load partition from %s which is \"\n        \"%d bytes. It may take non-trivial \"\n        \"time for large partition.\",\n        partition_path,\n        os.path.getsize(partition_path),\n    )\n    graph = (\n        torch.load(partition_path, weights_only=False)\n        if use_graphbolt\n        else load_graphs(partition_path)[0][0]\n    )\n    logging.info(\"Finished loading partition from %s.\", partition_path)\n\n    gpb, graph_name, ntypes, etypes = load_partition_book(part_config, part_id)\n    ntypes_list = list(ntypes.keys())\n    etypes_list = list(etypes.keys())\n\n    if \"DGL_DIST_DEBUG\" in os.environ:\n        _verify_func = (\n            _verify_graphbolt_partition\n            if use_graphbolt\n            else _verify_dgl_partition\n        )\n        _verify_func(graph, part_id, gpb, ntypes, etypes)\n\n    node_feats = {}\n    edge_feats = {}\n    if load_feats:\n        node_feats, edge_feats = load_partition_feats(part_config, part_id)\n\n    return (\n        graph,\n        node_feats,\n        edge_feats,\n        gpb,\n        graph_name,\n        ntypes_list,\n        etypes_list,\n    )\n\n\ndef load_partition_feats(\n    part_config, part_id, load_nodes=True, load_edges=True\n):\n    \"\"\"Load node/edge feature data from a partition.\n\n    Parameters\n    ----------\n    part_config : str\n        The path of the partition config file.\n    part_id : int\n        The partition ID.\n    load_nodes : bool, optional\n        Whether to load node features. If ``False``, ``None`` is returned.\n    load_edges : bool, optional\n        Whether to load edge features. If ``False``, ``None`` is returned.\n\n    Returns\n    -------\n    Dict[str, Tensor] or None\n        Node features.\n    Dict[str, Tensor] or None\n        Edge features.\n    \"\"\"\n    config_path = os.path.dirname(part_config)\n    relative_to_config = lambda path: os.path.join(config_path, path)\n\n    with open(part_config) as conf_f:\n        part_metadata = json.load(conf_f)\n    assert (\n        \"part-{}\".format(part_id) in part_metadata\n    ), \"part-{} does not exist\".format(part_id)\n    part_files = part_metadata[\"part-{}\".format(part_id)]\n    assert (\n        \"node_feats\" in part_files\n    ), \"the partition does not contain node features.\"\n    assert (\n        \"edge_feats\" in part_files\n    ), \"the partition does not contain edge feature.\"\n    node_feats = None\n    if load_nodes:\n        feat_path = relative_to_config(part_files[\"node_feats\"])\n        logging.debug(\n            \"Start to load node data from %s which is \" \"%d bytes.\",\n            feat_path,\n            os.path.getsize(feat_path),\n        )\n        node_feats = load_tensors(feat_path)\n        logging.info(\"Finished loading node data.\")\n    edge_feats = None\n    if load_edges:\n        feat_path = relative_to_config(part_files[\"edge_feats\"])\n        logging.debug(\n            \"Start to load edge data from %s which is \" \"%d bytes.\",\n            feat_path,\n            os.path.getsize(feat_path),\n        )\n        edge_feats = load_tensors(feat_path)\n        logging.info(\"Finished loading edge data.\")\n    # In the old format, the feature name doesn't contain node/edge type.\n    # For compatibility, let's add node/edge types to the feature names.\n    if node_feats is not None:\n        new_feats = {}\n        for name in node_feats:\n            feat = node_feats[name]\n            if name.find(\"/\") == -1:\n                name = DEFAULT_NTYPE + \"/\" + name\n            new_feats[name] = feat\n        node_feats = new_feats\n    if edge_feats is not None:\n        new_feats = {}\n        for name in edge_feats:\n            feat = edge_feats[name]\n            if name.find(\"/\") == -1:\n                name = _etype_tuple_to_str(DEFAULT_ETYPE) + \"/\" + name\n            new_feats[name] = feat\n        edge_feats = new_feats\n\n    return node_feats, edge_feats\n\n\ndef load_partition_book(part_config, part_id, part_metadata=None):\n    \"\"\"Load a graph partition book from the partition config file.\n\n    Parameters\n    ----------\n    part_config : str\n        The path of the partition config file.\n    part_id : int\n        The partition ID.\n    part_metadata : dict\n        The meta data of partition.\n\n    Returns\n    -------\n    GraphPartitionBook\n        The global partition information.\n    str\n        The graph name\n    dict\n        The node types\n    dict\n        The edge types\n    \"\"\"\n    if part_metadata is None:\n        part_metadata = _load_part_config(part_config)\n    assert \"num_parts\" in part_metadata, \"num_parts does not exist.\"\n    assert (\n        part_metadata[\"num_parts\"] > part_id\n    ), \"part {} is out of range (#parts: {})\".format(\n        part_id, part_metadata[\"num_parts\"]\n    )\n    num_parts = part_metadata[\"num_parts\"]\n    assert (\n        \"num_nodes\" in part_metadata\n    ), \"cannot get the number of nodes of the global graph.\"\n    assert (\n        \"num_edges\" in part_metadata\n    ), \"cannot get the number of edges of the global graph.\"\n    assert \"node_map\" in part_metadata, \"cannot get the node map.\"\n    assert \"edge_map\" in part_metadata, \"cannot get the edge map.\"\n    assert \"graph_name\" in part_metadata, \"cannot get the graph name\"\n\n    # If this is a range partitioning, node_map actually stores a list, whose elements\n    # indicate the boundary of range partitioning. Otherwise, node_map stores a filename\n    # that contains node map in a NumPy array.\n    node_map = part_metadata[\"node_map\"]\n    edge_map = part_metadata[\"edge_map\"]\n    if isinstance(node_map, dict):\n        for key in node_map:\n            is_range_part = isinstance(node_map[key], list)\n            break\n    elif isinstance(node_map, list):\n        is_range_part = True\n        node_map = {DEFAULT_NTYPE: node_map}\n    else:\n        is_range_part = False\n    if isinstance(edge_map, list):\n        edge_map = {DEFAULT_ETYPE: edge_map}\n\n    ntypes = {DEFAULT_NTYPE: 0}\n    etypes = {DEFAULT_ETYPE: 0}\n    if \"ntypes\" in part_metadata:\n        ntypes = part_metadata[\"ntypes\"]\n    if \"etypes\" in part_metadata:\n        etypes = part_metadata[\"etypes\"]\n\n    if isinstance(node_map, dict):\n        for key in node_map:\n            assert key in ntypes, \"The node type {} is invalid\".format(key)\n    if isinstance(edge_map, dict):\n        for key in edge_map:\n            assert key in etypes, \"The edge type {} is invalid\".format(key)\n\n    if not is_range_part:\n        raise TypeError(\"Only RangePartitionBook is supported currently.\")\n\n    node_map = _get_part_ranges(node_map)\n    edge_map = _get_part_ranges(edge_map)\n\n    # Format dtype of node/edge map if dtype is specified.\n    def _format_node_edge_map(part_metadata, map_type, data):\n        key = f\"{map_type}_map_dtype\"\n        if key not in part_metadata:\n            return data\n        dtype = part_metadata[key]\n        assert dtype in [\"int32\", \"int64\"], (\n            f\"The {map_type} map dtype should be either int32 or int64, \"\n            f\"but got {dtype}.\"\n        )\n        for key in data:\n            data[key] = data[key].astype(dtype)\n        return data\n\n    node_map = _format_node_edge_map(part_metadata, \"node\", node_map)\n    edge_map = _format_node_edge_map(part_metadata, \"edge\", edge_map)\n\n    # Sort the node/edge maps by the node/edge type ID.\n    node_map = dict(sorted(node_map.items(), key=lambda x: ntypes[x[0]]))\n    edge_map = dict(sorted(edge_map.items(), key=lambda x: etypes[x[0]]))\n\n    def _assert_is_sorted(id_map):\n        id_ranges = np.array(list(id_map.values()))\n        ids = []\n        for i in range(num_parts):\n            ids.append(id_ranges[:, i, :])\n        ids = np.array(ids).flatten()\n        assert np.all(\n            ids[:-1] <= ids[1:]\n        ), f\"The node/edge map is not sorted: {ids}\"\n\n    _assert_is_sorted(node_map)\n    _assert_is_sorted(edge_map)\n\n    return (\n        RangePartitionBook(\n            part_id, num_parts, node_map, edge_map, ntypes, etypes\n        ),\n        part_metadata[\"graph_name\"],\n        ntypes,\n        etypes,\n    )\n\n\ndef _get_orig_ids(g, sim_g, orig_nids, orig_eids):\n    \"\"\"Convert/construct the original node IDs and edge IDs.\n\n    It handles multiple cases:\n     * If the graph has been reshuffled and it's a homogeneous graph, we just return\n       the original node IDs and edge IDs in the inputs.\n     * If the graph has been reshuffled and it's a heterogeneous graph, we need to\n       split the original node IDs and edge IDs in the inputs based on the node types\n       and edge types.\n     * If the graph is not shuffled, the original node IDs and edge IDs don't change.\n\n    Parameters\n    ----------\n    g : DGLGraph\n       The input graph for partitioning.\n    sim_g : DGLGraph\n        The homogeneous version of the input graph.\n    orig_nids : tensor or None\n        The original node IDs after the input graph is reshuffled.\n    orig_eids : tensor or None\n        The original edge IDs after the input graph is reshuffled.\n\n    Returns\n    -------\n    tensor or dict of tensors, tensor or dict of tensors\n    \"\"\"\n    is_hetero = not g.is_homogeneous\n    if is_hetero:\n        # Get the type IDs\n        orig_ntype = F.gather_row(sim_g.ndata[NTYPE], orig_nids)\n        orig_etype = F.gather_row(sim_g.edata[ETYPE], orig_eids)\n        # Mapping between shuffled global IDs to original per-type IDs\n        orig_nids = F.gather_row(sim_g.ndata[NID], orig_nids)\n        orig_eids = F.gather_row(sim_g.edata[EID], orig_eids)\n        orig_nids = {\n            ntype: F.boolean_mask(\n                orig_nids, orig_ntype == g.get_ntype_id(ntype)\n            )\n            for ntype in g.ntypes\n        }\n        orig_eids = {\n            etype: F.boolean_mask(\n                orig_eids, orig_etype == g.get_etype_id(etype)\n            )\n            for etype in g.canonical_etypes\n        }\n    return orig_nids, orig_eids\n\n\ndef _set_trainer_ids(g, sim_g, node_parts):\n    \"\"\"Set the trainer IDs for each node and edge on the input graph.\n\n    The trainer IDs will be stored as node data and edge data in the input graph.\n\n    Parameters\n    ----------\n    g : DGLGraph\n       The input graph for partitioning.\n    sim_g : DGLGraph\n        The homogeneous version of the input graph.\n    node_parts : tensor\n        The node partition ID for each node in `sim_g`.\n    \"\"\"\n    if g.is_homogeneous:\n        g.ndata[\"trainer_id\"] = node_parts\n        # An edge is assigned to a partition based on its destination node.\n        g.edata[\"trainer_id\"] = F.gather_row(node_parts, g.edges()[1])\n    else:\n        for ntype_id, ntype in enumerate(g.ntypes):\n            type_idx = sim_g.ndata[NTYPE] == ntype_id\n            orig_nid = F.boolean_mask(sim_g.ndata[NID], type_idx)\n            trainer_id = F.zeros((len(orig_nid),), F.dtype(node_parts), F.cpu())\n            F.scatter_row_inplace(\n                trainer_id, orig_nid, F.boolean_mask(node_parts, type_idx)\n            )\n            g.nodes[ntype].data[\"trainer_id\"] = trainer_id\n        for c_etype in g.canonical_etypes:\n            # An edge is assigned to a partition based on its destination node.\n            _, _, dst_type = c_etype\n            trainer_id = F.gather_row(\n                g.nodes[dst_type].data[\"trainer_id\"], g.edges(etype=c_etype)[1]\n            )\n            g.edges[c_etype].data[\"trainer_id\"] = trainer_id\n\n\ndef _partition_to_graphbolt(\n    parts,\n    part_i,\n    part_config,\n    part_metadata,\n    *,\n    store_eids=True,\n    store_inner_node=False,\n    store_inner_edge=False,\n    graph_formats=None,\n):\n    gpb, _, ntypes, etypes = load_partition_book(\n        part_config=part_config, part_id=part_i, part_metadata=part_metadata\n    )\n    graph = parts[part_i]\n    csc_graph = _convert_dgl_partition_to_gb(\n        ntypes=ntypes,\n        etypes=etypes,\n        gpb=gpb,\n        part_meta=part_metadata,\n        graph=graph,\n        store_eids=store_eids,\n        store_inner_edge=store_inner_edge,\n        store_inner_node=store_inner_node,\n        graph_formats=graph_formats,\n    )\n    rel_path_result = _save_graph_gb(\n        part_config=part_config, part_id=part_i, csc_graph=csc_graph\n    )\n    part_metadata[f\"part-{part_i}\"][\"part_graph_graphbolt\"] = rel_path_result\n\n\ndef _update_node_edge_map(node_map_val, edge_map_val, g, num_parts):\n    \"\"\"\n    If the original graph contains few nodes or edges for specific node/edge\n    types, the partitioned graph may have empty partitions for these types. And\n    the node_map_val and edge_map_val will have -1 for the start and end ID of\n    these types. This function updates the node_map_val and edge_map_val to be\n    contiguous.\n\n    Example case:\n    Suppose we have a heterogeneous graph with 3 node/edge types and the number\n    of partitions is 3. A possible node_map_val or edge_map_val is as follows:\n\n    | part_id\\\\Node/Edge Type| Type A |  Type B | Type C |\n    |------------------------|--------|---------|--------|\n    | 0                      | 0, 1   |  -1, -1 |  2, 3  |\n    | 1                      | -1, -1 |  3, 4   |  4, 5  |\n    | 2                      | 5, 6   |  7, 8   |  -1, -1|\n\n    As node/edge IDs are contiguous in node/edge type for each partition, we can\n    update the node_map_val and edge_map_val via updating the start and end ID\n    in row-wise order.\n\n    Updated node_map_val or edge_map_val:\n\n    | part_id\\\\Node/Edge Type| Type A |  Type B | Type C |\n    |------------------------|--------|---------|--------|\n    | 0                      |  0, 1  |  1, 1   |  2, 3  |\n    | 1                      |  3, 3  |  3, 4   |  4, 5  |\n    | 2                      |  5, 6  |  7, 8   |  8, 8  |\n\n    \"\"\"\n    # Update the node_map_val to be contiguous.\n    ntype_ids = {ntype: g.get_ntype_id(ntype) for ntype in g.ntypes}\n    ntype_ids_reverse = {v: k for k, v in ntype_ids.items()}\n    for part_id in range(num_parts):\n        for ntype_id in list(ntype_ids.values()):\n            ntype = ntype_ids_reverse[ntype_id]\n            start_id = node_map_val[ntype][part_id][0]\n            end_id = node_map_val[ntype][part_id][1]\n            if not (start_id == -1 and end_id == -1):\n                continue\n            prev_ntype_id = (\n                ntype_ids[ntype] - 1\n                if ntype_ids[ntype] > 0\n                else max(ntype_ids.values())\n            )\n            prev_ntype = ntype_ids_reverse[prev_ntype_id]\n            if ntype_ids[ntype] == 0:\n                if part_id == 0:\n                    node_map_val[ntype][part_id][0] = 0\n                else:\n                    node_map_val[ntype][part_id][0] = node_map_val[prev_ntype][\n                        part_id - 1\n                    ][1]\n            else:\n                node_map_val[ntype][part_id][0] = node_map_val[prev_ntype][\n                    part_id\n                ][1]\n            node_map_val[ntype][part_id][1] = node_map_val[ntype][part_id][0]\n    # Update the edge_map_val to be contiguous.\n    etype_ids = {etype: g.get_etype_id(etype) for etype in g.canonical_etypes}\n    etype_ids_reverse = {v: k for k, v in etype_ids.items()}\n    for part_id in range(num_parts):\n        for etype_id in list(etype_ids.values()):\n            etype = etype_ids_reverse[etype_id]\n            start_id = edge_map_val[etype][part_id][0]\n            end_id = edge_map_val[etype][part_id][1]\n            if not (start_id == -1 and end_id == -1):\n                continue\n            prev_etype_id = (\n                etype_ids[etype] - 1\n                if etype_ids[etype] > 0\n                else max(etype_ids.values())\n            )\n            prev_etype = etype_ids_reverse[prev_etype_id]\n            if etype_ids[etype] == 0:\n                if part_id == 0:\n                    edge_map_val[etype][part_id][0] = 0\n                else:\n                    edge_map_val[etype][part_id][0] = edge_map_val[prev_etype][\n                        part_id - 1\n                    ][1]\n            else:\n                edge_map_val[etype][part_id][0] = edge_map_val[prev_etype][\n                    part_id\n                ][1]\n            edge_map_val[etype][part_id][1] = edge_map_val[etype][part_id][0]\n\n\ndef partition_graph(\n    g,\n    graph_name,\n    num_parts,\n    out_path,\n    num_hops=1,\n    part_method=\"metis\",\n    balance_ntypes=None,\n    balance_edges=False,\n    return_mapping=False,\n    num_trainers_per_machine=1,\n    objtype=\"cut\",\n    graph_formats=None,\n    use_graphbolt=False,\n    **kwargs,\n):\n    \"\"\"Partition a graph for distributed training and store the partitions on files.\n\n    The partitioning occurs in three steps: 1) run a partition algorithm (e.g., Metis) to\n    assign nodes to partitions; 2) construct partition graph structure based on\n    the node assignment; 3) split the node features and edge features based on\n    the partition result.\n\n    When a graph is partitioned, each partition can contain *HALO* nodes, which are assigned\n    to other partitions but are included in this partition for efficiency purpose.\n    In this document, *local nodes/edges* refers to the nodes and edges that truly belong to\n    a partition. The rest are \"HALO nodes/edges\".\n\n    The partitioned data is stored into multiple files organized as follows:\n\n    .. code-block:: none\n\n        data_root_dir/\n          |-- graph_name.json     # partition configuration file in JSON\n          |-- node_map.npy        # partition id of each node stored in a numpy array (optional)\n          |-- edge_map.npy        # partition id of each edge stored in a numpy array (optional)\n          |-- part0/              # data for partition 0\n              |-- node_feats.dgl  # node features stored in binary format\n              |-- edge_feats.dgl  # edge features stored in binary format\n              |-- graph.dgl       # graph structure of this partition stored in binary format\n          |-- part1/              # data for partition 1\n              |-- node_feats.dgl\n              |-- edge_feats.dgl\n              |-- graph.dgl\n\n    First, the metadata of the original graph and the partitioning is stored in a JSON file\n    named after ``graph_name``. This JSON file contains the information of the original graph\n    as well as the path of the files that store each partition. Below show an example.\n\n    .. code-block:: none\n\n        {\n           \"graph_name\" : \"test\",\n           \"part_method\" : \"metis\",\n           \"num_parts\" : 2,\n           \"halo_hops\" : 1,\n           \"node_map\": {\n               \"_N\": [ [ 0, 1261310 ],\n                       [ 1261310, 2449029 ] ]\n           },\n           \"edge_map\": {\n               \"_N:_E:_N\": [ [ 0, 62539528 ],\n                             [ 62539528, 123718280 ] ]\n           },\n           \"etypes\": { \"_N:_E:_N\": 0 },\n           \"ntypes\": { \"_N\": 0 },\n           \"num_nodes\" : 1000000,\n           \"num_edges\" : 52000000,\n           \"part-0\" : {\n             \"node_feats\" : \"data_root_dir/part0/node_feats.dgl\",\n             \"edge_feats\" : \"data_root_dir/part0/edge_feats.dgl\",\n             \"part_graph\" : \"data_root_dir/part0/graph.dgl\",\n           },\n           \"part-1\" : {\n             \"node_feats\" : \"data_root_dir/part1/node_feats.dgl\",\n             \"edge_feats\" : \"data_root_dir/part1/edge_feats.dgl\",\n             \"part_graph\" : \"data_root_dir/part1/graph.dgl\",\n           },\n        }\n\n    Here are the definition of the fields in the partition configuration file:\n\n    * ``graph_name`` is the name of the graph given by a user.\n    * ``part_method`` is the method used to assign nodes to partitions.\n      Currently, it supports \"random\" and \"metis\".\n    * ``num_parts`` is the number of partitions.\n    * ``halo_hops`` is the number of hops of nodes we include in a partition as HALO nodes.\n    * ``node_map`` is the node assignment map, which tells the partition ID a node is assigned to.\n      The format of ``node_map`` is described below.\n    * ``edge_map`` is the edge assignment map, which tells the partition ID an edge is assigned to.\n    * ``num_nodes`` is the number of nodes in the global graph.\n    * ``num_edges`` is the number of edges in the global graph.\n    * `part-*` stores the data of a partition.\n\n    As node/edge IDs are reshuffled, ``node_map`` and ``edge_map`` contains the information\n    for mapping between global node/edge IDs to partition-local node/edge IDs.\n    For heterogeneous graphs, the information in ``node_map`` and ``edge_map`` can also be used\n    to compute node types and edge types. The format of the data in ``node_map`` and ``edge_map``\n    is as follows:\n\n    .. code-block:: none\n\n        {\n            \"node_type\": [ [ part1_start, part1_end ],\n                           [ part2_start, part2_end ],\n                           ... ],\n            ...\n        },\n\n    Essentially, ``node_map`` and ``edge_map`` are dictionaries. The keys are\n    node etypes and canonical edge types respectively. The values are lists of pairs\n    containing the start and end of the ID range for the corresponding types in a partition.\n    The length of the list is the number of\n    partitions; each element in the list is a tuple that stores the start and the end of\n    an ID range for a particular node/edge type in the partition.\n\n    The graph structure of a partition is stored in a file with the DGLGraph format.\n    Nodes in each partition is *relabeled* to always start with zero. We call the node\n    ID in the original graph, *global ID*, while the relabeled ID in each partition,\n    *local ID*. Each partition graph has an integer node data tensor stored under name\n    `dgl.NID` and each value is the node's global ID. Similarly, edges are relabeled too\n    and the mapping from local ID to global ID is stored as an integer edge data tensor\n    under name `dgl.EID`. For a heterogeneous graph, the DGLGraph also contains a node\n    data `dgl.NTYPE` for node type and an edge data `dgl.ETYPE` for the edge type.\n\n    The partition graph contains additional node data (\"inner_node\") and\n    edge data (\"inner_edge\"):\n\n    * \"inner_node\" indicates whether a node belongs to a partition.\n    * \"inner_edge\" indicates whether an edge belongs to a partition.\n\n    Node and edge features are splitted and stored together with each graph partition.\n    All node/edge features in a partition are stored in a file with DGL format. The node/edge\n    features are stored in dictionaries, in which the key is the node/edge data name and\n    the value is a tensor. We do not store features of HALO nodes and edges.\n\n    When performing Metis partitioning, we can put some constraint on the partitioning.\n    Current, it supports two constrants to balance the partitioning. By default, Metis\n    always tries to balance the number of nodes in each partition.\n\n    * ``balance_ntypes`` balances the number of nodes of different types in each partition.\n    * ``balance_edges`` balances the number of edges in each partition.\n\n    To balance the node types, a user needs to pass a vector of N elements to indicate\n    the type of each node. N is the number of nodes in the input graph.\n\n    Parameters\n    ----------\n    g : DGLGraph\n        The input graph to partition\n    graph_name : str\n        The name of the graph. The name will be used to construct\n        :py:meth:`~dgl.distributed.DistGraph`.\n    num_parts : int\n        The number of partitions\n    out_path : str\n        The path to store the files for all partitioned data.\n    num_hops : int, optional\n        The number of hops of HALO nodes we construct on a partition graph structure.\n        The default value is 1.\n    part_method : str, optional\n        The partition method. It supports \"random\" and \"metis\". The default value is \"metis\".\n    balance_ntypes : tensor, optional\n        Node type of each node. This is a 1D-array of integers. Its values indicates the node\n        type of each node. This argument is used by Metis partition. When the argument is\n        specified, the Metis algorithm will try to partition the input graph into partitions where\n        each partition has roughly the same number of nodes for each node type. The default value\n        is None, which means Metis partitions the graph to only balance the number of nodes.\n    balance_edges : bool\n        Indicate whether to balance the edges in each partition. This argument is used by\n        the Metis algorithm.\n    return_mapping : bool\n        Indicate whether to return the mapping between shuffled node/edge IDs and the original\n        node/edge IDs.\n    num_trainers_per_machine : int, optional\n        The number of trainers per machine. If is not 1, the whole graph will be first partitioned\n        to each trainer, that is num_parts*num_trainers_per_machine parts. And the trainer ids of\n        each node will be stored in the node feature 'trainer_id'. Then the partitions of trainers\n        on the same machine will be coalesced into one larger partition. The final number of\n        partitions is `num_part`.\n    objtype : str, \"cut\" or \"vol\"\n        Set the objective as edge-cut minimization or communication volume minimization. This\n        argument is used by the Metis algorithm.\n    graph_formats : str or list[str]\n        Save partitions in specified formats. It could be any combination of ``coo``,\n        ``csc`` and ``csr``. If not specified, save one format only according to what\n        format is available. If multiple formats are available, selection priority\n        from high to low is ``coo``, ``csc``, ``csr``.\n    use_graphbolt : bool, optional\n        Whether to save partitions in GraphBolt format. Default: False.\n    kwargs : dict\n        Other keyword arguments for converting DGL partitions to GraphBolt.\n\n    Returns\n    -------\n    Tensor or dict of tensors, optional\n        If `return_mapping=True`, return a 1D tensor that indicates the mapping between shuffled\n        node IDs and the original node IDs for a homogeneous graph; return a dict of 1D tensors\n        whose key is the node type and value is a 1D tensor mapping between shuffled node IDs and\n        the original node IDs for each node type for a heterogeneous graph.\n    Tensor or dict of tensors, optional\n        If `return_mapping=True`, return a 1D tensor that indicates the mapping between shuffled\n        edge IDs and the original edge IDs for a homogeneous graph; return a dict of 1D tensors\n        whose key is the edge type and value is a 1D tensor mapping between shuffled edge IDs and\n        the original edge IDs for each edge type for a heterogeneous graph.\n\n    Examples\n    --------\n    >>> dgl.distributed.partition_graph(g, 'test', 4, num_hops=1, part_method='metis',\n    ...                                 out_path='output/',\n    ...                                 balance_ntypes=g.ndata['train_mask'],\n    ...                                 balance_edges=True)\n    >>> (\n    ...     g, node_feats, edge_feats, gpb, graph_name, ntypes_list, etypes_list,\n    ... ) = dgl.distributed.load_partition('output/test.json', 0)\n    \"\"\"\n    # 'coo' is required for partition\n    assert \"coo\" in np.concatenate(\n        list(g.formats().values())\n    ), \"'coo' format should be allowed for partitioning graph.\"\n\n    def get_homogeneous(g, balance_ntypes):\n        if g.is_homogeneous:\n            sim_g = to_homogeneous(g)\n            if isinstance(balance_ntypes, dict):\n                assert len(balance_ntypes) == 1\n                bal_ntypes = list(balance_ntypes.values())[0]\n            else:\n                bal_ntypes = balance_ntypes\n        elif isinstance(balance_ntypes, dict):\n            # Here we assign node types for load balancing.\n            # The new node types includes the ones provided by users.\n            num_ntypes = 0\n            for key in g.ntypes:\n                if key in balance_ntypes:\n                    g.nodes[key].data[\"bal_ntype\"] = (\n                        F.astype(balance_ntypes[key], F.int32) + num_ntypes\n                    )\n                    uniq_ntypes = F.unique(balance_ntypes[key])\n                    assert np.all(\n                        F.asnumpy(uniq_ntypes) == np.arange(len(uniq_ntypes))\n                    )\n                    num_ntypes += len(uniq_ntypes)\n                else:\n                    g.nodes[key].data[\"bal_ntype\"] = (\n                        F.ones((g.num_nodes(key),), F.int32, F.cpu())\n                        * num_ntypes\n                    )\n                    num_ntypes += 1\n            sim_g = to_homogeneous(g, ndata=[\"bal_ntype\"])\n            bal_ntypes = sim_g.ndata[\"bal_ntype\"]\n            print(\n                \"The graph has {} node types and balance among {} types\".format(\n                    len(g.ntypes), len(F.unique(bal_ntypes))\n                )\n            )\n            # We now no longer need them.\n            for key in g.ntypes:\n                del g.nodes[key].data[\"bal_ntype\"]\n            del sim_g.ndata[\"bal_ntype\"]\n        else:\n            sim_g = to_homogeneous(g)\n            bal_ntypes = sim_g.ndata[NTYPE]\n        return sim_g, bal_ntypes\n\n    if objtype not in [\"cut\", \"vol\"]:\n        raise ValueError\n\n    if num_parts == 1:\n        start = time.time()\n        sim_g, balance_ntypes = get_homogeneous(g, balance_ntypes)\n        print(\n            \"Converting to homogeneous graph takes {:.3f}s, peak mem: {:.3f} GB\".format(\n                time.time() - start, get_peak_mem()\n            )\n        )\n        assert num_trainers_per_machine >= 1\n        if num_trainers_per_machine > 1:\n            # First partition the whole graph to each trainer and save the trainer ids in\n            # the node feature \"trainer_id\".\n            start = time.time()\n            node_parts = metis_partition_assignment(\n                sim_g,\n                num_parts * num_trainers_per_machine,\n                balance_ntypes=balance_ntypes,\n                balance_edges=balance_edges,\n                mode=\"k-way\",\n            )\n            _set_trainer_ids(g, sim_g, node_parts)\n            print(\n                \"Assigning nodes to METIS partitions takes {:.3f}s, peak mem: {:.3f} GB\".format(\n                    time.time() - start, get_peak_mem()\n                )\n            )\n\n        node_parts = F.zeros((sim_g.num_nodes(),), F.int64, F.cpu())\n        parts = {0: sim_g.clone()}\n        orig_nids = parts[0].ndata[NID] = F.arange(0, sim_g.num_nodes())\n        orig_eids = parts[0].edata[EID] = F.arange(0, sim_g.num_edges())\n        # For one partition, we don't really shuffle nodes and edges. We just need to simulate\n        # it and set node data and edge data of orig_id.\n        parts[0].ndata[\"orig_id\"] = orig_nids\n        parts[0].edata[\"orig_id\"] = orig_eids\n        if return_mapping:\n            if g.is_homogeneous:\n                orig_nids = F.arange(0, sim_g.num_nodes())\n                orig_eids = F.arange(0, sim_g.num_edges())\n            else:\n                orig_nids = {\n                    ntype: F.arange(0, g.num_nodes(ntype)) for ntype in g.ntypes\n                }\n                orig_eids = {\n                    etype: F.arange(0, g.num_edges(etype))\n                    for etype in g.canonical_etypes\n                }\n        parts[0].ndata[\"inner_node\"] = F.ones(\n            (sim_g.num_nodes(),),\n            RESERVED_FIELD_DTYPE[\"inner_node\"],\n            F.cpu(),\n        )\n        parts[0].edata[\"inner_edge\"] = F.ones(\n            (sim_g.num_edges(),),\n            RESERVED_FIELD_DTYPE[\"inner_edge\"],\n            F.cpu(),\n        )\n    elif part_method in (\"metis\", \"random\"):\n        start = time.time()\n        sim_g, balance_ntypes = get_homogeneous(g, balance_ntypes)\n        print(\n            \"Converting to homogeneous graph takes {:.3f}s, peak mem: {:.3f} GB\".format(\n                time.time() - start, get_peak_mem()\n            )\n        )\n        if part_method == \"metis\":\n            assert num_trainers_per_machine >= 1\n            start = time.time()\n            if num_trainers_per_machine > 1:\n                # First partition the whole graph to each trainer and save the trainer ids in\n                # the node feature \"trainer_id\".\n                node_parts = metis_partition_assignment(\n                    sim_g,\n                    num_parts * num_trainers_per_machine,\n                    balance_ntypes=balance_ntypes,\n                    balance_edges=balance_edges,\n                    mode=\"k-way\",\n                    objtype=objtype,\n                )\n                _set_trainer_ids(g, sim_g, node_parts)\n\n                # And then coalesce the partitions of trainers on the same machine into one\n                # larger partition.\n                node_parts = F.floor_div(node_parts, num_trainers_per_machine)\n            else:\n                node_parts = metis_partition_assignment(\n                    sim_g,\n                    num_parts,\n                    balance_ntypes=balance_ntypes,\n                    balance_edges=balance_edges,\n                    objtype=objtype,\n                )\n            print(\n                \"Assigning nodes to METIS partitions takes {:.3f}s, peak mem: {:.3f} GB\".format(\n                    time.time() - start, get_peak_mem()\n                )\n            )\n        else:\n            node_parts = random_choice(num_parts, sim_g.num_nodes())\n        start = time.time()\n        parts, orig_nids, orig_eids = partition_graph_with_halo(\n            sim_g, node_parts, num_hops, reshuffle=True\n        )\n        print(\n            \"Splitting the graph into partitions takes {:.3f}s, peak mem: {:.3f} GB\".format(\n                time.time() - start, get_peak_mem()\n            )\n        )\n        if return_mapping:\n            orig_nids, orig_eids = _get_orig_ids(g, sim_g, orig_nids, orig_eids)\n    else:\n        raise Exception(\"Unknown partitioning method: \" + part_method)\n\n    # If the input is a heterogeneous graph, get the original node types and original node IDs.\n    # `part' has three types of node data at this point.\n    # NTYPE: the node type.\n    # orig_id: the global node IDs in the homogeneous version of input graph.\n    # NID: the global node IDs in the reshuffled homogeneous version of the input graph.\n    if not g.is_homogeneous:\n        for name in parts:\n            orig_ids = parts[name].ndata[\"orig_id\"]\n            ntype = F.gather_row(sim_g.ndata[NTYPE], orig_ids)\n            parts[name].ndata[NTYPE] = F.astype(\n                ntype, RESERVED_FIELD_DTYPE[NTYPE]\n            )\n            assert np.all(\n                F.asnumpy(ntype) == F.asnumpy(parts[name].ndata[NTYPE])\n            )\n            # Get the original edge types and original edge IDs.\n            orig_ids = parts[name].edata[\"orig_id\"]\n            etype = F.gather_row(sim_g.edata[ETYPE], orig_ids)\n            parts[name].edata[ETYPE] = F.astype(\n                etype, RESERVED_FIELD_DTYPE[ETYPE]\n            )\n            assert np.all(\n                F.asnumpy(etype) == F.asnumpy(parts[name].edata[ETYPE])\n            )\n\n            # Calculate the global node IDs to per-node IDs mapping.\n            inner_ntype = F.boolean_mask(\n                parts[name].ndata[NTYPE], parts[name].ndata[\"inner_node\"] == 1\n            )\n            inner_nids = F.boolean_mask(\n                parts[name].ndata[NID], parts[name].ndata[\"inner_node\"] == 1\n            )\n            for ntype in g.ntypes:\n                inner_ntype_mask = inner_ntype == g.get_ntype_id(ntype)\n                if F.sum(F.astype(inner_ntype_mask, F.int64), 0) == 0:\n                    # Skip if there is no node of this type in this partition.\n                    continue\n                typed_nids = F.boolean_mask(inner_nids, inner_ntype_mask)\n                # inner node IDs are in a contiguous ID range.\n                expected_range = np.arange(\n                    int(F.as_scalar(typed_nids[0])),\n                    int(F.as_scalar(typed_nids[-1])) + 1,\n                )\n                assert np.all(F.asnumpy(typed_nids) == expected_range)\n            # Calculate the global edge IDs to per-edge IDs mapping.\n            inner_etype = F.boolean_mask(\n                parts[name].edata[ETYPE], parts[name].edata[\"inner_edge\"] == 1\n            )\n            inner_eids = F.boolean_mask(\n                parts[name].edata[EID], parts[name].edata[\"inner_edge\"] == 1\n            )\n            for etype in g.canonical_etypes:\n                inner_etype_mask = inner_etype == g.get_etype_id(etype)\n                if F.sum(F.astype(inner_etype_mask, F.int64), 0) == 0:\n                    # Skip if there is no edge of this type in this partition.\n                    continue\n                typed_eids = np.sort(\n                    F.asnumpy(F.boolean_mask(inner_eids, inner_etype_mask))\n                )\n                assert np.all(\n                    typed_eids\n                    == np.arange(int(typed_eids[0]), int(typed_eids[-1]) + 1)\n                )\n\n    os.makedirs(out_path, mode=0o775, exist_ok=True)\n    tot_num_inner_edges = 0\n    out_path = os.path.abspath(out_path)\n\n    # With reshuffling, we can ensure that all nodes and edges are reshuffled\n    # and are in contiguous ID space.\n    if num_parts > 1:\n        node_map_val = {}\n        edge_map_val = {}\n        for ntype in g.ntypes:\n            ntype_id = g.get_ntype_id(ntype)\n            val = []\n            node_map_val[ntype] = []\n            for i in parts:\n                inner_node_mask = _get_inner_node_mask(parts[i], ntype_id)\n                val.append(\n                    F.as_scalar(F.sum(F.astype(inner_node_mask, F.int64), 0))\n                )\n                if F.sum(F.astype(inner_node_mask, F.int64), 0) == 0:\n                    node_map_val[ntype].append([-1, -1])\n                    continue\n                inner_nids = F.boolean_mask(\n                    parts[i].ndata[NID], inner_node_mask\n                )\n                node_map_val[ntype].append(\n                    [\n                        int(F.as_scalar(inner_nids[0])),\n                        int(F.as_scalar(inner_nids[-1])) + 1,\n                    ]\n                )\n            val = np.cumsum(val).tolist()\n            assert val[-1] == g.num_nodes(ntype)\n        for etype in g.canonical_etypes:\n            etype_id = g.get_etype_id(etype)\n            val = []\n            edge_map_val[etype] = []\n            for i in parts:\n                inner_edge_mask = _get_inner_edge_mask(parts[i], etype_id)\n                val.append(\n                    F.as_scalar(F.sum(F.astype(inner_edge_mask, F.int64), 0))\n                )\n                if F.sum(F.astype(inner_edge_mask, F.int64), 0) == 0:\n                    edge_map_val[etype].append([-1, -1])\n                    continue\n                inner_eids = np.sort(\n                    F.asnumpy(\n                        F.boolean_mask(parts[i].edata[EID], inner_edge_mask)\n                    )\n                )\n                edge_map_val[etype].append(\n                    [int(inner_eids[0]), int(inner_eids[-1]) + 1]\n                )\n            val = np.cumsum(val).tolist()\n            assert val[-1] == g.num_edges(etype)\n        # Update the node_map_val and edge_map_val to be contiguous.\n        _update_node_edge_map(node_map_val, edge_map_val, g, num_parts)\n    else:\n        node_map_val = {}\n        edge_map_val = {}\n        for ntype in g.ntypes:\n            ntype_id = g.get_ntype_id(ntype)\n            inner_node_mask = _get_inner_node_mask(parts[0], ntype_id)\n            inner_nids = F.boolean_mask(parts[0].ndata[NID], inner_node_mask)\n            node_map_val[ntype] = [\n                [\n                    int(F.as_scalar(inner_nids[0])),\n                    int(F.as_scalar(inner_nids[-1])) + 1,\n                ]\n            ]\n        for etype in g.canonical_etypes:\n            etype_id = g.get_etype_id(etype)\n            inner_edge_mask = _get_inner_edge_mask(parts[0], etype_id)\n            inner_eids = F.boolean_mask(parts[0].edata[EID], inner_edge_mask)\n            edge_map_val[etype] = [\n                [\n                    int(F.as_scalar(inner_eids[0])),\n                    int(F.as_scalar(inner_eids[-1])) + 1,\n                ]\n            ]\n\n        # Double check that the node IDs in the global ID space are sorted.\n        for ntype in node_map_val:\n            val = np.concatenate([np.array(l) for l in node_map_val[ntype]])\n            assert np.all(val[:-1] <= val[1:])\n        for etype in edge_map_val:\n            val = np.concatenate([np.array(l) for l in edge_map_val[etype]])\n            assert np.all(val[:-1] <= val[1:])\n\n    start = time.time()\n    ntypes = {ntype: g.get_ntype_id(ntype) for ntype in g.ntypes}\n    etypes = {etype: g.get_etype_id(etype) for etype in g.canonical_etypes}\n    part_metadata = {\n        \"graph_name\": graph_name,\n        \"num_nodes\": g.num_nodes(),\n        \"num_edges\": g.num_edges(),\n        \"part_method\": part_method,\n        \"num_parts\": num_parts,\n        \"halo_hops\": num_hops,\n        \"node_map\": node_map_val,\n        \"edge_map\": edge_map_val,\n        \"ntypes\": ntypes,\n        \"etypes\": etypes,\n    }\n    part_config = os.path.join(out_path, graph_name + \".json\")\n    for part_id in range(num_parts):\n        part = parts[part_id]\n\n        # Get the node/edge features of each partition.\n        node_feats = {}\n        edge_feats = {}\n        if num_parts > 1:\n            for ntype in g.ntypes:\n                ntype_id = g.get_ntype_id(ntype)\n                # To get the edges in the input graph, we should use original node IDs.\n                # Both orig_id and NID stores the per-node-type IDs.\n                ndata_name = \"orig_id\"\n                inner_node_mask = _get_inner_node_mask(part, ntype_id)\n                # This is global node IDs.\n                local_nodes = F.boolean_mask(\n                    part.ndata[ndata_name], inner_node_mask\n                )\n                if len(g.ntypes) > 1:\n                    # If the input is a heterogeneous graph.\n                    local_nodes = F.gather_row(sim_g.ndata[NID], local_nodes)\n                    print(\n                        \"part {} has {} nodes of type {} and {} are inside the partition\".format(\n                            part_id,\n                            F.as_scalar(\n                                F.sum(part.ndata[NTYPE] == ntype_id, 0)\n                            ),\n                            ntype,\n                            len(local_nodes),\n                        )\n                    )\n                else:\n                    print(\n                        \"part {} has {} nodes and {} are inside the partition\".format(\n                            part_id, part.num_nodes(), len(local_nodes)\n                        )\n                    )\n\n                for name in g.nodes[ntype].data:\n                    if name in [NID, \"inner_node\"]:\n                        continue\n                    node_feats[ntype + \"/\" + name] = F.gather_row(\n                        g.nodes[ntype].data[name], local_nodes\n                    )\n\n            for etype in g.canonical_etypes:\n                etype_id = g.get_etype_id(etype)\n                edata_name = \"orig_id\"\n                inner_edge_mask = _get_inner_edge_mask(part, etype_id)\n                # This is global edge IDs.\n                local_edges = F.boolean_mask(\n                    part.edata[edata_name], inner_edge_mask\n                )\n                if not g.is_homogeneous:\n                    local_edges = F.gather_row(sim_g.edata[EID], local_edges)\n                    print(\n                        \"part {} has {} edges of type {} and {} are inside the partition\".format(\n                            part_id,\n                            F.as_scalar(\n                                F.sum(part.edata[ETYPE] == etype_id, 0)\n                            ),\n                            etype,\n                            len(local_edges),\n                        )\n                    )\n                else:\n                    print(\n                        \"part {} has {} edges and {} are inside the partition\".format(\n                            part_id, part.num_edges(), len(local_edges)\n                        )\n                    )\n                tot_num_inner_edges += len(local_edges)\n\n                for name in g.edges[etype].data:\n                    if name in [EID, \"inner_edge\"]:\n                        continue\n                    edge_feats[\n                        _etype_tuple_to_str(etype) + \"/\" + name\n                    ] = F.gather_row(g.edges[etype].data[name], local_edges)\n        else:\n            for ntype in g.ntypes:\n                if len(g.ntypes) > 1:\n                    ndata_name = \"orig_id\"\n                    ntype_id = g.get_ntype_id(ntype)\n                    inner_node_mask = _get_inner_node_mask(part, ntype_id)\n                    # This is global node IDs.\n                    local_nodes = F.boolean_mask(\n                        part.ndata[ndata_name], inner_node_mask\n                    )\n                    local_nodes = F.gather_row(sim_g.ndata[NID], local_nodes)\n                else:\n                    local_nodes = sim_g.ndata[NID]\n                for name in g.nodes[ntype].data:\n                    if name in [NID, \"inner_node\"]:\n                        continue\n                    node_feats[ntype + \"/\" + name] = F.gather_row(\n                        g.nodes[ntype].data[name], local_nodes\n                    )\n            for etype in g.canonical_etypes:\n                if not g.is_homogeneous:\n                    edata_name = \"orig_id\"\n                    etype_id = g.get_etype_id(etype)\n                    inner_edge_mask = _get_inner_edge_mask(part, etype_id)\n                    # This is global edge IDs.\n                    local_edges = F.boolean_mask(\n                        part.edata[edata_name], inner_edge_mask\n                    )\n                    local_edges = F.gather_row(sim_g.edata[EID], local_edges)\n                else:\n                    local_edges = sim_g.edata[EID]\n                for name in g.edges[etype].data:\n                    if name in [EID, \"inner_edge\"]:\n                        continue\n                    edge_feats[\n                        _etype_tuple_to_str(etype) + \"/\" + name\n                    ] = F.gather_row(g.edges[etype].data[name], local_edges)\n        # delete `orig_id` from ndata/edata\n        del part.ndata[\"orig_id\"]\n        del part.edata[\"orig_id\"]\n\n        part_dir = os.path.join(out_path, \"part\" + str(part_id))\n        node_feat_file = os.path.join(part_dir, \"node_feat.dgl\")\n        edge_feat_file = os.path.join(part_dir, \"edge_feat.dgl\")\n\n        os.makedirs(part_dir, mode=0o775, exist_ok=True)\n        save_tensors(node_feat_file, node_feats)\n        save_tensors(edge_feat_file, edge_feats)\n\n        part_metadata[\"part-{}\".format(part_id)] = {\n            \"node_feats\": os.path.relpath(node_feat_file, out_path),\n            \"edge_feats\": os.path.relpath(edge_feat_file, out_path),\n        }\n        sort_etypes = len(g.etypes) > 1\n        part = process_partitions(part, graph_formats, sort_etypes)\n\n    # transmit to graphbolt and save graph\n    if use_graphbolt:\n        # save FusedCSCSamplingGraph\n        kwargs[\"graph_formats\"] = graph_formats\n        n_jobs = kwargs.pop(\"n_jobs\", 1)\n        mp_ctx = mp.get_context(\"spawn\")\n        with concurrent.futures.ProcessPoolExecutor(  # pylint: disable=unexpected-keyword-arg\n            max_workers=min(num_parts, n_jobs),\n            mp_context=mp_ctx,\n        ) as executor:\n            for part_id in range(num_parts):\n                executor.submit(\n                    _partition_to_graphbolt(\n                        part_i=part_id,\n                        part_config=part_config,\n                        part_metadata=part_metadata,\n                        parts=parts,\n                        **kwargs,\n                    )\n                )\n        part_metadata[\"node_map_dtype\"] = \"int64\"\n        part_metadata[\"edge_map_dtype\"] = \"int64\"\n    else:\n        for part_id, part in parts.items():\n            part_dir = os.path.join(out_path, \"part\" + str(part_id))\n            part_graph_file = os.path.join(part_dir, \"graph.dgl\")\n            part_metadata[\"part-{}\".format(part_id)][\n                \"part_graph\"\n            ] = os.path.relpath(part_graph_file, out_path)\n            # save DGLGraph\n            _save_dgl_graphs(\n                part_graph_file,\n                [part],\n                formats=graph_formats,\n            )\n\n    _dump_part_config(part_config, part_metadata)\n\n    num_cuts = sim_g.num_edges() - tot_num_inner_edges\n    if num_parts == 1:\n        num_cuts = 0\n    print(\n        \"There are {} edges in the graph and {} edge cuts for {} partitions.\".format(\n            g.num_edges(), num_cuts, num_parts\n        )\n    )\n\n    print(\n        \"Save partitions: {:.3f} seconds, peak memory: {:.3f} GB\".format(\n            time.time() - start, get_peak_mem()\n        )\n    )\n\n    if return_mapping:\n        return orig_nids, orig_eids\n\n\n# [TODO][Rui] Due to int64_t is expected in RPC, we have to limit the data type\n# of node/edge IDs to int64_t. See more details in #7175.\nDTYPES_TO_CHECK = {\n    \"default\": [torch.int32, torch.int64],\n    NID: [torch.int64],\n    EID: [torch.int64],\n    NTYPE: [torch.int8, torch.int16, torch.int32, torch.int64],\n    ETYPE: [torch.int8, torch.int16, torch.int32, torch.int64],\n    \"inner_node\": [torch.uint8],\n    \"inner_edge\": [torch.uint8],\n    \"part_id\": [torch.int8, torch.int16, torch.int32, torch.int64],\n}\n\n\ndef _cast_to_minimum_dtype(predicate, data, field=None):\n    if data is None:\n        return data\n    dtypes_to_check = DTYPES_TO_CHECK.get(field, DTYPES_TO_CHECK[\"default\"])\n    if data.dtype not in dtypes_to_check:\n        dgl_warning(\n            f\"Skipping as the data type of field {field} is {data.dtype}, \"\n            f\"while supported data types are {dtypes_to_check}.\"\n        )\n        return data\n    for dtype in dtypes_to_check:\n        if predicate < torch.iinfo(dtype).max:\n            return data.to(dtype)\n    return data\n\n\n# Utility functions.\ndef is_homogeneous(ntypes, etypes):\n    \"\"\"Checks if the provided ntypes and etypes form a homogeneous graph.\"\"\"\n    return len(ntypes) == 1 and len(etypes) == 1\n\n\ndef init_type_per_edge(graph, gpb):\n    \"\"\"Initialize edge ids for every edge type.\"\"\"\n    etype_ids = gpb.map_to_per_etype(graph.edata[EID])[0]\n    return etype_ids\n\n\ndef _load_part(part_config, part_id, parts=None):\n    \"\"\"load parts from variable or dist.\"\"\"\n    if parts is None:\n        graph, _, _, _, _, _, _ = load_partition(\n            part_config, part_id, load_feats=False\n        )\n    else:\n        graph = parts[part_id]\n    return graph\n\n\ndef _save_graph_gb(part_config, part_id, csc_graph):\n    csc_graph_save_dir = os.path.join(\n        os.path.dirname(part_config),\n        f\"part{part_id}\",\n    )\n    csc_graph_path = os.path.join(\n        csc_graph_save_dir, \"fused_csc_sampling_graph.pt\"\n    )\n    torch.save(csc_graph, csc_graph_path)\n\n    return os.path.relpath(csc_graph_path, os.path.dirname(part_config))\n\n\ndef cast_various_to_minimum_dtype_gb(\n    num_parts,\n    indptr,\n    indices,\n    type_per_edge,\n    etypes,\n    ntypes,\n    node_attributes,\n    edge_attributes,\n    part_meta=None,\n    graph=None,\n    edge_count=None,\n    node_count=None,\n    tot_edge_count=None,\n    tot_node_count=None,\n):\n    \"\"\"Cast various data to minimum dtype.\"\"\"\n    if graph is not None:\n        assert part_meta is not None\n        tot_edge_count = graph.num_edges()\n        tot_node_count = graph.num_nodes()\n        node_count = part_meta[\"num_nodes\"]\n        edge_count = part_meta[\"num_edges\"]\n    else:\n        assert tot_edge_count is not None\n        assert tot_node_count is not None\n        assert edge_count is not None\n        assert node_count is not None\n\n    # Cast 1: indptr.\n    indptr = _cast_to_minimum_dtype(tot_edge_count, indptr)\n    # Cast 2: indices.\n    indices = _cast_to_minimum_dtype(tot_node_count, indices)\n    # Cast 3: type_per_edge.\n    type_per_edge = _cast_to_minimum_dtype(\n        len(etypes), type_per_edge, field=ETYPE\n    )\n    # Cast 4: node/edge_attributes.\n    predicates = {\n        NID: node_count,\n        \"part_id\": num_parts,\n        NTYPE: len(ntypes),\n        EID: edge_count,\n        ETYPE: len(etypes),\n        DGL2GB_EID: edge_count,\n        GB_DST_ID: node_count,\n    }\n    for attributes in [node_attributes, edge_attributes]:\n        for key in attributes:\n            if key not in predicates:\n                continue\n            attributes[key] = _cast_to_minimum_dtype(\n                predicates[key], attributes[key], field=key\n            )\n    return indptr, indices, type_per_edge\n\n\ndef _create_attributes_gb(\n    graph,\n    gpb,\n    edge_ids,\n    is_homo,\n    store_inner_node,\n    store_inner_edge,\n    store_eids,\n    debug_mode,\n):\n    # Save node attributes. Detailed attributes are shown below.\n    #  DGL_GB\\Attributes  dgl.NID(\"_ID\")  dgl.NTYPE(\"_TYPE\")  \"inner_node\"  \"part_id\"\n    #  DGL_Homograph           ✅                🚫                  ✅          ✅\n    #  GB_Homograph            ✅                🚫               optional       🚫\n    #  DGL_Heterograph         ✅                ✅                  ✅          ✅\n    #  GB_Heterograph          ✅                🚫               optional       🚫\n    required_node_attrs = [NID]\n    if store_inner_node:\n        required_node_attrs.append(\"inner_node\")\n    if debug_mode:\n        required_node_attrs = list(graph.ndata.keys())\n    node_attributes = {attr: graph.ndata[attr] for attr in required_node_attrs}\n\n    # Save edge attributes. Detailed attributes are shown below.\n    #  DGL_GB\\Attributes  dgl.EID(\"_ID\")  dgl.ETYPE(\"_TYPE\")  \"inner_edge\"\n    #  DGL_Homograph           ✅               🚫                  ✅\n    #  GB_Homograph         optional            🚫               optional\n    #  DGL_Heterograph         ✅               ✅                  ✅\n    #  GB_Heterograph       optional            ✅               optional\n    type_per_edge = None\n    if not is_homo:\n        type_per_edge = init_type_per_edge(graph, gpb)[edge_ids]\n        type_per_edge = type_per_edge.to(RESERVED_FIELD_DTYPE[ETYPE])\n    required_edge_attrs = []\n    if store_eids:\n        required_edge_attrs.append(EID)\n    if store_inner_edge:\n        required_edge_attrs.append(\"inner_edge\")\n    if debug_mode:\n        required_edge_attrs = list(graph.edata.keys())\n    edge_attributes = {\n        attr: graph.edata[attr][edge_ids] for attr in required_edge_attrs\n    }\n    return node_attributes, edge_attributes, type_per_edge\n\n\ndef _convert_dgl_partition_to_gb(\n    ntypes,\n    etypes,\n    gpb,\n    part_meta,\n    graph,\n    graph_formats=None,\n    store_eids=False,\n    store_inner_node=False,\n    store_inner_edge=False,\n):\n    \"\"\"Converts a single DGL partition to GraphBolt.\n\n    Parameters\n    ----------\n    node types : dict\n        The node types\n    edge types : dict\n        The edge types\n    gpb : GraphPartitionBook\n        The global partition information.\n    part_meta : dict\n        Contain the meta data of the partition.\n    graph : DGLGraph\n        The graph to be converted to graphbolt graph.\n    graph_formats : str or list[str], optional\n        Save partitions in specified formats. It could be any combination of\n        `coo`, `csc`. As `csc` format is mandatory for `FusedCSCSamplingGraph`,\n        it is not necessary to specify this argument. It's mainly for\n        specifying `coo` format to save edge ID mapping and destination node\n        IDs. If not specified, whether to save `coo` format is determined by\n        the availability of the format in DGL partitions. Default: None.\n    store_eids : bool, optional\n        Whether to store edge IDs in the new graph. Default: True.\n    store_inner_node : bool, optional\n        Whether to store inner node mask in the new graph. Default: False.\n    store_inner_edge : bool, optional\n        Whether to store inner edge mask in the new graph. Default: False.\n    \"\"\"\n    debug_mode = \"DGL_DIST_DEBUG\" in os.environ\n    if debug_mode:\n        dgl_warning(\n            \"Running in debug mode which means all attributes of DGL partitions\"\n            \" will be saved to the new format.\"\n        )\n    num_parts = part_meta[\"num_parts\"]\n\n    is_homo = is_homogeneous(ntypes, etypes)\n    node_type_to_id = (\n        None if is_homo else {ntype: ntid for ntid, ntype in enumerate(ntypes)}\n    )\n    edge_type_to_id = (\n        None\n        if is_homo\n        else {\n            gb.etype_tuple_to_str(etype): etid for etype, etid in etypes.items()\n        }\n    )\n    # Obtain CSC indtpr and indices.\n    indptr, indices, edge_ids = graph.adj_tensors(\"csc\")\n\n    node_attributes, edge_attributes, type_per_edge = _create_attributes_gb(\n        graph,\n        gpb,\n        edge_ids,\n        is_homo,\n        store_inner_node,\n        store_inner_edge,\n        store_eids,\n        debug_mode,\n    )\n    # When converting DGLGraph to FusedCSCSamplingGraph, edge IDs are\n    # re-ordered(actually FusedCSCSamplingGraph does not have edge IDs\n    # in nature). So we need to save such re-order info for any\n    # operations that uses original local edge IDs. For now, this is\n    # required by `DistGraph.find_edges()` for link prediction tasks.\n    #\n    # What's more, in order to find the dst nodes efficiently, we save\n    # dst nodes directly in the edge attributes.\n    #\n    # So we require additional `(2 * E) * dtype` space in total.\n    if graph_formats is not None and isinstance(graph_formats, str):\n        graph_formats = [graph_formats]\n    save_coo = (\n        graph_formats is None and \"coo\" in graph.formats()[\"created\"]\n    ) or (graph_formats is not None and \"coo\" in graph_formats)\n    if save_coo:\n        edge_attributes[DGL2GB_EID] = torch.argsort(edge_ids)\n        edge_attributes[GB_DST_ID] = gb.expand_indptr(\n            indptr, dtype=indices.dtype\n        )\n\n    indptr, indices, type_per_edge = cast_various_to_minimum_dtype_gb(\n        graph=graph,\n        part_meta=part_meta,\n        num_parts=num_parts,\n        indptr=indptr,\n        indices=indices,\n        type_per_edge=type_per_edge,\n        etypes=etypes,\n        ntypes=ntypes,\n        node_attributes=node_attributes,\n        edge_attributes=edge_attributes,\n    )\n\n    csc_graph = gb.fused_csc_sampling_graph(\n        indptr,\n        indices,\n        node_type_offset=None,\n        type_per_edge=type_per_edge,\n        node_attributes=node_attributes,\n        edge_attributes=edge_attributes,\n        node_type_to_id=node_type_to_id,\n        edge_type_to_id=edge_type_to_id,\n    )\n    return csc_graph\n\n\ndef gb_convert_single_dgl_partition(\n    part_id,\n    graph_formats,\n    part_config,\n    store_eids=True,\n    store_inner_node=True,\n    store_inner_edge=True,\n):\n    \"\"\"\n    The pipeline converting signle partition to graphbolt.\n\n    Parameters\n    ----------\n    part_id : int\n        The partition ID.\n    graph_formats : str or list[str]\n        Save partitions in specified formats. It could be any combination of\n        `coo`, `csc`. As `csc` format is mandatory for `FusedCSCSamplingGraph`,\n        it is not necessary to specify this argument. It's mainly for\n        specifying `coo` format to save edge ID mapping and destination node\n        IDs. If not specified, whether to save `coo` format is determined by\n        the availability of the format in DGL partitions. Default: None.\n    part_config : str\n        The path of the partition config file.\n    store_eids : bool, optional\n        Whether to store edge IDs in the new graph. Default: True.\n    store_inner_node : bool, optional\n        Whether to store inner node mask in the new graph. Default: False.\n    store_inner_edge : bool, optional\n        Whether to store inner edge mask in the new graph. Default: False.\n\n    Returns\n    -------\n    str\n        The path csc_graph to save.\n    \"\"\"\n    gpb, _, ntypes, etypes = load_partition_book(\n        part_config=part_config, part_id=part_id\n    )\n    part = _load_part(part_config, part_id)\n    part_meta = copy.deepcopy(_load_part_config(part_config))\n    csc_graph = _convert_dgl_partition_to_gb(\n        graph=part,\n        ntypes=ntypes,\n        etypes=etypes,\n        gpb=gpb,\n        part_meta=part_meta,\n        graph_formats=graph_formats,\n        store_eids=store_eids,\n        store_inner_node=store_inner_node,\n        store_inner_edge=store_inner_edge,\n    )\n    rel_path = _save_graph_gb(part_config, part_id, csc_graph)\n    return rel_path\n\n\ndef _convert_partition_to_graphbolt_wrapper(\n    graph_formats,\n    part_config,\n    store_eids,\n    store_inner_node,\n    store_inner_edge,\n    n_jobs,\n    num_parts,\n):\n    # [Rui] DGL partitions are always saved as homogeneous graphs even though\n    # the original graph is heterogeneous. But heterogeneous information like\n    # node/edge types are saved as node/edge data alongside with partitions.\n    # What needs more attention is that due to the existence of HALO nodes in\n    # each partition, the local node IDs are not sorted according to the node\n    # types. So we fail to assign ``node_type_offset`` as required by GraphBolt.\n    # But this is not a problem since such information is not used in sampling.\n    # We can simply pass None to it.\n\n    # Iterate over partitions.\n    convert_with_format = partial(\n        gb_convert_single_dgl_partition,\n        part_config=part_config,\n        graph_formats=graph_formats,\n        store_eids=store_eids,\n        store_inner_node=store_inner_node,\n        store_inner_edge=store_inner_edge,\n    )\n    # Need to create entirely new interpreters, because we call C++ downstream\n    # See https://docs.python.org/3.12/library/multiprocessing.html#contexts-and-start-methods\n    # and https://pybind11.readthedocs.io/en/stable/advanced/misc.html#global-interpreter-lock-gil\n    rel_path_results = []\n    if n_jobs > 1 and num_parts > 1:\n        mp_ctx = mp.get_context(\"spawn\")\n        with concurrent.futures.ProcessPoolExecutor(  # pylint: disable=unexpected-keyword-arg\n            max_workers=min(num_parts, n_jobs),\n            mp_context=mp_ctx,\n        ) as executor:\n            for part_id in range(num_parts):\n                rel_path_results.append(\n                    executor.submit(\n                        convert_with_format, part_id=part_id\n                    ).result()\n                )\n\n    else:\n        # If running single-threaded, avoid spawning new interpreter, which is slow\n        for part_id in range(num_parts):\n            rel_path = convert_with_format(part_id=part_id)\n            rel_path_results.append(rel_path)\n    part_meta = _load_part_config(part_config)\n    for part_id in range(num_parts):\n        # Update graph path.\n        part_meta[f\"part-{part_id}\"][\"part_graph_graphbolt\"] = rel_path_results[\n            part_id\n        ]\n\n    # Save dtype info into partition config.\n    # [TODO][Rui] Always use int64_t for node/edge IDs in GraphBolt. See more\n    # details in #7175.\n    part_meta[\"node_map_dtype\"] = \"int64\"\n    part_meta[\"edge_map_dtype\"] = \"int64\"\n\n    return part_meta\n\n\ndef dgl_partition_to_graphbolt(\n    part_config,\n    *,\n    store_eids=True,\n    store_inner_node=False,\n    store_inner_edge=False,\n    graph_formats=None,\n    n_jobs=1,\n):\n    \"\"\"Convert partitions of dgl to FusedCSCSamplingGraph of GraphBolt.\n\n    This API converts `DGLGraph` partitions to `FusedCSCSamplingGraph` which is\n    dedicated for sampling in `GraphBolt`. New graphs will be stored alongside\n    original graph as `fused_csc_sampling_graph.pt`.\n\n    In the near future, partitions are supposed to be saved as\n    `FusedCSCSamplingGraph` directly. At that time, this API should be deprecated.\n\n    Parameters\n    ----------\n    part_config : str\n        The partition configuration JSON file.\n    store_eids : bool, optional\n        Whether to store edge IDs in the new graph. Default: True.\n    store_inner_node : bool, optional\n        Whether to store inner node mask in the new graph. Default: False.\n    store_inner_edge : bool, optional\n        Whether to store inner edge mask in the new graph. Default: False.\n    graph_formats : str or list[str], optional\n        Save partitions in specified formats. It could be any combination of\n        `coo`, `csc`. As `csc` format is mandatory for `FusedCSCSamplingGraph`,\n        it is not necessary to specify this argument. It's mainly for\n        specifying `coo` format to save edge ID mapping and destination node\n        IDs. If not specified, whether to save `coo` format is determined by\n        the availability of the format in DGL partitions. Default: None.\n    n_jobs: int\n        Number of parallel jobs to run during partition conversion. Max parallelism\n        is determined by the partition count.\n    \"\"\"\n    debug_mode = \"DGL_DIST_DEBUG\" in os.environ\n    if debug_mode:\n        dgl_warning(\n            \"Running in debug mode which means all attributes of DGL partitions\"\n            \" will be saved to the new format.\"\n        )\n    part_meta = _load_part_config(part_config)\n    num_parts = part_meta[\"num_parts\"]\n    part_meta = _convert_partition_to_graphbolt_wrapper(\n        graph_formats=graph_formats,\n        part_config=part_config,\n        store_eids=store_eids,\n        store_inner_node=store_inner_node,\n        store_inner_edge=store_inner_edge,\n        n_jobs=n_jobs,\n        num_parts=num_parts,\n    )\n    _dump_part_config(part_config, part_meta)\n"
  },
  {
    "path": "python/dgl/distributed/role.py",
    "content": "\"\"\"Manage the roles in different clients.\n\nRight now, the clients have different roles. Some clients work as samplers and\nsome work as trainers.\n\"\"\"\n\nimport os\n\nimport numpy as np\n\nfrom . import rpc\n\nREGISTER_ROLE = 700001\nREG_ROLE_MSG = \"Register_Role\"\n\n\nclass RegisterRoleResponse(rpc.Response):\n    \"\"\"Send a confirmation signal (just a short string message)\n    of RegisterRoleRequest to client.\n    \"\"\"\n\n    def __init__(self, msg):\n        self.msg = msg\n\n    def __getstate__(self):\n        return self.msg\n\n    def __setstate__(self, state):\n        self.msg = state\n\n\nclass RegisterRoleRequest(rpc.Request):\n    \"\"\"Send client id and role to server\n\n    Parameters\n    ----------\n    client_id : int\n        ID of client\n    role : str\n        role of client\n    \"\"\"\n\n    def __init__(self, client_id, machine_id, role):\n        self.client_id = client_id\n        self.machine_id = machine_id\n        self.role = role\n        self.group_id = rpc.get_group_id()\n\n    def __getstate__(self):\n        return self.client_id, self.machine_id, self.role, self.group_id\n\n    def __setstate__(self, state):\n        self.client_id, self.machine_id, self.role, self.group_id = state\n\n    def process_request(self, server_state):\n        kv_store = server_state.kv_store\n        role = server_state.roles.setdefault(self.group_id, {})\n        if self.role not in role:\n            role[self.role] = set()\n            if kv_store is not None:\n                barrier_count = kv_store.barrier_count.setdefault(\n                    self.group_id, {}\n                )\n                barrier_count[self.role] = 0\n        role[self.role].add((self.client_id, self.machine_id))\n        total_count = 0\n        for key in role:\n            total_count += len(role[key])\n        # Clients are blocked util all clients register their roles.\n        if total_count == rpc.get_num_client():\n            res_list = []\n            for target_id in range(rpc.get_num_client()):\n                res_list.append((target_id, RegisterRoleResponse(REG_ROLE_MSG)))\n            return res_list\n        return None\n\n\nGET_ROLE = 700002\nGET_ROLE_MSG = \"Get_Role\"\n\n\nclass GetRoleResponse(rpc.Response):\n    \"\"\"Send the roles of all client processes\"\"\"\n\n    def __init__(self, role):\n        self.role = role\n        self.msg = GET_ROLE_MSG\n\n    def __getstate__(self):\n        return self.role, self.msg\n\n    def __setstate__(self, state):\n        self.role, self.msg = state\n\n\nclass GetRoleRequest(rpc.Request):\n    \"\"\"Send a request to get the roles of all client processes.\"\"\"\n\n    def __init__(self):\n        self.msg = GET_ROLE_MSG\n        self.group_id = rpc.get_group_id()\n\n    def __getstate__(self):\n        return self.msg, self.group_id\n\n    def __setstate__(self, state):\n        self.msg, self.group_id = state\n\n    def process_request(self, server_state):\n        return GetRoleResponse(server_state.roles[self.group_id])\n\n\n# The key is role, the value is a dict of mapping RPC rank to a rank within the role.\nPER_ROLE_RANK = {}\n\n# The global rank of a client process. The client processes of the same role have\n# global ranks that fall in a contiguous range.\nGLOBAL_RANK = {}\n\n# The role of the current process\nCUR_ROLE = None\n\nIS_STANDALONE = False\n\n\ndef init_role(role):\n    \"\"\"Initialize the role of the current process.\n\n    Each process is associated with a role so that we can determine what\n    function can be invoked in a process. For example, we do not allow some\n    functions in sampler processes.\n\n    The initialization includes registeration the role of the current process and\n    get the roles of all client processes. It also computes the rank of all client\n    processes in a deterministic way so that all clients will have the same rank for\n    the same client process.\n    \"\"\"\n    global CUR_ROLE\n    CUR_ROLE = role\n\n    global PER_ROLE_RANK\n    global GLOBAL_RANK\n    global IS_STANDALONE\n\n    if os.environ.get(\"DGL_DIST_MODE\", \"standalone\") == \"standalone\":\n        if role == \"default\":\n            GLOBAL_RANK[0] = 0\n            PER_ROLE_RANK[\"default\"] = {0: 0}\n        IS_STANDALONE = True\n        return\n\n    PER_ROLE_RANK = {}\n    GLOBAL_RANK = {}\n\n    # Register the current role. This blocks until all clients register themselves.\n    client_id = rpc.get_rank()\n    machine_id = rpc.get_machine_id()\n    request = RegisterRoleRequest(client_id, machine_id, role)\n    rpc.send_request(0, request)\n    response = rpc.recv_response()\n    assert response.msg == REG_ROLE_MSG\n\n    # Get all clients on all machines.\n    request = GetRoleRequest()\n    rpc.send_request(0, request)\n    response = rpc.recv_response()\n    assert response.msg == GET_ROLE_MSG\n\n    # Here we want to compute a new rank for each client.\n    # We compute the per-role rank as well as global rank.\n    # For per-role rank, we ensure that all ranks within a machine is contiguous.\n    # For global rank, we also ensure that all ranks within a machine are contiguous,\n    # and all ranks within a role are contiguous.\n    global_rank = 0\n\n    # We want to ensure that the global rank of the trainer process starts from 0.\n    role_names = [\"default\"]\n    for role_name in response.role:\n        if role_name not in role_names:\n            role_names.append(role_name)\n\n    for role_name in role_names:\n        # Let's collect the ranks of this role in all machines.\n        machines = {}\n        for client_id, machine_id in response.role[role_name]:\n            if machine_id not in machines:\n                machines[machine_id] = []\n            machines[machine_id].append(client_id)\n\n        num_machines = len(machines)\n        PER_ROLE_RANK[role_name] = {}\n        per_role_rank = 0\n        for i in range(num_machines):\n            clients = machines[i]\n            clients = np.sort(clients)\n            for client_id in clients:\n                GLOBAL_RANK[client_id] = global_rank\n                global_rank += 1\n                PER_ROLE_RANK[role_name][client_id] = per_role_rank\n                per_role_rank += 1\n\n\ndef get_global_rank():\n    \"\"\"Get the global rank\n\n    The rank can globally identify the client process. For the client processes\n    of the same role, their ranks are in a contiguous range.\n    \"\"\"\n    if IS_STANDALONE:\n        return 0\n    else:\n        return GLOBAL_RANK[rpc.get_rank()]\n\n\ndef get_rank(role):\n    \"\"\"Get the role-specific rank\"\"\"\n    if IS_STANDALONE:\n        return 0\n    else:\n        return PER_ROLE_RANK[role][rpc.get_rank()]\n\n\ndef get_trainer_rank():\n    \"\"\"Get the rank of the current trainer process.\n\n    This function can only be called in the trainer process. It will result in\n    an error if it's called in the process of other roles.\n    \"\"\"\n    assert CUR_ROLE == \"default\"\n    if IS_STANDALONE:\n        return 0\n    else:\n        return PER_ROLE_RANK[\"default\"][rpc.get_rank()]\n\n\ndef get_role():\n    \"\"\"Get the role of the current process\"\"\"\n    return CUR_ROLE\n\n\ndef get_num_trainers():\n    \"\"\"Get the number of trainer processes\"\"\"\n    return len(PER_ROLE_RANK[\"default\"])\n\n\nrpc.register_service(REGISTER_ROLE, RegisterRoleRequest, RegisterRoleResponse)\nrpc.register_service(GET_ROLE, GetRoleRequest, GetRoleResponse)\n"
  },
  {
    "path": "python/dgl/distributed/rpc.py",
    "content": "\"\"\"RPC components. They are typically functions or utilities used by both\nserver and clients.\"\"\"\nimport abc\nimport os\nimport pickle\nimport random\n\nimport numpy as np\n\nfrom .. import backend as F\nfrom .._ffi.function import _init_api\nfrom .._ffi.object import ObjectBase, register_object\nfrom ..base import DGLError\nfrom .constants import SERVER_EXIT\n\n__all__ = [\n    \"set_rank\",\n    \"get_rank\",\n    \"Request\",\n    \"Response\",\n    \"register_service\",\n    \"create_sender\",\n    \"create_receiver\",\n    \"finalize_sender\",\n    \"finalize_receiver\",\n    \"wait_for_senders\",\n    \"connect_receiver\",\n    \"read_ip_config\",\n    \"get_group_id\",\n    \"get_num_machines\",\n    \"set_num_machines\",\n    \"get_machine_id\",\n    \"set_machine_id\",\n    \"send_request\",\n    \"recv_request\",\n    \"send_response\",\n    \"recv_response\",\n    \"remote_call\",\n    \"send_request_to_machine\",\n    \"remote_call_to_machine\",\n    \"fast_pull\",\n    \"DistConnectError\",\n    \"get_num_client\",\n    \"set_num_client\",\n    \"client_barrier\",\n    \"copy_data_to_shared_memory\",\n]\n\nREQUEST_CLASS_TO_SERVICE_ID = {}\nRESPONSE_CLASS_TO_SERVICE_ID = {}\nSERVICE_ID_TO_PROPERTY = {}\n\nDEFUALT_PORT = 30050\n\n\ndef read_ip_config(filename, num_servers):\n    \"\"\"Read network configuration information of server from file.\n\n    For exampple, the following TXT shows a 4-machine configuration:\n\n        172.31.40.143\n        172.31.36.140\n        172.31.47.147\n        172.31.30.180\n\n    Users can also set user-specified port for this network configuration. For example:\n\n        172.31.40.143 20090\n        172.31.36.140 20090\n        172.31.47.147 20090\n        172.31.30.180 20090\n\n    Note that, DGL supports multiple backup servers that shares data with each others\n    on the same machine via shared-memory tensor. The num_servers should be >= 1. For example,\n    if we set num_servers to 5, it means that we have 1 main server and 4 backup servers on\n    current machine.\n\n    Parameters\n    ----------\n    filename : str\n        Path of IP configuration file.\n\n    num_servers : int\n        Server count on each machine.\n\n    Returns\n    -------\n    dict\n        server namebook.\n        The key is server_id (int)\n        The value is [machine_id, ip, port, num_servers] ([int, str, int, int])\n\n        e.g.,\n\n          {0:[0, '172.31.40.143', 30050, 2],\n           1:[0, '172.31.40.143', 30051, 2],\n           2:[1, '172.31.36.140', 30050, 2],\n           3:[1, '172.31.36.140', 30051, 2],\n           4:[2, '172.31.47.147', 30050, 2],\n           5:[2, '172.31.47.147', 30051, 2],\n           6:[3, '172.31.30.180', 30050, 2],\n           7:[3, '172.31.30.180', 30051, 2]}\n    \"\"\"\n    assert len(filename) > 0, \"filename cannot be empty.\"\n    assert num_servers > 0, (\n        \"num_servers (%d) must be a positive number.\" % num_servers\n    )\n    server_namebook = {}\n    try:\n        server_id = 0\n        machine_id = 0\n        lines = [line.rstrip(\"\\n\") for line in open(filename)]\n        for line in lines:\n            result = line.split()\n            if len(result) == 2:\n                port = int(result[1])\n            elif len(result) == 1:\n                port = DEFUALT_PORT\n            else:\n                raise RuntimeError(\"length of result can only be 1 or 2.\")\n            ip_addr = result[0]\n            for s_count in range(num_servers):\n                server_namebook[server_id] = [\n                    machine_id,\n                    ip_addr,\n                    port + s_count,\n                    num_servers,\n                ]\n                server_id += 1\n            machine_id += 1\n    except RuntimeError:\n        print(\"Error: data format on each line should be: [ip] [port]\")\n    return server_namebook\n\n\ndef reset():\n    \"\"\"Reset the rpc context\"\"\"\n    _CAPI_DGLRPCReset()\n\n\ndef create_sender(max_queue_size):\n    \"\"\"Create rpc sender of this process.\n\n    Parameters\n    ----------\n    max_queue_size : int\n        Maximal size (bytes) of network queue buffer.\n    \"\"\"\n    max_thread_count = int(os.getenv(\"DGL_SOCKET_MAX_THREAD_COUNT\", \"0\"))\n    _CAPI_DGLRPCCreateSender(int(max_queue_size), max_thread_count)\n\n\ndef create_receiver(max_queue_size):\n    \"\"\"Create rpc receiver of this process.\n\n    Parameters\n    ----------\n    max_queue_size : int\n        Maximal size (bytes) of network queue buffer.\n    \"\"\"\n    max_thread_count = int(os.getenv(\"DGL_SOCKET_MAX_THREAD_COUNT\", \"0\"))\n    _CAPI_DGLRPCCreateReceiver(int(max_queue_size), max_thread_count)\n\n\ndef finalize_sender():\n    \"\"\"Finalize rpc sender of this process.\"\"\"\n    _CAPI_DGLRPCFinalizeSender()\n\n\ndef finalize_receiver():\n    \"\"\"Finalize rpc receiver of this process.\"\"\"\n    _CAPI_DGLRPCFinalizeReceiver()\n\n\ndef wait_for_senders(ip_addr, port, num_senders):\n    \"\"\"Wait all of the senders' connections.\n\n    This api will be blocked until all the senders connect to the receiver.\n\n    Parameters\n    ----------\n    ip_addr : str\n        receiver's IP address, e,g, '192.168.8.12'\n    port : int\n        receiver's port\n    num_senders : int\n        total number of senders\n    \"\"\"\n    _CAPI_DGLRPCWaitForSenders(ip_addr, int(port), int(num_senders))\n\n\ndef connect_receiver(ip_addr, port, recv_id, group_id=-1):\n    \"\"\"Connect to target receiver\n\n    Parameters\n    ----------\n    ip_addr : str\n        receiver's IP address, e,g, '192.168.8.12'\n    port : int\n        receiver's listening port\n    recv_id : int\n        receiver's ID\n    \"\"\"\n    target_id = (\n        recv_id if group_id == -1 else register_client(recv_id, group_id)\n    )\n    if target_id < 0:\n        raise DGLError(\"Invalid target id: {}\".format(target_id))\n    return _CAPI_DGLRPCConnectReceiver(ip_addr, int(port), int(target_id))\n\n\ndef connect_receiver_finalize(max_try_times):\n    \"\"\"Finalize the action to connect to receivers. Make sure that either all connections are\n    successfully established or connection fails.\n\n    When \"socket\" network backend is in use, the function issues actual requests to receiver\n    sockets to establish connections.\n\n    Parameters\n    ----------\n    max_try_times : int\n        maximum try times\n    \"\"\"\n    return _CAPI_DGLRPCConnectReceiverFinalize(max_try_times)\n\n\ndef set_rank(rank):\n    \"\"\"Set the rank of this process.\n\n    If the process is a client, this is equal to client ID. Otherwise, the process\n    is a server and this is equal to server ID.\n\n    Parameters\n    ----------\n    rank : int\n        Rank value\n    \"\"\"\n    _CAPI_DGLRPCSetRank(int(rank))\n\n\ndef get_rank():\n    \"\"\"Get the rank of this process.\n\n    If the process is a client, this is equal to client ID. Otherwise, the process\n    is a server and this is equal to server ID.\n\n    Returns\n    -------\n    int\n        Rank value\n    \"\"\"\n    return _CAPI_DGLRPCGetRank()\n\n\ndef set_machine_id(machine_id):\n    \"\"\"Set current machine ID\n\n    Parameters\n    ----------\n    machine_id : int\n        Current machine ID\n    \"\"\"\n    _CAPI_DGLRPCSetMachineID(int(machine_id))\n\n\ndef get_machine_id():\n    \"\"\"Get current machine ID\n\n    Returns\n    -------\n    int\n        machine ID\n    \"\"\"\n    return _CAPI_DGLRPCGetMachineID()\n\n\ndef set_num_machines(num_machines):\n    \"\"\"Set number of machine\n\n    Parameters\n    ----------\n    num_machines : int\n        Number of machine\n    \"\"\"\n    _CAPI_DGLRPCSetNumMachines(int(num_machines))\n\n\ndef get_num_machines():\n    \"\"\"Get number of machines\n\n    Returns\n    -------\n    int\n        number of machines\n    \"\"\"\n    return _CAPI_DGLRPCGetNumMachines()\n\n\ndef set_num_server(num_server):\n    \"\"\"Set the total number of server.\"\"\"\n    _CAPI_DGLRPCSetNumServer(int(num_server))\n\n\ndef get_num_server():\n    \"\"\"Get the total number of server.\"\"\"\n    return _CAPI_DGLRPCGetNumServer()\n\n\ndef set_num_client(num_client):\n    \"\"\"Set the total number of client.\"\"\"\n    _CAPI_DGLRPCSetNumClient(int(num_client))\n\n\ndef get_num_client():\n    \"\"\"Get the total number of client.\"\"\"\n    return _CAPI_DGLRPCGetNumClient()\n\n\ndef set_num_server_per_machine(num_server):\n    \"\"\"Set the total number of server per machine\"\"\"\n    _CAPI_DGLRPCSetNumServerPerMachine(num_server)\n\n\ndef get_num_server_per_machine():\n    \"\"\"Get the total number of server per machine\"\"\"\n    return _CAPI_DGLRPCGetNumServerPerMachine()\n\n\ndef incr_msg_seq():\n    \"\"\"Increment the message sequence number and return the old one.\n\n    Returns\n    -------\n    long\n        Message sequence number\n    \"\"\"\n    return _CAPI_DGLRPCIncrMsgSeq()\n\n\ndef get_msg_seq():\n    \"\"\"Get the current message sequence number.\n\n    Returns\n    -------\n    long\n        Message sequence number\n    \"\"\"\n    return _CAPI_DGLRPCGetMsgSeq()\n\n\ndef set_msg_seq(msg_seq):\n    \"\"\"Set the current message sequence number.\n\n    Parameters\n    ----------\n    msg_seq : int\n        sequence number of current rpc message.\n    \"\"\"\n    _CAPI_DGLRPCSetMsgSeq(int(msg_seq))\n\n\ndef register_service(service_id, req_cls, res_cls=None):\n    \"\"\"Register a service to RPC.\n\n    Parameter\n    ---------\n    service_id : int\n        Service ID.\n    req_cls : class\n        Request class.\n    res_cls : class, optional\n        Response class. If none, the service has no response.\n    \"\"\"\n    REQUEST_CLASS_TO_SERVICE_ID[req_cls] = service_id\n    if res_cls is not None:\n        RESPONSE_CLASS_TO_SERVICE_ID[res_cls] = service_id\n    SERVICE_ID_TO_PROPERTY[service_id] = (req_cls, res_cls)\n\n\ndef get_service_property(service_id):\n    \"\"\"Get service property.\n\n    Parameters\n    ----------\n    service_id : int\n        Service ID.\n\n    Returns\n    -------\n    (class, class)\n        (Request class, Response class)\n    \"\"\"\n    return SERVICE_ID_TO_PROPERTY[service_id]\n\n\nclass Request:\n    \"\"\"Base request class\"\"\"\n\n    @abc.abstractmethod\n    def __getstate__(self):\n        \"\"\"Get serializable states.\n\n        Must be inherited by subclasses. For array members, return them as\n        individual return values (i.e., do not put them in containers like\n        dictionary or list).\n        \"\"\"\n\n    @abc.abstractmethod\n    def __setstate__(self, state):\n        \"\"\"Construct the request object from serialized states.\n\n        Must be inherited by subclasses.\n        \"\"\"\n\n    @abc.abstractmethod\n    def process_request(self, server_state):\n        \"\"\"Server-side function to process the request.\n\n        Must be inherited by subclasses.\n\n        Parameters\n        ----------\n        server_state : ServerState\n            Server state data.\n\n        Returns\n        -------\n        Response\n            Response of this request or None if no response.\n        \"\"\"\n\n    @property\n    def service_id(self):\n        \"\"\"Get service ID.\"\"\"\n        cls = self.__class__\n        sid = REQUEST_CLASS_TO_SERVICE_ID.get(cls, None)\n        if sid is None:\n            raise DGLError(\n                \"Request class {} has not been registered as a service.\".format(\n                    cls\n                )\n            )\n        return sid\n\n\nclass Response:\n    \"\"\"Base response class\"\"\"\n\n    @abc.abstractmethod\n    def __getstate__(self):\n        \"\"\"Get serializable states.\n\n        Must be inherited by subclasses. For array members, return them as\n        individual return values (i.e., do not put them in containers like\n        dictionary or list).\n        \"\"\"\n\n    @abc.abstractmethod\n    def __setstate__(self, state):\n        \"\"\"Construct the response object from serialized states.\n\n        Must be inherited by subclasses.\n        \"\"\"\n\n    @property\n    def service_id(self):\n        \"\"\"Get service ID.\"\"\"\n        cls = self.__class__\n        sid = RESPONSE_CLASS_TO_SERVICE_ID.get(cls, None)\n        if sid is None:\n            raise DGLError(\n                \"Response class {} has not been registered as a service.\".format(\n                    cls\n                )\n            )\n        return sid\n\n\ndef serialize_to_payload(serializable):\n    \"\"\"Serialize an object to payloads.\n\n    The object must have implemented the __getstate__ function.\n\n    Parameters\n    ----------\n    serializable : object\n        Any serializable object.\n\n    Returns\n    -------\n    bytearray\n        Serialized payload buffer.\n    list[Tensor]\n        A list of tensor payloads.\n    \"\"\"\n    state = serializable.__getstate__()\n    if not isinstance(state, tuple):\n        state = (state,)\n    nonarray_pos = []\n    nonarray_state = []\n    array_state = []\n    for i, arr_state in enumerate(state):\n        if F.is_tensor(arr_state):\n            array_state.append(arr_state)\n        else:\n            nonarray_state.append(arr_state)\n            nonarray_pos.append(i)\n    data = bytearray(pickle.dumps((nonarray_pos, nonarray_state)))\n    return data, array_state\n\n\nclass PlaceHolder:\n    \"\"\"PlaceHolder object for deserialization\"\"\"\n\n\n_PLACEHOLDER = PlaceHolder()\n\n\ndef deserialize_from_payload(cls, data, tensors):\n    \"\"\"Deserialize and reconstruct the object from payload.\n\n    The object must have implemented the __setstate__ function.\n\n    Parameters\n    ----------\n    cls : class\n        The object class.\n    data : bytearray\n        Serialized data buffer.\n    tensors : list[Tensor]\n        A list of tensor payloads.\n\n    Returns\n    -------\n    object\n        De-serialized object of class cls.\n    \"\"\"\n    pos, nonarray_state = pickle.loads(data)\n    # Use _PLACEHOLDER to distinguish with other deserizliaed elements\n    state = [_PLACEHOLDER] * (len(nonarray_state) + len(tensors))\n    for i, no_state in zip(pos, nonarray_state):\n        state[i] = no_state\n    if len(tensors) != 0:\n        j = 0\n        state_len = len(state)\n        for i in range(state_len):\n            if state[i] is _PLACEHOLDER:\n                state[i] = tensors[j]\n                j += 1\n    if len(state) == 1:\n        state = state[0]\n    else:\n        state = tuple(state)\n    obj = cls.__new__(cls)\n    obj.__setstate__(state)\n    return obj\n\n\n@register_object(\"rpc.RPCMessage\")\nclass RPCMessage(ObjectBase):\n    \"\"\"Serialized RPC message that can be sent to remote processes.\n\n    This class can be used as argument or return value for C API.\n\n    Attributes\n    ----------\n    service_id : int\n        The remote service ID the message wishes to invoke.\n    msg_seq : int\n        Sequence number of this message.\n    client_id : int\n        The client ID.\n    server_id : int\n        The server ID.\n    data : bytearray\n        Payload buffer carried by this request.\n    tensors : list[tensor]\n        Extra payloads in the form of tensors.\n    group_id : int\n        The group ID\n    \"\"\"\n\n    def __init__(\n        self,\n        service_id,\n        msg_seq,\n        client_id,\n        server_id,\n        data,\n        tensors,\n        group_id=0,\n    ):\n        self.__init_handle_by_constructor__(\n            _CAPI_DGLRPCCreateRPCMessage,\n            int(service_id),\n            int(msg_seq),\n            int(client_id),\n            int(server_id),\n            data,\n            [F.zerocopy_to_dgl_ndarray(tsor) for tsor in tensors],\n            int(group_id),\n        )\n\n    @property\n    def service_id(self):\n        \"\"\"Get service ID.\"\"\"\n        return _CAPI_DGLRPCMessageGetServiceId(self)\n\n    @property\n    def msg_seq(self):\n        \"\"\"Get message sequence number.\"\"\"\n        return _CAPI_DGLRPCMessageGetMsgSeq(self)\n\n    @property\n    def client_id(self):\n        \"\"\"Get client ID.\"\"\"\n        return _CAPI_DGLRPCMessageGetClientId(self)\n\n    @property\n    def server_id(self):\n        \"\"\"Get server ID.\"\"\"\n        return _CAPI_DGLRPCMessageGetServerId(self)\n\n    @property\n    def data(self):\n        \"\"\"Get payload buffer.\"\"\"\n        return _CAPI_DGLRPCMessageGetData(self)\n\n    @property\n    def tensors(self):\n        \"\"\"Get tensor payloads.\"\"\"\n        rst = _CAPI_DGLRPCMessageGetTensors(self)\n        return [F.zerocopy_from_dgl_ndarray(tsor) for tsor in rst]\n\n    @property\n    def group_id(self):\n        \"\"\"Get group ID.\"\"\"\n        return _CAPI_DGLRPCMessageGetGroupId(self)\n\n\ndef send_request(target, request):\n    \"\"\"Send one request to the target server.\n\n    Serialize the given request object to an :class:`RPCMessage` and send it\n    out.\n\n    The operation is non-blocking -- it does not guarantee the payloads have\n    reached the target or even have left the sender process. However,\n    all the payloads (i.e., data and arrays) can be safely freed after this\n    function returns.\n\n    Parameters\n    ----------\n    target : int\n        ID of target server.\n    request : Request\n        The request to send.\n\n    Raises\n    ------\n    ConnectionError if there is any problem with the connection.\n    \"\"\"\n    service_id = request.service_id\n    msg_seq = incr_msg_seq()\n    client_id = get_rank()\n    server_id = target\n    data, tensors = serialize_to_payload(request)\n    msg = RPCMessage(\n        service_id,\n        msg_seq,\n        client_id,\n        server_id,\n        data,\n        tensors,\n        group_id=get_group_id(),\n    )\n    send_rpc_message(msg, server_id)\n\n\ndef send_request_to_machine(target, request):\n    \"\"\"Send one request to the target machine, which will randomly\n    select a server node to process this request.\n\n    The operation is non-blocking -- it does not guarantee the payloads have\n    reached the target or even have left the sender process. However,\n    all the payloads (i.e., data and arrays) can be safely freed after this\n    function returns.\n\n    Parameters\n    ----------\n    target : int\n        ID of target machine.\n    request : Request\n        The request to send.\n\n    Raises\n    ------\n    ConnectionError if there is any problem with the connection.\n    \"\"\"\n    service_id = request.service_id\n    msg_seq = incr_msg_seq()\n    client_id = get_rank()\n    server_id = random.randint(\n        target * get_num_server_per_machine(),\n        (target + 1) * get_num_server_per_machine() - 1,\n    )\n    data, tensors = serialize_to_payload(request)\n    msg = RPCMessage(\n        service_id, msg_seq, client_id, server_id, data, tensors, get_group_id()\n    )\n    send_rpc_message(msg, server_id)\n\n\ndef send_response(target, response, group_id):\n    \"\"\"Send one response to the target client.\n\n    Serialize the given response object to an :class:`RPCMessage` and send it\n    out.\n\n    The operation is non-blocking -- it does not guarantee the payloads have\n    reached the target or even have left the sender process. However,\n    all the payloads (i.e., data and arrays) can be safely freed after this\n    function returns.\n\n    Parameters\n    ----------\n    target : int\n        ID of target client.\n    response : Response\n        The response to send.\n    group_id : int\n        Group ID of target client.\n\n    Raises\n    ------\n    ConnectionError if there is any problem with the connection.\n    \"\"\"\n    service_id = response.service_id\n    msg_seq = get_msg_seq()\n    client_id = target\n    server_id = get_rank()\n    data, tensors = serialize_to_payload(response)\n    msg = RPCMessage(\n        service_id, msg_seq, client_id, server_id, data, tensors, group_id\n    )\n    send_rpc_message(msg, get_client(client_id, group_id))\n\n\ndef recv_request(timeout=0):\n    \"\"\"Receive one request.\n\n    Receive one :class:`RPCMessage` and de-serialize it into a proper Request object.\n\n    The operation is blocking -- it returns when it receives any message\n    or it times out.\n\n    Parameters\n    ----------\n    timeout : int, optional\n        The timeout value in milliseconds. If zero, wait indefinitely.\n\n    Returns\n    -------\n    req : request\n        One request received from the target, or None if it times out.\n    client_id : int\n        Client' ID received from the target, or -1 if it times out.\n    group_id : int\n        Group' ID received from the target, or -1 if it times out.\n\n    Raises\n    ------\n    ConnectionError if there is any problem with the connection.\n    \"\"\"\n    msg = recv_rpc_message(timeout)\n    if msg is None:\n        return None, -1, -1\n    set_msg_seq(msg.msg_seq)\n    req_cls, _ = SERVICE_ID_TO_PROPERTY[msg.service_id]\n    if req_cls is None:\n        raise DGLError(\n            \"Got request message from service ID {}, \"\n            \"but no request class is registered.\".format(msg.service_id)\n        )\n    req = deserialize_from_payload(req_cls, msg.data, msg.tensors)\n    if msg.server_id != get_rank():\n        raise DGLError(\n            \"Got request sent to server {}, \"\n            \"different from my rank {}!\".format(msg.server_id, get_rank())\n        )\n    return req, msg.client_id, msg.group_id\n\n\ndef recv_response(timeout=0):\n    \"\"\"Receive one response.\n\n    Receive one :class:`RPCMessage` and de-serialize it into a proper Response object.\n\n    The operation is blocking -- it returns when it receives any message\n    or it times out.\n\n    Parameters\n    ----------\n    timeout : int, optional\n        The timeout value in milliseconds. If zero, wait indefinitely.\n\n    Returns\n    -------\n    res : Response\n        One response received from the target, or None if it times out.\n\n    Raises\n    ------\n    ConnectionError if there is any problem with the connection.\n    \"\"\"\n    msg = recv_rpc_message(timeout)\n    if msg is None:\n        return None\n    _, res_cls = SERVICE_ID_TO_PROPERTY[msg.service_id]\n    if res_cls is None:\n        raise DGLError(\n            \"Got response message from service ID {}, \"\n            \"but no response class is registered.\".format(msg.service_id)\n        )\n    res = deserialize_from_payload(res_cls, msg.data, msg.tensors)\n    if msg.client_id != get_rank() and get_rank() != -1:\n        raise DGLError(\n            \"Got response of request sent by client {}, \"\n            \"different from my rank {}!\".format(msg.client_id, get_rank())\n        )\n    if msg.group_id != get_group_id():\n        raise DGLError(\n            \"Got response of request sent by group {}, \"\n            \"different from my group {}!\".format(msg.group_id, get_group_id())\n        )\n    return res\n\n\ndef remote_call(target_and_requests, timeout=0):\n    \"\"\"Invoke registered services on remote servers and collect responses.\n\n    The operation is blocking -- it returns when it receives all responses\n    or it times out.\n\n    If the target server state is available locally, it invokes local computation\n    to calculate the response.\n\n    Parameters\n    ----------\n    target_and_requests : list[(int, Request)]\n        A list of requests and the server they should be sent to.\n    timeout : int, optional\n        The timeout value in milliseconds. If zero, wait indefinitely.\n\n    Returns\n    -------\n    list[Response]\n        Responses for each target-request pair. If the request does not have\n        response, None is placed.\n\n    Raises\n    ------\n    ConnectionError if there is any problem with the connection.\n    \"\"\"\n    all_res = [None] * len(target_and_requests)\n    msgseq2pos = {}\n    num_res = 0\n    myrank = get_rank()\n    for pos, (target, request) in enumerate(target_and_requests):\n        # send request\n        service_id = request.service_id\n        msg_seq = incr_msg_seq()\n        client_id = get_rank()\n        server_id = random.randint(\n            target * get_num_server_per_machine(),\n            (target + 1) * get_num_server_per_machine() - 1,\n        )\n        data, tensors = serialize_to_payload(request)\n        msg = RPCMessage(\n            service_id,\n            msg_seq,\n            client_id,\n            server_id,\n            data,\n            tensors,\n            get_group_id(),\n        )\n        send_rpc_message(msg, server_id)\n        # check if has response\n        res_cls = get_service_property(service_id)[1]\n        if res_cls is not None:\n            num_res += 1\n            msgseq2pos[msg_seq] = pos\n    while num_res != 0:\n        # recv response\n        msg = recv_rpc_message(timeout)\n        if msg is None:\n            raise DGLError(\n                f\"Timed out for receiving message within {timeout} milliseconds\"\n            )\n        num_res -= 1\n        _, res_cls = SERVICE_ID_TO_PROPERTY[msg.service_id]\n        if res_cls is None:\n            raise DGLError(\n                \"Got response message from service ID {}, \"\n                \"but no response class is registered.\".format(msg.service_id)\n            )\n        res = deserialize_from_payload(res_cls, msg.data, msg.tensors)\n        if msg.client_id != myrank:\n            raise DGLError(\n                \"Got reponse of request sent by client {}, \"\n                \"different from my rank {}!\".format(msg.client_id, myrank)\n            )\n        # set response\n        all_res[msgseq2pos[msg.msg_seq]] = res\n    return all_res\n\n\ndef send_requests_to_machine(target_and_requests):\n    \"\"\"Send requests to the remote machines.\n\n    This operation isn't block. It returns immediately once it sends all requests.\n\n    Parameters\n    ----------\n    target_and_requests : list[(int, Request)]\n        A list of requests and the machine they should be sent to.\n    timeout : int, optional\n        The timeout value in milliseconds. If zero, wait indefinitely.\n\n    Returns\n    -------\n    msgseq2pos : dict\n        map the message sequence number to its position in the input list.\n    \"\"\"\n    msgseq2pos = {}\n    for pos, (target, request) in enumerate(target_and_requests):\n        # send request\n        service_id = request.service_id\n        msg_seq = incr_msg_seq()\n        client_id = get_rank()\n\n        server_id = random.randint(\n            target * get_num_server_per_machine(),\n            (target + 1) * get_num_server_per_machine() - 1,\n        )\n        data, tensors = serialize_to_payload(request)\n        msg = RPCMessage(\n            service_id,\n            msg_seq,\n            client_id,\n            server_id,\n            data,\n            tensors,\n            get_group_id(),\n        )\n        send_rpc_message(msg, server_id)\n        # check if has response\n        res_cls = get_service_property(service_id)[1]\n        if res_cls is not None:\n            msgseq2pos[msg_seq] = pos\n    return msgseq2pos\n\n\ndef recv_responses(msgseq2pos, timeout=0):\n    \"\"\"Receive responses\n\n    It returns the responses in the same order as the requests. The order of requests\n    are stored in msgseq2pos.\n\n    The operation is blocking -- it returns when it receives all responses\n    or it times out.\n\n    Parameters\n    ----------\n    msgseq2pos : dict\n        map the message sequence number to its position in the input list.\n    timeout : int, optional\n        The timeout value in milliseconds. If zero, wait indefinitely.\n\n    Returns\n    -------\n    list[Response]\n        Responses for each target-request pair. If the request does not have\n        response, None is placed.\n    \"\"\"\n    myrank = get_rank()\n    size = np.max(list(msgseq2pos.values())) + 1\n    all_res = [None] * size\n    num_res = len(msgseq2pos)\n    while num_res != 0:\n        # recv response\n        msg = recv_rpc_message(timeout)\n        if msg is None:\n            raise DGLError(\n                f\"Timed out for receiving message within {timeout} milliseconds\"\n            )\n        num_res -= 1\n        _, res_cls = SERVICE_ID_TO_PROPERTY[msg.service_id]\n        if res_cls is None:\n            raise DGLError(\n                \"Got response message from service ID {}, \"\n                \"but no response class is registered.\".format(msg.service_id)\n            )\n        res = deserialize_from_payload(res_cls, msg.data, msg.tensors)\n        if msg.client_id != myrank:\n            raise DGLError(\n                \"Got reponse of request sent by client {}, \"\n                \"different from my rank {}!\".format(msg.client_id, myrank)\n            )\n        # set response\n        all_res[msgseq2pos[msg.msg_seq]] = res\n    return all_res\n\n\ndef remote_call_to_machine(target_and_requests, timeout=0):\n    \"\"\"Invoke registered services on remote machine\n    (which will ramdom select a server to process the request) and collect responses.\n\n    The operation is blocking -- it returns when it receives all responses\n    or it times out.\n\n    If the target server state is available locally, it invokes local computation\n    to calculate the response.\n\n    Parameters\n    ----------\n    target_and_requests : list[(int, Request)]\n        A list of requests and the machine they should be sent to.\n    timeout : int, optional\n        The timeout value in milliseconds. If zero, wait indefinitely.\n\n    Returns\n    -------\n    list[Response]\n        Responses for each target-request pair. If the request does not have\n        response, None is placed.\n\n    Raises\n    ------\n    ConnectionError if there is any problem with the connection.\n    \"\"\"\n    msgseq2pos = send_requests_to_machine(target_and_requests)\n    return recv_responses(msgseq2pos, timeout)\n\n\ndef send_rpc_message(msg, target):\n    \"\"\"Send one message to the target server.\n\n    The operation is non-blocking -- it does not guarantee the payloads have\n    reached the target or even have left the sender process. However,\n    all the payloads (i.e., data and arrays) can be safely freed after this\n    function returns.\n\n    The data buffer in the requst will be copied to internal buffer for actual\n    transmission, while no memory copy for tensor payloads (a.k.a. zero-copy).\n    The underlying sending threads will hold references to the tensors until\n    the contents have been transmitted.\n\n    Parameters\n    ----------\n    msg : RPCMessage\n        The message to send.\n    target : int\n        target ID\n\n    Raises\n    ------\n    ConnectionError if there is any problem with the connection.\n    \"\"\"\n    _CAPI_DGLRPCSendRPCMessage(msg, int(target))\n\n\ndef recv_rpc_message(timeout=0):\n    \"\"\"Receive one message.\n\n    The operation is blocking -- it returns when it receives any message\n    or it times out.\n\n    Parameters\n    ----------\n    timeout : int, optional\n        The timeout value in milliseconds. If zero, wait indefinitely.\n\n    Returns\n    -------\n    msg : RPCMessage\n        One rpc message received from the target, or None if it times out.\n\n    Raises\n    ------\n    ConnectionError if there is any problem with the connection.\n    \"\"\"\n    msg = _CAPI_DGLRPCCreateEmptyRPCMessage()\n    status = _CAPI_DGLRPCRecvRPCMessage(timeout, msg)\n    return msg if status == 0 else None\n\n\ndef client_barrier():\n    \"\"\"Barrier all client processes\"\"\"\n    req = ClientBarrierRequest()\n    send_request(0, req)\n    res = recv_response()\n    assert res.msg == \"barrier\"\n\n\ndef finalize_server():\n    \"\"\"Finalize resources of current server\"\"\"\n    finalize_sender()\n    finalize_receiver()\n    print(\"Server (%d) shutdown.\" % get_rank())\n\n\ndef fast_pull(\n    name,\n    id_tensor,\n    part_id,\n    service_id,\n    machine_count,\n    group_count,\n    machine_id,\n    client_id,\n    local_data,\n    policy,\n):\n    \"\"\"Fast-pull api used by kvstore.\n\n    Parameters\n    ----------\n    name : str\n        data name\n    id_tensor : tensor\n        data ID\n    part_id : tensor\n        partition ID of id_tensor\n    service_id : int\n        service_id of pull request\n    machine_count : int\n        total number of machine\n    group_count : int\n        total number of server inside machine\n    machine_id : int\n        current machine ID\n    client_id : int\n        current client ID\n    local_data : tensor\n        local data tensor\n    policy : PartitionPolicy\n        store the partition information\n    \"\"\"\n    msg_seq = incr_msg_seq()\n    pickle_data = bytearray(pickle.dumps(([0], [name])))\n    global_id = _CAPI_DGLRPCGetGlobalIDFromLocalPartition(\n        F.zerocopy_to_dgl_ndarray(id_tensor),\n        F.zerocopy_to_dgl_ndarray(part_id),\n        machine_id,\n    )\n    global_id = F.zerocopy_from_dgl_ndarray(global_id)\n    g2l_id = policy.to_local(global_id)\n    res_tensor = _CAPI_DGLRPCFastPull(\n        name,\n        int(machine_id),\n        int(machine_count),\n        int(group_count),\n        int(client_id),\n        int(service_id),\n        int(msg_seq),\n        pickle_data,\n        F.zerocopy_to_dgl_ndarray(id_tensor),\n        F.zerocopy_to_dgl_ndarray(part_id),\n        F.zerocopy_to_dgl_ndarray(g2l_id),\n        F.zerocopy_to_dgl_ndarray(local_data),\n    )\n    return F.zerocopy_from_dgl_ndarray(res_tensor)\n\n\ndef register_sig_handler():\n    \"\"\"Register for handling signal event.\"\"\"\n    _CAPI_DGLRPCHandleSignal()\n\n\ndef copy_data_to_shared_memory(dst, source):\n    \"\"\"Copy tensor data to shared-memory tensor\"\"\"\n    F.zerocopy_to_dgl_ndarray(dst).copyfrom(F.zerocopy_to_dgl_ndarray(source))\n\n\n############### Some basic services will be defined here #############\n\nCLIENT_REGISTER = 22451\n\n\nclass ClientRegisterRequest(Request):\n    \"\"\"This request will send client's ip to server.\n\n    Parameters\n    ----------\n    ip_addr : str\n        client's IP address\n    \"\"\"\n\n    def __init__(self, ip_addr):\n        self.ip_addr = ip_addr\n\n    def __getstate__(self):\n        return self.ip_addr\n\n    def __setstate__(self, state):\n        self.ip_addr = state\n\n    def process_request(self, server_state):\n        return None  # do nothing\n\n\nclass ClientRegisterResponse(Response):\n    \"\"\"This response will send assigned ID to client.\n\n    Parameters\n    ----------\n    ID : int\n        client's ID\n    \"\"\"\n\n    def __init__(self, client_id):\n        self.client_id = client_id\n\n    def __getstate__(self):\n        return self.client_id\n\n    def __setstate__(self, state):\n        self.client_id = state\n\n\nSHUT_DOWN_SERVER = 22452\n\n\nclass ShutDownRequest(Request):\n    \"\"\"Client send this request to shut-down a server.\n\n    This request has no response.\n\n    Parameters\n    ----------\n    client_id : int\n        client's ID\n    \"\"\"\n\n    def __init__(self, client_id, force_shutdown_server=False):\n        self.client_id = client_id\n        self.force_shutdown_server = force_shutdown_server\n\n    def __getstate__(self):\n        return self.client_id, self.force_shutdown_server\n\n    def __setstate__(self, state):\n        self.client_id, self.force_shutdown_server = state\n\n    def process_request(self, server_state):\n        assert self.client_id == 0\n        finalize_server()\n        return SERVER_EXIT\n\n\nGET_NUM_CLIENT = 22453\n\n\nclass GetNumberClientsResponse(Response):\n    \"\"\"This reponse will send total number of clients.\n\n    Parameters\n    ----------\n    num_client : int\n        total number of clients\n    \"\"\"\n\n    def __init__(self, num_client):\n        self.num_client = num_client\n\n    def __getstate__(self):\n        return self.num_client\n\n    def __setstate__(self, state):\n        self.num_client = state\n\n\nclass GetNumberClientsRequest(Request):\n    \"\"\"Client send this request to get the total number of client.\n\n    Parameters\n    ----------\n    client_id : int\n        client's ID\n    \"\"\"\n\n    def __init__(self, client_id):\n        self.client_id = client_id\n\n    def __getstate__(self):\n        return self.client_id\n\n    def __setstate__(self, state):\n        self.client_id = state\n\n    def process_request(self, server_state):\n        res = GetNumberClientsResponse(get_num_client())\n        return res\n\n\nCLIENT_BARRIER = 22454\n\n\nclass ClientBarrierResponse(Response):\n    \"\"\"Send the barrier confirmation to client\n\n    Parameters\n    ----------\n    msg : str\n        string msg\n    \"\"\"\n\n    def __init__(self, msg=\"barrier\"):\n        self.msg = msg\n\n    def __getstate__(self):\n        return self.msg\n\n    def __setstate__(self, state):\n        self.msg = state\n\n\nclass ClientBarrierRequest(Request):\n    \"\"\"Send the barrier information to server\n\n    Parameters\n    ----------\n    msg : str\n        string msg\n    \"\"\"\n\n    def __init__(self, msg=\"barrier\"):\n        self.msg = msg\n        self.group_id = get_group_id()\n\n    def __getstate__(self):\n        return self.msg, self.group_id\n\n    def __setstate__(self, state):\n        self.msg, self.group_id = state\n\n    def process_request(self, server_state):\n        _CAPI_DGLRPCSetBarrierCount(\n            _CAPI_DGLRPCGetBarrierCount(self.group_id) + 1, self.group_id\n        )\n        if _CAPI_DGLRPCGetBarrierCount(self.group_id) == get_num_client():\n            _CAPI_DGLRPCSetBarrierCount(0, self.group_id)\n            res_list = []\n            for target_id in range(get_num_client()):\n                res_list.append((target_id, ClientBarrierResponse()))\n            return res_list\n        return None\n\n\ndef set_group_id(group_id):\n    \"\"\"Set current group ID\n\n    Parameters\n    ----------\n    group_id : int\n        Current group ID\n    \"\"\"\n    _CAPI_DGLRPCSetGroupID(int(group_id))\n\n\ndef get_group_id():\n    \"\"\"Get current group ID\n\n    Returns\n    -------\n    int\n        group ID\n    \"\"\"\n    return _CAPI_DGLRPCGetGroupID()\n\n\ndef register_client(client_id, group_id):\n    \"\"\"Register client\n\n    Returns\n    -------\n    int\n        unique client ID\n    \"\"\"\n    return _CAPI_DGLRPCRegisterClient(int(client_id), int(group_id))\n\n\ndef get_client(client_id, group_id):\n    \"\"\"Get global client ID\n\n    Parameters\n    ----------\n    client_id : int\n        client ID\n    group_id : int\n        group ID\n\n    Returns\n    -------\n    int\n        global client ID\n    \"\"\"\n    return _CAPI_DGLRPCGetClient(int(client_id), int(group_id))\n\n\nclass DistConnectError(DGLError):\n    \"\"\"Exception raised for errors if fail to connect peer.\n\n    Attributes\n    ----------\n    kv_store : KVServer\n        reference for KVServer\n    \"\"\"\n\n    def __init__(self, max_try_times, ip=\"\", port=\"\"):\n        peer_str = \"peer[{}:{}]\".format(ip, port) if ip != \"\" else \"peer\"\n        self.message = (\n            \"Failed to build conncetion with {} after {} retries. \"\n            \"Please check network availability or increase max try \"\n            \"times via 'DGL_DIST_MAX_TRY_TIMES'.\".format(\n                peer_str, max_try_times\n            )\n        )\n        super().__init__(self.message)\n\n\n_init_api(\"dgl.distributed.rpc\")\n"
  },
  {
    "path": "python/dgl/distributed/rpc_client.py",
    "content": "\"\"\"Functions used by client.\"\"\"\n\nimport atexit\nimport logging\nimport os\nimport socket\nimport time\n\nfrom . import rpc\nfrom .constants import MAX_QUEUE_SIZE\n\nif os.name != \"nt\":\n    import fcntl\n    import struct\n\n\ndef local_ip4_addr_list():\n    \"\"\"Return a set of IPv4 address\n\n    You can use\n    `logging.getLogger(\"dgl-distributed-socket\").setLevel(logging.WARNING+1)`\n    to disable the warning here\n    \"\"\"\n    assert os.name != \"nt\", \"Do not support Windows rpc yet.\"\n    nic = set()\n    logger = logging.getLogger(\"dgl-distributed-socket\")\n    for if_nidx in socket.if_nameindex():\n        name = if_nidx[1]\n        sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)\n        try:\n            ip_of_ni = fcntl.ioctl(\n                sock.fileno(),\n                0x8915,  # SIOCGIFADDR\n                struct.pack(\"256s\", name[:15].encode(\"UTF-8\")),\n            )\n        except OSError as e:\n            if e.errno == 99:  # EADDRNOTAVAIL\n                logger.warning(\n                    \"Warning! Interface: %s \\n\"\n                    \"IP address not available for interface.\",\n                    name,\n                )\n                continue\n            raise e\n\n        ip_addr = socket.inet_ntoa(ip_of_ni[20:24])\n        nic.add(ip_addr)\n    return nic\n\n\ndef get_local_machine_id(server_namebook):\n    \"\"\"Given server_namebook, find local machine ID\n\n    Parameters\n    ----------\n    server_namebook: dict\n        IP address namebook of server nodes, where key is the server's ID\n        (start from 0) and value is the server's machine_id, IP address,\n        port, and group_count, e.g.,\n\n          {0:'[0, '172.31.40.143', 30050, 2],\n           1:'[0, '172.31.40.143', 30051, 2],\n           2:'[1, '172.31.36.140', 30050, 2],\n           3:'[1, '172.31.36.140', 30051, 2],\n           4:'[2, '172.31.47.147', 30050, 2],\n           5:'[2, '172.31.47.147', 30051, 2],\n           6:'[3, '172.31.30.180', 30050, 2],\n           7:'[3, '172.31.30.180', 30051, 2]}\n\n    Returns\n    -------\n    int\n        local machine ID\n    \"\"\"\n    res = 0\n    ip_list = local_ip4_addr_list()\n    for _, data in server_namebook.items():\n        machine_id = data[0]\n        ip_addr = data[1]\n        if ip_addr in ip_list:\n            res = machine_id\n            break\n    return res\n\n\ndef get_local_usable_addr(probe_addr):\n    \"\"\"Get local usable IP and port\n\n    Returns\n    -------\n    str\n        IP address, e.g., '192.168.8.12:50051'\n    \"\"\"\n    sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)\n    try:\n        # should get the address on the same subnet as probe_addr's\n        sock.connect((probe_addr, 1))\n        ip_addr = sock.getsockname()[0]\n    except ValueError:\n        ip_addr = \"127.0.0.1\"\n    finally:\n        sock.close()\n    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)\n    sock.bind((\"\", 0))\n    sock.listen(1)\n    port = sock.getsockname()[1]\n    sock.close()\n\n    return ip_addr + \":\" + str(port)\n\n\ndef connect_to_server(\n    ip_config,\n    num_servers,\n    max_queue_size=MAX_QUEUE_SIZE,\n    group_id=0,\n):\n    \"\"\"Connect this client to server.\n\n    Parameters\n    ----------\n    ip_config : str\n        Path of server IP configuration file.\n    num_servers : int\n        server count on each machine.\n    max_queue_size : int\n        Maximal size (bytes) of client queue buffer (~20 GB on default).\n        Note that the 20 GB is just an upper-bound and DGL uses zero-copy and\n        it will not allocate 20GB memory at once.\n    group_id : int\n        Indicates which group this client belongs to. Clients that are\n        booted together in each launch are gathered as a group and should\n        have same unique group_id.\n\n    Raises\n    ------\n    ConnectionError : If anything wrong with the connection.\n    \"\"\"\n    assert num_servers > 0, (\n        \"num_servers (%d) must be a positive number.\" % num_servers\n    )\n    assert max_queue_size > 0, (\n        \"queue_size (%d) cannot be a negative number.\" % max_queue_size\n    )\n    # Register some basic service\n    rpc.register_service(\n        rpc.CLIENT_REGISTER,\n        rpc.ClientRegisterRequest,\n        rpc.ClientRegisterResponse,\n    )\n    rpc.register_service(rpc.SHUT_DOWN_SERVER, rpc.ShutDownRequest, None)\n    rpc.register_service(\n        rpc.GET_NUM_CLIENT,\n        rpc.GetNumberClientsRequest,\n        rpc.GetNumberClientsResponse,\n    )\n    rpc.register_service(\n        rpc.CLIENT_BARRIER, rpc.ClientBarrierRequest, rpc.ClientBarrierResponse\n    )\n    rpc.register_sig_handler()\n    server_namebook = rpc.read_ip_config(ip_config, num_servers)\n    num_servers = len(server_namebook)\n    rpc.set_num_server(num_servers)\n    # group_count means how many servers\n    # (main_server + bakcup_server) in total inside a machine.\n    group_count = []\n    max_machine_id = 0\n    for server_info in server_namebook.values():\n        group_count.append(server_info[3])\n        if server_info[0] > max_machine_id:\n            max_machine_id = server_info[0]\n    rpc.set_num_server_per_machine(group_count[0])\n    num_machines = max_machine_id + 1\n    rpc.set_num_machines(num_machines)\n    machine_id = get_local_machine_id(server_namebook)\n    rpc.set_machine_id(machine_id)\n    rpc.set_group_id(group_id)\n    rpc.create_sender(max_queue_size)\n    rpc.create_receiver(max_queue_size)\n    # Get connected with all server nodes\n    max_try_times = int(os.environ.get(\"DGL_DIST_MAX_TRY_TIMES\", 1024))\n    for server_id, addr in server_namebook.items():\n        server_ip = addr[1]\n        server_port = addr[2]\n        try_times = 0\n        while not rpc.connect_receiver(server_ip, server_port, server_id):\n            try_times += 1\n            if try_times % 200 == 0:\n                print(\n                    \"Client is trying to connect server receiver: {}:{}\".format(\n                        server_ip, server_port\n                    )\n                )\n            if try_times >= max_try_times:\n                raise rpc.DistConnectError(\n                    max_try_times, server_ip, server_port\n                )\n            time.sleep(3)\n    if not rpc.connect_receiver_finalize(max_try_times):\n        raise rpc.DistConnectError(max_try_times)\n    # Get local usable IP address and port\n    ip_addr = get_local_usable_addr(server_ip)\n    client_ip, client_port = ip_addr.split(\":\")\n    # Register client on server\n    register_req = rpc.ClientRegisterRequest(ip_addr)\n    for server_id in range(num_servers):\n        rpc.send_request(server_id, register_req)\n    # wait server connect back\n    rpc.wait_for_senders(client_ip, client_port, num_servers)\n    print(\n        \"Client [{}] waits on {}:{}\".format(os.getpid(), client_ip, client_port)\n    )\n    # recv client ID from server\n    res = rpc.recv_response()\n    rpc.set_rank(res.client_id)\n    print(\n        \"Machine (%d) group (%d) client (%d) connect to server successfuly!\"\n        % (machine_id, group_id, rpc.get_rank())\n    )\n    # get total number of client\n    get_client_num_req = rpc.GetNumberClientsRequest(rpc.get_rank())\n    rpc.send_request(0, get_client_num_req)\n    res = rpc.recv_response()\n    rpc.set_num_client(res.num_client)\n    from .dist_context import exit_client, set_initialized\n\n    atexit.register(exit_client)\n    set_initialized(True)\n"
  },
  {
    "path": "python/dgl/distributed/rpc_server.py",
    "content": "\"\"\"Functions used by server.\"\"\"\n\nimport os\nimport time\n\nfrom ..base import DGLError\nfrom . import rpc\nfrom .constants import MAX_QUEUE_SIZE, SERVER_EXIT\n\n\ndef start_server(\n    server_id,\n    ip_config,\n    num_servers,\n    num_clients,\n    server_state,\n    max_queue_size=MAX_QUEUE_SIZE,\n):\n    \"\"\"Start DGL server, which will be shared with all the rpc services.\n\n    This is a blocking function -- it returns only when the server shutdown.\n\n    Parameters\n    ----------\n    server_id : int\n        Current server ID (starts from 0).\n    ip_config : str\n        Path of IP configuration file.\n    num_servers : int\n        Server count on each machine.\n    num_clients : int\n        Total number of clients that will be connected to the server.\n        Note that, we do not support dynamic connection for now. It means\n        that when all the clients connect to server, no client will can be added\n        to the cluster.\n    server_state : ServerSate object\n        Store in main data used by server.\n    max_queue_size : int\n        Maximal size (bytes) of server queue buffer (~20 GB on default).\n        Note that the 20 GB is just an upper-bound because DGL uses zero-copy and\n        it will not allocate 20GB memory at once.\n    \"\"\"\n    assert server_id >= 0, (\n        \"server_id (%d) cannot be a negative number.\" % server_id\n    )\n    assert num_servers > 0, (\n        \"num_servers (%d) must be a positive number.\" % num_servers\n    )\n    assert num_clients >= 0, (\n        \"num_client (%d) cannot be a negative number.\" % num_clients\n    )\n    assert max_queue_size > 0, (\n        \"queue_size (%d) cannot be a negative number.\" % max_queue_size\n    )\n    # Register signal handler.\n    rpc.register_sig_handler()\n    # Register some basic services\n    rpc.register_service(\n        rpc.CLIENT_REGISTER,\n        rpc.ClientRegisterRequest,\n        rpc.ClientRegisterResponse,\n    )\n    rpc.register_service(rpc.SHUT_DOWN_SERVER, rpc.ShutDownRequest, None)\n    rpc.register_service(\n        rpc.GET_NUM_CLIENT,\n        rpc.GetNumberClientsRequest,\n        rpc.GetNumberClientsResponse,\n    )\n    rpc.register_service(\n        rpc.CLIENT_BARRIER, rpc.ClientBarrierRequest, rpc.ClientBarrierResponse\n    )\n    rpc.set_rank(server_id)\n    server_namebook = rpc.read_ip_config(ip_config, num_servers)\n    machine_id = server_namebook[server_id][0]\n    rpc.set_machine_id(machine_id)\n    ip_addr = server_namebook[server_id][1]\n    port = server_namebook[server_id][2]\n    rpc.create_sender(max_queue_size)\n    rpc.create_receiver(max_queue_size)\n    # wait all the senders connect to server.\n    # Once all the senders connect to server, server will not\n    # accept new sender's connection\n    print(\n        \"Server is waiting for connections on [{}:{}]...\".format(ip_addr, port)\n    )\n    rpc.wait_for_senders(ip_addr, port, num_clients)\n    rpc.set_num_client(num_clients)\n    recv_clients = {}\n    while True:\n        # go through if any client group is ready for connection\n        for group_id in list(recv_clients.keys()):\n            ips = recv_clients[group_id]\n            if len(ips) < rpc.get_num_client():\n                continue\n\n            del recv_clients[group_id]\n            # a new client group is ready\n            ips.sort()\n            client_namebook = dict(enumerate(ips))\n            time.sleep(3)  # wait for clients' receivers ready\n            max_try_times = int(os.environ.get(\"DGL_DIST_MAX_TRY_TIMES\", 120))\n            for client_id, addr in client_namebook.items():\n                client_ip, client_port = addr.split(\":\")\n                try_times = 0\n                while not rpc.connect_receiver(\n                    client_ip, client_port, client_id, group_id\n                ):\n                    try_times += 1\n                    if try_times % 200 == 0:\n                        print(\n                            \"Server~{} is trying to connect client receiver: {}:{}\".format(\n                                server_id, client_ip, client_port\n                            )\n                        )\n                    if try_times >= max_try_times:\n                        raise rpc.DistConnectError(\n                            max_try_times, client_ip, client_port\n                        )\n                    time.sleep(1)\n            if not rpc.connect_receiver_finalize(max_try_times):\n                raise rpc.DistConnectError(max_try_times)\n            if rpc.get_rank() == 0:  # server_0 send all the IDs\n                for client_id, _ in client_namebook.items():\n                    register_res = rpc.ClientRegisterResponse(client_id)\n                    rpc.send_response(client_id, register_res, group_id)\n        # receive incomming client requests\n        timeout = 60 * 1000  # in milliseconds\n        req, client_id, group_id = rpc.recv_request(timeout)\n        if req is None:\n            continue\n        if isinstance(req, rpc.ClientRegisterRequest):\n            if group_id not in recv_clients:\n                recv_clients[group_id] = []\n            recv_clients[group_id].append(req.ip_addr)\n            continue\n\n        res = req.process_request(server_state)\n        if res is not None:\n            if isinstance(res, list):\n                for response in res:\n                    target_id, res_data = response\n                    rpc.send_response(target_id, res_data, group_id)\n            elif isinstance(res, str):\n                if res == SERVER_EXIT:\n                    print(\"Server is exiting...\")\n                    return\n                else:\n                    raise DGLError(\"Unexpected response: {}\".format(res))\n            else:\n                rpc.send_response(client_id, res, group_id)\n"
  },
  {
    "path": "python/dgl/distributed/server_state.py",
    "content": "\"\"\"Server data\"\"\"\n\nfrom .._ffi.function import _init_api\n\n# Remove C++ bindings for now, since not used\n\n\nclass ServerState:\n    \"\"\"Data stored in one DGL server.\n\n    In a distributed setting, DGL partitions all data associated with the graph\n    (e.g., node and edge features, graph structure, etc.) to multiple partitions,\n    each handled by one DGL server. Hence, the ServerState class includes all\n    the data associated with a graph partition.\n\n    Under some setup, users may want to deploy servers in a heterogeneous way\n    -- servers are further divided into special groups for fetching/updating\n    node/edge data and for sampling/querying on graph structure respectively.\n    In this case, the ServerState can be configured to include only node/edge\n    data or graph structure.\n\n    Each machine can have multiple server and client processes, but only one\n    server is the *master* server while all the others are backup servers. All\n    clients and backup servers share the state of the master server via shared\n    memory, which means the ServerState class must be serializable and large\n    bulk data (e.g., node/edge features) must be stored in NDArray to leverage\n    shared memory.\n\n    Attributes\n    ----------\n    kv_store : KVServer\n        reference for KVServer\n    graph : DGLGraph\n        Graph structure of one partition\n    total_num_nodes : int\n        Total number of nodes\n    total_num_edges : int\n        Total number of edges\n    partition_book : GraphPartitionBook\n        Graph Partition book\n    use_graphbolt : bool\n        Whether to use graphbolt for dataloading.\n    \"\"\"\n\n    def __init__(self, kv_store, local_g, partition_book, use_graphbolt=False):\n        self._kv_store = kv_store\n        self._graph = local_g\n        self.partition_book = partition_book\n        self._roles = {}\n        self._use_graphbolt = use_graphbolt\n\n    @property\n    def roles(self):\n        \"\"\"Roles of the client processes\"\"\"\n        return self._roles\n\n    @property\n    def kv_store(self):\n        \"\"\"Get data store.\"\"\"\n        return self._kv_store\n\n    @kv_store.setter\n    def kv_store(self, kv_store):\n        self._kv_store = kv_store\n\n    @property\n    def graph(self):\n        \"\"\"Get graph data.\"\"\"\n        return self._graph\n\n    @graph.setter\n    def graph(self, graph):\n        self._graph = graph\n\n    @property\n    def use_graphbolt(self):\n        \"\"\"Whether to use graphbolt for dataloading.\"\"\"\n        return self._use_graphbolt\n\n\n_init_api(\"dgl.distributed.server_state\")\n"
  },
  {
    "path": "python/dgl/distributed/shared_mem_utils.py",
    "content": "\"\"\"Define utility functions for shared memory.\"\"\"\n\nfrom .. import backend as F, ndarray as nd\nfrom .._ffi.ndarray import empty_shared_mem\n\nDTYPE_DICT = F.data_type_dict\nDTYPE_DICT = {DTYPE_DICT[key]: key for key in DTYPE_DICT}\n\n\ndef _get_ndata_path(graph_name, ndata_name):\n    return \"/\" + graph_name + \"_node_\" + ndata_name\n\n\ndef _get_edata_path(graph_name, edata_name):\n    return \"/\" + graph_name + \"_edge_\" + edata_name\n\n\ndef _to_shared_mem(arr, name):\n    dlpack = F.zerocopy_to_dlpack(arr)\n    dgl_tensor = nd.from_dlpack(dlpack)\n    new_arr = empty_shared_mem(\n        name, True, F.shape(arr), DTYPE_DICT[F.dtype(arr)]\n    )\n    dgl_tensor.copyto(new_arr)\n    dlpack = new_arr.to_dlpack()\n    return F.zerocopy_from_dlpack(dlpack)\n"
  },
  {
    "path": "python/dgl/distributed/standalone_kvstore.py",
    "content": "\"\"\"Define a fake kvstore\n\nThis kvstore is used when running in the standalone mode\n\"\"\"\n\nfrom .. import backend as F\n\n\nclass KVClient(object):\n    \"\"\"The fake KVStore client.\n\n    This is to mimic the distributed KVStore client. It's used for DistGraph\n    in standalone mode.\n    \"\"\"\n\n    def __init__(self):\n        self._data = {}\n        self._all_possible_part_policy = {}\n        self._push_handlers = {}\n        self._pull_handlers = {}\n        # Store all graph data name\n        self._gdata_name_list = set()\n\n    @property\n    def all_possible_part_policy(self):\n        \"\"\"Get all possible partition policies\"\"\"\n        return self._all_possible_part_policy\n\n    @property\n    def num_servers(self):\n        \"\"\"Get the number of servers\"\"\"\n        return 1\n\n    def barrier(self):\n        \"\"\"barrier\"\"\"\n\n    def register_push_handler(self, name, func):\n        \"\"\"register push handler\"\"\"\n        self._push_handlers[name] = func\n\n    def register_pull_handler(self, name, func):\n        \"\"\"register pull handler\"\"\"\n        self._pull_handlers[name] = func\n\n    def add_data(self, name, tensor, part_policy):\n        \"\"\"add data to the client\"\"\"\n        self._data[name] = tensor\n        self._gdata_name_list.add(name)\n        if part_policy.policy_str not in self._all_possible_part_policy:\n            self._all_possible_part_policy[part_policy.policy_str] = part_policy\n\n    def init_data(\n        self, name, shape, dtype, part_policy, init_func, is_gdata=True\n    ):\n        \"\"\"add new data to the client\"\"\"\n        self._data[name] = init_func(shape, dtype)\n        if part_policy.policy_str not in self._all_possible_part_policy:\n            self._all_possible_part_policy[part_policy.policy_str] = part_policy\n        if is_gdata:\n            self._gdata_name_list.add(name)\n\n    def delete_data(self, name):\n        \"\"\"delete the data\"\"\"\n        del self._data[name]\n        if name in self._gdata_name_list:\n            self._gdata_name_list.remove(name)\n\n    def data_name_list(self):\n        \"\"\"get the names of all data\"\"\"\n        return list(self._data.keys())\n\n    def gdata_name_list(self):\n        \"\"\"get the names of graph data\"\"\"\n        return list(self._gdata_name_list)\n\n    def get_data_meta(self, name):\n        \"\"\"get the metadata of data\"\"\"\n        return F.dtype(self._data[name]), F.shape(self._data[name]), None\n\n    def push(self, name, id_tensor, data_tensor):\n        \"\"\"push data to kvstore\"\"\"\n        if name in self._push_handlers:\n            self._push_handlers[name](self._data, name, id_tensor, data_tensor)\n        else:\n            F.scatter_row_inplace(self._data[name], id_tensor, data_tensor)\n\n    def pull(self, name, id_tensor):\n        \"\"\"pull data from kvstore\"\"\"\n        if name in self._pull_handlers:\n            return self._pull_handlers[name](self._data, name, id_tensor)\n        else:\n            return F.gather_row(self._data[name], id_tensor)\n\n    def map_shared_data(self, partition_book):\n        \"\"\"Mapping shared-memory tensor from server to client.\"\"\"\n\n    def count_nonzero(self, name):\n        \"\"\"Count nonzero value by pull request from KVServers.\n\n        Parameters\n        ----------\n        name : str\n            data name\n\n        Returns\n        -------\n        int\n            the number of nonzero in this data.\n        \"\"\"\n        return F.count_nonzero(self._data[name])\n\n    @property\n    def data_store(self):\n        \"\"\"Return the local partition of the data storage.\n\n        Returns\n        -------\n        dict[str, Tensor]\n            The tensor storages of the local partition.\n        \"\"\"\n        return self._data\n\n    def union(self, operand1_name, operand2_name, output_name):\n        \"\"\"Compute the union of two mask arrays in the KVStore.\"\"\"\n        self._data[output_name][:] = (\n            self._data[operand1_name] | self._data[operand2_name]\n        )\n"
  },
  {
    "path": "python/dgl/frame.py",
    "content": "\"\"\"Columnar storage for DGLGraph.\"\"\"\nfrom __future__ import absolute_import\n\nfrom collections import namedtuple\nfrom collections.abc import MutableMapping\n\nfrom . import backend as F\nfrom .base import dgl_warning, DGLError\nfrom .init import zero_initializer\nfrom .storages import TensorStorage\nfrom .utils import gather_pinned_tensor_rows, pin_memory_inplace\n\n\nclass _LazyIndex(object):\n    def __init__(self, index):\n        if isinstance(index, list):\n            self._indices = index\n        else:\n            self._indices = [index]\n\n    def __len__(self):\n        return len(self._indices[-1])\n\n    def slice(self, index):\n        \"\"\"Create a new _LazyIndex object sliced by the given index tensor.\"\"\"\n        # if our indices are in the same context, lets just slice now and free\n        # memory, otherwise do nothing until we have to\n        if F.context(self._indices[-1]) == F.context(index):\n            return _LazyIndex(\n                self._indices[:-1] + [F.gather_row(self._indices[-1], index)]\n            )\n        return _LazyIndex(self._indices + [index])\n\n    def flatten(self):\n        \"\"\"Evaluate the chain of indices, and return a single index tensor.\"\"\"\n        flat_index = self._indices[0]\n        # here we actually need to resolve it\n        for index in self._indices[1:]:\n            if F.context(index) != F.context(flat_index):\n                index = F.copy_to(index, F.context(flat_index))\n            flat_index = F.gather_row(flat_index, index)\n        return flat_index\n\n    def record_stream(self, stream):\n        \"\"\"Record stream for index.\n\n        Parameters\n        ----------\n        stream : torch.cuda.Stream.\n        \"\"\"\n        for index in self._indices:\n            if F.context(index) != F.cpu():\n                index.record_stream(stream)\n\n\nclass LazyFeature(object):\n    \"\"\"Placeholder for feature prefetching.\n\n    One can assign this object to ``ndata`` or ``edata`` of the graphs returned by various\n    samplers' :attr:`sample` method.  When DGL's dataloader receives the subgraphs\n    returned by the sampler, it will automatically look up all the ``ndata`` and ``edata``\n    whose data is a LazyFeature, replacing them with the actual data of the corresponding\n    nodes/edges from the original graph instead.  In particular, for a subgraph returned\n    by the sampler has a LazyFeature with name ``k`` in ``subgraph.ndata[key]``:\n\n    .. code:: python\n\n       subgraph.ndata[key] = LazyFeature(k)\n\n    Assuming that ``graph`` is the original graph, DGL's dataloader will perform\n\n    .. code:: python\n\n       subgraph.ndata[key] = graph.ndata[k][subgraph.ndata[dgl.NID]]\n\n    DGL dataloader performs similar replacement for ``edata``.\n    For heterogeneous graphs, the replacement is:\n\n    .. code:: python\n\n       subgraph.nodes[ntype].data[key] = graph.nodes[ntype].data[k][\n           subgraph.nodes[ntype].data[dgl.NID]]\n\n    For MFGs' ``srcdata`` (and similarly ``dstdata``), the replacement is\n\n    .. code:: python\n\n       mfg.srcdata[key] = graph.ndata[k][mfg.srcdata[dgl.NID]]\n\n    Parameters\n    ----------\n    name : str\n        The name of the data in the original graph.\n    id_ : Tensor, optional\n        The ID tensor.\n    \"\"\"\n\n    __slots__ = [\"name\", \"id_\"]\n\n    def __init__(self, name=None, id_=None):\n        self.name = name\n        self.id_ = id_\n\n    def to(\n        self, *args, **kwargs\n    ):  # pylint: disable=invalid-name, unused-argument\n        \"\"\"No-op.  For compatibility of :meth:`Frame.to` method.\"\"\"\n        return self\n\n    @property\n    def data(self):\n        \"\"\"No-op.  For compatibility of :meth:`Frame.__repr__` method.\"\"\"\n        return self\n\n    def pin_memory_(self):\n        \"\"\"No-op.  For compatibility of :meth:`Frame.pin_memory_` method.\"\"\"\n\n    def unpin_memory_(self):\n        \"\"\"No-op.  For compatibility of :meth:`Frame.unpin_memory_` method.\"\"\"\n\n    def record_stream(self, stream):\n        \"\"\"No-op.  For compatibility of :meth:`Frame.record_stream` method.\"\"\"\n\n\nclass Scheme(namedtuple(\"Scheme\", [\"shape\", \"dtype\"])):\n    \"\"\"The column scheme.\n\n    Parameters\n    ----------\n    shape : tuple of int\n        The feature shape.\n    dtype : backend-specific type object\n        The feature data type.\n    \"\"\"\n\n    # Pickling torch dtypes could be problemetic; this is a workaround.\n    # I also have to create data_type_dict and reverse_data_type_dict\n    # attribute just for this bug.\n    # I raised an issue in PyTorch bug tracker:\n    # https://github.com/pytorch/pytorch/issues/14057\n    def __reduce__(self):\n        state = (self.shape, F.reverse_data_type_dict[self.dtype])\n        return self._reconstruct_scheme, state\n\n    @classmethod\n    def _reconstruct_scheme(cls, shape, dtype_str):\n        dtype = F.data_type_dict[dtype_str]\n        return cls(shape, dtype)\n\n\ndef infer_scheme(tensor):\n    \"\"\"Infer column scheme from the given tensor data.\n\n    Parameters\n    ---------\n    tensor : Tensor\n        The tensor data.\n\n    Returns\n    -------\n    Scheme\n        The column scheme.\n    \"\"\"\n    return Scheme(tuple(F.shape(tensor)[1:]), F.dtype(tensor))\n\n\nclass Column(TensorStorage):\n    \"\"\"A column is a compact store of features of multiple nodes/edges.\n\n    It batches all the feature tensors together along the first dimension\n    as one dense tensor.\n\n    The column can optionally have an index tensor I.\n    In this case, the i^th feature is stored in ``storage[index[i]]``.\n    The column class implements a Copy-On-Read semantics -- the index\n    select operation happens upon the first read of the feature data.\n    This is useful when one extracts a subset of the feature data\n    but wishes the actual index select happens on-demand.\n\n    Parameters\n    ----------\n    storage : Tensor\n        The feature data storage.\n    scheme : Scheme, optional\n        The scheme of the column. Will be inferred if not provided.\n    index : Tensor, optional\n        The row index to the feature data storage. None means an\n        identity mapping.\n\n    Attributes\n    ----------\n    storage : Tensor\n        The storage tensor. The storage tensor may not be the actual data\n        tensor of this column when the index tensor is not None.\n        This typically happens when the column is extracted from another\n        column using the `subcolumn` method.\n\n        It can also be None, which may only happen when transmitting a\n        not-yet-materialized subcolumn from a subprocess to the main process.\n        In this case, the main process should already maintain the content of\n        the storage, and is responsible for restoring the subcolumn's storage pointer.\n    data : Tensor\n        The actual data tensor of this column.\n    scheme : Scheme\n        The scheme of the column.\n    index : Tensor\n        Index tensor\n    \"\"\"\n\n    def __init__(self, storage, *args, **kwargs):\n        super().__init__(storage)\n        self._init(*args, **kwargs)\n\n    def __len__(self):\n        \"\"\"The number of features (number of rows) in this column.\"\"\"\n        if self.index is None:\n            return F.shape(self.storage)[0]\n        else:\n            return len(self.index)\n\n    @property\n    def shape(self):\n        \"\"\"Return the scheme shape (feature shape) of this column.\"\"\"\n        return self.scheme.shape\n\n    @property\n    def data(self):\n        \"\"\"Return the feature data. Perform index selecting if needed.\"\"\"\n        if self.index is not None:\n            if isinstance(self.index, _LazyIndex):\n                self.index = self.index.flatten()\n\n            storage_ctx = F.context(self.storage)\n            index_ctx = F.context(self.index)\n            # If under the special case where the storage is pinned and the index is on\n            # CUDA, directly call UVA slicing (even if they aree not in the same context).\n            if (\n                storage_ctx != index_ctx\n                and storage_ctx == F.cpu()\n                and F.is_pinned(self.storage)\n            ):\n                self.storage = gather_pinned_tensor_rows(\n                    self.storage, self.index\n                )\n            else:\n                # If index and storage is not in the same context,\n                # copy index to the same context of storage.\n                # Copy index is usually cheaper than copy data\n                if storage_ctx != index_ctx:\n                    kwargs = {}\n                    if self.device is not None:\n                        kwargs = self.device[1]\n                    self.index = F.copy_to(self.index, storage_ctx, **kwargs)\n                self.storage = F.gather_row(self.storage, self.index)\n            self.index = None\n\n        # move data to the right device\n        if self.device is not None:\n            self.storage = F.copy_to(\n                self.storage, self.device[0], **self.device[1]\n            )\n            self.device = None\n\n        # convert data to the right type\n        if self.deferred_dtype is not None:\n            self.storage = F.astype(self.storage, self.deferred_dtype)\n            self.deferred_dtype = None\n        return self.storage\n\n    @data.setter\n    def data(self, val):\n        \"\"\"Update the column data.\"\"\"\n        self.index = None\n        self.device = None\n        self.deferred_dtype = None\n        self.storage = val\n        self._data_nd = None  # should unpin data if it was pinned.\n        self.pinned_by_dgl = False\n\n    def to(self, device, **kwargs):  # pylint: disable=invalid-name\n        \"\"\"Return a new column with columns copy to the targeted device (cpu/gpu).\n\n        Parameters\n        ----------\n        device : Framework-specific device context object\n            The context to move data to.\n        kwargs : Key-word arguments.\n            Key-word arguments fed to the framework copy function.\n\n        Returns\n        -------\n        Column\n            A new column\n        \"\"\"\n        col = self.clone()\n        col.device = (device, kwargs)\n        return col\n\n    @property\n    def dtype(self):\n        \"\"\"Return the effective data type of this Column\"\"\"\n        if self.deferred_dtype is not None:\n            return self.deferred_dtype\n        return self.storage.dtype\n\n    def astype(self, new_dtype):\n        \"\"\"Return a new column such that when its data is requested,\n        it will be converted to new_dtype.\n\n        Parameters\n        ----------\n        new_dtype : Framework-specific type object\n            The type to convert the data to.\n\n        Returns\n        -------\n        Column\n            A new column\n        \"\"\"\n        col = self.clone()\n        if col.dtype != new_dtype:\n            # If there is already a pending conversion, ensure that the pending\n            # conversion and transfer/sampling are done before this new conversion.\n            if col.deferred_dtype is not None:\n                _ = col.data\n\n            if (col.device is None) and (col.index is None):\n                # Do the conversion immediately if no device transfer or index\n                # sampling is pending.  The assumption is that this is most\n                # likely to be the desired behaviour, such as converting an\n                # entire graph's feature data to float16 (half) before transfer\n                # to device when training, or converting back to float32 (float)\n                # after fetching the data to a device.\n                col.storage = F.astype(col.storage, new_dtype)\n            else:\n                # Defer the conversion if there is a pending transfer or sampling.\n                # This is so that feature data that never gets accessed on the\n                # device never needs to be transferred or sampled or converted.\n                col.deferred_dtype = new_dtype\n        return col\n\n    def __getitem__(self, rowids):\n        \"\"\"Return the feature data given the rowids.\n\n        The operation triggers index selection.\n\n        Parameters\n        ----------\n        rowids : Tensor\n            Row ID tensor.\n\n        Returns\n        -------\n        Tensor\n            The feature data\n        \"\"\"\n        return F.gather_row(self.data, rowids)\n\n    def __setitem__(self, rowids, feats):\n        \"\"\"Update the feature data given the index.\n\n        The update is performed out-placely so it can be used in autograd mode.\n        The operation triggers index selection.\n\n        Parameters\n        ----------\n        rowids : Tensor\n            Row IDs.\n        feats : Tensor\n            New features.\n        \"\"\"\n        self.update(rowids, feats)\n\n    def update(self, rowids, feats):\n        \"\"\"Update the feature data given the index.\n\n        Parameters\n        ----------\n        rowids : Tensor\n            Row IDs.\n        feats : Tensor\n            New features.\n        \"\"\"\n        feat_scheme = infer_scheme(feats)\n        if feat_scheme != self.scheme:\n            raise DGLError(\n                \"Cannot update column of scheme %s using feature of scheme %s.\"\n                % (feat_scheme, self.scheme)\n            )\n        self.data = F.scatter_row(self.data, rowids, feats)\n\n    def extend(self, feats, feat_scheme=None):\n        \"\"\"Extend the feature data.\n\n        The operation triggers index selection.\n\n        Parameters\n        ----------\n        feats : Tensor\n            The new features.\n        feat_scheme : Scheme, optional\n            The scheme\n        \"\"\"\n        if feat_scheme is None:\n            feat_scheme = infer_scheme(feats)\n\n        if feat_scheme != self.scheme:\n            raise DGLError(\n                \"Cannot update column of scheme %s using feature of scheme %s.\"\n                % (feat_scheme, self.scheme)\n            )\n\n        self.data = F.cat([self.data, feats], dim=0)\n\n    def clone(self):\n        \"\"\"Return a shallow copy of this column.\"\"\"\n        return Column(\n            self.storage,\n            self.scheme,\n            self.index,\n            self.device,\n            self.deferred_dtype,\n        )\n\n    def deepclone(self):\n        \"\"\"Return a deepcopy of this column.\n\n        The operation triggers index selection.\n        \"\"\"\n        return Column(F.clone(self.data), copy.deepcopy(self.scheme))\n\n    def subcolumn(self, rowids):\n        \"\"\"Return a subcolumn.\n\n        The resulting column will share the same storage as this column so this operation\n        is quite efficient. If the current column is also a sub-column (i.e.,\n        the index tensor is not None), the current index tensor will be sliced\n        by 'rowids', if they are on the same context. Otherwise, both index\n        tensors are saved, and only applied when the data is accessed.\n\n        Parameters\n        ----------\n        rowids : Tensor\n            Row IDs.\n\n        Returns\n        -------\n        Column\n            Sub-column\n        \"\"\"\n        if self.index is None:\n            return Column(\n                self.storage,\n                self.scheme,\n                rowids,\n                self.device,\n                self.deferred_dtype,\n            )\n        else:\n            index = self.index\n            if not isinstance(index, _LazyIndex):\n                index = _LazyIndex(self.index)\n            index = index.slice(rowids)\n            return Column(\n                self.storage,\n                self.scheme,\n                index,\n                self.device,\n                self.deferred_dtype,\n            )\n\n    @staticmethod\n    def create(data):\n        \"\"\"Create a new column using the given data.\"\"\"\n        if isinstance(data, Column):\n            return data.clone()\n        else:\n            return Column(data)\n\n    def __repr__(self):\n        return repr(self.data)\n\n    def __getstate__(self):\n        if self.storage is not None:\n            # flush any deferred operations\n            _ = self.data\n        state = self.__dict__.copy()\n        # data pinning does not get serialized, so we need to remove that from\n        # the state\n        state[\"_data_nd\"] = None\n        state[\"pinned_by_dgl\"] = False\n        return state\n\n    def __setstate__(self, state):\n        index = None\n        device = None\n        if \"storage\" in state and state[\"storage\"] is not None:\n            assert \"index\" not in state or state[\"index\"] is None\n            assert \"device\" not in state or state[\"device\"] is None\n        else:\n            # we may have a column with only index information, and that is\n            # valid\n            index = None if \"index\" not in state else state[\"index\"]\n            device = None if \"device\" not in state else state[\"device\"]\n        assert \"deferred_dtype\" not in state or state[\"deferred_dtype\"] is None\n        assert \"pinned_by_dgl\" not in state or state[\"pinned_by_dgl\"] is False\n        assert \"_data_nd\" not in state or state[\"_data_nd\"] is None\n\n        self.__dict__ = state\n        # properly initialize this object\n        self._init(\n            self.scheme if hasattr(self, \"scheme\") else None,\n            index=index,\n            device=device,\n        )\n\n    def _init(self, scheme=None, index=None, device=None, deferred_dtype=None):\n        self.scheme = scheme if scheme else infer_scheme(self.storage)\n        self.index = index\n        self.device = device\n        self.deferred_dtype = deferred_dtype\n        self.pinned_by_dgl = False\n        self._data_nd = None\n\n    def __copy__(self):\n        return self.clone()\n\n    def fetch(self, indices, device, pin_memory=False, **kwargs):\n        _ = self.data  # materialize in case of lazy slicing & data transfer\n        return super().fetch(indices, device, pin_memory=pin_memory, **kwargs)\n\n    def pin_memory_(self):\n        \"\"\"Pin the storage into page-locked memory.\n\n        Does nothing if the storage is already pinned.\n        \"\"\"\n        if not self.pinned_by_dgl and not F.is_pinned(self.data):\n            self._data_nd = pin_memory_inplace(self.data)\n            self.pinned_by_dgl = True\n\n    def unpin_memory_(self):\n        \"\"\"Unpin the storage pinned by ``pin_memory_`` method.\n\n        Does nothing if the storage is not pinned by ``pin_memory_`` method, even if\n        it is actually in page-locked memory.\n        \"\"\"\n        if self.pinned_by_dgl:\n            self._data_nd.unpin_memory_()\n            self._data_nd = None\n            self.pinned_by_dgl = False\n\n    def record_stream(self, stream):\n        \"\"\"Record stream that is using the storage.\n        Does nothing if the backend is not PyTorch.\n\n        Parameters\n        ----------\n        stream : torch.cuda.Stream.\n        \"\"\"\n        if F.get_preferred_backend() != \"pytorch\":\n            raise DGLError(\"record_stream only supports the PyTorch backend.\")\n        if self.index is not None and (\n            isinstance(self.index, _LazyIndex)\n            or F.context(self.index) != F.cpu()\n        ):\n            self.index.record_stream(stream)\n        if F.context(self.storage) != F.cpu():\n            self.storage.record_stream(stream)\n\n\nclass Frame(MutableMapping):\n    \"\"\"The columnar storage for node/edge features.\n\n    The frame is a dictionary from feature names to feature columns.\n    All columns should have the same number of rows (i.e. the same first dimension).\n\n    Parameters\n    ----------\n    data : dict-like, optional\n        The frame data in dictionary. If the provided data is another frame,\n        this frame will NOT share columns with the given frame. So any out-place\n        update on one will not reflect to the other.\n    num_rows : int, optional\n        The number of rows in this frame. If ``data`` is provided and is not empty,\n        ``num_rows`` will be ignored and inferred from the given data.\n    \"\"\"\n\n    def __init__(self, data=None, num_rows=None):\n        if data is None:\n            self._columns = dict()\n            self._num_rows = 0 if num_rows is None else num_rows\n        else:\n            assert not isinstance(data, Frame)  # sanity check for code refactor\n            # Note that we always create a new column for the given data.\n            # This avoids two frames accidentally sharing the same column.\n            self._columns = {\n                k: v if isinstance(v, LazyFeature) else Column.create(v)\n                for k, v in data.items()\n            }\n            self._num_rows = num_rows\n            # infer num_rows & sanity check\n            for name, col in self._columns.items():\n                if isinstance(col, LazyFeature):\n                    continue\n                if self._num_rows is None:\n                    self._num_rows = len(col)\n                elif len(col) != self._num_rows:\n                    raise DGLError(\n                        \"Expected all columns to have same # rows (%d), \"\n                        \"got %d on %r.\" % (self._num_rows, len(col), name)\n                    )\n\n        # Initializer for empty values. Initializer is a callable.\n        # If is none, then a warning will be raised\n        # in the first call and zero initializer will be used later.\n        self._initializers = {}  # per-column initializers\n        self._default_initializer = None\n\n    def _set_zero_default_initializer(self):\n        \"\"\"Set the default initializer to be zero initializer.\"\"\"\n        self._default_initializer = zero_initializer\n\n    def get_initializer(self, column=None):\n        \"\"\"Get the initializer for empty values for the given column.\n\n        Parameters\n        ----------\n        column : str\n            The column\n\n        Returns\n        -------\n        callable\n            The initializer\n        \"\"\"\n        return self._initializers.get(column, self._default_initializer)\n\n    def set_initializer(self, initializer, column=None):\n        \"\"\"Set the initializer for empty values, for a given column or all future\n        columns.\n\n        Initializer is a callable that returns a tensor given the shape and data type.\n\n        Parameters\n        ----------\n        initializer : callable\n            The initializer.\n        column : str, optional\n            The column name\n        \"\"\"\n        if column is None:\n            self._default_initializer = initializer\n        else:\n            self._initializers[column] = initializer\n\n    @property\n    def schemes(self):\n        \"\"\"Return a dictionary of column name to column schemes.\"\"\"\n        return {k: col.scheme for k, col in self._columns.items()}\n\n    @property\n    def num_columns(self):\n        \"\"\"Return the number of columns in this frame.\"\"\"\n        return len(self._columns)\n\n    @property\n    def num_rows(self):\n        \"\"\"Return the number of rows in this frame.\"\"\"\n        return self._num_rows\n\n    def __contains__(self, name):\n        \"\"\"Return true if the given column name exists.\"\"\"\n        return name in self._columns\n\n    def __getitem__(self, name):\n        \"\"\"Return the column of the given name.\n\n        Parameters\n        ----------\n        name : str\n            The column name.\n\n        Returns\n        -------\n        Tensor\n            Column data.\n        \"\"\"\n        return self._columns[name].data\n\n    def __setitem__(self, name, data):\n        \"\"\"Update the whole column.\n\n        Parameters\n        ----------\n        name : str\n            The column name.\n        col : Column or data convertible to Column\n            The column data.\n        \"\"\"\n        self.update_column(name, data)\n\n    def __delitem__(self, name):\n        \"\"\"Delete the whole column.\n\n        Parameters\n        ----------\n        name : str\n            The column name.\n        \"\"\"\n        del self._columns[name]\n\n    def add_column(self, name, scheme, ctx):\n        \"\"\"Add a new column to the frame.\n\n        The frame will be initialized by the initializer.\n\n        Parameters\n        ----------\n        name : str\n            The column name.\n        scheme : Scheme\n            The column scheme.\n        ctx : DGLContext\n            The column context.\n        \"\"\"\n        if name in self:\n            dgl_warning(\n                'Column \"%s\" already exists. Ignore adding this column again.'\n                % name\n            )\n            return\n\n        if self.get_initializer(name) is None:\n            self._set_zero_default_initializer()\n        initializer = self.get_initializer(name)\n        init_data = initializer(\n            (self.num_rows,) + scheme.shape,\n            scheme.dtype,\n            ctx,\n            slice(0, self.num_rows),\n        )\n        self._columns[name] = Column(init_data, scheme)\n\n    def add_rows(self, num_rows):\n        \"\"\"Add blank rows to this frame.\n\n        For existing fields, the rows will be extended according to their\n        initializers.\n\n        Parameters\n        ----------\n        num_rows : int\n            The number of new rows\n        \"\"\"\n        feat_placeholders = {}\n        for key, col in self._columns.items():\n            scheme = col.scheme\n            ctx = F.context(col.data)\n            if self.get_initializer(key) is None:\n                self._set_zero_default_initializer()\n            initializer = self.get_initializer(key)\n            new_data = initializer(\n                (num_rows,) + scheme.shape,\n                scheme.dtype,\n                ctx,\n                slice(self._num_rows, self._num_rows + num_rows),\n            )\n            feat_placeholders[key] = new_data\n        self._append(Frame(feat_placeholders))\n        self._num_rows += num_rows\n\n    def update_column(self, name, data):\n        \"\"\"Add or replace the column with the given name and data.\n\n        Parameters\n        ----------\n        name : str\n            The column name.\n        data : Column or data convertible to Column\n            The column data.\n        \"\"\"\n        if isinstance(data, LazyFeature):\n            self._columns[name] = data\n            return\n\n        col = Column.create(data)\n        if len(col) != self.num_rows:\n            raise DGLError(\n                \"Expected data to have %d rows, got %d.\"\n                % (self.num_rows, len(col))\n            )\n        self._columns[name] = col\n\n    def update_row(self, rowids, data):\n        \"\"\"Update the feature data of the given rows.\n\n        If the data contains new keys (new columns) that do not exist in\n        this frame, add a new column.\n\n        The ``rowids`` shall not contain duplicates. Otherwise, the behavior\n        is undefined.\n\n        Parameters\n        ----------\n        rowids : Tensor\n            Row Ids.\n        data : dict[str, Tensor]\n            Row data.\n        \"\"\"\n        for key, val in data.items():\n            if key not in self:\n                scheme = infer_scheme(val)\n                ctx = F.context(val)\n                self.add_column(key, scheme, ctx)\n        for key, val in data.items():\n            self._columns[key].update(rowids, val)\n\n    def _append(self, other):\n        \"\"\"Append ``other`` frame to ``self`` frame.\"\"\"\n        # pad columns that are not provided in the other frame with initial values\n        for key, col in self._columns.items():\n            if key in other:\n                continue\n            scheme = col.scheme\n            ctx = F.context(col.data)\n            if self.get_initializer(key) is None:\n                self._set_zero_default_initializer()\n            initializer = self.get_initializer(key)\n            new_data = initializer(\n                (other.num_rows,) + scheme.shape,\n                scheme.dtype,\n                ctx,\n                slice(self._num_rows, self._num_rows + other.num_rows),\n            )\n            other[key] = new_data\n        # append other to self\n        for key, col in other._columns.items():\n            if key not in self._columns:\n                # the column does not exist; init a new column\n                self.add_column(key, col.scheme, F.context(col.data))\n            self._columns[key].extend(col.data, col.scheme)\n\n    def append(self, other):\n        \"\"\"Append another frame's data into this frame.\n\n        If the current frame is empty, it will just use the columns of the\n        given frame. Otherwise, the given data should contain all the\n        column keys of this frame.\n\n        Parameters\n        ----------\n        other : Frame or dict-like\n            The frame data to be appended.\n        \"\"\"\n        if not isinstance(other, Frame):\n            other = Frame(other)\n        self._append(other)\n        self._num_rows += other.num_rows\n\n    def clear(self):\n        \"\"\"Clear this frame. Remove all the columns.\"\"\"\n        self._columns = {}\n        self._num_rows = 0\n\n    def __iter__(self):\n        \"\"\"Return an iterator of columns.\"\"\"\n        return iter(self._columns)\n\n    def __len__(self):\n        \"\"\"Return the number of columns.\"\"\"\n        return self.num_columns\n\n    def keys(self):\n        \"\"\"Return the keys.\"\"\"\n        return self._columns.keys()\n\n    def values(self):\n        \"\"\"Return the values.\"\"\"\n        return self._columns.values()\n\n    def clone(self):\n        \"\"\"Return a clone of this frame.\n\n        The clone frame does not share the underlying storage with this frame,\n        i.e., adding or removing columns will not be visible to each other. However,\n        they still share the tensor contents so any mutable operation on the column\n        tensor are visible to each other. Hence, the function does not allocate extra\n        tensor memory. Use :func:`~dgl.Frame.deepclone` for cloning\n        a frame that does not share any data.\n\n        Returns\n        -------\n        Frame\n            A cloned frame.\n        \"\"\"\n        newframe = Frame(self._columns, self._num_rows)\n        newframe._initializers = self._initializers\n        newframe._default_initializer = self._default_initializer\n        return newframe\n\n    def deepclone(self):\n        \"\"\"Return a deep clone of this frame.\n\n        The clone frame has an copy of this frame and any modification to the clone frame\n        is not visible to this frame. The function allocate new tensors and copy the contents\n        from this frame. Use :func:`~dgl.Frame.clone` for cloning a frame that does not\n        allocate extra tensor memory.\n\n        Returns\n        -------\n        Frame\n            A deep-cloned frame.\n        \"\"\"\n        newframe = Frame(\n            {k: col.deepclone() for k, col in self._columns.items()},\n            self._num_rows,\n        )\n        newframe._initializers = self._initializers\n        newframe._default_initializer = self._default_initializer\n        return newframe\n\n    def subframe(self, rowids):\n        \"\"\"Return a new frame whose columns are subcolumns of this frame.\n\n        The given row IDs should be within range [0, self.num_rows), and allow\n        duplicate IDs.\n\n        Parameters\n        ----------\n        rowids : Tensor\n            Row IDs\n\n        Returns\n        -------\n        Frame\n            A new subframe.\n        \"\"\"\n        subcols = {k: col.subcolumn(rowids) for k, col in self._columns.items()}\n        subf = Frame(subcols, len(rowids))\n        subf._initializers = self._initializers\n        subf._default_initializer = self._default_initializer\n        return subf\n\n    def to(self, device, **kwargs):  # pylint: disable=invalid-name\n        \"\"\"Return a new frame with columns copy to the targeted device (cpu/gpu).\n\n        Parameters\n        ----------\n        device : Framework-specific device context object\n            The context to move data to.\n        kwargs : Key-word arguments.\n            Key-word arguments fed to the framework copy function.\n\n        Returns\n        -------\n        Frame\n            A new frame\n        \"\"\"\n        newframe = self.clone()\n        new_columns = {\n            key: col.to(device, **kwargs)\n            for key, col in newframe._columns.items()\n        }\n        newframe._columns = new_columns\n        return newframe\n\n    def __repr__(self):\n        return repr(dict(self))\n\n    def pin_memory_(self):\n        \"\"\"Registers the data of every column into pinned memory, materializing them if\n        necessary.\"\"\"\n        for column in self._columns.values():\n            column.pin_memory_()\n\n    def unpin_memory_(self):\n        \"\"\"Unregisters the data of every column from pinned memory, materializing them\n        if necessary.\"\"\"\n        for column in self._columns.values():\n            column.unpin_memory_()\n\n    def record_stream(self, stream):\n        \"\"\"Record stream that is using the data of every column, materializing them\n        if necessary.\"\"\"\n        for column in self._columns.values():\n            column.record_stream(stream)\n\n    def _astype_float(self, new_type):\n        assert new_type in [\n            F.float64,\n            F.float32,\n            F.float16,\n            F.bfloat16,\n        ], \"'new_type' must be floating-point type: %s\" % str(new_type)\n        newframe = self.clone()\n        new_columns = {}\n        for name, column in self._columns.items():\n            dtype = column.dtype\n            if dtype != new_type and dtype in [\n                F.float64,\n                F.float32,\n                F.float16,\n                F.bfloat16,\n            ]:\n                new_columns[name] = column.astype(new_type)\n            else:\n                new_columns[name] = column\n        newframe._columns = new_columns\n        return newframe\n\n    def bfloat16(self):\n        \"\"\"Return a new frame with all floating-point columns converted\n        to bfloat16\"\"\"\n        return self._astype_float(F.bfloat16)\n\n    def half(self):\n        \"\"\"Return a new frame with all floating-point columns converted\n        to half-precision (float16)\"\"\"\n        return self._astype_float(F.float16)\n\n    def float(self):\n        \"\"\"Return a new frame with all floating-point columns converted\n        to single-precision (float32)\"\"\"\n        return self._astype_float(F.float32)\n\n    def double(self):\n        \"\"\"Return a new frame with all floating-point columns converted\n        to double-precision (float64)\"\"\"\n        return self._astype_float(F.float64)\n"
  },
  {
    "path": "python/dgl/function/__init__.py",
    "content": "\"\"\"DGL builtin functors\"\"\"\n# pylint: disable=redefined-builtin\nfrom __future__ import absolute_import\n\nfrom .base import *\nfrom .message import *\nfrom .reducer import *\n"
  },
  {
    "path": "python/dgl/function/base.py",
    "content": "\"\"\"Built-in function base class\"\"\"\nfrom __future__ import absolute_import\n\n__all__ = [\"BuiltinFunction\", \"TargetCode\"]\n\n\nclass TargetCode(object):\n    \"\"\"Code for target\n\n    Note: must be consistent with the target code definition in C++ side:\n          src/kernel/binary_reduce_common.h\n    \"\"\"\n\n    SRC = 0\n    DST = 1\n    EDGE = 2\n\n    CODE2STR = {\n        0: \"u\",\n        1: \"v\",\n        2: \"e\",\n    }\n\n\nclass BuiltinFunction(object):\n    \"\"\"Base builtin function class.\"\"\"\n\n    @property\n    def name(self):\n        \"\"\"Return the name of this builtin function.\"\"\"\n        raise NotImplementedError\n"
  },
  {
    "path": "python/dgl/function/message.py",
    "content": "\"\"\"Built-in message function.\"\"\"\nfrom __future__ import absolute_import\n\nimport sys\nfrom itertools import product\n\nfrom .base import BuiltinFunction, TargetCode\n\n\n__all__ = [\"copy_u\", \"copy_e\", \"BinaryMessageFunction\", \"CopyMessageFunction\"]\n\n\nclass MessageFunction(BuiltinFunction):\n    \"\"\"Base builtin message function class.\"\"\"\n\n    @property\n    def name(self):\n        \"\"\"Return the name of this builtin function.\"\"\"\n        raise NotImplementedError\n\n\nclass BinaryMessageFunction(MessageFunction):\n    \"\"\"Class for the lhs_op_rhs builtin message function.\n\n    See Also\n    --------\n    u_mul_e\n    \"\"\"\n\n    def __init__(self, binary_op, lhs, rhs, lhs_field, rhs_field, out_field):\n        self.binary_op = binary_op\n        self.lhs = lhs\n        self.rhs = rhs\n        self.lhs_field = lhs_field\n        self.rhs_field = rhs_field\n        self.out_field = out_field\n\n    @property\n    def name(self):\n        lhs = TargetCode.CODE2STR[self.lhs]\n        rhs = TargetCode.CODE2STR[self.rhs]\n        return \"{}_{}_{}\".format(lhs, self.binary_op, rhs)\n\n\nclass CopyMessageFunction(MessageFunction):\n    \"\"\"Class for the copy builtin message function.\n\n    See Also\n    --------\n    copy_u\n    \"\"\"\n\n    def __init__(self, target, in_field, out_field):\n        self.target = target\n        self.in_field = in_field\n        self.out_field = out_field\n\n    @property\n    def name(self):\n        return \"copy_{}\".format(TargetCode.CODE2STR[self.target])\n\n\ndef copy_u(u, out):\n    \"\"\"Builtin message function that computes message using source node\n    feature.\n\n    Parameters\n    ----------\n    u : str\n        The source feature field.\n    out : str\n        The output message field.\n\n    Examples\n    --------\n    >>> import dgl\n    >>> message_func = dgl.function.copy_u('h', 'm')\n\n    The above example is equivalent to the following user defined function:\n\n    >>> def message_func(edges):\n    >>>     return {'m': edges.src['h']}\n    \"\"\"\n    return CopyMessageFunction(TargetCode.SRC, u, out)\n\n\ndef copy_e(e, out):\n    \"\"\"Builtin message function that computes message using edge feature.\n\n    Parameters\n    ----------\n    e : str\n        The edge feature field.\n    out : str\n        The output message field.\n\n    Examples\n    --------\n    >>> import dgl\n    >>> message_func = dgl.function.copy_e('h', 'm')\n\n    The above example is equivalent to the following user defined function:\n\n    >>> def message_func(edges):\n    >>>     return {'m': edges.data['h']}\n    \"\"\"\n    return CopyMessageFunction(TargetCode.EDGE, e, out)\n\n\n###############################################################################\n# Generate all following  builtin message functions:\n# element-wise message functions:\n# u_add_v, u_sub_v, u_mul_v, u_div_v\n# u_add_e, u_sub_e, u_mul_e, u_div_e\n# v_add_u, v_sub_u, v_mul_u, v_div_u\n# v_add_e, v_sub_e, v_mul_e, v_div_e\n# e_add_u, e_sub_u, e_mul_u, e_div_u\n# e_add_v, e_sub_v, e_mul_v, e_div_v\n#\n# dot message functions:\n# u_dot_v, u_dot_e, v_dot_e\n# v_dot_u, e_dot_u, e_dot_v\n\n_TARGET_MAP = {\n    \"u\": TargetCode.SRC,\n    \"v\": TargetCode.DST,\n    \"e\": TargetCode.EDGE,\n}\n\n\ndef _gen_message_builtin(lhs, rhs, binary_op):\n    name = \"{}_{}_{}\".format(lhs, binary_op, rhs)\n    docstring = \"\"\"Builtin message function that computes a message on an edge\n    by performing element-wise {} between features of {} and {}\n    if the features have the same shape; otherwise, it first broadcasts the features\n    to a new shape and performs the element-wise operation.\n\n    Broadcasting follows NumPy semantics. Please see\n    https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html\n    for more details about the NumPy broadcasting semantics.\n\n    Parameters\n    ----------\n    lhs_field : str\n        The feature field of {}.\n    rhs_field : str\n        The feature field of {}.\n    out : str\n        The output message field.\n\n    Examples\n    --------\n    >>> import dgl\n    >>> message_func = dgl.function.{}('h', 'h', 'm')\n    \"\"\".format(\n        binary_op,\n        TargetCode.CODE2STR[_TARGET_MAP[lhs]],\n        TargetCode.CODE2STR[_TARGET_MAP[rhs]],\n        TargetCode.CODE2STR[_TARGET_MAP[lhs]],\n        TargetCode.CODE2STR[_TARGET_MAP[rhs]],\n        name,\n    )\n\n    def func(lhs_field, rhs_field, out):\n        return BinaryMessageFunction(\n            binary_op,\n            _TARGET_MAP[lhs],\n            _TARGET_MAP[rhs],\n            lhs_field,\n            rhs_field,\n            out,\n        )\n\n    func.__name__ = name\n    func.__doc__ = docstring\n    return func\n\n\ndef _register_builtin_message_func():\n    \"\"\"Register builtin message functions\"\"\"\n    target = [\"u\", \"v\", \"e\"]\n    for lhs, rhs in product(target, target):\n        if lhs != rhs:\n            for binary_op in [\"add\", \"sub\", \"mul\", \"div\", \"dot\"]:\n                func = _gen_message_builtin(lhs, rhs, binary_op)\n                setattr(sys.modules[__name__], func.__name__, func)\n                __all__.append(func.__name__)\n\n\n_register_builtin_message_func()\n"
  },
  {
    "path": "python/dgl/function/reducer.py",
    "content": "\"\"\"Built-in reducer function.\"\"\"\n# pylint: disable=redefined-builtin\nfrom __future__ import absolute_import\n\nimport sys\n\nfrom .base import BuiltinFunction\n\n\nclass ReduceFunction(BuiltinFunction):\n    \"\"\"Base builtin reduce function class.\"\"\"\n\n    @property\n    def name(self):\n        \"\"\"Return the name of this builtin function.\"\"\"\n        raise NotImplementedError\n\n\nclass SimpleReduceFunction(ReduceFunction):\n    \"\"\"Builtin reduce function that aggregates a single field into another\n    single field.\"\"\"\n\n    def __init__(self, name, msg_field, out_field):\n        self._name = name\n        self.msg_field = msg_field\n        self.out_field = out_field\n\n    @property\n    def name(self):\n        return self._name\n\n\n###############################################################################\n# Generate all following reducer functions:\n# sum, max, min, mean, prod\n\n\ndef _gen_reduce_builtin(reducer):\n    docstring = \"\"\"Builtin reduce function that aggregates messages by {0}.\n\n    Parameters\n    ----------\n    msg : str\n        The message field.\n    out : str\n        The output node feature field.\n    Examples\n    --------\n    >>> import dgl\n    >>> reduce_func = dgl.function.{0}('m', 'h')\n\n    The above example is equivalent to the following user defined function\n    (if using PyTorch):\n\n    >>> import torch\n    >>> def reduce_func(nodes):\n    >>>     return {{'h': torch.{0}(nodes.mailbox['m'], dim=1)}}\n    \"\"\".format(\n        reducer\n    )\n\n    def func(msg, out):\n        return SimpleReduceFunction(reducer, msg, out)\n\n    func.__name__ = str(reducer)\n    func.__qualname__ = str(reducer)\n    func.__doc__ = docstring\n    return func\n\n\n__all__ = []\n\n\ndef _register_builtin_reduce_func():\n    \"\"\"Register builtin reduce functions\"\"\"\n    for reduce_op in [\"max\", \"min\", \"sum\", \"mean\"]:\n        builtin = _gen_reduce_builtin(reduce_op)\n        setattr(sys.modules[__name__], reduce_op, builtin)\n        __all__.append(reduce_op)\n\n\n_register_builtin_reduce_func()\n"
  },
  {
    "path": "python/dgl/generators.py",
    "content": "\"\"\"Module for various graph generator functions.\"\"\"\n\nfrom . import backend as F, convert, random\n\n__all__ = [\"rand_graph\", \"rand_bipartite\"]\n\n\ndef rand_graph(num_nodes, num_edges, idtype=F.int64, device=F.cpu()):\n    \"\"\"Generate a random graph of the given number of nodes/edges and return.\n\n    It uniformly chooses ``num_edges`` from all possible node pairs and form a graph.\n    The random choice is without replacement, which means there will be no multi-edge\n    in the resulting graph.\n\n    To control the randomness, set the random seed via :func:`dgl.seed`.\n\n    Parameters\n    ----------\n    num_nodes : int\n        The number of nodes\n    num_edges : int\n        The number of edges\n    idtype : int32, int64, optional\n        The data type for storing the structure-related graph information\n        such as node and edge IDs. It should be a framework-specific data type object\n        (e.g., torch.int32). By default, DGL uses int64.\n    device : Device context, optional\n        The device of the resulting graph. It should be a framework-specific device\n        object (e.g., torch.device). By default, DGL stores the graph on CPU.\n\n    Returns\n    -------\n    DGLGraph\n        The generated random graph.\n\n    See Also\n    --------\n    rand_bipartite\n\n    Examples\n    --------\n    >>> import dgl\n    >>> dgl.rand_graph(100, 10)\n    Graph(num_nodes=100, num_edges=10,\n          ndata_schemes={}\n          edata_schemes={})\n    \"\"\"\n    # TODO(minjie): support RNG as one of the arguments.\n    eids = random.choice(num_nodes * num_nodes, num_edges, replace=False)\n    eids = F.zerocopy_to_numpy(eids)\n    rows = F.zerocopy_from_numpy(eids // num_nodes)\n    cols = F.zerocopy_from_numpy(eids % num_nodes)\n    rows = F.copy_to(F.astype(rows, idtype), device)\n    cols = F.copy_to(F.astype(cols, idtype), device)\n    return convert.graph(\n        (rows, cols), num_nodes=num_nodes, idtype=idtype, device=device\n    )\n\n\ndef rand_bipartite(\n    utype,\n    etype,\n    vtype,\n    num_src_nodes,\n    num_dst_nodes,\n    num_edges,\n    idtype=F.int64,\n    device=F.cpu(),\n):\n    \"\"\"Generate a random uni-directional bipartite graph and return.\n\n    It uniformly chooses ``num_edges`` from all possible node pairs and form a graph.\n    The random choice is without replacement, which means there will be no multi-edge\n    in the resulting graph.\n\n    To control the randomness, set the random seed via :func:`dgl.seed`.\n\n    Parameters\n    ----------\n    utype : str, optional\n        The name of the source node type.\n    etype : str, optional\n        The name of the edge type.\n    vtype : str, optional\n        The name of the destination node type.\n    num_src_nodes : int\n        The number of source nodes.\n    num_dst_nodes : int\n        The number of destination nodes.\n    num_edges : int\n        The number of edges\n    idtype : int32, int64, optional\n        The data type for storing the structure-related graph information\n        such as node and edge IDs. It should be a framework-specific data type object\n        (e.g., torch.int32). By default, DGL uses int64.\n    device : Device context, optional\n        The device of the resulting graph. It should be a framework-specific device\n        object (e.g., torch.device). By default, DGL stores the graph on CPU.\n\n    Returns\n    -------\n    DGLGraph\n        The generated random bipartite graph.\n\n    See Also\n    --------\n    rand_graph\n\n    Examples\n    --------\n    >>> import dgl\n    >>> dgl.rand_bipartite('user', 'buys', 'game', 50, 100, 10)\n    Graph(num_nodes={'game': 100, 'user': 50},\n          num_edges={('user', 'buys', 'game'): 10},\n          metagraph=[('user', 'game', 'buys')])\n    \"\"\"\n    # TODO(minjie): support RNG as one of the arguments.\n    eids = random.choice(\n        num_src_nodes * num_dst_nodes, num_edges, replace=False\n    )\n    eids = F.zerocopy_to_numpy(eids)\n    rows = F.zerocopy_from_numpy(eids // num_dst_nodes)\n    cols = F.zerocopy_from_numpy(eids % num_dst_nodes)\n    rows = F.copy_to(F.astype(rows, idtype), device)\n    cols = F.copy_to(F.astype(cols, idtype), device)\n    return convert.heterograph(\n        {(utype, etype, vtype): (rows, cols)},\n        {utype: num_src_nodes, vtype: num_dst_nodes},\n        idtype=idtype,\n        device=device,\n    )\n"
  },
  {
    "path": "python/dgl/geometry/__init__.py",
    "content": "\"\"\"The ``dgl.geometry`` package contains geometry operations:\n\n* Farthest point sampling for point cloud sampling\n\n* Neighbor matching module for graclus pooling\n\n.. note::\n    This package is experimental and the interfaces may be subject\n    to changes in future releases.\n\"\"\"\nfrom .edge_coarsening import *\nfrom .fps import *\n"
  },
  {
    "path": "python/dgl/geometry/capi.py",
    "content": "\"\"\"Python interfaces to DGL farthest point sampler.\"\"\"\nimport numpy as np\n\nfrom .. import backend as F, ndarray as nd\nfrom .._ffi.base import DGLError\nfrom .._ffi.function import _init_api\n\n\ndef _farthest_point_sampler(\n    data, batch_size, sample_points, dist, start_idx, result\n):\n    r\"\"\"Farthest Point Sampler\n\n    Parameters\n    ----------\n    data : tensor\n        A tensor of shape (N, d) where N is the number of points and d is the dimension.\n    batch_size : int\n        The number of batches in the ``data``. N should be divisible by batch_size.\n    sample_points : int\n        The number of points to sample in each batch.\n    dist : tensor\n        Pre-allocated tensor of shape (N, ) for to-sample distance.\n    start_idx : tensor of int\n        Pre-allocated tensor of shape (batch_size, ) for the starting sample in each batch.\n    result : tensor of int\n        Pre-allocated tensor of shape (sample_points * batch_size, ) for the sampled index.\n\n    Returns\n    -------\n    No return value. The input variable ``result`` will be overwriten with sampled indices.\n\n    \"\"\"\n    assert F.shape(data)[0] >= sample_points * batch_size\n    assert F.shape(data)[0] % batch_size == 0\n\n    _CAPI_FarthestPointSampler(\n        F.zerocopy_to_dgl_ndarray(data),\n        batch_size,\n        sample_points,\n        F.zerocopy_to_dgl_ndarray(dist),\n        F.zerocopy_to_dgl_ndarray(start_idx),\n        F.zerocopy_to_dgl_ndarray(result),\n    )\n\n\ndef _neighbor_matching(\n    graph_idx, num_nodes, edge_weights=None, relabel_idx=True\n):\n    \"\"\"\n    Description\n    -----------\n    The neighbor matching procedure of edge coarsening used in\n    `Metis <http://cacs.usc.edu/education/cs653/Karypis-METIS-SIAMJSC98.pdf>`__\n    and\n    `Graclus <https://www.cs.utexas.edu/users/inderjit/public_papers/multilevel_pami.pdf>`__\n    for homogeneous graph coarsening. This procedure keeps picking an unmarked\n    vertex and matching it with one its unmarked neighbors (that maximizes its\n    edge weight) until no match can be done.\n\n    If no edge weight is given, this procedure will randomly pick neighbor for each\n    vertex.\n\n    The GPU implementation is based on `A GPU Algorithm for Greedy Graph Matching\n    <http://www.staff.science.uu.nl/~bisse101/Articles/match12.pdf>`__\n\n    NOTE: The input graph must be bi-directed (undirected) graph. Call :obj:`dgl.to_bidirected`\n    if you are not sure your graph is bi-directed.\n\n    Parameters\n    ----------\n    graph : HeteroGraphIndex\n        The input homogeneous graph.\n    num_nodes : int\n        The number of nodes in this homogeneous graph.\n    edge_weight : tensor, optional\n        The edge weight tensor holding non-negative scalar weight for each edge.\n        default: :obj:`None`\n    relabel_idx : bool, optional\n        If true, relabel resulting node labels to have consecutive node ids.\n        default: :obj:`True`\n\n    Returns\n    -------\n    a 1-D tensor\n        A vector with each element that indicates the cluster ID of a vertex.\n    \"\"\"\n    edge_weight_capi = nd.NULL[\"int64\"]\n    if edge_weights is not None:\n        edge_weight_capi = F.zerocopy_to_dgl_ndarray(edge_weights)\n    node_label = F.full_1d(\n        num_nodes,\n        -1,\n        getattr(F, graph_idx.dtype),\n        F.to_backend_ctx(graph_idx.ctx),\n    )\n    node_label_capi = F.zerocopy_to_dgl_ndarray_for_write(node_label)\n    _CAPI_NeighborMatching(graph_idx, edge_weight_capi, node_label_capi)\n    if F.reduce_sum(node_label < 0).item() != 0:\n        raise DGLError(\"Find unmatched node\")\n\n    # reorder node id\n    # TODO: actually we can add `return_inverse` option for `unique`\n    #       function in backend for efficiency.\n    if relabel_idx:\n        node_label_np = F.zerocopy_to_numpy(node_label)\n        _, node_label_np = np.unique(node_label_np, return_inverse=True)\n        return F.tensor(node_label_np)\n    else:\n        return node_label\n\n\n_init_api(\"dgl.geometry\", __name__)\n"
  },
  {
    "path": "python/dgl/geometry/edge_coarsening.py",
    "content": "\"\"\"Edge coarsening procedure used in Metis and Graclus, for pytorch\"\"\"\n# pylint: disable=no-member, invalid-name, W0613\nfrom .. import remove_self_loop\nfrom .capi import _neighbor_matching\n\n__all__ = [\"neighbor_matching\"]\n\n\ndef neighbor_matching(graph, e_weights=None, relabel_idx=True):\n    r\"\"\"\n    Description\n    -----------\n    The neighbor matching procedure of edge coarsening in\n    `Metis <http://cacs.usc.edu/education/cs653/Karypis-METIS-SIAMJSC98.pdf>`__\n    and\n    `Graclus <https://www.cs.utexas.edu/users/inderjit/public_papers/multilevel_pami.pdf>`__\n    for homogeneous graph coarsening. This procedure keeps picking an unmarked\n    vertex and matching it with one its unmarked neighbors (that maximizes its\n    edge weight) until no match can be done.\n\n    If no edge weight is given, this procedure will randomly pick neighbor for each\n    vertex.\n\n    The GPU implementation is based on `A GPU Algorithm for Greedy Graph Matching\n    <http://www.staff.science.uu.nl/~bisse101/Articles/match12.pdf>`__\n\n    NOTE: The input graph must be bi-directed (undirected) graph. Call :obj:`dgl.to_bidirected`\n          if you are not sure your graph is bi-directed.\n\n    Parameters\n    ----------\n    graph : DGLGraph\n        The input homogeneous graph.\n    edge_weight : torch.Tensor, optional\n        The edge weight tensor holding non-negative scalar weight for each edge.\n        default: :obj:`None`\n    relabel_idx : bool, optional\n        If true, relabel resulting node labels to have consecutive node ids.\n        default: :obj:`True`\n\n    Examples\n    --------\n    The following example uses PyTorch backend.\n\n    >>> import torch, dgl\n    >>> from dgl.geometry import neighbor_matching\n    >>>\n    >>> g = dgl.graph(([0, 1, 1, 2], [1, 0, 2, 1]))\n    >>> res = neighbor_matching(g)\n        tensor([0, 1, 1])\n    \"\"\"\n    assert (\n        graph.is_homogeneous\n    ), \"The graph used in graph node matching must be homogeneous\"\n    if e_weights is not None:\n        graph.edata[\"e_weights\"] = e_weights\n        graph = remove_self_loop(graph)\n        e_weights = graph.edata[\"e_weights\"]\n        graph.edata.pop(\"e_weights\")\n    else:\n        graph = remove_self_loop(graph)\n    return _neighbor_matching(\n        graph._graph, graph.num_nodes(), e_weights, relabel_idx\n    )\n"
  },
  {
    "path": "python/dgl/geometry/fps.py",
    "content": "\"\"\"Farthest Point Sampler for pytorch Geometry package\"\"\"\n# pylint: disable=no-member, invalid-name\n\nfrom .. import backend as F\nfrom ..base import DGLError\nfrom .capi import _farthest_point_sampler\n\n__all__ = [\"farthest_point_sampler\"]\n\n\ndef farthest_point_sampler(pos, npoints, start_idx=None):\n    \"\"\"Farthest Point Sampler without the need to compute all pairs of distance.\n\n    In each batch, the algorithm starts with the sample index specified by ``start_idx``.\n    Then for each point, we maintain the minimum to-sample distance.\n    Finally, we pick the point with the maximum such distance.\n    This process will be repeated for ``sample_points`` - 1 times.\n\n    Parameters\n    ----------\n    pos : tensor\n        The positional tensor of shape (B, N, C)\n    npoints : int\n        The number of points to sample in each batch.\n    start_idx : int, optional\n        If given, appoint the index of the starting point,\n        otherwise randomly select a point as the start point.\n        (default: None)\n\n    Returns\n    -------\n    tensor of shape (B, npoints)\n        The sampled indices in each batch.\n\n    Examples\n    --------\n    The following exmaple uses PyTorch backend.\n\n    >>> import torch\n    >>> from dgl.geometry import farthest_point_sampler\n    >>> x = torch.rand((2, 10, 3))\n    >>> point_idx = farthest_point_sampler(x, 2)\n    >>> print(point_idx)\n        tensor([[5, 6],\n                [7, 8]])\n    \"\"\"\n    ctx = F.context(pos)\n    B, N, C = pos.shape\n    pos = pos.reshape(-1, C)\n    dist = F.zeros((B * N), dtype=pos.dtype, ctx=ctx)\n    if start_idx is None:\n        start_idx = F.randint(\n            shape=(B,), dtype=F.int64, ctx=ctx, low=0, high=N - 1\n        )\n    else:\n        if start_idx >= N or start_idx < 0:\n            raise DGLError(\n                \"Invalid start_idx, expected 0 <= start_idx < {}, got {}\".format(\n                    N, start_idx\n                )\n            )\n        start_idx = F.full_1d(B, start_idx, dtype=F.int64, ctx=ctx)\n    result = F.zeros((npoints * B), dtype=F.int64, ctx=ctx)\n    _farthest_point_sampler(pos, B, npoints, dist, start_idx, result)\n    return result.reshape(B, npoints)\n"
  },
  {
    "path": "python/dgl/global_config.py",
    "content": "\"\"\"Module for global configuration operators.\"\"\"\nfrom ._ffi.function import _init_api\n\n__all__ = [\"is_libxsmm_enabled\", \"use_libxsmm\"]\n\n\ndef use_libxsmm(flag):\n    r\"\"\"Set whether DGL uses libxsmm at runtime.\n\n    Detailed information about libxsmm can be found here:\n    https://github.com/libxsmm/libxsmm\n\n    Parameters\n    ----------\n    flag : boolean\n        If True, use libxsmm, otherwise not.\n\n    See Also\n    --------\n    is_libxsmm_enabled\n    \"\"\"\n    _CAPI_DGLConfigSetLibxsmm(flag)\n\n\ndef is_libxsmm_enabled():\n    r\"\"\"Get whether the use_libxsmm flag is turned on.\n\n    Returns\n    ----------\n    use_libxsmm_flag[boolean]\n        True if the use_libxsmm flag is turned on.\n\n    See Also\n    ----------\n    use_libxsmm\n    \"\"\"\n    return _CAPI_DGLConfigGetLibxsmm()\n\n\n_init_api(\"dgl.global_config\")\n"
  },
  {
    "path": "python/dgl/graph_index.py",
    "content": "\"\"\"Module for graph index class definition.\"\"\"\nfrom __future__ import absolute_import\n\nimport networkx as nx\nimport numpy as np\nimport scipy\n\nfrom . import backend as F, utils\nfrom ._ffi.function import _init_api\nfrom ._ffi.object import ObjectBase, register_object\nfrom .base import dgl_warning, DGLError\n\n\nclass BoolFlag(object):\n    \"\"\"Bool flag with unknown value\"\"\"\n\n    BOOL_UNKNOWN = -1\n    BOOL_FALSE = 0\n    BOOL_TRUE = 1\n\n\n@register_object(\"graph.Graph\")\nclass GraphIndex(ObjectBase):\n    \"\"\"Graph index object.\n\n    Note\n    ----\n    Do not create GraphIndex directly, you can create graph index object using\n    following functions:\n\n    - `dgl.graph_index.from_edge_list`\n    - `dgl.graph_index.from_scipy_sparse_matrix`\n    - `dgl.graph_index.from_networkx`\n    - `dgl.graph_index.from_shared_mem_csr_matrix`\n    - `dgl.graph_index.from_csr`\n    - `dgl.graph_index.from_coo`\n    \"\"\"\n\n    def __new__(cls):\n        obj = ObjectBase.__new__(cls)\n        obj._readonly = None  # python-side cache of the flag\n        obj._cache = {}\n        return obj\n\n    def __getstate__(self):\n        src, dst, _ = self.edges()\n        n_nodes = self.num_nodes()\n        readonly = self.is_readonly()\n\n        return n_nodes, readonly, src, dst\n\n    def __setstate__(self, state):\n        \"\"\"The pickle state of GraphIndex is defined as a triplet\n        (num_nodes, readonly, src_nodes, dst_nodes)\n        \"\"\"\n        # Pickle compatibility check\n        # TODO: we should store a storage version number in later releases.\n        if isinstance(state, tuple) and len(state) == 5:\n            dgl_warning(\n                \"The object is pickled pre-0.4.2.  Multigraph flag is ignored in 0.4.3\"\n            )\n            num_nodes, _, readonly, src, dst = state\n        elif isinstance(state, tuple) and len(state) == 4:\n            # post-0.4.3.\n            num_nodes, readonly, src, dst = state\n        else:\n            raise IOError(\"Unrecognized storage format.\")\n\n        self._cache = {}\n        self._readonly = readonly\n        self.__init_handle_by_constructor__(\n            _CAPI_DGLGraphCreate,\n            src.todgltensor(),\n            dst.todgltensor(),\n            int(num_nodes),\n            readonly,\n        )\n\n    def add_nodes(self, num):\n        \"\"\"Add nodes.\n\n        Parameters\n        ----------\n        num : int\n            Number of nodes to be added.\n        \"\"\"\n        _CAPI_DGLGraphAddVertices(self, int(num))\n        self.clear_cache()\n\n    def add_edge(self, u, v):\n        \"\"\"Add one edge.\n\n        Parameters\n        ----------\n        u : int\n            The src node.\n        v : int\n            The dst node.\n        \"\"\"\n        _CAPI_DGLGraphAddEdge(self, int(u), int(v))\n        self.clear_cache()\n\n    def add_edges(self, u, v):\n        \"\"\"Add many edges.\n\n        Parameters\n        ----------\n        u : utils.Index\n            The src nodes.\n        v : utils.Index\n            The dst nodes.\n        \"\"\"\n        u_array = u.todgltensor()\n        v_array = v.todgltensor()\n        _CAPI_DGLGraphAddEdges(self, u_array, v_array)\n        self.clear_cache()\n\n    def clear(self):\n        \"\"\"Clear the graph.\"\"\"\n        _CAPI_DGLGraphClear(self)\n        self.clear_cache()\n\n    def clear_cache(self):\n        \"\"\"Clear the cached graph structures.\"\"\"\n        self._cache.clear()\n\n    def is_multigraph(self):\n        \"\"\"Return whether the graph is a multigraph\n        The time cost will be O(E)\n\n        Returns\n        -------\n        bool\n            True if it is a multigraph, False otherwise.\n        \"\"\"\n        return bool(_CAPI_DGLGraphIsMultigraph(self))\n\n    def is_readonly(self):\n        \"\"\"Indicate whether the graph index is read-only.\n\n        Returns\n        -------\n        bool\n            True if it is a read-only graph, False otherwise.\n        \"\"\"\n        if self._readonly is None:\n            self._readonly = bool(_CAPI_DGLGraphIsReadonly(self))\n        return self._readonly\n\n    def readonly(self, readonly_state=True):\n        \"\"\"Set the readonly state of graph index in-place.\n\n        Parameters\n        ----------\n        readonly_state : bool\n            New readonly state of current graph index.\n        \"\"\"\n        # TODO(minjie): very ugly code, should fix this\n        n_nodes, _, src, dst = self.__getstate__()\n        self.clear_cache()\n        state = (n_nodes, readonly_state, src, dst)\n        self.__setstate__(state)\n\n    def num_nodes(self):\n        \"\"\"Return the number of nodes.\n\n        Returns\n        -------\n        int\n            The number of nodes.\n        \"\"\"\n        return _CAPI_DGLGraphNumVertices(self)\n\n    def num_edges(self):\n        \"\"\"Return the number of edges.\n\n        Returns\n        -------\n        int\n            The number of edges.\n        \"\"\"\n        return _CAPI_DGLGraphNumEdges(self)\n\n    # TODO(#5485): remove this method.\n    def number_of_nodes(self):\n        \"\"\"Return the number of nodes.\n\n        Returns\n        -------\n        int\n            The number of nodes\n        \"\"\"\n        return _CAPI_DGLGraphNumVertices(self)\n\n    # TODO(#5485): remove this method.\n    def number_of_edges(self):\n        \"\"\"Return the number of edges.\n\n        Returns\n        -------\n        int\n            The number of edges\n        \"\"\"\n        return _CAPI_DGLGraphNumEdges(self)\n\n    def has_node(self, vid):\n        \"\"\"Return true if the node exists.\n\n        Parameters\n        ----------\n        vid : int\n            The nodes\n\n        Returns\n        -------\n        bool\n            True if the node exists, False otherwise.\n        \"\"\"\n        return bool(_CAPI_DGLGraphHasVertex(self, int(vid)))\n\n    def has_nodes(self, vids):\n        \"\"\"Return true if the nodes exist.\n\n        Parameters\n        ----------\n        vid : utils.Index\n            The nodes\n\n        Returns\n        -------\n        utils.Index\n            0-1 array indicating existence\n        \"\"\"\n        vid_array = vids.todgltensor()\n        return utils.toindex(_CAPI_DGLGraphHasVertices(self, vid_array))\n\n    def has_edge_between(self, u, v):\n        \"\"\"Return true if the edge exists.\n\n        Parameters\n        ----------\n        u : int\n            The src node.\n        v : int\n            The dst node.\n\n        Returns\n        -------\n        bool\n            True if the edge exists, False otherwise\n        \"\"\"\n        return bool(_CAPI_DGLGraphHasEdgeBetween(self, int(u), int(v)))\n\n    def has_edges_between(self, u, v):\n        \"\"\"Return true if the edge exists.\n\n        Parameters\n        ----------\n        u : utils.Index\n            The src nodes.\n        v : utils.Index\n            The dst nodes.\n\n        Returns\n        -------\n        utils.Index\n            0-1 array indicating existence\n        \"\"\"\n        u_array = u.todgltensor()\n        v_array = v.todgltensor()\n        return utils.toindex(\n            _CAPI_DGLGraphHasEdgesBetween(self, u_array, v_array)\n        )\n\n    def predecessors(self, v, radius=1):\n        \"\"\"Return the predecessors of the node.\n\n        Parameters\n        ----------\n        v : int\n            The node.\n        radius : int, optional\n            The radius of the neighborhood.\n\n        Returns\n        -------\n        utils.Index\n            Array of predecessors\n        \"\"\"\n        return utils.toindex(\n            _CAPI_DGLGraphPredecessors(self, int(v), int(radius))\n        )\n\n    def successors(self, v, radius=1):\n        \"\"\"Return the successors of the node.\n\n        Parameters\n        ----------\n        v : int\n            The node.\n        radius : int, optional\n            The radius of the neighborhood.\n\n        Returns\n        -------\n        utils.Index\n            Array of successors\n        \"\"\"\n        return utils.toindex(\n            _CAPI_DGLGraphSuccessors(self, int(v), int(radius))\n        )\n\n    def edge_id(self, u, v):\n        \"\"\"Return the id array of all edges between u and v.\n\n        Parameters\n        ----------\n        u : int\n            The src node.\n        v : int\n            The dst node.\n\n        Returns\n        -------\n        utils.Index\n            The edge id array.\n        \"\"\"\n        return utils.toindex(_CAPI_DGLGraphEdgeId(self, int(u), int(v)))\n\n    def edge_ids(self, u, v):\n        \"\"\"Return a triplet of arrays that contains the edge IDs.\n\n        Parameters\n        ----------\n        u : utils.Index\n            The src nodes.\n        v : utils.Index\n            The dst nodes.\n\n        Returns\n        -------\n        utils.Index\n            The src nodes.\n        utils.Index\n            The dst nodes.\n        utils.Index\n            The edge ids.\n        \"\"\"\n        u_array = u.todgltensor()\n        v_array = v.todgltensor()\n        edge_array = _CAPI_DGLGraphEdgeIds(self, u_array, v_array)\n\n        src = utils.toindex(edge_array(0))\n        dst = utils.toindex(edge_array(1))\n        eid = utils.toindex(edge_array(2))\n\n        return src, dst, eid\n\n    def find_edge(self, eid):\n        \"\"\"Return the edge tuple of the given id.\n\n        Parameters\n        ----------\n        eid : int\n            The edge id.\n\n        Returns\n        -------\n        int\n            src node id\n        int\n            dst node id\n        \"\"\"\n        ret = _CAPI_DGLGraphFindEdge(self, int(eid))\n        return ret(0), ret(1)\n\n    def find_edges(self, eid):\n        \"\"\"Return a triplet of arrays that contains the edge IDs.\n\n        Parameters\n        ----------\n        eid : utils.Index\n            The edge ids.\n\n        Returns\n        -------\n        utils.Index\n            The src nodes.\n        utils.Index\n            The dst nodes.\n        utils.Index\n            The edge ids.\n        \"\"\"\n        eid_array = eid.todgltensor()\n        edge_array = _CAPI_DGLGraphFindEdges(self, eid_array)\n\n        src = utils.toindex(edge_array(0))\n        dst = utils.toindex(edge_array(1))\n        eid = utils.toindex(edge_array(2))\n\n        return src, dst, eid\n\n    def in_edges(self, v):\n        \"\"\"Return the in edges of the node(s).\n\n        Parameters\n        ----------\n        v : utils.Index\n            The node(s).\n\n        Returns\n        -------\n        utils.Index\n            The src nodes.\n        utils.Index\n            The dst nodes.\n        utils.Index\n            The edge ids.\n        \"\"\"\n        if len(v) == 1:\n            edge_array = _CAPI_DGLGraphInEdges_1(self, int(v[0]))\n        else:\n            v_array = v.todgltensor()\n            edge_array = _CAPI_DGLGraphInEdges_2(self, v_array)\n        src = utils.toindex(edge_array(0))\n        dst = utils.toindex(edge_array(1))\n        eid = utils.toindex(edge_array(2))\n        return src, dst, eid\n\n    def out_edges(self, v):\n        \"\"\"Return the out edges of the node(s).\n\n        Parameters\n        ----------\n        v : utils.Index\n            The node(s).\n\n        Returns\n        -------\n        utils.Index\n            The src nodes.\n        utils.Index\n            The dst nodes.\n        utils.Index\n            The edge ids.\n        \"\"\"\n        if len(v) == 1:\n            edge_array = _CAPI_DGLGraphOutEdges_1(self, int(v[0]))\n        else:\n            v_array = v.todgltensor()\n            edge_array = _CAPI_DGLGraphOutEdges_2(self, v_array)\n        src = utils.toindex(edge_array(0))\n        dst = utils.toindex(edge_array(1))\n        eid = utils.toindex(edge_array(2))\n        return src, dst, eid\n\n    def sort_csr(self):\n        \"\"\"Sort the CSR matrix in the graph index.\n\n        By default, when the CSR matrix is created, the edges may be stored\n        in an arbitrary order. Sometimes, we want to sort them to accelerate\n        some computation. For example, `has_edges_between` can be much faster\n        on a giant adjacency matrix if the edges in the matrix is sorted.\n        \"\"\"\n        _CAPI_DGLSortAdj(self)\n\n    @utils.cached_member(cache=\"_cache\", prefix=\"edges\")\n    def edges(self, order=None):\n        \"\"\"Return all the edges\n\n        Parameters\n        ----------\n        order : string\n            The order of the returned edges. Currently support:\n\n            - 'srcdst' : sorted by their src and dst ids.\n            - 'eid'    : sorted by edge Ids.\n            - None     : the arbitrary order.\n\n        Returns\n        -------\n        utils.Index\n            The src nodes.\n        utils.Index\n            The dst nodes.\n        utils.Index\n            The edge ids.\n        \"\"\"\n        if order is None:\n            order = \"\"\n        edge_array = _CAPI_DGLGraphEdges(self, order)\n        src = edge_array(0)\n        dst = edge_array(1)\n        eid = edge_array(2)\n        src = utils.toindex(src)\n        dst = utils.toindex(dst)\n        eid = utils.toindex(eid)\n        return src, dst, eid\n\n    def in_degree(self, v):\n        \"\"\"Return the in degree of the node.\n\n        Parameters\n        ----------\n        v : int\n            The node.\n\n        Returns\n        -------\n        int\n            The in degree.\n        \"\"\"\n        return _CAPI_DGLGraphInDegree(self, int(v))\n\n    def in_degrees(self, v):\n        \"\"\"Return the in degrees of the nodes.\n\n        Parameters\n        ----------\n        v : utils.Index\n            The nodes.\n\n        Returns\n        -------\n        tensor\n            The in degree array.\n        \"\"\"\n        v_array = v.todgltensor()\n        return utils.toindex(_CAPI_DGLGraphInDegrees(self, v_array))\n\n    def out_degree(self, v):\n        \"\"\"Return the out degree of the node.\n\n        Parameters\n        ----------\n        v : int\n            The node.\n\n        Returns\n        -------\n        int\n            The out degree.\n        \"\"\"\n        return _CAPI_DGLGraphOutDegree(self, int(v))\n\n    def out_degrees(self, v):\n        \"\"\"Return the out degrees of the nodes.\n\n        Parameters\n        ----------\n        v : utils.Index\n            The nodes.\n\n        Returns\n        -------\n        tensor\n            The out degree array.\n        \"\"\"\n        v_array = v.todgltensor()\n        return utils.toindex(_CAPI_DGLGraphOutDegrees(self, v_array))\n\n    def node_subgraph(self, v):\n        \"\"\"Return the induced node subgraph.\n\n        Parameters\n        ----------\n        v : utils.Index\n            The nodes.\n\n        Returns\n        -------\n        SubgraphIndex\n            The subgraph index.\n        \"\"\"\n        v_array = v.todgltensor()\n        return _CAPI_DGLGraphVertexSubgraph(self, v_array)\n\n    def node_halo_subgraph(self, v, num_hops):\n        \"\"\"Return an induced subgraph with halo nodes.\n\n        Parameters\n        ----------\n        v : utils.Index\n            The nodes.\n\n        num_hops : int\n            The number of hops in which a HALO node can be accessed.\n\n        Returns\n        -------\n        SubgraphIndex\n            The subgraph index.\n        DGLTensor\n            Indicate if a node belongs to a partition.\n        DGLTensor\n            Indicate if an edge belongs to a partition.\n        \"\"\"\n        v_array = v.todgltensor()\n        subg = _CAPI_DGLGetSubgraphWithHalo(self, v_array, num_hops)\n        inner_nodes = _CAPI_GetHaloSubgraphInnerNodes(subg)\n        return subg, inner_nodes\n\n    def node_subgraphs(self, vs_arr):\n        \"\"\"Return the induced node subgraphs.\n\n        Parameters\n        ----------\n        vs_arr : a list of utils.Index\n            The nodes.\n\n        Returns\n        -------\n        a vector of SubgraphIndex\n            The subgraph index.\n        \"\"\"\n        gis = []\n        for v in vs_arr:\n            gis.append(self.node_subgraph(v))\n        return gis\n\n    def edge_subgraph(self, e, preserve_nodes=False):\n        \"\"\"Return the induced edge subgraph.\n\n        Parameters\n        ----------\n        e : utils.Index\n            The edges.\n        preserve_nodes : bool\n            Indicates whether to preserve all nodes or not.\n            If true, keep the nodes which have no edge connected in the subgraph;\n            If false, all nodes without edge connected to it would be removed.\n\n        Returns\n        -------\n        SubgraphIndex\n            The subgraph index.\n        \"\"\"\n        e_array = e.todgltensor()\n        return _CAPI_DGLGraphEdgeSubgraph(self, e_array, preserve_nodes)\n\n    @utils.cached_member(cache=\"_cache\", prefix=\"scipy_adj\")\n    def adjacency_matrix_scipy(self, transpose, fmt, return_edge_ids=None):\n        \"\"\"Return the scipy adjacency matrix representation of this graph.\n\n        By default, a row of returned adjacency matrix represents the destination\n        of an edge and the column represents the source.\n\n        When transpose is True, a row represents the source and a column represents\n        a destination.\n\n        Parameters\n        ----------\n        transpose : bool\n            A flag to transpose the returned adjacency matrix.\n        fmt : str\n            Indicates the format of returned adjacency matrix.\n        return_edge_ids : bool\n            Indicates whether to return edge IDs or 1 as elements.\n\n        Returns\n        -------\n        scipy.sparse.spmatrix\n            The scipy representation of adjacency matrix.\n        \"\"\"\n        if not isinstance(transpose, bool):\n            raise DGLError(\n                'Expect bool value for \"transpose\" arg,'\n                \" but got %s.\" % (type(transpose))\n            )\n\n        if return_edge_ids is None:\n            dgl_warning(\n                \"Adjacency matrix by default currently returns edge IDs.\"\n                \"  As a result there is one 0 entry which is not eliminated.\"\n                \"  In the next release it will return 1s by default,\"\n                \" and 0 will be eliminated otherwise.\",\n                FutureWarning,\n            )\n            return_edge_ids = True\n\n        rst = _CAPI_DGLGraphGetAdj(self, transpose, fmt)\n        if fmt == \"csr\":\n            indptr = utils.toindex(rst(0)).tonumpy()\n            indices = utils.toindex(rst(1)).tonumpy()\n            data = (\n                utils.toindex(rst(2)).tonumpy()\n                if return_edge_ids\n                else np.ones_like(indices)\n            )\n            n = self.num_nodes()\n            return scipy.sparse.csr_matrix(\n                (data, indices, indptr), shape=(n, n)\n            )\n        elif fmt == \"coo\":\n            idx = utils.toindex(rst(0)).tonumpy()\n            n = self.num_nodes()\n            m = self.num_edges()\n            row, col = np.reshape(idx, (2, m))\n            data = np.arange(0, m) if return_edge_ids else np.ones_like(row)\n            return scipy.sparse.coo_matrix((data, (row, col)), shape=(n, n))\n        else:\n            raise Exception(\"unknown format\")\n\n    @utils.cached_member(cache=\"_cache\", prefix=\"immu_gidx\")\n    def get_immutable_gidx(self, ctx):\n        \"\"\"Create an immutable graph index and copy to the given device context.\n\n        Note: this internal function is for DGL scheduler use only\n\n        Parameters\n        ----------\n        ctx : DGLContext\n            The context of the returned graph.\n\n        Returns\n        -------\n        GraphIndex\n        \"\"\"\n        return self.to_immutable().asbits(self.bits_needed()).copy_to(ctx)\n\n    def get_csr_shuffle_order(self):\n        \"\"\"Return the edge shuffling order when a coo graph is converted to csr format\n\n        Returns\n        -------\n        tuple of two utils.Index\n            The first element of the tuple is the shuffle order for outward graph\n            The second element of the tuple is the shuffle order for inward graph\n        \"\"\"\n        csr = _CAPI_DGLGraphGetAdj(self, True, \"csr\")\n        order = csr(2)\n        rev_csr = _CAPI_DGLGraphGetAdj(self, False, \"csr\")\n        rev_order = rev_csr(2)\n        return utils.toindex(order), utils.toindex(rev_order)\n\n    def adjacency_matrix(self, transpose, ctx):\n        \"\"\"Return the adjacency matrix representation of this graph.\n\n        By default, a row of returned adjacency matrix represents the destination\n        of an edge and the column represents the source.\n\n        When transpose is True, a row represents the source and a column represents\n        a destination.\n\n        Parameters\n        ----------\n        transpose : bool\n            A flag to transpose the returned adjacency matrix.\n        ctx : context\n            The context of the returned matrix.\n\n        Returns\n        -------\n        SparseTensor\n            The adjacency matrix.\n        utils.Index\n            A index for data shuffling due to sparse format change. Return None\n            if shuffle is not required.\n        \"\"\"\n        if not isinstance(transpose, bool):\n            raise DGLError(\n                'Expect bool value for \"transpose\" arg,'\n                \" but got %s.\" % (type(transpose))\n            )\n        fmt = F.get_preferred_sparse_format()\n        rst = _CAPI_DGLGraphGetAdj(self, transpose, fmt)\n        if fmt == \"csr\":\n            indptr = F.copy_to(utils.toindex(rst(0)).tousertensor(), ctx)\n            indices = F.copy_to(utils.toindex(rst(1)).tousertensor(), ctx)\n            shuffle = utils.toindex(rst(2))\n            dat = F.ones(indices.shape, dtype=F.float32, ctx=ctx)\n            spmat = F.sparse_matrix(\n                dat,\n                (\"csr\", indices, indptr),\n                (self.num_nodes(), self.num_nodes()),\n            )[0]\n            return spmat, shuffle\n        elif fmt == \"coo\":\n            ## FIXME(minjie): data type\n            idx = F.copy_to(utils.toindex(rst(0)).tousertensor(), ctx)\n            m = self.num_edges()\n            idx = F.reshape(idx, (2, m))\n            dat = F.ones((m,), dtype=F.float32, ctx=ctx)\n            n = self.num_nodes()\n            adj, shuffle_idx = F.sparse_matrix(dat, (\"coo\", idx), (n, n))\n            shuffle_idx = (\n                utils.toindex(shuffle_idx) if shuffle_idx is not None else None\n            )\n            return adj, shuffle_idx\n        else:\n            raise Exception(\"unknown format\")\n\n    def incidence_matrix(self, typestr, ctx):\n        \"\"\"Return the incidence matrix representation of this graph.\n\n        An incidence matrix is an n x m sparse matrix, where n is\n        the number of nodes and m is the number of edges. Each nnz\n        value indicating whether the edge is incident to the node\n        or not.\n\n        There are three types of an incidence matrix `I`:\n        * \"in\":\n          - I[v, e] = 1 if e is the in-edge of v (or v is the dst node of e);\n          - I[v, e] = 0 otherwise.\n        * \"out\":\n          - I[v, e] = 1 if e is the out-edge of v (or v is the src node of e);\n          - I[v, e] = 0 otherwise.\n        * \"both\":\n          - I[v, e] = 1 if e is the in-edge of v;\n          - I[v, e] = -1 if e is the out-edge of v;\n          - I[v, e] = 0 otherwise (including self-loop).\n\n        Parameters\n        ----------\n        typestr : str\n            Can be either \"in\", \"out\" or \"both\"\n        ctx : context\n            The context of returned incidence matrix.\n\n        Returns\n        -------\n        SparseTensor\n            The incidence matrix.\n        utils.Index\n            A index for data shuffling due to sparse format change. Return None\n            if shuffle is not required.\n        \"\"\"\n        src, dst, eid = self.edges()\n        src = src.tousertensor(ctx)  # the index of the ctx will be cached\n        dst = dst.tousertensor(ctx)  # the index of the ctx will be cached\n        eid = eid.tousertensor(ctx)  # the index of the ctx will be cached\n        n = self.num_nodes()\n        m = self.num_edges()\n        if typestr == \"in\":\n            row = F.unsqueeze(dst, 0)\n            col = F.unsqueeze(eid, 0)\n            idx = F.cat([row, col], dim=0)\n            # FIXME(minjie): data type\n            dat = F.ones((m,), dtype=F.float32, ctx=ctx)\n            inc, shuffle_idx = F.sparse_matrix(dat, (\"coo\", idx), (n, m))\n        elif typestr == \"out\":\n            row = F.unsqueeze(src, 0)\n            col = F.unsqueeze(eid, 0)\n            idx = F.cat([row, col], dim=0)\n            # FIXME(minjie): data type\n            dat = F.ones((m,), dtype=F.float32, ctx=ctx)\n            inc, shuffle_idx = F.sparse_matrix(dat, (\"coo\", idx), (n, m))\n        elif typestr == \"both\":\n            # first remove entries for self loops\n            mask = F.logical_not(F.equal(src, dst))\n            src = F.boolean_mask(src, mask)\n            dst = F.boolean_mask(dst, mask)\n            eid = F.boolean_mask(eid, mask)\n            n_entries = F.shape(src)[0]\n            # create index\n            row = F.unsqueeze(F.cat([src, dst], dim=0), 0)\n            col = F.unsqueeze(F.cat([eid, eid], dim=0), 0)\n            idx = F.cat([row, col], dim=0)\n            # FIXME(minjie): data type\n            x = -F.ones((n_entries,), dtype=F.float32, ctx=ctx)\n            y = F.ones((n_entries,), dtype=F.float32, ctx=ctx)\n            dat = F.cat([x, y], dim=0)\n            inc, shuffle_idx = F.sparse_matrix(dat, (\"coo\", idx), (n, m))\n        else:\n            raise DGLError(\"Invalid incidence matrix type: %s\" % str(typestr))\n        shuffle_idx = (\n            utils.toindex(shuffle_idx) if shuffle_idx is not None else None\n        )\n        return inc, shuffle_idx\n\n    def to_networkx(self):\n        \"\"\"Convert to networkx graph.\n\n        The edge id will be saved as the 'id' edge attribute.\n\n        Returns\n        -------\n        networkx.DiGraph\n            The nx graph\n        \"\"\"\n        src, dst, eid = self.edges()\n        # xiangsx: Always treat graph as multigraph\n        ret = nx.MultiDiGraph()\n        ret.add_nodes_from(range(self.num_nodes()))\n        for u, v, e in zip(src, dst, eid):\n            ret.add_edge(u, v, id=e)\n        return ret\n\n    def line_graph(self, backtracking=True):\n        \"\"\"Return the line graph of this graph.\n\n        Parameters\n        ----------\n        backtracking : bool, optional (default=False)\n          Whether (i, j) ~ (j, i) in L(G).\n          (i, j) ~ (j, i) is the behavior of networkx.line_graph.\n\n        Returns\n        -------\n        GraphIndex\n            The line graph of this graph.\n        \"\"\"\n        return _CAPI_DGLGraphLineGraph(self, backtracking)\n\n    def to_immutable(self):\n        \"\"\"Convert this graph index to an immutable one.\n\n        Returns\n        -------\n        GraphIndex\n            An immutable graph index.\n        \"\"\"\n        return _CAPI_DGLToImmutable(self)\n\n    def ctx(self):\n        \"\"\"Return the context of this graph index.\n\n        Returns\n        -------\n        DGLContext\n            The context of the graph.\n        \"\"\"\n        return _CAPI_DGLGraphContext(self)\n\n    @property\n    def dtype(self):\n        \"\"\"Return the index dtype\n\n        Returns\n        ----------\n        str\n            The dtype of graph index\n        \"\"\"\n        bits = self.nbits()\n        if bits == 32:\n            return \"int32\"\n        else:\n            return \"int64\"\n\n    def copy_to(self, ctx):\n        \"\"\"Copy this immutable graph index to the given device context.\n\n        NOTE: this method only works for immutable graph index\n\n        Parameters\n        ----------\n        ctx : DGLContext\n            The target device context.\n\n        Returns\n        -------\n        GraphIndex\n            The graph index on the given device context.\n        \"\"\"\n        return _CAPI_DGLImmutableGraphCopyTo(\n            self, ctx.device_type, ctx.device_id\n        )\n\n    def copyto_shared_mem(self, shared_mem_name):\n        \"\"\"Copy this immutable graph index to shared memory.\n\n        NOTE: this method only works for immutable graph index\n\n        Parameters\n        ----------\n        shared_mem_name : string\n            The name of the shared memory.\n\n        Returns\n        -------\n        GraphIndex\n            The graph index on the given device context.\n        \"\"\"\n        return _CAPI_DGLImmutableGraphCopyToSharedMem(self, shared_mem_name)\n\n    def nbits(self):\n        \"\"\"Return the number of integer bits used in the storage (32 or 64).\n\n        Returns\n        -------\n        int\n            The number of bits.\n        \"\"\"\n        return _CAPI_DGLGraphNumBits(self)\n\n    def bits_needed(self):\n        \"\"\"Return the number of integer bits needed to represent the graph\n\n        Returns\n        -------\n        int\n            The number of bits needed\n        \"\"\"\n        if self.num_edges() >= 0x80000000 or self.num_nodes() >= 0x80000000:\n            return 64\n        else:\n            return 32\n\n    def asbits(self, bits):\n        \"\"\"Transform the graph to a new one with the given number of bits storage.\n\n        NOTE: this method only works for immutable graph index\n\n        Parameters\n        ----------\n        bits : int\n            The number of integer bits (32 or 64)\n\n        Returns\n        -------\n        GraphIndex\n            The graph index stored using the given number of bits.\n        \"\"\"\n        return _CAPI_DGLImmutableGraphAsNumBits(self, int(bits))\n\n\n@register_object(\"graph.Subgraph\")\nclass SubgraphIndex(ObjectBase):\n    \"\"\"Subgraph data structure\"\"\"\n\n    @property\n    def graph(self):\n        \"\"\"The subgraph structure\n\n        Returns\n        -------\n        GraphIndex\n            The subgraph\n        \"\"\"\n        return _CAPI_DGLSubgraphGetGraph(self)\n\n    @property\n    def induced_nodes(self):\n        \"\"\"Induced nodes for each node type. The return list\n        length should be equal to the number of node types.\n\n        Returns\n        -------\n        list of utils.Index\n            Induced nodes\n        \"\"\"\n        ret = _CAPI_DGLSubgraphGetInducedVertices(self)\n        return utils.toindex(ret)\n\n    @property\n    def induced_edges(self):\n        \"\"\"Induced edges for each edge type. The return list\n        length should be equal to the number of edge types.\n\n        Returns\n        -------\n        list of utils.Index\n            Induced edges\n        \"\"\"\n        ret = _CAPI_DGLSubgraphGetInducedEdges(self)\n        return utils.toindex(ret)\n\n\n###############################################################\n# Conversion functions\n###############################################################\ndef from_coo(num_nodes, src, dst, readonly):\n    \"\"\"Convert from coo arrays.\n\n    Parameters\n    ----------\n    num_nodes : int\n        Number of nodes.\n    src : Tensor\n        Src end nodes of the edges.\n    dst : Tensor\n        Dst end nodes of the edges.\n    readonly : bool\n        True if the returned graph is readonly.\n\n    Returns\n    -------\n    GraphIndex\n        The graph index.\n    \"\"\"\n    src = utils.toindex(src)\n    dst = utils.toindex(dst)\n    if readonly:\n        gidx = _CAPI_DGLGraphCreate(\n            src.todgltensor(), dst.todgltensor(), int(num_nodes), readonly\n        )\n    else:\n        gidx = _CAPI_DGLGraphCreateMutable()\n        gidx.add_nodes(num_nodes)\n        gidx.add_edges(src, dst)\n    return gidx\n\n\ndef from_csr(indptr, indices, direction):\n    \"\"\"Load a graph from CSR arrays.\n\n    Parameters\n    ----------\n    indptr : Tensor\n        index pointer in the CSR format\n    indices : Tensor\n        column index array in the CSR format\n    direction : str\n\n    Returns\n    ------\n    GraphIndex\n        The graph index\n        the edge direction. Either \"in\" or \"out\".\n    \"\"\"\n    indptr = utils.toindex(indptr)\n    indices = utils.toindex(indices)\n    gidx = _CAPI_DGLGraphCSRCreate(\n        indptr.todgltensor(), indices.todgltensor(), direction\n    )\n    return gidx\n\n\ndef from_shared_mem_graph_index(shared_mem_name):\n    \"\"\"Load a graph index from the shared memory.\n\n    Parameters\n    ----------\n    shared_mem_name : string\n        the name of shared memory\n\n    Returns\n    ------\n    GraphIndex\n        The graph index\n    \"\"\"\n    return _CAPI_DGLGraphCSRCreateMMap(shared_mem_name)\n\n\ndef from_networkx(nx_graph, readonly):\n    \"\"\"Convert from networkx graph.\n\n    If 'id' edge attribute exists, the edge will be added follows\n    the edge id order. Otherwise, order is undefined.\n\n    Parameters\n    ----------\n    nx_graph : networkx.DiGraph\n        The nx graph or any graph that can be converted to nx.DiGraph\n    readonly : bool\n        True if the returned graph is readonly.\n\n    Returns\n    -------\n    GraphIndex\n        The graph index.\n    \"\"\"\n    if not isinstance(nx_graph, nx.Graph):\n        nx_graph = nx.DiGraph(nx_graph)\n    else:\n        if not nx_graph.is_directed():\n            # to_directed creates a deep copy of the networkx graph even if\n            # the original graph is already directed and we do not want to do it.\n            nx_graph = nx_graph.to_directed()\n    num_nodes = nx_graph.number_of_nodes()\n\n    # nx_graph.edges(data=True) returns src, dst, attr_dict\n    if nx_graph.number_of_edges() > 0:\n        has_edge_id = \"id\" in next(iter(nx_graph.edges(data=True)))[-1]\n    else:\n        has_edge_id = False\n\n    if has_edge_id:\n        num_edges = nx_graph.number_of_edges()\n        src = np.zeros((num_edges,), dtype=np.int64)\n        dst = np.zeros((num_edges,), dtype=np.int64)\n        for u, v, attr in nx_graph.edges(data=True):\n            eid = attr[\"id\"]\n            src[eid] = u\n            dst[eid] = v\n    else:\n        src = []\n        dst = []\n        for e in nx_graph.edges:\n            src.append(e[0])\n            dst.append(e[1])\n    num_nodes = nx_graph.number_of_nodes()\n    # We store edge Ids as an edge attribute.\n    src = utils.toindex(src)\n    dst = utils.toindex(dst)\n    return from_coo(num_nodes, src, dst, readonly)\n\n\ndef from_scipy_sparse_matrix(adj, readonly):\n    \"\"\"Convert from scipy sparse matrix.\n\n    Parameters\n    ----------\n    adj : scipy sparse matrix\n    readonly : bool\n        True if the returned graph is readonly.\n\n    Returns\n    -------\n    GraphIndex\n        The graph index.\n    \"\"\"\n    if adj.getformat() != \"csr\" or not readonly:\n        num_nodes = max(adj.shape[0], adj.shape[1])\n        adj_coo = adj.tocoo()\n        return from_coo(num_nodes, adj_coo.row, adj_coo.col, readonly)\n    else:\n        # If the input matrix is csr, we still treat it as multigraph.\n        return from_csr(adj.indptr, adj.indices, \"out\")\n\n\ndef from_edge_list(elist, readonly):\n    \"\"\"Convert from an edge list.\n\n    Parameters\n    ---------\n    elist : list, tuple\n        List of (u, v) edge tuple, or a tuple of src/dst lists\n    \"\"\"\n    if isinstance(elist, tuple):\n        src, dst = elist\n    else:\n        src, dst = zip(*elist)\n    src = np.asarray(src)\n    dst = np.asarray(dst)\n    src_ids = utils.toindex(src)\n    dst_ids = utils.toindex(dst)\n    num_nodes = max(src.max(), dst.max()) + 1\n    return from_coo(num_nodes, src_ids, dst_ids, readonly)\n\n\ndef map_to_subgraph_nid(induced_nodes, parent_nids):\n    \"\"\"Map parent node Ids to the subgraph node Ids.\n\n    Parameters\n    ----------\n    induced_nodes: utils.Index\n        Induced nodes of the subgraph.\n\n    parent_nids: utils.Index\n        Node Ids in the parent graph.\n\n    Returns\n    -------\n    utils.Index\n        Node Ids in the subgraph.\n    \"\"\"\n    return utils.toindex(\n        _CAPI_DGLMapSubgraphNID(\n            induced_nodes.todgltensor(), parent_nids.todgltensor()\n        )\n    )\n\n\ndef transform_ids(mapping, ids):\n    \"\"\"Transform ids by the given mapping.\n\n    Parameters\n    ----------\n    mapping : utils.Index\n        The id mapping. new_id = mapping[old_id]\n    ids : utils.Index\n        The old ids.\n\n    Returns\n    -------\n    utils.Index\n        The new ids.\n    \"\"\"\n    return utils.toindex(\n        _CAPI_DGLMapSubgraphNID(mapping.todgltensor(), ids.todgltensor())\n    )\n\n\ndef disjoint_union(graphs):\n    \"\"\"Return a disjoint union of the input graphs.\n\n    The new graph will include all the nodes/edges in the given graphs.\n    Nodes/Edges will be relabeled by adding the cumsum of the previous graph sizes\n    in the given sequence order. For example, giving input [g1, g2, g3], where\n    they have 5, 6, 7 nodes respectively. Then node#2 of g2 will become node#7\n    in the result graph. Edge ids are re-assigned similarly.\n\n    Parameters\n    ----------\n    graphs : iterable of GraphIndex\n        The input graphs\n\n    Returns\n    -------\n    GraphIndex\n        The disjoint union\n    \"\"\"\n    return _CAPI_DGLDisjointUnion(list(graphs))\n\n\ndef disjoint_partition(graph, num_or_size_splits):\n    \"\"\"Partition the graph disjointly.\n\n    This is a reverse operation of DisjointUnion. The graph will be partitioned\n    into num graphs. This requires the given number of partitions to evenly\n    divides the number of nodes in the graph. If the a size list is given,\n    the sum of the given sizes is equal.\n\n    Parameters\n    ----------\n    graph : GraphIndex\n        The graph to be partitioned\n    num_or_size_splits : int or utils.Index\n        The partition number of size splits\n\n    Returns\n    -------\n    list of GraphIndex\n        The partitioned graphs\n    \"\"\"\n    if isinstance(num_or_size_splits, utils.Index):\n        rst = _CAPI_DGLDisjointPartitionBySizes(\n            graph, num_or_size_splits.todgltensor()\n        )\n    else:\n        rst = _CAPI_DGLDisjointPartitionByNum(graph, int(num_or_size_splits))\n    return rst\n\n\ndef create_graph_index(graph_data, readonly):\n    \"\"\"Create a graph index object.\n\n    Parameters\n    ----------\n    graph_data : graph data\n        Data to initialize graph. Same as networkx's semantics.\n    readonly : bool\n        Whether the graph structure is read-only.\n    \"\"\"\n    if isinstance(graph_data, GraphIndex):\n        # FIXME(minjie): this return is not correct for mutable graph index\n        return graph_data\n\n    if graph_data is None:\n        if readonly:\n            raise Exception(\"can't create an empty immutable graph\")\n        return _CAPI_DGLGraphCreateMutable()\n    elif isinstance(graph_data, (list, tuple)):\n        # edge list\n        return from_edge_list(graph_data, readonly)\n    elif isinstance(graph_data, scipy.sparse.spmatrix):\n        # scipy format\n        return from_scipy_sparse_matrix(graph_data, readonly)\n    else:\n        # networkx - any format\n        try:\n            gidx = from_networkx(graph_data, readonly)\n        except Exception:  # pylint: disable=broad-except\n            raise DGLError(\n                'Error while creating graph from input of type \"%s\".'\n                % type(graph_data)\n            )\n        return gidx\n\n\ndef _get_halo_subgraph_inner_node(halo_subg):\n    return _CAPI_GetHaloSubgraphInnerNodes(halo_subg)\n\n\n_init_api(\"dgl.graph_index\")\n"
  },
  {
    "path": "python/dgl/graphbolt/__init__.py",
    "content": "\"\"\"Graphbolt.\"\"\"\nimport os\nimport sys\n\nfrom .internal_utils import *\n\nCUDA_ALLOCATOR_ENV_WARNING_STR = \"\"\"\nAn experimental feature for CUDA allocations is turned on for better allocation\npattern resulting in better memory usage for minibatch GNN training workloads.\nSee https://pytorch.org/docs/stable/notes/cuda.html#optimizing-memory-usage-with-pytorch-cuda-alloc-conf,\nand set the environment variable `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:False`\nif you want to disable it and set it True to acknowledge and disable the warning.\n\"\"\"\ncuda_allocator_env = os.getenv(\"PYTORCH_CUDA_ALLOC_CONF\")\nWARNING_STR_TO_BE_SHOWN = None\nconfigs = (\n    {}\n    if cuda_allocator_env is None or len(cuda_allocator_env) == 0\n    else {\n        kv_pair.split(\":\")[0]: kv_pair.split(\":\")[1]\n        for kv_pair in cuda_allocator_env.split(\",\")\n    }\n)\nif \"expandable_segments\" in configs:\n    if configs[\"expandable_segments\"] != \"True\":\n        WARNING_STR_TO_BE_SHOWN = (\n            \"You should consider `expandable_segments:True` in the\"\n            \" environment variable `PYTORCH_CUDA_ALLOC_CONF` for lower\"\n            \" memory usage. See \"\n            \"https://pytorch.org/docs/stable/notes/cuda.html\"\n            \"#optimizing-memory-usage-with-pytorch-cuda-alloc-conf\"\n        )\nelse:\n    configs[\"expandable_segments\"] = \"True\"\n    os.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = \",\".join(\n        [k + \":\" + v for k, v in configs.items()]\n    )\n    WARNING_STR_TO_BE_SHOWN = CUDA_ALLOCATOR_ENV_WARNING_STR\ndel configs\ndel cuda_allocator_env\ndel CUDA_ALLOCATOR_ENV_WARNING_STR\n\n# pylint: disable=wrong-import-position, wrong-import-order\nimport torch\n\n### FROM DGL @todo\nfrom .._ffi import libinfo\n\n\ndef load_graphbolt():\n    \"\"\"Load Graphbolt C++ library\"\"\"\n    vers = torch.__version__.split(\"+\", maxsplit=1)[0]\n\n    if sys.platform.startswith(\"linux\"):\n        basename = f\"libgraphbolt_pytorch_{vers}.so\"\n    elif sys.platform.startswith(\"darwin\"):\n        basename = f\"libgraphbolt_pytorch_{vers}.dylib\"\n    elif sys.platform.startswith(\"win\"):\n        basename = f\"graphbolt_pytorch_{vers}.dll\"\n    else:\n        raise NotImplementedError(\"Unsupported system: %s\" % sys.platform)\n\n    dirname = os.path.dirname(libinfo.find_lib_path()[0])\n    path = os.path.join(dirname, \"graphbolt\", basename)\n    if not os.path.exists(path):\n        raise FileNotFoundError(\n            f\"Unable to locate the DGL C++ GraphBolt library at {path}. This \"\n            \"error typically occurs due to a version mismatch between the \"\n            \"installed DGL and the PyTorch version you are currently using. \"\n            \"Please ensure that your DGL installation is compatible with your \"\n            \"PyTorch version. For more information, refer to the installation \"\n            \"guide at https://www.dgl.ai/pages/start.html.\"\n        )\n\n    try:\n        torch.classes.load_library(path)\n    except Exception:  # pylint: disable=W0703\n        raise ImportError(\"Cannot load Graphbolt C++ library\")\n\n\nload_graphbolt()\n\n# pylint: disable=wrong-import-position\nfrom .base import *\nfrom .minibatch import *\nfrom .dataloader import *\nfrom .datapipes import *\nfrom .dataset import *\nfrom .feature_fetcher import *\nfrom .feature_store import *\nfrom .impl import *\nfrom .itemset import *\nfrom .item_sampler import *\nfrom .minibatch_transformer import *\nfrom .negative_sampler import *\nfrom .sampled_subgraph import *\nfrom .subgraph_sampler import *\nfrom .external_utils import add_reverse_edges, exclude_seed_edges\nfrom .internal import (\n    compact_csc_format,\n    numpy_save_aligned,\n    unique_and_compact,\n    unique_and_compact_csc_formats,\n)\n\nif torch.cuda.is_available() and not built_with_cuda():\n    raise ImportError(\n        \"torch was installed with CUDA support while GraphBolt's CPU version \"\n        \"is installed. Consider reinstalling GraphBolt with CUDA support, see \"\n        \"installation instructions at https://www.dgl.ai/pages/start.html\"\n    )\n\nif torch.cuda.is_available() and WARNING_STR_TO_BE_SHOWN is not None:\n    gb_warning(WARNING_STR_TO_BE_SHOWN)\ndel WARNING_STR_TO_BE_SHOWN\n\ntorch.ops.graphbolt.set_num_io_uring_threads(\n    min((torch.get_num_threads() + 1) // 2, 8)\n)\n"
  },
  {
    "path": "python/dgl/graphbolt/base.py",
    "content": "\"\"\"Base types and utilities for Graph Bolt.\"\"\"\n\nfrom collections import deque\nfrom dataclasses import dataclass\n\nimport torch\nfrom torch.torch_version import TorchVersion\n\nif (\n    TorchVersion(torch.__version__) >= \"2.3.0\"\n    and TorchVersion(torch.__version__) < \"2.3.1\"\n):\n    # Due to https://github.com/dmlc/dgl/issues/7380, for torch 2.3.0, we need\n    # to check if dill is available before using it.\n    torch.utils.data.datapipes.utils.common.DILL_AVAILABLE = (\n        torch.utils._import_utils.dill_available()\n    )\n\n# pylint: disable=wrong-import-position\nfrom torch.utils.data import functional_datapipe, IterDataPipe\n\nfrom .internal_utils import (\n    get_nonproperty_attributes,\n    recursive_apply,\n    recursive_apply_reduce_all,\n)\n\n__all__ = [\n    \"CANONICAL_ETYPE_DELIMITER\",\n    \"ORIGINAL_EDGE_ID\",\n    \"etype_str_to_tuple\",\n    \"etype_tuple_to_str\",\n    \"CopyTo\",\n    \"Waiter\",\n    \"Bufferer\",\n    \"EndMarker\",\n    \"isin\",\n    \"index_select\",\n    \"expand_indptr\",\n    \"indptr_edge_ids\",\n    \"CSCFormatBase\",\n    \"seed\",\n    \"seed_type_str_to_ntypes\",\n    \"get_host_to_device_uva_stream\",\n    \"get_device_to_host_uva_stream\",\n]\n\nCANONICAL_ETYPE_DELIMITER = \":\"\nORIGINAL_EDGE_ID = \"_ORIGINAL_EDGE_ID\"\n\n\n# There needs to be a single instance of the uva_stream, if it is created\n# multiple times, it leads to multiple CUDA memory pools and memory leaks.\ndef get_host_to_device_uva_stream():\n    \"\"\"The host to device copy stream to be used for pipeline parallelism.\"\"\"\n    if not hasattr(get_host_to_device_uva_stream, \"stream\"):\n        get_host_to_device_uva_stream.stream = torch.cuda.Stream(priority=-1)\n    return get_host_to_device_uva_stream.stream\n\n\ndef get_device_to_host_uva_stream():\n    \"\"\"The device to host copy stream to be used for pipeline parallelism.\"\"\"\n    if not hasattr(get_device_to_host_uva_stream, \"stream\"):\n        get_device_to_host_uva_stream.stream = torch.cuda.Stream(priority=-1)\n    return get_device_to_host_uva_stream.stream\n\n\ndef seed(val):\n    \"\"\"Set the random seed of Graphbolt.\n\n    Parameters\n    ----------\n    val : int\n        The seed.\n    \"\"\"\n    torch.ops.graphbolt.set_seed(val)\n\n\ndef isin(elements, test_elements):\n    \"\"\"Tests if each element of elements is in test_elements. Returns a boolean\n    tensor of the same shape as elements that is True for elements in\n    test_elements and False otherwise.\n\n    Parameters\n    ----------\n    elements : torch.Tensor\n        A 1D tensor represents the input elements.\n    test_elements : torch.Tensor\n        A 1D tensor represents the values to test against for each input.\n\n    Examples\n    --------\n    >>> isin(torch.tensor([1, 2, 3, 4]), torch.tensor([2, 3]))\n    tensor([[False,  True,  True,  False]])\n    \"\"\"\n    assert elements.dim() == 1, \"Elements should be 1D tensor.\"\n    assert test_elements.dim() == 1, \"Test_elements should be 1D tensor.\"\n    return torch.ops.graphbolt.isin(elements, test_elements)\n\n\nif TorchVersion(torch.__version__) >= TorchVersion(\"2.2.0a0\"):\n\n    torch_fake_decorator = (\n        torch.library.impl_abstract\n        if TorchVersion(torch.__version__) < TorchVersion(\"2.4.0a0\")\n        else torch.library.register_fake\n    )\n\n    @torch_fake_decorator(\"graphbolt::expand_indptr\")\n    def expand_indptr_fake(indptr, dtype, node_ids, output_size):\n        \"\"\"Fake implementation of expand_indptr for torch.compile() support.\"\"\"\n        if output_size is None:\n            output_size = torch.library.get_ctx().new_dynamic_size()\n        if dtype is None:\n            dtype = node_ids.dtype\n        return indptr.new_empty(output_size, dtype=dtype)\n\n\ndef expand_indptr(indptr, dtype=None, node_ids=None, output_size=None):\n    \"\"\"Converts a given indptr offset tensor to a COO format tensor. If\n    node_ids is not given, it is assumed to be equal to\n    torch.arange(indptr.size(0) - 1, dtype=dtype, device=indptr.device).\n\n    This is equivalent to\n\n    .. code:: python\n\n       if node_ids is None:\n           node_ids = torch.arange(len(indptr) - 1, dtype=dtype, device=indptr.device)\n       return node_ids.to(dtype).repeat_interleave(indptr.diff())\n\n    Parameters\n    ----------\n    indptr : torch.Tensor\n        A 1D tensor represents the csc_indptr tensor.\n    dtype : Optional[torch.dtype]\n        The dtype of the returned output tensor.\n    node_ids : Optional[torch.Tensor]\n        A 1D tensor represents the column node ids that the returned tensor will\n        be populated with.\n    output_size : Optional[int]\n        The size of the output tensor. Should be equal to indptr[-1]. Using this\n        argument avoids a stream synchronization to calculate the output shape.\n\n    Returns\n    -------\n    torch.Tensor\n        The converted COO tensor with values from node_ids.\n    \"\"\"\n    assert indptr.dim() == 1, \"Indptr should be 1D tensor.\"\n    assert not (\n        node_ids is None and dtype is None\n    ), \"One of node_ids or dtype must be given.\"\n    assert (\n        node_ids is None or node_ids.dim() == 1\n    ), \"Node_ids should be 1D tensor.\"\n    if dtype is None:\n        dtype = node_ids.dtype\n    return torch.ops.graphbolt.expand_indptr(\n        indptr, dtype, node_ids, output_size\n    )\n\n\nif TorchVersion(torch.__version__) >= TorchVersion(\"2.2.0a0\"):\n\n    torch_fake_decorator = (\n        torch.library.impl_abstract\n        if TorchVersion(torch.__version__) < TorchVersion(\"2.4.0a0\")\n        else torch.library.register_fake\n    )\n\n    @torch_fake_decorator(\"graphbolt::indptr_edge_ids\")\n    def indptr_edge_ids_fake(indptr, dtype, offset, output_size):\n        \"\"\"Fake implementation of indptr_edge_ids for torch.compile() support.\"\"\"\n        if output_size is None:\n            output_size = torch.library.get_ctx().new_dynamic_size()\n        if dtype is None:\n            dtype = offset.dtype\n        return indptr.new_empty(output_size, dtype=dtype)\n\n\ndef indptr_edge_ids(indptr, dtype=None, offset=None, output_size=None):\n    \"\"\"Converts a given indptr offset tensor to a COO format tensor for the edge\n    ids. For a given indptr [0, 2, 5, 7] and offset tensor [0, 100, 200], the\n    output will be [0, 1, 100, 101, 102, 201, 202]. If offset was not provided,\n    the output would be [0, 1, 0, 1, 2, 0, 1].\n\n    Parameters\n    ----------\n    indptr : torch.Tensor\n        A 1D tensor represents the csc_indptr tensor.\n    dtype : Optional[torch.dtype]\n        The dtype of the returned output tensor.\n    offset : Optional[torch.Tensor]\n        A 1D tensor represents the offsets that the returned tensor will be\n        populated with.\n    output_size : Optional[int]\n        The size of the output tensor. Should be equal to indptr[-1]. Using this\n        argument avoids a stream synchronization to calculate the output shape.\n\n    Returns\n    -------\n    torch.Tensor\n        The converted COO edge ids tensor.\n    \"\"\"\n    assert indptr.dim() == 1, \"Indptr should be 1D tensor.\"\n    assert offset is None or offset.dim() == 1, \"Offset should be 1D tensor.\"\n    if dtype is None:\n        dtype = offset.dtype\n    return torch.ops.graphbolt.indptr_edge_ids(\n        indptr, dtype, offset, output_size\n    )\n\n\ndef index_select(tensor, index):\n    \"\"\"Returns a new tensor which indexes the input tensor along dimension dim\n    using the entries in index.\n\n    The returned tensor has the same number of dimensions as the original tensor\n    (tensor). The first dimension has the same size as the length of index;\n    other dimensions have the same size as in the original tensor.\n\n    When tensor is a pinned tensor and index.is_cuda is True, the operation runs\n    on the CUDA device and the returned tensor will also be on CUDA.\n\n    Parameters\n    ----------\n    tensor : torch.Tensor\n        The input tensor.\n    index : torch.Tensor\n        The 1-D tensor containing the indices to index.\n\n    Returns\n    -------\n    torch.Tensor\n        The indexed input tensor, equivalent to tensor[index]. If index is in\n        pinned memory, then the result is placed into pinned memory as well.\n    \"\"\"\n    assert index.dim() == 1, \"Index should be 1D tensor.\"\n    return torch.ops.graphbolt.index_select(tensor, index)\n\n\ndef etype_tuple_to_str(c_etype):\n    \"\"\"Convert canonical etype from tuple to string.\n\n    Examples\n    --------\n    >>> c_etype = (\"user\", \"like\", \"item\")\n    >>> c_etype_str = _etype_tuple_to_str(c_etype)\n    >>> print(c_etype_str)\n    \"user:like:item\"\n    \"\"\"\n    assert isinstance(c_etype, tuple) and len(c_etype) == 3, (\n        \"Passed-in canonical etype should be in format of (str, str, str). \"\n        f\"But got {c_etype}.\"\n    )\n    return CANONICAL_ETYPE_DELIMITER.join(c_etype)\n\n\ndef etype_str_to_tuple(c_etype):\n    \"\"\"Convert canonical etype from string to tuple.\n\n    Examples\n    --------\n    >>> c_etype_str = \"user:like:item\"\n    >>> c_etype = _etype_str_to_tuple(c_etype_str)\n    >>> print(c_etype)\n    (\"user\", \"like\", \"item\")\n    \"\"\"\n    if isinstance(c_etype, tuple):\n        return c_etype\n    ret = tuple(c_etype.split(CANONICAL_ETYPE_DELIMITER))\n    assert len(ret) == 3, (\n        \"Passed-in canonical etype should be in format of 'str:str:str'. \"\n        f\"But got {c_etype}.\"\n    )\n    return ret\n\n\ndef seed_type_str_to_ntypes(seed_type, seed_size):\n    \"\"\"Convert seeds type to node types from string to list.\n\n    Examples\n    --------\n    1. node pairs\n\n    >>> seed_type = \"user:like:item\"\n    >>> seed_size = 2\n    >>> node_type = seed_type_str_to_ntypes(seed_type, seed_size)\n    >>> print(node_type)\n    [\"user\", \"item\"]\n\n    2. hyperlink\n\n    >>> seed_type = \"query:user:item\"\n    >>> seed_size = 3\n    >>> node_type = seed_type_str_to_ntypes(seed_type, seed_size)\n    >>> print(node_type)\n    [\"query\", \"user\", \"item\"]\n    \"\"\"\n    assert isinstance(\n        seed_type, str\n    ), f\"Passed-in seed type should be string, but got {type(seed_type)}\"\n    ntypes = seed_type.split(CANONICAL_ETYPE_DELIMITER)\n    is_hyperlink = len(ntypes) == seed_size\n    if not is_hyperlink:\n        ntypes = ntypes[::2]\n    return ntypes\n\n\ndef apply_to(x, device, non_blocking=False):\n    \"\"\"Apply `to` function to object x only if it has `to`.\"\"\"\n\n    if device == \"pinned\" and hasattr(x, \"pin_memory\"):\n        return x.pin_memory()\n    if not hasattr(x, \"to\"):\n        return x\n    if not non_blocking:\n        return x.to(device)\n    return x.to(device, non_blocking=True)\n\n\ndef is_object_pinned(obj):\n    \"\"\"Recursively check all members of the object and return True if only if\n    all are pinned.\"\"\"\n\n    for attr in get_nonproperty_attributes(obj):\n        member_result = recursive_apply_reduce_all(\n            getattr(obj, attr),\n            lambda x: x is None or x.is_pinned(),\n        )\n        if not member_result:\n            return False\n    return True\n\n\n@functional_datapipe(\"copy_to\")\nclass CopyTo(IterDataPipe):\n    \"\"\"DataPipe that transfers each element yielded from the previous DataPipe\n    to the given device. For MiniBatch, only the related attributes\n    (automatically inferred) will be transferred by default.\n\n    Functional name: :obj:`copy_to`.\n\n    When ``data`` has ``to`` method implemented, ``CopyTo`` will be equivalent\n    to\n\n    .. code:: python\n\n       for data in datapipe:\n           yield data.to(device)\n\n    Parameters\n    ----------\n    datapipe : DataPipe\n        The DataPipe.\n    device : torch.device\n        The PyTorch CUDA device.\n    non_blocking : bool\n        Whether the copy should be performed without blocking. All elements have\n        to be already in pinned system memory if enabled. Default is False.\n    \"\"\"\n\n    def __init__(self, datapipe, device, non_blocking=False):\n        super().__init__()\n        self.datapipe = datapipe\n        self.device = torch.device(device)\n        self.non_blocking = non_blocking\n\n    def __iter__(self):\n        for data in self.datapipe:\n            yield recursive_apply(\n                data, apply_to, self.device, self.non_blocking\n            )\n\n\n@functional_datapipe(\"mark_end\")\nclass EndMarker(IterDataPipe):\n    \"\"\"Used to mark the end of a datapipe and is a no-op.\"\"\"\n\n    def __init__(self, datapipe):\n        self.datapipe = datapipe\n\n    def __iter__(self):\n        yield from self.datapipe\n\n\n@functional_datapipe(\"buffer\")\nclass Bufferer(IterDataPipe):\n    \"\"\"Buffers items before yielding them.\n\n    Parameters\n    ----------\n    datapipe : DataPipe\n        The data pipeline.\n    buffer_size : int, optional\n        The size of the buffer which stores the fetched samples. If data coming\n        from datapipe has latency spikes, consider setting to a higher value.\n        Default is 1.\n    \"\"\"\n\n    def __init__(self, datapipe, buffer_size=1):\n        self.datapipe = datapipe\n        if buffer_size <= 0:\n            raise ValueError(\n                \"'buffer_size' is required to be a positive integer.\"\n            )\n        self.buffer = deque(maxlen=buffer_size)\n\n    def __iter__(self):\n        for data in self.datapipe:\n            if len(self.buffer) < self.buffer.maxlen:\n                self.buffer.append(data)\n            else:\n                return_data = self.buffer.popleft()\n                self.buffer.append(data)\n                yield return_data\n        while len(self.buffer) > 0:\n            yield self.buffer.popleft()\n\n    def __getstate__(self):\n        state = (self.datapipe, self.buffer.maxlen)\n        if IterDataPipe.getstate_hook is not None:\n            return IterDataPipe.getstate_hook(state)\n        return state\n\n    def __setstate__(self, state):\n        self.datapipe, buffer_size = state\n        self.buffer = deque(maxlen=buffer_size)\n\n    def reset(self):\n        \"\"\"Resets the state of the datapipe.\"\"\"\n        self.buffer.clear()\n\n\n@functional_datapipe(\"wait\")\nclass Waiter(IterDataPipe):\n    \"\"\"Calls the wait function of all items.\"\"\"\n\n    def __init__(self, datapipe):\n        self.datapipe = datapipe\n\n    def __iter__(self):\n        for data in self.datapipe:\n            data.wait()\n            yield data\n\n\n@dataclass\nclass CSCFormatBase:\n    r\"\"\"Basic class representing data in Compressed Sparse Column (CSC) format.\n\n    Examples\n    --------\n    >>> indptr = torch.tensor([0, 1, 3])\n    >>> indices = torch.tensor([1, 4, 2])\n    >>> csc_foramt_base = CSCFormatBase(indptr=indptr, indices=indices)\n    >>> print(csc_format_base.indptr)\n    ... torch.tensor([0, 1, 3])\n    >>> print(csc_foramt_base)\n    ... torch.tensor([1, 4, 2])\n    \"\"\"\n\n    indptr: torch.Tensor = None\n    indices: torch.Tensor = None\n\n    def __init__(self, indptr: torch.Tensor, indices: torch.Tensor):\n        self.indptr = indptr\n        self.indices = indices\n        if not indptr.is_cuda:\n            assert self.indptr[-1] == len(\n                self.indices\n            ), \"The last element of indptr should be the same as the length of indices.\"\n\n    def __repr__(self) -> str:\n        return _csc_format_base_str(self)\n\n    def to(  # pylint: disable=invalid-name\n        self, device: torch.device, non_blocking=False\n    ) -> None:\n        \"\"\"Copy `CSCFormatBase` to the specified device using reflection.\"\"\"\n\n        for attr in dir(self):\n            # Only copy member variables.\n            if not callable(getattr(self, attr)) and not attr.startswith(\"__\"):\n                setattr(\n                    self,\n                    attr,\n                    recursive_apply(\n                        getattr(self, attr),\n                        apply_to,\n                        device,\n                        non_blocking=non_blocking,\n                    ),\n                )\n\n        return self\n\n    def pin_memory(self):\n        \"\"\"Copy `SampledSubgraph` to the pinned memory using reflection.\"\"\"\n\n        return self.to(\"pinned\")\n\n    def is_pinned(self) -> bool:\n        \"\"\"Check whether `SampledSubgraph` is pinned using reflection.\"\"\"\n\n        return is_object_pinned(self)\n\n\ndef _csc_format_base_str(csc_format_base: CSCFormatBase) -> str:\n    final_str = \"CSCFormatBase(\"\n\n    def _add_indent(_str, indent):\n        lines = _str.split(\"\\n\")\n        lines = [lines[0]] + [\" \" * indent + line for line in lines[1:]]\n        return \"\\n\".join(lines)\n\n    final_str += (\n        f\"indptr={_add_indent(str(csc_format_base.indptr), 21)},\\n\" + \" \" * 14\n    )\n    final_str += (\n        f\"indices={_add_indent(str(csc_format_base.indices), 22)},\\n\" + \")\"\n    )\n    return final_str\n"
  },
  {
    "path": "python/dgl/graphbolt/dataloader.py",
    "content": "\"\"\"Graph Bolt DataLoaders\"\"\"\n\nimport torch\nimport torch.utils.data as torch_data\n\nfrom .base import CopyTo\nfrom .datapipes import (\n    datapipe_graph_to_adjlist,\n    find_dps,\n    replace_dp,\n    traverse_dps,\n)\nfrom .feature_fetcher import FeatureFetcher, FeatureFetcherStartMarker\nfrom .impl.neighbor_sampler import SamplePerLayer\nfrom .internal_utils import gb_warning\nfrom .item_sampler import ItemSampler\nfrom .minibatch_transformer import MiniBatchTransformer\n\n\n__all__ = [\n    \"DataLoader\",\n]\n\n\ndef _find_and_wrap_parent(datapipe_graph, target_datapipe, wrapper, **kwargs):\n    \"\"\"Find parent of target_datapipe and wrap it with .\"\"\"\n    datapipes = find_dps(\n        datapipe_graph,\n        target_datapipe,\n    )\n    datapipe_adjlist = datapipe_graph_to_adjlist(datapipe_graph)\n    for datapipe in datapipes:\n        datapipe_id = id(datapipe)\n        for parent_datapipe_id in datapipe_adjlist[datapipe_id][1]:\n            parent_datapipe, _ = datapipe_adjlist[parent_datapipe_id]\n            datapipe_graph = replace_dp(\n                datapipe_graph,\n                parent_datapipe,\n                wrapper(parent_datapipe, **kwargs),\n            )\n    return datapipe_graph\n\n\ndef _set_worker_id(worked_id):\n    torch.ops.graphbolt.set_worker_id(worked_id)\n\n\nclass MultiprocessingWrapper(torch_data.IterDataPipe):\n    \"\"\"Wraps a datapipe with multiprocessing.\n\n    Parameters\n    ----------\n    datapipe : DataPipe\n        The data pipeline.\n    num_workers : int, optional\n        The number of worker processes. Default is 0, meaning that there\n        will be no multiprocessing.\n    persistent_workers : bool, optional\n        If True, the data loader will not shut down the worker processes after a\n        dataset has been consumed once. This allows to maintain the workers\n        instances alive.\n    \"\"\"\n\n    def __init__(self, datapipe, num_workers=0, persistent_workers=True):\n        self.datapipe = datapipe\n        self.dataloader = torch_data.DataLoader(\n            datapipe,\n            batch_size=None,\n            num_workers=num_workers,\n            persistent_workers=(num_workers > 0) and persistent_workers,\n            worker_init_fn=_set_worker_id if num_workers > 0 else None,\n        )\n\n    def __iter__(self):\n        yield from self.dataloader\n\n\nclass DataLoader(MiniBatchTransformer):\n    \"\"\"Multiprocessing DataLoader.\n\n    Iterates over the data pipeline with everything before feature fetching\n    (i.e. :class:`dgl.graphbolt.FeatureFetcher`) in subprocesses, and\n    everything after feature fetching in the main process. The datapipe\n    is modified in-place as a result.\n\n    When the copy_to operation is placed earlier in the data pipeline, the\n    num_workers argument is required to be 0 as utilizing CUDA in multiple\n    worker processes is not supported.\n\n    Parameters\n    ----------\n    datapipe : DataPipe\n        The data pipeline.\n    num_workers : int, optional\n        Number of worker processes. Default is 0.\n    persistent_workers : bool, optional\n        If True, the data loader will not shut down the worker processes after a\n        dataset has been consumed once. This allows to maintain the workers\n        instances alive.\n    max_uva_threads : int, optional\n        Limits the number of CUDA threads used for UVA copies so that the rest\n        of the computations can run simultaneously with it. Setting it to a too\n        high value will limit the amount of overlap while setting it too low may\n        cause the PCI-e bandwidth to not get fully utilized. Manually tuned\n        default is 10240, meaning around 5-7 Streaming Multiprocessors.\n    \"\"\"\n\n    def __init__(\n        self,\n        datapipe,\n        num_workers=0,\n        persistent_workers=True,\n        max_uva_threads=10240,\n    ):\n        # Multiprocessing requires two modifications to the datapipe:\n        #\n        # 1. Insert a stage after ItemSampler to distribute the\n        #    minibatches evenly across processes.\n        # 2. Cut the datapipe at FeatureFetcher, and wrap the inner datapipe\n        #    of the FeatureFetcher with a multiprocessing PyTorch DataLoader.\n\n        datapipe = datapipe.mark_end()\n        datapipe_graph = traverse_dps(datapipe)\n\n        if num_workers > 0:\n            # (1) Insert minibatch distribution.\n            # TODO(BarclayII): Currently I'm using sharding_filter() as a\n            # concept demonstration. Later on minibatch distribution should be\n            # merged into ItemSampler to maximize efficiency.\n            item_samplers = find_dps(\n                datapipe_graph,\n                ItemSampler,\n            )\n            for item_sampler in item_samplers:\n                datapipe_graph = replace_dp(\n                    datapipe_graph,\n                    item_sampler,\n                    item_sampler.sharding_filter(),\n                )\n\n            # (2) Cut datapipe at FeatureFetcher and wrap.\n            datapipe_graph = _find_and_wrap_parent(\n                datapipe_graph,\n                FeatureFetcherStartMarker,\n                MultiprocessingWrapper,\n                num_workers=num_workers,\n                persistent_workers=persistent_workers,\n            )\n\n        # (3) Limit the number of UVA threads used if the feature_fetcher\n        # or any of the samplers have overlapping optimization enabled.\n        if num_workers == 0 and torch.cuda.is_available():\n            feature_fetchers = find_dps(\n                datapipe_graph,\n                FeatureFetcher,\n            )\n            for feature_fetcher in feature_fetchers:\n                if feature_fetcher.max_num_stages > 0:  # Overlap enabled.\n                    torch.ops.graphbolt.set_max_uva_threads(max_uva_threads)\n\n        if num_workers == 0 and torch.cuda.is_available():\n            samplers = find_dps(\n                datapipe_graph,\n                SamplePerLayer,\n            )\n            for sampler in samplers:\n                if sampler.overlap_fetch:\n                    torch.ops.graphbolt.set_max_uva_threads(max_uva_threads)\n\n        # (4) Cut datapipe at CopyTo and wrap with pinning and prefetching\n        # before it. This enables enables non_blocking copies to the device.\n        # Prefetching enables the data pipeline up to the CopyTo to run in a\n        # separate thread.\n        copiers = find_dps(datapipe_graph, CopyTo)\n        if len(copiers) > 1:\n            gb_warning(\n                \"Multiple CopyTo operations were found in the datapipe graph.\"\n                \" This case is not officially supported.\"\n            )\n        for copier in copiers:\n            # We enable the prefetch at all times for good CPU only performance.\n            datapipe_graph = replace_dp(\n                datapipe_graph,\n                copier,\n                # Add prefetch so that CPU and GPU can run concurrently.\n                copier.datapipe.prefetch(2).copy_to(\n                    copier.device, non_blocking=True\n                ),\n            )\n\n        super().__init__(datapipe)\n"
  },
  {
    "path": "python/dgl/graphbolt/datapipes/__init__.py",
    "content": "\"\"\"GraphBolt's datapipes, mostly copied from \"torchdata==0.7.1\".\"\"\"\nfrom .utils import *\nfrom .visualization import *\n"
  },
  {
    "path": "python/dgl/graphbolt/datapipes/utils.py",
    "content": "\"\"\"DataPipe utilities\"\"\"\n\nimport threading\nimport time\n\nfrom collections import deque\nfrom typing import final, List, Set, Type  # pylint: disable=no-name-in-module\n\nfrom torch.utils.data import functional_datapipe, IterDataPipe, MapDataPipe\nfrom torch.utils.data.graph import DataPipe, DataPipeGraph, traverse_dps\n\n__all__ = [\n    \"datapipe_graph_to_adjlist\",\n    \"find_dps\",\n    \"replace_dp\",\n    \"traverse_dps\",\n]\n\n# Copied from:\n# https://github.com/pytorch/data/blob/88c8bdc6662f37649b7ea5df0bd90a4b24a56876/torchdata/datapipes/iter/util/prefetcher.py#L19-L20\n# Interval between buffer fulfillment checks\nPRODUCER_SLEEP_INTERVAL = 0.0001\n# Interval between checking items availability in buffer\nCONSUMER_SLEEP_INTERVAL = 0.0001\n\n\ndef _get_parents(result_dict, datapipe_graph):\n    for k, (v, parents) in datapipe_graph.items():\n        if k not in result_dict:\n            result_dict[k] = (v, list(parents.keys()))\n            _get_parents(result_dict, parents)\n\n\ndef datapipe_graph_to_adjlist(datapipe_graph):\n    \"\"\"Given a DataPipe graph returned by\n    :func:`torch.utils.data.graph.traverse_dps` in DAG form, convert it into\n    adjacency list form.\n\n    Namely, :func:`torch.utils.data.graph.traverse_dps` returns the following\n    data structure:\n\n    .. code::\n\n       {\n           id(datapipe): (\n               datapipe,\n               {\n                   id(parent1_of_datapipe): (parent1_of_datapipe, {...}),\n                   id(parent2_of_datapipe): (parent2_of_datapipe, {...}),\n                   ...\n               }\n           )\n       }\n\n    We convert it into the following for easier access:\n\n    .. code::\n\n       {\n           id(datapipe1): (\n               datapipe1,\n               [id(parent1_of_datapipe1), id(parent2_of_datapipe1), ...]\n           ),\n           id(datapipe2): (\n               datapipe2,\n               [id(parent1_of_datapipe2), id(parent2_of_datapipe2), ...]\n           ),\n           ...\n       }\n    \"\"\"\n\n    result_dict = {}\n    _get_parents(result_dict, datapipe_graph)\n    return result_dict\n\n\n# Copied from:\n# https://github.com/pytorch/data/blob/88c8bdc6662f37649b7ea5df0bd90a4b24a56876/torchdata/dataloader2/graph/utils.py#L16-L35\ndef find_dps(graph: DataPipeGraph, dp_type: Type[DataPipe]) -> List[DataPipe]:\n    r\"\"\"\n    Given the graph of DataPipe generated by ``traverse_dps`` function, return DataPipe\n    instances with the provided DataPipe type.\n    \"\"\"\n    dps: List[DataPipe] = []\n    cache: Set[int] = set()\n\n    def helper(g) -> None:  # pyre-ignore\n        for dp_id, (dp, src_graph) in g.items():\n            if dp_id in cache:\n                continue\n            cache.add(dp_id)\n            # Please not use `isinstance`, there is a bug.\n            if type(dp) is dp_type:  # pylint: disable=unidiomatic-typecheck\n                dps.append(dp)\n            helper(src_graph)\n\n    helper(graph)\n\n    return dps\n\n\n# Copied from:\n# https://github.com/pytorch/data/blob/88c8bdc6662f37649b7ea5df0bd90a4b24a56876/torchdata/dataloader2/graph/utils.py#L82-L97\n# Given the DataPipe needs to be replaced and the expected DataPipe, return a new graph\ndef replace_dp(\n    graph: DataPipeGraph, old_datapipe: DataPipe, new_datapipe: DataPipe\n) -> DataPipeGraph:\n    r\"\"\"\n    Given the graph of DataPipe generated by ``traverse_dps`` function and the\n    DataPipe to be replaced and the new DataPipe, return the new graph of\n    DataPipe.\n    \"\"\"\n    assert len(graph) == 1\n\n    if id(old_datapipe) in graph:\n        graph = traverse_dps(new_datapipe)\n\n    final_datapipe = list(graph.values())[0][0]\n\n    for recv_dp, send_graph in graph.values():\n        _replace_dp(recv_dp, send_graph, old_datapipe, new_datapipe)\n\n    return traverse_dps(final_datapipe)\n\n\n# For each `recv_dp`, find if the source_datapipe needs to be replaced by the new one.\n# If found, find where the `old_dp` is located in `recv_dp` and switch it to the `new_dp`\ndef _replace_dp(\n    recv_dp, send_graph: DataPipeGraph, old_dp: DataPipe, new_dp: DataPipe\n) -> None:\n    old_dp_id = id(old_dp)\n    for send_id in send_graph:\n        if send_id == old_dp_id:\n            _assign_attr(recv_dp, old_dp, new_dp, inner_dp=True)\n        else:\n            send_dp, sub_send_graph = send_graph[send_id]\n            _replace_dp(send_dp, sub_send_graph, old_dp, new_dp)\n\n\n# Recursively re-assign datapipe for the sake of nested data structure\n# `inner_dp` is used to prevent recursive call if we have already met a `DataPipe`\ndef _assign_attr(obj, old_dp, new_dp, inner_dp: bool = False):\n    if obj is old_dp:\n        return new_dp\n    elif isinstance(obj, (IterDataPipe, MapDataPipe)):\n        # Prevent recursive call for DataPipe\n        if not inner_dp:\n            return None\n        for k in list(obj.__dict__.keys()):\n            new_obj = _assign_attr(obj.__dict__[k], old_dp, new_dp)\n            if new_obj is not None:\n                obj.__dict__[k] = new_obj\n                break\n        return None\n    elif isinstance(obj, dict):\n        for k in list(obj.keys()):\n            new_obj = _assign_attr(obj[k], old_dp, new_dp)\n            if new_obj is not None:\n                obj[k] = new_obj\n                break\n        return None\n    # Tuple is immutable, has to re-create a tuple\n    elif isinstance(obj, tuple):\n        temp_list = []\n        flag = False\n        for item in obj:\n            new_obj = _assign_attr(item, old_dp, new_dp, inner_dp)\n            if new_obj is not None:\n                flag = True\n                temp_list.append(new_dp)\n            else:\n                temp_list.append(item)\n        if flag:\n            return tuple(temp_list)  # Special case\n        else:\n            return None\n    elif isinstance(obj, list):\n        for i in range(len(obj)):  # pylint: disable=consider-using-enumerate\n            new_obj = _assign_attr(obj[i], old_dp, new_dp, inner_dp)\n            if new_obj is not None:\n                obj[i] = new_obj\n                break\n        return None\n    elif isinstance(obj, set):\n        new_obj = None\n        for item in obj:\n            if _assign_attr(item, old_dp, new_dp, inner_dp) is not None:\n                new_obj = new_dp\n                break\n        if new_obj is not None:\n            obj.remove(old_dp)\n            obj.add(new_dp)\n        return None\n    else:\n        return None\n\n\nclass _PrefetchData:\n    def __init__(self, source_datapipe, buffer_size: int):\n        self.run_prefetcher: bool = True\n        self.prefetch_buffer: Deque = deque()\n        self.buffer_size: int = buffer_size\n        self.source_datapipe = source_datapipe\n        self.stop_iteration: bool = False\n        self.paused: bool = False\n\n\n# Copied from:\n# https://github.com/pytorch/data/blob/88c8bdc6662f37649b7ea5df0bd90a4b24a56876/torchdata/datapipes/iter/util/prefetcher.py#L34-L172\n@functional_datapipe(\"prefetch\")\nclass PrefetcherIterDataPipe(IterDataPipe):\n    r\"\"\"\n    Prefetches elements from the source DataPipe and puts them into a buffer\n    (functional name: ``prefetch``). Prefetching performs the operations (e.g.\n    I/O, computations) of the DataPipes up to this one ahead of time and stores\n    the result in the buffer, ready to be consumed by the subsequent DataPipe.\n    It has no effect aside from getting the sample ready ahead of time.\n\n    This is used by ``MultiProcessingReadingService`` when the arguments\n    ``worker_prefetch_cnt`` (for prefetching at each worker process) or\n    ``main_prefetch_cnt`` (for prefetching at the main loop) are greater than 0.\n\n    Beyond the built-in use cases, this can be useful to put after I/O DataPipes\n    that have expensive I/O operations (e.g. takes a long time to request a file\n    from a remote server).\n\n    Args:\n        source_datapipe: IterDataPipe from which samples are prefetched\n        buffer_size: the size of the buffer which stores the prefetched samples\n\n    Example:\n        >>> from torchdata.datapipes.iter import IterableWrapper\n        >>> dp = IterableWrapper(file_paths).open_files().prefetch(5)\n    \"\"\"\n\n    def __init__(self, source_datapipe, buffer_size: int = 10):\n        self.source_datapipe = source_datapipe\n        if buffer_size <= 0:\n            raise ValueError(\n                \"'buffer_size' is required to be a positive integer.\"\n            )\n        self.buffer_size = buffer_size\n        self.thread: Optional[threading.Thread] = None\n        self.prefetch_data: Optional[_PrefetchData] = None\n\n    @staticmethod\n    def thread_worker(\n        prefetch_data: _PrefetchData,\n    ):  # pylint: disable=missing-function-docstring\n        itr = iter(prefetch_data.source_datapipe)\n        while not prefetch_data.stop_iteration:\n            # Run if not paused\n            while prefetch_data.run_prefetcher:\n                if (\n                    len(prefetch_data.prefetch_buffer)\n                    < prefetch_data.buffer_size\n                ):\n                    try:\n                        item = next(itr)\n                        prefetch_data.prefetch_buffer.append(item)\n                    except Exception as e:  # pylint: disable=broad-except\n                        prefetch_data.run_prefetcher = False\n                        prefetch_data.stop_iteration = True\n                        prefetch_data.prefetch_buffer.append(e)\n                else:  # Buffer is full, waiting for main thread to consume items\n                    # TODO: Calculate sleep interval based on previous consumption speed\n                    time.sleep(PRODUCER_SLEEP_INTERVAL)\n            prefetch_data.paused = True\n            # Sleep longer when this prefetcher thread is paused\n            time.sleep(PRODUCER_SLEEP_INTERVAL * 10)\n\n    def __iter__(self):\n        try:\n            prefetch_data = _PrefetchData(\n                self.source_datapipe, self.buffer_size\n            )\n            self.prefetch_data = prefetch_data\n            thread = threading.Thread(\n                target=PrefetcherIterDataPipe.thread_worker,\n                args=(prefetch_data,),\n                daemon=True,\n            )\n            thread.start()\n            self.thread = thread\n\n            while (\n                not prefetch_data.stop_iteration\n                or len(prefetch_data.prefetch_buffer) > 0\n            ):\n                if len(prefetch_data.prefetch_buffer) > 0:\n                    data = prefetch_data.prefetch_buffer.popleft()\n                    if isinstance(data, Exception):\n                        if isinstance(data, StopIteration):\n                            break\n                        raise data\n                    yield data\n                else:\n                    time.sleep(CONSUMER_SLEEP_INTERVAL)\n        finally:\n            if \"prefetch_data\" in locals():\n                prefetch_data.run_prefetcher = False\n                prefetch_data.stop_iteration = True\n                prefetch_data.paused = False\n            if \"thread\" in locals():\n                thread.join()\n\n    def __getstate__(self):\n        \"\"\"\n        Getting state in threading environment requires next operations:\n            1) Stopping of the producer thread.\n            2) Saving buffer.\n            3) Adding lazy restart of producer thread when __next__ is called again\n              (this will guarantee that you only change state of the source_datapipe\n               after entire state of the graph is saved).\n        \"\"\"\n        # TODO: Update __getstate__ and __setstate__ to support snapshotting and restoration\n        return {\n            \"source_datapipe\": self.source_datapipe,\n            \"buffer_size\": self.buffer_size,\n        }\n\n    def __setstate__(self, state):\n        self.source_datapipe = state[\"source_datapipe\"]\n        self.buffer_size = state[\"buffer_size\"]\n        self.thread = None\n\n    @final\n    def reset(self):  # pylint: disable=missing-function-docstring\n        self.shutdown()\n\n    def pause(self):  # pylint: disable=missing-function-docstring\n        if self.thread is not None:\n            assert self.prefetch_data is not None\n            self.prefetch_data.run_prefetcher = False\n            if self.thread.is_alive():\n                # Blocking until the thread is paused\n                while not self.prefetch_data.paused:\n                    time.sleep(PRODUCER_SLEEP_INTERVAL * 10)\n\n    @final\n    def resume(self):  # pylint: disable=missing-function-docstring\n        if (\n            self.thread is not None\n            and self.prefetch_data is not None\n            and (\n                not self.prefetch_data.stop_iteration\n                or len(self.prefetch_data.prefetch_buffer) > 0\n            )\n        ):\n            self.prefetch_data.run_prefetcher = True\n            self.prefetch_data.paused = False\n\n    @final\n    def shutdown(self):  # pylint: disable=missing-function-docstring\n        if hasattr(self, \"prefetch_data\") and self.prefetch_data is not None:\n            self.prefetch_data.run_prefetcher = False\n            self.prefetch_data.stop_iteration = True\n            self.prefetch_data.paused = False\n            self.prefetch_data = None\n        if hasattr(self, \"thread\") and self.thread is not None:\n            self.thread.join()\n            self.thread = None\n\n    def __del__(self):\n        self.shutdown()\n\n    def __len__(self) -> int:\n        if isinstance(self.source_datapipe, Sized):\n            return len(self.source_datapipe)\n        raise TypeError(\n            f\"{type(self).__name__} instance doesn't have valid length\"\n        )\n"
  },
  {
    "path": "python/dgl/graphbolt/datapipes/visualization.py",
    "content": "# pylint: disable=W,C,R\n# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the BSD-style license found in the\n# LICENSE file in the root directory of this source tree.\n# Original source:\n# https://github.com/pytorch/data/blob/v0.7.1/torchdata/datapipes/utils/_visualization.py\n\nimport itertools\nfrom collections import defaultdict\n\nfrom typing import Optional, Set, TYPE_CHECKING\n\nfrom torch.utils.data.datapipes.iter.combining import _ChildDataPipe\n\nfrom .utils import IterDataPipe, traverse_dps\n\nif TYPE_CHECKING:\n    import graphviz\n\n\n__all__ = [\n    \"to_graph\",\n]\n\n\nclass Node:\n    def __init__(self, dp, *, name=None):\n        self.dp = dp\n        self.name = name or type(dp).__name__.replace(\"IterDataPipe\", \"\")\n        self.childs = set()\n        self.parents = set()\n\n    def add_child(self, child):\n        self.childs.add(child)\n        child.parents.add(self)\n\n    def remove_child(self, child):\n        self.childs.remove(child)\n        child.parents.remove(self)\n\n    def add_parent(self, parent):\n        self.parents.add(parent)\n        parent.childs.add(self)\n\n    def remove_parent(self, parent):\n        self.parents.remove(parent)\n        parent.childs.remove(self)\n\n    def __eq__(self, other):\n        if not isinstance(other, Node):\n            return NotImplemented\n\n        return hash(self) == hash(other)\n\n    def __hash__(self):\n        return hash(self.dp)\n\n    def __str__(self):\n        return self.name\n\n    def __repr__(self):\n        return f\"{self}-{hash(self)}\"\n\n\ndef to_nodes(dp, *, debug: bool) -> Set[Node]:\n    def recurse(dp_graph, child=None):\n        for _dp_id, (dp_node, dp_parents) in dp_graph.items():\n            node = Node(dp_node)\n            if child is not None:\n                node.add_child(child)\n            yield node\n            yield from recurse(dp_parents, child=node)\n\n    def aggregate(nodes):\n        groups = defaultdict(list)\n        for node in nodes:\n            groups[node].append(node)\n\n        nodes = set()\n        for node, group in groups.items():\n            if len(group) == 1:\n                nodes.add(node)\n                continue\n\n            aggregated_node = Node(node.dp)\n\n            for duplicate_node in group:\n                for child in duplicate_node.childs.copy():\n                    duplicate_node.remove_child(child)\n                    aggregated_node.add_child(child)\n\n                for parent in duplicate_node.parents.copy():\n                    duplicate_node.remove_parent(parent)\n                    aggregated_node.add_parent(parent)\n\n            nodes.add(aggregated_node)\n\n        if debug:\n            return nodes\n\n        child_dp_nodes = set(\n            itertools.chain.from_iterable(\n                node.parents\n                for node in nodes\n                if isinstance(node.dp, _ChildDataPipe)\n            )\n        )\n\n        if not child_dp_nodes:\n            return nodes\n\n        for node in child_dp_nodes:\n            fixed_parent_node = Node(\n                type(\n                    str(node).lstrip(\"_\"),\n                    (IterDataPipe,),\n                    dict(dp=node.dp, childs=node.childs),\n                )()\n            )\n            nodes.remove(node)\n            nodes.add(fixed_parent_node)\n\n            for parent in node.parents.copy():\n                node.remove_parent(parent)\n                fixed_parent_node.add_parent(parent)\n\n            for child in node.childs:\n                nodes.remove(child)\n                for actual_child in child.childs.copy():\n                    actual_child.remove_parent(child)\n                    actual_child.add_parent(fixed_parent_node)\n\n        return nodes\n\n    return aggregate(recurse(traverse_dps(dp)))\n\n\ndef to_graph(dp, *, debug: bool = False) -> \"graphviz.Digraph\":\n    \"\"\"Visualizes a DataPipe by returning a :class:`graphviz.Digraph`, which is a graph of the data pipeline.\n    This allows you to visually inspect all the transformation that takes place in your DataPipes.\n\n    .. note::\n\n        The package :mod:`graphviz` is required to use this function.\n\n    .. note::\n\n        The most common interfaces for the returned graph object are:\n\n        - :meth:`~graphviz.Digraph.render`: Save the graph to a file.\n        - :meth:`~graphviz.Digraph.view`: Open the graph in a viewer.\n\n    Args:\n        dp: DataPipe that you would like to visualize (generally the last one in a chain of DataPipes).\n        debug (bool): If ``True``, renders internal datapipes that are usually hidden from the user\n            (such as ``ChildDataPipe`` of `demux` and `fork`). Defaults to ``False``.\n\n    Example:\n        >>> from torchdata.datapipes.iter import IterableWrapper\n        >>> from torchdata.datapipes.utils import to_graph\n        >>> dp = IterableWrapper(range(10))\n        >>> dp1, dp2 = dp.demux(num_instances=2, classifier_fn=lambda x: x % 2)\n        >>> dp1 = dp1.map(lambda x: x + 1)\n        >>> dp2 = dp2.filter(lambda _: True)\n        >>> dp3 = dp1.zip(dp2).map(lambda t: t[0] + t[1])\n        >>> g = to_graph(dp3)\n        >>> g.view()  # This will open the graph in a viewer\n    \"\"\"\n    try:\n        import graphviz\n    except ModuleNotFoundError:\n        raise ModuleNotFoundError(\n            \"The package `graphviz` is required to be installed to use this function. \"\n            \"Please `pip install graphviz` or `conda install -c conda-forge graphviz`.\"\n        ) from None\n\n    # The graph style as well as the color scheme below was copied from https://github.com/szagoruyko/pytorchviz/\n    # https://github.com/szagoruyko/pytorchviz/blob/0adcd83af8aa7ab36d6afd139cabbd9df598edb7/torchviz/dot.py#L78-L85\n    node_attr = dict(\n        style=\"filled\",\n        shape=\"box\",\n        align=\"left\",\n        fontsize=\"10\",\n        ranksep=\"0.1\",\n        height=\"0.2\",\n        fontname=\"monospace\",\n    )\n    graph = graphviz.Digraph(node_attr=node_attr, graph_attr=dict(size=\"12,12\"))\n\n    for node in to_nodes(dp, debug=debug):\n        fillcolor: Optional[str]\n        if not node.parents:\n            fillcolor = \"lightblue\"\n        elif not node.childs:\n            fillcolor = \"darkolivegreen1\"\n        else:\n            fillcolor = None\n\n        graph.node(name=repr(node), label=str(node), fillcolor=fillcolor)\n\n        for child in node.childs:\n            graph.edge(repr(node), repr(child))\n\n    return graph\n"
  },
  {
    "path": "python/dgl/graphbolt/dataset.py",
    "content": "\"\"\"GraphBolt Dataset.\"\"\"\n\nfrom typing import Dict, List, Union\n\nfrom .feature_store import FeatureStore\nfrom .itemset import HeteroItemSet, ItemSet\nfrom .sampling_graph import SamplingGraph\n\n__all__ = [\n    \"Task\",\n    \"Dataset\",\n]\n\n\nclass Task:\n    \"\"\"An abstract task which consists of meta information and\n    Train/Validation/Test Set.\n\n    * meta information\n        The meta information of a task includes any kinds of data that are\n        defined by the user in YAML when instantiating the task.\n\n    * Train/Validation/Test Set\n        The train/validation/test (TVT) set which is used to train the neural\n        networks. We calculate the embeddings based on their respective features\n        and the graph structure, and then utilize the embeddings to optimize the\n        neural network parameters.\n    \"\"\"\n\n    @property\n    def metadata(self) -> Dict:\n        \"\"\"Return the task metadata.\"\"\"\n        raise NotImplementedError\n\n    @property\n    def train_set(self) -> Union[ItemSet, HeteroItemSet]:\n        \"\"\"Return the training set.\"\"\"\n        raise NotImplementedError\n\n    @property\n    def validation_set(self) -> Union[ItemSet, HeteroItemSet]:\n        \"\"\"Return the validation set.\"\"\"\n        raise NotImplementedError\n\n    @property\n    def test_set(self) -> Union[ItemSet, HeteroItemSet]:\n        \"\"\"Return the test set.\"\"\"\n        raise NotImplementedError\n\n\nclass Dataset:\n    \"\"\"An abstract dataset which provides abstraction for accessing the data\n    required for training.\n\n    The data abstraction could be a native CPU memory block, a shared memory\n    block, a file handle of an opened file on disk, a service that provides\n    the API to access the data e.t.c. There are 3 primary components in the\n    dataset:\n\n    * Task\n        A task consists of several meta information and the\n        Train/Validation/Test Set. A dataset could have multiple tasks.\n\n    * Feature Storage\n        A key-value store which stores node/edge/graph features.\n\n    * Graph Topology\n        Graph topology is used by the subgraph sampling algorithm to generate\n        a subgraph.\n    \"\"\"\n\n    @property\n    def tasks(self) -> List[Task]:\n        \"\"\"Return the tasks.\"\"\"\n        raise NotImplementedError\n\n    @property\n    def graph(self) -> SamplingGraph:\n        \"\"\"Return the graph.\"\"\"\n        raise NotImplementedError\n\n    @property\n    def feature(self) -> FeatureStore:\n        \"\"\"Return the feature.\"\"\"\n        raise NotImplementedError\n\n    @property\n    def dataset_name(self) -> str:\n        \"\"\"Return the dataset name.\"\"\"\n        raise NotImplementedError\n\n    @property\n    def all_nodes_set(self) -> Union[ItemSet, HeteroItemSet]:\n        \"\"\"Return the itemset containing all nodes.\"\"\"\n        raise NotImplementedError\n"
  },
  {
    "path": "python/dgl/graphbolt/external_utils.py",
    "content": "\"\"\"Utility functions for external use.\"\"\"\nfrom functools import partial\nfrom typing import Dict, Union\n\nimport torch\n\nfrom torch.utils.data import functional_datapipe\n\nfrom .minibatch import MiniBatch\nfrom .minibatch_transformer import MiniBatchTransformer\n\n\n@functional_datapipe(\"exclude_seed_edges\")\nclass SeedEdgesExcluder(MiniBatchTransformer):\n    \"\"\"A mini-batch transformer used to manipulate mini-batch.\n\n    Functional name: :obj:`transform`.\n\n    Parameters\n    ----------\n    datapipe : DataPipe\n        The datapipe.\n    include_reverse_edges : bool\n        Whether reverse edges should be excluded as well. Default is False.\n    reverse_etypes_mapping : Dict[str, str] = None\n        The mapping from the original edge types to their reverse edge types.\n    asynchronous: bool\n        Boolean indicating whether edge exclusion stages should run on\n        background threads to hide the latency of CPU GPU synchronization.\n        Should be enabled only when sampling on the GPU.\n    \"\"\"\n\n    def __init__(\n        self,\n        datapipe,\n        include_reverse_edges: bool = False,\n        reverse_etypes_mapping: Dict[str, str] = None,\n        asynchronous=False,\n    ):\n        exclude_seed_edges_fn = partial(\n            exclude_seed_edges,\n            include_reverse_edges=include_reverse_edges,\n            reverse_etypes_mapping=reverse_etypes_mapping,\n            async_op=asynchronous,\n        )\n        datapipe = datapipe.transform(exclude_seed_edges_fn)\n        if asynchronous:\n            datapipe = datapipe.buffer()\n            datapipe = datapipe.transform(self._wait_for_sampled_subgraphs)\n        super().__init__(datapipe)\n\n    @staticmethod\n    def _wait_for_sampled_subgraphs(minibatch):\n        minibatch.sampled_subgraphs = [\n            subgraph.wait() for subgraph in minibatch.sampled_subgraphs\n        ]\n        return minibatch\n\n\ndef add_reverse_edges(\n    edges: Union[Dict[str, torch.Tensor], torch.Tensor],\n    reverse_etypes_mapping: Dict[str, str] = None,\n):\n    r\"\"\"\n    This function finds the reverse edges of the given `edges` and returns the\n    composition of them. In a homogeneous graph, reverse edges have inverted\n    source and destination node IDs. While in a heterogeneous graph, reversing\n    also involves swapping node IDs and their types. This function could be\n    used before `exclude_edges` function to help find targeting edges.\n    Note: The found reverse edges may not really exists in the original graph.\n    And repeat edges could be added becasue reverse edges may already exists in\n    the `edges`.\n\n    Parameters\n    ----------\n    edges : Union[Dict[str, torch.Tensor], torch.Tensor]\n      - If sampled subgraph is homogeneous, then `edges` should be a N*2\n        tensors.\n      - If sampled subgraph is heterogeneous, then `edges` should be a\n        dictionary of edge types and the corresponding edges to exclude.\n    reverse_etypes_mapping : Dict[str, str], optional\n        The mapping from the original edge types to their reverse edge types.\n\n    Returns\n    -------\n    Union[Dict[str, torch.Tensor], torch.Tensor]\n        The node pairs contain both the original edges and their reverse\n        counterparts.\n\n    Examples\n    --------\n    >>> edges = {\"A:r:B\": torch.tensor([[0, 1],[1, 2]]))}\n    >>> print(gb.add_reverse_edges(edges, {\"A:r:B\": \"B:rr:A\"}))\n    {'A:r:B': torch.tensor([[0, 1],[1, 2]]),\n    'B:rr:A': torch.tensor([[1, 0],[2, 1]])}\n\n    >>> edges = torch.tensor([[0, 1],[1, 2]])\n    >>> print(gb.add_reverse_edges(edges))\n    torch.tensor([[1, 0],[2, 1]])\n    \"\"\"\n    if isinstance(edges, torch.Tensor):\n        assert edges.ndim == 2 and edges.shape[1] == 2, (\n            \"Only tensor with shape N*2 is supported now, but got \"\n            + f\"{edges.shape}.\"\n        )\n        reverse_edges = edges.flip(dims=(1,))\n        return torch.cat((edges, reverse_edges))\n    else:\n        combined_edges = edges.copy()\n        for etype, reverse_etype in reverse_etypes_mapping.items():\n            if etype in edges:\n                assert edges[etype].ndim == 2 and edges[etype].shape[1] == 2, (\n                    \"Only tensor with shape N*2 is supported now, but got \"\n                    + f\"{edges[etype].shape}.\"\n                )\n                if reverse_etype in combined_edges:\n                    combined_edges[reverse_etype] = torch.cat(\n                        (\n                            combined_edges[reverse_etype],\n                            edges[etype].flip(dims=(1,)),\n                        )\n                    )\n                else:\n                    combined_edges[reverse_etype] = edges[etype].flip(dims=(1,))\n        return combined_edges\n\n\ndef exclude_seed_edges(\n    minibatch: MiniBatch,\n    include_reverse_edges: bool = False,\n    reverse_etypes_mapping: Dict[str, str] = None,\n    async_op: bool = False,\n):\n    \"\"\"\n    Exclude seed edges with or without their reverse edges from the sampled\n    subgraphs in the minibatch.\n\n    Parameters\n    ----------\n    minibatch : MiniBatch\n        The minibatch.\n    include_reverse_edges : bool\n        Whether reverse edges should be excluded as well. Default is False.\n    reverse_etypes_mapping : Dict[str, str] = None\n        The mapping from the original edge types to their reverse edge types.\n    async_op: bool\n        Boolean indicating whether the call is asynchronous. If so, the result\n        can be obtained by calling wait on the modified sampled_subgraphs.\n    \"\"\"\n    edges_to_exclude = minibatch.seeds\n    if include_reverse_edges:\n        edges_to_exclude = add_reverse_edges(\n            edges_to_exclude, reverse_etypes_mapping\n        )\n    minibatch.sampled_subgraphs = [\n        subgraph.exclude_edges(edges_to_exclude, async_op=async_op)\n        for subgraph in minibatch.sampled_subgraphs\n    ]\n    return minibatch\n"
  },
  {
    "path": "python/dgl/graphbolt/feature_fetcher.py",
    "content": "\"\"\"Feature fetchers\"\"\"\n\nfrom functools import partial\nfrom typing import Dict\n\nimport torch\n\nfrom torch.utils.data import functional_datapipe\n\nfrom .base import etype_tuple_to_str\nfrom .impl.cooperative_conv import CooperativeConvFunction\n\nfrom .minibatch_transformer import MiniBatchTransformer\n\n\n__all__ = [\n    \"FeatureFetcher\",\n    \"FeatureFetcherStartMarker\",\n]\n\n\ndef get_feature_key_list(feature_keys, domain):\n    \"\"\"Processes node_feature_keys and extracts their feature keys to a list.\"\"\"\n    if isinstance(feature_keys, Dict):\n        return [\n            (domain, type_name, feature_name)\n            for type_name, feature_names in feature_keys.items()\n            for feature_name in feature_names\n        ]\n    elif feature_keys is not None:\n        return [(domain, None, feature_name) for feature_name in feature_keys]\n    else:\n        return []\n\n\n@functional_datapipe(\"mark_feature_fetcher_start\")\nclass FeatureFetcherStartMarker(MiniBatchTransformer):\n    \"\"\"Used to mark the start of a FeatureFetcher and is a no-op. All the\n    datapipes created during a FeatureFetcher instantiation are guarenteed to be\n    contained between FeatureFetcherStartMarker and FeatureFetcher instances in\n    the datapipe graph.\n    \"\"\"\n\n    def __init__(self, datapipe):\n        super().__init__(datapipe, self._identity)\n\n\n@functional_datapipe(\"fetch_feature\")\nclass FeatureFetcher(MiniBatchTransformer):\n    \"\"\"A feature fetcher used to fetch features for node/edge in graphbolt.\n\n    Functional name: :obj:`fetch_feature`.\n\n    Parameters\n    ----------\n    datapipe : DataPipe\n        The datapipe.\n    feature_store : FeatureStore\n        A storage for features, support read and update.\n    node_feature_keys : List[str] or Dict[str, List[str]]\n        Node features keys indicates the node features need to be read.\n        - If `node_features` is a list: It means the graph is homogeneous\n        graph, and the 'str' inside are feature names.\n        - If `node_features` is a dictionary: The keys should be node type\n        and the values are lists of feature names.\n    edge_feature_keys : List[str] or Dict[str, List[str]]\n        Edge features name indicates the edge features need to be read.\n        - If `edge_features` is a list: It means the graph is homogeneous\n        graph, and the 'str' inside are feature names.\n        - If `edge_features` is a dictionary: The keys are edge types,\n        following the format 'str:str:str', and the values are lists of\n        feature names.\n    overlap_fetch : bool, optional\n        If True, the feature fetcher will overlap the UVA feature fetcher\n        operations with the rest of operations by using an alternative CUDA\n        stream or utilizing asynchronous operations. Default is True.\n    cooperative: bool, optional\n        Boolean indicating whether Cooperative Minibatching, which was initially\n        proposed in\n        `Deep Graph Library PR#4337<https://github.com/dmlc/dgl/pull/4337>`__\n        and was later first fully described in\n        `Cooperative Minibatching in Graph Neural Networks\n        <https://arxiv.org/abs/2310.12403>`__. Cooperation between the GPUs\n        eliminates duplicate work performed across the GPUs due to the\n        overlapping sampled k-hop neighborhoods of seed nodes when performing\n        GNN minibatching.\n    \"\"\"\n\n    def __init__(\n        self,\n        datapipe,\n        feature_store,\n        node_feature_keys=None,\n        edge_feature_keys=None,\n        overlap_fetch=True,\n        cooperative=False,\n    ):\n        datapipe = datapipe.mark_feature_fetcher_start()\n        self.feature_store = feature_store\n        self.node_feature_keys = node_feature_keys\n        self.edge_feature_keys = edge_feature_keys\n        max_val = 0\n        if overlap_fetch:\n            for feature_key_list in [\n                get_feature_key_list(node_feature_keys, \"node\"),\n                get_feature_key_list(edge_feature_keys, \"edge\"),\n            ]:\n                for feature_key in feature_key_list:\n                    if feature_key not in feature_store:\n                        continue\n                    for device_str in [\"cpu\", \"cuda\"]:\n                        try:\n                            max_val = max(\n                                feature_store[\n                                    feature_key\n                                ].read_async_num_stages(\n                                    torch.device(device_str)\n                                ),\n                                max_val,\n                            )\n                        except AssertionError:\n                            pass\n        datapipe = datapipe.transform(self._read)\n        for i in range(max_val, 0, -1):\n            datapipe = datapipe.transform(\n                partial(self._execute_stage, i)\n            ).buffer(1)\n        if max_val > 0:\n            datapipe = datapipe.transform(self._final_stage)\n        if cooperative:\n            datapipe = datapipe.transform(self._cooperative_exchange)\n            datapipe = datapipe.buffer()\n        super().__init__(datapipe)\n        # A positive value indicates that the overlap optimization is enabled.\n        self.max_num_stages = max_val\n\n    @staticmethod\n    def _execute_stage(current_stage, data):\n        all_features = [data.node_features] + [\n            data.edge_features[i] for i in range(data.num_layers())\n        ]\n        for features in all_features:\n            for key in features:\n                handle, stage = features[key]\n                assert current_stage >= stage\n                if current_stage == stage:\n                    value = next(handle)\n                    features[key] = (handle if stage > 1 else value, stage - 1)\n        return data\n\n    @staticmethod\n    def _final_stage(data):\n        all_features = [data.node_features] + [\n            data.edge_features[i] for i in range(data.num_layers())\n        ]\n        for features in all_features:\n            for key in features:\n                value, stage = features[key]\n                assert stage == 0\n                features[key] = value.wait()\n        return data\n\n    def _cooperative_exchange(self, data):\n        subgraph = data.sampled_subgraphs[0]\n        is_heterogeneous = isinstance(\n            self.node_feature_keys, Dict\n        ) or isinstance(self.edge_feature_keys, Dict)\n        if is_heterogeneous:\n            node_features = {key: {} for key, _ in data.node_features.keys()}\n            for (key, ntype), feature in data.node_features.items():\n                node_features[key][ntype] = feature\n            for key, feature in node_features.items():\n                new_feature = CooperativeConvFunction.apply(subgraph, feature)\n                for ntype, tensor in new_feature.items():\n                    data.node_features[(key, ntype)] = tensor\n        else:\n            for key in data.node_features:\n                feature = data.node_features[key]\n                new_feature = CooperativeConvFunction.apply(subgraph, feature)\n                data.node_features[key] = new_feature\n        return data\n\n    def _read(self, data):\n        \"\"\"\n        Fill in the node/edge features field in data.\n\n        Parameters\n        ----------\n        data : MiniBatch\n            An instance of :class:`MiniBatch`. Even if 'node_feature' or\n            'edge_feature' is already filled, it will be overwritten for\n            overlapping features.\n\n        Returns\n        -------\n        MiniBatch\n            An instance of :class:`MiniBatch` filled with required features.\n        \"\"\"\n        node_features = {}\n        num_layers = data.num_layers()\n        edge_features = [{} for _ in range(num_layers)]\n        is_heterogeneous = isinstance(\n            self.node_feature_keys, Dict\n        ) or isinstance(self.edge_feature_keys, Dict)\n        # Read Node features.\n        input_nodes = data.node_ids()\n\n        def read_helper(feature_key, index):\n            if self.max_num_stages > 0:\n                feature = self.feature_store[feature_key]\n                num_stages = feature.read_async_num_stages(index.device)\n                if num_stages > 0:\n                    return (feature.read_async(index), num_stages)\n                else:  # Asynchronicity is not needed, compute in _final_stage.\n\n                    class _Waiter:\n                        def __init__(self, feature, index):\n                            self.feature = feature\n                            self.index = index\n\n                        def wait(self):\n                            \"\"\"Returns the stored value when invoked.\"\"\"\n                            result = self.feature.read(self.index)\n                            # Ensure there is no memory leak.\n                            self.feature = self.index = None\n                            return result\n\n                    return (_Waiter(feature, index), 0)\n            else:\n                domain, type_name, feature_name = feature_key\n                return self.feature_store.read(\n                    domain, type_name, feature_name, index\n                )\n\n        if self.node_feature_keys and input_nodes is not None:\n            if is_heterogeneous:\n                for type_name, nodes in input_nodes.items():\n                    if type_name not in self.node_feature_keys or nodes is None:\n                        continue\n                    for feature_name in self.node_feature_keys[type_name]:\n                        node_features[(type_name, feature_name)] = read_helper(\n                            (\"node\", type_name, feature_name), nodes\n                        )\n            else:\n                for feature_name in self.node_feature_keys:\n                    node_features[feature_name] = read_helper(\n                        (\"node\", None, feature_name), input_nodes\n                    )\n        # Read Edge features.\n        if self.edge_feature_keys and num_layers > 0:\n            for i in range(num_layers):\n                original_edge_ids = data.edge_ids(i)\n                if is_heterogeneous:\n                    # Convert edge type to string.\n                    original_edge_ids = {\n                        (\n                            etype_tuple_to_str(key)\n                            if isinstance(key, tuple)\n                            else key\n                        ): value\n                        for key, value in original_edge_ids.items()\n                    }\n                    for type_name, edges in original_edge_ids.items():\n                        if (\n                            type_name not in self.edge_feature_keys\n                            or edges is None\n                        ):\n                            continue\n                        for feature_name in self.edge_feature_keys[type_name]:\n                            edge_features[i][\n                                (type_name, feature_name)\n                            ] = read_helper(\n                                (\"edge\", type_name, feature_name), edges\n                            )\n                else:\n                    for feature_name in self.edge_feature_keys:\n                        edge_features[i][feature_name] = read_helper(\n                            (\"edge\", None, feature_name), original_edge_ids\n                        )\n        data.set_node_features(node_features)\n        data.set_edge_features(edge_features)\n        return data\n"
  },
  {
    "path": "python/dgl/graphbolt/feature_store.py",
    "content": "\"\"\"Feature store for GraphBolt.\"\"\"\n\nfrom typing import Dict, NamedTuple, Union\n\nimport torch\n\n__all__ = [\n    \"bytes_to_number_of_items\",\n    \"Feature\",\n    \"FeatureStore\",\n    \"FeatureKey\",\n    \"wrap_with_cached_feature\",\n]\n\n\nclass FeatureKey(NamedTuple):\n    \"\"\"A named tuple class to represent feature keys in FeatureStore classes.\n    The fields are domain, type and name all of which take string values.\n    \"\"\"\n\n    domain: str\n    type: str\n    name: int\n\n\nclass Feature:\n    r\"\"\"A wrapper of feature data for access.\"\"\"\n\n    def __init__(self):\n        pass\n\n    def read(self, ids: torch.Tensor = None):\n        \"\"\"Read from the feature.\n\n        Parameters\n        ----------\n        ids : torch.Tensor, optional\n            The index of the feature. If specified, only the specified indices\n            of the feature are read. If None, the entire feature is returned.\n        Returns\n        -------\n        torch.Tensor\n            The read feature.\n        \"\"\"\n        raise NotImplementedError\n\n    def read_async(self, ids: torch.Tensor):\n        \"\"\"Read the feature by index asynchronously.\n\n        Parameters\n        ----------\n        ids : torch.Tensor\n            The index of the feature. Only the specified indices of the\n            feature are read.\n        Returns\n        -------\n        A generator object.\n            The returned generator object returns a future on\n            `read_async_num_stages(ids.device)`th invocation. The return result\n            can be accessed by calling `.wait()`. on the returned future object.\n            It is undefined behavior to call `.wait()` more than once.\n\n        Example Usage\n        --------\n        >>> import dgl.graphbolt as gb\n        >>> feature = gb.Feature(...)\n        >>> ids = torch.tensor([0, 2])\n        >>> for stage, future in enumerate(feature.read_async(ids)):\n        ...     pass\n        >>> assert stage + 1 == feature.read_async_num_stages(ids.device)\n        >>> result = future.wait()  # result contains the read values.\n        \"\"\"\n        raise NotImplementedError\n\n    def read_async_num_stages(self, ids_device: torch.device):\n        \"\"\"The number of stages of the read_async operation. See read_async\n        function for directions on its use. This function is required to return\n        the number of yield operations when read_async is used with a tensor\n        residing on ids_device.\n\n        Parameters\n        ----------\n        ids_device : torch.device\n            The device of the ids parameter passed into read_async.\n        Returns\n        -------\n        int\n            The number of stages of the read_async operation.\n        \"\"\"\n        raise NotImplementedError\n\n    def size(self):\n        \"\"\"Get the size of the feature.\n\n        Returns\n        -------\n        torch.Size\n            The size of the feature.\n        \"\"\"\n        raise NotImplementedError\n\n    def count(self):\n        \"\"\"Get the count of the feature.\n\n        Returns\n        -------\n        int\n            The count of the feature.\n        \"\"\"\n        raise NotImplementedError\n\n    def update(self, value: torch.Tensor, ids: torch.Tensor = None):\n        \"\"\"Update the feature.\n\n        Parameters\n        ----------\n        value : torch.Tensor\n            The updated value of the feature.\n        ids : torch.Tensor, optional\n            The indices of the feature to update. If specified, only the\n            specified indices of the feature will be updated. For the feature,\n            the `ids[i]` row is updated to `value[i]`. So the indices and value\n            must have the same length. If None, the entire feature will be\n            updated.\n        \"\"\"\n        raise NotImplementedError\n\n    def metadata(self):\n        \"\"\"Get the metadata of the feature.\n\n        Returns\n        -------\n        Dict\n            The metadata of the feature.\n        \"\"\"\n        return {}\n\n\nclass FeatureStore:\n    r\"\"\"A store to manage multiple features for access.\"\"\"\n\n    def __init__(self):\n        pass\n\n    def __getitem__(self, feature_key: FeatureKey) -> Feature:\n        \"\"\"Access the underlying `Feature` with its (domain, type, name) as\n        the feature_key.\n        \"\"\"\n        raise NotImplementedError\n\n    def __setitem__(self, feature_key: FeatureKey, feature: Feature):\n        \"\"\"Set the underlying `Feature` with its (domain, type, name) as\n        the feature_key and feature as the value.\n        \"\"\"\n        raise NotImplementedError\n\n    def __contains__(self, feature_key: FeatureKey) -> bool:\n        \"\"\"Checks whether the provided (domain, type, name) as the feature_key\n        is container in the FeatureStore.\"\"\"\n        raise NotImplementedError\n\n    def read(\n        self,\n        domain: str,\n        type_name: str,\n        feature_name: str,\n        ids: torch.Tensor = None,\n    ):\n        \"\"\"Read from the feature store.\n\n        Parameters\n        ----------\n        domain : str\n            The domain of the feature such as \"node\", \"edge\" or \"graph\".\n        type_name : str\n            The node or edge type name.\n        feature_name : str\n            The feature name.\n        ids : torch.Tensor, optional\n            The index of the feature. If specified, only the specified indices\n            of the feature are read. If None, the entire feature is returned.\n\n        Returns\n        -------\n        torch.Tensor\n            The read feature.\n        \"\"\"\n        return self.__getitem__((domain, type_name, feature_name)).read(ids)\n\n    def size(\n        self,\n        domain: str,\n        type_name: str,\n        feature_name: str,\n    ):\n        \"\"\"Get the size of the specified feature in the feature store.\n\n        Parameters\n        ----------\n        domain : str\n            The domain of the feature such as \"node\", \"edge\" or \"graph\".\n        type_name : str\n            The node or edge type name.\n        feature_name : str\n            The feature name.\n        Returns\n        -------\n        torch.Size\n            The size of the specified feature in the feature store.\n        \"\"\"\n        return self.__getitem__((domain, type_name, feature_name)).size()\n\n    def count(\n        self,\n        domain: str,\n        type_name: str,\n        feature_name: str,\n    ):\n        \"\"\"Get the count the specified feature in the feature store.\n\n        Parameters\n        ----------\n        domain : str\n            The domain of the feature such as \"node\", \"edge\" or \"graph\".\n        type_name : str\n            The node or edge type name.\n        feature_name : str\n            The feature name.\n        Returns\n        -------\n        int\n            The count of the specified feature in the feature store.\n        \"\"\"\n        return self.__getitem__((domain, type_name, feature_name)).count()\n\n    def metadata(\n        self,\n        domain: str,\n        type_name: str,\n        feature_name: str,\n    ):\n        \"\"\"Get the metadata of the specified feature in the feature store.\n\n        Parameters\n        ----------\n        domain : str\n            The domain of the feature such as \"node\", \"edge\" or \"graph\".\n        type_name : str\n            The node or edge type name.\n        feature_name : str\n            The feature name.\n        Returns\n        -------\n        Dict\n            The metadata of the feature.\n        \"\"\"\n        return self.__getitem__((domain, type_name, feature_name)).metadata()\n\n    def update(\n        self,\n        domain: str,\n        type_name: str,\n        feature_name: str,\n        value: torch.Tensor,\n        ids: torch.Tensor = None,\n    ):\n        \"\"\"Update the feature store.\n\n        Parameters\n        ----------\n        domain : str\n            The domain of the feature such as \"node\", \"edge\" or \"graph\".\n        type_name : str\n            The node or edge type name.\n        feature_name : str\n            The feature name.\n        value : torch.Tensor\n            The updated value of the feature.\n        ids : torch.Tensor, optional\n            The indices of the feature to update. If specified, only the\n            specified indices of the feature will be updated. For the feature,\n            the `ids[i]` row is updated to `value[i]`. So the indices and value\n            must have the same length. If None, the entire feature will be\n            updated.\n        \"\"\"\n        self.__getitem__((domain, type_name, feature_name)).update(value, ids)\n\n    def keys(self):\n        \"\"\"Get the keys of the features.\n\n        Returns\n        -------\n        List[tuple]\n            The keys of the features. The tuples are in `(domain, type_name,\n            feat_name)` format.\n        \"\"\"\n        raise NotImplementedError\n\n\ndef bytes_to_number_of_items(cache_capacity_in_bytes, single_item):\n    \"\"\"Returns the number of rows to be cached.\"\"\"\n    item_bytes = single_item.nbytes\n    # Round up so that we never get a size of 0, unless bytes is 0.\n    return (cache_capacity_in_bytes + item_bytes - 1) // item_bytes\n\n\ndef wrap_with_cached_feature(\n    cached_feature_type,\n    fallback_features: Union[Feature, Dict[FeatureKey, Feature]],\n    max_cache_size_in_bytes: int,\n    *args,\n    **kwargs,\n) -> Union[Feature, Dict[FeatureKey, Feature]]:\n    \"\"\"Wraps the given features with the given cached feature type using\n    a single cache instance.\"\"\"\n    if not isinstance(fallback_features, dict):\n        assert isinstance(fallback_features, Feature)\n        return wrap_with_cached_feature(\n            cached_feature_type,\n            {\"a\": fallback_features},\n            max_cache_size_in_bytes,\n            *args,\n            **kwargs,\n        )[\"a\"]\n    row_bytes = None\n    cache = None\n    wrapped_features = {}\n    offset = 0\n    for feature_key, fallback_feature in fallback_features.items():\n        # Fetching the feature dimension from the underlying feature.\n        feat0 = fallback_feature.read(torch.tensor([0]))\n        if row_bytes is None:\n            row_bytes = feat0.nbytes\n        else:\n            assert (\n                row_bytes == feat0.nbytes\n            ), \"The # bytes of a single row of the features should match.\"\n        cache_size = bytes_to_number_of_items(max_cache_size_in_bytes, feat0)\n        if cache is None:\n            cache = cached_feature_type._cache_type(\n                cache_shape=(cache_size,) + feat0.shape[1:],\n                dtype=feat0.dtype,\n                *args,\n                **kwargs,\n            )\n        wrapped_features[feature_key] = cached_feature_type(\n            fallback_feature, cache=cache, offset=offset\n        )\n        offset += fallback_feature.count()\n\n    return wrapped_features\n"
  },
  {
    "path": "python/dgl/graphbolt/impl/__init__.py",
    "content": "\"\"\"Implementation of GraphBolt.\"\"\"\nfrom .basic_feature_store import *\nfrom .fused_csc_sampling_graph import *\nfrom .gpu_feature_cache import *\nfrom .gpu_cached_feature import *\nfrom .in_subgraph_sampler import *\nfrom .legacy_dataset import *\nfrom .neighbor_sampler import *\nfrom .temporal_neighbor_sampler import *\nfrom .ondisk_dataset import *\nfrom .ondisk_metadata import *\nfrom .sampled_subgraph_impl import *\nfrom .torch_based_feature_store import *\nfrom .uniform_negative_sampler import *\nfrom .gpu_graph_cache import *\nfrom .cpu_feature_cache import *\nfrom .cpu_cached_feature import *\nfrom .cooperative_conv import *\n"
  },
  {
    "path": "python/dgl/graphbolt/impl/basic_feature_store.py",
    "content": "\"\"\"Basic feature store for GraphBolt.\"\"\"\n\nfrom typing import Dict, Tuple\n\nfrom ..feature_store import Feature, FeatureKey, FeatureStore\n\n__all__ = [\"BasicFeatureStore\"]\n\n\nclass BasicFeatureStore(FeatureStore):\n    r\"\"\"A basic feature store to manage multiple features for access.\"\"\"\n\n    def __init__(self, features: Dict[Tuple[str, str, str], Feature]):\n        r\"\"\"Initiate a basic feature store.\n\n\n        Parameters\n        ----------\n        features : Dict[Tuple[str, str, str], Feature]\n            The dict of features served by the feature store, in which the key\n            is tuple of (domain, type_name, feature_name).\n\n        Returns\n        -------\n        The feature stores.\n        \"\"\"\n        super().__init__()\n        self._features = features\n\n    def __getitem__(self, feature_key: FeatureKey) -> Feature:\n        \"\"\"Access the underlying `Feature` with its (domain, type, name) as\n        the feature_key.\n        \"\"\"\n        return self._features[feature_key]\n\n    def __setitem__(self, feature_key: FeatureKey, feature: Feature):\n        \"\"\"Set the underlying `Feature` with its (domain, type, name) as\n        the feature_key and feature as the value.\n        \"\"\"\n        self._features[feature_key] = feature\n\n    def __contains__(self, feature_key: FeatureKey) -> bool:\n        \"\"\"Checks whether the provided (domain, type, name) as the feature_key\n        is container in the BasicFeatureStore.\"\"\"\n        return feature_key in self._features\n\n    def __len__(self):\n        \"\"\"Return the number of features.\"\"\"\n        return len(self._features)\n\n    def keys(self):\n        \"\"\"Get the keys of the features.\n\n        Returns\n        -------\n        List[tuple]\n            The keys of the features. The tuples are in `(domain, type_name,\n            feat_name)` format.\n        \"\"\"\n        return list(self._features.keys())\n"
  },
  {
    "path": "python/dgl/graphbolt/impl/cooperative_conv.py",
    "content": "\"\"\"Graphbolt cooperative convolution.\"\"\"\nfrom typing import Dict, Union\n\nimport torch\n\nfrom ..sampled_subgraph import SampledSubgraph\nfrom ..subgraph_sampler import all_to_all, convert_to_hetero, revert_to_homo\n\n__all__ = [\"CooperativeConvFunction\", \"CooperativeConv\"]\n\n\nclass CooperativeConvFunction(torch.autograd.Function):\n    \"\"\"Cooperative convolution operation from Cooperative Minibatching.\n\n    Implements the `all-to-all` message passing algorithm\n    in Cooperative Minibatching, which was initially proposed in\n    `Deep Graph Library PR#4337<https://github.com/dmlc/dgl/pull/4337>`__ and\n    was later first fully described in\n    `Cooperative Minibatching in Graph Neural Networks\n    <https://arxiv.org/abs/2310.12403>`__.\n    Cooperation between the GPUs eliminates duplicate work performed across the\n    GPUs due to the overlapping sampled k-hop neighborhoods of seed nodes when\n    performing GNN minibatching. This reduces the redundant computations across\n    GPUs at the expense of communication.\n    \"\"\"\n\n    @staticmethod\n    def forward(\n        ctx,\n        subgraph: SampledSubgraph,\n        tensor: Union[torch.Tensor, Dict[str, torch.Tensor]],\n    ):\n        \"\"\"Implements the forward pass.\"\"\"\n        counts_sent = convert_to_hetero(subgraph._counts_sent)\n        counts_received = convert_to_hetero(subgraph._counts_received)\n        seed_inverse_ids = convert_to_hetero(subgraph._seed_inverse_ids)\n        seed_sizes = convert_to_hetero(subgraph._seed_sizes)\n        ctx.communication_variables = (\n            counts_sent,\n            counts_received,\n            seed_inverse_ids,\n            seed_sizes,\n        )\n        outs = {}\n        for ntype, typed_tensor in convert_to_hetero(tensor).items():\n            out = typed_tensor.new_empty(\n                (sum(counts_sent[ntype]),) + typed_tensor.shape[1:]\n            )\n            all_to_all(\n                torch.split(out, counts_sent[ntype]),\n                torch.split(\n                    typed_tensor[seed_inverse_ids[ntype]],\n                    counts_received[ntype],\n                ),\n            )\n            outs[ntype] = out\n        return revert_to_homo(out)\n\n    @staticmethod\n    def backward(\n        ctx, grad_output: Union[torch.Tensor, Dict[str, torch.Tensor]]\n    ):\n        \"\"\"Implements the backward pass.\"\"\"\n        (\n            counts_sent,\n            counts_received,\n            seed_inverse_ids,\n            seed_sizes,\n        ) = ctx.communication_variables\n        delattr(ctx, \"communication_variables\")\n        outs = {}\n        for ntype, typed_grad_output in convert_to_hetero(grad_output).items():\n            out = typed_grad_output.new_empty(\n                (sum(counts_received[ntype]),) + typed_grad_output.shape[1:]\n            )\n            all_to_all(\n                torch.split(out, counts_received[ntype]),\n                torch.split(typed_grad_output, counts_sent[ntype]),\n            )\n            i = out.new_empty(2, out.shape[0], dtype=torch.int64)\n            i[0] = seed_inverse_ids[ntype]  # src\n            i[1] = torch.arange(\n                out.shape[0], device=typed_grad_output.device\n            )  # dst\n            coo = torch.sparse_coo_tensor(\n                i,\n                torch.ones(\n                    i.shape[1], dtype=grad_output.dtype, device=i.device\n                ),\n                size=(seed_sizes[ntype], i.shape[1]),\n            )\n            outs[ntype] = torch.sparse.mm(coo, out)\n        return None, revert_to_homo(outs)\n\n\nclass CooperativeConv(torch.nn.Module):\n    \"\"\"Cooperative convolution operation from Cooperative Minibatching.\n\n    Implements the `all-to-all` message passing algorithm\n    in Cooperative Minibatching, which was initially proposed in\n    `Deep Graph Library PR#4337<https://github.com/dmlc/dgl/pull/4337>`__ and\n    was later first fully described in\n    `Cooperative Minibatching in Graph Neural Networks\n    <https://arxiv.org/abs/2310.12403>`__.\n    Cooperation between the GPUs eliminates duplicate work performed across the\n    GPUs due to the overlapping sampled k-hop neighborhoods of seed nodes when\n    performing GNN minibatching. This reduces the redundant computations across\n    GPUs at the expense of communication.\n    \"\"\"\n\n    def forward(\n        self,\n        subgraph: SampledSubgraph,\n        x: Union[torch.Tensor, Dict[str, torch.Tensor]],\n    ):\n        \"\"\"Implements the forward pass.\"\"\"\n        return CooperativeConvFunction.apply(subgraph, x)\n"
  },
  {
    "path": "python/dgl/graphbolt/impl/cpu_cached_feature.py",
    "content": "\"\"\"CPU cached feature for GraphBolt.\"\"\"\nfrom typing import Dict, Optional, Union\n\nimport torch\n\nfrom ..base import get_device_to_host_uva_stream, get_host_to_device_uva_stream\nfrom ..feature_store import (\n    bytes_to_number_of_items,\n    Feature,\n    FeatureKey,\n    wrap_with_cached_feature,\n)\n\nfrom .cpu_feature_cache import CPUFeatureCache\n\n__all__ = [\"CPUCachedFeature\", \"cpu_cached_feature\"]\n\n\nclass CPUCachedFeature(Feature):\n    r\"\"\"CPU cached feature wrapping a fallback feature. Use `cpu_cached_feature`\n    to construct an instance of this class.\n\n    Parameters\n    ----------\n    fallback_feature : Feature\n        The fallback feature.\n    cache : CPUFeatureCache\n        A CPUFeatureCache instance to serve as the cache backend.\n    offset : int, optional\n        The offset value to add to the given ids before using the cache. This\n        parameter is useful if multiple `CPUCachedFeature`s are sharing a single\n        CPUFeatureCache object.\n    \"\"\"\n\n    _cache_type = CPUFeatureCache\n\n    def __init__(\n        self,\n        fallback_feature: Feature,\n        cache: CPUFeatureCache,\n        offset: int = 0,\n    ):\n        super(CPUCachedFeature, self).__init__()\n        assert isinstance(fallback_feature, Feature), (\n            f\"The fallback_feature must be an instance of Feature, but got \"\n            f\"{type(fallback_feature)}.\"\n        )\n        self._fallback_feature = fallback_feature\n        self._feature = cache\n        self._offset = offset\n\n    def read(self, ids: torch.Tensor = None):\n        \"\"\"Read the feature by index.\n\n        Parameters\n        ----------\n        ids : torch.Tensor, optional\n            The index of the feature. If specified, only the specified indices\n            of the feature are read. If None, the entire feature is returned.\n\n        Returns\n        -------\n        torch.Tensor\n            The read feature.\n        \"\"\"\n        if ids is None:\n            return self._fallback_feature.read()\n        return self._feature.query_and_replace(\n            ids.cpu(), self._fallback_feature.read, self._offset\n        ).to(ids.device)\n\n    def read_async(self, ids: torch.Tensor):\n        r\"\"\"Read the feature by index asynchronously.\n\n        Parameters\n        ----------\n        ids : torch.Tensor\n            The index of the feature. Only the specified indices of the\n            feature are read.\n        Returns\n        -------\n        A generator object.\n            The returned generator object returns a future on\n            ``read_async_num_stages(ids.device)``\\ th invocation. The return result\n            can be accessed by calling ``.wait()``. on the returned future object.\n            It is undefined behavior to call ``.wait()`` more than once.\n\n        Examples\n        --------\n        >>> import dgl.graphbolt as gb\n        >>> feature = gb.Feature(...)\n        >>> ids = torch.tensor([0, 2])\n        >>> for stage, future in enumerate(feature.read_async(ids)):\n        ...     pass\n        >>> assert stage + 1 == feature.read_async_num_stages(ids.device)\n        >>> result = future.wait()  # result contains the read values.\n        \"\"\"\n        policy = self._feature._policy\n        cache = self._feature._cache\n        if ids.is_cuda and self.is_pinned():\n            ids_device = ids.device\n            current_stream = torch.cuda.current_stream()\n            device_to_host_stream = get_device_to_host_uva_stream()\n            device_to_host_stream.wait_stream(current_stream)\n            with torch.cuda.stream(device_to_host_stream):\n                ids.record_stream(torch.cuda.current_stream())\n                ids = ids.to(\"cpu\", non_blocking=True)\n                ids_copy_event = torch.cuda.Event()\n                ids_copy_event.record()\n\n            yield  # first stage is done.\n\n            ids_copy_event.synchronize()\n            policy_future = policy.query_and_replace_async(ids, self._offset)\n\n            yield\n\n            (\n                positions,\n                index,\n                pointers,\n                missing_keys,\n                found_offsets,\n                missing_offsets,\n            ) = policy_future.wait()\n            self._feature.total_queries += ids.shape[0]\n            self._feature.total_miss += missing_keys.shape[0]\n            found_cnt = ids.size(0) - missing_keys.size(0)\n            found_positions = positions[:found_cnt]\n            missing_positions = positions[found_cnt:]\n            found_pointers = pointers[:found_cnt]\n            missing_pointers = pointers[found_cnt:]\n            host_to_device_stream = get_host_to_device_uva_stream()\n            with torch.cuda.stream(host_to_device_stream):\n                found_positions = found_positions.to(\n                    ids_device, non_blocking=True\n                )\n                values_from_cpu = cache.index_select(found_positions)\n                values_from_cpu.record_stream(current_stream)\n                values_from_cpu_copy_event = torch.cuda.Event()\n                values_from_cpu_copy_event.record()\n\n            fallback_reader = self._fallback_feature.read_async(missing_keys)\n            for _ in range(\n                self._fallback_feature.read_async_num_stages(\n                    missing_keys.device\n                )\n            ):\n                missing_values_future = next(fallback_reader, None)\n                yield  # fallback feature stages.\n\n            values_from_cpu_copy_event.synchronize()\n            reading_completed = policy.reading_completed_async(\n                found_pointers, found_offsets\n            )\n\n            missing_values = missing_values_future.wait()\n            replace_future = cache.replace_async(\n                missing_positions, missing_values\n            )\n\n            host_to_device_stream = get_host_to_device_uva_stream()\n            with torch.cuda.stream(host_to_device_stream):\n                index = index.to(ids_device, non_blocking=True)\n                missing_values = missing_values.to(\n                    ids_device, non_blocking=True\n                )\n                index.record_stream(current_stream)\n                missing_values.record_stream(current_stream)\n                missing_values_copy_event = torch.cuda.Event()\n                missing_values_copy_event.record()\n\n            yield\n\n            reading_completed.wait()\n            replace_future.wait()\n            writing_completed = policy.writing_completed_async(\n                missing_pointers, missing_offsets\n            )\n\n            class _Waiter:\n                def __init__(self, events, existing, missing, index):\n                    self.events = events\n                    self.existing = existing\n                    self.missing = missing\n                    self.index = index\n\n                def wait(self):\n                    \"\"\"Returns the stored value when invoked.\"\"\"\n                    for event in self.events:\n                        event.wait()\n                    values = torch.empty(\n                        (self.index.shape[0],) + self.missing.shape[1:],\n                        dtype=self.missing.dtype,\n                        device=ids_device,\n                    )\n                    num_found = self.existing.size(0)\n                    found_index = self.index[:num_found]\n                    missing_index = self.index[num_found:]\n                    values[found_index] = self.existing\n                    values[missing_index] = self.missing\n                    # Ensure there is no memory leak.\n                    self.events = self.existing = None\n                    self.missing = self.index = None\n                    return values\n\n            yield _Waiter(\n                [\n                    writing_completed,\n                    values_from_cpu_copy_event,\n                    missing_values_copy_event,\n                ],\n                values_from_cpu,\n                missing_values,\n                index,\n            )\n        elif ids.is_cuda:\n            ids_device = ids.device\n            current_stream = torch.cuda.current_stream()\n            device_to_host_stream = get_device_to_host_uva_stream()\n            device_to_host_stream.wait_stream(current_stream)\n            with torch.cuda.stream(device_to_host_stream):\n                ids.record_stream(torch.cuda.current_stream())\n                ids = ids.to(\"cpu\", non_blocking=True)\n                ids_copy_event = torch.cuda.Event()\n                ids_copy_event.record()\n\n            yield  # first stage is done.\n\n            ids_copy_event.synchronize()\n            policy_future = policy.query_and_replace_async(ids, self._offset)\n\n            yield\n\n            (\n                positions,\n                index,\n                pointers,\n                missing_keys,\n                found_offsets,\n                missing_offsets,\n            ) = policy_future.wait()\n            self._feature.total_queries += ids.shape[0]\n            self._feature.total_miss += missing_keys.shape[0]\n            found_cnt = ids.size(0) - missing_keys.size(0)\n            found_positions = positions[:found_cnt]\n            missing_positions = positions[found_cnt:]\n            found_pointers = pointers[:found_cnt]\n            missing_pointers = pointers[found_cnt:]\n            values_future = cache.query_async(\n                found_positions, index, ids.shape[0]\n            )\n\n            fallback_reader = self._fallback_feature.read_async(missing_keys)\n            for _ in range(\n                self._fallback_feature.read_async_num_stages(\n                    missing_keys.device\n                )\n            ):\n                missing_values_future = next(fallback_reader, None)\n                yield  # fallback feature stages.\n\n            values = values_future.wait()\n            reading_completed = policy.reading_completed_async(\n                found_pointers, found_offsets\n            )\n\n            missing_index = index[found_cnt:]\n\n            missing_values = missing_values_future.wait()\n            replace_future = cache.replace_async(\n                missing_positions, missing_values\n            )\n            values = torch.ops.graphbolt.scatter_async(\n                values, missing_index, missing_values\n            )\n\n            yield\n\n            host_to_device_stream = get_host_to_device_uva_stream()\n            with torch.cuda.stream(host_to_device_stream):\n                values = values.wait().to(ids_device, non_blocking=True)\n                values.record_stream(current_stream)\n                values_copy_event = torch.cuda.Event()\n                values_copy_event.record()\n\n            reading_completed.wait()\n            replace_future.wait()\n            writing_completed = policy.writing_completed_async(\n                missing_pointers, missing_offsets\n            )\n\n            class _Waiter:\n                def __init__(self, events, values):\n                    self.events = events\n                    self.values = values\n\n                def wait(self):\n                    \"\"\"Returns the stored value when invoked.\"\"\"\n                    for event in self.events:\n                        event.wait()\n                    values = self.values\n                    # Ensure there is no memory leak.\n                    self.events = self.values = None\n                    return values\n\n            yield _Waiter([values_copy_event, writing_completed], values)\n        else:\n            policy_future = policy.query_and_replace_async(ids, self._offset)\n\n            yield\n\n            (\n                positions,\n                index,\n                pointers,\n                missing_keys,\n                found_offsets,\n                missing_offsets,\n            ) = policy_future.wait()\n            self._feature.total_queries += ids.shape[0]\n            self._feature.total_miss += missing_keys.shape[0]\n            found_cnt = ids.size(0) - missing_keys.size(0)\n            found_positions = positions[:found_cnt]\n            missing_positions = positions[found_cnt:]\n            found_pointers = pointers[:found_cnt]\n            missing_pointers = pointers[found_cnt:]\n            values_future = cache.query_async(\n                found_positions, index, ids.shape[0]\n            )\n\n            fallback_reader = self._fallback_feature.read_async(missing_keys)\n            for _ in range(\n                self._fallback_feature.read_async_num_stages(\n                    missing_keys.device\n                )\n            ):\n                missing_values_future = next(fallback_reader, None)\n                yield  # fallback feature stages.\n\n            values = values_future.wait()\n            reading_completed = policy.reading_completed_async(\n                found_pointers, found_offsets\n            )\n\n            missing_index = index[found_cnt:]\n\n            missing_values = missing_values_future.wait()\n            replace_future = cache.replace_async(\n                missing_positions, missing_values\n            )\n            values = torch.ops.graphbolt.scatter_async(\n                values, missing_index, missing_values\n            )\n\n            yield\n\n            reading_completed.wait()\n            replace_future.wait()\n            writing_completed = policy.writing_completed_async(\n                missing_pointers, missing_offsets\n            )\n\n            class _Waiter:\n                def __init__(self, event, values):\n                    self.event = event\n                    self.values = values\n\n                def wait(self):\n                    \"\"\"Returns the stored value when invoked.\"\"\"\n                    self.event.wait()\n                    values = self.values.wait()\n                    # Ensure there is no memory leak.\n                    self.event = self.values = None\n                    return values\n\n            yield _Waiter(writing_completed, values)\n\n    def read_async_num_stages(self, ids_device: torch.device):\n        \"\"\"The number of stages of the read_async operation. See read_async\n        function for directions on its use. This function is required to return\n        the number of yield operations when read_async is used with a tensor\n        residing on ids_device.\n\n        Parameters\n        ----------\n        ids_device : torch.device\n            The device of the ids parameter passed into read_async.\n        Returns\n        -------\n        int\n            The number of stages of the read_async operation.\n        \"\"\"\n        if ids_device.type == \"cuda\":\n            return 4 + self._fallback_feature.read_async_num_stages(\n                torch.device(\"cpu\")\n            )\n        else:\n            return 3 + self._fallback_feature.read_async_num_stages(ids_device)\n\n    def size(self):\n        \"\"\"Get the size of the feature.\n\n        Returns\n        -------\n        torch.Size\n            The size of the feature.\n        \"\"\"\n        return self._fallback_feature.size()\n\n    def count(self):\n        \"\"\"Get the count of the feature.\n\n        Returns\n        -------\n        int\n            The count of the feature.\n        \"\"\"\n        return self._fallback_feature.count()\n\n    def update(self, value: torch.Tensor, ids: torch.Tensor = None):\n        \"\"\"Update the feature.\n\n        Parameters\n        ----------\n        value : torch.Tensor\n            The updated value of the feature.\n        ids : torch.Tensor, optional\n            The indices of the feature to update. If specified, only the\n            specified indices of the feature will be updated. For the feature,\n            the `ids[i]` row is updated to `value[i]`. So the indices and value\n            must have the same length. If None, the entire feature will be\n            updated.\n        \"\"\"\n        if ids is None:\n            feat0 = value[:1]\n            self._fallback_feature.update(value)\n            cache_size = min(\n                bytes_to_number_of_items(self.cache_size_in_bytes, feat0),\n                value.shape[0],\n            )\n            self._feature = None  # Destroy the existing cache first.\n            self._feature = self._cache_type(\n                (cache_size,) + feat0.shape[1:], feat0.dtype\n            )\n        else:\n            self._fallback_feature.update(value, ids)\n            self._feature.replace(ids, value, None, self._offset)\n\n    def is_pinned(self):\n        \"\"\"Returns True if the cache storage is pinned.\"\"\"\n        return self._feature.is_pinned()\n\n    @property\n    def cache_size_in_bytes(self):\n        \"\"\"Return the size taken by the cache in bytes.\"\"\"\n        return self._feature.max_size_in_bytes\n\n    @property\n    def miss_rate(self):\n        \"\"\"Returns the cache miss rate since creation.\"\"\"\n        return self._feature.miss_rate\n\n\ndef cpu_cached_feature(\n    fallback_features: Union[Feature, Dict[FeatureKey, Feature]],\n    max_cache_size_in_bytes: int,\n    policy: Optional[str] = None,\n    pin_memory: bool = False,\n) -> Union[CPUCachedFeature, Dict[FeatureKey, CPUCachedFeature]]:\n    r\"\"\"CPU cached feature wrapping a fallback feature.\n\n    Parameters\n    ----------\n    fallback_features : Union[Feature, Dict[FeatureKey, Feature]]\n        The fallback feature(s).\n    max_cache_size_in_bytes : int\n        The capacity of the cache in bytes. The size should be a few factors\n        larger than the size of each read request. Otherwise, the caching policy\n        will hang due to all cache entries being read and/or write locked,\n        resulting in a deadlock.\n    policy : str, optional\n        The cache eviction policy algorithm name. The available policies are\n        [\"s3-fifo\", \"sieve\", \"lru\", \"clock\"]. Default is \"sieve\".\n    pin_memory : bool, optional\n        Whether the cache storage should be allocated on system pinned memory.\n        Default is False.\n    Returns\n    -------\n    Union[CPUCachedFeature, Dict[FeatureKey, CPUCachedFeature]]\n        New feature(s) wrapped with CPUCachedFeature.\n    \"\"\"\n    return wrap_with_cached_feature(\n        CPUCachedFeature,\n        fallback_features,\n        max_cache_size_in_bytes,\n        policy=policy,\n        pin_memory=pin_memory,\n    )\n"
  },
  {
    "path": "python/dgl/graphbolt/impl/cpu_feature_cache.py",
    "content": "\"\"\"CPU Feature Cache implementation wrapper for graphbolt.\"\"\"\nimport torch\n\n__all__ = [\"CPUFeatureCache\"]\n\ncaching_policies = {\n    \"s3-fifo\": torch.ops.graphbolt.s3_fifo_cache_policy,\n    \"sieve\": torch.ops.graphbolt.sieve_cache_policy,\n    \"lru\": torch.ops.graphbolt.lru_cache_policy,\n    \"clock\": torch.ops.graphbolt.clock_cache_policy,\n}\n\n\nclass CPUFeatureCache(object):\n    r\"\"\"High level wrapper for the CPU feature cache.\n\n    Parameters\n    ----------\n    cache_shape : List[int]\n        The shape of the cache. cache_shape[0] gives us the capacity.\n    dtype : torch.dtype\n        The data type of the elements stored in the cache.\n    policy: str, optional\n        The cache policy. Default is \"sieve\". \"s3-fifo\", \"lru\" and \"clock\" are\n        also available.\n    num_parts: int, optional\n        The number of cache partitions for parallelism. Default is\n        `torch.get_num_threads()`.\n    pin_memory: bool, optional\n        Whether the cache storage should be pinned.\n    \"\"\"\n\n    def __init__(\n        self,\n        cache_shape,\n        dtype,\n        policy=None,\n        num_parts=None,\n        pin_memory=False,\n    ):\n        if policy is None:\n            policy = \"sieve\"\n        assert (\n            policy in caching_policies\n        ), f\"{list(caching_policies.keys())} are the available caching policies.\"\n        if num_parts is None:\n            num_parts = torch.get_num_threads()\n        min_num_cache_items = num_parts * (10 if policy == \"s3-fifo\" else 1)\n        # Since we partition the cache, each partition needs to have a positive\n        # number of slots. In addition, each \"s3-fifo\" partition needs at least\n        # 10 slots since the small queue is 10% and the small queue needs a\n        # positive size.\n        if cache_shape[0] < min_num_cache_items:\n            cache_shape = (min_num_cache_items,) + cache_shape[1:]\n        self._policy = caching_policies[policy](cache_shape[0], num_parts)\n        self._cache = torch.ops.graphbolt.feature_cache(\n            cache_shape, dtype, pin_memory\n        )\n        self.total_miss = 0\n        self.total_queries = 0\n\n    def is_pinned(self):\n        \"\"\"Returns True if the cache storage is pinned.\"\"\"\n        return self._cache.is_pinned()\n\n    @property\n    def max_size_in_bytes(self):\n        \"\"\"Return the size taken by the cache in bytes.\"\"\"\n        return self._cache.nbytes\n\n    def query(self, keys, offset=0):\n        \"\"\"Queries the cache.\n\n        Parameters\n        ----------\n        keys : Tensor\n            The keys to query the cache with.\n        offset : int\n            The offset to be added to the keys. Default is 0.\n\n        Returns\n        -------\n        tuple(Tensor, Tensor, Tensor, Tensor)\n            A tuple containing\n            (values, missing_indices, missing_keys, missing_offsets) where\n            values[missing_indices] corresponds to cache misses that should be\n            filled by quering another source with missing_keys. If keys is\n            pinned, then the returned values tensor is pinned as well. The\n            missing_offsets tensor has the partition offsets of missing_keys.\n        \"\"\"\n        self.total_queries += keys.shape[0]\n        (\n            positions,\n            index,\n            missing_keys,\n            found_pointers,\n            found_offsets,\n            missing_offsets,\n        ) = self._policy.query(keys, offset)\n        values = self._cache.query(positions, index, keys.shape[0])\n        self._policy.reading_completed(found_pointers, found_offsets)\n        self.total_miss += missing_keys.shape[0]\n        missing_index = index[positions.size(0) :]\n        return values, missing_index, missing_keys, missing_offsets\n\n    def query_and_replace(self, keys, reader_fn, offset=0):\n        \"\"\"Queries the cache. Then inserts the keys that are not found by\n        reading them by calling `reader_fn(missing_keys)`, which are then\n        inserted into the cache using the selected caching policy algorithm\n        to remove the old entries if it is full.\n\n        Parameters\n        ----------\n        keys : Tensor\n            The keys to query the cache with.\n        reader_fn : reader_fn(keys: torch.Tensor) -> torch.Tensor\n            A function that will take a missing keys tensor and will return\n            their values.\n        offset : int\n            The offset to be added to the keys. Default is 0.\n\n        Returns\n        -------\n        Tensor\n            A tensor containing values corresponding to the keys. Should equal\n            `reader_fn(keys)`, computed in a faster way.\n        \"\"\"\n        self.total_queries += keys.shape[0]\n        (\n            positions,\n            index,\n            pointers,\n            missing_keys,\n            found_offsets,\n            missing_offsets,\n        ) = self._policy.query_and_replace(keys, offset)\n        found_cnt = keys.size(0) - missing_keys.size(0)\n        found_positions = positions[:found_cnt]\n        values = self._cache.query(found_positions, index, keys.shape[0])\n        found_pointers = pointers[:found_cnt]\n        self._policy.reading_completed(found_pointers, found_offsets)\n        self.total_miss += missing_keys.shape[0]\n        missing_index = index[found_cnt:]\n        missing_values = reader_fn(missing_keys)\n        values[missing_index] = missing_values\n        missing_positions = positions[found_cnt:]\n        self._cache.replace(missing_positions, missing_values)\n        missing_pointers = pointers[found_cnt:]\n        self._policy.writing_completed(missing_pointers, missing_offsets)\n        return values\n\n    def replace(self, keys, values, offsets=None, offset=0):\n        \"\"\"Inserts key-value pairs into the cache using the selected caching\n        policy algorithm to remove old key-value pairs if it is full.\n\n        Parameters\n        ----------\n        keys : Tensor\n            The keys to insert to the cache.\n        values : Tensor\n            The values to insert to the cache.\n        offsets : Tensor, optional\n            The partition offsets of the keys.\n        offset : int\n            The offset to be added to the keys. Default is 0.\n        \"\"\"\n        positions, pointers, offsets = self._policy.replace(\n            keys, offsets, offset\n        )\n        self._cache.replace(positions, values)\n        self._policy.writing_completed(pointers, offsets)\n\n    @property\n    def miss_rate(self):\n        \"\"\"Returns the cache miss rate since creation.\"\"\"\n        return self.total_miss / self.total_queries\n"
  },
  {
    "path": "python/dgl/graphbolt/impl/fused_csc_sampling_graph.py",
    "content": "\"\"\"CSC format sampling graph.\"\"\"\n\nimport textwrap\n\n# pylint: disable= invalid-name\nfrom typing import Dict, Optional, Union\n\nimport torch\n\nfrom ..base import etype_str_to_tuple, etype_tuple_to_str, ORIGINAL_EDGE_ID\nfrom ..internal_utils import gb_warning, is_wsl, recursive_apply\nfrom ..sampling_graph import SamplingGraph\nfrom .gpu_graph_cache import GPUGraphCache\nfrom .sampled_subgraph_impl import CSCFormatBase, SampledSubgraphImpl\n\n\n__all__ = [\n    \"FusedCSCSamplingGraph\",\n    \"fused_csc_sampling_graph\",\n    \"load_from_shared_memory\",\n    \"from_dglgraph\",\n]\n\n\nclass _SampleNeighborsWaiter:\n    def __init__(\n        self, fn, future, seed_offsets, fetching_original_edge_ids_is_optional\n    ):\n        self.fn = fn\n        self.future = future\n        self.seed_offsets = seed_offsets\n        self.fetching_original_edge_ids_is_optional = (\n            fetching_original_edge_ids_is_optional\n        )\n\n    def wait(self):\n        \"\"\"Returns the stored value when invoked.\"\"\"\n        fn = self.fn\n        C_sampled_subgraph = self.future.wait()\n        seed_offsets = self.seed_offsets\n        fetching_original_edge_ids_is_optional = (\n            self.fetching_original_edge_ids_is_optional\n        )\n        # Ensure there is no memory leak.\n        self.fn = self.future = self.seed_offsets = None\n        self.fetching_original_edge_ids_is_optional = None\n        return fn(\n            C_sampled_subgraph,\n            seed_offsets,\n            fetching_original_edge_ids_is_optional,\n        )\n\n\nclass FusedCSCSamplingGraph(SamplingGraph):\n    r\"\"\"A sampling graph in CSC format.\"\"\"\n\n    def __repr__(self):\n        final_str = (\n            \"{classname}(csc_indptr={csc_indptr},\\n\"\n            \"indices={indices},\\n\"\n            \"{metadata})\"\n        )\n\n        classname_str = self.__class__.__name__\n        csc_indptr_str = str(self.csc_indptr)\n        indices_str = str(self.indices)\n        meta_str = f\"total_num_nodes={self.total_num_nodes}, num_edges={self.num_edges},\"\n        if self.node_type_offset is not None:\n            meta_str += f\"\\nnode_type_offset={self.node_type_offset},\"\n        if self.type_per_edge is not None:\n            meta_str += f\"\\ntype_per_edge={self.type_per_edge},\"\n        if self.node_type_to_id is not None:\n            meta_str += f\"\\nnode_type_to_id={self.node_type_to_id},\"\n        if self.edge_type_to_id is not None:\n            meta_str += f\"\\nedge_type_to_id={self.edge_type_to_id},\"\n        if self.node_attributes is not None:\n            meta_str += f\"\\nnode_attributes={self.node_attributes},\"\n        if self.edge_attributes is not None:\n            meta_str += f\"\\nedge_attributes={self.edge_attributes},\"\n\n        final_str = final_str.format(\n            classname=classname_str,\n            csc_indptr=csc_indptr_str,\n            indices=indices_str,\n            metadata=meta_str,\n        )\n        return textwrap.indent(\n            final_str, \" \" * (len(classname_str) + 1)\n        ).strip()\n\n    def __init__(\n        self,\n        c_csc_graph: torch.ScriptObject,\n    ):\n        super().__init__()\n        self._c_csc_graph = c_csc_graph\n\n    def __del__(self):\n        # torch.Tensor.pin_memory() is not an inplace operation. To make it\n        # truly in-place, we need to use cudaHostRegister. Then, we need to use\n        # cudaHostUnregister to unpin the tensor in the destructor.\n        # https://github.com/pytorch/pytorch/issues/32167#issuecomment-753551842\n        if hasattr(self, \"_is_inplace_pinned\"):\n            for tensor in self._is_inplace_pinned:\n                assert self._inplace_unpinner(tensor.data_ptr()) == 0\n\n    @property\n    def total_num_nodes(self) -> int:\n        \"\"\"Returns the number of nodes in the graph.\n\n        Returns\n        -------\n        int\n            The number of rows in the dense format.\n        \"\"\"\n        return self._c_csc_graph.num_nodes()\n\n    @property\n    def total_num_edges(self) -> int:\n        \"\"\"Returns the number of edges in the graph.\n\n        Returns\n        -------\n        int\n            The number of edges in the graph.\n        \"\"\"\n        return self._c_csc_graph.num_edges()\n\n    @property\n    def num_nodes(self) -> Union[int, Dict[str, int]]:\n        \"\"\"The number of nodes in the graph.\n        - If the graph is homogenous, returns an integer.\n        - If the graph is heterogenous, returns a dictionary.\n\n        Returns\n        -------\n        Union[int, Dict[str, int]]\n            The number of nodes. Integer indicates the total nodes number of a\n            homogenous graph; dict indicates nodes number per node types of a\n            heterogenous graph.\n\n        Examples\n        --------\n        >>> import dgl.graphbolt as gb, torch\n        >>> total_num_nodes = 5\n        >>> total_num_edges = 12\n        >>> ntypes = {\"N0\": 0, \"N1\": 1}\n        >>> etypes = {\"N0:R0:N0\": 0, \"N0:R1:N1\": 1,\n        ...     \"N1:R2:N0\": 2, \"N1:R3:N1\": 3}\n        >>> indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])\n        >>> indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])\n        >>> node_type_offset = torch.LongTensor([0, 2, 5])\n        >>> type_per_edge = torch.LongTensor(\n        ...     [0, 0, 2, 2, 2, 1, 1, 1, 3, 1, 3, 3])\n        >>> graph = gb.fused_csc_sampling_graph(indptr, indices,\n        ...     node_type_offset=node_type_offset,\n        ...     type_per_edge=type_per_edge,\n        ...     node_type_to_id=ntypes,\n        ...     edge_type_to_id=etypes)\n        >>> print(graph.num_nodes)\n        {'N0': 2, 'N1': 3}\n        \"\"\"\n\n        offset = self._node_type_offset_list\n\n        # Homogenous.\n        if offset is None or self.node_type_to_id is None:\n            return self._c_csc_graph.num_nodes()\n\n        # Heterogenous\n        else:\n            num_nodes_per_type = {\n                _type: offset[_idx + 1] - offset[_idx]\n                for _type, _idx in self.node_type_to_id.items()\n            }\n\n            return num_nodes_per_type\n\n    @property\n    def num_edges(self) -> Union[int, Dict[str, int]]:\n        \"\"\"The number of edges in the graph.\n        - If the graph is homogenous, returns an integer.\n        - If the graph is heterogenous, returns a dictionary.\n\n        Returns\n        -------\n        Union[int, Dict[str, int]]\n            The number of edges. Integer indicates the total edges number of a\n            homogenous graph; dict indicates edges number per edge types of a\n            heterogenous graph.\n\n        Examples\n        --------\n        >>> import dgl.graphbolt as gb, torch\n        >>> total_num_nodes = 5\n        >>> total_num_edges = 12\n        >>> ntypes = {\"N0\": 0, \"N1\": 1}\n        >>> etypes = {\"N0:R0:N0\": 0, \"N0:R1:N1\": 1,\n        ...     \"N1:R2:N0\": 2, \"N1:R3:N1\": 3}\n        >>> indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])\n        >>> indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])\n        >>> node_type_offset = torch.LongTensor([0, 2, 5])\n        >>> type_per_edge = torch.LongTensor(\n        ...     [0, 0, 2, 2, 2, 1, 1, 1, 3, 1, 3, 3])\n        >>> metadata = gb.GraphMetadata(ntypes, etypes)\n        >>> graph = gb.fused_csc_sampling_graph(indptr, indices, node_type_offset,\n        ...     type_per_edge, None, metadata)\n        >>> print(graph.num_edges)\n        {'N0:R0:N0': 2, 'N0:R1:N1': 1, 'N1:R2:N0': 2, 'N1:R3:N1': 3}\n        \"\"\"\n\n        type_per_edge = self.type_per_edge\n\n        # Homogenous.\n        if type_per_edge is None or self.edge_type_to_id is None:\n            return self._c_csc_graph.num_edges()\n\n        # Heterogenous\n        bincount = torch.bincount(type_per_edge)\n        num_edges_per_type = {}\n        for etype, etype_id in self.edge_type_to_id.items():\n            if etype_id < len(bincount):\n                num_edges_per_type[etype] = bincount[etype_id].item()\n            else:\n                num_edges_per_type[etype] = 0\n        return num_edges_per_type\n\n    @property\n    def csc_indptr(self) -> torch.tensor:\n        \"\"\"Returns the indices pointer in the CSC graph.\n\n        Returns\n        -------\n        torch.tensor\n            The indices pointer in the CSC graph. An integer tensor with\n            shape `(total_num_nodes+1,)`.\n        \"\"\"\n        return self._c_csc_graph.csc_indptr()\n\n    @csc_indptr.setter\n    def csc_indptr(self, csc_indptr: torch.tensor) -> None:\n        \"\"\"Sets the indices pointer in the CSC graph.\"\"\"\n        self._c_csc_graph.set_csc_indptr(csc_indptr)\n\n    @property\n    def indices(self) -> torch.tensor:\n        \"\"\"Returns the indices in the CSC graph.\n\n        Returns\n        -------\n        torch.tensor\n            The indices in the CSC graph. An integer tensor with shape\n            `(total_num_edges,)`.\n\n        Notes\n        -------\n        It is assumed that edges of each node are already sorted by edge type\n        ids.\n        \"\"\"\n        return self._c_csc_graph.indices()\n\n    @indices.setter\n    def indices(self, indices: torch.tensor) -> None:\n        \"\"\"Sets the indices in the CSC graph.\"\"\"\n        self._c_csc_graph.set_indices(indices)\n\n    @property\n    def node_type_offset(self) -> Optional[torch.Tensor]:\n        \"\"\"Returns the node type offset tensor if present. Do not modify the\n        returned tensor in place.\n\n        Returns\n        -------\n        torch.Tensor or None\n            If present, returns a 1D integer tensor of shape\n            `(num_node_types + 1,)`. The tensor is in ascending order as nodes\n            of the same type have continuous IDs, and larger node IDs are\n            paired with larger node type IDs. The first value is 0 and last\n            value is the number of nodes. And nodes with IDs between\n            `node_type_offset_[i]~node_type_offset_[i+1]` are of type id 'i'.\n\n        \"\"\"\n        return self._c_csc_graph.node_type_offset()\n\n    @property\n    def _node_type_offset_list(self) -> Optional[list]:\n        \"\"\"Returns the node type offset list if present.\n\n        Returns\n        -------\n        list or None\n            If present, returns a 1D integer list of shape\n            `(num_node_types + 1,)`. The list is in ascending order as nodes\n            of the same type have continuous IDs, and larger node IDs are\n            paired with larger node type IDs. The first value is 0 and last\n            value is the number of nodes. And nodes with IDs between\n            `node_type_offset_[i]~node_type_offset_[i+1]` are of type id 'i'.\n\n        \"\"\"\n        if (\n            not hasattr(self, \"_node_type_offset_cached_list\")\n            or self._node_type_offset_cached_list is None\n        ):\n            self._node_type_offset_cached_list = self.node_type_offset\n            if self._node_type_offset_cached_list is not None:\n                self._node_type_offset_cached_list = (\n                    self._node_type_offset_cached_list.tolist()\n                )\n        return self._node_type_offset_cached_list\n\n    @node_type_offset.setter\n    def node_type_offset(\n        self, node_type_offset: Optional[torch.Tensor]\n    ) -> None:\n        \"\"\"Sets the node type offset tensor if present.\"\"\"\n        self._c_csc_graph.set_node_type_offset(node_type_offset)\n        self._node_type_offset_cached_list = None\n\n    @property\n    def _indptr_node_type_offset_list(self) -> Optional[list]:\n        \"\"\"Returns the indptr node type offset list which presents the column id\n        space when it does not match the global id space. It is useful when we\n        slice a subgraph from another FusedCSCSamplingGraph.\n\n        Returns\n        -------\n        list or None\n            If present, returns a 1D integer list of shape\n            `(num_node_types + 1,)`. The list is in ascending order as nodes\n            of the same type have continuous IDs, and larger node IDs are\n            paired with larger node type IDs. The first value is 0 and last\n            value is the number of nodes. And nodes with IDs between\n            `node_type_offset_[i]~node_type_offset_[i+1]` are of type id 'i'.\n        \"\"\"\n        return (\n            self._indptr_node_type_offset_list_\n            if hasattr(self, \"_indptr_node_type_offset_list_\")\n            else None\n        )\n\n    @_indptr_node_type_offset_list.setter\n    def _indptr_node_type_offset_list(\n        self, indptr_node_type_offset_list: Optional[torch.Tensor]\n    ):\n        \"\"\"Sets the indptr node type offset list if present.\"\"\"\n        self._indptr_node_type_offset_list_ = indptr_node_type_offset_list\n\n    @property\n    def _gpu_graph_cache(self) -> Optional[GPUGraphCache]:\n        return (\n            self._gpu_graph_cache_\n            if hasattr(self, \"_gpu_graph_cache_\")\n            else None\n        )\n\n    @property\n    def type_per_edge(self) -> Optional[torch.Tensor]:\n        \"\"\"Returns the edge type tensor if present.\n\n        Returns\n        -------\n        torch.Tensor or None\n            If present, returns a 1D integer tensor of shape (total_num_edges,)\n            containing the type of each edge in the graph.\n        \"\"\"\n        return self._c_csc_graph.type_per_edge()\n\n    @type_per_edge.setter\n    def type_per_edge(self, type_per_edge: Optional[torch.Tensor]) -> None:\n        \"\"\"Sets the edge type tensor if present.\"\"\"\n        self._c_csc_graph.set_type_per_edge(type_per_edge)\n\n    @property\n    def node_type_to_id(self) -> Optional[Dict[str, int]]:\n        \"\"\"Returns the node type to id dictionary if present.\n\n        Returns\n        -------\n        Dict[str, int] or None\n            If present, returns a dictionary mapping node type to node type\n            id.\n        \"\"\"\n        return self._c_csc_graph.node_type_to_id()\n\n    @node_type_to_id.setter\n    def node_type_to_id(\n        self, node_type_to_id: Optional[Dict[str, int]]\n    ) -> None:\n        \"\"\"Sets the node type to id dictionary if present.\"\"\"\n        self._c_csc_graph.set_node_type_to_id(node_type_to_id)\n\n    @property\n    def edge_type_to_id(self) -> Optional[Dict[str, int]]:\n        \"\"\"Returns the edge type to id dictionary if present.\n\n        Returns\n        -------\n        Dict[str, int] or None\n            If present, returns a dictionary mapping edge type to edge type\n            id.\n        \"\"\"\n        return self._c_csc_graph.edge_type_to_id()\n\n    @edge_type_to_id.setter\n    def edge_type_to_id(\n        self, edge_type_to_id: Optional[Dict[str, int]]\n    ) -> None:\n        \"\"\"Sets the edge type to id dictionary if present.\"\"\"\n        self._c_csc_graph.set_edge_type_to_id(edge_type_to_id)\n\n    @property\n    def node_attributes(self) -> Optional[Dict[str, torch.Tensor]]:\n        \"\"\"Returns the node attributes dictionary.\n\n        Returns\n        -------\n        Dict[str, torch.Tensor] or None\n            If present, returns a dictionary of node attributes. Each key\n            represents the attribute's name, while the corresponding value\n            holds the attribute's specific value. The length of each value\n            should match the total number of nodes.\"\n        \"\"\"\n        return self._c_csc_graph.node_attributes()\n\n    @node_attributes.setter\n    def node_attributes(\n        self, node_attributes: Optional[Dict[str, torch.Tensor]]\n    ) -> None:\n        \"\"\"Sets the node attributes dictionary.\"\"\"\n        self._c_csc_graph.set_node_attributes(node_attributes)\n\n    @property\n    def edge_attributes(self) -> Optional[Dict[str, torch.Tensor]]:\n        \"\"\"Returns the edge attributes dictionary.\n\n        Returns\n        -------\n        Dict[str, torch.Tensor] or None\n            If present, returns a dictionary of edge attributes. Each key\n            represents the attribute's name, while the corresponding value\n            holds the attribute's specific value. The length of each value\n            should match the total number of edges.\"\n        \"\"\"\n        return self._c_csc_graph.edge_attributes()\n\n    @edge_attributes.setter\n    def edge_attributes(\n        self, edge_attributes: Optional[Dict[str, torch.Tensor]]\n    ) -> None:\n        \"\"\"Sets the edge attributes dictionary.\"\"\"\n        self._c_csc_graph.set_edge_attributes(edge_attributes)\n\n    def node_attribute(self, name: str) -> Optional[torch.Tensor]:\n        \"\"\"Returns the node attribute tensor by name.\n\n        Parameters\n        ----------\n        name: str\n            The name of the node attribute.\n\n        Returns\n        -------\n        torch.Tensor or None\n            If present, returns the node attribute tensor.\n        \"\"\"\n        return self._c_csc_graph.node_attribute(name)\n\n    def add_node_attribute(self, name: str, tensor: torch.Tensor) -> None:\n        \"\"\"Adds node attribute tensor by name.\n\n        Parameters\n        ----------\n        name: str\n            The name of the node attribute.\n        tensor: torch.Tensor\n            The node attribute tensor.\n        \"\"\"\n        self._c_csc_graph.add_node_attribute(name, tensor)\n\n    def edge_attribute(self, name: str) -> Optional[torch.Tensor]:\n        \"\"\"Returns the edge attribute tensor by name.\n\n        Parameters\n        ----------\n        name: str\n            The name of the edge attribute.\n\n        Returns\n        -------\n        torch.Tensor or None\n            If present, returns the edge attribute tensor.\n        \"\"\"\n        return self._c_csc_graph.edge_attribute(name)\n\n    def add_edge_attribute(self, name: str, tensor: torch.Tensor) -> None:\n        \"\"\"Adds edge attribute tensor by name.\n\n        Parameters\n        ----------\n        name: str\n            The name of the edge attribute.\n        tensor: torch.Tensor\n            The edge attribute tensor.\n        \"\"\"\n        self._c_csc_graph.add_edge_attribute(name, tensor)\n\n    def in_subgraph(\n        self,\n        nodes: Union[torch.Tensor, Dict[str, torch.Tensor]],\n    ) -> SampledSubgraphImpl:\n        \"\"\"Return the subgraph induced on the inbound edges of the given nodes.\n\n        An in subgraph is equivalent to creating a new graph using the incoming\n        edges of the given nodes. Subgraph is compacted according to the order\n        of passed-in `nodes`.\n\n        Parameters\n        ----------\n        nodes: torch.Tensor or Dict[str, torch.Tensor]\n            IDs of the given seed nodes.\n              - If `nodes` is a tensor: It means the graph is homogeneous\n                graph, and ids inside are homogeneous ids.\n              - If `nodes` is a dictionary: The keys should be node type and\n                ids inside are heterogeneous ids.\n\n        Returns\n        -------\n        SampledSubgraphImpl\n            The in subgraph.\n\n        Examples\n        --------\n        >>> import dgl.graphbolt as gb\n        >>> import torch\n        >>> total_num_nodes = 5\n        >>> total_num_edges = 12\n        >>> ntypes = {\"N0\": 0, \"N1\": 1}\n        >>> etypes = {\n        ...     \"N0:R0:N0\": 0, \"N0:R1:N1\": 1, \"N1:R2:N0\": 2, \"N1:R3:N1\": 3}\n        >>> indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])\n        >>> indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])\n        >>> node_type_offset = torch.LongTensor([0, 2, 5])\n        >>> type_per_edge = torch.LongTensor(\n        ...     [0, 0, 2, 2, 2, 1, 1, 1, 3, 1, 3, 3])\n        >>> graph = gb.fused_csc_sampling_graph(indptr, indices,\n        ...     node_type_offset=node_type_offset,\n        ...     type_per_edge=type_per_edge,\n        ...     node_type_to_id=ntypes,\n        ...     edge_type_to_id=etypes)\n        >>> nodes = {\"N0\":torch.LongTensor([1]), \"N1\":torch.LongTensor([1, 2])}\n        >>> in_subgraph = graph.in_subgraph(nodes)\n        >>> print(in_subgraph.sampled_csc)\n        {'N0:R0:N0': CSCFormatBase(indptr=tensor([0, 0]),\n              indices=tensor([], dtype=torch.int64),\n        ), 'N0:R1:N1': CSCFormatBase(indptr=tensor([0, 1, 2]),\n                    indices=tensor([1, 0]),\n        ), 'N1:R2:N0': CSCFormatBase(indptr=tensor([0, 2]),\n                    indices=tensor([0, 1]),\n        ), 'N1:R3:N1': CSCFormatBase(indptr=tensor([0, 1, 3]),\n                    indices=tensor([0, 1, 2]),\n        )}\n        \"\"\"\n        if isinstance(nodes, dict):\n            nodes, _ = self._convert_to_homogeneous_nodes(nodes)\n        # Ensure nodes is 1-D tensor.\n        assert nodes.dim() == 1, \"Nodes should be 1-D tensor.\"\n\n        _in_subgraph = self._c_csc_graph.in_subgraph(nodes)\n        return self._convert_to_sampled_subgraph(_in_subgraph)\n\n    def _convert_to_homogeneous_nodes(\n        self, nodes, timestamps=None, time_windows=None\n    ):\n        homogeneous_nodes = []\n        homogeneous_node_offsets = [0]\n        homogeneous_timestamps = []\n        homogeneous_time_windows = []\n        offset = self._node_type_offset_list\n        for ntype, ntype_id in self.node_type_to_id.items():\n            ids = nodes.get(ntype, [])\n            if len(ids) > 0:\n                homogeneous_nodes.append(ids + offset[ntype_id])\n                if timestamps is not None:\n                    homogeneous_timestamps.append(timestamps[ntype])\n                if time_windows is not None:\n                    homogeneous_time_windows.append(time_windows[ntype])\n            homogeneous_node_offsets.append(\n                homogeneous_node_offsets[-1] + len(ids)\n            )\n        if timestamps is not None:\n            homogeneous_time_windows = (\n                torch.cat(homogeneous_time_windows)\n                if homogeneous_time_windows\n                else None\n            )\n            return (\n                torch.cat(homogeneous_nodes),\n                homogeneous_node_offsets,\n                torch.cat(homogeneous_timestamps),\n                homogeneous_time_windows,\n            )\n        return torch.cat(homogeneous_nodes), homogeneous_node_offsets\n\n    def _convert_to_sampled_subgraph(\n        self,\n        C_sampled_subgraph: torch.ScriptObject,\n        seed_offsets: Optional[list] = None,\n        fetching_original_edge_ids_is_optional: bool = False,\n    ) -> SampledSubgraphImpl:\n        \"\"\"An internal function used to convert a fused homogeneous sampled\n        subgraph to general struct 'SampledSubgraphImpl'.\"\"\"\n        indptr = C_sampled_subgraph.indptr\n        indices = C_sampled_subgraph.indices\n        type_per_edge = C_sampled_subgraph.type_per_edge\n        column = C_sampled_subgraph.original_column_node_ids\n        edge_ids_in_fused_csc_sampling_graph = (\n            C_sampled_subgraph.original_edge_ids\n        )\n        etype_offsets = C_sampled_subgraph.etype_offsets\n        if etype_offsets is not None:\n            etype_offsets = etype_offsets.tolist()\n\n        has_original_eids = (\n            self.edge_attributes is not None\n            and ORIGINAL_EDGE_ID in self.edge_attributes\n        )\n        original_edge_ids = (\n            (\n                torch.ops.graphbolt.index_select(\n                    self.edge_attributes[ORIGINAL_EDGE_ID],\n                    edge_ids_in_fused_csc_sampling_graph,\n                )\n                if not fetching_original_edge_ids_is_optional\n                or not edge_ids_in_fused_csc_sampling_graph.is_cuda\n                or not self.edge_attributes[ORIGINAL_EDGE_ID].is_pinned()\n                else None\n            )\n            if has_original_eids\n            else edge_ids_in_fused_csc_sampling_graph\n        )\n        if type_per_edge is None and etype_offsets is None:\n            # The sampled graph is already a homogeneous graph.\n            sampled_csc = CSCFormatBase(indptr=indptr, indices=indices)\n            if indices is not None and original_edge_ids is not None:\n                # Only needed to fetch indices or original_edge_ids.\n                edge_ids_in_fused_csc_sampling_graph = None\n        else:\n            offset = self._node_type_offset_list\n\n            original_hetero_edge_ids = {}\n            sub_indices = {}\n            sub_indptr = {}\n            if etype_offsets is None:\n                # UVA sampling requires us to move node_type_offset to GPU.\n                self.node_type_offset = self.node_type_offset.to(column.device)\n                # 1. Find node types for each nodes in column.\n                node_types = (\n                    torch.searchsorted(\n                        self.node_type_offset, column, right=True\n                    )\n                    - 1\n                )\n                for ntype, ntype_id in self.node_type_to_id.items():\n                    # Get all nodes of a specific node type in column.\n                    nids = torch.nonzero(node_types == ntype_id).view(-1)\n                    nids_original_indptr = indptr[nids + 1]\n                    for etype, etype_id in self.edge_type_to_id.items():\n                        src_ntype, _, dst_ntype = etype_str_to_tuple(etype)\n                        if dst_ntype != ntype:\n                            continue\n                        # Get all edge ids of a specific edge type.\n                        eids = torch.nonzero(type_per_edge == etype_id).view(-1)\n                        src_ntype_id = self.node_type_to_id[src_ntype]\n                        sub_indices[etype] = (\n                            indices[eids] - offset[src_ntype_id]\n                        )\n                        cum_edges = torch.searchsorted(\n                            eids, nids_original_indptr, right=False\n                        )\n                        sub_indptr[etype] = torch.cat(\n                            (torch.tensor([0], device=indptr.device), cum_edges)\n                        )\n                        original_hetero_edge_ids[etype] = original_edge_ids[\n                            eids\n                        ]\n                sampled_hetero_edge_ids_in_fused_csc_sampling_graph = None\n            else:\n                sampled_hetero_edge_ids_in_fused_csc_sampling_graph = {}\n                edge_offsets = [0]\n                for etype, etype_id in self.edge_type_to_id.items():\n                    src_ntype, _, dst_ntype = etype_str_to_tuple(etype)\n                    ntype_id = self.node_type_to_id[dst_ntype]\n                    edge_offsets.append(\n                        edge_offsets[-1]\n                        + seed_offsets[ntype_id + 1]\n                        - seed_offsets[ntype_id]\n                        + 1\n                    )\n                for etype, etype_id in self.edge_type_to_id.items():\n                    src_ntype, _, dst_ntype = etype_str_to_tuple(etype)\n                    ntype_id = self.node_type_to_id[dst_ntype]\n                    sub_indptr[etype] = indptr[\n                        edge_offsets[etype_id] : edge_offsets[etype_id + 1]\n                    ]\n                    sub_indices[etype] = (\n                        None\n                        if indices is None\n                        else indices[\n                            etype_offsets[etype_id] : etype_offsets[\n                                etype_id + 1\n                            ]\n                        ]\n                    )\n                    original_hetero_edge_ids[etype] = (\n                        None\n                        if original_edge_ids is None\n                        else original_edge_ids[\n                            etype_offsets[etype_id] : etype_offsets[\n                                etype_id + 1\n                            ]\n                        ]\n                    )\n                    if indices is None or original_edge_ids is None:\n                        # Only needed to fetch indices or original edge ids.\n                        sampled_hetero_edge_ids_in_fused_csc_sampling_graph[\n                            etype\n                        ] = edge_ids_in_fused_csc_sampling_graph[\n                            etype_offsets[etype_id] : etype_offsets[\n                                etype_id + 1\n                            ]\n                        ]\n\n            original_edge_ids = original_hetero_edge_ids\n            edge_ids_in_fused_csc_sampling_graph = (\n                sampled_hetero_edge_ids_in_fused_csc_sampling_graph\n            )\n            sampled_csc = {\n                etype: CSCFormatBase(\n                    indptr=sub_indptr[etype],\n                    indices=sub_indices[etype],\n                )\n                for etype in self.edge_type_to_id.keys()\n            }\n        return SampledSubgraphImpl(\n            sampled_csc=sampled_csc,\n            original_edge_ids=original_edge_ids,\n            _edge_ids_in_fused_csc_sampling_graph=edge_ids_in_fused_csc_sampling_graph,\n        )\n\n    def sample_neighbors(\n        self,\n        seeds: Union[torch.Tensor, Dict[str, torch.Tensor]],\n        fanouts: torch.Tensor,\n        replace: bool = False,\n        probs_name: Optional[str] = None,\n        returning_indices_and_original_edge_ids_are_optional: bool = False,\n        async_op: bool = False,\n    ) -> SampledSubgraphImpl:\n        \"\"\"Sample neighboring edges of the given nodes and return the induced\n        subgraph.\n\n        Parameters\n        ----------\n        seeds: torch.Tensor or Dict[str, torch.Tensor]\n            IDs of the given seed nodes.\n              - If `nodes` is a tensor: It means the graph is homogeneous\n                graph, and ids inside are homogeneous ids.\n              - If `nodes` is a dictionary: The keys should be node type and\n                ids inside are heterogeneous ids.\n        fanouts: torch.Tensor\n            The number of edges to be sampled for each node with or without\n            considering edge types.\n              - When the length is 1, it indicates that the fanout applies to\n                all neighbors of the node as a collective, regardless of the\n                edge type.\n              - Otherwise, the length should equal to the number of edge\n                types, and each fanout value corresponds to a specific edge\n                type of the nodes.\n            The value of each fanout should be >= 0 or = -1.\n              - When the value is -1, all neighbors (with non-zero probability,\n                if weighted) will be sampled once regardless of replacement. It\n                is equivalent to selecting all neighbors with non-zero\n                probability when the fanout is >= the number of neighbors (and\n                replace is set to false).\n              - When the value is a non-negative integer, it serves as a\n                minimum threshold for selecting neighbors.\n        replace: bool\n            Boolean indicating whether the sample is preformed with or\n            without replacement. If True, a value can be selected multiple\n            times. Otherwise, each value can be selected only once.\n        probs_name: str, optional\n            An optional string specifying the name of an edge attribute used.\n            This attribute tensor should contain (unnormalized) probabilities\n            corresponding to each neighboring edge of a node. It must be a 1D\n            floating-point or boolean tensor, with the number of elements\n            equalling the total number of edges.\n        returning_indices_and_original_edge_ids_are_optional: bool\n            Boolean indicating whether it is okay for the call to this function\n            to leave the indices and the original edge ids tensors\n            uninitialized. In this case, it is the user's responsibility to\n            gather them using _edge_ids_in_fused_csc_sampling_graph if either is\n            missing.\n        async_op: bool\n            Boolean indicating whether the call is asynchronous. If so, the\n            result can be obtained by calling wait on the returned future.\n\n        Returns\n        -------\n        SampledSubgraphImpl\n            The sampled subgraph.\n\n        Examples\n        --------\n        >>> import dgl.graphbolt as gb\n        >>> import torch\n        >>> ntypes = {\"n1\": 0, \"n2\": 1}\n        >>> etypes = {\"n1:e1:n2\": 0, \"n2:e2:n1\": 1}\n        >>> indptr = torch.LongTensor([0, 2, 4, 6, 7, 9])\n        >>> indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])\n        >>> node_type_offset = torch.LongTensor([0, 2, 5])\n        >>> type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0])\n        >>> graph = gb.fused_csc_sampling_graph(indptr, indices,\n        ...     node_type_offset=node_type_offset,\n        ...     type_per_edge=type_per_edge,\n        ...     node_type_to_id=ntypes,\n        ...     edge_type_to_id=etypes)\n        >>> nodes = {'n1': torch.LongTensor([0]), 'n2': torch.LongTensor([0])}\n        >>> fanouts = torch.tensor([1, 1])\n        >>> subgraph = graph.sample_neighbors(nodes, fanouts)\n        >>> print(subgraph.sampled_csc)\n        {'n1:e1:n2': CSCFormatBase(indptr=tensor([0, 1]),\n                    indices=tensor([0]),\n        ), 'n2:e2:n1': CSCFormatBase(indptr=tensor([0, 1]),\n                    indices=tensor([2]),\n        )}\n        \"\"\"\n        seed_offsets = None\n        if isinstance(seeds, dict):\n            seeds, seed_offsets = self._convert_to_homogeneous_nodes(seeds)\n        elif seeds is None:\n            seed_offsets = self._indptr_node_type_offset_list\n        probs_or_mask = self.edge_attributes[probs_name] if probs_name else None\n        C_sampled_subgraph = self._sample_neighbors(\n            seeds,\n            seed_offsets,\n            fanouts,\n            replace=replace,\n            probs_or_mask=probs_or_mask,\n            returning_indices_is_optional=returning_indices_and_original_edge_ids_are_optional,\n            async_op=async_op,\n        )\n        if async_op:\n            return _SampleNeighborsWaiter(\n                self._convert_to_sampled_subgraph,\n                C_sampled_subgraph,\n                seed_offsets,\n                returning_indices_and_original_edge_ids_are_optional,\n            )\n        else:\n            return self._convert_to_sampled_subgraph(\n                C_sampled_subgraph,\n                seed_offsets,\n                returning_indices_and_original_edge_ids_are_optional,\n            )\n\n    def _check_sampler_arguments(self, nodes, fanouts, probs_or_mask):\n        if nodes is not None:\n            assert nodes.dim() == 1, \"Nodes should be 1-D tensor.\"\n            assert nodes.dtype == self.indices.dtype, (\n                f\"Data type of nodes must be consistent with \"\n                f\"indices.dtype({self.indices.dtype}), but got {nodes.dtype}.\"\n            )\n        assert fanouts.dim() == 1, \"Fanouts should be 1-D tensor.\"\n        expected_fanout_len = 1\n        if self.edge_type_to_id:\n            expected_fanout_len = len(self.edge_type_to_id)\n        assert len(fanouts) in [\n            expected_fanout_len,\n            1,\n        ], \"Fanouts should have the same number of elements as etypes or \\\n            should have a length of 1.\"\n        if fanouts.size(0) > 1:\n            assert (\n                self.type_per_edge is not None\n            ), \"To perform sampling for each edge type (when the length of \\\n                `fanouts` > 1), the graph must include edge type information.\"\n        assert torch.all(\n            (fanouts >= 0) | (fanouts == -1)\n        ), \"Fanouts should consist of values that are either -1 or \\\n            greater than or equal to 0.\"\n        if probs_or_mask is not None:\n            assert probs_or_mask.dim() == 1, \"Probs should be 1-D tensor.\"\n            assert (\n                probs_or_mask.size(0) == self.total_num_edges\n            ), \"Probs should have the same number of elements as the number \\\n                of edges.\"\n            assert probs_or_mask.dtype in [\n                torch.bool,\n                torch.float16,\n                torch.bfloat16,\n                torch.float32,\n                torch.float64,\n            ], \"Probs should have a floating-point or boolean data type.\"\n\n    def _sample_neighbors(\n        self,\n        seeds: torch.Tensor,\n        seed_offsets: Optional[list],\n        fanouts: torch.Tensor,\n        replace: bool = False,\n        probs_or_mask: Optional[torch.Tensor] = None,\n        returning_indices_is_optional: bool = False,\n        async_op: bool = False,\n    ) -> torch.ScriptObject:\n        \"\"\"Sample neighboring edges of the given nodes and return the induced\n        subgraph.\n\n        Parameters\n        ----------\n        seeds: torch.Tensor\n            IDs of the given seed nodes.\n        seeds_offsets: list, optional\n            The offsets of the given seeds,\n            seeds[seed_offsets[i]: seed_offsets[i + 1]] has node type i.\n        fanouts: torch.Tensor\n            The number of edges to be sampled for each node with or without\n            considering edge types.\n              - When the length is 1, it indicates that the fanout applies to\n                all neighbors of the node as a collective, regardless of the\n                edge type.\n              - Otherwise, the length should equal to the number of edge\n                types, and each fanout value corresponds to a specific edge\n                type of the nodes.\n            The value of each fanout should be >= 0 or = -1.\n              - When the value is -1, all neighbors (with non-zero probability,\n                if weighted) will be sampled once regardless of replacement. It\n                is equivalent to selecting all neighbors with non-zero\n                probability when the fanout is >= the number of neighbors (and\n                replace is set to false).\n              - When the value is a non-negative integer, it serves as a\n                minimum threshold for selecting neighbors.\n        replace: bool\n            Boolean indicating whether the sample is preformed with or\n            without replacement. If True, a value can be selected multiple\n            times. Otherwise, each value can be selected only once.\n        probs_or_mask: torch.Tensor, optional\n            An optional tensor of edge attribute for probability or masks. This\n            attribute tensor should contain (unnormalized) probabilities\n            corresponding to each neighboring edge of a node. It must be a 1D\n            floating-point or boolean tensor, with the number of elements\n            equalling the total number of edges.\n        returning_indices_is_optional: bool\n            Boolean indicating whether it is okay for the call to this function\n            to leave the indices tensor uninitialized. In this case, it is the\n            user's responsibility to gather it using the edge ids.\n        async_op: bool\n            Boolean indicating whether the call is asynchronous. If so, the\n            result can be obtained by calling wait on the returned future.\n\n        Returns\n        -------\n        torch.classes.graphbolt.SampledSubgraph\n            The sampled C subgraph.\n        \"\"\"\n        # Ensure nodes is 1-D tensor.\n        self._check_sampler_arguments(seeds, fanouts, probs_or_mask)\n        sampling_fn = (\n            self._c_csc_graph.sample_neighbors_async\n            if async_op\n            else self._c_csc_graph.sample_neighbors\n        )\n        return sampling_fn(\n            seeds,\n            seed_offsets,\n            fanouts.tolist(),\n            replace,\n            False,  # is_labor\n            returning_indices_is_optional,\n            probs_or_mask,\n            None,  # random_seed, labor parameter\n            0,  # seed2_contribution, labor_parameter\n        )\n\n    def sample_layer_neighbors(\n        self,\n        seeds: Union[torch.Tensor, Dict[str, torch.Tensor]],\n        fanouts: torch.Tensor,\n        replace: bool = False,\n        probs_name: Optional[str] = None,\n        returning_indices_and_original_edge_ids_are_optional: bool = False,\n        random_seed: torch.Tensor = None,\n        seed2_contribution: float = 0.0,\n        async_op: bool = False,\n    ) -> SampledSubgraphImpl:\n        \"\"\"Sample neighboring edges of the given nodes and return the induced\n        subgraph via layer-neighbor sampling from the NeurIPS 2023 paper\n        `Layer-Neighbor Sampling -- Defusing Neighborhood Explosion in GNNs\n        <https://proceedings.neurips.cc/paper_files/paper/2023/file/51f9036d5e7ae822da8f6d4adda1fb39-Paper-Conference.pdf>`__\n\n        Parameters\n        ----------\n        seeds: torch.Tensor or Dict[str, torch.Tensor]\n            IDs of the given seed nodes.\n              - If `nodes` is a tensor: It means the graph is homogeneous\n                graph, and ids inside are homogeneous ids.\n              - If `nodes` is a dictionary: The keys should be node type and\n                ids inside are heterogeneous ids.\n        fanouts: torch.Tensor\n            The number of edges to be sampled for each node with or without\n            considering edge types.\n              - When the length is 1, it indicates that the fanout applies to\n                all neighbors of the node as a collective, regardless of the\n                edge type.\n              - Otherwise, the length should equal to the number of edge\n                types, and each fanout value corresponds to a specific edge\n                type of the nodes.\n            The value of each fanout should be >= 0 or = -1.\n              - When the value is -1, all neighbors (with non-zero probability,\n                if weighted) will be sampled once regardless of replacement. It\n                is equivalent to selecting all neighbors with non-zero\n                probability when the fanout is >= the number of neighbors (and\n                replace is set to false).\n              - When the value is a non-negative integer, it serves as a\n                minimum threshold for selecting neighbors.\n        replace: bool\n            Boolean indicating whether the sample is preformed with or\n            without replacement. If True, a value can be selected multiple\n            times. Otherwise, each value can be selected only once.\n        probs_name: str, optional\n            An optional string specifying the name of an edge attribute. This\n            attribute tensor should contain (unnormalized) probabilities\n            corresponding to each neighboring edge of a node. It must be a 1D\n            floating-point or boolean tensor, with the number of elements\n            equalling the total number of edges.\n        returning_indices_and_original_edge_ids_are_optional: bool\n            Boolean indicating whether it is okay for the call to this function\n            to leave the indices and the original edge ids tensors\n            uninitialized. In this case, it is the user's responsibility to\n            gather them using _edge_ids_in_fused_csc_sampling_graph if either is\n            missing.\n        random_seed: torch.Tensor, optional\n            An int64 tensor with one or two elements.\n\n            The passed random_seed makes it so that for any seed node ``s`` and\n            its neighbor ``t``, the rolled random variate ``r_t`` is the same\n            for any call to this function with the same random seed. When\n            sampling as part of the same batch, one would want identical seeds\n            so that LABOR can globally sample. One example is that for\n            heterogenous graphs, there is a single random seed passed for each\n            edge type. This will sample much fewer nodes compared to having\n            unique random seeds for each edge type. If one called this function\n            individually for each edge type for a heterogenous graph with\n            different random seeds, then it would run LABOR locally for each\n            edge type, resulting into a larger number of nodes being sampled.\n\n            If this function is called without a ``random_seed``, we get the\n            random seed by getting a random number from GraphBolt. Use this\n            argument with identical random_seed if multiple calls to this\n            function are used to sample as part of a single batch.\n\n            If given two numbers, then the ``seed2_contribution`` argument\n            determines the interpolation between the two random seeds.\n        seed2_contribution: float, optional\n            A float value between [0, 1) that determines the contribution of the\n            second random seed, ``random_seed[-1]``, to generate the random\n            variates.\n        async_op: bool\n            Boolean indicating whether the call is asynchronous. If so, the\n            result can be obtained by calling wait on the returned future.\n\n        Returns\n        -------\n        SampledSubgraphImpl\n            The sampled subgraph.\n\n        Examples\n        --------\n        >>> import dgl.graphbolt as gb\n        >>> import torch\n        >>> ntypes = {\"n1\": 0, \"n2\": 1}\n        >>> etypes = {\"n1:e1:n2\": 0, \"n2:e2:n1\": 1}\n        >>> indptr = torch.LongTensor([0, 2, 4, 6, 7, 9])\n        >>> indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])\n        >>> node_type_offset = torch.LongTensor([0, 2, 5])\n        >>> type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0])\n        >>> graph = gb.fused_csc_sampling_graph(indptr, indices,\n        ...     node_type_offset=node_type_offset,\n        ...     type_per_edge=type_per_edge,\n        ...     node_type_to_id=ntypes,\n        ...     edge_type_to_id=etypes)\n        >>> nodes = {'n1': torch.LongTensor([0]), 'n2': torch.LongTensor([0])}\n        >>> fanouts = torch.tensor([1, 1])\n        >>> subgraph = graph.sample_layer_neighbors(nodes, fanouts)\n        >>> print(subgraph.sampled_csc)\n        {'n1:e1:n2': CSCFormatBase(indptr=tensor([0, 1]),\n                    indices=tensor([0]),\n        ), 'n2:e2:n1': CSCFormatBase(indptr=tensor([0, 1]),\n                    indices=tensor([2]),\n        )}\n        \"\"\"\n        if random_seed is not None:\n            assert (\n                1 <= len(random_seed) <= 2\n            ), \"There should be a 1 or 2 random seeds.\"\n            if len(random_seed) == 2:\n                assert (\n                    0 <= seed2_contribution <= 1\n                ), \"seed2_contribution should be in [0, 1].\"\n\n        seed_offsets = None\n        if isinstance(seeds, dict):\n            seeds, seed_offsets = self._convert_to_homogeneous_nodes(seeds)\n        elif seeds is None:\n            seed_offsets = self._indptr_node_type_offset_list\n        probs_or_mask = self.edge_attributes[probs_name] if probs_name else None\n        self._check_sampler_arguments(seeds, fanouts, probs_or_mask)\n        sampling_fn = (\n            self._c_csc_graph.sample_neighbors_async\n            if async_op\n            else self._c_csc_graph.sample_neighbors\n        )\n        C_sampled_subgraph = sampling_fn(\n            seeds,\n            seed_offsets,\n            fanouts.tolist(),\n            replace,\n            True,  # is_labor\n            returning_indices_and_original_edge_ids_are_optional,\n            probs_or_mask,\n            random_seed,\n            seed2_contribution,\n        )\n        if async_op:\n            return _SampleNeighborsWaiter(\n                self._convert_to_sampled_subgraph,\n                C_sampled_subgraph,\n                seed_offsets,\n                returning_indices_and_original_edge_ids_are_optional,\n            )\n        else:\n            return self._convert_to_sampled_subgraph(\n                C_sampled_subgraph,\n                seed_offsets,\n                returning_indices_and_original_edge_ids_are_optional,\n            )\n\n    def temporal_sample_neighbors(\n        self,\n        seeds: Union[torch.Tensor, Dict[str, torch.Tensor]],\n        seeds_timestamp: Union[torch.Tensor, Dict[str, torch.Tensor]],\n        fanouts: torch.Tensor,\n        replace: bool = False,\n        seeds_pre_time_window: Optional[\n            Union[torch.Tensor, Dict[str, torch.Tensor]]\n        ] = None,\n        probs_name: Optional[str] = None,\n        node_timestamp_attr_name: Optional[str] = None,\n        edge_timestamp_attr_name: Optional[str] = None,\n    ) -> torch.ScriptObject:\n        \"\"\"Temporally Sample neighboring edges of the given nodes and return the induced\n        subgraph.\n\n        If `node_timestamp_attr_name` or `edge_timestamp_attr_name` is given,\n        the sampled neighbor or edge of an seed node must have a timestamp\n        that is smaller than that of the seed node.\n\n        Parameters\n        ----------\n        seeds: torch.Tensor\n            IDs of the given seed nodes.\n        seeds_timestamp: torch.Tensor\n            Timestamps of the given seed nodes.\n        fanouts: torch.Tensor\n            The number of edges to be sampled for each node with or without\n            considering edge types.\n              - When the length is 1, it indicates that the fanout applies to\n                all neighbors of the node as a collective, regardless of the\n                edge type.\n              - Otherwise, the length should equal to the number of edge\n                types, and each fanout value corresponds to a specific edge\n                type of the nodes.\n            The value of each fanout should be >= 0 or = -1.\n              - When the value is -1, all neighbors (with non-zero probability,\n                if weighted) will be sampled once regardless of replacement. It\n                is equivalent to selecting all neighbors with non-zero\n                probability when the fanout is >= the number of neighbors (and\n                replace is set to false).\n              - When the value is a non-negative integer, it serves as a\n                minimum threshold for selecting neighbors.\n        replace: bool\n            Boolean indicating whether the sample is preformed with or\n            without replacement. If True, a value can be selected multiple\n            times. Otherwise, each value can be selected only once.\n        seeds_pre_time_window: torch.Tensor\n            The time window of the nodes represents a period of time before\n            `seeds_timestamp`. If provided, only neighbors and related\n            edges whose timestamps fall within `[seeds_timestamp -\n            seeds_pre_time_window, seeds_timestamp]` will be filtered.\n        probs_name: str, optional\n            An optional string specifying the name of an edge attribute. This\n            attribute tensor should contain (unnormalized) probabilities\n            corresponding to each neighboring edge of a node. It must be a 1D\n            floating-point or boolean tensor, with the number of elements\n            equalling the total number of edges.\n        node_timestamp_attr_name: str, optional\n            An optional string specifying the name of an node attribute.\n        edge_timestamp_attr_name: str, optional\n            An optional string specifying the name of an edge attribute.\n\n        Returns\n        -------\n        SampledSubgraphImpl\n            The sampled subgraph.\n        \"\"\"\n        seed_offsets = None\n        if isinstance(seeds, dict):\n            (\n                seeds,\n                seed_offsets,\n                seeds_timestamp,\n                seeds_pre_time_window,\n            ) = self._convert_to_homogeneous_nodes(\n                seeds, seeds_timestamp, seeds_pre_time_window\n            )\n        elif seeds is None:\n            seed_offsets = self._indptr_node_type_offset_list\n\n        # Ensure nodes is 1-D tensor.\n        probs_or_mask = self.edge_attributes[probs_name] if probs_name else None\n        self._check_sampler_arguments(seeds, fanouts, probs_or_mask)\n        C_sampled_subgraph = self._c_csc_graph.temporal_sample_neighbors(\n            seeds,\n            seed_offsets,\n            seeds_timestamp,\n            fanouts.tolist(),\n            replace,\n            False,  # is_labor\n            False,  # returning_indices_is_optional\n            seeds_pre_time_window,\n            probs_or_mask,\n            node_timestamp_attr_name,\n            edge_timestamp_attr_name,\n            None,  # random_seed, labor parameter\n            0,  # seed2_contribution, labor_parameter\n        )\n        return self._convert_to_sampled_subgraph(\n            C_sampled_subgraph, seed_offsets\n        )\n\n    def temporal_sample_layer_neighbors(\n        self,\n        seeds: Union[torch.Tensor, Dict[str, torch.Tensor]],\n        seeds_timestamp: Union[torch.Tensor, Dict[str, torch.Tensor]],\n        fanouts: torch.Tensor,\n        replace: bool = False,\n        seeds_pre_time_window: Optional[\n            Union[torch.Tensor, Dict[str, torch.Tensor]]\n        ] = None,\n        probs_name: Optional[str] = None,\n        node_timestamp_attr_name: Optional[str] = None,\n        edge_timestamp_attr_name: Optional[str] = None,\n        random_seed: torch.Tensor = None,\n        seed2_contribution: float = 0.0,\n    ) -> torch.ScriptObject:\n        \"\"\"Temporally Sample neighboring edges of the given nodes and return the induced\n        subgraph via layer-neighbor sampling from the NeurIPS 2023 paper\n        `Layer-Neighbor Sampling -- Defusing Neighborhood Explosion in GNNs\n        <https://proceedings.neurips.cc/paper_files/paper/2023/file/51f9036d5e7ae822da8f6d4adda1fb39-Paper-Conference.pdf>`__\n\n        If `node_timestamp_attr_name` or `edge_timestamp_attr_name` is given,\n        the sampled neighbor or edge of an seed node must have a timestamp\n        that is smaller than that of the seed node.\n\n        Parameters\n        ----------\n        seeds: torch.Tensor\n            IDs of the given seed nodes.\n        seeds_timestamp: torch.Tensor\n            Timestamps of the given seed nodes.\n        fanouts: torch.Tensor\n            The number of edges to be sampled for each node with or without\n            considering edge types.\n              - When the length is 1, it indicates that the fanout applies to\n                all neighbors of the node as a collective, regardless of the\n                edge type.\n              - Otherwise, the length should equal to the number of edge\n                types, and each fanout value corresponds to a specific edge\n                type of the nodes.\n            The value of each fanout should be >= 0 or = -1.\n              - When the value is -1, all neighbors (with non-zero probability,\n                if weighted) will be sampled once regardless of replacement. It\n                is equivalent to selecting all neighbors with non-zero\n                probability when the fanout is >= the number of neighbors (and\n                replace is set to false).\n              - When the value is a non-negative integer, it serves as a\n                minimum threshold for selecting neighbors.\n        replace: bool\n            Boolean indicating whether the sample is preformed with or\n            without replacement. If True, a value can be selected multiple\n            times. Otherwise, each value can be selected only once.\n        seeds_pre_time_window: torch.Tensor\n            The time window of the nodes represents a period of time before\n            `seeds_timestamp`. If provided, only neighbors and related\n            edges whose timestamps fall within `[seeds_timestamp -\n            seeds_pre_time_window, seeds_timestamp]` will be\n            filtered.\n        probs_name: str, optional\n            An optional string specifying the name of an edge attribute. This\n            attribute tensor should contain (unnormalized) probabilities\n            corresponding to each neighboring edge of a node. It must be a 1D\n            floating-point or boolean tensor, with the number of elements\n            equalling the total number of edges.\n        node_timestamp_attr_name: str, optional\n            An optional string specifying the name of an node attribute.\n        edge_timestamp_attr_name: str, optional\n            An optional string specifying the name of an edge attribute.\n        random_seed: torch.Tensor, optional\n            An int64 tensor with one or two elements.\n\n            The passed random_seed makes it so that for any seed node ``s`` and\n            its neighbor ``t``, the rolled random variate ``r_t`` is the same\n            for any call to this function with the same random seed. When\n            sampling as part of the same batch, one would want identical seeds\n            so that LABOR can globally sample. One example is that for\n            heterogenous graphs, there is a single random seed passed for each\n            edge type. This will sample much fewer nodes compared to having\n            unique random seeds for each edge type. If one called this function\n            individually for each edge type for a heterogenous graph with\n            different random seeds, then it would run LABOR locally for each\n            edge type, resulting into a larger number of nodes being sampled.\n\n            If this function is called without a ``random_seed``, we get the\n            random seed by getting a random number from GraphBolt. Use this\n            argument with identical random_seed if multiple calls to this\n            function are used to sample as part of a single batch.\n\n            If given two numbers, then the ``seed2_contribution`` argument\n            determines the interpolation between the two random seeds.\n        seed2_contribution: float, optional\n            A float value between [0, 1) that determines the contribution of the\n            second random seed, ``random_seed[-1]``, to generate the random\n            variates.\n\n        Returns\n        -------\n        SampledSubgraphImpl\n            The sampled subgraph.\n        \"\"\"\n        seed_offsets = None\n        if isinstance(seeds, dict):\n            (\n                seeds,\n                seed_offsets,\n                seeds_timestamp,\n                seeds_pre_time_window,\n            ) = self._convert_to_homogeneous_nodes(\n                seeds, seeds_timestamp, seeds_pre_time_window\n            )\n        elif seeds is None:\n            seed_offsets = self._indptr_node_type_offset_list\n\n        # Ensure nodes is 1-D tensor.\n        probs_or_mask = self.edge_attributes[probs_name] if probs_name else None\n        self._check_sampler_arguments(seeds, fanouts, probs_or_mask)\n        C_sampled_subgraph = self._c_csc_graph.temporal_sample_neighbors(\n            seeds,\n            seed_offsets,\n            seeds_timestamp,\n            fanouts.tolist(),\n            replace,\n            True,  # is_labor\n            False,  # returning_indices_is_optional\n            seeds_pre_time_window,\n            probs_or_mask,\n            node_timestamp_attr_name,\n            edge_timestamp_attr_name,\n            random_seed,\n            seed2_contribution,\n        )\n        return self._convert_to_sampled_subgraph(\n            C_sampled_subgraph, seed_offsets\n        )\n\n    def sample_negative_edges_uniform(\n        self, edge_type, node_pairs, negative_ratio\n    ):\n        \"\"\"\n        Sample negative edges by randomly choosing negative source-destination\n        edges according to a uniform distribution. For each edge ``(u, v)``,\n        it is supposed to generate `negative_ratio` pairs of negative edges\n        ``(u, v')``, where ``v'`` is chosen uniformly from all the nodes in\n        the graph. ``u`` is exactly same as the corresponding positive edges.\n        It returns positive edges concatenated with negative edges. In\n        negative edges, negative sources are constructed from the\n        corresponding positive edges.\n\n        Parameters\n        ----------\n        edge_type: str\n            The type of edges in the provided node_pairs. Any negative edges\n            sampled will also have the same type. If set to None, it will be\n            considered as a homogeneous graph.\n        node_pairs : torch.Tensor\n            A 2D tensors that represent the N pairs of positive edges in\n            source-destination format, with 'positive' indicating that these\n            edges are present in the graph. It's important to note that within\n            the context of a heterogeneous graph, the ids in these tensors\n            signify heterogeneous ids.\n        negative_ratio: int\n            The ratio of the number of negative samples to positive samples.\n\n        Returns\n        -------\n        torch.Tensor\n            A 2D tensors represents the N pairs of positive and negative\n            source-destination node pairs. In the context of a heterogeneous\n            graph, both the input nodes and the selected nodes are represented\n            by heterogeneous IDs, and the formed edges are of the input type\n            `edge_type`. Note that negative refers to false negatives, which\n            means the edge could be present or not present in the graph.\n        \"\"\"\n        if edge_type:\n            _, _, dst_ntype = etype_str_to_tuple(edge_type)\n            max_node_id = self.num_nodes[dst_ntype]\n        else:\n            max_node_id = self.total_num_nodes\n        pos_src = node_pairs[:, 0]\n        num_negative = node_pairs.shape[0] * negative_ratio\n        negative_seeds = (\n            torch.cat(\n                (\n                    pos_src.repeat_interleave(negative_ratio),\n                    torch.randint(\n                        0,\n                        max_node_id,\n                        (num_negative,),\n                        dtype=node_pairs.dtype,\n                        device=node_pairs.device,\n                    ),\n                ),\n            )\n            .view(2, num_negative)\n            .T\n        )\n        seeds = torch.cat((node_pairs, negative_seeds))\n        return seeds\n\n    def copy_to_shared_memory(self, shared_memory_name: str):\n        \"\"\"Copy the graph to shared memory.\n\n        Parameters\n        ----------\n        shared_memory_name : str\n            Name of the shared memory.\n\n        Returns\n        -------\n        FusedCSCSamplingGraph\n            The copied FusedCSCSamplingGraph object on shared memory.\n        \"\"\"\n        return FusedCSCSamplingGraph(\n            self._c_csc_graph.copy_to_shared_memory(shared_memory_name),\n        )\n\n    def _apply_to_members(self, fn):\n        \"\"\"Apply passed fn to all members of `FusedCSCSamplingGraph`.\"\"\"\n        self.csc_indptr = recursive_apply(self.csc_indptr, fn)\n        self.indices = recursive_apply(self.indices, fn)\n        self.node_type_offset = recursive_apply(self.node_type_offset, fn)\n        self.type_per_edge = recursive_apply(self.type_per_edge, fn)\n        self.node_attributes = recursive_apply(self.node_attributes, fn)\n        self.edge_attributes = recursive_apply(self.edge_attributes, fn)\n\n        return self\n\n    def to(self, device: torch.device) -> None:  # pylint: disable=invalid-name\n        \"\"\"Copy `FusedCSCSamplingGraph` to the specified device.\"\"\"\n\n        def _to(x):\n            return x.to(device) if hasattr(x, \"to\") else x\n\n        def _pin(x):\n            return x.pin_memory() if hasattr(x, \"pin_memory\") else x\n\n        # Create a copy of self.\n        self2 = fused_csc_sampling_graph(\n            self.csc_indptr,\n            self.indices,\n            self.node_type_offset,\n            self.type_per_edge,\n            self.node_type_to_id,\n            self.edge_type_to_id,\n            self.node_attributes,\n            self.edge_attributes,\n        )\n        return self2._apply_to_members(_pin if device == \"pinned\" else _to)\n\n    def pin_memory_(self):\n        \"\"\"Copy `FusedCSCSamplingGraph` to the pinned memory in-place. Returns\n        the same object modified in-place.\"\"\"\n        if is_wsl():\n            gb_warning(\n                \"In place pinning is not supported on WSL. \"\n                \"Returning the out of place pinned `FusedCSCSamplingGraph`.\"\n            )\n            return self.to(\"pinned\")\n        # torch.Tensor.pin_memory() is not an inplace operation. To make it\n        # truly in-place, we need to use cudaHostRegister. Then, we need to use\n        # cudaHostUnregister to unpin the tensor in the destructor.\n        # https://github.com/pytorch/pytorch/issues/32167#issuecomment-753551842\n        cudart = torch.cuda.cudart()\n        if not hasattr(self, \"_is_inplace_pinned\"):\n            self._is_inplace_pinned = set()\n\n        def _pin(x):\n            if hasattr(x, \"pin_memory_\"):\n                x.pin_memory_()\n            elif (\n                isinstance(x, torch.Tensor)\n                and not x.is_pinned()\n                and x.device.type == \"cpu\"\n            ):\n                assert (\n                    x.is_contiguous()\n                ), \"Tensor pinning is only supported for contiguous tensors.\"\n                assert (\n                    cudart.cudaHostRegister(\n                        x.data_ptr(), x.numel() * x.element_size(), 0\n                    )\n                    == 0\n                )\n\n                self._is_inplace_pinned.add(x)\n                self._inplace_unpinner = cudart.cudaHostUnregister\n\n            return x\n\n        return self._apply_to_members(_pin)\n\n    def _initialize_gpu_graph_cache(\n        self,\n        num_gpu_cached_edges: int,\n        gpu_cache_threshold: int,\n        prob_name: Optional[str] = None,\n    ):\n        \"Construct a GPUGraphCache given the cache parameters.\"\n        num_gpu_cached_edges = min(num_gpu_cached_edges, self.total_num_edges)\n        dtypes = [self.indices.dtype]\n        if self.type_per_edge is not None:\n            dtypes.append(self.type_per_edge.dtype)\n        has_original_edge_ids = False\n        if self.edge_attributes is not None:\n            probs_or_mask = self.edge_attributes.get(prob_name, None)\n            if probs_or_mask is not None:\n                dtypes.append(probs_or_mask.dtype)\n            original_edge_ids = self.edge_attributes.get(ORIGINAL_EDGE_ID, None)\n            if original_edge_ids is not None:\n                dtypes.append(original_edge_ids.dtype)\n                has_original_edge_ids = True\n        self._gpu_graph_cache_ = GPUGraphCache(\n            num_gpu_cached_edges,\n            gpu_cache_threshold,\n            self.csc_indptr.dtype,\n            dtypes,\n            has_original_edge_ids,\n        )\n\n\ndef fused_csc_sampling_graph(\n    csc_indptr: torch.Tensor,\n    indices: torch.Tensor,\n    node_type_offset: Optional[torch.tensor] = None,\n    type_per_edge: Optional[torch.tensor] = None,\n    node_type_to_id: Optional[Dict[str, int]] = None,\n    edge_type_to_id: Optional[Dict[str, int]] = None,\n    node_attributes: Optional[Dict[str, torch.tensor]] = None,\n    edge_attributes: Optional[Dict[str, torch.tensor]] = None,\n) -> FusedCSCSamplingGraph:\n    \"\"\"Create a FusedCSCSamplingGraph object from a CSC representation.\n\n    Parameters\n    ----------\n    csc_indptr : torch.Tensor\n        Pointer to the start of each row in the `indices`. An integer tensor\n        with shape `(total_num_nodes+1,)`.\n    indices : torch.Tensor\n        Column indices of the non-zero elements in the CSC graph. An integer\n        tensor with shape `(total_num_edges,)`.\n    node_type_offset : Optional[torch.tensor], optional\n        Offset of node types in the graph, by default None.\n    type_per_edge : Optional[torch.tensor], optional\n        Type ids of each edge in the graph, by default None. If provided, it is\n        required that the edge types in each vertex neighborhood are in sorted\n        order. To be more precise, For each i in [0, csc_indptr.size(0) - 1),\n        `type_per_edge[indptr[i]: indptr[i + 1]]` is expected to be\n        monotonically nondecreasing.\n    node_type_to_id : Optional[Dict[str, int]], optional\n        Map node types to ids, by default None.\n    edge_type_to_id : Optional[Dict[str, int]], optional\n        Map edge types to ids, by default None.\n    node_attributes: Optional[Dict[str, torch.tensor]], optional\n        Node attributes of the graph, by default None.\n    edge_attributes: Optional[Dict[str, torch.tensor]], optional\n        Edge attributes of the graph, by default None.\n\n    Returns\n    -------\n    FusedCSCSamplingGraph\n        The created FusedCSCSamplingGraph object.\n\n    Examples\n    --------\n    >>> ntypes = {'n1': 0, 'n2': 1, 'n3': 2}\n    >>> etypes = {'n1:e1:n2': 0, 'n1:e2:n3': 1}\n    >>> csc_indptr = torch.tensor([0, 2, 5, 7, 8])\n    >>> indices = torch.tensor([1, 3, 0, 1, 2, 0, 3, 2])\n    >>> node_type_offset = torch.tensor([0, 1, 2, 4])\n    >>> type_per_edge = torch.tensor([0, 1, 0, 1, 1, 0, 0, 0])\n    >>> graph = graphbolt.fused_csc_sampling_graph(csc_indptr, indices,\n    ...         node_type_offset=node_type_offset,\n    ...         type_per_edge=type_per_edge,\n    ...         node_type_to_id=ntypes, edge_type_to_id=etypes,\n    ...         node_attributes=None, edge_attributes=None,)\n    >>> print(graph)\n    FusedCSCSamplingGraph(csc_indptr=tensor([0, 2, 5, 7, 8]),\n                          indices=tensor([1, 3, 0, 1, 2, 0, 3, 2]),\n                          total_num_nodes=4, num_edges={'n1:e1:n2': 5, 'n1:e2:n3': 3},\n                          node_type_offset=tensor([0, 1, 2, 4]),\n                          type_per_edge=tensor([0, 1, 0, 1, 1, 0, 0, 0]),\n                          node_type_to_id={'n1': 0, 'n2': 1, 'n3': 2},\n                          edge_type_to_id={'n1:e1:n2': 0, 'n1:e2:n3': 1},)\n    \"\"\"\n    if node_type_to_id is not None and edge_type_to_id is not None:\n        node_types = list(node_type_to_id.keys())\n        edge_types = list(edge_type_to_id.keys())\n        node_type_ids = list(node_type_to_id.values())\n        edge_type_ids = list(edge_type_to_id.values())\n\n        # Validate node_type_to_id.\n        assert all(\n            isinstance(x, str) for x in node_types\n        ), \"Node type name should be string.\"\n        assert all(\n            isinstance(x, int) for x in node_type_ids\n        ), \"Node type id should be int.\"\n        assert len(node_type_ids) == len(\n            set(node_type_ids)\n        ), \"Multiple node types shoud not be mapped to a same id.\"\n        # Validate edge_type_to_id.\n        for edge_type in edge_types:\n            src, edge, dst = etype_str_to_tuple(edge_type)\n            assert isinstance(edge, str), \"Edge type name should be string.\"\n            assert (\n                src in node_types\n            ), f\"Unrecognized node type {src} in edge type {edge_type}\"\n            assert (\n                dst in node_types\n            ), f\"Unrecognized node type {dst} in edge type {edge_type}\"\n        assert all(\n            isinstance(x, int) for x in edge_type_ids\n        ), \"Edge type id should be int.\"\n        assert len(edge_type_ids) == len(\n            set(edge_type_ids)\n        ), \"Multiple edge types shoud not be mapped to a same id.\"\n\n        if node_type_offset is not None:\n            assert len(node_type_to_id) + 1 == node_type_offset.size(\n                0\n            ), \"node_type_offset length should be |ntypes| + 1.\"\n    return FusedCSCSamplingGraph(\n        torch.ops.graphbolt.fused_csc_sampling_graph(\n            csc_indptr,\n            indices,\n            node_type_offset,\n            type_per_edge,\n            node_type_to_id,\n            edge_type_to_id,\n            node_attributes,\n            edge_attributes,\n        ),\n    )\n\n\ndef load_from_shared_memory(\n    shared_memory_name: str,\n) -> FusedCSCSamplingGraph:\n    \"\"\"Load a FusedCSCSamplingGraph object from shared memory.\n\n    Parameters\n    ----------\n    shared_memory_name : str\n        Name of the shared memory.\n\n    Returns\n    -------\n    FusedCSCSamplingGraph\n        The loaded FusedCSCSamplingGraph object on shared memory.\n    \"\"\"\n    return FusedCSCSamplingGraph(\n        torch.ops.graphbolt.load_from_shared_memory(shared_memory_name),\n    )\n\n\ndef from_dglgraph(\n    DGLGraphInstance,\n    is_homogeneous: bool = False,\n    include_original_edge_id: bool = False,\n) -> FusedCSCSamplingGraph:\n    \"\"\"Convert a DGLGraph to FusedCSCSamplingGraph.\"\"\"\n    from dgl.base import EID, ETYPE, NID, NTYPE\n    from dgl.convert import to_homogeneous\n\n    g = DGLGraphInstance\n\n    homo_g, ntype_count, _ = to_homogeneous(\n        g, ndata=g.ndata, edata=g.edata, return_count=True\n    )\n\n    if is_homogeneous:\n        node_type_to_id = None\n        edge_type_to_id = None\n    else:\n        # Initialize metadata.\n        node_type_to_id = {ntype: g.get_ntype_id(ntype) for ntype in g.ntypes}\n        edge_type_to_id = {\n            etype_tuple_to_str(etype): g.get_etype_id(etype)\n            for etype in g.canonical_etypes\n        }\n\n    # Obtain CSC matrix.\n    indptr, indices, edge_ids = homo_g.adj_tensors(\"csc\")\n    ntype_count.insert(0, 0)\n    node_type_offset = (\n        None\n        if is_homogeneous\n        else torch.cumsum(torch.LongTensor(ntype_count), 0)\n    )\n\n    # Assign edge type according to the order of CSC matrix.\n    type_per_edge = (\n        None\n        if is_homogeneous\n        else torch.index_select(homo_g.edata[ETYPE], dim=0, index=edge_ids)\n    )\n\n    node_attributes = {}\n    edge_attributes = {}\n    for feat_name, feat_data in homo_g.ndata.items():\n        if feat_name not in (NID, NTYPE):\n            node_attributes[feat_name] = feat_data\n    for feat_name, feat_data in homo_g.edata.items():\n        if feat_name not in (EID, ETYPE):\n            edge_attributes[feat_name] = feat_data\n    if include_original_edge_id:\n        # Assign edge attributes according to the original eids mapping.\n        edge_attributes[ORIGINAL_EDGE_ID] = torch.index_select(\n            homo_g.edata[EID], dim=0, index=edge_ids\n        )\n\n    return FusedCSCSamplingGraph(\n        torch.ops.graphbolt.fused_csc_sampling_graph(\n            indptr,\n            indices,\n            node_type_offset,\n            type_per_edge,\n            node_type_to_id,\n            edge_type_to_id,\n            node_attributes,\n            edge_attributes,\n        ),\n    )\n"
  },
  {
    "path": "python/dgl/graphbolt/impl/gpu_cached_feature.py",
    "content": "\"\"\"GPU cached feature for GraphBolt.\"\"\"\nfrom typing import Dict, Union\n\nimport torch\n\nfrom ..feature_store import (\n    bytes_to_number_of_items,\n    Feature,\n    FeatureKey,\n    wrap_with_cached_feature,\n)\n\nfrom .gpu_feature_cache import GPUFeatureCache\n\n__all__ = [\"GPUCachedFeature\", \"gpu_cached_feature\"]\n\n\nclass GPUCachedFeature(Feature):\n    r\"\"\"GPU cached feature wrapping a fallback feature. It uses the least\n    recently used (LRU) algorithm as the cache eviction policy. Use\n    `gpu_cached_feature` to construct an instance of this class.\n\n    Places the GPU cache to torch.cuda.current_device().\n\n    Parameters\n    ----------\n    fallback_feature : Feature\n        The fallback feature.\n    cache : GPUFeatureCache\n        A GPUFeatureCache instance to serve as the cache backend.\n    offset : int, optional\n        The offset value to add to the given ids before using the cache. This\n        parameter is useful if multiple `GPUCachedFeature`s are sharing a single\n        GPUFeatureCache object.\n\n    Examples\n    --------\n    >>> import torch\n    >>> from dgl import graphbolt as gb\n    >>> torch_feat = torch.arange(10).reshape(2, -1).to(\"cuda\")\n    >>> cache_size = 5\n    >>> fallback_feature = gb.TorchBasedFeature(torch_feat)\n    >>> feature = gb.gpu_cached_feature(fallback_feature, cache_size)\n    >>> feature.read()\n    tensor([[0, 1, 2, 3, 4],\n            [5, 6, 7, 8, 9]], device='cuda:0')\n    >>> feature.read(torch.tensor([0]).to(\"cuda\"))\n    tensor([[0, 1, 2, 3, 4]], device='cuda:0')\n    >>> feature.update(torch.tensor([[1 for _ in range(5)]]).to(\"cuda\"),\n    ...                torch.tensor([1]).to(\"cuda\"))\n    >>> feature.read(torch.tensor([0, 1]).to(\"cuda\"))\n    tensor([[0, 1, 2, 3, 4],\n            [1, 1, 1, 1, 1]], device='cuda:0')\n    >>> feature.size()\n    torch.Size([5])\n    \"\"\"\n\n    _cache_type = GPUFeatureCache\n\n    def __init__(\n        self,\n        fallback_feature: Feature,\n        cache: GPUFeatureCache,\n        offset: int = 0,\n    ):\n        super(GPUCachedFeature, self).__init__()\n        assert isinstance(fallback_feature, Feature), (\n            f\"The fallback_feature must be an instance of Feature, but got \"\n            f\"{type(fallback_feature)}.\"\n        )\n        self._fallback_feature = fallback_feature\n        self._feature = cache\n        self._offset = offset\n\n    def read(self, ids: torch.Tensor = None):\n        \"\"\"Read the feature by index.\n\n        The returned tensor is always in GPU memory, no matter whether the\n        fallback feature is in memory or on disk.\n\n        Parameters\n        ----------\n        ids : torch.Tensor, optional\n            The index of the feature. If specified, only the specified indices\n            of the feature are read. If None, the entire feature is returned.\n\n        Returns\n        -------\n        torch.Tensor\n            The read feature.\n        \"\"\"\n        if ids is None:\n            return self._fallback_feature.read()\n        values, missing_index, missing_keys = self._feature.query(\n            ids if self._offset == 0 else ids + self._offset\n        )\n        missing_values = self._fallback_feature.read(\n            missing_keys if self._offset == 0 else missing_keys - self._offset\n        )\n        values[missing_index] = missing_values\n        self._feature.replace(missing_keys, missing_values)\n        return values\n\n    def read_async(self, ids: torch.Tensor):\n        r\"\"\"Read the feature by index asynchronously.\n\n        Parameters\n        ----------\n        ids : torch.Tensor\n            The index of the feature. Only the specified indices of the\n            feature are read.\n        Returns\n        -------\n        A generator object.\n            The returned generator object returns a future on\n            ``read_async_num_stages(ids.device)``\\ th invocation. The return result\n            can be accessed by calling ``.wait()``. on the returned future object.\n            It is undefined behavior to call ``.wait()`` more than once.\n\n        Examples\n        --------\n        >>> import dgl.graphbolt as gb\n        >>> feature = gb.Feature(...)\n        >>> ids = torch.tensor([0, 2])\n        >>> for stage, future in enumerate(feature.read_async(ids)):\n        ...     pass\n        >>> assert stage + 1 == feature.read_async_num_stages(ids.device)\n        >>> result = future.wait()  # result contains the read values.\n        \"\"\"\n        future = self._feature.query(\n            ids if self._offset == 0 else ids + self._offset, async_op=True\n        )\n\n        yield\n\n        values, missing_index, missing_keys = future.wait()\n\n        fallback_reader = self._fallback_feature.read_async(\n            missing_keys if self._offset == 0 else missing_keys - self._offset\n        )\n        fallback_num_stages = self._fallback_feature.read_async_num_stages(\n            missing_keys.device\n        )\n        for i in range(fallback_num_stages):\n            missing_values_future = next(fallback_reader, None)\n            if i < fallback_num_stages - 1:\n                yield  # fallback feature stages.\n\n        class _Waiter:\n            def __init__(\n                self,\n                feature,\n                values,\n                missing_index,\n                missing_keys,\n                missing_values_future,\n            ):\n                self.feature = feature\n                self.values = values\n                self.missing_index = missing_index\n                self.missing_keys = missing_keys\n                self.missing_values_future = missing_values_future\n\n            def wait(self):\n                \"\"\"Returns the stored value when invoked.\"\"\"\n                missing_values = self.missing_values_future.wait()\n                self.feature.replace(self.missing_keys, missing_values)\n                self.values[self.missing_index] = missing_values\n                values = self.values\n                # Ensure there is no memory leak.\n                self.feature = self.values = self.missing_index = None\n                self.missing_keys = self.missing_values_future = None\n                return values\n\n        yield _Waiter(\n            self._feature,\n            values,\n            missing_index,\n            missing_keys,\n            missing_values_future,\n        )\n\n    def read_async_num_stages(self, ids_device: torch.device):\n        \"\"\"The number of stages of the read_async operation. See read_async\n        function for directions on its use. This function is required to return\n        the number of yield operations when read_async is used with a tensor\n        residing on ids_device.\n\n        Parameters\n        ----------\n        ids_device : torch.device\n            The device of the ids parameter passed into read_async.\n        Returns\n        -------\n        int\n            The number of stages of the read_async operation.\n        \"\"\"\n        assert ids_device.type == \"cuda\"\n        return 1 + self._fallback_feature.read_async_num_stages(ids_device)\n\n    def size(self):\n        \"\"\"Get the size of the feature.\n\n        Returns\n        -------\n        torch.Size\n            The size of the feature.\n        \"\"\"\n        return self._fallback_feature.size()\n\n    def count(self):\n        \"\"\"Get the count of the feature.\n\n        Returns\n        -------\n        int\n            The count of the feature.\n        \"\"\"\n        return self._fallback_feature.count()\n\n    def update(self, value: torch.Tensor, ids: torch.Tensor = None):\n        \"\"\"Update the feature.\n\n        Parameters\n        ----------\n        value : torch.Tensor\n            The updated value of the feature.\n        ids : torch.Tensor, optional\n            The indices of the feature to update. If specified, only the\n            specified indices of the feature will be updated. For the feature,\n            the `ids[i]` row is updated to `value[i]`. So the indices and value\n            must have the same length. If None, the entire feature will be\n            updated.\n        \"\"\"\n        if ids is None:\n            feat0 = value[:1]\n            self._fallback_feature.update(value)\n            cache_size = min(\n                bytes_to_number_of_items(self.cache_size_in_bytes, feat0),\n                value.shape[0],\n            )\n            self._feature = None  # Destroy the existing cache first.\n            self._feature = self._cache_type(\n                (cache_size,) + feat0.shape[1:], feat0.dtype\n            )\n        else:\n            self._fallback_feature.update(value, ids)\n            self._feature.replace(ids, value)\n\n    @property\n    def cache_size_in_bytes(self):\n        \"\"\"Return the size taken by the cache in bytes.\"\"\"\n        return self._feature.max_size_in_bytes\n\n    @property\n    def miss_rate(self):\n        \"\"\"Returns the cache miss rate since creation.\"\"\"\n        return self._feature.miss_rate\n\n\ndef gpu_cached_feature(\n    fallback_features: Union[Feature, Dict[FeatureKey, Feature]],\n    max_cache_size_in_bytes: int,\n) -> Union[GPUCachedFeature, Dict[FeatureKey, GPUCachedFeature]]:\n    r\"\"\"GPU cached feature wrapping a fallback feature. It uses the least\n    recently used (LRU) algorithm as the cache eviction policy.\n\n    Places the GPU cache to torch.cuda.current_device().\n\n    Parameters\n    ----------\n    fallback_features : Union[Feature, Dict[FeatureKey, Feature]]\n        The fallback feature(s).\n    max_cache_size_in_bytes : int\n        The capacity of the GPU cache in bytes.\n    Returns\n    -------\n    Union[GPUCachedFeature, Dict[FeatureKey, GPUCachedFeature]]\n        The feature(s) wrapped with GPUCachedFeature.\n    \"\"\"\n    return wrap_with_cached_feature(\n        GPUCachedFeature, fallback_features, max_cache_size_in_bytes\n    )\n"
  },
  {
    "path": "python/dgl/graphbolt/impl/gpu_feature_cache.py",
    "content": "\"\"\"HugeCTR gpu_cache wrapper for graphbolt.\"\"\"\nfrom functools import reduce\nfrom operator import mul\n\nimport torch\n\n\nclass GPUFeatureCache(object):\n    \"\"\"High-level wrapper for GPU embedding cache\"\"\"\n\n    def __init__(self, cache_shape, dtype):\n        major, _ = torch.cuda.get_device_capability()\n        assert (\n            major >= 7\n        ), \"GPUFeatureCache is supported only on CUDA compute capability >= 70 (Volta).\"\n        self._cache = torch.ops.graphbolt.gpu_cache(cache_shape, dtype)\n        element_size = torch.tensor([], dtype=dtype).element_size()\n        self.max_size_in_bytes = reduce(mul, cache_shape) * element_size\n        self.total_miss = 0\n        self.total_queries = 0\n\n    def query(self, keys, async_op=False):\n        \"\"\"Queries the GPU cache.\n\n        Parameters\n        ----------\n        keys : Tensor\n            The keys to query the GPU cache with.\n        async_op: bool\n            Boolean indicating whether the call is asynchronous. If so, the\n            result can be obtained by calling wait on the returned future.\n\n        Returns\n        -------\n        tuple(Tensor, Tensor, Tensor)\n            A tuple containing (values, missing_indices, missing_keys) where\n            values[missing_indices] corresponds to cache misses that should be\n            filled by quering another source with missing_keys.\n        \"\"\"\n\n        class _Waiter:\n            def __init__(self, gpu_cache, future):\n                self.gpu_cache = gpu_cache\n                self.future = future\n\n            def wait(self):\n                \"\"\"Returns the stored value when invoked.\"\"\"\n                gpu_cache = self.gpu_cache\n                values, missing_index, missing_keys = (\n                    self.future.wait() if async_op else self.future\n                )\n                # Ensure there is no leak.\n                self.gpu_cache = self.future = None\n\n                gpu_cache.total_queries += values.shape[0]\n                gpu_cache.total_miss += missing_keys.shape[0]\n                return values, missing_index, missing_keys\n\n        if async_op:\n            return _Waiter(self, self._cache.query_async(keys))\n        else:\n            return _Waiter(self, self._cache.query(keys)).wait()\n\n    def replace(self, keys, values):\n        \"\"\"Inserts key-value pairs into the GPU cache using the Least-Recently\n        Used (LRU) algorithm to remove old key-value pairs if it is full.\n\n        Parameters\n        ----------\n        keys: Tensor\n            The keys to insert to the GPU cache.\n        values: Tensor\n            The values to insert to the GPU cache.\n        \"\"\"\n        self._cache.replace(keys, values)\n\n    @property\n    def miss_rate(self):\n        \"\"\"Returns the cache miss rate since creation.\"\"\"\n        return self.total_miss / self.total_queries\n"
  },
  {
    "path": "python/dgl/graphbolt/impl/gpu_graph_cache.py",
    "content": "\"\"\"HugeCTR gpu_cache wrapper for graphbolt.\"\"\"\nimport torch\n\n\nclass GPUGraphCache(object):\n    r\"\"\"High-level wrapper for GPU graph cache.\n\n    Places the GPU graph cache to torch.cuda.current_device().\n\n    Parameters\n    ----------\n    num_edges : int\n        Upperbound on number of edges to cache.\n    threshold : int\n        The number of accesses before the neighborhood of a vertex is cached.\n    indptr_dtype : torch.dtype\n        The dtype of the indptr tensor of the graph.\n    dtypes : list[torch.dtype]\n        The dtypes of the edge tensors that are going to be cached.\n    has_original_edge_ids : bool\n        Whether the graph to be cached has original edge ids.\n    \"\"\"\n\n    def __init__(\n        self, num_edges, threshold, indptr_dtype, dtypes, has_original_edge_ids\n    ):\n        major, _ = torch.cuda.get_device_capability()\n        assert (\n            major >= 7\n        ), \"GPUGraphCache is supported only on CUDA compute capability >= 70 (Volta).\"\n        self._cache = torch.ops.graphbolt.gpu_graph_cache(\n            num_edges, threshold, indptr_dtype, dtypes, has_original_edge_ids\n        )\n        self.total_miss = 0\n        self.total_queries = 0\n\n    def query(self, keys):\n        \"\"\"Queries the GPU cache.\n\n        Parameters\n        ----------\n        keys : Tensor\n            The keys to query the GPU graph cache with.\n\n        Returns\n        -------\n        tuple(Tensor, func)\n            A tuple containing (missing_keys, replace_fn) where replace_fn is a\n            function that should be called with the graph structure\n            corresponding to the missing keys. Its arguments are\n            (Tensor, list(Tensor)), where the first tensor is the missing indptr\n            and the second list is the missing edge tensors.\n        \"\"\"\n        self.total_queries += keys.shape[0]\n        (\n            index,\n            position,\n            num_hit,\n            num_threshold,\n        ) = self._cache.query(keys)\n        self.total_miss += keys.shape[0] - num_hit\n\n        def replace_functional(missing_indptr, missing_edge_tensors):\n            return self._cache.replace(\n                keys,\n                index,\n                position,\n                num_hit,\n                num_threshold,\n                missing_indptr,\n                missing_edge_tensors,\n            )\n\n        return keys[index[num_hit:]], replace_functional\n\n    def query_async(self, keys):\n        \"\"\"Queries the GPU cache asynchronously.\n\n        Parameters\n        ----------\n        keys : Tensor\n            The keys to query the GPU graph cache with.\n\n        Returns\n        -------\n        A generator object.\n            The returned generator object returns the missing keys on the second\n            invocation and expects the fetched indptr and edge tensors on the\n            next invocation. The third and last invocation returns a future\n            object and the return result can be accessed by calling `.wait()`\n            on the returned future object. It is undefined behavior to call\n            `.wait()` more than once.\n        \"\"\"\n        future = self._cache.query_async(keys)\n\n        yield\n\n        index, position, num_hit, num_threshold = future.wait()\n\n        self.total_queries += keys.shape[0]\n        self.total_miss += keys.shape[0] - num_hit\n\n        missing_indptr, missing_edge_tensors = yield keys[index[num_hit:]]\n\n        yield self._cache.replace_async(\n            keys,\n            index,\n            position,\n            num_hit,\n            num_threshold,\n            missing_indptr,\n            missing_edge_tensors,\n        )\n\n    @property\n    def miss_rate(self):\n        \"\"\"Returns the cache miss rate since creation.\"\"\"\n        return self.total_miss / self.total_queries\n"
  },
  {
    "path": "python/dgl/graphbolt/impl/in_subgraph_sampler.py",
    "content": "\"\"\"In-subgraph sampler for GraphBolt.\"\"\"\n\nfrom torch.utils.data import functional_datapipe\n\nfrom ..internal import unique_and_compact_csc_formats\n\nfrom ..subgraph_sampler import SubgraphSampler\nfrom .sampled_subgraph_impl import SampledSubgraphImpl\n\n\n__all__ = [\"InSubgraphSampler\"]\n\n\n@functional_datapipe(\"sample_in_subgraph\")\nclass InSubgraphSampler(SubgraphSampler):\n    \"\"\"Sample the subgraph induced on the inbound edges of the given nodes.\n\n    Functional name: :obj:`sample_in_subgraph`.\n\n    In-subgraph sampler is responsible for sampling a subgraph from given data,\n    returning an induced subgraph along with compacted information.\n\n    Parameters\n    ----------\n    datapipe : DataPipe\n        The datapipe.\n    graph : FusedCSCSamplingGraph\n        The graph on which to perform in_subgraph sampling.\n\n    Examples\n    -------\n    >>> import dgl.graphbolt as gb\n    >>> import torch\n    >>> indptr = torch.LongTensor([0, 3, 5, 7, 9, 12, 14])\n    >>> indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 5, 1, 2, 0, 3, 5, 1, 4])\n    >>> graph = gb.fused_csc_sampling_graph(indptr, indices)\n    >>> item_set = gb.ItemSet(len(indptr) - 1, names=\"seeds\")\n    >>> item_sampler = gb.ItemSampler(item_set, batch_size=2)\n    >>> insubgraph_sampler = gb.InSubgraphSampler(item_sampler, graph)\n    >>> for _, data in enumerate(insubgraph_sampler):\n    ...     print(data.sampled_subgraphs[0].sampled_csc)\n    ...     print(data.sampled_subgraphs[0].original_row_node_ids)\n    ...     print(data.sampled_subgraphs[0].original_column_node_ids)\n    CSCFormatBase(indptr=tensor([0, 3, 5]),\n                indices=tensor([0, 1, 2, 3, 4]),\n    )\n    tensor([0, 1, 4, 2, 3])\n    tensor([0, 1])\n    CSCFormatBase(indptr=tensor([0, 2, 4]),\n                indices=tensor([2, 3, 4, 0]),\n    )\n    tensor([2, 3, 0, 5, 1])\n    tensor([2, 3])\n    CSCFormatBase(indptr=tensor([0, 3, 5]),\n                indices=tensor([2, 3, 1, 4, 0]),\n    )\n    tensor([4, 5, 0, 3, 1])\n    tensor([4, 5])\n    \"\"\"\n\n    def __init__(\n        self,\n        datapipe,\n        graph,\n    ):\n        super().__init__(datapipe)\n        self.graph = graph\n        self.sampler = graph.in_subgraph\n\n    def sample_subgraphs(\n        self, seeds, seeds_timestamp, seeds_pre_time_window=None\n    ):\n        subgraph = self.sampler(seeds)\n        (\n            original_row_node_ids,\n            compacted_csc_formats,\n            _,\n        ) = unique_and_compact_csc_formats(subgraph.sampled_csc, seeds)\n        subgraph = SampledSubgraphImpl(\n            sampled_csc=compacted_csc_formats,\n            original_column_node_ids=seeds,\n            original_row_node_ids=original_row_node_ids,\n            original_edge_ids=subgraph.original_edge_ids,\n        )\n        seeds = original_row_node_ids\n        return (seeds, [subgraph])\n"
  },
  {
    "path": "python/dgl/graphbolt/impl/legacy_dataset.py",
    "content": "\"\"\"Graphbolt dataset for legacy DGLDataset.\"\"\"\n\nfrom typing import List, Union\n\nfrom ..base import etype_tuple_to_str\nfrom ..dataset import Dataset, Task\nfrom ..itemset import HeteroItemSet, ItemSet\nfrom ..sampling_graph import SamplingGraph\nfrom .basic_feature_store import BasicFeatureStore\nfrom .fused_csc_sampling_graph import from_dglgraph\nfrom .ondisk_dataset import OnDiskTask\nfrom .torch_based_feature_store import TorchBasedFeature\n\n\nclass LegacyDataset(Dataset):\n    \"\"\"A Graphbolt dataset for legacy DGLDataset.\"\"\"\n\n    def __init__(self, legacy):\n        # Only supports single graph cases.\n        assert len(legacy) == 1\n        graph = legacy[0]\n        # Handle OGB Dataset.\n        if isinstance(graph, tuple):\n            graph, _ = graph\n        if graph.is_homogeneous:\n            self._init_as_homogeneous_node_pred(legacy)\n        else:\n            self._init_as_heterogeneous_node_pred(legacy)\n\n    def _init_as_heterogeneous_node_pred(self, legacy):\n        def _init_item_set_dict(idx, labels):\n            item_set_dict = {}\n            for key in idx.keys():\n                item_set = ItemSet(\n                    (idx[key], labels[key][idx[key]]),\n                    names=(\"seeds\", \"labels\"),\n                )\n                item_set_dict[key] = item_set\n            return HeteroItemSet(item_set_dict)\n\n        # OGB Dataset has the idx split.\n        if hasattr(legacy, \"get_idx_split\"):\n            graph, labels = legacy[0]\n            split_idx = legacy.get_idx_split()\n\n            # Initialize tasks.\n            tasks = []\n            metadata = {\n                \"num_classes\": legacy.num_classes,\n                \"name\": \"node_classification\",\n            }\n            train_set = _init_item_set_dict(split_idx[\"train\"], labels)\n            validation_set = _init_item_set_dict(split_idx[\"valid\"], labels)\n            test_set = _init_item_set_dict(split_idx[\"test\"], labels)\n            task = OnDiskTask(metadata, train_set, validation_set, test_set)\n            tasks.append(task)\n            self._tasks = tasks\n\n            item_set_dict = {}\n            for ntype in graph.ntypes:\n                item_set = ItemSet(graph.num_nodes(ntype), names=\"seeds\")\n                item_set_dict[ntype] = item_set\n            self._all_nodes_set = HeteroItemSet(item_set_dict)\n\n            features = {}\n            for ntype in graph.ntypes:\n                for name in graph.nodes[ntype].data.keys():\n                    tensor = graph.nodes[ntype].data[name]\n                    if tensor.dim() == 1:\n                        tensor = tensor.view(-1, 1)\n                    features[(\"node\", ntype, name)] = TorchBasedFeature(tensor)\n            for etype in graph.canonical_etypes:\n                for name in graph.edges[etype].data.keys():\n                    tensor = graph.edges[etype].data[name]\n                    if tensor.dim() == 1:\n                        tensor = tensor.view(-1, 1)\n                    gb_etype = etype_tuple_to_str(etype)\n                    features[(\"edge\", gb_etype, name)] = TorchBasedFeature(\n                        tensor\n                    )\n            self._feature = BasicFeatureStore(features)\n            self._graph = from_dglgraph(graph, is_homogeneous=False)\n            self._dataset_name = legacy.name\n        else:\n            raise NotImplementedError(\n                \"Only support heterogeneous ogn node pred dataset\"\n            )\n\n    def _init_as_homogeneous_node_pred(self, legacy):\n        from dgl.data import AsNodePredDataset\n\n        legacy = AsNodePredDataset(legacy)\n\n        # Initialize tasks.\n        tasks = []\n        metadata = {\n            \"num_classes\": legacy.num_classes,\n            \"name\": \"node_classification\",\n        }\n        train_labels = legacy[0].ndata[\"label\"][legacy.train_idx]\n        validation_labels = legacy[0].ndata[\"label\"][legacy.val_idx]\n        test_labels = legacy[0].ndata[\"label\"][legacy.test_idx]\n        train_set = ItemSet(\n            (legacy.train_idx, train_labels),\n            names=(\"seeds\", \"labels\"),\n        )\n        validation_set = ItemSet(\n            (legacy.val_idx, validation_labels),\n            names=(\"seeds\", \"labels\"),\n        )\n        test_set = ItemSet(\n            (legacy.test_idx, test_labels), names=(\"seeds\", \"labels\")\n        )\n        task = OnDiskTask(metadata, train_set, validation_set, test_set)\n        tasks.append(task)\n        self._tasks = tasks\n\n        num_nodes = legacy[0].num_nodes()\n        self._all_nodes_set = ItemSet(num_nodes, names=\"seeds\")\n        features = {}\n        for name in legacy[0].ndata.keys():\n            tensor = legacy[0].ndata[name]\n            if tensor.dim() == 1:\n                tensor = tensor.view(-1, 1)\n            features[(\"node\", None, name)] = TorchBasedFeature(tensor)\n        for name in legacy[0].edata.keys():\n            tensor = legacy[0].edata[name]\n            if tensor.dim() == 1:\n                tensor = tensor.view(-1, 1)\n            features[(\"edge\", None, name)] = TorchBasedFeature(tensor)\n        self._feature = BasicFeatureStore(features)\n        self._graph = from_dglgraph(legacy[0], is_homogeneous=True)\n        self._dataset_name = legacy.name\n\n    @property\n    def tasks(self) -> List[Task]:\n        \"\"\"Return the tasks.\"\"\"\n        return self._tasks\n\n    @property\n    def graph(self) -> SamplingGraph:\n        \"\"\"Return the graph.\"\"\"\n        return self._graph\n\n    @property\n    def feature(self) -> BasicFeatureStore:\n        \"\"\"Return the feature.\"\"\"\n        return self._feature\n\n    @property\n    def dataset_name(self) -> str:\n        \"\"\"Return the dataset name.\"\"\"\n        return self._dataset_name\n\n    @property\n    def all_nodes_set(self) -> Union[ItemSet, HeteroItemSet]:\n        \"\"\"Return the itemset containing all nodes.\"\"\"\n        return self._all_nodes_set\n"
  },
  {
    "path": "python/dgl/graphbolt/impl/neighbor_sampler.py",
    "content": "\"\"\"Neighbor subgraph samplers for GraphBolt.\"\"\"\n\nfrom functools import partial\n\nimport torch\nimport torch.distributed as thd\nfrom torch.utils.data import functional_datapipe\nfrom torch.utils.data.datapipes.iter import Mapper\n\nfrom ..base import (\n    etype_str_to_tuple,\n    get_host_to_device_uva_stream,\n    index_select,\n    ORIGINAL_EDGE_ID,\n)\nfrom ..internal import (\n    compact_csc_format,\n    unique_and_compact,\n    unique_and_compact_csc_formats,\n)\nfrom ..minibatch_transformer import MiniBatchTransformer\n\nfrom ..subgraph_sampler import all_to_all, revert_to_homo, SubgraphSampler\nfrom .fused_csc_sampling_graph import fused_csc_sampling_graph\nfrom .sampled_subgraph_impl import SampledSubgraphImpl\n\n\n__all__ = [\n    \"NeighborSampler\",\n    \"LayerNeighborSampler\",\n    \"SamplePerLayer\",\n    \"FetchInsubgraphData\",\n    \"CombineCachedAndFetchedInSubgraph\",\n]\n\n\n@functional_datapipe(\"fetch_cached_insubgraph_data\")\nclass FetchCachedInsubgraphData(Mapper):\n    \"\"\"Queries the GPUGraphCache and returns the missing seeds and a generator\n    handle that can be called with the fetched graph structure.\n    \"\"\"\n\n    def __init__(self, datapipe, gpu_graph_cache):\n        datapipe = datapipe.transform(self._fetch_per_layer).buffer()\n        super().__init__(datapipe, self._wait_query_future)\n        self.cache = gpu_graph_cache\n\n    def _fetch_per_layer(self, minibatch):\n        minibatch._async_handle = self.cache.query_async(minibatch._seeds)\n        # Start first stage\n        next(minibatch._async_handle)\n\n        return minibatch\n\n    @staticmethod\n    def _wait_query_future(minibatch):\n        minibatch._seeds = next(minibatch._async_handle)\n\n        return minibatch\n\n\n@functional_datapipe(\"combine_cached_and_fetched_insubgraph\")\nclass CombineCachedAndFetchedInSubgraph(Mapper):\n    \"\"\"Combined the fetched graph structure with the graph structure already\n    found inside the GPUGraphCache.\n    \"\"\"\n\n    def __init__(self, datapipe, prob_name):\n        datapipe = datapipe.transform(self._combine_per_layer).buffer()\n        super().__init__(datapipe, self._wait_replace_future)\n        self.prob_name = prob_name\n\n    def _combine_per_layer(self, minibatch):\n        subgraph = minibatch._sliced_sampling_graph\n\n        edge_tensors = [subgraph.indices]\n        if subgraph.type_per_edge is not None:\n            edge_tensors.append(subgraph.type_per_edge)\n        probs_or_mask = subgraph.edge_attribute(self.prob_name)\n        if probs_or_mask is not None:\n            edge_tensors.append(probs_or_mask)\n        edge_tensors.append(subgraph.edge_attribute(ORIGINAL_EDGE_ID))\n\n        minibatch._future = minibatch._async_handle.send(\n            (subgraph.csc_indptr, edge_tensors)\n        )\n        delattr(minibatch, \"_async_handle\")\n\n        return minibatch\n\n    def _wait_replace_future(self, minibatch):\n        subgraph = minibatch._sliced_sampling_graph\n        subgraph.csc_indptr, edge_tensors = minibatch._future.wait()\n        delattr(minibatch, \"_future\")\n\n        subgraph.indices = edge_tensors[0]\n        edge_tensors = edge_tensors[1:]\n        if subgraph.type_per_edge is not None:\n            subgraph.type_per_edge = edge_tensors[0]\n            edge_tensors = edge_tensors[1:]\n        probs_or_mask = subgraph.edge_attribute(self.prob_name)\n        if probs_or_mask is not None:\n            subgraph.add_edge_attribute(self.prob_name, edge_tensors[0])\n            edge_tensors = edge_tensors[1:]\n        subgraph.add_edge_attribute(ORIGINAL_EDGE_ID, edge_tensors[0])\n        edge_tensors = edge_tensors[1:]\n        assert len(edge_tensors) == 0\n\n        return minibatch\n\n\n@functional_datapipe(\"fetch_insubgraph_data\")\nclass FetchInsubgraphData(MiniBatchTransformer):\n    \"\"\"Fetches the insubgraph and wraps it in a FusedCSCSamplingGraph object. If\n    the provided sample_per_layer_obj has a valid prob_name, then it reads the\n    probabilies of all the fetched edges. Furthermore, if type_per_array tensor\n    exists in the underlying graph, then the types of all the fetched edges are\n    read as well.\"\"\"\n\n    def __init__(\n        self,\n        datapipe,\n        graph,\n        prob_name,\n    ):\n        datapipe = datapipe.transform(self._concat_hetero_seeds)\n        if graph._gpu_graph_cache is not None:\n            datapipe = datapipe.fetch_cached_insubgraph_data(\n                graph._gpu_graph_cache\n            )\n        datapipe = datapipe.transform(self._fetch_per_layer_stage_1)\n        datapipe = datapipe.buffer()\n        datapipe = datapipe.transform(self._fetch_per_layer_stage_2)\n        if graph._gpu_graph_cache is not None:\n            datapipe = datapipe.combine_cached_and_fetched_insubgraph(prob_name)\n        super().__init__(datapipe)\n        self.graph = graph\n        self.prob_name = prob_name\n\n    def _concat_hetero_seeds(self, minibatch):\n        \"\"\"Concatenates the seeds into a single tensor in the hetero case.\"\"\"\n        seeds = minibatch._seed_nodes\n        if isinstance(seeds, dict):\n            (\n                seeds,\n                seed_offsets,\n            ) = self.graph._convert_to_homogeneous_nodes(seeds)\n        else:\n            seed_offsets = None\n        minibatch._seeds = seeds\n        minibatch._seed_offsets = seed_offsets\n\n        return minibatch\n\n    def _fetch_per_layer_stage_1(self, minibatch):\n        minibatch._async_handle_fetch = self._fetch_per_layer_async(minibatch)\n        next(minibatch._async_handle_fetch)\n        return minibatch\n\n    def _fetch_per_layer_stage_2(self, minibatch):\n        minibatch = next(minibatch._async_handle_fetch)\n        delattr(minibatch, \"_async_handle_fetch\")\n        return minibatch\n\n    def _fetch_per_layer_async(self, minibatch):\n        stream = torch.cuda.current_stream()\n        uva_stream = get_host_to_device_uva_stream()\n        uva_stream.wait_stream(stream)\n        with torch.cuda.stream(uva_stream):\n            seeds = minibatch._seeds\n            seed_offsets = minibatch._seed_offsets\n            delattr(minibatch, \"_seeds\")\n            delattr(minibatch, \"_seed_offsets\")\n\n            seeds.record_stream(torch.cuda.current_stream())\n\n            # Packs tensors for batch slicing.\n            tensors_to_be_sliced = [self.graph.indices]\n\n            has_type_per_edge = False\n            if self.graph.type_per_edge is not None:\n                tensors_to_be_sliced.append(self.graph.type_per_edge)\n                has_type_per_edge = True\n\n            has_probs_or_mask = False\n            has_original_edge_ids = False\n            if self.graph.edge_attributes is not None:\n                probs_or_mask = self.graph.edge_attributes.get(\n                    self.prob_name, None\n                )\n                if probs_or_mask is not None:\n                    tensors_to_be_sliced.append(probs_or_mask)\n                    has_probs_or_mask = True\n                original_edge_ids = self.graph.edge_attributes.get(\n                    ORIGINAL_EDGE_ID, None\n                )\n                if original_edge_ids is not None:\n                    tensors_to_be_sliced.append(original_edge_ids)\n                    has_original_edge_ids = True\n\n            # Slices the batched tensors.\n            future = torch.ops.graphbolt.index_select_csc_batched_async(\n                self.graph.csc_indptr,\n                tensors_to_be_sliced,\n                seeds,\n                # When there are no edge ids, we assume it is arange(num_edges).\n                not has_original_edge_ids,\n                None,\n            )\n\n        yield\n\n        # graphbolt::async has already recorded a CUDAEvent for us and\n        # called CUDAStreamWaitEvent for us on the current stream.\n        indptr, sliced_tensors = future.wait()\n\n        for tensor in [indptr] + sliced_tensors:\n            tensor.record_stream(stream)\n\n        # Unpacks the sliced tensors.\n        indices = sliced_tensors[0]\n        sliced_tensors = sliced_tensors[1:]\n\n        type_per_edge = None\n        if has_type_per_edge:\n            type_per_edge = sliced_tensors[0]\n            sliced_tensors = sliced_tensors[1:]\n\n        probs_or_mask = None\n        if has_probs_or_mask:\n            probs_or_mask = sliced_tensors[0]\n            sliced_tensors = sliced_tensors[1:]\n\n        edge_ids = sliced_tensors[0]\n        sliced_tensors = sliced_tensors[1:]\n        assert len(sliced_tensors) == 0\n\n        subgraph = fused_csc_sampling_graph(\n            indptr,\n            indices,\n            node_type_offset=self.graph.node_type_offset,\n            type_per_edge=type_per_edge,\n            node_type_to_id=self.graph.node_type_to_id,\n            edge_type_to_id=self.graph.edge_type_to_id,\n        )\n        if self.prob_name is not None and probs_or_mask is not None:\n            subgraph.add_edge_attribute(self.prob_name, probs_or_mask)\n        subgraph.add_edge_attribute(ORIGINAL_EDGE_ID, edge_ids)\n\n        subgraph._indptr_node_type_offset_list = seed_offsets\n        minibatch._sliced_sampling_graph = subgraph\n\n        yield minibatch\n\n\n@functional_datapipe(\"sample_per_layer\")\nclass SamplePerLayer(MiniBatchTransformer):\n    \"\"\"Sample neighbor edges from a graph for a single layer.\"\"\"\n\n    def __init__(\n        self,\n        datapipe,\n        sampler,\n        fanout,\n        replace,\n        prob_name,\n        overlap_fetch,\n        asynchronous=False,\n    ):\n        graph = sampler.__self__\n        self.returning_indices_and_original_edge_ids_are_optional = False\n        original_edge_ids = (\n            None\n            if graph.edge_attributes is None\n            else graph.edge_attributes.get(ORIGINAL_EDGE_ID, None)\n        )\n        if (\n            overlap_fetch\n            and sampler.__name__ == \"sample_neighbors\"\n            and (\n                graph.indices.is_pinned()\n                or (\n                    original_edge_ids is not None\n                    and original_edge_ids.is_pinned()\n                )\n            )\n            and graph._gpu_graph_cache is None\n        ):\n            datapipe = datapipe.transform(self._sample_per_layer)\n            if asynchronous:\n                datapipe = datapipe.buffer()\n                datapipe = datapipe.transform(self._wait_subgraph_future)\n            fetch_indices_and_original_edge_ids_fn = partial(\n                self._fetch_indices_and_original_edge_ids,\n                graph.indices,\n                original_edge_ids,\n            )\n            datapipe = (\n                datapipe.transform(fetch_indices_and_original_edge_ids_fn)\n                .buffer()\n                .wait()\n            )\n            if graph.type_per_edge is not None:\n                # Hetero case.\n                datapipe = datapipe.transform(\n                    partial(\n                        self._subtract_hetero_indices_offset,\n                        graph._node_type_offset_list,\n                        graph.node_type_to_id,\n                    )\n                )\n            self.returning_indices_and_original_edge_ids_are_optional = True\n        elif overlap_fetch:\n            datapipe = datapipe.fetch_insubgraph_data(graph, prob_name)\n            datapipe = datapipe.transform(\n                self._sample_per_layer_from_fetched_subgraph\n            )\n            if asynchronous:\n                datapipe = datapipe.buffer()\n                datapipe = datapipe.transform(self._wait_subgraph_future)\n        else:\n            datapipe = datapipe.transform(self._sample_per_layer)\n            if asynchronous:\n                datapipe = datapipe.buffer()\n                datapipe = datapipe.transform(self._wait_subgraph_future)\n        super().__init__(datapipe)\n        self.sampler = sampler\n        self.fanout = fanout\n        self.replace = replace\n        self.prob_name = prob_name\n        self.overlap_fetch = overlap_fetch\n        self.asynchronous = asynchronous\n\n    def _sample_per_layer(self, minibatch):\n        kwargs = {\n            key[1:]: getattr(minibatch, key)\n            for key in [\"_random_seed\", \"_seed2_contribution\"]\n            if hasattr(minibatch, key)\n        }\n        subgraph = self.sampler(\n            minibatch._seed_nodes,\n            self.fanout,\n            self.replace,\n            self.prob_name,\n            self.returning_indices_and_original_edge_ids_are_optional,\n            async_op=self.asynchronous,\n            **kwargs,\n        )\n        minibatch.sampled_subgraphs.insert(0, subgraph)\n        return minibatch\n\n    def _sample_per_layer_from_fetched_subgraph(self, minibatch):\n        subgraph = minibatch._sliced_sampling_graph\n        delattr(minibatch, \"_sliced_sampling_graph\")\n        kwargs = {\n            key[1:]: getattr(minibatch, key)\n            for key in [\"_random_seed\", \"_seed2_contribution\"]\n            if hasattr(minibatch, key)\n        }\n        sampled_subgraph = getattr(subgraph, self.sampler.__name__)(\n            None,\n            self.fanout,\n            self.replace,\n            self.prob_name,\n            async_op=self.asynchronous,\n            **kwargs,\n        )\n        minibatch.sampled_subgraphs.insert(0, sampled_subgraph)\n        return minibatch\n\n    @staticmethod\n    def _wait_subgraph_future(minibatch):\n        minibatch.sampled_subgraphs[0] = minibatch.sampled_subgraphs[0].wait()\n        return minibatch\n\n    @staticmethod\n    def _fetch_indices_and_original_edge_ids(indices, orig_edge_ids, minibatch):\n        stream = torch.cuda.current_stream()\n        host_to_device_stream = get_host_to_device_uva_stream()\n        host_to_device_stream.wait_stream(stream)\n\n        def record_stream(tensor):\n            tensor.record_stream(stream)\n            return tensor\n\n        with torch.cuda.stream(host_to_device_stream):\n            minibatch._indices_needs_offset_subtraction = False\n            subgraph = minibatch.sampled_subgraphs[0]\n            if isinstance(subgraph.sampled_csc, dict):\n                for etype, pair in subgraph.sampled_csc.items():\n                    if pair.indices is None:\n                        edge_ids = (\n                            subgraph._edge_ids_in_fused_csc_sampling_graph[\n                                etype\n                            ]\n                        )\n                        edge_ids.record_stream(torch.cuda.current_stream())\n                        pair.indices = record_stream(\n                            index_select(indices, edge_ids)\n                        )\n                        minibatch._indices_needs_offset_subtraction = True\n                    if (\n                        orig_edge_ids is not None\n                        and subgraph.original_edge_ids[etype] is None\n                    ):\n                        edge_ids = (\n                            subgraph._edge_ids_in_fused_csc_sampling_graph[\n                                etype\n                            ]\n                        )\n                        edge_ids.record_stream(torch.cuda.current_stream())\n                        subgraph.original_edge_ids[etype] = record_stream(\n                            index_select(orig_edge_ids, edge_ids)\n                        )\n            else:\n                if subgraph.sampled_csc.indices is None:\n                    subgraph._edge_ids_in_fused_csc_sampling_graph.record_stream(\n                        torch.cuda.current_stream()\n                    )\n                    subgraph.sampled_csc.indices = record_stream(\n                        index_select(\n                            indices,\n                            subgraph._edge_ids_in_fused_csc_sampling_graph,\n                        )\n                    )\n                if (\n                    orig_edge_ids is not None\n                    and subgraph.original_edge_ids is None\n                ):\n                    subgraph._edge_ids_in_fused_csc_sampling_graph.record_stream(\n                        torch.cuda.current_stream()\n                    )\n                    subgraph.original_edge_ids = record_stream(\n                        index_select(\n                            orig_edge_ids,\n                            subgraph._edge_ids_in_fused_csc_sampling_graph,\n                        )\n                    )\n            subgraph._edge_ids_in_fused_csc_sampling_graph = None\n            minibatch.wait = torch.cuda.current_stream().record_event().wait\n\n        return minibatch\n\n    @staticmethod\n    def _subtract_hetero_indices_offset(\n        node_type_offset, node_type_to_id, minibatch\n    ):\n        if minibatch._indices_needs_offset_subtraction:\n            subgraph = minibatch.sampled_subgraphs[0]\n            for etype, pair in subgraph.sampled_csc.items():\n                src_ntype = etype_str_to_tuple(etype)[0]\n                src_ntype_id = node_type_to_id[src_ntype]\n                pair.indices -= node_type_offset[src_ntype_id]\n        delattr(minibatch, \"_indices_needs_offset_subtraction\")\n\n        return minibatch\n\n\n@functional_datapipe(\"compact_per_layer\")\nclass CompactPerLayer(MiniBatchTransformer):\n    \"\"\"Compact the sampled edges for a single layer.\"\"\"\n\n    def __init__(\n        self, datapipe, deduplicate, cooperative=False, asynchronous=False\n    ):\n        self.deduplicate = deduplicate\n        self.cooperative = cooperative\n        if asynchronous and deduplicate:\n            datapipe = datapipe.transform(self._compact_per_layer_async)\n            datapipe = datapipe.buffer()\n            datapipe = datapipe.transform(self._compact_per_layer_wait_future)\n            if cooperative:\n                datapipe = datapipe.transform(\n                    self._seeds_cooperative_exchange_1\n                )\n                datapipe = datapipe.buffer()\n                datapipe = datapipe.transform(\n                    self._seeds_cooperative_exchange_2\n                )\n                datapipe = datapipe.buffer()\n                datapipe = datapipe.transform(\n                    self._seeds_cooperative_exchange_3\n                )\n                datapipe = datapipe.buffer()\n                datapipe = datapipe.transform(\n                    self._seeds_cooperative_exchange_4\n                )\n            super().__init__(datapipe)\n        else:\n            super().__init__(datapipe, self._compact_per_layer)\n\n    def _compact_per_layer(self, minibatch):\n        subgraph = minibatch.sampled_subgraphs[0]\n        seeds = minibatch._seed_nodes\n        if self.deduplicate:\n            (\n                original_row_node_ids,\n                compacted_csc_format,\n                _,\n            ) = unique_and_compact_csc_formats(subgraph.sampled_csc, seeds)\n            subgraph = SampledSubgraphImpl(\n                sampled_csc=compacted_csc_format,\n                original_column_node_ids=seeds,\n                original_row_node_ids=original_row_node_ids,\n                original_edge_ids=subgraph.original_edge_ids,\n            )\n        else:\n            (\n                original_row_node_ids,\n                compacted_csc_format,\n            ) = compact_csc_format(subgraph.sampled_csc, seeds)\n            subgraph = SampledSubgraphImpl(\n                sampled_csc=compacted_csc_format,\n                original_column_node_ids=seeds,\n                original_row_node_ids=original_row_node_ids,\n                original_edge_ids=subgraph.original_edge_ids,\n            )\n        minibatch._seed_nodes = original_row_node_ids\n        minibatch.sampled_subgraphs[0] = subgraph\n        return minibatch\n\n    def _compact_per_layer_async(self, minibatch):\n        subgraph = minibatch.sampled_subgraphs[0]\n        seeds = minibatch._seed_nodes\n        assert self.deduplicate\n        rank = thd.get_rank() if self.cooperative else 0\n        world_size = thd.get_world_size() if self.cooperative else 1\n        minibatch._future = unique_and_compact_csc_formats(\n            subgraph.sampled_csc, seeds, rank, world_size, async_op=True\n        )\n        return minibatch\n\n    def _compact_per_layer_wait_future(self, minibatch):\n        subgraph = minibatch.sampled_subgraphs[0]\n        seeds = minibatch._seed_nodes\n        (\n            original_row_node_ids,\n            compacted_csc_format,\n            seeds_offsets,\n        ) = minibatch._future.wait()\n        delattr(minibatch, \"_future\")\n        subgraph = SampledSubgraphImpl(\n            sampled_csc=compacted_csc_format,\n            original_column_node_ids=seeds,\n            original_row_node_ids=original_row_node_ids,\n            original_edge_ids=subgraph.original_edge_ids,\n        )\n        minibatch._seed_nodes = original_row_node_ids\n        minibatch.sampled_subgraphs[0] = subgraph\n        if self.cooperative:\n            subgraph._seeds_offsets = seeds_offsets\n        return minibatch\n\n    @staticmethod\n    def _seeds_cooperative_exchange_1(minibatch):\n        world_size = thd.get_world_size()\n        subgraph = minibatch.sampled_subgraphs[0]\n        seeds_offsets = subgraph._seeds_offsets\n        is_homogeneous = not isinstance(seeds_offsets, dict)\n        if is_homogeneous:\n            seeds_offsets = {\"_N\": seeds_offsets}\n        num_ntypes = len(seeds_offsets)\n        counts_sent = torch.empty(world_size * num_ntypes, dtype=torch.int64)\n        for i, offsets in enumerate(seeds_offsets.values()):\n            counts_sent[\n                torch.arange(i, world_size * num_ntypes, num_ntypes)\n            ] = offsets.diff()\n        counts_received = torch.empty_like(counts_sent)\n        subgraph._counts_future = all_to_all(\n            counts_received.split(num_ntypes),\n            counts_sent.split(num_ntypes),\n            async_op=True,\n        )\n        subgraph._counts_sent = counts_sent\n        subgraph._counts_received = counts_received\n        return minibatch\n\n    @staticmethod\n    def _seeds_cooperative_exchange_2(minibatch):\n        world_size = thd.get_world_size()\n        seeds = minibatch._seed_nodes\n        is_homogenous = not isinstance(seeds, dict)\n        if is_homogenous:\n            seeds = {\"_N\": seeds}\n        subgraph = minibatch.sampled_subgraphs[0]\n        subgraph._counts_future.wait()\n        delattr(subgraph, \"_counts_future\")\n        num_ntypes = len(seeds.keys())\n        seeds_received = {}\n        counts_sent = {}\n        counts_received = {}\n        for i, (ntype, typed_seeds) in enumerate(seeds.items()):\n            idx = torch.arange(i, world_size * num_ntypes, num_ntypes)\n            typed_counts_sent = subgraph._counts_sent[idx].tolist()\n            typed_counts_received = subgraph._counts_received[idx].tolist()\n            typed_seeds_received = typed_seeds.new_empty(\n                sum(typed_counts_received)\n            )\n            all_to_all(\n                typed_seeds_received.split(typed_counts_received),\n                typed_seeds.split(typed_counts_sent),\n            )\n            seeds_received[ntype] = typed_seeds_received\n            counts_sent[ntype] = typed_counts_sent\n            counts_received[ntype] = typed_counts_received\n        minibatch._seed_nodes = seeds_received\n        subgraph._counts_sent = revert_to_homo(counts_sent)\n        subgraph._counts_received = revert_to_homo(counts_received)\n        return minibatch\n\n    @staticmethod\n    def _seeds_cooperative_exchange_3(minibatch):\n        nodes = {\n            ntype: [typed_seeds]\n            for ntype, typed_seeds in minibatch._seed_nodes.items()\n        }\n        minibatch._unique_future = unique_and_compact(\n            nodes, 0, 1, async_op=True\n        )\n        return minibatch\n\n    @staticmethod\n    def _seeds_cooperative_exchange_4(minibatch):\n        unique_seeds, inverse_seeds, _ = minibatch._unique_future.wait()\n        delattr(minibatch, \"_unique_future\")\n        inverse_seeds = {\n            ntype: typed_inv[0] for ntype, typed_inv in inverse_seeds.items()\n        }\n        minibatch._seed_nodes = revert_to_homo(unique_seeds)\n        subgraph = minibatch.sampled_subgraphs[0]\n        sizes = {\n            ntype: typed_seeds.size(0)\n            for ntype, typed_seeds in unique_seeds.items()\n        }\n        subgraph._seed_sizes = revert_to_homo(sizes)\n        subgraph._seed_inverse_ids = revert_to_homo(inverse_seeds)\n        return minibatch\n\n\nclass NeighborSamplerImpl(SubgraphSampler):\n    # pylint: disable=abstract-method\n    \"\"\"Base class for NeighborSamplers.\"\"\"\n\n    # pylint: disable=useless-super-delegation\n    def __init__(\n        self,\n        datapipe,\n        graph,\n        fanouts,\n        replace,\n        prob_name,\n        deduplicate,\n        sampler,\n        overlap_fetch,\n        num_gpu_cached_edges,\n        gpu_cache_threshold,\n        cooperative,\n        asynchronous,\n        layer_dependency=None,\n        batch_dependency=None,\n    ):\n        if overlap_fetch and num_gpu_cached_edges > 0:\n            if graph._gpu_graph_cache is None:\n                graph._initialize_gpu_graph_cache(\n                    num_gpu_cached_edges, gpu_cache_threshold, prob_name\n                )\n        if sampler.__name__ == \"sample_layer_neighbors\":\n            self._init_seed(batch_dependency)\n        super().__init__(\n            datapipe,\n            graph,\n            fanouts,\n            replace,\n            prob_name,\n            deduplicate,\n            sampler,\n            overlap_fetch,\n            cooperative=cooperative,\n            asynchronous=asynchronous,\n            layer_dependency=layer_dependency,\n        )\n\n    def _init_seed(self, batch_dependency):\n        self.rng = torch.random.manual_seed(\n            torch.randint(0, int(1e18), size=tuple())\n        )\n        self.cnt = [-1, int(batch_dependency)]\n        self.random_seed = torch.empty(\n            2 if self.cnt[1] > 1 else 1, dtype=torch.int64\n        )\n        self.random_seed.random_(generator=self.rng)\n\n    def _set_seed(self, minibatch):\n        self.cnt[0] += 1\n        if self.cnt[1] > 0 and self.cnt[0] % self.cnt[1] == 0:\n            self.random_seed[0] = self.random_seed[-1]\n            self.random_seed[-1:].random_(generator=self.rng)\n        minibatch._random_seed = self.random_seed.clone()\n        minibatch._seed2_contribution = (\n            0.0\n            if self.cnt[1] <= 1\n            else (self.cnt[0] % self.cnt[1]) / self.cnt[1]\n        )\n        minibatch._iter = self.cnt[0]\n        return minibatch\n\n    @staticmethod\n    def _increment_seed(minibatch):\n        minibatch._random_seed = 1 + minibatch._random_seed\n        return minibatch\n\n    @staticmethod\n    def _delattr_dependency(minibatch):\n        delattr(minibatch, \"_random_seed\")\n        delattr(minibatch, \"_seed2_contribution\")\n        return minibatch\n\n    @staticmethod\n    def _prepare(node_type_to_id, minibatch):\n        seeds = minibatch._seed_nodes\n        # Enrich seeds with all node types.\n        if isinstance(seeds, dict):\n            ntypes = list(node_type_to_id.keys())\n            # Loop over different seeds to extract the device they are on.\n            device = None\n            dtype = None\n            for _, seed in seeds.items():\n                device = seed.device\n                dtype = seed.dtype\n                break\n            default_tensor = torch.tensor([], dtype=dtype, device=device)\n            seeds = {\n                ntype: seeds.get(ntype, default_tensor) for ntype in ntypes\n            }\n        minibatch._seed_nodes = seeds\n        minibatch.sampled_subgraphs = []\n        return minibatch\n\n    @staticmethod\n    def _set_input_nodes(minibatch):\n        minibatch.input_nodes = minibatch._seed_nodes\n        return minibatch\n\n    # pylint: disable=arguments-differ\n    def sampling_stages(\n        self,\n        datapipe,\n        graph,\n        fanouts,\n        replace,\n        prob_name,\n        deduplicate,\n        sampler,\n        overlap_fetch,\n        cooperative,\n        asynchronous,\n        layer_dependency,\n    ):\n        datapipe = datapipe.transform(\n            partial(self._prepare, graph.node_type_to_id)\n        )\n        is_labor = sampler.__name__ == \"sample_layer_neighbors\"\n        if is_labor:\n            datapipe = datapipe.transform(self._set_seed)\n        for fanout in reversed(fanouts):\n            # Convert fanout to tensor.\n            if not isinstance(fanout, torch.Tensor):\n                fanout = torch.LongTensor([int(fanout)])\n            datapipe = datapipe.sample_per_layer(\n                sampler, fanout, replace, prob_name, overlap_fetch, asynchronous\n            )\n            datapipe = datapipe.compact_per_layer(\n                deduplicate, cooperative, asynchronous\n            )\n            if is_labor and not layer_dependency:\n                datapipe = datapipe.transform(self._increment_seed)\n        if is_labor:\n            datapipe = datapipe.transform(self._delattr_dependency)\n        return datapipe.transform(self._set_input_nodes)\n\n\n@functional_datapipe(\"sample_neighbor\")\nclass NeighborSampler(NeighborSamplerImpl):\n    # pylint: disable=abstract-method\n    \"\"\"Sample neighbor edges from a graph and return a subgraph.\n\n    Functional name: :obj:`sample_neighbor`.\n\n    Neighbor sampler is responsible for sampling a subgraph from given data. It\n    returns an induced subgraph along with compacted information. In the\n    context of a node classification task, the neighbor sampler directly\n    utilizes the nodes provided as seed nodes. However, in scenarios involving\n    link prediction, the process needs another pre-peocess operation. That is,\n    gathering unique nodes from the given node pairs, encompassing both\n    positive and negative node pairs, and employs these nodes as the seed nodes\n    for subsequent steps. When the graph is hetero, sampled subgraphs in\n    minibatch will contain every edge type even though it is empty after\n    sampling.\n\n    Parameters\n    ----------\n    datapipe : DataPipe\n        The datapipe.\n    graph : FusedCSCSamplingGraph\n        The graph on which to perform subgraph sampling.\n    fanouts: list[torch.Tensor] or list[int]\n        The number of edges to be sampled for each node with or without\n        considering edge types. The length of this parameter implicitly\n        signifies the layer of sampling being conducted.\n        Note: The fanout order is from the outermost layer to innermost layer.\n        For example, the fanout '[15, 10, 5]' means that 15 to the outermost\n        layer, 10 to the intermediate layer and 5 corresponds to the innermost\n        layer.\n    replace: bool\n        Boolean indicating whether the sample is preformed with or\n        without replacement. If True, a value can be selected multiple\n        times. Otherwise, each value can be selected only once.\n    prob_name: str, optional\n        The name of an edge attribute used as the weights of sampling for\n        each node. This attribute tensor should contain (unnormalized)\n        probabilities corresponding to each neighboring edge of a node.\n        It must be a 1D floating-point or boolean tensor, with the number\n        of elements equalling the total number of edges.\n    deduplicate: bool\n        Boolean indicating whether seeds between hops will be deduplicated.\n        If True, the same elements in seeds will be deleted to only one.\n        Otherwise, the same elements will be remained.\n    overlap_fetch : bool, optional\n        If True, the data loader will overlap the UVA graph fetching operations\n        with the rest of operations by using an alternative CUDA stream. This\n        option should be enabled if you have moved your graph to the pinned\n        memory for optimal performance. Default is False.\n    num_gpu_cached_edges : int, optional\n        If positive and overlap_graph_fetch is True, then the GPU will cache\n        frequently accessed vertex neighborhoods to reduce the PCI-e bandwidth\n        demand due to pinned graph accesses.\n    gpu_cache_threshold : int, optional\n        Determines how many times a vertex needs to be accessed before its\n        neighborhood ends up being cached on the GPU.\n    cooperative: bool, optional\n        Boolean indicating whether Cooperative Minibatching, which was initially\n        proposed in\n        `Deep Graph Library PR#4337<https://github.com/dmlc/dgl/pull/4337>`__\n        and was later first fully described in\n        `Cooperative Minibatching in Graph Neural Networks\n        <https://arxiv.org/abs/2310.12403>`__. Cooperation between the GPUs\n        eliminates duplicate work performed across the GPUs due to the\n        overlapping sampled k-hop neighborhoods of seed nodes when performing\n        GNN minibatching.\n    asynchronous: bool\n        Boolean indicating whether sampling and compaction stages should run\n        in background threads to hide the latency of CPU GPU synchronization.\n        Should be enabled only when sampling on the GPU.\n\n    Examples\n    -------\n    >>> import torch\n    >>> import dgl.graphbolt as gb\n    >>> indptr = torch.LongTensor([0, 2, 4, 5, 6, 7 ,8])\n    >>> indices = torch.LongTensor([1, 2, 0, 3, 5, 4, 3, 5])\n    >>> graph = gb.fused_csc_sampling_graph(indptr, indices)\n    >>> seeds = torch.LongTensor([[0, 1], [1, 2]])\n    >>> item_set = gb.ItemSet(seeds, names=\"seeds\")\n    >>> datapipe = gb.ItemSampler(item_set, batch_size=1)\n    >>> datapipe = datapipe.sample_uniform_negative(graph, 2)\n    >>> datapipe = datapipe.sample_neighbor(graph, [5, 10, 15])\n    >>> next(iter(datapipe)).sampled_subgraphs\n    [SampledSubgraphImpl(sampled_csc=CSCFormatBase(\n            indptr=tensor([0, 2, 4, 5, 6, 7, 8]),\n            indices=tensor([1, 4, 0, 5, 5, 3, 3, 2]),\n        ),\n        original_row_node_ids=tensor([0, 1, 4, 5, 2, 3]),\n        original_edge_ids=None,\n        original_column_node_ids=tensor([0, 1, 4, 5, 2, 3]),\n    ),\n    SampledSubgraphImpl(sampled_csc=CSCFormatBase(\n            indptr=tensor([0, 2, 4, 5, 6, 7, 8]),\n            indices=tensor([1, 4, 0, 5, 5, 3, 3, 2]),\n        ),\n        original_row_node_ids=tensor([0, 1, 4, 5, 2, 3]),\n        original_edge_ids=None,\n        original_column_node_ids=tensor([0, 1, 4, 5, 2, 3]),\n    ),\n    SampledSubgraphImpl(sampled_csc=CSCFormatBase(\n            indptr=tensor([0, 2, 4, 5, 6]),\n            indices=tensor([1, 4, 0, 5, 5, 3]),\n        ),\n        original_row_node_ids=tensor([0, 1, 4, 5, 2, 3]),\n        original_edge_ids=None,\n        original_column_node_ids=tensor([0, 1, 4, 5]),\n    )]\n    \"\"\"\n\n    # pylint: disable=useless-super-delegation\n    def __init__(\n        self,\n        datapipe,\n        graph,\n        fanouts,\n        replace=False,\n        prob_name=None,\n        deduplicate=True,\n        overlap_fetch=False,\n        num_gpu_cached_edges=0,\n        gpu_cache_threshold=1,\n        cooperative=False,\n        asynchronous=False,\n    ):\n        super().__init__(\n            datapipe,\n            graph,\n            fanouts,\n            replace,\n            prob_name,\n            deduplicate,\n            graph.sample_neighbors,\n            overlap_fetch,\n            num_gpu_cached_edges,\n            gpu_cache_threshold,\n            cooperative,\n            asynchronous,\n        )\n\n\n@functional_datapipe(\"sample_layer_neighbor\")\nclass LayerNeighborSampler(NeighborSamplerImpl):\n    # pylint: disable=abstract-method\n    \"\"\"Sample layer neighbor edges from a graph and return a subgraph.\n\n    Functional name: :obj:`sample_layer_neighbor`.\n\n    Sampler that builds computational dependency of node representations via\n    labor sampling for multilayer GNN from the NeurIPS 2023 paper\n    `Layer-Neighbor Sampling -- Defusing Neighborhood Explosion in GNNs\n    <https://proceedings.neurips.cc/paper_files/paper/2023/file/51f9036d5e7ae822da8f6d4adda1fb39-Paper-Conference.pdf>`__\n\n    Layer-Neighbor sampler is responsible for sampling a subgraph from given\n    data. It returns an induced subgraph along with compacted information. In\n    the context of a node classification task, the neighbor sampler directly\n    utilizes the nodes provided as seed nodes. However, in scenarios involving\n    link prediction, the process needs another pre-process operation. That is,\n    gathering unique nodes from the given node pairs, encompassing both\n    positive and negative node pairs, and employs these nodes as the seed nodes\n    for subsequent steps. When the graph is hetero, sampled subgraphs in\n    minibatch will contain every edge type even though it is empty after\n    sampling.\n\n    Implements the approach described in Appendix A.3 of the paper. Similar to\n    dgl.dataloading.LaborSampler but this uses sequential poisson sampling\n    instead of poisson sampling to keep the count of sampled edges per vertex\n    deterministic like NeighborSampler. Thus, it is a drop-in replacement for\n    NeighborSampler. However, unlike NeighborSampler, it samples fewer vertices\n    and edges for multilayer GNN scenario without harming convergence speed with\n    respect to training iterations.\n\n    Parameters\n    ----------\n    datapipe : DataPipe\n        The datapipe.\n    graph : FusedCSCSamplingGraph\n        The graph on which to perform subgraph sampling.\n    fanouts: list[torch.Tensor]\n        The number of edges to be sampled for each node with or without\n        considering edge types. The length of this parameter implicitly\n        signifies the layer of sampling being conducted.\n    replace: bool\n        Boolean indicating whether the sample is preformed with or\n        without replacement. If True, a value can be selected multiple\n        times. Otherwise, each value can be selected only once.\n    prob_name: str, optional\n        The name of an edge attribute used as the weights of sampling for\n        each node. This attribute tensor should contain (unnormalized)\n        probabilities corresponding to each neighboring edge of a node.\n        It must be a 1D floating-point or boolean tensor, with the number\n        of elements equalling the total number of edges.\n    deduplicate: bool\n        Boolean indicating whether seeds between hops will be deduplicated.\n        If True, the same elements in seeds will be deleted to only one.\n        Otherwise, the same elements will be remained.\n    layer_dependency: bool\n        Boolean indicating whether different layers should use the same random\n        variates. Results in a reduction in the number of nodes sampled and\n        turns LayerNeighborSampler into a subgraph sampling method. Later layers\n        will be guaranteed to sample overlapping neighbors as the previous\n        layers.\n    batch_dependency: int\n        Specifies whether consecutive minibatches should use similar random\n        variates. Results in a higher temporal access locality of sampled\n        nodes and edges. Setting it to :math:`\\\\kappa` slows down the change in\n        the random variates proportional to :math:`\\\\frac{1}{\\\\kappa}`. Implements\n        the dependent minibatching approach in `arXiv:2310.12403\n        <https://arxiv.org/abs/2310.12403>`__.\n    overlap_fetch : bool, optional\n        If True, the data loader will overlap the UVA graph fetching operations\n        with the rest of operations by using an alternative CUDA stream. This\n        option should be enabled if you have moved your graph to the pinned\n        memory for optimal performance. Default is False.\n    num_gpu_cached_edges : int, optional\n        If positive and overlap_graph_fetch is True, then the GPU will cache\n        frequently accessed vertex neighborhoods to reduce the PCI-e bandwidth\n        demand due to pinned graph accesses.\n    gpu_cache_threshold : int, optional\n        Determines how many times a vertex needs to be accessed before its\n        neighborhood ends up being cached on the GPU.\n    cooperative: bool, optional\n        Boolean indicating whether Cooperative Minibatching, which was initially\n        proposed in\n        `Deep Graph Library PR#4337<https://github.com/dmlc/dgl/pull/4337>`__\n        and was later first fully described in\n        `Cooperative Minibatching in Graph Neural Networks\n        <https://arxiv.org/abs/2310.12403>`__. Cooperation between the GPUs\n        eliminates duplicate work performed across the GPUs due to the\n        overlapping sampled k-hop neighborhoods of seed nodes when performing\n        GNN minibatching.\n    asynchronous: bool\n        Boolean indicating whether sampling and compaction stages should run\n        in background threads to hide the latency of CPU GPU synchronization.\n        Should be enabled only when sampling on the GPU.\n\n    Examples\n    -------\n    >>> import dgl.graphbolt as gb\n    >>> import torch\n    >>> indptr = torch.LongTensor([0, 2, 4, 5, 6, 7 ,8])\n    >>> indices = torch.LongTensor([1, 2, 0, 3, 5, 4, 3, 5])\n    >>> graph = gb.fused_csc_sampling_graph(indptr, indices)\n    >>> seeds = torch.LongTensor([[0, 1], [1, 2]])\n    >>> item_set = gb.ItemSet(seeds, names=\"seeds\")\n    >>> item_sampler = gb.ItemSampler(item_set, batch_size=1,)\n    >>> neg_sampler = gb.UniformNegativeSampler(item_sampler, graph, 2)\n    >>> fanouts = [torch.LongTensor([5]),\n    ...     torch.LongTensor([10]),torch.LongTensor([15])]\n    >>> subgraph_sampler = gb.LayerNeighborSampler(neg_sampler, graph, fanouts)\n    >>> next(iter(subgraph_sampler)).sampled_subgraphs\n    [SampledSubgraphImpl(sampled_csc=CSCFormatBase(\n            indptr=tensor([0, 2, 4, 5, 6, 7, 8]),\n            indices=tensor([1, 3, 0, 4, 2, 2, 5, 4]),\n        ),\n        original_row_node_ids=tensor([0, 1, 5, 2, 3, 4]),\n        original_edge_ids=None,\n        original_column_node_ids=tensor([0, 1, 5, 2, 3, 4]),\n    ),\n    SampledSubgraphImpl(sampled_csc=CSCFormatBase(\n            indptr=tensor([0, 2, 4, 5, 6, 7]),\n            indices=tensor([1, 3, 0, 4, 2, 2, 5]),\n        ),\n        original_row_node_ids=tensor([0, 1, 5, 2, 3, 4]),\n        original_edge_ids=None,\n        original_column_node_ids=tensor([0, 1, 5, 2, 3]),\n    ),\n    SampledSubgraphImpl(sampled_csc=CSCFormatBase(\n            indptr=tensor([0, 2, 4, 5, 6]),\n            indices=tensor([1, 3, 0, 4, 2, 2]),\n        ),\n        original_row_node_ids=tensor([0, 1, 5, 2, 3]),\n        original_edge_ids=None,\n        original_column_node_ids=tensor([0, 1, 5, 2]),\n    )]\n    >>> next(iter(subgraph_sampler)).compacted_seeds\n    tensor([[0, 1], [0, 2], [0, 3]])\n    >>> next(iter(subgraph_sampler)).labels\n    tensor([1., 0., 0.])\n    >>> next(iter(subgraph_sampler)).indexes\n    tensor([0, 0, 0])\n    \"\"\"\n\n    def __init__(\n        self,\n        datapipe,\n        graph,\n        fanouts,\n        replace=False,\n        prob_name=None,\n        deduplicate=True,\n        layer_dependency=False,\n        batch_dependency=1,\n        overlap_fetch=False,\n        num_gpu_cached_edges=0,\n        gpu_cache_threshold=1,\n        cooperative=False,\n        asynchronous=False,\n    ):\n        super().__init__(\n            datapipe,\n            graph,\n            fanouts,\n            replace,\n            prob_name,\n            deduplicate,\n            graph.sample_layer_neighbors,\n            overlap_fetch,\n            num_gpu_cached_edges,\n            gpu_cache_threshold,\n            cooperative,\n            asynchronous,\n            layer_dependency,\n            batch_dependency,\n        )\n"
  },
  {
    "path": "python/dgl/graphbolt/impl/ondisk_dataset.py",
    "content": "\"\"\"GraphBolt OnDiskDataset.\"\"\"\n\nimport bisect\nimport json\nimport os\nimport shutil\nimport textwrap\nfrom copy import deepcopy\nfrom typing import Dict, List, Union\n\nimport numpy as np\n\nimport torch\nimport yaml\n\nfrom ..base import etype_str_to_tuple, ORIGINAL_EDGE_ID\nfrom ..dataset import Dataset, Task\nfrom ..internal import (\n    calculate_dir_hash,\n    check_dataset_change,\n    copy_or_convert_data,\n    read_data,\n    read_edges,\n)\nfrom ..internal_utils import (\n    download,\n    extract_archive,\n    gb_warning,\n    get_attributes,\n)\nfrom ..itemset import HeteroItemSet, ItemSet\nfrom ..sampling_graph import SamplingGraph\nfrom .fused_csc_sampling_graph import (\n    fused_csc_sampling_graph,\n    FusedCSCSamplingGraph,\n)\nfrom .ondisk_metadata import (\n    OnDiskGraphTopology,\n    OnDiskMetaData,\n    OnDiskTaskData,\n    OnDiskTVTSet,\n)\nfrom .torch_based_feature_store import TorchBasedFeatureStore\n\n__all__ = [\"OnDiskDataset\", \"preprocess_ondisk_dataset\", \"BuiltinDataset\"]\n\nNAMES_INDICATING_NODE_IDS = [\n    \"seeds\",\n]\n\n\ndef _graph_data_to_fused_csc_sampling_graph(\n    dataset_dir: str,\n    graph_data: Dict,\n    include_original_edge_id: bool,\n    auto_cast_to_optimal_dtype: bool,\n) -> FusedCSCSamplingGraph:\n    \"\"\"Convert the raw graph data into FusedCSCSamplingGraph.\n\n    Parameters\n    ----------\n    dataset_dir : str\n        The path to the dataset directory.\n    graph_data : Dict\n        The raw data read from yaml file.\n    include_original_edge_id : bool\n        Whether to include the original edge id in the FusedCSCSamplingGraph.\n    auto_cast_to_optimal_dtype: bool, optional\n        Casts the dtypes of tensors in the dataset into smallest possible dtypes\n        for reduced storage requirements and potentially increased performance.\n\n    Returns\n    -------\n    sampling_graph : FusedCSCSamplingGraph\n        The FusedCSCSamplingGraph constructed from the raw data.\n    \"\"\"\n    from ...sparse import spmatrix\n\n    is_homogeneous = (\n        len(graph_data[\"nodes\"]) == 1\n        and len(graph_data[\"edges\"]) == 1\n        and \"type\" not in graph_data[\"nodes\"][0]\n        and \"type\" not in graph_data[\"edges\"][0]\n    )\n\n    if is_homogeneous:\n        # Homogeneous graph.\n        edge_fmt = graph_data[\"edges\"][0][\"format\"]\n        edge_path = graph_data[\"edges\"][0][\"path\"]\n        src, dst = read_edges(dataset_dir, edge_fmt, edge_path)\n        num_nodes = graph_data[\"nodes\"][0][\"num\"]\n        num_edges = len(src)\n        coo_tensor = torch.tensor(np.array([src, dst]))\n        sparse_matrix = spmatrix(coo_tensor, shape=(num_nodes, num_nodes))\n        del coo_tensor\n        indptr, indices, edge_ids = sparse_matrix.csc()\n        del sparse_matrix\n\n        if auto_cast_to_optimal_dtype:\n            if num_nodes <= torch.iinfo(torch.int32).max:\n                indices = indices.to(torch.int32)\n            if num_edges <= torch.iinfo(torch.int32).max:\n                indptr = indptr.to(torch.int32)\n                edge_ids = edge_ids.to(torch.int32)\n\n        node_type_offset = None\n        type_per_edge = None\n        node_type_to_id = None\n        edge_type_to_id = None\n        node_attributes = {}\n        edge_attributes = {}\n        if include_original_edge_id:\n            edge_attributes[ORIGINAL_EDGE_ID] = edge_ids\n    else:\n        # Heterogeneous graph.\n        # Sort graph_data by ntype/etype lexicographically to ensure ordering.\n        graph_data[\"nodes\"].sort(key=lambda x: x[\"type\"])\n        graph_data[\"edges\"].sort(key=lambda x: x[\"type\"])\n        # Construct node_type_offset and node_type_to_id.\n        node_type_offset = [0]\n        node_type_to_id = {}\n        for ntype_id, node_info in enumerate(graph_data[\"nodes\"]):\n            node_type_to_id[node_info[\"type\"]] = ntype_id\n            node_type_offset.append(node_type_offset[-1] + node_info[\"num\"])\n        total_num_nodes = node_type_offset[-1]\n        # Construct edge_type_offset, edge_type_to_id and coo_tensor.\n        edge_type_offset = [0]\n        edge_type_to_id = {}\n        coo_src_list = []\n        coo_dst_list = []\n        coo_etype_list = []\n        for etype_id, edge_info in enumerate(graph_data[\"edges\"]):\n            edge_type_to_id[edge_info[\"type\"]] = etype_id\n            edge_fmt = edge_info[\"format\"]\n            edge_path = edge_info[\"path\"]\n            src, dst = read_edges(dataset_dir, edge_fmt, edge_path)\n            edge_type_offset.append(edge_type_offset[-1] + len(src))\n            src_type, _, dst_type = etype_str_to_tuple(edge_info[\"type\"])\n            src += node_type_offset[node_type_to_id[src_type]]\n            dst += node_type_offset[node_type_to_id[dst_type]]\n            coo_src_list.append(torch.tensor(src))\n            coo_dst_list.append(torch.tensor(dst))\n            coo_etype_list.append(torch.full((len(src),), etype_id))\n        total_num_edges = edge_type_offset[-1]\n\n        coo_src = torch.cat(coo_src_list)\n        del coo_src_list\n        coo_dst = torch.cat(coo_dst_list)\n        del coo_dst_list\n        if auto_cast_to_optimal_dtype:\n            dtypes = [torch.uint8, torch.int16, torch.int32, torch.int64]\n            dtype_maxes = [torch.iinfo(dtype).max for dtype in dtypes]\n            dtype_id = bisect.bisect_left(dtype_maxes, len(edge_type_to_id) - 1)\n            etype_dtype = dtypes[dtype_id]\n            coo_etype_list = [\n                tensor.to(etype_dtype) for tensor in coo_etype_list\n            ]\n        coo_etype = torch.cat(coo_etype_list)\n        del coo_etype_list\n\n        sparse_matrix = spmatrix(\n            indices=torch.stack((coo_src, coo_dst), dim=0),\n            shape=(total_num_nodes, total_num_nodes),\n        )\n        del coo_src, coo_dst\n        indptr, indices, edge_ids = sparse_matrix.csc()\n        del sparse_matrix\n\n        if auto_cast_to_optimal_dtype:\n            if total_num_nodes <= torch.iinfo(torch.int32).max:\n                indices = indices.to(torch.int32)\n            if total_num_edges <= torch.iinfo(torch.int32).max:\n                indptr = indptr.to(torch.int32)\n                edge_ids = edge_ids.to(torch.int32)\n\n        node_type_offset = torch.tensor(node_type_offset, dtype=indices.dtype)\n        type_per_edge = torch.index_select(coo_etype, dim=0, index=edge_ids)\n        del coo_etype\n        node_attributes = {}\n        edge_attributes = {}\n        if include_original_edge_id:\n            # If uint8 or int16 was chosen above for etypes, we cast to int.\n            temp_etypes = (\n                type_per_edge.int()\n                if type_per_edge.element_size() < 4\n                else type_per_edge\n            )\n            edge_ids -= torch.index_select(\n                torch.tensor(edge_type_offset, dtype=edge_ids.dtype),\n                dim=0,\n                index=temp_etypes,\n            )\n            del temp_etypes\n            edge_attributes[ORIGINAL_EDGE_ID] = edge_ids\n\n    # Load the sampling related node/edge features and add them to\n    # the sampling-graph.\n    if graph_data.get(\"feature_data\", None):\n        if is_homogeneous:\n            # Homogeneous graph.\n            for graph_feature in graph_data[\"feature_data\"]:\n                in_memory = (\n                    True\n                    if \"in_memory\" not in graph_feature\n                    else graph_feature[\"in_memory\"]\n                )\n                if graph_feature[\"domain\"] == \"node\":\n                    node_data = read_data(\n                        os.path.join(dataset_dir, graph_feature[\"path\"]),\n                        graph_feature[\"format\"],\n                        in_memory=in_memory,\n                    )\n                    assert node_data.shape[0] == num_nodes\n                    node_attributes[graph_feature[\"name\"]] = node_data\n                elif graph_feature[\"domain\"] == \"edge\":\n                    edge_data = read_data(\n                        os.path.join(dataset_dir, graph_feature[\"path\"]),\n                        graph_feature[\"format\"],\n                        in_memory=in_memory,\n                    )\n                    assert edge_data.shape[0] == num_edges\n                    edge_attributes[graph_feature[\"name\"]] = edge_data\n        else:\n            # Heterogeneous graph.\n            node_feature_collector = {}\n            edge_feature_collector = {}\n            for graph_feature in graph_data[\"feature_data\"]:\n                in_memory = (\n                    True\n                    if \"in_memory\" not in graph_feature\n                    else graph_feature[\"in_memory\"]\n                )\n                if graph_feature[\"domain\"] == \"node\":\n                    node_data = read_data(\n                        os.path.join(dataset_dir, graph_feature[\"path\"]),\n                        graph_feature[\"format\"],\n                        in_memory=in_memory,\n                    )\n                    if graph_feature[\"name\"] not in node_feature_collector:\n                        node_feature_collector[graph_feature[\"name\"]] = {}\n                    node_feature_collector[graph_feature[\"name\"]][\n                        graph_feature[\"type\"]\n                    ] = node_data\n                elif graph_feature[\"domain\"] == \"edge\":\n                    edge_data = read_data(\n                        os.path.join(dataset_dir, graph_feature[\"path\"]),\n                        graph_feature[\"format\"],\n                        in_memory=in_memory,\n                    )\n                    if graph_feature[\"name\"] not in edge_feature_collector:\n                        edge_feature_collector[graph_feature[\"name\"]] = {}\n                    edge_feature_collector[graph_feature[\"name\"]][\n                        graph_feature[\"type\"]\n                    ] = edge_data\n\n            # For heterogenous, a node/edge feature must cover all node/edge types.\n            all_node_types = set(node_type_to_id.keys())\n            for feat_name, feat_data in node_feature_collector.items():\n                existing_node_type = set(feat_data.keys())\n                assert all_node_types == existing_node_type, (\n                    f\"Node feature {feat_name} does not cover all node types. \"\n                    f\"Existing types: {existing_node_type}. \"\n                    f\"Expected types: {all_node_types}.\"\n                )\n            all_edge_types = set(edge_type_to_id.keys())\n            for feat_name, feat_data in edge_feature_collector.items():\n                existing_edge_type = set(feat_data.keys())\n                assert all_edge_types == existing_edge_type, (\n                    f\"Edge feature {feat_name} does not cover all edge types. \"\n                    f\"Existing types: {existing_edge_type}. \"\n                    f\"Expected types: {all_edge_types}.\"\n                )\n\n            for feat_name, feat_data in node_feature_collector.items():\n                _feat = next(iter(feat_data.values()))\n                feat_tensor = torch.empty(\n                    ([total_num_nodes] + list(_feat.shape[1:])),\n                    dtype=_feat.dtype,\n                )\n                for ntype, feat in feat_data.items():\n                    feat_tensor[\n                        node_type_offset[\n                            node_type_to_id[ntype]\n                        ] : node_type_offset[node_type_to_id[ntype] + 1]\n                    ] = feat\n                node_attributes[feat_name] = feat_tensor\n            del node_feature_collector\n            for feat_name, feat_data in edge_feature_collector.items():\n                _feat = next(iter(feat_data.values()))\n                feat_tensor = torch.empty(\n                    ([total_num_edges] + list(_feat.shape[1:])),\n                    dtype=_feat.dtype,\n                )\n                for etype, feat in feat_data.items():\n                    feat_tensor[\n                        edge_type_offset[\n                            edge_type_to_id[etype]\n                        ] : edge_type_offset[edge_type_to_id[etype] + 1]\n                    ] = feat\n                edge_attributes[feat_name] = feat_tensor\n            del edge_feature_collector\n\n    if not bool(node_attributes):\n        node_attributes = None\n    if not bool(edge_attributes):\n        edge_attributes = None\n\n    # Construct the FusedCSCSamplingGraph.\n    return fused_csc_sampling_graph(\n        csc_indptr=indptr,\n        indices=indices,\n        node_type_offset=node_type_offset,\n        type_per_edge=type_per_edge,\n        node_type_to_id=node_type_to_id,\n        edge_type_to_id=edge_type_to_id,\n        node_attributes=node_attributes,\n        edge_attributes=edge_attributes,\n    )\n\n\ndef preprocess_ondisk_dataset(\n    dataset_dir: str,\n    include_original_edge_id: bool = False,\n    force_preprocess: bool = None,\n    auto_cast_to_optimal_dtype: bool = True,\n) -> str:\n    \"\"\"Preprocess the on-disk dataset. Parse the input config file,\n    load the data, and save the data in the format that GraphBolt supports.\n\n    Parameters\n    ----------\n    dataset_dir : str\n        The path to the dataset directory.\n    include_original_edge_id : bool, optional\n        Whether to include the original edge id in the FusedCSCSamplingGraph.\n    force_preprocess: bool, optional\n        Whether to force reload the ondisk dataset.\n    auto_cast_to_optimal_dtype: bool, optional\n        Casts the dtypes of tensors in the dataset into smallest possible dtypes\n        for reduced storage requirements and potentially increased performance.\n        Default is True.\n\n    Returns\n    -------\n    output_config_path : str\n        The path to the output config file.\n    \"\"\"\n    # Check if the dataset path is valid.\n    if not os.path.exists(dataset_dir):\n        raise RuntimeError(f\"Invalid dataset path: {dataset_dir}\")\n\n    # Check if the dataset_dir is a directory.\n    if not os.path.isdir(dataset_dir):\n        raise RuntimeError(\n            f\"The dataset must be a directory. But got {dataset_dir}\"\n        )\n\n    # 0. Check if the dataset is already preprocessed.\n    processed_dir_prefix = \"preprocessed\"\n    preprocess_metadata_path = os.path.join(\n        processed_dir_prefix, \"metadata.yaml\"\n    )\n    if os.path.exists(os.path.join(dataset_dir, preprocess_metadata_path)):\n        if force_preprocess is None:\n            with open(\n                os.path.join(dataset_dir, preprocess_metadata_path), \"r\"\n            ) as f:\n                preprocess_config = yaml.safe_load(f)\n            if (\n                preprocess_config.get(\"include_original_edge_id\", None)\n                == include_original_edge_id\n            ):\n                force_preprocess = check_dataset_change(\n                    dataset_dir, processed_dir_prefix\n                )\n            else:\n                force_preprocess = True\n        if force_preprocess:\n            shutil.rmtree(os.path.join(dataset_dir, processed_dir_prefix))\n            print(\n                \"The on-disk dataset is re-preprocessing, so the existing \"\n                + \"preprocessed dataset has been removed.\"\n            )\n        else:\n            print(\"The dataset is already preprocessed.\")\n            return os.path.join(dataset_dir, preprocess_metadata_path)\n\n    print(\"Start to preprocess the on-disk dataset.\")\n\n    # Check if the metadata.yaml exists.\n    metadata_file_path = os.path.join(dataset_dir, \"metadata.yaml\")\n    if not os.path.exists(metadata_file_path):\n        raise RuntimeError(\"metadata.yaml does not exist.\")\n\n    # Read the input config.\n    with open(metadata_file_path, \"r\") as f:\n        input_config = yaml.safe_load(f)\n\n    # 1. Make `processed_dir_abs` directory if it does not exist.\n    os.makedirs(os.path.join(dataset_dir, processed_dir_prefix), exist_ok=True)\n    output_config = deepcopy(input_config)\n\n    # 2. Load the data and create a FusedCSCSamplingGraph.\n    if \"graph\" not in input_config:\n        raise RuntimeError(\"Invalid config: does not contain graph field.\")\n\n    sampling_graph = _graph_data_to_fused_csc_sampling_graph(\n        dataset_dir,\n        input_config[\"graph\"],\n        include_original_edge_id,\n        auto_cast_to_optimal_dtype,\n    )\n\n    # 3. Record value of include_original_edge_id.\n    output_config[\"include_original_edge_id\"] = include_original_edge_id\n\n    # 4. Save the FusedCSCSamplingGraph and modify the output_config.\n    output_config[\"graph_topology\"] = {}\n    output_config[\"graph_topology\"][\"type\"] = \"FusedCSCSamplingGraph\"\n    output_config[\"graph_topology\"][\"path\"] = os.path.join(\n        processed_dir_prefix, \"fused_csc_sampling_graph.pt\"\n    )\n\n    node_ids_within_int32 = (\n        sampling_graph.indices.dtype == torch.int32\n        and auto_cast_to_optimal_dtype\n    )\n    torch.save(\n        sampling_graph,\n        os.path.join(\n            dataset_dir,\n            output_config[\"graph_topology\"][\"path\"],\n        ),\n    )\n    del sampling_graph\n    del output_config[\"graph\"]\n\n    # 5. Load the node/edge features and do necessary conversion.\n    if input_config.get(\"feature_data\", None):\n        has_edge_feature_data = False\n        for feature, out_feature in zip(\n            input_config[\"feature_data\"], output_config[\"feature_data\"]\n        ):\n            # Always save the feature in numpy format.\n            out_feature[\"format\"] = \"numpy\"\n            out_feature[\"path\"] = os.path.join(\n                processed_dir_prefix, feature[\"path\"].replace(\"pt\", \"npy\")\n            )\n            in_memory = (\n                True if \"in_memory\" not in feature else feature[\"in_memory\"]\n            )\n            if not has_edge_feature_data and feature[\"domain\"] == \"edge\":\n                has_edge_feature_data = True\n            copy_or_convert_data(\n                os.path.join(dataset_dir, feature[\"path\"]),\n                os.path.join(dataset_dir, out_feature[\"path\"]),\n                feature[\"format\"],\n                output_format=out_feature[\"format\"],\n                in_memory=in_memory,\n                is_feature=True,\n            )\n        if has_edge_feature_data and not include_original_edge_id:\n            gb_warning(\"Edge feature is stored, but edge IDs are not saved.\")\n\n    # 6. Save tasks and train/val/test split according to the output_config.\n    if input_config.get(\"tasks\", None):\n        for input_task, output_task in zip(\n            input_config[\"tasks\"], output_config[\"tasks\"]\n        ):\n            for set_name in [\"train_set\", \"validation_set\", \"test_set\"]:\n                if set_name not in input_task:\n                    continue\n                for input_set_per_type, output_set_per_type in zip(\n                    input_task[set_name], output_task[set_name]\n                ):\n                    for input_data, output_data in zip(\n                        input_set_per_type[\"data\"], output_set_per_type[\"data\"]\n                    ):\n                        # Always save the feature in numpy format.\n                        output_data[\"format\"] = \"numpy\"\n                        output_data[\"path\"] = os.path.join(\n                            processed_dir_prefix,\n                            input_data[\"path\"].replace(\"pt\", \"npy\"),\n                        )\n                        name = (\n                            input_data[\"name\"] if \"name\" in input_data else None\n                        )\n                        copy_or_convert_data(\n                            os.path.join(dataset_dir, input_data[\"path\"]),\n                            os.path.join(dataset_dir, output_data[\"path\"]),\n                            input_data[\"format\"],\n                            output_data[\"format\"],\n                            within_int32=node_ids_within_int32\n                            and name in NAMES_INDICATING_NODE_IDS,\n                        )\n\n    # 7. Save the output_config.\n    output_config_path = os.path.join(dataset_dir, preprocess_metadata_path)\n    with open(output_config_path, \"w\") as f:\n        yaml.dump(output_config, f)\n    print(\"Finish preprocessing the on-disk dataset.\")\n\n    # 8. Calculate and save the hash value of the dataset directory.\n    hash_value_file = \"dataset_hash_value.txt\"\n    hash_value_file_path = os.path.join(\n        dataset_dir, processed_dir_prefix, hash_value_file\n    )\n    if os.path.exists(hash_value_file_path):\n        os.remove(hash_value_file_path)\n    dir_hash = calculate_dir_hash(dataset_dir)\n    with open(hash_value_file_path, \"w\") as f:\n        f.write(json.dumps(dir_hash, indent=4))\n\n    # 9. Return the absolute path of the preprocessing yaml file.\n    return output_config_path\n\n\nclass OnDiskTask:\n    \"\"\"An on-disk task.\n\n    An on-disk task is for ``OnDiskDataset``. It contains the metadata and the\n    train/val/test sets.\n    \"\"\"\n\n    def __init__(\n        self,\n        metadata: Dict,\n        train_set: Union[ItemSet, HeteroItemSet],\n        validation_set: Union[ItemSet, HeteroItemSet],\n        test_set: Union[ItemSet, HeteroItemSet],\n    ):\n        \"\"\"Initialize a task.\n\n        Parameters\n        ----------\n        metadata : Dict\n            Metadata.\n        train_set : Union[ItemSet, HeteroItemSet]\n            Training set.\n        validation_set : Union[ItemSet, HeteroItemSet]\n            Validation set.\n        test_set : Union[ItemSet, HeteroItemSet]\n            Test set.\n        \"\"\"\n        self._metadata = metadata\n        self._train_set = train_set\n        self._validation_set = validation_set\n        self._test_set = test_set\n\n    @property\n    def metadata(self) -> Dict:\n        \"\"\"Return the task metadata.\"\"\"\n        return self._metadata\n\n    @property\n    def train_set(self) -> Union[ItemSet, HeteroItemSet]:\n        \"\"\"Return the training set.\"\"\"\n        return self._train_set\n\n    @property\n    def validation_set(self) -> Union[ItemSet, HeteroItemSet]:\n        \"\"\"Return the validation set.\"\"\"\n        return self._validation_set\n\n    @property\n    def test_set(self) -> Union[ItemSet, HeteroItemSet]:\n        \"\"\"Return the test set.\"\"\"\n        return self._test_set\n\n    def __repr__(self) -> str:\n        ret = \"{Classname}({attributes})\"\n\n        attributes_str = \"\"\n\n        attributes = get_attributes(self)\n        attributes.reverse()\n        for attribute in attributes:\n            if attribute[0] == \"_\":\n                continue\n            value = getattr(self, attribute)\n            attributes_str += f\"{attribute}={value},\\n\"\n        attributes_str = textwrap.indent(\n            attributes_str, \" \" * len(\"OnDiskTask(\")\n        ).strip()\n\n        return ret.format(\n            Classname=self.__class__.__name__, attributes=attributes_str\n        )\n\n\nclass OnDiskDataset(Dataset):\n    \"\"\"An on-disk dataset which reads graph topology, feature data and\n    Train/Validation/Test set from disk.\n\n    Due to limited resources, the data which are too large to fit into RAM will\n    remain on disk while others reside in RAM once ``OnDiskDataset`` is\n    initialized. This behavior could be controled by user via ``in_memory``\n    field in YAML file. All paths in YAML file are relative paths to the\n    dataset directory.\n\n    A full example of YAML file is as follows:\n\n    .. code-block:: yaml\n\n        dataset_name: graphbolt_test\n        graph:\n          nodes:\n            - type: paper # could be omitted for homogeneous graph.\n              num: 1000\n            - type: author\n              num: 1000\n          edges:\n            - type: author:writes:paper # could be omitted for homogeneous graph.\n              format: csv # Can be csv only.\n              path: edge_data/author-writes-paper.csv\n            - type: paper:cites:paper\n              format: csv\n              path: edge_data/paper-cites-paper.csv\n        feature_data:\n          - domain: node\n            type: paper # could be omitted for homogeneous graph.\n            name: feat\n            format: numpy\n            in_memory: false # If not specified, default to true.\n            path: node_data/paper-feat.npy\n          - domain: edge\n            type: \"author:writes:paper\"\n            name: feat\n            format: numpy\n            in_memory: false\n            path: edge_data/author-writes-paper-feat.npy\n        tasks:\n          - name: \"edge_classification\"\n            num_classes: 10\n            train_set:\n              - type: paper # could be omitted for homogeneous graph.\n                data: # multiple data sources could be specified.\n                  - name: seeds\n                    format: numpy # Can be numpy or torch.\n                    in_memory: true # If not specified, default to true.\n                    path: set/paper-train-seeds.npy\n                  - name: labels\n                    format: numpy\n                    path: set/paper-train-labels.npy\n            validation_set:\n              - type: paper\n                data:\n                  - name: seeds\n                    format: numpy\n                    path: set/paper-validation-seeds.npy\n                  - name: labels\n                    format: numpy\n                    path: set/paper-validation-labels.npy\n            test_set:\n              - type: paper\n                data:\n                  - name: seeds\n                    format: numpy\n                    path: set/paper-test-seeds.npy\n                  - name: labels\n                    format: numpy\n                    path: set/paper-test-labels.npy\n\n    Parameters\n    ----------\n    path: str\n        The YAML file path.\n    include_original_edge_id: bool, optional\n        Whether to include the original edge id in the FusedCSCSamplingGraph.\n    force_preprocess: bool, optional\n        Whether to force reload the ondisk dataset.\n    auto_cast_to_optimal_dtype: bool, optional\n        Casts the dtypes of tensors in the dataset into smallest possible dtypes\n        for reduced storage requirements and potentially increased performance.\n        Default is True.\n    \"\"\"\n\n    def __init__(\n        self,\n        path: str,\n        include_original_edge_id: bool = False,\n        force_preprocess: bool = None,\n        auto_cast_to_optimal_dtype: bool = True,\n    ) -> None:\n        # Always call the preprocess function first. If already preprocessed,\n        # the function will return the original path directly.\n        self._dataset_dir = path\n        yaml_path = preprocess_ondisk_dataset(\n            path,\n            include_original_edge_id,\n            force_preprocess,\n            auto_cast_to_optimal_dtype,\n        )\n        with open(yaml_path) as f:\n            self._yaml_data = yaml.load(f, Loader=yaml.loader.SafeLoader)\n        self._loaded = False\n\n    def _convert_yaml_path_to_absolute_path(self):\n        \"\"\"Convert the path in YAML file to absolute path.\"\"\"\n        if \"graph_topology\" in self._yaml_data:\n            self._yaml_data[\"graph_topology\"][\"path\"] = os.path.join(\n                self._dataset_dir, self._yaml_data[\"graph_topology\"][\"path\"]\n            )\n        if \"feature_data\" in self._yaml_data:\n            for feature in self._yaml_data[\"feature_data\"]:\n                feature[\"path\"] = os.path.join(\n                    self._dataset_dir, feature[\"path\"]\n                )\n        if \"tasks\" in self._yaml_data:\n            for task in self._yaml_data[\"tasks\"]:\n                for set_name in [\"train_set\", \"validation_set\", \"test_set\"]:\n                    if set_name not in task:\n                        continue\n                    for set_per_type in task[set_name]:\n                        for data in set_per_type[\"data\"]:\n                            data[\"path\"] = os.path.join(\n                                self._dataset_dir, data[\"path\"]\n                            )\n\n    def load(self, tasks: List[str] = None):\n        \"\"\"Load the dataset.\n\n        Parameters\n        ----------\n        tasks: List[str] = None\n            The name of the tasks to be loaded. For single task, the type of\n            tasks can be both string and List[str]. For multiple tasks, only\n            List[str] is acceptable.\n\n        Examples\n        --------\n        1. Loading via single task name \"node_classification\".\n\n        >>> dataset = gb.OnDiskDataset(base_dir).load(\n        ...     tasks=\"node_classification\")\n        >>> len(dataset.tasks)\n        1\n        >>> dataset.tasks[0].metadata[\"name\"]\n        \"node_classification\"\n\n        2. Loading via single task name [\"node_classification\"].\n\n        >>> dataset = gb.OnDiskDataset(base_dir).load(\n        ...     tasks=[\"node_classification\"])\n        >>> len(dataset.tasks)\n        1\n        >>> dataset.tasks[0].metadata[\"name\"]\n        \"node_classification\"\n\n        3. Loading via multiple task names [\"node_classification\",\n        \"link_prediction\"].\n\n        >>> dataset = gb.OnDiskDataset(base_dir).load(\n        ...     tasks=[\"node_classification\",\"link_prediction\"])\n        >>> len(dataset.tasks)\n        2\n        >>> dataset.tasks[0].metadata[\"name\"]\n        \"node_classification\"\n        >>> dataset.tasks[1].metadata[\"name\"]\n        \"link_prediction\"\n        \"\"\"\n        self._convert_yaml_path_to_absolute_path()\n        self._meta = OnDiskMetaData(**self._yaml_data)\n        self._dataset_name = self._meta.dataset_name\n        self._graph = self._load_graph(self._meta.graph_topology)\n        self._feature = TorchBasedFeatureStore(self._meta.feature_data)\n        self._tasks = self._init_tasks(self._meta.tasks, tasks)\n        self._all_nodes_set = self._init_all_nodes_set(self._graph)\n        self._loaded = True\n        return self\n\n    @property\n    def yaml_data(self) -> Dict:\n        \"\"\"Return the YAML data.\"\"\"\n        return self._yaml_data\n\n    @property\n    def tasks(self) -> List[Task]:\n        \"\"\"Return the tasks.\"\"\"\n        self._check_loaded()\n        return self._tasks\n\n    @property\n    def graph(self) -> SamplingGraph:\n        \"\"\"Return the graph.\"\"\"\n        self._check_loaded()\n        return self._graph\n\n    @property\n    def feature(self) -> TorchBasedFeatureStore:\n        \"\"\"Return the feature.\"\"\"\n        self._check_loaded()\n        return self._feature\n\n    @property\n    def dataset_name(self) -> str:\n        \"\"\"Return the dataset name.\"\"\"\n        self._check_loaded()\n        return self._dataset_name\n\n    @property\n    def all_nodes_set(self) -> Union[ItemSet, HeteroItemSet]:\n        \"\"\"Return the itemset containing all nodes.\"\"\"\n        self._check_loaded()\n        return self._all_nodes_set\n\n    def _init_tasks(\n        self, tasks: List[OnDiskTaskData], selected_tasks: List[str]\n    ) -> List[OnDiskTask]:\n        \"\"\"Initialize the tasks.\"\"\"\n        if isinstance(selected_tasks, str):\n            selected_tasks = [selected_tasks]\n        if selected_tasks and not isinstance(selected_tasks, list):\n            raise TypeError(\n                f\"The type of selected_task should be list, but got {type(selected_tasks)}\"\n            )\n        ret = []\n        if tasks is None:\n            return ret\n        task_names = set()\n        for task in tasks:\n            task_name = task.extra_fields.get(\"name\", None)\n            if selected_tasks is None or task_name in selected_tasks:\n                ret.append(\n                    OnDiskTask(\n                        task.extra_fields,\n                        self._init_tvt_set(task.train_set),\n                        self._init_tvt_set(task.validation_set),\n                        self._init_tvt_set(task.test_set),\n                    )\n                )\n                if selected_tasks:\n                    task_names.add(task_name)\n        if selected_tasks:\n            not_found_tasks = set(selected_tasks) - task_names\n            if len(not_found_tasks):\n                gb_warning(\n                    f\"Below tasks are not found in YAML: {not_found_tasks}. Skipped.\"\n                )\n        return ret\n\n    def _check_loaded(self):\n        assert self._loaded, (\n            \"Please ensure that you have called the OnDiskDataset.load() method\"\n            + \" to properly load the data.\"\n        )\n\n    def _load_graph(\n        self, graph_topology: OnDiskGraphTopology\n    ) -> FusedCSCSamplingGraph:\n        \"\"\"Load the graph topology.\"\"\"\n        if graph_topology is None:\n            return None\n        if graph_topology.type == \"FusedCSCSamplingGraph\":\n            return torch.load(graph_topology.path, weights_only=False)\n        raise NotImplementedError(\n            f\"Graph topology type {graph_topology.type} is not supported.\"\n        )\n\n    def _init_tvt_set(\n        self, tvt_set: List[OnDiskTVTSet]\n    ) -> Union[ItemSet, HeteroItemSet]:\n        \"\"\"Initialize the TVT set.\"\"\"\n        ret = None\n        if (tvt_set is None) or (len(tvt_set) == 0):\n            return ret\n        if tvt_set[0].type is None:\n            assert (\n                len(tvt_set) == 1\n            ), \"Only one TVT set is allowed if type is not specified.\"\n            ret = ItemSet(\n                tuple(\n                    read_data(data.path, data.format, data.in_memory)\n                    for data in tvt_set[0].data\n                ),\n                names=tuple(data.name for data in tvt_set[0].data),\n            )\n        else:\n            itemsets = {}\n            for tvt in tvt_set:\n                itemsets[tvt.type] = ItemSet(\n                    tuple(\n                        read_data(data.path, data.format, data.in_memory)\n                        for data in tvt.data\n                    ),\n                    names=tuple(data.name for data in tvt.data),\n                )\n            ret = HeteroItemSet(itemsets)\n        return ret\n\n    def _init_all_nodes_set(self, graph) -> Union[ItemSet, HeteroItemSet]:\n        if graph is None:\n            gb_warning(\n                \"`all_nodes_set` is returned as None, since graph is None.\"\n            )\n            return None\n        num_nodes = graph.num_nodes\n        dtype = graph.indices.dtype\n        if isinstance(num_nodes, int):\n            return ItemSet(\n                torch.tensor(num_nodes, dtype=dtype),\n                names=\"seeds\",\n            )\n        else:\n            data = {\n                node_type: ItemSet(\n                    torch.tensor(num_node, dtype=dtype),\n                    names=\"seeds\",\n                )\n                for node_type, num_node in num_nodes.items()\n            }\n            return HeteroItemSet(data)\n\n\nclass BuiltinDataset(OnDiskDataset):\n    \"\"\"A utility class to download built-in dataset from AWS S3 and load it as\n    :class:`OnDiskDataset`.\n\n    Available built-in datasets include:\n\n    **cora**\n        The cora dataset is a homogeneous citation network dataset, which is\n        designed for the node classification task.\n\n    **ogbn-mag**\n        The ogbn-mag dataset is a heterogeneous network composed of a subset of\n        the Microsoft Academic Graph (MAG). See more details in\n        `ogbn-mag <https://ogb.stanford.edu/docs/nodeprop/#ogbn-mag>`_.\n\n        .. note::\n            Reverse edges are added to the original graph and duplicated\n            edges are removed.\n\n    **ogbl-citation2**\n        The ogbl-citation2 dataset is a directed graph, representing the\n        citation network between a subset of papers extracted from MAG. See\n        more details in `ogbl-citation2\n        <https://ogb.stanford.edu/docs/linkprop/#ogbl-citation2>`_.\n\n        .. note::\n            Reverse edges are added to the original graph and duplicated\n            edges are removed.\n\n    **ogbn-arxiv**\n        The ogbn-arxiv dataset is a directed graph, representing the citation\n        network between all Computer Science (CS) arXiv papers indexed by MAG.\n        See more details in `ogbn-arxiv\n        <https://ogb.stanford.edu/docs/nodeprop/#ogbn-arxiv>`_.\n\n        .. note::\n            Reverse edges are added to the original graph and duplicated\n            edges are removed.\n\n    **ogbn-papers100M**\n        The ogbn-papers100M dataset is a directed graph, representing the citation\n        network between all Computer Science (CS) arXiv papers indexed by MAG.\n        See more details in `ogbn-papers100M\n        <https://ogb.stanford.edu/docs/nodeprop/#ogbn-papers100M>`_.\n\n        .. note::\n            Reverse edges are added to the original graph and duplicated\n            edges are removed.\n\n    **ogbn-products**\n        The ogbn-products dataset is an undirected and unweighted graph,\n        representing an Amazon product co-purchasing network. See more details\n        in `ogbn-products\n        <https://ogb.stanford.edu/docs/nodeprop/#ogbn-products>`_.\n\n        .. note::\n            Reverse edges are added to the original graph.\n            Node features are stored as float32.\n\n    **ogb-lsc-mag240m**\n        The ogb-lsc-mag240m dataset is a heterogeneous academic graph extracted\n        from the Microsoft Academic Graph (MAG). See more details in\n        `ogb-lsc-mag240m <https://ogb.stanford.edu/docs/lsc/mag240m/>`_.\n\n        .. note::\n            Reverse edges are added to the original graph.\n\n    **igb-hom and igb-hom-[tiny|small|medium|large]**\n        The igb-hom-[tiny|small|medium|large] and igb-hom dataset is a homogeneous\n        citation network, which is designed for developers to train and evaluate\n        GNN models with high fidelity. See more details in\n        `igb-hom-[tiny|small|medium|large]\n        <https://github.com/IllinoisGraphBenchmark/IGB-Datasets>`_.\n\n        .. note::\n            Self edges are added to the original graph.\n            Node features are stored as float32.\n\n    **igb-het-[tiny|small|medium]**\n        The igb-hom-[tiny|small|medium] dataset is a heterogeneous citation network,\n        which is designed for developers to train and evaluate GNN models with\n        high fidelity. See more details in `igb-het-[tiny|small|medium]\n        <https://github.com/IllinoisGraphBenchmark/IGB-Datasets>`_.\n\n        .. note::\n            Four Reverse edge types are added to the original graph.\n            Node features are stored as float32.\n\n    Parameters\n    ----------\n    name : str\n        The name of the builtin dataset.\n    root : str, optional\n        The root directory of the dataset. Default ot ``datasets``.\n    \"\"\"\n\n    # For dataset that is smaller than 30GB, we use the base url.\n    # Otherwise, we use the accelerated url.\n    _base_url = \"https://data.dgl.ai/dataset/graphbolt/\"\n    _accelerated_url = (\n        \"https://dgl-data.s3-accelerate.amazonaws.com/dataset/graphbolt/\"\n    )\n    _datasets = [\n        \"cora\",\n        \"cora-seeds\",\n        \"ogbn-mag\",\n        \"ogbn-mag-seeds\",\n        \"ogbl-citation2\",\n        \"ogbl-citation2-seeds\",\n        \"ogbn-products\",\n        \"ogbn-products-seeds\",\n        \"ogbn-arxiv\",\n        \"ogbn-arxiv-seeds\",\n        \"igb-hom-tiny\",\n        \"igb-hom-tiny-seeds\",\n        \"igb-hom-small\",\n        \"igb-hom-small-seeds\",\n        \"igb-het-tiny\",\n        \"igb-het-tiny-seeds\",\n        \"igb-het-small\",\n        \"igb-het-small-seeds\",\n    ]\n    _large_datasets = [\n        \"ogb-lsc-mag240m\",\n        \"ogb-lsc-mag240m-seeds\",\n        \"ogbn-papers100M\",\n        \"ogbn-papers100M-seeds\",\n        \"igb-hom-medium\",\n        \"igb-hom-medium-seeds\",\n        \"igb-hom-large\",\n        \"igb-hom-large-seeds\",\n        \"igb-hom\",\n        \"igb-hom-seeds\",\n        \"igb-het-medium\",\n        \"igb-het-medium-seeds\",\n    ]\n    _all_datasets = _datasets + _large_datasets\n\n    def __init__(self, name: str, root: str = \"datasets\") -> OnDiskDataset:\n        # For user using DGL 2.2 or later version, we prefer them to use\n        # datasets with `seeds` suffix. This hack should be removed, when the\n        # datasets with `seeds` suffix have covered previous ones.\n        if \"seeds\" not in name:\n            name += \"-seeds\"\n        dataset_dir = os.path.join(root, name)\n        if not os.path.exists(dataset_dir):\n            if name not in self._all_datasets:\n                raise RuntimeError(\n                    f\"Dataset {name} is not available. Available datasets are \"\n                    f\"{self._all_datasets}.\"\n                )\n            url = (\n                self._accelerated_url\n                if name in self._large_datasets\n                else self._base_url\n            )\n            url += name + \".zip\"\n            os.makedirs(root, exist_ok=True)\n            zip_file_path = os.path.join(root, name + \".zip\")\n            download(url, path=zip_file_path)\n            extract_archive(zip_file_path, root, overwrite=True)\n            os.remove(zip_file_path)\n        super().__init__(dataset_dir, force_preprocess=False)\n"
  },
  {
    "path": "python/dgl/graphbolt/impl/ondisk_metadata.py",
    "content": "\"\"\"Ondisk metadata of GraphBolt.\"\"\"\n\nfrom enum import Enum\nfrom typing import Any, Dict, List, Optional\n\nimport pydantic\n\nfrom ..internal_utils import version\n\n\n__all__ = [\n    \"OnDiskFeatureDataFormat\",\n    \"OnDiskTVTSetData\",\n    \"OnDiskTVTSet\",\n    \"OnDiskFeatureDataDomain\",\n    \"OnDiskFeatureData\",\n    \"OnDiskMetaData\",\n    \"OnDiskGraphTopologyType\",\n    \"OnDiskGraphTopology\",\n    \"OnDiskTaskData\",\n]\n\n\nclass ExtraMetaData(pydantic.BaseModel, extra=\"allow\"):\n    \"\"\"Group extra fields into metadata. Internal use only.\"\"\"\n\n    extra_fields: Optional[Dict[str, Any]] = {}\n\n    # As pydantic 2.0 has changed the API of validators, we need to use\n    # different validators for different versions to be compatible with\n    # previous versions.\n    if version.parse(pydantic.__version__) >= version.parse(\"2.0\"):\n\n        @pydantic.model_validator(mode=\"before\")\n        @classmethod\n        def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:\n            \"\"\"Build extra fields.\"\"\"\n            for key in list(values.keys()):\n                if key not in cls.model_fields:\n                    values[\"extra_fields\"] = values.get(\"extra_fields\", {})\n                    values[\"extra_fields\"][key] = values.pop(key)\n            return values\n\n    else:\n\n        @pydantic.root_validator(pre=True)\n        @classmethod\n        def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:\n            \"\"\"Build extra fields.\"\"\"\n            for key in list(values.keys()):\n                if key not in cls.__fields__:\n                    values[\"extra_fields\"] = values.get(\"extra_fields\", {})\n                    values[\"extra_fields\"][key] = values.pop(key)\n            return values\n\n\nclass OnDiskFeatureDataFormat(str, Enum):\n    \"\"\"Enum of data format.\"\"\"\n\n    TORCH = \"torch\"\n    NUMPY = \"numpy\"\n\n\nclass OnDiskTVTSetData(pydantic.BaseModel):\n    \"\"\"Train-Validation-Test set data.\"\"\"\n\n    name: Optional[str] = None\n    format: OnDiskFeatureDataFormat\n    in_memory: Optional[bool] = True\n    path: str\n\n\nclass OnDiskTVTSet(pydantic.BaseModel):\n    \"\"\"Train-Validation-Test set.\"\"\"\n\n    type: Optional[str] = None\n    data: List[OnDiskTVTSetData]\n\n\nclass OnDiskFeatureDataDomain(str, Enum):\n    \"\"\"Enum of feature data domain.\"\"\"\n\n    NODE = \"node\"\n    EDGE = \"edge\"\n    GRAPH = \"graph\"\n\n\nclass OnDiskFeatureData(ExtraMetaData):\n    r\"\"\"The description of an on-disk feature.\"\"\"\n    domain: OnDiskFeatureDataDomain\n    type: Optional[str] = None\n    name: str\n    format: OnDiskFeatureDataFormat\n    path: str\n    in_memory: Optional[bool] = True\n\n\nclass OnDiskGraphTopologyType(str, Enum):\n    \"\"\"Enum of graph topology type.\"\"\"\n\n    FUSED_CSC_SAMPLING = \"FusedCSCSamplingGraph\"\n\n\nclass OnDiskGraphTopology(pydantic.BaseModel):\n    \"\"\"The description of an on-disk graph topology.\"\"\"\n\n    type: OnDiskGraphTopologyType\n    path: str\n\n\nclass OnDiskTaskData(ExtraMetaData):\n    \"\"\"Task specification in YAML.\"\"\"\n\n    train_set: Optional[List[OnDiskTVTSet]] = []\n    validation_set: Optional[List[OnDiskTVTSet]] = []\n    test_set: Optional[List[OnDiskTVTSet]] = []\n\n\nclass OnDiskMetaData(pydantic.BaseModel):\n    \"\"\"Metadata specification in YAML.\n\n    As multiple node/edge types and multiple splits are supported, each TVT set\n    is a list of list of ``OnDiskTVTSet``.\n    \"\"\"\n\n    dataset_name: Optional[str] = None\n    graph_topology: Optional[OnDiskGraphTopology] = None\n    feature_data: Optional[List[OnDiskFeatureData]] = []\n    tasks: Optional[List[OnDiskTaskData]] = []\n"
  },
  {
    "path": "python/dgl/graphbolt/impl/sampled_subgraph_impl.py",
    "content": "\"\"\"Sampled subgraph for FusedCSCSamplingGraph.\"\"\"\n# pylint: disable= invalid-name\nfrom dataclasses import dataclass\nfrom typing import Dict, Union\n\nimport torch\n\nfrom ..base import CSCFormatBase, etype_str_to_tuple\nfrom ..internal_utils import get_attributes\nfrom ..sampled_subgraph import SampledSubgraph\n\n__all__ = [\"SampledSubgraphImpl\"]\n\n\n@dataclass\nclass SampledSubgraphImpl(SampledSubgraph):\n    r\"\"\"Sampled subgraph of CSCSamplingGraph.\n\n    Examples\n    --------\n    >>> sampled_csc = {\"A:relation:B\": CSCFormatBase(indptr=torch.tensor([0, 1, 2, 3]),\n    ... indices=torch.tensor([0, 1, 2]))}\n    >>> original_column_node_ids = {'B': torch.tensor([10, 11, 12])}\n    >>> original_row_node_ids = {'A': torch.tensor([13, 14, 15])}\n    >>> original_edge_ids = {\"A:relation:B\": torch.tensor([19, 20, 21])}\n    >>> subgraph = gb.SampledSubgraphImpl(\n    ... sampled_csc=sampled_csc,\n    ... original_column_node_ids=original_column_node_ids,\n    ... original_row_node_ids=original_row_node_ids,\n    ... original_edge_ids=original_edge_ids\n    ... )\n    >>> print(subgraph.sampled_csc)\n    {\"A:relation:B\": CSCForamtBase(indptr=torch.tensor([0, 1, 2, 3]),\n    ... indices=torch.tensor([0, 1, 2]))}\n    >>> print(subgraph.original_column_node_ids)\n    {'B': tensor([10, 11, 12])}\n    >>> print(subgraph.original_row_node_ids)\n    {'A': tensor([13, 14, 15])}\n    >>> print(subgraph.original_edge_ids)\n    {\"A:relation:B\": tensor([19, 20, 21])}\n    \"\"\"\n    sampled_csc: Union[CSCFormatBase, Dict[str, CSCFormatBase]] = None\n    original_column_node_ids: Union[\n        Dict[str, torch.Tensor], torch.Tensor\n    ] = None\n    original_row_node_ids: Union[Dict[str, torch.Tensor], torch.Tensor] = None\n    original_edge_ids: Union[Dict[str, torch.Tensor], torch.Tensor] = None\n    # Used to fetch sampled_csc.indices if it is missing.\n    _edge_ids_in_fused_csc_sampling_graph: Union[\n        Dict[str, torch.Tensor], torch.Tensor\n    ] = None\n\n    def __post_init__(self):\n        if isinstance(self.sampled_csc, dict):\n            for etype, pair in self.sampled_csc.items():\n                assert (\n                    isinstance(etype, str)\n                    and len(etype_str_to_tuple(etype)) == 3\n                ), \"Edge type should be a string in format of str:str:str.\"\n                assert pair.indptr is not None and isinstance(\n                    pair.indptr, torch.Tensor\n                ), \"Node pair should be have indptr of type torch.Tensor.\"\n                # For CUDA, indices may be None because it will be fetched later.\n                if not pair.indptr.is_cuda or pair.indices is not None:\n                    assert isinstance(\n                        pair.indices, torch.Tensor\n                    ), \"Node pair should be have indices of type torch.Tensor.\"\n                else:\n                    assert isinstance(\n                        self._edge_ids_in_fused_csc_sampling_graph.get(\n                            etype, None\n                        ),\n                        torch.Tensor,\n                    ), \"When indices is missing, sampled edge ids needs to be provided.\"\n        else:\n            assert self.sampled_csc.indptr is not None and isinstance(\n                self.sampled_csc.indptr, torch.Tensor\n            ), \"Node pair should be have torch.Tensor indptr.\"\n            # For CUDA, indices may be None because it will be fetched later.\n            if (\n                not self.sampled_csc.indptr.is_cuda\n                or self.sampled_csc.indices is not None\n            ):\n                assert isinstance(\n                    self.sampled_csc.indices, torch.Tensor\n                ), \"Node pair should have a torch.Tensor indices.\"\n            else:\n                assert isinstance(\n                    self._edge_ids_in_fused_csc_sampling_graph, torch.Tensor\n                ), \"When indices is missing, sampled edge ids needs to be provided.\"\n\n    def __repr__(self) -> str:\n        return _sampled_subgraph_str(self, \"SampledSubgraphImpl\")\n\n\ndef _sampled_subgraph_str(sampled_subgraph: SampledSubgraph, classname) -> str:\n    final_str = classname + \"(\"\n\n    attributes = get_attributes(sampled_subgraph)\n    attributes.reverse()\n\n    for name in attributes:\n        if name in \"_edge_ids_in_fused_csc_sampling_graph\":\n            continue\n        val = getattr(sampled_subgraph, name)\n\n        def _add_indent(_str, indent):\n            lines = _str.split(\"\\n\")\n            lines = [lines[0]] + [\" \" * indent + line for line in lines[1:]]\n            return \"\\n\".join(lines)\n\n        val = str(val)\n        final_str = (\n            final_str\n            + f\"{name}={_add_indent(val, len(name) + len(classname) + 1)},\\n\"\n            + \" \" * len(classname)\n        )\n    return final_str[: -len(classname)] + \")\"\n"
  },
  {
    "path": "python/dgl/graphbolt/impl/temporal_neighbor_sampler.py",
    "content": "\"\"\"Temporal neighbor subgraph samplers for GraphBolt.\"\"\"\nimport torch\nfrom torch.utils.data import functional_datapipe\n\nfrom ..internal import compact_csc_format\n\nfrom ..subgraph_sampler import SubgraphSampler\nfrom .sampled_subgraph_impl import SampledSubgraphImpl\n\n\n__all__ = [\"TemporalNeighborSampler\", \"TemporalLayerNeighborSampler\"]\n\n\nclass TemporalNeighborSamplerImpl(SubgraphSampler):\n    \"\"\"Base class for TemporalNeighborSamplers.\"\"\"\n\n    def __init__(\n        self,\n        datapipe,\n        graph,\n        fanouts,\n        replace,\n        prob_name,\n        node_timestamp_attr_name,\n        edge_timestamp_attr_name,\n        sampler,\n    ):\n        super().__init__(datapipe)\n        self.graph = graph\n        # Convert fanouts to a list of tensors.\n        self.fanouts = []\n        for fanout in fanouts:\n            if not isinstance(fanout, torch.Tensor):\n                fanout = torch.LongTensor([int(fanout)])\n            self.fanouts.insert(0, fanout)\n        self.replace = replace\n        self.prob_name = prob_name\n        self.node_timestamp_attr_name = node_timestamp_attr_name\n        self.edge_timestamp_attr_name = edge_timestamp_attr_name\n        self.sampler = sampler\n\n    def sample_subgraphs(\n        self, seeds, seeds_timestamp, seeds_pre_time_window=None\n    ):\n        assert (\n            seeds_timestamp is not None\n        ), \"seeds_timestamp must be provided for temporal neighbor sampling.\"\n        subgraphs = []\n        num_layers = len(self.fanouts)\n        # Enrich seeds with all node types. Ensure that the dtype and device\n        # remain consistent with those of the existing seeds.\n        if isinstance(seeds, dict):\n            first_val = next(iter(seeds.items()))[1]\n            ntypes = list(self.graph.node_type_to_id.keys())\n            seeds = {\n                ntype: seeds.get(\n                    ntype,\n                    torch.tensor(\n                        [], dtype=first_val.dtype, device=first_val.device\n                    ),\n                )\n                for ntype in ntypes\n            }\n            empty_tensor = torch.tensor(\n                [], dtype=torch.int64, device=first_val.device\n            )\n            seeds_timestamp = {\n                ntype: seeds_timestamp.get(ntype, empty_tensor)\n                for ntype in ntypes\n            }\n            if seeds_pre_time_window:\n                seeds_pre_time_window = {\n                    ntype: seeds_pre_time_window.get(ntype, empty_tensor)\n                    for ntype in ntypes\n                }\n        for hop in range(num_layers):\n            subgraph = self.sampler(\n                seeds,\n                seeds_timestamp,\n                self.fanouts[hop],\n                self.replace,\n                seeds_pre_time_window,\n                self.prob_name,\n                self.node_timestamp_attr_name,\n                self.edge_timestamp_attr_name,\n            )\n            (\n                original_row_node_ids,\n                compacted_csc_formats,\n                row_timestamps,\n            ) = compact_csc_format(subgraph.sampled_csc, seeds, seeds_timestamp)\n\n            subgraph = SampledSubgraphImpl(\n                sampled_csc=compacted_csc_formats,\n                original_column_node_ids=seeds,\n                original_row_node_ids=original_row_node_ids,\n                original_edge_ids=subgraph.original_edge_ids,\n            )\n\n            subgraphs.insert(0, subgraph)\n            seeds = original_row_node_ids\n            seeds_timestamp = row_timestamps\n        return seeds, subgraphs\n\n\n@functional_datapipe(\"temporal_sample_neighbor\")\nclass TemporalNeighborSampler(TemporalNeighborSamplerImpl):\n    \"\"\"Temporally sample neighbor edges from a graph and return sampled\n    subgraphs.\n\n    Functional name: :obj:`temporal_sample_neighbor`.\n\n    Neighbor sampler is responsible for sampling a subgraph from given data. It\n    returns an induced subgraph along with compacted information. In the\n    context of a node classification task, the neighbor sampler directly\n    utilizes the nodes provided as seed nodes. However, in scenarios involving\n    link prediction, the process needs another pre-peocess operation. That is,\n    gathering unique nodes from the given node pairs, encompassing both\n    positive and negative node pairs, and employs these nodes as the seed nodes\n    for subsequent steps.\n\n    Parameters\n    ----------\n    datapipe : DataPipe\n        The datapipe.\n    graph : FusedCSCSamplingGraph\n        The graph on which to perform subgraph sampling.\n    fanouts: list[torch.Tensor] or list[int]\n        The number of edges to be sampled for each node with or without\n        considering edge types. The length of this parameter implicitly\n        signifies the layer of sampling being conducted.\n        Note: The fanout order is from the outermost layer to innermost layer.\n        For example, the fanout '[15, 10, 5]' means that 15 to the outermost\n        layer, 10 to the intermediate layer and 5 corresponds to the innermost\n        layer.\n    replace: bool\n        Boolean indicating whether the sample is preformed with or\n        without replacement. If True, a value can be selected multiple\n        times. Otherwise, each value can be selected only once.\n    prob_name: str, optional\n        The name of an edge attribute used as the weights of sampling for\n        each node. This attribute tensor should contain (unnormalized)\n        probabilities corresponding to each neighboring edge of a node.\n        It must be a 1D floating-point or boolean tensor, with the number\n        of elements equalling the total number of edges.\n    node_timestamp_attr_name: str, optional\n        The name of an node attribute used as the timestamps of nodes.\n        It must be a 1D integer tensor, with the number of elements\n        equalling the total number of nodes.\n    edge_timestamp_attr_name: str, optional\n        The name of an edge attribute used as the timestamps of edges.\n        It must be a 1D integer tensor, with the number of elements\n        equalling the total number of edges.\n\n    Examples\n    -------\n    TODO(zhenkun) : Add an example after the API to pass timestamps is finalized.\n    \"\"\"\n\n    def __init__(\n        self,\n        datapipe,\n        graph,\n        fanouts,\n        replace=False,\n        prob_name=None,\n        node_timestamp_attr_name=None,\n        edge_timestamp_attr_name=None,\n    ):\n        super().__init__(\n            datapipe,\n            graph,\n            fanouts,\n            replace,\n            prob_name,\n            node_timestamp_attr_name,\n            edge_timestamp_attr_name,\n            graph.temporal_sample_neighbors,\n        )\n\n\n@functional_datapipe(\"temporal_sample_layer_neighbor\")\nclass TemporalLayerNeighborSampler(TemporalNeighborSamplerImpl):\n    \"\"\"Temporally sample neighbor edges from a graph and return sampled\n    subgraphs.\n\n    Functional name: :obj:`temporal_sample_layer_neighbor`.\n\n    Sampler that builds computational dependency of node representations via\n    labor sampling for multilayer GNN from the NeurIPS 2023 paper\n    `Layer-Neighbor Sampling -- Defusing Neighborhood Explosion in GNNs\n    <https://proceedings.neurips.cc/paper_files/paper/2023/file/51f9036d5e7ae822da8f6d4adda1fb39-Paper-Conference.pdf>`__\n\n    Layer-Neighbor sampler is responsible for sampling a subgraph from given\n    data. It returns an induced subgraph along with compacted information. In\n    the context of a node classification task, the neighbor sampler directly\n    utilizes the nodes provided as seed nodes. However, in scenarios involving\n    link prediction, the process needs another pre-process operation. That is,\n    gathering unique nodes from the given node pairs, encompassing both\n    positive and negative node pairs, and employs these nodes as the seed nodes\n    for subsequent steps. When the graph is hetero, sampled subgraphs in\n    minibatch will contain every edge type even though it is empty after\n    sampling.\n\n    Implements the approach described in Appendix A.3 of the paper. Similar to\n    dgl.dataloading.LaborSampler but this uses sequential poisson sampling\n    instead of poisson sampling to keep the count of sampled edges per vertex\n    deterministic like NeighborSampler. Thus, it is a drop-in replacement for\n    NeighborSampler. However, unlike NeighborSampler, it samples fewer vertices\n    and edges for multilayer GNN scenario without harming convergence speed with\n    respect to training iterations.\n\n    Parameters\n    ----------\n    datapipe : DataPipe\n        The datapipe.\n    graph : FusedCSCSamplingGraph\n        The graph on which to perform subgraph sampling.\n    fanouts: list[torch.Tensor] or list[int]\n        The number of edges to be sampled for each node with or without\n        considering edge types. The length of this parameter implicitly\n        signifies the layer of sampling being conducted.\n        Note: The fanout order is from the outermost layer to innermost layer.\n        For example, the fanout '[15, 10, 5]' means that 15 to the outermost\n        layer, 10 to the intermediate layer and 5 corresponds to the innermost\n        layer.\n    replace: bool\n        Boolean indicating whether the sample is preformed with or\n        without replacement. If True, a value can be selected multiple\n        times. Otherwise, each value can be selected only once.\n    prob_name: str, optional\n        The name of an edge attribute used as the weights of sampling for\n        each node. This attribute tensor should contain (unnormalized)\n        probabilities corresponding to each neighboring edge of a node.\n        It must be a 1D floating-point or boolean tensor, with the number\n        of elements equalling the total number of edges.\n    node_timestamp_attr_name: str, optional\n        The name of an node attribute used as the timestamps of nodes.\n        It must be a 1D integer tensor, with the number of elements\n        equalling the total number of nodes.\n    edge_timestamp_attr_name: str, optional\n        The name of an edge attribute used as the timestamps of edges.\n        It must be a 1D integer tensor, with the number of elements\n        equalling the total number of edges.\n\n    Examples\n    -------\n    TODO(zhenkun) : Add an example after the API to pass timestamps is finalized.\n    \"\"\"\n\n    def __init__(\n        self,\n        datapipe,\n        graph,\n        fanouts,\n        replace=False,\n        prob_name=None,\n        node_timestamp_attr_name=None,\n        edge_timestamp_attr_name=None,\n    ):\n        super().__init__(\n            datapipe,\n            graph,\n            fanouts,\n            replace,\n            prob_name,\n            node_timestamp_attr_name,\n            edge_timestamp_attr_name,\n            graph.temporal_sample_layer_neighbors,\n        )\n"
  },
  {
    "path": "python/dgl/graphbolt/impl/torch_based_feature_store.py",
    "content": "\"\"\"Torch-based feature store for GraphBolt.\"\"\"\n\nimport copy\nimport textwrap\nfrom typing import Dict, List\n\nimport numpy as np\nimport torch\n\nfrom ..base import (\n    get_device_to_host_uva_stream,\n    get_host_to_device_uva_stream,\n    index_select,\n)\nfrom ..feature_store import Feature\nfrom ..internal_utils import gb_warning, is_wsl\nfrom .basic_feature_store import BasicFeatureStore\nfrom .ondisk_metadata import OnDiskFeatureData\n\n__all__ = [\"TorchBasedFeature\", \"DiskBasedFeature\", \"TorchBasedFeatureStore\"]\n\n\nclass _Waiter:\n    def __init__(self, event, values):\n        self.event = event\n        self.values = values\n\n    def wait(self):\n        \"\"\"Returns the stored value when invoked.\"\"\"\n        self.event.wait()\n        values = self.values\n        # Ensure there is no memory leak.\n        self.event = self.values = None\n        return values\n\n\nclass TorchBasedFeature(Feature):\n    r\"\"\"A wrapper of pytorch based feature.\n\n    Initialize a torch based feature store by a torch feature.\n    Note that the feature can be either in memory or on disk.\n\n    Parameters\n    ----------\n    torch_feature : torch.Tensor\n        The torch feature.\n        Note that the dimension of the tensor should be greater than 1.\n\n    Examples\n    --------\n    >>> import torch\n    >>> from dgl import graphbolt as gb\n\n    1. The feature is in memory.\n\n    >>> torch_feat = torch.arange(10).reshape(2, -1)\n    >>> feature = gb.TorchBasedFeature(torch_feat)\n    >>> feature.read()\n    tensor([[0, 1, 2, 3, 4],\n            [5, 6, 7, 8, 9]])\n    >>> feature.read(torch.tensor([0]))\n    tensor([[0, 1, 2, 3, 4]])\n    >>> feature.update(torch.tensor([[1 for _ in range(5)]]),\n    ...                      torch.tensor([1]))\n    >>> feature.read(torch.tensor([0, 1]))\n    tensor([[0, 1, 2, 3, 4],\n            [1, 1, 1, 1, 1]])\n    >>> feature.size()\n    torch.Size([5])\n\n    2. The feature is on disk. Note that you can use gb.numpy_save_aligned as a\n    replacement for np.save to potentially get increased performance.\n\n    >>> import numpy as np\n    >>> arr = np.array([[1, 2], [3, 4]])\n    >>> np.save(\"/tmp/arr.npy\", arr)\n    >>> torch_feat = torch.from_numpy(np.load(\"/tmp/arr.npy\", mmap_mode=\"r+\"))\n    >>> feature = gb.TorchBasedFeature(torch_feat)\n    >>> feature.read()\n    tensor([[1, 2],\n            [3, 4]])\n    >>> feature.read(torch.tensor([0]))\n    tensor([[1, 2]])\n\n    3. Pinned CPU feature.\n\n    >>> torch_feat = torch.arange(10).reshape(2, -1).pin_memory()\n    >>> feature = gb.TorchBasedFeature(torch_feat)\n    >>> feature.read().device\n    device(type='cuda', index=0)\n    >>> feature.read(torch.tensor([0]).cuda()).device\n    device(type='cuda', index=0)\n    \"\"\"\n\n    def __init__(self, torch_feature: torch.Tensor, metadata: Dict = None):\n        super().__init__()\n        self._is_inplace_pinned = set()\n        assert isinstance(torch_feature, torch.Tensor), (\n            f\"torch_feature in TorchBasedFeature must be torch.Tensor, \"\n            f\"but got {type(torch_feature)}.\"\n        )\n        assert torch_feature.dim() > 1, (\n            f\"dimension of torch_feature in TorchBasedFeature must be greater \"\n            f\"than 1, but got {torch_feature.dim()} dimension.\"\n        )\n        # Make sure the tensor is contiguous.\n        self._tensor = torch_feature.contiguous()\n        self._metadata = metadata\n\n    def __del__(self):\n        # torch.Tensor.pin_memory() is not an inplace operation. To make it\n        # truly in-place, we need to use cudaHostRegister. Then, we need to use\n        # cudaHostUnregister to unpin the tensor in the destructor.\n        # https://github.com/pytorch/pytorch/issues/32167#issuecomment-753551842\n        for tensor in self._is_inplace_pinned:\n            assert self._inplace_unpinner(tensor.data_ptr()) == 0\n\n    def read(self, ids: torch.Tensor = None):\n        \"\"\"Read the feature by index.\n\n        If the feature is on pinned CPU memory and `ids` is on GPU or pinned CPU\n        memory, it will be read by GPU and the returned tensor will be on GPU.\n        Otherwise, the returned tensor will be on CPU.\n\n        Parameters\n        ----------\n        ids : torch.Tensor, optional\n            The index of the feature. If specified, only the specified indices\n            of the feature are read. If None, the entire feature is returned.\n\n        Returns\n        -------\n        torch.Tensor\n            The read feature.\n        \"\"\"\n        if ids is None:\n            if self._tensor.is_pinned():\n                return self._tensor.cuda()\n            return self._tensor\n        return index_select(self._tensor, ids)\n\n    def read_async(self, ids: torch.Tensor):\n        r\"\"\"Read the feature by index asynchronously.\n\n        Parameters\n        ----------\n        ids : torch.Tensor\n            The index of the feature. Only the specified indices of the\n            feature are read.\n        Returns\n        -------\n        A generator object.\n            The returned generator object returns a future on\n            ``read_async_num_stages(ids.device)``\\ th invocation. The return result\n            can be accessed by calling ``.wait()``. on the returned future object.\n            It is undefined behavior to call ``.wait()`` more than once.\n\n        Examples\n        --------\n        >>> import dgl.graphbolt as gb\n        >>> feature = gb.Feature(...)\n        >>> ids = torch.tensor([0, 2])\n        >>> for stage, future in enumerate(feature.read_async(ids)):\n        ...     pass\n        >>> assert stage + 1 == feature.read_async_num_stages(ids.device)\n        >>> result = future.wait()  # result contains the read values.\n        \"\"\"\n        assert self._tensor.device.type == \"cpu\"\n        if ids.is_cuda and self.is_pinned():\n            current_stream = torch.cuda.current_stream()\n            host_to_device_stream = get_host_to_device_uva_stream()\n            host_to_device_stream.wait_stream(current_stream)\n            with torch.cuda.stream(host_to_device_stream):\n                ids.record_stream(torch.cuda.current_stream())\n                values = index_select(self._tensor, ids)\n                values.record_stream(current_stream)\n                values_copy_event = torch.cuda.Event()\n                values_copy_event.record()\n\n            yield _Waiter(values_copy_event, values)\n        elif ids.is_cuda:\n            ids_device = ids.device\n            current_stream = torch.cuda.current_stream()\n            device_to_host_stream = get_device_to_host_uva_stream()\n            device_to_host_stream.wait_stream(current_stream)\n            with torch.cuda.stream(device_to_host_stream):\n                ids.record_stream(torch.cuda.current_stream())\n                ids = ids.to(self._tensor.device, non_blocking=True)\n                ids_copy_event = torch.cuda.Event()\n                ids_copy_event.record()\n\n            yield  # first stage is done.\n\n            ids_copy_event.synchronize()\n            values = torch.ops.graphbolt.index_select_async(self._tensor, ids)\n            yield\n\n            host_to_device_stream = get_host_to_device_uva_stream()\n            with torch.cuda.stream(host_to_device_stream):\n                values_cuda = values.wait().to(ids_device, non_blocking=True)\n                values_cuda.record_stream(current_stream)\n                values_copy_event = torch.cuda.Event()\n                values_copy_event.record()\n\n            yield _Waiter(values_copy_event, values_cuda)\n        else:\n            yield torch.ops.graphbolt.index_select_async(self._tensor, ids)\n\n    def read_async_num_stages(self, ids_device: torch.device):\n        \"\"\"The number of stages of the read_async operation. See read_async\n        function for directions on its use. This function is required to return\n        the number of yield operations when read_async is used with a tensor\n        residing on ids_device.\n\n        Parameters\n        ----------\n        ids_device : torch.device\n            The device of the ids parameter passed into read_async.\n        Returns\n        -------\n        int\n            The number of stages of the read_async operation.\n        \"\"\"\n        if ids_device.type == \"cuda\":\n            if self._tensor.is_cuda:\n                # If the ids and the tensor are on cuda, no need for async.\n                return 0\n            return 1 if self.is_pinned() else 3\n        else:\n            return 1\n\n    def size(self):\n        \"\"\"Get the size of the feature.\n\n        Returns\n        -------\n        torch.Size\n            The size of the feature.\n        \"\"\"\n        return self._tensor.size()[1:]\n\n    def count(self):\n        \"\"\"Get the count of the feature.\n\n        Returns\n        -------\n        int\n            The count of the feature.\n        \"\"\"\n        return self._tensor.size()[0]\n\n    def update(self, value: torch.Tensor, ids: torch.Tensor = None):\n        \"\"\"Update the feature store.\n\n        Parameters\n        ----------\n        value : torch.Tensor\n            The updated value of the feature.\n        ids : torch.Tensor, optional\n            The indices of the feature to update. If specified, only the\n            specified indices of the feature will be updated. For the feature,\n            the `ids[i]` row is updated to `value[i]`. So the indices and value\n            must have the same length. If None, the entire feature will be\n            updated.\n        \"\"\"\n        if ids is None:\n            self._tensor = value\n        else:\n            assert ids.shape[0] == value.shape[0], (\n                f\"ids and value must have the same length, \"\n                f\"but got {ids.shape[0]} and {value.shape[0]}.\"\n            )\n            assert self.size() == value.size()[1:], (\n                f\"The size of the feature is {self.size()}, \"\n                f\"while the size of the value is {value.size()[1:]}.\"\n            )\n            if self._tensor.is_pinned() and value.is_cuda and ids.is_cuda:\n                raise NotImplementedError(\n                    \"Update the feature on pinned CPU memory by GPU is not \"\n                    \"supported yet.\"\n                )\n            self._tensor[ids] = value\n\n    def metadata(self):\n        \"\"\"Get the metadata of the feature.\n\n        Returns\n        -------\n        Dict\n            The metadata of the feature.\n        \"\"\"\n        return (\n            self._metadata if self._metadata is not None else super().metadata()\n        )\n\n    def pin_memory_(self):\n        \"\"\"In-place operation to copy the feature to pinned memory. Returns the\n        same object modified in-place.\"\"\"\n        # torch.Tensor.pin_memory() is not an inplace operation. To make it\n        # truly in-place, we need to use cudaHostRegister. Then, we need to use\n        # cudaHostUnregister to unpin the tensor in the destructor.\n        # https://github.com/pytorch/pytorch/issues/32167#issuecomment-753551842\n        x = self._tensor\n        if not x.is_pinned() and x.device.type == \"cpu\":\n            assert (\n                x.is_contiguous()\n            ), \"Tensor pinning is only supported for contiguous tensors.\"\n            cudart = torch.cuda.cudart()\n            assert (\n                cudart.cudaHostRegister(\n                    x.data_ptr(), x.numel() * x.element_size(), 0\n                )\n                == 0\n            )\n\n            self._is_inplace_pinned.add(x)\n            self._inplace_unpinner = cudart.cudaHostUnregister\n\n        return self\n\n    def is_pinned(self):\n        \"\"\"Returns True if the stored feature is pinned.\"\"\"\n        return self._tensor.is_pinned()\n\n    def to(self, device):  # pylint: disable=invalid-name\n        \"\"\"Copy `TorchBasedFeature` to the specified device.\"\"\"\n        # copy.copy is a shallow copy so it does not copy tensor memory.\n        self2 = copy.copy(self)\n        if device == \"pinned\":\n            self2._tensor = self2._tensor.pin_memory()\n        else:\n            self2._tensor = self2._tensor.to(device)\n        return self2\n\n    def __repr__(self) -> str:\n        ret = (\n            \"{Classname}(\\n\"\n            \"    feature={feature},\\n\"\n            \"    metadata={metadata},\\n\"\n            \")\"\n        )\n\n        feature_str = textwrap.indent(\n            str(self._tensor), \" \" * len(\"    feature=\")\n        ).strip()\n        metadata_str = textwrap.indent(\n            str(self.metadata()), \" \" * len(\"    metadata=\")\n        ).strip()\n\n        return ret.format(\n            Classname=self.__class__.__name__,\n            feature=feature_str,\n            metadata=metadata_str,\n        )\n\n\nclass DiskBasedFeature(Feature):\n    r\"\"\"A wrapper of disk based feature.\n\n    Initialize a disk based feature fetcher by a numpy file. Note that you can\n    use gb.numpy_save_aligned as a replacement for np.save to potentially get\n    increased performance.\n\n    Parameters\n    ----------\n    path : string\n        The path to the numpy feature file.\n        Note that the dimension of the numpy should be greater than 1.\n    metadata : Dict\n        The metadata of the feature.\n    num_threads : int\n        The number of threads driving io_uring queues.\n    Examples\n    --------\n    >>> import torch\n    >>> from dgl import graphbolt as gb\n    >>> torch_feat = torch.arange(10).reshape(2, -1)\n    >>> pth = \"path/to/feat.npy\"\n    >>> np.save(pth, torch_feat)\n    >>> feature = gb.DiskBasedFeature(pth)\n    >>> feature.read(torch.tensor([0]))\n    tensor([[0, 1, 2, 3, 4]])\n    >>> feature.size()\n    torch.Size([5])\n    \"\"\"\n\n    def __init__(self, path: str, metadata: Dict = None, num_threads=None):\n        super().__init__()\n        mmap_mode = \"r+\"\n        ondisk_data = np.load(path, mmap_mode=mmap_mode)\n        assert ondisk_data.flags[\n            \"C_CONTIGUOUS\"\n        ], \"DiskBasedFeature only supports C_CONTIGUOUS array.\"\n        self._tensor = torch.from_numpy(ondisk_data)\n\n        self._metadata = metadata\n        if torch.ops.graphbolt.detect_io_uring():\n            self._ondisk_npy_array = torch.ops.graphbolt.ondisk_npy_array(\n                path, self._tensor.dtype, self._tensor.shape, num_threads\n            )\n\n    def read(self, ids: torch.Tensor = None):\n        \"\"\"Read the feature by index.\n        The returned tensor will be on CPU.\n        Parameters\n        ----------\n        ids : torch.Tensor\n            The index of the feature. Only the specified indices of the\n            feature are read.\n        Returns\n        -------\n        torch.Tensor\n            The read feature.\n        \"\"\"\n        if ids is None:\n            return self._tensor\n        elif torch.ops.graphbolt.detect_io_uring():\n            try:\n                return self._ondisk_npy_array.index_select(ids).wait()\n            except RuntimeError:\n                raise IndexError\n        else:\n            return index_select(self._tensor, ids)\n\n    def read_async(self, ids: torch.Tensor):\n        r\"\"\"Read the feature by index asynchronously.\n\n        Parameters\n        ----------\n        ids : torch.Tensor\n            The index of the feature. Only the specified indices of the\n            feature are read.\n        Returns\n        -------\n        A generator object.\n            The returned generator object returns a future on\n            ``read_async_num_stages(ids.device)``\\ th invocation. The return result\n            can be accessed by calling ``.wait()``. on the returned future object.\n            It is undefined behavior to call ``.wait()`` more than once.\n\n        Examples\n        --------\n        >>> import dgl.graphbolt as gb\n        >>> feature = gb.Feature(...)\n        >>> ids = torch.tensor([0, 2])\n        >>> for stage, future in enumerate(feature.read_async(ids)):\n        ...     pass\n        >>> assert stage + 1 == feature.read_async_num_stages(ids.device)\n        >>> result = future.wait()  # result contains the read values.\n        \"\"\"\n        assert torch.ops.graphbolt.detect_io_uring()\n        if ids.is_cuda:\n            ids_device = ids.device\n            current_stream = torch.cuda.current_stream()\n            device_to_host_stream = get_device_to_host_uva_stream()\n            device_to_host_stream.wait_stream(current_stream)\n            with torch.cuda.stream(device_to_host_stream):\n                ids.record_stream(torch.cuda.current_stream())\n                ids = ids.to(self._tensor.device, non_blocking=True)\n                ids_copy_event = torch.cuda.Event()\n                ids_copy_event.record()\n\n            yield  # first stage is done.\n\n            ids_copy_event.synchronize()\n            values = self._ondisk_npy_array.index_select(ids)\n            yield\n\n            host_to_device_stream = get_host_to_device_uva_stream()\n            with torch.cuda.stream(host_to_device_stream):\n                values_cuda = values.wait().to(ids_device, non_blocking=True)\n                values_cuda.record_stream(current_stream)\n                values_copy_event = torch.cuda.Event()\n                values_copy_event.record()\n\n            yield _Waiter(values_copy_event, values_cuda)\n        else:\n            yield self._ondisk_npy_array.index_select(ids)\n\n    def read_async_num_stages(self, ids_device: torch.device):\n        \"\"\"The number of stages of the read_async operation. See read_async\n        function for directions on its use. This function is required to return\n        the number of yield operations when read_async is used with a tensor\n        residing on ids_device.\n\n        Parameters\n        ----------\n        ids_device : torch.device\n            The device of the ids parameter passed into read_async.\n        Returns\n        -------\n        int\n            The number of stages of the read_async operation.\n        \"\"\"\n        return 3 if ids_device.type == \"cuda\" else 1\n\n    def size(self):\n        \"\"\"Get the size of the feature.\n        Returns\n        -------\n        torch.Size\n            The size of the feature.\n        \"\"\"\n        return self._tensor.size()[1:]\n\n    def count(self):\n        \"\"\"Get the count of the feature.\n\n        Returns\n        -------\n        int\n            The count of the feature.\n        \"\"\"\n        return self._tensor.size()[0]\n\n    def update(self, value: torch.Tensor, ids: torch.Tensor = None):\n        \"\"\"Disk based feature does not support update for now.\"\"\"\n        raise NotImplementedError\n\n    def metadata(self):\n        \"\"\"Get the metadata of the feature.\n        Returns\n        -------\n        Dict\n            The metadata of the feature.\n        \"\"\"\n        return (\n            self._metadata if self._metadata is not None else super().metadata()\n        )\n\n    def read_into_memory(self) -> TorchBasedFeature:\n        \"\"\"Change disk-based feature to torch-based feature.\"\"\"\n        return TorchBasedFeature(self._tensor, self._metadata)\n\n    def to(self, _):  # pylint: disable=invalid-name\n        \"\"\"Placeholder `DiskBasedFeature` to implementation. It is a no-op.\"\"\"\n        gb_warning(\n            \"`DiskBasedFeature.to(device)` is not supported. Leaving unmodified.\"\n        )\n        return self\n\n    def pin_memory_(self):  # pylint: disable=invalid-name\n        r\"\"\"Placeholder `DiskBasedFeature` pin_memory_ implementation. It is a no-op.\"\"\"\n        gb_warning(\n            \"`DiskBasedFeature.pin_memory_()` is not supported. Leaving unmodified.\"\n        )\n        return self\n\n    def __repr__(self) -> str:\n        ret = (\n            \"{Classname}(\\n\"\n            \"    feature={feature},\\n\"\n            \"    metadata={metadata},\\n\"\n            \")\"\n        )\n\n        feature_str = textwrap.indent(\n            str(self._tensor), \" \" * len(\"    feature=\")\n        ).strip()\n        metadata_str = textwrap.indent(\n            str(self.metadata()), \" \" * len(\"    metadata=\")\n        ).strip()\n\n        return ret.format(\n            Classname=self.__class__.__name__,\n            feature=feature_str,\n            metadata=metadata_str,\n        )\n\n\nclass TorchBasedFeatureStore(BasicFeatureStore):\n    r\"\"\"A store to manage multiple pytorch based feature for access.\n\n    The feature stores are described by the `feat_data`. The `feat_data` is a\n    list of `OnDiskFeatureData`.\n\n    For a feature store, its format must be either \"pt\" or \"npy\" for Pytorch or\n    Numpy formats. If the format is \"pt\", the feature store must be loaded in\n    memory. If the format is \"npy\", the feature store can be loaded in memory or\n    on disk. Note that you can use gb.numpy_save_aligned as a replacement for\n    np.save to potentially get increased performance.\n\n    Parameters\n    ----------\n    feat_data : List[OnDiskFeatureData]\n        The description of the feature stores.\n\n    Examples\n    --------\n    >>> import torch\n    >>> import numpy as np\n    >>> from dgl import graphbolt as gb\n    >>> edge_label = torch.tensor([[1], [2], [3]])\n    >>> node_feat = torch.tensor([[1, 2, 3], [4, 5, 6]])\n    >>> torch.save(edge_label, \"/tmp/edge_label.pt\")\n    >>> gb.numpy_save_aligned(\"/tmp/node_feat.npy\", node_feat.numpy())\n    >>> feat_data = [\n    ...     gb.OnDiskFeatureData(domain=\"edge\", type=\"author:writes:paper\",\n    ...         name=\"label\", format=\"torch\", path=\"/tmp/edge_label.pt\",\n    ...         in_memory=True),\n    ...     gb.OnDiskFeatureData(domain=\"node\", type=\"paper\", name=\"feat\",\n    ...         format=\"numpy\", path=\"/tmp/node_feat.npy\", in_memory=False),\n    ... ]\n    >>> feature_store = gb.TorchBasedFeatureStore(feat_data)\n    \"\"\"\n\n    def __init__(self, feat_data: List[OnDiskFeatureData]):\n        features = {}\n        for spec in feat_data:\n            key = (spec.domain, spec.type, spec.name)\n            metadata = spec.extra_fields\n            if spec.format == \"torch\":\n                assert spec.in_memory, (\n                    f\"Pytorch tensor can only be loaded in memory, \"\n                    f\"but the feature {key} is loaded on disk.\"\n                )\n                features[key] = TorchBasedFeature(\n                    torch.load(spec.path, weights_only=False), metadata=metadata\n                )\n            elif spec.format == \"numpy\":\n                if spec.in_memory:\n                    # TorchBasedFeature is always in memory by default.\n                    features[key] = TorchBasedFeature(\n                        torch.as_tensor(np.load(spec.path)), metadata=metadata\n                    )\n                else:\n                    # DiskBasedFeature is always out of memory by default.\n                    features[key] = DiskBasedFeature(\n                        spec.path, metadata=metadata\n                    )\n            else:\n                raise ValueError(f\"Unknown feature format {spec.format}\")\n        super().__init__(features)\n\n    def pin_memory_(self):\n        \"\"\"In-place operation to copy the feature store to pinned memory.\n        Returns the same object modified in-place.\"\"\"\n        if is_wsl():\n            gb_warning(\n                \"In place pinning is not supported on WSL. \"\n                \"Returning the out of place pinned `TorchBasedFeatureStore`.\"\n            )\n            return self.to(\"pinned\")\n        for feature in self._features.values():\n            feature.pin_memory_()\n\n        return self\n\n    def is_pinned(self):\n        \"\"\"Returns True if all the stored features are pinned.\"\"\"\n        return all(feature.is_pinned() for feature in self._features.values())\n\n    def to(self, device):  # pylint: disable=invalid-name\n        \"\"\"Copy `TorchBasedFeatureStore` to the specified device.\"\"\"\n        # copy.copy is a shallow copy so it does not copy tensor memory.\n        self2 = copy.copy(self)\n        self2._features = {k: v.to(device) for k, v in self2._features.items()}\n        return self2\n\n    def __repr__(self) -> str:\n        ret = \"{Classname}(\\n\" + \"    {features}\\n\" + \")\"\n        features_str = textwrap.indent(str(self._features), \"    \").strip()\n        return ret.format(\n            Classname=self.__class__.__name__, features=features_str\n        )\n"
  },
  {
    "path": "python/dgl/graphbolt/impl/uniform_negative_sampler.py",
    "content": "\"\"\"Uniform negative sampler for GraphBolt.\"\"\"\n\nimport torch\nfrom torch.utils.data import functional_datapipe\n\nfrom ..negative_sampler import NegativeSampler\n\n__all__ = [\"UniformNegativeSampler\"]\n\n\n@functional_datapipe(\"sample_uniform_negative\")\nclass UniformNegativeSampler(NegativeSampler):\n    \"\"\"Sample negative destination nodes for each source node based on a uniform\n    distribution.\n\n    Functional name: :obj:`sample_uniform_negative`.\n\n    It's important to note that the term 'negative' refers to false negatives,\n    indicating that the sampled pairs are not ensured to be absent in the graph.\n    For each edge ``(u, v)``, it is supposed to generate `negative_ratio` pairs\n    of negative edges ``(u, v')``, where ``v'`` is chosen uniformly from all\n    the nodes in the graph.\n\n    Parameters\n    ----------\n    datapipe : DataPipe\n        The datapipe.\n    graph : FusedCSCSamplingGraph\n        The graph on which to perform negative sampling.\n    negative_ratio : int\n        The proportion of negative samples to positive samples.\n\n    Examples\n    --------\n    >>> from dgl import graphbolt as gb\n    >>> indptr = torch.LongTensor([0, 1, 2, 3, 4])\n    >>> indices = torch.LongTensor([1, 2, 3, 0])\n    >>> graph = gb.fused_csc_sampling_graph(indptr, indices)\n    >>> seeds = torch.tensor([[0, 1], [1, 2], [2, 3], [3, 0]])\n    >>> item_set = gb.ItemSet(seeds, names=\"seeds\")\n    >>> item_sampler = gb.ItemSampler(\n    ...     item_set, batch_size=4,)\n    >>> neg_sampler = gb.UniformNegativeSampler(\n    ...     item_sampler, graph, 2)\n    >>> for minibatch in neg_sampler:\n    ...       print(minibatch.seeds)\n    ...       print(minibatch.labels)\n    ...       print(minibatch.indexes)\n    tensor([[0, 1], [1, 2], [2, 3], [3, 0], [0, 1], [0, 3], [1, 1], [1, 2],\n        [2, 1], [2, 0], [3, 0], [3, 2]])\n    tensor([1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.])\n    tensor([0, 1, 2, 3, 0, 0, 1, 1, 2, 2, 3, 3])\n    \"\"\"\n\n    def __init__(\n        self,\n        datapipe,\n        graph,\n        negative_ratio,\n    ):\n        super().__init__(datapipe, negative_ratio)\n        self.graph = graph\n\n    def _sample_with_etype(self, seeds, etype=None):\n        assert seeds.ndim == 2 and seeds.shape[1] == 2, (\n            \"Only tensor with shape N*2 is supported for negative\"\n            + f\" sampling, but got {seeds.shape}.\"\n        )\n        # Sample negative edges, and concatenate positive edges with them.\n        all_seeds = self.graph.sample_negative_edges_uniform(\n            etype,\n            seeds,\n            self.negative_ratio,\n        )\n        # Construct indexes for all node pairs.\n        pos_num = seeds.shape[0]\n        negative_ratio = self.negative_ratio\n        pos_indexes = torch.arange(0, pos_num, device=all_seeds.device)\n        neg_indexes = pos_indexes.repeat_interleave(negative_ratio)\n        indexes = torch.cat((pos_indexes, neg_indexes))\n        # Construct labels for all node pairs.\n        neg_num = all_seeds.shape[0] - pos_num\n        labels = torch.empty(pos_num + neg_num, device=all_seeds.device)\n        labels[:pos_num] = 1\n        labels[pos_num:] = 0\n        return all_seeds, labels, indexes\n"
  },
  {
    "path": "python/dgl/graphbolt/internal/__init__.py",
    "content": "\"\"\"Utility functions for GraphBolt.\"\"\"\nfrom .utils import *\nfrom .sample_utils import *\nfrom .item_sampler_utils import *\n"
  },
  {
    "path": "python/dgl/graphbolt/internal/item_sampler_utils.py",
    "content": "\"\"\"Utility functions for DistributedItemSampler.\"\"\"\n\n\ndef count_split(total, num_workers, worker_id, batch_size=1):\n    \"\"\"Calculate the number of assigned items after splitting them by batch\n    size evenly. It will return the number for this worker and also a sum of\n    previous workers.\n    \"\"\"\n    quotient, remainder = divmod(total, num_workers * batch_size)\n    if batch_size == 1:\n        assigned = quotient + (worker_id < remainder)\n    else:\n        batch_count, last_batch = divmod(remainder, batch_size)\n        assigned = quotient * batch_size + (\n            batch_size\n            if worker_id < batch_count\n            else (last_batch if worker_id == batch_count else 0)\n        )\n    prefix_sum = quotient * worker_id * batch_size + min(\n        worker_id * batch_size, remainder\n    )\n    return (assigned, prefix_sum)\n\n\ndef calculate_range(\n    distributed,\n    total,\n    num_replicas,\n    rank,\n    num_workers,\n    worker_id,\n    batch_size,\n    drop_last,\n    drop_uneven_inputs,\n):\n    \"\"\"Calculates the range of items to be assigned to the current worker.\n\n    This function evenly distributes `total` items among multiple workers,\n    batching them using `batch_size`. Each replica has `num_workers` workers.\n    The batches generated by workers within the same replica are combined into\n    the replica`s output. The `drop_last` parameter determines whether\n    incomplete batches should be dropped. If `drop_last` is True, incomplete\n    batches are discarded. The `drop_uneven_inputs` parameter determines if the\n    number of batches assigned to each replica should be the same. If\n    `drop_uneven_inputs` is True, excessive batches for some replicas will be\n    dropped.\n\n    Args:\n        distributed (bool): Whether it's in distributed mode.\n        total (int): The total number of items.\n        num_replicas (int): The total number of replicas.\n        rank (int): The rank of the current replica.\n        num_workers (int): The number of workers per replica.\n        worker_id (int): The ID of the current worker.\n        batch_size (int): The desired batch size.\n        drop_last (bool): Whether to drop incomplete batches.\n        drop_uneven_inputs (bool): Whether to drop excessive batches for some\n          replicas.\n\n    Returns:\n        tuple: A tuple containing three numbers:\n            - start_offset (int): The starting offset of the range assigned to\n              the current worker.\n            - assigned_count (int): The length of the range assigned to the\n              current worker.\n            - output_count (int): The number of items that the current worker\n              will produce after dropping.\n    \"\"\"\n    # Check if it's distributed mode.\n    if not distributed:\n        if not drop_last:\n            return (0, total, total)\n        else:\n            return (0, total, total // batch_size * batch_size)\n    # First, equally distribute items into all replicas.\n    assigned_count, start_offset = count_split(\n        total, num_replicas, rank, batch_size\n    )\n    # Calculate the number of outputs when drop_uneven_inputs is True.\n    # `assigned_count` is the number of items distributed to the current\n    # process. `output_count` is the number of items should be output\n    # by this process after dropping.\n    if not drop_uneven_inputs:\n        if not drop_last:\n            output_count = assigned_count\n        else:\n            output_count = assigned_count // batch_size * batch_size\n    else:\n        if not drop_last:\n            min_item_count, _ = count_split(\n                total, num_replicas, num_replicas - 1, batch_size\n            )\n            min_batch_count = (min_item_count + batch_size - 1) // batch_size\n            output_count = min(min_batch_count * batch_size, assigned_count)\n        else:\n            output_count = total // (batch_size * num_replicas) * batch_size\n    # If there are multiple workers, equally distribute the batches to\n    # all workers.\n    if num_workers > 1:\n        # Equally distribute the dropped number too.\n        dropped_items, prev_dropped_items = count_split(\n            assigned_count - output_count, num_workers, worker_id\n        )\n        output_count, prev_output_count = count_split(\n            output_count,\n            num_workers,\n            worker_id,\n            batch_size,\n        )\n        assigned_count = output_count + dropped_items\n        start_offset += prev_output_count + prev_dropped_items\n    return (start_offset, assigned_count, output_count)\n"
  },
  {
    "path": "python/dgl/graphbolt/internal/sample_utils.py",
    "content": "\"\"\"Utility functions for sampling.\"\"\"\n\nfrom collections import defaultdict\nfrom typing import Dict, List, Optional, Tuple, Union\n\nimport torch\n\nfrom ..base import CSCFormatBase, etype_str_to_tuple, expand_indptr\n\n\ndef unique_and_compact(\n    nodes: Union[\n        List[torch.Tensor],\n        Dict[str, List[torch.Tensor]],\n    ],\n    rank: int = 0,\n    world_size: int = 1,\n    async_op: bool = False,\n):\n    \"\"\"\n    Compact a list of nodes tensor. The `rank` and `world_size` parameters are\n    relevant when using Cooperative Minibatching, which was initially proposed\n    in `Deep Graph Library PR#4337<https://github.com/dmlc/dgl/pull/4337>`__ and\n    was later first fully described in\n    `Cooperative Minibatching in Graph Neural Networks\n    <https://arxiv.org/abs/2310.12403>`__.\n    Cooperation between the GPUs eliminates duplicate work performed across the\n    GPUs due to the overlapping sampled k-hop neighborhoods of seed nodes when\n    performing GNN minibatching.\n\n    When `world_size` is greater than 1, then the given ids are partitioned\n    between the available ranks. The ids corresponding to the given rank are\n    guaranteed to come before the ids of other ranks. To do this, the\n    partitioned ids are rotated backwards by the given rank so that the ids are\n    ordered as: `[rank, rank + 1, world_size, 0, ..., rank - 1]`. This is\n    supported only for Volta and later generation NVIDIA GPUs.\n\n    Parameters\n    ----------\n    nodes : List[torch.Tensor] or Dict[str, List[torch.Tensor]]\n        List of nodes for compacting.\n        the unique_and_compact will be done per type\n        - If `nodes` is a list of tensor: All the tensors will do unique and\n        compact together, usually it is used for homogeneous graph.\n        - If `nodes` is a list of dictionary: The keys should be node type and\n        the values should be corresponding nodes, the unique and compact will\n        be done per type, usually it is used for heterogeneous graph.\n    rank : int\n        The rank of the current process.\n    world_size : int\n        The number of processes.\n    async_op: bool\n        Boolean indicating whether the call is asynchronous. If so, the result\n        can be obtained by calling wait on the returned future.\n\n    Returns\n    -------\n    Tuple[unique_nodes, compacted_node_list, unique_nodes_offsets]\n        The Unique nodes (per type) of all nodes in the input. And the compacted\n        nodes list, where IDs inside are replaced with compacted node IDs.\n        \"Compacted node list\" indicates that the node IDs in the input node\n        list are replaced with mapped node IDs, where each type of node is\n        mapped to a contiguous space of IDs ranging from 0 to N.\n        The unique nodes offsets tensor partitions the unique_nodes tensor. Has\n        size `world_size + 1` and `unique_nodes[offsets[i]: offsets[i + 1]]`\n        belongs to the rank `(rank + i) % world_size`.\n    \"\"\"\n    is_heterogeneous = isinstance(nodes, dict)\n\n    if not is_heterogeneous:\n        homo_ntype = \"a\"\n        nodes = {homo_ntype: nodes}\n\n    nums = {}\n    concat_nodes, empties = [], []\n    for ntype, nodes_of_type in nodes.items():\n        nums[ntype] = [node.size(0) for node in nodes_of_type]\n        concat_nodes.append(torch.cat(nodes_of_type))\n        empties.append(concat_nodes[-1].new_empty(0))\n    unique_fn = (\n        torch.ops.graphbolt.unique_and_compact_batched_async\n        if async_op\n        else torch.ops.graphbolt.unique_and_compact_batched\n    )\n    results = unique_fn(concat_nodes, empties, empties, rank, world_size)\n\n    class _Waiter:\n        def __init__(self, future, ntypes, nums):\n            self.future = future\n            self.ntypes = ntypes\n            self.nums = nums\n\n        def wait(self):\n            \"\"\"Returns the stored value when invoked.\"\"\"\n            results = self.future.wait() if async_op else self.future\n            ntypes = self.ntypes\n            nums = self.nums\n            # Ensure there is no memory leak.\n            self.future = self.ntypes = self.nums = None\n\n            unique, compacted, offsets = {}, {}, {}\n            for ntype, result in zip(ntypes, results):\n                (\n                    unique[ntype],\n                    concat_compacted,\n                    _,\n                    offsets[ntype],\n                ) = result\n                compacted[ntype] = list(concat_compacted.split(nums[ntype]))\n            if is_heterogeneous:\n                return unique, compacted, offsets\n            else:\n                return (\n                    unique[homo_ntype],\n                    compacted[homo_ntype],\n                    offsets[homo_ntype],\n                )\n\n    post_processer = _Waiter(results, nodes.keys(), nums)\n    if async_op:\n        return post_processer\n    else:\n        return post_processer.wait()\n\n\ndef compact_temporal_nodes(nodes, nodes_timestamp):\n    \"\"\"Compact a list of temporal nodes without unique.\n\n    Note that since there is no unique, the nodes and nodes_timestamp are simply\n    concatenated. And the compacted nodes are consecutive numbers starting from\n    0.\n\n    Parameters\n    ----------\n    nodes : List[torch.Tensor] or Dict[str, List[torch.Tensor]]\n        List of nodes for compacting.\n        the compact operator will be done per type\n        - If `nodes` is a list of tensor: All the tensors will compact together,\n        usually it is used for homogeneous graph.\n        - If `nodes` is a list of dictionary: The keys should be node type and\n        the values should be corresponding nodes, the compact will be done per\n        type, usually it is used for heterogeneous graph.\n\n    nodes_timestamp : List[torch.Tensor] or Dict[str, List[torch.Tensor]]\n        List of timestamps for compacting.\n\n    Returns\n    -------\n    Tuple[nodes, nodes_timestamp, compacted_node_list]\n\n    The concatenated nodes and nodes_timestamp, and the compacted nodes list,\n    where IDs inside are replaced with compacted node IDs.\n    \"\"\"\n\n    def _compact_per_type(per_type_nodes, per_type_nodes_timestamp):\n        nums = [node.size(0) for node in per_type_nodes]\n        per_type_nodes = torch.cat(per_type_nodes)\n        per_type_nodes_timestamp = torch.cat(per_type_nodes_timestamp)\n        compacted_nodes = torch.arange(\n            0,\n            per_type_nodes.numel(),\n            dtype=per_type_nodes.dtype,\n            device=per_type_nodes.device,\n        )\n        compacted_nodes = list(compacted_nodes.split(nums))\n        return per_type_nodes, per_type_nodes_timestamp, compacted_nodes\n\n    if isinstance(nodes, dict):\n        ret_nodes, ret_timestamp, compacted = {}, {}, {}\n        for ntype, nodes_of_type in nodes.items():\n            (\n                ret_nodes[ntype],\n                ret_timestamp[ntype],\n                compacted[ntype],\n            ) = _compact_per_type(nodes_of_type, nodes_timestamp[ntype])\n        return ret_nodes, ret_timestamp, compacted\n    else:\n        return _compact_per_type(nodes, nodes_timestamp)\n\n\ndef unique_and_compact_csc_formats(\n    csc_formats: Union[\n        Tuple[torch.Tensor, torch.Tensor],\n        Dict[str, Tuple[torch.Tensor, torch.Tensor]],\n    ],\n    unique_dst_nodes: Union[\n        torch.Tensor,\n        Dict[str, torch.Tensor],\n    ],\n    rank: int = 0,\n    world_size: int = 1,\n    async_op: bool = False,\n):\n    \"\"\"\n    Compact csc formats and return unique nodes (per type). The `rank` and\n    `world_size` parameters are relevant when using Cooperative Minibatching,\n    which was initially proposed in\n    `Deep Graph Library PR#4337<https://github.com/dmlc/dgl/pull/4337>`__\n    and was later first fully described in\n    `Cooperative Minibatching in Graph Neural Networks\n    <https://arxiv.org/abs/2310.12403>`__.\n    Cooperation between the GPUs eliminates duplicate work performed across the\n    GPUs due to the overlapping sampled k-hop neighborhoods of seed nodes when\n    performing GNN minibatching.\n\n    When `world_size` is greater than 1, then the given ids are partitioned\n    between the available ranks. The ids corresponding to the given rank are\n    guaranteed to come before the ids of other ranks. To do this, the\n    partitioned ids are rotated backwards by the given rank so that the ids are\n    ordered as: `[rank, rank + 1, world_size, 0, ..., rank - 1]`. This is\n    supported only for Volta and later generation NVIDIA GPUs.\n\n    Parameters\n    ----------\n    csc_formats : Union[CSCFormatBase, Dict(str, CSCFormatBase)]\n        CSC formats representing source-destination edges.\n        - If `csc_formats` is a CSCFormatBase: It means the graph is\n        homogeneous. Also, indptr and indice in it should be torch.tensor\n        representing source and destination pairs in csc format. And IDs inside\n        are homogeneous ids.\n        - If `csc_formats` is a Dict[str, CSCFormatBase]: The keys\n        should be edge type and the values should be csc format node pairs.\n        And IDs inside are heterogeneous ids.\n    unique_dst_nodes: torch.Tensor or Dict[str, torch.Tensor]\n        Unique nodes of all destination nodes in the node pairs.\n        - If `unique_dst_nodes` is a tensor: It means the graph is homogeneous.\n        - If `csc_formats` is a dictionary: The keys are node type and the\n        values are corresponding nodes. And IDs inside are heterogeneous ids.\n    rank : int\n        The rank of the current process.\n    world_size : int\n        The number of processes.\n    async_op: bool\n        Boolean indicating whether the call is asynchronous. If so, the result\n        can be obtained by calling wait on the returned future.\n\n    Returns\n    -------\n    Tuple[unique_nodes, csc_formats, unique_nodes_offsets]\n        The compacted csc formats, where node IDs are replaced with mapped node\n        IDs, and the unique nodes (per type).\n        \"Compacted csc formats\" indicates that the node IDs in the input node\n        pairs are replaced with mapped node IDs, where each type of node is\n        mapped to a contiguous space of IDs ranging from 0 to N. The unique\n        nodes offsets tensor partitions the unique_nodes tensor. Has size\n        `world_size + 1` and `unique_nodes[offsets[i]: offsets[i + 1]]` belongs\n        to the rank `(rank + i) % world_size`.\n\n    Examples\n    --------\n    >>> import dgl.graphbolt as gb\n    >>> N1 = torch.LongTensor([1, 2, 2])\n    >>> N2 = torch.LongTensor([5, 5, 6])\n    >>> unique_dst = {\n    ...     \"n1\": torch.LongTensor([1, 2]),\n    ...     \"n2\": torch.LongTensor([5, 6])}\n    >>> csc_formats = {\n    ...     \"n1:e1:n2\": gb.CSCFormatBase(indptr=torch.tensor([0, 2, 3]),indices=N1),\n    ...     \"n2:e2:n1\": gb.CSCFormatBase(indptr=torch.tensor([0, 1, 3]),indices=N2)}\n    >>> unique_nodes, compacted_csc_formats, _ = gb.unique_and_compact_csc_formats(\n    ...     csc_formats, unique_dst\n    ... )\n    >>> print(unique_nodes)\n    {'n1': tensor([1, 2]), 'n2': tensor([5, 6])}\n    >>> print(compacted_csc_formats)\n    {\"n1:e1:n2\": CSCFormatBase(indptr=torch.tensor([0, 2, 3]),\n                               indices=torch.tensor([0, 1, 1])),\n     \"n2:e2:n1\": CSCFormatBase(indptr=torch.tensor([0, 1, 3]),\n                               indices=torch.Longtensor([0, 0, 1]))}\n    \"\"\"\n    is_homogeneous = not isinstance(csc_formats, dict)\n    if is_homogeneous:\n        csc_formats = {\"_N:_E:_N\": csc_formats}\n        if unique_dst_nodes is not None:\n            assert isinstance(\n                unique_dst_nodes, torch.Tensor\n            ), \"Edge type not supported in homogeneous graph.\"\n            unique_dst_nodes = {\"_N\": unique_dst_nodes}\n\n    # Collect all source and destination nodes for each node type.\n    indices = defaultdict(list)\n    device = None\n    for etype, csc_format in csc_formats.items():\n        if device is None:\n            device = csc_format.indices.device\n        src_type, _, dst_type = etype_str_to_tuple(etype)\n        assert len(unique_dst_nodes.get(dst_type, [])) + 1 == len(\n            csc_format.indptr\n        ), \"The seed nodes should correspond to indptr.\"\n        indices[src_type].append(csc_format.indices)\n    indices = {ntype: torch.cat(nodes) for ntype, nodes in indices.items()}\n\n    ntypes = set(indices.keys())\n    dtype = list(indices.values())[0].dtype\n    default_tensor = torch.tensor([], dtype=dtype, device=device)\n    indice_list = []\n    unique_dst_list = []\n    for ntype in ntypes:\n        indice_list.append(indices.get(ntype, default_tensor))\n        unique_dst_list.append(unique_dst_nodes.get(ntype, default_tensor))\n    dst_list = [torch.tensor([], dtype=dtype, device=device)] * len(\n        unique_dst_list\n    )\n    uniq_fn = (\n        torch.ops.graphbolt.unique_and_compact_batched_async\n        if async_op\n        else torch.ops.graphbolt.unique_and_compact_batched\n    )\n    results = uniq_fn(indice_list, dst_list, unique_dst_list, rank, world_size)\n\n    class _Waiter:\n        def __init__(self, future, csc_formats):\n            self.future = future\n            self.csc_formats = csc_formats\n\n        def wait(self):\n            \"\"\"Returns the stored value when invoked.\"\"\"\n            results = self.future.wait() if async_op else self.future\n            csc_formats = self.csc_formats\n            # Ensure there is no memory leak.\n            self.future = self.csc_formats = None\n\n            unique_nodes = {}\n            compacted_indices = {}\n            offsets = {}\n            for i, ntype in enumerate(ntypes):\n                (\n                    unique_nodes[ntype],\n                    compacted_indices[ntype],\n                    _,\n                    offsets[ntype],\n                ) = results[i]\n\n            compacted_csc_formats = {}\n            # Map back with the same order.\n            for etype, csc_format in csc_formats.items():\n                num_elem = csc_format.indices.size(0)\n                src_type, _, _ = etype_str_to_tuple(etype)\n                indice = compacted_indices[src_type][:num_elem]\n                indptr = csc_format.indptr\n                compacted_csc_formats[etype] = CSCFormatBase(\n                    indptr=indptr, indices=indice\n                )\n                compacted_indices[src_type] = compacted_indices[src_type][\n                    num_elem:\n                ]\n\n            # Return singleton for a homogeneous graph.\n            if is_homogeneous:\n                compacted_csc_formats = list(compacted_csc_formats.values())[0]\n                unique_nodes = list(unique_nodes.values())[0]\n                offsets = list(offsets.values())[0]\n\n            return unique_nodes, compacted_csc_formats, offsets\n\n    post_processer = _Waiter(results, csc_formats)\n    if async_op:\n        return post_processer\n    else:\n        return post_processer.wait()\n\n\ndef _broadcast_timestamps(csc, dst_timestamps):\n    \"\"\"Broadcast the timestamp of each destination node to its corresponding\n    source nodes.\"\"\"\n    return expand_indptr(\n        csc.indptr, node_ids=dst_timestamps, output_size=len(csc.indices)\n    )\n\n\ndef compact_csc_format(\n    csc_formats: Union[CSCFormatBase, Dict[str, CSCFormatBase]],\n    dst_nodes: Union[torch.Tensor, Dict[str, torch.Tensor]],\n    dst_timestamps: Optional[\n        Union[torch.Tensor, Dict[str, torch.Tensor]]\n    ] = None,\n):\n    \"\"\"\n    Relabel the row (source) IDs in the csc formats into a contiguous range from\n    0 and return the original row node IDs per type.\n\n    Note that\n    1. The column (destination) IDs are included in the relabeled row IDs.\n    2. If there are repeated row IDs, they would not be uniqued and will be\n    treated as different nodes.\n    3. If `dst_timestamps` is given, the timestamp of each destination node will\n    be broadcasted to its corresponding source nodes.\n\n    Parameters\n    ----------\n    csc_formats: Union[CSCFormatBase, Dict[str, CSCFormatBase]]\n        CSC formats representing source-destination edges.\n        - If `csc_formats` is a CSCFormatBase: It means the graph is\n        homogeneous. Also, indptr and indice in it should be torch.tensor\n        representing source and destination pairs in csc format. And IDs inside\n        are homogeneous ids.\n        - If `csc_formats` is a Dict[str, CSCFormatBase]: The keys\n        should be edge type and the values should be csc format node pairs.\n        And IDs inside are heterogeneous ids.\n    dst_nodes: Union[torch.Tensor, Dict[str, torch.Tensor]]\n        Nodes of all destination nodes in the node pairs.\n        - If `dst_nodes` is a tensor: It means the graph is homogeneous.\n        - If `dst_nodes` is a dictionary: The keys are node type and the\n        values are corresponding nodes. And IDs inside are heterogeneous ids.\n\n    dst_timestamps: Optional[Union[torch.Tensor, Dict[str, torch.Tensor]]]\n        Timestamps of all destination nodes in the csc formats.\n        If given, the timestamp of each destination node will be broadcasted\n        to its corresponding source nodes.\n\n    Returns\n    -------\n    Tuple[original_row_node_ids, compacted_csc_formats, ...]\n        A tensor of original row node IDs (per type) of all nodes in the input.\n        The compacted CSC formats, where node IDs are replaced with mapped node\n        IDs ranging from 0 to N.\n        The source timestamps (per type) of all nodes in the input if\n        `dst_timestamps` is given.\n\n    Examples\n    --------\n    >>> import dgl.graphbolt as gb\n    >>> csc_formats = {\n    ...     \"n2:e2:n1\": gb.CSCFormatBase(\n    ...         indptr=torch.tensor([0, 1, 3]), indices=torch.tensor([5, 4, 6])\n    ...     ),\n    ...     \"n1:e1:n1\": gb.CSCFormatBase(\n    ...         indptr=torch.tensor([0, 1, 3]), indices=torch.tensor([1, 2, 3])\n    ...     ),\n    ... }\n    >>> dst_nodes = {\"n1\": torch.LongTensor([2, 4])}\n    >>> original_row_node_ids, compacted_csc_formats = gb.compact_csc_format(\n    ...     csc_formats, dst_nodes\n    ... )\n    >>> original_row_node_ids\n    {'n1': tensor([2, 4, 1, 2, 3]), 'n2': tensor([5, 4, 6])}\n    >>> compacted_csc_formats\n    {'n2:e2:n1': CSCFormatBase(indptr=tensor([0, 1, 3]),\n                indices=tensor([0, 1, 2]),\n    ), 'n1:e1:n1': CSCFormatBase(indptr=tensor([0, 1, 3]),\n                indices=tensor([2, 3, 4]),\n    )}\n\n    >>> csc_formats = {\n    ...     \"n2:e2:n1\": gb.CSCFormatBase(\n    ...         indptr=torch.tensor([0, 1, 3]), indices=torch.tensor([5, 4, 6])\n    ...     ),\n    ...     \"n1:e1:n1\": gb.CSCFormatBase(\n    ...         indptr=torch.tensor([0, 1, 3]), indices=torch.tensor([1, 2, 3])\n    ...     ),\n    ... }\n    >>> dst_nodes = {\"n1\": torch.LongTensor([2, 4])}\n    >>> original_row_node_ids, compacted_csc_formats = gb.compact_csc_format(\n    ...     csc_formats, dst_nodes\n    ... )\n    >>> original_row_node_ids\n    {'n1': tensor([2, 4, 1, 2, 3]), 'n2': tensor([5, 4, 6])}\n    >>> compacted_csc_formats\n    {'n2:e2:n1': CSCFormatBase(indptr=tensor([0, 1, 3]),\n                indices=tensor([0, 1, 2]),\n    ), 'n1:e1:n1': CSCFormatBase(indptr=tensor([0, 1, 3]),\n                indices=tensor([2, 3, 4]),\n    )}\n\n    >>> dst_timestamps = {\"n1\": torch.LongTensor([10, 20])}\n    >>> (\n    ...     original_row_node_ids,\n    ...     compacted_csc_formats,\n    ...     src_timestamps,\n    ... ) = gb.compact_csc_format(csc_formats, dst_nodes, dst_timestamps)\n    >>> src_timestamps\n    {'n1': tensor([10, 20, 10, 20, 20]), 'n2': tensor([10, 20, 20])}\n    \"\"\"\n    is_homogeneous = not isinstance(csc_formats, dict)\n    has_timestamp = dst_timestamps is not None\n    if is_homogeneous:\n        if dst_nodes is not None:\n            assert isinstance(\n                dst_nodes, torch.Tensor\n            ), \"Edge type not supported in homogeneous graph.\"\n            assert len(dst_nodes) + 1 == len(\n                csc_formats.indptr\n            ), \"The seed nodes should correspond to indptr.\"\n        offset = dst_nodes.size(0)\n        original_row_ids = torch.cat((dst_nodes, csc_formats.indices))\n        compacted_csc_formats = CSCFormatBase(\n            indptr=csc_formats.indptr,\n            indices=(\n                torch.arange(\n                    0,\n                    csc_formats.indices.size(0),\n                    device=csc_formats.indices.device,\n                )\n                + offset\n            ),\n        )\n\n        src_timestamps = None\n        if has_timestamp:\n            src_timestamps = torch.cat(\n                [\n                    dst_timestamps,\n                    _broadcast_timestamps(\n                        compacted_csc_formats, dst_timestamps\n                    ),\n                ]\n            )\n    else:\n        compacted_csc_formats = {}\n        src_timestamps = None\n        original_row_ids = {key: val.clone() for key, val in dst_nodes.items()}\n        if has_timestamp:\n            src_timestamps = {\n                key: val.clone() for key, val in dst_timestamps.items()\n            }\n        for etype, csc_format in csc_formats.items():\n            src_type, _, dst_type = etype_str_to_tuple(etype)\n            assert len(dst_nodes.get(dst_type, [])) + 1 == len(\n                csc_format.indptr\n            ), \"The seed nodes should correspond to indptr.\"\n            device = csc_format.indices.device\n            offset = original_row_ids.get(\n                src_type, torch.tensor([], device=device)\n            ).size(0)\n            original_row_ids[src_type] = torch.cat(\n                (\n                    original_row_ids.get(\n                        src_type,\n                        torch.tensor(\n                            [], dtype=csc_format.indices.dtype, device=device\n                        ),\n                    ),\n                    csc_format.indices,\n                )\n            )\n            compacted_csc_formats[etype] = CSCFormatBase(\n                indptr=csc_format.indptr,\n                indices=(\n                    torch.arange(\n                        0,\n                        csc_format.indices.size(0),\n                        dtype=csc_format.indices.dtype,\n                        device=device,\n                    )\n                    + offset\n                ),\n            )\n            if has_timestamp:\n                # If destination timestamps are given, broadcast them to the\n                # corresponding source nodes.\n                src_timestamps[src_type] = torch.cat(\n                    (\n                        src_timestamps.get(\n                            src_type,\n                            torch.tensor(\n                                [],\n                                dtype=dst_timestamps[dst_type].dtype,\n                                device=device,\n                            ),\n                        ),\n                        _broadcast_timestamps(\n                            csc_format, dst_timestamps[dst_type]\n                        ),\n                    )\n                )\n    if has_timestamp:\n        return original_row_ids, compacted_csc_formats, src_timestamps\n    return original_row_ids, compacted_csc_formats\n"
  },
  {
    "path": "python/dgl/graphbolt/internal/utils.py",
    "content": "\"\"\"Utility functions for GraphBolt.\"\"\"\n\nimport hashlib\nimport json\nimport os\nimport shutil\nfrom typing import List, Union\n\nimport numpy as np\nimport pandas as pd\nimport torch\nfrom numpy.lib.format import read_array_header_1_0, read_array_header_2_0\n\n\ndef numpy_save_aligned(*args, **kwargs):\n    \"\"\"A wrapper for numpy.save(), ensures the array is stored 4KiB aligned.\"\"\"\n    # https://github.com/numpy/numpy/blob/2093a6d5b933f812d15a3de0eafeeb23c61f948a/numpy/lib/format.py#L179\n    has_array_align = hasattr(np.lib.format, \"ARRAY_ALIGN\")\n    if has_array_align:\n        default_alignment = np.lib.format.ARRAY_ALIGN\n        # The maximum allowed alignment by the numpy code linked above is 4K.\n        # Most filesystems work with block sizes of 4K so in practice, the file\n        # size on the disk won't be larger.\n        np.lib.format.ARRAY_ALIGN = 4096\n    np.save(*args, **kwargs)\n    if has_array_align:\n        np.lib.format.ARRAY_ALIGN = default_alignment\n\n\ndef _read_torch_data(path):\n    return torch.load(path, weights_only=False)\n\n\ndef _read_numpy_data(path, in_memory=True):\n    if in_memory:\n        return torch.from_numpy(np.load(path))\n    return torch.as_tensor(np.load(path, mmap_mode=\"r+\"))\n\n\ndef read_data(path, fmt, in_memory=True):\n    \"\"\"Read data from disk.\"\"\"\n    if fmt == \"torch\":\n        return _read_torch_data(path)\n    elif fmt == \"numpy\":\n        return _read_numpy_data(path, in_memory=in_memory)\n    else:\n        raise RuntimeError(f\"Unsupported format: {fmt}\")\n\n\ndef save_data(data, path, fmt):\n    \"\"\"Save data into disk.\"\"\"\n    # Make sure the directory exists.\n    os.makedirs(os.path.dirname(path), exist_ok=True)\n\n    if fmt not in [\"numpy\", \"torch\"]:\n        raise RuntimeError(f\"Unsupported format: {fmt}\")\n\n    # Perform necessary conversion.\n    if fmt == \"numpy\" and isinstance(data, torch.Tensor):\n        data = data.cpu().numpy()\n    elif fmt == \"torch\" and isinstance(data, np.ndarray):\n        data = torch.from_numpy(data).cpu()\n\n    # Save the data.\n    if fmt == \"numpy\":\n        if not data.flags[\"C_CONTIGUOUS\"]:\n            Warning(\n                \"The ndarray saved to disk is not contiguous, \"\n                \"so it will be copied to contiguous memory.\"\n            )\n            data = np.ascontiguousarray(data)\n        numpy_save_aligned(path, data)\n    elif fmt == \"torch\":\n        if not data.is_contiguous():\n            Warning(\n                \"The tensor saved to disk is not contiguous, \"\n                \"so it will be copied to contiguous memory.\"\n            )\n            data = data.contiguous()\n        torch.save(data, path)\n\n\ndef get_npy_dim(npy_path):\n    \"\"\"Get the dim of numpy file.\"\"\"\n    with open(npy_path, \"rb\") as f:\n        # For the read_array_header API provided by numpy will only read the\n        # length of the header, it will cause parsing failure and error if\n        # first 8 bytes which contains magin string and version are not read\n        # ahead of time. So, we need to make sure we have skipped these 8\n        # bytes.\n        f.seek(8, 0)\n        try:\n            shape, _, _ = read_array_header_1_0(f)\n        except ValueError:\n            try:\n                shape, _, _ = read_array_header_2_0(f)\n            except ValueError:\n                raise ValueError(\"Invalid file format\")\n\n        return len(shape)\n\n\ndef _to_int32(data):\n    if isinstance(data, torch.Tensor):\n        return data.to(torch.int32)\n    elif isinstance(data, np.ndarray):\n        return data.astype(np.int32)\n    else:\n        raise TypeError(\n            \"Unsupported input type. Please provide a torch tensor or numpy array.\"\n        )\n\n\ndef copy_or_convert_data(\n    input_path,\n    output_path,\n    input_format,\n    output_format=\"numpy\",\n    in_memory=True,\n    is_feature=False,\n    within_int32=False,\n):\n    \"\"\"Copy or convert the data from input_path to output_path.\"\"\"\n    assert (\n        output_format == \"numpy\"\n    ), \"The output format of the data should be numpy.\"\n    os.makedirs(os.path.dirname(output_path), exist_ok=True)\n    # We read the data always in case we need to cast its type.\n    data = read_data(input_path, input_format, in_memory)\n    if within_int32:\n        data = _to_int32(data)\n    if input_format == \"numpy\":\n        # If dim of the data is 1, reshape it to n * 1 and save it to output_path.\n        if is_feature and get_npy_dim(input_path) == 1:\n            data = data.reshape(-1, 1)\n        # If the data does not need to be modified, just copy the file.\n        elif not within_int32 and data.numpy().flags[\"C_CONTIGUOUS\"]:\n            shutil.copyfile(input_path, output_path)\n            return\n    else:\n        # If dim of the data is 1, reshape it to n * 1 and save it to output_path.\n        if is_feature and data.dim() == 1:\n            data = data.reshape(-1, 1)\n    save_data(data, output_path, output_format)\n\n\ndef read_edges(dataset_dir, edge_fmt, edge_path):\n    \"\"\"Read egde data from numpy or csv.\"\"\"\n    assert edge_fmt in [\n        \"numpy\",\n        \"csv\",\n    ], f\"`numpy` or `csv` is expected when reading edges but got `{edge_fmt}`.\"\n    if edge_fmt == \"numpy\":\n        edge_data = read_data(\n            os.path.join(dataset_dir, edge_path),\n            edge_fmt,\n        )\n        assert (\n            edge_data.shape[0] == 2 and len(edge_data.shape) == 2\n        ), f\"The shape of edges should be (2, N), but got {edge_data.shape}.\"\n        src, dst = edge_data.numpy()\n    else:\n        edge_data = pd.read_csv(\n            os.path.join(dataset_dir, edge_path),\n            names=[\"src\", \"dst\"],\n        )\n        src, dst = edge_data[\"src\"].to_numpy(), edge_data[\"dst\"].to_numpy()\n    return (src, dst)\n\n\ndef calculate_file_hash(file_path, hash_algo=\"md5\"):\n    \"\"\"Calculate the hash value of a file.\"\"\"\n    hash_algos = [\"md5\", \"sha1\", \"sha224\", \"sha256\", \"sha384\", \"sha512\"]\n    if hash_algo in hash_algos:\n        hash_obj = getattr(hashlib, hash_algo)()\n    else:\n        raise ValueError(\n            f\"Hash algorithm must be one of: {hash_algos}, but got `{hash_algo}`.\"\n        )\n    with open(file_path, \"rb\") as file:\n        for chunk in iter(lambda: file.read(4096), b\"\"):\n            hash_obj.update(chunk)\n    return hash_obj.hexdigest()\n\n\ndef calculate_dir_hash(\n    dir_path, hash_algo=\"md5\", ignore: Union[str, List[str]] = None\n):\n    \"\"\"Calculte the hash values of all files under the directory.\"\"\"\n    hashes = {}\n    for dirpath, _, filenames in os.walk(dir_path):\n        for filename in filenames:\n            if ignore and filename in ignore:\n                continue\n            filepath = os.path.join(dirpath, filename)\n            file_hash = calculate_file_hash(filepath, hash_algo=hash_algo)\n            hashes[filepath] = file_hash\n    return hashes\n\n\ndef check_dataset_change(dataset_dir, processed_dir):\n    \"\"\"Check whether dataset has been changed by checking its hash value.\"\"\"\n    hash_value_file = \"dataset_hash_value.txt\"\n    hash_value_file_path = os.path.join(\n        dataset_dir, processed_dir, hash_value_file\n    )\n    if not os.path.exists(hash_value_file_path):\n        return True\n    with open(hash_value_file_path, \"r\") as f:\n        oringinal_hash_value = json.load(f)\n    present_hash_value = calculate_dir_hash(dataset_dir, ignore=hash_value_file)\n    if oringinal_hash_value == present_hash_value:\n        force_preprocess = False\n    else:\n        force_preprocess = True\n    return force_preprocess\n"
  },
  {
    "path": "python/dgl/graphbolt/internal_utils.py",
    "content": "\"\"\"Miscallenous internal utils.\"\"\"\nimport functools\nimport hashlib\nimport os\nimport platform\nimport warnings\nfrom collections.abc import Mapping, Sequence\n\nimport requests\nimport torch\nfrom tqdm.auto import tqdm\n\ntry:\n    from packaging import version  # pylint: disable=unused-import\nexcept ImportError:\n    # If packaging isn't installed, try and use the vendored copy in setuptools\n    from setuptools.extern.packaging import version\n\n\n@functools.lru_cache(maxsize=None)\ndef is_wsl(v: str = platform.uname().release) -> int:\n    \"\"\"Detects if Python is running in WSL\"\"\"\n\n    if v.endswith(\"-Microsoft\"):\n        return 1\n    elif v.endswith(\"microsoft-standard-WSL2\"):\n        return 2\n\n    return 0\n\n\n# pylint: disable=invalid-name\n_default_formatwarning = warnings.formatwarning\n\n\ndef built_with_cuda():\n    \"\"\"Returns whether GraphBolt was built with CUDA support.\"\"\"\n    # This op is defined if graphbolt is built with CUDA support.\n    return hasattr(torch.ops.graphbolt, \"set_max_uva_threads\")\n\n\nclass GBWarning(UserWarning):\n    \"\"\"GraphBolt Warning class.\"\"\"\n\n\n# pylint: disable=unused-argument\ndef gb_warning_format(message, category, filename, lineno, line=None):\n    \"\"\"Format GraphBolt warnings.\"\"\"\n    if isinstance(category, GBWarning):\n        return \"GraphBolt Warning: {}\\n\".format(message)\n    else:\n        return _default_formatwarning(\n            message, category, filename, lineno, line=None\n        )\n\n\ndef gb_warning(message, category=GBWarning, stacklevel=2):\n    \"\"\"GraphBolt warning wrapper that defaults to ``GBWarning`` instead of\n    ``UserWarning`` category.\n    \"\"\"\n    return warnings.warn(message, category=category, stacklevel=stacklevel)\n\n\nwarnings.formatwarning = gb_warning_format\n\n\ndef is_listlike(data):\n    \"\"\"Return if the data is a sequence but not a string.\"\"\"\n    return isinstance(data, Sequence) and not isinstance(data, str)\n\n\ndef recursive_apply(data, fn, *args, **kwargs):\n    \"\"\"Recursively apply a function to every element in a container.\n\n    If the input data is a list or any sequence other than a string, returns a list\n    whose elements are the same elements applied with the given function.\n\n    If the input data is a dict or any mapping, returns a dict whose keys are the same\n    and values are the elements applied with the given function.\n\n    If the input data is a nested container, the result will have the same nested\n    structure where each element is transformed recursively.\n\n    The first argument of the function will be passed with the individual elements from\n    the input data, followed by the arguments in :attr:`args` and :attr:`kwargs`.\n\n    Parameters\n    ----------\n    data : any\n        Any object.\n    fn : callable\n        Any function.\n    args, kwargs :\n        Additional arguments and keyword-arguments passed to the function.\n\n    Examples\n    --------\n    Applying a ReLU function to a dictionary of tensors:\n\n    >>> h = {k: torch.randn(3) for k in ['A', 'B', 'C']}\n    >>> h = recursive_apply(h, torch.nn.functional.relu)\n    >>> assert all((v >= 0).all() for v in h.values())\n    \"\"\"\n    if isinstance(data, Mapping):\n        return {\n            k: recursive_apply(v, fn, *args, **kwargs) for k, v in data.items()\n        }\n    elif isinstance(data, tuple):\n        return tuple(recursive_apply(v, fn, *args, **kwargs) for v in data)\n    elif is_listlike(data):\n        return [recursive_apply(v, fn, *args, **kwargs) for v in data]\n    else:\n        return fn(data, *args, **kwargs)\n\n\ndef recursive_apply_reduce_all(data, fn, *args, **kwargs):\n    \"\"\"Recursively apply a function to every element in a container and reduce\n    the boolean results with all.\n\n    If the input data is a list or any sequence other than a string, returns\n    True if and only if the given function returns True for all elements.\n\n    If the input data is a dict or any mapping, returns True if and only if the\n    given function returns True for values.\n\n    If the input data is a nested container, the result will be reduced over the\n    nested structure where each element is tested recursively.\n\n    The first argument of the function will be passed with the individual elements from\n    the input data, followed by the arguments in :attr:`args` and :attr:`kwargs`.\n\n    Parameters\n    ----------\n    data : any\n        Any object.\n    fn : callable\n        Any function returning a boolean.\n    args, kwargs :\n        Additional arguments and keyword-arguments passed to the function.\n    \"\"\"\n    if isinstance(data, Mapping):\n        return all(\n            recursive_apply_reduce_all(v, fn, *args, **kwargs)\n            for v in data.values()\n        )\n    elif isinstance(data, tuple) or is_listlike(data):\n        return all(\n            recursive_apply_reduce_all(v, fn, *args, **kwargs) for v in data\n        )\n    else:\n        return fn(data, *args, **kwargs)\n\n\ndef get_nonproperty_attributes(_obj) -> list:\n    \"\"\"Get attributes of the class except for the properties.\"\"\"\n    attributes = [\n        attribute\n        for attribute in dir(_obj)\n        if not attribute.startswith(\"__\")\n        and (\n            not hasattr(type(_obj), attribute)\n            or not isinstance(getattr(type(_obj), attribute), property)\n        )\n        and not callable(getattr(_obj, attribute))\n    ]\n    return attributes\n\n\ndef get_attributes(_obj) -> list:\n    \"\"\"Get attributes of the class.\"\"\"\n    attributes = [\n        attribute\n        for attribute in dir(_obj)\n        if not attribute.startswith(\"__\")\n        and not callable(getattr(_obj, attribute))\n    ]\n    return attributes\n\n\ndef download(\n    url,\n    path=None,\n    overwrite=True,\n    sha1_hash=None,\n    retries=5,\n    verify_ssl=True,\n    log=True,\n):\n    \"\"\"Download a given URL.\n\n    Codes borrowed from mxnet/gluon/utils.py\n\n    Parameters\n    ----------\n    url : str\n        URL to download.\n    path : str, optional\n        Destination path to store downloaded file. By default stores to the\n        current directory with the same name as in url.\n    overwrite : bool, optional\n        Whether to overwrite the destination file if it already exists.\n        By default always overwrites the downloaded file.\n    sha1_hash : str, optional\n        Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified\n        but doesn't match.\n    retries : integer, default 5\n        The number of times to attempt downloading in case of failure or non 200 return codes.\n    verify_ssl : bool, default True\n        Verify SSL certificates.\n    log : bool, default True\n        Whether to print the progress for download\n\n    Returns\n    -------\n    str\n        The file path of the downloaded file.\n    \"\"\"\n    if path is None:\n        fname = url.split(\"/\")[-1]\n        # Empty filenames are invalid\n        assert fname, (\n            \"Can't construct file-name from this URL. \"\n            \"Please set the `path` option manually.\"\n        )\n    else:\n        path = os.path.expanduser(path)\n        if os.path.isdir(path):\n            fname = os.path.join(path, url.split(\"/\")[-1])\n        else:\n            fname = path\n    assert retries >= 0, \"Number of retries should be at least 0\"\n\n    if not verify_ssl:\n        warnings.warn(\n            \"Unverified HTTPS request is being made (verify_ssl=False). \"\n            \"Adding certificate verification is strongly advised.\"\n        )\n\n    if (\n        overwrite\n        or not os.path.exists(fname)\n        or (sha1_hash and not check_sha1(fname, sha1_hash))\n    ):\n        dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))\n        if not os.path.exists(dirname):\n            os.makedirs(dirname)\n        while retries + 1 > 0:\n            # Disable pyling too broad Exception\n            # pylint: disable=W0703\n            try:\n                if log:\n                    print(\"Downloading %s from %s...\" % (fname, url))\n                r = requests.get(url, stream=True, verify=verify_ssl)\n                if r.status_code != 200:\n                    raise RuntimeError(\"Failed downloading url %s\" % url)\n                # Get the total file size.\n                total_size = int(r.headers.get(\"content-length\", 0))\n                with tqdm(\n                    total=total_size, unit=\"B\", unit_scale=True, desc=fname\n                ) as progress_bar:\n                    with open(fname, \"wb\") as f:\n                        for chunk in r.iter_content(chunk_size=1024):\n                            if chunk:  # filter out keep-alive new chunks\n                                f.write(chunk)\n                                progress_bar.update(len(chunk))\n                if sha1_hash and not check_sha1(fname, sha1_hash):\n                    raise UserWarning(\n                        \"File {} is downloaded but the content hash does not match.\"\n                        \" The repo may be outdated or download may be incomplete. \"\n                        'If the \"repo_url\" is overridden, consider switching to '\n                        \"the default repo.\".format(fname)\n                    )\n                break\n            except Exception as e:\n                retries -= 1\n                if retries <= 0:\n                    raise e\n                if log:\n                    print(\n                        \"download failed, retrying, {} attempt{} left\".format(\n                            retries, \"s\" if retries > 1 else \"\"\n                        )\n                    )\n\n    return fname\n\n\ndef check_sha1(filename, sha1_hash):\n    \"\"\"Check whether the sha1 hash of the file content matches the expected hash.\n\n    Codes borrowed from mxnet/gluon/utils.py\n\n    Parameters\n    ----------\n    filename : str\n        Path to the file.\n    sha1_hash : str\n        Expected sha1 hash in hexadecimal digits.\n\n    Returns\n    -------\n    bool\n        Whether the file content matches the expected hash.\n    \"\"\"\n    sha1 = hashlib.sha1()\n    with open(filename, \"rb\") as f:\n        while True:\n            data = f.read(1048576)\n            if not data:\n                break\n            sha1.update(data)\n\n    return sha1.hexdigest() == sha1_hash\n\n\ndef extract_archive(file, target_dir, overwrite=True):\n    \"\"\"Extract archive file.\n\n    Parameters\n    ----------\n    file : str\n        Absolute path of the archive file.\n    target_dir : str\n        Target directory of the archive to be uncompressed.\n    overwrite : bool, default True\n        Whether to overwrite the contents inside the directory.\n        By default always overwrites.\n    \"\"\"\n    if os.path.exists(target_dir) and not overwrite:\n        return\n    print(\"Extracting file to {}\".format(target_dir))\n    if (\n        file.endswith(\".tar.gz\")\n        or file.endswith(\".tar\")\n        or file.endswith(\".tgz\")\n    ):\n        import tarfile\n\n        with tarfile.open(file, \"r\") as archive:\n\n            def is_within_directory(directory, target):\n                abs_directory = os.path.abspath(directory)\n                abs_target = os.path.abspath(target)\n                prefix = os.path.commonprefix([abs_directory, abs_target])\n                return prefix == abs_directory\n\n            def safe_extract(\n                tar, path=\".\", members=None, *, numeric_owner=False\n            ):\n                for member in tar.getmembers():\n                    member_path = os.path.join(path, member.name)\n                    if not is_within_directory(path, member_path):\n                        raise Exception(\"Attempted Path Traversal in Tar File\")\n                tar.extractall(path, members, numeric_owner=numeric_owner)\n\n            safe_extract(archive, path=target_dir)\n    elif file.endswith(\".gz\"):\n        import gzip\n        import shutil\n\n        with gzip.open(file, \"rb\") as f_in:\n            target_file = os.path.join(target_dir, os.path.basename(file)[:-3])\n            with open(target_file, \"wb\") as f_out:\n                shutil.copyfileobj(f_in, f_out)\n    elif file.endswith(\".zip\"):\n        import zipfile\n\n        with zipfile.ZipFile(file, \"r\") as archive:\n            archive.extractall(path=target_dir)\n    else:\n        raise Exception(\"Unrecognized file type: \" + file)\n"
  },
  {
    "path": "python/dgl/graphbolt/item_sampler.py",
    "content": "\"\"\"Item Sampler\"\"\"\n\nfrom collections.abc import Mapping\nfrom typing import Callable, Iterator, Optional, Union\n\nimport numpy as np\nimport torch\nimport torch.distributed as dist\nfrom torch.utils.data import IterDataPipe\n\nfrom .internal import calculate_range\nfrom .internal_utils import gb_warning\nfrom .itemset import HeteroItemSet, ItemSet\nfrom .minibatch import MiniBatch\n\n__all__ = [\"ItemSampler\", \"DistributedItemSampler\", \"minibatcher_default\"]\n\n\ndef minibatcher_default(batch, names):\n    \"\"\"Default minibatcher which maps a list of items to a `MiniBatch` with the\n    same names as the items. The names of items are supposed to be provided\n    and align with the data attributes of `MiniBatch`. If any unknown item name\n    is provided, exception will be raised. If the names of items are not\n    provided, the item list is returned as is and a warning will be raised.\n\n    Parameters\n    ----------\n    batch : list\n        List of items.\n    names : Tuple[str] or None\n        Names of items in `batch` with same length. The order should align\n        with `batch`.\n\n    Returns\n    -------\n    MiniBatch\n        A minibatch.\n    \"\"\"\n    if names is None:\n        gb_warning(\n            \"Failed to map item list to `MiniBatch` as the names of items are \"\n            \"not provided. Please provide a customized `MiniBatcher`. \"\n            \"The item list is returned as is.\"\n        )\n        return batch\n    if len(names) == 1:\n        # Handle the case of single item: batch = tensor([0, 1, 2, 3]), names =\n        # (\"seeds\",) as `zip(batch, names)` will iterate over the tensor\n        # instead of the batch.\n        init_data = {names[0]: batch}\n    else:\n        if isinstance(batch, Mapping):\n            init_data = {\n                name: {k: v[i] for k, v in batch.items()}\n                for i, name in enumerate(names)\n            }\n        else:\n            init_data = {name: item for item, name in zip(batch, names)}\n    minibatch = MiniBatch()\n    # TODO(#7254): Hacks for original `seed_nodes` and `node_pairs`, which need\n    # to be cleaned up later.\n    if \"node_pairs\" in names:\n        pos_seeds = init_data[\"node_pairs\"]\n        # Build negative graph.\n        if \"negative_srcs\" in names and \"negative_dsts\" in names:\n            neg_srcs = init_data[\"negative_srcs\"]\n            neg_dsts = init_data[\"negative_dsts\"]\n            (\n                init_data[\"seeds\"],\n                init_data[\"labels\"],\n                init_data[\"indexes\"],\n            ) = _construct_seeds(\n                pos_seeds, neg_srcs=neg_srcs, neg_dsts=neg_dsts\n            )\n        elif \"negative_srcs\" in names:\n            neg_srcs = init_data[\"negative_srcs\"]\n            (\n                init_data[\"seeds\"],\n                init_data[\"labels\"],\n                init_data[\"indexes\"],\n            ) = _construct_seeds(pos_seeds, neg_srcs=neg_srcs)\n        elif \"negative_dsts\" in names:\n            neg_dsts = init_data[\"negative_dsts\"]\n            (\n                init_data[\"seeds\"],\n                init_data[\"labels\"],\n                init_data[\"indexes\"],\n            ) = _construct_seeds(pos_seeds, neg_dsts=neg_dsts)\n        else:\n            init_data[\"seeds\"] = pos_seeds\n    for name, item in init_data.items():\n        if not hasattr(minibatch, name):\n            gb_warning(\n                f\"Unknown item name '{name}' is detected and added into \"\n                \"`MiniBatch`. You probably need to provide a customized \"\n                \"`MiniBatcher`.\"\n            )\n        # TODO(#7254): Hacks for original `seed_nodes` and `node_pairs`, which\n        # need to be cleaned up later.\n        if name == \"seed_nodes\":\n            name = \"seeds\"\n        if name in (\"node_pairs\", \"negative_srcs\", \"negative_dsts\"):\n            continue\n        setattr(minibatch, name, item)\n    return minibatch\n\n\nclass ItemSampler(IterDataPipe):\n    \"\"\"A sampler to iterate over input items and create minibatches.\n\n    Input items could be node IDs, node pairs with or without labels, node\n    pairs with negative sources/destinations.\n\n    Note: This class `ItemSampler` is not decorated with\n    `torch.utils.data.functional_datapipe` on purpose. This indicates it\n    does not support function-like call. But any iterable datapipes from\n    `torch.utils.data.datapipes` can be further appended.\n\n    Parameters\n    ----------\n    item_set : Union[ItemSet, HeteroItemSet]\n        Data to be sampled.\n    batch_size : int\n        The size of each batch.\n    minibatcher : Optional[Callable]\n        A callable that takes in a list of items and returns a `MiniBatch`.\n    drop_last : bool\n        Option to drop the last batch if it's not full.\n    shuffle : bool\n        Option to shuffle before sample.\n    seed: int\n        The seed for reproducible stochastic shuffling. If None, a random seed\n        will be generated.\n\n    Examples\n    --------\n    1. Node IDs.\n\n    >>> import torch\n    >>> from dgl import graphbolt as gb\n    >>> item_set = gb.ItemSet(torch.arange(0, 10), names=\"seeds\")\n    >>> item_sampler = gb.ItemSampler(\n    ...     item_set, batch_size=4, shuffle=False, drop_last=False\n    ... )\n    >>> next(iter(item_sampler))\n    MiniBatch(seeds=tensor([0, 1, 2, 3]), sampled_subgraphs=None,\n        node_features=None, labels=None, input_nodes=None,\n        indexes=None, edge_features=None, compacted_seeds=None,\n        blocks=None,)\n\n    2. Node pairs.\n\n    >>> item_set = gb.ItemSet(torch.arange(0, 20).reshape(-1, 2),\n    ...     names=\"seeds\")\n    >>> item_sampler = gb.ItemSampler(\n    ...     item_set, batch_size=4, shuffle=False, drop_last=False\n    ... )\n    >>> next(iter(item_sampler))\n    MiniBatch(seeds=tensor([[0, 1], [2, 3], [4, 5], [6, 7]]),\n        sampled_subgraphs=None, node_features=None, labels=None,\n        input_nodes=None, indexes=None, edge_features=None,\n        compacted_seeds=None, blocks=None,)\n\n    3. Node pairs and labels.\n\n    >>> item_set = gb.ItemSet(\n    ...     (torch.arange(0, 20).reshape(-1, 2), torch.arange(10, 20)),\n    ...     names=(\"seeds\", \"labels\")\n    ... )\n    >>> item_sampler = gb.ItemSampler(\n    ...     item_set, batch_size=4, shuffle=False, drop_last=False\n    ... )\n    >>> next(iter(item_sampler))\n    MiniBatch(seeds=tensor([[0, 1], [2, 3], [4, 5], [6, 7]]),\n        sampled_subgraphs=None, node_features=None,\n        labels=tensor([10, 11, 12, 13]), input_nodes=None,\n        indexes=None, edge_features=None, compacted_seeds=None,\n        blocks=None,)\n\n    4. Node pairs, labels and indexes.\n\n    >>> seeds = torch.arange(0, 20).reshape(-1, 2)\n    >>> labels = torch.tensor([1, 1, 0, 0, 0, 0, 0, 0, 0, 0])\n    >>> indexes = torch.tensor([0, 1, 0, 0, 0, 0, 1, 1, 1, 1])\n    >>> item_set = gb.ItemSet((seeds, labels, indexes), names=(\"seeds\",\n    ...     \"labels\", \"indexes\"))\n    >>> item_sampler = gb.ItemSampler(\n    ...     item_set, batch_size=4, shuffle=False, drop_last=False\n    ... )\n    >>> next(iter(item_sampler))\n    MiniBatch(seeds=tensor([[0, 1], [2, 3], [4, 5], [6, 7]]),\n        sampled_subgraphs=None, node_features=None,\n        labels=tensor([1, 1, 0, 0]), input_nodes=None,\n        indexes=tensor([0, 1, 0, 0]), edge_features=None,\n        compacted_seeds=None, blocks=None,)\n\n    5. Further process batches with other datapipes such as\n    :class:`torch.utils.data.datapipes.iter.Mapper`.\n\n    >>> item_set = gb.ItemSet(torch.arange(0, 10))\n    >>> data_pipe = gb.ItemSampler(item_set, 4)\n    >>> def add_one(batch):\n    ...     return batch + 1\n    >>> data_pipe = data_pipe.map(add_one)\n    >>> list(data_pipe)\n    [tensor([1, 2, 3, 4]), tensor([5, 6, 7, 8]), tensor([ 9, 10])]\n\n    6. Heterogeneous node IDs.\n\n    >>> ids = {\n    ...     \"user\": gb.ItemSet(torch.arange(0, 5), names=\"seeds\"),\n    ...     \"item\": gb.ItemSet(torch.arange(0, 6), names=\"seeds\"),\n    ... }\n    >>> item_set = gb.HeteroItemSet(ids)\n    >>> item_sampler = gb.ItemSampler(item_set, batch_size=4)\n    >>> next(iter(item_sampler))\n    MiniBatch(seeds={'user': tensor([0, 1, 2, 3])}, sampled_subgraphs=None,\n        node_features=None, labels=None, input_nodes=None, indexes=None,\n        edge_features=None, compacted_seeds=None, blocks=None,)\n\n    7. Heterogeneous node pairs.\n\n    >>> seeds_like = torch.arange(0, 10).reshape(-1, 2)\n    >>> seeds_follow = torch.arange(10, 20).reshape(-1, 2)\n    >>> item_set = gb.HeteroItemSet({\n    ...     \"user:like:item\": gb.ItemSet(\n    ...         seeds_like, names=\"seeds\"),\n    ...     \"user:follow:user\": gb.ItemSet(\n    ...         seeds_follow, names=\"seeds\"),\n    ... })\n    >>> item_sampler = gb.ItemSampler(item_set, batch_size=4)\n    >>> next(iter(item_sampler))\n    MiniBatch(seeds={'user:like:item':\n        tensor([[0, 1], [2, 3], [4, 5], [6, 7]])}, sampled_subgraphs=None,\n        node_features=None, labels=None, input_nodes=None, indexes=None,\n        edge_features=None, compacted_seeds=None, blocks=None,)\n\n    8. Heterogeneous node pairs and labels.\n\n    >>> seeds_like = torch.arange(0, 10).reshape(-1, 2)\n    >>> labels_like = torch.arange(0, 5)\n    >>> seeds_follow = torch.arange(10, 20).reshape(-1, 2)\n    >>> labels_follow = torch.arange(5, 10)\n    >>> item_set = gb.HeteroItemSet({\n    ...     \"user:like:item\": gb.ItemSet((seeds_like, labels_like),\n    ...         names=(\"seeds\", \"labels\")),\n    ...     \"user:follow:user\": gb.ItemSet((seeds_follow, labels_follow),\n    ...         names=(\"seeds\", \"labels\")),\n    ... })\n    >>> item_sampler = gb.ItemSampler(item_set, batch_size=4)\n    >>> next(iter(item_sampler))\n    MiniBatch(seeds={'user:like:item':\n        tensor([[0, 1], [2, 3], [4, 5], [6, 7]])}, sampled_subgraphs=None,\n        node_features=None, labels={'user:like:item': tensor([0, 1, 2, 3])},\n        input_nodes=None, indexes=None, edge_features=None,\n        compacted_seeds=None, blocks=None,)\n\n    9. Heterogeneous node pairs, labels and indexes.\n\n    >>> seeds_like = torch.arange(0, 10).reshape(-1, 2)\n    >>> labels_like = torch.tensor([1, 1, 0, 0, 0])\n    >>> indexes_like = torch.tensor([0, 1, 0, 0, 1])\n    >>> seeds_follow = torch.arange(20, 30).reshape(-1, 2)\n    >>> labels_follow = torch.tensor([1, 1, 0, 0, 0])\n    >>> indexes_follow = torch.tensor([0, 1, 0, 0, 1])\n    >>> item_set = gb.HeteroItemSet({\n    ...     \"user:like:item\": gb.ItemSet((seeds_like, labels_like,\n    ...         indexes_like), names=(\"seeds\", \"labels\", \"indexes\")),\n    ...     \"user:follow:user\": gb.ItemSet((seeds_follow,labels_follow,\n    ...         indexes_follow), names=(\"seeds\", \"labels\", \"indexes\")),\n    ... })\n    >>> item_sampler = gb.ItemSampler(item_set, batch_size=4)\n    >>> next(iter(item_sampler))\n    MiniBatch(seeds={'user:like:item':\n        tensor([[0, 1], [2, 3], [4, 5], [6, 7]])}, sampled_subgraphs=None,\n        node_features=None, labels={'user:like:item': tensor([1, 1, 0, 0])},\n        input_nodes=None, indexes={'user:like:item': tensor([0, 1, 0, 0])},\n        edge_features=None, compacted_seeds=None, blocks=None,)\n    \"\"\"\n\n    def __init__(\n        self,\n        item_set: Union[ItemSet, HeteroItemSet],\n        batch_size: int,\n        minibatcher: Optional[Callable] = minibatcher_default,\n        drop_last: Optional[bool] = False,\n        shuffle: Optional[bool] = False,\n        seed: Optional[int] = None,\n    ) -> None:\n        super().__init__()\n        self._item_set = item_set\n        self._names = item_set.names\n        self._batch_size = batch_size\n        self._minibatcher = minibatcher\n        self._drop_last = drop_last\n        self._shuffle = shuffle\n        self._distributed = False\n        self._drop_uneven_inputs = False\n        self._world_size = None\n        self._rank = None\n        # For the sake of reproducibility, the seed should be allowed to be\n        # manually set by the user.\n        if seed is None:\n            self._seed = np.random.randint(0, np.iinfo(np.int32).max)\n        else:\n            self._seed = seed\n        # The attribute `self._epoch` is added to make shuffling work properly\n        # across multiple epochs. Otherwise, the same ordering will always be\n        # used in every epoch.\n        self._epoch = 0\n\n    def __iter__(self) -> Iterator:\n        worker_info = torch.utils.data.get_worker_info()\n        if worker_info is not None:\n            num_workers = worker_info.num_workers\n            worker_id = worker_info.id\n        else:\n            num_workers = 1\n            worker_id = 0\n        total = len(self._item_set)\n        start_offset, assigned_count, output_count = calculate_range(\n            self._distributed,\n            total,\n            self._world_size,\n            self._rank,\n            num_workers,\n            worker_id,\n            self._batch_size,\n            self._drop_last,\n            self._drop_uneven_inputs,\n        )\n        if self._shuffle:\n            g = torch.Generator().manual_seed(self._seed + self._epoch)\n            permutation = torch.randperm(total, generator=g)\n            indices = permutation[start_offset : start_offset + assigned_count]\n        else:\n            indices = torch.arange(start_offset, start_offset + assigned_count)\n        for i in range(0, assigned_count, self._batch_size):\n            if output_count <= 0:\n                break\n            yield self._minibatcher(\n                self._item_set[\n                    indices[i : i + min(self._batch_size, output_count)]\n                ],\n                self._names,\n            )\n            output_count -= self._batch_size\n\n        self._epoch += 1\n\n\nclass DistributedItemSampler(ItemSampler):\n    \"\"\"A sampler to iterate over input items and create subsets distributedly.\n\n    This sampler creates a distributed subset of items from the given data set,\n    which can be used for training with PyTorch's Distributed Data Parallel\n    (DDP). The items can be node IDs, node pairs with or without labels, node\n    pairs with negative sources/destinations, DGLGraphs, or heterogeneous\n    counterparts. The original item set is split such that each replica\n    (process) receives an exclusive subset.\n\n    Note: The items will be first split onto each replica, then get shuffled\n    (if needed) and batched. Therefore, each replica will always get a same set\n    of items.\n\n    Note: This class `DistributedItemSampler` is not decorated with\n    `torch.utils.data.functional_datapipe` on purpose. This indicates it\n    does not support function-like call. But any iterable datapipes from\n    `torch.utils.data.datapipes` can be further appended.\n\n    Parameters\n    ----------\n    item_set : Union[ItemSet, HeteroItemSet]\n        Data to be sampled.\n    batch_size : int\n        The size of each batch.\n    minibatcher : Optional[Callable]\n        A callable that takes in a list of items and returns a `MiniBatch`.\n    drop_last : bool\n        Option to drop the last batch if it's not full.\n    shuffle : bool\n        Option to shuffle before sample.\n    num_replicas: int\n        The number of model replicas that will be created during Distributed\n        Data Parallel (DDP) training. It should be the same as the real world\n        size, otherwise it could cause errors. By default, it is retrieved from\n        the current distributed group.\n    drop_uneven_inputs : bool\n        Option to make sure the numbers of batches for each replica are the\n        same. If some of the replicas have more batches than the others, the\n        redundant batches of those replicas will be dropped. If the drop_last\n        parameter is also set to True, the last batch will be dropped before the\n        redundant batches are dropped.\n        Note: When using Distributed Data Parallel (DDP) training, the program\n        may hang or error if the a replica has fewer inputs. It is recommended\n        to use the Join Context Manager provided by PyTorch to solve this\n        problem. Please refer to\n        https://pytorch.org/tutorials/advanced/generic_join.html. However, this\n        option can be used if the Join Context Manager is not helpful for any\n        reason.\n    seed: int\n        The seed for reproducible stochastic shuffling. If None, a random seed\n        will be generated.\n\n    Examples\n    --------\n    0. Preparation: DistributedItemSampler needs multi-processing environment to\n    work. You need to spawn subprocesses and initialize processing group before\n    executing following examples. Due to randomness, the output is not always\n    the same as listed below.\n\n    >>> import torch\n    >>> from dgl import graphbolt as gb\n    >>> item_set = gb.ItemSet(torch.arange(15))\n    >>> num_replicas = 4\n    >>> batch_size = 2\n    >>> mp.spawn(...)\n\n    1. shuffle = False, drop_last = False, drop_uneven_inputs = False.\n\n    >>> item_sampler = gb.DistributedItemSampler(\n    >>>     item_set, batch_size=2, shuffle=False, drop_last=False,\n    >>>     drop_uneven_inputs=False\n    >>> )\n    >>> data_loader = gb.DataLoader(item_sampler)\n    >>> print(f\"Replica#{proc_id}: {list(data_loader)})\n    Replica#0: [tensor([0, 1]), tensor([2, 3])]\n    Replica#1: [tensor([4, 5]), tensor([6, 7])]\n    Replica#2: [tensor([8, 9]), tensor([10, 11])]\n    Replica#3: [tensor([12, 13]), tensor([14])]\n\n    2. shuffle = False, drop_last = True, drop_uneven_inputs = False.\n\n    >>> item_sampler = gb.DistributedItemSampler(\n    >>>     item_set, batch_size=2, shuffle=False, drop_last=True,\n    >>>     drop_uneven_inputs=False\n    >>> )\n    >>> data_loader = gb.DataLoader(item_sampler)\n    >>> print(f\"Replica#{proc_id}: {list(data_loader)})\n    Replica#0: [tensor([0, 1]), tensor([2, 3])]\n    Replica#1: [tensor([4, 5]), tensor([6, 7])]\n    Replica#2: [tensor([8, 9]), tensor([10, 11])]\n    Replica#3: [tensor([12, 13])]\n\n    3. shuffle = False, drop_last = False, drop_uneven_inputs = True.\n\n    >>> item_sampler = gb.DistributedItemSampler(\n    >>>     item_set, batch_size=2, shuffle=False, drop_last=False,\n    >>>     drop_uneven_inputs=True\n    >>> )\n    >>> data_loader = gb.DataLoader(item_sampler)\n    >>> print(f\"Replica#{proc_id}: {list(data_loader)})\n    Replica#0: [tensor([0, 1]), tensor([2, 3])]\n    Replica#1: [tensor([4, 5]), tensor([6, 7])]\n    Replica#2: [tensor([8, 9]), tensor([10, 11])]\n    Replica#3: [tensor([12, 13]), tensor([14])]\n\n    4. shuffle = False, drop_last = True, drop_uneven_inputs = True.\n\n    >>> item_sampler = gb.DistributedItemSampler(\n    >>>     item_set, batch_size=2, shuffle=False, drop_last=True,\n    >>>     drop_uneven_inputs=True\n    >>> )\n    >>> data_loader = gb.DataLoader(item_sampler)\n    >>> print(f\"Replica#{proc_id}: {list(data_loader)})\n    Replica#0: [tensor([0, 1])]\n    Replica#1: [tensor([4, 5])]\n    Replica#2: [tensor([8, 9])]\n    Replica#3: [tensor([12, 13])]\n\n    5. shuffle = True, drop_last = True, drop_uneven_inputs = False.\n\n    >>> item_sampler = gb.DistributedItemSampler(\n    >>>     item_set, batch_size=2, shuffle=True, drop_last=True,\n    >>>     drop_uneven_inputs=False\n    >>> )\n    >>> data_loader = gb.DataLoader(item_sampler)\n    >>> print(f\"Replica#{proc_id}: {list(data_loader)})\n    (One possible output:)\n    Replica#0: [tensor([3, 2]), tensor([0, 1])]\n    Replica#1: [tensor([6, 5]), tensor([7, 4])]\n    Replica#2: [tensor([8, 10])]\n    Replica#3: [tensor([14, 12])]\n\n    6. shuffle = True, drop_last = True, drop_uneven_inputs = True.\n\n    >>> item_sampler = gb.DistributedItemSampler(\n    >>>     item_set, batch_size=2, shuffle=True, drop_last=True,\n    >>>     drop_uneven_inputs=True\n    >>> )\n    >>> data_loader = gb.DataLoader(item_sampler)\n    >>> print(f\"Replica#{proc_id}: {list(data_loader)})\n    (One possible output:)\n    Replica#0: [tensor([1, 3])]\n    Replica#1: [tensor([7, 5])]\n    Replica#2: [tensor([11, 9])]\n    Replica#3: [tensor([13, 14])]\n    \"\"\"\n\n    def __init__(\n        self,\n        item_set: Union[ItemSet, HeteroItemSet],\n        batch_size: int,\n        minibatcher: Optional[Callable] = minibatcher_default,\n        drop_last: Optional[bool] = False,\n        shuffle: Optional[bool] = False,\n        drop_uneven_inputs: Optional[bool] = False,\n        seed: Optional[int] = None,\n    ) -> None:\n        super().__init__(\n            item_set,\n            batch_size,\n            minibatcher,\n            drop_last,\n            shuffle,\n            seed,\n        )\n        self._distributed = True\n        self._drop_uneven_inputs = drop_uneven_inputs\n        if not dist.is_available():\n            raise RuntimeError(\n                \"Distributed item sampler requires distributed package.\"\n            )\n        self._world_size = dist.get_world_size()\n        self._rank = dist.get_rank()\n        if self._world_size > 1:\n            # For the sake of reproducibility, the seed should be allowed to be\n            # manually set by the user.\n            self._align_seeds(src=0, seed=seed)\n\n    def _align_seeds(\n        self, src: Optional[int] = 0, seed: Optional[int] = None\n    ) -> None:\n        \"\"\"Aligns seeds across distributed processes.\n\n        This method synchronizes seeds across distributed processes, ensuring\n        consistent randomness.\n\n        Parameters\n        ----------\n        src: int, optional\n            The source process rank. Defaults to 0.\n        seed: int, optional\n            The seed value to synchronize. If None, a random seed will be\n            generated. Defaults to None.\n        \"\"\"\n        device = (\n            torch.cuda.current_device()\n            if torch.cuda.is_available() and dist.get_backend() == \"nccl\"\n            else \"cpu\"\n        )\n        if seed is None:\n            seed = np.random.randint(0, np.iinfo(np.int32).max)\n        if self._rank == src:\n            seed_tensor = torch.tensor(seed, dtype=torch.int32, device=device)\n        else:\n            seed_tensor = torch.empty([], dtype=torch.int32, device=device)\n        dist.broadcast(seed_tensor, src=src)\n        self._seed = seed_tensor.item()\n\n\ndef _construct_seeds(pos_seeds, neg_srcs=None, neg_dsts=None):\n    # For homogeneous graph.\n    if isinstance(pos_seeds, torch.Tensor):\n        negative_ratio = neg_srcs.size(1) if neg_srcs else neg_dsts.size(1)\n        neg_srcs = (\n            neg_srcs\n            if neg_srcs is not None\n            else pos_seeds[:, 0].repeat_interleave(negative_ratio)\n        ).view(-1)\n        neg_dsts = (\n            neg_dsts\n            if neg_dsts is not None\n            else pos_seeds[:, 1].repeat_interleave(negative_ratio)\n        ).view(-1)\n        neg_seeds = torch.cat((neg_srcs, neg_dsts)).view(2, -1).T\n        seeds = torch.cat((pos_seeds, neg_seeds))\n        pos_seeds_num = pos_seeds.size(0)\n        labels = torch.empty(seeds.size(0), device=pos_seeds.device)\n        labels[:pos_seeds_num] = 1\n        labels[pos_seeds_num:] = 0\n        pos_indexes = torch.arange(\n            0,\n            pos_seeds_num,\n            device=pos_seeds.device,\n        )\n        neg_indexes = pos_indexes.repeat_interleave(negative_ratio)\n        indexes = torch.cat((pos_indexes, neg_indexes))\n    # For heterogeneous graph.\n    else:\n        negative_ratio = (\n            list(neg_srcs.values())[0].size(1)\n            if neg_srcs\n            else list(neg_dsts.values())[0].size(1)\n        )\n        seeds = {}\n        labels = {}\n        indexes = {}\n        for etype in pos_seeds:\n            neg_src = (\n                neg_srcs[etype]\n                if neg_srcs is not None\n                else pos_seeds[etype][:, 0].repeat_interleave(negative_ratio)\n            ).view(-1)\n            neg_dst = (\n                neg_dsts[etype]\n                if neg_dsts is not None\n                else pos_seeds[etype][:, 1].repeat_interleave(negative_ratio)\n            ).view(-1)\n            seeds[etype] = torch.cat(\n                (\n                    pos_seeds[etype],\n                    torch.cat(\n                        (\n                            neg_src,\n                            neg_dst,\n                        )\n                    )\n                    .view(2, -1)\n                    .T,\n                )\n            )\n            pos_seeds_num = pos_seeds[etype].size(0)\n            labels[etype] = torch.empty(\n                seeds[etype].size(0), device=pos_seeds[etype].device\n            )\n            labels[etype][:pos_seeds_num] = 1\n            labels[etype][pos_seeds_num:] = 0\n            pos_indexes = torch.arange(\n                0,\n                pos_seeds_num,\n                device=pos_seeds[etype].device,\n            )\n            neg_indexes = pos_indexes.repeat_interleave(negative_ratio)\n            indexes[etype] = torch.cat((pos_indexes, neg_indexes))\n    return seeds, labels, indexes\n"
  },
  {
    "path": "python/dgl/graphbolt/itemset.py",
    "content": "\"\"\"GraphBolt Itemset.\"\"\"\n\nimport textwrap\nfrom typing import Dict, Iterable, Tuple, Union\n\nimport torch\n\nfrom .internal_utils import gb_warning\n\n__all__ = [\"ItemSet\", \"HeteroItemSet\", \"ItemSetDict\"]\n\n\ndef is_scalar(x):\n    \"\"\"Checks if the input is a scalar.\"\"\"\n    return (\n        len(x.shape) == 0 if isinstance(x, torch.Tensor) else isinstance(x, int)\n    )\n\n\nclass ItemSet:\n    r\"\"\"A wrapper of a tensor or tuple of tensors.\n\n    Parameters\n    ----------\n    items: Union[int, torch.Tensor, Tuple[torch.Tensor]]\n        The tensors to be wrapped.\n        - If it is a single scalar (an integer or a tensor that holds a single\n          value), the item would be considered as a range_tensor created by\n          `torch.arange`.\n        - If it is a multi-dimensional tensor, the indexing will be performed\n          along the first dimension.\n        - If it is a tuple, each item in the tuple must be a tensor.\n\n    names: Union[str, Tuple[str]], optional\n        The names of the items. If it is a tuple, each name must corresponds to\n        an item in the `items` parameter. The naming is arbitrary, but in\n        general practice, the names should be chosen from ['labels', 'seeds',\n        'indexes'] to align with the attributes of class\n        `dgl.graphbolt.MiniBatch`.\n\n    Examples\n    --------\n    >>> import torch\n    >>> from dgl import graphbolt as gb\n\n    1. Integer: number of nodes.\n\n    >>> num = 10\n    >>> item_set = gb.ItemSet(num, names=\"seeds\")\n    >>> list(item_set)\n    [tensor(0), tensor(1), tensor(2), tensor(3), tensor(4), tensor(5),\n     tensor(6), tensor(7), tensor(8), tensor(9)]\n    >>> item_set[:]\n    tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])\n    >>> item_set.names\n    ('seeds',)\n\n    2. Torch scalar: number of nodes. Customizable dtype compared to Integer.\n\n    >>> num = torch.tensor(10, dtype=torch.int32)\n    >>> item_set = gb.ItemSet(num, names=\"seeds\")\n    >>> list(item_set)\n    [tensor(0, dtype=torch.int32), tensor(1, dtype=torch.int32),\n     tensor(2, dtype=torch.int32), tensor(3, dtype=torch.int32),\n     tensor(4, dtype=torch.int32), tensor(5, dtype=torch.int32),\n     tensor(6, dtype=torch.int32), tensor(7, dtype=torch.int32),\n     tensor(8, dtype=torch.int32), tensor(9, dtype=torch.int32)]\n    >>> item_set[:]\n    tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int32)\n    >>> item_set.names\n    ('seeds',)\n\n    3. Single tensor: seed nodes.\n\n    >>> node_ids = torch.arange(0, 5)\n    >>> item_set = gb.ItemSet(node_ids, names=\"seeds\")\n    >>> list(item_set)\n    [tensor(0), tensor(1), tensor(2), tensor(3), tensor(4)]\n    >>> item_set[:]\n    tensor([0, 1, 2, 3, 4])\n    >>> item_set.names\n    ('seeds',)\n\n    4. Tuple of tensors with same shape: seed nodes and labels.\n\n    >>> node_ids = torch.arange(0, 5)\n    >>> labels = torch.arange(5, 10)\n    >>> item_set = gb.ItemSet(\n    ...     (node_ids, labels), names=(\"seeds\", \"labels\"))\n    >>> list(item_set)\n    [(tensor(0), tensor(5)), (tensor(1), tensor(6)), (tensor(2), tensor(7)),\n     (tensor(3), tensor(8)), (tensor(4), tensor(9))]\n    >>> item_set[:]\n    (tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9]))\n    >>> item_set.names\n    ('seeds', 'labels')\n\n    5. Tuple of tensors with different shape: seeds and labels.\n\n    >>> seeds = torch.arange(0, 10).reshape(-1, 2)\n    >>> labels = torch.tensor([1, 1, 0, 0, 0])\n    >>> item_set = gb.ItemSet(\n    ...     (seeds, labels), names=(\"seeds\", \"lables\"))\n    >>> list(item_set)\n    [(tensor([0, 1]), tensor([1])),\n     (tensor([2, 3]), tensor([1])),\n     (tensor([4, 5]), tensor([0])),\n     (tensor([6, 7]), tensor([0])),\n     (tensor([8, 9]), tensor([0]))]\n    >>> item_set[:]\n    (tensor([[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]),\n     tensor([1, 1, 0, 0, 0]))\n    >>> item_set.names\n    ('seeds', 'labels')\n\n    6. Tuple of tensors with different shape: hyperlink and labels.\n\n    >>> seeds = torch.arange(0, 10).reshape(-1, 5)\n    >>> labels = torch.tensor([1, 0])\n    >>> item_set = gb.ItemSet(\n    ...     (seeds, labels), names=(\"seeds\", \"lables\"))\n    >>> list(item_set)\n    [(tensor([0, 1, 2, 3, 4]), tensor([1])),\n     (tensor([5, 6, 7, 8, 9]), tensor([0]))]\n    >>> item_set[:]\n    (tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]),\n     tensor([1, 0]))\n    >>> item_set.names\n    ('seeds', 'labels')\n    \"\"\"\n\n    def __init__(\n        self,\n        items: Union[int, torch.Tensor, Tuple[torch.Tensor]],\n        names: Union[str, Tuple[str]] = None,\n    ) -> None:\n        if is_scalar(items):\n            self._length = int(items)\n            self._items = items\n        elif isinstance(items, tuple):\n            self._length = len(items[0])\n            if any(self._length != len(item) for item in items):\n                raise ValueError(\"Size mismatch between items.\")\n            self._items = items\n        else:\n            self._length = len(items)\n            self._items = (items,)\n        self._num_items = (\n            len(self._items) if isinstance(self._items, tuple) else 1\n        )\n        if names is not None:\n            if isinstance(names, tuple):\n                self._names = names\n            else:\n                self._names = (names,)\n            assert self._num_items == len(self._names), (\n                f\"Number of items ({self._num_items}) and \"\n                f\"names ({len(self._names)}) don't match.\"\n            )\n        else:\n            self._names = None\n\n    def __len__(self) -> int:\n        return self._length\n\n    def __getitem__(self, index: Union[int, slice, Iterable[int]]):\n        if is_scalar(self._items):\n            dtype = getattr(self._items, \"dtype\", torch.int64)\n            if isinstance(index, slice):\n                start, stop, step = index.indices(self._length)\n                return torch.arange(start, stop, step, dtype=dtype)\n            elif isinstance(index, int):\n                if index < 0:\n                    index += self._length\n                if index < 0 or index >= self._length:\n                    raise IndexError(\n                        f\"{type(self).__name__} index out of range.\"\n                    )\n                return torch.tensor(index, dtype=dtype)\n            elif isinstance(index, torch.Tensor):\n                return index.to(dtype)\n            else:\n                raise TypeError(\n                    f\"{type(self).__name__} indices must be int, slice, or \"\n                    f\"torch.Tensor, not {type(index)}.\"\n                )\n        elif self._num_items == 1:\n            return self._items[0][index]\n        else:\n            return tuple(item[index] for item in self._items)\n\n    @property\n    def names(self) -> Tuple[str]:\n        \"\"\"Return the names of the items.\"\"\"\n        return self._names\n\n    @property\n    def num_items(self) -> int:\n        \"\"\"Return the number of the items.\"\"\"\n        return self._num_items\n\n    def __repr__(self) -> str:\n        ret = (\n            f\"{self.__class__.__name__}(\\n\"\n            f\"    items={self._items},\\n\"\n            f\"    names={self._names},\\n\"\n            f\")\"\n        )\n        return ret\n\n\nclass HeteroItemSet:\n    r\"\"\"A collection of itemsets, each associated with a unique type.\n\n    This class aims to assemble existing itemsets with different types, for\n    example, seed_nodes of different node types in a graph.\n\n    Parameters\n    ----------\n    itemsets: Dict[str, ItemSet]\n        A dictionary whose keys are types and values are ItemSet instances.\n\n    Examples\n    --------\n    >>> import torch\n    >>> from dgl import graphbolt as gb\n\n    1. Each itemset is a single tensor: seed nodes.\n\n    >>> node_ids_user = torch.arange(0, 5)\n    >>> node_ids_item = torch.arange(5, 10)\n    >>> item_set = gb.HeteroItemSet({\n    ...     \"user\": gb.ItemSet(node_ids_user, names=\"seeds\"),\n    ...     \"item\": gb.ItemSet(node_ids_item, names=\"seeds\")})\n    >>> list(item_set)\n    [{\"user\": tensor(0)}, {\"user\": tensor(1)}, {\"user\": tensor(2)},\n     {\"user\": tensor(3)}, {\"user\": tensor(4)}, {\"item\": tensor(5)},\n     {\"item\": tensor(6)}, {\"item\": tensor(7)}, {\"item\": tensor(8)},\n     {\"item\": tensor(9)}}]\n    >>> item_set[:]\n    {\"user\": tensor([0, 1, 2, 3, 4]), \"item\": tensor([5, 6, 7, 8, 9])}\n    >>> item_set.names\n    ('seeds',)\n\n    2. Each itemset is a tuple of tensors with same shape: seed nodes and\n    labels.\n\n    >>> node_ids_user = torch.arange(0, 2)\n    >>> labels_user = torch.arange(0, 2)\n    >>> node_ids_item = torch.arange(2, 5)\n    >>> labels_item = torch.arange(2, 5)\n    >>> item_set = gb.HeteroItemSet({\n    ...     \"user\": gb.ItemSet(\n    ...         (node_ids_user, labels_user),\n    ...         names=(\"seeds\", \"labels\")),\n    ...     \"item\": gb.ItemSet(\n    ...         (node_ids_item, labels_item),\n    ...         names=(\"seeds\", \"labels\"))})\n    >>> list(item_set)\n    [{\"user\": (tensor(0), tensor(0))}, {\"user\": (tensor(1), tensor(1))},\n     {\"item\": (tensor(2), tensor(2))}, {\"item\": (tensor(3), tensor(3))},\n     {\"item\": (tensor(4), tensor(4))}}]\n    >>> item_set[:]\n    {\"user\": (tensor([0, 1]), tensor([0, 1])),\n     \"item\": (tensor([2, 3, 4]), tensor([2, 3, 4]))}\n    >>> item_set.names\n    ('seeds', 'labels')\n\n    3. Each itemset is a tuple of tensors with different shape: seeds and\n    labels.\n\n    >>> seeds_like = torch.arange(0, 4).reshape(-1, 2)\n    >>> labels_like = torch.tensor([1, 0])\n    >>> seeds_follow = torch.arange(0, 6).reshape(-1, 2)\n    >>> labels_follow = torch.tensor([1, 1, 0])\n    >>> item_set = gb.HeteroItemSet({\n    ...     \"user:like:item\": gb.ItemSet(\n    ...         (seeds_like, labels_like),\n    ...         names=(\"seeds\", \"labels\")),\n    ...     \"user:follow:user\": gb.ItemSet(\n    ...         (seeds_follow, labels_follow),\n    ...         names=(\"seeds\", \"labels\"))})\n    >>> list(item_set)\n    [{'user:like:item': (tensor([0, 1]), tensor(1))},\n     {'user:like:item': (tensor([2, 3]), tensor(0))},\n     {'user:follow:user': (tensor([0, 1]), tensor(1))},\n     {'user:follow:user': (tensor([2, 3]), tensor(1))},\n     {'user:follow:user': (tensor([4, 5]), tensor(0))}]\n    >>> item_set[:]\n    {'user:like:item': (tensor([[0, 1], [2, 3]]),\n                        tensor([1, 0])),\n     'user:follow:user': (tensor([[0, 1], [2, 3], [4, 5]]),\n                          tensor([1, 1, 0]))}\n    >>> item_set.names\n    ('seeds', 'labels')\n\n    4. Each itemset is a tuple of tensors with different shape: hyperlink and\n    labels.\n\n    >>> first_seeds = torch.arange(0, 6).reshape(-1, 3)\n    >>> first_labels = torch.tensor([1, 0])\n    >>> second_seeds = torch.arange(0, 2).reshape(-1, 1)\n    >>> second_labels = torch.tensor([1, 0])\n    >>> item_set = gb.HeteroItemSet({\n    ...     \"query:user:item\": gb.ItemSet(\n    ...         (first_seeds, first_labels),\n    ...         names=(\"seeds\", \"labels\")),\n    ...     \"user\": gb.ItemSet(\n    ...         (second_seeds, second_labels),\n    ...         names=(\"seeds\", \"labels\"))})\n    >>> list(item_set)\n    [{'query:user:item': (tensor([0, 1, 2]), tensor(1))},\n     {'query:user:item': (tensor([3, 4, 5]), tensor(0))},\n     {'user': (tensor([0]), tensor(1))},\n     {'user': (tensor([1]), tensor(0))}]\n    >>> item_set[:]\n    {'query:user:item': (tensor([[0, 1, 2], [3, 4, 5]]),\n                        tensor([1, 0])),\n     'user': (tensor([[0], [1]]),tensor([1, 0]))}\n    >>> item_set.names\n    ('seeds', 'labels')\n    \"\"\"\n\n    def __init__(self, itemsets: Dict[str, ItemSet]) -> None:\n        self._itemsets = itemsets\n        self._names = next(iter(itemsets.values())).names\n        assert all(\n            self._names == itemset.names for itemset in itemsets.values()\n        ), \"All itemsets must have the same names.\"\n        offset = [0] + [len(itemset) for itemset in self._itemsets.values()]\n        self._offsets = torch.tensor(offset).cumsum(0)\n        self._length = int(self._offsets[-1])\n        self._keys = list(self._itemsets.keys())\n\n    def __len__(self) -> int:\n        return self._length\n\n    def __getitem__(self, index: Union[int, slice, Iterable[int]]):\n        if isinstance(index, int):\n            if index < 0:\n                index += self._length\n            if index < 0 or index >= self._length:\n                raise IndexError(f\"{type(self).__name__} index out of range.\")\n            offset_idx = torch.searchsorted(self._offsets, index, right=True)\n            offset_idx -= 1\n            index -= self._offsets[offset_idx]\n            key = self._keys[offset_idx]\n            return {key: self._itemsets[key][index]}\n        elif isinstance(index, slice):\n            start, stop, step = index.indices(self._length)\n            if step != 1:\n                return self.__getitem__(torch.arange(start, stop, step))\n            assert start < stop, \"Start must be smaller than stop.\"\n            data = {}\n            offset_idx_start = max(\n                1, torch.searchsorted(self._offsets, start, right=False)\n            )\n            for offset_idx in range(offset_idx_start, len(self._offsets)):\n                key = self._keys[offset_idx - 1]\n                data[key] = self._itemsets[key][\n                    max(0, start - self._offsets[offset_idx - 1]) : stop\n                    - self._offsets[offset_idx - 1]\n                ]\n                if stop <= self._offsets[offset_idx]:\n                    break\n            return data\n        elif isinstance(index, Iterable):\n            if not isinstance(index, torch.Tensor):\n                index = torch.tensor(index)\n            assert torch.all((index >= 0) & (index < self._length))\n            key_indices = (\n                torch.searchsorted(self._offsets, index, right=True) - 1\n            )\n            data = {}\n            for key_id, key in enumerate(self._keys):\n                mask = (key_indices == key_id).nonzero().squeeze(1)\n                if len(mask) == 0:\n                    continue\n                data[key] = self._itemsets[key][\n                    index[mask] - self._offsets[key_id]\n                ]\n            return data\n        else:\n            raise TypeError(\n                f\"{type(self).__name__} indices must be int, slice, or \"\n                f\"iterable of int, not {type(index)}.\"\n            )\n\n    @property\n    def names(self) -> Tuple[str]:\n        \"\"\"Return the names of the items.\"\"\"\n        return self._names\n\n    def __repr__(self) -> str:\n        ret = (\n            \"{Classname}(\\n\"\n            \"    itemsets={itemsets},\\n\"\n            \"    names={names},\\n\"\n            \")\"\n        )\n\n        itemsets_str = textwrap.indent(\n            repr(self._itemsets), \" \" * len(\"    itemsets=\")\n        ).strip()\n\n        return ret.format(\n            Classname=self.__class__.__name__,\n            itemsets=itemsets_str,\n            names=self._names,\n        )\n\n\nclass ItemSetDict:\n    \"\"\"`ItemSetDict` is a deprecated class and will be removed in a future\n    version. Please use `HeteroItemSet` instead.\n\n    This class is an alias for `HeteroItemSet` and serves as a wrapper to\n    provide a smooth transition for users of the old class name. It issues a\n    deprecation warning upon instantiation and forwards all attribute access\n    and method calls to an instance of `HeteroItemSet`.\n    \"\"\"\n\n    def __init__(self, itemsets: Dict[str, ItemSet]) -> None:\n        gb_warning(\n            \"ItemSetDict is deprecated and will be removed in the future. \"\n            \"Please use HeteroItemSet instead.\",\n            category=DeprecationWarning,\n        )\n        self._new_instance = HeteroItemSet(itemsets)\n\n    def __getattr__(self, name: str):\n        return getattr(self._new_instance, name)\n\n    def __getitem__(self, index):\n        return self._new_instance[index]\n\n    def __len__(self) -> int:\n        return len(self._new_instance)\n\n    def __repr__(self) -> str:\n        ret = (\n            \"{Classname}(\\n\"\n            \"    itemsets={itemsets},\\n\"\n            \"    names={names},\\n\"\n            \")\"\n        )\n        itemsets_str = textwrap.indent(\n            repr(self._itemsets), \" \" * len(\"    itemsets=\")\n        ).strip()\n        return ret.format(\n            Classname=self.__class__.__name__,\n            itemsets=itemsets_str,\n            names=self._names,\n        )\n"
  },
  {
    "path": "python/dgl/graphbolt/minibatch.py",
    "content": "\"\"\"Unified data structure for input and ouput of all the stages in loading process.\"\"\"\n\nfrom dataclasses import dataclass\nfrom typing import Dict, List, Tuple, Union\n\nimport torch\n\nfrom .base import (\n    apply_to,\n    CSCFormatBase,\n    etype_str_to_tuple,\n    expand_indptr,\n    is_object_pinned,\n)\nfrom .internal_utils import (\n    get_attributes,\n    get_nonproperty_attributes,\n    recursive_apply,\n)\nfrom .sampled_subgraph import SampledSubgraph\n\n__all__ = [\"MiniBatch\"]\n\n\n@dataclass\nclass MiniBatch:\n    r\"\"\"A composite data class for data structure in the graphbolt.\n\n    It is designed to facilitate the exchange of data among different components\n    involved in processing data. The purpose of this class is to unify the\n    representation of input and output data across different stages, ensuring\n    consistency and ease of use throughout the loading process.\"\"\"\n\n    labels: Union[torch.Tensor, Dict[str, torch.Tensor]] = None\n    \"\"\"\n    Labels associated with seeds in the graph.\n    - If `labels` is a tensor: It indicates the graph is homogeneous. The value\n      should be corresponding labels to given 'seeds'.\n    - If `labels` is a dictionary: The keys should be node or edge type and the\n      value should be corresponding labels to given 'seeds'.\n    \"\"\"\n\n    seeds: Union[\n        torch.Tensor,\n        Dict[str, torch.Tensor],\n    ] = None\n    \"\"\"\n    Representation of seed items utilized in node classification tasks, link\n    prediction tasks and hyperlinks tasks.\n    - If `seeds` is a tensor: it indicates that the seeds originate from a\n      homogeneous graph. It can be either a 1-dimensional or 2-dimensional\n      tensor:\n        - 1-dimensional tensor: Each element directly represents a seed node\n          within the graph.\n        - 2-dimensional tensor: Each row designates a seed item, which can\n          encompass various entities such as edges, hyperlinks, or other graph\n          components depending on the specific context.\n    - If `seeds` is a dictionary: it indicates that the seeds originate from a\n      heterogeneous graph. The keys should be edge or node type, and the value\n      should be a tensor, which can be either a 1-dimensional or 2-dimensional\n      tensor:\n        - 1-dimensional tensor: Each element directly represents a seed node\n        of the given type within the graph.\n        - 2-dimensional tensor: Each row designates a seed item of the given\n          type, which can encompass various entities such as edges, hyperlinks,\n          or other graph components depending on the specific context.\n    \"\"\"\n\n    indexes: Union[torch.Tensor, Dict[str, torch.Tensor]] = None\n    \"\"\"\n    Indexes associated with seeds in the graph, which\n    indicates to which query a seeds belongs.\n    - If `indexes` is a tensor: It indicates the graph is homogeneous. The\n      value should be corresponding query to given 'seeds'.\n    - If `indexes` is a dictionary: It indicates the graph is heterogeneous.\n      The keys should be node or edge type and the value should be\n      corresponding query to given 'seeds'. For each key, indexes are\n      consecutive integers starting from zero.\n    \"\"\"\n\n    sampled_subgraphs: List[SampledSubgraph] = None\n    \"\"\"A list of 'SampledSubgraph's, each one corresponding to one layer,\n    representing a subset of a larger graph structure.\n    \"\"\"\n\n    input_nodes: Union[torch.Tensor, Dict[str, torch.Tensor]] = None\n    \"\"\"A representation of input nodes in the outermost layer. Conatins all nodes\n       in the 'sampled_subgraphs'.\n    - If `input_nodes` is a tensor: It indicates the graph is homogeneous.\n    - If `input_nodes` is a dictionary: The keys should be node type and the\n      value should be corresponding heterogeneous node id.\n    \"\"\"\n\n    node_features: Union[\n        Dict[str, torch.Tensor], Dict[Tuple[str, str], torch.Tensor]\n    ] = None\n    \"\"\"A representation of node features.\n      - If keys are single strings: It means the graph is homogeneous, and the\n      keys are feature names.\n      - If keys are tuples: It means the graph is heterogeneous, and the keys\n      are tuples of '(node_type, feature_name)'.\n    \"\"\"\n\n    edge_features: List[\n        Union[Dict[str, torch.Tensor], Dict[Tuple[str, str], torch.Tensor]]\n    ] = None\n    \"\"\"Edge features associated with the 'sampled_subgraphs'.\n      - If keys are single strings: It means the graph is homogeneous, and the\n      keys are feature names.\n      - If keys are tuples: It means the graph is heterogeneous, and the keys\n      are tuples of '(edge_type, feature_name)'. Note, edge type is single\n      string of format 'str:str:str'.\n    \"\"\"\n\n    compacted_seeds: Union[\n        torch.Tensor,\n        Dict[str, torch.Tensor],\n    ] = None\n    \"\"\"\n    Representation of compacted seeds corresponding to 'seeds', where\n    all node ids inside are compacted.\n    \"\"\"\n\n    _blocks: list = None\n    \"\"\"\n    A list of `DGLBlock`s.\n    \"\"\"\n\n    def __repr__(self) -> str:\n        return _minibatch_str(self)\n\n    def node_ids(self) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:\n        \"\"\"A representation of input nodes in the outermost layer. Contains all\n        nodes in the `sampled_subgraphs`.\n        - If `input_nodes` is a tensor: It indicates the graph is homogeneous.\n        - If `input_nodes` is a dictionary: The keys should be node type and the\n          value should be corresponding heterogeneous node id.\n        \"\"\"\n        return self.input_nodes\n\n    def num_layers(self) -> int:\n        \"\"\"Return the number of layers.\"\"\"\n        if self.sampled_subgraphs is None:\n            return 0\n        return len(self.sampled_subgraphs)\n\n    def edge_ids(\n        self, layer_id: int\n    ) -> Union[Dict[str, torch.Tensor], torch.Tensor]:\n        \"\"\"Get the edge ids of a layer.\"\"\"\n        return self.sampled_subgraphs[layer_id].original_edge_ids\n\n    def set_node_features(\n        self,\n        node_features: Union[\n            Dict[str, torch.Tensor], Dict[Tuple[str, str], torch.Tensor]\n        ],\n    ) -> None:\n        \"\"\"Set node features.\"\"\"\n        self.node_features = node_features\n\n    def set_edge_features(\n        self,\n        edge_features: List[\n            Union[Dict[str, torch.Tensor], Dict[Tuple[str, str], torch.Tensor]]\n        ],\n    ) -> None:\n        \"\"\"Set edge features.\"\"\"\n        self.edge_features = edge_features\n\n    @property\n    def blocks(self) -> list:\n        \"\"\"DGL blocks extracted from `MiniBatch` containing graphical structures\n        and ID mappings.\n        \"\"\"\n        if not self.sampled_subgraphs:\n            return None\n\n        if self._blocks is None:\n            self._blocks = self.compute_blocks()\n        return self._blocks\n\n    def compute_blocks(self) -> list:\n        \"\"\"Extracts DGL blocks from `MiniBatch` to construct graphical\n        structures and ID mappings.\n        \"\"\"\n        from dgl.convert import create_block, EID, NID\n\n        is_heterogeneous = isinstance(\n            self.sampled_subgraphs[0].sampled_csc, Dict\n        )\n\n        # Casts to minimum dtype in-place and returns self.\n        def cast_to_minimum_dtype(v: CSCFormatBase):\n            # Checks if number of vertices and edges fit into an int32.\n            dtype = (\n                torch.int32\n                if max(v.indptr.size(0) - 2, v.indices.size(0))\n                <= torch.iinfo(torch.int32).max\n                else torch.int64\n            )\n            v.indptr = v.indptr.to(dtype)\n            v.indices = v.indices.to(dtype)\n            return v\n\n        blocks = []\n        for subgraph in self.sampled_subgraphs:\n            original_row_node_ids = subgraph.original_row_node_ids\n            assert (\n                original_row_node_ids is not None\n            ), \"Missing `original_row_node_ids` in sampled subgraph.\"\n            original_column_node_ids = subgraph.original_column_node_ids\n            assert (\n                original_column_node_ids is not None\n            ), \"Missing `original_column_node_ids` in sampled subgraph.\"\n            if is_heterogeneous:\n                node_types = set()\n                sampled_csc = {}\n                for v in subgraph.sampled_csc.values():\n                    cast_to_minimum_dtype(v)\n                for etype, v in subgraph.sampled_csc.items():\n                    etype_tuple = etype_str_to_tuple(etype)\n                    node_types.add(etype_tuple[0])\n                    node_types.add(etype_tuple[2])\n                    sampled_csc[etype_tuple] = (\n                        \"csc\",\n                        (\n                            v.indptr,\n                            v.indices,\n                            torch.arange(\n                                0,\n                                len(v.indices),\n                                device=v.indptr.device,\n                                dtype=v.indptr.dtype,\n                            ),\n                        ),\n                    )\n                num_src_nodes = {\n                    ntype: (\n                        original_row_node_ids[ntype].size(0)\n                        if original_row_node_ids.get(ntype) is not None\n                        else 0\n                    )\n                    for ntype in node_types\n                }\n                num_dst_nodes = {\n                    ntype: (\n                        original_column_node_ids[ntype].size(0)\n                        if original_column_node_ids.get(ntype) is not None\n                        else 0\n                    )\n                    for ntype in node_types\n                }\n            else:\n                sampled_csc = cast_to_minimum_dtype(subgraph.sampled_csc)\n                sampled_csc = (\n                    \"csc\",\n                    (\n                        sampled_csc.indptr,\n                        sampled_csc.indices,\n                        torch.arange(\n                            0,\n                            len(sampled_csc.indices),\n                            device=sampled_csc.indptr.device,\n                            dtype=sampled_csc.indptr.dtype,\n                        ),\n                    ),\n                )\n                num_src_nodes = original_row_node_ids.size(0)\n                num_dst_nodes = original_column_node_ids.size(0)\n            blocks.append(\n                create_block(\n                    sampled_csc,\n                    num_src_nodes=num_src_nodes,\n                    num_dst_nodes=num_dst_nodes,\n                    node_count_check=False,\n                )\n            )\n\n        if is_heterogeneous:\n            # Assign reverse node ids to the outermost layer's source nodes.\n            for node_type, reverse_ids in self.sampled_subgraphs[\n                0\n            ].original_row_node_ids.items():\n                blocks[0].srcnodes[node_type].data[NID] = reverse_ids\n            # Assign reverse edges ids.\n            for block, subgraph in zip(blocks, self.sampled_subgraphs):\n                if subgraph.original_edge_ids is not None:\n                    for (\n                        edge_type,\n                        reverse_ids,\n                    ) in subgraph.original_edge_ids.items():\n                        block.edges[etype_str_to_tuple(edge_type)].data[\n                            EID\n                        ] = reverse_ids\n        else:\n            blocks[0].srcdata[NID] = self.sampled_subgraphs[\n                0\n            ].original_row_node_ids\n            # Assign reverse edges ids.\n            for block, subgraph in zip(blocks, self.sampled_subgraphs):\n                if subgraph.original_edge_ids is not None:\n                    block.edata[EID] = subgraph.original_edge_ids\n        return blocks\n\n    def to_pyg_data(self):\n        \"\"\"Construct a PyG Data from `MiniBatch`. This function only supports\n        node classification task on a homogeneous graph and the number of\n        features cannot be more than one.\n        \"\"\"\n        from torch_geometric.data import Data\n\n        if self.sampled_subgraphs is None:\n            edge_index = None\n        else:\n            col_nodes = []\n            row_nodes = []\n            for subgraph in self.sampled_subgraphs:\n                if subgraph is None:\n                    continue\n                sampled_csc = subgraph.sampled_csc\n                indptr = sampled_csc.indptr\n                indices = sampled_csc.indices\n                expanded_indptr = expand_indptr(\n                    indptr, dtype=indices.dtype, output_size=len(indices)\n                )\n                col_nodes.append(expanded_indptr)\n                row_nodes.append(indices)\n            col_nodes = torch.cat(col_nodes)\n            row_nodes = torch.cat(row_nodes)\n            edge_index = torch.unique(\n                torch.stack((row_nodes, col_nodes)), dim=1\n            ).long()\n\n        if self.node_features is None:\n            node_features = None\n        else:\n            assert (\n                len(self.node_features) == 1\n            ), \"`to_pyg_data` only supports single feature homogeneous graph.\"\n            node_features = next(iter(self.node_features.values()))\n\n        if self.seeds is not None:\n            if isinstance(self.seeds, Dict):\n                batch_size = len(next(iter(self.seeds.values())))\n            else:\n                batch_size = len(self.seeds)\n        else:\n            batch_size = None\n        pyg_data = Data(\n            x=node_features,\n            edge_index=edge_index,\n            y=self.labels,\n            batch_size=batch_size,\n            n_id=self.node_ids(),\n        )\n        return pyg_data\n\n    def to(\n        self, device: torch.device, non_blocking=False\n    ):  # pylint: disable=invalid-name\n        \"\"\"Copy `MiniBatch` to the specified device using reflection.\"\"\"\n\n        copy_fn = lambda x: apply_to(x, device, non_blocking=non_blocking)\n\n        transfer_attrs = get_nonproperty_attributes(self)\n\n        for attr in transfer_attrs:\n            # Only copy member variables.\n            setattr(self, attr, recursive_apply(getattr(self, attr), copy_fn))\n\n        return self\n\n    def pin_memory(self):\n        \"\"\"Copy `MiniBatch` to the pinned memory using reflection.\"\"\"\n\n        return self.to(\"pinned\")\n\n    def is_pinned(self) -> bool:\n        \"\"\"Check whether `SampledSubgraph` is pinned using reflection.\"\"\"\n\n        return is_object_pinned(self)\n\n\ndef _minibatch_str(minibatch: MiniBatch) -> str:\n    final_str = \"\"\n    # Get all attributes in the class except methods.\n    attributes = get_attributes(minibatch)\n    attributes.reverse()\n    # Insert key with its value into the string.\n    for name in attributes:\n        if name[0] == \"_\":\n            continue\n        val = getattr(minibatch, name)\n\n        def _add_indent(_str, indent):\n            lines = _str.split(\"\\n\")\n            lines = [lines[0]] + [\n                \" \" * (indent + 10) + line for line in lines[1:]\n            ]\n            return \"\\n\".join(lines)\n\n        # Let the variables in the list occupy one line each, and adjust the\n        # indentation on top of the original if the original data output has\n        # line feeds.\n        if isinstance(val, list):\n            val = [str(val_str) for val_str in val]\n            val = \"[\" + \",\\n\".join(val) + \"]\"\n        elif isinstance(val, tuple):\n            val = [str(val_str) for val_str in val]\n            val = \"(\" + \",\\n\".join(val) + \")\"\n        else:\n            val = str(val)\n        final_str = (\n            final_str + f\"{name}={_add_indent(val, len(name)+1)},\\n\" + \" \" * 10\n        )\n    return \"MiniBatch(\" + final_str[:-3] + \")\"\n"
  },
  {
    "path": "python/dgl/graphbolt/minibatch_transformer.py",
    "content": "\"\"\"Mini-batch transformer\"\"\"\n\nfrom torch.utils.data import functional_datapipe\n\nfrom torch.utils.data.datapipes.iter import Mapper\n\nfrom .minibatch import MiniBatch\n\n__all__ = [\n    \"MiniBatchTransformer\",\n]\n\n\n@functional_datapipe(\"transform\")\nclass MiniBatchTransformer(Mapper):\n    \"\"\"A mini-batch transformer used to manipulate mini-batch.\n\n    Functional name: :obj:`transform`.\n\n    Parameters\n    ----------\n    datapipe : DataPipe\n        The datapipe.\n    transformer:\n        The function applied to each minibatch which is responsible for\n        transforming the minibatch.\n    \"\"\"\n\n    def __init__(\n        self,\n        datapipe,\n        transformer=None,\n    ):\n        super().__init__(datapipe, self._transformer)\n        self.transformer = transformer or self._identity\n\n    def _transformer(self, minibatch):\n        minibatch = self.transformer(minibatch)\n        assert isinstance(\n            minibatch, (MiniBatch,)\n        ), \"The transformer output should be an instance of MiniBatch\"\n        return minibatch\n\n    @staticmethod\n    def _identity(minibatch):\n        return minibatch\n"
  },
  {
    "path": "python/dgl/graphbolt/negative_sampler.py",
    "content": "\"\"\"Negative samplers.\"\"\"\n\nfrom _collections_abc import Mapping\n\nfrom torch.utils.data import functional_datapipe\n\nfrom .minibatch_transformer import MiniBatchTransformer\n\n__all__ = [\n    \"NegativeSampler\",\n]\n\n\n@functional_datapipe(\"sample_negative\")\nclass NegativeSampler(MiniBatchTransformer):\n    \"\"\"\n    A negative sampler used to generate negative samples and return\n    a mix of positive and negative samples.\n\n    Functional name: :obj:`sample_negative`.\n\n    Parameters\n    ----------\n    datapipe : DataPipe\n        The datapipe.\n    negative_ratio : int\n        The proportion of negative samples to positive samples.\n    \"\"\"\n\n    def __init__(\n        self,\n        datapipe,\n        negative_ratio,\n    ):\n        super().__init__(datapipe, self._sample)\n        assert negative_ratio > 0, \"Negative_ratio should be positive Integer.\"\n        self.negative_ratio = negative_ratio\n\n    def _sample(self, minibatch):\n        \"\"\"\n        Generate a mix of positive and negative samples. If `seeds` in\n        minibatch is not None, `labels` and `indexes` will be constructed\n        after negative sampling, based on corresponding seeds.\n\n        Parameters\n        ----------\n        minibatch : MiniBatch\n            An instance of 'MiniBatch' class requires the 'seeds' field. This\n            function is responsible for generating negative edges corresponding\n            to the positive edges defined by the 'seeds'.\n\n        Returns\n        -------\n        MiniBatch\n            An instance of 'MiniBatch' encompasses both positive and negative\n            samples.\n        \"\"\"\n        seeds = minibatch.seeds\n        if isinstance(seeds, Mapping):\n            if minibatch.indexes is None:\n                minibatch.indexes = {}\n            if minibatch.labels is None:\n                minibatch.labels = {}\n            for etype, pos_pairs in seeds.items():\n                (\n                    minibatch.seeds[etype],\n                    minibatch.labels[etype],\n                    minibatch.indexes[etype],\n                ) = self._sample_with_etype(pos_pairs, etype)\n        else:\n            (\n                minibatch.seeds,\n                minibatch.labels,\n                minibatch.indexes,\n            ) = self._sample_with_etype(seeds)\n        return minibatch\n\n    def _sample_with_etype(self, seeds, etype=None):\n        \"\"\"Generate negative pairs for a given etype form positive pairs\n        for a given etype. If `seeds` is a 2D tensor, which represents\n        `seeds` is used in minibatch, corresponding labels and indexes will be\n        constructed.\n\n        Parameters\n        ----------\n        seeds : Tensor, Tensor\n            A N*2 tensors that represent source-destination node pairs of\n            positive edges, where positive means the edge must exist in the\n            graph.\n        etype : str\n            Canonical edge type.\n\n        Returns\n        -------\n        Tensor\n            A collection of postive and negative node pairs.\n        Tensor\n            Corresponding labels. If label is True, corresponding edge is\n            positive. If label is False, corresponding edge is negative.\n        Tensor\n            Corresponding indexes, indicates to which query an edge belongs.\n\n        \"\"\"\n        raise NotImplementedError\n"
  },
  {
    "path": "python/dgl/graphbolt/sampled_subgraph.py",
    "content": "\"\"\"Graphbolt sampled subgraph.\"\"\"\n\n# pylint: disable= invalid-name\nfrom typing import Dict, NamedTuple, Tuple, Union\n\nimport torch\n\nfrom .base import (\n    apply_to,\n    CSCFormatBase,\n    etype_str_to_tuple,\n    expand_indptr,\n    is_object_pinned,\n    isin,\n)\n\nfrom .internal_utils import recursive_apply\n\n\n__all__ = [\"SampledSubgraph\"]\n\n\nclass _ExcludeEdgesWaiter:\n    def __init__(self, sampled_subgraph, index):\n        self.sampled_subgraph = sampled_subgraph\n        self.index = index\n\n    def wait(self):\n        \"\"\"Returns the stored value when invoked.\"\"\"\n        sampled_subgraph = self.sampled_subgraph\n        index = self.index\n        # Ensure there is no memory leak.\n        self.sampled_subgraph = self.index = None\n\n        if isinstance(index, dict):\n            for k in list(index.keys()):\n                index[k] = index[k].wait()\n        else:\n            index = index.wait()\n\n        return type(sampled_subgraph)(*_slice_subgraph(sampled_subgraph, index))\n\n\nclass PyGLayerData(NamedTuple):\n    \"\"\"A named tuple class to represent homogenous inputs to a PyG model layer.\n    The fields are x (input features), edge_index and size\n    (source and destination sizes).\n    \"\"\"\n\n    x: torch.Tensor\n    edge_index: torch.Tensor\n    size: Tuple[int, int]\n\n\nclass PyGLayerHeteroData(NamedTuple):\n    \"\"\"A named tuple class to represent heterogenous inputs to a PyG model\n    layer. The fields are x (input features), edge_index and size\n    (source and destination sizes), and all fields are dictionaries.\n    \"\"\"\n\n    x: Dict[str, torch.Tensor]\n    edge_index: Dict[str, torch.Tensor]\n    size: Dict[str, Tuple[int, int]]\n\n\nclass SampledSubgraph:\n    r\"\"\"An abstract class for sampled subgraph. In the context of a\n    heterogeneous graph, each field should be of `Dict` type. Otherwise,\n    for homogeneous graphs, each field should correspond to its respective\n    value type.\"\"\"\n\n    @property\n    def sampled_csc(\n        self,\n    ) -> Union[CSCFormatBase, Dict[str, CSCFormatBase],]:\n        \"\"\"Returns the node pairs representing edges in csc format.\n          - If `sampled_csc` is a CSCFormatBase: It should be in the csc\n            format. `indptr` stores the index in the data array where each\n            column starts. `indices` stores the row indices of the non-zero\n            elements.\n          - If `sampled_csc` is a dictionary: The keys should be edge type and\n            the values should be corresponding node pairs. The ids inside is\n            heterogeneous ids.\n\n        Examples\n        --------\n        1. Homogeneous graph.\n\n        >>> import dgl.graphbolt as gb\n        >>> import torch\n        >>> sampled_csc = gb.CSCFormatBase(\n        ...     indptr=torch.tensor([0, 1, 2, 3]),\n        ...     indices=torch.tensor([0, 1, 2]))\n        >>> print(sampled_csc)\n        CSCFormatBase(indptr=tensor([0, 1, 2, 3]),\n                    indices=tensor([0, 1, 2]),\n        )\n\n        2. Heterogeneous graph.\n\n        >>> sampled_csc = {\"A:relation:B\": gb.CSCFormatBase(\n        ...     indptr=torch.tensor([0, 1, 2, 3]),\n        ...     indices=torch.tensor([0, 1, 2]))}\n        >>> print(sampled_csc)\n        {'A:relation:B': CSCFormatBase(indptr=tensor([0, 1, 2, 3]),\n                    indices=tensor([0, 1, 2]),\n        )}\n        \"\"\"\n        raise NotImplementedError\n\n    @property\n    def original_column_node_ids(\n        self,\n    ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:\n        \"\"\"Returns corresponding reverse column node ids the original graph.\n        Column's reverse node ids in the original graph. A graph structure\n        can be treated as a coordinated row and column pair, and this is\n        the mapped ids of the column.\n          - If `original_column_node_ids` is a tensor: It represents the\n            original node ids.\n          - If `original_column_node_ids` is a dictionary: The keys should be\n            node type and the values should be corresponding original\n            heterogeneous node ids.\n        If present, it means column IDs are compacted, and `sampled_csc`\n        column IDs match these compacted ones.\n        \"\"\"\n        return None\n\n    @property\n    def original_row_node_ids(\n        self,\n    ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:\n        \"\"\"Returns corresponding reverse row node ids the original graph.\n        Row's reverse node ids in the original graph. A graph structure\n        can be treated as a coordinated row and column pair, and this is\n        the mapped ids of the row.\n          - If `original_row_node_ids` is a tensor: It represents the original\n            node ids.\n          - If `original_row_node_ids` is a dictionary: The keys should be node\n            type and the values should be corresponding original heterogeneous\n            node ids.\n        If present, it means row IDs are compacted, and `sampled_csc`\n        row IDs match these compacted ones.\"\"\"\n        return None\n\n    @property\n    def original_edge_ids(self) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:\n        \"\"\"Returns corresponding reverse edge ids the original graph.\n        Reverse edge ids in the original graph. This is useful when edge\n        features are needed.\n          - If `original_edge_ids` is a tensor: It represents the original edge\n            ids.\n          - If `original_edge_ids` is a dictionary: The keys should be edge\n            type and the values should be corresponding original heterogeneous\n            edge ids.\n        \"\"\"\n        return None\n\n    def exclude_edges(\n        self,\n        edges: Union[\n            Dict[str, torch.Tensor],\n            torch.Tensor,\n        ],\n        assume_num_node_within_int32: bool = True,\n        async_op: bool = False,\n    ):\n        r\"\"\"Exclude edges from the sampled subgraph.\n\n        This function can be used with sampled subgraphs, regardless of\n        whether they have compacted row/column nodes or not. If the original\n        subgraph has compacted row or column nodes, the corresponding row or\n        column nodes in the returned subgraph will also be compacted.\n\n        Parameters\n        ----------\n        self : SampledSubgraph\n            The sampled subgraph.\n        edges : Union[torch.Tensor, Dict[str, torch.Tensor]]\n            Edges to exclude. If sampled subgraph is homogeneous, then `edges`\n            should be a N*2 tensors representing the edges to exclude. If\n            sampled subgraph is heterogeneous, then `edges` should be a\n            dictionary of edge types and the corresponding edges to exclude.\n        assume_num_node_within_int32: bool\n            If True, assumes the value of node IDs in the provided `edges` fall\n            within the int32 range, which can significantly enhance computation\n            speed. Default: True\n        async_op: bool\n            Boolean indicating whether the call is asynchronous. If so, the\n            result can be obtained by calling wait on the returned future.\n\n        Returns\n        -------\n        SampledSubgraph\n            An instance of a class that inherits from `SampledSubgraph`.\n\n        Examples\n        --------\n        >>> import dgl.graphbolt as gb\n        >>> import torch\n        >>> sampled_csc = {\"A:relation:B\": gb.CSCFormatBase(\n        ...     indptr=torch.tensor([0, 1, 2, 3]),\n        ...     indices=torch.tensor([0, 1, 2]))}\n        >>> original_column_node_ids = {\"B\": torch.tensor([10, 11, 12])}\n        >>> original_row_node_ids = {\"A\": torch.tensor([13, 14, 15])}\n        >>> original_edge_ids = {\"A:relation:B\": torch.tensor([19, 20, 21])}\n        >>> subgraph = gb.SampledSubgraphImpl(\n        ...     sampled_csc=sampled_csc,\n        ...     original_column_node_ids=original_column_node_ids,\n        ...     original_row_node_ids=original_row_node_ids,\n        ...     original_edge_ids=original_edge_ids\n        ... )\n        >>> edges_to_exclude = {\"A:relation:B\": torch.tensor([[14, 11], [15, 12]])}\n        >>> result = subgraph.exclude_edges(edges_to_exclude)\n        >>> print(result.sampled_csc)\n        {'A:relation:B': CSCFormatBase(indptr=tensor([0, 1, 1, 1]),\n                    indices=tensor([0]),\n        )}\n        >>> print(result.original_column_node_ids)\n        {'B': tensor([10, 11, 12])}\n        >>> print(result.original_row_node_ids)\n        {'A': tensor([13, 14, 15])}\n        >>> print(result.original_edge_ids)\n        {'A:relation:B': tensor([19])}\n        \"\"\"\n        # TODO: Add support for value > in32, then remove this line.\n        assert (\n            assume_num_node_within_int32\n        ), \"Values > int32 are not supported yet.\"\n        assert (isinstance(self.sampled_csc, CSCFormatBase)) == isinstance(\n            edges, torch.Tensor\n        ), (\n            \"The sampled subgraph and the edges to exclude should be both \"\n            \"homogeneous or both heterogeneous.\"\n        )\n        # Get type of calling class.\n        calling_class = type(self)\n\n        # Three steps to exclude edges:\n        # 1. Convert the node pairs to the original ids if they are compacted.\n        # 2. Exclude the edges and get the index of the edges to keep.\n        # 3. Slice the subgraph according to the index.\n        if isinstance(self.sampled_csc, CSCFormatBase):\n            reverse_edges = _to_reverse_ids(\n                self.sampled_csc,\n                self.original_row_node_ids,\n                self.original_column_node_ids,\n            )\n            index = _exclude_homo_edges(\n                reverse_edges, edges, assume_num_node_within_int32, async_op\n            )\n        else:\n            index = {}\n            for etype, pair in self.sampled_csc.items():\n                if etype not in edges:\n                    # No edges need to be excluded.\n                    index[etype] = None\n                    continue\n                src_type, _, dst_type = etype_str_to_tuple(etype)\n                original_row_node_ids = (\n                    None\n                    if self.original_row_node_ids is None\n                    else self.original_row_node_ids.get(src_type)\n                )\n                original_column_node_ids = (\n                    None\n                    if self.original_column_node_ids is None\n                    else self.original_column_node_ids.get(dst_type)\n                )\n                reverse_edges = _to_reverse_ids(\n                    pair,\n                    original_row_node_ids,\n                    original_column_node_ids,\n                )\n                index[etype] = _exclude_homo_edges(\n                    reverse_edges,\n                    edges[etype],\n                    assume_num_node_within_int32,\n                    async_op,\n                )\n        if async_op:\n            return _ExcludeEdgesWaiter(self, index)\n        else:\n            return calling_class(*_slice_subgraph(self, index))\n\n    def to_pyg(\n        self, x: Union[torch.Tensor, Dict[str, torch.Tensor]]\n    ) -> Union[PyGLayerData, PyGLayerHeteroData]:\n        \"\"\"\n        Process layer inputs so that they can be consumed by a PyG model layer.\n\n        Parameters\n        ----------\n        x : Union[torch.Tensor, Dict[str, torch.Tensor]]\n            The input node features to the GNN layer.\n\n        Returns\n        -------\n        Union[PyGLayerData, PyGLayerHeteroData]\n            A named tuple class with `x`, `edge_index` and `size` fields.\n            Typically, a PyG GNN layer's forward method will accept these as\n            arguments.\n        \"\"\"\n        if isinstance(x, torch.Tensor):\n            # Homogenous\n            src = self.sampled_csc.indices\n            dst = expand_indptr(\n                self.sampled_csc.indptr,\n                dtype=src.dtype,\n                output_size=src.size(0),\n            )\n            edge_index = torch.stack([src, dst], dim=0).long()\n            dst_size = self.sampled_csc.indptr.size(0) - 1\n            # h and h[:dst_size] correspond to source and destination features resp.\n            return PyGLayerData(\n                (x, x[:dst_size]), edge_index, (x.size(0), dst_size)\n            )\n        else:\n            # Heterogenous\n            x_dst_dict = {}\n            edge_index_dict = {}\n            sizes_dict = {}\n            for etype, sampled_csc in self.sampled_csc.items():\n                src = sampled_csc.indices\n                dst = expand_indptr(\n                    sampled_csc.indptr,\n                    dtype=src.dtype,\n                    output_size=src.size(0),\n                )\n                edge_index = torch.stack([src, dst], dim=0).long()\n                dst_size = sampled_csc.indptr.size(0) - 1\n                # h and h[:dst_size] correspond to source and destination features resp.\n                src_ntype, _, dst_ntype = etype_str_to_tuple(etype)\n                x_dst_dict[dst_ntype] = x[dst_ntype][:dst_size]\n                edge_index_dict[etype] = edge_index\n                sizes_dict[etype] = (x[src_ntype].size(0), dst_size)\n\n            return PyGLayerHeteroData(\n                (x, x_dst_dict), edge_index_dict, sizes_dict\n            )\n\n    def to(\n        self, device: torch.device, non_blocking=False\n    ) -> None:  # pylint: disable=invalid-name\n        \"\"\"Copy `SampledSubgraph` to the specified device using reflection.\"\"\"\n\n        for attr in dir(self):\n            # Only copy member variables.\n            if not callable(getattr(self, attr)) and not attr.startswith(\"__\"):\n                setattr(\n                    self,\n                    attr,\n                    recursive_apply(\n                        getattr(self, attr),\n                        apply_to,\n                        device,\n                        non_blocking=non_blocking,\n                    ),\n                )\n\n        return self\n\n    def pin_memory(self):\n        \"\"\"Copy `SampledSubgraph` to the pinned memory using reflection.\"\"\"\n\n        return self.to(\"pinned\")\n\n    def is_pinned(self) -> bool:\n        \"\"\"Check whether `SampledSubgraph` is pinned using reflection.\"\"\"\n\n        return is_object_pinned(self)\n\n\ndef _to_reverse_ids(node_pair, original_row_node_ids, original_column_node_ids):\n    indptr = node_pair.indptr\n    indices = node_pair.indices\n    if original_row_node_ids is not None:\n        indices = torch.index_select(\n            original_row_node_ids, dim=0, index=indices\n        )\n    indptr = expand_indptr(\n        indptr, indices.dtype, original_column_node_ids, len(indices)\n    )\n    return (indices, indptr)\n\n\ndef _relabel_two_arrays(lhs_array, rhs_array):\n    \"\"\"Relabel two arrays into a consecutive range starting from 0.\"\"\"\n    concated = torch.cat([lhs_array, rhs_array])\n    _, mapping = torch.unique(concated, return_inverse=True)\n    return mapping[: lhs_array.numel()], mapping[lhs_array.numel() :]\n\n\ndef _exclude_homo_edges(\n    edges: Tuple[torch.Tensor, torch.Tensor],\n    edges_to_exclude: torch.Tensor,\n    assume_num_node_within_int32: bool,\n    async_op: bool,\n):\n    \"\"\"Return the indices of edges to be included.\"\"\"\n    if assume_num_node_within_int32:\n        val = edges[0].long() << 32 | edges[1].long()\n        edges_to_exclude_trans = edges_to_exclude.T\n        val_to_exclude = (\n            edges_to_exclude_trans[0].long() << 32\n            | edges_to_exclude_trans[1].long()\n        )\n    else:\n        # TODO: Add support for value > int32.\n        raise NotImplementedError(\n            \"Values out of range int32 are not supported yet\"\n        )\n    if async_op:\n        return torch.ops.graphbolt.is_not_in_index_async(val, val_to_exclude)\n    else:\n        mask = ~isin(val, val_to_exclude)\n        return torch.nonzero(mask, as_tuple=True)[0]\n\n\ndef _slice_subgraph(subgraph: SampledSubgraph, index: torch.Tensor):\n    \"\"\"Slice the subgraph according to the index.\"\"\"\n\n    def _index_select(obj, index):\n        if obj is None:\n            return None\n        if index is None:\n            return obj\n        if isinstance(obj, CSCFormatBase):\n            new_indices = obj.indices[index]\n            new_indptr = torch.searchsorted(index, obj.indptr)\n            return CSCFormatBase(\n                indptr=new_indptr,\n                indices=new_indices,\n            )\n        if isinstance(obj, torch.Tensor):\n            return obj[index]\n        # Handle the case when obj is a dictionary.\n        assert isinstance(obj, dict)\n        assert isinstance(index, dict)\n        ret = {}\n        for k, v in obj.items():\n            ret[k] = _index_select(v, index[k])\n        return ret\n\n    return (\n        _index_select(subgraph.sampled_csc, index),\n        subgraph.original_column_node_ids,\n        subgraph.original_row_node_ids,\n        _index_select(subgraph.original_edge_ids, index),\n    )\n"
  },
  {
    "path": "python/dgl/graphbolt/sampling_graph.py",
    "content": "\"\"\"Sampling Graphs.\"\"\"\n\nfrom typing import Dict, Union\n\nimport torch\n\n\n__all__ = [\"SamplingGraph\"]\n\n\nclass SamplingGraph:\n    r\"\"\"Class for sampling graph.\"\"\"\n\n    def __init__(self):\n        pass\n\n    def __repr__(self) -> str:\n        \"\"\"Return a string representation of the graph.\n\n        Returns\n        -------\n        str\n            String representation of the graph.\n        \"\"\"\n        raise NotImplementedError\n\n    @property\n    def num_nodes(self) -> Union[int, Dict[str, int]]:\n        \"\"\"The number of nodes in the graph.\n        - If the graph is homogenous, returns an integer.\n        - If the graph is heterogenous, returns a dictionary.\n\n        Returns\n        -------\n        Union[int, Dict[str, int]]\n            The number of nodes. Integer indicates the total nodes number of a\n            homogenous graph; dict indicates nodes number per node types of a\n            heterogenous graph.\n        \"\"\"\n        raise NotImplementedError\n\n    @property\n    def num_edges(self) -> Union[int, Dict[str, int]]:\n        \"\"\"The number of edges in the graph.\n        - If the graph is homogenous, returns an integer.\n        - If the graph is heterogenous, returns a dictionary.\n\n        Returns\n        -------\n        Union[int, Dict[str, int]]\n            The number of edges. Integer indicates the total edges number of a\n            homogenous graph; dict indicates edges number per edge types of a\n            heterogenous graph.\n        \"\"\"\n        raise NotImplementedError\n\n    def copy_to_shared_memory(self, shared_memory_name: str) -> \"SamplingGraph\":\n        \"\"\"Copy the graph to shared memory.\n\n        Parameters\n        ----------\n        shared_memory_name : str\n            Name of the shared memory.\n\n        Returns\n        -------\n        SamplingGraph\n            The copied SamplingGraph object on shared memory.\n        \"\"\"\n        raise NotImplementedError\n\n    # pylint: disable=invalid-name\n    def to(self, device: torch.device) -> \"SamplingGraph\":\n        \"\"\"Copy graph to the specified device.\n\n        Parameters\n        ----------\n        device : torch.device\n            The destination device.\n\n        Returns\n        -------\n        SamplingGraph\n            The graph on the specified device.\n        \"\"\"\n        raise NotImplementedError\n"
  },
  {
    "path": "python/dgl/graphbolt/subgraph_sampler.py",
    "content": "\"\"\"Subgraph samplers\"\"\"\n\nfrom collections import defaultdict\nfrom functools import partial\nfrom typing import Dict\n\nimport torch\nimport torch.distributed as thd\nfrom torch.utils.data import functional_datapipe\n\nfrom .base import seed_type_str_to_ntypes\nfrom .internal import compact_temporal_nodes, unique_and_compact\nfrom .minibatch import MiniBatch\nfrom .minibatch_transformer import MiniBatchTransformer\n\n__all__ = [\n    \"SubgraphSampler\",\n    \"all_to_all\",\n    \"convert_to_hetero\",\n    \"revert_to_homo\",\n]\n\n\nclass _NoOpWaiter:\n    def __init__(self, result):\n        self.result = result\n\n    def wait(self):\n        \"\"\"Returns the stored value when invoked.\"\"\"\n        result = self.result\n        # Ensure there is no memory leak.\n        self.result = None\n        return result\n\n\ndef _shift(inputs: list, group=None):\n    cutoff = len(inputs) - thd.get_rank(group)\n    return inputs[cutoff:] + inputs[:cutoff]\n\n\ndef all_to_all(outputs, inputs, group=None, async_op=False):\n    \"\"\"Wrapper for thd.all_to_all that permuted outputs and inputs before\n    calling it. The arguments have the permutation\n    `rank, ..., world_size - 1, 0, ..., rank - 1` and we make it\n    `0, world_size - 1` before calling `thd.all_to_all`.\"\"\"\n    shift_fn = partial(_shift, group=group)\n    outputs = shift_fn(list(outputs))\n    inputs = shift_fn(list(inputs))\n    if outputs[0].is_cuda:\n        return thd.all_to_all(outputs, inputs, group, async_op)\n    # gloo backend will be used.\n    outputs_single = torch.cat(outputs)\n    output_split_sizes = [o.size(0) for o in outputs]\n    handle = thd.all_to_all_single(\n        outputs_single,\n        torch.cat(inputs),\n        output_split_sizes,\n        [i.size(0) for i in inputs],\n        group,\n        async_op,\n    )\n    temp_outputs = outputs_single.split(output_split_sizes)\n\n    class _Waiter:\n        def __init__(self, handle, outputs, temp_outputs):\n            self.handle = handle\n            self.outputs = outputs\n            self.temp_outputs = temp_outputs\n\n        def wait(self):\n            \"\"\"Returns the stored value when invoked.\"\"\"\n            handle = self.handle\n            outputs = self.outputs\n            temp_outputs = self.temp_outputs\n            # Ensure that there is no leak\n            self.handle = self.outputs = self.temp_outputs = None\n\n            if handle is not None:\n                handle.wait()\n            for output, temp_output in zip(outputs, temp_outputs):\n                output.copy_(temp_output)\n\n    post_processor = _Waiter(handle, outputs, temp_outputs)\n    return post_processor if async_op else post_processor.wait()\n\n\ndef revert_to_homo(d: dict):\n    \"\"\"Utility function to convert a dictionary that stores homogenous data.\"\"\"\n    is_homogenous = len(d) == 1 and \"_N\" in d\n    return list(d.values())[0] if is_homogenous else d\n\n\ndef convert_to_hetero(item):\n    \"\"\"Utility function to convert homogenous data to heterogenous with a single\n    node type.\"\"\"\n    is_heterogenous = isinstance(item, dict)\n    return item if is_heterogenous else {\"_N\": item}\n\n\n@functional_datapipe(\"sample_subgraph\")\nclass SubgraphSampler(MiniBatchTransformer):\n    \"\"\"A subgraph sampler used to sample a subgraph from a given set of nodes\n    from a larger graph.\n\n    Functional name: :obj:`sample_subgraph`.\n\n    This class is the base class of all subgraph samplers. Any subclass of\n    SubgraphSampler should implement either the :meth:`sample_subgraphs` method\n    or the :meth:`sampling_stages` method to define the fine-grained sampling\n    stages to take advantage of optimizations provided by the GraphBolt\n    DataLoader.\n\n    Parameters\n    ----------\n    datapipe : DataPipe\n        The datapipe.\n    args : Non-Keyword Arguments\n        Arguments to be passed into sampling_stages.\n    kwargs : Keyword Arguments\n        Arguments to be passed into sampling_stages. Preprocessing stage makes\n        use of the `asynchronous` and `cooperative` parameters before they are\n        passed to the sampling stages.\n    \"\"\"\n\n    def __init__(\n        self,\n        datapipe,\n        *args,\n        **kwargs,\n    ):\n        async_op = kwargs.get(\"asynchronous\", False)\n        cooperative = kwargs.get(\"cooperative\", False)\n        preprocess_fn = partial(\n            self._preprocess, cooperative=cooperative, async_op=async_op\n        )\n        datapipe = datapipe.transform(preprocess_fn)\n        if async_op:\n            fn = partial(self._wait_preprocess_future, cooperative=cooperative)\n            datapipe = datapipe.buffer().transform(fn)\n        if cooperative:\n            datapipe = datapipe.transform(self._seeds_cooperative_exchange_1)\n            datapipe = datapipe.buffer()\n            datapipe = datapipe.transform(\n                self._seeds_cooperative_exchange_1_wait_future\n            ).buffer()\n            datapipe = datapipe.transform(self._seeds_cooperative_exchange_2)\n            datapipe = datapipe.buffer()\n            datapipe = datapipe.transform(self._seeds_cooperative_exchange_3)\n            datapipe = datapipe.buffer()\n            datapipe = datapipe.transform(self._seeds_cooperative_exchange_4)\n        datapipe = self.sampling_stages(datapipe, *args, **kwargs)\n        datapipe = datapipe.transform(self._postprocess)\n        super().__init__(datapipe)\n\n    @staticmethod\n    def _postprocess(minibatch):\n        delattr(minibatch, \"_seed_nodes\")\n        delattr(minibatch, \"_seeds_timestamp\")\n        return minibatch\n\n    @staticmethod\n    def _preprocess(minibatch, cooperative: bool, async_op: bool):\n        if minibatch.seeds is None:\n            raise ValueError(\n                f\"Invalid minibatch {minibatch}: `seeds` should have a value.\"\n            )\n        rank = thd.get_rank() if cooperative else 0\n        world_size = thd.get_world_size() if cooperative else 1\n        results = SubgraphSampler._seeds_preprocess(\n            minibatch, rank, world_size, async_op\n        )\n        if async_op:\n            minibatch._preprocess_future = results\n        else:\n            (\n                minibatch._seed_nodes,\n                minibatch._seeds_timestamp,\n                minibatch.compacted_seeds,\n                offsets,\n            ) = results\n            if cooperative:\n                minibatch._seeds_offsets = offsets\n        return minibatch\n\n    @staticmethod\n    def _wait_preprocess_future(minibatch, cooperative: bool):\n        (\n            minibatch._seed_nodes,\n            minibatch._seeds_timestamp,\n            minibatch.compacted_seeds,\n            offsets,\n        ) = minibatch._preprocess_future.wait()\n        delattr(minibatch, \"_preprocess_future\")\n        if cooperative:\n            minibatch._seeds_offsets = offsets\n        return minibatch\n\n    @staticmethod\n    def _seeds_cooperative_exchange_1(minibatch):\n        rank = thd.get_rank()\n        world_size = thd.get_world_size()\n        seeds = minibatch._seed_nodes\n        is_homogeneous = not isinstance(seeds, dict)\n        if is_homogeneous:\n            seeds = {\"_N\": seeds}\n        if minibatch._seeds_offsets is None:\n            assert minibatch.compacted_seeds is None\n            minibatch._rank_sort_future = torch.ops.graphbolt.rank_sort_async(\n                list(seeds.values()), rank, world_size\n            )\n        return minibatch\n\n    @staticmethod\n    def _seeds_cooperative_exchange_1_wait_future(minibatch):\n        world_size = thd.get_world_size()\n        seeds = minibatch._seed_nodes\n        is_homogeneous = not isinstance(seeds, dict)\n        if is_homogeneous:\n            seeds = {\"_N\": seeds}\n        num_ntypes = len(seeds.keys())\n        if minibatch._seeds_offsets is None:\n            result = minibatch._rank_sort_future.wait()\n            delattr(minibatch, \"_rank_sort_future\")\n            sorted_seeds, sorted_compacted, sorted_offsets = {}, {}, {}\n            for i, (\n                seed_type,\n                (typed_sorted_seeds, typed_index, typed_offsets),\n            ) in enumerate(zip(seeds.keys(), result)):\n                sorted_seeds[seed_type] = typed_sorted_seeds\n                sorted_compacted[seed_type] = typed_index\n                sorted_offsets[seed_type] = typed_offsets\n\n            minibatch._seed_nodes = sorted_seeds\n            minibatch.compacted_seeds = revert_to_homo(sorted_compacted)\n            minibatch._seeds_offsets = sorted_offsets\n        else:\n            minibatch._seeds_offsets = {\"_N\": minibatch._seeds_offsets}\n        counts_sent = torch.empty(world_size * num_ntypes, dtype=torch.int64)\n        for i, offsets in enumerate(minibatch._seeds_offsets.values()):\n            counts_sent[\n                torch.arange(i, world_size * num_ntypes, num_ntypes)\n            ] = offsets.diff()\n        delattr(minibatch, \"_seeds_offsets\")\n        counts_received = torch.empty_like(counts_sent)\n        minibatch._counts_future = all_to_all(\n            counts_received.split(num_ntypes),\n            counts_sent.split(num_ntypes),\n            async_op=True,\n        )\n        minibatch._counts_sent = counts_sent\n        minibatch._counts_received = counts_received\n        return minibatch\n\n    @staticmethod\n    def _seeds_cooperative_exchange_2(minibatch):\n        world_size = thd.get_world_size()\n        seeds = minibatch._seed_nodes\n        minibatch._counts_future.wait()\n        delattr(minibatch, \"_counts_future\")\n        num_ntypes = len(seeds.keys())\n        seeds_received = {}\n        counts_sent = {}\n        counts_received = {}\n        for i, (ntype, typed_seeds) in enumerate(seeds.items()):\n            idx = torch.arange(i, world_size * num_ntypes, num_ntypes)\n            typed_counts_sent = minibatch._counts_sent[idx].tolist()\n            typed_counts_received = minibatch._counts_received[idx].tolist()\n            typed_seeds_received = typed_seeds.new_empty(\n                sum(typed_counts_received)\n            )\n            all_to_all(\n                typed_seeds_received.split(typed_counts_received),\n                typed_seeds.split(typed_counts_sent),\n            )\n            seeds_received[ntype] = typed_seeds_received\n            counts_sent[ntype] = typed_counts_sent\n            counts_received[ntype] = typed_counts_received\n        minibatch._seed_nodes = seeds_received\n        minibatch._counts_sent = revert_to_homo(counts_sent)\n        minibatch._counts_received = revert_to_homo(counts_received)\n        return minibatch\n\n    @staticmethod\n    def _seeds_cooperative_exchange_3(minibatch):\n        nodes = {\n            ntype: [typed_seeds]\n            for ntype, typed_seeds in minibatch._seed_nodes.items()\n        }\n        minibatch._unique_future = unique_and_compact(\n            nodes, 0, 1, async_op=True\n        )\n        return minibatch\n\n    @staticmethod\n    def _seeds_cooperative_exchange_4(minibatch):\n        unique_seeds, inverse_seeds, _ = minibatch._unique_future.wait()\n        delattr(minibatch, \"_unique_future\")\n        inverse_seeds = {\n            ntype: typed_inv[0] for ntype, typed_inv in inverse_seeds.items()\n        }\n        minibatch._seed_nodes = revert_to_homo(unique_seeds)\n        sizes = {\n            ntype: typed_seeds.size(0)\n            for ntype, typed_seeds in unique_seeds.items()\n        }\n        minibatch._seed_sizes = revert_to_homo(sizes)\n        minibatch._seed_inverse_ids = revert_to_homo(inverse_seeds)\n        return minibatch\n\n    def _sample(self, minibatch):\n        (\n            minibatch.input_nodes,\n            minibatch.sampled_subgraphs,\n        ) = self.sample_subgraphs(\n            minibatch._seed_nodes, minibatch._seeds_timestamp\n        )\n        return minibatch\n\n    def sampling_stages(self, datapipe):\n        \"\"\"The sampling stages are defined here by chaining to the datapipe. The\n        default implementation expects :meth:`sample_subgraphs` to be\n        implemented. To define fine-grained stages, this method should be\n        overridden.\n        \"\"\"\n        return datapipe.transform(self._sample)\n\n    @staticmethod\n    def _seeds_preprocess(\n        minibatch: MiniBatch,\n        rank: int = 0,\n        world_size: int = 1,\n        async_op: bool = False,\n    ):\n        \"\"\"Preprocess `seeds` in a minibatch to construct `unique_seeds`,\n        `node_timestamp` and `compacted_seeds` for further sampling. It\n        optionally incorporates timestamps for temporal graphs, organizing and\n        compacting seeds based on their types and timestamps. In heterogeneous\n        graph, `seeds` with same node type will be unqiued together.\n\n        Parameters\n        ----------\n        minibatch: MiniBatch\n            The minibatch.\n        rank : int\n            The rank of the current process among cooperating processes.\n        world_size : int\n            The number of cooperating\n            (`arXiv:2210.13339<https://arxiv.org/abs/2310.12403>`__) processes.\n        async_op: bool\n            Boolean indicating whether the call is asynchronous. If so, the\n            result can be obtained by calling wait on the returned future.\n\n        Returns\n        -------\n        unique_seeds: torch.Tensor or Dict[str, torch.Tensor]\n            A tensor or a dictionary of tensors representing the unique seeds.\n            In heterogeneous graphs, seeds are returned for each node type.\n        nodes_timestamp: None or a torch.Tensor or Dict[str, torch.Tensor]\n            Containing timestamps for each seed. This is only returned if\n            `minibatch` includes timestamps and the graph is temporal.\n        compacted_seeds: torch.tensor or a Dict[str, torch.Tensor]\n            Representation of compacted seeds corresponding to 'seeds', where\n            all node ids inside are compacted.\n        offsets: None or torch.Tensor or Dict[src, torch.Tensor]\n            The unique nodes offsets tensor partitions the unique_nodes tensor.\n            Has size `world_size + 1` and\n            `unique_nodes[offsets[i]: offsets[i + 1]]` belongs to the rank\n            `(rank + i) % world_size`.\n        \"\"\"\n        use_timestamp = hasattr(minibatch, \"timestamp\")\n        assert (\n            not use_timestamp or world_size == 1\n        ), \"Temporal code path does not currently support Cooperative Minibatching\"\n        seeds = minibatch.seeds\n        is_heterogeneous = isinstance(seeds, Dict)\n        if is_heterogeneous:\n            # Collect nodes from all types of input.\n            nodes = defaultdict(list)\n            nodes_timestamp = None\n            if use_timestamp:\n                nodes_timestamp = defaultdict(list)\n            for seed_type, typed_seeds in seeds.items():\n                # When typed_seeds is a one-dimensional tensor, it represents\n                # seed nodes, which does not need to do unique and compact.\n                if typed_seeds.ndim == 1:\n                    nodes_timestamp = (\n                        minibatch.timestamp\n                        if hasattr(minibatch, \"timestamp\")\n                        else None\n                    )\n                    result = _NoOpWaiter((seeds, nodes_timestamp, None, None))\n                    break\n                result = None\n                assert typed_seeds.ndim == 2, (\n                    \"Only tensor with shape 1*N and N*M is \"\n                    + f\"supported now, but got {typed_seeds.shape}.\"\n                )\n                ntypes = seed_type_str_to_ntypes(\n                    seed_type, typed_seeds.shape[1]\n                )\n                if use_timestamp:\n                    negative_ratio = (\n                        typed_seeds.shape[0]\n                        // minibatch.timestamp[seed_type].shape[0]\n                        - 1\n                    )\n                    neg_timestamp = minibatch.timestamp[\n                        seed_type\n                    ].repeat_interleave(negative_ratio)\n                for i, ntype in enumerate(ntypes):\n                    nodes[ntype].append(typed_seeds[:, i])\n                    if use_timestamp:\n                        nodes_timestamp[ntype].append(\n                            minibatch.timestamp[seed_type]\n                        )\n                        nodes_timestamp[ntype].append(neg_timestamp)\n\n            class _Waiter:\n                def __init__(self, nodes, nodes_timestamp, seeds):\n                    # Unique and compact the collected nodes.\n                    if use_timestamp:\n                        self.future = compact_temporal_nodes(\n                            nodes, nodes_timestamp\n                        )\n                    else:\n                        self.future = unique_and_compact(\n                            nodes, rank, world_size, async_op\n                        )\n                    self.seeds = seeds\n\n                def wait(self):\n                    \"\"\"Returns the stored value when invoked.\"\"\"\n                    if use_timestamp:\n                        unique_seeds, nodes_timestamp, compacted = self.future\n                        offsets = None\n                    else:\n                        unique_seeds, compacted, offsets = (\n                            self.future.wait() if async_op else self.future\n                        )\n                        nodes_timestamp = None\n                    seeds = self.seeds\n                    # Ensure there is no memory leak.\n                    self.future = self.seeds = None\n\n                    compacted_seeds = {}\n                    # Map back in same order as collect.\n                    for seed_type, typed_seeds in seeds.items():\n                        ntypes = seed_type_str_to_ntypes(\n                            seed_type, typed_seeds.shape[1]\n                        )\n                        compacted_seed = []\n                        for ntype in ntypes:\n                            compacted_seed.append(compacted[ntype].pop(0))\n                        compacted_seeds[seed_type] = (\n                            torch.cat(compacted_seed).view(len(ntypes), -1).T\n                        )\n\n                    return (\n                        unique_seeds,\n                        nodes_timestamp,\n                        compacted_seeds,\n                        offsets,\n                    )\n\n            # When typed_seeds is not a one-dimensional tensor\n            if result is None:\n                result = _Waiter(nodes, nodes_timestamp, seeds)\n        else:\n            # When seeds is a one-dimensional tensor, it represents seed nodes,\n            # which does not need to do unique and compact.\n            if seeds.ndim == 1:\n                nodes_timestamp = (\n                    minibatch.timestamp\n                    if hasattr(minibatch, \"timestamp\")\n                    else None\n                )\n                result = _NoOpWaiter((seeds, nodes_timestamp, None, None))\n            else:\n                # Collect nodes from all types of input.\n                nodes = [seeds.view(-1)]\n                nodes_timestamp = None\n                if use_timestamp:\n                    # Timestamp for source and destination nodes are the same.\n                    negative_ratio = (\n                        seeds.shape[0] // minibatch.timestamp.shape[0] - 1\n                    )\n                    neg_timestamp = minibatch.timestamp.repeat_interleave(\n                        negative_ratio\n                    )\n                    seeds_timestamp = torch.cat(\n                        (minibatch.timestamp, neg_timestamp)\n                    )\n                    nodes_timestamp = [\n                        seeds_timestamp for _ in range(seeds.shape[1])\n                    ]\n\n                class _Waiter:\n                    def __init__(self, nodes, nodes_timestamp, seeds):\n                        # Unique and compact the collected nodes.\n                        if use_timestamp:\n                            self.future = compact_temporal_nodes(\n                                nodes, nodes_timestamp\n                            )\n                        else:\n                            self.future = unique_and_compact(\n                                nodes, async_op=async_op\n                            )\n                        self.seeds = seeds\n\n                    def wait(self):\n                        \"\"\"Returns the stored value when invoked.\"\"\"\n                        if use_timestamp:\n                            (\n                                unique_seeds,\n                                nodes_timestamp,\n                                compacted,\n                            ) = self.future\n                            offsets = None\n                        else:\n                            unique_seeds, compacted, offsets = (\n                                self.future.wait() if async_op else self.future\n                            )\n                            nodes_timestamp = None\n                        seeds = self.seeds\n                        # Ensure there is no memory leak.\n                        self.future = self.seeds = None\n\n                        # Map back in same order as collect.\n                        compacted_seeds = compacted[0].view(seeds.shape)\n\n                        return (\n                            unique_seeds,\n                            nodes_timestamp,\n                            compacted_seeds,\n                            offsets,\n                        )\n\n                result = _Waiter(nodes, nodes_timestamp, seeds)\n\n        return result if async_op else result.wait()\n\n    def sample_subgraphs(\n        self, seeds, seeds_timestamp, seeds_pre_time_window=None\n    ):\n        \"\"\"Sample subgraphs from the given seeds, possibly with temporal constraints.\n\n        Any subclass of SubgraphSampler should implement this method.\n\n        Parameters\n        ----------\n        seeds : Union[torch.Tensor, Dict[str, torch.Tensor]]\n            The seed nodes.\n\n        seeds_timestamp : Union[torch.Tensor, Dict[str, torch.Tensor]]\n            The timestamps of the seed nodes. If given, the sampled subgraphs\n            should not contain any nodes or edges that are newer than the\n            timestamps of the seed nodes. Default: None.\n\n        seeds_pre_time_window : Union[torch.Tensor, Dict[str, torch.Tensor]]\n            The time window of the nodes represents a period of time before\n            `seeds_timestamp`. If provided, only neighbors and related edges\n            whose timestamps fall within `[seeds_timestamp -\n            seeds_pre_time_window, seeds_timestamp]` will be filtered.\n        Returns\n        -------\n        Union[torch.Tensor, Dict[str, torch.Tensor]]\n            The input nodes.\n        List[SampledSubgraph]\n            The sampled subgraphs.\n\n        Examples\n        --------\n        >>> @functional_datapipe(\"my_sample_subgraph\")\n        >>> class MySubgraphSampler(SubgraphSampler):\n        >>>     def __init__(self, datapipe, graph, fanouts):\n        >>>         super().__init__(datapipe)\n        >>>         self.graph = graph\n        >>>         self.fanouts = fanouts\n        >>>     def sample_subgraphs(self, seeds):\n        >>>         # Sample subgraphs from the given seeds.\n        >>>         subgraphs = []\n        >>>         subgraphs_nodes = []\n        >>>         for fanout in reversed(self.fanouts):\n        >>>             subgraph = self.graph.sample_neighbors(seeds, fanout)\n        >>>             subgraphs.insert(0, subgraph)\n        >>>             subgraphs_nodes.append(subgraph.nodes)\n        >>>             seeds = subgraph.nodes\n        >>>         subgraphs_nodes = torch.unique(torch.cat(subgraphs_nodes))\n        >>>         return subgraphs_nodes, subgraphs\n        \"\"\"\n        raise NotImplementedError\n"
  },
  {
    "path": "python/dgl/heterograph.py",
    "content": "\"\"\"Classes for heterogeneous graphs.\"\"\"\nimport copy\nimport itertools\nimport numbers\n\n# pylint: disable= too-many-lines\nfrom collections import defaultdict\nfrom collections.abc import Iterable, Mapping\nfrom contextlib import contextmanager\n\nimport networkx as nx\nimport numpy as np\n\nfrom . import backend as F, core, graph_index, heterograph_index, utils\n\nfrom ._ffi.function import _init_api\nfrom .base import (\n    ALL,\n    dgl_warning,\n    DGLError,\n    EID,\n    ETYPE,\n    is_all,\n    NID,\n    NTYPE,\n    SLICE_FULL,\n)\nfrom .frame import Frame\nfrom .ops import segment\nfrom .view import (\n    HeteroEdgeDataView,\n    HeteroEdgeView,\n    HeteroNodeDataView,\n    HeteroNodeView,\n)\n\n__all__ = [\"DGLGraph\", \"combine_names\"]\n\n\nclass DGLGraph(object):\n    \"\"\"Class for storing graph structure and node/edge feature data.\n\n    There are a few ways to create a DGLGraph:\n\n    * To create a homogeneous graph from Tensor data, use :func:`dgl.graph`.\n    * To create a heterogeneous graph from Tensor data, use :func:`dgl.heterograph`.\n    * To create a graph from other data sources, use ``dgl.*`` create ops. See\n      :ref:`api-graph-create-ops`.\n\n    Read the user guide chapter :ref:`guide-graph` for an in-depth explanation about its\n    usage.\n    \"\"\"\n\n    is_block = False\n\n    # pylint: disable=unused-argument, dangerous-default-value\n    def __init__(\n        self,\n        gidx=[],\n        ntypes=[\"_N\"],\n        etypes=[\"_E\"],\n        node_frames=None,\n        edge_frames=None,\n        **deprecate_kwargs\n    ):\n        \"\"\"Internal constructor for creating a DGLGraph.\n\n        Parameters\n        ----------\n        gidx : HeteroGraphIndex\n            Graph index object.\n        ntypes : list of str, pair of list of str\n            Node type list. ``ntypes[i]`` stores the name of node type i.\n            If a pair is given, the graph created is a uni-directional bipartite graph,\n            and its SRC node types and DST node types are given as in the pair.\n        etypes : list of str\n            Edge type list. ``etypes[i]`` stores the name of edge type i.\n        node_frames : list[Frame], optional\n            Node feature storage. If None, empty frame is created.\n            Otherwise, ``node_frames[i]`` stores the node features\n            of node type i. (default: None)\n        edge_frames : list[Frame], optional\n            Edge feature storage. If None, empty frame is created.\n            Otherwise, ``edge_frames[i]`` stores the edge features\n            of edge type i. (default: None)\n        \"\"\"\n        if isinstance(gidx, DGLGraph):\n            raise DGLError(\n                \"The input is already a DGLGraph. No need to create it again.\"\n            )\n        if not isinstance(gidx, heterograph_index.HeteroGraphIndex):\n            dgl_warning(\n                \"Recommend creating graphs by `dgl.graph(data)`\"\n                \" instead of `dgl.DGLGraph(data)`.\"\n            )\n            (sparse_fmt, arrays), num_src, num_dst = utils.graphdata2tensors(\n                gidx\n            )\n            if sparse_fmt == \"coo\":\n                gidx = heterograph_index.create_unitgraph_from_coo(\n                    1,\n                    num_src,\n                    num_dst,\n                    arrays[0],\n                    arrays[1],\n                    [\"coo\", \"csr\", \"csc\"],\n                )\n            else:\n                gidx = heterograph_index.create_unitgraph_from_csr(\n                    1,\n                    num_src,\n                    num_dst,\n                    arrays[0],\n                    arrays[1],\n                    arrays[2],\n                    [\"coo\", \"csr\", \"csc\"],\n                    sparse_fmt == \"csc\",\n                )\n        if len(deprecate_kwargs) != 0:\n            dgl_warning(\n                \"Keyword arguments {} are deprecated in v0.5, and can be safely\"\n                \" removed in all cases.\".format(list(deprecate_kwargs.keys()))\n            )\n        self._init(gidx, ntypes, etypes, node_frames, edge_frames)\n\n    def _init(self, gidx, ntypes, etypes, node_frames, edge_frames):\n        \"\"\"Init internal states.\"\"\"\n        self._graph = gidx\n        self._canonical_etypes = None\n        self._batch_num_nodes = None\n        self._batch_num_edges = None\n\n        # Handle node types\n        if isinstance(ntypes, tuple):\n            if len(ntypes) != 2:\n                errmsg = \"Invalid input. Expect a pair (srctypes, dsttypes) but got {}\".format(\n                    ntypes\n                )\n                raise TypeError(errmsg)\n            if not self._graph.is_metagraph_unibipartite():\n                raise ValueError(\n                    \"Invalid input. The metagraph must be a uni-directional\"\n                    \" bipartite graph.\"\n                )\n            self._ntypes = ntypes[0] + ntypes[1]\n            self._srctypes_invmap = {t: i for i, t in enumerate(ntypes[0])}\n            self._dsttypes_invmap = {\n                t: i + len(ntypes[0]) for i, t in enumerate(ntypes[1])\n            }\n            self._is_unibipartite = True\n            if len(ntypes[0]) == 1 and len(ntypes[1]) == 1 and len(etypes) == 1:\n                self._canonical_etypes = [\n                    (ntypes[0][0], etypes[0], ntypes[1][0])\n                ]\n        else:\n            self._ntypes = ntypes\n            if len(ntypes) == 1:\n                src_dst_map = None\n            else:\n                src_dst_map = find_src_dst_ntypes(\n                    self._ntypes, self._graph.metagraph\n                )\n            self._is_unibipartite = src_dst_map is not None\n            if self._is_unibipartite:\n                self._srctypes_invmap, self._dsttypes_invmap = src_dst_map\n            else:\n                self._srctypes_invmap = {\n                    t: i for i, t in enumerate(self._ntypes)\n                }\n                self._dsttypes_invmap = self._srctypes_invmap\n\n        # Handle edge types\n        self._etypes = etypes\n        if self._canonical_etypes is None:\n            if len(etypes) == 1 and len(ntypes) == 1:\n                self._canonical_etypes = [(ntypes[0], etypes[0], ntypes[0])]\n            else:\n                self._canonical_etypes = make_canonical_etypes(\n                    self._etypes, self._ntypes, self._graph.metagraph\n                )\n\n        # An internal map from etype to canonical etype tuple.\n        # If two etypes have the same name, an empty tuple is stored instead to indicate\n        # ambiguity.\n        self._etype2canonical = {}\n        for i, ety in enumerate(self._etypes):\n            if ety in self._etype2canonical:\n                self._etype2canonical[ety] = tuple()\n            else:\n                self._etype2canonical[ety] = self._canonical_etypes[i]\n        self._etypes_invmap = {\n            t: i for i, t in enumerate(self._canonical_etypes)\n        }\n\n        # node and edge frame\n        if node_frames is None:\n            node_frames = [None] * len(self._ntypes)\n        node_frames = [\n            Frame(num_rows=self._graph.num_nodes(i)) if frame is None else frame\n            for i, frame in enumerate(node_frames)\n        ]\n        self._node_frames = node_frames\n\n        if edge_frames is None:\n            edge_frames = [None] * len(self._etypes)\n        edge_frames = [\n            Frame(num_rows=self._graph.num_edges(i)) if frame is None else frame\n            for i, frame in enumerate(edge_frames)\n        ]\n        self._edge_frames = edge_frames\n\n    def __setstate__(self, state):\n        # Compatibility check\n        # TODO: version the storage\n        if isinstance(state, dict):\n            # Since 0.5 we use the default __dict__ method\n            self.__dict__.update(state)\n        elif isinstance(state, tuple) and len(state) == 5:\n            # DGL == 0.4.3\n            dgl_warning(\n                \"The object is pickled with DGL == 0.4.3.  \"\n                \"Some of the original attributes are ignored.\"\n            )\n            self._init(*state)\n        elif isinstance(state, dict):\n            # DGL <= 0.4.2\n            dgl_warning(\n                \"The object is pickled with DGL <= 0.4.2.  \"\n                \"Some of the original attributes are ignored.\"\n            )\n            self._init(\n                state[\"_graph\"],\n                state[\"_ntypes\"],\n                state[\"_etypes\"],\n                state[\"_node_frames\"],\n                state[\"_edge_frames\"],\n            )\n        else:\n            raise IOError(\"Unrecognized pickle format.\")\n\n    def __repr__(self):\n        if len(self.ntypes) == 1 and len(self.etypes) == 1:\n            ret = (\n                \"Graph(num_nodes={node}, num_edges={edge},\\n\"\n                \"      ndata_schemes={ndata}\\n\"\n                \"      edata_schemes={edata})\"\n            )\n            return ret.format(\n                node=self.num_nodes(),\n                edge=self.num_edges(),\n                ndata=str(self.node_attr_schemes()),\n                edata=str(self.edge_attr_schemes()),\n            )\n        else:\n            ret = (\n                \"Graph(num_nodes={node},\\n\"\n                \"      num_edges={edge},\\n\"\n                \"      metagraph={meta})\"\n            )\n            nnode_dict = {\n                self.ntypes[i]: self._graph.num_nodes(i)\n                for i in range(len(self.ntypes))\n            }\n            nedge_dict = {\n                self.canonical_etypes[i]: self._graph.num_edges(i)\n                for i in range(len(self.etypes))\n            }\n            meta = str(self.metagraph().edges(keys=True))\n            return ret.format(node=nnode_dict, edge=nedge_dict, meta=meta)\n\n    def __copy__(self):\n        \"\"\"Shallow copy implementation.\"\"\"\n        # TODO(minjie): too many states in python; should clean up and lower to C\n        cls = type(self)\n        obj = cls.__new__(cls)\n        obj.__dict__.update(self.__dict__)\n        return obj\n\n    #################################################################\n    # Mutation operations\n    #################################################################\n\n    def add_nodes(self, num, data=None, ntype=None):\n        r\"\"\"Add new nodes of the same node type\n\n        Parameters\n        ----------\n        num : int\n            Number of nodes to add.\n        data : dict, optional\n            Feature data of the added nodes.\n        ntype : str, optional\n            The type of the new nodes. Can be omitted if there is\n            only one node type in the graph.\n\n        Notes\n        -----\n\n        * Inplace update is applied to the current graph.\n        * If the key of ``data`` does not contain some existing feature fields,\n          those features for the new nodes will be created by initializers\n          defined with :func:`set_n_initializer` (default initializer fills zeros).\n        * If the key of ``data`` contains new feature fields, those features for\n          the old nodes will be created by initializers defined with\n          :func:`set_n_initializer` (default initializer fills zeros).\n        * This function discards the batch information. Please use\n          :func:`dgl.DGLGraph.set_batch_num_nodes`\n          and :func:`dgl.DGLGraph.set_batch_num_edges` on the transformed graph\n          to maintain the information.\n\n        Examples\n        --------\n\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        **Homogeneous Graphs or Heterogeneous Graphs with A Single Node Type**\n\n        >>> g = dgl.graph((torch.tensor([0, 1]), torch.tensor([1, 2])))\n        >>> g.num_nodes()\n        3\n        >>> g.add_nodes(2)\n        >>> g.num_nodes()\n        5\n\n        If the graph has some node features and new nodes are added without\n        features, their features will be created by initializers defined\n        with :func:`set_n_initializer`.\n\n        >>> g.ndata['h'] = torch.ones(5, 1)\n        >>> g.add_nodes(1)\n        >>> g.ndata['h']\n        tensor([[1.], [1.], [1.], [1.], [1.], [0.]])\n\n        We can also assign features for the new nodes in adding new nodes.\n\n        >>> g.add_nodes(1, {'h': torch.ones(1, 1), 'w': torch.ones(1, 1)})\n        >>> g.ndata['h']\n        tensor([[1.], [1.], [1.], [1.], [1.], [0.], [1.]])\n\n        Since ``data`` contains new feature fields, the features for old nodes\n        will be created by initializers defined with :func:`set_n_initializer`.\n\n        >>> g.ndata['w']\n        tensor([[0.], [0.], [0.], [0.], [0.], [0.], [1.]])\n\n\n        **Heterogeneous Graphs with Multiple Node Types**\n\n        >>> g = dgl.heterograph({\n        ...     ('user', 'plays', 'game'): (torch.tensor([0, 1, 1, 2]),\n        ...                                 torch.tensor([0, 0, 1, 1])),\n        ...     ('developer', 'develops', 'game'): (torch.tensor([0, 1]),\n        ...                                         torch.tensor([0, 1]))\n        ...     })\n        >>> g.add_nodes(2)\n        DGLError: Node type name must be specified\n        if there are more than one node types.\n        >>> g.num_nodes('user')\n        3\n        >>> g.add_nodes(2, ntype='user')\n        >>> g.num_nodes('user')\n        5\n\n        See Also\n        --------\n        remove_nodes\n        add_edges\n        remove_edges\n        \"\"\"\n        # TODO(xiangsx): block do not support add_nodes\n        if ntype is None:\n            if self._graph.number_of_ntypes() != 1:\n                raise DGLError(\n                    \"Node type name must be specified if there are more than one \"\n                    \"node types.\"\n                )\n\n        # nothing happen\n        if num == 0:\n            return\n\n        assert num > 0, \"Number of new nodes should be larger than one.\"\n        ntid = self.get_ntype_id(ntype)\n        # update graph idx\n        metagraph = self._graph.metagraph\n        num_nodes_per_type = []\n        for c_ntype in self.ntypes:\n            if self.get_ntype_id(c_ntype) == ntid:\n                num_nodes_per_type.append(self.num_nodes(c_ntype) + num)\n            else:\n                num_nodes_per_type.append(self.num_nodes(c_ntype))\n\n        relation_graphs = []\n        for c_etype in self.canonical_etypes:\n            # src or dst == ntype, update the relation graph\n            if (\n                self.get_ntype_id(c_etype[0]) == ntid\n                or self.get_ntype_id(c_etype[2]) == ntid\n            ):\n                u, v = self.edges(form=\"uv\", order=\"eid\", etype=c_etype)\n                hgidx = heterograph_index.create_unitgraph_from_coo(\n                    1 if c_etype[0] == c_etype[2] else 2,\n                    self.num_nodes(c_etype[0])\n                    + (num if self.get_ntype_id(c_etype[0]) == ntid else 0),\n                    self.num_nodes(c_etype[2])\n                    + (num if self.get_ntype_id(c_etype[2]) == ntid else 0),\n                    u,\n                    v,\n                    [\"coo\", \"csr\", \"csc\"],\n                )\n                relation_graphs.append(hgidx)\n            else:\n                # do nothing\n                relation_graphs.append(\n                    self._graph.get_relation_graph(self.get_etype_id(c_etype))\n                )\n        hgidx = heterograph_index.create_heterograph_from_relations(\n            metagraph,\n            relation_graphs,\n            utils.toindex(num_nodes_per_type, \"int64\"),\n        )\n        self._graph = hgidx\n\n        # update data frames\n        if data is None:\n            # Initialize feature with :func:`set_n_initializer`\n            self._node_frames[ntid].add_rows(num)\n        else:\n            self._node_frames[ntid].append(data)\n        self._reset_cached_info()\n\n    def add_edges(self, u, v, data=None, etype=None):\n        r\"\"\"Add multiple new edges for the specified edge type\n\n        The i-th new edge will be from ``u[i]`` to ``v[i]``.\n\n        Parameters\n        ----------\n        u : int, tensor, numpy.ndarray, list\n            Source node IDs, ``u[i]`` gives the source node for the i-th new edge.\n        v : int, tensor, numpy.ndarray, list\n            Destination node IDs, ``v[i]`` gives the destination node for the i-th new edge.\n        data : dict, optional\n            Feature data of the added edges. The i-th row of the feature data\n            corresponds to the i-th new edge.\n        etype : str or tuple of str, optional\n            The type of the new edges. Can be omitted if there is\n            only one edge type in the graph.\n\n        Notes\n        -----\n\n        * Inplace update is applied to the current graph.\n        * If end nodes of adding edges does not exists, add_nodes is invoked\n          to add new nodes. The node features of the new nodes will be created\n          by initializers defined with :func:`set_n_initializer` (default\n          initializer fills zeros). In certain cases, it is recommanded to\n          add_nodes first and then add_edges.\n        * If the key of ``data`` does not contain some existing feature fields,\n          those features for the new edges will be created by initializers\n          defined with :func:`set_n_initializer` (default initializer fills zeros).\n        * If the key of ``data`` contains new feature fields, those features for\n          the old edges will be created by initializers defined with\n          :func:`set_n_initializer` (default initializer fills zeros).\n        * This function discards the batch information. Please use\n          :func:`dgl.DGLGraph.set_batch_num_nodes`\n          and :func:`dgl.DGLGraph.set_batch_num_edges` on the transformed graph\n          to maintain the information.\n\n        Examples\n        --------\n\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        **Homogeneous Graphs or Heterogeneous Graphs with A Single Edge Type**\n\n        >>> g = dgl.graph((torch.tensor([0, 1]), torch.tensor([1, 2])))\n        >>> g.num_edges()\n        2\n        >>> g.add_edges(torch.tensor([1, 3]), torch.tensor([0, 1]))\n        >>> g.num_edges()\n        4\n\n        Since ``u`` or ``v`` contains a non-existing node ID, the nodes are\n        added implicitly.\n        >>> g.num_nodes()\n        4\n\n        If the graph has some edge features and new edges are added without\n        features, their features will be created by initializers defined\n        with :func:`set_n_initializer`.\n\n        >>> g.edata['h'] = torch.ones(4, 1)\n        >>> g.add_edges(torch.tensor([1]), torch.tensor([1]))\n        >>> g.edata['h']\n        tensor([[1.], [1.], [1.], [1.], [0.]])\n\n        We can also assign features for the new edges in adding new edges.\n\n        >>> g.add_edges(torch.tensor([0, 0]), torch.tensor([2, 2]),\n        ...             {'h': torch.tensor([[1.], [2.]]), 'w': torch.ones(2, 1)})\n        >>> g.edata['h']\n        tensor([[1.], [1.], [1.], [1.], [0.], [1.], [2.]])\n\n        Since ``data`` contains new feature fields, the features for old edges\n        will be created by initializers defined with :func:`set_n_initializer`.\n\n        >>> g.edata['w']\n        tensor([[0.], [0.], [0.], [0.], [0.], [1.], [1.]])\n\n        **Heterogeneous Graphs with Multiple Edge Types**\n\n        >>> g = dgl.heterograph({\n        ...     ('user', 'plays', 'game'): (torch.tensor([0, 1, 1, 2]),\n        ...                                 torch.tensor([0, 0, 1, 1])),\n        ...     ('developer', 'develops', 'game'): (torch.tensor([0, 1]),\n        ...                                         torch.tensor([0, 1]))\n        ...     })\n        >>> g.add_edges(torch.tensor([3]), torch.tensor([3]))\n        DGLError: Edge type name must be specified\n        if there are more than one edge types.\n        >>> g.num_edges('plays')\n        4\n        >>> g.add_edges(torch.tensor([3]), torch.tensor([3]), etype='plays')\n        >>> g.num_edges('plays')\n        5\n\n        See Also\n        --------\n        add_nodes\n        remove_nodes\n        remove_edges\n        \"\"\"\n        # TODO(xiangsx): block do not support add_edges\n        u = utils.prepare_tensor(self, u, \"u\")\n        v = utils.prepare_tensor(self, v, \"v\")\n\n        if etype is None:\n            if self._graph.number_of_etypes() != 1:\n                raise DGLError(\n                    \"Edge type name must be specified if there are more than one \"\n                    \"edge types.\"\n                )\n\n        # nothing changed\n        if len(u) == 0 or len(v) == 0:\n            return\n\n        assert len(u) == len(v) or len(u) == 1 or len(v) == 1, (\n            \"The number of source nodes and the number of destination nodes should be same, \"\n            \"or either the number of source nodes or the number of destination nodes is 1.\"\n        )\n\n        if len(u) == 1 and len(v) > 1:\n            u = F.full_1d(\n                len(v), F.as_scalar(u), dtype=F.dtype(u), ctx=F.context(u)\n            )\n        if len(v) == 1 and len(u) > 1:\n            v = F.full_1d(\n                len(u), F.as_scalar(v), dtype=F.dtype(v), ctx=F.context(v)\n            )\n\n        u_type, e_type, v_type = self.to_canonical_etype(etype)\n        # if end nodes of adding edges does not exists\n        # use add_nodes to add new nodes first.\n        num_of_u = self.num_nodes(u_type)\n        num_of_v = self.num_nodes(v_type)\n        u_max = F.as_scalar(F.max(u, dim=0)) + 1\n        v_max = F.as_scalar(F.max(v, dim=0)) + 1\n\n        if u_type == v_type:\n            num_nodes = max(u_max, v_max)\n            if num_nodes > num_of_u:\n                self.add_nodes(num_nodes - num_of_u, ntype=u_type)\n        else:\n            if u_max > num_of_u:\n                self.add_nodes(u_max - num_of_u, ntype=u_type)\n            if v_max > num_of_v:\n                self.add_nodes(v_max - num_of_v, ntype=v_type)\n\n        # metagraph is not changed\n        metagraph = self._graph.metagraph\n        num_nodes_per_type = []\n        for ntype in self.ntypes:\n            num_nodes_per_type.append(self.num_nodes(ntype))\n        # update graph idx\n        relation_graphs = []\n        for c_etype in self.canonical_etypes:\n            # the target edge type\n            if c_etype == (u_type, e_type, v_type):\n                old_u, old_v = self.edges(form=\"uv\", order=\"eid\", etype=c_etype)\n                hgidx = heterograph_index.create_unitgraph_from_coo(\n                    1 if u_type == v_type else 2,\n                    self.num_nodes(u_type),\n                    self.num_nodes(v_type),\n                    F.cat([old_u, u], dim=0),\n                    F.cat([old_v, v], dim=0),\n                    [\"coo\", \"csr\", \"csc\"],\n                )\n                relation_graphs.append(hgidx)\n            else:\n                # do nothing\n                # Note: node range change has been handled in add_nodes()\n                relation_graphs.append(\n                    self._graph.get_relation_graph(self.get_etype_id(c_etype))\n                )\n\n        hgidx = heterograph_index.create_heterograph_from_relations(\n            metagraph,\n            relation_graphs,\n            utils.toindex(num_nodes_per_type, \"int64\"),\n        )\n        self._graph = hgidx\n\n        # handle data\n        etid = self.get_etype_id(etype)\n        if data is None:\n            self._edge_frames[etid].add_rows(len(u))\n        else:\n            self._edge_frames[etid].append(data)\n        self._reset_cached_info()\n\n    def remove_edges(self, eids, etype=None, store_ids=False):\n        r\"\"\"Remove multiple edges with the specified edge type\n\n        Nodes will not be removed. After removing edges, the rest\n        edges will be re-indexed using consecutive integers from 0,\n        with their relative order preserved.\n\n        The features for the removed edges will be removed accordingly.\n\n        Parameters\n        ----------\n        eids : int, tensor, numpy.ndarray, list\n            IDs for the edges to remove.\n        etype : str or tuple of str, optional\n            The type of the edges to remove. Can be omitted if there is\n            only one edge type in the graph.\n        store_ids : bool, optional\n            If True, it will store the raw IDs of the extracted nodes and edges in the ``ndata``\n            and ``edata`` of the resulting graph under name ``dgl.NID`` and ``dgl.EID``,\n            respectively.\n\n        Notes\n        -----\n        This function preserves the batch information.\n\n        Examples\n        --------\n\n        >>> import dgl\n        >>> import torch\n\n        **Homogeneous Graphs or Heterogeneous Graphs with A Single Edge Type**\n\n        >>> g = dgl.graph((torch.tensor([0, 0, 2]), torch.tensor([0, 1, 2])))\n        >>> g.edata['he'] = torch.arange(3).float().reshape(-1, 1)\n        >>> g.remove_edges(torch.tensor([0, 1]))\n        >>> g\n        Graph(num_nodes=3, num_edges=1,\n            ndata_schemes={}\n            edata_schemes={'he': Scheme(shape=(1,), dtype=torch.float32)})\n        >>> g.edges('all')\n        (tensor([2]), tensor([2]), tensor([0]))\n        >>> g.edata['he']\n        tensor([[2.]])\n\n        Removing edges from a batched graph preserves batch information.\n\n        >>> g = dgl.graph((torch.tensor([0, 0, 2]), torch.tensor([0, 1, 2])))\n        >>> g2 = dgl.graph((torch.tensor([1, 2, 3]), torch.tensor([1, 3, 4])))\n        >>> bg = dgl.batch([g, g2])\n        >>> bg.batch_num_edges()\n        tensor([3, 3])\n        >>> bg.remove_edges([1, 4])\n        >>> bg.batch_num_edges()\n        tensor([2, 2])\n\n        **Heterogeneous Graphs with Multiple Edge Types**\n\n        >>> g = dgl.heterograph({\n        ...     ('user', 'plays', 'game'): (torch.tensor([0, 1, 1, 2]),\n        ...                                 torch.tensor([0, 0, 1, 1])),\n        ...     ('developer', 'develops', 'game'): (torch.tensor([0, 1]),\n        ...                                         torch.tensor([0, 1]))\n        ...     })\n        >>> g.remove_edges(torch.tensor([0, 1]))\n        DGLError: Edge type name must be specified\n        if there are more than one edge types.\n        >>> g.remove_edges(torch.tensor([0, 1]), 'plays')\n        >>> g.edges('all', etype='plays')\n        (tensor([0, 1]), tensor([0, 0]), tensor([0, 1]))\n\n        See Also\n        --------\n        add_nodes\n        add_edges\n        remove_nodes\n        \"\"\"\n        # TODO(xiangsx): block do not support remove_edges\n        if etype is None:\n            if self._graph.number_of_etypes() != 1:\n                raise DGLError(\n                    \"Edge type name must be specified if there are more than one \"\n                    \"edge types.\"\n                )\n        eids = utils.prepare_tensor(self, eids, \"u\")\n        if len(eids) == 0:\n            # no edge to delete\n            return\n        assert self.num_edges(etype) > F.as_scalar(\n            F.max(eids, dim=0)\n        ), \"The input eid {} is out of the range [0:{})\".format(\n            F.as_scalar(F.max(eids, dim=0)), self.num_edges(etype)\n        )\n\n        # edge_subgraph\n        edges = {}\n        u_type, e_type, v_type = self.to_canonical_etype(etype)\n        for c_etype in self.canonical_etypes:\n            # the target edge type\n            if c_etype == (u_type, e_type, v_type):\n                origin_eids = self.edges(form=\"eid\", order=\"eid\", etype=c_etype)\n                edges[c_etype] = utils.compensate(eids, origin_eids)\n            else:\n                edges[c_etype] = self.edges(\n                    form=\"eid\", order=\"eid\", etype=c_etype\n                )\n\n        # If the graph is batched, update batch_num_edges\n        batched = self._batch_num_edges is not None\n        if batched:\n            c_etype = (u_type, e_type, v_type)\n            one_hot_removed_edges = F.zeros(\n                (self.num_edges(c_etype),), F.float32, self.device\n            )\n            one_hot_removed_edges = F.scatter_row(\n                one_hot_removed_edges,\n                eids,\n                F.full_1d(len(eids), 1.0, F.float32, self.device),\n            )\n            c_etype_batch_num_edges = self._batch_num_edges[c_etype]\n            batch_num_removed_edges = segment.segment_reduce(\n                c_etype_batch_num_edges, one_hot_removed_edges, reducer=\"sum\"\n            )\n            self._batch_num_edges[c_etype] = c_etype_batch_num_edges - F.astype(\n                batch_num_removed_edges, self.idtype\n            )\n\n        sub_g = self.edge_subgraph(\n            edges, relabel_nodes=False, store_ids=store_ids\n        )\n        self._graph = sub_g._graph\n        self._node_frames = sub_g._node_frames\n        self._edge_frames = sub_g._edge_frames\n\n    def remove_nodes(self, nids, ntype=None, store_ids=False):\n        r\"\"\"Remove multiple nodes with the specified node type\n\n        Edges that connect to the nodes will be removed as well. After removing\n        nodes and edges, the rest nodes and edges will be re-indexed using\n        consecutive integers from 0, with their relative order preserved.\n\n        The features for the removed nodes/edges will be removed accordingly.\n\n        Parameters\n        ----------\n        nids : int, tensor, numpy.ndarray, list\n            Nodes to remove.\n        ntype : str, optional\n            The type of the nodes to remove. Can be omitted if there is\n            only one node type in the graph.\n        store_ids : bool, optional\n            If True, it will store the raw IDs of the extracted nodes and edges in the ``ndata``\n            and ``edata`` of the resulting graph under name ``dgl.NID`` and ``dgl.EID``,\n            respectively.\n\n        Notes\n        -----\n        This function preserves the batch information.\n\n        Examples\n        --------\n\n        >>> import dgl\n        >>> import torch\n\n        **Homogeneous Graphs or Heterogeneous Graphs with A Single Node Type**\n\n        >>> g = dgl.graph((torch.tensor([0, 0, 2]), torch.tensor([0, 1, 2])))\n        >>> g.ndata['hv'] = torch.arange(3).float().reshape(-1, 1)\n        >>> g.edata['he'] = torch.arange(3).float().reshape(-1, 1)\n        >>> g.remove_nodes(torch.tensor([0, 1]))\n        >>> g\n        Graph(num_nodes=1, num_edges=1,\n            ndata_schemes={'hv': Scheme(shape=(1,), dtype=torch.float32)}\n            edata_schemes={'he': Scheme(shape=(1,), dtype=torch.float32)})\n        >>> g.ndata['hv']\n        tensor([[2.]])\n        >>> g.edata['he']\n        tensor([[2.]])\n\n        Removing nodes from a batched graph preserves batch information.\n\n        >>> g = dgl.graph((torch.tensor([0, 0, 2]), torch.tensor([0, 1, 2])))\n        >>> g2 = dgl.graph((torch.tensor([1, 2, 3]), torch.tensor([1, 3, 4])))\n        >>> bg = dgl.batch([g, g2])\n        >>> bg.batch_num_nodes()\n        tensor([3, 5])\n        >>> bg.remove_nodes([1, 4])\n        >>> bg.batch_num_nodes()\n        tensor([2, 4])\n        >>> bg.batch_num_edges()\n        tensor([2, 2])\n\n        **Heterogeneous Graphs with Multiple Node Types**\n\n        >>> g = dgl.heterograph({\n        ...     ('user', 'plays', 'game'): (torch.tensor([0, 1, 1, 2]),\n        ...                                 torch.tensor([0, 0, 1, 1])),\n        ...     ('developer', 'develops', 'game'): (torch.tensor([0, 1]),\n        ...                                         torch.tensor([0, 1]))\n        ...     })\n        >>> g.remove_nodes(torch.tensor([0, 1]))\n        DGLError: Node type name must be specified\n        if there are more than one node types.\n        >>> g.remove_nodes(torch.tensor([0, 1]), ntype='game')\n        >>> g.num_nodes('user')\n        3\n        >>> g.num_nodes('game')\n        0\n        >>> g.num_edges('plays')\n        0\n\n        See Also\n        --------\n        add_nodes\n        add_edges\n        remove_edges\n        \"\"\"\n        # TODO(xiangsx): block do not support remove_nodes\n        if ntype is None:\n            if self._graph.number_of_ntypes() != 1:\n                raise DGLError(\n                    \"Node type name must be specified if there are more than one \"\n                    \"node types.\"\n                )\n\n        nids = utils.prepare_tensor(self, nids, \"u\")\n        if len(nids) == 0:\n            # no node to delete\n            return\n        assert self.num_nodes(ntype) > F.as_scalar(\n            F.max(nids, dim=0)\n        ), \"The input nids {} is out of the range [0:{})\".format(\n            F.as_scalar(F.max(nids, dim=0)), self.num_nodes(ntype)\n        )\n\n        ntid = self.get_ntype_id(ntype)\n        nodes = {}\n        for c_ntype in self.ntypes:\n            if self.get_ntype_id(c_ntype) == ntid:\n                target_ntype = c_ntype\n                original_nids = self.nodes(c_ntype)\n                nodes[c_ntype] = utils.compensate(nids, original_nids)\n            else:\n                nodes[c_ntype] = self.nodes(c_ntype)\n\n        # If the graph is batched, update batch_num_nodes\n        batched = self._batch_num_nodes is not None\n        if batched:\n            one_hot_removed_nodes = F.zeros(\n                (self.num_nodes(target_ntype),), F.float32, self.device\n            )\n            one_hot_removed_nodes = F.scatter_row(\n                one_hot_removed_nodes,\n                nids,\n                F.full_1d(len(nids), 1.0, F.float32, self.device),\n            )\n            c_ntype_batch_num_nodes = self._batch_num_nodes[target_ntype]\n            batch_num_removed_nodes = segment.segment_reduce(\n                c_ntype_batch_num_nodes, one_hot_removed_nodes, reducer=\"sum\"\n            )\n            self._batch_num_nodes[\n                target_ntype\n            ] = c_ntype_batch_num_nodes - F.astype(\n                batch_num_removed_nodes, self.idtype\n            )\n            # Record old num_edges to check later whether some edges were removed\n            old_num_edges = {\n                c_etype: self._graph.num_edges(self.get_etype_id(c_etype))\n                for c_etype in self.canonical_etypes\n            }\n\n        # node_subgraph\n        # If batch_num_edges is to be updated, record the original edge IDs\n        sub_g = self.subgraph(nodes, store_ids=store_ids or batched)\n        self._graph = sub_g._graph\n        self._node_frames = sub_g._node_frames\n        self._edge_frames = sub_g._edge_frames\n\n        # If the graph is batched, update batch_num_edges\n        if batched:\n            canonical_etypes = [\n                c_etype\n                for c_etype in self.canonical_etypes\n                if self._graph.num_edges(self.get_etype_id(c_etype))\n                != old_num_edges[c_etype]\n            ]\n\n            for c_etype in canonical_etypes:\n                if self._graph.num_edges(self.get_etype_id(c_etype)) == 0:\n                    self._batch_num_edges[c_etype] = F.zeros(\n                        (self.batch_size,), self.idtype, self.device\n                    )\n                    continue\n\n                one_hot_left_edges = F.zeros(\n                    (old_num_edges[c_etype],), F.float32, self.device\n                )\n                eids = self.edges[c_etype].data[EID]\n                one_hot_left_edges = F.scatter_row(\n                    one_hot_left_edges,\n                    eids,\n                    F.full_1d(len(eids), 1.0, F.float32, self.device),\n                )\n                batch_num_left_edges = segment.segment_reduce(\n                    self._batch_num_edges[c_etype],\n                    one_hot_left_edges,\n                    reducer=\"sum\",\n                )\n                self._batch_num_edges[c_etype] = F.astype(\n                    batch_num_left_edges, self.idtype\n                )\n\n        if batched and not store_ids:\n            for c_ntype in self.ntypes:\n                self.nodes[c_ntype].data.pop(NID)\n            for c_etype in self.canonical_etypes:\n                self.edges[c_etype].data.pop(EID)\n\n    def _reset_cached_info(self):\n        \"\"\"Some info like batch_num_nodes may be stale after mutation\n        Clean these cached info\n        \"\"\"\n        self._batch_num_nodes = None\n        self._batch_num_edges = None\n\n    #################################################################\n    # Metagraph query\n    #################################################################\n\n    @property\n    def is_unibipartite(self):\n        \"\"\"Return whether the graph is a uni-bipartite graph.\n\n        A uni-bipartite heterograph can further divide its node types into two sets:\n        SRC and DST. All edges are from nodes in SRC to nodes in DST. The following APIs\n        can be used to get the type, data, and nodes that belong to SRC and DST sets:\n\n        * :func:`srctype` and :func:`dsttype`\n        * :func:`srcdata` and :func:`dstdata`\n        * :func:`srcnodes` and :func:`dstnodes`\n\n        Note that we allow two node types to have the same name as long as one\n        belongs to SRC while the other belongs to DST. To distinguish them, prepend\n        the name with ``\"SRC/\"`` or ``\"DST/\"`` when specifying a node type.\n        \"\"\"\n        return self._is_unibipartite\n\n    @property\n    def ntypes(self):\n        \"\"\"Return all the node type names in the graph.\n\n        Returns\n        -------\n        list[str]\n            All the node type names in a list.\n\n        Notes\n        -----\n        DGL internally assigns an integer ID for each node type. The returned\n        node type names are sorted according to their IDs.\n\n        Examples\n        --------\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        >>> g = dgl.heterograph({\n        ...     ('user', 'follows', 'user'): (torch.tensor([0, 1]), torch.tensor([1, 2])),\n        ...     ('user', 'follows', 'game'): (torch.tensor([0, 1, 2]), torch.tensor([1, 2, 3])),\n        ...     ('user', 'plays', 'game'): (torch.tensor([1, 3]), torch.tensor([2, 3]))\n        ... })\n        >>> g.ntypes\n        ['game', 'user']\n        \"\"\"\n        return self._ntypes\n\n    @property\n    def etypes(self):\n        \"\"\"Return all the edge type names in the graph.\n\n        Returns\n        -------\n        list[str]\n            All the edge type names in a list.\n\n        Notes\n        -----\n        DGL internally assigns an integer ID for each edge type. The returned\n        edge type names are sorted according to their IDs.\n\n        The complete format to specify an relation is a string triplet ``(str, str, str)``\n        for source node type, edge type and destination node type. DGL calls this\n        format *canonical edge type*. An edge type can appear in multiple canonical edge types.\n        For example, ``'interacts'`` can appear in two canonical edge types\n        ``('drug', 'interacts', 'drug')`` and ``('protein', 'interacts', 'protein')``.\n\n        See Also\n        --------\n        canonical_etypes\n\n        Examples\n        --------\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        >>> g = dgl.heterograph({\n        ...     ('user', 'follows', 'user'): (torch.tensor([0, 1]), torch.tensor([1, 2])),\n        ...     ('user', 'follows', 'game'): (torch.tensor([0, 1, 2]), torch.tensor([1, 2, 3])),\n        ...     ('user', 'plays', 'game'): (torch.tensor([1, 3]), torch.tensor([2, 3]))\n        ... })\n        >>> g.etypes\n        ['follows', 'follows', 'plays']\n        \"\"\"\n        return self._etypes\n\n    @property\n    def canonical_etypes(self):\n        \"\"\"Return all the canonical edge types in the graph.\n\n        A canonical edge type is a string triplet ``(str, str, str)``\n        for source node type, edge type and destination node type.\n\n        Returns\n        -------\n        list[(str, str, str)]\n            All the canonical edge type triplets in a list.\n\n        Notes\n        -----\n        DGL internally assigns an integer ID for each edge type. The returned\n        edge type names are sorted according to their IDs.\n\n        See Also\n        --------\n        etypes\n\n        Examples\n        --------\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        >>> g = dgl.heterograph({\n        ...     ('user', 'follows', 'user'): (torch.tensor([0, 1]), torch.tensor([1, 2])),\n        ...     ('user', 'follows', 'game'): (torch.tensor([0, 1, 2]), torch.tensor([1, 2, 3])),\n        ...     ('user', 'plays', 'game'): (torch.tensor([1, 3]), torch.tensor([2, 3]))\n        ... })\n        >>> g.canonical_etypes\n        [('user', 'follows', 'user'),\n         ('user', 'follows', 'game'),\n         ('user', 'plays', 'game')]\n        \"\"\"\n        return self._canonical_etypes\n\n    @property\n    def srctypes(self):\n        \"\"\"Return all the source node type names in this graph.\n\n        If the graph can further divide its node types into two subsets A and B where\n        all the edeges are from nodes of types in A to nodes of types in B, we call\n        this graph a *uni-bipartite* graph and the nodes in A being the *source*\n        nodes and the ones in B being the *destination* nodes. If the graph is not\n        uni-bipartite, the source and destination nodes are just the entire set of\n        nodes in the graph.\n\n        Returns\n        -------\n        list[str]\n            All the source node type names in a list.\n\n        See Also\n        --------\n        dsttypes\n        is_unibipartite\n\n        Examples\n        --------\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        Query for a uni-bipartite graph.\n\n        >>> g = dgl.heterograph({\n        ...     ('user', 'plays', 'game'): (torch.tensor([0]), torch.tensor([1])),\n        ...     ('developer', 'develops', 'game'): (torch.tensor([1]), torch.tensor([2]))\n        ... })\n        >>> g.srctypes\n        ['developer', 'user']\n\n        Query for a graph that is not uni-bipartite.\n\n        >>> g = dgl.heterograph({\n        ...     ('user', 'follows', 'user'): (torch.tensor([0]), torch.tensor([1])),\n        ...     ('developer', 'develops', 'game'): (torch.tensor([1]), torch.tensor([2]))\n        ... })\n        >>> g.srctypes\n        ['developer', 'game', 'user']\n        \"\"\"\n        if self.is_unibipartite:\n            return sorted(list(self._srctypes_invmap.keys()))\n        else:\n            return self.ntypes\n\n    @property\n    def dsttypes(self):\n        \"\"\"Return all the destination node type names in this graph.\n\n        If the graph can further divide its node types into two subsets A and B where\n        all the edeges are from nodes of types in A to nodes of types in B, we call\n        this graph a *uni-bipartite* graph and the nodes in A being the *source*\n        nodes and the ones in B being the *destination* nodes. If the graph is not\n        uni-bipartite, the source and destination nodes are just the entire set of\n        nodes in the graph.\n\n        Returns\n        -------\n        list[str]\n            All the destination node type names in a list.\n\n        See Also\n        --------\n        srctypes\n        is_unibipartite\n\n        Examples\n        --------\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        Query for a uni-bipartite graph.\n\n        >>> g = dgl.heterograph({\n        ...     ('user', 'plays', 'game'): (torch.tensor([0]), torch.tensor([1])),\n        ...     ('developer', 'develops', 'game'): (torch.tensor([1]), torch.tensor([2]))\n        ... })\n        >>> g.dsttypes\n        ['game']\n\n        Query for a graph that is not uni-bipartite.\n\n        >>> g = dgl.heterograph({\n        ...     ('user', 'follows', 'user'): (torch.tensor([0]), torch.tensor([1])),\n        ...     ('developer', 'develops', 'game'): (torch.tensor([1]), torch.tensor([2]))\n        ... })\n        >>> g.dsttypes\n        ['developer', 'game', 'user']\n        \"\"\"\n        if self.is_unibipartite:\n            return sorted(list(self._dsttypes_invmap.keys()))\n        else:\n            return self.ntypes\n\n    def metagraph(self):\n        \"\"\"Return the metagraph of the heterograph.\n\n        The metagraph (or network schema) of a heterogeneous network specifies type constraints\n        on the sets of nodes and edges between the nodes. For a formal definition, refer to\n        `Yizhou et al. <https://www.kdd.org/exploration_files/V14-02-03-Sun.pdf>`_.\n\n        Returns\n        -------\n        networkx.MultiDiGraph\n            The metagraph.\n\n        Examples\n        --------\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        >>> g = dgl.heterograph({\n        ...     ('user', 'follows', 'user'): (torch.tensor([0, 1]), torch.tensor([1, 2])),\n        ...     ('user', 'follows', 'game'): (torch.tensor([0, 1, 2]), torch.tensor([1, 2, 3])),\n        ...     ('user', 'plays', 'game'): (torch.tensor([1, 3]), torch.tensor([2, 3]))\n        ... })\n        >>> meta_g = g.metagraph()\n        >>> meta_g.nodes()\n        NodeView(('user', 'game'))\n        >>> meta_g.edges()\n        OutMultiEdgeDataView([('user', 'user'), ('user', 'game'), ('user', 'game')])\n        \"\"\"\n        nx_graph = self._graph.metagraph.to_networkx()\n        nx_metagraph = nx.MultiDiGraph()\n        for u_v in nx_graph.edges:\n            srctype, etype, dsttype = self.canonical_etypes[\n                nx_graph.edges[u_v][\"id\"]\n            ]\n            nx_metagraph.add_edge(srctype, dsttype, etype)\n        return nx_metagraph\n\n    def to_canonical_etype(self, etype):\n        \"\"\"Convert an edge type to the corresponding canonical edge type in the graph.\n\n        A canonical edge type is a string triplet ``(str, str, str)``\n        for source node type, edge type and destination node type.\n\n        The function expects the given edge type name can uniquely identify a canonical edge\n        type. DGL will raise error if this is not the case.\n\n        Parameters\n        ----------\n        etype : str or (str, str, str)\n            If :attr:`etype` is an edge type (str), it returns the corresponding canonical edge\n            type in the graph. If :attr:`etype` is already a canonical edge type,\n            it directly returns the input unchanged.\n\n        Returns\n        -------\n        (str, str, str)\n            The canonical edge type corresponding to the edge type.\n\n        Examples\n        --------\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        Create a heterograph.\n\n        >>> g = dgl.heterograph({\n        ...     ('user', 'follows', 'user'): ([0, 1], [1, 2]),\n        ...     ('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 1, 1]),\n        ...     ('developer', 'follows', 'game'): ([0, 1], [0, 1])\n        ... })\n\n        Map an edge type to its corresponding canonical edge type.\n\n        >>> g.to_canonical_etype('plays')\n        ('user', 'plays', 'game')\n        >>> g.to_canonical_etype(('user', 'plays', 'game'))\n        ('user', 'plays', 'game')\n\n        See Also\n        --------\n        canonical_etypes\n        \"\"\"\n        if etype is None:\n            if len(self.etypes) != 1:\n                raise DGLError(\n                    \"Edge type name must be specified if there are more than one \"\n                    \"edge types.\"\n                )\n            etype = self.etypes[0]\n        if isinstance(etype, tuple):\n            return etype\n        else:\n            ret = self._etype2canonical.get(etype, None)\n            if ret is None:\n                raise DGLError('Edge type \"{}\" does not exist.'.format(etype))\n            if len(ret) == 0:\n                raise DGLError(\n                    'Edge type \"%s\" is ambiguous. Please use canonical edge type '\n                    \"in the form of (srctype, etype, dsttype)\" % etype\n                )\n            return ret\n\n    def get_ntype_id(self, ntype):\n        \"\"\"Return the ID of the given node type.\n\n        ntype can also be None. If so, there should be only one node type in the\n        graph.\n\n        Parameters\n        ----------\n        ntype : str\n            Node type\n\n        Returns\n        -------\n        int\n        \"\"\"\n        if self.is_unibipartite and ntype is not None:\n            # Only check 'SRC/' and 'DST/' prefix when is_unibipartite graph is True.\n            if ntype.startswith(\"SRC/\"):\n                return self.get_ntype_id_from_src(ntype[4:])\n            elif ntype.startswith(\"DST/\"):\n                return self.get_ntype_id_from_dst(ntype[4:])\n            # If there is no prefix, fallback to normal lookup.\n\n        # Lookup both SRC and DST\n        if ntype is None:\n            if self.is_unibipartite or len(self._srctypes_invmap) != 1:\n                raise DGLError(\n                    \"Node type name must be specified if there are more than one \"\n                    \"node types.\"\n                )\n            return 0\n        ntid = self._srctypes_invmap.get(\n            ntype, self._dsttypes_invmap.get(ntype, None)\n        )\n        if ntid is None:\n            raise DGLError('Node type \"{}\" does not exist.'.format(ntype))\n        return ntid\n\n    def get_ntype_id_from_src(self, ntype):\n        \"\"\"Internal function to return the ID of the given SRC node type.\n\n        ntype can also be None. If so, there should be only one node type in the\n        SRC category. Callable even when the self graph is not uni-bipartite.\n\n        Parameters\n        ----------\n        ntype : str\n            Node type\n\n        Returns\n        -------\n        int\n        \"\"\"\n        if ntype is None:\n            if len(self._srctypes_invmap) != 1:\n                raise DGLError(\n                    \"SRC node type name must be specified if there are more than one \"\n                    \"SRC node types.\"\n                )\n            return next(iter(self._srctypes_invmap.values()))\n        ntid = self._srctypes_invmap.get(ntype, None)\n        if ntid is None:\n            raise DGLError('SRC node type \"{}\" does not exist.'.format(ntype))\n        return ntid\n\n    def get_ntype_id_from_dst(self, ntype):\n        \"\"\"Internal function to return the ID of the given DST node type.\n\n        ntype can also be None. If so, there should be only one node type in the\n        DST category. Callable even when the self graph is not uni-bipartite.\n\n        Parameters\n        ----------\n        ntype : str\n            Node type\n\n        Returns\n        -------\n        int\n        \"\"\"\n        if ntype is None:\n            if len(self._dsttypes_invmap) != 1:\n                raise DGLError(\n                    \"DST node type name must be specified if there are more than one \"\n                    \"DST node types.\"\n                )\n            return next(iter(self._dsttypes_invmap.values()))\n        ntid = self._dsttypes_invmap.get(ntype, None)\n        if ntid is None:\n            raise DGLError('DST node type \"{}\" does not exist.'.format(ntype))\n        return ntid\n\n    def get_etype_id(self, etype):\n        \"\"\"Return the id of the given edge type.\n\n        etype can also be None. If so, there should be only one edge type in the\n        graph.\n\n        Parameters\n        ----------\n        etype : str or tuple of str\n            Edge type\n\n        Returns\n        -------\n        int\n        \"\"\"\n        if etype is None:\n            if self._graph.number_of_etypes() != 1:\n                raise DGLError(\n                    \"Edge type name must be specified if there are more than one \"\n                    \"edge types.\"\n                )\n            return 0\n        etid = self._etypes_invmap.get(self.to_canonical_etype(etype), None)\n        if etid is None:\n            raise DGLError('Edge type \"{}\" does not exist.'.format(etype))\n        return etid\n\n    #################################################################\n    # Batching\n    #################################################################\n    @property\n    def batch_size(self):\n        \"\"\"Return the number of graphs in the batched graph.\n\n        Returns\n        -------\n        int\n            The Number of graphs in the batch. If the graph is not a batched one,\n            it will return 1.\n\n        Examples\n        --------\n\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        Query for homogeneous graphs.\n\n        >>> g1 = dgl.graph((torch.tensor([0, 1, 2]), torch.tensor([1, 2, 3])))\n        >>> g1.batch_size\n        1\n        >>> g2 = dgl.graph((torch.tensor([0, 0, 0, 1]), torch.tensor([0, 1, 2, 0])))\n        >>> bg = dgl.batch([g1, g2])\n        >>> bg.batch_size\n        2\n\n        Query for heterogeneous graphs.\n\n        >>> hg1 = dgl.heterograph({\n        ...       ('user', 'plays', 'game') : (torch.tensor([0, 1]), torch.tensor([0, 0]))})\n        >>> hg1.batch_size\n        1\n        >>> hg2 = dgl.heterograph({\n        ...       ('user', 'plays', 'game') : (torch.tensor([0, 0]), torch.tensor([1, 0]))})\n        >>> bg = dgl.batch([hg1, hg2])\n        >>> bg.batch_size\n        2\n        \"\"\"\n        return len(self.batch_num_nodes(self.ntypes[0]))\n\n    def batch_num_nodes(self, ntype=None):\n        \"\"\"Return the number of nodes for each graph in the batch with the specified node type.\n\n        Parameters\n        ----------\n        ntype : str, optional\n            The node type for query. If the graph has multiple node types, one must\n            specify the argument. Otherwise, it can be omitted. If the graph is not a batched\n            one, it will return a list of length 1 that holds the number of nodes in the graph.\n\n        Returns\n        -------\n        Tensor\n            The number of nodes with the specified type for each graph in the batch. The i-th\n            element of it is the number of nodes with the specified type for the i-th graph.\n\n        Examples\n        --------\n\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        Query for homogeneous graphs.\n\n        >>> g1 = dgl.graph((torch.tensor([0, 1, 2]), torch.tensor([1, 2, 3])))\n        >>> g1.batch_num_nodes()\n        tensor([4])\n        >>> g2 = dgl.graph((torch.tensor([0, 0, 0, 1]), torch.tensor([0, 1, 2, 0])))\n        >>> bg = dgl.batch([g1, g2])\n        >>> bg.batch_num_nodes()\n        tensor([4, 3])\n\n        Query for heterogeneous graphs.\n\n        >>> hg1 = dgl.heterograph({\n        ...       ('user', 'plays', 'game') : (torch.tensor([0, 1]), torch.tensor([0, 0]))})\n        >>> hg2 = dgl.heterograph({\n        ...       ('user', 'plays', 'game') : (torch.tensor([0, 0]), torch.tensor([1, 0]))})\n        >>> bg = dgl.batch([hg1, hg2])\n        >>> bg.batch_num_nodes('user')\n        tensor([2, 1])\n        \"\"\"\n        if ntype is not None and ntype not in self.ntypes:\n            raise DGLError(\n                \"Expect ntype in {}, got {}\".format(self.ntypes, ntype)\n            )\n\n        if self._batch_num_nodes is None:\n            self._batch_num_nodes = {}\n            for ty in self.ntypes:\n                bnn = F.copy_to(\n                    F.tensor([self.num_nodes(ty)], self.idtype), self.device\n                )\n                self._batch_num_nodes[ty] = bnn\n        if ntype is None:\n            if len(self.ntypes) != 1:\n                raise DGLError(\n                    \"Node type name must be specified if there are more than one \"\n                    \"node types.\"\n                )\n            ntype = self.ntypes[0]\n        return self._batch_num_nodes[ntype]\n\n    def set_batch_num_nodes(self, val):\n        \"\"\"Manually set the number of nodes for each graph in the batch with the specified node\n        type.\n\n        Parameters\n        ----------\n        val : Tensor or Mapping[str, Tensor]\n            The dictionary storing number of nodes for each graph in the batch for all node types.\n            If the graph has only one node type, ``val`` can also be a single array indicating the\n            number of nodes per graph in the batch.\n\n        Notes\n        -----\n        This API is always used together with ``set_batch_num_edges`` to specify batching\n        information of a graph, it also do not check the correspondance between the graph structure\n        and batching information and user must guarantee there will be no cross-graph edges in the\n        batch.\n\n        Examples\n        --------\n\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        Create a homogeneous graph.\n\n        >>> g = dgl.graph(([0, 1, 2, 3, 4, 5], [1, 2, 0, 4, 5, 3]))\n\n        Manually set batch information\n\n        >>> g.set_batch_num_nodes(torch.tensor([3, 3]))\n        >>> g.set_batch_num_edges(torch.tensor([3, 3]))\n\n        Unbatch the graph.\n\n        >>> dgl.unbatch(g)\n        [Graph(num_nodes=3, num_edges=3,\n              ndata_schemes={}\n              edata_schemes={}), Graph(num_nodes=3, num_edges=3,\n              ndata_schemes={}\n              edata_schemes={})]\n\n        Create a heterogeneous graph.\n\n        >>> hg = dgl.heterograph({\n        ...      ('user', 'plays', 'game') : ([0, 1, 2, 3, 4, 5], [0, 1, 1, 3, 3, 2]),\n        ...      ('developer', 'develops', 'game') : ([0, 1, 2, 3], [1, 0, 3, 2])})\n\n        Manually set batch information.\n\n        >>> hg.set_batch_num_nodes({\n        ...     'user': torch.tensor([3, 3]),\n        ...     'game': torch.tensor([2, 2]),\n        ...     'developer': torch.tensor([2, 2])})\n        >>> hg.set_batch_num_edges({\n        ...     ('user', 'plays', 'game'): torch.tensor([3, 3]),\n        ...     ('developer', 'develops', 'game'): torch.tensor([2, 2])})\n\n        Unbatch the graph.\n\n        >>> g1, g2 = dgl.unbatch(hg)\n        >>> g1\n        Graph(num_nodes={'developer': 2, 'game': 2, 'user': 3},\n              num_edges={('developer', 'develops', 'game'): 2, ('user', 'plays', 'game'): 3},\n              metagraph=[('developer', 'game', 'develops'), ('user', 'game', 'plays')])\n        >>> g2\n        Graph(num_nodes={'developer': 2, 'game': 2, 'user': 3},\n              num_edges={('developer', 'develops', 'game'): 2, ('user', 'plays', 'game'): 3},\n              metagraph=[('developer', 'game', 'develops'), ('user', 'game', 'plays')])\n\n        See Also\n        --------\n        set_batch_num_edges\n        batch\n        unbatch\n        \"\"\"\n        val = utils.prepare_tensor_or_dict(self, val, \"batch_num_nodes\")\n        if not isinstance(val, Mapping):\n            if len(self.ntypes) != 1:\n                raise DGLError(\n                    \"Must provide a dictionary when there are multiple node types.\"\n                )\n            val = {self.ntypes[0]: val}\n        self._batch_num_nodes = val\n\n    def batch_num_edges(self, etype=None):\n        \"\"\"Return the number of edges for each graph in the batch with the specified edge type.\n\n        Parameters\n        ----------\n        etype : str or tuple of str, optional\n            The edge type for query, which can be an edge type (str) or a canonical edge type\n            (3-tuple of str). When an edge type appears in multiple canonical edge types, one\n            must use a canonical edge type. If the graph has multiple edge types, one must\n            specify the argument. Otherwise, it can be omitted.\n\n        Returns\n        -------\n        Tensor\n            The number of edges with the specified type for each graph in the batch. The i-th\n            element of it is the number of edges with the specified type for the i-th graph.\n            If the graph is not a batched one, it will return a list of length 1 that holds\n            the number of edges in the graph.\n\n        Examples\n        --------\n\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        Query for homogeneous graphs.\n\n        >>> g1 = dgl.graph((torch.tensor([0, 1, 2]), torch.tensor([1, 2, 3])))\n        >>> g1.batch_num_edges()\n        tensor([3])\n        >>> g2 = dgl.graph((torch.tensor([0, 0, 0, 1]), torch.tensor([0, 1, 2, 0])))\n        >>> bg = dgl.batch([g1, g2])\n        >>> bg.batch_num_edges()\n        tensor([3, 4])\n\n        Query for heterogeneous graphs.\n\n        >>> hg1 = dgl.heterograph({\n        ...       ('user', 'plays', 'game') : (torch.tensor([0, 1]), torch.tensor([0, 0]))})\n        >>> hg2 = dgl.heterograph({\n        ...       ('user', 'plays', 'game') : (torch.tensor([0, 0]), torch.tensor([1, 0]))})\n        >>> bg = dgl.batch([hg1, hg2])\n        >>> bg.batch_num_edges('plays')\n        tensor([2, 2])\n        \"\"\"\n        if self._batch_num_edges is None:\n            self._batch_num_edges = {}\n            for ty in self.canonical_etypes:\n                bne = F.copy_to(\n                    F.tensor([self.num_edges(ty)], self.idtype), self.device\n                )\n                self._batch_num_edges[ty] = bne\n        if etype is None:\n            if len(self.etypes) != 1:\n                raise DGLError(\n                    \"Edge type name must be specified if there are more than one \"\n                    \"edge types.\"\n                )\n            etype = self.canonical_etypes[0]\n        else:\n            etype = self.to_canonical_etype(etype)\n        return self._batch_num_edges[etype]\n\n    def set_batch_num_edges(self, val):\n        \"\"\"Manually set the number of edges for each graph in the batch with the specified edge\n        type.\n\n        Parameters\n        ----------\n        val : Tensor or Mapping[str, Tensor]\n            The dictionary storing number of edges for each graph in the batch for all edge types.\n            If the graph has only one edge type, ``val`` can also be a single array indicating the\n            number of edges per graph in the batch.\n\n        Notes\n        -----\n        This API is always used together with ``set_batch_num_nodes`` to specify batching\n        information of a graph, it also do not check the correspondance between the graph structure\n        and batching information and user must guarantee there will be no cross-graph edges in the\n        batch.\n\n        Examples\n        --------\n\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        Create a homogeneous graph.\n\n        >>> g = dgl.graph(([0, 1, 2, 3, 4, 5], [1, 2, 0, 4, 5, 3]))\n\n        Manually set batch information\n\n        >>> g.set_batch_num_nodes(torch.tensor([3, 3]))\n        >>> g.set_batch_num_edges(torch.tensor([3, 3]))\n\n        Unbatch the graph.\n\n        >>> dgl.unbatch(g)\n        [Graph(num_nodes=3, num_edges=3,\n              ndata_schemes={}\n              edata_schemes={}), Graph(num_nodes=3, num_edges=3,\n              ndata_schemes={}\n              edata_schemes={})]\n\n        Create a heterogeneous graph.\n\n        >>> hg = dgl.heterograph({\n        ...      ('user', 'plays', 'game') : ([0, 1, 2, 3, 4, 5], [0, 1, 1, 3, 3, 2]),\n        ...      ('developer', 'develops', 'game') : ([0, 1, 2, 3], [1, 0, 3, 2])})\n\n        Manually set batch information.\n\n        >>> hg.set_batch_num_nodes({\n        ...     'user': torch.tensor([3, 3]),\n        ...     'game': torch.tensor([2, 2]),\n        ...     'developer': torch.tensor([2, 2])})\n        >>> hg.set_batch_num_edges(\n        ...     {('user', 'plays', 'game'): torch.tensor([3, 3]),\n        ...     ('developer', 'develops', 'game'): torch.tensor([2, 2])})\n\n        Unbatch the graph.\n\n        >>> g1, g2 = dgl.unbatch(hg)\n        >>> g1\n        Graph(num_nodes={'developer': 2, 'game': 2, 'user': 3},\n              num_edges={('developer', 'develops', 'game'): 2, ('user', 'plays', 'game'): 3},\n              metagraph=[('developer', 'game', 'develops'), ('user', 'game', 'plays')])\n        >>> g2\n        Graph(num_nodes={'developer': 2, 'game': 2, 'user': 3},\n              num_edges={('developer', 'develops', 'game'): 2, ('user', 'plays', 'game'): 3},\n              metagraph=[('developer', 'game', 'develops'), ('user', 'game', 'plays')])\n\n        See Also\n        --------\n        set_batch_num_nodes\n        batch\n        unbatch\n        \"\"\"\n        val = utils.prepare_tensor_or_dict(self, val, \"batch_num_edges\")\n        if not isinstance(val, Mapping):\n            if len(self.etypes) != 1:\n                raise DGLError(\n                    \"Must provide a dictionary when there are multiple edge types.\"\n                )\n            val = {self.canonical_etypes[0]: val}\n        self._batch_num_edges = val\n\n    #################################################################\n    # View\n    #################################################################\n\n    def get_node_storage(self, key, ntype=None):\n        \"\"\"Get storage object of node feature of type :attr:`ntype` and name :attr:`key`.\"\"\"\n        return self._node_frames[self.get_ntype_id(ntype)]._columns[key]\n\n    def get_edge_storage(self, key, etype=None):\n        \"\"\"Get storage object of edge feature of type :attr:`etype` and name :attr:`key`.\"\"\"\n        return self._edge_frames[self.get_etype_id(etype)]._columns[key]\n\n    @property\n    def nodes(self):\n        \"\"\"Return a node view\n\n        One can use it for:\n\n        1. Getting the node IDs for a single node type.\n        2. Setting/getting features for all nodes of a single node type.\n\n        Examples\n        --------\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        Create a homogeneous graph and a heterogeneous graph of two node types.\n\n        >>> g = dgl.graph((torch.tensor([0, 1]), torch.tensor([1, 2])))\n        >>> hg = dgl.heterograph({\n        ...     ('user', 'follows', 'user'): (torch.tensor([0, 1]), torch.tensor([1, 2])),\n        ...     ('user', 'plays', 'game'): (torch.tensor([3, 4]), torch.tensor([5, 6]))\n        ... })\n\n        Get the node IDs of the homogeneous graph.\n\n        >>> g.nodes()\n        tensor([0, 1, 2])\n\n        Get the node IDs of the heterogeneous graph. With multiple node types introduced,\n        one needs to specify the node type for query.\n\n        >>> hg.nodes('user')\n        tensor([0, 1, 2, 3, 4])\n\n        Set and get a feature 'h' for all nodes of a single type in the heterogeneous graph.\n\n        >>> hg.nodes['user'].data['h'] = torch.ones(5, 1)\n        >>> hg.nodes['user'].data['h']\n        tensor([[1.], [1.], [1.], [1.], [1.]])\n\n        To set node features for a graph with a single node type, use :func:`DGLGraph.ndata`.\n\n        See Also\n        --------\n        ndata\n        \"\"\"\n        # Todo (Mufei) Replace the syntax g.nodes[...].ndata[...] with g.nodes[...][...]\n        return HeteroNodeView(self, self.get_ntype_id)\n\n    @property\n    def srcnodes(self):\n        \"\"\"Return a node view for source nodes\n\n        If the graph is a uni-bipartite graph (see :func:`is_unibipartite` for reference),\n        this is :func:`nodes` restricted to source node types. Otherwise, it is an alias\n        for :func:`nodes`.\n\n        One can use it for:\n\n        1. Getting the node IDs for a single node type.\n        2. Setting/getting features for all nodes of a single node type.\n\n        Examples\n        --------\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        Create a uni-bipartite graph.\n\n        >>> g = dgl.heterograph({\n        ...     ('user', 'plays', 'game'): (torch.tensor([0]), torch.tensor([1])),\n        ...     ('developer', 'develops', 'game'): (torch.tensor([1]), torch.tensor([2]))\n        ... })\n\n        Get the node IDs for source node types.\n\n        >>> g.srcnodes('user')\n        tensor([0])\n        >>> g.srcnodes('developer')\n        tensor([0, 1])\n\n        Set/get features for source node types.\n\n        >>> g.srcnodes['user'].data['h'] = torch.ones(1, 1)\n        >>> g.srcnodes['user'].data['h']\n        tensor([[1.]])\n\n        Create a graph that is not uni-bipartite.\n\n        >>> g = dgl.heterograph({\n        ...     ('user', 'follows', 'user'): (torch.tensor([0]), torch.tensor([1])),\n        ...     ('developer', 'develops', 'game'): (torch.tensor([1]), torch.tensor([2]))\n        ... })\n\n        :func:`dgl.DGLGraph.srcnodes` falls back to :func:`dgl.DGLGraph.nodes` and one can\n        get the node IDs for both source and destination node types.\n\n        >>> g.srcnodes('game')\n        tensor([0, 1, 2])\n\n        One can also set/get features for destination node types in this case.\n\n        >>> g.srcnodes['game'].data['h'] = torch.ones(3, 1)\n        >>> g.srcnodes['game'].data['h']\n        tensor([[1.],\n                [1.],\n                [1.]])\n\n        See Also\n        --------\n        srcdata\n        \"\"\"\n        return HeteroNodeView(self, self.get_ntype_id_from_src)\n\n    @property\n    def dstnodes(self):\n        \"\"\"Return a node view for destination nodes\n\n        If the graph is a uni-bipartite graph (see :func:`is_unibipartite` for reference),\n        this is :func:`nodes` restricted to destination node types. Otherwise, it is an alias\n        for :func:`nodes`.\n\n        One can use it for:\n\n        1. Getting the node IDs for a single node type.\n        2. Setting/getting features for all nodes of a single node type.\n\n        Examples\n        --------\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        Create a uni-bipartite graph.\n\n        >>> g = dgl.heterograph({\n        ...     ('user', 'plays', 'game'): (torch.tensor([0]), torch.tensor([1])),\n        ...     ('developer', 'develops', 'game'): (torch.tensor([1]), torch.tensor([2]))\n        ... })\n\n        Get the node IDs for destination node types.\n\n        >>> g.dstnodes('game')\n        tensor([0, 1, 2])\n\n        Set/get features for destination node types.\n\n        >>> g.dstnodes['game'].data['h'] = torch.ones(3, 1)\n        >>> g.dstnodes['game'].data['h']\n        tensor([[1.],\n                [1.],\n                [1.]])\n\n        Create a graph that is not uni-bipartite.\n\n        >>> g = dgl.heterograph({\n        ...     ('user', 'follows', 'user'): (torch.tensor([0]), torch.tensor([1])),\n        ...     ('developer', 'develops', 'game'): (torch.tensor([1]), torch.tensor([2]))\n        ... })\n\n        :func:`dgl.DGLGraph.dstnodes` falls back to :func:`dgl.DGLGraph.nodes` and one can\n        get the node IDs for both source and destination node types.\n\n        >>> g.dstnodes('developer')\n        tensor([0, 1])\n\n        One can also set/get features for source node types in this case.\n\n        >>> g.dstnodes['developer'].data['h'] = torch.ones(2, 1)\n        >>> g.dstnodes['developer'].data['h']\n        tensor([[1.],\n                [1.]])\n\n        See Also\n        --------\n        dstdata\n        \"\"\"\n        return HeteroNodeView(self, self.get_ntype_id_from_dst)\n\n    @property\n    def ndata(self):\n        \"\"\"Return a node data view for setting/getting node features\n\n        Let ``g`` be a DGLGraph. If ``g`` is a graph of a single node type, ``g.ndata[feat]``\n        returns the node feature associated with the name ``feat``. One can also set a node\n        feature associated with the name ``feat`` by setting ``g.ndata[feat]`` to a tensor.\n\n        If ``g`` is a graph of multiple node types, ``g.ndata[feat]`` returns a\n        dict[str, Tensor] mapping node types to the node features associated with the name\n        ``feat`` for the corresponding type. One can also set a node feature associated\n        with the name ``feat`` for some node type(s) by setting ``g.ndata[feat]`` to a\n        dictionary as described.\n\n        Notes\n        -----\n        For setting features, the device of the features must be the same as the device\n        of the graph.\n\n        Examples\n        --------\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        Set and get feature 'h' for a graph of a single node type.\n\n        >>> g = dgl.graph((torch.tensor([0, 1]), torch.tensor([1, 2])))\n        >>> g.ndata['h'] = torch.ones(3, 1)\n        >>> g.ndata['h']\n        tensor([[1.],\n                [1.],\n                [1.]])\n\n        Set and get feature 'h' for a graph of multiple node types.\n\n        >>> g = dgl.heterograph({\n        ...     ('user', 'follows', 'user'): (torch.tensor([1, 2]), torch.tensor([3, 4])),\n        ...     ('player', 'plays', 'game'): (torch.tensor([2, 2]), torch.tensor([1, 1]))\n        ... })\n        >>> g.ndata['h'] = {'game': torch.zeros(2, 1), 'player': torch.ones(3, 1)}\n        >>> g.ndata['h']\n        {'game': tensor([[0.], [0.]]),\n         'player': tensor([[1.], [1.], [1.]])}\n        >>> g.ndata['h'] = {'game': torch.ones(2, 1)}\n        >>> g.ndata['h']\n        {'game': tensor([[1.], [1.]]),\n         'player': tensor([[1.], [1.], [1.]])}\n\n        See Also\n        --------\n        nodes\n        \"\"\"\n        if len(self.ntypes) == 1:\n            ntid = self.get_ntype_id(None)\n            ntype = self.ntypes[0]\n            return HeteroNodeDataView(self, ntype, ntid, ALL)\n        else:\n            ntids = [self.get_ntype_id(ntype) for ntype in self.ntypes]\n            ntypes = self.ntypes\n            return HeteroNodeDataView(self, ntypes, ntids, ALL)\n\n    @property\n    def srcdata(self):\n        \"\"\"Return a node data view for setting/getting source node features.\n\n        Let ``g`` be a DGLGraph. If ``g`` is a graph of a single source node type,\n        ``g.srcdata[feat]`` returns the source node feature associated with the name ``feat``.\n        One can also set a source node feature associated with the name ``feat`` by\n        setting ``g.srcdata[feat]`` to a tensor.\n\n        If ``g`` is a graph of multiple source node types, ``g.srcdata[feat]`` returns a\n        dict[str, Tensor] mapping source node types to the node features associated with\n        the name ``feat`` for the corresponding type. One can also set a node feature\n        associated with the name ``feat`` for some source node type(s) by setting\n        ``g.srcdata[feat]`` to a dictionary as described.\n\n        Notes\n        -----\n        For setting features, the device of the features must be the same as the device\n        of the graph.\n\n        Examples\n        --------\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        Set and get feature 'h' for a graph of a single source node type.\n\n        >>> g = dgl.heterograph({\n        ...     ('user', 'plays', 'game'): (torch.tensor([0, 1]), torch.tensor([1, 2]))})\n        >>> g.srcdata['h'] = torch.ones(2, 1)\n        >>> g.srcdata['h']\n        tensor([[1.],\n                [1.]])\n\n        Set and get feature 'h' for a graph of multiple source node types.\n\n        >>> g = dgl.heterograph({\n        ...     ('user', 'plays', 'game'): (torch.tensor([1, 2]), torch.tensor([3, 4])),\n        ...     ('player', 'plays', 'game'): (torch.tensor([2, 2]), torch.tensor([1, 1]))\n        ... })\n        >>> g.srcdata['h'] = {'user': torch.zeros(3, 1), 'player': torch.ones(3, 1)}\n        >>> g.srcdata['h']\n        {'player': tensor([[1.], [1.], [1.]]),\n         'user': tensor([[0.], [0.], [0.]])}\n        >>> g.srcdata['h'] = {'user': torch.ones(3, 1)}\n        >>> g.srcdata['h']\n        {'player': tensor([[1.], [1.], [1.]]),\n         'user': tensor([[1.], [1.], [1.]])}\n\n        See Also\n        --------\n        nodes\n        ndata\n        srcnodes\n        \"\"\"\n        if len(self.srctypes) == 1:\n            ntype = self.srctypes[0]\n            ntid = self.get_ntype_id_from_src(ntype)\n            return HeteroNodeDataView(self, ntype, ntid, ALL)\n        else:\n            ntypes = self.srctypes\n            ntids = [self.get_ntype_id_from_src(ntype) for ntype in ntypes]\n            return HeteroNodeDataView(self, ntypes, ntids, ALL)\n\n    @property\n    def dstdata(self):\n        \"\"\"Return a node data view for setting/getting destination node features.\n\n        Let ``g`` be a DGLGraph. If ``g`` is a graph of a single destination node type,\n        ``g.dstdata[feat]`` returns the destination node feature associated with the name\n        ``feat``. One can also set a destination node feature associated with the name\n        ``feat`` by setting ``g.dstdata[feat]`` to a tensor.\n\n        If ``g`` is a graph of multiple destination node types, ``g.dstdata[feat]`` returns a\n        dict[str, Tensor] mapping destination node types to the node features associated with\n        the name ``feat`` for the corresponding type. One can also set a node feature\n        associated with the name ``feat`` for some destination node type(s) by setting\n        ``g.dstdata[feat]`` to a dictionary as described.\n\n        Notes\n        -----\n        For setting features, the device of the features must be the same as the device\n        of the graph.\n\n        Examples\n        --------\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        Set and get feature 'h' for a graph of a single destination node type.\n\n        >>> g = dgl.heterograph({\n        ...     ('user', 'plays', 'game'): (torch.tensor([0, 1]), torch.tensor([1, 2]))})\n        >>> g.dstdata['h'] = torch.ones(3, 1)\n        >>> g.dstdata['h']\n        tensor([[1.],\n                [1.],\n                [1.]])\n\n        Set and get feature 'h' for a graph of multiple destination node types.\n\n        >>> g = dgl.heterograph({\n        ...     ('user', 'plays', 'game'): (torch.tensor([1, 2]), torch.tensor([1, 2])),\n        ...     ('user', 'watches', 'movie'): (torch.tensor([2, 2]), torch.tensor([1, 1]))\n        ... })\n        >>> g.dstdata['h'] = {'game': torch.zeros(3, 1), 'movie': torch.ones(2, 1)}\n        >>> g.dstdata['h']\n        {'game': tensor([[0.], [0.], [0.]]),\n         'movie': tensor([[1.], [1.]])}\n        >>> g.dstdata['h'] = {'game': torch.ones(3, 1)}\n        >>> g.dstdata['h']\n        {'game': tensor([[1.], [1.], [1.]]),\n         'movie': tensor([[1.], [1.]])}\n\n        See Also\n        --------\n        nodes\n        ndata\n        dstnodes\n        \"\"\"\n        if len(self.dsttypes) == 1:\n            ntype = self.dsttypes[0]\n            ntid = self.get_ntype_id_from_dst(ntype)\n            return HeteroNodeDataView(self, ntype, ntid, ALL)\n        else:\n            ntypes = self.dsttypes\n            ntids = [self.get_ntype_id_from_dst(ntype) for ntype in ntypes]\n            return HeteroNodeDataView(self, ntypes, ntids, ALL)\n\n    @property\n    def edges(self):\n        \"\"\"Return an edge view\n\n        One can use it for:\n\n        1. Getting the edges for a single edge type. In this case, it can take the\n           following optional arguments:\n\n            - form : str, optional\n                  The return form, which can be one of the following:\n\n                  - ``'uv'`` (default): The returned result is a 2-tuple of 1D tensors\n                    :math:`(U, V)`, representing the source and destination nodes of all edges.\n                    For each :math:`i`, :math:`(U[i], V[i])` forms an edge.\n                  - ``'eid'``: The returned result is a 1D tensor :math:`EID`, representing\n                    the IDs of all edges.\n                  - ``'all'``: The returned result is a 3-tuple of 1D tensors :math:`(U, V, EID)`,\n                    representing the source nodes, destination nodes and IDs of all edges.\n                    For each :math:`i`, :math:`(U[i], V[i])` forms an edge with ID :math:`EID[i]`.\n            - order : str, optional\n                  The order of the returned edges, which can be one of the following:\n\n                  - ``'eid'`` (default): The edges are sorted by their IDs.\n                  - ``'srcdst'``: The edges are sorted first by their source node IDs and then\n                    by their destination node IDs to break ties.\n            - etype : str or tuple of str, optional\n                  The edge type for query, which can be an edge type (str) or a canonical edge\n                  type (3-tuple of str). When an edge type appears in multiple canonical edge\n                  types, one must use a canonical edge type. If the graph has multiple edge\n                  types, one must specify the argument. Otherwise, it can be omitted.\n        2. Setting/getting features for all edges of a single edge type. To set/get a feature\n           ``feat`` for edges of type ``etype`` in a graph ``g``, one can use\n           ``g.edges[etype].data[feat]``.\n\n        Examples\n        --------\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        **Get the Edges for a Single Edge Type**\n\n        Create a graph with a single edge type.\n\n        >>> g = dgl.graph((torch.tensor([1, 0, 0]), torch.tensor([1, 1, 0])))\n        >>> g.edges()\n        (tensor([1, 0, 0]), tensor([1, 1, 0]))\n\n        Specify a different value for :attr:`form` and :attr:`order`.\n\n        >>> g.edges(form='all', order='srcdst')\n        (tensor([0, 0, 1]), tensor([0, 1, 1]), tensor([2, 1, 0]))\n\n        For a graph of multiple edge types, it is required to specify the edge type in query.\n\n        >>> hg = dgl.heterograph({\n        ...     ('user', 'follows', 'user'): (torch.tensor([0, 1]), torch.tensor([1, 2])),\n        ...     ('user', 'plays', 'game'): (torch.tensor([3, 4]), torch.tensor([5, 6]))\n        ... })\n        >>> hg.edges(etype='plays')\n        (tensor([3, 4]), tensor([5, 6]))\n\n        **Set/get Features for All Edges of a Single Edge Type**\n\n        Create a heterogeneous graph of two edge types.\n\n        >>> hg = dgl.heterograph({\n        ...     ('user', 'follows', 'user'): (torch.tensor([0, 1]), torch.tensor([1, 2])),\n        ...     ('user', 'plays', 'game'): (torch.tensor([3, 4]), torch.tensor([5, 6]))\n        ... })\n\n        Set and get a feature 'h' for all edges of a single type in the heterogeneous graph.\n\n        >>> hg.edges['follows'].data['h'] = torch.ones(2, 1)\n        >>> hg.edges['follows'].data['h']\n        tensor([[1.], [1.]])\n\n        To set edge features for a graph with a single edge type, use :func:`DGLGraph.edata`.\n\n        See Also\n        --------\n        edata\n        \"\"\"\n        # TODO(Mufei): Replace the syntax g.edges[...].edata[...] with g.edges[...][...]\n        return HeteroEdgeView(self)\n\n    @property\n    def edata(self):\n        \"\"\"Return an edge data view for setting/getting edge features.\n\n        Let ``g`` be a DGLGraph. If ``g`` is a graph of a single edge type, ``g.edata[feat]``\n        returns the edge feature associated with the name ``feat``. One can also set an\n        edge feature associated with the name ``feat`` by setting ``g.edata[feat]`` to a tensor.\n\n        If ``g`` is a graph of multiple edge types, ``g.edata[feat]`` returns a\n        dict[str, Tensor] mapping canonical edge types to the edge features associated with\n        the name ``feat`` for the corresponding type. One can also set an edge feature\n        associated with the name ``feat`` for some edge type(s) by setting\n        ``g.edata[feat]`` to a dictionary as described.\n\n        Notes\n        -----\n        For setting features, the device of the features must be the same as the device\n        of the graph.\n\n        Examples\n        --------\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        Set and get feature 'h' for a graph of a single edge type.\n\n        >>> g = dgl.graph((torch.tensor([0, 1]), torch.tensor([1, 2])))\n        >>> g.edata['h'] = torch.ones(2, 1)\n        >>> g.edata['h']\n        tensor([[1.],\n                [1.]])\n\n        Set and get feature 'h' for a graph of multiple edge types.\n\n        >>> g = dgl.heterograph({\n        ...     ('user', 'follows', 'user'): (torch.tensor([1, 2]), torch.tensor([3, 4])),\n        ...     ('user', 'plays', 'user'): (torch.tensor([2, 2]), torch.tensor([1, 1])),\n        ...     ('player', 'plays', 'game'): (torch.tensor([2, 2]), torch.tensor([1, 1]))\n        ... })\n        >>> g.edata['h'] = {('user', 'follows', 'user'): torch.zeros(2, 1),\n        ...                 ('user', 'plays', 'user'): torch.ones(2, 1)}\n        >>> g.edata['h']\n        {('user', 'follows', 'user'): tensor([[0.], [0.]]),\n         ('user', 'plays', 'user'): tensor([[1.], [1.]])}\n        >>> g.edata['h'] = {('user', 'follows', 'user'): torch.ones(2, 1)}\n        >>> g.edata['h']\n        {('user', 'follows', 'user'): tensor([[1.], [1.]]),\n         ('user', 'plays', 'user'): tensor([[1.], [1.]])}\n\n        See Also\n        --------\n        edges\n        \"\"\"\n        if len(self.canonical_etypes) == 1:\n            return HeteroEdgeDataView(self, None, ALL)\n        else:\n            return HeteroEdgeDataView(self, self.canonical_etypes, ALL)\n\n    def _find_etypes(self, key):\n        etypes = [\n            i\n            for i, (srctype, etype, dsttype) in enumerate(\n                self._canonical_etypes\n            )\n            if (key[0] == SLICE_FULL or key[0] == srctype)\n            and (key[1] == SLICE_FULL or key[1] == etype)\n            and (key[2] == SLICE_FULL or key[2] == dsttype)\n        ]\n        return etypes\n\n    def __getitem__(self, key):\n        \"\"\"Return the relation slice of this graph.\n\n        You can get a relation slice with ``self[srctype, etype, dsttype]``, where\n        ``srctype``, ``etype``, and ``dsttype`` can be either a string or a full\n        slice (``:``) representing wildcard (i.e. any source/edge/destination type).\n\n        A relation slice is a homogeneous (with one node type and one edge type) or\n        bipartite (with two node types and one edge type) graph, transformed from\n        the original heterogeneous graph.\n\n        If there is only one canonical edge type found, then the returned relation\n        slice would be a subgraph induced from the original graph.  That is, it is\n        equivalent to ``self.edge_type_subgraph(etype)``.  The node and edge features\n        of the returned graph would be shared with thew original graph.\n\n        If there are multiple canonical edge types found, then the source/edge/destination\n        node types would be a *concatenation* of original node/edge types.  The\n        new source/destination node type would have the concatenation determined by\n        :func:`dgl.combine_names() <dgl.combine_names>` called on original source/destination\n        types as its name.  The source/destination node would be formed by concatenating the\n        common features of the original source/destination types.  Therefore they are not\n        shared with the original graph.  Edge type is similar.\n\n        Parameters\n        ----------\n        key : str or tuple\n            Either a string representing the edge type name, or a tuple in the form of\n            ``(srctype, etype, dsttype)`` where ``srctype``, ``etype``, ``dsttype`` can be either\n            strings representing type names or a full slice object (`:`).\n\n        Returns\n        -------\n        DGLGraph\n            The relation slice.\n\n        Notes\n        -----\n        This function returns a new graph.  Changing the content of this graph does not reflect\n        onto the original graph.\n\n        If the graph combines multiple node types or edge types together, it will have the\n        mapping of node/edge types and IDs from the new graph to the original graph.\n        The mappings have the name ``dgl.NTYPE``, ``dgl.NID``, ``dgl.ETYPE`` and ``dgl.EID``,\n        similar to the function :func:`dgl.to_homogenenous`.\n\n        Examples\n        --------\n        >>> g = dgl.heterograph({\n        ...     ('A1', 'AB1', 'B'): ([0, 1, 2], [1, 2, 3]),\n        ...     ('A1', 'AB2', 'B'): ([1, 2, 3], [3, 4, 5]),\n        ...     ('A2', 'AB2', 'B'): ([1, 3, 5], [2, 4, 6])})\n        >>> new_g = g['A1', :, 'B']         # combines all edge types between A1 and B\n        >>> new_g\n        Graph(num_nodes={'A1': 4, 'B': 7},\n              num_edges={('A1', 'AB1+AB2', 'B'): 6},\n              metagraph=[('A1', 'B', 'AB1+AB2')])\n        >>> new_g.edges()\n        (tensor([0, 1, 2, 1, 2, 3]), tensor([1, 2, 3, 3, 4, 5]))\n        >>> new_g2 = g[:, 'AB2', 'B']        # combines all node types that are source of AB2\n        >>> new_g2\n        Graph(num_nodes={'A1+A2': 10, 'B': 7},\n              num_edges={('A1+A2', 'AB2+AB2', 'B'): 6},\n              metagraph=[('A1+A2', 'B', 'AB2+AB2')])\n        >>> new_g2.edges()\n        (tensor([1, 2, 3, 5, 7, 9]), tensor([3, 4, 5, 2, 4, 6]))\n\n        If a combination of multiple node types and edge types occur, one can find\n        the mapping to the original node type and IDs like the following:\n\n        >>> new_g1.edges['AB1+AB2'].data[dgl.EID]\n        tensor([0, 1, 2, 0, 1, 2])\n        >>> new_g1.edges['AB1+AB2'].data[dgl.ETYPE]\n        tensor([0, 0, 0, 1, 1, 1])\n        >>> new_g2.nodes['A1+A2'].data[dgl.NID]\n        tensor([0, 1, 2, 3, 0, 1, 2, 3, 4, 5])\n        >>> new_g2.nodes['A1+A2'].data[dgl.NTYPE]\n        tensor([0, 0, 0, 0, 1, 1, 1, 1, 1, 1])\n        \"\"\"\n        err_msg = (\n            \"Invalid slice syntax. Use G['etype'] or G['srctype', 'etype', 'dsttype'] \"\n            + \"to get view of one relation type. Use : to slice multiple types (e.g. \"\n            + \"G['srctype', :, 'dsttype']).\"\n        )\n\n        orig_key = key\n        if not isinstance(key, tuple):\n            key = (SLICE_FULL, key, SLICE_FULL)\n\n        if len(key) != 3:\n            raise DGLError(err_msg)\n\n        etypes = self._find_etypes(key)\n\n        if len(etypes) == 0:\n            raise DGLError(\n                'Invalid key \"{}\". Must be one of the edge types.'.format(\n                    orig_key\n                )\n            )\n\n        if len(etypes) == 1:\n            # no ambiguity: return the unitgraph itself\n            srctype, etype, dsttype = self._canonical_etypes[etypes[0]]\n            stid = self.get_ntype_id_from_src(srctype)\n            etid = self.get_etype_id((srctype, etype, dsttype))\n            dtid = self.get_ntype_id_from_dst(dsttype)\n            new_g = self._graph.get_relation_graph(etid)\n\n            if stid == dtid:\n                new_ntypes = [srctype]\n                new_nframes = [self._node_frames[stid]]\n            else:\n                new_ntypes = ([srctype], [dsttype])\n                new_nframes = [self._node_frames[stid], self._node_frames[dtid]]\n            new_etypes = [etype]\n            new_eframes = [self._edge_frames[etid]]\n\n            return self.__class__(\n                new_g, new_ntypes, new_etypes, new_nframes, new_eframes\n            )\n        else:\n            flat = self._graph.flatten_relations(etypes)\n            new_g = flat.graph\n\n            # merge frames\n            stids = flat.induced_srctype_set.asnumpy()\n            dtids = flat.induced_dsttype_set.asnumpy()\n            etids = flat.induced_etype_set.asnumpy()\n            new_ntypes = [combine_names(self.ntypes, stids)]\n            if new_g.number_of_ntypes() == 2:\n                new_ntypes.append(combine_names(self.ntypes, dtids))\n                new_nframes = [\n                    combine_frames(self._node_frames, stids),\n                    combine_frames(self._node_frames, dtids),\n                ]\n            else:\n                assert np.array_equal(stids, dtids)\n                new_nframes = [combine_frames(self._node_frames, stids)]\n            new_etypes = [combine_names(self.etypes, etids)]\n            new_eframes = [combine_frames(self._edge_frames, etids)]\n\n            # create new heterograph\n            new_hg = self.__class__(\n                new_g, new_ntypes, new_etypes, new_nframes, new_eframes\n            )\n\n            src = new_ntypes[0]\n            dst = new_ntypes[1] if new_g.number_of_ntypes() == 2 else src\n            # put the parent node/edge type and IDs\n            new_hg.nodes[src].data[NTYPE] = F.zerocopy_from_dgl_ndarray(\n                flat.induced_srctype\n            )\n            new_hg.nodes[src].data[NID] = F.zerocopy_from_dgl_ndarray(\n                flat.induced_srcid\n            )\n            new_hg.nodes[dst].data[NTYPE] = F.zerocopy_from_dgl_ndarray(\n                flat.induced_dsttype\n            )\n            new_hg.nodes[dst].data[NID] = F.zerocopy_from_dgl_ndarray(\n                flat.induced_dstid\n            )\n            new_hg.edata[ETYPE] = F.zerocopy_from_dgl_ndarray(\n                flat.induced_etype\n            )\n            new_hg.edata[EID] = F.zerocopy_from_dgl_ndarray(flat.induced_eid)\n\n            return new_hg\n\n    #################################################################\n    # Graph query\n    #################################################################\n\n    def number_of_nodes(self, ntype=None):\n        \"\"\"Alias of :meth:`num_nodes`\"\"\"\n        return self.num_nodes(ntype)\n\n    def num_nodes(self, ntype=None):\n        \"\"\"Return the number of nodes in the graph.\n\n        Parameters\n        ----------\n        ntype : str, optional\n            The node type name. If given, it returns the number of nodes of the\n            type. If not given (default), it returns the total number of nodes of all types.\n\n        Returns\n        -------\n        int\n            The number of nodes.\n\n        Examples\n        --------\n\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        Create a graph with two node types -- 'user' and 'game'.\n\n        >>> g = dgl.heterograph({\n        ...     ('user', 'follows', 'user'): (torch.tensor([0, 1]), torch.tensor([1, 2])),\n        ...     ('user', 'plays', 'game'): (torch.tensor([3, 4]), torch.tensor([5, 6]))\n        ... })\n\n        Query for the number of nodes.\n\n        >>> g.num_nodes('user')\n        5\n        >>> g.num_nodes('game')\n        7\n        >>> g.num_nodes()\n        12\n        \"\"\"\n        if ntype is None:\n            return sum(\n                [\n                    self._graph.num_nodes(ntid)\n                    for ntid in range(len(self.ntypes))\n                ]\n            )\n        else:\n            return self._graph.num_nodes(self.get_ntype_id(ntype))\n\n    def number_of_src_nodes(self, ntype=None):\n        \"\"\"Alias of :meth:`num_src_nodes`\"\"\"\n        return self.num_src_nodes(ntype)\n\n    def num_src_nodes(self, ntype=None):\n        \"\"\"Return the number of source nodes in the graph.\n\n        If the graph can further divide its node types into two subsets A and B where\n        all the edeges are from nodes of types in A to nodes of types in B, we call\n        this graph a *uni-bipartite* graph and the nodes in A being the *source*\n        nodes and the ones in B being the *destination* nodes. If the graph is not\n        uni-bipartite, the source and destination nodes are just the entire set of\n        nodes in the graph.\n\n        Parameters\n        ----------\n        ntype : str, optional\n            The source node type name. If given, it returns the number of nodes for\n            the source node type. If not given (default), it returns the number of\n            nodes summed over all source node types.\n\n        Returns\n        -------\n        int\n            The number of nodes\n\n        See Also\n        --------\n        num_dst_nodes\n        is_unibipartite\n\n        Examples\n        --------\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        Create a homogeneous graph for query.\n\n        >>> g = dgl.graph((torch.tensor([0, 1]), torch.tensor([1, 2])))\n        >>> g.num_src_nodes()\n        3\n\n        Create a heterogeneous graph with two source node types -- 'developer' and 'user'.\n\n        >>> g = dgl.heterograph({\n        ...     ('developer', 'develops', 'game'): (torch.tensor([0, 1]), torch.tensor([1, 2])),\n        ...     ('user', 'plays', 'game'): (torch.tensor([3, 4]), torch.tensor([5, 6]))\n        ... })\n\n        Query for the number of nodes.\n\n        >>> g.num_src_nodes('developer')\n        2\n        >>> g.num_src_nodes('user')\n        5\n        >>> g.num_src_nodes()\n        7\n        \"\"\"\n        if ntype is None:\n            return sum(\n                [\n                    self._graph.num_nodes(self.get_ntype_id_from_src(nty))\n                    for nty in self.srctypes\n                ]\n            )\n        else:\n            return self._graph.num_nodes(self.get_ntype_id_from_src(ntype))\n\n    def number_of_dst_nodes(self, ntype=None):\n        \"\"\"Alias of :func:`num_dst_nodes`\"\"\"\n        return self.num_dst_nodes(ntype)\n\n    def num_dst_nodes(self, ntype=None):\n        \"\"\"Return the number of destination nodes in the graph.\n\n        If the graph can further divide its node types into two subsets A and B where\n        all the edeges are from nodes of types in A to nodes of types in B, we call\n        this graph a *uni-bipartite* graph and the nodes in A being the *source*\n        nodes and the ones in B being the *destination* nodes. If the graph is not\n        uni-bipartite, the source and destination nodes are just the entire set of\n        nodes in the graph.\n\n        Parameters\n        ----------\n        ntype : str, optional\n            The destination node type name. If given, it returns the number of nodes of\n            the destination node type. If not given (default), it returns the number of\n            nodes summed over all the destination node types.\n\n        Returns\n        -------\n        int\n            The number of nodes\n\n        See Also\n        --------\n        num_src_nodes\n        is_unibipartite\n\n        Examples\n        --------\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        Create a homogeneous graph for query.\n\n        >>> g = dgl.graph((torch.tensor([0, 1]), torch.tensor([1, 2])))\n        >>> g.num_dst_nodes()\n        3\n\n        Create a heterogeneous graph with two destination node types -- 'user' and 'game'.\n\n        >>> g = dgl.heterograph({\n        ...     ('user', 'follows', 'user'): (torch.tensor([0, 1]), torch.tensor([1, 2])),\n        ...     ('user', 'plays', 'game'): (torch.tensor([3, 4]), torch.tensor([5, 6]))\n        ... })\n\n        Query for the number of nodes.\n\n        >>> g.num_dst_nodes('user')\n        5\n        >>> g.num_dst_nodes('game')\n        7\n        >>> g.num_dst_nodes()\n        12\n        \"\"\"\n        if ntype is None:\n            return sum(\n                [\n                    self._graph.num_nodes(self.get_ntype_id_from_dst(nty))\n                    for nty in self.dsttypes\n                ]\n            )\n        else:\n            return self._graph.num_nodes(self.get_ntype_id_from_dst(ntype))\n\n    def number_of_edges(self, etype=None):\n        \"\"\"Alias of :func:`num_edges`\"\"\"\n        return self.num_edges(etype)\n\n    def num_edges(self, etype=None):\n        \"\"\"Return the number of edges in the graph.\n\n        Parameters\n        ----------\n        etype : str or (str, str, str), optional\n            The type name of the edges. The allowed type name formats are:\n\n            * ``(str, str, str)`` for source node type, edge type and destination node type.\n            * or one ``str`` edge type name if the name can uniquely identify a\n              triplet format in the graph.\n\n            If not provided, return the total number of edges regardless of the types\n            in the graph.\n\n        Returns\n        -------\n        int\n            The number of edges.\n\n        Examples\n        --------\n\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        Create a graph with three canonical edge types.\n\n        >>> g = dgl.heterograph({\n        ...     ('user', 'follows', 'user'): (torch.tensor([0, 1]), torch.tensor([1, 2])),\n        ...     ('user', 'follows', 'game'): (torch.tensor([0, 1, 2]), torch.tensor([1, 2, 3])),\n        ...     ('user', 'plays', 'game'): (torch.tensor([1, 3]), torch.tensor([2, 3]))\n        ... })\n\n        Query for the number of edges.\n\n        >>> g.num_edges('plays')\n        2\n        >>> g.num_edges()\n        7\n\n        Use a canonical edge type instead when there is ambiguity for an edge type.\n\n        >>> g.num_edges(('user', 'follows', 'user'))\n        2\n        >>> g.num_edges(('user', 'follows', 'game'))\n        3\n        \"\"\"\n        if etype is None:\n            return sum(\n                [\n                    self._graph.num_edges(etid)\n                    for etid in range(len(self.canonical_etypes))\n                ]\n            )\n        else:\n            return self._graph.num_edges(self.get_etype_id(etype))\n\n    @property\n    def is_multigraph(self):\n        \"\"\"Return whether the graph is a multigraph with parallel edges.\n\n        A multigraph has more than one edges between the same pair of nodes, called\n        *parallel edges*.  For heterogeneous graphs, parallel edge further requires\n        the canonical edge type to be the same (see :meth:`canonical_etypes` for the\n        definition).\n\n        Returns\n        -------\n        bool\n            True if the graph is a multigraph.\n\n        Notes\n        -----\n        Checking whether the graph is a multigraph could be expensive for a large one.\n\n        Examples\n        --------\n\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        Check for homogeneous graphs.\n\n        >>> g = dgl.graph((torch.tensor([0, 1]), torch.tensor([1, 3])))\n        >>> g.is_multigraph\n        False\n        >>> g = dgl.graph((torch.tensor([0, 1, 1]), torch.tensor([1, 3, 3])))\n        >>> g.is_multigraph\n        True\n\n        Check for heterogeneous graphs.\n\n        >>> g = dgl.heterograph({\n        ...     ('user', 'follows', 'user'): (torch.tensor([0, 1]), torch.tensor([1, 2])),\n        ...     ('user', 'follows', 'game'): (torch.tensor([0, 1, 2]), torch.tensor([1, 2, 3]))\n        ... })\n        >>> g.is_multigraph\n        False\n        >>> g = dgl.heterograph({\n        ...     ('user', 'follows', 'user'): (torch.tensor([0, 1, 1]), torch.tensor([1, 2, 2])),\n        ...     ('user', 'follows', 'game'): (torch.tensor([0, 1, 2]), torch.tensor([1, 2, 3]))\n        ... })\n        >>> g.is_multigraph\n        True\n        \"\"\"\n        return self._graph.is_multigraph()\n\n    @property\n    def is_homogeneous(self):\n        \"\"\"Return whether the graph is a homogeneous graph.\n\n        A homogeneous graph only has one node type and one edge type.\n\n        Returns\n        -------\n        bool\n            True if the graph is a homogeneous graph.\n\n        Examples\n        --------\n\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        Create a homogeneous graph for check.\n\n        >>> g = dgl.graph((torch.tensor([0, 0, 1, 1]), torch.tensor([1, 0, 2, 3])))\n        >>> g.is_homogeneous\n        True\n\n        Create a heterogeneous graph for check.\n\n        If the graph has multiple edge types, one need to specify the edge type.\n\n        >>> g = dgl.heterograph({\n        ...     ('user', 'follows', 'game'): (torch.tensor([0, 1, 2]), torch.tensor([1, 2, 3]))})\n        >>> g.is_homogeneous\n        False\n        \"\"\"\n        return len(self.ntypes) == 1 and len(self.etypes) == 1\n\n    @property\n    def idtype(self):\n        \"\"\"The data type for storing the structure-related graph information\n        such as node and edge IDs.\n\n        Returns\n        -------\n        Framework-specific device object\n            For example, this can be ``torch.int32`` or ``torch.int64`` for PyTorch.\n\n        Examples\n        --------\n\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        >>> src_ids = torch.tensor([0, 0, 1])\n        >>> dst_ids = torch.tensor([1, 2, 2])\n        >>> g = dgl.graph((src_ids, dst_ids))\n        >>> g.idtype\n        torch.int64\n        >>> g = dgl.graph((src_ids, dst_ids), idtype=torch.int32)\n        >>> g.idtype\n        torch.int32\n\n        See Also\n        --------\n        long\n        int\n        \"\"\"\n        return getattr(F, self._graph.dtype)\n\n    @property\n    def _idtype_str(self):\n        \"\"\"The dtype of graph index\n\n        Returns\n        -------\n        backend dtype object\n            th.int32/th.int64 or tf.int32/tf.int64 etc.\n        \"\"\"\n        return self._graph.dtype\n\n    def has_nodes(self, vid, ntype=None):\n        \"\"\"Return whether the graph contains the given nodes.\n\n        Parameters\n        ----------\n        vid : node ID(s)\n            The nodes IDs. The allowed nodes ID formats are:\n\n            * ``int``: The ID of a single node.\n            * Int Tensor: Each element is a node ID. The tensor must have the same device type\n              and ID data type as the graph's.\n            * iterable[int]: Each element is a node ID.\n\n        ntype : str, optional\n            The node type name. Can be omitted if there is\n            only one type of nodes in the graph.\n\n        Returns\n        -------\n        bool or bool Tensor\n            A tensor of bool flags where each element is True if the node is in the graph.\n            If the input is a single node, return one bool value.\n\n        Examples\n        --------\n\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        Create a graph with two node types -- 'user' and 'game'.\n\n        >>> g = dgl.heterograph({\n        ...     ('user', 'follows', 'user'): (torch.tensor([0, 1]), torch.tensor([1, 2])),\n        ...     ('user', 'plays', 'game'): (torch.tensor([3, 4]), torch.tensor([0, 1]))\n        ... })\n\n        Query for the nodes.\n\n        >>> g.has_nodes(0, 'user')\n        True\n        >>> g.has_nodes(3, 'game')\n        False\n        >>> g.has_nodes(torch.tensor([3, 0, 1]), 'game')\n        tensor([False,  True,  True])\n        \"\"\"\n        vid_tensor = utils.prepare_tensor(self, vid, \"vid\")\n        if len(vid_tensor) > 0 and F.as_scalar(F.min(vid_tensor, 0)) < 0 < len(\n            vid_tensor\n        ):\n            raise DGLError(\"All IDs must be non-negative integers.\")\n        ret = self._graph.has_nodes(self.get_ntype_id(ntype), vid_tensor)\n        if isinstance(vid, numbers.Integral):\n            return bool(F.as_scalar(ret))\n        else:\n            return F.astype(ret, F.bool)\n\n    def has_edges_between(self, u, v, etype=None):\n        \"\"\"Return whether the graph contains the given edges.\n\n        Parameters\n        ----------\n        u : node IDs\n            The source node IDs of the edges. The allowed formats are:\n\n            * ``int``: A single node.\n            * Int Tensor: Each element is a node ID. The tensor must have the same device type\n              and ID data type as the graph's.\n            * iterable[int]: Each element is a node ID.\n\n        v : node IDs\n            The destination node IDs of the edges. The allowed formats are:\n\n            * ``int``: A single node.\n            * Int Tensor: Each element is a node ID. The tensor must have the same device type\n              and ID data type as the graph's.\n            * iterable[int]: Each element is a node ID.\n\n        etype : str or (str, str, str), optional\n            The type names of the edges. The allowed type name formats are:\n\n            * ``(str, str, str)`` for source node type, edge type and destination node type.\n            * or one ``str`` edge type name if the name can uniquely identify a\n              triplet format in the graph.\n\n            Can be omitted if the graph has only one type of edges.\n\n\n        Returns\n        -------\n        bool or bool Tensor\n            A tensor of bool flags where each element is True if the node is in the graph.\n            If the input is a single node, return one bool value.\n\n        Examples\n        --------\n\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        Create a homogeneous graph.\n\n        >>> g = dgl.graph((torch.tensor([0, 0, 1, 1]), torch.tensor([1, 0, 2, 3])))\n\n        Query for the edges.\n\n        >>> g.has_edges_between(1, 2)\n        True\n        >>> g.has_edges_between(torch.tensor([1, 2]), torch.tensor([2, 3]))\n        tensor([ True, False])\n\n        If the graph has multiple edge types, one need to specify the edge type.\n\n        >>> g = dgl.heterograph({\n        ...     ('user', 'follows', 'user'): (torch.tensor([0, 1]), torch.tensor([1, 2])),\n        ...     ('user', 'follows', 'game'): (torch.tensor([0, 1, 2]), torch.tensor([1, 2, 3])),\n        ...     ('user', 'plays', 'game'): (torch.tensor([1, 3]), torch.tensor([2, 3]))\n        ... })\n        >>> g.has_edges_between(torch.tensor([1, 2]), torch.tensor([2, 3]), 'plays')\n        tensor([ True, False])\n\n        Use a canonical edge type instead when there is ambiguity for an edge type.\n\n        >>> g.has_edges_between(torch.tensor([1, 2]), torch.tensor([2, 3]),\n        ...                     ('user', 'follows', 'user'))\n        tensor([ True, False])\n        >>> g.has_edges_between(torch.tensor([1, 2]), torch.tensor([2, 3]),\n        ...                     ('user', 'follows', 'game'))\n        tensor([True, True])\n        \"\"\"\n        srctype, _, dsttype = self.to_canonical_etype(etype)\n        u_tensor = utils.prepare_tensor(self, u, \"u\")\n        if F.as_scalar(\n            F.sum(self.has_nodes(u_tensor, ntype=srctype), dim=0)\n        ) != len(u_tensor):\n            raise DGLError(\"u contains invalid node IDs\")\n        v_tensor = utils.prepare_tensor(self, v, \"v\")\n        if F.as_scalar(\n            F.sum(self.has_nodes(v_tensor, ntype=dsttype), dim=0)\n        ) != len(v_tensor):\n            raise DGLError(\"v contains invalid node IDs\")\n        ret = self._graph.has_edges_between(\n            self.get_etype_id(etype), u_tensor, v_tensor\n        )\n        if isinstance(u, numbers.Integral) and isinstance(v, numbers.Integral):\n            return bool(F.as_scalar(ret))\n        else:\n            return F.astype(ret, F.bool)\n\n    def predecessors(self, v, etype=None):\n        \"\"\"Return the predecessor(s) of a particular node with the specified edge type.\n\n        Node ``u`` is a predecessor of node ``v`` if there is an edge ``(u, v)`` with type\n        ``etype`` in the graph.\n\n        Parameters\n        ----------\n        v : int\n            The node ID. If the graph has multiple edge types, the ID is for the destination\n            type corresponding to the edge type.\n        etype : str or (str, str, str), optional\n            The type names of the edges. The allowed type name formats are:\n\n            * ``(str, str, str)`` for source node type, edge type and destination node type.\n            * or one ``str`` edge type name if the name can uniquely identify a\n              triplet format in the graph.\n\n            Can be omitted if the graph has only one type of edges.\n\n\n        Returns\n        -------\n        Tensor\n            The predecessors of :attr:`v` with the specified edge type.\n\n        Examples\n        --------\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        Create a homogeneous graph.\n\n        >>> g = dgl.graph((torch.tensor([0, 0, 1, 1]), torch.tensor([1, 1, 2, 3])))\n\n        Query for node 1.\n\n        >>> g.predecessors(1)\n        tensor([0, 0])\n\n        For a graph of multiple edge types, it is required to specify the edge type in query.\n\n        >>> hg = dgl.heterograph({\n        ...     ('user', 'follows', 'user'): (torch.tensor([0, 1]), torch.tensor([1, 2])),\n        ...     ('user', 'plays', 'game'): (torch.tensor([3, 4]), torch.tensor([5, 6]))\n        ... })\n        >>> hg.predecessors(1, etype='follows')\n        tensor([0])\n\n        See Also\n        --------\n        successors\n        \"\"\"\n        if not self.has_nodes(v, self.to_canonical_etype(etype)[-1]):\n            raise DGLError(\"Non-existing node ID {}\".format(v))\n        return self._graph.predecessors(self.get_etype_id(etype), v)\n\n    def successors(self, v, etype=None):\n        \"\"\"Return the successor(s) of a particular node with the specified edge type.\n\n        Node ``u`` is a successor of node ``v`` if there is an edge ``(v, u)`` with type\n        ``etype`` in the graph.\n\n        Parameters\n        ----------\n        v : int\n            The node ID. If the graph has multiple edge types, the ID is for the source\n            type corresponding to the edge type.\n        etype : str or (str, str, str), optional\n            The type names of the edges. The allowed type name formats are:\n\n            * ``(str, str, str)`` for source node type, edge type and destination node type.\n            * or one ``str`` edge type name if the name can uniquely identify a\n              triplet format in the graph.\n\n            Can be omitted if the graph has only one type of edges.\n\n        Returns\n        -------\n        Tensor\n            The successors of :attr:`v` with the specified edge type.\n\n        Examples\n        --------\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        Create a homogeneous graph.\n\n        >>> g = dgl.graph((torch.tensor([0, 0, 1, 1]), torch.tensor([1, 1, 2, 3])))\n\n        Query for node 1.\n\n        >>> g.successors(1)\n        tensor([2, 3])\n\n        For a graph of multiple edge types, it is required to specify the edge type in query.\n\n        >>> hg = dgl.heterograph({\n        ...     ('user', 'follows', 'user'): (torch.tensor([0, 1]), torch.tensor([1, 2])),\n        ...     ('user', 'plays', 'game'): (torch.tensor([3, 4]), torch.tensor([5, 6]))\n        ... })\n        >>> hg.successors(1, etype='follows')\n        tensor([2])\n\n        See Also\n        --------\n        predecessors\n        \"\"\"\n        if not self.has_nodes(v, self.to_canonical_etype(etype)[0]):\n            raise DGLError(\"Non-existing node ID {}\".format(v))\n        return self._graph.successors(self.get_etype_id(etype), v)\n\n    def edge_ids(self, u, v, return_uv=False, etype=None):\n        \"\"\"Return the edge ID(s) given the two endpoints of the edge(s).\n\n        Parameters\n        ----------\n        u : node IDs\n            The source node IDs of the edges. The allowed formats are:\n\n            * ``int``: A single node.\n            * Int Tensor: Each element is a node ID. The tensor must have the same device type\n              and ID data type as the graph's.\n            * iterable[int]: Each element is a node ID.\n\n        v : node IDs\n            The destination node IDs of the edges. The allowed formats are:\n\n            * ``int``: A single node.\n            * Int Tensor: Each element is a node ID. The tensor must have the same device type\n              and ID data type as the graph's.\n            * iterable[int]: Each element is a node ID.\n        return_uv : bool, optional\n            Whether to return the source and destination node IDs along with the edges. If\n            False (default), it assumes that the graph is a simple graph and there is only\n            one edge from one node to another. If True, there can be multiple edges found\n            from one node to another.\n        etype : str or (str, str, str), optional\n            The type names of the edges. The allowed type name formats are:\n\n            * ``(str, str, str)`` for source node type, edge type and destination node type.\n            * or one ``str`` edge type name if the name can uniquely identify a\n              triplet format in the graph.\n\n            Can be omitted if the graph has only one type of edges.\n\n        Returns\n        -------\n        Tensor, or (Tensor, Tensor, Tensor)\n\n            * If ``return_uv=False``, it returns the edge IDs in a tensor, where the i-th\n              element is the ID of the edge ``(u[i], v[i])``.\n            * If ``return_uv=True``, it returns a tuple of three 1D tensors ``(eu, ev, e)``.\n              ``e[i]`` is the ID of an edge from ``eu[i]`` to ``ev[i]``. It returns all edges\n              (including parallel edges) from ``eu[i]`` to ``ev[i]`` in this case.\n\n        Notes\n        -----\n        If the graph is a simple graph, ``return_uv=False``, and there are no edges\n        between some pairs of node(s), it will raise an error.\n\n        If the graph is a multigraph, ``return_uv=False``, and there are multiple edges\n        between some pairs of node(s), it returns an arbitrary one from them.\n\n        Examples\n        --------\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        Create a homogeneous graph.\n\n        >>> g = dgl.graph((torch.tensor([0, 0, 1, 1, 1]), torch.tensor([1, 0, 2, 3, 2])))\n\n        Query for the edges.\n\n        >>> g.edge_ids(0, 0)\n        1\n        >>> g.edge_ids(torch.tensor([1, 0]), torch.tensor([3, 1]))\n        tensor([3, 0])\n\n        Get all edges for pairs of nodes.\n\n        >>> g.edge_ids(torch.tensor([1, 0]), torch.tensor([3, 1]), return_uv=True)\n        (tensor([1, 0]), tensor([3, 1]), tensor([3, 0]))\n\n        If the graph has multiple edge types, one need to specify the edge type.\n\n        >>> g = dgl.heterograph({\n        ...     ('user', 'follows', 'user'): (torch.tensor([0, 1]), torch.tensor([1, 2])),\n        ...     ('user', 'follows', 'game'): (torch.tensor([0, 1, 2]), torch.tensor([1, 2, 3])),\n        ...     ('user', 'plays', 'game'): (torch.tensor([1, 3]), torch.tensor([2, 3]))\n        ... })\n        >>> g.edge_ids(torch.tensor([1]), torch.tensor([2]), etype='plays')\n        tensor([0])\n\n        Use a canonical edge type instead when there is ambiguity for an edge type.\n\n        >>> g.edge_ids(torch.tensor([0, 1]), torch.tensor([1, 2]),\n        ...            etype=('user', 'follows', 'user'))\n        tensor([0, 1])\n        >>> g.edge_ids(torch.tensor([1, 2]), torch.tensor([2, 3]),\n        ...            etype=('user', 'follows', 'game'))\n        tensor([1, 2])\n        \"\"\"\n        is_int = isinstance(u, numbers.Integral) and isinstance(\n            v, numbers.Integral\n        )\n        srctype, _, dsttype = self.to_canonical_etype(etype)\n        u = utils.prepare_tensor(self, u, \"u\")\n        if F.as_scalar(F.sum(self.has_nodes(u, ntype=srctype), dim=0)) != len(\n            u\n        ):\n            raise DGLError(\"u contains invalid node IDs\")\n        v = utils.prepare_tensor(self, v, \"v\")\n        if F.as_scalar(F.sum(self.has_nodes(v, ntype=dsttype), dim=0)) != len(\n            v\n        ):\n            raise DGLError(\"v contains invalid node IDs\")\n\n        if return_uv:\n            return self._graph.edge_ids_all(self.get_etype_id(etype), u, v)\n        else:\n            eid = self._graph.edge_ids_one(self.get_etype_id(etype), u, v)\n            is_neg_one = F.equal(eid, -1)\n            if F.as_scalar(F.sum(is_neg_one, 0)):\n                # Raise error since some (u, v) pair is not a valid edge.\n                idx = F.nonzero_1d(is_neg_one)\n                raise DGLError(\n                    \"Error: (%d, %d) does not form a valid edge.\"\n                    % (\n                        F.as_scalar(F.gather_row(u, idx)),\n                        F.as_scalar(F.gather_row(v, idx)),\n                    )\n                )\n            return F.as_scalar(eid) if is_int else eid\n\n    def find_edges(self, eid, etype=None):\n        \"\"\"Return the source and destination node ID(s) given the edge ID(s).\n\n        Parameters\n        ----------\n        eid : edge ID(s)\n            The edge IDs. The allowed formats are:\n\n            * ``int``: A single ID.\n            * Int Tensor: Each element is an ID. The tensor must have the same device type\n              and ID data type as the graph's.\n            * iterable[int]: Each element is an ID.\n\n        etype : str or (str, str, str), optional\n            The type names of the edges. The allowed type name formats are:\n\n            * ``(str, str, str)`` for source node type, edge type and destination node type.\n            * or one ``str`` edge type name if the name can uniquely identify a\n              triplet format in the graph.\n\n            Can be omitted if the graph has only one type of edges.\n\n        Returns\n        -------\n        Tensor\n            The source node IDs of the edges. The i-th element is the source node ID of\n            the i-th edge.\n        Tensor\n            The destination node IDs of the edges. The i-th element is the destination node\n            ID of the i-th edge.\n\n        Examples\n        --------\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        Create a homogeneous graph.\n\n        >>> g = dgl.graph((torch.tensor([0, 0, 1, 1]), torch.tensor([1, 0, 2, 3])))\n\n        Find edges of IDs 0 and 2.\n\n        >>> g.find_edges(torch.tensor([0, 2]))\n        (tensor([0, 1]), tensor([1, 2]))\n\n        For a graph of multiple edge types, it is required to specify the edge type in query.\n\n        >>> hg = dgl.heterograph({\n        ...     ('user', 'follows', 'user'): (torch.tensor([0, 1]), torch.tensor([1, 2])),\n        ...     ('user', 'plays', 'game'): (torch.tensor([3, 4]), torch.tensor([5, 6]))\n        ... })\n        >>> hg.find_edges(torch.tensor([1, 0]), 'plays')\n        (tensor([4, 3]), tensor([6, 5]))\n        \"\"\"\n        eid = utils.prepare_tensor(self, eid, \"eid\")\n        if len(eid) > 0:\n            min_eid = F.as_scalar(F.min(eid, 0))\n            if min_eid < 0:\n                raise DGLError(\"Invalid edge ID {:d}\".format(min_eid))\n            max_eid = F.as_scalar(F.max(eid, 0))\n            if max_eid >= self.num_edges(etype):\n                raise DGLError(\"Invalid edge ID {:d}\".format(max_eid))\n\n        if len(eid) == 0:\n            empty = F.copy_to(F.tensor([], self.idtype), self.device)\n            return empty, empty\n        src, dst, _ = self._graph.find_edges(self.get_etype_id(etype), eid)\n        return src, dst\n\n    def in_edges(self, v, form=\"uv\", etype=None):\n        \"\"\"Return the incoming edges of the given nodes.\n\n        Parameters\n        ----------\n        v : node ID(s)\n            The node IDs. The allowed formats are:\n\n            * ``int``: A single node.\n            * Int Tensor: Each element is a node ID. The tensor must have the same device type\n              and ID data type as the graph's.\n            * iterable[int]: Each element is a node ID.\n        form : str, optional\n            The result format, which can be one of the following:\n\n            - ``'eid'``: The returned result is a 1D tensor :math:`EID`, representing\n              the IDs of all edges.\n            - ``'uv'`` (default): The returned result is a 2-tuple of 1D tensors :math:`(U, V)`,\n              representing the source and destination nodes of all edges. For each :math:`i`,\n              :math:`(U[i], V[i])` forms an edge.\n            - ``'all'``: The returned result is a 3-tuple of 1D tensors :math:`(U, V, EID)`,\n              representing the source nodes, destination nodes and IDs of all edges.\n              For each :math:`i`, :math:`(U[i], V[i])` forms an edge with ID :math:`EID[i]`.\n        etype : str or (str, str, str), optional\n            The type names of the edges. The allowed type name formats are:\n\n            * ``(str, str, str)`` for source node type, edge type and destination node type.\n            * or one ``str`` edge type name if the name can uniquely identify a\n              triplet format in the graph.\n\n            Can be omitted if the graph has only one type of edges.\n\n        Returns\n        -------\n        Tensor or (Tensor, Tensor) or (Tensor, Tensor, Tensor)\n            All incoming edges of the nodes with the specified type. For a description of the\n            returned result, see the description of :attr:`form`.\n\n        Examples\n        --------\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        Create a homogeneous graph.\n\n        >>> g = dgl.graph((torch.tensor([0, 0, 1, 1]), torch.tensor([1, 0, 2, 3])))\n\n        Query for the nodes 1 and 0.\n\n        >>> g.in_edges(torch.tensor([1, 0]))\n        (tensor([0, 0]), tensor([1, 0]))\n\n        Specify a different value for :attr:`form`.\n\n        >>> g.in_edges(torch.tensor([1, 0]), form='all')\n        (tensor([0, 0]), tensor([1, 0]), tensor([0, 1]))\n\n        For a graph of multiple edge types, it is required to specify the edge type in query.\n\n        >>> hg = dgl.heterograph({\n        ...     ('user', 'follows', 'user'): (torch.tensor([0, 1]), torch.tensor([1, 2])),\n        ...     ('user', 'plays', 'game'): (torch.tensor([3, 4]), torch.tensor([5, 6]))\n        ... })\n        >>> hg.in_edges(torch.tensor([1, 0]), etype='follows')\n        (tensor([0]), tensor([1]))\n\n        See Also\n        --------\n        edges\n        out_edges\n        \"\"\"\n        v = utils.prepare_tensor(self, v, \"v\")\n        src, dst, eid = self._graph.in_edges(self.get_etype_id(etype), v)\n        if form == \"all\":\n            return src, dst, eid\n        elif form == \"uv\":\n            return src, dst\n        elif form == \"eid\":\n            return eid\n        else:\n            raise DGLError(\n                'Invalid form: {}. Must be \"all\", \"uv\" or \"eid\".'.format(form)\n            )\n\n    def out_edges(self, u, form=\"uv\", etype=None):\n        \"\"\"Return the outgoing edges of the given nodes.\n\n        Parameters\n        ----------\n        u : node ID(s)\n            The node IDs. The allowed formats are:\n\n            * ``int``: A single node.\n            * Int Tensor: Each element is a node ID. The tensor must have the same device type\n              and ID data type as the graph's.\n            * iterable[int]: Each element is a node ID.\n        form : str, optional\n            The return form, which can be one of the following:\n\n            - ``'eid'``: The returned result is a 1D tensor :math:`EID`, representing\n              the IDs of all edges.\n            - ``'uv'`` (default): The returned result is a 2-tuple of 1D tensors :math:`(U, V)`,\n              representing the source and destination nodes of all edges. For each :math:`i`,\n              :math:`(U[i], V[i])` forms an edge.\n            - ``'all'``: The returned result is a 3-tuple of 1D tensors :math:`(U, V, EID)`,\n              representing the source nodes, destination nodes and IDs of all edges.\n              For each :math:`i`, :math:`(U[i], V[i])` forms an edge with ID :math:`EID[i]`.\n        etype : str or (str, str, str), optional\n            The type names of the edges. The allowed type name formats are:\n\n            * ``(str, str, str)`` for source node type, edge type and destination node type.\n            * or one ``str`` edge type name if the name can uniquely identify a\n              triplet format in the graph.\n\n            Can be omitted if the graph has only one type of edges.\n\n        Returns\n        -------\n        Tensor or (Tensor, Tensor) or (Tensor, Tensor, Tensor)\n            All outgoing edges of the nodes with the specified type. For a description of the\n            returned result, see the description of :attr:`form`.\n\n        Examples\n        --------\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        Create a homogeneous graph.\n\n        >>> g = dgl.graph((torch.tensor([0, 0, 1, 1]), torch.tensor([1, 0, 2, 3])))\n\n        Query for the nodes 1 and 2.\n\n        >>> g.out_edges(torch.tensor([1, 2]))\n        (tensor([1, 1]), tensor([2, 3]))\n\n        Specify a different value for :attr:`form`.\n\n        >>> g.out_edges(torch.tensor([1, 2]), form='all')\n        (tensor([1, 1]), tensor([2, 3]), tensor([2, 3]))\n\n        For a graph of multiple edge types, it is required to specify the edge type in query.\n\n        >>> hg = dgl.heterograph({\n        ...     ('user', 'follows', 'user'): (torch.tensor([0, 1]), torch.tensor([1, 2])),\n        ...     ('user', 'plays', 'game'): (torch.tensor([3, 4]), torch.tensor([5, 6]))\n        ... })\n        >>> hg.out_edges(torch.tensor([1, 2]), etype='follows')\n        (tensor([1]), tensor([2]))\n\n        See Also\n        --------\n        edges\n        in_edges\n        \"\"\"\n        u = utils.prepare_tensor(self, u, \"u\")\n        srctype, _, _ = self.to_canonical_etype(etype)\n        if F.as_scalar(F.sum(self.has_nodes(u, ntype=srctype), dim=0)) != len(\n            u\n        ):\n            raise DGLError(\"u contains invalid node IDs\")\n        src, dst, eid = self._graph.out_edges(self.get_etype_id(etype), u)\n        if form == \"all\":\n            return src, dst, eid\n        elif form == \"uv\":\n            return src, dst\n        elif form == \"eid\":\n            return eid\n        else:\n            raise DGLError(\n                'Invalid form: {}. Must be \"all\", \"uv\" or \"eid\".'.format(form)\n            )\n\n    def all_edges(self, form=\"uv\", order=\"eid\", etype=None):\n        \"\"\"Return all edges with the specified edge type.\n\n        Parameters\n        ----------\n        form : str, optional\n            The return form, which can be one of the following:\n\n            - ``'eid'``: The returned result is a 1D tensor :math:`EID`, representing\n              the IDs of all edges.\n            - ``'uv'`` (default): The returned result is a 2-tuple of 1D tensors :math:`(U, V)`,\n              representing the source and destination nodes of all edges. For each :math:`i`,\n              :math:`(U[i], V[i])` forms an edge.\n            - ``'all'``: The returned result is a 3-tuple of 1D tensors :math:`(U, V, EID)`,\n              representing the source nodes, destination nodes and IDs of all edges.\n              For each :math:`i`, :math:`(U[i], V[i])` forms an edge with ID :math:`EID[i]`.\n        order : str, optional\n            The order of the returned edges, which can be one of the following:\n\n            - ``'srcdst'``: The edges are sorted first by their source node IDs and then\n              by their destination node IDs to break ties.\n            - ``'eid'`` (default): The edges are sorted by their IDs.\n        etype : str or tuple of str, optional\n            The edge type for query, which can be an edge type (str) or a canonical edge type\n            (3-tuple of str). When an edge type appears in multiple canonical edge types, one\n            must use a canonical edge type. If the graph has multiple edge types, one must\n            specify the argument. Otherwise, it can be omitted.\n\n        Returns\n        -------\n        Tensor or (Tensor, Tensor) or (Tensor, Tensor, Tensor)\n            All edges of the specified edge type. For a description of the returned result,\n            see the description of :attr:`form`.\n\n        Examples\n        --------\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        Create a homogeneous graph.\n\n        >>> g = dgl.graph((torch.tensor([0, 0, 1, 1]), torch.tensor([1, 0, 2, 3])))\n\n        Query for edges.\n\n        >>> g.all_edges()\n        (tensor([0, 0, 1, 1]), tensor([1, 0, 2, 3]))\n\n        Specify a different value for :attr:`form` and :attr:`order`.\n\n        >>> g.all_edges(form='all', order='srcdst')\n        (tensor([0, 0, 1, 1]), tensor([0, 1, 2, 3]), tensor([1, 0, 2, 3]))\n\n        For a graph of multiple edge types, it is required to specify the edge type in query.\n\n        >>> hg = dgl.heterograph({\n        ...     ('user', 'follows', 'user'): (torch.tensor([0, 1]), torch.tensor([1, 2])),\n        ...     ('user', 'plays', 'game'): (torch.tensor([3, 4]), torch.tensor([5, 6]))\n        ... })\n        >>> hg.all_edges(etype='plays')\n        (tensor([3, 4]), tensor([5, 6]))\n\n        See Also\n        --------\n        edges\n        in_edges\n        out_edges\n        \"\"\"\n        src, dst, eid = self._graph.edges(self.get_etype_id(etype), order)\n        if form == \"all\":\n            return src, dst, eid\n        elif form == \"uv\":\n            return src, dst\n        elif form == \"eid\":\n            return eid\n        else:\n            raise DGLError(\n                'Invalid form: {}. Must be \"all\", \"uv\" or \"eid\".'.format(form)\n            )\n\n    def in_degrees(self, v=ALL, etype=None):\n        \"\"\"Return the in-degree(s) of the given nodes.\n\n        It computes the in-degree(s) w.r.t. to the edges of the given edge type.\n\n        Parameters\n        ----------\n        v : node IDs\n            The node IDs. The allowed formats are:\n\n            * ``int``: A single node.\n            * Int Tensor: Each element is a node ID. The tensor must have the same device type\n              and ID data type as the graph's.\n            * iterable[int]: Each element is a node ID.\n\n            If not given, return the in-degrees of all the nodes.\n        etype : str or (str, str, str), optional\n            The type name of the edges. The allowed type name formats are:\n\n            * ``(str, str, str)`` for source node type, edge type and destination node type.\n            * or one ``str`` edge type name if the name can uniquely identify a\n              triplet format in the graph.\n\n            Can be omitted if the graph has only one type of edges.\n\n        Returns\n        -------\n        int or Tensor\n            The in-degree(s) of the node(s) in a Tensor. The i-th element is the in-degree\n            of the i-th input node. If :attr:`v` is an ``int``, return an ``int`` too.\n\n        Examples\n        --------\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        Create a homogeneous graph.\n\n        >>> g = dgl.graph((torch.tensor([0, 0, 1, 1]), torch.tensor([1, 1, 2, 3])))\n\n        Query for all nodes.\n\n        >>> g.in_degrees()\n        tensor([0, 2, 1, 1])\n\n        Query for nodes 1 and 2.\n\n        >>> g.in_degrees(torch.tensor([1, 2]))\n        tensor([2, 1])\n\n        For a graph of multiple edge types, it is required to specify the edge type in query.\n\n        >>> hg = dgl.heterograph({\n        ...     ('user', 'follows', 'user'): (torch.tensor([0, 1]), torch.tensor([1, 2])),\n        ...     ('user', 'plays', 'game'): (torch.tensor([3, 4]), torch.tensor([5, 6]))\n        ... })\n        >>> hg.in_degrees(torch.tensor([1, 0]), etype='follows')\n        tensor([1, 0])\n\n        See Also\n        --------\n        out_degrees\n        \"\"\"\n        dsttype = self.to_canonical_etype(etype)[2]\n        etid = self.get_etype_id(etype)\n        if is_all(v):\n            v = self.dstnodes(dsttype)\n        v_tensor = utils.prepare_tensor(self, v, \"v\")\n        deg = self._graph.in_degrees(etid, v_tensor)\n        if isinstance(v, numbers.Integral):\n            return F.as_scalar(deg)\n        else:\n            return deg\n\n    def out_degrees(self, u=ALL, etype=None):\n        \"\"\"Return the out-degree(s) of the given nodes.\n\n        It computes the out-degree(s) w.r.t. to the edges of the given edge type.\n\n        Parameters\n        ----------\n        u : node IDs\n            The node IDs. The allowed formats are:\n\n            * ``int``: A single node.\n            * Int Tensor: Each element is a node ID. The tensor must have the same device type\n              and ID data type as the graph's.\n            * iterable[int]: Each element is a node ID.\n\n            If not given, return the in-degrees of all the nodes.\n        etype : str or (str, str, str), optional\n            The type names of the edges. The allowed type name formats are:\n\n            * ``(str, str, str)`` for source node type, edge type and destination node type.\n            * or one ``str`` edge type name if the name can uniquely identify a\n              triplet format in the graph.\n\n            Can be omitted if the graph has only one type of edges.\n\n        Returns\n        -------\n        int or Tensor\n            The out-degree(s) of the node(s) in a Tensor. The i-th element is the out-degree\n            of the i-th input node. If :attr:`v` is an ``int``, return an ``int`` too.\n\n        Examples\n        --------\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        Create a homogeneous graph.\n\n        >>> g = dgl.graph((torch.tensor([0, 0, 1, 1]), torch.tensor([1, 1, 2, 3])))\n\n        Query for all nodes.\n\n        >>> g.out_degrees()\n        tensor([2, 2, 0, 0])\n\n        Query for nodes 1 and 2.\n\n        >>> g.out_degrees(torch.tensor([1, 2]))\n        tensor([2, 0])\n\n        For a graph of multiple edge types, it is required to specify the edge type in query.\n\n        >>> hg = dgl.heterograph({\n        ...     ('user', 'follows', 'user'): (torch.tensor([0, 1]), torch.tensor([1, 2])),\n        ...     ('user', 'plays', 'game'): (torch.tensor([3, 4]), torch.tensor([5, 6]))\n        ... })\n        >>> hg.out_degrees(torch.tensor([1, 0]), etype='follows')\n        tensor([1, 1])\n\n        See Also\n        --------\n        in_degrees\n        \"\"\"\n        srctype = self.to_canonical_etype(etype)[0]\n        etid = self.get_etype_id(etype)\n        if is_all(u):\n            u = self.srcnodes(srctype)\n        u_tensor = utils.prepare_tensor(self, u, \"u\")\n        if F.as_scalar(\n            F.sum(self.has_nodes(u_tensor, ntype=srctype), dim=0)\n        ) != len(u_tensor):\n            raise DGLError(\"u contains invalid node IDs\")\n        deg = self._graph.out_degrees(etid, utils.prepare_tensor(self, u, \"u\"))\n        if isinstance(u, numbers.Integral):\n            return F.as_scalar(deg)\n        else:\n            return deg\n\n    def adjacency_matrix(self, etype=None):\n        \"\"\"Alias of :meth:`adj`\"\"\"\n        return self.adj(etype)\n\n    def adj(self, etype=None, eweight_name=None):\n        \"\"\"Get the adjacency matrix of the graph.\n\n        Parameters\n        ----------\n        etype : str or (str, str, str), optional\n            The type names of the edges. The allowed type name formats are:\n\n            * ``(str, str, str)`` for source node type, edge type and\n            destination node type.\n            * or one ``str`` edge type name if the name can uniquely identify a\n              triplet format in the graph.\n\n            Can be omitted if the graph has only one type of edges.\n\n        eweight_name : str, optional\n            The name of edge feature used as the non-zero values. If not given,\n            the non-zero values are all 1.\n\n        Returns\n        -------\n        SparseMatrix\n            The adjacency matrix.\n\n        Examples\n        --------\n\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        >>> g = dgl.graph(([0, 1, 2], [1, 2, 3]))\n        >>> g.adj()\n        SparseMatrix(indices=tensor([[0, 1, 2],\n                                     [1, 2, 3]]),\n                     values=tensor([1., 1., 1.]),\n                     shape=(4, 4), nnz=3)\n\n        >>> g = dgl.heterograph({\n        ...     ('user', 'follows', 'user'): ([0, 1], [0, 1]),\n        ...     ('developer', 'develops', 'game'): ([0, 1], [0, 2])\n        ... })\n\n        >>> g.adj(etype='develops')\n        SparseMatrix(indices=tensor([[0, 1],\n                                     [0, 2]]),\n                     values=tensor([1., 1.]),\n                     shape=(2, 3), nnz=2)\n        >>> g.edata['h'] = {('user', 'follows', 'user'): torch.tensor([3, 2])}\n        >>> g.adj(etype='follows', eweight_name='h')\n        SparseMatrix(indices=tensor([[0, 1],\n                                     [0, 1]]),\n                     values=tensor([3, 2]),\n                     shape=(2, 2), nnz=2)\n        \"\"\"\n        assert F.backend_name == \"pytorch\", \"Only PyTorch backend supports adj.\"\n        # Temporal fix to introduce a dependency on torch\n        import torch\n\n        from .sparse import spmatrix\n\n        etype = self.to_canonical_etype(etype)\n        indices = torch.stack(self.all_edges(etype=etype))\n        shape = (self.num_nodes(etype[0]), self.number_of_nodes(etype[2]))\n        if eweight_name is not None:\n            val = self.edata[eweight_name][etype]\n        else:\n            val = None\n        return spmatrix(\n            indices,\n            val=val,\n            shape=shape,\n        )\n\n    def adj_external(\n        self, transpose=False, ctx=F.cpu(), scipy_fmt=None, etype=None\n    ):\n        \"\"\"Return the adjacency matrix in an external format, such as Scipy or\n        backend dependent sparse tensor.\n\n        By default, a row of returned adjacency matrix represents the\n        source of an edge and the column represents the destination.\n\n        When transpose is True, a row represents the destination and a column\n        represents the source.\n\n        Parameters\n        ----------\n        transpose : bool, optional\n            A flag to transpose the returned adjacency matrix. (Default: False)\n        ctx : context, optional\n            The context of returned adjacency matrix. (Default: cpu)\n        scipy_fmt : str, optional\n            If specified, return a scipy sparse matrix in the given format.\n            Otherwise, return a backend dependent sparse tensor. (Default: None)\n        etype : str or (str, str, str), optional\n            The type names of the edges. The allowed type name formats are:\n\n            * ``(str, str, str)`` for source node type, edge type and destination node type.\n            * or one ``str`` edge type name if the name can uniquely identify a\n              triplet format in the graph.\n\n            Can be omitted if the graph has only one type of edges.\n\n        Returns\n        -------\n        SparseTensor or scipy.sparse.spmatrix\n            Adjacency matrix.\n\n        Examples\n        --------\n\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        Instantiate a heterogeneous graph.\n\n        >>> g = dgl.heterograph({\n        ...     ('user', 'follows', 'user'): ([0, 1], [0, 1]),\n        ...     ('developer', 'develops', 'game'): ([0, 1], [0, 2])\n        ... })\n\n        Get a backend dependent sparse tensor. Here we use PyTorch for example.\n\n        >>> g.adj_external(etype='develops')\n        tensor(indices=tensor([[0, 1],\n                               [0, 2]]),\n               values=tensor([1., 1.]),\n               size=(2, 3), nnz=2, layout=torch.sparse_coo)\n\n        Get a scipy coo sparse matrix.\n\n        >>> g.adj_external(scipy_fmt='coo', etype='develops')\n        <2x3 sparse matrix of type '<class 'numpy.int64'>'\n           with 2 stored elements in COOrdinate format>\n        \"\"\"\n        etid = self.get_etype_id(etype)\n        if scipy_fmt is None:\n            return self._graph.adjacency_matrix(etid, transpose, ctx)[0]\n        else:\n            return self._graph.adjacency_matrix_scipy(\n                etid, transpose, scipy_fmt, False\n            )\n\n    def adj_tensors(self, fmt, etype=None):\n        \"\"\"Return the adjacency matrix of edges of the given edge type as tensors of\n        a sparse matrix representation.\n        By default, a row of returned adjacency matrix represents the\n        source of an edge and the column represents the destination.\n        Parameters\n        ----------\n        fmt : str\n            Either ``coo``, ``csr`` or ``csc``.\n        etype : str or (str, str, str), optional\n            The type names of the edges. The allowed type name formats are:\n            * ``(str, str, str)`` for source node type, edge type and destination node type.\n            * or one ``str`` edge type name if the name can uniquely identify a\n              triplet format in the graph.\n            Can be omitted if the graph has only one type of edges.\n        Returns\n        -------\n        tuple[Tensor]\n            If :attr:`fmt` is ``coo``, returns a pair of source and destination node ID\n            tensors.\n            If :attr:`fmt` is ``csr`` or ``csc``, return the CSR or CSC representation\n            of the adjacency matrix as a triplet of tensors\n            ``(indptr, indices, edge_ids)``.  Namely ``edge_ids`` could be an empty\n            tensor with 0 elements, in which case the edge IDs are consecutive\n            integers starting from 0.\n        Examples\n        --------\n        >>> g = dgl.graph(([0, 1, 2], [1, 2, 3]))\n        >>> g.adj_tensors('coo')\n        (tensor([0, 1, 2]), tensor([1, 2, 3]))\n        >>> g.adj_tensors('csr')\n        (tensor([0, 1, 2, 3, 3]), tensor([1, 2, 3]), tensor([0, 1, 2]))\n        \"\"\"\n        etid = self.get_etype_id(etype)\n        if fmt == \"csc\":\n            # The first two elements are number of rows and columns\n            return self._graph.adjacency_matrix_tensors(etid, True, \"csr\")[2:]\n        else:\n            return self._graph.adjacency_matrix_tensors(etid, False, fmt)[2:]\n\n    def inc(self, typestr, ctx=F.cpu(), etype=None):\n        \"\"\"Return the incidence matrix representation of edges with the given\n        edge type.\n\n        An incidence matrix is an n-by-m sparse matrix, where n is\n        the number of nodes and m is the number of edges. Each nnz\n        value indicating whether the edge is incident to the node\n        or not.\n\n        There are three types of incidence matrices :math:`I`:\n\n        * ``in``:\n\n            - :math:`I[v, e] = 1` if :math:`e` is the in-edge of :math:`v`\n              (or :math:`v` is the dst node of :math:`e`);\n            - :math:`I[v, e] = 0` otherwise.\n\n        * ``out``:\n\n            - :math:`I[v, e] = 1` if :math:`e` is the out-edge of :math:`v`\n              (or :math:`v` is the src node of :math:`e`);\n            - :math:`I[v, e] = 0` otherwise.\n\n        * ``both`` (only if source and destination node type are the same):\n\n            - :math:`I[v, e] = 1` if :math:`e` is the in-edge of :math:`v`;\n            - :math:`I[v, e] = -1` if :math:`e` is the out-edge of :math:`v`;\n            - :math:`I[v, e] = 0` otherwise (including self-loop).\n\n        Parameters\n        ----------\n        typestr : str\n            Can be either ``in``, ``out`` or ``both``\n        ctx : context, optional\n            The context of returned incidence matrix. (Default: cpu)\n        etype : str or (str, str, str), optional\n            The type names of the edges. The allowed type name formats are:\n\n            * ``(str, str, str)`` for source node type, edge type and destination node type.\n            * or one ``str`` edge type name if the name can uniquely identify a\n              triplet format in the graph.\n\n            Can be omitted if the graph has only one type of edges.\n\n        Returns\n        -------\n        Framework SparseTensor\n            The incidence matrix.\n\n        Examples\n        --------\n\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n\n        >>> g = dgl.graph(([0, 1], [0, 2]))\n        >>> g.inc('in')\n        tensor(indices=tensor([[0, 2],\n                               [0, 1]]),\n               values=tensor([1., 1.]),\n               size=(3, 2), nnz=2, layout=torch.sparse_coo)\n        >>> g.inc('out')\n        tensor(indices=tensor([[0, 1],\n                               [0, 1]]),\n               values=tensor([1., 1.]),\n               size=(3, 2), nnz=2, layout=torch.sparse_coo)\n        >>> g.inc('both')\n        tensor(indices=tensor([[1, 2],\n                               [1, 1]]),\n               values=tensor([-1.,  1.]),\n               size=(3, 2), nnz=2, layout=torch.sparse_coo)\n        \"\"\"\n        etid = self.get_etype_id(etype)\n        return self._graph.incidence_matrix(etid, typestr, ctx)[0]\n\n    incidence_matrix = inc\n\n    #################################################################\n    # Features\n    #################################################################\n\n    def node_attr_schemes(self, ntype=None):\n        \"\"\"Return the node feature schemes for the specified type.\n\n        The scheme of a feature describes the shape and data type of it.\n\n        Parameters\n        ----------\n        ntype : str, optional\n            The node type name. Can be omitted if there is only one type of nodes\n            in the graph.\n\n        Returns\n        -------\n        dict[str, Scheme]\n            A dictionary mapping a feature name to its associated feature scheme.\n\n        Examples\n        --------\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        Query for a homogeneous graph.\n\n        >>> g = dgl.graph((torch.tensor([0, 1]), torch.tensor([1, 2])))\n        >>> g.ndata['h1'] = torch.randn(3, 1)\n        >>> g.ndata['h2'] = torch.randn(3, 2)\n        >>> g.node_attr_schemes()\n        {'h1': Scheme(shape=(1,), dtype=torch.float32),\n         'h2': Scheme(shape=(2,), dtype=torch.float32)}\n\n        Query for a heterogeneous graph of multiple node types.\n\n        >>> g = dgl.heterograph({('user', 'plays', 'game'):\n        ...                      (torch.tensor([1, 2]), torch.tensor([3, 4]))})\n        >>> g.nodes['user'].data['h1'] = torch.randn(3, 1)\n        >>> g.nodes['user'].data['h2'] = torch.randn(3, 2)\n        >>> g.node_attr_schemes('user')\n        {'h1': Scheme(shape=(1,), dtype=torch.float32),\n         'h2': Scheme(shape=(2,), dtype=torch.float32)}\n\n        See Also\n        --------\n        edge_attr_schemes\n        \"\"\"\n        return self._node_frames[self.get_ntype_id(ntype)].schemes\n\n    def edge_attr_schemes(self, etype=None):\n        \"\"\"Return the edge feature schemes for the specified type.\n\n        The scheme of a feature describes the shape and data type of it.\n\n        Parameters\n        ----------\n        etype : str or (str, str, str), optional\n            The type names of the edges. The allowed type name formats are:\n\n            * ``(str, str, str)`` for source node type, edge type and destination node type.\n            * or one ``str`` edge type name if the name can uniquely identify a\n              triplet format in the graph.\n\n            Can be omitted if the graph has only one type of edges.\n\n\n        Returns\n        -------\n        dict[str, Scheme]\n            A dictionary mapping a feature name to its associated feature scheme.\n\n        Examples\n        --------\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        Query for a homogeneous graph.\n\n        >>> g = dgl.graph((torch.tensor([0, 1]), torch.tensor([1, 2])))\n        >>> g.edata['h1'] = torch.randn(2, 1)\n        >>> g.edata['h2'] = torch.randn(2, 2)\n        >>> g.edge_attr_schemes()\n        {'h1': Scheme(shape=(1,), dtype=torch.float32),\n         'h2': Scheme(shape=(2,), dtype=torch.float32)}\n\n        Query for a heterogeneous graph of multiple edge types.\n\n        >>> g = dgl.heterograph({('user', 'plays', 'game'):\n        ...                      (torch.tensor([1, 2]), torch.tensor([3, 4])),\n        ...                      ('user', 'follows', 'user'):\n        ...                      (torch.tensor([3, 4]), torch.tensor([5, 6]))})\n        >>> g.edges['plays'].data['h1'] = torch.randn(2, 1)\n        >>> g.edges['plays'].data['h2'] = torch.randn(2, 2)\n        >>> g.edge_attr_schemes('plays')\n        {'h1': Scheme(shape=(1,), dtype=torch.float32),\n         'h2': Scheme(shape=(2,), dtype=torch.float32)}\n\n        See Also\n        --------\n        node_attr_schemes\n        \"\"\"\n        return self._edge_frames[self.get_etype_id(etype)].schemes\n\n    def set_n_initializer(self, initializer, field=None, ntype=None):\n        \"\"\"Set the initializer for node features.\n\n        When only part of the nodes have a feature (e.g. new nodes are added,\n        features are set for a subset of nodes), the initializer initializes\n        features for the rest nodes.\n\n        Parameters\n        ----------\n        initializer : callable\n            A function of signature ``func(shape, dtype, ctx, id_range) -> Tensor``.\n            The tensor will be the initialized features. The arguments are:\n\n            - ``shape``: The shape of the tensor to return, which is a tuple of int.\n              The first dimension is the number of nodes for feature initialization.\n            - ``dtype``: The data type of the tensor to return, which is a\n              framework-specific data type object.\n            - ``ctx``: The device of the tensor to return, which is a framework-specific\n              device object.\n            - ``id_range``: The start and end ID of the nodes for feature initialization,\n              which is a slice.\n        field : str, optional\n            The name of the feature that the initializer applies. If not given, the\n            initializer applies to all features.\n        ntype : str, optional\n            The type name of the nodes. Can be omitted if the graph has only one type of nodes.\n\n        Notes\n        -----\n        Without setting a node feature initializer, zero tensors are generated\n        for nodes without a feature.\n\n        Examples\n        --------\n\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        Define a function for initializer.\n\n        >>> def init_feats(shape, dtype, device, id_range):\n        ...     return torch.ones(shape, dtype=dtype, device=device)\n\n        An example for a homogeneous graph.\n\n        >>> g = dgl.graph((torch.tensor([0]), torch.tensor([1])))\n        >>> g.ndata['h1'] = torch.zeros(2, 2)\n        >>> g.ndata['h2'] = torch.ones(2, 1)\n        >>> # Apply the initializer to feature 'h2' only.\n        >>> g.set_n_initializer(init_feats, field='h2')\n        >>> g.add_nodes(1)\n        >>> print(g.ndata['h1'])\n        tensor([[0., 0.],\n                [0., 0.],\n                [0., 0.]])\n        >>> print(g.ndata['h2'])\n        tensor([[1.], [1.], [1.]])\n\n        An example for a heterogeneous graph of multiple node types.\n\n        >>> g = dgl.heterograph({\n        ...     ('user', 'plays', 'game'): (torch.tensor([0, 1, 1, 2]),\n        ...                                 torch.tensor([0, 0, 1, 1])),\n        ...     ('developer', 'develops', 'game'): (torch.tensor([0, 1]),\n        ...                                         torch.tensor([0, 1]))\n        ...     })\n        >>> g.nodes['user'].data['h'] = torch.zeros(3, 2)\n        >>> g.nodes['game'].data['w'] = torch.ones(2, 2)\n        >>> g.set_n_initializer(init_feats, ntype='game')\n        >>> g.add_nodes(1, ntype='user')\n        >>> # Initializer not set for 'user', use zero tensors by default\n        >>> g.nodes['user'].data['h']\n        tensor([[0., 0.],\n                [0., 0.],\n                [0., 0.],\n                [0., 0.]])\n        >>> # Initializer set for 'game'\n        >>> g.add_nodes(1, ntype='game')\n        >>> g.nodes['game'].data['w']\n        tensor([[1., 1.],\n                [1., 1.],\n                [1., 1.]])\n        \"\"\"\n        ntid = self.get_ntype_id(ntype)\n        self._node_frames[ntid].set_initializer(initializer, field)\n\n    def set_e_initializer(self, initializer, field=None, etype=None):\n        \"\"\"Set the initializer for edge features.\n\n        When only part of the edges have a feature (e.g. new edges are added,\n        features are set for a subset of edges), the initializer initializes\n        features for the rest edges.\n\n        Parameters\n        ----------\n        initializer : callable\n            A function of signature ``func(shape, dtype, ctx, id_range) -> Tensor``.\n            The tensor will be the initialized features. The arguments are:\n\n            - ``shape``: The shape of the tensor to return, which is a tuple of int.\n              The first dimension is the number of edges for feature initialization.\n            - ``dtype``: The data type of the tensor to return, which is a\n              framework-specific data type object.\n            - ``ctx``: The device of the tensor to return, which is a framework-specific\n              device object.\n            - ``id_range``: The start and end ID of the edges for feature initialization,\n              which is a slice.\n        field : str, optional\n            The name of the feature that the initializer applies. If not given, the\n            initializer applies to all features.\n        etype : str or (str, str, str), optional\n            The type names of the edges. The allowed type name formats are:\n\n            * ``(str, str, str)`` for source node type, edge type and destination node type.\n            * or one ``str`` edge type name if the name can uniquely identify a\n              triplet format in the graph.\n\n            Can be omitted if the graph has only one type of edges.\n\n\n        Notes\n        -----\n        Without setting an edge feature initializer, zero tensors are generated\n        for edges without a feature.\n\n        Examples\n        --------\n\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        Define a function for initializer.\n\n        >>> def init_feats(shape, dtype, device, id_range):\n        ...     return torch.ones(shape, dtype=dtype, device=device)\n\n        An example for a homogeneous graph.\n\n        >>> g = dgl.graph((torch.tensor([0]), torch.tensor([1])))\n        >>> g.edata['h1'] = torch.zeros(1, 2)\n        >>> g.edata['h2'] = torch.ones(1, 1)\n        >>> # Apply the initializer to feature 'h2' only.\n        >>> g.set_e_initializer(init_feats, field='h2')\n        >>> g.add_edges(torch.tensor([1]), torch.tensor([1]))\n        >>> print(g.edata['h1'])\n        tensor([[0., 0.],\n                [0., 0.]])\n        >>> print(g.edata['h2'])\n        tensor([[1.], [1.]])\n\n        An example for a heterogeneous graph of multiple edge types.\n\n        >>> g = dgl.heterograph({\n        ...     ('user', 'plays', 'game'): (torch.tensor([0, 1]),\n        ...                                 torch.tensor([0, 0])),\n        ...     ('developer', 'develops', 'game'): (torch.tensor([0, 1]),\n        ...                                         torch.tensor([0, 1]))\n        ...     })\n        >>> g.edges['plays'].data['h'] = torch.zeros(2, 2)\n        >>> g.edges['develops'].data['w'] = torch.ones(2, 2)\n        >>> g.set_e_initializer(init_feats, etype='plays')\n        >>> # Initializer not set for 'develops', use zero tensors by default\n        >>> g.add_edges(torch.tensor([1]), torch.tensor([1]), etype='develops')\n        >>> g.edges['develops'].data['w']\n        tensor([[1., 1.],\n                [1., 1.],\n                [0., 0.]])\n        >>> # Initializer set for 'plays'\n        >>> g.add_edges(torch.tensor([1]), torch.tensor([1]), etype='plays')\n        >>> g.edges['plays'].data['h']\n        tensor([[0., 0.],\n                [0., 0.],\n                [1., 1.]])\n        \"\"\"\n        etid = self.get_etype_id(etype)\n        self._edge_frames[etid].set_initializer(initializer, field)\n\n    def _set_n_repr(self, ntid, u, data):\n        \"\"\"Internal API to set node features.\n\n        `data` is a dictionary from the feature name to feature tensor. Each tensor\n        is of shape (B, D1, D2, ...), where B is the number of nodes to be updated,\n        and (D1, D2, ...) be the shape of the node representation tensor. The\n        length of the given node ids must match B (i.e, len(u) == B).\n\n        All updates will be done out of place to work with autograd.\n\n        Parameters\n        ----------\n        ntid : int\n            Node type id.\n        u : node, container or tensor\n            The node(s).\n        data : dict of tensor\n            Node representation.\n        \"\"\"\n        if is_all(u):\n            num_nodes = self._graph.num_nodes(ntid)\n        else:\n            u = utils.prepare_tensor(self, u, \"u\")\n            num_nodes = len(u)\n        for key, val in data.items():\n            nfeats = F.shape(val)[0]\n            if nfeats != num_nodes:\n                raise DGLError(\n                    \"Expect number of features to match number of nodes (len(u)).\"\n                    \" Got %d and %d instead.\" % (nfeats, num_nodes)\n                )\n            if F.context(val) != self.device:\n                raise DGLError(\n                    'Cannot assign node feature \"{}\" on device {} to a graph on'\n                    \" device {}. Call DGLGraph.to() to copy the graph to the\"\n                    \" same device.\".format(key, F.context(val), self.device)\n                )\n            # To prevent users from doing things like:\n            #\n            #     g.pin_memory_()\n            #     g.ndata['x'] = torch.randn(...)\n            #     sg = g.sample_neighbors(torch.LongTensor([...]).cuda())\n            #     sg.ndata['x']    # Becomes a CPU tensor even if sg is on GPU due to lazy slicing\n            if (\n                self.is_pinned()\n                and F.context(val) == \"cpu\"\n                and not F.is_pinned(val)\n            ):\n                raise DGLError(\n                    \"Pinned graph requires the node data to be pinned as well. \"\n                    \"Please pin the node data before assignment.\"\n                )\n\n        if is_all(u):\n            self._node_frames[ntid].update(data)\n        else:\n            self._node_frames[ntid].update_row(u, data)\n\n    def _get_n_repr(self, ntid, u):\n        \"\"\"Get node(s) representation of a single node type.\n\n        The returned feature tensor batches multiple node features on the first dimension.\n\n        Parameters\n        ----------\n        ntid : int\n            Node type id.\n        u : node, container or tensor\n            The node(s).\n\n        Returns\n        -------\n        dict\n            Representation dict from feature name to feature tensor.\n        \"\"\"\n        if is_all(u):\n            return self._node_frames[ntid]\n        else:\n            u = utils.prepare_tensor(self, u, \"u\")\n            return self._node_frames[ntid].subframe(u)\n\n    def _pop_n_repr(self, ntid, key):\n        \"\"\"Internal API to get and remove the specified node feature.\n\n        Parameters\n        ----------\n        ntid : int\n            Node type id.\n        key : str\n            The attribute name.\n\n        Returns\n        -------\n        Tensor\n            The popped representation\n        \"\"\"\n        return self._node_frames[ntid].pop(key)\n\n    def _set_e_repr(self, etid, edges, data):\n        \"\"\"Internal API to set edge(s) features.\n\n        `data` is a dictionary from the feature name to feature tensor. Each tensor\n        is of shape (B, D1, D2, ...), where B is the number of edges to be updated,\n        and (D1, D2, ...) be the shape of the edge representation tensor.\n\n        All update will be done out of place to work with autograd.\n\n        Parameters\n        ----------\n        etid : int\n            Edge type id.\n        edges : edges\n            Edges can be either\n\n            * A pair of endpoint nodes (u, v), where u is the node ID of source\n              node type and v is that of destination node type.\n            * A tensor of edge ids of the given type.\n\n            The default value is all the edges.\n        data : tensor or dict of tensor\n            Edge representation.\n        \"\"\"\n        # parse argument\n        if not is_all(edges):\n            eid = utils.parse_edges_arg_to_eid(self, edges, etid, \"edges\")\n\n        # sanity check\n        if not utils.is_dict_like(data):\n            raise DGLError(\n                \"Expect dictionary type for feature data.\"\n                ' Got \"%s\" instead.' % type(data)\n            )\n\n        if is_all(edges):\n            num_edges = self._graph.num_edges(etid)\n        else:\n            num_edges = len(eid)\n        for key, val in data.items():\n            nfeats = F.shape(val)[0]\n            if nfeats != num_edges:\n                raise DGLError(\n                    \"Expect number of features to match number of edges.\"\n                    \" Got %d and %d instead.\" % (nfeats, num_edges)\n                )\n            if F.context(val) != self.device:\n                raise DGLError(\n                    'Cannot assign edge feature \"{}\" on device {} to a graph on'\n                    \" device {}. Call DGLGraph.to() to copy the graph to the\"\n                    \" same device.\".format(key, F.context(val), self.device)\n                )\n            # To prevent users from doing things like:\n            #\n            #     g.pin_memory_()\n            #     g.edata['x'] = torch.randn(...)\n            #     sg = g.sample_neighbors(torch.LongTensor([...]).cuda())\n            #     sg.edata['x']    # Becomes a CPU tensor even if sg is on GPU due to lazy slicing\n            if (\n                self.is_pinned()\n                and F.context(val) == \"cpu\"\n                and not F.is_pinned(val)\n            ):\n                raise DGLError(\n                    \"Pinned graph requires the edge data to be pinned as well. \"\n                    \"Please pin the edge data before assignment.\"\n                )\n\n        # set\n        if is_all(edges):\n            self._edge_frames[etid].update(data)\n        else:\n            self._edge_frames[etid].update_row(eid, data)\n\n    def _get_e_repr(self, etid, edges):\n        \"\"\"Internal API to get edge features.\n\n        Parameters\n        ----------\n        etid : int\n            Edge type id.\n        edges : edges\n            Edges can be a pair of endpoint nodes (u, v), or a\n            tensor of edge ids. The default value is all the edges.\n\n        Returns\n        -------\n        dict\n            Representation dict\n        \"\"\"\n        # parse argument\n        if is_all(edges):\n            return self._edge_frames[etid]\n        else:\n            eid = utils.parse_edges_arg_to_eid(self, edges, etid, \"edges\")\n            return self._edge_frames[etid].subframe(eid)\n\n    def _pop_e_repr(self, etid, key):\n        \"\"\"Get and remove the specified edge repr of a single edge type.\n\n        Parameters\n        ----------\n        etid : int\n            Edge type id.\n        key : str\n          The attribute name.\n\n        Returns\n        -------\n        Tensor\n            The popped representation\n        \"\"\"\n        self._edge_frames[etid].pop(key)\n\n    #################################################################\n    # Message passing\n    #################################################################\n\n    def apply_nodes(self, func, v=ALL, ntype=None):\n        \"\"\"Update the features of the specified nodes by the provided function.\n\n        Parameters\n        ----------\n        func : callable\n            The function to update node features. It must be\n            a :ref:`apiudf`.\n        v : node IDs\n            The node IDs. The allowed formats are:\n\n            * ``int``: A single node.\n            * Int Tensor: Each element is a node ID. The tensor must have the same device type\n              and ID data type as the graph's.\n            * iterable[int]: Each element is a node ID.\n\n            If not given (default), use all the nodes in the graph.\n        ntype : str, optional\n            The node type name. Can be omitted if there is\n            only one type of nodes in the graph.\n\n        Examples\n        --------\n\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        **Homogeneous graph**\n\n        >>> g = dgl.graph(([0, 1, 2, 3], [1, 2, 3, 4]))\n        >>> g.ndata['h'] = torch.ones(5, 2)\n        >>> g.apply_nodes(lambda nodes: {'x' : nodes.data['h'] * 2})\n        >>> g.ndata['x']\n        tensor([[2., 2.],\n                [2., 2.],\n                [2., 2.],\n                [2., 2.],\n                [2., 2.]])\n\n        **Heterogeneous graph**\n\n        >>> g = dgl.heterograph({('user', 'follows', 'user'): ([0, 1], [1, 2])})\n        >>> g.nodes['user'].data['h'] = torch.ones(3, 5)\n        >>> g.apply_nodes(lambda nodes: {'h': nodes.data['h'] * 2}, ntype='user')\n        >>> g.nodes['user'].data['h']\n        tensor([[2., 2., 2., 2., 2.],\n                [2., 2., 2., 2., 2.],\n                [2., 2., 2., 2., 2.]])\n\n        See Also\n        --------\n        apply_edges\n        \"\"\"\n        ntid = self.get_ntype_id(ntype)\n        ntype = self.ntypes[ntid]\n        if is_all(v):\n            v_id = self.nodes(ntype)\n        else:\n            v_id = utils.prepare_tensor(self, v, \"v\")\n        ndata = core.invoke_node_udf(self, v_id, ntype, func, orig_nid=v_id)\n        self._set_n_repr(ntid, v, ndata)\n\n    def apply_edges(self, func, edges=ALL, etype=None):\n        \"\"\"Update the features of the specified edges by the provided function.\n\n        Parameters\n        ----------\n        func : dgl.function.BuiltinFunction or callable\n            The function to generate new edge features. It must be either\n            a :ref:`api-built-in` or a :ref:`apiudf`.\n        edges : edges\n            The edges to update features on. The allowed input formats are:\n\n            * ``int``: A single edge ID.\n            * Int Tensor: Each element is an edge ID.  The tensor must have the same device type\n              and ID data type as the graph's.\n            * iterable[int]: Each element is an edge ID.\n            * (Tensor, Tensor): The node-tensors format where the i-th elements\n              of the two tensors specify an edge.\n            * (iterable[int], iterable[int]): Similar to the node-tensors format but\n              stores edge endpoints in python iterables.\n\n            Default value specifies all the edges in the graph.\n\n        etype : str or (str, str, str), optional\n            The type name of the edges. The allowed type name formats are:\n\n            * ``(str, str, str)`` for source node type, edge type and destination node type.\n            * or one ``str`` edge type name if the name can uniquely identify a\n              triplet format in the graph.\n\n            Can be omitted if the graph has only one type of edges.\n\n        Notes\n        -----\n        DGL recommends using DGL's bulit-in function for the :attr:`func` argument,\n        because DGL will invoke efficient kernels that avoids copying node features to\n        edge features in this case.\n\n        Examples\n        --------\n\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        **Homogeneous graph**\n\n        >>> g = dgl.graph(([0, 1, 2, 3], [1, 2, 3, 4]))\n        >>> g.ndata['h'] = torch.ones(5, 2)\n        >>> g.apply_edges(lambda edges: {'x' : edges.src['h'] + edges.dst['h']})\n        >>> g.edata['x']\n        tensor([[2., 2.],\n                [2., 2.],\n                [2., 2.],\n                [2., 2.]])\n\n        Use built-in function\n\n        >>> import dgl.function as fn\n        >>> g.apply_edges(fn.u_add_v('h', 'h', 'x'))\n        >>> g.edata['x']\n        tensor([[2., 2.],\n                [2., 2.],\n                [2., 2.],\n                [2., 2.]])\n\n        **Heterogeneous graph**\n\n        >>> g = dgl.heterograph({('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 2, 1])})\n        >>> g.edges[('user', 'plays', 'game')].data['h'] = torch.ones(4, 5)\n        >>> g.apply_edges(lambda edges: {'h': edges.data['h'] * 2})\n        >>> g.edges[('user', 'plays', 'game')].data['h']\n        tensor([[2., 2., 2., 2., 2.],\n                [2., 2., 2., 2., 2.],\n                [2., 2., 2., 2., 2.],\n                [2., 2., 2., 2., 2.]])\n\n        See Also\n        --------\n        apply_nodes\n        \"\"\"\n        # Graph with one relation type\n        if self._graph.number_of_etypes() == 1 or etype is not None:\n            etid = self.get_etype_id(etype)\n            etype = self.canonical_etypes[etid]\n            g = self if etype is None else self[etype]\n        else:  # heterogeneous graph with number of relation types > 1\n            if not core.is_builtin(func):\n                raise DGLError(\n                    \"User defined functions are not yet \"\n                    \"supported in apply_edges for heterogeneous graphs. \"\n                    \"Please use (apply_edges(func), etype = rel) instead.\"\n                )\n            g = self\n        if is_all(edges):\n            eid = ALL\n        else:\n            eid = utils.parse_edges_arg_to_eid(self, edges, etid, \"edges\")\n        if core.is_builtin(func):\n            if not is_all(eid):\n                g = g.edge_subgraph(eid, relabel_nodes=False)\n            edata = core.invoke_gsddmm(g, func)\n        else:\n            edata = core.invoke_edge_udf(g, eid, etype, func)\n\n        if self._graph.number_of_etypes() == 1 or etype is not None:\n            self._set_e_repr(etid, eid, edata)\n        else:\n            edata_tensor = {}\n            key = list(edata.keys())[0]\n            out_tensor_tuples = edata[key]\n            for etid in range(self._graph.number_of_etypes()):\n                # TODO (Israt): Check the logic why some output tensor is None\n                if out_tensor_tuples[etid] is not None:\n                    edata_tensor[key] = out_tensor_tuples[etid]\n                    self._set_e_repr(etid, eid, edata_tensor)\n\n    def send_and_recv(\n        self, edges, message_func, reduce_func, apply_node_func=None, etype=None\n    ):\n        \"\"\"Send messages along the specified edges and reduce them on\n        the destination nodes to update their features.\n\n        Parameters\n        ----------\n        edges : edges\n            The edges to send and receive messages on. The allowed input formats are:\n\n            * ``int``: A single edge ID.\n            * Int Tensor: Each element is an edge ID.  The tensor must have the same device type\n              and ID data type as the graph's.\n            * iterable[int]: Each element is an edge ID.\n            * (Tensor, Tensor): The node-tensors format where the i-th elements\n              of the two tensors specify an edge.\n            * (iterable[int], iterable[int]): Similar to the node-tensors format but\n              stores edge endpoints in python iterables.\n\n        message_func : dgl.function.BuiltinFunction or callable\n            The message function to generate messages along the edges.\n            It must be either a :ref:`api-built-in` or a :ref:`apiudf`.\n        reduce_func : dgl.function.BuiltinFunction or callable\n            The reduce function to aggregate the messages.\n            It must be either a :ref:`api-built-in` or a :ref:`apiudf`.\n        apply_node_func : callable, optional\n            An optional apply function to further update the node features\n            after the message reduction. It must be a :ref:`apiudf`.\n        etype : str or (str, str, str), optional\n            The type name of the edges. The allowed type name formats are:\n\n            * ``(str, str, str)`` for source node type, edge type and destination node type.\n            * or one ``str`` edge type name if the name can uniquely identify a\n              triplet format in the graph.\n\n            Can be omitted if the graph has only one type of edges.\n\n        Notes\n        -----\n        DGL recommends using DGL's bulit-in function for the :attr:`message_func`\n        and the :attr:`reduce_func` arguments,\n        because DGL will invoke efficient kernels that avoids copying node features to\n        edge features in this case.\n\n        Examples\n        --------\n\n        >>> import dgl\n        >>> import dgl.function as fn\n        >>> import torch\n\n        **Homogeneous graph**\n\n        >>> g = dgl.graph(([0, 1, 2, 3], [1, 2, 3, 4]))\n        >>> g.ndata['x'] = torch.ones(5, 2)\n        >>> # Specify edges using (Tensor, Tensor).\n        >>> g.send_and_recv(([1, 2], [2, 3]), fn.copy_u('x', 'm'), fn.sum('m', 'h'))\n        >>> g.ndata['h']\n        tensor([[0., 0.],\n                [0., 0.],\n                [1., 1.],\n                [1., 1.],\n                [0., 0.]])\n        >>> # Specify edges using IDs.\n        >>> g.send_and_recv([0, 2, 3], fn.copy_u('x', 'm'), fn.sum('m', 'h'))\n        >>> g.ndata['h']\n        tensor([[0., 0.],\n                [1., 1.],\n                [0., 0.],\n                [1., 1.],\n                [1., 1.]])\n\n        **Heterogeneous graph**\n\n        >>> g = dgl.heterograph({\n        ...     ('user', 'follows', 'user'): ([0, 1], [1, 2]),\n        ...     ('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 1, 1])\n        ... })\n        >>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])\n        >>> g.send_and_recv(g['follows'].edges(), fn.copy_u('h', 'm'),\n        ...                 fn.sum('m', 'h'), etype='follows')\n        >>> g.nodes['user'].data['h']\n        tensor([[0.],\n                [0.],\n                [1.]])\n\n        **``send_and_recv`` using user-defined functions**\n\n        >>> import torch as th\n        >>> g = dgl.graph(([0, 1], [1, 2]))\n        >>> g.ndata['x'] = th.tensor([[1.], [2.], [3.]])\n\n        >>> # Define the function for sending node features as messages.\n        >>> def send_source(edges):\n        ...     return {'m': edges.src['x']}\n        >>> # Sum the messages received and use this to replace the original node feature.\n        >>> def simple_reduce(nodes):\n        ...     return {'x': nodes.mailbox['m'].sum(1)}\n\n        Send and receive messages.\n\n        >>> g.send_and_recv(g.edges())\n        >>> g.ndata['x']\n        tensor([[1.],\n                [1.],\n                [2.]])\n\n        Note that the feature of node 0 remains the same as it has no incoming edges.\n        \"\"\"\n        # edge type\n        etid = self.get_etype_id(etype)\n        _, dtid = self._graph.metagraph.find_edge(etid)\n        etype = self.canonical_etypes[etid]\n        # edge IDs\n        eid = utils.parse_edges_arg_to_eid(self, edges, etid, \"edges\")\n        if len(eid) == 0:\n            # no computation\n            return\n        u, v = self.find_edges(eid, etype=etype)\n        # call message passing onsubgraph\n        g = self if etype is None else self[etype]\n        compute_graph, _, dstnodes, _ = _create_compute_graph(g, u, v, eid)\n        ndata = core.message_passing(\n            compute_graph, message_func, reduce_func, apply_node_func\n        )\n        self._set_n_repr(dtid, dstnodes, ndata)\n\n    def pull(\n        self, v, message_func, reduce_func, apply_node_func=None, etype=None\n    ):\n        \"\"\"Pull messages from the specified node(s)' predecessors along the\n        specified edge type, aggregate them to update the node features.\n\n        Parameters\n        ----------\n        v : node IDs\n            The node IDs. The allowed formats are:\n\n            * ``int``: A single node.\n            * Int Tensor: Each element is a node ID. The tensor must have the same device type\n              and ID data type as the graph's.\n            * iterable[int]: Each element is a node ID.\n\n        message_func : dgl.function.BuiltinFunction or callable\n            The message function to generate messages along the edges.\n            It must be either a :ref:`api-built-in` or a :ref:`apiudf`.\n        reduce_func : dgl.function.BuiltinFunction or callable\n            The reduce function to aggregate the messages.\n            It must be either a :ref:`api-built-in` or a :ref:`apiudf`.\n        apply_node_func : callable, optional\n            An optional apply function to further update the node features\n            after the message reduction. It must be a :ref:`apiudf`.\n        etype : str or (str, str, str), optional\n            The type name of the edges. The allowed type name formats are:\n\n            * ``(str, str, str)`` for source node type, edge type and destination node type.\n            * or one ``str`` edge type name if the name can uniquely identify a\n              triplet format in the graph.\n\n            Can be omitted if the graph has only one type of edges.\n\n        Notes\n        -----\n        * If some of the given nodes :attr:`v` has no in-edges, DGL does not invoke\n          message and reduce functions for these nodes and fill their aggregated messages\n          with zero. Users can control the filled values via :meth:`set_n_initializer`.\n          DGL still invokes :attr:`apply_node_func` if provided.\n        * DGL recommends using DGL's bulit-in function for the :attr:`message_func`\n          and the :attr:`reduce_func` arguments,\n          because DGL will invoke efficient kernels that avoids copying node features to\n          edge features in this case.\n\n        Examples\n        --------\n\n        >>> import dgl\n        >>> import dgl.function as fn\n        >>> import torch\n\n        **Homogeneous graph**\n\n        >>> g = dgl.graph(([0, 1, 2, 3], [1, 2, 3, 4]))\n        >>> g.ndata['x'] = torch.ones(5, 2)\n        >>> g.pull([0, 3, 4], fn.copy_u('x', 'm'), fn.sum('m', 'h'))\n        >>> g.ndata['h']\n        tensor([[0., 0.],\n                [0., 0.],\n                [0., 0.],\n                [1., 1.],\n                [1., 1.]])\n\n        **Heterogeneous graph**\n\n        >>> g = dgl.heterograph({\n        ...     ('user', 'follows', 'user'): ([0, 1], [1, 2]),\n        ...     ('user', 'plays', 'game'): ([0, 2], [0, 1])\n        ... })\n        >>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])\n\n        Pull.\n\n        >>> g['follows'].pull(2, fn.copy_u('h', 'm'), fn.sum('m', 'h'), etype='follows')\n        >>> g.nodes['user'].data['h']\n        tensor([[0.],\n                [1.],\n                [1.]])\n        \"\"\"\n        v = utils.prepare_tensor(self, v, \"v\")\n        if len(v) == 0:\n            # no computation\n            return\n        etid = self.get_etype_id(etype)\n        _, dtid = self._graph.metagraph.find_edge(etid)\n        etype = self.canonical_etypes[etid]\n        g = self if etype is None else self[etype]\n        # call message passing on subgraph\n        src, dst, eid = g.in_edges(v, form=\"all\")\n        compute_graph, _, dstnodes, _ = _create_compute_graph(\n            g, src, dst, eid, v\n        )\n        ndata = core.message_passing(\n            compute_graph, message_func, reduce_func, apply_node_func\n        )\n        self._set_n_repr(dtid, dstnodes, ndata)\n\n    def push(\n        self, u, message_func, reduce_func, apply_node_func=None, etype=None\n    ):\n        \"\"\"Send message from the specified node(s) to their successors\n        along the specified edge type and update their node features.\n\n        Parameters\n        ----------\n        v : node IDs\n            The node IDs. The allowed formats are:\n\n            * ``int``: A single node.\n            * Int Tensor: Each element is a node ID. The tensor must have the same device type\n              and ID data type as the graph's.\n            * iterable[int]: Each element is a node ID.\n\n        message_func : dgl.function.BuiltinFunction or callable\n            The message function to generate messages along the edges.\n            It must be either a :ref:`api-built-in` or a :ref:`apiudf`.\n        reduce_func : dgl.function.BuiltinFunction or callable\n            The reduce function to aggregate the messages.\n            It must be either a :ref:`api-built-in` or a :ref:`apiudf`.\n        apply_node_func : callable, optional\n            An optional apply function to further update the node features\n            after the message reduction. It must be a :ref:`apiudf`.\n        etype : str or (str, str, str), optional\n            The type name of the edges. The allowed type name formats are:\n\n            * ``(str, str, str)`` for source node type, edge type and destination node type.\n            * or one ``str`` edge type name if the name can uniquely identify a\n              triplet format in the graph.\n\n            Can be omitted if the graph has only one type of edges.\n\n        Notes\n        -----\n        DGL recommends using DGL's bulit-in function for the :attr:`message_func`\n        and the :attr:`reduce_func` arguments,\n        because DGL will invoke efficient kernels that avoids copying node features to\n        edge features in this case.\n\n        Examples\n        --------\n\n        >>> import dgl\n        >>> import dgl.function as fn\n        >>> import torch\n\n        **Homogeneous graph**\n\n        >>> g = dgl.graph(([0, 1, 2, 3], [1, 2, 3, 4]))\n        >>> g.ndata['x'] = torch.ones(5, 2)\n        >>> g.push([0, 1], fn.copy_u('x', 'm'), fn.sum('m', 'h'))\n        >>> g.ndata['h']\n        tensor([[0., 0.],\n                [1., 1.],\n                [1., 1.],\n                [0., 0.],\n                [0., 0.]])\n\n        **Heterogeneous graph**\n\n        >>> g = dgl.heterograph({('user', 'follows', 'user'): ([0, 0], [1, 2])})\n        >>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])\n\n        Push.\n\n        >>> g['follows'].push(0, fn.copy_u('h', 'm'), fn.sum('m', 'h'), etype='follows')\n        >>> g.nodes['user'].data['h']\n        tensor([[0.],\n                [0.],\n                [0.]])\n        \"\"\"\n        edges = self.out_edges(u, form=\"eid\", etype=etype)\n        self.send_and_recv(\n            edges, message_func, reduce_func, apply_node_func, etype=etype\n        )\n\n    def update_all(\n        self, message_func, reduce_func, apply_node_func=None, etype=None\n    ):\n        \"\"\"Send messages along all the edges of the specified type\n        and update all the nodes of the corresponding destination type.\n\n        For heterogeneous graphs with number of relation types > 1, send messages\n        along all the edges, reduce them by type-wisely and across different types\n        at the same time. Then, update the node features of all the nodes.\n\n        Parameters\n        ----------\n        message_func : dgl.function.BuiltinFunction or callable\n            The message function to generate messages along the edges.\n            It must be either a :ref:`api-built-in` or a :ref:`apiudf`.\n        reduce_func : dgl.function.BuiltinFunction or callable\n            The reduce function to aggregate the messages.\n            It must be either a :ref:`api-built-in` or a :ref:`apiudf`.\n        apply_node_func : callable, optional\n            An optional apply function to further update the node features\n            after the message reduction. It must be a :ref:`apiudf`.\n        etype : str or (str, str, str), optional\n            The type name of the edges. The allowed type name formats are:\n\n            * ``(str, str, str)`` for source node type, edge type and destination node type.\n            * or one ``str`` edge type name if the name can uniquely identify a\n              triplet format in the graph.\n\n            Can be omitted if the graph has only one type of edges.\n\n        Notes\n        -----\n        * If some of the nodes in the graph has no in-edges, DGL does not invoke\n          message and reduce functions for these nodes and fill their aggregated messages\n          with zero. Users can control the filled values via :meth:`set_n_initializer`.\n          DGL still invokes :attr:`apply_node_func` if provided.\n        * DGL recommends using DGL's bulit-in function for the :attr:`message_func`\n          and the :attr:`reduce_func` arguments,\n          because DGL will invoke efficient kernels that avoids copying node features to\n          edge features in this case.\n\n        Examples\n        --------\n        >>> import dgl\n        >>> import dgl.function as fn\n        >>> import torch\n\n        **Homogeneous graph**\n\n        >>> g = dgl.graph(([0, 1, 2, 3], [1, 2, 3, 4]))\n        >>> g.ndata['x'] = torch.ones(5, 2)\n        >>> g.update_all(fn.copy_u('x', 'm'), fn.sum('m', 'h'))\n        >>> g.ndata['h']\n        tensor([[0., 0.],\n                [1., 1.],\n                [1., 1.],\n                [1., 1.],\n                [1., 1.]])\n\n        **Heterogeneous graph**\n\n        >>> g = dgl.heterograph({('user', 'follows', 'user'): ([0, 1, 2], [1, 2, 2])})\n\n        Update all.\n\n        >>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])\n        >>> g['follows'].update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'), etype='follows')\n        >>> g.nodes['user'].data['h']\n        tensor([[0.],\n                [0.],\n                [3.]])\n\n        **Heterogenenous graph (number relation types > 1)**\n\n        >>> g = dgl.heterograph({\n        ...     ('user', 'follows', 'user'): ([0, 1], [1, 1]),\n        ...     ('game', 'attracts', 'user'): ([0], [1])\n        ... })\n\n        Update all.\n\n        >>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.]])\n        >>> g.nodes['game'].data['h'] = torch.tensor([[1.]])\n        >>> g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))\n        >>> g.nodes['user'].data['h']\n        tensor([[0.],\n                [4.]])\n        \"\"\"\n        # Graph with one relation type\n        if self._graph.number_of_etypes() == 1 or etype is not None:\n            etid = self.get_etype_id(etype)\n            etype = self.canonical_etypes[etid]\n            _, dtid = self._graph.metagraph.find_edge(etid)\n            g = self if etype is None else self[etype]\n            ndata = core.message_passing(\n                g, message_func, reduce_func, apply_node_func\n            )\n            if (\n                core.is_builtin(reduce_func)\n                and reduce_func.name in [\"min\", \"max\"]\n                and ndata\n            ):\n                # Replace infinity with zero for isolated nodes\n                key = list(ndata.keys())[0]\n                ndata[key] = F.replace_inf_with_zero(ndata[key])\n            self._set_n_repr(dtid, ALL, ndata)\n        else:  # heterogeneous graph with number of relation types > 1\n            if not core.is_builtin(message_func) or not core.is_builtin(\n                reduce_func\n            ):\n                raise DGLError(\n                    \"User defined functions are not yet \"\n                    \"supported in update_all for heterogeneous graphs. \"\n                    \"Please use multi_update_all instead.\"\n                )\n            if reduce_func.name in [\"mean\"]:\n                raise NotImplementedError(\n                    \"Cannot set both intra-type and inter-type reduce \"\n                    \"operators as 'mean' using update_all. Please use \"\n                    \"multi_update_all instead.\"\n                )\n            g = self\n            all_out = core.message_passing(\n                g, message_func, reduce_func, apply_node_func\n            )\n            key = list(all_out.keys())[0]\n            out_tensor_tuples = all_out[key]\n\n            dst_tensor = {}\n            for _, _, dsttype in g.canonical_etypes:\n                dtid = g.get_ntype_id(dsttype)\n                dst_tensor[key] = out_tensor_tuples[dtid]\n                if core.is_builtin(reduce_func) and reduce_func.name in [\n                    \"min\",\n                    \"max\",\n                ]:\n                    dst_tensor[key] = F.replace_inf_with_zero(dst_tensor[key])\n                self._node_frames[dtid].update(dst_tensor)\n\n    #################################################################\n    # Message passing on heterograph\n    #################################################################\n\n    def multi_update_all(self, etype_dict, cross_reducer, apply_node_func=None):\n        r\"\"\"Send messages along all the edges, reduce them by first type-wisely\n        then across different types, and then update the node features of all\n        the nodes.\n\n        Parameters\n        ----------\n        etype_dict : dict\n            Arguments for edge-type-wise message passing. The keys are edge types\n            while the values are message passing arguments.\n\n            The allowed key formats are:\n\n            * ``(str, str, str)`` for source node type, edge type and destination node type.\n            * or one ``str`` edge type name if the name can uniquely identify a\n              triplet format in the graph.\n\n            The value must be a tuple ``(message_func, reduce_func, [apply_node_func])``, where\n\n            * message_func : dgl.function.BuiltinFunction or callable\n                The message function to generate messages along the edges.\n                It must be either a :ref:`api-built-in` or a :ref:`apiudf`.\n            * reduce_func : dgl.function.BuiltinFunction or callable\n                The reduce function to aggregate the messages.\n                It must be either a :ref:`api-built-in` or a :ref:`apiudf`.\n            * apply_node_func : callable, optional\n                An optional apply function to further update the node features\n                after the message reduction. It must be a :ref:`apiudf`.\n\n        cross_reducer : str or callable function\n            Cross type reducer. One of ``\"sum\"``, ``\"min\"``, ``\"max\"``, ``\"mean\"``, ``\"stack\"``\n            or a callable function. If a callable function is provided, the input argument must be\n            a single list of tensors containing aggregation results from each edge type, and the\n            output of function must be a single tensor.\n        apply_node_func : callable, optional\n            An optional apply function after the messages are reduced both\n            type-wisely and across different types.\n            It must be a :ref:`apiudf`.\n\n        Notes\n        -----\n        DGL recommends using DGL's bulit-in function for the message_func\n        and the reduce_func in the type-wise message passing arguments,\n        because DGL will invoke efficient kernels that avoids copying node features to\n        edge features in this case.\n\n\n        Examples\n        --------\n        >>> import dgl\n        >>> import dgl.function as fn\n        >>> import torch\n\n        Instantiate a heterograph.\n\n        >>> g = dgl.heterograph({\n        ...     ('user', 'follows', 'user'): ([0, 1], [1, 1]),\n        ...     ('game', 'attracts', 'user'): ([0], [1])\n        ... })\n        >>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.]])\n        >>> g.nodes['game'].data['h'] = torch.tensor([[1.]])\n\n        Update all.\n\n        >>> g.multi_update_all(\n        ...     {'follows': (fn.copy_u('h', 'm'), fn.sum('m', 'h')),\n        ...      'attracts': (fn.copy_u('h', 'm'), fn.sum('m', 'h'))},\n        ... \"sum\")\n        >>> g.nodes['user'].data['h']\n        tensor([[0.],\n                [4.]])\n\n        User-defined cross reducer equivalent to \"sum\".\n\n        >>> def cross_sum(flist):\n        ...     return torch.sum(torch.stack(flist, dim=0), dim=0) if len(flist) > 1 else flist[0]\n\n        Use the user-defined cross reducer.\n\n        >>> g.multi_update_all(\n        ...     {'follows': (fn.copy_u('h', 'm'), fn.sum('m', 'h')),\n        ...      'attracts': (fn.copy_u('h', 'm'), fn.sum('m', 'h'))},\n        ... cross_sum)\n        \"\"\"\n        all_out = defaultdict(list)\n        merge_order = defaultdict(list)\n        for etype, args in etype_dict.items():\n\n            etid = self.get_etype_id(etype)\n            _, dtid = self._graph.metagraph.find_edge(etid)\n            args = pad_tuple(args, 3)\n            if args is None:\n                raise DGLError(\n                    'Invalid arguments for edge type \"{}\". Should be '\n                    \"(msg_func, reduce_func, [apply_node_func])\".format(etype)\n                )\n            mfunc, rfunc, afunc = args\n            g = self if etype is None else self[etype]\n            all_out[dtid].append(core.message_passing(g, mfunc, rfunc, afunc))\n            merge_order[dtid].append(\n                etid\n            )  # use edge type id as merge order hint\n        for dtid, frames in all_out.items():\n            # merge by cross_reducer\n            out = reduce_dict_data(frames, cross_reducer, merge_order[dtid])\n            # Replace infinity with zero for isolated nodes when reducer is min/max\n            if core.is_builtin(rfunc) and rfunc.name in [\"min\", \"max\"]:\n                for key in out.keys():\n                    out[key] = (\n                        F.replace_inf_with_zero(out[key])\n                        if out[key] is not None\n                        else None\n                    )\n            self._node_frames[dtid].update(out)\n            # apply\n            if apply_node_func is not None:\n                self.apply_nodes(apply_node_func, ALL, self.ntypes[dtid])\n\n    #################################################################\n    # Message propagation\n    #################################################################\n\n    def prop_nodes(\n        self,\n        nodes_generator,\n        message_func,\n        reduce_func,\n        apply_node_func=None,\n        etype=None,\n    ):\n        \"\"\"Propagate messages using graph traversal by sequentially triggering\n        :func:`pull()` on nodes.\n\n        The traversal order is specified by the ``nodes_generator``. It generates\n        node frontiers, which is a list or a tensor of nodes. The nodes in the\n        same frontier will be triggered together, while nodes in different frontiers\n        will be triggered according to the generating order.\n\n        Parameters\n        ----------\n        nodes_generator : iterable[node IDs]\n            The generator of node frontiers. Each frontier is a set of node IDs\n            stored in Tensor or python iterables.\n            It specifies which nodes perform :func:`pull` at each step.\n        message_func : dgl.function.BuiltinFunction or callable\n            The message function to generate messages along the edges.\n            It must be either a :ref:`api-built-in` or a :ref:`apiudf`.\n        reduce_func : dgl.function.BuiltinFunction or callable\n            The reduce function to aggregate the messages.\n            It must be either a :ref:`api-built-in` or a :ref:`apiudf`.\n        apply_node_func : callable, optional\n            An optional apply function to further update the node features\n            after the message reduction. It must be a :ref:`apiudf`.\n        etype : str or (str, str, str), optional\n            The type name of the edges. The allowed type name formats are:\n\n            * ``(str, str, str)`` for source node type, edge type and destination node type.\n            * or one ``str`` edge type name if the name can uniquely identify a\n              triplet format in the graph.\n\n            Can be omitted if the graph has only one type of edges.\n\n        Examples\n        --------\n        >>> import torch\n        >>> import dgl\n        >>> import dgl.function as fn\n\n        Instantiate a heterogrph and perform multiple rounds of message passing.\n\n        >>> g = dgl.heterograph({('user', 'follows', 'user'): ([0, 1, 2, 3], [2, 3, 4, 4])})\n        >>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.], [3.], [4.], [5.]])\n        >>> g['follows'].prop_nodes([[2, 3], [4]], fn.copy_u('h', 'm'),\n        ...                         fn.sum('m', 'h'), etype='follows')\n        tensor([[1.],\n                [2.],\n                [1.],\n                [2.],\n                [3.]])\n\n        See Also\n        --------\n        prop_edges\n        \"\"\"\n        for node_frontier in nodes_generator:\n            self.pull(\n                node_frontier,\n                message_func,\n                reduce_func,\n                apply_node_func,\n                etype=etype,\n            )\n\n    def prop_edges(\n        self,\n        edges_generator,\n        message_func,\n        reduce_func,\n        apply_node_func=None,\n        etype=None,\n    ):\n        \"\"\"Propagate messages using graph traversal by sequentially triggering\n        :func:`send_and_recv()` on edges.\n\n        The traversal order is specified by the ``edges_generator``. It generates\n        edge frontiers. The edge frontiers should be of *valid edges type*.\n        See :func:`send` for more details.\n\n        Edges in the same frontier will be triggered together, and edges in\n        different frontiers will be triggered according to the generating order.\n\n        Parameters\n        ----------\n        edges_generator : generator\n            The generator of edge frontiers.\n        message_func : dgl.function.BuiltinFunction or callable\n            The message function to generate messages along the edges.\n            It must be either a :ref:`api-built-in` or a :ref:`apiudf`.\n        reduce_func : dgl.function.BuiltinFunction or callable\n            The reduce function to aggregate the messages.\n            It must be either a :ref:`api-built-in` or a :ref:`apiudf`.\n        apply_node_func : callable, optional\n            An optional apply function to further update the node features\n            after the message reduction. It must be a :ref:`apiudf`.\n        etype : str or (str, str, str), optional\n            The type name of the edges. The allowed type name formats are:\n\n            * ``(str, str, str)`` for source node type, edge type and destination node type.\n            * or one ``str`` edge type name if the name can uniquely identify a\n              triplet format in the graph.\n\n            Can be omitted if the graph has only one type of edges.\n\n        Examples\n        --------\n        >>> import torch\n        >>> import dgl\n        >>> import dgl.function as fn\n\n        Instantiate a heterogrph and perform multiple rounds of message passing.\n\n        >>> g = dgl.heterograph({('user', 'follows', 'user'): ([0, 1, 2, 3], [2, 3, 4, 4])})\n        >>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.], [3.], [4.], [5.]])\n        >>> g['follows'].prop_edges([[0, 1], [2, 3]], fn.copy_u('h', 'm'),\n        ...                         fn.sum('m', 'h'), etype='follows')\n        >>> g.nodes['user'].data['h']\n        tensor([[1.],\n                [2.],\n                [1.],\n                [2.],\n                [3.]])\n\n        See Also\n        --------\n        prop_nodes\n        \"\"\"\n        for edge_frontier in edges_generator:\n            self.send_and_recv(\n                edge_frontier,\n                message_func,\n                reduce_func,\n                apply_node_func,\n                etype=etype,\n            )\n\n    #################################################################\n    # Misc\n    #################################################################\n\n    def filter_nodes(self, predicate, nodes=ALL, ntype=None):\n        \"\"\"Return the IDs of the nodes with the given node type that satisfy\n        the given predicate.\n\n        Parameters\n        ----------\n        predicate : callable\n            A function of signature ``func(nodes) -> Tensor``.\n            ``nodes`` are :class:`dgl.NodeBatch` objects.\n            Its output tensor should be a 1D boolean tensor with\n            each element indicating whether the corresponding node in\n            the batch satisfies the predicate.\n        nodes : node ID(s), optional\n            The node(s) for query. The allowed formats are:\n\n            - Tensor: A 1D tensor that contains the node(s) for query, whose data type\n              and device should be the same as the :py:attr:`idtype` and device of the graph.\n            - iterable[int] : Similar to the tensor, but stores node IDs in a sequence\n              (e.g. list, tuple, numpy.ndarray).\n\n            By default, it considers all nodes.\n        ntype : str, optional\n            The node type for query. If the graph has multiple node types, one must\n            specify the argument. Otherwise, it can be omitted.\n\n        Returns\n        -------\n        Tensor\n            A 1D tensor that contains the ID(s) of the node(s) that satisfy the predicate.\n\n        Examples\n        --------\n\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        Define a predicate function.\n\n        >>> def nodes_with_feature_one(nodes):\n        ...     # Whether a node has feature 1\n        ...     return (nodes.data['h'] == 1.).squeeze(1)\n\n        Filter nodes for a homogeneous graph.\n\n        >>> g = dgl.graph((torch.tensor([0, 1, 2]), torch.tensor([1, 2, 3])))\n        >>> g.ndata['h'] = torch.tensor([[0.], [1.], [1.], [0.]])\n        >>> print(g.filter_nodes(nodes_with_feature_one))\n        tensor([1, 2])\n\n        Filter on nodes with IDs 0 and 1\n\n        >>> print(g.filter_nodes(nodes_with_feature_one, nodes=torch.tensor([0, 1])))\n        tensor([1])\n\n        Filter nodes for a heterogeneous graph.\n\n        >>> g = dgl.heterograph({\n        ...     ('user', 'plays', 'game'): (torch.tensor([0, 1, 1, 2]),\n        ...                                 torch.tensor([0, 0, 1, 1]))})\n        >>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [1.]])\n        >>> g.nodes['game'].data['h'] = torch.tensor([[0.], [1.]])\n        >>> # Filter for 'user' nodes\n        >>> print(g.filter_nodes(nodes_with_feature_one, ntype='user'))\n        tensor([1, 2])\n        \"\"\"\n        if is_all(nodes):\n            nodes = self.nodes(ntype)\n        v = utils.prepare_tensor(self, nodes, \"nodes\")\n        if F.as_scalar(F.sum(self.has_nodes(v, ntype=ntype), dim=0)) != len(v):\n            raise DGLError(\"v contains invalid node IDs\")\n\n        with self.local_scope():\n            self.apply_nodes(\n                lambda nbatch: {\"_mask\": predicate(nbatch)}, nodes, ntype\n            )\n            ntype = self.ntypes[0] if ntype is None else ntype\n            mask = self.nodes[ntype].data[\"_mask\"]\n            if is_all(nodes):\n                return F.nonzero_1d(mask)\n            else:\n                return F.boolean_mask(v, F.gather_row(mask, v))\n\n    def filter_edges(self, predicate, edges=ALL, etype=None):\n        \"\"\"Return the IDs of the edges with the given edge type that satisfy\n        the given predicate.\n\n        Parameters\n        ----------\n        predicate : callable\n            A function of signature ``func(edges) -> Tensor``.\n            ``edges`` are :class:`dgl.EdgeBatch` objects.\n            Its output tensor should be a 1D boolean tensor with\n            each element indicating whether the corresponding edge in\n            the batch satisfies the predicate.\n        edges : edges\n            The edges to send and receive messages on. The allowed input formats are:\n\n            * ``int``: A single edge ID.\n            * Int Tensor: Each element is an edge ID.  The tensor must have the same device type\n              and ID data type as the graph's.\n            * iterable[int]: Each element is an edge ID.\n            * (Tensor, Tensor): The node-tensors format where the i-th elements\n              of the two tensors specify an edge.\n            * (iterable[int], iterable[int]): Similar to the node-tensors format but\n              stores edge endpoints in python iterables.\n\n            By default, it considers all the edges.\n        etype : str or (str, str, str), optional\n            The type name of the edges. The allowed type name formats are:\n\n            * ``(str, str, str)`` for source node type, edge type and destination node type.\n            * or one ``str`` edge type name if the name can uniquely identify a\n              triplet format in the graph.\n\n            Can be omitted if the graph has only one type of edges.\n\n        Returns\n        -------\n        Tensor\n            A 1D tensor that contains the ID(s) of the edge(s) that satisfy the predicate.\n\n        Examples\n        --------\n\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        Define a predicate function.\n\n        >>> def edges_with_feature_one(edges):\n        ...     # Whether an edge has feature 1\n        ...     return (edges.data['h'] == 1.).squeeze(1)\n\n        Filter edges for a homogeneous graph.\n\n        >>> g = dgl.graph((torch.tensor([0, 1, 2]), torch.tensor([1, 2, 3])))\n        >>> g.edata['h'] = torch.tensor([[0.], [1.], [1.]])\n        >>> print(g.filter_edges(edges_with_feature_one))\n        tensor([1, 2])\n\n        Filter on edges with IDs 0 and 1\n\n        >>> print(g.filter_edges(edges_with_feature_one, edges=torch.tensor([0, 1])))\n        tensor([1])\n\n        Filter edges for a heterogeneous graph.\n\n        >>> g = dgl.heterograph({\n        ...     ('user', 'plays', 'game'): (torch.tensor([0, 1, 1, 2]),\n        ...                                 torch.tensor([0, 0, 1, 1])),\n        ...     ('user', 'follows', 'user'): (torch.tensor([0, 1]), torch.tensor([1, 2]))})\n        >>> g.edges['plays'].data['h'] = torch.tensor([[0.], [1.], [1.], [0.]])\n        >>> # Filter for 'plays' nodes\n        >>> print(g.filter_edges(edges_with_feature_one, etype='plays'))\n        tensor([1, 2])\n        \"\"\"\n        if is_all(edges):\n            pass\n        elif isinstance(edges, tuple):\n            u, v = edges\n            srctype, _, dsttype = self.to_canonical_etype(etype)\n            u = utils.prepare_tensor(self, u, \"u\")\n            if F.as_scalar(\n                F.sum(self.has_nodes(u, ntype=srctype), dim=0)\n            ) != len(u):\n                raise DGLError(\"edges[0] contains invalid node IDs\")\n            v = utils.prepare_tensor(self, v, \"v\")\n            if F.as_scalar(\n                F.sum(self.has_nodes(v, ntype=dsttype), dim=0)\n            ) != len(v):\n                raise DGLError(\"edges[1] contains invalid node IDs\")\n        elif isinstance(edges, Iterable) or F.is_tensor(edges):\n            edges = utils.prepare_tensor(self, edges, \"edges\")\n            min_eid = F.as_scalar(F.min(edges, 0))\n            if len(edges) > 0 > min_eid:\n                raise DGLError(\"Invalid edge ID {:d}\".format(min_eid))\n            max_eid = F.as_scalar(F.max(edges, 0))\n            if len(edges) > 0 and max_eid >= self.num_edges(etype):\n                raise DGLError(\"Invalid edge ID {:d}\".format(max_eid))\n        else:\n            raise ValueError(\"Unsupported type of edges:\", type(edges))\n\n        with self.local_scope():\n            self.apply_edges(\n                lambda ebatch: {\"_mask\": predicate(ebatch)}, edges, etype\n            )\n            etype = self.canonical_etypes[0] if etype is None else etype\n            mask = self.edges[etype].data[\"_mask\"]\n            if is_all(edges):\n                return F.nonzero_1d(mask)\n            else:\n                if isinstance(edges, tuple):\n                    e = self.edge_ids(edges[0], edges[1], etype=etype)\n                else:\n                    e = utils.prepare_tensor(self, edges, \"edges\")\n                return F.boolean_mask(e, F.gather_row(mask, e))\n\n    @property\n    def device(self):\n        \"\"\"Get the device of the graph.\n\n        Returns\n        -------\n        device context\n            The device of the graph, which should be a framework-specific device object\n            (e.g., ``torch.device``).\n\n        Examples\n        --------\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        Create a homogeneous graph for demonstration.\n\n        >>> g = dgl.graph((torch.tensor([0, 1]), torch.tensor([1, 2])))\n        >>> print(g.device)\n        device(type='cpu')\n\n        The case of heterogeneous graphs is the same.\n        \"\"\"\n        return F.to_backend_ctx(self._graph.ctx)\n\n    def to(self, device, **kwargs):  # pylint: disable=invalid-name\n        \"\"\"Move ndata, edata and graph structure to the targeted device (cpu/gpu).\n\n        If the graph is already on the specified device, the function directly returns it.\n        Otherwise, it returns a cloned graph on the specified device.\n\n        Note that data of node and edge features are not moved to the specified\n        device before being accessed or `materialize_data()` is called.\n\n        Parameters\n        ----------\n        device : Framework-specific device context object\n            The context to move data to (e.g., ``torch.device``).\n        kwargs : Key-word arguments.\n            Key-word arguments fed to the framework copy function.\n\n        Returns\n        -------\n        DGLGraph\n            The graph on the specified device.\n\n        Examples\n        --------\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        >>> g = dgl.graph((torch.tensor([1, 0]), torch.tensor([1, 2])))\n        >>> g.ndata['h'] = torch.ones(3, 1)\n        >>> g.edata['h'] = torch.zeros(2, 2)\n        >>> g1 = g.to(torch.device('cuda:0'))\n        >>> print(g1.device)\n        device(type='cuda', index=0)\n        >>> print(g1.ndata['h'].device)\n        device(type='cuda', index=0)\n        >>> print(g1.nodes().device)\n        device(type='cuda', index=0)\n\n        The original graph is still on CPU.\n\n        >>> print(g.device)\n        device(type='cpu')\n        >>> print(g.ndata['h'].device)\n        device(type='cpu')\n        >>> print(g.nodes().device)\n        device(type='cpu')\n\n        The case of heterogeneous graphs is the same.\n        \"\"\"\n        if device is None or self.device == device:\n            return self\n\n        ret = copy.copy(self)\n\n        # 1. Copy graph structure\n        ret._graph = self._graph.copy_to(utils.to_dgl_context(device))\n\n        # 2. Copy features\n        # TODO(minjie): handle initializer\n        new_nframes = []\n        for nframe in self._node_frames:\n            new_nframes.append(nframe.to(device, **kwargs))\n        ret._node_frames = new_nframes\n\n        new_eframes = []\n        for eframe in self._edge_frames:\n            new_eframes.append(eframe.to(device, **kwargs))\n        ret._edge_frames = new_eframes\n\n        # 2. Copy misc info\n        if self._batch_num_nodes is not None:\n            new_bnn = {\n                k: F.copy_to(num, device, **kwargs)\n                for k, num in self._batch_num_nodes.items()\n            }\n            ret._batch_num_nodes = new_bnn\n        if self._batch_num_edges is not None:\n            new_bne = {\n                k: F.copy_to(num, device, **kwargs)\n                for k, num in self._batch_num_edges.items()\n            }\n            ret._batch_num_edges = new_bne\n\n        return ret\n\n    def cpu(self):\n        \"\"\"Return a new copy of this graph on CPU.\n\n        Returns\n        -------\n        DGLGraph\n            Graph on CPU.\n\n        See Also\n        --------\n        to\n        \"\"\"\n        return self.to(F.cpu())\n\n    def materialize_data(self):\n        \"\"\"Materialize the graph data on the current device.\n\n        This method is a no-op if the graph data is already materialized.\n\n        Returns\n        -------\n        DGLGraph\n            The graph on the current device.\n        \"\"\"\n        for frame in itertools.chain(self._node_frames, self._edge_frames):\n            for col in frame._columns.values():\n                col.data  # pylint: disable=pointless-statement\n        return self\n\n    def pin_memory_(self):\n        \"\"\"Pin the graph structure and node/edge data to the page-locked memory for\n        GPU zero-copy access.\n\n        This is an **inplace** method. The graph structure must be on CPU to be pinned.\n        If the graph struture is already pinned, the function directly returns it.\n\n        Materialization of new sparse formats for pinned graphs is not allowed.\n        To avoid implicit formats materialization during training,\n        you should create all the needed formats before pinning.\n        But cloning and materialization is fine. See the examples below.\n\n        Returns\n        -------\n        DGLGraph\n            The pinned graph.\n\n        Examples\n        --------\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        >>> g = dgl.graph((torch.tensor([1, 0]), torch.tensor([1, 2])))\n        >>> g.pin_memory_()\n\n        Materialization of new sparse formats is not allowed for pinned graphs.\n\n        >>> g.create_formats_()  # This would raise an error! You should do this before pinning.\n\n        Cloning and materializing new formats is allowed. The returned graph is **not** pinned.\n\n        >>> g1 = g.formats(['csc'])\n        >>> assert not g1.is_pinned()\n\n        The pinned graph can be access from both CPU and GPU. The concrete device depends\n        on the context of ``query``. For example, ``eid`` in ``find_edges()`` is a query.\n        When ``eid`` is on CPU, ``find_edges()`` is executed on CPU, and the returned\n        values are CPU tensors\n\n        >>> g.unpin_memory_()\n        >>> g.create_formats_()\n        >>> g.pin_memory_()\n        >>> eid = torch.tensor([1])\n        >>> g.find_edges(eids)\n        (tensor([0]), tensor([2]))\n\n        Moving ``eid`` to GPU, ``find_edges()`` will be executed on GPU, and the returned\n        values are GPU tensors.\n\n        >>> eid = eid.to('cuda:0')\n        >>> g.find_edges(eids)\n        (tensor([0], device='cuda:0'), tensor([2], device='cuda:0'))\n\n        If you don't provide a ``query``, methods will be executed on CPU by default.\n\n        >>> g.in_degrees()\n        tensor([0, 1, 1])\n        \"\"\"\n        if not self._graph.is_pinned():\n            if F.device_type(self.device) != \"cpu\":\n                raise DGLError(\n                    \"The graph structure must be on CPU to be pinned.\"\n                )\n            self._graph.pin_memory_()\n        for frame in itertools.chain(self._node_frames, self._edge_frames):\n            for col in frame._columns.values():\n                col.pin_memory_()\n\n        return self\n\n    def unpin_memory_(self):\n        \"\"\"Unpin the graph structure and node/edge data from the page-locked memory.\n\n        This is an **inplace** method. If the graph struture is not pinned,\n        e.g., on CPU or GPU, the function directly returns it.\n\n        Returns\n        -------\n        DGLGraph\n            The unpinned graph.\n        \"\"\"\n        if self._graph.is_pinned():\n            self._graph.unpin_memory_()\n        for frame in itertools.chain(self._node_frames, self._edge_frames):\n            for col in frame._columns.values():\n                col.unpin_memory_()\n\n        return self\n\n    def is_pinned(self):\n        \"\"\"Check if the graph structure is pinned to the page-locked memory.\n\n        Returns\n        -------\n        bool\n            True if the graph structure is pinned.\n        \"\"\"\n        return self._graph.is_pinned()\n\n    def record_stream(self, stream):\n        \"\"\"Record the stream that is using this graph.\n        This method only supports the PyTorch backend and requires graphs on the GPU.\n\n        Parameters\n        ----------\n        stream : torch.cuda.Stream\n            The stream that is using this graph.\n\n        Returns\n        -------\n        DGLGraph\n            self.\n        \"\"\"\n        if F.get_preferred_backend() != \"pytorch\":\n            raise DGLError(\"record_stream only support the PyTorch backend.\")\n        if F.device_type(self.device) != \"cuda\":\n            raise DGLError(\"The graph must be on GPU to be recorded.\")\n        self._graph.record_stream(stream)\n        for frame in itertools.chain(self._node_frames, self._edge_frames):\n            for col in frame._columns.values():\n                col.record_stream(stream)\n\n        return self\n\n    def clone(self):\n        \"\"\"Return a heterograph object that is a clone of current graph.\n\n        Returns\n        -------\n        DGLGraph\n            The graph object that is a clone of current graph.\n        \"\"\"\n        # XXX(minjie): Do a shallow copy first to clone some internal metagraph information.\n        #   Not a beautiful solution though.\n        ret = copy.copy(self)\n\n        # Clone the graph structure\n        meta_edges = []\n        for s_ntype, _, d_ntype in self.canonical_etypes:\n            meta_edges.append(\n                (self.get_ntype_id(s_ntype), self.get_ntype_id(d_ntype))\n            )\n\n        metagraph = graph_index.from_edge_list(meta_edges, True)\n        # rebuild graph idx\n        num_nodes_per_type = [\n            self.num_nodes(c_ntype) for c_ntype in self.ntypes\n        ]\n        relation_graphs = [\n            self._graph.get_relation_graph(self.get_etype_id(c_etype))\n            for c_etype in self.canonical_etypes\n        ]\n        ret._graph = heterograph_index.create_heterograph_from_relations(\n            metagraph,\n            relation_graphs,\n            utils.toindex(num_nodes_per_type, \"int64\"),\n        )\n\n        # Clone the frames\n        ret._node_frames = [fr.clone() for fr in self._node_frames]\n        ret._edge_frames = [fr.clone() for fr in self._edge_frames]\n\n        # Copy the batch information\n        ret._batch_num_nodes = copy.copy(self._batch_num_nodes)\n        ret._batch_num_edges = copy.copy(self._batch_num_edges)\n\n        return ret\n\n    def local_var(self):\n        \"\"\"Return a graph object for usage in a local function scope.\n\n        The returned graph object shares the feature data and graph structure of this graph.\n        However, any out-place mutation to the feature data will not reflect to this graph,\n        thus making it easier to use in a function scope (e.g. forward computation of a model).\n\n        If set, the local graph object will use same initializers for node features and\n        edge features.\n\n        Returns\n        -------\n        DGLGraph\n            The graph object for a local variable.\n\n        Notes\n        -----\n        Inplace operations do reflect to the original graph. This function also has little\n        overhead when the number of feature tensors in this graph is small.\n\n        Examples\n        --------\n\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        Create a function for computation on graphs.\n\n        >>> def foo(g):\n        ...     g = g.local_var()\n        ...     g.edata['h'] = torch.ones((g.num_edges(), 3))\n        ...     g.edata['h2'] = torch.ones((g.num_edges(), 3))\n        ...     return g.edata['h']\n\n        ``local_var`` avoids changing the graph features when exiting the function.\n\n        >>> g = dgl.graph((torch.tensor([0, 1, 1]), torch.tensor([0, 0, 2])))\n        >>> g.edata['h'] = torch.zeros((g.num_edges(), 3))\n        >>> newh = foo(g)\n        >>> print(g.edata['h'])  # still get tensor of all zeros\n        tensor([[0., 0., 0.],\n                [0., 0., 0.],\n                [0., 0., 0.]])\n        >>> 'h2' in g.edata      # new feature set in the function scope is not found\n        False\n\n        In-place operations will still reflect to the original graph.\n\n        >>> def foo(g):\n        ...     g = g.local_var()\n        ...     # in-place operation\n        ...     g.edata['h'] += 1\n        ...     return g.edata['h']\n\n        >>> g = dgl.graph((torch.tensor([0, 1, 1]), torch.tensor([0, 0, 2])))\n        >>> g.edata['h'] = torch.zeros((g.num_edges(), 1))\n        >>> newh = foo(g)\n        >>> print(g.edata['h'])  # the result changes\n        tensor([[1.],\n                [1.],\n                [1.]])\n\n        See Also\n        --------\n        local_scope\n        \"\"\"\n        ret = copy.copy(self)\n        ret._node_frames = [fr.clone() for fr in self._node_frames]\n        ret._edge_frames = [fr.clone() for fr in self._edge_frames]\n        return ret\n\n    @contextmanager\n    def local_scope(self):\n        \"\"\"Enter a local scope context for the graph.\n\n        By entering a local scope, any out-place mutation to the feature data will\n        not reflect to the original graph, thus making it easier to use in a function scope\n        (e.g. forward computation of a model).\n\n        If set, the local scope will use same initializers for node features and\n        edge features.\n\n        Notes\n        -----\n        Inplace operations do reflect to the original graph. This function also has little\n        overhead when the number of feature tensors in this graph is small.\n\n        Examples\n        --------\n\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        Create a function for computation on graphs.\n\n        >>> def foo(g):\n        ...     with g.local_scope():\n        ...         g.edata['h'] = torch.ones((g.num_edges(), 3))\n        ...         g.edata['h2'] = torch.ones((g.num_edges(), 3))\n        ...         return g.edata['h']\n\n        ``local_scope`` avoids changing the graph features when exiting the function.\n\n        >>> g = dgl.graph((torch.tensor([0, 1, 1]), torch.tensor([0, 0, 2])))\n        >>> g.edata['h'] = torch.zeros((g.num_edges(), 3))\n        >>> newh = foo(g)\n        >>> print(g.edata['h'])  # still get tensor of all zeros\n        tensor([[0., 0., 0.],\n                [0., 0., 0.],\n                [0., 0., 0.]])\n        >>> 'h2' in g.edata      # new feature set in the function scope is not found\n        False\n\n        In-place operations will still reflect to the original graph.\n\n        >>> def foo(g):\n        ...     with g.local_scope():\n        ...         # in-place operation\n        ...         g.edata['h'] += 1\n        ...         return g.edata['h']\n\n        >>> g = dgl.graph((torch.tensor([0, 1, 1]), torch.tensor([0, 0, 2])))\n        >>> g.edata['h'] = torch.zeros((g.num_edges(), 1))\n        >>> newh = foo(g)\n        >>> print(g.edata['h'])  # the result changes\n        tensor([[1.],\n                [1.],\n                [1.]])\n\n        See Also\n        --------\n        local_var\n        \"\"\"\n        old_nframes = self._node_frames\n        old_eframes = self._edge_frames\n        self._node_frames = [fr.clone() for fr in self._node_frames]\n        self._edge_frames = [fr.clone() for fr in self._edge_frames]\n        try:\n            yield\n        finally:\n            self._node_frames = old_nframes\n            self._edge_frames = old_eframes\n\n    def formats(self, formats=None):\n        r\"\"\"Get a cloned graph with the specified allowed sparse format(s) or\n        query for the usage status of sparse formats.\n\n        The API copies both the graph structure and the features.\n\n        If the input graph has multiple edge types, they will have the same\n        sparse format.\n\n        When ``formats`` is not None, if the intersection between `formats` and\n        the current graph's created sparse format(s) is not empty, the returned\n        cloned graph only retains all sparse format(s) in the intersection. If\n        the intersection is empty, a sparse format will be selected to be\n        created following the order of ``'coo' -> 'csr' -> 'csc'``.\n\n        Parameters\n        ----------\n        formats : str or list of str or None\n\n            * If formats is None, return the usage status of sparse formats\n            * Otherwise, it can be ``'coo'``/``'csr'``/``'csc'`` or a sublist of\n              them, specifying the sparse formats to use.\n\n        Returns\n        -------\n        dict or DGLGraph\n\n            * If formats is None, the result will be a dict recording the usage\n              status of sparse formats.\n            * Otherwise, a DGLGraph will be returned, which is a clone of the\n              original graph with the specified allowed sparse format(s)\n              ``formats``.\n\n        Examples\n        --------\n\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        **Homographs or Heterographs with A Single Edge Type**\n\n        >>> g = dgl.graph(([0, 0, 1], [2, 3, 2]))\n        >>> g.ndata['h'] = torch.ones(4, 1)\n        >>> # Check status of format usage.\n        >>> g.formats()\n        {'created': ['coo'], 'not created': ['csr', 'csc']}\n        >>> # Get a clone of the graph with 'csr' format.\n        >>> csr_g = g.formats('csr')\n        >>> # Only allowed formats will be displayed in the status query.\n        >>> csr_g.formats()\n        {'created': ['csr'], 'not created': []}\n        >>> # Features are copied as well.\n        >>> csr_g.ndata['h']\n        tensor([[1.],\n                [1.],\n                [1.],\n                [1.]])\n\n        **Heterographs with Multiple Edge Types**\n\n        >>> g = dgl.heterograph({\n        ...     ('user', 'plays', 'game'): (torch.tensor([0, 1, 1, 2]),\n        ...                                 torch.tensor([0, 0, 1, 1])),\n        ...     ('developer', 'develops', 'game'): (torch.tensor([0, 1]),\n        ...                                         torch.tensor([0, 1]))\n        ...     })\n        >>> g.formats()\n        {'created': ['coo'], 'not created': ['csr', 'csc']}\n        >>> # Get a clone of the graph with 'csr' format.\n        >>> csr_g = g.formats('csr')\n        >>> # Only allowed formats will be displayed in the status query.\n        >>> csr_g.formats()\n        {'created': ['csr'], 'not created': []}\n\n        **When formats intersects with created formats**\n\n        >>> g = dgl.graph(([0, 0, 1], [2, 3, 2]))\n        >>> g = g.formats(['coo', 'csr'])\n        >>> g.create_formats_()\n        >>> g.formats()\n        {'created': ['coo', 'csr'], 'not created': []}\n        >>> # Get a clone of the graph allowed formats 'csr' and 'csc'.\n        >>> csr_csc_g = g.formats(['csr', 'csc'])\n        >>> # Only the intersection 'csr' will be retained.\n        >>> csr_csc_g.formats()\n        {'created': ['csr'], 'not created': ['csc']}\n\n        **When formats doesn't intersect with created formats**\n\n        >>> g = dgl.graph(([0, 0, 1], [2, 3, 2]))\n        >>> g = g.formats('coo')\n        >>> g.formats()\n        {'created': ['coo'], 'not created': []}\n        >>> # Get a clone of the graph allowed formats 'csr' and 'csc'.\n        >>> csr_csc_g = g.formats(['csr', 'csc'])\n        >>> # Since the intersection is empty, 'csr' will be created as it is\n        >>> # first in the order of 'coo' -> 'csr' -> 'csc'.\n        >>> csr_csc_g.formats()\n        {'created': ['csr'], 'not created': ['csc']}\n        \"\"\"\n        if formats is None:\n            # Return the format information.\n            return self._graph.formats()\n        else:\n            # Convert the graph to use another allowed format.\n            ret = copy.copy(self)\n            ret._graph = self._graph.formats(formats)\n            return ret\n\n    def create_formats_(self):\n        r\"\"\"Create all sparse matrices allowed for the graph.\n\n        By default, we create sparse matrices for a graph only when necessary.\n        In some cases we may want to create them immediately (e.g. in a\n        multi-process data loader), which can be achieved via this API.\n\n        Examples\n        --------\n\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        **Homographs or Heterographs with A Single Edge Type**\n\n        >>> g = dgl.graph(([0, 0, 1], [2, 3, 2]))\n        >>> g.format()\n        {'created': ['coo'], 'not created': ['csr', 'csc']}\n        >>> g.create_formats_()\n        >>> g.format()\n        {'created': ['coo', 'csr', 'csc'], 'not created': []}\n\n        **Heterographs with Multiple Edge Types**\n\n        >>> g = dgl.heterograph({\n        ...     ('user', 'plays', 'game'): (torch.tensor([0, 1, 1, 2]),\n        ...                                 torch.tensor([0, 0, 1, 1])),\n        ...     ('developer', 'develops', 'game'): (torch.tensor([0, 1]),\n        ...                                         torch.tensor([0, 1]))\n        ...     })\n        >>> g.format()\n        {'created': ['coo'], 'not created': ['csr', 'csc']}\n        >>> g.create_formats_()\n        >>> g.format()\n        {'created': ['coo', 'csr', 'csc'], 'not created': []}\n        \"\"\"\n        return self._graph.create_formats_()\n\n    def astype(self, idtype):\n        \"\"\"Cast this graph to use another ID type.\n\n        Features are copied (shallow copy) to the new graph.\n\n        Parameters\n        ----------\n        idtype : Data type object.\n            New ID type. Can only be int32 or int64.\n\n        Returns\n        -------\n        DGLGraph\n            Graph in the new ID type.\n        \"\"\"\n        if idtype is None:\n            return self\n        utils.check_valid_idtype(idtype)\n        if self.idtype == idtype:\n            return self\n        bits = 32 if idtype == F.int32 else 64\n        ret = copy.copy(self)\n        ret._graph = self._graph.asbits(bits)\n        return ret\n\n    # TODO: Formats should not be specified, just saving all the materialized formats\n    def shared_memory(self, name, formats=(\"coo\", \"csr\", \"csc\")):\n        \"\"\"Return a copy of this graph in shared memory, without node data or edge data.\n\n        It moves the graph index to shared memory and returns a DGLGraph object which\n        has the same graph structure, node types and edge types but does not contain node data\n        or edge data.\n\n        Parameters\n        ----------\n        name : str\n            The name of the shared memory.\n        formats : str or a list of str (optional)\n            Desired formats to be materialized.\n\n        Returns\n        -------\n        DGLGraph\n            The graph in shared memory\n        \"\"\"\n        assert len(name) > 0, \"The name of shared memory cannot be empty\"\n        assert len(formats) > 0\n        if isinstance(formats, str):\n            formats = [formats]\n        for fmt in formats:\n            assert fmt in (\n                \"coo\",\n                \"csr\",\n                \"csc\",\n            ), \"{} is not coo, csr or csc\".format(fmt)\n        gidx = self._graph.shared_memory(\n            name, self.ntypes, self.etypes, formats\n        )\n        return DGLGraph(gidx, self.ntypes, self.etypes)\n\n    def long(self):\n        \"\"\"Cast the graph to one with idtype int64\n\n        If the graph already has idtype int64, the function directly returns it. Otherwise,\n        it returns a cloned graph of idtype int64 with features copied (shallow copy).\n\n        Returns\n        -------\n        DGLGraph\n            The graph of idtype int64.\n\n        Examples\n        --------\n\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        Create a graph of idtype int32.\n\n        >>> # (0, 1), (0, 2), (1, 2)\n        >>> g = dgl.graph((torch.tensor([0, 0, 1]).int(), torch.tensor([1, 2, 2]).int()))\n        >>> g.ndata['feat'] = torch.ones(3, 1)\n        >>> g.idtype\n        torch.int32\n\n        Cast the graph to one of idtype int64.\n\n        >>> # A cloned graph with an idtype of int64\n        >>> g_long = g.long()\n        >>> g_long.idtype\n        torch.int64\n        >>> # The idtype of the original graph does not change.\n        >>> g.idtype\n        torch.int32\n        >>> g_long.edges()\n        (tensor([0, 0, 1]), tensor([1, 2, 2]))\n        >>> g_long.ndata\n        {'feat': tensor([[1.],\n                         [1.],\n                         [1.]])}\n\n        See Also\n        --------\n        int\n        idtype\n        \"\"\"\n        return self.astype(F.int64)\n\n    def int(self):\n        \"\"\"Cast the graph to one with idtype int32\n\n        If the graph already has idtype int32, the function directly returns it. Otherwise,\n        it returns a cloned graph of idtype int32 with features copied (shallow copy).\n\n        Returns\n        -------\n        DGLGraph\n            The graph of idtype int32.\n\n        Examples\n        --------\n\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        Create a graph of idtype int64.\n\n        >>> # (0, 1), (0, 2), (1, 2)\n        >>> g = dgl.graph((torch.tensor([0, 0, 1]), torch.tensor([1, 2, 2])))\n        >>> g.ndata['feat'] = torch.ones(3, 1)\n        >>> g.idtype\n        torch.int64\n\n        Cast the graph to one of idtype int32.\n\n        >>> # A cloned graph with an idtype of int32\n        >>> g_int = g.int()\n        >>> g_int.idtype\n        torch.int32\n        >>> # The idtype of the original graph does not change.\n        >>> g.idtype\n        torch.int64\n        >>> g_int.edges()\n        (tensor([0, 0, 1], dtype=torch.int32), tensor([1, 2, 2], dtype=torch.int32))\n        >>> g_int.ndata\n        {'feat': tensor([[1.],\n                         [1.],\n                         [1.]])}\n\n        See Also\n        --------\n        long\n        idtype\n        \"\"\"\n        return self.astype(F.int32)\n\n\n############################################################\n# Internal APIs\n############################################################\n\n\ndef make_canonical_etypes(etypes, ntypes, metagraph):\n    \"\"\"Internal function to convert etype name to (srctype, etype, dsttype)\n\n    Parameters\n    ----------\n    etypes : list of str\n        Edge type list\n    ntypes : list of str\n        Node type list\n    metagraph : GraphIndex\n        Meta graph.\n\n    Returns\n    -------\n    list of tuples (srctype, etype, dsttype)\n    \"\"\"\n    # sanity check\n    if len(etypes) != metagraph.num_edges():\n        raise DGLError(\n            \"Length of edge type list must match the number of \"\n            \"edges in the metagraph. {} vs {}\".format(\n                len(etypes), metagraph.num_edges()\n            )\n        )\n    if len(ntypes) != metagraph.num_nodes():\n        raise DGLError(\n            \"Length of nodes type list must match the number of \"\n            \"nodes in the metagraph. {} vs {}\".format(\n                len(ntypes), metagraph.num_nodes()\n            )\n        )\n    if len(etypes) == 1 and len(ntypes) == 1:\n        return [(ntypes[0], etypes[0], ntypes[0])]\n    src, dst, eid = metagraph.edges(order=\"eid\")\n    rst = [\n        (ntypes[sid], etypes[eid], ntypes[did])\n        for sid, did, eid in zip(src, dst, eid)\n    ]\n    return rst\n\n\ndef find_src_dst_ntypes(ntypes, metagraph):\n    \"\"\"Internal function to split ntypes into SRC and DST categories.\n\n    If the metagraph is not a uni-bipartite graph (so that the SRC and DST categories\n    are not well-defined), return None.\n\n    For node types that are isolated (i.e, no relation is associated with it), they\n    are assigned to the SRC category.\n\n    Parameters\n    ----------\n    ntypes : list of str\n        Node type list\n    metagraph : GraphIndex\n        Meta graph.\n\n    Returns\n    -------\n    (dict[int, str], dict[int, str]) or None\n        Node types belonging to SRC and DST categories. Types are stored in\n        a dictionary from type name to type id. Return None if the graph is\n        not uni-bipartite.\n    \"\"\"\n    ret = _CAPI_DGLFindSrcDstNtypes(metagraph)\n    if ret is None:\n        return None\n    else:\n        src, dst = ret\n        srctypes = {ntypes[tid]: tid for tid in src}\n        dsttypes = {ntypes[tid]: tid for tid in dst}\n        return srctypes, dsttypes\n\n\ndef pad_tuple(tup, length, pad_val=None):\n    \"\"\"Pad the given tuple to the given length.\n\n    If the input is not a tuple, convert it to a tuple of length one.\n    Return None if pad fails.\n    \"\"\"\n    if not isinstance(tup, tuple):\n        tup = (tup,)\n    if len(tup) > length:\n        return None\n    elif len(tup) == length:\n        return tup\n    else:\n        return tup + (pad_val,) * (length - len(tup))\n\n\ndef reduce_dict_data(frames, reducer, order=None):\n    \"\"\"Merge tensor dictionaries into one. Resolve conflict fields using reducer.\n\n    Parameters\n    ----------\n    frames : list[dict[str, Tensor]]\n        Input tensor dictionaries\n    reducer : str or callable function\n        One of \"sum\", \"max\", \"min\", \"mean\", \"stack\" or a callable function.\n        If a callable function is provided, the input arguments must be a single list\n        of tensors containing aggregation results from each edge type, and the\n        output of function must be a single tensor.\n    order : list[Int], optional\n        Merge order hint. Useful for \"stack\" reducer.\n        If provided, each integer indicates the relative order\n        of the ``frames`` list. Frames are sorted according to this list\n        in ascending order. Tie is not handled so make sure the order values\n        are distinct.\n\n    Returns\n    -------\n    dict[str, Tensor]\n        Merged frame\n    \"\"\"\n    if len(frames) == 1 and reducer != \"stack\":\n        # Directly return the only one input. Stack reducer requires\n        # modifying tensor shape.\n        return frames[0]\n    if callable(reducer):\n        merger = reducer\n    elif reducer == \"stack\":\n        # Stack order does not matter. However, it must be consistent!\n        if order:\n            assert len(order) == len(frames)\n            sorted_with_key = sorted(zip(frames, order), key=lambda x: x[1])\n            frames = list(zip(*sorted_with_key))[0]\n\n        def merger(flist):\n            return F.stack(flist, 1)\n\n    else:\n        redfn = getattr(F, reducer, None)\n        if redfn is None:\n            raise DGLError(\n                \"Invalid cross type reducer. Must be one of \"\n                '\"sum\", \"max\", \"min\", \"mean\" or \"stack\".'\n            )\n\n        def merger(flist):\n            return redfn(F.stack(flist, 0), 0) if len(flist) > 1 else flist[0]\n\n    keys = set()\n    for frm in frames:\n        keys.update(frm.keys())\n    ret = {}\n    for k in keys:\n        flist = []\n        for frm in frames:\n            if k in frm:\n                flist.append(frm[k])\n        ret[k] = merger(flist)\n    return ret\n\n\ndef combine_frames(frames, ids, col_names=None):\n    \"\"\"Merge the frames into one frame, taking the common columns.\n\n    Return None if there is no common columns.\n\n    Parameters\n    ----------\n    frames : List[Frame]\n        List of frames\n    ids : List[int]\n        List of frame IDs\n    col_names : List[str], optional\n        Column names to consider. If not given, it considers all columns.\n\n    Returns\n    -------\n    Frame\n        The resulting frame\n    \"\"\"\n    # find common columns and check if their schemes match\n    schemes = None\n    for frame_id in ids:\n        frame = frames[frame_id]\n        if frame.num_rows == 0:\n            continue\n        if schemes is None:\n            schemes = frame.schemes\n            if col_names is not None:\n                schemes = {key: frame.schemes[key] for key in col_names}\n            continue\n        for key, scheme in list(schemes.items()):\n            if key in frame.schemes:\n                if frame.schemes[key] != scheme:\n                    raise DGLError(\n                        \"Cannot concatenate column %s with shape %s and shape %s\"\n                        % (key, frame.schemes[key], scheme)\n                    )\n            else:\n                del schemes[key]\n\n    if len(schemes) == 0:\n        return None\n\n    # concatenate the columns\n    to_cat = lambda key: [frames[i][key] for i in ids if frames[i].num_rows > 0]\n    cols = {key: F.cat(to_cat(key), dim=0) for key in schemes}\n    return Frame(cols)\n\n\ndef combine_names(names, ids=None):\n    \"\"\"Combine the selected names into one new name.\n\n    Parameters\n    ----------\n    names : list of str\n        String names\n    ids : numpy.ndarray, optional\n        Selected index\n\n    Returns\n    -------\n    str\n    \"\"\"\n    if ids is None:\n        return \"+\".join(sorted(names))\n    else:\n        selected = sorted([names[i] for i in ids])\n        return \"+\".join(selected)\n\n\nclass DGLBlock(DGLGraph):\n    \"\"\"Subclass that signifies the graph is a block created from\n    :func:`dgl.to_block`.\n    \"\"\"\n\n    # (BarclayII) I'm making a subclass because I don't want to make another version of\n    # serialization that contains the is_block flag.\n    is_block = True\n\n    def __repr__(self):\n        if (\n            len(self.srctypes) == 1\n            and len(self.dsttypes) == 1\n            and len(self.etypes) == 1\n        ):\n            ret = \"Block(num_src_nodes={srcnode}, num_dst_nodes={dstnode}, num_edges={edge})\"\n            return ret.format(\n                srcnode=self.number_of_src_nodes(),\n                dstnode=self.number_of_dst_nodes(),\n                edge=self.num_edges(),\n            )\n        else:\n            ret = (\n                \"Block(num_src_nodes={srcnode},\\n\"\n                \"      num_dst_nodes={dstnode},\\n\"\n                \"      num_edges={edge},\\n\"\n                \"      metagraph={meta})\"\n            )\n            nsrcnode_dict = {\n                ntype: self.number_of_src_nodes(ntype)\n                for ntype in self.srctypes\n            }\n            ndstnode_dict = {\n                ntype: self.number_of_dst_nodes(ntype)\n                for ntype in self.dsttypes\n            }\n            nedge_dict = {\n                etype: self.num_edges(etype) for etype in self.canonical_etypes\n            }\n            meta = str(self.metagraph().edges(keys=True))\n            return ret.format(\n                srcnode=nsrcnode_dict,\n                dstnode=ndstnode_dict,\n                edge=nedge_dict,\n                meta=meta,\n            )\n\n\ndef _create_compute_graph(graph, u, v, eid, recv_nodes=None):\n    \"\"\"Create a computation graph from the given edges.\n\n    The compute graph is a uni-directional bipartite graph with only\n    one edge type. Similar to subgraph extraction, it stores the original node IDs\n    in the srcdata[NID] and dstdata[NID] and extracts features accordingly.\n    Edges are not relabeled.\n\n    This function is typically used during message passing to generate\n    a graph that contains only the active set of edges.\n\n    Parameters\n    ----------\n    graph : DGLGraph\n        The input graph.\n    u : Tensor\n        Src nodes.\n    v : Tensor\n        Dst nodes.\n    eid : Tensor\n        Edge IDs.\n    recv_nodes : Tensor\n        Nodes that receive messages. If None, it is equal to unique(v).\n        Otherwise, it must be a superset of v and can contain nodes\n        that have no incoming edges.\n\n    Returns\n    -------\n    DGLGraph\n        A computation graph.\n    \"\"\"\n    if len(u) == 0:\n        # The computation graph has no edge and will not trigger message\n        # passing. However, because of the apply node phase, we still construct\n        # an empty graph to continue.\n        unique_src = new_u = new_v = u\n        assert recv_nodes is not None\n        unique_dst, _ = utils.relabel(recv_nodes)\n    else:\n        # relabel u and v to starting from 0\n        unique_src, src_map = utils.relabel(u)\n        if recv_nodes is None:\n            unique_dst, dst_map = utils.relabel(v)\n        else:\n            unique_dst, dst_map = utils.relabel(recv_nodes)\n        new_u = F.gather_row(src_map, u)\n        new_v = F.gather_row(dst_map, v)\n\n    srctype, etype, dsttype = graph.canonical_etypes[0]\n    # create graph\n    hgidx = heterograph_index.create_unitgraph_from_coo(\n        2, len(unique_src), len(unique_dst), new_u, new_v, [\"coo\", \"csr\", \"csc\"]\n    )\n    # create frame\n    srcframe = graph._node_frames[graph.get_ntype_id(srctype)].subframe(\n        unique_src\n    )\n    srcframe[NID] = unique_src\n    dstframe = graph._node_frames[graph.get_ntype_id(dsttype)].subframe(\n        unique_dst\n    )\n    dstframe[NID] = unique_dst\n    eframe = graph._edge_frames[0].subframe(eid)\n    eframe[EID] = eid\n\n    return (\n        DGLGraph(\n            hgidx,\n            ([srctype], [dsttype]),\n            [etype],\n            node_frames=[srcframe, dstframe],\n            edge_frames=[eframe],\n        ),\n        unique_src,\n        unique_dst,\n        eid,\n    )\n\n\n_init_api(\"dgl.heterograph\")\n"
  },
  {
    "path": "python/dgl/heterograph_index.py",
    "content": "\"\"\"Module for heterogeneous graph index class definition.\"\"\"\nfrom __future__ import absolute_import\n\nimport itertools\nimport sys\n\nimport numpy as np\nimport scipy\n\nfrom . import backend as F, utils\nfrom ._ffi.function import _init_api\nfrom ._ffi.object import ObjectBase, register_object\nfrom ._ffi.streams import to_dgl_stream_handle\nfrom .base import dgl_warning, DGLError\nfrom .graph_index import from_coo\n\n\n@register_object(\"graph.HeteroGraph\")\nclass HeteroGraphIndex(ObjectBase):\n    \"\"\"HeteroGraph index object.\n\n    Note\n    ----\n    Do not create GraphIndex directly.\n    \"\"\"\n\n    def __new__(cls):\n        obj = ObjectBase.__new__(cls)\n        obj._cache = {}\n        return obj\n\n    def __getstate__(self):\n        \"\"\"Issue: https://github.com/pytorch/pytorch/issues/32351\n        Need to set the tensor created in the __getstate__ function\n         as object attribute to avoid potential bugs\n        \"\"\"\n        self._pk_state = _CAPI_DGLHeteroPickle(self)\n        return self._pk_state\n\n    def __setstate__(self, state):\n        self._cache = {}\n\n        # Pickle compatibility check\n        # TODO: we should store a storage version number in later releases.\n        if isinstance(state, HeteroPickleStates):\n            # post-0.4.3\n            self.__init_handle_by_constructor__(_CAPI_DGLHeteroUnpickle, state)\n        elif isinstance(state, tuple) and len(state) == 3:\n            # pre-0.4.2\n            metagraph, num_nodes, edges = state\n\n            self._cache = {}\n            # loop over etypes and recover unit graphs\n            rel_graphs = []\n            for i, edges_per_type in enumerate(edges):\n                src_ntype, dst_ntype = metagraph.find_edge(i)\n                num_src = num_nodes[src_ntype]\n                num_dst = num_nodes[dst_ntype]\n                src_id, dst_id, _ = edges_per_type\n                rel_graphs.append(\n                    create_unitgraph_from_coo(\n                        1 if src_ntype == dst_ntype else 2,\n                        num_src,\n                        num_dst,\n                        src_id,\n                        dst_id,\n                        [\"coo\", \"csr\", \" csc\"],\n                    )\n                )\n            self.__init_handle_by_constructor__(\n                _CAPI_DGLHeteroCreateHeteroGraph, metagraph, rel_graphs\n            )\n\n    @property\n    def metagraph(self):\n        \"\"\"Meta graph\n\n        Returns\n        -------\n        GraphIndex\n            The meta graph.\n        \"\"\"\n        return _CAPI_DGLHeteroGetMetaGraph(self)\n\n    def is_metagraph_unibipartite(self):\n        \"\"\"Return whether or not the graph is unibiparite.\"\"\"\n        return _CAPI_DGLHeteroIsMetaGraphUniBipartite(self)\n\n    def number_of_ntypes(self):\n        \"\"\"Return number of node types.\"\"\"\n        return self.metagraph.num_nodes()\n\n    def number_of_etypes(self):\n        \"\"\"Return number of edge types.\"\"\"\n        return self.metagraph.num_edges()\n\n    def get_relation_graph(self, etype):\n        \"\"\"Get the unitgraph graph of the given edge/relation type.\n\n        Parameters\n        ----------\n        etype : int\n            The edge/relation type.\n\n        Returns\n        -------\n        HeteroGraphIndex\n            The unitgraph graph.\n        \"\"\"\n        return _CAPI_DGLHeteroGetRelationGraph(self, int(etype))\n\n    def flatten_relations(self, etypes):\n        \"\"\"Convert the list of requested unitgraph graphs into a single unitgraph\n        graph.\n\n        Parameters\n        ----------\n        etypes : list[int]\n            The edge/relation types.\n\n        Returns\n        -------\n        FlattenedHeteroGraph\n            A flattened heterograph object\n        \"\"\"\n        return _CAPI_DGLHeteroGetFlattenedGraph(self, etypes)\n\n    def add_nodes(self, ntype, num):\n        \"\"\"Add nodes.\n\n        Parameters\n        ----------\n        ntype : int\n            Node type\n        num : int\n            Number of nodes to be added.\n        \"\"\"\n        _CAPI_DGLHeteroAddVertices(self, int(ntype), int(num))\n        self.clear_cache()\n\n    def add_edge(self, etype, u, v):\n        \"\"\"Add one edge.\n\n        Parameters\n        ----------\n        etype : int\n            Edge type\n        u : int\n            The src node.\n        v : int\n            The dst node.\n        \"\"\"\n        _CAPI_DGLHeteroAddEdge(self, int(etype), int(u), int(v))\n        self.clear_cache()\n\n    def add_edges(self, etype, u, v):\n        \"\"\"Add many edges.\n\n        Parameters\n        ----------\n        etype : int\n            Edge type\n        u : utils.Index\n            The src nodes.\n        v : utils.Index\n            The dst nodes.\n        \"\"\"\n        _CAPI_DGLHeteroAddEdges(\n            self, int(etype), u.todgltensor(), v.todgltensor()\n        )\n        self.clear_cache()\n\n    def clear(self):\n        \"\"\"Clear the graph.\"\"\"\n        _CAPI_DGLHeteroClear(self)\n        self._cache.clear()\n\n    @property\n    def dtype(self):\n        \"\"\"Return the data type of this graph index.\n\n        Returns\n        -------\n        DGLDataType\n            The data type of the graph.\n        \"\"\"\n        return _CAPI_DGLHeteroDataType(self)\n\n    @property\n    def ctx(self):\n        \"\"\"Return the context of this graph index.\n\n        Returns\n        -------\n        DGLContext\n            The context of the graph.\n        \"\"\"\n        return _CAPI_DGLHeteroContext(self)\n\n    def bits_needed(self, etype):\n        \"\"\"Return the number of integer bits needed to represent the unitgraph graph.\n\n        Parameters\n        ----------\n        etype : int\n            The edge type.\n\n        Returns\n        -------\n        int\n            The number of bits needed.\n        \"\"\"\n        stype, dtype = self.metagraph.find_edge(etype)\n        if (\n            self.num_edges(etype) >= 0x80000000\n            or self.num_nodes(stype) >= 0x80000000\n            or self.num_nodes(dtype) >= 0x80000000\n        ):\n            return 64\n        else:\n            return 32\n\n    def asbits(self, bits):\n        \"\"\"Transform the graph to a new one with the given number of bits storage.\n\n        NOTE: this method only works for immutable graph index\n\n        Parameters\n        ----------\n        bits : int\n            The number of integer bits (32 or 64)\n\n        Returns\n        -------\n        HeteroGraphIndex\n            The graph index stored using the given number of bits.\n        \"\"\"\n        return _CAPI_DGLHeteroAsNumBits(self, int(bits))\n\n    def copy_to(self, ctx):\n        \"\"\"Copy this immutable graph index to the given device context.\n\n        NOTE: this method only works for immutable graph index\n\n        Parameters\n        ----------\n        ctx : DGLContext\n            The target device context.\n\n        Returns\n        -------\n        HeteroGraphIndex\n            The graph index on the given device context.\n        \"\"\"\n        return _CAPI_DGLHeteroCopyTo(self, ctx.device_type, ctx.device_id)\n\n    def pin_memory(self):\n        \"\"\"Copies the graph structure to pinned memory, if it's not already\n        pinned.\n\n        NOTE: This function is similar to PyTorch's Tensor.pin_memory(), but\n              tailored for graphs. It utilizes the same pin_memory allocator as\n              PyTorch, so the lifecycle of the graph is also managed by PyTorch.\n              If a batch includes a DGL graph object (HeteroGraphIndex),\n              PyTorch's DataLoader memory pinning logic will detect it and\n              automatically activate this function when pin_memory=True.\n\n        Returns\n        -------\n        HeteroGraphIndex\n            The pinned graph index.\n        \"\"\"\n        return _CAPI_DGLHeteroPinMemory(self)\n\n    def pin_memory_(self):\n        \"\"\"Pin this graph to the page-locked memory.\n\n        NOTE: This is an inplace method to pin the current graph index, i.e.,\n              it does not require new memory allocation but simply flags the\n              existing graph structure to be page-locked. The graph structure\n              must be on CPU to be pinned. If the graph struture is already\n              pinned, the function directly returns it.\n\n        Returns\n        -------\n        HeteroGraphIndex\n            The pinned graph index.\n        \"\"\"\n        return _CAPI_DGLHeteroPinMemory_(self)\n\n    def unpin_memory_(self):\n        \"\"\"Unpin this graph from the page-locked memory.\n\n        NOTE: this is an inplace method.\n              If the graph struture is not pinned, e.g., on CPU or GPU,\n              the function directly returns it.\n\n        Returns\n        -------\n        HeteroGraphIndex\n            The unpinned graph index.\n        \"\"\"\n        return _CAPI_DGLHeteroUnpinMemory_(self)\n\n    def is_pinned(self):\n        \"\"\"Check if this graph is pinned to the page-locked memory.\n\n        Returns\n        -------\n        bool\n            True if the graph is pinned.\n        \"\"\"\n        return bool(_CAPI_DGLHeteroIsPinned(self))\n\n    def record_stream(self, stream):\n        \"\"\"Record the stream that is using this graph.\n\n        Parameters\n        ----------\n        stream : torch.cuda.Stream\n            The stream that is using this graph.\n\n        Returns\n        -------\n        HeteroGraphIndex\n            self.\n        \"\"\"\n        return _CAPI_DGLHeteroRecordStream(self, to_dgl_stream_handle(stream))\n\n    def shared_memory(\n        self, name, ntypes=None, etypes=None, formats=(\"coo\", \"csr\", \"csc\")\n    ):\n        \"\"\"Return a copy of this graph in shared memory\n\n        Parameters\n        ----------\n        name : str\n            The name of the shared memory.\n        ntypes : list of str\n            Name of node types\n        etypes : list of str\n            Name of edge types\n        format : list of str\n            Desired formats to be materialized.\n\n        Returns\n        -------\n        HeteroGraphIndex\n            The graph index in shared memory\n        \"\"\"\n        assert len(name) > 0, \"The name of shared memory cannot be empty\"\n        assert len(formats) > 0\n        for fmt in formats:\n            assert fmt in (\"coo\", \"csr\", \"csc\")\n        ntypes = [] if ntypes is None else ntypes\n        etypes = [] if etypes is None else etypes\n        return _CAPI_DGLHeteroCopyToSharedMem(\n            self, name, ntypes, etypes, formats\n        )\n\n    def is_multigraph(self):\n        \"\"\"Return whether the graph is a multigraph\n        The time cost will be O(E)\n\n        Returns\n        -------\n        bool\n            True if it is a multigraph, False otherwise.\n        \"\"\"\n        return bool(_CAPI_DGLHeteroIsMultigraph(self))\n\n    def is_readonly(self):\n        \"\"\"Return whether the graph index is read-only.\n\n        Returns\n        -------\n        bool\n            True if it is a read-only graph, False otherwise.\n        \"\"\"\n        return bool(_CAPI_DGLHeteroIsReadonly(self))\n\n    def num_nodes(self, ntype):\n        \"\"\"Return the number of nodes.\n\n        Parameters\n        ----------\n        ntype : int\n            Node type.\n\n        Returns\n        -------\n        int\n            The number of nodes.\n        \"\"\"\n        return _CAPI_DGLHeteroNumVertices(self, int(ntype))\n\n    def num_edges(self, etype):\n        \"\"\"Return the number of edges.\n\n        Parameters\n        ----------\n        etype : int\n            Edge type.\n\n        Returns\n        -------\n        int\n            The number of edges.\n        \"\"\"\n        return _CAPI_DGLHeteroNumEdges(self, int(etype))\n\n    # TODO(#5485): remove this method.\n    def number_of_nodes(self, ntype):\n        \"\"\"Return the number of nodes.\n\n        Parameters\n        ----------\n        ntype : int\n            Node type\n\n        Returns\n        -------\n        int\n            The number of nodes\n        \"\"\"\n        return _CAPI_DGLHeteroNumVertices(self, int(ntype))\n\n    # TODO(#5485): remove this method.\n    def number_of_edges(self, etype):\n        \"\"\"Return the number of edges.\n\n        Parameters\n        ----------\n        etype : int\n            Edge type\n\n        Returns\n        -------\n        int\n            The number of edges\n        \"\"\"\n        return _CAPI_DGLHeteroNumEdges(self, int(etype))\n\n    def has_nodes(self, ntype, vids):\n        \"\"\"Return true if the nodes exist.\n\n        Parameters\n        ----------\n        ntype : int\n            Node type\n        vid : Tensor\n            Node IDs\n\n        Returns\n        -------\n        Tensor\n            0-1 array indicating existence\n        \"\"\"\n        return F.from_dgl_nd(\n            _CAPI_DGLHeteroHasVertices(self, int(ntype), F.to_dgl_nd(vids))\n        )\n\n    def has_edges_between(self, etype, u, v):\n        \"\"\"Return true if the edge exists.\n\n        Parameters\n        ----------\n        etype : int\n            Edge type\n        u : Tensor\n            Src node Ids.\n        v : Tensor\n            Dst node Ids.\n\n        Returns\n        -------\n        Tensor\n            0-1 array indicating existence\n        \"\"\"\n        return F.from_dgl_nd(\n            _CAPI_DGLHeteroHasEdgesBetween(\n                self, int(etype), F.to_dgl_nd(u), F.to_dgl_nd(v)\n            )\n        )\n\n    def predecessors(self, etype, v):\n        \"\"\"Return the predecessors of the node.\n\n        Assume that node_type(v) == dst_type(etype). Thus, the ntype argument is omitted.\n\n        Parameters\n        ----------\n        etype : int\n            Edge type\n        v : int\n            The node.\n\n        Returns\n        -------\n        Tensor\n            Array of predecessors\n        \"\"\"\n        return F.from_dgl_nd(\n            _CAPI_DGLHeteroPredecessors(self, int(etype), int(v))\n        )\n\n    def successors(self, etype, v):\n        \"\"\"Return the successors of the node.\n\n        Assume that node_type(v) == src_type(etype). Thus, the ntype argument is omitted.\n\n        Parameters\n        ----------\n        etype : int\n            Edge type\n        v : int\n            The node.\n\n        Returns\n        -------\n        Tensor\n            Array of successors\n        \"\"\"\n        return F.from_dgl_nd(\n            _CAPI_DGLHeteroSuccessors(self, int(etype), int(v))\n        )\n\n    def edge_ids_all(self, etype, u, v):\n        \"\"\"Return a triplet of arrays that contains the edge IDs.\n\n        Parameters\n        ----------\n        etype : int\n            Edge type\n        u : Tensor\n            The src nodes.\n        v : Tensor\n            The dst nodes.\n\n        Returns\n        -------\n        Tensor\n            The src nodes.\n        Tensor\n            The dst nodes.\n        Tensor\n            The edge ids.\n        \"\"\"\n        edge_array = _CAPI_DGLHeteroEdgeIdsAll(\n            self, int(etype), F.to_dgl_nd(u), F.to_dgl_nd(v)\n        )\n\n        src = F.from_dgl_nd(edge_array(0))\n        dst = F.from_dgl_nd(edge_array(1))\n        eid = F.from_dgl_nd(edge_array(2))\n\n        return src, dst, eid\n\n    def edge_ids_one(self, etype, u, v):\n        \"\"\"Return an arrays of edge IDs.\n\n        Parameters\n        ----------\n        etype : int\n            Edge type\n        u : Tensor\n            The src nodes.\n        v : Tensor\n            The dst nodes.\n\n        Returns\n        -------\n        Tensor\n            The edge ids.\n        \"\"\"\n        eid = F.from_dgl_nd(\n            _CAPI_DGLHeteroEdgeIdsOne(\n                self, int(etype), F.to_dgl_nd(u), F.to_dgl_nd(v)\n            )\n        )\n        return eid\n\n    def find_edges(self, etype, eid):\n        \"\"\"Return a triplet of arrays that contains the edge IDs.\n\n        Parameters\n        ----------\n        etype : int\n            Edge type\n        eid : Tensor\n            Edge ids.\n\n        Returns\n        -------\n        Tensor\n            The src nodes.\n        Tensor\n            The dst nodes.\n        Tensor\n            The edge ids.\n        \"\"\"\n        edge_array = _CAPI_DGLHeteroFindEdges(\n            self, int(etype), F.to_dgl_nd(eid)\n        )\n\n        src = F.from_dgl_nd(edge_array(0))\n        dst = F.from_dgl_nd(edge_array(1))\n        eid = F.from_dgl_nd(edge_array(2))\n\n        return src, dst, eid\n\n    def in_edges(self, etype, v):\n        \"\"\"Return the in edges of the node(s).\n\n        Assume that node_type(v) == dst_type(etype). Thus, the ntype argument is omitted.\n\n        Parameters\n        ----------\n        etype : int\n            Edge type\n        v : Tensor\n            Node IDs.\n\n        Returns\n        -------\n        Tensor\n            The src nodes.\n        Tensor\n            The dst nodes.\n        Tensor\n            The edge ids.\n        \"\"\"\n        edge_array = _CAPI_DGLHeteroInEdges_2(self, int(etype), F.to_dgl_nd(v))\n        src = F.from_dgl_nd(edge_array(0))\n        dst = F.from_dgl_nd(edge_array(1))\n        eid = F.from_dgl_nd(edge_array(2))\n        return src, dst, eid\n\n    def out_edges(self, etype, v):\n        \"\"\"Return the out edges of the node(s).\n\n        Assume that node_type(v) == src_type(etype). Thus, the ntype argument is omitted.\n\n        Parameters\n        ----------\n        etype : int\n            Edge type\n        v : Tensor\n            Node IDs.\n\n        Returns\n        -------\n        Tensor\n            The src nodes.\n        Tensor\n            The dst nodes.\n        Tensor\n            The edge ids.\n        \"\"\"\n        edge_array = _CAPI_DGLHeteroOutEdges_2(self, int(etype), F.to_dgl_nd(v))\n        src = F.from_dgl_nd(edge_array(0))\n        dst = F.from_dgl_nd(edge_array(1))\n        eid = F.from_dgl_nd(edge_array(2))\n        return src, dst, eid\n\n    def edges(self, etype, order=None):\n        \"\"\"Return all the edges\n\n        Parameters\n        ----------\n        etype : int\n            Edge type\n        order : string\n            The order of the returned edges. Currently support:\n\n            - 'srcdst' : sorted by their src and dst ids.\n            - 'eid'    : sorted by edge Ids.\n            - None     : the arbitrary order.\n\n        Returns\n        -------\n        Tensor\n            The src nodes.\n        Tensor\n            The dst nodes.\n        Tensor\n            The edge ids.\n        \"\"\"\n        if order is None:\n            order = \"\"\n        elif order not in [\"srcdst\", \"eid\"]:\n            raise DGLError(\n                \"Expect order to be one of None, 'srcdst', 'eid', \"\n                \"got {}\".format(order)\n            )\n        edge_array = _CAPI_DGLHeteroEdges(self, int(etype), order)\n        src = F.from_dgl_nd(edge_array(0))\n        dst = F.from_dgl_nd(edge_array(1))\n        eid = F.from_dgl_nd(edge_array(2))\n        return src, dst, eid\n\n    def in_degrees(self, etype, v):\n        \"\"\"Return the in degrees of the nodes.\n\n        Assume that node_type(v) == dst_type(etype). Thus, the ntype argument is omitted.\n\n        Parameters\n        ----------\n        etype : int\n            Edge type\n        v : Tensor\n            The nodes.\n\n        Returns\n        -------\n        Tensor\n            The in degree array.\n        \"\"\"\n        return F.from_dgl_nd(\n            _CAPI_DGLHeteroInDegrees(self, int(etype), F.to_dgl_nd(v))\n        )\n\n    def out_degrees(self, etype, v):\n        \"\"\"Return the out degrees of the nodes.\n\n        Assume that node_type(v) == src_type(etype). Thus, the ntype argument is omitted.\n\n        Parameters\n        ----------\n        etype : int\n            Edge type\n        v : Tensor\n            The nodes.\n\n        Returns\n        -------\n        Tensor\n            The out degree array.\n        \"\"\"\n        return F.from_dgl_nd(\n            _CAPI_DGLHeteroOutDegrees(self, int(etype), F.to_dgl_nd(v))\n        )\n\n    def adjacency_matrix(self, etype, transpose, ctx):\n        \"\"\"Return the adjacency matrix representation of this graph.\n\n        By default, a row of returned adjacency matrix represents the source\n        of an edge and the column represents the destination.\n\n        When transpose is True, a row represents the destination and a column represents\n        the source.\n\n        Parameters\n        ----------\n        etype : int\n            Edge type\n        transpose : bool\n            A flag to transpose the returned adjacency matrix.\n        ctx : context\n            The context of the returned matrix.\n\n        Returns\n        -------\n        SparseTensor\n            The adjacency matrix.\n        Tensor\n            A index for data shuffling due to sparse format change. Return None\n            if shuffle is not required.\n        \"\"\"\n        if not isinstance(transpose, bool):\n            raise DGLError(\n                'Expect bool value for \"transpose\" arg,'\n                \" but got %s.\" % (type(transpose))\n            )\n        fmt = F.get_preferred_sparse_format()\n        rst = _CAPI_DGLHeteroGetAdj(self, int(etype), transpose, fmt)\n        # convert to framework-specific sparse matrix\n        srctype, dsttype = self.metagraph.find_edge(etype)\n        nrows = (\n            self.num_nodes(dsttype) if transpose else self.num_nodes(srctype)\n        )\n        ncols = (\n            self.num_nodes(srctype) if transpose else self.num_nodes(dsttype)\n        )\n        nnz = self.num_edges(etype)\n        if fmt == \"csr\":\n            indptr = F.copy_to(F.from_dgl_nd(rst(0)), ctx)\n            indices = F.copy_to(F.from_dgl_nd(rst(1)), ctx)\n            shuffle = F.copy_to(F.from_dgl_nd(rst(2)), ctx)\n            dat = F.ones(\n                nnz, dtype=F.float32, ctx=ctx\n            )  # FIXME(minjie): data type\n            spmat = F.sparse_matrix(\n                dat, (\"csr\", indices, indptr), (nrows, ncols)\n            )[0]\n            return spmat, shuffle\n        elif fmt == \"coo\":\n            idx = F.copy_to(F.from_dgl_nd(rst(0)), ctx)\n            idx = F.reshape(idx, (2, nnz))\n            dat = F.ones((nnz,), dtype=F.float32, ctx=ctx)\n            adj, shuffle_idx = F.sparse_matrix(\n                dat, (\"coo\", idx), (nrows, ncols)\n            )\n            return adj, shuffle_idx\n        else:\n            raise Exception(\"unknown format\")\n\n    def adjacency_matrix_tensors(self, etype, transpose, fmt):\n        \"\"\"Return the adjacency matrix as a triplet of tensors.\n\n        By default, a row of returned adjacency matrix represents the source\n        of an edge and the column represents the destination.\n\n        When transpose is True, a row represents the destination and a column represents\n        the source.\n\n        Parameters\n        ----------\n        etype : int\n            Edge type\n        transpose : bool\n            A flag to transpose the returned adjacency matrix.\n        fmt : str\n            Indicates the format of returned adjacency matrix.\n\n        Returns\n        -------\n        tuple[int, int, Tensor, Tensor] or tuple[int, int, Tensor, Tensor, Tensor]\n            The number of rows and columns, followed by the adjacency matrix tensors\n            whose data type and device are the same as those of the graph.\n\n            If :attr:`fmt` is ``'coo'``, then the triplet will be\n            the row array and column array of the COO representation.\n\n            If :attr:`fmt` is ``'csr'``, then the triplet will be\n            the index pointer array (``indptr``), indices array, and data array\n            of the CSR representation.  The data array will contain the edge ID for\n            each entry of the adjacency matrix.  If the data array is empty, then it is\n            equivalent to a consecutive array from zero to the number of edges minus one.\n        \"\"\"\n        if not isinstance(transpose, bool):\n            raise DGLError(\n                'Expect bool value for \"transpose\" arg,'\n                \" but got %s.\" % (type(transpose))\n            )\n\n        rst = _CAPI_DGLHeteroGetAdj(self, int(etype), transpose, fmt)\n        srctype, dsttype = self.metagraph.find_edge(etype)\n        nrows = (\n            self.num_nodes(dsttype) if transpose else self.num_nodes(srctype)\n        )\n        ncols = (\n            self.num_nodes(srctype) if transpose else self.num_nodes(dsttype)\n        )\n        nnz = self.num_edges(etype)\n        if fmt == \"csr\":\n            indptr = F.from_dgl_nd(rst(0))\n            indices = F.from_dgl_nd(rst(1))\n            data = F.from_dgl_nd(rst(2))\n            return nrows, ncols, indptr, indices, data\n        elif fmt == \"coo\":\n            idx = F.from_dgl_nd(rst(0))\n            row, col = F.reshape(idx, (2, nnz))\n            return nrows, ncols, row, col\n        else:\n            raise ValueError(\"unknown format\")\n\n    def adjacency_matrix_scipy(\n        self, etype, transpose, fmt, return_edge_ids=None\n    ):\n        \"\"\"Return the scipy adjacency matrix representation of this graph.\n\n        By default, a row of returned adjacency matrix represents the destination\n        of an edge and the column represents the source.\n\n        When transpose is True, a row represents the source and a column represents\n        a destination.\n\n        Parameters\n        ----------\n        etype : int\n            Edge type\n        transpose : bool\n            A flag to transpose the returned adjacency matrix.\n        fmt : str\n            Indicates the format of returned adjacency matrix.\n        return_edge_ids : bool\n            Indicates whether to return edge IDs or 1 as elements.\n\n        Returns\n        -------\n        scipy.sparse.spmatrix\n            The scipy representation of adjacency matrix.\n        \"\"\"\n        if return_edge_ids is None:\n            dgl_warning(\n                \"Adjacency matrix by default currently returns edge IDs.\"\n                \"  As a result there is one 0 entry which is not eliminated.\"\n                \"  In the next release it will return 1s by default,\"\n                \" and 0 will be eliminated otherwise.\",\n                FutureWarning,\n            )\n            return_edge_ids = True\n\n        if fmt == \"csr\":\n            nrows, ncols, indptr, indices, data = self.adjacency_matrix_tensors(\n                etype, transpose, fmt\n            )\n            indptr = F.asnumpy(indptr)\n            indices = F.asnumpy(indices)\n            data = F.asnumpy(data)\n\n            # Check if edge ID is omitted\n            if return_edge_ids and data.shape[0] == 0:\n                data = np.arange(self.num_edges(etype))\n            else:\n                data = np.ones_like(indices)\n\n            return scipy.sparse.csr_matrix(\n                (data, indices, indptr), shape=(nrows, ncols)\n            )\n        elif fmt == \"coo\":\n            nrows, ncols, row, col = self.adjacency_matrix_tensors(\n                etype, transpose, fmt\n            )\n            row = F.asnumpy(row)\n            col = F.asnumpy(col)\n            data = (\n                np.arange(self.num_edges(etype))\n                if return_edge_ids\n                else np.ones_like(row)\n            )\n            return scipy.sparse.coo_matrix(\n                (data, (row, col)), shape=(nrows, ncols)\n            )\n        else:\n            raise ValueError(\"unknown format\")\n\n    def incidence_matrix(self, etype, typestr, ctx):\n        \"\"\"Return the incidence matrix representation of this graph.\n\n        An incidence matrix is an n x m sparse matrix, where n is\n        the number of nodes and m is the number of edges. Each nnz\n        value indicating whether the edge is incident to the node\n        or not.\n\n        There are three types of an incidence matrix `I`:\n        * \"in\":\n          - I[v, e] = 1 if e is the in-edge of v (or v is the dst node of e);\n          - I[v, e] = 0 otherwise.\n        * \"out\":\n          - I[v, e] = 1 if e is the out-edge of v (or v is the src node of e);\n          - I[v, e] = 0 otherwise.\n        * \"both\":\n          - I[v, e] = 1 if e is the in-edge of v;\n          - I[v, e] = -1 if e is the out-edge of v;\n          - I[v, e] = 0 otherwise (including self-loop).\n\n        Parameters\n        ----------\n        etype : int\n            Edge type\n        typestr : str\n            Can be either \"in\", \"out\" or \"both\"\n        ctx : context\n            The context of returned incidence matrix.\n\n        Returns\n        -------\n        SparseTensor\n            The incidence matrix.\n        utils.Index\n            A index for data shuffling due to sparse format change. Return None\n            if shuffle is not required.\n        \"\"\"\n        src, dst, eid = self.edges(etype)\n        srctype, dsttype = self.metagraph.find_edge(etype)\n\n        m = self.num_edges(etype)\n        if typestr == \"in\":\n            n = self.num_nodes(dsttype)\n            row = F.unsqueeze(dst, 0)\n            col = F.unsqueeze(eid, 0)\n            idx = F.copy_to(F.cat([row, col], dim=0), ctx)\n            # FIXME(minjie): data type\n            dat = F.ones((m,), dtype=F.float32, ctx=ctx)\n            inc, shuffle_idx = F.sparse_matrix(dat, (\"coo\", idx), (n, m))\n        elif typestr == \"out\":\n            n = self.num_nodes(srctype)\n            row = F.unsqueeze(src, 0)\n            col = F.unsqueeze(eid, 0)\n            idx = F.copy_to(F.cat([row, col], dim=0), ctx)\n            # FIXME(minjie): data type\n            dat = F.ones((m,), dtype=F.float32, ctx=ctx)\n            inc, shuffle_idx = F.sparse_matrix(dat, (\"coo\", idx), (n, m))\n        elif typestr == \"both\":\n            assert (\n                srctype == dsttype\n            ), \"'both' is supported only if source and destination type are the same\"\n            n = self.num_nodes(srctype)\n            # first remove entries for self loops\n            mask = F.logical_not(F.equal(src, dst))\n            src = F.boolean_mask(src, mask)\n            dst = F.boolean_mask(dst, mask)\n            eid = F.boolean_mask(eid, mask)\n            n_entries = F.shape(src)[0]\n            # create index\n            row = F.unsqueeze(F.cat([src, dst], dim=0), 0)\n            col = F.unsqueeze(F.cat([eid, eid], dim=0), 0)\n            idx = F.copy_to(F.cat([row, col], dim=0), ctx)\n            # FIXME(minjie): data type\n            x = -F.ones((n_entries,), dtype=F.float32, ctx=ctx)\n            y = F.ones((n_entries,), dtype=F.float32, ctx=ctx)\n            dat = F.cat([x, y], dim=0)\n            inc, shuffle_idx = F.sparse_matrix(dat, (\"coo\", idx), (n, m))\n        else:\n            raise DGLError(\"Invalid incidence matrix type: %s\" % str(typestr))\n        return inc, shuffle_idx\n\n    def node_subgraph(self, induced_nodes):\n        \"\"\"Return the induced node subgraph.\n\n        Parameters\n        ----------\n        induced_nodes : list of utils.Index\n            Induced nodes. The length should be equal to the number of\n            node types in this heterograph.\n\n        Returns\n        -------\n        SubgraphIndex\n            The subgraph index.\n        \"\"\"\n        vids = [F.to_dgl_nd(nodes) for nodes in induced_nodes]\n        return _CAPI_DGLHeteroVertexSubgraph(self, vids)\n\n    def edge_subgraph(self, induced_edges, preserve_nodes):\n        \"\"\"Return the induced edge subgraph.\n\n        Parameters\n        ----------\n        induced_edges : list of utils.Index\n            Induced edges. The length should be equal to the number of\n            edge types in this heterograph.\n        preserve_nodes : bool\n            Indicates whether to preserve all nodes or not.\n            If true, keep the nodes which have no edge connected in the subgraph;\n            If false, all nodes without edge connected to it would be removed.\n\n        Returns\n        -------\n        SubgraphIndex\n            The subgraph index.\n        \"\"\"\n        eids = [F.to_dgl_nd(edges) for edges in induced_edges]\n        return _CAPI_DGLHeteroEdgeSubgraph(self, eids, preserve_nodes)\n\n    def get_unitgraph(self, etype, ctx):\n        \"\"\"Create a unitgraph graph from given edge type and copy to the given device\n        context.\n\n        Note: this internal function is for DGL scheduler use only\n\n        Parameters\n        ----------\n        etype : int\n            If the graph index is a Bipartite graph index, this argument must be None.\n            Otherwise, it represents the edge type.\n        ctx : DGLContext\n            The context of the returned graph.\n\n        Returns\n        -------\n        HeteroGraphIndex\n        \"\"\"\n        g = self.get_relation_graph(etype)\n        return g.copy_to(ctx).asbits(self.bits_needed(etype or 0))\n\n    def get_csr_shuffle_order(self, etype):\n        \"\"\"Return the edge shuffling order when a coo graph is converted to csr format\n\n        Parameters\n        ----------\n        etype : int\n            The edge type\n\n        Returns\n        -------\n        tuple of two utils.Index\n            The first element of the tuple is the shuffle order for outward graph\n            The second element of the tuple is the shuffle order for inward graph\n        \"\"\"\n        csr = _CAPI_DGLHeteroGetAdj(self, int(etype), False, \"csr\")\n        order = csr(2)\n        rev_csr = _CAPI_DGLHeteroGetAdj(self, int(etype), True, \"csr\")\n        rev_order = rev_csr(2)\n        return utils.toindex(order, self.dtype), utils.toindex(\n            rev_order, self.dtype\n        )\n\n    def formats(self, formats=None):\n        \"\"\"Get a graph index with the specified allowed sparse format(s) or\n        query for the usage status of sparse formats.\n\n        If the graph has multiple edge types, they will have the same\n        sparse format.\n\n        When ``formats`` is not None, if the intersection between `formats` and\n        the current graph's created sparse format(s) is not empty, the returned\n        cloned graph only retains all sparse format(s) in the intersection. If\n        the intersection is empty, a sparse format will be selected to be\n        created following the order of ``'coo' -> 'csr' -> 'csc'``.\n\n        Parameters\n        ----------\n        formats : str or list of str or None\n\n            * If formats is None, return the usage status of sparse formats\n            * Otherwise, it can be ``'coo'``/``'csr'``/``'csc'`` or a sublist of\n            them, specifying the sparse formats to use.\n\n        Returns\n        -------\n        dict or GraphIndex\n\n            * If formats is None, the result will be a dict recording the usage\n              status of sparse formats.\n            * Otherwise, a GraphIndex will be returned, which is a clone of the\n              original graph with the specified allowed sparse format(s)\n              ``formats``.\n\n        \"\"\"\n        formats_allowed = _CAPI_DGLHeteroGetAllowedFormats(self)\n        formats_created = _CAPI_DGLHeteroGetCreatedFormats(self)\n        created = []\n        not_created = []\n        if formats is None:\n            for fmt in [\"coo\", \"csr\", \"csc\"]:\n                if fmt in formats_allowed:\n                    if fmt in formats_created:\n                        created.append(fmt)\n                    else:\n                        not_created.append(fmt)\n            return {\"created\": created, \"not created\": not_created}\n        else:\n            if isinstance(formats, str):\n                formats = [formats]\n            return _CAPI_DGLHeteroGetFormatGraph(self, formats)\n\n    def create_formats_(self):\n        \"\"\"Create all sparse matrices allowed for the graph.\"\"\"\n        return _CAPI_DGLHeteroCreateFormat(self)\n\n    def reverse(self):\n        \"\"\"Reverse the heterogeneous graph adjacency\n\n        The node types and edge types are not changed.\n\n        Returns\n        -------\n        A new graph index.\n        \"\"\"\n        return _CAPI_DGLHeteroReverse(self)\n\n\n@register_object(\"graph.HeteroSubgraph\")\nclass HeteroSubgraphIndex(ObjectBase):\n    \"\"\"Hetero-subgraph data structure\"\"\"\n\n    @property\n    def graph(self):\n        \"\"\"The subgraph structure\n\n        Returns\n        -------\n        HeteroGraphIndex\n            The subgraph\n        \"\"\"\n        return _CAPI_DGLHeteroSubgraphGetGraph(self)\n\n    @property\n    def induced_nodes(self):\n        \"\"\"Induced nodes for each node type. The return list\n        length should be equal to the number of node types.\n\n        Returns\n        -------\n        list of utils.Index\n            Induced nodes\n        \"\"\"\n        ret = _CAPI_DGLHeteroSubgraphGetInducedVertices(self)\n        return [F.from_dgl_nd(v) for v in ret]\n\n    @property\n    def induced_edges(self):\n        \"\"\"Induced edges for each edge type. The return list\n        length should be equal to the number of edge types.\n\n        Returns\n        -------\n        list of utils.Index\n            Induced edges\n        \"\"\"\n        ret = _CAPI_DGLHeteroSubgraphGetInducedEdges(self)\n        return [F.from_dgl_nd(v) for v in ret]\n\n\n#################################################################\n# Creators\n#################################################################\n\n\ndef create_metagraph_index(ntypes, canonical_etypes):\n    \"\"\"Return a GraphIndex instance for a metagraph given the node types and canonical\n    edge types.\n\n    This function will reorder the node types and canonical edge types.\n\n    Parameters\n    ----------\n    ntypes : Iterable[str]\n        The node types.\n    canonical_etypes : Iterable[tuple[str, str, str]]\n        The canonical edge types.\n\n    Returns\n    -------\n    GraphIndex\n        The index object for metagraph.\n    list[str]\n        The reordered node types for each node in the metagraph.\n    list[str]\n        The reordered edge types for each edge in the metagraph.\n    list[tuple[str, str, str]]\n        The reordered canonical edge types for each edge in the metagraph.\n    \"\"\"\n    # Sort the ntypes and relation tuples to have a deterministic order for the same set\n    # of type names.\n    ntypes = list(sorted(ntypes))\n    relations = list(sorted(canonical_etypes))\n    ntype_dict = {ntype: i for i, ntype in enumerate(ntypes)}\n    meta_edges_src = []\n    meta_edges_dst = []\n    etypes = []\n    for srctype, etype, dsttype in relations:\n        meta_edges_src.append(ntype_dict[srctype])\n        meta_edges_dst.append(ntype_dict[dsttype])\n        etypes.append(etype)\n    # metagraph is DGLGraph, currently still using int64 as index dtype\n    metagraph = from_coo(len(ntypes), meta_edges_src, meta_edges_dst, True)\n    return metagraph, ntypes, etypes, relations\n\n\ndef create_unitgraph_from_coo(\n    num_ntypes,\n    num_src,\n    num_dst,\n    row,\n    col,\n    formats,\n    row_sorted=False,\n    col_sorted=False,\n):\n    \"\"\"Create a unitgraph graph index from COO format\n\n    Parameters\n    ----------\n    num_ntypes : int\n        Number of node types (must be 1 or 2).\n    num_src : int\n        Number of nodes in the src type.\n    num_dst : int\n        Number of nodes in the dst type.\n    row : utils.Index\n        Row index.\n    col : utils.Index\n        Col index.\n    formats : list of str.\n        Restrict the storage formats allowed for the unit graph.\n    row_sorted : bool, optional\n        Whether or not the rows of the COO are in ascending order.\n    col_sorted : bool, optional\n        Whether or not the columns of the COO are in ascending order within\n        each row. This only has an effect when ``row_sorted`` is True.\n\n    Returns\n    -------\n    HeteroGraphIndex\n    \"\"\"\n    if isinstance(formats, str):\n        formats = [formats]\n    return _CAPI_DGLHeteroCreateUnitGraphFromCOO(\n        int(num_ntypes),\n        int(num_src),\n        int(num_dst),\n        F.to_dgl_nd(row),\n        F.to_dgl_nd(col),\n        formats,\n        row_sorted,\n        col_sorted,\n    )\n\n\ndef create_unitgraph_from_csr(\n    num_ntypes,\n    num_src,\n    num_dst,\n    indptr,\n    indices,\n    edge_ids,\n    formats,\n    transpose=False,\n):\n    \"\"\"Create a unitgraph graph index from CSR format\n\n    Parameters\n    ----------\n    num_ntypes : int\n        Number of node types (must be 1 or 2).\n    num_src : int\n        Number of nodes in the src type.\n    num_dst : int\n        Number of nodes in the dst type.\n    indptr : utils.Index\n        CSR indptr.\n    indices : utils.Index\n        CSR indices.\n    edge_ids : utils.Index\n        Edge shuffle id.\n    formats : str\n        Restrict the storage formats allowed for the unit graph.\n    transpose : bool, optional\n        If True, treats the input matrix as CSC.\n\n    Returns\n    -------\n    HeteroGraphIndex\n    \"\"\"\n    if isinstance(formats, str):\n        formats = [formats]\n    return _CAPI_DGLHeteroCreateUnitGraphFromCSR(\n        int(num_ntypes),\n        int(num_src),\n        int(num_dst),\n        F.to_dgl_nd(indptr),\n        F.to_dgl_nd(indices),\n        F.to_dgl_nd(edge_ids),\n        formats,\n        transpose,\n    )\n\n\ndef create_heterograph_from_relations(\n    metagraph, rel_graphs, num_nodes_per_type\n):\n    \"\"\"Create a heterograph from metagraph and graphs of every relation.\n\n    Parameters\n    ----------\n    metagraph : GraphIndex\n        Meta-graph.\n    rel_graphs : list of HeteroGraphIndex\n        Bipartite graph of each relation.\n    num_nodes_per_type : utils.Index, optional\n        Number of nodes per node type\n\n    Returns\n    -------\n    HeteroGraphIndex\n    \"\"\"\n    if num_nodes_per_type is None:\n        return _CAPI_DGLHeteroCreateHeteroGraph(metagraph, rel_graphs)\n    else:\n        return _CAPI_DGLHeteroCreateHeteroGraphWithNumNodes(\n            metagraph, rel_graphs, num_nodes_per_type.todgltensor()\n        )\n\n\ndef create_heterograph_from_shared_memory(name):\n    \"\"\"Create a heterograph from shared memory with the given name.\n\n    Paramaters\n    ----------\n    name : str\n        The name of the share memory\n\n    Returns\n    -------\n    HeteroGraphIndex (in shared memory)\n    ntypes : list of str\n        Names of node types\n    etypes : list of str\n        Names of edge types\n    \"\"\"\n    g, ntypes, etypes = _CAPI_DGLHeteroCreateFromSharedMem(name)\n    return g, list(ntypes), list(etypes)\n\n\ndef joint_union(metagraph, gidx_list):\n    \"\"\"Return a joint union of the input heterographs.\n\n    Parameters\n    ----------\n    metagraph : GraphIndex\n        Meta-graph.\n    gidx_list : list of HeteroGraphIndex\n        Heterographs to be joint_unioned.\n\n    Returns\n    -------\n    HeteroGraphIndex\n        joint_unioned Heterograph.\n    \"\"\"\n    return _CAPI_DGLHeteroJointUnion(metagraph, gidx_list)\n\n\ndef disjoint_union(metagraph, graphs):\n    \"\"\"Return a disjoint union of the input heterographs.\n\n    Parameters\n    ----------\n    metagraph : GraphIndex\n        Meta-graph.\n    graphs : list of HeteroGraphIndex\n        Heterographs to be batched.\n\n    Returns\n    -------\n    HeteroGraphIndex\n        Batched Heterograph.\n    \"\"\"\n    return _CAPI_DGLHeteroDisjointUnion_v2(metagraph, graphs)\n\n\ndef disjoint_partition(graph, bnn_all_types, bne_all_types):\n    \"\"\"Partition the graph disjointly.\n\n    Parameters\n    ----------\n    graph : HeteroGraphIndex\n        The graph to be partitioned.\n    bnn_all_types : list of list of int\n        bnn_all_types[t] gives the number of nodes with t-th type in the batch.\n    bne_all_types : list of list of int\n        bne_all_types[t] gives the number of edges with t-th type in the batch.\n\n    Returns\n    --------\n    list of HeteroGraphIndex\n        Heterographs unbatched.\n    \"\"\"\n    bnn_all_types = utils.toindex(\n        list(itertools.chain.from_iterable(bnn_all_types))\n    )\n    bne_all_types = utils.toindex(\n        list(itertools.chain.from_iterable(bne_all_types))\n    )\n    return _CAPI_DGLHeteroDisjointPartitionBySizes_v2(\n        graph, bnn_all_types.todgltensor(), bne_all_types.todgltensor()\n    )\n\n\ndef slice_gidx(graph, num_nodes, start_nid, num_edges, start_eid):\n    \"\"\"Slice a chunk of the graph.\n\n    Parameters\n    ----------\n    graph : HeteroGraphIndex\n        The batched graph to slice.\n    num_nodes : utils.Index\n        Number of nodes per node type in the result graph.\n    start_nid : utils.Index\n        Start node ID per node type in the result graph.\n    num_edges : utils.Index\n        Number of edges per edge type in the result graph.\n    start_eid : utils.Index\n        Start edge ID per edge type in the result graph.\n\n    Returns\n    -------\n    HeteroGraphIndex\n        The sliced graph.\n    \"\"\"\n    return _CAPI_DGLHeteroSlice(\n        graph,\n        num_nodes.todgltensor(),\n        start_nid.todgltensor(),\n        num_edges.todgltensor(),\n        start_eid.todgltensor(),\n    )\n\n\n#################################################################\n# Data structure used by C APIs\n#################################################################\n\n\n@register_object(\"graph.FlattenedHeteroGraph\")\nclass FlattenedHeteroGraph(ObjectBase):\n    \"\"\"FlattenedHeteroGraph object class in C++ backend.\"\"\"\n\n\n@register_object(\"graph.HeteroPickleStates\")\nclass HeteroPickleStates(ObjectBase):\n    \"\"\"Pickle states object class in C++ backend.\"\"\"\n\n    @property\n    def version(self):\n        \"\"\"Version number\n\n        Returns\n        -------\n        int\n            version number\n        \"\"\"\n        return _CAPI_DGLHeteroPickleStatesGetVersion(self)\n\n    @property\n    def meta(self):\n        \"\"\"Meta info\n\n        Returns\n        -------\n        bytearray\n            Serialized meta info\n        \"\"\"\n        return bytearray(_CAPI_DGLHeteroPickleStatesGetMeta(self))\n\n    @property\n    def arrays(self):\n        \"\"\"Arrays representing the graph structure (COO or CSR)\n\n        Returns\n        -------\n        list of dgl.ndarray.NDArray\n            Arrays\n        \"\"\"\n        num_arr = _CAPI_DGLHeteroPickleStatesGetArraysNum(self)\n        arr_func = _CAPI_DGLHeteroPickleStatesGetArrays(self)\n        return [arr_func(i) for i in range(num_arr)]\n\n    def __getstate__(self):\n        \"\"\"Issue: https://github.com/pytorch/pytorch/issues/32351\n        Need to set the tensor created in the __getstate__ function\n         as object attribute to avoid potential bugs\n        \"\"\"\n        self._pk_arrays = [\n            F.zerocopy_from_dgl_ndarray(arr) for arr in self.arrays\n        ]\n        return self.version, self.meta, self._pk_arrays\n\n    def __setstate__(self, state):\n        if isinstance(state[0], int):\n            version, meta, arrays = state\n            arrays = [F.zerocopy_to_dgl_ndarray(arr) for arr in arrays]\n            self.__init_handle_by_constructor__(\n                _CAPI_DGLCreateHeteroPickleStates, version, meta, arrays\n            )\n        else:\n            metagraph, num_nodes_per_type, adjs = state\n            num_nodes_per_type = F.zerocopy_to_dgl_ndarray(num_nodes_per_type)\n            self.__init_handle_by_constructor__(\n                _CAPI_DGLCreateHeteroPickleStatesOld,\n                metagraph,\n                num_nodes_per_type,\n                adjs,\n            )\n\n\ndef _forking_rebuild(pk_state):\n    version, meta, arrays = pk_state\n    arrays = [F.to_dgl_nd(arr) for arr in arrays]\n    states = _CAPI_DGLCreateHeteroPickleStates(version, meta, arrays)\n    graph_index = _CAPI_DGLHeteroForkingUnpickle(states)\n    graph_index._forking_pk_state = pk_state\n    return graph_index\n\n\ndef _forking_reduce(graph_index):\n    # Because F.from_dgl_nd(F.to_dgl_nd(x)) loses the information of shared memory\n    # file descriptor (because DLPack does not keep it), without caching the tensors\n    # PyTorch will allocate one shared memory region for every single worker.\n    # The downside is that if a graph_index is shared by forking and new formats are created\n    # afterwards, then sharing it again will not bring together the new formats.  This case\n    # should be rare though because (1) DataLoader will create all the formats if num_workers > 0\n    # anyway, and (2) we require the users to explicitly create all formats before calling\n    # mp.spawn().\n    if hasattr(graph_index, \"_forking_pk_state\"):\n        return _forking_rebuild, (graph_index._forking_pk_state,)\n    states = _CAPI_DGLHeteroForkingPickle(graph_index)\n    arrays = [F.from_dgl_nd(arr) for arr in states.arrays]\n    # Similar to what being mentioned in HeteroGraphIndex.__getstate__, we need to save\n    # the tensors as an attribute of the original graph index object.  Otherwise\n    # PyTorch will throw weird errors like bad value(s) in fds_to_keep or unable to\n    # resize file.\n    graph_index._forking_pk_state = (states.version, states.meta, arrays)\n    return _forking_rebuild, (graph_index._forking_pk_state,)\n\n\nif not (F.get_preferred_backend() == \"mxnet\" and sys.version_info.minor <= 6):\n    # Python 3.6 MXNet crashes with the following statement; remove until we no longer support\n    # 3.6 (which is EOL anyway).\n    from multiprocessing.reduction import ForkingPickler\n\n    ForkingPickler.register(HeteroGraphIndex, _forking_reduce)\n\n_init_api(\"dgl.heterograph_index\")\n"
  },
  {
    "path": "python/dgl/homophily.py",
    "content": "\"\"\"Utils for tracking graph homophily and heterophily\"\"\"\n# pylint: disable=W0611\nfrom . import function as fn, to_bidirected\n\ntry:\n    import torch\nexcept ImportError:\n    HAS_TORCH = False\nelse:\n    HAS_TORCH = True\n\n__all__ = [\n    \"node_homophily\",\n    \"edge_homophily\",\n    \"linkx_homophily\",\n    \"adjusted_homophily\",\n]\n\n\ndef check_pytorch():\n    \"\"\"Check if PyTorch is the backend.\"\"\"\n    if HAS_TORCH is False:\n        raise ModuleNotFoundError(\n            \"This function requires PyTorch to be the backend.\"\n        )\n\n\ndef get_long_edges(graph):\n    \"\"\"Internal function for getting the edges of a graph as long tensors.\"\"\"\n    src, dst = graph.edges()\n    return src.long(), dst.long()\n\n\ndef node_homophily(graph, y):\n    r\"\"\"Homophily measure from `Geom-GCN: Geometric Graph Convolutional\n    Networks <https://arxiv.org/abs/2002.05287>`__\n\n    We follow the practice of a later paper `Large Scale Learning on\n    Non-Homophilous Graphs: New Benchmarks and Strong Simple Methods\n    <https://arxiv.org/abs/2110.14446>`__ to call it node homophily.\n\n    Mathematically it is defined as follows:\n\n    .. math::\n      \\frac{1}{|\\mathcal{V}|} \\sum_{v \\in \\mathcal{V}} \\frac{ | \\{u\n      \\in \\mathcal{N}(v): y_v = y_u \\} |  } { |\\mathcal{N}(v)| },\n\n    where :math:`\\mathcal{V}` is the set of nodes, :math:`\\mathcal{N}(v)` is\n    the predecessors of node :math:`v`, and :math:`y_v` is the class of node\n    :math:`v`.\n\n    Parameters\n    ----------\n    graph : DGLGraph\n        The graph.\n    y : torch.Tensor\n        The node labels, which is a tensor of shape (|V|).\n\n    Returns\n    -------\n    float\n        The node homophily value.\n\n    Examples\n    --------\n    >>> import dgl\n    >>> import torch\n\n    >>> graph = dgl.graph(([1, 2, 0, 4], [0, 1, 2, 3]))\n    >>> y = torch.tensor([0, 0, 0, 0, 1])\n    >>> dgl.node_homophily(graph, y)\n    0.6000000238418579\n    \"\"\"\n    check_pytorch()\n    with graph.local_scope():\n        # Handle the case where graph is of dtype int32.\n        src, dst = get_long_edges(graph)\n        # Compute y_v = y_u for all edges.\n        graph.edata[\"same_class\"] = (y[src] == y[dst]).float()\n        graph.update_all(\n            fn.copy_e(\"same_class\", \"m\"), fn.mean(\"m\", \"same_class_deg\")\n        )\n        return graph.ndata[\"same_class_deg\"].mean(dim=0).item()\n\n\ndef edge_homophily(graph, y):\n    r\"\"\"Homophily measure from `Beyond Homophily in Graph Neural Networks:\n    Current Limitations and Effective Designs\n    <https://arxiv.org/abs/2006.11468>`__\n\n    Mathematically it is defined as follows:\n\n    .. math::\n      \\frac{| \\{ (u,v) : (u,v) \\in \\mathcal{E} \\wedge y_u = y_v \\} | }\n      {|\\mathcal{E}|},\n\n    where :math:`\\mathcal{E}` is the set of edges, and :math:`y_u` is the class\n    of node :math:`u`.\n\n    Parameters\n    ----------\n    graph : DGLGraph\n        The graph.\n    y : torch.Tensor\n        The node labels, which is a tensor of shape (|V|).\n\n    Returns\n    -------\n    float\n        The edge homophily ratio value.\n\n    Examples\n    --------\n    >>> import dgl\n    >>> import torch\n\n    >>> graph = dgl.graph(([1, 2, 0, 4], [0, 1, 2, 3]))\n    >>> y = torch.tensor([0, 0, 0, 0, 1])\n    >>> dgl.edge_homophily(graph, y)\n    0.75\n    \"\"\"\n    check_pytorch()\n    with graph.local_scope():\n        # Handle the case where graph is of dtype int32.\n        src, dst = get_long_edges(graph)\n        # Compute y_v = y_u for all edges.\n        edge_indicator = (y[src] == y[dst]).float()\n        return edge_indicator.mean(dim=0).item()\n\n\ndef linkx_homophily(graph, y):\n    r\"\"\"Homophily measure from `Large Scale Learning on Non-Homophilous Graphs:\n    New Benchmarks and Strong Simple Methods\n    <https://arxiv.org/abs/2110.14446>`__\n\n    Mathematically it is defined as follows:\n\n    .. math::\n      \\frac{1}{C-1} \\sum_{k=1}^{C} \\max \\left(0, \\frac{\\sum_{v\\in C_k}|\\{u\\in\n      \\mathcal{N}(v): y_v = y_u \\}|}{\\sum_{v\\in C_k}|\\mathcal{N}(v)|} -\n      \\frac{|\\mathcal{C}_k|}{|\\mathcal{V}|} \\right),\n\n    where :math:`C` is the number of node classes, :math:`C_k` is the set of\n    nodes that belong to class k, :math:`\\mathcal{N}(v)` are the predecessors\n    of node :math:`v`, :math:`y_v` is the class of node :math:`v`, and\n    :math:`\\mathcal{V}` is the set of nodes.\n\n    Parameters\n    ----------\n    graph : DGLGraph\n        The graph.\n    y : torch.Tensor\n        The node labels, which is a tensor of shape (|V|).\n\n    Returns\n    -------\n    float\n        The homophily value.\n\n    Examples\n    --------\n    >>> import dgl\n    >>> import torch\n\n    >>> graph = dgl.graph(([0, 1, 2, 3], [1, 2, 0, 4]))\n    >>> y = torch.tensor([0, 0, 0, 0, 1])\n    >>> dgl.linkx_homophily(graph, y)\n    0.19999998807907104\n    \"\"\"\n    check_pytorch()\n    with graph.local_scope():\n        # Compute |{u\\in N(v): y_v = y_u}| for each node v.\n        # Handle the case where graph is of dtype int32.\n        src, dst = get_long_edges(graph)\n        # Compute y_v = y_u for all edges.\n        graph.edata[\"same_class\"] = (y[src] == y[dst]).float()\n        graph.update_all(\n            fn.copy_e(\"same_class\", \"m\"), fn.sum(\"m\", \"same_class_deg\")\n        )\n\n        deg = graph.in_degrees().float()\n        num_nodes = graph.num_nodes()\n        num_classes = y.max(dim=0).values.item() + 1\n\n        value = torch.tensor(0.0).to(graph.device)\n        for k in range(num_classes):\n            # Get the nodes that belong to class k.\n            class_mask = y == k\n            same_class_deg_k = graph.ndata[\"same_class_deg\"][class_mask].sum()\n            deg_k = deg[class_mask].sum()\n            num_nodes_k = class_mask.sum()\n            value += max(0, same_class_deg_k / deg_k - num_nodes_k / num_nodes)\n\n        return value.item() / (num_classes - 1)\n\n\ndef adjusted_homophily(graph, y):\n    r\"\"\"Homophily measure recommended in `Characterizing Graph Datasets for\n    Node Classification: Homophily-Heterophily Dichotomy and Beyond\n    <https://arxiv.org/abs/2209.06177>`__\n\n    Adjusted homophily is edge homophily adjusted for the expected number of\n    edges connecting nodes with the same class label (taking into account the\n    number of classes, their sizes, and the distribution of node degrees among\n    them).\n\n    Mathematically it is defined as follows:\n\n    .. math::\n        \\frac{h_{edge} - \\sum_{k=1}^C \\bar{p}(k)^2}\n        {1 - \\sum_{k=1}^C \\bar{p}(k)^2},\n\n    where :math:`h_{edge}` denotes edge homophily, :math:`C` denotes the\n    number of classes, and :math:`\\bar{p}(\\cdot)` is the empirical\n    degree-weighted distribution of classes:\n    :math:`\\bar{p}(k) = \\frac{\\sum_{v\\,:\\,y_v = k} d(v)}{2|E|}`,\n    where :math:`d(v)` is the degree of node :math:`v`.\n\n    It has been shown that adjusted homophily satisifes more desirable\n    properties than other homophily measures, which makes it appropriate for\n    comparing the levels of homophily across datasets with different number\n    of classes, different class sizes, andd different degree distributions\n    among classes.\n\n    Adjusted homophily can be negative. If adjusted homophily is zero, then\n    the edge pattern in the graph is independent of node class labels. If it\n    is positive, then the nodes in the graph tend to connect to nodes of the\n    same class more often, and if it is negative, than the nodes in the graph\n    tend to connect to nodes of different classes more often (compared to the\n    null model where edges are independent of node class labels).\n\n    Parameters\n    ----------\n    graph : DGLGraph\n        The graph.\n    y : torch.Tensor\n        The node labels, which is a tensor of shape (|V|).\n\n    Returns\n    -------\n    float\n        The adjusted homophily value.\n\n    Examples\n    --------\n    >>> import dgl\n    >>> import torch\n\n    >>> graph = dgl.graph(([1, 2, 0, 4], [0, 1, 2, 3]))\n    >>> y = torch.tensor([0, 0, 0, 0, 1])\n    >>> dgl.adjusted_homophily(graph, y)\n    -0.1428571492433548\n    \"\"\"\n    check_pytorch()\n\n    graph = to_bidirected(graph.cpu()).to(y.device)\n\n    h_edge = edge_homophily(graph, y)\n\n    degrees = graph.in_degrees().float()\n    num_classes = y.max().item() + 1\n    degree_sums = torch.zeros(num_classes).to(y.device)\n    degree_sums.index_add_(dim=0, index=y, source=degrees)\n\n    adjust = (degree_sums**2).sum() / graph.num_edges() ** 2\n\n    h_adj = (h_edge - adjust) / (1 - adjust)\n\n    return h_adj.item()\n"
  },
  {
    "path": "python/dgl/init.py",
    "content": "\"\"\"Module for common feature initializers.\"\"\"\nfrom __future__ import absolute_import\n\nfrom . import backend as F\n\n__all__ = [\"base_initializer\", \"zero_initializer\"]\n\n\ndef base_initializer(\n    shape, dtype, ctx, id_range\n):  # pylint: disable=unused-argument\n    \"\"\"The function signature for feature initializer.\n\n    Any customized feature initializer should follow this signature (see\n    example below).\n\n    Parameters\n    ----------\n    shape : tuple of int\n        The shape of the result features. The first dimension\n        is the batch dimension.\n    dtype : data type object\n        The data type of the returned features.\n    ctx : context object\n        The device context of the returned features.\n    id_range : slice\n        The start id and the end id of the features to be initialized.\n        The id could be node or edge id depending on the scenario.\n        Note that the step is always None.\n\n    Examples\n    --------\n    If PyTorch is used as backend, the following code defines an feature\n    initializer that initializes tensor value to 1\n\n    >>> import torch\n    >>> import dgl\n    >>> def initializer(shape, dtype, ctx, id_range):\n    >>>     return torch.ones(shape, dtype=dtype, device=ctx)\n    >>> g = dgl.DGLGraph()\n    >>> g.set_n_initializer(initializer)\n\n    See Also\n    --------\n    dgl.DGLGraph.set_n_initializer\n    dgl.DGLGraph.set_e_initializer\n    \"\"\"\n    raise NotImplementedError\n\n\ndef zero_initializer(\n    shape, dtype, ctx, id_range\n):  # pylint: disable=unused-argument\n    \"\"\"Zero feature initializer\n\n    Examples\n    --------\n    >>> import dgl\n    >>> g = dgl.DGLGraph()\n    >>> g.set_n_initializer(dgl.init.zero_initializer)\n\n    See Also\n    --------\n    dgl.DGLGraph.set_n_initializer\n    dgl.DGLGraph.set_e_initializer\n    \"\"\"\n    return F.zeros(shape, dtype, ctx)\n"
  },
  {
    "path": "python/dgl/label_informativeness.py",
    "content": "\"\"\"Utils for computing graph label informativeness\"\"\"\nfrom . import to_bidirected\n\ntry:\n    import torch\nexcept ImportError:\n    HAS_TORCH = False\nelse:\n    HAS_TORCH = True\n\n__all__ = [\"edge_label_informativeness\", \"node_label_informativeness\"]\n\n\ndef check_pytorch():\n    \"\"\"Check if PyTorch is the backend.\"\"\"\n    if HAS_TORCH is False:\n        raise ModuleNotFoundError(\n            \"This function requires PyTorch to be the backend.\"\n        )\n\n\ndef edge_label_informativeness(graph, y, eps=1e-8):\n    r\"\"\"Label informativeness (:math:`\\mathrm{LI}`) is a characteristic of\n    labeled graphs proposed in the `Characterizing Graph Datasets for Node\n    Classification: Homophily-Heterophily Dichotomy and Beyond\n    <https://arxiv.org/abs/2209.06177>`__\n\n    Label informativeness shows how much information about a node's label we\n    get from knowing its neighbor's label. Formally, assume that we sample an\n    edge :math:`(\\xi,\\eta) \\in E`. The class labels of nodes :math:`\\xi` and\n    :math:`\\eta` are then random variables :math:`y_\\xi` and :math:`y_\\eta`.\n    We want to measure the amount of knowledge the label :math:`y_\\eta` gives\n    for predicting :math:`y_\\xi`. The entropy :math:`H(y_\\xi)` measures the\n    `hardness' of predicting the label of :math:`\\xi` without knowing\n    :math:`y_\\eta`. Given :math:`y_\\eta`, this value is reduced to the\n    conditional entropy :math:`H(y_\\xi|y_\\eta)`. In other words, :math:`y_\\eta`\n    reveals :math:`I(y_\\xi,y_\\eta) = H(y_\\xi) - H(y_\\xi|y_\\eta)` information\n    about the label. To make the obtained quantity comparable across different\n    datasets, label informativeness is defined as the normalized mutual\n    information of :math:`y_{\\xi}` and :math:`y_{\\eta}`:\n\n    .. math::\n      \\mathrm{LI} = \\frac{I(y_\\xi,y_\\eta)}{H(y_\\xi)}\n\n    Depending on the distribution used for sampling an edge\n    :math:`(\\xi, \\eta)`, several variants of label informativeness can be\n    obtained. Two of them are particularly intuitive: in edge label\n    informativeness (:math:`\\mathrm{LI}_{edge}`), edges are sampled uniformly\n    at random, and in node label informativeness (:math:`\\mathrm{LI}_{node}`),\n    first a node is sampled uniformly at random and then an edge incident to it\n    is sampled uniformly at random. These two versions of label informativeness\n    differ in how they weight high/low-degree nodes. In edge label\n    informativeness, averaging is over the edges, thus high-degree nodes are\n    given more weight. In node label informativeness, averaging is over the\n    nodes, so all nodes are weighted equally.\n\n    This function computes edge label informativeness.\n\n    Parameters\n    ----------\n    graph : DGLGraph\n        The graph.\n    y : torch.Tensor\n        The node labels, which is a tensor of shape (|V|).\n    eps : float, optional\n        A small constant for numerical stability. (default: 1e-8)\n\n    Returns\n    -------\n    float\n        The edge label informativeness value.\n\n    Examples\n    --------\n    >>> import dgl\n    >>> import torch\n\n    >>> graph = dgl.graph(([0, 1, 2, 2, 3, 4], [1, 2, 0, 3, 4, 5]))\n    >>> y = torch.tensor([0, 0, 0, 0, 1, 1])\n    >>> dgl.edge_label_informativeness(graph, y)\n    0.25177597999572754\n    \"\"\"\n    check_pytorch()\n\n    graph = to_bidirected(graph.cpu()).to(y.device)\n\n    degrees = graph.in_degrees().float()\n    num_classes = y.max() + 1\n    class_degree_weighted_probs = torch.zeros(num_classes).to(y.device)\n    class_degree_weighted_probs.index_add_(dim=0, index=y, source=degrees)\n    class_degree_weighted_probs /= class_degree_weighted_probs.sum()\n\n    edge_probs = torch.zeros(num_classes, num_classes).to(y.device)\n    labels_u = y[graph.edges()[0].long()]\n    labels_v = y[graph.edges()[1].long()]\n    edge_probs.index_put_(\n        indices=(labels_u, labels_v),\n        values=torch.ones(graph.num_edges()).to(y.device),\n        accumulate=True,\n    )\n    edge_probs /= edge_probs.sum()\n    edge_probs += eps\n\n    numerator = (edge_probs * torch.log(edge_probs)).sum()\n    denominator = (\n        class_degree_weighted_probs * torch.log(class_degree_weighted_probs)\n    ).sum()\n    li_edge = 2 - numerator / denominator\n\n    return li_edge.item()\n\n\ndef node_label_informativeness(graph, y, eps=1e-8):\n    r\"\"\"Label informativeness (:math:`\\mathrm{LI}`) is a characteristic of\n    labeled graphs proposed in the `Characterizing Graph Datasets for Node\n    Classification: Homophily-Heterophily Dichotomy and Beyond\n    <https://arxiv.org/abs/2209.06177>`__\n\n    Label informativeness shows how much information about a node's label we\n    get from knowing its neighbor's label. Formally, assume that we sample an\n    edge :math:`(\\xi,\\eta) \\in E`. The class labels of nodes :math:`\\xi` and\n    :math:`\\eta` are then random variables :math:`y_\\xi` and :math:`y_\\eta`.\n    We want to measure the amount of knowledge the label :math:`y_\\eta` gives\n    for predicting :math:`y_\\xi`. The entropy :math:`H(y_\\xi)` measures the\n    `hardness' of predicting the label of :math:`\\xi` without knowing\n    :math:`y_\\eta`. Given :math:`y_\\eta`, this value is reduced to the\n    conditional entropy :math:`H(y_\\xi|y_\\eta)`. In other words, :math:`y_\\eta`\n    reveals :math:`I(y_\\xi,y_\\eta) = H(y_\\xi) - H(y_\\xi|y_\\eta)` information\n    about the label. To make the obtained quantity comparable across different\n    datasets, label informativeness is defined as the normalized mutual\n    information of :math:`y_{\\xi}` and :math:`y_{\\eta}`:\n\n    .. math::\n      \\mathrm{LI} = \\frac{I(y_\\xi,y_\\eta)}{H(y_\\xi)}\n\n    Depending on the distribution used for sampling an edge\n    :math:`(\\xi, \\eta)`, several variants of label informativeness can be\n    obtained. Two of them are particularly intuitive: in edge label\n    informativeness (:math:`\\mathrm{LI}_{edge}`), edges are sampled uniformly\n    at random, and in node label informativeness (:math:`\\mathrm{LI}_{node}`),\n    first a node is sampled uniformly at random and then an edge incident to it\n    is sampled uniformly at random. These two versions of label informativeness\n    differ in how they weight high/low-degree nodes. In edge label\n    informativeness, averaging is over the edges, thus high-degree nodes are\n    given more weight. In node label informativeness, averaging is over the\n    nodes, so all nodes are weighted equally.\n\n    This function computes node label informativeness.\n\n    Parameters\n    ----------\n    graph : DGLGraph\n        The graph.\n    y : torch.Tensor\n        The node labels, which is a tensor of shape (|V|).\n    eps : float, optional\n        A small constant for numerical stability. (default: 1e-8)\n\n    Returns\n    -------\n    float\n        The node label informativeness value.\n\n    Examples\n    --------\n    >>> import dgl\n    >>> import torch\n\n    >>> graph = dgl.graph(([0, 1, 2, 2, 3, 4], [1, 2, 0, 3, 4, 5]))\n    >>> y = torch.tensor([0, 0, 0, 0, 1, 1])\n    >>> dgl.node_label_informativeness(graph, y)\n    0.3381872773170471\n    \"\"\"\n    check_pytorch()\n\n    graph = to_bidirected(graph.cpu()).to(y.device)\n\n    degrees = graph.in_degrees().float()\n    num_classes = y.max() + 1\n\n    class_probs = torch.zeros(num_classes).to(y.device)\n    class_probs.index_add_(\n        dim=0, index=y, source=torch.ones(graph.num_nodes()).to(y.device)\n    )\n    class_probs /= class_probs.sum()\n\n    class_degree_weighted_probs = torch.zeros(num_classes).to(y.device)\n    class_degree_weighted_probs.index_add_(dim=0, index=y, source=degrees)\n    class_degree_weighted_probs /= class_degree_weighted_probs.sum()\n\n    num_nonzero_degree_nodes = (degrees > 0).sum()\n\n    edge_probs = torch.zeros(num_classes, num_classes).to(y.device)\n    labels_u = y[graph.edges()[0].long()]\n    labels_v = y[graph.edges()[1].long()]\n    degrees_u = degrees[graph.edges()[0].long()]\n    edge_probs.index_put_(\n        indices=(labels_u, labels_v),\n        values=1 / (num_nonzero_degree_nodes * degrees_u),\n        accumulate=True,\n    )\n    edge_probs += eps\n\n    log = torch.log(\n        edge_probs\n        / (class_probs[:, None] * class_degree_weighted_probs[None, :])\n    )\n    numerator = (edge_probs * log).sum()\n    denominator = (class_probs * torch.log(class_probs)).sum()\n    li_node = -numerator / denominator\n\n    return li_node.item()\n"
  },
  {
    "path": "python/dgl/logging.py",
    "content": "\"\"\"logging module for DGL\"\"\"\nimport logging\nimport os\n\n\ndef enable_verbose_logging():\n    \"\"\"\n    Enable debug level logging for DGL\n    \"\"\"\n    os.environ[\"DMLC_LOG_DEBUG\"] = \"1\"\n    logger = logging.getLogger(\"dgl-core\")\n    logger.setLevel(logging.DEBUG)\n    logging.info(\"DGL's logging level is set to DEBUG\")\n\n\ndef _setup_logger():\n    \"\"\"setup logger\"\"\"\n    logger = logging.getLogger(\"dgl-core\")\n    console = logging.StreamHandler()\n    formatter = logging.Formatter(\n        \"%(asctime)s %(filename)s:%(lineno)s %(levelname)s p:%(processName)s \\\n        t:%(threadName)s: %(message)s\"\n    )\n    console.setFormatter(formatter)\n    console.setLevel(logging.DEBUG)\n    logger.addHandler(console)\n    logger.propagate = False\n    logger.setLevel(logging.INFO)\n\n\n_setup_logger()\n\nif os.environ.get(\"DGL_LOG_DEBUG\", None) == \"1\":\n    enable_verbose_logging()\n"
  },
  {
    "path": "python/dgl/merge.py",
    "content": "\"\"\"Utilities for merging graphs.\"\"\"\n\nimport dgl\n\nfrom . import backend as F\nfrom .base import DGLError\n\n__all__ = [\"merge\"]\n\n\ndef merge(graphs):\n    r\"\"\"Merge a sequence of graphs together into a single graph.\n\n    Nodes and edges that exist in ``graphs[i+1]`` but not in ``dgl.merge(graphs[0:i+1])``\n    will be added to ``dgl.merge(graphs[0:i+1])`` along with their data.\n    Nodes that exist in both ``dgl.merge(graphs[0:i+1])`` and ``graphs[i+1]``\n    will be updated with ``graphs[i+1]``'s data if they do not match.\n\n    Parameters\n    ----------\n    graphs : list[DGLGraph]\n        Input graphs.\n\n    Returns\n    -------\n    DGLGraph\n        The merged graph.\n\n    Notes\n    ----------\n    * Inplace updates are applied to a new, empty graph.\n    * Features that exist in ``dgl.graphs[i+1]`` will be created in\n      ``dgl.merge(dgl.graphs[i+1])`` if they do not already exist.\n\n    Examples\n    ----------\n\n    The following example uses PyTorch backend.\n\n    >>> import dgl\n    >>> import torch\n    >>> g = dgl.graph((torch.tensor([0,1]), torch.tensor([2,3])))\n    >>> g.ndata[\"x\"] = torch.zeros(4)\n    >>> h = dgl.graph((torch.tensor([1,2]), torch.tensor([0,4])))\n    >>> h.ndata[\"x\"] = torch.ones(5)\n    >>> m = dgl.merge([g, h])\n\n    ``m`` now contains edges and nodes from ``h`` and ``g``.\n\n    >>> m.edges()\n    (tensor([0, 1, 1, 2]), tensor([2, 3, 0, 4]))\n    >>> m.nodes()\n    tensor([0, 1, 2, 3, 4])\n\n    ``g``'s data has updated with ``h``'s in ``m``.\n\n    >>> m.ndata[\"x\"]\n    tensor([1., 1., 1., 1., 1.])\n\n    See Also\n    ----------\n    add_nodes\n    add_edges\n    \"\"\"\n\n    if len(graphs) == 0:\n        raise DGLError(\"The input list of graphs cannot be empty.\")\n\n    ref = graphs[0]\n    ntypes = ref.ntypes\n    etypes = ref.canonical_etypes\n    data_dict = {etype: ([], []) for etype in etypes}\n    num_nodes_dict = {ntype: 0 for ntype in ntypes}\n    merged = dgl.heterograph(data_dict, num_nodes_dict, ref.idtype, ref.device)\n\n    # Merge edges and edge data.\n    for etype in etypes:\n        unmerged_us = []\n        unmerged_vs = []\n        edata_frames = []\n        for graph in graphs:\n            etype_id = graph.get_etype_id(etype)\n            us, vs = graph.edges(etype=etype)\n            unmerged_us.append(us)\n            unmerged_vs.append(vs)\n            edge_data = graph._edge_frames[etype_id]\n            edata_frames.append(edge_data)\n        keys = ref.edges[etype].data.keys()\n        if len(keys) == 0:\n            edges_data = None\n        else:\n            edges_data = {\n                k: F.cat([f[k] for f in edata_frames], dim=0) for k in keys\n            }\n        merged_us = F.copy_to(\n            F.astype(F.cat(unmerged_us, dim=0), ref.idtype), ref.device\n        )\n        merged_vs = F.copy_to(\n            F.astype(F.cat(unmerged_vs, dim=0), ref.idtype), ref.device\n        )\n        merged.add_edges(merged_us, merged_vs, edges_data, etype)\n\n    # Add node data and isolated nodes from next_graph to merged.\n    for next_graph in graphs:\n        for ntype in ntypes:\n            merged_ntype_id = merged.get_ntype_id(ntype)\n            next_ntype_id = next_graph.get_ntype_id(ntype)\n            next_ndata = next_graph._node_frames[next_ntype_id]\n            node_diff = next_graph.num_nodes(ntype=ntype) - merged.num_nodes(\n                ntype=ntype\n            )\n            n_extra_nodes = max(0, node_diff)\n            merged.add_nodes(n_extra_nodes, ntype=ntype)\n            next_nodes = F.arange(\n                0,\n                next_graph.num_nodes(ntype=ntype),\n                merged.idtype,\n                merged.device,\n            )\n            merged._node_frames[merged_ntype_id].update_row(\n                next_nodes, next_ndata\n            )\n\n    return merged\n"
  },
  {
    "path": "python/dgl/mpops/__init__.py",
    "content": "\"\"\"Message passing operator sub-package\"\"\"\n\nfrom .edgewise import *\nfrom .nodewise import *\nfrom .fused import *\n"
  },
  {
    "path": "python/dgl/mpops/edgewise.py",
    "content": "\"\"\"Operators for computing edge data.\"\"\"\nimport sys\n\nfrom .. import ops\n\n__all__ = [\"copy_u\", \"copy_v\"]\n\n#######################################################\n# Edge-wise operators that fetch node data to edges\n#######################################################\n\n\ndef copy_u(g, x_node, etype=None):\n    \"\"\"Compute new edge data by fetching from source node data.\n\n    Given an input graph :math:`G(V, E)` (or a unidirectional bipartite graph\n    :math:`G(V_{src}, V_{dst}, E)`) and an input tensor :math:`X`,\n    the operator computes a tensor :math:`Y` storing the new edge data.\n    For each edge :math:`e=(u,v) \\\\in E`, it computes:\n\n    .. math:\n\n        Y_e = X_u\n\n    Parameters\n    ----------\n    g : DGLGraph\n        The input graph.\n    x_node : Tensor\n        The tensor storing the source node data. Shape :math:`(|V_{src}|, *)`.\n    etype : str or (str, str, str), optional\n        Edge type. If not specified, the input graph must have only one type of\n        edges.\n\n    Returns\n    -------\n    Tensor\n        The tensor storing the new edge data. Shape :math:`(|E|, *)`.\n\n    Examples\n    --------\n\n    **Homogeneous graph**\n\n    >>> import torch, dgl\n    >>> g = dgl.rand_graph(100, 500)  # a random graph of 100 nodes, 500 edges\n    >>> x = torch.randn(g.num_nodes(), 5)  # 5 features\n    >>> y = dgl.copy_u(g, x)\n    >>> print(y.shape)\n    (500, 5)\n\n    **Heterogeneous graph**\n\n    >>> hg = dgl.heterograph({\n    ...     ('user', 'follow', 'user'): ([0, 1, 2], [2, 3, 4]),\n    ...     ('user', 'like', 'movie'): ([3, 3, 1, 2], [0, 0, 1, 1])\n    ... })\n    >>> x = torch.randn(hg.num_nodes('user'), 5)\n    >>> y = dgl.copy_u(hg, x, etype='like')\n    >>> print(y.shape)\n    (4, 5)\n    \"\"\"\n    etype_subg = g if etype is None else g[etype]\n    return ops.gsddmm(etype_subg, \"copy_lhs\", x_node, None)\n\n\ndef copy_v(g, x_node, etype=None):\n    \"\"\"Compute new edge data by fetching from destination node data.\n\n    Given an input graph :math:`G(V, E)` (or a unidirectional bipartite graph\n    :math:`G(V_{src}, V_{dst}, E)`) and an input tensor :math:`X`,\n    the operator computes a tensor :math:`Y` storing the new edge data.\n    For each edge :math:`e=(u,v) \\\\in E`, it computes:\n\n    .. math:\n\n        Y_e = X_v\n\n    Parameters\n    ----------\n    g : DGLGraph\n        The input graph.\n    x_node : Tensor\n        The tensor storing the destination node data. Shape :math:`(|V_{dst}|, *)`.\n    etype : str or (str, str, str), optional\n        Edge type. If not specified, the input graph must have\n        only one type of edges.\n\n    Returns\n    -------\n    Tensor\n        The tensor storing the new edge data. Shape :math:`(|E|, *)`.\n\n    Examples\n    --------\n\n    **Homogeneous graph**\n\n    >>> import torch, dgl\n    >>> g = dgl.rand_graph(100, 500)  # a random graph of 100 nodes, 500 edges\n    >>> x = torch.randn(g.num_nodes(), 5)  # 5 features\n    >>> y = dgl.copy_v(g, x)\n    >>> print(y.shape)\n    (500, 5)\n\n    **Heterogeneous graph**\n\n    >>> hg = dgl.heterograph({\n    ...     ('user', 'follow', 'user'): ([0, 1, 2], [2, 3, 4]),\n    ...     ('user', 'like', 'movie'): ([3, 3, 1, 2], [0, 0, 1, 1])\n    ... })\n    >>> x = torch.randn(hg.num_nodes('movie'), 5)\n    >>> y = dgl.copy_v(hg, x, etype='like')\n    >>> print(y.shape)\n    (4, 5)\n    \"\"\"\n    etype_subg = g if etype is None else g[etype]\n    return ops.gsddmm(etype_subg, \"copy_rhs\", None, x_node)\n\n\n#######################################################\n# Binary edge-wise operators\n#######################################################\n\n\ndef _gen_u_op_v(op):\n    \"\"\"Internal helper function to create binary edge-wise operators.\n\n    The function will return a Python function with:\n\n     - Name: u_{op}_v\n     - Docstring template\n\n    Parameters\n    ----------\n    op : str\n        Binary operator name. Must be 'add', 'sub', 'mul', 'div' or 'dot'.\n    \"\"\"\n    name = f\"u_{op}_v\"\n    op_verb = {\n        \"add\": \"adding\",\n        \"sub\": \"subtracting\",\n        \"mul\": \"multiplying\",\n        \"div\": \"dividing\",\n        \"dot\": \"dot-product\",\n    }\n    docstring = f\"\"\"Compute new edge data by {op_verb[op]} the source node data\nand destination node data.\n\nGiven an input graph :math:`G(V, E)` (or a unidirectional bipartite graph\n:math:`G(V_{{src}}, V_{{dst}}, E)`) and two input tensors :math:`X` and\n:math:`Y`, the operator computes a tensor :math:`Z` storing the new edge data.\nFor each edge :math:`e=(u,v) \\\\in E`, it computes:\n\n.. math:\n\n    Z_e = {op}(X_u, Y_v)\n\nIf :math:`X_u` and :math:`Y_v` are vectors or high-dimensional tensors, the\noperation is element-wise and supports shape broadcasting. Read more about\n`NumPy's broadcasting semantics\n<https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html>`_.\n\nParameters\n----------\ng : DGLGraph\n    The input graph.\nx_node : Tensor\n    The tensor storing the source node data. Shape :math:`(|V_{{src}}|, *)`.\ny_node : Tensor\n    The tensor storing the destination node data. Shape :math:`(|V_{{dst}}|, *)`.\netype : str or (str, str, str), optional\n    Edge type. If not specified, the input graph must have\n    only one type of edges.\n\nReturns\n-------\nTensor\n    The tensor storing the new edge data. Shape :math:`(|E|, *)`.\n\nExamples\n--------\n\n**Homogeneous graph**\n\n>>> import torch, dgl\n>>> g = dgl.rand_graph(100, 500)  # a random graph of 100 nodes, 500 edges\n>>> x = torch.randn(g.num_nodes(), 5)  # 5 features\n>>> y = torch.randn(g.num_nodes(), 5)  # 5 features\n>>> z = dgl.{name}(g, x, y)\n>>> print(z.shape)\n(500, 5)\n\n**Heterogeneous graph**\n\n>>> hg = dgl.heterograph({{\n...     ('user', 'follow', 'user'): ([0, 1, 2], [2, 3, 4]),\n...     ('user', 'like', 'movie'): ([3, 3, 1, 2], [0, 0, 1, 1])\n... }})\n>>> x = torch.randn(hg.num_nodes('user'), 5)\n>>> y = torch.randn(hg.num_nodes('user'), 5)\n>>> z = dgl.{name}(hg, x, y, etype='follow')\n>>> print(z.shape)\n(3, 5)\n\n**Shape broadcasting**\n\n>>> x = torch.randn(g.num_nodes(), 5)  # 5 features\n>>> y = torch.randn(g.num_nodes(), 1)  # one feature\n>>> z = dgl.{name}(g, x, y)\n>>> print(z.shape)\n(500, 5)\n\"\"\"\n\n    def func(g, x_node, y_node, etype=None):\n        etype_subg = g if etype is None else g[etype]\n        return ops.gsddmm(\n            etype_subg, op, x_node, y_node, lhs_target=\"u\", rhs_target=\"v\"\n        )\n\n    func.__name__ = name\n    func.__doc__ = docstring\n    return func\n\n\ndef _register_func(func):\n    setattr(sys.modules[__name__], func.__name__, func)\n    __all__.append(func.__name__)\n\n\n_register_func(_gen_u_op_v(\"add\"))\n_register_func(_gen_u_op_v(\"sub\"))\n_register_func(_gen_u_op_v(\"mul\"))\n_register_func(_gen_u_op_v(\"div\"))\n_register_func(_gen_u_op_v(\"dot\"))\n"
  },
  {
    "path": "python/dgl/mpops/fused.py",
    "content": "\"\"\"Operators that fuse the computation and aggregation of edge data.\"\"\"\n"
  },
  {
    "path": "python/dgl/mpops/nodewise.py",
    "content": "\"\"\"Operators for aggregating/reducing edge data to node data.\"\"\"\n"
  },
  {
    "path": "python/dgl/multiprocessing/__init__.py",
    "content": "\"\"\"Wrapper of the multiprocessing module for multi-GPU training.\"\"\"\n\n# To avoid duplicating the graph structure for node classification or link prediction\n# training we recommend using fork() rather than spawn() for multiple GPU training.\n# However, we need to work around https://github.com/pytorch/pytorch/issues/17199 to\n# make fork() and openmp work together.\nfrom .. import backend as F\n\nif F.get_preferred_backend() == \"pytorch\":\n    # Wrap around torch.multiprocessing...\n    from torch.multiprocessing import *\n\n    # ... and override the Process initializer.\n    from .pytorch import *\nelse:\n    # Just import multiprocessing module.\n    from multiprocessing import *  # pylint: disable=redefined-builtin\n"
  },
  {
    "path": "python/dgl/multiprocessing/pytorch.py",
    "content": "\"\"\"PyTorch multiprocessing wrapper.\"\"\"\nimport random\nimport traceback\nfrom _thread import start_new_thread\nfrom functools import wraps\n\nimport torch\nimport torch.multiprocessing as mp\n\nfrom ..utils import create_shared_mem_array, get_shared_mem_array\n\n\ndef thread_wrapped_func(func):\n    \"\"\"\n    Wraps a process entry point to make it work with OpenMP.\n    \"\"\"\n\n    @wraps(func)\n    def decorated_function(*args, **kwargs):\n        queue = mp.Queue()\n\n        def _queue_result():\n            exception, trace, res = None, None, None\n            try:\n                res = func(*args, **kwargs)\n            except Exception as e:  # pylint: disable=broad-except\n                exception = e\n                trace = traceback.format_exc()\n            queue.put((res, exception, trace))\n\n        start_new_thread(_queue_result, ())\n        result, exception, trace = queue.get()\n        if exception is None:\n            return result\n        else:\n            assert isinstance(exception, Exception)\n            raise exception.__class__(trace)\n\n    return decorated_function\n\n\n# pylint: disable=missing-docstring\nclass Process(mp.Process):\n    # pylint: disable=dangerous-default-value\n    def __init__(\n        self,\n        group=None,\n        target=None,\n        name=None,\n        args=(),\n        kwargs={},\n        *,\n        daemon=None\n    ):\n        target = thread_wrapped_func(target)\n        super().__init__(group, target, name, args, kwargs, daemon=daemon)\n\n\ndef _get_shared_mem_name(id_):\n    return \"shared\" + str(id_)\n\n\ndef call_once_and_share(func, shape, dtype, rank=0):\n    \"\"\"Invoke the function in a single process of the PyTorch distributed process group,\n    and share the result with other processes.\n\n    Parameters\n    ----------\n    func : callable\n        Any callable that accepts no arguments and returns an arbitrary object.\n    shape : tuple[int]\n        The shape of the shared tensor.  Must match the output of :attr:`func`.\n    dtype : torch.dtype\n        The data type of the shared tensor.  Must match the output of :attr:`func`.\n    rank : int, optional\n        The process ID to actually execute the function.\n    \"\"\"\n    current_rank = torch.distributed.get_rank()\n    dist_buf = torch.LongTensor([1])\n\n    if torch.distributed.get_backend() == \"nccl\":\n        # Use .cuda() to transfer it to the correct device.  Should be OK since\n        # PyTorch recommends the users to call set_device() after getting inside\n        # torch.multiprocessing.spawn()\n        dist_buf = dist_buf.cuda()\n\n    # Process with the given rank creates and populates the shared memory array.\n    if current_rank == rank:\n        # PyTorch Lightning 1.6+ seems to set the random seed during process spawning\n        # to the same seed value.\n        random_ = random.Random()\n        id_ = random_.getrandbits(32)\n        name = _get_shared_mem_name(id_)\n        result = create_shared_mem_array(name, shape, dtype)\n        result[:] = func()\n        dist_buf[0] = id_\n\n    # Broadcasts the name of the shared array to other processes.\n    torch.distributed.broadcast(dist_buf, rank)\n    # If no exceptions, other processes open the same shared memory object.\n    if current_rank != rank:\n        id_ = dist_buf.item()\n        name = _get_shared_mem_name(id_)\n        result = get_shared_mem_array(name, shape, dtype)\n\n    return result\n\n\ndef shared_tensor(shape, dtype=torch.float32):\n    \"\"\"Create a tensor in shared memory accessible by all processes within the same\n    ``torch.distributed`` process group.\n\n    The content is uninitialized.\n\n    Parameters\n    ----------\n    shape : tuple[int]\n        The shape of the tensor.\n    dtype : torch.dtype, optional\n        The dtype of the tensor.\n\n    Returns\n    -------\n    Tensor\n        The shared tensor.\n    \"\"\"\n    return call_once_and_share(\n        lambda: torch.empty(*shape, dtype=dtype), shape, dtype\n    )\n"
  },
  {
    "path": "python/dgl/ndarray.py",
    "content": "\"\"\"DGL Runtime NDArray API.\n\ndgl.ndarray provides a minimum runtime array structure to be\nused with C++ library.\n\"\"\"\n# pylint: disable=invalid-name,unused-import\nfrom __future__ import absolute_import as _abs\n\nimport ctypes\nimport functools\nimport operator\n\nimport numpy as _np\n\nfrom . import backend as F\nfrom ._ffi.function import _init_api\nfrom ._ffi.ndarray import (\n    _set_class_ndarray,\n    context,\n    DGLContext,\n    DGLDataType,\n    empty,\n    empty_shared_mem,\n    from_dlpack,\n    NDArrayBase,\n    numpyasarray,\n)\nfrom ._ffi.object import ObjectBase, register_object\n\n\nclass NDArray(NDArrayBase):\n    \"\"\"Lightweight NDArray class for DGL framework.\"\"\"\n\n    def __len__(self):\n        return functools.reduce(operator.mul, self.shape, 1)\n\n    def shared_memory(self, name):\n        \"\"\"Return a copy of the ndarray in shared memory\n\n        Parameters\n        ----------\n        name : str\n            The name of the shared memory\n\n        Returns\n        -------\n        NDArray\n        \"\"\"\n        return empty_shared_mem(name, True, self.shape, self.dtype).copyfrom(\n            self\n        )\n\n\ndef cpu(dev_id=0):\n    \"\"\"Construct a CPU device\n\n    Parameters\n    ----------\n    dev_id : int, optional\n        The integer device id\n\n    Returns\n    -------\n    ctx : DGLContext\n        The created context\n    \"\"\"\n    return DGLContext(1, dev_id)\n\n\ndef gpu(dev_id=0):\n    \"\"\"Construct a CPU device\n\n    Parameters\n    ----------\n    dev_id : int, optional\n        The integer device id\n\n    Returns\n    -------\n    ctx : DGLContext\n        The created context\n    \"\"\"\n    return DGLContext(2, dev_id)\n\n\ndef array(arr, ctx=cpu(0)):\n    \"\"\"Create an array from source arr.\n\n    Parameters\n    ----------\n    arr : numpy.ndarray\n        The array to be copied from\n\n    ctx : DGLContext, optional\n        The device context to create the array\n\n    Returns\n    -------\n    ret : NDArray\n        The created array\n    \"\"\"\n    if not isinstance(arr, (_np.ndarray, NDArray)):\n        arr = _np.array(arr)\n    return empty(arr.shape, arr.dtype, ctx).copyfrom(arr)\n\n\ndef zerocopy_from_numpy(np_data):\n    \"\"\"Create an array that shares the given numpy data.\n\n    Parameters\n    ----------\n    np_data : numpy.ndarray\n        The numpy data\n\n    Returns\n    -------\n    NDArray\n        The array\n    \"\"\"\n    arr, _ = numpyasarray(np_data)\n    handle = ctypes.pointer(arr)\n    return NDArray(handle, is_view=True)\n\n\ndef cast_to_signed(arr):\n    \"\"\"Cast this NDArray from unsigned integer to signed one.\n\n    uint64 -> int64\n    uint32 -> int32\n\n    Useful for backends with poor signed integer support (e.g., TensorFlow).\n\n    Parameters\n    ----------\n    arr : NDArray\n        Input array\n\n    Returns\n    -------\n    NDArray\n        Cased array\n    \"\"\"\n    return _CAPI_DGLArrayCastToSigned(arr)\n\n\ndef get_shared_mem_array(name, shape, dtype):\n    \"\"\"Get a tensor from shared memory with specific name\n\n    Parameters\n    ----------\n    name : str\n        The unique name of the shared memory\n    shape : tuple of int\n        The shape of the returned tensor\n    dtype : F.dtype\n        The dtype of the returned tensor\n\n    Returns\n    -------\n    F.tensor\n        The tensor got from shared memory.\n    \"\"\"\n    new_arr = empty_shared_mem(\n        name, False, shape, F.reverse_data_type_dict[dtype]\n    )\n    dlpack = new_arr.to_dlpack()\n    return F.zerocopy_from_dlpack(dlpack)\n\n\ndef create_shared_mem_array(name, shape, dtype):\n    \"\"\"Create a tensor from shared memory with the specific name\n\n    Parameters\n    ----------\n    name : str\n        The unique name of the shared memory\n    shape : tuple of int\n        The shape of the returned tensor\n    dtype : F.dtype\n        The dtype of the returned tensor\n\n    Returns\n    -------\n    F.tensor\n        The created tensor.\n    \"\"\"\n    new_arr = empty_shared_mem(\n        name, True, shape, F.reverse_data_type_dict[dtype]\n    )\n    dlpack = new_arr.to_dlpack()\n    return F.zerocopy_from_dlpack(dlpack)\n\n\ndef exist_shared_mem_array(name):\n    \"\"\"Check the existence of shared-memory array.\n\n    Parameters\n    ----------\n    name : str\n        The name of the shared-memory array.\n\n    Returns\n    -------\n    bool\n        The existence of the array\n    \"\"\"\n    return _CAPI_DGLExistSharedMemArray(name)\n\n\nclass SparseFormat:\n    \"\"\"Format code\"\"\"\n\n    ANY = 0\n    COO = 1\n    CSR = 2\n    CSC = 3\n\n    FORMAT2STR = {\n        0: \"ANY\",\n        1: \"COO\",\n        2: \"CSR\",\n        3: \"CSC\",\n    }\n\n\n@register_object(\"aten.SparseMatrix\")\nclass SparseMatrix(ObjectBase):\n    \"\"\"Sparse matrix object class in C++ backend.\"\"\"\n\n    @property\n    def format(self):\n        \"\"\"Sparse format enum\n\n        Returns\n        -------\n        int\n        \"\"\"\n        return _CAPI_DGLSparseMatrixGetFormat(self)\n\n    @property\n    def num_rows(self):\n        \"\"\"Number of rows.\n\n        Returns\n        -------\n        int\n        \"\"\"\n        return _CAPI_DGLSparseMatrixGetNumRows(self)\n\n    @property\n    def num_cols(self):\n        \"\"\"Number of rows.\n\n        Returns\n        -------\n        int\n        \"\"\"\n        return _CAPI_DGLSparseMatrixGetNumCols(self)\n\n    @property\n    def indices(self):\n        \"\"\"Index arrays.\n\n        Returns\n        -------\n        list of ndarrays\n        \"\"\"\n        ret = [_CAPI_DGLSparseMatrixGetIndices(self, i) for i in range(3)]\n        return [F.zerocopy_from_dgl_ndarray(arr) for arr in ret]\n\n    @property\n    def flags(self):\n        \"\"\"Flag arrays\n\n        Returns\n        -------\n        list of boolean\n        \"\"\"\n        return _CAPI_DGLSparseMatrixGetFlags(self)\n\n    def __getstate__(self):\n        return (\n            self.format,\n            self.num_rows,\n            self.num_cols,\n            self.indices,\n            self.flags,\n        )\n\n    def __setstate__(self, state):\n        fmt, nrows, ncols, indices, flags = state\n        indices = [F.zerocopy_to_dgl_ndarray(idx) for idx in indices]\n        self.__init_handle_by_constructor__(\n            _CAPI_DGLCreateSparseMatrix, fmt, nrows, ncols, indices, flags\n        )\n\n    def __repr__(self):\n        return 'SparseMatrix(fmt=\"{}\", shape=({},{}))'.format(\n            SparseFormat.FORMAT2STR[self.format], self.num_rows, self.num_cols\n        )\n\n\n_set_class_ndarray(NDArray)\n_init_api(\"dgl.ndarray\")\n_init_api(\"dgl.ndarray.uvm\", __name__)\n\n# An array representing null (no value) that can be safely converted to\n# other backend tensors.\nNULL = {\n    \"int64\": array(_np.array([], dtype=_np.int64)),\n    \"int32\": array(_np.array([], dtype=_np.int32)),\n}\n"
  },
  {
    "path": "python/dgl/nn/__init__.py",
    "content": "\"\"\"The ``dgl.nn`` package contains framework-specific implementations for\ncommon Graph Neural Network layers (or module in PyTorch, Block in MXNet).\nUsers can directly import ``dgl.nn.<layer_name>`` (e.g., ``dgl.nn.GraphConv``),\nand the package will dispatch the layer name to the actual implementation\naccording to the backend framework currently in use.\n\nNote that there are coverage differences among frameworks. If you encounter\nan ``ImportError: cannot import name 'XXX'`` error, that means the layer is\nnot available to the current backend. If you wish a module to appear in DGL,\nplease `create an issue <https://github.com/dmlc/dgl/issues>`_ started with\n\"[Feature Request] NN Module XXXModel\". If you want to contribute a NN module,\nplease `create a pull request <https://github.com/dmlc/dgl/pulls>`_ started\nwith \"[NN] XXX module\".\n\"\"\"\n\nimport importlib\nimport os\nimport sys\n\nfrom ..backend import backend_name\nfrom ..utils import expand_as_pair\n\n# [BarclayII] Not sure what's going on with pylint.\n# Possible issue: https://github.com/PyCQA/pylint/issues/2648\nfrom . import functional  # pylint: disable=import-self\n\n\ndef _load_backend(mod_name):\n    mod = importlib.import_module(\".%s\" % mod_name, __name__)\n    thismod = sys.modules[__name__]\n    for api, obj in mod.__dict__.items():\n        setattr(thismod, api, obj)\n\n\n_load_backend(backend_name)\n"
  },
  {
    "path": "python/dgl/nn/functional/__init__.py",
    "content": "\"\"\"Functions related to DGL NN Modules.\"\"\"\n\nfrom ...ops import edge_softmax\n"
  },
  {
    "path": "python/dgl/nn/mxnet/__init__.py",
    "content": "\"\"\"Package for mxnet-specific NN modules.\"\"\"\nfrom .conv import *\nfrom .glob import *\nfrom .hetero import *\nfrom .softmax import *\nfrom .utils import Sequential\n"
  },
  {
    "path": "python/dgl/nn/mxnet/conv/__init__.py",
    "content": "\"\"\"MXNet modules for graph convolutions.\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name\n\nfrom .agnnconv import AGNNConv\nfrom .appnpconv import APPNPConv\nfrom .chebconv import ChebConv\nfrom .densechebconv import DenseChebConv\nfrom .densegraphconv import DenseGraphConv\nfrom .densesageconv import DenseSAGEConv\nfrom .edgeconv import EdgeConv\nfrom .gatconv import GATConv\nfrom .gatedgraphconv import GatedGraphConv\nfrom .ginconv import GINConv\nfrom .gmmconv import GMMConv\nfrom .graphconv import GraphConv\nfrom .nnconv import NNConv\nfrom .relgraphconv import RelGraphConv\nfrom .sageconv import SAGEConv\nfrom .sgconv import SGConv\nfrom .tagconv import TAGConv\n\n__all__ = [\n    \"GraphConv\",\n    \"TAGConv\",\n    \"RelGraphConv\",\n    \"GATConv\",\n    \"SAGEConv\",\n    \"GatedGraphConv\",\n    \"ChebConv\",\n    \"AGNNConv\",\n    \"APPNPConv\",\n    \"DenseGraphConv\",\n    \"DenseSAGEConv\",\n    \"DenseChebConv\",\n    \"EdgeConv\",\n    \"GINConv\",\n    \"GMMConv\",\n    \"NNConv\",\n    \"SGConv\",\n]\n"
  },
  {
    "path": "python/dgl/nn/mxnet/conv/agnnconv.py",
    "content": "\"\"\"MXNet Module for Attention-based Graph Neural Network layer\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name\nimport mxnet as mx\nfrom mxnet.gluon import nn\n\nfrom .... import function as fn\nfrom ....base import DGLError\nfrom ....utils import expand_as_pair\nfrom ...functional import edge_softmax\nfrom ..utils import normalize\n\n\nclass AGNNConv(nn.Block):\n    r\"\"\"Attention-based Graph Neural Network layer from `Attention-based Graph Neural Network for\n    Semi-Supervised Learning <https://arxiv.org/abs/1803.03735>`__\n\n    .. math::\n        H^{l+1} = P H^{l}\n\n    where :math:`P` is computed as:\n\n    .. math::\n        P_{ij} = \\mathrm{softmax}_i ( \\beta \\cdot \\cos(h_i^l, h_j^l))\n\n    where :math:`\\beta` is a single scalar parameter.\n\n    Parameters\n    ----------\n    init_beta : float, optional\n        The :math:`\\beta` in the formula, a single scalar parameter.\n    learn_beta : bool, optional\n        If True, :math:`\\beta` will be learnable parameter.\n    allow_zero_in_degree : bool, optional\n        If there are 0-in-degree nodes in the graph, output for those nodes will be invalid\n        since no message will be passed to those nodes. This is harmful for some applications\n        causing silent performance regression. This module will raise a DGLError if it detects\n        0-in-degree nodes in input graph. By setting ``True``, it will suppress the check\n        and let the users handle it by themselves. Default: ``False``.\n\n    Note\n    ----\n    Zero in-degree nodes will lead to invalid output value. This is because no message\n    will be passed to those nodes, the aggregation function will be appied on empty input.\n    A common practice to avoid this is to add a self-loop for each node in the graph if\n    it is homogeneous, which can be achieved by:\n\n    >>> g = ... # a DGLGraph\n    >>> g = dgl.add_self_loop(g)\n\n    Calling ``add_self_loop`` will not work for some graphs, for example, heterogeneous graph\n    since the edge type can not be decided for self_loop edges. Set ``allow_zero_in_degree``\n    to ``True`` for those cases to unblock the code and handle zero-in-degree nodes manually.\n    A common practise to handle this is to filter out the nodes with zero-in-degree when use\n    after conv.\n\n    Example\n    -------\n    >>> import dgl\n    >>> import numpy as np\n    >>> import mxnet as mx\n    >>> from dgl.nn import AGNNConv\n    >>>\n    >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))\n    >>> g = dgl.add_self_loop(g)\n    >>> feat = mx.nd.ones((6, 10))\n    >>> conv = AGNNConv()\n    >>> conv.initialize(ctx=mx.cpu(0))\n    >>> res = conv(g, feat)\n    >>> res\n    [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n    [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n    [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n    [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n    [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n    [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]\n    <NDArray 6x10 @cpu(0)>\n    \"\"\"\n\n    def __init__(\n        self, init_beta=1.0, learn_beta=True, allow_zero_in_degree=False\n    ):\n        super(AGNNConv, self).__init__()\n        self._allow_zero_in_degree = allow_zero_in_degree\n        with self.name_scope():\n            self.beta = self.params.get(\n                \"beta\",\n                shape=(1,),\n                grad_req=\"write\" if learn_beta else \"null\",\n                init=mx.init.Constant(init_beta),\n            )\n\n    def set_allow_zero_in_degree(self, set_value):\n        r\"\"\"\n\n        Description\n        -----------\n        Set allow_zero_in_degree flag.\n\n        Parameters\n        ----------\n        set_value : bool\n            The value to be set to the flag.\n        \"\"\"\n        self._allow_zero_in_degree = set_value\n\n    def forward(self, graph, feat):\n        r\"\"\"\n\n        Description\n        -----------\n        Compute AGNN layer.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The graph.\n        feat : mxnet.NDArray\n            The input feature of shape :math:`(N, *)` :math:`N` is the\n            number of nodes, and :math:`*` could be of any shape.\n            If a pair of mxnet.NDArray is given, the pair must contain two tensors of shape\n            :math:`(N_{in}, *)` and :math:`(N_{out}, *)`, the :math:`*` in the later\n            tensor must equal the previous one.\n\n        Returns\n        -------\n        mxnet.NDArray\n            The output feature of shape :math:`(N, *)` where :math:`*`\n            should be the same as input shape.\n\n        Raises\n        ------\n        DGLError\n            If there are 0-in-degree nodes in the input graph, it will raise DGLError\n            since no message will be passed to those nodes. This will cause invalid output.\n            The error can be ignored by setting ``allow_zero_in_degree`` parameter to ``True``.\n        \"\"\"\n        with graph.local_scope():\n            if not self._allow_zero_in_degree:\n                if graph.in_degrees().min() == 0:\n                    raise DGLError(\n                        \"There are 0-in-degree nodes in the graph, \"\n                        \"output for those nodes will be invalid. \"\n                        \"This is harmful for some applications, \"\n                        \"causing silent performance regression. \"\n                        \"Adding self-loop on the input graph by \"\n                        \"calling `g = dgl.add_self_loop(g)` will resolve \"\n                        \"the issue. Setting ``allow_zero_in_degree`` \"\n                        \"to be `True` when constructing this module will \"\n                        \"suppress the check and let the code run.\"\n                    )\n\n            feat_src, feat_dst = expand_as_pair(feat, graph)\n            graph.srcdata[\"h\"] = feat_src\n            graph.srcdata[\"norm_h\"] = normalize(feat_src, p=2, axis=-1)\n            if isinstance(feat, tuple) or graph.is_block:\n                graph.dstdata[\"norm_h\"] = normalize(feat_dst, p=2, axis=-1)\n            # compute cosine distance\n            graph.apply_edges(fn.u_dot_v(\"norm_h\", \"norm_h\", \"cos\"))\n            cos = graph.edata.pop(\"cos\")\n            e = self.beta.data(feat_src.context) * cos\n            graph.edata[\"p\"] = edge_softmax(graph, e)\n            graph.update_all(fn.u_mul_e(\"h\", \"p\", \"m\"), fn.sum(\"m\", \"h\"))\n            return graph.dstdata.pop(\"h\")\n"
  },
  {
    "path": "python/dgl/nn/mxnet/conv/appnpconv.py",
    "content": "\"\"\"MXNet Module for APPNPConv\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name\nimport mxnet as mx\nfrom mxnet import nd\nfrom mxnet.gluon import nn\n\nfrom .... import function as fn\n\n\nclass APPNPConv(nn.Block):\n    r\"\"\"Approximate Personalized Propagation of Neural Predictions layer from `Predict then\n    Propagate: Graph Neural Networks meet Personalized PageRank\n    <https://arxiv.org/pdf/1810.05997.pdf>`__\n\n    .. math::\n        H^{0} &= X\n\n        H^{l+1} &= (1-\\alpha)\\left(\\tilde{D}^{-1/2}\n        \\tilde{A} \\tilde{D}^{-1/2} H^{l}\\right) + \\alpha H^{0}\n\n    where :math:`\\tilde{A}` is :math:`A` + :math:`I`.\n\n    Parameters\n    ----------\n    k : int\n        The number of iterations :math:`K`.\n    alpha : float\n        The teleport probability :math:`\\alpha`.\n    edge_drop : float, optional\n        The dropout rate on edges that controls the\n        messages received by each node. Default: ``0``.\n\n    Example\n    -------\n    >>> import dgl\n    >>> import numpy as np\n    >>> import mxnet as mx\n    >>> from dgl.nn import APPNPConv\n    >>>\n    >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))\n    >>> feat = mx.nd.ones((6, 10))\n    >>> conv = APPNPConv(k=3, alpha=0.5)\n    >>> conv.initialize(ctx=mx.cpu(0))\n    >>> res = conv(g, feat)\n    >>> res\n    [[1.         1.         1.         1.         1.         1.\n    1.         1.         1.         1.        ]\n    [1.         1.         1.         1.         1.         1.\n    1.         1.         1.         1.        ]\n    [1.         1.         1.         1.         1.         1.\n    1.         1.         1.         1.        ]\n    [1.0303301  1.0303301  1.0303301  1.0303301  1.0303301  1.0303301\n    1.0303301  1.0303301  1.0303301  1.0303301 ]\n    [0.86427665 0.86427665 0.86427665 0.86427665 0.86427665 0.86427665\n    0.86427665 0.86427665 0.86427665 0.86427665]\n    [0.5        0.5        0.5        0.5        0.5        0.5\n    0.5        0.5        0.5        0.5       ]]\n    <NDArray 6x10 @cpu(0)>\n    \"\"\"\n\n    def __init__(self, k, alpha, edge_drop=0.0):\n        super(APPNPConv, self).__init__()\n        self._k = k\n        self._alpha = alpha\n        with self.name_scope():\n            self.edge_drop = nn.Dropout(edge_drop)\n\n    def forward(self, graph, feat):\n        r\"\"\"\n\n        Description\n        -----------\n        Compute APPNP layer.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The graph.\n        feat : mx.NDArray\n            The input feature of shape :math:`(N, *)`. :math:`N` is the\n            number of nodes, and :math:`*` could be of any shape.\n\n        Returns\n        -------\n        mx.NDArray\n            The output feature of shape :math:`(N, *)` where :math:`*`\n            should be the same as input shape.\n        \"\"\"\n        with graph.local_scope():\n            norm = mx.nd.power(\n                mx.nd.clip(\n                    graph.in_degrees().astype(feat.dtype),\n                    a_min=1,\n                    a_max=float(\"inf\"),\n                ),\n                -0.5,\n            )\n            shp = norm.shape + (1,) * (feat.ndim - 1)\n            norm = norm.reshape(shp).as_in_context(feat.context)\n            feat_0 = feat\n            for _ in range(self._k):\n                # normalization by src node\n                feat = feat * norm\n                graph.ndata[\"h\"] = feat\n                graph.edata[\"w\"] = self.edge_drop(\n                    nd.ones((graph.num_edges(), 1), ctx=feat.context)\n                )\n                graph.update_all(fn.u_mul_e(\"h\", \"w\", \"m\"), fn.sum(\"m\", \"h\"))\n                feat = graph.ndata.pop(\"h\")\n                # normalization by dst node\n                feat = feat * norm\n                feat = (1 - self._alpha) * feat + self._alpha * feat_0\n            return feat\n"
  },
  {
    "path": "python/dgl/nn/mxnet/conv/chebconv.py",
    "content": "\"\"\"MXNet Module for Chebyshev Spectral Graph Convolution layer\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name\nimport math\n\nimport mxnet as mx\nfrom mxnet import nd\nfrom mxnet.gluon import nn\n\nfrom .... import broadcast_nodes, function as fn\nfrom ....base import dgl_warning\n\n\nclass ChebConv(nn.Block):\n    r\"\"\"Chebyshev Spectral Graph Convolution layer from `Convolutional Neural Networks on Graphs\n    with Fast Localized Spectral Filtering <https://arxiv.org/pdf/1606.09375.pdf>`__\n\n    .. math::\n        h_i^{l+1} &= \\sum_{k=0}^{K-1} W^{k, l}z_i^{k, l}\n\n        Z^{0, l} &= H^{l}\n\n        Z^{1, l} &= \\tilde{L} \\cdot H^{l}\n\n        Z^{k, l} &= 2 \\cdot \\tilde{L} \\cdot Z^{k-1, l} - Z^{k-2, l}\n\n        \\tilde{L} &= 2\\left(I - \\tilde{D}^{-1/2} \\tilde{A} \\tilde{D}^{-1/2}\\right)/\\lambda_{max} - I\n\n    where :math:`\\tilde{A}` is :math:`A` + :math:`I`, :math:`W` is learnable weight.\n\n\n    Parameters\n    ----------\n    in_feats: int\n        Dimension of input features; i.e, the number of dimensions of :math:`h_i^{(l)}`.\n    out_feats: int\n        Dimension of output features :math:`h_i^{(l+1)}`.\n    k : int\n        Chebyshev filter size :math:`K`.\n    activation : function, optional\n        Activation function. Default ``ReLu``.\n    bias : bool, optional\n        If True, adds a learnable bias to the output. Default: ``True``.\n\n    Example\n    -------\n    >>> import dgl\n    >>> import numpy as np\n    >>> import mxnet as mx\n    >>> from dgl.nn import ChebConv\n    >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))\n    >>> feat = mx.nd.ones((6, 10))\n    >>> conv = ChebConv(10, 2, 2)\n    >>> conv.initialize(ctx=mx.cpu(0))\n    >>> res = conv(g, feat)\n    >>> res\n    [[ 0.832592   -0.738757  ]\n    [ 0.832592   -0.738757  ]\n    [ 0.832592   -0.738757  ]\n    [ 0.43377423 -1.0455742 ]\n    [ 1.1145986  -0.5218046 ]\n    [ 1.7954229   0.00196505]]\n    <NDArray 6x2 @cpu(0)>\n    \"\"\"\n\n    def __init__(self, in_feats, out_feats, k, bias=True):\n        super(ChebConv, self).__init__()\n        self._in_feats = in_feats\n        self._out_feats = out_feats\n        self._k = k\n        with self.name_scope():\n            self.fc = nn.Sequential()\n            for _ in range(k):\n                self.fc.add(\n                    nn.Dense(\n                        out_feats,\n                        use_bias=False,\n                        weight_initializer=mx.init.Xavier(\n                            magnitude=math.sqrt(2.0)\n                        ),\n                        in_units=in_feats,\n                    )\n                )\n            if bias:\n                self.bias = self.params.get(\n                    \"bias\", shape=(out_feats,), init=mx.init.Zero()\n                )\n            else:\n                self.bias = None\n\n    def forward(self, graph, feat, lambda_max=None):\n        r\"\"\"\n\n        Description\n        -----------\n        Compute ChebNet layer.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The graph.\n        feat : mxnet.NDArray\n            The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`\n            is size of input feature, :math:`N` is the number of nodes.\n        lambda_max : list or tensor or None, optional.\n            A list(tensor) with length :math:`B`, stores the largest eigenvalue\n            of the normalized laplacian of each individual graph in ``graph``,\n            where :math:`B` is the batch size of the input graph. Default: None.\n\n            If None, this method would set the default value to 2.\n            One can use :func:`dgl.laplacian_lambda_max` to compute this value.\n\n        Returns\n        -------\n        mxnet.NDArray\n            The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`\n            is size of output feature.\n        \"\"\"\n        with graph.local_scope():\n            degs = graph.in_degrees().astype(\"float32\")\n            norm = mx.nd.power(\n                mx.nd.clip(degs, a_min=1, a_max=float(\"inf\")), -0.5\n            )\n            norm = norm.expand_dims(-1).as_in_context(feat.context)\n\n            if lambda_max is None:\n                dgl_warning(\n                    \"lambda_max is not provided, using default value of 2.  \"\n                    \"Please use dgl.laplacian_lambda_max to compute the eigenvalues.\"\n                )\n                lambda_max = [2] * graph.batch_size\n\n            if isinstance(lambda_max, list):\n                lambda_max = nd.array(lambda_max).as_in_context(feat.context)\n            if lambda_max.ndim == 1:\n                lambda_max = lambda_max.expand_dims(-1)\n            # broadcast from (B, 1) to (N, 1)\n            lambda_max = broadcast_nodes(graph, lambda_max)\n            # T0(X)\n            Tx_0 = feat\n            rst = self.fc[0](Tx_0)\n            # T1(X)\n            if self._k > 1:\n                graph.ndata[\"h\"] = Tx_0 * norm\n                graph.update_all(fn.copy_u(\"h\", \"m\"), fn.sum(\"m\", \"h\"))\n                h = graph.ndata.pop(\"h\") * norm\n                # Λ = 2 * (I - D ^ -1/2 A D ^ -1/2) / lambda_max - I\n                #   = - 2(D ^ -1/2 A D ^ -1/2) / lambda_max + (2 / lambda_max - 1) I\n                Tx_1 = -2.0 * h / lambda_max + Tx_0 * (2.0 / lambda_max - 1)\n                rst = rst + self.fc[1](Tx_1)\n            # Ti(x), i = 2...k\n            for i in range(2, self._k):\n                graph.ndata[\"h\"] = Tx_1 * norm\n                graph.update_all(fn.copy_u(\"h\", \"m\"), fn.sum(\"m\", \"h\"))\n                h = graph.ndata.pop(\"h\") * norm\n                # Tx_k = 2 * Λ * Tx_(k-1) - Tx_(k-2)\n                #      = - 4(D ^ -1/2 A D ^ -1/2) / lambda_max Tx_(k-1) +\n                #        (4 / lambda_max - 2) Tx_(k-1) -\n                #        Tx_(k-2)\n                Tx_2 = (\n                    -4.0 * h / lambda_max + Tx_1 * (4.0 / lambda_max - 2) - Tx_0\n                )\n                rst = rst + self.fc[i](Tx_2)\n                Tx_1, Tx_0 = Tx_2, Tx_1\n            # add bias\n            if self.bias is not None:\n                rst = rst + self.bias.data(feat.context)\n            return rst\n"
  },
  {
    "path": "python/dgl/nn/mxnet/conv/densechebconv.py",
    "content": "\"\"\"MXNet Module for DenseChebConv\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name\nimport math\n\nimport mxnet as mx\nfrom mxnet import nd\nfrom mxnet.gluon import nn\n\n\nclass DenseChebConv(nn.Block):\n    r\"\"\"Chebyshev Spectral Graph Convolution layer from `Convolutional Neural Networks on Graphs\n    with Fast Localized Spectral Filtering <https://arxiv.org/pdf/1606.09375.pdf>`__\n\n    We recommend to use this module when applying ChebConv on dense graphs.\n\n    Parameters\n    ----------\n    in_feats: int\n        Dimension of input features :math:`h_i^{(l)}`.\n    out_feats: int\n        Dimension of output features :math:`h_i^{(l+1)}`.\n    k : int\n        Chebyshev filter size.\n    activation : function, optional\n        Activation function, default is ReLu.\n    bias : bool, optional\n        If True, adds a learnable bias to the output. Default: ``True``.\n\n    See also\n    --------\n    `ChebConv <https://docs.dgl.ai/api/python/nn.pytorch.html#chebconv>`__\n    \"\"\"\n\n    def __init__(self, in_feats, out_feats, k, bias=True):\n        super(DenseChebConv, self).__init__()\n        self._in_feats = in_feats\n        self._out_feats = out_feats\n        self._k = k\n        with self.name_scope():\n            self.fc = nn.Sequential()\n            for _ in range(k):\n                self.fc.add(\n                    nn.Dense(\n                        out_feats,\n                        in_units=in_feats,\n                        use_bias=False,\n                        weight_initializer=mx.init.Xavier(\n                            magnitude=math.sqrt(2.0)\n                        ),\n                    )\n                )\n            if bias:\n                self.bias = self.params.get(\n                    \"bias\", shape=(out_feats,), init=mx.init.Zero()\n                )\n            else:\n                self.bias = None\n\n    def forward(self, adj, feat, lambda_max=None):\n        r\"\"\"\n\n        Description\n        -----------\n        Compute (Dense) Chebyshev Spectral Graph Convolution layer.\n\n        Parameters\n        ----------\n        adj : mxnet.NDArray\n            The adjacency matrix of the graph to apply Graph Convolution on,\n            should be of shape :math:`(N, N)`, where a row represents the destination\n            and a column represents the source.\n        feat : mxnet.NDArray\n            The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`\n            is size of input feature, :math:`N` is the number of nodes.\n        lambda_max : float or None, optional\n            A float value indicates the largest eigenvalue of given graph.\n            Default: None.\n\n        Returns\n        -------\n        mxnet.NDArray\n            The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`\n            is size of output feature.\n        \"\"\"\n        A = adj.astype(feat.dtype).as_in_context(feat.context)\n        num_nodes = A.shape[0]\n\n        in_degree = 1.0 / nd.clip(A.sum(axis=1), 1, float(\"inf\")).sqrt()\n        D_invsqrt = nd.diag(in_degree)\n        I = nd.eye(num_nodes, ctx=A.context)\n        L = I - nd.dot(D_invsqrt, nd.dot(A, D_invsqrt))\n\n        if lambda_max is None:\n            # NOTE(zihao): this only works for directed graph.\n            lambda_max = (nd.linalg.syevd(L)[1]).max()\n\n        L_hat = 2 * L / lambda_max - I\n        Z = [nd.eye(num_nodes, ctx=A.context)]\n        Zh = self.fc[0](feat)\n        for i in range(1, self._k):\n            if i == 1:\n                Z.append(L_hat)\n            else:\n                Z.append(2 * nd.dot(L_hat, Z[-1]) - Z[-2])\n            Zh = Zh + nd.dot(Z[i], self.fc[i](feat))\n\n        if self.bias is not None:\n            Zh = Zh + self.bias.data(feat.context)\n        return Zh\n"
  },
  {
    "path": "python/dgl/nn/mxnet/conv/densegraphconv.py",
    "content": "\"\"\"MXNet Module for DenseGraphConv\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name\nimport math\n\nimport mxnet as mx\nfrom mxnet import nd\nfrom mxnet.gluon import nn\n\n\nclass DenseGraphConv(nn.Block):\n    \"\"\"Graph Convolutional layer from `Semi-Supervised Classification with Graph\n    Convolutional Networks <https://arxiv.org/abs/1609.02907>`__\n\n    We recommend user to use this module when applying graph convolution on\n    dense graphs.\n\n    Parameters\n    ----------\n    in_feats : int\n        Input feature size; i.e, the number of dimensions of :math:`h_j^{(l)}`.\n    out_feats : int\n        Output feature size; i.e., the number of dimensions of :math:`h_i^{(l+1)}`.\n    norm : str, optional\n        How to apply the normalizer. If is `'right'`, divide the aggregated messages\n        by each node's in-degrees, which is equivalent to averaging the received messages.\n        If is `'none'`, no normalization is applied. Default is `'both'`,\n        where the :math:`c_{ij}` in the paper is applied.\n    bias : bool, optional\n        If True, adds a learnable bias to the output. Default: ``True``.\n    activation : callable activation function/layer or None, optional\n        If not None, applies an activation function to the updated node features.\n        Default: ``None``.\n\n    Notes\n    -----\n    Zero in-degree nodes will lead to all-zero output. A common practice\n    to avoid this is to add a self-loop for each node in the graph,\n    which can be achieved by setting the diagonal of the adjacency matrix to be 1.\n\n    See also\n    --------\n    `GraphConv <https://docs.dgl.ai/api/python/nn.pytorch.html#graphconv>`__\n    \"\"\"\n\n    def __init__(\n        self, in_feats, out_feats, norm=\"both\", bias=True, activation=None\n    ):\n        super(DenseGraphConv, self).__init__()\n        self._in_feats = in_feats\n        self._out_feats = out_feats\n        self._norm = norm\n        with self.name_scope():\n            self.weight = self.params.get(\n                \"weight\",\n                shape=(in_feats, out_feats),\n                init=mx.init.Xavier(magnitude=math.sqrt(2.0)),\n            )\n            if bias:\n                self.bias = self.params.get(\n                    \"bias\", shape=(out_feats,), init=mx.init.Zero()\n                )\n            else:\n                self.bias = None\n            self._activation = activation\n\n    def forward(self, adj, feat):\n        r\"\"\"\n\n        Description\n        -----------\n        Compute (Dense) Graph Convolution layer.\n\n        Parameters\n        ----------\n        adj : mxnet.NDArray\n            The adjacency matrix of the graph to apply Graph Convolution on, when\n            applied to a unidirectional bipartite graph, ``adj`` should be of shape\n            should be of shape :math:`(N_{out}, N_{in})`; when applied to a homo\n            graph, ``adj`` should be of shape :math:`(N, N)`. In both cases,\n            a row represents a destination node while a column represents a source\n            node.\n        feat : mxnet.NDArray\n            The input feature.\n\n        Returns\n        -------\n        mxnet.NDArray\n            The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`\n            is size of output feature.\n        \"\"\"\n        adj = adj.astype(feat.dtype).as_in_context(feat.context)\n        src_degrees = nd.clip(adj.sum(axis=0), a_min=1, a_max=float(\"inf\"))\n        dst_degrees = nd.clip(adj.sum(axis=1), a_min=1, a_max=float(\"inf\"))\n        feat_src = feat\n\n        if self._norm == \"both\":\n            norm_src = nd.power(src_degrees, -0.5)\n            shp_src = norm_src.shape + (1,) * (feat.ndim - 1)\n            norm_src = norm_src.reshape(shp_src).as_in_context(feat.context)\n            feat_src = feat_src * norm_src\n\n        if self._in_feats > self._out_feats:\n            # mult W first to reduce the feature size for aggregation.\n            feat_src = nd.dot(feat_src, self.weight.data(feat_src.context))\n            rst = nd.dot(adj, feat_src)\n        else:\n            # aggregate first then mult W\n            rst = nd.dot(adj, feat_src)\n            rst = nd.dot(rst, self.weight.data(feat_src.context))\n\n        if self._norm != \"none\":\n            if self._norm == \"both\":\n                norm_dst = nd.power(dst_degrees, -0.5)\n            else:  # right\n                norm_dst = 1.0 / dst_degrees\n            shp_dst = norm_dst.shape + (1,) * (feat.ndim - 1)\n            norm_dst = norm_dst.reshape(shp_dst).as_in_context(feat.context)\n            rst = rst * norm_dst\n\n        if self.bias is not None:\n            rst = rst + self.bias.data(feat.context)\n\n        if self._activation is not None:\n            rst = self._activation(rst)\n\n        return rst\n"
  },
  {
    "path": "python/dgl/nn/mxnet/conv/densesageconv.py",
    "content": "\"\"\"MXNet Module for DenseGraphSAGE\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name\nimport math\n\nimport mxnet as mx\nfrom mxnet import nd\nfrom mxnet.gluon import nn\n\nfrom ....utils import check_eq_shape\n\n\nclass DenseSAGEConv(nn.Block):\n    \"\"\"GraphSAGE layer from `Inductive Representation Learning on Large Graphs\n    <https://arxiv.org/abs/1706.02216>`__\n\n    We recommend to use this module when appying GraphSAGE on dense graphs.\n\n    Note that we only support gcn aggregator in DenseSAGEConv.\n\n    Parameters\n    ----------\n    in_feats : int\n        Input feature size; i.e, the number of dimensions of :math:`h_i^{(l)}`.\n    out_feats : int\n        Output feature size; i.e, the number of dimensions of :math:`h_i^{(l+1)}`.\n    feat_drop : float, optional\n        Dropout rate on features. Default: 0.\n    bias : bool\n        If True, adds a learnable bias to the output. Default: ``True``.\n    norm : callable activation function/layer or None, optional\n        If not None, applies normalization to the updated node features.\n    activation : callable activation function/layer or None, optional\n        If not None, applies an activation function to the updated node features.\n        Default: ``None``.\n\n    See also\n    --------\n    `SAGEConv <https://docs.dgl.ai/api/python/nn.pytorch.html#sageconv>`__\n    \"\"\"\n\n    def __init__(\n        self,\n        in_feats,\n        out_feats,\n        feat_drop=0.0,\n        bias=True,\n        norm=None,\n        activation=None,\n    ):\n        super(DenseSAGEConv, self).__init__()\n        self._in_feats = in_feats\n        self._out_feats = out_feats\n        self._norm = norm\n        with self.name_scope():\n            self.feat_drop = nn.Dropout(feat_drop)\n            self.activation = activation\n            self.fc = nn.Dense(\n                out_feats,\n                in_units=in_feats,\n                use_bias=bias,\n                weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)),\n            )\n\n    def forward(self, adj, feat):\n        r\"\"\"\n\n        Description\n        -----------\n        Compute (Dense) Graph SAGE layer.\n\n        Parameters\n        ----------\n        adj : mxnet.NDArray\n            The adjacency matrix of the graph to apply SAGE Convolution on, when\n            applied to a unidirectional bipartite graph, ``adj`` should be of shape\n            should be of shape :math:`(N_{out}, N_{in})`; when applied to a homo\n            graph, ``adj`` should be of shape :math:`(N, N)`. In both cases,\n            a row represents a destination node while a column represents a source\n            node.\n        feat : mxnet.NDArray or a pair of mxnet.NDArray\n            If a mxnet.NDArray is given, the input feature of shape :math:`(N, D_{in})` where\n            :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.\n            If a pair of mxnet.NDArray is given, the pair must contain two tensors of shape\n            :math:`(N_{in}, D_{in})` and :math:`(N_{out}, D_{in})`.\n\n        Returns\n        -------\n        mxnet.NDArray\n            The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`\n            is size of output feature.\n        \"\"\"\n        check_eq_shape(feat)\n        if isinstance(feat, tuple):\n            feat_src = self.feat_drop(feat[0])\n            feat_dst = self.feat_drop(feat[1])\n        else:\n            feat_src = feat_dst = self.feat_drop(feat)\n        adj = adj.astype(feat_src.dtype).as_in_context(feat_src.context)\n        in_degrees = adj.sum(axis=1, keepdims=True)\n        h_neigh = (nd.dot(adj, feat_src) + feat_dst) / (in_degrees + 1)\n        rst = self.fc(h_neigh)\n        # activation\n        if self.activation is not None:\n            rst = self.activation(rst)\n        # normalization\n        if self._norm is not None:\n            rst = self._norm(rst)\n\n        return rst\n"
  },
  {
    "path": "python/dgl/nn/mxnet/conv/edgeconv.py",
    "content": "\"\"\"MXNet Module for EdgeConv Layer\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name\nimport mxnet as mx\nfrom mxnet.gluon import nn\n\nfrom .... import function as fn\nfrom ....base import DGLError\nfrom ....utils import expand_as_pair\n\n\nclass EdgeConv(nn.Block):\n    r\"\"\"EdgeConv layer from `Dynamic Graph CNN for Learning on Point Clouds\n    <https://arxiv.org/pdf/1801.07829>`__\n\n    It can be described as follows:\n\n    .. math::\n       h_i^{(l+1)} = \\max_{j \\in \\mathcal{N}(i)} (\n       \\Theta \\cdot (h_j^{(l)} - h_i^{(l)}) + \\Phi \\cdot h_i^{(l)})\n\n    where :math:`\\mathcal{N}(i)` is the neighbor of :math:`i`.\n    :math:`\\Theta` and :math:`\\Phi` are linear layers.\n\n    .. note::\n\n       The original formulation includes a ReLU inside the maximum operator.\n       This is equivalent to first applying a maximum operator then applying\n       the ReLU.\n\n    Parameters\n    ----------\n    in_feat : int\n        Input feature size; i.e, the number of dimensions of :math:`h_j^{(l)}`.\n    out_feat : int\n        Output feature size; i.e., the number of dimensions of :math:`h_i^{(l+1)}`.\n    batch_norm : bool\n        Whether to include batch normalization on messages. Default: ``False``.\n    allow_zero_in_degree : bool, optional\n        If there are 0-in-degree nodes in the graph, output for those nodes will be invalid\n        since no message will be passed to those nodes. This is harmful for some applications\n        causing silent performance regression. This module will raise a DGLError if it detects\n        0-in-degree nodes in input graph. By setting ``True``, it will suppress the check\n        and let the users handle it by themselves. Default: ``False``.\n\n    Note\n    ----\n    Zero in-degree nodes will lead to invalid output value. This is because no message\n    will be passed to those nodes, the aggregation function will be appied on empty input.\n    A common practice to avoid this is to add a self-loop for each node in the graph if\n    it is homogeneous, which can be achieved by:\n\n    >>> g = ... # a DGLGraph\n    >>> g = dgl.add_self_loop(g)\n\n    Calling ``add_self_loop`` will not work for some graphs, for example, heterogeneous graph\n    since the edge type can not be decided for self_loop edges. Set ``allow_zero_in_degree``\n    to ``True`` for those cases to unblock the code and handle zero-in-degree nodes manually.\n    A common practise to handle this is to filter out the nodes with zero-in-degree when use\n    after conv.\n\n    Examples\n    --------\n    >>> import dgl\n    >>> import numpy as np\n    >>> import mxnet as mx\n    >>> from mxnet import gluon\n    >>> from dgl.nn import EdgeConv\n    >>>\n    >>> # Case 1: Homogeneous graph\n    >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))\n    >>> g = dgl.add_self_loop(g)\n    >>> feat = mx.nd.ones((6, 10))\n    >>> conv = EdgeConv(10, 2)\n    >>> conv.initialize(ctx=mx.cpu(0))\n    >>> res = conv(g, feat)\n    >>> res\n    [[1.0517545 0.8091326]\n    [1.0517545 0.8091326]\n    [1.0517545 0.8091326]\n    [1.0517545 0.8091326]\n    [1.0517545 0.8091326]\n    [1.0517545 0.8091326]]\n    <NDArray 6x2 @cpu(0)>\n\n    >>> # Case 2: Unidirectional bipartite graph\n    >>> u = [0, 1, 0, 0, 1]\n    >>> v = [0, 1, 2, 3, 2]\n    >>> g = dgl.bipartite((u, v))\n    >>> u_fea = mx.nd.random.randn(2, 5)\n    >>> v_fea = mx.nd.random.randn(4, 5)\n    >>> conv = EdgeConv(5, 2, 3)\n    >>> conv.initialize(ctx=mx.cpu(0))\n    >>> res = conv(g, (u_fea, v_fea))\n    >>> res\n    [[-3.4617817   0.84700686]\n    [ 1.3170856  -1.5731761 ]\n    [-2.0761423   0.56653017]\n    [-1.015364    0.78919804]]\n    <NDArray 4x2 @cpu(0)>\n    \"\"\"\n\n    def __init__(\n        self, in_feat, out_feat, batch_norm=False, allow_zero_in_degree=False\n    ):\n        super(EdgeConv, self).__init__()\n        self.batch_norm = batch_norm\n        self._allow_zero_in_degree = allow_zero_in_degree\n\n        with self.name_scope():\n            self.theta = nn.Dense(\n                out_feat, in_units=in_feat, weight_initializer=mx.init.Xavier()\n            )\n            self.phi = nn.Dense(\n                out_feat, in_units=in_feat, weight_initializer=mx.init.Xavier()\n            )\n\n            if batch_norm:\n                self.bn = nn.BatchNorm(in_channels=out_feat)\n\n    def set_allow_zero_in_degree(self, set_value):\n        r\"\"\"\n\n        Description\n        -----------\n        Set allow_zero_in_degree flag.\n\n        Parameters\n        ----------\n        set_value : bool\n            The value to be set to the flag.\n        \"\"\"\n        self._allow_zero_in_degree = set_value\n\n    def forward(self, g, h):\n        \"\"\"\n\n        Description\n        -----------\n        Forward computation\n\n        Parameters\n        ----------\n        g : DGLGraph\n            The graph.\n        feat : mxnet.NDArray or pair of mxnet.NDArray\n            :math:`(N, D)` where :math:`N` is the number of nodes and\n            :math:`D` is the number of feature dimensions.\n\n            If a pair of mxnet.NDArray is given, the graph must be a uni-bipartite graph\n            with only one edge type, and the two tensors must have the same\n            dimensionality on all except the first axis.\n\n        Returns\n        -------\n        mxnet.NDArray\n            New node features.\n\n        Raises\n        ------\n        DGLError\n            If there are 0-in-degree nodes in the input graph, it will raise DGLError\n            since no message will be passed to those nodes. This will cause invalid output.\n            The error can be ignored by setting ``allow_zero_in_degree`` parameter to ``True``.\n        \"\"\"\n        with g.local_scope():\n            if not self._allow_zero_in_degree:\n                if g.in_degrees().min() == 0:\n                    raise DGLError(\n                        \"There are 0-in-degree nodes in the graph, \"\n                        \"output for those nodes will be invalid. \"\n                        \"This is harmful for some applications, \"\n                        \"causing silent performance regression. \"\n                        \"Adding self-loop on the input graph by \"\n                        \"calling `g = dgl.add_self_loop(g)` will resolve \"\n                        \"the issue. Setting ``allow_zero_in_degree`` \"\n                        \"to be `True` when constructing this module will \"\n                        \"suppress the check and let the code run.\"\n                    )\n\n            h_src, h_dst = expand_as_pair(h, g)\n            g.srcdata[\"x\"] = h_src\n            g.dstdata[\"x\"] = h_dst\n            g.apply_edges(fn.v_sub_u(\"x\", \"x\", \"theta\"))\n            g.edata[\"theta\"] = self.theta(g.edata[\"theta\"])\n            g.dstdata[\"phi\"] = self.phi(g.dstdata[\"x\"])\n            if not self.batch_norm:\n                g.update_all(fn.e_add_v(\"theta\", \"phi\", \"e\"), fn.max(\"e\", \"x\"))\n            else:\n                g.apply_edges(fn.e_add_v(\"theta\", \"phi\", \"e\"))\n                g.edata[\"e\"] = self.bn(g.edata[\"e\"])\n                g.update_all(fn.copy_e(\"e\", \"m\"), fn.max(\"m\", \"x\"))\n            return g.dstdata[\"x\"]\n"
  },
  {
    "path": "python/dgl/nn/mxnet/conv/gatconv.py",
    "content": "\"\"\"MXNet modules for graph attention networks(GAT).\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name\nimport math\n\nimport mxnet as mx\nfrom mxnet.gluon import nn\nfrom mxnet.gluon.contrib.nn import Identity\n\nfrom .... import function as fn\nfrom ....base import DGLError\nfrom ....utils import expand_as_pair\nfrom ...functional import edge_softmax\n\n\n# pylint: enable=W0235\nclass GATConv(nn.Block):\n    r\"\"\"Graph attention layer from `Graph Attention Network\n    <https://arxiv.org/pdf/1710.10903.pdf>`__\n\n    .. math::\n        h_i^{(l+1)} = \\sum_{j\\in \\mathcal{N}(i)} \\alpha_{i,j} W^{(l)} h_j^{(l)}\n\n    where :math:`\\alpha_{ij}` is the attention score bewteen node :math:`i` and\n    node :math:`j`:\n\n    .. math::\n        \\alpha_{ij}^{l} &= \\mathrm{softmax_i} (e_{ij}^{l})\n\n        e_{ij}^{l} &= \\mathrm{LeakyReLU}\\left(\\vec{a}^T [W h_{i} \\| W h_{j}]\\right)\n\n    Parameters\n    ----------\n    in_feats : int, or pair of ints\n        Input feature size; i.e, the number of dimensions of :math:`h_i^{(l)}`.\n        GATConv can be applied on homogeneous graph and unidirectional\n        `bipartite graph <https://docs.dgl.ai/generated/dgl.bipartite.html?highlight=bipartite>`__.\n        If the layer is to be applied to a unidirectional bipartite graph, ``in_feats``\n        specifies the input feature size on both the source and destination nodes.  If\n        a scalar is given, the source and destination node feature size would take the\n        same value.\n    out_feats : int\n        Output feature size; i.e, the number of dimensions of :math:`h_i^{(l+1)}`.\n    num_heads : int\n        Number of heads in Multi-Head Attention.\n    feat_drop : float, optional\n        Dropout rate on feature. Defaults: ``0``.\n    attn_drop : float, optional\n        Dropout rate on attention weight. Defaults: ``0``.\n    negative_slope : float, optional\n        LeakyReLU angle of negative slope. Defaults: ``0.2``.\n    residual : bool, optional\n        If True, use residual connection. Defaults: ``False``.\n    activation : callable activation function/layer or None, optional.\n        If not None, applies an activation function to the updated node features.\n        Default: ``None``.\n    allow_zero_in_degree : bool, optional\n        If there are 0-in-degree nodes in the graph, output for those nodes will be invalid\n        since no message will be passed to those nodes. This is harmful for some applications\n        causing silent performance regression. This module will raise a DGLError if it detects\n        0-in-degree nodes in input graph. By setting ``True``, it will suppress the check\n        and let the users handle it by themselves. Defaults: ``False``.\n\n    Note\n    ----\n    Zero in-degree nodes will lead to invalid output value. This is because no message\n    will be passed to those nodes, the aggregation function will be appied on empty input.\n    A common practice to avoid this is to add a self-loop for each node in the graph if\n    it is homogeneous, which can be achieved by:\n\n    >>> g = ... # a DGLGraph\n    >>> g = dgl.add_self_loop(g)\n\n    Calling ``add_self_loop`` will not work for some graphs, for example, heterogeneous graph\n    since the edge type can not be decided for self_loop edges. Set ``allow_zero_in_degree``\n    to ``True`` for those cases to unblock the code and handle zero-in-degree nodes manually.\n    A common practise to handle this is to filter out the nodes with zero-in-degree when use\n    after conv.\n\n    Examples\n    --------\n    >>> import dgl\n    >>> import numpy as np\n    >>> import mxnet as mx\n    >>> from mxnet import gluon\n    >>> from dgl.nn import GATConv\n    >>>\n    >>> # Case 1: Homogeneous graph\n    >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))\n    >>> g = dgl.add_self_loop(g)\n    >>> feat = mx.nd.ones((6, 10))\n    >>> gatconv = GATConv(10, 2, num_heads=3)\n    >>> gatconv.initialize(ctx=mx.cpu(0))\n    >>> res = gatconv(g, feat)\n    >>> res\n    [[[ 0.32368395 -0.10501936]\n    [ 1.0839728   0.92690575]\n    [-0.54581136 -0.84279203]]\n    [[ 0.32368395 -0.10501936]\n    [ 1.0839728   0.92690575]\n    [-0.54581136 -0.84279203]]\n    [[ 0.32368395 -0.10501936]\n    [ 1.0839728   0.92690575]\n    [-0.54581136 -0.84279203]]\n    [[ 0.32368395 -0.10501937]\n    [ 1.0839728   0.9269058 ]\n    [-0.5458114  -0.8427921 ]]\n    [[ 0.32368395 -0.10501936]\n    [ 1.0839728   0.92690575]\n    [-0.54581136 -0.84279203]]\n    [[ 0.32368395 -0.10501936]\n    [ 1.0839728   0.92690575]\n    [-0.54581136 -0.84279203]]]\n    <NDArray 6x3x2 @cpu(0)>\n\n    >>> # Case 2: Unidirectional bipartite graph\n    >>> u = [0, 1, 0, 0, 1]\n    >>> v = [0, 1, 2, 3, 2]\n    >>> g = dgl.heterograph({('A', 'r', 'B'): (u, v)})\n    >>> u_feat = mx.nd.random.randn(2, 5)\n    >>> v_feat = mx.nd.random.randn(4, 10)\n    >>> gatconv = GATConv((5,10), 2, 3)\n    >>> gatconv.initialize(ctx=mx.cpu(0))\n    >>> res = gatconv(g, (u_feat, v_feat))\n    >>> res\n    [[[-1.01624     1.8138596 ]\n    [ 1.2322129  -0.8410206 ]\n    [-1.9325689   1.3824553 ]]\n    [[ 0.9915016  -1.6564168 ]\n    [-0.32610354  0.42505783]\n    [ 1.5278397  -0.92114615]]\n    [[-0.32592064  0.62067866]\n    [ 0.6162219  -0.3405491 ]\n    [-1.356375    0.9988818 ]]\n    [[-1.01624     1.8138596 ]\n    [ 1.2322129  -0.8410206 ]\n    [-1.9325689   1.3824553 ]]]\n    <NDArray 4x3x2 @cpu(0)>\n    \"\"\"\n\n    def __init__(\n        self,\n        in_feats,\n        out_feats,\n        num_heads,\n        feat_drop=0.0,\n        attn_drop=0.0,\n        negative_slope=0.2,\n        residual=False,\n        activation=None,\n        allow_zero_in_degree=False,\n    ):\n        super(GATConv, self).__init__()\n        self._num_heads = num_heads\n        self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)\n        self._in_feats = in_feats\n        self._out_feats = out_feats\n        self._allow_zero_in_degree = allow_zero_in_degree\n        with self.name_scope():\n            if isinstance(in_feats, tuple):\n                self.fc_src = nn.Dense(\n                    out_feats * num_heads,\n                    use_bias=False,\n                    weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)),\n                    in_units=self._in_src_feats,\n                )\n                self.fc_dst = nn.Dense(\n                    out_feats * num_heads,\n                    use_bias=False,\n                    weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)),\n                    in_units=self._in_dst_feats,\n                )\n            else:\n                self.fc = nn.Dense(\n                    out_feats * num_heads,\n                    use_bias=False,\n                    weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)),\n                    in_units=in_feats,\n                )\n            self.attn_l = self.params.get(\n                \"attn_l\",\n                shape=(1, num_heads, out_feats),\n                init=mx.init.Xavier(magnitude=math.sqrt(2.0)),\n            )\n            self.attn_r = self.params.get(\n                \"attn_r\",\n                shape=(1, num_heads, out_feats),\n                init=mx.init.Xavier(magnitude=math.sqrt(2.0)),\n            )\n            self.feat_drop = nn.Dropout(feat_drop)\n            self.attn_drop = nn.Dropout(attn_drop)\n            self.leaky_relu = nn.LeakyReLU(negative_slope)\n            if residual:\n                if in_feats != out_feats:\n                    self.res_fc = nn.Dense(\n                        out_feats * num_heads,\n                        use_bias=False,\n                        weight_initializer=mx.init.Xavier(\n                            magnitude=math.sqrt(2.0)\n                        ),\n                        in_units=in_feats,\n                    )\n                else:\n                    self.res_fc = Identity()\n            else:\n                self.res_fc = None\n            self.activation = activation\n\n    def set_allow_zero_in_degree(self, set_value):\n        r\"\"\"\n\n        Description\n        -----------\n        Set allow_zero_in_degree flag.\n\n        Parameters\n        ----------\n        set_value : bool\n            The value to be set to the flag.\n        \"\"\"\n        self._allow_zero_in_degree = set_value\n\n    def forward(self, graph, feat, get_attention=False):\n        r\"\"\"\n\n        Description\n        -----------\n        Compute graph attention network layer.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The graph.\n        feat : mxnet.NDArray or pair of mxnet.NDArray\n            If a mxnet.NDArray is given, the input feature of shape :math:`(N, *, D_{in})` where\n            :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.\n            If a pair of mxnet.NDArray is given, the pair must contain two tensors of shape\n            :math:`(N_{in}, *, D_{in_{src}})` and :math:`(N_{out}, *, D_{in_{dst}})`.\n        get_attention : bool, optional\n            Whether to return the attention values. Default to False.\n\n        Returns\n        -------\n        mxnet.NDArray\n            The output feature of shape :math:`(N, *, H, D_{out})` where :math:`H`\n            is the number of heads, and :math:`D_{out}` is size of output feature.\n        mxnet.NDArray, optional\n            The attention values of shape :math:`(E, *, H, 1)`, where :math:`E` is the number of\n            edges. This is returned only when :attr:`get_attention` is ``True``.\n\n        Raises\n        ------\n        DGLError\n            If there are 0-in-degree nodes in the input graph, it will raise DGLError\n            since no message will be passed to those nodes. This will cause invalid output.\n            The error can be ignored by setting ``allow_zero_in_degree`` parameter to ``True``.\n        \"\"\"\n        with graph.local_scope():\n            if not self._allow_zero_in_degree:\n                if graph.in_degrees().min() == 0:\n                    raise DGLError(\n                        \"There are 0-in-degree nodes in the graph, \"\n                        \"output for those nodes will be invalid. \"\n                        \"This is harmful for some applications, \"\n                        \"causing silent performance regression. \"\n                        \"Adding self-loop on the input graph by \"\n                        \"calling `g = dgl.add_self_loop(g)` will resolve \"\n                        \"the issue. Setting ``allow_zero_in_degree`` \"\n                        \"to be `True` when constructing this module will \"\n                        \"suppress the check and let the code run.\"\n                    )\n\n            if isinstance(feat, tuple):\n                src_prefix_shape = feat[0].shape[:-1]\n                dst_prefix_shape = feat[1].shape[:-1]\n                feat_dim = feat[0].shape[-1]\n                h_src = self.feat_drop(feat[0])\n                h_dst = self.feat_drop(feat[1])\n                if not hasattr(self, \"fc_src\"):\n                    self.fc_src, self.fc_dst = self.fc, self.fc\n                feat_src = self.fc_src(h_src.reshape(-1, feat_dim)).reshape(\n                    *src_prefix_shape, self._num_heads, self._out_feats\n                )\n                feat_dst = self.fc_dst(h_dst.reshape(-1, feat_dim)).reshape(\n                    *dst_prefix_shape, self._num_heads, self._out_feats\n                )\n            else:\n                src_prefix_shape = dst_prefix_shape = feat.shape[:-1]\n                feat_dim = feat[0].shape[-1]\n                h_src = h_dst = self.feat_drop(feat)\n                feat_src = feat_dst = self.fc(\n                    h_src.reshape(-1, feat_dim)\n                ).reshape(*src_prefix_shape, self._num_heads, self._out_feats)\n                if graph.is_block:\n                    feat_dst = feat_src[: graph.number_of_dst_nodes()]\n                    h_dst = h_dst[: graph.number_of_dst_nodes()]\n                    dst_prefix_shape = (\n                        graph.number_of_dst_nodes(),\n                    ) + dst_prefix_shape[1:]\n            # NOTE: GAT paper uses \"first concatenation then linear projection\"\n            # to compute attention scores, while ours is \"first projection then\n            # addition\", the two approaches are mathematically equivalent:\n            # We decompose the weight vector a mentioned in the paper into\n            # [a_l || a_r], then\n            # a^T [Wh_i || Wh_j] = a_l Wh_i + a_r Wh_j\n            # Our implementation is much efficient because we do not need to\n            # save [Wh_i || Wh_j] on edges, which is not memory-efficient. Plus,\n            # addition could be optimized with DGL's built-in function u_add_v,\n            # which further speeds up computation and saves memory footprint.\n            el = (\n                (feat_src * self.attn_l.data(feat_src.context))\n                .sum(axis=-1)\n                .expand_dims(-1)\n            )\n            er = (\n                (feat_dst * self.attn_r.data(feat_src.context))\n                .sum(axis=-1)\n                .expand_dims(-1)\n            )\n            graph.srcdata.update({\"ft\": feat_src, \"el\": el})\n            graph.dstdata.update({\"er\": er})\n            # compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively.\n            graph.apply_edges(fn.u_add_v(\"el\", \"er\", \"e\"))\n            e = self.leaky_relu(graph.edata.pop(\"e\"))\n            # compute softmax\n            graph.edata[\"a\"] = self.attn_drop(edge_softmax(graph, e))\n            graph.update_all(fn.u_mul_e(\"ft\", \"a\", \"m\"), fn.sum(\"m\", \"ft\"))\n            rst = graph.dstdata[\"ft\"]\n            # residual\n            if self.res_fc is not None:\n                resval = self.res_fc(h_dst.reshape(-1, feat_dim)).reshape(\n                    *dst_prefix_shape, -1, self._out_feats\n                )\n                rst = rst + resval\n            # activation\n            if self.activation:\n                rst = self.activation(rst)\n\n            if get_attention:\n                return rst, graph.edata[\"a\"]\n            else:\n                return rst\n"
  },
  {
    "path": "python/dgl/nn/mxnet/conv/gatedgraphconv.py",
    "content": "\"\"\"MXNet Module for Gated Graph Convolution layer\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name, cell-var-from-loop\nimport mxnet as mx\nfrom mxnet import gluon, nd\nfrom mxnet.gluon import nn\n\nfrom .... import function as fn\n\n\nclass GatedGraphConv(nn.Block):\n    r\"\"\"Gated Graph Convolution layer from `Gated Graph Sequence\n    Neural Networks <https://arxiv.org/pdf/1511.05493.pdf>`__\n\n    .. math::\n        h_{i}^{0} &= [ x_i \\| \\mathbf{0} ]\n\n        a_{i}^{t} &= \\sum_{j\\in\\mathcal{N}(i)} W_{e_{ij}} h_{j}^{t}\n\n        h_{i}^{t+1} &= \\mathrm{GRU}(a_{i}^{t}, h_{i}^{t})\n\n    Parameters\n    ----------\n    in_feats : int\n        Input feature size; i.e, the number of dimensions of :math:`x_i`.\n    out_feats : int\n        Output feature size; i.e., the number of dimensions of :math:`h_i^{(t+1)}`.\n    n_steps : int\n        Number of recurrent steps; i.e, the :math:`t` in the above formula.\n    n_etypes : int\n        Number of edge types.\n    bias : bool\n        If True, adds a learnable bias to the output. Default: ``True``.\n        Can only be set to True in MXNet.\n\n    Example\n    -------\n    >>> import dgl\n    >>> import numpy as np\n    >>> import mxnet as mx\n    >>> from dgl.nn import GatedGraphConv\n    >>>\n    >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))\n    >>> feat = mx.nd.ones((6, 10))\n    >>> conv = GatedGraphConv(10, 10, 2, 3)\n    >>> conv.initialize(ctx=mx.cpu(0))\n    >>> etype = mx.nd.array([0,1,2,0,1,2])\n    >>> res = conv(g, feat, etype)\n    >>> res\n    [[0.24378185 0.17402579 0.2644723  0.2740628  0.14041871 0.32523093\n    0.2703067  0.18234392 0.32777587 0.30957845]\n    [0.17872348 0.28878236 0.2509409  0.20139427 0.3355541  0.22643831\n    0.2690711  0.22341749 0.27995753 0.21575949]\n    [0.23911178 0.16696918 0.26120248 0.27397877 0.13745922 0.3223175\n    0.27561218 0.18071817 0.3251124  0.30608907]\n    [0.25242943 0.3098581  0.25249368 0.27968448 0.24624602 0.12270881\n    0.335147   0.31550157 0.19065917 0.21087633]\n    [0.17503153 0.29523152 0.2474858  0.20848347 0.3526433  0.23443702\n    0.24741334 0.21986549 0.28935105 0.21859099]\n    [0.2159364  0.26942077 0.23083271 0.28329757 0.24758333 0.24230732\n    0.23958017 0.23430146 0.26431587 0.27001363]]\n    <NDArray 6x10 @cpu(0)>\n    \"\"\"\n\n    def __init__(self, in_feats, out_feats, n_steps, n_etypes, bias=True):\n        super(GatedGraphConv, self).__init__()\n        self._in_feats = in_feats\n        self._out_feats = out_feats\n        self._n_steps = n_steps\n        self._n_etypes = n_etypes\n        if not bias:\n            raise KeyError(\"MXNet do not support disabling bias in GRUCell.\")\n        with self.name_scope():\n            self.linears = nn.Sequential()\n            for _ in range(n_etypes):\n                self.linears.add(\n                    nn.Dense(\n                        out_feats,\n                        weight_initializer=mx.init.Xavier(),\n                        in_units=out_feats,\n                    )\n                )\n            self.gru = gluon.rnn.GRUCell(out_feats, input_size=out_feats)\n\n    def forward(self, graph, feat, etypes):\n        \"\"\"Compute Gated Graph Convolution layer.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The graph.\n        feat : mxnet.NDArray\n            The input feature of shape :math:`(N, D_{in})` where :math:`N`\n            is the number of nodes of the graph and :math:`D_{in}` is the\n            input feature size.\n        etypes : torch.LongTensor\n            The edge type tensor of shape :math:`(E,)` where :math:`E` is\n            the number of edges of the graph.\n\n        Returns\n        -------\n        mxnet.NDArray\n            The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`\n            is the output feature size.\n        \"\"\"\n        with graph.local_scope():\n            assert graph.is_homogeneous, (\n                \"not a homogeneous graph; convert it with to_homogeneous \"\n                \"and pass in the edge type as argument\"\n            )\n            zero_pad = nd.zeros(\n                (feat.shape[0], self._out_feats - feat.shape[1]),\n                ctx=feat.context,\n            )\n            feat = nd.concat(feat, zero_pad, dim=-1)\n\n            for _ in range(self._n_steps):\n                graph.ndata[\"h\"] = feat\n                for i in range(self._n_etypes):\n                    eids = (etypes.asnumpy() == i).nonzero()[0]\n                    eids = (\n                        nd.from_numpy(eids, zero_copy=True)\n                        .as_in_context(feat.context)\n                        .astype(graph.idtype)\n                    )\n                    if len(eids) > 0:\n                        graph.apply_edges(\n                            lambda edges: {\n                                \"W_e*h\": self.linears[i](edges.src[\"h\"])\n                            },\n                            eids,\n                        )\n                graph.update_all(fn.copy_e(\"W_e*h\", \"m\"), fn.sum(\"m\", \"a\"))\n                a = graph.ndata.pop(\"a\")\n                feat = self.gru(a, [feat])[0]\n            return feat\n"
  },
  {
    "path": "python/dgl/nn/mxnet/conv/ginconv.py",
    "content": "\"\"\"MXNet Module for Graph Isomorphism Network layer\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name\nimport mxnet as mx\nfrom mxnet.gluon import nn\n\nfrom .... import function as fn\nfrom ....utils import expand_as_pair\n\n\nclass GINConv(nn.Block):\n    r\"\"\"Graph Isomorphism layer from `How Powerful are Graph\n    Neural Networks? <https://arxiv.org/pdf/1810.00826.pdf>`__\n\n    .. math::\n        h_i^{(l+1)} = f_\\Theta \\left((1 + \\epsilon) h_i^{l} +\n        \\mathrm{aggregate}\\left(\\left\\{h_j^{l}, j\\in\\mathcal{N}(i)\n        \\right\\}\\right)\\right)\n\n    Parameters\n    ----------\n    apply_func : callable activation function/layer or None\n        If not None, apply this function to the updated node feature,\n        the :math:`f_\\Theta` in the formula.\n    aggregator_type : str\n        Aggregator type to use (``sum``, ``max`` or ``mean``).\n    init_eps : float, optional\n        Initial :math:`\\epsilon` value, default: ``0``.\n    learn_eps : bool, optional\n        If True, :math:`\\epsilon` will be a learnable parameter. Default: ``False``.\n\n    Example\n    -------\n    >>> import dgl\n    >>> import numpy as np\n    >>> import mxnet as mx\n    >>> from mxnet import gluon\n    >>> from dgl.nn import GINConv\n    >>>\n    >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))\n    >>> feat = mx.nd.ones((6, 10))\n    >>> lin = gluon.nn.Dense(10)\n    >>> lin.initialize(ctx=mx.cpu(0))\n    >>> conv = GINConv(lin, 'max')\n    >>> conv.initialize(ctx=mx.cpu(0))\n    >>> res = conv(g, feat)\n    >>> res\n    [[ 0.44832918 -0.05283341  0.20823681  0.16020004  0.37311912 -0.03372726\n    -0.05716725 -0.20730163  0.14121324  0.46083626]\n    [ 0.44832918 -0.05283341  0.20823681  0.16020004  0.37311912 -0.03372726\n    -0.05716725 -0.20730163  0.14121324  0.46083626]\n    [ 0.44832918 -0.05283341  0.20823681  0.16020004  0.37311912 -0.03372726\n    -0.05716725 -0.20730163  0.14121324  0.46083626]\n    [ 0.44832918 -0.05283341  0.20823681  0.16020004  0.37311912 -0.03372726\n    -0.05716725 -0.20730163  0.14121324  0.46083626]\n    [ 0.44832918 -0.05283341  0.20823681  0.16020004  0.37311912 -0.03372726\n    -0.05716725 -0.20730163  0.14121324  0.46083626]\n    [ 0.22416459 -0.0264167   0.10411841  0.08010002  0.18655956 -0.01686363\n    -0.02858362 -0.10365082  0.07060662  0.23041813]]\n    <NDArray 6x10 @cpu(0)>\n    \"\"\"\n\n    def __init__(\n        self, apply_func, aggregator_type, init_eps=0, learn_eps=False\n    ):\n        super(GINConv, self).__init__()\n        if aggregator_type == \"sum\":\n            self._reducer = fn.sum\n        elif aggregator_type == \"max\":\n            self._reducer = fn.max\n        elif aggregator_type == \"mean\":\n            self._reducer = fn.mean\n        else:\n            raise KeyError(\n                \"Aggregator type {} not recognized.\".format(aggregator_type)\n            )\n\n        with self.name_scope():\n            self.apply_func = apply_func\n            self.eps = self.params.get(\n                \"eps\",\n                shape=(1,),\n                grad_req=\"write\" if learn_eps else \"null\",\n                init=mx.init.Constant(init_eps),\n            )\n\n    def forward(self, graph, feat):\n        r\"\"\"\n\n        Description\n        -----------\n        Compute Graph Isomorphism Network layer.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The graph.\n        feat : mxnet.NDArray or a pair of mxnet.NDArray\n            If a mxnet.NDArray is given, the input feature of shape :math:`(N, D_{in})` where\n            :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.\n            If a pair of mxnet.NDArray is given, the pair must contain two tensors of shape\n            :math:`(N_{in}, D_{in})` and :math:`(N_{out}, D_{in})`.\n            If ``apply_func`` is not None, :math:`D_{in}` should\n            fit the input dimensionality requirement of ``apply_func``.\n\n        Returns\n        -------\n        mxnet.NDArray\n            The output feature of shape :math:`(N, D_{out})` where\n            :math:`D_{out}` is the output dimensionality of ``apply_func``.\n            If ``apply_func`` is None, :math:`D_{out}` should be the same\n            as input dimensionality.\n        \"\"\"\n        with graph.local_scope():\n            feat_src, feat_dst = expand_as_pair(feat, graph)\n            graph.srcdata[\"h\"] = feat_src\n            graph.update_all(fn.copy_u(\"h\", \"m\"), self._reducer(\"m\", \"neigh\"))\n            rst = (\n                1 + self.eps.data(feat_dst.context)\n            ) * feat_dst + graph.dstdata[\"neigh\"]\n            if self.apply_func is not None:\n                rst = self.apply_func(rst)\n            return rst\n"
  },
  {
    "path": "python/dgl/nn/mxnet/conv/gmmconv.py",
    "content": "\"\"\"Torch Module for GMM Conv\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name\nimport math\n\nimport mxnet as mx\nfrom mxnet import nd\nfrom mxnet.gluon import nn\nfrom mxnet.gluon.contrib.nn import Identity\n\nfrom .... import function as fn\nfrom ....base import DGLError\nfrom ....utils import expand_as_pair\n\n\nclass GMMConv(nn.Block):\n    r\"\"\"Gaussian Mixture Model Convolution layer from `Geometric Deep Learning on Graphs and\n    Manifolds using Mixture Model CNNs <https://arxiv.org/abs/1611.08402>`__\n\n    .. math::\n        u_{ij} &= f(x_i, x_j), x_j \\in \\mathcal{N}(i)\n\n        w_k(u) &= \\exp\\left(-\\frac{1}{2}(u-\\mu_k)^T \\Sigma_k^{-1} (u - \\mu_k)\\right)\n\n        h_i^{l+1} &= \\mathrm{aggregate}\\left(\\left\\{\\frac{1}{K}\n         \\sum_{k}^{K} w_k(u_{ij}), \\forall j\\in \\mathcal{N}(i)\\right\\}\\right)\n\n    where :math:`u` denotes the pseudo-coordinates between a vertex and one of its neighbor,\n    computed using function :math:`f`, :math:`\\Sigma_k^{-1}` and :math:`\\mu_k` are\n    learnable parameters representing the covariance matrix and mean vector of a Gaussian kernel.\n\n    Parameters\n    ----------\n    in_feats : int\n        Number of input features; i.e., the number of dimensions of :math:`x_i`.\n    out_feats : int\n        Number of output features; i.e., the number of dimensions of :math:`h_i^{(l+1)}`.\n    dim : int\n        Dimensionality of pseudo-coordinte; i.e, the number of dimensions of :math:`u_{ij}`.\n    n_kernels : int\n        Number of kernels :math:`K`.\n    aggregator_type : str\n        Aggregator type (``sum``, ``mean``, ``max``). Default: ``sum``.\n    residual : bool\n        If True, use residual connection inside this layer. Default: ``False``.\n    bias : bool\n        If True, adds a learnable bias to the output. Default: ``True``.\n    allow_zero_in_degree : bool, optional\n        If there are 0-in-degree nodes in the graph, output for those nodes will be invalid\n        since no message will be passed to those nodes. This is harmful for some applications\n        causing silent performance regression. This module will raise a DGLError if it detects\n        0-in-degree nodes in input graph. By setting ``True``, it will suppress the check\n        and let the users handle it by themselves. Default: ``False``.\n\n    Note\n    ----\n    Zero in-degree nodes will lead to invalid output value. This is because no message\n    will be passed to those nodes, the aggregation function will be appied on empty input.\n    A common practice to avoid this is to add a self-loop for each node in the graph if\n    it is homogeneous, which can be achieved by:\n\n    >>> g = ... # a DGLGraph\n    >>> g = dgl.add_self_loop(g)\n\n    Calling ``add_self_loop`` will not work for some graphs, for example, heterogeneous graph\n    since the edge type can not be decided for self_loop edges. Set ``allow_zero_in_degree``\n    to ``True`` for those cases to unblock the code and handle zero-in-degree nodes manually.\n    A common practise to handle this is to filter out the nodes with zero-in-degree when use\n    after conv.\n\n    Examples\n    --------\n    >>> import dgl\n    >>> import numpy as np\n    >>> import mxnet as mx\n    >>> from dgl.nn import GMMConv\n    >>>\n    >>> # Case 1: Homogeneous graph\n    >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))\n    >>> g = dgl.add_self_loop(g)\n    >>> feat = mx.nd.ones((6, 10))\n    >>> conv = GMMConv(10, 2, 3, 2, 'mean')\n    >>> conv.initialize(ctx=mx.cpu(0))\n    >>> pseudo = mx.nd.ones((12, 3))\n    >>> res = conv(g, feat, pseudo)\n    >>> res\n    [[-0.05083769 -0.1567954 ]\n    [-0.05083769 -0.1567954 ]\n    [-0.05083769 -0.1567954 ]\n    [-0.05083769 -0.1567954 ]\n    [-0.05083769 -0.1567954 ]\n    [-0.05083769 -0.1567954 ]]\n    <NDArray 6x2 @cpu(0)>\n\n    >>> # Case 2: Unidirectional bipartite graph\n    >>> u = [0, 1, 0, 0, 1]\n    >>> v = [0, 1, 2, 3, 2]\n    >>> g = dgl.heterograph({('_N', '_E', '_N'):(u, v)})\n    >>> u_fea = mx.nd.random.randn(2, 5)\n    >>> v_fea = mx.nd.random.randn(4, 10)\n    >>> pseudo = mx.nd.ones((5, 3))\n    >>> conv = GMMConv((5, 10), 2, 3, 2, 'mean')\n    >>> conv.initialize(ctx=mx.cpu(0))\n    >>> res = conv(g, (u_fea, v_fea), pseudo)\n    >>> res\n    [[-0.1005067  -0.09494358]\n    [-0.0023314  -0.07597432]\n    [-0.05141905 -0.08545895]\n    [-0.1005067  -0.09494358]]\n    <NDArray 4x2 @cpu(0)>\n    \"\"\"\n\n    def __init__(\n        self,\n        in_feats,\n        out_feats,\n        dim,\n        n_kernels,\n        aggregator_type=\"sum\",\n        residual=False,\n        bias=True,\n        allow_zero_in_degree=False,\n    ):\n        super(GMMConv, self).__init__()\n\n        self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)\n        self._out_feats = out_feats\n        self._dim = dim\n        self._n_kernels = n_kernels\n        self._allow_zero_in_degree = allow_zero_in_degree\n        if aggregator_type == \"sum\":\n            self._reducer = fn.sum\n        elif aggregator_type == \"mean\":\n            self._reducer = fn.mean\n        elif aggregator_type == \"max\":\n            self._reducer = fn.max\n        else:\n            raise KeyError(\n                \"Aggregator type {} not recognized.\".format(aggregator_type)\n            )\n\n        with self.name_scope():\n            self.mu = self.params.get(\n                \"mu\", shape=(n_kernels, dim), init=mx.init.Normal(0.1)\n            )\n            self.inv_sigma = self.params.get(\n                \"inv_sigma\", shape=(n_kernels, dim), init=mx.init.Constant(1)\n            )\n            self.fc = nn.Dense(\n                n_kernels * out_feats,\n                in_units=self._in_src_feats,\n                use_bias=False,\n                weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)),\n            )\n            if residual:\n                if self._in_dst_feats != out_feats:\n                    self.res_fc = nn.Dense(\n                        out_feats, in_units=self._in_dst_feats, use_bias=False\n                    )\n                else:\n                    self.res_fc = Identity()\n            else:\n                self.res_fc = None\n\n            if bias:\n                self.bias = self.params.get(\n                    \"bias\", shape=(out_feats,), init=mx.init.Zero()\n                )\n            else:\n                self.bias = None\n\n    def set_allow_zero_in_degree(self, set_value):\n        r\"\"\"\n\n        Description\n        -----------\n        Set allow_zero_in_degree flag.\n\n        Parameters\n        ----------\n        set_value : bool\n            The value to be set to the flag.\n        \"\"\"\n        self._allow_zero_in_degree = set_value\n\n    def forward(self, graph, feat, pseudo):\n        \"\"\"\n\n        Description\n        -----------\n        Compute Gaussian Mixture Model Convolution layer.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The graph.\n        feat : mxnet.NDArray\n            If a single tensor is given, the input feature of shape :math:`(N, D_{in})` where\n            :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.\n            If a pair of tensors are given, the pair must contain two tensors of shape\n            :math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.\n        pseudo : mxnet.NDArray\n            The pseudo coordinate tensor of shape :math:`(E, D_{u})` where\n            :math:`E` is the number of edges of the graph and :math:`D_{u}`\n            is the dimensionality of pseudo coordinate.\n\n        Returns\n        -------\n        mxnet.NDArray\n            The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`\n            is the output feature size.\n\n        Raises\n        ------\n        DGLError\n            If there are 0-in-degree nodes in the input graph, it will raise DGLError\n            since no message will be passed to those nodes. This will cause invalid output.\n            The error can be ignored by setting ``allow_zero_in_degree`` parameter to ``True``.\n        \"\"\"\n        if not self._allow_zero_in_degree:\n            if graph.in_degrees().min() == 0:\n                raise DGLError(\n                    \"There are 0-in-degree nodes in the graph, \"\n                    \"output for those nodes will be invalid. \"\n                    \"This is harmful for some applications, \"\n                    \"causing silent performance regression. \"\n                    \"Adding self-loop on the input graph by \"\n                    \"calling `g = dgl.add_self_loop(g)` will resolve \"\n                    \"the issue. Setting ``allow_zero_in_degree`` \"\n                    \"to be `True` when constructing this module will \"\n                    \"suppress the check and let the code run.\"\n                )\n\n        feat_src, feat_dst = expand_as_pair(feat, graph)\n        with graph.local_scope():\n            graph.srcdata[\"h\"] = self.fc(feat_src).reshape(\n                -1, self._n_kernels, self._out_feats\n            )\n            E = graph.num_edges()\n            # compute gaussian weight\n            gaussian = -0.5 * (\n                (\n                    pseudo.reshape(E, 1, self._dim)\n                    - self.mu.data(feat_src.context).reshape(\n                        1, self._n_kernels, self._dim\n                    )\n                )\n                ** 2\n            )\n            gaussian = gaussian * (\n                self.inv_sigma.data(feat_src.context).reshape(\n                    1, self._n_kernels, self._dim\n                )\n                ** 2\n            )\n            gaussian = nd.exp(gaussian.sum(axis=-1, keepdims=True))  # (E, K, 1)\n            graph.edata[\"w\"] = gaussian\n            graph.update_all(fn.u_mul_e(\"h\", \"w\", \"m\"), self._reducer(\"m\", \"h\"))\n            rst = graph.dstdata[\"h\"].sum(1)\n            # residual connection\n            if self.res_fc is not None:\n                rst = rst + self.res_fc(feat_dst)\n            # bias\n            if self.bias is not None:\n                rst = rst + self.bias.data(feat_dst.context)\n            return rst\n"
  },
  {
    "path": "python/dgl/nn/mxnet/conv/graphconv.py",
    "content": "\"\"\"MXNet modules for graph convolutions(GCN)\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name\nimport math\n\nimport mxnet as mx\nfrom mxnet import gluon\n\nfrom .... import function as fn\nfrom ....base import DGLError\nfrom ....utils import expand_as_pair\n\n\nclass GraphConv(gluon.Block):\n    r\"\"\"Graph convolutional layer from `Semi-Supervised Classification with Graph Convolutional\n    Networks <https://arxiv.org/abs/1609.02907>`__\n\n    Mathematically it is defined as follows:\n\n    .. math::\n      h_i^{(l+1)} = \\sigma(b^{(l)} + \\sum_{j\\in\\mathcal{N}(i)}\\frac{1}{c_{ij}}h_j^{(l)}W^{(l)})\n\n    where :math:`\\mathcal{N}(i)` is the set of neighbors of node :math:`i`,\n    :math:`c_{ij}` is the product of the square root of node degrees\n    (i.e.,  :math:`c_{ij} = \\sqrt{|\\mathcal{N}(i)|}\\sqrt{|\\mathcal{N}(j)|}`),\n    and :math:`\\sigma` is an activation function.\n\n    Parameters\n    ----------\n    in_feats : int\n        Input feature size; i.e, the number of dimensions of :math:`h_j^{(l)}`.\n    out_feats : int\n        Output feature size; i.e., the number of dimensions of :math:`h_i^{(l+1)}`.\n    norm : str, optional\n        How to apply the normalizer.  Can be one of the following values:\n\n        * ``right``, to divide the aggregated messages by each node's in-degrees,\n          which is equivalent to averaging the received messages.\n\n        * ``none``, where no normalization is applied.\n\n        * ``both`` (default), where the messages are scaled with :math:`1/c_{ji}` above, equivalent\n          to symmetric normalization.\n\n        * ``left``, to divide the messages sent out from each node by its out-degrees,\n          equivalent to random walk normalization.\n    weight : bool, optional\n        If True, apply a linear layer. Otherwise, aggregating the messages\n        without a weight matrix.\n    bias : bool, optional\n        If True, adds a learnable bias to the output. Default: ``True``.\n    activation : callable activation function/layer or None, optional\n        If not None, applies an activation function to the updated node features.\n        Default: ``None``.\n    allow_zero_in_degree : bool, optional\n        If there are 0-in-degree nodes in the graph, output for those nodes will be invalid\n        since no message will be passed to those nodes. This is harmful for some applications\n        causing silent performance regression. This module will raise a DGLError if it detects\n        0-in-degree nodes in input graph. By setting ``True``, it will suppress the check\n        and let the users handle it by themselves. Default: ``False``.\n\n    Attributes\n    ----------\n    weight : torch.Tensor\n        The learnable weight tensor.\n    bias : torch.Tensor\n        The learnable bias tensor.\n\n    Note\n    ----\n    Zero in-degree nodes will lead to invalid output value. This is because no message\n    will be passed to those nodes, the aggregation function will be appied on empty input.\n    A common practice to avoid this is to add a self-loop for each node in the graph if\n    it is homogeneous, which can be achieved by:\n\n    >>> g = ... # a DGLGraph\n    >>> g = dgl.add_self_loop(g)\n\n    Calling ``add_self_loop`` will not work for some graphs, for example, heterogeneous graph\n    since the edge type can not be decided for self_loop edges. Set ``allow_zero_in_degree``\n    to ``True`` for those cases to unblock the code and handle zero-in-degree nodes manually.\n    A common practise to handle this is to filter out the nodes with zero-in-degree when use\n    after conv.\n\n    Examples\n    --------\n    >>> import dgl\n    >>> import mxnet as mx\n    >>> from mxnet import gluon\n    >>> import numpy as np\n    >>> from dgl.nn import GraphConv\n\n    >>> # Case 1: Homogeneous graph\n    >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))\n    >>> g = dgl.add_self_loop(g)\n    >>> feat = mx.nd.ones((6, 10))\n    >>> conv = GraphConv(10, 2, norm='both', weight=True, bias=True)\n    >>> conv.initialize(ctx=mx.cpu(0))\n    >>> res = conv(g, feat)\n    >>> print(res)\n    [[1.0209361  0.22472616]\n    [1.1240715  0.24742813]\n    [1.0209361  0.22472616]\n    [1.2924911  0.28450024]\n    [1.3568745  0.29867214]\n    [0.7948386  0.17495811]]\n    <NDArray 6x2 @cpu(0)>\n\n    >>> # allow_zero_in_degree example\n    >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))\n    >>> conv = GraphConv(10, 2, norm='both', weight=True, bias=True, allow_zero_in_degree=True)\n    >>> res = conv(g, feat)\n    >>> print(res)\n    [[1.0209361  0.22472616]\n    [1.1240715  0.24742813]\n    [1.0209361  0.22472616]\n    [1.2924911  0.28450024]\n    [1.3568745  0.29867214]\n    [0.  0.]]\n    <NDArray 6x2 @cpu(0)>\n\n    >>> # Case 2: Unidirectional bipartite graph\n    >>> u = [0, 1, 0, 0, 1]\n    >>> v = [0, 1, 2, 3, 2]\n    >>> g = dgl.heterograph({('_N', '_E', '_N'):(u, v)})\n    >>> u_fea = mx.nd.random.randn(2, 5)\n    >>> v_fea = mx.nd.random.randn(4, 5)\n    >>> conv = GraphConv(5, 2, norm='both', weight=True, bias=True)\n    >>> conv.initialize(ctx=mx.cpu(0))\n    >>> res = conv(g, (u_fea, v_fea))\n    >>> res\n    [[ 0.26967263  0.308129  ]\n    [ 0.05143356 -0.11355402]\n    [ 0.22705637  0.1375853 ]\n    [ 0.26967263  0.308129  ]]\n    <NDArray 4x2 @cpu(0)>\n    \"\"\"\n\n    def __init__(\n        self,\n        in_feats,\n        out_feats,\n        norm=\"both\",\n        weight=True,\n        bias=True,\n        activation=None,\n        allow_zero_in_degree=False,\n    ):\n        super(GraphConv, self).__init__()\n        if norm not in (\"none\", \"both\", \"right\", \"left\"):\n            raise DGLError(\n                'Invalid norm value. Must be either \"none\", \"both\", \"right\" or \"left\".'\n                ' But got \"{}\".'.format(norm)\n            )\n        self._in_feats = in_feats\n        self._out_feats = out_feats\n        self._norm = norm\n        self._allow_zero_in_degree = allow_zero_in_degree\n\n        with self.name_scope():\n            if weight:\n                self.weight = self.params.get(\n                    \"weight\",\n                    shape=(in_feats, out_feats),\n                    init=mx.init.Xavier(magnitude=math.sqrt(2.0)),\n                )\n            else:\n                self.weight = None\n\n            if bias:\n                self.bias = self.params.get(\n                    \"bias\", shape=(out_feats,), init=mx.init.Zero()\n                )\n            else:\n                self.bias = None\n\n        self._activation = activation\n\n    def set_allow_zero_in_degree(self, set_value):\n        r\"\"\"\n\n        Description\n        -----------\n        Set allow_zero_in_degree flag.\n\n        Parameters\n        ----------\n        set_value : bool\n            The value to be set to the flag.\n        \"\"\"\n        self._allow_zero_in_degree = set_value\n\n    def forward(self, graph, feat, weight=None):\n        r\"\"\"\n\n        Description\n        -----------\n        Compute graph convolution.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The graph.\n        feat : mxnet.NDArray or pair of mxnet.NDArray\n            If a single tensor is given, it represents the input feature of shape\n            :math:`(N, D_{in})`\n            where :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.\n            If a pair of tensors are given, the pair must contain two tensors of shape\n            :math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.\n\n            Note that in the special case of graph convolutional networks, if a pair of\n            tensors is given, the latter element will not participate in computation.\n        weight : torch.Tensor, optional\n            Optional external weight tensor.\n\n        Returns\n        -------\n        mxnet.NDArray\n            The output feature\n\n        Raises\n        ------\n        DGLError\n            If there are 0-in-degree nodes in the input graph, it will raise DGLError\n            since no message will be passed to those nodes. This will cause invalid output.\n            The error can be ignored by setting ``allow_zero_in_degree`` parameter to ``True``.\n\n        Note\n        ----\n        * Input shape: :math:`(N, *, \\text{in_feats})` where * means any number of additional\n          dimensions, :math:`N` is the number of nodes.\n        * Output shape: :math:`(N, *, \\text{out_feats})` where all but the last dimension are\n          the same shape as the input.\n        * Weight shape: :math:`(\\text{in_feats}, \\text{out_feats})`.\n        \"\"\"\n        with graph.local_scope():\n            if not self._allow_zero_in_degree:\n                if graph.in_degrees().min() == 0:\n                    raise DGLError(\n                        \"There are 0-in-degree nodes in the graph, \"\n                        \"output for those nodes will be invalid. \"\n                        \"This is harmful for some applications, \"\n                        \"causing silent performance regression. \"\n                        \"Adding self-loop on the input graph by \"\n                        \"calling `g = dgl.add_self_loop(g)` will resolve \"\n                        \"the issue. Setting ``allow_zero_in_degree`` \"\n                        \"to be `True` when constructing this module will \"\n                        \"suppress the check and let the code run.\"\n                    )\n\n            feat_src, feat_dst = expand_as_pair(feat, graph)\n            if self._norm in [\"both\", \"left\"]:\n                degs = (\n                    graph.out_degrees()\n                    .as_in_context(feat_dst.context)\n                    .astype(\"float32\")\n                )\n                degs = mx.nd.clip(degs, a_min=1, a_max=float(\"inf\"))\n                if self._norm == \"both\":\n                    norm = mx.nd.power(degs, -0.5)\n                else:\n                    norm = 1.0 / degs\n                shp = norm.shape + (1,) * (feat_src.ndim - 1)\n                norm = norm.reshape(shp)\n                feat_src = feat_src * norm\n\n            if weight is not None:\n                if self.weight is not None:\n                    raise DGLError(\n                        \"External weight is provided while at the same time the\"\n                        \" module has defined its own weight parameter. Please\"\n                        \" create the module with flag weight=False.\"\n                    )\n            else:\n                weight = self.weight.data(feat_src.context)\n\n            if self._in_feats > self._out_feats:\n                # mult W first to reduce the feature size for aggregation.\n                if weight is not None:\n                    feat_src = mx.nd.dot(feat_src, weight)\n                graph.srcdata[\"h\"] = feat_src\n                graph.update_all(\n                    fn.copy_u(u=\"h\", out=\"m\"), fn.sum(msg=\"m\", out=\"h\")\n                )\n                rst = graph.dstdata.pop(\"h\")\n            else:\n                # aggregate first then mult W\n                graph.srcdata[\"h\"] = feat_src\n                graph.update_all(\n                    fn.copy_u(u=\"h\", out=\"m\"), fn.sum(msg=\"m\", out=\"h\")\n                )\n                rst = graph.dstdata.pop(\"h\")\n                if weight is not None:\n                    rst = mx.nd.dot(rst, weight)\n\n            if self._norm in [\"both\", \"right\"]:\n                degs = (\n                    graph.in_degrees()\n                    .as_in_context(feat_dst.context)\n                    .astype(\"float32\")\n                )\n                degs = mx.nd.clip(degs, a_min=1, a_max=float(\"inf\"))\n                if self._norm == \"both\":\n                    norm = mx.nd.power(degs, -0.5)\n                else:\n                    norm = 1.0 / degs\n                shp = norm.shape + (1,) * (feat_dst.ndim - 1)\n                norm = norm.reshape(shp)\n                rst = rst * norm\n\n            if self.bias is not None:\n                rst = rst + self.bias.data(rst.context)\n\n            if self._activation is not None:\n                rst = self._activation(rst)\n\n            return rst\n\n    def __repr__(self):\n        summary = \"GraphConv(\"\n        summary += \"in={:d}, out={:d}, normalization={}, activation={}\".format(\n            self._in_feats, self._out_feats, self._norm, self._activation\n        )\n        summary += \")\"\n        return summary\n"
  },
  {
    "path": "python/dgl/nn/mxnet/conv/nnconv.py",
    "content": "\"\"\"MXNet Module for NNConv layer\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name\nimport mxnet as mx\nfrom mxnet.gluon import nn\nfrom mxnet.gluon.contrib.nn import Identity\n\nfrom .... import function as fn\nfrom ....utils import expand_as_pair\n\n\nclass NNConv(nn.Block):\n    r\"\"\"Graph Convolution layer from `Neural Message Passing\n    for Quantum Chemistry <https://arxiv.org/pdf/1704.01212.pdf>`__\n\n    .. math::\n        h_{i}^{l+1} = h_{i}^{l} + \\mathrm{aggregate}\\left(\\left\\{\n        f_\\Theta (e_{ij}) \\cdot h_j^{l}, j\\in \\mathcal{N}(i) \\right\\}\\right)\n\n    where :math:`e_{ij}` is the edge feature, :math:`f_\\Theta` is a function\n    with learnable parameters.\n\n    Parameters\n    ----------\n    in_feats : int\n        Input feature size; i.e, the number of dimensions of :math:`h_j^{(l)}`.\n        NN can be applied on homogeneous graph and unidirectional\n        `bipartite graph <https://docs.dgl.ai/generated/dgl.bipartite.html?highlight=bipartite>`__.\n        If the layer is to be applied on a unidirectional bipartite graph, ``in_feats``\n        specifies the input feature size on both the source and destination nodes.  If\n        a scalar is given, the source and destination node feature size would take the\n        same value.\n    out_feats : int\n        Output feature size; i.e., the number of dimensions of :math:`h_i^{(l+1)}`.\n    edge_func : callable activation function/layer\n        Maps each edge feature to a vector of shape\n        ``(in_feats * out_feats)`` as weight to compute\n        messages.\n        Also is the :math:`f_\\Theta` in the formula.\n    aggregator_type : str\n        Aggregator type to use (``sum``, ``mean`` or ``max``).\n    residual : bool, optional\n        If True, use residual connection. Default: ``False``.\n    bias : bool, optional\n        If True, adds a learnable bias to the output. Default: ``True``.\n\n    Examples\n    --------\n    >>> import dgl\n    >>> import numpy as np\n    >>> import mxnet as mx\n    >>> from mxnet import gluon\n    >>> from dgl.nn import NNConv\n    >>>\n    >>> # Case 1: Homogeneous graph\n    >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))\n    >>> g = dgl.add_self_loop(g)\n    >>> feat = mx.nd.ones((6, 10))\n    >>> lin = gluon.nn.Dense(20)\n    >>> lin.initialize(ctx=mx.cpu(0))\n    >>> def edge_func(efeat):\n    >>>      return lin(efeat)\n    >>> efeat = mx.nd.ones((12, 5))\n    >>> conv = NNConv(10, 2, edge_func, 'mean')\n    >>> conv.initialize(ctx=mx.cpu(0))\n    >>> res = conv(g, feat, efeat)\n    >>> res\n    [[0.39946803 0.32098457]\n    [0.39946803 0.32098457]\n    [0.39946803 0.32098457]\n    [0.39946803 0.32098457]\n    [0.39946803 0.32098457]\n    [0.39946803 0.32098457]]\n    <NDArray 6x2 @cpu(0)>\n\n    >>> # Case 2: Unidirectional bipartite graph\n    >>> u = [0, 1, 0, 0, 1]\n    >>> v = [0, 1, 2, 3, 2]\n    >>> g = dgl.heterograph({('_N', '_E', '_N'):(u, v)})\n    >>> u_feat = mx.nd.random.randn(2, 10)\n    >>> v_feat = mx.nd.random.randn(4, 10)\n    >>> conv = NNConv(10, 2, edge_func, 'mean')\n    >>> conv.initialize(ctx=mx.cpu(0))\n    >>> efeat = mx.nd.ones((5, 5))\n    >>> res = conv(g, (u_feat, v_feat), efeat)\n    >>> res\n    [[ 0.24425688  0.3238042 ]\n    [-0.11651017 -0.01738572]\n    [ 0.06387337  0.15320925]\n    [ 0.24425688  0.3238042 ]]\n    <NDArray 4x2 @cpu(0)>\n    \"\"\"\n\n    def __init__(\n        self,\n        in_feats,\n        out_feats,\n        edge_func,\n        aggregator_type,\n        residual=False,\n        bias=True,\n    ):\n        super(NNConv, self).__init__()\n        self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)\n        self._out_feats = out_feats\n        if aggregator_type == \"sum\":\n            self.reducer = fn.sum\n        elif aggregator_type == \"mean\":\n            self.reducer = fn.mean\n        elif aggregator_type == \"max\":\n            self.reducer = fn.max\n        else:\n            raise KeyError(\n                \"Aggregator type {} not recognized: \".format(aggregator_type)\n            )\n        self._aggre_type = aggregator_type\n\n        with self.name_scope():\n            self.edge_nn = edge_func\n            if residual:\n                if self._in_dst_feats != out_feats:\n                    self.res_fc = nn.Dense(\n                        out_feats,\n                        in_units=self._in_dst_feats,\n                        use_bias=False,\n                        weight_initializer=mx.init.Xavier(),\n                    )\n                else:\n                    self.res_fc = Identity()\n            else:\n                self.res_fc = None\n\n            if bias:\n                self.bias = self.params.get(\n                    \"bias\", shape=(out_feats,), init=mx.init.Zero()\n                )\n            else:\n                self.bias = None\n\n    def forward(self, graph, feat, efeat):\n        r\"\"\"Compute MPNN Graph Convolution layer.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The graph.\n        feat : mxnet.NDArray or pair of mxnet.NDArray\n            The input feature of shape :math:`(N, D_{in})` where :math:`N`\n            is the number of nodes of the graph and :math:`D_{in}` is the\n            input feature size.\n        efeat : mxnet.NDArray\n            The edge feature of shape :math:`(N, *)`, should fit the input\n            shape requirement of ``edge_nn``.\n\n        Returns\n        -------\n        mxnet.NDArray\n            The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`\n            is the output feature size.\n        \"\"\"\n        with graph.local_scope():\n            feat_src, feat_dst = expand_as_pair(feat, graph)\n\n            # (n, d_in, 1)\n            graph.srcdata[\"h\"] = feat_src.expand_dims(-1)\n            # (n, d_in, d_out)\n            graph.edata[\"w\"] = self.edge_nn(efeat).reshape(\n                -1, self._in_src_feats, self._out_feats\n            )\n            # (n, d_in, d_out)\n            graph.update_all(\n                fn.u_mul_e(\"h\", \"w\", \"m\"), self.reducer(\"m\", \"neigh\")\n            )\n            rst = graph.dstdata.pop(\"neigh\").sum(axis=1)  # (n, d_out)\n            # residual connection\n            if self.res_fc is not None:\n                rst = rst + self.res_fc(feat_dst)\n            # bias\n            if self.bias is not None:\n                rst = rst + self.bias.data(feat_dst.context)\n            return rst\n"
  },
  {
    "path": "python/dgl/nn/mxnet/conv/relgraphconv.py",
    "content": "\"\"\"MXNet module for RelGraphConv\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name\nimport math\n\nimport mxnet as mx\nimport numpy as np\nfrom mxnet import gluon, nd\nfrom mxnet.gluon import nn\n\nfrom .... import function as fn\nfrom .. import utils\n\n\nclass RelGraphConv(gluon.Block):\n    r\"\"\"Relational graph convolution layer from `Modeling Relational Data with Graph\n    Convolutional Networks <https://arxiv.org/abs/1703.06103>`__\n\n    It can be described as below:\n\n    .. math::\n\n       h_i^{(l+1)} = \\sigma(\\sum_{r\\in\\mathcal{R}}\n       \\sum_{j\\in\\mathcal{N}^r(i)}\\frac{1}{c_{i,r}}W_r^{(l)}h_j^{(l)}+W_0^{(l)}h_i^{(l)})\n\n    where :math:`\\mathcal{N}^r(i)` is the neighbor set of node :math:`i` w.r.t. relation\n    :math:`r`. :math:`c_{i,r}` is the normalizer equal\n    to :math:`|\\mathcal{N}^r(i)|`. :math:`\\sigma` is an activation function. :math:`W_0`\n    is the self-loop weight.\n\n    The basis regularization decomposes :math:`W_r` by:\n\n    .. math::\n\n       W_r^{(l)} = \\sum_{b=1}^B a_{rb}^{(l)}V_b^{(l)}\n\n    where :math:`B` is the number of bases, :math:`V_b^{(l)}` are linearly combined\n    with coefficients :math:`a_{rb}^{(l)}`.\n\n    The block-diagonal-decomposition regularization decomposes :math:`W_r` into :math:`B`\n    number of block diagonal matrices. We refer :math:`B` as the number of bases.\n\n    The block regularization decomposes :math:`W_r` by:\n\n    .. math::\n\n       W_r^{(l)} = \\oplus_{b=1}^B Q_{rb}^{(l)}\n\n    where :math:`B` is the number of bases, :math:`Q_{rb}^{(l)}` are block\n    bases with shape :math:`R^{(d^{(l+1)}/B)*(d^{l}/B)}`.\n\n    Parameters\n    ----------\n    in_feat : int\n        Input feature size; i.e, the number of dimensions of :math:`h_j^{(l)}`.\n    out_feat : int\n        Output feature size; i.e., the number of dimensions of :math:`h_i^{(l+1)}`.\n    num_rels : int\n        Number of relations. .\n    regularizer : str\n        Which weight regularizer to use \"basis\" or \"bdd\".\n        \"basis\" is short for basis-diagonal-decomposition.\n        \"bdd\" is short for block-diagonal-decomposition.\n    num_bases : int, optional\n        Number of bases. If is none, use number of relations. Default: ``None``.\n    bias : bool, optional\n        True if bias is added. Default: ``True``.\n    activation : callable, optional\n        Activation function. Default: ``None``.\n    self_loop : bool, optional\n        True to include self loop message. Default: ``True``.\n    low_mem : bool, optional\n        True to use low memory implementation of relation message passing function. Default: False.\n        This option trades speed with memory consumption, and will slowdown the forward/backward.\n        Turn it on when you encounter OOM problem during training or evaluation. Default: ``False``.\n    dropout : float, optional\n        Dropout rate. Default: ``0.0``\n    layer_norm: float, optional\n        Add layer norm. Default: ``False``\n\n    Examples\n    --------\n    >>> import dgl\n    >>> import numpy as np\n    >>> import mxnet as mx\n    >>> from mxnet import gluon\n    >>> from dgl.nn import RelGraphConv\n    >>>\n    >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))\n    >>> feat = mx.nd.ones((6, 10))\n    >>> conv = RelGraphConv(10, 2, 3, regularizer='basis', num_bases=2)\n    >>> conv.initialize(ctx=mx.cpu(0))\n    >>> etype = mx.nd.array(np.array([0,1,2,0,1,2]).astype(np.int64))\n    >>> res = conv(g, feat, etype)\n    [[ 0.561324    0.33745846]\n    [ 0.61585337  0.09992217]\n    [ 0.561324    0.33745846]\n    [-0.01557937  0.01227859]\n    [ 0.61585337  0.09992217]\n    [ 0.056508   -0.00307822]]\n    <NDArray 6x2 @cpu(0)>\n    \"\"\"\n\n    def __init__(\n        self,\n        in_feat,\n        out_feat,\n        num_rels,\n        regularizer=\"basis\",\n        num_bases=None,\n        bias=True,\n        activation=None,\n        self_loop=True,\n        low_mem=False,\n        dropout=0.0,\n        layer_norm=False,\n    ):\n        super(RelGraphConv, self).__init__()\n        self.in_feat = in_feat\n        self.out_feat = out_feat\n        self.num_rels = num_rels\n        self.regularizer = regularizer\n        self.num_bases = num_bases\n        if (\n            self.num_bases is None\n            or self.num_bases > self.num_rels\n            or self.num_bases < 0\n        ):\n            self.num_bases = self.num_rels\n        self.bias = bias\n        self.activation = activation\n        self.self_loop = self_loop\n\n        assert (\n            low_mem is False\n        ), \"MXNet currently does not support low-memory implementation.\"\n        assert (\n            layer_norm is False\n        ), \"MXNet currently does not support layer norm.\"\n\n        if regularizer == \"basis\":\n            # add basis weights\n            self.weight = self.params.get(\n                \"weight\",\n                shape=(self.num_bases, self.in_feat, self.out_feat),\n                init=mx.init.Xavier(magnitude=math.sqrt(2.0)),\n            )\n            if self.num_bases < self.num_rels:\n                # linear combination coefficients\n                self.w_comp = self.params.get(\n                    \"w_comp\",\n                    shape=(self.num_rels, self.num_bases),\n                    init=mx.init.Xavier(magnitude=math.sqrt(2.0)),\n                )\n            # message func\n            self.message_func = self.basis_message_func\n        elif regularizer == \"bdd\":\n            if in_feat % num_bases != 0 or out_feat % num_bases != 0:\n                raise ValueError(\n                    \"Feature size must be a multiplier of num_bases.\"\n                )\n            # add block diagonal weights\n            self.submat_in = in_feat // self.num_bases\n            self.submat_out = out_feat // self.num_bases\n\n            # assuming in_feat and out_feat are both divisible by num_bases\n            self.weight = self.params.get(\n                \"weight\",\n                shape=(\n                    self.num_rels,\n                    self.num_bases * self.submat_in * self.submat_out,\n                ),\n                init=mx.init.Xavier(magnitude=math.sqrt(2.0)),\n            )\n            # message func\n            self.message_func = self.bdd_message_func\n        else:\n            raise ValueError(\"Regularizer must be either 'basis' or 'bdd'\")\n\n        # bias\n        if self.bias:\n            self.h_bias = self.params.get(\n                \"bias\", shape=(out_feat,), init=mx.init.Zero()\n            )\n\n        # weight for self loop\n        if self.self_loop:\n            self.loop_weight = self.params.get(\n                \"W_0\",\n                shape=(in_feat, out_feat),\n                init=mx.init.Xavier(magnitude=math.sqrt(2.0)),\n            )\n\n        self.dropout = nn.Dropout(dropout)\n\n    def basis_message_func(self, edges):\n        \"\"\"Message function for basis regularizer\"\"\"\n        ctx = edges.src[\"h\"].context\n        if self.num_bases < self.num_rels:\n            # generate all weights from bases\n            weight = self.weight.data(ctx).reshape(\n                self.num_bases, self.in_feat * self.out_feat\n            )\n            weight = nd.dot(self.w_comp.data(ctx), weight).reshape(\n                self.num_rels, self.in_feat, self.out_feat\n            )\n        else:\n            weight = self.weight.data(ctx)\n\n        msg = utils.bmm_maybe_select(edges.src[\"h\"], weight, edges.data[\"type\"])\n        if \"norm\" in edges.data:\n            msg = msg * edges.data[\"norm\"]\n        return {\"msg\": msg}\n\n    def bdd_message_func(self, edges):\n        \"\"\"Message function for block-diagonal-decomposition regularizer\"\"\"\n        ctx = edges.src[\"h\"].context\n        if (\n            edges.src[\"h\"].dtype in (np.int32, np.int64)\n            and len(edges.src[\"h\"].shape) == 1\n        ):\n            raise TypeError(\n                \"Block decomposition does not allow integer ID feature.\"\n            )\n        weight = self.weight.data(ctx)[edges.data[\"type\"], :].reshape(\n            -1, self.submat_in, self.submat_out\n        )\n        node = edges.src[\"h\"].reshape(-1, 1, self.submat_in)\n        msg = nd.batch_dot(node, weight).reshape(-1, self.out_feat)\n        if \"norm\" in edges.data:\n            msg = msg * edges.data[\"norm\"]\n        return {\"msg\": msg}\n\n    def forward(self, g, x, etypes, norm=None):\n        \"\"\"\n        Description\n        -----------\n\n        Forward computation\n\n        Parameters\n        ----------\n        g : DGLGraph\n            The graph.\n        feat : mx.ndarray.NDArray\n            Input node features. Could be either\n\n                * :math:`(|V|, D)` dense tensor\n                * :math:`(|V|,)` int64 vector, representing the categorical values of each\n                  node. It then treat the input feature as an one-hot encoding feature.\n        etypes : mx.ndarray.NDArray\n            Edge type tensor. Shape: :math:`(|E|,)`\n        norm : mx.ndarray.NDArray\n            Optional edge normalizer tensor. Shape: :math:`(|E|, 1)`.\n\n        Returns\n        -------\n        mx.ndarray.NDArray\n            New node features.\n        \"\"\"\n        assert g.is_homogeneous, (\n            \"not a homogeneous graph; convert it with to_homogeneous \"\n            \"and pass in the edge type as argument\"\n        )\n        with g.local_scope():\n            g.ndata[\"h\"] = x\n            g.edata[\"type\"] = etypes\n            if norm is not None:\n                g.edata[\"norm\"] = norm\n            if self.self_loop:\n                loop_message = utils.matmul_maybe_select(\n                    x, self.loop_weight.data(x.context)\n                )\n\n            # message passing\n            g.update_all(self.message_func, fn.sum(msg=\"msg\", out=\"h\"))\n\n            # apply bias and activation\n            node_repr = g.ndata[\"h\"]\n            if self.bias:\n                node_repr = node_repr + self.h_bias.data(x.context)\n            if self.self_loop:\n                node_repr = node_repr + loop_message\n            if self.activation:\n                node_repr = self.activation(node_repr)\n            node_repr = self.dropout(node_repr)\n\n            return node_repr\n"
  },
  {
    "path": "python/dgl/nn/mxnet/conv/sageconv.py",
    "content": "\"\"\"MXNet Module for GraphSAGE layer\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name\nimport math\n\nimport mxnet as mx\nfrom mxnet import nd\nfrom mxnet.gluon import nn\n\nfrom .... import function as fn\nfrom ....base import DGLError\nfrom ....utils import check_eq_shape, expand_as_pair\n\n\nclass SAGEConv(nn.Block):\n    r\"\"\"GraphSAGE layer from `Inductive Representation Learning on\n    Large Graphs <https://arxiv.org/pdf/1706.02216.pdf>`__\n\n    .. math::\n        h_{\\mathcal{N}(i)}^{(l+1)} &= \\mathrm{aggregate}\n        \\left(\\{h_{j}^{l}, \\forall j \\in \\mathcal{N}(i) \\}\\right)\n\n        h_{i}^{(l+1)} &= \\sigma \\left(W \\cdot \\mathrm{concat}\n        (h_{i}^{l}, h_{\\mathcal{N}(i)}^{l+1}) \\right)\n\n        h_{i}^{(l+1)} &= \\mathrm{norm}(h_{i}^{(l+1)})\n\n    Parameters\n    ----------\n    in_feats : int, or pair of ints\n        Input feature size; i.e, the number of dimensions of :math:`h_i^{(l)}`.\n\n        GATConv can be applied on homogeneous graph and unidirectional\n        `bipartite graph <https://docs.dgl.ai/generated/dgl.bipartite.html?highlight=bipartite>`__.\n        If the layer applies on a unidirectional bipartite graph, ``in_feats``\n        specifies the input feature size on both the source and destination nodes.  If\n        a scalar is given, the source and destination node feature size would take the\n        same value.\n\n        If aggregator type is ``gcn``, the feature size of source and destination nodes\n        are required to be the same.\n    out_feats : int\n        Output feature size; i.e, the number of dimensions of :math:`h_i^{(l+1)}`.\n    aggregator_type : str\n        Aggregator type to use (``mean``, ``gcn``, ``pool``, ``lstm``).\n    feat_drop : float\n        Dropout rate on features, default: ``0``.\n    bias : bool\n        If True, adds a learnable bias to the output. Default: ``True``.\n    norm : callable activation function/layer or None, optional\n        If not None, applies normalization to the updated node features.\n    activation : callable activation function/layer or None, optional\n        If not None, applies an activation function to the updated node features.\n        Default: ``None``.\n\n    Examples\n    --------\n    >>> import dgl\n    >>> import numpy as np\n    >>> import mxnet as mx\n    >>> from dgl.nn import SAGEConv\n    >>>\n    >>> # Case 1: Homogeneous graph\n    >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))\n    >>> g = dgl.add_self_loop(g)\n    >>> feat = mx.nd.ones((6, 10))\n    >>> conv = SAGEConv(10, 2, 'pool')\n    >>> conv.initialize(ctx=mx.cpu(0))\n    >>> res = conv(g, feat)\n    >>> res\n    [[ 0.32144994 -0.8729614 ]\n    [ 0.32144994 -0.8729614 ]\n    [ 0.32144994 -0.8729614 ]\n    [ 0.32144994 -0.8729614 ]\n    [ 0.32144994 -0.8729614 ]\n    [ 0.32144994 -0.8729614 ]]\n    <NDArray 6x2 @cpu(0)>\n\n    >>> # Case 2: Unidirectional bipartite graph\n    >>> u = [0, 1, 0, 0, 1]\n    >>> v = [0, 1, 2, 3, 2]\n    >>> g = dgl.heterograph({('_N', '_E', '_N'):(u, v)})\n    >>> u_fea = mx.nd.random.randn(2, 5)\n    >>> v_fea = mx.nd.random.randn(4, 10)\n    >>> conv = SAGEConv((5, 10), 2, 'pool')\n    >>> conv.initialize(ctx=mx.cpu(0))\n    >>> res = conv(g, (u_fea, v_fea))\n    >>> res\n    [[-0.60524774  0.7196473 ]\n    [ 0.8832787  -0.5928619 ]\n    [-1.8245722   1.159798  ]\n    [-1.0509381   2.2239418 ]]\n    <NDArray 4x2 @cpu(0)>\n    \"\"\"\n\n    def __init__(\n        self,\n        in_feats,\n        out_feats,\n        aggregator_type=\"mean\",\n        feat_drop=0.0,\n        bias=True,\n        norm=None,\n        activation=None,\n    ):\n        super(SAGEConv, self).__init__()\n        valid_aggre_types = {\"mean\", \"gcn\", \"pool\", \"lstm\"}\n        if aggregator_type not in valid_aggre_types:\n            raise DGLError(\n                \"Invalid aggregator_type. Must be one of {}. \"\n                \"But got {!r} instead.\".format(\n                    valid_aggre_types, aggregator_type\n                )\n            )\n\n        self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)\n        self._out_feats = out_feats\n        self._aggre_type = aggregator_type\n        with self.name_scope():\n            self.norm = norm\n            self.feat_drop = nn.Dropout(feat_drop)\n            self.activation = activation\n            if aggregator_type == \"pool\":\n                self.fc_pool = nn.Dense(\n                    self._in_src_feats,\n                    use_bias=bias,\n                    weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)),\n                    in_units=self._in_src_feats,\n                )\n            if aggregator_type == \"lstm\":\n                raise NotImplementedError\n            if aggregator_type != \"gcn\":\n                self.fc_self = nn.Dense(\n                    out_feats,\n                    use_bias=bias,\n                    weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)),\n                    in_units=self._in_dst_feats,\n                )\n            self.fc_neigh = nn.Dense(\n                out_feats,\n                use_bias=bias,\n                weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)),\n                in_units=self._in_src_feats,\n            )\n\n    def forward(self, graph, feat):\n        r\"\"\"Compute GraphSAGE layer.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The graph.\n        feat : mxnet.NDArray or pair of mxnet.NDArray\n            If a single tensor is given, it represents the input feature of shape\n            :math:`(N, D_{in})`\n            where :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.\n            If a pair of tensors are given, the pair must contain two tensors of shape\n            :math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.\n\n        Returns\n        -------\n        mxnet.NDArray\n            The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`\n            is size of output feature.\n        \"\"\"\n        with graph.local_scope():\n            if isinstance(feat, tuple):\n                feat_src = self.feat_drop(feat[0])\n                feat_dst = self.feat_drop(feat[1])\n            else:\n                feat_src = feat_dst = self.feat_drop(feat)\n                if graph.is_block:\n                    feat_dst = feat_src[: graph.number_of_dst_nodes()]\n\n            h_self = feat_dst\n\n            # Handle the case of graphs without edges\n            if graph.num_edges() == 0:\n                dst_neigh = mx.nd.zeros(\n                    (graph.number_of_dst_nodes(), self._in_src_feats)\n                )\n                dst_neigh = dst_neigh.as_in_context(feat_dst.context)\n                graph.dstdata[\"neigh\"] = dst_neigh\n\n            if self._aggre_type == \"mean\":\n                graph.srcdata[\"h\"] = feat_src\n                graph.update_all(fn.copy_u(\"h\", \"m\"), fn.mean(\"m\", \"neigh\"))\n                h_neigh = graph.dstdata[\"neigh\"]\n            elif self._aggre_type == \"gcn\":\n                check_eq_shape(feat)\n                graph.srcdata[\"h\"] = feat_src\n                graph.dstdata[\"h\"] = feat_dst  # same as above if homogeneous\n                graph.update_all(fn.copy_u(\"h\", \"m\"), fn.sum(\"m\", \"neigh\"))\n                # divide in degrees\n                degs = graph.in_degrees().astype(feat_dst.dtype)\n                degs = degs.as_in_context(feat_dst.context)\n                h_neigh = (graph.dstdata[\"neigh\"] + graph.dstdata[\"h\"]) / (\n                    degs.expand_dims(-1) + 1\n                )\n            elif self._aggre_type == \"pool\":\n                graph.srcdata[\"h\"] = nd.relu(self.fc_pool(feat_src))\n                graph.update_all(fn.copy_u(\"h\", \"m\"), fn.max(\"m\", \"neigh\"))\n                h_neigh = graph.dstdata[\"neigh\"]\n            elif self._aggre_type == \"lstm\":\n                raise NotImplementedError\n            else:\n                raise KeyError(\n                    \"Aggregator type {} not recognized.\".format(\n                        self._aggre_type\n                    )\n                )\n\n            if self._aggre_type == \"gcn\":\n                rst = self.fc_neigh(h_neigh)\n            else:\n                rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)\n            # activation\n            if self.activation is not None:\n                rst = self.activation(rst)\n            # normalization\n            if self.norm is not None:\n                rst = self.norm(rst)\n            return rst\n"
  },
  {
    "path": "python/dgl/nn/mxnet/conv/sgconv.py",
    "content": "\"\"\"MXNet Module for Simplifying Graph Convolution layer\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name\n\nimport mxnet as mx\nfrom mxnet import nd\nfrom mxnet.gluon import nn\n\nfrom .... import function as fn\nfrom ....base import DGLError\n\n\nclass SGConv(nn.Block):\n    r\"\"\"SGC layer from `Simplifying Graph Convolutional Networks\n    <https://arxiv.org/pdf/1902.07153.pdf>`__\n\n    .. math::\n        H^{K} = (\\tilde{D}^{-1/2} \\tilde{A} \\tilde{D}^{-1/2})^K X \\Theta\n\n    where :math:`\\tilde{A}` is :math:`A` + :math:`I`.\n    Thus the graph input is expected to have self-loop edges added.\n\n    Parameters\n    ----------\n    in_feats : int\n        Number of input features; i.e, the number of dimensions of :math:`X`.\n    out_feats : int\n        Number of output features; i.e, the number of dimensions of :math:`H^{K}`.\n    k : int\n        Number of hops :math:`K`. Defaults:``1``.\n    cached : bool\n        If True, the module would cache\n\n        .. math::\n            (\\tilde{D}^{-\\frac{1}{2}}\\tilde{A}\\tilde{D}^{-\\frac{1}{2}})^K X\\Theta\n\n        at the first forward call. This parameter should only be set to\n        ``True`` in Transductive Learning setting.\n    bias : bool\n        If True, adds a learnable bias to the output. Default: ``True``.\n    norm : callable activation function/layer or None, optional\n        If not None, applies normalization to the updated node features.  Default: ``False``.\n    allow_zero_in_degree : bool, optional\n        If there are 0-in-degree nodes in the graph, output for those nodes will be invalid\n        since no message will be passed to those nodes. This is harmful for some applications\n        causing silent performance regression. This module will raise a DGLError if it detects\n        0-in-degree nodes in input graph. By setting ``True``, it will suppress the check\n        and let the users handle it by themselves. Default: ``False``.\n\n    Note\n    ----\n    Zero in-degree nodes will lead to invalid output value. This is because no message\n    will be passed to those nodes, the aggregation function will be appied on empty input.\n    A common practice to avoid this is to add a self-loop for each node in the graph if\n    it is homogeneous, which can be achieved by:\n\n    >>> g = ... # a DGLGraph\n    >>> g = dgl.add_self_loop(g)\n\n    Calling ``add_self_loop`` will not work for some graphs, for example, heterogeneous graph\n    since the edge type can not be decided for self_loop edges. Set ``allow_zero_in_degree``\n    to ``True`` for those cases to unblock the code and handle zero-in-degree nodes manually.\n    A common practise to handle this is to filter out the nodes with zero-in-degree when use\n    after conv.\n\n    Example\n    -------\n    >>> import dgl\n    >>> import numpy as np\n    >>> import mxnet as mx\n    >>> from dgl.nn import SGConv\n    >>>\n    >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))\n    >>> g = dgl.add_self_loop(g)\n    >>> feat = mx.nd.ones((6, 10))\n    >>> conv = SGConv(10, 2, k=2, cached=True)\n    >>> conv.initialize(ctx=mx.cpu(0))\n    >>> res = conv(g, feat)\n    >>> res\n    [[ 2.264404   -0.26684892]\n    [ 2.264404   -0.26684892]\n    [ 2.264404   -0.26684892]\n    [ 3.2273252  -0.3803246 ]\n    [ 2.247593   -0.2648679 ]\n    [ 2.2644043  -0.26684904]]\n    <NDArray 6x2 @cpu(0)>\n    \"\"\"\n\n    def __init__(\n        self,\n        in_feats,\n        out_feats,\n        k=1,\n        cached=False,\n        bias=True,\n        norm=None,\n        allow_zero_in_degree=False,\n    ):\n        super(SGConv, self).__init__()\n        self._cached = cached\n        self._cached_h = None\n        self._k = k\n        self._allow_zero_in_degree = allow_zero_in_degree\n        with self.name_scope():\n            self.norm = norm\n            self.fc = nn.Dense(\n                out_feats,\n                in_units=in_feats,\n                use_bias=bias,\n                weight_initializer=mx.init.Xavier(),\n            )\n\n    def set_allow_zero_in_degree(self, set_value):\n        r\"\"\"\n\n        Description\n        -----------\n        Set allow_zero_in_degree flag.\n\n        Parameters\n        ----------\n        set_value : bool\n            The value to be set to the flag.\n        \"\"\"\n        self._allow_zero_in_degree = set_value\n\n    def forward(self, graph, feat):\n        r\"\"\"\n\n        Description\n        -----------\n        Compute Simplifying Graph Convolution layer.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The graph.\n        feat : mxnet.NDArray\n            The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`\n            is size of input feature, :math:`N` is the number of nodes.\n\n        Returns\n        -------\n        mxnet.NDArray\n            The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`\n            is size of output feature.\n\n        Raises\n        ------\n        DGLError\n            If there are 0-in-degree nodes in the input graph, it will raise DGLError\n            since no message will be passed to those nodes. This will cause invalid output.\n            The error can be ignored by setting ``allow_zero_in_degree`` parameter to ``True``.\n\n        Note\n        ----\n        If ``cache`` is set to True, ``feat`` and ``graph`` should not change during\n        training, or you will get wrong results.\n        \"\"\"\n        with graph.local_scope():\n            if not self._allow_zero_in_degree:\n                if graph.in_degrees().min() == 0:\n                    raise DGLError(\n                        \"There are 0-in-degree nodes in the graph, \"\n                        \"output for those nodes will be invalid. \"\n                        \"This is harmful for some applications, \"\n                        \"causing silent performance regression. \"\n                        \"Adding self-loop on the input graph by \"\n                        \"calling `g = dgl.add_self_loop(g)` will resolve \"\n                        \"the issue. Setting ``allow_zero_in_degree`` \"\n                        \"to be `True` when constructing this module will \"\n                        \"suppress the check and let the code run.\"\n                    )\n\n            if self._cached_h is not None:\n                feat = self._cached_h\n            else:\n                # compute normalization\n                degs = nd.clip(\n                    graph.in_degrees().astype(feat.dtype), 1, float(\"inf\")\n                )\n                norm = nd.power(degs, -0.5).expand_dims(1)\n                norm = norm.as_in_context(feat.context)\n                # compute (D^-1 A D)^k X\n                for _ in range(self._k):\n                    feat = feat * norm\n                    graph.ndata[\"h\"] = feat\n                    graph.update_all(fn.copy_u(\"h\", \"m\"), fn.sum(\"m\", \"h\"))\n                    feat = graph.ndata.pop(\"h\")\n                    feat = feat * norm\n\n                if self.norm is not None:\n                    feat = self.norm(feat)\n\n                # cache feature\n                if self._cached:\n                    self._cached_h = feat\n            return self.fc(feat)\n"
  },
  {
    "path": "python/dgl/nn/mxnet/conv/tagconv.py",
    "content": "\"\"\"MXNet module for TAGConv\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name\nimport math\n\nimport mxnet as mx\nfrom mxnet import gluon\n\nfrom .... import function as fn\n\n\nclass TAGConv(gluon.Block):\n    r\"\"\"Topology Adaptive Graph Convolutional layer from `Topology\n    Adaptive Graph Convolutional Networks <https://arxiv.org/pdf/1710.10370.pdf>`__.\n\n    .. math::\n        H^{K} = {\\sum}_{k=0}^K (D^{-1/2} A D^{-1/2})^{k} X {\\Theta}_{k},\n\n    where :math:`A` denotes the adjacency matrix,\n    :math:`D_{ii} = \\sum_{j=0} A_{ij}` its diagonal degree matrix,\n    :math:`{\\Theta}_{k}` denotes the linear weights to sum the results of different hops together.\n\n    Parameters\n    ----------\n    in_feats : int\n        Input feature size. i.e, the number of dimensions of :math:`X`.\n    out_feats : int\n        Output feature size.  i.e, the number of dimensions of :math:`H^{K}`.\n    k: int, optional\n        Number of hops :math:`K`. Default: ``2``.\n    bias: bool, optional\n        If True, adds a learnable bias to the output. Default: ``True``.\n    activation: callable activation function/layer or None, optional\n        If not None, applies an activation function to the updated node features.\n        Default: ``None``.\n\n    Attributes\n    ----------\n    lin : torch.Module\n        The learnable linear module.\n\n    Example\n    -------\n    >>> import dgl\n    >>> import numpy as np\n    >>> import mxnet as mx\n    >>> from mxnet import gluon\n    >>> from dgl.nn import TAGConv\n    >>>\n    >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))\n    >>> feat = mx.nd.ones((6, 10))\n    >>> conv = TAGConv(10, 2, k=2)\n    >>> conv.initialize(ctx=mx.cpu(0))\n    >>> res = conv(g, feat)\n    >>> res\n    [[-0.86147034  0.10089529]\n    [-0.86147034  0.10089529]\n    [-0.86147034  0.10089529]\n    [-0.9707841   0.0360311 ]\n    [-0.6716844   0.02247889]\n    [ 0.32964635 -0.7669234 ]]\n    <NDArray 6x2 @cpu(0)>\n    \"\"\"\n\n    def __init__(self, in_feats, out_feats, k=2, bias=True, activation=None):\n        super(TAGConv, self).__init__()\n        self.out_feats = out_feats\n        self.k = k\n        self.bias = bias\n        self.activation = activation\n        self.in_feats = in_feats\n\n        self.lin = self.params.get(\n            \"weight\",\n            shape=(self.in_feats * (self.k + 1), self.out_feats),\n            init=mx.init.Xavier(magnitude=math.sqrt(2.0)),\n        )\n        if self.bias:\n            self.h_bias = self.params.get(\n                \"bias\", shape=(out_feats,), init=mx.init.Zero()\n            )\n\n    def forward(self, graph, feat):\n        r\"\"\"\n\n        Description\n        -----------\n        Compute topology adaptive graph convolution.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The graph.\n        feat : mxnet.NDArray\n            The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`\n            is size of input feature, :math:`N` is the number of nodes.\n\n        Returns\n        -------\n        mxnet.NDArray\n            The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`\n            is size of output feature.\n        \"\"\"\n        with graph.local_scope():\n            assert graph.is_homogeneous, \"Graph is not homogeneous\"\n\n            degs = graph.in_degrees().astype(\"float32\")\n            norm = mx.nd.power(\n                mx.nd.clip(degs, a_min=1, a_max=float(\"inf\")), -0.5\n            )\n            shp = norm.shape + (1,) * (feat.ndim - 1)\n            norm = norm.reshape(shp).as_in_context(feat.context)\n\n            rst = feat\n            for _ in range(self.k):\n                rst = rst * norm\n                graph.ndata[\"h\"] = rst\n\n                graph.update_all(\n                    fn.copy_u(u=\"h\", out=\"m\"), fn.sum(msg=\"m\", out=\"h\")\n                )\n                rst = graph.ndata[\"h\"]\n                rst = rst * norm\n                feat = mx.nd.concat(feat, rst, dim=-1)\n\n            rst = mx.nd.dot(feat, self.lin.data(feat.context))\n            if self.bias is not None:\n                rst = rst + self.h_bias.data(rst.context)\n\n            if self.activation is not None:\n                rst = self.activation(rst)\n\n            return rst\n"
  },
  {
    "path": "python/dgl/nn/mxnet/glob.py",
    "content": "\"\"\"MXNet modules for graph global pooling.\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name, W0235\nfrom mxnet import gluon, nd\nfrom mxnet.gluon import nn\n\nfrom ...readout import (\n    broadcast_nodes,\n    max_nodes,\n    mean_nodes,\n    softmax_nodes,\n    sum_nodes,\n    topk_nodes,\n)\n\n__all__ = [\n    \"SumPooling\",\n    \"AvgPooling\",\n    \"MaxPooling\",\n    \"SortPooling\",\n    \"GlobalAttentionPooling\",\n    \"Set2Set\",\n]\n\n\nclass SumPooling(nn.Block):\n    r\"\"\"Apply sum pooling over the nodes in the graph.\n\n    .. math::\n        r^{(i)} = \\sum_{k=1}^{N_i} x^{(i)}_k\n    \"\"\"\n\n    def __init__(self):\n        super(SumPooling, self).__init__()\n\n    def forward(self, graph, feat):\n        r\"\"\"Compute sum pooling.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The graph.\n        feat : mxnet.NDArray\n            The input feature with shape :math:`(N, *)` where\n            :math:`N` is the number of nodes in the graph.\n\n        Returns\n        -------\n        mxnet.NDArray\n            The output feature with shape :math:`(B, *)`, where\n            :math:`B` refers to the batch size.\n        \"\"\"\n        with graph.local_scope():\n            graph.ndata[\"h\"] = feat\n            readout = sum_nodes(graph, \"h\")\n            graph.ndata.pop(\"h\")\n            return readout\n\n    def __repr__(self):\n        return \"SumPooling()\"\n\n\nclass AvgPooling(nn.Block):\n    r\"\"\"Apply average pooling over the nodes in the graph.\n\n    .. math::\n        r^{(i)} = \\frac{1}{N_i}\\sum_{k=1}^{N_i} x^{(i)}_k\n    \"\"\"\n\n    def __init__(self):\n        super(AvgPooling, self).__init__()\n\n    def forward(self, graph, feat):\n        r\"\"\"Compute average pooling.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The graph.\n        feat : mxnet.NDArray\n            The input feature with shape :math:`(N, *)` where\n            :math:`N` is the number of nodes in the graph.\n\n        Returns\n        -------\n        mxnet.NDArray\n            The output feature with shape :math:`(B, *)`, where\n            :math:`B` refers to the batch size.\n        \"\"\"\n        with graph.local_scope():\n            graph.ndata[\"h\"] = feat\n            readout = mean_nodes(graph, \"h\")\n            graph.ndata.pop(\"h\")\n            return readout\n\n    def __repr__(self):\n        return \"AvgPooling()\"\n\n\nclass MaxPooling(nn.Block):\n    r\"\"\"Apply max pooling over the nodes in the graph.\n\n    .. math::\n        r^{(i)} = \\max_{k=1}^{N_i} \\left( x^{(i)}_k \\right)\n    \"\"\"\n\n    def __init__(self):\n        super(MaxPooling, self).__init__()\n\n    def forward(self, graph, feat):\n        r\"\"\"Compute max pooling.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The graph.\n        feat : mxnet.NDArray\n            The input feature with shape :math:`(N, *)` where\n            :math:`N` is the number of nodes in the graph.\n\n        Returns\n        -------\n        mxnet.NDArray\n            The output feature with shape :math:`(B, *)`, where\n            :math:`B` refers to the batch size.\n        \"\"\"\n        with graph.local_scope():\n            graph.ndata[\"h\"] = feat\n            readout = max_nodes(graph, \"h\")\n            graph.ndata.pop(\"h\")\n            return readout\n\n    def __repr__(self):\n        return \"MaxPooling()\"\n\n\nclass SortPooling(nn.Block):\n    r\"\"\"Pooling layer from `An End-to-End Deep Learning Architecture for Graph Classification\n    <https://www.cse.wustl.edu/~ychen/public/DGCNN.pdf>`__\n\n    Parameters\n    ----------\n    k : int\n        The number of nodes to hold for each graph.\n    \"\"\"\n\n    def __init__(self, k):\n        super(SortPooling, self).__init__()\n        self.k = k\n\n    def forward(self, graph, feat):\n        r\"\"\"Compute sort pooling.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The graph.\n        feat : mxnet.NDArray\n            The input node feature with shape :math:`(N, D)` where\n            :math:`N` is the number of nodes in the graph.\n\n        Returns\n        -------\n        mxnet.NDArray\n            The output feature with shape :math:`(B, k * D)`, where\n            :math:`B` refers to the batch size.\n        \"\"\"\n        # Sort the feature of each node in ascending order.\n        with graph.local_scope():\n            feat = feat.sort(axis=-1)\n            graph.ndata[\"h\"] = feat\n            # Sort nodes according to their last features.\n            ret = topk_nodes(graph, \"h\", self.k, sortby=-1)[0].reshape(\n                -1, self.k * feat.shape[-1]\n            )\n            return ret\n\n    def __repr__(self):\n        return \"SortPooling(k={})\".format(self.k)\n\n\nclass GlobalAttentionPooling(nn.Block):\n    r\"\"\"Global Attention Pooling layer from `Gated Graph Sequence Neural Networks\n    <https://arxiv.org/abs/1511.05493.pdf>`__\n\n    .. math::\n        r^{(i)} = \\sum_{k=1}^{N_i}\\mathrm{softmax}\\left(f_{gate}\n        \\left(x^{(i)}_k\\right)\\right) f_{feat}\\left(x^{(i)}_k\\right)\n\n    Parameters\n    ----------\n    gate_nn : gluon.nn.Block\n        A neural network that computes attention scores for each feature.\n    feat_nn : gluon.nn.Block, optional\n        A neural network applied to each feature before combining them\n        with attention scores.\n    \"\"\"\n\n    def __init__(self, gate_nn, feat_nn=None):\n        super(GlobalAttentionPooling, self).__init__()\n        with self.name_scope():\n            self.gate_nn = gate_nn\n            self.feat_nn = feat_nn\n\n    def forward(self, graph, feat):\n        r\"\"\"Compute global attention pooling.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The graph.\n        feat : mxnet.NDArray\n            The input node feature with shape :math:`(N, D)` where\n            :math:`N` is the number of nodes in the graph.\n\n        Returns\n        -------\n        mxnet.NDArray\n            The output feature with shape :math:`(B, D)`, where\n            :math:`B` refers to the batch size.\n        \"\"\"\n        with graph.local_scope():\n            gate = self.gate_nn(feat)\n            assert (\n                gate.shape[-1] == 1\n            ), \"The output of gate_nn should have size 1 at the last axis.\"\n            feat = self.feat_nn(feat) if self.feat_nn else feat\n\n            graph.ndata[\"gate\"] = gate\n            gate = softmax_nodes(graph, \"gate\")\n\n            graph.ndata[\"r\"] = feat * gate\n            readout = sum_nodes(graph, \"r\")\n\n            return readout\n\n\nclass Set2Set(nn.Block):\n    r\"\"\"Set2Set operator from `Order Matters: Sequence to sequence for sets\n    <https://arxiv.org/pdf/1511.06391.pdf>`__\n\n    For each individual graph in the batch, set2set computes\n\n    .. math::\n        q_t &= \\mathrm{LSTM} (q^*_{t-1})\n\n        \\alpha_{i,t} &= \\mathrm{softmax}(x_i \\cdot q_t)\n\n        r_t &= \\sum_{i=1}^N \\alpha_{i,t} x_i\n\n        q^*_t &= q_t \\Vert r_t\n\n    for this graph.\n\n    Parameters\n    ----------\n    input_dim : int\n        Size of each input sample\n    n_iters : int\n        Number of iterations.\n    n_layers : int\n        Number of recurrent layers.\n    \"\"\"\n\n    def __init__(self, input_dim, n_iters, n_layers):\n        super(Set2Set, self).__init__()\n        self.input_dim = input_dim\n        self.output_dim = 2 * input_dim\n        self.n_iters = n_iters\n        self.n_layers = n_layers\n        with self.name_scope():\n            self.lstm = gluon.rnn.LSTM(\n                self.input_dim, num_layers=n_layers, input_size=self.output_dim\n            )\n\n    def forward(self, graph, feat):\n        r\"\"\"Compute set2set pooling.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The graph.\n        feat : mxnet.NDArray\n            The input node feature with shape :math:`(N, D)` where\n            :math:`N` is the number of nodes in the graph.\n\n        Returns\n        -------\n        mxnet.NDArray\n            The output feature with shape :math:`(B, D)`, where\n            :math:`B` refers to the batch size.\n        \"\"\"\n        with graph.local_scope():\n            batch_size = graph.batch_size\n\n            h = (\n                nd.zeros(\n                    (self.n_layers, batch_size, self.input_dim),\n                    ctx=feat.context,\n                ),\n                nd.zeros(\n                    (self.n_layers, batch_size, self.input_dim),\n                    ctx=feat.context,\n                ),\n            )\n            q_star = nd.zeros((batch_size, self.output_dim), ctx=feat.context)\n\n            for _ in range(self.n_iters):\n                q, h = self.lstm(q_star.expand_dims(axis=0), h)\n                q = q.reshape((batch_size, self.input_dim))\n                e = (feat * broadcast_nodes(graph, q)).sum(\n                    axis=-1, keepdims=True\n                )\n                graph.ndata[\"e\"] = e\n                alpha = softmax_nodes(graph, \"e\")\n                graph.ndata[\"r\"] = feat * alpha\n                readout = sum_nodes(graph, \"r\")\n                q_star = nd.concat(q, readout, dim=-1)\n\n            return q_star\n\n    def __repr__(self):\n        summary = \"Set2Set(\"\n        summary += \"in={}, out={}, \" \"n_iters={}, n_layers={}\".format(\n            self.input_dim, self.output_dim, self.n_iters, self.n_layers\n        )\n        summary += \")\"\n        return summary\n"
  },
  {
    "path": "python/dgl/nn/mxnet/hetero.py",
    "content": "\"\"\"Heterograph NN modules\"\"\"\nfrom mxnet import nd\nfrom mxnet.gluon import nn\n\n__all__ = [\"HeteroGraphConv\"]\n\n\nclass HeteroGraphConv(nn.Block):\n    r\"\"\"A generic module for computing convolution on heterogeneous graphs\n\n    The heterograph convolution applies sub-modules on their associating\n    relation graphs, which reads the features from source nodes and writes the\n    updated ones to destination nodes. If multiple relations have the same\n    destination node types, their results are aggregated by the specified method.\n    If the relation graph has no edge, the corresponding module will not be called.\n\n    Pseudo-code:\n\n    .. code::\n\n        outputs = {nty : [] for nty in g.dsttypes}\n        # Apply sub-modules on their associating relation graphs in parallel\n        for relation in g.canonical_etypes:\n            stype, etype, dtype = relation\n            dstdata = relation_submodule(g[relation], ...)\n            outputs[dtype].append(dstdata)\n\n        # Aggregate the results for each destination node type\n        rsts = {}\n        for ntype, ntype_outputs in outputs.items():\n            if len(ntype_outputs) != 0:\n                rsts[ntype] = aggregate(ntype_outputs)\n        return rsts\n\n    Examples\n    --------\n\n    Create a heterograph with three types of relations and nodes.\n\n    >>> import dgl\n    >>> g = dgl.heterograph({\n    ...     ('user', 'follows', 'user') : edges1,\n    ...     ('user', 'plays', 'game') : edges2,\n    ...     ('store', 'sells', 'game')  : edges3})\n\n    Create a ``HeteroGraphConv`` that applies different convolution modules to\n    different relations. Note that the modules for ``'follows'`` and ``'plays'``\n    do not share weights.\n\n    >>> import dgl.nn.pytorch as dglnn\n    >>> conv = dglnn.HeteroGraphConv({\n    ...     'follows' : dglnn.GraphConv(...),\n    ...     'plays' : dglnn.GraphConv(...),\n    ...     'sells' : dglnn.SAGEConv(...)},\n    ...     aggregate='sum')\n\n    Call forward with some ``'user'`` features. This computes new features for both\n    ``'user'`` and ``'game'`` nodes.\n\n    >>> import mxnet.ndarray as nd\n    >>> h1 = {'user' : nd.random.randn(g.num_nodes('user'), 5)}\n    >>> h2 = conv(g, h1)\n    >>> print(h2.keys())\n    dict_keys(['user', 'game'])\n\n    Call forward with both ``'user'`` and ``'store'`` features. Because both the\n    ``'plays'`` and ``'sells'`` relations will update the ``'game'`` features,\n    their results are aggregated by the specified method (i.e., summation here).\n\n    >>> f1 = {'user' : ..., 'store' : ...}\n    >>> f2 = conv(g, f1)\n    >>> print(f2.keys())\n    dict_keys(['user', 'game'])\n\n    Call forward with some ``'store'`` features. This only computes new features\n    for ``'game'`` nodes.\n\n    >>> g1 = {'store' : ...}\n    >>> g2 = conv(g, g1)\n    >>> print(g2.keys())\n    dict_keys(['game'])\n\n    Call forward with a pair of inputs is allowed and each submodule will also\n    be invoked with a pair of inputs.\n\n    >>> x_src = {'user' : ..., 'store' : ...}\n    >>> x_dst = {'user' : ..., 'game' : ...}\n    >>> y_dst = conv(g, (x_src, x_dst))\n    >>> print(y_dst.keys())\n    dict_keys(['user', 'game'])\n\n    Parameters\n    ----------\n    mods : dict[str, nn.Module]\n        Modules associated with every edge types. The forward function of each\n        module must have a `DGLGraph` object as the first argument, and\n        its second argument is either a tensor object representing the node\n        features or a pair of tensor object representing the source and destination\n        node features.\n    aggregate : str, callable, optional\n        Method for aggregating node features generated by different relations.\n        Allowed string values are 'sum', 'max', 'min', 'mean', 'stack'.\n        The 'stack' aggregation is performed along the second dimension, whose order\n        is deterministic.\n        User can also customize the aggregator by providing a callable instance.\n        For example, aggregation by summation is equivalent to the follows:\n\n        .. code::\n\n            def my_agg_func(tensors, dsttype):\n                # tensors: is a list of tensors to aggregate\n                # dsttype: string name of the destination node type for which the\n                #          aggregation is performed\n                stacked = mx.nd.stack(*tensors, axis=0)\n                return mx.nd.sum(stacked, axis=0)\n\n    Attributes\n    ----------\n    mods : dict[str, nn.Module]\n        Modules associated with every edge types.\n    \"\"\"\n\n    def __init__(self, mods, aggregate=\"sum\"):\n        super(HeteroGraphConv, self).__init__()\n        with self.name_scope():\n            for name, mod in mods.items():\n                self.register_child(mod, name)\n            self.mods = mods\n            # Do not break if graph has 0-in-degree nodes.\n            # Because there is no general rule to add self-loop for heterograph.\n            for _, v in self.mods.items():\n                set_allow_zero_in_degree_fn = getattr(\n                    v, \"set_allow_zero_in_degree\", None\n                )\n                if callable(set_allow_zero_in_degree_fn):\n                    set_allow_zero_in_degree_fn(True)\n            if isinstance(aggregate, str):\n                self.agg_fn = get_aggregate_fn(aggregate)\n            else:\n                self.agg_fn = aggregate\n\n    def forward(self, g, inputs, mod_args=None, mod_kwargs=None):\n        \"\"\"Forward computation\n\n        Invoke the forward function with each module and aggregate their results.\n\n        Parameters\n        ----------\n        g : DGLGraph\n            Graph data.\n        inputs : dict[str, Tensor] or pair of dict[str, Tensor]\n            Input node features.\n        mod_args : dict[str, tuple[any]], optional\n            Extra positional arguments for the sub-modules.\n        mod_kwargs : dict[str, dict[str, any]], optional\n            Extra key-word arguments for the sub-modules.\n\n        Returns\n        -------\n        dict[str, Tensor]\n            Output representations for every types of nodes.\n        \"\"\"\n        if mod_args is None:\n            mod_args = {}\n        if mod_kwargs is None:\n            mod_kwargs = {}\n        outputs = {nty: [] for nty in g.dsttypes}\n        if isinstance(inputs, tuple):\n            src_inputs, dst_inputs = inputs\n            for stype, etype, dtype in g.canonical_etypes:\n                rel_graph = g[stype, etype, dtype]\n                if stype not in src_inputs or dtype not in dst_inputs:\n                    continue\n                dstdata = self.mods[etype](\n                    rel_graph,\n                    (src_inputs[stype], dst_inputs[dtype]),\n                    *mod_args.get(etype, ()),\n                    **mod_kwargs.get(etype, {})\n                )\n                outputs[dtype].append(dstdata)\n        else:\n            for stype, etype, dtype in g.canonical_etypes:\n                rel_graph = g[stype, etype, dtype]\n                if stype not in inputs:\n                    continue\n                dstdata = self.mods[etype](\n                    rel_graph,\n                    (inputs[stype], inputs[dtype]),\n                    *mod_args.get(etype, ()),\n                    **mod_kwargs.get(etype, {})\n                )\n                outputs[dtype].append(dstdata)\n        rsts = {}\n        for nty, alist in outputs.items():\n            if len(alist) != 0:\n                rsts[nty] = self.agg_fn(alist, nty)\n        return rsts\n\n    def __repr__(self):\n        summary = \"HeteroGraphConv({\\n\"\n        for name, mod in self.mods.items():\n            summary += \"  {} : {},\\n\".format(name, mod)\n        summary += \"\\n})\"\n        return summary\n\n\ndef get_aggregate_fn(agg):\n    \"\"\"Internal function to get the aggregation function for node data\n    generated from different relations.\n\n    Parameters\n    ----------\n    agg : str\n        Method for aggregating node features generated by different relations.\n        Allowed values are 'sum', 'max', 'min', 'mean', 'stack'.\n\n    Returns\n    -------\n    callable\n        Aggregator function that takes a list of tensors to aggregate\n        and returns one aggregated tensor.\n    \"\"\"\n    if agg == \"sum\":\n        fn = nd.sum\n    elif agg == \"max\":\n        fn = nd.max\n    elif agg == \"min\":\n        fn = nd.min\n    elif agg == \"mean\":\n        fn = nd.mean\n    elif agg == \"stack\":\n        fn = None  # will not be called\n    else:\n        raise DGLError(\n            \"Invalid cross type aggregator. Must be one of \"\n            '\"sum\", \"max\", \"min\", \"mean\" or \"stack\". But got \"%s\"' % agg\n        )\n    if agg == \"stack\":\n\n        def stack_agg(inputs, dsttype):  # pylint: disable=unused-argument\n            if len(inputs) == 0:\n                return None\n            return nd.stack(*inputs, axis=1)\n\n        return stack_agg\n    else:\n\n        def aggfn(inputs, dsttype):  # pylint: disable=unused-argument\n            if len(inputs) == 0:\n                return None\n            stacked = nd.stack(*inputs, axis=0)\n            return fn(stacked, axis=0)\n\n        return aggfn\n"
  },
  {
    "path": "python/dgl/nn/mxnet/softmax.py",
    "content": "\"\"\"Gluon layer for graph related softmax.\"\"\"\n# pylint: disable= unused-import\nfrom ..functional import edge_softmax\n"
  },
  {
    "path": "python/dgl/nn/mxnet/utils.py",
    "content": "\"\"\"Utilities for pytorch NN package\"\"\"\n# pylint: disable=no-member, invalid-name\n\nimport numpy as np\nfrom mxnet import gluon, nd\n\nfrom ... import DGLGraph\n\n\ndef matmul_maybe_select(A, B):\n    \"\"\"Perform Matrix multiplication C = A * B but A could be an integer id vector.\n\n    If A is an integer vector, we treat it as multiplying a one-hot encoded tensor.\n    In this case, the expensive dense matrix multiply can be replaced by a much\n    cheaper index lookup.\n\n    For example,\n    ::\n\n        A = [2, 0, 1],\n        B = [[0.1, 0.2],\n             [0.3, 0.4],\n             [0.5, 0.6]]\n\n    then matmul_maybe_select(A, B) is equivalent to\n    ::\n\n        [[0, 0, 1],     [[0.1, 0.2],\n         [1, 0, 0],  *   [0.3, 0.4],\n         [0, 1, 0]]      [0.5, 0.6]]\n\n    In all other cases, perform a normal matmul.\n\n    Parameters\n    ----------\n    A : mxnet.NDArray\n        lhs tensor\n    B : mxnet.NDArray\n        rhs tensor\n\n    Returns\n    -------\n    C : mxnet.NDArray\n        result tensor\n    \"\"\"\n    if A.dtype in (np.int32, np.int64) and len(A.shape) == 1:\n        return nd.take(B, A, axis=0)\n    else:\n        return nd.dot(A, B)\n\n\ndef bmm_maybe_select(A, B, index):\n    \"\"\"Slice submatrices of A by the given index and perform bmm.\n\n    B is a 3D tensor of shape (N, D1, D2), which can be viewed as a stack of\n    N matrices of shape (D1, D2). The input index is an integer vector of length M.\n    A could be either:\n    (1) a dense tensor of shape (M, D1),\n    (2) an integer vector of length M.\n    The result C is a 2D matrix of shape (M, D2)\n\n    For case (1), C is computed by bmm:\n    ::\n\n        C[i, :] = matmul(A[i, :], B[index[i], :, :])\n\n    For case (2), C is computed by index select:\n    ::\n\n        C[i, :] = B[index[i], A[i], :]\n\n    Parameters\n    ----------\n    A : mxnet.NDArray\n        lhs tensor\n    B : mxnet.NDArray\n        rhs tensor\n    index : mxnet.NDArray\n        index tensor\n\n    Returns\n    -------\n    C : mxnet.NDArray\n        return tensor\n    \"\"\"\n    if A.dtype in (np.int32, np.int64) and len(A.shape) == 1:\n        return B[index, A, :]\n    else:\n        BB = nd.take(B, index, axis=0)\n        return nd.batch_dot(A.expand_dims(1), BB).squeeze(1)\n\n\ndef normalize(x, p=2, axis=1, eps=1e-12):\n    r\"\"\"Performs :math:`L_p` normalization of inputs over specified dimension.\n\n    For a tensor :attr:`input` of sizes :math:`(n_0, ..., n_{dim}, ..., n_k)`, each\n    :math:`n_{dim}` -element vector :math:`v` along dimension :attr:`dim` is transformed as\n\n    .. math::\n        v = \\frac{v}{\\max(\\lVert v \\rVert_p, \\epsilon)}.\n\n    With the default arguments it uses the Euclidean norm over vectors along dimension\n     :math:`1` for normalization.\n\n    Args:\n        x: input ndarray of any shape\n        ord (float): the exponent value in the norm formulation. Default: 2\n        dim (int): the dimension to reduce. Default: 1\n        eps (float): small value to avoid division by zero. Default: 1e-12\n    \"\"\"\n    denom = nd.clip(\n        nd.norm(x, ord=p, axis=axis, keepdims=True), eps, float(\"inf\")\n    )\n    return x / denom\n\n\nclass Sequential(gluon.nn.Sequential):\n    r\"\"\"A squential container for stacking graph neural network blocks\n\n    We support two modes: sequentially apply GNN blocks on the same graph or\n    a list of given graphs. In the second case, the number of graphs equals the\n    number of blocks inside this container.\n\n    Examples\n    --------\n\n    Mode 1: sequentially apply GNN modules on the same graph\n\n    >>> import dgl\n    >>> from mxnet import nd\n    >>> from mxnet.gluon import nn\n    >>> import dgl.function as fn\n    >>> from dgl.nn.mxnet import Sequential\n    >>> class ExampleLayer(nn.Block):\n    >>>     def __init__(self, **kwargs):\n    >>>         super().__init__(**kwargs)\n    >>>     def forward(self, graph, n_feat, e_feat):\n    >>>         with graph.local_scope():\n    >>>             graph.ndata['h'] = n_feat\n    >>>             graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))\n    >>>             n_feat += graph.ndata['h']\n    >>>             graph.apply_edges(fn.u_add_v('h', 'h', 'e'))\n    >>>             e_feat += graph.edata['e']\n    >>>             return n_feat, e_feat\n    >>>\n    >>> g = dgl.DGLGraph()\n    >>> g.add_nodes(3)\n    >>> g.add_edges([0, 1, 2, 0, 1, 2, 0, 1, 2], [0, 0, 0, 1, 1, 1, 2, 2, 2])\n    >>> net = Sequential()\n    >>> net.add(ExampleLayer())\n    >>> net.add(ExampleLayer())\n    >>> net.add(ExampleLayer())\n    >>> net.initialize()\n    >>> n_feat = nd.random.randn(3, 4)\n    >>> e_feat = nd.random.randn(9, 4)\n    >>> net(g, n_feat, e_feat)\n    (\n    [[ 12.412863   99.61184    21.472883  -57.625923 ]\n     [ 10.08097   100.68611    20.627377  -60.13458  ]\n     [ 11.7912245 101.80654    22.427956  -58.32772  ]]\n    <NDArray 3x4 @cpu(0)>,\n    [[  21.818504  198.12076    42.72387  -115.147736]\n     [  23.070837  195.49811    43.42292  -116.17203 ]\n     [  24.330334  197.10927    42.40048  -118.06538 ]\n     [  21.907919  199.11469    42.1187   -115.35658 ]\n     [  22.849625  198.79213    43.866085 -113.65381 ]\n     [  20.926125  198.116      42.64334  -114.246704]\n     [  23.003159  197.06662    41.796425 -117.14977 ]\n     [  21.391375  198.3348     41.428078 -116.30361 ]\n     [  21.291483  200.0701     40.8239   -118.07314 ]]\n    <NDArray 9x4 @cpu(0)>)\n\n    Mode 2: sequentially apply GNN modules on different graphs\n\n    >>> import dgl\n    >>> from mxnet import nd\n    >>> from mxnet.gluon import nn\n    >>> import dgl.function as fn\n    >>> import networkx as nx\n    >>> from dgl.nn.mxnet import Sequential\n    >>> class ExampleLayer(nn.Block):\n    >>>     def __init__(self, **kwargs):\n    >>>         super().__init__(**kwargs)\n    >>>     def forward(self, graph, n_feat):\n    >>>         with graph.local_scope():\n    >>>             graph.ndata['h'] = n_feat\n    >>>             graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))\n    >>>             n_feat += graph.ndata['h']\n    >>>             return n_feat.reshape(graph.num_nodes() // 2, 2, -1).sum(1)\n    >>>\n    >>> g1 = dgl.DGLGraph(nx.erdos_renyi_graph(32, 0.05))\n    >>> g2 = dgl.DGLGraph(nx.erdos_renyi_graph(16, 0.2))\n    >>> g3 = dgl.DGLGraph(nx.erdos_renyi_graph(8, 0.8))\n    >>> net = Sequential()\n    >>> net.add(ExampleLayer())\n    >>> net.add(ExampleLayer())\n    >>> net.add(ExampleLayer())\n    >>> net.initialize()\n    >>> n_feat = nd.random.randn(32, 4)\n    >>> net([g1, g2, g3], n_feat)\n    [[-101.289566  -22.584694  -89.25348  -151.6447  ]\n     [-130.74239   -49.494812 -120.250854 -199.81546 ]\n     [-112.32089   -50.036713 -116.13266  -190.38638 ]\n     [-119.23065   -26.78553  -111.11185  -166.08322 ]]\n    <NDArray 4x4 @cpu(0)>\n    \"\"\"\n\n    def __init__(self, prefix=None, params=None):\n        super(Sequential, self).__init__(prefix=prefix, params=params)\n\n    def forward(self, graph, *feats):\n        r\"\"\"Sequentially apply modules to the input.\n\n        Parameters\n        ----------\n        graph : DGLGraph or list of DGLGraphs\n            The graph(s) to apply modules on.\n\n        *feats :\n            Input features.\n            The output of :math:`i`-th block should match that of the input\n            of :math:`(i+1)`-th block.\n        \"\"\"\n        if isinstance(graph, list):\n            for graph_i, module in zip(graph, self):\n                if not isinstance(feats, tuple):\n                    feats = (feats,)\n                feats = module(graph_i, *feats)\n        elif isinstance(graph, DGLGraph):\n            for module in self:\n                if not isinstance(feats, tuple):\n                    feats = (feats,)\n                feats = module(graph, *feats)\n        else:\n            raise TypeError(\n                \"The first argument of forward must be a DGLGraph\"\n                \" or a list of DGLGraph s\"\n            )\n        return feats\n"
  },
  {
    "path": "python/dgl/nn/pytorch/__init__.py",
    "content": "\"\"\"Package for pytorch-specific NN modules.\"\"\"\nfrom .conv import *\nfrom .explain import *\nfrom .link import *\nfrom .linear import *\nfrom .glob import *\nfrom .softmax import *\nfrom .factory import *\nfrom .hetero import *\nfrom .sparse_emb import NodeEmbedding\nfrom .utils import JumpingKnowledge, LabelPropagation, Sequential, WeightBasis\nfrom .network_emb import *\nfrom .gt import *\n"
  },
  {
    "path": "python/dgl/nn/pytorch/conv/__init__.py",
    "content": "\"\"\"Torch modules for graph convolutions.\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name\n\nfrom .agnnconv import AGNNConv\nfrom .appnpconv import APPNPConv\nfrom .atomicconv import AtomicConv\nfrom .cfconv import CFConv\nfrom .chebconv import ChebConv\nfrom .cugraph_gatconv import CuGraphGATConv\nfrom .cugraph_relgraphconv import CuGraphRelGraphConv\nfrom .cugraph_sageconv import CuGraphSAGEConv\nfrom .densechebconv import DenseChebConv\nfrom .densegraphconv import DenseGraphConv\nfrom .densesageconv import DenseSAGEConv\nfrom .dgnconv import DGNConv\nfrom .dotgatconv import DotGatConv\nfrom .edgeconv import EdgeConv\nfrom .edgegatconv import EdgeGATConv\nfrom .egatconv import EGATConv\nfrom .egnnconv import EGNNConv\nfrom .gatconv import GATConv\nfrom .gatedgcnconv import GatedGCNConv\nfrom .gatedgraphconv import GatedGraphConv\nfrom .gatv2conv import GATv2Conv\nfrom .gcn2conv import GCN2Conv\nfrom .ginconv import GINConv\nfrom .gineconv import GINEConv\nfrom .gmmconv import GMMConv\nfrom .graphconv import EdgeWeightNorm, GraphConv\nfrom .grouprevres import GroupRevRes\nfrom .hgtconv import HGTConv\nfrom .nnconv import NNConv\nfrom .pnaconv import PNAConv\nfrom .relgraphconv import RelGraphConv\nfrom .sageconv import SAGEConv\nfrom .sgconv import SGConv\nfrom .tagconv import TAGConv\nfrom .twirlsconv import TWIRLSConv, TWIRLSUnfoldingAndAttention\n\n__all__ = [\n    \"GraphConv\",\n    \"EdgeWeightNorm\",\n    \"GATConv\",\n    \"GATv2Conv\",\n    \"EGATConv\",\n    \"EdgeGATConv\",\n    \"TAGConv\",\n    \"RelGraphConv\",\n    \"SAGEConv\",\n    \"SGConv\",\n    \"APPNPConv\",\n    \"GINConv\",\n    \"GINEConv\",\n    \"GatedGraphConv\",\n    \"GatedGCNConv\",\n    \"GMMConv\",\n    \"ChebConv\",\n    \"AGNNConv\",\n    \"NNConv\",\n    \"DenseGraphConv\",\n    \"DenseSAGEConv\",\n    \"DenseChebConv\",\n    \"EdgeConv\",\n    \"AtomicConv\",\n    \"CFConv\",\n    \"DotGatConv\",\n    \"TWIRLSConv\",\n    \"TWIRLSUnfoldingAndAttention\",\n    \"GCN2Conv\",\n    \"HGTConv\",\n    \"GroupRevRes\",\n    \"EGNNConv\",\n    \"PNAConv\",\n    \"DGNConv\",\n    \"CuGraphGATConv\",\n    \"CuGraphRelGraphConv\",\n    \"CuGraphSAGEConv\",\n]\n"
  },
  {
    "path": "python/dgl/nn/pytorch/conv/agnnconv.py",
    "content": "\"\"\"Torch Module for Attention-based Graph Neural Network layer\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name\nimport torch as th\nfrom torch import nn\nfrom torch.nn import functional as F\n\nfrom .... import function as fn\nfrom ....base import DGLError\nfrom ....utils import expand_as_pair\nfrom ...functional import edge_softmax\n\n\nclass AGNNConv(nn.Module):\n    r\"\"\"Attention-based Graph Neural Network layer from `Attention-based Graph Neural Network for\n    Semi-Supervised Learning <https://arxiv.org/abs/1803.03735>`__\n\n    .. math::\n        H^{l+1} = P H^{l}\n\n    where :math:`P` is computed as:\n\n    .. math::\n        P_{ij} = \\mathrm{softmax}_i ( \\beta \\cdot \\cos(h_i^l, h_j^l))\n\n    where :math:`\\beta` is a single scalar parameter.\n\n    Parameters\n    ----------\n    init_beta : float, optional\n        The :math:`\\beta` in the formula, a single scalar parameter.\n    learn_beta : bool, optional\n        If True, :math:`\\beta` will be learnable parameter.\n    allow_zero_in_degree : bool, optional\n        If there are 0-in-degree nodes in the graph, output for those nodes will be invalid\n        since no message will be passed to those nodes. This is harmful for some applications\n        causing silent performance regression. This module will raise a DGLError if it detects\n        0-in-degree nodes in input graph. By setting ``True``, it will suppress the check\n        and let the users handle it by themselves. Default: ``False``.\n\n    Note\n    ----\n    Zero in-degree nodes will lead to invalid output value. This is because no message\n    will be passed to those nodes, the aggregation function will be appied on empty input.\n    A common practice to avoid this is to add a self-loop for each node in the graph if\n    it is homogeneous, which can be achieved by:\n\n    >>> g = ... # a DGLGraph\n    >>> g = dgl.add_self_loop(g)\n\n    Calling ``add_self_loop`` will not work for some graphs, for example, heterogeneous graph\n    since the edge type can not be decided for self_loop edges. Set ``allow_zero_in_degree``\n    to ``True`` for those cases to unblock the code and handle zero-in-degree nodes manually.\n    A common practise to handle this is to filter out the nodes with zero-in-degree when use\n    after conv.\n\n    Example\n    -------\n    >>> import dgl\n    >>> import numpy as np\n    >>> import torch as th\n    >>> from dgl.nn import AGNNConv\n    >>>\n    >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))\n    >>> g = dgl.add_self_loop(g)\n    >>> feat = th.ones(6, 10)\n    >>> conv = AGNNConv()\n    >>> res = conv(g, feat)\n    >>> res\n    tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n            [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n            [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n            [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n            [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n            [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]],\n        grad_fn=<BinaryReduceBackward>)\n    \"\"\"\n\n    def __init__(\n        self, init_beta=1.0, learn_beta=True, allow_zero_in_degree=False\n    ):\n        super(AGNNConv, self).__init__()\n        self._allow_zero_in_degree = allow_zero_in_degree\n        if learn_beta:\n            self.beta = nn.Parameter(th.Tensor([init_beta]))\n        else:\n            self.register_buffer(\"beta\", th.Tensor([init_beta]))\n\n    def set_allow_zero_in_degree(self, set_value):\n        r\"\"\"\n\n        Description\n        -----------\n        Set allow_zero_in_degree flag.\n\n        Parameters\n        ----------\n        set_value : bool\n            The value to be set to the flag.\n        \"\"\"\n        self._allow_zero_in_degree = set_value\n\n    def forward(self, graph, feat):\n        r\"\"\"\n\n        Description\n        -----------\n        Compute AGNN layer.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The graph.\n        feat : torch.Tensor\n            The input feature of shape :math:`(N, *)` :math:`N` is the\n            number of nodes, and :math:`*` could be of any shape.\n            If a pair of torch.Tensor is given, the pair must contain two tensors of shape\n            :math:`(N_{in}, *)` and :math:`(N_{out}, *)`, the :math:`*` in the later\n            tensor must equal the previous one.\n\n        Returns\n        -------\n        torch.Tensor\n            The output feature of shape :math:`(N, *)` where :math:`*`\n            should be the same as input shape.\n\n        Raises\n        ------\n        DGLError\n            If there are 0-in-degree nodes in the input graph, it will raise DGLError\n            since no message will be passed to those nodes. This will cause invalid output.\n            The error can be ignored by setting ``allow_zero_in_degree`` parameter to ``True``.\n        \"\"\"\n        with graph.local_scope():\n            if not self._allow_zero_in_degree:\n                if (graph.in_degrees() == 0).any():\n                    raise DGLError(\n                        \"There are 0-in-degree nodes in the graph, \"\n                        \"output for those nodes will be invalid. \"\n                        \"This is harmful for some applications, \"\n                        \"causing silent performance regression. \"\n                        \"Adding self-loop on the input graph by \"\n                        \"calling `g = dgl.add_self_loop(g)` will resolve \"\n                        \"the issue. Setting ``allow_zero_in_degree`` \"\n                        \"to be `True` when constructing this module will \"\n                        \"suppress the check and let the code run.\"\n                    )\n\n            feat_src, feat_dst = expand_as_pair(feat, graph)\n\n            graph.srcdata[\"h\"] = feat_src\n            graph.srcdata[\"norm_h\"] = F.normalize(feat_src, p=2, dim=-1)\n            if isinstance(feat, tuple) or graph.is_block:\n                graph.dstdata[\"norm_h\"] = F.normalize(feat_dst, p=2, dim=-1)\n            # compute cosine distance\n            graph.apply_edges(fn.u_dot_v(\"norm_h\", \"norm_h\", \"cos\"))\n            cos = graph.edata.pop(\"cos\")\n            e = self.beta * cos\n            graph.edata[\"p\"] = edge_softmax(graph, e)\n            graph.update_all(fn.u_mul_e(\"h\", \"p\", \"m\"), fn.sum(\"m\", \"h\"))\n            return graph.dstdata.pop(\"h\")\n"
  },
  {
    "path": "python/dgl/nn/pytorch/conv/appnpconv.py",
    "content": "\"\"\"Torch Module for APPNPConv\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name\nimport torch as th\nfrom torch import nn\n\nfrom .... import function as fn\nfrom .graphconv import EdgeWeightNorm\n\n\nclass APPNPConv(nn.Module):\n    r\"\"\"Approximate Personalized Propagation of Neural Predictions layer from `Predict then\n    Propagate: Graph Neural Networks meet Personalized PageRank\n    <https://arxiv.org/pdf/1810.05997.pdf>`__\n\n    .. math::\n        H^{0} &= X\n\n        H^{l+1} &= (1-\\alpha)\\left(\\tilde{D}^{-1/2}\n        \\tilde{A} \\tilde{D}^{-1/2} H^{l}\\right) + \\alpha H^{0}\n\n    where :math:`\\tilde{A}` is :math:`A` + :math:`I`.\n\n    Parameters\n    ----------\n    k : int\n        The number of iterations :math:`K`.\n    alpha : float\n        The teleport probability :math:`\\alpha`.\n    edge_drop : float, optional\n        The dropout rate on edges that controls the\n        messages received by each node. Default: ``0``.\n\n    Example\n    -------\n    >>> import dgl\n    >>> import numpy as np\n    >>> import torch as th\n    >>> from dgl.nn import APPNPConv\n    >>>\n    >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))\n    >>> feat = th.ones(6, 10)\n    >>> conv = APPNPConv(k=3, alpha=0.5)\n    >>> res = conv(g, feat)\n    >>> print(res)\n    tensor([[0.8536, 0.8536, 0.8536, 0.8536, 0.8536, 0.8536, 0.8536, 0.8536, 0.8536,\n            0.8536],\n            [0.9268, 0.9268, 0.9268, 0.9268, 0.9268, 0.9268, 0.9268, 0.9268, 0.9268,\n            0.9268],\n            [0.9634, 0.9634, 0.9634, 0.9634, 0.9634, 0.9634, 0.9634, 0.9634, 0.9634,\n            0.9634],\n            [0.9268, 0.9268, 0.9268, 0.9268, 0.9268, 0.9268, 0.9268, 0.9268, 0.9268,\n            0.9268],\n            [0.9634, 0.9634, 0.9634, 0.9634, 0.9634, 0.9634, 0.9634, 0.9634, 0.9634,\n            0.9634],\n            [0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,\n            0.5000]])\n    \"\"\"\n\n    def __init__(self, k, alpha, edge_drop=0.0):\n        super(APPNPConv, self).__init__()\n        self._k = k\n        self._alpha = alpha\n        self.edge_drop = nn.Dropout(edge_drop)\n\n    def forward(self, graph, feat, edge_weight=None):\n        r\"\"\"\n\n        Description\n        -----------\n        Compute APPNP layer.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The graph.\n        feat : torch.Tensor\n            The input feature of shape :math:`(N, *)`. :math:`N` is the\n            number of nodes, and :math:`*` could be of any shape.\n        edge_weight: torch.Tensor, optional\n            edge_weight to use in the message passing process. This is equivalent to\n            using weighted adjacency matrix in the equation above, and\n            :math:`\\tilde{D}^{-1/2}\\tilde{A} \\tilde{D}^{-1/2}`\n            is based on :class:`dgl.nn.pytorch.conv.graphconv.EdgeWeightNorm`.\n\n        Returns\n        -------\n        torch.Tensor\n            The output feature of shape :math:`(N, *)` where :math:`*`\n            should be the same as input shape.\n        \"\"\"\n        with graph.local_scope():\n            if edge_weight is None:\n                src_norm = th.pow(\n                    graph.out_degrees().to(feat).clamp(min=1), -0.5\n                )\n                shp = src_norm.shape + (1,) * (feat.dim() - 1)\n                src_norm = th.reshape(src_norm, shp).to(feat.device)\n                dst_norm = th.pow(\n                    graph.in_degrees().to(feat).clamp(min=1), -0.5\n                )\n                shp = dst_norm.shape + (1,) * (feat.dim() - 1)\n                dst_norm = th.reshape(dst_norm, shp).to(feat.device)\n            else:\n                edge_weight = EdgeWeightNorm(\"both\")(graph, edge_weight)\n            feat_0 = feat\n            for _ in range(self._k):\n                # normalization by src node\n                if edge_weight is None:\n                    feat = feat * src_norm\n                graph.ndata[\"h\"] = feat\n                w = (\n                    th.ones(graph.num_edges(), 1)\n                    if edge_weight is None\n                    else edge_weight\n                )\n                graph.edata[\"w\"] = self.edge_drop(w).to(feat.device)\n                graph.update_all(fn.u_mul_e(\"h\", \"w\", \"m\"), fn.sum(\"m\", \"h\"))\n                feat = graph.ndata.pop(\"h\")\n                # normalization by dst node\n                if edge_weight is None:\n                    feat = feat * dst_norm\n                feat = (1 - self._alpha) * feat + self._alpha * feat_0\n            return feat\n"
  },
  {
    "path": "python/dgl/nn/pytorch/conv/atomicconv.py",
    "content": "\"\"\"Torch Module for Atomic Convolution Layer\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name\nimport numpy as np\nimport torch as th\nimport torch.nn as nn\n\n\nclass RadialPooling(nn.Module):\n    r\"\"\"Radial pooling from `Atomic Convolutional Networks for\n    Predicting Protein-Ligand Binding Affinity <https://arxiv.org/abs/1703.10603>`__\n\n    We denote the distance between atom :math:`i` and :math:`j` by :math:`r_{ij}`.\n\n    A radial pooling layer transforms distances with radial filters. For radial filter\n    indexed by :math:`k`, it projects edge distances with\n\n    .. math::\n        h_{ij}^{k} = \\exp(-\\gamma_{k}|r_{ij}-r_{k}|^2)\n\n    If :math:`r_{ij} < c_k`,\n\n    .. math::\n        f_{ij}^{k} = 0.5 * \\cos(\\frac{\\pi r_{ij}}{c_k} + 1),\n\n    else,\n\n    .. math::\n        f_{ij}^{k} = 0.\n\n    Finally,\n\n    .. math::\n        e_{ij}^{k} = h_{ij}^{k} * f_{ij}^{k}\n\n    Parameters\n    ----------\n    interaction_cutoffs : float32 tensor of shape (K)\n        :math:`c_k` in the equations above. Roughly they can be considered as learnable cutoffs\n        and two atoms are considered as connected if the distance between them is smaller than\n        the cutoffs. K for the number of radial filters.\n    rbf_kernel_means : float32 tensor of shape (K)\n        :math:`r_k` in the equations above. K for the number of radial filters.\n    rbf_kernel_scaling : float32 tensor of shape (K)\n        :math:`\\gamma_k` in the equations above. K for the number of radial filters.\n    \"\"\"\n\n    def __init__(\n        self, interaction_cutoffs, rbf_kernel_means, rbf_kernel_scaling\n    ):\n        super(RadialPooling, self).__init__()\n\n        self.interaction_cutoffs = nn.Parameter(\n            interaction_cutoffs.reshape(-1, 1, 1), requires_grad=True\n        )\n        self.rbf_kernel_means = nn.Parameter(\n            rbf_kernel_means.reshape(-1, 1, 1), requires_grad=True\n        )\n        self.rbf_kernel_scaling = nn.Parameter(\n            rbf_kernel_scaling.reshape(-1, 1, 1), requires_grad=True\n        )\n\n    def forward(self, distances):\n        \"\"\"\n\n        Description\n        -----------\n        Apply the layer to transform edge distances.\n\n        Parameters\n        ----------\n        distances : Float32 tensor of shape (E, 1)\n            Distance between end nodes of edges. E for the number of edges.\n\n        Returns\n        -------\n        Float32 tensor of shape (K, E, 1)\n            Transformed edge distances. K for the number of radial filters.\n        \"\"\"\n        scaled_euclidean_distance = (\n            -self.rbf_kernel_scaling * (distances - self.rbf_kernel_means) ** 2\n        )  # (K, E, 1)\n        rbf_kernel_results = th.exp(scaled_euclidean_distance)  # (K, E, 1)\n\n        cos_values = 0.5 * (\n            th.cos(np.pi * distances / self.interaction_cutoffs) + 1\n        )  # (K, E, 1)\n        cutoff_values = th.where(\n            distances <= self.interaction_cutoffs,\n            cos_values,\n            th.zeros_like(cos_values),\n        )  # (K, E, 1)\n\n        # Note that there appears to be an inconsistency between the paper and\n        # DeepChem's implementation. In the paper, the scaled_euclidean_distance first\n        # gets multiplied by cutoff_values, followed by exponentiation. Here we follow\n        # the practice of DeepChem.\n        return rbf_kernel_results * cutoff_values\n\n\ndef msg_func(edges):\n    \"\"\"\n\n    Description\n    -----------\n    Send messages along edges.\n\n    Parameters\n    ----------\n    edges : EdgeBatch\n        A batch of edges.\n\n    Returns\n    -------\n    dict mapping 'm' to Float32 tensor of shape (E, K * T)\n        Messages computed. E for the number of edges, K for the number of\n        radial filters and T for the number of features to use\n        (types of atomic number in the paper).\n    \"\"\"\n    return {\n        \"m\": th.einsum(\"ij,ik->ijk\", edges.src[\"hv\"], edges.data[\"he\"]).view(\n            len(edges), -1\n        )\n    }\n\n\ndef reduce_func(nodes):\n    \"\"\"\n\n    Description\n    -----------\n    Collect messages and update node representations.\n\n    Parameters\n    ----------\n    nodes : NodeBatch\n        A batch of nodes.\n\n    Returns\n    -------\n    dict mapping 'hv_new' to Float32 tensor of shape (V, K * T)\n        Updated node representations. V for the number of nodes, K for the number of\n        radial filters and T for the number of features to use\n        (types of atomic number in the paper).\n    \"\"\"\n    return {\"hv_new\": nodes.mailbox[\"m\"].sum(1)}\n\n\nclass AtomicConv(nn.Module):\n    r\"\"\"Atomic Convolution Layer from `Atomic Convolutional Networks for\n    Predicting Protein-Ligand Binding Affinity <https://arxiv.org/abs/1703.10603>`__\n\n    Denoting the type of atom :math:`i` by :math:`z_i` and the distance between atom\n    :math:`i` and :math:`j` by :math:`r_{ij}`.\n\n    **Distance Transformation**\n\n    An atomic convolution layer first transforms distances with radial filters and\n    then perform a pooling operation.\n\n    For radial filter indexed by :math:`k`, it projects edge distances with\n\n    .. math::\n        h_{ij}^{k} = \\exp(-\\gamma_{k}|r_{ij}-r_{k}|^2)\n\n    If :math:`r_{ij} < c_k`,\n\n    .. math::\n        f_{ij}^{k} = 0.5 * \\cos(\\frac{\\pi r_{ij}}{c_k} + 1),\n\n    else,\n\n    .. math::\n        f_{ij}^{k} = 0.\n\n    Finally,\n\n    .. math::\n        e_{ij}^{k} = h_{ij}^{k} * f_{ij}^{k}\n\n    **Aggregation**\n\n    For each type :math:`t`, each atom collects distance information from all neighbor atoms\n    of type :math:`t`:\n\n    .. math::\n        p_{i, t}^{k} = \\sum_{j\\in N(i)} e_{ij}^{k} * 1(z_j == t)\n\n    Then concatenate the results for all RBF kernels and atom types.\n\n    Parameters\n    ----------\n    interaction_cutoffs : float32 tensor of shape (K)\n        :math:`c_k` in the equations above. Roughly they can be considered as learnable cutoffs\n        and two atoms are considered as connected if the distance between them is smaller than\n        the cutoffs. K for the number of radial filters.\n    rbf_kernel_means : float32 tensor of shape (K)\n        :math:`r_k` in the equations above. K for the number of radial filters.\n    rbf_kernel_scaling : float32 tensor of shape (K)\n        :math:`\\gamma_k` in the equations above. K for the number of radial filters.\n    features_to_use : None or float tensor of shape (T)\n        In the original paper, these are atomic numbers to consider, representing the types\n        of atoms. T for the number of types of atomic numbers. Default to None.\n\n    Note\n    ----\n\n    * This convolution operation is designed for molecular graphs in Chemistry, but it might\n      be possible to extend it to more general graphs.\n\n    * There seems to be an inconsistency about the definition of :math:`e_{ij}^{k}` in the\n      paper and the author's implementation. We follow the author's implementation. In the\n      paper, :math:`e_{ij}^{k}` was defined as\n      :math:`\\exp(-\\gamma_{k}|r_{ij}-r_{k}|^2 * f_{ij}^{k})`.\n\n    * :math:`\\gamma_{k}`, :math:`r_k` and :math:`c_k` are all learnable.\n\n    Example\n    -------\n    >>> import dgl\n    >>> import numpy as np\n    >>> import torch as th\n    >>> from dgl.nn import AtomicConv\n\n    >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))\n    >>> feat = th.ones(6, 1)\n    >>> edist = th.ones(6, 1)\n    >>> interaction_cutoffs = th.ones(3).float() * 2\n    >>> rbf_kernel_means = th.ones(3).float()\n    >>> rbf_kernel_scaling = th.ones(3).float()\n    >>> conv = AtomicConv(interaction_cutoffs, rbf_kernel_means, rbf_kernel_scaling)\n    >>> res = conv(g, feat, edist)\n    >>> res\n    tensor([[0.5000, 0.5000, 0.5000],\n                [0.5000, 0.5000, 0.5000],\n                [0.5000, 0.5000, 0.5000],\n                [1.0000, 1.0000, 1.0000],\n                [0.5000, 0.5000, 0.5000],\n                [0.0000, 0.0000, 0.0000]], grad_fn=<ViewBackward>)\n    \"\"\"\n\n    def __init__(\n        self,\n        interaction_cutoffs,\n        rbf_kernel_means,\n        rbf_kernel_scaling,\n        features_to_use=None,\n    ):\n        super(AtomicConv, self).__init__()\n\n        self.radial_pooling = RadialPooling(\n            interaction_cutoffs=interaction_cutoffs,\n            rbf_kernel_means=rbf_kernel_means,\n            rbf_kernel_scaling=rbf_kernel_scaling,\n        )\n        if features_to_use is None:\n            self.num_channels = 1\n            self.features_to_use = None\n        else:\n            self.num_channels = len(features_to_use)\n            self.features_to_use = nn.Parameter(\n                features_to_use, requires_grad=False\n            )\n\n    def forward(self, graph, feat, distances):\n        \"\"\"\n\n        Description\n        -----------\n        Apply the atomic convolution layer.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            Topology based on which message passing is performed.\n        feat : Float32 tensor of shape :math:`(V, 1)`\n            Initial node features, which are atomic numbers in the paper.\n            :math:`V` for the number of nodes.\n        distances : Float32 tensor of shape :math:`(E, 1)`\n            Distance between end nodes of edges. E for the number of edges.\n\n        Returns\n        -------\n        Float32 tensor of shape :math:`(V, K * T)`\n            Updated node representations. :math:`V` for the number of nodes, :math:`K` for the\n            number of radial filters, and :math:`T` for the number of types of atomic numbers.\n        \"\"\"\n        with graph.local_scope():\n            radial_pooled_values = self.radial_pooling(distances).to(\n                feat\n            )  # (K, E, 1)\n            if self.features_to_use is not None:\n                feat = (feat == self.features_to_use).to(feat)  # (V, T)\n            graph.ndata[\"hv\"] = feat\n            graph.edata[\"he\"] = radial_pooled_values.transpose(1, 0).squeeze(\n                -1\n            )  # (E, K)\n            graph.update_all(msg_func, reduce_func)\n\n            return graph.ndata[\"hv_new\"].view(\n                graph.num_nodes(), -1\n            )  # (V, K * T)\n"
  },
  {
    "path": "python/dgl/nn/pytorch/conv/cfconv.py",
    "content": "\"\"\"Torch modules for interaction blocks in SchNet\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name\nimport numpy as np\nimport torch.nn as nn\n\nfrom .... import function as fn\n\n\nclass ShiftedSoftplus(nn.Module):\n    r\"\"\"Applies the element-wise function:\n\n    .. math::\n        \\text{SSP}(x) = \\frac{1}{\\beta} * \\log(1 + \\exp(\\beta * x)) - \\log(\\text{shift})\n\n    Attributes\n    ----------\n    beta : int\n        :math:`\\beta` value for the mathematical formulation. Default to 1.\n    shift : int\n        :math:`\\text{shift}` value for the mathematical formulation. Default to 2.\n    \"\"\"\n\n    def __init__(self, beta=1, shift=2, threshold=20):\n        super(ShiftedSoftplus, self).__init__()\n\n        self.shift = shift\n        self.softplus = nn.Softplus(beta=beta, threshold=threshold)\n\n    def forward(self, inputs):\n        \"\"\"\n\n        Description\n        -----------\n        Applies the activation function.\n\n        Parameters\n        ----------\n        inputs : float32 tensor of shape (N, *)\n            * denotes any number of additional dimensions.\n\n        Returns\n        -------\n        float32 tensor of shape (N, *)\n            Result of applying the activation function to the input.\n        \"\"\"\n        return self.softplus(inputs) - np.log(float(self.shift))\n\n\nclass CFConv(nn.Module):\n    r\"\"\"CFConv from `SchNet: A continuous-filter convolutional neural network for\n    modeling quantum interactions <https://arxiv.org/abs/1706.08566>`__\n\n    It combines node and edge features in message passing and updates node representations.\n\n    .. math::\n        h_i^{(l+1)} = \\sum_{j\\in \\mathcal{N}(i)} h_j^{l} \\circ W^{(l)}e_ij\n\n    where :math:`\\circ` represents element-wise multiplication and for :math:`\\text{SPP}` :\n\n    .. math::\n        \\text{SSP}(x) = \\frac{1}{\\beta} * \\log(1 + \\exp(\\beta * x)) - \\log(\\text{shift})\n\n    Parameters\n    ----------\n    node_in_feats : int\n        Size for the input node features :math:`h_j^{(l)}`.\n    edge_in_feats : int\n        Size for the input edge features :math:`e_ij`.\n    hidden_feats : int\n        Size for the hidden representations.\n    out_feats : int\n        Size for the output representations :math:`h_j^{(l+1)}`.\n\n    Example\n    -------\n    >>> import dgl\n    >>> import numpy as np\n    >>> import torch as th\n    >>> from dgl.nn import CFConv\n    >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))\n    >>> nfeat = th.ones(6, 10)\n    >>> efeat = th.ones(6, 5)\n    >>> conv = CFConv(10, 5, 3, 2)\n    >>> res = conv(g, nfeat, efeat)\n    >>> res\n    tensor([[-0.1209, -0.2289],\n            [-0.1209, -0.2289],\n            [-0.1209, -0.2289],\n            [-0.1135, -0.2338],\n            [-0.1209, -0.2289],\n            [-0.1283, -0.2240]], grad_fn=<SubBackward0>)\n    \"\"\"\n\n    def __init__(self, node_in_feats, edge_in_feats, hidden_feats, out_feats):\n        super(CFConv, self).__init__()\n\n        self.project_edge = nn.Sequential(\n            nn.Linear(edge_in_feats, hidden_feats),\n            ShiftedSoftplus(),\n            nn.Linear(hidden_feats, hidden_feats),\n            ShiftedSoftplus(),\n        )\n        self.project_node = nn.Linear(node_in_feats, hidden_feats)\n        self.project_out = nn.Sequential(\n            nn.Linear(hidden_feats, out_feats), ShiftedSoftplus()\n        )\n\n    def forward(self, g, node_feats, edge_feats):\n        \"\"\"\n\n        Description\n        -----------\n        Performs message passing and updates node representations.\n\n        Parameters\n        ----------\n        g : DGLGraph\n            The graph.\n        node_feats : torch.Tensor or pair of torch.Tensor\n            The input node features. If a torch.Tensor is given, it represents the input\n            node feature of shape :math:`(N, D_{in})` where :math:`D_{in}` is size of\n            input feature, :math:`N` is the number of nodes.\n            If a pair of torch.Tensor is given, which is the case for bipartite graph,\n            the pair must contain two tensors of shape :math:`(N_{src}, D_{in_{src}})` and\n            :math:`(N_{dst}, D_{in_{dst}})` separately for the source and destination nodes.\n\n        edge_feats : torch.Tensor\n            The input edge feature of shape :math:`(E, edge_in_feats)`\n            where :math:`E` is the number of edges.\n\n        Returns\n        -------\n        torch.Tensor\n            The output node feature of shape :math:`(N_{out}, out_feats)`\n            where :math:`N_{out}` is the number of destination nodes.\n        \"\"\"\n        with g.local_scope():\n            if isinstance(node_feats, tuple):\n                node_feats_src, _ = node_feats\n            else:\n                node_feats_src = node_feats\n            g.srcdata[\"hv\"] = self.project_node(node_feats_src)\n            g.edata[\"he\"] = self.project_edge(edge_feats)\n            g.update_all(fn.u_mul_e(\"hv\", \"he\", \"m\"), fn.sum(\"m\", \"h\"))\n            return self.project_out(g.dstdata[\"h\"])\n"
  },
  {
    "path": "python/dgl/nn/pytorch/conv/chebconv.py",
    "content": "\"\"\"Torch Module for Chebyshev Spectral Graph Convolution layer\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name\nimport torch as th\nimport torch.nn.functional as F\nfrom torch import nn\n\nfrom .... import broadcast_nodes, function as fn\nfrom ....base import dgl_warning\n\n\nclass ChebConv(nn.Module):\n    r\"\"\"Chebyshev Spectral Graph Convolution layer from `Convolutional\n    Neural Networks on Graphs with Fast Localized Spectral Filtering\n    <https://arxiv.org/pdf/1606.09375.pdf>`__\n\n    .. math::\n        h_i^{l+1} &= \\sum_{k=0}^{K-1} W^{k, l}z_i^{k, l}\n\n        Z^{0, l} &= H^{l}\n\n        Z^{1, l} &= \\tilde{L} \\cdot H^{l}\n\n        Z^{k, l} &= 2 \\cdot \\tilde{L} \\cdot Z^{k-1, l} - Z^{k-2, l}\n\n        \\tilde{L} &= 2\\left(I - \\tilde{D}^{-1/2} \\tilde{A} \\tilde{D}^{-1/2}\\right)/\\lambda_{max} - I\n\n    where :math:`\\tilde{A}` is :math:`A` + :math:`I`, :math:`W` is learnable weight.\n\n\n    Parameters\n    ----------\n    in_feats: int\n        Dimension of input features; i.e, the number of dimensions of :math:`h_i^{(l)}`.\n    out_feats: int\n        Dimension of output features :math:`h_i^{(l+1)}`.\n    k : int\n        Chebyshev filter size :math:`K`.\n    activation : function, optional\n        Activation function. Default ``ReLu``.\n    bias : bool, optional\n        If True, adds a learnable bias to the output. Default: ``True``.\n\n    Example\n    -------\n    >>> import dgl\n    >>> import numpy as np\n    >>> import torch as th\n    >>> from dgl.nn import ChebConv\n    >>\n    >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))\n    >>> feat = th.ones(6, 10)\n    >>> conv = ChebConv(10, 2, 2)\n    >>> res = conv(g, feat)\n    >>> res\n    tensor([[ 0.6163, -0.1809],\n            [ 0.6163, -0.1809],\n            [ 0.6163, -0.1809],\n            [ 0.9698, -1.5053],\n            [ 0.3664,  0.7556],\n            [-0.2370,  3.0164]], grad_fn=<AddBackward0>)\n    \"\"\"\n\n    def __init__(self, in_feats, out_feats, k, activation=F.relu, bias=True):\n        super(ChebConv, self).__init__()\n        self._k = k\n        self._in_feats = in_feats\n        self._out_feats = out_feats\n        self.activation = activation\n        self.linear = nn.Linear(k * in_feats, out_feats, bias)\n\n    def forward(self, graph, feat, lambda_max=None):\n        r\"\"\"Compute ChebNet layer.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The graph.\n        feat : torch.Tensor\n            The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`\n            is size of input feature, :math:`N` is the number of nodes.\n        lambda_max : list or tensor or None, optional.\n            A list(tensor) with length :math:`B`, stores the largest eigenvalue\n            of the normalized laplacian of each individual graph in ``graph``,\n            where :math:`B` is the batch size of the input graph. Default: None.\n\n            If None, this method would set the default value to 2.\n            One can use :func:`dgl.laplacian_lambda_max` to compute this value.\n\n        Returns\n        -------\n        torch.Tensor\n            The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`\n            is size of output feature.\n        \"\"\"\n\n        def unnLaplacian(feat, D_invsqrt, graph):\n            \"\"\"Operation Feat * D^-1/2 A D^-1/2\"\"\"\n            graph.ndata[\"h\"] = feat * D_invsqrt\n            graph.update_all(fn.copy_u(\"h\", \"m\"), fn.sum(\"m\", \"h\"))\n            return graph.ndata.pop(\"h\") * D_invsqrt\n\n        with graph.local_scope():\n            D_invsqrt = th.pow(\n                graph.in_degrees().to(feat).clamp(min=1), -0.5\n            ).unsqueeze(-1)\n\n            if lambda_max is None:\n                dgl_warning(\n                    \"lambda_max is not provided, using default value of 2.  \"\n                    \"Please use dgl.laplacian_lambda_max to compute the eigenvalues.\"\n                )\n                lambda_max = [2] * graph.batch_size\n\n            if isinstance(lambda_max, list):\n                lambda_max = th.Tensor(lambda_max).to(feat)\n            if lambda_max.dim() == 1:\n                lambda_max = lambda_max.unsqueeze(-1)  # (B,) to (B, 1)\n\n            # broadcast from (B, 1) to (N, 1)\n            lambda_max = broadcast_nodes(graph, lambda_max)\n            re_norm = 2.0 / lambda_max\n\n            # X_0 is the raw feature, Xt is the list of X_0, X_1, ... X_t\n            X_0 = feat\n            Xt = [X_0]\n\n            # X_1(f)\n            if self._k > 1:\n                h = unnLaplacian(X_0, D_invsqrt, graph)\n                X_1 = -re_norm * h + X_0 * (re_norm - 1)\n                # Append X_1 to Xt\n                Xt.append(X_1)\n\n            # Xi(x), i = 2...k\n            for _ in range(2, self._k):\n                h = unnLaplacian(X_1, D_invsqrt, graph)\n                X_i = -2 * re_norm * h + X_1 * 2 * (re_norm - 1) - X_0\n                # Add X_1 to Xt\n                Xt.append(X_i)\n                X_1, X_0 = X_i, X_1\n\n            # Create the concatenation\n            Xt = th.cat(Xt, dim=1)\n\n            # linear projection\n            h = self.linear(Xt)\n\n            # activation\n            if self.activation:\n                h = self.activation(h)\n\n        return h\n"
  },
  {
    "path": "python/dgl/nn/pytorch/conv/cugraph_base.py",
    "content": "\"\"\"An abstract base class for cugraph-ops nn module.\"\"\"\nimport torch\nfrom torch import nn\n\n\nclass CuGraphBaseConv(nn.Module):\n    r\"\"\"An abstract base class for cugraph-ops nn module.\"\"\"\n\n    def __init__(self):\n        super().__init__()\n        self._cached_offsets_fg = None\n\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        raise NotImplementedError\n\n    def forward(self, *args):\n        r\"\"\"Runs the forward pass of the module.\"\"\"\n        raise NotImplementedError\n\n    def pad_offsets(self, offsets: torch.Tensor, size: int) -> torch.Tensor:\n        r\"\"\"Pad zero-in-degree nodes to the end of offsets to reach size.\n\n        cugraph-ops often provides two variants of aggregation functions for a\n        specific model: one intended for sampled-graph use cases, one for\n        full-graph ones. The former is in general more performant, however, it\n        only works when the sample size (the max of in-degrees) is small (<200),\n        due to the limit of GPU shared memory. For graphs with a larger max\n        in-degree, we need to fall back to the full-graph option, which requires\n        to convert a DGL block to a full graph. With the csc-representation,\n        this is equivalent to pad zero-in-degree nodes to the end of the offsets\n        array (also called indptr or colptr).\n\n        Parameters\n        ----------\n        offsets :\n            The (monotonically increasing) index pointer array in a CSC-format\n            graph.\n        size : int\n            The length of offsets after padding.\n\n        Returns\n        -------\n        torch.Tensor\n            The augmented offsets array.\n        \"\"\"\n        if self._cached_offsets_fg is None:\n            self._cached_offsets_fg = torch.empty(\n                size, dtype=offsets.dtype, device=offsets.device\n            )\n        elif self._cached_offsets_fg.numel() < size:\n            self._cached_offsets_fg.resize_(size)\n\n        self._cached_offsets_fg[: offsets.numel()] = offsets\n        self._cached_offsets_fg[offsets.numel() : size] = offsets[-1]\n\n        return self._cached_offsets_fg[:size]\n"
  },
  {
    "path": "python/dgl/nn/pytorch/conv/cugraph_gatconv.py",
    "content": "\"\"\"Torch Module for graph attention network layer using the aggregation\nprimitives in cugraph-ops\"\"\"\n# pylint: disable=no-member, arguments-differ, invalid-name, too-many-arguments\n\nimport torch\nfrom torch import nn\n\nfrom .cugraph_base import CuGraphBaseConv\n\ntry:\n    from pylibcugraphops.pytorch import SampledCSC, StaticCSC\n    from pylibcugraphops.pytorch.operators import mha_gat_n2n as GATConvAgg\n\n    HAS_PYLIBCUGRAPHOPS = True\nexcept ImportError:\n    HAS_PYLIBCUGRAPHOPS = False\n\n\nclass CuGraphGATConv(CuGraphBaseConv):\n    r\"\"\"Graph attention layer from `Graph Attention Networks\n    <https://arxiv.org/pdf/1710.10903.pdf>`__, with the sparse aggregation\n    accelerated by cugraph-ops.\n\n    See :class:`dgl.nn.pytorch.conv.GATConv` for mathematical model.\n\n    This module depends on :code:`pylibcugraphops` package, which can be\n    installed via :code:`conda install -c nvidia pylibcugraphops=23.04`.\n    :code:`pylibcugraphops` 23.04 requires python 3.8.x or 3.10.x.\n\n    .. note::\n        This is an **experimental** feature.\n\n    Parameters\n    ----------\n    in_feats : int\n        Input feature size.\n    out_feats : int\n        Output feature size.\n    num_heads : int\n        Number of heads in Multi-Head Attention.\n    feat_drop : float, optional\n        Dropout rate on feature. Defaults: ``0``.\n    negative_slope : float, optional\n        LeakyReLU angle of negative slope. Defaults: ``0.2``.\n    residual : bool, optional\n        If True, use residual connection. Defaults: ``False``.\n    activation : callable activation function/layer or None, optional.\n        If not None, applies an activation function to the updated node features.\n        Default: ``None``.\n    bias : bool, optional\n        If True, learns a bias term. Defaults: ``True``.\n\n    Examples\n    --------\n    >>> import dgl\n    >>> import torch\n    >>> from dgl.nn import CuGraphGATConv\n    >>> device = 'cuda'\n    >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])).to(device)\n    >>> g = dgl.add_self_loop(g)\n    >>> feat = torch.ones(6, 10).to(device)\n    >>> conv = CuGraphGATConv(10, 2, num_heads=3).to(device)\n    >>> res = conv(g, feat)\n    >>> res\n    tensor([[[ 0.2340,  1.9226],\n            [ 1.6477, -1.9986],\n            [ 1.1138, -1.9302]],\n            [[ 0.2340,  1.9226],\n            [ 1.6477, -1.9986],\n            [ 1.1138, -1.9302]],\n            [[ 0.2340,  1.9226],\n            [ 1.6477, -1.9986],\n            [ 1.1138, -1.9302]],\n            [[ 0.2340,  1.9226],\n            [ 1.6477, -1.9986],\n            [ 1.1138, -1.9302]],\n            [[ 0.2340,  1.9226],\n            [ 1.6477, -1.9986],\n            [ 1.1138, -1.9302]],\n            [[ 0.2340,  1.9226],\n            [ 1.6477, -1.9986],\n            [ 1.1138, -1.9302]]], device='cuda:0', grad_fn=<ViewBackward0>)\n    \"\"\"\n    MAX_IN_DEGREE_MFG = 200\n\n    def __init__(\n        self,\n        in_feats,\n        out_feats,\n        num_heads,\n        feat_drop=0.0,\n        negative_slope=0.2,\n        residual=False,\n        activation=None,\n        bias=True,\n    ):\n        if HAS_PYLIBCUGRAPHOPS is False:\n            raise ModuleNotFoundError(\n                f\"{self.__class__.__name__} requires pylibcugraphops=23.04. \"\n                f\"Install via `conda install -c nvidia 'pylibcugraphops=23.04'`.\"\n                f\"pylibcugraphops requires Python 3.8 or 3.10.\"\n            )\n        super().__init__()\n        self.in_feats = in_feats\n        self.out_feats = out_feats\n        self.num_heads = num_heads\n        self.feat_drop = nn.Dropout(feat_drop)\n        self.negative_slope = negative_slope\n        self.activation = activation\n\n        self.fc = nn.Linear(in_feats, out_feats * num_heads, bias=False)\n        self.attn_weights = nn.Parameter(\n            torch.Tensor(2 * num_heads * out_feats)\n        )\n\n        if bias:\n            self.bias = nn.Parameter(torch.Tensor(num_heads * out_feats))\n        else:\n            self.register_buffer(\"bias\", None)\n\n        if residual:\n            if in_feats == out_feats * num_heads:\n                self.res_fc = nn.Identity()\n            else:\n                self.res_fc = nn.Linear(\n                    in_feats, out_feats * num_heads, bias=False\n                )\n        else:\n            self.register_buffer(\"res_fc\", None)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        r\"\"\"Reinitialize learnable parameters.\"\"\"\n\n        gain = nn.init.calculate_gain(\"relu\")\n        nn.init.xavier_normal_(self.fc.weight, gain=gain)\n        nn.init.xavier_normal_(\n            self.attn_weights.view(2, self.num_heads, self.out_feats), gain=gain\n        )\n        if self.bias is not None:\n            nn.init.zeros_(self.bias)\n\n        if isinstance(self.res_fc, nn.Linear):\n            self.res_fc.reset_parameters()\n\n    def forward(self, g, feat, max_in_degree=None):\n        r\"\"\"Forward computation.\n\n        Parameters\n        ----------\n        g : DGLGraph\n            The graph.\n        feat : torch.Tensor\n            Input features of shape :math:`(N, D_{in})`.\n        max_in_degree : int\n            Maximum in-degree of destination nodes. It is only effective when\n            :attr:`g` is a :class:`DGLBlock`, i.e., bipartite graph. When\n            :attr:`g` is generated from a neighbor sampler, the value should be\n            set to the corresponding :attr:`fanout`. If not given,\n            :attr:`max_in_degree` will be calculated on-the-fly.\n\n        Returns\n        -------\n        torch.Tensor\n            The output feature of shape :math:`(N, H, D_{out})` where\n            :math:`H` is the number of heads, and :math:`D_{out}` is size of\n            output feature.\n        \"\"\"\n        offsets, indices, _ = g.adj_tensors(\"csc\")\n\n        if g.is_block:\n            if max_in_degree is None:\n                max_in_degree = g.in_degrees().max().item()\n\n            if max_in_degree < self.MAX_IN_DEGREE_MFG:\n                _graph = SampledCSC(\n                    offsets,\n                    indices,\n                    max_in_degree,\n                    g.num_src_nodes(),\n                )\n            else:\n                offsets_fg = self.pad_offsets(offsets, g.num_src_nodes() + 1)\n                _graph = StaticCSC(offsets_fg, indices)\n        else:\n            _graph = StaticCSC(offsets, indices)\n\n        feat = self.feat_drop(feat)\n        feat_transformed = self.fc(feat)\n        out = GATConvAgg(\n            feat_transformed,\n            self.attn_weights,\n            _graph,\n            self.num_heads,\n            \"LeakyReLU\",\n            self.negative_slope,\n            concat_heads=True,\n        )[: g.num_dst_nodes()].view(-1, self.num_heads, self.out_feats)\n\n        feat_dst = feat[: g.num_dst_nodes()]\n        if self.res_fc is not None:\n            out = out + self.res_fc(feat_dst).view(\n                -1, self.num_heads, self.out_feats\n            )\n\n        if self.bias is not None:\n            out = out + self.bias.view(-1, self.num_heads, self.out_feats)\n\n        if self.activation is not None:\n            out = self.activation(out)\n\n        return out\n"
  },
  {
    "path": "python/dgl/nn/pytorch/conv/cugraph_relgraphconv.py",
    "content": "\"\"\"Torch Module for Relational graph convolution layer using the aggregation\nprimitives in cugraph-ops\"\"\"\n# pylint: disable=no-member, arguments-differ, invalid-name, too-many-arguments\nimport math\n\nimport torch\nfrom torch import nn\n\nfrom .cugraph_base import CuGraphBaseConv\n\ntry:\n    from pylibcugraphops.pytorch import HeteroCSC\n    from pylibcugraphops.pytorch.operators import (\n        agg_hg_basis_n2n_post as RelGraphConvAgg,\n    )\n\n    HAS_PYLIBCUGRAPHOPS = True\nexcept ImportError:\n    HAS_PYLIBCUGRAPHOPS = False\n\n\nclass CuGraphRelGraphConv(CuGraphBaseConv):\n    r\"\"\"An accelerated relational graph convolution layer from `Modeling\n    Relational Data with Graph Convolutional Networks\n    <https://arxiv.org/abs/1703.06103>`__ that leverages the highly-optimized\n    aggregation primitives in cugraph-ops.\n\n    See :class:`dgl.nn.pytorch.conv.RelGraphConv` for mathematical model.\n\n    This module depends on :code:`pylibcugraphops` package, which can be\n    installed via :code:`conda install -c nvidia pylibcugraphops=23.04`.\n    :code:`pylibcugraphops` 23.04 requires python 3.8.x or 3.10.x.\n\n    .. note::\n        This is an **experimental** feature.\n\n    Parameters\n    ----------\n    in_feat : int\n        Input feature size.\n    out_feat : int\n        Output feature size.\n    num_rels : int\n        Number of relations.\n    regularizer : str, optional\n        Which weight regularizer to use (\"basis\" or ``None``):\n         - \"basis\" is for basis-decomposition.\n         - ``None`` applies no regularization.\n        Default: ``None``.\n    num_bases : int, optional\n        Number of bases. It comes into effect when a regularizer is applied.\n        Default: ``None``.\n    bias : bool, optional\n        True if bias is added. Default: ``True``.\n    self_loop : bool, optional\n        True to include self loop message. Default: ``True``.\n    dropout : float, optional\n        Dropout rate. Default: ``0.0``.\n    apply_norm : bool, optional\n        True to normalize aggregation output by the in-degree of the destination\n        node per edge type, i.e. :math:`|\\mathcal{N}^r_i|`. Default: ``True``.\n\n    Examples\n    --------\n    >>> import dgl\n    >>> import torch\n    >>> from dgl.nn import CuGraphRelGraphConv\n    ...\n    >>> device = 'cuda'\n    >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])).to(device)\n    >>> feat = torch.ones(6, 10).to(device)\n    >>> conv = CuGraphRelGraphConv(\n    ...     10, 2, 3, regularizer='basis', num_bases=2).to(device)\n    >>> etype = torch.tensor([0,1,2,0,1,2]).to(device)\n    >>> res = conv(g, feat, etype)\n    >>> res\n    tensor([[-1.7774, -2.0184],\n            [-1.4335, -2.3758],\n            [-1.7774, -2.0184],\n            [-0.4698, -3.0876],\n            [-1.4335, -2.3758],\n            [-1.4331, -2.3295]], device='cuda:0', grad_fn=<AddBackward0>)\n    \"\"\"\n    MAX_IN_DEGREE_MFG = 500\n\n    def __init__(\n        self,\n        in_feat,\n        out_feat,\n        num_rels,\n        regularizer=None,\n        num_bases=None,\n        bias=True,\n        self_loop=True,\n        dropout=0.0,\n        apply_norm=False,\n    ):\n        if HAS_PYLIBCUGRAPHOPS is False:\n            raise ModuleNotFoundError(\n                f\"{self.__class__.__name__} requires pylibcugraphops=23.04. \"\n                f\"Install via `conda install -c nvidia 'pylibcugraphops=23.04'`.\"\n                f\"pylibcugraphops requires Python 3.8 or 3.10.\"\n            )\n        super().__init__()\n        self.in_feat = in_feat\n        self.out_feat = out_feat\n        self.num_rels = num_rels\n        self.apply_norm = apply_norm\n        self.dropout = nn.Dropout(dropout)\n\n        dim_self_loop = 1 if self_loop else 0\n        self.self_loop = self_loop\n        if regularizer is None:\n            self.W = nn.Parameter(\n                torch.Tensor(num_rels + dim_self_loop, in_feat, out_feat)\n            )\n            self.coeff = None\n        elif regularizer == \"basis\":\n            if num_bases is None:\n                raise ValueError(\n                    'Missing \"num_bases\" for basis regularization.'\n                )\n            self.W = nn.Parameter(\n                torch.Tensor(num_bases + dim_self_loop, in_feat, out_feat)\n            )\n            self.coeff = nn.Parameter(torch.Tensor(num_rels, num_bases))\n            self.num_bases = num_bases\n        else:\n            raise ValueError(\n                f\"Supported regularizer options: 'basis' or None, but got \"\n                f\"'{regularizer}'.\"\n            )\n        self.regularizer = regularizer\n\n        if bias:\n            self.bias = nn.Parameter(torch.Tensor(out_feat))\n        else:\n            self.register_parameter(\"bias\", None)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        r\"\"\"Reinitialize learnable parameters.\"\"\"\n        bound = 1 / math.sqrt(self.in_feat)\n        end = -1 if self.self_loop else None\n        nn.init.uniform_(self.W[:end], -bound, bound)\n        if self.regularizer == \"basis\":\n            nn.init.xavier_uniform_(\n                self.coeff, gain=nn.init.calculate_gain(\"relu\")\n            )\n        if self.self_loop:\n            nn.init.xavier_uniform_(self.W[-1], nn.init.calculate_gain(\"relu\"))\n        if self.bias is not None:\n            nn.init.zeros_(self.bias)\n\n    def forward(self, g, feat, etypes, max_in_degree=None):\n        r\"\"\"Forward computation.\n\n        Parameters\n        ----------\n        g : DGLGraph\n            The graph.\n        feat : torch.Tensor\n            A 2D tensor of node features. Shape: :math:`(|V|, D_{in})`.\n        etypes : torch.Tensor\n            A 1D integer tensor of edge types. Shape: :math:`(|E|,)`.\n            Note that cugraph-ops only accepts edge type tensors in int32,\n            so any input of other integer types will be casted into int32,\n            thus introducing some overhead. Pass in int32 tensors directly\n            for best performance.\n        max_in_degree : int, optional\n            Maximum in-degree of destination nodes. It is only effective when\n            :attr:`g` is a :class:`DGLBlock`, i.e., bipartite graph. When\n            :attr:`g` is generated from a neighbor sampler, the value should be\n            set to the corresponding :attr:`fanout`. If not given,\n            :attr:`max_in_degree` will be calculated on-the-fly.\n\n        Returns\n        -------\n        torch.Tensor\n            New node features. Shape: :math:`(|V|, D_{out})`.\n        \"\"\"\n        offsets, indices, edge_ids = g.adj_tensors(\"csc\")\n        edge_types_perm = etypes[edge_ids.long()].int()\n\n        if g.is_block:\n            if max_in_degree is None:\n                max_in_degree = g.in_degrees().max().item()\n\n            if max_in_degree < self.MAX_IN_DEGREE_MFG:\n                _graph = HeteroCSC(\n                    offsets,\n                    indices,\n                    edge_types_perm,\n                    g.num_src_nodes(),\n                    self.num_rels,\n                )\n            else:\n                offsets_fg = self.pad_offsets(offsets, g.num_src_nodes() + 1)\n                _graph = HeteroCSC(\n                    offsets_fg,\n                    indices,\n                    edge_types_perm,\n                    g.num_src_nodes(),\n                    self.num_rels,\n                )\n        else:\n            _graph = HeteroCSC(\n                offsets,\n                indices,\n                edge_types_perm,\n                g.num_src_nodes(),\n                self.num_rels,\n            )\n\n        h = RelGraphConvAgg(\n            feat,\n            self.coeff,\n            _graph,\n            concat_own=self.self_loop,\n            norm_by_out_degree=self.apply_norm,\n        )[: g.num_dst_nodes()]\n        h = h @ self.W.view(-1, self.out_feat)\n        if self.bias is not None:\n            h = h + self.bias\n        h = self.dropout(h)\n\n        return h\n"
  },
  {
    "path": "python/dgl/nn/pytorch/conv/cugraph_sageconv.py",
    "content": "\"\"\"Torch Module for GraphSAGE layer using the aggregation primitives in\ncugraph-ops\"\"\"\n# pylint: disable=no-member, arguments-differ, invalid-name, too-many-arguments\n\nfrom torch import nn\n\nfrom .cugraph_base import CuGraphBaseConv\n\ntry:\n    from pylibcugraphops.pytorch import SampledCSC, StaticCSC\n    from pylibcugraphops.pytorch.operators import agg_concat_n2n as SAGEConvAgg\n\n    HAS_PYLIBCUGRAPHOPS = True\nexcept ImportError:\n    HAS_PYLIBCUGRAPHOPS = False\n\n\nclass CuGraphSAGEConv(CuGraphBaseConv):\n    r\"\"\"An accelerated GraphSAGE layer from `Inductive Representation Learning\n    on Large Graphs <https://arxiv.org/pdf/1706.02216.pdf>`__ that leverages the\n    highly-optimized aggregation primitives in cugraph-ops:\n\n    .. math::\n        h_{\\mathcal{N}(i)}^{(l+1)} &= \\mathrm{aggregate}\n        \\left(\\{h_{j}^{l}, \\forall j \\in \\mathcal{N}(i) \\}\\right)\n\n        h_{i}^{(l+1)} &= W \\cdot \\mathrm{concat}\n        (h_{i}^{l}, h_{\\mathcal{N}(i)}^{(l+1)})\n\n    This module depends on :code:`pylibcugraphops` package, which can be\n    installed via :code:`conda install -c nvidia pylibcugraphops=23.04`.\n    :code:`pylibcugraphops` 23.04 requires python 3.8.x or 3.10.x.\n\n    .. note::\n        This is an **experimental** feature.\n\n    Parameters\n    ----------\n    in_feats : int\n        Input feature size.\n    out_feats : int\n        Output feature size.\n    aggregator_type : str\n        Aggregator type to use (``mean``, ``sum``, ``min``, ``max``).\n    feat_drop : float\n        Dropout rate on features, default: ``0``.\n    bias : bool\n        If True, adds a learnable bias to the output. Default: ``True``.\n\n    Examples\n    --------\n    >>> import dgl\n    >>> import torch\n    >>> from dgl.nn import CuGraphSAGEConv\n    >>> device = 'cuda'\n    >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])).to(device)\n    >>> g = dgl.add_self_loop(g)\n    >>> feat = torch.ones(6, 10).to(device)\n    >>> conv = CuGraphSAGEConv(10, 2, 'mean').to(device)\n    >>> res = conv(g, feat)\n    >>> res\n    tensor([[-1.1690,  0.1952],\n            [-1.1690,  0.1952],\n            [-1.1690,  0.1952],\n            [-1.1690,  0.1952],\n            [-1.1690,  0.1952],\n            [-1.1690,  0.1952]], device='cuda:0', grad_fn=<AddmmBackward0>)\n    \"\"\"\n    MAX_IN_DEGREE_MFG = 500\n\n    def __init__(\n        self,\n        in_feats,\n        out_feats,\n        aggregator_type=\"mean\",\n        feat_drop=0.0,\n        bias=True,\n    ):\n        if HAS_PYLIBCUGRAPHOPS is False:\n            raise ModuleNotFoundError(\n                f\"{self.__class__.__name__} requires pylibcugraphops=23.04. \"\n                f\"Install via `conda install -c nvidia 'pylibcugraphops=23.04'`.\"\n                f\"pylibcugraphops requires Python 3.8 or 3.10.\"\n            )\n\n        valid_aggr_types = {\"max\", \"min\", \"mean\", \"sum\"}\n        if aggregator_type not in valid_aggr_types:\n            raise ValueError(\n                f\"Invalid aggregator_type. Must be one of {valid_aggr_types}. \"\n                f\"But got '{aggregator_type}' instead.\"\n            )\n\n        super().__init__()\n        self.in_feats = in_feats\n        self.out_feats = out_feats\n        self.aggr = aggregator_type\n        self.feat_drop = nn.Dropout(feat_drop)\n        self.linear = nn.Linear(2 * in_feats, out_feats, bias=bias)\n\n    def reset_parameters(self):\n        r\"\"\"Reinitialize learnable parameters.\"\"\"\n        self.linear.reset_parameters()\n\n    def forward(self, g, feat, max_in_degree=None):\n        r\"\"\"Forward computation.\n\n        Parameters\n        ----------\n        g : DGLGraph\n            The graph.\n        feat : torch.Tensor\n            Node features. Shape: :math:`(N, D_{in})`.\n        max_in_degree : int\n            Maximum in-degree of destination nodes. It is only effective when\n            :attr:`g` is a :class:`DGLBlock`, i.e., bipartite graph. When\n            :attr:`g` is generated from a neighbor sampler, the value should be\n            set to the corresponding :attr:`fanout`. If not given,\n            :attr:`max_in_degree` will be calculated on-the-fly.\n\n        Returns\n        -------\n        torch.Tensor\n            Output node features. Shape: :math:`(N, D_{out})`.\n        \"\"\"\n        offsets, indices, _ = g.adj_tensors(\"csc\")\n\n        if g.is_block:\n            if max_in_degree is None:\n                max_in_degree = g.in_degrees().max().item()\n\n            if max_in_degree < self.MAX_IN_DEGREE_MFG:\n                _graph = SampledCSC(\n                    offsets,\n                    indices,\n                    max_in_degree,\n                    g.num_src_nodes(),\n                )\n            else:\n                offsets_fg = self.pad_offsets(offsets, g.num_src_nodes() + 1)\n                _graph = StaticCSC(offsets_fg, indices)\n        else:\n            _graph = StaticCSC(offsets, indices)\n\n        feat = self.feat_drop(feat)\n        h = SAGEConvAgg(feat, _graph, self.aggr)[: g.num_dst_nodes()]\n        h = self.linear(h)\n\n        return h\n"
  },
  {
    "path": "python/dgl/nn/pytorch/conv/densechebconv.py",
    "content": "\"\"\"Torch Module for DenseChebConv\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name\nimport torch as th\nfrom torch import nn\nfrom torch.nn import init\n\n\nclass DenseChebConv(nn.Module):\n    r\"\"\"Chebyshev Spectral Graph Convolution layer from `Convolutional\n    Neural Networks on Graphs with Fast Localized Spectral Filtering\n    <https://arxiv.org/pdf/1606.09375.pdf>`__\n\n    We recommend to use this module when applying ChebConv on dense graphs.\n\n    Parameters\n    ----------\n    in_feats: int\n        Dimension of input features :math:`h_i^{(l)}`.\n    out_feats: int\n        Dimension of output features :math:`h_i^{(l+1)}`.\n    k : int\n        Chebyshev filter size.\n    activation : function, optional\n        Activation function, default is ReLu.\n    bias : bool, optional\n        If True, adds a learnable bias to the output. Default: ``True``.\n\n    Example\n    -------\n    >>> import dgl\n    >>> import numpy as np\n    >>> import torch as th\n    >>> from dgl.nn import DenseChebConv\n    >>>\n    >>> feat = th.ones(6, 10)\n    >>> adj = th.tensor([[0., 0., 1., 0., 0., 0.],\n    ...         [1., 0., 0., 0., 0., 0.],\n    ...         [0., 1., 0., 0., 0., 0.],\n    ...         [0., 0., 1., 0., 0., 1.],\n    ...         [0., 0., 0., 1., 0., 0.],\n    ...         [0., 0., 0., 0., 0., 0.]])\n    >>> conv = DenseChebConv(10, 2, 2)\n    >>> res = conv(adj, feat)\n    >>> res\n    tensor([[-3.3516, -2.4797],\n            [-3.3516, -2.4797],\n            [-3.3516, -2.4797],\n            [-4.5192, -3.0835],\n            [-2.5259, -2.0527],\n            [-0.5327, -1.0219]], grad_fn=<AddBackward0>)\n\n    See also\n    --------\n    `ChebConv <https://docs.dgl.ai/api/python/nn.pytorch.html#chebconv>`__\n    \"\"\"\n\n    def __init__(self, in_feats, out_feats, k, bias=True):\n        super(DenseChebConv, self).__init__()\n        self._in_feats = in_feats\n        self._out_feats = out_feats\n        self._k = k\n        self.W = nn.Parameter(th.Tensor(k, in_feats, out_feats))\n        if bias:\n            self.bias = nn.Parameter(th.Tensor(out_feats))\n        else:\n            self.register_buffer(\"bias\", None)\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        \"\"\"Reinitialize learnable parameters.\"\"\"\n        if self.bias is not None:\n            init.zeros_(self.bias)\n        for i in range(self._k):\n            init.xavier_normal_(self.W[i], init.calculate_gain(\"relu\"))\n\n    def forward(self, adj, feat, lambda_max=None):\n        r\"\"\"Compute (Dense) Chebyshev Spectral Graph Convolution layer\n\n        Parameters\n        ----------\n        adj : torch.Tensor\n            The adjacency matrix of the graph to apply Graph Convolution on,\n            should be of shape :math:`(N, N)`, where a row represents the destination\n            and a column represents the source.\n        feat : torch.Tensor\n            The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`\n            is size of input feature, :math:`N` is the number of nodes.\n        lambda_max : float or None, optional\n            A float value indicates the largest eigenvalue of given graph.\n            Default: None.\n\n        Returns\n        -------\n        torch.Tensor\n            The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`\n            is size of output feature.\n        \"\"\"\n        A = adj.to(feat)\n        num_nodes = A.shape[0]\n\n        in_degree = 1 / A.sum(dim=1).clamp(min=1).sqrt()\n        D_invsqrt = th.diag(in_degree)\n        I = th.eye(num_nodes).to(A)\n        L = I - D_invsqrt @ A @ D_invsqrt\n\n        if lambda_max is None:\n            lambda_ = th.eig(L)[0][:, 0]\n            lambda_max = lambda_.max()\n\n        L_hat = 2 * L / lambda_max - I\n        Z = [th.eye(num_nodes).to(A)]\n        for i in range(1, self._k):\n            if i == 1:\n                Z.append(L_hat)\n            else:\n                Z.append(2 * L_hat @ Z[-1] - Z[-2])\n\n        Zs = th.stack(Z, 0)  # (k, n, n)\n\n        Zh = Zs @ feat.unsqueeze(0) @ self.W\n        Zh = Zh.sum(0)\n\n        if self.bias is not None:\n            Zh = Zh + self.bias\n        return Zh\n"
  },
  {
    "path": "python/dgl/nn/pytorch/conv/densegraphconv.py",
    "content": "\"\"\"Torch Module for DenseGraphConv\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name\nimport torch as th\nfrom torch import nn\nfrom torch.nn import init\n\n\nclass DenseGraphConv(nn.Module):\n    \"\"\"Graph Convolutional layer from `Semi-Supervised Classification with Graph\n    Convolutional Networks <https://arxiv.org/abs/1609.02907>`__\n\n    We recommend user to use this module when applying graph convolution on\n    dense graphs.\n\n    Parameters\n    ----------\n    in_feats : int\n        Input feature size; i.e, the number of dimensions of :math:`h_j^{(l)}`.\n    out_feats : int\n        Output feature size; i.e., the number of dimensions of :math:`h_i^{(l+1)}`.\n    norm : str, optional\n        How to apply the normalizer. If is `'right'`, divide the aggregated messages\n        by each node's in-degrees, which is equivalent to averaging the received messages.\n        If is `'none'`, no normalization is applied. Default is `'both'`,\n        where the :math:`c_{ij}` in the paper is applied.\n    bias : bool, optional\n        If True, adds a learnable bias to the output. Default: ``True``.\n    activation : callable activation function/layer or None, optional\n        If not None, applies an activation function to the updated node features.\n        Default: ``None``.\n\n    Notes\n    -----\n    Zero in-degree nodes will lead to all-zero output. A common practice\n    to avoid this is to add a self-loop for each node in the graph,\n    which can be achieved by setting the diagonal of the adjacency matrix to be 1.\n\n    Example\n    -------\n    >>> import dgl\n    >>> import numpy as np\n    >>> import torch as th\n    >>> from dgl.nn import DenseGraphConv\n    >>>\n    >>> feat = th.ones(6, 10)\n    >>> adj = th.tensor([[0., 0., 1., 0., 0., 0.],\n    ...         [1., 0., 0., 0., 0., 0.],\n    ...         [0., 1., 0., 0., 0., 0.],\n    ...         [0., 0., 1., 0., 0., 1.],\n    ...         [0., 0., 0., 1., 0., 0.],\n    ...         [0., 0., 0., 0., 0., 0.]])\n    >>> conv = DenseGraphConv(10, 2)\n    >>> res = conv(adj, feat)\n    >>> res\n    tensor([[0.2159, 1.9027],\n            [0.3053, 2.6908],\n            [0.3053, 2.6908],\n            [0.3685, 3.2481],\n            [0.3053, 2.6908],\n            [0.0000, 0.0000]], grad_fn=<AddBackward0>)\n\n    See also\n    --------\n    `GraphConv <https://docs.dgl.ai/api/python/nn.pytorch.html#graphconv>`__\n    \"\"\"\n\n    def __init__(\n        self, in_feats, out_feats, norm=\"both\", bias=True, activation=None\n    ):\n        super(DenseGraphConv, self).__init__()\n        self._in_feats = in_feats\n        self._out_feats = out_feats\n        self._norm = norm\n        self.weight = nn.Parameter(th.Tensor(in_feats, out_feats))\n        if bias:\n            self.bias = nn.Parameter(th.Tensor(out_feats))\n        else:\n            self.register_buffer(\"bias\", None)\n\n        self.reset_parameters()\n        self._activation = activation\n\n    def reset_parameters(self):\n        \"\"\"Reinitialize learnable parameters.\"\"\"\n        init.xavier_uniform_(self.weight)\n        if self.bias is not None:\n            init.zeros_(self.bias)\n\n    def forward(self, adj, feat):\n        r\"\"\"Compute (Dense) Graph Convolution layer.\n\n        Parameters\n        ----------\n        adj : torch.Tensor\n            The adjacency matrix of the graph to apply Graph Convolution on, when\n            applied to a unidirectional bipartite graph, ``adj`` should be of shape\n            should be of shape :math:`(N_{out}, N_{in})`; when applied to a homo\n            graph, ``adj`` should be of shape :math:`(N, N)`. In both cases,\n            a row represents a destination node while a column represents a source\n            node.\n        feat : torch.Tensor\n            The input feature.\n\n        Returns\n        -------\n        torch.Tensor\n            The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`\n            is size of output feature.\n        \"\"\"\n        adj = adj.to(feat)\n        src_degrees = adj.sum(dim=0).clamp(min=1)\n        dst_degrees = adj.sum(dim=1).clamp(min=1)\n        feat_src = feat\n\n        if self._norm == \"both\":\n            norm_src = th.pow(src_degrees, -0.5)\n            shp = norm_src.shape + (1,) * (feat.dim() - 1)\n            norm_src = th.reshape(norm_src, shp).to(feat.device)\n            feat_src = feat_src * norm_src\n\n        if self._in_feats > self._out_feats:\n            # mult W first to reduce the feature size for aggregation.\n            feat_src = th.matmul(feat_src, self.weight)\n            rst = adj @ feat_src\n        else:\n            # aggregate first then mult W\n            rst = adj @ feat_src\n            rst = th.matmul(rst, self.weight)\n\n        if self._norm != \"none\":\n            if self._norm == \"both\":\n                norm_dst = th.pow(dst_degrees, -0.5)\n            else:  # right\n                norm_dst = 1.0 / dst_degrees\n            shp = norm_dst.shape + (1,) * (feat.dim() - 1)\n            norm_dst = th.reshape(norm_dst, shp).to(feat.device)\n            rst = rst * norm_dst\n\n        if self.bias is not None:\n            rst = rst + self.bias\n\n        if self._activation is not None:\n            rst = self._activation(rst)\n\n        return rst\n"
  },
  {
    "path": "python/dgl/nn/pytorch/conv/densesageconv.py",
    "content": "\"\"\"Torch Module for DenseSAGEConv\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name\nfrom torch import nn\n\nfrom ....utils import check_eq_shape\n\n\nclass DenseSAGEConv(nn.Module):\n    \"\"\"GraphSAGE layer from `Inductive Representation Learning on Large Graphs\n    <https://arxiv.org/abs/1706.02216>`__\n\n    We recommend to use this module when appying GraphSAGE on dense graphs.\n\n    Note that we only support gcn aggregator in DenseSAGEConv.\n\n    Parameters\n    ----------\n    in_feats : int\n        Input feature size; i.e, the number of dimensions of :math:`h_i^{(l)}`.\n    out_feats : int\n        Output feature size; i.e, the number of dimensions of :math:`h_i^{(l+1)}`.\n    feat_drop : float, optional\n        Dropout rate on features. Default: 0.\n    bias : bool\n        If True, adds a learnable bias to the output. Default: ``True``.\n    norm : callable activation function/layer or None, optional\n        If not None, applies normalization to the updated node features.\n    activation : callable activation function/layer or None, optional\n        If not None, applies an activation function to the updated node features.\n        Default: ``None``.\n\n    Example\n    -------\n    >>> import dgl\n    >>> import numpy as np\n    >>> import torch as th\n    >>> from dgl.nn import DenseSAGEConv\n    >>>\n    >>> feat = th.ones(6, 10)\n    >>> adj = th.tensor([[0., 0., 1., 0., 0., 0.],\n    ...         [1., 0., 0., 0., 0., 0.],\n    ...         [0., 1., 0., 0., 0., 0.],\n    ...         [0., 0., 1., 0., 0., 1.],\n    ...         [0., 0., 0., 1., 0., 0.],\n    ...         [0., 0., 0., 0., 0., 0.]])\n    >>> conv = DenseSAGEConv(10, 2)\n    >>> res = conv(adj, feat)\n    >>> res\n    tensor([[1.0401, 2.1008],\n            [1.0401, 2.1008],\n            [1.0401, 2.1008],\n            [1.0401, 2.1008],\n            [1.0401, 2.1008],\n            [1.0401, 2.1008]], grad_fn=<AddmmBackward>)\n\n    See also\n    --------\n    `SAGEConv <https://docs.dgl.ai/api/python/nn.pytorch.html#sageconv>`__\n    \"\"\"\n\n    def __init__(\n        self,\n        in_feats,\n        out_feats,\n        feat_drop=0.0,\n        bias=True,\n        norm=None,\n        activation=None,\n    ):\n        super(DenseSAGEConv, self).__init__()\n        self._in_feats = in_feats\n        self._out_feats = out_feats\n        self._norm = norm\n        self.feat_drop = nn.Dropout(feat_drop)\n        self.activation = activation\n        self.fc = nn.Linear(in_feats, out_feats, bias=bias)\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        r\"\"\"\n\n        Description\n        -----------\n        Reinitialize learnable parameters.\n\n        Notes\n        -----\n        The linear weights :math:`W^{(l)}` are initialized using Glorot uniform initialization.\n        \"\"\"\n        gain = nn.init.calculate_gain(\"relu\")\n        nn.init.xavier_uniform_(self.fc.weight, gain=gain)\n\n    def forward(self, adj, feat):\n        r\"\"\"\n\n        Description\n        -----------\n        Compute (Dense) Graph SAGE layer.\n\n        Parameters\n        ----------\n        adj : torch.Tensor\n            The adjacency matrix of the graph to apply SAGE Convolution on, when\n            applied to a unidirectional bipartite graph, ``adj`` should be of shape\n            should be of shape :math:`(N_{out}, N_{in})`; when applied to a homo\n            graph, ``adj`` should be of shape :math:`(N, N)`. In both cases,\n            a row represents a destination node while a column represents a source\n            node.\n        feat : torch.Tensor or a pair of torch.Tensor\n            If a torch.Tensor is given, the input feature of shape :math:`(N, D_{in})` where\n            :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.\n            If a pair of torch.Tensor is given, the pair must contain two tensors of shape\n            :math:`(N_{in}, D_{in})` and :math:`(N_{out}, D_{in})`.\n\n        Returns\n        -------\n        torch.Tensor\n            The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`\n            is size of output feature.\n        \"\"\"\n        check_eq_shape(feat)\n        if isinstance(feat, tuple):\n            feat_src = self.feat_drop(feat[0])\n            feat_dst = self.feat_drop(feat[1])\n        else:\n            feat_src = feat_dst = self.feat_drop(feat)\n        adj = adj.to(feat_src)\n        in_degrees = adj.sum(dim=1, keepdim=True)\n        h_neigh = (adj @ feat_src + feat_dst) / (in_degrees + 1)\n        rst = self.fc(h_neigh)\n        # activation\n        if self.activation is not None:\n            rst = self.activation(rst)\n        # normalization\n        if self._norm is not None:\n            rst = self._norm(rst)\n\n        return rst\n"
  },
  {
    "path": "python/dgl/nn/pytorch/conv/dgnconv.py",
    "content": "\"\"\"Torch Module for Directional Graph Networks Convolution Layer\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name\nfrom functools import partial\n\nimport torch\nimport torch.nn as nn\n\nfrom .pnaconv import AGGREGATORS, PNAConv, PNAConvTower, SCALERS\n\n\ndef aggregate_dir_av(h, eig_s, eig_d, eig_idx):\n    \"\"\"directional average aggregation\"\"\"\n    h_mod = torch.mul(\n        h,\n        (\n            torch.abs(eig_s[:, :, eig_idx] - eig_d[:, :, eig_idx])\n            / (\n                torch.sum(\n                    torch.abs(eig_s[:, :, eig_idx] - eig_d[:, :, eig_idx]),\n                    keepdim=True,\n                    dim=1,\n                )\n                + 1e-30\n            )\n        ).unsqueeze(-1),\n    )\n    return torch.sum(h_mod, dim=1)\n\n\ndef aggregate_dir_dx(h, eig_s, eig_d, h_in, eig_idx):\n    \"\"\"directional derivative aggregation\"\"\"\n    eig_w = (\n        (eig_s[:, :, eig_idx] - eig_d[:, :, eig_idx])\n        / (\n            torch.sum(\n                torch.abs(eig_s[:, :, eig_idx] - eig_d[:, :, eig_idx]),\n                keepdim=True,\n                dim=1,\n            )\n            + 1e-30\n        )\n    ).unsqueeze(-1)\n    h_mod = torch.mul(h, eig_w)\n    return torch.abs(torch.sum(h_mod, dim=1) - torch.sum(eig_w, dim=1) * h_in)\n\n\nfor k in range(1, 4):\n    AGGREGATORS[f\"dir{k}-av\"] = partial(aggregate_dir_av, eig_idx=k - 1)\n    AGGREGATORS[f\"dir{k}-dx\"] = partial(aggregate_dir_dx, eig_idx=k - 1)\n\n\nclass DGNConvTower(PNAConvTower):\n    \"\"\"A single DGN tower with modified reduce function\"\"\"\n\n    def message(self, edges):\n        \"\"\"message function for DGN layer\"\"\"\n        if self.edge_feat_size > 0:\n            f = torch.cat(\n                [edges.src[\"h\"], edges.dst[\"h\"], edges.data[\"a\"]], dim=-1\n            )\n        else:\n            f = torch.cat([edges.src[\"h\"], edges.dst[\"h\"]], dim=-1)\n        return {\n            \"msg\": self.M(f),\n            \"eig_s\": edges.src[\"eig\"],\n            \"eig_d\": edges.dst[\"eig\"],\n        }\n\n    def reduce_func(self, nodes):\n        \"\"\"reduce function for DGN layer\"\"\"\n        h_in = nodes.data[\"h\"]\n        eig_s = nodes.mailbox[\"eig_s\"]\n        eig_d = nodes.mailbox[\"eig_d\"]\n        msg = nodes.mailbox[\"msg\"]\n        degree = msg.size(1)\n\n        h = []\n        for agg in self.aggregators:\n            if agg.startswith(\"dir\"):\n                if agg.endswith(\"av\"):\n                    h.append(AGGREGATORS[agg](msg, eig_s, eig_d))\n                else:\n                    h.append(AGGREGATORS[agg](msg, eig_s, eig_d, h_in))\n            else:\n                h.append(AGGREGATORS[agg](msg))\n        h = torch.cat(h, dim=1)\n        h = torch.cat(\n            [\n                SCALERS[scaler](h, D=degree, delta=self.delta)\n                if scaler != \"identity\"\n                else h\n                for scaler in self.scalers\n            ],\n            dim=1,\n        )\n        return {\"h_neigh\": h}\n\n\nclass DGNConv(PNAConv):\n    r\"\"\"Directional Graph Network Layer from `Directional Graph Networks\n    <https://arxiv.org/abs/2010.02863>`__\n\n    DGN introduces two special directional aggregators according to the vector field\n    :math:`F`, which is defined as the gradient of the low-frequency eigenvectors of graph\n    laplacian.\n\n    The directional average aggregator is defined as\n    :math:`h_i' = \\sum_{j\\in\\mathcal{N}(i)}\\frac{|F_{i,j}|\\cdot h_j}{||F_{i,:}||_1+\\epsilon}`\n\n    The directional derivative aggregator is defined as\n    :math:`h_i' = \\sum_{j\\in\\mathcal{N}(i)}\\frac{F_{i,j}\\cdot h_j}{||F_{i,:}||_1+\\epsilon}\n    -h_i\\cdot\\sum_{j\\in\\mathcal{N}(i)}\\frac{F_{i,j}}{||F_{i,:}||_1+\\epsilon}`\n\n    :math:`\\epsilon` is the infinitesimal to keep the computation numerically stable.\n\n    Parameters\n    ----------\n    in_size : int\n        Input feature size; i.e. the size of :math:`h_i^l`.\n    out_size : int\n        Output feature size; i.e. the size of :math:`h_i^{l+1}`.\n    aggregators : list of str\n        List of aggregation function names(each aggregator specifies a way to aggregate\n        messages from neighbours), selected from:\n\n        * ``mean``: the mean of neighbour messages\n\n        * ``max``: the maximum of neighbour messages\n\n        * ``min``: the minimum of neighbour messages\n\n        * ``std``: the standard deviation of neighbour messages\n\n        * ``var``: the variance of neighbour messages\n\n        * ``sum``: the sum of neighbour messages\n\n        * ``moment3``, ``moment4``, ``moment5``: the normalized moments aggregation\n        :math:`(E[(X-E[X])^n])^{1/n}`\n\n        * ``dir{k}-av``: directional average aggregation with directions defined by the k-th\n        smallest eigenvectors. k can be selected from 1, 2, 3.\n\n        * ``dir{k}-dx``: directional derivative aggregation with directions defined by the k-th\n        smallest eigenvectors. k can be selected from 1, 2, 3.\n\n        Note that using directional aggregation requires the LaplacianPE transform on the input\n        graph for eigenvector computation (the PE size must be >= k above).\n    scalers: list of str\n        List of scaler function names, selected from:\n\n        * ``identity``: no scaling\n\n        * ``amplification``: multiply the aggregated message by :math:`\\log(d+1)/\\delta`,\n        where :math:`d` is the in-degree of the node.\n\n        * ``attenuation``: multiply the aggregated message by :math:`\\delta/\\log(d+1)`\n    delta: float\n        The in-degree-related normalization factor computed over the training set, used by scalers\n        for normalization. :math:`E[\\log(d+1)]`, where :math:`d` is the in-degree for each node\n        in the training set.\n    dropout: float, optional\n        The dropout ratio. Default: 0.0.\n    num_towers: int, optional\n        The number of towers used. Default: 1. Note that in_size and out_size must be divisible\n        by num_towers.\n    edge_feat_size: int, optional\n        The edge feature size. Default: 0.\n    residual : bool, optional\n        The bool flag that determines whether to add a residual connection for the\n        output. Default: True. If in_size and out_size of the DGN conv layer are not\n        the same, this flag will be set as False forcibly.\n\n    Example\n    -------\n    >>> import dgl\n    >>> import torch as th\n    >>> from dgl.nn import DGNConv\n    >>> from dgl import LaplacianPE\n    >>>\n    >>> # DGN requires precomputed eigenvectors, with 'eig' as feature name.\n    >>> transform = LaplacianPE(k=3, feat_name='eig')\n    >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))\n    >>> g = transform(g)\n    >>> eig = g.ndata['eig']\n    >>> feat = th.ones(6, 10)\n    >>> conv = DGNConv(10, 10, ['dir1-av', 'dir1-dx', 'sum'], ['identity', 'amplification'], 2.5)\n    >>> ret = conv(g, feat, eig_vec=eig)\n    \"\"\"\n\n    def __init__(\n        self,\n        in_size,\n        out_size,\n        aggregators,\n        scalers,\n        delta,\n        dropout=0.0,\n        num_towers=1,\n        edge_feat_size=0,\n        residual=True,\n    ):\n        super(DGNConv, self).__init__(\n            in_size,\n            out_size,\n            aggregators,\n            scalers,\n            delta,\n            dropout,\n            num_towers,\n            edge_feat_size,\n            residual,\n        )\n\n        self.towers = nn.ModuleList(\n            [\n                DGNConvTower(\n                    self.tower_in_size,\n                    self.tower_out_size,\n                    aggregators,\n                    scalers,\n                    delta,\n                    dropout=dropout,\n                    edge_feat_size=edge_feat_size,\n                )\n                for _ in range(num_towers)\n            ]\n        )\n\n        self.use_eig_vec = False\n        for aggr in aggregators:\n            if aggr.startswith(\"dir\"):\n                self.use_eig_vec = True\n                break\n\n    def forward(self, graph, node_feat, edge_feat=None, eig_vec=None):\n        r\"\"\"\n        Description\n        -----------\n        Compute DGN layer.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The graph.\n        node_feat : torch.Tensor\n            The input feature of shape :math:`(N, h_n)`. :math:`N` is the number of\n            nodes, and :math:`h_n` must be the same as in_size.\n        edge_feat : torch.Tensor, optional\n            The edge feature of shape :math:`(M, h_e)`. :math:`M` is the number of\n            edges, and :math:`h_e` must be the same as edge_feat_size.\n        eig_vec : torch.Tensor, optional\n            K smallest non-trivial eigenvectors of Graph Laplacian of shape :math:`(N, K)`.\n            It is only required when :attr:`aggregators` contains directional aggregators.\n\n        Returns\n        -------\n        torch.Tensor\n            The output node feature of shape :math:`(N, h_n')` where :math:`h_n'`\n            should be the same as out_size.\n        \"\"\"\n        with graph.local_scope():\n            if self.use_eig_vec:\n                graph.ndata[\"eig\"] = eig_vec\n            return super().forward(graph, node_feat, edge_feat)\n"
  },
  {
    "path": "python/dgl/nn/pytorch/conv/dotgatconv.py",
    "content": "\"\"\"Torch modules for graph attention networks(GAT).\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name\nfrom torch import nn\n\nfrom .... import function as fn\nfrom ....base import DGLError\nfrom ....utils import expand_as_pair\nfrom ...functional import edge_softmax\n\n\nclass DotGatConv(nn.Module):\n    r\"\"\"Apply dot product version of self attention in `Graph Attention Network\n    <https://arxiv.org/pdf/1710.10903.pdf>`__\n\n        .. math::\n            h_i^{(l+1)} = \\sum_{j\\in \\mathcal{N}(i)} \\alpha_{i, j} h_j^{(l)}\n\n        where :math:`\\alpha_{ij}` is the attention score bewteen node :math:`i` and node :math:`j`:\n\n        .. math::\n            \\alpha_{i, j} &= \\mathrm{softmax_i}(e_{ij}^{l})\n\n            e_{ij}^{l} &= ({W_i^{(l)} h_i^{(l)}})^T \\cdot {W_j^{(l)} h_j^{(l)}}\n\n        where :math:`W_i` and :math:`W_j` transform node :math:`i`'s and node :math:`j`'s\n        features into the same dimension, so that when compute note features' similarity,\n        it can use dot-product.\n\n    Parameters\n    ----------\n    in_feats : int, or pair of ints\n        Input feature size; i.e, the number of dimensions of :math:`h_i^{(l)}`.\n        DotGatConv can be applied on homogeneous graph and unidirectional\n        `bipartite graph <https://docs.dgl.ai/generated/dgl.bipartite.html?highlight=bipartite>`__.\n        If the layer is to be applied to a unidirectional bipartite graph, ``in_feats``\n        specifies the input feature size on both the source and destination nodes.  If\n        a scalar is given, the source and destination node feature size would take the\n        same value.\n    out_feats : int\n        Output feature size; i.e, the number of dimensions of :math:`h_i^{(l+1)}`.\n    num_heads : int\n        Number of head in Multi-Head Attention\n    allow_zero_in_degree : bool, optional\n        If there are 0-in-degree nodes in the graph, output for those nodes will be invalid\n        since no message will be passed to those nodes. This is harmful for some applications\n        causing silent performance regression. This module will raise a DGLError if it detects\n        0-in-degree nodes in input graph. By setting ``True``, it will suppress the check\n        and let the users handle it by themselves. Default: ``False``.\n\n    Note\n    ----\n    Zero in-degree nodes will lead to invalid output value. This is because no message\n    will be passed to those nodes, the aggregation function will be appied on empty input.\n    A common practice to avoid this is to add a self-loop for each node in the graph if\n    it is homogeneous, which can be achieved by:\n\n    >>> g = ... # a DGLGraph\n    >>> g = dgl.add_self_loop(g)\n\n    Calling ``add_self_loop`` will not work for some graphs, for example, heterogeneous graph\n    since the edge type can not be decided for self_loop edges. Set ``allow_zero_in_degree``\n    to ``True`` for those cases to unblock the code and handle zero-in-degree nodes manually.\n    A common practise to handle this is to filter out the nodes with zero-in-degree when use\n    after conv.\n\n    Examples\n    --------\n    >>> import dgl\n    >>> import numpy as np\n    >>> import torch as th\n    >>> from dgl.nn import DotGatConv\n\n    >>> # Case 1: Homogeneous graph\n    >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))\n    >>> g = dgl.add_self_loop(g)\n    >>> feat = th.ones(6, 10)\n    >>> dotgatconv = DotGatConv(10, 2, num_heads=3)\n    >>> res = dotgatconv(g, feat)\n    >>> res\n    tensor([[[ 3.4570,  1.8634],\n            [ 1.3805, -0.0762],\n            [ 1.0390, -1.1479]],\n            [[ 3.4570,  1.8634],\n            [ 1.3805, -0.0762],\n            [ 1.0390, -1.1479]],\n            [[ 3.4570,  1.8634],\n            [ 1.3805, -0.0762],\n            [ 1.0390, -1.1479]],\n            [[ 3.4570,  1.8634],\n            [ 1.3805, -0.0762],\n            [ 1.0390, -1.1479]],\n            [[ 3.4570,  1.8634],\n            [ 1.3805, -0.0762],\n            [ 1.0390, -1.1479]],\n            [[ 3.4570,  1.8634],\n            [ 1.3805, -0.0762],\n            [ 1.0390, -1.1479]]], grad_fn=<BinaryReduceBackward>)\n\n    >>> # Case 2: Unidirectional bipartite graph\n    >>> u = [0, 1, 0, 0, 1]\n    >>> v = [0, 1, 2, 3, 2]\n    >>> g = dgl.heterograph({('_N', '_E', '_N'):(u, v)})\n    >>> u_feat = th.tensor(np.random.rand(2, 5).astype(np.float32))\n    >>> v_feat = th.tensor(np.random.rand(4, 10).astype(np.float32))\n    >>> dotgatconv = DotGatConv((5,10), 2, 3)\n    >>> res = dotgatconv(g, (u_feat, v_feat))\n    >>> res\n    tensor([[[-0.6066,  1.0268],\n            [-0.5945, -0.4801],\n            [ 0.1594,  0.3825]],\n            [[ 0.0268,  1.0783],\n            [ 0.5041, -1.3025],\n            [ 0.6568,  0.7048]],\n            [[-0.2688,  1.0543],\n            [-0.0315, -0.9016],\n            [ 0.3943,  0.5347]],\n            [[-0.6066,  1.0268],\n            [-0.5945, -0.4801],\n            [ 0.1594,  0.3825]]], grad_fn=<BinaryReduceBackward>)\n    \"\"\"\n\n    def __init__(\n        self, in_feats, out_feats, num_heads, allow_zero_in_degree=False\n    ):\n        super(DotGatConv, self).__init__()\n        self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)\n        self._out_feats = out_feats\n        self._allow_zero_in_degree = allow_zero_in_degree\n        self._num_heads = num_heads\n\n        if isinstance(in_feats, tuple):\n            self.fc_src = nn.Linear(\n                self._in_src_feats,\n                self._out_feats * self._num_heads,\n                bias=False,\n            )\n            self.fc_dst = nn.Linear(\n                self._in_dst_feats,\n                self._out_feats * self._num_heads,\n                bias=False,\n            )\n        else:\n            self.fc = nn.Linear(\n                self._in_src_feats,\n                self._out_feats * self._num_heads,\n                bias=False,\n            )\n\n    def forward(self, graph, feat, get_attention=False):\n        r\"\"\"\n\n        Description\n        -----------\n        Apply dot product version of self attention in GCN.\n\n        Parameters\n        ----------\n        graph: DGLGraph or bi_partities graph\n            The graph\n        feat: torch.Tensor or pair of torch.Tensor\n            If a torch.Tensor is given, the input feature of shape :math:`(N, D_{in})` where\n            :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.\n            If a pair of torch.Tensor is given, the pair must contain two tensors of shape\n            :math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.\n        get_attention : bool, optional\n            Whether to return the attention values. Default to False.\n\n        Returns\n        -------\n        torch.Tensor\n            The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` is size\n            of output feature.\n        torch.Tensor, optional\n            The attention values of shape :math:`(E, 1)`, where :math:`E` is the number of\n            edges. This is returned only when :attr:`get_attention` is ``True``.\n\n        Raises\n        ------\n        DGLError\n            If there are 0-in-degree nodes in the input graph, it will raise DGLError\n            since no message will be passed to those nodes. This will cause invalid output.\n            The error can be ignored by setting ``allow_zero_in_degree`` parameter to ``True``.\n        \"\"\"\n\n        graph = graph.local_var()\n\n        if not self._allow_zero_in_degree:\n            if (graph.in_degrees() == 0).any():\n                raise DGLError(\n                    \"There are 0-in-degree nodes in the graph, \"\n                    \"output for those nodes will be invalid. \"\n                    \"This is harmful for some applications, \"\n                    \"causing silent performance regression. \"\n                    \"Adding self-loop on the input graph by \"\n                    \"calling `g = dgl.add_self_loop(g)` will resolve \"\n                    \"the issue. Setting ``allow_zero_in_degree`` \"\n                    \"to be `True` when constructing this module will \"\n                    \"suppress the check and let the code run.\"\n                )\n\n        # check if feat is a tuple\n        if isinstance(feat, tuple):\n            h_src = feat[0]\n            h_dst = feat[1]\n            feat_src = self.fc_src(h_src).view(\n                -1, self._num_heads, self._out_feats\n            )\n            feat_dst = self.fc_dst(h_dst).view(\n                -1, self._num_heads, self._out_feats\n            )\n        else:\n            h_src = feat\n            feat_src = feat_dst = self.fc(h_src).view(\n                -1, self._num_heads, self._out_feats\n            )\n            if graph.is_block:\n                feat_dst = feat_src[: graph.number_of_dst_nodes()]\n\n        # Assign features to nodes\n        graph.srcdata.update({\"ft\": feat_src})\n        graph.dstdata.update({\"ft\": feat_dst})\n\n        # Step 1. dot product\n        graph.apply_edges(fn.u_dot_v(\"ft\", \"ft\", \"a\"))\n\n        # Step 2. edge softmax to compute attention scores\n        graph.edata[\"sa\"] = edge_softmax(\n            graph, graph.edata[\"a\"] / self._out_feats**0.5\n        )\n\n        # Step 3. Broadcast softmax value to each edge, and aggregate dst node\n        graph.update_all(\n            fn.u_mul_e(\"ft\", \"sa\", \"attn\"), fn.sum(\"attn\", \"agg_u\")\n        )\n\n        # output results to the destination nodes\n        rst = graph.dstdata[\"agg_u\"]\n\n        if get_attention:\n            return rst, graph.edata[\"sa\"]\n        else:\n            return rst\n"
  },
  {
    "path": "python/dgl/nn/pytorch/conv/edgeconv.py",
    "content": "\"\"\"Torch Module for EdgeConv Layer\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name\nfrom torch import nn\n\nfrom .... import function as fn\n\nfrom ....base import DGLError\nfrom ....utils import expand_as_pair\n\n\nclass EdgeConv(nn.Module):\n    r\"\"\"EdgeConv layer from `Dynamic Graph CNN for Learning on Point Clouds\n    <https://arxiv.org/pdf/1801.07829>`__\n\n    It can be described as follows:\n\n    .. math::\n       h_i^{(l+1)} = \\max_{j \\in \\mathcal{N}(i)} (\n       \\Theta \\cdot (h_j^{(l)} - h_i^{(l)}) + \\Phi \\cdot h_i^{(l)})\n\n    where :math:`\\mathcal{N}(i)` is the neighbor of :math:`i`.\n    :math:`\\Theta` and :math:`\\Phi` are linear layers.\n\n    .. note::\n\n       The original formulation includes a ReLU inside the maximum operator.\n       This is equivalent to first applying a maximum operator then applying\n       the ReLU.\n\n    Parameters\n    ----------\n    in_feat : int\n        Input feature size; i.e, the number of dimensions of :math:`h_j^{(l)}`.\n    out_feat : int\n        Output feature size; i.e., the number of dimensions of :math:`h_i^{(l+1)}`.\n    batch_norm : bool\n        Whether to include batch normalization on messages. Default: ``False``.\n    allow_zero_in_degree : bool, optional\n        If there are 0-in-degree nodes in the graph, output for those nodes will be invalid\n        since no message will be passed to those nodes. This is harmful for some applications\n        causing silent performance regression. This module will raise a DGLError if it detects\n        0-in-degree nodes in input graph. By setting ``True``, it will suppress the check\n        and let the users handle it by themselves. Default: ``False``.\n\n    Note\n    ----\n    Zero in-degree nodes will lead to invalid output value. This is because no message\n    will be passed to those nodes, the aggregation function will be appied on empty input.\n    A common practice to avoid this is to add a self-loop for each node in the graph if\n    it is homogeneous, which can be achieved by:\n\n    >>> g = ... # a DGLGraph\n    >>> g = dgl.add_self_loop(g)\n\n    Calling ``add_self_loop`` will not work for some graphs, for example, heterogeneous graph\n    since the edge type can not be decided for self_loop edges. Set ``allow_zero_in_degree``\n    to ``True`` for those cases to unblock the code and handle zero-in-degree nodes manually.\n    A common practise to handle this is to filter out the nodes with zero-in-degree when use\n    after conv.\n\n    Examples\n    --------\n    >>> import dgl\n    >>> import numpy as np\n    >>> import torch as th\n    >>> from dgl.nn import EdgeConv\n\n    >>> # Case 1: Homogeneous graph\n    >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))\n    >>> g = dgl.add_self_loop(g)\n    >>> feat = th.ones(6, 10)\n    >>> conv = EdgeConv(10, 2)\n    >>> res = conv(g, feat)\n    >>> res\n    tensor([[-0.2347,  0.5849],\n            [-0.2347,  0.5849],\n            [-0.2347,  0.5849],\n            [-0.2347,  0.5849],\n            [-0.2347,  0.5849],\n            [-0.2347,  0.5849]], grad_fn=<CopyReduceBackward>)\n\n    >>> # Case 2: Unidirectional bipartite graph\n    >>> u = [0, 1, 0, 0, 1]\n    >>> v = [0, 1, 2, 3, 2]\n    >>> g = dgl.heterograph({('_N', '_E', '_N'):(u, v)})\n    >>> u_fea = th.rand(2, 5)\n    >>> v_fea = th.rand(4, 5)\n    >>> conv = EdgeConv(5, 2, 3)\n    >>> res = conv(g, (u_fea, v_fea))\n    >>> res\n    tensor([[ 1.6375,  0.2085],\n            [-1.1925, -1.2852],\n            [ 0.2101,  1.3466],\n            [ 0.2342, -0.9868]], grad_fn=<CopyReduceBackward>)\n    \"\"\"\n\n    def __init__(\n        self, in_feat, out_feat, batch_norm=False, allow_zero_in_degree=False\n    ):\n        super(EdgeConv, self).__init__()\n        self.batch_norm = batch_norm\n        self._allow_zero_in_degree = allow_zero_in_degree\n\n        self.theta = nn.Linear(in_feat, out_feat)\n        self.phi = nn.Linear(in_feat, out_feat)\n\n        if batch_norm:\n            self.bn = nn.BatchNorm1d(out_feat)\n\n    def set_allow_zero_in_degree(self, set_value):\n        r\"\"\"\n\n        Description\n        -----------\n        Set allow_zero_in_degree flag.\n\n        Parameters\n        ----------\n        set_value : bool\n            The value to be set to the flag.\n        \"\"\"\n        self._allow_zero_in_degree = set_value\n\n    def forward(self, g, feat):\n        \"\"\"\n\n        Description\n        -----------\n        Forward computation\n\n        Parameters\n        ----------\n        g : DGLGraph\n            The graph.\n        feat : Tensor or pair of tensors\n            :math:`(N, D)` where :math:`N` is the number of nodes and\n            :math:`D` is the number of feature dimensions.\n\n            If a pair of tensors is given, the graph must be a uni-bipartite graph\n            with only one edge type, and the two tensors must have the same\n            dimensionality on all except the first axis.\n\n        Returns\n        -------\n        torch.Tensor\n            New node features.\n\n        Raises\n        ------\n        DGLError\n            If there are 0-in-degree nodes in the input graph, it will raise DGLError\n            since no message will be passed to those nodes. This will cause invalid output.\n            The error can be ignored by setting ``allow_zero_in_degree`` parameter to ``True``.\n        \"\"\"\n        with g.local_scope():\n            if not self._allow_zero_in_degree:\n                if (g.in_degrees() == 0).any():\n                    raise DGLError(\n                        \"There are 0-in-degree nodes in the graph, \"\n                        \"output for those nodes will be invalid. \"\n                        \"This is harmful for some applications, \"\n                        \"causing silent performance regression. \"\n                        \"Adding self-loop on the input graph by \"\n                        \"calling `g = dgl.add_self_loop(g)` will resolve \"\n                        \"the issue. Setting ``allow_zero_in_degree`` \"\n                        \"to be `True` when constructing this module will \"\n                        \"suppress the check and let the code run.\"\n                    )\n\n            h_src, h_dst = expand_as_pair(feat, g)\n            g.srcdata[\"x\"] = h_src\n            g.dstdata[\"x\"] = h_dst\n            g.apply_edges(fn.v_sub_u(\"x\", \"x\", \"theta\"))\n            g.edata[\"theta\"] = self.theta(g.edata[\"theta\"])\n            g.dstdata[\"phi\"] = self.phi(g.dstdata[\"x\"])\n            if not self.batch_norm:\n                g.update_all(fn.e_add_v(\"theta\", \"phi\", \"e\"), fn.max(\"e\", \"x\"))\n            else:\n                g.apply_edges(fn.e_add_v(\"theta\", \"phi\", \"e\"))\n                # Although the official implementation includes a per-edge\n                # batch norm within EdgeConv, I choose to replace it with a\n                # global batch norm for a number of reasons:\n                #\n                # (1) When the point clouds within each batch do not have the\n                #     same number of points, batch norm would not work.\n                #\n                # (2) Even if the point clouds always have the same number of\n                #     points, the points may as well be shuffled even with the\n                #     same (type of) object (and the official implementation\n                #     *does* shuffle the points of the same example for each\n                #     epoch).\n                #\n                #     For example, the first point of a point cloud of an\n                #     airplane does not always necessarily reside at its nose.\n                #\n                #     In this case, the learned statistics of each position\n                #     by batch norm is not as meaningful as those learned from\n                #     images.\n                g.edata[\"e\"] = self.bn(g.edata[\"e\"])\n                g.update_all(fn.copy_e(\"e\", \"e\"), fn.max(\"e\", \"x\"))\n            return g.dstdata[\"x\"]\n"
  },
  {
    "path": "python/dgl/nn/pytorch/conv/edgegatconv.py",
    "content": "\"\"\"Torch modules for graph attention networks(GAT).\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name\nimport torch as th\nfrom torch import nn\n\nfrom .... import function as fn\nfrom ....base import DGLError\nfrom ....utils import expand_as_pair\nfrom ...functional import edge_softmax\n\n# pylint: enable=W0235\nclass EdgeGATConv(nn.Module):\n    r\"\"\"Graph attention layer with edge features from `SCENE\n    <https://arxiv.org/pdf/2301.03512.pdf>`__\n\n    .. math::\n\n        \\mathbf{v}_i^\\prime = \\mathbf{\\Theta}_\\mathrm{s} \\cdot \\mathbf{v}_i +\n        \\sum\\limits_{j \\in \\mathcal{N}(v_i)} \\alpha_{j, i} \\left( \\mathbf{\\Theta}_\\mathrm{n}\n        \\cdot \\mathbf{v}_j + \\mathbf{\\Theta}_\\mathrm{e} \\cdot \\mathbf{e}_{j,i} \\right)\n\n    where :math:`\\mathbf{\\Theta}` is used to denote learnable weight matrices\n    for the transformation of features of the node to update (s=self),\n    neighboring nodes (n=neighbor) and edge features (e=edge).\n    Attention weights are obtained by\n\n    .. math::\n\n        \\alpha_{j, i} = \\mathrm{softmax}_i \\Big( \\mathrm{LeakyReLU} \\big( \\mathbf{a}^T\n        [ \\mathbf{\\Theta}_\\mathrm{n} \\cdot \\mathbf{v}_i || \\mathbf{\\Theta}_\\mathrm{n}\n        \\cdot \\mathbf{v}_j || \\mathbf{\\Theta}_\\mathrm{e} \\cdot \\mathbf{e}_{j,i} ] \\big) \\Big)\n\n    with :math:`\\mathbf{a}` corresponding to a learnable vector.\n    :math:`\\mathrm{softmax_i}` stands for the normalization by all incoming edges of node :math:`i`.\n\n    Parameters\n    ----------\n    in_feats : int, or pair of ints\n        Input feature size; i.e, the number of dimensions of :math:`\\mathbf{v}_i`.\n        GATConv can be applied on homogeneous graph and unidirectional\n        `bipartite graph <https://docs.dgl.ai/generated/dgl.bipartite.html?highlight=bipartite>`__.\n        If the layer is to be applied to a unidirectional bipartite graph, ``in_feats``\n        specifies the input feature size on both the source and destination nodes.  If\n        a scalar is given, the source and destination node feature size would take the\n        same value.\n    edge_feats: int\n        Edge feature size; i.e., the number of dimensions of :math:\\mathbf{e}_{j,i}`.\n    out_feats : int\n        Output feature size; i.e, the number of dimensions of :math:`\\mathbf{v}_i^\\prime`.\n    num_heads : int\n        Number of heads in Multi-Head Attention.\n    feat_drop : float, optional\n        Dropout rate on feature. Defaults: ``0``.\n    attn_drop : float, optional\n        Dropout rate on attention weight. Defaults: ``0``.\n    negative_slope : float, optional\n        LeakyReLU angle of negative slope. Defaults: ``0.2``.\n    residual : bool, optional\n        If True, use residual connection. Defaults: ``False``.\n    activation : callable activation function/layer or None, optional.\n        If not None, applies an activation function to the updated node features.\n        Default: ``None``.\n    allow_zero_in_degree : bool, optional\n        If there are 0-in-degree nodes in the graph, output for those nodes will be invalid\n        since no message will be passed to those nodes. This is harmful for some applications\n        causing silent performance regression. This module will raise a DGLError if it detects\n        0-in-degree nodes in input graph. By setting ``True``, it will suppress the check\n        and let the users handle it by themselves. Defaults: ``False``.\n    bias : bool, optional\n        If True, learns a bias term. Defaults: ``True``.\n\n    Note\n    ----\n    Zero in-degree nodes will lead to invalid output value. This is because no message\n    will be passed to those nodes, the aggregation function will be appied on empty input.\n    A common practice to avoid this is to add a self-loop for each node in the graph if\n    it is homogeneous, which can be achieved by:\n\n    >>> g = ... # a DGLGraph\n    >>> g = dgl.add_self_loop(g)\n\n    Calling ``add_self_loop`` will not work for some graphs, for example, heterogeneous graph\n    since the edge type can not be decided for self_loop edges. Set ``allow_zero_in_degree``\n    to ``True`` for those cases to unblock the code and handle zero-in-degree nodes manually.\n    A common practise to handle this is to filter out the nodes with zero-in-degree when use\n    after conv.\n\n    Examples\n    ----------\n    >>> import dgl\n    >>> import numpy as np\n    >>> import torch as th\n    >>> from dgl.nn import EdgeGATConv\n\n    >>> # Case 1: Homogeneous graph.\n    >>> num_nodes, num_edges = 8, 30\n    >>> # Generate a graph.\n    >>> graph = dgl.rand_graph(num_nodes,num_edges)\n    >>> node_feats = th.rand((num_nodes, 20))\n    >>> edge_feats = th.rand((num_edges, 12))\n    >>> edge_gat = EdgeGATConv(\n    ...     in_feats=20,\n    ...     edge_feats=12,\n    ...     out_feats=15,\n    ...     num_heads=3,\n    ... )\n    >>> # Forward pass.\n    >>> new_node_feats = edge_gat(graph, node_feats, edge_feats)\n    >>> new_node_feats.shape\n    torch.Size([8, 3, 15]) torch.Size([30, 3, 10])\n\n    >>> # Case 2: Unidirectional bipartite graph.\n    >>> u = [0, 1, 0, 0, 1]\n    >>> v = [0, 1, 2, 3, 2]\n    >>> g = dgl.heterograph({('A', 'r', 'B'): (u, v)})\n    >>> u_feat = th.tensor(np.random.rand(2, 25).astype(np.float32))\n    >>> v_feat = th.tensor(np.random.rand(4, 30).astype(np.float32))\n    >>> nfeats = (u_feat,v_feat)\n    >>> efeats = th.tensor(np.random.rand(5, 15).astype(np.float32))\n    >>> in_feats = (25,30)\n    >>> edge_feats = 15\n    >>> out_feats = 10\n    >>> num_heads = 3\n    >>> egat_model =  EdgeGATConv(\n    ...     in_feats,\n    ...     edge_feats,\n    ...     out_feats,\n    ...     num_heads,\n    ... )\n    >>> # Forward pass.\n    >>> new_node_feats, attention_weights = egat_model(g, nfeats, efeats, get_attention=True)\n    >>> new_node_feats.shape, attention_weights.shape\n    (torch.Size([4, 3, 10]), torch.Size([5, 3, 1]))\n    \"\"\"\n\n    def __init__(\n        self,\n        in_feats,\n        edge_feats,\n        out_feats,\n        num_heads,\n        feat_drop=0.0,\n        attn_drop=0.0,\n        negative_slope=0.2,\n        residual=True,\n        activation=None,\n        allow_zero_in_degree=False,\n        bias=True,\n    ):\n        super(EdgeGATConv, self).__init__()\n        self._num_heads = num_heads\n        self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)\n        self._out_feats = out_feats\n        self._allow_zero_in_degree = allow_zero_in_degree\n        if isinstance(in_feats, tuple):\n            self.fc_src = nn.Linear(\n                self._in_src_feats, out_feats * num_heads, bias=False\n            )\n            self.fc_dst = nn.Linear(\n                self._in_dst_feats, out_feats * num_heads, bias=False\n            )\n        else:\n            self.fc = nn.Linear(\n                self._in_src_feats, out_feats * num_heads, bias=False\n            )\n        self.attn_l = nn.Parameter(\n            th.FloatTensor(size=(1, num_heads, out_feats))\n        )\n        self.attn_r = nn.Parameter(\n            th.FloatTensor(size=(1, num_heads, out_feats))\n        )\n        self.feat_drop = nn.Dropout(feat_drop)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.leaky_relu = nn.LeakyReLU(negative_slope)\n        if bias:\n            self.bias = nn.Parameter(\n                th.FloatTensor(size=(num_heads * out_feats,))\n            )\n        else:\n            self.register_buffer(\"bias\", None)\n        if residual:\n            self.res_fc = nn.Linear(\n                self._in_dst_feats, num_heads * out_feats, bias=False\n            )\n        else:\n            self.register_buffer(\"res_fc\", None)\n\n        self._edge_feats = edge_feats\n        self.fc_edge = nn.Linear(edge_feats, out_feats * num_heads, bias=False)\n        self.attn_edge = nn.Parameter(\n            th.FloatTensor(size=(1, num_heads, out_feats))\n        )\n\n        self.reset_parameters()\n        self.activation = activation\n\n    def reset_parameters(self):\n        r\"\"\"\n\n        Description\n        -----------\n        Reinitialize learnable parameters.\n\n        Note\n        ----\n        The fc weights :math:`\\mathbf{\\Theta}` are and the\n        attention weights are using xavier initialization method.\n        \"\"\"\n        gain = nn.init.calculate_gain(\"relu\")\n        if hasattr(self, \"fc\"):\n            nn.init.xavier_normal_(self.fc.weight, gain=gain)\n        else:\n            nn.init.xavier_normal_(self.fc_src.weight, gain=gain)\n            nn.init.xavier_normal_(self.fc_dst.weight, gain=gain)\n        nn.init.xavier_normal_(self.attn_l, gain=gain)\n        nn.init.xavier_normal_(self.attn_r, gain=gain)\n\n        nn.init.xavier_normal_(self.fc_edge.weight, gain=gain)\n        nn.init.xavier_normal_(self.attn_edge, gain=gain)\n        if self.bias is not None:\n            nn.init.constant_(self.bias, 0)\n        if isinstance(self.res_fc, nn.Linear):\n            nn.init.xavier_normal_(self.res_fc.weight, gain=gain)\n\n    def set_allow_zero_in_degree(self, set_value):\n        r\"\"\"\n\n        Description\n        -----------\n        Set allow_zero_in_degree flag.\n\n        Parameters\n        ----------\n        set_value : bool\n            The value to be set to the flag.\n        \"\"\"\n        self._allow_zero_in_degree = set_value\n\n    def forward(self, graph, feat, edge_feat, get_attention=False):\n        r\"\"\"\n\n        Description\n        -----------\n        Compute graph attention network layer.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The graph.\n        feat : torch.Tensor or pair of torch.Tensor\n            If a torch.Tensor is given, the input feature of shape :math:`(N, *, D_{in})` where\n            :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.\n            If a pair of torch.Tensor is given, the pair must contain two tensors of shape\n            :math:`(N_{in}, *, D_{in_{src}})` and :math:`(N_{out}, *, D_{in_{dst}})`.\n        edge_feat : torch.Tensor\n            The input edge feature of shape :math:`(E, D_{in_{edge}})`,\n            where :math:`E` is the number of edges and :math:`D_{in_{edge}}`\n            the size of the edge features.\n        get_attention : bool, optional\n            Whether to return the attention values. Default to False.\n\n        Returns\n        -------\n        torch.Tensor\n            The output feature of shape :math:`(N, *, H, D_{out})` where :math:`H`\n            is the number of heads, and :math:`D_{out}` is size of output feature.\n        torch.Tensor, optional\n            The attention values of shape :math:`(E, *, H, 1)`. This is returned only\n            when :attr:`get_attention` is ``True``.\n\n        Raises\n        ------\n        DGLError\n            If there are 0-in-degree nodes in the input graph, it will raise DGLError\n            since no message will be passed to those nodes. This will cause invalid output.\n            The error can be ignored by setting ``allow_zero_in_degree`` parameter to ``True``.\n        \"\"\"\n        with graph.local_scope():\n            if not self._allow_zero_in_degree:\n                if (graph.in_degrees() == 0).any():\n                    raise DGLError(\n                        \"There are 0-in-degree nodes in the graph, \"\n                        \"output for those nodes will be invalid. \"\n                        \"This is harmful for some applications, \"\n                        \"causing silent performance regression. \"\n                        \"Adding self-loop on the input graph by \"\n                        \"calling `g = dgl.add_self_loop(g)` will resolve \"\n                        \"the issue. Setting ``allow_zero_in_degree`` \"\n                        \"to be `True` when constructing this module will \"\n                        \"suppress the check and let the code run.\"\n                    )\n\n            if isinstance(feat, tuple):\n                src_prefix_shape = feat[0].shape[:-1]\n                dst_prefix_shape = feat[1].shape[:-1]\n                h_src = self.feat_drop(feat[0])\n                h_dst = self.feat_drop(feat[1])\n                if not hasattr(self, \"fc_src\"):\n                    feat_src = self.fc(h_src).view(\n                        *src_prefix_shape, self._num_heads, self._out_feats\n                    )\n                    feat_dst = self.fc(h_dst).view(\n                        *dst_prefix_shape, self._num_heads, self._out_feats\n                    )\n                else:\n                    feat_src = self.fc_src(h_src).view(\n                        *src_prefix_shape, self._num_heads, self._out_feats\n                    )\n                    feat_dst = self.fc_dst(h_dst).view(\n                        *dst_prefix_shape, self._num_heads, self._out_feats\n                    )\n            else:\n                src_prefix_shape = dst_prefix_shape = feat.shape[:-1]\n                h_src = h_dst = self.feat_drop(feat)\n                feat_src = feat_dst = self.fc(h_src).view(\n                    *src_prefix_shape, self._num_heads, self._out_feats\n                )\n                if graph.is_block:\n                    feat_dst = feat_src[: graph.number_of_dst_nodes()]\n                    h_dst = h_dst[: graph.number_of_dst_nodes()]\n                    dst_prefix_shape = (\n                        graph.number_of_dst_nodes(),\n                    ) + dst_prefix_shape[1:]\n\n            # Linearly tranform the edge features.\n            n_edges = edge_feat.shape[:-1]\n            feat_edge = self.fc_edge(edge_feat).view(\n                *n_edges, self._num_heads, self._out_feats\n            )\n\n            # Add edge features to graph.\n            graph.edata[\"ft_edge\"] = feat_edge\n\n            el = (feat_src * self.attn_l).sum(dim=-1).unsqueeze(-1)\n            er = (feat_dst * self.attn_r).sum(dim=-1).unsqueeze(-1)\n\n            # Calculate scalar for each edge.\n            ee = (feat_edge * self.attn_edge).sum(dim=-1).unsqueeze(-1)\n            graph.edata[\"ee\"] = ee\n\n            graph.srcdata.update({\"ft\": feat_src, \"el\": el})\n            graph.dstdata.update({\"er\": er})\n            # Compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively.\n            graph.apply_edges(fn.u_add_v(\"el\", \"er\", \"e_tmp\"))\n\n            # e_tmp combines attention weights of source and destination node.\n            # Add the attention weight of the edge.\n            graph.edata[\"e\"] = graph.edata[\"e_tmp\"] + graph.edata[\"ee\"]\n\n            # Create new edges features that combine the\n            # features of the source node and the edge features.\n            graph.apply_edges(fn.u_add_e(\"ft\", \"ft_edge\", \"ft_combined\"))\n\n            e = self.leaky_relu(graph.edata.pop(\"e\"))\n            # Compute softmax.\n            graph.edata[\"a\"] = self.attn_drop(edge_softmax(graph, e))\n\n            # For each edge, element-wise multiply the combined features with\n            # the attention coefficient.\n            graph.edata[\"m_combined\"] = (\n                graph.edata[\"ft_combined\"] * graph.edata[\"a\"]\n            )\n\n            # First copy the edge features and then sum them up.\n            graph.update_all(fn.copy_e(\"m_combined\", \"m\"), fn.sum(\"m\", \"ft\"))\n\n            rst = graph.dstdata[\"ft\"]\n            # Residual.\n            if self.res_fc is not None:\n                # Use -1 rather than self._num_heads to handle broadcasting.\n                if h_dst.numel() != 0:\n                    resval = self.res_fc(h_dst).view(\n                        *dst_prefix_shape, -1, self._out_feats\n                    )\n                    rst = rst + resval\n            # Bias.\n            if self.bias is not None:\n                rst = rst + self.bias.view(\n                    *((1,) * len(dst_prefix_shape)),\n                    self._num_heads,\n                    self._out_feats\n                )\n            # Activation.\n            if self.activation:\n                rst = self.activation(rst)\n\n            if get_attention:\n                return rst, graph.edata[\"a\"]\n            else:\n                return rst\n"
  },
  {
    "path": "python/dgl/nn/pytorch/conv/egatconv.py",
    "content": "\"\"\"Torch modules for graph attention networks with fully valuable edges (EGAT).\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name\nimport torch as th\nfrom torch import nn\nfrom torch.nn import init\n\nfrom .... import function as fn\nfrom ....base import DGLError\nfrom ....utils import expand_as_pair\nfrom ...functional import edge_softmax\n\n\n# pylint: enable=W0235\nclass EGATConv(nn.Module):\n    r\"\"\"Graph attention layer that handles edge features from `Rossmann-Toolbox\n    <https://pubmed.ncbi.nlm.nih.gov/34571541/>`__ (see supplementary data)\n\n    The difference lies in how unnormalized attention scores :math:`e_{ij}` are obtained:\n\n    .. math::\n        e_{ij} &= \\vec{F} (f_{ij}^{\\prime})\n\n        f_{ij}^{\\prime} &= \\mathrm{LeakyReLU}\\left(A [ h_{i} \\| f_{ij} \\| h_{j}]\\right)\n\n    where :math:`f_{ij}^{\\prime}` are edge features, :math:`\\mathrm{A}` is weight matrix and\n    :math:`\\vec{F}` is weight vector. After that, resulting node features\n    :math:`h_{i}^{\\prime}` are updated in the same way as in regular GAT.\n\n    Parameters\n    ----------\n    in_node_feats : int, or pair of ints\n        Input feature size; i.e, the number of dimensions of :math:`h_{i}`.\n        EGATConv can be applied on homogeneous graph and unidirectional\n        `bipartite graph <https://docs.dgl.ai/generated/dgl.bipartite.html?highlight=bipartite>`__.\n        If the layer is to be applied to a unidirectional bipartite graph, ``in_feats``\n        specifies the input feature size on both the source and destination nodes.  If\n        a scalar is given, the source and destination node feature size would take the\n        same value.\n    in_edge_feats : int\n        Input edge feature size :math:`f_{ij}`.\n    out_node_feats : int\n        Output node feature size.\n    out_edge_feats : int\n        Output edge feature size :math:`f_{ij}^{\\prime}`.\n    num_heads : int\n        Number of attention heads.\n    bias : bool, optional\n        If True, add bias term to :math:`f_{ij}^{\\prime}`. Defaults: ``True``.\n\n    Examples\n    ----------\n    >>> import dgl\n    >>> import torch as th\n    >>> from dgl.nn import EGATConv\n\n    >>> # Case 1: Homogeneous graph\n    >>> num_nodes, num_edges = 8, 30\n    >>> # generate a graph\n    >>> graph = dgl.rand_graph(num_nodes,num_edges)\n    >>> node_feats = th.rand((num_nodes, 20))\n    >>> edge_feats = th.rand((num_edges, 12))\n    >>> egat = EGATConv(in_node_feats=20,\n    ...                 in_edge_feats=12,\n    ...                 out_node_feats=15,\n    ...                 out_edge_feats=10,\n    ...                 num_heads=3)\n    >>> #forward pass\n    >>> new_node_feats, new_edge_feats = egat(graph, node_feats, edge_feats)\n    >>> new_node_feats.shape, new_edge_feats.shape\n    torch.Size([8, 3, 15]) torch.Size([30, 3, 10])\n\n    >>> # Case 2: Unidirectional bipartite graph\n    >>> u = [0, 1, 0, 0, 1]\n    >>> v = [0, 1, 2, 3, 2]\n    >>> g = dgl.heterograph({('A', 'r', 'B'): (u, v)})\n    >>> u_feat = th.tensor(np.random.rand(2, 25).astype(np.float32))\n    >>> v_feat = th.tensor(np.random.rand(4, 30).astype(np.float32))\n    >>> nfeats = (u_feat,v_feat)\n    >>> efeats = th.tensor(np.random.rand(5, 15).astype(np.float32))\n    >>> in_node_feats = (25,30)\n    >>> in_edge_feats = 15\n    >>> out_node_feats = 10\n    >>> out_edge_feats = 5\n    >>> num_heads = 3\n    >>> egat_model =  EGATConv(in_node_feats,\n    ...                        in_edge_feats,\n    ...                        out_node_feats,\n    ...                        out_edge_feats,\n    ...                        num_heads,\n    ...                        bias=True)\n    >>> #forward pass\n    >>> new_node_feats,\n    >>> new_edge_feats,\n    >>> attentions = egat_model(g, nfeats, efeats, get_attention=True)\n    >>> new_node_feats.shape, new_edge_feats.shape, attentions.shape\n    (torch.Size([4, 3, 10]), torch.Size([5, 3, 5]), torch.Size([5, 3, 1]))\n    \"\"\"\n\n    def __init__(\n        self,\n        in_node_feats,\n        in_edge_feats,\n        out_node_feats,\n        out_edge_feats,\n        num_heads,\n        bias=True,\n    ):\n        super().__init__()\n        self._num_heads = num_heads\n        self._in_src_node_feats, self._in_dst_node_feats = expand_as_pair(\n            in_node_feats\n        )\n        self._out_node_feats = out_node_feats\n        self._out_edge_feats = out_edge_feats\n        if isinstance(in_node_feats, tuple):\n            self.fc_node_src = nn.Linear(\n                self._in_src_node_feats, out_node_feats * num_heads, bias=False\n            )\n            self.fc_ni = nn.Linear(\n                self._in_src_node_feats, out_edge_feats * num_heads, bias=False\n            )\n            self.fc_nj = nn.Linear(\n                self._in_dst_node_feats, out_edge_feats * num_heads, bias=False\n            )\n        else:\n            self.fc_node_src = nn.Linear(\n                self._in_src_node_feats, out_node_feats * num_heads, bias=False\n            )\n            self.fc_ni = nn.Linear(\n                self._in_src_node_feats, out_edge_feats * num_heads, bias=False\n            )\n            self.fc_nj = nn.Linear(\n                self._in_src_node_feats, out_edge_feats * num_heads, bias=False\n            )\n\n        self.fc_fij = nn.Linear(\n            in_edge_feats, out_edge_feats * num_heads, bias=False\n        )\n        self.attn = nn.Parameter(\n            th.FloatTensor(size=(1, num_heads, out_edge_feats))\n        )\n        if bias:\n            self.bias = nn.Parameter(\n                th.FloatTensor(size=(num_heads * out_edge_feats,))\n            )\n        else:\n            self.register_buffer(\"bias\", None)\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        \"\"\"\n        Reinitialize learnable parameters.\n        \"\"\"\n        gain = init.calculate_gain(\"relu\")\n        init.xavier_normal_(self.fc_node_src.weight, gain=gain)\n        init.xavier_normal_(self.fc_ni.weight, gain=gain)\n        init.xavier_normal_(self.fc_fij.weight, gain=gain)\n        init.xavier_normal_(self.fc_nj.weight, gain=gain)\n        init.xavier_normal_(self.attn, gain=gain)\n        init.constant_(self.bias, 0)\n\n    def forward(\n        self, graph, nfeats, efeats, edge_weight=None, get_attention=False\n    ):\n        r\"\"\"\n        Compute new node and edge features.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The graph.\n        nfeat : torch.Tensor or pair of torch.Tensor\n            If a torch.Tensor is given, the input feature of shape :math:`(N, D_{in})`\n            where:\n                :math:`D_{in}` is size of input node feature,\n                :math:`N` is the number of nodes.\n            If a pair of torch.Tensor is given, the pair must contain two tensors of shape\n                :math:`(N_{in}, D_{in_{src}})` and\n                :math:`(N_{out}, D_{in_{dst}})`.\n        efeats: torch.Tensor\n             The input edge feature of shape :math:`(E, F_{in})`\n             where:\n                 :math:`F_{in}` is size of input node feature,\n                 :math:`E` is the number of edges.\n        edge_weight : torch.Tensor, optional\n            A 1D tensor of edge weight values.  Shape: :math:`(|E|,)`.\n        get_attention : bool, optional\n                Whether to return the attention values. Default to False.\n\n        Returns\n        -------\n        pair of torch.Tensor\n            node output features followed by edge output features.\n            The node output feature is of shape :math:`(N, H, D_{out})`\n            The edge output feature is of shape :math:`(F, H, F_{out})`\n            where:\n                :math:`H` is the number of heads,\n                :math:`D_{out}` is size of output node feature,\n                :math:`F_{out}` is size of output edge feature.\n        torch.Tensor, optional\n            The attention values of shape :math:`(E, H, 1)`.\n            This is returned only when :attr:`get_attention` is ``True``.\n        \"\"\"\n\n        with graph.local_scope():\n            if (graph.in_degrees() == 0).any():\n                raise DGLError(\n                    \"There are 0-in-degree nodes in the graph, \"\n                    \"output for those nodes will be invalid. \"\n                    \"This is harmful for some applications, \"\n                    \"causing silent performance regression. \"\n                    \"Adding self-loop on the input graph by \"\n                    \"calling `g = dgl.add_self_loop(g)` will resolve \"\n                    \"the issue.\"\n                )\n\n            # calc edge attention\n            # same trick way as in dgl.nn.pytorch.GATConv, but also includes edge feats\n            # https://github.com/dmlc/dgl/blob/master/python/dgl/nn/pytorch/conv/gatconv.py\n            if isinstance(nfeats, tuple):\n                nfeats_src, nfeats_dst = nfeats\n            else:\n                nfeats_src = nfeats_dst = nfeats\n\n            f_ni = self.fc_ni(nfeats_src)\n            f_nj = self.fc_nj(nfeats_dst)\n            f_fij = self.fc_fij(efeats)\n\n            graph.srcdata.update({\"f_ni\": f_ni})\n            graph.dstdata.update({\"f_nj\": f_nj})\n            # add ni, nj factors\n            graph.apply_edges(fn.u_add_v(\"f_ni\", \"f_nj\", \"f_tmp\"))\n            # add fij to node factor\n            f_out = graph.edata.pop(\"f_tmp\") + f_fij\n            if self.bias is not None:\n                f_out = f_out + self.bias\n            f_out = nn.functional.leaky_relu(f_out)\n            f_out = f_out.view(-1, self._num_heads, self._out_edge_feats)\n            # compute attention factor\n            e = (f_out * self.attn).sum(dim=-1).unsqueeze(-1)\n            graph.edata[\"a\"] = edge_softmax(graph, e)\n            if edge_weight is not None:\n                graph.edata[\"a\"] = graph.edata[\"a\"] * edge_weight.tile(\n                    1, self._num_heads, 1\n                ).transpose(0, 2)\n            graph.srcdata[\"h_out\"] = self.fc_node_src(nfeats_src).view(\n                -1, self._num_heads, self._out_node_feats\n            )\n            # calc weighted sum\n            graph.update_all(\n                fn.u_mul_e(\"h_out\", \"a\", \"m\"), fn.sum(\"m\", \"h_out\")\n            )\n\n            h_out = graph.dstdata[\"h_out\"].view(\n                -1, self._num_heads, self._out_node_feats\n            )\n            if get_attention:\n                return h_out, f_out, graph.edata.pop(\"a\")\n            else:\n                return h_out, f_out\n"
  },
  {
    "path": "python/dgl/nn/pytorch/conv/egnnconv.py",
    "content": "\"\"\"Torch Module for E(n) Equivariant Graph Convolutional Layer\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name\nimport torch\nimport torch.nn as nn\n\nfrom .... import function as fn\n\n\nclass EGNNConv(nn.Module):\n    r\"\"\"Equivariant Graph Convolutional Layer from `E(n) Equivariant Graph\n    Neural Networks <https://arxiv.org/abs/2102.09844>`__\n\n    .. math::\n\n        m_{ij}=\\phi_e(h_i^l, h_j^l, ||x_i^l-x_j^l||^2, a_{ij})\n\n        x_i^{l+1} = x_i^l + C\\sum_{j\\in\\mathcal{N}(i)}(x_i^l-x_j^l)\\phi_x(m_{ij})\n\n        m_i = \\sum_{j\\in\\mathcal{N}(i)} m_{ij}\n\n        h_i^{l+1} = \\phi_h(h_i^l, m_i)\n\n    where :math:`h_i`, :math:`x_i`, :math:`a_{ij}` are node features, coordinate\n    features, and edge features respectively. :math:`\\phi_e`, :math:`\\phi_h`, and\n    :math:`\\phi_x` are two-layer MLPs. :math:`C` is a constant for normalization,\n    computed as :math:`1/|\\mathcal{N}(i)|`.\n\n    Parameters\n    ----------\n    in_size : int\n        Input feature size; i.e. the size of :math:`h_i^l`.\n    hidden_size : int\n        Hidden feature size; i.e. the size of hidden layer in the two-layer MLPs in\n        :math:`\\phi_e, \\phi_x, \\phi_h`.\n    out_size : int\n        Output feature size; i.e. the size of :math:`h_i^{l+1}`.\n    edge_feat_size : int, optional\n        Edge feature size; i.e. the size of :math:`a_{ij}`. Default: 0.\n\n    Example\n    -------\n    >>> import dgl\n    >>> import torch as th\n    >>> from dgl.nn import EGNNConv\n    >>>\n    >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))\n    >>> node_feat, coord_feat, edge_feat = th.ones(6, 10), th.ones(6, 3), th.ones(6, 2)\n    >>> conv = EGNNConv(10, 10, 10, 2)\n    >>> h, x = conv(g, node_feat, coord_feat, edge_feat)\n    \"\"\"\n\n    def __init__(self, in_size, hidden_size, out_size, edge_feat_size=0):\n        super(EGNNConv, self).__init__()\n\n        self.in_size = in_size\n        self.hidden_size = hidden_size\n        self.out_size = out_size\n        self.edge_feat_size = edge_feat_size\n        act_fn = nn.SiLU()\n\n        # \\phi_e\n        self.edge_mlp = nn.Sequential(\n            # +1 for the radial feature: ||x_i - x_j||^2\n            nn.Linear(in_size * 2 + edge_feat_size + 1, hidden_size),\n            act_fn,\n            nn.Linear(hidden_size, hidden_size),\n            act_fn,\n        )\n\n        # \\phi_h\n        self.node_mlp = nn.Sequential(\n            nn.Linear(in_size + hidden_size, hidden_size),\n            act_fn,\n            nn.Linear(hidden_size, out_size),\n        )\n\n        # \\phi_x\n        self.coord_mlp = nn.Sequential(\n            nn.Linear(hidden_size, hidden_size),\n            act_fn,\n            nn.Linear(hidden_size, 1, bias=False),\n        )\n\n    def message(self, edges):\n        \"\"\"message function for EGNN\"\"\"\n        # concat features for edge mlp\n        if self.edge_feat_size > 0:\n            f = torch.cat(\n                [\n                    edges.src[\"h\"],\n                    edges.dst[\"h\"],\n                    edges.data[\"radial\"],\n                    edges.data[\"a\"],\n                ],\n                dim=-1,\n            )\n        else:\n            f = torch.cat(\n                [edges.src[\"h\"], edges.dst[\"h\"], edges.data[\"radial\"]], dim=-1\n            )\n\n        msg_h = self.edge_mlp(f)\n        msg_x = self.coord_mlp(msg_h) * edges.data[\"x_diff\"]\n\n        return {\"msg_x\": msg_x, \"msg_h\": msg_h}\n\n    def forward(self, graph, node_feat, coord_feat, edge_feat=None):\n        r\"\"\"\n        Description\n        -----------\n        Compute EGNN layer.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The graph.\n        node_feat : torch.Tensor\n            The input feature of shape :math:`(N, h_n)`. :math:`N` is the number of\n            nodes, and :math:`h_n` must be the same as in_size.\n        coord_feat : torch.Tensor\n            The coordinate feature of shape :math:`(N, h_x)`. :math:`N` is the\n            number of nodes, and :math:`h_x` can be any positive integer.\n        edge_feat : torch.Tensor, optional\n            The edge feature of shape :math:`(M, h_e)`. :math:`M` is the number of\n            edges, and :math:`h_e` must be the same as edge_feat_size.\n\n        Returns\n        -------\n        node_feat_out : torch.Tensor\n            The output node feature of shape :math:`(N, h_n')` where :math:`h_n'`\n            is the same as out_size.\n        coord_feat_out: torch.Tensor\n            The output coordinate feature of shape :math:`(N, h_x)` where :math:`h_x`\n            is the same as the input coordinate feature dimension.\n        \"\"\"\n        with graph.local_scope():\n            # node feature\n            graph.ndata[\"h\"] = node_feat\n            # coordinate feature\n            graph.ndata[\"x\"] = coord_feat\n            # edge feature\n            if self.edge_feat_size > 0:\n                assert edge_feat is not None, \"Edge features must be provided.\"\n                graph.edata[\"a\"] = edge_feat\n            # get coordinate diff & radial features\n            graph.apply_edges(fn.u_sub_v(\"x\", \"x\", \"x_diff\"))\n            graph.edata[\"radial\"] = (\n                graph.edata[\"x_diff\"].square().sum(dim=1).unsqueeze(-1)\n            )\n            # normalize coordinate difference\n            graph.edata[\"x_diff\"] = graph.edata[\"x_diff\"] / (\n                graph.edata[\"radial\"].sqrt() + 1e-30\n            )\n            graph.apply_edges(self.message)\n            graph.update_all(fn.copy_e(\"msg_x\", \"m\"), fn.mean(\"m\", \"x_neigh\"))\n            graph.update_all(fn.copy_e(\"msg_h\", \"m\"), fn.sum(\"m\", \"h_neigh\"))\n\n            h_neigh, x_neigh = graph.ndata[\"h_neigh\"], graph.ndata[\"x_neigh\"]\n\n            h = self.node_mlp(torch.cat([node_feat, h_neigh], dim=-1))\n            x = coord_feat + x_neigh\n\n            return h, x\n"
  },
  {
    "path": "python/dgl/nn/pytorch/conv/gatconv.py",
    "content": "\"\"\"Torch modules for graph attention networks(GAT).\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name\nimport torch as th\nfrom torch import nn\n\nfrom .... import function as fn\nfrom ....base import DGLError\nfrom ....utils import expand_as_pair\nfrom ...functional import edge_softmax\nfrom ..utils import Identity\n\n\n# pylint: enable=W0235\nclass GATConv(nn.Module):\n    r\"\"\"Graph attention layer from `Graph Attention Network\n    <https://arxiv.org/pdf/1710.10903.pdf>`__\n\n    .. math::\n        h_i^{(l+1)} = \\sum_{j\\in \\mathcal{N}(i)} \\alpha_{i,j} W^{(l)} h_j^{(l)}\n\n    where :math:`\\alpha_{ij}` is the attention score bewteen node :math:`i` and\n    node :math:`j`:\n\n    .. math::\n        \\alpha_{ij}^{l} &= \\mathrm{softmax_i} (e_{ij}^{l})\n\n        e_{ij}^{l} &= \\mathrm{LeakyReLU}\\left(\\vec{a}^T [W h_{i} \\| W h_{j}]\\right)\n\n    Parameters\n    ----------\n    in_feats : int, or pair of ints\n        Input feature size; i.e, the number of dimensions of :math:`h_i^{(l)}`.\n        GATConv can be applied on homogeneous graph and unidirectional\n        `bipartite graph <https://docs.dgl.ai/generated/dgl.bipartite.html?highlight=bipartite>`__.\n        If the layer is to be applied to a unidirectional bipartite graph, ``in_feats``\n        specifies the input feature size on both the source and destination nodes.  If\n        a scalar is given, the source and destination node feature size would take the\n        same value.\n    out_feats : int\n        Output feature size; i.e, the number of dimensions of :math:`h_i^{(l+1)}`.\n    num_heads : int\n        Number of heads in Multi-Head Attention.\n    feat_drop : float, optional\n        Dropout rate on feature. Defaults: ``0``.\n    attn_drop : float, optional\n        Dropout rate on attention weight. Defaults: ``0``.\n    negative_slope : float, optional\n        LeakyReLU angle of negative slope. Defaults: ``0.2``.\n    residual : bool, optional\n        If True, use residual connection. Defaults: ``False``.\n    activation : callable activation function/layer or None, optional.\n        If not None, applies an activation function to the updated node features.\n        Default: ``None``.\n    allow_zero_in_degree : bool, optional\n        If there are 0-in-degree nodes in the graph, output for those nodes will be invalid\n        since no message will be passed to those nodes. This is harmful for some applications\n        causing silent performance regression. This module will raise a DGLError if it detects\n        0-in-degree nodes in input graph. By setting ``True``, it will suppress the check\n        and let the users handle it by themselves. Defaults: ``False``.\n    bias : bool, optional\n        If True, learns a bias term. Defaults: ``True``.\n\n    Note\n    ----\n    Zero in-degree nodes will lead to invalid output value. This is because no message\n    will be passed to those nodes, the aggregation function will be appied on empty input.\n    A common practice to avoid this is to add a self-loop for each node in the graph if\n    it is homogeneous, which can be achieved by:\n\n    >>> g = ... # a DGLGraph\n    >>> g = dgl.add_self_loop(g)\n\n    Calling ``add_self_loop`` will not work for some graphs, for example, heterogeneous graph\n    since the edge type can not be decided for self_loop edges. Set ``allow_zero_in_degree``\n    to ``True`` for those cases to unblock the code and handle zero-in-degree nodes manually.\n    A common practise to handle this is to filter out the nodes with zero-in-degree when use\n    after conv.\n\n    Examples\n    --------\n    >>> import dgl\n    >>> import numpy as np\n    >>> import torch as th\n    >>> from dgl.nn import GATConv\n\n    >>> # Case 1: Homogeneous graph\n    >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))\n    >>> g = dgl.add_self_loop(g)\n    >>> feat = th.ones(6, 10)\n    >>> gatconv = GATConv(10, 2, num_heads=3)\n    >>> res = gatconv(g, feat)\n    >>> res\n    tensor([[[ 3.4570,  1.8634],\n            [ 1.3805, -0.0762],\n            [ 1.0390, -1.1479]],\n            [[ 3.4570,  1.8634],\n            [ 1.3805, -0.0762],\n            [ 1.0390, -1.1479]],\n            [[ 3.4570,  1.8634],\n            [ 1.3805, -0.0762],\n            [ 1.0390, -1.1479]],\n            [[ 3.4570,  1.8634],\n            [ 1.3805, -0.0762],\n            [ 1.0390, -1.1479]],\n            [[ 3.4570,  1.8634],\n            [ 1.3805, -0.0762],\n            [ 1.0390, -1.1479]],\n            [[ 3.4570,  1.8634],\n            [ 1.3805, -0.0762],\n            [ 1.0390, -1.1479]]], grad_fn=<BinaryReduceBackward>)\n\n    >>> # Case 2: Unidirectional bipartite graph\n    >>> u = [0, 1, 0, 0, 1]\n    >>> v = [0, 1, 2, 3, 2]\n    >>> g = dgl.heterograph({('A', 'r', 'B'): (u, v)})\n    >>> u_feat = th.tensor(np.random.rand(2, 5).astype(np.float32))\n    >>> v_feat = th.tensor(np.random.rand(4, 10).astype(np.float32))\n    >>> gatconv = GATConv((5,10), 2, 3)\n    >>> res = gatconv(g, (u_feat, v_feat))\n    >>> res\n    tensor([[[-0.6066,  1.0268],\n            [-0.5945, -0.4801],\n            [ 0.1594,  0.3825]],\n            [[ 0.0268,  1.0783],\n            [ 0.5041, -1.3025],\n            [ 0.6568,  0.7048]],\n            [[-0.2688,  1.0543],\n            [-0.0315, -0.9016],\n            [ 0.3943,  0.5347]],\n            [[-0.6066,  1.0268],\n            [-0.5945, -0.4801],\n            [ 0.1594,  0.3825]]], grad_fn=<BinaryReduceBackward>)\n    \"\"\"\n\n    def __init__(\n        self,\n        in_feats,\n        out_feats,\n        num_heads,\n        feat_drop=0.0,\n        attn_drop=0.0,\n        negative_slope=0.2,\n        residual=False,\n        activation=None,\n        allow_zero_in_degree=False,\n        bias=True,\n    ):\n        super(GATConv, self).__init__()\n        self._num_heads = num_heads\n        self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)\n        self._out_feats = out_feats\n        self._allow_zero_in_degree = allow_zero_in_degree\n        if isinstance(in_feats, tuple):\n            self.fc_src = nn.Linear(\n                self._in_src_feats, out_feats * num_heads, bias=False\n            )\n            self.fc_dst = nn.Linear(\n                self._in_dst_feats, out_feats * num_heads, bias=False\n            )\n        else:\n            self.fc = nn.Linear(\n                self._in_src_feats, out_feats * num_heads, bias=False\n            )\n        self.attn_l = nn.Parameter(\n            th.FloatTensor(size=(1, num_heads, out_feats))\n        )\n        self.attn_r = nn.Parameter(\n            th.FloatTensor(size=(1, num_heads, out_feats))\n        )\n        self.feat_drop = nn.Dropout(feat_drop)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.leaky_relu = nn.LeakyReLU(negative_slope)\n\n        self.has_linear_res = False\n        self.has_explicit_bias = False\n        if residual:\n            if self._in_dst_feats != out_feats * num_heads:\n                self.res_fc = nn.Linear(\n                    self._in_dst_feats, num_heads * out_feats, bias=bias\n                )\n                self.has_linear_res = True\n            else:\n                self.res_fc = Identity()\n        else:\n            self.register_buffer(\"res_fc\", None)\n\n        if bias and not self.has_linear_res:\n            self.bias = nn.Parameter(\n                th.FloatTensor(size=(num_heads * out_feats,))\n            )\n            self.has_explicit_bias = True\n        else:\n            self.register_buffer(\"bias\", None)\n\n        self.reset_parameters()\n        self.activation = activation\n\n    def reset_parameters(self):\n        \"\"\"\n\n        Description\n        -----------\n        Reinitialize learnable parameters.\n\n        Note\n        ----\n        The fc weights :math:`W^{(l)}` are initialized using Glorot uniform initialization.\n        The attention weights are using xavier initialization method.\n        \"\"\"\n        gain = nn.init.calculate_gain(\"relu\")\n        if hasattr(self, \"fc\"):\n            nn.init.xavier_normal_(self.fc.weight, gain=gain)\n        else:\n            nn.init.xavier_normal_(self.fc_src.weight, gain=gain)\n            nn.init.xavier_normal_(self.fc_dst.weight, gain=gain)\n        nn.init.xavier_normal_(self.attn_l, gain=gain)\n        nn.init.xavier_normal_(self.attn_r, gain=gain)\n        if self.has_explicit_bias:\n            nn.init.constant_(self.bias, 0)\n        if isinstance(self.res_fc, nn.Linear):\n            nn.init.xavier_normal_(self.res_fc.weight, gain=gain)\n            if self.res_fc.bias is not None:\n                nn.init.constant_(self.res_fc.bias, 0)\n\n    def set_allow_zero_in_degree(self, set_value):\n        r\"\"\"\n\n        Description\n        -----------\n        Set allow_zero_in_degree flag.\n\n        Parameters\n        ----------\n        set_value : bool\n            The value to be set to the flag.\n        \"\"\"\n        self._allow_zero_in_degree = set_value\n\n    def forward(self, graph, feat, edge_weight=None, get_attention=False):\n        r\"\"\"\n\n        Description\n        -----------\n        Compute graph attention network layer.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The graph.\n        feat : torch.Tensor or pair of torch.Tensor\n            If a torch.Tensor is given, the input feature of shape :math:`(N, *, D_{in})` where\n            :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.\n            If a pair of torch.Tensor is given, the pair must contain two tensors of shape\n            :math:`(N_{in}, *, D_{in_{src}})` and :math:`(N_{out}, *, D_{in_{dst}})`.\n        edge_weight : torch.Tensor, optional\n            A 1D tensor of edge weight values.  Shape: :math:`(|E|,)`.\n        get_attention : bool, optional\n            Whether to return the attention values. Default to False.\n\n        Returns\n        -------\n        torch.Tensor\n            The output feature of shape :math:`(N, *, H, D_{out})` where :math:`H`\n            is the number of heads, and :math:`D_{out}` is size of output feature.\n        torch.Tensor, optional\n            The attention values of shape :math:`(E, *, H, 1)`, where :math:`E` is the number of\n            edges. This is returned only when :attr:`get_attention` is ``True``.\n\n        Raises\n        ------\n        DGLError\n            If there are 0-in-degree nodes in the input graph, it will raise DGLError\n            since no message will be passed to those nodes. This will cause invalid output.\n            The error can be ignored by setting ``allow_zero_in_degree`` parameter to ``True``.\n        \"\"\"\n        with graph.local_scope():\n            if not self._allow_zero_in_degree:\n                if (graph.in_degrees() == 0).any():\n                    raise DGLError(\n                        \"There are 0-in-degree nodes in the graph, \"\n                        \"output for those nodes will be invalid. \"\n                        \"This is harmful for some applications, \"\n                        \"causing silent performance regression. \"\n                        \"Adding self-loop on the input graph by \"\n                        \"calling `g = dgl.add_self_loop(g)` will resolve \"\n                        \"the issue. Setting ``allow_zero_in_degree`` \"\n                        \"to be `True` when constructing this module will \"\n                        \"suppress the check and let the code run.\"\n                    )\n\n            if isinstance(feat, tuple):\n                src_prefix_shape = feat[0].shape[:-1]\n                dst_prefix_shape = feat[1].shape[:-1]\n                h_src = self.feat_drop(feat[0])\n                h_dst = self.feat_drop(feat[1])\n                if not hasattr(self, \"fc_src\"):\n                    feat_src = self.fc(h_src).view(\n                        *src_prefix_shape, self._num_heads, self._out_feats\n                    )\n                    feat_dst = self.fc(h_dst).view(\n                        *dst_prefix_shape, self._num_heads, self._out_feats\n                    )\n                else:\n                    feat_src = self.fc_src(h_src).view(\n                        *src_prefix_shape, self._num_heads, self._out_feats\n                    )\n                    feat_dst = self.fc_dst(h_dst).view(\n                        *dst_prefix_shape, self._num_heads, self._out_feats\n                    )\n            else:\n                src_prefix_shape = dst_prefix_shape = feat.shape[:-1]\n                h_src = h_dst = self.feat_drop(feat)\n                feat_src = feat_dst = self.fc(h_src).view(\n                    *src_prefix_shape, self._num_heads, self._out_feats\n                )\n                if graph.is_block:\n                    feat_dst = feat_src[: graph.number_of_dst_nodes()]\n                    h_dst = h_dst[: graph.number_of_dst_nodes()]\n                    dst_prefix_shape = (\n                        graph.number_of_dst_nodes(),\n                    ) + dst_prefix_shape[1:]\n            # NOTE: GAT paper uses \"first concatenation then linear projection\"\n            # to compute attention scores, while ours is \"first projection then\n            # addition\", the two approaches are mathematically equivalent:\n            # We decompose the weight vector a mentioned in the paper into\n            # [a_l || a_r], then\n            # a^T [Wh_i || Wh_j] = a_l Wh_i + a_r Wh_j\n            # Our implementation is much efficient because we do not need to\n            # save [Wh_i || Wh_j] on edges, which is not memory-efficient. Plus,\n            # addition could be optimized with DGL's built-in function u_add_v,\n            # which further speeds up computation and saves memory footprint.\n            el = (feat_src * self.attn_l).sum(dim=-1).unsqueeze(-1)\n            er = (feat_dst * self.attn_r).sum(dim=-1).unsqueeze(-1)\n            graph.srcdata.update({\"ft\": feat_src, \"el\": el})\n            graph.dstdata.update({\"er\": er})\n            # compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively.\n            graph.apply_edges(fn.u_add_v(\"el\", \"er\", \"e\"))\n            e = self.leaky_relu(graph.edata.pop(\"e\"))\n            # compute softmax\n            graph.edata[\"a\"] = self.attn_drop(edge_softmax(graph, e))\n            if edge_weight is not None:\n                graph.edata[\"a\"] = graph.edata[\"a\"] * edge_weight.tile(\n                    1, self._num_heads, 1\n                ).transpose(0, 2)\n            # message passing\n            graph.update_all(fn.u_mul_e(\"ft\", \"a\", \"m\"), fn.sum(\"m\", \"ft\"))\n            rst = graph.dstdata[\"ft\"]\n            # residual\n            if self.res_fc is not None:\n                # Use -1 rather than self._num_heads to handle broadcasting\n                if h_dst.numel() != 0:\n                    resval = self.res_fc(h_dst).view(\n                        *dst_prefix_shape, -1, self._out_feats\n                    )\n                    rst = rst + resval\n            # bias\n            if self.has_explicit_bias:\n                rst = rst + self.bias.view(\n                    *((1,) * len(dst_prefix_shape)),\n                    self._num_heads,\n                    self._out_feats\n                )\n            # activation\n            if self.activation:\n                rst = self.activation(rst)\n\n            if get_attention:\n                return rst, graph.edata[\"a\"]\n            else:\n                return rst\n"
  },
  {
    "path": "python/dgl/nn/pytorch/conv/gatedgcnconv.py",
    "content": "\"\"\"Torch Module for GatedGCN layer\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name, cell-var-from-loop\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\n\nfrom .... import function as fn\n\n\nclass GatedGCNConv(nn.Module):\n    r\"\"\"Gated graph convolutional layer from `Benchmarking Graph Neural Networks\n    <https://arxiv.org/abs/2003.00982>`__\n\n    .. math::\n        e_{ij}^{l+1}=D^l h_{i}^{l}+E^l h_{j}^{l}+C^l e_{ij}^{l}\n\n        norm_{ij}=\\Sigma_{j\\in N_{i}} \\sigma\\left(e_{ij}^{l+1}\\right)+\\varepsilon\n\n        \\hat{e}_{ij}^{l+1}=\\sigma(e_{ij}^{l+1}) / norm_{ij}\n\n        h_{i}^{l+1}=A^l h_{i}^{l}+\\Sigma_{j \\in N_{i}} \\hat{e}_{ij}^{l+1} \\odot B^l h_{j}^{l}\n\n    where :math:`h_{i}^{l}` is node :math:`i` feature of layer :math:`l`,\n    :math:`e_{ij}^{l}` is edge :math:`ij` feature of layer :math:`l`,\n    :math:`\\sigma` is sigmoid function, :math:`\\varepsilon` is a small fixed constant\n    for numerical stability, :math:`A^l, B^l, C^l, D^l, E^l` are linear layers.\n\n    Parameters\n    ----------\n    input_feats : int\n        Input feature size; i.e, the number of dimensions of :math:`h_{i}^{l}`.\n    edge_feats: int\n        Edge feature size; i.e., the number of dimensions of :math:`e_{ij}^{l}`.\n    output_feats : int\n        Output feature size; i.e., the number of dimensions of :math:`h_{i}^{l+1}`.\n    dropout : float, optional\n        Dropout rate on node and edge feature. Default: ``0``.\n    batch_norm : bool, optional\n        Whether to include batch normalization on node and edge feature. Default: ``True``.\n    residual : bool, optional\n        Whether to include residual connections. Default: ``True``.\n    activation : callable activation function/layer or None, optional\n        If not None, apply an activation function to the updated node features.\n        Default: ``F.relu``.\n\n    Example\n    -------\n    >>> import dgl\n    >>> import torch as th\n    >>> import torch.nn.functional as F\n    >>> from dgl.nn import GatedGCNConv\n\n    >>> num_nodes, num_edges = 8, 30\n    >>> graph = dgl.rand_graph(num_nodes,num_edges)\n    >>> node_feats = th.rand(num_nodes, 20)\n    >>> edge_feats = th.rand(num_edges, 12)\n    >>> gatedGCN = GatedGCNConv(20, 12, 20)\n    >>> new_node_feats, new_edge_feats = gatedGCN(graph, node_feats, edge_feats)\n    >>> new_node_feats.shape, new_edge_feats.shape\n    (torch.Size([8, 20]), torch.Size([30, 20]))\n\n    \"\"\"\n\n    def __init__(\n        self,\n        input_feats,\n        edge_feats,\n        output_feats,\n        dropout=0,\n        batch_norm=True,\n        residual=True,\n        activation=F.relu,\n    ):\n        super(GatedGCNConv, self).__init__()\n        self.dropout = nn.Dropout(dropout)\n        self.batch_norm = batch_norm\n        self.residual = residual\n\n        if input_feats != output_feats or edge_feats != output_feats:\n            self.residual = False\n\n        # Linearly transform the node features.\n        self.A = nn.Linear(input_feats, output_feats, bias=True)\n        self.B = nn.Linear(input_feats, output_feats, bias=True)\n        self.D = nn.Linear(input_feats, output_feats, bias=True)\n        self.E = nn.Linear(input_feats, output_feats, bias=True)\n\n        # Linearly transform the edge features.\n        self.C = nn.Linear(edge_feats, output_feats, bias=True)\n\n        # Batch normalization on the node/edge features.\n        self.bn_node = nn.BatchNorm1d(output_feats)\n        self.bn_edge = nn.BatchNorm1d(output_feats)\n\n        self.activation = activation\n\n    def forward(self, graph, feat, edge_feat):\n        \"\"\"\n\n        Description\n        -----------\n        Compute gated graph convolution layer.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The graph.\n        feat : torch.Tensor\n            The input feature of shape :math:`(N, D_{in})` where :math:`N`\n            is the number of nodes of the graph and :math:`D_{in}` is the\n            input feature size.\n        edge_feat : torch.Tensor\n            The input edge feature of shape :math:`(E, D_{edge})`,\n            where :math:`E` is the number of edges and :math:`D_{edge}`\n            is the size of the edge features.\n\n        Returns\n        -------\n        torch.Tensor\n            The output node feature of shape :math:`(N, D_{out})` where :math:`D_{out}`\n            is the output feature size.\n        torch.Tensor\n            The output edge feature of shape :math:`(E, D_{out})` where :math:`D_{out}`\n            is the output feature size.\n        \"\"\"\n        with graph.local_scope():\n            # For residual connection\n            h_in = feat\n            e_in = edge_feat\n\n            graph.ndata[\"Ah\"] = self.A(feat)\n            graph.ndata[\"Bh\"] = self.B(feat)\n            graph.ndata[\"Dh\"] = self.D(feat)\n            graph.ndata[\"Eh\"] = self.E(feat)\n            graph.edata[\"Ce\"] = self.C(edge_feat)\n\n            graph.apply_edges(fn.u_add_v(\"Dh\", \"Eh\", \"DEh\"))\n\n            # Get edge feature\n            graph.edata[\"e\"] = graph.edata[\"DEh\"] + graph.edata[\"Ce\"]\n            graph.edata[\"sigma\"] = torch.sigmoid(graph.edata[\"e\"])\n\n            graph.update_all(\n                fn.u_mul_e(\"Bh\", \"sigma\", \"m\"), fn.sum(\"m\", \"sum_sigma_h\")\n            )\n            graph.update_all(fn.copy_e(\"sigma\", \"m\"), fn.sum(\"m\", \"sum_sigma\"))\n            graph.ndata[\"h\"] = graph.ndata[\"Ah\"] + graph.ndata[\n                \"sum_sigma_h\"\n            ] / (graph.ndata[\"sum_sigma\"] + 1e-6)\n\n            # Result of graph convolution.\n            feat = graph.ndata[\"h\"]\n            edge_feat = graph.edata[\"e\"]\n\n            # Batch normalization.\n            if self.batch_norm:\n                feat = self.bn_node(feat)\n                edge_feat = self.bn_edge(edge_feat)\n\n            # Non-linear activation.\n            if self.activation:\n                feat = self.activation(feat)\n                edge_feat = self.activation(edge_feat)\n\n            # Residual connection.\n            if self.residual:\n                feat = h_in + feat\n                edge_feat = e_in + edge_feat\n\n            feat = self.dropout(feat)\n            edge_feat = self.dropout(edge_feat)\n\n            return feat, edge_feat\n"
  },
  {
    "path": "python/dgl/nn/pytorch/conv/gatedgraphconv.py",
    "content": "\"\"\"Torch Module for Gated Graph Convolution layer\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name, cell-var-from-loop\nimport torch as th\nfrom torch import nn\nfrom torch.nn import init\n\nfrom .... import function as fn\n\n\nclass GatedGraphConv(nn.Module):\n    r\"\"\"Gated Graph Convolution layer from `Gated Graph Sequence\n    Neural Networks <https://arxiv.org/pdf/1511.05493.pdf>`__\n\n    .. math::\n        h_{i}^{0} &= [ x_i \\| \\mathbf{0} ]\n\n        a_{i}^{t} &= \\sum_{j\\in\\mathcal{N}(i)} W_{e_{ij}} h_{j}^{t}\n\n        h_{i}^{t+1} &= \\mathrm{GRU}(a_{i}^{t}, h_{i}^{t})\n\n    Parameters\n    ----------\n    in_feats : int\n        Input feature size; i.e, the number of dimensions of :math:`x_i`.\n    out_feats : int\n        Output feature size; i.e., the number of dimensions of :math:`h_i^{(t+1)}`.\n    n_steps : int\n        Number of recurrent steps; i.e, the :math:`t` in the above formula.\n    n_etypes : int\n        Number of edge types.\n    bias : bool\n        If True, adds a learnable bias to the output. Default: ``True``.\n\n    Example\n    -------\n    >>> import dgl\n    >>> import numpy as np\n    >>> import torch as th\n    >>> from dgl.nn import GatedGraphConv\n    >>>\n    >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))\n    >>> feat = th.ones(6, 10)\n    >>> conv = GatedGraphConv(10, 10, 2, 3)\n    >>> etype = th.tensor([0,1,2,0,1,2])\n    >>> res = conv(g, feat, etype)\n    >>> res\n    tensor([[ 0.4652,  0.4458,  0.5169,  0.4126,  0.4847,  0.2303,  0.2757,  0.7721,\n            0.0523,  0.0857],\n            [ 0.0832,  0.1388, -0.5643,  0.7053, -0.2524, -0.3847,  0.7587,  0.8245,\n            0.9315,  0.4063],\n            [ 0.6340,  0.4096,  0.7692,  0.2125,  0.2106,  0.4542, -0.0580,  0.3364,\n            -0.1376,  0.4948],\n            [ 0.5551,  0.7946,  0.6220,  0.8058,  0.5711,  0.3063, -0.5454,  0.2272,\n            -0.6931, -0.1607],\n            [ 0.2644,  0.2469, -0.6143,  0.6008, -0.1516, -0.3781,  0.5878,  0.7993,\n            0.9241,  0.1835],\n            [ 0.6393,  0.3447,  0.3893,  0.4279,  0.3342,  0.3809,  0.0406,  0.5030,\n            0.1342,  0.0425]], grad_fn=<AddBackward0>)\n    \"\"\"\n\n    def __init__(self, in_feats, out_feats, n_steps, n_etypes, bias=True):\n        super(GatedGraphConv, self).__init__()\n        assert in_feats <= out_feats, \"out_feats must be not less than in_feats\"\n        self._in_feats = in_feats\n        self._out_feats = out_feats\n        self._n_steps = n_steps\n        self._n_etypes = n_etypes\n        self.linears = nn.ModuleList(\n            [nn.Linear(out_feats, out_feats) for _ in range(n_etypes)]\n        )\n        self.gru = nn.GRUCell(out_feats, out_feats, bias=bias)\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        r\"\"\"\n\n        Description\n        -----------\n        Reinitialize learnable parameters.\n\n        Note\n        ----\n        The model parameters are initialized using Glorot uniform initialization\n        and the bias is initialized to be zero.\n        \"\"\"\n        gain = init.calculate_gain(\"relu\")\n        self.gru.reset_parameters()\n        for linear in self.linears:\n            init.xavier_normal_(linear.weight, gain=gain)\n            init.zeros_(linear.bias)\n\n    def set_allow_zero_in_degree(self, set_value):\n        r\"\"\"\n\n        Description\n        -----------\n        Set allow_zero_in_degree flag.\n\n        Parameters\n        ----------\n        set_value : bool\n            The value to be set to the flag.\n        \"\"\"\n        self._allow_zero_in_degree = set_value\n\n    def forward(self, graph, feat, etypes=None):\n        \"\"\"\n\n        Description\n        -----------\n        Compute Gated Graph Convolution layer.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The graph.\n        feat : torch.Tensor\n            The input feature of shape :math:`(N, D_{in})` where :math:`N`\n            is the number of nodes of the graph and :math:`D_{in}` is the\n            input feature size.\n        etypes : torch.LongTensor, or None\n            The edge type tensor of shape :math:`(E,)` where :math:`E` is\n            the number of edges of the graph. When there's only one edge type,\n            this argument can be skipped\n\n        Returns\n        -------\n        torch.Tensor\n            The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`\n            is the output feature size.\n        \"\"\"\n        with graph.local_scope():\n            assert graph.is_homogeneous, (\n                \"not a homogeneous graph; convert it with to_homogeneous \"\n                \"and pass in the edge type as argument\"\n            )\n            if self._n_etypes != 1:\n                assert (\n                    etypes.min() >= 0 and etypes.max() < self._n_etypes\n                ), \"edge type indices out of range [0, {})\".format(\n                    self._n_etypes\n                )\n\n            zero_pad = feat.new_zeros(\n                (feat.shape[0], self._out_feats - feat.shape[1])\n            )\n            feat = th.cat([feat, zero_pad], -1)\n\n            for _ in range(self._n_steps):\n                if self._n_etypes == 1 and etypes is None:\n                    # Fast path when graph has only one edge type\n                    graph.ndata[\"h\"] = self.linears[0](feat)\n                    graph.update_all(fn.copy_u(\"h\", \"m\"), fn.sum(\"m\", \"a\"))\n                    a = graph.ndata.pop(\"a\")  # (N, D)\n                else:\n                    graph.ndata[\"h\"] = feat\n                    for i in range(self._n_etypes):\n                        eids = (\n                            th.nonzero(etypes == i, as_tuple=False)\n                            .view(-1)\n                            .type(graph.idtype)\n                        )\n                        if len(eids) > 0:\n                            graph.apply_edges(\n                                lambda edges: {\n                                    \"W_e*h\": self.linears[i](edges.src[\"h\"])\n                                },\n                                eids,\n                            )\n                    graph.update_all(fn.copy_e(\"W_e*h\", \"m\"), fn.sum(\"m\", \"a\"))\n                    a = graph.ndata.pop(\"a\")  # (N, D)\n                feat = self.gru(a, feat)\n            return feat\n"
  },
  {
    "path": "python/dgl/nn/pytorch/conv/gatv2conv.py",
    "content": "\"\"\"Torch modules for graph attention networks v2 (GATv2).\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name\nimport torch as th\nfrom torch import nn\n\nfrom .... import function as fn\nfrom ....base import DGLError\nfrom ....utils import expand_as_pair\nfrom ...functional import edge_softmax\nfrom ..utils import Identity\n\n\n# pylint: enable=W0235\nclass GATv2Conv(nn.Module):\n    r\"\"\"GATv2 from `How Attentive are Graph Attention Networks?\n    <https://arxiv.org/pdf/2105.14491.pdf>`__\n\n    .. math::\n        h_i^{(l+1)} = \\sum_{j\\in \\mathcal{N}(i)} \\alpha_{ij}^{(l)} W^{(l)}_{right} h_j^{(l)}\n\n    where :math:`\\alpha_{ij}` is the attention score bewteen node :math:`i` and\n    node :math:`j`:\n\n    .. math::\n        \\alpha_{ij}^{(l)} &= \\mathrm{softmax_i} (e_{ij}^{(l)})\n\n        e_{ij}^{(l)} &= {\\vec{a}^T}^{(l)}\\mathrm{LeakyReLU}\\left(\n            W^{(l)}_{left} h_{i} + W^{(l)}_{right} h_{j}\\right)\n\n    Parameters\n    ----------\n    in_feats : int, or pair of ints\n        Input feature size; i.e, the number of dimensions of :math:`h_i^{(l)}`.\n        If the layer is to be applied to a unidirectional bipartite graph, `in_feats`\n        specifies the input feature size on both the source and destination nodes.\n        If a scalar is given, the source and destination node feature size\n        would take the same value.\n    out_feats : int\n        Output feature size; i.e, the number of dimensions of :math:`h_i^{(l+1)}`.\n    num_heads : int\n        Number of heads in Multi-Head Attention.\n    feat_drop : float, optional\n        Dropout rate on feature. Defaults: ``0``.\n    attn_drop : float, optional\n        Dropout rate on attention weight. Defaults: ``0``.\n    negative_slope : float, optional\n        LeakyReLU angle of negative slope. Defaults: ``0.2``.\n    residual : bool, optional\n        If True, use residual connection. Defaults: ``False``.\n    activation : callable activation function/layer or None, optional.\n        If not None, applies an activation function to the updated node features.\n        Default: ``None``.\n    allow_zero_in_degree : bool, optional\n        If there are 0-in-degree nodes in the graph, output for those nodes will be invalid\n        since no message will be passed to those nodes. This is harmful for some applications\n        causing silent performance regression. This module will raise a DGLError if it detects\n        0-in-degree nodes in input graph. By setting ``True``, it will suppress the check\n        and let the users handle it by themselves. Defaults: ``False``.\n    bias : bool, optional\n        If set to :obj:`False`, the layer will not learn\n        an additive bias. (default: :obj:`True`)\n    share_weights : bool, optional\n        If set to :obj:`True`, the same matrix for :math:`W_{left}` and :math:`W_{right}` in\n        the above equations, will be applied to the source and the target node of every edge.\n        (default: :obj:`False`)\n\n    Note\n    ----\n    Zero in-degree nodes will lead to invalid output value. This is because no message\n    will be passed to those nodes, the aggregation function will be applied on empty input.\n    A common practice to avoid this is to add a self-loop for each node in the graph if\n    it is homogeneous, which can be achieved by:\n\n    >>> g = ... # a DGLGraph\n    >>> g = dgl.add_self_loop(g)\n\n    Calling ``add_self_loop`` will not work for some graphs, for example, heterogeneous graph\n    since the edge type can not be decided for self_loop edges. Set ``allow_zero_in_degree``\n    to ``True`` for those cases to unblock the code and handle zero-in-degree nodes manually.\n    A common practise to handle this is to filter out the nodes with zero-in-degree when use\n    after conv.\n\n    Examples\n    --------\n    >>> import dgl\n    >>> import numpy as np\n    >>> import torch as th\n    >>> from dgl.nn import GATv2Conv\n\n    >>> # Case 1: Homogeneous graph\n    >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))\n    >>> g = dgl.add_self_loop(g)\n    >>> feat = th.ones(6, 10)\n    >>> gatv2conv = GATv2Conv(10, 2, num_heads=3)\n    >>> res = gatv2conv(g, feat)\n    >>> res\n    tensor([[[ 1.9599,  1.0239],\n            [ 3.2015, -0.5512],\n            [ 2.3700, -2.2182]],\n            [[ 1.9599,  1.0239],\n            [ 3.2015, -0.5512],\n            [ 2.3700, -2.2182]],\n            [[ 1.9599,  1.0239],\n            [ 3.2015, -0.5512],\n            [ 2.3700, -2.2182]],\n            [[ 1.9599,  1.0239],\n            [ 3.2015, -0.5512],\n            [ 2.3700, -2.2182]],\n            [[ 1.9599,  1.0239],\n            [ 3.2015, -0.5512],\n            [ 2.3700, -2.2182]],\n            [[ 1.9599,  1.0239],\n            [ 3.2015, -0.5512],\n            [ 2.3700, -2.2182]]], grad_fn=<GSpMMBackward>)\n\n    >>> # Case 2: Unidirectional bipartite graph\n    >>> u = [0, 1, 0, 0, 1]\n    >>> v = [0, 1, 2, 3, 2]\n    >>> g = dgl.heterograph({('A', 'r', 'B'): (u, v)})\n    >>> u_feat = th.tensor(np.random.rand(2, 5).astype(np.float32))\n    >>> v_feat = th.tensor(np.random.rand(4, 10).astype(np.float32))\n    >>> gatv2conv = GATv2Conv((5,10), 2, 3)\n    >>> res = gatv2conv(g, (u_feat, v_feat))\n    >>> res\n    tensor([[[-0.0935, -0.4273],\n            [-1.1850,  0.1123],\n            [-0.2002,  0.1155]],\n            [[ 0.1908, -1.2095],\n            [-0.0129,  0.6408],\n            [-0.8135,  0.1157]],\n            [[ 0.0596, -0.8487],\n            [-0.5421,  0.4022],\n            [-0.4805,  0.1156]],\n            [[-0.0935, -0.4273],\n            [-1.1850,  0.1123],\n            [-0.2002,  0.1155]]], grad_fn=<GSpMMBackward>)\n    \"\"\"\n\n    def __init__(\n        self,\n        in_feats,\n        out_feats,\n        num_heads,\n        feat_drop=0.0,\n        attn_drop=0.0,\n        negative_slope=0.2,\n        residual=False,\n        activation=None,\n        allow_zero_in_degree=False,\n        bias=True,\n        share_weights=False,\n    ):\n        super(GATv2Conv, self).__init__()\n        self._num_heads = num_heads\n        self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)\n        self._out_feats = out_feats\n        self._allow_zero_in_degree = allow_zero_in_degree\n        if isinstance(in_feats, tuple):\n            self.fc_src = nn.Linear(\n                self._in_src_feats, out_feats * num_heads, bias=bias\n            )\n            self.fc_dst = nn.Linear(\n                self._in_dst_feats, out_feats * num_heads, bias=bias\n            )\n        else:\n            self.fc_src = nn.Linear(\n                self._in_src_feats, out_feats * num_heads, bias=bias\n            )\n            if share_weights:\n                self.fc_dst = self.fc_src\n            else:\n                self.fc_dst = nn.Linear(\n                    self._in_src_feats, out_feats * num_heads, bias=bias\n                )\n        self.attn = nn.Parameter(th.FloatTensor(size=(1, num_heads, out_feats)))\n        self.feat_drop = nn.Dropout(feat_drop)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.leaky_relu = nn.LeakyReLU(negative_slope)\n        if residual:\n            if self._in_dst_feats != out_feats * num_heads:\n                self.res_fc = nn.Linear(\n                    self._in_dst_feats, num_heads * out_feats, bias=bias\n                )\n            else:\n                self.res_fc = Identity()\n        else:\n            self.register_buffer(\"res_fc\", None)\n        self.activation = activation\n        self.share_weights = share_weights\n        self.bias = bias\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        \"\"\"\n        Description\n        -----------\n        Reinitialize learnable parameters.\n\n        Note\n        ----\n        The fc weights :math:`W^{(l)}` are initialized using Glorot uniform initialization.\n        The attention weights are using xavier initialization method.\n        \"\"\"\n        gain = nn.init.calculate_gain(\"relu\")\n        nn.init.xavier_normal_(self.fc_src.weight, gain=gain)\n        if self.bias:\n            nn.init.constant_(self.fc_src.bias, 0)\n        if not self.share_weights:\n            nn.init.xavier_normal_(self.fc_dst.weight, gain=gain)\n            if self.bias:\n                nn.init.constant_(self.fc_dst.bias, 0)\n        nn.init.xavier_normal_(self.attn, gain=gain)\n        if isinstance(self.res_fc, nn.Linear):\n            nn.init.xavier_normal_(self.res_fc.weight, gain=gain)\n            if self.bias:\n                nn.init.constant_(self.res_fc.bias, 0)\n\n    def set_allow_zero_in_degree(self, set_value):\n        r\"\"\"\n        Description\n        -----------\n        Set allow_zero_in_degree flag.\n\n        Parameters\n        ----------\n        set_value : bool\n            The value to be set to the flag.\n        \"\"\"\n        self._allow_zero_in_degree = set_value\n\n    def forward(self, graph, feat, get_attention=False):\n        r\"\"\"\n        Description\n        -----------\n        Compute graph attention network layer.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The graph.\n        feat : torch.Tensor or pair of torch.Tensor\n            If a torch.Tensor is given, the input feature of shape :math:`(N, D_{in})` where\n            :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.\n            If a pair of torch.Tensor is given, the pair must contain two tensors of shape\n            :math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.\n        get_attention : bool, optional\n            Whether to return the attention values. Default to False.\n\n        Returns\n        -------\n        torch.Tensor\n            The output feature of shape :math:`(N, H, D_{out})` where :math:`H`\n            is the number of heads, and :math:`D_{out}` is size of output feature.\n        torch.Tensor, optional\n            The attention values of shape :math:`(E, H, 1)`, where :math:`E` is the number of\n            edges. This is returned only when :attr:`get_attention` is ``True``.\n\n        Raises\n        ------\n        DGLError\n            If there are 0-in-degree nodes in the input graph, it will raise DGLError\n            since no message will be passed to those nodes. This will cause invalid output.\n            The error can be ignored by setting ``allow_zero_in_degree`` parameter to ``True``.\n        \"\"\"\n        with graph.local_scope():\n            if not self._allow_zero_in_degree:\n                if (graph.in_degrees() == 0).any():\n                    raise DGLError(\n                        \"There are 0-in-degree nodes in the graph, \"\n                        \"output for those nodes will be invalid. \"\n                        \"This is harmful for some applications, \"\n                        \"causing silent performance regression. \"\n                        \"Adding self-loop on the input graph by \"\n                        \"calling `g = dgl.add_self_loop(g)` will resolve \"\n                        \"the issue. Setting ``allow_zero_in_degree`` \"\n                        \"to be `True` when constructing this module will \"\n                        \"suppress the check and let the code run.\"\n                    )\n\n            if isinstance(feat, tuple):\n                h_src = self.feat_drop(feat[0])\n                h_dst = self.feat_drop(feat[1])\n                feat_src = self.fc_src(h_src).view(\n                    -1, self._num_heads, self._out_feats\n                )\n                feat_dst = self.fc_dst(h_dst).view(\n                    -1, self._num_heads, self._out_feats\n                )\n            else:\n                h_src = h_dst = self.feat_drop(feat)\n                feat_src = self.fc_src(h_src).view(\n                    -1, self._num_heads, self._out_feats\n                )\n                if self.share_weights:\n                    feat_dst = feat_src\n                else:\n                    feat_dst = self.fc_dst(h_dst).view(\n                        -1, self._num_heads, self._out_feats\n                    )\n                if graph.is_block:\n                    feat_dst = feat_dst[: graph.number_of_dst_nodes()]\n                    h_dst = h_dst[: graph.number_of_dst_nodes()]\n            graph.srcdata.update(\n                {\"el\": feat_src}\n            )  # (num_src_edge, num_heads, out_dim)\n            graph.dstdata.update({\"er\": feat_dst})\n            graph.apply_edges(fn.u_add_v(\"el\", \"er\", \"e\"))\n            e = self.leaky_relu(\n                graph.edata.pop(\"e\")\n            )  # (num_src_edge, num_heads, out_dim)\n            e = (\n                (e * self.attn).sum(dim=-1).unsqueeze(dim=2)\n            )  # (num_edge, num_heads, 1)\n            # compute softmax\n            graph.edata[\"a\"] = self.attn_drop(\n                edge_softmax(graph, e)\n            )  # (num_edge, num_heads)\n            # message passing\n            graph.update_all(fn.u_mul_e(\"el\", \"a\", \"m\"), fn.sum(\"m\", \"ft\"))\n            rst = graph.dstdata[\"ft\"]\n            # residual\n            if self.res_fc is not None:\n                if h_dst.numel() != 0:\n                    resval = self.res_fc(h_dst).view(\n                        h_dst.shape[0], -1, self._out_feats\n                    )\n                    rst = rst + resval\n            # activation\n            if self.activation:\n                rst = self.activation(rst)\n\n            if get_attention:\n                return rst, graph.edata[\"a\"]\n            else:\n                return rst\n"
  },
  {
    "path": "python/dgl/nn/pytorch/conv/gcn2conv.py",
    "content": "\"\"\"Torch Module for Graph Convolutional Network via Initial residual\n    and Identity mapping (GCNII) layer\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name\nimport math\n\nimport torch as th\nfrom torch import nn\n\nfrom .... import function as fn\nfrom ....base import DGLError\nfrom .graphconv import EdgeWeightNorm\n\n\nclass GCN2Conv(nn.Module):\n    r\"\"\"Graph Convolutional Network via Initial residual\n    and Identity mapping (GCNII) from `Simple and Deep Graph Convolutional\n    Networks <https://arxiv.org/abs/2007.02133>`__\n\n    It is mathematically is defined as follows:\n\n    .. math::\n\n        \\mathbf{h}^{(l+1)} =\\left( (1 - \\alpha)(\\mathbf{D}^{-1/2} \\mathbf{\\hat{A}}\n        \\mathbf{D}^{-1/2})\\mathbf{h}^{(l)} + \\alpha {\\mathbf{h}^{(0)}} \\right)\n        \\left( (1 - \\beta_l) \\mathbf{I} + \\beta_l \\mathbf{W} \\right)\n\n    where :math:`\\mathbf{\\hat{A}}` is the adjacency matrix with self-loops,\n    :math:`\\mathbf{D}_{ii} = \\sum_{j=0} \\mathbf{A}_{ij}` is its diagonal degree matrix,\n    :math:`\\mathbf{h}^{(0)}` is the initial node features,\n    :math:`\\mathbf{h}^{(l)}` is the feature of layer :math:`l`,\n    :math:`\\alpha` is the fraction of initial node features, and\n    :math:`\\beta_l` is the hyperparameter to tune the strength of identity mapping.\n    It is defined by :math:`\\beta_l = \\log(\\frac{\\lambda}{l}+1)\\approx\\frac{\\lambda}{l}`,\n    where :math:`\\lambda` is a hyperparameter. :math:`\\beta` ensures that the decay of\n    the weight matrix adaptively increases as we stack more layers.\n\n    Parameters\n    ----------\n    in_feats : int\n        Input feature size; i.e, the number of dimensions of :math:`h_j^{(l)}`.\n    layer : int\n        the index of current layer.\n    alpha : float\n        The fraction of the initial input features. Default: ``0.1``\n    lambda_ : float\n        The hyperparameter to ensure the decay of the weight matrix\n        adaptively increases. Default: ``1``\n    project_initial_features : bool\n        Whether to share a weight matrix between initial features and\n        smoothed features. Default: ``True``\n    bias : bool, optional\n        If True, adds a learnable bias to the output. Default: ``True``.\n    activation : callable activation function/layer or None, optional\n        If not None, applies an activation function to the updated node features.\n        Default: ``None``.\n    allow_zero_in_degree : bool, optional\n        If there are 0-in-degree nodes in the graph, output for those nodes will be invalid\n        since no message will be passed to those nodes. This is harmful for some applications\n        causing silent performance regression. This module will raise a DGLError if it detects\n        0-in-degree nodes in input graph. By setting ``True``, it will suppress the check\n        and let the users handle it by themselves. Default: ``False``.\n\n    Note\n    ----\n    Zero in-degree nodes will lead to invalid output value. This is because no message\n    will be passed to those nodes, the aggregation function will be appied on empty input.\n    A common practice to avoid this is to add a self-loop for each node in the graph if\n    it is homogeneous, which can be achieved by:\n\n    >>> g = ... # a DGLGraph\n    >>> g = dgl.add_self_loop(g)\n\n    Calling ``add_self_loop`` will not work for some graphs, for example, heterogeneous graph\n    since the edge type can not be decided for self_loop edges. Set ``allow_zero_in_degree``\n    to ``True`` for those cases to unblock the code and handle zero-in-degree nodes manually.\n    A common practise to handle this is to filter out the nodes with zero-in-degree when use\n    after conv.\n\n    Examples\n    --------\n    >>> import dgl\n    >>> import numpy as np\n    >>> import torch as th\n    >>> from dgl.nn import GCN2Conv\n\n    >>> # Homogeneous graph\n    >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))\n    >>> feat = th.ones(6, 3)\n    >>> g = dgl.add_self_loop(g)\n    >>> conv1 = GCN2Conv(3, layer=1, alpha=0.5, \\\n    ...         project_initial_features=True, allow_zero_in_degree=True)\n    >>> conv2 = GCN2Conv(3, layer=2, alpha=0.5, \\\n    ...         project_initial_features=True, allow_zero_in_degree=True)\n    >>> res = feat\n    >>> res = conv1(g, res, feat)\n    >>> res = conv2(g, res, feat)\n    >>> print(res)\n    tensor([[1.3803, 3.3191, 2.9572],\n            [1.3803, 3.3191, 2.9572],\n            [1.3803, 3.3191, 2.9572],\n            [1.4770, 3.8326, 3.2451],\n            [1.3623, 3.2102, 2.8679],\n            [1.3803, 3.3191, 2.9572]], grad_fn=<AddBackward0>)\n\n    \"\"\"\n\n    def __init__(\n        self,\n        in_feats,\n        layer,\n        alpha=0.1,\n        lambda_=1,\n        project_initial_features=True,\n        allow_zero_in_degree=False,\n        bias=True,\n        activation=None,\n    ):\n        super().__init__()\n\n        self._in_feats = in_feats\n        self._project_initial_features = project_initial_features\n\n        self.alpha = alpha\n        self.beta = math.log(lambda_ / layer + 1)\n\n        self._bias = bias\n        self._activation = activation\n        self._allow_zero_in_degree = allow_zero_in_degree\n\n        self.weight1 = nn.Parameter(th.Tensor(self._in_feats, self._in_feats))\n\n        if self._project_initial_features:\n            self.register_parameter(\"weight2\", None)\n        else:\n            self.weight2 = nn.Parameter(\n                th.Tensor(self._in_feats, self._in_feats)\n            )\n\n        if self._bias:\n            self.bias = nn.Parameter(th.Tensor(self._in_feats))\n        else:\n            self.register_parameter(\"bias\", None)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        r\"\"\"\n\n        Description\n        -----------\n        Reinitialize learnable parameters.\n\n        \"\"\"\n        nn.init.normal_(self.weight1)\n        if not self._project_initial_features:\n            nn.init.normal_(self.weight2)\n        if self._bias:\n            nn.init.zeros_(self.bias)\n\n    def set_allow_zero_in_degree(self, set_value):\n        r\"\"\"\n\n        Description\n        -----------\n        Set allow_zero_in_degree flag.\n\n        Parameters\n        ----------\n        set_value : bool\n            The value to be set to the flag.\n        \"\"\"\n        self._allow_zero_in_degree = set_value\n\n    def forward(self, graph, feat, feat_0, edge_weight=None):\n        r\"\"\"\n\n        Description\n        -----------\n        Compute graph convolution.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The graph.\n        feat : torch.Tensor\n            The input feature of shape\n            :math:`(N, D_{in})`\n            where :math:`D_{in}` is the size of input feature and :math:`N` is the number of nodes.\n        feat_0 : torch.Tensor\n            The initial feature of shape :math:`(N, D_{in})`\n        edge_weight: torch.Tensor, optional\n            edge_weight to use in the message passing process. This is equivalent to\n            using weighted adjacency matrix in the equation above, and\n            :math:`\\tilde{D}^{-1/2}\\tilde{A} \\tilde{D}^{-1/2}`\n            is based on :class:`dgl.nn.pytorch.conv.graphconv.EdgeWeightNorm`.\n\n\n        Returns\n        -------\n        torch.Tensor\n            The output feature\n\n        Raises\n        ------\n        DGLError\n            If there are 0-in-degree nodes in the input graph, it will raise DGLError\n            since no message will be passed to those nodes. This will cause invalid output.\n            The error can be ignored by setting ``allow_zero_in_degree`` parameter to ``True``.\n\n        Note\n        ----\n        * Input shape: :math:`(N, *, \\text{in_feats})` where * means any number of additional\n          dimensions, :math:`N` is the number of nodes.\n        * Output shape: :math:`(N, *, \\text{out_feats})` where all but the last dimension are\n          the same shape as the input.\n        * Weight shape: :math:`(\\text{in_feats}, \\text{out_feats})`.\n        \"\"\"\n        with graph.local_scope():\n            if not self._allow_zero_in_degree:\n                if (graph.in_degrees() == 0).any():\n                    raise DGLError(\n                        \"There are 0-in-degree nodes in the graph, \"\n                        \"output for those nodes will be invalid. \"\n                        \"This is harmful for some applications, \"\n                        \"causing silent performance regression. \"\n                        \"Adding self-loop on the input graph by \"\n                        \"calling `g = dgl.add_self_loop(g)` will resolve \"\n                        \"the issue. Setting ``allow_zero_in_degree`` \"\n                        \"to be `True` when constructing this module will \"\n                        \"suppress the check and let the code run.\"\n                    )\n\n            # normalize  to get smoothed representation\n            if edge_weight is None:\n                degs = graph.in_degrees().to(feat).clamp(min=1)\n                norm = th.pow(degs, -0.5)\n                norm = norm.to(feat.device).unsqueeze(1)\n            else:\n                edge_weight = EdgeWeightNorm(\"both\")(graph, edge_weight)\n\n            if edge_weight is None:\n                feat = feat * norm\n            graph.ndata[\"h\"] = feat\n            msg_func = fn.copy_u(\"h\", \"m\")\n            if edge_weight is not None:\n                graph.edata[\"_edge_weight\"] = edge_weight\n                msg_func = fn.u_mul_e(\"h\", \"_edge_weight\", \"m\")\n            graph.update_all(msg_func, fn.sum(\"m\", \"h\"))\n            feat = graph.ndata.pop(\"h\")\n            if edge_weight is None:\n                feat = feat * norm\n            # scale\n            feat = feat * (1 - self.alpha)\n\n            # initial residual connection to the first layer\n            feat_0 = feat_0[: feat.size(0)] * self.alpha\n            feat_sum = feat + feat_0\n\n            if self._project_initial_features:\n                feat_proj_sum = feat_sum @ self.weight1\n            else:\n                feat_proj_sum = feat @ self.weight1 + feat_0 @ self.weight2\n\n            rst = (1 - self.beta) * feat_sum + self.beta * feat_proj_sum\n\n            if self._bias:\n                rst = rst + self.bias\n\n            if self._activation is not None:\n                rst = self._activation(rst)\n\n            return rst\n\n    def extra_repr(self):\n        \"\"\"Set the extra representation of the module,\n        which will come into effect when printing the model.\n        \"\"\"\n        summary = \"in={_in_feats}\"\n        summary += \", alpha={alpha}, beta={beta}\"\n        if \"self._bias\" in self.__dict__:\n            summary += \", bias={bias}\"\n        if \"self._activation\" in self.__dict__:\n            summary += \", activation={_activation}\"\n\n        return summary.format(**self.__dict__)\n"
  },
  {
    "path": "python/dgl/nn/pytorch/conv/ginconv.py",
    "content": "\"\"\"Torch Module for Graph Isomorphism Network layer\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name\nimport torch as th\nfrom torch import nn\n\nfrom .... import function as fn\nfrom ....utils import expand_as_pair\n\n\nclass GINConv(nn.Module):\n    r\"\"\"Graph Isomorphism Network layer from `How Powerful are Graph\n    Neural Networks? <https://arxiv.org/pdf/1810.00826.pdf>`__\n\n    .. math::\n        h_i^{(l+1)} = f_\\Theta \\left((1 + \\epsilon) h_i^{l} +\n        \\mathrm{aggregate}\\left(\\left\\{h_j^{l}, j\\in\\mathcal{N}(i)\n        \\right\\}\\right)\\right)\n\n    If a weight tensor on each edge is provided, the weighted graph convolution is defined as:\n\n    .. math::\n        h_i^{(l+1)} = f_\\Theta \\left((1 + \\epsilon) h_i^{l} +\n        \\mathrm{aggregate}\\left(\\left\\{e_{ji} h_j^{l}, j\\in\\mathcal{N}(i)\n        \\right\\}\\right)\\right)\n\n    where :math:`e_{ji}` is the weight on the edge from node :math:`j` to node :math:`i`.\n    Please make sure that `e_{ji}` is broadcastable with `h_j^{l}`.\n\n    Parameters\n    ----------\n    apply_func : callable activation function/layer or None\n        If not None, apply this function to the updated node feature,\n        the :math:`f_\\Theta` in the formula, default: None.\n    aggregator_type : str\n        Aggregator type to use (``sum``, ``max`` or ``mean``), default: 'sum'.\n    init_eps : float, optional\n        Initial :math:`\\epsilon` value, default: ``0``.\n    learn_eps : bool, optional\n        If True, :math:`\\epsilon` will be a learnable parameter. Default: ``False``.\n    activation : callable activation function/layer or None, optional\n        If not None, applies an activation function to the updated node features.\n        Default: ``None``.\n\n    Examples\n    --------\n    >>> import dgl\n    >>> import numpy as np\n    >>> import torch as th\n    >>> from dgl.nn import GINConv\n    >>>\n    >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))\n    >>> feat = th.ones(6, 10)\n    >>> lin = th.nn.Linear(10, 10)\n    >>> conv = GINConv(lin, 'max')\n    >>> res = conv(g, feat)\n    >>> res\n    tensor([[-0.4821,  0.0207, -0.7665,  0.5721, -0.4682, -0.2134, -0.5236,  1.2855,\n            0.8843, -0.8764],\n            [-0.4821,  0.0207, -0.7665,  0.5721, -0.4682, -0.2134, -0.5236,  1.2855,\n            0.8843, -0.8764],\n            [-0.4821,  0.0207, -0.7665,  0.5721, -0.4682, -0.2134, -0.5236,  1.2855,\n            0.8843, -0.8764],\n            [-0.4821,  0.0207, -0.7665,  0.5721, -0.4682, -0.2134, -0.5236,  1.2855,\n            0.8843, -0.8764],\n            [-0.4821,  0.0207, -0.7665,  0.5721, -0.4682, -0.2134, -0.5236,  1.2855,\n            0.8843, -0.8764],\n            [-0.1804,  0.0758, -0.5159,  0.3569, -0.1408, -0.1395, -0.2387,  0.7773,\n            0.5266, -0.4465]], grad_fn=<AddmmBackward>)\n\n    >>> # With activation\n    >>> from torch.nn.functional import relu\n    >>> conv = GINConv(lin, 'max', activation=relu)\n    >>> res = conv(g, feat)\n    >>> res\n    tensor([[5.0118, 0.0000, 0.0000, 3.9091, 1.3371, 0.0000, 0.0000, 0.0000, 0.0000,\n             0.0000],\n            [5.0118, 0.0000, 0.0000, 3.9091, 1.3371, 0.0000, 0.0000, 0.0000, 0.0000,\n             0.0000],\n            [5.0118, 0.0000, 0.0000, 3.9091, 1.3371, 0.0000, 0.0000, 0.0000, 0.0000,\n             0.0000],\n            [5.0118, 0.0000, 0.0000, 3.9091, 1.3371, 0.0000, 0.0000, 0.0000, 0.0000,\n             0.0000],\n            [5.0118, 0.0000, 0.0000, 3.9091, 1.3371, 0.0000, 0.0000, 0.0000, 0.0000,\n             0.0000],\n            [2.5011, 0.0000, 0.0089, 2.0541, 0.8262, 0.0000, 0.0000, 0.1371, 0.0000,\n             0.0000]], grad_fn=<ReluBackward0>)\n    \"\"\"\n\n    def __init__(\n        self,\n        apply_func=None,\n        aggregator_type=\"sum\",\n        init_eps=0,\n        learn_eps=False,\n        activation=None,\n    ):\n        super(GINConv, self).__init__()\n        self.apply_func = apply_func\n        self._aggregator_type = aggregator_type\n        self.activation = activation\n        if aggregator_type not in (\"sum\", \"max\", \"mean\"):\n            raise KeyError(\n                \"Aggregator type {} not recognized.\".format(aggregator_type)\n            )\n        # to specify whether eps is trainable or not.\n        if learn_eps:\n            self.eps = th.nn.Parameter(th.FloatTensor([init_eps]))\n        else:\n            self.register_buffer(\"eps\", th.FloatTensor([init_eps]))\n\n    def forward(self, graph, feat, edge_weight=None):\n        r\"\"\"\n\n        Description\n        -----------\n        Compute Graph Isomorphism Network layer.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The graph.\n        feat : torch.Tensor or pair of torch.Tensor\n            If a torch.Tensor is given, the input feature of shape :math:`(N, D_{in})` where\n            :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.\n            If a pair of torch.Tensor is given, the pair must contain two tensors of shape\n            :math:`(N_{in}, D_{in})` and :math:`(N_{out}, D_{in})`.\n            If ``apply_func`` is not None, :math:`D_{in}` should\n            fit the input dimensionality requirement of ``apply_func``.\n        edge_weight : torch.Tensor, optional\n            Optional tensor on the edge. If given, the convolution will weight\n            with regard to the message.\n\n        Returns\n        -------\n        torch.Tensor\n            The output feature of shape :math:`(N, D_{out})` where\n            :math:`D_{out}` is the output dimensionality of ``apply_func``.\n            If ``apply_func`` is None, :math:`D_{out}` should be the same\n            as input dimensionality.\n        \"\"\"\n        _reducer = getattr(fn, self._aggregator_type)\n        with graph.local_scope():\n            aggregate_fn = fn.copy_u(\"h\", \"m\")\n            if edge_weight is not None:\n                assert edge_weight.shape[0] == graph.num_edges()\n                graph.edata[\"_edge_weight\"] = edge_weight\n                aggregate_fn = fn.u_mul_e(\"h\", \"_edge_weight\", \"m\")\n\n            feat_src, feat_dst = expand_as_pair(feat, graph)\n            graph.srcdata[\"h\"] = feat_src\n            graph.update_all(aggregate_fn, _reducer(\"m\", \"neigh\"))\n            rst = (1 + self.eps) * feat_dst + graph.dstdata[\"neigh\"]\n            if self.apply_func is not None:\n                rst = self.apply_func(rst)\n            # activation\n            if self.activation is not None:\n                rst = self.activation(rst)\n            return rst\n"
  },
  {
    "path": "python/dgl/nn/pytorch/conv/gineconv.py",
    "content": "\"\"\"Torch Module for Graph Isomorphism Network layer variant with edge features\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name\nimport torch as th\nimport torch.nn.functional as F\nfrom torch import nn\n\nfrom .... import function as fn\nfrom ....utils import expand_as_pair\n\n\nclass GINEConv(nn.Module):\n    r\"\"\"Graph Isomorphism Network with Edge Features, introduced by\n    `Strategies for Pre-training Graph Neural Networks <https://arxiv.org/abs/1905.12265>`__\n\n    .. math::\n        h_i^{(l+1)} = f_\\Theta \\left((1 + \\epsilon) h_i^{l} +\n        \\sum_{j\\in\\mathcal{N}(i)}\\mathrm{ReLU}(h_j^{l} + e_{j,i}^{l})\\right)\n\n    where :math:`e_{j,i}^{l}` is the edge feature.\n\n    Parameters\n    ----------\n    apply_func : callable module or None\n        The :math:`f_\\Theta` in the formula. If not None, it will be applied to\n        the updated node features. The default value is None.\n    init_eps : float, optional\n        Initial :math:`\\epsilon` value, default: ``0``.\n    learn_eps : bool, optional\n        If True, :math:`\\epsilon` will be a learnable parameter. Default: ``False``.\n\n    Examples\n    --------\n\n    >>> import dgl\n    >>> import torch\n    >>> import torch.nn as nn\n    >>> from dgl.nn import GINEConv\n\n    >>> g = dgl.graph(([0, 1, 2], [1, 1, 3]))\n    >>> in_feats = 10\n    >>> out_feats = 20\n    >>> nfeat = torch.randn(g.num_nodes(), in_feats)\n    >>> efeat = torch.randn(g.num_edges(), in_feats)\n    >>> conv = GINEConv(nn.Linear(in_feats, out_feats))\n    >>> res = conv(g, nfeat, efeat)\n    >>> print(res.shape)\n    torch.Size([4, 20])\n    \"\"\"\n\n    def __init__(self, apply_func=None, init_eps=0, learn_eps=False):\n        super(GINEConv, self).__init__()\n        self.apply_func = apply_func\n        # to specify whether eps is trainable or not.\n        if learn_eps:\n            self.eps = nn.Parameter(th.FloatTensor([init_eps]))\n        else:\n            self.register_buffer(\"eps\", th.FloatTensor([init_eps]))\n\n    def message(self, edges):\n        r\"\"\"User-defined Message Function\"\"\"\n        return {\"m\": F.relu(edges.src[\"hn\"] + edges.data[\"he\"])}\n\n    def forward(self, graph, node_feat, edge_feat):\n        r\"\"\"Forward computation.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The graph.\n        node_feat : torch.Tensor or pair of torch.Tensor\n            If a torch.Tensor is given, it is the input feature of shape :math:`(N, D_{in})` where\n            :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.\n            If a pair of torch.Tensor is given, the pair must contain two tensors of shape\n            :math:`(N_{in}, D_{in})` and :math:`(N_{out}, D_{in})`.\n            If ``apply_func`` is not None, :math:`D_{in}` should\n            fit the input feature size requirement of ``apply_func``.\n        edge_feat : torch.Tensor\n            Edge feature. It is a tensor of shape :math:`(E, D_{in})` where :math:`E`\n            is the number of edges.\n\n        Returns\n        -------\n        torch.Tensor\n            The output feature of shape :math:`(N, D_{out})` where\n            :math:`D_{out}` is the output feature size of ``apply_func``.\n            If ``apply_func`` is None, :math:`D_{out}` should be the same\n            as :math:`D_{in}`.\n        \"\"\"\n        with graph.local_scope():\n            feat_src, feat_dst = expand_as_pair(node_feat, graph)\n            graph.srcdata[\"hn\"] = feat_src\n            graph.edata[\"he\"] = edge_feat\n            graph.update_all(self.message, fn.sum(\"m\", \"neigh\"))\n            rst = (1 + self.eps) * feat_dst + graph.dstdata[\"neigh\"]\n            if self.apply_func is not None:\n                rst = self.apply_func(rst)\n            return rst\n"
  },
  {
    "path": "python/dgl/nn/pytorch/conv/gmmconv.py",
    "content": "\"\"\"Torch Module for GMM Conv\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name\nimport torch as th\nfrom torch import nn\nfrom torch.nn import init\n\nfrom .... import function as fn\nfrom ....base import DGLError\nfrom ....utils import expand_as_pair\nfrom ..utils import Identity\n\n\nclass GMMConv(nn.Module):\n    r\"\"\"Gaussian Mixture Model Convolution layer from `Geometric Deep\n    Learning on Graphs and Manifolds using Mixture Model CNNs\n    <https://arxiv.org/abs/1611.08402>`__\n\n    .. math::\n        u_{ij} &= f(x_i, x_j), x_j \\in \\mathcal{N}(i)\n\n        w_k(u) &= \\exp\\left(-\\frac{1}{2}(u-\\mu_k)^T \\Sigma_k^{-1} (u - \\mu_k)\\right)\n\n        h_i^{l+1} &= \\mathrm{aggregate}\\left(\\left\\{\\frac{1}{K}\n         \\sum_{k}^{K} w_k(u_{ij}), \\forall j\\in \\mathcal{N}(i)\\right\\}\\right)\n\n    where :math:`u` denotes the pseudo-coordinates between a vertex and one of its neighbor,\n    computed using function :math:`f`, :math:`\\Sigma_k^{-1}` and :math:`\\mu_k` are\n    learnable parameters representing the covariance matrix and mean vector of a Gaussian kernel.\n\n    Parameters\n    ----------\n    in_feats : int\n        Number of input features; i.e., the number of dimensions of :math:`x_i`.\n    out_feats : int\n        Number of output features; i.e., the number of dimensions of :math:`h_i^{(l+1)}`.\n    dim : int\n        Dimensionality of pseudo-coordinte; i.e, the number of dimensions of :math:`u_{ij}`.\n    n_kernels : int\n        Number of kernels :math:`K`.\n    aggregator_type : str\n        Aggregator type (``sum``, ``mean``, ``max``). Default: ``sum``.\n    residual : bool\n        If True, use residual connection inside this layer. Default: ``False``.\n    bias : bool\n        If True, adds a learnable bias to the output. Default: ``True``.\n    allow_zero_in_degree : bool, optional\n        If there are 0-in-degree nodes in the graph, output for those nodes will be invalid\n        since no message will be passed to those nodes. This is harmful for some applications\n        causing silent performance regression. This module will raise a DGLError if it detects\n        0-in-degree nodes in input graph. By setting ``True``, it will suppress the check\n        and let the users handle it by themselves. Default: ``False``.\n\n    Note\n    ----\n    Zero in-degree nodes will lead to invalid output value. This is because no message\n    will be passed to those nodes, the aggregation function will be appied on empty input.\n    A common practice to avoid this is to add a self-loop for each node in the graph if\n    it is homogeneous, which can be achieved by:\n\n    >>> g = ... # a DGLGraph\n    >>> g = dgl.add_self_loop(g)\n\n    Calling ``add_self_loop`` will not work for some graphs, for example, heterogeneous graph\n    since the edge type can not be decided for self_loop edges. Set ``allow_zero_in_degree``\n    to ``True`` for those cases to unblock the code and handle zero-in-degree nodes manually.\n    A common practise to handle this is to filter out the nodes with zero-in-degree when use\n    after conv.\n\n    Examples\n    --------\n    >>> import dgl\n    >>> import numpy as np\n    >>> import torch as th\n    >>> from dgl.nn import GMMConv\n\n    >>> # Case 1: Homogeneous graph\n    >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))\n    >>> g = dgl.add_self_loop(g)\n    >>> feat = th.ones(6, 10)\n    >>> conv = GMMConv(10, 2, 3, 2, 'mean')\n    >>> pseudo = th.ones(12, 3)\n    >>> res = conv(g, feat, pseudo)\n    >>> res\n    tensor([[-0.3462, -0.2654],\n            [-0.3462, -0.2654],\n            [-0.3462, -0.2654],\n            [-0.3462, -0.2654],\n            [-0.3462, -0.2654],\n            [-0.3462, -0.2654]], grad_fn=<AddBackward0>)\n\n    >>> # Case 2: Unidirectional bipartite graph\n    >>> u = [0, 1, 0, 0, 1]\n    >>> v = [0, 1, 2, 3, 2]\n    >>> g = dgl.heterograph({('_N', '_E', '_N'):(u, v)})\n    >>> u_fea = th.rand(2, 5)\n    >>> v_fea = th.rand(4, 10)\n    >>> pseudo = th.ones(5, 3)\n    >>> conv = GMMConv((10, 5), 2, 3, 2, 'mean')\n    >>> res = conv(g, (u_fea, v_fea), pseudo)\n    >>> res\n    tensor([[-0.1107, -0.1559],\n            [-0.1646, -0.2326],\n            [-0.1377, -0.1943],\n            [-0.1107, -0.1559]], grad_fn=<AddBackward0>)\n    \"\"\"\n\n    def __init__(\n        self,\n        in_feats,\n        out_feats,\n        dim,\n        n_kernels,\n        aggregator_type=\"sum\",\n        residual=False,\n        bias=True,\n        allow_zero_in_degree=False,\n    ):\n        super(GMMConv, self).__init__()\n        self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)\n        self._out_feats = out_feats\n        self._dim = dim\n        self._n_kernels = n_kernels\n        self._allow_zero_in_degree = allow_zero_in_degree\n        if aggregator_type == \"sum\":\n            self._reducer = fn.sum\n        elif aggregator_type == \"mean\":\n            self._reducer = fn.mean\n        elif aggregator_type == \"max\":\n            self._reducer = fn.max\n        else:\n            raise KeyError(\n                \"Aggregator type {} not recognized.\".format(aggregator_type)\n            )\n\n        self.mu = nn.Parameter(th.Tensor(n_kernels, dim))\n        self.inv_sigma = nn.Parameter(th.Tensor(n_kernels, dim))\n        self.fc = nn.Linear(\n            self._in_src_feats, n_kernels * out_feats, bias=False\n        )\n        if residual:\n            if self._in_dst_feats != out_feats:\n                self.res_fc = nn.Linear(\n                    self._in_dst_feats, out_feats, bias=False\n                )\n            else:\n                self.res_fc = Identity()\n        else:\n            self.register_buffer(\"res_fc\", None)\n\n        if bias:\n            self.bias = nn.Parameter(th.Tensor(out_feats))\n        else:\n            self.register_buffer(\"bias\", None)\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        r\"\"\"\n\n        Description\n        -----------\n        Reinitialize learnable parameters.\n\n        Note\n        ----\n        The fc parameters are initialized using Glorot uniform initialization\n        and the bias is initialized to be zero.\n        The mu weight is initialized using normal distribution and\n        inv_sigma is initialized with constant value 1.0.\n        \"\"\"\n        gain = init.calculate_gain(\"relu\")\n        init.xavier_normal_(self.fc.weight, gain=gain)\n        if isinstance(self.res_fc, nn.Linear):\n            init.xavier_normal_(self.res_fc.weight, gain=gain)\n        init.normal_(self.mu.data, 0, 0.1)\n        init.constant_(self.inv_sigma.data, 1)\n        if self.bias is not None:\n            init.zeros_(self.bias.data)\n\n    def set_allow_zero_in_degree(self, set_value):\n        r\"\"\"\n\n        Description\n        -----------\n        Set allow_zero_in_degree flag.\n\n        Parameters\n        ----------\n        set_value : bool\n            The value to be set to the flag.\n        \"\"\"\n        self._allow_zero_in_degree = set_value\n\n    def forward(self, graph, feat, pseudo):\n        \"\"\"\n\n        Description\n        -----------\n        Compute Gaussian Mixture Model Convolution layer.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The graph.\n        feat : torch.Tensor\n            If a single tensor is given, the input feature of shape :math:`(N, D_{in})` where\n            :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.\n            If a pair of tensors are given, the pair must contain two tensors of shape\n            :math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.\n        pseudo : torch.Tensor\n            The pseudo coordinate tensor of shape :math:`(E, D_{u})` where\n            :math:`E` is the number of edges of the graph and :math:`D_{u}`\n            is the dimensionality of pseudo coordinate.\n\n        Returns\n        -------\n        torch.Tensor\n            The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`\n            is the output feature size.\n\n        Raises\n        ------\n        DGLError\n            If there are 0-in-degree nodes in the input graph, it will raise DGLError\n            since no message will be passed to those nodes. This will cause invalid output.\n            The error can be ignored by setting ``allow_zero_in_degree`` parameter to ``True``.\n        \"\"\"\n        with graph.local_scope():\n            if not self._allow_zero_in_degree:\n                if (graph.in_degrees() == 0).any():\n                    raise DGLError(\n                        \"There are 0-in-degree nodes in the graph, \"\n                        \"output for those nodes will be invalid. \"\n                        \"This is harmful for some applications, \"\n                        \"causing silent performance regression. \"\n                        \"Adding self-loop on the input graph by \"\n                        \"calling `g = dgl.add_self_loop(g)` will resolve \"\n                        \"the issue. Setting ``allow_zero_in_degree`` \"\n                        \"to be `True` when constructing this module will \"\n                        \"suppress the check and let the code run.\"\n                    )\n\n            feat_src, feat_dst = expand_as_pair(feat, graph)\n            graph.srcdata[\"h\"] = self.fc(feat_src).view(\n                -1, self._n_kernels, self._out_feats\n            )\n            E = graph.num_edges()\n            # compute gaussian weight\n            gaussian = -0.5 * (\n                (\n                    pseudo.view(E, 1, self._dim)\n                    - self.mu.view(1, self._n_kernels, self._dim)\n                )\n                ** 2\n            )\n            gaussian = gaussian * (\n                self.inv_sigma.view(1, self._n_kernels, self._dim) ** 2\n            )\n            gaussian = th.exp(gaussian.sum(dim=-1, keepdim=True))  # (E, K, 1)\n            graph.edata[\"w\"] = gaussian\n            graph.update_all(fn.u_mul_e(\"h\", \"w\", \"m\"), self._reducer(\"m\", \"h\"))\n            rst = graph.dstdata[\"h\"].sum(1)\n            # residual connection\n            if self.res_fc is not None:\n                rst = rst + self.res_fc(feat_dst)\n            # bias\n            if self.bias is not None:\n                rst = rst + self.bias\n            return rst\n"
  },
  {
    "path": "python/dgl/nn/pytorch/conv/graphconv.py",
    "content": "\"\"\"Torch modules for graph convolutions(GCN).\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name\nimport torch as th\nfrom torch import nn\nfrom torch.nn import init\n\nfrom .... import function as fn\nfrom ....base import DGLError\nfrom ....convert import block_to_graph\nfrom ....heterograph import DGLBlock\nfrom ....transforms import reverse\nfrom ....utils import expand_as_pair\n\n\nclass EdgeWeightNorm(nn.Module):\n    r\"\"\"This module normalizes positive scalar edge weights on a graph\n    following the form in `GCN <https://arxiv.org/abs/1609.02907>`__.\n\n    Mathematically, setting ``norm='both'`` yields the following normalization term:\n\n    .. math::\n      c_{ji} = (\\sqrt{\\sum_{k\\in\\mathcal{N}(j)}e_{jk}}\\sqrt{\\sum_{k\\in\\mathcal{N}(i)}e_{ki}})\n\n    And, setting ``norm='right'`` yields the following normalization term:\n\n    .. math::\n      c_{ji} = (\\sum_{k\\in\\mathcal{N}(i)}e_{ki})\n\n    where :math:`e_{ji}` is the scalar weight on the edge from node :math:`j` to node :math:`i`.\n\n    The module returns the normalized weight :math:`e_{ji} / c_{ji}`.\n\n    Parameters\n    ----------\n    norm : str, optional\n        The normalizer as specified above. Default is `'both'`.\n    eps : float, optional\n        A small offset value in the denominator. Default is 0.\n\n    Examples\n    --------\n    >>> import dgl\n    >>> import numpy as np\n    >>> import torch as th\n    >>> from dgl.nn import EdgeWeightNorm, GraphConv\n\n    >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))\n    >>> g = dgl.add_self_loop(g)\n    >>> feat = th.ones(6, 10)\n    >>> edge_weight = th.tensor([0.5, 0.6, 0.4, 0.7, 0.9, 0.1, 1, 1, 1, 1, 1, 1])\n    >>> norm = EdgeWeightNorm(norm='both')\n    >>> norm_edge_weight = norm(g, edge_weight)\n    >>> conv = GraphConv(10, 2, norm='none', weight=True, bias=True)\n    >>> res = conv(g, feat, edge_weight=norm_edge_weight)\n    >>> print(res)\n    tensor([[-1.1849, -0.7525],\n            [-1.3514, -0.8582],\n            [-1.2384, -0.7865],\n            [-1.9949, -1.2669],\n            [-1.3658, -0.8674],\n            [-0.8323, -0.5286]], grad_fn=<AddBackward0>)\n    \"\"\"\n\n    def __init__(self, norm=\"both\", eps=0.0):\n        super(EdgeWeightNorm, self).__init__()\n        self._norm = norm\n        self._eps = eps\n\n    def forward(self, graph, edge_weight):\n        r\"\"\"\n\n        Description\n        -----------\n        Compute normalized edge weight for the GCN model.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The graph.\n        edge_weight : torch.Tensor\n            Unnormalized scalar weights on the edges.\n            The shape is expected to be :math:`(|E|)`.\n\n        Returns\n        -------\n        torch.Tensor\n            The normalized edge weight.\n\n        Raises\n        ------\n        DGLError\n            Case 1:\n            The edge weight is multi-dimensional. Currently this module\n            only supports a scalar weight on each edge.\n\n            Case 2:\n            The edge weight has non-positive values with ``norm='both'``.\n            This will trigger square root and division by a non-positive number.\n        \"\"\"\n        with graph.local_scope():\n            if isinstance(graph, DGLBlock):\n                graph = block_to_graph(graph)\n            if len(edge_weight.shape) > 1:\n                raise DGLError(\n                    \"Currently the normalization is only defined \"\n                    \"on scalar edge weight. Please customize the \"\n                    \"normalization for your high-dimensional weights.\"\n                )\n            if self._norm == \"both\" and th.any(edge_weight <= 0).item():\n                raise DGLError(\n                    'Non-positive edge weight detected with `norm=\"both\"`. '\n                    \"This leads to square root of zero or negative values.\"\n                )\n\n            dev = graph.device\n            dtype = edge_weight.dtype\n            graph.srcdata[\"_src_out_w\"] = th.ones(\n                graph.number_of_src_nodes(), dtype=dtype, device=dev\n            )\n            graph.dstdata[\"_dst_in_w\"] = th.ones(\n                graph.number_of_dst_nodes(), dtype=dtype, device=dev\n            )\n            graph.edata[\"_edge_w\"] = edge_weight\n\n            if self._norm == \"both\":\n                reversed_g = reverse(graph)\n                reversed_g.edata[\"_edge_w\"] = edge_weight\n                reversed_g.update_all(\n                    fn.copy_e(\"_edge_w\", \"m\"), fn.sum(\"m\", \"out_weight\")\n                )\n                degs = reversed_g.dstdata[\"out_weight\"] + self._eps\n                norm = th.pow(degs, -0.5)\n                graph.srcdata[\"_src_out_w\"] = norm\n\n            if self._norm != \"none\":\n                graph.update_all(\n                    fn.copy_e(\"_edge_w\", \"m\"), fn.sum(\"m\", \"in_weight\")\n                )\n                degs = graph.dstdata[\"in_weight\"] + self._eps\n                if self._norm == \"both\":\n                    norm = th.pow(degs, -0.5)\n                else:\n                    norm = 1.0 / degs\n                graph.dstdata[\"_dst_in_w\"] = norm\n\n            graph.apply_edges(\n                lambda e: {\n                    \"_norm_edge_weights\": e.src[\"_src_out_w\"]\n                    * e.dst[\"_dst_in_w\"]\n                    * e.data[\"_edge_w\"]\n                }\n            )\n            return graph.edata[\"_norm_edge_weights\"]\n\n\n# pylint: disable=W0235\nclass GraphConv(nn.Module):\n    r\"\"\"Graph convolutional layer from `Semi-Supervised Classification with Graph Convolutional\n    Networks <https://arxiv.org/abs/1609.02907>`__\n\n    Mathematically it is defined as follows:\n\n    .. math::\n      h_i^{(l+1)} = \\sigma(b^{(l)} + \\sum_{j\\in\\mathcal{N}(i)}\\frac{1}{c_{ji}}h_j^{(l)}W^{(l)})\n\n    where :math:`\\mathcal{N}(i)` is the set of neighbors of node :math:`i`,\n    :math:`c_{ji}` is the product of the square root of node degrees\n    (i.e.,  :math:`c_{ji} = \\sqrt{|\\mathcal{N}(j)|}\\sqrt{|\\mathcal{N}(i)|}`),\n    and :math:`\\sigma` is an activation function.\n\n    If a weight tensor on each edge is provided, the weighted graph convolution is defined as:\n\n    .. math::\n      h_i^{(l+1)} = \\sigma(b^{(l)} + \\sum_{j\\in\\mathcal{N}(i)}\\frac{e_{ji}}{c_{ji}}h_j^{(l)}W^{(l)})\n\n    where :math:`e_{ji}` is the scalar weight on the edge from node :math:`j` to node :math:`i`.\n    This is NOT equivalent to the weighted graph convolutional network formulation in the paper.\n\n    To customize the normalization term :math:`c_{ji}`, one can first set ``norm='none'`` for\n    the model, and send the pre-normalized :math:`e_{ji}` to the forward computation. We provide\n    :class:`~dgl.nn.pytorch.EdgeWeightNorm` to normalize scalar edge weight following the GCN paper.\n\n    Parameters\n    ----------\n    in_feats : int\n        Input feature size; i.e, the number of dimensions of :math:`h_j^{(l)}`.\n    out_feats : int\n        Output feature size; i.e., the number of dimensions of :math:`h_i^{(l+1)}`.\n    norm : str, optional\n        How to apply the normalizer.  Can be one of the following values:\n\n        * ``right``, to divide the aggregated messages by each node's in-degrees,\n          which is equivalent to averaging the received messages.\n\n        * ``none``, where no normalization is applied.\n\n        * ``both`` (default), where the messages are scaled with :math:`1/c_{ji}` above, equivalent\n          to symmetric normalization.\n\n        * ``left``, to divide the messages sent out from each node by its out-degrees,\n          equivalent to random walk normalization.\n    weight : bool, optional\n        If True, apply a linear layer. Otherwise, aggregating the messages\n        without a weight matrix.\n    bias : bool, optional\n        If True, adds a learnable bias to the output. Default: ``True``.\n    activation : callable activation function/layer or None, optional\n        If not None, applies an activation function to the updated node features.\n        Default: ``None``.\n    allow_zero_in_degree : bool, optional\n        If there are 0-in-degree nodes in the graph, output for those nodes will be invalid\n        since no message will be passed to those nodes. This is harmful for some applications\n        causing silent performance regression. This module will raise a DGLError if it detects\n        0-in-degree nodes in input graph. By setting ``True``, it will suppress the check\n        and let the users handle it by themselves. Default: ``False``.\n\n    Attributes\n    ----------\n    weight : torch.Tensor\n        The learnable weight tensor.\n    bias : torch.Tensor\n        The learnable bias tensor.\n\n    Note\n    ----\n    Zero in-degree nodes will lead to invalid output value. This is because no message\n    will be passed to those nodes, the aggregation function will be appied on empty input.\n    A common practice to avoid this is to add a self-loop for each node in the graph if\n    it is homogeneous, which can be achieved by:\n\n    >>> g = ... # a DGLGraph\n    >>> g = dgl.add_self_loop(g)\n\n    Calling ``add_self_loop`` will not work for some graphs, for example, heterogeneous graph\n    since the edge type can not be decided for self_loop edges. Set ``allow_zero_in_degree``\n    to ``True`` for those cases to unblock the code and handle zero-in-degree nodes manually.\n    A common practise to handle this is to filter out the nodes with zero-in-degree when use\n    after conv.\n\n    Examples\n    --------\n    >>> import dgl\n    >>> import numpy as np\n    >>> import torch as th\n    >>> from dgl.nn import GraphConv\n\n    >>> # Case 1: Homogeneous graph\n    >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))\n    >>> g = dgl.add_self_loop(g)\n    >>> feat = th.ones(6, 10)\n    >>> conv = GraphConv(10, 2, norm='both', weight=True, bias=True)\n    >>> res = conv(g, feat)\n    >>> print(res)\n    tensor([[ 1.3326, -0.2797],\n            [ 1.4673, -0.3080],\n            [ 1.3326, -0.2797],\n            [ 1.6871, -0.3541],\n            [ 1.7711, -0.3717],\n            [ 1.0375, -0.2178]], grad_fn=<AddBackward0>)\n    >>> # allow_zero_in_degree example\n    >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))\n    >>> conv = GraphConv(10, 2, norm='both', weight=True, bias=True, allow_zero_in_degree=True)\n    >>> res = conv(g, feat)\n    >>> print(res)\n    tensor([[-0.2473, -0.4631],\n            [-0.3497, -0.6549],\n            [-0.3497, -0.6549],\n            [-0.4221, -0.7905],\n            [-0.3497, -0.6549],\n            [ 0.0000,  0.0000]], grad_fn=<AddBackward0>)\n\n    >>> # Case 2: Unidirectional bipartite graph\n    >>> u = [0, 1, 0, 0, 1]\n    >>> v = [0, 1, 2, 3, 2]\n    >>> g = dgl.heterograph({('_U', '_E', '_V') : (u, v)})\n    >>> u_fea = th.rand(2, 5)\n    >>> v_fea = th.rand(4, 5)\n    >>> conv = GraphConv(5, 2, norm='both', weight=True, bias=True)\n    >>> res = conv(g, (u_fea, v_fea))\n    >>> res\n    tensor([[-0.2994,  0.6106],\n            [-0.4482,  0.5540],\n            [-0.5287,  0.8235],\n            [-0.2994,  0.6106]], grad_fn=<AddBackward0>)\n    \"\"\"\n\n    def __init__(\n        self,\n        in_feats,\n        out_feats,\n        norm=\"both\",\n        weight=True,\n        bias=True,\n        activation=None,\n        allow_zero_in_degree=False,\n    ):\n        super(GraphConv, self).__init__()\n        if norm not in (\"none\", \"both\", \"right\", \"left\"):\n            raise DGLError(\n                'Invalid norm value. Must be either \"none\", \"both\", \"right\" or \"left\".'\n                ' But got \"{}\".'.format(norm)\n            )\n        self._in_feats = in_feats\n        self._out_feats = out_feats\n        self._norm = norm\n        self._allow_zero_in_degree = allow_zero_in_degree\n\n        if weight:\n            self.weight = nn.Parameter(th.Tensor(in_feats, out_feats))\n        else:\n            self.register_parameter(\"weight\", None)\n\n        if bias:\n            self.bias = nn.Parameter(th.Tensor(out_feats))\n        else:\n            self.register_parameter(\"bias\", None)\n\n        self.reset_parameters()\n\n        self._activation = activation\n\n    def reset_parameters(self):\n        r\"\"\"\n\n        Description\n        -----------\n        Reinitialize learnable parameters.\n\n        Note\n        ----\n        The model parameters are initialized as in the\n        `original implementation <https://github.com/tkipf/gcn/blob/master/gcn/layers.py>`__\n        where the weight :math:`W^{(l)}` is initialized using Glorot uniform initialization\n        and the bias is initialized to be zero.\n\n        \"\"\"\n        if self.weight is not None:\n            init.xavier_uniform_(self.weight)\n        if self.bias is not None:\n            init.zeros_(self.bias)\n\n    def set_allow_zero_in_degree(self, set_value):\n        r\"\"\"\n\n        Description\n        -----------\n        Set allow_zero_in_degree flag.\n\n        Parameters\n        ----------\n        set_value : bool\n            The value to be set to the flag.\n        \"\"\"\n        self._allow_zero_in_degree = set_value\n\n    def forward(self, graph, feat, weight=None, edge_weight=None):\n        r\"\"\"\n\n        Description\n        -----------\n        Compute graph convolution.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The graph.\n        feat : torch.Tensor or pair of torch.Tensor\n            If a torch.Tensor is given, it represents the input feature of shape\n            :math:`(N, D_{in})`\n            where :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.\n            If a pair of torch.Tensor is given, which is the case for bipartite graph, the pair\n            must contain two tensors of shape :math:`(N_{in}, D_{in_{src}})` and\n            :math:`(N_{out}, D_{in_{dst}})`.\n        weight : torch.Tensor, optional\n            Optional external weight tensor.\n        edge_weight : torch.Tensor, optional\n            Optional tensor on the edge. If given, the convolution will weight\n            with regard to the message.\n\n        Returns\n        -------\n        torch.Tensor\n            The output feature\n\n        Raises\n        ------\n        DGLError\n            Case 1:\n            If there are 0-in-degree nodes in the input graph, it will raise DGLError\n            since no message will be passed to those nodes. This will cause invalid output.\n            The error can be ignored by setting ``allow_zero_in_degree`` parameter to ``True``.\n\n            Case 2:\n            External weight is provided while at the same time the module\n            has defined its own weight parameter.\n\n        Note\n        ----\n        * Input shape: :math:`(N, *, \\text{in_feats})` where * means any number of additional\n          dimensions, :math:`N` is the number of nodes.\n        * Output shape: :math:`(N, *, \\text{out_feats})` where all but the last dimension are\n          the same shape as the input.\n        * Weight shape: :math:`(\\text{in_feats}, \\text{out_feats})`.\n        \"\"\"\n        with graph.local_scope():\n            if not self._allow_zero_in_degree:\n                if (graph.in_degrees() == 0).any():\n                    raise DGLError(\n                        \"There are 0-in-degree nodes in the graph, \"\n                        \"output for those nodes will be invalid. \"\n                        \"This is harmful for some applications, \"\n                        \"causing silent performance regression. \"\n                        \"Adding self-loop on the input graph by \"\n                        \"calling `g = dgl.add_self_loop(g)` will resolve \"\n                        \"the issue. Setting ``allow_zero_in_degree`` \"\n                        \"to be `True` when constructing this module will \"\n                        \"suppress the check and let the code run.\"\n                    )\n            aggregate_fn = fn.copy_u(\"h\", \"m\")\n            if edge_weight is not None:\n                assert edge_weight.shape[0] == graph.num_edges()\n                graph.edata[\"_edge_weight\"] = edge_weight\n                aggregate_fn = fn.u_mul_e(\"h\", \"_edge_weight\", \"m\")\n\n            # (BarclayII) For RGCN on heterogeneous graphs we need to support GCN on bipartite.\n            feat_src, feat_dst = expand_as_pair(feat, graph)\n            if self._norm in [\"left\", \"both\"]:\n                degs = graph.out_degrees().to(feat_src).clamp(min=1)\n                if self._norm == \"both\":\n                    norm = th.pow(degs, -0.5)\n                else:\n                    norm = 1.0 / degs\n                shp = norm.shape + (1,) * (feat_src.dim() - 1)\n                norm = th.reshape(norm, shp)\n                feat_src = feat_src * norm\n\n            if weight is not None:\n                if self.weight is not None:\n                    raise DGLError(\n                        \"External weight is provided while at the same time the\"\n                        \" module has defined its own weight parameter. Please\"\n                        \" create the module with flag weight=False.\"\n                    )\n            else:\n                weight = self.weight\n\n            if self._in_feats > self._out_feats:\n                # mult W first to reduce the feature size for aggregation.\n                if weight is not None:\n                    feat_src = th.matmul(feat_src, weight)\n                graph.srcdata[\"h\"] = feat_src\n                graph.update_all(aggregate_fn, fn.sum(msg=\"m\", out=\"h\"))\n                rst = graph.dstdata[\"h\"]\n            else:\n                # aggregate first then mult W\n                graph.srcdata[\"h\"] = feat_src\n                graph.update_all(aggregate_fn, fn.sum(msg=\"m\", out=\"h\"))\n                rst = graph.dstdata[\"h\"]\n                if weight is not None:\n                    rst = th.matmul(rst, weight)\n\n            if self._norm in [\"right\", \"both\"]:\n                degs = graph.in_degrees().to(feat_dst).clamp(min=1)\n                if self._norm == \"both\":\n                    norm = th.pow(degs, -0.5)\n                else:\n                    norm = 1.0 / degs\n                shp = norm.shape + (1,) * (feat_dst.dim() - 1)\n                norm = th.reshape(norm, shp)\n                rst = rst * norm\n\n            if self.bias is not None:\n                rst = rst + self.bias\n\n            if self._activation is not None:\n                rst = self._activation(rst)\n\n            return rst\n\n    def extra_repr(self):\n        \"\"\"Set the extra representation of the module,\n        which will come into effect when printing the model.\n        \"\"\"\n        summary = \"in={_in_feats}, out={_out_feats}\"\n        summary += \", normalization={_norm}\"\n        if \"_activation\" in self.__dict__:\n            summary += \", activation={_activation}\"\n        return summary.format(**self.__dict__)\n"
  },
  {
    "path": "python/dgl/nn/pytorch/conv/grouprevres.py",
    "content": "\"\"\"Torch module for grouped reversible residual connections for GNNs\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name, C0116, R1728\nfrom copy import deepcopy\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\n\n\nclass InvertibleCheckpoint(torch.autograd.Function):\n    r\"\"\"Extension of torch.autograd\"\"\"\n\n    @staticmethod\n    def forward(ctx, fn, fn_inverse, num_inputs, *inputs_and_weights):\n        ctx.fn = fn\n        ctx.fn_inverse = fn_inverse\n        ctx.weights = inputs_and_weights[num_inputs:]\n        inputs = inputs_and_weights[:num_inputs]\n        ctx.input_requires_grad = []\n\n        with torch.no_grad():\n            # Make a detached copy, which shares the storage\n            x = []\n            for element in inputs:\n                if isinstance(element, torch.Tensor):\n                    x.append(element.detach())\n                    ctx.input_requires_grad.append(element.requires_grad)\n                else:\n                    x.append(element)\n                    ctx.input_requires_grad.append(None)\n            # Detach the output, which then allows discarding the intermediary results\n            outputs = ctx.fn(*x).detach_()\n\n        # clear memory of input node features\n        inputs[1].untyped_storage().resize_(0)\n\n        # store for backward pass\n        ctx.inputs = [inputs]\n        ctx.outputs = [outputs]\n\n        return outputs\n\n    @staticmethod\n    def backward(ctx, *grad_outputs):\n        if not torch.autograd._is_checkpoint_valid():\n            raise RuntimeError(\n                \"InvertibleCheckpoint is not compatible with .grad(), \\\n                               please use .backward() if possible\"\n            )\n        # retrieve input and output tensor nodes\n        if len(ctx.outputs) == 0:\n            raise RuntimeError(\n                \"Trying to perform backward on the InvertibleCheckpoint \\\n                               for more than once.\"\n            )\n        inputs = ctx.inputs.pop()\n        outputs = ctx.outputs.pop()\n\n        # reconstruct input node features\n        with torch.no_grad():\n            # inputs[0] is DGLGraph and inputs[1] is input node features\n            inputs_inverted = ctx.fn_inverse(\n                *((inputs[0], outputs) + inputs[2:])\n            )\n            # clear memory of outputs\n            outputs.untyped_storage().resize_(0)\n\n            x = inputs[1]\n            x.untyped_storage().resize_(int(np.prod(x.size())))\n            x.set_(inputs_inverted)\n\n        # compute gradients\n        with torch.set_grad_enabled(True):\n            detached_inputs = []\n            for i, element in enumerate(inputs):\n                if isinstance(element, torch.Tensor):\n                    element = element.detach()\n                    element.requires_grad = ctx.input_requires_grad[i]\n                detached_inputs.append(element)\n\n            detached_inputs = tuple(detached_inputs)\n            temp_output = ctx.fn(*detached_inputs)\n\n        filtered_detached_inputs = tuple(\n            filter(\n                lambda x: getattr(x, \"requires_grad\", False), detached_inputs\n            )\n        )\n        gradients = torch.autograd.grad(\n            outputs=(temp_output,),\n            inputs=filtered_detached_inputs + ctx.weights,\n            grad_outputs=grad_outputs,\n        )\n\n        input_gradients = []\n        i = 0\n        for rg in ctx.input_requires_grad:\n            if rg:\n                input_gradients.append(gradients[i])\n                i += 1\n            else:\n                input_gradients.append(None)\n\n        gradients = tuple(input_gradients) + gradients[-len(ctx.weights) :]\n\n        return (None, None, None) + gradients\n\n\nclass GroupRevRes(nn.Module):\n    r\"\"\"Grouped reversible residual connections for GNNs, as introduced in\n    `Training Graph Neural Networks with 1000 Layers <https://arxiv.org/abs/2106.07476>`__\n\n    It uniformly partitions an input node feature :math:`X` into :math:`C` groups\n    :math:`X_1, X_2, \\cdots, X_C` across the channel dimension. Besides, it makes\n    :math:`C` copies of the input GNN module :math:`f_{w1}, \\cdots, f_{wC}`. In the\n    forward pass, each GNN module only takes the corresponding group of node features.\n\n    The output node representations :math:`X^{'}` are computed as follows.\n\n    .. math::\n\n        X_0^{'} = \\sum_{i=2}^{C}X_i\n\n        X_i^{'} = f_{wi}(X_{i-1}^{'}, g, U) + X_i, i\\in\\{1,\\cdots,C\\}\n\n        X^{'} = X_1^{'} \\, \\Vert \\, \\ldots \\, \\Vert \\, X_C^{'}\n\n    where :math:`g` is the input graph, :math:`U` is arbitrary additional input arguments like\n    edge features, and :math:`\\, \\Vert \\,` is concatenation.\n\n    Parameters\n    ----------\n    gnn_module : nn.Module\n        GNN module for message passing. :attr:`GroupRevRes` will clone the module for\n        :attr:`groups`-1 number of times, yielding :attr:`groups` copies in total.\n        The input and output node representation size need to be the same. Its forward\n        function needs to take a DGLGraph and the associated input node features in order,\n        optionally followed by additional arguments like edge features.\n    groups : int, optional\n        The number of groups.\n\n    Examples\n    --------\n\n    >>> import dgl\n    >>> import torch\n    >>> import torch.nn as nn\n    >>> from dgl.nn import GraphConv, GroupRevRes\n\n    >>> class GNNLayer(nn.Module):\n    ...     def __init__(self, feats, dropout=0.2):\n    ...         super(GNNLayer, self).__init__()\n    ...         # Use BatchNorm and dropout to prevent gradient vanishing\n    ...         # In particular if you use a large number of GNN layers\n    ...         self.norm = nn.BatchNorm1d(feats)\n    ...         self.conv = GraphConv(feats, feats)\n    ...         self.dropout = nn.Dropout(dropout)\n    ...\n    ...     def forward(self, g, x):\n    ...         x = self.norm(x)\n    ...         x = self.dropout(x)\n    ...         return self.conv(g, x)\n\n    >>> num_nodes = 5\n    >>> num_edges = 20\n    >>> feats = 32\n    >>> groups = 2\n    >>> g = dgl.rand_graph(num_nodes, num_edges)\n    >>> x = torch.randn(num_nodes, feats)\n    >>> conv = GNNLayer(feats // groups)\n    >>> model = GroupRevRes(conv, groups)\n    >>> out = model(g, x)\n    \"\"\"\n\n    def __init__(self, gnn_module, groups=2):\n        super(GroupRevRes, self).__init__()\n        self.gnn_modules = nn.ModuleList()\n        for i in range(groups):\n            if i == 0:\n                self.gnn_modules.append(gnn_module)\n            else:\n                self.gnn_modules.append(deepcopy(gnn_module))\n        self.groups = groups\n\n    def _forward(self, g, x, *args):\n        xs = torch.chunk(x, self.groups, dim=-1)\n\n        if len(args) == 0:\n            args_chunks = [()] * self.groups\n        else:\n            chunked_args = list(\n                map(lambda arg: torch.chunk(arg, self.groups, dim=-1), args)\n            )\n            args_chunks = list(zip(*chunked_args))\n        y_in = sum(xs[1:])\n\n        ys = []\n        for i in range(self.groups):\n            y_in = xs[i] + self.gnn_modules[i](g, y_in, *args_chunks[i])\n            ys.append(y_in)\n\n        out = torch.cat(ys, dim=-1)\n\n        return out\n\n    def _inverse(self, g, y, *args):\n        ys = torch.chunk(y, self.groups, dim=-1)\n\n        if len(args) == 0:\n            args_chunks = [()] * self.groups\n        else:\n            chunked_args = list(\n                map(lambda arg: torch.chunk(arg, self.groups, dim=-1), args)\n            )\n            args_chunks = list(zip(*chunked_args))\n\n        xs = []\n        for i in range(self.groups - 1, -1, -1):\n            if i != 0:\n                y_in = ys[i - 1]\n            else:\n                y_in = sum(xs)\n\n            x = ys[i] - self.gnn_modules[i](g, y_in, *args_chunks[i])\n            xs.append(x)\n\n        x = torch.cat(xs[::-1], dim=-1)\n\n        return x\n\n    def forward(self, g, x, *args):\n        r\"\"\"Apply the GNN module with grouped reversible residual connection.\n\n        Parameters\n        ----------\n        g : DGLGraph\n            The graph.\n        x : torch.Tensor\n            The input feature of shape :math:`(N, D_{in})`, where :math:`D_{in}` is size\n            of input feature, :math:`N` is the number of nodes.\n        args\n            Additional arguments to pass to :attr:`gnn_module`.\n\n        Returns\n        -------\n        torch.Tensor\n            The output feature of shape :math:`(N, D_{in})`.\n        \"\"\"\n        args = (g, x) + args\n        y = InvertibleCheckpoint.apply(\n            self._forward,\n            self._inverse,\n            len(args),\n            *(args + tuple([p for p in self.parameters() if p.requires_grad]))\n        )\n\n        return y\n"
  },
  {
    "path": "python/dgl/nn/pytorch/conv/hgtconv.py",
    "content": "\"\"\"Heterogeneous Graph Transformer\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name\nimport math\n\nimport torch\nimport torch.nn as nn\n\nfrom .... import function as fn\nfrom ..linear import TypedLinear\nfrom ..softmax import edge_softmax\n\n\nclass HGTConv(nn.Module):\n    r\"\"\"Heterogeneous graph transformer convolution from `Heterogeneous Graph Transformer\n    <https://arxiv.org/abs/2003.01332>`__\n\n    Given a graph :math:`G(V, E)` and input node features :math:`H^{(l-1)}`,\n    it computes the new node features as follows:\n\n    Compute a multi-head attention score for each edge :math:`(s, e, t)` in the graph:\n\n    .. math::\n\n      Attention(s, e, t) = \\text{Softmax}\\left(||_{i\\in[1,h]}ATT-head^i(s, e, t)\\right) \\\\\n      ATT-head^i(s, e, t) = \\left(K^i(s)W^{ATT}_{\\phi(e)}Q^i(t)^{\\top}\\right)\\cdot\n        \\frac{\\mu_{(\\tau(s),\\phi(e),\\tau(t)}}{\\sqrt{d}} \\\\\n      K^i(s) = \\text{K-Linear}^i_{\\tau(s)}(H^{(l-1)}[s]) \\\\\n      Q^i(t) = \\text{Q-Linear}^i_{\\tau(t)}(H^{(l-1)}[t]) \\\\\n\n    Compute the message to send on each edge :math:`(s, e, t)`:\n\n    .. math::\n\n      Message(s, e, t) = ||_{i\\in[1, h]} MSG-head^i(s, e, t) \\\\\n      MSG-head^i(s, e, t) = \\text{M-Linear}^i_{\\tau(s)}(H^{(l-1)}[s])W^{MSG}_{\\phi(e)} \\\\\n\n    Send messages to target nodes :math:`t` and aggregate:\n\n    .. math::\n\n      \\tilde{H}^{(l)}[t] = \\sum_{\\forall s\\in \\mathcal{N}(t)}\\left( Attention(s,e,t)\n      \\cdot Message(s,e,t)\\right)\n\n    Compute new node features:\n\n    .. math::\n\n      H^{(l)}[t]=\\text{A-Linear}_{\\tau(t)}(\\sigma(\\tilde(H)^{(l)}[t])) + H^{(l-1)}[t]\n\n    Parameters\n    ----------\n    in_size : int\n        Input node feature size.\n    head_size : int\n        Output head size. The output node feature size is ``head_size * num_heads``.\n    num_heads : int\n        Number of heads. The output node feature size is ``head_size * num_heads``.\n    num_ntypes : int\n        Number of node types.\n    num_etypes : int\n        Number of edge types.\n    dropout : optional, float\n        Dropout rate.\n    use_norm : optiona, bool\n        If true, apply a layer norm on the output node feature.\n\n    Examples\n    --------\n    \"\"\"\n\n    def __init__(\n        self,\n        in_size,\n        head_size,\n        num_heads,\n        num_ntypes,\n        num_etypes,\n        dropout=0.2,\n        use_norm=False,\n    ):\n        super().__init__()\n        self.in_size = in_size\n        self.head_size = head_size\n        self.num_heads = num_heads\n        self.sqrt_d = math.sqrt(head_size)\n        self.use_norm = use_norm\n\n        self.linear_k = TypedLinear(in_size, head_size * num_heads, num_ntypes)\n        self.linear_q = TypedLinear(in_size, head_size * num_heads, num_ntypes)\n        self.linear_v = TypedLinear(in_size, head_size * num_heads, num_ntypes)\n        self.linear_a = TypedLinear(\n            head_size * num_heads, head_size * num_heads, num_ntypes\n        )\n\n        self.relation_pri = nn.ParameterList(\n            [nn.Parameter(torch.ones(num_etypes)) for i in range(num_heads)]\n        )\n        self.relation_att = nn.ModuleList(\n            [\n                TypedLinear(head_size, head_size, num_etypes)\n                for i in range(num_heads)\n            ]\n        )\n        self.relation_msg = nn.ModuleList(\n            [\n                TypedLinear(head_size, head_size, num_etypes)\n                for i in range(num_heads)\n            ]\n        )\n        self.skip = nn.Parameter(torch.ones(num_ntypes))\n        self.drop = nn.Dropout(dropout)\n        if use_norm:\n            self.norm = nn.LayerNorm(head_size * num_heads)\n        if in_size != head_size * num_heads:\n            self.residual_w = nn.Parameter(\n                torch.Tensor(in_size, head_size * num_heads)\n            )\n            nn.init.xavier_uniform_(self.residual_w)\n\n    def forward(self, g, x, ntype, etype, *, presorted=False):\n        \"\"\"Forward computation.\n\n        Parameters\n        ----------\n        g : DGLGraph\n            The input graph.\n        x : torch.Tensor\n            A 2D tensor of node features. Shape: :math:`(|V|, D_{in})`.\n        ntype : torch.Tensor\n            An 1D integer tensor of node types. Shape: :math:`(|V|,)`.\n        etype : torch.Tensor\n            An 1D integer tensor of edge types. Shape: :math:`(|E|,)`.\n        presorted : bool, optional\n            Whether *both* the nodes and the edges of the input graph have been sorted by\n            their types. Forward on pre-sorted graph may be faster. Graphs created by\n            :func:`~dgl.to_homogeneous` automatically satisfy the condition.\n            Also see :func:`~dgl.reorder_graph` for manually reordering the nodes and edges.\n\n        Returns\n        -------\n        torch.Tensor\n            New node features. Shape: :math:`(|V|, D_{head} * N_{head})`.\n        \"\"\"\n        self.presorted = presorted\n        if g.is_block:\n            x_src = x\n            x_dst = x[: g.num_dst_nodes()]\n            srcntype = ntype\n            dstntype = ntype[: g.num_dst_nodes()]\n        else:\n            x_src = x\n            x_dst = x\n            srcntype = ntype\n            dstntype = ntype\n        with g.local_scope():\n            k = self.linear_k(x_src, srcntype, presorted).view(\n                -1, self.num_heads, self.head_size\n            )\n            q = self.linear_q(x_dst, dstntype, presorted).view(\n                -1, self.num_heads, self.head_size\n            )\n            v = self.linear_v(x_src, srcntype, presorted).view(\n                -1, self.num_heads, self.head_size\n            )\n            g.srcdata[\"k\"] = k\n            g.dstdata[\"q\"] = q\n            g.srcdata[\"v\"] = v\n            g.edata[\"etype\"] = etype\n            g.apply_edges(self.message)\n            g.edata[\"m\"] = g.edata[\"m\"] * edge_softmax(\n                g, g.edata[\"a\"]\n            ).unsqueeze(-1)\n            g.update_all(fn.copy_e(\"m\", \"m\"), fn.sum(\"m\", \"h\"))\n            h = g.dstdata[\"h\"].view(-1, self.num_heads * self.head_size)\n            # target-specific aggregation\n            h = self.drop(self.linear_a(h, dstntype, presorted))\n            alpha = torch.sigmoid(self.skip[dstntype]).unsqueeze(-1)\n            if x_dst.shape != h.shape:\n                h = h * alpha + (x_dst @ self.residual_w) * (1 - alpha)\n            else:\n                h = h * alpha + x_dst * (1 - alpha)\n            if self.use_norm:\n                h = self.norm(h)\n            return h\n\n    def message(self, edges):\n        \"\"\"Message function.\"\"\"\n        a, m = [], []\n        etype = edges.data[\"etype\"]\n        k = torch.unbind(edges.src[\"k\"], dim=1)\n        q = torch.unbind(edges.dst[\"q\"], dim=1)\n        v = torch.unbind(edges.src[\"v\"], dim=1)\n        for i in range(self.num_heads):\n            kw = self.relation_att[i](k[i], etype, self.presorted)  # (E, O)\n            a.append(\n                (kw * q[i]).sum(-1) * self.relation_pri[i][etype] / self.sqrt_d\n            )  # (E,)\n            m.append(\n                self.relation_msg[i](v[i], etype, self.presorted)\n            )  # (E, O)\n        return {\"a\": torch.stack(a, dim=1), \"m\": torch.stack(m, dim=1)}\n"
  },
  {
    "path": "python/dgl/nn/pytorch/conv/nnconv.py",
    "content": "\"\"\"Torch Module for NNConv layer\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name\nimport torch as th\nfrom torch import nn\nfrom torch.nn import init\n\nfrom .... import function as fn\nfrom ....utils import expand_as_pair\nfrom ..utils import Identity\n\n\nclass NNConv(nn.Module):\n    r\"\"\"Graph Convolution layer from `Neural Message Passing\n    for Quantum Chemistry <https://arxiv.org/pdf/1704.01212.pdf>`__\n\n    .. math::\n        h_{i}^{l+1} = h_{i}^{l} + \\mathrm{aggregate}\\left(\\left\\{\n        f_\\Theta (e_{ij}) \\cdot h_j^{l}, j\\in \\mathcal{N}(i) \\right\\}\\right)\n\n    where :math:`e_{ij}` is the edge feature, :math:`f_\\Theta` is a function\n    with learnable parameters.\n\n    Parameters\n    ----------\n    in_feats : int\n        Input feature size; i.e, the number of dimensions of :math:`h_j^{(l)}`.\n        NNConv can be applied on homogeneous graph and unidirectional\n        `bipartite graph <https://docs.dgl.ai/generated/dgl.bipartite.html?highlight=bipartite>`__.\n        If the layer is to be applied on a unidirectional bipartite graph, ``in_feats``\n        specifies the input feature size on both the source and destination nodes.  If\n        a scalar is given, the source and destination node feature size would take the\n        same value.\n    out_feats : int\n        Output feature size; i.e., the number of dimensions of :math:`h_i^{(l+1)}`.\n    edge_func : callable activation function/layer\n        Maps each edge feature to a vector of shape\n        ``(in_feats * out_feats)`` as weight to compute\n        messages.\n        Also is the :math:`f_\\Theta` in the formula.\n    aggregator_type : str\n        Aggregator type to use (``sum``, ``mean`` or ``max``).\n    residual : bool, optional\n        If True, use residual connection. Default: ``False``.\n    bias : bool, optional\n        If True, adds a learnable bias to the output. Default: ``True``.\n\n    Examples\n    --------\n    >>> import dgl\n    >>> import numpy as np\n    >>> import torch as th\n    >>> from dgl.nn import NNConv\n\n    >>> # Case 1: Homogeneous graph\n    >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))\n    >>> g = dgl.add_self_loop(g)\n    >>> feat = th.ones(6, 10)\n    >>> lin = th.nn.Linear(5, 20)\n    >>> def edge_func(efeat):\n    ...     return lin(efeat)\n    >>> efeat = th.ones(6+6, 5)\n    >>> conv = NNConv(10, 2, edge_func, 'mean')\n    >>> res = conv(g, feat, efeat)\n    >>> res\n    tensor([[-1.5243, -0.2719],\n            [-1.5243, -0.2719],\n            [-1.5243, -0.2719],\n            [-1.5243, -0.2719],\n            [-1.5243, -0.2719],\n            [-1.5243, -0.2719]], grad_fn=<AddBackward0>)\n\n    >>> # Case 2: Unidirectional bipartite graph\n    >>> u = [0, 1, 0, 0, 1]\n    >>> v = [0, 1, 2, 3, 2]\n    >>> g = dgl.heterograph({('_N', '_E', '_N'):(u, v)})\n    >>> u_feat = th.tensor(np.random.rand(2, 10).astype(np.float32))\n    >>> v_feat = th.tensor(np.random.rand(4, 10).astype(np.float32))\n    >>> conv = NNConv(10, 2, edge_func, 'mean')\n    >>> efeat = th.ones(5, 5)\n    >>> res = conv(g, (u_feat, v_feat), efeat)\n    >>> res\n    tensor([[-0.6568,  0.5042],\n            [ 0.9089, -0.5352],\n            [ 0.1261, -0.0155],\n            [-0.6568,  0.5042]], grad_fn=<AddBackward0>)\n    \"\"\"\n\n    def __init__(\n        self,\n        in_feats,\n        out_feats,\n        edge_func,\n        aggregator_type=\"mean\",\n        residual=False,\n        bias=True,\n    ):\n        super(NNConv, self).__init__()\n        self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)\n        self._out_feats = out_feats\n        self.edge_func = edge_func\n        if aggregator_type == \"sum\":\n            self.reducer = fn.sum\n        elif aggregator_type == \"mean\":\n            self.reducer = fn.mean\n        elif aggregator_type == \"max\":\n            self.reducer = fn.max\n        else:\n            raise KeyError(\n                \"Aggregator type {} not recognized: \".format(aggregator_type)\n            )\n        self._aggre_type = aggregator_type\n        if residual:\n            if self._in_dst_feats != out_feats:\n                self.res_fc = nn.Linear(\n                    self._in_dst_feats, out_feats, bias=False\n                )\n            else:\n                self.res_fc = Identity()\n        else:\n            self.register_buffer(\"res_fc\", None)\n        if bias:\n            self.bias = nn.Parameter(th.Tensor(out_feats))\n        else:\n            self.register_buffer(\"bias\", None)\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        r\"\"\"\n\n        Description\n        -----------\n        Reinitialize learnable parameters.\n\n        Note\n        ----\n        The model parameters are initialized using Glorot uniform initialization\n        and the bias is initialized to be zero.\n        \"\"\"\n        gain = init.calculate_gain(\"relu\")\n        if self.bias is not None:\n            nn.init.zeros_(self.bias)\n        if isinstance(self.res_fc, nn.Linear):\n            nn.init.xavier_normal_(self.res_fc.weight, gain=gain)\n\n    def forward(self, graph, feat, efeat):\n        r\"\"\"Compute MPNN Graph Convolution layer.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The graph.\n        feat : torch.Tensor or pair of torch.Tensor\n            The input feature of shape :math:`(N, D_{in})` where :math:`N`\n            is the number of nodes of the graph and :math:`D_{in}` is the\n            input feature size.\n        efeat : torch.Tensor\n            The edge feature of shape :math:`(E, *)`, which should fit the input\n            shape requirement of ``edge_func``. :math:`E` is the number of edges\n            of the graph.\n\n        Returns\n        -------\n        torch.Tensor\n            The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`\n            is the output feature size.\n        \"\"\"\n        with graph.local_scope():\n            feat_src, feat_dst = expand_as_pair(feat, graph)\n\n            # (n, d_in, 1)\n            graph.srcdata[\"h\"] = feat_src.unsqueeze(-1)\n            # (n, d_in, d_out)\n            graph.edata[\"w\"] = self.edge_func(efeat).view(\n                -1, self._in_src_feats, self._out_feats\n            )\n            # (n, d_in, d_out)\n            graph.update_all(\n                fn.u_mul_e(\"h\", \"w\", \"m\"), self.reducer(\"m\", \"neigh\")\n            )\n            rst = graph.dstdata[\"neigh\"].sum(dim=1)  # (n, d_out)\n            # residual connection\n            if self.res_fc is not None:\n                rst = rst + self.res_fc(feat_dst)\n            # bias\n            if self.bias is not None:\n                rst = rst + self.bias\n            return rst\n"
  },
  {
    "path": "python/dgl/nn/pytorch/conv/pnaconv.py",
    "content": "\"\"\"Torch Module for Principal Neighbourhood Aggregation Convolution Layer\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name\nimport numpy as np\nimport torch\nimport torch.nn as nn\n\n\ndef aggregate_mean(h):\n    \"\"\"mean aggregation\"\"\"\n    return torch.mean(h, dim=1)\n\n\ndef aggregate_max(h):\n    \"\"\"max aggregation\"\"\"\n    return torch.max(h, dim=1)[0]\n\n\ndef aggregate_min(h):\n    \"\"\"min aggregation\"\"\"\n    return torch.min(h, dim=1)[0]\n\n\ndef aggregate_sum(h):\n    \"\"\"sum aggregation\"\"\"\n    return torch.sum(h, dim=1)\n\n\ndef aggregate_std(h):\n    \"\"\"standard deviation aggregation\"\"\"\n    return torch.sqrt(aggregate_var(h) + 1e-30)\n\n\ndef aggregate_var(h):\n    \"\"\"variance aggregation\"\"\"\n    h_mean_squares = torch.mean(h * h, dim=1)\n    h_mean = torch.mean(h, dim=1)\n    var = torch.relu(h_mean_squares - h_mean * h_mean)\n    return var\n\n\ndef _aggregate_moment(h, n):\n    \"\"\"moment aggregation: for each node (E[(X-E[X])^n])^{1/n}\"\"\"\n    h_mean = torch.mean(h, dim=1, keepdim=True)\n    h_n = torch.mean(torch.pow(h - h_mean, n), dim=1)\n    rooted_h_n = torch.sign(h_n) * torch.pow(torch.abs(h_n) + 1e-30, 1.0 / n)\n    return rooted_h_n\n\n\ndef aggregate_moment_3(h):\n    \"\"\"moment aggregation with n=3\"\"\"\n    return _aggregate_moment(h, n=3)\n\n\ndef aggregate_moment_4(h):\n    \"\"\"moment aggregation with n=4\"\"\"\n    return _aggregate_moment(h, n=4)\n\n\ndef aggregate_moment_5(h):\n    \"\"\"moment aggregation with n=5\"\"\"\n    return _aggregate_moment(h, n=5)\n\n\ndef scale_identity(h):\n    \"\"\"identity scaling (no scaling operation)\"\"\"\n    return h\n\n\ndef scale_amplification(h, D, delta):\n    \"\"\"amplification scaling\"\"\"\n    return h * (np.log(D + 1) / delta)\n\n\ndef scale_attenuation(h, D, delta):\n    \"\"\"attenuation scaling\"\"\"\n    return h * (delta / np.log(D + 1))\n\n\nAGGREGATORS = {\n    \"mean\": aggregate_mean,\n    \"sum\": aggregate_sum,\n    \"max\": aggregate_max,\n    \"min\": aggregate_min,\n    \"std\": aggregate_std,\n    \"var\": aggregate_var,\n    \"moment3\": aggregate_moment_3,\n    \"moment4\": aggregate_moment_4,\n    \"moment5\": aggregate_moment_5,\n}\nSCALERS = {\n    \"identity\": scale_identity,\n    \"amplification\": scale_amplification,\n    \"attenuation\": scale_attenuation,\n}\n\n\nclass PNAConvTower(nn.Module):\n    \"\"\"A single PNA tower in PNA layers\"\"\"\n\n    def __init__(\n        self,\n        in_size,\n        out_size,\n        aggregators,\n        scalers,\n        delta,\n        dropout=0.0,\n        edge_feat_size=0,\n    ):\n        super(PNAConvTower, self).__init__()\n        self.in_size = in_size\n        self.out_size = out_size\n        self.aggregators = aggregators\n        self.scalers = scalers\n        self.delta = delta\n        self.edge_feat_size = edge_feat_size\n\n        self.M = nn.Linear(2 * in_size + edge_feat_size, in_size)\n        self.U = nn.Linear(\n            (len(aggregators) * len(scalers) + 1) * in_size, out_size\n        )\n        self.dropout = nn.Dropout(dropout)\n        self.batchnorm = nn.BatchNorm1d(out_size)\n\n    def reduce_func(self, nodes):\n        \"\"\"reduce function for PNA layer:\n        tensordot of multiple aggregation and scaling operations\"\"\"\n        msg = nodes.mailbox[\"msg\"]\n        degree = msg.size(1)\n        h = torch.cat(\n            [AGGREGATORS[agg](msg) for agg in self.aggregators], dim=1\n        )\n        h = torch.cat(\n            [\n                SCALERS[scaler](h, D=degree, delta=self.delta)\n                if scaler != \"identity\"\n                else h\n                for scaler in self.scalers\n            ],\n            dim=1,\n        )\n        return {\"h_neigh\": h}\n\n    def message(self, edges):\n        \"\"\"message function for PNA layer\"\"\"\n        if self.edge_feat_size > 0:\n            f = torch.cat(\n                [edges.src[\"h\"], edges.dst[\"h\"], edges.data[\"a\"]], dim=-1\n            )\n        else:\n            f = torch.cat([edges.src[\"h\"], edges.dst[\"h\"]], dim=-1)\n        return {\"msg\": self.M(f)}\n\n    def forward(self, graph, node_feat, edge_feat=None):\n        \"\"\"compute the forward pass of a single tower in PNA convolution layer\"\"\"\n        # calculate graph normalization factors\n        snorm_n = torch.cat(\n            [\n                torch.ones(N, 1).to(node_feat) / N\n                for N in graph.batch_num_nodes()\n            ],\n            dim=0,\n        ).sqrt()\n        with graph.local_scope():\n            graph.ndata[\"h\"] = node_feat\n            if self.edge_feat_size > 0:\n                assert edge_feat is not None, \"Edge features must be provided.\"\n                graph.edata[\"a\"] = edge_feat\n\n            graph.update_all(self.message, self.reduce_func)\n            h = self.U(torch.cat([node_feat, graph.ndata[\"h_neigh\"]], dim=-1))\n            h = h * snorm_n\n            return self.dropout(self.batchnorm(h))\n\n\nclass PNAConv(nn.Module):\n    r\"\"\"Principal Neighbourhood Aggregation Layer from `Principal Neighbourhood Aggregation\n    for Graph Nets <https://arxiv.org/abs/2004.05718>`__\n\n    A PNA layer is composed of multiple PNA towers. Each tower takes as input a split of the\n    input features, and computes the message passing as below.\n\n    .. math::\n        h_i^(l+1) = U(h_i^l, \\oplus_{(i,j)\\in E}M(h_i^l, e_{i,j}, h_j^l))\n\n    where :math:`h_i` and :math:`e_{i,j}` are node features and edge features, respectively.\n    :math:`M` and :math:`U` are MLPs, taking the concatenation of input for computing\n    output features. :math:`\\oplus` represents the combination of various aggregators\n    and scalers. Aggregators aggregate messages from neighbours and scalers scale the\n    aggregated messages in different ways. :math:`\\oplus` concatenates the output features\n    of each combination.\n\n    The output of multiple towers are concatenated and fed into a linear mixing layer for the\n    final output.\n\n    Parameters\n    ----------\n    in_size : int\n        Input feature size; i.e. the size of :math:`h_i^l`.\n    out_size : int\n        Output feature size; i.e. the size of :math:`h_i^{l+1}`.\n    aggregators : list of str\n        List of aggregation function names(each aggregator specifies a way to aggregate\n        messages from neighbours), selected from:\n\n        * ``mean``: the mean of neighbour messages\n\n        * ``max``: the maximum of neighbour messages\n\n        * ``min``: the minimum of neighbour messages\n\n        * ``std``: the standard deviation of neighbour messages\n\n        * ``var``: the variance of neighbour messages\n\n        * ``sum``: the sum of neighbour messages\n\n        * ``moment3``, ``moment4``, ``moment5``: the normalized moments aggregation\n        :math:`(E[(X-E[X])^n])^{1/n}`\n    scalers: list of str\n        List of scaler function names, selected from:\n\n        * ``identity``: no scaling\n\n        * ``amplification``: multiply the aggregated message by :math:`\\log(d+1)/\\delta`,\n        where :math:`d` is the degree of the node.\n\n        * ``attenuation``: multiply the aggregated message by :math:`\\delta/\\log(d+1)`\n    delta: float\n        The degree-related normalization factor computed over the training set, used by scalers\n        for normalization. :math:`E[\\log(d+1)]`, where :math:`d` is the degree for each node\n        in the training set.\n    dropout: float, optional\n        The dropout ratio. Default: 0.0.\n    num_towers: int, optional\n        The number of towers used. Default: 1. Note that in_size and out_size must be divisible\n        by num_towers.\n    edge_feat_size: int, optional\n        The edge feature size. Default: 0.\n    residual : bool, optional\n        The bool flag that determines whether to add a residual connection for the\n        output. Default: True. If in_size and out_size of the PNA conv layer are not\n        the same, this flag will be set as False forcibly.\n\n    Example\n    -------\n    >>> import dgl\n    >>> import torch as th\n    >>> from dgl.nn import PNAConv\n    >>>\n    >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))\n    >>> feat = th.ones(6, 10)\n    >>> conv = PNAConv(10, 10, ['mean', 'max', 'sum'], ['identity', 'amplification'], 2.5)\n    >>> ret = conv(g, feat)\n    \"\"\"\n\n    def __init__(\n        self,\n        in_size,\n        out_size,\n        aggregators,\n        scalers,\n        delta,\n        dropout=0.0,\n        num_towers=1,\n        edge_feat_size=0,\n        residual=True,\n    ):\n        super(PNAConv, self).__init__()\n\n        self.in_size = in_size\n        self.out_size = out_size\n        assert (\n            in_size % num_towers == 0\n        ), \"in_size must be divisible by num_towers\"\n        assert (\n            out_size % num_towers == 0\n        ), \"out_size must be divisible by num_towers\"\n        self.tower_in_size = in_size // num_towers\n        self.tower_out_size = out_size // num_towers\n        self.edge_feat_size = edge_feat_size\n        self.residual = residual\n        if self.in_size != self.out_size:\n            self.residual = False\n\n        self.towers = nn.ModuleList(\n            [\n                PNAConvTower(\n                    self.tower_in_size,\n                    self.tower_out_size,\n                    aggregators,\n                    scalers,\n                    delta,\n                    dropout=dropout,\n                    edge_feat_size=edge_feat_size,\n                )\n                for _ in range(num_towers)\n            ]\n        )\n\n        self.mixing_layer = nn.Sequential(\n            nn.Linear(out_size, out_size), nn.LeakyReLU()\n        )\n\n    def forward(self, graph, node_feat, edge_feat=None):\n        r\"\"\"\n        Description\n        -----------\n        Compute PNA layer.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The graph.\n        node_feat : torch.Tensor\n            The input feature of shape :math:`(N, h_n)`. :math:`N` is the number of\n            nodes, and :math:`h_n` must be the same as in_size.\n        edge_feat : torch.Tensor, optional\n            The edge feature of shape :math:`(M, h_e)`. :math:`M` is the number of\n            edges, and :math:`h_e` must be the same as edge_feat_size.\n\n        Returns\n        -------\n        torch.Tensor\n            The output node feature of shape :math:`(N, h_n')` where :math:`h_n'`\n            should be the same as out_size.\n        \"\"\"\n        h_cat = torch.cat(\n            [\n                tower(\n                    graph,\n                    node_feat[\n                        :,\n                        ti * self.tower_in_size : (ti + 1) * self.tower_in_size,\n                    ],\n                    edge_feat,\n                )\n                for ti, tower in enumerate(self.towers)\n            ],\n            dim=1,\n        )\n        h_out = self.mixing_layer(h_cat)\n        # add residual connection\n        if self.residual:\n            h_out = h_out + node_feat\n\n        return h_out\n"
  },
  {
    "path": "python/dgl/nn/pytorch/conv/relgraphconv.py",
    "content": "\"\"\"Torch Module for Relational graph convolution layer\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name\nimport torch as th\nfrom torch import nn\n\nfrom .... import function as fn\nfrom ..linear import TypedLinear\n\n\nclass RelGraphConv(nn.Module):\n    r\"\"\"Relational graph convolution layer from `Modeling Relational Data with Graph\n    Convolutional Networks <https://arxiv.org/abs/1703.06103>`__\n\n    It can be described in as below:\n\n    .. math::\n\n       h_i^{(l+1)} = \\sigma(\\sum_{r\\in\\mathcal{R}}\n       \\sum_{j\\in\\mathcal{N}^r(i)}e_{j,i}W_r^{(l)}h_j^{(l)}+W_0^{(l)}h_i^{(l)})\n\n    where :math:`\\mathcal{N}^r(i)` is the neighbor set of node :math:`i` w.r.t. relation\n    :math:`r`. :math:`e_{j,i}` is the normalizer. :math:`\\sigma` is an activation\n    function. :math:`W_0` is the self-loop weight.\n\n    The basis regularization decomposes :math:`W_r` by:\n\n    .. math::\n\n       W_r^{(l)} = \\sum_{b=1}^B a_{rb}^{(l)}V_b^{(l)}\n\n    where :math:`B` is the number of bases, :math:`V_b^{(l)}` are linearly combined\n    with coefficients :math:`a_{rb}^{(l)}`.\n\n    The block-diagonal-decomposition regularization decomposes :math:`W_r` into :math:`B`\n    number of block diagonal matrices. We refer :math:`B` as the number of bases.\n\n    The block regularization decomposes :math:`W_r` by:\n\n    .. math::\n\n       W_r^{(l)} = \\oplus_{b=1}^B Q_{rb}^{(l)}\n\n    where :math:`B` is the number of bases, :math:`Q_{rb}^{(l)}` are block\n    bases with shape :math:`R^{(d^{(l+1)}/B)*(d^{l}/B)}`.\n\n    Parameters\n    ----------\n    in_feat : int\n        Input feature size; i.e, the number of dimensions of :math:`h_j^{(l)}`.\n    out_feat : int\n        Output feature size; i.e., the number of dimensions of :math:`h_i^{(l+1)}`.\n    num_rels : int\n        Number of relations.\n    regularizer : str, optional\n        Which weight regularizer to use (\"basis\", \"bdd\" or ``None``):\n\n         - \"basis\" is for basis-decomposition.\n         - \"bdd\" is for block-diagonal-decomposition.\n         - ``None`` applies no regularization.\n\n        Default: ``None``.\n    num_bases : int, optional\n        Number of bases. It comes into effect when a regularizer is applied.\n        If ``None``, it uses number of relations (``num_rels``). Default: ``None``.\n        Note that ``in_feat`` and ``out_feat`` must be divisible by ``num_bases``\n        when applying \"bdd\" regularizer.\n    bias : bool, optional\n        True if bias is added. Default: ``True``.\n    activation : callable, optional\n        Activation function. Default: ``None``.\n    self_loop : bool, optional\n        True to include self loop message. Default: ``True``.\n    dropout : float, optional\n        Dropout rate. Default: ``0.0``\n    layer_norm: bool, optional\n        True to add layer norm. Default: ``False``\n\n    Examples\n    --------\n    >>> import dgl\n    >>> import numpy as np\n    >>> import torch as th\n    >>> from dgl.nn import RelGraphConv\n    >>>\n    >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))\n    >>> feat = th.ones(6, 10)\n    >>> conv = RelGraphConv(10, 2, 3, regularizer='basis', num_bases=2)\n    >>> etype = th.tensor([0,1,2,0,1,2])\n    >>> res = conv(g, feat, etype)\n    >>> res\n    tensor([[ 0.3996, -2.3303],\n            [-0.4323, -0.1440],\n            [ 0.3996, -2.3303],\n            [ 2.1046, -2.8654],\n            [-0.4323, -0.1440],\n            [-0.1309, -1.0000]], grad_fn=<AddBackward0>)\n    \"\"\"\n\n    def __init__(\n        self,\n        in_feat,\n        out_feat,\n        num_rels,\n        regularizer=None,\n        num_bases=None,\n        bias=True,\n        activation=None,\n        self_loop=True,\n        dropout=0.0,\n        layer_norm=False,\n    ):\n        super().__init__()\n        if regularizer is not None and num_bases is None:\n            num_bases = num_rels\n        self.linear_r = TypedLinear(\n            in_feat, out_feat, num_rels, regularizer, num_bases\n        )\n        self.bias = bias\n        self.activation = activation\n        self.self_loop = self_loop\n        self.layer_norm = layer_norm\n\n        # bias\n        if self.bias:\n            self.h_bias = nn.Parameter(th.Tensor(out_feat))\n            nn.init.zeros_(self.h_bias)\n\n        # TODO(minjie): consider remove those options in the future to make\n        #   the module only about graph convolution.\n        # layer norm\n        if self.layer_norm:\n            self.layer_norm_weight = nn.LayerNorm(\n                out_feat, elementwise_affine=True\n            )\n\n        # weight for self loop\n        if self.self_loop:\n            self.loop_weight = nn.Parameter(th.Tensor(in_feat, out_feat))\n            nn.init.xavier_uniform_(\n                self.loop_weight, gain=nn.init.calculate_gain(\"relu\")\n            )\n\n        self.dropout = nn.Dropout(dropout)\n\n    def message(self, edges):\n        \"\"\"Message function.\"\"\"\n        m = self.linear_r(edges.src[\"h\"], edges.data[\"etype\"], self.presorted)\n        if \"norm\" in edges.data:\n            m = m * edges.data[\"norm\"]\n        return {\"m\": m}\n\n    def forward(self, g, feat, etypes, norm=None, *, presorted=False):\n        \"\"\"Forward computation.\n\n        Parameters\n        ----------\n        g : DGLGraph\n            The graph.\n        feat : torch.Tensor\n            A 2D tensor of node features. Shape: :math:`(|V|, D_{in})`.\n        etypes : torch.Tensor or list[int]\n            An 1D integer tensor of edge types. Shape: :math:`(|E|,)`.\n        norm : torch.Tensor, optional\n            An 1D tensor of edge norm value.  Shape: :math:`(|E|,)`.\n        presorted : bool, optional\n            Whether the edges of the input graph have been sorted by their types.\n            Forward on pre-sorted graph may be faster. Graphs created\n            by :func:`~dgl.to_homogeneous` automatically satisfy the condition.\n            Also see :func:`~dgl.reorder_graph` for sorting edges manually.\n\n        Returns\n        -------\n        torch.Tensor\n            New node features. Shape: :math:`(|V|, D_{out})`.\n        \"\"\"\n        self.presorted = presorted\n        with g.local_scope():\n            g.srcdata[\"h\"] = feat\n            if norm is not None:\n                g.edata[\"norm\"] = norm\n            g.edata[\"etype\"] = etypes\n            # message passing\n            g.update_all(self.message, fn.sum(\"m\", \"h\"))\n            # apply bias and activation\n            h = g.dstdata[\"h\"]\n            if self.layer_norm:\n                h = self.layer_norm_weight(h)\n            if self.bias:\n                h = h + self.h_bias\n            if self.self_loop:\n                h = h + feat[: g.num_dst_nodes()] @ self.loop_weight\n            if self.activation:\n                h = self.activation(h)\n            h = self.dropout(h)\n            return h\n"
  },
  {
    "path": "python/dgl/nn/pytorch/conv/sageconv.py",
    "content": "\"\"\"Torch Module for GraphSAGE layer\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nfrom .... import function as fn\nfrom ....base import DGLError\nfrom ....utils import check_eq_shape, expand_as_pair\n\n\nclass SAGEConv(nn.Module):\n    r\"\"\"GraphSAGE layer from `Inductive Representation Learning on\n    Large Graphs <https://arxiv.org/pdf/1706.02216.pdf>`__\n\n    .. math::\n        h_{\\mathcal{N}(i)}^{(l+1)} &= \\mathrm{aggregate}\n        \\left(\\{h_{j}^{l}, \\forall j \\in \\mathcal{N}(i) \\}\\right)\n\n        h_{i}^{(l+1)} &= \\sigma \\left(W \\cdot \\mathrm{concat}\n        (h_{i}^{l}, h_{\\mathcal{N}(i)}^{l+1}) \\right)\n\n        h_{i}^{(l+1)} &= \\mathrm{norm}(h_{i}^{(l+1)})\n\n    If a weight tensor on each edge is provided, the aggregation becomes:\n\n    .. math::\n        h_{\\mathcal{N}(i)}^{(l+1)} = \\mathrm{aggregate}\n        \\left(\\{e_{ji} h_{j}^{l}, \\forall j \\in \\mathcal{N}(i) \\}\\right)\n\n    where :math:`e_{ji}` is the scalar weight on the edge from node :math:`j` to node :math:`i`.\n    Please make sure that :math:`e_{ji}` is broadcastable with :math:`h_j^{l}`.\n\n    Parameters\n    ----------\n    in_feats : int, or pair of ints\n        Input feature size; i.e, the number of dimensions of :math:`h_i^{(l)}`.\n\n        SAGEConv can be applied on homogeneous graph and unidirectional\n        `bipartite graph <https://docs.dgl.ai/generated/dgl.bipartite.html?highlight=bipartite>`__.\n        If the layer applies on a unidirectional bipartite graph, ``in_feats``\n        specifies the input feature size on both the source and destination nodes.  If\n        a scalar is given, the source and destination node feature size would take the\n        same value.\n\n        If aggregator type is ``gcn``, the feature size of source and destination nodes\n        are required to be the same.\n    out_feats : int\n        Output feature size; i.e, the number of dimensions of :math:`h_i^{(l+1)}`.\n    aggregator_type : str\n        Aggregator type to use (``mean``, ``gcn``, ``pool``, ``lstm``).\n    feat_drop : float\n        Dropout rate on features, default: ``0``.\n    bias : bool\n        If True, adds a learnable bias to the output. Default: ``True``.\n    norm : callable activation function/layer or None, optional\n        If not None, applies normalization to the updated node features.\n    activation : callable activation function/layer or None, optional\n        If not None, applies an activation function to the updated node features.\n        Default: ``None``.\n\n    Examples\n    --------\n    >>> import dgl\n    >>> import numpy as np\n    >>> import torch as th\n    >>> from dgl.nn import SAGEConv\n\n    >>> # Case 1: Homogeneous graph\n    >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))\n    >>> g = dgl.add_self_loop(g)\n    >>> feat = th.ones(6, 10)\n    >>> conv = SAGEConv(10, 2, 'pool')\n    >>> res = conv(g, feat)\n    >>> res\n    tensor([[-1.0888, -2.1099],\n            [-1.0888, -2.1099],\n            [-1.0888, -2.1099],\n            [-1.0888, -2.1099],\n            [-1.0888, -2.1099],\n            [-1.0888, -2.1099]], grad_fn=<AddBackward0>)\n\n    >>> # Case 2: Unidirectional bipartite graph\n    >>> u = [0, 1, 0, 0, 1]\n    >>> v = [0, 1, 2, 3, 2]\n    >>> g = dgl.heterograph({('_N', '_E', '_N'):(u, v)})\n    >>> u_fea = th.rand(2, 5)\n    >>> v_fea = th.rand(4, 10)\n    >>> conv = SAGEConv((5, 10), 2, 'mean')\n    >>> res = conv(g, (u_fea, v_fea))\n    >>> res\n    tensor([[ 0.3163,  3.1166],\n            [ 0.3866,  2.5398],\n            [ 0.5873,  1.6597],\n            [-0.2502,  2.8068]], grad_fn=<AddBackward0>)\n    \"\"\"\n\n    def __init__(\n        self,\n        in_feats,\n        out_feats,\n        aggregator_type,\n        feat_drop=0.0,\n        bias=True,\n        norm=None,\n        activation=None,\n    ):\n        super(SAGEConv, self).__init__()\n        valid_aggre_types = {\"mean\", \"gcn\", \"pool\", \"lstm\"}\n        if aggregator_type not in valid_aggre_types:\n            raise DGLError(\n                \"Invalid aggregator_type. Must be one of {}. \"\n                \"But got {!r} instead.\".format(\n                    valid_aggre_types, aggregator_type\n                )\n            )\n\n        self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)\n        self._out_feats = out_feats\n        self._aggre_type = aggregator_type\n        self.norm = norm\n        self.feat_drop = nn.Dropout(feat_drop)\n        self.activation = activation\n\n        # aggregator type: mean/pool/lstm/gcn\n        if aggregator_type == \"pool\":\n            self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)\n        if aggregator_type == \"lstm\":\n            self.lstm = nn.LSTM(\n                self._in_src_feats, self._in_src_feats, batch_first=True\n            )\n\n        self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=False)\n\n        if aggregator_type != \"gcn\":\n            self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)\n        elif bias:\n            self.bias = nn.parameter.Parameter(torch.zeros(self._out_feats))\n        else:\n            self.register_buffer(\"bias\", None)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        r\"\"\"\n\n        Description\n        -----------\n        Reinitialize learnable parameters.\n\n        Note\n        ----\n        The linear weights :math:`W^{(l)}` are initialized using Glorot uniform initialization.\n        The LSTM module is using xavier initialization method for its weights.\n        \"\"\"\n        gain = nn.init.calculate_gain(\"relu\")\n        if self._aggre_type == \"pool\":\n            nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)\n        if self._aggre_type == \"lstm\":\n            self.lstm.reset_parameters()\n        if self._aggre_type != \"gcn\":\n            nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)\n        nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)\n\n    def _lstm_reducer(self, nodes):\n        \"\"\"LSTM reducer\n        NOTE(zihao): lstm reducer with default schedule (degree bucketing)\n        is slow, we could accelerate this with degree padding in the future.\n        \"\"\"\n        m = nodes.mailbox[\"m\"]  # (B, L, D)\n        batch_size = m.shape[0]\n        h = (\n            m.new_zeros((1, batch_size, self._in_src_feats)),\n            m.new_zeros((1, batch_size, self._in_src_feats)),\n        )\n        _, (rst, _) = self.lstm(m, h)\n        return {\"neigh\": rst.squeeze(0)}\n\n    def forward(self, graph, feat, edge_weight=None):\n        r\"\"\"\n\n        Description\n        -----------\n        Compute GraphSAGE layer.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The graph.\n        feat : torch.Tensor or pair of torch.Tensor\n            If a torch.Tensor is given, it represents the input feature of shape\n            :math:`(N, D_{in})`\n            where :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.\n            If a pair of torch.Tensor is given, the pair must contain two tensors of shape\n            :math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.\n        edge_weight : torch.Tensor, optional\n            Optional tensor on the edge. If given, the convolution will weight\n            with regard to the message.\n\n        Returns\n        -------\n        torch.Tensor\n            The output feature of shape :math:`(N_{dst}, D_{out})`\n            where :math:`N_{dst}` is the number of destination nodes in the input graph,\n            :math:`D_{out}` is the size of the output feature.\n        \"\"\"\n        with graph.local_scope():\n            if isinstance(feat, tuple):\n                feat_src = self.feat_drop(feat[0])\n                feat_dst = self.feat_drop(feat[1])\n            else:\n                feat_src = feat_dst = self.feat_drop(feat)\n                if graph.is_block:\n                    feat_dst = feat_src[: graph.number_of_dst_nodes()]\n            msg_fn = fn.copy_u(\"h\", \"m\")\n            if edge_weight is not None:\n                assert edge_weight.shape[0] == graph.num_edges()\n                graph.edata[\"_edge_weight\"] = edge_weight\n                msg_fn = fn.u_mul_e(\"h\", \"_edge_weight\", \"m\")\n\n            h_self = feat_dst\n\n            # Handle the case of graphs without edges\n            if graph.num_edges() == 0:\n                graph.dstdata[\"neigh\"] = torch.zeros(\n                    feat_dst.shape[0], self._in_src_feats\n                ).to(feat_dst)\n\n            # Determine whether to apply linear transformation before message passing A(XW)\n            lin_before_mp = self._in_src_feats > self._out_feats\n\n            # Message Passing\n            if self._aggre_type == \"mean\":\n                graph.srcdata[\"h\"] = (\n                    self.fc_neigh(feat_src) if lin_before_mp else feat_src\n                )\n                graph.update_all(msg_fn, fn.mean(\"m\", \"neigh\"))\n                h_neigh = graph.dstdata[\"neigh\"]\n                if not lin_before_mp:\n                    h_neigh = self.fc_neigh(h_neigh)\n            elif self._aggre_type == \"gcn\":\n                check_eq_shape(feat)\n                graph.srcdata[\"h\"] = (\n                    self.fc_neigh(feat_src) if lin_before_mp else feat_src\n                )\n                if isinstance(feat, tuple):  # heterogeneous\n                    graph.dstdata[\"h\"] = (\n                        self.fc_neigh(feat_dst) if lin_before_mp else feat_dst\n                    )\n                else:\n                    if graph.is_block:\n                        graph.dstdata[\"h\"] = graph.srcdata[\"h\"][\n                            : graph.num_dst_nodes()\n                        ]\n                    else:\n                        graph.dstdata[\"h\"] = graph.srcdata[\"h\"]\n                graph.update_all(msg_fn, fn.sum(\"m\", \"neigh\"))\n                # divide in_degrees\n                degs = graph.in_degrees().to(feat_dst)\n                h_neigh = (graph.dstdata[\"neigh\"] + graph.dstdata[\"h\"]) / (\n                    degs.unsqueeze(-1) + 1\n                )\n                if not lin_before_mp:\n                    h_neigh = self.fc_neigh(h_neigh)\n            elif self._aggre_type == \"pool\":\n                graph.srcdata[\"h\"] = F.relu(self.fc_pool(feat_src))\n                graph.update_all(msg_fn, fn.max(\"m\", \"neigh\"))\n                h_neigh = self.fc_neigh(graph.dstdata[\"neigh\"])\n            elif self._aggre_type == \"lstm\":\n                graph.srcdata[\"h\"] = feat_src\n                graph.update_all(msg_fn, self._lstm_reducer)\n                h_neigh = self.fc_neigh(graph.dstdata[\"neigh\"])\n            else:\n                raise KeyError(\n                    \"Aggregator type {} not recognized.\".format(\n                        self._aggre_type\n                    )\n                )\n\n            # GraphSAGE GCN does not require fc_self.\n            if self._aggre_type == \"gcn\":\n                rst = h_neigh\n                # add bias manually for GCN\n                if self.bias is not None:\n                    rst = rst + self.bias\n            else:\n                rst = self.fc_self(h_self) + h_neigh\n\n            # activation\n            if self.activation is not None:\n                rst = self.activation(rst)\n            # normalization\n            if self.norm is not None:\n                rst = self.norm(rst)\n            return rst\n"
  },
  {
    "path": "python/dgl/nn/pytorch/conv/sgconv.py",
    "content": "\"\"\"Torch Module for Simplifying Graph Convolution layer\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name\nimport torch as th\nfrom torch import nn\n\nfrom .... import function as fn\nfrom ....base import DGLError\nfrom .graphconv import EdgeWeightNorm\n\n\nclass SGConv(nn.Module):\n    r\"\"\"SGC layer from `Simplifying Graph\n    Convolutional Networks <https://arxiv.org/pdf/1902.07153.pdf>`__\n\n    .. math::\n        H^{K} = (\\tilde{D}^{-1/2} \\tilde{A} \\tilde{D}^{-1/2})^K X \\Theta\n\n    where :math:`\\tilde{A}` is :math:`A` + :math:`I`.\n    Thus the graph input is expected to have self-loop edges added.\n\n    Parameters\n    ----------\n    in_feats : int\n        Number of input features; i.e, the number of dimensions of :math:`X`.\n    out_feats : int\n        Number of output features; i.e, the number of dimensions of :math:`H^{K}`.\n    k : int\n        Number of hops :math:`K`. Defaults:``1``.\n    cached : bool\n        If True, the module would cache\n\n        .. math::\n            (\\tilde{D}^{-\\frac{1}{2}}\\tilde{A}\\tilde{D}^{-\\frac{1}{2}})^K X\\Theta\n\n        at the first forward call. This parameter should only be set to\n        ``True`` in Transductive Learning setting.\n    bias : bool\n        If True, adds a learnable bias to the output. Default: ``True``.\n    norm : callable activation function/layer or None, optional\n        If not None, applies normalization to the updated node features.  Default: ``False``.\n    allow_zero_in_degree : bool, optional\n        If there are 0-in-degree nodes in the graph, output for those nodes will be invalid\n        since no message will be passed to those nodes. This is harmful for some applications\n        causing silent performance regression. This module will raise a DGLError if it detects\n        0-in-degree nodes in input graph. By setting ``True``, it will suppress the check\n        and let the users handle it by themselves. Default: ``False``.\n\n    Note\n    ----\n    Zero in-degree nodes will lead to invalid output value. This is because no message\n    will be passed to those nodes, the aggregation function will be appied on empty input.\n    A common practice to avoid this is to add a self-loop for each node in the graph if\n    it is homogeneous, which can be achieved by:\n\n    >>> g = ... # a DGLGraph\n    >>> g = dgl.add_self_loop(g)\n\n    Calling ``add_self_loop`` will not work for some graphs, for example, heterogeneous graph\n    since the edge type can not be decided for self_loop edges. Set ``allow_zero_in_degree``\n    to ``True`` for those cases to unblock the code and handle zero-in-degree nodes manually.\n    A common practise to handle this is to filter out the nodes with zero-in-degree when use\n    after conv.\n\n    Example\n    -------\n    >>> import dgl\n    >>> import numpy as np\n    >>> import torch as th\n    >>> from dgl.nn import SGConv\n    >>>\n    >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))\n    >>> g = dgl.add_self_loop(g)\n    >>> feat = th.ones(6, 10)\n    >>> conv = SGConv(10, 2, k=2)\n    >>> res = conv(g, feat)\n    >>> res\n    tensor([[-1.9441, -0.9343],\n            [-1.9441, -0.9343],\n            [-1.9441, -0.9343],\n            [-2.7709, -1.3316],\n            [-1.9297, -0.9273],\n            [-1.9441, -0.9343]], grad_fn=<AddmmBackward>)\n    \"\"\"\n\n    def __init__(\n        self,\n        in_feats,\n        out_feats,\n        k=1,\n        cached=False,\n        bias=True,\n        norm=None,\n        allow_zero_in_degree=False,\n    ):\n        super(SGConv, self).__init__()\n        self.fc = nn.Linear(in_feats, out_feats, bias=bias)\n        self._cached = cached\n        self._cached_h = None\n        self._k = k\n        self.norm = norm\n        self._allow_zero_in_degree = allow_zero_in_degree\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        r\"\"\"\n\n        Description\n        -----------\n        Reinitialize learnable parameters.\n\n        Note\n        ----\n        The model parameters are initialized using xavier initialization\n        and the bias is initialized to be zero.\n        \"\"\"\n        nn.init.xavier_uniform_(self.fc.weight)\n        if self.fc.bias is not None:\n            nn.init.zeros_(self.fc.bias)\n\n    def set_allow_zero_in_degree(self, set_value):\n        r\"\"\"\n\n        Description\n        -----------\n        Set allow_zero_in_degree flag.\n\n        Parameters\n        ----------\n        set_value : bool\n            The value to be set to the flag.\n        \"\"\"\n        self._allow_zero_in_degree = set_value\n\n    def forward(self, graph, feat, edge_weight=None):\n        r\"\"\"\n\n        Description\n        -----------\n        Compute Simplifying Graph Convolution layer.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The graph.\n        feat : torch.Tensor\n            The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`\n            is size of input feature, :math:`N` is the number of nodes.\n        edge_weight: torch.Tensor, optional\n            edge_weight to use in the message passing process. This is equivalent to\n            using weighted adjacency matrix in the equation above, and\n            :math:`\\tilde{D}^{-1/2}\\tilde{A} \\tilde{D}^{-1/2}`\n            is based on :class:`dgl.nn.pytorch.conv.graphconv.EdgeWeightNorm`.\n\n        Returns\n        -------\n        torch.Tensor\n            The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`\n            is size of output feature.\n\n        Raises\n        ------\n        DGLError\n            If there are 0-in-degree nodes in the input graph, it will raise DGLError\n            since no message will be passed to those nodes. This will cause invalid output.\n            The error can be ignored by setting ``allow_zero_in_degree`` parameter to ``True``.\n\n        Note\n        ----\n        If ``cache`` is set to True, ``feat`` and ``graph`` should not change during\n        training, or you will get wrong results.\n        \"\"\"\n        with graph.local_scope():\n            if not self._allow_zero_in_degree:\n                if (graph.in_degrees() == 0).any():\n                    raise DGLError(\n                        \"There are 0-in-degree nodes in the graph, \"\n                        \"output for those nodes will be invalid. \"\n                        \"This is harmful for some applications, \"\n                        \"causing silent performance regression. \"\n                        \"Adding self-loop on the input graph by \"\n                        \"calling `g = dgl.add_self_loop(g)` will resolve \"\n                        \"the issue. Setting ``allow_zero_in_degree`` \"\n                        \"to be `True` when constructing this module will \"\n                        \"suppress the check and let the code run.\"\n                    )\n\n            msg_func = fn.copy_u(\"h\", \"m\")\n            if edge_weight is not None:\n                graph.edata[\"_edge_weight\"] = EdgeWeightNorm(\"both\")(\n                    graph, edge_weight\n                )\n                msg_func = fn.u_mul_e(\"h\", \"_edge_weight\", \"m\")\n\n            if self._cached_h is not None:\n                feat = self._cached_h\n            else:\n                if edge_weight is None:\n                    # compute normalization\n                    degs = graph.in_degrees().to(feat).clamp(min=1)\n                    norm = th.pow(degs, -0.5)\n                    norm = norm.to(feat.device).unsqueeze(1)\n                # compute (D^-1 A^k D)^k X\n                for _ in range(self._k):\n                    if edge_weight is None:\n                        feat = feat * norm\n                    graph.ndata[\"h\"] = feat\n                    graph.update_all(msg_func, fn.sum(\"m\", \"h\"))\n                    feat = graph.ndata.pop(\"h\")\n                    if edge_weight is None:\n                        feat = feat * norm\n\n                if self.norm is not None:\n                    feat = self.norm(feat)\n\n                # cache feature\n                if self._cached:\n                    self._cached_h = feat\n            return self.fc(feat)\n"
  },
  {
    "path": "python/dgl/nn/pytorch/conv/tagconv.py",
    "content": "\"\"\"Torch Module for Topology Adaptive Graph Convolutional layer\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name\nimport torch as th\nfrom torch import nn\n\nfrom .... import function as fn\nfrom .graphconv import EdgeWeightNorm\n\n\nclass TAGConv(nn.Module):\n    r\"\"\"Topology Adaptive Graph Convolutional layer from `Topology\n    Adaptive Graph Convolutional Networks <https://arxiv.org/pdf/1710.10370.pdf>`__\n\n    .. math::\n        H^{K} = {\\sum}_{k=0}^K (D^{-1/2} A D^{-1/2})^{k} X {\\Theta}_{k},\n\n    where :math:`A` denotes the adjacency matrix,\n    :math:`D_{ii} = \\sum_{j=0} A_{ij}` its diagonal degree matrix,\n    :math:`{\\Theta}_{k}` denotes the linear weights to sum the results of different hops together.\n\n    Parameters\n    ----------\n    in_feats : int\n        Input feature size. i.e, the number of dimensions of :math:`X`.\n    out_feats : int\n        Output feature size.  i.e, the number of dimensions of :math:`H^{K}`.\n    k: int, optional\n        Number of hops :math:`K`. Default: ``2``.\n    bias: bool, optional\n        If True, adds a learnable bias to the output. Default: ``True``.\n    activation: callable activation function/layer or None, optional\n        If not None, applies an activation function to the updated node features.\n        Default: ``None``.\n\n    Attributes\n    ----------\n    lin : torch.Module\n        The learnable linear module.\n\n    Example\n    -------\n    >>> import dgl\n    >>> import numpy as np\n    >>> import torch as th\n    >>> from dgl.nn import TAGConv\n    >>>\n    >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))\n    >>> feat = th.ones(6, 10)\n    >>> conv = TAGConv(10, 2, k=2)\n    >>> res = conv(g, feat)\n    >>> res\n    tensor([[ 0.5490, -1.6373],\n            [ 0.5490, -1.6373],\n            [ 0.5490, -1.6373],\n            [ 0.5513, -1.8208],\n            [ 0.5215, -1.6044],\n            [ 0.3304, -1.9927]], grad_fn=<AddmmBackward>)\n    \"\"\"\n\n    def __init__(\n        self,\n        in_feats,\n        out_feats,\n        k=2,\n        bias=True,\n        activation=None,\n    ):\n        super(TAGConv, self).__init__()\n        self._in_feats = in_feats\n        self._out_feats = out_feats\n        self._k = k\n        self._activation = activation\n        self.lin = nn.Linear(in_feats * (self._k + 1), out_feats, bias=bias)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        r\"\"\"\n\n        Description\n        -----------\n        Reinitialize learnable parameters.\n\n        Note\n        ----\n        The model parameters are initialized using Glorot uniform initialization.\n        \"\"\"\n        gain = nn.init.calculate_gain(\"relu\")\n        nn.init.xavier_normal_(self.lin.weight, gain=gain)\n\n    def forward(self, graph, feat, edge_weight=None):\n        r\"\"\"\n\n        Description\n        -----------\n        Compute topology adaptive graph convolution.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The graph.\n        feat : torch.Tensor\n            The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`\n            is size of input feature, :math:`N` is the number of nodes.\n        edge_weight: torch.Tensor, optional\n            edge_weight to use in the message passing process. This is equivalent to\n            using weighted adjacency matrix in the equation above, and\n            :math:`\\tilde{D}^{-1/2}\\tilde{A} \\tilde{D}^{-1/2}`\n            is based on :class:`dgl.nn.pytorch.conv.graphconv.EdgeWeightNorm`.\n\n        Returns\n        -------\n        torch.Tensor\n            The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`\n            is size of output feature.\n        \"\"\"\n        with graph.local_scope():\n            assert graph.is_homogeneous, \"Graph is not homogeneous\"\n            if edge_weight is None:\n                norm = th.pow(graph.in_degrees().to(feat).clamp(min=1), -0.5)\n                shp = norm.shape + (1,) * (feat.dim() - 1)\n                norm = th.reshape(norm, shp).to(feat.device)\n\n            msg_func = fn.copy_u(\"h\", \"m\")\n            if edge_weight is not None:\n                graph.edata[\"_edge_weight\"] = EdgeWeightNorm(\"both\")(\n                    graph, edge_weight\n                )\n                msg_func = fn.u_mul_e(\"h\", \"_edge_weight\", \"m\")\n            # D-1/2 A D -1/2 X\n            fstack = [feat]\n            for _ in range(self._k):\n                if edge_weight is None:\n                    rst = fstack[-1] * norm\n                else:\n                    rst = fstack[-1]\n                graph.ndata[\"h\"] = rst\n\n                graph.update_all(msg_func, fn.sum(msg=\"m\", out=\"h\"))\n                rst = graph.ndata[\"h\"]\n                if edge_weight is None:\n                    rst = rst * norm\n                fstack.append(rst)\n\n            rst = self.lin(th.cat(fstack, dim=-1))\n\n            if self._activation is not None:\n                rst = self._activation(rst)\n\n            return rst\n"
  },
  {
    "path": "python/dgl/nn/pytorch/conv/twirlsconv.py",
    "content": "\"\"\"Torch modules for TWIRLS\"\"\"\n# pylint: disable=invalid-name, useless-super-delegation, no-member\n\nimport torch as tc\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom .... import function as fn\n\n\nclass TWIRLSConv(nn.Module):\n    r\"\"\"Convolution together with iteratively reweighting least squre from\n    `Graph Neural Networks Inspired by Classical Iterative Algorithms\n    <https://arxiv.org/pdf/2103.06064.pdf>`__\n\n    Parameters\n    ----------\n    input_d : int\n        Number of input features.\n    output_d : int\n        Number of output features.\n    hidden_d : int\n        Size of hidden layers.\n    prop_step : int\n        Number of propagation steps\n    num_mlp_before : int\n        Number of mlp layers before propagation. Default: ``1``.\n    num_mlp_after : int\n        Number of mlp layers after propagation.  Default: ``1``.\n    norm : str\n        The type of norm layers inside mlp layers. Can be ``'batch'``, ``'layer'`` or ``'none'``.\n        Default: ``'none'``\n    precond : str\n        If True, use pre conditioning and unormalized laplacian, else not use pre conditioning\n        and use normalized laplacian. Default: ``True``\n    alp : float\n        The :math:`\\alpha` in paper. If equal to :math:`0`, will be automatically decided based\n        on other hyper prameters. Default: ``0``.\n    lam : float\n        The :math:`\\lambda` in paper. Default: ``1``.\n    attention : bool\n        If ``True``, add an attention layer inside propagations. Default: ``False``.\n    tau : float\n        The :math:`\\tau` in paper. Default: ``0.2``.\n    T : float\n        The :math:`T` in paper. If < 0, :math:`T` will be set to `\\infty`. Default: ``-1``.\n    p : float\n        The :math:`p` in paper. Default: ``1``.\n    use_eta : bool\n        If ``True``, add a learnable weight on each dimension in attention. Default: ``False``.\n    attn_bef : bool\n        If ``True``, add another attention layer before propagation. Default: ``False``.\n    dropout : float\n        The dropout rate in mlp layers. Default: ``0.0``.\n    attn_dropout : float\n        The dropout rate of attention values. Default: ``0.0``.\n    inp_dropout : float\n        The dropout rate on input features. Default: ``0.0``.\n\n\n    Note\n    ----\n     ``add_self_loop`` will be automatically called before propagation.\n\n    Example\n    -------\n    >>> import dgl\n    >>> from dgl.nn import TWIRLSConv\n    >>> import torch as th\n\n    >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))\n    >>> feat = th.ones(6, 10)\n    >>> conv = TWIRLSConv(10, 2, 128, prop_step = 64)\n    >>> res = conv(g , feat)\n    >>> res.size()\n    torch.Size([6, 2])\n    \"\"\"\n\n    def __init__(\n        self,\n        input_d,\n        output_d,\n        hidden_d,\n        prop_step,\n        num_mlp_before=1,\n        num_mlp_after=1,\n        norm=\"none\",\n        precond=True,\n        alp=0,\n        lam=1,\n        attention=False,\n        tau=0.2,\n        T=-1,\n        p=1,\n        use_eta=False,\n        attn_bef=False,\n        dropout=0.0,\n        attn_dropout=0.0,\n        inp_dropout=0.0,\n    ):\n        super().__init__()\n        self.input_d = input_d\n        self.output_d = output_d\n        self.hidden_d = hidden_d\n        self.prop_step = prop_step\n        self.num_mlp_before = num_mlp_before\n        self.num_mlp_after = num_mlp_after\n        self.norm = norm\n        self.precond = precond\n        self.attention = attention\n        self.alp = alp\n        self.lam = lam\n        self.tau = tau\n        self.T = T\n        self.p = p\n        self.use_eta = use_eta\n        self.init_att = attn_bef\n        self.dropout = dropout\n        self.attn_dropout = attn_dropout\n        self.inp_dropout = inp_dropout\n\n        # ----- initialization of some variables -----\n        # where to put attention\n        self.attn_aft = prop_step // 2 if attention else -1\n\n        # whether we can cache unfolding result\n        self.cacheable = (\n            (not self.attention)\n            and self.num_mlp_before == 0\n            and self.inp_dropout <= 0\n        )\n        if self.cacheable:\n            self.cached_unfolding = None\n\n        # if only one layer, then no hidden size\n        self.size_bef_unf = self.hidden_d\n        self.size_aft_unf = self.hidden_d\n        if self.num_mlp_before == 0:\n            self.size_aft_unf = self.input_d  # as the input  of mlp_aft\n        if self.num_mlp_after == 0:\n            self.size_bef_unf = self.output_d  # as the output of mlp_bef\n\n        # ----- computational modules -----\n        self.mlp_bef = MLP(\n            self.input_d,\n            self.hidden_d,\n            self.size_bef_unf,\n            self.num_mlp_before,\n            self.dropout,\n            self.norm,\n            init_activate=False,\n        )\n\n        self.unfolding = TWIRLSUnfoldingAndAttention(\n            self.hidden_d,\n            self.alp,\n            self.lam,\n            self.prop_step,\n            self.attn_aft,\n            self.tau,\n            self.T,\n            self.p,\n            self.use_eta,\n            self.init_att,\n            self.attn_dropout,\n            self.precond,\n        )\n\n        # if there are really transformations before unfolding, then do init_activate in mlp_aft\n        self.mlp_aft = MLP(\n            self.size_aft_unf,\n            self.hidden_d,\n            self.output_d,\n            self.num_mlp_after,\n            self.dropout,\n            self.norm,\n            init_activate=(self.num_mlp_before > 0)\n            and (self.num_mlp_after > 0),\n        )\n\n    def forward(self, graph, feat):\n        r\"\"\"\n\n        Description\n        -----------\n        Run TWIRLS forward.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The graph.\n        feat : torch.Tensor\n            The initial node features.\n        Returns\n        -------\n        torch.Tensor\n            The output feature\n\n        Note\n        ----\n        * Input shape: :math:`(N, \\text{input_d})` where :math:`N` is the number of nodes.\n        * Output shape: :math:`(N, \\text{output_d})`.\n        \"\"\"\n\n        # ensure self loop\n        graph = graph.remove_self_loop()\n        graph = graph.add_self_loop()\n\n        x = feat\n\n        if self.cacheable:\n            # to cache unfolding result becase there is no paramaters before it\n            if self.cached_unfolding is None:\n                self.cached_unfolding = self.unfolding(graph, x)\n\n            x = self.cached_unfolding\n        else:\n            if self.inp_dropout > 0:\n                x = F.dropout(x, self.inp_dropout, training=self.training)\n            x = self.mlp_bef(x)\n            x = self.unfolding(graph, x)\n\n        x = self.mlp_aft(x)\n\n        return x\n\n\nclass Propagate(nn.Module):\n    r\"\"\"\n\n    Description\n    -----------\n    The propagation method which is with pre-conditioning and reparameterizing. Correspond to\n    eq.28 in the paper.\n\n    \"\"\"\n\n    def __init__(self):\n        super().__init__()\n\n    def _prop(self, graph, Y, lam):\n        \"\"\"propagation part.\"\"\"\n        Y = D_power_bias_X(graph, Y, -0.5, lam, 1 - lam)\n        Y = AX(graph, Y)\n        Y = D_power_bias_X(graph, Y, -0.5, lam, 1 - lam)\n\n        return Y\n\n    def forward(self, graph, Y, X, alp, lam):\n        r\"\"\"\n\n        Description\n        -----------\n        Propagation forward.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The graph.\n        Y : torch.Tensor\n            The feature under propagation. Corresponds to :math:`Z^{(k)}` in eq.28 in the paper.\n        X : torch.Tensor\n            The original feature. Corresponds to :math:`Z^{(0)}` in eq.28 in the paper.\n        alp : float\n            The step size. Corresponds to :math:`\\alpha` in the paper.\n        lam : torch.Tensor\n            The coefficient of smoothing term. Corresponds to :math:`\\lambda` in the paper.\n        Returns\n        -------\n        torch.Tensor\n            Propagated feature. :math:`Z^{(k+1)}` in eq.28 in the paper.\n        \"\"\"\n\n        return (\n            (1 - alp) * Y\n            + alp * lam * self._prop(graph, Y, lam)\n            + alp * D_power_bias_X(graph, X, -1, lam, 1 - lam)\n        )\n\n\nclass PropagateNoPrecond(nn.Module):\n    r\"\"\"\n\n    Description\n    -----------\n    The propagation method which is without pre-conditioning and reparameterizing and using\n    normalized laplacian.\n    Correspond to eq.30 in the paper.\n    \"\"\"\n\n    def __init__(self):\n        super().__init__()\n\n    def forward(self, graph, Y, X, alp, lam):\n        r\"\"\"\n\n        Description\n        -----------\n        Propagation forward.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The graph.\n        Y : torch.Tensor\n            The feature under propagation. Corresponds to :math:`Y^{(k)}` in eq.30 in the paper.\n        X : torch.Tensor\n            The original feature. Corresponds to :math:`Y^{(0)}` in eq.30 in the paper.\n        alp : float\n            The step size. Corresponds to :math:`\\alpha` in the paper.\n        lam : torch.Tensor\n            The coefficient of smoothing term. Corresponds to :math:`\\lambda` in the paper.\n        Returns\n        -------\n        torch.Tensor\n            Propagated feature. :math:`Y^{(k+1)}` in eq.30 in the paper.\n        \"\"\"\n\n        return (\n            (1 - alp * lam - alp) * Y\n            + alp * lam * normalized_AX(graph, Y)\n            + alp * X\n        )\n\n\nclass Attention(nn.Module):\n    r\"\"\"\n\n    Description\n    -----------\n    The attention function. Correspond to :math:`s` in eq.27 the paper.\n\n    Parameters\n    ----------\n    tau : float\n        The lower thresholding parameter. Correspond to :math:`\\tau` in the paper.\n    T : float\n        The upper thresholding parameter. Correspond to :math:`T` in the paper.\n    p : float\n        Correspond to :math:`\\rho` in the paper..\n    attn_dropout : float\n        the dropout rate of attention value. Default: ``0.0``.\n\n    Returns\n    -------\n    torch.Tensor\n        The output feature\n    \"\"\"\n\n    def __init__(self, tau, T, p, attn_dropout=0.0):\n        super().__init__()\n\n        self.tau = tau\n        self.T = T\n        self.p = p\n        self.attn_dropout = attn_dropout\n\n    def reweighting(self, graph):\n        \"\"\"Compute graph edge weight. Would be stored in ``graph.edata['w']``\"\"\"\n\n        w = graph.edata[\"w\"]\n\n        # It is not activation here but to ensure w > 0.\n        # w can be < 0 here because of some precision issue in dgl, which causes NaN afterwards.\n        w = F.relu(w) + 1e-7\n\n        w = tc.pow(w, 1 - 0.5 * self.p)\n\n        w[(w < self.tau)] = self.tau\n        if self.T > 0:\n            w[(w > self.T)] = float(\"inf\")\n\n        w = 1 / w\n\n        # if not (w == w).all():\n        #     raise \"nan occured!\"\n\n        graph.edata[\"w\"] = w + 1e-9  # avoid 0 degree\n\n    def forward(self, graph, Y, etas=None):\n        r\"\"\"\n\n        Description\n        -----------\n        Attention forward. Will update ``graph.edata['w']`` and ``graph.ndata['deg']``.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The graph.\n        Y : torch.Tensor\n            The feature to compute attention.\n        etas : float\n            The weight of each dimension. If ``None``, then weight of each dimension is 1.\n            Default: ``None``.\n\n        Returns\n        -------\n        DGLGraph\n            The graph.\n        \"\"\"\n\n        if etas is not None:\n            Y = Y * etas.view(-1)\n\n        # computing edge distance\n        graph.srcdata[\"h\"] = Y\n        graph.srcdata[\"h_norm\"] = (Y**2).sum(-1)\n        graph.apply_edges(fn.u_dot_v(\"h\", \"h\", \"dot_\"))\n        graph.apply_edges(fn.u_add_v(\"h_norm\", \"h_norm\", \"norm_\"))\n        graph.edata[\"dot_\"] = graph.edata[\"dot_\"].view(-1)\n        graph.edata[\"norm_\"] = graph.edata[\"norm_\"].view(-1)\n        graph.edata[\"w\"] = graph.edata[\"norm_\"] - 2 * graph.edata[\"dot_\"]\n\n        # apply edge distance to get edge weight\n        self.reweighting(graph)\n\n        # update node degrees\n        graph.update_all(fn.copy_e(\"w\", \"m\"), fn.sum(\"m\", \"deg\"))\n        graph.ndata[\"deg\"] = graph.ndata[\"deg\"].view(-1)\n\n        # attention dropout. the implementation can ensure the degrees do not change in expectation.\n        # FIXME: consider if there is a better way\n        if self.attn_dropout > 0:\n            graph.edata[\"w\"] = F.dropout(\n                graph.edata[\"w\"], self.attn_dropout, training=self.training\n            )\n\n        return graph\n\n\ndef normalized_AX(graph, X):\n    \"\"\"Y = D^{-1/2}AD^{-1/2}X\"\"\"\n\n    Y = D_power_X(graph, X, -0.5)  # Y = D^{-1/2}X\n    Y = AX(graph, Y)  # Y = AD^{-1/2}X\n    Y = D_power_X(graph, Y, -0.5)  # Y = D^{-1/2}AD^{-1/2}X\n\n    return Y\n\n\ndef AX(graph, X):\n    \"\"\"Y = AX\"\"\"\n\n    graph.srcdata[\"h\"] = X\n    graph.update_all(\n        fn.u_mul_e(\"h\", \"w\", \"m\"),\n        fn.sum(\"m\", \"h\"),\n    )\n    Y = graph.dstdata[\"h\"]\n\n    return Y\n\n\ndef D_power_X(graph, X, power):\n    \"\"\"Y = D^{power}X\"\"\"\n\n    degs = graph.ndata[\"deg\"]\n    norm = tc.pow(degs, power)\n    Y = X * norm.view(X.size(0), 1)\n    return Y\n\n\ndef D_power_bias_X(graph, X, power, coeff, bias):\n    \"\"\"Y = (coeff*D + bias*I)^{power} X\"\"\"\n    degs = graph.ndata[\"deg\"]\n    degs = coeff * degs + bias\n    norm = tc.pow(degs, power)\n    Y = X * norm.view(X.size(0), 1)\n    return Y\n\n\nclass TWIRLSUnfoldingAndAttention(nn.Module):\n    r\"\"\"\n\n    Description\n    -----------\n    Combine propagation and attention together.\n\n    Parameters\n    ----------\n    d : int\n        Size of graph feature.\n    alp : float\n        Step size. :math:`\\alpha` in ther paper.\n    lam : int\n        Coefficient of graph smooth term. :math:`\\lambda` in ther paper.\n    prop_step : int\n        Number of propagation steps\n    attn_aft : int\n        Where to put attention layer. i.e. number of propagation steps before attention.\n        If set to ``-1``, then no attention.\n    tau : float\n        The lower thresholding parameter. Correspond to :math:`\\tau` in the paper.\n    T : float\n        The upper thresholding parameter. Correspond to :math:`T` in the paper.\n    p : float\n        Correspond to :math:`\\rho` in the paper..\n    use_eta : bool\n        If `True`, learn a weight vector for each dimension when doing attention.\n    init_att : bool\n        If ``True``, add an extra attention layer before propagation.\n    attn_dropout : float\n        the dropout rate of attention value. Default: ``0.0``.\n    precond : bool\n        If ``True``, use pre-conditioned & reparameterized version propagation (eq.28), else use\n        normalized laplacian (eq.30).\n\n    Example\n    -------\n    >>> import dgl\n    >>> from dgl.nn import TWIRLSUnfoldingAndAttention\n    >>> import torch as th\n\n    >>> g = dgl.graph(([0, 1, 2, 3, 2, 5], [1, 2, 3, 4, 0, 3])).add_self_loop()\n    >>> feat = th.ones(6,5)\n    >>> prop = TWIRLSUnfoldingAndAttention(10, 1, 1, prop_step=3)\n    >>> res = prop(g,feat)\n    >>> res\n    tensor([[2.5000, 2.5000, 2.5000, 2.5000, 2.5000],\n            [2.5000, 2.5000, 2.5000, 2.5000, 2.5000],\n            [2.5000, 2.5000, 2.5000, 2.5000, 2.5000],\n            [3.7656, 3.7656, 3.7656, 3.7656, 3.7656],\n            [2.5217, 2.5217, 2.5217, 2.5217, 2.5217],\n            [4.0000, 4.0000, 4.0000, 4.0000, 4.0000]])\n\n    \"\"\"\n\n    def __init__(\n        self,\n        d,\n        alp,\n        lam,\n        prop_step,\n        attn_aft=-1,\n        tau=0.2,\n        T=-1,\n        p=1,\n        use_eta=False,\n        init_att=False,\n        attn_dropout=0,\n        precond=True,\n    ):\n        super().__init__()\n\n        self.d = d\n        self.alp = alp if alp > 0 else 1 / (lam + 1)  # automatic set alpha\n        self.lam = lam\n        self.tau = tau\n        self.p = p\n        self.prop_step = prop_step\n        self.attn_aft = attn_aft\n        self.use_eta = use_eta\n        self.init_att = init_att\n\n        prop_method = Propagate if precond else PropagateNoPrecond\n        self.prop_layers = nn.ModuleList(\n            [prop_method() for _ in range(prop_step)]\n        )\n\n        self.init_attn = (\n            Attention(tau, T, p, attn_dropout) if self.init_att else None\n        )\n        self.attn_layer = (\n            Attention(tau, T, p, attn_dropout) if self.attn_aft >= 0 else None\n        )\n        self.etas = nn.Parameter(tc.ones(d)) if self.use_eta else None\n\n    def forward(self, g, X):\n        r\"\"\"\n\n        Description\n        -----------\n        Compute forward pass of propagation & attention.\n\n        Parameters\n        ----------\n        g : DGLGraph\n            The graph.\n        X : torch.Tensor\n            Init features.\n\n        Returns\n        -------\n        torch.Tensor\n            The graph.\n        \"\"\"\n        Y = X\n\n        g.edata[\"w\"] = tc.ones(g.num_edges(), 1, device=g.device)\n        g.ndata[\"deg\"] = g.in_degrees().to(X)\n\n        if self.init_att:\n            g = self.init_attn(g, Y, self.etas)\n\n        for k, layer in enumerate(self.prop_layers):\n            # do unfolding\n            Y = layer(g, Y, X, self.alp, self.lam)\n\n            # do attention at certain layer\n            if k == self.attn_aft - 1:\n                g = self.attn_layer(g, Y, self.etas)\n\n        return Y\n\n\nclass MLP(nn.Module):\n    r\"\"\"\n\n    Description\n    -----------\n    An MLP module.\n\n    Parameters\n    ----------\n    input_d : int\n        Number of input features.\n    output_d : int\n        Number of output features.\n    hidden_d : int\n        Size of hidden layers.\n    num_layers : int\n        Number of mlp layers.\n    dropout : float\n        The dropout rate in mlp layers.\n    norm : str\n        The type of norm layers inside mlp layers. Can be ``'batch'``, ``'layer'`` or ``'none'``.\n    init_activate : bool\n        If add a relu at the beginning.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        input_d,\n        hidden_d,\n        output_d,\n        num_layers,\n        dropout,\n        norm,\n        init_activate,\n    ):\n        super().__init__()\n\n        self.init_activate = init_activate\n        self.norm = norm\n        self.dropout = dropout\n\n        self.layers = nn.ModuleList([])\n\n        if num_layers == 1:\n            self.layers.append(nn.Linear(input_d, output_d))\n        elif num_layers > 1:\n            self.layers.append(nn.Linear(input_d, hidden_d))\n            for _ in range(num_layers - 2):\n                self.layers.append(nn.Linear(hidden_d, hidden_d))\n            self.layers.append(nn.Linear(hidden_d, output_d))\n\n        # how many norm layers we have\n        self.norm_cnt = num_layers - 1 + int(init_activate)\n        if norm == \"batch\":\n            self.norms = nn.ModuleList(\n                [nn.BatchNorm1d(hidden_d) for _ in range(self.norm_cnt)]\n            )\n        elif norm == \"layer\":\n            self.norms = nn.ModuleList(\n                [nn.LayerNorm(hidden_d) for _ in range(self.norm_cnt)]\n            )\n\n        self.reset_params()\n\n    def reset_params(self):\n        \"\"\"reset mlp parameters using xavier_norm\"\"\"\n        for layer in self.layers:\n            nn.init.xavier_normal_(layer.weight.data)\n            nn.init.constant_(layer.bias.data, 0)\n\n    def activate(self, x):\n        \"\"\"do normlaization and activation\"\"\"\n        if self.norm != \"none\":\n            x = self.norms[self.cur_norm_idx](x)  # use the last norm layer\n            self.cur_norm_idx += 1\n        x = F.relu(x)\n        x = F.dropout(x, self.dropout, training=self.training)\n        return x\n\n    def forward(self, x):\n        \"\"\"The forward pass of mlp.\"\"\"\n        self.cur_norm_idx = 0\n\n        if self.init_activate:\n            x = self.activate(x)\n\n        for i, layer in enumerate(self.layers):\n            x = layer(x)\n            if i != len(self.layers) - 1:  # do not activate in the last layer\n                x = self.activate(x)\n\n        return x\n"
  },
  {
    "path": "python/dgl/nn/pytorch/explain/__init__.py",
    "content": "\"\"\"Torch modules for explanation models.\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name\n\nfrom .gnnexplainer import *\nfrom .subgraphx import *\nfrom .pgexplainer import *\n"
  },
  {
    "path": "python/dgl/nn/pytorch/explain/gnnexplainer.py",
    "content": "\"\"\"Torch Module for GNNExplainer\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name\nfrom math import sqrt\n\nimport torch\n\nfrom torch import nn\nfrom tqdm.auto import tqdm\n\nfrom ....base import EID, NID\nfrom ....subgraph import khop_in_subgraph\n\n__all__ = [\"GNNExplainer\", \"HeteroGNNExplainer\"]\n\n\nclass GNNExplainer(nn.Module):\n    r\"\"\"GNNExplainer model from `GNNExplainer: Generating Explanations for\n    Graph Neural Networks <https://arxiv.org/abs/1903.03894>`__\n\n    It identifies compact subgraph structures and small subsets of node features that play a\n    critical role in GNN-based node classification and graph classification.\n\n    To generate an explanation, it learns an edge mask :math:`M` and a feature mask :math:`F`\n    by optimizing the following objective function.\n\n    .. math::\n      l(y, \\hat{y}) + \\alpha_1 \\|M\\|_1 + \\alpha_2 H(M) + \\beta_1 \\|F\\|_1 + \\beta_2 H(F)\n\n    where :math:`l` is the loss function, :math:`y` is the original model prediction,\n    :math:`\\hat{y}` is the model prediction with the edge and feature mask applied, :math:`H` is\n    the entropy function.\n\n    Parameters\n    ----------\n    model : nn.Module\n        The GNN model to explain.\n\n        * The required arguments of its forward function are graph and feat.\n          The latter one is for input node features.\n        * It should also optionally take an eweight argument for edge weights\n          and multiply the messages by it in message passing.\n        * The output of its forward function is the logits for the predicted\n          node/graph classes.\n\n        See also the example in :func:`explain_node` and :func:`explain_graph`.\n    num_hops : int\n        The number of hops for GNN information aggregation.\n    lr : float, optional\n        The learning rate to use, default to 0.01.\n    num_epochs : int, optional\n        The number of epochs to train.\n    alpha1 : float, optional\n        A higher value will make the explanation edge masks more sparse by decreasing\n        the sum of the edge mask.\n    alpha2 : float, optional\n        A higher value will make the explanation edge masks more sparse by decreasing\n        the entropy of the edge mask.\n    beta1 : float, optional\n        A higher value will make the explanation node feature masks more sparse by\n        decreasing the mean of the node feature mask.\n    beta2 : float, optional\n        A higher value will make the explanation node feature masks more sparse by\n        decreasing the entropy of the node feature mask.\n    log : bool, optional\n        If True, it will log the computation process, default to True.\n    \"\"\"\n\n    def __init__(\n        self,\n        model,\n        num_hops,\n        lr=0.01,\n        num_epochs=100,\n        *,\n        alpha1=0.005,\n        alpha2=1.0,\n        beta1=1.0,\n        beta2=0.1,\n        log=True,\n    ):\n        super(GNNExplainer, self).__init__()\n        self.model = model\n        self.num_hops = num_hops\n        self.lr = lr\n        self.num_epochs = num_epochs\n        self.alpha1 = alpha1\n        self.alpha2 = alpha2\n        self.beta1 = beta1\n        self.beta2 = beta2\n        self.log = log\n\n    def _init_masks(self, graph, feat):\n        r\"\"\"Initialize learnable feature and edge mask.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            Input graph.\n        feat : Tensor\n            Input node features.\n\n        Returns\n        -------\n        feat_mask : Tensor\n            Feature mask of shape :math:`(1, D)`, where :math:`D`\n            is the feature size.\n        edge_mask : Tensor\n            Edge mask of shape :math:`(E)`, where :math:`E` is the\n            number of edges.\n        \"\"\"\n        num_nodes, feat_size = feat.size()\n        num_edges = graph.num_edges()\n        device = feat.device\n\n        std = 0.1\n        feat_mask = nn.Parameter(torch.randn(1, feat_size, device=device) * std)\n\n        std = nn.init.calculate_gain(\"relu\") * sqrt(2.0 / (2 * num_nodes))\n        edge_mask = nn.Parameter(torch.randn(num_edges, device=device) * std)\n\n        return feat_mask, edge_mask\n\n    def _loss_regularize(self, loss, feat_mask, edge_mask):\n        r\"\"\"Add regularization terms to the loss.\n\n        Parameters\n        ----------\n        loss : Tensor\n            Loss value.\n        feat_mask : Tensor\n            Feature mask of shape :math:`(1, D)`, where :math:`D`\n            is the feature size.\n        edge_mask : Tensor\n            Edge mask of shape :math:`(E)`, where :math:`E`\n            is the number of edges.\n\n        Returns\n        -------\n        Tensor\n            Loss value with regularization terms added.\n        \"\"\"\n        # epsilon for numerical stability\n        eps = 1e-15\n\n        edge_mask = edge_mask.sigmoid()\n        # Edge mask sparsity regularization\n        loss = loss + self.alpha1 * torch.sum(edge_mask)\n        # Edge mask entropy regularization\n        ent = -edge_mask * torch.log(edge_mask + eps) - (\n            1 - edge_mask\n        ) * torch.log(1 - edge_mask + eps)\n        loss = loss + self.alpha2 * ent.mean()\n\n        feat_mask = feat_mask.sigmoid()\n        # Feature mask sparsity regularization\n        loss = loss + self.beta1 * torch.mean(feat_mask)\n        # Feature mask entropy regularization\n        ent = -feat_mask * torch.log(feat_mask + eps) - (\n            1 - feat_mask\n        ) * torch.log(1 - feat_mask + eps)\n        loss = loss + self.beta2 * ent.mean()\n\n        return loss\n\n    def explain_node(self, node_id, graph, feat, **kwargs):\n        r\"\"\"Learn and return a node feature mask and subgraph that play a\n        crucial role to explain the prediction made by the GNN for node\n        :attr:`node_id`.\n\n        Parameters\n        ----------\n        node_id : int\n            The node to explain.\n        graph : DGLGraph\n            A homogeneous graph.\n        feat : Tensor\n            The input feature of shape :math:`(N, D)`. :math:`N` is the\n            number of nodes, and :math:`D` is the feature size.\n        kwargs : dict\n            Additional arguments passed to the GNN model. Tensors whose\n            first dimension is the number of nodes or edges will be\n            assumed to be node/edge features.\n\n        Returns\n        -------\n        new_node_id : Tensor\n            The new ID of the input center node.\n        sg : DGLGraph\n            The subgraph induced on the k-hop in-neighborhood of the input center node.\n        feat_mask : Tensor\n            Learned node feature importance mask of shape :math:`(D)`, where :math:`D` is the\n            feature size. The values are within range :math:`(0, 1)`.\n            The higher, the more important.\n        edge_mask : Tensor\n            Learned importance mask of the edges in the subgraph, which is a tensor\n            of shape :math:`(E)`, where :math:`E` is the number of edges in the\n            subgraph. The values are within range :math:`(0, 1)`.\n            The higher, the more important.\n\n        Examples\n        --------\n\n        >>> import dgl\n        >>> import dgl.function as fn\n        >>> import torch\n        >>> import torch.nn as nn\n        >>> from dgl.data import CoraGraphDataset\n        >>> from dgl.nn import GNNExplainer\n\n        >>> # Load dataset\n        >>> data = CoraGraphDataset()\n        >>> g = data[0]\n        >>> features = g.ndata['feat']\n        >>> labels = g.ndata['label']\n        >>> train_mask = g.ndata['train_mask']\n\n        >>> # Define a model\n        >>> class Model(nn.Module):\n        ...     def __init__(self, in_feats, out_feats):\n        ...         super(Model, self).__init__()\n        ...         self.linear = nn.Linear(in_feats, out_feats)\n        ...\n        ...     def forward(self, graph, feat, eweight=None):\n        ...         with graph.local_scope():\n        ...             feat = self.linear(feat)\n        ...             graph.ndata['h'] = feat\n        ...             if eweight is None:\n        ...                 graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))\n        ...             else:\n        ...                 graph.edata['w'] = eweight\n        ...                 graph.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h'))\n        ...             return graph.ndata['h']\n\n        >>> # Train the model\n        >>> model = Model(features.shape[1], data.num_classes)\n        >>> criterion = nn.CrossEntropyLoss()\n        >>> optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)\n        >>> for epoch in range(10):\n        ...     logits = model(g, features)\n        ...     loss = criterion(logits[train_mask], labels[train_mask])\n        ...     optimizer.zero_grad()\n        ...     loss.backward()\n        ...     optimizer.step()\n\n        >>> # Explain the prediction for node 10\n        >>> explainer = GNNExplainer(model, num_hops=1)\n        >>> new_center, sg, feat_mask, edge_mask = explainer.explain_node(10, g, features)\n        >>> new_center\n        tensor([1])\n        >>> sg.num_edges()\n        12\n        >>> # Old IDs of the nodes in the subgraph\n        >>> sg.ndata[dgl.NID]\n        tensor([ 9, 10, 11, 12])\n        >>> # Old IDs of the edges in the subgraph\n        >>> sg.edata[dgl.EID]\n        tensor([51, 53, 56, 48, 52, 57, 47, 50, 55, 46, 49, 54])\n        >>> feat_mask\n        tensor([0.2638, 0.2738, 0.3039,  ..., 0.2794, 0.2643, 0.2733])\n        >>> edge_mask\n        tensor([0.0937, 0.1496, 0.8287, 0.8132, 0.8825, 0.8515, 0.8146, 0.0915, 0.1145,\n                0.9011, 0.1311, 0.8437])\n        \"\"\"\n        self.model = self.model.to(graph.device)\n        self.model.eval()\n        num_nodes = graph.num_nodes()\n        num_edges = graph.num_edges()\n\n        # Extract node-centered k-hop subgraph and\n        # its associated node and edge features.\n        sg, inverse_indices = khop_in_subgraph(graph, node_id, self.num_hops)\n        sg_nodes = sg.ndata[NID].long()\n        sg_edges = sg.edata[EID].long()\n        feat = feat[sg_nodes]\n        for key, item in kwargs.items():\n            if torch.is_tensor(item) and item.size(0) == num_nodes:\n                item = item[sg_nodes]\n            elif torch.is_tensor(item) and item.size(0) == num_edges:\n                item = item[sg_edges]\n            kwargs[key] = item\n\n        # Get the initial prediction.\n        with torch.no_grad():\n            logits = self.model(graph=sg, feat=feat, **kwargs)\n            pred_label = logits.argmax(dim=-1)\n\n        feat_mask, edge_mask = self._init_masks(sg, feat)\n\n        params = [feat_mask, edge_mask]\n        optimizer = torch.optim.Adam(params, lr=self.lr)\n\n        if self.log:\n            pbar = tqdm(total=self.num_epochs)\n            pbar.set_description(f\"Explain node {node_id}\")\n\n        for _ in range(self.num_epochs):\n            optimizer.zero_grad()\n            h = feat * feat_mask.sigmoid()\n            logits = self.model(\n                graph=sg, feat=h, eweight=edge_mask.sigmoid(), **kwargs\n            )\n            log_probs = logits.log_softmax(dim=-1)\n            loss = -log_probs[inverse_indices, pred_label[inverse_indices]]\n            loss = self._loss_regularize(loss, feat_mask, edge_mask)\n            loss.backward()\n            optimizer.step()\n\n            if self.log:\n                pbar.update(1)\n\n        if self.log:\n            pbar.close()\n\n        feat_mask = feat_mask.detach().sigmoid().squeeze()\n        edge_mask = edge_mask.detach().sigmoid()\n\n        return inverse_indices, sg, feat_mask, edge_mask\n\n    def explain_graph(self, graph, feat, **kwargs):\n        r\"\"\"Learn and return a node feature mask and an edge mask that play a\n        crucial role to explain the prediction made by the GNN for a graph.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            A homogeneous graph.\n        feat : Tensor\n            The input feature of shape :math:`(N, D)`. :math:`N` is the\n            number of nodes, and :math:`D` is the feature size.\n        kwargs : dict\n            Additional arguments passed to the GNN model. Tensors whose\n            first dimension is the number of nodes or edges will be\n            assumed to be node/edge features.\n\n        Returns\n        -------\n        feat_mask : Tensor\n            Learned feature importance mask of shape :math:`(D)`, where :math:`D` is the\n            feature size. The values are within range :math:`(0, 1)`.\n            The higher, the more important.\n        edge_mask : Tensor\n            Learned importance mask of the edges in the graph, which is a tensor\n            of shape :math:`(E)`, where :math:`E` is the number of edges in the\n            graph. The values are within range :math:`(0, 1)`. The higher,\n            the more important.\n\n        Examples\n        --------\n\n        >>> import dgl.function as fn\n        >>> import torch\n        >>> import torch.nn as nn\n        >>> from dgl.data import GINDataset\n        >>> from dgl.dataloading import GraphDataLoader\n        >>> from dgl.nn import AvgPooling, GNNExplainer\n\n        >>> # Load dataset\n        >>> data = GINDataset('MUTAG', self_loop=True)\n        >>> dataloader = GraphDataLoader(data, batch_size=64, shuffle=True)\n\n        >>> # Define a model\n        >>> class Model(nn.Module):\n        ...     def __init__(self, in_feats, out_feats):\n        ...         super(Model, self).__init__()\n        ...         self.linear = nn.Linear(in_feats, out_feats)\n        ...         self.pool = AvgPooling()\n        ...\n        ...     def forward(self, graph, feat, eweight=None):\n        ...         with graph.local_scope():\n        ...             feat = self.linear(feat)\n        ...             graph.ndata['h'] = feat\n        ...             if eweight is None:\n        ...                 graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))\n        ...             else:\n        ...                 graph.edata['w'] = eweight\n        ...                 graph.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h'))\n        ...             return self.pool(graph, graph.ndata['h'])\n\n        >>> # Train the model\n        >>> feat_size = data[0][0].ndata['attr'].shape[1]\n        >>> model = Model(feat_size, data.gclasses)\n        >>> criterion = nn.CrossEntropyLoss()\n        >>> optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)\n        >>> for bg, labels in dataloader:\n        ...     logits = model(bg, bg.ndata['attr'])\n        ...     loss = criterion(logits, labels)\n        ...     optimizer.zero_grad()\n        ...     loss.backward()\n        ...     optimizer.step()\n\n        >>> # Explain the prediction for graph 0\n        >>> explainer = GNNExplainer(model, num_hops=1)\n        >>> g, _ = data[0]\n        >>> features = g.ndata['attr']\n        >>> feat_mask, edge_mask = explainer.explain_graph(g, features)\n        >>> feat_mask\n        tensor([0.2362, 0.2497, 0.2622, 0.2675, 0.2649, 0.2962, 0.2533])\n        >>> edge_mask\n        tensor([0.2154, 0.2235, 0.8325, ..., 0.7787, 0.1735, 0.1847])\n        \"\"\"\n        self.model = self.model.to(graph.device)\n        self.model.eval()\n\n        # Get the initial prediction.\n        with torch.no_grad():\n            logits = self.model(graph=graph, feat=feat, **kwargs)\n            pred_label = logits.argmax(dim=-1)\n\n        feat_mask, edge_mask = self._init_masks(graph, feat)\n\n        params = [feat_mask, edge_mask]\n        optimizer = torch.optim.Adam(params, lr=self.lr)\n\n        if self.log:\n            pbar = tqdm(total=self.num_epochs)\n            pbar.set_description(\"Explain graph\")\n\n        for _ in range(self.num_epochs):\n            optimizer.zero_grad()\n            h = feat * feat_mask.sigmoid()\n            logits = self.model(\n                graph=graph, feat=h, eweight=edge_mask.sigmoid(), **kwargs\n            )\n            log_probs = logits.log_softmax(dim=-1)\n            loss = -log_probs[0, pred_label[0]]\n            loss = self._loss_regularize(loss, feat_mask, edge_mask)\n            loss.backward()\n            optimizer.step()\n\n            if self.log:\n                pbar.update(1)\n\n        if self.log:\n            pbar.close()\n\n        feat_mask = feat_mask.detach().sigmoid().squeeze()\n        edge_mask = edge_mask.detach().sigmoid()\n\n        return feat_mask, edge_mask\n\n\nclass HeteroGNNExplainer(nn.Module):\n    r\"\"\"GNNExplainer model from `GNNExplainer: Generating Explanations for\n    Graph Neural Networks <https://arxiv.org/abs/1903.03894>`__, adapted for heterogeneous graphs\n\n    It identifies compact subgraph structures and small subsets of node features that play a\n    critical role in GNN-based node classification and graph classification.\n\n    To generate an explanation, it learns an edge mask :math:`M` and a feature mask :math:`F`\n    by optimizing the following objective function.\n\n    .. math::\n      l(y, \\hat{y}) + \\alpha_1 \\|M\\|_1 + \\alpha_2 H(M) + \\beta_1 \\|F\\|_1 + \\beta_2 H(F)\n\n    where :math:`l` is the loss function, :math:`y` is the original model prediction,\n    :math:`\\hat{y}` is the model prediction with the edge and feature mask applied, :math:`H` is\n    the entropy function.\n\n    Parameters\n    ----------\n    model : nn.Module\n        The GNN model to explain.\n\n        * The required arguments of its forward function are graph and feat.\n          The latter one is for input node features.\n        * It should also optionally take an eweight argument for edge weights\n          and multiply the messages by it in message passing.\n        * The output of its forward function is the logits for the predicted\n          node/graph classes.\n\n        See also the example in :func:`explain_node` and :func:`explain_graph`.\n    num_hops : int\n        The number of hops for GNN information aggregation.\n    lr : float, optional\n        The learning rate to use, default to 0.01.\n    num_epochs : int, optional\n        The number of epochs to train.\n    alpha1 : float, optional\n        A higher value will make the explanation edge masks more sparse by decreasing\n        the sum of the edge mask.\n    alpha2 : float, optional\n        A higher value will make the explanation edge masks more sparse by decreasing\n        the entropy of the edge mask.\n    beta1 : float, optional\n        A higher value will make the explanation node feature masks more sparse by\n        decreasing the mean of the node feature mask.\n    beta2 : float, optional\n        A higher value will make the explanation node feature masks more sparse by\n        decreasing the entropy of the node feature mask.\n    log : bool, optional\n        If True, it will log the computation process, default to True.\n    \"\"\"\n\n    def __init__(\n        self,\n        model,\n        num_hops,\n        lr=0.01,\n        num_epochs=100,\n        *,\n        alpha1=0.005,\n        alpha2=1.0,\n        beta1=1.0,\n        beta2=0.1,\n        log=True,\n    ):\n        super(HeteroGNNExplainer, self).__init__()\n        self.model = model\n        self.num_hops = num_hops\n        self.lr = lr\n        self.num_epochs = num_epochs\n        self.alpha1 = alpha1\n        self.alpha2 = alpha2\n        self.beta1 = beta1\n        self.beta2 = beta2\n        self.log = log\n\n    def _init_masks(self, graph, feat):\n        r\"\"\"Initialize learnable feature and edge mask.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            Input graph.\n        feat : dict[str, Tensor]\n            The dictionary that associates input node features (values) with\n            the respective node types (keys) present in the graph.\n\n        Returns\n        -------\n        feat_masks : dict[str, Tensor]\n            The dictionary that associates the node feature masks (values) with\n            the respective node types (keys). The feature masks are of shape :math:`(1, D_t)`,\n            where :math:`D_t` is the feature size for node type :math:`t`.\n        edge_masks : dict[tuple[str], Tensor]\n            The dictionary that associates the edge masks (values) with\n            the respective canonical edge types (keys). The edge masks are of shape :math:`(E_t)`,\n            where :math:`E_t` is the number of edges for canonical edge type :math:`t`.\n        \"\"\"\n        device = graph.device\n        feat_masks = {}\n        std = 0.1\n        for node_type, feature in feat.items():\n            _, feat_size = feature.size()\n            feat_masks[node_type] = nn.Parameter(\n                torch.randn(1, feat_size, device=device) * std\n            )\n\n        edge_masks = {}\n        for canonical_etype in graph.canonical_etypes:\n            src_num_nodes = graph.num_nodes(canonical_etype[0])\n            dst_num_nodes = graph.num_nodes(canonical_etype[-1])\n            num_nodes_sum = src_num_nodes + dst_num_nodes\n            num_edges = graph.num_edges(canonical_etype)\n            std = nn.init.calculate_gain(\"relu\")\n            if num_nodes_sum > 0:\n                std *= sqrt(2.0 / num_nodes_sum)\n            edge_masks[canonical_etype] = nn.Parameter(\n                torch.randn(num_edges, device=device) * std\n            )\n\n        return feat_masks, edge_masks\n\n    def _loss_regularize(self, loss, feat_masks, edge_masks):\n        r\"\"\"Add regularization terms to the loss.\n\n        Parameters\n        ----------\n        loss : Tensor\n            Loss value.\n        feat_masks : dict[str, Tensor]\n            The dictionary that associates the node feature masks (values) with\n            the respective node types (keys). The feature masks are of shape :math:`(1, D_t)`,\n            where :math:`D_t` is the feature size for node type :math:`t`.\n        edge_masks : dict[tuple[str], Tensor]\n            The dictionary that associates the edge masks (values) with\n            the respective canonical edge types (keys). The edge masks are of shape :math:`(E_t)`,\n            where :math:`E_t` is the number of edges for canonical edge type :math:`t`.\n\n        Returns\n        -------\n        Tensor\n            Loss value with regularization terms added.\n        \"\"\"\n        # epsilon for numerical stability\n        eps = 1e-15\n\n        for edge_mask in edge_masks.values():\n            edge_mask = edge_mask.sigmoid()\n            # Edge mask sparsity regularization\n            loss = loss + self.alpha1 * torch.sum(edge_mask)\n            # Edge mask entropy regularization\n            ent = -edge_mask * torch.log(edge_mask + eps) - (\n                1 - edge_mask\n            ) * torch.log(1 - edge_mask + eps)\n            loss = loss + self.alpha2 * ent.mean()\n\n        for feat_mask in feat_masks.values():\n            feat_mask = feat_mask.sigmoid()\n            # Feature mask sparsity regularization\n            loss = loss + self.beta1 * torch.mean(feat_mask)\n            # Feature mask entropy regularization\n            ent = -feat_mask * torch.log(feat_mask + eps) - (\n                1 - feat_mask\n            ) * torch.log(1 - feat_mask + eps)\n            loss = loss + self.beta2 * ent.mean()\n\n        return loss\n\n    def explain_node(self, ntype, node_id, graph, feat, **kwargs):\n        r\"\"\"Learn and return node feature masks and a subgraph that play a\n        crucial role to explain the prediction made by the GNN for node\n        :attr:`node_id` of type :attr:`ntype`.\n\n        It requires :attr:`model` to return a dictionary mapping node types to type-specific\n        predictions.\n\n        Parameters\n        ----------\n        ntype : str\n            The type of the node to explain. :attr:`model` must be trained to\n            make predictions for this particular node type.\n        node_id : int\n            The ID of the node to explain.\n        graph : DGLGraph\n            A heterogeneous graph.\n        feat : dict[str, Tensor]\n            The dictionary that associates input node features (values) with\n            the respective node types (keys) present in the graph.\n            The input features are of shape :math:`(N_t, D_t)`. :math:`N_t` is the\n            number of nodes for node type :math:`t`, and :math:`D_t` is the feature size for\n            node type :math:`t`\n        kwargs : dict\n            Additional arguments passed to the GNN model.\n\n        Returns\n        -------\n        new_node_id : Tensor\n            The new ID of the input center node.\n        sg : DGLGraph\n            The subgraph induced on the k-hop in-neighborhood of the input center node.\n        feat_mask : dict[str, Tensor]\n            The dictionary that associates the learned node feature importance masks (values) with\n            the respective node types (keys). The masks are of shape :math:`(D_t)`, where\n            :math:`D_t` is the node feature size for node type :attr:`t`. The values are within\n            range :math:`(0, 1)`. The higher, the more important.\n        edge_mask : dict[Tuple[str], Tensor]\n            The dictionary that associates the learned edge importance masks (values) with\n            the respective canonical edge types (keys). The masks are of shape :math:`(E_t)`,\n            where :math:`E_t` is the number of edges for canonical edge type :math:`t` in the\n            subgraph. The values are within range :math:`(0, 1)`.\n            The higher, the more important.\n\n        Examples\n        --------\n\n        >>> import dgl\n        >>> import dgl.function as fn\n        >>> import torch as th\n        >>> import torch.nn as nn\n        >>> import torch.nn.functional as F\n        >>> from dgl.nn import HeteroGNNExplainer\n\n        >>> class Model(nn.Module):\n        ...     def __init__(self, in_dim, num_classes, canonical_etypes):\n        ...         super(Model, self).__init__()\n        ...         self.etype_weights = nn.ModuleDict({\n        ...             '_'.join(c_etype): nn.Linear(in_dim, num_classes)\n        ...             for c_etype in canonical_etypes\n        ...         })\n        ...\n        ...     def forward(self, graph, feat, eweight=None):\n        ...         with graph.local_scope():\n        ...             c_etype_func_dict = {}\n        ...             for c_etype in graph.canonical_etypes:\n        ...                 src_type, etype, dst_type = c_etype\n        ...                 wh = self.etype_weights['_'.join(c_etype)](feat[src_type])\n        ...                 graph.nodes[src_type].data[f'h_{c_etype}'] = wh\n        ...                 if eweight is None:\n        ...                     c_etype_func_dict[c_etype] = (fn.copy_u(f'h_{c_etype}', 'm'),\n        ...                         fn.mean('m', 'h'))\n        ...                 else:\n        ...                     graph.edges[c_etype].data['w'] = eweight[c_etype]\n        ...                     c_etype_func_dict[c_etype] = (\n        ...                         fn.u_mul_e(f'h_{c_etype}', 'w', 'm'), fn.mean('m', 'h'))\n        ...             graph.multi_update_all(c_etype_func_dict, 'sum')\n        ...             return graph.ndata['h']\n\n        >>> input_dim = 5\n        >>> num_classes = 2\n        >>> g = dgl.heterograph({\n        ...     ('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 1, 1])})\n        >>> g.nodes['user'].data['h'] = th.randn(g.num_nodes('user'), input_dim)\n        >>> g.nodes['game'].data['h'] = th.randn(g.num_nodes('game'), input_dim)\n\n        >>> transform = dgl.transforms.AddReverse()\n        >>> g = transform(g)\n\n        >>> # define and train the model\n        >>> model = Model(input_dim, num_classes, g.canonical_etypes)\n        >>> feat = g.ndata['h']\n        >>> optimizer = th.optim.Adam(model.parameters())\n        >>> for epoch in range(10):\n        ...     logits = model(g, feat)['user']\n        ...     loss = F.cross_entropy(logits, th.tensor([1, 1, 1]))\n        ...     optimizer.zero_grad()\n        ...     loss.backward()\n        ...     optimizer.step()\n\n        >>> # Explain the prediction for node 0 of type 'user'\n        >>> explainer = HeteroGNNExplainer(model, num_hops=1)\n        >>> new_center, sg, feat_mask, edge_mask = explainer.explain_node('user', 0, g, feat)\n        >>> new_center\n        tensor([0])\n        >>> sg\n        Graph(num_nodes={'game': 1, 'user': 1},\n              num_edges={('game', 'rev_plays', 'user'): 1, ('user', 'plays', 'game'): 1,\n                         ('user', 'rev_rev_plays', 'game'): 1},\n              metagraph=[('game', 'user', 'rev_plays'), ('user', 'game', 'plays'),\n                         ('user', 'game', 'rev_rev_plays')])\n        >>> feat_mask\n        {'game': tensor([0.2348, 0.2780, 0.2611, 0.2513, 0.2823]),\n         'user': tensor([0.2716, 0.2450, 0.2658, 0.2876, 0.2738])}\n        >>> edge_mask\n        {('game', 'rev_plays', 'user'): tensor([0.0630]),\n         ('user', 'plays', 'game'): tensor([0.1939]),\n         ('user', 'rev_rev_plays', 'game'): tensor([0.9166])}\n        \"\"\"\n        self.model = self.model.to(graph.device)\n        self.model.eval()\n\n        # Extract node-centered k-hop subgraph and\n        # its associated node and edge features.\n        sg, inverse_indices = khop_in_subgraph(\n            graph, {ntype: node_id}, self.num_hops\n        )\n        inverse_indices = inverse_indices[ntype]\n        sg_nodes = sg.ndata[NID]\n        sg_feat = {}\n\n        for node_type in sg_nodes.keys():\n            sg_feat[node_type] = feat[node_type][sg_nodes[node_type].long()]\n\n        # Get the initial prediction.\n        with torch.no_grad():\n            logits = self.model(graph=sg, feat=sg_feat, **kwargs)[ntype]\n            pred_label = logits.argmax(dim=-1)\n\n        feat_mask, edge_mask = self._init_masks(sg, sg_feat)\n\n        params = [*feat_mask.values(), *edge_mask.values()]\n        optimizer = torch.optim.Adam(params, lr=self.lr)\n\n        if self.log:\n            pbar = tqdm(total=self.num_epochs)\n            pbar.set_description(f\"Explain node {node_id} with type {ntype}\")\n\n        for _ in range(self.num_epochs):\n            optimizer.zero_grad()\n            h = {}\n            for node_type, sg_node_feat in sg_feat.items():\n                h[node_type] = sg_node_feat * feat_mask[node_type].sigmoid()\n            eweight = {}\n            for canonical_etype, canonical_etype_mask in edge_mask.items():\n                eweight[canonical_etype] = canonical_etype_mask.sigmoid()\n            logits = self.model(graph=sg, feat=h, eweight=eweight, **kwargs)[\n                ntype\n            ]\n            log_probs = logits.log_softmax(dim=-1)\n            loss = -log_probs[inverse_indices, pred_label[inverse_indices]]\n            loss = self._loss_regularize(loss, feat_mask, edge_mask)\n            loss.backward()\n            optimizer.step()\n\n            if self.log:\n                pbar.update(1)\n\n        if self.log:\n            pbar.close()\n\n        for node_type in feat_mask:\n            feat_mask[node_type] = (\n                feat_mask[node_type].detach().sigmoid().squeeze()\n            )\n\n        for canonical_etype in edge_mask:\n            edge_mask[canonical_etype] = (\n                edge_mask[canonical_etype].detach().sigmoid()\n            )\n\n        return inverse_indices, sg, feat_mask, edge_mask\n\n    def explain_graph(self, graph, feat, **kwargs):\n        r\"\"\"Learn and return node feature masks and edge masks that play a\n        crucial role to explain the prediction made by the GNN for a graph.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            A heterogeneous graph that will be explained.\n        feat : dict[str, Tensor]\n            The dictionary that associates input node features (values) with\n            the respective node types (keys) present in the graph.\n            The input features are of shape :math:`(N_t, D_t)`. :math:`N_t` is the\n            number of nodes for node type :math:`t`, and :math:`D_t` is the feature size for\n            node type :math:`t`\n        kwargs : dict\n            Additional arguments passed to the GNN model.\n\n        Returns\n        -------\n        feat_mask : dict[str, Tensor]\n            The dictionary that associates the learned node feature importance masks (values) with\n            the respective node types (keys). The masks are of shape :math:`(D_t)`, where\n            :math:`D_t` is the node feature size for node type :attr:`t`. The values are within\n            range :math:`(0, 1)`. The higher, the more important.\n        edge_mask : dict[Tuple[str], Tensor]\n            The dictionary that associates the learned edge importance masks (values) with\n            the respective canonical edge types (keys). The masks are of shape :math:`(E_t)`,\n            where :math:`E_t` is the number of edges for canonical edge type :math:`t` in the\n            graph. The values are within range :math:`(0, 1)`. The higher, the more important.\n\n        Examples\n        --------\n\n        >>> import dgl\n        >>> import dgl.function as fn\n        >>> import torch as th\n        >>> import torch.nn as nn\n        >>> import torch.nn.functional as F\n        >>> from dgl.nn import HeteroGNNExplainer\n\n        >>> class Model(nn.Module):\n        ...     def __init__(self, in_dim, num_classes, canonical_etypes):\n        ...         super(Model, self).__init__()\n        ...         self.etype_weights = nn.ModuleDict({\n        ...             '_'.join(c_etype): nn.Linear(in_dim, num_classes)\n        ...             for c_etype in canonical_etypes\n        ...         })\n        ...\n        ...     def forward(self, graph, feat, eweight=None):\n        ...         with graph.local_scope():\n        ...             c_etype_func_dict = {}\n        ...             for c_etype in graph.canonical_etypes:\n        ...                 src_type, etype, dst_type = c_etype\n        ...                 wh = self.etype_weights['_'.join(c_etype)](feat[src_type])\n        ...                 graph.nodes[src_type].data[f'h_{c_etype}'] = wh\n        ...                 if eweight is None:\n        ...                     c_etype_func_dict[c_etype] = (fn.copy_u(f'h_{c_etype}', 'm'),\n        ...                         fn.mean('m', 'h'))\n        ...                 else:\n        ...                     graph.edges[c_etype].data['w'] = eweight[c_etype]\n        ...                     c_etype_func_dict[c_etype] = (\n        ...                         fn.u_mul_e(f'h_{c_etype}', 'w', 'm'), fn.mean('m', 'h'))\n        ...             graph.multi_update_all(c_etype_func_dict, 'sum')\n        ...             hg = 0\n        ...             for ntype in graph.ntypes:\n        ...                 if graph.num_nodes(ntype):\n        ...                     hg = hg + dgl.mean_nodes(graph, 'h', ntype=ntype)\n        ...             return hg\n\n        >>> input_dim = 5\n        >>> num_classes = 2\n        >>> g = dgl.heterograph({\n        ...     ('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 1, 1])})\n        >>> g.nodes['user'].data['h'] = th.randn(g.num_nodes('user'), input_dim)\n        >>> g.nodes['game'].data['h'] = th.randn(g.num_nodes('game'), input_dim)\n\n        >>> transform = dgl.transforms.AddReverse()\n        >>> g = transform(g)\n\n        >>> # define and train the model\n        >>> model = Model(input_dim, num_classes, g.canonical_etypes)\n        >>> feat = g.ndata['h']\n        >>> optimizer = th.optim.Adam(model.parameters())\n        >>> for epoch in range(10):\n        ...     logits = model(g, feat)\n        ...     loss = F.cross_entropy(logits, th.tensor([1]))\n        ...     optimizer.zero_grad()\n        ...     loss.backward()\n        ...     optimizer.step()\n\n        >>> # Explain for the graph\n        >>> explainer = HeteroGNNExplainer(model, num_hops=1)\n        >>> feat_mask, edge_mask = explainer.explain_graph(g, feat)\n        >>> feat_mask\n        {'game': tensor([0.2684, 0.2597, 0.3135, 0.2976, 0.2607]),\n         'user': tensor([0.2216, 0.2908, 0.2644, 0.2738, 0.2663])}\n        >>> edge_mask\n        {('game', 'rev_plays', 'user'): tensor([0.8922, 0.1966, 0.8371, 0.1330]),\n         ('user', 'plays', 'game'): tensor([0.1785, 0.1696, 0.8065, 0.2167])}\n        \"\"\"\n        self.model = self.model.to(graph.device)\n        self.model.eval()\n\n        # Get the initial prediction.\n        with torch.no_grad():\n            logits = self.model(graph=graph, feat=feat, **kwargs)\n            pred_label = logits.argmax(dim=-1)\n\n        feat_mask, edge_mask = self._init_masks(graph, feat)\n\n        params = [*feat_mask.values(), *edge_mask.values()]\n        optimizer = torch.optim.Adam(params, lr=self.lr)\n\n        if self.log:\n            pbar = tqdm(total=self.num_epochs)\n            pbar.set_description(\"Explain graph\")\n\n        for _ in range(self.num_epochs):\n            optimizer.zero_grad()\n            h = {}\n            for node_type, node_feat in feat.items():\n                h[node_type] = node_feat * feat_mask[node_type].sigmoid()\n            eweight = {}\n            for canonical_etype, canonical_etype_mask in edge_mask.items():\n                eweight[canonical_etype] = canonical_etype_mask.sigmoid()\n            logits = self.model(graph=graph, feat=h, eweight=eweight, **kwargs)\n            log_probs = logits.log_softmax(dim=-1)\n            loss = -log_probs[0, pred_label[0]]\n            loss = self._loss_regularize(loss, feat_mask, edge_mask)\n            loss.backward()\n            optimizer.step()\n\n            if self.log:\n                pbar.update(1)\n\n        if self.log:\n            pbar.close()\n\n        for node_type in feat_mask:\n            feat_mask[node_type] = (\n                feat_mask[node_type].detach().sigmoid().squeeze()\n            )\n\n        for canonical_etype in edge_mask:\n            edge_mask[canonical_etype] = (\n                edge_mask[canonical_etype].detach().sigmoid()\n            )\n\n        return feat_mask, edge_mask\n"
  },
  {
    "path": "python/dgl/nn/pytorch/explain/pgexplainer.py",
    "content": "\"\"\"Torch Module for PGExplainer\"\"\"\nimport math\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom .... import batch, ETYPE, khop_in_subgraph, NID, to_homogeneous\n\n__all__ = [\"PGExplainer\", \"HeteroPGExplainer\"]\n\n\nclass PGExplainer(nn.Module):\n    r\"\"\"PGExplainer from `Parameterized Explainer for Graph Neural Network\n    <https://arxiv.org/pdf/2011.04573>`\n\n    PGExplainer adopts a deep neural network (explanation network) to\n    parameterize the generation process of explanations, which enables it to\n    explain multiple instances collectively. PGExplainer models the underlying\n    structure as edge distributions, from which the explanatory graph is\n    sampled.\n\n    Parameters\n    ----------\n    model : nn.Module\n        The GNN model to explain that tackles multiclass graph classification\n\n        * Its forward function must have the form\n          :attr:`forward(self, graph, nfeat, embed, edge_weight)`.\n        * The output of its forward function is the logits if embed=False else\n          the intermediate node embeddings.\n    num_features : int\n        Node embedding size used by :attr:`model`.\n    num_hops : int, optional\n        The number of hops for GNN information aggregation, which must match the\n        number of message passing layers employed by the GNN to be explained.\n    explain_graph : bool, optional\n        Whether to initialize the model for graph-level or node-level predictions.\n    coff_budget : float, optional\n        Size regularization to constrain the explanation size. Default: 0.01.\n    coff_connect : float, optional\n        Entropy regularization to constrain the connectivity of explanation. Default: 5e-4.\n    sample_bias : float, optional\n        Some members of a population are systematically more likely to be selected\n        in a sample than others. Default: 0.0.\n    \"\"\"\n\n    def __init__(\n        self,\n        model,\n        num_features,\n        num_hops=None,\n        explain_graph=True,\n        coff_budget=0.01,\n        coff_connect=5e-4,\n        sample_bias=0.0,\n    ):\n        super(PGExplainer, self).__init__()\n\n        self.model = model\n        self.graph_explanation = explain_graph\n        # Node explanation requires additional self-embedding data.\n        self.num_features = num_features * (2 if self.graph_explanation else 3)\n        self.num_hops = num_hops\n\n        # training hyperparameters for PGExplainer\n        self.coff_budget = coff_budget\n        self.coff_connect = coff_connect\n        self.sample_bias = sample_bias\n\n        self.init_bias = 0.0\n\n        # Explanation network in PGExplainer\n        self.elayers = nn.Sequential(\n            nn.Linear(self.num_features, 64), nn.ReLU(), nn.Linear(64, 1)\n        )\n\n    def set_masks(self, graph, edge_mask=None):\n        r\"\"\"Set the edge mask that plays a crucial role to explain the\n        prediction made by the GNN for a graph. Initialize learnable edge\n        mask if it is None.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            A homogeneous graph.\n        edge_mask : Tensor, optional\n            Learned importance mask of the edges in the graph, which is a tensor\n            of shape :math:`(E)`, where :math:`E` is the number of edges in the\n            graph. The values are within range :math:`(0, 1)`. The higher,\n            the more important. Default: None.\n        \"\"\"\n        if edge_mask is None:\n            num_nodes = graph.num_nodes()\n            num_edges = graph.num_edges()\n\n            init_bias = self.init_bias\n            std = nn.init.calculate_gain(\"relu\") * math.sqrt(\n                2.0 / (2 * num_nodes)\n            )\n            self.edge_mask = torch.randn(num_edges) * std + init_bias\n        else:\n            self.edge_mask = edge_mask\n\n        self.edge_mask = self.edge_mask.to(graph.device)\n\n    def clear_masks(self):\n        r\"\"\"Clear the edge mask that play a crucial role to explain the\n        prediction made by the GNN for a graph.\n        \"\"\"\n        self.edge_mask = None\n\n    def parameters(self):\n        r\"\"\"\n        Returns an iterator over the `Parameter` objects of the `nn.Linear`\n        layers in the `self.elayers` sequential module. Each `Parameter`\n        object contains the weight and bias parameters of an `nn.Linear`\n        layer, as learned during training.\n\n        Returns\n        -------\n        iterator\n            An iterator over the `Parameter` objects of the `nn.Linear`\n            layers in the `self.elayers` sequential module.\n        \"\"\"\n        return self.elayers.parameters()\n\n    def loss(self, prob, ori_pred):\n        r\"\"\"The loss function that is used to learn the edge\n        distribution.\n\n        Parameters\n        ----------\n        prob: Tensor\n            Tensor contains a set of probabilities for each possible\n            class label of some model for all the batched graphs,\n            which is of shape :math:`(B, L)`, where :math:`L` is the\n            different types of label in the dataset and :math:`B` is\n            the batch size.\n        ori_pred: Tensor\n            Tensor of shape :math:`(B, 1)`, representing the original prediction\n            for the graph, where :math:`B` is the batch size.\n\n        Returns\n        -------\n        float\n            The function that returns the sum of the three loss components,\n            which is a scalar tensor representing the total loss.\n        \"\"\"\n        target_prob = prob.gather(-1, ori_pred.unsqueeze(-1))\n        # 1e-6 added to prob to avoid taking the logarithm of zero\n        target_prob += 1e-6\n        # computing the log likelihood for a single prediction\n        pred_loss = torch.mean(-torch.log(target_prob))\n\n        # size\n        edge_mask = self.sparse_mask_values\n        if self.coff_budget <= 0:\n            size_loss = self.coff_budget * torch.sum(edge_mask)\n        else:\n            size_loss = self.coff_budget * F.relu(\n                torch.sum(edge_mask) - self.coff_budget\n            )\n\n        # entropy\n        scale = 0.99\n        edge_mask = self.edge_mask * (2 * scale - 1.0) + (1.0 - scale)\n        mask_ent = -edge_mask * torch.log(edge_mask) - (\n            1 - edge_mask\n        ) * torch.log(1 - edge_mask)\n        mask_ent_loss = self.coff_connect * torch.mean(mask_ent)\n\n        loss = pred_loss + size_loss + mask_ent_loss\n        return loss\n\n    def concrete_sample(self, w, beta=1.0, training=True):\n        r\"\"\"Sample from the instantiation of concrete distribution when training.\n\n        Parameters\n        ----------\n        w : Tensor\n            A tensor representing the log of the prior probability of choosing the edges.\n        beta : float, optional\n            Controls the degree of randomness in the output of the sigmoid function.\n        training : bool, optional\n            Randomness is injected during training.\n\n        Returns\n        -------\n        Tensor\n            If training is set to True, the output is a tensor of probabilities that\n            represent the probability of activating the gate for each input element.\n            If training is set to False, the output is also a tensor of probabilities,\n            but they are determined solely by the log_alpha values, without adding any\n            random noise.\n        \"\"\"\n        if training:\n            bias = self.sample_bias\n            random_noise = torch.rand(w.size()).to(w.device)\n            random_noise = bias + (1 - 2 * bias) * random_noise\n            gate_inputs = torch.log(random_noise) - torch.log(\n                1.0 - random_noise\n            )\n            gate_inputs = (gate_inputs + w) / beta\n            gate_inputs = torch.sigmoid(gate_inputs)\n        else:\n            gate_inputs = torch.sigmoid(w)\n\n        return gate_inputs\n\n    def train_step(self, graph, feat, temperature, **kwargs):\n        r\"\"\"Compute the loss of the explanation network for graph classification\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            Input batched homogeneous graph.\n        feat : Tensor\n            The input feature of shape :math:`(N, D)`. :math:`N` is the\n            number of nodes, and :math:`D` is the feature size.\n        temperature : float\n            The temperature parameter fed to the sampling procedure.\n        kwargs : dict\n            Additional arguments passed to the GNN model.\n\n        Returns\n        -------\n        Tensor\n            A scalar tensor representing the loss.\n        \"\"\"\n        assert (\n            self.graph_explanation\n        ), '\"explain_graph\" must be True when initializing the module.'\n\n        self.model = self.model.to(graph.device)\n        self.elayers = self.elayers.to(graph.device)\n\n        pred = self.model(graph, feat, embed=False, **kwargs)\n        pred = pred.argmax(-1).data\n\n        prob, _ = self.explain_graph(\n            graph, feat, temperature, training=True, **kwargs\n        )\n\n        loss = self.loss(prob, pred)\n        return loss\n\n    def train_step_node(self, nodes, graph, feat, temperature, **kwargs):\n        r\"\"\"Compute the loss of the explanation network for node classification\n\n        Parameters\n        ----------\n        nodes : int, iterable[int], tensor\n            The nodes from the graph used to train the explanation network,\n            which cannot have any duplicate value.\n        graph : DGLGraph\n            Input homogeneous graph.\n        feat : Tensor\n            The input feature of shape :math:`(N, D)`. :math:`N` is the\n            number of nodes, and :math:`D` is the feature size.\n        temperature : float\n            The temperature parameter fed to the sampling procedure.\n        kwargs : dict\n            Additional arguments passed to the GNN model.\n\n        Returns\n        -------\n        Tensor\n            A scalar tensor representing the loss.\n        \"\"\"\n        assert (\n            not self.graph_explanation\n        ), '\"explain_graph\" must be False when initializing the module.'\n\n        self.model = self.model.to(graph.device)\n        self.elayers = self.elayers.to(graph.device)\n\n        if isinstance(nodes, torch.Tensor):\n            nodes = nodes.tolist()\n        if isinstance(nodes, int):\n            nodes = [nodes]\n\n        prob, _, batched_graph, inverse_indices = self.explain_node(\n            nodes, graph, feat, temperature, training=True, **kwargs\n        )\n\n        pred = self.model(\n            batched_graph, self.batched_feats, embed=False, **kwargs\n        )\n        pred = pred.argmax(-1).data\n\n        loss = self.loss(prob[inverse_indices], pred[inverse_indices])\n        return loss\n\n    def explain_graph(\n        self, graph, feat, temperature=1.0, training=False, **kwargs\n    ):\n        r\"\"\"Learn and return an edge mask that plays a crucial role to\n        explain the prediction made by the GNN for a graph. Also, return\n        the prediction made with the edges chosen based on the edge mask.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            A homogeneous graph.\n        feat : Tensor\n            The input feature of shape :math:`(N, D)`. :math:`N` is the\n            number of nodes, and :math:`D` is the feature size.\n        temperature : float\n            The temperature parameter fed to the sampling procedure.\n        training : bool\n            Training the explanation network.\n        kwargs : dict\n            Additional arguments passed to the GNN model.\n\n        Returns\n        -------\n        Tensor\n            Classification probabilities given the masked graph. It is a tensor\n            of shape :math:`(B, L)`, where :math:`L` is the different types of\n            label in the dataset, and :math:`B` is the batch size.\n        Tensor\n            Edge weights which is a tensor of shape :math:`(E)`, where :math:`E`\n            is the number of edges in the graph. A higher weight suggests a\n            larger contribution of the edge.\n\n        Examples\n        --------\n\n        >>> import torch as th\n        >>> import torch.nn as nn\n        >>> import dgl\n        >>> from dgl.data import GINDataset\n        >>> from dgl.dataloading import GraphDataLoader\n        >>> from dgl.nn import GraphConv, PGExplainer\n        >>> import numpy as np\n\n        >>> # Define the model\n        >>> class Model(nn.Module):\n        ...     def __init__(self, in_feats, out_feats):\n        ...         super().__init__()\n        ...         self.conv = GraphConv(in_feats, out_feats)\n        ...         self.fc = nn.Linear(out_feats, out_feats)\n        ...         nn.init.xavier_uniform_(self.fc.weight)\n        ...\n        ...     def forward(self, g, h, embed=False, edge_weight=None):\n        ...         h = self.conv(g, h, edge_weight=edge_weight)\n        ...\n        ...         if embed:\n        ...             return h\n        ...\n        ...         with g.local_scope():\n        ...             g.ndata['h'] = h\n        ...             hg = dgl.mean_nodes(g, 'h')\n        ...             return self.fc(hg)\n\n        >>> # Load dataset\n        >>> data = GINDataset('MUTAG', self_loop=True)\n        >>> dataloader = GraphDataLoader(data, batch_size=64, shuffle=True)\n\n        >>> # Train the model\n        >>> feat_size = data[0][0].ndata['attr'].shape[1]\n        >>> model = Model(feat_size, data.gclasses)\n        >>> criterion = nn.CrossEntropyLoss()\n        >>> optimizer = th.optim.Adam(model.parameters(), lr=1e-2)\n        >>> for bg, labels in dataloader:\n        ...     preds = model(bg, bg.ndata['attr'])\n        ...     loss = criterion(preds, labels)\n        ...     optimizer.zero_grad()\n        ...     loss.backward()\n        ...     optimizer.step()\n\n        >>> # Initialize the explainer\n        >>> explainer = PGExplainer(model, data.gclasses)\n\n        >>> # Train the explainer\n        >>> # Define explainer temperature parameter\n        >>> init_tmp, final_tmp = 5.0, 1.0\n        >>> optimizer_exp = th.optim.Adam(explainer.parameters(), lr=0.01)\n        >>> for epoch in range(20):\n        ...     tmp = float(init_tmp * np.power(final_tmp / init_tmp, epoch / 20))\n        ...     for bg, labels in dataloader:\n        ...          loss = explainer.train_step(bg, bg.ndata['attr'], tmp)\n        ...          optimizer_exp.zero_grad()\n        ...          loss.backward()\n        ...          optimizer_exp.step()\n\n        >>> # Explain the prediction for graph 0\n        >>> graph, l = data[0]\n        >>> graph_feat = graph.ndata.pop(\"attr\")\n        >>> probs, edge_weight = explainer.explain_graph(graph, graph_feat)\n        \"\"\"\n        assert (\n            self.graph_explanation\n        ), '\"explain_graph\" must be True when initializing the module.'\n\n        self.model = self.model.to(graph.device)\n        self.elayers = self.elayers.to(graph.device)\n\n        embed = self.model(graph, feat, embed=True, **kwargs)\n        embed = embed.data\n\n        col, row = graph.edges()\n        col_emb = embed[col.long()]\n        row_emb = embed[row.long()]\n        emb = torch.cat([col_emb, row_emb], dim=-1)\n        emb = self.elayers(emb)\n        values = emb.reshape(-1)\n\n        values = self.concrete_sample(\n            values, beta=temperature, training=training\n        )\n        self.sparse_mask_values = values\n\n        reverse_eids = graph.edge_ids(row, col).long()\n        edge_mask = (values + values[reverse_eids]) / 2\n\n        self.set_masks(graph, edge_mask)\n\n        # the model prediction with the updated edge mask\n        logits = self.model(graph, feat, edge_weight=self.edge_mask, **kwargs)\n        probs = F.softmax(logits, dim=-1)\n\n        if training:\n            probs = probs.data\n        else:\n            self.clear_masks()\n\n        return (probs, edge_mask)\n\n    def explain_node(\n        self, nodes, graph, feat, temperature=1.0, training=False, **kwargs\n    ):\n        r\"\"\"Learn and return an edge mask that plays a crucial role to\n        explain the prediction made by the GNN for provided set of node IDs.\n        Also, return the prediction made with the graph and edge mask.\n\n        Parameters\n        ----------\n        nodes : int, iterable[int], tensor\n            The nodes from the graph, which cannot have any duplicate value.\n        graph : DGLGraph\n            A homogeneous graph.\n        feat : Tensor\n            The input feature of shape :math:`(N, D)`. :math:`N` is the\n            number of nodes, and :math:`D` is the feature size.\n        temperature : float\n            The temperature parameter fed to the sampling procedure.\n        training : bool\n            Training the explanation network.\n        kwargs : dict\n            Additional arguments passed to the GNN model.\n\n        Returns\n        -------\n        Tensor\n            Classification probabilities given the masked graph. It is a tensor\n            of shape :math:`(N, L)`, where :math:`L` is the different types\n            of node labels in the dataset, and :math:`N` is the number of nodes\n            in the graph.\n        Tensor\n            Edge weights which is a tensor of shape :math:`(E)`, where :math:`E`\n            is the number of edges in the graph. A higher weight suggests a\n            larger contribution of the edge.\n        DGLGraph\n            The batched set of subgraphs induced on the k-hop in-neighborhood\n            of the input center nodes.\n        Tensor\n            The new IDs of the subgraph center nodes.\n\n        Examples\n        --------\n\n        >>> import dgl\n        >>> import numpy as np\n        >>> import torch\n\n        >>> # Define the model\n        >>> class Model(torch.nn.Module):\n        ...     def __init__(self, in_feats, out_feats):\n        ...         super().__init__()\n        ...         self.conv1 = dgl.nn.GraphConv(in_feats, out_feats)\n        ...         self.conv2 = dgl.nn.GraphConv(out_feats, out_feats)\n        ...\n        ...     def forward(self, g, h, embed=False, edge_weight=None):\n        ...         h = self.conv1(g, h, edge_weight=edge_weight)\n        ...         if embed:\n        ...             return h\n        ...         return self.conv2(g, h)\n\n        >>> # Load dataset\n        >>> data = dgl.data.CoraGraphDataset(verbose=False)\n        >>> g = data[0]\n        >>> features = g.ndata[\"feat\"]\n        >>> labels = g.ndata[\"label\"]\n\n        >>> # Train the model\n        >>> model = Model(features.shape[1], data.num_classes)\n        >>> criterion = torch.nn.CrossEntropyLoss()\n        >>> optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)\n        >>> for epoch in range(20):\n        ...     logits = model(g, features)\n        ...     loss = criterion(logits, labels)\n        ...     optimizer.zero_grad()\n        ...     loss.backward()\n        ...     optimizer.step()\n\n        >>> # Initialize the explainer\n        >>> explainer = dgl.nn.PGExplainer(\n        ...     model, data.num_classes, num_hops=2, explain_graph=False\n        ... )\n\n        >>> # Train the explainer\n        >>> # Define explainer temperature parameter\n        >>> init_tmp, final_tmp = 5.0, 1.0\n        >>> optimizer_exp = torch.optim.Adam(explainer.parameters(), lr=0.01)\n        >>> epochs = 10\n        >>> for epoch in range(epochs):\n        ...     tmp = float(init_tmp * np.power(final_tmp / init_tmp, epoch / epochs))\n        ...     loss = explainer.train_step_node(g.nodes(), g, features, tmp)\n        ...     optimizer_exp.zero_grad()\n        ...     loss.backward()\n        ...     optimizer_exp.step()\n\n        >>> # Explain the prediction for graph 0\n        >>> probs, edge_weight, bg, inverse_indices = explainer.explain_node(\n        ...     0, g, features\n        ... )\n        \"\"\"\n        assert (\n            not self.graph_explanation\n        ), '\"explain_graph\" must be False when initializing the module.'\n        assert (\n            self.num_hops is not None\n        ), '\"num_hops\" must be provided when initializing the module.'\n\n        if isinstance(nodes, torch.Tensor):\n            nodes = nodes.tolist()\n        if isinstance(nodes, int):\n            nodes = [nodes]\n\n        self.model = self.model.to(graph.device)\n        self.elayers = self.elayers.to(graph.device)\n\n        batched_graph = []\n        batched_embed = []\n        for node_id in nodes:\n            sg, inverse_indices = khop_in_subgraph(\n                graph, node_id, self.num_hops\n            )\n            sg.ndata[\"feat\"] = feat[sg.ndata[NID].long()]\n            sg.ndata[\"train\"] = torch.tensor(\n                [nid in inverse_indices for nid in sg.nodes()], device=sg.device\n            )\n\n            embed = self.model(sg, sg.ndata[\"feat\"], embed=True, **kwargs)\n            embed = embed.data\n\n            col, row = sg.edges()\n            col_emb = embed[col.long()]\n            row_emb = embed[row.long()]\n            self_emb = embed[inverse_indices[0]].repeat(sg.num_edges(), 1)\n            emb = torch.cat([col_emb, row_emb, self_emb], dim=-1)\n            batched_embed.append(emb)\n            batched_graph.append(sg)\n\n        batched_graph = batch(batched_graph)\n\n        batched_embed = torch.cat(batched_embed)\n        batched_embed = self.elayers(batched_embed)\n        values = batched_embed.reshape(-1)\n\n        values = self.concrete_sample(\n            values, beta=temperature, training=training\n        )\n        self.sparse_mask_values = values\n\n        col, row = batched_graph.edges()\n        reverse_eids = batched_graph.edge_ids(row, col).long()\n        edge_mask = (values + values[reverse_eids]) / 2\n\n        self.set_masks(batched_graph, edge_mask)\n\n        batched_feats = batched_graph.ndata[\"feat\"]\n        # the model prediction with the updated edge mask\n        logits = self.model(\n            batched_graph, batched_feats, edge_weight=self.edge_mask, **kwargs\n        )\n        probs = F.softmax(logits, dim=-1)\n\n        batched_inverse_indices = (\n            batched_graph.ndata[\"train\"].nonzero().squeeze(1)\n        )\n\n        if training:\n            self.batched_feats = batched_feats\n            probs = probs.data\n        else:\n            self.clear_masks()\n\n        return (\n            probs,\n            edge_mask,\n            batched_graph,\n            batched_inverse_indices,\n        )\n\n\nclass HeteroPGExplainer(PGExplainer):\n    r\"\"\"PGExplainer from `Parameterized Explainer for Graph Neural Network\n    <https://arxiv.org/pdf/2011.04573>`__, adapted for heterogeneous graphs\n\n    PGExplainer adopts a deep neural network (explanation network) to\n    parameterize the generation process of explanations, which enables it to\n    explain multiple instances collectively. PGExplainer models the underlying\n    structure as edge distributions, from which the explanatory graph is\n    sampled.\n\n    Parameters\n    ----------\n    model : nn.Module\n        The GNN model to explain that tackles multiclass graph classification\n\n        * Its forward function must have the form\n          :attr:`forward(self, graph, nfeat, embed, edge_weight)`.\n        * The output of its forward function is the logits if embed=False else\n          the intermediate node embeddings.\n    num_features : int\n        Node embedding size used by :attr:`model`.\n    coff_budget : float, optional\n        Size regularization to constrain the explanation size. Default: 0.01.\n    coff_connect : float, optional\n        Entropy regularization to constrain the connectivity of explanation. Default: 5e-4.\n    sample_bias : float, optional\n        Some members of a population are systematically more likely to be selected\n        in a sample than others. Default: 0.0.\n    \"\"\"\n\n    def train_step(self, graph, feat, temperature, **kwargs):\n        # pylint: disable=useless-super-delegation\n        r\"\"\"Compute the loss of the explanation network for graph classification\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            Input batched heterogeneous graph.\n        feat : dict[str, Tensor]\n            A dict mapping node types (keys) to feature tensors (values).\n            The input features are of shape :math:`(N_t, D_t)`. :math:`N_t` is\n            the number of nodes for node type :math:`t`, and :math:`D_t` is the\n            feature size for node type :math:`t`\n        temperature : float\n            The temperature parameter fed to the sampling procedure.\n        kwargs : dict\n            Additional arguments passed to the GNN model.\n\n        Returns\n        -------\n        Tensor\n            A scalar tensor representing the loss.\n        \"\"\"\n        return super().train_step(graph, feat, temperature, **kwargs)\n\n    def train_step_node(self, nodes, graph, feat, temperature, **kwargs):\n        r\"\"\"Compute the loss of the explanation network for node classification\n\n        Parameters\n        ----------\n        nodes : dict[str, Iterable[int]]\n            A dict mapping node types (keys) to an iterable set of node ids (values).\n        graph : DGLGraph\n            Input heterogeneous graph.\n        feat : dict[str, Tensor]\n            A dict mapping node types (keys) to feature tensors (values).\n            The input features are of shape :math:`(N_t, D_t)`. :math:`N_t` is\n            the number of nodes for node type :math:`t`, and :math:`D_t` is the\n            feature size for node type :math:`t`\n        temperature : float\n            The temperature parameter fed to the sampling procedure.\n        kwargs : dict\n            Additional arguments passed to the GNN model.\n\n        Returns\n        -------\n        Tensor\n            A scalar tensor representing the loss.\n        \"\"\"\n        assert (\n            not self.graph_explanation\n        ), '\"explain_graph\" must be False when initializing the module.'\n\n        self.model = self.model.to(graph.device)\n        self.elayers = self.elayers.to(graph.device)\n\n        prob, _, batched_graph, inverse_indices = self.explain_node(\n            nodes, graph, feat, temperature, training=True, **kwargs\n        )\n\n        pred = self.model(\n            batched_graph, self.batched_feats, embed=False, **kwargs\n        )\n        pred = {ntype: pred[ntype].argmax(-1).data for ntype in pred.keys()}\n\n        loss = self.loss(\n            torch.cat(\n                [prob[ntype][nid] for ntype, nid in inverse_indices.items()]\n            ),\n            torch.cat(\n                [pred[ntype][nid] for ntype, nid in inverse_indices.items()]\n            ),\n        )\n        return loss\n\n    def explain_graph(\n        self, graph, feat, temperature=1.0, training=False, **kwargs\n    ):\n        r\"\"\"Learn and return an edge mask that plays a crucial role to\n        explain the prediction made by the GNN for a graph. Also, return\n        the prediction made with the edges chosen based on the edge mask.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            A heterogeneous graph.\n        feat : dict[str, Tensor]\n            A dict mapping node types (keys) to feature tensors (values).\n            The input features are of shape :math:`(N_t, D_t)`. :math:`N_t` is\n            the number of nodes for node type :math:`t`, and :math:`D_t` is the\n            feature size for node type :math:`t`\n        temperature : float\n            The temperature parameter fed to the sampling procedure.\n        training : bool\n            Training the explanation network.\n        kwargs : dict\n            Additional arguments passed to the GNN model.\n\n        Returns\n        -------\n        Tensor\n            Classification probabilities given the masked graph. It is a tensor\n            of shape :math:`(B, L)`, where :math:`L` is the different types of\n            label in the dataset, and :math:`B` is the batch size.\n        dict[str, Tensor]\n            A dict mapping edge types (keys) to edge tensors (values) of shape\n            :math:`(E_t)`, where :math:`E_t` is the number of edges in the graph\n            for edge type :math:`t`.  A higher weight suggests a larger\n            contribution of the edge.\n\n        Examples\n        --------\n\n        >>> import dgl\n        >>> import torch as th\n        >>> import torch.nn as nn\n        >>> import numpy as np\n\n        >>> # Define the model\n        >>> class Model(nn.Module):\n        ...     def __init__(self, in_feats, hid_feats, out_feats, rel_names):\n        ...         super().__init__()\n        ...         self.conv = dgl.nn.HeteroGraphConv(\n        ...             {rel: dgl.nn.GraphConv(in_feats, hid_feats) for rel in rel_names},\n        ...             aggregate=\"sum\",\n        ...         )\n        ...         self.fc = nn.Linear(hid_feats, out_feats)\n        ...         nn.init.xavier_uniform_(self.fc.weight)\n        ...\n        ...     def forward(self, g, h, embed=False, edge_weight=None):\n        ...         if edge_weight:\n        ...             mod_kwargs = {\n        ...                 etype: {\"edge_weight\": mask} for etype, mask in edge_weight.items()\n        ...             }\n        ...             h = self.conv(g, h, mod_kwargs=mod_kwargs)\n        ...         else:\n        ...             h = self.conv(g, h)\n        ...\n        ...         if embed:\n        ...             return h\n        ...\n        ...         with g.local_scope():\n        ...             g.ndata[\"h\"] = h\n        ...             hg = 0\n        ...             for ntype in g.ntypes:\n        ...                 hg = hg + dgl.mean_nodes(g, \"h\", ntype=ntype)\n        ...             return self.fc(hg)\n\n        >>> # Load dataset\n        >>> input_dim = 5\n        >>> hidden_dim = 5\n        >>> num_classes = 2\n        >>> g = dgl.heterograph({(\"user\", \"plays\", \"game\"): ([0, 1, 1, 2], [0, 0, 1, 1])})\n        >>> g.nodes[\"user\"].data[\"h\"] = th.randn(g.num_nodes(\"user\"), input_dim)\n        >>> g.nodes[\"game\"].data[\"h\"] = th.randn(g.num_nodes(\"game\"), input_dim)\n\n        >>> transform = dgl.transforms.AddReverse()\n        >>> g = transform(g)\n\n        >>> # define and train the model\n        >>> model = Model(input_dim, hidden_dim, num_classes, g.canonical_etypes)\n        >>> optimizer = th.optim.Adam(model.parameters())\n        >>> for epoch in range(10):\n        ...     logits = model(g, g.ndata[\"h\"])\n        ...     loss = th.nn.functional.cross_entropy(logits, th.tensor([1]))\n        ...     optimizer.zero_grad()\n        ...     loss.backward()\n        ...     optimizer.step()\n\n        >>> # Initialize the explainer\n        >>> explainer = dgl.nn.HeteroPGExplainer(model, hidden_dim)\n\n        >>> # Train the explainer\n        >>> # Define explainer temperature parameter\n        >>> init_tmp, final_tmp = 5.0, 1.0\n        >>> optimizer_exp = th.optim.Adam(explainer.parameters(), lr=0.01)\n        >>> for epoch in range(20):\n        ...     tmp = float(init_tmp * np.power(final_tmp / init_tmp, epoch / 20))\n        ...     loss = explainer.train_step(g, g.ndata[\"h\"], tmp)\n        ...     optimizer_exp.zero_grad()\n        ...     loss.backward()\n        ...     optimizer_exp.step()\n\n        >>> # Explain the graph\n        >>> feat = g.ndata.pop(\"h\")\n        >>> probs, edge_mask = explainer.explain_graph(g, feat)\n        \"\"\"\n        assert (\n            self.graph_explanation\n        ), '\"explain_graph\" must be True when initializing the module.'\n\n        self.model = self.model.to(graph.device)\n        self.elayers = self.elayers.to(graph.device)\n\n        embed = self.model(graph, feat, embed=True, **kwargs)\n        for ntype, emb in embed.items():\n            graph.nodes[ntype].data[\"emb\"] = emb.data\n        homo_graph = to_homogeneous(graph, ndata=[\"emb\"])\n        homo_embed = homo_graph.ndata[\"emb\"]\n\n        col, row = homo_graph.edges()\n        col_emb = homo_embed[col.long()]\n        row_emb = homo_embed[row.long()]\n        emb = torch.cat([col_emb, row_emb], dim=-1)\n        emb = self.elayers(emb)\n        values = emb.reshape(-1)\n\n        values = self.concrete_sample(\n            values, beta=temperature, training=training\n        )\n        self.sparse_mask_values = values\n\n        reverse_eids = homo_graph.edge_ids(row, col).long()\n        edge_mask = (values + values[reverse_eids]) / 2\n\n        self.set_masks(homo_graph, edge_mask)\n\n        # convert the edge mask back into heterogeneous format\n        hetero_edge_mask = self._edge_mask_to_heterogeneous(\n            edge_mask=edge_mask,\n            homograph=homo_graph,\n            heterograph=graph,\n        )\n\n        # the model prediction with the updated edge mask\n        logits = self.model(graph, feat, edge_weight=hetero_edge_mask, **kwargs)\n        probs = F.softmax(logits, dim=-1)\n\n        if training:\n            probs = probs.data\n        else:\n            self.clear_masks()\n\n        return (probs, hetero_edge_mask)\n\n    def explain_node(\n        self, nodes, graph, feat, temperature=1.0, training=False, **kwargs\n    ):\n        r\"\"\"Learn and return an edge mask that plays a crucial role to\n        explain the prediction made by the GNN for provided set of node IDs.\n        Also, return the prediction made with the batched graph and edge mask.\n\n        Parameters\n        ----------\n        nodes : dict[str, Iterable[int]]\n            A dict mapping node types (keys) to an iterable set of node ids (values).\n        graph : DGLGraph\n            A heterogeneous graph.\n        feat : dict[str, Tensor]\n            A dict mapping node types (keys) to feature tensors (values).\n            The input features are of shape :math:`(N_t, D_t)`. :math:`N_t` is\n            the number of nodes for node type :math:`t`, and :math:`D_t` is the\n            feature size for node type :math:`t`\n        temperature : float\n            The temperature parameter fed to the sampling procedure.\n        training : bool\n            Training the explanation network.\n        kwargs : dict\n            Additional arguments passed to the GNN model.\n\n        Returns\n        -------\n        dict[str, Tensor]\n            A dict mapping node types (keys) to classification probabilities\n            for node labels (values). The values are tensors of shape\n            :math:`(N_t, L)`, where :math:`L` is the different types of node\n            labels in the dataset, and :math:`N_t` is the number of nodes in\n            the graph for node type :math:`t`.\n        dict[str, Tensor]\n            A dict mapping edge types (keys) to edge tensors (values) of shape\n            :math:`(E_t)`, where :math:`E_t` is the number of edges in the graph\n            for edge type :math:`t`.  A higher weight suggests a larger\n            contribution of the edge.\n        DGLGraph\n            The batched set of subgraphs induced on the k-hop in-neighborhood\n            of the input center nodes.\n        dict[str, Tensor]\n            A dict mapping node types (keys) to a tensor of node IDs (values)\n            which correspond to the subgraph center nodes.\n\n        Examples\n        --------\n\n        >>> import dgl\n        >>> import torch as th\n        >>> import torch.nn as nn\n        >>> import numpy as np\n\n        >>> # Define the model\n        >>> class Model(nn.Module):\n        ...     def __init__(self, in_feats, hid_feats, out_feats, rel_names):\n        ...         super().__init__()\n        ...         self.conv = dgl.nn.HeteroGraphConv(\n        ...             {rel: dgl.nn.GraphConv(in_feats, hid_feats) for rel in rel_names},\n        ...             aggregate=\"sum\",\n        ...         )\n        ...         self.fc = nn.Linear(hid_feats, out_feats)\n        ...         nn.init.xavier_uniform_(self.fc.weight)\n        ...\n        ...     def forward(self, g, h, embed=False, edge_weight=None):\n        ...         if edge_weight:\n        ...             mod_kwargs = {\n        ...                 etype: {\"edge_weight\": mask} for etype, mask in edge_weight.items()\n        ...             }\n        ...             h = self.conv(g, h, mod_kwargs=mod_kwargs)\n        ...         else:\n        ...             h = self.conv(g, h)\n        ...\n        ...         return h\n\n        >>> # Load dataset\n        >>> input_dim = 5\n        >>> hidden_dim = 5\n        >>> num_classes = 2\n        >>> g = dgl.heterograph({(\"user\", \"plays\", \"game\"): ([0, 1, 1, 2], [0, 0, 1, 1])})\n        >>> g.nodes[\"user\"].data[\"h\"] = th.randn(g.num_nodes(\"user\"), input_dim)\n        >>> g.nodes[\"game\"].data[\"h\"] = th.randn(g.num_nodes(\"game\"), input_dim)\n\n        >>> transform = dgl.transforms.AddReverse()\n        >>> g = transform(g)\n\n        >>> # define and train the model\n        >>> model = Model(input_dim, hidden_dim, num_classes, g.canonical_etypes)\n        >>> optimizer = th.optim.Adam(model.parameters())\n        >>> for epoch in range(10):\n        ...     logits = model(g, g.ndata[\"h\"])['user']\n        ...     loss = th.nn.functional.cross_entropy(logits, th.tensor([1,1,1]))\n        ...     optimizer.zero_grad()\n        ...     loss.backward()\n        ...     optimizer.step()\n\n        >>> # Initialize the explainer\n        >>> explainer = dgl.nn.HeteroPGExplainer(\n        ...     model, hidden_dim, num_hops=2, explain_graph=False\n        ... )\n\n        >>> # Train the explainer\n        >>> # Define explainer temperature parameter\n        >>> init_tmp, final_tmp = 5.0, 1.0\n        >>> optimizer_exp = th.optim.Adam(explainer.parameters(), lr=0.01)\n        >>> for epoch in range(20):\n        ...     tmp = float(init_tmp * np.power(final_tmp / init_tmp, epoch / 20))\n        ...     loss = explainer.train_step_node(\n        ...         { ntype: g.nodes(ntype) for ntype in g.ntypes },\n        ...         g, g.ndata[\"h\"], tmp\n        ...     )\n        ...     optimizer_exp.zero_grad()\n        ...     loss.backward()\n        ...     optimizer_exp.step()\n\n        >>> # Explain the graph\n        >>> feat = g.ndata.pop(\"h\")\n        >>> probs, edge_mask, bg, inverse_indices = explainer.explain_node(\n        ...     { \"user\": [0] }, g, feat\n        ... )\n        \"\"\"\n        assert (\n            not self.graph_explanation\n        ), '\"explain_graph\" must be False when initializing the module.'\n        assert (\n            self.num_hops is not None\n        ), '\"num_hops\" must be provided when initializing the module.'\n\n        self.model = self.model.to(graph.device)\n        self.elayers = self.elayers.to(graph.device)\n\n        batched_embed = []\n        batched_homo_graph = []\n        batched_hetero_graph = []\n        for target_ntype, target_nids in nodes.items():\n            if isinstance(target_nids, torch.Tensor):\n                target_nids = target_nids.tolist()\n\n            for target_nid in target_nids:\n                sg, inverse_indices = khop_in_subgraph(\n                    graph, {target_ntype: target_nid}, self.num_hops\n                )\n\n                for sg_ntype in sg.ntypes:\n                    sg_feat = feat[sg_ntype][sg.ndata[NID][sg_ntype].long()]\n                    train_mask = [\n                        sg_ntype in inverse_indices\n                        and node_id in inverse_indices[sg_ntype]\n                        for node_id in sg.nodes(sg_ntype)\n                    ]\n\n                    sg.nodes[sg_ntype].data[\"feat\"] = sg_feat\n                    sg.nodes[sg_ntype].data[\"train\"] = torch.tensor(\n                        train_mask, device=sg.device\n                    )\n\n                embed = self.model(sg, sg.ndata[\"feat\"], embed=True, **kwargs)\n                for ntype in embed.keys():\n                    sg.nodes[ntype].data[\"emb\"] = embed[ntype].data\n\n                homo_sg = to_homogeneous(sg, ndata=[\"emb\"])\n                homo_sg_embed = homo_sg.ndata[\"emb\"]\n\n                col, row = homo_sg.edges()\n                col_emb = homo_sg_embed[col.long()]\n                row_emb = homo_sg_embed[row.long()]\n                self_emb = homo_sg_embed[\n                    inverse_indices[target_ntype][0]\n                ].repeat(sg.num_edges(), 1)\n                emb = torch.cat([col_emb, row_emb, self_emb], dim=-1)\n                batched_embed.append(emb)\n                batched_homo_graph.append(homo_sg)\n                batched_hetero_graph.append(sg)\n\n        batched_homo_graph = batch(batched_homo_graph)\n        batched_hetero_graph = batch(batched_hetero_graph)\n\n        batched_embed = torch.cat(batched_embed)\n        batched_embed = self.elayers(batched_embed)\n        values = batched_embed.reshape(-1)\n\n        values = self.concrete_sample(\n            values, beta=temperature, training=training\n        )\n        self.sparse_mask_values = values\n\n        col, row = batched_homo_graph.edges()\n        reverse_eids = batched_homo_graph.edge_ids(row, col).long()\n        edge_mask = (values + values[reverse_eids]) / 2\n\n        self.set_masks(batched_homo_graph, edge_mask)\n\n        # Convert the edge mask back into heterogeneous format.\n        hetero_edge_mask = self._edge_mask_to_heterogeneous(\n            edge_mask=edge_mask,\n            homograph=batched_homo_graph,\n            heterograph=batched_hetero_graph,\n        )\n\n        batched_feats = {\n            ntype: batched_hetero_graph.nodes[ntype].data[\"feat\"]\n            for ntype in batched_hetero_graph.ntypes\n        }\n\n        # The model prediction with the updated edge mask.\n        logits = self.model(\n            batched_hetero_graph,\n            batched_feats,\n            edge_weight=hetero_edge_mask,\n            **kwargs,\n        )\n        probs = {\n            ntype: F.softmax(logits[ntype], dim=-1) for ntype in logits.keys()\n        }\n\n        batched_inverse_indices = {\n            ntype: batched_hetero_graph.nodes[ntype]\n            .data[\"train\"]\n            .nonzero()\n            .squeeze(1)\n            for ntype in batched_hetero_graph.ntypes\n        }\n\n        if training:\n            self.batched_feats = batched_feats\n            probs = {ntype: probs[ntype].data for ntype in probs.keys()}\n        else:\n            self.clear_masks()\n\n        return (\n            probs,\n            hetero_edge_mask,\n            batched_hetero_graph,\n            batched_inverse_indices,\n        )\n\n    def _edge_mask_to_heterogeneous(self, edge_mask, homograph, heterograph):\n        r\"\"\"Convert an edge mask from homogeneous mappings built through\n        embeddings into heterogenous format by leveraging the context from\n        the source DGLGraphs in homogenous and heterogeneous form.\n\n        The `edge_mask` needs to have been built using the embedding of the\n        homogenous graph format for the mappings to work correctly.\n\n        Parameters\n        ----------\n        edge_mask : dict[str, Tensor]\n            A dict mapping node types (keys) to a tensor of edge weights (values).\n        homograph : DGLGraph\n            The homogeneous form of the source graph.\n        heterograph : DGLGraph\n            The heterogeneous form of the source graph.\n\n        Returns\n        -------\n        dict[str, Tensor]\n            A dict mapping node types (keys) to tensors of node ids (values)\n        \"\"\"\n        return {\n            etype: edge_mask[\n                (homograph.edata[ETYPE] == heterograph.get_etype_id(etype))\n                .nonzero()\n                .squeeze(1)\n            ]\n            for etype in heterograph.canonical_etypes\n        }\n"
  },
  {
    "path": "python/dgl/nn/pytorch/explain/subgraphx.py",
    "content": "\"\"\"Torch Module for SubgraphX\"\"\"\nimport math\n\nimport networkx as nx\nimport numpy as np\nimport torch\nimport torch.nn as nn\n\nfrom .... import to_heterogeneous, to_homogeneous\nfrom ....base import NID\nfrom ....convert import to_networkx\nfrom ....subgraph import node_subgraph\nfrom ....transforms.functional import remove_nodes\n\n__all__ = [\"SubgraphX\", \"HeteroSubgraphX\"]\n\n\nclass MCTSNode:\n    r\"\"\"Monte Carlo Tree Search Node\n\n    Parameters\n    ----------\n    nodes : Tensor\n        The node IDs of the graph that are associated with this tree node\n    \"\"\"\n\n    def __init__(self, nodes):\n        self.nodes = nodes\n        self.num_visit = 0\n        self.total_reward = 0.0\n        self.immediate_reward = 0.0\n        self.children = []\n\n    def __repr__(self):\n        r\"\"\"Get the string representation of the node.\n\n        Returns\n        -------\n        str\n            The string representation of the node\n        \"\"\"\n        return str(self.nodes)\n\n\nclass SubgraphX(nn.Module):\n    r\"\"\"SubgraphX from `On Explainability of Graph Neural Networks via Subgraph\n    Explorations <https://arxiv.org/abs/2102.05152>`\n\n    It identifies the most important subgraph from the original graph that\n    plays a critical role in GNN-based graph classification.\n\n    It employs Monte Carlo tree search (MCTS) in efficiently exploring\n    different subgraphs for explanation and uses Shapley values as the measure\n    of subgraph importance.\n\n    Parameters\n    ----------\n    model : nn.Module\n        The GNN model to explain that tackles multiclass graph classification\n\n        * Its forward function must have the form\n          :attr:`forward(self, graph, nfeat)`.\n        * The output of its forward function is the logits.\n    num_hops : int\n        Number of message passing layers in the model\n    coef : float, optional\n        This hyperparameter controls the trade-off between exploration and\n        exploitation. A higher value encourages the algorithm to explore\n        relatively unvisited nodes. Default: 10.0\n    high2low : bool, optional\n        If True, it will use the \"High2low\" strategy for pruning actions,\n        expanding children nodes from high degree to low degree when extending\n        the children nodes in the search tree. Otherwise, it will use the\n        \"Low2high\" strategy. Default: True\n    num_child : int, optional\n        This is the number of children nodes to expand when extending the\n        children nodes in the search tree. Default: 12\n    num_rollouts : int, optional\n        This is the number of rollouts for MCTS. Default: 20\n    node_min : int, optional\n        This is the threshold to define a leaf node based on the number of\n        nodes in a subgraph. Default: 3\n    shapley_steps : int, optional\n        This is the number of steps for Monte Carlo sampling in estimating\n        Shapley values. Default: 100\n    log : bool, optional\n        If True, it will log the progress. Default: False\n    \"\"\"\n\n    def __init__(\n        self,\n        model,\n        num_hops,\n        coef=10.0,\n        high2low=True,\n        num_child=12,\n        num_rollouts=20,\n        node_min=3,\n        shapley_steps=100,\n        log=False,\n    ):\n        super().__init__()\n        self.num_hops = num_hops\n        self.coef = coef\n        self.high2low = high2low\n        self.num_child = num_child\n        self.num_rollouts = num_rollouts\n        self.node_min = node_min\n        self.shapley_steps = shapley_steps\n        self.log = log\n\n        self.model = model\n\n    def shapley(self, subgraph_nodes):\n        r\"\"\"Compute Shapley value with Monte Carlo approximation.\n\n        Parameters\n        ----------\n        subgraph_nodes : tensor\n            The tensor node ids of the subgraph that are associated with this\n            tree node\n\n        Returns\n        -------\n        float\n            Shapley value\n        \"\"\"\n        num_nodes = self.graph.num_nodes()\n        subgraph_nodes = subgraph_nodes.tolist()\n\n        # Obtain neighboring nodes of the subgraph g_i, P'.\n        local_region = subgraph_nodes\n        for _ in range(self.num_hops - 1):\n            in_neighbors, _ = self.graph.in_edges(local_region)\n            _, out_neighbors = self.graph.out_edges(local_region)\n            neighbors = torch.cat([in_neighbors, out_neighbors]).tolist()\n            local_region = list(set(local_region + neighbors))\n\n        split_point = num_nodes\n        coalition_space = list(set(local_region) - set(subgraph_nodes)) + [\n            split_point\n        ]\n\n        marginal_contributions = []\n        device = self.feat.device\n        for _ in range(self.shapley_steps):\n            permuted_space = np.random.permutation(coalition_space)\n            split_idx = int(np.where(permuted_space == split_point)[0])\n\n            selected_nodes = permuted_space[:split_idx]\n\n            # Mask for coalition set S_i\n            exclude_mask = torch.ones(num_nodes)\n            exclude_mask[local_region] = 0.0\n            exclude_mask[selected_nodes] = 1.0\n\n            # Mask for set S_i and g_i\n            include_mask = exclude_mask.clone()\n            include_mask[subgraph_nodes] = 1.0\n\n            exclude_feat = self.feat * exclude_mask.unsqueeze(1).to(device)\n            include_feat = self.feat * include_mask.unsqueeze(1).to(device)\n\n            with torch.no_grad():\n                exclude_probs = self.model(\n                    self.graph, exclude_feat, **self.kwargs\n                ).softmax(dim=-1)\n                exclude_value = exclude_probs[:, self.target_class]\n                include_probs = self.model(\n                    self.graph, include_feat, **self.kwargs\n                ).softmax(dim=-1)\n                include_value = include_probs[:, self.target_class]\n            marginal_contributions.append(include_value - exclude_value)\n\n        return torch.cat(marginal_contributions).mean().item()\n\n    def get_mcts_children(self, mcts_node):\n        r\"\"\"Get the children of the MCTS node for the search.\n\n        Parameters\n        ----------\n        mcts_node : MCTSNode\n            Node in MCTS\n\n        Returns\n        -------\n        list\n            Children nodes after pruning\n        \"\"\"\n        if len(mcts_node.children) > 0:\n            return mcts_node.children\n\n        subg = node_subgraph(self.graph, mcts_node.nodes)\n        node_degrees = subg.out_degrees() + subg.in_degrees()\n        k = min(subg.num_nodes(), self.num_child)\n        chosen_nodes = torch.topk(\n            node_degrees, k, largest=self.high2low\n        ).indices\n\n        mcts_children_maps = dict()\n\n        for node in chosen_nodes:\n            new_subg = remove_nodes(subg, node.to(subg.idtype), store_ids=True)\n            # Get the largest weakly connected component in the subgraph.\n            nx_graph = to_networkx(new_subg.cpu())\n            largest_cc_nids = list(\n                max(nx.weakly_connected_components(nx_graph), key=len)\n            )\n            # Map to the original node IDs.\n            largest_cc_nids = new_subg.ndata[NID][largest_cc_nids].long()\n            largest_cc_nids = subg.ndata[NID][largest_cc_nids].sort().values\n            if str(largest_cc_nids) not in self.mcts_node_maps:\n                child_mcts_node = MCTSNode(largest_cc_nids)\n                self.mcts_node_maps[str(child_mcts_node)] = child_mcts_node\n            else:\n                child_mcts_node = self.mcts_node_maps[str(largest_cc_nids)]\n\n            if str(child_mcts_node) not in mcts_children_maps:\n                mcts_children_maps[str(child_mcts_node)] = child_mcts_node\n\n        mcts_node.children = list(mcts_children_maps.values())\n        for child_mcts_node in mcts_node.children:\n            if child_mcts_node.immediate_reward == 0:\n                child_mcts_node.immediate_reward = self.shapley(\n                    child_mcts_node.nodes\n                )\n\n        return mcts_node.children\n\n    def mcts_rollout(self, mcts_node):\n        r\"\"\"Perform a MCTS rollout.\n\n        Parameters\n        ----------\n        mcts_node : MCTSNode\n            Starting node for MCTS\n\n        Returns\n        -------\n        float\n            Reward for visiting the node this time\n        \"\"\"\n        if len(mcts_node.nodes) <= self.node_min:\n            return mcts_node.immediate_reward\n\n        children_nodes = self.get_mcts_children(mcts_node)\n        children_visit_sum = sum([child.num_visit for child in children_nodes])\n        children_visit_sum_sqrt = math.sqrt(children_visit_sum)\n        chosen_child = max(\n            children_nodes,\n            key=lambda c: c.total_reward / max(c.num_visit, 1)\n            + self.coef\n            * c.immediate_reward\n            * children_visit_sum_sqrt\n            / (1 + c.num_visit),\n        )\n        reward = self.mcts_rollout(chosen_child)\n        chosen_child.num_visit += 1\n        chosen_child.total_reward += reward\n\n        return reward\n\n    def explain_graph(self, graph, feat, target_class, **kwargs):\n        r\"\"\"Find the most important subgraph from the original graph for the\n        model to classify the graph into the target class.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            A homogeneous graph\n        feat : Tensor\n            The input node feature of shape :math:`(N, D)`, :math:`N` is the\n            number of nodes, and :math:`D` is the feature size\n        target_class : int\n            The target class to explain\n        kwargs : dict\n            Additional arguments passed to the GNN model\n\n        Returns\n        -------\n        Tensor\n            Nodes that represent the most important subgraph\n\n        Examples\n        --------\n\n        >>> import torch\n        >>> import torch.nn as nn\n        >>> import torch.nn.functional as F\n        >>> from dgl.data import GINDataset\n        >>> from dgl.dataloading import GraphDataLoader\n        >>> from dgl.nn import GraphConv, AvgPooling, SubgraphX\n\n        >>> # Define the model\n        >>> class Model(nn.Module):\n        ...     def __init__(self, in_dim, n_classes, hidden_dim=128):\n        ...         super().__init__()\n        ...         self.conv1 = GraphConv(in_dim, hidden_dim)\n        ...         self.conv2 = GraphConv(hidden_dim, n_classes)\n        ...         self.pool = AvgPooling()\n        ...\n        ...     def forward(self, g, h):\n        ...         h = F.relu(self.conv1(g, h))\n        ...         h = self.conv2(g, h)\n        ...         return self.pool(g, h)\n\n        >>> # Load dataset\n        >>> data = GINDataset('MUTAG', self_loop=True)\n        >>> dataloader = GraphDataLoader(data, batch_size=64, shuffle=True)\n\n        >>> # Train the model\n        >>> feat_size = data[0][0].ndata['attr'].shape[1]\n        >>> model = Model(feat_size, data.gclasses)\n        >>> criterion = nn.CrossEntropyLoss()\n        >>> optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)\n        >>> for bg, labels in dataloader:\n        ...     logits = model(bg, bg.ndata['attr'])\n        ...     loss = criterion(logits, labels)\n        ...     optimizer.zero_grad()\n        ...     loss.backward()\n        ...     optimizer.step()\n\n        >>> # Initialize the explainer\n        >>> explainer = SubgraphX(model, num_hops=2)\n\n        >>> # Explain the prediction for graph 0\n        >>> graph, l = data[0]\n        >>> graph_feat = graph.ndata.pop(\"attr\")\n        >>> g_nodes_explain = explainer.explain_graph(graph, graph_feat,\n        ...                                           target_class=l)\n        \"\"\"\n        self.model.eval()\n        assert (\n            graph.num_nodes() > self.node_min\n        ), f\"The number of nodes in the\\\n            graph {graph.num_nodes()} should be bigger than {self.node_min}.\"\n\n        self.graph = graph\n        self.feat = feat\n        self.target_class = target_class\n        self.kwargs = kwargs\n\n        # book all nodes in MCTS\n        self.mcts_node_maps = dict()\n\n        root = MCTSNode(graph.nodes())\n        self.mcts_node_maps[str(root)] = root\n\n        for i in range(self.num_rollouts):\n            if self.log:\n                print(\n                    f\"Rollout {i}/{self.num_rollouts}, \\\n                    {len(self.mcts_node_maps)} subgraphs have been explored.\"\n                )\n            self.mcts_rollout(root)\n\n        best_leaf = None\n        best_immediate_reward = float(\"-inf\")\n        for mcts_node in self.mcts_node_maps.values():\n            if len(mcts_node.nodes) > self.node_min:\n                continue\n\n            if mcts_node.immediate_reward > best_immediate_reward:\n                best_leaf = mcts_node\n                best_immediate_reward = best_leaf.immediate_reward\n\n        return best_leaf.nodes\n\n\nclass HeteroSubgraphX(nn.Module):\n    r\"\"\"SubgraphX from `On Explainability of Graph Neural Networks via Subgraph\n    Explorations <https://arxiv.org/abs/2102.05152>`__, adapted for heterogeneous graphs\n\n    It identifies the most important subgraph from the original graph that\n    plays a critical role in GNN-based graph classification.\n\n    It employs Monte Carlo tree search (MCTS) in efficiently exploring\n    different subgraphs for explanation and uses Shapley values as the measure\n    of subgraph importance.\n\n    Parameters\n    ----------\n    model : nn.Module\n        The GNN model to explain that tackles multiclass graph classification\n\n        * Its forward function must have the form\n          :attr:`forward(self, graph, nfeat)`.\n        * The output of its forward function is the logits.\n    num_hops : int\n        Number of message passing layers in the model\n    coef : float, optional\n        This hyperparameter controls the trade-off between exploration and\n        exploitation. A higher value encourages the algorithm to explore\n        relatively unvisited nodes. Default: 10.0\n    high2low : bool, optional\n        If True, it will use the \"High2low\" strategy for pruning actions,\n        expanding children nodes from high degree to low degree when extending\n        the children nodes in the search tree. Otherwise, it will use the\n        \"Low2high\" strategy. Default: True\n    num_child : int, optional\n        This is the number of children nodes to expand when extending the\n        children nodes in the search tree. Default: 12\n    num_rollouts : int, optional\n        This is the number of rollouts for MCTS. Default: 20\n    node_min : int, optional\n        This is the threshold to define a leaf node based on the number of\n        nodes in a subgraph. Default: 3\n    shapley_steps : int, optional\n        This is the number of steps for Monte Carlo sampling in estimating\n        Shapley values. Default: 100\n    log : bool, optional\n        If True, it will log the progress. Default: False\n    \"\"\"\n\n    def __init__(\n        self,\n        model,\n        num_hops,\n        coef=10.0,\n        high2low=True,\n        num_child=12,\n        num_rollouts=20,\n        node_min=3,\n        shapley_steps=100,\n        log=False,\n    ):\n        super().__init__()\n        self.num_hops = num_hops\n        self.coef = coef\n        self.high2low = high2low\n        self.num_child = num_child\n        self.num_rollouts = num_rollouts\n        self.node_min = node_min\n        self.shapley_steps = shapley_steps\n        self.log = log\n\n        self.model = model\n\n    def shapley(self, subgraph_nodes):\n        r\"\"\"Compute Shapley value with Monte Carlo approximation.\n\n        Parameters\n        ----------\n        subgraph_nodes : dict[str, Tensor]\n            subgraph_nodes[nty] gives the tensor node IDs of node type nty\n            in the subgraph, which are associated with this tree node\n\n        Returns\n        -------\n        float\n            Shapley value\n        \"\"\"\n        # Obtain neighboring nodes of the subgraph g_i, P'.\n        local_regions = {\n            ntype: nodes.tolist() for ntype, nodes in subgraph_nodes.items()\n        }\n        for _ in range(self.num_hops - 1):\n            for c_etype in self.graph.canonical_etypes:\n                src_ntype, _, dst_ntype = c_etype\n                if (\n                    src_ntype not in local_regions\n                    or dst_ntype not in local_regions\n                ):\n                    continue\n\n                in_neighbors, _ = self.graph.in_edges(\n                    local_regions[dst_ntype], etype=c_etype\n                )\n                _, out_neighbors = self.graph.out_edges(\n                    local_regions[src_ntype], etype=c_etype\n                )\n                local_regions[src_ntype] = list(\n                    set(local_regions[src_ntype] + in_neighbors.tolist())\n                )\n                local_regions[dst_ntype] = list(\n                    set(local_regions[dst_ntype] + out_neighbors.tolist())\n                )\n\n        split_point = self.graph.num_nodes()\n        coalition_space = {\n            ntype: list(\n                set(local_regions[ntype]) - set(subgraph_nodes[ntype].tolist())\n            )\n            + [split_point]\n            for ntype in subgraph_nodes.keys()\n        }\n\n        marginal_contributions = []\n        for _ in range(self.shapley_steps):\n            selected_node_map = dict()\n            for ntype, nodes in coalition_space.items():\n                permuted_space = np.random.permutation(nodes)\n                split_idx = int(np.where(permuted_space == split_point)[0])\n                selected_node_map[ntype] = permuted_space[:split_idx]\n\n            # Mask for coalition set S_i\n            exclude_mask = {\n                ntype: torch.ones(self.graph.num_nodes(ntype))\n                for ntype in self.graph.ntypes\n            }\n            for ntype, region in local_regions.items():\n                exclude_mask[ntype][region] = 0.0\n            for ntype, selected_nodes in selected_node_map.items():\n                exclude_mask[ntype][selected_nodes] = 1.0\n\n            # Mask for set S_i and g_i\n            include_mask = {\n                ntype: exclude_mask[ntype].clone()\n                for ntype in self.graph.ntypes\n            }\n            for ntype, subgn in subgraph_nodes.items():\n                exclude_mask[ntype][subgn] = 1.0\n\n            exclude_feat = {\n                ntype: self.feat[ntype]\n                * exclude_mask[ntype].unsqueeze(1).to(self.feat[ntype].device)\n                for ntype in self.graph.ntypes\n            }\n            include_feat = {\n                ntype: self.feat[ntype]\n                * include_mask[ntype].unsqueeze(1).to(self.feat[ntype].device)\n                for ntype in self.graph.ntypes\n            }\n\n            with torch.no_grad():\n                exclude_probs = self.model(\n                    self.graph, exclude_feat, **self.kwargs\n                ).softmax(dim=-1)\n                exclude_value = exclude_probs[:, self.target_class]\n                include_probs = self.model(\n                    self.graph, include_feat, **self.kwargs\n                ).softmax(dim=-1)\n                include_value = include_probs[:, self.target_class]\n            marginal_contributions.append(include_value - exclude_value)\n\n        return torch.cat(marginal_contributions).mean().item()\n\n    def get_mcts_children(self, mcts_node):\n        r\"\"\"Get the children of the MCTS node for the search.\n\n        Parameters\n        ----------\n        mcts_node : MCTSNode\n            Node in MCTS\n\n        Returns\n        -------\n        list\n            Children nodes after pruning\n        \"\"\"\n        if len(mcts_node.children) > 0:\n            return mcts_node.children\n\n        subg = node_subgraph(self.graph, mcts_node.nodes)\n        # Choose k nodes based on the highest degree in the subgraph\n        node_degrees_map = {\n            ntype: torch.zeros(\n                subg.num_nodes(ntype), device=subg.nodes(ntype).device\n            )\n            for ntype in subg.ntypes\n        }\n        for c_etype in subg.canonical_etypes:\n            src_ntype, _, dst_ntype = c_etype\n            node_degrees_map[src_ntype] += subg.out_degrees(etype=c_etype)\n            node_degrees_map[dst_ntype] += subg.in_degrees(etype=c_etype)\n\n        node_degrees_list = [\n            ((ntype, i), degree)\n            for ntype, node_degrees in node_degrees_map.items()\n            for i, degree in enumerate(node_degrees)\n        ]\n        node_degrees = torch.stack([v for _, v in node_degrees_list])\n        k = min(subg.num_nodes(), self.num_child)\n        chosen_node_indicies = torch.topk(\n            node_degrees, k, largest=self.high2low\n        ).indices\n        chosen_nodes = [node_degrees_list[i][0] for i in chosen_node_indicies]\n\n        mcts_children_maps = dict()\n\n        for ntype, node in chosen_nodes:\n            new_subg = remove_nodes(subg, node, ntype, store_ids=True)\n\n            if new_subg.num_edges() > 0:\n                new_subg_homo = to_homogeneous(new_subg)\n                # Get the largest weakly connected component in the subgraph.\n                nx_graph = to_networkx(new_subg_homo.cpu())\n                largest_cc_nids = list(\n                    max(nx.weakly_connected_components(nx_graph), key=len)\n                )\n                largest_cc_homo = node_subgraph(new_subg_homo, largest_cc_nids)\n                largest_cc_hetero = to_heterogeneous(\n                    largest_cc_homo, new_subg.ntypes, new_subg.etypes\n                )\n\n                # Follow steps for backtracking to original graph node ids\n                # 1. retrieve instanced homograph from connected-component homograph\n                # 2. retrieve instanced heterograph from instanced homograph\n                # 3. retrieve hetero-subgraph from instanced heterograph\n                # 4. retrieve orignal graph ids from subgraph node ids\n                cc_nodes = {\n                    ntype: subg.ndata[NID][ntype][\n                        new_subg.ndata[NID][ntype][\n                            new_subg_homo.ndata[NID][\n                                largest_cc_homo.ndata[NID][indicies]\n                            ]\n                        ]\n                    ]\n                    for ntype, indicies in largest_cc_hetero.ndata[NID].items()\n                }\n            else:\n                available_ntypes = [\n                    ntype\n                    for ntype in new_subg.ntypes\n                    if new_subg.num_nodes(ntype) > 0\n                ]\n                chosen_ntype = np.random.choice(available_ntypes)\n                # backtrack from subgraph node ids to entire graph\n                chosen_node = subg.ndata[NID][chosen_ntype][\n                    np.random.choice(new_subg.nodes[chosen_ntype].data[NID])\n                ]\n                cc_nodes = {\n                    chosen_ntype: torch.tensor(\n                        [chosen_node],\n                        device=subg.device,\n                    )\n                }\n\n            if str(cc_nodes) not in self.mcts_node_maps:\n                child_mcts_node = MCTSNode(cc_nodes)\n                self.mcts_node_maps[str(child_mcts_node)] = child_mcts_node\n            else:\n                child_mcts_node = self.mcts_node_maps[str(cc_nodes)]\n\n            if str(child_mcts_node) not in mcts_children_maps:\n                mcts_children_maps[str(child_mcts_node)] = child_mcts_node\n\n        mcts_node.children = list(mcts_children_maps.values())\n        for child_mcts_node in mcts_node.children:\n            if child_mcts_node.immediate_reward == 0:\n                child_mcts_node.immediate_reward = self.shapley(\n                    child_mcts_node.nodes\n                )\n\n        return mcts_node.children\n\n    def mcts_rollout(self, mcts_node):\n        r\"\"\"Perform a MCTS rollout.\n\n        Parameters\n        ----------\n        mcts_node : MCTSNode\n            Starting node for MCTS\n\n        Returns\n        -------\n        float\n            Reward for visiting the node this time\n        \"\"\"\n        if (\n            sum(len(nodes) for nodes in mcts_node.nodes.values())\n            <= self.node_min\n        ):\n            return mcts_node.immediate_reward\n\n        children_nodes = self.get_mcts_children(mcts_node)\n        children_visit_sum = sum([child.num_visit for child in children_nodes])\n        children_visit_sum_sqrt = math.sqrt(children_visit_sum)\n        chosen_child = max(\n            children_nodes,\n            key=lambda c: c.total_reward / max(c.num_visit, 1)\n            + self.coef\n            * c.immediate_reward\n            * children_visit_sum_sqrt\n            / (1 + c.num_visit),\n        )\n        reward = self.mcts_rollout(chosen_child)\n        chosen_child.num_visit += 1\n        chosen_child.total_reward += reward\n\n        return reward\n\n    def explain_graph(self, graph, feat, target_class, **kwargs):\n        r\"\"\"Find the most important subgraph from the original graph for the\n        model to classify the graph into the target class.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            A heterogeneous graph\n        feat : dict[str, Tensor]\n            The dictionary that associates input node features (values) with\n            the respective node types (keys) present in the graph.\n            The input features are of shape :math:`(N_t, D_t)`. :math:`N_t` is the\n            number of nodes for node type :math:`t`, and :math:`D_t` is the feature size for\n            node type :math:`t`\n        target_class : int\n            The target class to explain\n        kwargs : dict\n            Additional arguments passed to the GNN model\n\n        Returns\n        -------\n        dict[str, Tensor]\n            The dictionary associating tensor node ids (values) to\n            node types (keys) that represents the most important subgraph\n\n        Examples\n        --------\n\n        >>> import dgl\n        >>> import dgl.function as fn\n        >>> import torch as th\n        >>> import torch.nn as nn\n        >>> import torch.nn.functional as F\n        >>> from dgl.nn import HeteroSubgraphX\n\n        >>> class Model(nn.Module):\n        ...     def __init__(self, in_dim, num_classes, canonical_etypes):\n        ...         super(Model, self).__init__()\n        ...         self.etype_weights = nn.ModuleDict(\n        ...             {\n        ...                 \"_\".join(c_etype): nn.Linear(in_dim, num_classes)\n        ...                 for c_etype in canonical_etypes\n        ...             }\n        ...         )\n        ...\n        ...     def forward(self, graph, feat):\n        ...         with graph.local_scope():\n        ...             c_etype_func_dict = {}\n        ...             for c_etype in graph.canonical_etypes:\n        ...                 src_type, etype, dst_type = c_etype\n        ...                 wh = self.etype_weights[\"_\".join(c_etype)](feat[src_type])\n        ...                 graph.nodes[src_type].data[f\"h_{c_etype}\"] = wh\n        ...                 c_etype_func_dict[c_etype] = (\n        ...                     fn.copy_u(f\"h_{c_etype}\", \"m\"),\n        ...                     fn.mean(\"m\", \"h\"),\n        ...                 )\n        ...             graph.multi_update_all(c_etype_func_dict, \"sum\")\n        ...             hg = 0\n        ...             for ntype in graph.ntypes:\n        ...                 if graph.num_nodes(ntype):\n        ...                     hg = hg + dgl.mean_nodes(graph, \"h\", ntype=ntype)\n        ...             return hg\n\n        >>> input_dim = 5\n        >>> num_classes = 2\n        >>> g = dgl.heterograph({(\"user\", \"plays\", \"game\"): ([0, 1, 1, 2], [0, 0, 1, 1])})\n        >>> g.nodes[\"user\"].data[\"h\"] = th.randn(g.num_nodes(\"user\"), input_dim)\n        >>> g.nodes[\"game\"].data[\"h\"] = th.randn(g.num_nodes(\"game\"), input_dim)\n\n        >>> transform = dgl.transforms.AddReverse()\n        >>> g = transform(g)\n\n        >>> # define and train the model\n        >>> model = Model(input_dim, num_classes, g.canonical_etypes)\n        >>> feat = g.ndata[\"h\"]\n        >>> optimizer = th.optim.Adam(model.parameters())\n        >>> for epoch in range(10):\n        ...     logits = model(g, feat)\n        ...     loss = F.cross_entropy(logits, th.tensor([1]))\n        ...     optimizer.zero_grad()\n        ...     loss.backward()\n        ...     optimizer.step()\n\n        >>> # Explain for the graph\n        >>> explainer = HeteroSubgraphX(model, num_hops=1)\n        >>> explainer.explain_graph(g, feat, target_class=1)\n        {'game': tensor([0, 1]), 'user': tensor([1, 2])}\n        \"\"\"\n        self.model.eval()\n        assert (\n            graph.num_nodes() > self.node_min\n        ), f\"The number of nodes in the\\\n            graph {graph.num_nodes()} should be bigger than {self.node_min}.\"\n\n        self.graph = graph\n        self.feat = feat\n        self.target_class = target_class\n        self.kwargs = kwargs\n\n        # book all nodes in MCTS\n        self.mcts_node_maps = dict()\n\n        root_dict = {ntype: graph.nodes(ntype) for ntype in graph.ntypes}\n        root = MCTSNode(root_dict)\n        self.mcts_node_maps[str(root)] = root\n\n        for i in range(self.num_rollouts):\n            if self.log:\n                print(\n                    f\"Rollout {i}/{self.num_rollouts}, \\\n                    {len(self.mcts_node_maps)} subgraphs have been explored.\"\n                )\n            self.mcts_rollout(root)\n\n        best_leaf = None\n        best_immediate_reward = float(\"-inf\")\n        for mcts_node in self.mcts_node_maps.values():\n            if len(mcts_node.nodes) > self.node_min:\n                continue\n\n            if mcts_node.immediate_reward > best_immediate_reward:\n                best_leaf = mcts_node\n                best_immediate_reward = best_leaf.immediate_reward\n\n        return best_leaf.nodes\n"
  },
  {
    "path": "python/dgl/nn/pytorch/factory.py",
    "content": "\"\"\"Modules that transforms between graphs and between graph and tensors.\"\"\"\nimport torch.nn as nn\n\nfrom ...transforms import knn_graph, radius_graph, segmented_knn_graph\n\n\ndef pairwise_squared_distance(x):\n    \"\"\"\n    x : (n_samples, n_points, dims)\n    return : (n_samples, n_points, n_points)\n    \"\"\"\n    x2s = (x * x).sum(-1, keepdim=True)\n    return x2s + x2s.transpose(-1, -2) - 2 * x @ x.transpose(-1, -2)\n\n\nclass KNNGraph(nn.Module):\n    r\"\"\"Layer that transforms one point set into a graph, or a batch of\n    point sets with the same number of points into a batched union of those graphs.\n\n    The KNNGraph is implemented in the following steps:\n\n    1. Compute an NxN matrix of pairwise distance for all points.\n    2. Pick the k points with the smallest distance for each point as their k-nearest neighbors.\n    3. Construct a graph with edges to each point as a node from its k-nearest neighbors.\n\n    The overall computational complexity is :math:`O(N^2(logN + D)`.\n\n    If a batch of point sets is provided, the point :math:`j` in point\n    set :math:`i` is mapped to graph node ID: :math:`i \\times M + j`, where\n    :math:`M` is the number of nodes in each point set.\n\n    The predecessors of each node are the k-nearest neighbors of the\n    corresponding point.\n\n    Parameters\n    ----------\n    k : int\n        The number of neighbors.\n\n    Notes\n    -----\n    The nearest neighbors found for a node include the node itself.\n\n    Examples\n    --------\n    The following example uses PyTorch backend.\n\n    >>> import torch\n    >>> from dgl.nn.pytorch.factory import KNNGraph\n    >>>\n    >>> kg = KNNGraph(2)\n    >>> x = torch.tensor([[0,1],\n                          [1,2],\n                          [1,3],\n                          [100, 101],\n                          [101, 102],\n                          [50, 50]])\n    >>> g = kg(x)\n    >>> print(g.edges())\n        (tensor([0, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 5]),\n         tensor([0, 0, 1, 2, 1, 2, 5, 3, 4, 3, 4, 5]))\n    \"\"\"\n\n    def __init__(self, k):\n        super(KNNGraph, self).__init__()\n        self.k = k\n\n    # pylint: disable=invalid-name\n    def forward(\n        self,\n        x,\n        algorithm=\"bruteforce-blas\",\n        dist=\"euclidean\",\n        exclude_self=False,\n    ):\n        r\"\"\"\n\n        Forward computation.\n\n        Parameters\n        ----------\n        x : Tensor\n            :math:`(M, D)` or :math:`(N, M, D)` where :math:`N` means the\n            number of point sets, :math:`M` means the number of points in\n            each point set, and :math:`D` means the size of features.\n        algorithm : str, optional\n            Algorithm used to compute the k-nearest neighbors.\n\n            * 'bruteforce-blas' will first compute the distance matrix\n              using BLAS matrix multiplication operation provided by\n              backend frameworks. Then use topk algorithm to get\n              k-nearest neighbors. This method is fast when the point\n              set is small but has :math:`O(N^2)` memory complexity where\n              :math:`N` is the number of points.\n\n            * 'bruteforce' will compute distances pair by pair and\n              directly select the k-nearest neighbors during distance\n              computation. This method is slower than 'bruteforce-blas'\n              but has less memory overhead (i.e., :math:`O(Nk)` where :math:`N`\n              is the number of points, :math:`k` is the number of nearest\n              neighbors per node) since we do not need to store all distances.\n\n            * 'bruteforce-sharemem' (CUDA only) is similar to 'bruteforce'\n              but use shared memory in CUDA devices for buffer. This method is\n              faster than 'bruteforce' when the dimension of input points\n              is not large. This method is only available on CUDA device.\n\n            * 'kd-tree' will use the kd-tree algorithm (CPU only).\n              This method is suitable for low-dimensional data (e.g. 3D\n              point clouds)\n\n            * 'nn-descent' is a approximate approach from paper\n              `Efficient k-nearest neighbor graph construction for generic similarity\n              measures <https://www.cs.princeton.edu/cass/papers/www11.pdf>`_. This method\n              will search for nearest neighbor candidates in \"neighbors' neighbors\".\n\n            (default: 'bruteforce-blas')\n        dist : str, optional\n            The distance metric used to compute distance between points. It can be the following\n            metrics:\n            * 'euclidean': Use Euclidean distance (L2 norm)\n              :math:`\\sqrt{\\sum_{i} (x_{i} - y_{i})^{2}}`.\n            * 'cosine': Use cosine distance.\n            (default: 'euclidean')\n        exclude_self : bool, optional\n            If True, the output graph will not contain self loop edges, and each node will not\n            be counted as one of its own k neighbors.  If False, the output graph will contain\n            self loop edges, and a node will be counted as one of its own k neighbors.\n\n        Returns\n        -------\n        DGLGraph\n            A DGLGraph without features.\n        \"\"\"\n        return knn_graph(\n            x, self.k, algorithm=algorithm, dist=dist, exclude_self=exclude_self\n        )\n\n\nclass SegmentedKNNGraph(nn.Module):\n    r\"\"\"Layer that transforms one point set into a graph, or a batch of\n    point sets with different number of points into a batched union of those graphs.\n\n    If a batch of point sets is provided, then the point :math:`j` in the point\n    set :math:`i` is mapped to graph node ID:\n    :math:`\\sum_{p<i} |V_p| + j`, where :math:`|V_p|` means the number of\n    points in the point set :math:`p`.\n\n    The predecessors of each node are the k-nearest neighbors of the\n    corresponding point.\n\n    Parameters\n    ----------\n    k : int\n        The number of neighbors.\n\n    Notes\n    -----\n    The nearest neighbors found for a node include the node itself.\n\n    Examples\n    --------\n    The following example uses PyTorch backend.\n\n    >>> import torch\n    >>> from dgl.nn.pytorch.factory import SegmentedKNNGraph\n    >>>\n    >>> kg = SegmentedKNNGraph(2)\n    >>> x = torch.tensor([[0,1],\n    ...                   [1,2],\n    ...                   [1,3],\n    ...                   [100, 101],\n    ...                   [101, 102],\n    ...                   [50, 50],\n    ...                   [24,25],\n    ...                   [25,24]])\n    >>> g = kg(x, [3,3,2])\n    >>> print(g.edges())\n    (tensor([0, 1, 1, 1, 2, 2, 3, 3, 3, 4, 4, 5, 6, 6, 7, 7]),\n     tensor([0, 0, 1, 2, 1, 2, 3, 4, 5, 3, 4, 5, 6, 7, 6, 7]))\n    >>>\n\n    \"\"\"\n\n    def __init__(self, k):\n        super(SegmentedKNNGraph, self).__init__()\n        self.k = k\n\n    # pylint: disable=invalid-name\n    def forward(\n        self,\n        x,\n        segs,\n        algorithm=\"bruteforce-blas\",\n        dist=\"euclidean\",\n        exclude_self=False,\n    ):\n        r\"\"\"Forward computation.\n\n        Parameters\n        ----------\n        x : Tensor\n            :math:`(M, D)` where :math:`M` means the total number of points\n            in all point sets, and :math:`D` means the size of features.\n        segs : iterable of int\n            :math:`(N)` integers where :math:`N` means the number of point\n            sets.  The number of elements must sum up to :math:`M`. And any\n            :math:`N` should :math:`\\ge k`\n        algorithm : str, optional\n            Algorithm used to compute the k-nearest neighbors.\n\n            * 'bruteforce-blas' will first compute the distance matrix\n              using BLAS matrix multiplication operation provided by\n              backend frameworks. Then use topk algorithm to get\n              k-nearest neighbors. This method is fast when the point\n              set is small but has :math:`O(N^2)` memory complexity where\n              :math:`N` is the number of points.\n\n            * 'bruteforce' will compute distances pair by pair and\n              directly select the k-nearest neighbors during distance\n              computation. This method is slower than 'bruteforce-blas'\n              but has less memory overhead (i.e., :math:`O(Nk)` where :math:`N`\n              is the number of points, :math:`k` is the number of nearest\n              neighbors per node) since we do not need to store all distances.\n\n            * 'bruteforce-sharemem' (CUDA only) is similar to 'bruteforce'\n              but use shared memory in CUDA devices for buffer. This method is\n              faster than 'bruteforce' when the dimension of input points\n              is not large. This method is only available on CUDA device.\n\n            * 'kd-tree' will use the kd-tree algorithm (CPU only).\n              This method is suitable for low-dimensional data (e.g. 3D\n              point clouds)\n\n            * 'nn-descent' is a approximate approach from paper\n              `Efficient k-nearest neighbor graph construction for generic similarity\n              measures <https://www.cs.princeton.edu/cass/papers/www11.pdf>`_. This method\n              will search for nearest neighbor candidates in \"neighbors' neighbors\".\n\n            (default: 'bruteforce-blas')\n        dist : str, optional\n            The distance metric used to compute distance between points. It can be the following\n            metrics:\n            * 'euclidean': Use Euclidean distance (L2 norm)\n              :math:`\\sqrt{\\sum_{i} (x_{i} - y_{i})^{2}}`.\n            * 'cosine': Use cosine distance.\n            (default: 'euclidean')\n        exclude_self : bool, optional\n            If True, the output graph will not contain self loop edges, and each node will not\n            be counted as one of its own k neighbors.  If False, the output graph will contain\n            self loop edges, and a node will be counted as one of its own k neighbors.\n\n        Returns\n        -------\n        DGLGraph\n            A batched DGLGraph without features.\n        \"\"\"\n\n        return segmented_knn_graph(\n            x,\n            self.k,\n            segs,\n            algorithm=algorithm,\n            dist=dist,\n            exclude_self=exclude_self,\n        )\n\n\nclass RadiusGraph(nn.Module):\n    r\"\"\"Layer that transforms one point set into a bidirected graph with\n    neighbors within given distance.\n\n    The RadiusGraph is implemented in the following steps:\n\n    1. Compute an NxN matrix of pairwise distance for all points.\n    2. Pick the points within distance to each point as their neighbors.\n    3. Construct a graph with edges to each point as a node from its neighbors.\n\n    The nodes of the returned graph correspond to the points, where the neighbors\n    of each point are within given distance.\n\n    Parameters\n    ----------\n    r : float\n        Radius of the neighbors.\n    p : float, optional\n        Power parameter for the Minkowski metric. When :attr:`p = 1` it is the\n        equivalent of Manhattan distance (L1 norm) and Euclidean distance\n        (L2 norm) for :attr:`p = 2`.\n\n        (default: 2)\n    self_loop : bool, optional\n        Whether the radius graph will contain self-loops.\n\n        (default: False)\n    compute_mode : str, optional\n        ``use_mm_for_euclid_dist_if_necessary`` - will use matrix multiplication\n        approach to calculate euclidean distance (p = 2) if P > 25 or R > 25\n        ``use_mm_for_euclid_dist`` - will always use matrix multiplication\n        approach to calculate euclidean distance (p = 2)\n        ``donot_use_mm_for_euclid_dist`` - will never use matrix multiplication\n        approach to calculate euclidean distance (p = 2).\n\n        (default: donot_use_mm_for_euclid_dist)\n\n    Examples\n    --------\n    The following examples uses PyTorch backend.\n\n    >>> import dgl\n    >>> from dgl.nn.pytorch.factory import RadiusGraph\n\n    >>> x = torch.tensor([[0.0, 0.0, 1.0],\n    ...                   [1.0, 0.5, 0.5],\n    ...                   [0.5, 0.2, 0.2],\n    ...                   [0.3, 0.2, 0.4]])\n    >>> rg = RadiusGraph(0.75)\n    >>> g = rg(x)  # Each node has neighbors within 0.75 distance\n    >>> g.edges()\n    (tensor([0, 1, 2, 2, 3, 3]), tensor([3, 2, 1, 3, 0, 2]))\n\n    When :attr:`get_distances` is True, forward pass returns the radius graph and\n    distances for the corresponding edges.\n\n    >>> x = torch.tensor([[0.0, 0.0, 1.0],\n    ...                   [1.0, 0.5, 0.5],\n    ...                   [0.5, 0.2, 0.2],\n    ...                   [0.3, 0.2, 0.4]])\n    >>> rg = RadiusGraph(0.75)\n    >>> g, dist = rg(x, get_distances=True)\n    >>> g.edges()\n    (tensor([0, 1, 2, 2, 3, 3]), tensor([3, 2, 1, 3, 0, 2]))\n    >>> dist\n    tensor([[0.7000],\n            [0.6557],\n            [0.6557],\n            [0.2828],\n            [0.7000],\n            [0.2828]])\n    \"\"\"\n\n    # pylint: disable=invalid-name\n    def __init__(\n        self,\n        r,\n        p=2,\n        self_loop=False,\n        compute_mode=\"donot_use_mm_for_euclid_dist\",\n    ):\n        super(RadiusGraph, self).__init__()\n        self.r = r\n        self.p = p\n        self.self_loop = self_loop\n        self.compute_mode = compute_mode\n\n    # pylint: disable=invalid-name\n    def forward(self, x, get_distances=False):\n        r\"\"\"\n        Forward computation.\n\n        Parameters\n        ----------\n        x : Tensor\n            The point coordinates. :math:`(N, D)` where :math:`N` means the\n            number of points in the point set, and :math:`D` means the size of\n            the features. It can be either on CPU or GPU. Device of the point\n            coordinates specifies device of the radius graph.\n        get_distances : bool, optional\n            Whether to return the distances for the corresponding edges in the\n            radius graph.\n\n            (default: False)\n\n        Returns\n        -------\n        DGLGraph\n            The constructed graph. The node IDs are in the same order as :attr:`x`.\n        torch.Tensor, optional\n            The distances for the edges in the constructed graph. The distances\n            are in the same order as edge IDs.\n        \"\"\"\n        return radius_graph(\n            x, self.r, self.p, self.self_loop, self.compute_mode, get_distances\n        )\n"
  },
  {
    "path": "python/dgl/nn/pytorch/glob.py",
    "content": "\"\"\"Torch modules for graph global pooling.\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name, W0235\nimport numpy as np\nimport torch as th\nimport torch.nn as nn\n\nfrom ...backend import pytorch as F\nfrom ...base import dgl_warning\nfrom ...readout import (\n    broadcast_nodes,\n    max_nodes,\n    mean_nodes,\n    softmax_nodes,\n    sum_nodes,\n    topk_nodes,\n)\n\n__all__ = [\n    \"SumPooling\",\n    \"AvgPooling\",\n    \"MaxPooling\",\n    \"SortPooling\",\n    \"GlobalAttentionPooling\",\n    \"Set2Set\",\n    \"SetTransformerEncoder\",\n    \"SetTransformerDecoder\",\n    \"WeightAndSum\",\n]\n\n\nclass SumPooling(nn.Module):\n    r\"\"\"Apply sum pooling over the nodes in a graph.\n\n    .. math::\n        r^{(i)} = \\sum_{k=1}^{N_i} x^{(i)}_k\n\n    Notes\n    -----\n        Input: Could be one graph, or a batch of graphs. If using a batch of graphs,\n        make sure nodes in all graphs have the same feature size, and concatenate\n        nodes' feature together as the input.\n\n    Examples\n    --------\n    The following example uses PyTorch backend.\n\n    >>> import dgl\n    >>> import torch as th\n    >>> from dgl.nn import SumPooling\n    >>>\n    >>> g1 = dgl.rand_graph(3, 4)  # g1 is a random graph with 3 nodes and 4 edges\n    >>> g1_node_feats = th.rand(3, 5)  # feature size is 5\n    >>> g1_node_feats\n    tensor([[0.8948, 0.0699, 0.9137, 0.7567, 0.3637],\n            [0.8137, 0.8938, 0.8377, 0.4249, 0.6118],\n            [0.5197, 0.9030, 0.6825, 0.5725, 0.4755]])\n    >>>\n    >>> g2 = dgl.rand_graph(4, 6)  # g2 is a random graph with 4 nodes and 6 edges\n    >>> g2_node_feats = th.rand(4, 5)  # feature size is 5\n    >>> g2_node_feats\n    tensor([[0.2053, 0.2426, 0.4111, 0.9028, 0.5658],\n            [0.5278, 0.6365, 0.9990, 0.2351, 0.8945],\n            [0.3134, 0.0580, 0.4349, 0.7949, 0.3891],\n            [0.0142, 0.2709, 0.3330, 0.8521, 0.6925]])\n    >>>\n    >>> sumpool = SumPooling()  # create a sum pooling layer\n\n    Case 1: Input a single graph\n\n    >>> sumpool(g1, g1_node_feats)\n    tensor([[2.2282, 1.8667, 2.4338, 1.7540, 1.4511]])\n\n    Case 2: Input a batch of graphs\n\n    Build a batch of DGL graphs and concatenate all graphs' node features into one tensor.\n\n    >>> batch_g = dgl.batch([g1, g2])\n    >>> batch_f = th.cat([g1_node_feats, g2_node_feats])\n    >>>\n    >>> sumpool(batch_g, batch_f)\n    tensor([[2.2282, 1.8667, 2.4338, 1.7540, 1.4511],\n            [1.0608, 1.2080, 2.1780, 2.7849, 2.5420]])\n    \"\"\"\n\n    def __init__(self):\n        super(SumPooling, self).__init__()\n\n    def forward(self, graph, feat):\n        r\"\"\"\n\n        Compute sum pooling.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            a DGLGraph or a batch of DGLGraphs\n        feat : torch.Tensor\n            The input feature with shape :math:`(N, D)`, where :math:`N` is the number\n            of nodes in the graph, and :math:`D` means the size of features.\n\n        Returns\n        -------\n        torch.Tensor\n            The output feature with shape :math:`(B, D)`, where :math:`B` refers to the\n            batch size of input graphs.\n        \"\"\"\n        with graph.local_scope():\n            graph.ndata[\"h\"] = feat\n            readout = sum_nodes(graph, \"h\")\n            return readout\n\n\nclass AvgPooling(nn.Module):\n    r\"\"\"Apply average pooling over the nodes in a graph.\n\n    .. math::\n        r^{(i)} = \\frac{1}{N_i}\\sum_{k=1}^{N_i} x^{(i)}_k\n\n    Notes\n    -----\n        Input: Could be one graph, or a batch of graphs. If using a batch of graphs,\n        make sure nodes in all graphs have the same feature size, and concatenate\n        nodes' feature together as the input.\n\n    Examples\n    --------\n    The following example uses PyTorch backend.\n\n    >>> import dgl\n    >>> import torch as th\n    >>> from dgl.nn import AvgPooling\n    >>>\n    >>> g1 = dgl.rand_graph(3, 4)  # g1 is a random graph with 3 nodes and 4 edges\n    >>> g1_node_feats = th.rand(3, 5)  # feature size is 5\n    >>> g1_node_feats\n    tensor([[0.8948, 0.0699, 0.9137, 0.7567, 0.3637],\n            [0.8137, 0.8938, 0.8377, 0.4249, 0.6118],\n            [0.5197, 0.9030, 0.6825, 0.5725, 0.4755]])\n    >>>\n    >>> g2 = dgl.rand_graph(4, 6)  # g2 is a random graph with 4 nodes and 6 edges\n    >>> g2_node_feats = th.rand(4, 5)  # feature size is 5\n    >>> g2_node_feats\n    tensor([[0.2053, 0.2426, 0.4111, 0.9028, 0.5658],\n            [0.5278, 0.6365, 0.9990, 0.2351, 0.8945],\n            [0.3134, 0.0580, 0.4349, 0.7949, 0.3891],\n            [0.0142, 0.2709, 0.3330, 0.8521, 0.6925]])\n    >>>\n    >>> avgpool = AvgPooling()  # create an average pooling layer\n\n    Case 1: Input single graph\n\n    >>> avgpool(g1, g1_node_feats)\n    tensor([[0.7427, 0.6222, 0.8113, 0.5847, 0.4837]])\n\n    Case 2: Input a batch of graphs\n\n    Build a batch of DGL graphs and concatenate all graphs' note features into one tensor.\n\n    >>> batch_g = dgl.batch([g1, g2])\n    >>> batch_f = th.cat([g1_node_feats, g2_node_feats])\n    >>>\n    >>> avgpool(batch_g, batch_f)\n    tensor([[0.7427, 0.6222, 0.8113, 0.5847, 0.4837],\n            [0.2652, 0.3020, 0.5445, 0.6962, 0.6355]])\n    \"\"\"\n\n    def __init__(self):\n        super(AvgPooling, self).__init__()\n\n    def forward(self, graph, feat):\n        r\"\"\"\n\n        Compute average pooling.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            A DGLGraph or a batch of DGLGraphs.\n        feat : torch.Tensor\n            The input feature with shape :math:`(N, D)`, where :math:`N` is the number\n            of nodes in the graph, and :math:`D` means the size of features.\n\n        Returns\n        -------\n        torch.Tensor\n            The output feature with shape :math:`(B, D)`, where\n            :math:`B` refers to the batch size of input graphs.\n        \"\"\"\n        with graph.local_scope():\n            graph.ndata[\"h\"] = feat\n            readout = mean_nodes(graph, \"h\")\n            return readout\n\n\nclass MaxPooling(nn.Module):\n    r\"\"\"Apply max pooling over the nodes in a graph.\n\n    .. math::\n        r^{(i)} = \\max_{k=1}^{N_i}\\left( x^{(i)}_k \\right)\n\n    Notes\n    -----\n        Input: Could be one graph, or a batch of graphs. If using a batch of graphs,\n        make sure nodes in all graphs have the same feature size, and concatenate\n        nodes' feature together as the input.\n\n    Examples\n    --------\n    The following example uses PyTorch backend.\n\n    >>> import dgl\n    >>> import torch as th\n    >>> from dgl.nn import MaxPooling\n    >>>\n    >>> g1 = dgl.rand_graph(3, 4)  # g1 is a random graph with 3 nodes and 4 edges\n    >>> g1_node_feats = th.rand(3, 5)  # feature size is 5\n    >>> g1_node_feats\n    tensor([[0.8948, 0.0699, 0.9137, 0.7567, 0.3637],\n            [0.8137, 0.8938, 0.8377, 0.4249, 0.6118],\n            [0.5197, 0.9030, 0.6825, 0.5725, 0.4755]])\n    >>>\n    >>> g2 = dgl.rand_graph(4, 6)  # g2 is a random graph with 4 nodes and 6 edges\n    >>> g2_node_feats = th.rand(4, 5)  # feature size is 5\n    >>> g2_node_feats\n    tensor([[0.2053, 0.2426, 0.4111, 0.9028, 0.5658],\n            [0.5278, 0.6365, 0.9990, 0.2351, 0.8945],\n            [0.3134, 0.0580, 0.4349, 0.7949, 0.3891],\n            [0.0142, 0.2709, 0.3330, 0.8521, 0.6925]])\n    >>>\n    >>> maxpool = MaxPooling()  # create a max pooling layer\n\n    Case 1: Input a single graph\n\n    >>> maxpool(g1, g1_node_feats)\n    tensor([[0.8948, 0.9030, 0.9137, 0.7567, 0.6118]])\n\n    Case 2: Input a batch of graphs\n\n    Build a batch of DGL graphs and concatenate all graphs' node features into one tensor.\n\n    >>> batch_g = dgl.batch([g1, g2])\n    >>> batch_f = th.cat([g1_node_feats, g2_node_feats])\n    >>>\n    >>> maxpool(batch_g, batch_f)\n    tensor([[0.8948, 0.9030, 0.9137, 0.7567, 0.6118],\n            [0.5278, 0.6365, 0.9990, 0.9028, 0.8945]])\n    \"\"\"\n\n    def __init__(self):\n        super(MaxPooling, self).__init__()\n\n    def forward(self, graph, feat):\n        r\"\"\"Compute max pooling.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            A DGLGraph or a batch of DGLGraphs.\n        feat : torch.Tensor\n            The input feature with shape :math:`(N, *)`, where\n            :math:`N` is the number of nodes in the graph.\n\n        Returns\n        -------\n        torch.Tensor\n            The output feature with shape :math:`(B, *)`, where\n            :math:`B` refers to the batch size.\n        \"\"\"\n        with graph.local_scope():\n            graph.ndata[\"h\"] = feat\n            readout = max_nodes(graph, \"h\")\n            return readout\n\n\nclass SortPooling(nn.Module):\n    r\"\"\"Sort Pooling from `An End-to-End Deep Learning Architecture for Graph Classification\n    <https://www.cse.wustl.edu/~ychen/public/DGCNN.pdf>`__\n\n    It first sorts the node features in ascending order along the feature dimension,\n    and selects the sorted features of top-k nodes (ranked by the largest value of each node).\n\n    Parameters\n    ----------\n    k : int\n        The number of nodes to hold for each graph.\n\n    Notes\n    -----\n        Input: Could be one graph, or a batch of graphs. If using a batch of graphs,\n        make sure nodes in all graphs have the same feature size, and concatenate\n        nodes' feature together as the input.\n\n    Examples\n    --------\n\n    >>> import dgl\n    >>> import torch as th\n    >>> from dgl.nn import SortPooling\n    >>>\n    >>> g1 = dgl.rand_graph(3, 4)  # g1 is a random graph with 3 nodes and 4 edges\n    >>> g1_node_feats = th.rand(3, 5)  # feature size is 5\n    >>> g1_node_feats\n    tensor([[0.8948, 0.0699, 0.9137, 0.7567, 0.3637],\n            [0.8137, 0.8938, 0.8377, 0.4249, 0.6118],\n            [0.5197, 0.9030, 0.6825, 0.5725, 0.4755]])\n    >>>\n    >>> g2 = dgl.rand_graph(4, 6)  # g2 is a random graph with 4 nodes and 6 edges\n    >>> g2_node_feats = th.rand(4, 5)  # feature size is 5\n    >>> g2_node_feats\n    tensor([[0.2053, 0.2426, 0.4111, 0.9028, 0.5658],\n            [0.5278, 0.6365, 0.9990, 0.2351, 0.8945],\n            [0.3134, 0.0580, 0.4349, 0.7949, 0.3891],\n            [0.0142, 0.2709, 0.3330, 0.8521, 0.6925]])\n    >>>\n    >>> sortpool = SortPooling(k=2)  # create a sort pooling layer\n\n    Case 1: Input a single graph\n\n    >>> sortpool(g1, g1_node_feats)\n    tensor([[0.0699, 0.3637, 0.7567, 0.8948, 0.9137, 0.4755, 0.5197, 0.5725, 0.6825,\n             0.9030]])\n\n    Case 2: Input a batch of graphs\n\n    Build a batch of DGL graphs and concatenate all graphs' node features into one tensor.\n\n    >>> batch_g = dgl.batch([g1, g2])\n    >>> batch_f = th.cat([g1_node_feats, g2_node_feats])\n    >>>\n    >>> sortpool(batch_g, batch_f)\n    tensor([[0.0699, 0.3637, 0.7567, 0.8948, 0.9137, 0.4755, 0.5197, 0.5725, 0.6825,\n             0.9030],\n            [0.2351, 0.5278, 0.6365, 0.8945, 0.9990, 0.2053, 0.2426, 0.4111, 0.5658,\n             0.9028]])\n    \"\"\"\n\n    def __init__(self, k):\n        super(SortPooling, self).__init__()\n        self.k = k\n\n    def forward(self, graph, feat):\n        r\"\"\"\n\n        Compute sort pooling.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            A DGLGraph or a batch of DGLGraphs.\n        feat : torch.Tensor\n            The input node feature with shape :math:`(N, D)`, where :math:`N` is the\n            number of nodes in the graph, and :math:`D` means the size of features.\n\n        Returns\n        -------\n        torch.Tensor\n            The output feature with shape :math:`(B, k * D)`, where :math:`B` refers\n            to the batch size of input graphs.\n        \"\"\"\n        with graph.local_scope():\n            # Sort the feature of each node in ascending order.\n            feat, _ = feat.sort(dim=-1)\n            graph.ndata[\"h\"] = feat\n            # Sort nodes according to their last features.\n            ret = topk_nodes(graph, \"h\", self.k, sortby=-1)[0].view(\n                -1, self.k * feat.shape[-1]\n            )\n            return ret\n\n\nclass GlobalAttentionPooling(nn.Module):\n    r\"\"\"Global Attention Pooling from `Gated Graph Sequence Neural Networks\n    <https://arxiv.org/abs/1511.05493>`__\n\n    .. math::\n        r^{(i)} = \\sum_{k=1}^{N_i}\\mathrm{softmax}\\left(f_{gate}\n        \\left(x^{(i)}_k\\right)\\right) f_{feat}\\left(x^{(i)}_k\\right)\n\n    Parameters\n    ----------\n    gate_nn : torch.nn.Module\n        A neural network that computes attention scores for each feature.\n    feat_nn : torch.nn.Module, optional\n        A neural network applied to each feature before combining them with attention\n        scores.\n\n    Examples\n    --------\n    The following example uses PyTorch backend.\n\n    >>> import dgl\n    >>> import torch as th\n    >>> from dgl.nn import GlobalAttentionPooling\n    >>>\n    >>> g1 = dgl.rand_graph(3, 4)  # g1 is a random graph with 3 nodes and 4 edges\n    >>> g1_node_feats = th.rand(3, 5)  # feature size is 5\n    >>> g1_node_feats\n    tensor([[0.8948, 0.0699, 0.9137, 0.7567, 0.3637],\n            [0.8137, 0.8938, 0.8377, 0.4249, 0.6118],\n            [0.5197, 0.9030, 0.6825, 0.5725, 0.4755]])\n    >>>\n    >>> g2 = dgl.rand_graph(4, 6)  # g2 is a random graph with 4 nodes and 6 edges\n    >>> g2_node_feats = th.rand(4, 5)  # feature size is 5\n    >>> g2_node_feats\n    tensor([[0.2053, 0.2426, 0.4111, 0.9028, 0.5658],\n            [0.5278, 0.6365, 0.9990, 0.2351, 0.8945],\n            [0.3134, 0.0580, 0.4349, 0.7949, 0.3891],\n            [0.0142, 0.2709, 0.3330, 0.8521, 0.6925]])\n    >>>\n    >>> gate_nn = th.nn.Linear(5, 1)  # the gate layer that maps node feature to scalar\n    >>> gap = GlobalAttentionPooling(gate_nn)  # create a Global Attention Pooling layer\n\n    Case 1: Input a single graph\n\n    >>> gap(g1, g1_node_feats)\n    tensor([[0.7410, 0.6032, 0.8111, 0.5942, 0.4762]],\n           grad_fn=<SegmentReduceBackward>)\n\n    Case 2: Input a batch of graphs\n\n    Build a batch of DGL graphs and concatenate all graphs' node features into one tensor.\n\n    >>> batch_g = dgl.batch([g1, g2])\n    >>> batch_f = th.cat([g1_node_feats, g2_node_feats], 0)\n    >>>\n    >>> gap(batch_g, batch_f)\n    tensor([[0.7410, 0.6032, 0.8111, 0.5942, 0.4762],\n            [0.2417, 0.2743, 0.5054, 0.7356, 0.6146]],\n           grad_fn=<SegmentReduceBackward>)\n    Notes\n    -----\n    See our `GGNN example <https://github.com/dmlc/dgl/tree/master/examples/pytorch/ggnn>`_\n    on how to use GatedGraphConv and GlobalAttentionPooling layer to build a Graph Neural\n    Networks that can solve Soduku.\n    \"\"\"\n\n    def __init__(self, gate_nn, feat_nn=None):\n        super(GlobalAttentionPooling, self).__init__()\n        self.gate_nn = gate_nn\n        self.feat_nn = feat_nn\n\n    def forward(self, graph, feat, get_attention=False):\n        r\"\"\"\n\n        Compute global attention pooling.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            A DGLGraph or a batch of DGLGraphs.\n        feat : torch.Tensor\n            The input node feature with shape :math:`(N, D)` where :math:`N` is the\n            number of nodes in the graph, and :math:`D` means the size of features.\n        get_attention : bool, optional\n            Whether to return the attention values from gate_nn. Default to False.\n\n        Returns\n        -------\n        torch.Tensor\n            The output feature with shape :math:`(B, D)`, where :math:`B` refers\n            to the batch size.\n        torch.Tensor, optional\n            The attention values of shape :math:`(N, 1)`, where :math:`N` is the number of\n            nodes in the graph. This is returned only when :attr:`get_attention` is ``True``.\n        \"\"\"\n        with graph.local_scope():\n            gate = self.gate_nn(feat)\n            assert (\n                gate.shape[-1] == 1\n            ), \"The output of gate_nn should have size 1 at the last axis.\"\n            feat = self.feat_nn(feat) if self.feat_nn else feat\n\n            graph.ndata[\"gate\"] = gate\n            gate = softmax_nodes(graph, \"gate\")\n            graph.ndata.pop(\"gate\")\n\n            graph.ndata[\"r\"] = feat * gate\n            readout = sum_nodes(graph, \"r\")\n            graph.ndata.pop(\"r\")\n\n            if get_attention:\n                return readout, gate\n            else:\n                return readout\n\n\nclass Set2Set(nn.Module):\n    r\"\"\"Set2Set operator from `Order Matters: Sequence to sequence for sets\n    <https://arxiv.org/pdf/1511.06391.pdf>`__\n\n    For each individual graph in the batch, set2set computes\n\n    .. math::\n        q_t &= \\mathrm{LSTM} (q^*_{t-1})\n\n        \\alpha_{i,t} &= \\mathrm{softmax}(x_i \\cdot q_t)\n\n        r_t &= \\sum_{i=1}^N \\alpha_{i,t} x_i\n\n        q^*_t &= q_t \\Vert r_t\n\n    for this graph.\n\n    Parameters\n    ----------\n    input_dim : int\n        The size of each input sample.\n    n_iters : int\n        The number of iterations.\n    n_layers : int\n        The number of recurrent layers.\n\n    Examples\n    --------\n    The following example uses PyTorch backend.\n\n    >>> import dgl\n    >>> import torch as th\n    >>> from dgl.nn import Set2Set\n    >>>\n    >>> g1 = dgl.rand_graph(3, 4)  # g1 is a random graph with 3 nodes and 4 edges\n    >>> g1_node_feats = th.rand(3, 5)  # feature size is 5\n    >>> g1_node_feats\n    tensor([[0.8948, 0.0699, 0.9137, 0.7567, 0.3637],\n            [0.8137, 0.8938, 0.8377, 0.4249, 0.6118],\n            [0.5197, 0.9030, 0.6825, 0.5725, 0.4755]])\n    >>>\n    >>> g2 = dgl.rand_graph(4, 6)  # g2 is a random graph with 4 nodes and 6 edges\n    >>> g2_node_feats = th.rand(4, 5)  # feature size is 5\n    >>> g2_node_feats\n    tensor([[0.2053, 0.2426, 0.4111, 0.9028, 0.5658],\n            [0.5278, 0.6365, 0.9990, 0.2351, 0.8945],\n            [0.3134, 0.0580, 0.4349, 0.7949, 0.3891],\n            [0.0142, 0.2709, 0.3330, 0.8521, 0.6925]])\n    >>>\n    >>> s2s = Set2Set(5, 2, 1)  # create a Set2Set layer(n_iters=2, n_layers=1)\n\n    Case 1: Input a single graph\n\n    >>> s2s(g1, g1_node_feats)\n        tensor([[-0.0235, -0.2291,  0.2654,  0.0376,  0.1349,  0.7560,  0.5822,  0.8199,\n                  0.5960,  0.4760]], grad_fn=<CatBackward>)\n\n    Case 2: Input a batch of graphs\n\n    Build a batch of DGL graphs and concatenate all graphs' node features into one tensor.\n\n    >>> batch_g = dgl.batch([g1, g2])\n    >>> batch_f = th.cat([g1_node_feats, g2_node_feats], 0)\n    >>>\n    >>> s2s(batch_g, batch_f)\n    tensor([[-0.0235, -0.2291,  0.2654,  0.0376,  0.1349,  0.7560,  0.5822,  0.8199,\n              0.5960,  0.4760],\n            [-0.0483, -0.2010,  0.2324,  0.0145,  0.1361,  0.2703,  0.3078,  0.5529,\n              0.6876,  0.6399]], grad_fn=<CatBackward>)\n\n    Notes\n    -----\n    Set2Set is widely used in molecular property predictions, see\n    `dgl-lifesci's MPNN example <https://github.com/awslabs/dgl-lifesci/blob/\n    ecd95c905479ec048097777039cf9a19cfdcf223/python/dgllife/model/model_zoo/\n    mpnn_predictor.py>`__\n    on how to use DGL's Set2Set layer in graph property prediction applications.\n    \"\"\"\n\n    def __init__(self, input_dim, n_iters, n_layers):\n        super(Set2Set, self).__init__()\n        self.input_dim = input_dim\n        self.output_dim = 2 * input_dim\n        self.n_iters = n_iters\n        self.n_layers = n_layers\n        self.lstm = th.nn.LSTM(self.output_dim, self.input_dim, n_layers)\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        \"\"\"Reinitialize learnable parameters.\"\"\"\n        self.lstm.reset_parameters()\n\n    def forward(self, graph, feat):\n        r\"\"\"\n        Compute set2set pooling.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The input graph.\n        feat : torch.Tensor\n            The input feature with shape :math:`(N, D)` where  :math:`N` is the\n            number of nodes in the graph, and :math:`D` means the size of features.\n\n        Returns\n        -------\n        torch.Tensor\n            The output feature with shape :math:`(B, D)`, where :math:`B` refers to\n            the batch size, and :math:`D` means the size of features.\n        \"\"\"\n        with graph.local_scope():\n            batch_size = graph.batch_size\n\n            h = (\n                feat.new_zeros((self.n_layers, batch_size, self.input_dim)),\n                feat.new_zeros((self.n_layers, batch_size, self.input_dim)),\n            )\n\n            q_star = feat.new_zeros(batch_size, self.output_dim)\n\n            for _ in range(self.n_iters):\n                q, h = self.lstm(q_star.unsqueeze(0), h)\n                q = q.view(batch_size, self.input_dim)\n                e = (feat * broadcast_nodes(graph, q)).sum(dim=-1, keepdim=True)\n                graph.ndata[\"e\"] = e\n                alpha = softmax_nodes(graph, \"e\")\n                graph.ndata[\"r\"] = feat * alpha\n                readout = sum_nodes(graph, \"r\")\n                q_star = th.cat([q, readout], dim=-1)\n\n            return q_star\n\n    def extra_repr(self):\n        \"\"\"Set the extra representation of the module.\n        which will come into effect when printing the model.\n        \"\"\"\n        summary = \"n_iters={n_iters}\"\n        return summary.format(**self.__dict__)\n\n\ndef _gen_mask(lengths_x, lengths_y, max_len_x, max_len_y):\n    \"\"\"Generate binary mask array for given x and y input pairs.\n\n    Parameters\n    ----------\n    lengths_x : Tensor\n        The int tensor indicates the segment information of x.\n    lengths_y : Tensor\n        The int tensor indicates the segment information of y.\n    max_len_x : int\n        The maximum element in lengths_x.\n    max_len_y : int\n        The maximum element in lengths_y.\n\n    Returns\n    -------\n    Tensor\n        the mask tensor with shape (batch_size, 1, max_len_x, max_len_y)\n    \"\"\"\n    device = lengths_x.device\n    # x_mask: (batch_size, max_len_x)\n    x_mask = th.arange(max_len_x, device=device).unsqueeze(\n        0\n    ) < lengths_x.unsqueeze(1)\n    # y_mask: (batch_size, max_len_y)\n    y_mask = th.arange(max_len_y, device=device).unsqueeze(\n        0\n    ) < lengths_y.unsqueeze(1)\n    # mask: (batch_size, 1, max_len_x, max_len_y)\n    mask = (x_mask.unsqueeze(-1) & y_mask.unsqueeze(-2)).unsqueeze(1)\n    return mask\n\n\nclass MultiHeadAttention(nn.Module):\n    r\"\"\"Multi-Head Attention block, used in Transformer, Set Transformer and so on\n\n    Parameters\n    ----------\n    d_model : int\n        The feature size (input and output) in Multi-Head Attention layer.\n    num_heads : int\n        The number of heads.\n    d_head : int\n        The hidden size per head.\n    d_ff : int\n        The inner hidden size in the Feed-Forward Neural Network.\n    dropouth : float\n        The dropout rate of each sublayer.\n    dropouta : float\n        The dropout rate of attention heads.\n\n    Notes\n    -----\n    This module was used in SetTransformer layer.\n    \"\"\"\n\n    def __init__(\n        self, d_model, num_heads, d_head, d_ff, dropouth=0.0, dropouta=0.0\n    ):\n        super(MultiHeadAttention, self).__init__()\n        self.d_model = d_model\n        self.num_heads = num_heads\n        self.d_head = d_head\n        self.d_ff = d_ff\n        self.proj_q = nn.Linear(d_model, num_heads * d_head, bias=False)\n        self.proj_k = nn.Linear(d_model, num_heads * d_head, bias=False)\n        self.proj_v = nn.Linear(d_model, num_heads * d_head, bias=False)\n        self.proj_o = nn.Linear(num_heads * d_head, d_model, bias=False)\n        self.ffn = nn.Sequential(\n            nn.Linear(d_model, d_ff),\n            nn.ReLU(),\n            nn.Dropout(dropouth),\n            nn.Linear(d_ff, d_model),\n        )\n        self.droph = nn.Dropout(dropouth)\n        self.dropa = nn.Dropout(dropouta)\n        self.norm_in = nn.LayerNorm(d_model)\n        self.norm_inter = nn.LayerNorm(d_model)\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        \"\"\"Reinitialize learnable parameters.\"\"\"\n        for p in self.parameters():\n            if p.dim() > 1:\n                nn.init.xavier_uniform_(p)\n\n    def forward(self, x, mem, lengths_x, lengths_mem):\n        \"\"\"\n        Compute multi-head self-attention.\n\n        Parameters\n        ----------\n        x : torch.Tensor\n            The input tensor used to compute queries.\n        mem : torch.Tensor\n            The memory tensor used to compute keys and values.\n        lengths_x : list\n            The array of node numbers, used to segment x.\n        lengths_mem : list\n            The array of node numbers, used to segment mem.\n        \"\"\"\n        batch_size = len(lengths_x)\n        max_len_x = max(lengths_x)\n        max_len_mem = max(lengths_mem)\n        device = x.device\n        lengths_x = th.as_tensor(lengths_x, dtype=th.int64, device=device)\n        lengths_mem = th.as_tensor(lengths_mem, dtype=th.int64, device=device)\n\n        queries = self.proj_q(x).view(-1, self.num_heads, self.d_head)\n        keys = self.proj_k(mem).view(-1, self.num_heads, self.d_head)\n        values = self.proj_v(mem).view(-1, self.num_heads, self.d_head)\n\n        # padding to (B, max_len_x/mem, num_heads, d_head)\n        queries = F.pad_packed_tensor(queries, lengths_x, 0)\n        keys = F.pad_packed_tensor(keys, lengths_mem, 0)\n        values = F.pad_packed_tensor(values, lengths_mem, 0)\n\n        # attention score with shape (B, num_heads, max_len_x, max_len_mem)\n        e = th.einsum(\"bxhd,byhd->bhxy\", queries, keys)\n        # normalize\n        e = e / np.sqrt(self.d_head)\n\n        # generate mask\n        mask = _gen_mask(lengths_x, lengths_mem, max_len_x, max_len_mem)\n        e = e.masked_fill(mask == 0, -float(\"inf\"))\n\n        # apply softmax\n        alpha = th.softmax(e, dim=-1)\n        # the following line addresses the NaN issue, see\n        # https://github.com/dmlc/dgl/issues/2657\n        alpha = alpha.masked_fill(mask == 0, 0.0)\n\n        # sum of value weighted by alpha\n        out = th.einsum(\"bhxy,byhd->bxhd\", alpha, values)\n        # project to output\n        out = self.proj_o(\n            out.contiguous().view(\n                batch_size, max_len_x, self.num_heads * self.d_head\n            )\n        )\n        # pack tensor\n        out = F.pack_padded_tensor(out, lengths_x)\n\n        # intra norm\n        x = self.norm_in(x + out)\n\n        # inter norm\n        x = self.norm_inter(x + self.ffn(x))\n\n        return x\n\n\nclass SetAttentionBlock(nn.Module):\n    r\"\"\"SAB block from `Set Transformer: A Framework for Attention-based\n    Permutation-Invariant Neural Networks <https://arxiv.org/abs/1810.00825>`__\n\n    Parameters\n    ----------\n    d_model : int\n        The feature size (input and output) in Multi-Head Attention layer.\n    num_heads : int\n        The number of heads.\n    d_head : int\n        The hidden size per head.\n    d_ff : int\n        The inner hidden size in the Feed-Forward Neural Network.\n    dropouth : float\n        The dropout rate of each sublayer.\n    dropouta : float\n        The dropout rate of attention heads.\n\n    Notes\n    -----\n    This module was used in SetTransformer layer.\n    \"\"\"\n\n    def __init__(\n        self, d_model, num_heads, d_head, d_ff, dropouth=0.0, dropouta=0.0\n    ):\n        super(SetAttentionBlock, self).__init__()\n        self.mha = MultiHeadAttention(\n            d_model,\n            num_heads,\n            d_head,\n            d_ff,\n            dropouth=dropouth,\n            dropouta=dropouta,\n        )\n\n    def forward(self, feat, lengths):\n        \"\"\"\n        Compute a Set Attention Block.\n\n        Parameters\n        ----------\n        feat : torch.Tensor\n            The input feature.\n        lengths : list\n            The array of node numbers, used to segment feat tensor.\n        \"\"\"\n        return self.mha(feat, feat, lengths, lengths)\n\n\nclass InducedSetAttentionBlock(nn.Module):\n    r\"\"\"ISAB block from `Set Transformer: A Framework for Attention-based\n    Permutation-Invariant Neural Networks <https://arxiv.org/abs/1810.00825>`__\n\n    Parameters\n    ----------\n    m : int\n        The number of induced vectors.\n    d_model : int\n        The feature size (input and output) in Multi-Head Attention layer.\n    num_heads : int\n        The number of heads.\n    d_head : int\n        The hidden size per head.\n    d_ff : int\n        The inner hidden size in the Feed-Forward Neural Network.\n    dropouth : float\n        The dropout rate of each sublayer.\n    dropouta : float\n        The dropout rate of attention heads.\n\n    Notes\n    -----\n    This module was used in SetTransformer layer.\n    \"\"\"\n\n    def __init__(\n        self, m, d_model, num_heads, d_head, d_ff, dropouth=0.0, dropouta=0.0\n    ):\n        super(InducedSetAttentionBlock, self).__init__()\n        self.m = m\n        if m == 1:\n            dgl_warning(\n                \"if m is set to 1, the parameters corresponding to query and key \"\n                \"projections would not get updated during training.\"\n            )\n        self.d_model = d_model\n        self.inducing_points = nn.Parameter(th.FloatTensor(m, d_model))\n        self.mha = nn.ModuleList(\n            [\n                MultiHeadAttention(\n                    d_model,\n                    num_heads,\n                    d_head,\n                    d_ff,\n                    dropouth=dropouth,\n                    dropouta=dropouta,\n                )\n                for _ in range(2)\n            ]\n        )\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        \"\"\"Reinitialize learnable parameters.\"\"\"\n        nn.init.xavier_uniform_(self.inducing_points)\n\n    def forward(self, feat, lengths):\n        \"\"\"\n        Compute an Induced Set Attention Block.\n\n        Parameters\n        ----------\n        feat : torch.Tensor\n            The input feature.\n        lengths : list\n            The array of node numbers, used to segment feat tensor.\n\n        Returns\n        -------\n        torch.Tensor\n            The output feature\n        \"\"\"\n        batch_size = len(lengths)\n        query = self.inducing_points.repeat(batch_size, 1)\n        memory = self.mha[0](query, feat, [self.m] * batch_size, lengths)\n        return self.mha[1](feat, memory, lengths, [self.m] * batch_size)\n\n    def extra_repr(self):\n        \"\"\"Set the extra representation of the module.\n        which will come into effect when printing the model.\n        \"\"\"\n        shape_str = \"({}, {})\".format(\n            self.inducing_points.shape[0], self.inducing_points.shape[1]\n        )\n        return \"InducedVector: \" + shape_str\n\n\nclass PMALayer(nn.Module):\n    r\"\"\"Pooling by Multihead Attention from `Set Transformer: A Framework for Attention-based\n    Permutation-Invariant Neural Networks <https://arxiv.org/abs/1810.00825>`__\n\n    Parameters\n    ----------\n    k : int\n        The number of seed vectors.\n    d_model : int\n        The feature size (input and output) in Multi-Head Attention layer.\n    num_heads : int\n        The number of heads.\n    d_head : int\n        The hidden size per head.\n    d_ff : int\n        The kernel size in FFN (Positionwise Feed-Forward Network) layer.\n    dropouth : float\n        The dropout rate of each sublayer.\n    dropouta : float\n        The dropout rate of attention heads.\n\n    Notes\n    -----\n    This module was used in SetTransformer layer.\n    \"\"\"\n\n    def __init__(\n        self, k, d_model, num_heads, d_head, d_ff, dropouth=0.0, dropouta=0.0\n    ):\n        super(PMALayer, self).__init__()\n        self.k = k\n        if k == 1:\n            dgl_warning(\n                \"if k is set to 1, the parameters corresponding to query and key \"\n                \"projections would not get updated during training.\"\n            )\n        self.d_model = d_model\n        self.seed_vectors = nn.Parameter(th.FloatTensor(k, d_model))\n        self.mha = MultiHeadAttention(\n            d_model,\n            num_heads,\n            d_head,\n            d_ff,\n            dropouth=dropouth,\n            dropouta=dropouta,\n        )\n        self.ffn = nn.Sequential(\n            nn.Linear(d_model, d_ff),\n            nn.ReLU(),\n            nn.Dropout(dropouth),\n            nn.Linear(d_ff, d_model),\n        )\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        \"\"\"Reinitialize learnable parameters.\"\"\"\n        nn.init.xavier_uniform_(self.seed_vectors)\n\n    def forward(self, feat, lengths):\n        \"\"\"\n        Compute Pooling by Multihead Attention.\n\n        Parameters\n        ----------\n        feat : torch.Tensor\n            The input feature.\n        lengths : list\n            The array of node numbers, used to segment feat tensor.\n\n        Returns\n        -------\n        torch.Tensor\n            The output feature\n        \"\"\"\n        batch_size = len(lengths)\n        query = self.seed_vectors.repeat(batch_size, 1)\n        return self.mha(query, self.ffn(feat), [self.k] * batch_size, lengths)\n\n    def extra_repr(self):\n        \"\"\"Set the extra representation of the module.\n        which will come into effect when printing the model.\n        \"\"\"\n        shape_str = \"({}, {})\".format(\n            self.seed_vectors.shape[0], self.seed_vectors.shape[1]\n        )\n        return \"SeedVector: \" + shape_str\n\n\nclass SetTransformerEncoder(nn.Module):\n    r\"\"\"The Encoder module from `Set Transformer: A Framework for Attention-based\n    Permutation-Invariant Neural Networks <https://arxiv.org/pdf/1810.00825.pdf>`__\n\n    Parameters\n    ----------\n    d_model : int\n        The hidden size of the model.\n    n_heads : int\n        The number of heads.\n    d_head : int\n        The hidden size of each head.\n    d_ff : int\n        The kernel size in FFN (Positionwise Feed-Forward Network) layer.\n    n_layers : int\n        The number of layers.\n    block_type : str\n        Building block type: 'sab' (Set Attention Block) or 'isab' (Induced\n        Set Attention Block).\n    m : int or None\n        The number of induced vectors in ISAB Block. Set to None if block type\n        is 'sab'.\n    dropouth : float\n        The dropout rate of each sublayer.\n    dropouta : float\n        The dropout rate of attention heads.\n\n    Examples\n    --------\n    >>> import dgl\n    >>> import torch as th\n    >>> from dgl.nn import SetTransformerEncoder\n    >>>\n    >>> g1 = dgl.rand_graph(3, 4)  # g1 is a random graph with 3 nodes and 4 edges\n    >>> g1_node_feats = th.rand(3, 5)  # feature size is 5\n    >>> g1_node_feats\n    tensor([[0.8948, 0.0699, 0.9137, 0.7567, 0.3637],\n            [0.8137, 0.8938, 0.8377, 0.4249, 0.6118],\n            [0.5197, 0.9030, 0.6825, 0.5725, 0.4755]])\n    >>>\n    >>> g2 = dgl.rand_graph(4, 6)  # g2 is a random graph with 4 nodes and 6 edges\n    >>> g2_node_feats = th.rand(4, 5)  # feature size is 5\n    >>> g2_node_feats\n    tensor([[0.2053, 0.2426, 0.4111, 0.9028, 0.5658],\n            [0.5278, 0.6365, 0.9990, 0.2351, 0.8945],\n            [0.3134, 0.0580, 0.4349, 0.7949, 0.3891],\n            [0.0142, 0.2709, 0.3330, 0.8521, 0.6925]])\n    >>>\n    >>> set_trans_enc = SetTransformerEncoder(5, 4, 4, 20)  # create a settrans encoder.\n\n    Case 1: Input a single graph\n\n    >>> set_trans_enc(g1, g1_node_feats)\n    tensor([[ 0.1262, -1.9081,  0.7287,  0.1678,  0.8854],\n            [-0.0634, -1.1996,  0.6955, -0.9230,  1.4904],\n            [-0.9972, -0.7924,  0.6907, -0.5221,  1.6211]],\n           grad_fn=<NativeLayerNormBackward>)\n\n    Case 2: Input a batch of graphs\n\n    Build a batch of DGL graphs and concatenate all graphs' node features into one tensor.\n\n    >>> batch_g = dgl.batch([g1, g2])\n    >>> batch_f = th.cat([g1_node_feats, g2_node_feats])\n    >>>\n    >>> set_trans_enc(batch_g, batch_f)\n    tensor([[ 0.1262, -1.9081,  0.7287,  0.1678,  0.8854],\n            [-0.0634, -1.1996,  0.6955, -0.9230,  1.4904],\n            [-0.9972, -0.7924,  0.6907, -0.5221,  1.6211],\n            [-0.7973, -1.3203,  0.0634,  0.5237,  1.5306],\n            [-0.4497, -1.0920,  0.8470, -0.8030,  1.4977],\n            [-0.4940, -1.6045,  0.2363,  0.4885,  1.3737],\n            [-0.9840, -1.0913, -0.0099,  0.4653,  1.6199]],\n           grad_fn=<NativeLayerNormBackward>)\n\n    See Also\n    --------\n    SetTransformerDecoder\n\n    Notes\n    -----\n    SetTransformerEncoder is not a readout layer, the tensor it returned is nodewise\n    representation instead out graphwise representation, and the SetTransformerDecoder\n    would return a graph readout tensor.\n    \"\"\"\n\n    def __init__(\n        self,\n        d_model,\n        n_heads,\n        d_head,\n        d_ff,\n        n_layers=1,\n        block_type=\"sab\",\n        m=None,\n        dropouth=0.0,\n        dropouta=0.0,\n    ):\n        super(SetTransformerEncoder, self).__init__()\n        self.n_layers = n_layers\n        self.block_type = block_type\n        self.m = m\n        layers = []\n        if block_type == \"isab\" and m is None:\n            raise KeyError(\n                \"The number of inducing points is not specified in ISAB block.\"\n            )\n\n        for _ in range(n_layers):\n            if block_type == \"sab\":\n                layers.append(\n                    SetAttentionBlock(\n                        d_model,\n                        n_heads,\n                        d_head,\n                        d_ff,\n                        dropouth=dropouth,\n                        dropouta=dropouta,\n                    )\n                )\n            elif block_type == \"isab\":\n                layers.append(\n                    InducedSetAttentionBlock(\n                        m,\n                        d_model,\n                        n_heads,\n                        d_head,\n                        d_ff,\n                        dropouth=dropouth,\n                        dropouta=dropouta,\n                    )\n                )\n            else:\n                raise KeyError(\n                    \"Unrecognized block type {}: we only support sab/isab\"\n                )\n\n        self.layers = nn.ModuleList(layers)\n\n    def forward(self, graph, feat):\n        \"\"\"\n        Compute the Encoder part of Set Transformer.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The input graph.\n        feat : torch.Tensor\n            The input feature with shape :math:`(N, D)`, where :math:`N` is the\n            number of nodes in the graph.\n\n        Returns\n        -------\n        torch.Tensor\n            The output feature with shape :math:`(N, D)`.\n        \"\"\"\n        lengths = graph.batch_num_nodes()\n        for layer in self.layers:\n            feat = layer(feat, lengths)\n        return feat\n\n\nclass SetTransformerDecoder(nn.Module):\n    r\"\"\"The Decoder module from `Set Transformer: A Framework for Attention-based\n    Permutation-Invariant Neural Networks <https://arxiv.org/pdf/1810.00825.pdf>`__\n\n    Parameters\n    ----------\n    d_model : int\n        Hidden size of the model.\n    num_heads : int\n        The number of heads.\n    d_head : int\n        Hidden size of each head.\n    d_ff : int\n        Kernel size in FFN (Positionwise Feed-Forward Network) layer.\n    n_layers : int\n        The number of layers.\n    k : int\n        The number of seed vectors in PMA (Pooling by Multihead Attention) layer.\n    dropouth : float\n        Dropout rate of each sublayer.\n    dropouta : float\n        Dropout rate of attention heads.\n\n    Examples\n    --------\n    >>> import dgl\n    >>> import torch as th\n    >>> from dgl.nn import SetTransformerDecoder\n    >>>\n    >>> g1 = dgl.rand_graph(3, 4)  # g1 is a random graph with 3 nodes and 4 edges\n    >>> g1_node_feats = th.rand(3, 5)  # feature size is 5\n    >>> g1_node_feats\n    tensor([[0.8948, 0.0699, 0.9137, 0.7567, 0.3637],\n            [0.8137, 0.8938, 0.8377, 0.4249, 0.6118],\n            [0.5197, 0.9030, 0.6825, 0.5725, 0.4755]])\n    >>>\n    >>> g2 = dgl.rand_graph(4, 6)  # g2 is a random graph with 4 nodes and 6 edges\n    >>> g2_node_feats = th.rand(4, 5)  # feature size is 5\n    >>> g2_node_feats\n    tensor([[0.2053, 0.2426, 0.4111, 0.9028, 0.5658],\n            [0.5278, 0.6365, 0.9990, 0.2351, 0.8945],\n            [0.3134, 0.0580, 0.4349, 0.7949, 0.3891],\n            [0.0142, 0.2709, 0.3330, 0.8521, 0.6925]])\n    >>>\n    >>> set_trans_dec = SetTransformerDecoder(5, 4, 4, 20, 1, 3)  # define the layer\n\n    Case 1: Input a single graph\n\n    >>> set_trans_dec(g1, g1_node_feats)\n    tensor([[-0.5538,  1.8726, -1.0470,  0.0276, -0.2994, -0.6317,  1.6754, -1.3189,\n              0.2291,  0.0461, -0.4042,  0.8387, -1.7091,  1.0845,  0.1902]],\n           grad_fn=<ViewBackward>)\n\n    Case 2: Input a batch of graphs\n\n    Build a batch of DGL graphs and concatenate all graphs' node features into one tensor.\n\n    >>> batch_g = dgl.batch([g1, g2])\n    >>> batch_f = th.cat([g1_node_feats, g2_node_feats])\n    >>>\n    >>> set_trans_dec(batch_g, batch_f)\n    tensor([[-0.5538,  1.8726, -1.0470,  0.0276, -0.2994, -0.6317,  1.6754, -1.3189,\n              0.2291,  0.0461, -0.4042,  0.8387, -1.7091,  1.0845,  0.1902],\n            [-0.5511,  1.8869, -1.0156,  0.0028, -0.3231, -0.6305,  1.6845, -1.3105,\n              0.2136,  0.0428, -0.3820,  0.8043, -1.7138,  1.1126,  0.1789]],\n           grad_fn=<ViewBackward>)\n\n    See Also\n    --------\n    SetTransformerEncoder\n    \"\"\"\n\n    def __init__(\n        self,\n        d_model,\n        num_heads,\n        d_head,\n        d_ff,\n        n_layers,\n        k,\n        dropouth=0.0,\n        dropouta=0.0,\n    ):\n        super(SetTransformerDecoder, self).__init__()\n        self.n_layers = n_layers\n        self.k = k\n        self.d_model = d_model\n        self.pma = PMALayer(\n            k,\n            d_model,\n            num_heads,\n            d_head,\n            d_ff,\n            dropouth=dropouth,\n            dropouta=dropouta,\n        )\n        layers = []\n        for _ in range(n_layers):\n            layers.append(\n                SetAttentionBlock(\n                    d_model,\n                    num_heads,\n                    d_head,\n                    d_ff,\n                    dropouth=dropouth,\n                    dropouta=dropouta,\n                )\n            )\n\n        self.layers = nn.ModuleList(layers)\n\n    def forward(self, graph, feat):\n        \"\"\"\n        Compute the decoder part of Set Transformer.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The input graph.\n        feat : torch.Tensor\n            The input feature with shape :math:`(N, D)`, where :math:`N` is the\n            number of nodes in the graph, and :math:`D` means the size of features.\n\n        Returns\n        -------\n        torch.Tensor\n            The output feature with shape :math:`(B, D)`, where :math:`B` refers to\n            the batch size.\n        \"\"\"\n        len_pma = graph.batch_num_nodes()\n        len_sab = [self.k] * graph.batch_size\n        feat = self.pma(feat, len_pma)\n        for layer in self.layers:\n            feat = layer(feat, len_sab)\n        return feat.view(graph.batch_size, self.k * self.d_model)\n\n\nclass WeightAndSum(nn.Module):\n    \"\"\"Compute importance weights for atoms and perform a weighted sum.\n\n    Parameters\n    ----------\n    in_feats : int\n        Input atom feature size\n\n    Examples\n    --------\n    The following example uses PyTorch backend.\n\n    >>> import dgl\n    >>> import torch as th\n    >>> from dgl.nn import WeightAndSum\n    >>>\n    >>> g1 = dgl.rand_graph(3, 4)  # g1 is a random graph with 3 nodes and 4 edges\n    >>> g1_node_feats = th.rand(3, 5)  # feature size is 5\n    >>> g1_node_feats\n    tensor([[0.8948, 0.0699, 0.9137, 0.7567, 0.3637],\n            [0.8137, 0.8938, 0.8377, 0.4249, 0.6118],\n            [0.5197, 0.9030, 0.6825, 0.5725, 0.4755]])\n    >>>\n    >>> g2 = dgl.rand_graph(4, 6)  # g2 is a random graph with 4 nodes and 6 edges\n    >>> g2_node_feats = th.rand(4, 5)  # feature size is 5\n    >>> g2_node_feats\n    tensor([[0.2053, 0.2426, 0.4111, 0.9028, 0.5658],\n            [0.5278, 0.6365, 0.9990, 0.2351, 0.8945],\n            [0.3134, 0.0580, 0.4349, 0.7949, 0.3891],\n            [0.0142, 0.2709, 0.3330, 0.8521, 0.6925]])\n    >>>\n    >>> weight_and_sum = WeightAndSum(5)  # create a weight and sum layer(in_feats=16)\n\n    Case 1: Input a single graph\n\n    >>> weight_and_sum(g1, g1_node_feats)\n    tensor([[1.2194, 0.9490, 1.3235, 0.9609, 0.7710]],\n           grad_fn=<SegmentReduceBackward>)\n\n    Case 2: Input a batch of graphs\n\n    Build a batch of DGL graphs and concatenate all graphs' node features into one tensor.\n\n    >>> batch_g = dgl.batch([g1, g2])\n    >>> batch_f = th.cat([g1_node_feats, g2_node_feats])\n    >>>\n    >>> weight_and_sum(batch_g, batch_f)\n    tensor([[1.2194, 0.9490, 1.3235, 0.9609, 0.7710],\n            [0.5322, 0.5840, 1.0729, 1.3665, 1.2360]],\n           grad_fn=<SegmentReduceBackward>)\n\n    Notes\n    -----\n    WeightAndSum module was commonly used in molecular property prediction networks,\n    see the GCN predictor in `dgl-lifesci <https://github.com/awslabs/dgl-lifesci/blob/\n    ae0491431804611ba466ff413f69d435789dbfd5/python/dgllife/model/model_zoo/\n    gcn_predictor.py>`__\n    to understand how to use WeightAndSum layer to get the graph readout output.\n    \"\"\"\n\n    def __init__(self, in_feats):\n        super(WeightAndSum, self).__init__()\n        self.in_feats = in_feats\n        self.atom_weighting = nn.Sequential(\n            nn.Linear(in_feats, 1), nn.Sigmoid()\n        )\n\n    def forward(self, g, feats):\n        \"\"\"Compute molecule representations out of atom representations\n\n        Parameters\n        ----------\n        g : DGLGraph\n            DGLGraph with batch size B for processing multiple molecules in parallel\n        feats : FloatTensor of shape (N, self.in_feats)\n            Representations for all atoms in the molecules\n            * N is the total number of atoms in all molecules\n\n        Returns\n        -------\n        FloatTensor of shape (B, self.in_feats)\n            Representations for B molecules\n        \"\"\"\n        with g.local_scope():\n            g.ndata[\"h\"] = feats\n            g.ndata[\"w\"] = self.atom_weighting(g.ndata[\"h\"])\n            h_g_sum = sum_nodes(g, \"h\", \"w\")\n\n        return h_g_sum\n"
  },
  {
    "path": "python/dgl/nn/pytorch/gt/__init__.py",
    "content": "\"\"\"Torch modules for Graph Transformer.\"\"\"\n\nfrom .biased_mha import BiasedMHA\nfrom .degree_encoder import DegreeEncoder\nfrom .egt import EGTLayer\nfrom .graphormer import GraphormerLayer\nfrom .lap_pos_encoder import LapPosEncoder\nfrom .path_encoder import PathEncoder\nfrom .spatial_encoder import SpatialEncoder, SpatialEncoder3d\n"
  },
  {
    "path": "python/dgl/nn/pytorch/gt/biased_mha.py",
    "content": "\"\"\"Biased Multi-head Attention\"\"\"\n\nimport torch as th\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass BiasedMHA(nn.Module):\n    r\"\"\"Dense Multi-Head Attention Module with Graph Attention Bias.\n\n    Compute attention between nodes with attention bias obtained from graph\n    structures, as introduced in `Do Transformers Really Perform Bad for\n    Graph Representation? <https://arxiv.org/pdf/2106.05234>`__\n\n    .. math::\n\n        \\text{Attn}=\\text{softmax}(\\dfrac{QK^T}{\\sqrt{d}} \\circ b)\n\n    :math:`Q` and :math:`K` are feature representations of nodes. :math:`d`\n    is the corresponding :attr:`feat_size`. :math:`b` is attention bias, which\n    can be additive or multiplicative according to the operator :math:`\\circ`.\n\n    Parameters\n    ----------\n    feat_size : int\n        Feature size.\n    num_heads : int\n        Number of attention heads, by which :attr:`feat_size` is divisible.\n    bias : bool, optional\n        If True, it uses bias for linear projection. Default: True.\n    attn_bias_type : str, optional\n        The type of attention bias used for modifying attention. Selected from\n        'add' or 'mul'. Default: 'add'.\n\n        * 'add' is for additive attention bias.\n        * 'mul' is for multiplicative attention bias.\n    attn_drop : float, optional\n        Dropout probability on attention weights. Defalt: 0.1.\n\n    Examples\n    --------\n    >>> import torch as th\n    >>> from dgl.nn import BiasedMHA\n\n    >>> ndata = th.rand(16, 100, 512)\n    >>> bias = th.rand(16, 100, 100, 8)\n    >>> net = BiasedMHA(feat_size=512, num_heads=8)\n    >>> out = net(ndata, bias)\n    \"\"\"\n\n    def __init__(\n        self,\n        feat_size,\n        num_heads,\n        bias=True,\n        attn_bias_type=\"add\",\n        attn_drop=0.1,\n    ):\n        super().__init__()\n        self.feat_size = feat_size\n        self.num_heads = num_heads\n        self.head_dim = feat_size // num_heads\n        assert (\n            self.head_dim * num_heads == feat_size\n        ), \"feat_size must be divisible by num_heads\"\n        self.scaling = self.head_dim**-0.5\n        self.attn_bias_type = attn_bias_type\n\n        self.q_proj = nn.Linear(feat_size, feat_size, bias=bias)\n        self.k_proj = nn.Linear(feat_size, feat_size, bias=bias)\n        self.v_proj = nn.Linear(feat_size, feat_size, bias=bias)\n        self.out_proj = nn.Linear(feat_size, feat_size, bias=bias)\n\n        self.dropout = nn.Dropout(p=attn_drop)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        \"\"\"\n        Initialize parameters of projection matrices, the same settings as in\n        the original implementation of the paper.\n        \"\"\"\n        nn.init.xavier_uniform_(self.q_proj.weight, gain=2**-0.5)\n        nn.init.xavier_uniform_(self.k_proj.weight, gain=2**-0.5)\n        nn.init.xavier_uniform_(self.v_proj.weight, gain=2**-0.5)\n\n        nn.init.xavier_uniform_(self.out_proj.weight)\n        if self.out_proj.bias is not None:\n            nn.init.constant_(self.out_proj.bias, 0.0)\n\n    def forward(self, ndata, attn_bias=None, attn_mask=None):\n        \"\"\"Forward computation.\n\n        Parameters\n        ----------\n        ndata : torch.Tensor\n            A 3D input tensor. Shape: (batch_size, N, :attr:`feat_size`), where\n            N is the maximum number of nodes.\n        attn_bias : torch.Tensor, optional\n            The attention bias used for attention modification. Shape:\n            (batch_size, N, N, :attr:`num_heads`).\n        attn_mask : torch.Tensor, optional\n            The attention mask used for avoiding computation on invalid\n            positions, where invalid positions are indicated by `True` values.\n            Shape: (batch_size, N, N). Note: For rows corresponding to\n            unexisting nodes, make sure at least one entry is set to `False` to\n            prevent obtaining NaNs with softmax.\n\n        Returns\n        -------\n        y : torch.Tensor\n            The output tensor. Shape: (batch_size, N, :attr:`feat_size`)\n        \"\"\"\n        q_h = self.q_proj(ndata).transpose(0, 1)\n        k_h = self.k_proj(ndata).transpose(0, 1)\n        v_h = self.v_proj(ndata).transpose(0, 1)\n        bsz, N, _ = ndata.shape\n        q_h = (\n            q_h.reshape(N, bsz * self.num_heads, self.head_dim).transpose(0, 1)\n            * self.scaling\n        )\n        k_h = k_h.reshape(N, bsz * self.num_heads, self.head_dim).permute(\n            1, 2, 0\n        )\n        v_h = v_h.reshape(N, bsz * self.num_heads, self.head_dim).transpose(\n            0, 1\n        )\n\n        attn_weights = (\n            th.bmm(q_h, k_h)\n            .transpose(0, 2)\n            .reshape(N, N, bsz, self.num_heads)\n            .transpose(0, 2)\n        )\n\n        if attn_bias is not None:\n            if self.attn_bias_type == \"add\":\n                attn_weights += attn_bias\n            else:\n                attn_weights *= attn_bias\n        if attn_mask is not None:\n            attn_weights[attn_mask.to(th.bool)] = float(\"-inf\")\n        attn_weights = F.softmax(\n            attn_weights.transpose(0, 2)\n            .reshape(N, N, bsz * self.num_heads)\n            .transpose(0, 2),\n            dim=2,\n        )\n\n        attn_weights = self.dropout(attn_weights)\n\n        attn = th.bmm(attn_weights, v_h).transpose(0, 1)\n\n        attn = self.out_proj(\n            attn.reshape(N, bsz, self.feat_size).transpose(0, 1)\n        )\n\n        return attn\n"
  },
  {
    "path": "python/dgl/nn/pytorch/gt/degree_encoder.py",
    "content": "\"\"\"Degree Encoder\"\"\"\n\nimport torch as th\nimport torch.nn as nn\n\n\nclass DegreeEncoder(nn.Module):\n    r\"\"\"Degree Encoder, as introduced in\n    `Do Transformers Really Perform Bad for Graph Representation?\n    <https://proceedings.neurips.cc/paper/2021/file/f1c1592588411002af340cbaedd6fc33-Paper.pdf>`__\n\n    This module is a learnable degree embedding module.\n\n    Parameters\n    ----------\n    max_degree : int\n        Upper bound of degrees to be encoded.\n        Each degree will be clamped into the range [0, ``max_degree``].\n    embedding_dim : int\n        Output dimension of embedding vectors.\n    direction : str, optional\n        Degrees of which direction to be encoded,\n        selected from ``in``, ``out`` and ``both``.\n        ``both`` encodes degrees from both directions\n        and output the addition of them.\n        Default : ``both``.\n\n    Example\n    -------\n    >>> import dgl\n    >>> from dgl.nn import DegreeEncoder\n    >>> import torch as th\n    >>> from torch.nn.utils.rnn import pad_sequence\n\n    >>> g1 = dgl.graph(([0,0,0,1,1,2,3,3], [1,2,3,0,3,0,0,1]))\n    >>> g2 = dgl.graph(([0,1], [1,0]))\n    >>> in_degree = pad_sequence([g1.in_degrees(), g2.in_degrees()], batch_first=True)\n    >>> out_degree = pad_sequence([g1.out_degrees(), g2.out_degrees()], batch_first=True)\n    >>> print(in_degree.shape)\n    torch.Size([2, 4])\n    >>> degree_encoder = DegreeEncoder(5, 16)\n    >>> degree_embedding = degree_encoder(th.stack((in_degree, out_degree)))\n    >>> print(degree_embedding.shape)\n    torch.Size([2, 4, 16])\n    \"\"\"\n\n    def __init__(self, max_degree, embedding_dim, direction=\"both\"):\n        super(DegreeEncoder, self).__init__()\n        self.direction = direction\n        if direction == \"both\":\n            self.encoder1 = nn.Embedding(\n                max_degree + 1, embedding_dim, padding_idx=0\n            )\n            self.encoder2 = nn.Embedding(\n                max_degree + 1, embedding_dim, padding_idx=0\n            )\n        else:\n            self.encoder = nn.Embedding(\n                max_degree + 1, embedding_dim, padding_idx=0\n            )\n        self.max_degree = max_degree\n\n    def forward(self, degrees):\n        \"\"\"\n        Parameters\n        ----------\n        degrees : Tensor\n            If :attr:`direction` is ``both``, it should be stacked in degrees and out degrees\n            of the batched graph with zero padding, a tensor of shape :math:`(2, B, N)`.\n            Otherwise, it should be zero-padded in degrees or out degrees of the batched\n            graph, a tensor of shape :math:`(B, N)`, where :math:`B` is the batch size\n            of the batched graph, and :math:`N` is the maximum number of nodes.\n\n        Returns\n        -------\n        Tensor\n            Return degree embedding vectors of shape :math:`(B, N, d)`,\n            where :math:`d` is :attr:`embedding_dim`.\n        \"\"\"\n        degrees = th.clamp(degrees, min=0, max=self.max_degree)\n\n        if self.direction == \"in\":\n            assert len(degrees.shape) == 2\n            degree_embedding = self.encoder(degrees)\n        elif self.direction == \"out\":\n            assert len(degrees.shape) == 2\n            degree_embedding = self.encoder(degrees)\n        elif self.direction == \"both\":\n            assert len(degrees.shape) == 3 and degrees.shape[0] == 2\n            degree_embedding = self.encoder1(degrees[0]) + self.encoder2(\n                degrees[1]\n            )\n        else:\n            raise ValueError(\n                f'Supported direction options: \"in\", \"out\" and \"both\", '\n                f\"but got {self.direction}\"\n            )\n        return degree_embedding\n"
  },
  {
    "path": "python/dgl/nn/pytorch/gt/egt.py",
    "content": "\"\"\"EGT Layer\"\"\"\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass EGTLayer(nn.Module):\n    r\"\"\"EGTLayer for Edge-augmented Graph Transformer (EGT), as introduced in\n    `Global Self-Attention as a Replacement for Graph Convolution\n    Reference `<https://arxiv.org/pdf/2108.03348.pdf>`_\n\n    Parameters\n    ----------\n    feat_size : int\n        Node feature size.\n    edge_feat_size : int\n        Edge feature size.\n    num_heads : int\n        Number of attention heads, by which :attr: `feat_size` is divisible.\n    num_virtual_nodes : int\n        Number of virtual nodes.\n    dropout : float, optional\n        Dropout probability. Default: 0.0.\n    attn_dropout : float, optional\n        Attention dropout probability. Default: 0.0.\n    activation : callable activation layer, optional\n        Activation function. Default: nn.ELU().\n    edge_update : bool, optional\n        Whether to update the edge embedding. Default: True.\n\n    Examples\n    --------\n    >>> import torch as th\n    >>> from dgl.nn import EGTLayer\n\n    >>> batch_size = 16\n    >>> num_nodes = 100\n    >>> feat_size, edge_feat_size = 128, 32\n    >>> nfeat = th.rand(batch_size, num_nodes, feat_size)\n    >>> efeat = th.rand(batch_size, num_nodes, num_nodes, edge_feat_size)\n    >>> net = EGTLayer(\n            feat_size=feat_size,\n            edge_feat_size=edge_feat_size,\n            num_heads=8,\n            num_virtual_nodes=4,\n        )\n    >>> out = net(nfeat, efeat)\n    \"\"\"\n\n    def __init__(\n        self,\n        feat_size,\n        edge_feat_size,\n        num_heads,\n        num_virtual_nodes,\n        dropout=0,\n        attn_dropout=0,\n        activation=nn.ELU(),\n        edge_update=True,\n    ):\n        super().__init__()\n        self.num_heads = num_heads\n        self.num_virtual_nodes = num_virtual_nodes\n        self.edge_update = edge_update\n\n        assert (\n            feat_size % num_heads == 0\n        ), \"feat_size must be divisible by num_heads\"\n        self.dot_dim = feat_size // num_heads\n        self.mha_ln_h = nn.LayerNorm(feat_size)\n        self.mha_ln_e = nn.LayerNorm(edge_feat_size)\n        self.edge_input = nn.Linear(edge_feat_size, num_heads)\n        self.qkv_proj = nn.Linear(feat_size, feat_size * 3)\n        self.gate = nn.Linear(edge_feat_size, num_heads)\n        self.attn_dropout = nn.Dropout(attn_dropout)\n        self.node_output = nn.Linear(feat_size, feat_size)\n        self.mha_dropout_h = nn.Dropout(dropout)\n\n        self.node_ffn = nn.Sequential(\n            nn.LayerNorm(feat_size),\n            nn.Linear(feat_size, feat_size),\n            activation,\n            nn.Linear(feat_size, feat_size),\n            nn.Dropout(dropout),\n        )\n\n        if self.edge_update:\n            self.edge_output = nn.Linear(num_heads, edge_feat_size)\n            self.mha_dropout_e = nn.Dropout(dropout)\n            self.edge_ffn = nn.Sequential(\n                nn.LayerNorm(edge_feat_size),\n                nn.Linear(edge_feat_size, edge_feat_size),\n                activation,\n                nn.Linear(edge_feat_size, edge_feat_size),\n                nn.Dropout(dropout),\n            )\n\n    def forward(self, nfeat, efeat, mask=None):\n        \"\"\"Forward computation. Note: :attr:`nfeat` and :attr:`efeat` should be\n        padded with embedding of virtual nodes if :attr:`num_virtual_nodes` > 0,\n        while :attr:`mask` should be padded with `0` values for virtual nodes.\n        The padding should be put at the beginning.\n\n        Parameters\n        ----------\n        nfeat : torch.Tensor\n            A 3D input tensor. Shape: (batch_size, N, :attr:`feat_size`), where N\n            is the sum of the maximum number of nodes and the number of virtual nodes.\n        efeat : torch.Tensor\n            Edge embedding used for attention computation and self update.\n            Shape: (batch_size, N, N, :attr:`edge_feat_size`).\n        mask : torch.Tensor, optional\n            The attention mask used for avoiding computation on invalid\n            positions, where valid positions are indicated by `0` and\n            invalid positions are indicated by `-inf`.\n            Shape: (batch_size, N, N). Default: None.\n\n        Returns\n        -------\n        nfeat : torch.Tensor\n            The output node embedding. Shape: (batch_size, N, :attr:`feat_size`).\n        efeat : torch.Tensor, optional\n            The output edge embedding. Shape: (batch_size, N, N, :attr:`edge_feat_size`).\n            It is returned only if :attr:`edge_update` is True.\n        \"\"\"\n        nfeat_r1 = nfeat\n        efeat_r1 = efeat\n\n        nfeat_ln = self.mha_ln_h(nfeat)\n        efeat_ln = self.mha_ln_e(efeat)\n        qkv = self.qkv_proj(nfeat_ln)\n        e_bias = self.edge_input(efeat_ln)\n        gates = self.gate(efeat_ln)\n        bsz, N, _ = qkv.shape\n        q_h, k_h, v_h = qkv.view(bsz, N, -1, self.num_heads).split(\n            self.dot_dim, dim=2\n        )\n        attn_hat = torch.einsum(\"bldh,bmdh->blmh\", q_h, k_h)\n        attn_hat = attn_hat.clamp(-5, 5) + e_bias\n\n        if mask is None:\n            gates = torch.sigmoid(gates)\n            attn_tild = F.softmax(attn_hat, dim=2) * gates\n        else:\n            gates = torch.sigmoid(gates + mask.unsqueeze(-1))\n            attn_tild = F.softmax(attn_hat + mask.unsqueeze(-1), dim=2) * gates\n\n        attn_tild = self.attn_dropout(attn_tild)\n        v_attn = torch.einsum(\"blmh,bmkh->blkh\", attn_tild, v_h)\n\n        # Scale the aggregated values by degree.\n        degrees = torch.sum(gates, dim=2, keepdim=True)\n        degree_scalers = torch.log(1 + degrees)\n        degree_scalers[:, : self.num_virtual_nodes] = 1.0\n        v_attn = v_attn * degree_scalers\n\n        v_attn = v_attn.reshape(bsz, N, self.num_heads * self.dot_dim)\n        nfeat = self.node_output(v_attn)\n\n        nfeat = self.mha_dropout_h(nfeat)\n        nfeat.add_(nfeat_r1)\n        nfeat_r2 = nfeat\n        nfeat = self.node_ffn(nfeat)\n        nfeat.add_(nfeat_r2)\n\n        if self.edge_update:\n            efeat = self.edge_output(attn_hat)\n            efeat = self.mha_dropout_e(efeat)\n            efeat.add_(efeat_r1)\n            efeat_r2 = efeat\n            efeat = self.edge_ffn(efeat)\n            efeat.add_(efeat_r2)\n\n            return nfeat, efeat\n\n        return nfeat\n"
  },
  {
    "path": "python/dgl/nn/pytorch/gt/graphormer.py",
    "content": "\"\"\"Graphormer Layer\"\"\"\n\nimport torch.nn as nn\n\nfrom .biased_mha import BiasedMHA\n\n\nclass GraphormerLayer(nn.Module):\n    r\"\"\"Graphormer Layer with Dense Multi-Head Attention, as introduced\n    in `Do Transformers Really Perform Bad for Graph Representation?\n    <https://arxiv.org/pdf/2106.05234>`__\n\n    Parameters\n    ----------\n    feat_size : int\n        Feature size.\n    hidden_size : int\n        Hidden size of feedforward layers.\n    num_heads : int\n        Number of attention heads, by which :attr:`feat_size` is divisible.\n    attn_bias_type : str, optional\n        The type of attention bias used for modifying attention. Selected from\n        'add' or 'mul'. Default: 'add'.\n\n        * 'add' is for additive attention bias.\n        * 'mul' is for multiplicative attention bias.\n    norm_first : bool, optional\n        If True, it performs layer normalization before attention and\n        feedforward operations. Otherwise, it applies layer normalization\n        afterwards. Default: False.\n    dropout : float, optional\n        Dropout probability. Default: 0.1.\n    attn_dropout : float, optional\n        Attention dropout probability. Default: 0.1.\n    activation : callable activation layer, optional\n        Activation function. Default: nn.ReLU().\n\n    Examples\n    --------\n    >>> import torch as th\n    >>> from dgl.nn import GraphormerLayer\n\n    >>> batch_size = 16\n    >>> num_nodes = 100\n    >>> feat_size = 512\n    >>> num_heads = 8\n    >>> nfeat = th.rand(batch_size, num_nodes, feat_size)\n    >>> bias = th.rand(batch_size, num_nodes, num_nodes, num_heads)\n    >>> net = GraphormerLayer(\n            feat_size=feat_size,\n            hidden_size=2048,\n            num_heads=num_heads\n        )\n    >>> out = net(nfeat, bias)\n    \"\"\"\n\n    def __init__(\n        self,\n        feat_size,\n        hidden_size,\n        num_heads,\n        attn_bias_type=\"add\",\n        norm_first=False,\n        dropout=0.1,\n        attn_dropout=0.1,\n        activation=nn.ReLU(),\n    ):\n        super().__init__()\n\n        self.norm_first = norm_first\n\n        self.attn = BiasedMHA(\n            feat_size=feat_size,\n            num_heads=num_heads,\n            attn_bias_type=attn_bias_type,\n            attn_drop=attn_dropout,\n        )\n        self.ffn = nn.Sequential(\n            nn.Linear(feat_size, hidden_size),\n            activation,\n            nn.Dropout(p=dropout),\n            nn.Linear(hidden_size, feat_size),\n            nn.Dropout(p=dropout),\n        )\n\n        self.dropout = nn.Dropout(p=dropout)\n        self.attn_layer_norm = nn.LayerNorm(feat_size)\n        self.ffn_layer_norm = nn.LayerNorm(feat_size)\n\n    def forward(self, nfeat, attn_bias=None, attn_mask=None):\n        \"\"\"Forward computation.\n\n        Parameters\n        ----------\n        nfeat : torch.Tensor\n            A 3D input tensor. Shape: (batch_size, N, :attr:`feat_size`), where\n            N is the maximum number of nodes.\n        attn_bias : torch.Tensor, optional\n            The attention bias used for attention modification. Shape:\n            (batch_size, N, N, :attr:`num_heads`).\n        attn_mask : torch.Tensor, optional\n            The attention mask used for avoiding computation on invalid\n            positions, where invalid positions are indicated by `True` values.\n            Shape: (batch_size, N, N). Note: For rows corresponding to\n            unexisting nodes, make sure at least one entry is set to `False` to\n            prevent obtaining NaNs with softmax.\n\n        Returns\n        -------\n        y : torch.Tensor\n            The output tensor. Shape: (batch_size, N, :attr:`feat_size`)\n        \"\"\"\n        residual = nfeat\n        if self.norm_first:\n            nfeat = self.attn_layer_norm(nfeat)\n        nfeat = self.attn(nfeat, attn_bias, attn_mask)\n        nfeat = self.dropout(nfeat)\n        nfeat = residual + nfeat\n        if not self.norm_first:\n            nfeat = self.attn_layer_norm(nfeat)\n        residual = nfeat\n        if self.norm_first:\n            nfeat = self.ffn_layer_norm(nfeat)\n        nfeat = self.ffn(nfeat)\n        nfeat = residual + nfeat\n        if not self.norm_first:\n            nfeat = self.ffn_layer_norm(nfeat)\n        return nfeat\n"
  },
  {
    "path": "python/dgl/nn/pytorch/gt/lap_pos_encoder.py",
    "content": "\"\"\"Laplacian Positional Encoder\"\"\"\n\nimport torch as th\nimport torch.nn as nn\n\n\nclass LapPosEncoder(nn.Module):\n    r\"\"\"Laplacian Positional Encoder (LPE), as introduced in\n    `GraphGPS: General Powerful Scalable Graph Transformers\n    <https://arxiv.org/abs/2205.12454>`__\n\n    This module is a learned laplacian positional encoding module using\n    Transformer or DeepSet.\n\n    Parameters\n    ----------\n    model_type : str\n        Encoder model type for LPE, can only be \"Transformer\" or \"DeepSet\".\n    num_layer : int\n        Number of layers in Transformer/DeepSet Encoder.\n    k : int\n        Number of smallest non-trivial eigenvectors.\n    dim : int\n        Output size of final laplacian encoding.\n    n_head : int, optional\n        Number of heads in Transformer Encoder.\n        Default : 1.\n    batch_norm : bool, optional\n        If True, apply batch normalization on raw laplacian positional\n        encoding. Default : False.\n    num_post_layer : int, optional\n        If num_post_layer > 0, apply an MLP of ``num_post_layer`` layers after\n        pooling. Default : 0.\n\n    Example\n    -------\n    >>> import dgl\n    >>> from dgl import LapPE\n    >>> from dgl.nn import LapPosEncoder\n\n    >>> transform = LapPE(k=5, feat_name='eigvec', eigval_name='eigval', padding=True)\n    >>> g = dgl.graph(([0,1,2,3,4,2,3,1,4,0], [2,3,1,4,0,0,1,2,3,4]))\n    >>> g = transform(g)\n    >>> eigvals, eigvecs = g.ndata['eigval'], g.ndata['eigvec']\n    >>> transformer_encoder = LapPosEncoder(\n            model_type=\"Transformer\", num_layer=3, k=5, dim=16, n_head=4\n        )\n    >>> pos_encoding = transformer_encoder(eigvals, eigvecs)\n    >>> deepset_encoder = LapPosEncoder(\n            model_type=\"DeepSet\", num_layer=3, k=5, dim=16, num_post_layer=2\n        )\n    >>> pos_encoding = deepset_encoder(eigvals, eigvecs)\n    \"\"\"\n\n    def __init__(\n        self,\n        model_type,\n        num_layer,\n        k,\n        dim,\n        n_head=1,\n        batch_norm=False,\n        num_post_layer=0,\n    ):\n        super(LapPosEncoder, self).__init__()\n        self.model_type = model_type\n        self.linear = nn.Linear(2, dim)\n\n        if self.model_type == \"Transformer\":\n            encoder_layer = nn.TransformerEncoderLayer(\n                d_model=dim, nhead=n_head, batch_first=True\n            )\n            self.pe_encoder = nn.TransformerEncoder(\n                encoder_layer, num_layers=num_layer\n            )\n        elif self.model_type == \"DeepSet\":\n            layers = []\n            if num_layer == 1:\n                layers.append(nn.ReLU())\n            else:\n                self.linear = nn.Linear(2, 2 * dim)\n                layers.append(nn.ReLU())\n                for _ in range(num_layer - 2):\n                    layers.append(nn.Linear(2 * dim, 2 * dim))\n                    layers.append(nn.ReLU())\n                layers.append(nn.Linear(2 * dim, dim))\n                layers.append(nn.ReLU())\n            self.pe_encoder = nn.Sequential(*layers)\n        else:\n            raise ValueError(\n                f\"model_type '{model_type}' is not allowed, must be \"\n                \"'Transformer' or 'DeepSet'.\"\n            )\n\n        if batch_norm:\n            self.raw_norm = nn.BatchNorm1d(k)\n        else:\n            self.raw_norm = None\n\n        if num_post_layer > 0:\n            layers = []\n            if num_post_layer == 1:\n                layers.append(nn.Linear(dim, dim))\n                layers.append(nn.ReLU())\n            else:\n                layers.append(nn.Linear(dim, 2 * dim))\n                layers.append(nn.ReLU())\n                for _ in range(num_post_layer - 2):\n                    layers.append(nn.Linear(2 * dim, 2 * dim))\n                    layers.append(nn.ReLU())\n                layers.append(nn.Linear(2 * dim, dim))\n                layers.append(nn.ReLU())\n            self.post_mlp = nn.Sequential(*layers)\n        else:\n            self.post_mlp = None\n\n    def forward(self, eigvals, eigvecs):\n        r\"\"\"\n        Parameters\n        ----------\n        eigvals : Tensor\n            Laplacian Eigenvalues of shape :math:`(N, k)`, k different\n            eigenvalues repeat N times, can be obtained by using `LaplacianPE`.\n        eigvecs : Tensor\n            Laplacian Eigenvectors of shape :math:`(N, k)`, can be obtained by\n            using `LaplacianPE`.\n\n        Returns\n        -------\n        Tensor\n            Return the laplacian positional encodings of shape :math:`(N, d)`,\n            where :math:`N` is the number of nodes in the input graph,\n            :math:`d` is :attr:`dim`.\n        \"\"\"\n        pos_encoding = th.cat(\n            (eigvecs.unsqueeze(2), eigvals.unsqueeze(2)), dim=2\n        ).float()\n        empty_mask = th.isnan(pos_encoding)\n\n        pos_encoding[empty_mask] = 0\n        if self.raw_norm:\n            pos_encoding = self.raw_norm(pos_encoding)\n        pos_encoding = self.linear(pos_encoding)\n\n        if self.model_type == \"Transformer\":\n            pos_encoding = self.pe_encoder(\n                src=pos_encoding, src_key_padding_mask=empty_mask[:, :, 1]\n            )\n        else:\n            pos_encoding = self.pe_encoder(pos_encoding)\n\n        # Remove masked sequences.\n        pos_encoding[empty_mask[:, :, 1]] = 0\n\n        # Sum pooling.\n        pos_encoding = th.sum(pos_encoding, 1, keepdim=False)\n\n        # MLP post pooling.\n        if self.post_mlp:\n            pos_encoding = self.post_mlp(pos_encoding)\n\n        return pos_encoding\n"
  },
  {
    "path": "python/dgl/nn/pytorch/gt/path_encoder.py",
    "content": "\"\"\"Path Encoder\"\"\"\nimport torch as th\nimport torch.nn as nn\n\n\nclass PathEncoder(nn.Module):\n    r\"\"\"Path Encoder, as introduced in Edge Encoding of\n    `Do Transformers Really Perform Bad for Graph Representation?\n    <https://proceedings.neurips.cc/paper/2021/file/f1c1592588411002af340cbaedd6fc33-Paper.pdf>`__\n\n    This module is a learnable path embedding module and encodes the shortest\n    path between each pair of nodes as attention bias.\n\n    Parameters\n    ----------\n    max_len : int\n        Maximum number of edges in each path to be encoded.\n        Exceeding part of each path will be truncated, i.e.\n        truncating edges with serial number no less than :attr:`max_len`.\n    feat_dim : int\n        Dimension of edge features in the input graph.\n    num_heads : int, optional\n        Number of attention heads if multi-head attention mechanism is applied.\n        Default : 1.\n\n    Examples\n    --------\n    >>> import torch as th\n    >>> import dgl\n    >>> from dgl.nn import PathEncoder\n    >>> from dgl import shortest_dist\n\n    >>> g = dgl.graph(([0,0,0,1,1,2,3,3], [1,2,3,0,3,0,0,1]))\n    >>> edata = th.rand(8, 16)\n    >>> # Since shortest_dist returns -1 for unreachable node pairs,\n    >>> # edata[-1] should be filled with zero padding.\n    >>> edata = th.cat(\n            (edata, th.zeros(1, 16)), dim=0\n        )\n    >>> dist, path = shortest_dist(g, root=None, return_paths=True)\n    >>> path_data = edata[path[:, :, :2]]\n    >>> path_encoder = PathEncoder(2, 16, num_heads=8)\n    >>> out = path_encoder(dist.unsqueeze(0), path_data.unsqueeze(0))\n    >>> print(out.shape)\n    torch.Size([1, 4, 4, 8])\n    \"\"\"\n\n    def __init__(self, max_len, feat_dim, num_heads=1):\n        super().__init__()\n        self.max_len = max_len\n        self.feat_dim = feat_dim\n        self.num_heads = num_heads\n        self.embedding_table = nn.Embedding(max_len * num_heads, feat_dim)\n\n    def forward(self, dist, path_data):\n        \"\"\"\n        Parameters\n        ----------\n        dist : Tensor\n            Shortest path distance matrix of the batched graph with zero padding,\n            of shape :math:`(B, N, N)`, where :math:`B` is the batch size of\n            the batched graph, and :math:`N` is the maximum number of nodes.\n        path_data : Tensor\n            Edge feature along the shortest path with zero padding, of shape\n            :math:`(B, N, N, L, d)`, where :math:`L` is the maximum length of\n            the shortest paths, and :math:`d` is :attr:`feat_dim`.\n\n        Returns\n        -------\n        torch.Tensor\n            Return attention bias as path encoding, of shape\n            :math:`(B, N, N, H)`, where :math:`B` is the batch size of\n            the input graph, :math:`N` is the maximum number of nodes, and\n            :math:`H` is :attr:`num_heads`.\n        \"\"\"\n        shortest_distance = th.clamp(dist, min=1, max=self.max_len)\n        edge_embedding = self.embedding_table.weight.reshape(\n            self.max_len, self.num_heads, -1\n        )\n        path_encoding = th.div(\n            th.einsum(\"bxyld,lhd->bxyh\", path_data, edge_embedding).permute(\n                3, 0, 1, 2\n            ),\n            shortest_distance,\n        ).permute(1, 2, 3, 0)\n        return path_encoding\n"
  },
  {
    "path": "python/dgl/nn/pytorch/gt/spatial_encoder.py",
    "content": "\"\"\"Spatial Encoder\"\"\"\n\nimport torch as th\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\ndef gaussian(x, mean, std):\n    \"\"\"compute gaussian basis kernel function\"\"\"\n    const_pi = 3.14159\n    a = (2 * const_pi) ** 0.5\n    return th.exp(-0.5 * (((x - mean) / std) ** 2)) / (a * std)\n\n\nclass SpatialEncoder(nn.Module):\n    r\"\"\"Spatial Encoder, as introduced in\n    `Do Transformers Really Perform Bad for Graph Representation?\n    <https://proceedings.neurips.cc/paper/2021/file/f1c1592588411002af340cbaedd6fc33-Paper.pdf>`__\n\n    This module is a learnable spatial embedding module, which encodes\n    the shortest distance between each node pair for attention bias.\n\n    Parameters\n    ----------\n    max_dist : int\n        Upper bound of the shortest path distance\n        between each node pair to be encoded.\n        All distance will be clamped into the range `[0, max_dist]`.\n    num_heads : int, optional\n        Number of attention heads if multi-head attention mechanism is applied.\n        Default : 1.\n\n    Examples\n    --------\n    >>> import torch as th\n    >>> import dgl\n    >>> from dgl.nn import SpatialEncoder\n    >>> from dgl import shortest_dist\n\n    >>> g1 = dgl.graph(([0,0,0,1,1,2,3,3], [1,2,3,0,3,0,0,1]))\n    >>> g2 = dgl.graph(([0,1], [1,0]))\n    >>> n1, n2 = g1.num_nodes(), g2.num_nodes()\n    >>> # use -1 padding since shortest_dist returns -1 for unreachable node pairs\n    >>> dist = -th.ones((2, 4, 4), dtype=th.long)\n    >>> dist[0, :n1, :n1] = shortest_dist(g1, root=None, return_paths=False)\n    >>> dist[1, :n2, :n2] = shortest_dist(g2, root=None, return_paths=False)\n    >>> spatial_encoder = SpatialEncoder(max_dist=2, num_heads=8)\n    >>> out = spatial_encoder(dist)\n    >>> print(out.shape)\n    torch.Size([2, 4, 4, 8])\n    \"\"\"\n\n    def __init__(self, max_dist, num_heads=1):\n        super().__init__()\n        self.max_dist = max_dist\n        self.num_heads = num_heads\n        # deactivate node pair between which the distance is -1\n        self.embedding_table = nn.Embedding(\n            max_dist + 2, num_heads, padding_idx=0\n        )\n\n    def forward(self, dist):\n        \"\"\"\n        Parameters\n        ----------\n        dist : Tensor\n            Shortest path distance of the batched graph with -1 padding, a tensor\n            of shape :math:`(B, N, N)`, where :math:`B` is the batch size of\n            the batched graph, and :math:`N` is the maximum number of nodes.\n\n        Returns\n        -------\n        torch.Tensor\n            Return attention bias as spatial encoding of shape\n            :math:`(B, N, N, H)`, where :math:`H` is :attr:`num_heads`.\n        \"\"\"\n        spatial_encoding = self.embedding_table(\n            th.clamp(\n                dist,\n                min=-1,\n                max=self.max_dist,\n            )\n            + 1\n        )\n        return spatial_encoding\n\n\nclass SpatialEncoder3d(nn.Module):\n    r\"\"\"3D Spatial Encoder, as introduced in\n    `One Transformer Can Understand Both 2D & 3D Molecular Data\n    <https://arxiv.org/pdf/2210.01765.pdf>`__\n\n    This module encodes pair-wise relation between node pair :math:`(i,j)` in\n    the 3D geometric space, according to the Gaussian Basis Kernel function:\n\n    :math:`\\psi _{(i,j)} ^k = \\frac{1}{\\sqrt{2\\pi} \\lvert \\sigma^k \\rvert}\n    \\exp{\\left ( -\\frac{1}{2} \\left( \\frac{\\gamma_{(i,j)} \\lvert \\lvert r_i -\n    r_j \\rvert \\rvert + \\beta_{(i,j)} - \\mu^k}{\\lvert \\sigma^k \\rvert} \\right)\n    ^2 \\right)}，k=1,...,K,`\n\n    where :math:`K` is the number of Gaussian Basis kernels. :math:`r_i` is the\n    Cartesian coordinate of node :math:`i`.\n    :math:`\\gamma_{(i,j)}, \\beta_{(i,j)}` are learnable scaling factors and\n    biases determined by node types. :math:`\\mu^k, \\sigma^k` are learnable\n    centers and standard deviations of the Gaussian Basis kernels.\n\n    Parameters\n    ----------\n    num_kernels : int\n        Number of Gaussian Basis Kernels to be applied. Each Gaussian Basis\n        Kernel contains a learnable kernel center and a learnable standard\n        deviation.\n    num_heads : int, optional\n        Number of attention heads if multi-head attention mechanism is applied.\n        Default : 1.\n    max_node_type : int, optional\n        Maximum number of node types. Each node type has a corresponding\n        learnable scaling factor and a bias. Default : 100.\n\n    Examples\n    --------\n    >>> import torch as th\n    >>> import dgl\n    >>> from dgl.nn import SpatialEncoder3d\n\n    >>> coordinate = th.rand(1, 4, 3)\n    >>> node_type = th.tensor([[1, 0, 2, 1]])\n    >>> spatial_encoder = SpatialEncoder3d(num_kernels=4,\n    ...                                    num_heads=8,\n    ...                                    max_node_type=3)\n    >>> out = spatial_encoder(coordinate, node_type=node_type)\n    >>> print(out.shape)\n    torch.Size([1, 4, 4, 8])\n    \"\"\"\n\n    def __init__(self, num_kernels, num_heads=1, max_node_type=100):\n        super().__init__()\n        self.num_kernels = num_kernels\n        self.num_heads = num_heads\n        self.max_node_type = max_node_type\n        self.means = nn.Parameter(th.empty(num_kernels))\n        self.stds = nn.Parameter(th.empty(num_kernels))\n        self.linear_layer_1 = nn.Linear(num_kernels, num_kernels)\n        self.linear_layer_2 = nn.Linear(num_kernels, num_heads)\n        # There are 2 * max_node_type + 3 pairs of gamma and beta parameters:\n        # 1. Parameters at position 0 are for default gamma/beta when no node\n        #    type is given\n        # 2. Parameters at position 1 to max_node_type+1 are for src node types.\n        #    (position 1 is for padded unexisting nodes)\n        # 3. Parameters at position max_node_type+2 to 2*max_node_type+2 are\n        #    for tgt node types. (position max_node_type+2 is for padded)\n        #    unexisting nodes)\n        self.gamma = nn.Embedding(2 * max_node_type + 3, 1, padding_idx=0)\n        self.beta = nn.Embedding(2 * max_node_type + 3, 1, padding_idx=0)\n\n        nn.init.uniform_(self.means, 0, 3)\n        nn.init.uniform_(self.stds, 0, 3)\n        nn.init.constant_(self.gamma.weight, 1)\n        nn.init.constant_(self.beta.weight, 0)\n\n    def forward(self, coord, node_type=None):\n        \"\"\"\n        Parameters\n        ----------\n        coord : torch.Tensor\n            3D coordinates of nodes in shape :math:`(B, N, 3)`, where :math:`B`\n            is the batch size, :math:`N`: is the maximum number of nodes.\n        node_type : torch.Tensor, optional\n            Node type ids of nodes. Default : None.\n\n            * If specified, :attr:`node_type` should be a tensor in shape\n              :math:`(B, N,)`. The scaling factors in gaussian kernels of each\n              pair of nodes are determined by their node types.\n            * Otherwise, :attr:`node_type` will be set to zeros of the same\n              shape by default.\n\n        Returns\n        -------\n        torch.Tensor\n            Return attention bias as 3D spatial encoding of shape\n            :math:`(B, N, N, H)`, where :math:`H` is :attr:`num_heads`.\n        \"\"\"\n        bsz, N = coord.shape[:2]\n        euc_dist = th.cdist(coord, coord, p=2.0)  # shape: [B, n, n]\n        if node_type is None:\n            node_type = th.zeros([bsz, N, N, 2], device=coord.device).long()\n        else:\n            src_node_type = node_type.unsqueeze(-1).repeat(1, 1, N)\n            tgt_node_type = node_type.unsqueeze(1).repeat(1, N, 1)\n            node_type = th.stack(\n                [src_node_type + 2, tgt_node_type + self.max_node_type + 3],\n                dim=-1,\n            )  # shape: [B, n, n, 2]\n\n        # scaled euclidean distance\n        gamma = self.gamma(node_type).sum(dim=-2)  # shape: [B, n, n, 1]\n        beta = self.beta(node_type).sum(dim=-2)  # shape: [B, n, n, 1]\n        euc_dist = gamma * euc_dist.unsqueeze(-1) + beta  # shape: [B, n, n, 1]\n        # gaussian basis kernel\n        euc_dist = euc_dist.expand(-1, -1, -1, self.num_kernels)\n        gaussian_kernel = gaussian(\n            euc_dist, self.means, self.stds.abs() + 1e-2\n        )  # shape: [B, n, n, K]\n        # linear projection\n        encoding = self.linear_layer_1(gaussian_kernel)\n        encoding = F.gelu(encoding)\n        encoding = self.linear_layer_2(encoding)  # shape: [B, n, n, H]\n\n        return encoding\n"
  },
  {
    "path": "python/dgl/nn/pytorch/hetero.py",
    "content": "\"\"\"Heterograph NN modules\"\"\"\nfrom functools import partial\n\nimport torch as th\nimport torch.nn as nn\n\nfrom ...base import DGLError\n\n__all__ = [\"HeteroGraphConv\", \"HeteroLinear\", \"HeteroEmbedding\"]\n\n\nclass HeteroGraphConv(nn.Module):\n    r\"\"\"A generic module for computing convolution on heterogeneous graphs.\n\n    The heterograph convolution applies sub-modules on their associating\n    relation graphs, which reads the features from source nodes and writes the\n    updated ones to destination nodes. If multiple relations have the same\n    destination node types, their results are aggregated by the specified method.\n    If the relation graph has no edge, the corresponding module will not be called.\n\n    Pseudo-code:\n\n    .. code::\n\n        outputs = {nty : [] for nty in g.dsttypes}\n        # Apply sub-modules on their associating relation graphs in parallel\n        for relation in g.canonical_etypes:\n            stype, etype, dtype = relation\n            dstdata = relation_submodule(g[relation], ...)\n            outputs[dtype].append(dstdata)\n\n        # Aggregate the results for each destination node type\n        rsts = {}\n        for ntype, ntype_outputs in outputs.items():\n            if len(ntype_outputs) != 0:\n                rsts[ntype] = aggregate(ntype_outputs)\n        return rsts\n\n    Examples\n    --------\n\n    Create a heterograph with three types of relations and nodes.\n\n    >>> import dgl\n    >>> g = dgl.heterograph({\n    ...     ('user', 'follows', 'user') : edges1,\n    ...     ('user', 'plays', 'game') : edges2,\n    ...     ('store', 'sells', 'game')  : edges3})\n\n    Create a ``HeteroGraphConv`` that applies different convolution modules to\n    different relations. Note that the modules for ``'follows'`` and ``'plays'``\n    do not share weights.\n\n    >>> import dgl.nn.pytorch as dglnn\n    >>> conv = dglnn.HeteroGraphConv({\n    ...     'follows' : dglnn.GraphConv(...),\n    ...     'plays' : dglnn.GraphConv(...),\n    ...     'sells' : dglnn.SAGEConv(...)},\n    ...     aggregate='sum')\n\n    Call forward with some ``'user'`` features. This computes new features for both\n    ``'user'`` and ``'game'`` nodes.\n\n    >>> import torch as th\n    >>> h1 = {'user' : th.randn((g.num_nodes('user'), 5))}\n    >>> h2 = conv(g, h1)\n    >>> print(h2.keys())\n    dict_keys(['user', 'game'])\n\n    Call forward with both ``'user'`` and ``'store'`` features. Because both the\n    ``'plays'`` and ``'sells'`` relations will update the ``'game'`` features,\n    their results are aggregated by the specified method (i.e., summation here).\n\n    >>> f1 = {'user' : ..., 'store' : ...}\n    >>> f2 = conv(g, f1)\n    >>> print(f2.keys())\n    dict_keys(['user', 'game'])\n\n    Call forward with some ``'store'`` features. This only computes new features\n    for ``'game'`` nodes.\n\n    >>> g1 = {'store' : ...}\n    >>> g2 = conv(g, g1)\n    >>> print(g2.keys())\n    dict_keys(['game'])\n\n    Call forward with a pair of inputs is allowed and each submodule will also\n    be invoked with a pair of inputs.\n\n    >>> x_src = {'user' : ..., 'store' : ...}\n    >>> x_dst = {'user' : ..., 'game' : ...}\n    >>> y_dst = conv(g, (x_src, x_dst))\n    >>> print(y_dst.keys())\n    dict_keys(['user', 'game'])\n\n    Parameters\n    ----------\n    mods : dict[str, nn.Module]\n        Modules associated with every edge types. The forward function of each\n        module must have a `DGLGraph` object as the first argument, and\n        its second argument is either a tensor object representing the node\n        features or a pair of tensor object representing the source and destination\n        node features.\n    aggregate : str, callable, optional\n        Method for aggregating node features generated by different relations.\n        Allowed string values are 'sum', 'max', 'min', 'mean', 'stack'.\n        The 'stack' aggregation is performed along the second dimension, whose order\n        is deterministic.\n        User can also customize the aggregator by providing a callable instance.\n        For example, aggregation by summation is equivalent to the follows:\n\n        .. code::\n\n            def my_agg_func(tensors, dsttype):\n                # tensors: is a list of tensors to aggregate\n                # dsttype: string name of the destination node type for which the\n                #          aggregation is performed\n                stacked = torch.stack(tensors, dim=0)\n                return torch.sum(stacked, dim=0)\n\n    Attributes\n    ----------\n    mods : dict[str, nn.Module]\n        Modules associated with every edge types.\n    \"\"\"\n\n    def __init__(self, mods, aggregate=\"sum\"):\n        super(HeteroGraphConv, self).__init__()\n        self.mod_dict = mods\n        mods = {str(k): v for k, v in mods.items()}\n        # Register as child modules\n        self.mods = nn.ModuleDict(mods)\n        # PyTorch ModuleDict doesn't have get() method, so I have to store two\n        # dictionaries so that I can index with both canonical edge type and\n        # edge type with the get() method.\n        # Do not break if graph has 0-in-degree nodes.\n        # Because there is no general rule to add self-loop for heterograph.\n        for _, v in self.mods.items():\n            set_allow_zero_in_degree_fn = getattr(\n                v, \"set_allow_zero_in_degree\", None\n            )\n            if callable(set_allow_zero_in_degree_fn):\n                set_allow_zero_in_degree_fn(True)\n        if isinstance(aggregate, str):\n            self.agg_fn = get_aggregate_fn(aggregate)\n        else:\n            self.agg_fn = aggregate\n\n    def _get_module(self, etype):\n        mod = self.mod_dict.get(etype, None)\n        if mod is not None:\n            return mod\n        if isinstance(etype, tuple):\n            # etype is canonical\n            _, etype, _ = etype\n            return self.mod_dict[etype]\n        raise KeyError(\"Cannot find module with edge type %s\" % etype)\n\n    def forward(self, g, inputs, mod_args=None, mod_kwargs=None):\n        \"\"\"Forward computation\n\n        Invoke the forward function with each module and aggregate their results.\n\n        Parameters\n        ----------\n        g : DGLGraph\n            Graph data.\n        inputs : dict[str, Tensor] or pair of dict[str, Tensor]\n            Input node features.\n        mod_args : dict[str, tuple[any]], optional\n            Extra positional arguments for the sub-modules.\n        mod_kwargs : dict[str, dict[str, any]], optional\n            Extra key-word arguments for the sub-modules.\n\n        Returns\n        -------\n        dict[str, Tensor]\n            Output representations for every types of nodes.\n        \"\"\"\n        if mod_args is None:\n            mod_args = {}\n        if mod_kwargs is None:\n            mod_kwargs = {}\n        outputs = {nty: [] for nty in g.dsttypes}\n        if isinstance(inputs, tuple) or g.is_block:\n            if isinstance(inputs, tuple):\n                src_inputs, dst_inputs = inputs\n            else:\n                src_inputs = inputs\n                dst_inputs = {\n                    k: v[: g.number_of_dst_nodes(k)] for k, v in inputs.items()\n                }\n\n            for stype, etype, dtype in g.canonical_etypes:\n                rel_graph = g[stype, etype, dtype]\n                if stype not in src_inputs or dtype not in dst_inputs:\n                    continue\n                dstdata = self._get_module((stype, etype, dtype))(\n                    rel_graph,\n                    (src_inputs[stype], dst_inputs[dtype]),\n                    *mod_args.get(etype, ()),\n                    **mod_kwargs.get(etype, {})\n                )\n                outputs[dtype].append(dstdata)\n        else:\n            for stype, etype, dtype in g.canonical_etypes:\n                rel_graph = g[stype, etype, dtype]\n                if stype not in inputs:\n                    continue\n                dstdata = self._get_module((stype, etype, dtype))(\n                    rel_graph,\n                    (inputs[stype], inputs[dtype]),\n                    *mod_args.get(etype, ()),\n                    **mod_kwargs.get(etype, {})\n                )\n                outputs[dtype].append(dstdata)\n        rsts = {}\n        for nty, alist in outputs.items():\n            if len(alist) != 0:\n                rsts[nty] = self.agg_fn(alist, nty)\n        return rsts\n\n\ndef _max_reduce_func(inputs, dim):\n    return th.max(inputs, dim=dim)[0]\n\n\ndef _min_reduce_func(inputs, dim):\n    return th.min(inputs, dim=dim)[0]\n\n\ndef _sum_reduce_func(inputs, dim):\n    return th.sum(inputs, dim=dim)\n\n\ndef _mean_reduce_func(inputs, dim):\n    return th.mean(inputs, dim=dim)\n\n\ndef _stack_agg_func(inputs, dsttype):  # pylint: disable=unused-argument\n    if len(inputs) == 0:\n        return None\n    return th.stack(inputs, dim=1)\n\n\ndef _agg_func(inputs, dsttype, fn):  # pylint: disable=unused-argument\n    if len(inputs) == 0:\n        return None\n    stacked = th.stack(inputs, dim=0)\n    return fn(stacked, dim=0)\n\n\ndef get_aggregate_fn(agg):\n    \"\"\"Internal function to get the aggregation function for node data\n    generated from different relations.\n\n    Parameters\n    ----------\n    agg : str\n        Method for aggregating node features generated by different relations.\n        Allowed values are 'sum', 'max', 'min', 'mean', 'stack'.\n\n    Returns\n    -------\n    callable\n        Aggregator function that takes a list of tensors to aggregate\n        and returns one aggregated tensor.\n    \"\"\"\n    if agg == \"sum\":\n        fn = _sum_reduce_func\n    elif agg == \"max\":\n        fn = _max_reduce_func\n    elif agg == \"min\":\n        fn = _min_reduce_func\n    elif agg == \"mean\":\n        fn = _mean_reduce_func\n    elif agg == \"stack\":\n        fn = None  # will not be called\n    else:\n        raise DGLError(\n            \"Invalid cross type aggregator. Must be one of \"\n            '\"sum\", \"max\", \"min\", \"mean\" or \"stack\". But got \"%s\"' % agg\n        )\n    if agg == \"stack\":\n        return _stack_agg_func\n    else:\n        return partial(_agg_func, fn=fn)\n\n\nclass HeteroLinear(nn.Module):\n    \"\"\"Apply linear transformations on heterogeneous inputs.\n\n    Parameters\n    ----------\n    in_size : dict[key, int]\n        Input feature size for heterogeneous inputs. A key can be a string or a tuple of strings.\n    out_size : int\n        Output feature size.\n    bias : bool, optional\n        If True, learns a bias term. Defaults: ``True``.\n\n    Examples\n    --------\n\n    >>> import dgl\n    >>> import torch\n    >>> from dgl.nn import HeteroLinear\n\n    >>> layer = HeteroLinear({'user': 1, ('user', 'follows', 'user'): 2}, 3)\n    >>> in_feats = {'user': torch.randn(2, 1), ('user', 'follows', 'user'): torch.randn(3, 2)}\n    >>> out_feats = layer(in_feats)\n    >>> print(out_feats['user'].shape)\n    torch.Size([2, 3])\n    >>> print(out_feats[('user', 'follows', 'user')].shape)\n    torch.Size([3, 3])\n    \"\"\"\n\n    def __init__(self, in_size, out_size, bias=True):\n        super(HeteroLinear, self).__init__()\n\n        self.linears = nn.ModuleDict()\n        for typ, typ_in_size in in_size.items():\n            self.linears[str(typ)] = nn.Linear(typ_in_size, out_size, bias=bias)\n\n    def forward(self, feat):\n        \"\"\"Forward function\n\n        Parameters\n        ----------\n        feat : dict[key, Tensor]\n            Heterogeneous input features. It maps keys to features.\n\n        Returns\n        -------\n        dict[key, Tensor]\n            Transformed features.\n        \"\"\"\n        out_feat = dict()\n        for typ, typ_feat in feat.items():\n            out_feat[typ] = self.linears[str(typ)](typ_feat)\n\n        return out_feat\n\n\nclass HeteroEmbedding(nn.Module):\n    \"\"\"Create a heterogeneous embedding table.\n\n    It internally contains multiple ``torch.nn.Embedding`` with different dictionary sizes.\n\n    Parameters\n    ----------\n    num_embeddings : dict[key, int]\n        Size of the dictionaries. A key can be a string or a tuple of strings.\n    embedding_dim : int\n        Size of each embedding vector.\n\n    Examples\n    --------\n\n    >>> import dgl\n    >>> import torch\n    >>> from dgl.nn import HeteroEmbedding\n\n    >>> layer = HeteroEmbedding({'user': 2, ('user', 'follows', 'user'): 3}, 4)\n    >>> # Get the heterogeneous embedding table\n    >>> embeds = layer.weight\n    >>> print(embeds['user'].shape)\n    torch.Size([2, 4])\n    >>> print(embeds[('user', 'follows', 'user')].shape)\n    torch.Size([3, 4])\n\n    >>> # Get the embeddings for a subset\n    >>> input_ids = {'user': torch.LongTensor([0]),\n    ...              ('user', 'follows', 'user'): torch.LongTensor([0, 2])}\n    >>> embeds = layer(input_ids)\n    >>> print(embeds['user'].shape)\n    torch.Size([1, 4])\n    >>> print(embeds[('user', 'follows', 'user')].shape)\n    torch.Size([2, 4])\n    \"\"\"\n\n    def __init__(self, num_embeddings, embedding_dim):\n        super(HeteroEmbedding, self).__init__()\n\n        self.embeds = nn.ModuleDict()\n        self.raw_keys = dict()\n        for typ, typ_num_rows in num_embeddings.items():\n            self.embeds[str(typ)] = nn.Embedding(typ_num_rows, embedding_dim)\n            self.raw_keys[str(typ)] = typ\n\n    @property\n    def weight(self):\n        \"\"\"Get the heterogeneous embedding table\n\n        Returns\n        -------\n        dict[key, Tensor]\n            Heterogeneous embedding table\n        \"\"\"\n        return {\n            self.raw_keys[typ]: emb.weight for typ, emb in self.embeds.items()\n        }\n\n    def reset_parameters(self):\n        \"\"\"\n        Use the xavier method in nn.init module to make the parameters uniformly distributed\n        \"\"\"\n        for typ in self.embeds.keys():\n            nn.init.xavier_uniform_(self.embeds[typ].weight)\n\n    def forward(self, input_ids):\n        \"\"\"Forward function\n\n        Parameters\n        ----------\n        input_ids : dict[key, Tensor]\n            The row IDs to retrieve embeddings. It maps a key to key-specific IDs.\n\n        Returns\n        -------\n        dict[key, Tensor]\n            The retrieved embeddings.\n        \"\"\"\n        embeds = dict()\n        for typ, typ_ids in input_ids.items():\n            embeds[typ] = self.embeds[str(typ)](typ_ids)\n\n        return embeds\n"
  },
  {
    "path": "python/dgl/nn/pytorch/linear.py",
    "content": "\"\"\"Various commonly used linear modules\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name, W0235\nimport math\n\nimport torch\nimport torch.nn as nn\n\nfrom ...ops import gather_mm, segment_mm\n\n__all__ = [\"TypedLinear\"]\n\n\nclass TypedLinear(nn.Module):\n    r\"\"\"Linear transformation according to types.\n\n    For each sample of the input batch :math:`x \\in X`, apply linear transformation\n    :math:`xW_t`, where :math:`t` is the type of :math:`x`.\n\n    The module supports two regularization methods (basis-decomposition and\n    block-diagonal-decomposition) proposed by \"`Modeling Relational Data\n    with Graph Convolutional Networks <https://arxiv.org/abs/1703.06103>`__\"\n\n    The basis regularization decomposes :math:`W_t` by:\n\n    .. math::\n\n       W_t^{(l)} = \\sum_{b=1}^B a_{tb}^{(l)}V_b^{(l)}\n\n    where :math:`B` is the number of bases, :math:`V_b^{(l)}` are linearly combined\n    with coefficients :math:`a_{tb}^{(l)}`.\n\n    The block-diagonal-decomposition regularization decomposes :math:`W_t` into :math:`B`\n    block-diagonal matrices. We refer to :math:`B` as the number of bases:\n\n    .. math::\n\n       W_t^{(l)} = \\oplus_{b=1}^B Q_{tb}^{(l)}\n\n    where :math:`B` is the number of bases, :math:`Q_{tb}^{(l)}` are block\n    bases with shape :math:`R^{(d^{(l+1)}/B)\\times(d^{l}/B)}`.\n\n    Parameters\n    ----------\n    in_size : int\n        Input feature size.\n    out_size : int\n        Output feature size.\n    num_types : int\n        Total number of types.\n    regularizer : str, optional\n        Which weight regularizer to use \"basis\" or \"bdd\":\n\n         - \"basis\" is short for basis-decomposition.\n         - \"bdd\" is short for block-diagonal-decomposition.\n\n        Default applies no regularization.\n    num_bases : int, optional\n        Number of bases. Needed when ``regularizer`` is specified. Typically smaller\n        than ``num_types``.\n        Default: ``None``.\n\n    Examples\n    --------\n\n    No regularization.\n\n    >>> from dgl.nn import TypedLinear\n    >>> import torch\n    >>>\n    >>> x = torch.randn(100, 32)\n    >>> x_type = torch.randint(0, 5, (100,))\n    >>> m = TypedLinear(32, 64, 5)\n    >>> y = m(x, x_type)\n    >>> print(y.shape)\n    torch.Size([100, 64])\n\n    With basis regularization\n\n    >>> x = torch.randn(100, 32)\n    >>> x_type = torch.randint(0, 5, (100,))\n    >>> m = TypedLinear(32, 64, 5, regularizer='basis', num_bases=4)\n    >>> y = m(x, x_type)\n    >>> print(y.shape)\n    torch.Size([100, 64])\n    \"\"\"\n\n    def __init__(\n        self, in_size, out_size, num_types, regularizer=None, num_bases=None\n    ):\n        super().__init__()\n        self.in_size = in_size\n        self.out_size = out_size\n        self.num_types = num_types\n        if regularizer is None:\n            self.W = nn.Parameter(torch.Tensor(num_types, in_size, out_size))\n        elif regularizer == \"basis\":\n            if num_bases is None:\n                raise ValueError(\n                    'Missing \"num_bases\" for basis regularization.'\n                )\n            self.W = nn.Parameter(torch.Tensor(num_bases, in_size, out_size))\n            self.coeff = nn.Parameter(torch.Tensor(num_types, num_bases))\n            self.num_bases = num_bases\n        elif regularizer == \"bdd\":\n            if num_bases is None:\n                raise ValueError('Missing \"num_bases\" for bdd regularization.')\n            if in_size % num_bases != 0 or out_size % num_bases != 0:\n                raise ValueError(\n                    \"Input and output sizes must be divisible by num_bases.\"\n                )\n            self.submat_in = in_size // num_bases\n            self.submat_out = out_size // num_bases\n            self.W = nn.Parameter(\n                torch.Tensor(\n                    num_types, num_bases * self.submat_in * self.submat_out\n                )\n            )\n            self.num_bases = num_bases\n        else:\n            raise ValueError(\n                f'Supported regularizer options: \"basis\", \"bdd\", but got {regularizer}'\n            )\n        self.regularizer = regularizer\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        \"\"\"Reset parameters\"\"\"\n        with torch.no_grad():\n            # Follow torch.nn.Linear 's initialization to use kaiming_uniform_ on in_size\n            if self.regularizer is None:\n                nn.init.uniform_(\n                    self.W,\n                    -1 / math.sqrt(self.in_size),\n                    1 / math.sqrt(self.in_size),\n                )\n            elif self.regularizer == \"basis\":\n                nn.init.uniform_(\n                    self.W,\n                    -1 / math.sqrt(self.in_size),\n                    1 / math.sqrt(self.in_size),\n                )\n                nn.init.xavier_uniform_(\n                    self.coeff, gain=nn.init.calculate_gain(\"relu\")\n                )\n            elif self.regularizer == \"bdd\":\n                nn.init.uniform_(\n                    self.W,\n                    -1 / math.sqrt(self.submat_in),\n                    1 / math.sqrt(self.submat_in),\n                )\n            else:\n                raise ValueError(\n                    f'Supported regularizer options: \"basis\", \"bdd\", but got {regularizer}'\n                )\n\n    def get_weight(self):\n        \"\"\"Get type-wise weight\"\"\"\n        if self.regularizer is None:\n            return self.W\n        elif self.regularizer == \"basis\":\n            W = self.W.view(self.num_bases, self.in_size * self.out_size)\n            return (self.coeff @ W).view(\n                self.num_types, self.in_size, self.out_size\n            )\n        elif self.regularizer == \"bdd\":\n            return self.W\n        else:\n            raise ValueError(\n                f'Supported regularizer options: \"basis\", \"bdd\", but got {regularizer}'\n            )\n\n    def forward(self, x, x_type, sorted_by_type=False):\n        \"\"\"Forward computation.\n\n        Parameters\n        ----------\n        x : torch.Tensor\n            A 2D input tensor. Shape: (N, D1)\n        x_type : torch.Tensor\n            A 1D integer tensor storing the type of the elements in ``x`` with one-to-one\n            correspondenc. Shape: (N,)\n        sorted_by_type : bool, optional\n            Whether the inputs have been sorted by the types. Forward on pre-sorted inputs may\n            be faster.\n\n        Returns\n        -------\n        y : torch.Tensor\n            The transformed output tensor. Shape: (N, D2)\n        \"\"\"\n        w = self.get_weight()\n        if self.regularizer == \"bdd\":\n            w = w.index_select(0, x_type).view(\n                -1, self.submat_in, self.submat_out\n            )\n            x = x.view(-1, 1, self.submat_in)\n            return torch.bmm(x, w).view(-1, self.out_size)\n        elif sorted_by_type:\n            pos_l = torch.searchsorted(\n                x_type, torch.arange(self.num_types, device=x.device)\n            )\n            pos_r = torch.cat(\n                [pos_l[1:], torch.tensor([len(x_type)], device=x.device)]\n            )\n            seglen = (\n                pos_r - pos_l\n            ).cpu()  # XXX(minjie): cause device synchronize\n            return segment_mm(x, w, seglen_a=seglen)\n        else:\n            return gather_mm(x, w, idx_b=x_type)\n\n    def __repr__(self):\n        if self.regularizer is None:\n            return (\n                f\"TypedLinear(in_size={self.in_size}, out_size={self.out_size}, \"\n                f\"num_types={self.num_types})\"\n            )\n        else:\n            return (\n                f\"TypedLinear(in_size={self.in_size}, out_size={self.out_size}, \"\n                f\"num_types={self.num_types}, regularizer={self.regularizer}, \"\n                f\"num_bases={self.num_bases})\"\n            )\n"
  },
  {
    "path": "python/dgl/nn/pytorch/link/__init__.py",
    "content": "\"\"\"Torch modules for link prediction/knowledge graph completion.\"\"\"\n\nfrom .edgepred import EdgePredictor\nfrom .transe import TransE\nfrom .transr import TransR\n"
  },
  {
    "path": "python/dgl/nn/pytorch/link/edgepred.py",
    "content": "\"\"\"Predictor for edges in homogeneous graphs.\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name, W0235\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass EdgePredictor(nn.Module):\n    r\"\"\"Predictor/score function for pairs of node representations\n\n    Given a pair of node representations, :math:`h_i` and :math:`h_j`, it combines them with\n\n    **dot product**\n\n    .. math::\n\n        h_i^{T} h_j\n\n    or **cosine similarity**\n\n    .. math::\n\n        \\frac{h_i^{T} h_j}{{\\| h_i \\|}_2 \\cdot {\\| h_j \\|}_2}\n\n    or **elementwise product**\n\n    .. math::\n\n        h_i \\odot h_j\n\n    or **concatenation**\n\n    .. math::\n\n        h_i \\Vert h_j\n\n    Optionally, it passes the combined results to a linear layer for the final prediction.\n\n    Parameters\n    ----------\n    op : str\n        The operation to apply. It can be 'dot', 'cos', 'ele', or 'cat',\n        corresponding to the equations above in order.\n    in_feats : int, optional\n        The input feature size of :math:`h_i` and :math:`h_j`. It is required\n        only if a linear layer is to be applied.\n    out_feats : int, optional\n        The output feature size. It is reuiqred only if a linear layer is to be applied.\n    bias : bool, optional\n        Whether to use bias for the linear layer if it applies.\n\n    Examples\n    --------\n    >>> import dgl\n    >>> import torch as th\n    >>> from dgl.nn import EdgePredictor\n    >>> num_nodes = 2\n    >>> num_edges = 3\n    >>> in_feats = 4\n    >>> g = dgl.rand_graph(num_nodes=num_nodes, num_edges=num_edges)\n    >>> h = th.randn(num_nodes, in_feats)\n    >>> src, dst = g.edges()\n    >>> h_src = h[src]\n    >>> h_dst = h[dst]\n\n    Case1: dot product\n\n    >>> predictor = EdgePredictor('dot')\n    >>> predictor(h_src, h_dst).shape\n    torch.Size([3, 1])\n    >>> predictor = EdgePredictor('dot', in_feats, out_feats=3)\n    >>> predictor.reset_parameters()\n    >>> predictor(h_src, h_dst).shape\n    torch.Size([3, 3])\n\n    Case2: cosine similarity\n\n    >>> predictor = EdgePredictor('cos')\n    >>> predictor(h_src, h_dst).shape\n    torch.Size([3, 1])\n    >>> predictor = EdgePredictor('cos', in_feats, out_feats=3)\n    >>> predictor.reset_parameters()\n    >>> predictor(h_src, h_dst).shape\n    torch.Size([3, 3])\n\n    Case3: elementwise product\n\n    >>> predictor = EdgePredictor('ele')\n    >>> predictor(h_src, h_dst).shape\n    torch.Size([3, 4])\n    >>> predictor = EdgePredictor('ele', in_feats, out_feats=3)\n    >>> predictor.reset_parameters()\n    >>> predictor(h_src, h_dst).shape\n    torch.Size([3, 3])\n\n    Case4: concatenation\n\n    >>> predictor = EdgePredictor('cat')\n    >>> predictor(h_src, h_dst).shape\n    torch.Size([3, 8])\n    >>> predictor = EdgePredictor('cat', in_feats, out_feats=3)\n    >>> predictor.reset_parameters()\n    >>> predictor(h_src, h_dst).shape\n    torch.Size([3, 3])\n    \"\"\"\n\n    def __init__(self, op, in_feats=None, out_feats=None, bias=False):\n        super(EdgePredictor, self).__init__()\n\n        assert op in [\n            \"dot\",\n            \"cos\",\n            \"ele\",\n            \"cat\",\n        ], \"Expect op to be in ['dot', 'cos', 'ele', 'cat'], got {}\".format(op)\n        self.op = op\n        if (in_feats is not None) and (out_feats is not None):\n            if op in [\"dot\", \"cos\"]:\n                in_feats = 1\n            elif op == \"cat\":\n                in_feats = 2 * in_feats\n            self.linear = nn.Linear(in_feats, out_feats, bias=bias)\n        else:\n            self.linear = None\n\n    def reset_parameters(self):\n        r\"\"\"\n\n        Description\n        -----------\n        Reinitialize learnable parameters.\n        \"\"\"\n        if self.linear is not None:\n            self.linear.reset_parameters()\n\n    def forward(self, h_src, h_dst):\n        r\"\"\"\n\n        Description\n        -----------\n        Predict for pairs of node representations.\n\n        Parameters\n        ----------\n        h_src : torch.Tensor\n            Source node features. The tensor is of shape :math:`(E, D_{in})`,\n            where :math:`E` is the number of edges/node pairs, and :math:`D_{in}`\n            is the input feature size.\n        h_dst : torch.Tensor\n            Destination node features. The tensor is of shape :math:`(E, D_{in})`,\n            where :math:`E` is the number of edges/node pairs, and :math:`D_{in}`\n            is the input feature size.\n\n        Returns\n        -------\n        torch.Tensor\n            The output features.\n        \"\"\"\n        if self.op == \"dot\":\n            N, D = h_src.shape\n            h = torch.bmm(h_src.view(N, 1, D), h_dst.view(N, D, 1)).squeeze(-1)\n        elif self.op == \"cos\":\n            h = F.cosine_similarity(h_src, h_dst).unsqueeze(-1)\n        elif self.op == \"ele\":\n            h = h_src * h_dst\n        else:\n            h = torch.cat([h_src, h_dst], dim=-1)\n\n        if self.linear is not None:\n            h = self.linear(h)\n\n        return h\n"
  },
  {
    "path": "python/dgl/nn/pytorch/link/transe.py",
    "content": "\"\"\"TransE.\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name, W0235\nimport torch\nimport torch.nn as nn\n\n\nclass TransE(nn.Module):\n    r\"\"\"Similarity measure from `Translating Embeddings for Modeling Multi-relational Data\n    <https://papers.nips.cc/paper/2013/hash/1cecc7a77928ca8133fa24680a88d2f9-Abstract.html>`__\n\n    Mathematically, it is defined as follows:\n\n    .. math::\n\n        - {\\| h + r - t \\|}_p\n\n    where :math:`h` is the head embedding, :math:`r` is the relation embedding, and\n    :math:`t` is the tail embedding.\n\n    Parameters\n    ----------\n    num_rels : int\n        Number of relation types.\n    feats : int\n        Embedding size.\n    p : int, optional\n        The p to use for Lp norm, which can be 1 or 2.\n\n    Attributes\n    ----------\n    rel_emb : torch.nn.Embedding\n        The learnable relation type embedding.\n\n    Examples\n    --------\n    >>> import dgl\n    >>> import torch as th\n    >>> from dgl.nn import TransE\n\n    >>> # input features\n    >>> num_nodes = 10\n    >>> num_edges = 30\n    >>> num_rels = 3\n    >>> feats = 4\n\n    >>> scorer = TransE(num_rels=num_rels, feats=feats)\n    >>> g = dgl.rand_graph(num_nodes=num_nodes, num_edges=num_edges)\n    >>> src, dst = g.edges()\n    >>> h = th.randn(num_nodes, feats)\n    >>> h_head = h[src]\n    >>> h_tail = h[dst]\n    >>> # Randomly initialize edge relation types for demonstration\n    >>> rels = th.randint(low=0, high=num_rels, size=(num_edges,))\n    >>> scorer(h_head, h_tail, rels).shape\n    torch.Size([30])\n    \"\"\"\n\n    def __init__(self, num_rels, feats, p=1):\n        super(TransE, self).__init__()\n\n        self.rel_emb = nn.Embedding(num_rels, feats)\n        self.p = p\n\n    def reset_parameters(self):\n        r\"\"\"\n\n        Description\n        -----------\n        Reinitialize learnable parameters.\n        \"\"\"\n        self.rel_emb.reset_parameters()\n\n    def forward(self, h_head, h_tail, rels):\n        r\"\"\"\n\n        Description\n        -----------\n        Score triples.\n\n        Parameters\n        ----------\n        h_head : torch.Tensor\n            Head entity features. The tensor is of shape :math:`(E, D)`, where\n            :math:`E` is the number of triples, and :math:`D` is the feature size.\n        h_tail : torch.Tensor\n            Tail entity features. The tensor is of shape :math:`(E, D)`, where\n            :math:`E` is the number of triples, and :math:`D` is the feature size.\n        rels : torch.Tensor\n            Relation types. It is a LongTensor of shape :math:`(E)`, where\n            :math:`E` is the number of triples.\n\n        Returns\n        -------\n        torch.Tensor\n            The triple scores. The tensor is of shape :math:`(E)`.\n        \"\"\"\n        h_rel = self.rel_emb(rels)\n\n        return -torch.norm(h_head + h_rel - h_tail, p=self.p, dim=-1)\n"
  },
  {
    "path": "python/dgl/nn/pytorch/link/transr.py",
    "content": "\"\"\"TransR.\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name, W0235\nimport torch\nimport torch.nn as nn\n\n\nclass TransR(nn.Module):\n    r\"\"\"Similarity measure from\n    `Learning entity and relation embeddings for knowledge graph completion\n    <https://ojs.aaai.org/index.php/AAAI/article/view/9491>`__\n\n    Mathematically, it is defined as follows:\n\n    .. math::\n\n        - {\\| M_r h + r - M_r t \\|}_p\n\n    where :math:`M_r` is a relation-specific projection matrix, :math:`h` is the\n    head embedding, :math:`r` is the relation embedding, and :math:`t` is the tail embedding.\n\n    Parameters\n    ----------\n    num_rels : int\n        Number of relation types.\n    rfeats : int\n        Relation embedding size.\n    nfeats : int\n        Entity embedding size.\n    p : int, optional\n        The p to use for Lp norm, which can be 1 or 2.\n\n    Attributes\n    ----------\n    rel_emb : torch.nn.Embedding\n        The learnable relation type embedding.\n    rel_project : torch.nn.Embedding\n        The learnable relation-type-specific projection.\n\n    Examples\n    --------\n    >>> import dgl\n    >>> import torch as th\n    >>> from dgl.nn import TransR\n\n    >>> # input features\n    >>> num_nodes = 10\n    >>> num_edges = 30\n    >>> num_rels = 3\n    >>> feats = 4\n\n    >>> scorer = TransR(num_rels=num_rels, rfeats=2, nfeats=feats)\n    >>> g = dgl.rand_graph(num_nodes=num_nodes, num_edges=num_edges)\n    >>> src, dst = g.edges()\n    >>> h = th.randn(num_nodes, feats)\n    >>> h_head = h[src]\n    >>> h_tail = h[dst]\n    >>> # Randomly initialize edge relation types for demonstration\n    >>> rels = th.randint(low=0, high=num_rels, size=(num_edges,))\n    >>> scorer(h_head, h_tail, rels).shape\n    torch.Size([30])\n    \"\"\"\n\n    def __init__(self, num_rels, rfeats, nfeats, p=1):\n        super(TransR, self).__init__()\n\n        self.rel_emb = nn.Embedding(num_rels, rfeats)\n        self.rel_project = nn.Embedding(num_rels, nfeats * rfeats)\n        self.rfeats = rfeats\n        self.nfeats = nfeats\n        self.p = p\n\n    def reset_parameters(self):\n        r\"\"\"\n\n        Description\n        -----------\n        Reinitialize learnable parameters.\n        \"\"\"\n        self.rel_emb.reset_parameters()\n        self.rel_project.reset_parameters()\n\n    def forward(self, h_head, h_tail, rels):\n        r\"\"\"\n        Score triples.\n\n        Parameters\n        ----------\n        h_head : torch.Tensor\n            Head entity features. The tensor is of shape :math:`(E, D)`, where\n            :math:`E` is the number of triples, and :math:`D` is the feature size.\n        h_tail : torch.Tensor\n            Tail entity features. The tensor is of shape :math:`(E, D)`, where\n            :math:`E` is the number of triples, and :math:`D` is the feature size.\n        rels : torch.Tensor\n            Relation types. It is a LongTensor of shape :math:`(E)`, where\n            :math:`E` is the number of triples.\n\n        Returns\n        -------\n        torch.Tensor\n            The triple scores. The tensor is of shape :math:`(E)`.\n        \"\"\"\n        h_rel = self.rel_emb(rels)\n        proj_rel = self.rel_project(rels).reshape(-1, self.nfeats, self.rfeats)\n        h_head = (h_head.unsqueeze(1) @ proj_rel).squeeze(1)\n        h_tail = (h_tail.unsqueeze(1) @ proj_rel).squeeze(1)\n\n        return -torch.norm(h_head + h_rel - h_tail, p=self.p, dim=-1)\n"
  },
  {
    "path": "python/dgl/nn/pytorch/network_emb.py",
    "content": "\"\"\"Network Embedding NN Modules\"\"\"\n\n# pylint: disable= invalid-name\n\nimport random\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\nfrom torch.nn import init\nfrom tqdm.auto import trange\n\nfrom ...base import NID\nfrom ...convert import to_heterogeneous, to_homogeneous\nfrom ...random import choice\nfrom ...sampling import random_walk\n\n__all__ = [\"DeepWalk\", \"MetaPath2Vec\"]\n\n\nclass DeepWalk(nn.Module):\n    \"\"\"DeepWalk module from `DeepWalk: Online Learning of Social Representations\n    <https://arxiv.org/abs/1403.6652>`__\n\n    For a graph, it learns the node representations from scratch by maximizing the similarity of\n    node pairs that are nearby (positive node pairs) and minimizing the similarity of other\n    random node pairs (negative node pairs).\n\n    Parameters\n    ----------\n    g : DGLGraph\n        Graph for learning node embeddings\n    emb_dim : int, optional\n        Size of each embedding vector. Default: 128\n    walk_length : int, optional\n        Number of nodes in a random walk sequence. Default: 40\n    window_size : int, optional\n        In a random walk :attr:`w`, a node :attr:`w[j]` is considered close to a node\n        :attr:`w[i]` if :attr:`i - window_size <= j <= i + window_size`. Default: 5\n    neg_weight : float, optional\n        Weight of the loss term for negative samples in the total loss. Default: 1.0\n    negative_size : int, optional\n        Number of negative samples to use for each positive sample. Default: 5\n    fast_neg : bool, optional\n        If True, it samples negative node pairs within a batch of random walks. Default: True\n    sparse : bool, optional\n        If True, gradients with respect to the learnable weights will be sparse.\n        Default: True\n\n    Attributes\n    ----------\n    node_embed : nn.Embedding\n        Embedding table of the nodes\n\n    Examples\n    --------\n\n    >>> import torch\n    >>> from dgl.data import CoraGraphDataset\n    >>> from dgl.nn import DeepWalk\n    >>> from torch.optim import SparseAdam\n    >>> from torch.utils.data import DataLoader\n    >>> from sklearn.linear_model import LogisticRegression\n\n    >>> dataset = CoraGraphDataset()\n    >>> g = dataset[0]\n    >>> model = DeepWalk(g)\n    >>> dataloader = DataLoader(torch.arange(g.num_nodes()), batch_size=128,\n    ...                         shuffle=True, collate_fn=model.sample)\n    >>> optimizer = SparseAdam(model.parameters(), lr=0.01)\n    >>> num_epochs = 5\n\n    >>> for epoch in range(num_epochs):\n    ...     for batch_walk in dataloader:\n    ...         loss = model(batch_walk)\n    ...         optimizer.zero_grad()\n    ...         loss.backward()\n    ...         optimizer.step()\n\n    >>> train_mask = g.ndata['train_mask']\n    >>> test_mask = g.ndata['test_mask']\n    >>> X = model.node_embed.weight.detach()\n    >>> y = g.ndata['label']\n    >>> clf = LogisticRegression().fit(X[train_mask].numpy(), y[train_mask].numpy())\n    >>> clf.score(X[test_mask].numpy(), y[test_mask].numpy())\n    \"\"\"\n\n    def __init__(\n        self,\n        g,\n        emb_dim=128,\n        walk_length=40,\n        window_size=5,\n        neg_weight=1,\n        negative_size=5,\n        fast_neg=True,\n        sparse=True,\n    ):\n        super().__init__()\n\n        assert (\n            walk_length >= window_size + 1\n        ), f\"Expect walk_length >= window_size + 1, got {walk_length} and {window_size + 1}\"\n\n        self.g = g\n        self.emb_dim = emb_dim\n        self.window_size = window_size\n        self.walk_length = walk_length\n        self.neg_weight = neg_weight\n        self.negative_size = negative_size\n        self.fast_neg = fast_neg\n\n        num_nodes = g.num_nodes()\n\n        # center node embedding\n        self.node_embed = nn.Embedding(num_nodes, emb_dim, sparse=sparse)\n        self.context_embed = nn.Embedding(num_nodes, emb_dim, sparse=sparse)\n        self.reset_parameters()\n\n        if not fast_neg:\n            neg_prob = g.out_degrees().pow(0.75)\n            # categorical distribution for true negative sampling\n            self.neg_prob = neg_prob / neg_prob.sum()\n\n        # Get list index pairs for positive samples.\n        # Given i, positive index pairs are (i - window_size, i), ... ,\n        # (i - 1, i), (i + 1, i), ..., (i + window_size, i)\n        idx_list_src = []\n        idx_list_dst = []\n\n        for i in range(walk_length):\n            for j in range(max(0, i - window_size), i):\n                idx_list_src.append(j)\n                idx_list_dst.append(i)\n            for j in range(i + 1, min(walk_length, i + 1 + window_size)):\n                idx_list_src.append(j)\n                idx_list_dst.append(i)\n\n        self.idx_list_src = torch.LongTensor(idx_list_src)\n        self.idx_list_dst = torch.LongTensor(idx_list_dst)\n\n    def reset_parameters(self):\n        \"\"\"Reinitialize learnable parameters\"\"\"\n        init_range = 1.0 / self.emb_dim\n        init.uniform_(self.node_embed.weight.data, -init_range, init_range)\n        init.constant_(self.context_embed.weight.data, 0)\n\n    def sample(self, indices):\n        \"\"\"Sample random walks\n\n        Parameters\n        ----------\n        indices : torch.Tensor\n            Nodes from which we perform random walk\n\n        Returns\n        -------\n        torch.Tensor\n            Random walks in the form of node ID sequences. The Tensor\n            is of shape :attr:`(len(indices), walk_length)`.\n        \"\"\"\n        return random_walk(self.g, indices, length=self.walk_length - 1)[0]\n\n    def forward(self, batch_walk):\n        \"\"\"Compute the loss for the batch of random walks\n\n        Parameters\n        ----------\n        batch_walk : torch.Tensor\n            Random walks in the form of node ID sequences. The Tensor\n            is of shape :attr:`(batch_size, walk_length)`.\n\n        Returns\n        -------\n        torch.Tensor\n            Loss value\n        \"\"\"\n        batch_size = len(batch_walk)\n        device = batch_walk.device\n\n        batch_node_embed = self.node_embed(batch_walk).view(-1, self.emb_dim)\n        batch_context_embed = self.context_embed(batch_walk).view(\n            -1, self.emb_dim\n        )\n\n        batch_idx_list_offset = torch.arange(batch_size) * self.walk_length\n        batch_idx_list_offset = batch_idx_list_offset.unsqueeze(1)\n        idx_list_src = batch_idx_list_offset + self.idx_list_src.unsqueeze(0)\n        idx_list_dst = batch_idx_list_offset + self.idx_list_dst.unsqueeze(0)\n        idx_list_src = idx_list_src.view(-1).to(device)\n        idx_list_dst = idx_list_dst.view(-1).to(device)\n\n        pos_src_emb = batch_node_embed[idx_list_src]\n        pos_dst_emb = batch_context_embed[idx_list_dst]\n\n        neg_idx_list_src = idx_list_dst.unsqueeze(1) + torch.zeros(\n            self.negative_size\n        ).unsqueeze(0).to(device)\n        neg_idx_list_src = neg_idx_list_src.view(-1)\n        neg_src_emb = batch_node_embed[neg_idx_list_src.long()]\n\n        if self.fast_neg:\n            neg_idx_list_dst = list(range(batch_size * self.walk_length)) * (\n                self.negative_size * self.window_size * 2\n            )\n            random.shuffle(neg_idx_list_dst)\n            neg_idx_list_dst = neg_idx_list_dst[: len(neg_idx_list_src)]\n            neg_idx_list_dst = torch.LongTensor(neg_idx_list_dst).to(device)\n            neg_dst_emb = batch_context_embed[neg_idx_list_dst]\n        else:\n            neg_dst = choice(\n                self.g.num_nodes(), size=len(neg_src_emb), prob=self.neg_prob\n            )\n            neg_dst_emb = self.context_embed(neg_dst.to(device))\n\n        pos_score = torch.sum(torch.mul(pos_src_emb, pos_dst_emb), dim=1)\n        pos_score = torch.clamp(pos_score, max=6, min=-6)\n        pos_score = torch.mean(-F.logsigmoid(pos_score))\n\n        neg_score = torch.sum(torch.mul(neg_src_emb, neg_dst_emb), dim=1)\n        neg_score = torch.clamp(neg_score, max=6, min=-6)\n        neg_score = (\n            torch.mean(-F.logsigmoid(-neg_score))\n            * self.negative_size\n            * self.neg_weight\n        )\n\n        return torch.mean(pos_score + neg_score)\n\n\nclass MetaPath2Vec(nn.Module):\n    r\"\"\"metapath2vec module from `metapath2vec: Scalable Representation Learning for\n    Heterogeneous Networks <https://dl.acm.org/doi/pdf/10.1145/3097983.3098036>`__\n\n    To achieve efficient optimization, we leverage the negative sampling technique for the\n    training process. Repeatedly for each node in meta-path, we treat it as the center node\n    and sample nearby positive nodes within context size and draw negative samples among all\n    types of nodes from all meta-paths. Then we can use the center-context paired nodes and\n    context-negative paired nodes to update the network.\n\n    Parameters\n    ----------\n    g : DGLGraph\n        Graph for learning node embeddings. Two different canonical edge types\n        :attr:`(utype, etype, vtype)` are not allowed to have same :attr:`etype`.\n    metapath : list[str]\n        A sequence of edge types in the form of a string. It defines a new edge type by composing\n        multiple edge types in order. Note that the start node type and the end one are commonly\n        the same.\n    window_size : int\n        In a random walk :attr:`w`, a node :attr:`w[j]` is considered close to a node\n        :attr:`w[i]` if :attr:`i - window_size <= j <= i + window_size`.\n    emb_dim : int, optional\n        Size of each embedding vector. Default: 128\n    negative_size : int, optional\n        Number of negative samples to use for each positive sample. Default: 5\n    sparse : bool, optional\n        If True, gradients with respect to the learnable weights will be sparse.\n        Default: True\n\n    Attributes\n    ----------\n    node_embed : nn.Embedding\n        Embedding table of all nodes\n    local_to_global_nid : dict[str, list]\n        Mapping from type-specific node IDs to global node IDs\n\n    Examples\n    --------\n\n    >>> import torch\n    >>> import dgl\n    >>> from torch.optim import SparseAdam\n    >>> from torch.utils.data import DataLoader\n    >>> from dgl.nn.pytorch import MetaPath2Vec\n\n    >>> # Define a model\n    >>> g = dgl.heterograph({\n    ...     ('user', 'uc', 'company'): dgl.rand_graph(100, 1000).edges(),\n    ...     ('company', 'cp', 'product'): dgl.rand_graph(100, 1000).edges(),\n    ...     ('company', 'cu', 'user'): dgl.rand_graph(100, 1000).edges(),\n    ...     ('product', 'pc', 'company'): dgl.rand_graph(100, 1000).edges()\n    ... })\n    >>> model = MetaPath2Vec(g, ['uc', 'cu'], window_size=1)\n\n    >>> # Use the source node type of etype 'uc'\n    >>> dataloader = DataLoader(torch.arange(g.num_nodes('user')), batch_size=128,\n    ...                         shuffle=True, collate_fn=model.sample)\n    >>> optimizer = SparseAdam(model.parameters(), lr=0.025)\n\n    >>> for (pos_u, pos_v, neg_v) in dataloader:\n    ...     loss = model(pos_u, pos_v, neg_v)\n    ...     optimizer.zero_grad()\n    ...     loss.backward()\n    ...     optimizer.step()\n\n    >>> # Get the embeddings of all user nodes\n    >>> user_nids = torch.LongTensor(model.local_to_global_nid['user'])\n    >>> user_emb = model.node_embed(user_nids)\n    \"\"\"\n\n    def __init__(\n        self,\n        g,\n        metapath,\n        window_size,\n        emb_dim=128,\n        negative_size=5,\n        sparse=True,\n    ):\n        super().__init__()\n\n        assert (\n            len(metapath) + 1 >= window_size\n        ), f\"Expect len(metapath) >= window_size - 1, got {metapath} and {window_size}\"\n\n        self.hg = g\n        self.emb_dim = emb_dim\n        self.metapath = metapath\n        self.window_size = window_size\n        self.negative_size = negative_size\n\n        # convert edge metapath to node metapath\n        # get initial source node type\n        src_type, _, _ = g.to_canonical_etype(metapath[0])\n        node_metapath = [src_type]\n        for etype in metapath:\n            _, _, dst_type = g.to_canonical_etype(etype)\n            node_metapath.append(dst_type)\n        self.node_metapath = node_metapath\n\n        # Convert the graph into a homogeneous one for global to local node ID mapping\n        g = to_homogeneous(g)\n        # Convert it back to the hetero one for local to global node ID mapping\n        hg = to_heterogeneous(g, self.hg.ntypes, self.hg.etypes)\n        local_to_global_nid = hg.ndata[NID]\n        for key, val in local_to_global_nid.items():\n            local_to_global_nid[key] = list(val.cpu().numpy())\n        self.local_to_global_nid = local_to_global_nid\n\n        num_nodes_total = hg.num_nodes()\n        node_frequency = torch.zeros(num_nodes_total)\n        # random walk\n        for idx in trange(hg.num_nodes(node_metapath[0])):\n            traces, _ = random_walk(g=hg, nodes=[idx], metapath=metapath)\n            for tr in traces.cpu().numpy():\n                tr_nids = [\n                    self.local_to_global_nid[node_metapath[i]][tr[i]]\n                    for i in range(len(tr))\n                ]\n                node_frequency[torch.LongTensor(tr_nids)] += 1\n\n        neg_prob = node_frequency.pow(0.75)\n        self.neg_prob = neg_prob / neg_prob.sum()\n\n        # center node embedding\n        self.node_embed = nn.Embedding(\n            num_nodes_total, self.emb_dim, sparse=sparse\n        )\n        self.context_embed = nn.Embedding(\n            num_nodes_total, self.emb_dim, sparse=sparse\n        )\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        \"\"\"Reinitialize learnable parameters\"\"\"\n        init_range = 1.0 / self.emb_dim\n        init.uniform_(self.node_embed.weight.data, -init_range, init_range)\n        init.constant_(self.context_embed.weight.data, 0)\n\n    def sample(self, indices):\n        \"\"\"Sample positive and negative samples\n\n        Parameters\n        ----------\n        indices : torch.Tensor\n            Node IDs of the source node type from which we perform random walks\n\n        Returns\n        -------\n        torch.Tensor\n            Positive center nodes\n        torch.Tensor\n            Positive context nodes\n        torch.Tensor\n            Negative context nodes\n        \"\"\"\n        traces, _ = random_walk(\n            g=self.hg, nodes=indices, metapath=self.metapath\n        )\n        u_list = []\n        v_list = []\n        for tr in traces.cpu().numpy():\n            tr_nids = [\n                self.local_to_global_nid[self.node_metapath[i]][tr[i]]\n                for i in range(len(tr))\n            ]\n            for i, u in enumerate(tr_nids):\n                for j, v in enumerate(\n                    tr_nids[max(i - self.window_size, 0) : i + self.window_size]\n                ):\n                    if i == j:\n                        continue\n                    u_list.append(u)\n                    v_list.append(v)\n\n        neg_v = choice(\n            self.hg.num_nodes(),\n            size=len(u_list) * self.negative_size,\n            prob=self.neg_prob,\n        ).reshape(len(u_list), self.negative_size)\n\n        return torch.LongTensor(u_list), torch.LongTensor(v_list), neg_v\n\n    def forward(self, pos_u, pos_v, neg_v):\n        r\"\"\"Compute the loss for the batch of positive and negative samples\n\n        Parameters\n        ----------\n        pos_u : torch.Tensor\n            Positive center nodes\n        pos_v : torch.Tensor\n            Positive context nodes\n        neg_v : torch.Tensor\n            Negative context nodes\n\n        Returns\n        -------\n        torch.Tensor\n            Loss value\n        \"\"\"\n        emb_u = self.node_embed(pos_u)\n        emb_v = self.context_embed(pos_v)\n        emb_neg_v = self.context_embed(neg_v)\n\n        score = torch.sum(torch.mul(emb_u, emb_v), dim=1)\n        score = torch.clamp(score, max=10, min=-10)\n        score = -F.logsigmoid(score)\n\n        neg_score = torch.bmm(emb_neg_v, emb_u.unsqueeze(2)).squeeze()\n        neg_score = torch.clamp(neg_score, max=10, min=-10)\n        neg_score = -torch.sum(F.logsigmoid(-neg_score), dim=1)\n\n        return torch.mean(score + neg_score)\n"
  },
  {
    "path": "python/dgl/nn/pytorch/softmax.py",
    "content": "\"\"\"Torch modules for graph related softmax.\"\"\"\n# pylint: disable= unused-import\nfrom ..functional import edge_softmax\n"
  },
  {
    "path": "python/dgl/nn/pytorch/sparse_emb.py",
    "content": "\"\"\"Torch NodeEmbedding.\"\"\"\nfrom datetime import timedelta\n\nimport torch as th\n\nfrom ...backend import pytorch as F\nfrom ...cuda import nccl\nfrom ...partition import NDArrayPartition\nfrom ...utils import create_shared_mem_array, get_shared_mem_array\n\n_STORE = None\n\n\nclass NodeEmbedding:  # NodeEmbedding\n    \"\"\"Class for storing node embeddings.\n\n    The class is optimized for training large-scale node embeddings. It updates the embedding in\n    a sparse way and can scale to graphs with millions of nodes. It also supports partitioning\n    to multiple GPUs (on a single machine) for more acceleration. It does not support partitioning\n    across machines.\n\n    Currently, DGL provides two optimizers that work with this NodeEmbedding\n    class: ``SparseAdagrad`` and ``SparseAdam``.\n\n    The implementation is based on torch.distributed package. It depends on the pytorch\n    default distributed process group to collect multi-process information and uses\n    ``torch.distributed.TCPStore`` to share meta-data information across multiple gpu processes.\n    It use the local address of '127.0.0.1:12346' to initialize the TCPStore.\n\n    NOTE: The support of NodeEmbedding is experimental.\n\n    Parameters\n    ----------\n    num_embeddings : int\n        The number of embeddings. Currently, the number of embeddings has to be the same as\n        the number of nodes.\n    embedding_dim : int\n        The dimension size of embeddings.\n    name : str\n        The name of the embeddings. The name should uniquely identify the embeddings in the system.\n    init_func : callable, optional\n        The function to create the initial data. If the init function is not provided,\n        the values of the embeddings are initialized to zero.\n    device : th.device\n        Device to store the embeddings on.\n    parittion : NDArrayPartition\n        The partition to use to distributed the embeddings between\n        processes.\n\n    Examples\n    --------\n    Before launching multiple gpu processes\n\n    >>> def initializer(emb):\n            th.nn.init.xavier_uniform_(emb)\n            return emb\n\n    In each training process\n\n    >>> emb = dgl.nn.NodeEmbedding(g.num_nodes(), 10, 'emb', init_func=initializer)\n    >>> optimizer = dgl.optim.SparseAdam([emb], lr=0.001)\n    >>> for blocks in dataloader:\n    ...     ...\n    ...     feats = emb(nids, gpu_0)\n    ...     loss = F.sum(feats + 1, 0)\n    ...     loss.backward()\n    ...     optimizer.step()\n    \"\"\"\n\n    def __init__(\n        self,\n        num_embeddings,\n        embedding_dim,\n        name,\n        init_func=None,\n        device=None,\n        partition=None,\n    ):\n        global _STORE\n\n        if device is None:\n            device = th.device(\"cpu\")\n\n        # Check whether it is multi-gpu training or not.\n        if th.distributed.is_initialized():\n            rank = th.distributed.get_rank()\n            world_size = th.distributed.get_world_size()\n        else:\n            rank = -1\n            world_size = 0\n        self._rank = rank\n        self._world_size = world_size\n        self._store = None\n        self._comm = None\n        self._partition = partition\n\n        host_name = \"127.0.0.1\"\n        port = 12346\n\n        if rank >= 0:\n            # for multi-gpu training, setup a TCPStore for\n            # embeding status synchronization across GPU processes\n            if _STORE is None:\n                _STORE = th.distributed.TCPStore(\n                    host_name,\n                    port,\n                    world_size,\n                    rank == 0,\n                    timedelta(seconds=10 * 60),\n                )\n            self._store = _STORE\n\n        # embeddings is stored in CPU memory.\n        if th.device(device) == th.device(\"cpu\"):\n            if rank <= 0:\n                emb = create_shared_mem_array(\n                    name, (num_embeddings, embedding_dim), th.float32\n                )\n                if init_func is not None:\n                    emb = init_func(emb)\n            if rank == 0:  # the master gpu process\n                for _ in range(1, world_size):\n                    # send embs\n                    self._store.set(name, name)\n            elif rank > 0:\n                # receive\n                self._store.wait([name])\n                emb = get_shared_mem_array(\n                    name, (num_embeddings, embedding_dim), th.float32\n                )\n            self._tensor = emb\n        else:  # embeddings is stored in GPU memory.\n            self._comm = True\n\n            if not self._partition:\n                # for communication we need a partition\n                self._partition = NDArrayPartition(\n                    num_embeddings,\n                    self._world_size if self._world_size > 0 else 1,\n                    mode=\"remainder\",\n                )\n\n            # create local tensors for the weights\n            local_size = self._partition.local_size(max(self._rank, 0))\n\n            # TODO(dlasalle): support 16-bit/half embeddings\n            emb = th.empty(\n                [local_size, embedding_dim],\n                dtype=th.float32,\n                requires_grad=False,\n                device=device,\n            )\n            if init_func:\n                emb = init_func(emb)\n            self._tensor = emb\n\n        self._num_embeddings = num_embeddings\n        self._embedding_dim = embedding_dim\n        self._name = name\n        self._optm_state = None  # track optimizer state\n        self._trace = []  # track minibatch\n\n    def __call__(self, node_ids, device=th.device(\"cpu\")):\n        \"\"\"\n        node_ids : th.tensor\n            Index of the embeddings to collect.\n        device : th.device\n            Target device to put the collected embeddings.\n        \"\"\"\n        if not self._comm:\n            # For embeddings stored on the CPU.\n            emb = self._tensor[node_ids].to(device)\n        else:\n            # For embeddings stored on the GPU.\n            # The following method is designed to perform communication\n            # across multiple GPUs and can handle situations where only one GPU\n            # is present gracefully, a.k.a. self._world_size == 1 or\n            # 0 (when th.distributed.is_initialized() is false).\n            emb = nccl.sparse_all_to_all_pull(\n                node_ids, self._tensor, self._partition\n            )\n            emb = emb.to(device)\n        if F.is_recording():\n            emb = F.attach_grad(emb)\n            self._trace.append((node_ids.to(device), emb))\n\n        return emb\n\n    @property\n    def store(self):\n        \"\"\"Return torch.distributed.TCPStore for\n        meta data sharing across processes.\n\n        Returns\n        -------\n        torch.distributed.TCPStore\n            KVStore used for meta data sharing.\n        \"\"\"\n        return self._store\n\n    @property\n    def partition(self):\n        \"\"\"Return the partition identifying how the tensor is split across\n        processes.\n\n        Returns\n        -------\n        String\n            The mode.\n        \"\"\"\n\n        return self._partition\n\n    @property\n    def rank(self):\n        \"\"\"Return rank of current process.\n\n        Returns\n        -------\n        int\n            The rank of current process.\n        \"\"\"\n        return self._rank\n\n    @property\n    def world_size(self):\n        \"\"\"Return world size of the pytorch distributed training env.\n\n        Returns\n        -------\n        int\n            The world size of the pytorch distributed training env.\n        \"\"\"\n        return self._world_size\n\n    @property\n    def name(self):\n        \"\"\"Return the name of NodeEmbedding.\n\n        Returns\n        -------\n        str\n            The name of NodeEmbedding.\n        \"\"\"\n        return self._name\n\n    @property\n    def num_embeddings(self):\n        \"\"\"Return the number of embeddings.\n\n        Returns\n        -------\n        int\n            The number of embeddings.\n        \"\"\"\n        return self._num_embeddings\n\n    @property\n    def embedding_dim(self):\n        \"\"\"Return the dimension of embeddings.\n\n        Returns\n        -------\n        int\n            The dimension of embeddings.\n        \"\"\"\n        return self._embedding_dim\n\n    def set_optm_state(self, state):\n        \"\"\"Store the optimizer related state tensor.\n\n        Parameters\n        ----------\n        state : tuple of torch.Tensor\n            Optimizer related state.\n        \"\"\"\n        self._optm_state = state\n\n    @property\n    def optm_state(self):\n        \"\"\"Return the optimizer related state tensor.\n\n        Returns\n        -------\n        tuple of torch.Tensor\n            The optimizer related state.\n        \"\"\"\n        return self._optm_state\n\n    @property\n    def trace(self):\n        \"\"\"Return a trace of the indices of embeddings\n        used in the training step(s).\n\n        Returns\n        -------\n        [torch.Tensor]\n            The indices of embeddings used in the training step(s).\n        \"\"\"\n        return self._trace\n\n    def reset_trace(self):\n        \"\"\"Clean up the trace of the indices of embeddings\n        used in the training step(s).\n        \"\"\"\n        self._trace = []\n\n    @property\n    def weight(self):\n        \"\"\"Return the tensor storing the node embeddings\n\n        Returns\n        -------\n        torch.Tensor\n            The tensor storing the node embeddings\n        \"\"\"\n        return self._tensor\n\n    def all_set_embedding(self, values):\n        \"\"\"Set the values of the embedding. This method must be called by all\n        processes sharing the embedding with identical tensors for\n        :attr:`values`.\n\n        NOTE: This method must be called by all processes sharing the\n        embedding, or it may result in a deadlock.\n\n        Parameters\n        ----------\n        values : Tensor\n            The global tensor to pull values from.\n        \"\"\"\n        if self._partition:\n            idxs = F.copy_to(\n                self._partition.get_local_indices(\n                    max(self._rank, 0),\n                    ctx=F.context(self._tensor),\n                ),\n                F.context(values),\n            )\n            self._tensor[:] = F.copy_to(\n                F.gather_row(values, idxs), ctx=F.context(self._tensor)\n            )[:]\n        else:\n            if self._rank == 0:\n                self._tensor[:] = F.copy_to(\n                    values, ctx=F.context(self._tensor)\n                )[:]\n        if th.distributed.is_initialized():\n            th.distributed.barrier()\n\n    def _all_get_tensor(self, shared_name, tensor, shape):\n        \"\"\"A helper function to get model-parallel tensors.\n\n        This method must and only need to be called in multi-GPU DDP training.\n        For now, it's only used in ``all_get_embedding`` and\n        ``_all_get_optm_state``.\n        \"\"\"\n        # create a shared memory tensor\n        if self._rank == 0:\n            # root process creates shared memory\n            val = create_shared_mem_array(\n                shared_name,\n                shape,\n                tensor.dtype,\n            )\n            self._store.set(shared_name, shared_name)\n        else:\n            self._store.wait([shared_name])\n            val = get_shared_mem_array(\n                shared_name,\n                shape,\n                tensor.dtype,\n            )\n        # need to map indices and slice into existing tensor\n        idxs = self._partition.map_to_global(\n            F.arange(0, tensor.shape[0], ctx=F.context(tensor)),\n            self._rank,\n        ).to(val.device)\n        val[idxs] = tensor.to(val.device)\n\n        self._store.delete_key(shared_name)\n        # wait for all processes to finish\n        th.distributed.barrier()\n        return val\n\n    def all_get_embedding(self):\n        \"\"\"Return a copy of the embedding stored in CPU memory. If this is a\n        multi-processing instance, the tensor will be returned in shared\n        memory. If the embedding is currently stored on multiple GPUs, all\n        processes must call this method in the same order.\n\n        NOTE: This method must be called by all processes sharing the\n        embedding, or it may result in a deadlock.\n\n        Returns\n        -------\n        torch.Tensor\n            The tensor storing the node embeddings.\n        \"\"\"\n        if self._partition:\n            if self._world_size == 0:\n                # non-multiprocessing\n                return self._tensor.to(th.device(\"cpu\"))\n            else:\n                return self._all_get_tensor(\n                    f\"{self._name}_gather\",\n                    self._tensor,\n                    (self._num_embeddings, self._embedding_dim),\n                )\n        else:\n            # already stored in CPU memory\n            return self._tensor\n\n    def _all_get_optm_state(self):\n        \"\"\"Return a copy of the whole optimizer states stored in CPU memory.\n        If this is a multi-processing instance, the states will be returned in\n        shared memory. If the embedding is currently stored on multiple GPUs,\n        all processes must call this method in the same order.\n\n        NOTE: This method must be called by all processes sharing the\n        embedding, or it may result in a deadlock.\n\n        Returns\n        -------\n        tuple of torch.Tensor\n            The optimizer states stored in CPU memory.\n        \"\"\"\n        if self._partition:\n            if self._world_size == 0:\n                # non-multiprocessing\n                return tuple(\n                    state.to(th.device(\"cpu\")) for state in self._optm_state\n                )\n            else:\n                return tuple(\n                    self._all_get_tensor(\n                        f\"state_gather_{self._name}_{i}\",\n                        state,\n                        (self._num_embeddings, *state.shape[1:]),\n                    )\n                    for i, state in enumerate(self._optm_state)\n                )\n        else:\n            # already stored in CPU memory\n            return self._optm_state\n\n    def _all_set_optm_state(self, states):\n        \"\"\"Set the optimizer states of the embedding. This method must be\n        called by all processes sharing the embedding with identical\n        :attr:`states`.\n\n        NOTE: This method must be called by all processes sharing the\n        embedding, or it may result in a deadlock.\n\n        Parameters\n        ----------\n        states : tuple of torch.Tensor\n            The global states to pull values from.\n        \"\"\"\n        if self._partition:\n            idxs = F.copy_to(\n                self._partition.get_local_indices(\n                    max(self._rank, 0), ctx=F.context(self._tensor)\n                ),\n                F.context(states[0]),\n            )\n            for state, new_state in zip(self._optm_state, states):\n                state[:] = F.copy_to(\n                    F.gather_row(new_state, idxs), ctx=F.context(self._tensor)\n                )[:]\n        else:\n            # stored in CPU memory\n            if self._rank <= 0:\n                for state, new_state in zip(self._optm_state, states):\n                    state[:] = F.copy_to(\n                        new_state, ctx=F.context(self._tensor)\n                    )[:]\n        if th.distributed.is_initialized():\n            th.distributed.barrier()\n"
  },
  {
    "path": "python/dgl/nn/pytorch/utils.py",
    "content": "\"\"\"Utilities for pytorch NN package\"\"\"\n# pylint: disable=no-member, invalid-name\n\nimport torch as th\nimport torch.nn.functional as F\nfrom torch import nn\n\nfrom ... import DGLGraph, function as fn\nfrom ...base import dgl_warning\n\n\ndef matmul_maybe_select(A, B):\n    \"\"\"Perform Matrix multiplication C = A * B but A could be an integer id vector.\n\n    If A is an integer vector, we treat it as multiplying a one-hot encoded tensor.\n    In this case, the expensive dense matrix multiply can be replaced by a much\n    cheaper index lookup.\n\n    For example,\n    ::\n\n        A = [2, 0, 1],\n        B = [[0.1, 0.2],\n             [0.3, 0.4],\n             [0.5, 0.6]]\n\n    then matmul_maybe_select(A, B) is equivalent to\n    ::\n\n        [[0, 0, 1],     [[0.1, 0.2],\n         [1, 0, 0],  *   [0.3, 0.4],\n         [0, 1, 0]]      [0.5, 0.6]]\n\n    In all other cases, perform a normal matmul.\n\n    Parameters\n    ----------\n    A : torch.Tensor\n        lhs tensor\n    B : torch.Tensor\n        rhs tensor\n\n    Returns\n    -------\n    C : torch.Tensor\n        result tensor\n    \"\"\"\n    if A.dtype == th.int64 and len(A.shape) == 1:\n        return B.index_select(0, A)\n    else:\n        return th.matmul(A, B)\n\n\ndef bmm_maybe_select(A, B, index):\n    \"\"\"Slice submatrices of A by the given index and perform bmm.\n\n    B is a 3D tensor of shape (N, D1, D2), which can be viewed as a stack of\n    N matrices of shape (D1, D2). The input index is an integer vector of length M.\n    A could be either:\n    (1) a dense tensor of shape (M, D1),\n    (2) an integer vector of length M.\n    The result C is a 2D matrix of shape (M, D2)\n\n    For case (1), C is computed by bmm:\n    ::\n\n        C[i, :] = matmul(A[i, :], B[index[i], :, :])\n\n    For case (2), C is computed by index select:\n    ::\n\n        C[i, :] = B[index[i], A[i], :]\n\n    Parameters\n    ----------\n    A : torch.Tensor\n        lhs tensor\n    B : torch.Tensor\n        rhs tensor\n    index : torch.Tensor\n        index tensor\n\n    Returns\n    -------\n    C : torch.Tensor\n        return tensor\n    \"\"\"\n    if A.dtype == th.int64 and len(A.shape) == 1:\n        # following is a faster version of B[index, A, :]\n        B = B.view(-1, B.shape[2])\n        flatidx = index * B.shape[1] + A\n        return B.index_select(0, flatidx)\n    else:\n        BB = B.index_select(0, index)\n        return th.bmm(A.unsqueeze(1), BB).squeeze()\n\n\n# pylint: disable=W0235\nclass Identity(nn.Module):\n    \"\"\"A placeholder identity operator that is argument-insensitive.\n    (Identity has already been supported by PyTorch 1.2, we will directly\n    import torch.nn.Identity in the future)\n    \"\"\"\n\n    def __init__(self):\n        super(Identity, self).__init__()\n\n    def forward(self, x):\n        \"\"\"Return input\"\"\"\n        return x\n\n\nclass Sequential(nn.Sequential):\n    r\"\"\"A sequential container for stacking graph neural network modules\n\n    DGL supports two modes: sequentially apply GNN modules on 1) the same graph or\n    2) a list of given graphs. In the second case, the number of graphs equals the\n    number of modules inside this container.\n\n    Parameters\n    ----------\n    *args :\n        Sub-modules of torch.nn.Module that will be added to the container in\n        the order by which they are passed in the constructor.\n\n    Examples\n    --------\n    The following example uses PyTorch backend.\n\n    Mode 1: sequentially apply GNN modules on the same graph\n\n    >>> import torch\n    >>> import dgl\n    >>> import torch.nn as nn\n    >>> import dgl.function as fn\n    >>> from dgl.nn.pytorch import Sequential\n    >>> class ExampleLayer(nn.Module):\n    >>>     def __init__(self):\n    >>>         super().__init__()\n    >>>     def forward(self, graph, n_feat, e_feat):\n    >>>         with graph.local_scope():\n    >>>             graph.ndata['h'] = n_feat\n    >>>             graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))\n    >>>             n_feat += graph.ndata['h']\n    >>>             graph.apply_edges(fn.u_add_v('h', 'h', 'e'))\n    >>>             e_feat += graph.edata['e']\n    >>>             return n_feat, e_feat\n    >>>\n    >>> g = dgl.DGLGraph()\n    >>> g.add_nodes(3)\n    >>> g.add_edges([0, 1, 2, 0, 1, 2, 0, 1, 2], [0, 0, 0, 1, 1, 1, 2, 2, 2])\n    >>> net = Sequential(ExampleLayer(), ExampleLayer(), ExampleLayer())\n    >>> n_feat = torch.rand(3, 4)\n    >>> e_feat = torch.rand(9, 4)\n    >>> net(g, n_feat, e_feat)\n    (tensor([[39.8597, 45.4542, 25.1877, 30.8086],\n             [40.7095, 45.3985, 25.4590, 30.0134],\n             [40.7894, 45.2556, 25.5221, 30.4220]]),\n     tensor([[80.3772, 89.7752, 50.7762, 60.5520],\n             [80.5671, 89.3736, 50.6558, 60.6418],\n             [80.4620, 89.5142, 50.3643, 60.3126],\n             [80.4817, 89.8549, 50.9430, 59.9108],\n             [80.2284, 89.6954, 50.0448, 60.1139],\n             [79.7846, 89.6882, 50.5097, 60.6213],\n             [80.2654, 90.2330, 50.2787, 60.6937],\n             [80.3468, 90.0341, 50.2062, 60.2659],\n             [80.0556, 90.2789, 50.2882, 60.5845]]))\n\n    Mode 2: sequentially apply GNN modules on different graphs\n\n    >>> import torch\n    >>> import dgl\n    >>> import torch.nn as nn\n    >>> import dgl.function as fn\n    >>> import networkx as nx\n    >>> from dgl.nn.pytorch import Sequential\n    >>> class ExampleLayer(nn.Module):\n    >>>     def __init__(self):\n    >>>         super().__init__()\n    >>>     def forward(self, graph, n_feat):\n    >>>         with graph.local_scope():\n    >>>             graph.ndata['h'] = n_feat\n    >>>             graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))\n    >>>             n_feat += graph.ndata['h']\n    >>>             return n_feat.view(graph.num_nodes() // 2, 2, -1).sum(1)\n    >>>\n    >>> g1 = dgl.DGLGraph(nx.erdos_renyi_graph(32, 0.05))\n    >>> g2 = dgl.DGLGraph(nx.erdos_renyi_graph(16, 0.2))\n    >>> g3 = dgl.DGLGraph(nx.erdos_renyi_graph(8, 0.8))\n    >>> net = Sequential(ExampleLayer(), ExampleLayer(), ExampleLayer())\n    >>> n_feat = torch.rand(32, 4)\n    >>> net([g1, g2, g3], n_feat)\n    tensor([[209.6221, 225.5312, 193.8920, 220.1002],\n            [250.0169, 271.9156, 240.2467, 267.7766],\n            [220.4007, 239.7365, 213.8648, 234.9637],\n            [196.4630, 207.6319, 184.2927, 208.7465]])\n    \"\"\"\n\n    def __init__(self, *args):\n        super(Sequential, self).__init__(*args)\n\n    def forward(self, graph, *feats):\n        r\"\"\"\n\n        Sequentially apply modules to the input.\n\n        Parameters\n        ----------\n        graph : DGLGraph or list of DGLGraphs\n            The graph(s) to apply modules on.\n\n        *feats :\n            Input features.\n            The output of the :math:`i`-th module should match the input\n            of the :math:`(i+1)`-th module in the sequential.\n        \"\"\"\n        if isinstance(graph, list):\n            for graph_i, module in zip(graph, self):\n                if not isinstance(feats, tuple):\n                    feats = (feats,)\n                feats = module(graph_i, *feats)\n        elif isinstance(graph, DGLGraph):\n            for module in self:\n                if not isinstance(feats, tuple):\n                    feats = (feats,)\n                feats = module(graph, *feats)\n        else:\n            raise TypeError(\n                \"The first argument of forward must be a DGLGraph\"\n                \" or a list of DGLGraph s\"\n            )\n        return feats\n\n\nclass WeightBasis(nn.Module):\n    r\"\"\"Basis decomposition from `Modeling Relational Data with Graph\n    Convolutional Networks <https://arxiv.org/abs/1703.06103>`__\n\n    It can be described as below:\n\n    .. math::\n\n        W_o = \\sum_{b=1}^B a_{ob} V_b\n\n    Each weight output :math:`W_o` is essentially a linear combination of basis\n    transformations :math:`V_b` with coefficients :math:`a_{ob}`.\n\n    If is useful as a form of regularization on a large parameter matrix. Thus,\n    the number of weight outputs is usually larger than the number of bases.\n\n    Parameters\n    ----------\n    shape : tuple[int]\n        Shape of the basis parameter.\n    num_bases : int\n        Number of bases.\n    num_outputs : int\n        Number of outputs.\n    \"\"\"\n\n    def __init__(self, shape, num_bases, num_outputs):\n        super(WeightBasis, self).__init__()\n        self.shape = shape\n        self.num_bases = num_bases\n        self.num_outputs = num_outputs\n\n        if num_outputs <= num_bases:\n            dgl_warning(\n                \"The number of weight outputs should be larger than the number\"\n                \" of bases.\"\n            )\n\n        self.weight = nn.Parameter(th.Tensor(self.num_bases, *shape))\n        nn.init.xavier_uniform_(\n            self.weight, gain=nn.init.calculate_gain(\"relu\")\n        )\n        # linear combination coefficients\n        self.w_comp = nn.Parameter(th.Tensor(self.num_outputs, self.num_bases))\n        nn.init.xavier_uniform_(\n            self.w_comp, gain=nn.init.calculate_gain(\"relu\")\n        )\n\n    def forward(self):\n        r\"\"\"Forward computation\n\n        Returns\n        -------\n        weight : torch.Tensor\n            Composed weight tensor of shape ``(num_outputs,) + shape``\n        \"\"\"\n        # generate all weights from bases\n        weight = th.matmul(self.w_comp, self.weight.view(self.num_bases, -1))\n        return weight.view(self.num_outputs, *self.shape)\n\n\nclass JumpingKnowledge(nn.Module):\n    r\"\"\"The Jumping Knowledge aggregation module from `Representation Learning on\n    Graphs with Jumping Knowledge Networks <https://arxiv.org/abs/1806.03536>`__\n\n    It aggregates the output representations of multiple GNN layers with\n\n    **concatenation**\n\n    .. math::\n\n        h_i^{(1)} \\, \\Vert \\, \\ldots \\, \\Vert \\, h_i^{(T)}\n\n    or **max pooling**\n\n    .. math::\n\n        \\max \\left( h_i^{(1)}, \\ldots, h_i^{(T)} \\right)\n\n    or **LSTM**\n\n    .. math::\n\n        \\sum_{t=1}^T \\alpha_i^{(t)} h_i^{(t)}\n\n    with attention scores :math:`\\alpha_i^{(t)}` obtained from a BiLSTM\n\n    Parameters\n    ----------\n    mode : str\n        The aggregation to apply. It can be 'cat', 'max', or 'lstm',\n        corresponding to the equations above in order.\n    in_feats : int, optional\n        This argument is only required if :attr:`mode` is ``'lstm'``.\n        The output representation size of a single GNN layer. Note that\n        all GNN layers need to have the same output representation size.\n    num_layers : int, optional\n        This argument is only required if :attr:`mode` is ``'lstm'``.\n        The number of GNN layers for output aggregation.\n\n    Examples\n    --------\n    >>> import dgl\n    >>> import torch as th\n    >>> from dgl.nn import JumpingKnowledge\n\n    >>> # Output representations of two GNN layers\n    >>> num_nodes = 3\n    >>> in_feats = 4\n    >>> feat_list = [th.zeros(num_nodes, in_feats), th.ones(num_nodes, in_feats)]\n\n    >>> # Case1\n    >>> model = JumpingKnowledge()\n    >>> model(feat_list).shape\n    torch.Size([3, 8])\n\n    >>> # Case2\n    >>> model = JumpingKnowledge(mode='max')\n    >>> model(feat_list).shape\n    torch.Size([3, 4])\n\n    >>> # Case3\n    >>> model = JumpingKnowledge(mode='max', in_feats=in_feats, num_layers=len(feat_list))\n    >>> model(feat_list).shape\n    torch.Size([3, 4])\n    \"\"\"\n\n    def __init__(self, mode=\"cat\", in_feats=None, num_layers=None):\n        super(JumpingKnowledge, self).__init__()\n        assert mode in [\n            \"cat\",\n            \"max\",\n            \"lstm\",\n        ], \"Expect mode to be 'cat', or 'max' or 'lstm', got {}\".format(mode)\n        self.mode = mode\n\n        if mode == \"lstm\":\n            assert in_feats is not None, \"in_feats is required for lstm mode\"\n            assert (\n                num_layers is not None\n            ), \"num_layers is required for lstm mode\"\n            hidden_size = (num_layers * in_feats) // 2\n            self.lstm = nn.LSTM(\n                in_feats, hidden_size, bidirectional=True, batch_first=True\n            )\n            self.att = nn.Linear(2 * hidden_size, 1)\n\n    def reset_parameters(self):\n        r\"\"\"\n\n        Description\n        -----------\n        Reinitialize learnable parameters. This comes into effect only for the lstm mode.\n        \"\"\"\n        if self.mode == \"lstm\":\n            self.lstm.reset_parameters()\n            self.att.reset_parameters()\n\n    def forward(self, feat_list):\n        r\"\"\"\n\n        Description\n        -----------\n        Aggregate output representations across multiple GNN layers.\n\n        Parameters\n        ----------\n        feat_list : list[Tensor]\n            feat_list[i] is the output representations of a GNN layer.\n\n        Returns\n        -------\n        Tensor\n            The aggregated representations.\n        \"\"\"\n        if self.mode == \"cat\":\n            return th.cat(feat_list, dim=-1)\n        elif self.mode == \"max\":\n            return th.stack(feat_list, dim=-1).max(dim=-1)[0]\n        else:\n            # LSTM\n            stacked_feat_list = th.stack(\n                feat_list, dim=1\n            )  # (N, num_layers, in_feats)\n            alpha, _ = self.lstm(stacked_feat_list)\n            alpha = self.att(alpha).squeeze(-1)  # (N, num_layers)\n            alpha = th.softmax(alpha, dim=-1)\n            return (stacked_feat_list * alpha.unsqueeze(-1)).sum(dim=1)\n\n\nclass LabelPropagation(nn.Module):\n    r\"\"\"Label Propagation from `Learning from Labeled and Unlabeled Data with Label\n    Propagation <http://mlg.eng.cam.ac.uk/zoubin/papers/CMU-CALD-02-107.pdf>`__\n\n    .. math::\n\n        \\mathbf{Y}^{(t+1)} = \\alpha \\tilde{A} \\mathbf{Y}^{(t)} + (1 - \\alpha) \\mathbf{Y}^{(0)}\n\n    where unlabeled data is initially set to zero and inferred from labeled data via\n    propagation. :math:`\\alpha` is a weight parameter for balancing between updated labels\n    and initial labels. :math:`\\tilde{A}` denotes the normalized adjacency matrix.\n\n    Parameters\n    ----------\n    k: int\n        The number of propagation steps.\n    alpha : float\n        The :math:`\\alpha` coefficient in range [0, 1].\n    norm_type : str, optional\n        The type of normalization applied to the adjacency matrix, must be one of the\n        following choices:\n\n        * ``row``: row-normalized adjacency as :math:`D^{-1}A`\n\n        * ``sym``: symmetrically normalized adjacency as :math:`D^{-1/2}AD^{-1/2}`\n\n        Default: 'sym'.\n    clamp : bool, optional\n        A bool flag to indicate whether to clamp the labels to [0, 1] after propagation.\n        Default: True.\n    normalize: bool, optional\n        A bool flag to indicate whether to apply row-normalization after propagation.\n        Default: False.\n    reset : bool, optional\n        A bool flag to indicate whether to reset the known labels after each\n        propagation step. Default: False.\n\n    Examples\n    --------\n    >>> import torch\n    >>> import dgl\n    >>> from dgl.nn import LabelPropagation\n\n    >>> label_propagation = LabelPropagation(k=5, alpha=0.5, clamp=False, normalize=True)\n    >>> g = dgl.rand_graph(5, 10)\n    >>> labels = torch.tensor([0, 2, 1, 3, 0]).long()\n    >>> mask = torch.tensor([0, 1, 1, 1, 0]).bool()\n    >>> new_labels = label_propagation(g, labels, mask)\n    \"\"\"\n\n    def __init__(\n        self,\n        k,\n        alpha,\n        norm_type=\"sym\",\n        clamp=True,\n        normalize=False,\n        reset=False,\n    ):\n        super(LabelPropagation, self).__init__()\n        self.k = k\n        self.alpha = alpha\n        self.norm_type = norm_type\n        self.clamp = clamp\n        self.normalize = normalize\n        self.reset = reset\n\n    def forward(self, g, labels, mask=None):\n        r\"\"\"Compute the label propagation process.\n\n        Parameters\n        ----------\n        g : DGLGraph\n            The input graph.\n        labels : torch.Tensor\n            The input node labels. There are three cases supported.\n\n            * A LongTensor of shape :math:`(N, 1)` or :math:`(N,)` for node class labels in\n              multiclass classification, where :math:`N` is the number of nodes.\n            * A LongTensor of shape :math:`(N, C)` for one-hot encoding of node class labels\n              in multiclass classification, where :math:`C` is the number of classes.\n            * A LongTensor of shape :math:`(N, L)` for node labels in multilabel binary\n              classification, where :math:`L` is the number of labels.\n        mask : torch.Tensor\n            The bool indicators of shape :math:`(N,)` with True denoting labeled nodes.\n            Default: None, indicating all nodes are labeled.\n\n        Returns\n        -------\n        torch.Tensor\n            The propagated node labels of shape :math:`(N, D)` with float type, where :math:`D`\n            is the number of classes or labels.\n        \"\"\"\n        with g.local_scope():\n            # multi-label / multi-class\n            if len(labels.size()) > 1 and labels.size(1) > 1:\n                labels = labels.to(th.float32)\n            # single-label multi-class\n            else:\n                labels = F.one_hot(labels.view(-1)).to(th.float32)\n\n            y = labels\n            if mask is not None:\n                y = th.zeros_like(labels)\n                y[mask] = labels[mask]\n\n            init = (1 - self.alpha) * y\n            in_degs = g.in_degrees().float().clamp(min=1)\n            out_degs = g.out_degrees().float().clamp(min=1)\n            if self.norm_type == \"sym\":\n                norm_i = th.pow(in_degs, -0.5).to(labels.device).unsqueeze(1)\n                norm_j = th.pow(out_degs, -0.5).to(labels.device).unsqueeze(1)\n            elif self.norm_type == \"row\":\n                norm_i = th.pow(in_degs, -1.0).to(labels.device).unsqueeze(1)\n            else:\n                raise ValueError(\n                    f\"Expect norm_type to be 'sym' or 'row', got {self.norm_type}\"\n                )\n\n            for _ in range(self.k):\n                g.ndata[\"h\"] = y * norm_j if self.norm_type == \"sym\" else y\n                g.update_all(fn.copy_u(\"h\", \"m\"), fn.sum(\"m\", \"h\"))\n                y = init + self.alpha * g.ndata[\"h\"] * norm_i\n\n                if self.clamp:\n                    y = y.clamp_(0.0, 1.0)\n                if self.normalize:\n                    y = F.normalize(y, p=1)\n                if self.reset:\n                    y[mask] = labels[mask]\n\n            return y\n"
  },
  {
    "path": "python/dgl/nn/tensorflow/__init__.py",
    "content": "\"\"\"Package for Tensorflow-specific NN modules.\"\"\"\nfrom .conv import *\nfrom .glob import *\nfrom .hetero import *\nfrom .softmax import *\nfrom .utils import *\n"
  },
  {
    "path": "python/dgl/nn/tensorflow/conv/__init__.py",
    "content": "\"\"\"TF NN conv module\"\"\"\nfrom .appnpconv import APPNPConv\nfrom .chebconv import ChebConv\nfrom .densechebconv import DenseChebConv\nfrom .edgeconv import EdgeConv\nfrom .gatconv import GATConv\nfrom .ginconv import GINConv\nfrom .graphconv import GraphConv\nfrom .relgraphconv import RelGraphConv\nfrom .sageconv import SAGEConv\nfrom .sgconv import SGConv\n"
  },
  {
    "path": "python/dgl/nn/tensorflow/conv/appnpconv.py",
    "content": "\"\"\"TF Module for APPNPConv\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name\nimport numpy as np\nimport tensorflow as tf\nfrom tensorflow.keras import layers\n\nfrom .... import function as fn\n\n\nclass APPNPConv(layers.Layer):\n    r\"\"\"Approximate Personalized Propagation of Neural Predictions\n    layer from `Predict then Propagate: Graph Neural Networks\n    meet Personalized PageRank <https://arxiv.org/pdf/1810.05997.pdf>`__\n\n    .. math::\n        H^{0} & = X\n\n        H^{t+1} & = (1-\\alpha)\\left(\\hat{D}^{-1/2}\n        \\hat{A} \\hat{D}^{-1/2} H^{t}\\right) + \\alpha H^{0}\n\n    Parameters\n    ----------\n    k : int\n        Number of iterations :math:`K`.\n    alpha : float\n        The teleport probability :math:`\\alpha`.\n    edge_drop : float, optional\n        Dropout rate on edges that controls the\n        messages received by each node. Default: ``0``.\n    \"\"\"\n\n    def __init__(self, k, alpha, edge_drop=0.0):\n        super(APPNPConv, self).__init__()\n        self._k = k\n        self._alpha = alpha\n        self.edge_drop = layers.Dropout(edge_drop)\n\n    def call(self, graph, feat):\n        r\"\"\"Compute APPNP layer.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The graph.\n        feat : tf.Tensor\n            The input feature of shape :math:`(N, *)` :math:`N` is the\n            number of nodes, and :math:`*` could be of any shape.\n\n        Returns\n        -------\n        tf.Tensor\n            The output feature of shape :math:`(N, *)` where :math:`*`\n            should be the same as input shape.\n        \"\"\"\n        with graph.local_scope():\n            degs = tf.clip_by_value(\n                tf.cast(graph.in_degrees(), tf.float32),\n                clip_value_min=1,\n                clip_value_max=np.inf,\n            )\n            norm = tf.pow(degs, -0.5)\n            shp = norm.shape + (1,) * (feat.ndim - 1)\n            norm = tf.reshape(norm, shp)\n            feat_0 = feat\n            for _ in range(self._k):\n                # normalization by src node\n                feat = feat * norm\n                graph.ndata[\"h\"] = feat\n                graph.edata[\"w\"] = self.edge_drop(tf.ones(graph.num_edges(), 1))\n                graph.update_all(fn.u_mul_e(\"h\", \"w\", \"m\"), fn.sum(\"m\", \"h\"))\n                feat = graph.ndata.pop(\"h\")\n                # normalization by dst node\n                feat = feat * norm\n                feat = (1 - self._alpha) * feat + self._alpha * feat_0\n            return feat\n"
  },
  {
    "path": "python/dgl/nn/tensorflow/conv/chebconv.py",
    "content": "\"\"\"Tensorflow Module for Chebyshev Spectral Graph Convolution layer\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name\nimport numpy as np\nimport tensorflow as tf\nfrom tensorflow.keras import layers\n\nfrom .... import broadcast_nodes, function as fn\nfrom ....base import dgl_warning\n\n\nclass ChebConv(layers.Layer):\n    r\"\"\"Chebyshev Spectral Graph Convolution layer from `Convolutional\n    Neural Networks on Graphs with Fast Localized Spectral Filtering\n    <https://arxiv.org/pdf/1606.09375.pdf>`__\n\n    .. math::\n        h_i^{l+1} &= \\sum_{k=0}^{K-1} W^{k, l}z_i^{k, l}\n\n        Z^{0, l} &= H^{l}\n\n        Z^{1, l} &= \\tilde{L} \\cdot H^{l}\n\n        Z^{k, l} &= 2 \\cdot \\tilde{L} \\cdot Z^{k-1, l} - Z^{k-2, l}\n\n        \\tilde{L} &= 2\\left(I - \\tilde{D}^{-1/2} \\tilde{A} \\tilde{D}^{-1/2}\\right)/\\lambda_{max} - I\n\n    where :math:`\\tilde{A}` is :math:`A` + :math:`I`, :math:`W` is learnable weight.\n\n    Parameters\n    ----------\n    in_feats: int\n        Dimension of input features; i.e, the number of dimensions of :math:`h_i^{(l)}`.\n    out_feats: int\n        Dimension of output features :math:`h_i^{(l+1)}`.\n    k : int\n        Chebyshev filter size :math:`K`.\n    activation : function, optional\n        Activation function. Default ``ReLu``.\n    bias : bool, optional\n        If True, adds a learnable bias to the output. Default: ``True``.\n\n    Example\n    -------\n    >>> import dgl\n    >>> import numpy as np\n    >>> import tensorflow as tf\n    >>> from dgl.nn import ChebConv\n    >>> with tf.device(\"CPU:0\"):\n    ...     g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))\n    ...     feat = tf.ones((6, 10))\n    ...     conv = ChebConv(10, 2, 2)\n    ...     res = conv(g, feat)\n    ...     res\n    <tf.Tensor: shape=(6, 2), dtype=float32, numpy=\n    array([[ 0.6163, -0.1809],\n            [ 0.6163, -0.1809],\n            [ 0.6163, -0.1809],\n            [ 0.9698, -1.5053],\n            [ 0.3664,  0.7556],\n            [-0.2370,  3.0164]], dtype=float32)>\n    \"\"\"\n\n    def __init__(\n        self, in_feats, out_feats, k, activation=tf.nn.relu, bias=True\n    ):\n        super(ChebConv, self).__init__()\n        self._k = k\n        self._in_feats = in_feats\n        self._out_feats = out_feats\n        self.activation = activation\n        self.linear = layers.Dense(out_feats, use_bias=bias)\n\n    def call(self, graph, feat, lambda_max=None):\n        r\"\"\"Compute ChebNet layer.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The graph.\n        feat : tf.Tensor\n            The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`\n            is size of input feature, :math:`N` is the number of nodes.\n        lambda_max : list or tensor or None, optional.\n            A list(tensor) with length :math:`B`, stores the largest eigenvalue\n            of the normalized laplacian of each individual graph in ``graph``,\n            where :math:`B` is the batch size of the input graph. Default: None.\n\n            If None, this method would set the default value to 2.\n            One can use :func:`dgl.laplacian_lambda_max` to compute this value.\n\n        Returns\n        -------\n        tf.Tensor\n            The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`\n            is size of output feature.\n        \"\"\"\n\n        def unnLaplacian(feat, D_invsqrt, graph):\n            \"\"\"Operation Feat * D^-1/2 A D^-1/2\"\"\"\n            graph.ndata[\"h\"] = feat * D_invsqrt\n            graph.update_all(fn.copy_u(\"h\", \"m\"), fn.sum(\"m\", \"h\"))\n            return graph.ndata.pop(\"h\") * D_invsqrt\n\n        with graph.local_scope():\n            in_degrees = tf.clip_by_value(\n                tf.cast(graph.in_degrees(), tf.float32),\n                clip_value_min=1,\n                clip_value_max=np.inf,\n            )\n            D_invsqrt = tf.expand_dims(tf.pow(in_degrees, -0.5), axis=-1)\n\n            if lambda_max is None:\n                dgl_warning(\n                    \"lambda_max is not provided, using default value of 2.  \"\n                    \"Please use dgl.laplacian_lambda_max to compute the eigenvalues.\"\n                )\n                lambda_max = [2] * graph.batch_size\n\n            if isinstance(lambda_max, list):\n                lambda_max = tf.constant(lambda_max, dtype=tf.float32)\n            if lambda_max.ndim == 1:\n                lambda_max = tf.expand_dims(\n                    lambda_max, axis=-1\n                )  # (B,) to (B, 1)\n\n            # broadcast from (B, 1) to (N, 1)\n            lambda_max = broadcast_nodes(graph, lambda_max)\n            re_norm = 2.0 / lambda_max\n\n            # X_0 is the raw feature, Xt is the list of X_0, X_1, ... X_t\n            X_0 = feat\n            Xt = [X_0]\n\n            # X_1(f)\n            if self._k > 1:\n                h = unnLaplacian(X_0, D_invsqrt, graph)\n                X_1 = -re_norm * h + X_0 * (re_norm - 1)\n                # Append X_1 to Xt\n                Xt.append(X_1)\n\n            # Xi(x), i = 2...k\n            for _ in range(2, self._k):\n                h = unnLaplacian(X_1, D_invsqrt, graph)\n                X_i = -2 * re_norm * h + X_1 * 2 * (re_norm - 1) - X_0\n                # Append X_i to Xt\n                Xt.append(X_i)\n                X_1, X_0 = X_i, X_1\n\n            # Create the concatenation\n            Xt = tf.concat(Xt, 1)\n\n            # linear projection\n            h = self.linear(Xt)\n\n            # activation\n            if self.activation:\n                h = self.activation(h)\n\n        return h\n"
  },
  {
    "path": "python/dgl/nn/tensorflow/conv/densechebconv.py",
    "content": "\"\"\"Tensorflow Module for DenseChebConv\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name\nimport numpy as np\nimport tensorflow as tf\nfrom tensorflow.keras import layers\n\n\nclass DenseChebConv(layers.Layer):\n    r\"\"\"Chebyshev Spectral Graph Convolution layer from `Convolutional\n    Neural Networks on Graphs with Fast Localized Spectral Filtering\n    <https://arxiv.org/pdf/1606.09375.pdf>`__\n\n    We recommend to use this module when applying ChebConv on dense graphs.\n\n    Parameters\n    ----------\n    in_feats: int\n        Dimension of input features :math:`h_i^{(l)}`.\n    out_feats: int\n        Dimension of output features :math:`h_i^{(l+1)}`.\n    k : int\n        Chebyshev filter size.\n    activation : function, optional\n        Activation function, default is ReLu.\n    bias : bool, optional\n        If True, adds a learnable bias to the output. Default: ``True``.\n\n    See also\n    --------\n    `ChebConv <https://docs.dgl.ai/api/python/nn.tensorflow.html#chebconv>`__\n    \"\"\"\n\n    def __init__(self, in_feats, out_feats, k, bias=True):\n        super(DenseChebConv, self).__init__()\n        self._in_feats = in_feats\n        self._out_feats = out_feats\n        self._k = k\n\n        # keras initializer assume last two dims as fan_in and fan_out\n        xinit = tf.keras.initializers.glorot_normal()\n        self.W = tf.Variable(\n            initial_value=xinit(\n                shape=(k, in_feats, out_feats), dtype=\"float32\"\n            ),\n            trainable=True,\n        )\n\n        if bias:\n            zeroinit = tf.keras.initializers.zeros()\n            self.bias = tf.Variable(\n                initial_value=zeroinit(shape=(out_feats), dtype=\"float32\"),\n                trainable=True,\n            )\n        else:\n            self.bias = None\n\n    def call(self, adj, feat, lambda_max=None):\n        r\"\"\"Compute (Dense) Chebyshev Spectral Graph Convolution layer.\n\n        Parameters\n        ----------\n        adj : tf.Tensor\n            The adjacency matrix of the graph to apply Graph Convolution on,\n            should be of shape :math:`(N, N)`, where a row represents the destination\n            and a column represents the source.\n        feat : tf.Tensor\n            The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`\n            is size of input feature, :math:`N` is the number of nodes.\n        lambda_max : float or None, optional\n            A float value indicates the largest eigenvalue of given graph.\n            Default: None.\n\n        Returns\n        -------\n        tf.Tensor\n            The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`\n            is size of output feature.\n        \"\"\"\n        A = adj\n        num_nodes = A.shape[0]\n        in_degree = 1 / tf.sqrt(\n            tf.clip_by_value(\n                tf.reduce_sum(A, 1), clip_value_min=1, clip_value_max=np.inf\n            )\n        )\n        D_invsqrt = tf.linalg.diag(in_degree)\n        I = tf.eye(num_nodes)\n        L = I - D_invsqrt @ A @ D_invsqrt\n\n        if lambda_max is None:\n            lambda_ = tf.linalg.eig(L)[0][:, 0]\n            lambda_max = tf.reduce_max(lambda_)\n\n        L_hat = 2 * L / lambda_max - I\n        Z = [tf.eye(num_nodes)]\n        for i in range(1, self._k):\n            if i == 1:\n                Z.append(L_hat)\n            else:\n                Z.append(2 * L_hat @ Z[-1] - Z[-2])\n\n        Zs = tf.stack(Z, 0)  # (k, n, n)\n\n        Zh = Zs @ tf.expand_dims(feat, axis=0) @ self.W\n        Zh = tf.reduce_sum(Zh, 0)\n\n        if self.bias is not None:\n            Zh = Zh + self.bias\n        return Zh\n"
  },
  {
    "path": "python/dgl/nn/tensorflow/conv/edgeconv.py",
    "content": "\"\"\"Tensorflow modules for EdgeConv Layer\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name\nimport tensorflow as tf\nfrom tensorflow.keras import layers\n\nfrom .... import function as fn\nfrom ....base import DGLError\nfrom ....utils import expand_as_pair\n\n\nclass EdgeConv(layers.Layer):\n    r\"\"\"EdgeConv layer from `Dynamic Graph CNN for Learning on Point Clouds\n    <https://arxiv.org/pdf/1801.07829>`__\n\n    It can be described as follows:\n\n    .. math::\n\n       h_i^{(l+1)} = \\max_{j \\in \\mathcal{N}(i)} (\n       \\Theta \\cdot (h_j^{(l)} - h_i^{(l)}) + \\Phi \\cdot h_i^{(l)})\n\n    where :math:`\\mathcal{N}(i)` is the neighbor of :math:`i`,\n    :math:`\\Theta` and :math:`\\Phi` are linear layers.\n\n    .. note::\n\n       The original formulation includes a ReLU inside the maximum operator.\n       This is equivalent to first applying a maximum operator then applying\n       the ReLU.\n\n    Parameters\n    ----------\n    in_feat : int\n        Input feature size; i.e, the number of dimensions of :math:`h_j^{(l)}`.\n    out_feat : int\n        Output feature size; i.e., the number of dimensions of :math:`h_i^{(l+1)}`.\n    batch_norm : bool\n        Whether to include batch normalization on messages. Default: ``False``.\n    allow_zero_in_degree : bool, optional\n        If there are 0-in-degree nodes in the graph, output for those nodes will be invalid\n        since no message will be passed to those nodes. This is harmful for some applications\n        causing silent performance regression. This module will raise a DGLError if it detects\n        0-in-degree nodes in input graph. By setting ``True``, it will suppress the check\n        and let the users handle it by themselves. Default: ``False``.\n\n    Note\n    ----\n\n    Zero in-degree nodes will lead to invalid output value. This is because no message\n    will be passed to those nodes, the aggregation function will be appied on empty input.\n    A common practice to avoid this is to add a self-loop for each node in the graph if\n    it is homogeneous, which can be achieved by:\n\n    >>> g = ... # a DGLGraph\n    >>> g = dgl.add_self_loop(g)\n\n    Calling ``add_self_loop`` will not work for some graphs, for example, heterogeneous graph\n    since the edge type can not be decided for self_loop edges. Set ``allow_zero_in_degree``\n    to ``True`` for those cases to unblock the code and handle zero-in-degree nodes manually.\n    A common practise to handle this is to filter out the nodes with zero-in-degree when use\n    after conv.\n    \"\"\"\n\n    def __init__(self, out_feats, batch_norm=False, allow_zero_in_degree=False):\n        super(EdgeConv, self).__init__()\n        self.batch_norm = batch_norm\n        self._allow_zero_in_degree = allow_zero_in_degree\n\n        self.theta = layers.Dense(out_feats)\n        self.phi = layers.Dense(out_feats)\n        if batch_norm:\n            self.bn = layers.BatchNormalization()\n\n    def set_allow_zero_in_degree(self, set_value):\n        r\"\"\"Set allow_zero_in_degree flag.\n\n        Parameters\n        ----------\n        set_value : bool\n            The value to be set to the flag.\n        \"\"\"\n        self._allow_zero_in_degree = set_value\n\n    def call(self, g, feat):\n        \"\"\"Forward computation\n\n        Parameters\n        ----------\n        g : DGLGraph\n            The graph.\n        feat : tf.Tensor or pair of tf.Tensor\n            :math:`(N, D)` where :math:`N` is the number of nodes and\n            :math:`D` is the number of feature dimensions.\n            If a pair of tensors is given, the graph must be a uni-bipartite graph\n            with only one edge type, and the two tensors must have the same\n            dimensionality on all except the first axis.\n\n        Returns\n        -------\n        tf.Tensor or pair of tf.Tensor\n            New node features.\n\n        Raises\n        ------\n        DGLError\n            If there are 0-in-degree nodes in the input graph, it will raise DGLError\n            since no message will be passed to those nodes. This will cause invalid output.\n            The error can be ignored by setting ``allow_zero_in_degree`` parameter to ``True``.\n        \"\"\"\n        with g.local_scope():\n            if not self._allow_zero_in_degree:\n                if tf.math.count_nonzero(g.in_degrees() == 0) > 0:\n                    raise DGLError(\n                        \"There are 0-in-degree nodes in the graph, \"\n                        \"output for those nodes will be invalid. \"\n                        \"This is harmful for some applications, \"\n                        \"causing silent performance regression. \"\n                        \"Adding self-loop on the input graph by \"\n                        \"calling `g = dgl.add_self_loop(g)` will resolve \"\n                        \"the issue. Setting ``allow_zero_in_degree`` \"\n                        \"to be `True` when constructing this module will \"\n                        \"suppress the check and let the code run.\"\n                    )\n            h_src, h_dst = expand_as_pair(feat, g)\n            g.srcdata[\"x\"] = h_src\n            g.dstdata[\"x\"] = h_dst\n            g.apply_edges(fn.v_sub_u(\"x\", \"x\", \"theta\"))\n            g.edata[\"theta\"] = self.theta(g.edata[\"theta\"])\n            g.dstdata[\"phi\"] = self.phi(g.dstdata[\"x\"])\n            if not self.batch_norm:\n                g.update_all(fn.e_add_v(\"theta\", \"phi\", \"e\"), fn.max(\"e\", \"x\"))\n            else:\n                g.apply_edges(fn.e_add_v(\"theta\", \"phi\", \"e\"))\n                # for more comments on why global batch norm instead\n                # of batch norm within EdgeConv go to\n                # https://github.com/dmlc/dgl/blob/master/python/dgl/nn/pytorch/conv/edgeconv.py\n                g.edata[\"e\"] = self.bn(g.edata[\"e\"])\n                g.update_all(fn.copy_e(\"e\", \"e\"), fn.max(\"e\", \"x\"))\n            return g.dstdata[\"x\"]\n"
  },
  {
    "path": "python/dgl/nn/tensorflow/conv/gatconv.py",
    "content": "\"\"\"Tensorflow modules for graph attention networks(GAT).\"\"\"\nimport numpy as np\n\n# pylint: disable= no-member, arguments-differ, invalid-name\nimport tensorflow as tf\nfrom tensorflow.keras import layers\n\nfrom .... import function as fn\nfrom ....base import DGLError\nfrom ...functional import edge_softmax\nfrom ..utils import Identity\n\n# pylint: enable=W0235\n\n\nclass GATConv(layers.Layer):\n    r\"\"\"Graph Attention Layer from `Graph Attention Network\n    <https://arxiv.org/pdf/1710.10903.pdf>`__\n\n    .. math::\n        h_i^{(l+1)} = \\sum_{j\\in \\mathcal{N}(i)} \\alpha_{i,j} W^{(l)} h_j^{(l)}\n\n    where :math:`\\alpha_{ij}` is the attention score bewteen node :math:`i` and\n    node :math:`j`:\n\n    .. math::\n        \\alpha_{ij}^{l} &= \\mathrm{softmax_i} (e_{ij}^{l})\n\n        e_{ij}^{l} &= \\mathrm{LeakyReLU}\\left(\\vec{a}^T [W h_{i} \\| W h_{j}]\\right)\n\n    Parameters\n    ----------\n    in_feats : int, or pair of ints\n        Input feature size; i.e, the number of dimensions of :math:`h_i^{(l)}`.\n        ATConv can be applied on homogeneous graph and unidirectional\n        `bipartite graph <https://docs.dgl.ai/generated/dgl.bipartite.html?highlight=bipartite>`__.\n        If the layer is to be applied to a unidirectional bipartite graph, ``in_feats``\n        specifies the input feature size on both the source and destination nodes.  If\n        a scalar is given, the source and destination node feature size would take the\n        same value.\n    out_feats : int\n        Output feature size; i.e, the number of dimensions of :math:`h_i^{(l+1)}`.\n    num_heads : int\n        Number of heads in Multi-Head Attention.\n    feat_drop : float, optional\n        Dropout rate on feature. Defaults: ``0``.\n    attn_drop : float, optional\n        Dropout rate on attention weight. Defaults: ``0``.\n    negative_slope : float, optional\n        LeakyReLU angle of negative slope. Defaults: ``0.2``.\n    residual : bool, optional\n        If True, use residual connection. Defaults: ``False``.\n    activation : callable activation function/layer or None, optional.\n        If not None, applies an activation function to the updated node features.\n        Default: ``None``.\n    allow_zero_in_degree : bool, optional\n        If there are 0-in-degree nodes in the graph, output for those nodes will be invalid\n        since no message will be passed to those nodes. This is harmful for some applications\n        causing silent performance regression. This module will raise a DGLError if it detects\n        0-in-degree nodes in input graph. By setting ``True``, it will suppress the check\n        and let the users handle it by themselves. Defaults: ``False``.\n\n    Note\n    ----\n    Zero in-degree nodes will lead to invalid output value. This is because no message\n    will be passed to those nodes, the aggregation function will be appied on empty input.\n    A common practice to avoid this is to add a self-loop for each node in the graph if\n    it is homogeneous, which can be achieved by:\n\n    >>> g = ... # a DGLGraph\n    >>> g = dgl.add_self_loop(g)\n\n    Calling ``add_self_loop`` will not work for some graphs, for example, heterogeneous graph\n    since the edge type can not be decided for self_loop edges. Set ``allow_zero_in_degree``\n    to ``True`` for those cases to unblock the code and handle zero-in-degree nodes manually.\n    A common practise to handle this is to filter out the nodes with zero-in-degree when use\n    after conv.\n\n    Examples\n    --------\n    >>> import dgl\n    >>> import numpy as np\n    >>> import tensorflow as tf\n    >>> from dgl.nn import GATConv\n    >>>\n    >>> # Case 1: Homogeneous graph\n    >>> with tf.device(\"CPU:0\"):\n    >>>     g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))\n    >>>     g = dgl.add_self_loop(g)\n    >>>     feat = tf.ones((6, 10))\n    >>>     gatconv = GATConv(10, 2, num_heads=3)\n    >>>     res = gatconv(g, feat)\n    >>>     res\n    <tf.Tensor: shape=(6, 3, 2), dtype=float32, numpy=\n    array([[[ 0.75311995, -1.8093625 ],\n            [-0.12128812, -0.78072834],\n            [-0.49870574, -0.15074375]],\n        [[ 0.75311995, -1.8093625 ],\n            [-0.12128812, -0.78072834],\n            [-0.49870574, -0.15074375]],\n        [[ 0.75311995, -1.8093625 ],\n            [-0.12128812, -0.78072834],\n            [-0.49870574, -0.15074375]],\n        [[ 0.75311995, -1.8093626 ],\n            [-0.12128813, -0.78072834],\n            [-0.49870574, -0.15074375]],\n        [[ 0.75311995, -1.8093625 ],\n            [-0.12128812, -0.78072834],\n            [-0.49870574, -0.15074375]],\n        [[ 0.75311995, -1.8093625 ],\n            [-0.12128812, -0.78072834],\n            [-0.49870574, -0.15074375]]], dtype=float32)>\n\n    >>> # Case 2: Unidirectional bipartite graph\n    >>> u = [0, 1, 0, 0, 1]\n    >>> v = [0, 1, 2, 3, 2]\n    >>> g = dgl.heterograph({('A', 'r', 'B'): (u, v)})\n    >>> with tf.device(\"CPU:0\"):\n    >>>     u_feat = tf.convert_to_tensor(np.random.rand(2, 5))\n    >>>     v_feat = tf.convert_to_tensor(np.random.rand(4, 10))\n    >>>     gatconv = GATConv((5,10), 2, 3)\n    >>>     res = gatconv(g, (u_feat, v_feat))\n    >>>     res\n    <tf.Tensor: shape=(4, 3, 2), dtype=float32, numpy=\n    array([[[-0.89649093, -0.74841046],\n            [ 0.5088224 ,  0.10908248],\n            [ 0.55670375, -0.6811229 ]],\n        [[-0.7905004 , -0.1457274 ],\n            [ 0.2248168 ,  0.93014705],\n            [ 0.12816726, -0.4093595 ]],\n        [[-0.85875374, -0.53382933],\n            [ 0.36841977,  0.51498866],\n            [ 0.31893706, -0.5303393 ]],\n        [[-0.89649093, -0.74841046],\n            [ 0.5088224 ,  0.10908248],\n            [ 0.55670375, -0.6811229 ]]], dtype=float32)>\n    \"\"\"\n\n    def __init__(\n        self,\n        in_feats,\n        out_feats,\n        num_heads,\n        feat_drop=0.0,\n        attn_drop=0.0,\n        negative_slope=0.2,\n        residual=False,\n        activation=None,\n        allow_zero_in_degree=False,\n    ):\n        super(GATConv, self).__init__()\n        self._num_heads = num_heads\n        self._in_feats = in_feats\n        self._out_feats = out_feats\n        self._allow_zero_in_degree = allow_zero_in_degree\n        xinit = tf.keras.initializers.VarianceScaling(\n            scale=np.sqrt(2), mode=\"fan_avg\", distribution=\"untruncated_normal\"\n        )\n        if isinstance(in_feats, tuple):\n            self.fc_src = layers.Dense(\n                out_feats * num_heads, use_bias=False, kernel_initializer=xinit\n            )\n            self.fc_dst = layers.Dense(\n                out_feats * num_heads, use_bias=False, kernel_initializer=xinit\n            )\n        else:\n            self.fc = layers.Dense(\n                out_feats * num_heads, use_bias=False, kernel_initializer=xinit\n            )\n        self.attn_l = tf.Variable(\n            initial_value=xinit(\n                shape=(1, num_heads, out_feats), dtype=\"float32\"\n            ),\n            trainable=True,\n        )\n        self.attn_r = tf.Variable(\n            initial_value=xinit(\n                shape=(1, num_heads, out_feats), dtype=\"float32\"\n            ),\n            trainable=True,\n        )\n        self.feat_drop = layers.Dropout(rate=feat_drop)\n        self.attn_drop = layers.Dropout(rate=attn_drop)\n        self.leaky_relu = layers.LeakyReLU(alpha=negative_slope)\n        if residual:\n            if in_feats != out_feats:\n                self.res_fc = layers.Dense(\n                    num_heads * out_feats,\n                    use_bias=False,\n                    kernel_initializer=xinit,\n                )\n            else:\n                self.res_fc = Identity()\n        else:\n            self.res_fc = None\n            # self.register_buffer('res_fc', None)\n        self.activation = activation\n\n    def set_allow_zero_in_degree(self, set_value):\n        r\"\"\"Set allow_zero_in_degree flag.\n\n        Parameters\n        ----------\n        set_value : bool\n            The value to be set to the flag.\n        \"\"\"\n        self._allow_zero_in_degree = set_value\n\n    def call(self, graph, feat, get_attention=False):\n        r\"\"\"Compute graph attention network layer.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The graph.\n        feat : tf.Tensor or pair of tf.Tensor\n            If a tf.Tensor is given, the input feature of shape :math:`(N, *, D_{in})` where\n            :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.\n            If a pair of tf.Tensor is given, the pair must contain two tensors of shape\n            :math:`(N_{in}, *, D_{in_{src}})` and :math:`(N_{out}, *, D_{in_{dst}})`.\n        get_attention : bool, optional\n            Whether to return the attention values. Default to False.\n\n        Returns\n        -------\n        tf.Tensor\n            The output feature of shape :math:`(N, *, H, D_{out})` where :math:`H`\n            is the number of heads, and :math:`D_{out}` is size of output feature.\n        tf.Tensor, optional\n            The attention values of shape :math:`(E, *, H, 1)`, where :math:`E` is the number of\n            edges. This is returned only when :attr:`get_attention` is ``True``.\n\n        Raises\n        ------\n        DGLError\n            If there are 0-in-degree nodes in the input graph, it will raise DGLError\n            since no message will be passed to those nodes. This will cause invalid output.\n            The error can be ignored by setting ``allow_zero_in_degree`` parameter to ``True``.\n        \"\"\"\n        with graph.local_scope():\n            if not self._allow_zero_in_degree:\n                if tf.math.count_nonzero(graph.in_degrees() == 0) > 0:\n                    raise DGLError(\n                        \"There are 0-in-degree nodes in the graph, \"\n                        \"output for those nodes will be invalid. \"\n                        \"This is harmful for some applications, \"\n                        \"causing silent performance regression. \"\n                        \"Adding self-loop on the input graph by \"\n                        \"calling `g = dgl.add_self_loop(g)` will resolve \"\n                        \"the issue. Setting ``allow_zero_in_degree`` \"\n                        \"to be `True` when constructing this module will \"\n                        \"suppress the check and let the code run.\"\n                    )\n\n            if isinstance(feat, tuple):\n                src_prefix_shape = tuple(feat[0].shape[:-1])\n                dst_prefix_shape = tuple(feat[1].shape[:-1])\n                h_src = self.feat_drop(feat[0])\n                h_dst = self.feat_drop(feat[1])\n                if not hasattr(self, \"fc_src\"):\n                    self.fc_src, self.fc_dst = self.fc, self.fc\n                feat_src = tf.reshape(\n                    self.fc_src(h_src),\n                    src_prefix_shape + (self._num_heads, self._out_feats),\n                )\n                feat_dst = tf.reshape(\n                    self.fc_dst(h_dst),\n                    dst_prefix_shape + (self._num_heads, self._out_feats),\n                )\n            else:\n                src_prefix_shape = dst_prefix_shape = tuple(feat.shape[:-1])\n                h_src = h_dst = self.feat_drop(feat)\n                feat_src = feat_dst = tf.reshape(\n                    self.fc(h_src),\n                    src_prefix_shape + (self._num_heads, self._out_feats),\n                )\n                if graph.is_block:\n                    feat_dst = feat_src[: graph.number_of_dst_nodes()]\n                    h_dst = h_dst[: graph.number_of_dst_nodes()]\n                    dst_prefix_shape = (\n                        graph.number_of_dst_nodes(),\n                    ) + dst_prefix_shape[1:]\n            # NOTE: GAT paper uses \"first concatenation then linear projection\"\n            # to compute attention scores, while ours is \"first projection then\n            # addition\", the two approaches are mathematically equivalent:\n            # We decompose the weight vector a mentioned in the paper into\n            # [a_l || a_r], then\n            # a^T [Wh_i || Wh_j] = a_l Wh_i + a_r Wh_j\n            # Our implementation is much efficient because we do not need to\n            # save [Wh_i || Wh_j] on edges, which is not memory-efficient. Plus,\n            # addition could be optimized with DGL's built-in function u_add_v,\n            # which further speeds up computation and saves memory footprint.\n            el = tf.reduce_sum(feat_src * self.attn_l, axis=-1, keepdims=True)\n            er = tf.reduce_sum(feat_dst * self.attn_r, axis=-1, keepdims=True)\n            graph.srcdata.update({\"ft\": feat_src, \"el\": el})\n            graph.dstdata.update({\"er\": er})\n            # compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively.\n            graph.apply_edges(fn.u_add_v(\"el\", \"er\", \"e\"))\n            e = self.leaky_relu(graph.edata.pop(\"e\"))\n            # compute softmax\n            graph.edata[\"a\"] = self.attn_drop(edge_softmax(graph, e))\n            # message passing\n            graph.update_all(fn.u_mul_e(\"ft\", \"a\", \"m\"), fn.sum(\"m\", \"ft\"))\n            rst = graph.dstdata[\"ft\"]\n            # residual\n            if self.res_fc is not None:\n                resval = tf.reshape(\n                    self.res_fc(h_dst), dst_prefix_shape + (-1, self._out_feats)\n                )\n                rst = rst + resval\n            # activation\n            if self.activation:\n                rst = self.activation(rst)\n\n            if get_attention:\n                return rst, graph.edata[\"a\"]\n            else:\n                return rst\n"
  },
  {
    "path": "python/dgl/nn/tensorflow/conv/ginconv.py",
    "content": "\"\"\"Tensorflow Module for Graph Isomorphism Network layer\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name\nimport tensorflow as tf\nfrom tensorflow.keras import layers\n\nfrom .... import function as fn\nfrom ....utils import expand_as_pair\n\n\nclass GINConv(layers.Layer):\n    r\"\"\"Graph Isomorphism Network layer from `How Powerful are Graph\n    Neural Networks? <https://arxiv.org/pdf/1810.00826.pdf>`__\n\n    .. math::\n        h_i^{(l+1)} = f_\\Theta \\left((1 + \\epsilon) h_i^{l} +\n        \\mathrm{aggregate}\\left(\\left\\{h_j^{l}, j\\in\\mathcal{N}(i)\n        \\right\\}\\right)\\right)\n\n    Parameters\n    ----------\n    apply_func : callable activation function/layer or None\n        If not None, apply this function to the updated node feature,\n        the :math:`f_\\Theta` in the formula.\n    aggregator_type : str\n        Aggregator type to use (``sum``, ``max`` or ``mean``).\n    init_eps : float, optional\n        Initial :math:`\\epsilon` value, default: ``0``.\n    learn_eps : bool, optional\n        If True, :math:`\\epsilon` will be a learnable parameter. Default: ``False``.\n\n    Example\n    -------\n    >>> import dgl\n    >>> import numpy as np\n    >>> import tensorflow as tf\n    >>> from dgl.nn import GINConv\n    >>>\n    >>> with tf.device(\"CPU:0\"):\n    >>>     g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))\n    >>>     feat = tf.ones((6, 10))\n    >>>     lin = tf.keras.layers.Dense(10)\n    >>>     conv = GINConv(lin, 'max')\n    >>>     res = conv(g, feat)\n    >>>     res\n    <tf.Tensor: shape=(6, 10), dtype=float32, numpy=\n    array([[-0.1090256 ,  1.9050574 , -0.30704725, -1.995831  , -0.36399186,\n            1.10414   ,  2.4885745 , -0.35387516,  1.3568261 ,  1.7267858 ],\n        [-0.1090256 ,  1.9050574 , -0.30704725, -1.995831  , -0.36399186,\n            1.10414   ,  2.4885745 , -0.35387516,  1.3568261 ,  1.7267858 ],\n        [-0.1090256 ,  1.9050574 , -0.30704725, -1.995831  , -0.36399186,\n            1.10414   ,  2.4885745 , -0.35387516,  1.3568261 ,  1.7267858 ],\n        [-0.1090256 ,  1.9050574 , -0.30704725, -1.995831  , -0.36399186,\n            1.10414   ,  2.4885745 , -0.35387516,  1.3568261 ,  1.7267858 ],\n        [-0.1090256 ,  1.9050574 , -0.30704725, -1.995831  , -0.36399186,\n            1.10414   ,  2.4885745 , -0.35387516,  1.3568261 ,  1.7267858 ],\n        [-0.0545128 ,  0.9525287 , -0.15352362, -0.9979155 , -0.18199593,\n            0.55207   ,  1.2442873 , -0.17693758,  0.67841303,  0.8633929 ]],\n        dtype=float32)>\n    \"\"\"\n\n    def __init__(\n        self, apply_func, aggregator_type, init_eps=0, learn_eps=False\n    ):\n        super(GINConv, self).__init__()\n        self.apply_func = apply_func\n        if aggregator_type == \"sum\":\n            self._reducer = fn.sum\n        elif aggregator_type == \"max\":\n            self._reducer = fn.max\n        elif aggregator_type == \"mean\":\n            self._reducer = fn.mean\n        else:\n            raise KeyError(\n                \"Aggregator type {} not recognized.\".format(aggregator_type)\n            )\n        # to specify whether eps is trainable or not.\n        self.eps = tf.Variable(\n            initial_value=[init_eps], dtype=tf.float32, trainable=learn_eps\n        )\n\n    def call(self, graph, feat):\n        r\"\"\"Compute Graph Isomorphism Network layer.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The graph.\n        feat : tf.Tensor or pair of tf.Tensor\n            If a tf.Tensor is given, the input feature of shape :math:`(N, D_{in})` where\n            :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.\n            If a pair of tf.Tensor is given, the pair must contain two tensors of shape\n            :math:`(N_{in}, D_{in})` and :math:`(N_{out}, D_{in})`.\n            If ``apply_func`` is not None, :math:`D_{in}` should\n            fit the input dimensionality requirement of ``apply_func``.\n\n        Returns\n        -------\n        tf.Tensor\n            The output feature of shape :math:`(N, D_{out})` where\n            :math:`D_{out}` is the output dimensionality of ``apply_func``.\n            If ``apply_func`` is None, :math:`D_{out}` should be the same\n            as input dimensionality.\n        \"\"\"\n        with graph.local_scope():\n            feat_src, feat_dst = expand_as_pair(feat, graph)\n            graph.srcdata[\"h\"] = feat_src\n            graph.update_all(fn.copy_u(\"h\", \"m\"), self._reducer(\"m\", \"neigh\"))\n            rst = (1 + self.eps) * feat_dst + graph.dstdata[\"neigh\"]\n            if self.apply_func is not None:\n                rst = self.apply_func(rst)\n            return rst\n"
  },
  {
    "path": "python/dgl/nn/tensorflow/conv/graphconv.py",
    "content": "\"\"\"Tensorflow modules for graph convolutions(GCN).\"\"\"\nimport numpy as np\n\n# pylint: disable= no-member, arguments-differ, invalid-name\nimport tensorflow as tf\nfrom tensorflow.keras import layers\n\nfrom .... import function as fn\nfrom ....base import DGLError\nfrom ....utils import expand_as_pair\n\n# pylint: disable=W0235\n\n\nclass GraphConv(layers.Layer):\n    r\"\"\"Graph convolution from `Semi-Supervised Classification with Graph Convolutional Networks\n    <https://arxiv.org/abs/1609.02907>`__\n\n    Mathematically it is defined as follows:\n\n    .. math::\n      h_i^{(l+1)} = \\sigma(b^{(l)} + \\sum_{j\\in\\mathcal{N}(i)}\\frac{1}{c_{ij}}h_j^{(l)}W^{(l)})\n\n    where :math:`\\mathcal{N}(i)` is the set of neighbors of node :math:`i`,\n    :math:`c_{ij}` is the product of the square root of node degrees\n    (i.e.,  :math:`c_{ij} = \\sqrt{|\\mathcal{N}(i)|}\\sqrt{|\\mathcal{N}(j)|}`),\n    and :math:`\\sigma` is an activation function.\n\n    Parameters\n    ----------\n    in_feats : int\n        Input feature size; i.e, the number of dimensions of :math:`h_j^{(l)}`.\n    out_feats : int\n        Output feature size; i.e., the number of dimensions of :math:`h_i^{(l+1)}`.\n    norm : str, optional\n        How to apply the normalizer.  Can be one of the following values:\n\n        * ``right``, to divide the aggregated messages by each node's in-degrees,\n          which is equivalent to averaging the received messages.\n\n        * ``none``, where no normalization is applied.\n\n        * ``both`` (default), where the messages are scaled with :math:`1/c_{ji}` above, equivalent\n          to symmetric normalization.\n\n        * ``left``, to divide the messages sent out from each node by its out-degrees,\n          equivalent to random walk normalization.\n    weight : bool, optional\n        If True, apply a linear layer. Otherwise, aggregating the messages\n        without a weight matrix.\n    bias : bool, optional\n        If True, adds a learnable bias to the output. Default: ``True``.\n    activation : callable activation function/layer or None, optional\n        If not None, applies an activation function to the updated node features.\n        Default: ``None``.\n    allow_zero_in_degree : bool, optional\n        If there are 0-in-degree nodes in the graph, output for those nodes will be invalid\n        since no message will be passed to those nodes. This is harmful for some applications\n        causing silent performance regression. This module will raise a DGLError if it detects\n        0-in-degree nodes in input graph. By setting ``True``, it will suppress the check\n        and let the users handle it by themselves. Default: ``False``.\n\n    Attributes\n    ----------\n    weight : torch.Tensor\n        The learnable weight tensor.\n    bias : torch.Tensor\n        The learnable bias tensor.\n\n    Note\n    ----\n    Zero in-degree nodes will lead to invalid output value. This is because no message\n    will be passed to those nodes, the aggregation function will be appied on empty input.\n    A common practice to avoid this is to add a self-loop for each node in the graph if\n    it is homogeneous, which can be achieved by:\n\n    >>> g = ... # a DGLGraph\n    >>> g = dgl.add_self_loop(g)\n\n    Calling ``add_self_loop`` will not work for some graphs, for example, heterogeneous graph\n    since the edge type can not be decided for self_loop edges. Set ``allow_zero_in_degree``\n    to ``True`` for those cases to unblock the code and handle zero-in-degree nodes manually.\n    A common practise to handle this is to filter out the nodes with zero-in-degree when use\n    after conv.\n\n    Examples\n    --------\n    >>> import dgl\n    >>> import numpy as np\n    >>> import tensorflow as tf\n    >>> from dgl.nn import GraphConv\n\n    >>> # Case 1: Homogeneous graph\n    >>> with tf.device(\"CPU:0\"):\n    ...     g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))\n    ...     g = dgl.add_self_loop(g)\n    ...     feat = tf.ones((6, 10))\n    ...     conv = GraphConv(10, 2, norm='both', weight=True, bias=True)\n    ...     res = conv(g, feat)\n    >>> print(res)\n    <tf.Tensor: shape=(6, 2), dtype=float32, numpy=\n    array([[ 0.6208475 , -0.4896223 ],\n        [ 0.68356586, -0.5390842 ],\n        [ 0.6208475 , -0.4896223 ],\n        [ 0.7859846 , -0.61985517],\n        [ 0.8251371 , -0.65073216],\n        [ 0.48335412, -0.38119012]], dtype=float32)>\n    >>> # allow_zero_in_degree example\n    >>> with tf.device(\"CPU:0\"):\n    ...     g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))\n    ...     conv = GraphConv(10, 2, norm='both', weight=True, bias=True, allow_zero_in_degree=True)\n    ...     res = conv(g, feat)\n    >>> print(res)\n        <tf.Tensor: shape=(6, 2), dtype=float32, numpy=\n        array([[ 0.6208475 , -0.4896223 ],\n            [ 0.68356586, -0.5390842 ],\n            [ 0.6208475 , -0.4896223 ],\n            [ 0.7859846 , -0.61985517],\n            [ 0.8251371 , -0.65073216],\n            [ 0., 0.]], dtype=float32)>\n\n    >>> # Case 2: Unidirectional bipartite graph\n    >>> u = [0, 1, 0, 0, 1]\n    >>> v = [0, 1, 2, 3, 2]\n    >>> with tf.device(\"CPU:0\"):\n    ...     g = dgl.heterograph({('_N', '_E', '_N'):(u, v)})\n    ...     u_fea = tf.convert_to_tensor(np.random.rand(2, 5))\n    ...     v_fea = tf.convert_to_tensor(np.random.rand(4, 5))\n    ...     conv = GraphConv(5, 2, norm='both', weight=True, bias=True)\n    ...     res = conv(g, (u_fea, v_fea))\n    >>> res\n    <tf.Tensor: shape=(4, 2), dtype=float32, numpy=\n    array([[ 1.3607183, -0.1636453],\n        [ 1.6665325, -0.2004239],\n        [ 2.1405895, -0.2574358],\n        [ 1.3607183, -0.1636453]], dtype=float32)>\n    \"\"\"\n\n    def __init__(\n        self,\n        in_feats,\n        out_feats,\n        norm=\"both\",\n        weight=True,\n        bias=True,\n        activation=None,\n        allow_zero_in_degree=False,\n    ):\n        super(GraphConv, self).__init__()\n        if norm not in (\"none\", \"both\", \"right\", \"left\"):\n            raise DGLError(\n                'Invalid norm value. Must be either \"none\", \"both\", \"right\" or \"left\".'\n                ' But got \"{}\".'.format(norm)\n            )\n        self._in_feats = in_feats\n        self._out_feats = out_feats\n        self._norm = norm\n        self._allow_zero_in_degree = allow_zero_in_degree\n\n        if weight:\n            xinit = tf.keras.initializers.glorot_uniform()\n            self.weight = tf.Variable(\n                initial_value=xinit(\n                    shape=(in_feats, out_feats), dtype=\"float32\"\n                ),\n                trainable=True,\n            )\n        else:\n            self.weight = None\n\n        if bias:\n            zeroinit = tf.keras.initializers.zeros()\n            self.bias = tf.Variable(\n                initial_value=zeroinit(shape=(out_feats), dtype=\"float32\"),\n                trainable=True,\n            )\n        else:\n            self.bias = None\n\n        self._activation = activation\n\n    def set_allow_zero_in_degree(self, set_value):\n        r\"\"\"Set allow_zero_in_degree flag.\n\n        Parameters\n        ----------\n        set_value : bool\n            The value to be set to the flag.\n        \"\"\"\n        self._allow_zero_in_degree = set_value\n\n    def call(self, graph, feat, weight=None):\n        r\"\"\"Compute graph convolution.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The graph.\n        feat : torch.Tensor or pair of torch.Tensor\n            If a torch.Tensor is given, it represents the input feature of shape\n            :math:`(N, D_{in})`\n            where :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.\n            If a pair of torch.Tensor is given, which is the case for bipartite graph, the pair\n            must contain two tensors of shape :math:`(N_{in}, D_{in_{src}})` and\n            :math:`(N_{out}, D_{in_{dst}})`.\n        weight : torch.Tensor, optional\n            Optional external weight tensor.\n\n        Returns\n        -------\n        torch.Tensor\n            The output feature\n\n        Raises\n        ------\n        DGLError\n            If there are 0-in-degree nodes in the input graph, it will raise DGLError\n            since no message will be passed to those nodes. This will cause invalid output.\n            The error can be ignored by setting ``allow_zero_in_degree`` parameter to ``True``.\n\n        Note\n        ----\n        * Input shape: :math:`(N, *, \\text{in_feats})` where * means any number of additional\n          dimensions, :math:`N` is the number of nodes.\n        * Output shape: :math:`(N, *, \\text{out_feats})` where all but the last dimension are\n          the same shape as the input.\n        * Weight shape: :math:`(\\text{in_feats}, \\text{out_feats})`.\n        \"\"\"\n        with graph.local_scope():\n            if not self._allow_zero_in_degree:\n                if tf.math.count_nonzero(graph.in_degrees() == 0) > 0:\n                    raise DGLError(\n                        \"There are 0-in-degree nodes in the graph, \"\n                        \"output for those nodes will be invalid. \"\n                        \"This is harmful for some applications, \"\n                        \"causing silent performance regression. \"\n                        \"Adding self-loop on the input graph by \"\n                        \"calling `g = dgl.add_self_loop(g)` will resolve \"\n                        \"the issue. Setting ``allow_zero_in_degree`` \"\n                        \"to be `True` when constructing this module will \"\n                        \"suppress the check and let the code run.\"\n                    )\n\n            feat_src, feat_dst = expand_as_pair(feat, graph)\n            if self._norm in [\"both\", \"left\"]:\n                degs = tf.clip_by_value(\n                    tf.cast(graph.out_degrees(), tf.float32),\n                    clip_value_min=1,\n                    clip_value_max=np.inf,\n                )\n                if self._norm == \"both\":\n                    norm = tf.pow(degs, -0.5)\n                else:\n                    norm = 1.0 / degs\n                shp = norm.shape + (1,) * (feat_dst.ndim - 1)\n                norm = tf.reshape(norm, shp)\n                feat_src = feat_src * norm\n\n            if weight is not None:\n                if self.weight is not None:\n                    raise DGLError(\n                        \"External weight is provided while at the same time the\"\n                        \" module has defined its own weight parameter. Please\"\n                        \" create the module with flag weight=False.\"\n                    )\n            else:\n                weight = self.weight\n\n            if self._in_feats > self._out_feats:\n                # mult W first to reduce the feature size for aggregation.\n                if weight is not None:\n                    feat_src = tf.matmul(feat_src, weight)\n                graph.srcdata[\"h\"] = feat_src\n                graph.update_all(\n                    fn.copy_u(u=\"h\", out=\"m\"), fn.sum(msg=\"m\", out=\"h\")\n                )\n                rst = graph.dstdata[\"h\"]\n            else:\n                # aggregate first then mult W\n                graph.srcdata[\"h\"] = feat_src\n                graph.update_all(\n                    fn.copy_u(u=\"h\", out=\"m\"), fn.sum(msg=\"m\", out=\"h\")\n                )\n                rst = graph.dstdata[\"h\"]\n                if weight is not None:\n                    rst = tf.matmul(rst, weight)\n\n            if self._norm in [\"both\", \"right\"]:\n                degs = tf.clip_by_value(\n                    tf.cast(graph.in_degrees(), tf.float32),\n                    clip_value_min=1,\n                    clip_value_max=np.inf,\n                )\n                if self._norm == \"both\":\n                    norm = tf.pow(degs, -0.5)\n                else:\n                    norm = 1.0 / degs\n                shp = norm.shape + (1,) * (feat_dst.ndim - 1)\n                norm = tf.reshape(norm, shp)\n                rst = rst * norm\n\n            if self.bias is not None:\n                rst = rst + self.bias\n\n            if self._activation is not None:\n                rst = self._activation(rst)\n\n            return rst\n\n    def extra_repr(self):\n        \"\"\"Set the extra representation of the module,\n        which will come into effect when printing the model.\n        \"\"\"\n        summary = \"in={_in_feats}, out={_out_feats}\"\n        summary += \", normalization={_norm}\"\n        if \"_activation\" in self.__dict__:\n            summary += \", activation={_activation}\"\n        return summary.format(**self.__dict__)\n"
  },
  {
    "path": "python/dgl/nn/tensorflow/conv/relgraphconv.py",
    "content": "\"\"\"Tensorflow Module for Relational graph convolution layer\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name\nimport tensorflow as tf\nfrom tensorflow.keras import layers\n\nfrom .... import function as fn\nfrom .. import utils\n\n\nclass RelGraphConv(layers.Layer):\n    r\"\"\"Relational graph convolution layer from `Modeling Relational Data with Graph\n    Convolutional Networks <https://arxiv.org/abs/1703.06103>`__\n\n    It can be described as below:\n\n    .. math::\n\n       h_i^{(l+1)} = \\sigma(\\sum_{r\\in\\mathcal{R}}\n       \\sum_{j\\in\\mathcal{N}^r(i)}\\frac{1}{c_{i,r}}W_r^{(l)}h_j^{(l)}+W_0^{(l)}h_i^{(l)})\n\n    where :math:`\\mathcal{N}^r(i)` is the neighbor set of node :math:`i` w.r.t. relation\n    :math:`r`. :math:`c_{i,r}` is the normalizer equal\n    to :math:`|\\mathcal{N}^r(i)|`. :math:`\\sigma` is an activation function. :math:`W_0`\n    is the self-loop weight.\n\n    The basis regularization decomposes :math:`W_r` by:\n\n    .. math::\n\n       W_r^{(l)} = \\sum_{b=1}^B a_{rb}^{(l)}V_b^{(l)}\n\n    where :math:`B` is the number of bases, :math:`V_b^{(l)}` are linearly combined\n    with coefficients :math:`a_{rb}^{(l)}`.\n\n    The block-diagonal-decomposition regularization decomposes :math:`W_r` into :math:`B`\n    number of block diagonal matrices. We refer :math:`B` as the number of bases.\n\n    The block regularization decomposes :math:`W_r` by:\n\n    .. math::\n\n       W_r^{(l)} = \\oplus_{b=1}^B Q_{rb}^{(l)}\n\n    where :math:`B` is the number of bases, :math:`Q_{rb}^{(l)}` are block\n    bases with shape :math:`R^{(d^{(l+1)}/B)*(d^{l}/B)}`.\n\n    Parameters\n    ----------\n    in_feat : int\n        Input feature size; i.e, the number of dimensions of :math:`h_j^{(l)}`.\n    out_feat : int\n        Output feature size; i.e., the number of dimensions of :math:`h_i^{(l+1)}`.\n    num_rels : int\n        Number of relations. .\n    regularizer : str\n        Which weight regularizer to use \"basis\" or \"bdd\".\n        \"basis\" is short for basis-diagonal-decomposition.\n        \"bdd\" is short for block-diagonal-decomposition.\n    num_bases : int, optional\n        Number of bases. If is none, use number of relations. Default: ``None``.\n    bias : bool, optional\n        True if bias is added. Default: ``True``.\n    activation : callable, optional\n        Activation function. Default: ``None``.\n    self_loop : bool, optional\n        True to include self loop message. Default: ``True``.\n    low_mem : bool, optional\n        True to use low memory implementation of relation message passing function. Default: False.\n        This option trades speed with memory consumption, and will slowdown the forward/backward.\n        Turn it on when you encounter OOM problem during training or evaluation. Default: ``False``.\n    dropout : float, optional\n        Dropout rate. Default: ``0.0``\n    layer_norm: float, optional\n        Add layer norm. Default: ``False``\n\n    Examples\n    --------\n    >>> import dgl\n    >>> import numpy as np\n    >>> import tensorflow as tf\n    >>> from dgl.nn import RelGraphConv\n    >>>\n    >>> with tf.device(\"CPU:0\"):\n    >>>     g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))\n    >>>     feat = tf.ones((6, 10))\n    >>>     conv = RelGraphConv(10, 2, 3, regularizer='basis', num_bases=2)\n    >>>     etype = tf.convert_to_tensor(np.array([0,1,2,0,1,2]).astype(np.int64))\n    >>>     res = conv(g, feat, etype)\n    >>>     res\n    <tf.Tensor: shape=(6, 2), dtype=float32, numpy=\n    array([[-0.02938664,  1.7932655 ],\n        [ 0.1146394 ,  0.48319   ],\n        [-0.02938664,  1.7932655 ],\n        [ 1.2054908 , -0.26098895],\n        [ 0.1146394 ,  0.48319   ],\n        [ 0.75915515,  1.1454091 ]], dtype=float32)>\n\n    >>> # One-hot input\n    >>> with tf.device(\"CPU:0\"):\n    >>>     one_hot_feat = tf.convert_to_tensor(np.array([0,1,2,3,4,5]).astype(np.int64))\n    >>>     res = conv(g, one_hot_feat, etype)\n    >>>     res\n    <tf.Tensor: shape=(6, 2), dtype=float32, numpy=\n    array([[-0.24205256, -0.7922753 ],\n        [ 0.62085056,  0.4893622 ],\n        [-0.9484881 , -0.26546806],\n        [-0.2163915 , -0.12585883],\n        [-0.14293689,  0.77483284],\n        [ 0.091169  , -0.06761569]], dtype=float32)>\n    \"\"\"\n\n    def __init__(\n        self,\n        in_feat,\n        out_feat,\n        num_rels,\n        regularizer=\"basis\",\n        num_bases=None,\n        bias=True,\n        activation=None,\n        self_loop=True,\n        low_mem=False,\n        dropout=0.0,\n        layer_norm=False,\n    ):\n        super(RelGraphConv, self).__init__()\n        self.in_feat = in_feat\n        self.out_feat = out_feat\n        self.num_rels = num_rels\n        self.regularizer = regularizer\n        self.num_bases = num_bases\n        if (\n            self.num_bases is None\n            or self.num_bases > self.num_rels\n            or self.num_bases < 0\n        ):\n            self.num_bases = self.num_rels\n        self.bias = bias\n        self.activation = activation\n        self.self_loop = self_loop\n        self.low_mem = low_mem\n\n        assert (\n            layer_norm is False\n        ), \"TensorFlow currently does not support layer norm.\"\n\n        xinit = tf.keras.initializers.glorot_uniform()\n        zeroinit = tf.keras.initializers.zeros()\n\n        if regularizer == \"basis\":\n            # add basis weights\n            self.weight = tf.Variable(\n                initial_value=xinit(\n                    shape=(self.num_bases, self.in_feat, self.out_feat),\n                    dtype=\"float32\",\n                ),\n                trainable=True,\n            )\n            if self.num_bases < self.num_rels:\n                # linear combination coefficients\n                self.w_comp = tf.Variable(\n                    initial_value=xinit(\n                        shape=(self.num_rels, self.num_bases), dtype=\"float32\"\n                    ),\n                    trainable=True,\n                )\n            # message func\n            self.message_func = self.basis_message_func\n        elif regularizer == \"bdd\":\n            if in_feat % num_bases != 0 or out_feat % num_bases != 0:\n                raise ValueError(\n                    \"Feature size must be a multiplier of num_bases.\"\n                )\n            # add block diagonal weights\n            self.submat_in = in_feat // self.num_bases\n            self.submat_out = out_feat // self.num_bases\n\n            # assuming in_feat and out_feat are both divisible by num_bases\n            self.weight = tf.Variable(\n                initial_value=xinit(\n                    shape=(\n                        self.num_rels,\n                        self.num_bases * self.submat_in * self.submat_out,\n                    ),\n                    dtype=\"float32\",\n                ),\n                trainable=True,\n            )\n            # message func\n            self.message_func = self.bdd_message_func\n        else:\n            raise ValueError(\"Regularizer must be either 'basis' or 'bdd'\")\n\n        # bias\n        if self.bias:\n            self.h_bias = tf.Variable(\n                initial_value=zeroinit(shape=(out_feat), dtype=\"float32\"),\n                trainable=True,\n            )\n\n        # weight for self loop\n        if self.self_loop:\n            self.loop_weight = tf.Variable(\n                initial_value=xinit(shape=(in_feat, out_feat), dtype=\"float32\"),\n                trainable=True,\n            )\n\n        self.dropout = layers.Dropout(rate=dropout)\n\n    def basis_message_func(self, edges):\n        \"\"\"Message function for basis regularizer\"\"\"\n        if self.num_bases < self.num_rels:\n            # generate all weights from bases\n            weight = tf.reshape(\n                self.weight, (self.num_bases, self.in_feat * self.out_feat)\n            )\n            weight = tf.reshape(\n                tf.matmul(self.w_comp, weight),\n                (self.num_rels, self.in_feat, self.out_feat),\n            )\n        else:\n            weight = self.weight\n\n        # calculate msg @ W_r before put msg into edge\n        # if src is th.int64 we expect it is an index select\n        if edges.src[\"h\"].dtype != tf.int64 and self.low_mem:\n            etypes, _ = tf.unique(edges.data[\"type\"])\n            msg = tf.zeros([edges.src[\"h\"].shape[0], self.out_feat])\n            idx = tf.range(edges.src[\"h\"].shape[0])\n            for etype in etypes:\n                loc = edges.data[\"type\"] == etype\n                w = weight[etype]\n                src = tf.boolean_mask(edges.src[\"h\"], loc)\n                sub_msg = tf.matmul(src, w)\n                indices = tf.reshape(tf.boolean_mask(idx, loc), (-1, 1))\n                msg = tf.tensor_scatter_nd_update(msg, indices, sub_msg)\n        else:\n            msg = utils.bmm_maybe_select(\n                edges.src[\"h\"], weight, edges.data[\"type\"]\n            )\n        if \"norm\" in edges.data:\n            msg = msg * edges.data[\"norm\"]\n        return {\"msg\": msg}\n\n    def bdd_message_func(self, edges):\n        \"\"\"Message function for block-diagonal-decomposition regularizer\"\"\"\n        if (edges.src[\"h\"].dtype == tf.int64) and len(\n            edges.src[\"h\"].shape\n        ) == 1:\n            raise TypeError(\n                \"Block decomposition does not allow integer ID feature.\"\n            )\n\n        # calculate msg @ W_r before put msg into edge\n        # if src is th.int64 we expect it is an index select\n        if self.low_mem:\n            etypes, _ = tf.unique(edges.data[\"type\"])\n            msg = tf.zeros([edges.src[\"h\"].shape[0], self.out_feat])\n            idx = tf.range(edges.src[\"h\"].shape[0])\n            for etype in etypes:\n                loc = edges.data[\"type\"] == etype\n                w = tf.reshape(\n                    self.weight[etype],\n                    (self.num_bases, self.submat_in, self.submat_out),\n                )\n                src = tf.reshape(\n                    tf.boolean_mask(edges.src[\"h\"], loc),\n                    (-1, self.num_bases, self.submat_in),\n                )\n                sub_msg = tf.einsum(\"abc,bcd->abd\", src, w)\n                sub_msg = tf.reshape(sub_msg, (-1, self.out_feat))\n                indices = tf.reshape(tf.boolean_mask(idx, loc), (-1, 1))\n                msg = tf.tensor_scatter_nd_update(msg, indices, sub_msg)\n        else:\n            weight = tf.reshape(\n                tf.gather(self.weight, edges.data[\"type\"]),\n                (-1, self.submat_in, self.submat_out),\n            )\n            node = tf.reshape(edges.src[\"h\"], (-1, 1, self.submat_in))\n            msg = tf.reshape(tf.matmul(node, weight), (-1, self.out_feat))\n        if \"norm\" in edges.data:\n            msg = msg * edges.data[\"norm\"]\n        return {\"msg\": msg}\n\n    def call(self, g, x, etypes, norm=None):\n        \"\"\"Forward computation\n\n        Parameters\n        ----------\n        g : DGLGraph\n            The graph.\n        x : tf.Tensor\n            Input node features. Could be either\n\n                * :math:`(|V|, D)` dense tensor\n                * :math:`(|V|,)` int64 vector, representing the categorical values of each\n                  node. We then treat the input feature as an one-hot encoding feature.\n        etypes : tf.Tensor\n            Edge type tensor. Shape: :math:`(|E|,)`\n        norm : tf.Tensor\n            Optional edge normalizer tensor. Shape: :math:`(|E|, 1)`\n\n        Returns\n        -------\n        tf.Tensor\n            New node features.\n        \"\"\"\n        assert g.is_homogeneous, (\n            \"not a homogeneous graph; convert it with to_homogeneous \"\n            \"and pass in the edge type as argument\"\n        )\n        with g.local_scope():\n            g.ndata[\"h\"] = x\n            g.edata[\"type\"] = tf.cast(etypes, tf.int64)\n            if norm is not None:\n                g.edata[\"norm\"] = norm\n            if self.self_loop:\n                loop_message = utils.matmul_maybe_select(x, self.loop_weight)\n            # message passing\n            g.update_all(self.message_func, fn.sum(msg=\"msg\", out=\"h\"))\n            # apply bias and activation\n            node_repr = g.ndata[\"h\"]\n            if self.bias:\n                node_repr = node_repr + self.h_bias\n            if self.self_loop:\n                node_repr = node_repr + loop_message\n            if self.activation:\n                node_repr = self.activation(node_repr)\n            node_repr = self.dropout(node_repr)\n            return node_repr\n"
  },
  {
    "path": "python/dgl/nn/tensorflow/conv/sageconv.py",
    "content": "\"\"\"Tensorflow Module for GraphSAGE layer\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name\nimport tensorflow as tf\nfrom tensorflow.keras import layers\n\nfrom .... import function as fn\nfrom ....base import DGLError\nfrom ....utils import check_eq_shape, expand_as_pair\n\n\nclass SAGEConv(layers.Layer):\n    r\"\"\"GraphSAGE layer from `Inductive Representation Learning on\n    Large Graphs <https://arxiv.org/pdf/1706.02216.pdf>`__\n\n    .. math::\n        h_{\\mathcal{N}(i)}^{(l+1)} &= \\mathrm{aggregate}\n        \\left(\\{h_{j}^{l}, \\forall j \\in \\mathcal{N}(i) \\}\\right)\n\n        h_{i}^{(l+1)} &= \\sigma \\left(W \\cdot \\mathrm{concat}\n        (h_{i}^{l}, h_{\\mathcal{N}(i)}^{l+1}) \\right)\n\n        h_{i}^{(l+1)} &= \\mathrm{norm}(h_{i}^{(l+1)})\n\n    Parameters\n    ----------\n    in_feats : int, or pair of ints\n        Input feature size; i.e, the number of dimensions of :math:`h_i^{(l)}`.\n\n        GATConv can be applied on homogeneous graph and unidirectional\n        `bipartite graph <https://docs.dgl.ai/generated/dgl.bipartite.html?highlight=bipartite>`__.\n        If the layer applies on a unidirectional bipartite graph, ``in_feats``\n        specifies the input feature size on both the source and destination nodes.  If\n        a scalar is given, the source and destination node feature size would take the\n        same value.\n\n        If aggregator type is ``gcn``, the feature size of source and destination nodes\n        are required to be the same.\n    out_feats : int\n        Output feature size; i.e, the number of dimensions of :math:`h_i^{(l+1)}`.\n    aggregator_type : str\n        Aggregator type to use (``mean``, ``gcn``, ``pool``, ``lstm``).\n    feat_drop : float\n        Dropout rate on features, default: ``0``.\n    bias : bool\n        If True, adds a learnable bias to the output. Default: ``True``.\n    norm : callable activation function/layer or None, optional\n        If not None, applies normalization to the updated node features.\n    activation : callable activation function/layer or None, optional\n        If not None, applies an activation function to the updated node features.\n        Default: ``None``.\n\n    Examples\n    --------\n    >>> import dgl\n    >>> import numpy as np\n    >>> import tensorflow as tf\n    >>> from dgl.nn import SAGEConv\n    >>>\n    >>> # Case 1: Homogeneous graph\n    >>> with tf.device(\"CPU:0\"):\n    >>>     g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))\n    >>>     g = dgl.add_self_loop(g)\n    >>>     feat = tf.ones((6, 10))\n    >>>     conv = SAGEConv(10, 2, 'pool')\n    >>>     res = conv(g, feat)\n    >>>     res\n    <tf.Tensor: shape=(6, 2), dtype=float32, numpy=\n    array([[-3.6633523 , -0.90711546],\n        [-3.6633523 , -0.90711546],\n        [-3.6633523 , -0.90711546],\n        [-3.6633523 , -0.90711546],\n        [-3.6633523 , -0.90711546],\n        [-3.6633523 , -0.90711546]], dtype=float32)>\n\n    >>> # Case 2: Unidirectional bipartite graph\n    >>> with tf.device(\"CPU:0\"):\n    >>>     u = [0, 1, 0, 0, 1]\n    >>>     v = [0, 1, 2, 3, 2]\n    >>>     g = dgl.heterograph({('_N', '_E', '_N'):(u, v)})\n    >>>     u_fea = tf.convert_to_tensor(np.random.rand(2, 5))\n    >>>     v_fea = tf.convert_to_tensor(np.random.rand(4, 5))\n    >>>     conv = SAGEConv((5, 10), 2, 'mean')\n    >>>     res = conv(g, (u_fea, v_fea))\n    >>>     res\n    <tf.Tensor: shape=(4, 2), dtype=float32, numpy=\n    array([[-0.59453356, -0.4055441 ],\n        [-0.47459763, -0.717764  ],\n        [ 0.3221837 , -0.29876417],\n        [-0.63356155,  0.09390211]], dtype=float32)>\n    \"\"\"\n\n    def __init__(\n        self,\n        in_feats,\n        out_feats,\n        aggregator_type,\n        feat_drop=0.0,\n        bias=True,\n        norm=None,\n        activation=None,\n    ):\n        super(SAGEConv, self).__init__()\n        valid_aggre_types = {\"mean\", \"gcn\", \"pool\", \"lstm\"}\n        if aggregator_type not in valid_aggre_types:\n            raise DGLError(\n                \"Invalid aggregator_type. Must be one of {}. \"\n                \"But got {!r} instead.\".format(\n                    valid_aggre_types, aggregator_type\n                )\n            )\n\n        self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)\n        self._out_feats = out_feats\n        self._aggre_type = aggregator_type\n        self.norm = norm\n        self.feat_drop = layers.Dropout(feat_drop)\n        self.activation = activation\n        # aggregator type: mean/pool/lstm/gcn\n        if aggregator_type == \"pool\":\n            self.fc_pool = layers.Dense(self._in_src_feats)\n        if aggregator_type == \"lstm\":\n            self.lstm = layers.LSTM(units=self._in_src_feats)\n        if aggregator_type != \"gcn\":\n            self.fc_self = layers.Dense(out_feats, use_bias=bias)\n        self.fc_neigh = layers.Dense(out_feats, use_bias=bias)\n\n    def _lstm_reducer(self, nodes):\n        \"\"\"LSTM reducer\n        NOTE(zihao): lstm reducer with default schedule (degree bucketing)\n        is slow, we could accelerate this with degree padding in the future.\n        \"\"\"\n        m = nodes.mailbox[\"m\"]  # (B, L, D)\n        rst = self.lstm(m)\n        return {\"neigh\": rst}\n\n    def call(self, graph, feat):\n        r\"\"\"Compute GraphSAGE layer.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The graph.\n        feat : tf.Tensor or pair of tf.Tensor\n            If a tf.Tensor is given, it represents the input feature of shape\n            :math:`(N, D_{in})`\n            where :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.\n            If a pair of tf.Tensor is given, the pair must contain two tensors of shape\n            :math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.\n\n        Returns\n        -------\n        tf.Tensor\n            The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`\n            is size of output feature.\n        \"\"\"\n        with graph.local_scope():\n            if isinstance(feat, tuple):\n                feat_src = self.feat_drop(feat[0])\n                feat_dst = self.feat_drop(feat[1])\n            else:\n                feat_src = feat_dst = self.feat_drop(feat)\n                if graph.is_block:\n                    feat_dst = feat_src[: graph.number_of_dst_nodes()]\n\n            h_self = feat_dst\n\n            # Handle the case of graphs without edges\n            if graph.num_edges() == 0:\n                graph.dstdata[\"neigh\"] = tf.cast(\n                    tf.zeros((graph.number_of_dst_nodes(), self._in_src_feats)),\n                    tf.float32,\n                )\n\n            if self._aggre_type == \"mean\":\n                graph.srcdata[\"h\"] = feat_src\n                graph.update_all(fn.copy_u(\"h\", \"m\"), fn.mean(\"m\", \"neigh\"))\n                h_neigh = graph.dstdata[\"neigh\"]\n            elif self._aggre_type == \"gcn\":\n                check_eq_shape(feat)\n                graph.srcdata[\"h\"] = feat_src\n                graph.dstdata[\"h\"] = feat_dst  # same as above if homogeneous\n                graph.update_all(fn.copy_u(\"h\", \"m\"), fn.sum(\"m\", \"neigh\"))\n                # divide in_degrees\n                degs = tf.cast(graph.in_degrees(), tf.float32)\n                h_neigh = (graph.dstdata[\"neigh\"] + graph.dstdata[\"h\"]) / (\n                    tf.expand_dims(degs, -1) + 1\n                )\n            elif self._aggre_type == \"pool\":\n                graph.srcdata[\"h\"] = tf.nn.relu(self.fc_pool(feat_src))\n                graph.update_all(fn.copy_u(\"h\", \"m\"), fn.max(\"m\", \"neigh\"))\n                h_neigh = graph.dstdata[\"neigh\"]\n            elif self._aggre_type == \"lstm\":\n                graph.srcdata[\"h\"] = feat_src\n                graph.update_all(fn.copy_u(\"h\", \"m\"), self._lstm_reducer)\n                h_neigh = graph.dstdata[\"neigh\"]\n            else:\n                raise KeyError(\n                    \"Aggregator type {} not recognized.\".format(\n                        self._aggre_type\n                    )\n                )\n            # GraphSAGE GCN does not require fc_self.\n            if self._aggre_type == \"gcn\":\n                rst = self.fc_neigh(h_neigh)\n            else:\n                rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)\n            # activation\n            if self.activation is not None:\n                rst = self.activation(rst)\n            # normalization\n            if self.norm is not None:\n                rst = self.norm(rst)\n            return rst\n"
  },
  {
    "path": "python/dgl/nn/tensorflow/conv/sgconv.py",
    "content": "\"\"\"tf Module for Simplifying Graph Convolution layer\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name, W0613\nimport numpy as np\nimport tensorflow as tf\nfrom tensorflow.keras import layers\n\nfrom .... import function as fn\nfrom ....base import DGLError\n\n\nclass SGConv(layers.Layer):\n    r\"\"\"SGC layer from `Simplifying Graph\n    Convolutional Networks <https://arxiv.org/pdf/1902.07153.pdf>`__\n\n    .. math::\n        H^{K} = (\\tilde{D}^{-1/2} \\tilde{A} \\tilde{D}^{-1/2})^K X \\Theta\n\n    where :math:`\\tilde{A}` is :math:`A` + :math:`I`.\n    Thus the graph input is expected to have self-loop edges added.\n\n    Parameters\n    ----------\n    in_feats : int\n        Number of input features; i.e, the number of dimensions of :math:`X`.\n    out_feats : int\n        Number of output features; i.e, the number of dimensions of :math:`H^{K}`.\n    k : int\n        Number of hops :math:`K`. Defaults:``1``.\n    cached : bool\n        If True, the module would cache\n\n        .. math::\n            (\\tilde{D}^{-\\frac{1}{2}}\\tilde{A}\\tilde{D}^{-\\frac{1}{2}})^K X\\Theta\n\n        at the first forward call. This parameter should only be set to\n        ``True`` in Transductive Learning setting.\n    bias : bool\n        If True, adds a learnable bias to the output. Default: ``True``.\n    norm : callable activation function/layer or None, optional\n        If not None, applies normalization to the updated node features.  Default: ``False``.\n    allow_zero_in_degree : bool, optional\n        If there are 0-in-degree nodes in the graph, output for those nodes will be invalid\n        since no message will be passed to those nodes. This is harmful for some applications\n        causing silent performance regression. This module will raise a DGLError if it detects\n        0-in-degree nodes in input graph. By setting ``True``, it will suppress the check\n        and let the users handle it by themselves. Default: ``False``.\n\n    Note\n    ----\n    Zero in-degree nodes will lead to invalid output value. This is because no message\n    will be passed to those nodes, the aggregation function will be appied on empty input.\n    A common practice to avoid this is to add a self-loop for each node in the graph if\n    it is homogeneous, which can be achieved by:\n\n    >>> g = ... # a DGLGraph\n    >>> g = dgl.add_self_loop(g)\n\n    Calling ``add_self_loop`` will not work for some graphs, for example, heterogeneous graph\n    since the edge type can not be decided for self_loop edges. Set ``allow_zero_in_degree``\n    to ``True`` for those cases to unblock the code and handle zero-in-degree nodes manually.\n    A common practise to handle this is to filter out the nodes with zero-in-degree when use\n    after conv.\n\n    Example\n    -------\n    >>> import dgl\n    >>> import numpy as np\n    >>> import tensorflow as tf\n    >>> from dgl.nn import SGConv\n    >>>\n    >>> with tf.device(\"CPU:0\"):\n    >>>     g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))\n    >>>     g = dgl.add_self_loop(g)\n    >>>     feat = tf.ones((6, 10))\n    >>>     conv = SGConv(10, 2, k=2, cached=True)\n    >>>     res = conv(g, feat)\n    >>>     res\n    <tf.Tensor: shape=(6, 2), dtype=float32, numpy=\n    array([[0.61023676, 0.5246612 ],\n        [0.61023676, 0.5246612 ],\n        [0.61023676, 0.5246612 ],\n        [0.8697353 , 0.7477695 ],\n        [0.60570633, 0.520766  ],\n        [0.6102368 , 0.52466124]], dtype=float32)>\n    \"\"\"\n\n    def __init__(\n        self,\n        in_feats,\n        out_feats,\n        k=1,\n        cached=False,\n        bias=True,\n        norm=None,\n        allow_zero_in_degree=False,\n    ):\n        super(SGConv, self).__init__()\n        self.fc = layers.Dense(out_feats, use_bias=bias)\n        self._cached = cached\n        self._cached_h = None\n        self._k = k\n        self.norm = norm\n        self._allow_zero_in_degree = allow_zero_in_degree\n\n    def set_allow_zero_in_degree(self, set_value):\n        r\"\"\"Set allow_zero_in_degree flag.\n\n        Parameters\n        ----------\n        set_value : bool\n            The value to be set to the flag.\n        \"\"\"\n        self._allow_zero_in_degree = set_value\n\n    def call(self, graph, feat):\n        r\"\"\"Compute Simplifying Graph Convolution layer.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The graph.\n        feat : tf.Tensor\n            The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`\n            is size of input feature, :math:`N` is the number of nodes.\n\n        Returns\n        -------\n        tf.Tensor\n            The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`\n            is size of output feature.\n\n        Raises\n        ------\n        DGLError\n            If there are 0-in-degree nodes in the input graph, it will raise DGLError\n            since no message will be passed to those nodes. This will cause invalid output.\n            The error can be ignored by setting ``allow_zero_in_degree`` parameter to ``True``.\n\n        Note\n        ----\n        If ``cache`` is set to True, ``feat`` and ``graph`` should not change during\n        training, or you will get wrong results.\n        \"\"\"\n        with graph.local_scope():\n            if not self._allow_zero_in_degree:\n                if tf.math.count_nonzero(graph.in_degrees() == 0) > 0:\n                    raise DGLError(\n                        \"There are 0-in-degree nodes in the graph, \"\n                        \"output for those nodes will be invalid. \"\n                        \"This is harmful for some applications, \"\n                        \"causing silent performance regression. \"\n                        \"Adding self-loop on the input graph by \"\n                        \"calling `g = dgl.add_self_loop(g)` will resolve \"\n                        \"the issue. Setting ``allow_zero_in_degree`` \"\n                        \"to be `True` when constructing this module will \"\n                        \"suppress the check and let the code run.\"\n                    )\n\n            if self._cached_h is not None:\n                feat = self._cached_h\n            else:\n                # compute normalization\n                degs = tf.clip_by_value(\n                    tf.cast(graph.in_degrees(), tf.float32),\n                    clip_value_min=1,\n                    clip_value_max=np.inf,\n                )\n                norm = tf.pow(degs, -0.5)\n                norm = tf.expand_dims(norm, 1)\n                # compute (D^-1 A^k D)^k X\n                for _ in range(self._k):\n                    feat = feat * norm\n                    graph.ndata[\"h\"] = feat\n                    graph.update_all(fn.copy_u(\"h\", \"m\"), fn.sum(\"m\", \"h\"))\n                    feat = graph.ndata.pop(\"h\")\n                    feat = feat * norm\n\n                if self.norm is not None:\n                    feat = self.norm(feat)\n\n                # cache feature\n                if self._cached:\n                    self._cached_h = feat\n            return self.fc(feat)\n"
  },
  {
    "path": "python/dgl/nn/tensorflow/glob.py",
    "content": "\"\"\"Tensorflow modules for graph global pooling.\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name, W0235\nimport tensorflow as tf\nfrom tensorflow.keras import layers\n\nfrom ...readout import (\n    max_nodes,\n    mean_nodes,\n    softmax_nodes,\n    sum_nodes,\n    topk_nodes,\n)\n\n__all__ = [\n    \"SumPooling\",\n    \"AvgPooling\",\n    \"MaxPooling\",\n    \"SortPooling\",\n    \"WeightAndSum\",\n    \"GlobalAttentionPooling\",\n]\n\n\nclass SumPooling(layers.Layer):\n    r\"\"\"Apply sum pooling over the nodes in the graph.\n\n    .. math::\n        r^{(i)} = \\sum_{k=1}^{N_i} x^{(i)}_k\n    \"\"\"\n\n    def __init__(self):\n        super(SumPooling, self).__init__()\n\n    def call(self, graph, feat):\n        r\"\"\"Compute sum pooling.\n\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The graph.\n        feat : tf.Tensor\n            The input feature with shape :math:`(N, *)` where\n            :math:`N` is the number of nodes in the graph.\n\n        Returns\n        -------\n        tf.Tensor\n            The output feature with shape :math:`(B, *)`, where\n            :math:`B` refers to the batch size.\n        \"\"\"\n        with graph.local_scope():\n            graph.ndata[\"h\"] = feat\n            readout = sum_nodes(graph, \"h\")\n            return readout\n\n\nclass AvgPooling(layers.Layer):\n    r\"\"\"Apply average pooling over the nodes in the graph.\n\n    .. math::\n        r^{(i)} = \\frac{1}{N_i}\\sum_{k=1}^{N_i} x^{(i)}_k\n    \"\"\"\n\n    def __init__(self):\n        super(AvgPooling, self).__init__()\n\n    def call(self, graph, feat):\n        r\"\"\"Compute average pooling.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The graph.\n        feat : tf.Tensor\n            The input feature with shape :math:`(N, *)` where\n            :math:`N` is the number of nodes in the graph.\n\n        Returns\n        -------\n        tf.Tensor\n            The output feature with shape :math:`(B, *)`, where\n            :math:`B` refers to the batch size.\n        \"\"\"\n        with graph.local_scope():\n            graph.ndata[\"h\"] = feat\n            readout = mean_nodes(graph, \"h\")\n            return readout\n\n\nclass MaxPooling(layers.Layer):\n    r\"\"\"Apply max pooling over the nodes in the graph.\n\n    .. math::\n        r^{(i)} = \\max_{k=1}^{N_i}\\left( x^{(i)}_k \\right)\n    \"\"\"\n\n    def __init__(self):\n        super(MaxPooling, self).__init__()\n\n    def call(self, graph, feat):\n        r\"\"\"Compute max pooling.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The graph.\n        feat : tf.Tensor\n            The input feature with shape :math:`(N, *)` where\n            :math:`N` is the number of nodes in the graph.\n\n        Returns\n        -------\n        tf.Tensor\n            The output feature with shape :math:`(B, *)`, where\n            :math:`B` refers to the batch size.\n        \"\"\"\n        with graph.local_scope():\n            graph.ndata[\"h\"] = feat\n            readout = max_nodes(graph, \"h\")\n            return readout\n\n\nclass SortPooling(layers.Layer):\n    r\"\"\"Sort Pooling from `An End-to-End Deep Learning Architecture for Graph Classification\n    <https://www.cse.wustl.edu/~ychen/public/DGCNN.pdf>`__\n\n    Parameters\n    ----------\n    k : int\n        The number of nodes to hold for each graph.\n    \"\"\"\n\n    def __init__(self, k):\n        super(SortPooling, self).__init__()\n        self.k = k\n\n    def call(self, graph, feat):\n        r\"\"\"Compute sort pooling.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The graph.\n        feat : tf.Tensor\n            The input node feature with shape :math:`(N, D)` where\n            :math:`N` is the number of nodes in the graph.\n\n        Returns\n        -------\n        tf.Tensor\n            The output feature with shape :math:`(B, k * D)`, where\n            :math:`B` refers to the batch size.\n        \"\"\"\n        with graph.local_scope():\n            # Sort the feature of each node in ascending order.\n            feat = tf.sort(feat, -1)\n            graph.ndata[\"h\"] = feat\n            # Sort nodes according to their last features.\n            ret = tf.reshape(\n                topk_nodes(graph, \"h\", self.k, sortby=-1)[0],\n                (-1, self.k * feat.shape[-1]),\n            )\n            return ret\n\n\nclass GlobalAttentionPooling(layers.Layer):\n    r\"\"\"Global Attention Pooling from `Gated Graph Sequence Neural Networks\n    <https://arxiv.org/abs/1511.05493.pdf>`__\n\n    .. math::\n        r^{(i)} = \\sum_{k=1}^{N_i}\\mathrm{softmax}\\left(f_{gate}\n        \\left(x^{(i)}_k\\right)\\right) f_{feat}\\left(x^{(i)}_k\\right)\n\n    Parameters\n    ----------\n    gate_nn : tf.layers.Layer\n        A neural network that computes attention scores for each feature.\n    feat_nn : tf.layers.Layer, optional\n        A neural network applied to each feature before combining them\n        with attention scores.\n    \"\"\"\n\n    def __init__(self, gate_nn, feat_nn=None):\n        super(GlobalAttentionPooling, self).__init__()\n        self.gate_nn = gate_nn\n        self.feat_nn = feat_nn\n\n    def call(self, graph, feat):\n        r\"\"\"Compute global attention pooling.\n\n        Parameters\n        ----------\n        graph : DGLGraph\n            The graph.\n        feat : tf.Tensor\n            The input node feature with shape :math:`(N, D)` where\n            :math:`N` is the number of nodes in the graph.\n\n        Returns\n        -------\n        tf.Tensor\n            The output feature with shape :math:`(B, *)`, where\n            :math:`B` refers to the batch size.\n        \"\"\"\n        with graph.local_scope():\n            gate = self.gate_nn(feat)\n            assert (\n                gate.shape[-1] == 1\n            ), \"The output of gate_nn should have size 1 at the last axis.\"\n            feat = self.feat_nn(feat) if self.feat_nn else feat\n\n            graph.ndata[\"gate\"] = gate\n            gate = softmax_nodes(graph, \"gate\")\n            graph.ndata.pop(\"gate\")\n\n            graph.ndata[\"r\"] = feat * gate\n            readout = sum_nodes(graph, \"r\")\n            graph.ndata.pop(\"r\")\n\n            return readout\n\n\nclass WeightAndSum(layers.Layer):\n    \"\"\"Compute importance weights for atoms and perform a weighted sum.\n\n    Parameters\n    ----------\n    in_feats : int\n        Input atom feature size\n    \"\"\"\n\n    def __init__(self, in_feats):\n        super(WeightAndSum, self).__init__()\n        self.in_feats = in_feats\n        self.atom_weighting = tf.keras.Sequential(\n            layers.Dense(1), layers.Activation(tf.nn.sigmoid)\n        )\n\n    def call(self, g, feats):\n        \"\"\"Compute molecule representations out of atom representations\n\n        Parameters\n        ----------\n        g : DGLGraph\n            DGLGraph with batch size B for processing multiple molecules in parallel\n        feats : FloatTensor of shape (N, self.in_feats)\n            Representations for all atoms in the molecules\n            * N is the total number of atoms in all molecules\n\n        Returns\n        -------\n        FloatTensor of shape (B, self.in_feats)\n            Representations for B molecules\n        \"\"\"\n        with g.local_scope():\n            g.ndata[\"h\"] = feats\n            g.ndata[\"w\"] = self.atom_weighting(g.ndata[\"h\"])\n            h_g_sum = sum_nodes(g, \"h\", \"w\")\n\n        return h_g_sum\n"
  },
  {
    "path": "python/dgl/nn/tensorflow/hetero.py",
    "content": "\"\"\"Heterograph NN modules\"\"\"\nimport tensorflow as tf\nfrom tensorflow.keras import layers\n\n__all__ = [\"HeteroGraphConv\"]\n\n\nclass HeteroGraphConv(layers.Layer):\n    r\"\"\"A generic module for computing convolution on heterogeneous graphs.\n\n    The heterograph convolution applies sub-modules on their associating\n    relation graphs, which reads the features from source nodes and writes the\n    updated ones to destination nodes. If multiple relations have the same\n    destination node types, their results are aggregated by the specified method.\n    If the relation graph has no edge, the corresponding module will not be called.\n\n    Pseudo-code:\n\n    .. code::\n\n        outputs = {nty : [] for nty in g.dsttypes}\n        # Apply sub-modules on their associating relation graphs in parallel\n        for relation in g.canonical_etypes:\n            stype, etype, dtype = relation\n            dstdata = relation_submodule(g[relation], ...)\n            outputs[dtype].append(dstdata)\n\n        # Aggregate the results for each destination node type\n        rsts = {}\n        for ntype, ntype_outputs in outputs.items():\n            if len(ntype_outputs) != 0:\n                rsts[ntype] = aggregate(ntype_outputs)\n        return rsts\n\n    Examples\n    --------\n\n    Create a heterograph with three types of relations and nodes.\n\n    >>> import dgl\n    >>> g = dgl.heterograph({\n    ...     ('user', 'follows', 'user') : edges1,\n    ...     ('user', 'plays', 'game') : edges2,\n    ...     ('store', 'sells', 'game')  : edges3})\n\n    Create a ``HeteroGraphConv`` that applies different convolution modules to\n    different relations. Note that the modules for ``'follows'`` and ``'plays'``\n    do not share weights.\n\n    >>> import dgl.nn.pytorch as dglnn\n    >>> conv = dglnn.HeteroGraphConv({\n    ...     'follows' : dglnn.GraphConv(...),\n    ...     'plays' : dglnn.GraphConv(...),\n    ...     'sells' : dglnn.SAGEConv(...)},\n    ...     aggregate='sum')\n\n    Call forward with some ``'user'`` features. This computes new features for both\n    ``'user'`` and ``'game'`` nodes.\n\n    >>> import tensorflow as tf\n    >>> h1 = {'user' : tf.random.normal((g.num_nodes('user'), 5))}\n    >>> h2 = conv(g, h1)\n    >>> print(h2.keys())\n    dict_keys(['user', 'game'])\n\n    Call forward with both ``'user'`` and ``'store'`` features. Because both the\n    ``'plays'`` and ``'sells'`` relations will update the ``'game'`` features,\n    their results are aggregated by the specified method (i.e., summation here).\n\n    >>> f1 = {'user' : ..., 'store' : ...}\n    >>> f2 = conv(g, f1)\n    >>> print(f2.keys())\n    dict_keys(['user', 'game'])\n\n    Call forward with some ``'store'`` features. This only computes new features\n    for ``'game'`` nodes.\n\n    >>> g1 = {'store' : ...}\n    >>> g2 = conv(g, g1)\n    >>> print(g2.keys())\n    dict_keys(['game'])\n\n    Call forward with a pair of inputs is allowed and each submodule will also\n    be invoked with a pair of inputs.\n\n    >>> x_src = {'user' : ..., 'store' : ...}\n    >>> x_dst = {'user' : ..., 'game' : ...}\n    >>> y_dst = conv(g, (x_src, x_dst))\n    >>> print(y_dst.keys())\n    dict_keys(['user', 'game'])\n\n    Notes\n    -----\n\n    HeteroGraphConv requires that there is a module for every ``'etype'`` in an input graph.\n    If you want to apply HeteroGraphConv to a subset of a graph's ``'etypes'``, you must\n    create a new graph using for example :func:`~dgl.edge_type_subgraph()`.\n\n    Parameters\n    ----------\n    mods : dict[str, nn.Module]\n        Modules associated with every edge types. The forward function of each\n        module must have a `DGLGraph` object as the first argument, and\n        its second argument is either a tensor object representing the node\n        features or a pair of tensor object representing the source and destination\n        node features.\n    aggregate : str, callable, optional\n        Method for aggregating node features generated by different relations.\n        Allowed string values are 'sum', 'max', 'min', 'mean', 'stack'.\n        The 'stack' aggregation is performed along the second dimension, whose order\n        is deterministic.\n        User can also customize the aggregator by providing a callable instance.\n        For example, aggregation by summation is equivalent to the follows:\n\n        .. code::\n\n            def my_agg_func(tensors, dsttype):\n                # tensors: is a list of tensors to aggregate\n                # dsttype: string name of the destination node type for which the\n                #          aggregation is performed\n                stacked = tf.stack(tensors, axis=0)\n                return tf.reduce_sum(stacked, axis=0)\n\n    Attributes\n    ----------\n    mods : dict[str, nn.Module]\n        Modules associated with every edge types.\n    \"\"\"\n\n    def __init__(self, mods, aggregate=\"sum\"):\n        super(HeteroGraphConv, self).__init__()\n        self.mods = mods\n        # Do not break if graph has 0-in-degree nodes.\n        # Because there is no general rule to add self-loop for heterograph.\n        for _, v in self.mods.items():\n            set_allow_zero_in_degree_fn = getattr(\n                v, \"set_allow_zero_in_degree\", None\n            )\n            if callable(set_allow_zero_in_degree_fn):\n                set_allow_zero_in_degree_fn(True)\n        if isinstance(aggregate, str):\n            self.agg_fn = get_aggregate_fn(aggregate)\n        else:\n            self.agg_fn = aggregate\n\n    def call(self, g, inputs, mod_args=None, mod_kwargs=None):\n        \"\"\"Forward computation\n\n        Invoke the forward function with each module and aggregate their results.\n\n        Parameters\n        ----------\n        g : DGLGraph\n            Graph data.\n        inputs : dict[str, Tensor] or pair of dict[str, Tensor]\n            Input node features.\n        mod_args : dict[str, tuple[any]], optional\n            Extra positional arguments for the sub-modules.\n        mod_kwargs : dict[str, dict[str, any]], optional\n            Extra key-word arguments for the sub-modules.\n\n        Returns\n        -------\n        dict[str, Tensor]\n            Output representations for every types of nodes.\n        \"\"\"\n        if mod_args is None:\n            mod_args = {}\n        if mod_kwargs is None:\n            mod_kwargs = {}\n        outputs = {nty: [] for nty in g.dsttypes}\n        if isinstance(inputs, tuple):\n            src_inputs, dst_inputs = inputs\n            for stype, etype, dtype in g.canonical_etypes:\n                rel_graph = g[stype, etype, dtype]\n                if stype not in src_inputs or dtype not in dst_inputs:\n                    continue\n                dstdata = self.mods[etype](\n                    rel_graph,\n                    (src_inputs[stype], dst_inputs[dtype]),\n                    *mod_args.get(etype, ()),\n                    **mod_kwargs.get(etype, {})\n                )\n                outputs[dtype].append(dstdata)\n        else:\n            for stype, etype, dtype in g.canonical_etypes:\n                rel_graph = g[stype, etype, dtype]\n                if stype not in inputs:\n                    continue\n                dstdata = self.mods[etype](\n                    rel_graph,\n                    (inputs[stype], inputs[dtype]),\n                    *mod_args.get(etype, ()),\n                    **mod_kwargs.get(etype, {})\n                )\n                outputs[dtype].append(dstdata)\n        rsts = {}\n        for nty, alist in outputs.items():\n            if len(alist) != 0:\n                rsts[nty] = self.agg_fn(alist, nty)\n        return rsts\n\n\ndef get_aggregate_fn(agg):\n    \"\"\"Internal function to get the aggregation function for node data\n    generated from different relations.\n\n    Parameters\n    ----------\n    agg : str\n        Method for aggregating node features generated by different relations.\n        Allowed values are 'sum', 'max', 'min', 'mean', 'stack'.\n\n    Returns\n    -------\n    callable\n        Aggregator function that takes a list of tensors to aggregate\n        and returns one aggregated tensor.\n    \"\"\"\n    if agg == \"sum\":\n        fn = tf.reduce_sum\n    elif agg == \"max\":\n        fn = tf.reduce_max\n    elif agg == \"min\":\n        fn = tf.reduce_min\n    elif agg == \"mean\":\n        fn = tf.reduce_mean\n    elif agg == \"stack\":\n        fn = None  # will not be called\n    else:\n        raise DGLError(\n            \"Invalid cross type aggregator. Must be one of \"\n            '\"sum\", \"max\", \"min\", \"mean\" or \"stack\". But got \"%s\"' % agg\n        )\n    if agg == \"stack\":\n\n        def stack_agg(inputs, dsttype):  # pylint: disable=unused-argument\n            if len(inputs) == 0:\n                return None\n            return tf.stack(inputs, axis=1)\n\n        return stack_agg\n    else:\n\n        def aggfn(inputs, dsttype):  # pylint: disable=unused-argument\n            if len(inputs) == 0:\n                return None\n            stacked = tf.stack(inputs, axis=0)\n            return fn(stacked, axis=0)\n\n        return aggfn\n"
  },
  {
    "path": "python/dgl/nn/tensorflow/softmax.py",
    "content": "\"\"\"tf modules for graph related softmax.\"\"\"\n# pylint: disable= unused-import\nfrom ..functional import edge_softmax\n"
  },
  {
    "path": "python/dgl/nn/tensorflow/utils.py",
    "content": "\"\"\"Utilities for tf NN package\"\"\"\n# pylint: disable=no-member, invalid-name\nimport tensorflow as tf\nfrom tensorflow.keras import layers  # pylint: disable=W0235\n\n\ndef matmul_maybe_select(A, B):\n    \"\"\"Perform Matrix multiplication C = A * B but A could be an integer id vector.\n\n    If A is an integer vector, we treat it as multiplying a one-hot encoded tensor.\n    In this case, the expensive dense matrix multiply can be replaced by a much\n    cheaper index lookup.\n\n    For example,\n    ::\n\n        A = [2, 0, 1],\n        B = [[0.1, 0.2],\n             [0.3, 0.4],\n             [0.5, 0.6]]\n\n    then matmul_maybe_select(A, B) is equivalent to\n    ::\n\n        [[0, 0, 1],     [[0.1, 0.2],\n         [1, 0, 0],  *   [0.3, 0.4],\n         [0, 1, 0]]      [0.5, 0.6]]\n\n    In all other cases, perform a normal matmul.\n\n    Parameters\n    ----------\n    A : tf.Tensor\n        lhs tensor\n    B : tf.Tensor\n        rhs tensor\n\n    Returns\n    -------\n    C : tf.Tensor\n        result tensor\n    \"\"\"\n    if A.dtype == tf.int64 and len(A.shape) == 1:\n        return tf.gather(B, A)\n    else:\n        return tf.matmul(A, B)\n\n\ndef bmm_maybe_select(A, B, index):\n    \"\"\"Slice submatrices of A by the given index and perform bmm.\n\n    B is a 3D tensor of shape (N, D1, D2), which can be viewed as a stack of\n    N matrices of shape (D1, D2). The input index is an integer vector of length M.\n    A could be either:\n    (1) a dense tensor of shape (M, D1),\n    (2) an integer vector of length M.\n    The result C is a 2D matrix of shape (M, D2)\n\n    For case (1), C is computed by bmm:\n    ::\n\n        C[i, :] = matmul(A[i, :], B[index[i], :, :])\n\n    For case (2), C is computed by index select:\n    ::\n\n        C[i, :] = B[index[i], A[i], :]\n\n    Parameters\n    ----------\n    A : tf.Tensor\n        lhs tensor\n    B : tf.Tensor\n        rhs tensor\n    index : tf.Tensor\n        index tensor\n\n    Returns\n    -------\n    C : tf.Tensor\n        return tensor\n    \"\"\"\n    if A.dtype == tf.int64 and len(A.shape) == 1:\n        # following is a faster version of B[index, A, :]\n        B = tf.reshape(B, (-1, B.shape[2]))\n        flatidx = index * B.shape[1] + A\n        return tf.gather(B, flatidx)\n    else:\n        BB = tf.gather(B, index)\n        return tf.squeeze(tf.matmul(tf.expand_dims(A, 1), BB), 1)\n\n\nclass Identity(layers.Layer):\n    \"\"\"A placeholder identity operator that is argument-insensitive.\"\"\"\n\n    def call(self, x):\n        \"\"\"Return input\"\"\"\n        return x\n"
  },
  {
    "path": "python/dgl/ops/__init__.py",
    "content": "\"\"\"dgl operator module.\"\"\"\nfrom .edge_softmax import *\nfrom .gather_mm import *\nfrom .sddmm import *\nfrom .segment import *\nfrom .spmm import *\n"
  },
  {
    "path": "python/dgl/ops/edge_softmax.py",
    "content": "\"\"\"dgl edge_softmax operator module.\"\"\"\nfrom ..backend import (\n    astype,\n    edge_softmax as edge_softmax_internal,\n    edge_softmax_hetero as edge_softmax_hetero_internal,\n)\nfrom ..base import ALL, is_all\n\n__all__ = [\"edge_softmax\"]\n\n\ndef edge_softmax(graph, logits, eids=ALL, norm_by=\"dst\"):\n    r\"\"\"Compute softmax over weights of incoming edges for every node.\n\n    For a node :math:`i`, edge softmax is an operation that computes\n\n    .. math::\n      a_{ij} = \\frac{\\exp(z_{ij})}{\\sum_{j\\in\\mathcal{N}(i)}\\exp(z_{ij})}\n\n    where :math:`z_{ij}` is a signal of edge :math:`j\\rightarrow i`, also\n    called logits in the context of softmax. :math:`\\mathcal{N}(i)` is\n    the set of nodes that have an edge to :math:`i`.\n\n    By default edge softmax is normalized by destination nodes(i.e. :math:`ij`\n    are incoming edges of `i` in the formula above). We also support edge\n    softmax normalized by source nodes(i.e. :math:`ij` are outgoing edges of\n    `i` in the formula). The former case corresponds to softmax in GAT and\n    Transformer, and the latter case corresponds to softmax in Capsule network.\n    An example of using edge softmax is in\n    `Graph Attention Network <https://arxiv.org/pdf/1710.10903.pdf>`__ where\n    the attention weights are computed with this operation.\n    Other non-GNN examples using this are\n    `Transformer <https://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf>`__,\n    `Capsule <https://arxiv.org/pdf/1710.09829.pdf>`__, etc.\n\n    Parameters\n    ----------\n    graph : DGLGraph\n        The graph over which edge softmax will be performed.\n    logits : torch.Tensor or dict of torch.Tensor\n        The input edge feature. Heterogeneous graphs can have dict of tensors where\n        each tensor stores the edge features of the corresponding relation type.\n    eids : torch.Tensor or ALL, optional\n        The IDs of the edges to apply edge softmax. If ALL, it will apply edge\n        softmax to all edges in the graph. Default: ALL.\n    norm_by : str, could be `src` or `dst`\n        Normalized by source nodes or destination nodes. Default: `dst`.\n\n    Returns\n    -------\n    Tensor or tuple of tensors\n        Softmax value.\n\n    Notes\n    -----\n        * Input shape: :math:`(E, *, 1)` where * means any number of\n          additional dimensions, :math:`E` equals the length of eids.\n          If the `eids` is ALL, :math:`E` equals the number of edges in\n          the graph.\n        * Return shape: :math:`(E, *, 1)`\n\n    Examples on a homogeneous graph\n    -------------------------------\n    The following example uses PyTorch backend.\n\n    >>> from dgl.nn.functional import edge_softmax\n    >>> import dgl\n    >>> import torch as th\n\n    Create a :code:`DGLGraph` object and initialize its edge features.\n\n    >>> g = dgl.graph((th.tensor([0, 0, 0, 1, 1, 2]), th.tensor([0, 1, 2, 1, 2, 2])))\n    >>> edata = th.ones(6, 1).float()\n    >>> edata\n        tensor([[1.],\n                [1.],\n                [1.],\n                [1.],\n                [1.],\n                [1.]])\n\n    Apply edge softmax over g:\n\n    >>> edge_softmax(g, edata)\n        tensor([[1.0000],\n                [0.5000],\n                [0.3333],\n                [0.5000],\n                [0.3333],\n                [0.3333]])\n\n    Apply edge softmax over g normalized by source nodes:\n\n    >>> edge_softmax(g, edata, norm_by='src')\n        tensor([[0.3333],\n                [0.3333],\n                [0.3333],\n                [0.5000],\n                [0.5000],\n                [1.0000]])\n\n    Apply edge softmax to first 4 edges of g:\n\n    >>> edge_softmax(g, edata[:4], th.Tensor([0,1,2,3]))\n        tensor([[1.0000],\n                [0.5000],\n                [1.0000],\n                [0.5000]])\n\n\n    Examples on a heterogeneous graph\n    ---------------------------------\n\n    Create a heterogeneous graph and initialize its edge features.\n\n    >>> hg = dgl.heterograph({\n    ...     ('user', 'follows', 'user'): ([0, 0, 1], [0, 1, 2]),\n    ...     ('developer', 'develops', 'game'): ([0, 1], [0, 1])\n    ...     })\n    >>> edata_follows = th.ones(3, 1).float()\n    >>> edata_develops = th.ones(2, 1).float()\n    >>> edata_dict = {('user', 'follows', 'user'): edata_follows,\n    ... ('developer','develops', 'game'): edata_develops}\n\n    Apply edge softmax over hg normalized by source nodes:\n\n    >>> edge_softmax(hg, edata_dict, norm_by='src')\n        {('developer', 'develops', 'game'): tensor([[1.],\n        [1.]]), ('user', 'follows', 'user'): tensor([[0.5000],\n        [0.5000],\n        [1.0000]])}\n    \"\"\"\n    if not is_all(eids):\n        eids = astype(eids, graph.idtype)\n    if graph._graph.number_of_etypes() == 1:\n        return edge_softmax_internal(\n            graph._graph, logits, eids=eids, norm_by=norm_by\n        )\n    else:\n        logits_list = [None] * graph._graph.number_of_etypes()\n        logits = {graph.to_canonical_etype(k): v for k, v in logits.items()}\n        for rel in graph.canonical_etypes:\n            etid = graph.get_etype_id(rel)\n            logits_list[etid] = logits[rel]\n        logits_tuple = tuple(logits_list)\n        score_tuple = edge_softmax_hetero_internal(\n            graph._graph, eids, norm_by, *logits_tuple\n        )\n        score = {}\n        for rel in graph.canonical_etypes:\n            etid = graph.get_etype_id(rel)\n            score[rel] = score_tuple[etid]\n        return score\n"
  },
  {
    "path": "python/dgl/ops/gather_mm.py",
    "content": "\"\"\"dgl gather_mm operator module.\"\"\"\nfrom .. import backend as F\n\n__all__ = [\"gather_mm\"]\n\n\ndef gather_mm(a, b, *, idx_b):\n    r\"\"\"Gather data according to the given indices and perform matrix multiplication.\n\n    Let the result tensor be ``c``, the operator conducts the following computation:\n\n      c[i] = a[i] @ b[idx_b[i]]\n      , where len(c) == len(idx_b)\n\n\n    Parameters\n    ----------\n    a : Tensor\n        A 2-D tensor of shape ``(N, D1)``\n    b : Tensor\n        A 3-D tensor of shape ``(R, D1, D2)``\n    idx_b : Tensor, optional\n        An 1-D integer tensor of shape ``(N,)``.\n\n    Returns\n    -------\n    Tensor\n        The output dense matrix of shape ``(N, D2)``\n    \"\"\"\n    N, D1 = F.shape(a)\n    R, _, D2 = F.shape(b)\n    if N > 1000000 or D1 > 8 or D2 > 8:\n        # Use segment_mm for large workload\n        import torch\n\n        sorted_idx_b, perm = torch.sort(idx_b)\n        _, rev_perm = torch.sort(perm)\n        sorted_a = torch.index_select(a, 0, perm)\n        pos_l = torch.searchsorted(\n            sorted_idx_b, torch.arange(R, device=a.device)\n        )\n        pos_r = torch.cat(\n            [pos_l[1:], torch.tensor([len(idx_b)], device=a.device)]\n        )\n        seglen = (pos_r - pos_l).cpu()  # XXX(minjie): cause device synchronize\n        return torch.index_select(\n            F.segment_mm(sorted_a, b, seglen), 0, rev_perm\n        )\n    else:\n        return F.gather_mm(a, b, None, idx_b)\n"
  },
  {
    "path": "python/dgl/ops/sddmm.py",
    "content": "\"\"\"dgl sddmm operator module.\"\"\"\nimport sys\nfrom itertools import product\n\nfrom .. import backend as F\nfrom ..backend import (\n    gsddmm as gsddmm_internal,\n    gsddmm_hetero as gsddmm_internal_hetero,\n)\n\n__all__ = [\"gsddmm\", \"copy_u\", \"copy_v\", \"copy_e\"]\n\n\ndef reshape_lhs_rhs(lhs_data, rhs_data):\n    r\"\"\"Expand dims so that there will be no broadcasting issues with different\n    number of dimensions. For example, given two shapes (N, 3, 1), (E, 5, 3, 4)\n    that are valid broadcastable shapes, change them to (N, 1, 3, 1) and\n    (E, 5, 3, 4)\n\n    Parameters\n    ----------\n    lhs_data : tensor or None\n        The left operand, could be None if it's not required by op.\n    rhs_data : tensor or None\n        The right operand, could be None if it's not required by op.\n    \"\"\"\n    lhs_shape = F.shape(lhs_data)\n    rhs_shape = F.shape(rhs_data)\n    if len(lhs_shape) != len(rhs_shape):\n        max_ndims = max(len(lhs_shape), len(rhs_shape))\n        lhs_pad_ndims = max_ndims - len(lhs_shape)\n        rhs_pad_ndims = max_ndims - len(rhs_shape)\n        new_lhs_shape = (lhs_shape[0],) + (1,) * lhs_pad_ndims + lhs_shape[1:]\n        new_rhs_shape = (rhs_shape[0],) + (1,) * rhs_pad_ndims + rhs_shape[1:]\n        lhs_data = F.reshape(lhs_data, new_lhs_shape)\n        rhs_data = F.reshape(rhs_data, new_rhs_shape)\n    return lhs_data, rhs_data\n\n\ndef gsddmm(g, op, lhs_data, rhs_data, lhs_target=\"u\", rhs_target=\"v\"):\n    r\"\"\"Generalized Sampled-Dense-Dense Matrix Multiplication interface.\n    It computes edge features by :attr:`op` lhs features and rhs features.\n\n    .. math::\n\n        x_{e} = \\phi(x_{lhs}, x_{rhs}), \\forall (u,e,v)\\in \\mathcal{G}\n\n    where :math:`x_{e}` is the returned feature on edges and :math:`x_u`,\n    :math:`x_v` refers to :attr:`u`, :attr:`v` respectively. :math:`\\phi`\n    is the binary operator :attr:`op`, and :math:`\\mathcal{G}` is the graph\n    we apply gsddmm on: :attr:`g`. :math:`lhs` and :math:`rhs` are one of\n    :math:`u,v,e`'s.\n\n    Parameters\n    ----------\n    g : DGLGraph\n        The input graph.\n    op : str\n        Binary operator, could be ``add``, ``sub``, ``mul``, ``div``, ``dot``,\n        ``copy_lhs``, ``copy_rhs``.\n    lhs_data : tensor or None\n        The left operand, could be None if it's not required by op.\n    rhs_data : tensor or None\n        The right operand, could be None if it's not required by op.\n    lhs_target: str\n        Choice of ``u``(source), ``e``(edge) or ``v``(destination) for left operand.\n    rhs_target: str\n        Choice of ``u``(source), ``e``(edge) or ``v``(destination) for right operand.\n\n    Returns\n    -------\n    tensor\n        The result tensor.\n    \"\"\"\n    if g._graph.number_of_etypes() == 1:\n        if op not in [\"copy_lhs\", \"copy_rhs\"]:\n            lhs_data, rhs_data = reshape_lhs_rhs(lhs_data, rhs_data)\n        return gsddmm_internal(\n            g._graph, op, lhs_data, rhs_data, lhs_target, rhs_target\n        )\n    else:\n        if op == \"copy_lhs\":\n            rhs_data = [None] * g._graph.number_of_etypes()\n        elif op == \"copy_rhs\":\n            lhs_data = [None] * g._graph.number_of_ntypes()\n        # TODO (Israt): Call reshape_lhs_rhs() on lhs and rhs data to match their dimension\n        # and avoid broadcasting issue. Handle the case where different nodes have\n        # different dimensions, and different etypes may need different broadcasting\n        # dims for the same node.\n        lhs_and_rhs_tuple = tuple(list(lhs_data) + list(rhs_data))\n        return gsddmm_internal_hetero(\n            g._graph,\n            op,\n            len(lhs_data),\n            lhs_target,\n            rhs_target,\n            *lhs_and_rhs_tuple\n        )\n\n\ndef _gen_sddmm_func(lhs_target, rhs_target, binary_op):\n    name = \"{}_{}_{}\".format(lhs_target, binary_op, rhs_target)\n    target_dict = {\"u\": \"source node\", \"e\": \"edge\", \"v\": \"destination node\"}\n    lhs_str = target_dict[lhs_target]\n    rhs_str = target_dict[rhs_target]\n    docstring = r\"\"\"Generalized SDDMM function.\n    It computes edge features by {op} {lhs} features and {rhs} features.\n\n    Parameters\n    ----------\n    g : DGLGraph\n        The input graph\n    x : tensor\n        The {lhs} features.\n    y : tensor\n        The {rhs} features.\n\n    Returns\n    -------\n    tensor\n        The result tensor.\n\n    Notes\n    -----\n    This function supports autograd (computing input gradients given the output gradient). If the\n    feature shape of two input operands do not match, we first broadcasts the features to a unified\n    shape (note that the memory usage will not increase accordingly) and then performs the operation.\n\n    Broadcasting follows NumPy semantics. Please see\n    https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html\n    for more details about the NumPy broadcasting semantics.\n    \"\"\".format(\n        op=binary_op, lhs=lhs_str, rhs=rhs_str\n    )\n\n    def func(g, x, y):\n        return gsddmm(\n            g, binary_op, x, y, lhs_target=lhs_target, rhs_target=rhs_target\n        )\n\n    func.__name__ = name\n    func.__doc__ = docstring\n    return func\n\n\ndef _register_sddmm_func():\n    \"\"\"Register sddmm functions\"\"\"\n    target = [\"u\", \"v\", \"e\"]\n    for lhs, rhs in product(target, target):\n        if lhs != rhs:\n            for binary_op in [\"add\", \"sub\", \"mul\", \"div\", \"dot\"]:\n                func = _gen_sddmm_func(lhs, rhs, binary_op)\n                setattr(sys.modules[__name__], func.__name__, func)\n                __all__.append(func.__name__)\n\n\ndef copy_u(g, x):\n    r\"\"\"Generalized SDDMM function that copies source node features to edges.\n\n    Parameters\n    ----------\n    g : DGLGraph\n        The input graph.\n    x : tensor\n        The source node features.\n\n    Returns\n    -------\n    tensor\n        The result tensor.\n\n    Notes\n    -----\n    This function supports autograd (computing input gradients given the output gradient).\n    \"\"\"\n    return gsddmm(g, \"copy_lhs\", x, None)\n\n\ndef copy_v(g, x):\n    r\"\"\"Generalized SDDMM function that copies destination node features to edges.\n\n    Parameters\n    ----------\n    g : DGLGraph\n        The input graph.\n    x : tensor\n        The destination node features.\n\n    Returns\n    -------\n    tensor\n        The result tensor.\n\n    Notes\n    -----\n    This function supports autograd (computing input gradients given the output gradient).\n    \"\"\"\n    return gsddmm(g, \"copy_rhs\", None, x)\n\n\n# pylint: disable=unused-argument\ndef copy_e(g, x):\n    r\"\"\"Generalized SDDMM function that copies destination node features to edges.\"\"\"\n    return x\n\n\n_register_sddmm_func()\n"
  },
  {
    "path": "python/dgl/ops/segment.py",
    "content": "\"\"\"Segment aggregation operators implemented using DGL graph.\"\"\"\n\nfrom .. import backend as F\nfrom ..base import DGLError\n\n__all__ = [\"segment_reduce\", \"segment_softmax\", \"segment_mm\"]\n\n\ndef segment_reduce(seglen, value, reducer=\"sum\"):\n    \"\"\"Segment reduction operator.\n\n    It aggregates the value tensor along the first dimension by segments.\n    The first argument ``seglen`` stores the length of each segment. Its\n    summation must be equal to the first dimension of the ``value`` tensor.\n    Zero-length segments are allowed.\n\n    Parameters\n    ----------\n    seglen : Tensor\n        Segment lengths.\n    value : Tensor\n        Value to aggregate.\n    reducer : str, optional\n        Aggregation method. Can be 'sum', 'max', 'min', 'mean'.\n\n    Returns\n    -------\n    Tensor\n        Aggregated tensor of shape ``(len(seglen), value.shape[1:])``.\n\n    Examples\n    --------\n\n    >>> import dgl\n    >>> import torch as th\n    >>> val = th.ones(10, 3)\n    >>> seg = th.tensor([1, 0, 5, 4])  # 4 segments\n    >>> dgl.segment_reduce(seg, val)\n    tensor([[1., 1., 1.],\n            [0., 0., 0.],\n            [5., 5., 5.],\n            [4., 4., 4.]])\n    \"\"\"\n    offsets = F.cumsum(\n        F.cat([F.zeros((1,), F.dtype(seglen), F.context(seglen)), seglen], 0), 0\n    )\n    if reducer == \"mean\":\n        rst = F.segment_reduce(\"sum\", value, offsets)\n        rst_shape = F.shape(rst)\n        z = F.astype(F.clamp(seglen, 1, len(value)), F.dtype(rst))\n        z_shape = (rst_shape[0],) + (1,) * (len(rst_shape) - 1)\n        return rst / F.reshape(z, z_shape)\n    elif reducer in [\"min\", \"sum\", \"max\"]:\n        rst = F.segment_reduce(reducer, value, offsets)\n        if reducer in [\"min\", \"max\"]:\n            rst = F.replace_inf_with_zero(rst)\n        return rst\n    else:\n        raise DGLError(\"reducer {} not recognized.\".format(reducer))\n\n\ndef segment_softmax(seglen, value):\n    \"\"\"Performa softmax on each segment.\n\n    The first argument ``seglen`` stores the length of each segment. Its\n    summation must be equal to the first dimension of the ``value`` tensor.\n    Zero-length segments are allowed.\n\n    Parameters\n    ----------\n    seglen : Tensor\n        Segment lengths.\n    value : Tensor\n        Value to aggregate.\n\n    Returns\n    -------\n    Tensor\n        Result tensor of the same shape as the ``value`` tensor.\n\n    Examples\n    --------\n\n    >>> import dgl\n    >>> import torch as th\n    >>> val = th.ones(10, 3)\n    >>> seg = th.tensor([1, 0, 5, 4])  # 4 segments\n    >>> dgl.segment_softmax(seg, val)\n    tensor([[1.0000, 1.0000, 1.0000],\n            [0.2000, 0.2000, 0.2000],\n            [0.2000, 0.2000, 0.2000],\n            [0.2000, 0.2000, 0.2000],\n            [0.2000, 0.2000, 0.2000],\n            [0.2000, 0.2000, 0.2000],\n            [0.2500, 0.2500, 0.2500],\n            [0.2500, 0.2500, 0.2500],\n            [0.2500, 0.2500, 0.2500],\n            [0.2500, 0.2500, 0.2500]])\n    \"\"\"\n    value_max = segment_reduce(seglen, value, reducer=\"max\")\n    value = F.exp(value - F.repeat(value_max, seglen, dim=0))\n    value_sum = segment_reduce(seglen, value, reducer=\"sum\")\n    return value / F.repeat(value_sum, seglen, dim=0)\n\n\ndef segment_mm(a, b, seglen_a):\n    r\"\"\"Performs matrix multiplication according to segments.\n\n    Suppose ``seglen_a == [10, 5, 0, 3]``, the operator will perform\n    four matrix multiplications::\n\n        a[0:10] @ b[0], a[10:15] @ b[1],\n        a[15:15] @ b[2], a[15:18] @ b[3]\n\n    Parameters\n    ----------\n    a : Tensor\n        The left operand, 2-D tensor of shape ``(N, D1)``\n    b : Tensor\n        The right operand, 3-D tensor of shape ``(R, D1, D2)``\n    seglen_a : Tensor\n        An integer tensor of shape ``(R,)``. Each element is the length of segments\n        of input ``a``. The summation of all elements must be equal to ``N``.\n\n    Returns\n    -------\n    Tensor\n        The output dense matrix of shape ``(N, D2)``\n    \"\"\"\n    return F.segment_mm(a, b, seglen_a)\n"
  },
  {
    "path": "python/dgl/ops/spmm.py",
    "content": "\"\"\"Internal module for general spmm operators.\"\"\"\nimport sys\n\nfrom .. import backend as F\nfrom ..backend import (\n    gspmm as gspmm_internal,\n    gspmm_hetero as gspmm_internal_hetero,\n)\n\n__all__ = [\"gspmm\"]\n\n\ndef reshape_lhs_rhs(lhs_data, rhs_data):\n    r\"\"\"Expand dims so that there will be no broadcasting issues with different\n    number of dimensions. For example, given two shapes (N, 3, 1), (E, 5, 3, 4)\n    that are valid broadcastable shapes, change them to (N, 1, 3, 1) and\n    (E, 5, 3, 4)\n\n    Parameters\n    ----------\n    lhs_data : tensor or None\n        The left operand, could be None if it's not required by op.\n    rhs_data : tensor or None\n        The right operand, could be None if it's not required by op.\n    \"\"\"\n    lhs_shape = F.shape(lhs_data)\n    rhs_shape = F.shape(rhs_data)\n    if len(lhs_shape) != len(rhs_shape):\n        max_ndims = max(len(lhs_shape), len(rhs_shape))\n        lhs_pad_ndims = max_ndims - len(lhs_shape)\n        rhs_pad_ndims = max_ndims - len(rhs_shape)\n        new_lhs_shape = (lhs_shape[0],) + (1,) * lhs_pad_ndims + lhs_shape[1:]\n        new_rhs_shape = (rhs_shape[0],) + (1,) * rhs_pad_ndims + rhs_shape[1:]\n        lhs_data = F.reshape(lhs_data, new_lhs_shape)\n        rhs_data = F.reshape(rhs_data, new_rhs_shape)\n    return lhs_data, rhs_data\n\n\ndef gspmm(g, op, reduce_op, lhs_data, rhs_data):\n    r\"\"\"Generalized Sparse Matrix Multiplication interface.\n    It fuses two steps into one kernel.\n\n    1. Computes messages by :attr:`op` source node and edge features.\n    2. Aggregate the messages by :attr:`reduce_op` as the features on destination nodes.\n\n    .. math::\n        x_v = \\psi_{(u, v, e)\\in \\mathcal{G}}(\\rho(x_u, x_e))\n\n    where :math:`x_v` is the returned feature on destination nodes, and :math:`x_u`,\n    :math:`x_e` refers to :attr:`u`, :attr:`e` respectively. :math:`\\rho` means binary\n    operator :attr:`op` and :math:`\\psi` means reduce operator :attr:`reduce_op`,\n    :math:`\\mathcal{G}` is the graph we apply gspmm on: :attr:`g`.\n\n    Note that this function does not handle gradients.\n\n    Parameters\n    ----------\n    g : DGLGraph\n        The input graph.\n    op : str\n        The binary op's name, could be ``add``, ``sub``, ``mul``, ``div``,\n        ``copy_lhs``, ``copy_rhs``.\n    reduce_op : str\n        Reduce operator, could be ``sum``, ``max``, ``min``, ``mean``.\n    lhs_data : tensor or None\n        The left operand, could be None if it's not required by the op.\n    rhs_data : tensor or None\n        The right operand, could be None if it's not required by the op.\n\n    Returns\n    -------\n    tensor\n        The result tensor.\n    \"\"\"\n    if g._graph.number_of_etypes() == 1:\n        if op not in [\"copy_lhs\", \"copy_rhs\"]:\n            lhs_data, rhs_data = reshape_lhs_rhs(lhs_data, rhs_data)\n        # With max and min reducers infinity will be returned for zero degree nodes\n        ret = gspmm_internal(\n            g._graph,\n            op,\n            \"sum\" if reduce_op == \"mean\" else reduce_op,\n            lhs_data,\n            rhs_data,\n        )\n    else:\n        # lhs_data or rhs_data is None only in unary functions like ``copy-u`` or ``copy_e``\n        lhs_data = (\n            [None] * g._graph.number_of_ntypes()\n            if lhs_data is None\n            else lhs_data\n        )\n        rhs_data = (\n            [None] * g._graph.number_of_etypes()\n            if rhs_data is None\n            else rhs_data\n        )\n        # TODO (Israt): Call reshape func\n        lhs_and_rhs_tuple = tuple(list(lhs_data) + list(rhs_data))\n        ret = gspmm_internal_hetero(\n            g._graph,\n            op,\n            \"sum\" if reduce_op == \"mean\" else reduce_op,\n            len(lhs_data),\n            *lhs_and_rhs_tuple\n        )\n    # TODO (Israt): Add support for 'mean' in heterograph\n    # divide in degrees for mean reducer.\n    if reduce_op == \"mean\":\n        ret_shape = F.shape(ret)\n        deg = g.in_degrees()\n        deg = F.astype(F.clamp(deg, 1, max(g.num_edges(), 1)), F.dtype(ret))\n        deg_shape = (ret_shape[0],) + (1,) * (len(ret_shape) - 1)\n        return ret / F.reshape(deg, deg_shape)\n    else:\n        return ret\n\n\ndef _attach_zerodeg_note(docstring, reducer):\n    note1 = \"\"\"\n    The {} function will return zero for nodes with no incoming messages.\"\"\".format(\n        reducer\n    )\n    note2 = \"\"\"\n    This is implemented by replacing all {} values to zero.\n    \"\"\".format(\n        \"infinity\" if reducer == \"min\" else \"negative infinity\"\n    )\n\n    docstring = docstring + note1\n    if reducer in (\"min\", \"max\"):\n        docstring = docstring + note2\n    return docstring\n\n\ndef _gen_spmm_func(binary_op, reduce_op):\n    name = \"u_{}_e_{}\".format(binary_op, reduce_op)\n    docstring = \"\"\"Generalized SpMM function.\n    It fuses two steps into one kernel.\n\n    1. Computes messages by {} source node and edge features.\n    2. Aggregate the messages by {} as the features on destination nodes.\n\n    Parameters\n    ----------\n    g : DGLGraph\n        The input graph\n    x : tensor\n        The source node features.\n    y : tensor\n        The edge features.\n\n    Returns\n    -------\n    tensor\n        The result tensor.\n\n    Notes\n    -----\n    This function supports autograd (computing input gradients given the output gradient). If the\n    feature shape of two input operands do not match, we first broadcasts the features to a unified\n    shape (note that the memory usage will not increase accordingly) and then performs the operation.\n\n    Broadcasting follows NumPy semantics. Please see\n    https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html\n    for more details about the NumPy broadcasting semantics.\n    \"\"\".format(\n        binary_op, reduce_op\n    )\n    docstring = _attach_zerodeg_note(docstring, reduce_op)\n\n    def func(g, x, y):\n        return gspmm(g, binary_op, reduce_op, x, y)\n\n    func.__name__ = name\n    func.__doc__ = docstring\n    return func\n\n\ndef _gen_copy_reduce_func(binary_op, reduce_op):\n\n    name = \"{}_{}\".format(binary_op, reduce_op)\n    binary_str = {\n        \"copy_u\": \"It copies node feature to edge as the message.\",\n        \"copy_e\": \"It regards edge feature as message.\",\n    }\n    x_str = {\"copy_u\": \"source node\", \"copy_e\": \"edge\"}\n    docstring = lambda binary_op: _attach_zerodeg_note(\n        \"\"\"Generalized SpMM function. {}\n    Then aggregates the message by {} on destination nodes.\n\n    Parameters\n    ----------\n    g : DGLGraph\n        The input graph\n    x : tensor\n        The {} features.\n\n    Returns\n    -------\n    tensor\n        The result tensor.\n\n    Notes\n    -----\n    This function supports autograd (computing input gradients given the output gradient).\n    \"\"\".format(\n            binary_str[binary_op], reduce_op, x_str[binary_op]\n        ),\n        reduce_op,\n    )\n\n    def func(g, x):\n        if binary_op == \"copy_u\":\n            return gspmm(g, \"copy_lhs\", reduce_op, x, None)\n        else:\n            return gspmm(g, \"copy_rhs\", reduce_op, None, x)\n\n    func.__name__ = name\n    func.__doc__ = docstring(binary_op)\n    return func\n\n\ndef _register_spmm_func():\n    \"\"\"Register spmm functions\n\n    - Binary operation plus reduction between u and e: u_[]_e_[]\n    - Copy u plus reduction: copy_u_[]\n    - Copy e plus reduction: copy_e_[]\n    \"\"\"\n    for binary_op in [\"add\", \"sub\", \"mul\", \"div\", \"copy_u\", \"copy_e\"]:\n        for reduce_op in [\"sum\", \"max\", \"min\", \"mean\"]:\n            if binary_op.startswith(\"copy\"):\n                func = _gen_copy_reduce_func(binary_op, reduce_op)\n            else:\n                func = _gen_spmm_func(binary_op, reduce_op)\n            setattr(sys.modules[__name__], func.__name__, func)\n            __all__.append(func.__name__)\n\n\n_register_spmm_func()\n"
  },
  {
    "path": "python/dgl/optim/__init__.py",
    "content": "\"\"\"dgl optims.\"\"\"\nimport importlib\nimport os\nimport sys\n\nfrom ..backend import backend_name\nfrom ..utils import expand_as_pair\n\n\ndef _load_backend(mod_name):\n    mod = importlib.import_module(\".%s\" % mod_name, __name__)\n    thismod = sys.modules[__name__]\n    for api, obj in mod.__dict__.items():\n        setattr(thismod, api, obj)\n\n\n_load_backend(backend_name)\n"
  },
  {
    "path": "python/dgl/optim/mxnet/__init__.py",
    "content": ""
  },
  {
    "path": "python/dgl/optim/pytorch/__init__.py",
    "content": "\"\"\"dgl sparse optimizer for pytorch.\"\"\"\nfrom .sparse_optim import SparseAdagrad, SparseAdam\n"
  },
  {
    "path": "python/dgl/optim/pytorch/sparse_optim.py",
    "content": "\"\"\"Node embedding optimizers\"\"\"\nimport abc\nfrom abc import abstractmethod\n\nimport torch as th\n\nfrom ...cuda import nccl\nfrom ...nn.pytorch import NodeEmbedding\nfrom ...partition import NDArrayPartition\nfrom ...utils import (\n    create_shared_mem_array,\n    gather_pinned_tensor_rows,\n    get_shared_mem_array,\n    pin_memory_inplace,\n    scatter_pinned_tensor_rows,\n)\n\n\nclass SparseGradOptimizer(abc.ABC):\n    r\"\"\"The abstract sparse optimizer.\n\n    Note: dgl sparse optimizer only work with dgl.NodeEmbedding\n\n    Parameters\n    ----------\n    params : list of NodeEmbedding\n        The list of NodeEmbeddings.\n    lr : float\n        The learning rate.\n    \"\"\"\n\n    def __init__(self, params, lr):\n        self._params = params\n        self._lr = lr\n        self._rank = None\n        self._world_size = None\n        self._shared_cache = {}\n        self._clean_grad = False\n        self._opt_meta = {}\n        self._comm = None\n        self._first_step = True\n        self._device = None\n        # hold released shared memory to let other process to munmap it first\n        # otherwise it will crash the training\n        self.shmem_buffer_holder = []\n\n        assert len(params) > 0, \"Empty parameters\"\n        # if we are using shared memory for communication\n        for emb in params:\n            assert isinstance(\n                emb, NodeEmbedding\n            ), \"DGL SparseOptimizer only supports dgl.nn.NodeEmbedding\"\n\n            if self._rank is None:\n                self._rank = emb.rank\n                self._world_size = emb.world_size\n            else:\n                assert (\n                    self._rank == emb.rank\n                ), \"MultiGPU rank for each embedding should be same.\"\n                assert (\n                    self._world_size == emb.world_size\n                ), \"MultiGPU world_size for each embedding should be same.\"\n        assert not self._rank is None\n        assert not self._world_size is None\n\n    def step(self):\n        \"\"\"The step function.\n\n        The step function is invoked at the end of every batch to update embeddings\n        \"\"\"\n        # on the first step, check to see if the grads are on the GPU\n        if self._first_step:\n            for emb in self._params:\n                for _, data in emb._trace:\n                    if data.grad.device.type == \"cuda\":\n                        # create a communicator\n                        if self._device:\n                            assert (\n                                self._device == data.grad.device\n                            ), \"All gradients must be on the same device\"\n                        else:\n                            self._device = data.grad.device\n                    else:\n                        assert (\n                            not self._device\n                        ), \"All gradients must be on the same device\"\n\n            # distributed backend use nccl\n            if self._device and (\n                not th.distributed.is_initialized()\n                or th.distributed.get_backend() == \"nccl\"\n            ):\n                # device is only set if the grads are on a GPU\n                self._comm_setup()\n            else:\n                self._shared_setup()\n            self._first_step = False\n\n        if self._comm:\n            self._comm_step()\n        else:\n            self._shared_step()\n\n    @abstractmethod\n    def setup(self, params):\n        \"\"\"This is function where subclasses can perform any setup they need\n        to. It will be called during the first step, and communicators or\n        shared memory will have been setup before this call.\n\n        Parameters\n        ----------\n        params : list of NodeEmbedding\n            The list of NodeEmbeddings.\n        \"\"\"\n\n    def _comm_setup(self):\n        self._comm = True\n\n    def _shared_setup(self):\n        for emb in self._params:\n            emb_name = emb.name\n            if self._rank == 0:  # the master gpu process\n                opt_meta = create_shared_mem_array(\n                    emb_name + \"_opt_meta\",\n                    (self._world_size, self._world_size),\n                    th.int32,\n                ).zero_()\n\n            if self._rank == 0:\n                emb.store.set(emb_name + \"_opt_meta\", emb_name)\n                self._opt_meta[emb_name] = opt_meta\n            elif self._rank > 0:\n                # receive\n                emb.store.wait([emb_name + \"_opt_meta\"])\n                opt_meta = get_shared_mem_array(\n                    emb_name + \"_opt_meta\",\n                    (self._world_size, self._world_size),\n                    th.int32,\n                )\n                self._opt_meta[emb_name] = opt_meta\n\n    def _comm_step(self):\n        with th.no_grad():\n            idx_in = {}\n            grad_in = {}\n            for emb in self._params:  # pylint: disable=too-many-nested-blocks\n                emb_name = emb.name\n                partition = emb.partition\n\n                if not partition:\n                    # use default partitioning\n                    partition = NDArrayPartition(\n                        emb.num_embeddings,\n                        self._world_size if self._world_size > 0 else 1,\n                        mode=\"remainder\",\n                    )\n\n                # we need to combine gradients from multiple forward paths\n                if len(emb._trace) == 0:\n                    idx = th.zeros((0,), dtype=th.long, device=self._device)\n                    grad = th.zeros(\n                        (0, emb.embedding_dim),\n                        dtype=th.float32,\n                        device=self._device,\n                    )\n                elif len(emb._trace) == 1:\n                    # the special case where we can use the tensors as is\n                    # without any memcpy's\n                    idx, grad = emb._trace[0]\n                    grad = grad.grad.data\n                else:\n                    idx = []\n                    grad = []\n                    for i, data in emb._trace:\n                        idx.append(i)\n                        grad.append(data.grad.data)\n                    idx = th.cat(idx, dim=0)\n                    grad = th.cat(grad, dim=0)\n\n                (\n                    idx_in[emb_name],\n                    grad_in[emb_name],\n                ) = nccl.sparse_all_to_all_push(idx, grad, partition=partition)\n                if emb.partition:\n                    # if the embedding is partitioned, map back to indexes\n                    # into the local tensor\n                    idx_in[emb_name] = partition.map_to_local(idx_in[emb_name])\n\n            if self._clean_grad:\n                # clean gradient track\n                for emb in self._params:\n                    emb.reset_trace()\n                self._clean_grad = False\n\n            for emb in self._params:\n                emb_name = emb.name\n                idx = idx_in[emb_name]\n                grad = grad_in[emb_name]\n                self.update(idx, grad, emb)\n\n    def _shared_step(self):\n        with th.no_grad():\n            # Frequently alloc and free shared memory to hold intermediate tensor is expensive\n            # We cache shared memory buffers in shared_emb.\n            shared_emb = {emb.name: ([], []) for emb in self._params}\n\n            # Go through all sparse embeddings\n            for emb in self._params:  # pylint: disable=too-many-nested-blocks\n                emb_name = emb.name\n\n                # we need to combine gradients from multiple forward paths\n                idx = []\n                grad = []\n                for i, data in emb._trace:\n                    idx.append(i)\n                    grad.append(data.grad.data)\n                # If the sparse embedding is not used in the previous forward step\n                # The idx and grad will be empty, initialize them as empty tensors to\n                # avoid crashing the optimizer step logic.\n                #\n                # Note: we cannot skip the gradient exchange and update steps as other\n                # working processes may send gradient update requests corresponding\n                # to certain embedding to this process.\n                idx = (\n                    th.cat(idx, dim=0)\n                    if len(idx) != 0\n                    else th.zeros((0,), dtype=th.long, device=th.device(\"cpu\"))\n                )\n                grad = (\n                    th.cat(grad, dim=0)\n                    if len(grad) != 0\n                    else th.zeros(\n                        (0, emb.embedding_dim),\n                        dtype=th.float32,\n                        device=th.device(\"cpu\"),\n                    )\n                )\n\n                device = grad.device\n                idx_dtype = idx.dtype\n                grad_dtype = grad.dtype\n                grad_dim = grad.shape[1]\n                if self._world_size > 1:\n                    if emb_name not in self._shared_cache:\n                        self._shared_cache[emb_name] = {}\n\n                    # Each training process takes the resposibility of updating a range\n                    # of node embeddings, thus we can parallel the gradient update.\n                    # The overall progress includes:\n                    #   1. In each training process:\n                    #     1.a Deciding which process a node embedding belongs to according\n                    #         to the formula: process_id = node_idx mod num_of_process(N)\n                    #     1.b Split the node index tensor and gradient tensor into N parts\n                    #         according to step 1.\n                    #     1.c Write each node index sub-tensor and gradient sub-tensor into\n                    #         different DGL shared memory buffers.\n                    #   2. Cross training process synchronization\n                    #   3. In each traning process:\n                    #     3.a Collect node index sub-tensors and gradient sub-tensors\n                    #     3.b Do gradient update\n                    #   4. Done\n                    idx_split = th.remainder(idx, self._world_size).long()\n                    for i in range(self._world_size):\n                        mask = idx_split == i\n                        idx_i = idx[mask]\n                        grad_i = grad[mask]\n\n                        if i == self._rank:\n                            shared_emb[emb_name][0].append(idx_i)\n                            shared_emb[emb_name][1].append(grad_i)\n                        else:\n                            # currently nccl does not support Alltoallv operation\n                            # we need to use CPU shared memory to share gradient\n                            # across processes\n                            idx_i = idx_i.to(th.device(\"cpu\"))\n                            grad_i = grad_i.to(th.device(\"cpu\"))\n                            idx_shmem_name = \"idx_{}_{}_{}\".format(\n                                emb_name, self._rank, i\n                            )\n                            grad_shmem_name = \"grad_{}_{}_{}\".format(\n                                emb_name, self._rank, i\n                            )\n\n                            # Create shared memory to hold temporary index and gradient tensor for\n                            # cross-process send and recv.\n                            if (\n                                idx_shmem_name\n                                not in self._shared_cache[emb_name]\n                                or self._shared_cache[emb_name][\n                                    idx_shmem_name\n                                ].shape[0]\n                                < idx_i.shape[0]\n                            ):\n\n                                if (\n                                    idx_shmem_name\n                                    in self._shared_cache[emb_name]\n                                ):\n                                    self.shmem_buffer_holder.append(\n                                        self._shared_cache[emb_name][\n                                            idx_shmem_name\n                                        ]\n                                    )\n                                    self.shmem_buffer_holder.append(\n                                        self._shared_cache[emb_name][\n                                            grad_shmem_name\n                                        ]\n                                    )\n\n                                # The total number of buffers is the number of NodeEmbeddings *\n                                # world_size * (world_size - 1). The minimun buffer size is 128.\n                                #\n                                # We extend the buffer by idx_i.shape[0] * 2 to avoid\n                                # frequent shared memory allocation.\n                                # The overall buffer cost will be smaller than three times\n                                # the maximum memory requirement for sharing gradients.\n                                buffer_size = (\n                                    128\n                                    if idx_i.shape[0] < 128\n                                    else idx_i.shape[0] * 2\n                                )\n                                idx_shmem = create_shared_mem_array(\n                                    \"{}_{}\".format(idx_shmem_name, buffer_size),\n                                    (buffer_size,),\n                                    idx_dtype,\n                                )\n                                grad_shmem = create_shared_mem_array(\n                                    \"{}_{}\".format(\n                                        grad_shmem_name, buffer_size\n                                    ),\n                                    (buffer_size, grad_dim),\n                                    grad_dtype,\n                                )\n                                self._shared_cache[emb_name][\n                                    idx_shmem_name\n                                ] = idx_shmem\n                                self._shared_cache[emb_name][\n                                    grad_shmem_name\n                                ] = grad_shmem\n\n                            # Fill shared memory with temporal index tensor and gradient tensor\n                            self._shared_cache[emb_name][idx_shmem_name][\n                                : idx_i.shape[0]\n                            ] = idx_i\n                            self._shared_cache[emb_name][grad_shmem_name][\n                                : idx_i.shape[0]\n                            ] = grad_i\n                            self._opt_meta[emb_name][self._rank][\n                                i\n                            ] = idx_i.shape[0]\n                else:\n                    shared_emb[emb_name][0].append(idx)\n                    shared_emb[emb_name][1].append(grad)\n\n            # make sure the idx shape is passed to each process through opt_meta\n            if self._world_size > 1:\n                th.distributed.barrier()\n            for emb in self._params:  # pylint: disable=too-many-nested-blocks\n                emb_name = emb.name\n                if self._world_size > 1:\n                    # The first element in shared_emb[emb_name][0] is the local idx\n                    device = shared_emb[emb_name][0][0].device\n                    # gather gradients from all other processes\n                    for i in range(self._world_size):\n                        if i != self._rank:\n                            idx_shmem_name = \"idx_{}_{}_{}\".format(\n                                emb_name, i, self._rank\n                            )\n                            grad_shmem_name = \"grad_{}_{}_{}\".format(\n                                emb_name, i, self._rank\n                            )\n                            size = self._opt_meta[emb_name][i][self._rank]\n\n                            # Retrive shared memory holding the temporal index and gradient\n                            # tensor that is sent to current training process\n                            if (\n                                idx_shmem_name\n                                not in self._shared_cache[emb_name]\n                                or self._shared_cache[emb_name][\n                                    idx_shmem_name\n                                ].shape[0]\n                                < size\n                            ):\n                                buffer_size = 128 if size < 128 else size * 2\n                                idx_shmem = get_shared_mem_array(\n                                    \"{}_{}\".format(idx_shmem_name, buffer_size),\n                                    (buffer_size,),\n                                    idx_dtype,\n                                )\n                                grad_shmem = get_shared_mem_array(\n                                    \"{}_{}\".format(\n                                        grad_shmem_name, buffer_size\n                                    ),\n                                    (buffer_size, grad_dim),\n                                    grad_dtype,\n                                )\n                                self._shared_cache[emb_name][\n                                    idx_shmem_name\n                                ] = idx_shmem\n                                self._shared_cache[emb_name][\n                                    grad_shmem_name\n                                ] = grad_shmem\n\n                            idx_i = self._shared_cache[emb_name][\n                                idx_shmem_name\n                            ][:size]\n                            grad_i = self._shared_cache[emb_name][\n                                grad_shmem_name\n                            ][:size]\n                            shared_emb[emb_name][0].append(\n                                idx_i.to(device, non_blocking=True)\n                            )\n                            shared_emb[emb_name][1].append(\n                                grad_i.to(device, non_blocking=True)\n                            )\n\n            if self._clean_grad:\n                # clean gradient track\n                for emb in self._params:\n                    emb.reset_trace()\n                self._clean_grad = False\n\n            for emb in self._params:\n                emb_name = emb.name\n\n                idx = th.cat(shared_emb[emb_name][0], dim=0)\n                grad = th.cat(shared_emb[emb_name][1], dim=0)\n                self.update(idx, grad, emb)\n\n            # synchronized gradient update\n            if self._world_size > 1:\n                th.distributed.barrier()\n\n    @abstractmethod\n    def update(self, idx, grad, emb):\n        \"\"\"Update embeddings in a sparse manner\n        Sparse embeddings are updated in mini batches. We maintain gradient states for\n        each embedding so they can be updated separately.\n\n        Parameters\n        ----------\n        idx : tensor\n            Index of the embeddings to be updated.\n        grad : tensor\n            Gradient of each embedding.\n        emb : dgl.nn.NodeEmbedding\n            Sparse node embedding to update.\n        \"\"\"\n\n    def zero_grad(self):\n        \"\"\"clean grad cache\"\"\"\n        self._clean_grad = True\n\n    def state_dict(self, **kwargs):  # pylint: disable=unused-argument\n        \"\"\"Return a copy of the whole optimizer states stored in CPU memory.\n        If this is a multi-processing instance, the states will be returned in\n        shared memory. If the underlying embedding is currently stored on\n        multiple GPUs, all processes must call this method in the same order.\n\n        NOTE: This method must be called by all processes sharing the\n        underlying embedding, or it may result in a deadlock.\n\n        Returns\n        -------\n        dictionary of optimizer states\n            The optimizer states stored in CPU memory.\n        \"\"\"\n        return {\n            \"state\": {\n                emb.name: emb._all_get_optm_state() for emb in self._params\n            },\n            \"param_groups\": self.param_groups,\n        }\n\n    def load_state_dict(\n        self, state_dict, **kwargs\n    ):  # pylint: disable=unused-argument\n        \"\"\"Load the optimizer states. This method must be called by all\n        processes sharing the underlying embedding with identical\n        :attr:`state_dict`.\n\n        NOTE: This method must be called by all processes sharing the\n        underlying embedding, or it may result in a deadlock.\n\n        Parameters\n        ----------\n        state_dict : dictionary of optimizer states\n            The global states to pull values from.\n        \"\"\"\n        for emb in self._params:\n            emb._all_set_optm_state(state_dict[\"state\"][emb.name])\n        self._set_param_groups(state_dict[\"param_groups\"])\n\n    @property\n    @abstractmethod\n    def param_groups(self):\n        \"\"\"Emulate 'param_groups' of torch.optim.Optimizer.\n        Different from that, the returned 'param_groups' doesn't contain\n        parameters because getting the whole embedding is very expensive.\n        It contains other attributes, e.g., lr, eps, for debugging.\n        \"\"\"\n\n    @abstractmethod\n    def _set_param_groups(self, groups):\n        \"\"\"A helper method to load param_groups from saved state_dict.\"\"\"\n\n\nclass SparseAdagrad(SparseGradOptimizer):\n    r\"\"\"Node embedding optimizer using the Adagrad algorithm.\n\n    This optimizer implements a sparse version of Adagrad algorithm for\n    optimizing :class:`dgl.nn.NodeEmbedding`. Being sparse means it only updates\n    the embeddings whose gradients have updates, which are usually a very\n    small portion of the total embeddings.\n\n    Adagrad maintains a :math:`G_{t,i,j}` for every parameter in the embeddings, where\n    :math:`G_{t,i,j}=G_{t-1,i,j} + g_{t,i,j}^2` and :math:`g_{t,i,j}` is the gradient of\n    the dimension :math:`j` of embedding :math:`i` at step :math:`t`.\n\n    NOTE: The support of sparse Adagrad optimizer is experimental.\n\n    Parameters\n    ----------\n    params : list[dgl.nn.NodeEmbedding]\n        The list of dgl.nn.NodeEmbedding.\n    lr : float\n        The learning rate.\n    eps : float, Optional\n        The term added to the denominator to improve numerical stability\n        Default: 1e-10\n\n    Examples\n    --------\n    >>> def initializer(emb):\n            th.nn.init.xavier_uniform_(emb)\n            return emb\n    >>> emb = dgl.nn.NodeEmbedding(g.num_nodes(), 10, 'emb', init_func=initializer)\n    >>> optimizer = dgl.optim.SparseAdagrad([emb], lr=0.001)\n    >>> for blocks in dataloader:\n    ...     ...\n    ...     feats = emb(nids, gpu_0)\n    ...     loss = F.sum(feats + 1, 0)\n    ...     loss.backward()\n    ...     optimizer.step()\n    \"\"\"\n\n    def __init__(self, params, lr, eps=1e-10):\n        super(SparseAdagrad, self).__init__(params, lr)\n        self._eps = eps\n\n        # setup tensors for optimizer states\n        self.setup(self._params)\n\n    def setup(self, params):\n        # We need to register a state sum for each embedding in the kvstore.\n        for emb in params:\n            assert isinstance(\n                emb, NodeEmbedding\n            ), \"SparseAdagrad only supports dgl.nn.NodeEmbedding\"\n\n            emb_name = emb.name\n            if th.device(emb.weight.device) == th.device(\"cpu\"):\n                # if our embedding is on the CPU, our state also has to be\n                if self._rank < 0:\n                    state = th.empty(\n                        emb.weight.shape,\n                        dtype=th.float32,\n                        device=th.device(\"cpu\"),\n                    ).zero_()\n                elif self._rank == 0:\n                    state = create_shared_mem_array(\n                        emb_name + \"_state\", emb.weight.shape, th.float32\n                    ).zero_()\n\n                    if self._world_size > 1:\n                        emb.store.set(emb_name + \"_opt\", emb_name)\n                elif self._rank > 0:\n                    # receive\n                    emb.store.wait([emb_name + \"_opt\"])\n                    state = get_shared_mem_array(\n                        emb_name + \"_state\", emb.weight.shape, th.float32\n                    )\n            else:\n                # distributed state on on gpu\n                state = th.empty(\n                    emb.weight.shape,\n                    dtype=th.float32,\n                    device=emb.weight.device,\n                ).zero_()\n            emb.set_optm_state((state,))\n\n    def update(self, idx, grad, emb):\n        \"\"\"Update embeddings in a sparse manner\n        Sparse embeddings are updated in mini batches. We maintain gradient states for\n        each embedding so they can be updated separately.\n\n        Parameters\n        ----------\n        idx : tensor\n            Index of the embeddings to be updated.\n        grad : tensor\n            Gradient of each embedding.\n        emb : dgl.nn.NodeEmbedding\n            Sparse embedding to update.\n        \"\"\"\n        eps = self._eps\n        clr = self._lr\n\n        # the update is non-linear so indices must be unique\n        grad_indices, inverse, cnt = th.unique(\n            idx, return_inverse=True, return_counts=True\n        )\n        grad_values = th.zeros(\n            (grad_indices.shape[0], grad.shape[1]), device=grad.device\n        )\n        grad_values.index_add_(0, inverse, grad)\n        grad_values = grad_values / cnt.unsqueeze(1)\n\n        grad_sum = grad_values * grad_values\n        (state,) = emb.optm_state\n        state_dev = state.device\n        state_idx = grad_indices.to(state_dev)\n        grad_state = state[state_idx].to(grad.device)\n        grad_state += grad_sum\n        state[state_idx] = grad_state.to(state_dev)\n\n        std_values = grad_state.add_(eps).sqrt_()\n        tmp = clr * grad_values / std_values\n        emb.weight[state_idx] -= tmp.to(state_dev)\n\n    @property\n    def param_groups(self):\n        \"\"\"Emulate 'param_groups' of torch.optim.Optimizer.\n        Different from that, the returned 'param_groups' doesn't contain\n        parameters because getting the whole embedding is very expensive.\n        It contains other attributes, e.g., lr, eps, for debugging.\n        \"\"\"\n        return [{\"lr\": self._lr, \"eps\": self._eps}]\n\n    def _set_param_groups(self, groups):\n        \"\"\"A helper method to load param_groups from saved state_dict.\"\"\"\n        self._lr = groups[0][\"lr\"]\n        self._eps = groups[0][\"eps\"]\n\n\nclass SparseAdam(SparseGradOptimizer):\n    r\"\"\"Node embedding optimizer using the Adam algorithm.\n\n    This optimizer implements a sparse version of Adagrad algorithm for\n    optimizing :class:`dgl.nn.NodeEmbedding`. Being sparse means it only\n    updates the embeddings whose gradients have updates, which are usually\n    a very small portion of the total embeddings.\n\n    Adam maintains a :math:`Gm_{t,i,j}` and `Gp_{t,i,j}` for every parameter\n    in the embeddings, where\n    :math:`Gm_{t,i,j}=beta1 * Gm_{t-1,i,j} + (1-beta1) * g_{t,i,j}`,\n    :math:`Gp_{t,i,j}=beta2 * Gp_{t-1,i,j} + (1-beta2) * g_{t,i,j}^2`,\n    :math:`g_{t,i,j} = lr * Gm_{t,i,j} / (1 - beta1^t) / \\sqrt{Gp_{t,i,j} / (1 - beta2^t)}` and\n    :math:`g_{t,i,j}` is the gradient of the dimension :math:`j` of embedding :math:`i`\n    at step :math:`t`.\n\n    NOTE: The support of sparse Adam optimizer is experimental.\n\n    Parameters\n    ----------\n    params : list[dgl.nn.NodeEmbedding]\n        The list of dgl.nn.NodeEmbeddings.\n    lr : float\n        The learning rate.\n    betas : tuple[float, float], Optional\n        Coefficients used for computing running averages of gradient and its square.\n        Default: (0.9, 0.999)\n    eps : float, Optional\n        The term added to the denominator to improve numerical stability\n        Default: 1e-8\n    use_uva : bool, Optional\n        Whether to use pinned memory for storing 'mem' and 'power' parameters,\n        when the embedding is stored on the CPU. This will improve training\n        speed, but will require locking a large number of virtual memory pages.\n        For embeddings which are stored in GPU memory, this setting will have\n        no effect.\n        Default: True if the gradients are generated on the GPU, and False\n        if the gradients are on the CPU.\n    dtype : torch.dtype, Optional\n        The type to store optimizer state with. Default: th.float32.\n\n    Examples\n    --------\n    >>> def initializer(emb):\n            th.nn.init.xavier_uniform_(emb)\n            return emb\n    >>> emb = dgl.nn.NodeEmbedding(g.num_nodes(), 10, 'emb', init_func=initializer)\n    >>> optimizer = dgl.optim.SparseAdam([emb], lr=0.001)\n    >>> for blocks in dataloader:\n    ...     ...\n    ...     feats = emb(nids, gpu_0)\n    ...     loss = F.sum(feats + 1, 0)\n    ...     loss.backward()\n    ...     optimizer.step()\n    \"\"\"\n\n    def __init__(\n        self,\n        params,\n        lr,\n        betas=(0.9, 0.999),\n        eps=1e-08,\n        use_uva=None,\n        dtype=th.float32,\n    ):\n        super(SparseAdam, self).__init__(params, lr)\n        self._lr = lr\n        self._beta1 = betas[0]\n        self._beta2 = betas[1]\n        self._eps = eps\n        self._use_uva = use_uva\n        self._nd_handle = {}\n        self._is_using_uva = {}\n        assert dtype in [th.float16, th.float32], (\n            \"Unsupported dtype {}. Valid choices are th.float32 \"\n            \"and th.float32\".format(dtype)\n        )\n        self._dtype = dtype\n\n        # setup tensors for optimizer states\n        self.setup(self._params)\n\n    def _setup_uva(self, name, mem, power):\n        self._is_using_uva[name] = True\n        mem_nd = pin_memory_inplace(mem)\n        power_nd = pin_memory_inplace(power)\n        self._nd_handle[name] = [mem_nd, power_nd]\n\n    def setup(self, params):\n        # We need to register a state sum for each embedding in the kvstore.\n        for emb in params:\n            assert isinstance(\n                emb, NodeEmbedding\n            ), \"SparseAdam only supports dgl.nn.NodeEmbedding\"\n            emb_name = emb.name\n            self._is_using_uva[emb_name] = self._use_uva\n            if th.device(emb.weight.device) == th.device(\"cpu\"):\n                # if our embedding is on the CPU, our state also has to be\n                if self._rank < 0:\n                    state_step = th.empty(\n                        (emb.weight.shape[0],),\n                        dtype=th.int32,\n                        device=th.device(\"cpu\"),\n                    ).zero_()\n                    state_mem = th.empty(\n                        emb.weight.shape,\n                        dtype=self._dtype,\n                        device=th.device(\"cpu\"),\n                    ).zero_()\n                    state_power = th.empty(\n                        emb.weight.shape,\n                        dtype=self._dtype,\n                        device=th.device(\"cpu\"),\n                    ).zero_()\n                elif self._rank == 0:\n                    state_step = create_shared_mem_array(\n                        emb_name + \"_step\", (emb.weight.shape[0],), th.int32\n                    ).zero_()\n                    state_mem = create_shared_mem_array(\n                        emb_name + \"_mem\", emb.weight.shape, self._dtype\n                    ).zero_()\n                    state_power = create_shared_mem_array(\n                        emb_name + \"_power\", emb.weight.shape, self._dtype\n                    ).zero_()\n\n                    if self._world_size > 1:\n                        emb.store.set(emb_name + \"_opt\", emb_name)\n                elif self._rank > 0:\n                    # receive\n                    emb.store.wait([emb_name + \"_opt\"])\n                    state_step = get_shared_mem_array(\n                        emb_name + \"_step\", (emb.weight.shape[0],), th.int32\n                    )\n                    state_mem = get_shared_mem_array(\n                        emb_name + \"_mem\", emb.weight.shape, self._dtype\n                    )\n                    state_power = get_shared_mem_array(\n                        emb_name + \"_power\", emb.weight.shape, self._dtype\n                    )\n\n                if self._is_using_uva[emb_name]:\n                    # if use_uva has been explicitly set to true, otherwise\n                    # wait until first step to decide\n                    self._setup_uva(emb_name, state_mem, state_power)\n            else:\n                # make sure we don't use UVA when data is on the GPU\n                self._is_using_uva[emb_name] = False\n\n                # distributed state on on gpu\n                state_step = th.empty(\n                    [emb.weight.shape[0]],\n                    dtype=th.int32,\n                    device=emb.weight.device,\n                ).zero_()\n                state_mem = th.empty(\n                    emb.weight.shape,\n                    dtype=self._dtype,\n                    device=emb.weight.device,\n                ).zero_()\n                state_power = th.empty(\n                    emb.weight.shape,\n                    dtype=self._dtype,\n                    device=emb.weight.device,\n                ).zero_()\n            state = (state_step, state_mem, state_power)\n            emb.set_optm_state(state)\n\n    def update(self, idx, grad, emb):\n        \"\"\"Update embeddings in a sparse manner\n        Sparse embeddings are updated in mini batches. We maintain gradient states for\n        each embedding so they can be updated separately.\n\n        Parameters\n        ----------\n        idx : tensor\n            Index of the embeddings to be updated.\n        grad : tensor\n            Gradient of each embedding.\n        emb : dgl.nn.NodeEmbedding\n            Sparse embedding to update.\n        \"\"\"\n        with th.no_grad():\n            state_step, state_mem, state_power = emb.optm_state\n            exec_dtype = grad.dtype\n            exec_dev = grad.device\n            state_dev = state_step.device\n\n            # whether or not we need to transfer data from the GPU to the CPU\n            # while updating the weights\n            is_d2h = state_dev.type == \"cpu\" and exec_dev.type == \"cuda\"\n\n            # only perform async copies cpu -> gpu, or gpu-> gpu, but block\n            # when copying to the cpu, so as to ensure the copy is finished\n            # before operating on the data on the cpu\n            state_block = is_d2h\n\n            if self._is_using_uva[emb.name] is None and is_d2h:\n                # we should use UVA going forward\n                self._setup_uva(emb.name, state_mem, state_power)\n            elif self._is_using_uva[emb.name] is None:\n                # we shouldn't use UVA going forward\n                self._is_using_uva[emb.name] = False\n\n            use_uva = self._is_using_uva[emb.name]\n\n            beta1 = self._beta1\n            beta2 = self._beta2\n            eps = self._eps\n\n            clr = self._lr\n            # There can be duplicated indices due to sampling.\n            # Thus unique them here and average the gradient here.\n            grad_indices, inverse, cnt = th.unique(\n                idx, return_inverse=True, return_counts=True\n            )\n            state_idx = grad_indices.to(state_dev)\n            state_step[state_idx] += 1\n            state_step = state_step[state_idx].to(exec_dev)\n\n            if use_uva:\n                orig_mem = gather_pinned_tensor_rows(state_mem, grad_indices)\n                orig_power = gather_pinned_tensor_rows(\n                    state_power, grad_indices\n                )\n            else:\n                orig_mem = state_mem[state_idx].to(exec_dev)\n                orig_power = state_power[state_idx].to(exec_dev)\n            # convert to exec dtype\n            orig_mem = orig_mem.to(dtype=exec_dtype)\n            orig_power = orig_power.to(dtype=exec_dtype)\n\n            grad_values = th.zeros(\n                (grad_indices.shape[0], grad.shape[1]), device=exec_dev\n            )\n            grad_values.index_add_(0, inverse, grad)\n            grad_values = grad_values / cnt.unsqueeze(1)\n\n            grad_mem = grad_values\n            grad_power = grad_values * grad_values\n\n            update_mem = beta1 * orig_mem + (1.0 - beta1) * grad_mem\n            update_power = beta2 * orig_power + (1.0 - beta2) * grad_power\n\n            if use_uva:\n                scatter_pinned_tensor_rows(\n                    state_mem, grad_indices, update_mem.to(dtype=self._dtype)\n                )\n                scatter_pinned_tensor_rows(\n                    state_power,\n                    grad_indices,\n                    update_power.to(dtype=self._dtype),\n                )\n            else:\n                update_mem_dst = update_mem.to(dtype=self._dtype).to(\n                    state_dev, non_blocking=True\n                )\n                update_power_dst = update_power.to(dtype=self._dtype).to(\n                    state_dev, non_blocking=True\n                )\n                if state_block:\n                    # use events to try and overlap CPU and GPU as much as possible\n                    update_event = th.cuda.Event()\n                    update_event.record()\n\n            update_mem_corr = update_mem / (\n                1.0 - th.pow(th.tensor(beta1, device=exec_dev), state_step)\n            ).unsqueeze(1)\n            update_power_corr = update_power / (\n                1.0 - th.pow(th.tensor(beta2, device=exec_dev), state_step)\n            ).unsqueeze(1)\n            std_values = (\n                clr * update_mem_corr / (th.sqrt(update_power_corr) + eps)\n            )\n            std_values_dst = std_values.to(state_dev, non_blocking=True)\n\n            if state_block:\n                std_event = th.cuda.Event()\n                std_event.record()\n\n            if not use_uva:\n                if state_block:\n                    # wait for our transfers from exec_dev to state_dev to finish\n                    # before we can use them\n                    update_event.wait()\n                state_mem[state_idx] = update_mem_dst\n                state_power[state_idx] = update_power_dst\n\n            if state_block:\n                # wait for the transfer of std_values to finish before we\n                # can use it\n                std_event.wait()\n            emb.weight[state_idx] -= std_values_dst\n\n    @property\n    def param_groups(self):\n        \"\"\"Emulate 'param_groups' of torch.optim.Optimizer.\n        Different from that, the returned 'param_groups' doesn't contain\n        parameters because getting the whole embedding is very expensive.\n        It contains other attributes, e.g., lr, betas, eps, for debugging.\n        \"\"\"\n        return [\n            {\n                \"lr\": self._lr,\n                \"betas\": (self._beta1, self._beta2),\n                \"eps\": self._eps,\n            }\n        ]\n\n    def _set_param_groups(self, groups):\n        \"\"\"A helper method to load param_groups from saved state_dict.\"\"\"\n        self._lr = groups[0][\"lr\"]\n        self._beta1, self._beta2 = groups[0][\"betas\"]\n        self._eps = groups[0][\"eps\"]\n"
  },
  {
    "path": "python/dgl/optim/tensorflow/__init__.py",
    "content": ""
  },
  {
    "path": "python/dgl/partition.py",
    "content": "\"\"\"Module for graph partition utilities.\"\"\"\nimport os\nimport re\nimport time\n\nimport numpy as np\n\nfrom . import backend as F, utils\nfrom ._ffi.function import _init_api\nfrom .base import EID, ETYPE, NID, NTYPE\nfrom .heterograph import DGLGraph\nfrom .ndarray import NDArray\nfrom .subgraph import edge_subgraph\n\n__all__ = [\n    \"metis_partition\",\n    \"metis_partition_assignment\",\n    \"partition_graph_with_halo\",\n]\n\n\ndef reorder_nodes(g, new_node_ids):\n    \"\"\"Generate a new graph with new node IDs.\n\n    We assign each node in the input graph with a new node ID. This results in\n    a new graph.\n\n    Parameters\n    ----------\n    g : DGLGraph\n        The input graph\n    new_node_ids : a tensor\n        The new node IDs\n    Returns\n    -------\n    DGLGraph\n        The graph with new node IDs.\n    \"\"\"\n    assert (\n        len(new_node_ids) == g.num_nodes()\n    ), \"The number of new node ids must match #nodes in the graph.\"\n    new_node_ids = utils.toindex(new_node_ids)\n    sorted_ids, idx = F.sort_1d(new_node_ids.tousertensor())\n    assert (\n        F.asnumpy(sorted_ids[0]) == 0\n        and F.asnumpy(sorted_ids[-1]) == g.num_nodes() - 1\n    ), \"The new node IDs are incorrect.\"\n    new_gidx = _CAPI_DGLReorderGraph_Hetero(\n        g._graph, new_node_ids.todgltensor()\n    )\n    new_g = DGLGraph(gidx=new_gidx, ntypes=[\"_N\"], etypes=[\"_E\"])\n    new_g.ndata[\"orig_id\"] = idx\n    return new_g\n\n\ndef _get_halo_heterosubgraph_inner_node(halo_subg):\n    return _CAPI_GetHaloSubgraphInnerNodes_Hetero(halo_subg)\n\n\ndef reshuffle_graph(g, node_part=None):\n    \"\"\"Reshuffle node ids and edge IDs of a graph.\n\n    This function reshuffles nodes and edges in a graph so that all nodes/edges of the same type\n    have contiguous IDs. If a graph is partitioned and nodes are assigned to different partitions,\n    all nodes/edges in a partition should\n    get contiguous IDs; within a partition, all nodes/edges of the same type have contigous IDs.\n\n    Parameters\n    ----------\n    g : DGLGraph\n        The input graph.\n    node_part : Tensor\n        This is a vector whose length is the same as the number of nodes in the input graph.\n        Each element indicates the partition ID the corresponding node is assigned to.\n\n    Returns\n    -------\n    (DGLGraph, Tensor)\n        The graph whose nodes and edges are reshuffled.\n        The 1D tensor that indicates the partition IDs of the nodes in the reshuffled graph.\n    \"\"\"\n    # In this case, we don't need to reshuffle node IDs and edge IDs.\n    if node_part is None:\n        g.ndata[\"orig_id\"] = F.arange(0, g.num_nodes())\n        g.edata[\"orig_id\"] = F.arange(0, g.num_edges())\n        return g, None\n\n    start = time.time()\n    if node_part is not None:\n        node_part = utils.toindex(node_part)\n        node_part = node_part.tousertensor()\n    if NTYPE in g.ndata:\n        is_hetero = len(F.unique(g.ndata[NTYPE])) > 1\n    else:\n        is_hetero = False\n    if is_hetero:\n        num_node_types = F.max(g.ndata[NTYPE], 0) + 1\n        if node_part is not None:\n            sorted_part, new2old_map = F.sort_1d(\n                node_part * num_node_types + g.ndata[NTYPE]\n            )\n        else:\n            sorted_part, new2old_map = F.sort_1d(g.ndata[NTYPE])\n        sorted_part = F.floor_div(sorted_part, num_node_types)\n    elif node_part is not None:\n        sorted_part, new2old_map = F.sort_1d(node_part)\n    else:\n        g.ndata[\"orig_id\"] = g.ndata[NID]\n        g.edata[\"orig_id\"] = g.edata[EID]\n        return g, None\n\n    new_node_ids = np.zeros((g.num_nodes(),), dtype=np.int64)\n    new_node_ids[F.asnumpy(new2old_map)] = np.arange(0, g.num_nodes())\n    # If the input graph is homogneous, we only need to create an empty array, so that\n    # _CAPI_DGLReassignEdges_Hetero knows how to handle it.\n    etype = (\n        g.edata[ETYPE]\n        if ETYPE in g.edata\n        else F.zeros((0), F.dtype(sorted_part), F.cpu())\n    )\n    g = reorder_nodes(g, new_node_ids)\n    node_part = utils.toindex(sorted_part)\n    # We reassign edges in in-CSR. In this way, after partitioning, we can ensure\n    # that all edges in a partition are in the contiguous ID space.\n    etype_idx = utils.toindex(etype)\n    orig_eids = _CAPI_DGLReassignEdges_Hetero(\n        g._graph, etype_idx.todgltensor(), node_part.todgltensor(), True\n    )\n    orig_eids = utils.toindex(orig_eids)\n    orig_eids = orig_eids.tousertensor()\n    g.edata[\"orig_id\"] = orig_eids\n\n    print(\n        \"Reshuffle nodes and edges: {:.3f} seconds\".format(time.time() - start)\n    )\n    return g, node_part.tousertensor()\n\n\ndef partition_graph_with_halo(g, node_part, extra_cached_hops, reshuffle=False):\n    \"\"\"Partition a graph.\n\n    Based on the given node assignments for each partition, the function splits\n    the input graph into subgraphs. A subgraph may contain HALO nodes which does\n    not belong to the partition of a subgraph but are connected to the nodes\n    in the partition within a fixed number of hops.\n\n    If `reshuffle` is turned on, the function reshuffles node IDs and edge IDs\n    of the input graph before partitioning. After reshuffling, all nodes and edges\n    in a partition fall in a contiguous ID range in the input graph.\n    The partitioend subgraphs have node data 'orig_id', which stores the node IDs\n    in the original input graph.\n\n    Parameters\n    ------------\n    g: DGLGraph\n        The graph to be partitioned\n    node_part: 1D tensor\n        Specify which partition a node is assigned to. The length of this tensor\n        needs to be the same as the number of nodes of the graph. Each element\n        indicates the partition ID of a node.\n    extra_cached_hops: int\n        The number of hops a HALO node can be accessed.\n    reshuffle : bool\n        Resuffle nodes so that nodes in the same partition are in the same ID range.\n\n    Returns\n    --------\n    a dict of DGLGraphs\n        The key is the partition ID and the value is the DGLGraph of the partition.\n    Tensor\n        1D tensor that stores the mapping between the reshuffled node IDs and\n        the original node IDs if 'reshuffle=True'. Otherwise, return None.\n    Tensor\n        1D tensor that stores the mapping between the reshuffled edge IDs and\n        the original edge IDs if 'reshuffle=True'. Otherwise, return None.\n    \"\"\"\n    assert len(node_part) == g.num_nodes()\n    if reshuffle:\n        g, node_part = reshuffle_graph(g, node_part)\n        orig_nids = g.ndata[\"orig_id\"]\n        orig_eids = g.edata[\"orig_id\"]\n\n    node_part = utils.toindex(node_part)\n    start = time.time()\n    subgs = _CAPI_DGLPartitionWithHalo_Hetero(\n        g._graph, node_part.todgltensor(), extra_cached_hops\n    )\n    # g is no longer needed. Free memory.\n    g = None\n    print(\"Split the graph: {:.3f} seconds\".format(time.time() - start))\n    subg_dict = {}\n    node_part = node_part.tousertensor()\n    start = time.time()\n\n    # This function determines whether an edge belongs to a partition.\n    # An edge is assigned to a partition based on its destination node. If its destination node\n    # is assigned to a partition, we assign the edge to the partition as well.\n    def get_inner_edge(subg, inner_node):\n        inner_edge = F.zeros((subg.num_edges(),), F.int8, F.cpu())\n        inner_nids = F.nonzero_1d(inner_node)\n        # TODO(zhengda) we need to fix utils.toindex() to avoid the dtype cast below.\n        inner_nids = F.astype(inner_nids, F.int64)\n        inner_eids = subg.in_edges(inner_nids, form=\"eid\")\n        inner_edge = F.scatter_row(\n            inner_edge,\n            inner_eids,\n            F.ones((len(inner_eids),), F.dtype(inner_edge), F.cpu()),\n        )\n        return inner_edge\n\n    # This creaets a subgraph from subgraphs returned from the CAPI above.\n    def create_subgraph(subg, induced_nodes, induced_edges, inner_node):\n        subg1 = DGLGraph(gidx=subg.graph, ntypes=[\"_N\"], etypes=[\"_E\"])\n        # If IDs are shuffled, we should shuffled edges. This will help us collect edge data\n        # from the distributed graph after training.\n        if reshuffle:\n            # When we shuffle edges, we need to make sure that the inner edges are assigned with\n            # contiguous edge IDs and their ID range starts with 0. In other words, we want to\n            # place these edge IDs in the front of the edge list. To ensure that, we add the IDs\n            # of outer edges with a large value, so we will get the sorted list as we want.\n            max_eid = F.max(induced_edges[0], 0) + 1\n            inner_edge = get_inner_edge(subg1, inner_node)\n            eid = F.astype(induced_edges[0], F.int64) + max_eid * F.astype(\n                inner_edge == 0, F.int64\n            )\n\n            _, index = F.sort_1d(eid)\n            subg1 = edge_subgraph(subg1, index, relabel_nodes=False)\n            subg1.ndata[NID] = induced_nodes[0]\n            subg1.edata[EID] = F.gather_row(induced_edges[0], index)\n        else:\n            subg1.ndata[NID] = induced_nodes[0]\n            subg1.edata[EID] = induced_edges[0]\n        return subg1\n\n    for i, subg in enumerate(subgs):\n        inner_node = _get_halo_heterosubgraph_inner_node(subg)\n        inner_node = F.zerocopy_from_dlpack(inner_node.to_dlpack())\n        subg = create_subgraph(\n            subg, subg.induced_nodes, subg.induced_edges, inner_node\n        )\n        subg.ndata[\"inner_node\"] = inner_node\n        subg.ndata[\"part_id\"] = F.gather_row(node_part, subg.ndata[NID])\n        if reshuffle:\n            subg.ndata[\"orig_id\"] = F.gather_row(orig_nids, subg.ndata[NID])\n            subg.edata[\"orig_id\"] = F.gather_row(orig_eids, subg.edata[EID])\n\n        if extra_cached_hops >= 1:\n            inner_edge = get_inner_edge(subg, inner_node)\n        else:\n            inner_edge = F.ones((subg.num_edges(),), F.int8, F.cpu())\n        subg.edata[\"inner_edge\"] = inner_edge\n        subg_dict[i] = subg\n    print(\"Construct subgraphs: {:.3f} seconds\".format(time.time() - start))\n    if reshuffle:\n        return subg_dict, orig_nids, orig_eids\n    else:\n        return subg_dict, None, None\n\n\ndef get_peak_mem():\n    \"\"\"Get the peak memory size.\n\n    Returns\n    -------\n    float\n        The peak memory size in GB.\n    \"\"\"\n    if not os.path.exists(\"/proc/self/status\"):\n        return 0.0\n    for line in open(\"/proc/self/status\", \"r\"):\n        if \"VmPeak\" in line:\n            mem = re.findall(r\"\\d+\", line)[0]\n            return int(mem) / 1024 / 1024\n    return 0.0\n\n\ndef metis_partition_assignment(\n    g, k, balance_ntypes=None, balance_edges=False, mode=\"k-way\", objtype=\"cut\"\n):\n    \"\"\"This assigns nodes to different partitions with Metis partitioning algorithm.\n\n    When performing Metis partitioning, we can put some constraint on the partitioning.\n    Current, it supports two constrants to balance the partitioning. By default, Metis\n    always tries to balance the number of nodes in each partition.\n\n    * `balance_ntypes` balances the number of nodes of different types in each partition.\n    * `balance_edges` balances the number of edges in each partition.\n\n    To balance the node types, a user needs to pass a vector of N elements to indicate\n    the type of each node. N is the number of nodes in the input graph.\n\n    After the partition assignment, we construct partitions.\n\n    Parameters\n    ----------\n    g : DGLGraph\n        The graph to be partitioned\n    k : int\n        The number of partitions.\n    balance_ntypes : tensor\n        Node type of each node\n    balance_edges : bool\n        Indicate whether to balance the edges.\n    mode : str, \"k-way\" or \"recursive\"\n        Whether use multilevel recursive bisection or multilevel k-way paritioning.\n    objtype : str, \"cut\" or \"vol\"\n        Set the objective as edge-cut minimization or communication volume minimization. This\n        argument is used by the Metis algorithm.\n\n    Returns\n    -------\n    a 1-D tensor\n        A vector with each element that indicates the partition ID of a vertex.\n    \"\"\"\n    assert mode in (\n        \"k-way\",\n        \"recursive\",\n    ), \"'mode' can only be 'k-way' or 'recursive'\"\n    assert (\n        g.idtype == F.int64\n    ), \"IdType of graph is required to be int64 for now.\"\n    # METIS works only on symmetric graphs.\n    # The METIS runs on the symmetric graph to generate the node assignment to partitions.\n    start = time.time()\n    sym_gidx = _CAPI_DGLMakeSymmetric_Hetero(g._graph)\n    sym_g = DGLGraph(gidx=sym_gidx)\n    print(\n        \"Convert a graph into a bidirected graph: {:.3f} seconds, peak memory: {:.3f} GB\".format(\n            time.time() - start, get_peak_mem()\n        )\n    )\n    vwgt = []\n    # To balance the node types in each partition, we can take advantage of the vertex weights\n    # in Metis. When vertex weights are provided, Metis will tries to generate partitions with\n    # balanced vertex weights. A vertex can be assigned with multiple weights. The vertex weights\n    # are stored in a vector of N * w elements, where N is the number of vertices and w\n    # is the number of weights per vertex. Metis tries to balance the first weight, and then\n    # the second weight, and so on.\n    # When balancing node types, we use the first weight to indicate the first node type.\n    # if a node belongs to the first node type, its weight is set to 1; otherwise, 0.\n    # Similary, we set the second weight for the second node type and so on. The number\n    # of weights is the same as the number of node types.\n    start = time.time()\n    if balance_ntypes is not None:\n        assert (\n            len(balance_ntypes) == g.num_nodes()\n        ), \"The length of balance_ntypes should be equal to #nodes in the graph\"\n        balance_ntypes = F.tensor(balance_ntypes)\n        uniq_ntypes = F.unique(balance_ntypes)\n        for ntype in uniq_ntypes:\n            vwgt.append(F.astype(balance_ntypes == ntype, F.int64))\n\n    # When balancing edges in partitions, we use in-degree as one of the weights.\n    if balance_edges:\n        if balance_ntypes is None:\n            vwgt.append(F.astype(g.in_degrees(), F.int64))\n        else:\n            for ntype in uniq_ntypes:\n                nids = F.asnumpy(F.nonzero_1d(balance_ntypes == ntype))\n                degs = np.zeros((g.num_nodes(),), np.int64)\n                degs[nids] = F.asnumpy(g.in_degrees(nids))\n                vwgt.append(F.zerocopy_from_numpy(degs))\n\n    # The vertex weights have to be stored in a vector.\n    if len(vwgt) > 0:\n        vwgt = F.stack(vwgt, 1)\n        shape = (\n            np.prod(\n                F.shape(vwgt),\n            ),\n        )\n        vwgt = F.reshape(vwgt, shape)\n        vwgt = F.to_dgl_nd(vwgt)\n    else:\n        vwgt = F.zeros((0,), F.int64, F.cpu())\n        vwgt = F.to_dgl_nd(vwgt)\n    print(\n        \"Construct multi-constraint weights: {:.3f} seconds, peak memory: {:.3f} GB\".format(\n            time.time() - start, get_peak_mem()\n        )\n    )\n\n    start = time.time()\n    node_part = _CAPI_DGLMetisPartition_Hetero(\n        sym_g._graph, k, vwgt, mode, (objtype == \"cut\")\n    )\n    print(\n        \"Metis partitioning: {:.3f} seconds, peak memory: {:.3f} GB\".format(\n            time.time() - start, get_peak_mem()\n        )\n    )\n    if len(node_part) == 0:\n        return None\n    else:\n        node_part = utils.toindex(node_part)\n        return node_part.tousertensor()\n\n\ndef metis_partition(\n    g,\n    k,\n    extra_cached_hops=0,\n    reshuffle=False,\n    balance_ntypes=None,\n    balance_edges=False,\n    mode=\"k-way\",\n):\n    \"\"\"This is to partition a graph with Metis partitioning.\n\n    Metis assigns vertices to partitions. This API constructs subgraphs with the vertices assigned\n    to the partitions and their incoming edges. A subgraph may contain HALO nodes which does\n    not belong to the partition of a subgraph but are connected to the nodes\n    in the partition within a fixed number of hops.\n\n    When performing Metis partitioning, we can put some constraint on the partitioning.\n    Current, it supports two constrants to balance the partitioning. By default, Metis\n    always tries to balance the number of nodes in each partition.\n\n    * `balance_ntypes` balances the number of nodes of different types in each partition.\n    * `balance_edges` balances the number of edges in each partition.\n\n    To balance the node types, a user needs to pass a vector of N elements to indicate\n    the type of each node. N is the number of nodes in the input graph.\n\n    If `reshuffle` is turned on, the function reshuffles node IDs and edge IDs\n    of the input graph before partitioning. After reshuffling, all nodes and edges\n    in a partition fall in a contiguous ID range in the input graph.\n    The partitioend subgraphs have node data 'orig_id', which stores the node IDs\n    in the original input graph.\n\n    The partitioned subgraph is stored in DGLGraph. The DGLGraph has the `part_id`\n    node data that indicates the partition a node belongs to. The subgraphs do not contain\n    the node/edge data in the input graph.\n\n    Parameters\n    ------------\n    g: DGLGraph\n        The graph to be partitioned\n    k: int\n        The number of partitions.\n    extra_cached_hops: int\n        The number of hops a HALO node can be accessed.\n    reshuffle : bool\n        Resuffle nodes so that nodes in the same partition are in the same ID range.\n    balance_ntypes : tensor\n        Node type of each node\n    balance_edges : bool\n        Indicate whether to balance the edges.\n    mode : str, \"k-way\" or \"recursive\"\n        Whether use multilevel recursive bisection or multilevel k-way paritioning.\n\n    Returns\n    --------\n    a dict of DGLGraphs\n        The key is the partition ID and the value is the DGLGraph of the partition.\n    \"\"\"\n    assert mode in (\n        \"k-way\",\n        \"recursive\",\n    ), \"'mode' can only be 'k-way' or 'recursive'\"\n    node_part = metis_partition_assignment(\n        g, k, balance_ntypes, balance_edges, mode\n    )\n    if node_part is None:\n        return None\n\n    # Then we split the original graph into parts based on the METIS partitioning results.\n    return partition_graph_with_halo(\n        g, node_part, extra_cached_hops, reshuffle\n    )[0]\n\n\nclass NDArrayPartition(object):\n    \"\"\"Create a new partition of an NDArray. That is, an object which assigns\n    each row of an NDArray to a specific partition.\n\n    Parameters\n    ----------\n    array_size : int\n        The first dimension of the array being partitioned.\n    num_parts : int\n        The number of parts to divide the array into.\n    mode : String\n        The type of partition. Currently, the only valid values are\n        'remainder' and 'range'.\n        'remainder' assigns rows based on remainder when dividing the row id by the\n        number of parts (e.g., i % num_parts).\n        'range' assigns rows based on which part of the range 'part_ranges'\n        they fall into.\n    part_ranges : Tensor or dgl.NDArray, Optional\n        Should only be specified when the mode is 'range'. Should be of the\n        length `num_parts + 1`, and be the exclusive prefix-sum of the number\n        of nodes in each partition. That is, for 3 partitions, we could have\n        the list [0, a, b, 'array_size'], and all rows with index less\n        than 'a' are assigned to partition 0, all rows with index greater than\n        or equal to 'a' and less than 'b' are in partition 1, and all rows\n        with index greater or equal to 'b' are in partition 2. Should have\n        the same context as the partitioned NDArray (i.e., be on the same GPU).\n\n    Examples\n    --------\n\n    A partition of a homgeonous graph `g`, where the vertices are\n    striped across processes can be generated via:\n\n    >>> from dgl.partition import NDArrayPartition\n    >>> part = NDArrayPartition(g.num_nodes(), num_parts, mode='remainder' )\n\n    A range based partition of a homogenous graph `g`'s nodes, where\n    the nodes are stored in contiguous memory. This converts an existing\n    range based partitioning (e.g. from a\n    dgl.distributed.graph_partition_book.RangePartitionBook)\n    'max_node_map', to an NDArrayPartition 'part'.\n\n    >>> part_range = [0]\n    >>> for part in part_book.metadata():\n    >>>     part_range.append(part_range[-1] + part['num_nodes'])\n    >>> part = NDArrayPartition(g.num_nodes(), num_parts, mode='range',\n    ...                         part_ranges=part_range)\n    \"\"\"\n\n    def __init__(\n        self, array_size, num_parts, mode=\"remainder\", part_ranges=None\n    ):\n        assert num_parts > 0, 'Invalid \"num_parts\", must be > 0.'\n        if mode == \"remainder\":\n            assert part_ranges is None, (\n                \"When using remainder-based \"\n                'partitioning, \"part_ranges\" should not be specified.'\n            )\n            self._partition = _CAPI_DGLNDArrayPartitionCreateRemainderBased(\n                array_size, num_parts\n            )\n        elif mode == \"range\":\n            assert part_ranges is not None, (\n                \"When using range-based \"\n                'partitioning, \"part_ranges\" must not be None.'\n            )\n            assert part_ranges[0] == 0 and part_ranges[-1] == array_size, (\n                \"part_ranges[0] must be 0, and part_ranges[-1] must be \"\n                '\"array_size\".'\n            )\n            if F.is_tensor(part_ranges):\n                part_ranges = F.zerocopy_to_dgl_ndarray(part_ranges)\n            assert isinstance(part_ranges, NDArray), (\n                '\"part_ranges\" must ' \"be Tensor or dgl.NDArray.\"\n            )\n            self._partition = _CAPI_DGLNDArrayPartitionCreateRangeBased(\n                array_size, num_parts, part_ranges\n            )\n        else:\n            assert False, 'Unknown partition mode \"{}\"'.format(mode)\n        self._array_size = array_size\n        self._num_parts = num_parts\n\n    def num_parts(self):\n        \"\"\"Get the number of partitions.\"\"\"\n        return self._num_parts\n\n    def array_size(self):\n        \"\"\"Get the total size of the first dimension of the partitioned array.\"\"\"\n        return self._array_size\n\n    def get(self):\n        \"\"\"Get the C-handle for this object.\"\"\"\n        return self._partition\n\n    def get_local_indices(self, part, ctx):\n        \"\"\"Get the set of global indices in this given partition.\"\"\"\n        return self.map_to_global(\n            F.arange(0, self.local_size(part), ctx=ctx), part\n        )\n\n    def local_size(self, part):\n        \"\"\"Get the number of rows/items assigned to the given part.\"\"\"\n        return _CAPI_DGLNDArrayPartitionGetPartSize(self._partition, part)\n\n    def map_to_local(self, idxs):\n        \"\"\"Convert the set of global indices to local indices\"\"\"\n        return F.zerocopy_from_dgl_ndarray(\n            _CAPI_DGLNDArrayPartitionMapToLocal(\n                self._partition, F.zerocopy_to_dgl_ndarray(idxs)\n            )\n        )\n\n    def map_to_global(self, idxs, part_id):\n        \"\"\"Convert the set of local indices ot global indices\"\"\"\n        return F.zerocopy_from_dgl_ndarray(\n            _CAPI_DGLNDArrayPartitionMapToGlobal(\n                self._partition, F.zerocopy_to_dgl_ndarray(idxs), part_id\n            )\n        )\n\n    def generate_permutation(self, idxs):\n        \"\"\"Produce a scheme that maps the given indices to separate partitions\n        and the counts of how many indices are in each partition.\n\n\n        Parameters\n        ----------\n        idxs: torch.Tensor.\n            A tensor with shape (`num_indices`,), representing global indices.\n\n        Return\n        ------\n        torch.Tensor.\n            A tensor with shape (`num_indices`,), representing the permutation\n            to re-order the indices by partition.\n        torch.Tensor.\n            A tensor with shape (`num_partition`,), representing the number of\n            indices per partition.\n\n        Examples\n        --------\n\n        >>> import torch\n        >>> from dgl.partition import NDArrayPartition\n        >>> part = NDArrayPartition(10, 2, mode=\"remainder\")\n        >>> idx = torch.tensor([0, 2, 4, 5, 8, 8, 9], device=\"cuda:0\")\n        >>> perm, splits_sum = part.generate_permutation(idx)\n        >>> perm\n        tensor([0, 1, 2, 4, 5, 3, 6], device='cuda:0')\n        >>> splits_sum\n        tensor([5, 2], device='cuda:0')\n        \"\"\"\n        ret = _CAPI_DGLNDArrayPartitionGeneratePermutation(\n            self._partition, F.zerocopy_to_dgl_ndarray(idxs)\n        )\n        return F.zerocopy_from_dgl_ndarray(ret(0)), F.zerocopy_from_dgl_ndarray(\n            ret(1)\n        )\n\n\n_init_api(\"dgl.partition\")\n"
  },
  {
    "path": "python/dgl/propagate.py",
    "content": "\"\"\"Module for message propagation.\"\"\"\nfrom __future__ import absolute_import\n\nfrom . import backend as F, traversal as trv\nfrom .heterograph import DGLGraph\n\n__all__ = [\n    \"prop_nodes\",\n    \"prop_nodes_bfs\",\n    \"prop_nodes_topo\",\n    \"prop_edges\",\n    \"prop_edges_dfs\",\n]\n\n\ndef prop_nodes(\n    graph,\n    nodes_generator,\n    message_func=\"default\",\n    reduce_func=\"default\",\n    apply_node_func=\"default\",\n):\n    \"\"\"Functional method for :func:`dgl.DGLGraph.prop_nodes`.\n\n    Parameters\n    ----------\n    node_generators : generator\n        The generator of node frontiers.\n    message_func : callable, optional\n        The message function.\n    reduce_func : callable, optional\n        The reduce function.\n    apply_node_func : callable, optional\n        The update function.\n\n    See Also\n    --------\n    dgl.DGLGraph.prop_nodes\n    \"\"\"\n    graph.prop_nodes(\n        nodes_generator, message_func, reduce_func, apply_node_func\n    )\n\n\ndef prop_edges(\n    graph,\n    edges_generator,\n    message_func=\"default\",\n    reduce_func=\"default\",\n    apply_node_func=\"default\",\n):\n    \"\"\"Functional method for :func:`dgl.DGLGraph.prop_edges`.\n\n    Parameters\n    ----------\n    edges_generator : generator\n        The generator of edge frontiers.\n    message_func : callable, optional\n        The message function.\n    reduce_func : callable, optional\n        The reduce function.\n    apply_node_func : callable, optional\n        The update function.\n\n    See Also\n    --------\n    dgl.DGLGraph.prop_edges\n    \"\"\"\n    graph.prop_edges(\n        edges_generator, message_func, reduce_func, apply_node_func\n    )\n\n\ndef prop_nodes_bfs(\n    graph,\n    source,\n    message_func,\n    reduce_func,\n    reverse=False,\n    apply_node_func=None,\n):\n    \"\"\"Message propagation using node frontiers generated by BFS.\n\n    Parameters\n    ----------\n    graph : DGLGraph\n        The graph object.\n    source : list, tensor of nodes\n        Source nodes.\n    message_func : callable\n        The message function.\n    reduce_func : callable\n        The reduce function.\n    reverse : bool, optional\n        If true, traverse following the in-edge direction.\n    apply_node_func : callable, optional\n        The update function.\n\n    See Also\n    --------\n    dgl.traversal.bfs_nodes_generator\n    \"\"\"\n    assert isinstance(\n        graph, DGLGraph\n    ), \"DGLHeteroGraph is merged with DGLGraph, Please use DGLGraph\"\n    assert (\n        len(graph.canonical_etypes) == 1\n    ), \"prop_nodes_bfs only support homogeneous graph\"\n    # TODO(murphy): Graph traversal currently is only supported on\n    # CPP graphs. Move graph to CPU as a workaround,\n    # which should be fixed in the future.\n    nodes_gen = trv.bfs_nodes_generator(graph.cpu(), source, reverse)\n    nodes_gen = [F.copy_to(frontier, graph.device) for frontier in nodes_gen]\n    prop_nodes(graph, nodes_gen, message_func, reduce_func, apply_node_func)\n\n\ndef prop_nodes_topo(\n    graph, message_func, reduce_func, reverse=False, apply_node_func=None\n):\n    \"\"\"Message propagation using node frontiers generated by topological order.\n\n    Parameters\n    ----------\n    graph : DGLGraph\n        The graph object.\n    message_func : callable\n        The message function.\n    reduce_func : callable\n        The reduce function.\n    reverse : bool, optional\n        If true, traverse following the in-edge direction.\n    apply_node_func : callable, optional\n        The update function.\n\n    See Also\n    --------\n    dgl.traversal.topological_nodes_generator\n    \"\"\"\n    assert isinstance(\n        graph, DGLGraph\n    ), \"DGLHeteroGraph is merged with DGLGraph, Please use DGLGraph\"\n    assert (\n        len(graph.canonical_etypes) == 1\n    ), \"prop_nodes_topo only support homogeneous graph\"\n    # TODO(murphy): Graph traversal currently is only supported on\n    # CPP graphs. Move graph to CPU as a workaround,\n    # which should be fixed in the future.\n    nodes_gen = trv.topological_nodes_generator(graph.cpu(), reverse)\n    nodes_gen = [F.copy_to(frontier, graph.device) for frontier in nodes_gen]\n    prop_nodes(graph, nodes_gen, message_func, reduce_func, apply_node_func)\n\n\ndef prop_edges_dfs(\n    graph,\n    source,\n    message_func,\n    reduce_func,\n    reverse=False,\n    has_reverse_edge=False,\n    has_nontree_edge=False,\n    apply_node_func=None,\n):\n    \"\"\"Message propagation using edge frontiers generated by labeled DFS.\n\n    Parameters\n    ----------\n    graph : DGLGraph\n        The graph object.\n    source : list, tensor of nodes\n        Source nodes.\n    message_func : callable, optional\n        The message function.\n    reduce_func : callable, optional\n        The reduce function.\n    reverse : bool, optional\n        If true, traverse following the in-edge direction.\n    has_reverse_edge : bool, optional\n        If true, REVERSE edges are included.\n    has_nontree_edge : bool, optional\n        If true, NONTREE edges are included.\n    apply_node_func : callable, optional\n        The update function.\n\n    See Also\n    --------\n    dgl.traversal.dfs_labeled_edges_generator\n    \"\"\"\n    assert isinstance(\n        graph, DGLGraph\n    ), \"DGLHeteroGraph is merged with DGLGraph, Please use DGLGraph\"\n    assert (\n        len(graph.canonical_etypes) == 1\n    ), \"prop_edges_dfs only support homogeneous graph\"\n    # TODO(murphy): Graph traversal currently is only supported on\n    # CPP graphs. Move graph to CPU as a workaround,\n    # which should be fixed in the future.\n    edges_gen = trv.dfs_labeled_edges_generator(\n        graph.cpu(),\n        source,\n        reverse,\n        has_reverse_edge,\n        has_nontree_edge,\n        return_labels=False,\n    )\n    edges_gen = [F.copy_to(frontier, graph.device) for frontier in edges_gen]\n    prop_edges(graph, edges_gen, message_func, reduce_func, apply_node_func)\n"
  },
  {
    "path": "python/dgl/random.py",
    "content": "\"\"\"Python interfaces to DGL random number generators.\"\"\"\nimport numpy as np\n\nfrom . import backend as F, ndarray as nd\nfrom ._ffi.function import _init_api\n\n__all__ = [\"seed\"]\n\n\ndef seed(val):\n    \"\"\"Set the random seed of DGL.\n\n    Parameters\n    ----------\n    val : int\n        The seed.\n    \"\"\"\n    _CAPI_SetSeed(val)\n\n\ndef choice(a, size, replace=True, prob=None):  # pylint: disable=invalid-name\n    \"\"\"An equivalent to :func:`numpy.random.choice`.\n\n    Use this function if you:\n\n    * Perform a non-uniform sampling (probability tensor is given).\n    * Sample a small set from a very large population (ratio <5%) uniformly\n      *without* replacement.\n    * Have a backend tensor on hand and does not want to convert it to numpy\n      back and forth.\n\n    Compared to :func:`numpy.random.choice`, it is slower when replace is True\n    and is comparable when replace is False. It wins when the population is\n    very large and the number of draws are quite small (e.g., draw <5%). The\n    reasons are two folds:\n\n    * When ``a`` is a large integer, it avoids creating a large range array as\n      numpy does.\n    * When draw ratio is small, it switches to a hashmap based implementation.\n\n    It out-performs numpy for non-uniform sampling in general cases.\n\n    Parameters\n    ----------\n    a : 1-D tensor or int\n        If an ndarray, a random sample is generated from its elements. If an int,\n        the random sample is generated as if a were F.arange(a)\n    size : int or tuple of ints\n        Output shape. E.g., for size ``(m, n, k)``, then ``m * n * k`` samples are drawn.\n    replace : bool, optional\n        If true, sample with replacement.\n    prob : 1-D tensor, optional\n        The probabilities associated with each entry in a.\n        If not given the sample assumes a uniform distribution over all entries in a.\n\n    Returns\n    -------\n    samples : 1-D tensor\n        The generated random samples\n    \"\"\"\n    # TODO(minjie): support RNG as one of the arguments.\n    if isinstance(size, tuple):\n        num = np.prod(size)\n    else:\n        num = size\n\n    if F.is_tensor(a):\n        population = F.shape(a)[0]\n    else:\n        population = a\n\n    if prob is None:\n        prob = nd.NULL[\"int64\"]\n    else:\n        prob = F.zerocopy_to_dgl_ndarray(prob)\n\n    bits = 64  # index array is in 64-bit\n    chosen_idx = _CAPI_Choice(\n        int(num), int(population), prob, bool(replace), bits\n    )\n    chosen_idx = F.zerocopy_from_dgl_ndarray(chosen_idx)\n\n    if F.is_tensor(a):\n        chosen = F.gather_row(a, chosen_idx)\n    else:\n        chosen = chosen_idx\n\n    if isinstance(size, tuple):\n        return F.reshape(chosen, size)\n    else:\n        return chosen\n\n\n_init_api(\"dgl.rng\", __name__)\n"
  },
  {
    "path": "python/dgl/readout.py",
    "content": "\"\"\"Classes and functions for batching multiple graphs together.\"\"\"\nfrom __future__ import absolute_import\n\nfrom . import backend as F\nfrom .base import dgl_warning, DGLError\nfrom .ops import segment\n\n__all__ = [\n    \"readout_nodes\",\n    \"readout_edges\",\n    \"sum_nodes\",\n    \"sum_edges\",\n    \"mean_nodes\",\n    \"mean_edges\",\n    \"max_nodes\",\n    \"max_edges\",\n    \"softmax_nodes\",\n    \"softmax_edges\",\n    \"broadcast_nodes\",\n    \"broadcast_edges\",\n    \"topk_nodes\",\n    \"topk_edges\",\n]\n\n\ndef readout_nodes(graph, feat, weight=None, *, op=\"sum\", ntype=None):\n    \"\"\"Generate a graph-level representation by aggregating node features\n    :attr:`feat`.\n\n    The function is commonly used as a *readout* function on a batch of graphs\n    to generate graph-level representation. Thus, the result tensor shape\n    depends on the batch size of the input graph. Given a graph of batch size\n    :math:`B`, and a feature size of :math:`D`, the result shape will be\n    :math:`(B, D)`, with each row being the aggregated node features of each\n    graph.\n\n    Parameters\n    ----------\n    graph : DGLGraph.\n        Input graph.\n    feat : str\n        Node feature name.\n    weight : str, optional\n        Node weight name. None means aggregating without weights.\n        Otherwise, multiply each node feature by node feature :attr:`weight`\n        before aggregation. The weight feature shape must be compatible with\n        an element-wise multiplication with the feature tensor.\n    op : str, optional\n        Readout operator. Can be 'sum', 'max', 'min', 'mean'.\n    ntype : str, optional\n        Node type. Can be omitted if there is only one node type in the graph.\n\n    Returns\n    -------\n    Tensor\n        Result tensor.\n\n    Examples\n    --------\n\n    >>> import dgl\n    >>> import torch as th\n\n    Create two :class:`~dgl.DGLGraph` objects and initialize their\n    node features.\n\n    >>> g1 = dgl.graph(([0, 1], [1, 0]))              # Graph 1\n    >>> g1.ndata['h'] = th.tensor([1., 2.])\n    >>> g2 = dgl.graph(([0, 1], [1, 2]))              # Graph 2\n    >>> g2.ndata['h'] = th.tensor([1., 2., 3.])\n\n    Sum over one graph:\n\n    >>> dgl.readout_nodes(g1, 'h')\n    tensor([3.])  # 1 + 2\n\n    Sum over a batched graph:\n\n    >>> bg = dgl.batch([g1, g2])\n    >>> dgl.readout_nodes(bg, 'h')\n    tensor([3., 6.])  # [1 + 2, 1 + 2 + 3]\n\n    Weighted sum:\n\n    >>> bg.ndata['w'] = th.tensor([.1, .2, .1, .5, .2])\n    >>> dgl.readout_nodes(bg, 'h', 'w')\n    tensor([.5, 1.7])\n\n    Readout by max:\n\n    >>> dgl.readout_nodes(bg, 'h', op='max')\n    tensor([2., 3.])\n\n    See Also\n    --------\n    readout_edges\n    \"\"\"\n    x = graph.nodes[ntype].data[feat]\n    if weight is not None:\n        x = x * graph.nodes[ntype].data[weight]\n    return segment.segment_reduce(graph.batch_num_nodes(ntype), x, reducer=op)\n\n\ndef readout_edges(graph, feat, weight=None, *, op=\"sum\", etype=None):\n    \"\"\"Sum the edge feature :attr:`feat` in :attr:`graph`, optionally\n    multiplies it by a edge :attr:`weight`.\n\n    The function is commonly used as a *readout* function on a batch of graphs\n    to generate graph-level representation. Thus, the result tensor shape\n    depends on the batch size of the input graph. Given a graph of batch size\n    :math:`B`, and a feature size of :math:`D`, the result shape will be\n    :math:`(B, D)`, with each row being the aggregated edge features of each\n    graph.\n\n    Parameters\n    ----------\n    graph : DGLGraph.\n        The input graph.\n    feat : str\n        The edge feature name.\n    weight : str, optional\n        The edge weight feature name. If None, no weighting will be performed,\n        otherwise, weight each edge feature with field :attr:`feat`.\n        for summation. The weight feature shape must be compatible with\n        an element-wise multiplication with the feature tensor.\n    op : str, optional\n        Readout operator. Can be 'sum', 'max', 'min', 'mean'.\n    etype : str or (str, str, str), optional\n        The type names of the edges. The allowed type name formats are:\n\n        * ``(str, str, str)`` for source node type, edge type and destination node type.\n        * or one ``str`` edge type name if the name can uniquely identify a\n          triplet format in the graph.\n\n        Can be omitted if the graph has only one type of edges.\n\n    Returns\n    -------\n    Tensor\n        Result tensor.\n\n    Examples\n    --------\n\n    >>> import dgl\n    >>> import torch as th\n\n    Create two :class:`~dgl.DGLGraph` objects and initialize their\n    edge features.\n\n    >>> g1 = dgl.graph(([0, 1], [1, 0]))              # Graph 1\n    >>> g1.edata['h'] = th.tensor([1., 2.])\n    >>> g2 = dgl.graph(([0, 1], [1, 2]))              # Graph 2\n    >>> g2.edata['h'] = th.tensor([2., 3.])\n\n    Sum over one graph:\n\n    >>> dgl.readout_edges(g1, 'h')\n    tensor([3.])  # 1 + 2\n\n    Sum over a batched graph:\n\n    >>> bg = dgl.batch([g1, g2])\n    >>> dgl.readout_edges(bg, 'h')\n    tensor([3., 5.])  # [1 + 2, 2 + 3]\n\n    Weighted sum:\n\n    >>> bg.edata['w'] = th.tensor([.1, .2, .1, .5])\n    >>> dgl.readout_edges(bg, 'h', 'w')\n    tensor([.5, 1.7])\n\n    Readout by max:\n\n    >>> dgl.readout_edges(bg, 'w', op='max')\n    tensor([2., 3.])\n\n    See Also\n    --------\n    readout_nodes\n    \"\"\"\n    x = graph.edges[etype].data[feat]\n    if weight is not None:\n        x = x * graph.edges[etype].data[weight]\n    return segment.segment_reduce(graph.batch_num_edges(etype), x, reducer=op)\n\n\ndef sum_nodes(graph, feat, weight=None, *, ntype=None):\n    \"\"\"Syntax sugar for ``dgl.readout_nodes(graph, feat, weight, ntype=ntype, op='sum')``.\n\n    See Also\n    --------\n    readout_nodes\n    \"\"\"\n    return readout_nodes(graph, feat, weight, ntype=ntype, op=\"sum\")\n\n\ndef sum_edges(graph, feat, weight=None, *, etype=None):\n    \"\"\"Syntax sugar for ``dgl.readout_edges(graph, feat, weight, etype=etype, op='sum')``.\n\n    See Also\n    --------\n    readout_edges\n    \"\"\"\n    return readout_edges(graph, feat, weight, etype=etype, op=\"sum\")\n\n\ndef mean_nodes(graph, feat, weight=None, *, ntype=None):\n    \"\"\"Syntax sugar for ``dgl.readout_nodes(graph, feat, weight, ntype=ntype, op='mean')``.\n\n    See Also\n    --------\n    readout_nodes\n    \"\"\"\n    return readout_nodes(graph, feat, weight, ntype=ntype, op=\"mean\")\n\n\ndef mean_edges(graph, feat, weight=None, *, etype=None):\n    \"\"\"Syntax sugar for ``dgl.readout_edges(graph, feat, weight, etype=etype, op='mean')``.\n\n    See Also\n    --------\n    readout_edges\n    \"\"\"\n    return readout_edges(graph, feat, weight, etype=etype, op=\"mean\")\n\n\ndef max_nodes(graph, feat, weight=None, *, ntype=None):\n    \"\"\"Syntax sugar for ``dgl.readout_nodes(graph, feat, weight, ntype=ntype, op='max')``.\n\n    See Also\n    --------\n    readout_nodes\n    \"\"\"\n    return readout_nodes(graph, feat, weight, ntype=ntype, op=\"max\")\n\n\ndef max_edges(graph, feat, weight=None, *, etype=None):\n    \"\"\"Syntax sugar for ``dgl.readout_edges(graph, feat, weight, etype=etype, op='max')``.\n\n    See Also\n    --------\n    readout_edges\n    \"\"\"\n    return readout_edges(graph, feat, weight, etype=etype, op=\"max\")\n\n\ndef softmax_nodes(graph, feat, *, ntype=None):\n    r\"\"\"Perform graph-wise softmax on the node features.\n\n    For each node :math:`v\\in\\mathcal{V}` and its feature :math:`x_v`,\n    calculate its normalized feature as follows:\n\n    .. math::\n        z_v = \\frac{\\exp(x_v)}{\\sum_{u\\in\\mathcal{V}}\\exp(x_u)}\n\n    If the graph is a batch of multiple graphs, each graph computes softmax\n    independently. The result tensor has the same shape as the original node\n    feature.\n\n    Parameters\n    ----------\n    graph : DGLGraph.\n        The input graph.\n    feat : str\n        The node feature name.\n    ntype : str, optional\n        The node type name. Can be omitted if there is only one node type in the graph.\n\n    Returns\n    -------\n    Tensor\n        Result tensor.\n\n    Examples\n    --------\n\n    >>> import dgl\n    >>> import torch as th\n\n    Create two :class:`~dgl.DGLGraph` objects and initialize their\n    node features.\n\n    >>> g1 = dgl.graph(([0, 1], [1, 0]))              # Graph 1\n    >>> g1.ndata['h'] = th.tensor([1., 1.])\n    >>> g2 = dgl.graph(([0, 1], [1, 2]))              # Graph 2\n    >>> g2.ndata['h'] = th.tensor([1., 1., 1.])\n\n    Softmax over one graph:\n\n    >>> dgl.softmax_nodes(g1, 'h')\n    tensor([.5000, .5000])\n\n    Softmax over a batched graph:\n\n    >>> bg = dgl.batch([g1, g2])\n    >>> dgl.softmax_nodes(bg, 'h')\n    tensor([.5000, .5000, .3333, .3333, .3333])\n\n    See Also\n    --------\n    softmax_edges\n    \"\"\"\n    x = graph.nodes[ntype].data[feat]\n    return segment.segment_softmax(graph.batch_num_nodes(ntype), x)\n\n\ndef softmax_edges(graph, feat, *, etype=None):\n    r\"\"\"Perform graph-wise softmax on the edge features.\n\n    For each edge :math:`e\\in\\mathcal{E}` and its feature :math:`x_e`,\n    calculate its normalized feature as follows:\n\n    .. math::\n        z_e = \\frac{\\exp(x_e)}{\\sum_{e'\\in\\mathcal{E}}\\exp(x_{e'})}\n\n    If the graph is a batch of multiple graphs, each graph computes softmax\n    independently. The result tensor has the same shape as the original edge\n    feature.\n\n    Parameters\n    ----------\n    graph : DGLGraph.\n        The input graph.\n    feat : str\n        The edge feature name.\n    etype : str or (str, str, str), optional\n        The type names of the edges. The allowed type name formats are:\n\n        * ``(str, str, str)`` for source node type, edge type and destination node type.\n        * or one ``str`` edge type name if the name can uniquely identify a\n          triplet format in the graph.\n\n        Can be omitted if the graph has only one type of edges.\n\n    Returns\n    -------\n    Tensor\n        Result tensor.\n\n    Examples\n    --------\n\n    >>> import dgl\n    >>> import torch as th\n\n    Create two :class:`~dgl.DGLGraph` objects and initialize their\n    edge features.\n\n    >>> g1 = dgl.graph(([0, 1], [1, 0]))              # Graph 1\n    >>> g1.edata['h'] = th.tensor([1., 1.])\n    >>> g2 = dgl.graph(([0, 1, 0], [1, 2, 2]))        # Graph 2\n    >>> g2.edata['h'] = th.tensor([1., 1., 1.])\n\n    Softmax over one graph:\n\n    >>> dgl.softmax_edges(g1, 'h')\n    tensor([.5000, .5000])\n\n    Softmax over a batched graph:\n\n    >>> bg = dgl.batch([g1, g2])\n    >>> dgl.softmax_edges(bg, 'h')\n    tensor([.5000, .5000, .3333, .3333, .3333])\n\n    See Also\n    --------\n    softmax_nodes\n    \"\"\"\n    x = graph.edges[etype].data[feat]\n    return segment.segment_softmax(graph.batch_num_edges(etype), x)\n\n\ndef broadcast_nodes(graph, graph_feat, *, ntype=None):\n    \"\"\"Generate a node feature equal to the graph-level feature :attr:`graph_feat`.\n\n    The operation is similar to ``numpy.repeat`` (or ``torch.repeat_interleave``).\n    It is commonly used to normalize node features by a global vector. For example,\n    to normalize node features across graph to range :math:`[0~1)`:\n\n    >>> g = dgl.batch([...])  # batch multiple graphs\n    >>> g.ndata['h'] = ...  # some node features\n    >>> h_sum = dgl.broadcast_nodes(g, dgl.sum_nodes(g, 'h'))\n    >>> g.ndata['h'] /= h_sum  # normalize by summation\n\n    Parameters\n    ----------\n    graph : DGLGraph\n        The graph.\n    graph_feat : tensor\n        The feature to broadcast. Tensor shape is :math:`(B, *)` for batched graph,\n        where :math:`B` is the batch size.\n\n    ntype : str, optional\n        Node type. Can be omitted if there is only one node type.\n\n    Returns\n    -------\n    Tensor\n        The node features tensor with shape :math:`(N, *)`, where :math:`N` is the\n        number of nodes.\n\n    Examples\n    --------\n\n    >>> import dgl\n    >>> import torch as th\n\n    Create two :class:`~dgl.DGLGraph` objects and initialize their\n    node features.\n\n    >>> g1 = dgl.graph(([0], [1]))                    # Graph 1\n    >>> g2 = dgl.graph(([0, 1], [1, 2]))              # Graph 2\n    >>> bg = dgl.batch([g1, g2])\n    >>> feat = th.rand(2, 5)\n    >>> feat\n    tensor([[0.4325, 0.7710, 0.5541, 0.0544, 0.9368],\n            [0.2721, 0.4629, 0.7269, 0.0724, 0.1014]])\n\n    Broadcast feature to all nodes in the batched graph, feat[i] is broadcast to nodes\n    in the i-th example in the batch.\n\n    >>> dgl.broadcast_nodes(bg, feat)\n    tensor([[0.4325, 0.7710, 0.5541, 0.0544, 0.9368],\n            [0.4325, 0.7710, 0.5541, 0.0544, 0.9368],\n            [0.2721, 0.4629, 0.7269, 0.0724, 0.1014],\n            [0.2721, 0.4629, 0.7269, 0.0724, 0.1014],\n            [0.2721, 0.4629, 0.7269, 0.0724, 0.1014]])\n\n    Broadcast feature to all nodes in the single graph (the feature tensor shape\n    to broadcast should be :math:`(1, *)`).\n\n    >>> feat0 = th.unsqueeze(feat[0], 0)\n    >>> dgl.broadcast_nodes(g1, feat0)\n    tensor([[0.4325, 0.7710, 0.5541, 0.0544, 0.9368],\n            [0.4325, 0.7710, 0.5541, 0.0544, 0.9368]])\n\n    See Also\n    --------\n    broadcast_edges\n    \"\"\"\n    if F.shape(graph_feat)[0] != graph.batch_size and graph.batch_size == 1:\n        dgl_warning(\n            \"For a single graph, use a tensor of shape (1, *) for graph_feat.\"\n            \" The support of shape (*) will be deprecated.\"\n        )\n        graph_feat = F.unsqueeze(graph_feat, dim=0)\n    return F.repeat(graph_feat, graph.batch_num_nodes(ntype), dim=0)\n\n\ndef broadcast_edges(graph, graph_feat, *, etype=None):\n    \"\"\"Generate an edge feature equal to the graph-level feature :attr:`graph_feat`.\n\n    The operation is similar to ``numpy.repeat`` (or ``torch.repeat_interleave``).\n    It is commonly used to normalize edge features by a global vector. For example,\n    to normalize edge features across graph to range :math:`[0~1)`:\n\n    >>> g = dgl.batch([...])  # batch multiple graphs\n    >>> g.edata['h'] = ...  # some node features\n    >>> h_sum = dgl.broadcast_edges(g, dgl.sum_edges(g, 'h'))\n    >>> g.edata['h'] /= h_sum  # normalize by summation\n\n    Parameters\n    ----------\n    graph : DGLGraph\n        The graph.\n    graph_feat : tensor\n        The feature to broadcast. Tensor shape is :math:`(B, *)` for batched graph,\n        where :math:`B` is the batch size.\n    etype : str, typle of str, optional\n        Edge type. Can be omitted if there is only one edge type in the graph.\n\n    Returns\n    -------\n    Tensor\n        The edge features tensor with shape :math:`(M, *)`, where :math:`M` is the\n        number of edges.\n\n    Examples\n    --------\n\n    >>> import dgl\n    >>> import torch as th\n\n    Create two :class:`~dgl.DGLGraph` objects and initialize their\n    edge features.\n\n    >>> g1 = dgl.graph(([0], [1]))                    # Graph 1\n    >>> g2 = dgl.graph(([0, 1], [1, 2]))              # Graph 2\n    >>> bg = dgl.batch([g1, g2])\n    >>> feat = th.rand(2, 5)\n    >>> feat\n    tensor([[0.4325, 0.7710, 0.5541, 0.0544, 0.9368],\n            [0.2721, 0.4629, 0.7269, 0.0724, 0.1014]])\n\n    Broadcast feature to all edges in the batched graph, feat[i] is broadcast to edges\n    in the i-th example in the batch.\n\n    >>> dgl.broadcast_edges(bg, feat)\n    tensor([[0.4325, 0.7710, 0.5541, 0.0544, 0.9368],\n            [0.2721, 0.4629, 0.7269, 0.0724, 0.1014],\n            [0.2721, 0.4629, 0.7269, 0.0724, 0.1014]])\n\n    Broadcast feature to all edges in the single graph (the feature tensor shape\n    to broadcast should be :math:`(1, *)`).\n\n    >>> feat1 = th.unsqueeze(feat[1], 0)\n    >>> dgl.broadcast_edges(g2, feat1)\n    tensor([[0.2721, 0.4629, 0.7269, 0.0724, 0.1014],\n            [0.2721, 0.4629, 0.7269, 0.0724, 0.1014]])\n\n    See Also\n    --------\n    broadcast_nodes\n    \"\"\"\n    if F.shape(graph_feat)[0] != graph.batch_size and graph.batch_size == 1:\n        dgl_warning(\n            \"For a single graph, use a tensor of shape (1, *) for graph_feat.\"\n            \" The support of shape (*) will be deprecated.\"\n        )\n        graph_feat = F.unsqueeze(graph_feat, dim=0)\n    return F.repeat(graph_feat, graph.batch_num_edges(etype), dim=0)\n\n\nREADOUT_ON_ATTRS = {\n    \"nodes\": (\"ndata\", \"batch_num_nodes\", \"number_of_nodes\"),\n    \"edges\": (\"edata\", \"batch_num_edges\", \"number_of_edges\"),\n}\n\n\ndef _topk_torch(keys, k, descending, x):\n    \"\"\"Internal function to take graph-wise top-k node/edge features according to\n    the rank given by keys, this function is PyTorch only.\n\n    Parameters\n    ----------\n    keys : Tensor\n        The key for ranking.\n    k : int\n        The :math:`k` in \"top-:math:`k`\".\n    descending : bool\n        Indicates whether to return the feature corresponding to largest or\n        smallest elements.\n    x : Tensor\n        The padded feature with shape (batch, max_len, *)\n\n    Returns\n    -------\n    sorted_feat : Tensor\n        A tensor with shape :math:`(batch, k, *)`.\n    sorted_idx : Tensor\n        A tensor with shape :math:`(batch, k)`.\n    \"\"\"\n    import torch as th\n\n    batch_size, max_len = x.shape[0], x.shape[1]\n    topk_indices = keys.topk(k, -1, largest=descending)[1]  # (batch_size, k)\n    x = x.view((batch_size * max_len), -1)\n    shift = (\n        th.arange(0, batch_size, device=x.device).view(batch_size, 1) * max_len\n    )\n    topk_indices_ = topk_indices + shift\n    x = x[topk_indices_].view(batch_size, k, -1)\n    return th.masked_fill(x, th.isinf(x), 0), topk_indices\n\n\ndef _topk_on(graph, typestr, feat, k, descending, sortby, ntype_or_etype):\n    \"\"\"Internal function to take graph-wise top-k node/edge features of\n    field :attr:`feat` in :attr:`graph` ranked by keys at given\n    index :attr:`sortby`. If :attr:`descending` is set to False, return the\n    k smallest elements instead.\n\n    Parameters\n    ---------\n    graph : DGLGraph\n        The graph\n    typestr : str\n        'nodes' or 'edges'\n    feat : str\n        The feature field name.\n    k : int\n        The :math:`k` in \"top-:math`k`\".\n    descending : bool\n        Controls whether to return the largest or smallest elements,\n         defaults to True.\n    sortby : int\n        The key index we sort :attr:`feat` on, if set to None, we sort\n        the whole :attr:`feat`.\n    ntype_or_etype : str, tuple of str\n        Node/edge type.\n\n    Returns\n    -------\n    sorted_feat : Tensor\n        A tensor with shape :math:`(B, K, D)`, where\n        :math:`B` is the batch size of the input graph.\n    sorted_idx : Tensor\n        A tensor with shape :math:`(B, K)`(:math:`(B, K, D)` if sortby\n        is set to None), where\n        :math:`B` is the batch size of the input graph, :math:`D`\n        is the feature size.\n\n\n    Notes\n    -----\n    If an example has :math:`n` nodes/edges and :math:`n<k`, in the first\n    returned tensor the :math:`n+1` to :math:`k`th rows would be padded\n    with all zero; in the second returned tensor, the behavior of :math:`n+1`\n    to :math:`k`th elements is not defined.\n    \"\"\"\n    _, batch_num_objs_attr, _ = READOUT_ON_ATTRS[typestr]\n    data = getattr(graph, typestr)[ntype_or_etype].data\n    if F.ndim(data[feat]) > 2:\n        raise DGLError(\n            \"Only support {} feature `{}` with dimension less than or\"\n            \" equal to 2\".format(typestr, feat)\n        )\n    feat = data[feat]\n    hidden_size = F.shape(feat)[-1]\n    batch_num_objs = getattr(graph, batch_num_objs_attr)(ntype_or_etype)\n    batch_size = len(batch_num_objs)\n    length = max(max(F.asnumpy(batch_num_objs)), k)\n    fill_val = -float(\"inf\") if descending else float(\"inf\")\n    feat_ = F.pad_packed_tensor(\n        feat, batch_num_objs, fill_val, l_min=k\n    )  # (batch_size, l, d)\n\n    if F.backend_name == \"pytorch\" and sortby is not None:\n        # PyTorch's implementation of top-K\n        keys = feat_[..., sortby]  # (batch_size, l)\n        return _topk_torch(keys, k, descending, feat_)\n    else:\n        # Fallback to framework-agnostic implementation of top-K\n        if sortby is not None:\n            keys = F.squeeze(F.slice_axis(feat_, -1, sortby, sortby + 1), -1)\n            order = F.argsort(keys, -1, descending=descending)\n        else:\n            order = F.argsort(feat_, 1, descending=descending)\n        topk_indices = F.slice_axis(order, 1, 0, k)\n\n        if sortby is not None:\n            feat_ = F.reshape(feat_, (batch_size * length, -1))\n            shift = F.repeat(F.arange(0, batch_size) * length, k, -1)\n            shift = F.copy_to(shift, F.context(feat))\n            topk_indices_ = F.reshape(topk_indices, (-1,)) + shift\n        else:\n            feat_ = F.reshape(feat_, (-1,))\n            shift = F.repeat(\n                F.arange(0, batch_size), k * hidden_size, -1\n            ) * length * hidden_size + F.cat(\n                [F.arange(0, hidden_size)] * batch_size * k, -1\n            )\n            shift = F.copy_to(shift, F.context(feat))\n            topk_indices_ = F.reshape(topk_indices, (-1,)) * hidden_size + shift\n        out = F.reshape(F.gather_row(feat_, topk_indices_), (batch_size, k, -1))\n        out = F.replace_inf_with_zero(out)\n        return out, topk_indices\n\n\ndef topk_nodes(graph, feat, k, *, descending=True, sortby=None, ntype=None):\n    \"\"\"Return a graph-level representation by a graph-wise top-k on\n    node features :attr:`feat` in :attr:`graph` by feature at index :attr:`sortby`.\n\n    If :attr:`descending` is set to False, return the k smallest elements instead.\n\n    If :attr:`sortby` is set to None, the function would perform top-k on\n    all dimensions independently, equivalent to calling\n    :code:`torch.topk(graph.ndata[feat], dim=0)`.\n\n    Parameters\n    ----------\n    graph : DGLGraph\n        The graph.\n    feat : str\n        The feature field.\n    k : int\n        The k in \"top-k\"\n    descending : bool\n        Controls whether to return the largest or smallest elements.\n    sortby : int, optional\n        Sort according to which feature. If is None, all features are sorted independently.\n    ntype : str, optional\n        Node type. Can be omitted if there is only one node type in the graph.\n\n    Returns\n    -------\n    sorted_feat : Tensor\n        A tensor with shape :math:`(B, K, D)`, where\n        :math:`B` is the batch size of the input graph.\n    sorted_idx : Tensor\n        A tensor with shape :math:`(B, K)`(:math:`(B, K, D)` if sortby\n        is set to None), where\n        :math:`B` is the batch size of the input graph, :math:`D`\n        is the feature size.\n\n    Notes\n    -----\n    If an example has :math:`n` nodes and :math:`n<k`, the ``sorted_feat``\n    tensor will pad the :math:`n+1` to :math:`k` th rows with zero;\n\n    Examples\n    --------\n\n    >>> import dgl\n    >>> import torch as th\n\n    Create two :class:`~dgl.DGLGraph` objects and initialize their\n    node features.\n\n    >>> g1 = dgl.graph(([0, 1], [2, 3]))              # Graph 1\n    >>> g1.ndata['h'] = th.rand(4, 5)\n    >>> g1.ndata['h']\n    tensor([[0.0297, 0.8307, 0.9140, 0.6702, 0.3346],\n            [0.5901, 0.3030, 0.9280, 0.6893, 0.7997],\n            [0.0880, 0.6515, 0.4451, 0.7507, 0.5297],\n            [0.5171, 0.6379, 0.2695, 0.8954, 0.5197]])\n\n    >>> g2 = dgl.graph(([0, 1, 2], [2, 3, 4]))       # Graph 2\n    >>> g2.ndata['h'] = th.rand(5, 5)\n    >>> g2.ndata['h']\n    tensor([[0.3168, 0.3174, 0.5303, 0.0804, 0.3808],\n            [0.1323, 0.2766, 0.4318, 0.6114, 0.1458],\n            [0.1752, 0.9105, 0.5692, 0.8489, 0.0539],\n            [0.1931, 0.4954, 0.3455, 0.3934, 0.0857],\n            [0.5065, 0.5182, 0.5418, 0.1520, 0.3872]])\n\n    Top-k over node attribute :attr:`h` in a batched graph.\n\n    >>> bg = dgl.batch([g1, g2], ndata=['h'])\n    >>> dgl.topk_nodes(bg, 'h', 3)\n    (tensor([[[0.5901, 0.8307, 0.9280, 0.8954, 0.7997],\n              [0.5171, 0.6515, 0.9140, 0.7507, 0.5297],\n              [0.0880, 0.6379, 0.4451, 0.6893, 0.5197]],\n             [[0.5065, 0.9105, 0.5692, 0.8489, 0.3872],\n              [0.3168, 0.5182, 0.5418, 0.6114, 0.3808],\n              [0.1931, 0.4954, 0.5303, 0.3934, 0.1458]]]), tensor([[[1, 0, 1, 3, 1],\n              [3, 2, 0, 2, 2],\n              [2, 3, 2, 1, 3]],\n             [[4, 2, 2, 2, 4],\n              [0, 4, 4, 1, 0],\n              [3, 3, 0, 3, 1]]]))\n\n    Top-k over node attribute :attr:`h` along the last dimension in a batched graph.\n    (used in SortPooling)\n\n    >>> dgl.topk_nodes(bg, 'h', 3, sortby=-1)\n    (tensor([[[0.5901, 0.3030, 0.9280, 0.6893, 0.7997],\n              [0.0880, 0.6515, 0.4451, 0.7507, 0.5297],\n              [0.5171, 0.6379, 0.2695, 0.8954, 0.5197]],\n             [[0.5065, 0.5182, 0.5418, 0.1520, 0.3872],\n              [0.3168, 0.3174, 0.5303, 0.0804, 0.3808],\n              [0.1323, 0.2766, 0.4318, 0.6114, 0.1458]]]), tensor([[1, 2, 3],\n             [4, 0, 1]]))\n\n    Top-k over node attribute :attr:`h` in a single graph.\n\n    >>> dgl.topk_nodes(g1, 'h', 3)\n    (tensor([[[0.5901, 0.8307, 0.9280, 0.8954, 0.7997],\n              [0.5171, 0.6515, 0.9140, 0.7507, 0.5297],\n              [0.0880, 0.6379, 0.4451, 0.6893, 0.5197]]]), tensor([[[1, 0, 1, 3, 1],\n              [3, 2, 0, 2, 2],\n              [2, 3, 2, 1, 3]]]))\n    \"\"\"\n    return _topk_on(\n        graph,\n        \"nodes\",\n        feat,\n        k,\n        descending=descending,\n        sortby=sortby,\n        ntype_or_etype=ntype,\n    )\n\n\ndef topk_edges(graph, feat, k, *, descending=True, sortby=None, etype=None):\n    \"\"\"Return a graph-level representation by a graph-wise top-k\n    on edge features :attr:`feat` in :attr:`graph` by feature at index :attr:`sortby`.\n\n    If :attr:`descending` is set to False, return the k smallest elements instead.\n\n    If :attr:`sortby` is set to None, the function would perform top-k on\n    all dimensions independently, equivalent to calling\n    :code:`torch.topk(graph.edata[feat], dim=0)`.\n\n    Parameters\n    ----------\n    graph : DGLGraph\n        The graph.\n    feat : str\n        The feature field.\n    k : int\n        The k in \"top-k\"\n    descending : bool\n        Controls whether to return the largest or smallest elements.\n    sortby : int, optional\n        Sort according to which feature. If is None, all features are sorted independently.\n    etype : str, typle of str, optional\n        Edge type. Can be omitted if there is only one edge type in the graph.\n\n    Returns\n    -------\n    sorted_feat : Tensor\n        A tensor with shape :math:`(B, K, D)`, where\n        :math:`B` is the batch size of the input graph.\n    sorted_idx : Tensor\n        A tensor with shape :math:`(B, K)`(:math:`(B, K, D)` if sortby\n        is set to None), where\n        :math:`B` is the batch size of the input graph, :math:`D`\n        is the feature size.\n\n\n    Notes\n    -----\n    If an example has :math:`n` nodes and :math:`n<k`, the ``sorted_feat``\n    tensor will pad the :math:`n+1` to :math:`k` th rows with zero;\n    Examples\n    --------\n\n    >>> import dgl\n    >>> import torch as th\n\n    Create two :class:`~dgl.DGLGraph` objects and initialize their\n    edge features.\n\n    >>> g1 = dgl.graph(([0, 1, 2, 3], [1, 2, 3, 0]))         # Graph 1\n    >>> g1.edata['h'] = th.rand(4, 5)\n    >>> g1.edata['h']\n    tensor([[0.0297, 0.8307, 0.9140, 0.6702, 0.3346],\n            [0.5901, 0.3030, 0.9280, 0.6893, 0.7997],\n            [0.0880, 0.6515, 0.4451, 0.7507, 0.5297],\n            [0.5171, 0.6379, 0.2695, 0.8954, 0.5197]])\n\n    >>> g2 = dgl.graph(([0, 1, 2, 3, 4], [1, 2, 3, 4, 0]))   # Graph 2\n    >>> g2.edata['h'] = th.rand(5, 5)\n    >>> g2.edata['h']\n    tensor([[0.3168, 0.3174, 0.5303, 0.0804, 0.3808],\n            [0.1323, 0.2766, 0.4318, 0.6114, 0.1458],\n            [0.1752, 0.9105, 0.5692, 0.8489, 0.0539],\n            [0.1931, 0.4954, 0.3455, 0.3934, 0.0857],\n            [0.5065, 0.5182, 0.5418, 0.1520, 0.3872]])\n\n    Top-k over edge attribute :attr:`h` in a batched graph.\n\n    >>> bg = dgl.batch([g1, g2], edata=['h'])\n    >>> dgl.topk_edges(bg, 'h', 3)\n    (tensor([[[0.5901, 0.8307, 0.9280, 0.8954, 0.7997],\n              [0.5171, 0.6515, 0.9140, 0.7507, 0.5297],\n              [0.0880, 0.6379, 0.4451, 0.6893, 0.5197]],\n             [[0.5065, 0.9105, 0.5692, 0.8489, 0.3872],\n              [0.3168, 0.5182, 0.5418, 0.6114, 0.3808],\n              [0.1931, 0.4954, 0.5303, 0.3934, 0.1458]]]), tensor([[[1, 0, 1, 3, 1],\n              [3, 2, 0, 2, 2],\n              [2, 3, 2, 1, 3]],\n             [[4, 2, 2, 2, 4],\n              [0, 4, 4, 1, 0],\n              [3, 3, 0, 3, 1]]]))\n\n    Top-k over edge attribute :attr:`h` along index -1 in a batched graph.\n    (used in SortPooling)\n\n    >>> dgl.topk_edges(bg, 'h', 3, sortby=-1)\n    (tensor([[[0.5901, 0.3030, 0.9280, 0.6893, 0.7997],\n              [0.0880, 0.6515, 0.4451, 0.7507, 0.5297],\n              [0.5171, 0.6379, 0.2695, 0.8954, 0.5197]],\n             [[0.5065, 0.5182, 0.5418, 0.1520, 0.3872],\n              [0.3168, 0.3174, 0.5303, 0.0804, 0.3808],\n              [0.1323, 0.2766, 0.4318, 0.6114, 0.1458]]]), tensor([[1, 2, 3],\n             [4, 0, 1]]))\n\n    Top-k over edge attribute :attr:`h` in a single graph.\n\n    >>> dgl.topk_edges(g1, 'h', 3)\n    (tensor([[[0.5901, 0.8307, 0.9280, 0.8954, 0.7997],\n              [0.5171, 0.6515, 0.9140, 0.7507, 0.5297],\n              [0.0880, 0.6379, 0.4451, 0.6893, 0.5197]]]), tensor([[[1, 0, 1, 3, 1],\n              [3, 2, 0, 2, 2],\n              [2, 3, 2, 1, 3]]]))\n    \"\"\"\n    return _topk_on(\n        graph,\n        \"edges\",\n        feat,\n        k,\n        descending=descending,\n        sortby=sortby,\n        ntype_or_etype=etype,\n    )\n"
  },
  {
    "path": "python/dgl/sampling/__init__.py",
    "content": "\"\"\"The ``dgl.sampling`` package contains operators and utilities for\nsampling from a graph via random walks, neighbor sampling, etc. They\nare typically used together with the ``DataLoader`` s in the\n``dgl.dataloading`` package. The user guide :ref:`guide-minibatch`\ngives a holistic explanation on how different components work together.\n\"\"\"\n\nfrom .randomwalks import *\nfrom .pinsage import *\nfrom .neighbor import *\nfrom .labor import *\nfrom .node2vec_randomwalk import *\nfrom .negative import *\nfrom . import utils\n"
  },
  {
    "path": "python/dgl/sampling/labor.py",
    "content": "#\n#   Copyright (c) 2022 by Contributors\n#\n#   Licensed under the Apache License, Version 2.0 (the \"License\");\n#   you may not use this file except in compliance with the License.\n#   You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n#   Unless required by applicable law or agreed to in writing, software\n#   distributed under the License is distributed on an \"AS IS\" BASIS,\n#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#   See the License for the specific language governing permissions and\n#   limitations under the License.\n#\n#   Based off of neighbor.py\n#\n\n\n\"\"\"Labor sampling APIs\"\"\"\n\nfrom .. import backend as F, ndarray as nd, utils\nfrom .._ffi.function import _init_api\nfrom ..base import DGLError\nfrom ..heterograph import DGLGraph\nfrom ..random import choice\nfrom .utils import EidExcluder\n\n__all__ = [\"sample_labors\"]\n\n\ndef sample_labors(\n    g,\n    nodes,\n    fanout,\n    edge_dir=\"in\",\n    prob=None,\n    importance_sampling=0,\n    random_seed=None,\n    seed2_contribution=0,\n    copy_ndata=True,\n    copy_edata=True,\n    exclude_edges=None,\n    output_device=None,\n):\n    \"\"\"Sampler that builds computational dependency of node representations via\n    labor sampling for multilayer GNN from the NeurIPS 2023 paper\n    `Layer-Neighbor Sampling -- Defusing Neighborhood Explosion in GNNs\n    <https://arxiv.org/abs/2210.13339>`__\n\n    This sampler will make every node gather messages from a fixed number of neighbors\n    per edge type. The neighbors are picked uniformly with default parameters. For every vertex t\n    that will be considered to be sampled, there will be a single random variate r_t.\n\n    For each node, a number of inbound (or outbound when ``edge_dir == 'out'``) edges\n    will be randomly chosen.  The graph returned will then contain all the nodes in the\n    original graph, but only the sampled edges.\n\n    Node/edge features are not preserved. The original IDs of\n    the sampled edges are stored as the `dgl.EID` feature in the returned graph.\n\n    Parameters\n    ----------\n    g : DGLGraph\n        The graph, allowed to have multiple node or edge types. Can be either on CPU or GPU.\n    nodes : tensor or dict\n        Node IDs to sample neighbors from.\n\n        This argument can take a single ID tensor or a dictionary of node types and ID tensors.\n        If a single tensor is given, the graph must only have one type of nodes.\n    fanout : int or dict[etype, int]\n        The number of edges to be sampled for each node on each edge type.\n\n        This argument can take a single int or a dictionary of edge types and ints.\n        If a single int is given, DGL will sample this number of edges for each node for\n        every edge type.\n\n        If -1 is given for a single edge type, all the neighboring edges with that edge\n        type will be selected.\n    edge_dir : str, optional\n        Determines whether to sample inbound or outbound edges.\n\n        Can take either ``in`` for inbound edges or ``out`` for outbound edges.\n    prob : str, optional\n        Feature name used as the (unnormalized) probabilities associated with each\n        neighboring edge of a node.  The feature must have only one element for each\n        edge.\n\n        The features must be non-negative floats, and the sum of the features of\n        inbound/outbound edges for every node must be positive (though they don't have\n        to sum up to one).  Otherwise, the result will be undefined.\n\n        If :attr:`prob` is not None, GPU sampling is not supported.\n    importance_sampling : int, optional\n        Whether to use importance sampling or uniform sampling, use of negative values optimizes\n        importance sampling probabilities until convergence while use of positive values runs\n        optimization steps that many times. If the value is i, then LABOR-i variant is used.\n    random_seed : tensor\n        An int64 tensor with one element.\n\n        The passed random_seed makes it so that for any seed vertex ``s`` and its neighbor ``t``,\n        the rolled random variate ``r_t`` is the same for any call to this function with the same\n        random seed. When sampling as part of the same batch, one would want identical seeds so that\n        LABOR can globally sample. One example is that for heterogenous graphs, there is a single\n        random seed passed for each edge type. This will sample much fewer vertices compared to\n        having unique random seeds for each edge type. If one called this function individually for\n        each edge type for a heterogenous graph with different random seeds, then it would run LABOR\n        locally for each edge type, resulting into a larger number of vertices being sampled.\n\n        If this function is called without a ``random_seed``, we get the random seed by getting a\n        random number from DGL. Use this argument with identical random_seed if multiple calls to\n        this function are used to sample as part of a single batch.\n    seed2_contribution : float, optional\n        A float value between [0, 1) that determines the contribution\n        of the second random seed to generate the random variates for the\n        LABOR sampling algorithm.\n    copy_ndata: bool, optional\n        If True, the node features of the new graph are copied from\n        the original graph. If False, the new graph will not have any\n        node features.\n\n        (Default: True)\n    copy_edata: bool, optional\n        If True, the edge features of the new graph are copied from\n        the original graph.  If False, the new graph will not have any\n        edge features.\n\n        (Default: True)\n    exclude_edges: tensor or dict\n        Edge IDs to exclude during sampling neighbors for the seed nodes.\n\n        This argument can take a single ID tensor or a dictionary of edge types and ID tensors.\n        If a single tensor is given, the graph must only have one type of nodes.\n    output_device : Framework-specific device context object, optional\n        The output device.  Default is the same as the input graph.\n\n    Returns\n    -------\n    tuple(DGLGraph, list[Tensor])\n        A sampled subgraph containing only the sampled neighboring edges along with edge weights.\n\n    Notes\n    -----\n    If :attr:`copy_ndata` or :attr:`copy_edata` is True, same tensors are used as\n    the node or edge features of the original graph and the new graph.\n    As a result, users should avoid performing in-place operations\n    on the node features of the new graph to avoid feature corruption.\n\n    Examples\n    --------\n    Assume that you have the following graph\n\n    >>> g = dgl.graph(([0, 0, 1, 1, 2, 2], [1, 2, 0, 1, 2, 0]))\n\n    And the weights\n\n    >>> g.edata['prob'] = torch.FloatTensor([0., 1., 0., 1., 0., 1.])\n\n    To sample one inbound edge for node 0 and node 1:\n\n    >>> sg = dgl.sampling.sample_labors(g, [0, 1], 1)\n    >>> sg.edges(order='eid')\n    (tensor([1, 0]), tensor([0, 1]))\n    >>> sg.edata[dgl.EID]\n    tensor([2, 0])\n\n    To sample one inbound edge for node 0 and node 1 with probability in edge feature\n    ``prob``:\n\n    >>> sg = dgl.sampling.sample_labors(g, [0, 1], 1, prob='prob')\n    >>> sg.edges(order='eid')\n    (tensor([2, 1]), tensor([0, 1]))\n\n    With ``fanout`` greater than the number of actual neighbors and without replacement,\n    DGL will take all neighbors instead:\n\n    >>> sg = dgl.sampling.sample_labors(g, [0, 1], 3)\n    >>> sg.edges(order='eid')\n    (tensor([1, 2, 0, 1]), tensor([0, 0, 1, 1]))\n\n    To exclude certain EID's during sampling for the seed nodes:\n\n    >>> g = dgl.graph(([0, 0, 1, 1, 2, 2], [1, 2, 0, 1, 2, 0]))\n    >>> g_edges = g.all_edges(form='all')``\n    (tensor([0, 0, 1, 1, 2, 2]), tensor([1, 2, 0, 1, 2, 0]), tensor([0, 1, 2, 3, 4, 5]))\n    >>> sg = dgl.sampling.sample_labors(g, [0, 1], 3, exclude_edges=[0, 1, 2])\n    >>> sg.all_edges(form='all')\n    (tensor([2, 1]), tensor([0, 1]), tensor([0, 1]))\n    >>> sg.has_edges_between(g_edges[0][:3],g_edges[1][:3])\n    tensor([False, False, False])\n    >>> g = dgl.heterograph({\n    ...   ('drug', 'interacts', 'drug'): ([0, 0, 1, 1, 3, 2], [1, 2, 0, 1, 2, 0]),\n    ...   ('drug', 'interacts', 'gene'): ([0, 0, 1, 1, 2, 2], [1, 2, 0, 1, 2, 0]),\n    ...   ('drug', 'treats', 'disease'): ([0, 0, 1, 1, 2, 2], [1, 2, 0, 1, 2, 0])})\n    >>> g_edges = g.all_edges(form='all', etype=('drug', 'interacts', 'drug'))\n    (tensor([0, 0, 1, 1, 3, 2]), tensor([1, 2, 0, 1, 2, 0]), tensor([0, 1, 2, 3, 4, 5]))\n    >>> excluded_edges  = {('drug', 'interacts', 'drug'): g_edges[2][:3]}\n    >>> sg = dgl.sampling.sample_labors(g, {'drug':[0, 1]}, 3, exclude_edges=excluded_edges)\n    >>> sg.all_edges(form='all', etype=('drug', 'interacts', 'drug'))\n    (tensor([2, 1]), tensor([0, 1]), tensor([0, 1]))\n    >>> sg.has_edges_between(g_edges[0][:3],g_edges[1][:3],etype=('drug', 'interacts', 'drug'))\n    tensor([False, False, False])\n\n    \"\"\"\n    if F.device_type(g.device) == \"cpu\" and not g.is_pinned():\n        frontier, importances = _sample_labors(\n            g,\n            nodes,\n            fanout,\n            edge_dir=edge_dir,\n            prob=prob,\n            importance_sampling=importance_sampling,\n            random_seed=random_seed,\n            seed2_contribution=seed2_contribution,\n            copy_ndata=copy_ndata,\n            copy_edata=copy_edata,\n            exclude_edges=exclude_edges,\n        )\n    else:\n        frontier, importances = _sample_labors(\n            g,\n            nodes,\n            fanout,\n            edge_dir=edge_dir,\n            prob=prob,\n            importance_sampling=importance_sampling,\n            random_seed=random_seed,\n            seed2_contribution=seed2_contribution,\n            copy_ndata=copy_ndata,\n            copy_edata=copy_edata,\n        )\n        if exclude_edges is not None:\n            eid_excluder = EidExcluder(exclude_edges)\n            frontier, importances = eid_excluder(frontier, importances)\n    if output_device is None:\n        return (frontier, importances)\n    else:\n        return (\n            frontier.to(output_device),\n            list(map(lambda x: x.to(output_device), importances)),\n        )\n\n\ndef _sample_labors(\n    g,\n    nodes,\n    fanout,\n    edge_dir=\"in\",\n    prob=None,\n    importance_sampling=0,\n    random_seed=None,\n    seed2_contribution=0,\n    copy_ndata=True,\n    copy_edata=True,\n    exclude_edges=None,\n):\n    if random_seed is None:\n        random_seed = F.to_dgl_nd(choice(1e18, 1))\n    if not isinstance(nodes, dict):\n        if len(g.ntypes) > 1:\n            raise DGLError(\n                \"Must specify node type when the graph is not homogeneous.\"\n            )\n        nodes = {g.ntypes[0]: nodes}\n\n    nodes = utils.prepare_tensor_dict(g, nodes, \"nodes\")\n    if len(nodes) == 0:\n        raise ValueError(\n            \"Got an empty dictionary in the nodes argument. \"\n            \"Please pass in a dictionary with empty tensors as values instead.\"\n        )\n    ctx = utils.to_dgl_context(F.context(next(iter(nodes.values()))))\n    nodes_all_types = []\n    # nids_all_types is needed if one wants labor to work for subgraphs whose vertices have\n    # been renamed and the rolled randoms should be rolled for global vertex ids.\n    # It is disabled for now below by passing empty ndarrays.\n    nids_all_types = [nd.array([], ctx=ctx) for _ in g.ntypes]\n    for ntype in g.ntypes:\n        if ntype in nodes:\n            nodes_all_types.append(F.to_dgl_nd(nodes[ntype]))\n        else:\n            nodes_all_types.append(nd.array([], ctx=ctx))\n\n    if isinstance(fanout, nd.NDArray):\n        fanout_array = fanout\n    else:\n        if not isinstance(fanout, dict):\n            fanout_array = [int(fanout)] * len(g.etypes)\n        else:\n            if len(fanout) != len(g.etypes):\n                raise DGLError(\n                    \"Fan-out must be specified for each edge type \"\n                    \"if a dict is provided.\"\n                )\n            fanout_array = [None] * len(g.etypes)\n            for etype, value in fanout.items():\n                fanout_array[g.get_etype_id(etype)] = value\n        fanout_array = F.to_dgl_nd(F.tensor(fanout_array, dtype=F.int64))\n\n    if (\n        isinstance(prob, list)\n        and len(prob) > 0\n        and isinstance(prob[0], nd.NDArray)\n    ):\n        prob_arrays = prob\n    elif prob is None:\n        prob_arrays = [nd.array([], ctx=nd.cpu())] * len(g.etypes)\n    else:\n        prob_arrays = []\n        for etype in g.canonical_etypes:\n            if prob in g.edges[etype].data:\n                prob_arrays.append(F.to_dgl_nd(g.edges[etype].data[prob]))\n            else:\n                prob_arrays.append(nd.array([], ctx=nd.cpu()))\n\n    excluded_edges_all_t = []\n    if exclude_edges is not None:\n        if not isinstance(exclude_edges, dict):\n            if len(g.etypes) > 1:\n                raise DGLError(\n                    \"Must specify etype when the graph is not homogeneous.\"\n                )\n            exclude_edges = {g.canonical_etypes[0]: exclude_edges}\n        exclude_edges = utils.prepare_tensor_dict(g, exclude_edges, \"edges\")\n        for etype in g.canonical_etypes:\n            if etype in exclude_edges:\n                excluded_edges_all_t.append(F.to_dgl_nd(exclude_edges[etype]))\n            else:\n                excluded_edges_all_t.append(nd.array([], ctx=ctx))\n\n    ret_val = _CAPI_DGLSampleLabors(\n        g._graph,\n        nodes_all_types,\n        fanout_array,\n        edge_dir,\n        prob_arrays,\n        excluded_edges_all_t,\n        importance_sampling,\n        random_seed,\n        seed2_contribution,\n        nids_all_types,\n    )\n    subgidx = ret_val[0]\n    importances = [F.from_dgl_nd(importance) for importance in ret_val[1:]]\n    induced_edges = subgidx.induced_edges\n    ret = DGLGraph(subgidx.graph, g.ntypes, g.etypes)\n\n    if copy_ndata:\n        node_frames = utils.extract_node_subframes(g, None)\n        utils.set_new_frames(ret, node_frames=node_frames)\n\n    if copy_edata:\n        edge_frames = utils.extract_edge_subframes(g, induced_edges)\n        utils.set_new_frames(ret, edge_frames=edge_frames)\n\n    return ret, importances\n\n\nDGLGraph.sample_labors = utils.alias_func(sample_labors)\n\n_init_api(\"dgl.sampling.labor\", __name__)\n"
  },
  {
    "path": "python/dgl/sampling/negative.py",
    "content": "\"\"\"Negative sampling APIs\"\"\"\n\nfrom numpy.polynomial import polynomial\n\nfrom .. import backend as F, utils\nfrom .._ffi.function import _init_api\nfrom ..heterograph import DGLGraph\n\n__all__ = [\"global_uniform_negative_sampling\"]\n\n\ndef _calc_redundancy(\n    k_hat, num_edges, num_pairs, r=3\n):  # pylint: disable=invalid-name\n    # pylint: disable=invalid-name\n    # Calculates the number of samples required based on a lower-bound\n    # of the expected number of negative samples, based on N draws from\n    # a binomial distribution.  Solves the following equation for N:\n    #\n    # k_hat = N*p_k - r * np.sqrt(N*p_k*(1-p_k))\n    #\n    # where p_k is the probability that a node pairing is a negative edge\n    # and r is the number of standard deviations to construct the lower bound\n    #\n    # Credits to @zjost\n    p_m = num_edges / num_pairs\n    p_k = 1 - p_m\n\n    a = p_k**2\n    b = -p_k * (2 * k_hat + r**2 * p_m)\n    c = k_hat**2\n\n    poly = polynomial.Polynomial([c, b, a])\n    N = poly.roots()[-1]\n    redundancy = N / k_hat - 1.0\n    return redundancy\n\n\ndef global_uniform_negative_sampling(\n    g,\n    num_samples,\n    exclude_self_loops=True,\n    replace=False,\n    etype=None,\n    redundancy=None,\n):\n    \"\"\"Performs negative sampling, which generate source-destination pairs such that\n    edges with the given type do not exist.\n\n    Specifically, this function takes in an edge type and a number of samples.  It\n    returns two tensors ``src`` and ``dst``, the former in the range of ``[0, num_src)``\n    and the latter in the range of ``[0, num_dst)``, where ``num_src`` and ``num_dst``\n    represents the number of nodes with the source and destination node type respectively.\n    It guarantees that no edge will exist between the corresponding pairs of ``src``\n    with the source node type and ``dst`` with the destination node type.\n\n    .. note::\n\n       This negative sampler will try to generate as many negative samples as possible, but\n       it may rarely return less than :attr:`num_samples` negative samples.\n       This is more likely to happen when a graph is so small or dense that not many\n       unique negative samples exist.\n\n    Parameters\n    ----------\n    g : DGLGraph\n        The graph.\n    num_samples : int\n        The number of desired negative samples to generate.\n    exclude_self_loops : bool, optional\n        Whether to exclude self-loops from the negative samples.  Only impacts the\n        edge types whose source and destination node types are the same.\n\n        Default: True.\n    replace : bool, optional\n        Whether to sample with replacement.  Setting it to True will make things\n        faster.  (Default: False)\n    etype : str or tuple of str, optional\n        The edge type.  Can be omitted if the graph only has one edge type.\n    redundancy : float, optional\n        Indicates how much more negative samples to actually generate during rejection sampling\n        before finding the unique pairs.\n\n        Increasing it will increase the likelihood of getting :attr:`num_samples` negative\n        samples, but will also take more time and memory.\n\n        (Default: automatically determined by the density of graph)\n\n    Returns\n    -------\n    tuple[Tensor, Tensor]\n        The source and destination pairs.\n\n    Examples\n    --------\n    >>> g = dgl.graph(([0, 1, 2], [1, 2, 3]))\n    >>> dgl.sampling.global_uniform_negative_sampling(g, 3)\n    (tensor([0, 1, 3]), tensor([2, 0, 2]))\n    \"\"\"\n    if etype is None:\n        etype = g.etypes[0]\n    utype, _, vtype = g.to_canonical_etype(etype)\n    exclude_self_loops = exclude_self_loops and (utype == vtype)\n\n    redundancy = _calc_redundancy(\n        num_samples, g.num_edges(etype), g.num_nodes(utype) * g.num_nodes(vtype)\n    )\n\n    etype_id = g.get_etype_id(etype)\n    src, dst = _CAPI_DGLGlobalUniformNegativeSampling(\n        g._graph,\n        etype_id,\n        num_samples,\n        3,\n        exclude_self_loops,\n        replace,\n        redundancy,\n    )\n    return F.from_dgl_nd(src), F.from_dgl_nd(dst)\n\n\nDGLGraph.global_uniform_negative_sampling = utils.alias_func(\n    global_uniform_negative_sampling\n)\n\n_init_api(\"dgl.sampling.negative\", __name__)\n"
  },
  {
    "path": "python/dgl/sampling/neighbor.py",
    "content": "\"\"\"Neighbor sampling APIs\"\"\"\n\nimport os\n\nimport torch\n\nfrom .. import backend as F, ndarray as nd, utils\nfrom .._ffi.function import _init_api\nfrom ..base import DGLError, EID\nfrom ..heterograph import DGLBlock, DGLGraph\nfrom .utils import EidExcluder\n\n__all__ = [\n    \"sample_etype_neighbors\",\n    \"sample_neighbors\",\n    \"sample_neighbors_fused\",\n    \"sample_neighbors_biased\",\n    \"select_topk\",\n]\n\n\ndef _prepare_edge_arrays(g, arg):\n    \"\"\"Converts the argument into a list of NDArrays.\n\n    If the argument is already a list of array-like objects, directly do the\n    conversion.\n\n    If the argument is a string, converts g.edata[arg] into a list of NDArrays\n    ordered by the edge types.\n    \"\"\"\n    if isinstance(arg, list) and len(arg) > 0:\n        if isinstance(arg[0], nd.NDArray):\n            return arg\n        else:\n            # The list can have None as placeholders for empty arrays with\n            # undetermined data type.\n            dtype = None\n            ctx = None\n            result = []\n            for entry in arg:\n                if F.is_tensor(entry):\n                    result.append(F.to_dgl_nd(entry))\n                    dtype = F.dtype(entry)\n                    ctx = F.context(entry)\n                else:\n                    result.append(None)\n\n            result = [\n                (\n                    F.to_dgl_nd(F.copy_to(F.tensor([], dtype=dtype), ctx))\n                    if x is None\n                    else x\n                )\n                for x in result\n            ]\n            return result\n    elif arg is None:\n        return [nd.array([], ctx=nd.cpu())] * len(g.etypes)\n    else:\n        arrays = []\n        for etype in g.canonical_etypes:\n            if arg in g.edges[etype].data:\n                arrays.append(F.to_dgl_nd(g.edges[etype].data[arg]))\n            else:\n                arrays.append(nd.array([], ctx=nd.cpu()))\n        return arrays\n\n\ndef sample_etype_neighbors(\n    g,\n    nodes,\n    etype_offset,\n    fanout,\n    edge_dir=\"in\",\n    prob=None,\n    exclude_edges=None,\n    replace=False,\n    copy_ndata=True,\n    copy_edata=True,\n    etype_sorted=False,\n    _dist_training=False,\n    output_device=None,\n):\n    \"\"\"Sample neighboring edges of the given nodes and return the induced subgraph.\n\n    For each node, a number of inbound (or outbound when ``edge_dir == 'out'``) edges\n    will be randomly chosen.  The graph returned will then contain all the nodes in the\n    original graph, but only the sampled edges.\n\n    Node/edge features are not preserved. The original IDs of\n    the sampled edges are stored as the `dgl.EID` feature in the returned graph.\n\n    Parameters\n    ----------\n    g : DGLGraph\n        The graph.  Can only be in CPU. Should only have one node type and one edge type.\n    nodes : tensor or dict\n        Node IDs to sample neighbors from.\n\n        This argument can take a single ID tensor or a dictionary of node types and ID tensors.\n        If a single tensor is given, the graph must only have one type of nodes.\n    etype_offset : list[int]\n        The offset of each edge type ID.\n    fanout : Tensor\n        The number of edges to be sampled for each node per edge type.  Must be a\n        1D tensor with the number of elements same as the number of edge types.\n\n        If -1 is given, all of the neighbors with non-zero probability will be selected.\n    edge_dir : str, optional\n        Determines whether to sample inbound or outbound edges.\n\n        Can take either ``in`` for inbound edges or ``out`` for outbound edges.\n    prob : list[Tensor], optional\n        The (unnormalized) probabilities associated with each neighboring edge of\n        a node.\n\n        The features must be non-negative floats or boolean.  Otherwise, the\n        result will be undefined.\n    exclude_edges: tensor or dict\n        Edge IDs to exclude during sampling neighbors for the seed nodes.\n\n        This argument can take a single ID tensor or a dictionary of edge types and ID tensors.\n        If a single tensor is given, the graph must only have one type of nodes.\n    replace : bool, optional\n        If True, sample with replacement.\n    copy_ndata: bool, optional\n        If True, the node features of the new graph are copied from\n        the original graph. If False, the new graph will not have any\n        node features.\n\n        (Default: True)\n    copy_edata: bool, optional\n        If True, the edge features of the new graph are copied from\n        the original graph.  If False, the new graph will not have any\n        edge features.\n\n        (Default: True)\n    _dist_training : bool, optional\n        Internal argument.  Do not use.\n\n        (Default: False)\n    etype_sorted: bool, optional\n        A hint telling whether the etypes are already sorted.\n\n        (Default: False)\n    output_device : Framework-specific device context object, optional\n        The output device.  Default is the same as the input graph.\n\n    Returns\n    -------\n    DGLGraph\n        A sampled subgraph containing only the sampled neighboring edges, with the\n        same device as the input graph.\n\n    Notes\n    -----\n    If :attr:`copy_ndata` or :attr:`copy_edata` is True, same tensors are used as\n    the node or edge features of the original graph and the new graph.\n    As a result, users should avoid performing in-place operations\n    on the node features of the new graph to avoid feature corruption.\n    \"\"\"\n    if exclude_edges is not None:\n        raise DGLError(\n            \"exclude_edges is not supported for sample_etype_neighbors\"\n        )\n    if g.device != F.cpu():\n        raise DGLError(\"The graph should be in cpu.\")\n    # (BarclayII) because the homogenized graph no longer contains the *name* of edge\n    # types, the fanout argument can no longer be a dict of etypes and ints, as opposed\n    # to sample_neighbors.\n    if not F.is_tensor(fanout):\n        raise DGLError(\"The fanout should be a tensor\")\n    if isinstance(nodes, dict):\n        assert len(nodes) == 1, \"The input graph should not have node types\"\n        nodes = list(nodes.values())[0]\n\n    nodes = utils.prepare_tensor(g, nodes, \"nodes\")\n    device = utils.context_of(nodes)\n    nodes = F.to_dgl_nd(nodes)\n    # treat etypes as int32, it is much cheaper than int64\n    # TODO(xiangsx): int8 can be a better choice.\n    fanout = F.to_dgl_nd(fanout)\n\n    prob_array = _prepare_edge_arrays(g, prob)\n\n    subgidx = _CAPI_DGLSampleNeighborsEType(\n        g._graph,\n        nodes,\n        etype_offset,\n        fanout,\n        edge_dir,\n        prob_array,\n        replace,\n        etype_sorted,\n    )\n    induced_edges = subgidx.induced_edges\n    ret = DGLGraph(subgidx.graph, g.ntypes, g.etypes)\n\n    # handle features\n    # (TODO) (BarclayII) DGL distributed fails with bus error, freezes, or other\n    # incomprehensible errors with lazy feature copy.\n    # So in distributed training context, we fall back to old behavior where we\n    # only set the edge IDs.\n    if not _dist_training:\n        if copy_ndata:\n            node_frames = utils.extract_node_subframes(g, device)\n            utils.set_new_frames(ret, node_frames=node_frames)\n\n        if copy_edata:\n            edge_frames = utils.extract_edge_subframes(g, induced_edges)\n            utils.set_new_frames(ret, edge_frames=edge_frames)\n    else:\n        for i, etype in enumerate(ret.canonical_etypes):\n            ret.edges[etype].data[EID] = induced_edges[i]\n\n    return ret if output_device is None else ret.to(output_device)\n\n\nDGLGraph.sample_etype_neighbors = utils.alias_func(sample_etype_neighbors)\n\n\ndef sample_neighbors(\n    g,\n    nodes,\n    fanout,\n    edge_dir=\"in\",\n    prob=None,\n    replace=False,\n    copy_ndata=True,\n    copy_edata=True,\n    _dist_training=False,\n    exclude_edges=None,\n    output_device=None,\n):\n    \"\"\"Sample neighboring edges of the given nodes and return the induced subgraph.\n\n    For each node, a number of inbound (or outbound when ``edge_dir == 'out'``) edges\n    will be randomly chosen.  The graph returned will then contain all the nodes in the\n    original graph, but only the sampled edges.\n\n    Node/edge features are not preserved. The original IDs of\n    the sampled edges are stored as the `dgl.EID` feature in the returned graph.\n\n    GPU sampling is supported for this function. Refer to :ref:`guide-minibatch-gpu-sampling`\n    for more details.\n\n    Parameters\n    ----------\n    g : DGLGraph\n        The graph.  Can be either on CPU or GPU.\n    nodes : tensor or dict\n        Node IDs to sample neighbors from.\n\n        This argument can take a single ID tensor or a dictionary of node types and ID tensors.\n        If a single tensor is given, the graph must only have one type of nodes.\n    fanout : int or dict[etype, int]\n        The number of edges to be sampled for each node on each edge type.\n\n        This argument can take a single int or a dictionary of edge types and ints.\n        If a single int is given, DGL will sample this number of edges for each node for\n        every edge type.\n\n        If -1 is given for a single edge type, all the neighboring edges with that edge\n        type and non-zero probability will be selected.\n    edge_dir : str, optional\n        Determines whether to sample inbound or outbound edges.\n\n        Can take either ``in`` for inbound edges or ``out`` for outbound edges.\n    prob : str, optional\n        Feature name used as the (unnormalized) probabilities associated with each\n        neighboring edge of a node.  The feature must have only one element for each\n        edge.\n\n        The features must be non-negative floats or boolean.  Otherwise, the result\n        will be undefined.\n    exclude_edges: tensor or dict\n        Edge IDs to exclude during sampling neighbors for the seed nodes.\n\n        This argument can take a single ID tensor or a dictionary of edge types and ID tensors.\n        If a single tensor is given, the graph must only have one type of nodes.\n    replace : bool, optional\n        If True, sample with replacement.\n    copy_ndata: bool, optional\n        If True, the node features of the new graph are copied from\n        the original graph. If False, the new graph will not have any\n        node features.\n\n        (Default: True)\n    copy_edata: bool, optional\n        If True, the edge features of the new graph are copied from\n        the original graph.  If False, the new graph will not have any\n        edge features.\n\n        (Default: True)\n    _dist_training : bool, optional\n        Internal argument.  Do not use.\n\n        (Default: False)\n    output_device : Framework-specific device context object, optional\n        The output device.  Default is the same as the input graph.\n\n    Returns\n    -------\n    DGLGraph\n        A sampled subgraph containing only the sampled neighboring edges.\n\n    Notes\n    -----\n    If :attr:`copy_ndata` or :attr:`copy_edata` is True, same tensors are used as\n    the node or edge features of the original graph and the new graph.\n    As a result, users should avoid performing in-place operations\n    on the node features of the new graph to avoid feature corruption.\n\n    Examples\n    --------\n    Assume that you have the following graph\n\n    >>> g = dgl.graph(([0, 0, 1, 1, 2, 2], [1, 2, 0, 1, 2, 0]))\n\n    And the weights\n\n    >>> g.edata['prob'] = torch.FloatTensor([0., 1., 0., 1., 0., 1.])\n\n    To sample one inbound edge for node 0 and node 1:\n\n    >>> sg = dgl.sampling.sample_neighbors(g, [0, 1], 1)\n    >>> sg.edges(order='eid')\n    (tensor([1, 0]), tensor([0, 1]))\n    >>> sg.edata[dgl.EID]\n    tensor([2, 0])\n\n    To sample one inbound edge for node 0 and node 1 with probability in edge feature\n    ``prob``:\n\n    >>> sg = dgl.sampling.sample_neighbors(g, [0, 1], 1, prob='prob')\n    >>> sg.edges(order='eid')\n    (tensor([2, 1]), tensor([0, 1]))\n\n    With ``fanout`` greater than the number of actual neighbors and without replacement,\n    DGL will take all neighbors instead:\n\n    >>> sg = dgl.sampling.sample_neighbors(g, [0, 1], 3)\n    >>> sg.edges(order='eid')\n    (tensor([1, 2, 0, 1]), tensor([0, 0, 1, 1]))\n\n    To exclude certain EID's during sampling for the seed nodes:\n\n    >>> g = dgl.graph(([0, 0, 1, 1, 2, 2], [1, 2, 0, 1, 2, 0]))\n    >>> g_edges = g.all_edges(form='all')``\n    (tensor([0, 0, 1, 1, 2, 2]), tensor([1, 2, 0, 1, 2, 0]), tensor([0, 1, 2, 3, 4, 5]))\n    >>> sg = dgl.sampling.sample_neighbors(g, [0, 1], 3, exclude_edges=[0, 1, 2])\n    >>> sg.all_edges(form='all')\n    (tensor([2, 1]), tensor([0, 1]), tensor([0, 1]))\n    >>> sg.has_edges_between(g_edges[0][:3],g_edges[1][:3])\n    tensor([False, False, False])\n    >>> g = dgl.heterograph({\n    ...   ('drug', 'interacts', 'drug'): ([0, 0, 1, 1, 3, 2], [1, 2, 0, 1, 2, 0]),\n    ...   ('drug', 'interacts', 'gene'): ([0, 0, 1, 1, 2, 2], [1, 2, 0, 1, 2, 0]),\n    ...   ('drug', 'treats', 'disease'): ([0, 0, 1, 1, 2, 2], [1, 2, 0, 1, 2, 0])})\n    >>> g_edges = g.all_edges(form='all', etype=('drug', 'interacts', 'drug'))\n    (tensor([0, 0, 1, 1, 3, 2]), tensor([1, 2, 0, 1, 2, 0]), tensor([0, 1, 2, 3, 4, 5]))\n    >>> excluded_edges  = {('drug', 'interacts', 'drug'): g_edges[2][:3]}\n    >>> sg = dgl.sampling.sample_neighbors(g, {'drug':[0, 1]}, 3, exclude_edges=excluded_edges)\n    >>> sg.all_edges(form='all', etype=('drug', 'interacts', 'drug'))\n    (tensor([2, 1]), tensor([0, 1]), tensor([0, 1]))\n    >>> sg.has_edges_between(g_edges[0][:3],g_edges[1][:3],etype=('drug', 'interacts', 'drug'))\n    tensor([False, False, False])\n\n    \"\"\"\n    if F.device_type(g.device) == \"cpu\" and not g.is_pinned():\n        frontier = _sample_neighbors(\n            g,\n            nodes,\n            fanout,\n            edge_dir=edge_dir,\n            prob=prob,\n            replace=replace,\n            copy_ndata=copy_ndata,\n            copy_edata=copy_edata,\n            exclude_edges=exclude_edges,\n        )\n    else:\n        frontier = _sample_neighbors(\n            g,\n            nodes,\n            fanout,\n            edge_dir=edge_dir,\n            prob=prob,\n            replace=replace,\n            copy_ndata=copy_ndata,\n            copy_edata=copy_edata,\n        )\n        if exclude_edges is not None:\n            eid_excluder = EidExcluder(exclude_edges)\n            frontier = eid_excluder(frontier)\n    return frontier if output_device is None else frontier.to(output_device)\n\n\ndef sample_neighbors_fused(\n    g,\n    nodes,\n    fanout,\n    edge_dir=\"in\",\n    prob=None,\n    replace=False,\n    copy_ndata=True,\n    copy_edata=True,\n    exclude_edges=None,\n    mapping=None,\n):\n    \"\"\"Sample neighboring edges of the given nodes and return the induced subgraph.\n\n    For each node, a number of inbound (or outbound when ``edge_dir == 'out'``) edges\n    will be randomly chosen.  The graph returned will then contain all the nodes in the\n    original graph, but only the sampled edges. Nodes will be renumbered starting from id 0,\n    which would be new node id of first seed node.\n\n    Parameters\n    ----------\n    g : DGLGraph\n        The graph.  Can be either on CPU or GPU.\n    nodes : tensor or dict\n        Node IDs to sample neighbors from.\n\n        This argument can take a single ID tensor or a dictionary of node types and ID tensors.\n        If a single tensor is given, the graph must only have one type of nodes.\n    fanout : int or dict[etype, int]\n        The number of edges to be sampled for each node on each edge type.\n\n        This argument can take a single int or a dictionary of edge types and ints.\n        If a single int is given, DGL will sample this number of edges for each node for\n        every edge type.\n\n        If -1 is given for a single edge type, all the neighboring edges with that edge\n        type and non-zero probability will be selected.\n    edge_dir : str, optional\n        Determines whether to sample inbound or outbound edges.\n\n        Can take either ``in`` for inbound edges or ``out`` for outbound edges.\n    prob : str, optional\n        Feature name used as the (unnormalized) probabilities associated with each\n        neighboring edge of a node.  The feature must have only one element for each\n        edge.\n\n        The features must be non-negative floats or boolean.  Otherwise, the result\n        will be undefined.\n    exclude_edges: tensor or dict\n        Edge IDs to exclude during sampling neighbors for the seed nodes.\n\n        This argument can take a single ID tensor or a dictionary of edge types and ID tensors.\n        If a single tensor is given, the graph must only have one type of nodes.\n    replace : bool, optional\n        If True, sample with replacement.\n    copy_ndata: bool, optional\n        If True, the node features of the new graph are copied from\n        the original graph. If False, the new graph will not have any\n        node features.\n\n        (Default: True)\n    copy_edata: bool, optional\n        If True, the edge features of the new graph are copied from\n        the original graph.  If False, the new graph will not have any\n        edge features.\n\n        (Default: False)\n\n    mapping : dictionary, optional\n        Used by fused version of NeighborSampler. To avoid constant data allocation\n        provide empty dictionary ({}) that will be allocated once with proper data and reused\n        by each function call\n\n        (Default: None)\n    Returns\n    -------\n    DGLGraph\n        A sampled subgraph containing only the sampled neighboring edges.\n\n    Notes\n    -----\n    If :attr:`copy_ndata` or :attr:`copy_edata` is True, same tensors are used as\n    the node or edge features of the original graph and the new graph.\n    As a result, users should avoid performing in-place operations\n    on the node features of the new graph to avoid feature corruption.\n\n    \"\"\"\n    if not g.is_pinned():\n        frontier = _sample_neighbors(\n            g,\n            nodes,\n            fanout,\n            edge_dir=edge_dir,\n            prob=prob,\n            replace=replace,\n            copy_ndata=copy_ndata,\n            copy_edata=copy_edata,\n            exclude_edges=exclude_edges,\n            fused=True,\n            mapping=mapping,\n        )\n    else:\n        frontier = _sample_neighbors(\n            g,\n            nodes,\n            fanout,\n            edge_dir=edge_dir,\n            prob=prob,\n            replace=replace,\n            copy_ndata=copy_ndata,\n            copy_edata=copy_edata,\n            fused=True,\n            mapping=mapping,\n        )\n        if exclude_edges is not None:\n            eid_excluder = EidExcluder(exclude_edges)\n            frontier = eid_excluder(frontier)\n    return frontier\n\n\ndef _sample_neighbors(\n    g,\n    nodes,\n    fanout,\n    edge_dir=\"in\",\n    prob=None,\n    replace=False,\n    copy_ndata=True,\n    copy_edata=True,\n    _dist_training=False,\n    exclude_edges=None,\n    fused=False,\n    mapping=None,\n):\n    if not isinstance(nodes, dict):\n        if len(g.ntypes) > 1:\n            raise DGLError(\n                \"Must specify node type when the graph is not homogeneous.\"\n            )\n        nodes = {g.ntypes[0]: nodes}\n\n    nodes = utils.prepare_tensor_dict(g, nodes, \"nodes\")\n    if len(nodes) == 0:\n        raise ValueError(\n            \"Got an empty dictionary in the nodes argument. \"\n            \"Please pass in a dictionary with empty tensors as values instead.\"\n        )\n    device = utils.context_of(nodes)\n    ctx = utils.to_dgl_context(device)\n    nodes_all_types = []\n    for ntype in g.ntypes:\n        if ntype in nodes:\n            nodes_all_types.append(F.to_dgl_nd(nodes[ntype]))\n        else:\n            nodes_all_types.append(nd.array([], ctx=ctx))\n\n    if isinstance(fanout, nd.NDArray):\n        fanout_array = fanout\n    else:\n        if not isinstance(fanout, dict):\n            fanout_array = [int(fanout)] * len(g.etypes)\n        else:\n            if len(fanout) != len(g.etypes):\n                raise DGLError(\n                    \"Fan-out must be specified for each edge type \"\n                    \"if a dict is provided.\"\n                )\n            fanout_array = [None] * len(g.etypes)\n            for etype, value in fanout.items():\n                fanout_array[g.get_etype_id(etype)] = value\n        fanout_array = F.to_dgl_nd(F.tensor(fanout_array, dtype=F.int64))\n\n    prob_arrays = _prepare_edge_arrays(g, prob)\n\n    excluded_edges_all_t = []\n    if exclude_edges is not None:\n        if not isinstance(exclude_edges, dict):\n            if len(g.etypes) > 1:\n                raise DGLError(\n                    \"Must specify etype when the graph is not homogeneous.\"\n                )\n            exclude_edges = {g.canonical_etypes[0]: exclude_edges}\n        exclude_edges = utils.prepare_tensor_dict(g, exclude_edges, \"edges\")\n        for etype in g.canonical_etypes:\n            if etype in exclude_edges:\n                excluded_edges_all_t.append(F.to_dgl_nd(exclude_edges[etype]))\n            else:\n                excluded_edges_all_t.append(nd.array([], ctx=ctx))\n\n    if fused:\n        if _dist_training:\n            raise DGLError(\n                \"distributed training not supported in fused sampling\"\n            )\n        cpu = F.device_type(g.device) == \"cpu\"\n        if isinstance(nodes, dict):\n            for ntype in list(nodes.keys()):\n                if not cpu:\n                    break\n                cpu = cpu and F.device_type(nodes[ntype].device) == \"cpu\"\n        else:\n            cpu = cpu and F.device_type(nodes.device) == \"cpu\"\n        if not cpu or F.backend_name != \"pytorch\":\n            raise DGLError(\n                \"Only PyTorch backend and cpu is supported in fused sampling\"\n            )\n\n        if mapping is None:\n            mapping = {}\n        mapping_name = \"__mapping\" + str(os.getpid())\n        if mapping_name not in mapping.keys():\n            mapping[mapping_name] = [\n                torch.LongTensor(g.num_nodes(ntype)).fill_(-1)\n                for ntype in g.ntypes\n            ]\n\n        subgidx, induced_nodes, induced_edges = _CAPI_DGLSampleNeighborsFused(\n            g._graph,\n            nodes_all_types,\n            [F.to_dgl_nd(m) for m in mapping[mapping_name]],\n            fanout_array,\n            edge_dir,\n            prob_arrays,\n            excluded_edges_all_t,\n            replace,\n        )\n        for mapping_vector, src_nodes in zip(\n            mapping[mapping_name], induced_nodes\n        ):\n            mapping_vector[F.from_dgl_nd(src_nodes).type(F.int64)] = -1\n\n        new_ntypes = (g.ntypes, g.ntypes)\n        ret = DGLBlock(subgidx, new_ntypes, g.etypes)\n        assert ret.is_unibipartite\n\n    else:\n        subgidx = _CAPI_DGLSampleNeighbors(\n            g._graph,\n            nodes_all_types,\n            fanout_array,\n            edge_dir,\n            prob_arrays,\n            excluded_edges_all_t,\n            replace,\n        )\n        ret = DGLGraph(subgidx.graph, g.ntypes, g.etypes)\n        induced_edges = subgidx.induced_edges\n\n    # handle features\n    # (TODO) (BarclayII) DGL distributed fails with bus error, freezes, or other\n    # incomprehensible errors with lazy feature copy.\n    # So in distributed training context, we fall back to old behavior where we\n    # only set the edge IDs.\n    if not _dist_training:\n        if copy_ndata:\n            if fused:\n                src_node_ids = [F.from_dgl_nd(src) for src in induced_nodes]\n                dst_node_ids = [\n                    utils.toindex(\n                        nodes.get(ntype, []), g._idtype_str\n                    ).tousertensor(ctx=F.to_backend_ctx(g._graph.ctx))\n                    for ntype in g.ntypes\n                ]\n                node_frames = utils.extract_node_subframes_for_block(\n                    g, src_node_ids, dst_node_ids\n                )\n                utils.set_new_frames(ret, node_frames=node_frames)\n            else:\n                node_frames = utils.extract_node_subframes(g, device)\n                utils.set_new_frames(ret, node_frames=node_frames)\n\n        if copy_edata:\n            if fused:\n                edge_ids = [F.from_dgl_nd(eid) for eid in induced_edges]\n                edge_frames = utils.extract_edge_subframes(g, edge_ids)\n                utils.set_new_frames(ret, edge_frames=edge_frames)\n            else:\n                edge_frames = utils.extract_edge_subframes(g, induced_edges)\n                utils.set_new_frames(ret, edge_frames=edge_frames)\n\n    else:\n        for i, etype in enumerate(ret.canonical_etypes):\n            ret.edges[etype].data[EID] = induced_edges[i]\n\n    return ret\n\n\nDGLGraph.sample_neighbors = utils.alias_func(sample_neighbors)\nDGLGraph.sample_neighbors_fused = utils.alias_func(sample_neighbors_fused)\n\n\ndef sample_neighbors_biased(\n    g,\n    nodes,\n    fanout,\n    bias,\n    edge_dir=\"in\",\n    tag_offset_name=\"_TAG_OFFSET\",\n    replace=False,\n    copy_ndata=True,\n    copy_edata=True,\n    output_device=None,\n):\n    r\"\"\"Sample neighboring edges of the given nodes and return the induced subgraph, where each\n    neighbor's probability to be picked is determined by its tag.\n\n    For each node, a number of inbound (or outbound when ``edge_dir == 'out'``) edges\n    will be randomly chosen.  The graph returned will then contain all the nodes in the\n    original graph, but only the sampled edges.\n\n    This version of neighbor sampling can support the scenario where adjacent nodes with different\n    types have different sampling probability. Each node is assigned an integer (called a *tag*)\n    which represents its type. Tag is an analogue of node type under the framework of homogeneous\n    graphs. Nodes with the same tag share the same probability.\n\n    For example, assume a node has :math:`N+M` neighbors, and :math:`N` of them\n    have tag 0 while :math:`M` of them have tag 1. Assume a node of tag 0 has\n    an unnormalized probability :math:`p` to be picked while a node of tag 1\n    has :math:`q`. This function first chooses a tag according to the\n    unnormalized probability distribution\n    :math:`\\frac{P(tag=0)}{P(tag=1)}=\\frac{Np}{Mq}`, and then run a uniform\n    sampling to get a node of the chosen tag.\n\n    In order to make sampling more efficient, the input graph must have its\n    CSC matrix (or CSR matrix if ``edge_dir='out'``) sorted according to the tag. The API\n    :func:`~dgl.sort_csc_by_tag` and\n    :func:`~dgl.sort_csr_by_tag` are designed for this purpose, which\n    will internally reorder the neighbors by tags so that neighbors of the same tags are\n    stored in a consecutive range. The two APIs will also store the offsets of these ranges\n    in a node feature with :attr:`tag_offset_name` as its name.\n\n    **Please make sure that the CSR (or CSC) matrix of the graph has been sorted before\n    calling this function.**  This function itself will not check whether the\n    input graph is sorted. Note that the input :attr:`tag_offset_name` should\n    be consistent with that in the sorting function.\n\n    Only homogeneous or bipartite graphs are supported. For bipartite graphs,\n    the tag offsets of the source nodes when ``edge_dir='in'`` (or the destination\n    nodes when ``edge_dir='out'``) will be used in sampling.\n\n    Node/edge features are not preserved. The original IDs of\n    the sampled edges are stored as the ``dgl.EID`` feature in the returned graph.\n\n    Parameters\n    ----------\n    g : DGLGraph\n        The graph. Must be homogeneous or bipartite (only one edge type). Must be on CPU.\n    nodes : tensor or list\n        Node IDs to sample neighbors from.\n    fanout : int\n        The number of edges to be sampled for each node on each edge type.\n\n        If -1 is given, all the neighboring edges with non-zero probability will be selected.\n    bias : tensor or list\n        The (unnormalized) probabilities associated with each tag. Its length should be equal\n        to the number of tags.\n\n        Entries of this array must be non-negative floats. Otherwise, the result will be\n        undefined.\n    edge_dir : str, optional\n        Determines whether to sample inbound or outbound edges.\n\n        Can take either ``in`` for inbound edges or ``out`` for outbound edges.\n    tag_offset_name : str, optional\n        The name of the node feature storing tag offsets.\n\n        (Default: \"_TAG_OFFSET\")\n    replace : bool, optional\n        If True, sample with replacement.\n    copy_ndata: bool, optional\n        If True, the node features of the new graph are copied from\n        the original graph. If False, the new graph will not have any\n        node features.\n\n        (Default: True)\n    copy_edata: bool, optional\n        If True, the edge features of the new graph are copied from\n        the original graph.  If False, the new graph will not have any\n        edge features.\n\n        (Default: True)\n    output_device : Framework-specific device context object, optional\n        The output device.  Default is the same as the input graph.\n\n    Returns\n    -------\n    DGLGraph\n        A sampled subgraph containing only the sampled neighboring edges.  It is on CPU.\n\n    Notes\n    -----\n    If :attr:`copy_ndata` or :attr:`copy_edata` is True, same tensors are used as\n    the node or edge features of the original graph and the new graph.\n    As a result, users should avoid performing in-place operations\n    on the node features of the new graph to avoid feature corruption.\n\n    See Also\n    --------\n    dgl.sort_csc_by_tag\n    dgl.sort_csr_by_tag\n\n    Examples\n    --------\n    Assume that you have the following graph\n\n    >>> g = dgl.graph(([0, 0, 1, 1, 2, 2], [1, 2, 0, 1, 2, 0]))\n\n    And the tags\n\n    >>> tag = torch.IntTensor([0, 0, 1])\n\n    Sort the graph (necessary!)\n\n    >>> g_sorted = dgl.transforms.sort_csr_by_tag(g, tag)\n    >>> g_sorted.ndata['_TAG_OFFSET']\n    tensor([[0, 1, 2],\n            [0, 2, 2],\n            [0, 1, 2]])\n\n    Set the probability of each tag:\n\n    >>> bias = torch.tensor([1.0, 0.001])\n    >>> # node 2 is almost impossible to be sampled because it has tag 1.\n\n    To sample one out bound edge for node 0 and node 2:\n\n    >>> sg = dgl.sampling.sample_neighbors_biased(g_sorted, [0, 2], 1, bias, edge_dir='out')\n    >>> sg.edges(order='eid')\n    (tensor([0, 2]), tensor([1, 0]))\n    >>> sg.edata[dgl.EID]\n    tensor([0, 5])\n\n    With ``fanout`` greater than the number of actual neighbors and without replacement,\n    DGL will take all neighbors instead:\n\n    >>> sg = dgl.sampling.sample_neighbors_biased(g_sorted, [0, 2], 3, bias, edge_dir='out')\n    >>> sg.edges(order='eid')\n    (tensor([0, 0, 2, 2]), tensor([1, 2, 0, 2]))\n    \"\"\"\n    if isinstance(nodes, list):\n        nodes = F.tensor(nodes)\n    if isinstance(bias, list):\n        bias = F.tensor(bias)\n    device = utils.context_of(nodes)\n\n    nodes_array = F.to_dgl_nd(nodes)\n    bias_array = F.to_dgl_nd(bias)\n    if edge_dir == \"in\":\n        tag_offset_array = F.to_dgl_nd(g.dstdata[tag_offset_name])\n    elif edge_dir == \"out\":\n        tag_offset_array = F.to_dgl_nd(g.srcdata[tag_offset_name])\n    else:\n        raise DGLError(\"edge_dir can only be 'in' or 'out'\")\n\n    subgidx = _CAPI_DGLSampleNeighborsBiased(\n        g._graph,\n        nodes_array,\n        fanout,\n        bias_array,\n        tag_offset_array,\n        edge_dir,\n        replace,\n    )\n    induced_edges = subgidx.induced_edges\n    ret = DGLGraph(subgidx.graph, g.ntypes, g.etypes)\n\n    if copy_ndata:\n        node_frames = utils.extract_node_subframes(g, device)\n        utils.set_new_frames(ret, node_frames=node_frames)\n\n    if copy_edata:\n        edge_frames = utils.extract_edge_subframes(g, induced_edges)\n        utils.set_new_frames(ret, edge_frames=edge_frames)\n\n    ret.edata[EID] = induced_edges[0]\n    return ret if output_device is None else ret.to(output_device)\n\n\nDGLGraph.sample_neighbors_biased = utils.alias_func(sample_neighbors_biased)\n\n\ndef select_topk(\n    g,\n    k,\n    weight,\n    nodes=None,\n    edge_dir=\"in\",\n    ascending=False,\n    copy_ndata=True,\n    copy_edata=True,\n    output_device=None,\n):\n    \"\"\"Select the neighboring edges with k-largest (or k-smallest) weights of the given\n    nodes and return the induced subgraph.\n\n    For each node, a number of inbound (or outbound when ``edge_dir == 'out'``) edges\n    with the largest (or smallest when ``ascending == True``) weights will be chosen.\n    The graph returned will then contain all the nodes in the original graph, but only\n    the sampled edges.\n\n    Node/edge features are not preserved. The original IDs of\n    the sampled edges are stored as the `dgl.EID` feature in the returned graph.\n\n    Parameters\n    ----------\n    g : DGLGraph\n        The graph.  Must be on CPU.\n    k : int or dict[etype, int]\n        The number of edges to be selected for each node on each edge type.\n\n        This argument can take a single int or a dictionary of edge types and ints.\n        If a single int is given, DGL will select this number of edges for each node for\n        every edge type.\n\n        If -1 is given for a single edge type, all the neighboring edges with that edge\n        type will be selected.\n    weight : str\n        Feature name of the weights associated with each edge.  The feature should have only\n        one element for each edge.  The feature can be either int32/64 or float32/64.\n    nodes : tensor or dict, optional\n        Node IDs to sample neighbors from.\n\n        This argument can take a single ID tensor or a dictionary of node types and ID tensors.\n        If a single tensor is given, the graph must only have one type of nodes.\n\n        If None, DGL will select the edges for all nodes.\n    edge_dir : str, optional\n        Determines whether to sample inbound or outbound edges.\n\n        Can take either ``in`` for inbound edges or ``out`` for outbound edges.\n    ascending : bool, optional\n        If True, DGL will return edges with k-smallest weights instead of\n        k-largest weights.\n    copy_ndata: bool, optional\n        If True, the node features of the new graph are copied from\n        the original graph. If False, the new graph will not have any\n        node features.\n\n        (Default: True)\n    copy_edata: bool, optional\n        If True, the edge features of the new graph are copied from\n        the original graph.  If False, the new graph will not have any\n        edge features.\n\n        (Default: True)\n    output_device : Framework-specific device context object, optional\n        The output device.  Default is the same as the input graph.\n\n    Returns\n    -------\n    DGLGraph\n        A sampled subgraph containing only the sampled neighboring edges.  It is on CPU.\n\n    Notes\n    -----\n    If :attr:`copy_ndata` or :attr:`copy_edata` is True, same tensors are used as\n    the node or edge features of the original graph and the new graph.\n    As a result, users should avoid performing in-place operations\n    on the node features of the new graph to avoid feature corruption.\n\n    Examples\n    --------\n    >>> g = dgl.graph(([0, 0, 1, 1, 2, 2], [1, 2, 0, 1, 2, 0]))\n    >>> g.edata['weight'] = torch.FloatTensor([0, 1, 0, 1, 0, 1])\n    >>> sg = dgl.sampling.select_topk(g, 1, 'weight')\n    >>> sg.edges(order='eid')\n    (tensor([2, 1, 0]), tensor([0, 1, 2]))\n    \"\"\"\n    # Rectify nodes to a dictionary\n    if nodes is None:\n        nodes = {\n            ntype: F.astype(F.arange(0, g.num_nodes(ntype)), g.idtype)\n            for ntype in g.ntypes\n        }\n    elif not isinstance(nodes, dict):\n        if len(g.ntypes) > 1:\n            raise DGLError(\n                \"Must specify node type when the graph is not homogeneous.\"\n            )\n        nodes = {g.ntypes[0]: nodes}\n    assert g.device == F.cpu(), \"Graph must be on CPU.\"\n\n    # Parse nodes into a list of NDArrays.\n    nodes = utils.prepare_tensor_dict(g, nodes, \"nodes\")\n    device = utils.context_of(nodes)\n    nodes_all_types = []\n    for ntype in g.ntypes:\n        if ntype in nodes:\n            nodes_all_types.append(F.to_dgl_nd(nodes[ntype]))\n        else:\n            nodes_all_types.append(nd.array([], ctx=nd.cpu()))\n\n    if not isinstance(k, dict):\n        k_array = [int(k)] * len(g.etypes)\n    else:\n        if len(k) != len(g.etypes):\n            raise DGLError(\n                \"K value must be specified for each edge type \"\n                \"if a dict is provided.\"\n            )\n        k_array = [None] * len(g.etypes)\n        for etype, value in k.items():\n            k_array[g.get_etype_id(etype)] = value\n    k_array = F.to_dgl_nd(F.tensor(k_array, dtype=F.int64))\n\n    weight_arrays = []\n    for etype in g.canonical_etypes:\n        if weight in g.edges[etype].data:\n            weight_arrays.append(F.to_dgl_nd(g.edges[etype].data[weight]))\n        else:\n            raise DGLError(\n                'Edge weights \"{}\" do not exist for relation graph \"{}\".'.format(\n                    weight, etype\n                )\n            )\n\n    subgidx = _CAPI_DGLSampleNeighborsTopk(\n        g._graph,\n        nodes_all_types,\n        k_array,\n        edge_dir,\n        weight_arrays,\n        bool(ascending),\n    )\n    induced_edges = subgidx.induced_edges\n    ret = DGLGraph(subgidx.graph, g.ntypes, g.etypes)\n\n    # handle features\n    if copy_ndata:\n        node_frames = utils.extract_node_subframes(g, device)\n        utils.set_new_frames(ret, node_frames=node_frames)\n\n    if copy_edata:\n        edge_frames = utils.extract_edge_subframes(g, induced_edges)\n        utils.set_new_frames(ret, edge_frames=edge_frames)\n    return ret if output_device is None else ret.to(output_device)\n\n\nDGLGraph.select_topk = utils.alias_func(select_topk)\n\n_init_api(\"dgl.sampling.neighbor\", __name__)\n"
  },
  {
    "path": "python/dgl/sampling/node2vec_randomwalk.py",
    "content": "\"\"\"Node2vec random walk\"\"\"\n\nfrom .. import backend as F, ndarray as nd, utils\nfrom .._ffi.function import _init_api\n\n# pylint: disable=invalid-name\n\n__all__ = [\"node2vec_random_walk\"]\n\n\ndef node2vec_random_walk(\n    g, nodes, p, q, walk_length, prob=None, return_eids=False\n):\n    \"\"\"\n    Generate random walk traces from an array of starting nodes based on the node2vec model.\n    Paper: `node2vec: Scalable Feature Learning for Networks\n    <https://arxiv.org/abs/1607.00653>`__.\n\n    The returned traces all have length ``walk_length + 1``, where the first node\n    is the starting node itself.\n\n    Note that if a random walk stops in advance, DGL pads the trace with -1 to have the same\n    length.\n\n    Parameters\n    ----------\n    g : DGLGraph\n        The graph.  Must be on CPU.\n\n        Note that node2vec only support homogeneous graph.\n    nodes : Tensor\n        Node ID tensor from which the random walk traces starts.\n\n        The tensor must be on CPU, and must have the same dtype as the ID type\n        of the graph.\n    p: float\n        Likelihood of immediately revisiting a node in the walk.\n    q: float\n        Control parameter to interpolate between breadth-first strategy and depth-first strategy.\n    walk_length: int\n        Length of random walks.\n    prob : str, optional\n        The name of the edge feature tensor on the graph storing the (unnormalized)\n        probabilities associated with each edge for choosing the next node.\n\n        The feature tensor must be non-negative and the sum of the probabilities\n        must be positive for the outbound edges of all nodes (although they don't have\n        to sum up to one).  The result will be undefined otherwise.\n\n        If omitted, DGL assumes that the neighbors are picked uniformly.\n    return_eids : bool, optional\n        If True, additionally return the edge IDs traversed.\n\n        Default: False.\n\n    Returns\n    -------\n    traces : Tensor\n        A 2-dimensional node ID tensor with shape ``(num_seeds, walk_length + 1)``.\n    eids : Tensor, optional\n        A 2-dimensional edge ID tensor with shape ``(num_seeds, length)``.\n        Only returned if :attr:`return_eids` is True.\n\n    Examples\n    --------\n    >>> g1 = dgl.graph(([0, 1, 1, 2, 3], [1, 2, 3, 0, 0]))\n    >>> dgl.sampling.node2vec_random_walk(g1, [0, 1, 2, 0], 1, 1, walk_length=4)\n    tensor([[0, 1, 3, 0, 1],\n            [1, 2, 0, 1, 3],\n            [2, 0, 1, 3, 0],\n            [0, 1, 2, 0, 1]])\n\n    >>> dgl.sampling.node2vec_random_walk(g1, [0, 1, 2, 0], 1, 1, walk_length=4, return_eids=True)\n    (tensor([[0, 1, 3, 0, 1],\n             [1, 2, 0, 1, 2],\n             [2, 0, 1, 2, 0],\n             [0, 1, 2, 0, 1]]),\n     tensor([[0, 2, 4, 0],\n             [1, 3, 0, 1],\n             [3, 0, 1, 3],\n             [0, 1, 3, 0]]))\n    \"\"\"\n    assert g.device == F.cpu(), \"Graph must be on CPU.\"\n\n    gidx = g._graph\n    nodes = F.to_dgl_nd(utils.prepare_tensor(g, nodes, \"nodes\"))\n\n    if prob is None:\n        prob_nd = nd.array([], ctx=nodes.ctx)\n    else:\n        prob_nd = F.to_dgl_nd(g.edata[prob])\n\n    traces, eids = _CAPI_DGLSamplingNode2vec(\n        gidx, nodes, p, q, walk_length, prob_nd\n    )\n\n    traces = F.from_dgl_nd(traces)\n    eids = F.from_dgl_nd(eids)\n\n    return (traces, eids) if return_eids else traces\n\n\n_init_api(\"dgl.sampling.randomwalks\", __name__)\n"
  },
  {
    "path": "python/dgl/sampling/pinsage.py",
    "content": "\"\"\"PinSAGE sampler & related functions and classes\"\"\"\n\nimport numpy as np\n\nfrom .. import backend as F, convert, utils\nfrom .._ffi.function import _init_api\nfrom .randomwalks import random_walk\n\n\ndef _select_pinsage_neighbors(src, dst, num_samples_per_node, k):\n    \"\"\"Determine the neighbors for PinSAGE algorithm from the given random walk traces.\n\n    This is fusing ``to_simple()``, ``select_topk()``, and counting the number of occurrences\n    together.\n    \"\"\"\n    src = F.to_dgl_nd(src)\n    dst = F.to_dgl_nd(dst)\n    src, dst, counts = _CAPI_DGLSamplingSelectPinSageNeighbors(\n        src, dst, num_samples_per_node, k\n    )\n    src = F.from_dgl_nd(src)\n    dst = F.from_dgl_nd(dst)\n    counts = F.from_dgl_nd(counts)\n    return (src, dst, counts)\n\n\nclass RandomWalkNeighborSampler(object):\n    \"\"\"PinSage-like neighbor sampler extended to any heterogeneous graphs.\n\n    Given a heterogeneous graph and a list of nodes, this callable will generate a homogeneous\n    graph where the neighbors of each given node are the most commonly visited nodes of the\n    same type by multiple random walks starting from that given node.  Each random walk consists\n    of multiple metapath-based traversals, with a probability of termination after each traversal.\n\n    The edges of the returned homogeneous graph will connect to the given nodes from their most\n    commonly visited nodes, with a feature indicating the number of visits.\n\n    The metapath must have the same beginning and ending node type to make the algorithm work.\n\n    This is a generalization of PinSAGE sampler which only works on bidirectional bipartite\n    graphs.\n\n    UVA and GPU sampling is supported for this sampler.\n    Refer to :ref:`guide-minibatch-gpu-sampling` for more details.\n\n    Parameters\n    ----------\n    G : DGLGraph\n        The graph.\n    num_traversals : int\n        The maximum number of metapath-based traversals for a single random walk.\n\n        Usually considered a hyperparameter.\n    termination_prob : float\n        Termination probability after each metapath-based traversal.\n\n        Usually considered a hyperparameter.\n    num_random_walks : int\n        Number of random walks to try for each given node.\n\n        Usually considered a hyperparameter.\n    num_neighbors : int\n        Number of neighbors (or most commonly visited nodes) to select for each given node.\n    metapath : list[str] or list[tuple[str, str, str]], optional\n        The metapath.\n\n        If not given, DGL assumes that the graph is homogeneous and the metapath consists\n        of one step over the single edge type.\n    weight_column : str, default \"weights\"\n        The name of the edge feature to be stored on the returned graph with the number of\n        visits.\n\n    Examples\n    --------\n    See examples in :any:`PinSAGESampler`.\n    \"\"\"\n\n    def __init__(\n        self,\n        G,\n        num_traversals,\n        termination_prob,\n        num_random_walks,\n        num_neighbors,\n        metapath=None,\n        weight_column=\"weights\",\n    ):\n        self.G = G\n        self.weight_column = weight_column\n        self.num_random_walks = num_random_walks\n        self.num_neighbors = num_neighbors\n        self.num_traversals = num_traversals\n\n        if metapath is None:\n            if len(G.ntypes) > 1 or len(G.etypes) > 1:\n                raise ValueError(\n                    \"Metapath must be specified if the graph is homogeneous.\"\n                )\n            metapath = [G.canonical_etypes[0]]\n        start_ntype = G.to_canonical_etype(metapath[0])[0]\n        end_ntype = G.to_canonical_etype(metapath[-1])[-1]\n        if start_ntype != end_ntype:\n            raise ValueError(\n                \"The metapath must start and end at the same node type.\"\n            )\n        self.ntype = start_ntype\n\n        self.metapath_hops = len(metapath)\n        self.metapath = metapath\n        self.full_metapath = metapath * num_traversals\n        restart_prob = np.zeros(self.metapath_hops * num_traversals)\n        restart_prob[\n            self.metapath_hops :: self.metapath_hops\n        ] = termination_prob\n        restart_prob = F.tensor(restart_prob, dtype=F.float32)\n        self.restart_prob = F.copy_to(restart_prob, G.device)\n\n    # pylint: disable=no-member\n    def __call__(self, seed_nodes):\n        \"\"\"\n        Parameters\n        ----------\n        seed_nodes : Tensor\n            A tensor of given node IDs of node type ``ntype`` to generate neighbors from.  The\n            node type ``ntype`` is the beginning and ending node type of the given metapath.\n\n            It must be on the same device as the graph and have the same dtype\n            as the ID type of the graph.\n\n        Returns\n        -------\n        g : DGLGraph\n            A homogeneous graph constructed by selecting neighbors for each given node according\n            to the algorithm above.\n        \"\"\"\n        seed_nodes = utils.prepare_tensor(self.G, seed_nodes, \"seed_nodes\")\n        self.restart_prob = F.copy_to(self.restart_prob, F.context(seed_nodes))\n\n        seed_nodes = F.repeat(seed_nodes, self.num_random_walks, 0)\n        paths, _ = random_walk(\n            self.G,\n            seed_nodes,\n            metapath=self.full_metapath,\n            restart_prob=self.restart_prob,\n        )\n        src = F.reshape(\n            paths[:, self.metapath_hops :: self.metapath_hops], (-1,)\n        )\n        dst = F.repeat(paths[:, 0], self.num_traversals, 0)\n\n        src, dst, counts = _select_pinsage_neighbors(\n            src,\n            dst,\n            (self.num_random_walks * self.num_traversals),\n            self.num_neighbors,\n        )\n        neighbor_graph = convert.heterograph(\n            {(self.ntype, \"_E\", self.ntype): (src, dst)},\n            {self.ntype: self.G.num_nodes(self.ntype)},\n        )\n        neighbor_graph.edata[self.weight_column] = counts\n\n        return neighbor_graph\n\n\nclass PinSAGESampler(RandomWalkNeighborSampler):\n    \"\"\"PinSAGE-like neighbor sampler.\n\n    This callable works on a bidirectional bipartite graph with edge types\n    ``(ntype, fwtype, other_type)`` and ``(other_type, bwtype, ntype)`` (where ``ntype``,\n    ``fwtype``, ``bwtype`` and ``other_type`` could be arbitrary type names).  It will generate\n    a homogeneous graph of node type ``ntype`` where the neighbors of each given node are the\n    most commonly visited nodes of the same type by multiple random walks starting from that\n    given node.  Each random walk consists of multiple metapath-based traversals, with a\n    probability of termination after each traversal.  The metapath is always ``[fwtype, bwtype]``,\n    walking from node type ``ntype`` to node type ``other_type`` then back to ``ntype``.\n\n    The edges of the returned homogeneous graph will connect to the given nodes from their most\n    commonly visited nodes, with a feature indicating the number of visits.\n\n    UVA and GPU sampling is supported for this sampler.\n    Refer to :ref:`guide-minibatch-gpu-sampling` for more details.\n\n    Parameters\n    ----------\n    G : DGLGraph\n        The bidirectional bipartite graph.\n\n        The graph should only have two node types: ``ntype`` and ``other_type``.\n        The graph should only have two edge types, one connecting from ``ntype`` to\n        ``other_type``, and another connecting from ``other_type`` to ``ntype``.\n    ntype : str\n        The node type for which the graph would be constructed on.\n    other_type : str\n        The other node type.\n    num_traversals : int\n        The maximum number of metapath-based traversals for a single random walk.\n\n        Usually considered a hyperparameter.\n    termination_prob : int\n        Termination probability after each metapath-based traversal.\n\n        Usually considered a hyperparameter.\n    num_random_walks : int\n        Number of random walks to try for each given node.\n\n        Usually considered a hyperparameter.\n    num_neighbors : int\n        Number of neighbors (or most commonly visited nodes) to select for each given node.\n    weight_column : str, default \"weights\"\n        The name of the edge feature to be stored on the returned graph with the number of\n        visits.\n\n    Examples\n    --------\n    Generate a random bidirectional bipartite graph with 3000 \"A\" nodes and 5000 \"B\" nodes.\n\n    >>> g = scipy.sparse.random(3000, 5000, 0.003)\n    >>> G = dgl.heterograph({\n    ...     ('A', 'AB', 'B'): g.nonzero(),\n    ...     ('B', 'BA', 'A'): g.T.nonzero()})\n\n    Then we create a PinSage neighbor sampler that samples a graph of node type \"A\".  Each\n    node would have (a maximum of) 10 neighbors.\n\n    >>> sampler = dgl.sampling.PinSAGESampler(G, 'A', 'B', 3, 0.5, 200, 10)\n\n    This is how we select the neighbors for node #0, #1 and #2 of type \"A\" according to\n    PinSAGE algorithm:\n\n    >>> seeds = torch.LongTensor([0, 1, 2])\n    >>> frontier = sampler(seeds)\n    >>> frontier.all_edges(form='uv')\n    (tensor([ 230,    0,  802,   47,   50, 1639, 1533,  406, 2110, 2687, 2408, 2823,\n                0,  972, 1230, 1658, 2373, 1289, 1745, 2918, 1818, 1951, 1191, 1089,\n             1282,  566, 2541, 1505, 1022,  812]),\n     tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2,\n             2, 2, 2, 2, 2, 2]))\n\n    For an end-to-end example of PinSAGE model, including sampling on multiple layers\n    and computing with the sampled graphs, please refer to our PinSage example\n    in ``examples/pytorch/pinsage``.\n\n    References\n    ----------\n    Graph Convolutional Neural Networks for Web-Scale Recommender Systems\n        Ying et al., 2018, https://arxiv.org/abs/1806.01973\n    \"\"\"\n\n    def __init__(\n        self,\n        G,\n        ntype,\n        other_type,\n        num_traversals,\n        termination_prob,\n        num_random_walks,\n        num_neighbors,\n        weight_column=\"weights\",\n    ):\n        metagraph = G.metagraph()\n        fw_etype = list(metagraph[ntype][other_type])[0]\n        bw_etype = list(metagraph[other_type][ntype])[0]\n        super().__init__(\n            G,\n            num_traversals,\n            termination_prob,\n            num_random_walks,\n            num_neighbors,\n            metapath=[fw_etype, bw_etype],\n            weight_column=weight_column,\n        )\n\n\n_init_api(\"dgl.sampling.pinsage\", __name__)\n"
  },
  {
    "path": "python/dgl/sampling/randomwalks.py",
    "content": "\"\"\"Random walk routines\n\"\"\"\n\nfrom .. import backend as F, ndarray as nd, utils\nfrom .._ffi.function import _init_api\nfrom ..base import DGLError\n\n__all__ = [\"random_walk\", \"pack_traces\"]\n\n\ndef random_walk(\n    g,\n    nodes,\n    *,\n    metapath=None,\n    length=None,\n    prob=None,\n    restart_prob=None,\n    return_eids=False\n):\n    \"\"\"Generate random walk traces from an array of starting nodes based on the given metapath.\n\n    Each starting node will have one trace generated, which\n\n    1. Start from the given node and set ``t`` to 0.\n    2. Pick and traverse along edge type ``metapath[t]`` from the current node.\n    3. If no edge can be found, halt.  Otherwise, increment ``t`` and go to step 2.\n\n    To generate multiple traces for a single node, you can specify the same node multiple\n    times.\n\n    The returned traces all have length ``len(metapath) + 1``, where the first node\n    is the starting node itself.\n\n    If a random walk stops in advance, DGL pads the trace with -1 to have the same\n    length.\n\n    This function supports the graph on GPU and UVA sampling.\n\n    Parameters\n    ----------\n    g : DGLGraph\n        The graph.\n    nodes : Tensor\n        Node ID tensor from which the random walk traces starts.\n\n        The tensor must have the same dtype as the ID type of the graph.\n        The tensor must be on the same device as the graph or\n        on the GPU when the graph is pinned (UVA sampling).\n    metapath : list[str or tuple of str], optional\n        Metapath, specified as a list of edge types.\n\n        Mutually exclusive with :attr:`length`.\n\n        If omitted, DGL assumes that ``g`` only has one node & edge type.  In this\n        case, the argument ``length`` specifies the length of random walk traces.\n    length : int, optional\n        Length of random walks.\n\n        Mutually exclusive with :attr:`metapath`.\n\n        Only used when :attr:`metapath` is None.\n    prob : str, optional\n        The name of the edge feature tensor on the graph storing the (unnormalized)\n        probabilities associated with each edge for choosing the next node.\n\n        The feature tensor must be non-negative and the sum of the probabilities\n        must be positive for the outbound edges of all nodes (although they don't have\n        to sum up to one).  The result will be undefined otherwise.\n\n        The feature tensor must be on the same device as the graph.\n\n        If omitted, DGL assumes that the neighbors are picked uniformly.\n    restart_prob : float or Tensor, optional\n        Probability to terminate the current trace before each transition.\n\n        If a tensor is given, :attr:`restart_prob` should be on the same device as the graph\n        or on the GPU when the graph is pinned (UVA sampling),\n        and have the same length as :attr:`metapath` or :attr:`length`.\n    return_eids : bool, optional\n        If True, additionally return the edge IDs traversed.\n\n        Default: False.\n\n    Returns\n    -------\n    traces : Tensor\n        A 2-dimensional node ID tensor with shape ``(num_seeds, len(metapath) + 1)`` or\n        ``(num_seeds, length + 1)`` if :attr:`metapath` is None.\n    eids : Tensor, optional\n        A 2-dimensional edge ID tensor with shape ``(num_seeds, len(metapath))`` or\n        ``(num_seeds, length)`` if :attr:`metapath` is None.  Only returned if\n        :attr:`return_eids` is True.\n    types : Tensor\n        A 1-dimensional node type ID tensor with shape ``(len(metapath) + 1)`` or\n        ``(length + 1)``.\n        The type IDs match the ones in the original graph ``g``.\n\n    Examples\n    --------\n    The following creates a homogeneous graph:\n    >>> g1 = dgl.graph(([0, 1, 1, 2, 3], [1, 2, 3, 0, 0]))\n\n    Normal random walk:\n\n    >>> dgl.sampling.random_walk(g1, [0, 1, 2, 0], length=4)\n    (tensor([[0, 1, 2, 0, 1],\n             [1, 3, 0, 1, 3],\n             [2, 0, 1, 3, 0],\n             [0, 1, 2, 0, 1]]), tensor([0, 0, 0, 0, 0]))\n\n    Or returning edge IDs:\n\n    >>> dgl.sampling.random_walk(g1, [0, 1, 2, 0], length=4, return_eids=True)\n    (tensor([[0, 1, 2, 0, 1],\n             [1, 3, 0, 1, 2],\n             [2, 0, 1, 3, 0],\n             [0, 1, 3, 0, 1]]),\n     tensor([[0, 1, 3, 0],\n             [2, 4, 0, 1],\n             [3, 0, 2, 4],\n             [0, 2, 4, 0]]),\n     tensor([0, 0, 0, 0, 0]))\n\n    The first tensor indicates the random walk path for each seed node.\n    The j-th element in the second tensor indicates the node type ID of the j-th node\n    in every path.  In this case, it is returning all 0.\n\n    Random walk with restart:\n\n    >>> dgl.sampling.random_walk_with_restart(g1, [0, 1, 2, 0], length=4, restart_prob=0.5)\n    (tensor([[ 0, -1, -1, -1, -1],\n             [ 1,  3,  0, -1, -1],\n             [ 2, -1, -1, -1, -1],\n             [ 0, -1, -1, -1, -1]]), tensor([0, 0, 0, 0, 0]))\n\n    Non-uniform random walk:\n\n    >>> g1.edata['p'] = torch.FloatTensor([1, 0, 1, 1, 1])     # disallow going from 1 to 2\n    >>> dgl.sampling.random_walk(g1, [0, 1, 2, 0], length=4, prob='p')\n    (tensor([[0, 1, 3, 0, 1],\n             [1, 3, 0, 1, 3],\n             [2, 0, 1, 3, 0],\n             [0, 1, 3, 0, 1]]), tensor([0, 0, 0, 0, 0]))\n\n    Metapath-based random walk:\n\n    >>> g2 = dgl.heterograph({\n    ...     ('user', 'follow', 'user'): ([0, 1, 1, 2, 3], [1, 2, 3, 0, 0]),\n    ...     ('user', 'view', 'item'): ([0, 0, 1, 2, 3, 3], [0, 1, 1, 2, 2, 1]),\n    ...     ('item', 'viewed-by', 'user'): ([0, 1, 1, 2, 2, 1], [0, 0, 1, 2, 3, 3])\n    >>> dgl.sampling.random_walk(\n    ...     g2, [0, 1, 2, 0], metapath=['follow', 'view', 'viewed-by'] * 2)\n    (tensor([[0, 1, 1, 1, 2, 2, 3],\n             [1, 3, 1, 1, 2, 2, 2],\n             [2, 0, 1, 1, 3, 1, 1],\n             [0, 1, 1, 0, 1, 1, 3]]), tensor([0, 0, 1, 0, 0, 1, 0]))\n\n    Metapath-based random walk, with restarts only on items (i.e. after traversing a \"view\"\n    relationship):\n\n    >>> dgl.sampling.random_walk(\n    ...     g2, [0, 1, 2, 0], metapath=['follow', 'view', 'viewed-by'] * 2,\n    ...     restart_prob=torch.FloatTensor([0, 0.5, 0, 0, 0.5, 0]))\n    (tensor([[ 0,  1, -1, -1, -1, -1, -1],\n             [ 1,  3,  1,  0,  1,  1,  0],\n             [ 2,  0,  1,  1,  3,  2,  2],\n             [ 0,  1,  1,  3,  0,  0,  0]]), tensor([0, 0, 1, 0, 0, 1, 0]))\n    \"\"\"\n    n_etypes = len(g.canonical_etypes)\n    n_ntypes = len(g.ntypes)\n\n    if metapath is None:\n        if n_etypes > 1 or n_ntypes > 1:\n            raise DGLError(\n                \"metapath not specified and the graph is not homogeneous.\"\n            )\n        if length is None:\n            raise ValueError(\n                \"Please specify either the metapath or the random walk length.\"\n            )\n        metapath = [0] * length\n    else:\n        metapath = [g.get_etype_id(etype) for etype in metapath]\n\n    gidx = g._graph\n    nodes = utils.prepare_tensor(g, nodes, \"nodes\")\n    nodes = F.to_dgl_nd(nodes)\n    # (Xin) Since metapath array is created by us, safe to skip the check\n    #       and keep it on CPU to make max_nodes sanity check easier.\n    metapath = F.to_dgl_nd(F.astype(F.tensor(metapath), g.idtype))\n\n    # Load the probability tensor from the edge frames\n    ctx = utils.to_dgl_context(g.device)\n    if prob is None:\n        p_nd = [nd.array([], ctx=ctx) for _ in g.canonical_etypes]\n    else:\n        p_nd = []\n        for etype in g.canonical_etypes:\n            if prob in g.edges[etype].data:\n                prob_nd = F.to_dgl_nd(g.edges[etype].data[prob])\n            else:\n                prob_nd = nd.array([], ctx=ctx)\n            p_nd.append(prob_nd)\n\n    # Actual random walk\n    if restart_prob is None:\n        traces, eids, types = _CAPI_DGLSamplingRandomWalk(\n            gidx, nodes, metapath, p_nd\n        )\n    elif F.is_tensor(restart_prob):\n        restart_prob = F.to_dgl_nd(restart_prob)\n        traces, eids, types = _CAPI_DGLSamplingRandomWalkWithStepwiseRestart(\n            gidx, nodes, metapath, p_nd, restart_prob\n        )\n    elif isinstance(restart_prob, float):\n        traces, eids, types = _CAPI_DGLSamplingRandomWalkWithRestart(\n            gidx, nodes, metapath, p_nd, restart_prob\n        )\n    else:\n        raise TypeError(\"restart_prob should be float or Tensor.\")\n\n    traces = F.from_dgl_nd(traces)\n    types = F.from_dgl_nd(types)\n    eids = F.from_dgl_nd(eids)\n    return (traces, eids, types) if return_eids else (traces, types)\n\n\ndef pack_traces(traces, types):\n    \"\"\"Pack the padded traces returned by ``random_walk()`` into a concatenated array.\n    The padding values (-1) are removed, and the length and offset of each trace is\n    returned along with the concatenated node ID and node type arrays.\n\n    Parameters\n    ----------\n    traces : Tensor\n        A 2-dimensional node ID tensor.  Must be on CPU and either ``int32`` or ``int64``.\n    types : Tensor\n        A 1-dimensional node type ID tensor.  Must be on CPU and either ``int32`` or ``int64``.\n\n    Returns\n    -------\n    concat_vids : Tensor\n        An array of all node IDs concatenated and padding values removed.\n    concat_types : Tensor\n        An array of node types corresponding for each node in ``concat_vids``.\n        Has the same length as ``concat_vids``.\n    lengths : Tensor\n        Length of each trace in the original traces tensor.\n    offsets : Tensor\n        Offset of each trace in the originial traces tensor in the new concatenated tensor.\n\n    Notes\n    -----\n    The returned tensors are on CPU.\n\n    Examples\n    --------\n    >>> g2 = dgl.heterograph({\n    ...     ('user', 'follow', 'user'): ([0, 1, 1, 2, 3], [1, 2, 3, 0, 0]),\n    ...     ('user', 'view', 'item'): ([0, 0, 1, 2, 3, 3], [0, 1, 1, 2, 2, 1]),\n    ...     ('item', 'viewed-by', 'user'): ([0, 1, 1, 2, 2, 1], [0, 0, 1, 2, 3, 3])\n    >>> traces, types = dgl.sampling.random_walk(\n    ...     g2, [0, 0], metapath=['follow', 'view', 'viewed-by'] * 2,\n    ...     restart_prob=torch.FloatTensor([0, 0.5, 0, 0, 0.5, 0]))\n    >>> traces, types\n    (tensor([[ 0,  1, -1, -1, -1, -1, -1],\n             [ 0,  1,  1,  3,  0,  0,  0]]), tensor([0, 0, 1, 0, 0, 1, 0]))\n    >>> concat_vids, concat_types, lengths, offsets = dgl.sampling.pack_traces(traces, types)\n    >>> concat_vids\n    tensor([0, 1, 0, 1, 1, 3, 0, 0, 0])\n    >>> concat_types\n    tensor([0, 0, 0, 0, 1, 0, 0, 1, 0])\n    >>> lengths\n    tensor([2, 7])\n    >>> offsets\n    tensor([0, 2]))\n\n    The first tensor ``concat_vids`` is the concatenation of all paths, i.e. flattened array\n    of ``traces``, excluding all padding values (-1).\n\n    The second tensor ``concat_types`` stands for the node type IDs of all corresponding nodes\n    in the first tensor.\n\n    The third and fourth tensor indicates the length and the offset of each path.  With these\n    tensors it is easy to obtain the i-th random walk path with:\n\n    >>> vids = concat_vids.split(lengths.tolist())\n    >>> vtypes = concat_vtypes.split(lengths.tolist())\n    >>> vids[1], vtypes[1]\n    (tensor([0, 1, 1, 3, 0, 0, 0]), tensor([0, 0, 1, 0, 0, 1, 0]))\n    \"\"\"\n    assert (\n        F.is_tensor(traces) and F.context(traces) == F.cpu()\n    ), \"traces must be a CPU tensor\"\n    assert (\n        F.is_tensor(types) and F.context(types) == F.cpu()\n    ), \"types must be a CPU tensor\"\n    traces = F.to_dgl_nd(traces)\n    types = F.to_dgl_nd(types)\n\n    concat_vids, concat_types, lengths, offsets = _CAPI_DGLSamplingPackTraces(\n        traces, types\n    )\n\n    concat_vids = F.from_dgl_nd(concat_vids)\n    concat_types = F.from_dgl_nd(concat_types)\n    lengths = F.from_dgl_nd(lengths)\n    offsets = F.from_dgl_nd(offsets)\n\n    return concat_vids, concat_types, lengths, offsets\n\n\n_init_api(\"dgl.sampling.randomwalks\", __name__)\n"
  },
  {
    "path": "python/dgl/sampling/utils.py",
    "content": "\"\"\"Sampling utilities\"\"\"\nfrom collections.abc import Mapping\n\nimport numpy as np\n\nfrom .. import backend as F, transforms, utils\nfrom ..base import EID\n\nfrom ..utils import recursive_apply, recursive_apply_pair\n\n\ndef _locate_eids_to_exclude(frontier_parent_eids, exclude_eids):\n    \"\"\"Find the edges whose IDs in parent graph appeared in exclude_eids.\n\n    Note that both arguments are numpy arrays or numpy dicts.\n    \"\"\"\n    if not isinstance(frontier_parent_eids, Mapping):\n        return np.isin(frontier_parent_eids, exclude_eids).nonzero()[0]\n    result = {}\n    for k, v in frontier_parent_eids.items():\n        if k in exclude_eids:\n            result[k] = np.isin(v, exclude_eids[k]).nonzero()[0]\n    return recursive_apply(result, F.zerocopy_from_numpy)\n\n\nclass EidExcluder(object):\n    \"\"\"Class that finds the edges whose IDs in parent graph appeared in exclude_eids.\n\n    The edge IDs can be both CPU and GPU tensors.\n    \"\"\"\n\n    def __init__(self, exclude_eids):\n        device = None\n        if isinstance(exclude_eids, Mapping):\n            for _, v in exclude_eids.items():\n                if device is None:\n                    device = F.context(v)\n                    break\n        else:\n            device = F.context(exclude_eids)\n        self._exclude_eids = None\n        self._filter = None\n\n        if device == F.cpu():\n            # TODO(nv-dlasalle): Once Filter is implemented for the CPU, we\n            # should just use that irregardless of the device.\n            self._exclude_eids = (\n                recursive_apply(exclude_eids, F.zerocopy_to_numpy)\n                if exclude_eids is not None\n                else None\n            )\n        else:\n            self._filter = recursive_apply(exclude_eids, utils.Filter)\n\n    def _find_indices(self, parent_eids):\n        \"\"\"Find the set of edge indices to remove.\"\"\"\n        if self._exclude_eids is not None:\n            parent_eids_np = recursive_apply(parent_eids, F.zerocopy_to_numpy)\n            return _locate_eids_to_exclude(parent_eids_np, self._exclude_eids)\n        else:\n            assert self._filter is not None\n            func = lambda x, y: x.find_included_indices(y)\n            return recursive_apply_pair(self._filter, parent_eids, func)\n\n    def __call__(self, frontier, weights=None):\n        parent_eids = frontier.edata[EID]\n        located_eids = self._find_indices(parent_eids)\n\n        if not isinstance(located_eids, Mapping):\n            # (BarclayII) If frontier already has a EID field and located_eids is empty,\n            # the returned graph will keep EID intact.  Otherwise, EID will change\n            # to the mapping from the new graph to the old frontier.\n            # So we need to test if located_eids is empty, and do the remapping ourselves.\n            if len(located_eids) > 0:\n                frontier = transforms.remove_edges(\n                    frontier, located_eids, store_ids=True\n                )\n                if (\n                    weights is not None\n                    and weights[0].shape[0] == frontier.num_edges()\n                ):\n                    weights[0] = F.gather_row(weights[0], frontier.edata[EID])\n                frontier.edata[EID] = F.gather_row(\n                    parent_eids, frontier.edata[EID]\n                )\n        else:\n            # (BarclayII) remove_edges only accepts removing one type of edges,\n            # so I need to keep track of the edge IDs left one by one.\n            new_eids = parent_eids.copy()\n            for i, (k, v) in enumerate(located_eids.items()):\n                if len(v) > 0:\n                    frontier = transforms.remove_edges(\n                        frontier, v, etype=k, store_ids=True\n                    )\n                    new_eids[k] = F.gather_row(\n                        parent_eids[k], frontier.edges[k].data[EID]\n                    )\n                    if weights is not None and weights[i].shape[\n                        0\n                    ] == frontier.num_edges(k):\n                        weights[i] = F.gather_row(\n                            weights[i], frontier.edges[k].data[EID]\n                        )\n            frontier.edata[EID] = new_eids\n        return frontier if weights is None else (frontier, weights)\n"
  },
  {
    "path": "python/dgl/sparse/__init__.py",
    "content": "\"\"\"dgl sparse class.\"\"\"\nimport os\nimport sys\n\nimport torch\n\nfrom .._ffi import libinfo\nfrom .broadcast import *\nfrom .elementwise_op import *\nfrom .elementwise_op_sp import *\nfrom .matmul import *\nfrom .reduction import *  # pylint: disable=W0622\nfrom .sddmm import *\nfrom .softmax import *\nfrom .sparse_matrix import *\nfrom .unary_op import *\n\n\ndef load_dgl_sparse():\n    \"\"\"Load DGL C++ sparse library\"\"\"\n    version = torch.__version__.split(\"+\", maxsplit=1)[0]\n\n    if sys.platform.startswith(\"linux\"):\n        basename = f\"libdgl_sparse_pytorch_{version}.so\"\n    elif sys.platform.startswith(\"darwin\"):\n        basename = f\"libdgl_sparse_pytorch_{version}.dylib\"\n    elif sys.platform.startswith(\"win\"):\n        basename = f\"dgl_sparse_pytorch_{version}.dll\"\n    else:\n        raise NotImplementedError(\"Unsupported system: %s\" % sys.platform)\n\n    dirname = os.path.dirname(libinfo.find_lib_path()[0])\n    path = os.path.join(dirname, \"dgl_sparse\", basename)\n    if not os.path.exists(path):\n        raise FileNotFoundError(f\"Cannot find DGL C++ sparse library at {path}\")\n\n    try:\n        torch.classes.load_library(path)\n    except Exception:  # pylint: disable=W0703\n        raise ImportError(\"Cannot load DGL C++ sparse library\")\n\n\nload_dgl_sparse()\n"
  },
  {
    "path": "python/dgl/sparse/broadcast.py",
    "content": "\"\"\"DGL broadcast operator module.\"\"\"\n\nimport operator\n\nimport torch\n\nfrom .sparse_matrix import SparseMatrix, val_like\n\n\ndef sp_broadcast_v(A: SparseMatrix, v: torch.Tensor, op: str) -> SparseMatrix:\n    \"\"\"Broadcast operator for sparse matrix and vector.\n\n    :attr:`v` is broadcasted to the shape of :attr:`A` and then the operator is\n    applied on the non-zero values of :attr:`A`.\n\n    There are two cases regarding the shape of v:\n\n    1. :attr:`v` is a vector of shape ``(1, A.shape[1])`` or ``(A.shape[1])``.\n    In this case, :attr:`v` is broadcasted on the row dimension of :attr:`A`.\n\n    2. :attr:`v` is a vector of shape ``(A.shape[0], 1)``. In this case,\n    :attr:`v` is broadcasted on the column dimension of :attr:`A`.\n\n    If ``A.val`` takes shape ``(nnz, D)``, then :attr:`v` will be broadcasted on\n    the ``D`` dimension.\n\n    Parameters\n    ----------\n    A: SparseMatrix\n        Sparse matrix\n    v: torch.Tensor\n        Vector\n    op: str\n        Operator in [\"add\", \"sub\", \"mul\", \"truediv\"]\n\n    Returns\n    -------\n    SparseMatrix\n        Sparse matrix\n\n    Examples\n    --------\n    >>> indices = torch.tensor([[1, 0, 2], [0, 3, 2]])\n    >>> val = torch.tensor([10, 20, 30])\n    >>> A = dglsp.spmatrix(indices, val, shape=(3, 4))\n    >>> v = torch.tensor([1, 2, 3, 4])\n    >>> dglsp.sp_broadcast_v(A, v, \"add\")\n    SparseMatrix(indices=tensor([[1, 0, 2],\n                                 [0, 3, 2]]),\n                 values=tensor([11, 24, 33]),\n                 shape=(3, 4), nnz=3)\n\n    >>> v = torch.tensor([1, 2, 3]).view(-1, 1)\n    >>> dglsp.sp_broadcast_v(A, v, \"add\")\n    SparseMatrix(indices=tensor([[1, 0, 2],\n                                 [0, 3, 2]]),\n                 values=tensor([12, 21, 33]),\n                 shape=(3, 4), nnz=3)\n\n    >>> indices = torch.tensor([[1, 0, 2], [0, 3, 2]])\n    >>> val = torch.tensor([[10, 20], [30, 40], [50, 60]])\n    >>> A = dglsp.spmatrix(indices, val, shape=(3, 4))\n    >>> v = torch.tensor([1, 2, 3]).view(-1, 1)\n    >>> dglsp.sp_broadcast_v(A, v, \"sub\")\n    SparseMatrix(indices=tensor([[1, 0, 2],\n                                 [0, 3, 2]]),\n                 values=tensor([[ 8, 18],\n                                [29, 39],\n                                [47, 57]]),\n                 shape=(3, 4), nnz=3, val_size=(2,))\n    \"\"\"\n    op = getattr(operator, op)\n    if v.dim() == 1:\n        v = v.view(1, -1)\n\n    shape_error_message = (\n        f\"Dimension mismatch for broadcasting. Got A.shape = {A.shape} and\"\n        f\"v.shape = {v.shape}.\"\n    )\n    assert v.dim() <= 2 and (1 in v.shape), shape_error_message\n    broadcast_dim = None\n    # v can be broadcasted to A if exactly one dimension of v is 1 and the other\n    # is the same as A.\n    for d, (dim1, dim2) in enumerate(zip(A.shape, v.shape)):\n        assert dim2 in (1, dim1), shape_error_message\n        if dim1 != dim2:\n            assert broadcast_dim is None, shape_error_message\n            broadcast_dim = d\n\n    # A and v has the same shape of (1, *) or (*, 1).\n    if broadcast_dim is None:\n        broadcast_dim = 0 if A.shape[0] == 1 else 1\n\n    if broadcast_dim == 0:\n        v = v.view(-1)[A.col]\n    else:\n        v = v.view(-1)[A.row]\n    if A.val.dim() > 1:\n        v = v.view(-1, 1)\n    ret_val = op(A.val, v)\n    return val_like(A, ret_val)\n\n\ndef sp_add_v(A: SparseMatrix, v: torch.Tensor) -> SparseMatrix:\n    \"\"\"Broadcast addition for sparse matrix and vector.\n\n    See the definition of :func:`sp_broadcast_v` for details.\n    \"\"\"\n    return sp_broadcast_v(A, v, \"add\")\n\n\ndef sp_sub_v(A: SparseMatrix, v: torch.Tensor) -> SparseMatrix:\n    \"\"\"Broadcast substraction for sparse matrix and vector.\n\n    See the definition of :func:`sp_broadcast_v` for details.\n    \"\"\"\n    return sp_broadcast_v(A, v, \"sub\")\n\n\ndef sp_mul_v(A: SparseMatrix, v: torch.Tensor) -> SparseMatrix:\n    \"\"\"Broadcast multiply for sparse matrix and vector.\n\n    See the definition of :func:`sp_broadcast_v` for details.\n    \"\"\"\n    return sp_broadcast_v(A, v, \"mul\")\n\n\ndef sp_div_v(A: SparseMatrix, v: torch.Tensor) -> SparseMatrix:\n    \"\"\"Broadcast division for sparse matrix and vector.\n\n    See the definition of :func:`sp_broadcast_v` for details.\n    \"\"\"\n    return sp_broadcast_v(A, v, \"truediv\")\n"
  },
  {
    "path": "python/dgl/sparse/elementwise_op.py",
    "content": "# pylint: disable=anomalous-backslash-in-string\n\"\"\"DGL elementwise operator module.\"\"\"\nfrom typing import Union\n\nfrom .sparse_matrix import SparseMatrix\nfrom .utils import Scalar\n\n__all__ = [\"add\", \"sub\", \"mul\", \"div\", \"power\"]\n\n\ndef add(A: SparseMatrix, B: SparseMatrix) -> SparseMatrix:\n    r\"\"\"Elementwise addition for ``SparseMatrix``, equivalent to ``A + B``.\n\n    Parameters\n    ----------\n    A : SparseMatrix\n        Sparse matrix\n    B : SparseMatrix\n        Sparse matrix\n\n    Returns\n    -------\n    SparseMatrix\n        Sparse matrix\n\n    Examples\n    --------\n    >>> indices = torch.tensor([[1, 0, 2], [0, 1, 2]])\n    >>> val = torch.tensor([10, 20, 30])\n    >>> A = dglsp.spmatrix(indices, val)\n    >>> B = dglsp.diag(torch.arange(1, 4))\n    >>> dglsp.add(A, B)\n    SparseMatrix(indices=tensor([[0, 0, 1, 1, 2],\n                                 [0, 1, 0, 1, 2]]),\n                 values=tensor([1, 20, 10,  2, 33]),\n                 shape=(3, 3), nnz=5)\n    \"\"\"\n    return A + B\n\n\ndef sub(A: SparseMatrix, B: SparseMatrix) -> SparseMatrix:\n    r\"\"\"Elementwise subtraction for ``SparseMatrix``, equivalent to ``A - B``.\n\n    Parameters\n    ----------\n    A : SparseMatrix\n        Sparse matrix\n    B : SparseMatrix\n        Sparse matrix\n\n    Returns\n    -------\n    SparseMatrix\n        Sparse matrix\n\n    Examples\n    --------\n    >>> indices = torch.tensor([[1, 0, 2], [0, 1, 2]])\n    >>> val = torch.tensor([10, 20, 30])\n    >>> A = dglsp.spmatrix(indices, val)\n    >>> B = dglsp.diag(torch.arange(1, 4))\n    >>> dglsp.sub(A, B)\n    SparseMatrix(indices=tensor([[0, 0, 1, 1, 2],\n                                 [0, 1, 0, 1, 2]]),\n                 values=tensor([-1, 20, 10, -2, 27]),\n                 shape=(3, 3), nnz=5)\n    \"\"\"\n    return A - B\n\n\ndef mul(\n    A: Union[SparseMatrix, Scalar], B: Union[SparseMatrix, Scalar]\n) -> SparseMatrix:\n    r\"\"\"Elementwise multiplication for ``SparseMatrix``, equivalent to\n    ``A * B``.\n\n    If both :attr:`A` and :attr:`B` are sparse matrices, both of them should be\n    diagonal matrices.\n\n    Parameters\n    ----------\n    A : SparseMatrix or Scalar\n        Sparse matrix or scalar value\n    B : SparseMatrix or Scalar\n        Sparse matrix or scalar value\n\n    Returns\n    -------\n    SparseMatrix\n        Sparse matrix\n\n    Examples\n    --------\n    >>> indices = torch.tensor([[1, 0, 2], [0, 3, 2]])\n    >>> val = torch.tensor([10, 20, 30])\n    >>> A = dglsp.spmatrix(indices, val)\n    >>> dglsp.mul(A, 2)\n    SparseMatrix(indices=tensor([[1, 0, 2],\n                                 [0, 3, 2]]),\n                 values=tensor([20, 40, 60]),\n                 shape=(3, 4), nnz=3)\n\n    >>> D = dglsp.diag(torch.arange(1, 4))\n    >>> dglsp.mul(D, 2)\n    SparseMatrix(indices=tensor([[0, 1, 2],\n                                 [0, 1, 2]]),\n                 values=tensor([2, 4, 6]),\n                 shape=(3, 3), nnz=3)\n\n    >>> D = dglsp.diag(torch.arange(1, 4))\n    >>> dglsp.mul(D, D)\n    SparseMatrix(indices=tensor([[0, 1, 2],\n                                 [0, 1, 2]]),\n                 values=tensor([1, 4, 9]),\n                 shape=(3, 3), nnz=3)\n    \"\"\"\n    return A * B\n\n\ndef div(A: SparseMatrix, B: Union[SparseMatrix, Scalar]) -> SparseMatrix:\n    r\"\"\"Elementwise division for ``SparseMatrix``, equivalent to ``A / B``.\n\n    If both :attr:`A` and :attr:`B` are sparse matrices, both of them should be\n    diagonal matrices.\n\n    Parameters\n    ----------\n    A : SparseMatrix\n        Sparse matrix\n    B : SparseMatrix or Scalar\n        Sparse matrix or scalar value\n\n    Returns\n    -------\n    SparseMatrix\n        Sparse matrix\n\n    Examples\n    --------\n    >>> A = dglsp.diag(torch.arange(1, 4))\n    >>> B = dglsp.diag(torch.arange(10, 13))\n    >>> dglsp.div(A, B)\n    SparseMatrix(indices=tensor([[0, 1, 2],\n                                 [0, 1, 2]]),\n                 values=tensor([0.1000, 0.1818, 0.2500]),\n                 shape=(3, 3), nnz=3)\n\n    >>> A = dglsp.diag(torch.arange(1, 4))\n    >>> dglsp.div(A, 2)\n    SparseMatrix(indices=tensor([[0, 1, 2],\n                                 [0, 1, 2]]),\n                 values=tensor([0.5000, 1.0000, 1.5000]),\n                 shape=(3, 3), nnz=3)\n\n    >>> indices = torch.tensor([[1, 0, 2], [0, 3, 2]])\n    >>> val = torch.tensor([1, 2, 3])\n    >>> A = dglsp.spmatrix(indices, val, shape=(3, 4))\n    >>> dglsp.div(A, 2)\n    SparseMatrix(indices=tensor([[1, 0, 2],\n                                 [0, 3, 2]]),\n                 values=tensor([0.5000, 1.0000, 1.5000]),\n                 shape=(3, 4), nnz=3)\n    \"\"\"\n    return A / B\n\n\ndef power(A: SparseMatrix, scalar: Scalar) -> SparseMatrix:\n    r\"\"\"Elementwise exponentiation ``SparseMatrix``, equivalent to\n    ``A ** scalar``.\n\n    Parameters\n    ----------\n    A : SparseMatrix\n        Sparse matrix\n    scalar : Scalar\n        Exponent\n\n    Returns\n    -------\n    SparseMatrix\n        Sparse matrix\n\n    Examples\n    --------\n    >>> indices = torch.tensor([[1, 0, 2], [0, 3, 2]])\n    >>> val = torch.tensor([10, 20, 30])\n    >>> A = dglsp.spmatrix(indices, val)\n    >>> dglsp.power(A, 2)\n    SparseMatrix(indices=tensor([[1, 0, 2],\n                                 [0, 3, 2]]),\n                 values=tensor([100, 400, 900]),\n                 shape=(3, 4), nnz=3)\n\n    >>> D = dglsp.diag(torch.arange(1, 4))\n    >>> dglsp.power(D, 2)\n    SparseMatrix(indices=tensor([[0, 1, 2],\n                                 [0, 1, 2]]),\n                 values=tensor([1, 4, 9]),\n                 shape=(3, 3), nnz=3)\n    \"\"\"\n    return A**scalar\n"
  },
  {
    "path": "python/dgl/sparse/elementwise_op_sp.py",
    "content": "\"\"\"DGL elementwise operators for sparse matrix module.\"\"\"\nfrom typing import Union\n\nimport torch\n\nfrom .sparse_matrix import SparseMatrix, val_like\nfrom .utils import is_scalar, Scalar\n\n\ndef spsp_add(A, B):\n    \"\"\"Invoke C++ sparse library for addition\"\"\"\n    return SparseMatrix(\n        torch.ops.dgl_sparse.spsp_add(A.c_sparse_matrix, B.c_sparse_matrix)\n    )\n\n\ndef spsp_mul(A, B):\n    \"\"\"Invoke C++ sparse library for multiplication\"\"\"\n    return SparseMatrix(\n        torch.ops.dgl_sparse.spsp_mul(A.c_sparse_matrix, B.c_sparse_matrix)\n    )\n\n\ndef spsp_div(A, B):\n    \"\"\"Invoke C++ sparse library for division\"\"\"\n    return SparseMatrix(\n        torch.ops.dgl_sparse.spsp_div(A.c_sparse_matrix, B.c_sparse_matrix)\n    )\n\n\ndef sp_add(A: SparseMatrix, B: SparseMatrix) -> SparseMatrix:\n    \"\"\"Elementwise addition\n\n    Parameters\n    ----------\n    A : SparseMatrix\n        Sparse matrix\n    B : SparseMatrix\n        Sparse matrix\n\n    Returns\n    -------\n    SparseMatrix\n        Sparse matrix\n\n    Examples\n    --------\n\n    >>> indices = torch.tensor([[1, 0, 2], [0, 3, 2]])\n    >>> val = torch.tensor([10, 20, 30])\n    >>> A = dglsp.spmatrix(indices, val, shape=(3, 4))\n    >>> A + A\n    SparseMatrix(indices=tensor([[0, 1, 2],\n                                 [3, 0, 2]]),\n                 values=tensor([40, 20, 60]),\n                 shape=(3, 4), nnz=3)\n    \"\"\"\n    # Python falls back to B.__radd__ then TypeError when NotImplemented is\n    # returned.\n    return spsp_add(A, B) if isinstance(B, SparseMatrix) else NotImplemented\n\n\ndef sp_sub(A: SparseMatrix, B: SparseMatrix) -> SparseMatrix:\n    \"\"\"Elementwise subtraction\n\n    Parameters\n    ----------\n    A : SparseMatrix\n        Sparse matrix\n    B : SparseMatrix\n        Sparse matrix\n\n    Returns\n    -------\n    SparseMatrix\n        Sparse matrix\n\n    Examples\n    --------\n\n    >>> indices = torch.tensor([[1, 0, 2], [0, 3, 2]])\n    >>> val = torch.tensor([10, 20, 30])\n    >>> val2 = torch.tensor([5, 10, 15])\n    >>> A = dglsp.spmatrix(indices, val, shape=(3, 4))\n    >>> B = dglsp.spmatrix(indices, val2, shape=(3, 4))\n    >>> A - B\n    SparseMatrix(indices=tensor([[0, 1, 2],\n                                 [3, 0, 2]]),\n                 values=tensor([10, 5, 15]),\n                 shape=(3, 4), nnz=3)\n    \"\"\"\n    # Python falls back to B.__rsub__ then TypeError when NotImplemented is\n    # returned.\n    return spsp_add(A, -B) if isinstance(B, SparseMatrix) else NotImplemented\n\n\ndef sp_mul(A: SparseMatrix, B: Union[SparseMatrix, Scalar]) -> SparseMatrix:\n    \"\"\"Elementwise multiplication\n\n    Note that if both :attr:`A` and :attr:`B` are sparse matrices, both of them\n    need to be diagonal or on CPU.\n\n    Parameters\n    ----------\n    A : SparseMatrix\n        First operand\n    B : SparseMatrix or Scalar\n        Second operand\n\n    Returns\n    -------\n    SparseMatrix\n        Result of A * B\n\n    Examples\n    --------\n\n    >>> indices = torch.tensor([[1, 0, 2], [0, 3, 2]])\n    >>> val = torch.tensor([1, 2, 3])\n    >>> A = dglsp.spmatrix(indices, val, shape=(3, 4))\n\n    >>> A * 2\n    SparseMatrix(indices=tensor([[1, 0, 2],\n                                 [0, 3, 2]]),\n                 values=tensor([2, 4, 6]),\n                 shape=(3, 4), nnz=3)\n\n    >>> 2 * A\n    SparseMatrix(indices=tensor([[1, 0, 2],\n                                 [0, 3, 2]]),\n                 values=tensor([2, 4, 6]),\n                 shape=(3, 4), nnz=3)\n\n    >>> indices2 = torch.tensor([[2, 0, 1], [0, 3, 2]])\n    >>> val2 = torch.tensor([3, 2, 1])\n    >>> B = dglsp.spmatrix(indices2, val2, shape=(3, 4))\n    >>> A * B\n    SparseMatrix(indices=tensor([[0],\n                                 [3]]),\n                 values=tensor([4]),\n                 shape=(3, 4), nnz=1)\n    \"\"\"\n    if is_scalar(B):\n        return val_like(A, A.val * B)\n    return spsp_mul(A, B)\n\n\ndef sp_div(A: SparseMatrix, B: Union[SparseMatrix, Scalar]) -> SparseMatrix:\n    \"\"\"Elementwise division\n\n    If :attr:`B` is a sparse matrix, both :attr:`A` and :attr:`B` must have the\n    same sparsity. And the returned matrix has the same order of non-zero\n    entries as :attr:`A`.\n\n    Parameters\n    ----------\n    A : SparseMatrix\n        First operand\n    B : SparseMatrix or Scalar\n        Second operand\n\n    Returns\n    -------\n    SparseMatrix\n        Result of A / B\n\n    Examples\n    --------\n    >>> indices = torch.tensor([[1, 0, 2], [0, 3, 2]])\n    >>> val = torch.tensor([1, 2, 3])\n    >>> A = dglsp.spmatrix(indices, val, shape=(3, 4))\n    >>> A / 2\n    SparseMatrix(indices=tensor([[1, 0, 2],\n                                 [0, 3, 2]]),\n                 values=tensor([0.5000, 1.0000, 1.5000]),\n                 shape=(3, 4), nnz=3)\n    \"\"\"\n    if is_scalar(B):\n        return val_like(A, A.val / B)\n    return spsp_div(A, B)\n\n\ndef sp_power(A: SparseMatrix, scalar: Scalar) -> SparseMatrix:\n    \"\"\"Take the power of each nonzero element and return a sparse matrix with\n    the result.\n\n    Parameters\n    ----------\n    A : SparseMatrix\n        Sparse matrix\n    scalar : float or int\n        Exponent\n\n    Returns\n    -------\n    SparseMatrix\n        Sparse matrix\n\n    Examples\n    --------\n    >>> indices = torch.tensor([[1, 0, 2], [0, 3, 2]])\n    >>> val = torch.tensor([10, 20, 30])\n    >>> A = dglsp.spmatrix(indices, val)\n    >>> A ** 2\n    SparseMatrix(indices=tensor([[1, 0, 2],\n                                 [0, 3, 2]]),\n    values=tensor([100, 400, 900]),\n    shape=(3, 4), nnz=3)\n    \"\"\"\n    # Python falls back to scalar.__rpow__ then TypeError when NotImplemented\n    # is returned.\n    return val_like(A, A.val**scalar) if is_scalar(scalar) else NotImplemented\n\n\nSparseMatrix.__add__ = sp_add\nSparseMatrix.__sub__ = sp_sub\nSparseMatrix.__mul__ = sp_mul\nSparseMatrix.__rmul__ = sp_mul\nSparseMatrix.__truediv__ = sp_div\nSparseMatrix.__pow__ = sp_power\n"
  },
  {
    "path": "python/dgl/sparse/matmul.py",
    "content": "\"\"\"Matmul ops for SparseMatrix\"\"\"\n# pylint: disable=invalid-name\nfrom typing import Union\n\nimport torch\n\nfrom .sparse_matrix import SparseMatrix\n\n__all__ = [\"spmm\", \"bspmm\", \"spspmm\", \"matmul\"]\n\n\ndef spmm(A: SparseMatrix, X: torch.Tensor) -> torch.Tensor:\n    \"\"\"Multiplies a sparse matrix by a dense matrix, equivalent to ``A @ X``.\n\n    Parameters\n    ----------\n    A : SparseMatrix\n        Sparse matrix of shape ``(L, M)`` with scalar values\n    X : torch.Tensor\n        Dense matrix of shape ``(M, N)`` or ``(M)``\n\n    Returns\n    -------\n    torch.Tensor\n        The dense matrix of shape ``(L, N)`` or ``(L)``\n\n    Examples\n    --------\n\n    >>> indices = torch.tensor([[0, 1, 1], [1, 0, 1]])\n    >>> val = torch.randn(indices.shape[1])\n    >>> A = dglsp.spmatrix(indices, val)\n    >>> X = torch.randn(2, 3)\n    >>> result = dglsp.spmm(A, X)\n    >>> type(result)\n    <class 'torch.Tensor'>\n    >>> result.shape\n    torch.Size([2, 3])\n    \"\"\"\n    assert isinstance(\n        A, SparseMatrix\n    ), f\"Expect arg1 to be a SparseMatrix object, got {type(A)}.\"\n    assert isinstance(\n        X, torch.Tensor\n    ), f\"Expect arg2 to be a torch.Tensor, got {type(X)}.\"\n\n    return torch.ops.dgl_sparse.spmm(A.c_sparse_matrix, X)\n\n\ndef bspmm(A: SparseMatrix, X: torch.Tensor) -> torch.Tensor:\n    \"\"\"Multiplies a sparse matrix by a dense matrix by batches, equivalent to\n    ``A @ X``.\n\n    Parameters\n    ----------\n    A : SparseMatrix\n        Sparse matrix of shape ``(L, M)`` with vector values of length ``K``\n    X : torch.Tensor\n        Dense matrix of shape ``(M, N, K)``\n\n    Returns\n    -------\n    torch.Tensor\n        Dense matrix of shape ``(L, N, K)``\n\n    Examples\n    --------\n\n    >>> indices = torch.tensor([[0, 1, 1], [1, 0, 2]])\n    >>> val = torch.randn(len(row), 2)\n    >>> A = dglsp.spmatrix(indices, val, shape=(3, 3))\n    >>> X = torch.randn(3, 3, 2)\n    >>> result = dglsp.bspmm(A, X)\n    >>> type(result)\n    <class 'torch.Tensor'>\n    >>> result.shape\n    torch.Size([3, 3, 2])\n    \"\"\"\n    assert isinstance(\n        A, SparseMatrix\n    ), f\"Expect arg1 to be a SparseMatrix object, got {type(A)}.\"\n    assert isinstance(\n        X, torch.Tensor\n    ), f\"Expect arg2 to be a torch.Tensor, got {type(X)}.\"\n    return spmm(A, X)\n\n\ndef spspmm(A: SparseMatrix, B: SparseMatrix) -> SparseMatrix:\n    \"\"\"Multiplies a sparse matrix by a sparse matrix, equivalent to ``A @ B``.\n\n    The non-zero values of the two sparse matrices must be 1D.\n\n    Parameters\n    ----------\n    A : SparseMatrix\n        Sparse matrix of shape ``(L, M)``\n    B : SparseMatrix\n        Sparse matrix of shape ``(M, N)``\n\n    Returns\n    -------\n    SparseMatrix\n        Sparse matrix of shape ``(L, N)``.\n\n    Examples\n    --------\n\n    >>> indices1 = torch.tensor([[0, 1, 1], [1, 0, 1]])\n    >>> val1 = torch.ones(len(row1))\n    >>> A = dglsp.spmatrix(indices1, val1)\n    >>> indices2 = torch.tensor([[0, 1, 1], [0, 2, 1]])\n    >>> val2 = torch.ones(len(row2))\n    >>> B = dglsp.spmatrix(indices2, val2)\n    >>> dglsp.spspmm(A, B)\n    SparseMatrix(indices=tensor([[0, 0, 1, 1, 1],\n                                 [1, 2, 0, 1, 2]]),\n                 values=tensor([1., 1., 1., 1., 1.]),\n                 shape=(2, 3), nnz=5)\n    \"\"\"\n    assert isinstance(\n        A, SparseMatrix\n    ), f\"Expect A1 to be a SparseMatrix object, got {type(A)}.\"\n    assert isinstance(\n        B, SparseMatrix\n    ), f\"Expect A2 to be a SparseMatrix object, got {type(B)}.\"\n\n    return SparseMatrix(\n        torch.ops.dgl_sparse.spspmm(A.c_sparse_matrix, B.c_sparse_matrix)\n    )\n\n\ndef matmul(\n    A: Union[torch.Tensor, SparseMatrix], B: Union[torch.Tensor, SparseMatrix]\n) -> Union[torch.Tensor, SparseMatrix]:\n    \"\"\"Multiplies two dense/sparse matrices, equivalent to ``A @ B``.\n\n    This function does not support the case where :attr:`A` is a \\\n    ``torch.Tensor`` and :attr:`B` is a ``SparseMatrix``.\n\n    * If both matrices are torch.Tensor, it calls \\\n        :func:`torch.matmul()`. The result is a dense matrix.\n\n    * If both matrices are sparse, it calls :func:`dgl.sparse.spspmm`. The \\\n        result is a sparse matrix.\n\n    * If :attr:`A` is sparse while :attr:`B` is dense, it calls \\\n        :func:`dgl.sparse.spmm`. The result is a dense matrix.\n\n    * The operator supports batched sparse-dense matrix multiplication. In \\\n        this case, the sparse matrix :attr:`A` should have shape ``(L, M)``, \\\n        where the non-zero values have a batch dimension ``K``. The dense \\\n        matrix :attr:`B` should have shape ``(M, N, K)``. The output \\\n        is a dense matrix of shape ``(L, N, K)``.\n\n    * Sparse-sparse matrix multiplication does not support batched computation.\n\n    Parameters\n    ----------\n    A : torch.Tensor or SparseMatrix\n        The first matrix.\n    B : torch.Tensor or SparseMatrix\n        The second matrix.\n\n    Returns\n    -------\n    torch.Tensor or SparseMatrix\n        The result matrix\n\n    Examples\n    --------\n\n    Multiplies a diagonal matrix with a dense matrix.\n\n    >>> val = torch.randn(3)\n    >>> A = dglsp.diag(val)\n    >>> B = torch.randn(3, 2)\n    >>> result = dglsp.matmul(A, B)\n    >>> type(result)\n    <class 'torch.Tensor'>\n    >>> result.shape\n    torch.Size([3, 2])\n\n    Multiplies a sparse matrix with a dense matrix.\n\n    >>> indices = torch.tensor([[0, 1, 1], [1, 0, 1]])\n    >>> val = torch.randn(indices.shape[1])\n    >>> A = dglsp.spmatrix(indices, val)\n    >>> X = torch.randn(2, 3)\n    >>> result = dglsp.matmul(A, X)\n    >>> type(result)\n    <class 'torch.Tensor'>\n    >>> result.shape\n    torch.Size([2, 3])\n\n    Multiplies a sparse matrix with a sparse matrix.\n\n    >>> indices1 = torch.tensor([[0, 1, 1], [1, 0, 1]])\n    >>> val1 = torch.ones(indices1.shape[1])\n    >>> A = dglsp.spmatrix(indices1, val1)\n    >>> indices2 = torch.tensor([[0, 1, 1], [0, 2, 1]])\n    >>> val2 = torch.ones(indices2.shape[1])\n    >>> B = dglsp.spmatrix(indices2, val2)\n    >>> result = dglsp.matmul(A, B)\n    >>> type(result)\n    <class 'dgl.sparse.sparse_matrix.SparseMatrix'>\n    >>> result.shape\n    (2, 3)\n    \"\"\"\n    assert isinstance(\n        A, (torch.Tensor, SparseMatrix)\n    ), f\"Expect arg1 to be a torch.Tensor or SparseMatrix, got {type(A)}.\"\n    assert isinstance(B, (torch.Tensor, SparseMatrix)), (\n        f\"Expect arg2 to be a torch Tensor or SparseMatrix\"\n        f\"object, got {type(B)}.\"\n    )\n    if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):\n        return torch.matmul(A, B)\n    assert not isinstance(A, torch.Tensor), (\n        f\"Expect arg2 to be a torch Tensor if arg 1 is torch Tensor, \"\n        f\"got {type(B)}.\"\n    )\n    if isinstance(B, torch.Tensor):\n        return spmm(A, B)\n    return spspmm(A, B)\n\n\nSparseMatrix.__matmul__ = matmul\n"
  },
  {
    "path": "python/dgl/sparse/reduction.py",
    "content": "\"\"\"DGL sparse matrix reduce operators\"\"\"\n# pylint: disable=W0622\n\nfrom typing import Optional\n\nimport torch\n\nfrom .sparse_matrix import SparseMatrix\n\n\ndef reduce(input: SparseMatrix, dim: Optional[int] = None, rtype: str = \"sum\"):\n    \"\"\"Computes the reduction of non-zero values of the :attr:`input` sparse\n    matrix along the given dimension :attr:`dim`.\n\n    The reduction does not count zero elements. If the row or column to be\n    reduced does not have any non-zero elements, the result will be 0.\n\n    Parameters\n    ----------\n    input : SparseMatrix\n        The input sparse matrix\n    dim : int, optional\n        The dimension to reduce, must be either 0 (by rows) or 1 (by columns)\n        or None (on both rows and columns simultaneously)\n\n        If :attr:`dim` is None, it reduces both the rows and the columns\n        in the sparse matrix, producing a tensor of shape\n        ``input.val.shape[1:]``. Otherwise, it reduces on the row (``dim=0``)\n        or column (``dim=1``) dimension, producing a tensor of shape\n        ``(input.shape[1],) + input.val.shape[1:]`` or\n        ``(input.shape[0],) + input.val.shape[1:]``.\n    rtype: str, optional\n        Reduction type, one of ``['sum', 'smin', 'smax', 'smean', 'sprod']``,\n        representing taking the sum, minimum, maximum, mean, and product of the\n        non-zero elements\n\n    Returns\n    ----------\n    torch.Tensor\n        Reduced tensor\n\n    Examples\n    ----------\n\n    Case1: scalar-valued sparse matrix\n\n    >>> indices = torch.tensor([[0, 1, 1], [0, 0, 2]])\n    >>> val = torch.tensor([1, 1, 2])\n    >>> A = dglsp.spmatrix(indices, val, shape=(4, 3))\n    >>> dglsp.reduce(A, rtype='sum')\n    tensor(4)\n    >>> dglsp.reduce(A, 0, 'sum')\n    tensor([2, 0, 2])\n    >>> dglsp.reduce(A, 1, 'sum')\n    tensor([1, 3, 0, 0])\n    >>> dglsp.reduce(A, 0, 'smax')\n    tensor([1, 0, 2])\n    >>> dglsp.reduce(A, 1, 'smin')\n    tensor([1, 1, 0, 0])\n\n    Case2: vector-valued sparse matrix\n\n    >>> indices = torch.tensor([[0, 1, 1], [0, 0, 2]])\n    >>> val = torch.tensor([[1., 2.], [2., 1.], [2., 2.]])\n    >>> A = dglsp.spmatrix(indices, val, shape=(4, 3))\n    >>> dglsp.reduce(A, rtype='sum')\n    tensor([5., 5.])\n    >>> dglsp.reduce(A, 0, 'sum')\n    tensor([[3., 3.],\n            [0., 0.],\n            [2., 2.]])\n    >>> dglsp.reduce(A, 1, 'smin')\n    tensor([[1., 2.],\n            [2., 1.],\n            [0., 0.],\n            [0., 0.]])\n    >>> dglsp.reduce(A, 0, 'smean')\n    tensor([[1.5000, 1.5000],\n            [0.0000, 0.0000],\n            [2.0000, 2.0000]])\n    \"\"\"\n    return torch.ops.dgl_sparse.reduce(input.c_sparse_matrix, rtype, dim)\n\n\ndef sum(input: SparseMatrix, dim: Optional[int] = None):\n    \"\"\"Computes the sum of non-zero values of the :attr:`input` sparse matrix\n    along the given dimension :attr:`dim`.\n\n    Parameters\n    ----------\n    input : SparseMatrix\n        The input sparse matrix\n    dim : int, optional\n        The dimension to reduce, must be either 0 (by rows) or 1 (by columns)\n        or None (on both rows and columns simultaneously)\n\n        If :attr:`dim` is None, it reduces both the rows and the columns\n        in the sparse matrix, producing a tensor of shape\n        ``input.val.shape[1:]``. Otherwise, it reduces on the row (``dim=0``)\n        or column (``dim=1``) dimension, producing a tensor of shape\n        ``(input.shape[1],) + input.val.shape[1:]`` or\n        ``(input.shape[0],) + input.val.shape[1:]``.\n\n    Returns\n    ----------\n    torch.Tensor\n        Reduced tensor\n\n    Examples\n    ----------\n\n    Case1: scalar-valued sparse matrix\n\n    >>> indices = torch.tensor([[0, 1, 1], [0, 0, 2]])\n    >>> val = torch.tensor([1, 1, 2])\n    >>> A = dglsp.spmatrix(indices, val, shape=(4, 3))\n    >>> dglsp.sum(A)\n    tensor(4)\n    >>> dglsp.sum(A, 0)\n    tensor([2, 0, 2])\n    >>> dglsp.sum(A, 1)\n    tensor([1, 3, 0, 0])\n\n    Case2: vector-valued sparse matrix\n\n    >>> indices = torch.tensor([[0, 1, 1], [0, 0, 2]])\n    >>> val = torch.tensor([[1, 2], [2, 1], [2, 2]])\n    >>> A = dglsp.spmatrix(indices, val, shape=(4, 3))\n    >>> dglsp.sum(A)\n    tensor([5, 5])\n    >>> dglsp.sum(A, 0)\n    tensor([[3, 3],\n            [0, 0],\n            [2, 2]])\n    \"\"\"\n    return torch.ops.dgl_sparse.sum(input.c_sparse_matrix, dim)\n\n\ndef smax(input: SparseMatrix, dim: Optional[int] = None):\n    \"\"\"Computes the maximum of non-zero values of the :attr:`input` sparse\n    matrix along the given dimension :attr:`dim`.\n\n    The reduction does not count zero values. If the row or column to be\n    reduced does not have any non-zero value, the result will be 0.\n\n    Parameters\n    ----------\n    input : SparseMatrix\n        The input sparse matrix\n    dim : int, optional\n        The dimension to reduce, must be either 0 (by rows) or 1 (by columns)\n        or None (on both rows and columns simultaneously)\n\n        If :attr:`dim` is None, it reduces both the rows and the columns\n        in the sparse matrix, producing a tensor of shape\n        ``input.val.shape[1:]``. Otherwise, it reduces on the row (``dim=0``)\n        or column (``dim=1``) dimension, producing a tensor of shape\n        ``(input.shape[1],) + input.val.shape[1:]`` or\n        ``(input.shape[0],) + input.val.shape[1:]``.\n\n    Returns\n    ----------\n    torch.Tensor\n        Reduced tensor\n\n    Examples\n    ----------\n\n    Case1: scalar-valued sparse matrix\n\n    >>> indices = torch.tensor([[0, 1, 1], [0, 0, 2]])\n    >>> val = torch.tensor([1, 1, 2])\n    >>> A = dglsp.spmatrix(indices, val, shape=(4, 3))\n    >>> dglsp.smax(A)\n    tensor(2)\n    >>> dglsp.smax(A, 0)\n    tensor([1, 0, 2])\n    >>> dglsp.smax(A, 1)\n    tensor([1, 2, 0, 0])\n\n    Case2: vector-valued sparse matrix\n\n    >>> indices = torch.tensor([[0, 1, 1], [0, 0, 2]])\n    >>> val = torch.tensor([[1, 2], [2, 1], [2, 2]])\n    >>> A = dglsp.spmatrix(indices, val, shape=(4, 3))\n    >>> dglsp.smax(A)\n    tensor([2, 2])\n    >>> dglsp.smax(A, 1)\n    tensor([[1, 2],\n            [2, 2],\n            [0, 0],\n            [0, 0]])\n    \"\"\"\n    return torch.ops.dgl_sparse.smax(input.c_sparse_matrix, dim)\n\n\ndef smin(input: SparseMatrix, dim: Optional[int] = None):\n    \"\"\"Computes the minimum of non-zero values of the :attr:`input` sparse\n    matrix along the given dimension :attr:`dim`.\n\n    The reduction does not count zero values. If the row or column to be reduced\n    does not have any non-zero value, the result will be 0.\n\n    Parameters\n    ----------\n    input : SparseMatrix\n        The input sparse matrix\n    dim : int, optional\n        The dimension to reduce, must be either 0 (by rows) or 1 (by columns)\n        or None (on both rows and columns simultaneously)\n\n        If :attr:`dim` is None, it reduces both the rows and the columns\n        in the sparse matrix, producing a tensor of shape\n        ``input.val.shape[1:]``. Otherwise, it reduces on the row (``dim=0``)\n        or column (``dim=1``) dimension, producing a tensor of shape\n        ``(input.shape[1],) + input.val.shape[1:]`` or\n        ``(input.shape[0],) + input.val.shape[1:]``.\n\n    Returns\n    ----------\n    torch.Tensor\n        Reduced tensor\n\n    Examples\n    ----------\n\n    Case1: scalar-valued sparse matrix\n\n    >>> indices = torch.tensor([[0, 1, 1], [0, 0, 2]])\n    >>> val = torch.tensor([1, 1, 2])\n    >>> A = dglsp.spmatrix(indices, val, shape=(4, 3))\n    >>> dglsp.smin(A)\n    tensor(1)\n    >>> dglsp.smin(A, 0)\n    tensor([1, 0, 2])\n    >>> dglsp.smin(A, 1)\n    tensor([1, 1, 0, 0])\n\n    Case2: vector-valued sparse matrix\n\n    >>> indices = torch.tensor([[0, 1, 1], [0, 0, 2]])\n    >>> val = torch.tensor([[1, 2], [2, 1], [2, 2]])\n    >>> A = dglsp.spmatrix(indices, val, shape=(4, 3))\n    >>> dglsp.smin(A)\n    tensor([1, 1])\n    >>> dglsp.smin(A, 0)\n    tensor([[1, 1],\n            [0, 0],\n            [2, 2]])\n    >>> dglsp.smin(A, 1)\n    tensor([[1, 2],\n            [2, 1],\n            [0, 0],\n            [0, 0]])\n    \"\"\"\n    return torch.ops.dgl_sparse.smin(input.c_sparse_matrix, dim)\n\n\ndef smean(input: SparseMatrix, dim: Optional[int] = None):\n    \"\"\"Computes the mean of non-zero values of the :attr:`input` sparse matrix\n    along the given dimension :attr:`dim`.\n\n    The reduction does not count zero values. If the row or column to be reduced\n    does not have any non-zero value, the result will be 0.\n\n    Parameters\n    ----------\n    input : SparseMatrix\n        The input sparse matrix\n    dim : int, optional\n        The dimension to reduce, must be either 0 (by rows) or 1 (by columns)\n        or None (on both rows and columns simultaneously)\n\n        If :attr:`dim` is None, it reduces both the rows and the columns\n        in the sparse matrix, producing a tensor of shape\n        ``input.val.shape[1:]``. Otherwise, it reduces on the row (``dim=0``)\n        or column (``dim=1``) dimension, producing a tensor of shape\n        ``(input.shape[1],) + input.val.shape[1:]`` or\n        ``(input.shape[0],) + input.val.shape[1:]``.\n\n    Returns\n    ----------\n    torch.Tensor\n        Reduced tensor\n\n    Examples\n    ----------\n\n    Case1: scalar-valued sparse matrix\n\n    >>> indices = torch.tensor([[0, 1, 1], [0, 0, 2]])\n    >>> val = torch.tensor([1., 1., 2.])\n    >>> A = dglsp.spmatrix(indices, val, shape=(4, 3))\n    >>> dglsp.smean(A)\n    tensor(1.3333)\n    >>> dglsp.smean(A, 0)\n    tensor([1., 0., 2.])\n    >>> dglsp.smean(A, 1)\n    tensor([1.0000, 1.5000, 0.0000, 0.0000])\n\n    Case2: vector-valued sparse matrix\n\n    >>> indices = torch.tensor([[0, 1, 1], [0, 0, 2]])\n    >>> val = torch.tensor([[1., 2.], [2., 1.], [2., 2.]])\n    >>> A = dglsp.spmatrix(indices, val, shape=(4, 3))\n    >>> dglsp.smean(A)\n    tensor([1.6667, 1.6667])\n    >>> dglsp.smean(A, 0)\n    tensor([[1.5000, 1.5000],\n            [0.0000, 0.0000],\n            [2.0000, 2.0000]])\n    >>> dglsp.smean(A, 1)\n    tensor([[1.0000, 2.0000],\n            [2.0000, 1.5000],\n            [0.0000, 0.0000],\n            [0.0000, 0.0000]])\n    \"\"\"\n    return torch.ops.dgl_sparse.smean(input.c_sparse_matrix, dim)\n\n\ndef sprod(input: SparseMatrix, dim: Optional[int] = None):\n    \"\"\"Computes the product of non-zero values of the :attr:`input` sparse\n    matrix along the given dimension :attr:`dim`.\n\n    The reduction does not count zero values. If the row or column to be reduced\n    does not have any non-zero value, the result will be 0.\n\n    Parameters\n    ----------\n    input : SparseMatrix\n        The input sparse matrix\n    dim : int, optional\n        The dimension to reduce, must be either 0 (by rows) or 1 (by columns)\n        or None (on both rows and columns simultaneously)\n\n        If :attr:`dim` is None, it reduces both the rows and the columns\n        in the sparse matrix, producing a tensor of shape\n        ``input.val.shape[1:]``. Otherwise, it reduces on the row (``dim=0``)\n        or column (``dim=1``) dimension, producing a tensor of shape\n        ``(input.shape[1],) + input.val.shape[1:]`` or\n        ``(input.shape[0],) + input.val.shape[1:]``.\n\n    Returns\n    ----------\n    torch.Tensor\n        Reduced tensor\n\n    Examples\n    ----------\n\n    Case1: scalar-valued sparse matrix\n\n    >>> indices = torch.tensor([[0, 1, 1], [0, 0, 2]])\n    >>> val = torch.tensor([1, 1, 2])\n    >>> A = dglsp.spmatrix(indices, val, shape=(4, 3))\n    >>> dglsp.sprod(A)\n    tensor(2)\n    >>> dglsp.sprod(A, 0)\n    tensor([1, 0, 2])\n    >>> dglsp.sprod(A, 1)\n    tensor([1, 2, 0, 0])\n\n    Case2: vector-valued sparse matrix\n\n    >>> indices = torch.tensor([[0, 1, 1], [0, 0, 2]])\n    >>> val = torch.tensor([[1, 2], [2, 1], [2, 2]])\n    >>> A = dglsp.spmatrix(indices, val, shape=(4, 3))\n    >>> dglsp.sprod(A)\n    tensor([4, 4])\n    >>> dglsp.sprod(A, 0)\n    tensor([[2, 2],\n            [0, 0],\n            [2, 2]])\n    >>> dglsp.sprod(A, 1)\n    tensor([[1, 2],\n            [4, 2],\n            [0, 0],\n            [0, 0]])\n    \"\"\"\n    return torch.ops.dgl_sparse.sprod(input.c_sparse_matrix, dim)\n\n\nSparseMatrix.reduce = reduce\nSparseMatrix.sum = sum\nSparseMatrix.smax = smax\nSparseMatrix.smin = smin\nSparseMatrix.smean = smean\nSparseMatrix.sprod = sprod\n"
  },
  {
    "path": "python/dgl/sparse/sddmm.py",
    "content": "\"\"\"Sampled Dense-Dense Matrix Multiplication (SDDMM) operator module.\"\"\"\nimport torch\n\nfrom .sparse_matrix import SparseMatrix\n\n__all__ = [\"sddmm\", \"bsddmm\"]\n\n\n# pylint: disable=invalid-name\ndef sddmm(A: SparseMatrix, X1: torch.Tensor, X2: torch.Tensor) -> SparseMatrix:\n    r\"\"\"Sampled-Dense-Dense Matrix Multiplication (SDDMM).\n\n    ``sddmm`` matrix-multiplies two dense matrices :attr:`X1` and :attr:`X2`,\n    then elementwise-multiplies the result with sparse matrix :attr:`A` at the\n    nonzero locations.\n\n    Mathematically ``sddmm`` is formulated as:\n\n    .. math::\n        out = (X1 @ X2) * A\n\n    In particular, :attr:`X1` and :attr:`X2` can be 1-D, then ``X1 @ X2``\n    becomes the out-product of the two vectors (which results in a matrix).\n\n    Parameters\n    ----------\n    A : SparseMatrix\n        Sparse matrix of shape ``(L, N)``\n    X1 : torch.Tensor\n        Dense matrix of shape ``(L, M)`` or ``(L,)``\n    X2 : torch.Tensor\n        Dense matrix of shape ``(M, N)`` or ``(N,)``\n\n    Returns\n    -------\n    SparseMatrix\n        Sparse matrix of shape ``(L, N)``\n\n    Examples\n    --------\n\n    >>> indices = torch.tensor([[1, 1, 2], [2, 3, 3]])\n    >>> val = torch.arange(1, 4).float()\n    >>> A = dglsp.spmatrix(indices, val, (3, 4))\n    >>> X1 = torch.randn(3, 5)\n    >>> X2 = torch.randn(5, 4)\n    >>> dglsp.sddmm(A, X1, X2)\n    SparseMatrix(indices=tensor([[1, 1, 2],\n                                 [2, 3, 3]]),\n                 values=tensor([-1.6585, -3.9714, -0.5406]),\n                 shape=(3, 4), nnz=3)\n    \"\"\"\n    return SparseMatrix(torch.ops.dgl_sparse.sddmm(A.c_sparse_matrix, X1, X2))\n\n\n# pylint: disable=invalid-name\ndef bsddmm(A: SparseMatrix, X1: torch.Tensor, X2: torch.Tensor) -> SparseMatrix:\n    r\"\"\"Sampled-Dense-Dense Matrix Multiplication (SDDMM) by batches.\n\n    ``sddmm`` matrix-multiplies two dense matrices :attr:`X1` and :attr:`X2`,\n    then elementwise-multiplies the result with sparse matrix :attr:`A` at the\n    nonzero locations.\n\n    Mathematically ``sddmm`` is formulated as:\n\n    .. math::\n        out = (X1 @ X2) * A\n\n    The batch dimension is the last dimension for input dense matrices. In\n    particular, if the sparse matrix has scalar non-zero values, it will be\n    broadcasted for bsddmm.\n\n    Parameters\n    ----------\n    A : SparseMatrix\n        Sparse matrix of shape ``(L, N)`` with scalar values or vector values of\n        length ``K``\n    X1 : Tensor\n        Dense matrix of shape ``(L, M, K)``\n    X2 : Tensor\n        Dense matrix of shape ``(M, N, K)``\n\n    Returns\n    -------\n    SparseMatrix\n        Sparse matrix of shape ``(L, N)`` with vector values of length ``K``\n\n    Examples\n    --------\n\n    >>> indices = torch.tensor([[1, 1, 2], [2, 3, 3]])\n    >>> val = torch.arange(1, 4).float()\n    >>> A = dglsp.spmatrix(indices, val, (3, 4))\n    >>> X1 = torch.arange(0, 3 * 5 * 2).view(3, 5, 2).float()\n    >>> X2 = torch.arange(0, 5 * 4 * 2).view(5, 4, 2).float()\n    >>> dglsp.bsddmm(A, X1, X2)\n    SparseMatrix(indices=tensor([[1, 1, 2],\n                                 [2, 3, 3]]),\n                 values=tensor([[1560., 1735.],\n                                [3400., 3770.],\n                                [8400., 9105.]]),\n                 shape=(3, 4), nnz=3, val_size=(2,))\n    \"\"\"\n    return sddmm(A, X1, X2)\n"
  },
  {
    "path": "python/dgl/sparse/softmax.py",
    "content": "\"\"\"Softmax op for SparseMatrix\"\"\"\n# pylint: disable=invalid-name, W0622\n\nimport torch\n\nfrom .sparse_matrix import SparseMatrix\n\n__all__ = [\"softmax\"]\n\n\ndef softmax(input: SparseMatrix, dim: int = 1) -> SparseMatrix:\n    \"\"\"Applies softmax to the non-zero elements of the sparse matrix on the\n    dimension :attr:``dim``. dim = 0 or 1 indicates column-wise or row-wise\n    softmax respectively.\n\n    If :attr:`input.val` takes shape ``(nnz, D)``, then the output matrix\n    :attr:`output` and :attr:`output.val` take the same shape as :attr:`input`\n    and :attr:`input.val`. :attr:`output.val[:, i]` is calculated based on\n    :attr:`input.val[:, i]`.\n\n    Parameters\n    ----------\n    input : SparseMatrix\n        The input sparse matrix\n\n    Returns\n    -------\n    SparseMatrix\n        The output sparse matrix\n\n    Examples\n    --------\n\n    Case1: row-wise softmax on matrix with values of shape (nnz)\n\n    >>> indices = torch.tensor([[0, 0, 1, 2], [1, 2, 2, 0]])\n    >>> val = torch.tensor([0., 1., 2., 3.])\n    >>> A = dglsp.spmatrix(indices, val)\n    >>> dglsp.softmax(A)\n    SparseMatrix(indices=tensor([[0, 0, 1, 2],\n                                 [1, 2, 2, 0]]),\n                 values=tensor([0.2689, 0.7311, 1.0000, 1.0000]),\n                 shape=(3, 3), nnz=4)\n\n    Case2: row-wise softmax on matrix with values of shape (nnz, D)\n\n    >>> indices = torch.tensor([[0, 0, 1, 2], [1, 2, 2, 0]])\n    >>> val = torch.tensor([[0., 7.], [1., 3.], [2., 2.], [3., 1.]])\n    >>> A = dglsp.spmatrix(indices, val)\n    >>> dglsp.softmax(A)\n    SparseMatrix(indices=tensor([[0, 0, 1, 2],\n                                 [1, 2, 2, 0]]),\n                 values=tensor([[0.2689, 0.9820],\n                                [0.7311, 0.0180],\n                                [1.0000, 1.0000],\n                                [1.0000, 1.0000]]),\n                 shape=(3, 3), nnz=4, val_size=(2,))\n\n    Case3: column-wise softmax on matrix with values of shape (nnz)\n\n    >>> indices = torch.tensor([[0, 0, 1, 2], [1, 2, 2, 0]])\n    >>> val = torch.tensor([0., 1., 2., 3.])\n    >>> A = dglsp.spmatrix(indices, val)\n    >>> dglsp.softmax(A, 0)\n    SparseMatrix(indices=tensor([[0, 0, 1, 2],\n                                 [1, 2, 2, 0]]),\n                 values=tensor([1.0000, 0.2689, 0.7311, 1.0000]),\n                 shape=(3, 3), nnz=4)\n    \"\"\"\n    return SparseMatrix(\n        torch.ops.dgl_sparse.softmax(input.c_sparse_matrix, dim)\n    )\n\n\nSparseMatrix.softmax = softmax\n"
  },
  {
    "path": "python/dgl/sparse/sparse_matrix.py",
    "content": "\"\"\"DGL sparse matrix module.\"\"\"\n# pylint: disable= invalid-name\nfrom typing import Optional, Tuple\n\nimport torch\n\n\nclass SparseMatrix:\n    r\"\"\"Class for sparse matrix.\"\"\"\n\n    def __init__(self, c_sparse_matrix: torch.ScriptObject):\n        self.c_sparse_matrix = c_sparse_matrix\n\n    def __repr__(self):\n        return _sparse_matrix_str(self)\n\n    @property\n    def val(self) -> torch.Tensor:\n        \"\"\"Returns the values of the non-zero elements.\n\n        Returns\n        -------\n        torch.Tensor\n            Values of the non-zero elements\n        \"\"\"\n        return self.c_sparse_matrix.val()\n\n    @property\n    def shape(self) -> Tuple[int]:\n        \"\"\"Returns the shape of the sparse matrix.\n\n        Returns\n        -------\n        Tuple[int]\n            The shape of the sparse matrix\n        \"\"\"\n        return tuple(self.c_sparse_matrix.shape())\n\n    @property\n    def nnz(self) -> int:\n        \"\"\"Returns the number of non-zero elements in the sparse matrix.\n\n        Returns\n        -------\n        int\n            The number of non-zero elements of the matrix\n        \"\"\"\n        return self.c_sparse_matrix.nnz()\n\n    @property\n    def dtype(self) -> torch.dtype:\n        \"\"\"Returns the data type of the sparse matrix.\n\n        Returns\n        -------\n        torch.dtype\n            Data type of the sparse matrix\n        \"\"\"\n        return self.c_sparse_matrix.val().dtype\n\n    @property\n    def device(self) -> torch.device:\n        \"\"\"Returns the device the sparse matrix is on.\n\n        Returns\n        -------\n        torch.device\n            The device the sparse matrix is on\n        \"\"\"\n        return self.c_sparse_matrix.device()\n\n    @property\n    def row(self) -> torch.Tensor:\n        \"\"\"Returns the row indices of the non-zero elements.\n\n        Returns\n        -------\n        torch.Tensor\n            Row indices of the non-zero elements\n        \"\"\"\n        return self.coo()[0]\n\n    @property\n    def col(self) -> torch.Tensor:\n        \"\"\"Returns the column indices of the non-zero elements.\n\n        Returns\n        -------\n        torch.Tensor\n            Column indices of the non-zero elements\n        \"\"\"\n        return self.coo()[1]\n\n    def coo(self) -> Tuple[torch.Tensor, torch.Tensor]:\n        r\"\"\"Returns the coordinate list (COO) representation of the sparse\n        matrix.\n\n        See `COO in Wikipedia <https://en.wikipedia.org/wiki/\n        Sparse_matrix#Coordinate_list_(COO)>`_.\n\n        Returns\n        -------\n        torch.Tensor\n            Row coordinate\n        torch.Tensor\n            Column coordinate\n\n        Examples\n        --------\n\n        >>> indices = torch.tensor([[1, 2, 1], [2, 4, 3]])\n        >>> A = dglsp.spmatrix(indices)\n        >>> A.coo()\n        (tensor([1, 2, 1]), tensor([2, 4, 3]))\n        \"\"\"\n        return self.c_sparse_matrix.coo()\n\n    def indices(self) -> torch.Tensor:\n        r\"\"\"Returns the coordinate list (COO) representation in one tensor with\n        shape ``(2, nnz)``.\n\n        See `COO in Wikipedia <https://en.wikipedia.org/wiki/\n        Sparse_matrix#Coordinate_list_(COO)>`_.\n\n        Returns\n        -------\n        torch.Tensor\n            Stacked COO tensor with shape ``(2, nnz)``.\n\n        Examples\n        --------\n\n        >>> indices = torch.tensor([[1, 2, 1], [2, 4, 3]])\n        >>> A = dglsp.spmatrix(indices)\n        >>> A.indices()\n        tensor([[1, 2, 1],\n                [2, 4, 3]])\n        \"\"\"\n        return self.c_sparse_matrix.indices()\n\n    def csr(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n        r\"\"\"Returns the compressed sparse row (CSR) representation of the sparse\n        matrix.\n\n        See `CSR in Wikipedia <https://en.wikipedia.org/wiki/\n        Sparse_matrix#Compressed_sparse_row_(CSR, _CRS_or_Yale_format)>`_.\n\n        This function also returns value indices as an index tensor, indicating\n        the order of the values of non-zero elements in the CSR representation.\n        A ``None`` value indices array indicates the order of the values stays\n        the same as the values of the SparseMatrix.\n\n        Returns\n        -------\n        torch.Tensor\n            Row indptr\n        torch.Tensor\n            Column indices\n        torch.Tensor\n            Value indices\n\n        Examples\n        --------\n\n        >>> indices = torch.tensor([[1, 2, 1], [2, 4, 3]])\n        >>> A = dglsp.spmatrix(indices)\n        >>> A.csr()\n        (tensor([0, 0, 2, 3]), tensor([2, 3, 4]), tensor([0, 2, 1]))\n        \"\"\"\n        return self.c_sparse_matrix.csr()\n\n    def csc(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n        r\"\"\"Returns the compressed sparse column (CSC) representation of the\n        sparse matrix.\n\n        See `CSC in Wikipedia <https://en.wikipedia.org/wiki/\n        Sparse_matrix#Compressed_sparse_column_(CSC_or_CCS)>`_.\n\n        This function also returns value indices as an index tensor, indicating\n        the order of the values of non-zero elements in the CSC representation.\n        A ``None`` value indices array indicates the order of the values stays\n        the same as the values of the SparseMatrix.\n\n        Returns\n        -------\n        torch.Tensor\n            Column indptr\n        torch.Tensor\n            Row indices\n        torch.Tensor\n            Value indices\n\n        Examples\n        --------\n\n        >>> indices = torch.tensor([[1, 2, 1], [2, 4, 3]])\n        >>> A = dglsp.spmatrix(indices)\n        >>> A.csc()\n        (tensor([0, 0, 0, 1, 2, 3]), tensor([1, 1, 2]), tensor([0, 2, 1]))\n        \"\"\"\n        return self.c_sparse_matrix.csc()\n\n    def to_dense(self) -> torch.Tensor:\n        \"\"\"Returns a copy in dense matrix format of the sparse matrix.\n\n        Returns\n        -------\n        torch.Tensor\n            The copy in dense matrix format\n        \"\"\"\n        row, col = self.coo()\n        val = self.val\n        shape = self.shape + val.shape[1:]\n        mat = torch.zeros(shape, device=self.device, dtype=self.dtype)\n        mat[row, col] = val\n        return mat\n\n    def t(self):\n        \"\"\"Alias of :meth:`transpose()`\"\"\"\n        return self.transpose()\n\n    @property\n    def T(self):  # pylint: disable=C0103\n        \"\"\"Alias of :meth:`transpose()`\"\"\"\n        return self.transpose()\n\n    def transpose(self):\n        \"\"\"Returns the transpose of this sparse matrix.\n\n        Returns\n        -------\n        SparseMatrix\n            The transpose of this sparse matrix.\n\n        Examples\n        --------\n\n        >>> indices = torch.tensor([[1, 1, 3], [2, 1, 3]])\n        >>> val = torch.tensor([1, 1, 2])\n        >>> A = dglsp.spmatrix(indices, val)\n        >>> A = A.transpose()\n        SparseMatrix(indices=tensor([[2, 1, 3],\n                                     [1, 1, 3]]),\n                     values=tensor([1, 1, 2]),\n                     shape=(4, 4), nnz=3)\n        \"\"\"\n        return SparseMatrix(self.c_sparse_matrix.transpose())\n\n    def to(self, device=None, dtype=None):\n        \"\"\"Performs matrix dtype and/or device conversion. If the target device\n        and dtype are already in use, the original matrix will be returned.\n\n        Parameters\n        ----------\n        device : torch.device, optional\n            The target device of the matrix if provided, otherwise the current\n            device will be used\n        dtype : torch.dtype, optional\n            The target data type of the matrix values if provided, otherwise the\n            current data type will be used\n\n        Returns\n        -------\n        SparseMatrix\n            The converted matrix\n\n        Examples\n        --------\n\n        >>> indices = torch.tensor([[1, 1, 2], [1, 2, 0]])\n        >>> A = dglsp.spmatrix(indices, shape=(3, 4))\n        >>> A.to(device=\"cuda:0\", dtype=torch.int32)\n        SparseMatrix(indices=tensor([[1, 1, 2],\n                                     [1, 2, 0]], device='cuda:0'),\n                     values=tensor([1, 1, 1], device='cuda:0',\n                                   dtype=torch.int32),\n                     shape=(3, 4), nnz=3)\n        \"\"\"\n        if device is None:\n            device = self.device\n        if dtype is None:\n            dtype = self.dtype\n\n        if device == self.device and dtype == self.dtype:\n            return self\n        elif device == self.device:\n            return val_like(self, self.val.to(dtype=dtype))\n        else:\n            # TODO(#5119): Find a better moving strategy instead of always\n            # convert to COO format.\n            row, col = self.coo()\n            row = row.to(device=device)\n            col = col.to(device=device)\n            val = self.val.to(device=device, dtype=dtype)\n            return from_coo(row, col, val, self.shape)\n\n    def cuda(self):\n        \"\"\"Moves the matrix to GPU. If the matrix is already on GPU, the\n        original matrix will be returned. If multiple GPU devices exist,\n        ``cuda:0`` will be selected.\n\n        Returns\n        -------\n        SparseMatrix\n            The matrix on GPU\n\n        Examples\n        --------\n\n        >>> indices = torch.tensor([[1, 1, 2], [1, 2, 0]])\n        >>> A = dglsp.spmatrix(indices, shape=(3, 4))\n        >>> A.cuda()\n        SparseMatrix(indices=tensor([[1, 1, 2],\n                                     [1, 2, 0]], device='cuda:0'),\n                     values=tensor([1., 1., 1.], device='cuda:0'),\n                     shape=(3, 4), nnz=3)\n        \"\"\"\n        return self.to(device=\"cuda\")\n\n    def cpu(self):\n        \"\"\"Moves the matrix to CPU. If the matrix is already on CPU, the\n        original matrix will be returned.\n\n        Returns\n        -------\n        SparseMatrix\n            The matrix on CPU\n\n        Examples\n        --------\n\n        >>> indices = torch.tensor([[1, 1, 2], [1, 2, 0]]).to(\"cuda\")\n        >>> A = dglsp.spmatrix(indices, shape=(3, 4))\n\n        >>> A.cpu()\n        SparseMatrix(indices=tensor([[1, 1, 2],\n                                     [1, 2, 0]]),\n                     values=tensor([1., 1., 1.]),\n                     shape=(3, 4), nnz=3)\n        \"\"\"\n        return self.to(device=\"cpu\")\n\n    def float(self):\n        \"\"\"Converts the matrix values to float32 data type. If the matrix\n        already uses float data type, the original matrix will be returned.\n\n        Returns\n        -------\n        SparseMatrix\n            The matrix with float values\n\n        Examples\n        --------\n\n        >>> indices = torch.tensor([[1, 1, 2], [1, 2, 0]])\n        >>> val = torch.ones(len(row)).long()\n        >>> A = dglsp.spmatrix(indices, val, shape=(3, 4))\n        >>> A.float()\n        SparseMatrix(indices=tensor([[1, 1, 2],\n                                     [1, 2, 0]]),\n                     values=tensor([1., 1., 1.]),\n                     shape=(3, 4), nnz=3)\n        \"\"\"\n        return self.to(dtype=torch.float)\n\n    def double(self):\n        \"\"\"Converts the matrix values to double data type. If the matrix already\n        uses double data type, the original matrix will be returned.\n\n        Returns\n        -------\n        SparseMatrix\n            The matrix with double values\n\n        Examples\n        --------\n\n        >>> indices = torch.tensor([[1, 1, 2], [1, 2, 0]])\n        >>> A = dglsp.spmatrix(indices, shape=(3, 4))\n        >>> A.double()\n        SparseMatrix(indices=tensor([[1, 1, 2],\n                                     [1, 2, 0]]),\n                     values=tensor([1., 1., 1.], dtype=torch.float64),\n                     shape=(3, 4), nnz=3)\n        \"\"\"\n        return self.to(dtype=torch.double)\n\n    def int(self):\n        \"\"\"Converts the matrix values to int32 data type. If the matrix already\n        uses int data type, the original matrix will be returned.\n\n        Returns\n        -------\n        DiagMatrix\n            The matrix with int values\n\n        Examples\n        --------\n\n        >>> indices = torch.tensor([[1, 1, 2], [1, 2, 0]])\n        >>> A = dglsp.spmatrix(indices, shape=(3, 4))\n        >>> A.int()\n        SparseMatrix(indices=tensor([[1, 1, 2],\n                                     [1, 2, 0]]),\n                     values=tensor([1, 1, 1], dtype=torch.int32),\n                     shape=(3, 4), nnz=3)\n        \"\"\"\n        return self.to(dtype=torch.int)\n\n    def long(self):\n        \"\"\"Converts the matrix values to long data type. If the matrix already\n        uses long data type, the original matrix will be returned.\n\n        Returns\n        -------\n        DiagMatrix\n            The matrix with long values\n\n        Examples\n        --------\n\n        >>> indices = torch.tensor([[1, 1, 2], [1, 2, 0]])\n        >>> A = dglsp.spmatrix(indices, shape=(3, 4))\n        >>> A.long()\n        SparseMatrix(indices=tensor([[1, 1, 2],\n                                     [1, 2, 0]]),\n                     values=tensor([1, 1, 1]),\n                     shape=(3, 4), nnz=3)\n        \"\"\"\n        return self.to(dtype=torch.long)\n\n    def coalesce(self):\n        \"\"\"Returns a coalesced sparse matrix.\n\n        A coalesced sparse matrix satisfies the following properties:\n\n          - the indices of the non-zero elements are unique,\n          - the indices are sorted in lexicographical order.\n\n        The coalescing process will accumulate the non-zero elements of the same\n        indices by summation.\n\n        The function does not support autograd.\n\n        Returns\n        -------\n        SparseMatrix\n            The coalesced sparse matrix\n\n        Examples\n        --------\n        >>> indices = torch.tensor([[1, 0, 0, 0, 1], [1, 1, 1, 2, 2]])\n        >>> val = torch.tensor([0, 1, 2, 3, 4])\n        >>> A = dglsp.spmatrix(indices, val)\n        >>> A.coalesce()\n        SparseMatrix(indices=tensor([[0, 0, 1, 1],\n                                     [1, 2, 1, 2]]),\n                     values=tensor([3, 3, 0, 4]),\n                     shape=(2, 3), nnz=4)\n        \"\"\"\n        return SparseMatrix(self.c_sparse_matrix.coalesce())\n\n    def has_duplicate(self):\n        \"\"\"Returns ``True`` if the sparse matrix contains duplicate indices.\n\n        Examples\n        --------\n        >>> indices = torch.tensor([[1, 0, 0, 0, 1], [1, 1, 1, 2, 2]])\n        >>> val = torch.tensor([0, 1, 2, 3, 4])\n        >>> A = dglsp.spmatrix(indices, val)\n        >>> A.has_duplicate()\n        True\n        >>> A.coalesce().has_duplicate()\n        False\n        \"\"\"\n        return self.c_sparse_matrix.has_duplicate()\n\n    def is_diag(self):\n        \"\"\"Returns whether the sparse matrix is a diagonal matrix.\"\"\"\n        return self.c_sparse_matrix.is_diag()\n\n    def index_select(self, dim: int, index: torch.Tensor):\n        \"\"\"Returns a sub-matrix selected according to the given index.\n\n        Parameters\n        ----------\n        dim : int\n            The dim to select from matrix, should be 0 or 1. `dim = 0` for\n            rowwise selection and `dim = 1` for columnwise selection.\n        index : torch.Tensor\n            The selection index indicates which IDs from the `dim` should\n            be chosen from the matrix.\n            Note that duplicated ids are allowed.\n\n        The function does not support autograd.\n\n        Returns\n        -------\n        SparseMatrix\n            The sub-matrix which contains selected rows or columns.\n\n        Examples\n        --------\n\n        >>> indices = torch.tensor([0, 1, 1, 2, 3, 4], [0, 2, 4, 3, 5, 0]])\n        >>> val = torch.tensor([0, 1, 2, 3, 4, 5])\n        >>> A = dglsp.spmatrix(indices, val)\n\n        Case 1: Select rows by IDs.\n\n        >>> row_ids = torch.tensor([0, 1, 4])\n        >>> A.index_select(0, row_ids)\n        SparseMatrix(indices=tensor([[0, 1, 1, 2],\n                                     [0, 2, 4, 0]]),\n                     values=tensor([0, 1, 2, 5]),\n                     shape=(3, 6), nnz=4)\n\n        Case 2: Select columns by IDs.\n\n        >>> column_ids = torch.tensor([0, 4, 5])\n        >>> A.index_select(1, column_ids)\n        SparseMatrix(indices=tensor([[0, 4, 1, 3],\n                                     [0, 0, 1, 2]]),\n                     values=tensor([0, 5, 2, 4]),\n                     shape=(5, 3), nnz=4)\n        \"\"\"\n        if dim not in (0, 1):\n            raise ValueError(\"The selection dimension should be 0 or 1.\")\n        if isinstance(index, torch.Tensor):\n            return SparseMatrix(self.c_sparse_matrix.index_select(dim, index))\n        raise TypeError(f\"{type(index).__name__} is unsupported input type.\")\n\n    def range_select(self, dim: int, index: slice):\n        \"\"\"Returns a sub-matrix selected according to the given range index.\n\n        Parameters\n        ----------\n        dim : int\n            The dim to select from matrix, should be 0 or 1. `dim = 0` for\n            rowwise selection and `dim = 1` for columnwise selection.\n        index : slice\n            The selection slice indicates ID index from the `dim` should\n            be chosen from the matrix.\n\n        The function does not support autograd.\n\n        Returns\n        -------\n        SparseMatrix\n            The sub-matrix which contains selected rows or columns.\n\n        Examples\n        --------\n\n        >>> indices = torch.tensor([0, 1, 1, 2, 3, 4], [0, 2, 4, 3, 5, 0]])\n        >>> val = torch.tensor([0, 1, 2, 3, 4, 5])\n        >>> A = dglsp.spmatrix(indices, val)\n\n        Case 1: Select rows with given slice object.\n\n        >>> A.range_select(0, slice(1, 3))\n        SparseMatrix(indices=tensor([[0, 0, 1],\n                                     [2, 4, 3]]),\n                     values=tensor([1, 2, 3]),\n                     shape=(2, 6), nnz=3)\n\n        Case 2: Select columns with given slice object.\n\n        >>> A.range_select(1, slice(3, 6))\n        SparseMatrix(indices=tensor([[2, 1, 3],\n                                     [0, 1, 2]]),\n                     values=tensor([3, 2, 4]),\n                     shape=(5, 3), nnz=3)\n        \"\"\"\n        if dim not in (0, 1):\n            raise ValueError(\"The selection dimension should be 0 or 1.\")\n        if isinstance(index, slice):\n            if index.step not in (None, 1):\n                raise NotImplementedError(\n                    \"Slice with step other than 1 are not supported yet.\"\n                )\n            start = 0 if index.start is None else index.start\n            end = index.stop\n            return SparseMatrix(\n                self.c_sparse_matrix.range_select(dim, start, end)\n            )\n        raise TypeError(f\"{type(index).__name__} is unsupported input type.\")\n\n    def sample(\n        self,\n        dim: int,\n        fanout: int,\n        ids: Optional[torch.Tensor] = None,\n        replace: Optional[bool] = False,\n        bias: Optional[bool] = False,\n    ):\n        \"\"\"Returns a sampled matrix on the given dimension and sample arguments.\n\n        Parameters\n        ----------\n        dim : int\n            The dimension for sampling, should be 0 or 1. `dim = 0` for\n            rowwise selection and `dim = 1` for columnwise selection.\n        fanout : int\n            The number of elements to randomly sample on each row or column.\n        ids : torch.Tensor, optional\n            An optional tensor containing row or column IDs from which to\n            sample elements.\n            NOTE: If `ids` is not provided (i.e., `ids = None`), the function\n            will sample from all rows or columns.\n        replace : bool, optional\n            Indicates whether repeated sampling of the same element is allowed.\n            When `replace = True`, repeated sampling is permitted; when\n            `replace = False`, it is not allowed.\n            NOTE: If `replace = False` and there are fewer elements than\n            `fanout`, all non-zero elements will be sampled.\n        bias : bool, optional\n            A boolean flag indicating whether to enable biasing during sampling.\n            When `bias = True`, the values of the sparse matrix will be used as\n            bias weights.\n\n        The function does not support autograd.\n\n        Returns\n        -------\n        SparseMatrix\n            A submatrix with the same shape as the original matrix, containing\n            the randomly sampled non-zero elements.\n\n        Examples\n        --------\n\n        >>> indices = torch.tensor([[0, 0, 1, 1, 2, 2, 2],\n                                    [0, 2, 0, 1, 0, 1, 2]])\n        >>> val = torch.tensor([0, 1, 2, 3, 4, 5, 6])\n        >>> A = dglsp.spmatrix(indices, val)\n\n        Case 1: Sample rows with the given number and disable repeated sampling.\n\n        >>> row_ids = torch.tensor([0, 2])\n        >>> A.sample(0, 2, row_ids)\n        SparseMatrix(indices=tensor([[0, 0, 1, 1],\n                                     [0, 2, 0, 2]]),\n                     values=tensor([0, 1, 4, 6]),\n                     shape=(2, 3), nnz=4)\n\n        Case 2: Sample cols with the given number and disable repeated sampling.\n\n        >>> col_ids = torch.tensor([0, 2])\n        >>> A.sample(1, 2, col_ids)\n        SparseMatrix(indices=tensor([[0, 1, 0, 2],\n                                     [0, 0, 1, 1]]),\n                     values=tensor([0, 2, 1, 6]),\n                     shape=(3, 2), nnz=4)\n\n        Case 3: Sample rows with the given number and enable repeated sampling.\n\n        >>> row_ids = torch.tensor([0, 1])\n        >>> A.sample(0, 2, row_ids, True)\n        SparseMatrix(indices=tensor([[0, 0, 1, 1],\n                                     [0, 2, 0, 0]]),\n                     values=tensor([0, 1, 2, 2]),\n                     shape=(2, 3), nnz=3)\n\n        Case 4: Sample cols with the given number and enable repeated sampling.\n\n        >>> col_ids = torch.tensor([0, 1])\n        >>> A.sample(1, 2, col_ids, True)\n        SparseMatrix(indices=tensor([[0, 1, 1, 1],\n                                     [0, 0, 1, 1]]),\n                     values=tensor([0, 2, 3, 3]),\n                     shape=(3, 2), nnz=3)\n        \"\"\"\n        if ids is None:\n            dim_size = self.shape[0] if dim == 0 else self.shape[1]\n            ids = torch.range(\n                0, dim_size, dtype=torch.int64, device=self.device\n            )\n        return SparseMatrix(\n            self.c_sparse_matrix.sample(dim, fanout, ids, replace, bias)\n        )\n\n    def compact(\n        self,\n        dim: int,\n        leading_indices: Optional[torch.Tensor] = None,\n    ):\n        \"\"\"Compact sparse matrix by removing rows or columns without non-zero\n        elements in the sparse matrix and relabeling indices of the dimension.\n\n        This function serves a dual purpose: it allows you to reorganize the\n        indices within a specific dimension (rows or columns) of the sparse\n        matrix and, if needed, place certain 'leading_indices' at the beginning\n        of the relabeled dimension.\n\n        In the absence of 'leading_indices' (when it's set to `None`), the order\n        of relabeled indices remains the same as the original order, except that\n        rows or columns without non-zero elements are removed. When\n        'leading_indices' are provided, they are positioned at the start of the\n        relabeled dimension. To be precise, all rows selected by the specified\n        indices will be remapped from 0 to length(indices) - 1. Rows that are not\n        selected and contain any non-zero elements will be positioned after those\n        remapped rows while maintaining their original order.\n\n        This function mimics 'dgl.to_block', a method used to compress a sampled\n        subgraph by eliminating redundant nodes. The 'leading_indices' parameter\n        replicates the behavior of 'include_dst_in_src' in 'dgl.to_block',\n        adding destination node information for message passing.\n        Setting 'leading_indices' to column IDs when relabeling the row\n        dimension, for example, achieves the same effect as including destination\n        nodes in source nodes.\n\n        Parameters\n        ----------\n        dim : int\n            The dimension to relabel. Should be 0 or 1. Use `dim = 0` for rowwise\n            relabeling and `dim = 1` for columnwise relabeling.\n        leading_indices : torch.Tensor, optional\n            An optional tensor containing row or column ids that should be placed\n            at the beginning of the relabeled dimension.\n\n        Returns\n        -------\n        Tuple[SparseMatrix, torch.Tensor]\n            A tuple containing the relabeled sparse matrix and the index mapping\n            of the relabeled dimension from the new index to the original index.\n\n        Examples\n        --------\n        >>> indices = torch.tensor([[0, 2],\n                                    [1, 2]])\n        >>> A = dglsp.spmatrix(indices)\n        >>> print(A.to_dense())\n        tensor([[0., 1., 0.],\n                [0., 0., 0.],\n                [0., 0., 1.]])\n\n        Case 1: Compact rows without indices.\n\n        >>> B, original_rows = A.compact(dim=0, leading_indices=None)\n        >>> print(B.to_dense())\n        tensor([[0., 1., 0.],\n                [0., 0., 1.]])\n        >>> print(original_rows)\n        torch.Tensor([0, 2])\n\n        Case 2: Compact rows with indices.\n\n        >>> B, original_rows = A.compact(dim=0, leading_indices=[1, 2])\n        >>> print(B.to_dense())\n        tensor([[0., 0., 0.],\n                [0., 0., 1.],\n                [0., 1., 0.],])\n        >>> print(original_rows)\n        torch.Tensor([1, 2, 0])\n        \"\"\"\n        mat, idx = torch.ops.dgl_sparse.compact(\n            self.c_sparse_matrix, dim, leading_indices\n        )\n        return SparseMatrix(mat), idx\n\n\ndef spmatrix(\n    indices: torch.Tensor,\n    val: Optional[torch.Tensor] = None,\n    shape: Optional[Tuple[int, int]] = None,\n) -> SparseMatrix:\n    r\"\"\"Creates a sparse matrix from Coordinate format indices.\n\n    Parameters\n    ----------\n    indices : tensor.Tensor\n        The indices are the coordinates of the non-zero elements in the matrix,\n        which should have shape of ``(2, N)`` where the first row is the row\n        indices and the second row is the column indices of non-zero elements.\n    val : tensor.Tensor, optional\n        The values of shape ``(nnz)`` or ``(nnz, D)``. If None, it will be a\n        tensor of shape ``(nnz)`` filled by 1.\n    shape : tuple[int, int], optional\n        If not specified, it will be inferred from :attr:`row` and :attr:`col`,\n        i.e., ``(row.max() + 1, col.max() + 1)``. Otherwise, :attr:`shape`\n        should be no smaller than this.\n\n    Returns\n    -------\n    SparseMatrix\n        Sparse matrix\n\n    Examples\n    --------\n\n    Case1: Sparse matrix with row and column indices without values.\n\n    >>> indices = torch.tensor([[1, 1, 2], [2, 4, 3]])\n    >>> A = dglsp.spmatrix(indices)\n    SparseMatrix(indices=tensor([[1, 1, 2],\n                                 [2, 4, 3]]),\n                 values=tensor([1., 1., 1.]),\n                 shape=(3, 5), nnz=3)\n    >>> # Specify shape\n    >>> A = dglsp.spmatrix(indices, shape=(5, 5))\n    SparseMatrix(indices=tensor([[1, 1, 2],\n                                 [2, 4, 3]]),\n                 values=tensor([1., 1., 1.]),\n                 shape=(5, 5), nnz=3)\n\n    Case2: Sparse matrix with scalar values.\n\n    >>> indices = torch.tensor([[1, 1, 2], [2, 4, 3]])\n    >>> val = torch.tensor([[1.], [2.], [3.]])\n    >>> A = dglsp.spmatrix(indices, val)\n    SparseMatrix(indices=tensor([[1, 1, 2],\n                                 [2, 4, 3]]),\n                 values=tensor([[1.],\n                                [2.],\n                                [3.]]),\n                 shape=(3, 5), nnz=3, val_size=(1,))\n\n    Case3: Sparse matrix with vector values.\n\n    >>> indices = torch.tensor([[1, 1, 2], [2, 4, 3]])\n    >>> val = torch.tensor([[1., 1.], [2., 2.], [3., 3.]])\n    >>> A = dglsp.spmatrix(indices, val)\n    SparseMatrix(indices=tensor([[1, 1, 2],\n                                 [2, 4, 3]]),\n                 values=tensor([[1., 1.],\n                                [2., 2.],\n                                [3., 3.]]),\n                 shape=(3, 5), nnz=3, val_size=(2,))\n    \"\"\"\n    if shape is None:\n        shape = (\n            torch.max(indices[0]).item() + 1,\n            torch.max(indices[1]).item() + 1,\n        )\n    if val is None:\n        val = torch.ones(indices.shape[1]).to(indices.device)\n\n    assert (\n        val.dim() <= 2\n    ), \"The values of a SparseMatrix can only be scalars or vectors.\"\n    return SparseMatrix(torch.ops.dgl_sparse.from_coo(indices, val, shape))\n\n\ndef from_coo(\n    row: torch.Tensor,\n    col: torch.Tensor,\n    val: Optional[torch.Tensor] = None,\n    shape: Optional[Tuple[int, int]] = None,\n) -> SparseMatrix:\n    r\"\"\"Creates a sparse matrix from a coordinate list (COO), which stores a list\n    of (row, column, value) tuples.\n\n    See `COO in Wikipedia\n    <https://en.wikipedia.org/wiki/Sparse_matrix#Coordinate_list_(COO)>`_.\n\n    Parameters\n    ----------\n    row : torch.Tensor\n        The row indices of shape ``(nnz)``\n    col : torch.Tensor\n        The column indices of shape ``(nnz)``\n    val : torch.Tensor, optional\n        The values of shape ``(nnz)`` or ``(nnz, D)``. If None, it will be a\n        tensor of shape ``(nnz)`` filled by 1.\n    shape : tuple[int, int], optional\n        If not specified, it will be inferred from :attr:`row` and :attr:`col`,\n        i.e., ``(row.max() + 1, col.max() + 1)``. Otherwise, :attr:`shape`\n        should be no smaller than this.\n\n    Returns\n    -------\n    SparseMatrix\n        Sparse matrix\n\n    Examples\n    --------\n\n    Case1: Sparse matrix with row and column indices without values.\n\n    >>> dst = torch.tensor([1, 1, 2])\n    >>> src = torch.tensor([2, 4, 3])\n    >>> A = dglsp.from_coo(dst, src)\n    SparseMatrix(indices=tensor([[1, 1, 2],\n                                 [2, 4, 3]]),\n                 values=tensor([1., 1., 1.]),\n                 shape=(3, 5), nnz=3)\n    >>> # Specify shape\n    >>> A = dglsp.from_coo(dst, src, shape=(5, 5))\n    SparseMatrix(indices=tensor([[1, 1, 2],\n                                 [2, 4, 3]]),\n                 values=tensor([1., 1., 1.]),\n                 shape=(5, 5), nnz=3)\n\n    Case2: Sparse matrix with scalar values.\n\n    >>> indices = torch.tensor([[1, 1, 2], [2, 4, 3]])\n    >>> val = torch.tensor([[1.], [2.], [3.]])\n    >>> A = dglsp.spmatrix(indices, val)\n    SparseMatrix(indices=tensor([[1, 1, 2],\n                                 [2, 4, 3]]),\n                 values=tensor([[1.],\n                                [2.],\n                                [3.]]),\n                 shape=(3, 5), nnz=3, val_size=(1,))\n\n    Case3: Sparse matrix with vector values.\n\n    >>> dst = torch.tensor([1, 1, 2])\n    >>> src = torch.tensor([2, 4, 3])\n    >>> val = torch.tensor([[1., 1.], [2., 2.], [3., 3.]])\n    >>> A = dglsp.from_coo(dst, src, val)\n    SparseMatrix(indices=tensor([[1, 1, 2],\n                                 [2, 4, 3]]),\n                 values=tensor([[1., 1.],\n                                [2., 2.],\n                                [3., 3.]]),\n                 shape=(3, 5), nnz=3, val_size=(2,))\n    \"\"\"\n    assert row.shape[0] == col.shape[0]\n    return spmatrix(torch.stack([row, col]), val, shape)\n\n\ndef from_csr(\n    indptr: torch.Tensor,\n    indices: torch.Tensor,\n    val: Optional[torch.Tensor] = None,\n    shape: Optional[Tuple[int, int]] = None,\n) -> SparseMatrix:\n    r\"\"\"Creates a sparse matrix from compress sparse row (CSR) format.\n\n    See `CSR in Wikipedia <https://en.wikipedia.org/wiki/\n    Sparse_matrix#Compressed_sparse_row_(CSR,_CRS_or_Yale_format)>`_.\n\n    For row i of the sparse matrix\n\n    - the column indices of the non-zero elements are stored in\n      ``indices[indptr[i]: indptr[i+1]]``\n    - the corresponding values are stored in ``val[indptr[i]: indptr[i+1]]``\n\n    Parameters\n    ----------\n    indptr : torch.Tensor\n        Pointer to the column indices of shape ``(N + 1)``, where ``N`` is the\n        number of rows\n    indices : torch.Tensor\n        The column indices of shape ``(nnz)``\n    val : torch.Tensor, optional\n        The values of shape ``(nnz)`` or ``(nnz, D)``. If None, it will be a\n        tensor of shape ``(nnz)`` filled by 1.\n    shape : tuple[int, int], optional\n        If not specified, it will be inferred from :attr:`indptr` and\n        :attr:`indices`, i.e., ``(len(indptr) - 1, indices.max() + 1)``.\n        Otherwise, :attr:`shape` should be no smaller than this.\n\n    Returns\n    -------\n    SparseMatrix\n        Sparse matrix\n\n    Examples\n    --------\n\n    Case1: Sparse matrix without values\n\n    .. code::\n\n        [[0, 1, 0],\n         [0, 0, 1],\n         [1, 1, 1]]\n\n    >>> indptr = torch.tensor([0, 1, 2, 5])\n    >>> indices = torch.tensor([1, 2, 0, 1, 2])\n    >>> A = dglsp.from_csr(indptr, indices)\n    SparseMatrix(indices=tensor([[0, 1, 2, 2, 2],\n                                 [1, 2, 0, 1, 2]]),\n                 values=tensor([1., 1., 1., 1., 1.]),\n                 shape=(3, 3), nnz=5)\n    >>> # Specify shape\n    >>> A = dglsp.from_csr(indptr, indices, shape=(3, 5))\n    SparseMatrix(indices=tensor([[0, 1, 2, 2, 2],\n                                 [1, 2, 0, 1, 2]]),\n                 values=tensor([1., 1., 1., 1., 1.]),\n                 shape=(3, 5), nnz=5)\n\n    Case2: Sparse matrix with scalar/vector values. Following example is with\n    vector data.\n\n    >>> indptr = torch.tensor([0, 1, 2, 5])\n    >>> indices = torch.tensor([1, 2, 0, 1, 2])\n    >>> val = torch.tensor([[1, 1], [2, 2], [3, 3], [4, 4], [5, 5]])\n    >>> A = dglsp.from_csr(indptr, indices, val)\n    SparseMatrix(indices=tensor([[0, 1, 2, 2, 2],\n                                 [1, 2, 0, 1, 2]]),\n                 values=tensor([[1, 1],\n                                [2, 2],\n                                [3, 3],\n                                [4, 4],\n                                [5, 5]]),\n                 shape=(3, 3), nnz=5, val_size=(2,))\n    \"\"\"\n    if shape is None:\n        shape = (indptr.shape[0] - 1, torch.max(indices) + 1)\n    if val is None:\n        val = torch.ones(indices.shape[0]).to(indptr.device)\n\n    assert (\n        val.dim() <= 2\n    ), \"The values of a SparseMatrix can only be scalars or vectors.\"\n\n    return SparseMatrix(\n        torch.ops.dgl_sparse.from_csr(indptr, indices, val, shape)\n    )\n\n\ndef from_csc(\n    indptr: torch.Tensor,\n    indices: torch.Tensor,\n    val: Optional[torch.Tensor] = None,\n    shape: Optional[Tuple[int, int]] = None,\n) -> SparseMatrix:\n    r\"\"\"Creates a sparse matrix from compress sparse column (CSC) format.\n\n    See `CSC in Wikipedia <https://en.wikipedia.org/wiki/\n    Sparse_matrix#Compressed_sparse_column_(CSC_or_CCS)>`_.\n\n    For column i of the sparse matrix\n\n    - the row indices of the non-zero elements are stored in\n      ``indices[indptr[i]: indptr[i+1]]``\n    - the corresponding values are stored in ``val[indptr[i]: indptr[i+1]]``\n\n    Parameters\n    ----------\n    indptr : torch.Tensor\n        Pointer to the row indices of shape N + 1, where N is the\n        number of columns\n    indices : torch.Tensor\n        The row indices of shape nnz\n    val : torch.Tensor, optional\n        The values of shape ``(nnz)`` or ``(nnz, D)``. If None, it will be a\n        tensor of shape ``(nnz)`` filled by 1.\n    shape : tuple[int, int], optional\n        If not specified, it will be inferred from :attr:`indptr` and\n        :attr:`indices`, i.e., ``(indices.max() + 1, len(indptr) - 1)``.\n        Otherwise, :attr:`shape` should be no smaller than this.\n\n    Returns\n    -------\n    SparseMatrix\n        Sparse matrix\n\n    Examples\n    --------\n\n    Case1: Sparse matrix without values\n\n    .. code::\n\n        [[0, 1, 0],\n         [0, 0, 1],\n         [1, 1, 1]]\n\n    >>> indptr = torch.tensor([0, 1, 3, 5])\n    >>> indices = torch.tensor([2, 0, 2, 1, 2])\n    >>> A = dglsp.from_csc(indptr, indices)\n    SparseMatrix(indices=tensor([[2, 0, 2, 1, 2],\n                                 [0, 1, 1, 2, 2]]),\n                 values=tensor([1., 1., 1., 1., 1.]),\n                 shape=(3, 3), nnz=5)\n    >>> # Specify shape\n    >>> A = dglsp.from_csc(indptr, indices, shape=(5, 3))\n    SparseMatrix(indices=tensor([[2, 0, 2, 1, 2],\n                                 [0, 1, 1, 2, 2]]),\n                 values=tensor([1., 1., 1., 1., 1.]),\n                 shape=(5, 3), nnz=5)\n\n    Case2: Sparse matrix with scalar/vector values. Following example is with\n    vector data.\n\n    >>> indptr = torch.tensor([0, 1, 3, 5])\n    >>> indices = torch.tensor([2, 0, 2, 1, 2])\n    >>> val = torch.tensor([[1, 1], [2, 2], [3, 3], [4, 4], [5, 5]])\n    >>> A = dglsp.from_csc(indptr, indices, val)\n    SparseMatrix(indices=tensor([[2, 0, 2, 1, 2],\n                                 [0, 1, 1, 2, 2]]),\n                 values=tensor([[1, 1],\n                                [2, 2],\n                                [3, 3],\n                                [4, 4],\n                                [5, 5]]),\n                 shape=(3, 3), nnz=5, val_size=(2,))\n    \"\"\"\n    if shape is None:\n        shape = (torch.max(indices) + 1, indptr.shape[0] - 1)\n    if val is None:\n        val = torch.ones(indices.shape[0]).to(indptr.device)\n\n    assert (\n        val.dim() <= 2\n    ), \"The values of a SparseMatrix can only be scalars or vectors.\"\n\n    return SparseMatrix(\n        torch.ops.dgl_sparse.from_csc(indptr, indices, val, shape)\n    )\n\n\ndef val_like(mat: SparseMatrix, val: torch.Tensor) -> SparseMatrix:\n    \"\"\"Creates a sparse matrix from an existing sparse matrix using new values.\n\n    The new sparse matrix will have the same non-zero indices as the given\n    sparse matrix and use the given values as the new non-zero values.\n\n    Parameters\n    ----------\n    mat : SparseMatrix\n        An existing sparse matrix with non-zero values\n    val : torch.Tensor\n        The new values of the non-zero elements, a tensor of shape ``(nnz)`` or\n        ``(nnz, D)``\n\n    Returns\n    -------\n    SparseMatrix\n        New sparse matrix\n\n    Examples\n    --------\n\n    >>> indices = torch.tensor([[1, 1, 2], [2, 4, 3]])\n    >>> val = torch.ones(3)\n    >>> A = dglsp.spmatrix(indices, val)\n    >>> A = dglsp.val_like(A, torch.tensor([2, 2, 2]))\n    SparseMatrix(indices=tensor([[1, 1, 2],\n                                 [2, 4, 3]]),\n                 values=tensor([2, 2, 2]),\n                 shape=(3, 5), nnz=3)\n    \"\"\"\n    assert (\n        val.dim() <= 2\n    ), \"The values of a SparseMatrix can only be scalars or vectors.\"\n\n    return SparseMatrix(torch.ops.dgl_sparse.val_like(mat.c_sparse_matrix, val))\n\n\ndef diag(\n    val: torch.Tensor, shape: Optional[Tuple[int, int]] = None\n) -> SparseMatrix:\n    \"\"\"Creates a sparse matrix based on the diagonal values.\n\n    Parameters\n    ----------\n    val : torch.Tensor\n        Diagonal of the matrix, in shape ``(N)`` or ``(N, D)``\n    shape : tuple[int, int], optional\n        If specified, :attr:`len(val)` must be equal to :attr:`min(shape)`,\n        otherwise, it will be inferred from :attr:`val`, i.e., ``(N, N)``\n\n    Returns\n    -------\n    SparseMatrix\n        Sparse matrix\n\n    Examples\n    --------\n\n    Case1: 5-by-5 diagonal matrix with scaler values on the diagonal\n\n    >>> import torch\n    >>> val = torch.ones(5)\n    >>> dglsp.diag(val)\n    SparseMatrix(indices=tensor([[0, 1, 2, 3, 4],\n                                 [0, 1, 2, 3, 4]]),\n                 values=tensor([1., 1., 1., 1., 1.]),\n                 shape=(5, 5), nnz=5)\n\n    Case2: 5-by-10 diagonal matrix with scaler values on the diagonal\n\n    >>> val = torch.ones(5)\n    >>> dglsp.diag(val, shape=(5, 10))\n    SparseMatrix(indices=tensor([[0, 1, 2, 3, 4],\n                                 [0, 1, 2, 3, 4]]),\n                 values=tensor([1., 1., 1., 1., 1.]),\n                 shape=(5, 10), nnz=5)\n\n    Case3: 5-by-5 diagonal matrix with vector values on the diagonal\n\n    >>> val = torch.randn(5, 3)\n    >>> D = dglsp.diag(val)\n    >>> D.shape\n    (5, 5)\n    >>> D.nnz\n    5\n    \"\"\"\n    assert (\n        val.dim() <= 2\n    ), \"The values of a DiagMatrix can only be scalars or vectors.\"\n    len_val = len(val)\n    if shape is not None:\n        assert len_val == min(shape), (\n            f\"Expect len(val) to be min(shape) for a diagonal matrix, got\"\n            f\"{len_val} for len(val) and {shape} for shape.\"\n        )\n    else:\n        shape = (len_val, len_val)\n    return SparseMatrix(torch.ops.dgl_sparse.from_diag(val, shape))\n\n\ndef identity(\n    shape: Tuple[int, int],\n    d: Optional[int] = None,\n    dtype: Optional[torch.dtype] = None,\n    device: Optional[torch.device] = None,\n) -> SparseMatrix:\n    r\"\"\"Creates a sparse matrix with ones on the diagonal and zeros elsewhere.\n\n    Parameters\n    ----------\n    shape : tuple[int, int]\n        Shape of the matrix.\n    d : int, optional\n        If None, the diagonal entries will be scaler 1. Otherwise, the diagonal\n        entries will be a 1-valued tensor of shape ``(d)``.\n    dtype : torch.dtype, optional\n        The data type of the matrix\n    device : torch.device, optional\n        The device of the matrix\n\n    Returns\n    -------\n    SparseMatrix\n        Sparse matrix\n\n    Examples\n    --------\n\n    Case1: 3-by-3 matrix with scaler diagonal values\n\n    .. code::\n\n        [[1, 0, 0],\n         [0, 1, 0],\n         [0, 0, 1]]\n\n    >>> dglsp.identity(shape=(3, 3))\n    SparseMatrix(indices=tensor([[0, 1, 2],\n                                 [0, 1, 2]]),\n                 values=tensor([1., 1., 1.]),\n                 shape=(3, 3), nnz=3)\n\n    Case2: 3-by-5 matrix with scaler diagonal values\n\n    .. code::\n\n        [[1, 0, 0, 0, 0],\n         [0, 1, 0, 0, 0],\n         [0, 0, 1, 0, 0]]\n\n    >>> dglsp.identity(shape=(3, 5))\n    SparseMatrix(indices=tensor([[0, 1, 2],\n                                 [0, 1, 2]]),\n                 values=tensor([1., 1., 1.]),\n                 shape=(3, 5), nnz=3)\n\n    Case3: 3-by-3 matrix with vector diagonal values\n\n    >>> dglsp.identity(shape=(3, 3), d=2)\n    SparseMatrix(indices=tensor([[0, 1, 2],\n                                 [0, 1, 2]]),\n                 values=tensor([[1., 1.],\n                                [1., 1.],\n                                [1., 1.]]),\n                 shape=(3, 3), nnz=3, val_size=(2,))\n    \"\"\"\n    len_val = min(shape)\n    if d is None:\n        val_shape = (len_val,)\n    else:\n        val_shape = (len_val, d)\n    val = torch.ones(val_shape, dtype=dtype, device=device)\n    return diag(val, shape)\n\n\ndef from_torch_sparse(torch_sparse_tensor: torch.Tensor) -> SparseMatrix:\n    \"\"\"Creates a sparse matrix from a torch sparse tensor, which can have coo,\n    csr, or csc layout.\n\n    Parameters\n    ----------\n    torch_sparse_tensor : torch.Tensor\n        Torch sparse tensor\n\n    Returns\n    -------\n    SparseMatrix\n        Sparse matrix\n\n    Examples\n    --------\n\n    >>> indices = torch.tensor([[1, 1, 2], [2, 4, 3]])\n    >>> val = torch.ones(3)\n    >>> torch_coo = torch.sparse_coo_tensor(indices, val)\n    >>> dglsp.from_torch_sparse(torch_coo)\n    SparseMatrix(indices=tensor([[1, 1, 2],\n                                 [2, 4, 3]]),\n                 values=tensor([1., 1., 1.]),\n                 shape=(3, 5), nnz=3)\n    \"\"\"\n    assert torch_sparse_tensor.layout in (\n        torch.sparse_coo,\n        torch.sparse_csr,\n        torch.sparse_csc,\n    ), (\n        f\"Cannot convert Pytorch sparse tensor with layout \"\n        f\"{torch_sparse_tensor.layout} to DGL sparse.\"\n    )\n    if torch_sparse_tensor.layout == torch.sparse_coo:\n        # Use ._indices() and ._values() to access uncoalesced indices and\n        # values.\n        return spmatrix(\n            torch_sparse_tensor._indices(),\n            torch_sparse_tensor._values(),\n            torch_sparse_tensor.shape[:2],\n        )\n    elif torch_sparse_tensor.layout == torch.sparse_csr:\n        return from_csr(\n            torch_sparse_tensor.crow_indices(),\n            torch_sparse_tensor.col_indices(),\n            torch_sparse_tensor.values(),\n            torch_sparse_tensor.shape[:2],\n        )\n    else:\n        return from_csc(\n            torch_sparse_tensor.ccol_indices(),\n            torch_sparse_tensor.row_indices(),\n            torch_sparse_tensor.values(),\n            torch_sparse_tensor.shape[:2],\n        )\n\n\ndef to_torch_sparse_coo(spmat: SparseMatrix) -> torch.Tensor:\n    \"\"\"Creates a torch sparse coo tensor from a sparse matrix.\n\n    Parameters\n    ----------\n    spmat : SparseMatrix\n        Sparse matrix\n\n    Returns\n    -------\n    torch.Tensor\n        torch tensor with torch.sparse_coo layout\n\n    Examples\n    --------\n\n    >>> indices = torch.tensor([[1, 1, 2], [2, 4, 3]])\n    >>> val = torch.ones(3)\n    >>> spmat = dglsp.spmatrix(indices, val)\n    >>> dglsp.to_torch_sparse_coo(spmat)\n    tensor(indices=tensor([[1, 1, 2],\n                           [2, 4, 3]]),\n           values=tensor([1., 1., 1.]),\n           size=(3, 5), nnz=3, layout=torch.sparse_coo)\n    \"\"\"\n    shape = spmat.shape\n    if spmat.val.dim() > 1:\n        shape += spmat.val.shape[1:]\n    return torch.sparse_coo_tensor(spmat.indices(), spmat.val, shape)\n\n\ndef to_torch_sparse_csr(spmat: SparseMatrix) -> torch.Tensor:\n    \"\"\"Creates a torch sparse csr tensor from a sparse matrix.\n\n    Note that converting a sparse matrix to torch csr tensor could change the\n    order of non-zero values.\n\n    Parameters\n    ----------\n    spmat : SparseMatrix\n        Sparse matrix\n\n    Returns\n    -------\n    torch.Tensor\n        Torch tensor with torch.sparse_csr layout\n\n    Examples\n    --------\n\n    >>> indices = torch.tensor([[1, 2, 1], [2, 4, 3]])\n    >>> val = torch.arange(3)\n    >>> spmat = dglsp.spmatrix(indices, val)\n    >>> dglsp.to_torch_sparse_csr(spmat)\n    tensor(crow_indices=tensor([0, 0, 2, 3]),\n           col_indices=tensor([2, 3, 4]),\n           values=tensor([0, 2, 1]), size=(3, 5), nnz=3,\n           layout=torch.sparse_csr)\n    \"\"\"\n    shape = spmat.shape\n    if spmat.val.dim() > 1:\n        shape += spmat.val.shape[1:]\n    indptr, indices, value_indices = spmat.csr()\n    val = spmat.val\n    if value_indices is not None:\n        val = val[value_indices]\n    return torch.sparse_csr_tensor(indptr, indices, val, shape)\n\n\ndef to_torch_sparse_csc(spmat: SparseMatrix) -> torch.Tensor:\n    \"\"\"Creates a torch sparse csc tensor from a sparse matrix.\n\n    Note that converting a sparse matrix to torch csc tensor could change the\n    order of non-zero values.\n\n    Parameters\n    ----------\n    spmat : SparseMatrix\n        Sparse matrix\n\n    Returns\n    -------\n    torch.Tensor\n        Torch tensor with torch.sparse_csc layout\n\n    Examples\n    --------\n\n    >>> indices = torch.tensor([[1, 2, 1], [2, 4, 3]])\n    >>> val = torch.arange(3)\n    >>> spmat = dglsp.spmatrix(indices, val)\n    >>> dglsp.to_torch_sparse_csc(spmat)\n    tensor(ccol_indices=tensor([0, 0, 0, 1, 2, 3]),\n           row_indices=tensor([1, 1, 2]),\n           values=tensor([0, 2, 1]), size=(3, 5), nnz=3,\n           layout=torch.sparse_csc)\n    \"\"\"\n    shape = spmat.shape\n    if spmat.val.dim() > 1:\n        shape += spmat.val.shape[1:]\n    indptr, indices, value_indices = spmat.csc()\n    val = spmat.val\n    if value_indices is not None:\n        val = val[value_indices]\n    return torch.sparse_csc_tensor(indptr, indices, val, shape)\n\n\ndef _sparse_matrix_str(spmat: SparseMatrix) -> str:\n    \"\"\"Internal function for converting a sparse matrix to string\n    representation.\n    \"\"\"\n    indices_str = str(torch.stack(spmat.coo()))\n    values_str = str(spmat.val)\n    meta_str = f\"shape={spmat.shape}, nnz={spmat.nnz}\"\n    if spmat.val.dim() > 1:\n        val_size = tuple(spmat.val.shape[1:])\n        meta_str += f\", val_size={val_size}\"\n    prefix = f\"{type(spmat).__name__}(\"\n\n    def _add_indent(_str, indent):\n        lines = _str.split(\"\\n\")\n        lines = [lines[0]] + [\" \" * indent + line for line in lines[1:]]\n        return \"\\n\".join(lines)\n\n    final_str = (\n        \"indices=\"\n        + _add_indent(indices_str, len(\"indices=\"))\n        + \",\\n\"\n        + \"values=\"\n        + _add_indent(values_str, len(\"values=\"))\n        + \",\\n\"\n        + meta_str\n        + \")\"\n    )\n    final_str = prefix + _add_indent(final_str, len(prefix))\n    return final_str\n"
  },
  {
    "path": "python/dgl/sparse/unary_op.py",
    "content": "\"\"\"DGL unary operators for sparse matrix module.\"\"\"\nfrom .sparse_matrix import diag, SparseMatrix, val_like\n\n\ndef neg(A: SparseMatrix) -> SparseMatrix:\n    \"\"\"Returns a new sparse matrix with the negation of the original nonzero\n    values, equivalent to ``-A``.\n\n    Returns\n    -------\n    SparseMatrix\n        Negation of the sparse matrix\n\n    Examples\n    --------\n\n    >>> indices = torch.tensor([[1, 1, 3], [1, 2, 3]])\n    >>> val = torch.tensor([1., 1., 2.])\n    >>> A = dglsp.spmatrix(indices, val)\n    >>> A = -A\n    SparseMatrix(indices=tensor([[1, 1, 3],\n                                 [1, 2, 3]]),\n                 values=tensor([-1., -1., -2.]),\n                 shape=(4, 4), nnz=3)\n    \"\"\"\n    return val_like(A, -A.val)\n\n\ndef inv(A: SparseMatrix) -> SparseMatrix:\n    \"\"\"Returns the inverse of the sparse matrix.\n\n    This function only supports square diagonal matrices with scalar nonzero\n    values.\n\n    Returns\n    -------\n    SparseMatrix\n        Inverse of the sparse matrix\n\n    Examples\n    --------\n\n    >>> val = torch.arange(1, 4).float()\n    >>> D = dglsp.diag(val)\n    >>> D.inv()\n    SparseMatrix(indices=tensor([[0, 1, 2],\n                                 [0, 1, 2]]),\n                 values=tensor([1., 2., 3.]),\n                 shape=(3, 3), nnz=3)\n    \"\"\"\n    num_rows, num_cols = A.shape\n    assert A.is_diag(), \"Non-diagonal sparse matrix does not support inversion.\"\n    assert num_rows == num_cols, f\"Expect a square matrix, got shape {A.shape}\"\n    assert len(A.val.shape) == 1, \"inv only supports 1D nonzero val\"\n\n    return diag(1.0 / A.val, A.shape)\n\n\nSparseMatrix.neg = neg\nSparseMatrix.__neg__ = neg\nSparseMatrix.inv = inv\n"
  },
  {
    "path": "python/dgl/sparse/utils.py",
    "content": "\"\"\"Utilities for DGL sparse module.\"\"\"\nfrom numbers import Number\nfrom typing import Union\n\nimport torch\n\n\ndef is_scalar(x):\n    \"\"\"Check if the input is a scalar.\"\"\"\n    return isinstance(x, Number) or (torch.is_tensor(x) and x.dim() == 0)\n\n\n# Scalar type annotation\nScalar = Union[Number, torch.Tensor]\n"
  },
  {
    "path": "python/dgl/storages/__init__.py",
    "content": "\"\"\"Feature storage classes for DataLoading\"\"\"\nfrom .. import backend as F\nfrom .base import *\nfrom .numpy import *\n\n# Defines the name TensorStorage\nif F.get_preferred_backend() == \"pytorch\":\n    from .pytorch_tensor import PyTorchTensorStorage as TensorStorage\nelse:\n    from .tensor import BaseTensorStorage as TensorStorage\n"
  },
  {
    "path": "python/dgl/storages/base.py",
    "content": "\"\"\"Base classes and functionalities for feature storages.\"\"\"\n\nimport threading\n\nSTORAGE_WRAPPERS = {}\n\n\ndef register_storage_wrapper(type_):\n    \"\"\"Decorator that associates a type to a ``FeatureStorage`` object.\"\"\"\n\n    def deco(cls):\n        STORAGE_WRAPPERS[type_] = cls\n        return cls\n\n    return deco\n\n\ndef wrap_storage(storage):\n    \"\"\"Wrap an object into a FeatureStorage as specified by the ``register_storage_wrapper``\n    decorators.\n    \"\"\"\n    for type_, storage_cls in STORAGE_WRAPPERS.items():\n        if isinstance(storage, type_):\n            return storage_cls(storage)\n\n    assert isinstance(\n        storage, FeatureStorage\n    ), \"The frame column must be a tensor or a FeatureStorage object, got {}\".format(\n        type(storage)\n    )\n    return storage\n\n\nclass _FuncWrapper(object):\n    def __init__(self, func):\n        self.func = func\n\n    def __call__(self, buf, *args):\n        buf[0] = self.func(*args)\n\n\nclass ThreadedFuture(object):\n    \"\"\"Wraps a function into a future asynchronously executed by a Python\n    ``threading.Thread`.  The function is being executed upon instantiation of\n    this object.\n    \"\"\"\n\n    def __init__(self, target, args):\n        self.buf = [None]\n\n        thread = threading.Thread(\n            target=_FuncWrapper(target),\n            args=[self.buf] + list(args),\n            daemon=True,\n        )\n        thread.start()\n        self.thread = thread\n\n    def wait(self):\n        \"\"\"Blocks the current thread until the result becomes available and returns it.\"\"\"\n        self.thread.join()\n        return self.buf[0]\n\n\nclass FeatureStorage(object):\n    \"\"\"Feature storage object which should support a fetch() operation.  It is the\n    counterpart of a tensor for homogeneous graphs, or a dict of tensor for heterogeneous\n    graphs where the keys are node/edge types.\n    \"\"\"\n\n    def requires_ddp(self):\n        \"\"\"Whether the FeatureStorage requires the DataLoader to set use_ddp.\"\"\"\n        return False\n\n    def fetch(self, indices, device, pin_memory=False, **kwargs):\n        \"\"\"Retrieve the features at the given indices.\n\n        If :attr:`indices` is a tensor, this is equivalent to\n\n        .. code::\n\n           storage[indices]\n\n        If :attr:`indices` is a dict of tensor, this is equivalent to\n\n        .. code::\n\n           {k: storage[k][indices[k]] for k in indices.keys()}\n\n        The subclasses can choose to utilize or ignore the flag :attr:`pin_memory`\n        depending on the underlying framework.\n        \"\"\"\n        raise NotImplementedError\n"
  },
  {
    "path": "python/dgl/storages/numpy.py",
    "content": "\"\"\"Feature storage for ``numpy.memmap`` object.\"\"\"\nimport numpy as np\n\nfrom .. import backend as F\nfrom .base import FeatureStorage, register_storage_wrapper, ThreadedFuture\n\n\n@register_storage_wrapper(np.memmap)\nclass NumpyStorage(FeatureStorage):\n    \"\"\"FeatureStorage that asynchronously reads features from a ``numpy.memmap`` object.\"\"\"\n\n    def __init__(self, arr):\n        self.arr = arr\n\n    # pylint: disable=unused-argument\n    def _fetch(self, indices, device, pin_memory=False):\n        result = F.zerocopy_from_numpy(self.arr[indices])\n        result = F.copy_to(result, device)\n        return result\n\n    # pylint: disable=unused-argument\n    def fetch(self, indices, device, pin_memory=False, **kwargs):\n        return ThreadedFuture(\n            target=self._fetch, args=(indices, device, pin_memory)\n        )\n"
  },
  {
    "path": "python/dgl/storages/pytorch_tensor.py",
    "content": "\"\"\"Feature storages for PyTorch tensors.\"\"\"\n\nimport torch\n\nfrom ..utils import gather_pinned_tensor_rows\nfrom .base import register_storage_wrapper\nfrom .tensor import BaseTensorStorage\n\n\ndef _fetch_cpu(indices, tensor, feature_shape, device, pin_memory, **kwargs):\n    result = torch.empty(\n        indices.shape[0],\n        *feature_shape,\n        dtype=tensor.dtype,\n        pin_memory=pin_memory,\n    )\n    torch.index_select(tensor, 0, indices, out=result)\n    kwargs[\"non_blocking\"] = pin_memory\n    result = result.to(device, **kwargs)\n    return result\n\n\ndef _fetch_cuda(indices, tensor, device, **kwargs):\n    return torch.index_select(tensor, 0, indices).to(device, **kwargs)\n\n\n@register_storage_wrapper(torch.Tensor)\nclass PyTorchTensorStorage(BaseTensorStorage):\n    \"\"\"Feature storages for slicing a PyTorch tensor.\"\"\"\n\n    def fetch(self, indices, device, pin_memory=False, **kwargs):\n        device = torch.device(device)\n        storage_device_type = self.storage.device.type\n        indices_device_type = indices.device.type\n        if storage_device_type != \"cuda\":\n            if indices_device_type == \"cuda\":\n                if self.storage.is_pinned():\n                    return gather_pinned_tensor_rows(self.storage, indices)\n                else:\n                    raise ValueError(\n                        f\"Got indices on device {indices.device} whereas the feature tensor \"\n                        f\"is on {self.storage.device}. Please either (1) move the graph \"\n                        f\"to GPU with to() method, or (2) pin the graph with \"\n                        f\"pin_memory_() method.\"\n                    )\n            # CPU to CPU or CUDA - use pin_memory and async transfer if possible\n            else:\n                return _fetch_cpu(\n                    indices,\n                    self.storage,\n                    self.storage.shape[1:],\n                    device,\n                    pin_memory,\n                    **kwargs,\n                )\n        else:\n            # CUDA to CUDA or CPU\n            return _fetch_cuda(indices, self.storage, device, **kwargs)\n"
  },
  {
    "path": "python/dgl/storages/tensor.py",
    "content": "\"\"\"Feature storages for tensors across different frameworks.\"\"\"\nfrom .. import backend as F\nfrom .base import FeatureStorage\n\n\nclass BaseTensorStorage(FeatureStorage):\n    \"\"\"FeatureStorage that synchronously slices features from a tensor and transfers\n    it to the given device.\n    \"\"\"\n\n    def __init__(self, tensor):\n        self.storage = tensor\n\n    def fetch(\n        self, indices, device, pin_memory=False, **kwargs\n    ):  # pylint: disable=unused-argument\n        return F.copy_to(F.gather_row(self.storage, indices), device, **kwargs)\n"
  },
  {
    "path": "python/dgl/subgraph.py",
    "content": "\"\"\"Functions for extracting subgraphs.\n\nThe module only contains functions for extracting subgraphs deterministically.\nFor stochastic subgraph extraction, please see functions under :mod:`dgl.sampling`.\n\"\"\"\nfrom collections.abc import Mapping\n\nfrom . import backend as F, graph_index, heterograph_index, utils\nfrom ._ffi.function import _init_api\nfrom .base import DGLError\nfrom .heterograph import DGLGraph\nfrom .utils import context_of, recursive_apply\n\n__all__ = [\n    \"node_subgraph\",\n    \"edge_subgraph\",\n    \"node_type_subgraph\",\n    \"edge_type_subgraph\",\n    \"in_subgraph\",\n    \"out_subgraph\",\n    \"khop_in_subgraph\",\n    \"khop_out_subgraph\",\n]\n\n\ndef node_subgraph(\n    graph, nodes, *, relabel_nodes=True, store_ids=True, output_device=None\n):\n    \"\"\"Return a subgraph induced on the given nodes.\n\n    A node-induced subgraph is a graph with edges whose endpoints are both in the\n    specified node set. In addition to extracting the subgraph, DGL also copies\n    the features of the extracted nodes and edges to the resulting graph. The copy\n    is *lazy* and incurs data movement only when needed.\n\n    If the graph is heterogeneous, DGL extracts a subgraph per relation and composes\n    them as the resulting graph. Thus, the resulting graph has the same set of relations\n    as the input one.\n\n    Parameters\n    ----------\n    graph : DGLGraph\n        The graph to extract subgraphs from.\n    nodes : nodes or dict[str, nodes]\n        The nodes to form the subgraph, which cannot have any duplicate value. The result\n        will be undefined otherwise. The allowed nodes formats are:\n\n        * Int Tensor: Each element is a node ID. The tensor must have the same device type\n          and ID data type as the graph's.\n        * iterable[int]: Each element is a node ID.\n        * Bool Tensor: Each :math:`i^{th}` element is a bool flag indicating whether\n          node :math:`i` is in the subgraph.\n\n        If the graph is homogeneous, one can directly pass the above formats.\n        Otherwise, the argument must be a dictionary with keys being node types\n        and values being the node IDs in the above formats.\n    relabel_nodes : bool, optional\n        If True, the extracted subgraph will only have the nodes in the specified node set\n        and it will relabel the nodes in order.\n    store_ids : bool, optional\n        If True, it will store the raw IDs of the extracted edges in the ``edata`` of the\n        resulting graph under name ``dgl.EID``; if ``relabel_nodes`` is ``True``, it will\n        also store the raw IDs of the specified nodes in the ``ndata`` of the resulting\n        graph under name ``dgl.NID``.\n    output_device : Framework-specific device context object, optional\n        The output device.  Default is the same as the input graph.\n\n    Returns\n    -------\n    G : DGLGraph\n        The subgraph.\n\n    Notes\n    -----\n\n    This function discards the batch information. Please use\n    :func:`dgl.DGLGraph.set_batch_num_nodes`\n    and :func:`dgl.DGLGraph.set_batch_num_edges` on the transformed graph\n    to maintain the information.\n\n    Examples\n    --------\n    The following example uses PyTorch backend.\n\n    >>> import dgl\n    >>> import torch\n\n    Extract a subgraph from a homogeneous graph.\n\n    >>> g = dgl.graph(([0, 1, 2, 3, 4], [1, 2, 3, 4, 0]))  # 5-node cycle\n    >>> sg = dgl.node_subgraph(g, [0, 1, 4])\n    >>> sg\n    Graph(num_nodes=3, num_edges=2,\n          ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)}\n          edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)})\n    >>> sg.edges()\n    (tensor([0, 2]), tensor([1, 0]))\n    >>> sg.ndata[dgl.NID]  # original node IDs\n    tensor([0, 1, 4])\n    >>> sg.edata[dgl.EID]  # original edge IDs\n    tensor([0, 4])\n\n    Specify nodes using a boolean mask.\n\n    >>> nodes = torch.tensor([True, True, False, False, True])  # choose nodes [0, 1, 4]\n    >>> dgl.node_subgraph(g, nodes)\n    Graph(num_nodes=3, num_edges=2,\n          ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)}\n          edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)})\n\n    The resulting subgraph also copies features from the parent graph.\n\n    >>> g.ndata['x'] = torch.arange(10).view(5, 2)\n    >>> sg = dgl.node_subgraph(g, [0, 1, 4])\n    >>> sg\n    Graph(num_nodes=3, num_edges=2,\n          ndata_schemes={'x': Scheme(shape=(2,), dtype=torch.int64),\n                         '_ID': Scheme(shape=(), dtype=torch.int64)}\n          edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)})\n    >>> sg.ndata['x']\n    tensor([[0, 1],\n            [2, 3],\n            [8, 9]])\n\n    Extract a subgraph from a hetergeneous graph.\n\n    >>> g = dgl.heterograph({\n    >>>     ('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 2, 1]),\n    >>>     ('user', 'follows', 'user'): ([0, 1, 1], [1, 2, 2])\n    >>> })\n    >>> sub_g = dgl.node_subgraph(g, {'user': [1, 2]})\n    >>> sub_g\n    Graph(num_nodes={'game': 0, 'user': 2},\n          num_edges={('user', 'follows', 'user'): 2, ('user', 'plays', 'game'): 0},\n          metagraph=[('user', 'user', 'follows'), ('user', 'game', 'plays')])\n\n    See Also\n    --------\n    edge_subgraph\n    \"\"\"\n    if graph.is_block:\n        raise DGLError(\"Extracting subgraph from a block graph is not allowed.\")\n    if not isinstance(nodes, Mapping):\n        assert (\n            len(graph.ntypes) == 1\n        ), \"need a dict of node type and IDs for graph with multiple node types\"\n        nodes = {graph.ntypes[0]: nodes}\n\n    def _process_nodes(ntype, v):\n        if F.is_tensor(v) and F.dtype(v) == F.bool:\n            return F.astype(\n                F.nonzero_1d(F.copy_to(v, graph.device)), graph.idtype\n            )\n        else:\n            return utils.prepare_tensor(graph, v, 'nodes[\"{}\"]'.format(ntype))\n\n    nodes = {ntype: _process_nodes(ntype, v) for ntype, v in nodes.items()}\n    device = context_of(nodes)\n\n    induced_nodes = [\n        nodes.get(ntype, F.copy_to(F.tensor([], graph.idtype), device))\n        for ntype in graph.ntypes\n    ]\n    sgi = graph._graph.node_subgraph(induced_nodes)\n    induced_edges = sgi.induced_edges\n    if not relabel_nodes:\n        sgi = graph._graph.edge_subgraph(induced_edges, True)\n    # (BarclayII) should not write induced_nodes = sgi.induced_nodes due to the same\n    # bug in #1453.\n    induced_nodes_or_device = induced_nodes if relabel_nodes else device\n    subg = _create_hetero_subgraph(\n        graph, sgi, induced_nodes_or_device, induced_edges, store_ids=store_ids\n    )\n    return subg if output_device is None else subg.to(output_device)\n\n\nDGLGraph.subgraph = utils.alias_func(node_subgraph)\n\n\ndef edge_subgraph(\n    graph, edges, *, relabel_nodes=True, store_ids=True, output_device=None\n):\n    \"\"\"Return a subgraph induced on the given edges.\n\n    An edge-induced subgraph is equivalent to creating a new graph using the given\n    edges. In addition to extracting the subgraph, DGL also copies the features\n    of the extracted nodes and edges to the resulting graph. The copy is *lazy*\n    and incurs data movement only when needed.\n\n    If the graph is heterogeneous, DGL extracts a subgraph per relation and composes\n    them as the resulting graph. Thus, the resulting graph has the same set of relations\n    as the input one.\n\n    Parameters\n    ----------\n    graph : DGLGraph\n        The graph to extract the subgraph from.\n    edges : edges or dict[(str, str, str), edges]\n        The edges to form the subgraph. The allowed edges formats are:\n\n        * Int Tensor: Each element is an edge ID. The tensor must have the same device type\n          and ID data type as the graph's.\n        * iterable[int]: Each element is an edge ID.\n        * Bool Tensor: Each :math:`i^{th}` element is a bool flag indicating whether\n          edge :math:`i` is in the subgraph.\n\n        If the graph is homogeneous, one can directly pass the above formats.\n        Otherwise, the argument must be a dictionary with keys being edge types\n        and values being the edge IDs in the above formats.\n    relabel_nodes : bool, optional\n        If True, it will remove the isolated nodes and relabel the incident nodes in the\n        extracted subgraph.\n    store_ids : bool, optional\n        If True, it will store the raw IDs of the extracted edges in the ``edata`` of the\n        resulting graph under name ``dgl.EID``; if ``relabel_nodes`` is ``True``, it will\n        also store the raw IDs of the incident nodes in the ``ndata`` of the resulting\n        graph under name ``dgl.NID``.\n    output_device : Framework-specific device context object, optional\n        The output device.  Default is the same as the input graph.\n\n    Returns\n    -------\n    G : DGLGraph\n        The subgraph.\n\n    Notes\n    -----\n\n    This function discards the batch information. Please use\n    :func:`dgl.DGLGraph.set_batch_num_nodes`\n    and :func:`dgl.DGLGraph.set_batch_num_edges` on the transformed graph\n    to maintain the information.\n\n    Examples\n    --------\n    The following example uses PyTorch backend.\n\n    >>> import dgl\n    >>> import torch\n\n    Extract a subgraph from a homogeneous graph.\n\n    >>> g = dgl.graph(([0, 1, 2, 3, 4], [1, 2, 3, 4, 0]))  # 5-node cycle\n    >>> sg = dgl.edge_subgraph(g, [0, 4])\n    >>> sg\n    Graph(num_nodes=3, num_edges=2,\n          ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)}\n          edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)})\n    >>> sg.edges()\n    (tensor([0, 1]), tensor([2, 0]))\n    >>> sg.ndata[dgl.NID]  # original node IDs\n    tensor([0, 4, 1])\n    >>> sg.edata[dgl.EID]  # original edge IDs\n    tensor([0, 4])\n\n    Extract a subgraph without node relabeling.\n\n    >>> sg = dgl.edge_subgraph(g, [0, 4], relabel_nodes=False)\n    >>> sg\n    Graph(num_nodes=5, num_edges=2,\n          ndata_schemes={}\n          edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)})\n    >>> sg.edges()\n    (tensor([0, 4]), tensor([1, 0]))\n\n    Specify edges using a boolean mask.\n\n    >>> nodes = torch.tensor([True, False, False, False, True])  # choose edges [0, 4]\n    >>> dgl.edge_subgraph(g, nodes)\n    Graph(num_nodes=3, num_edges=2,\n          ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)}\n          edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)})\n\n    The resulting subgraph also copies features from the parent graph.\n\n    >>> g.ndata['x'] = torch.arange(10).view(5, 2)\n    >>> sg = dgl.edge_subgraph(g, [0, 4])\n    >>> sg\n    Graph(num_nodes=3, num_edges=2,\n          ndata_schemes={'x': Scheme(shape=(2,), dtype=torch.int64),\n                         '_ID': Scheme(shape=(), dtype=torch.int64)}\n          edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)})\n    >>> sg.ndata[dgl.NID]\n    tensor([0, 4, 1])\n    >>> sg.ndata['x']\n    tensor([[0, 1],\n            [8, 9],\n            [2, 3]])\n\n    Extract a subgraph from a hetergeneous graph.\n\n    >>> g = dgl.heterograph({\n    >>>     ('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 2, 1]),\n    >>>     ('user', 'follows', 'user'): ([0, 1, 1], [1, 2, 2])\n    >>> })\n    >>> sub_g = dgl.edge_subgraph(g, {('user', 'follows', 'user'): [1, 2],\n    ...                               ('user', 'plays', 'game'): [2]})\n    >>> print(sub_g)\n    Graph(num_nodes={'game': 1, user': 2},\n          num_edges={('user', 'follows', 'user'): 2, ('user', 'plays', 'game'): 1},\n          metagraph=[('user', 'user', 'follows'), ('user', 'game', 'plays')])\n\n    See Also\n    --------\n    node_subgraph\n    \"\"\"\n    if graph.is_block and relabel_nodes:\n        raise DGLError(\"Extracting subgraph from a block graph is not allowed.\")\n    if not isinstance(edges, Mapping):\n        assert (\n            len(graph.canonical_etypes) == 1\n        ), \"need a dict of edge type and IDs for graph with multiple edge types\"\n        edges = {graph.canonical_etypes[0]: edges}\n\n    def _process_edges(etype, e):\n        if F.is_tensor(e) and F.dtype(e) == F.bool:\n            return F.astype(\n                F.nonzero_1d(F.copy_to(e, graph.device)), graph.idtype\n            )\n        else:\n            return utils.prepare_tensor(graph, e, 'edges[\"{}\"]'.format(etype))\n\n    edges = {graph.to_canonical_etype(etype): e for etype, e in edges.items()}\n    edges = {etype: _process_edges(etype, e) for etype, e in edges.items()}\n    device = context_of(edges)\n    induced_edges = [\n        edges.get(cetype, F.copy_to(F.tensor([], graph.idtype), device))\n        for cetype in graph.canonical_etypes\n    ]\n\n    sgi = graph._graph.edge_subgraph(induced_edges, not relabel_nodes)\n    induced_nodes_or_device = sgi.induced_nodes if relabel_nodes else device\n    subg = _create_hetero_subgraph(\n        graph, sgi, induced_nodes_or_device, induced_edges, store_ids=store_ids\n    )\n    return subg if output_device is None else subg.to(output_device)\n\n\nDGLGraph.edge_subgraph = utils.alias_func(edge_subgraph)\n\n\ndef in_subgraph(\n    graph, nodes, *, relabel_nodes=False, store_ids=True, output_device=None\n):\n    \"\"\"Return the subgraph induced on the inbound edges of all the edge types of the\n    given nodes.\n\n    An in subgraph is equivalent to creating a new graph using the incoming edges of the\n    given nodes. In addition to extracting the subgraph, DGL also copies the features of\n    the extracted nodes and edges to the resulting graph. The copy is *lazy* and incurs\n    data movement only when needed.\n\n    If the graph is heterogeneous, DGL extracts a subgraph per relation and composes\n    them as the resulting graph. Thus, the resulting graph has the same set of relations\n    as the input one.\n\n    Parameters\n    ----------\n    graph : DGLGraph\n        The input graph.\n    nodes : nodes or dict[str, nodes]\n        The nodes to form the subgraph, which cannot have any duplicate value. The result\n        will be undefined otherwise. The allowed nodes formats are:\n\n        * Int Tensor: Each element is a node ID. The tensor must have the same device type\n          and ID data type as the graph's.\n        * iterable[int]: Each element is a node ID.\n\n        If the graph is homogeneous, one can directly pass the above formats.\n        Otherwise, the argument must be a dictionary with keys being node types\n        and values being the node IDs in the above formats.\n    relabel_nodes : bool, optional\n        If True, it will remove the isolated nodes and relabel the rest nodes in the\n        extracted subgraph.\n    store_ids : bool, optional\n        If True, it will store the raw IDs of the extracted edges in the ``edata`` of the\n        resulting graph under name ``dgl.EID``; if ``relabel_nodes`` is ``True``, it will\n        also store the raw IDs of the extracted nodes in the ``ndata`` of the resulting\n        graph under name ``dgl.NID``.\n    output_device : Framework-specific device context object, optional\n        The output device.  Default is the same as the input graph.\n\n    Returns\n    -------\n    DGLGraph\n        The subgraph.\n\n    Notes\n    -----\n\n    This function discards the batch information. Please use\n    :func:`dgl.DGLGraph.set_batch_num_nodes`\n    and :func:`dgl.DGLGraph.set_batch_num_edges` on the transformed graph\n    to maintain the information.\n\n    Examples\n    --------\n    The following example uses PyTorch backend.\n\n    >>> import dgl\n    >>> import torch\n\n    Extract a subgraph from a homogeneous graph.\n\n    >>> g = dgl.graph(([0, 1, 2, 3, 4], [1, 2, 3, 4, 0]))  # 5-node cycle\n    >>> g.edata['w'] = torch.arange(10).view(5, 2)\n    >>> sg = dgl.in_subgraph(g, [2, 0])\n    >>> sg\n    Graph(num_nodes=5, num_edges=2,\n          ndata_schemes={}\n          edata_schemes={'w': Scheme(shape=(2,), dtype=torch.int64),\n                         '_ID': Scheme(shape=(), dtype=torch.int64)})\n    >>> sg.edges()\n    (tensor([1, 4]), tensor([2, 0]))\n    >>> sg.edata[dgl.EID]  # original edge IDs\n    tensor([1, 4])\n    >>> sg.edata['w']  # also extract the features\n    tensor([[2, 3],\n            [8, 9]])\n\n    Extract a subgraph with node labeling.\n\n    >>> sg = dgl.in_subgraph(g, [2, 0], relabel_nodes=True)\n    >>> sg\n    Graph(num_nodes=4, num_edges=2,\n          ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64}\n          edata_schemes={'w': Scheme(shape=(2,), dtype=torch.int64),\n                         '_ID': Scheme(shape=(), dtype=torch.int64)})\n    >>> sg.edges()\n    (tensor([1, 3]), tensor([2, 0]))\n    >>> sg.edata[dgl.EID]  # original edge IDs\n    tensor([1, 4])\n    >>> sg.ndata[dgl.NID]  # original node IDs\n    tensor([0, 1, 2, 4])\n\n    Extract a subgraph from a heterogeneous graph.\n\n    >>> g = dgl.heterograph({\n    ...     ('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 2, 1]),\n    ...     ('user', 'follows', 'user'): ([0, 1, 1], [1, 2, 2])})\n    >>> sub_g = g.in_subgraph({'user': [2], 'game': [2]})\n    >>> sub_g\n    Graph(num_nodes={'game': 3, 'user': 3},\n          num_edges={('user', 'plays', 'game'): 1, ('user', 'follows', 'user'): 2},\n          metagraph=[('user', 'game', 'plays'), ('user', 'user', 'follows')])\n\n    See also\n    --------\n    out_subgraph\n    \"\"\"\n    if graph.is_block:\n        raise DGLError(\"Extracting subgraph of a block graph is not allowed.\")\n    if not isinstance(nodes, dict):\n        if len(graph.ntypes) > 1:\n            raise DGLError(\n                \"Must specify node type when the graph is not homogeneous.\"\n            )\n        nodes = {graph.ntypes[0]: nodes}\n    nodes = utils.prepare_tensor_dict(graph, nodes, \"nodes\")\n    device = context_of(nodes)\n    nodes_all_types = [\n        F.to_dgl_nd(\n            nodes.get(ntype, F.copy_to(F.tensor([], graph.idtype), device))\n        )\n        for ntype in graph.ntypes\n    ]\n\n    sgi = _CAPI_DGLInSubgraph(graph._graph, nodes_all_types, relabel_nodes)\n    induced_nodes_or_device = sgi.induced_nodes if relabel_nodes else device\n    induced_edges = sgi.induced_edges\n    subg = _create_hetero_subgraph(\n        graph, sgi, induced_nodes_or_device, induced_edges, store_ids=store_ids\n    )\n    return subg if output_device is None else subg.to(output_device)\n\n\nDGLGraph.in_subgraph = utils.alias_func(in_subgraph)\n\n\ndef out_subgraph(\n    graph, nodes, *, relabel_nodes=False, store_ids=True, output_device=None\n):\n    \"\"\"Return the subgraph induced on the outbound edges of all the edge types of the\n    given nodes.\n\n    An out subgraph is equivalent to creating a new graph using the outcoming edges of\n    the given nodes. In addition to extracting the subgraph, DGL also copies the features\n    of the extracted nodes and edges to the resulting graph. The copy is *lazy* and incurs\n    data movement only when needed.\n\n    If the graph is heterogeneous, DGL extracts a subgraph per relation and composes\n    them as the resulting graph. Thus, the resulting graph has the same set of relations\n    as the input one.\n\n    Parameters\n    ----------\n    graph : DGLGraph\n        The input graph.\n    nodes : nodes or dict[str, nodes]\n        The nodes to form the subgraph, which cannot have any duplicate value. The result\n        will be undefined otherwise. The allowed nodes formats are:\n\n        * Int Tensor: Each element is a node ID. The tensor must have the same device type\n          and ID data type as the graph's.\n        * iterable[int]: Each element is a node ID.\n\n        If the graph is homogeneous, one can directly pass the above formats.\n        Otherwise, the argument must be a dictionary with keys being node types\n        and values being the node IDs in the above formats.\n    relabel_nodes : bool, optional\n        If True, it will remove the isolated nodes and relabel the rest nodes in the\n        extracted subgraph.\n    store_ids : bool, optional\n        If True, it will store the raw IDs of the extracted edges in the ``edata`` of the\n        resulting graph under name ``dgl.EID``; if ``relabel_nodes`` is ``True``, it will\n        also store the raw IDs of the extracted nodes in the ``ndata`` of the resulting\n        graph under name ``dgl.NID``.\n    output_device : Framework-specific device context object, optional\n        The output device.  Default is the same as the input graph.\n\n    Returns\n    -------\n    DGLGraph\n        The subgraph.\n\n    Notes\n    -----\n\n    This function discards the batch information. Please use\n    :func:`dgl.DGLGraph.set_batch_num_nodes`\n    and :func:`dgl.DGLGraph.set_batch_num_edges` on the transformed graph\n    to maintain the information.\n\n    Examples\n    --------\n    The following example uses PyTorch backend.\n\n    >>> import dgl\n    >>> import torch\n\n    Extract a subgraph from a homogeneous graph.\n\n    >>> g = dgl.graph(([0, 1, 2, 3, 4], [1, 2, 3, 4, 0]))  # 5-node cycle\n    >>> g.edata['w'] = torch.arange(10).view(5, 2)\n    >>> sg = dgl.out_subgraph(g, [2, 0])\n    >>> sg\n    Graph(num_nodes=5, num_edges=2,\n          ndata_schemes={}\n          edata_schemes={'w': Scheme(shape=(2,), dtype=torch.int64),\n                         '_ID': Scheme(shape=(), dtype=torch.int64)})\n    >>> sg.edges()\n    (tensor([2, 0]), tensor([3, 1]))\n    >>> sg.edata[dgl.EID]  # original edge IDs\n    tensor([2, 0])\n    >>> sg.edata['w']  # also extract the features\n    tensor([[4, 5],\n            [0, 1]])\n\n    Extract a subgraph with node labeling.\n\n    >>> sg = dgl.out_subgraph(g, [2, 0], relabel_nodes=True)\n    >>> sg\n    Graph(num_nodes=4, num_edges=2,\n          ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)}\n          edata_schemes={'w': Scheme(shape=(2,), dtype=torch.int64),\n                         '_ID': Scheme(shape=(), dtype=torch.int64)})\n    >>> sg.edges()\n    (tensor([2, 0]), tensor([3, 1]))\n    >>> sg.edata[dgl.EID]  # original edge IDs\n    tensor([2, 0])\n    >>> sg.ndata[dgl.NID]  # original node IDs\n    tensor([0, 1, 2, 3])\n\n    Extract a subgraph from a heterogeneous graph.\n\n    >>> g = dgl.heterograph({\n    ...     ('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 2, 1]),\n    ...     ('user', 'follows', 'user'): ([0, 1, 1], [1, 2, 2])})\n    >>> sub_g = g.out_subgraph({'user': [1]})\n    >>> sub_g\n    Graph(num_nodes={'game': 3, 'user': 3},\n          num_edges={('user', 'plays', 'game'): 2, ('user', 'follows', 'user'): 2},\n          metagraph=[('user', 'game', 'plays'), ('user', 'user', 'follows')])\n\n    See also\n    --------\n    in_subgraph\n    \"\"\"\n    if graph.is_block:\n        raise DGLError(\"Extracting subgraph of a block graph is not allowed.\")\n    if not isinstance(nodes, dict):\n        if len(graph.ntypes) > 1:\n            raise DGLError(\n                \"Must specify node type when the graph is not homogeneous.\"\n            )\n        nodes = {graph.ntypes[0]: nodes}\n    nodes = utils.prepare_tensor_dict(graph, nodes, \"nodes\")\n    device = context_of(nodes)\n    nodes_all_types = [\n        F.to_dgl_nd(\n            nodes.get(ntype, F.copy_to(F.tensor([], graph.idtype), device))\n        )\n        for ntype in graph.ntypes\n    ]\n\n    sgi = _CAPI_DGLOutSubgraph(graph._graph, nodes_all_types, relabel_nodes)\n    induced_nodes_or_device = sgi.induced_nodes if relabel_nodes else device\n    induced_edges = sgi.induced_edges\n    subg = _create_hetero_subgraph(\n        graph, sgi, induced_nodes_or_device, induced_edges, store_ids=store_ids\n    )\n    return subg if output_device is None else subg.to(output_device)\n\n\nDGLGraph.out_subgraph = utils.alias_func(out_subgraph)\n\n\ndef khop_in_subgraph(\n    graph, nodes, k, *, relabel_nodes=True, store_ids=True, output_device=None\n):\n    \"\"\"Return the subgraph induced by k-hop in-neighborhood of the specified node(s).\n\n    We can expand a set of nodes by including the predecessors of them. From a\n    specified node set, a k-hop in subgraph is obtained by first repeating the node set\n    expansion for k times and then creating a node induced subgraph. In addition to\n    extracting the subgraph, DGL also copies the features of the extracted nodes and\n    edges to the resulting graph. The copy is *lazy* and incurs data movement only\n    when needed.\n\n    If the graph is heterogeneous, DGL extracts a subgraph per relation and composes\n    them as the resulting graph. Thus the resulting graph has the same set of relations\n    as the input one.\n\n    Parameters\n    ----------\n    graph : DGLGraph\n        The input graph.\n    nodes : nodes or dict[str, nodes]\n        The starting node(s) to expand, which cannot have any duplicate value. The result\n        will be undefined otherwise. The allowed formats are:\n\n        * Int: ID of a single node.\n        * Int Tensor: Each element is a node ID. The tensor must have the same device\n          type and ID data type as the graph's.\n        * iterable[int]: Each element is a node ID.\n\n        If the graph is homogeneous, one can directly pass the above formats.\n        Otherwise, the argument must be a dictionary with keys being node types\n        and values being the node IDs in the above formats.\n    k : int\n        The number of hops.\n    relabel_nodes : bool, optional\n        If True, it will remove the isolated nodes and relabel the rest nodes in the\n        extracted subgraph.\n    store_ids : bool, optional\n        If True, it will store the raw IDs of the extracted edges in the ``edata`` of the\n        resulting graph under name ``dgl.EID``; if ``relabel_nodes`` is ``True``, it will\n        also store the raw IDs of the extracted nodes in the ``ndata`` of the resulting\n        graph under name ``dgl.NID``.\n    output_device : Framework-specific device context object, optional\n        The output device.  Default is the same as the input graph.\n\n    Returns\n    -------\n    DGLGraph\n        The subgraph.\n    Tensor or dict[str, Tensor], optional\n        The new IDs of the input :attr:`nodes` after node relabeling. This is returned\n        only when :attr:`relabel_nodes` is True. It is in the same form as :attr:`nodes`.\n\n    Notes\n    -----\n\n    When k is 1, the result subgraph is different from the one obtained by\n    :func:`dgl.in_subgraph`. The 1-hop in subgraph also includes the edges\n    among the neighborhood.\n\n    Examples\n    --------\n    The following example uses PyTorch backend.\n\n    >>> import dgl\n    >>> import torch\n\n    Extract a two-hop subgraph from a homogeneous graph.\n\n    >>> g = dgl.graph(([1, 1, 2, 3, 4], [0, 2, 0, 4, 2]))\n    >>> g.edata['w'] = torch.arange(10).view(5, 2)\n    >>> sg, inverse_indices = dgl.khop_in_subgraph(g, 0, k=2)\n    >>> sg\n    Graph(num_nodes=4, num_edges=4,\n          ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)}\n          edata_schemes={'w': Scheme(shape=(2,), dtype=torch.int64),\n                         '_ID': Scheme(shape=(), dtype=torch.int64)})\n    >>> sg.edges()\n    (tensor([1, 1, 2, 3]), tensor([0, 2, 0, 2]))\n    >>> sg.edata[dgl.EID]  # original edge IDs\n    tensor([0, 1, 2, 4])\n    >>> sg.edata['w']  # also extract the features\n    tensor([[0, 1],\n            [2, 3],\n            [4, 5],\n            [8, 9]])\n    >>> inverse_indices\n    tensor([0])\n\n    Extract a subgraph from a heterogeneous graph.\n\n    >>> g = dgl.heterograph({\n    ...     ('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 2, 1]),\n    ...     ('user', 'follows', 'user'): ([0, 1, 1], [1, 2, 2])})\n    >>> sg, inverse_indices = dgl.khop_in_subgraph(g, {'game': 0}, k=2)\n    >>> sg\n    Graph(num_nodes={'game': 1, 'user': 2},\n          num_edges={('user', 'follows', 'user'): 1, ('user', 'plays', 'game'): 2},\n          metagraph=[('user', 'user', 'follows'), ('user', 'game', 'plays')])\n    >>> inverse_indices\n    {'game': tensor([0])}\n\n    See also\n    --------\n    khop_out_subgraph\n    \"\"\"\n    if graph.is_block:\n        raise DGLError(\"Extracting subgraph of a block graph is not allowed.\")\n\n    is_mapping = isinstance(nodes, Mapping)\n    if not is_mapping:\n        assert (\n            len(graph.ntypes) == 1\n        ), \"need a dict of node type and IDs for graph with multiple node types\"\n        nodes = {graph.ntypes[0]: nodes}\n\n    for nty, nty_nodes in nodes.items():\n        nodes[nty] = utils.prepare_tensor(\n            graph, nty_nodes, 'nodes[\"{}\"]'.format(nty)\n        )\n\n    last_hop_nodes = nodes\n    k_hop_nodes_ = [last_hop_nodes]\n    device = context_of(nodes)\n    place_holder = F.copy_to(F.tensor([], dtype=graph.idtype), device)\n    for _ in range(k):\n        current_hop_nodes = {nty: [] for nty in graph.ntypes}\n        for cetype in graph.canonical_etypes:\n            srctype, _, dsttype = cetype\n            in_nbrs, _ = graph.in_edges(\n                last_hop_nodes.get(dsttype, place_holder), etype=cetype\n            )\n            current_hop_nodes[srctype].append(in_nbrs)\n        for nty in graph.ntypes:\n            if len(current_hop_nodes[nty]) == 0:\n                current_hop_nodes[nty] = place_holder\n                continue\n            current_hop_nodes[nty] = F.unique(\n                F.cat(current_hop_nodes[nty], dim=0)\n            )\n        k_hop_nodes_.append(current_hop_nodes)\n        last_hop_nodes = current_hop_nodes\n\n    k_hop_nodes = dict()\n    inverse_indices = dict()\n    for nty in graph.ntypes:\n        k_hop_nodes[nty], inverse_indices[nty] = F.unique(\n            F.cat(\n                [\n                    hop_nodes.get(nty, place_holder)\n                    for hop_nodes in k_hop_nodes_\n                ],\n                dim=0,\n            ),\n            return_inverse=True,\n        )\n\n    sub_g = node_subgraph(\n        graph, k_hop_nodes, relabel_nodes=relabel_nodes, store_ids=store_ids\n    )\n    if output_device is not None:\n        sub_g = sub_g.to(output_device)\n    if relabel_nodes:\n        if is_mapping:\n            seed_inverse_indices = dict()\n            for nty in nodes:\n                seed_inverse_indices[nty] = F.slice_axis(\n                    inverse_indices[nty], axis=0, begin=0, end=len(nodes[nty])\n                )\n        else:\n            seed_inverse_indices = F.slice_axis(\n                inverse_indices[nty], axis=0, begin=0, end=len(nodes[nty])\n            )\n        if output_device is not None:\n            seed_inverse_indices = recursive_apply(\n                seed_inverse_indices, lambda x: F.copy_to(x, output_device)\n            )\n        return sub_g, seed_inverse_indices\n    else:\n        return sub_g\n\n\nDGLGraph.khop_in_subgraph = utils.alias_func(khop_in_subgraph)\n\n\ndef khop_out_subgraph(\n    graph, nodes, k, *, relabel_nodes=True, store_ids=True, output_device=None\n):\n    \"\"\"Return the subgraph induced by k-hop out-neighborhood of the specified node(s).\n\n    We can expand a set of nodes by including the successors of them. From a\n    specified node set, a k-hop out subgraph is obtained by first repeating the node set\n    expansion for k times and then creating a node induced subgraph. In addition to\n    extracting the subgraph, DGL also copies the features of the extracted nodes and\n    edges to the resulting graph. The copy is *lazy* and incurs data movement only\n    when needed.\n\n    If the graph is heterogeneous, DGL extracts a subgraph per relation and composes\n    them as the resulting graph. Thus the resulting graph has the same set of relations\n    as the input one.\n\n    Parameters\n    ----------\n    graph : DGLGraph\n        The input graph.\n    nodes : nodes or dict[str, nodes]\n        The starting node(s) to expand, which cannot have any duplicate value. The result\n        will be undefined otherwise. The allowed formats are:\n\n        * Int: ID of a single node.\n        * Int Tensor: Each element is a node ID. The tensor must have the same device\n          type and ID data type as the graph's.\n        * iterable[int]: Each element is a node ID.\n\n        If the graph is homogeneous, one can directly pass the above formats.\n        Otherwise, the argument must be a dictionary with keys being node types\n        and values being the node IDs in the above formats.\n    k : int\n        The number of hops.\n    relabel_nodes : bool, optional\n        If True, it will remove the isolated nodes and relabel the rest nodes in the\n        extracted subgraph.\n    store_ids : bool, optional\n        If True, it will store the raw IDs of the extracted edges in the ``edata`` of the\n        resulting graph under name ``dgl.EID``; if ``relabel_nodes`` is ``True``, it will\n        also store the raw IDs of the extracted nodes in the ``ndata`` of the resulting\n        graph under name ``dgl.NID``.\n    output_device : Framework-specific device context object, optional\n        The output device.  Default is the same as the input graph.\n\n    Returns\n    -------\n    DGLGraph\n        The subgraph.\n    Tensor or dict[str, Tensor], optional\n        The new IDs of the input :attr:`nodes` after node relabeling. This is returned\n        only when :attr:`relabel_nodes` is True. It is in the same form as :attr:`nodes`.\n\n    Notes\n    -----\n\n    When k is 1, the result subgraph is different from the one obtained by\n    :func:`dgl.out_subgraph`. The 1-hop out subgraph also includes the edges\n    among the neighborhood.\n\n    Examples\n    --------\n    The following example uses PyTorch backend.\n\n    >>> import dgl\n    >>> import torch\n\n    Extract a two-hop subgraph from a homogeneous graph.\n\n    >>> g = dgl.graph(([0, 2, 0, 4, 2], [1, 1, 2, 3, 4]))\n    >>> g.edata['w'] = torch.arange(10).view(5, 2)\n    >>> sg, inverse_indices = dgl.khop_out_subgraph(g, 0, k=2)\n    >>> sg\n    Graph(num_nodes=4, num_edges=4,\n          ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)}\n          edata_schemes={'w': Scheme(shape=(2,), dtype=torch.int64),\n                         '_ID': Scheme(shape=(), dtype=torch.int64)})\n    >>> sg.edges()\n    (tensor([0, 0, 2, 2]), tensor([1, 2, 1, 3]))\n    >>> sg.edata[dgl.EID]  # original edge IDs\n    tensor([0, 2, 1, 4])\n    >>> sg.edata['w']  # also extract the features\n    tensor([[0, 1],\n            [4, 5],\n            [2, 3],\n            [8, 9]])\n    >>> inverse_indices\n    tensor([0])\n\n    Extract a subgraph from a heterogeneous graph.\n\n    >>> g = dgl.heterograph({\n    ...     ('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 2, 1]),\n    ...     ('user', 'follows', 'user'): ([0, 1], [1, 3])})\n    >>> sg, inverse_indices = dgl.khop_out_subgraph(g, {'user': 0}, k=2)\n    >>> sg\n    Graph(num_nodes={'game': 2, 'user': 3},\n          num_edges={('user', 'follows', 'user'): 2, ('user', 'plays', 'game'): 2},\n          metagraph=[('user', 'user', 'follows'), ('user', 'game', 'plays')])\n    >>> inverse_indices\n    {'user': tensor([0])}\n\n    See also\n    --------\n    khop_in_subgraph\n    \"\"\"\n    if graph.is_block:\n        raise DGLError(\"Extracting subgraph of a block graph is not allowed.\")\n\n    is_mapping = isinstance(nodes, Mapping)\n    if not is_mapping:\n        assert (\n            len(graph.ntypes) == 1\n        ), \"need a dict of node type and IDs for graph with multiple node types\"\n        nodes = {graph.ntypes[0]: nodes}\n\n    for nty, nty_nodes in nodes.items():\n        nodes[nty] = utils.prepare_tensor(\n            graph, nty_nodes, 'nodes[\"{}\"]'.format(nty)\n        )\n\n    last_hop_nodes = nodes\n    k_hop_nodes_ = [last_hop_nodes]\n    device = context_of(nodes)\n    place_holder = F.copy_to(F.tensor([], dtype=graph.idtype), device)\n    for _ in range(k):\n        current_hop_nodes = {nty: [] for nty in graph.ntypes}\n        for cetype in graph.canonical_etypes:\n            srctype, _, dsttype = cetype\n            _, out_nbrs = graph.out_edges(\n                last_hop_nodes.get(srctype, place_holder), etype=cetype\n            )\n            current_hop_nodes[dsttype].append(out_nbrs)\n        for nty in graph.ntypes:\n            if len(current_hop_nodes[nty]) == 0:\n                current_hop_nodes[nty] = place_holder\n                continue\n            current_hop_nodes[nty] = F.unique(\n                F.cat(current_hop_nodes[nty], dim=0)\n            )\n        k_hop_nodes_.append(current_hop_nodes)\n        last_hop_nodes = current_hop_nodes\n\n    k_hop_nodes = dict()\n    inverse_indices = dict()\n    for nty in graph.ntypes:\n        k_hop_nodes[nty], inverse_indices[nty] = F.unique(\n            F.cat(\n                [\n                    hop_nodes.get(nty, place_holder)\n                    for hop_nodes in k_hop_nodes_\n                ],\n                dim=0,\n            ),\n            return_inverse=True,\n        )\n\n    sub_g = node_subgraph(\n        graph, k_hop_nodes, relabel_nodes=relabel_nodes, store_ids=store_ids\n    )\n    if output_device is not None:\n        sub_g = sub_g.to(output_device)\n    if relabel_nodes:\n        if is_mapping:\n            seed_inverse_indices = dict()\n            for nty in nodes:\n                seed_inverse_indices[nty] = F.slice_axis(\n                    inverse_indices[nty], axis=0, begin=0, end=len(nodes[nty])\n                )\n        else:\n            seed_inverse_indices = F.slice_axis(\n                inverse_indices[nty], axis=0, begin=0, end=len(nodes[nty])\n            )\n        if output_device is not None:\n            seed_inverse_indices = recursive_apply(\n                seed_inverse_indices, lambda x: F.copy_to(x, output_device)\n            )\n        return sub_g, seed_inverse_indices\n    else:\n        return sub_g\n\n\nDGLGraph.khop_out_subgraph = utils.alias_func(khop_out_subgraph)\n\n\ndef node_type_subgraph(graph, ntypes, output_device=None):\n    \"\"\"Return the subgraph induced on given node types.\n\n    A node-type-induced subgraph contains all the nodes of the given subset of\n    the node types of a graph and any edges whose endpoints are both in this subset.\n    In addition to extracting the subgraph, DGL also copies the features of the\n    extracted nodes and edges to the resulting graph.\n    The copy is *lazy* and incurs data movement only when needed.\n\n    Parameters\n    ----------\n    graph : DGLGraph\n        The graph to extract subgraphs from.\n    ntypes : list[str]\n        The type names of the nodes in the subgraph.\n    output_device : Framework-specific device context object, optional\n        The output device.  Default is the same as the input graph.\n\n    Returns\n    -------\n    G : DGLGraph\n        The subgraph.\n\n    Notes\n    -----\n\n    This function discards the batch information. Please use\n    :func:`dgl.DGLGraph.set_batch_num_nodes`\n    and :func:`dgl.DGLGraph.set_batch_num_edges` on the transformed graph\n    to maintain the information.\n\n    Examples\n    --------\n    The following example uses PyTorch backend.\n\n    >>> import dgl\n    >>> import torch\n\n    Instantiate a heterograph.\n\n    >>> g = dgl.heterograph({\n    >>>     ('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 2, 1]),\n    >>>     ('user', 'follows', 'user'): ([0, 1, 1], [1, 2, 2])\n    >>> })\n    >>> # Set node features\n    >>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])\n\n    Get subgraphs.\n\n    >>> sub_g = g.node_type_subgraph(['user'])\n    >>> print(sub_g)\n    Graph(num_nodes=3, num_edges=3,\n          ndata_schemes={'h': Scheme(shape=(1,), dtype=torch.float32)}\n          edata_schemes={})\n\n    Get the extracted node features.\n\n    >>> sub_g.nodes['user'].data['h']\n    tensor([[0.],\n            [1.],\n            [2.]])\n\n    See Also\n    --------\n    edge_type_subgraph\n    \"\"\"\n    ntid = [graph.get_ntype_id(ntype) for ntype in ntypes]\n    stids, dtids, etids = graph._graph.metagraph.edges(\"eid\")\n    stids, dtids, etids = stids.tonumpy(), dtids.tonumpy(), etids.tonumpy()\n    etypes = []\n    for stid, dtid, etid in zip(stids, dtids, etids):\n        if stid in ntid and dtid in ntid:\n            etypes.append(graph.canonical_etypes[etid])\n    if len(etypes) == 0:\n        raise DGLError(\"There are no edges among nodes of the specified types.\")\n    return edge_type_subgraph(graph, etypes, output_device=output_device)\n\n\nDGLGraph.node_type_subgraph = utils.alias_func(node_type_subgraph)\n\n\ndef edge_type_subgraph(graph, etypes, output_device=None):\n    \"\"\"Return the subgraph induced on given edge types.\n\n    An edge-type-induced subgraph contains all the edges of the given subset of\n    the edge types of a graph. It also contains all nodes of a particular type\n    if some nodes of the type are incident to these edges.\n    In addition to extracting the subgraph, DGL also copies the features of the\n    extracted nodes and edges to the resulting graph.\n    The copy is *lazy* and incurs data movement only when needed.\n\n    Parameters\n    ----------\n    graph : DGLGraph\n        The graph to extract subgraphs from.\n    etypes : list[str] or list[(str, str, str)]\n        The type names of the edges in the subgraph. The allowed type name\n        formats are:\n\n        * ``(str, str, str)`` for source node type, edge type and destination node type.\n        * or one ``str`` for the edge type name  if the name can uniquely identify a\n          triplet format in the graph.\n    output_device : Framework-specific device context object, optional\n        The output device.  Default is the same as the input graph.\n\n    Returns\n    -------\n    G : DGLGraph\n        The subgraph.\n\n    Notes\n    -----\n\n    This function discards the batch information. Please use\n    :func:`dgl.DGLGraph.set_batch_num_nodes`\n    and :func:`dgl.DGLGraph.set_batch_num_edges` on the transformed graph\n    to maintain the information.\n\n    Examples\n    --------\n    The following example uses PyTorch backend.\n\n    >>> import dgl\n    >>> import torch\n\n    Instantiate a heterograph.\n\n    >>> g = dgl.heterograph({\n    >>>     ('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 2, 1]),\n    >>>     ('user', 'follows', 'user'): ([0, 1, 1], [1, 2, 2])\n    >>> })\n    >>> # Set edge features\n    >>> g.edges['follows'].data['h'] = torch.tensor([[0.], [1.], [2.]])\n\n    Get subgraphs.\n\n    >>> sub_g = g.edge_type_subgraph(['follows'])\n    >>> sub_g\n    Graph(num_nodes=3, num_edges=3,\n          ndata_schemes={}\n          edata_schemes={'h': Scheme(shape=(1,), dtype=torch.float32)})\n\n    Get the shared edge features.\n\n    >>> sub_g.edges['follows'].data['h']\n    tensor([[0.],\n            [1.],\n            [2.]])\n\n    See Also\n    --------\n    node_type_subgraph\n    \"\"\"\n    etype_ids = [graph.get_etype_id(etype) for etype in etypes]\n    # meta graph is homogeneous graph, still using int64\n    meta_src, meta_dst, _ = graph._graph.metagraph.find_edges(\n        utils.toindex(etype_ids, \"int64\")\n    )\n    rel_graphs = [graph._graph.get_relation_graph(i) for i in etype_ids]\n    meta_src = meta_src.tonumpy()\n    meta_dst = meta_dst.tonumpy()\n    ntypes_invmap = {n: i for i, n in enumerate(set(meta_src) | set(meta_dst))}\n    mapped_meta_src = [ntypes_invmap[v] for v in meta_src]\n    mapped_meta_dst = [ntypes_invmap[v] for v in meta_dst]\n    node_frames = [graph._node_frames[i] for i in ntypes_invmap]\n    edge_frames = [graph._edge_frames[i] for i in etype_ids]\n    induced_ntypes = [graph._ntypes[i] for i in ntypes_invmap]\n    induced_etypes = [\n        graph._etypes[i] for i in etype_ids\n    ]  # get the \"name\" of edge type\n    num_nodes_per_induced_type = [\n        graph.num_nodes(ntype) for ntype in induced_ntypes\n    ]\n\n    metagraph = graph_index.from_edge_list(\n        (mapped_meta_src, mapped_meta_dst), True\n    )\n    # num_nodes_per_type should be int64\n    hgidx = heterograph_index.create_heterograph_from_relations(\n        metagraph,\n        rel_graphs,\n        utils.toindex(num_nodes_per_induced_type, \"int64\"),\n    )\n    hg = DGLGraph(\n        hgidx, induced_ntypes, induced_etypes, node_frames, edge_frames\n    )\n    return hg if output_device is None else hg.to(output_device)\n\n\nDGLGraph.edge_type_subgraph = utils.alias_func(edge_type_subgraph)\n\n#################### Internal functions ####################\n\n\ndef _create_hetero_subgraph(\n    parent,\n    sgi,\n    induced_nodes_or_device,\n    induced_edges_or_device,\n    store_ids=True,\n):\n    \"\"\"Internal function to create a subgraph.\n\n    Parameters\n    ----------\n    parent : DGLGraph\n        The parent DGLGraph.\n    sgi : HeteroSubgraphIndex\n        Subgraph object returned by CAPI.\n    induced_nodes_or_device : list[Tensor] or device or None\n        Induced node IDs or the device. Will store it as the dgl.NID ndata unless it\n        is None, which means the induced node IDs are the same as the parent node IDs.\n        If a device is given, the features will be copied to the given device.\n    induced_edges_or_device : list[Tensor] or device or None\n        Induced edge IDs. Will store it as the dgl.EID ndata unless it\n        is None, which means the induced edge IDs are the same as the parent edge IDs.\n        If a device is given, the features will be copied to the given device.\n    store_ids : bool\n        If True and induced_nodes is not None, it will store the raw IDs of the extracted\n        nodes in the ``ndata`` of the resulting graph under name ``dgl.NID``.\n        If True and induced_edges is not None, it will store the raw IDs of the extracted\n        edges in the ``edata`` of the resulting graph under name ``dgl.EID``.\n\n    Returns\n    -------\n    DGLGraph\n        Graph\n    \"\"\"\n    # (BarclayII) Giving a device argument to induced_nodes_or_device is necessary for\n    # UVA subgraphing, where the node features are not sliced but the device changed.\n    # Not having this will give us a subgraph on GPU but node features on CPU if we don't\n    # relabel the nodes.\n    node_frames = utils.extract_node_subframes(\n        parent, induced_nodes_or_device, store_ids\n    )\n    edge_frames = utils.extract_edge_subframes(\n        parent, induced_edges_or_device, store_ids\n    )\n    hsg = DGLGraph(sgi.graph, parent.ntypes, parent.etypes)\n    utils.set_new_frames(hsg, node_frames=node_frames, edge_frames=edge_frames)\n    return hsg\n\n\n_init_api(\"dgl.subgraph\")\n"
  },
  {
    "path": "python/dgl/transforms/__init__.py",
    "content": "\"\"\"Transform for structures and features\"\"\"\nfrom .functional import *\nfrom .module import *\nfrom .to_block import *\n"
  },
  {
    "path": "python/dgl/transforms/functional.py",
    "content": "##\n#   Copyright 2019-2021 Contributors\n#\n#   Licensed under the Apache License, Version 2.0 (the \"License\");\n#   you may not use this file except in compliance with the License.\n#   You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n#   Unless required by applicable law or agreed to in writing, software\n#   distributed under the License is distributed on an \"AS IS\" BASIS,\n#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#   See the License for the specific language governing permissions and\n#   limitations under the License.\n#\n\"\"\"Functional interface for transform\"\"\"\n# pylint: disable= too-many-lines\n\nimport copy\nfrom collections.abc import Iterable, Mapping\n\nimport numpy as np\nimport scipy.sparse as sparse\nimport scipy.sparse.linalg\n\nfrom ..utils import version\n\ntry:\n    import torch as th\nexcept ImportError:\n    pass\n\nfrom .. import (\n    backend as F,\n    batch,\n    convert,\n    function,\n    ndarray as nd,\n    subgraph,\n    utils,\n)\nfrom .._ffi.function import _init_api\nfrom ..base import dgl_warning, DGLError, EID, NID\nfrom ..frame import Frame\nfrom ..heterograph import DGLGraph\nfrom ..heterograph_index import (\n    create_heterograph_from_relations,\n    create_metagraph_index,\n)\nfrom ..partition import (\n    metis_partition,\n    metis_partition_assignment,\n    partition_graph_with_halo,\n)\nfrom ..sampling.neighbor import sample_neighbors\n\n__all__ = [\n    \"line_graph\",\n    \"khop_adj\",\n    \"khop_graph\",\n    \"reverse\",\n    \"to_bidirected\",\n    \"add_reverse_edges\",\n    \"laplacian_lambda_max\",\n    \"knn_graph\",\n    \"segmented_knn_graph\",\n    \"add_edges\",\n    \"add_nodes\",\n    \"remove_edges\",\n    \"remove_nodes\",\n    \"add_self_loop\",\n    \"remove_self_loop\",\n    \"metapath_reachable_graph\",\n    \"compact_graphs\",\n    \"to_simple\",\n    \"to_simple_graph\",\n    \"sort_csr_by_tag\",\n    \"sort_csc_by_tag\",\n    \"metis_partition_assignment\",\n    \"partition_graph_with_halo\",\n    \"metis_partition\",\n    \"adj_product_graph\",\n    \"adj_sum_graph\",\n    \"reorder_graph\",\n    \"norm_by_dst\",\n    \"radius_graph\",\n    \"random_walk_pe\",\n    \"laplacian_pe\",\n    \"lap_pe\",\n    \"to_bfloat16\",\n    \"to_half\",\n    \"to_float\",\n    \"to_double\",\n    \"double_radius_node_labeling\",\n    \"shortest_dist\",\n    \"svd_pe\",\n]\n\n\ndef pairwise_squared_distance(x):\n    \"\"\"\n    x : (n_samples, n_points, dims)\n    return : (n_samples, n_points, n_points)\n    \"\"\"\n    x2s = F.sum(x * x, -1, True)\n    # assuming that __matmul__ is always implemented (true for PyTorch, MXNet and Chainer)\n    return x2s + F.swapaxes(x2s, -1, -2) - 2 * x @ F.swapaxes(x, -1, -2)\n\n\n# pylint: disable=invalid-name\ndef knn_graph(\n    x, k, algorithm=\"bruteforce-blas\", dist=\"euclidean\", exclude_self=False\n):\n    r\"\"\"Construct a graph from a set of points according to k-nearest-neighbor (KNN)\n    and return.\n\n    The function transforms the coordinates/features of a point set\n    into a directed homogeneous graph. The coordinates of the point\n    set is specified as a matrix whose rows correspond to points and\n    columns correspond to coordinate/feature dimensions.\n\n    The nodes of the returned graph correspond to the points, where the predecessors\n    of each point are its k-nearest neighbors measured by the chosen distance.\n\n    If :attr:`x` is a 3D tensor, then each submatrix will be transformed\n    into a separate graph. DGL then composes the graphs into a large batched\n    graph of multiple (:math:`shape(x)[0]`) connected components.\n\n    See :doc:`the benchmark <../api/python/knn_benchmark>` for a complete benchmark result.\n\n    Parameters\n    ----------\n    x : Tensor\n        The point coordinates. It can be either on CPU or GPU.\n\n        * If is 2D, ``x[i]`` corresponds to the i-th node in the KNN graph.\n\n        * If is 3D, ``x[i]`` corresponds to the i-th KNN graph and\n          ``x[i][j]`` corresponds to the j-th node in the i-th KNN graph.\n    k : int\n        The number of nearest neighbors per node.\n    algorithm : str, optional\n        Algorithm used to compute the k-nearest neighbors.\n\n        * 'bruteforce-blas' will first compute the distance matrix\n          using BLAS matrix multiplication operation provided by\n          backend frameworks. Then use topk algorithm to get\n          k-nearest neighbors. This method is fast when the point\n          set is small but has :math:`O(N^2)` memory complexity where\n          :math:`N` is the number of points.\n\n        * 'bruteforce' will compute distances pair by pair and\n          directly select the k-nearest neighbors during distance\n          computation. This method is slower than 'bruteforce-blas'\n          but has less memory overhead (i.e., :math:`O(Nk)` where :math:`N`\n          is the number of points, :math:`k` is the number of nearest\n          neighbors per node) since we do not need to store all distances.\n\n        * 'bruteforce-sharemem' (CUDA only) is similar to 'bruteforce'\n          but use shared memory in CUDA devices for buffer. This method is\n          faster than 'bruteforce' when the dimension of input points\n          is not large. This method is only available on CUDA device.\n\n        * 'kd-tree' will use the kd-tree algorithm (CPU only).\n          This method is suitable for low-dimensional data (e.g. 3D\n          point clouds)\n\n        * 'nn-descent' is an approximate approach from paper\n          `Efficient k-nearest neighbor graph construction for generic similarity\n          measures <https://www.cs.princeton.edu/cass/papers/www11.pdf>`_. This method\n          will search for nearest neighbor candidates in \"neighbors' neighbors\".\n\n        (default: 'bruteforce-blas')\n    dist : str, optional\n        The distance metric used to compute distance between points. It can be the following\n        metrics:\n        * 'euclidean': Use Euclidean distance (L2 norm) :math:`\\sqrt{\\sum_{i} (x_{i} - y_{i})^{2}}`.\n        * 'cosine': Use cosine distance.\n        (default: 'euclidean')\n    exclude_self : bool, optional\n        If True, the output graph will not contain self loop edges, and each node will not\n        be counted as one of its own k neighbors.  If False, the output graph will contain\n        self loop edges, and a node will be counted as one of its own k neighbors.\n\n    Returns\n    -------\n    DGLGraph\n        The constructed graph. The node IDs are in the same order as :attr:`x`.\n\n    Examples\n    --------\n\n    The following examples use PyTorch backend.\n\n    >>> import dgl\n    >>> import torch\n\n    When :attr:`x` is a 2D tensor, a single KNN graph is constructed.\n\n    >>> x = torch.tensor([[0.0, 0.0, 1.0],\n    ...                   [1.0, 0.5, 0.5],\n    ...                   [0.5, 0.2, 0.2],\n    ...                   [0.3, 0.2, 0.4]])\n    >>> knn_g = dgl.knn_graph(x, 2)  # Each node has two predecessors\n    >>> knn_g.edges()\n    (tensor([0, 1, 2, 2, 2, 3, 3, 3]), tensor([0, 1, 1, 2, 3, 0, 2, 3]))\n\n    When :attr:`x` is a 3D tensor, DGL constructs multiple KNN graphs and\n    and then composes them into a graph of multiple connected components.\n\n    >>> x1 = torch.tensor([[0.0, 0.0, 1.0],\n    ...                    [1.0, 0.5, 0.5],\n    ...                    [0.5, 0.2, 0.2],\n    ...                    [0.3, 0.2, 0.4]])\n    >>> x2 = torch.tensor([[0.0, 1.0, 1.0],\n    ...                    [0.3, 0.3, 0.3],\n    ...                    [0.4, 0.4, 1.0],\n    ...                    [0.3, 0.8, 0.2]])\n    >>> x = torch.stack([x1, x2], dim=0)\n    >>> knn_g = dgl.knn_graph(x, 2)  # Each node has two predecessors\n    >>> knn_g.edges()\n    (tensor([0, 1, 2, 2, 2, 3, 3, 3, 4, 5, 5, 5, 6, 6, 7, 7]),\n     tensor([0, 1, 1, 2, 3, 0, 2, 3, 4, 5, 6, 7, 4, 6, 5, 7]))\n    \"\"\"\n    if exclude_self:\n        # add 1 to k, for the self edge, since it will be removed\n        k = k + 1\n\n    # check invalid k\n    if k <= 0:\n        raise DGLError(\"Invalid k value. expect k > 0, got k = {}\".format(k))\n\n    # check empty point set\n    x_size = tuple(F.shape(x))\n    if x_size[0] == 0:\n        raise DGLError(\"Find empty point set\")\n\n    d = F.ndim(x)\n    x_seg = x_size[0] * [x_size[1]] if d == 3 else [x_size[0]]\n    if algorithm == \"bruteforce-blas\":\n        result = _knn_graph_blas(x, k, dist=dist)\n    else:\n        if d == 3:\n            x = F.reshape(x, (x_size[0] * x_size[1], x_size[2]))\n        out = knn(k, x, x_seg, algorithm=algorithm, dist=dist)\n        row, col = out[1], out[0]\n        result = convert.graph((row, col))\n\n    if d == 3:\n        # set batch information if x is 3D\n        num_nodes = F.tensor(x_seg, dtype=F.int64).to(F.context(x))\n        result.set_batch_num_nodes(num_nodes)\n        # if any segment is too small for k, all algorithms reduce k for all segments\n        clamped_k = min(k, np.min(x_seg))\n        result.set_batch_num_edges(clamped_k * num_nodes)\n\n    if exclude_self:\n        # remove_self_loop will update batch_num_edges as needed\n        result = remove_self_loop(result)\n\n        # If there were more than k(+1) coincident points, there may not have been self loops on\n        # all nodes, in which case there would still be one too many out edges on some nodes.\n        # However, if every node had a self edge, the common case, every node would still have the\n        # same degree as each other, so we can check that condition easily.\n        # The -1 is for the self edge removal.\n        clamped_k = min(k, np.min(x_seg)) - 1\n        if result.num_edges() != clamped_k * result.num_nodes():\n            # edges on any nodes with too high degree should all be length zero,\n            # so pick an arbitrary one to remove from each such node\n            degrees = result.in_degrees()\n            node_indices = F.nonzero_1d(degrees > clamped_k)\n            edges_to_remove_graph = sample_neighbors(\n                result, node_indices, 1, edge_dir=\"in\"\n            )\n            edge_ids = edges_to_remove_graph.edata[EID]\n            result = remove_edges(result, edge_ids)\n\n    return result\n\n\ndef _knn_graph_blas(x, k, dist=\"euclidean\"):\n    r\"\"\"Construct a graph from a set of points according to k-nearest-neighbor (KNN).\n\n    This function first compute the distance matrix using BLAS matrix multiplication\n    operation provided by backend frameworks. Then use topk algorithm to get\n    k-nearest neighbors.\n\n    Parameters\n    ----------\n    x : Tensor\n        The point coordinates. It can be either on CPU or GPU.\n\n        * If is 2D, ``x[i]`` corresponds to the i-th node in the KNN graph.\n\n        * If is 3D, ``x[i]`` corresponds to the i-th KNN graph and\n          ``x[i][j]`` corresponds to the j-th node in the i-th KNN graph.\n    k : int\n        The number of nearest neighbors per node.\n    dist : str, optional\n        The distance metric used to compute distance between points. It can be the following\n        metrics:\n        * 'euclidean': Use Euclidean distance (L2 norm) :math:`\\sqrt{\\sum_{i} (x_{i} - y_{i})^{2}}`.\n        * 'cosine': Use cosine distance.\n        (default: 'euclidean')\n    \"\"\"\n    if F.ndim(x) == 2:\n        x = F.unsqueeze(x, 0)\n    n_samples, n_points, _ = F.shape(x)\n\n    if k > n_points:\n        dgl_warning(\n            \"'k' should be less than or equal to the number of points in 'x'\"\n            \"expect k <= {0}, got k = {1}, use k = {0}\".format(n_points, k)\n        )\n        k = n_points\n\n    # if use cosine distance, normalize input points first\n    # thus we can use euclidean distance to find knn equivalently.\n    if dist == \"cosine\":\n        l2_norm = lambda v: F.sqrt(F.sum(v * v, dim=2, keepdims=True))\n        x = x / (l2_norm(x) + 1e-5)\n\n    ctx = F.context(x)\n    dist = pairwise_squared_distance(x)\n    k_indices = F.astype(F.argtopk(dist, k, 2, descending=False), F.int64)\n    # index offset for each sample\n    offset = F.arange(0, n_samples, ctx=ctx) * n_points\n    offset = F.unsqueeze(offset, 1)\n    src = F.reshape(k_indices, (n_samples, n_points * k))\n    src = F.unsqueeze(src, 0) + offset\n    dst = F.repeat(F.arange(0, n_points, ctx=ctx), k, dim=0)\n    dst = F.unsqueeze(dst, 0) + offset\n    return convert.graph((F.reshape(src, (-1,)), F.reshape(dst, (-1,))))\n\n\n# pylint: disable=invalid-name\ndef segmented_knn_graph(\n    x,\n    k,\n    segs,\n    algorithm=\"bruteforce-blas\",\n    dist=\"euclidean\",\n    exclude_self=False,\n):\n    r\"\"\"Construct multiple graphs from multiple sets of points according to\n    k-nearest-neighbor (KNN) and return.\n\n    Compared with :func:`dgl.knn_graph`, this allows multiple point sets with\n    different capacity. The points from different sets are stored contiguously\n    in the :attr:`x` tensor.\n    :attr:`segs` specifies the number of points in each point set. The\n    function constructs a KNN graph for each point set, where the predecessors\n    of each point are its k-nearest neighbors measured by the Euclidean distance.\n    DGL then composes all KNN graphs\n    into a batched graph with multiple (:math:`len(segs)`) connected components.\n\n    Parameters\n    ----------\n    x : Tensor\n        Coordinates/features of points. Must be 2D. It can be either on CPU or GPU.\n    k : int\n        The number of nearest neighbors per node.\n    segs : list[int]\n        Number of points in each point set. The numbers in :attr:`segs`\n        must sum up to the number of rows in :attr:`x`.\n    algorithm : str, optional\n        Algorithm used to compute the k-nearest neighbors.\n\n        * 'bruteforce-blas' will first compute the distance matrix\n          using BLAS matrix multiplication operation provided by\n          backend frameworks. Then use topk algorithm to get\n          k-nearest neighbors. This method is fast when the point\n          set is small but has :math:`O(N^2)` memory complexity where\n          :math:`N` is the number of points.\n\n        * 'bruteforce' will compute distances pair by pair and\n          directly select the k-nearest neighbors during distance\n          computation. This method is slower than 'bruteforce-blas'\n          but has less memory overhead (i.e., :math:`O(Nk)` where :math:`N`\n          is the number of points, :math:`k` is the number of nearest\n          neighbors per node) since we do not need to store all distances.\n\n        * 'bruteforce-sharemem' (CUDA only) is similar to 'bruteforce'\n          but use shared memory in CUDA devices for buffer. This method is\n          faster than 'bruteforce' when the dimension of input points\n          is not large. This method is only available on CUDA device.\n\n        * 'kd-tree' will use the kd-tree algorithm (CPU only).\n          This method is suitable for low-dimensional data (e.g. 3D\n          point clouds)\n\n        * 'nn-descent' is an approximate approach from paper\n          `Efficient k-nearest neighbor graph construction for generic similarity\n          measures <https://www.cs.princeton.edu/cass/papers/www11.pdf>`_. This method\n          will search for nearest neighbor candidates in \"neighbors' neighbors\".\n\n        (default: 'bruteforce-blas')\n    dist : str, optional\n        The distance metric used to compute distance between points. It can be the following\n        metrics:\n        * 'euclidean': Use Euclidean distance (L2 norm) :math:`\\sqrt{\\sum_{i} (x_{i} - y_{i})^{2}}`.\n        * 'cosine': Use cosine distance.\n        (default: 'euclidean')\n    exclude_self : bool, optional\n        If True, the output graph will not contain self loop edges, and each node will not\n        be counted as one of its own k neighbors.  If False, the output graph will contain\n        self loop edges, and a node will be counted as one of its own k neighbors.\n\n    Returns\n    -------\n    DGLGraph\n        The batched graph. The node IDs are in the same order as :attr:`x`.\n\n    Examples\n    --------\n\n    The following examples use PyTorch backend.\n\n    >>> import dgl\n    >>> import torch\n\n    In the example below, the first point set has three points\n    and the second point set has four points.\n\n    >>> # Features/coordinates of the first point set\n    >>> x1 = torch.tensor([[0.0, 0.5, 0.2],\n    ...                    [0.1, 0.3, 0.2],\n    ...                    [0.4, 0.2, 0.2]])\n    >>> # Features/coordinates of the second point set\n    >>> x2 = torch.tensor([[0.3, 0.2, 0.1],\n    ...                    [0.5, 0.2, 0.3],\n    ...                    [0.1, 0.1, 0.2],\n    ...                    [0.6, 0.3, 0.3]])\n    >>> x = torch.cat([x1, x2], dim=0)\n    >>> segs = [x1.shape[0], x2.shape[0]]\n    >>> knn_g = dgl.segmented_knn_graph(x, 2, segs)\n    >>> knn_g.edges()\n    (tensor([0, 0, 1, 1, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6]),\n     tensor([0, 1, 0, 1, 2, 2, 3, 5, 4, 6, 3, 5, 4, 6]))\n    \"\"\"\n    if exclude_self:\n        # add 1 to k, for the self edge, since it will be removed\n        k = k + 1\n\n    # check invalid k\n    if k <= 0:\n        raise DGLError(\"Invalid k value. expect k > 0, got k = {}\".format(k))\n\n    # check empty point set\n    if F.shape(x)[0] == 0:\n        raise DGLError(\"Find empty point set\")\n\n    if algorithm == \"bruteforce-blas\":\n        result = _segmented_knn_graph_blas(x, k, segs, dist=dist)\n    else:\n        out = knn(k, x, segs, algorithm=algorithm, dist=dist)\n        row, col = out[1], out[0]\n        result = convert.graph((row, col))\n\n    num_nodes = F.tensor(segs, dtype=F.int64).to(F.context(x))\n    result.set_batch_num_nodes(num_nodes)\n    # if any segment is too small for k, all algorithms reduce k for all segments\n    clamped_k = min(k, np.min(segs))\n    result.set_batch_num_edges(clamped_k * num_nodes)\n\n    if exclude_self:\n        # remove_self_loop will update batch_num_edges as needed\n        result = remove_self_loop(result)\n\n        # If there were more than k(+1) coincident points, there may not have been self loops on\n        # all nodes, in which case there would still be one too many out edges on some nodes.\n        # However, if every node had a self edge, the common case, every node would still have the\n        # same degree as each other, so we can check that condition easily.\n        # The -1 is for the self edge removal.\n        clamped_k = min(k, np.min(segs)) - 1\n        if result.num_edges() != clamped_k * result.num_nodes():\n            # edges on any nodes with too high degree should all be length zero,\n            # so pick an arbitrary one to remove from each such node\n            degrees = result.in_degrees()\n            node_indices = F.nonzero_1d(degrees > clamped_k)\n            edges_to_remove_graph = sample_neighbors(\n                result, node_indices, 1, edge_dir=\"in\"\n            )\n            edge_ids = edges_to_remove_graph.edata[EID]\n            result = remove_edges(result, edge_ids)\n\n    return result\n\n\ndef _segmented_knn_graph_blas(x, k, segs, dist=\"euclidean\"):\n    r\"\"\"Construct multiple graphs from multiple sets of points according to\n    k-nearest-neighbor (KNN).\n\n    This function first compute the distance matrix using BLAS matrix multiplication\n    operation provided by backend frameworks. Then use topk algorithm to get\n    k-nearest neighbors.\n\n    Parameters\n    ----------\n    x : Tensor\n        Coordinates/features of points. Must be 2D. It can be either on CPU or GPU.\n    k : int\n        The number of nearest neighbors per node.\n    segs : list[int]\n        Number of points in each point set. The numbers in :attr:`segs`\n        must sum up to the number of rows in :attr:`x`.\n    dist : str, optional\n        The distance metric used to compute distance between points. It can be the following\n        metrics:\n        * 'euclidean': Use Euclidean distance (L2 norm) :math:`\\sqrt{\\sum_{i} (x_{i} - y_{i})^{2}}`.\n        * 'cosine': Use cosine distance.\n        (default: 'euclidean')\n    \"\"\"\n    # if use cosine distance, normalize input points first\n    # thus we can use euclidean distance to find knn equivalently.\n    if dist == \"cosine\":\n        l2_norm = lambda v: F.sqrt(F.sum(v * v, dim=1, keepdims=True))\n        x = x / (l2_norm(x) + 1e-5)\n\n    n_total_points, _ = F.shape(x)\n    offset = np.insert(np.cumsum(segs), 0, 0)\n    min_seg_size = np.min(segs)\n    if k > min_seg_size:\n        dgl_warning(\n            \"'k' should be less than or equal to the number of points in 'x'\"\n            \"expect k <= {0}, got k = {1}, use k = {0}\".format(min_seg_size, k)\n        )\n        k = min_seg_size\n\n    h_list = F.split(x, segs, 0)\n    src = [\n        F.argtopk(pairwise_squared_distance(h_g), k, 1, descending=False)\n        + int(offset[i])\n        for i, h_g in enumerate(h_list)\n    ]\n    src = F.cat(src, 0)\n    ctx = F.context(x)\n    dst = F.repeat(F.arange(0, n_total_points, ctx=ctx), k, dim=0)\n    return convert.graph((F.reshape(src, (-1,)), F.reshape(dst, (-1,))))\n\n\ndef _nndescent_knn_graph(\n    x,\n    k,\n    segs,\n    num_iters=None,\n    max_candidates=None,\n    delta=0.001,\n    sample_rate=0.5,\n    dist=\"euclidean\",\n):\n    r\"\"\"Construct multiple graphs from multiple sets of points according to\n    **approximate** k-nearest-neighbor using NN-descent algorithm from paper\n    `Efficient k-nearest neighbor graph construction for generic similarity\n    measures <https://www.cs.princeton.edu/cass/papers/www11.pdf>`_.\n\n    Parameters\n    ----------\n    x : Tensor\n        Coordinates/features of points. Must be 2D. It can be either on CPU or GPU.\n    k : int\n        The number of nearest neighbors per node.\n    segs : list[int]\n        Number of points in each point set. The numbers in :attr:`segs`\n        must sum up to the number of rows in :attr:`x`.\n    num_iters : int, optional\n        The maximum number of NN-descent iterations to perform. A value will be\n        chosen based on the size of input by default.\n        (Default: None)\n    max_candidates : int, optional\n        The maximum number of candidates to be considered during one iteration.\n        Larger values will provide more accurate search results later, but\n        potentially at non-negligible computation cost. A value will be chosen\n        based on the number of neighbors by default.\n        (Default: None)\n    delta : float, optional\n        A value controls the early abort. This function will abort if\n        :math:`k * N * delta > c`, where :math:`N` is the number of points,\n        :math:`c` is the number of updates during last iteration.\n        (Default: 0.001)\n    sample_rate : float, optional\n        A value controls how many candidates sampled. It should be a float value\n        between 0 and 1. Larger values will provide higher accuracy and converge\n        speed but with higher time cost.\n        (Default: 0.5)\n    dist : str, optional\n        The distance metric used to compute distance between points. It can be the following\n        metrics:\n        * 'euclidean': Use Euclidean distance (L2 norm) :math:`\\sqrt{\\sum_{i} (x_{i} - y_{i})^{2}}`.\n        * 'cosine': Use cosine distance.\n        (default: 'euclidean')\n\n    Returns\n    -------\n    DGLGraph\n        The graph. The node IDs are in the same order as :attr:`x`.\n    \"\"\"\n    num_points, _ = F.shape(x)\n    if isinstance(segs, (tuple, list)):\n        segs = F.tensor(segs)\n    segs = F.copy_to(segs, F.context(x))\n\n    if max_candidates is None:\n        max_candidates = min(60, k)\n    if num_iters is None:\n        num_iters = max(10, int(round(np.log2(num_points))))\n    max_candidates = int(sample_rate * max_candidates)\n\n    # if use cosine distance, normalize input points first\n    # thus we can use euclidean distance to find knn equivalently.\n    if dist == \"cosine\":\n        l2_norm = lambda v: F.sqrt(F.sum(v * v, dim=1, keepdims=True))\n        x = x / (l2_norm(x) + 1e-5)\n\n    # k must less than or equal to min(segs)\n    if k > F.min(segs, dim=0):\n        raise DGLError(\n            \"'k' must be less than or equal to the number of points in 'x'\"\n            \"expect 'k' <= {}, got 'k' = {}\".format(F.min(segs, dim=0), k)\n        )\n    if delta < 0 or delta > 1:\n        raise DGLError(\"'delta' must in [0, 1], got 'delta' = {}\".format(delta))\n\n    offset = F.zeros((F.shape(segs)[0] + 1,), F.dtype(segs), F.context(segs))\n    offset[1:] = F.cumsum(segs, dim=0)\n    out = F.zeros((2, num_points * k), F.dtype(segs), F.context(segs))\n\n    # points, offsets, out, k, num_iters, max_candidates, delta\n    _CAPI_DGLNNDescent(\n        F.to_dgl_nd(x),\n        F.to_dgl_nd(offset),\n        F.zerocopy_to_dgl_ndarray_for_write(out),\n        k,\n        num_iters,\n        max_candidates,\n        delta,\n    )\n    return out\n\n\ndef knn(\n    k, x, x_segs, y=None, y_segs=None, algorithm=\"bruteforce\", dist=\"euclidean\"\n):\n    r\"\"\"For each element in each segment in :attr:`y`, find :attr:`k` nearest\n    points in the same segment in :attr:`x`. If :attr:`y` is None, perform a self-query\n    over :attr:`x`.\n\n    This function allows multiple point sets with different capacity. The points\n    from different sets are stored contiguously in the :attr:`x` and :attr:`y` tensor.\n    :attr:`x_segs` and :attr:`y_segs` specifies the number of points in each point set.\n\n    Parameters\n    ----------\n    k : int\n        The number of nearest neighbors per node.\n    x : Tensor\n        The point coordinates in x. It can be either on CPU or GPU (must be the\n        same as :attr:`y`). Must be 2D.\n    x_segs : Union[List[int], Tensor]\n        Number of points in each point set in :attr:`x`. The numbers in :attr:`x_segs`\n        must sum up to the number of rows in :attr:`x`.\n    y : Tensor, optional\n        The point coordinates in y. It can be either on CPU or GPU (must be the\n        same as :attr:`x`). Must be 2D.\n        (default: None)\n    y_segs : Union[List[int], Tensor], optional\n        Number of points in each point set in :attr:`y`. The numbers in :attr:`y_segs`\n        must sum up to the number of rows in :attr:`y`.\n        (default: None)\n    algorithm : str, optional\n        Algorithm used to compute the k-nearest neighbors.\n\n        * 'bruteforce' will compute distances pair by pair and\n          directly select the k-nearest neighbors during distance\n          computation. This method is slower than 'bruteforce-blas'\n          but has less memory overhead (i.e., :math:`O(Nk)` where :math:`N`\n          is the number of points, :math:`k` is the number of nearest\n          neighbors per node) since we do not need to store all distances.\n\n        * 'bruteforce-sharemem' (CUDA only) is similar to 'bruteforce'\n          but use shared memory in CUDA devices for buffer. This method is\n          faster than 'bruteforce' when the dimension of input points\n          is not large. This method is only available on CUDA device.\n\n        * 'kd-tree' will use the kd-tree algorithm (CPU only).\n          This method is suitable for low-dimensional data (e.g. 3D\n          point clouds)\n\n        * 'nn-descent' is an approximate approach from paper\n          `Efficient k-nearest neighbor graph construction for generic similarity\n          measures <https://www.cs.princeton.edu/cass/papers/www11.pdf>`_. This method\n          will search for nearest neighbor candidates in \"neighbors' neighbors\".\n\n        Note: Currently, 'nn-descent' only supports self-query cases, i.e. :attr:`y` is None.\n        (default: 'bruteforce')\n    dist : str, optional\n        The distance metric used to compute distance between points. It can be the following\n        metrics:\n        * 'euclidean': Use Euclidean distance (L2 norm) :math:`\\sqrt{\\sum_{i} (x_{i} - y_{i})^{2}}`.\n        * 'cosine': Use cosine distance.\n        (default: 'euclidean')\n\n    Returns\n    -------\n    Tensor\n        Tensor with size `(2, k * num_points(y))`\n        The first subtensor contains point indexs in :attr:`y`. The second subtensor contains\n        point indexs in :attr:`x`\n    \"\"\"\n    # TODO(lygztq) add support for querying different point sets using nn-descent.\n    if algorithm == \"nn-descent\":\n        if y is not None or y_segs is not None:\n            raise DGLError(\n                \"Currently 'nn-descent' only supports self-query cases.\"\n            )\n        return _nndescent_knn_graph(x, k, x_segs, dist=dist)\n\n    # self query\n    if y is None:\n        y = x\n        y_segs = x_segs\n\n    assert F.context(x) == F.context(y)\n    if isinstance(x_segs, (tuple, list)):\n        x_segs = F.tensor(x_segs)\n    if isinstance(y_segs, (tuple, list)):\n        y_segs = F.tensor(y_segs)\n    x_segs = F.copy_to(x_segs, F.context(x))\n    y_segs = F.copy_to(y_segs, F.context(y))\n\n    # k shoule be less than or equal to min(x_segs)\n    min_num_points = F.min(x_segs, dim=0)\n    if k > min_num_points:\n        dgl_warning(\n            \"'k' should be less than or equal to the number of points in 'x'\"\n            \"expect k <= {0}, got k = {1}, use k = {0}\".format(\n                min_num_points, k\n            )\n        )\n        k = F.as_scalar(min_num_points)\n\n    # invalid k\n    if k <= 0:\n        raise DGLError(\"Invalid k value. expect k > 0, got k = {}\".format(k))\n\n    # empty point set\n    if F.shape(x)[0] == 0 or F.shape(y)[0] == 0:\n        raise DGLError(\"Find empty point set\")\n\n    dist = dist.lower()\n    dist_metric_list = [\"euclidean\", \"cosine\"]\n    if dist not in dist_metric_list:\n        raise DGLError(\n            \"Only {} are supported for distance\"\n            \"computation, got {}\".format(dist_metric_list, dist)\n        )\n\n    x_offset = F.zeros(\n        (F.shape(x_segs)[0] + 1,), F.dtype(x_segs), F.context(x_segs)\n    )\n    x_offset[1:] = F.cumsum(x_segs, dim=0)\n    y_offset = F.zeros(\n        (F.shape(y_segs)[0] + 1,), F.dtype(y_segs), F.context(y_segs)\n    )\n    y_offset[1:] = F.cumsum(y_segs, dim=0)\n\n    out = F.zeros((2, F.shape(y)[0] * k), F.dtype(x_segs), F.context(x_segs))\n\n    # if use cosine distance, normalize input points first\n    # thus we can use euclidean distance to find knn equivalently.\n    if dist == \"cosine\":\n        l2_norm = lambda v: F.sqrt(F.sum(v * v, dim=1, keepdims=True))\n        x = x / (l2_norm(x) + 1e-5)\n        y = y / (l2_norm(y) + 1e-5)\n\n    _CAPI_DGLKNN(\n        F.to_dgl_nd(x),\n        F.to_dgl_nd(x_offset),\n        F.to_dgl_nd(y),\n        F.to_dgl_nd(y_offset),\n        k,\n        F.zerocopy_to_dgl_ndarray_for_write(out),\n        algorithm,\n    )\n    return out\n\n\ndef to_bidirected(g, copy_ndata=False, readonly=None):\n    r\"\"\"Convert the graph to a bi-directional simple graph and return.\n\n    For an input graph :math:`G`, return a new graph :math:`G'` such that an edge\n    :math:`(u, v)\\in G'` exists if and only if there exists an edge\n    :math:`(v, u)\\in G`. The resulting graph :math:`G'` is a simple graph,\n    meaning there is no parallel edge.\n\n    The operation only works for edges whose two endpoints belong to the same node type.\n    DGL will raise error if the input graph is heterogeneous and contains edges\n    with different types of endpoints.\n\n    Parameters\n    ----------\n    g : DGLGraph\n        The input graph.\n    copy_ndata: bool, optional\n        If True, the node features of the bidirected graph are copied from the\n        original graph. If False, the bidirected graph will not have any node features.\n        (Default: False)\n    readonly : bool\n        **DEPRECATED**.\n\n    Returns\n    -------\n    DGLGraph\n        The bidirected graph\n\n    Notes\n    -----\n    If :attr:`copy_ndata` is True, the resulting graph will share the node feature\n    tensors with the input graph. Hence, users should try to avoid in-place operations\n    which will be visible to both graphs.\n\n    This function discards the batch information. Please use\n    :func:`dgl.DGLGraph.set_batch_num_nodes`\n    and :func:`dgl.DGLGraph.set_batch_num_edges` on the transformed graph\n    to maintain the information.\n\n    Examples\n    --------\n    The following examples use PyTorch backend.\n\n    >>> import dgl\n    >>> import torch as th\n    >>> g = dgl.graph((th.tensor([0, 1, 2]), th.tensor([1, 2, 0])))\n    >>> bg1 = dgl.to_bidirected(g)\n    >>> bg1.edges()\n    (tensor([0, 1, 2, 1, 2, 0]), tensor([1, 2, 0, 0, 1, 2]))\n\n    The graph already have i->j and j->i\n\n    >>> g = dgl.graph((th.tensor([0, 1, 2, 0]), th.tensor([1, 2, 0, 2])))\n    >>> bg1 = dgl.to_bidirected(g)\n    >>> bg1.edges()\n    (tensor([0, 1, 2, 1, 2, 0]), tensor([1, 2, 0, 0, 1, 2]))\n\n    **Heterogeneous graphs with Multiple Edge Types**\n\n    >>> g = dgl.heterograph({\n    ...     ('user', 'wins', 'user'): (th.tensor([0, 2, 0, 2]), th.tensor([1, 1, 2, 0])),\n    ...     ('user', 'follows', 'user'): (th.tensor([1, 2, 1]), th.tensor([2, 1, 1]))\n    ... })\n    >>> bg1 = dgl.to_bidirected(g)\n    >>> bg1.edges(etype='wins')\n    (tensor([0, 0, 1, 1, 2, 2]), tensor([1, 2, 0, 2, 0, 1]))\n    >>> bg1.edges(etype='follows')\n    (tensor([1, 1, 2]), tensor([1, 2, 1]))\n    \"\"\"\n    if readonly is not None:\n        dgl_warning(\n            \"Parameter readonly is deprecated\"\n            \"There will be no difference between readonly and non-readonly DGLGraph\"\n        )\n\n    for c_etype in g.canonical_etypes:\n        if c_etype[0] != c_etype[2]:\n            assert False, (\n                \"to_bidirected is not well defined for \"\n                \"unidirectional bipartite graphs\"\n                \", but {} is unidirectional bipartite\".format(c_etype)\n            )\n\n    g = add_reverse_edges(g, copy_ndata=copy_ndata, copy_edata=False)\n    g = to_simple(\n        g, return_counts=None, copy_ndata=copy_ndata, copy_edata=False\n    )\n    return g\n\n\ndef add_reverse_edges(\n    g,\n    readonly=None,\n    copy_ndata=True,\n    copy_edata=False,\n    ignore_bipartite=False,\n    exclude_self=True,\n):\n    r\"\"\"Add a reversed edge for each edge in the input graph and return a new graph.\n\n    For a graph with edges :math:`(i_1, j_1), \\cdots, (i_n, j_n)`, this\n    function creates a new graph with edges\n    :math:`(i_1, j_1), \\cdots, (i_n, j_n), (j_1, i_1), \\cdots, (j_n, i_n)`.\n\n    The returned graph may have duplicate edges. To create a bidirected graph without\n    duplicate edges, use :func:`to_bidirected`.\n\n    The operation only works for edges whose two endpoints belong to the same node type.\n    DGL will raise error if the input graph is heterogeneous and contains edges\n    with different types of endpoints. If :attr:`ignore_bipartite` is true, DGL will\n    ignore those edges instead.\n\n    Parameters\n    ----------\n    g : DGLGraph\n        The input graph.\n    readonly : bool, default to be True\n        Deprecated. There will be no difference between readonly and non-readonly\n    copy_ndata: bool, optional\n        If True, the node features of the new graph are copied from\n        the original graph. If False, the new graph will not have any\n        node features.\n\n        (Default: True)\n    copy_edata: bool, optional\n        If True, the features of the reversed edges will be identical to\n        the original ones.\n\n        If False, the new graph will not have any edge features.\n\n        (Default: False)\n    ignore_bipartite: bool, optional\n        If True, unidirectional bipartite graphs are ignored and\n        no error is raised. If False, an error will be raised if\n        an edge type of the input heterogeneous graph is for a unidirectional\n        bipartite graph.\n    exclude_self: bool, optional\n        If True, it does not add reverse edges for self-loops, which is likely\n        meaningless in most cases.\n\n    Returns\n    -------\n    DGLGraph\n        The graph with reversed edges added.\n\n    Notes\n    -----\n    If :attr:`copy_ndata` is True, the resulting graph will share the node feature\n    tensors with the input graph. Hence, users should try to avoid in-place operations\n    which will be visible to both graphs. On the contrary, the two graphs do not share\n    the same edge feature storage.\n\n    This function discards the batch information. Please use\n    :func:`dgl.DGLGraph.set_batch_num_nodes`\n    and :func:`dgl.DGLGraph.set_batch_num_edges` on the transformed graph\n    to maintain the information.\n\n    Examples\n    --------\n    **Homogeneous graphs**\n\n    >>> g = dgl.graph((th.tensor([0, 0]), th.tensor([0, 1])))\n    >>> bg1 = dgl.add_reverse_edges(g)\n    >>> bg1.edges()\n    (tensor([0, 0, 0, 1]), tensor([0, 1, 0, 0]))\n\n    **Heterogeneous graphs**\n\n    >>> g = dgl.heterograph({\n    >>>     ('user', 'wins', 'user'): (th.tensor([0, 2, 0, 2, 2]), th.tensor([1, 1, 2, 1, 0])),\n    >>>     ('user', 'plays', 'game'): (th.tensor([1, 2, 1]), th.tensor([2, 1, 1])),\n    >>>     ('user', 'follows', 'user'): (th.tensor([1, 2, 1), th.tensor([0, 0, 0]))\n    >>> })\n    >>> g.nodes['game'].data['hv'] = th.ones(3, 1)\n    >>> g.edges['wins'].data['h'] = th.tensor([0, 1, 2, 3, 4])\n\n    The :func:`add_reverse_edges` operation is applied to the edge type\n    ``('user', 'wins', 'user')`` and the edge type ``('user', 'follows', 'user')``.\n    The edge type ``('user', 'plays', 'game')`` is ignored.  Both the node features and\n    edge features are shared.\n\n    >>> bg = dgl.add_reverse_edges(g, copy_ndata=True,\n                               copy_edata=True, ignore_bipartite=True)\n    >>> bg.edges(('user', 'wins', 'user'))\n    (tensor([0, 2, 0, 2, 2, 1, 1, 2, 1, 0]), tensor([1, 1, 2, 1, 0, 0, 2, 0, 2, 2]))\n    >>> bg.edges(('user', 'follows', 'user'))\n    (tensor([1, 2, 1, 0, 0, 0]), tensor([0, 0, 0, 1, 2, 1]))\n    >>> bg.edges(('user', 'plays', 'game'))\n    (th.tensor([1, 2, 1]), th.tensor([2, 1, 1]))\n    >>> bg.nodes['game'].data['hv']\n    tensor([0, 0, 0])\n    >>> bg.edges[('user', 'wins', 'user')].data['h']\n    th.tensor([0, 1, 2, 3, 4, 0, 1, 2, 3, 4])\n    \"\"\"\n    if readonly is not None:\n        dgl_warning(\n            \"Parameter readonly is deprecated\"\n            \"There will be no difference between readonly and non-readonly DGLGraph\"\n        )\n\n    # get node cnt for each ntype\n    num_nodes_dict = {}\n    for ntype in g.ntypes:\n        num_nodes_dict[ntype] = g.num_nodes(ntype)\n\n    canonical_etypes = g.canonical_etypes\n    num_nodes_dict = {ntype: g.num_nodes(ntype) for ntype in g.ntypes}\n    subgs = {}\n    rev_eids = {}\n\n    def add_for_etype(etype):\n        u, v = g.edges(form=\"uv\", order=\"eid\", etype=etype)\n        rev_u, rev_v = v, u\n        eid = F.copy_to(F.arange(0, g.num_edges(etype)), g.device)\n        if exclude_self:\n            self_loop_mask = F.equal(rev_u, rev_v)\n            non_self_loop_mask = F.logical_not(self_loop_mask)\n            rev_u = F.boolean_mask(rev_u, non_self_loop_mask)\n            rev_v = F.boolean_mask(rev_v, non_self_loop_mask)\n            non_self_loop_eid = F.boolean_mask(eid, non_self_loop_mask)\n            rev_eids[etype] = F.cat([eid, non_self_loop_eid], 0)\n        else:\n            rev_eids[etype] = F.cat([eid, eid], 0)\n        subgs[etype] = (F.cat([u, rev_u], dim=0), F.cat([v, rev_v], dim=0))\n\n    # fast path\n    if ignore_bipartite is False:\n        for c_etype in canonical_etypes:\n            if c_etype[0] != c_etype[2]:\n                assert False, (\n                    \"add_reverse_edges is not well defined for \"\n                    \"unidirectional bipartite graphs\"\n                    \", but {} is unidirectional bipartite\".format(c_etype)\n                )\n            add_for_etype(c_etype)\n\n        new_g = convert.heterograph(subgs, num_nodes_dict=num_nodes_dict)\n    else:\n        for c_etype in canonical_etypes:\n            if c_etype[0] != c_etype[2]:\n                u, v = g.edges(form=\"uv\", order=\"eid\", etype=c_etype)\n                subgs[c_etype] = (u, v)\n            else:\n                add_for_etype(c_etype)\n\n        new_g = convert.heterograph(subgs, num_nodes_dict=num_nodes_dict)\n\n    # handle features\n    if copy_ndata:\n        node_frames = utils.extract_node_subframes(g, None)\n        utils.set_new_frames(new_g, node_frames=node_frames)\n\n    if copy_edata:\n        # find indices\n        eids = []\n        for c_etype in canonical_etypes:\n            if c_etype[0] != c_etype[2]:\n                eids.append(\n                    F.copy_to(F.arange(0, g.num_edges(c_etype)), new_g.device)\n                )\n            else:\n                eids.append(rev_eids[c_etype])\n\n        edge_frames = utils.extract_edge_subframes(g, eids)\n        utils.set_new_frames(new_g, edge_frames=edge_frames)\n\n    return new_g\n\n\ndef line_graph(g, backtracking=True, shared=False):\n    \"\"\"Return the line graph of this graph.\n\n    The line graph ``L(G)`` of a given graph ``G`` is defined as another graph where\n    the nodes in ``L(G)`` correspond to the edges in ``G``.  For any pair of edges ``(u, v)``\n    and ``(v, w)`` in ``G``, the corresponding node of edge ``(u, v)`` in ``L(G)`` will\n    have an edge connecting to the corresponding node of edge ``(v, w)``.\n\n    Parameters\n    ----------\n    g : DGLGraph\n        Input graph.  Must be homogeneous.\n    backtracking : bool, optional\n        If False, the line graph node corresponding to edge ``(u, v)`` will not have\n        an edge connecting to the line graph node corresponding to edge ``(v, u)``.\n\n        Default: True.\n    shared : bool, optional\n        Whether to copy the edge features of the original graph as the node features\n        of the result line graph.\n\n    Returns\n    -------\n    G : DGLGraph\n        The line graph of this graph.\n\n    Notes\n    -----\n    * If :attr:`shared` is True, the node features of the resulting graph share the same\n      storage with the edge features of the input graph. Hence, users should try to\n      avoid in-place operations which will be visible to both graphs.\n    * The function supports input graph on GPU but copies it to CPU during computation.\n    * This function discards the batch information. Please use\n      :func:`dgl.DGLGraph.set_batch_num_nodes`\n      and :func:`dgl.DGLGraph.set_batch_num_edges` on the transformed graph\n      to maintain the information.\n\n    Examples\n    --------\n    Assume that the graph has the following adjacency matrix: ::\n\n       A = [[0, 0, 1],\n            [1, 0, 1],\n            [1, 1, 0]]\n\n    >>> g = dgl.graph(([0, 1, 1, 2, 2],[2, 0, 2, 0, 1]), 'user', 'follows')\n    >>> lg = g.line_graph()\n    >>> lg\n    Graph(num_nodes=5, num_edges=8,\n    ndata_schemes={}\n    edata_schemes={})\n    >>> lg.edges()\n    (tensor([0, 0, 1, 2, 2, 3, 4, 4]), tensor([3, 4, 0, 3, 4, 0, 1, 2]))\n    >>> lg = g.line_graph(backtracking=False)\n    >>> lg\n    Graph(num_nodes=5, num_edges=4,\n    ndata_schemes={}\n    edata_schemes={})\n    >>> lg.edges()\n    (tensor([0, 1, 2, 4]), tensor([4, 0, 3, 1]))\n    \"\"\"\n    assert g.is_homogeneous, \"only homogeneous graph is supported\"\n\n    dev = g.device\n    lg = DGLGraph(\n        _CAPI_DGLHeteroLineGraph(g._graph.copy_to(nd.cpu()), backtracking)\n    )\n    lg = lg.to(dev)\n    if shared:\n        new_frames = utils.extract_edge_subframes(g, None)\n        utils.set_new_frames(lg, node_frames=new_frames)\n\n    return lg\n\n\nDGLGraph.line_graph = utils.alias_func(line_graph)\n\n\ndef khop_adj(g, k):\n    \"\"\"Return the matrix of :math:`A^k` where :math:`A` is the adjacency matrix of the graph\n    :math:`g`.\n\n    The returned matrix is a 32-bit float dense matrix on CPU. The graph must be homogeneous.\n\n    Parameters\n    ----------\n    g : DGLGraph\n        The input graph.\n    k : int\n        The :math:`k` in :math:`A^k`.\n\n    Returns\n    -------\n    Tensor\n        The returned tensor.\n\n    Examples\n    --------\n    >>> import dgl\n    >>> g = dgl.graph(([0,1,2,3,4,0,1,2,3,4], [0,1,2,3,4,1,2,3,4,0]))\n    >>> dgl.khop_adj(g, 1)\n    tensor([[1., 1., 0., 0., 0.],\n            [0., 1., 1., 0., 0.],\n            [0., 0., 1., 1., 0.],\n            [0., 0., 0., 1., 1.],\n            [1., 0., 0., 0., 1.]])\n    >>> dgl.khop_adj(g, 3)\n    tensor([[1., 3., 3., 1., 0.],\n            [0., 1., 3., 3., 1.],\n            [1., 0., 1., 3., 3.],\n            [3., 1., 0., 1., 3.],\n            [3., 3., 1., 0., 1.]])\n    \"\"\"\n    assert g.is_homogeneous, \"only homogeneous graph is supported\"\n    adj_k = (\n        g.adj_external(transpose=False, scipy_fmt=g.formats()[\"created\"][0])\n        ** k\n    )\n    return F.tensor(adj_k.todense().astype(np.float32))\n\n\ndef khop_graph(g, k, copy_ndata=True):\n    \"\"\"Return the graph whose edges connect the :attr:`k`-hop neighbors of the original graph.\n\n    More specifically, an edge from node ``u`` and node ``v`` exists in the new graph if\n    and only if a path with length :attr:`k` exists from node ``u`` to node ``v`` in the\n    original graph.\n\n    The adjacency matrix of the returned graph is :math:`A^k`\n    (where :math:`A` is the adjacency matrix of :math:`g`).\n\n    Parameters\n    ----------\n    g : DGLGraph\n        The input graph.\n    k : int\n        The :math:`k` in `k`-hop graph.\n    copy_ndata: bool, optional\n        If True, the node features of the new graph are copied from the\n        original graph.\n\n        If False, the new graph will not have any node features.\n\n        (Default: True)\n\n    Returns\n    -------\n    DGLGraph\n        The returned graph.\n\n    Notes\n    -----\n    If :attr:`copy_ndata` is True, the resulting graph will share the node feature\n    tensors with the input graph. Hence, users should try to avoid in-place operations\n    which will be visible to both graphs.\n\n    This function discards the batch information. Please use\n    :func:`dgl.DGLGraph.set_batch_num_nodes`\n    and :func:`dgl.DGLGraph.set_batch_num_edges` on the transformed graph\n    to maintain the information.\n\n    Examples\n    --------\n\n    Below gives an easy example:\n\n    >>> import dgl\n    >>> g = dgl.graph(([0, 1], [1, 2]))\n    >>> g_2 = dgl.transforms.khop_graph(g, 2)\n    >>> print(g_2.edges())\n    (tensor([0]), tensor([2]))\n\n    A more complicated example:\n\n    >>> import dgl\n    >>> g = dgl.graph(([0,1,2,3,4,0,1,2,3,4], [0,1,2,3,4,1,2,3,4,0]))\n    >>> dgl.khop_graph(g, 1)\n    DGLGraph(num_nodes=5, num_edges=10,\n             ndata_schemes={}\n             edata_schemes={})\n    >>> dgl.khop_graph(g, 3)\n    DGLGraph(num_nodes=5, num_edges=40,\n             ndata_schemes={}\n             edata_schemes={})\n    \"\"\"\n    assert g.is_homogeneous, \"only homogeneous graph is supported\"\n    n = g.num_nodes()\n    adj_k = (\n        g.adj_external(transpose=False, scipy_fmt=g.formats()[\"created\"][0])\n        ** k\n    )\n    adj_k = adj_k.tocoo()\n    multiplicity = adj_k.data\n    row = np.repeat(adj_k.row, multiplicity)\n    col = np.repeat(adj_k.col, multiplicity)\n    # TODO(zihao): we should support creating multi-graph from scipy sparse matrix\n    # in the future.\n    new_g = convert.graph(\n        (row, col), num_nodes=n, idtype=g.idtype, device=g.device\n    )\n\n    # handle ndata\n    if copy_ndata:\n        node_frames = utils.extract_node_subframes(g, None)\n        utils.set_new_frames(new_g, node_frames=node_frames)\n\n    return new_g\n\n\ndef reverse(\n    g, copy_ndata=True, copy_edata=False, *, share_ndata=None, share_edata=None\n):\n    r\"\"\"Return a new graph with every edges being the reverse ones in the input graph.\n\n    The reverse (also called converse, transpose) of a graph with edges\n    :math:`(i_1, j_1), (i_2, j_2), \\cdots` of type ``(U, E, V)`` is a new graph with edges\n    :math:`(j_1, i_1), (j_2, i_2), \\cdots` of type ``(V, E, U)``.\n\n    The returned graph shares the data structure with the original graph, i.e. dgl.reverse\n    will not create extra storage for the reversed graph.\n\n    Parameters\n    ----------\n    g : DGLGraph\n        The input graph.\n    copy_ndata: bool, optional\n        If True, the node features of the reversed graph are copied from the\n        original graph. If False, the reversed graph will not have any node features.\n        (Default: True)\n    copy_edata: bool, optional\n        If True, the edge features of the reversed graph are copied from the\n        original graph. If False, the reversed graph will not have any edge features.\n        (Default: False)\n\n    Return\n    ------\n    DGLGraph\n        The reversed graph.\n\n    Notes\n    -----\n    If :attr:`copy_ndata` or :attr:`copy_edata` is True,\n    the resulting graph will share the node or edge feature\n    tensors with the input graph. Hence, users should try to avoid in-place operations\n    which will be visible to both graphs.\n\n    This function discards the batch information. Please use\n    :func:`dgl.DGLGraph.set_batch_num_nodes`\n    and :func:`dgl.DGLGraph.set_batch_num_edges` on the transformed graph\n    to maintain the information.\n\n    Examples\n    --------\n    **Homogeneous graphs**\n\n    Create a graph to reverse.\n\n    >>> import dgl\n    >>> import torch as th\n    >>> g = dgl.graph((th.tensor([0, 1, 2]), th.tensor([1, 2, 0])))\n    >>> g.ndata['h'] = th.tensor([[0.], [1.], [2.]])\n    >>> g.edata['h'] = th.tensor([[3.], [4.], [5.]])\n\n    Reverse the graph.\n\n    >>> rg = dgl.reverse(g, copy_edata=True)\n    >>> rg.ndata['h']\n    tensor([[0.],\n            [1.],\n            [2.]])\n\n    The i-th edge in the reversed graph corresponds to the i-th edge in the\n    original graph. When :attr:`copy_edata` is True, they have the same features.\n\n    >>> rg.edges()\n    (tensor([1, 2, 0]), tensor([0, 1, 2]))\n    >>> rg.edata['h']\n    tensor([[3.],\n            [4.],\n            [5.]])\n\n    **Heterogenenous graphs**\n\n    >>> g = dgl.heterograph({\n    ...     ('user', 'follows', 'user'): (th.tensor([0, 2]), th.tensor([1, 2])),\n    ...     ('user', 'plays', 'game'): (th.tensor([1, 2, 1]), th.tensor([2, 1, 1]))\n    ... })\n    >>> g.nodes['game'].data['hv'] = th.ones(3, 1)\n    >>> g.edges['plays'].data['he'] = th.zeros(3, 1)\n\n    The resulting graph will have edge types\n    ``('user', 'follows', 'user)`` and ``('game', 'plays', 'user')``.\n\n    >>> rg = dgl.reverse(g, copy_ndata=True)\n    >>> rg\n    Graph(num_nodes={'game': 3, 'user': 3},\n          num_edges={('user', 'follows', 'user'): 2, ('game', 'plays', 'user'): 3},\n          metagraph=[('user', 'user'), ('game', 'user')])\n    >>> rg.edges(etype='follows')\n    (tensor([1, 2]), tensor([0, 2]))\n    >>> rg.edges(etype='plays')\n    (tensor([2, 1, 1]), tensor([1, 2, 1]))\n    >>> rg.nodes['game'].data['hv']\n    tensor([[1.],\n            [1.],\n            [1.]])\n    >>> rg.edges['plays'].data\n    {}\n    \"\"\"\n    if share_ndata is not None:\n        dgl_warning(\"share_ndata argument has been renamed to copy_ndata.\")\n        copy_ndata = share_ndata\n    if share_edata is not None:\n        dgl_warning(\"share_edata argument has been renamed to copy_edata.\")\n        copy_edata = share_edata\n    if g.is_block:\n        # TODO(0.5 release, xiangsx) need to handle BLOCK\n        # currently reversing a block results in undefined behavior\n        raise DGLError(\"Reversing a block graph is not supported.\")\n    gidx = g._graph.reverse()\n    new_g = DGLGraph(gidx, g.ntypes, g.etypes)\n\n    # handle ndata\n    if copy_ndata:\n        # for each ntype\n        for ntype in g.ntypes:\n            new_g.nodes[ntype].data.update(g.nodes[ntype].data)\n\n    # handle edata\n    if copy_edata:\n        # for each etype\n        for utype, etype, vtype in g.canonical_etypes:\n            new_g.edges[vtype, etype, utype].data.update(\n                g.edges[utype, etype, vtype].data\n            )\n\n    return new_g\n\n\nDGLGraph.reverse = utils.alias_func(reverse)\n\n\ndef to_simple_graph(g):\n    \"\"\"Convert the graph to a simple graph with no multi-edge.\n\n    DEPRECATED: renamed to dgl.to_simple\n\n    Parameters\n    ----------\n    g : DGLGraph\n        The input graph.\n\n    Returns\n    -------\n    DGLGraph\n        A simple graph.\n\n    Notes\n    -----\n\n    This function discards the batch information. Please use\n    :func:`dgl.DGLGraph.set_batch_num_nodes`\n    and :func:`dgl.DGLGraph.set_batch_num_edges` on the transformed graph\n    to maintain the information.\n    \"\"\"\n    dgl_warning(\"dgl.to_simple_graph is renamed to dgl.to_simple in v0.5.\")\n    return to_simple(g)\n\n\ndef laplacian_lambda_max(g):\n    \"\"\"Return the largest eigenvalue of the normalized symmetric Laplacian of a graph.\n\n    If the graph is batched from multiple graphs, return the list of the largest eigenvalue\n    for each graph instead.\n\n    Parameters\n    ----------\n    g : DGLGraph\n        The input graph, it must be a bi-directed homogeneous graph, i.e., every edge\n        should have an accompanied reverse edge in the graph.\n        The graph can be batched from multiple graphs.\n\n    Returns\n    -------\n    list[float]\n        A list where the i-th item indicates the largest eigenvalue\n        of i-th graph in :attr:`g`.\n\n        In the case where the function takes a single graph, it will return a list\n        consisting of a single element.\n\n    Examples\n    --------\n    >>> import dgl\n    >>> g = dgl.graph(([0, 1, 2, 3, 4, 0, 1, 2, 3, 4], [1, 2, 3, 4, 0, 4, 0, 1, 2, 3]))\n    >>> dgl.laplacian_lambda_max(g)\n    [1.809016994374948]\n    \"\"\"\n    g_arr = batch.unbatch(g)\n    rst = []\n    for g_i in g_arr:\n        n = g_i.num_nodes()\n        adj = g_i.adj_external(\n            transpose=True, scipy_fmt=g_i.formats()[\"created\"][0]\n        ).astype(float)\n        norm = sparse.diags(\n            F.asnumpy(g_i.in_degrees()).clip(1) ** -0.5, dtype=float\n        )\n        laplacian = sparse.eye(n) - norm * adj * norm\n        rst.append(\n            scipy.sparse.linalg.eigs(\n                laplacian, 1, which=\"LM\", return_eigenvectors=False\n            )[0].real\n        )\n    return rst\n\n\ndef metapath_reachable_graph(g, metapath):\n    \"\"\"Return a graph where the successors of any node ``u`` are nodes reachable from ``u`` by\n    the given metapath.\n\n    If the beginning node type ``s`` and ending node type ``t`` are the same, it will return\n    a homogeneous graph with node type ``s = t``.  Otherwise, a unidirectional bipartite graph\n    with source node type ``s`` and destination node type ``t`` is returned.\n\n    In both cases, two nodes ``u`` and ``v`` will be connected with an edge ``(u, v)`` if\n    there exists one path matching the metapath from ``u`` to ``v``.\n\n    The result graph keeps the node set of type ``s`` and ``t`` in the original graph even if\n    they might have no neighbor.\n\n    The features of the source/destination node type in the original graph would be copied to\n    the new graph.\n\n    Parameters\n    ----------\n    g : DGLGraph\n        The input graph\n    metapath : list[str or tuple of str]\n        Metapath in the form of a list of edge types\n\n    Returns\n    -------\n    DGLGraph\n        A homogeneous or unidirectional bipartite graph. It will be on CPU regardless of\n        whether the input graph is on CPU or GPU.\n\n    Notes\n    -----\n\n    This function discards the batch information. Please use\n    :func:`dgl.DGLGraph.set_batch_num_nodes`\n    and :func:`dgl.DGLGraph.set_batch_num_edges` on the transformed graph\n    to maintain the information.\n\n    Examples\n    --------\n    >>> g = dgl.heterograph({\n    ...     ('A', 'AB', 'B'): ([0, 1, 2], [1, 2, 3]),\n    ...     ('B', 'BA', 'A'): ([1, 2, 3], [0, 1, 2])})\n    >>> new_g = dgl.metapath_reachable_graph(g, ['AB', 'BA'])\n    >>> new_g.edges(order='eid')\n    (tensor([0, 1, 2]), tensor([0, 1, 2]))\n    \"\"\"\n    adj = 1\n    for etype in metapath:\n        adj = adj * g.adj_external(\n            etype=etype, scipy_fmt=\"csr\", transpose=False\n        )\n\n    adj = (adj != 0).tocsr()\n    srctype = g.to_canonical_etype(metapath[0])[0]\n    dsttype = g.to_canonical_etype(metapath[-1])[2]\n    new_g = convert.heterograph(\n        {(srctype, \"_E\", dsttype): adj.nonzero()},\n        {srctype: adj.shape[0], dsttype: adj.shape[1]},\n        idtype=g.idtype,\n        device=g.device,\n    )\n\n    # copy srcnode features\n    new_g.nodes[srctype].data.update(g.nodes[srctype].data)\n    # copy dstnode features\n    if srctype != dsttype:\n        new_g.nodes[dsttype].data.update(g.nodes[dsttype].data)\n\n    return new_g\n\n\ndef add_nodes(g, num, data=None, ntype=None):\n    r\"\"\"Add the given number of nodes to the graph and return a new graph.\n\n    The new nodes will have IDs starting from ``g.num_nodes(ntype)``.\n\n    Parameters\n    ----------\n    num : int\n        The number of nodes to add.\n    data : dict[str, Tensor], optional\n        Feature data of the added nodes. The keys are feature names\n        while the values are feature data.\n    ntype : str, optional\n        The node type name. Can be omitted if there is\n        only one type of nodes in the graph.\n\n    Return\n    ------\n    DGLGraph\n        The graph with newly added nodes.\n\n    Notes\n    -----\n    * For features in :attr:`g` but not in :attr:`data`,\n      DGL assigns zero features for the newly added nodes.\n    * For feature in :attr:`data` but not in :attr:`g`, DGL assigns zero features\n      for the existing nodes in the graph.\n    * This function discards the batch information. Please use\n      :func:`dgl.DGLGraph.set_batch_num_nodes`\n      and :func:`dgl.DGLGraph.set_batch_num_edges` on the transformed graph\n      to maintain the information.\n\n    Examples\n    --------\n\n    The following example uses PyTorch backend.\n\n    >>> import dgl\n    >>> import torch\n\n    **Homogeneous Graphs**\n\n    >>> g = dgl.graph((torch.tensor([0, 1]), torch.tensor([1, 2])))\n    >>> g.num_nodes()\n    3\n    >>> g = dgl.add_nodes(g, 2)\n    >>> g.num_nodes()\n    5\n\n    If the graph has some node features and new nodes are added without\n    features, their features will be filled with zeros.\n\n    >>> g.ndata['h'] = torch.ones(5, 1)\n    >>> g = dgl.add_nodes(g, 1)\n    >>> g.ndata['h']\n    tensor([[1.], [1.], [1.], [1.], [1.], [0.]])\n\n    Assign features for the new nodes.\n\n    >>> g = dgl.add_nodes(g, 1, {'h': torch.ones(1, 1), 'w': torch.ones(1, 1)})\n    >>> g.ndata['h']\n    tensor([[1.], [1.], [1.], [1.], [1.], [0.], [1.]])\n\n    Since :attr:`data` contains new feature fields, the features for existing nodes\n    will be filled with zeros.\n\n    >>> g.ndata['w']\n    tensor([[0.], [0.], [0.], [0.], [0.], [0.], [1.]])\n\n    **Heterogeneous Graphs**\n\n    >>> g = dgl.heterograph({\n    ...     ('user', 'plays', 'game'): (torch.tensor([0, 1, 1, 2]),\n    ...                                 torch.tensor([0, 0, 1, 1])),\n    ...     ('developer', 'develops', 'game'): (torch.tensor([0, 1]),\n    ...                                         torch.tensor([0, 1]))\n    ...     })\n    >>> g.num_nodes('user')\n    3\n    >>> g = dgl.add_nodes(g, 2, ntype='user')\n    >>> g.num_nodes('user')\n    5\n\n    See Also\n    --------\n    remove_nodes\n    add_edges\n    remove_edges\n    \"\"\"\n    g = g.clone()\n    g.add_nodes(num, data=data, ntype=ntype)\n    return g\n\n\ndef add_edges(g, u, v, data=None, etype=None):\n    r\"\"\"Add the edges to the graph and return a new graph.\n\n    The i-th new edge will be from ``u[i]`` to ``v[i]``.  The IDs of the new\n    edges will start from ``g.num_edges(etype)``.\n\n    Parameters\n    ----------\n    u : int, Tensor or iterable[int]\n        Source node IDs, ``u[i]`` gives the source node for the i-th new edge.\n    v : int, Tensor or iterable[int]\n        Destination node IDs, ``v[i]`` gives the destination node for the i-th new edge.\n    data : dict[str, Tensor], optional\n        Feature data of the added edges. The keys are feature names\n        while the values are feature data.\n    etype : str or (str, str, str), optional\n        The type names of the edges. The allowed type name formats are:\n\n        * ``(str, str, str)`` for source node type, edge type and destination node type.\n        * or one ``str`` edge type name if the name can uniquely identify a\n          triplet format in the graph.\n\n        Can be omitted if the graph has only one type of edges.\n\n    Return\n    ------\n    DGLGraph\n        The graph with newly added edges.\n\n    Notes\n    -----\n    * If the end nodes of the given edges do not exist in :attr:`g`,\n      :func:`dgl.add_nodes` is invoked to add those nodes.\n      The node features of the new nodes will be filled with zeros.\n    * For features in :attr:`g` but not in :attr:`data`,\n      DGL assigns zero features for the newly added nodes.\n    * For feature in :attr:`data` but not in :attr:`g`, DGL assigns zero features\n      for the existing nodes in the graph.\n    * This function discards the batch information. Please use\n      :func:`dgl.DGLGraph.set_batch_num_nodes`\n      and :func:`dgl.DGLGraph.set_batch_num_edges` on the transformed graph\n      to maintain the information.\n\n    Examples\n    --------\n    The following example uses PyTorch backend.\n\n    >>> import dgl\n    >>> import torch\n\n    **Homogeneous Graphs**\n\n    >>> g = dgl.graph((torch.tensor([0, 1]), torch.tensor([1, 2])))\n    >>> g.num_edges()\n    2\n    >>> g = dgl.add_edges(g, torch.tensor([1, 3]), torch.tensor([0, 1]))\n    >>> g.num_edges()\n    4\n\n    Since ``u`` or ``v`` contains a non-existing node ID, the nodes are\n    added implicitly.\n\n    >>> g.num_nodes()\n    4\n\n    If the graph has some edge features and new edges are added without\n    features, their features will be filled with zeros.\n\n    >>> g.edata['h'] = torch.ones(4, 1)\n    >>> g = dgl.add_edges(g, torch.tensor([1]), torch.tensor([1]))\n    >>> g.edata['h']\n    tensor([[1.], [1.], [1.], [1.], [0.]])\n\n    You can also assign features for the new edges in adding new edges.\n\n    >>> g = dgl.add_edges(g, torch.tensor([0, 0]), torch.tensor([2, 2]),\n    ...                   {'h': torch.tensor([[1.], [2.]]), 'w': torch.ones(2, 1)})\n    >>> g.edata['h']\n    tensor([[1.], [1.], [1.], [1.], [0.], [1.], [2.]])\n\n    Since :attr:`data` contains new feature fields, the features for old edges\n    will be filled with zeros.\n\n    >>> g.edata['w']\n    tensor([[0.], [0.], [0.], [0.], [0.], [1.], [1.]])\n\n    **Heterogeneous Graphs**\n\n    >>> g = dgl.heterograph({\n    ...     ('user', 'plays', 'game'): (torch.tensor([0, 1, 1, 2]),\n    ...                                 torch.tensor([0, 0, 1, 1])),\n    ...     ('developer', 'develops', 'game'): (torch.tensor([0, 1]),\n    ...                                         torch.tensor([0, 1]))\n    ...     })\n    >>> g.num_edges('plays')\n    4\n    >>> g = dgl.add_edges(g, torch.tensor([3]), torch.tensor([3]), etype='plays')\n    >>> g.num_edges('plays')\n    5\n\n    See Also\n    --------\n    add_nodes\n    remove_nodes\n    remove_edges\n    \"\"\"\n    g = g.clone()\n    g.add_edges(u, v, data=data, etype=etype)\n    return g\n\n\ndef remove_edges(g, eids, etype=None, store_ids=False):\n    r\"\"\"Remove the specified edges and return a new graph.\n\n    Also delete the features of the edges. The edges must exist in the graph.\n    The resulting graph has the same number of the nodes as the input one,\n    even if some nodes become isolated after the the edge removal.\n\n    Parameters\n    ----------\n    eids : int, Tensor, iterable[int]\n        The IDs of the edges to remove.\n    etype : str or (str, str, str), optional\n        The type names of the edges. The allowed type name formats are:\n\n        * ``(str, str, str)`` for source node type, edge type and destination node type.\n        * or one ``str`` edge type name if the name can uniquely identify a\n          triplet format in the graph.\n\n        Can be omitted if the graph has only one type of edges.\n    store_ids : bool, optional\n        If True, it will store the raw IDs of the extracted nodes and edges in the ``ndata``\n        and ``edata`` of the resulting graph under name ``dgl.NID`` and ``dgl.EID``,\n        respectively.\n\n    Return\n    ------\n    DGLGraph\n        The graph with edges deleted.\n\n    Notes\n    -----\n    This function preserves the batch information.\n\n    Examples\n    --------\n    >>> import dgl\n    >>> import torch\n\n    **Homogeneous Graphs**\n\n    >>> g = dgl.graph((torch.tensor([0, 0, 2]), torch.tensor([0, 1, 2])))\n    >>> g.edata['he'] = torch.arange(3).float().reshape(-1, 1)\n    >>> g = dgl.remove_edges(g, torch.tensor([0, 1]))\n    >>> g\n    Graph(num_nodes=3, num_edges=1,\n        ndata_schemes={}\n        edata_schemes={'he': Scheme(shape=(1,), dtype=torch.float32)})\n    >>> g.edges('all')\n    (tensor([2]), tensor([2]), tensor([0]))\n    >>> g.edata['he']\n    tensor([[2.]])\n\n    **Heterogeneous Graphs**\n\n    >>> g = dgl.heterograph({\n    ...     ('user', 'plays', 'game'): (torch.tensor([0, 1, 1, 2]),\n    ...                                 torch.tensor([0, 0, 1, 1])),\n    ...     ('developer', 'develops', 'game'): (torch.tensor([0, 1]),\n    ...                                         torch.tensor([0, 1]))\n    ...     })\n    >>> g = dgl.remove_edges(g, torch.tensor([0, 1]), 'plays')\n    >>> g.edges('all', etype='plays')\n    (tensor([1, 2]), tensor([1, 1]), tensor([0, 1]))\n\n    See Also\n    --------\n    add_nodes\n    add_edges\n    remove_nodes\n    \"\"\"\n    g = g.clone()\n    g.remove_edges(eids, etype=etype, store_ids=store_ids)\n    return g\n\n\ndef remove_nodes(g, nids, ntype=None, store_ids=False):\n    r\"\"\"Remove the specified nodes and return a new graph.\n\n    Also delete the features. Edges that connect from/to the nodes will be\n    removed as well. After the removal, DGL re-labels the remaining nodes and edges\n    with IDs from 0.\n\n    Parameters\n    ----------\n    nids : int, Tensor, iterable[int]\n        The nodes to be removed.\n    ntype : str, optional\n        The type of the nodes to remove. Can be omitted if there is\n        only one node type in the graph.\n    store_ids : bool, optional\n        If True, it will store the raw IDs of the extracted nodes and edges in the ``ndata``\n        and ``edata`` of the resulting graph under name ``dgl.NID`` and ``dgl.EID``,\n        respectively.\n\n    Return\n    ------\n    DGLGraph\n        The graph with nodes deleted.\n\n    Notes\n    -----\n\n    This function discards the batch information. Please use\n    :func:`dgl.DGLGraph.set_batch_num_nodes`\n    and :func:`dgl.DGLGraph.set_batch_num_edges` on the transformed graph\n    to maintain the information.\n\n    Examples\n    --------\n\n    >>> import dgl\n    >>> import torch\n\n    **Homogeneous Graphs**\n\n    >>> g = dgl.graph((torch.tensor([0, 0, 2]), torch.tensor([0, 1, 2])))\n    >>> g.ndata['hv'] = torch.arange(3).float().reshape(-1, 1)\n    >>> g.edata['he'] = torch.arange(3).float().reshape(-1, 1)\n    >>> g = dgl.remove_nodes(g, torch.tensor([0, 1]))\n    >>> g\n    Graph(num_nodes=1, num_edges=1,\n        ndata_schemes={'hv': Scheme(shape=(1,), dtype=torch.float32)}\n        edata_schemes={'he': Scheme(shape=(1,), dtype=torch.float32)})\n    >>> g.ndata['hv']\n    tensor([[2.]])\n    >>> g.edata['he']\n    tensor([[2.]])\n\n    **Heterogeneous Graphs**\n\n    >>> g = dgl.heterograph({\n    ...     ('user', 'plays', 'game'): (torch.tensor([0, 1, 1, 2]),\n    ...                                 torch.tensor([0, 0, 1, 1])),\n    ...     ('developer', 'develops', 'game'): (torch.tensor([0, 1]),\n    ...                                         torch.tensor([0, 1]))\n    ...     })\n    >>> g = dgl.remove_nodes(g, torch.tensor([0, 1]), ntype='game')\n    >>> g.num_nodes('user')\n    3\n    >>> g.num_nodes('game')\n    0\n    >>> g.num_edges('plays')\n    0\n\n    See Also\n    --------\n    add_nodes\n    add_edges\n    remove_edges\n    \"\"\"\n    g = g.clone()\n    g.remove_nodes(nids, ntype=ntype, store_ids=store_ids)\n    return g\n\n\ndef add_self_loop(g, edge_feat_names=None, fill_data=1.0, etype=None):\n    r\"\"\"Add self-loops for each node in the graph and return a new graph.\n\n    Parameters\n    ----------\n    g : DGLGraph\n        The graph.\n    edge_feat_names : list[str], optional\n        The names of the self-loop features to apply `fill_data`. If None, it will apply `fill_data`\n        to all self-loop features. Default: None.\n    fill_data : int, float or str, optional\n        The value to fill the self-loop features. Default: 1.\n\n        * If ``fill_data`` is ``int`` or ``float``, self-loop features will be directly given by\n          ``fill_data``.\n        * if ``fill_data`` is ``str``, self-loop features will be generated by aggregating the\n          features of the incoming edges of the corresponding nodes. The supported aggregation are:\n          ``'mean'``, ``'sum'``, ``'max'``, ``'min'``.\n    etype : str or (str, str, str), optional\n        The type names of the edges. The allowed type name formats are:\n\n        * ``(str, str, str)`` for source node type, edge type and destination node type.\n        * or one ``str`` edge type name if the name can uniquely identify a\n          triplet format in the graph.\n\n        Can be omitted if the graph has only one type of edges.\n\n    Return\n    ------\n    DGLGraph\n        The graph with self-loops.\n\n    Notes\n    -----\n    * The function only supports homogeneous graphs or heterogeneous graphs but\n      the relation graph specified by the :attr:`etype` argument is homogeneous.\n    * The function adds self-loops regardless of whether they already exist or not.\n      If one wishes to have exactly one self-loop for every node,\n      call :func:`remove_self_loop` before invoking :func:`add_self_loop`.\n    * This function discards the batch information. Please use\n      :func:`dgl.DGLGraph.set_batch_num_nodes`\n      and :func:`dgl.DGLGraph.set_batch_num_edges` on the transformed graph\n      to maintain the information.\n\n    Examples\n    --------\n    >>> import dgl\n    >>> import torch\n\n    **Homogeneous Graphs**\n\n    >>> g = dgl.graph((torch.tensor([0, 0, 2]), torch.tensor([2, 1, 0])))\n    >>> g.ndata['hv'] = torch.arange(3).float().reshape(-1, 1)\n    >>> g.edata['he'] = torch.arange(3).float().reshape(-1, 1)\n    >>> g = dgl.add_self_loop(g, fill_data='sum')\n    >>> g\n    Graph(num_nodes=3, num_edges=6,\n        ndata_schemes={'hv': Scheme(shape=(1,), dtype=torch.float32)}\n        edata_schemes={'he': Scheme(shape=(1,), dtype=torch.float32)})\n    >>> g.edata['he']\n    tensor([[0.],\n            [1.],\n            [2.],\n            [2.],\n            [1.],\n            [0.]])\n\n    **Heterogeneous Graphs**\n\n    >>> g = dgl.heterograph({\n    ...     ('user', 'follows', 'user'): (torch.tensor([1, 2]),\n    ...                                   torch.tensor([0, 1])),\n    ...     ('user', 'plays', 'game'): (torch.tensor([0, 1]),\n    ...                                 torch.tensor([0, 1]))})\n    >>> g = dgl.add_self_loop(g, etype='follows')\n    >>> g\n    Graph(num_nodes={'user': 3, 'game': 2},\n          num_edges={('user', 'plays', 'game'): 2, ('user', 'follows', 'user'): 5},\n          metagraph=[('user', 'user'), ('user', 'game')])\n    \"\"\"\n    etype = g.to_canonical_etype(etype)\n    data = {}\n    reduce_funcs = {\n        \"sum\": function.sum,\n        \"mean\": function.mean,\n        \"max\": function.max,\n        \"min\": function.min,\n    }\n\n    if edge_feat_names is None:\n        edge_feat_names = g.edges[etype].data.keys()\n\n    if etype[0] != etype[2]:\n        raise DGLError(\n            \"add_self_loop does not support unidirectional bipartite graphs: {}.\"\n            \"Please make sure the types of head node and tail node are identical.\"\n            \"\".format(etype)\n        )\n\n    for feat_name in edge_feat_names:\n        if isinstance(fill_data, (int, float)):\n            dtype = g.edges[etype].data[feat_name].dtype\n            dshape = g.edges[etype].data[feat_name].shape\n            tmp_fill_data = F.copy_to(\n                F.astype(F.tensor([fill_data]), dtype), g.device\n            )\n            if len(dshape) > 1:\n                data[feat_name] = (\n                    F.zeros(\n                        (g.num_nodes(etype[0]), *dshape[1:]), dtype, g.device\n                    )\n                    + tmp_fill_data\n                )\n            else:\n                data[feat_name] = (\n                    F.zeros((g.num_nodes(etype[0]),), dtype, g.device)\n                    + tmp_fill_data\n                )\n\n        elif isinstance(fill_data, str):\n            if fill_data not in reduce_funcs.keys():\n                raise DGLError(\"Unsupported aggregation: {}\".format(fill_data))\n            reducer = reduce_funcs[fill_data]\n            with g.local_scope():\n                g.update_all(\n                    function.copy_e(feat_name, \"h\"),\n                    reducer(\"h\", \"h\"),\n                    etype=etype,\n                )\n                data[feat_name] = g.nodes[etype[0]].data[\"h\"]\n\n    nodes = g.nodes(etype[0])\n    if len(data):\n        new_g = add_edges(g, nodes, nodes, data=data, etype=etype)\n    else:\n        new_g = add_edges(g, nodes, nodes, etype=etype)\n    return new_g\n\n\nDGLGraph.add_self_loop = utils.alias_func(add_self_loop)\n\n\ndef remove_self_loop(g, etype=None):\n    r\"\"\"Remove self-loops for each node in the graph and return a new graph.\n\n    Parameters\n    ----------\n    g : DGLGraph\n        The graph.\n    etype : str or (str, str, str), optional\n        The type names of the edges. The allowed type name formats are:\n\n        * ``(str, str, str)`` for source node type, edge type and destination node type.\n        * or one ``str`` edge type name if the name can uniquely identify a\n          triplet format in the graph.\n\n        Can be omitted if the graph has only one type of edges.\n\n    Notes\n    -----\n    If a node has multiple self-loops, remove them all. Do nothing for nodes without\n    self-loops.\n\n    This function preserves the batch information.\n\n    Examples\n    ---------\n\n    >>> import dgl\n    >>> import torch\n\n    **Homogeneous Graphs**\n\n    >>> g = dgl.graph((torch.tensor([0, 0, 0, 1]), torch.tensor([1, 0, 0, 2])))\n    >>> g.edata['he'] = torch.arange(4).float().reshape(-1, 1)\n    >>> g = dgl.remove_self_loop(g)\n    >>> g\n    Graph(num_nodes=3, num_edges=2,\n        edata_schemes={'he': Scheme(shape=(2,), dtype=torch.float32)})\n    >>> g.edata['he']\n    tensor([[0.],[3.]])\n\n    **Heterogeneous Graphs**\n\n    >>> g = dgl.heterograph({\n    ...     ('user', 'follows', 'user'): (torch.tensor([0, 1, 1, 1, 2]),\n    ...                                   torch.tensor([0, 0, 1, 1, 1])),\n    ...     ('user', 'plays', 'game'): (torch.tensor([0, 1]),\n    ...                                 torch.tensor([0, 1]))\n    ...     })\n    >>> g = dgl.remove_self_loop(g, etype='follows')\n    >>> g.num_nodes('user')\n    3\n    >>> g.num_nodes('game')\n    2\n    >>> g.num_edges('follows')\n    2\n    >>> g.num_edges('plays')\n    2\n\n    See Also\n    --------\n    add_self_loop\n    \"\"\"\n    etype = g.to_canonical_etype(etype)\n    if etype[0] != etype[2]:\n        raise DGLError(\n            \"remove_self_loop does not support unidirectional bipartite graphs: {}.\"\n            \"Please make sure the types of head node and tail node are identical.\"\n            \"\".format(etype)\n        )\n    u, v = g.edges(form=\"uv\", order=\"eid\", etype=etype)\n    self_loop_eids = F.tensor(F.nonzero_1d(u == v), dtype=F.dtype(u))\n    new_g = remove_edges(g, self_loop_eids, etype=etype)\n    return new_g\n\n\nDGLGraph.remove_self_loop = utils.alias_func(remove_self_loop)\n\n\ndef compact_graphs(\n    graphs, always_preserve=None, copy_ndata=True, copy_edata=True\n):\n    \"\"\"Given a list of graphs with the same set of nodes, find and eliminate the common\n    isolated nodes across all graphs.\n\n    This function requires the graphs to have the same set of nodes (i.e. the node types\n    must be the same, and the number of nodes of each node type must be the same).  The\n    metagraph does not have to be the same.\n\n    It finds all the nodes that have zero in-degree and zero out-degree in all the given\n    graphs, and eliminates them from all the graphs.\n\n    Useful for graph sampling where you have a giant graph but you only wish to perform\n    message passing on a smaller graph with a (tiny) subset of nodes.\n\n    Parameters\n    ----------\n    graphs : DGLGraph or list[DGLGraph]\n        The graph, or list of graphs.\n\n        All graphs must be on the same devices.\n\n        All graphs must have the same set of nodes.\n    always_preserve : Tensor or dict[str, Tensor], optional\n        If a dict of node types and node ID tensors is given, the nodes of given\n        node types would not be removed, regardless of whether they are isolated.\n\n        If a Tensor is given, DGL assumes that all the graphs have one (same) node type.\n    copy_ndata: bool, optional\n        If True, the node features of the returned graphs are copied from the\n        original graphs.\n\n        If False, the returned graphs will not have any node features.\n\n        (Default: True)\n    copy_edata: bool, optional\n        If True, the edge features of the reversed graph are copied from the\n        original graph.\n\n        If False, the reversed graph will not have any edge features.\n\n        (Default: True)\n\n    Returns\n    -------\n    DGLGraph or list[DGLGraph]\n        The compacted graph or list of compacted graphs.\n\n        Each returned graph would have a feature ``dgl.NID`` containing the mapping\n        of node IDs for each type from the compacted graph(s) to the original graph(s).\n        Note that the mapping is the same for all the compacted graphs.\n\n        All the returned graphs are on CPU.\n\n    Notes\n    -----\n    This function currently requires that the same node type of all graphs should have\n    the same node type ID, i.e. the node types are *ordered* the same.\n\n    If :attr:`copy_edata` is True, the resulting graph will share the edge feature\n    tensors with the input graph. Hence, users should try to avoid in-place operations\n    which will be visible to both graphs.\n\n    This function discards the batch information. Please use\n    :func:`dgl.DGLGraph.set_batch_num_nodes`\n    and :func:`dgl.DGLGraph.set_batch_num_edges` on the transformed graph\n    to maintain the information.\n\n    Examples\n    --------\n    The following code constructs a bipartite graph with 20 users and 10 games, but\n    only user #1 and #3, as well as game #3 and #5, have connections:\n\n    >>> g = dgl.heterograph({('user', 'plays', 'game'): ([1, 3], [3, 5])},\n    >>>                      {'user': 20, 'game': 10})\n\n    The following would compact the graph above to another bipartite graph with only\n    two users and two games.\n\n    >>> new_g = dgl.compact_graphs(g)\n    >>> new_g.ndata[dgl.NID]\n    {'user': tensor([1, 3]), 'game': tensor([3, 5])}\n\n    The mapping tells us that only user #1 and #3 as well as game #3 and #5 are kept.\n    Furthermore, the first user and second user in the compacted graph maps to\n    user #1 and #3 in the original graph.  Games are similar.\n\n    One can verify that the edge connections are kept the same in the compacted graph.\n\n    >>> new_g.edges(form='all', order='eid', etype='plays')\n    (tensor([0, 1]), tensor([0, 1]), tensor([0, 1]))\n\n    When compacting multiple graphs, nodes that do not have any connections in any\n    of the given graphs are removed.  So if you compact ``g`` and the following ``g2``\n    graphs together:\n\n    >>> g2 = dgl.heterograph({('user', 'plays', 'game'): ([1, 6], [6, 8])},\n    >>>                      {'user': 20, 'game': 10})\n    >>> new_g, new_g2 = dgl.compact_graphs([g, g2])\n    >>> new_g.ndata[dgl.NID]\n    {'user': tensor([1, 3, 6]), 'game': tensor([3, 5, 6, 8])}\n\n    Then one can see that user #1 from both graphs, users #3 from the first graph, as\n    well as user #6 from the second graph, are kept.  Games are similar.\n\n    Similarly, one can also verify the connections:\n\n    >>> new_g.edges(form='all', order='eid', etype='plays')\n    (tensor([0, 1]), tensor([0, 1]), tensor([0, 1]))\n    >>> new_g2.edges(form='all', order='eid', etype='plays')\n    (tensor([0, 2]), tensor([2, 3]), tensor([0, 1]))\n    \"\"\"\n    return_single = False\n    if not isinstance(graphs, Iterable):\n        graphs = [graphs]\n        return_single = True\n    if len(graphs) == 0:\n        return []\n    if graphs[0].is_block:\n        raise DGLError(\"Compacting a block graph is not allowed.\")\n\n    # Ensure the node types are ordered the same.\n    # TODO(BarclayII): we ideally need to remove this constraint.\n    ntypes = graphs[0].ntypes\n    idtype = graphs[0].idtype\n    device = graphs[0].device\n    for g in graphs:\n        assert ntypes == g.ntypes, (\n            \"All graphs should have the same node types in the same order, got %s and %s\"\n            % ntypes,\n            g.ntypes,\n        )\n        assert (\n            idtype == g.idtype\n        ), \"Expect graph data type to be {}, but got {}\".format(\n            idtype, g.idtype\n        )\n        assert device == g.device, (\n            \"All graphs must be on the same devices.\"\n            \"Expect graph device to be {}, but got {}\".format(device, g.device)\n        )\n\n    # Process the dictionary or tensor of \"always preserve\" nodes\n    if always_preserve is None:\n        always_preserve = {}\n    elif not isinstance(always_preserve, Mapping):\n        if len(ntypes) > 1:\n            raise ValueError(\n                \"Node type must be given if multiple node types exist.\"\n            )\n        always_preserve = {ntypes[0]: always_preserve}\n\n    always_preserve = utils.prepare_tensor_dict(\n        graphs[0], always_preserve, \"always_preserve\"\n    )\n    always_preserve_nd = []\n    for ntype in ntypes:\n        nodes = always_preserve.get(ntype, None)\n        if nodes is None:\n            nodes = F.copy_to(F.tensor([], idtype), device)\n        always_preserve_nd.append(F.to_dgl_nd(nodes))\n\n    # Compact and construct heterographs\n    new_graph_indexes, induced_nodes = _CAPI_DGLCompactGraphs(\n        [g._graph for g in graphs], always_preserve_nd\n    )\n    induced_nodes = [F.from_dgl_nd(nodes) for nodes in induced_nodes]\n\n    new_graphs = [\n        DGLGraph(new_graph_index, graph.ntypes, graph.etypes)\n        for new_graph_index, graph in zip(new_graph_indexes, graphs)\n    ]\n\n    if copy_ndata:\n        for g, new_g in zip(graphs, new_graphs):\n            node_frames = utils.extract_node_subframes(g, induced_nodes)\n            utils.set_new_frames(new_g, node_frames=node_frames)\n    if copy_edata:\n        for g, new_g in zip(graphs, new_graphs):\n            edge_frames = utils.extract_edge_subframes(g, None)\n            utils.set_new_frames(new_g, edge_frames=edge_frames)\n\n    if return_single:\n        new_graphs = new_graphs[0]\n\n    return new_graphs\n\n\ndef _coalesce_edge_frame(g, edge_maps, counts, aggregator):\n    r\"\"\"Coalesce edge features of duplicate edges via given aggregator in g.\n\n    Parameters\n    ----------\n    g : DGLGraph\n        The input graph.\n    edge_maps : List[Tensor]\n        The edge mapping corresponding to each edge type in g.\n    counts : List[Tensor]\n        The number of duplicated edges from the original graph for each edge type.\n    aggregator : str\n        Indicates how to coalesce edge features, could be ``arbitrary``, ``sum``\n        or ``mean``.\n\n    Returns\n    -------\n    List[Frame]\n        The frames corresponding to each edge type.\n    \"\"\"\n    if aggregator == \"arbitrary\":\n        eids = []\n        for i in range(len(g.canonical_etypes)):\n            feat_idx = F.asnumpy(edge_maps[i])\n            _, indices = np.unique(feat_idx, return_index=True)\n            eids.append(F.zerocopy_from_numpy(indices))\n\n        edge_frames = utils.extract_edge_subframes(g, eids)\n    elif aggregator in [\"sum\", \"mean\"]:\n        edge_frames = []\n        for i in range(len(g.canonical_etypes)):\n            feat_idx = edge_maps[i]\n            _, indices = np.unique(F.asnumpy(feat_idx), return_index=True)\n            _num_rows = len(indices)\n            _data = {}\n            for key, col in g._edge_frames[i]._columns.items():\n                data = col.data\n                new_data = F.scatter_add(data, feat_idx, _num_rows)\n                if aggregator == \"mean\":\n                    norm = F.astype(counts[i], F.dtype(data))\n                    norm = F.reshape(\n                        norm, (F.shape(norm)[0],) + (1,) * (F.ndim(data) - 1)\n                    )\n                    new_data /= norm\n                _data[key] = new_data\n\n            newf = Frame(data=_data, num_rows=_num_rows)\n            edge_frames.append(newf)\n    else:\n        raise DGLError(\n            \"Aggregator {} not regonized, cannot coalesce edge feature in the \"\n            \"specified way\".format(aggregator)\n        )\n    return edge_frames\n\n\ndef to_simple(\n    g,\n    return_counts=\"count\",\n    writeback_mapping=False,\n    copy_ndata=True,\n    copy_edata=False,\n    aggregator=\"arbitrary\",\n):\n    r\"\"\"Convert a graph to a simple graph without parallel edges and return.\n\n    For a heterogeneous graph with multiple edge types, DGL treats edges with the same\n    edge type and endpoints as parallel edges and removes them.\n    Optionally, one can get the the number of parallel edges by specifying the\n    :attr:`return_counts` argument. To get the a mapping from the edge IDs in the\n    input graph to the edge IDs in the resulting graph, set :attr:`writeback_mapping`\n    to true.\n\n    Parameters\n    ----------\n    g : DGLGraph\n        The input graph.  Must be on CPU.\n    return_counts : str, optional\n        If given, the count of each edge in the original graph\n        will be stored as edge features under the name\n        ``return_counts``.  The old features with the same name will be replaced.\n\n        (Default: \"count\")\n    writeback_mapping: bool, optional\n        If True, return an extra write-back mapping for each edge\n        type. The write-back mapping is a tensor recording\n        the mapping from the edge IDs in the input graph to\n        the edge IDs in the result graph. If the graph is\n        heterogeneous, DGL returns a dictionary of edge types and such\n        tensors.\n\n        If False, only the simple graph is returned.\n\n        (Default: False)\n    copy_ndata: bool, optional\n        If True, the node features of the simple graph are copied\n        from the original graph.\n\n        If False, the simple graph will not have any node features.\n\n        (Default: True)\n    copy_edata: bool, optional\n        If True, the edge features of the simple graph are copied\n        from the original graph. If there exists duplicate edges between\n        two nodes (u, v), the feature of the edge is the aggregation\n        of edge feature of duplicate edges.\n\n        If False, the simple graph will not have any edge features.\n\n        (Default: False)\n    aggregator: str, optional\n        Indicate how to coalesce edge feature of duplicate edges.\n        If ``arbitrary``, select one of the duplicate edges' feature.\n        If ``sum``, compute the summation of duplicate edges' feature.\n        If ``mean``, compute the average of duplicate edges' feature.\n\n        (Default: ``arbitrary``)\n\n    Returns\n    -------\n    DGLGraph\n        The graph.\n    tensor or dict of tensor\n        The writeback mapping. Only when ``writeback_mapping`` is True.\n\n    Notes\n    -----\n    If :attr:`copy_ndata` is True, the resulting graph will share the node feature\n    tensors with the input graph. Hence, users should try to avoid in-place operations\n    which will be visible to both graphs.\n\n    This function discards the batch information. Please use\n    :func:`dgl.DGLGraph.set_batch_num_nodes`\n    and :func:`dgl.DGLGraph.set_batch_num_edges` on the transformed graph\n    to maintain the information.\n\n    Examples\n    --------\n    **Homogeneous Graphs**\n\n    Create a graph for demonstrating to_simple API.\n    In the original graph, there are multiple edges between 1 and 2.\n\n    >>> import dgl\n    >>> import torch as th\n    >>> g = dgl.graph((th.tensor([0, 1, 2, 1]), th.tensor([1, 2, 0, 2])))\n    >>> g.ndata['h'] = th.tensor([[0.], [1.], [2.]])\n    >>> g.edata['h'] = th.tensor([[3.], [4.], [5.], [6.]])\n\n    Convert the graph to a simple graph. The return counts is\n    stored in the edge feature 'cnt' and the writeback mapping\n    is returned in a tensor.\n\n    >>> sg, wm = dgl.to_simple(g, return_counts='cnt', writeback_mapping=True)\n    >>> sg.ndata['h']\n    tensor([[0.],\n            [1.],\n            [2.]])\n    >>> u, v, eid = sg.edges(form='all')\n    >>> u\n    tensor([0, 1, 2])\n    >>> v\n    tensor([1, 2, 0])\n    >>> eid\n    tensor([0, 1, 2])\n    >>> sg.edata['cnt']\n    tensor([1, 2, 1])\n    >>> wm\n    tensor([0, 1, 2, 1])\n    >>> 'h' in g.edata\n    False\n\n    **Heterogeneous Graphs**\n\n    >>> g = dgl.heterograph({\n    ...     ('user', 'wins', 'user'): (th.tensor([0, 2, 0, 2, 2]), th.tensor([1, 1, 2, 1, 0])),\n    ...     ('user', 'plays', 'game'): (th.tensor([1, 2, 1]), th.tensor([2, 1, 1]))\n    ... })\n    >>> g.nodes['game'].data['hv'] = th.ones(3, 1)\n    >>> g.edges['plays'].data['he'] = th.zeros(3, 1)\n\n    The return counts is stored in the default edge feature 'count' for each edge type.\n\n    >>> sg, wm = dgl.to_simple(g, copy_ndata=False, writeback_mapping=True)\n    >>> sg\n    Graph(num_nodes={'game': 3, 'user': 3},\n          num_edges={('user', 'wins', 'user'): 4, ('game', 'plays', 'user'): 3},\n          metagraph=[('user', 'user'), ('game', 'user')])\n    >>> sg.edges(etype='wins')\n    (tensor([0, 2, 0, 2]), tensor([1, 1, 2, 0]))\n    >>> wm[('user', 'wins', 'user')]\n    tensor([0, 1, 2, 1, 3])\n    >>> sg.edges(etype='plays')\n    (tensor([2, 1, 1]), tensor([1, 2, 1]))\n    >>> wm[('user', 'plays', 'game')]\n    tensor([0, 1, 2])\n    >>> 'hv' in sg.nodes['game'].data\n    False\n    >>> 'he' in sg.edges['plays'].data\n    False\n    >>> sg.edata['count']\n    {('user', 'wins', 'user'): tensor([1, 2, 1, 1])\n     ('user', 'plays', 'game'): tensor([1, 1, 1])}\n    \"\"\"\n    assert g.device == F.cpu(), \"the graph must be on CPU\"\n    if g.is_block:\n        raise DGLError(\"Cannot convert a block graph to a simple graph.\")\n    simple_graph_index, counts, edge_maps = _CAPI_DGLToSimpleHetero(g._graph)\n    simple_graph = DGLGraph(simple_graph_index, g.ntypes, g.etypes)\n    counts = [F.from_dgl_nd(count) for count in counts]\n    edge_maps = [F.from_dgl_nd(edge_map) for edge_map in edge_maps]\n\n    if copy_ndata:\n        node_frames = utils.extract_node_subframes(g, None)\n        utils.set_new_frames(simple_graph, node_frames=node_frames)\n    if copy_edata:\n        new_edge_frames = _coalesce_edge_frame(g, edge_maps, counts, aggregator)\n        utils.set_new_frames(simple_graph, edge_frames=new_edge_frames)\n\n    if return_counts is not None:\n        for count, canonical_etype in zip(counts, g.canonical_etypes):\n            simple_graph.edges[canonical_etype].data[return_counts] = count\n\n    if writeback_mapping:\n        # single edge type\n        if len(edge_maps) == 1:\n            return simple_graph, edge_maps[0]\n        # multiple edge type\n        else:\n            wb_map = {}\n            for edge_map, canonical_etype in zip(edge_maps, g.canonical_etypes):\n                wb_map[canonical_etype] = edge_map\n            return simple_graph, wb_map\n\n    return simple_graph\n\n\nDGLGraph.to_simple = utils.alias_func(to_simple)\n\n\ndef _unitgraph_less_than_int32(g):\n    \"\"\"Check if a graph with only one edge type has more than 2 ** 31 - 1\n    nodes or edges.\n    \"\"\"\n    num_edges = g.num_edges()\n    num_nodes = max(g.num_nodes(g.ntypes[0]), g.num_nodes(g.ntypes[-1]))\n    return max(num_nodes, num_edges) <= (1 << 31) - 1\n\n\ndef adj_product_graph(A, B, weight_name, etype=\"_E\"):\n    r\"\"\"Create a weighted graph whose adjacency matrix is the product of\n    the adjacency matrices of the given two graphs.\n\n    Namely, given two weighted graphs :attr:`A` and :attr:`B`, whose rows\n    represent source nodes and columns represent destination nodes, this function\n    returns a new graph whose weighted adjacency matrix is\n    :math:`\\mathrm{adj}(A) \\times \\mathrm{adj}(B)`.\n\n    The two graphs must be simple graphs, and must have only one edge type.\n    Moreover, the number of nodes of the destination node type of :attr:`A` must\n    be the same as the number of nodes of the source node type of :attr:`B`.\n\n    The source node type of the returned graph will be the same as the source\n    node type of graph :attr:`A`.  The destination node type of the returned\n    graph will be the same as the destination node type of graph :attr:`B`.\n    If the two node types are the same, the returned graph will be homogeneous.\n    Otherwise, it will be a bipartite graph.\n\n    Unlike ``scipy``, if an edge in the result graph has zero weight, it will\n    not be removed from the graph.\n\n    Notes\n    -----\n    This function works on both CPU and GPU.  For GPU, the number of nodes and\n    edges must be less than the maximum of ``int32`` (i.e. ``2 ** 31 - 1``) due\n    to restriction of cuSPARSE.\n\n    The edge weights returned by this function is differentiable w.r.t. the\n    input edge weights.\n\n    If the graph format is restricted, both graphs must have CSR available.\n\n    Parameters\n    ----------\n    A : DGLGraph\n        The graph as left operand.\n    B : DGLGraph\n        The graph as right operand.\n    weight_name : str\n        The feature name of edge weight of both graphs.\n\n        The corresponding edge feature must be scalar.\n    etype : str, optional\n        The edge type of the returned graph.\n\n    Returns\n    -------\n    DGLGraph\n        The new graph.  The edge weight of the returned graph will have the\n        same feature name as :attr:`weight_name`.\n\n    Examples\n    --------\n    The following shows weighted adjacency matrix multiplication between two\n    bipartite graphs.  You can also perform this between two homogeneous\n    graphs, or one homogeneous graph and one bipartite graph, as long as the\n    numbers of nodes of the same type match.\n\n    >>> A = dgl.heterograph({\n    ...     ('A', 'AB', 'B'): ([2, 2, 0, 2, 0, 1], [2, 1, 0, 0, 2, 2])},\n    ...     num_nodes_dict={'A': 3, 'B': 4})\n    >>> B = dgl.heterograph({\n    ...     ('B', 'BA', 'A'): ([0, 3, 2, 1, 3, 3], [1, 2, 0, 2, 1, 0])},\n    ...     num_nodes_dict={'A': 3, 'B': 4})\n\n    If your graph is a multigraph, you will need to call :func:`dgl.to_simple`\n    to convert it into a simple graph first.\n\n    >>> A = dgl.to_simple(A)\n    >>> B = dgl.to_simple(B)\n\n    Initialize learnable edge weights.\n\n    >>> A.edata['w'] = torch.randn(6).requires_grad_()\n    >>> B.edata['w'] = torch.randn(6).requires_grad_()\n\n    Take the product.\n\n    >>> C = dgl.adj_product_graph(A, B, 'w')\n    >>> C.edges()\n    (tensor([0, 0, 1, 2, 2, 2]), tensor([0, 1, 0, 0, 2, 1]))\n\n    >>> C.edata['w']\n    tensor([0.6906, 0.2002, 0.0591, 0.3672, 0.1066, 0.1328],\n           grad_fn=<CSRMMBackward>)\n\n    Note that this function is differentiable:\n\n    >>> C.edata['w'].sum().backward()\n    >>> A.edata['w'].grad\n    tensor([0.7153, 0.2775, 0.7141, 0.7141, 0.7153, 0.7153])\n\n    >>> B.edata['w'].grad\n    tensor([0.4664, 0.0000, 1.5614, 0.3840, 0.0000, 0.0000])\n\n    If the source node type of the left operand is the same as the destination\n    node type of the right operand, this function returns a homogeneous graph:\n\n    >>> C.ntypes\n    ['A']\n\n    Otherwise, it returns a bipartite graph instead:\n\n    >>> A = dgl.heterograph({\n    ...     ('A', 'AB', 'B'): ([2, 2, 0, 2, 0, 1], [2, 1, 0, 0, 2, 2])},\n    ...     num_nodes_dict={'A': 3, 'B': 4})\n    >>> B = dgl.heterograph({\n    ...     ('B', 'BC', 'C'): ([0, 3, 2, 1, 3, 3], [1, 2, 0, 2, 1, 0])},\n    ...     num_nodes_dict={'C': 3, 'B': 4})\n    >>> A.edata['w'] = torch.randn(6).requires_grad_()\n    >>> B.edata['w'] = torch.randn(6).requires_grad_()\n    >>> C = dgl.adj_product_graph(A, B, 'w')\n    >>> C.ntypes\n    ['A', 'C']\n    \"\"\"\n    srctype, _, _ = A.canonical_etypes[0]\n    _, _, dsttype = B.canonical_etypes[0]\n    num_vtypes = 1 if srctype == dsttype else 2\n    ntypes = [srctype] if num_vtypes == 1 else [srctype, dsttype]\n\n    if A.device != F.cpu():\n        if not (\n            _unitgraph_less_than_int32(A) and _unitgraph_less_than_int32(B)\n        ):\n            raise ValueError(\n                \"For GPU graphs the number of nodes and edges must be less than 2 ** 31 - 1.\"\n            )\n\n    C_gidx, C_weights = F.csrmm(\n        A._graph,\n        A.edata[weight_name],\n        B._graph,\n        B.edata[weight_name],\n        num_vtypes,\n    )\n    num_nodes_dict = {\n        srctype: A.num_nodes(srctype),\n        dsttype: B.num_nodes(dsttype),\n    }\n    C_metagraph, ntypes, etypes, _ = create_metagraph_index(\n        ntypes, [(srctype, etype, dsttype)]\n    )\n    num_nodes_per_type = [num_nodes_dict[ntype] for ntype in ntypes]\n    C_gidx = create_heterograph_from_relations(\n        C_metagraph, [C_gidx], utils.toindex(num_nodes_per_type)\n    )\n\n    C = DGLGraph(C_gidx, ntypes, etypes)\n    C.edata[weight_name] = C_weights\n    return C\n\n\ndef adj_sum_graph(graphs, weight_name):\n    r\"\"\"Create a weighted graph whose adjacency matrix is the sum of the\n    adjacency matrices of the given graphs, whose rows represent source nodes\n    and columns represent destination nodes.\n\n    All the graphs must be simple graphs, and must have only one edge type.\n    They also must have the same metagraph, i.e. have the same source node type\n    and the same destination node type.  Moreover, the number of nodes for every\n    graph must also be the same.\n\n    The metagraph of the returned graph will be the same as the input graphs.\n\n    Unlike ``scipy``, if an edge in the result graph has zero weight, it will\n    not be removed from the graph.\n\n    Notes\n    -----\n    This function works on both CPU and GPU.  For GPU, the number of nodes and\n    edges must be less than the maximum of ``int32`` (i.e. ``2 ** 31 - 1``) due\n    to restriction of cuSPARSE.\n\n    The edge weights returned by this function is differentiable w.r.t. the\n    input edge weights.\n\n    If the graph format is restricted, both graphs must have CSR available.\n\n    Parameters\n    ----------\n    graphs : list[DGLGraph]\n        The list of graphs.  Must have at least one element.\n    weight_name : str\n        The feature name of edge weight of both graphs.\n\n        The corresponding edge feature must be scalar.\n\n    Returns\n    -------\n    DGLGraph\n        The new graph.  The edge weight of the returned graph will have the\n        same feature name as :attr:`weight_name`.\n\n    Examples\n    --------\n    The following shows weighted adjacency matrix summation between two\n    bipartite graphs.  You can also perform this between homogeneous graphs.\n\n    >>> A = dgl.heterograph(\n    ...     {('A', 'AB', 'B'): ([2, 2, 0, 2, 0, 1], [2, 1, 0, 0, 2, 2])},\n    ...     num_nodes_dict={'A': 3, 'B': 4})\n    >>> B = dgl.heterograph(\n    ...     {('A', 'AB', 'B'): ([1, 2, 0, 2, 1, 0], [0, 3, 2, 1, 3, 3])},\n    ...     num_nodes_dict={'A': 3, 'B': 4})\n    >>> A.edata['w'] = torch.randn(6).requires_grad_()\n    >>> B.edata['w'] = torch.randn(6).requires_grad_()\n\n    If your graph is a multigraph, call :func:`dgl.to_simple`\n    to convert it into a simple graph first.\n\n    >>> A = dgl.to_simple(A)\n    >>> B = dgl.to_simple(B)\n\n    Initialize learnable edge weights.\n\n    >>> A.edata['w'] = torch.randn(6).requires_grad_()\n    >>> B.edata['w'] = torch.randn(6).requires_grad_()\n\n    Take the sum.\n\n    >>> C = dgl.adj_sum_graph([A, B], 'w')\n    >>> C.edges()\n    (tensor([0, 0, 0, 1, 1, 1, 2, 2, 2, 2]),\n     tensor([0, 2, 3, 2, 0, 3, 0, 1, 2, 3]))\n\n    Note that this function is differentiable:\n\n    >>> C.edata['w'].sum().backward()\n    >>> A.edata['w'].grad\n    tensor([1., 1., 1., 1., 1., 1.])\n\n    >>> B.edata['w'].grad\n    tensor([1., 1., 1., 1., 1., 1.])\n    \"\"\"\n    if len(graphs) == 0:\n        raise ValueError(\"The list of graphs must not be empty.\")\n\n    if graphs[0].device != F.cpu():\n        if not all(_unitgraph_less_than_int32(A) for A in graphs):\n            raise ValueError(\n                \"For GPU graphs the number of nodes and edges must be less than 2 ** 31 - 1.\"\n            )\n    metagraph = graphs[0]._graph.metagraph\n    num_nodes = utils.toindex(\n        [\n            graphs[0]._graph.num_nodes(i)\n            for i in range(graphs[0]._graph.number_of_ntypes())\n        ]\n    )\n    weights = [A.edata[weight_name] for A in graphs]\n    gidxs = [A._graph for A in graphs]\n    C_gidx, C_weights = F.csrsum(gidxs, weights)\n    C_gidx = create_heterograph_from_relations(metagraph, [C_gidx], num_nodes)\n\n    C = DGLGraph(C_gidx, graphs[0].ntypes, graphs[0].etypes)\n    C.edata[weight_name] = C_weights\n    return C\n\n\ndef sort_csr_by_tag(g, tag, tag_offset_name=\"_TAG_OFFSET\", tag_type=\"node\"):\n    r\"\"\"Return a new graph whose CSR matrix is sorted by the given tag.\n\n    Sort the internal CSR matrix of the graph so that the adjacency list of each node\n    , which contains the out-edges, is sorted by the tag of the out-neighbors.\n    After sorting, edges sharing the same tag will be arranged in a consecutive range in\n    a node's adjacency list. Following is an example:\n\n        Consider a graph as follows::\n\n            0 -> 0, 1, 2, 3, 4\n            1 -> 0, 1, 2\n\n        Given node tags ``[1, 1, 0, 2, 0]``, each node's adjacency list\n        will be sorted as follows::\n\n            0 -> 2, 4, 0, 1, 3\n            1 -> 2, 0, 1\n\n        Given edge tags ``[1, 1, 0, 2, 0, 1, 1, 0]`` has the same effect\n        as above node tags.\n\n    The function will also returns the starting offsets of the tag\n    segments in a tensor of shape :math:`(N, max\\_tag+2)`. For node ``i``,\n    its out-edges connecting to node tag ``j`` is stored between\n    ``tag_offsets[i][j]`` ~ ``tag_offsets[i][j+1]``. Since the offsets\n    can be viewed node data, we store it in the\n    ``ndata`` of the returned graph. Users can specify the\n    ndata name by the :attr:`tag_pos_name` argument.\n\n    Note that the function will not change the edge ID neither\n    how the edge features are stored. The input graph must\n    allow CSR format. The graph must be on CPU.\n\n    If the input graph is heterogenous, it must have only one edge\n    type and two node types (i.e., source and destination node types).\n    In this case, the provided node tags are for the destination nodes,\n    and the tag offsets are stored in the source node data.\n\n    The sorted graph and the calculated tag offsets are needed by\n    certain operators that consider node tags. See\n    :func:`~dgl.sampling.sample_neighbors_biased` for an example.\n\n    Parameters\n    ------------\n    g : DGLGraph\n        The input graph.\n    tag : Tensor\n        Integer tensor of shape :math:`(N,)`, :math:`N` being the number\n        of (destination) nodes or edges.\n    tag_offset_name : str\n        The name of the node feature to store tag offsets.\n    tag_type : str\n        Tag type which could be ``node`` or ``edge``.\n\n    Returns\n    -------\n    g_sorted : DGLGraph\n        A new graph whose CSR is sorted. The node/edge features of the\n        input graph is shallow-copied over.\n\n        - ``g_sorted.ndata[tag_offset_name]`` : Tensor of shape :math:`(N, max\\_tag + 2)`.\n        - If ``g`` is heterogeneous, get from ``g_sorted.srcdata``.\n\n    Examples\n    -----------\n\n    ``tag_type`` is ``node``.\n\n    >>> import dgl\n    >>> import torch\n\n    >>> g = dgl.graph(([0,0,0,0,0,1,1,1],[0,1,2,3,4,0,1,2]))\n    >>> g.adj_external(scipy_fmt='csr').nonzero()\n    (array([0, 0, 0, 0, 0, 1, 1, 1], dtype=int32),\n     array([0, 1, 2, 3, 4, 0, 1, 2], dtype=int32))\n    >>> tag = torch.IntTensor([1,1,0,2,0])\n    >>> g_sorted = dgl.sort_csr_by_tag(g, tag)\n    >>> g_sorted.adj_external(scipy_fmt='csr').nonzero()\n    (array([0, 0, 0, 0, 0, 1, 1, 1], dtype=int32),\n     array([2, 4, 0, 1, 3, 2, 0, 1], dtype=int32))\n    >>> g_sorted.ndata['_TAG_OFFSET']\n    tensor([[0, 2, 4, 5],\n            [0, 1, 3, 3],\n            [0, 0, 0, 0],\n            [0, 0, 0, 0],\n            [0, 0, 0, 0]])\n\n    ``tag_type`` is ``edge``.\n\n    >>> g = dgl.graph(([0,0,0,0,0,1,1,1],[0,1,2,3,4,0,1,2]))\n    >>> g.edges()\n    (tensor([0, 0, 0, 0, 0, 1, 1, 1]), tensor([0, 1, 2, 3, 4, 0, 1, 2]))\n    >>> tag = torch.tensor([1, 1, 0, 2, 0, 1, 1, 0])\n    >>> g_sorted = dgl.sort_csr_by_tag(g, tag, tag_type='edge')\n    >>> g_sorted.adj_external(scipy_fmt='csr').nonzero()\n    (array([0, 0, 0, 0, 0, 1, 1, 1], dtype=int32), array([2, 4, 0, 1, 3, 2, 0, 1], dtype=int32))\n    >>> g_sorted.srcdata['_TAG_OFFSET']\n    tensor([[0, 2, 4, 5],\n            [0, 1, 3, 3],\n            [0, 0, 0, 0],\n            [0, 0, 0, 0],\n            [0, 0, 0, 0]])\n\n    See Also\n    --------\n    dgl.sampling.sample_neighbors_biased\n    \"\"\"\n    if len(g.etypes) > 1:\n        raise DGLError(\"Only support homograph and bipartite graph\")\n    assert tag_type in [\n        \"node\",\n        \"edge\",\n    ], \"tag_type should be either 'node' or 'edge'.\"\n    if tag_type == \"node\":\n        _, dst = g.edges()\n        tag = F.gather_row(tag, F.tensor(dst))\n    assert len(tag) == g.num_edges()\n    num_tags = int(F.asnumpy(F.max(tag, 0))) + 1\n    tag_arr = F.zerocopy_to_dgl_ndarray(tag)\n    new_g = g.clone()\n    new_g._graph, tag_pos_arr = _CAPI_DGLHeteroSortOutEdges(\n        g._graph, tag_arr, num_tags\n    )\n    new_g.srcdata[tag_offset_name] = F.from_dgl_nd(tag_pos_arr)\n    return new_g\n\n\ndef sort_csc_by_tag(g, tag, tag_offset_name=\"_TAG_OFFSET\", tag_type=\"node\"):\n    r\"\"\"Return a new graph whose CSC matrix is sorted by the given tag.\n\n    Sort the internal CSC matrix of the graph so that the adjacency list of each node\n    , which contains the in-edges, is sorted by the tag of the in-neighbors.\n    After sorting, edges sharing the same tag will be arranged in a consecutive range in\n    a node's adjacency list. Following is an example:\n\n\n        Consider a graph as follows::\n\n            0 <- 0, 1, 2, 3, 4\n            1 <- 0, 1, 2\n\n        Given node tags ``[1, 1, 0, 2, 0]``, each node's adjacency list\n        will be sorted as follows::\n\n            0 <- 2, 4, 0, 1, 3\n            1 <- 2, 0, 1\n\n        Given edge tags ``[1, 1, 0, 2, 0, 1, 1, 0]`` has the same effect\n        as above node tags.\n\n    The function will also return the starting offsets of the tag\n    segments in a tensor of shape :math:`(N, max\\_tag+2)`. For a node ``i``,\n    its in-edges connecting to node tag ``j`` is stored between\n    ``tag_offsets[i][j]`` ~ ``tag_offsets[i][j+1]``. Since the offsets\n    can be viewed node data, we store it in the\n    ``ndata`` of the returned graph. Users can specify the\n    ndata name by the ``tag_pos_name`` argument.\n\n    Note that the function will not change the edge ID neither\n    how the edge features are stored. The input graph must\n    allow CSC format. The graph must be on CPU.\n\n    If the input graph is heterogenous, it must have only one edge\n    type and two node types (i.e., source and destination node types).\n    In this case, the provided node tags are for the source nodes,\n    and the tag offsets are stored in the destination node data.\n\n    The sorted graph and the calculated tag offsets are needed by\n    certain operators that consider node tags. See :func:`~dgl.sampling.sample_neighbors_biased`\n    for an example.\n\n    Parameters\n    ------------\n    g : DGLGraph\n        The input graph.\n    tag : Tensor\n        Integer tensor of shape :math:`(N,)`, :math:`N` being the number\n        of (source) nodes or edges.\n    tag_offset_name : str\n        The name of the node feature to store tag offsets.\n    tag_type : str\n        Tag type which could be ``node`` or ``edge``.\n\n    Returns\n    -------\n    g_sorted : DGLGraph\n        A new graph whose CSC matrix is sorted. The node/edge features of the\n        input graph is shallow-copied over.\n\n        - ``g_sorted.ndata[tag_offset_name]`` : Tensor of shape :math:`(N, max\\_tag + 2)`.\n        - If ``g`` is heterogeneous, get from ``g_sorted.dstdata``.\n\n    Examples\n    -----------\n\n    ``tag_type`` is ``node``.\n\n    >>> import dgl\n    >>> import torch\n    >>> g = dgl.graph(([0,1,2,3,4,0,1,2],[0,0,0,0,0,1,1,1]))\n    >>> g.adj_external(scipy_fmt='csr', transpose=True).nonzero()\n    (array([0, 0, 0, 0, 0, 1, 1, 1], dtype=int32),\n     array([0, 1, 2, 3, 4, 0, 1, 2], dtype=int32)))\n    >>> tag = torch.IntTensor([1,1,0,2,0])\n    >>> g_sorted = dgl.sort_csc_by_tag(g, tag)\n    >>> g_sorted.adj_external(scipy_fmt='csr', transpose=True).nonzero()\n    (array([0, 0, 0, 0, 0, 1, 1, 1], dtype=int32),\n     array([2, 4, 0, 1, 3, 2, 0, 1], dtype=int32))\n    >>> g_sorted.ndata['_TAG_OFFSET']\n    tensor([[0, 2, 4, 5],\n            [0, 1, 3, 3],\n            [0, 0, 0, 0],\n            [0, 0, 0, 0],\n            [0, 0, 0, 0]])\n\n    ``tag_type`` is ``edge``.\n\n    >>> g = dgl.graph(([0,1,2,3,4,0,1,2],[0,0,0,0,0,1,1,1]))\n    >>> tag = torch.tensor([1, 1, 0, 2, 0, 1, 1, 0])\n    >>> g_sorted = dgl.sort_csc_by_tag(g, tag, tag_type='edge')\n    >>> g_sorted.adj_external(scipy_fmt='csr', transpose=True).nonzero()\n    (array([0, 0, 0, 0, 0, 1, 1, 1], dtype=int32), array([2, 4, 0, 1, 3, 2, 0, 1], dtype=int32))\n    >>> g_sorted.dstdata['_TAG_OFFSET']\n    tensor([[0, 2, 4, 5],\n            [0, 1, 3, 3],\n            [0, 0, 0, 0],\n            [0, 0, 0, 0],\n            [0, 0, 0, 0]])\n\n    See Also\n    --------\n    dgl.sampling.sample_neighbors_biased\n    \"\"\"\n    if len(g.etypes) > 1:\n        raise DGLError(\"Only support homograph and bipartite graph\")\n    assert tag_type in [\n        \"node\",\n        \"edge\",\n    ], \"tag_type should be either 'node' or 'edge'.\"\n    if tag_type == \"node\":\n        src, _ = g.edges()\n        tag = F.gather_row(tag, F.tensor(src))\n    assert len(tag) == g.num_edges()\n    num_tags = int(F.asnumpy(F.max(tag, 0))) + 1\n    tag_arr = F.zerocopy_to_dgl_ndarray(tag)\n    new_g = g.clone()\n    new_g._graph, tag_pos_arr = _CAPI_DGLHeteroSortInEdges(\n        g._graph, tag_arr, num_tags\n    )\n    new_g.dstdata[tag_offset_name] = F.from_dgl_nd(tag_pos_arr)\n    return new_g\n\n\ndef reorder_graph(\n    g,\n    node_permute_algo=None,\n    edge_permute_algo=\"src\",\n    store_ids=True,\n    permute_config=None,\n):\n    r\"\"\"Return a new graph with nodes and edges re-ordered/re-labeled\n    according to the specified permute algorithm.\n\n    Support homogeneous graph only for the moment.\n\n    The re-ordering has two 2 steps: first re-order nodes and then re-order edges.\n\n    For node permutation, users can re-order by the :attr:`node_permute_algo`\n    argument. For edge permutation, user can re-arrange edges according to their\n    source nodes or destination nodes by the :attr:`edge_permute_algo` argument.\n    Some of the permutation algorithms are only implemented in CPU, so if the\n    input graph is on GPU, it will be copied to CPU first. The storage order of\n    the node and edge features in the graph are permuted accordingly.\n\n    Parameters\n    ----------\n    g : DGLGraph\n        The homogeneous graph.\n    node_permute_algo: str, optional\n        The permutation algorithm to re-order nodes. If given, the options are ``rcmk`` or\n        ``metis`` or ``custom``.\n\n        * ``None``: Keep the current node order.\n        * ``rcmk``: Use the `Reverse Cuthill–McKee <https://docs.scipy.org/doc/scipy/reference/\n          generated/scipy.sparse.csgraph.reverse_cuthill_mckee.html#\n          scipy-sparse-csgraph-reverse-cuthill-mckee>`__ from ``scipy`` to generate nodes\n          permutation.\n        * ``metis``: Use the :func:`~dgl.metis_partition_assignment` function\n          to partition the input graph, which gives a cluster assignment of each node.\n          DGL then sorts the assignment array so the new node order will put nodes of\n          the same cluster together. Please note that the generated nodes permutation\n          of ``metis`` is non-deterministic due to algorithm's nature.\n        * ``custom``: Reorder the graph according to the user-provided node permutation\n          array (provided in :attr:`permute_config`).\n    edge_permute_algo: str, optional\n        The permutation algorithm to reorder edges. Options are ``src`` or ``dst`` or\n        ``custom``. ``src`` is the default value.\n\n        * ``src``: Edges are arranged according to their source nodes.\n        * ``dst``: Edges are arranged according to their destination nodes.\n        * ``custom``: Edges are arranged according to the user-provided edge permutation\n          array (provided in :attr:`permute_config`).\n    store_ids: bool, optional\n        If True, DGL will store the original node and edge IDs in the ndata and edata\n        of the resulting graph under name ``dgl.NID`` and ``dgl.EID``, respectively.\n    permute_config: dict, optional\n        Additional key-value config data for the specified permutation algorithm.\n\n        * For ``rcmk``, this argument is not required.\n        * For ``metis``, users should specify the number of partitions ``k`` (e.g.,\n          ``permute_config={'k':10}`` to partition the graph to 10 clusters).\n        * For ``custom`` node reordering, users should provide a node permutation\n          array ``nodes_perm``. The array must be an integer list or a tensor with\n          the same device of the input graph.\n        * For ``custom`` edge reordering, users should provide an edge permutation\n          array ``edges_perm``. The array must be an integer list or a tensor with\n          the same device of the input graph.\n\n    Returns\n    -------\n    DGLGraph\n        The re-ordered graph.\n\n    Examples\n    --------\n    >>> import dgl\n    >>> import torch\n    >>> g = dgl.graph((torch.tensor([0, 1, 2, 3, 4]), torch.tensor([2, 2, 3, 2, 3])))\n    >>> g.ndata['h'] = torch.arange(g.num_nodes() * 2).view(g.num_nodes(), 2)\n    >>> g.edata['w'] = torch.arange(g.num_edges() * 1).view(g.num_edges(), 1)\n    >>> g.ndata\n    {'h': tensor([[0, 1],\n            [2, 3],\n            [4, 5],\n            [6, 7],\n            [8, 9]])}\n    >>> g.edata\n    {'w': tensor([[0],\n            [1],\n            [2],\n            [3],\n            [4]])}\n\n    Reorder according to ``'rcmk'`` permute algorithm.\n\n    >>> rg = dgl.reorder_graph(g, node_permute_algo='rcmk')\n    >>> rg.ndata\n    {'h': tensor([[8, 9],\n            [6, 7],\n            [2, 3],\n            [4, 5],\n            [0, 1]]), '_ID': tensor([4, 3, 1, 2, 0])}\n    >>> rg.edata\n    {'w': tensor([[4],\n            [3],\n            [1],\n            [2],\n            [0]]), '_ID': tensor([4, 3, 1, 2, 0])}\n\n    Reorder according to ``'metis'`` permute algorithm.\n\n    >>> rg = dgl.reorder_graph(g, node_permute_algo='metis', permute_config={'k':2})\n    >>> rg.ndata\n    {'h': tensor([[4, 5],\n            [2, 3],\n            [0, 1],\n            [8, 9],\n            [6, 7]]), '_ID': tensor([2, 1, 0, 4, 3])}\n    >>> rg.edata\n    {'w': tensor([[2],\n            [1],\n            [0],\n            [4],\n            [3]]), '_ID': tensor([2, 1, 0, 4, 3])}\n\n    Reorder according to ``'custom'`` permute algorithm with user-provided nodes_perm.\n\n    >>> rg = dgl.reorder_graph(g, node_permute_algo='custom',\n    ...                        permute_config={'nodes_perm': [3, 2, 0, 4, 1]})\n    >>> rg.ndata\n    {'h': tensor([[6, 7],\n            [4, 5],\n            [0, 1],\n            [8, 9],\n            [2, 3]]), '_ID': tensor([3, 2, 0, 4, 1])}\n    >>> rg.edata\n    {'w': tensor([[3],\n            [2],\n            [0],\n            [4],\n            [1]]), '_ID': tensor([3, 2, 0, 4, 1])}\n\n    Reorder nodes according to ``'rcmk'`` and reorder edges according to ``dst``\n    edge permute algorithm.\n\n    >>> rg = dgl.reorder_graph(g, node_permute_algo='rcmk', edge_permute_algo='dst')\n    >>> print(rg.ndata)\n    {'h': tensor([[8, 9],\n            [6, 7],\n            [2, 3],\n            [4, 5],\n            [0, 1]]), '_ID': tensor([4, 3, 1, 2, 0])}\n    >>> print(rg.edata)\n    {'w': tensor([[4],\n            [2],\n            [3],\n            [1],\n            [0]]), '_ID': tensor([4, 2, 3, 1, 0])}\n\n    Nodes are not reordered but edges are reordered according to ``'custom'`` permute\n    algorithm with user-provided edges_perm.\n\n    >>> rg = dgl.reorder_graph(g, edge_permute_algo='custom',\n    ...                        permute_config={'edges_perm': [1, 2, 3, 4, 0]})\n    >>> print(rg.ndata)\n    {'h': tensor([[0, 1],\n            [2, 3],\n            [4, 5],\n            [6, 7],\n            [8, 9]]), '_ID': tensor([0, 1, 2, 3, 4])}\n    >>> print(rg.edata)\n    {'w': tensor([[1],\n            [2],\n            [3],\n            [4],\n            [0]]), '_ID': tensor([1, 2, 3, 4, 0])}\n    \"\"\"\n    # sanity checks\n    if not g.is_homogeneous:\n        raise DGLError(\"Only homogeneous graphs are supported.\")\n    expected_node_algo = [\"rcmk\", \"metis\", \"custom\"]\n    if (\n        node_permute_algo is not None\n        and node_permute_algo not in expected_node_algo\n    ):\n        raise DGLError(\n            \"Unexpected node_permute_algo is specified: {}. Expected algos: {}\".format(\n                node_permute_algo, expected_node_algo\n            )\n        )\n    expected_edge_algo = [\"src\", \"dst\", \"custom\"]\n    if edge_permute_algo not in expected_edge_algo:\n        raise DGLError(\n            \"Unexpected edge_permute_algo is specified: {}. Expected algos: {}\".format(\n                edge_permute_algo, expected_edge_algo\n            )\n        )\n\n    g.edata[\"__orig__\"] = F.arange(0, g.num_edges(), g.idtype, g.device)\n\n    # reorder nodes\n    if node_permute_algo == \"rcmk\":\n        nodes_perm = rcmk_perm(g)\n        rg = subgraph.node_subgraph(g, nodes_perm, store_ids=False)\n    elif node_permute_algo == \"metis\":\n        if permute_config is None or \"k\" not in permute_config:\n            raise DGLError(\n                \"Partition parts 'k' is required for metis. Please specify in permute_config.\"\n            )\n        nodes_perm = metis_perm(g, permute_config[\"k\"])\n        rg = subgraph.node_subgraph(g, nodes_perm, store_ids=False)\n    elif node_permute_algo == \"custom\":\n        if permute_config is None or \"nodes_perm\" not in permute_config:\n            raise DGLError(\n                \"node_permute_algo is specified as custom, but no 'nodes_perm' is specified in \\\n                    permute_config.\"\n            )\n        nodes_perm = permute_config[\"nodes_perm\"]\n        if len(nodes_perm) != g.num_nodes():\n            raise DGLError(\n                \"Length of 'nodes_perm' ({}) does not \\\n                    match graph num_nodes ({}).\".format(\n                    len(nodes_perm), g.num_nodes()\n                )\n            )\n        rg = subgraph.node_subgraph(g, nodes_perm, store_ids=False)\n    else:\n        nodes_perm = F.arange(0, g.num_nodes(), g.idtype, g.device)\n        rg = g.clone()\n\n    if store_ids:\n        rg.ndata[NID] = F.copy_to(F.tensor(nodes_perm, g.idtype), g.device)\n\n    g.edata.pop(\"__orig__\")\n\n    # reorder edges\n    if edge_permute_algo == \"src\":\n        edges_perm = np.argsort(F.asnumpy(rg.edges()[0]))\n        rg = subgraph.edge_subgraph(\n            rg, edges_perm, relabel_nodes=False, store_ids=False\n        )\n    elif edge_permute_algo == \"dst\":\n        edges_perm = np.argsort(F.asnumpy(rg.edges()[1]))\n        rg = subgraph.edge_subgraph(\n            rg, edges_perm, relabel_nodes=False, store_ids=False\n        )\n    elif edge_permute_algo == \"custom\":\n        if permute_config is None or \"edges_perm\" not in permute_config:\n            raise DGLError(\n                \"edge_permute_algo is specified as custom, but no 'edges_perm' is specified in \\\n                    permute_config.\"\n            )\n        edges_perm = permute_config[\"edges_perm\"]\n        # First revert the edge reorder caused by node reorder and then\n        # apply user-provided edge permutation\n        rev_id = F.argsort(rg.edata[\"__orig__\"], 0, False)\n        edges_perm = F.astype(\n            F.gather_row(rev_id, F.tensor(edges_perm)), rg.idtype\n        )\n        rg = subgraph.edge_subgraph(\n            rg, edges_perm, relabel_nodes=False, store_ids=False\n        )\n\n    if store_ids:\n        rg.edata[EID] = rg.edata.pop(\"__orig__\")\n\n    return rg\n\n\nDGLGraph.reorder_graph = utils.alias_func(reorder_graph)\n\n\ndef metis_perm(g, k):\n    r\"\"\"Return nodes permutation according to ``'metis'`` algorithm.\n\n    For internal use.\n\n    Parameters\n    ----------\n    g : DGLGraph\n        The homogeneous graph.\n    k: int\n        The partition parts number.\n\n    Returns\n    -------\n    iterable[int]\n        The nodes permutation.\n    \"\"\"\n    pids = metis_partition_assignment(\n        g if g.device == F.cpu() else g.to(F.cpu()), k\n    )\n    pids = F.asnumpy(pids)\n    return np.argsort(pids).copy()\n\n\ndef rcmk_perm(g):\n    r\"\"\"Return nodes permutation according to ``'rcmk'`` algorithm.\n\n    For internal use.\n\n    Parameters\n    ----------\n    g : DGLGraph\n        The homogeneous graph.\n\n    Returns\n    -------\n    iterable[int]\n        The nodes permutation.\n    \"\"\"\n    fmat = \"csr\"\n    allowed_fmats = sum(g.formats().values(), [])\n    if fmat not in allowed_fmats:\n        g = g.formats(allowed_fmats + [fmat])\n    csr_adj = g.adj_external(scipy_fmt=fmat)\n    perm = sparse.csgraph.reverse_cuthill_mckee(csr_adj)\n    return perm.copy()\n\n\ndef norm_by_dst(g, etype=None):\n    r\"\"\"Calculate normalization coefficient per edge based on destination node degree.\n\n    Parameters\n    ----------\n    g : DGLGraph\n        The input graph.\n    etype : str or (str, str, str), optional\n        The type of the edges to calculate. The allowed edge type formats are:\n\n        * ``(str, str, str)`` for source node type, edge type and destination node type.\n        * or one ``str`` edge type name if the name can uniquely identify a\n          triplet format in the graph.\n\n        It can be omitted if the graph has a single edge type.\n\n    Returns\n    -------\n    1D Tensor\n        The normalization coefficient of the edges.\n\n    Examples\n    --------\n\n    >>> import dgl\n    >>> g = dgl.graph(([0, 1, 1], [1, 1, 2]))\n    >>> print(dgl.norm_by_dst(g))\n    tensor([0.5000, 0.5000, 1.0000])\n    \"\"\"\n    _, v, _ = g.edges(form=\"all\", etype=etype)\n    _, inv_index, count = F.unique(v, return_inverse=True, return_counts=True)\n    deg = F.astype(count[inv_index], F.float32)\n    norm = 1.0 / deg\n    norm = F.replace_inf_with_zero(norm)\n\n    return norm\n\n\ndef radius_graph(\n    x,\n    r,\n    p=2,\n    self_loop=False,\n    compute_mode=\"donot_use_mm_for_euclid_dist\",\n    get_distances=False,\n):\n    r\"\"\"Construct a graph from a set of points with neighbors within given distance.\n\n    The function transforms the coordinates/features of a point set\n    into a bidirected homogeneous graph. The coordinates of the point\n    set is specified as a matrix whose rows correspond to points and\n    columns correspond to coordinate/feature dimensions.\n\n    The nodes of the returned graph correspond to the points, where the neighbors\n    of each point are within given distance.\n\n    The function requires the PyTorch backend.\n\n    Parameters\n    ----------\n    x : Tensor\n        The point coordinates. It can be either on CPU or GPU.\n        Device of the point coordinates specifies device of the radius graph and\n        ``x[i]`` corresponds to the i-th node in the radius graph.\n    r : float\n        Radius of the neighbors.\n    p : float, optional\n        Power parameter for the Minkowski metric. When :attr:`p = 1` it is the\n        equivalent of Manhattan distance (L1 norm) and Euclidean distance\n        (L2 norm) for :attr:`p = 2`.\n\n        (default: 2)\n    self_loop : bool, optional\n        Whether the radius graph will contain self-loops.\n\n        (default: False)\n    compute_mode : str, optional\n        ``use_mm_for_euclid_dist_if_necessary`` - will use matrix multiplication\n        approach to calculate euclidean distance (p = 2) if P > 25 or R > 25\n        ``use_mm_for_euclid_dist`` - will always use matrix multiplication\n        approach to calculate euclidean distance (p = 2)\n        ``donot_use_mm_for_euclid_dist`` - will never use matrix multiplication\n        approach to calculate euclidean distance (p = 2).\n\n        (default: donot_use_mm_for_euclid_dist)\n    get_distances : bool, optional\n        Whether to return the distances for the corresponding edges in the\n        radius graph.\n\n        (default: False)\n\n    Returns\n    -------\n    DGLGraph\n        The constructed graph. The node IDs are in the same order as :attr:`x`.\n    torch.Tensor, optional\n        The distances for the edges in the constructed graph. The distances are\n        in the same order as edge IDs.\n\n    Examples\n    --------\n\n    The following examples use PyTorch backend.\n\n    >>> import dgl\n    >>> import torch\n\n    >>> x = torch.tensor([[0.0, 0.0, 1.0],\n    ...                   [1.0, 0.5, 0.5],\n    ...                   [0.5, 0.2, 0.2],\n    ...                   [0.3, 0.2, 0.4]])\n    >>> r_g = dgl.radius_graph(x, 0.75)  # Each node has neighbors within 0.75 distance\n    >>> r_g.edges()\n    (tensor([0, 1, 2, 2, 3, 3]), tensor([3, 2, 1, 3, 0, 2]))\n\n    When :attr:`get_distances` is True, function returns the radius graph and\n    distances for the corresponding edges.\n\n    >>> x = torch.tensor([[0.0, 0.0, 1.0],\n    ...                   [1.0, 0.5, 0.5],\n    ...                   [0.5, 0.2, 0.2],\n    ...                   [0.3, 0.2, 0.4]])\n    >>> r_g, dist = dgl.radius_graph(x, 0.75, get_distances=True)\n    >>> r_g.edges()\n    (tensor([0, 1, 2, 2, 3, 3]), tensor([3, 2, 1, 3, 0, 2]))\n    >>> dist\n    tensor([[0.7000],\n            [0.6557],\n            [0.6557],\n            [0.2828],\n            [0.7000],\n            [0.2828]])\n    \"\"\"\n    # check invalid r\n    if r <= 0:\n        raise DGLError(\"Invalid r value. expect r > 0, got r = {}\".format(r))\n\n    # check empty point set\n    if F.shape(x)[0] == 0:\n        raise DGLError(\"Find empty point set\")\n\n    distances = th.cdist(x, x, p=p, compute_mode=compute_mode)\n\n    if not self_loop:\n        distances.fill_diagonal_(r + 1)\n\n    edges = th.nonzero(distances <= r, as_tuple=True)\n\n    g = convert.graph(edges, num_nodes=x.shape[0], device=x.device)\n\n    if get_distances:\n        distances = distances[edges].unsqueeze(-1)\n\n        return g, distances\n\n    return g\n\n\ndef random_walk_pe(g, k, eweight_name=None):\n    r\"\"\"Random Walk Positional Encoding, as introduced in\n    `Graph Neural Networks with Learnable Structural and Positional Representations\n    <https://arxiv.org/abs/2110.07875>`__\n\n    This function computes the random walk positional encodings as landing probabilities\n    from 1-step to k-step, starting from each node to itself.\n\n    Parameters\n    ----------\n    g : DGLGraph\n        The input graph. Must be homogeneous.\n    k : int\n        The number of random walk steps. The paper found the best value to be 16 and 20\n        for two experiments.\n    eweight_name : str, optional\n        The name to retrieve the edge weights. Default: None, not using the edge weights.\n\n    Returns\n    -------\n    Tensor\n        The random walk positional encodings of shape :math:`(N, k)`, where :math:`N` is the\n        number of nodes in the input graph.\n\n    Example\n    -------\n    >>> import dgl\n    >>> g = dgl.graph(([0,1,1], [1,1,0]))\n    >>> dgl.random_walk_pe(g, 2)\n    tensor([[0.0000, 0.5000],\n            [0.5000, 0.7500]])\n    \"\"\"\n    N = g.num_nodes()  # number of nodes\n    M = g.num_edges()  # number of edges\n    A = g.adj_external(scipy_fmt=\"csr\")  # adjacency matrix\n    if eweight_name is not None:\n        # add edge weights if required\n        W = sparse.csr_matrix(\n            (g.edata[eweight_name].squeeze(), g.find_edges(list(range(M)))),\n            shape=(N, N),\n        )\n        A = A.multiply(W)\n    # 1-step transition probability\n    if version.parse(scipy.__version__) < version.parse(\"1.11.0\"):\n        RW = np.array(A / (A.sum(1) + 1e-30))\n    else:\n        # Sparse matrix divided by a dense array returns a sparse matrix in\n        # scipy since 1.11.0.\n        RW = (A / (A.sum(1) + 1e-30)).toarray()\n\n    # Iterate for k steps\n    PE = [F.astype(F.tensor(np.array(RW.diagonal())), F.float32)]\n    RW_power = RW\n    for _ in range(k - 1):\n        RW_power = RW_power @ RW\n        PE.append(F.astype(F.tensor(np.array(RW_power.diagonal())), F.float32))\n    PE = F.stack(PE, dim=-1)\n\n    return PE\n\n\ndef lap_pe(g, k, padding=False, return_eigval=False):\n    r\"\"\"Laplacian Positional Encoding, as introduced in\n    `Benchmarking Graph Neural Networks\n    <https://arxiv.org/abs/2003.00982>`__\n\n    This function computes the laplacian positional encodings as the\n    k smallest non-trivial eigenvectors.\n\n    Parameters\n    ----------\n    g : DGLGraph\n        The input graph. Must be homogeneous and bidirected.\n    k : int\n        Number of smallest non-trivial eigenvectors to use for positional\n        encoding.\n    padding : bool, optional\n        If False, raise an exception when k>=n. Otherwise, add zero paddings\n        in the end of eigenvectors and 'nan' paddings in the end of eigenvalues\n        when k>=n. Default: False. n is the number of nodes in the given graph.\n    return_eigval : bool, optional\n        If True, return laplacian eigenvalues together with eigenvectors.\n        Otherwise, return laplacian eigenvectors only.\n        Default: False.\n\n    Returns\n    -------\n    Tensor or (Tensor, Tensor)\n        Return the laplacian positional encodings of shape :math:`(N, k)`,\n        where :math:`N` is the number of nodes in the input graph, when\n        :attr:`return_eigval` is False. The eigenvalues of shape :math:`N` is\n        additionally returned as the second element when :attr:`return_eigval`\n        is True.\n\n    Example\n    -------\n    >>> import dgl\n    >>> g = dgl.graph(([0,1,2,3,1,2,3,0], [1,2,3,0,0,1,2,3]))\n    >>> dgl.lap_pe(g, 2)\n    tensor([[ 7.0711e-01, -6.4921e-17],\n            [ 3.0483e-16, -7.0711e-01],\n            [-7.0711e-01, -2.4910e-16],\n            [ 9.9288e-17,  7.0711e-01]])\n    >>> dgl.lap_pe(g, 5, padding=True)\n    tensor([[ 7.0711e-01, -6.4921e-17,  5.0000e-01,  0.0000e+00,  0.0000e+00],\n            [ 3.0483e-16, -7.0711e-01, -5.0000e-01,  0.0000e+00,  0.0000e+00],\n            [-7.0711e-01, -2.4910e-16,  5.0000e-01,  0.0000e+00,  0.0000e+00],\n            [ 9.9288e-17,  7.0711e-01, -5.0000e-01,  0.0000e+00,  0.0000e+00]])\n    >>> dgl.lap_pe(g, 5, padding=True, return_eigval=True)\n    (tensor([[-7.0711e-01,  6.4921e-17, -5.0000e-01,  0.0000e+00,  0.0000e+00],\n             [-3.0483e-16,  7.0711e-01,  5.0000e-01,  0.0000e+00,  0.0000e+00],\n             [ 7.0711e-01,  2.4910e-16, -5.0000e-01,  0.0000e+00,  0.0000e+00],\n             [-9.9288e-17, -7.0711e-01,  5.0000e-01,  0.0000e+00,  0.0000e+00]]),\n     tensor([1., 1., 2., nan, nan]))\n    \"\"\"\n    # check for the \"k < n\" constraint\n    n = g.num_nodes()\n    if not padding and n <= k:\n        assert (\n            \"the number of eigenvectors k must be smaller than the number of \"\n            + f\"nodes n, {k} and {n} detected.\"\n        )\n\n    # get laplacian matrix as I - D^-0.5 * A * D^-0.5\n    A = g.adj_external(scipy_fmt=\"csr\")  # adjacency matrix\n    N = sparse.diags(\n        F.asnumpy(g.in_degrees()).clip(1) ** -0.5, dtype=float\n    )  # D^-1/2\n    L = sparse.eye(g.num_nodes()) - N * A * N\n\n    # select eigenvectors with smaller eigenvalues O(n + klogk)\n    if k + 1 < n - 1:\n        # Use scipy if k + 1 < n - 1 for memory efficiency.\n        EigVal, EigVec = scipy.sparse.linalg.eigs(\n            L, k=k + 1, which=\"SR\", ncv=4 * k, tol=1e-2\n        )\n        max_freqs = k\n        topk_indices = EigVal.argsort()[1:]\n    else:\n        # Fallback to numpy since scipy.sparse do not support this case.\n        EigVal, EigVec = np.linalg.eig(L.toarray())\n        max_freqs = min(n - 1, k)\n        kpartition_indices = np.argpartition(EigVal, max_freqs)[: max_freqs + 1]\n        topk_eigvals = EigVal[kpartition_indices]\n        topk_indices = kpartition_indices[topk_eigvals.argsort()][1:]\n\n    # Since scipy may return complex value, to avoid crashing in NN code,\n    # convert them to real number.\n    topk_EigVal = EigVal[topk_indices].real\n    topk_EigVec = EigVec[:, topk_indices].real\n    eigvals = F.tensor(topk_EigVal, dtype=F.float32)\n\n    # get random flip signs\n    rand_sign = 2 * (np.random.rand(max_freqs) > 0.5) - 1.0\n    PE = F.astype(F.tensor(rand_sign * topk_EigVec), F.float32)\n\n    # add paddings\n    if n <= k:\n        temp_EigVec = F.zeros(\n            [n, k - n + 1], dtype=F.float32, ctx=F.context(PE)\n        )\n        PE = F.cat([PE, temp_EigVec], dim=1)\n        temp_EigVal = F.tensor(np.full(k - n + 1, np.nan), F.float32)\n        eigvals = F.cat([eigvals, temp_EigVal], dim=0)\n\n    if return_eigval:\n        return PE, eigvals\n    return PE\n\n\ndef laplacian_pe(g, k, padding=False, return_eigval=False):\n    r\"\"\"Alias of `dgl.lap_pe`.\"\"\"\n    dgl_warning(\"dgl.laplacian_pe will be deprecated. Use dgl.lap_pe please.\")\n    return lap_pe(g, k, padding, return_eigval)\n\n\ndef to_bfloat16(g):\n    r\"\"\"Cast this graph to use bfloat16 for any\n    floating-point edge and node feature data.\n\n    A shallow copy is returned so that the original graph is not modified.\n    Feature tensors that are not floating-point will not be modified.\n\n    Returns\n    -------\n    DGLGraph\n        Clone of graph with the feature data converted to float16.\n    \"\"\"\n    ret = copy.copy(g)\n    ret._edge_frames = [frame.bfloat16() for frame in ret._edge_frames]\n    ret._node_frames = [frame.bfloat16() for frame in ret._node_frames]\n    return ret\n\n\ndef to_half(g):\n    r\"\"\"Cast this graph to use float16 (half-precision) for any\n    floating-point edge and node feature data.\n\n    A shallow copy is returned so that the original graph is not modified.\n    Feature tensors that are not floating-point will not be modified.\n\n    Returns\n    -------\n    DGLGraph\n        Clone of graph with the feature data converted to float16.\n    \"\"\"\n    ret = copy.copy(g)\n    ret._edge_frames = [frame.half() for frame in ret._edge_frames]\n    ret._node_frames = [frame.half() for frame in ret._node_frames]\n    return ret\n\n\ndef to_float(g):\n    r\"\"\"Cast this graph to use float32 (single-precision) for any\n    floating-point edge and node feature data.\n\n    A shallow copy is returned so that the original graph is not modified.\n    Feature tensors that are not floating-point will not be modified.\n\n    Returns\n    -------\n    DGLGraph\n        Clone of graph with the feature data converted to float32.\n    \"\"\"\n    ret = copy.copy(g)\n    ret._edge_frames = [frame.float() for frame in ret._edge_frames]\n    ret._node_frames = [frame.float() for frame in ret._node_frames]\n    return ret\n\n\ndef to_double(g):\n    r\"\"\"Cast this graph to use float64 (double-precision) for any\n    floating-point edge and node feature data.\n\n    A shallow copy is returned so that the original graph is not modified.\n    Feature tensors that are not floating-point will not be modified.\n\n    Returns\n    -------\n    DGLGraph\n        Clone of graph with the feature data converted to float64.\n    \"\"\"\n    ret = copy.copy(g)\n    ret._edge_frames = [frame.double() for frame in ret._edge_frames]\n    ret._node_frames = [frame.double() for frame in ret._node_frames]\n    return ret\n\n\ndef double_radius_node_labeling(g, src, dst):\n    r\"\"\"Double Radius Node Labeling, as introduced in `Link Prediction\n    Based on Graph Neural Networks <https://arxiv.org/abs/1802.09691>`__.\n\n    This function computes the double radius node labeling for each node to mark\n    nodes' different roles in an enclosing subgraph, given a target link.\n\n    The node labels of source :math:`s` and destination :math:`t` are set to 1 and\n    those of unreachable nodes from source or destination are set to 0. The labels\n    of other nodes :math:`l` are defined according to the following hash function:\n\n    :math:`l = 1 + min(d_s, d_t) + (d//2)[(d//2) + (d%2) - 1]`\n\n    where :math:`d_s` and :math:`d_t` denote the shortest distance to the source and\n    the target, respectively. :math:`d = d_s + d_t`.\n\n    Parameters\n    ----------\n    g : DGLGraph\n        The input graph.\n    src : int\n        The source node ID of the target link.\n    dst : int\n        The destination node ID of the target link.\n\n    Returns\n    -------\n    Tensor\n        Labels of all nodes. The tensor is of shape :math:`(N,)`, where\n        :math:`N` is the number of nodes in the input graph.\n\n    Example\n    -------\n    >>> import dgl\n\n    >>> g = dgl.graph(([0,0,0,0,1,1,2,4], [1,2,3,6,3,4,4,5]))\n    >>> dgl.double_radius_node_labeling(g, 0, 1)\n    tensor([1, 1, 3, 2, 3, 7, 0])\n    \"\"\"\n    adj = g.adj_external(scipy_fmt=\"csr\")\n    src, dst = (dst, src) if src > dst else (src, dst)\n\n    idx = list(range(src)) + list(range(src + 1, adj.shape[0]))\n    adj_wo_src = adj[idx, :][:, idx]\n\n    idx = list(range(dst)) + list(range(dst + 1, adj.shape[0]))\n    adj_wo_dst = adj[idx, :][:, idx]\n\n    # distance to the source node\n    ds = sparse.csgraph.shortest_path(\n        adj_wo_dst, directed=False, unweighted=True, indices=src\n    )\n    ds = np.insert(ds, dst, 0, axis=0)\n    # distance to the destination node\n    dt = sparse.csgraph.shortest_path(\n        adj_wo_src, directed=False, unweighted=True, indices=dst - 1\n    )\n    dt = np.insert(dt, src, 0, axis=0)\n\n    d = ds + dt\n    # suppress invalid value (nan) warnings\n    with np.errstate(invalid=\"ignore\"):\n        z = 1 + np.stack([ds, dt]).min(axis=0) + d // 2 * (d // 2 + d % 2 - 1)\n    z[src] = 1\n    z[dst] = 1\n    z[np.isnan(z)] = 0  # unreachable nodes\n\n    return F.tensor(z, F.int64)\n\n\ndef shortest_dist(g, root=None, return_paths=False):\n    r\"\"\"Compute shortest distance and paths on the given graph.\n\n    Only unweighted cases are supported. Only directed paths (in which the\n    edges are all oriented in the same direction) are considered effective.\n\n    Parameters\n    ----------\n    g : DGLGraph\n        The input graph. Must be homogeneous.\n    root : int, optional\n        Given a root node ID, it returns the shortest distance and paths\n        (optional) between the root node and all the nodes. If None, it returns\n        the results for all node pairs. Default: None.\n    return_paths : bool, optional\n        If True, it returns the shortest paths corresponding to the shortest\n        distances. Default: False.\n\n    Returns\n    -------\n    dist : Tensor\n        The shortest distance tensor.\n\n        * If :attr:`root` is a node ID, it is a tensor of shape :math:`(N,)`,\n          where :math:`N` is the number of nodes. :attr:`dist[j]` gives the\n          shortest distance from :attr:`root` to node :attr:`j`.\n        * Otherwise, it is a tensor of shape :math:`(N, N)`. :attr:`dist[i][j]`\n          gives the shortest distance from node :attr:`i` to node :attr:`j`.\n        * The distance values of unreachable node pairs are filled with -1.\n    paths : Tensor, optional\n        The shortest path tensor. It is only returned when :attr:`return_paths`\n        is True.\n\n        * If :attr:`root` is a node ID, it is a tensor of shape :math:`(N, L)`,\n          where :math:`L` is the length of the longest path. :attr:`path[j]` is\n          the shortest path from node :attr:`root` to node :attr:`j`.\n        * Otherwise, it is a tensor of shape :math:`(N, N, L)`.\n          :attr:`path[i][j]` is the shortest path from node :attr:`i` to node\n          :attr:`j`.\n        * Each path is a vector that consists of edge IDs with paddings of -1\n          at the end.\n        * Shortest path between a node and itself is a vector filled with -1's.\n\n    Example\n    -------\n    >>> import dgl\n\n    >>> g = dgl.graph(([0, 1, 1, 2], [2, 0, 3, 3]))\n    >>> dgl.shortest_dist(g, root=0)\n    tensor([ 0,  -1,  1, 2])\n    >>> dist, paths = dgl.shortest_dist(g, root=None, return_paths=True)\n    >>> print(dist)\n    tensor([[ 0, -1,  1,  2],\n            [ 1,  0,  2,  1],\n            [-1, -1,  0,  1],\n            [-1, -1, -1,  0]])\n    >>> print(paths)\n    tensor([[[-1, -1],\n             [-1, -1],\n             [ 0, -1],\n             [ 0,  3]],\n    <BLANKLINE>\n            [[ 1, -1],\n             [-1, -1],\n             [ 1,  0],\n             [ 2, -1]],\n    <BLANKLINE>\n            [[-1, -1],\n             [-1, -1],\n             [-1, -1],\n             [ 3, -1]],\n    <BLANKLINE>\n            [[-1, -1],\n             [-1, -1],\n             [-1, -1],\n             [-1, -1]]])\n    \"\"\"\n    if root is None:\n        dist, pred = sparse.csgraph.shortest_path(\n            g.adj_external(scipy_fmt=\"csr\"),\n            return_predecessors=True,\n            unweighted=True,\n            directed=True,\n        )\n    else:\n        dist, pred = sparse.csgraph.dijkstra(\n            g.adj_external(scipy_fmt=\"csr\"),\n            directed=True,\n            indices=root,\n            return_predecessors=True,\n            unweighted=True,\n        )\n    dist[np.isinf(dist)] = -1\n\n    if not return_paths:\n        return F.copy_to(F.tensor(dist, dtype=F.int64), g.device)\n\n    def _get_nodes(pred, i, j):\n        r\"\"\"return node IDs of a path from i to j given predecessors\"\"\"\n        if i == j:\n            return []\n        prev = pred[j]\n        nodes = [j, prev]\n        while prev != i:\n            prev = pred[prev]\n            nodes.append(prev)\n        nodes.reverse()\n\n        return nodes\n\n    # construct paths with given predecessors\n    max_len = int(dist[~np.isinf(dist)].max())\n    N = g.num_nodes()\n    roots = list(range(N)) if root is None else [root]\n    paths = np.ones([len(roots), N, max_len], dtype=np.int64) * -1\n    masks, u, v = [], [], []\n    for i in roots:\n        pred_ = pred[i] if root is None else pred\n        masks_i = np.zeros([N, max_len], dtype=bool)\n        for j in range(N):\n            if pred_[j] < 0:\n                continue\n            nodes = _get_nodes(pred_, i, j)\n            u.extend(nodes[:-1])\n            v.extend(nodes[1:])\n            if nodes:\n                masks_i[j, : len(nodes) - 1] = True\n        masks.append(masks_i)\n    masks = np.stack(masks, axis=0)\n\n    u, v = np.array(u), np.array(v)\n    edge_ids = g.edge_ids(u, v)\n    paths[masks] = F.asnumpy(edge_ids)\n    if root is not None:\n        paths = paths[0]\n\n    return F.copy_to(F.tensor(dist, dtype=F.int64), g.device), F.copy_to(\n        F.tensor(paths, dtype=F.int64), g.device\n    )\n\n\ndef svd_pe(g, k, padding=False, random_flip=True):\n    r\"\"\"SVD-based Positional Encoding, as introduced in\n    `Global Self-Attention as a Replacement for Graph Convolution\n    <https://arxiv.org/pdf/2108.03348.pdf>`__\n\n    This function computes the largest :math:`k` singular values and\n    corresponding left and right singular vectors to form positional encodings.\n\n    Parameters\n    ----------\n    g : DGLGraph\n        A DGLGraph to be encoded, which must be a homogeneous one.\n    k : int\n        Number of largest singular values and corresponding singular vectors\n        used for positional encoding.\n    padding : bool, optional\n        If False, raise an error when :math:`k > N`,\n        where :math:`N` is the number of nodes in :attr:`g`.\n        If True, add zero paddings in the end of encoding vectors when\n        :math:`k > N`.\n        Default : False.\n    random_flip : bool, optional\n        If True, randomly flip the signs of encoding vectors.\n        Proposed to be activated during training for better generalization.\n        Default : True.\n\n    Returns\n    -------\n    Tensor\n        Return SVD-based positional encodings of shape :math:`(N, 2k)`.\n\n    Example\n    -------\n    >>> import dgl\n\n    >>> g = dgl.graph(([0,1,2,3,4,2,3,1,4,0], [2,3,1,4,0,0,1,2,3,4]))\n    >>> dgl.svd_pe(g, k=2, padding=False, random_flip=True)\n    tensor([[-6.3246e-01, -1.1373e-07, -6.3246e-01,  0.0000e+00],\n            [-6.3246e-01,  7.6512e-01, -6.3246e-01, -7.6512e-01],\n            [ 6.3246e-01,  4.7287e-01,  6.3246e-01, -4.7287e-01],\n            [-6.3246e-01, -7.6512e-01, -6.3246e-01,  7.6512e-01],\n            [ 6.3246e-01, -4.7287e-01,  6.3246e-01,  4.7287e-01]])\n    \"\"\"\n    n = g.num_nodes()\n    if not padding and n < k:\n        raise ValueError(\n            \"The number of singular values k must be no greater than the \"\n            \"number of nodes n, but \" + f\"got {k} and {n} respectively.\"\n        )\n    a = g.adj_external(ctx=g.device, scipy_fmt=\"coo\").toarray()\n    u, d, vh = scipy.linalg.svd(a)\n    v = vh.transpose()\n    m = min(n, k)\n    topm_u = u[:, 0:m]\n    topm_v = v[:, 0:m]\n    topm_sqrt_d = sparse.diags(np.sqrt(d[0:m]))\n    encoding = np.concatenate(\n        ((topm_u @ topm_sqrt_d), (topm_v @ topm_sqrt_d)), axis=1\n    )\n    # randomly flip row vectors\n    if random_flip:\n        rand_sign = 2 * (np.random.rand(n) > 0.5) - 1\n        flipped_encoding = F.tensor(\n            rand_sign[:, np.newaxis] * encoding, dtype=F.float32\n        )\n    else:\n        flipped_encoding = F.tensor(encoding, dtype=F.float32)\n\n    if n < k:\n        zero_padding = F.zeros(\n            [n, 2 * (k - n)], dtype=F.float32, ctx=F.context(flipped_encoding)\n        )\n        flipped_encoding = F.cat([flipped_encoding, zero_padding], dim=1)\n\n    return flipped_encoding\n\n\n_init_api(\"dgl.transform\", __name__)\n"
  },
  {
    "path": "python/dgl/transforms/module.py",
    "content": "##\n#   Copyright 2019-2021 Contributors\n#\n#   Licensed under the Apache License, Version 2.0 (the \"License\");\n#   you may not use this file except in compliance with the License.\n#   You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n#   Unless required by applicable law or agreed to in writing, software\n#   distributed under the License is distributed on an \"AS IS\" BASIS,\n#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#   See the License for the specific language governing permissions and\n#   limitations under the License.\n#\n\"\"\"Modules for transform\"\"\"\n# pylint: disable= no-member, arguments-differ, invalid-name, missing-function-docstring\n\nfrom scipy.linalg import expm\n\nfrom .. import backend as F, convert, function as fn, utils\nfrom ..base import dgl_warning, DGLError\nfrom . import functional\n\ntry:\n    import torch\n    from torch.distributions import Bernoulli\nexcept ImportError:\n    pass\n\n__all__ = [\n    \"BaseTransform\",\n    \"RowFeatNormalizer\",\n    \"FeatMask\",\n    \"RandomWalkPE\",\n    \"LaplacianPE\",\n    \"LapPE\",\n    \"AddSelfLoop\",\n    \"RemoveSelfLoop\",\n    \"AddReverse\",\n    \"ToSimple\",\n    \"LineGraph\",\n    \"KHopGraph\",\n    \"AddMetaPaths\",\n    \"Compose\",\n    \"GCNNorm\",\n    \"PPR\",\n    \"HeatKernel\",\n    \"GDC\",\n    \"NodeShuffle\",\n    \"DropNode\",\n    \"DropEdge\",\n    \"AddEdge\",\n    \"SIGNDiffusion\",\n    \"ToLevi\",\n    \"SVDPE\",\n]\n\n\ndef update_graph_structure(g, data_dict, copy_edata=True):\n    r\"\"\"Update the structure of a graph.\n\n    Parameters\n    ----------\n    g : DGLGraph\n        The graph to update.\n    data_dict : graph data\n        The dictionary data for constructing a heterogeneous graph.\n    copy_edata : bool\n        If True, it will copy the edge features to the updated graph.\n\n    Returns\n    -------\n    DGLGraph\n        The updated graph.\n    \"\"\"\n    device = g.device\n    idtype = g.idtype\n    num_nodes_dict = dict()\n\n    for ntype in g.ntypes:\n        num_nodes_dict[ntype] = g.num_nodes(ntype)\n\n    new_g = convert.heterograph(\n        data_dict, num_nodes_dict=num_nodes_dict, idtype=idtype, device=device\n    )\n\n    # Copy features\n    for ntype in g.ntypes:\n        for key, feat in g.nodes[ntype].data.items():\n            new_g.nodes[ntype].data[key] = feat\n\n    if copy_edata:\n        for c_etype in g.canonical_etypes:\n            for key, feat in g.edges[c_etype].data.items():\n                new_g.edges[c_etype].data[key] = feat\n\n    return new_g\n\n\nclass BaseTransform:\n    r\"\"\"An abstract class for writing transforms.\"\"\"\n\n    def __call__(self, g):\n        raise NotImplementedError\n\n    def __repr__(self):\n        return self.__class__.__name__ + \"()\"\n\n\nclass RowFeatNormalizer(BaseTransform):\n    r\"\"\"\n    Row-normalizes the features given in ``node_feat_names`` and ``edge_feat_names``.\n\n    The row normalization formular is:\n\n    .. math::\n      x = \\frac{x}{\\sum_i x_i}\n\n    where :math:`x` denotes a row of the feature tensor.\n\n    Parameters\n    ----------\n    subtract_min: bool\n        If True, the minimum value of whole feature tensor will be subtracted before normalization.\n        Default: False.\n        Subtraction will make all values non-negative. If all values are negative, after\n        normalisation, the sum of each row of the feature tensor will be 1.\n    node_feat_names : list[str], optional\n        The names of the node feature tensors to be row-normalized. Default: `None`, which will\n        not normalize any node feature tensor.\n    edge_feat_names : list[str], optional\n        The names of the edge feature tensors to be row-normalized. Default: `None`, which will\n        not normalize any edge feature tensor.\n\n    Example\n    -------\n\n    The following example uses PyTorch backend.\n\n    >>> import dgl\n    >>> import torch\n    >>> from dgl import RowFeatNormalizer\n\n    Case1: Row normalize features of a homogeneous graph.\n\n    >>> transform = RowFeatNormalizer(subtract_min=True,\n    ...                               node_feat_names=['h'], edge_feat_names=['w'])\n    >>> g = dgl.rand_graph(5, 20)\n    >>> g.ndata['h'] = torch.randn((g.num_nodes(), 5))\n    >>> g.edata['w'] = torch.randn((g.num_edges(), 5))\n    >>> g = transform(g)\n    >>> print(g.ndata['h'].sum(1))\n    tensor([1., 1., 1., 1., 1.])\n    >>> print(g.edata['w'].sum(1))\n    tensor([1., 1., 1., 1., 1., 1., 1., 1., 1.,\n            1., 1., 1., 1., 1., 1., 1., 1., 1.,\n            1., 1.])\n\n    Case2: Row normalize features of a heterogeneous graph.\n\n    >>> g = dgl.heterograph({\n    ...     ('user', 'follows', 'user'): (torch.tensor([1, 2]), torch.tensor([3, 4])),\n    ...     ('player', 'plays', 'game'): (torch.tensor([2, 2]), torch.tensor([1, 1]))\n    ... })\n    >>> g.ndata['h'] = {'game': torch.randn(2, 5), 'player': torch.randn(3, 5)}\n    >>> g.edata['w'] = {\n    ...     ('user', 'follows', 'user'): torch.randn(2, 5),\n    ...     ('player', 'plays', 'game'): torch.randn(2, 5)\n    ... }\n    >>> g = transform(g)\n    >>> print(g.ndata['h']['game'].sum(1), g.ndata['h']['player'].sum(1))\n    tensor([1., 1.]) tensor([1., 1., 1.])\n    >>> print(g.edata['w'][('user', 'follows', 'user')].sum(1),\n    ...     g.edata['w'][('player', 'plays', 'game')].sum(1))\n    tensor([1., 1.]) tensor([1., 1.])\n    \"\"\"\n\n    def __init__(\n        self, subtract_min=False, node_feat_names=None, edge_feat_names=None\n    ):\n        self.node_feat_names = (\n            [] if node_feat_names is None else node_feat_names\n        )\n        self.edge_feat_names = (\n            [] if edge_feat_names is None else edge_feat_names\n        )\n        self.subtract_min = subtract_min\n\n    def row_normalize(self, feat):\n        r\"\"\"\n\n        Description\n        -----------\n        Row-normalize the given feature.\n\n        Parameters\n        ----------\n        feat : Tensor\n            The feature to be normalized.\n\n        Returns\n        -------\n        Tensor\n            The normalized feature.\n        \"\"\"\n        if self.subtract_min:\n            feat = feat - feat.min()\n        feat.div_(feat.sum(dim=-1, keepdim=True).clamp_(min=1.0))\n        return feat\n\n    def __call__(self, g):\n        for node_feat_name in self.node_feat_names:\n            if isinstance(g.ndata[node_feat_name], torch.Tensor):\n                g.ndata[node_feat_name] = self.row_normalize(\n                    g.ndata[node_feat_name]\n                )\n            else:\n                for ntype in g.ndata[node_feat_name].keys():\n                    g.nodes[ntype].data[node_feat_name] = self.row_normalize(\n                        g.nodes[ntype].data[node_feat_name]\n                    )\n\n        for edge_feat_name in self.edge_feat_names:\n            if isinstance(g.edata[edge_feat_name], torch.Tensor):\n                g.edata[edge_feat_name] = self.row_normalize(\n                    g.edata[edge_feat_name]\n                )\n            else:\n                for etype in g.edata[edge_feat_name].keys():\n                    g.edges[etype].data[edge_feat_name] = self.row_normalize(\n                        g.edges[etype].data[edge_feat_name]\n                    )\n\n        return g\n\n\nclass FeatMask(BaseTransform):\n    r\"\"\"Randomly mask columns of the node and edge feature tensors, as described in `Graph\n    Contrastive Learning with Augmentations <https://arxiv.org/abs/2010.13902>`__.\n\n    Parameters\n    ----------\n    p : float, optional\n        Probability of masking a column of a feature tensor. Default: `0.5`.\n    node_feat_names : list[str], optional\n        The names of the node feature tensors to be masked. Default: `None`, which will\n        not mask any node feature tensor.\n    edge_feat_names : list[str], optional\n        The names of the edge features to be masked. Default: `None`, which will not mask\n        any edge feature tensor.\n\n    Example\n    -------\n\n    The following example uses PyTorch backend.\n\n    >>> import dgl\n    >>> import torch\n    >>> from dgl import FeatMask\n\n    Case1 : Mask node and edge feature tensors of a homogeneous graph.\n\n    >>> transform = FeatMask(node_feat_names=['h'], edge_feat_names=['w'])\n    >>> g = dgl.rand_graph(5, 10)\n    >>> g.ndata['h'] = torch.ones((g.num_nodes(), 10))\n    >>> g.edata['w'] = torch.ones((g.num_edges(), 10))\n\n    >>> g = transform(g)\n    >>> print(g.ndata['h'])\n    tensor([[0., 0., 1., 1., 0., 0., 1., 1., 1., 0.],\n            [0., 0., 1., 1., 0., 0., 1., 1., 1., 0.],\n            [0., 0., 1., 1., 0., 0., 1., 1., 1., 0.],\n            [0., 0., 1., 1., 0., 0., 1., 1., 1., 0.],\n            [0., 0., 1., 1., 0., 0., 1., 1., 1., 0.]])\n    >>> print(g.edata['w'])\n    tensor([[1., 1., 0., 1., 0., 1., 0., 0., 0., 1.],\n            [1., 1., 0., 1., 0., 1., 0., 0., 0., 1.],\n            [1., 1., 0., 1., 0., 1., 0., 0., 0., 1.],\n            [1., 1., 0., 1., 0., 1., 0., 0., 0., 1.],\n            [1., 1., 0., 1., 0., 1., 0., 0., 0., 1.],\n            [1., 1., 0., 1., 0., 1., 0., 0., 0., 1.],\n            [1., 1., 0., 1., 0., 1., 0., 0., 0., 1.],\n            [1., 1., 0., 1., 0., 1., 0., 0., 0., 1.],\n            [1., 1., 0., 1., 0., 1., 0., 0., 0., 1.],\n            [1., 1., 0., 1., 0., 1., 0., 0., 0., 1.]])\n\n    Case2 : Mask node and edge feature tensors of a heterogeneous graph.\n\n    >>> g = dgl.heterograph({\n    ...     ('user', 'follows', 'user'): (torch.tensor([1, 2]), torch.tensor([3, 4])),\n    ...     ('player', 'plays', 'game'): (torch.tensor([2, 2]), torch.tensor([1, 1]))\n    ... })\n    >>> g.ndata['h'] = {'game': torch.ones(2, 5), 'player': torch.ones(3, 5)}\n    >>> g.edata['w'] = {('user', 'follows', 'user'): torch.ones(2, 5)}\n    >>> print(g.ndata['h']['game'])\n    tensor([[1., 1., 1., 1., 1.],\n            [1., 1., 1., 1., 1.]])\n    >>> print(g.edata['w'][('user', 'follows', 'user')])\n    tensor([[1., 1., 1., 1., 1.],\n            [1., 1., 1., 1., 1.]])\n    >>> g = transform(g)\n    >>> print(g.ndata['h']['game'])\n    tensor([[1., 1., 0., 1., 0.],\n            [1., 1., 0., 1., 0.]])\n    >>> print(g.edata['w'][('user', 'follows', 'user')])\n    tensor([[0., 1., 0., 1., 0.],\n            [0., 1., 0., 1., 0.]])\n    \"\"\"\n\n    def __init__(self, p=0.5, node_feat_names=None, edge_feat_names=None):\n        self.p = p\n        self.node_feat_names = (\n            [] if node_feat_names is None else node_feat_names\n        )\n        self.edge_feat_names = (\n            [] if edge_feat_names is None else edge_feat_names\n        )\n        self.dist = Bernoulli(p)\n\n    def __call__(self, g):\n        # Fast path\n        if self.p == 0:\n            return g\n\n        for node_feat_name in self.node_feat_names:\n            if isinstance(g.ndata[node_feat_name], torch.Tensor):\n                feat_mask = self.dist.sample(\n                    torch.Size(\n                        [\n                            g.ndata[node_feat_name].shape[-1],\n                        ]\n                    )\n                )\n                g.ndata[node_feat_name][:, feat_mask.bool().to(g.device)] = 0\n\n            else:\n                for ntype in g.ndata[node_feat_name].keys():\n                    mask_shape = g.ndata[node_feat_name][ntype].shape[-1]\n                    feat_mask = self.dist.sample(\n                        torch.Size(\n                            [\n                                mask_shape,\n                            ]\n                        )\n                    )\n                    g.ndata[node_feat_name][ntype][\n                        :, feat_mask.bool().to(g.device)\n                    ] = 0\n\n        for edge_feat_name in self.edge_feat_names:\n            if isinstance(g.edata[edge_feat_name], torch.Tensor):\n                feat_mask = self.dist.sample(\n                    torch.Size(\n                        [\n                            g.edata[edge_feat_name].shape[-1],\n                        ]\n                    )\n                )\n                g.edata[edge_feat_name][:, feat_mask.bool().to(g.device)] = 0\n\n            else:\n                for etype in g.edata[edge_feat_name].keys():\n                    mask_shape = g.edata[edge_feat_name][etype].shape[-1]\n                    feat_mask = self.dist.sample(\n                        torch.Size(\n                            [\n                                mask_shape,\n                            ]\n                        )\n                    )\n                    g.edata[edge_feat_name][etype][\n                        :, feat_mask.bool().to(g.device)\n                    ] = 0\n        return g\n\n\nclass RandomWalkPE(BaseTransform):\n    r\"\"\"Random Walk Positional Encoding, as introduced in\n    `Graph Neural Networks with Learnable Structural and Positional Representations\n    <https://arxiv.org/abs/2110.07875>`__\n\n    This module only works for homogeneous graphs.\n\n    Parameters\n    ----------\n    k : int\n        Number of random walk steps. The paper found the best value to be 16 and 20\n        for two experiments.\n    feat_name : str, optional\n        Name to store the computed positional encodings in ndata.\n    eweight_name : str, optional\n        Name to retrieve the edge weights. Default: None, not using the edge weights.\n\n    Example\n    -------\n\n    >>> import dgl\n    >>> from dgl import RandomWalkPE\n\n    >>> transform = RandomWalkPE(k=2)\n    >>> g = dgl.graph(([0, 1, 1], [1, 1, 0]))\n    >>> g = transform(g)\n    >>> print(g.ndata['PE'])\n    tensor([[0.0000, 0.5000],\n            [0.5000, 0.7500]])\n    \"\"\"\n\n    def __init__(self, k, feat_name=\"PE\", eweight_name=None):\n        self.k = k\n        self.feat_name = feat_name\n        self.eweight_name = eweight_name\n\n    def __call__(self, g):\n        PE = functional.random_walk_pe(\n            g, k=self.k, eweight_name=self.eweight_name\n        )\n        g.ndata[self.feat_name] = F.copy_to(PE, g.device)\n\n        return g\n\n\nclass LapPE(BaseTransform):\n    r\"\"\"Laplacian Positional Encoding, as introduced in\n    `Benchmarking Graph Neural Networks\n    <https://arxiv.org/abs/2003.00982>`__\n\n    This module only works for homogeneous bidirected graphs.\n\n    Parameters\n    ----------\n    k : int\n        Number of smallest non-trivial eigenvectors to use for positional encoding.\n    feat_name : str, optional\n        Name to store the computed positional encodings in ndata.\n    eigval_name : str, optional\n        If None, store laplacian eigenvectors only. Otherwise, it's the name to\n        store corresponding laplacian eigenvalues in ndata. Default: None.\n    padding : bool, optional\n        If False, raise an exception when k>=n.\n        Otherwise, add zero paddings in the end of eigenvectors and 'nan'\n        paddings in the end of eigenvalues when k>=n. Default: False.\n        n is the number of nodes in the given graph.\n\n    Example\n    -------\n    >>> import dgl\n    >>> from dgl import LapPE\n    >>> transform1 = LapPE(k=3)\n    >>> transform2 = LapPE(k=5, padding=True)\n    >>> transform3 = LapPE(k=5, feat_name='eigvec', eigval_name='eigval', padding=True)\n    >>> g = dgl.graph(([0,1,2,3,4,2,3,1,4,0], [2,3,1,4,0,0,1,2,3,4]))\n    >>> g1 = transform1(g)\n    >>> print(g1.ndata['PE'])\n    tensor([[ 0.6325,  0.1039,  0.3489],\n            [-0.5117,  0.2826,  0.6095],\n            [ 0.1954,  0.6254, -0.5923],\n            [-0.5117, -0.4508, -0.3938],\n            [ 0.1954, -0.5612,  0.0278]])\n    >>> g2 = transform2(g)\n    >>> print(g2.ndata['PE'])\n    tensor([[-0.6325, -0.1039,  0.3489, -0.2530,  0.0000],\n            [ 0.5117, -0.2826,  0.6095,  0.4731,  0.0000],\n            [-0.1954, -0.6254, -0.5923, -0.1361,  0.0000],\n            [ 0.5117,  0.4508, -0.3938, -0.6295,  0.0000],\n            [-0.1954,  0.5612,  0.0278,  0.5454,  0.0000]])\n    >>> g3 = transform3(g)\n    >>> print(g3.ndata['eigval'])\n    tensor([[0.6910, 0.6910, 1.8090, 1.8090,    nan],\n            [0.6910, 0.6910, 1.8090, 1.8090,    nan],\n            [0.6910, 0.6910, 1.8090, 1.8090,    nan],\n            [0.6910, 0.6910, 1.8090, 1.8090,    nan],\n            [0.6910, 0.6910, 1.8090, 1.8090,    nan]])\n    >>> print(g3.ndata['eigvec'])\n    tensor([[ 0.6325, -0.1039,  0.3489,  0.2530,  0.0000],\n            [-0.5117, -0.2826,  0.6095, -0.4731,  0.0000],\n            [ 0.1954, -0.6254, -0.5923,  0.1361,  0.0000],\n            [-0.5117,  0.4508, -0.3938,  0.6295,  0.0000],\n            [ 0.1954,  0.5612,  0.0278, -0.5454,  0.0000]])\n    \"\"\"\n\n    def __init__(self, k, feat_name=\"PE\", eigval_name=None, padding=False):\n        self.k = k\n        self.feat_name = feat_name\n        self.eigval_name = eigval_name\n        self.padding = padding\n\n    def __call__(self, g):\n        if self.eigval_name:\n            PE, eigval = functional.lap_pe(\n                g, k=self.k, padding=self.padding, return_eigval=True\n            )\n            eigval = F.repeat(F.reshape(eigval, [1, -1]), g.num_nodes(), dim=0)\n            g.ndata[self.eigval_name] = F.copy_to(eigval, g.device)\n        else:\n            PE = functional.lap_pe(g, k=self.k, padding=self.padding)\n        g.ndata[self.feat_name] = F.copy_to(PE, g.device)\n\n        return g\n\n\nclass LaplacianPE(LapPE):\n    r\"\"\"Alias of `LapPE`.\"\"\"\n\n    def __init__(self, k, feat_name=\"PE\", eigval_name=None, padding=False):\n        super().__init__(k, feat_name, eigval_name, padding)\n        dgl_warning(\"LaplacianPE will be deprecated. Use LapPE please.\")\n\n\nclass AddSelfLoop(BaseTransform):\n    r\"\"\"Add self-loops for each node in the graph and return a new graph.\n\n    For heterogeneous graphs, self-loops are added only for edge types with same\n    source and destination node types.\n\n    Parameters\n    ----------\n    allow_duplicate : bool, optional\n        If False, it will first remove self-loops to prevent duplicate self-loops.\n    new_etypes : bool, optional\n        If True, it will add an edge type 'self' per node type, which holds self-loops.\n    edge_feat_names : list[str], optional\n        The names of the self-loop features to apply `fill_data`. If None, it\n        will apply `fill_data` to all self-loop features. Default: None.\n    fill_data : int, float or str, optional\n        The value to fill the self-loop features. Default: 1.\n\n        * If ``fill_data`` is ``int`` or ``float``, self-loop features will be directly given by\n          ``fill_data``.\n        * if ``fill_data`` is ``str``, self-loop features will be generated by aggregating the\n          features of the incoming edges of the corresponding nodes. The supported aggregation are:\n          ``'mean'``, ``'sum'``, ``'max'``, ``'min'``.\n\n    Example\n    -------\n\n    >>> import dgl\n    >>> from dgl import AddSelfLoop\n\n    Case1: Add self-loops for a homogeneous graph\n\n    >>> transform = AddSelfLoop(fill_data='sum')\n    >>> g = dgl.graph(([0, 0, 2], [2, 1, 0]))\n    >>> g.edata['he'] = torch.arange(3).float().reshape(-1, 1)\n    >>> new_g = transform(g)\n    >>> print(new_g.edges())\n    (tensor([0, 0, 2, 0, 1, 2]), tensor([2, 1, 0, 0, 1, 2]))\n    >>> print(new_g.edata('he'))\n    tensor([[0.],\n            [1.],\n            [2.],\n            [2.],\n            [1.],\n            [0.]])\n\n    Case2: Add self-loops for a heterogeneous graph\n\n    >>> transform = AddSelfLoop(fill_data='sum')\n    >>> g = dgl.heterograph({\n    ...     ('user', 'follows', 'user'): (torch.tensor([1, 2]),\n    ...                                   torch.tensor([0, 1])),\n    ...     ('user', 'plays', 'game'): (torch.tensor([0, 1]),\n    ...                                 torch.tensor([0, 1]))})\n    >>> g.edata['feat'] = {('user', 'follows', 'user'): torch.randn(2, 5),\n    ...                    ('user', 'plays', 'game'): torch.randn(2, 5)}\n    >>> g.edata['feat1'] = {('user', 'follows', 'user'): torch.randn(2, 15),\n    ...                     ('user', 'plays', 'game'): torch.randn(2, 15)}\n    >>> new_g = transform(g)\n    >>> print(new_g.edges(etype='plays'))\n    (tensor([0, 1]), tensor([0, 1]))\n    >>> print(new_g.edges(etype='follows'))\n    (tensor([1, 2, 0, 1, 2]), tensor([0, 1, 0, 1, 2]))\n    >>> print(new_g.edata['feat'][('user', 'follows', 'user')].shape)\n    torch.Size([5, 5])\n\n    Case3: Add self-etypes for a heterogeneous graph\n\n    >>> transform = AddSelfLoop(new_etypes=True)\n    >>> new_g = transform(g)\n    >>> print(new_g.edges(etype='follows'))\n    (tensor([1, 2, 0, 1, 2]), tensor([0, 1, 0, 1, 2]))\n    >>> print(new_g.edges(etype=('game', 'self', 'game')))\n    (tensor([0, 1]), tensor([0, 1]))\n    \"\"\"\n\n    def __init__(\n        self,\n        allow_duplicate=False,\n        new_etypes=False,\n        edge_feat_names=None,\n        fill_data=1.0,\n    ):\n        self.allow_duplicate = allow_duplicate\n        self.new_etypes = new_etypes\n        self.edge_feat_names = edge_feat_names\n        self.fill_data = fill_data\n\n    def transform_etype(self, c_etype, g):\n        r\"\"\"\n\n        Description\n        -----------\n        Transform the graph corresponding to a canonical edge type.\n\n        Parameters\n        ----------\n        c_etype : tuple of str\n            A canonical edge type.\n        g : DGLGraph\n            The graph.\n\n        Returns\n        -------\n        DGLGraph\n            The transformed graph.\n        \"\"\"\n        utype, _, vtype = c_etype\n        if utype != vtype:\n            return g\n\n        if not self.allow_duplicate:\n            g = functional.remove_self_loop(g, etype=c_etype)\n        return functional.add_self_loop(\n            g,\n            edge_feat_names=self.edge_feat_names,\n            fill_data=self.fill_data,\n            etype=c_etype,\n        )\n\n    def __call__(self, g):\n        for c_etype in g.canonical_etypes:\n            g = self.transform_etype(c_etype, g)\n\n        if self.new_etypes:\n            device = g.device\n            idtype = g.idtype\n            data_dict = dict()\n\n            # Add self etypes\n            for ntype in g.ntypes:\n                nids = F.arange(0, g.num_nodes(ntype), idtype, device)\n                data_dict[(ntype, \"self\", ntype)] = (nids, nids)\n\n            # Copy edges\n            for c_etype in g.canonical_etypes:\n                data_dict[c_etype] = g.edges(etype=c_etype)\n\n            g = update_graph_structure(g, data_dict)\n\n        return g\n\n\nclass RemoveSelfLoop(BaseTransform):\n    r\"\"\"Remove self-loops for each node in the graph and return a new graph.\n\n    For heterogeneous graphs, this operation only applies to edge types with same\n    source and destination node types.\n\n    Example\n    -------\n\n    >>> import dgl\n    >>> from dgl import RemoveSelfLoop\n\n    Case1: Remove self-loops for a homogeneous graph\n\n    >>> transform = RemoveSelfLoop()\n    >>> g = dgl.graph(([1, 1], [1, 2]))\n    >>> new_g = transform(g)\n    >>> print(new_g.edges())\n    (tensor([1]), tensor([2]))\n\n    Case2: Remove self-loops for a heterogeneous graph\n\n    >>> g = dgl.heterograph({\n    ...     ('user', 'plays', 'game'): ([0, 1], [1, 1]),\n    ...     ('user', 'follows', 'user'): ([1, 2], [2, 2])\n    ... })\n    >>> new_g = transform(g)\n    >>> print(new_g.edges(etype='plays'))\n    (tensor([0, 1]), tensor([1, 1]))\n    >>> print(new_g.edges(etype='follows'))\n    (tensor([1]), tensor([2]))\n    \"\"\"\n\n    def transform_etype(self, c_etype, g):\n        r\"\"\"Transform the graph corresponding to a canonical edge type.\n\n        Parameters\n        ----------\n        c_etype : tuple of str\n            A canonical edge type.\n        g : DGLGraph\n            The graph.\n\n        Returns\n        -------\n        DGLGraph\n            The transformed graph.\n        \"\"\"\n        utype, _, vtype = c_etype\n        if utype == vtype:\n            g = functional.remove_self_loop(g, etype=c_etype)\n        return g\n\n    def __call__(self, g):\n        for c_etype in g.canonical_etypes:\n            g = self.transform_etype(c_etype, g)\n        return g\n\n\nclass AddReverse(BaseTransform):\n    r\"\"\"Add a reverse edge :math:`(i,j)` for each edge :math:`(j,i)` in the input graph and\n    return a new graph.\n\n    For a heterogeneous graph, it adds a \"reverse\" edge type for each edge type\n    to hold the reverse edges. For example, for a canonical edge type ('A', 'r', 'B'),\n    it adds a canonical edge type ('B', 'rev_r', 'A').\n\n    Parameters\n    ----------\n    copy_edata : bool, optional\n        If True, the features of the reverse edges will be identical to the original ones.\n    sym_new_etype : bool, optional\n        If False, it will not add a reverse edge type if the source and destination node type\n        in a canonical edge type are identical. Instead, it will directly add edges to the\n        original edge type.\n\n    Example\n    -------\n\n    The following example uses PyTorch backend.\n\n    >>> import dgl\n    >>> import torch\n    >>> from dgl import AddReverse\n\n    Case1: Add reverse edges for a homogeneous graph\n\n    >>> transform = AddReverse()\n    >>> g = dgl.graph(([0], [1]))\n    >>> g.edata['w'] = torch.ones(1, 2)\n    >>> new_g = transform(g)\n    >>> print(new_g.edges())\n    (tensor([0, 1]), tensor([1, 0]))\n    >>> print(new_g.edata['w'])\n    tensor([[1., 1.],\n            [0., 0.]])\n\n    Case2: Add reverse edges for a homogeneous graph and copy edata\n\n    >>> transform = AddReverse(copy_edata=True)\n    >>> new_g = transform(g)\n    >>> print(new_g.edata['w'])\n    tensor([[1., 1.],\n            [1., 1.]])\n\n    Case3: Add reverse edges for a heterogeneous graph\n\n    >>> g = dgl.heterograph({\n    ...     ('user', 'plays', 'game'): ([0, 1], [1, 1]),\n    ...     ('user', 'follows', 'user'): ([1, 2], [2, 2])\n    ... })\n    >>> new_g = transform(g)\n    >>> print(new_g.canonical_etypes)\n    [('game', 'rev_plays', 'user'), ('user', 'follows', 'user'), ('user', 'plays', 'game')]\n    >>> print(new_g.edges(etype='rev_plays'))\n    (tensor([1, 1]), tensor([0, 1]))\n    >>> print(new_g.edges(etype='follows'))\n    (tensor([1, 2, 2, 2]), tensor([2, 2, 1, 2]))\n    \"\"\"\n\n    def __init__(self, copy_edata=False, sym_new_etype=False):\n        self.copy_edata = copy_edata\n        self.sym_new_etype = sym_new_etype\n\n    def transform_symmetric_etype(self, c_etype, g, data_dict):\n        r\"\"\"Transform the graph corresponding to a symmetric canonical edge type.\n\n        Parameters\n        ----------\n        c_etype : tuple of str\n            A canonical edge type.\n        g : DGLGraph\n            The graph.\n        data_dict : dict\n            The edge data to update.\n        \"\"\"\n        if self.sym_new_etype:\n            self.transform_asymmetric_etype(c_etype, g, data_dict)\n        else:\n            src, dst = g.edges(etype=c_etype)\n            src, dst = F.cat([src, dst], dim=0), F.cat([dst, src], dim=0)\n            data_dict[c_etype] = (src, dst)\n\n    def transform_asymmetric_etype(self, c_etype, g, data_dict):\n        r\"\"\"Transform the graph corresponding to an asymmetric canonical edge type.\n\n        Parameters\n        ----------\n        c_etype : tuple of str\n            A canonical edge type.\n        g : DGLGraph\n            The graph.\n        data_dict : dict\n            The edge data to update.\n        \"\"\"\n        utype, etype, vtype = c_etype\n        src, dst = g.edges(etype=c_etype)\n        data_dict.update(\n            {\n                c_etype: (src, dst),\n                (vtype, \"rev_{}\".format(etype), utype): (dst, src),\n            }\n        )\n\n    def transform_etype(self, c_etype, g, data_dict):\n        r\"\"\"Transform the graph corresponding to a canonical edge type.\n\n        Parameters\n        ----------\n        c_etype : tuple of str\n            A canonical edge type.\n        g : DGLGraph\n            The graph.\n        data_dict : dict\n            The edge data to update.\n        \"\"\"\n        utype, _, vtype = c_etype\n        if utype == vtype:\n            self.transform_symmetric_etype(c_etype, g, data_dict)\n        else:\n            self.transform_asymmetric_etype(c_etype, g, data_dict)\n\n    def __call__(self, g):\n        data_dict = dict()\n        for c_etype in g.canonical_etypes:\n            self.transform_etype(c_etype, g, data_dict)\n        new_g = update_graph_structure(g, data_dict, copy_edata=False)\n\n        # Copy and expand edata\n        for c_etype in g.canonical_etypes:\n            utype, etype, vtype = c_etype\n            if utype != vtype or self.sym_new_etype:\n                rev_c_etype = (vtype, \"rev_{}\".format(etype), utype)\n                for key, feat in g.edges[c_etype].data.items():\n                    new_g.edges[c_etype].data[key] = feat\n                    if self.copy_edata:\n                        new_g.edges[rev_c_etype].data[key] = feat\n            else:\n                for key, feat in g.edges[c_etype].data.items():\n                    new_feat = (\n                        feat\n                        if self.copy_edata\n                        else F.zeros(\n                            F.shape(feat), F.dtype(feat), F.context(feat)\n                        )\n                    )\n                    new_g.edges[c_etype].data[key] = F.cat(\n                        [feat, new_feat], dim=0\n                    )\n\n        return new_g\n\n\nclass ToSimple(BaseTransform):\n    r\"\"\"Convert a graph to a simple graph without parallel edges and return a new graph.\n\n    Parameters\n    ----------\n    return_counts : str, optional\n        The edge feature name to hold the edge count in the original graph.\n    aggregator : str, optional\n        The way to coalesce features of duplicate edges.\n\n        * ``'arbitrary'``: select arbitrarily from one of the duplicate edges\n        * ``'sum'``: take the sum over the duplicate edges\n        * ``'mean'``: take the mean over the duplicate edges\n\n    Example\n    -------\n\n    The following example uses PyTorch backend.\n\n    >>> import dgl\n    >>> import torch\n    >>> from dgl import ToSimple\n\n    Case1: Convert a homogeneous graph to a simple graph\n\n    >>> transform = ToSimple()\n    >>> g = dgl.graph(([0, 1, 1], [1, 2, 2]))\n    >>> g.edata['w'] = torch.tensor([[0.1], [0.2], [0.3]])\n    >>> sg = transform(g)\n    >>> print(sg.edges())\n    (tensor([0, 1]), tensor([1, 2]))\n    >>> print(sg.edata['count'])\n    tensor([1, 2])\n    >>> print(sg.edata['w'])\n    tensor([[0.1000], [0.2000]])\n\n    Case2: Convert a heterogeneous graph to a simple graph\n\n    >>> g = dgl.heterograph({\n    ...     ('user', 'follows', 'user'): ([0, 1, 1], [1, 2, 2]),\n    ...     ('user', 'plays', 'game'): ([0, 1, 0], [1, 1, 1])\n    ... })\n    >>> sg = transform(g)\n    >>> print(sg.edges(etype='follows'))\n    (tensor([0, 1]), tensor([1, 2]))\n    >>> print(sg.edges(etype='plays'))\n    (tensor([0, 1]), tensor([1, 1]))\n    \"\"\"\n\n    def __init__(self, return_counts=\"count\", aggregator=\"arbitrary\"):\n        self.return_counts = return_counts\n        self.aggregator = aggregator\n\n    def __call__(self, g):\n        return functional.to_simple(\n            g,\n            return_counts=self.return_counts,\n            copy_edata=True,\n            aggregator=self.aggregator,\n        )\n\n\nclass LineGraph(BaseTransform):\n    r\"\"\"Return the line graph of the input graph.\n\n    The line graph :math:`L(G)` of a given graph :math:`G` is a graph where\n    the nodes in :math:`L(G)` correspond to the edges in :math:`G`. For a pair\n    of edges :math:`(u, v)` and :math:`(v, w)` in :math:`G`, there will be an\n    edge from the node corresponding to :math:`(u, v)` to the node corresponding to\n    :math:`(v, w)` in :math:`L(G)`.\n\n    This module only works for homogeneous graphs.\n\n    Parameters\n    ----------\n    backtracking : bool, optional\n        If False, there will be an edge from the line graph node corresponding to\n        :math:`(u, v)` to the line graph node corresponding to :math:`(v, u)`.\n\n    Example\n    -------\n\n    The following example uses PyTorch backend.\n\n    >>> import dgl\n    >>> import torch\n    >>> from dgl import LineGraph\n\n    Case1: Backtracking is True\n\n    >>> transform = LineGraph()\n    >>> g = dgl.graph(([0, 1, 1], [1, 0, 2]))\n    >>> g.ndata['h'] = torch.tensor([[0.], [1.], [2.]])\n    >>> g.edata['w'] = torch.tensor([[0.], [0.1], [0.2]])\n    >>> new_g = transform(g)\n    >>> print(new_g)\n    Graph(num_nodes=3, num_edges=3,\n          ndata_schemes={'w': Scheme(shape=(1,), dtype=torch.float32)}\n          edata_schemes={})\n    >>> print(new_g.edges())\n    (tensor([0, 0, 1]), tensor([1, 2, 0]))\n\n    Case2: Backtracking is False\n\n    >>> transform = LineGraph(backtracking=False)\n    >>> new_g = transform(g)\n    >>> print(new_g.edges())\n    (tensor([0]), tensor([2]))\n    \"\"\"\n\n    def __init__(self, backtracking=True):\n        self.backtracking = backtracking\n\n    def __call__(self, g):\n        return functional.line_graph(\n            g, backtracking=self.backtracking, shared=True\n        )\n\n\nclass KHopGraph(BaseTransform):\n    r\"\"\"Return the graph whose edges connect the :math:`k`-hop neighbors of the original graph.\n\n    This module only works for homogeneous graphs.\n\n    Parameters\n    ----------\n    k : int\n        The number of hops.\n\n    Example\n    -------\n\n    >>> import dgl\n    >>> from dgl import KHopGraph\n\n    >>> transform = KHopGraph(2)\n    >>> g = dgl.graph(([0, 1], [1, 2]))\n    >>> new_g = transform(g)\n    >>> print(new_g.edges())\n    (tensor([0]), tensor([2]))\n    \"\"\"\n\n    def __init__(self, k):\n        self.k = k\n\n    def __call__(self, g):\n        return functional.khop_graph(g, self.k)\n\n\nclass AddMetaPaths(BaseTransform):\n    r\"\"\"Add new edges to an input graph based on given metapaths, as described in\n    `Heterogeneous Graph Attention Network <https://arxiv.org/abs/1903.07293>`__.\n\n    Formally, a metapath is a path of the form\n\n    .. math::\n\n        \\mathcal{V}_1 \\xrightarrow{R_1} \\mathcal{V}_2 \\xrightarrow{R_2} \\ldots\n        \\xrightarrow{R_{\\ell-1}} \\mathcal{V}_{\\ell}\n\n    in which :math:`\\mathcal{V}_i` represents a node type and :math:`\\xrightarrow{R_j}`\n    represents a relation type connecting its two adjacent node types. The adjacency matrix\n    corresponding to the metapath is obtained by sequential multiplication of adjacency matrices\n    along the metapath.\n\n    Parameters\n    ----------\n    metapaths : dict[str, list]\n        The metapaths to add, mapping a metapath name to a metapath. For example,\n        :attr:`{'co-author': [('person', 'author', 'paper'), ('paper', 'authored by', 'person')]}`\n    keep_orig_edges : bool, optional\n        If True, it will keep the edges of the original graph. Otherwise, it will drop them.\n\n    Example\n    -------\n\n    >>> import dgl\n    >>> from dgl import AddMetaPaths\n\n    >>> transform = AddMetaPaths({\n    ...     'accepted': [('person', 'author', 'paper'), ('paper', 'accepted', 'venue')],\n    ...     'rejected': [('person', 'author', 'paper'), ('paper', 'rejected', 'venue')]\n    ... })\n    >>> g = dgl.heterograph({\n    ...     ('person', 'author', 'paper'): ([0, 0, 1], [1, 2, 2]),\n    ...     ('paper', 'accepted', 'venue'): ([1], [0]),\n    ...     ('paper', 'rejected', 'venue'): ([2], [1])\n    ... })\n    >>> new_g = transform(g)\n    >>> print(new_g.edges(etype=('person', 'accepted', 'venue')))\n    (tensor([0]), tensor([0]))\n    >>> print(new_g.edges(etype=('person', 'rejected', 'venue')))\n    (tensor([0, 1]), tensor([1, 1]))\n    \"\"\"\n\n    def __init__(self, metapaths, keep_orig_edges=True):\n        self.metapaths = metapaths\n        self.keep_orig_edges = keep_orig_edges\n\n    def __call__(self, g):\n        data_dict = dict()\n\n        for meta_etype, metapath in self.metapaths.items():\n            meta_g = functional.metapath_reachable_graph(g, metapath)\n            u_type = metapath[0][0]\n            v_type = metapath[-1][-1]\n            data_dict[(u_type, meta_etype, v_type)] = meta_g.edges()\n\n        if self.keep_orig_edges:\n            for c_etype in g.canonical_etypes:\n                data_dict[c_etype] = g.edges(etype=c_etype)\n            new_g = update_graph_structure(g, data_dict, copy_edata=True)\n        else:\n            new_g = update_graph_structure(g, data_dict, copy_edata=False)\n\n        return new_g\n\n\nclass Compose(BaseTransform):\n    r\"\"\"Create a transform composed of multiple transforms in sequence.\n\n    Parameters\n    ----------\n    transforms : list of Callable\n        A list of transform objects to apply in order. A transform object should inherit\n        :class:`~dgl.BaseTransform` and implement :func:`~dgl.BaseTransform.__call__`.\n\n    Example\n    -------\n\n    >>> import dgl\n    >>> from dgl import transforms as T\n\n    >>> g = dgl.graph(([0, 0], [1, 1]))\n    >>> transform = T.Compose([T.ToSimple(), T.AddReverse()])\n    >>> new_g = transform(g)\n    >>> print(new_g.edges())\n    (tensor([0, 1]), tensor([1, 0]))\n    \"\"\"\n\n    def __init__(self, transforms):\n        self.transforms = transforms\n\n    def __call__(self, g):\n        for transform in self.transforms:\n            g = transform(g)\n        return g\n\n    def __repr__(self):\n        args = [\"  \" + str(transform) for transform in self.transforms]\n        return self.__class__.__name__ + \"([\\n\" + \",\\n\".join(args) + \"\\n])\"\n\n\nclass GCNNorm(BaseTransform):\n    r\"\"\"Apply symmetric adjacency normalization to an input graph and save the result edge\n    weights, as described in `Semi-Supervised Classification with Graph Convolutional Networks\n    <https://arxiv.org/abs/1609.02907>`__.\n\n    For a heterogeneous graph, this only applies to symmetric canonical edge types, whose source\n    and destination node types are identical.\n\n    Parameters\n    ----------\n    eweight_name : str, optional\n        :attr:`edata` name to retrieve and store edge weights. The edge weights are optional.\n\n    Example\n    -------\n\n    >>> import dgl\n    >>> import torch\n    >>> from dgl import GCNNorm\n    >>> transform = GCNNorm()\n    >>> g = dgl.graph(([0, 1, 2], [0, 0, 1]))\n\n    Case1: Transform an unweighted graph\n\n    >>> g = transform(g)\n    >>> print(g.edata['w'])\n    tensor([0.5000, 0.7071, 0.0000])\n\n    Case2: Transform a weighted graph\n\n    >>> g.edata['w'] = torch.tensor([0.1, 0.2, 0.3])\n    >>> g = transform(g)\n    >>> print(g.edata['w'])\n    tensor([0.3333, 0.6667, 0.0000])\n    \"\"\"\n\n    def __init__(self, eweight_name=\"w\"):\n        self.eweight_name = eweight_name\n\n    def calc_etype(self, c_etype, g):\n        r\"\"\"\n\n        Description\n        -----------\n        Get edge weights for an edge type.\n        \"\"\"\n        ntype = c_etype[0]\n        with g.local_scope():\n            if self.eweight_name in g.edges[c_etype].data:\n                g.update_all(\n                    fn.copy_e(self.eweight_name, \"m\"),\n                    fn.sum(\"m\", \"deg\"),\n                    etype=c_etype,\n                )\n                deg_inv_sqrt = 1.0 / F.sqrt(g.nodes[ntype].data[\"deg\"])\n                g.nodes[ntype].data[\"w\"] = F.replace_inf_with_zero(deg_inv_sqrt)\n                g.apply_edges(\n                    lambda edge: {\n                        \"w\": edge.src[\"w\"]\n                        * edge.data[self.eweight_name]\n                        * edge.dst[\"w\"]\n                    },\n                    etype=c_etype,\n                )\n            else:\n                deg = g.in_degrees(etype=c_etype)\n                deg_inv_sqrt = 1.0 / F.sqrt(F.astype(deg, F.float32))\n                g.nodes[ntype].data[\"w\"] = F.replace_inf_with_zero(deg_inv_sqrt)\n                g.apply_edges(\n                    lambda edges: {\"w\": edges.src[\"w\"] * edges.dst[\"w\"]},\n                    etype=c_etype,\n                )\n            return g.edges[c_etype].data[\"w\"]\n\n    def __call__(self, g):\n        result = dict()\n        for c_etype in g.canonical_etypes:\n            utype, _, vtype = c_etype\n            if utype == vtype:\n                result[c_etype] = self.calc_etype(c_etype, g)\n\n        for c_etype, eweight in result.items():\n            g.edges[c_etype].data[self.eweight_name] = eweight\n        return g\n\n\nclass PPR(BaseTransform):\n    r\"\"\"Apply personalized PageRank (PPR) to an input graph for diffusion, as introduced in\n    `The pagerank citation ranking: Bringing order to the web\n    <http://ilpubs.stanford.edu:8090/422/>`__.\n\n    A sparsification will be applied to the weighted adjacency matrix after diffusion.\n    Specifically, edges whose weight is below a threshold will be dropped.\n\n    This module only works for homogeneous graphs.\n\n    Parameters\n    ----------\n    alpha : float, optional\n        Restart probability, which commonly lies in :math:`[0.05, 0.2]`.\n    eweight_name : str, optional\n        :attr:`edata` name to retrieve and store edge weights. If it does\n        not exist in an input graph, this module initializes a weight of 1\n        for all edges. The edge weights should be a tensor of shape :math:`(E)`,\n        where E is the number of edges.\n    eps : float, optional\n        The threshold to preserve edges in sparsification after diffusion. Edges of a\n        weight smaller than eps will be dropped.\n    avg_degree : int, optional\n        The desired average node degree of the result graph. This is the other way to\n        control the sparsity of the result graph and will only be effective if\n        :attr:`eps` is not given.\n\n    Example\n    -------\n\n    >>> import dgl\n    >>> import torch\n    >>> from dgl import PPR\n\n    >>> transform = PPR(avg_degree=2)\n    >>> g = dgl.graph(([0, 1, 2, 3, 4], [2, 3, 4, 5, 3]))\n    >>> g.edata['w'] = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5])\n    >>> new_g = transform(g)\n    >>> print(new_g.edata['w'])\n    tensor([0.1500, 0.1500, 0.1500, 0.0255, 0.0163, 0.1500, 0.0638, 0.0383, 0.1500,\n            0.0510, 0.0217, 0.1500])\n    \"\"\"\n\n    def __init__(self, alpha=0.15, eweight_name=\"w\", eps=None, avg_degree=5):\n        self.alpha = alpha\n        self.eweight_name = eweight_name\n        self.eps = eps\n        self.avg_degree = avg_degree\n\n    def get_eps(self, num_nodes, mat):\n        r\"\"\"Get the threshold for graph sparsification.\"\"\"\n        if self.eps is None:\n            # Infer from self.avg_degree\n            if self.avg_degree > num_nodes:\n                return float(\"-inf\")\n            sorted_weights = torch.sort(mat.flatten(), descending=True).values\n            return sorted_weights[self.avg_degree * num_nodes - 1]\n        else:\n            return self.eps\n\n    def __call__(self, g):\n        # Step1: PPR diffusion\n        # (α - 1) A\n        device = g.device\n        eweight = (self.alpha - 1) * g.edata.get(\n            self.eweight_name, F.ones((g.num_edges(),), F.float32, device)\n        )\n        num_nodes = g.num_nodes()\n        mat = F.zeros((num_nodes, num_nodes), F.float32, device)\n        src, dst = g.edges()\n        src, dst = F.astype(src, F.int64), F.astype(dst, F.int64)\n        mat[dst, src] = eweight\n        # I_n + (α - 1) A\n        nids = F.astype(g.nodes(), F.int64)\n        mat[nids, nids] = mat[nids, nids] + 1\n        # α (I_n + (α - 1) A)^-1\n        diff_mat = self.alpha * F.inverse(mat)\n\n        # Step2: sparsification\n        num_nodes = g.num_nodes()\n        eps = self.get_eps(num_nodes, diff_mat)\n        dst, src = (diff_mat >= eps).nonzero(as_tuple=False).t()\n        data_dict = {g.canonical_etypes[0]: (src, dst)}\n        new_g = update_graph_structure(g, data_dict, copy_edata=False)\n        new_g.edata[self.eweight_name] = diff_mat[dst, src]\n\n        return new_g\n\n\ndef is_bidirected(g):\n    \"\"\"Return whether the graph is a bidirected graph.\n\n    A graph is bidirected if for any edge :math:`(u, v)` in :math:`G` with weight :math:`w`,\n    there exists an edge :math:`(v, u)` in :math:`G` with the same weight.\n    \"\"\"\n    src, dst = g.edges()\n    num_nodes = g.num_nodes()\n\n    # Sort first by src then dst\n    idx_src_dst = src * num_nodes + dst\n    perm_src_dst = F.argsort(idx_src_dst, dim=0, descending=False)\n    src1, dst1 = src[perm_src_dst], dst[perm_src_dst]\n\n    # Sort first by dst then src\n    idx_dst_src = dst * num_nodes + src\n    perm_dst_src = F.argsort(idx_dst_src, dim=0, descending=False)\n    src2, dst2 = src[perm_dst_src], dst[perm_dst_src]\n\n    return F.allclose(src1, dst2) and F.allclose(src2, dst1)\n\n\n# pylint: disable=C0103\nclass HeatKernel(BaseTransform):\n    r\"\"\"Apply heat kernel to an input graph for diffusion, as introduced in\n    `Diffusion kernels on graphs and other discrete structures\n    <https://www.ml.cmu.edu/research/dap-papers/kondor-diffusion-kernels.pdf>`__.\n\n    A sparsification will be applied to the weighted adjacency matrix after diffusion.\n    Specifically, edges whose weight is below a threshold will be dropped.\n\n    This module only works for homogeneous graphs.\n\n    Parameters\n    ----------\n    t : float, optional\n        Diffusion time, which commonly lies in :math:`[2, 10]`.\n    eweight_name : str, optional\n        :attr:`edata` name to retrieve and store edge weights. If it does\n        not exist in an input graph, this module initializes a weight of 1\n        for all edges. The edge weights should be a tensor of shape :math:`(E)`,\n        where E is the number of edges.\n    eps : float, optional\n        The threshold to preserve edges in sparsification after diffusion. Edges of a\n        weight smaller than eps will be dropped.\n    avg_degree : int, optional\n        The desired average node degree of the result graph. This is the other way to\n        control the sparsity of the result graph and will only be effective if\n        :attr:`eps` is not given.\n\n    Example\n    -------\n\n    >>> import dgl\n    >>> import torch\n    >>> from dgl import HeatKernel\n\n    >>> transform = HeatKernel(avg_degree=2)\n    >>> g = dgl.graph(([0, 1, 2, 3, 4], [2, 3, 4, 5, 3]))\n    >>> g.edata['w'] = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5])\n    >>> new_g = transform(g)\n    >>> print(new_g.edata['w'])\n    tensor([0.1353, 0.1353, 0.1353, 0.0541, 0.0406, 0.1353, 0.1353, 0.0812, 0.1353,\n            0.1083, 0.0541, 0.1353])\n    \"\"\"\n\n    def __init__(self, t=2.0, eweight_name=\"w\", eps=None, avg_degree=5):\n        self.t = t\n        self.eweight_name = eweight_name\n        self.eps = eps\n        self.avg_degree = avg_degree\n\n    def get_eps(self, num_nodes, mat):\n        r\"\"\"Get the threshold for graph sparsification.\"\"\"\n        if self.eps is None:\n            # Infer from self.avg_degree\n            if self.avg_degree > num_nodes:\n                return float(\"-inf\")\n            sorted_weights = torch.sort(mat.flatten(), descending=True).values\n            return sorted_weights[self.avg_degree * num_nodes - 1]\n        else:\n            return self.eps\n\n    def __call__(self, g):\n        # Step1: heat kernel diffusion\n        # t A\n        device = g.device\n        eweight = self.t * g.edata.get(\n            self.eweight_name, F.ones((g.num_edges(),), F.float32, device)\n        )\n        num_nodes = g.num_nodes()\n        mat = F.zeros((num_nodes, num_nodes), F.float32, device)\n        src, dst = g.edges()\n        src, dst = F.astype(src, F.int64), F.astype(dst, F.int64)\n        mat[dst, src] = eweight\n        # t (A - I_n)\n        nids = F.astype(g.nodes(), F.int64)\n        mat[nids, nids] = mat[nids, nids] - self.t\n\n        if is_bidirected(g):\n            e, V = torch.linalg.eigh(mat, UPLO=\"U\")\n            diff_mat = V @ torch.diag(e.exp()) @ V.t()\n        else:\n            diff_mat_np = expm(mat.cpu().numpy())\n            diff_mat = torch.Tensor(diff_mat_np).to(device)\n\n        # Step2: sparsification\n        num_nodes = g.num_nodes()\n        eps = self.get_eps(num_nodes, diff_mat)\n        dst, src = (diff_mat >= eps).nonzero(as_tuple=False).t()\n        data_dict = {g.canonical_etypes[0]: (src, dst)}\n        new_g = update_graph_structure(g, data_dict, copy_edata=False)\n        new_g.edata[self.eweight_name] = diff_mat[dst, src]\n\n        return new_g\n\n\nclass GDC(BaseTransform):\n    r\"\"\"Apply graph diffusion convolution (GDC) to an input graph, as introduced in\n    `Diffusion Improves Graph Learning <https://www.in.tum.de/daml/gdc/>`__.\n\n    A sparsification will be applied to the weighted adjacency matrix after diffusion.\n    Specifically, edges whose weight is below a threshold will be dropped.\n\n    This module only works for homogeneous graphs.\n\n    Parameters\n    ----------\n    coefs : list[float], optional\n        List of coefficients. :math:`\\theta_k` for each power of the adjacency matrix.\n    eweight_name : str, optional\n        :attr:`edata` name to retrieve and store edge weights. If it does\n        not exist in an input graph, this module initializes a weight of 1\n        for all edges. The edge weights should be a tensor of shape :math:`(E)`,\n        where E is the number of edges.\n    eps : float, optional\n        The threshold to preserve edges in sparsification after diffusion. Edges of a\n        weight smaller than eps will be dropped.\n    avg_degree : int, optional\n        The desired average node degree of the result graph. This is the other way to\n        control the sparsity of the result graph and will only be effective if\n        :attr:`eps` is not given.\n\n    Example\n    -------\n\n    >>> import dgl\n    >>> import torch\n    >>> from dgl import GDC\n\n    >>> transform = GDC([0.3, 0.2, 0.1], avg_degree=2)\n    >>> g = dgl.graph(([0, 1, 2, 3, 4], [2, 3, 4, 5, 3]))\n    >>> g.edata['w'] = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5])\n    >>> new_g = transform(g)\n    >>> print(new_g.edata['w'])\n    tensor([0.3000, 0.3000, 0.0200, 0.3000, 0.0400, 0.3000, 0.1000, 0.0600, 0.3000,\n            0.0800, 0.0200, 0.3000])\n    \"\"\"\n\n    def __init__(self, coefs, eweight_name=\"w\", eps=None, avg_degree=5):\n        self.coefs = coefs\n        self.eweight_name = eweight_name\n        self.eps = eps\n        self.avg_degree = avg_degree\n\n    def get_eps(self, num_nodes, mat):\n        r\"\"\"Get the threshold for graph sparsification.\"\"\"\n        if self.eps is None:\n            # Infer from self.avg_degree\n            if self.avg_degree > num_nodes:\n                return float(\"-inf\")\n            sorted_weights = torch.sort(mat.flatten(), descending=True).values\n            return sorted_weights[self.avg_degree * num_nodes - 1]\n        else:\n            return self.eps\n\n    def __call__(self, g):\n        # Step1: diffusion\n        # A\n        device = g.device\n        eweight = g.edata.get(\n            self.eweight_name, F.ones((g.num_edges(),), F.float32, device)\n        )\n        num_nodes = g.num_nodes()\n        adj = F.zeros((num_nodes, num_nodes), F.float32, device)\n        src, dst = g.edges()\n        src, dst = F.astype(src, F.int64), F.astype(dst, F.int64)\n        adj[dst, src] = eweight\n\n        # theta_0 I_n\n        mat = torch.eye(num_nodes, device=device)\n        diff_mat = self.coefs[0] * mat\n        # add theta_k A^k\n        for coef in self.coefs[1:]:\n            mat = mat @ adj\n            diff_mat += coef * mat\n\n        # Step2: sparsification\n        num_nodes = g.num_nodes()\n        eps = self.get_eps(num_nodes, diff_mat)\n        dst, src = (diff_mat >= eps).nonzero(as_tuple=False).t()\n        data_dict = {g.canonical_etypes[0]: (src, dst)}\n        new_g = update_graph_structure(g, data_dict, copy_edata=False)\n        new_g.edata[self.eweight_name] = diff_mat[dst, src]\n\n        return new_g\n\n\nclass NodeShuffle(BaseTransform):\n    r\"\"\"Randomly shuffle the nodes.\n\n    Example\n    -------\n\n    >>> import dgl\n    >>> import torch\n    >>> from dgl import NodeShuffle\n\n    >>> transform = NodeShuffle()\n    >>> g = dgl.graph(([0, 1], [1, 2]))\n    >>> g.ndata['h1'] = torch.tensor([[1., 2.], [3., 4.], [5., 6.]])\n    >>> g.ndata['h2'] = torch.tensor([[7., 8.], [9., 10.], [11., 12.]])\n    >>> g = transform(g)\n    >>> print(g.ndata['h1'])\n    tensor([[5., 6.],\n            [3., 4.],\n            [1., 2.]])\n    >>> print(g.ndata['h2'])\n    tensor([[11., 12.],\n            [ 9., 10.],\n            [ 7.,  8.]])\n    \"\"\"\n\n    def __call__(self, g):\n        g = g.clone()\n        for ntype in g.ntypes:\n            nids = F.astype(g.nodes(ntype), F.int64)\n            perm = F.rand_shuffle(nids)\n            for key, feat in g.nodes[ntype].data.items():\n                g.nodes[ntype].data[key] = feat[perm]\n        return g\n\n\n# pylint: disable=C0103\nclass DropNode(BaseTransform):\n    r\"\"\"Randomly drop nodes, as described in\n    `Graph Contrastive Learning with Augmentations <https://arxiv.org/abs/2010.13902>`__.\n\n    Parameters\n    ----------\n    p : float, optional\n        Probability of a node to be dropped.\n\n    Example\n    -------\n\n    >>> import dgl\n    >>> import torch\n    >>> from dgl import DropNode\n\n    >>> transform = DropNode()\n    >>> g = dgl.rand_graph(5, 20)\n    >>> g.ndata['h'] = torch.arange(g.num_nodes())\n    >>> g.edata['h'] = torch.arange(g.num_edges())\n    >>> new_g = transform(g)\n    >>> print(new_g)\n    Graph(num_nodes=3, num_edges=7,\n          ndata_schemes={'h': Scheme(shape=(), dtype=torch.int64)}\n          edata_schemes={'h': Scheme(shape=(), dtype=torch.int64)})\n    >>> print(new_g.ndata['h'])\n    tensor([0, 1, 2])\n    >>> print(new_g.edata['h'])\n    tensor([0, 6, 14, 5, 17, 3, 11])\n    \"\"\"\n\n    def __init__(self, p=0.5):\n        self.p = p\n        self.dist = Bernoulli(p)\n\n    def __call__(self, g):\n        g = g.clone()\n\n        # Fast path\n        if self.p == 0:\n            return g\n\n        for ntype in g.ntypes:\n            samples = self.dist.sample(torch.Size([g.num_nodes(ntype)]))\n            nids_to_remove = g.nodes(ntype)[samples.bool().to(g.device)]\n            g.remove_nodes(nids_to_remove, ntype=ntype)\n        return g\n\n\n# pylint: disable=C0103\nclass DropEdge(BaseTransform):\n    r\"\"\"Randomly drop edges, as described in\n    `DropEdge: Towards Deep Graph Convolutional Networks on Node Classification\n    <https://arxiv.org/abs/1907.10903>`__ and `Graph Contrastive Learning with Augmentations\n    <https://arxiv.org/abs/2010.13902>`__.\n\n    Parameters\n    ----------\n    p : float, optional\n        Probability of an edge to be dropped.\n\n    Example\n    -------\n\n    >>> import dgl\n    >>> import torch\n    >>> from dgl import DropEdge\n\n    >>> transform = DropEdge()\n    >>> g = dgl.rand_graph(5, 20)\n    >>> g.edata['h'] = torch.arange(g.num_edges())\n    >>> new_g = transform(g)\n    >>> print(new_g)\n    Graph(num_nodes=5, num_edges=12,\n          ndata_schemes={}\n          edata_schemes={'h': Scheme(shape=(), dtype=torch.int64)})\n    >>> print(new_g.edata['h'])\n    tensor([0, 1, 3, 7, 8, 10, 11, 12, 13, 15, 18, 19])\n    \"\"\"\n\n    def __init__(self, p=0.5):\n        self.p = p\n        self.dist = Bernoulli(p)\n\n    def __call__(self, g):\n        g = g.clone()\n\n        # Fast path\n        if self.p == 0:\n            return g\n\n        for c_etype in g.canonical_etypes:\n            samples = self.dist.sample(torch.Size([g.num_edges(c_etype)]))\n            eids_to_remove = g.edges(form=\"eid\", etype=c_etype)[\n                samples.bool().to(g.device)\n            ]\n            g.remove_edges(eids_to_remove, etype=c_etype)\n        return g\n\n\nclass AddEdge(BaseTransform):\n    r\"\"\"Randomly add edges, as described in `Graph Contrastive Learning with Augmentations\n    <https://arxiv.org/abs/2010.13902>`__.\n\n    Parameters\n    ----------\n    ratio : float, optional\n        Number of edges to add divided by the number of existing edges.\n\n    Example\n    -------\n\n    >>> import dgl\n    >>> from dgl import AddEdge\n\n    >>> transform = AddEdge()\n    >>> g = dgl.rand_graph(5, 20)\n    >>> new_g = transform(g)\n    >>> print(new_g.num_edges())\n    24\n    \"\"\"\n\n    def __init__(self, ratio=0.2):\n        self.ratio = ratio\n\n    def __call__(self, g):\n        # Fast path\n        if self.ratio == 0.0:\n            return g\n\n        device = g.device\n        idtype = g.idtype\n        g = g.clone()\n        for c_etype in g.canonical_etypes:\n            utype, _, vtype = c_etype\n            num_edges_to_add = int(g.num_edges(c_etype) * self.ratio)\n            src = F.randint(\n                [num_edges_to_add],\n                idtype,\n                device,\n                low=0,\n                high=g.num_nodes(utype),\n            )\n            dst = F.randint(\n                [num_edges_to_add],\n                idtype,\n                device,\n                low=0,\n                high=g.num_nodes(vtype),\n            )\n            g.add_edges(src, dst, etype=c_etype)\n        return g\n\n\nclass SIGNDiffusion(BaseTransform):\n    r\"\"\"The diffusion operator from `SIGN: Scalable Inception Graph Neural Networks\n    <https://arxiv.org/abs/2004.11198>`__\n\n    It performs node feature diffusion with :math:`TX, \\cdots, T^{k}X`, where :math:`T`\n    is a diffusion matrix and :math:`X` is the input node features.\n\n    Specifically, this module provides four options for :math:`T`.\n\n    **raw**: raw adjacency matrix :math:`A`\n\n    **rw**: random walk (row-normalized) adjacency matrix :math:`D^{-1}A`, where\n    :math:`D` is the degree matrix.\n\n    **gcn**: symmetrically normalized adjacency matrix used by\n    `GCN <https://arxiv.org/abs/1609.02907>`__, :math:`D^{-1/2}AD^{-1/2}`\n\n    **ppr**: approximate personalized PageRank used by\n    `APPNP <https://arxiv.org/abs/1810.05997>`__\n\n    .. math::\n        H^{0} &= X\n\n        H^{l+1} &= (1-\\alpha)\\left(D^{-1/2}AD^{-1/2} H^{l}\\right) + \\alpha X\n\n    This module only works for homogeneous graphs.\n\n    Parameters\n    ----------\n    k : int\n        The maximum number of times for node feature diffusion.\n    in_feat_name : str, optional\n        :attr:`g.ndata[{in_feat_name}]` should store the input node features. Default: 'feat'\n    out_feat_name : str, optional\n        :attr:`g.ndata[{out_feat_name}_i]` will store the result of diffusing\n        input node features for i times. Default: 'out_feat'\n    eweight_name : str, optional\n        Name to retrieve edge weights from :attr:`g.edata`. Default: None,\n        treating the graph as unweighted.\n    diffuse_op : str, optional\n        The diffusion operator to use, which can be 'raw', 'rw', 'gcn', or 'ppr'.\n        Default: 'raw'\n    alpha : float, optional\n        Restart probability if :attr:`diffuse_op` is :attr:`'ppr'`,\n        which commonly lies in :math:`[0.05, 0.2]`. Default: 0.2\n\n    Example\n    -------\n\n    >>> import dgl\n    >>> import torch\n    >>> from dgl import SIGNDiffusion\n\n    >>> transform = SIGNDiffusion(k=2, eweight_name='w')\n    >>> num_nodes = 5\n    >>> num_edges = 20\n    >>> g = dgl.rand_graph(num_nodes, num_edges)\n    >>> g.ndata['feat'] = torch.randn(num_nodes, 10)\n    >>> g.edata['w'] = torch.randn(num_edges)\n    >>> transform(g)\n    Graph(num_nodes=5, num_edges=20,\n          ndata_schemes={'feat': Scheme(shape=(10,), dtype=torch.float32),\n                         'out_feat_1': Scheme(shape=(10,), dtype=torch.float32),\n                         'out_feat_2': Scheme(shape=(10,), dtype=torch.float32)}\n          edata_schemes={'w': Scheme(shape=(), dtype=torch.float32)})\n    \"\"\"\n\n    def __init__(\n        self,\n        k,\n        in_feat_name=\"feat\",\n        out_feat_name=\"out_feat\",\n        eweight_name=None,\n        diffuse_op=\"raw\",\n        alpha=0.2,\n    ):\n        self.k = k\n        self.in_feat_name = in_feat_name\n        self.out_feat_name = out_feat_name\n        self.eweight_name = eweight_name\n        self.diffuse_op = diffuse_op\n        self.alpha = alpha\n\n        if diffuse_op == \"raw\":\n            self.diffuse = self.raw\n        elif diffuse_op == \"rw\":\n            self.diffuse = self.rw\n        elif diffuse_op == \"gcn\":\n            self.diffuse = self.gcn\n        elif diffuse_op == \"ppr\":\n            self.diffuse = self.ppr\n        else:\n            raise DGLError(\n                \"Expect diffuse_op to be from ['raw', 'rw', 'gcn', 'ppr'], \\\n                got {}\".format(\n                    diffuse_op\n                )\n            )\n\n    def __call__(self, g):\n        feat_list = self.diffuse(g)\n\n        for i in range(1, self.k + 1):\n            g.ndata[self.out_feat_name + \"_\" + str(i)] = feat_list[i - 1]\n        return g\n\n    def raw(self, g):\n        use_eweight = False\n        if (self.eweight_name is not None) and self.eweight_name in g.edata:\n            use_eweight = True\n\n        feat_list = []\n        with g.local_scope():\n            if use_eweight:\n                message_func = fn.u_mul_e(\n                    self.in_feat_name, self.eweight_name, \"m\"\n                )\n            else:\n                message_func = fn.copy_u(self.in_feat_name, \"m\")\n            for _ in range(self.k):\n                g.update_all(message_func, fn.sum(\"m\", self.in_feat_name))\n                feat_list.append(g.ndata[self.in_feat_name])\n        return feat_list\n\n    def rw(self, g):\n        use_eweight = False\n        if (self.eweight_name is not None) and self.eweight_name in g.edata:\n            use_eweight = True\n\n        feat_list = []\n        with g.local_scope():\n            g.ndata[\"h\"] = g.ndata[self.in_feat_name]\n            if use_eweight:\n                message_func = fn.u_mul_e(\"h\", self.eweight_name, \"m\")\n                reduce_func = fn.sum(\"m\", \"h\")\n                # Compute the diagonal entries of D from the weighted A\n                g.update_all(\n                    fn.copy_e(self.eweight_name, \"m\"), fn.sum(\"m\", \"z\")\n                )\n            else:\n                message_func = fn.copy_u(\"h\", \"m\")\n                reduce_func = fn.mean(\"m\", \"h\")\n\n            for _ in range(self.k):\n                g.update_all(message_func, reduce_func)\n                if use_eweight:\n                    g.ndata[\"h\"] = g.ndata[\"h\"] / F.reshape(\n                        g.ndata[\"z\"], (g.num_nodes(), 1)\n                    )\n                feat_list.append(g.ndata[\"h\"])\n        return feat_list\n\n    def gcn(self, g):\n        feat_list = []\n        with g.local_scope():\n            if self.eweight_name is None:\n                eweight_name = \"w\"\n                if eweight_name in g.edata:\n                    g.edata.pop(eweight_name)\n            else:\n                eweight_name = self.eweight_name\n\n            transform = GCNNorm(eweight_name=eweight_name)\n            transform(g)\n\n            for _ in range(self.k):\n                g.update_all(\n                    fn.u_mul_e(self.in_feat_name, eweight_name, \"m\"),\n                    fn.sum(\"m\", self.in_feat_name),\n                )\n                feat_list.append(g.ndata[self.in_feat_name])\n        return feat_list\n\n    def ppr(self, g):\n        feat_list = []\n        with g.local_scope():\n            if self.eweight_name is None:\n                eweight_name = \"w\"\n                if eweight_name in g.edata:\n                    g.edata.pop(eweight_name)\n            else:\n                eweight_name = self.eweight_name\n            transform = GCNNorm(eweight_name=eweight_name)\n            transform(g)\n\n            in_feat = g.ndata[self.in_feat_name]\n            for _ in range(self.k):\n                g.update_all(\n                    fn.u_mul_e(self.in_feat_name, eweight_name, \"m\"),\n                    fn.sum(\"m\", self.in_feat_name),\n                )\n                g.ndata[self.in_feat_name] = (1 - self.alpha) * g.ndata[\n                    self.in_feat_name\n                ] + self.alpha * in_feat\n                feat_list.append(g.ndata[self.in_feat_name])\n        return feat_list\n\n\nclass ToLevi(BaseTransform):\n    r\"\"\"This function transforms the original graph to its heterogeneous Levi graph,\n    by converting edges to intermediate nodes, only support homogeneous directed graph.\n\n    Example\n    -------\n    >>> import dgl\n    >>> import torch as th\n    >>> from dgl import ToLevi\n\n    >>> transform = ToLevi()\n    >>> g = dgl.graph(([0, 1, 2, 3], [1, 2, 3, 0]))\n    >>> g.ndata['h'] = th.randn((g.num_nodes(), 2))\n    >>> g.edata['w'] = th.randn((g.num_edges(), 2))\n    >>> lg = transform(g)\n    >>> lg\n    Grpah(num_nodes={'edge': 4, 'node': 4},\n          num_edges={('edge', 'e2n', 'node'): 4,\n                     ('node', 'n2e', 'edge'): 4},\n          metagraph=[('edge', 'node', 'e2n'),\n                     ('node', 'edge', 'n2e')])\n    >>> lg.nodes('node')\n    tensor([0, 1, 2, 3])\n    >>> lg.nodes('edge')\n    tensor([0, 1, 2, 3])\n    >>> lg.nodes['node'].data['h'].shape\n    torch.Size([4, 2])\n    >>> lg.nodes['edge'].data['w'].shape\n    torch.Size([4, 2])\n    \"\"\"\n\n    def __init__(self):\n        pass\n\n    def __call__(self, g):\n        r\"\"\"\n        Parameters\n        ----------\n        g : DGLGraph\n            The input graph, should be a homogeneous directed graph.\n\n        Returns\n        -------\n        DGLGraph\n            The Levi graph of input, will be a heterogeneous graph, where nodes of\n            ntypes ``'node'`` and ``'edge'`` have corresponding IDs of nodes and edges\n            in the original graph. Edge features of the input graph are copied to\n            corresponding new nodes of ntype ``'edge'``.\n        \"\"\"\n        device = g.device\n        idtype = g.idtype\n\n        edge_list = g.edges()\n        n2e = edge_list[0], F.arange(0, g.num_edges(), idtype, device)\n        e2n = F.arange(0, g.num_edges(), idtype, device), edge_list[1]\n        graph_data = {\n            (\"node\", \"n2e\", \"edge\"): n2e,\n            (\"edge\", \"e2n\", \"node\"): e2n,\n        }\n        levi_g = convert.heterograph(graph_data, idtype=idtype, device=device)\n\n        # Copy ndata and edata\n        # Since the node types in dgl.heterograph are in alphabetical order\n        # ('edge' < 'node'), edge_frames should be in front of node_frames.\n        node_frames = utils.extract_node_subframes(g, nodes_or_device=device)\n        edge_frames = utils.extract_edge_subframes(g, edges_or_device=device)\n        utils.set_new_frames(levi_g, node_frames=edge_frames + node_frames)\n\n        return levi_g\n\n\nclass SVDPE(BaseTransform):\n    r\"\"\"SVD-based Positional Encoding, as introduced in\n    `Global Self-Attention as a Replacement for Graph Convolution\n    <https://arxiv.org/pdf/2108.03348.pdf>`__\n\n    This function computes the largest :math:`k` singular values and\n    corresponding left and right singular vectors to form positional encodings,\n    which could be stored in ndata.\n\n    Parameters\n    ----------\n    k : int\n        Number of largest singular values and corresponding singular vectors\n        used for positional encoding.\n    feat_name : str, optional\n        Name to store the computed positional encodings in ndata.\n        Default : ``svd_pe``\n    padding : bool, optional\n        If False, raise an error when :math:`k > N`,\n        where :math:`N` is the number of nodes in :attr:`g`.\n        If True, add zero paddings in the end of encodings when :math:`k > N`.\n        Default : False.\n    random_flip : bool, optional\n        If True, randomly flip the signs of encoding vectors.\n        Proposed to be activated during training for better generalization.\n        Default : True.\n\n    Example\n    -------\n    >>> import dgl\n    >>> from dgl import SVDPE\n\n    >>> transform = SVDPE(k=2, feat_name=\"svd_pe\")\n    >>> g = dgl.graph(([0,1,2,3,4,2,3,1,4,0], [2,3,1,4,0,0,1,2,3,4]))\n    >>> g_ = transform(g)\n    >>> print(g_.ndata['svd_pe'])\n    tensor([[-6.3246e-01, -1.1373e-07, -6.3246e-01,  0.0000e+00],\n            [-6.3246e-01,  7.6512e-01, -6.3246e-01, -7.6512e-01],\n            [ 6.3246e-01,  4.7287e-01,  6.3246e-01, -4.7287e-01],\n            [-6.3246e-01, -7.6512e-01, -6.3246e-01,  7.6512e-01],\n            [ 6.3246e-01, -4.7287e-01,  6.3246e-01,  4.7287e-01]])\n    \"\"\"\n\n    def __init__(self, k, feat_name=\"svd_pe\", padding=False, random_flip=True):\n        self.k = k\n        self.feat_name = feat_name\n        self.padding = padding\n        self.random_flip = random_flip\n\n    def __call__(self, g):\n        encoding = functional.svd_pe(\n            g, k=self.k, padding=self.padding, random_flip=self.random_flip\n        )\n        g.ndata[self.feat_name] = F.copy_to(encoding, g.device)\n\n        return g\n"
  },
  {
    "path": "python/dgl/transforms/to_block.py",
    "content": "#   Copyright (c) 2023, DGL Team\n#\n#   Licensed under the Apache License, Version 2.0 (the \"License\");\n#   you may not use this file except in compliance with the License.\n#   You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n#   Unless required by applicable law or agreed to in writing, software\n#   distributed under the License is distributed on an \"AS IS\" BASIS,\n#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#   See the License for the specific language governing permissions and\n#   limitations under the License.\n\n\"\"\"To block method.\"\"\"\n\nfrom collections import defaultdict\nfrom collections.abc import Mapping\n\nfrom .. import backend as F, utils\nfrom ..base import DGLError\nfrom ..heterograph import DGLBlock\nfrom .._ffi.capi import *\n\n__all__ = [\"to_block\"]\n\n\ndef to_block(g, dst_nodes=None, include_dst_in_src=True, src_nodes=None):\n    \"\"\"Convert a graph into a bipartite-structured *block* for message passing.\n\n    A block is a graph consisting of two sets of nodes: the\n    *source* nodes and *destination* nodes.  The source and destination nodes can have multiple\n    node types.  All the edges connect from source nodes to destination nodes.\n\n    Specifically, the source nodes and destination nodes will have the same node types as the\n    ones in the original graph.  DGL maps each edge ``(u, v)`` with edge type\n    ``(utype, etype, vtype)`` in the original graph to the edge with type\n    ``etype`` connecting from node ID ``u`` of type ``utype`` in the source side to node\n    ID ``v`` of type ``vtype`` in the destination side.\n\n    For blocks returned by :func:`to_block`, the destination nodes of the block will only\n    contain the nodes that have at least one inbound edge of any type.  The source nodes\n    of the block will only contain the nodes that appear in the destination nodes, as well\n    as the nodes that have at least one outbound edge connecting to one of the destination nodes.\n\n    The destination nodes are specified by the :attr:`dst_nodes` argument if it is not None.\n\n    Parameters\n    ----------\n    graph : DGLGraph\n        The graph.  Can be either on CPU or GPU.\n    dst_nodes : Tensor or dict[str, Tensor], optional\n        The list of destination nodes.\n\n        If a tensor is given, the graph must have only one node type.\n\n        If given, it must be a superset of all the nodes that have at least one inbound\n        edge.  An error will be raised otherwise.\n    include_dst_in_src : bool\n        If False, do not include destination nodes in source nodes.\n\n        (Default: True)\n\n    src_nodes : Tensor or disct[str, Tensor], optional\n        The list of source nodes (and prefixed by destination nodes if\n        `include_dst_in_src` is True).\n\n        If a tensor is given, the graph must have only one node type.\n\n    Returns\n    -------\n    DGLBlock\n        The new graph describing the block.\n\n        The node IDs induced for each type in both sides would be stored in feature\n        ``dgl.NID``.\n\n        The edge IDs induced for each type would be stored in feature ``dgl.EID``.\n\n    Raises\n    ------\n    DGLError\n        If :attr:`dst_nodes` is specified but it is not a superset of all the nodes that\n        have at least one inbound edge.\n\n        If :attr:`dst_nodes` is not None, and :attr:`g` and :attr:`dst_nodes`\n        are not in the same context.\n\n    Notes\n    -----\n    :func:`to_block` is most commonly used in customizing neighborhood sampling\n    for stochastic training on a large graph.  Please refer to the user guide\n    :ref:`guide-minibatch` for a more thorough discussion about the methodology\n    of stochastic training.\n\n    See also :func:`create_block` for more flexible construction of blocks.\n\n    Examples\n    --------\n    Converting a homogeneous graph to a block as described above:\n\n    >>> g = dgl.graph(([1, 2], [2, 3]))\n    >>> block = dgl.to_block(g, torch.LongTensor([3, 2]))\n\n    The destination nodes would be exactly the same as the ones given: [3, 2].\n\n    >>> induced_dst = block.dstdata[dgl.NID]\n    >>> induced_dst\n    tensor([3, 2])\n\n    The first few source nodes would also be exactly the same as\n    the ones given.  The rest of the nodes are the ones necessary for message passing\n    into nodes 3, 2.  This means that the node 1 would be included.\n\n    >>> induced_src = block.srcdata[dgl.NID]\n    >>> induced_src\n    tensor([3, 2, 1])\n\n    You can notice that the first two nodes are identical to the given nodes as well as\n    the destination nodes.\n\n    The induced edges can also be obtained by the following:\n\n    >>> block.edata[dgl.EID]\n    tensor([2, 1])\n\n    This indicates that edge (2, 3) and (1, 2) are included in the result graph.  You can\n    verify that the first edge in the block indeed maps to the edge (2, 3), and the\n    second edge in the block indeed maps to the edge (1, 2):\n\n    >>> src, dst = block.edges(order='eid')\n    >>> induced_src[src], induced_dst[dst]\n    (tensor([2, 1]), tensor([3, 2]))\n\n    The destination nodes specified must be a superset of the nodes that have edges connecting\n    to them.  For example, the following will raise an error since the destination nodes\n    does not contain node 3, which has an edge connecting to it.\n\n    >>> g = dgl.graph(([1, 2], [2, 3]))\n    >>> dgl.to_block(g, torch.LongTensor([2]))     # error\n\n    Converting a heterogeneous graph to a block is similar, except that when specifying\n    the destination nodes, you have to give a dict:\n\n    >>> g = dgl.heterograph({('A', '_E', 'B'): ([1, 2], [2, 3])})\n\n    If you don't specify any node of type A on the destination side, the node type ``A``\n    in the block would have zero nodes on the destination side.\n\n    >>> block = dgl.to_block(g, {'B': torch.LongTensor([3, 2])})\n    >>> block.number_of_dst_nodes('A')\n    0\n    >>> block.number_of_dst_nodes('B')\n    2\n    >>> block.dstnodes['B'].data[dgl.NID]\n    tensor([3, 2])\n\n    The source side would contain all the nodes on the destination side:\n\n    >>> block.srcnodes['B'].data[dgl.NID]\n    tensor([3, 2])\n\n    As well as all the nodes that have connections to the nodes on the destination side:\n\n    >>> block.srcnodes['A'].data[dgl.NID]\n    tensor([2, 1])\n\n    See also\n    --------\n    create_block\n    \"\"\"\n    if dst_nodes is None:\n        # Find all nodes that appeared as destinations\n        dst_nodes = defaultdict(list)\n        for etype in g.canonical_etypes:\n            _, dst = g.edges(etype=etype)\n            dst_nodes[etype[2]].append(dst)\n        dst_nodes = {\n            ntype: F.unique(F.cat(values, 0))\n            for ntype, values in dst_nodes.items()\n        }\n    elif not isinstance(dst_nodes, Mapping):\n        # dst_nodes is a Tensor, check if the g has only one type.\n        if len(g.ntypes) > 1:\n            raise DGLError(\n                \"Graph has more than one node type; please specify a dict for dst_nodes.\"\n            )\n        dst_nodes = {g.ntypes[0]: dst_nodes}\n\n    dst_node_ids = [\n        utils.toindex(dst_nodes.get(ntype, []), g._idtype_str).tousertensor(\n            ctx=F.to_backend_ctx(g._graph.ctx)\n        )\n        for ntype in g.ntypes\n    ]\n    dst_node_ids_nd = [F.to_dgl_nd(nodes) for nodes in dst_node_ids]\n\n    for d in dst_node_ids_nd:\n        if g._graph.ctx != d.ctx:\n            raise ValueError(\"g and dst_nodes need to have the same context.\")\n\n    src_node_ids = None\n    src_node_ids_nd = None\n    if src_nodes is not None and not isinstance(src_nodes, Mapping):\n        # src_nodes is a Tensor, check if the g has only one type.\n        if len(g.ntypes) > 1:\n            raise DGLError(\n                \"Graph has more than one node type; please specify a dict for src_nodes.\"\n            )\n        src_nodes = {g.ntypes[0]: src_nodes}\n        src_node_ids = [\n            F.copy_to(\n                F.tensor(src_nodes.get(ntype, []), dtype=g.idtype),\n                F.to_backend_ctx(g._graph.ctx),\n            )\n            for ntype in g.ntypes\n        ]\n        src_node_ids_nd = [F.to_dgl_nd(nodes) for nodes in src_node_ids]\n\n        for d in src_node_ids_nd:\n            if g._graph.ctx != d.ctx:\n                raise ValueError(\n                    \"g and src_nodes need to have the same context.\"\n                )\n    else:\n        # use an empty list to signal we need to generate it\n        src_node_ids_nd = []\n\n    new_graph_index, src_nodes_ids_nd, induced_edges_nd = _CAPI_DGLToBlock(\n        g._graph, dst_node_ids_nd, include_dst_in_src, src_node_ids_nd\n    )\n\n    # The new graph duplicates the original node types to SRC and DST sets.\n    new_ntypes = (g.ntypes, g.ntypes)\n    new_graph = DGLBlock(new_graph_index, new_ntypes, g.etypes)\n    assert new_graph.is_unibipartite  # sanity check\n\n    src_node_ids = [F.from_dgl_nd(src) for src in src_nodes_ids_nd]\n    edge_ids = [F.from_dgl_nd(eid) for eid in induced_edges_nd]\n\n    node_frames = utils.extract_node_subframes_for_block(\n        g, src_node_ids, dst_node_ids\n    )\n    edge_frames = utils.extract_edge_subframes(g, edge_ids)\n    utils.set_new_frames(\n        new_graph, node_frames=node_frames, edge_frames=edge_frames\n    )\n\n    return new_graph\n"
  },
  {
    "path": "python/dgl/traversal.py",
    "content": "\"\"\"Module for graph traversal methods.\"\"\"\nfrom __future__ import absolute_import\n\nfrom . import backend as F, utils\nfrom ._ffi.function import _init_api\nfrom .heterograph import DGLGraph\n\n__all__ = [\n    \"bfs_nodes_generator\",\n    \"bfs_edges_generator\",\n    \"topological_nodes_generator\",\n    \"dfs_edges_generator\",\n    \"dfs_labeled_edges_generator\",\n]\n\n\ndef bfs_nodes_generator(graph, source, reverse=False):\n    \"\"\"Node frontiers generator using breadth-first search.\n\n    Parameters\n    ----------\n    graph : DGLGraph\n        The graph object.\n    source : list, tensor of nodes\n        Source nodes.\n    reverse : bool, default False\n        If True, traverse following the in-edge direction.\n\n    Returns\n    -------\n    list of node frontiers\n        Each node frontier is a list or tensor of node ids.\n\n    Examples\n    --------\n    Given a graph (directed, edges from small node id to large):\n    ::\n\n              2 - 4\n             / \\\\\n        0 - 1 - 3 - 5\n\n    >>> g = dgl.graph(([0, 1, 1, 2, 2, 3], [1, 2, 3, 3, 4, 5]))\n    >>> list(dgl.bfs_nodes_generator(g, 0))\n    [tensor([0]), tensor([1]), tensor([2, 3]), tensor([4, 5])]\n    \"\"\"\n    assert isinstance(\n        graph, DGLGraph\n    ), \"DGLHeteroGraph is merged with DGLGraph, Please use DGLGraph\"\n    assert (\n        len(graph.canonical_etypes) == 1\n    ), \"bfs_nodes_generator only support homogeneous graph\"\n    # Workaround before support for GPU graph\n    gidx = graph._graph.copy_to(utils.to_dgl_context(F.cpu()))\n    source = utils.toindex(source, dtype=graph._idtype_str)\n    ret = _CAPI_DGLBFSNodes_v2(gidx, source.todgltensor(), reverse)\n    all_nodes = utils.toindex(ret(0), dtype=graph._idtype_str).tousertensor()\n    # TODO(minjie): how to support directly creating python list\n    sections = utils.toindex(ret(1)).tonumpy().tolist()\n    node_frontiers = F.split(all_nodes, sections, dim=0)\n    return node_frontiers\n\n\ndef bfs_edges_generator(graph, source, reverse=False):\n    \"\"\"Edges frontiers generator using breadth-first search.\n\n    Parameters\n    ----------\n    graph : DGLGraph\n        The graph object.\n    source : list, tensor of nodes\n        Source nodes.\n    reverse : bool, default False\n        If True, traverse following the in-edge direction.\n\n    Returns\n    -------\n    list of edge frontiers\n        Each edge frontier is a list or tensor of edge ids.\n\n    Examples\n    --------\n    Given a graph (directed, edges from small node id to large, sorted\n    in lexicographical order of source-destination node id tuple):\n    ::\n\n              2 - 4\n             / \\\\\n        0 - 1 - 3 - 5\n\n    >>> g = dgl.graph(([0, 1, 1, 2, 2, 3], [1, 2, 3, 3, 4, 5]))\n    >>> list(dgl.bfs_edges_generator(g, 0))\n    [tensor([0]), tensor([1, 2]), tensor([4, 5])]\n    \"\"\"\n    assert isinstance(\n        graph, DGLGraph\n    ), \"DGLHeteroGraph is merged with DGLGraph, Please use DGLGraph\"\n    assert (\n        len(graph.canonical_etypes) == 1\n    ), \"bfs_edges_generator only support homogeneous graph\"\n    # Workaround before support for GPU graph\n    gidx = graph._graph.copy_to(utils.to_dgl_context(F.cpu()))\n    source = utils.toindex(source, dtype=graph._idtype_str)\n    ret = _CAPI_DGLBFSEdges_v2(gidx, source.todgltensor(), reverse)\n    all_edges = utils.toindex(ret(0), dtype=graph._idtype_str).tousertensor()\n    # TODO(minjie): how to support directly creating python list\n    sections = utils.toindex(ret(1)).tonumpy().tolist()\n    edge_frontiers = F.split(all_edges, sections, dim=0)\n    return edge_frontiers\n\n\ndef topological_nodes_generator(graph, reverse=False):\n    \"\"\"Node frontiers generator using topological traversal.\n\n    Parameters\n    ----------\n    graph : DGLGraph\n        The graph object.\n    reverse : bool, optional\n        If True, traverse following the in-edge direction.\n\n    Returns\n    -------\n    list of node frontiers\n        Each node frontier is a list or tensor of node ids.\n\n    Examples\n    --------\n    Given a graph (directed, edges from small node id to large):\n    ::\n\n              2 - 4\n             / \\\\\n        0 - 1 - 3 - 5\n\n    >>> g = dgl.graph(([0, 1, 1, 2, 2, 3], [1, 2, 3, 3, 4, 5]))\n    >>> list(dgl.topological_nodes_generator(g))\n    [tensor([0]), tensor([1]), tensor([2]), tensor([3, 4]), tensor([5])]\n    \"\"\"\n    assert isinstance(\n        graph, DGLGraph\n    ), \"DGLHeteroGraph is merged with DGLGraph, Please use DGLGraph\"\n    assert (\n        len(graph.canonical_etypes) == 1\n    ), \"topological_nodes_generator only support homogeneous graph\"\n    # Workaround before support for GPU graph\n    gidx = graph._graph.copy_to(utils.to_dgl_context(F.cpu()))\n    ret = _CAPI_DGLTopologicalNodes_v2(gidx, reverse)\n    all_nodes = utils.toindex(ret(0), dtype=graph._idtype_str).tousertensor()\n    # TODO(minjie): how to support directly creating python list\n    sections = utils.toindex(ret(1)).tonumpy().tolist()\n    return F.split(all_nodes, sections, dim=0)\n\n\ndef dfs_edges_generator(graph, source, reverse=False):\n    \"\"\"Edge frontiers generator using depth-first-search (DFS).\n\n    Multiple source nodes can be specified to start the DFS traversal. One\n    needs to make sure that each source node belongs to different connected\n    component, so the frontiers can be easily merged. Otherwise, the behavior\n    is undefined.\n\n    Parameters\n    ----------\n    graph : DGLGraph\n        The graph object.\n    source : list, tensor of nodes\n        Source nodes.\n    reverse : bool, optional\n        If True, traverse following the in-edge direction.\n\n    Returns\n    -------\n    list of edge frontiers\n        Each edge frontier is a list or tensor of edge ids.\n\n    Examples\n    --------\n    Given a graph (directed, edges from small node id to large):\n    ::\n\n              2 - 4\n             / \\\\\n        0 - 1 - 3 - 5\n\n    Edge addition order [(0, 1), (1, 2), (1, 3), (2, 3), (2, 4), (3, 5)]\n\n    >>> g = dgl.graph(([0, 1, 1, 2, 2, 3], [1, 2, 3, 3, 4, 5]))\n    >>> list(dgl.dfs_edges_generator(g, 0))\n    [tensor([0]), tensor([1]), tensor([3]), tensor([5]), tensor([4])]\n    \"\"\"\n    assert isinstance(\n        graph, DGLGraph\n    ), \"DGLHeteroGraph is merged with DGLGraph, Please use DGLGraph\"\n    assert (\n        len(graph.canonical_etypes) == 1\n    ), \"dfs_edges_generator only support homogeneous graph\"\n    # Workaround before support for GPU graph\n    gidx = graph._graph.copy_to(utils.to_dgl_context(F.cpu()))\n    source = utils.toindex(source, dtype=graph._idtype_str)\n    ret = _CAPI_DGLDFSEdges_v2(gidx, source.todgltensor(), reverse)\n    all_edges = utils.toindex(ret(0), dtype=graph._idtype_str).tousertensor()\n    # TODO(minjie): how to support directly creating python list\n    sections = utils.toindex(ret(1)).tonumpy().tolist()\n    return F.split(all_edges, sections, dim=0)\n\n\ndef dfs_labeled_edges_generator(\n    graph,\n    source,\n    reverse=False,\n    has_reverse_edge=False,\n    has_nontree_edge=False,\n    return_labels=True,\n):\n    \"\"\"Produce edges in a depth-first-search (DFS) labeled by type.\n\n    There are three labels: FORWARD(0), REVERSE(1), NONTREE(2)\n\n    A FORWARD edge is one in which `u` has been visited but `v` has not. A\n    REVERSE edge is one in which both `u` and `v` have been visited and the\n    edge is in the DFS tree. A NONTREE edge is one in which both `u` and `v`\n    have been visited but the edge is NOT in the DFS tree.\n\n    See ``networkx``'s :func:`dfs_labeled_edges\n    <networkx.algorithms.traversal.depth_first_search.dfs_labeled_edges>`\n    for more details.\n\n    Multiple source nodes can be specified to start the DFS traversal. One\n    needs to make sure that each source node belongs to different connected\n    component, so the frontiers can be easily merged. Otherwise, the behavior\n    is undefined.\n\n    Parameters\n    ----------\n    graph : DGLGraph\n        The graph object.\n    source : list, tensor of nodes\n        Source nodes.\n    reverse : bool, optional\n        If true, traverse following the in-edge direction.\n    has_reverse_edge : bool, optional\n        True to include reverse edges.\n    has_nontree_edge : bool, optional\n        True to include nontree edges.\n    return_labels : bool, optional\n        True to return the labels of each edge.\n\n    Returns\n    -------\n    list of edge frontiers\n        Each edge frontier is a list or tensor of edge ids.\n    list of list of int\n        Label of each edge, organized in the same order as the edge frontiers.\n\n    Examples\n    --------\n    Given a graph (directed, edges from small node id to large):\n    ::\n\n              2 - 4\n             / \\\\\n        0 - 1 - 3 - 5\n\n    Edge addition order [(0, 1), (1, 2), (1, 3), (2, 3), (2, 4), (3, 5)]\n\n    >>> g = dgl.graph(([0, 1, 1, 2, 2, 3], [1, 2, 3, 3, 4, 5]))\n    >>> list(dgl.dfs_labeled_edges_generator(g, 0, has_nontree_edge=True))\n    (tensor([0]), tensor([1]), tensor([3]), tensor([5]), tensor([4]), tensor([2])),\n    (tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([2]))\n    \"\"\"\n    assert isinstance(\n        graph, DGLGraph\n    ), \"DGLHeteroGraph is merged with DGLGraph, Please use DGLGraph\"\n    assert (\n        len(graph.canonical_etypes) == 1\n    ), \"dfs_labeled_edges_generator only support homogeneous graph\"\n    # Workaround before support for GPU graph\n    gidx = graph._graph.copy_to(utils.to_dgl_context(F.cpu()))\n    source = utils.toindex(source, dtype=graph._idtype_str)\n    ret = _CAPI_DGLDFSLabeledEdges_v2(\n        gidx,\n        source.todgltensor(),\n        reverse,\n        has_reverse_edge,\n        has_nontree_edge,\n        return_labels,\n    )\n    all_edges = utils.toindex(ret(0), dtype=graph._idtype_str).tousertensor()\n    # TODO(minjie): how to support directly creating python list\n    if return_labels:\n        all_labels = utils.toindex(ret(1)).tousertensor()\n        sections = utils.toindex(ret(2)).tonumpy().tolist()\n        return (\n            F.split(all_edges, sections, dim=0),\n            F.split(all_labels, sections, dim=0),\n        )\n    else:\n        sections = utils.toindex(ret(1)).tonumpy().tolist()\n        return F.split(all_edges, sections, dim=0)\n\n\n_init_api(\"dgl.traversal\")\n"
  },
  {
    "path": "python/dgl/udf.py",
    "content": "\"\"\"User-defined function related data structures.\"\"\"\nfrom __future__ import absolute_import\n\n\nclass EdgeBatch(object):\n    \"\"\"The class that can represent a batch of edges.\n\n    Parameters\n    ----------\n    graph : DGLGraph\n        Graph object.\n    eid : Tensor\n        Edge IDs.\n    etype : (str, str, str)\n        Edge type.\n    src_data : dict[str, Tensor]\n        Src node features.\n    edge_data : dict[str, Tensor]\n        Edge features.\n    dst_data : dict[str, Tensor]\n        Dst node features.\n    \"\"\"\n\n    def __init__(self, graph, eid, etype, src_data, edge_data, dst_data):\n        self._graph = graph\n        self._eid = eid\n        self._etype = etype\n        self._src_data = src_data\n        self._edge_data = edge_data\n        self._dst_data = dst_data\n\n    @property\n    def src(self):\n        \"\"\"Return a view of the source node features for the edges in the batch.\n\n        Examples\n        --------\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        >>> # Instantiate a graph and set a node feature 'h'.\n        >>> g = dgl.graph((torch.tensor([0, 1, 1]), torch.tensor([1, 1, 0])))\n        >>> g.ndata['h'] = torch.ones(2, 1)\n\n        >>> # Define a UDF that retrieves the source node features for edges.\n        >>> def edge_udf(edges):\n        >>>     # edges.src['h'] is a tensor of shape (E, 1),\n        >>>     # where E is the number of edges in the batch.\n        >>>     return {'src': edges.src['h']}\n\n        >>> # Copy features from source nodes to edges.\n        >>> g.apply_edges(edge_udf)\n        >>> g.edata['src']\n        tensor([[1.],\n                [1.],\n                [1.]])\n\n        >>> # Use edge UDF in message passing, which is equivalent to\n        >>> # dgl.function.copy_u.\n        >>> import dgl.function as fn\n        >>> g.update_all(edge_udf, fn.sum('src', 'h'))\n        >>> g.ndata['h']\n        tensor([[1.],\n                [2.]])\n        \"\"\"\n        return self._src_data\n\n    @property\n    def dst(self):\n        \"\"\"Return a view of the destination node features for the edges in the batch.\n\n        Examples\n        --------\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        >>> # Instantiate a graph and set a node feature 'h'.\n        >>> g = dgl.graph((torch.tensor([0, 1, 1]), torch.tensor([1, 1, 0])))\n        >>> g.ndata['h'] = torch.tensor([[0.], [1.]])\n\n        >>> # Define a UDF that retrieves the destination node features for\n        >>> # edges.\n        >>> def edge_udf(edges):\n        >>>     # edges.dst['h'] is a tensor of shape (E, 1),\n        >>>     # where E is the number of edges in the batch.\n        >>>     return {'dst': edges.dst['h']}\n\n        >>> # Copy features from destination nodes to edges.\n        >>> g.apply_edges(edge_udf)\n        >>> g.edata['dst']\n        tensor([[1.],\n                [1.],\n                [1.]])\n\n        >>> # Use edge UDF in message passing.\n        >>> import dgl.function as fn\n        >>> g.update_all(edge_udf, fn.sum('dst', 'h'))\n        >>> g.ndata['h']\n        tensor([[0.],\n                [2.]])\n        \"\"\"\n        return self._dst_data\n\n    @property\n    def data(self):\n        \"\"\"Return a view of the edge features for the edges in the batch.\n\n        Examples\n        --------\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        >>> # Instantiate a graph and set an edge feature 'h'.\n        >>> g = dgl.graph((torch.tensor([0, 1, 1]), torch.tensor([1, 1, 0])))\n        >>> g.edata['h'] = torch.tensor([[1.], [1.], [1.]])\n\n        >>> # Define a UDF that retrieves the feature 'h' for all edges.\n        >>> def edge_udf(edges):\n        >>>     # edges.data['h'] is a tensor of shape (E, 1),\n        >>>     # where E is the number of edges in the batch.\n        >>>     return {'data': edges.data['h']}\n\n        >>> # Make a copy of the feature with name 'data'.\n        >>> g.apply_edges(edge_udf)\n        >>> g.edata['data']\n        tensor([[1.],\n                [1.],\n                [1.]])\n\n        >>> # Use edge UDF in message passing, which is equivalent to\n        >>> # dgl.function.copy_e.\n        >>> import dgl.function as fn\n        >>> g.update_all(edge_udf, fn.sum('data', 'h'))\n        >>> g.ndata['h']\n        tensor([[1.],\n                [2.]])\n        \"\"\"\n        return self._edge_data\n\n    def edges(self):\n        \"\"\"Return the edges in the batch.\n\n        Returns\n        -------\n        (U, V, EID) : (Tensor, Tensor, Tensor)\n            The edges in the batch. For each :math:`i`, :math:`(U[i], V[i])` is\n            an edge from :math:`U[i]` to :math:`V[i]` with ID :math:`EID[i]`.\n\n        Examples\n        --------\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        >>> # Instantiate a graph.\n        >>> g = dgl.graph((torch.tensor([0, 1, 1]), torch.tensor([1, 1, 0])))\n\n        >>> # Define a UDF that retrieves and concatenates the end nodes of the\n        >>> # edges.\n        >>> def edge_udf(edges):\n        >>>     src, dst, _ = edges.edges()\n        >>>     return {'uv': torch.stack([src, dst], dim=1).float()}\n\n        >>> # Create a feature 'uv' with the end nodes of the edges.\n        >>> g.apply_edges(edge_udf)\n        >>> g.edata['uv']\n        tensor([[0., 1.],\n                [1., 1.],\n                [1., 0.]])\n\n        >>> # Use edge UDF in message passing.\n        >>> import dgl.function as fn\n        >>> g.update_all(edge_udf, fn.sum('uv', 'h'))\n        >>> g.ndata['h']\n        tensor([[1., 0.],\n                [1., 2.]])\n        \"\"\"\n        u, v = self._graph.find_edges(self._eid, etype=self.canonical_etype)\n        return u, v, self._eid\n\n    def batch_size(self):\n        \"\"\"Return the number of edges in the batch.\n\n        Returns\n        -------\n        int\n\n        Examples\n        --------\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        >>> # Instantiate a graph.\n        >>> g = dgl.graph((torch.tensor([0, 1, 1]), torch.tensor([1, 1, 0])))\n\n        >>> # Define a UDF that returns one for each edge.\n        >>> def edge_udf(edges):\n        >>>     return {'h': torch.ones(edges.batch_size(), 1)}\n\n        >>> # Creates a feature 'h'.\n        >>> g.apply_edges(edge_udf)\n        >>> g.edata['h']\n        tensor([[1.],\n                [1.],\n                [1.]])\n\n        >>> # Use edge UDF in message passing.\n        >>> import dgl.function as fn\n        >>> g.update_all(edge_udf, fn.sum('h', 'h'))\n        >>> g.ndata['h']\n        tensor([[1.],\n                [2.]])\n        \"\"\"\n        return len(self._eid)\n\n    def __len__(self):\n        \"\"\"Return the number of edges in this edge batch.\n\n        Returns\n        -------\n        int\n        \"\"\"\n        return self.batch_size()\n\n    @property\n    def canonical_etype(self):\n        \"\"\"Return the canonical edge type (i.e. triplet of source, edge, and\n        destination node type) for this edge batch.\"\"\"\n        return self._etype\n\n\nclass NodeBatch(object):\n    \"\"\"The class to represent a batch of nodes.\n\n    Parameters\n    ----------\n    graph : DGLGraph\n        Graph object.\n    nodes : Tensor\n        Node ids.\n    ntype : str, optional\n        The node type of this node batch,\n    data : dict[str, Tensor]\n        Node feature data.\n    msgs : dict[str, Tensor], optional\n        Messages data.\n    \"\"\"\n\n    def __init__(self, graph, nodes, ntype, data, msgs=None):\n        self._graph = graph\n        self._nodes = nodes\n        self._ntype = ntype\n        self._data = data\n        self._msgs = msgs\n\n    @property\n    def data(self):\n        \"\"\"Return a view of the node features for the nodes in the batch.\n\n        Examples\n        --------\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        >>> # Instantiate a graph and set a feature 'h'.\n        >>> g = dgl.graph((torch.tensor([0, 1, 1]), torch.tensor([1, 1, 0])))\n        >>> g.ndata['h'] = torch.ones(2, 1)\n\n        >>> # Define a UDF that computes the sum of the messages received and\n        >>> # the original feature for each node.\n        >>> def node_udf(nodes):\n        >>>     # nodes.data['h'] is a tensor of shape (N, 1),\n        >>>     # nodes.mailbox['m'] is a tensor of shape (N, D, 1),\n        >>>     # where N is the number of nodes in the batch, D is the number\n        >>>     # of messages received per node for this node batch.\n        >>>     return {'h': nodes.data['h'] + nodes.mailbox['m'].sum(1)}\n\n        >>> # Use node UDF in message passing.\n        >>> import dgl.function as fn\n        >>> g.update_all(fn.copy_u('h', 'm'), node_udf)\n        >>> g.ndata['h']\n        tensor([[2.],\n                [3.]])\n        \"\"\"\n        return self._data\n\n    @property\n    def mailbox(self):\n        \"\"\"Return a view of the messages received.\n\n        Examples\n        --------\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        >>> # Instantiate a graph and set a feature 'h'.\n        >>> g = dgl.graph((torch.tensor([0, 1, 1]), torch.tensor([1, 1, 0])))\n        >>> g.ndata['h'] = torch.ones(2, 1)\n\n        >>> # Define a UDF that computes the sum of the messages received and\n        >>> # the original feature for each node.\n        >>> def node_udf(nodes):\n        >>>     # nodes.data['h'] is a tensor of shape (N, 1),\n        >>>     # nodes.mailbox['m'] is a tensor of shape (N, D, 1),\n        >>>     # where N is the number of nodes in the batch, D is the number\n        >>>     # of messages received per node for this node batch.\n        >>>     return {'h': nodes.data['h'] + nodes.mailbox['m'].sum(1)}\n\n        >>> # Use node UDF in message passing.\n        >>> import dgl.function as fn\n        >>> g.update_all(fn.copy_u('h', 'm'), node_udf)\n        >>> g.ndata['h']\n        tensor([[2.],\n                [3.]])\n        \"\"\"\n        return self._msgs\n\n    def nodes(self):\n        \"\"\"Return the nodes in the batch.\n\n        Returns\n        -------\n        NID : Tensor\n            The IDs of the nodes in the batch. :math:`NID[i]` gives the ID of\n            the i-th node.\n\n        Examples\n        --------\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        >>> # Instantiate a graph and set a feature 'h'.\n        >>> g = dgl.graph((torch.tensor([0, 1, 1]), torch.tensor([1, 1, 0])))\n        >>> g.ndata['h'] = torch.ones(2, 1)\n\n        >>> # Define a UDF that computes the sum of the messages received and\n        >>> # the original ID for each node.\n        >>> def node_udf(nodes):\n        >>>     # nodes.nodes() is a tensor of shape (N),\n        >>>     # nodes.mailbox['m'] is a tensor of shape (N, D, 1),\n        >>>     # where N is the number of nodes in the batch, D is the number\n        >>>     # of messages received per node for this node batch.\n        >>>     return {'h': nodes.nodes().unsqueeze(-1).float()\n        >>>         + nodes.mailbox['m'].sum(1)}\n\n        >>> # Use node UDF in message passing.\n        >>> import dgl.function as fn\n        >>> g.update_all(fn.copy_u('h', 'm'), node_udf)\n        >>> g.ndata['h']\n        tensor([[1.],\n                [3.]])\n        \"\"\"\n        return self._nodes\n\n    def batch_size(self):\n        \"\"\"Return the number of nodes in the batch.\n\n        Returns\n        -------\n        int\n\n        Examples\n        --------\n        The following example uses PyTorch backend.\n\n        >>> import dgl\n        >>> import torch\n\n        >>> # Instantiate a graph.\n        >>> g = dgl.graph((torch.tensor([0, 1, 1]), torch.tensor([1, 1, 0])))\n        >>> g.ndata['h'] = torch.ones(2, 1)\n\n        >>> # Define a UDF that computes the sum of the messages received for\n        >>> # each node and increments the result by 1.\n        >>> def node_udf(nodes):\n        >>>     return {'h': torch.ones(nodes.batch_size(), 1)\n        >>>         + nodes.mailbox['m'].sum(1)}\n\n        >>> # Use node UDF in message passing.\n        >>> import dgl.function as fn\n        >>> g.update_all(fn.copy_u('h', 'm'), node_udf)\n        >>> g.ndata['h']\n        tensor([[2.],\n                [3.]])\n        \"\"\"\n        return len(self._nodes)\n\n    def __len__(self):\n        \"\"\"Return the number of nodes in this node batch.\n\n        Returns\n        -------\n        int\n        \"\"\"\n        return self.batch_size()\n\n    @property\n    def ntype(self):\n        \"\"\"Return the node type of this node batch, if available.\"\"\"\n        return self._ntype\n"
  },
  {
    "path": "python/dgl/utils/__init__.py",
    "content": "\"\"\"Internal utilities.\"\"\"\nfrom .checks import *\nfrom .data import *\nfrom .exception import *\nfrom .filter import *\nfrom .internal import *\nfrom .pin_memory import *\nfrom .shared_mem import *\n\ntry:\n    from packaging import version\nexcept ImportError:\n    # If packaging isn't installed, try and use the vendored copy in setuptools\n    from setuptools.extern.packaging import version\n"
  },
  {
    "path": "python/dgl/utils/checks.py",
    "content": "\"\"\"Checking and logging utilities.\"\"\"\n# pylint: disable=invalid-name\nfrom __future__ import absolute_import, division\n\nfrom collections.abc import Mapping\n\nfrom .. import backend as F\nfrom .._ffi.function import _init_api\nfrom ..base import DGLError\n\n\ndef prepare_tensor(g, data, name):\n    \"\"\"Convert the data to ID tensor and check its ID type and context.\n\n    If the data is already in tensor type, raise error if its ID type\n    and context does not match the graph's.\n    Otherwise, convert it to tensor type of the graph's ID type and\n    ctx and return.\n\n    Parameters\n    ----------\n    g : DGLGraph\n        Graph.\n    data : int, iterable of int, tensor\n        Data.\n    name : str\n        Name of the data.\n\n    Returns\n    -------\n    Tensor\n        Data in tensor object.\n    \"\"\"\n    if F.is_tensor(data):\n        if F.dtype(data) != g.idtype:\n            raise DGLError(\n                f'Expect argument \"{name}\" to have data type {g.idtype}. '\n                f\"But got {F.dtype(data)}.\"\n            )\n        if F.context(data) != g.device and not g.is_pinned():\n            raise DGLError(\n                f'Expect argument \"{name}\" to have device {g.device}. '\n                f\"But got {F.context(data)}.\"\n            )\n        ret = data\n    else:\n        data = F.tensor(data)\n        if not (\n            F.ndim(data) > 0 and F.shape(data)[0] == 0\n        ) and F.dtype(  # empty tensor\n            data\n        ) not in (\n            F.int32,\n            F.int64,\n        ):\n            raise DGLError(\n                'Expect argument \"{}\" to have data type int32 or int64,'\n                \" but got {}.\".format(name, F.dtype(data))\n            )\n        ret = F.copy_to(F.astype(data, g.idtype), g.device)\n\n    if F.ndim(ret) == 0:\n        ret = F.unsqueeze(ret, 0)\n    if F.ndim(ret) > 1:\n        raise DGLError(\n            'Expect a 1-D tensor for argument \"{}\". But got {}.'.format(\n                name, ret\n            )\n        )\n    return ret\n\n\ndef prepare_tensor_dict(g, data, name):\n    \"\"\"Convert a dictionary of data to a dictionary of ID tensors.\n\n    Calls ``prepare_tensor`` on each key-value pair.\n\n    Parameters\n    ----------\n    g : DGLGraph\n        Graph.\n    data : dict[str, (int, iterable of int, tensor)]\n        Data dict.\n    name : str\n        Name of the data.\n\n    Returns\n    -------\n    dict[str, tensor]\n    \"\"\"\n    return {\n        key: prepare_tensor(g, val, '{}[\"{}\"]'.format(name, key))\n        for key, val in data.items()\n    }\n\n\ndef prepare_tensor_or_dict(g, data, name):\n    \"\"\"Convert data to either a tensor or a dictionary depending on input type.\n\n    Parameters\n    ----------\n    g : DGLGraph\n        Graph.\n    data : dict[str, (int, iterable of int, tensor)]\n        Data dict.\n    name : str\n        Name of the data.\n\n    Returns\n    -------\n    tensor or dict[str, tensor]\n    \"\"\"\n    return (\n        prepare_tensor_dict(g, data, name)\n        if isinstance(data, Mapping)\n        else prepare_tensor(g, data, name)\n    )\n\n\ndef parse_edges_arg_to_eid(g, edges, etid, argname=\"edges\"):\n    \"\"\"Parse the :attr:`edges` argument and return an edge ID tensor.\n\n    The resulting edge ID tensor has the same ID type and device of :attr:`g`.\n\n    Parameters\n    ----------\n    g : DGLGraph\n        Graph\n    edges : pair of Tensor, Tensor, iterable[int]\n        Argument for specifying edges.\n    etid : int\n        Edge type ID.\n    argname : str, optional\n        Argument name.\n\n    Returns\n    -------\n    Tensor\n        Edge ID tensor\n    \"\"\"\n    if isinstance(edges, tuple):\n        u, v = edges\n        u = prepare_tensor(g, u, \"{}[0]\".format(argname))\n        v = prepare_tensor(g, v, \"{}[1]\".format(argname))\n        eid = g.edge_ids(u, v, etype=g.canonical_etypes[etid])\n    else:\n        eid = prepare_tensor(g, edges, argname)\n    return eid\n\n\ndef check_all_same_idtype(glist, name):\n    \"\"\"Check all the graphs have the same idtype.\"\"\"\n    if len(glist) == 0:\n        return\n    idtype = glist[0].idtype\n    for i, g in enumerate(glist):\n        if g.idtype != idtype:\n            raise DGLError(\n                \"Expect {}[{}] to have {} type ID, but got {}.\".format(\n                    name, i, idtype, g.idtype\n                )\n            )\n\n\ndef check_device(data, device):\n    \"\"\"Check if data is on the target device.\n\n    Parameters\n    ----------\n    data : Tensor or dict[str, Tensor]\n    device: Backend device.\n\n    Returns\n    -------\n    Bool: True if the data is on the target device.\n    \"\"\"\n    if isinstance(data, dict):\n        for v in data.values():\n            if v.device != device:\n                return False\n    elif data.device != device:\n        return False\n    return True\n\n\ndef check_all_same_device(glist, name):\n    \"\"\"Check all the graphs have the same device.\"\"\"\n    if len(glist) == 0:\n        return\n    device = glist[0].device\n    for i, g in enumerate(glist):\n        if g.device != device:\n            raise DGLError(\n                \"Expect {}[{}] to be on device {}, but got {}.\".format(\n                    name, i, device, g.device\n                )\n            )\n\n\ndef check_all_same_schema(schemas, name):\n    \"\"\"Check the list of schemas are the same.\"\"\"\n    if len(schemas) == 0:\n        return\n\n    for i, schema in enumerate(schemas):\n        if schema != schemas[0]:\n            raise DGLError(\n                \"Expect all graphs to have the same schema on {}, \"\n                \"but graph {} got\\n\\t{}\\nwhich is different from\\n\\t{}.\".format(\n                    name, i, schema, schemas[0]\n                )\n            )\n\n\ndef check_all_same_schema_for_keys(schemas, keys, name):\n    \"\"\"Check the list of schemas are the same on the given keys.\"\"\"\n    if len(schemas) == 0:\n        return\n\n    head = None\n    keys = set(keys)\n    for i, schema in enumerate(schemas):\n        if not keys.issubset(schema.keys()):\n            raise DGLError(\n                \"Expect all graphs to have keys {} on {}, \"\n                \"but graph {} got keys {}.\".format(keys, name, i, schema.keys())\n            )\n\n        if head is None:\n            head = {k: schema[k] for k in keys}\n        else:\n            target = {k: schema[k] for k in keys}\n            if target != head:\n                raise DGLError(\n                    \"Expect all graphs to have the same schema for keys {} on {}, \"\n                    \"but graph {} got \\n\\t{}\\n which is different from\\n\\t{}.\".format(\n                        keys, name, i, target, head\n                    )\n                )\n\n\ndef check_valid_idtype(idtype):\n    \"\"\"Check whether the value of the idtype argument is valid (int32/int64)\n\n    Parameters\n    ----------\n    idtype : data type\n        The framework object of a data type.\n    \"\"\"\n    if idtype not in [None, F.int32, F.int64]:\n        raise DGLError(\n            \"Expect idtype to be a framework object of int32/int64, \"\n            \"got {}\".format(idtype)\n        )\n\n\ndef is_sorted_srcdst(src, dst, num_src=None, num_dst=None):\n    \"\"\"Checks whether an edge list is in ascending src-major order (e.g., first\n    sorted by ``src`` and then by ``dst``).\n\n    Parameters\n    ----------\n    src : IdArray\n        The tensor of source nodes for each edge.\n    dst : IdArray\n        The tensor of destination nodes for each edge.\n    num_src : int, optional\n        The number of source nodes.\n    num_dst : int, optional\n        The number of destination nodes.\n\n    Returns\n    -------\n    bool, bool\n        Whether ``src`` is in ascending order, and whether ``dst`` is\n        in ascending order with respect to ``src``.\n    \"\"\"\n    # for some versions of MXNET and TensorFlow, num_src and num_dst get\n    # incorrectly marked as floats, so force them as integers here\n    if num_src is None:\n        num_src = int(F.as_scalar(F.max(src, dim=0) + 1))\n    if num_dst is None:\n        num_dst = int(F.as_scalar(F.max(dst, dim=0) + 1))\n\n    src = F.zerocopy_to_dgl_ndarray(src)\n    dst = F.zerocopy_to_dgl_ndarray(dst)\n    sorted_status = _CAPI_DGLCOOIsSorted(src, dst, num_src, num_dst)\n\n    row_sorted = sorted_status > 0\n    col_sorted = sorted_status > 1\n\n    return row_sorted, col_sorted\n\n\n_init_api(\"dgl.utils.checks\")\n"
  },
  {
    "path": "python/dgl/utils/data.py",
    "content": "\"\"\"Data utilities.\"\"\"\n\nfrom collections import namedtuple\n\nimport networkx as nx\nimport scipy as sp\n\nfrom .. import backend as F\nfrom ..base import DGLError\nfrom . import checks\n\n\ndef elist2tensor(elist, idtype):\n    \"\"\"Function to convert an edge list to edge tensors.\n\n    Parameters\n    ----------\n    elist : iterable of int pairs\n        List of (src, dst) node ID pairs.\n    idtype : int32, int64, optional\n        Integer ID type. Must be int32 or int64.\n\n    Returns\n    -------\n    (Tensor, Tensor)\n        Edge tensors.\n    \"\"\"\n    if len(elist) == 0:\n        u, v = [], []\n    else:\n        u, v = zip(*elist)\n        u = list(u)\n        v = list(v)\n    return F.tensor(u, idtype), F.tensor(v, idtype)\n\n\ndef scipy2tensor(spmat, idtype):\n    \"\"\"Function to convert a scipy matrix to a sparse adjacency matrix tuple.\n\n    Note that the data array of the scipy matrix is discarded.\n\n    Parameters\n    ----------\n    spmat : scipy.sparse.spmatrix\n        SciPy sparse matrix.\n    idtype : int32, int64, optional\n        Integer ID type. Must be int32 or int64.\n\n    Returns\n    -------\n    (str, tuple[Tensor])\n        A tuple containing the format as well as the list of tensors representing\n        the sparse matrix.\n    \"\"\"\n    if spmat.format in [\"csr\", \"csc\"]:\n        indptr = F.tensor(spmat.indptr, idtype)\n        indices = F.tensor(spmat.indices, idtype)\n        data = F.tensor([], idtype)\n        return SparseAdjTuple(spmat.format, (indptr, indices, data))\n    else:\n        spmat = spmat.tocoo()\n        row = F.tensor(spmat.row, idtype)\n        col = F.tensor(spmat.col, idtype)\n        return SparseAdjTuple(\"coo\", (row, col))\n\n\ndef networkx2tensor(nx_graph, idtype, edge_id_attr_name=None):\n    \"\"\"Function to convert a networkx graph to edge tensors.\n\n    Parameters\n    ----------\n    nx_graph : nx.Graph\n        NetworkX graph.\n    idtype : int32, int64, optional\n        Integer ID type. Must be int32 or int64.\n    edge_id_attr_name : str, optional\n        Key name for edge ids in the NetworkX graph. If not found, we\n        will consider the graph not to have pre-specified edge ids. (Default: None)\n\n    Returns\n    -------\n    (Tensor, Tensor)\n        Edge tensors.\n    \"\"\"\n    if not nx_graph.is_directed():\n        nx_graph = nx_graph.to_directed()\n\n    # Relabel nodes using consecutive integers\n    nx_graph = nx.convert_node_labels_to_integers(nx_graph, ordering=\"sorted\")\n    has_edge_id = edge_id_attr_name is not None\n\n    if has_edge_id:\n        num_edges = nx_graph.number_of_edges()\n        src = [0] * num_edges\n        dst = [0] * num_edges\n        for u, v, attr in nx_graph.edges(data=True):\n            eid = int(attr[edge_id_attr_name])\n            if eid < 0 or eid >= nx_graph.number_of_edges():\n                raise DGLError(\n                    \"Expect edge IDs to be a non-negative integer smaller than {:d}, \"\n                    \"got {:d}\".format(num_edges, eid)\n                )\n            src[eid] = u\n            dst[eid] = v\n    else:\n        src = []\n        dst = []\n        for e in nx_graph.edges:\n            src.append(e[0])\n            dst.append(e[1])\n    src = F.tensor(src, idtype)\n    dst = F.tensor(dst, idtype)\n    return src, dst\n\n\nSparseAdjTuple = namedtuple(\"SparseAdjTuple\", [\"format\", \"arrays\"])\n\n\ndef graphdata2tensors(\n    data, idtype=None, bipartite=False, infer_node_count=True, **kwargs\n):\n    \"\"\"Function to convert various types of data to edge tensors and infer\n    the number of nodes.\n\n    Parameters\n    ----------\n    data : graph data\n        Various kinds of graph data.  Possible data types are:\n\n        - ``(row, col)``\n        - ``('coo', (row, col))``\n        - ``('csr', (indptr, indices, edge_ids))``\n        - ``('csc', (indptr, indices, edge_ids))``\n        - SciPy sparse matrix\n        - NetworkX graph\n    idtype : int32, int64, optional\n        Integer ID type. If None, try infer from the data and if fail use\n        int64.\n    bipartite : bool, optional\n        Whether infer number of nodes of a bipartite graph --\n        num_src and num_dst can be different.\n    infer_node_count : bool, optional\n        Whether infer number of nodes at all. If False, num_src and num_dst\n        are returned as None.\n    kwargs\n\n        - edge_id_attr_name : The name (str) of the edge attribute that stores the edge\n          IDs in the NetworkX graph.\n        - top_map : The dictionary mapping the original IDs of the source nodes to the\n          new ones.\n        - bottom_map : The dictionary mapping the original IDs of the destination nodes\n          to the new ones.\n\n    Returns\n    -------\n    data : SparseAdjTuple\n        A tuple with the sparse matrix format and the adjacency matrix tensors.\n    num_src : int\n        Number of source nodes.\n    num_dst : int\n        Number of destination nodes.\n    \"\"\"\n    # Convert tuple to SparseAdjTuple\n    if isinstance(data, tuple):\n        if not isinstance(data[0], str):\n            # (row, col) format, convert to ('coo', (row, col))\n            data = (\"coo\", data)\n        data = SparseAdjTuple(*data)\n\n    if idtype is None and not (\n        isinstance(data, SparseAdjTuple) and F.is_tensor(data.arrays[0])\n    ):\n        # preferred default idtype is int64\n        # if data is tensor and idtype is None, infer the idtype from tensor\n        idtype = F.int64\n    checks.check_valid_idtype(idtype)\n\n    if isinstance(data, SparseAdjTuple) and (\n        not all(F.is_tensor(a) for a in data.arrays)\n    ):\n        # (Iterable, Iterable) type data, convert it to (Tensor, Tensor)\n        if len(data.arrays[0]) == 0:\n            # force idtype for empty list\n            data = SparseAdjTuple(\n                data.format, tuple(F.tensor(a, idtype) for a in data.arrays)\n            )\n        else:\n            # convert the iterable to tensor and keep its native data type so we can check\n            # its validity later\n            data = SparseAdjTuple(\n                data.format, tuple(F.tensor(a) for a in data.arrays)\n            )\n\n    num_src, num_dst = None, None\n    if isinstance(data, SparseAdjTuple):\n        if idtype is not None:\n            data = SparseAdjTuple(\n                data.format, tuple(F.astype(a, idtype) for a in data.arrays)\n            )\n        if infer_node_count:\n            num_src, num_dst = infer_num_nodes(data, bipartite=bipartite)\n    elif isinstance(data, list):\n        src, dst = elist2tensor(data, idtype)\n        data = SparseAdjTuple(\"coo\", (src, dst))\n        if infer_node_count:\n            num_src, num_dst = infer_num_nodes(data, bipartite=bipartite)\n    elif isinstance(data, sp.sparse.spmatrix):\n        # We can get scipy matrix's number of rows and columns easily.\n        if infer_node_count:\n            num_src, num_dst = infer_num_nodes(data, bipartite=bipartite)\n        data = scipy2tensor(data, idtype)\n    elif isinstance(data, nx.Graph):\n        # We can get networkx graph's number of sources and destinations easily.\n        if infer_node_count:\n            num_src, num_dst = infer_num_nodes(data, bipartite=bipartite)\n        edge_id_attr_name = kwargs.get(\"edge_id_attr_name\", None)\n        if bipartite:\n            top_map = kwargs.get(\"top_map\")\n            bottom_map = kwargs.get(\"bottom_map\")\n            src, dst = networkxbipartite2tensors(\n                data,\n                idtype,\n                top_map=top_map,\n                bottom_map=bottom_map,\n                edge_id_attr_name=edge_id_attr_name,\n            )\n        else:\n            src, dst = networkx2tensor(\n                data, idtype, edge_id_attr_name=edge_id_attr_name\n            )\n        data = SparseAdjTuple(\"coo\", (src, dst))\n    else:\n        raise DGLError(\"Unsupported graph data type:\", type(data))\n\n    return data, num_src, num_dst\n\n\ndef networkxbipartite2tensors(\n    nx_graph, idtype, top_map, bottom_map, edge_id_attr_name=None\n):\n    \"\"\"Function to convert a networkx bipartite to edge tensors.\n\n    Parameters\n    ----------\n    nx_graph : nx.Graph\n        NetworkX graph. It must follow the bipartite graph convention of networkx.\n        Each node has an attribute ``bipartite`` with values 0 and 1 indicating\n        which set it belongs to.\n    top_map : dict\n        The dictionary mapping the original node labels to the node IDs for the source type.\n    bottom_map : dict\n        The dictionary mapping the original node labels to the node IDs for the destination type.\n    idtype : int32, int64, optional\n        Integer ID type. Must be int32 or int64.\n    edge_id_attr_name : str, optional\n        Key name for edge ids in the NetworkX graph. If not found, we\n        will consider the graph not to have pre-specified edge ids. (Default: None)\n\n    Returns\n    -------\n    (Tensor, Tensor)\n        Edge tensors.\n    \"\"\"\n    has_edge_id = edge_id_attr_name is not None\n\n    if has_edge_id:\n        num_edges = nx_graph.number_of_edges()\n        src = [0] * num_edges\n        dst = [0] * num_edges\n        for u, v, attr in nx_graph.edges(data=True):\n            if u not in top_map:\n                raise DGLError(\n                    \"Expect the node {} to have attribute bipartite=0 \"\n                    \"with edge {}\".format(u, (u, v))\n                )\n            if v not in bottom_map:\n                raise DGLError(\n                    \"Expect the node {} to have attribute bipartite=1 \"\n                    \"with edge {}\".format(v, (u, v))\n                )\n            eid = int(attr[edge_id_attr_name])\n            if eid < 0 or eid >= nx_graph.number_of_edges():\n                raise DGLError(\n                    \"Expect edge IDs to be a non-negative integer smaller than {:d}, \"\n                    \"got {:d}\".format(num_edges, eid)\n                )\n            src[eid] = top_map[u]\n            dst[eid] = bottom_map[v]\n    else:\n        src = []\n        dst = []\n        for e in nx_graph.edges:\n            u, v = e[0], e[1]\n            if u not in top_map:\n                raise DGLError(\n                    \"Expect the node {} to have attribute bipartite=0 \"\n                    \"with edge {}\".format(u, (u, v))\n                )\n            if v not in bottom_map:\n                raise DGLError(\n                    \"Expect the node {} to have attribute bipartite=1 \"\n                    \"with edge {}\".format(v, (u, v))\n                )\n            src.append(top_map[u])\n            dst.append(bottom_map[v])\n    src = F.tensor(src, dtype=idtype)\n    dst = F.tensor(dst, dtype=idtype)\n    return src, dst\n\n\ndef infer_num_nodes(data, bipartite=False):\n    \"\"\"Function for inferring the number of nodes.\n\n    Parameters\n    ----------\n    data : graph data\n        Supported types are:\n\n        * SparseTuple ``(sparse_fmt, arrays)`` where ``arrays`` can be either ``(src, dst)`` or\n          ``(indptr, indices, data)``.\n        * SciPy matrix.\n        * NetworkX graph.\n    bipartite : bool, optional\n        Whether infer number of nodes of a bipartite graph --\n        num_src and num_dst can be different.\n\n    Returns\n    -------\n    num_src : int\n        Number of source nodes.\n    num_dst : int\n        Number of destination nodes.\n\n    or\n\n    None\n        If the inference failed.\n    \"\"\"\n    if isinstance(data, tuple) and len(data) == 2:\n        if not isinstance(data[0], str):\n            raise TypeError(\n                \"Expected sparse format as a str, but got %s\" % type(data[0])\n            )\n\n        if data[0] == \"coo\":\n            # ('coo', (src, dst)) format\n            u, v = data[1]\n            nsrc = F.as_scalar(F.max(u, dim=0)) + 1 if len(u) > 0 else 0\n            ndst = F.as_scalar(F.max(v, dim=0)) + 1 if len(v) > 0 else 0\n        elif data[0] == \"csr\":\n            # ('csr', (indptr, indices, eids)) format\n            indptr, indices, _ = data[1]\n            nsrc = F.shape(indptr)[0] - 1\n            ndst = (\n                F.as_scalar(F.max(indices, dim=0)) + 1\n                if len(indices) > 0\n                else 0\n            )\n        elif data[0] == \"csc\":\n            # ('csc', (indptr, indices, eids)) format\n            indptr, indices, _ = data[1]\n            ndst = F.shape(indptr)[0] - 1\n            nsrc = (\n                F.as_scalar(F.max(indices, dim=0)) + 1\n                if len(indices) > 0\n                else 0\n            )\n        else:\n            raise ValueError(\"unknown format %s\" % data[0])\n    elif isinstance(data, sp.sparse.spmatrix):\n        nsrc, ndst = data.shape[0], data.shape[1]\n    elif isinstance(data, nx.Graph):\n        if data.number_of_nodes() == 0:\n            nsrc = ndst = 0\n        elif not bipartite:\n            nsrc = ndst = data.number_of_nodes()\n        else:\n            nsrc = len(\n                {n for n, d in data.nodes(data=True) if d[\"bipartite\"] == 0}\n            )\n            ndst = data.number_of_nodes() - nsrc\n    else:\n        return None\n    if not bipartite:\n        nsrc = ndst = max(nsrc, ndst)\n    return nsrc, ndst\n\n\ndef to_device(data, device):\n    \"\"\"Transfer the tensor or dictionary of tensors to the given device.\n\n    Nothing will happen if the device of the original tensor is the same as target device.\n\n    Parameters\n    ----------\n    data : Tensor or dict[str, Tensor]\n        The data.\n    device : device\n        The target device.\n\n    Returns\n    -------\n    Tensor or dict[str, Tensor]\n        The output data.\n    \"\"\"\n    if isinstance(data, dict):\n        return {k: F.copy_to(v, device) for k, v in data.items()}\n    else:\n        return F.copy_to(data, device)\n"
  },
  {
    "path": "python/dgl/utils/exception.py",
    "content": "\"\"\"Exception wrapper classes to properly display exceptions under multithreading or\nmultiprocessing.\n\"\"\"\nimport sys\nimport traceback\n\n# The following code is borrowed from PyTorch.  Basically when a subprocess or thread\n# throws an exception, you will need to wrap the exception with ExceptionWrapper class\n# and put it in the queue you are normally retrieving from.\n\n# NOTE [ Python Traceback Reference Cycle Problem ]\n#\n# When using sys.exc_info(), it is important to **not** store the exc_info[2],\n# which is the traceback, because otherwise you will run into the traceback\n# reference cycle problem, i.e., the traceback holding reference to the frame,\n# and the frame (which holds reference to all the object in its temporary scope)\n# holding reference the traceback.\n\n\nclass KeyErrorMessage(str):\n    r\"\"\"str subclass that returns itself in repr\"\"\"\n\n    def __repr__(self):  # pylint: disable=invalid-repr-returned\n        return self\n\n\nclass ExceptionWrapper(object):\n    r\"\"\"Wraps an exception plus traceback to communicate across threads\"\"\"\n\n    def __init__(self, exc_info=None, where=\"in background\"):\n        # It is important that we don't store exc_info, see\n        # NOTE [ Python Traceback Reference Cycle Problem ]\n        if exc_info is None:\n            exc_info = sys.exc_info()\n        self.exc_type = exc_info[0]\n        self.exc_msg = \"\".join(traceback.format_exception(*exc_info))\n        self.where = where\n\n    def reraise(self):\n        r\"\"\"Reraises the wrapped exception in the current thread\"\"\"\n        # Format a message such as: \"Caught ValueError in DataLoader worker\n        # process 2. Original Traceback:\", followed by the traceback.\n        msg = \"Caught {} {}.\\nOriginal {}\".format(\n            self.exc_type.__name__, self.where, self.exc_msg\n        )\n        if self.exc_type == KeyError:\n            # KeyError calls repr() on its argument (usually a dict key). This\n            # makes stack traces unreadable. It will not be changed in Python\n            # (https://bugs.python.org/issue2651), so we work around it.\n            msg = KeyErrorMessage(msg)\n        elif getattr(self.exc_type, \"message\", None):\n            # Some exceptions have first argument as non-str but explicitly\n            # have message field\n            raise self.exc_type(message=msg)\n        try:\n            exception = self.exc_type(msg)\n        except TypeError:\n            # If the exception takes multiple arguments, don't try to\n            # instantiate since we don't know how to\n            raise RuntimeError(msg) from None\n        raise exception\n"
  },
  {
    "path": "python/dgl/utils/filter.py",
    "content": "\"\"\"Utilities for finding overlap or missing items in arrays.\"\"\"\n\nfrom .. import backend as F\nfrom .._ffi.function import _init_api\n\n\nclass Filter(object):\n    \"\"\"Class used to either find the subset of IDs that are in this\n    filter, or the subset of IDs that are not in this filter\n    given a second set of IDs.\n\n    Examples\n    --------\n    >>> import torch as th\n    >>> from dgl.utils import Filter\n    >>> f = Filter(th.tensor([3,2,9], device=th.device('cuda')))\n    >>> f.find_included_indices(th.tensor([0,2,8,9], device=th.device('cuda')))\n    tensor([1,3])\n    >>> f.find_excluded_indices(th.tensor([0,2,8,9], device=th.device('cuda')))\n    tensor([0,2], device='cuda')\n    \"\"\"\n\n    def __init__(self, ids):\n        \"\"\"Create a new filter from a given set of IDs. This currently is only\n        implemented for the GPU.\n\n        Parameters\n        ----------\n        ids : IdArray\n            The unique set of IDs to keep in the filter.\n        \"\"\"\n        self._filter = _CAPI_DGLFilterCreateFromSet(\n            F.zerocopy_to_dgl_ndarray(ids)\n        )\n\n    def find_included_indices(self, test):\n        \"\"\"Find the index of the IDs in `test` that are in this filter.\n\n        Parameters\n        ----------\n        test : IdArray\n            The set of IDs to to test with.\n\n        Returns\n        -------\n        IdArray\n            The index of IDs in `test` that are also in this filter.\n        \"\"\"\n        return F.zerocopy_from_dgl_ndarray(\n            _CAPI_DGLFilterFindIncludedIndices(\n                self._filter, F.zerocopy_to_dgl_ndarray(test)\n            )\n        )\n\n    def find_excluded_indices(self, test):\n        \"\"\"Find the index of the IDs in `test` that are not in this filter.\n\n        Parameters\n        ----------\n        test : IdArray\n            The set of IDs to to test with.\n\n        Returns\n        -------\n        IdArray\n            The index of IDs in `test` that are not in this filter.\n        \"\"\"\n        return F.zerocopy_from_dgl_ndarray(\n            _CAPI_DGLFilterFindExcludedIndices(\n                self._filter, F.zerocopy_to_dgl_ndarray(test)\n            )\n        )\n\n\n_init_api(\"dgl.utils.filter\")\n"
  },
  {
    "path": "python/dgl/utils/internal.py",
    "content": "\"\"\"Internal utilities.\"\"\"\nfrom __future__ import absolute_import, division\n\nimport glob\nimport os\nfrom collections import defaultdict\nfrom collections.abc import Iterable, Mapping, Sequence\nfrom functools import wraps\n\nimport numpy as np\n\nfrom .. import backend as F, ndarray as nd\nfrom .._ffi.function import _init_api\nfrom ..base import dgl_warning, DGLError, EID, NID\n\n\ndef is_listlike(data):\n    \"\"\"Return if the data is a sequence but not a string.\"\"\"\n    return isinstance(data, Sequence) and not isinstance(data, str)\n\n\nclass InconsistentDtypeException(DGLError):\n    \"\"\"Exception class for inconsistent dtype between graph and tensor\"\"\"\n\n    def __init__(self, msg=\"\", *args, **kwargs):  # pylint: disable=W1113\n        prefix_message = \"DGL now requires the input tensor to have\\\n            the same dtype as the graph index's dtype(which you can get by g.idype). \"\n        super().__init__(prefix_message + msg, *args, **kwargs)\n\n\nclass Index(object):\n    \"\"\"Index class that can be easily converted to list/tensor.\"\"\"\n\n    def __init__(self, data, dtype=\"int64\"):\n        assert dtype in [\"int32\", \"int64\"]\n        self.dtype = dtype\n        self._initialize_data(data)\n\n    def _initialize_data(self, data):\n        self._pydata = None  # a numpy type data\n        self._user_tensor_data = dict()  # dictionary of user tensors\n        self._dgl_tensor_data = None  # a dgl ndarray\n        self._slice_data = None  # a slice type data\n        self._dispatch(data)\n\n    def __iter__(self):\n        for i in self.tonumpy():\n            yield int(i)\n\n    def __len__(self):\n        if self._slice_data is not None:\n            slc = self._slice_data\n            return slc.stop - slc.start\n        elif self._pydata is not None:\n            return len(self._pydata)\n        elif len(self._user_tensor_data) > 0:\n            data = next(iter(self._user_tensor_data.values()))\n            return len(data)\n        else:\n            return len(self._dgl_tensor_data)\n\n    def __getitem__(self, i):\n        return int(self.tonumpy()[i])\n\n    def _dispatch(self, data):\n        \"\"\"Store data based on its type.\"\"\"\n        if F.is_tensor(data):\n            if F.dtype(data) != F.data_type_dict[self.dtype]:\n                raise InconsistentDtypeException(\n                    \"Index data specified as %s, but got: %s\"\n                    % (self.dtype, F.reverse_data_type_dict[F.dtype(data)])\n                )\n            if len(F.shape(data)) > 1:\n                raise InconsistentDtypeException(\n                    \"Index data must be 1D int32/int64 vector,\\\n                    but got shape: %s\"\n                    % str(F.shape(data))\n                )\n            if len(F.shape(data)) == 0:\n                # a tensor of one int\n                self._dispatch(int(data))\n            else:\n                self._user_tensor_data[F.context(data)] = data\n        elif isinstance(data, nd.NDArray):\n            if not (data.dtype == self.dtype and len(data.shape) == 1):\n                raise InconsistentDtypeException(\n                    \"Index data must be 1D %s vector, but got: %s\"\n                    % (self.dtype, data.dtype)\n                )\n            self._dgl_tensor_data = data\n        elif isinstance(data, slice):\n            # save it in the _pydata temporarily; materialize it if `tonumpy` is called\n            assert (\n                data.step == 1 or data.step is None\n            ), \"step for slice type must be 1\"\n            self._slice_data = slice(data.start, data.stop)\n        else:\n            try:\n                data = np.asarray(data, dtype=self.dtype)\n            except Exception:  # pylint: disable=broad-except\n                raise DGLError(\"Error index data: %s\" % str(data))\n            if data.ndim == 0:  # scalar array\n                data = np.expand_dims(data, 0)\n            elif data.ndim != 1:\n                raise DGLError(\n                    \"Index data must be 1D int64 vector,\"\n                    \" but got: %s\" % str(data)\n                )\n            self._pydata = data\n            self._user_tensor_data[F.cpu()] = F.zerocopy_from_numpy(\n                self._pydata\n            )\n\n    def tonumpy(self):\n        \"\"\"Convert to a numpy ndarray.\"\"\"\n        if self._pydata is None:\n            if self._slice_data is not None:\n                slc = self._slice_data\n                self._pydata = np.arange(slc.start, slc.stop).astype(self.dtype)\n            elif self._dgl_tensor_data is not None:\n                self._pydata = self._dgl_tensor_data.asnumpy()\n            else:\n                data = self.tousertensor()\n                self._pydata = F.zerocopy_to_numpy(data)\n        return self._pydata\n\n    def tousertensor(self, ctx=None):\n        \"\"\"Convert to user tensor (defined in `backend`).\"\"\"\n        if ctx is None:\n            ctx = F.cpu()\n        if len(self._user_tensor_data) == 0:\n            if self._dgl_tensor_data is not None:\n                # zero copy from dgl tensor\n                dlpack = self._dgl_tensor_data.to_dlpack()\n                self._user_tensor_data[F.cpu()] = F.zerocopy_from_dlpack(dlpack)\n            else:\n                # zero copy from numpy array\n                self._user_tensor_data[F.cpu()] = F.zerocopy_from_numpy(\n                    self.tonumpy()\n                )\n        if ctx not in self._user_tensor_data:\n            # copy from cpu to another device\n            data = next(iter(self._user_tensor_data.values()))\n            self._user_tensor_data[ctx] = F.copy_to(data, ctx)\n        return self._user_tensor_data[ctx]\n\n    def todgltensor(self):\n        \"\"\"Convert to dgl.NDArray.\"\"\"\n        if self._dgl_tensor_data is None:\n            # zero copy from user tensor\n            tsor = self.tousertensor()\n            dlpack = F.zerocopy_to_dlpack(tsor)\n            self._dgl_tensor_data = nd.from_dlpack(dlpack)\n        return self._dgl_tensor_data\n\n    def slice_data(self):\n        \"\"\"Return the internal slice data.\n\n        If this index is not initialized from slice, the return will be None.\n        \"\"\"\n        return self._slice_data\n\n    def is_slice(self, start, stop):\n        \"\"\"Check if Index wraps a slice data with given start and stop\"\"\"\n        return self._slice_data == slice(start, stop)\n\n    def __getstate__(self):\n        if self._slice_data is not None:\n            # the index can be represented by a slice\n            return self._slice_data, self.dtype\n        else:\n            return self.tousertensor(), self.dtype\n\n    def __setstate__(self, state):\n        # Pickle compatibility check\n        # TODO: we should store a storage version number in later releases.\n        if isinstance(state, tuple) and len(state) == 2:\n            # post-0.4.4\n            data, self.dtype = state\n            self._initialize_data(data)\n        else:\n            # pre-0.4.3\n            dgl_warning(\n                \"The object is pickled before 0.4.3.  Setting dtype of graph to int64\"\n            )\n            self.dtype = \"int64\"\n            self._initialize_data(state)\n\n    def get_items(self, index):\n        \"\"\"Return values at given positions of an Index\n\n        Parameters\n        ----------\n        index: utils.Index\n\n        Returns\n        -------\n        utils.Index\n            The values at the given position.\n        \"\"\"\n        if self._slice_data is not None and self._slice_data.start == 0:\n            # short-cut for identical mapping\n            # NOTE: we don't check for out-of-bound error\n            return index\n        elif index._slice_data is None:\n            # the provided index is not a slice\n            tensor = self.tousertensor()\n            index = index.tousertensor()\n            # TODO(Allen): Change F.gather_row to dgl operation\n            return Index(F.gather_row(tensor, index), self.dtype)\n        elif self._slice_data is None:\n            # the current index is not a slice but the provided is a slice\n            tensor = self.tousertensor()\n            index = index._slice_data\n            # TODO(Allen): Change F.narrow_row to dgl operation\n            return Index(\n                F.astype(\n                    F.narrow_row(tensor, index.start, index.stop),\n                    F.data_type_dict[self.dtype],\n                ),\n                self.dtype,\n            )\n        else:\n            # both self and index wrap a slice object, then return another\n            # Index wrapping a slice\n            start = self._slice_data.start\n            index = index._slice_data\n            return Index(\n                slice(start + index.start, start + index.stop), self.dtype\n            )\n\n    def set_items(self, index, value):\n        \"\"\"Set values at given positions of an Index. Set is not done in place,\n        instead, a new Index object will be returned.\n\n        Parameters\n        ----------\n        index: utils.Index\n            Positions to set values\n        value: int or utils.Index\n            Values to set. If value is an integer, then all positions are set\n            to the same value\n\n        Returns\n        -------\n        utils.Index\n            The new values.\n        \"\"\"\n        tensor = self.tousertensor()\n        index = index.tousertensor()\n        if isinstance(value, int):\n            value = F.full_1d(len(index), value, dtype=F.int64, ctx=F.cpu())\n        else:\n            value = value.tousertensor()\n        return Index(F.scatter_row(tensor, index, value), self.dtype)\n\n    def append_zeros(self, num):\n        \"\"\"Append zeros to an Index\n\n        Parameters\n        ----------\n        num: int\n            number of zeros to append\n        \"\"\"\n        if num == 0:\n            return self\n        new_items = F.zeros((num,), dtype=F.int64, ctx=F.cpu())\n        if len(self) == 0:\n            return Index(new_items, self.dtype)\n        else:\n            tensor = self.tousertensor()\n            tensor = F.cat((tensor, new_items), dim=0)\n            return Index(tensor, self.dtype)\n\n    def nonzero(self):\n        \"\"\"Return the nonzero positions\"\"\"\n        tensor = self.tousertensor()\n        mask = F.nonzero_1d(tensor != 0)\n        return Index(mask, self.dtype)\n\n    def has_nonzero(self):\n        \"\"\"Check if there is any nonzero value in this Index\"\"\"\n        tensor = self.tousertensor()\n        return F.sum(tensor, 0) > 0\n\n\ndef toindex(data, dtype=\"int64\"):\n    \"\"\"Convert the given data to Index object.\n\n    Parameters\n    ----------\n    data : index data\n        Data to create the index.\n\n    Returns\n    -------\n    Index\n        The index object.\n\n    See Also\n    --------\n    Index\n    \"\"\"\n    return data if isinstance(data, Index) else Index(data, dtype)\n\n\ndef zero_index(size, dtype=\"int64\"):\n    \"\"\"Create a index with provided size initialized to zero\n\n    Parameters\n    ----------\n    size: int\n    \"\"\"\n    return Index(\n        F.zeros((size,), dtype=F.data_type_dict[dtype], ctx=F.cpu()),\n        dtype=dtype,\n    )\n\n\ndef set_diff(ar1, ar2):\n    \"\"\"Find the set difference of two index arrays.\n    Return the unique values in ar1 that are not in ar2.\n\n    Parameters\n    ----------\n    ar1: utils.Index\n        Input index array.\n\n    ar2: utils.Index\n        Input comparison index array.\n\n    Returns\n    -------\n    setdiff:\n        Array of values in ar1 that are not in ar2.\n    \"\"\"\n    ar1_np = ar1.tonumpy()\n    ar2_np = ar2.tonumpy()\n    setdiff = np.setdiff1d(ar1_np, ar2_np)\n    setdiff = toindex(setdiff)\n    return setdiff\n\n\nclass LazyDict(Mapping):\n    \"\"\"A readonly dictionary that does not materialize the storage.\"\"\"\n\n    def __init__(self, fn, keys):\n        self._fn = fn\n        self._keys = keys\n\n    def __getitem__(self, key):\n        if key not in self._keys:\n            raise KeyError(key)\n        return self._fn(key)\n\n    def __contains__(self, key):\n        return key in self._keys\n\n    def __iter__(self):\n        return iter(self._keys)\n\n    def __len__(self):\n        return len(self._keys)\n\n    def keys(self):\n        return self._keys\n\n\nclass HybridDict(Mapping):\n    \"\"\"A readonly dictonary that merges several dict-like (python dict, LazyDict).\n\n    If there are duplicate keys, early keys have priority over latter ones.\n    \"\"\"\n\n    def __init__(self, *dict_like_list):\n        self._dict_like_list = dict_like_list\n        self._keys = set()\n        for obj in dict_like_list:\n            self._keys.update(obj.keys())\n\n    def keys(self):\n        return self._keys\n\n    def __getitem__(self, key):\n        for obj in self._dict_like_list:\n            if key in obj:\n                return obj[key]\n        raise KeyError(key)\n\n    def __contains__(self, key):\n        return key in self.keys()\n\n    def __iter__(self):\n        return iter(self.keys())\n\n    def __len__(self):\n        return len(self.keys())\n\n\nclass ReadOnlyDict(Mapping):\n    \"\"\"A readonly dictionary wrapper.\"\"\"\n\n    def __init__(self, dict_like):\n        self._dict_like = dict_like\n\n    def keys(self):\n        return self._dict_like.keys()\n\n    def __getitem__(self, key):\n        return self._dict_like[key]\n\n    def __contains__(self, key):\n        return key in self._dict_like\n\n    def __iter__(self):\n        return iter(self._dict_like)\n\n    def __len__(self):\n        return len(self._dict_like)\n\n\ndef build_relabel_map(x, is_sorted=False):\n    \"\"\"Relabel the input ids to continuous ids that starts from zero.\n\n    Ids are assigned new ids according to their ascending order.\n\n    Examples\n    --------\n    >>> x = [1, 5, 3, 6]\n    >>> n2o, o2n = build_relabel_map(x)\n    >>> n2o\n    [1, 3, 5, 6]\n    >>> o2n\n    [n/a, 0, n/a, 1, n/a, 2, 3]\n\n    \"n/a\" will be filled with 0\n\n    Parameters\n    ----------\n    x : Index\n        The input ids.\n    is_sorted : bool, default=False\n        Whether the input has already been unique and sorted.\n\n    Returns\n    -------\n    new_to_old : tensor\n        The mapping from new id to old id.\n    old_to_new : tensor\n        The mapping from old id to new id. It is a vector of length MAX(x).\n        One can use advanced indexing to convert an old id tensor to a\n        new id tensor: new_id = old_to_new[old_id]\n    \"\"\"\n    x = x.tousertensor()\n    if not is_sorted:\n        unique_x, _ = F.sort_1d(F.unique(x))\n    else:\n        unique_x = x\n    map_len = int(F.asnumpy(F.max(unique_x, dim=0))) + 1\n    old_to_new = F.zeros((map_len,), dtype=F.int64, ctx=F.cpu())\n    old_to_new = F.scatter_row(old_to_new, unique_x, F.arange(0, len(unique_x)))\n    return unique_x, old_to_new\n\n\ndef build_relabel_dict(x):\n    \"\"\"Relabel the input ids to continuous ids that starts from zero.\n\n    The new id follows the order of the given node id list.\n\n    Parameters\n    ----------\n    x : list\n      The input ids.\n\n    Returns\n    -------\n    relabel_dict : dict\n      Dict from old id to new id.\n    \"\"\"\n    relabel_dict = {}\n    for i, v in enumerate(x):\n        relabel_dict[v] = i\n    return relabel_dict\n\n\nclass CtxCachedObject(object):\n    \"\"\"A wrapper to cache object generated by different context.\n\n    Note: such wrapper may incur significant overhead if the wrapped object is very light.\n\n    Parameters\n    ----------\n    generator : callable\n        A callable function that can create the object given ctx as the only argument.\n    \"\"\"\n\n    def __init__(self, generator):\n        self._generator = generator\n        self._ctx_dict = {}\n\n    def __call__(self, ctx):\n        if ctx not in self._ctx_dict:\n            self._ctx_dict[ctx] = self._generator(ctx)\n        return self._ctx_dict[ctx]\n\n\ndef cached_member(cache, prefix):\n    \"\"\"A member function decorator to memorize the result.\n\n    Note that the member function cannot support kwargs after being decorated.\n    The member function must be functional. Otherwise, the behavior is undefined.\n\n    Parameters\n    ----------\n    cache : str\n        The cache name. The cache should be a dictionary attribute\n        in the class object.\n    prefix : str\n        The key prefix to save the result of the function.\n    \"\"\"\n\n    def _creator(func):\n        @wraps(func)\n        def wrapper(self, *args, **kwargs):\n            dic = getattr(self, cache)\n            key = \"%s-%s-%s\" % (\n                prefix,\n                \"-\".join([str(a) for a in args]),\n                \"-\".join([str(k) + \":\" + str(v) for k, v in kwargs.items()]),\n            )\n            if key not in dic:\n                dic[key] = func(self, *args, **kwargs)\n            return dic[key]\n\n        return wrapper\n\n    return _creator\n\n\ndef is_dict_like(obj):\n    \"\"\"Return true if the object can be treated as a dictionary.\"\"\"\n    return isinstance(obj, Mapping)\n\n\ndef reorder(dict_like, index):\n    \"\"\"Reorder each column in the dict according to the index.\n\n    Parameters\n    ----------\n    dict_like : dict of tensors\n        The dict to be reordered.\n    index : dgl.utils.Index\n        The reorder index.\n    \"\"\"\n    new_dict = {}\n    for key, val in dict_like.items():\n        idx_ctx = index.tousertensor(F.context(val))\n        new_dict[key] = F.gather_row(val, idx_ctx)\n    return new_dict\n\n\ndef reorder_index(idx, order):\n    \"\"\"Reorder the idx according to the given order\n\n    Parameters\n    ----------\n    idx : utils.Index\n        The index to be reordered.\n    order : utils.Index\n        The order to follow.\n    \"\"\"\n    idx = idx.tousertensor()\n    order = order.tousertensor()\n    new_idx = F.gather_row(idx, order)\n    return toindex(new_idx)\n\n\ndef is_iterable(obj):\n    \"\"\"Return true if the object is an iterable.\"\"\"\n    return isinstance(obj, Iterable)\n\n\ndef to_dgl_context(ctx):\n    \"\"\"Convert a backend context to DGLContext\"\"\"\n    device_type = nd.DGLContext.STR2MASK[F.device_type(ctx)]\n    device_id = F.device_id(ctx)\n    return nd.DGLContext(device_type, device_id)\n\n\ndef to_nbits_int(tensor, nbits):\n    \"\"\"Change the dtype of integer tensor\n    The dtype of returned tensor uses nbits, nbits can only be 32 or 64\n    \"\"\"\n    assert nbits in (32, 64), \"nbits can either be 32 or 64\"\n    if nbits == 32:\n        return F.astype(tensor, F.int32)\n    else:\n        return F.astype(tensor, F.int64)\n\n\ndef make_invmap(array, use_numpy=True):\n    \"\"\"Find the unique elements of the array and return another array with indices\n    to the array of unique elements.\"\"\"\n    if use_numpy:\n        uniques = np.unique(array)\n    else:\n        uniques = list(set(array))\n    invmap = {x: i for i, x in enumerate(uniques)}\n    remapped = np.asarray([invmap[x] for x in array])\n    return uniques, invmap, remapped\n\n\ndef expand_as_pair(input_, g=None):\n    \"\"\"Return a pair of same element if the input is not a pair.\n\n    If the graph is a block, obtain the feature of destination nodes from the source nodes.\n\n    Parameters\n    ----------\n    input_ : Tensor, dict[str, Tensor], or their pairs\n        The input features\n    g : DGLGraph or None\n        The graph.\n\n        If None, skip checking if the graph is a block.\n\n    Returns\n    -------\n    tuple[Tensor, Tensor] or tuple[dict[str, Tensor], dict[str, Tensor]]\n        The features for input and output nodes\n    \"\"\"\n    if isinstance(input_, tuple):\n        return input_\n    elif g is not None and g.is_block:\n        if isinstance(input_, Mapping):\n            input_dst = {\n                k: F.narrow_row(v, 0, g.number_of_dst_nodes(k))\n                for k, v in input_.items()\n            }\n        else:\n            input_dst = F.narrow_row(input_, 0, g.number_of_dst_nodes())\n        return input_, input_dst\n    else:\n        return input_, input_\n\n\ndef check_eq_shape(input_):\n    \"\"\"If input_ is a pair of features, check if the feature shape of source\n    nodes is equal to the feature shape of destination nodes.\n    \"\"\"\n    srcdata, dstdata = expand_as_pair(input_)\n    src_feat_shape = tuple(F.shape(srcdata))[1:]\n    dst_feat_shape = tuple(F.shape(dstdata))[1:]\n    if src_feat_shape != dst_feat_shape:\n        raise DGLError(\n            \"The feature shape of source nodes: {} \\\n            should be equal to the feature shape of destination \\\n            nodes: {}.\".format(\n                src_feat_shape, dst_feat_shape\n            )\n        )\n\n\ndef retry_method_with_fix(fix_method):\n    \"\"\"Decorator that executes a fix method before retrying again when the decorated method\n    fails once with any exception.\n\n    If the decorated method fails again, the execution fails with that exception.\n\n    Notes\n    -----\n    This decorator only works on class methods, and the fix function must also be a class method.\n    It would not work on functions.\n\n    Parameters\n    ----------\n    fix_func : callable\n        The fix method to execute.  It should not accept any arguments.  Its return values are\n        ignored.\n    \"\"\"\n\n    def _creator(func):\n        @wraps(func)\n        def wrapper(self, *args, **kwargs):\n            # pylint: disable=W0703,bare-except\n            try:\n                return func(self, *args, **kwargs)\n            except:\n                fix_method(self)\n                return func(self, *args, **kwargs)\n\n        return wrapper\n\n    return _creator\n\n\ndef group_as_dict(pairs):\n    \"\"\"Combines a list of key-value pairs to a dictionary of keys and value lists.\n\n    Does not require the pairs to be sorted by keys.\n\n    Parameters\n    ----------\n    pairs : iterable\n        Iterable of key-value pairs\n\n    Returns\n    -------\n    dict\n        The dictionary of keys and value lists.\n    \"\"\"\n    dic = defaultdict(list)\n    for key, value in pairs:\n        dic[key].append(value)\n    return dic\n\n\nclass FlattenedDict(object):\n    \"\"\"Iterates over each item in a dictionary of groups.\n\n    Parameters\n    ----------\n    groups : dict\n        The item groups.\n\n    Examples\n    --------\n    >>> groups = FlattenedDict({'a': [1, 3], 'b': [2, 5, 8], 'c': [7]})\n    >>> list(groups)\n    [('a', 1), ('a', 3), ('b', 2), ('b', 5), ('b', 8), ('c', 7)]\n    >>> groups[2]\n    ('b', 2)\n    >>> len(groups)\n    6\n    \"\"\"\n\n    def __init__(self, groups):\n        self._groups = groups\n        group_sizes = {k: len(v) for k, v in groups.items()}\n        self._group_keys, self._group_sizes = zip(*group_sizes.items())\n        self._group_offsets = np.insert(np.cumsum(self._group_sizes), 0, 0)\n        # TODO: this is faster (37s -> 21s per epoch compared to searchsorted in GCMC) but takes\n        # O(E) memory.\n        self._idx_to_group = np.zeros(self._group_offsets[-1], dtype=\"int32\")\n        for i in range(len(self._groups)):\n            self._idx_to_group[\n                self._group_offsets[i] : self._group_offsets[i + 1]\n            ] = i\n\n    def __len__(self):\n        \"\"\"Return the total number of items.\"\"\"\n        return self._group_offsets[-1]\n\n    def __iter__(self):\n        \"\"\"Return the iterator of all items with the key of its original group.\"\"\"\n        for i, k in enumerate(self._group_keys):\n            for j in range(self._group_sizes[i]):\n                yield k, self._groups[k][j]\n\n    def __getitem__(self, idx):\n        \"\"\"Return the item at the given position with the key of its original group.\"\"\"\n        i = self._idx_to_group[idx]\n        k = self._group_keys[i]\n        j = idx - self._group_offsets[i]\n        g = self._groups[k]\n        return k, g[j]\n\n\ndef maybe_flatten_dict(data):\n    \"\"\"Return a FlattenedDict if the input is a Mapping, or the data itself otherwise.\"\"\"\n    return FlattenedDict(data) if isinstance(data, Mapping) else data\n\n\ndef compensate(ids, origin_ids):\n    \"\"\"computing the compensate set of ids from origin_ids\n\n    Note: ids should be a subset of origin_ids.\n    Any of ids and origin_ids can be non-consecutive,\n    and origin_ids should be sorted.\n\n    Example:\n    >>> ids = th.Tensor([0, 2, 4])\n    >>> origin_ids = th.Tensor([0, 1, 2, 4, 5])\n    >>> compensate(ids, origin_ids)\n    th.Tensor([1, 5])\n    \"\"\"\n    # trick here, eid_0 or nid_0 can be 0.\n    mask = F.scatter_row(\n        origin_ids,\n        F.copy_to(F.tensor(0, dtype=F.int64), F.context(origin_ids)),\n        F.copy_to(\n            F.tensor(1, dtype=F.dtype(origin_ids)), F.context(origin_ids)\n        ),\n    )\n    mask = F.scatter_row(\n        mask, ids, F.full_1d(len(ids), 0, F.dtype(ids), F.context(ids))\n    )\n    return F.tensor(F.nonzero_1d(mask), dtype=F.dtype(ids))\n\n\ndef relabel(x):\n    \"\"\"Relabel the input ids to continuous ids that starts from zero.\n\n    Ids are assigned new ids according to their ascending order.\n\n    Examples\n    --------\n    >>> x = [1, 5, 3, 6]\n    >>> n2o, o2n = build_relabel_map(x)\n    >>> n2o\n    [1, 3, 5, 6]\n    >>> o2n\n    [n/a, 0, n/a, 1, n/a, 2, 3]\n\n    \"n/a\" will be filled with 0\n\n    Parameters\n    ----------\n    x : Tensor\n        ID tensor.\n\n    Returns\n    -------\n    new_to_old : Tensor\n        The mapping from new id to old id.\n    old_to_new : Tensor\n        The mapping from old id to new id. It is a vector of length MAX(x).\n        One can use advanced indexing to convert an old id tensor to a\n        new id tensor: new_id = old_to_new[old_id]\n    \"\"\"\n    unique_x = F.unique(x)\n    map_len = F.as_scalar(F.max(unique_x, dim=0)) + 1\n    ctx = F.context(x)\n    dtype = F.dtype(x)\n    old_to_new = F.zeros((map_len,), dtype=dtype, ctx=ctx)\n    old_to_new = F.scatter_row(\n        old_to_new, unique_x, F.copy_to(F.arange(0, len(unique_x), dtype), ctx)\n    )\n    return unique_x, old_to_new\n\n\ndef extract_node_subframes(graph, nodes_or_device, store_ids=True):\n    \"\"\"Extract node features of the given nodes from :attr:`graph`\n    and return them in frames on the given device.\n\n    Note that this function does not perform actual tensor memory copy but using `Frame.subframe`\n    to get the features. If :attr:`nodes` is None, it performs a shallow copy of the\n    original node frames that only copies the dictionary structure but not the tensor\n    contents.\n\n    Parameters\n    ----------\n    graph : DGLGraph\n        The graph to extract features from.\n    nodes : list[Tensor] or device or None\n        Node IDs or device.\n        If a list, the list length must be equal to the number of node types\n        in the graph.\n        If None, the whole frame is shallow-copied.\n    store_ids : bool\n        If True, the returned frames will store :attr:`nodes` in the ``dgl.NID`` field\n        unless it is None.\n\n    Returns\n    -------\n    list[Frame]\n        Extracted node frames.\n    \"\"\"\n    if nodes_or_device is None:\n        node_frames = [nf.clone() for nf in graph._node_frames]\n    elif is_listlike(nodes_or_device):\n        node_frames = []\n        for i, ind_nodes in enumerate(nodes_or_device):\n            subf = graph._node_frames[i].subframe(ind_nodes)\n            if store_ids:\n                subf[NID] = ind_nodes\n            node_frames.append(subf)\n    else:  # device object\n        node_frames = [nf.to(nodes_or_device) for nf in graph._node_frames]\n    return node_frames\n\n\ndef extract_node_subframes_for_block(graph, srcnodes, dstnodes):\n    \"\"\"Extract the input node features and output node features of the given nodes from\n    :attr:`graph` and return them in frames ready for a block.\n\n    Note that this function does not perform actual tensor memory copy but using `Frame.subframe`\n    to get the features. If :attr:`srcnodes` or :attr:`dstnodes` is None, it performs a\n    shallow copy of the original node frames that only copies the dictionary structure\n    but not the tensor contents.\n\n    Parameters\n    ----------\n    graph : DGLGraph\n        The graph to extract features from.\n    srcnodes : list[Tensor]\n        Input node IDs. The list length must be equal to the number of node types\n        in the graph. The returned frames store the node IDs in the ``dgl.NID`` field.\n    dstnodes : list[Tensor]\n        Output node IDs. The list length must be equal to the number of node types\n        in the graph. The returned frames store the node IDs in the ``dgl.NID`` field.\n\n    Returns\n    -------\n    list[Frame]\n        Extracted node frames.\n    \"\"\"\n    node_frames = []\n    for i, ind_nodes in enumerate(srcnodes):\n        subf = graph._node_frames[i].subframe(ind_nodes)\n        subf[NID] = ind_nodes\n        node_frames.append(subf)\n    for i, ind_nodes in enumerate(dstnodes):\n        subf = graph._node_frames[i].subframe(ind_nodes)\n        subf[NID] = ind_nodes\n        node_frames.append(subf)\n    return node_frames\n\n\ndef extract_edge_subframes(graph, edges_or_device, store_ids=True):\n    \"\"\"Extract edge features of the given edges from :attr:`graph`\n    and return them in frames.\n\n    Note that this function does not perform actual tensor memory copy but using `Frame.subframe`\n    to get the features. If :attr:`edges` is None, it performs a shallow copy of the\n    original edge frames that only copies the dictionary structure but not the tensor\n    contents.\n\n    Parameters\n    ----------\n    graph : DGLGraph\n        The graph to extract features from.\n    edges_or_device : list[Tensor] or device or None\n        Edge IDs.\n        If a list, the list length must be equal to the number of edge types\n        in the graph.\n        If None, the whole frame is shallow-copied.\n    store_ids : bool\n        If True, the returned frames will store :attr:`edges` in the ``dgl.EID`` field\n        unless it is None.\n\n    Returns\n    -------\n    list[Frame]\n        Extracted edge frames.\n    \"\"\"\n    if edges_or_device is None:\n        edge_frames = [nf.clone() for nf in graph._edge_frames]\n    elif is_listlike(edges_or_device):\n        edge_frames = []\n        for i, ind_edges in enumerate(edges_or_device):\n            subf = graph._edge_frames[i].subframe(ind_edges)\n            if store_ids:\n                subf[EID] = ind_edges\n            edge_frames.append(subf)\n    else:  # device object\n        edge_frames = [nf.to(edges_or_device) for nf in graph._edge_frames]\n    return edge_frames\n\n\ndef set_new_frames(graph, *, node_frames=None, edge_frames=None):\n    \"\"\"Set the node and edge frames of a given graph to new ones.\n\n    Parameters\n    ----------\n    graph : DGLGraph\n        The graph whose node and edge frames are to be updated.\n    node_frames : list[Frame], optional\n        New node frames.\n\n        Default is None, where the node frames are not updated.\n    edge_frames : list[Frame], optional\n        New edge frames\n\n        Default is None, where the edge frames are not updated.\n    \"\"\"\n    if node_frames is not None:\n        assert len(node_frames) == len(\n            graph.ntypes\n        ), \"[BUG] number of node frames different from number of node types\"\n        graph._node_frames = node_frames\n    if edge_frames is not None:\n        assert len(edge_frames) == len(\n            graph.etypes\n        ), \"[BUG] number of edge frames different from number of edge types\"\n        graph._edge_frames = edge_frames\n\n\ndef set_num_threads(num_threads):\n    \"\"\"Set the number of OMP threads in the process.\n\n    Parameters\n    ----------\n    num_threads : int\n        The number of OMP threads in the process.\n    \"\"\"\n    _CAPI_DGLSetOMPThreads(num_threads)\n\n\ndef get_num_threads():\n    \"\"\"Get the number of OMP threads in the process\"\"\"\n    return _CAPI_DGLGetOMPThreads()\n\n\ndef get_numa_nodes_cores():\n    \"\"\"Returns numa nodes info, format:\n    {<node_id>: [(<core_id>, [<sibling_thread_id_0>, <sibling_thread_id_1>, ...]), ...], ...}\n    E.g.: {0: [(0, [0, 4]), (1, [1, 5])], 1: [(2, [2, 6]), (3, [3, 7])]}\n\n    If not available, returns {}\n    \"\"\"\n    numa_node_paths = glob.glob(\"/sys/devices/system/node/node[0-9]*\")\n\n    if not numa_node_paths:\n        return {}\n\n    nodes = {}\n    try:\n        for node_path in numa_node_paths:\n            numa_node_id = int(os.path.basename(node_path)[4:])\n\n            thread_siblings = {}\n            for cpu_dir in glob.glob(os.path.join(node_path, \"cpu[0-9]*\")):\n                cpu_id = int(os.path.basename(cpu_dir)[3:])\n\n                with open(\n                    os.path.join(cpu_dir, \"topology\", \"core_id\")\n                ) as core_id_file:\n                    core_id = int(core_id_file.read().strip())\n                    if core_id in thread_siblings:\n                        thread_siblings[core_id].append(cpu_id)\n                    else:\n                        thread_siblings[core_id] = [cpu_id]\n\n            nodes[numa_node_id] = sorted(\n                [(k, sorted(v)) for k, v in thread_siblings.items()]\n            )\n\n    except (OSError, ValueError, IndexError, IOError):\n        dgl_warning(\"Failed to read NUMA info\")\n        return {}\n\n    return nodes\n\n\ndef alias_func(func):\n    \"\"\"Return an alias function with proper docstring.\"\"\"\n\n    @wraps(func)\n    def _fn(*args, **kwargs):\n        return func(*args, **kwargs)\n\n    _fn.__doc__ = \"\"\"Alias of :func:`dgl.{}`.\"\"\".format(func.__name__)\n    return _fn\n\n\ndef apply_each(data, fn, *args, **kwargs):\n    \"\"\"Apply a function to every element in a container.\n\n    If the input data is a list or any sequence other than a string, returns a list\n    whose elements are the same elements applied with the given function.\n\n    If the input data is a dict or any mapping, returns a dict whose keys are the same\n    and values are the elements applied with the given function.\n\n    The first argument of the function will be passed with the individual elements from\n    the input data, followed by the arguments in :attr:`args` and :attr:`kwargs`.\n\n    Parameters\n    ----------\n    data : any\n        Any object.\n    fn : callable\n        Any function.\n    args, kwargs :\n        Additional arguments and keyword-arguments passed to the function.\n\n    Examples\n    --------\n    Applying a ReLU function to a dictionary of tensors:\n\n    >>> h = {k: torch.randn(3) for k in ['A', 'B', 'C']}\n    >>> h = apply_each(h, torch.nn.functional.relu)\n    >>> assert all((v >= 0).all() for v in h.values())\n    \"\"\"\n    if isinstance(data, Mapping):\n        return {k: fn(v, *args, **kwargs) for k, v in data.items()}\n    elif is_listlike(data):\n        return [fn(v, *args, **kwargs) for v in data]\n    else:\n        return fn(data, *args, **kwargs)\n\n\ndef recursive_apply(data, fn, *args, **kwargs):\n    \"\"\"Recursively apply a function to every element in a container.\n\n    If the input data is a list or any sequence other than a string, returns a list\n    whose elements are the same elements applied with the given function.\n\n    If the input data is a dict or any mapping, returns a dict whose keys are the same\n    and values are the elements applied with the given function.\n\n    If the input data is a nested container, the result will have the same nested\n    structure where each element is transformed recursively.\n\n    The first argument of the function will be passed with the individual elements from\n    the input data, followed by the arguments in :attr:`args` and :attr:`kwargs`.\n\n    Parameters\n    ----------\n    data : any\n        Any object.\n    fn : callable\n        Any function.\n    args, kwargs :\n        Additional arguments and keyword-arguments passed to the function.\n\n    Examples\n    --------\n    Applying a ReLU function to a dictionary of tensors:\n\n    >>> h = {k: torch.randn(3) for k in ['A', 'B', 'C']}\n    >>> h = recursive_apply(h, torch.nn.functional.relu)\n    >>> assert all((v >= 0).all() for v in h.values())\n    \"\"\"\n    if isinstance(data, Mapping):\n        return {\n            k: recursive_apply(v, fn, *args, **kwargs) for k, v in data.items()\n        }\n    elif isinstance(data, tuple):\n        return tuple(recursive_apply(v, fn, *args, **kwargs) for v in data)\n    elif is_listlike(data):\n        return [recursive_apply(v, fn, *args, **kwargs) for v in data]\n    else:\n        return fn(data, *args, **kwargs)\n\n\ndef recursive_apply_pair(data1, data2, fn, *args, **kwargs):\n    \"\"\"Recursively apply a function to every pair of elements in two containers with the\n    same nested structure.\n    \"\"\"\n    if isinstance(data1, Mapping) and isinstance(data2, Mapping):\n        return {\n            k: recursive_apply_pair(data1[k], data2[k], fn, *args, **kwargs)\n            for k in data1.keys()\n        }\n    elif isinstance(data1, tuple) and isinstance(data2, tuple):\n        return tuple(\n            recursive_apply_pair(x, y, fn, *args, **kwargs)\n            for x, y in zip(data1, data2)\n        )\n    elif is_listlike(data1) and is_listlike(data2):\n        return [\n            recursive_apply_pair(x, y, fn, *args, **kwargs)\n            for x, y in zip(data1, data2)\n        ]\n    else:\n        return fn(data1, data2, *args, **kwargs)\n\n\ndef context_of(data):\n    \"\"\"Return the device of the data which can be either a tensor or a list/dict of tensors.\"\"\"\n    if isinstance(data, Mapping):\n        return F.context(next(iter(data.values())))\n    elif is_listlike(data):\n        return F.context(next(iter(data)))\n    else:\n        return F.context(data)\n\n\ndef dtype_of(data):\n    \"\"\"Return the dtype of the data which can be either a tensor or a dict of tensors.\"\"\"\n    return F.dtype(\n        next(iter(data.values())) if isinstance(data, Mapping) else data\n    )\n\n\n_init_api(\"dgl.utils.internal\")\n"
  },
  {
    "path": "python/dgl/utils/pin_memory.py",
    "content": "\"\"\"Utility functions related to pinned memory tensors.\"\"\"\n\nfrom .. import backend as F\nfrom .._ffi.function import _init_api\nfrom ..base import DGLError\n\n\ndef pin_memory_inplace(tensor):\n    \"\"\"Register the tensor into pinned memory in-place (i.e. without copying).\n    Users are required to save the returned dgl.ndarray object to avoid being unpinned.\n\n    Parameters\n    ----------\n    tensor : Tensor\n        The tensor to be pinned.\n\n    Returns\n    -------\n    dgl.ndarray\n        The dgl.ndarray object that holds the pinning status and shares the same\n        underlying data with the tensor.\n    \"\"\"\n    if F.backend_name in [\"mxnet\", \"tensorflow\"]:\n        raise DGLError(\n            \"The {} backend does not support pinning \"\n            \"tensors in-place.\".format(F.backend_name)\n        )\n\n    # needs to be writable to allow in-place modification\n    try:\n        nd_array = F.zerocopy_to_dgl_ndarray_for_write(tensor)\n        nd_array.pin_memory_()\n        return nd_array\n    except Exception as e:\n        raise DGLError(\"Failed to pin memory in-place due to: {}\".format(e))\n\n\ndef gather_pinned_tensor_rows(tensor, rows):\n    \"\"\"Directly gather rows from a CPU tensor given an indices array on CUDA devices,\n    and returns the result on the same CUDA device without copying.\n\n    Parameters\n    ----------\n    tensor : Tensor\n        The tensor.  Must be in pinned memory.\n    rows : Tensor\n        The rows to gather.  Must be a CUDA tensor.\n\n    Returns\n    -------\n    Tensor\n        The result with the same device as :attr:`rows`.\n    \"\"\"\n    return F.from_dgl_nd(\n        _CAPI_DGLIndexSelectCPUFromGPU(F.to_dgl_nd(tensor), F.to_dgl_nd(rows))\n    )\n\n\ndef scatter_pinned_tensor_rows(dest, rows, source):\n    \"\"\"Directly scatter rows from a GPU tensor given an indices array on CUDA devices,\n    to a pinned tensor on the CPU.\n\n    Parameters\n    ----------\n    dest : Tensor\n        The tensor on the CPU to scatter rows to. Must be in pinned memory.\n    rows : Tensor\n        The rows to scatter. Must be a CUDA tensor with unique entries.\n    source : Tensor\n        The tensor on the GPU to scatter rows from.\n    \"\"\"\n    _CAPI_DGLIndexScatterGPUToCPU(\n        F.to_dgl_nd(dest), F.to_dgl_nd(rows), F.to_dgl_nd(source)\n    )\n\n\n_init_api(\"dgl.ndarray.uvm\", __name__)\n"
  },
  {
    "path": "python/dgl/utils/shared_mem.py",
    "content": "\"\"\"Shared memory utilities.\n\nFor compatibility with older code that uses ``dgl.utils.shared_mem`` namespace; the\ncontent has been moved to ``dgl.ndarray`` module.\n\"\"\"\nfrom ..ndarray import (  # pylint: disable=unused-import\n    create_shared_mem_array,\n    get_shared_mem_array,\n)\n"
  },
  {
    "path": "python/dgl/view.py",
    "content": "\"\"\"Views of DGLGraph.\"\"\"\nfrom __future__ import absolute_import\n\nfrom collections import defaultdict, namedtuple\nfrom collections.abc import MutableMapping\n\nfrom . import backend as F\nfrom .base import ALL, DGLError\nfrom .frame import LazyFeature\n\nNodeSpace = namedtuple(\"NodeSpace\", [\"data\"])\nEdgeSpace = namedtuple(\"EdgeSpace\", [\"data\"])\n\n\nclass HeteroNodeView(object):\n    \"\"\"A NodeView class to act as G.nodes for a DGLGraph.\"\"\"\n\n    __slots__ = [\"_graph\", \"_typeid_getter\"]\n\n    def __init__(self, graph, typeid_getter):\n        self._graph = graph\n        self._typeid_getter = typeid_getter\n\n    def __getitem__(self, key):\n        if isinstance(key, slice):\n            # slice\n            if not (\n                key.start is None and key.stop is None and key.step is None\n            ):\n                raise DGLError('Currently only full slice \":\" is supported')\n            nodes = ALL\n            ntype = None\n        elif isinstance(key, tuple):\n            nodes, ntype = key\n        elif key is None or isinstance(key, str):\n            nodes = ALL\n            ntype = key\n        else:\n            nodes = key\n            ntype = None\n        ntid = self._typeid_getter(ntype)\n        return NodeSpace(\n            data=HeteroNodeDataView(self._graph, ntype, ntid, nodes)\n        )\n\n    def __call__(self, ntype=None):\n        \"\"\"Return the nodes.\"\"\"\n        ntid = self._typeid_getter(ntype)\n        ret = F.arange(\n            0,\n            self._graph._graph.num_nodes(ntid),\n            dtype=self._graph.idtype,\n            ctx=self._graph.device,\n        )\n        return ret\n\n\nclass HeteroNodeDataView(MutableMapping):\n    \"\"\"The data view class when G.ndata[ntype] is called.\"\"\"\n\n    __slots__ = [\"_graph\", \"_ntype\", \"_ntid\", \"_nodes\"]\n\n    def __init__(self, graph, ntype, ntid, nodes):\n        self._graph = graph\n        self._ntype = ntype\n        self._ntid = ntid\n        self._nodes = nodes\n\n    def __getitem__(self, key):\n        if isinstance(self._ntype, list):\n            ret = {}\n            for (i, ntype) in enumerate(self._ntype):\n                value = self._graph._get_n_repr(self._ntid[i], self._nodes).get(\n                    key, None\n                )\n                if value is not None:\n                    ret[ntype] = value\n            return ret\n        else:\n            return self._graph._get_n_repr(self._ntid, self._nodes)[key]\n\n    def __setitem__(self, key, val):\n        if isinstance(val, LazyFeature):\n            self._graph._node_frames[self._ntid][key] = val\n        elif isinstance(self._ntype, list):\n            assert isinstance(val, dict), (\n                \"Current HeteroNodeDataView has multiple node types, \"\n                \"please passing the node type and the corresponding data through a dict.\"\n            )\n\n            for (ntype, data) in val.items():\n                ntid = self._graph.get_ntype_id(ntype)\n                self._graph._set_n_repr(ntid, self._nodes, {key: data})\n        else:\n            assert isinstance(val, dict) is False, (\n                \"The HeteroNodeDataView has only one node type. \"\n                \"please pass a tensor directly\"\n            )\n            self._graph._set_n_repr(self._ntid, self._nodes, {key: val})\n\n    def __delitem__(self, key):\n        if isinstance(self._ntype, list):\n            for ntid in self._ntid:\n                if self._graph._get_n_repr(ntid, ALL).get(key, None) is None:\n                    continue\n                self._graph._pop_n_repr(ntid, key)\n        else:\n            self._graph._pop_n_repr(self._ntid, key)\n\n    def _transpose(self, as_dict=False):\n        if isinstance(self._ntype, list):\n            ret = defaultdict(dict)\n            for (i, ntype) in enumerate(self._ntype):\n                data = self._graph._get_n_repr(self._ntid[i], self._nodes)\n                for key in self._graph._node_frames[self._ntid[i]]:\n                    ret[key][ntype] = data[key]\n        else:\n            ret = self._graph._get_n_repr(self._ntid, self._nodes)\n            if as_dict:\n                ret = {\n                    key: ret[key]\n                    for key in self._graph._node_frames[self._ntid]\n                }\n        return ret\n\n    def __len__(self):\n        return len(self._transpose())\n\n    def __iter__(self):\n        return iter(self._transpose())\n\n    def keys(self):\n        return self._transpose().keys()\n\n    def values(self):\n        return self._transpose().values()\n\n    def __repr__(self):\n        return repr(self._transpose(as_dict=True))\n\n\nclass HeteroEdgeView(object):\n    \"\"\"A EdgeView class to act as G.edges for a DGLGraph.\"\"\"\n\n    __slots__ = [\"_graph\"]\n\n    def __init__(self, graph):\n        self._graph = graph\n\n    def __getitem__(self, key):\n        if isinstance(key, slice):\n            # slice\n            if not (\n                key.start is None and key.stop is None and key.step is None\n            ):\n                raise DGLError('Currently only full slice \":\" is supported')\n            edges = ALL\n            etype = None\n        elif key is None:\n            edges = ALL\n            etype = None\n        elif isinstance(key, tuple):\n            if len(key) == 3:\n                edges = ALL\n                etype = key\n            else:\n                edges = key\n                etype = None\n        elif isinstance(key, str):\n            edges = ALL\n            etype = key\n        else:\n            edges = key\n            etype = None\n        return EdgeSpace(data=HeteroEdgeDataView(self._graph, etype, edges))\n\n    def __call__(self, *args, **kwargs):\n        \"\"\"Return all the edges.\"\"\"\n        return self._graph.all_edges(*args, **kwargs)\n\n\nclass HeteroEdgeDataView(MutableMapping):\n    \"\"\"The data view class when G.edata[etype] is called.\"\"\"\n\n    __slots__ = [\"_graph\", \"_etype\", \"_etid\", \"_edges\"]\n\n    def __init__(self, graph, etype, edges):\n        self._graph = graph\n        self._etype = etype\n        self._etid = (\n            [self._graph.get_etype_id(t) for t in etype]\n            if isinstance(etype, list)\n            else self._graph.get_etype_id(etype)\n        )\n        self._edges = edges\n\n    def __getitem__(self, key):\n        if isinstance(self._etype, list):\n            ret = {}\n            for (i, etype) in enumerate(self._etype):\n                value = self._graph._get_e_repr(self._etid[i], self._edges).get(\n                    key, None\n                )\n                if value is not None:\n                    ret[etype] = value\n            return ret\n        else:\n            return self._graph._get_e_repr(self._etid, self._edges)[key]\n\n    def __setitem__(self, key, val):\n        if isinstance(val, LazyFeature):\n            self._graph._edge_frames[self._etid][key] = val\n        elif isinstance(self._etype, list):\n            assert isinstance(val, dict), (\n                \"Current HeteroEdgeDataView has multiple edge types, \"\n                \"please pass the edge type and the corresponding data through a dict.\"\n            )\n\n            for (etype, data) in val.items():\n                etid = self._graph.get_etype_id(etype)\n                self._graph._set_e_repr(etid, self._edges, {key: data})\n        else:\n            assert isinstance(val, dict) is False, (\n                \"The HeteroEdgeDataView has only one edge type. \"\n                \"please pass a tensor directly\"\n            )\n            self._graph._set_e_repr(self._etid, self._edges, {key: val})\n\n    def __delitem__(self, key):\n        if isinstance(self._etype, list):\n            for etid in self._etid:\n                if self._graph._get_e_repr(etid, ALL).get(key, None) is None:\n                    continue\n                self._graph._pop_e_repr(etid, key)\n        else:\n            self._graph._pop_e_repr(self._etid, key)\n\n    def _transpose(self, as_dict=False):\n        if isinstance(self._etype, list):\n            ret = defaultdict(dict)\n            for (i, etype) in enumerate(self._etype):\n                data = self._graph._get_e_repr(self._etid[i], self._edges)\n                for key in self._graph._edge_frames[self._etid[i]]:\n                    ret[key][etype] = data[key]\n        else:\n            ret = self._graph._get_e_repr(self._etid, self._edges)\n            if as_dict:\n                ret = {\n                    key: ret[key]\n                    for key in self._graph._edge_frames[self._etid]\n                }\n        return ret\n\n    def __len__(self):\n        return len(self._transpose())\n\n    def __iter__(self):\n        return iter(self._transpose())\n\n    def keys(self):\n        return self._transpose().keys()\n\n    def values(self):\n        return self._transpose().values()\n\n    def __repr__(self):\n        return repr(self._transpose(as_dict=True))\n"
  },
  {
    "path": "python/setup.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\nimport glob\nimport os\nimport shutil\nimport sys\nimport sysconfig\n\nfrom setuptools import find_packages, setup\nfrom setuptools.dist import Distribution\nfrom setuptools.extension import Extension\n\n\nclass BinaryDistribution(Distribution):\n    def has_ext_modules(self):\n        return True\n\n\nCURRENT_DIR = os.path.dirname(__file__)\n\n\ndef get_lib_path():\n    \"\"\"Get library path, name and version\"\"\"\n    # We can not import `libinfo.py` in setup.py directly since __init__.py\n    # Will be invoked which introduces dependences\n    libinfo_py = os.path.join(CURRENT_DIR, \"./dgl/_ffi/libinfo.py\")\n    libinfo = {\"__file__\": libinfo_py}\n    exec(\n        compile(open(libinfo_py, \"rb\").read(), libinfo_py, \"exec\"),\n        libinfo,\n        libinfo,\n    )\n    version = libinfo[\"__version__\"]\n\n    lib_path = libinfo[\"find_lib_path\"]()\n    libs = [lib_path[0]]\n\n    return libs, version\n\n\ndef get_lib_pattern(lib_name):\n    if sys.platform.startswith(\"linux\"):\n        lib_pattern = f\"lib{lib_name}_*.so\"\n    elif sys.platform.startswith(\"darwin\"):\n        lib_pattern = f\"lib{lib_name}_*.dylib\"\n    elif sys.platform.startswith(\"win\"):\n        lib_pattern = f\"{lib_name}_*.dll\"\n    else:\n        raise NotImplementedError(\"Unsupported system: %s\" % sys.platform)\n    return lib_pattern\n\n\nLIBS, VERSION = get_lib_path()\nBACKENDS = [\"pytorch\"]\n\n\ndef remove_lib(lib_name):\n    for lib_path in glob.glob(\n        os.path.join(CURRENT_DIR, \"dgl\", lib_name, get_lib_pattern(lib_name))\n    ):\n        try:\n            os.remove(lib_path)\n        except BaseException:\n            pass\n\n\ndef cleanup():\n    # Wheel cleanup\n    try:\n        os.remove(\"MANIFEST.in\")\n    except BaseException:\n        pass\n\n    for path in LIBS:\n        _, libname = os.path.split(path)\n        try:\n            os.remove(os.path.join(\"dgl\", libname))\n        except BaseException:\n            pass\n    for backend in BACKENDS:\n        remove_lib(\"tensoradapter\")\n\n        if backend == \"pytorch\":\n            remove_lib(\"dgl_sparse\")\n            remove_lib(\"graphbolt\")\n\n    # Remove build artifacts.\n    dir_to_remove = [\"build\", \"dgl.egg-info\"]\n    for dir_ in dir_to_remove:\n        print(f\"Removing {dir_}\")\n        if os.path.isdir(dir_):\n            shutil.rmtree(dir_)\n\n\ndef config_cython():\n    \"\"\"Try to configure cython and return cython configuration\"\"\"\n    if sys.platform.startswith(\"win\"):\n        print(\n            \"WARNING: Cython is not supported on Windows, will compile without cython module\"\n        )\n        return []\n    sys_cflags = sysconfig.get_config_var(\"CFLAGS\")\n\n    if \"i386\" in sys_cflags and \"x86_64\" in sys_cflags:\n        print(\n            \"WARNING: Cython library may not be compiled correctly with both i386 and x64\"\n        )\n        return []\n    try:\n        from Cython.Build import cythonize\n\n        # from setuptools.extension import Extension\n        if sys.version_info >= (3, 0):\n            subdir = \"_cy3\"\n        else:\n            subdir = \"_cy2\"\n        ret = []\n        path = \"dgl/_ffi/_cython\"\n        library_dirs = [\"dgl\", \"../build/Release\", \"../build\"]\n        libraries = [\"dgl\"]\n        for fn in os.listdir(path):\n            if not fn.endswith(\".pyx\"):\n                continue\n            ret.append(\n                Extension(\n                    \"dgl._ffi.%s.%s\" % (subdir, fn[:-4]),\n                    [\"dgl/_ffi/_cython/%s\" % fn],\n                    include_dirs=[\n                        \"../include/\",\n                        \"../third_party/dmlc-core/include\",\n                        \"../third_party/dlpack/include\",\n                    ],\n                    library_dirs=library_dirs,\n                    libraries=libraries,\n                    # Crashes without this flag with GCC 5.3.1\n                    extra_compile_args=[\"-std=c++17\"],\n                    language=\"c++\",\n                )\n            )\n        return cythonize(\n            ret, force=True, compiler_directives={\"language_level\": \"3\"}\n        )\n    except ImportError:\n        print(\n            \"WARNING: Cython is not installed, will compile without cython module\"\n        )\n        return []\n\n\ndef copy_lib(lib_name, backend=\"\"):\n    for lib_path in glob.glob(\n        os.path.join(dir_, lib_name, backend, get_lib_pattern(lib_name))\n    ):\n        lib_file_name = os.path.basename(lib_path)\n        dst_dir_ = os.path.join(CURRENT_DIR, \"dgl\", lib_name, backend)\n        os.makedirs(\n            dst_dir_,\n            exist_ok=True,\n        )\n        shutil.copy(\n            os.path.join(dir_, lib_name, backend, lib_file_name),\n            dst_dir_,\n        )\n        fo.write(f\"include dgl/{lib_name}/{backend}/{lib_file_name}\\n\")\n\n\ninclude_libs = False\nwheel_include_libs = False\nif \"bdist_wheel\" in sys.argv or os.getenv(\"CONDA_BUILD\"):\n    wheel_include_libs = True\nelif \"clean\" in sys.argv:\n    cleanup()\nelse:\n    include_libs = True\n\nsetup_kwargs = {}\n\n# For bdist_wheel only\nif wheel_include_libs:\n    with open(\"MANIFEST.in\", \"w\") as fo:\n        for path in LIBS:\n            shutil.copy(path, os.path.join(CURRENT_DIR, \"dgl\"))\n            dir_, libname = os.path.split(path)\n            fo.write(\"include dgl/%s\\n\" % libname)\n\n        for backend in BACKENDS:\n            copy_lib(\"tensoradapter\", backend)\n            if backend == \"pytorch\":\n                copy_lib(\"dgl_sparse\")\n                copy_lib(\"graphbolt\")\n    setup_kwargs = {\"include_package_data\": True}\n\n\ndef get_lib_file_path(lib_name, backend=\"\"):\n    return (\n        f\"dgl/{lib_name}/{backend}\",\n        glob.glob(\n            os.path.join(\n                os.path.dirname(os.path.relpath(path, CURRENT_DIR)),\n                lib_name,\n                backend,\n                get_lib_pattern(lib_name),\n            )\n        ),\n    )\n\n\n# For source tree setup\n# Conda build also includes the binary library\nif include_libs:\n    rpath = [os.path.relpath(path, CURRENT_DIR) for path in LIBS]\n    data_files = [(\"dgl\", rpath)]\n    for path in LIBS:\n        for backend in BACKENDS:\n            data_files.append(get_lib_file_path(\"tensoradapter\", backend))\n            if backend == \"pytorch\":\n                data_files.append(get_lib_file_path(\"dgl_sparse\"))\n                data_files.append(get_lib_file_path(\"graphbolt\"))\n    setup_kwargs = {\"include_package_data\": True, \"data_files\": data_files}\n\n# Configure dependencies.\ninstall_requires = [\n    \"networkx>=2.1\",\n    \"numpy>=1.14.0\",\n    \"packaging\",\n    \"pandas\",\n    \"psutil>=5.8.0\",\n    \"pydantic>=2.0\",\n    \"pyyaml\",\n    \"requests>=2.19.0\",\n    \"scipy>=1.1.0\",\n    \"tqdm\",\n]\n\nsetup(\n    name=\"dgl\" + os.getenv(\"DGL_PACKAGE_SUFFIX\", \"\"),\n    version=VERSION,\n    description=\"Deep Graph Library\",\n    zip_safe=False,\n    maintainer=\"DGL Team\",\n    maintainer_email=\"wmjlyjemaine@gmail.com\",\n    packages=find_packages(),\n    install_requires=install_requires,\n    url=\"https://github.com/dmlc/dgl\",\n    distclass=BinaryDistribution,\n    ext_modules=config_cython(),\n    classifiers=[\n        \"Development Status :: 3 - Alpha\",\n        \"Programming Language :: Python :: 3\",\n        \"License :: OSI Approved :: Apache Software License\",\n    ],\n    license=\"APACHE\",\n    **setup_kwargs,\n)\n\nif wheel_include_libs:\n    cleanup()\n"
  },
  {
    "path": "python/update_version.py",
    "content": "\"\"\"\nThis is the global script that set the version information of DGL.\nThis script runs and update all the locations that related to versions\nList of affected files:\n- dgl-root/python/dgl/_ffi/libinfo.py\n- dgl-root/include/dgl/runtime/c_runtime_api.h\n- dgl-root/conda/dgl/meta.yaml\n\"\"\"\nimport os\nimport re\n\n# current version\n# We use the version of the incoming release for code\n# that is under development\n# The environment variable DGL_PRERELEASE is the prerelase suffix\n# (usually \"aYYMMDD\")\n# The environment variable DGL_VERSION_SUFFIX is the local version label\n# suffix for indicating CPU and CUDA versions as in PEP 440 (e.g. \"+cu102\")\n__version__ = \"2.5\" + os.getenv(\"DGL_PRERELEASE\", \"\")\n__version__ += os.getenv(\"DGL_VERSION_SUFFIX\", \"\")\nprint(__version__)\n\n# Implementations\n\n\ndef update(file_name, pattern, repl):\n    update = []\n    hit_counter = 0\n    need_update = False\n    for l in open(file_name):\n        result = re.findall(pattern, l)\n        if result:\n            assert len(result) == 1\n            hit_counter += 1\n            if result[0] != repl:\n                l = re.sub(pattern, repl, l)\n                need_update = True\n                print(\"%s: %s->%s\" % (file_name, result[0], repl))\n            else:\n                print(\"%s: version is already %s\" % (file_name, repl))\n\n        update.append(l)\n    if hit_counter != 1:\n        raise RuntimeError(\"Cannot find version in %s\" % file_name)\n\n    if need_update:\n        with open(file_name, \"w\") as output_file:\n            for l in update:\n                output_file.write(l)\n\n\ndef main():\n    curr_dir = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))\n    proj_root = os.path.abspath(os.path.join(curr_dir, \"..\"))\n    # python path\n    update(\n        os.path.join(proj_root, \"python\", \"dgl\", \"_ffi\", \"libinfo.py\"),\n        r\"(?<=__version__ = \\\")[.0-9a-z+_]+\",\n        __version__,\n    )\n    # C++ header\n    update(\n        os.path.join(proj_root, \"include\", \"dgl\", \"runtime\", \"c_runtime_api.h\"),\n        '(?<=DGL_VERSION \")[.0-9a-z+_]+',\n        __version__,\n    )\n    # conda\n    for path in [\"dgl\"]:\n        update(\n            os.path.join(proj_root, \"conda\", path, \"meta.yaml\"),\n            \"(?<=version: )[.0-9a-z+_]+\",\n            __version__,\n        )\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "readthedocs.yml",
    "content": "build:\n  image: latest\n\nformats: []\n\npython:\n  version: 3.6\n  use_system_site_packages: true\n  setup_py_install: false\n"
  },
  {
    "path": "script/build_dgl.sh",
    "content": "#!/bin/bash\n\nset -e\n\nusage() {\ncat << EOF\nusage: bash $0 OPTIONS\nexamples:\n  Start a CPU only build: bash $0 -c\n  Start a CUDA build: bash $0 -g\n  Build incrementally: bash $0\n  Remove all intermediate output and restart a CPU only build: bash $0 -c -r\n  Build with extra cmake arguments: bash $0 -c -e '-DBUILD_TORCH=ON'\n\nBuild DGL. By default, build incrementally on top of the current state.\n\nOPTIONS:\n  -h           Show this message.\n  -c           Restart CPU only build.\n  -e           Extra arguments of cmake.\n  -g           Restart CUDA build.\n  -r           Remove all intermediate output.\n  -t           Type of the build: dev, dogfood or release (default: dev).\nEOF\n}\n\n# Parse flags.\nwhile getopts \"ce:ghrt:\" flag; do\n  if [[ ${flag} == \"c\" ]]; then\n    cuda=\"OFF\"\n  elif [[ ${flag} == \"e\" ]]; then\n    extra_args=${OPTARG}\n  elif [[ ${flag} == \"g\" ]]; then\n    cuda=\"ON\"\n  elif [[ ${flag} == \"r\" ]]; then\n    remove=\"YES\"\n  elif [[ ${flag} == \"t\" ]]; then\n    build_type=${OPTARG}\n  elif [[ ${flag} == \"h\" ]]; then\n    usage\n    exit 0\n  else\n    usage\n    exit 1\n  fi\ndone\n\nif [[ -z ${DGL_HOME} ]]; then\n  echo \"ERROR: Please make sure environment variable DGL_HOME is set correctly.\"\n  exit 1\nfi\n\nif [[ ! ${PWD} == ${DGL_HOME} ]]; then\n  echo \"ERROR: This script only works properly from DGL root directory.\"\n  echo \" Current: ${PWD}\"\n  echo \"DGL_HOME: ${DGL_HOME}\"\n  exit 1\nfi\n\nif [[ ${remove} == \"YES\" ]]; then\n  rm -rf build\n  rm -rf graphbolt/build\n  rm -rf dgl_sparse/build\n  rm -rf tensoradapter/pytorch/build\nfi\n\nif [[ -z ${build_type} ]]; then\n  build_type=\"dev\"\nfi\n\nif [[ -z ${cuda} ]]; then\n  if [[ -d build ]]; then\n    cd build\n  else\n    echo \"ERROR: No existing build status found, unable to build incrementally.\"\n    usage\n    exit 1\n  fi\nelse\n  mkdir -p build\n  cd build\n  cmake -DBUILD_TYPE=${build_type} -DUSE_CUDA=${cuda} ${extra_args} ..\nfi\n\nif [[ ${PWD} == \"${DGL_HOME}/build\" ]]; then\n  make -j\nelse\n  echo \"ERROR: unexpected working directory.\"\n  echo \" Current: ${PWD}\"\n  echo \"Expected: ${DGL_HOME}/build\"\nfi\nexit 0\n"
  },
  {
    "path": "script/build_doc.sh",
    "content": "#!/bin/bash\n\nset -e\n\nusage() {\ncat << EOF\nusage: bash $0 OPTIONS\nexamples:\n  Build doc with PyTorch-backend: bash $0 -p\n  Build doc with MXNet-backend: bash $0 -m\n  Build doc with TensorFlow-backend: bash $0 -t\n  Build incrementally with PyTorch-backend: bash $0\n  Remove all outputs and restart a PyTorch build: bash $0 -p -r\n\nBuild DGL documentation. By default, build incrementally on top of the current state.\n\nOPTIONS:\n  -h           Show this message.\n  -p           Build doc with PyTorch backend.\n  -m           Build doc with MXNet backend.\n  -t           Build doc with TensorFlow backend.\n  -r           Remove all outputs.\nEOF\n}\n\nbackend=\"pytorch\"\n\n# Parse flags.\nwhile getopts \"hpmtr\" flag; do\n  if [[ ${flag} == \"p\" ]]; then\n    backend=\"pytorch\"\n  elif [[ ${flag} == \"m\" ]]; then\n    backend=\"mxnet\"\n  elif [[ ${flag} == \"t\" ]]; then\n    backend=\"tensorflow\"\n  elif [[ ${flag} == \"r\" ]]; then\n    remove=\"YES\"\n  elif [[ ${flag} == \"h\" ]]; then\n    usage\n    exit 0\n  else\n    usage\n    exit 1\n  fi\ndone\n\nif [[ -z ${DGL_HOME} ]]; then\n  echo \"ERROR: Please make sure environment variable DGL_HOME is set correctly.\"\n  exit 1\nfi\n\nif [[ ! ${PWD} == ${DGL_HOME} ]]; then\n  echo \"ERROR: This script only works properly from DGL root directory.\"\n  echo \" Current: ${PWD}\"\n  echo \"DGL_HOME: ${DGL_HOME}\"\n  exit 1\nfi\n\ncd ${DGL_HOME}/docs\n\nif [[ ${remove} == \"YES\" ]]; then\n  bash clean.sh\nfi\n\nexport DGLBACKEND=$backend\nexport DGL_LIBRARY_PATH=${DGL_HOME}/build\nexport PYTHONPATH=${DGL_HOME}/python:$PYTHONPATH\n\nmake $backend\n\nexit 0\n"
  },
  {
    "path": "script/create_dev_conda_env.sh",
    "content": "#!/bin/bash\n\nreadonly CUDA_VERSIONS=\"11.8,12.1,12.4\"\nreadonly TORCH_VERSION=\"2.1.0\"\nreadonly PYTHON_VERSION=\"3.10\"\n\nusage() {\ncat << EOF\nusage: bash $0 OPTIONS\nexamples:\n  bash $0 -c\n  bash $0 -g 12.1\n  bash $0 -g 12.1 -p 3.10\n  bash $0 -g 12.1 -p 3.10 -t 2.1.0\n  bash $0 -c -n dgl-dev-cpu\n\nCreate a developement environment for DGL developers.\n\nOPTIONS:\n  -h           Show this message.\n  -c           Create dev environment in CPU mode.\n  -d           Only display environment YAML file instead of creating it.\n  -f           Force creation of environment (removing a previously existing \n               environment of the same name).\n  -g           Create dev environment in GPU mode with specified CUDA version,\n               supported: ${CUDA_VERSIONS}.\n  -n           Specify the name of the environment.\n  -o           Save environment YAML file to specified path.\n  -p           Create dev environment based on specified python version.\n  -s           Run silently which indicates always 'yes' for any confirmation.\n  -t           Create dev environment based on specified PyTorch version such\n               as '2.0.0'.\nEOF\n}\n\nvalidate() {\n  values=$(echo \"$1\" | tr \",\" \"\\n\")\n  for value in ${values}\n  do\n    if [[ \"${value}\" == $2 ]]; then\n      return 0\n    fi\n  done\n  return 1\n}\n\nconfirm() {\n  echo \"Continue? [yes/no]:\"\n  read confirm\n  if [[ ! ${confirm} == \"yes\" ]]; then\n    exit 0\n  fi\n}\n\n# Parse flags.\nwhile getopts \"cdfg:hn:o:p:st:\" flag; do\n  case \"${flag}\" in\n    c)\n      cpu=1\n      ;;\n    d)\n      dry_run=1\n      ;;\n    f)\n      force_create=1\n      ;;\n    g)\n      cuda_version=${OPTARG}\n      ;;\n    h)\n      usage\n      exit 0\n      ;;\n    n)\n      name=${OPTARG}\n      ;;\n    o)\n      output_path=${OPTARG}\n      ;;\n    p)\n      python_version=${OPTARG}\n      ;;\n    s)\n      always_yes=1\n      ;;\n    t)\n      torch_version=${OPTARG}\n      ;;\n    :)\n      echo \"Error: -${OPTARG} requires an argument.\"\n      exit 1\n      ;;\n    *)\n      usage\n      exit 1\n      ;;\n  esac\ndone\n\nif [[ -n ${cuda_version} && ${cpu} -eq 1 ]]; then\n  echo \"Only one mode can be specified.\"\n  exit 1\nfi\n\nif [[ -z ${cuda_version} && -z ${cpu} ]]; then\n  usage\n  exit 1\nfi\n\nif [[ -z \"${torch_version}\" ]]; then\n  torch_version=${TORCH_VERSION}\nfi\n\n# Set up CPU mode.\nif [[ ${cpu} -eq 1 ]]; then\n  torchversion=${torch_version}\"+cpu\"\n  if [[ -z \"${name}\" ]]; then\n    name=\"dgl-dev-cpu\"\n  fi\nfi\n\n# Set up GPU mode.\nif [[ -n ${cuda_version} ]]; then\n  if ! validate ${CUDA_VERSIONS} ${cuda_version}; then\n    echo \"Error: Invalid CUDA version.\"\n    usage\n    exit 1\n  fi\n\n  echo \"Confirm the installed CUDA version matches the specified one.\"\n  [[ -n \"${always_yes}\" ]] || confirm\n\n  torchversion=${torch_version}\"+cu\"${cuda_version//[-._]/}\n  if [[ -z \"${name}\" ]]; then\n    name=\"dgl-dev-gpu-\"${cuda_version//[-._]/}\n  fi\nfi\n\n# Set python version.\nif [[ -z \"${python_version}\" ]]; then\n  python_version=${PYTHON_VERSION}\nfi\n\necho \"Confirm you are excuting the script from your DGL root directory.\"\necho \"Current working directory: ${PWD}\"\n[[ -n \"${always_yes}\" ]] || confirm\n\n# Prepare the conda environment yaml file.\nrand=$(echo \"${RANDOM}\" | md5sum | head -c 20)\nmkdir -p /tmp/${rand}\nyaml_path=\"/tmp/${rand}/dgl_dev.yml\"\ncp script/dgl_dev.yml.template ${yaml_path}\nsed -i \"s|__NAME__|${name}|g\" ${yaml_path}\nsed -i \"s|__PYTHON_VERSION__|${python_version}|g\" ${yaml_path}\nsed -i \"s|__TORCH_VERSION__|${torchversion}|g\" ${yaml_path}\nsed -i \"s|__DGL_HOME__|${PWD}|g\" ${yaml_path}\n\n# Ask for final confirmation.\necho \"--------------------------------------------------\"\ncat ${yaml_path}\necho \"--------------------------------------------------\"\necho \"Create a conda enviroment with the config?\"\n[[ -n \"${always_yes}\" ]] || confirm\n\n# Save YAML file to specified path\nif [[ -n \"${output_path}\" ]]; then\n  cp ${yaml_path} ${output_path}\n  echo \"Environment YAML file has been saved to ${output_path}.\"\nfi\n\n# Create conda environment.\nif [[ -z \"${dry_run}\" ]]; then\n  conda_args=\"\"\n  if [[ -n \"${force_create}\" ]]; then\n    conda_args=\"${conda_args} --force \"\n  fi\n  conda env create -f ${yaml_path} ${conda_args}\nelse\n  echo \"Running in dry mode, so creation of conda environment is skipped.\"\nfi\n\n# Clean up created tmp conda environment yaml file.\nrm -rf /tmp/${rand}\nexit 0\n"
  },
  {
    "path": "script/dgl_dev.yml.template",
    "content": "name: __NAME__\nchannels:\n  - conda-forge\n  - defaults\ndependencies:\n  - libstdcxx-ng>=9.5.0\n  - python=__PYTHON_VERSION__\n  - pip\n  - graphviz\n  - pandoc\n  - pygraphviz\n  - pip:\n    - --find-links https://download.pytorch.org/whl/torch/\n    - cmake>=3.18\n    - cython\n    - filelock\n    - matplotlib\n    - networkx\n    - nltk\n    - nose\n    - numpy\n    - ogb\n    - pandas\n    - psutil\n    - pyarrow\n    - pydantic>=2.0\n    - pytest\n    - pyyaml\n    - rdflib\n    - requests[security]\n    - scikit-learn\n    - scipy\n    - torch==__TORCH_VERSION__\n    - torcheval\n    - torchmetrics\n    - torch_geometric\n    - tqdm\n    - boto3 # AWS SDK for python\n    - sphinx\n    - sphinx-gallery\n    - sphinx_rtd_theme\n    - sphinx_copybutton\n    - sphinxemoji\n    - nbsphinx\n    - nbsphinx-link\n    - pillow\n    - seaborn\n    - jupyter_http_over_ws\n    - ufmt\n    - clang-format\n    - pylint\n    - lintrunner\n    - jupyterlab\n    - ipywidgets\n    - expecttest\nvariables:\n  DGL_HOME: __DGL_HOME__\n"
  },
  {
    "path": "script/run_pytest.sh",
    "content": "#!/bin/bash\n\nset -e\n\nusage() {\ncat << EOF\nusage: bash $0 OPTIONS TARGETS\nexamples:\n  Run python tests on CPU: bash $0 -c tests/compute/test_subgraph.py\n  Run python tests on GPU: bash $0 -g tests/compute/test_subgraph.py\n\nRun DGL python tests.\n\nOPTIONS:\n  -h           Show this message.\n  -c           Run python tests on CPU.\n  -g           Run python tests on GPU.\nEOF\n}\n\n# Parse flags.\nwhile getopts \"cgh\" flag; do\n  if [[ ${flag} == \"c\" ]]; then\n    device=\"cpu\"\n  elif [[ ${flag} == \"g\" ]]; then\n    device=\"gpu\"\n  elif [[ ${flag} == \"h\" ]]; then\n    usage\n    exit 0\n  else\n    usage\n    exit 1\n  fi\ndone\n\nif [[ -z ${DGL_HOME} ]]; then\n  echo \"ERROR: Please make sure environment variable DGL_HOME is set correctly.\"\n  exit 1\nfi\n\nif [[ ! ${PWD} == ${DGL_HOME} ]]; then\n  echo \"ERROR: This script only works properly from DGL root directory.\"\n  echo \" Current: ${PWD}\"\n  echo \"DGL_HOME: ${DGL_HOME}\"\n  exit 1\nfi\n\nif [[ -z ${device} ]]; then\n  echo \"ERROR: Test device unspecified.\"\n  usage\n  exit 1\nfi\n\n# Reset the index for non-option arguments.\nshift $(($OPTIND-1))\n\nexport DGLBACKEND=pytorch\nexport DGL_LIBRARY_PATH=${DGL_HOME}/build\nexport PYTHONPATH=${DGL_HOME}/python:${DGL_HOME}/tests:${DGL_HOME}/tests/python/pytorch/graphbolt:$PYTHONPATH\nexport DGLTESTDEV=${device}\nexport DGL_DOWNLOAD_DIR=${DGL_HOME}/_download\n\nif [[ -z $@ ]]; then\n  echo \"ERROR: Missing test targets.\"\n  usage\n  exit 1\nfi\n\npython3 -m pytest -v $@\n"
  },
  {
    "path": "src/api/api_container.cc",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file api/api_container.cc\n * @brief Runtime container APIs. (reference: tvm/src/api/api_lang.cc)\n */\n#include <dgl/packed_func_ext.h>\n#include <dgl/runtime/container.h>\n#include <dgl/runtime/ndarray.h>\n#include <dgl/runtime/registry.h>\n\nnamespace dgl {\nnamespace runtime {\n\nDGL_REGISTER_GLOBAL(\"_List\").set_body([](DGLArgs args, DGLRetValue* rv) {\n  auto ret_obj = std::make_shared<runtime::ListObject>();\n  for (int i = 0; i < args.size(); ++i) {\n    ret_obj->data.push_back(args[i].obj_sptr());\n  }\n  *rv = ret_obj;\n});\n\nDGL_REGISTER_GLOBAL(\"_ListGetItem\").set_body([](DGLArgs args, DGLRetValue* rv) {\n  auto& sptr = args[0].obj_sptr();\n  CHECK(sptr->is_type<ListObject>());\n  auto* o = static_cast<const ListObject*>(sptr.get());\n  int64_t i = args[1];\n  CHECK_LT(i, o->data.size()) << \"list out of bound\";\n  *rv = o->data[i];\n});\n\nDGL_REGISTER_GLOBAL(\"_ListSize\").set_body([](DGLArgs args, DGLRetValue* rv) {\n  auto& sptr = args[0].obj_sptr();\n  CHECK(sptr->is_type<ListObject>());\n  auto* o = static_cast<const ListObject*>(sptr.get());\n  *rv = static_cast<int64_t>(o->data.size());\n});\n\nDGL_REGISTER_GLOBAL(\"_Map\").set_body([](DGLArgs args, DGLRetValue* rv) {\n  CHECK_EQ(args.size() % 2, 0);\n  if (args.size() != 0 && args[0].type_code() == kStr) {\n    // StrMap\n    StrMapObject::ContainerType data;\n    for (int i = 0; i < args.size(); i += 2) {\n      CHECK(args[i].type_code() == kStr) << \"The key of the map must be string\";\n      CHECK(args[i + 1].type_code() == kObjectHandle)\n          << \"The value of the map must be an object type\";\n      data.emplace(std::make_pair(\n          args[i].operator std::string(), args[i + 1].obj_sptr()));\n    }\n    auto obj = std::make_shared<StrMapObject>();\n    obj->data = std::move(data);\n    *rv = obj;\n  } else {\n    // object container\n    MapObject::ContainerType data;\n    for (int i = 0; i < args.size(); i += 2) {\n      CHECK(args[i].type_code() == kObjectHandle)\n          << \"The key of the map must be an object type\";\n      CHECK(args[i + 1].type_code() == kObjectHandle)\n          << \"The value of the map must be an object type\";\n      data.emplace(std::make_pair(args[i].obj_sptr(), args[i + 1].obj_sptr()));\n    }\n    auto obj = std::make_shared<MapObject>();\n    obj->data = std::move(data);\n    *rv = obj;\n  }\n});\n\nDGL_REGISTER_GLOBAL(\"_EmptyStrMap\").set_body([](DGLArgs args, DGLRetValue* rv) {\n  StrMapObject::ContainerType data;\n  auto obj = std::make_shared<StrMapObject>();\n  obj->data = std::move(data);\n  *rv = obj;\n});\n\nDGL_REGISTER_GLOBAL(\"_MapSize\").set_body([](DGLArgs args, DGLRetValue* rv) {\n  auto& sptr = args[0].obj_sptr();\n  if (sptr->is_type<MapObject>()) {\n    auto* o = static_cast<const MapObject*>(sptr.get());\n    *rv = static_cast<int64_t>(o->data.size());\n  } else {\n    CHECK(sptr->is_type<StrMapObject>());\n    auto* o = static_cast<const StrMapObject*>(sptr.get());\n    *rv = static_cast<int64_t>(o->data.size());\n  }\n});\n\nDGL_REGISTER_GLOBAL(\"_MapGetItem\").set_body([](DGLArgs args, DGLRetValue* rv) {\n  auto& sptr = args[0].obj_sptr();\n  if (sptr->is_type<MapObject>()) {\n    auto* o = static_cast<const MapObject*>(sptr.get());\n    auto it = o->data.find(args[1].obj_sptr());\n    CHECK(it != o->data.end()) << \"cannot find the key in the map\";\n    *rv = (*it).second;\n  } else {\n    CHECK(sptr->is_type<StrMapObject>());\n    auto* o = static_cast<const StrMapObject*>(sptr.get());\n    auto it = o->data.find(args[1].operator std::string());\n    CHECK(it != o->data.end()) << \"cannot find the key in the map\";\n    *rv = (*it).second;\n  }\n});\n\nDGL_REGISTER_GLOBAL(\"_MapItems\").set_body([](DGLArgs args, DGLRetValue* rv) {\n  auto& sptr = args[0].obj_sptr();\n  if (sptr->is_type<MapObject>()) {\n    auto* o = static_cast<const MapObject*>(sptr.get());\n    auto rkvs = std::make_shared<ListObject>();\n    for (const auto& kv : o->data) {\n      rkvs->data.push_back(kv.first);\n      rkvs->data.push_back(kv.second);\n    }\n    *rv = rkvs;\n  } else {\n    CHECK(sptr->is_type<StrMapObject>());\n    auto* o = static_cast<const StrMapObject*>(sptr.get());\n    auto rkvs = std::make_shared<ListObject>();\n    for (const auto& kv : o->data) {\n      rkvs->data.push_back(MakeValue(kv.first));\n      rkvs->data.push_back(kv.second);\n    }\n    *rv = rkvs;\n  }\n});\n\nDGL_REGISTER_GLOBAL(\"_MapCount\").set_body([](DGLArgs args, DGLRetValue* rv) {\n  auto& sptr = args[0].obj_sptr();\n  if (sptr->is_type<MapObject>()) {\n    auto* o = static_cast<const MapObject*>(sptr.get());\n    *rv = static_cast<int64_t>(o->data.count(args[1].obj_sptr()));\n  } else {\n    CHECK(sptr->is_type<StrMapObject>());\n    auto* o = static_cast<const StrMapObject*>(sptr.get());\n    *rv = static_cast<int64_t>(o->data.count(args[1].operator std::string()));\n  }\n});\n\nDGL_REGISTER_GLOBAL(\"_Value\").set_body([](DGLArgs args, DGLRetValue* rv) {\n  *rv = MakeValue(args[0]);\n});\n\nDGL_REGISTER_GLOBAL(\"_ValueGet\").set_body([](DGLArgs args, DGLRetValue* rv) {\n  auto& sptr = args[0].obj_sptr();\n  CHECK(sptr->is_type<ValueObject>());\n  auto* o = static_cast<const ValueObject*>(sptr.get());\n  *rv = o->data;\n});\n\n}  // namespace runtime\n}  // namespace dgl\n"
  },
  {
    "path": "src/api/api_test.cc",
    "content": "/**\n *  Copyright (c) 2022 by Contributors\n * @file api/api_test.cc\n * @brief C APIs for testing FFI\n */\n#include <dgl/packed_func_ext.h>\n#include <dgl/runtime/container.h>\n#include <dgl/runtime/ndarray.h>\n#include <dgl/runtime/registry.h>\n\n#include <thread>\n\nnamespace dgl {\nnamespace runtime {\n\n// Register an internal API for testing python callback.\n// It receives two arguments:\n//   - The python callback function.\n//   - The argument to pass to the python callback\n// It returns what python callback returns\nDGL_REGISTER_GLOBAL(\"_TestPythonCallback\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      LOG(INFO) << \"Inside C API\";\n      PackedFunc fn = args[0];\n      DGLArgs cb_args(args.values + 1, args.type_codes + 1, 1);\n      fn.CallPacked(cb_args, rv);\n    });\n\n// Register an internal API for testing python callback.\n// It receives two arguments:\n//   - The python callback function.\n//   - The argument to pass to the python callback\n// It returns what python callback returns\n//\n// The API runs the python callback in a separate thread to test\n// python GIL is properly released.\nDGL_REGISTER_GLOBAL(\"_TestPythonCallbackThread\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      LOG(INFO) << \"Inside C API\";\n      PackedFunc fn = args[0];\n      auto thr = std::make_shared<std::thread>([fn, args, rv]() {\n        LOG(INFO) << \"Callback thread \" << std::this_thread::get_id();\n        DGLArgs cb_args(args.values + 1, args.type_codes + 1, 1);\n        fn.CallPacked(cb_args, rv);\n      });\n      thr->join();\n    });\n\n}  // namespace runtime\n}  // namespace dgl\n"
  },
  {
    "path": "src/array/arith.h",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file array/arith.h\n * @brief Arithmetic functors\n */\n#ifndef DGL_ARRAY_ARITH_H_\n#define DGL_ARRAY_ARITH_H_\n\n#ifdef __CUDACC__\n#define DGLDEVICE __device__\n#define DGLINLINE __forceinline__\n#else\n#define DGLDEVICE\n#define DGLINLINE inline\n#endif  // __CUDACC__\n\nnamespace dgl {\nnamespace aten {\nnamespace arith {\n\nstruct Add {\n  template <typename T>\n  static DGLINLINE DGLDEVICE T Call(const T& t1, const T& t2) {\n    return t1 + t2;\n  }\n};\n\nstruct Sub {\n  template <typename T>\n  static DGLINLINE DGLDEVICE T Call(const T& t1, const T& t2) {\n    return t1 - t2;\n  }\n};\n\nstruct Mul {\n  template <typename T>\n  static DGLINLINE DGLDEVICE T Call(const T& t1, const T& t2) {\n    return t1 * t2;\n  }\n};\n\nstruct Div {\n  template <typename T>\n  static DGLINLINE DGLDEVICE T Call(const T& t1, const T& t2) {\n    return t1 / t2;\n  }\n};\n\nstruct Mod {\n  template <typename T>\n  static DGLINLINE DGLDEVICE T Call(const T& t1, const T& t2) {\n    return t1 % t2;\n  }\n};\n\nstruct GT {\n  template <typename T>\n  static DGLINLINE DGLDEVICE bool Call(const T& t1, const T& t2) {\n    return t1 > t2;\n  }\n};\n\nstruct LT {\n  template <typename T>\n  static DGLINLINE DGLDEVICE bool Call(const T& t1, const T& t2) {\n    return t1 < t2;\n  }\n};\n\nstruct GE {\n  template <typename T>\n  static DGLINLINE DGLDEVICE bool Call(const T& t1, const T& t2) {\n    return t1 >= t2;\n  }\n};\n\nstruct LE {\n  template <typename T>\n  static DGLINLINE DGLDEVICE bool Call(const T& t1, const T& t2) {\n    return t1 <= t2;\n  }\n};\n\nstruct EQ {\n  template <typename T>\n  static DGLINLINE DGLDEVICE bool Call(const T& t1, const T& t2) {\n    return t1 == t2;\n  }\n};\n\nstruct NE {\n  template <typename T>\n  static DGLINLINE DGLDEVICE bool Call(const T& t1, const T& t2) {\n    return t1 != t2;\n  }\n};\n\nstruct Neg {\n  template <typename T>\n  static DGLINLINE DGLDEVICE T Call(const T& t1) {\n    return -t1;\n  }\n};\n\n}  // namespace arith\n}  // namespace aten\n}  // namespace dgl\n\n#endif  // DGL_ARRAY_ARITH_H_\n"
  },
  {
    "path": "src/array/array.cc",
    "content": "/**\n *  Copyright (c) 2019-2022 by Contributors\n * @file array/array.cc\n * @brief DGL array utilities implementation\n */\n#include <dgl/array.h>\n#include <dgl/bcast.h>\n#include <dgl/graph_traversal.h>\n#include <dgl/packed_func_ext.h>\n#include <dgl/runtime/container.h>\n#include <dgl/runtime/device_api.h>\n#include <dgl/runtime/shared_mem.h>\n\n#include <sstream>\n\n#include \"../c_api_common.h\"\n#include \"./arith.h\"\n#include \"./array_op.h\"\n#include \"./kernel_decl.h\"\n\nusing namespace dgl::runtime;\n\nnamespace dgl {\nnamespace aten {\n\nIdArray NewIdArray(int64_t length, DGLContext ctx, uint8_t nbits) {\n  return IdArray::Empty({length}, DGLDataType{kDGLInt, nbits, 1}, ctx);\n}\n\nFloatArray NewFloatArray(int64_t length, DGLContext ctx, uint8_t nbits) {\n  return FloatArray::Empty({length}, DGLDataType{kDGLFloat, nbits, 1}, ctx);\n}\n\nIdArray Clone(IdArray arr) {\n  IdArray ret = NewIdArray(arr->shape[0], arr->ctx, arr->dtype.bits);\n  ret.CopyFrom(arr);\n  return ret;\n}\n\nIdArray Range(int64_t low, int64_t high, uint8_t nbits, DGLContext ctx) {\n  IdArray ret;\n  ATEN_XPU_SWITCH_CUDA(ctx.device_type, XPU, \"Range\", {\n    if (nbits == 32) {\n      ret = impl::Range<XPU, int32_t>(low, high, ctx);\n    } else if (nbits == 64) {\n      ret = impl::Range<XPU, int64_t>(low, high, ctx);\n    } else {\n      LOG(FATAL) << \"Only int32 or int64 is supported.\";\n    }\n  });\n  return ret;\n}\n\nIdArray Full(int64_t val, int64_t length, uint8_t nbits, DGLContext ctx) {\n  IdArray ret;\n  ATEN_XPU_SWITCH_CUDA(ctx.device_type, XPU, \"Full\", {\n    if (nbits == 32) {\n      ret = impl::Full<XPU, int32_t>(val, length, ctx);\n    } else if (nbits == 64) {\n      ret = impl::Full<XPU, int64_t>(val, length, ctx);\n    } else {\n      LOG(FATAL) << \"Only int32 or int64 is supported.\";\n    }\n  });\n  return ret;\n}\n\ntemplate <typename DType>\nNDArray Full(DType val, int64_t length, DGLContext ctx) {\n  NDArray ret;\n  ATEN_XPU_SWITCH_CUDA(ctx.device_type, XPU, \"Full\", {\n    ret = impl::Full<XPU, DType>(val, length, ctx);\n  });\n  return ret;\n}\n\ntemplate NDArray Full<int32_t>(int32_t val, int64_t length, DGLContext ctx);\ntemplate NDArray Full<int64_t>(int64_t val, int64_t length, DGLContext ctx);\ntemplate NDArray Full<float>(float val, int64_t length, DGLContext ctx);\ntemplate NDArray Full<double>(double val, int64_t length, DGLContext ctx);\n\nIdArray AsNumBits(IdArray arr, uint8_t bits) {\n  CHECK(bits == 32 || bits == 64)\n      << \"Invalid ID type. Must be int32 or int64, but got int\"\n      << static_cast<int>(bits) << \".\";\n  if (arr->dtype.bits == bits) return arr;\n  if (arr.NumElements() == 0) return NewIdArray(arr->shape[0], arr->ctx, bits);\n  IdArray ret;\n  ATEN_XPU_SWITCH_CUDA(arr->ctx.device_type, XPU, \"AsNumBits\", {\n    ATEN_ID_TYPE_SWITCH(\n        arr->dtype, IdType, { ret = impl::AsNumBits<XPU, IdType>(arr, bits); });\n  });\n  return ret;\n}\n\nIdArray HStack(IdArray lhs, IdArray rhs) {\n  IdArray ret;\n  CHECK_SAME_CONTEXT(lhs, rhs);\n  CHECK_SAME_DTYPE(lhs, rhs);\n  CHECK_EQ(lhs->shape[0], rhs->shape[0]);\n  auto device = runtime::DeviceAPI::Get(lhs->ctx);\n  const auto& ctx = lhs->ctx;\n  ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, {\n    const int64_t len = lhs->shape[0];\n    ret = NewIdArray(2 * len, lhs->ctx, lhs->dtype.bits);\n    device->CopyDataFromTo(\n        lhs.Ptr<IdType>(), 0, ret.Ptr<IdType>(), 0, len * sizeof(IdType), ctx,\n        ctx, lhs->dtype);\n    device->CopyDataFromTo(\n        rhs.Ptr<IdType>(), 0, ret.Ptr<IdType>(), len * sizeof(IdType),\n        len * sizeof(IdType), ctx, ctx, lhs->dtype);\n  });\n  return ret;\n}\n\nNDArray IndexSelect(NDArray array, IdArray index) {\n  NDArray ret;\n  CHECK_GE(array->ndim, 1) << \"Only support array with at least 1 dimension\";\n  CHECK_EQ(index->ndim, 1) << \"Index array must be an 1D array.\";\n  // if array is not pinned, index has the same context as array\n  // if array is pinned, op dispatching depends on the context of index\n  CHECK_VALID_CONTEXT(array, index);\n  ATEN_XPU_SWITCH_CUDA(index->ctx.device_type, XPU, \"IndexSelect\", {\n    ATEN_DTYPE_SWITCH(array->dtype, DType, \"values\", {\n      ATEN_ID_TYPE_SWITCH(index->dtype, IdType, {\n        ret = impl::IndexSelect<XPU, DType, IdType>(array, index);\n      });\n    });\n  });\n  return ret;\n}\n\ntemplate <typename ValueType>\nValueType IndexSelect(NDArray array, int64_t index) {\n  CHECK_EQ(array->ndim, 1) << \"Only support select values from 1D array.\";\n  CHECK(index >= 0 && index < array.NumElements())\n      << \"Index \" << index << \" is out of bound.\";\n  ValueType ret = 0;\n  ATEN_XPU_SWITCH_CUDA(array->ctx.device_type, XPU, \"IndexSelect\", {\n    ATEN_DTYPE_SWITCH(array->dtype, DType, \"values\", {\n      ret = impl::IndexSelect<XPU, DType>(array, index);\n    });\n  });\n  return ret;\n}\ntemplate int32_t IndexSelect<int32_t>(NDArray array, int64_t index);\ntemplate int64_t IndexSelect<int64_t>(NDArray array, int64_t index);\ntemplate uint32_t IndexSelect<uint32_t>(NDArray array, int64_t index);\ntemplate uint64_t IndexSelect<uint64_t>(NDArray array, int64_t index);\ntemplate float IndexSelect<float>(NDArray array, int64_t index);\ntemplate double IndexSelect<double>(NDArray array, int64_t index);\n\nNDArray IndexSelect(NDArray array, int64_t start, int64_t end) {\n  CHECK_EQ(array->ndim, 1) << \"Only support select values from 1D array.\";\n  CHECK(start >= 0 && start < array.NumElements())\n      << \"Index \" << start << \" is out of bound.\";\n  CHECK(end >= 0 && end <= array.NumElements())\n      << \"Index \" << end << \" is out of bound.\";\n  CHECK_LE(start, end);\n  auto device = runtime::DeviceAPI::Get(array->ctx);\n  const int64_t len = end - start;\n  NDArray ret = NDArray::Empty({len}, array->dtype, array->ctx);\n  ATEN_DTYPE_SWITCH(array->dtype, DType, \"values\", {\n    device->CopyDataFromTo(\n        array->data, start * sizeof(DType), ret->data, 0, len * sizeof(DType),\n        array->ctx, ret->ctx, array->dtype);\n  });\n  return ret;\n}\n\nNDArray Scatter(NDArray array, IdArray indices) {\n  NDArray ret;\n  ATEN_XPU_SWITCH(array->ctx.device_type, XPU, \"Scatter\", {\n    ATEN_DTYPE_SWITCH(array->dtype, DType, \"values\", {\n      ATEN_ID_TYPE_SWITCH(indices->dtype, IdType, {\n        ret = impl::Scatter<XPU, DType, IdType>(array, indices);\n      });\n    });\n  });\n  return ret;\n}\n\nvoid Scatter_(IdArray index, NDArray value, NDArray out) {\n  CHECK_SAME_DTYPE(value, out);\n  CHECK_SAME_CONTEXT(index, value);\n  CHECK_SAME_CONTEXT(index, out);\n  CHECK_EQ(value->shape[0], index->shape[0]);\n  if (index->shape[0] == 0) return;\n  ATEN_XPU_SWITCH_CUDA(value->ctx.device_type, XPU, \"Scatter_\", {\n    ATEN_DTYPE_SWITCH(value->dtype, DType, \"values\", {\n      ATEN_ID_TYPE_SWITCH(index->dtype, IdType, {\n        impl::Scatter_<XPU, DType, IdType>(index, value, out);\n      });\n    });\n  });\n}\n\nNDArray Repeat(NDArray array, IdArray repeats) {\n  NDArray ret;\n  ATEN_XPU_SWITCH(array->ctx.device_type, XPU, \"Repeat\", {\n    ATEN_DTYPE_SWITCH(array->dtype, DType, \"values\", {\n      ATEN_ID_TYPE_SWITCH(repeats->dtype, IdType, {\n        ret = impl::Repeat<XPU, DType, IdType>(array, repeats);\n      });\n    });\n  });\n  return ret;\n}\n\nIdArray Relabel_(const std::vector<IdArray>& arrays) {\n  IdArray ret;\n  ATEN_XPU_SWITCH_CUDA(arrays[0]->ctx.device_type, XPU, \"Relabel_\", {\n    ATEN_ID_TYPE_SWITCH(arrays[0]->dtype, IdType, {\n      ret = impl::Relabel_<XPU, IdType>(arrays);\n    });\n  });\n  return ret;\n}\n\nNDArray Concat(const std::vector<IdArray>& arrays) {\n  IdArray ret;\n\n  int64_t len = 0, offset = 0;\n  for (size_t i = 0; i < arrays.size(); ++i) {\n    len += arrays[i]->shape[0];\n    CHECK_SAME_DTYPE(arrays[0], arrays[i]);\n    CHECK_SAME_CONTEXT(arrays[0], arrays[i]);\n  }\n\n  NDArray ret_arr = NDArray::Empty({len}, arrays[0]->dtype, arrays[0]->ctx);\n\n  auto device = runtime::DeviceAPI::Get(arrays[0]->ctx);\n  for (size_t i = 0; i < arrays.size(); ++i) {\n    ATEN_DTYPE_SWITCH(arrays[i]->dtype, DType, \"array\", {\n      device->CopyDataFromTo(\n          static_cast<DType*>(arrays[i]->data), 0,\n          static_cast<DType*>(ret_arr->data), offset,\n          arrays[i]->shape[0] * sizeof(DType), arrays[i]->ctx, ret_arr->ctx,\n          arrays[i]->dtype);\n\n      offset += arrays[i]->shape[0] * sizeof(DType);\n    });\n  }\n\n  return ret_arr;\n}\n\ntemplate <typename ValueType>\nstd::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, ValueType pad_value) {\n  std::tuple<NDArray, IdArray, IdArray> ret;\n  ATEN_XPU_SWITCH(array->ctx.device_type, XPU, \"Pack\", {\n    ATEN_DTYPE_SWITCH(array->dtype, DType, \"array\", {\n      ret = impl::Pack<XPU, DType>(array, static_cast<DType>(pad_value));\n    });\n  });\n  return ret;\n}\n\ntemplate std::tuple<NDArray, IdArray, IdArray> Pack<int32_t>(NDArray, int32_t);\ntemplate std::tuple<NDArray, IdArray, IdArray> Pack<int64_t>(NDArray, int64_t);\ntemplate std::tuple<NDArray, IdArray, IdArray> Pack<uint32_t>(\n    NDArray, uint32_t);\ntemplate std::tuple<NDArray, IdArray, IdArray> Pack<uint64_t>(\n    NDArray, uint64_t);\ntemplate std::tuple<NDArray, IdArray, IdArray> Pack<float>(NDArray, float);\ntemplate std::tuple<NDArray, IdArray, IdArray> Pack<double>(NDArray, double);\n\nstd::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths) {\n  std::pair<NDArray, IdArray> ret;\n  ATEN_XPU_SWITCH(array->ctx.device_type, XPU, \"ConcatSlices\", {\n    ATEN_DTYPE_SWITCH(array->dtype, DType, \"array\", {\n      ATEN_ID_TYPE_SWITCH(lengths->dtype, IdType, {\n        ret = impl::ConcatSlices<XPU, DType, IdType>(array, lengths);\n      });\n    });\n  });\n  return ret;\n}\n\nIdArray CumSum(IdArray array, bool prepend_zero) {\n  IdArray ret;\n  ATEN_XPU_SWITCH_CUDA(array->ctx.device_type, XPU, \"CumSum\", {\n    ATEN_ID_TYPE_SWITCH(array->dtype, IdType, {\n      ret = impl::CumSum<XPU, IdType>(array, prepend_zero);\n    });\n  });\n  return ret;\n}\n\nIdArray NonZero(NDArray array) {\n  IdArray ret;\n  ATEN_XPU_SWITCH_CUDA(array->ctx.device_type, XPU, \"NonZero\", {\n    ATEN_ID_TYPE_SWITCH(\n        array->dtype, DType, { ret = impl::NonZero<XPU, DType>(array); });\n  });\n  return ret;\n}\n\nstd::pair<IdArray, IdArray> Sort(IdArray array, const int num_bits) {\n  if (array.NumElements() == 0) {\n    IdArray idx = NewIdArray(0, array->ctx, 64);\n    return std::make_pair(array, idx);\n  }\n  std::pair<IdArray, IdArray> ret;\n  ATEN_XPU_SWITCH_CUDA(array->ctx.device_type, XPU, \"Sort\", {\n    ATEN_ID_TYPE_SWITCH(array->dtype, IdType, {\n      ret = impl::Sort<XPU, IdType>(array, num_bits);\n    });\n  });\n  return ret;\n}\n\nstd::string ToDebugString(NDArray array) {\n  std::ostringstream oss;\n  NDArray a = array.CopyTo(DGLContext{kDGLCPU, 0});\n  oss << \"array([\";\n  ATEN_DTYPE_SWITCH(a->dtype, DType, \"array\", {\n    for (int64_t i = 0; i < std::min<int64_t>(a.NumElements(), 10L); ++i) {\n      oss << a.Ptr<DType>()[i] << \", \";\n    }\n  });\n  if (a.NumElements() > 10) oss << \"...\";\n  oss << \"], dtype=\" << array->dtype << \", ctx=\" << array->ctx << \")\";\n  return oss.str();\n}\n\n///////////////////////// CSR routines //////////////////////////\n\nbool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) {\n  CHECK(row >= 0 && row < csr.num_rows) << \"Invalid row index: \" << row;\n  CHECK(col >= 0 && col < csr.num_cols) << \"Invalid col index: \" << col;\n  bool ret = false;\n  ATEN_CSR_SWITCH_CUDA(csr, XPU, IdType, \"CSRIsNonZero\", {\n    ret = impl::CSRIsNonZero<XPU, IdType>(csr, row, col);\n  });\n  return ret;\n}\n\nNDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) {\n  NDArray ret;\n  CHECK_SAME_DTYPE(csr.indices, row);\n  CHECK_SAME_DTYPE(csr.indices, col);\n  CHECK_SAME_CONTEXT(row, col);\n  ATEN_CSR_SWITCH_CUDA_UVA(csr, row, XPU, IdType, \"CSRIsNonZero\", {\n    ret = impl::CSRIsNonZero<XPU, IdType>(csr, row, col);\n  });\n  return ret;\n}\n\nbool CSRHasDuplicate(CSRMatrix csr) {\n  bool ret = false;\n  ATEN_CSR_SWITCH_CUDA(csr, XPU, IdType, \"CSRHasDuplicate\", {\n    ret = impl::CSRHasDuplicate<XPU, IdType>(csr);\n  });\n  return ret;\n}\n\nint64_t CSRGetRowNNZ(CSRMatrix csr, int64_t row) {\n  CHECK(row >= 0 && row < csr.num_rows) << \"Invalid row index: \" << row;\n  int64_t ret = 0;\n  ATEN_CSR_SWITCH_CUDA(csr, XPU, IdType, \"CSRGetRowNNZ\", {\n    ret = impl::CSRGetRowNNZ<XPU, IdType>(csr, row);\n  });\n  return ret;\n}\n\nNDArray CSRGetRowNNZ(CSRMatrix csr, NDArray row) {\n  NDArray ret;\n  CHECK_SAME_DTYPE(csr.indices, row);\n  ATEN_CSR_SWITCH_CUDA_UVA(csr, row, XPU, IdType, \"CSRGetRowNNZ\", {\n    ret = impl::CSRGetRowNNZ<XPU, IdType>(csr, row);\n  });\n  return ret;\n}\n\nNDArray CSRGetRowColumnIndices(CSRMatrix csr, int64_t row) {\n  CHECK(row >= 0 && row < csr.num_rows) << \"Invalid row index: \" << row;\n  NDArray ret;\n  ATEN_CSR_SWITCH_CUDA(csr, XPU, IdType, \"CSRGetRowColumnIndices\", {\n    ret = impl::CSRGetRowColumnIndices<XPU, IdType>(csr, row);\n  });\n  return ret;\n}\n\nNDArray CSRGetRowData(CSRMatrix csr, int64_t row) {\n  CHECK(row >= 0 && row < csr.num_rows) << \"Invalid row index: \" << row;\n  NDArray ret;\n  ATEN_CSR_SWITCH_CUDA(csr, XPU, IdType, \"CSRGetRowData\", {\n    ret = impl::CSRGetRowData<XPU, IdType>(csr, row);\n  });\n  return ret;\n}\n\nbool CSRIsSorted(CSRMatrix csr) {\n  if (csr.indices->shape[0] <= 1) return true;\n  bool ret = false;\n  ATEN_CSR_SWITCH_CUDA(csr, XPU, IdType, \"CSRIsSorted\", {\n    ret = impl::CSRIsSorted<XPU, IdType>(csr);\n  });\n  return ret;\n}\n\nNDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols) {\n  NDArray ret;\n  CHECK_SAME_DTYPE(csr.indices, rows);\n  CHECK_SAME_DTYPE(csr.indices, cols);\n  CHECK_SAME_CONTEXT(rows, cols);\n  ATEN_CSR_SWITCH_CUDA_UVA(csr, rows, XPU, IdType, \"CSRGetData\", {\n    ret = impl::CSRGetData<XPU, IdType>(csr, rows, cols);\n  });\n  return ret;\n}\n\ntemplate <typename DType>\nNDArray CSRGetData(\n    CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, DType filler) {\n  NDArray ret;\n  CHECK_SAME_DTYPE(csr.indices, rows);\n  CHECK_SAME_DTYPE(csr.indices, cols);\n  CHECK_SAME_CONTEXT(rows, cols);\n  CHECK_SAME_CONTEXT(rows, weights);\n  ATEN_CSR_SWITCH_CUDA_UVA(csr, rows, XPU, IdType, \"CSRGetData\", {\n    ret =\n        impl::CSRGetData<XPU, IdType, DType>(csr, rows, cols, weights, filler);\n  });\n  return ret;\n}\n\nruntime::NDArray CSRGetFloatingData(\n    CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols,\n    runtime::NDArray weights, double filler) {\n  if (weights->dtype.bits == 64) {\n    return CSRGetData<double>(csr, rows, cols, weights, filler);\n  } else {\n    CHECK(weights->dtype.bits == 32)\n        << \"CSRGetFloatingData only supports 32 or 64 bits floaring number\";\n    return CSRGetData<float>(csr, rows, cols, weights, filler);\n  }\n}\n\ntemplate NDArray CSRGetData<float>(\n    CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, float filler);\ntemplate NDArray CSRGetData<double>(\n    CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, double filler);\n\nstd::vector<NDArray> CSRGetDataAndIndices(\n    CSRMatrix csr, NDArray rows, NDArray cols) {\n  CHECK_SAME_DTYPE(csr.indices, rows);\n  CHECK_SAME_DTYPE(csr.indices, cols);\n  CHECK_SAME_CONTEXT(rows, cols);\n  std::vector<NDArray> ret;\n  ATEN_CSR_SWITCH_CUDA_UVA(csr, rows, XPU, IdType, \"CSRGetDataAndIndices\", {\n    ret = impl::CSRGetDataAndIndices<XPU, IdType>(csr, rows, cols);\n  });\n  return ret;\n}\n\nCSRMatrix CSRTranspose(CSRMatrix csr) {\n  CSRMatrix ret;\n  ATEN_XPU_SWITCH_CUDA(csr.indptr->ctx.device_type, XPU, \"CSRTranspose\", {\n    ATEN_ID_TYPE_SWITCH(csr.indptr->dtype, IdType, {\n      ret = impl::CSRTranspose<XPU, IdType>(csr);\n    });\n  });\n  return ret;\n}\n\nCOOMatrix CSRToCOO(CSRMatrix csr, bool data_as_order) {\n  COOMatrix ret;\n  if (data_as_order) {\n    ATEN_XPU_SWITCH_CUDA(\n        csr.indptr->ctx.device_type, XPU, \"CSRToCOODataAsOrder\", {\n          ATEN_ID_TYPE_SWITCH(csr.indptr->dtype, IdType, {\n            ret = impl::CSRToCOODataAsOrder<XPU, IdType>(csr);\n          });\n        });\n  } else {\n    ATEN_XPU_SWITCH_CUDA(csr.indptr->ctx.device_type, XPU, \"CSRToCOO\", {\n      ATEN_ID_TYPE_SWITCH(csr.indptr->dtype, IdType, {\n        ret = impl::CSRToCOO<XPU, IdType>(csr);\n      });\n    });\n  }\n  return ret;\n}\n\nCSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end) {\n  CHECK(start >= 0 && start < csr.num_rows) << \"Invalid start index: \" << start;\n  CHECK(end >= 0 && end <= csr.num_rows) << \"Invalid end index: \" << end;\n  CHECK_GE(end, start);\n  CSRMatrix ret;\n  ATEN_CSR_SWITCH_CUDA(csr, XPU, IdType, \"CSRSliceRows\", {\n    ret = impl::CSRSliceRows<XPU, IdType>(csr, start, end);\n  });\n  return ret;\n}\n\nCSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) {\n  CHECK_SAME_DTYPE(csr.indices, rows);\n  CSRMatrix ret;\n  ATEN_CSR_SWITCH_CUDA_UVA(csr, rows, XPU, IdType, \"CSRSliceRows\", {\n    ret = impl::CSRSliceRows<XPU, IdType>(csr, rows);\n  });\n  return ret;\n}\n\nCSRMatrix CSRSliceMatrix(CSRMatrix csr, NDArray rows, NDArray cols) {\n  CHECK_SAME_DTYPE(csr.indices, rows);\n  CHECK_SAME_DTYPE(csr.indices, cols);\n  CHECK_SAME_CONTEXT(rows, cols);\n  CSRMatrix ret;\n  ATEN_CSR_SWITCH_CUDA_UVA(csr, rows, XPU, IdType, \"CSRSliceMatrix\", {\n    ret = impl::CSRSliceMatrix<XPU, IdType>(csr, rows, cols);\n  });\n  return ret;\n}\n\nvoid CSRSort_(CSRMatrix* csr) {\n  if (csr->sorted) return;\n  ATEN_CSR_SWITCH_CUDA(\n      *csr, XPU, IdType, \"CSRSort_\", { impl::CSRSort_<XPU, IdType>(csr); });\n}\n\nstd::pair<CSRMatrix, NDArray> CSRSortByTag(\n    const CSRMatrix& csr, IdArray tag, int64_t num_tags) {\n  CHECK_EQ(csr.indices->shape[0], tag->shape[0])\n      << \"The length of the tag array should be equal to the number of \"\n         \"non-zero data.\";\n  CHECK_SAME_CONTEXT(csr.indices, tag);\n  CHECK_INT(tag, \"tag\");\n  std::pair<CSRMatrix, NDArray> ret;\n  ATEN_CSR_SWITCH(csr, XPU, IdType, \"CSRSortByTag\", {\n    ATEN_ID_TYPE_SWITCH(tag->dtype, TagType, {\n      ret = impl::CSRSortByTag<XPU, IdType, TagType>(csr, tag, num_tags);\n    });\n  });\n  return ret;\n}\n\nCSRMatrix CSRReorder(\n    CSRMatrix csr, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids) {\n  CSRMatrix ret;\n  ATEN_CSR_SWITCH(csr, XPU, IdType, \"CSRReorder\", {\n    ret = impl::CSRReorder<XPU, IdType>(csr, new_row_ids, new_col_ids);\n  });\n  return ret;\n}\n\nCSRMatrix CSRRemove(CSRMatrix csr, IdArray entries) {\n  CSRMatrix ret;\n  ATEN_CSR_SWITCH(csr, XPU, IdType, \"CSRRemove\", {\n    ret = impl::CSRRemove<XPU, IdType>(csr, entries);\n  });\n  return ret;\n}\n\nstd::pair<COOMatrix, FloatArray> CSRLaborSampling(\n    CSRMatrix mat, IdArray rows, int64_t num_samples, FloatArray prob,\n    int importance_sampling, IdArray random_seed, float seed2_contribution,\n    IdArray NIDs) {\n  std::pair<COOMatrix, FloatArray> ret;\n  ATEN_CSR_SWITCH_CUDA_UVA(mat, rows, XPU, IdType, \"CSRLaborSampling\", {\n    const auto dtype =\n        IsNullArray(prob) ? DGLDataTypeTraits<float>::dtype : prob->dtype;\n    ATEN_FLOAT_TYPE_SWITCH(dtype, FloatType, \"probability\", {\n      ret = impl::CSRLaborSampling<XPU, IdType, FloatType>(\n          mat, rows, num_samples, prob, importance_sampling, random_seed,\n          seed2_contribution, NIDs);\n    });\n  });\n  return ret;\n}\n\nCOOMatrix CSRRowWiseSampling(\n    CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask,\n    bool replace) {\n  COOMatrix ret;\n  if (IsNullArray(prob_or_mask)) {\n    ATEN_CSR_SWITCH_CUDA_UVA(\n        mat, rows, XPU, IdType, \"CSRRowWiseSamplingUniform\", {\n          ret = impl::CSRRowWiseSamplingUniform<XPU, IdType>(\n              mat, rows, num_samples, replace);\n        });\n  } else {\n    // prob_or_mask is pinned and rows on GPU is valid\n    CHECK_VALID_CONTEXT(prob_or_mask, rows);\n    ATEN_CSR_SWITCH_CUDA_UVA(mat, rows, XPU, IdType, \"CSRRowWiseSampling\", {\n      CHECK(!(prob_or_mask->dtype.bits == 8 && XPU == kDGLCUDA))\n          << \"GPU sampling with masks is currently not supported yet.\";\n      ATEN_FLOAT_INT8_UINT8_TYPE_SWITCH(\n          prob_or_mask->dtype, FloatType, \"probability or mask\", {\n            ret = impl::CSRRowWiseSampling<XPU, IdType, FloatType>(\n                mat, rows, num_samples, prob_or_mask, replace);\n          });\n    });\n  }\n  return ret;\n}\n\ntemplate <typename IdType, bool map_seed_nodes>\nstd::pair<CSRMatrix, IdArray> CSRRowWiseSamplingFused(\n    CSRMatrix mat, IdArray rows, IdArray seed_mapping,\n    std::vector<IdType>* new_seed_nodes, int64_t num_samples,\n    NDArray prob_or_mask, bool replace) {\n  std::pair<CSRMatrix, IdArray> ret;\n  if (IsNullArray(prob_or_mask)) {\n    ATEN_XPU_SWITCH(\n        rows->ctx.device_type, XPU, \"CSRRowWiseSamplingUniformFused\", {\n          ret =\n              impl::CSRRowWiseSamplingUniformFused<XPU, IdType, map_seed_nodes>(\n                  mat, rows, seed_mapping, new_seed_nodes, num_samples,\n                  replace);\n        });\n  } else {\n    CHECK_VALID_CONTEXT(prob_or_mask, rows);\n    ATEN_XPU_SWITCH(rows->ctx.device_type, XPU, \"CSRRowWiseSamplingFused\", {\n      ATEN_FLOAT_INT8_UINT8_TYPE_SWITCH(\n          prob_or_mask->dtype, FloatType, \"probability or mask\", {\n            ret = impl::CSRRowWiseSamplingFused<\n                XPU, IdType, FloatType, map_seed_nodes>(\n                mat, rows, seed_mapping, new_seed_nodes, num_samples,\n                prob_or_mask, replace);\n          });\n    });\n  }\n  return ret;\n}\n\ntemplate std::pair<CSRMatrix, IdArray> CSRRowWiseSamplingFused<int64_t, true>(\n    CSRMatrix, IdArray, IdArray, std::vector<int64_t>*, int64_t, NDArray, bool);\n\ntemplate std::pair<CSRMatrix, IdArray> CSRRowWiseSamplingFused<int64_t, false>(\n    CSRMatrix, IdArray, IdArray, std::vector<int64_t>*, int64_t, NDArray, bool);\n\ntemplate std::pair<CSRMatrix, IdArray> CSRRowWiseSamplingFused<int32_t, true>(\n    CSRMatrix, IdArray, IdArray, std::vector<int32_t>*, int64_t, NDArray, bool);\n\ntemplate std::pair<CSRMatrix, IdArray> CSRRowWiseSamplingFused<int32_t, false>(\n    CSRMatrix, IdArray, IdArray, std::vector<int32_t>*, int64_t, NDArray, bool);\n\nCOOMatrix CSRRowWisePerEtypeSampling(\n    CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,\n    const std::vector<int64_t>& num_samples,\n    const std::vector<NDArray>& prob_or_mask, bool replace,\n    bool rowwise_etype_sorted) {\n  COOMatrix ret;\n  CHECK(prob_or_mask.size() > 0) << \"probability or mask array is empty\";\n  ATEN_CSR_SWITCH(mat, XPU, IdType, \"CSRRowWisePerEtypeSampling\", {\n    if (std::all_of(prob_or_mask.begin(), prob_or_mask.end(), IsNullArray)) {\n      ret = impl::CSRRowWisePerEtypeSamplingUniform<XPU, IdType>(\n          mat, rows, eid2etype_offset, num_samples, replace,\n          rowwise_etype_sorted);\n    } else {\n      ATEN_FLOAT_INT8_UINT8_TYPE_SWITCH(\n          prob_or_mask[0]->dtype, DType, \"probability or mask\", {\n            ret = impl::CSRRowWisePerEtypeSampling<XPU, IdType, DType>(\n                mat, rows, eid2etype_offset, num_samples, prob_or_mask, replace,\n                rowwise_etype_sorted);\n          });\n    }\n  });\n  return ret;\n}\n\nCOOMatrix CSRRowWiseTopk(\n    CSRMatrix mat, IdArray rows, int64_t k, NDArray weight, bool ascending) {\n  COOMatrix ret;\n  ATEN_CSR_SWITCH(mat, XPU, IdType, \"CSRRowWiseTopk\", {\n    ATEN_DTYPE_SWITCH(weight->dtype, DType, \"weight\", {\n      ret = impl::CSRRowWiseTopk<XPU, IdType, DType>(\n          mat, rows, k, weight, ascending);\n    });\n  });\n  return ret;\n}\n\nCOOMatrix CSRRowWiseSamplingBiased(\n    CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray tag_offset,\n    FloatArray bias, bool replace) {\n  COOMatrix ret;\n  ATEN_CSR_SWITCH(mat, XPU, IdType, \"CSRRowWiseSamplingBiased\", {\n    ATEN_FLOAT_TYPE_SWITCH(bias->dtype, FloatType, \"bias\", {\n      ret = impl::CSRRowWiseSamplingBiased<XPU, IdType, FloatType>(\n          mat, rows, num_samples, tag_offset, bias, replace);\n    });\n  });\n  return ret;\n}\n\nstd::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(\n    const CSRMatrix& csr, int64_t num_samples, int num_trials,\n    bool exclude_self_loops, bool replace, double redundancy) {\n  CHECK_GT(num_samples, 0) << \"Number of samples must be positive\";\n  CHECK_GT(num_trials, 0) << \"Number of sampling trials must be positive\";\n  std::pair<IdArray, IdArray> result;\n  ATEN_CSR_SWITCH_CUDA(csr, XPU, IdType, \"CSRGlobalUniformNegativeSampling\", {\n    result = impl::CSRGlobalUniformNegativeSampling<XPU, IdType>(\n        csr, num_samples, num_trials, exclude_self_loops, replace, redundancy);\n  });\n  return result;\n}\n\nCSRMatrix UnionCsr(const std::vector<CSRMatrix>& csrs) {\n  CSRMatrix ret;\n  CHECK_GT(csrs.size(), 1)\n      << \"UnionCsr creates a union of multiple CSRMatrixes\";\n  // sanity check\n  for (size_t i = 1; i < csrs.size(); ++i) {\n    CHECK_EQ(csrs[0].num_rows, csrs[i].num_rows)\n        << \"UnionCsr requires both CSRMatrix have same number of rows\";\n    CHECK_EQ(csrs[0].num_cols, csrs[i].num_cols)\n        << \"UnionCsr requires both CSRMatrix have same number of cols\";\n    CHECK_SAME_CONTEXT(csrs[0].indptr, csrs[i].indptr);\n    CHECK_SAME_DTYPE(csrs[0].indptr, csrs[i].indptr);\n  }\n\n  ATEN_CSR_SWITCH(csrs[0], XPU, IdType, \"UnionCsr\", {\n    ret = impl::UnionCsr<XPU, IdType>(csrs);\n  });\n  return ret;\n}\n\nstd::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple(const CSRMatrix& csr) {\n  std::tuple<CSRMatrix, IdArray, IdArray> ret;\n\n  CSRMatrix sorted_csr = (CSRIsSorted(csr)) ? csr : CSRSort(csr);\n  ATEN_CSR_SWITCH(csr, XPU, IdType, \"CSRToSimple\", {\n    ret = impl::CSRToSimple<XPU, IdType>(sorted_csr);\n  });\n  return ret;\n}\n\n///////////////////////// COO routines //////////////////////////\n\nbool COOIsNonZero(COOMatrix coo, int64_t row, int64_t col) {\n  bool ret = false;\n  ATEN_COO_SWITCH(coo, XPU, IdType, \"COOIsNonZero\", {\n    ret = impl::COOIsNonZero<XPU, IdType>(coo, row, col);\n  });\n  return ret;\n}\n\nNDArray COOIsNonZero(COOMatrix coo, NDArray row, NDArray col) {\n  NDArray ret;\n  ATEN_COO_SWITCH(coo, XPU, IdType, \"COOIsNonZero\", {\n    ret = impl::COOIsNonZero<XPU, IdType>(coo, row, col);\n  });\n  return ret;\n}\n\nbool COOHasDuplicate(COOMatrix coo) {\n  bool ret = false;\n  ATEN_COO_SWITCH(coo, XPU, IdType, \"COOHasDuplicate\", {\n    ret = impl::COOHasDuplicate<XPU, IdType>(coo);\n  });\n  return ret;\n}\n\nint64_t COOGetRowNNZ(COOMatrix coo, int64_t row) {\n  int64_t ret = 0;\n  ATEN_COO_SWITCH_CUDA(coo, XPU, IdType, \"COOGetRowNNZ\", {\n    ret = impl::COOGetRowNNZ<XPU, IdType>(coo, row);\n  });\n  return ret;\n}\n\nNDArray COOGetRowNNZ(COOMatrix coo, NDArray row) {\n  NDArray ret;\n  ATEN_COO_SWITCH_CUDA(coo, XPU, IdType, \"COOGetRowNNZ\", {\n    ret = impl::COOGetRowNNZ<XPU, IdType>(coo, row);\n  });\n  return ret;\n}\n\nstd::pair<NDArray, NDArray> COOGetRowDataAndIndices(\n    COOMatrix coo, int64_t row) {\n  std::pair<NDArray, NDArray> ret;\n  ATEN_COO_SWITCH(coo, XPU, IdType, \"COOGetRowDataAndIndices\", {\n    ret = impl::COOGetRowDataAndIndices<XPU, IdType>(coo, row);\n  });\n  return ret;\n}\n\nstd::vector<NDArray> COOGetDataAndIndices(\n    COOMatrix coo, NDArray rows, NDArray cols) {\n  std::vector<NDArray> ret;\n  ATEN_COO_SWITCH(coo, XPU, IdType, \"COOGetDataAndIndices\", {\n    ret = impl::COOGetDataAndIndices<XPU, IdType>(coo, rows, cols);\n  });\n  return ret;\n}\n\nNDArray COOGetData(COOMatrix coo, NDArray rows, NDArray cols) {\n  NDArray ret;\n  ATEN_COO_SWITCH(coo, XPU, IdType, \"COOGetData\", {\n    ret = impl::COOGetData<XPU, IdType>(coo, rows, cols);\n  });\n  return ret;\n}\n\nCOOMatrix COOTranspose(COOMatrix coo) {\n  return COOMatrix(coo.num_cols, coo.num_rows, coo.col, coo.row, coo.data);\n}\n\nCSRMatrix COOToCSR(COOMatrix coo) {\n  CSRMatrix ret;\n  ATEN_XPU_SWITCH_CUDA(coo.row->ctx.device_type, XPU, \"COOToCSR\", {\n    ATEN_ID_TYPE_SWITCH(\n        coo.row->dtype, IdType, { ret = impl::COOToCSR<XPU, IdType>(coo); });\n  });\n  return ret;\n}\n\nCOOMatrix COOSliceRows(COOMatrix coo, int64_t start, int64_t end) {\n  COOMatrix ret;\n  ATEN_COO_SWITCH(coo, XPU, IdType, \"COOSliceRows\", {\n    ret = impl::COOSliceRows<XPU, IdType>(coo, start, end);\n  });\n  return ret;\n}\n\nCOOMatrix COOSliceRows(COOMatrix coo, NDArray rows) {\n  COOMatrix ret;\n  ATEN_COO_SWITCH(coo, XPU, IdType, \"COOSliceRows\", {\n    ret = impl::COOSliceRows<XPU, IdType>(coo, rows);\n  });\n  return ret;\n}\n\nCOOMatrix COOSliceMatrix(COOMatrix coo, NDArray rows, NDArray cols) {\n  COOMatrix ret;\n  ATEN_COO_SWITCH(coo, XPU, IdType, \"COOSliceMatrix\", {\n    ret = impl::COOSliceMatrix<XPU, IdType>(coo, rows, cols);\n  });\n  return ret;\n}\n\nvoid COOSort_(COOMatrix* mat, bool sort_column) {\n  if ((mat->row_sorted && !sort_column) || mat->col_sorted) return;\n  ATEN_XPU_SWITCH_CUDA(mat->row->ctx.device_type, XPU, \"COOSort_\", {\n    ATEN_ID_TYPE_SWITCH(mat->row->dtype, IdType, {\n      impl::COOSort_<XPU, IdType>(mat, sort_column);\n    });\n  });\n}\n\nstd::pair<bool, bool> COOIsSorted(COOMatrix coo) {\n  if (coo.row->shape[0] <= 1) return {true, true};\n  std::pair<bool, bool> ret;\n  ATEN_COO_SWITCH_CUDA(coo, XPU, IdType, \"COOIsSorted\", {\n    ret = impl::COOIsSorted<XPU, IdType>(coo);\n  });\n  return ret;\n}\n\nCOOMatrix COOReorder(\n    COOMatrix coo, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids) {\n  COOMatrix ret;\n  ATEN_COO_SWITCH(coo, XPU, IdType, \"COOReorder\", {\n    ret = impl::COOReorder<XPU, IdType>(coo, new_row_ids, new_col_ids);\n  });\n  return ret;\n}\n\nCOOMatrix COORemove(COOMatrix coo, IdArray entries) {\n  COOMatrix ret;\n  ATEN_COO_SWITCH(coo, XPU, IdType, \"COORemove\", {\n    ret = impl::COORemove<XPU, IdType>(coo, entries);\n  });\n  return ret;\n}\n\nstd::pair<COOMatrix, FloatArray> COOLaborSampling(\n    COOMatrix mat, IdArray rows, int64_t num_samples, FloatArray prob,\n    int importance_sampling, IdArray random_seed, float seed2_contribution,\n    IdArray NIDs) {\n  std::pair<COOMatrix, FloatArray> ret;\n  ATEN_COO_SWITCH(mat, XPU, IdType, \"COOLaborSampling\", {\n    const auto dtype =\n        IsNullArray(prob) ? DGLDataTypeTraits<float>::dtype : prob->dtype;\n    ATEN_FLOAT_TYPE_SWITCH(dtype, FloatType, \"probability\", {\n      ret = impl::COOLaborSampling<XPU, IdType, FloatType>(\n          mat, rows, num_samples, prob, importance_sampling, random_seed,\n          seed2_contribution, NIDs);\n    });\n  });\n  return ret;\n}\n\nCOOMatrix COORowWiseSampling(\n    COOMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask,\n    bool replace) {\n  COOMatrix ret;\n  ATEN_COO_SWITCH(mat, XPU, IdType, \"COORowWiseSampling\", {\n    if (IsNullArray(prob_or_mask)) {\n      ret = impl::COORowWiseSamplingUniform<XPU, IdType>(\n          mat, rows, num_samples, replace);\n    } else {\n      ATEN_FLOAT_INT8_UINT8_TYPE_SWITCH(\n          prob_or_mask->dtype, DType, \"probability or mask\", {\n            ret = impl::COORowWiseSampling<XPU, IdType, DType>(\n                mat, rows, num_samples, prob_or_mask, replace);\n          });\n    }\n  });\n  return ret;\n}\n\nCOOMatrix COORowWisePerEtypeSampling(\n    COOMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,\n    const std::vector<int64_t>& num_samples,\n    const std::vector<NDArray>& prob_or_mask, bool replace) {\n  COOMatrix ret;\n  CHECK(prob_or_mask.size() > 0) << \"probability or mask array is empty\";\n  ATEN_COO_SWITCH(mat, XPU, IdType, \"COORowWisePerEtypeSampling\", {\n    if (std::all_of(prob_or_mask.begin(), prob_or_mask.end(), IsNullArray)) {\n      ret = impl::COORowWisePerEtypeSamplingUniform<XPU, IdType>(\n          mat, rows, eid2etype_offset, num_samples, replace);\n    } else {\n      ATEN_FLOAT_INT8_UINT8_TYPE_SWITCH(\n          prob_or_mask[0]->dtype, DType, \"probability or mask\", {\n            ret = impl::COORowWisePerEtypeSampling<XPU, IdType, DType>(\n                mat, rows, eid2etype_offset, num_samples, prob_or_mask,\n                replace);\n          });\n    }\n  });\n  return ret;\n}\n\nCOOMatrix COORowWiseTopk(\n    COOMatrix mat, IdArray rows, int64_t k, FloatArray weight, bool ascending) {\n  COOMatrix ret;\n  ATEN_COO_SWITCH(mat, XPU, IdType, \"COORowWiseTopk\", {\n    ATEN_DTYPE_SWITCH(weight->dtype, DType, \"weight\", {\n      ret = impl::COORowWiseTopk<XPU, IdType, DType>(\n          mat, rows, k, weight, ascending);\n    });\n  });\n  return ret;\n}\n\nstd::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo) {\n  std::pair<COOMatrix, IdArray> ret;\n  ATEN_COO_SWITCH(coo, XPU, IdType, \"COOCoalesce\", {\n    ret = impl::COOCoalesce<XPU, IdType>(coo);\n  });\n  return ret;\n}\n\nCOOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos) {\n  COOMatrix ret;\n  ATEN_XPU_SWITCH_CUDA(coos[0].row->ctx.device_type, XPU, \"DisjointUnionCoo\", {\n    ATEN_ID_TYPE_SWITCH(coos[0].row->dtype, IdType, {\n      ret = impl::DisjointUnionCoo<XPU, IdType>(coos);\n    });\n  });\n  return ret;\n}\n\nCOOMatrix COOLineGraph(const COOMatrix& coo, bool backtracking) {\n  COOMatrix ret;\n  ATEN_COO_SWITCH(coo, XPU, IdType, \"COOLineGraph\", {\n    ret = impl::COOLineGraph<XPU, IdType>(coo, backtracking);\n  });\n  return ret;\n}\n\nCOOMatrix UnionCoo(const std::vector<COOMatrix>& coos) {\n  COOMatrix ret;\n  CHECK_GT(coos.size(), 1)\n      << \"UnionCoo creates a union of multiple COOMatrixes\";\n  // sanity check\n  for (size_t i = 1; i < coos.size(); ++i) {\n    CHECK_EQ(coos[0].num_rows, coos[i].num_rows)\n        << \"UnionCoo requires both COOMatrix have same number of rows\";\n    CHECK_EQ(coos[0].num_cols, coos[i].num_cols)\n        << \"UnionCoo requires both COOMatrix have same number of cols\";\n    CHECK_SAME_CONTEXT(coos[0].row, coos[i].row);\n    CHECK_SAME_DTYPE(coos[0].row, coos[i].row);\n  }\n\n  // we assume the number of coos is not large in common cases\n  std::vector<IdArray> coo_row;\n  std::vector<IdArray> coo_col;\n  bool has_data = false;\n\n  for (size_t i = 0; i < coos.size(); ++i) {\n    coo_row.push_back(coos[i].row);\n    coo_col.push_back(coos[i].col);\n    has_data |= COOHasData(coos[i]);\n  }\n\n  IdArray row = Concat(coo_row);\n  IdArray col = Concat(coo_col);\n  IdArray data = NullArray();\n\n  if (has_data) {\n    std::vector<IdArray> eid_data;\n    eid_data.push_back(\n        COOHasData(coos[0]) ? coos[0].data\n                            : Range(\n                                  0, coos[0].row->shape[0],\n                                  coos[0].row->dtype.bits, coos[0].row->ctx));\n    int64_t num_edges = coos[0].row->shape[0];\n    for (size_t i = 1; i < coos.size(); ++i) {\n      eid_data.push_back(\n          COOHasData(coos[i])\n              ? coos[i].data + num_edges\n              : Range(\n                    num_edges, num_edges + coos[i].row->shape[0],\n                    coos[i].row->dtype.bits, coos[i].row->ctx));\n      num_edges += coos[i].row->shape[0];\n    }\n\n    data = Concat(eid_data);\n  }\n\n  return COOMatrix(\n      coos[0].num_rows, coos[0].num_cols, row, col, data, false, false);\n}\n\nstd::tuple<COOMatrix, IdArray, IdArray> COOToSimple(const COOMatrix& coo) {\n  // coo column sorted\n  const COOMatrix sorted_coo = COOSort(coo, true);\n  const IdArray eids_shuffled =\n      COOHasData(sorted_coo)\n          ? sorted_coo.data\n          : Range(\n                0, sorted_coo.row->shape[0], sorted_coo.row->dtype.bits,\n                sorted_coo.row->ctx);\n  const auto& coalesced_result = COOCoalesce(sorted_coo);\n  const COOMatrix& coalesced_adj = coalesced_result.first;\n  const IdArray& count = coalesced_result.second;\n\n  /**\n   * eids_shuffled actually already contains the mapping from old edge space to\n   * the new one:\n   *\n   * * eids_shuffled[0:count[0]] indicates the original edge IDs that coalesced\n   * into new edge #0.\n   * * eids_shuffled[count[0]:count[0] + count[1]] indicates those that\n   * coalesced into new edge #1.\n   * * eids_shuffled[count[0] + count[1]:count[0] + count[1] + count[2]]\n   * indicates those that coalesced into new edge #2.\n   * * etc.\n   *\n   * Here, we need to translate eids_shuffled to an array \"eids_remapped\" such\n   * that eids_remapped[i] indicates the new edge ID the old edge #i is mapped\n   * to.  The translation can simply be achieved by (in numpy code):\n   *\n   *     new_eid_for_eids_shuffled = np.range(len(count)).repeat(count)\n   *     eids_remapped = np.zeros_like(new_eid_for_eids_shuffled)\n   *     eids_remapped[eids_shuffled] = new_eid_for_eids_shuffled\n   */\n  const IdArray new_eids = Range(\n      0, coalesced_adj.row->shape[0], coalesced_adj.row->dtype.bits,\n      coalesced_adj.row->ctx);\n  const IdArray eids_remapped = Scatter(Repeat(new_eids, count), eids_shuffled);\n\n  COOMatrix ret = COOMatrix(\n      coalesced_adj.num_rows, coalesced_adj.num_cols, coalesced_adj.row,\n      coalesced_adj.col, NullArray(), true, true);\n  return std::make_tuple(ret, count, eids_remapped);\n}\n\n///////////////////////// Graph Traverse routines //////////////////////////\nFrontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source) {\n  Frontiers ret;\n  CHECK_EQ(csr.indptr->ctx.device_type, source->ctx.device_type)\n      << \"Graph and source should in the same device context\";\n  CHECK_EQ(csr.indices->dtype, source->dtype)\n      << \"Graph and source should in the same dtype\";\n  CHECK_EQ(csr.num_rows, csr.num_cols)\n      << \"Graph traversal can only work on square-shaped CSR.\";\n  ATEN_XPU_SWITCH(source->ctx.device_type, XPU, \"BFSNodesFrontiers\", {\n    ATEN_ID_TYPE_SWITCH(source->dtype, IdType, {\n      ret = impl::BFSNodesFrontiers<XPU, IdType>(csr, source);\n    });\n  });\n  return ret;\n}\n\nFrontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source) {\n  Frontiers ret;\n  CHECK_EQ(csr.indptr->ctx.device_type, source->ctx.device_type)\n      << \"Graph and source should in the same device context\";\n  CHECK_EQ(csr.indices->dtype, source->dtype)\n      << \"Graph and source should in the same dtype\";\n  CHECK_EQ(csr.num_rows, csr.num_cols)\n      << \"Graph traversal can only work on square-shaped CSR.\";\n  ATEN_XPU_SWITCH(source->ctx.device_type, XPU, \"BFSEdgesFrontiers\", {\n    ATEN_ID_TYPE_SWITCH(source->dtype, IdType, {\n      ret = impl::BFSEdgesFrontiers<XPU, IdType>(csr, source);\n    });\n  });\n  return ret;\n}\n\nFrontiers TopologicalNodesFrontiers(const CSRMatrix& csr) {\n  Frontiers ret;\n  CHECK_EQ(csr.num_rows, csr.num_cols)\n      << \"Graph traversal can only work on square-shaped CSR.\";\n  ATEN_XPU_SWITCH(\n      csr.indptr->ctx.device_type, XPU, \"TopologicalNodesFrontiers\", {\n        ATEN_ID_TYPE_SWITCH(csr.indices->dtype, IdType, {\n          ret = impl::TopologicalNodesFrontiers<XPU, IdType>(csr);\n        });\n      });\n  return ret;\n}\n\nFrontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source) {\n  Frontiers ret;\n  CHECK_EQ(csr.indptr->ctx.device_type, source->ctx.device_type)\n      << \"Graph and source should in the same device context\";\n  CHECK_EQ(csr.indices->dtype, source->dtype)\n      << \"Graph and source should in the same dtype\";\n  CHECK_EQ(csr.num_rows, csr.num_cols)\n      << \"Graph traversal can only work on square-shaped CSR.\";\n  ATEN_XPU_SWITCH(source->ctx.device_type, XPU, \"DGLDFSEdges\", {\n    ATEN_ID_TYPE_SWITCH(source->dtype, IdType, {\n      ret = impl::DGLDFSEdges<XPU, IdType>(csr, source);\n    });\n  });\n  return ret;\n}\n\nFrontiers DGLDFSLabeledEdges(\n    const CSRMatrix& csr, IdArray source, const bool has_reverse_edge,\n    const bool has_nontree_edge, const bool return_labels) {\n  Frontiers ret;\n  CHECK_EQ(csr.indptr->ctx.device_type, source->ctx.device_type)\n      << \"Graph and source should in the same device context\";\n  CHECK_EQ(csr.indices->dtype, source->dtype)\n      << \"Graph and source should in the same dtype\";\n  CHECK_EQ(csr.num_rows, csr.num_cols)\n      << \"Graph traversal can only work on square-shaped CSR.\";\n  ATEN_XPU_SWITCH(source->ctx.device_type, XPU, \"DGLDFSLabeledEdges\", {\n    ATEN_ID_TYPE_SWITCH(source->dtype, IdType, {\n      ret = impl::DGLDFSLabeledEdges<XPU, IdType>(\n          csr, source, has_reverse_edge, has_nontree_edge, return_labels);\n    });\n  });\n  return ret;\n}\n\nvoid CSRSpMM(\n    const std::string& op, const std::string& reduce, const CSRMatrix& csr,\n    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux) {\n  const auto& bcast = CalcBcastOff(op, ufeat, efeat);\n\n  ATEN_XPU_SWITCH_CUDA(csr.indptr->ctx.device_type, XPU, \"SpMM\", {\n    ATEN_ID_TYPE_SWITCH(csr.indptr->dtype, IdType, {\n      ATEN_FLOAT_TYPE_SWITCH_16BITS(out->dtype, Dtype, XPU, \"Feature data\", {\n        SpMMCsr<XPU, IdType, Dtype>(\n            op, reduce, bcast, csr, ufeat, efeat, out, out_aux);\n      });\n    });\n  });\n}\n\nvoid CSRSpMM(\n    const char* op, const char* reduce, const CSRMatrix& csr, NDArray ufeat,\n    NDArray efeat, NDArray out, std::vector<NDArray> out_aux) {\n  CSRSpMM(\n      std::string(op), std::string(reduce), csr, ufeat, efeat, out, out_aux);\n}\n\nvoid CSRSDDMM(\n    const std::string& op, const CSRMatrix& csr, NDArray ufeat, NDArray efeat,\n    NDArray out, int lhs_target, int rhs_target) {\n  const auto& bcast = CalcBcastOff(op, ufeat, efeat);\n\n  ATEN_XPU_SWITCH_CUDA(csr.indptr->ctx.device_type, XPU, \"SDDMM\", {\n    ATEN_ID_TYPE_SWITCH(csr.indptr->dtype, IdType, {\n      ATEN_FLOAT_TYPE_SWITCH_16BITS(out->dtype, Dtype, XPU, \"Feature data\", {\n        SDDMMCsr<XPU, IdType, Dtype>(\n            op, bcast, csr, ufeat, efeat, out, lhs_target, rhs_target);\n      });\n    });\n  });\n}\n\nvoid CSRSDDMM(\n    const char* op, const CSRMatrix& csr, NDArray ufeat, NDArray efeat,\n    NDArray out, int lhs_target, int rhs_target) {\n  return CSRSDDMM(\n      std::string(op), csr, ufeat, efeat, out, lhs_target, rhs_target);\n}\n\nvoid COOSpMM(\n    const std::string& op, const std::string& reduce, const COOMatrix& coo,\n    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux) {\n  const auto& bcast = CalcBcastOff(op, ufeat, efeat);\n\n  ATEN_XPU_SWITCH_CUDA(coo.row->ctx.device_type, XPU, \"SpMM\", {\n    ATEN_ID_TYPE_SWITCH(coo.row->dtype, IdType, {\n      ATEN_FLOAT_TYPE_SWITCH_16BITS(out->dtype, Dtype, XPU, \"Feature data\", {\n        SpMMCoo<XPU, IdType, Dtype>(\n            op, reduce, bcast, coo, ufeat, efeat, out, out_aux);\n      });\n    });\n  });\n}\n\nvoid COOSpMM(\n    const char* op, const char* reduce, const COOMatrix& coo, NDArray ufeat,\n    NDArray efeat, NDArray out, std::vector<NDArray> out_aux) {\n  COOSpMM(\n      std::string(op), std::string(reduce), coo, ufeat, efeat, out, out_aux);\n}\n\nvoid COOSDDMM(\n    const std::string& op, const COOMatrix& coo, NDArray ufeat, NDArray efeat,\n    NDArray out, int lhs_target, int rhs_target) {\n  const auto& bcast = CalcBcastOff(op, ufeat, efeat);\n\n  ATEN_XPU_SWITCH_CUDA(coo.row->ctx.device_type, XPU, \"SDDMM\", {\n    ATEN_ID_TYPE_SWITCH(coo.row->dtype, IdType, {\n      ATEN_FLOAT_TYPE_SWITCH_16BITS(out->dtype, Dtype, XPU, \"Feature data\", {\n        SDDMMCoo<XPU, IdType, Dtype>(\n            op, bcast, coo, ufeat, efeat, out, lhs_target, rhs_target);\n      });\n    });\n  });\n}\n\nvoid COOSDDMM(\n    const char* op, const COOMatrix& coo, NDArray ufeat, NDArray efeat,\n    NDArray out, int lhs_target, int rhs_target) {\n  COOSDDMM(std::string(op), coo, ufeat, efeat, out, lhs_target, rhs_target);\n}\n\n///////////////////////// C APIs /////////////////////////\nDGL_REGISTER_GLOBAL(\"ndarray._CAPI_DGLSparseMatrixGetFormat\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      SparseMatrixRef spmat = args[0];\n      *rv = spmat->format;\n    });\n\nDGL_REGISTER_GLOBAL(\"ndarray._CAPI_DGLSparseMatrixGetNumRows\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      SparseMatrixRef spmat = args[0];\n      *rv = spmat->num_rows;\n    });\n\nDGL_REGISTER_GLOBAL(\"ndarray._CAPI_DGLSparseMatrixGetNumCols\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      SparseMatrixRef spmat = args[0];\n      *rv = spmat->num_cols;\n    });\n\nDGL_REGISTER_GLOBAL(\"ndarray._CAPI_DGLSparseMatrixGetIndices\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      SparseMatrixRef spmat = args[0];\n      const int64_t i = args[1];\n      *rv = spmat->indices[i];\n    });\n\nDGL_REGISTER_GLOBAL(\"ndarray._CAPI_DGLSparseMatrixGetFlags\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      SparseMatrixRef spmat = args[0];\n      List<Value> flags;\n      for (bool flg : spmat->flags) {\n        flags.push_back(Value(MakeValue(flg)));\n      }\n      *rv = flags;\n    });\n\nDGL_REGISTER_GLOBAL(\"ndarray._CAPI_DGLCreateSparseMatrix\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      const int32_t format = args[0];\n      const int64_t nrows = args[1];\n      const int64_t ncols = args[2];\n      const List<Value> indices = args[3];\n      const List<Value> flags = args[4];\n      std::shared_ptr<SparseMatrix> spmat(new SparseMatrix(\n          format, nrows, ncols, ListValueToVector<IdArray>(indices),\n          ListValueToVector<bool>(flags)));\n      *rv = SparseMatrixRef(spmat);\n    });\n\nDGL_REGISTER_GLOBAL(\"ndarray._CAPI_DGLExistSharedMemArray\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      const std::string name = args[0];\n#ifndef _WIN32\n      *rv = SharedMemory::Exist(name);\n#else\n      *rv = false;\n#endif  // _WIN32\n    });\n\nDGL_REGISTER_GLOBAL(\"ndarray._CAPI_DGLArrayCastToSigned\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      NDArray array = args[0];\n      CHECK_EQ(array->dtype.code, kDGLUInt);\n      std::vector<int64_t> shape(array->shape, array->shape + array->ndim);\n      DGLDataType dtype = array->dtype;\n      dtype.code = kDGLInt;\n      *rv = array.CreateView(shape, dtype, 0);\n    });\n\n}  // namespace aten\n}  // namespace dgl\n\nstd::ostream& operator<<(std::ostream& os, dgl::runtime::NDArray array) {\n  return os << dgl::aten::ToDebugString(array);\n}\n"
  },
  {
    "path": "src/array/array_arith.cc",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file array/array_aritch.cc\n * @brief DGL array arithmetic operations\n */\n#include <dgl/packed_func_ext.h>\n#include <dgl/runtime/container.h>\n#include <dgl/runtime/ndarray.h>\n\n#include \"../c_api_common.h\"\n#include \"./arith.h\"\n#include \"./array_op.h\"\n\nusing namespace dgl::runtime;\n\nnamespace dgl {\nnamespace aten {\n\n// Generate operators with both operations being NDArrays.\n#define BINARY_ELEMENT_OP(name, op)                                  \\\n  IdArray name(IdArray lhs, IdArray rhs) {                           \\\n    IdArray ret;                                                     \\\n    CHECK_SAME_DTYPE(lhs, rhs);                                      \\\n    CHECK_SAME_CONTEXT(lhs, rhs);                                    \\\n    ATEN_XPU_SWITCH_CUDA(lhs->ctx.device_type, XPU, #name, {         \\\n      ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, {                      \\\n        ret = impl::BinaryElewise<XPU, IdType, arith::op>(lhs, rhs); \\\n      });                                                            \\\n    });                                                              \\\n    return ret;                                                      \\\n  }\n\n// Generate operators with only lhs being NDArray.\n#define BINARY_ELEMENT_OP_L(name, op)                                \\\n  IdArray name(IdArray lhs, int64_t rhs) {                           \\\n    IdArray ret;                                                     \\\n    ATEN_XPU_SWITCH_CUDA(lhs->ctx.device_type, XPU, #name, {         \\\n      ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, {                      \\\n        ret = impl::BinaryElewise<XPU, IdType, arith::op>(lhs, rhs); \\\n      });                                                            \\\n    });                                                              \\\n    return ret;                                                      \\\n  }\n\n// Generate operators with only lhs being NDArray.\n#define BINARY_ELEMENT_OP_R(name, op)                                \\\n  IdArray name(int64_t lhs, IdArray rhs) {                           \\\n    IdArray ret;                                                     \\\n    ATEN_XPU_SWITCH_CUDA(rhs->ctx.device_type, XPU, #name, {         \\\n      ATEN_ID_TYPE_SWITCH(rhs->dtype, IdType, {                      \\\n        ret = impl::BinaryElewise<XPU, IdType, arith::op>(lhs, rhs); \\\n      });                                                            \\\n    });                                                              \\\n    return ret;                                                      \\\n  }\n\n// Generate operators with only lhs being NDArray.\n#define UNARY_ELEMENT_OP(name, op)                             \\\n  IdArray name(IdArray lhs) {                                  \\\n    IdArray ret;                                               \\\n    ATEN_XPU_SWITCH_CUDA(lhs->ctx.device_type, XPU, #name, {   \\\n      ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, {                \\\n        ret = impl::UnaryElewise<XPU, IdType, arith::op>(lhs); \\\n      });                                                      \\\n    });                                                        \\\n    return ret;                                                \\\n  }\n\nBINARY_ELEMENT_OP(Add, Add)\nBINARY_ELEMENT_OP(Sub, Sub)\nBINARY_ELEMENT_OP(Mul, Mul)\nBINARY_ELEMENT_OP(Div, Div)\nBINARY_ELEMENT_OP(Mod, Mod)\nBINARY_ELEMENT_OP(GT, GT)\nBINARY_ELEMENT_OP(LT, LT)\nBINARY_ELEMENT_OP(GE, GE)\nBINARY_ELEMENT_OP(LE, LE)\nBINARY_ELEMENT_OP(EQ, EQ)\nBINARY_ELEMENT_OP(NE, NE)\n\nBINARY_ELEMENT_OP_L(Add, Add)\nBINARY_ELEMENT_OP_L(Sub, Sub)\nBINARY_ELEMENT_OP_L(Mul, Mul)\nBINARY_ELEMENT_OP_L(Div, Div)\nBINARY_ELEMENT_OP_L(Mod, Mod)\nBINARY_ELEMENT_OP_L(GT, GT)\nBINARY_ELEMENT_OP_L(LT, LT)\nBINARY_ELEMENT_OP_L(GE, GE)\nBINARY_ELEMENT_OP_L(LE, LE)\nBINARY_ELEMENT_OP_L(EQ, EQ)\nBINARY_ELEMENT_OP_L(NE, NE)\n\nBINARY_ELEMENT_OP_R(Add, Add)\nBINARY_ELEMENT_OP_R(Sub, Sub)\nBINARY_ELEMENT_OP_R(Mul, Mul)\nBINARY_ELEMENT_OP_R(Div, Div)\nBINARY_ELEMENT_OP_R(Mod, Mod)\nBINARY_ELEMENT_OP_R(GT, GT)\nBINARY_ELEMENT_OP_R(LT, LT)\nBINARY_ELEMENT_OP_R(GE, GE)\nBINARY_ELEMENT_OP_R(LE, LE)\nBINARY_ELEMENT_OP_R(EQ, EQ)\nBINARY_ELEMENT_OP_R(NE, NE)\n\nUNARY_ELEMENT_OP(Neg, Neg)\n\n}  // namespace aten\n}  // namespace dgl\n\n///////////////// Operator overloading for NDArray /////////////////\nNDArray operator+(const NDArray& lhs, const NDArray& rhs) {\n  return dgl::aten::Add(lhs, rhs);\n}\nNDArray operator-(const NDArray& lhs, const NDArray& rhs) {\n  return dgl::aten::Sub(lhs, rhs);\n}\nNDArray operator*(const NDArray& lhs, const NDArray& rhs) {\n  return dgl::aten::Mul(lhs, rhs);\n}\nNDArray operator/(const NDArray& lhs, const NDArray& rhs) {\n  return dgl::aten::Div(lhs, rhs);\n}\nNDArray operator%(const NDArray& lhs, const NDArray& rhs) {\n  return dgl::aten::Mod(lhs, rhs);\n}\nNDArray operator+(const NDArray& lhs, int64_t rhs) {\n  return dgl::aten::Add(lhs, rhs);\n}\nNDArray operator-(const NDArray& lhs, int64_t rhs) {\n  return dgl::aten::Sub(lhs, rhs);\n}\nNDArray operator*(const NDArray& lhs, int64_t rhs) {\n  return dgl::aten::Mul(lhs, rhs);\n}\nNDArray operator/(const NDArray& lhs, int64_t rhs) {\n  return dgl::aten::Div(lhs, rhs);\n}\nNDArray operator%(const NDArray& lhs, int64_t rhs) {\n  return dgl::aten::Mod(lhs, rhs);\n}\nNDArray operator+(int64_t lhs, const NDArray& rhs) {\n  return dgl::aten::Add(lhs, rhs);\n}\nNDArray operator-(int64_t lhs, const NDArray& rhs) {\n  return dgl::aten::Sub(lhs, rhs);\n}\nNDArray operator*(int64_t lhs, const NDArray& rhs) {\n  return dgl::aten::Mul(lhs, rhs);\n}\nNDArray operator/(int64_t lhs, const NDArray& rhs) {\n  return dgl::aten::Div(lhs, rhs);\n}\nNDArray operator%(int64_t lhs, const NDArray& rhs) {\n  return dgl::aten::Mod(lhs, rhs);\n}\nNDArray operator-(const NDArray& array) { return dgl::aten::Neg(array); }\n\nNDArray operator>(const NDArray& lhs, const NDArray& rhs) {\n  return dgl::aten::GT(lhs, rhs);\n}\nNDArray operator<(const NDArray& lhs, const NDArray& rhs) {\n  return dgl::aten::LT(lhs, rhs);\n}\nNDArray operator>=(const NDArray& lhs, const NDArray& rhs) {\n  return dgl::aten::GE(lhs, rhs);\n}\nNDArray operator<=(const NDArray& lhs, const NDArray& rhs) {\n  return dgl::aten::LE(lhs, rhs);\n}\nNDArray operator==(const NDArray& lhs, const NDArray& rhs) {\n  return dgl::aten::EQ(lhs, rhs);\n}\nNDArray operator!=(const NDArray& lhs, const NDArray& rhs) {\n  return dgl::aten::NE(lhs, rhs);\n}\nNDArray operator>(const NDArray& lhs, int64_t rhs) {\n  return dgl::aten::GT(lhs, rhs);\n}\nNDArray operator<(const NDArray& lhs, int64_t rhs) {\n  return dgl::aten::LT(lhs, rhs);\n}\nNDArray operator>=(const NDArray& lhs, int64_t rhs) {\n  return dgl::aten::GE(lhs, rhs);\n}\nNDArray operator<=(const NDArray& lhs, int64_t rhs) {\n  return dgl::aten::LE(lhs, rhs);\n}\nNDArray operator==(const NDArray& lhs, int64_t rhs) {\n  return dgl::aten::EQ(lhs, rhs);\n}\nNDArray operator!=(const NDArray& lhs, int64_t rhs) {\n  return dgl::aten::NE(lhs, rhs);\n}\nNDArray operator>(int64_t lhs, const NDArray& rhs) {\n  return dgl::aten::GT(lhs, rhs);\n}\nNDArray operator<(int64_t lhs, const NDArray& rhs) {\n  return dgl::aten::LT(lhs, rhs);\n}\nNDArray operator>=(int64_t lhs, const NDArray& rhs) {\n  return dgl::aten::GE(lhs, rhs);\n}\nNDArray operator<=(int64_t lhs, const NDArray& rhs) {\n  return dgl::aten::LE(lhs, rhs);\n}\nNDArray operator==(int64_t lhs, const NDArray& rhs) {\n  return dgl::aten::EQ(lhs, rhs);\n}\nNDArray operator!=(int64_t lhs, const NDArray& rhs) {\n  return dgl::aten::NE(lhs, rhs);\n}\n"
  },
  {
    "path": "src/array/array_op.h",
    "content": "/**\n *  Copyright (c) 2019-2022 by Contributors\n * @file array/array_op.h\n * @brief Array operator templates\n */\n#ifndef DGL_ARRAY_ARRAY_OP_H_\n#define DGL_ARRAY_ARRAY_OP_H_\n\n#include <dgl/array.h>\n#include <dgl/graph_traversal.h>\n\n#include <tuple>\n#include <utility>\n#include <vector>\n\nnamespace dgl {\nnamespace aten {\nnamespace impl {\n\ntemplate <DGLDeviceType XPU, typename IdType>\nIdArray Full(IdType val, int64_t length, DGLContext ctx);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nIdArray Range(IdType low, IdType high, DGLContext ctx);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nIdArray AsNumBits(IdArray arr, uint8_t bits);\n\ntemplate <DGLDeviceType XPU, typename IdType, typename Op>\nIdArray BinaryElewise(IdArray lhs, IdArray rhs);\n\ntemplate <DGLDeviceType XPU, typename IdType, typename Op>\nIdArray BinaryElewise(IdArray lhs, IdType rhs);\n\ntemplate <DGLDeviceType XPU, typename IdType, typename Op>\nIdArray BinaryElewise(IdType lhs, IdArray rhs);\n\ntemplate <DGLDeviceType XPU, typename IdType, typename Op>\nIdArray UnaryElewise(IdArray array);\n\ntemplate <DGLDeviceType XPU, typename DType, typename IdType>\nNDArray IndexSelect(NDArray array, IdArray index);\n\ntemplate <DGLDeviceType XPU, typename DType>\nDType IndexSelect(NDArray array, int64_t index);\n\ntemplate <DGLDeviceType XPU, typename DType>\nIdArray NonZero(BoolArray bool_arr);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nIdArray NonZero(NDArray array);\n\ntemplate <DGLDeviceType XPU, typename DType>\nstd::pair<IdArray, IdArray> Sort(IdArray array, int num_bits);\n\ntemplate <DGLDeviceType XPU, typename DType, typename IdType>\nNDArray Scatter(NDArray array, IdArray indices);\n\ntemplate <DGLDeviceType XPU, typename DType, typename IdType>\nvoid Scatter_(IdArray index, NDArray value, NDArray out);\n\ntemplate <DGLDeviceType XPU, typename DType, typename IdType>\nNDArray Repeat(NDArray array, IdArray repeats);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nIdArray Relabel_(const std::vector<IdArray>& arrays);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nNDArray Concat(const std::vector<IdArray>& arrays);\n\ntemplate <DGLDeviceType XPU, typename DType>\nstd::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, DType pad_value);\n\ntemplate <DGLDeviceType XPU, typename DType, typename IdType>\nstd::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nIdArray CumSum(IdArray array, bool prepend_zero);\n\n// sparse arrays\n\ntemplate <DGLDeviceType XPU, typename IdType>\nbool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nruntime::NDArray CSRIsNonZero(\n    CSRMatrix csr, runtime::NDArray row, runtime::NDArray col);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nbool CSRHasDuplicate(CSRMatrix csr);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nint64_t CSRGetRowNNZ(CSRMatrix csr, int64_t row);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nruntime::NDArray CSRGetRowNNZ(CSRMatrix csr, runtime::NDArray row);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nruntime::NDArray CSRGetRowColumnIndices(CSRMatrix csr, int64_t row);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nruntime::NDArray CSRGetRowData(CSRMatrix csr, int64_t row);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nbool CSRIsSorted(CSRMatrix csr);\n\ntemplate <DGLDeviceType XPU, typename IdType, typename DType>\nruntime::NDArray CSRGetData(\n    CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols,\n    bool return_eids, runtime::NDArray weights, DType filler);\n\ntemplate <DGLDeviceType XPU, typename IdType, typename DType>\nruntime::NDArray CSRGetData(\n    CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols,\n    runtime::NDArray weights, DType filler) {\n  return CSRGetData<XPU, IdType, DType>(\n      csr, rows, cols, false, weights, filler);\n}\n\ntemplate <DGLDeviceType XPU, typename IdType>\nNDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols) {\n  return CSRGetData<XPU, IdType, IdType>(\n      csr, rows, cols, true, NullArray(rows->dtype), -1);\n}\n\ntemplate <DGLDeviceType XPU, typename IdType>\nstd::vector<runtime::NDArray> CSRGetDataAndIndices(\n    CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nCSRMatrix CSRTranspose(CSRMatrix csr);\n\n// Convert CSR to COO\ntemplate <DGLDeviceType XPU, typename IdType>\nCOOMatrix CSRToCOO(CSRMatrix csr);\n\n// Convert CSR to COO using data array as order\ntemplate <DGLDeviceType XPU, typename IdType>\nCOOMatrix CSRToCOODataAsOrder(CSRMatrix csr);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nCSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nCSRMatrix CSRSliceRows(CSRMatrix csr, runtime::NDArray rows);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nCSRMatrix CSRSliceMatrix(\n    CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nvoid CSRSort_(CSRMatrix* csr);\n\ntemplate <DGLDeviceType XPU, typename IdType, typename TagType>\nstd::pair<CSRMatrix, NDArray> CSRSortByTag(\n    const CSRMatrix& csr, IdArray tag_array, int64_t num_tags);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nCSRMatrix CSRReorder(\n    CSRMatrix csr, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nCOOMatrix COOReorder(\n    COOMatrix coo, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nCSRMatrix CSRRemove(CSRMatrix csr, IdArray entries);\n\ntemplate <DGLDeviceType XPU, typename IdType, typename FloatType>\nstd::pair<COOMatrix, FloatArray> CSRLaborSampling(\n    CSRMatrix mat, IdArray rows, int64_t num_samples, FloatArray prob,\n    int importance_sampling, IdArray random_seed, float seed2_contribution,\n    IdArray NIDs);\n\n// FloatType is the type of probability data.\ntemplate <DGLDeviceType XPU, typename IdType, typename DType>\nCOOMatrix CSRRowWiseSampling(\n    CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask,\n    bool replace);\n\n// FloatType is the type of probability data.\ntemplate <\n    DGLDeviceType XPU, typename IdxType, typename DType, bool map_seed_nodes>\nstd::pair<CSRMatrix, IdArray> CSRRowWiseSamplingFused(\n    CSRMatrix mat, IdArray rows, IdArray seed_mapping,\n    std::vector<IdxType>* new_seed_nodes, int64_t num_samples,\n    NDArray prob_or_mask, bool replace);\n\n// FloatType is the type of probability data.\ntemplate <DGLDeviceType XPU, typename IdType, typename DType>\nCOOMatrix CSRRowWisePerEtypeSampling(\n    CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,\n    const std::vector<int64_t>& num_samples,\n    const std::vector<NDArray>& prob_or_mask, bool replace,\n    bool rowwise_etype_sorted);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nCOOMatrix CSRRowWiseSamplingUniform(\n    CSRMatrix mat, IdArray rows, int64_t num_samples, bool replace);\n\ntemplate <DGLDeviceType XPU, typename IdType, bool map_seed_nodes>\nstd::pair<CSRMatrix, IdArray> CSRRowWiseSamplingUniformFused(\n    CSRMatrix mat, IdArray rows, IdArray seed_mapping,\n    std::vector<IdType>* new_seed_nodes, int64_t num_samples, bool replace);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nCOOMatrix CSRRowWisePerEtypeSamplingUniform(\n    CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,\n    const std::vector<int64_t>& num_samples, bool replace,\n    bool rowwise_etype_sorted);\n\n// FloatType is the type of weight data.\ntemplate <DGLDeviceType XPU, typename IdType, typename DType>\nCOOMatrix CSRRowWiseTopk(\n    CSRMatrix mat, IdArray rows, int64_t k, NDArray weight, bool ascending);\n\ntemplate <DGLDeviceType XPU, typename IdType, typename FloatType>\nCOOMatrix CSRRowWiseSamplingBiased(\n    CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray tag_offset,\n    FloatArray bias, bool replace);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nstd::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(\n    const CSRMatrix& csr, int64_t num_samples, int num_trials,\n    bool exclude_self_loops, bool replace, double redundancy);\n\n// Union CSRMatrixes\ntemplate <DGLDeviceType XPU, typename IdType>\nCSRMatrix UnionCsr(const std::vector<CSRMatrix>& csrs);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nstd::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple(CSRMatrix csr);\n\n////////////////////////////////////////////////////////////////////////////////\n\ntemplate <DGLDeviceType XPU, typename IdType>\nbool COOIsNonZero(COOMatrix coo, int64_t row, int64_t col);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nruntime::NDArray COOIsNonZero(\n    COOMatrix coo, runtime::NDArray row, runtime::NDArray col);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nbool COOHasDuplicate(COOMatrix coo);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nint64_t COOGetRowNNZ(COOMatrix coo, int64_t row);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nruntime::NDArray COOGetRowNNZ(COOMatrix coo, runtime::NDArray row);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nstd::pair<runtime::NDArray, runtime::NDArray> COOGetRowDataAndIndices(\n    COOMatrix coo, int64_t row);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nstd::vector<runtime::NDArray> COOGetDataAndIndices(\n    COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nruntime::NDArray COOGetData(\n    COOMatrix mat, runtime::NDArray rows, runtime::NDArray cols);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nCOOMatrix COOTranspose(COOMatrix coo);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nCSRMatrix COOToCSR(COOMatrix coo);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nCOOMatrix COOSliceRows(COOMatrix coo, int64_t start, int64_t end);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nCOOMatrix COOSliceRows(COOMatrix coo, runtime::NDArray rows);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nCOOMatrix COOSliceMatrix(\n    COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nstd::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nCOOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nvoid COOSort_(COOMatrix* mat, bool sort_column);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nstd::pair<bool, bool> COOIsSorted(COOMatrix coo);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nCOOMatrix COORemove(COOMatrix coo, IdArray entries);\n\ntemplate <DGLDeviceType XPU, typename IdType, typename FloatType>\nstd::pair<COOMatrix, FloatArray> COOLaborSampling(\n    COOMatrix mat, IdArray rows, int64_t num_samples, FloatArray prob,\n    int importance_sampling, IdArray random_seed, float seed2_contribution,\n    IdArray NIDs);\n\n// FloatType is the type of probability data.\ntemplate <DGLDeviceType XPU, typename IdType, typename DType>\nCOOMatrix COORowWiseSampling(\n    COOMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask,\n    bool replace);\n\n// FloatType is the type of probability data.\ntemplate <DGLDeviceType XPU, typename IdType, typename DType>\nCOOMatrix COORowWisePerEtypeSampling(\n    COOMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,\n    const std::vector<int64_t>& num_samples,\n    const std::vector<NDArray>& prob_or_mask, bool replace);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nCOOMatrix COORowWiseSamplingUniform(\n    COOMatrix mat, IdArray rows, int64_t num_samples, bool replace);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nCOOMatrix COORowWisePerEtypeSamplingUniform(\n    COOMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,\n    const std::vector<int64_t>& num_samples, bool replace);\n\n// FloatType is the type of weight data.\ntemplate <DGLDeviceType XPU, typename IdType, typename FloatType>\nCOOMatrix COORowWiseTopk(\n    COOMatrix mat, IdArray rows, int64_t k, FloatArray weight, bool ascending);\n\n///////////////////////// Graph Traverse routines //////////////////////////\n\ntemplate <DGLDeviceType XPU, typename IdType>\nFrontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nFrontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nFrontiers TopologicalNodesFrontiers(const CSRMatrix& csr);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nFrontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nFrontiers DGLDFSLabeledEdges(\n    const CSRMatrix& csr, IdArray source, const bool has_reverse_edge,\n    const bool has_nontree_edge, const bool return_labels);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nCOOMatrix COOLineGraph(const COOMatrix& coo, bool backtracking);\n\n}  // namespace impl\n}  // namespace aten\n}  // namespace dgl\n\n#endif  // DGL_ARRAY_ARRAY_OP_H_\n"
  },
  {
    "path": "src/array/check.h",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file array/check.h\n * @brief DGL check utilities\n */\n#ifndef DGL_ARRAY_CHECK_H_\n#define DGL_ARRAY_CHECK_H_\n\n#include <dgl/array.h>\n#include <dgl/runtime/ndarray.h>\n\n#include <string>\n#include <vector>\n\nnamespace dgl {\nnamespace aten {\n\n// Check whether the given arguments have the same context.\ninline void CheckCtx(\n    const DGLContext& ctx, const std::vector<NDArray>& arrays,\n    const std::vector<std::string>& names) {\n  for (size_t i = 0; i < arrays.size(); ++i) {\n    if (IsNullArray(arrays[i])) continue;\n    CHECK_EQ(ctx, arrays[i]->ctx)\n        << \"Expected device context \" << ctx << \". But got \" << arrays[i]->ctx\n        << \" for \" << names[i] << \".\";\n  }\n}\n\n// Check whether input tensors are contiguous.\ninline void CheckContiguous(\n    const std::vector<NDArray>& arrays, const std::vector<std::string>& names) {\n  for (size_t i = 0; i < arrays.size(); ++i) {\n    if (IsNullArray(arrays[i])) continue;\n    CHECK(arrays[i].IsContiguous())\n        << \"Expect \" << names[i] << \" to be a contiguous tensor\";\n  }\n}\n\n// Check whether input tensors have valid shape.\ninline void CheckShape(\n    const std::vector<uint64_t>& gdim, const std::vector<int>& uev_idx,\n    const std::vector<NDArray>& arrays, const std::vector<std::string>& names) {\n  for (size_t i = 0; i < arrays.size(); ++i) {\n    if (IsNullArray(arrays[i])) continue;\n    CHECK_GE(arrays[i]->ndim, 2)\n        << \"Expect \" << names[i] << \" to have ndim >= 2, \"\n        << \"Note that for scalar feature we expand its \"\n        << \"dimension with an additional dimension of \"\n        << \"length one.\";\n    CHECK_EQ(gdim[uev_idx[i]], arrays[i]->shape[0])\n        << \"Expect \" << names[i] << \" to have size \" << gdim[uev_idx[i]]\n        << \" on the first dimension, \"\n        << \"but got \" << arrays[i]->shape[0];\n  }\n}\n\n}  // namespace aten\n}  // namespace dgl\n\n#endif  // DGL_ARRAY_CHECK_H_\n"
  },
  {
    "path": "src/array/cpu/array_cumsum.cc",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file array/cpu/array_cumsum.cc\n * @brief Array cumsum CPU implementation\n */\n#include <dgl/array.h>\n\nnamespace dgl {\nusing runtime::NDArray;\nnamespace aten {\nnamespace impl {\n\ntemplate <DGLDeviceType XPU, typename IdType>\nIdArray CumSum(IdArray array, bool prepend_zero) {\n  const int64_t len = array.NumElements();\n  if (len == 0)\n    return !prepend_zero ? array\n                         : aten::Full(0, 1, array->dtype.bits, array->ctx);\n  if (prepend_zero) {\n    IdArray ret = aten::NewIdArray(len + 1, array->ctx, array->dtype.bits);\n    const IdType* in_d = array.Ptr<IdType>();\n    IdType* out_d = ret.Ptr<IdType>();\n    out_d[0] = 0;\n    for (int64_t i = 0; i < len; ++i) out_d[i + 1] = out_d[i] + in_d[i];\n    return ret;\n  } else {\n    IdArray ret = aten::NewIdArray(len, array->ctx, array->dtype.bits);\n    const IdType* in_d = array.Ptr<IdType>();\n    IdType* out_d = ret.Ptr<IdType>();\n    out_d[0] = in_d[0];\n    for (int64_t i = 1; i < len; ++i) out_d[i] = out_d[i - 1] + in_d[i];\n    return ret;\n  }\n}\n\ntemplate IdArray CumSum<kDGLCPU, int32_t>(IdArray, bool);\ntemplate IdArray CumSum<kDGLCPU, int64_t>(IdArray, bool);\n\n}  // namespace impl\n}  // namespace aten\n}  // namespace dgl\n"
  },
  {
    "path": "src/array/cpu/array_index_select.cc",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file array/cpu/array_index_select.cc\n * @brief Array index select CPU implementation\n */\n#include <dgl/array.h>\n\nnamespace dgl {\nusing runtime::NDArray;\nnamespace aten {\nnamespace impl {\n\ntemplate <DGLDeviceType XPU, typename DType, typename IdType>\nNDArray IndexSelect(NDArray array, IdArray index) {\n  CHECK_EQ(array->shape[0], array.NumElements())\n      << \"Only support tensor\"\n      << \" whose first dimension equals number of elements, e.g. (5,), (5, 1)\";\n\n  const DType* array_data = static_cast<DType*>(array->data);\n  const IdType* idx_data = static_cast<IdType*>(index->data);\n  const int64_t arr_len = array->shape[0];\n  const int64_t len = index->shape[0];\n  NDArray ret = NDArray::Empty({len}, array->dtype, array->ctx);\n  DType* ret_data = static_cast<DType*>(ret->data);\n  for (int64_t i = 0; i < len; ++i) {\n    CHECK_LT(idx_data[i], arr_len) << \"Index out of range.\";\n    ret_data[i] = array_data[idx_data[i]];\n  }\n  return ret;\n}\n\ntemplate NDArray IndexSelect<kDGLCPU, int32_t, int32_t>(NDArray, IdArray);\ntemplate NDArray IndexSelect<kDGLCPU, int32_t, int64_t>(NDArray, IdArray);\ntemplate NDArray IndexSelect<kDGLCPU, int64_t, int32_t>(NDArray, IdArray);\ntemplate NDArray IndexSelect<kDGLCPU, int64_t, int64_t>(NDArray, IdArray);\ntemplate NDArray IndexSelect<kDGLCPU, float, int32_t>(NDArray, IdArray);\ntemplate NDArray IndexSelect<kDGLCPU, float, int64_t>(NDArray, IdArray);\ntemplate NDArray IndexSelect<kDGLCPU, double, int32_t>(NDArray, IdArray);\ntemplate NDArray IndexSelect<kDGLCPU, double, int64_t>(NDArray, IdArray);\n\ntemplate <DGLDeviceType XPU, typename DType>\nDType IndexSelect(NDArray array, int64_t index) {\n  const DType* data = static_cast<DType*>(array->data);\n  return data[index];\n}\n\ntemplate int32_t IndexSelect<kDGLCPU, int32_t>(NDArray array, int64_t index);\ntemplate int64_t IndexSelect<kDGLCPU, int64_t>(NDArray array, int64_t index);\ntemplate float IndexSelect<kDGLCPU, float>(NDArray array, int64_t index);\ntemplate double IndexSelect<kDGLCPU, double>(NDArray array, int64_t index);\n\n}  // namespace impl\n}  // namespace aten\n}  // namespace dgl\n"
  },
  {
    "path": "src/array/cpu/array_nonzero.cc",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file array/cpu/array_nonzero.cc\n * @brief Array nonzero CPU implementation\n */\n#include <dgl/array.h>\n\nnamespace dgl {\nusing runtime::NDArray;\nnamespace aten {\nnamespace impl {\n\ntemplate <DGLDeviceType XPU, typename IdType>\nIdArray NonZero(IdArray array) {\n  std::vector<int64_t> ret;\n  const IdType* data = array.Ptr<IdType>();\n  for (int64_t i = 0; i < array->shape[0]; ++i)\n    if (data[i] != 0) ret.push_back(i);\n  return NDArray::FromVector(ret, array->ctx);\n}\n\ntemplate IdArray NonZero<kDGLCPU, int32_t>(IdArray);\ntemplate IdArray NonZero<kDGLCPU, int64_t>(IdArray);\n\n}  // namespace impl\n}  // namespace aten\n}  // namespace dgl\n"
  },
  {
    "path": "src/array/cpu/array_op_impl.cc",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file array/cpu/array_op_impl.cc\n * @brief Array operator CPU implementation\n */\n#include <dgl/array.h>\n#include <dgl/runtime/ndarray.h>\n#include <dgl/runtime/parallel_for.h>\n\n#include <numeric>\n\n#include \"../arith.h\"\n\nnamespace dgl {\nusing runtime::NDArray;\nusing runtime::parallel_for;\nnamespace aten {\nnamespace impl {\n\n///////////////////////////// AsNumBits /////////////////////////////\n\ntemplate <DGLDeviceType XPU, typename IdType>\nIdArray AsNumBits(IdArray arr, uint8_t bits) {\n  CHECK(bits == 32 || bits == 64) << \"invalid number of integer bits\";\n  if (sizeof(IdType) * 8 == bits) {\n    return arr;\n  }\n  const int64_t len = arr->shape[0];\n  IdArray ret = NewIdArray(len, arr->ctx, bits);\n  const IdType* arr_data = static_cast<IdType*>(arr->data);\n  if (bits == 32) {\n    int32_t* ret_data = static_cast<int32_t*>(ret->data);\n    for (int64_t i = 0; i < len; ++i) {\n      ret_data[i] = arr_data[i];\n    }\n  } else {\n    int64_t* ret_data = static_cast<int64_t*>(ret->data);\n    for (int64_t i = 0; i < len; ++i) {\n      ret_data[i] = arr_data[i];\n    }\n  }\n  return ret;\n}\n\ntemplate IdArray AsNumBits<kDGLCPU, int32_t>(IdArray arr, uint8_t bits);\ntemplate IdArray AsNumBits<kDGLCPU, int64_t>(IdArray arr, uint8_t bits);\n\n///////////////////////////// BinaryElewise /////////////////////////////\n\ntemplate <DGLDeviceType XPU, typename IdType, typename Op>\nIdArray BinaryElewise(IdArray lhs, IdArray rhs) {\n  IdArray ret = NewIdArray(lhs->shape[0], lhs->ctx, lhs->dtype.bits);\n  const IdType* lhs_data = static_cast<IdType*>(lhs->data);\n  const IdType* rhs_data = static_cast<IdType*>(rhs->data);\n  IdType* ret_data = static_cast<IdType*>(ret->data);\n  // TODO(BarclayII): this usually incurs lots of overhead in thread spawning,\n  // scheduling, etc., especially since the workload is very light.  Need to\n  // replace with parallel_for.\n  for (int64_t i = 0; i < lhs->shape[0]; i++) {\n    ret_data[i] = Op::Call(lhs_data[i], rhs_data[i]);\n  }\n  return ret;\n}\n\ntemplate IdArray BinaryElewise<kDGLCPU, int32_t, arith::Add>(\n    IdArray lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int32_t, arith::Sub>(\n    IdArray lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int32_t, arith::Mul>(\n    IdArray lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int32_t, arith::Div>(\n    IdArray lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int32_t, arith::Mod>(\n    IdArray lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int32_t, arith::GT>(\n    IdArray lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int32_t, arith::LT>(\n    IdArray lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int32_t, arith::GE>(\n    IdArray lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int32_t, arith::LE>(\n    IdArray lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int32_t, arith::EQ>(\n    IdArray lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int32_t, arith::NE>(\n    IdArray lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int64_t, arith::Add>(\n    IdArray lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int64_t, arith::Sub>(\n    IdArray lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int64_t, arith::Mul>(\n    IdArray lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int64_t, arith::Div>(\n    IdArray lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int64_t, arith::Mod>(\n    IdArray lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int64_t, arith::GT>(\n    IdArray lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int64_t, arith::LT>(\n    IdArray lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int64_t, arith::GE>(\n    IdArray lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int64_t, arith::LE>(\n    IdArray lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int64_t, arith::EQ>(\n    IdArray lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int64_t, arith::NE>(\n    IdArray lhs, IdArray rhs);\n\ntemplate <DGLDeviceType XPU, typename IdType, typename Op>\nIdArray BinaryElewise(IdArray lhs, IdType rhs) {\n  IdArray ret = NewIdArray(lhs->shape[0], lhs->ctx, lhs->dtype.bits);\n  const IdType* lhs_data = static_cast<IdType*>(lhs->data);\n  IdType* ret_data = static_cast<IdType*>(ret->data);\n  // TODO(BarclayII): this usually incurs lots of overhead in thread spawning,\n  // scheduling, etc., especially since the workload is very light.  Need to\n  // replace with parallel_for.\n  for (int64_t i = 0; i < lhs->shape[0]; i++) {\n    ret_data[i] = Op::Call(lhs_data[i], rhs);\n  }\n  return ret;\n}\n\ntemplate IdArray BinaryElewise<kDGLCPU, int32_t, arith::Add>(\n    IdArray lhs, int32_t rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int32_t, arith::Sub>(\n    IdArray lhs, int32_t rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int32_t, arith::Mul>(\n    IdArray lhs, int32_t rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int32_t, arith::Div>(\n    IdArray lhs, int32_t rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int32_t, arith::Mod>(\n    IdArray lhs, int32_t rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int32_t, arith::GT>(\n    IdArray lhs, int32_t rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int32_t, arith::LT>(\n    IdArray lhs, int32_t rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int32_t, arith::GE>(\n    IdArray lhs, int32_t rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int32_t, arith::LE>(\n    IdArray lhs, int32_t rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int32_t, arith::EQ>(\n    IdArray lhs, int32_t rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int32_t, arith::NE>(\n    IdArray lhs, int32_t rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int64_t, arith::Add>(\n    IdArray lhs, int64_t rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int64_t, arith::Sub>(\n    IdArray lhs, int64_t rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int64_t, arith::Mul>(\n    IdArray lhs, int64_t rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int64_t, arith::Div>(\n    IdArray lhs, int64_t rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int64_t, arith::Mod>(\n    IdArray lhs, int64_t rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int64_t, arith::GT>(\n    IdArray lhs, int64_t rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int64_t, arith::LT>(\n    IdArray lhs, int64_t rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int64_t, arith::GE>(\n    IdArray lhs, int64_t rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int64_t, arith::LE>(\n    IdArray lhs, int64_t rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int64_t, arith::EQ>(\n    IdArray lhs, int64_t rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int64_t, arith::NE>(\n    IdArray lhs, int64_t rhs);\n\ntemplate <DGLDeviceType XPU, typename IdType, typename Op>\nIdArray BinaryElewise(IdType lhs, IdArray rhs) {\n  IdArray ret = NewIdArray(rhs->shape[0], rhs->ctx, rhs->dtype.bits);\n  const IdType* rhs_data = static_cast<IdType*>(rhs->data);\n  IdType* ret_data = static_cast<IdType*>(ret->data);\n  // TODO(BarclayII): this usually incurs lots of overhead in thread spawning,\n  // scheduling, etc., especially since the workload is very light.  Need to\n  // replace with parallel_for.\n  for (int64_t i = 0; i < rhs->shape[0]; i++) {\n    ret_data[i] = Op::Call(lhs, rhs_data[i]);\n  }\n  return ret;\n}\n\ntemplate IdArray BinaryElewise<kDGLCPU, int32_t, arith::Add>(\n    int32_t lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int32_t, arith::Sub>(\n    int32_t lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int32_t, arith::Mul>(\n    int32_t lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int32_t, arith::Div>(\n    int32_t lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int32_t, arith::Mod>(\n    int32_t lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int32_t, arith::GT>(\n    int32_t lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int32_t, arith::LT>(\n    int32_t lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int32_t, arith::GE>(\n    int32_t lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int32_t, arith::LE>(\n    int32_t lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int32_t, arith::EQ>(\n    int32_t lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int32_t, arith::NE>(\n    int32_t lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int64_t, arith::Add>(\n    int64_t lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int64_t, arith::Sub>(\n    int64_t lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int64_t, arith::Mul>(\n    int64_t lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int64_t, arith::Div>(\n    int64_t lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int64_t, arith::Mod>(\n    int64_t lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int64_t, arith::GT>(\n    int64_t lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int64_t, arith::LT>(\n    int64_t lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int64_t, arith::GE>(\n    int64_t lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int64_t, arith::LE>(\n    int64_t lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int64_t, arith::EQ>(\n    int64_t lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCPU, int64_t, arith::NE>(\n    int64_t lhs, IdArray rhs);\n\ntemplate <DGLDeviceType XPU, typename IdType, typename Op>\nIdArray UnaryElewise(IdArray lhs) {\n  IdArray ret = NewIdArray(lhs->shape[0], lhs->ctx, lhs->dtype.bits);\n  const IdType* lhs_data = static_cast<IdType*>(lhs->data);\n  IdType* ret_data = static_cast<IdType*>(ret->data);\n  // TODO(BarclayII): this usually incurs lots of overhead in thread spawning,\n  // scheduling, etc., especially since the workload is very light.  Need to\n  // replace with parallel_for.\n  for (int64_t i = 0; i < lhs->shape[0]; i++) {\n    ret_data[i] = Op::Call(lhs_data[i]);\n  }\n  return ret;\n}\n\ntemplate IdArray UnaryElewise<kDGLCPU, int32_t, arith::Neg>(IdArray lhs);\ntemplate IdArray UnaryElewise<kDGLCPU, int64_t, arith::Neg>(IdArray lhs);\n\n///////////////////////////// Full /////////////////////////////\n\ntemplate <DGLDeviceType XPU, typename DType>\nNDArray Full(DType val, int64_t length, DGLContext ctx) {\n  NDArray ret = NDArray::Empty({length}, DGLDataTypeTraits<DType>::dtype, ctx);\n  DType* ret_data = static_cast<DType*>(ret->data);\n  std::fill(ret_data, ret_data + length, val);\n  return ret;\n}\n\ntemplate NDArray Full<kDGLCPU, int32_t>(\n    int32_t val, int64_t length, DGLContext ctx);\ntemplate NDArray Full<kDGLCPU, int64_t>(\n    int64_t val, int64_t length, DGLContext ctx);\ntemplate NDArray Full<kDGLCPU, float>(\n    float val, int64_t length, DGLContext ctx);\ntemplate NDArray Full<kDGLCPU, double>(\n    double val, int64_t length, DGLContext ctx);\n\n///////////////////////////// Range /////////////////////////////\n\ntemplate <DGLDeviceType XPU, typename IdType>\nIdArray Range(IdType low, IdType high, DGLContext ctx) {\n  CHECK(high >= low) << \"high must be bigger than low\";\n  IdArray ret = NewIdArray(high - low, ctx, sizeof(IdType) * 8);\n  IdType* ret_data = static_cast<IdType*>(ret->data);\n  std::iota(ret_data, ret_data + high - low, low);\n  return ret;\n}\n\ntemplate IdArray Range<kDGLCPU, int32_t>(int32_t, int32_t, DGLContext);\ntemplate IdArray Range<kDGLCPU, int64_t>(int64_t, int64_t, DGLContext);\n\n///////////////////////////// Relabel_ /////////////////////////////\n\ntemplate <DGLDeviceType XPU, typename IdType>\nIdArray Relabel_(const std::vector<IdArray>& arrays) {\n  // build map & relabel\n  IdType newid = 0;\n  std::unordered_map<IdType, IdType> oldv2newv;\n  for (IdArray arr : arrays) {\n    for (int64_t i = 0; i < arr->shape[0]; ++i) {\n      const IdType id = static_cast<IdType*>(arr->data)[i];\n      if (!oldv2newv.count(id)) {\n        oldv2newv[id] = newid++;\n      }\n      static_cast<IdType*>(arr->data)[i] = oldv2newv[id];\n    }\n  }\n  // map array\n  IdArray maparr =\n      NewIdArray(newid, DGLContext{kDGLCPU, 0}, sizeof(IdType) * 8);\n  IdType* maparr_data = static_cast<IdType*>(maparr->data);\n  for (const auto& kv : oldv2newv) {\n    maparr_data[kv.second] = kv.first;\n  }\n  return maparr;\n}\n\ntemplate IdArray Relabel_<kDGLCPU, int32_t>(const std::vector<IdArray>& arrays);\ntemplate IdArray Relabel_<kDGLCPU, int64_t>(const std::vector<IdArray>& arrays);\n\n}  // namespace impl\n}  // namespace aten\n}  // namespace dgl\n"
  },
  {
    "path": "src/array/cpu/array_pack.cc",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file array/cpu/array_index_select.cc\n * @brief Array index select CPU implementation\n */\n#include <dgl/array.h>\n#include <dgl/runtime/parallel_for.h>\n\n#include <tuple>\n#include <utility>\n\nnamespace dgl {\nusing runtime::NDArray;\nusing runtime::parallel_for;\nnamespace aten {\nnamespace impl {\n\ntemplate <DGLDeviceType XPU, typename DType, typename IdType>\nstd::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths) {\n  const int64_t rows = lengths->shape[0];\n  const int64_t cols = (array->ndim == 1 ? array->shape[0] : array->shape[1]);\n  const int64_t stride = (array->ndim == 1 ? 0 : cols);\n  const DType *array_data = static_cast<DType *>(array->data);\n  const IdType *length_data = static_cast<IdType *>(lengths->data);\n\n  IdArray offsets = NewIdArray(rows, array->ctx, sizeof(IdType) * 8);\n  IdType *offsets_data = static_cast<IdType *>(offsets->data);\n  for (int64_t i = 0; i < rows; ++i)\n    offsets_data[i] = (i == 0 ? 0 : length_data[i - 1] + offsets_data[i - 1]);\n  const int64_t total_length = offsets_data[rows - 1] + length_data[rows - 1];\n\n  NDArray concat = NDArray::Empty({total_length}, array->dtype, array->ctx);\n  DType *concat_data = static_cast<DType *>(concat->data);\n\n  parallel_for(0, rows, [=](size_t b, size_t e) {\n    for (auto i = b; i < e; ++i) {\n      for (int64_t j = 0; j < length_data[i]; ++j)\n        concat_data[offsets_data[i] + j] = array_data[i * stride + j];\n    }\n  });\n\n  return std::make_pair(concat, offsets);\n}\n\ntemplate std::pair<NDArray, IdArray> ConcatSlices<kDGLCPU, int32_t, int32_t>(\n    NDArray, IdArray);\ntemplate std::pair<NDArray, IdArray> ConcatSlices<kDGLCPU, int64_t, int32_t>(\n    NDArray, IdArray);\ntemplate std::pair<NDArray, IdArray> ConcatSlices<kDGLCPU, float, int32_t>(\n    NDArray, IdArray);\ntemplate std::pair<NDArray, IdArray> ConcatSlices<kDGLCPU, double, int32_t>(\n    NDArray, IdArray);\ntemplate std::pair<NDArray, IdArray> ConcatSlices<kDGLCPU, int32_t, int64_t>(\n    NDArray, IdArray);\ntemplate std::pair<NDArray, IdArray> ConcatSlices<kDGLCPU, int64_t, int64_t>(\n    NDArray, IdArray);\ntemplate std::pair<NDArray, IdArray> ConcatSlices<kDGLCPU, float, int64_t>(\n    NDArray, IdArray);\ntemplate std::pair<NDArray, IdArray> ConcatSlices<kDGLCPU, double, int64_t>(\n    NDArray, IdArray);\n\ntemplate <DGLDeviceType XPU, typename DType>\nstd::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, DType pad_value) {\n  CHECK_NDIM(array, 2, \"array\");\n  const DType *array_data = static_cast<DType *>(array->data);\n  const int64_t rows = array->shape[0];\n  const int64_t cols = array->shape[1];\n\n  IdArray length = NewIdArray(rows, array->ctx);\n  int64_t *length_data = static_cast<int64_t *>(length->data);\n  parallel_for(0, rows, [=](size_t b, size_t e) {\n    for (auto i = b; i < e; ++i) {\n      int64_t j;\n      for (j = 0; j < cols; ++j) {\n        const DType val = array_data[i * cols + j];\n        if (val == pad_value) break;\n      }\n      length_data[i] = j;\n    }\n  });\n\n  auto ret = ConcatSlices<XPU, DType, int64_t>(array, length);\n  return std::make_tuple(ret.first, length, ret.second);\n}\n\ntemplate std::tuple<NDArray, IdArray, IdArray> Pack<kDGLCPU, int32_t>(\n    NDArray, int32_t);\ntemplate std::tuple<NDArray, IdArray, IdArray> Pack<kDGLCPU, int64_t>(\n    NDArray, int64_t);\ntemplate std::tuple<NDArray, IdArray, IdArray> Pack<kDGLCPU, float>(\n    NDArray, float);\ntemplate std::tuple<NDArray, IdArray, IdArray> Pack<kDGLCPU, double>(\n    NDArray, double);\n\n}  // namespace impl\n}  // namespace aten\n}  // namespace dgl\n"
  },
  {
    "path": "src/array/cpu/array_repeat.cc",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file array/cpu/array_repeat.cc\n * @brief Array repeat CPU implementation\n */\n#include <dgl/array.h>\n\n#include <algorithm>\n\nnamespace dgl {\nusing runtime::NDArray;\nnamespace aten {\nnamespace impl {\n\ntemplate <DGLDeviceType XPU, typename DType, typename IdType>\nNDArray Repeat(NDArray array, IdArray repeats) {\n  CHECK(array->shape[0] == repeats->shape[0])\n      << \"shape of array and repeats mismatch\";\n\n  const int64_t len = array->shape[0];\n  const DType *array_data = static_cast<DType *>(array->data);\n  const IdType *repeats_data = static_cast<IdType *>(repeats->data);\n\n  IdType num_elements = 0;\n  for (int64_t i = 0; i < len; ++i) num_elements += repeats_data[i];\n\n  NDArray result = NDArray::Empty({num_elements}, array->dtype, array->ctx);\n  DType *result_data = static_cast<DType *>(result->data);\n  IdType curr = 0;\n  for (int64_t i = 0; i < len; ++i) {\n    std::fill(\n        result_data + curr, result_data + curr + repeats_data[i],\n        array_data[i]);\n    curr += repeats_data[i];\n  }\n\n  return result;\n}\n\ntemplate NDArray Repeat<kDGLCPU, int32_t, int32_t>(NDArray, IdArray);\ntemplate NDArray Repeat<kDGLCPU, int64_t, int32_t>(NDArray, IdArray);\ntemplate NDArray Repeat<kDGLCPU, float, int32_t>(NDArray, IdArray);\ntemplate NDArray Repeat<kDGLCPU, double, int32_t>(NDArray, IdArray);\ntemplate NDArray Repeat<kDGLCPU, int32_t, int64_t>(NDArray, IdArray);\ntemplate NDArray Repeat<kDGLCPU, int64_t, int64_t>(NDArray, IdArray);\ntemplate NDArray Repeat<kDGLCPU, float, int64_t>(NDArray, IdArray);\ntemplate NDArray Repeat<kDGLCPU, double, int64_t>(NDArray, IdArray);\n\n};  // namespace impl\n};  // namespace aten\n};  // namespace dgl\n"
  },
  {
    "path": "src/array/cpu/array_scatter.cc",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file array/cpu/array_scatter.cc\n * @brief Array scatter CPU implementation\n */\n#include <dgl/array.h>\n#include <dgl/runtime/parallel_for.h>\n\nnamespace dgl {\nusing runtime::NDArray;\nnamespace aten {\nnamespace impl {\n\ntemplate <DGLDeviceType XPU, typename DType, typename IdType>\nNDArray Scatter(NDArray array, IdArray indices) {\n  NDArray result =\n      NDArray::Empty({indices->shape[0]}, array->dtype, array->ctx);\n\n  const DType *array_data = static_cast<DType *>(array->data);\n  const IdType *indices_data = static_cast<IdType *>(indices->data);\n  DType *result_data = static_cast<DType *>(result->data);\n\n  for (int64_t i = 0; i < indices->shape[0]; ++i)\n    result_data[indices_data[i]] = array_data[i];\n\n  return result;\n}\n\ntemplate NDArray Scatter<kDGLCPU, int32_t, int32_t>(NDArray, IdArray);\ntemplate NDArray Scatter<kDGLCPU, int64_t, int32_t>(NDArray, IdArray);\ntemplate NDArray Scatter<kDGLCPU, float, int32_t>(NDArray, IdArray);\ntemplate NDArray Scatter<kDGLCPU, double, int32_t>(NDArray, IdArray);\ntemplate NDArray Scatter<kDGLCPU, int32_t, int64_t>(NDArray, IdArray);\ntemplate NDArray Scatter<kDGLCPU, int64_t, int64_t>(NDArray, IdArray);\ntemplate NDArray Scatter<kDGLCPU, float, int64_t>(NDArray, IdArray);\ntemplate NDArray Scatter<kDGLCPU, double, int64_t>(NDArray, IdArray);\n\ntemplate <DGLDeviceType XPU, typename DType, typename IdType>\nvoid Scatter_(IdArray index, NDArray value, NDArray out) {\n  const int64_t len = index->shape[0];\n  const IdType *idx = index.Ptr<IdType>();\n  const DType *val = value.Ptr<DType>();\n  DType *outd = out.Ptr<DType>();\n  runtime::parallel_for(0, len, [&](size_t b, size_t e) {\n    for (auto i = b; i < e; ++i) {\n      outd[idx[i]] = val[i];\n    }\n  });\n}\n\ntemplate void Scatter_<kDGLCPU, int32_t, int32_t>(IdArray, NDArray, NDArray);\ntemplate void Scatter_<kDGLCPU, int64_t, int32_t>(IdArray, NDArray, NDArray);\ntemplate void Scatter_<kDGLCPU, float, int32_t>(IdArray, NDArray, NDArray);\ntemplate void Scatter_<kDGLCPU, double, int32_t>(IdArray, NDArray, NDArray);\ntemplate void Scatter_<kDGLCPU, int32_t, int64_t>(IdArray, NDArray, NDArray);\ntemplate void Scatter_<kDGLCPU, int64_t, int64_t>(IdArray, NDArray, NDArray);\ntemplate void Scatter_<kDGLCPU, float, int64_t>(IdArray, NDArray, NDArray);\ntemplate void Scatter_<kDGLCPU, double, int64_t>(IdArray, NDArray, NDArray);\n\n};  // namespace impl\n};  // namespace aten\n};  // namespace dgl\n"
  },
  {
    "path": "src/array/cpu/array_sort.cc",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file array/cpu/array_sort.cc\n * @brief Array sort CPU implementation\n */\n#include <dgl/array.h>\n#ifdef PARALLEL_ALGORITHMS\n#include <parallel/algorithm>\n#endif\n#include <algorithm>\n#include <iterator>\n\nnamespace {\n\ntemplate <typename V1, typename V2>\nstruct PairRef {\n  PairRef() = delete;\n  PairRef(const PairRef& other) = default;\n  PairRef(PairRef&& other) = default;\n  PairRef(V1* const r, V2* const c) : row(r), col(c) {}\n\n  PairRef& operator=(const PairRef& other) {\n    *row = *other.row;\n    *col = *other.col;\n    return *this;\n  }\n  PairRef& operator=(const std::pair<V1, V2>& val) {\n    *row = std::get<0>(val);\n    *col = std::get<1>(val);\n    return *this;\n  }\n\n  operator std::pair<V1, V2>() const { return std::make_pair(*row, *col); }\n\n  void Swap(const PairRef& other) const {\n    std::swap(*row, *other.row);\n    std::swap(*col, *other.col);\n  }\n\n  V1* row;\n  V2* col;\n};\n\nusing std::swap;\ntemplate <typename V1, typename V2>\nvoid swap(const PairRef<V1, V2>& r1, const PairRef<V1, V2>& r2) {\n  r1.Swap(r2);\n}\n\ntemplate <typename V1, typename V2>\nstruct PairIterator\n    : public std::iterator<\n          std::random_access_iterator_tag, std::pair<V1, V2>, std::ptrdiff_t,\n          std::pair<V1*, V2*>, PairRef<V1, V2>> {\n  PairIterator() = default;\n  PairIterator(const PairIterator& other) = default;\n  PairIterator(PairIterator&& other) = default;\n  PairIterator(V1* r, V2* c) : row(r), col(c) {}\n\n  PairIterator& operator=(const PairIterator& other) = default;\n  PairIterator& operator=(PairIterator&& other) = default;\n  ~PairIterator() = default;\n\n  bool operator==(const PairIterator& other) const { return row == other.row; }\n\n  bool operator!=(const PairIterator& other) const { return row != other.row; }\n\n  bool operator<(const PairIterator& other) const { return row < other.row; }\n\n  bool operator>(const PairIterator& other) const { return row > other.row; }\n\n  bool operator<=(const PairIterator& other) const { return row <= other.row; }\n\n  bool operator>=(const PairIterator& other) const { return row >= other.row; }\n\n  PairIterator& operator+=(const std::ptrdiff_t& movement) {\n    row += movement;\n    col += movement;\n    return *this;\n  }\n\n  PairIterator& operator-=(const std::ptrdiff_t& movement) {\n    row -= movement;\n    col -= movement;\n    return *this;\n  }\n\n  PairIterator& operator++() { return operator+=(1); }\n\n  PairIterator& operator--() { return operator-=(1); }\n\n  PairIterator operator++(int) {\n    PairIterator ret(*this);\n    operator++();\n    return ret;\n  }\n\n  PairIterator operator--(int) {\n    PairIterator ret(*this);\n    operator--();\n    return ret;\n  }\n\n  PairIterator operator+(const std::ptrdiff_t& movement) const {\n    PairIterator ret(*this);\n    ret += movement;\n    return ret;\n  }\n\n  PairIterator operator-(const std::ptrdiff_t& movement) const {\n    PairIterator ret(*this);\n    ret -= movement;\n    return ret;\n  }\n\n  std::ptrdiff_t operator-(const PairIterator& other) const {\n    return row - other.row;\n  }\n\n  PairRef<V1, V2> operator*() const { return PairRef<V1, V2>(row, col); }\n  PairRef<V1, V2> operator*() { return PairRef<V1, V2>(row, col); }\n\n  // required for random access iterators in VS2019\n  PairRef<V1, V2> operator[](size_t offset) const {\n    return PairRef<V1, V2>(row + offset, col + offset);\n  }\n\n  V1* row;\n  V2* col;\n};\n\n}  // namespace\n\nnamespace dgl {\nusing runtime::NDArray;\nnamespace aten {\nnamespace impl {\n\ntemplate <DGLDeviceType XPU, typename IdType>\nstd::pair<IdArray, IdArray> Sort(IdArray array, int /* num_bits */) {\n  const int64_t nitem = array->shape[0];\n  IdArray val = array.Clone();\n  IdArray idx = aten::Range(0, nitem, 64, array->ctx);\n  IdType* val_data = val.Ptr<IdType>();\n  int64_t* idx_data = idx.Ptr<int64_t>();\n  typedef std::pair<IdType, int64_t> Pair;\n#ifdef PARALLEL_ALGORITHMS\n  __gnu_parallel::sort(\n#else\n  std::sort(\n#endif\n      PairIterator<IdType, int64_t>(val_data, idx_data),\n      PairIterator<IdType, int64_t>(val_data, idx_data) + nitem,\n      [](const Pair& a, const Pair& b) {\n        return std::get<0>(a) < std::get<0>(b);\n      });\n  return std::make_pair(val, idx);\n}\n\ntemplate std::pair<IdArray, IdArray> Sort<kDGLCPU, int32_t>(\n    IdArray, int num_bits);\ntemplate std::pair<IdArray, IdArray> Sort<kDGLCPU, int64_t>(\n    IdArray, int num_bits);\n\n}  // namespace impl\n}  // namespace aten\n}  // namespace dgl\n"
  },
  {
    "path": "src/array/cpu/array_utils.h",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file dgl/array_utils.h\n * @brief Utility classes and functions for DGL arrays.\n */\n#ifndef DGL_ARRAY_CPU_ARRAY_UTILS_H_\n#define DGL_ARRAY_CPU_ARRAY_UTILS_H_\n\n#include <dgl/aten/types.h>\n#include <tsl/robin_map.h>\n\n#include <unordered_map>\n#include <utility>\n#include <vector>\n\n#include \"../../c_api_common.h\"\n\nnamespace dgl {\nnamespace aten {\n\n/**\n * @brief A hashmap that maps each ids in the given array to new ids starting\n * from zero.\n *\n * Useful for relabeling integers and finding unique integers.\n *\n * Usually faster than std::unordered_map in existence checking.\n */\ntemplate <typename IdType>\nclass IdHashMap {\n public:\n  // default ctor\n  IdHashMap() : filter_(kFilterSize, false) {}\n\n  // Construct the hashmap using the given id array.\n  // The id array could contain duplicates.\n  // If the id array has no duplicates, the array will be relabeled to\n  // consecutive integers starting from 0.\n  explicit IdHashMap(IdArray ids) : filter_(kFilterSize, false) {\n    oldv2newv_.reserve(ids->shape[0]);\n    Update(ids);\n  }\n\n  // copy ctor\n  IdHashMap(const IdHashMap& other) = default;\n\n  void Reserve(const int64_t size) { oldv2newv_.reserve(size); }\n\n  // Update the hashmap with given id array.\n  // The id array could contain duplicates.\n  void Update(IdArray ids) {\n    const IdType* ids_data = static_cast<IdType*>(ids->data);\n    const int64_t len = ids->shape[0];\n    for (int64_t i = 0; i < len; ++i) {\n      const IdType id = ids_data[i];\n      // Insertion will not happen if the key already exists.\n      oldv2newv_.insert({id, oldv2newv_.size()});\n      filter_[id & kFilterMask] = true;\n    }\n  }\n\n  // Return true if the given id is contained in this hashmap.\n  bool Contains(IdType id) const {\n    return filter_[id & kFilterMask] && oldv2newv_.count(id);\n  }\n\n  // Return the new id of the given id. If the given id is not contained\n  // in the hash map, returns the default_val instead.\n  IdType Map(IdType id, IdType default_val) const {\n    if (filter_[id & kFilterMask]) {\n      auto it = oldv2newv_.find(id);\n      return (it == oldv2newv_.end()) ? default_val : it->second;\n    } else {\n      return default_val;\n    }\n  }\n\n  // Return the new id of each id in the given array.\n  IdArray Map(IdArray ids, IdType default_val) const {\n    const IdType* ids_data = static_cast<IdType*>(ids->data);\n    const int64_t len = ids->shape[0];\n    IdArray values = NewIdArray(len, ids->ctx, ids->dtype.bits);\n    IdType* values_data = static_cast<IdType*>(values->data);\n    for (int64_t i = 0; i < len; ++i)\n      values_data[i] = Map(ids_data[i], default_val);\n    return values;\n  }\n\n  // Return all the old ids collected so far, ordered by new id.\n  IdArray Values() const {\n    IdArray values = NewIdArray(\n        oldv2newv_.size(), DGLContext{kDGLCPU, 0}, sizeof(IdType) * 8);\n    IdType* values_data = static_cast<IdType*>(values->data);\n    for (auto pair : oldv2newv_) values_data[pair.second] = pair.first;\n    return values;\n  }\n\n  inline size_t Size() const { return oldv2newv_.size(); }\n\n private:\n  static constexpr int32_t kFilterMask = 0xFFFFFF;\n  static constexpr int32_t kFilterSize = kFilterMask + 1;\n  // This bitmap is used as a bloom filter to remove some lookups.\n  // Hashtable is very slow. Using bloom filter can significantly speed up\n  // lookups.\n  std::vector<bool> filter_;\n  // The hashmap from old vid to new vid\n  tsl::robin_map<IdType, IdType> oldv2newv_;\n};\n\n/**\n * @brief Hash type for building maps/sets with pairs as keys.\n */\nstruct PairHash {\n  template <class T1, class T2>\n  std::size_t operator()(const std::pair<T1, T2>& pair) const {\n    return std::hash<T1>()(pair.first) ^ std::hash<T2>()(pair.second);\n  }\n};\n\n}  // namespace aten\n}  // namespace dgl\n\n#endif  // DGL_ARRAY_CPU_ARRAY_UTILS_H_\n"
  },
  {
    "path": "src/array/cpu/concurrent_id_hash_map.cc",
    "content": "/**\n *  Copyright (c) 2023 by Contributors\n * @file array/cpu/concurrent_id_hash_map.cc\n * @brief Class about id hash map\n */\n\n#include \"concurrent_id_hash_map.h\"\n\n#ifdef _MSC_VER\n#include <intrin.h>\n#endif  // _MSC_VER\n\n#include <dgl/array.h>\n#include <dgl/runtime/device_api.h>\n#include <dgl/runtime/parallel_for.h>\n\n#include <cmath>\n#include <numeric>\n\nusing namespace dgl::runtime;\n\nnamespace {\nstatic constexpr int64_t kEmptyKey = -1;\nstatic constexpr int kGrainSize = 256;\n\n// The formula is established from experience which is used\n// to get the hashmap size from the input array size.\ninline size_t GetMapSize(size_t num) {\n  size_t capacity = 1;\n  return capacity << static_cast<size_t>(1 + std::log2(num * 3));\n}\n}  // namespace\n\nnamespace dgl {\nnamespace aten {\n\ntemplate <typename IdType>\nIdType ConcurrentIdHashMap<IdType>::CompareAndSwap(\n    IdType* ptr, IdType old_val, IdType new_val) {\n#ifdef _MSC_VER\n  if (sizeof(IdType) == 4) {\n    return _InterlockedCompareExchange(\n        reinterpret_cast<LONG*>(ptr), new_val, old_val);\n  } else if (sizeof(IdType) == 8) {\n    return _InterlockedCompareExchange64(\n        reinterpret_cast<LONGLONG*>(ptr), new_val, old_val);\n  } else {\n    LOG(FATAL) << \"ID can only be int32 or int64\";\n  }\n#elif __GNUC__  // _MSC_VER\n  return __sync_val_compare_and_swap(ptr, old_val, new_val);\n#else           // _MSC_VER\n#error \"CompareAndSwap is not supported on this platform.\"\n#endif  // _MSC_VER\n}\n\ntemplate <typename IdType>\nConcurrentIdHashMap<IdType>::ConcurrentIdHashMap() : mask_(0) {\n  // Used to deallocate the memory in hash_map_ with device api\n  // when the pointer is freed.\n  auto deleter = [](Mapping* mappings) {\n    if (mappings != nullptr) {\n      DGLContext ctx = DGLContext{kDGLCPU, 0};\n      auto device = DeviceAPI::Get(ctx);\n      device->FreeWorkspace(ctx, mappings);\n    }\n  };\n  hash_map_ = {nullptr, deleter};\n}\n\ntemplate <typename IdType>\nIdArray ConcurrentIdHashMap<IdType>::Init(\n    const IdArray& ids, size_t num_seeds) {\n  CHECK_EQ(ids.defined(), true);\n  const IdType* ids_data = ids.Ptr<IdType>();\n  const size_t num_ids = static_cast<size_t>(ids->shape[0]);\n  // Make sure `ids` is not 0 dim.\n  CHECK_GE(num_seeds, 0);\n  CHECK_GE(num_ids, num_seeds);\n  size_t capacity = GetMapSize(num_ids);\n  mask_ = static_cast<IdType>(capacity - 1);\n\n  auto ctx = DGLContext{kDGLCPU, 0};\n  auto device = DeviceAPI::Get(ctx);\n  hash_map_.reset(static_cast<Mapping*>(\n      device->AllocWorkspace(ctx, sizeof(Mapping) * capacity)));\n  memset(hash_map_.get(), -1, sizeof(Mapping) * capacity);\n\n  // This code block is to fill the ids into hash_map_.\n  IdArray unique_ids = NewIdArray(num_ids, ctx, sizeof(IdType) * 8);\n  IdType* unique_ids_data = unique_ids.Ptr<IdType>();\n  // Fill in the first `num_seeds` ids.\n  parallel_for(0, num_seeds, kGrainSize, [&](int64_t s, int64_t e) {\n    for (int64_t i = s; i < e; i++) {\n      InsertAndSet(ids_data[i], static_cast<IdType>(i));\n    }\n  });\n  // Place the first `num_seeds` ids.\n  device->CopyDataFromTo(\n      ids_data, 0, unique_ids_data, 0, sizeof(IdType) * num_seeds, ctx, ctx,\n      ids->dtype);\n\n  // An auxiliary array indicates whether the corresponding elements\n  // are inserted into hash map or not. Use `int16_t` instead of `bool` as\n  // vector<bool> is unsafe when updating different elements from different\n  // threads. See https://en.cppreference.com/w/cpp/container#Thread_safety.\n  std::vector<int16_t> valid(num_ids);\n  auto thread_num = compute_num_threads(0, num_ids, kGrainSize);\n  std::vector<size_t> block_offset(thread_num + 1, 0);\n  // Insert all elements in this loop.\n  parallel_for(num_seeds, num_ids, kGrainSize, [&](int64_t s, int64_t e) {\n    size_t count = 0;\n    for (int64_t i = s; i < e; i++) {\n      valid[i] = Insert(ids_data[i]);\n      count += valid[i];\n    }\n    block_offset[omp_get_thread_num() + 1] = count;\n  });\n\n  // Get ExclusiveSum of each block.\n  std::partial_sum(\n      block_offset.begin() + 1, block_offset.end(), block_offset.begin() + 1);\n  unique_ids->shape[0] = num_seeds + block_offset.back();\n\n  // Get unique array from ids and set value for hash map.\n  parallel_for(num_seeds, num_ids, kGrainSize, [&](int64_t s, int64_t e) {\n    auto tid = omp_get_thread_num();\n    auto pos = block_offset[tid] + num_seeds;\n    for (int64_t i = s; i < e; i++) {\n      if (valid[i]) {\n        unique_ids_data[pos] = ids_data[i];\n        Set(ids_data[i], pos);\n        pos = pos + 1;\n      }\n    }\n  });\n  return unique_ids;\n}\n\ntemplate <typename IdType>\nIdArray ConcurrentIdHashMap<IdType>::MapIds(const IdArray& ids) const {\n  CHECK_EQ(ids.defined(), true);\n  const IdType* ids_data = ids.Ptr<IdType>();\n  const size_t num_ids = static_cast<size_t>(ids->shape[0]);\n  CHECK_GT(num_ids, 0);\n\n  DGLContext ctx = DGLContext{kDGLCPU, 0};\n  IdArray new_ids = NewIdArray(num_ids, ctx, sizeof(IdType) * 8);\n  IdType* values_data = new_ids.Ptr<IdType>();\n\n  parallel_for(0, num_ids, kGrainSize, [&](int64_t s, int64_t e) {\n    for (int64_t i = s; i < e; i++) {\n      values_data[i] = MapId(ids_data[i]);\n    }\n  });\n  return new_ids;\n}\n\ntemplate <typename IdType>\ninline void ConcurrentIdHashMap<IdType>::Next(\n    IdType* pos, IdType* delta) const {\n  // Use Quadric probing.\n  *pos = (*pos + (*delta) * (*delta)) & mask_;\n  *delta = *delta + 1;\n}\n\ntemplate <typename IdType>\ninline IdType ConcurrentIdHashMap<IdType>::MapId(IdType id) const {\n  IdType pos = (id & mask_), delta = 1;\n  IdType empty_key = static_cast<IdType>(kEmptyKey);\n  while (hash_map_[pos].key != empty_key && hash_map_[pos].key != id) {\n    Next(&pos, &delta);\n  }\n  return hash_map_[pos].value;\n}\n\ntemplate <typename IdType>\nbool ConcurrentIdHashMap<IdType>::Insert(IdType id) {\n  IdType pos = (id & mask_), delta = 1;\n  InsertState state = AttemptInsertAt(pos, id);\n  while (state == InsertState::OCCUPIED) {\n    Next(&pos, &delta);\n    state = AttemptInsertAt(pos, id);\n  }\n\n  return state == InsertState::INSERTED;\n}\n\ntemplate <typename IdType>\ninline void ConcurrentIdHashMap<IdType>::Set(IdType key, IdType value) {\n  IdType pos = (key & mask_), delta = 1;\n  while (hash_map_[pos].key != key) {\n    Next(&pos, &delta);\n  }\n\n  hash_map_[pos].value = value;\n}\n\ntemplate <typename IdType>\ninline void ConcurrentIdHashMap<IdType>::InsertAndSet(IdType id, IdType value) {\n  IdType pos = (id & mask_), delta = 1;\n  while (AttemptInsertAt(pos, id) == InsertState::OCCUPIED) {\n    Next(&pos, &delta);\n  }\n\n  hash_map_[pos].value = value;\n}\n\ntemplate <typename IdType>\ninline typename ConcurrentIdHashMap<IdType>::InsertState\nConcurrentIdHashMap<IdType>::AttemptInsertAt(int64_t pos, IdType key) {\n  IdType empty_key = static_cast<IdType>(kEmptyKey);\n  IdType old_val = CompareAndSwap(&(hash_map_[pos].key), empty_key, key);\n  if (old_val == empty_key) {\n    return InsertState::INSERTED;\n  } else if (old_val == key) {\n    return InsertState::EXISTED;\n  } else {\n    return InsertState::OCCUPIED;\n  }\n}\n\ntemplate class ConcurrentIdHashMap<int32_t>;\ntemplate class ConcurrentIdHashMap<int64_t>;\n\ntemplate <typename IdType>\nbool BoolCompareAndSwap(IdType* ptr) {\n#ifdef _MSC_VER\n  if (sizeof(IdType) == 4) {\n    return _InterlockedCompareExchange(reinterpret_cast<LONG*>(ptr), 0, -1) ==\n           -1;\n  } else if (sizeof(IdType) == 8) {\n    return _InterlockedCompareExchange64(\n               reinterpret_cast<LONGLONG*>(ptr), 0, -1) == -1;\n  } else {\n    LOG(FATAL) << \"ID can only be int32 or int64\";\n  }\n#elif __GNUC__  // _MSC_VER\n  return __sync_bool_compare_and_swap(ptr, -1, 0);\n#else           // _MSC_VER\n#error \"CompareAndSwap is not supported on this platform.\"\n#endif  // _MSC_VER\n}\n\ntemplate bool BoolCompareAndSwap<int32_t>(int32_t*);\ntemplate bool BoolCompareAndSwap<int64_t>(int64_t*);\n\n}  // namespace aten\n}  // namespace dgl\n"
  },
  {
    "path": "src/array/cpu/concurrent_id_hash_map.h",
    "content": "/**\n *  Copyright (c) 2023 by Contributors\n * @file array/cpu/concurrent_id_hash_map.h\n * @brief Class about concurrent id hash map\n */\n\n#ifndef DGL_ARRAY_CPU_CONCURRENT_ID_HASH_MAP_H_\n#define DGL_ARRAY_CPU_CONCURRENT_ID_HASH_MAP_H_\n\n#include <dgl/aten/types.h>\n\n#include <functional>\n#include <memory>\n#include <vector>\n\nnamespace dgl {\nnamespace aten {\n\n/**\n * @brief A CPU targeted hashmap for mapping duplicate and non-consecutive ids\n * in the provided array to unique and consecutive ones. It utilizes\n * multi-threading to accelerate the insert and search speed. Currently it is\n * only designed to be used in `ToBlockCpu` for optimizing, so it only support\n * key insertions once with Init function, and it does not support key deletion.\n *\n * The hash map should be prepared in two phases before using. With the first\n * being creating the hashmap, and then initialize it with an id array which is\n * divided into 2 parts: [`seed ids`, `sampled ids`]. `Seed ids` refer to\n * a set ids chosen as the input for sampling process and `sampled ids` are the\n * ids new sampled from the process (note the the `seed ids` might also be\n * sampled in the process and included in the `sampled ids`). In result `seed\n * ids` are mapped to [0, num_seed_ids) and `sampled ids` to [num_seed_ids,\n * num_unique_ids). Notice that mapping order is stable for `seed ids` while not\n * for the `sampled ids`.\n *\n * For example, for an array `A` having 4 seed ids with following entries:\n * [99, 98, 100, 97, 97, 101, 101, 102, 101]\n * Create the hashmap `H` with:\n * `H = ConcurrentIdHashMap()` (1)\n * And Init it with:\n * `U = H.Init(A)` (2)  (U is an id array used to store the unqiue\n * ids in A).\n * Then `U` should be (U is not exclusive as the overall mapping is not stable):\n * [99, 98, 100, 97, 102, 101]\n * And the hashmap should generate following mappings:\n *  * [\n *   {key: 99, value: 0},\n *   {key: 98, value: 1},\n *   {key: 100, value: 2},\n *   {key: 97, value: 3},\n *   {key: 102, value: 4},\n *   {key: 101, value: 5}\n * ]\n * Search the hashmap with array `I`=[98, 99, 102]:\n * R = H.Map(I) (3)\n * R should be:\n * [1, 0, 4]\n **/\ntemplate <typename IdType>\nclass ConcurrentIdHashMap {\n private:\n  /**\n   * @brief The result state of an attempt to insert.\n   */\n  enum class InsertState {\n    OCCUPIED,  // Indicates that the space where an insertion is being\n               // attempted is already occupied by another element.\n    EXISTED,  // Indicates that the element being inserted already exists in the\n              // map, and thus no insertion is performed.\n    INSERTED  // Indicates that the insertion was successful and a new element\n              // was added to the map.\n  };\n\n public:\n  /**\n   * @brief An entry in the hashtable.\n   */\n  struct Mapping {\n    /**\n     * @brief The ID of the item inserted.\n     */\n    IdType key;\n    /**\n     * @brief The value of the item inserted.\n     */\n    IdType value;\n  };\n\n  /**\n   * @brief Cross platform CAS operation.\n   * It is an atomic operation that compares the contents of a memory\n   * location with a given value and, only if they are the same, modifies\n   * the contents of that memory location to a new given value.\n   *\n   * @param ptr The pointer to the object to test and modify .\n   * @param old_val The value expected to be found in `ptr`.\n   * @param new_val The value to store in `ptr` if it is as expected.\n   *\n   * @return Old value pointed by the `ptr`.\n   */\n  static IdType CompareAndSwap(IdType* ptr, IdType old_val, IdType new_val);\n\n  ConcurrentIdHashMap();\n\n  ConcurrentIdHashMap(const ConcurrentIdHashMap& other) = delete;\n  ConcurrentIdHashMap& operator=(const ConcurrentIdHashMap& other) = delete;\n\n  /**\n   * @brief Initialize the hashmap with an array of ids. The first `num_seeds`\n   * ids are unique and must be mapped to a contiguous array starting\n   * from 0. The left can be duplicated and the mapping result is not stable.\n   *\n   * @param ids The array of the ids to be inserted.\n   * @param num_seeds The number of seed ids.\n   *\n   * @return Unique ids from the input `ids`.\n   */\n  IdArray Init(const IdArray& ids, size_t num_seeds);\n\n  /**\n   * @brief Find mappings of given keys.\n   *\n   * @param ids The keys to map for.\n   *\n   * @return Mapping results corresponding to `ids`.\n   */\n  IdArray MapIds(const IdArray& ids) const;\n\n private:\n  /**\n   * @brief Get the next position and delta for probing.\n   *\n   * @param[in,out] pos Calculate the next position with quadric probing.\n   * @param[in,out] delta Calculate the next delta by adding 1.\n   */\n  inline void Next(IdType* pos, IdType* delta) const;\n\n  /**\n   * @brief Find the mapping of a given key.\n   *\n   * @param id The key to map for.\n   *\n   * @return Mapping result corresponding to `id`.\n   */\n  inline IdType MapId(const IdType id) const;\n\n  /**\n   * @brief Insert an id into the hash map.\n   *\n   * @param id The id to be inserted.\n   *\n   * @return Whether the `id` is inserted or not.\n   */\n  inline bool Insert(IdType id);\n\n  /**\n   * @brief Set the value for the key in the hash map.\n   *\n   * @param key The key to set for.\n   * @param value The value to be set for the `key`.\n   *\n   * @warning Key must exist.\n   */\n  inline void Set(IdType key, IdType value);\n\n  /**\n   * @brief Insert a key into the hash map.\n   *\n   * @param id The key to be inserted.\n   * @param value The value to be set for the `key`.\n   *\n   */\n  inline void InsertAndSet(IdType key, IdType value);\n\n  /**\n   * @brief Attempt to insert the key into the hash map at the given position.\n   *\n   * @param pos The position in the hash map to be inserted at.\n   * @param key The key to be inserted.\n   *\n   * @return The state of the insertion.\n   */\n  inline InsertState AttemptInsertAt(int64_t pos, IdType key);\n\n private:\n  /**\n   * @brief Hash maps which is used to store all elements.\n   */\n  std::unique_ptr<Mapping[], std::function<void(Mapping*)>> hash_map_;\n\n  /**\n   * @brief Mask which is assisted to get the position in the table\n   * for a key by performing `&` operation with it.\n   */\n  IdType mask_;\n};\n\ntemplate <typename IdType>\nbool BoolCompareAndSwap(IdType* ptr);\n\n}  // namespace aten\n}  // namespace dgl\n\n#endif  // DGL_ARRAY_CPU_CONCURRENT_ID_HASH_MAP_H_\n"
  },
  {
    "path": "src/array/cpu/coo_coalesce.cc",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file array/cpu/coo_coalesce.cc\n * @brief COO coalescing\n */\n\n#include <dgl/array.h>\n\n#include <vector>\n\nnamespace dgl {\n\nnamespace aten {\n\nnamespace impl {\n\ntemplate <DGLDeviceType XPU, typename IdType>\nstd::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo) {\n  const int64_t nnz = coo.row->shape[0];\n  const IdType* coo_row_data = static_cast<IdType*>(coo.row->data);\n  const IdType* coo_col_data = static_cast<IdType*>(coo.col->data);\n\n  if (!coo.row_sorted || !coo.col_sorted) coo = COOSort(coo, true);\n\n  std::vector<IdType> new_row, new_col, count;\n  IdType prev_row = -1, prev_col = -1;\n  for (int64_t i = 0; i < nnz; ++i) {\n    const IdType curr_row = coo_row_data[i];\n    const IdType curr_col = coo_col_data[i];\n    if (curr_row == prev_row && curr_col == prev_col) {\n      ++count[count.size() - 1];\n    } else {\n      new_row.push_back(curr_row);\n      new_col.push_back(curr_col);\n      count.push_back(1);\n      prev_row = curr_row;\n      prev_col = curr_col;\n    }\n  }\n\n  COOMatrix coo_result = COOMatrix{\n      coo.num_rows,\n      coo.num_cols,\n      NDArray::FromVector(new_row),\n      NDArray::FromVector(new_col),\n      NullArray(),\n      true};\n  return std::make_pair(coo_result, NDArray::FromVector(count));\n}\n\ntemplate std::pair<COOMatrix, IdArray> COOCoalesce<kDGLCPU, int32_t>(COOMatrix);\ntemplate std::pair<COOMatrix, IdArray> COOCoalesce<kDGLCPU, int64_t>(COOMatrix);\n\n};  // namespace impl\n};  // namespace aten\n};  // namespace dgl\n"
  },
  {
    "path": "src/array/cpu/coo_linegraph.cc",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file array/cpu/coo_line_graph.cc\n * @brief COO LineGraph\n */\n\n#include <dgl/array.h>\n\n#include <algorithm>\n#include <iterator>\n#include <numeric>\n#include <vector>\n\nnamespace dgl {\nnamespace aten {\nnamespace impl {\n\ntemplate <DGLDeviceType XPU, typename IdType>\nCOOMatrix COOLineGraph(const COOMatrix& coo, bool backtracking) {\n  const int64_t nnz = coo.row->shape[0];\n  IdType* coo_row = coo.row.Ptr<IdType>();\n  IdType* coo_col = coo.col.Ptr<IdType>();\n  IdArray data = COOHasData(coo)\n                     ? coo.data\n                     : Range(0, nnz, coo.row->dtype.bits, coo.row->ctx);\n  IdType* data_data = data.Ptr<IdType>();\n  std::vector<IdType> new_row;\n  std::vector<IdType> new_col;\n\n  for (int64_t i = 0; i < nnz; ++i) {\n    IdType u = coo_row[i];\n    IdType v = coo_col[i];\n    for (int64_t j = 0; j < nnz; ++j) {\n      // no self-loop\n      if (i == j) continue;\n\n      // succ_u == v\n      // if not backtracking succ_u != u\n      if (v == coo_row[j] && (backtracking || u != coo_col[j])) {\n        new_row.push_back(data_data[i]);\n        new_col.push_back(data_data[j]);\n      }\n    }\n  }\n\n  COOMatrix res = COOMatrix(\n      nnz, nnz, NDArray::FromVector(new_row), NDArray::FromVector(new_col),\n      NullArray(), false, false);\n  return res;\n}\n\ntemplate COOMatrix COOLineGraph<kDGLCPU, int32_t>(\n    const COOMatrix& coo, bool backtracking);\ntemplate COOMatrix COOLineGraph<kDGLCPU, int64_t>(\n    const COOMatrix& coo, bool backtracking);\n\n}  // namespace impl\n}  // namespace aten\n}  // namespace dgl\n"
  },
  {
    "path": "src/array/cpu/coo_remove.cc",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file array/cpu/coo_remove.cc\n * @brief COO matrix remove entries CPU implementation\n */\n#include <dgl/array.h>\n\n#include <utility>\n#include <vector>\n\n#include \"array_utils.h\"\n\nnamespace dgl {\nusing runtime::NDArray;\nnamespace aten {\nnamespace impl {\n\nnamespace {\n\n/** @brief COORemove implementation for COOMatrix with default consecutive edge\n * IDs */\ntemplate <DGLDeviceType XPU, typename IdType>\nvoid COORemoveConsecutive(\n    COOMatrix coo, IdArray entries, std::vector<IdType> *new_rows,\n    std::vector<IdType> *new_cols, std::vector<IdType> *new_eids) {\n  const int64_t nnz = coo.row->shape[0];\n  const int64_t n_entries = entries->shape[0];\n  const IdType *row_data = static_cast<IdType *>(coo.row->data);\n  const IdType *col_data = static_cast<IdType *>(coo.col->data);\n  const IdType *entry_data = static_cast<IdType *>(entries->data);\n\n  std::vector<IdType> entry_data_sorted(entry_data, entry_data + n_entries);\n  std::sort(entry_data_sorted.begin(), entry_data_sorted.end());\n\n  int64_t j = 0;\n  for (int64_t i = 0; i < nnz; ++i) {\n    if (j < n_entries && entry_data_sorted[j] == i) {\n      // Move on to the next different entry\n      while (j < n_entries && entry_data_sorted[j] == i) ++j;\n      continue;\n    }\n    new_rows->push_back(row_data[i]);\n    new_cols->push_back(col_data[i]);\n    new_eids->push_back(i);\n  }\n}\n\n/** @brief COORemove implementation for COOMatrix with shuffled edge IDs */\ntemplate <DGLDeviceType XPU, typename IdType>\nvoid COORemoveShuffled(\n    COOMatrix coo, IdArray entries, std::vector<IdType> *new_rows,\n    std::vector<IdType> *new_cols, std::vector<IdType> *new_eids) {\n  const int64_t nnz = coo.row->shape[0];\n  const IdType *row_data = static_cast<IdType *>(coo.row->data);\n  const IdType *col_data = static_cast<IdType *>(coo.col->data);\n  const IdType *eid_data = static_cast<IdType *>(coo.data->data);\n\n  IdHashMap<IdType> eid_map(entries);\n\n  for (int64_t i = 0; i < nnz; ++i) {\n    const IdType eid = eid_data[i];\n    if (eid_map.Contains(eid)) continue;\n    new_rows->push_back(row_data[i]);\n    new_cols->push_back(col_data[i]);\n    new_eids->push_back(eid);\n  }\n}\n\n};  // namespace\n\ntemplate <DGLDeviceType XPU, typename IdType>\nCOOMatrix COORemove(COOMatrix coo, IdArray entries) {\n  const int64_t nnz = coo.row->shape[0];\n  const int64_t n_entries = entries->shape[0];\n  if (n_entries == 0) return coo;\n\n  std::vector<IdType> new_rows, new_cols, new_eids;\n  new_rows.reserve(nnz - n_entries);\n  new_cols.reserve(nnz - n_entries);\n  new_eids.reserve(nnz - n_entries);\n\n  if (COOHasData(coo))\n    COORemoveShuffled<XPU, IdType>(\n        coo, entries, &new_rows, &new_cols, &new_eids);\n  else\n    // Removing from COO ordered by eid has more efficient implementation.\n    COORemoveConsecutive<XPU, IdType>(\n        coo, entries, &new_rows, &new_cols, &new_eids);\n\n  return COOMatrix(\n      coo.num_rows, coo.num_cols, IdArray::FromVector(new_rows),\n      IdArray::FromVector(new_cols), IdArray::FromVector(new_eids));\n}\n\ntemplate COOMatrix COORemove<kDGLCPU, int32_t>(COOMatrix coo, IdArray entries);\ntemplate COOMatrix COORemove<kDGLCPU, int64_t>(COOMatrix coo, IdArray entries);\n\n};  // namespace impl\n};  // namespace aten\n};  // namespace dgl\n"
  },
  {
    "path": "src/array/cpu/coo_sort.cc",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file array/cpu/coo_sort.cc\n * @brief COO sorting\n */\n#include <dgl/array.h>\n#ifdef PARALLEL_ALGORITHMS\n#include <parallel/algorithm>\n#endif\n#include <algorithm>\n#include <iterator>\n#include <numeric>\n#include <tuple>\n#include <vector>\n\nnamespace {\n\ntemplate <typename IdType>\nstruct TupleRef {\n  TupleRef() = delete;\n  TupleRef(const TupleRef& other) = default;\n  TupleRef(TupleRef&& other) = default;\n  TupleRef(IdType* const r, IdType* const c, IdType* const d)\n      : row(r), col(c), data(d) {}\n\n  TupleRef& operator=(const TupleRef& other) {\n    *row = *other.row;\n    *col = *other.col;\n    *data = *other.data;\n    return *this;\n  }\n  TupleRef& operator=(const std::tuple<IdType, IdType, IdType>& val) {\n    *row = std::get<0>(val);\n    *col = std::get<1>(val);\n    *data = std::get<2>(val);\n    return *this;\n  }\n\n  operator std::tuple<IdType, IdType, IdType>() const {\n    return std::make_tuple(*row, *col, *data);\n  }\n\n  void Swap(const TupleRef& other) const {\n    std::swap(*row, *other.row);\n    std::swap(*col, *other.col);\n    std::swap(*data, *other.data);\n  }\n\n  IdType *row, *col, *data;\n};\n\nusing std::swap;\ntemplate <typename IdType>\nvoid swap(const TupleRef<IdType>& r1, const TupleRef<IdType>& r2) {\n  r1.Swap(r2);\n}\n\ntemplate <typename IdType>\nstruct CooIterator\n    : public std::iterator<\n          std::random_access_iterator_tag, std::tuple<IdType, IdType, IdType>,\n          std::ptrdiff_t, std::tuple<IdType*, IdType*, IdType*>,\n          TupleRef<IdType>> {\n  CooIterator() = default;\n  CooIterator(const CooIterator& other) = default;\n  CooIterator(CooIterator&& other) = default;\n  CooIterator(IdType* r, IdType* c, IdType* d) : row(r), col(c), data(d) {}\n\n  CooIterator& operator=(const CooIterator& other) = default;\n  CooIterator& operator=(CooIterator&& other) = default;\n  ~CooIterator() = default;\n\n  bool operator==(const CooIterator& other) const { return row == other.row; }\n\n  bool operator!=(const CooIterator& other) const { return row != other.row; }\n\n  bool operator<(const CooIterator& other) const { return row < other.row; }\n\n  bool operator>(const CooIterator& other) const { return row > other.row; }\n\n  bool operator<=(const CooIterator& other) const { return row <= other.row; }\n\n  bool operator>=(const CooIterator& other) const { return row >= other.row; }\n\n  CooIterator& operator+=(const std::ptrdiff_t& movement) {\n    row += movement;\n    col += movement;\n    data += movement;\n    return *this;\n  }\n\n  CooIterator& operator-=(const std::ptrdiff_t& movement) {\n    row -= movement;\n    col -= movement;\n    data -= movement;\n    return *this;\n  }\n\n  CooIterator& operator++() { return operator+=(1); }\n\n  CooIterator& operator--() { return operator-=(1); }\n\n  CooIterator operator++(int) {\n    CooIterator ret(*this);\n    operator++();\n    return ret;\n  }\n\n  CooIterator operator--(int) {\n    CooIterator ret(*this);\n    operator--();\n    return ret;\n  }\n\n  CooIterator operator+(const std::ptrdiff_t& movement) const {\n    CooIterator ret(*this);\n    ret += movement;\n    return ret;\n  }\n\n  CooIterator operator-(const std::ptrdiff_t& movement) const {\n    CooIterator ret(*this);\n    ret -= movement;\n    return ret;\n  }\n\n  std::ptrdiff_t operator-(const CooIterator& other) const {\n    return row - other.row;\n  }\n\n  TupleRef<IdType> operator*() const {\n    return TupleRef<IdType>(row, col, data);\n  }\n  TupleRef<IdType> operator*() { return TupleRef<IdType>(row, col, data); }\n\n  // required for random access iterators in VS2019\n  TupleRef<IdType> operator[](size_t offset) const {\n    return TupleRef<IdType>(row + offset, col + offset, data + offset);\n  }\n\n  IdType *row, *col, *data;\n};\n\n}  // namespace\n\nnamespace dgl {\nnamespace aten {\nnamespace impl {\n\n///////////////////////////// COOSort_ /////////////////////////////\n\ntemplate <DGLDeviceType XPU, typename IdType>\nvoid COOSort_(COOMatrix* coo, bool sort_column) {\n  const int64_t nnz = coo->row->shape[0];\n  IdType* coo_row = coo->row.Ptr<IdType>();\n  IdType* coo_col = coo->col.Ptr<IdType>();\n  if (!COOHasData(*coo))\n    coo->data = aten::Range(0, nnz, coo->row->dtype.bits, coo->row->ctx);\n  IdType* coo_data = coo->data.Ptr<IdType>();\n\n  typedef std::tuple<IdType, IdType, IdType> Tuple;\n\n  // Arg sort\n  if (sort_column) {\n#ifdef PARALLEL_ALGORITHMS\n    __gnu_parallel::sort(\n#else\n    std::sort(\n#endif\n        CooIterator<IdType>(coo_row, coo_col, coo_data),\n        CooIterator<IdType>(coo_row, coo_col, coo_data) + nnz,\n        [](const Tuple& a, const Tuple& b) {\n          return (std::get<0>(a) != std::get<0>(b))\n                     ? (std::get<0>(a) < std::get<0>(b))\n                     : (std::get<1>(a) < std::get<1>(b));\n        });\n  } else {\n#ifdef PARALLEL_ALGORITHMS\n    __gnu_parallel::sort(\n#else\n    std::sort(\n#endif\n        CooIterator<IdType>(coo_row, coo_col, coo_data),\n        CooIterator<IdType>(coo_row, coo_col, coo_data) + nnz,\n        [](const Tuple& a, const Tuple& b) {\n          return std::get<0>(a) < std::get<0>(b);\n        });\n  }\n\n  coo->row_sorted = true;\n  coo->col_sorted = sort_column;\n}\n\ntemplate void COOSort_<kDGLCPU, int32_t>(COOMatrix*, bool);\ntemplate void COOSort_<kDGLCPU, int64_t>(COOMatrix*, bool);\n\n///////////////////////////// COOIsSorted /////////////////////////////\n\ntemplate <DGLDeviceType XPU, typename IdType>\nstd::pair<bool, bool> COOIsSorted(COOMatrix coo) {\n  const int64_t nnz = coo.row->shape[0];\n  IdType* row = coo.row.Ptr<IdType>();\n  IdType* col = coo.col.Ptr<IdType>();\n  bool row_sorted = true;\n  bool col_sorted = true;\n  for (int64_t i = 1; row_sorted && i < nnz; ++i) {\n    row_sorted = (row[i - 1] <= row[i]);\n    col_sorted = col_sorted && (row[i - 1] < row[i] || col[i - 1] <= col[i]);\n  }\n  if (!row_sorted) col_sorted = false;\n  return {row_sorted, col_sorted};\n}\n\ntemplate std::pair<bool, bool> COOIsSorted<kDGLCPU, int32_t>(COOMatrix coo);\ntemplate std::pair<bool, bool> COOIsSorted<kDGLCPU, int64_t>(COOMatrix coo);\n\n}  // namespace impl\n}  // namespace aten\n}  // namespace dgl\n"
  },
  {
    "path": "src/array/cpu/csr_get_data.cc",
    "content": "/**\n *  Copyright (c) 2021 by Contributors\n * @file array/cpu/csr_get_data.cc\n * @brief Retrieve entries of a CSR matrix\n */\n#include <dgl/array.h>\n#include <dgl/runtime/parallel_for.h>\n\n#include <numeric>\n#include <unordered_set>\n#include <vector>\n\n#include \"array_utils.h\"\n\nnamespace dgl {\n\nusing runtime::NDArray;\nusing runtime::parallel_for;\nnamespace aten {\nnamespace impl {\n\ntemplate <DGLDeviceType XPU, typename IdType>\nvoid CollectDataFromSorted(\n    const IdType* indices_data, const IdType* data, const IdType start,\n    const IdType end, const IdType col, std::vector<IdType>* ret_vec) {\n  const IdType* start_ptr = indices_data + start;\n  const IdType* end_ptr = indices_data + end;\n  auto it = std::lower_bound(start_ptr, end_ptr, col);\n  // This might be a multi-graph. We need to collect all of the matched\n  // columns.\n  for (; it != end_ptr; it++) {\n    // If the col exist\n    if (*it == col) {\n      IdType idx = it - indices_data;\n      ret_vec->push_back(data ? data[idx] : idx);\n    } else {\n      // If we find a column that is different, we can stop searching now.\n      break;\n    }\n  }\n}\n\ntemplate <DGLDeviceType XPU, typename IdType, typename DType>\nNDArray CSRGetData(\n    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,\n    NDArray weights, DType filler) {\n  const int64_t rowlen = rows->shape[0];\n  const int64_t collen = cols->shape[0];\n\n  CHECK((rowlen == collen) || (rowlen == 1) || (collen == 1))\n      << \"Invalid row and col id array.\";\n\n  const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1;\n  const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1;\n  const IdType* row_data = static_cast<IdType*>(rows->data);\n  const IdType* col_data = static_cast<IdType*>(cols->data);\n\n  const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);\n  const IdType* indices_data = static_cast<IdType*>(csr.indices->data);\n  const IdType* data =\n      CSRHasData(csr) ? static_cast<IdType*>(csr.data->data) : nullptr;\n\n  const int64_t retlen = std::max(rowlen, collen);\n  const DType* weight_data = return_eids ? nullptr : weights.Ptr<DType>();\n  if (return_eids)\n    BUG_IF_FAIL(DGLDataTypeTraits<DType>::dtype == rows->dtype)\n        << \"DType does not match row's dtype.\";\n\n  NDArray ret = Full(filler, retlen, rows->ctx);\n  DType* ret_data = ret.Ptr<DType>();\n\n  // NOTE: In most cases, the input csr is already sorted. If not, we might need\n  // to\n  //   consider sorting it especially when the number of (row, col) pairs is\n  //   large. Need more benchmarks to justify the choice.\n\n  if (csr.sorted) {\n    // use binary search on each row\n    parallel_for(0, retlen, [&](size_t b, size_t e) {\n      for (auto p = b; p < e; ++p) {\n        const IdType row_id = row_data[p * row_stride],\n                     col_id = col_data[p * col_stride];\n        CHECK(row_id >= 0 && row_id < csr.num_rows)\n            << \"Invalid row index: \" << row_id;\n        CHECK(col_id >= 0 && col_id < csr.num_cols)\n            << \"Invalid col index: \" << col_id;\n        const IdType* start_ptr = indices_data + indptr_data[row_id];\n        const IdType* end_ptr = indices_data + indptr_data[row_id + 1];\n        auto it = std::lower_bound(start_ptr, end_ptr, col_id);\n        if (it != end_ptr && *it == col_id) {\n          const IdType idx = it - indices_data;\n          IdType eid = data ? data[idx] : idx;\n          ret_data[p] = return_eids ? eid : weight_data[eid];\n        }\n      }\n    });\n  } else {\n    // linear search on each row\n    parallel_for(0, retlen, [&](size_t b, size_t e) {\n      for (auto p = b; p < e; ++p) {\n        const IdType row_id = row_data[p * row_stride],\n                     col_id = col_data[p * col_stride];\n        CHECK(row_id >= 0 && row_id < csr.num_rows)\n            << \"Invalid row index: \" << row_id;\n        CHECK(col_id >= 0 && col_id < csr.num_cols)\n            << \"Invalid col index: \" << col_id;\n        for (IdType idx = indptr_data[row_id]; idx < indptr_data[row_id + 1];\n             ++idx) {\n          if (indices_data[idx] == col_id) {\n            IdType eid = data ? data[idx] : idx;\n            ret_data[p] = return_eids ? eid : weight_data[eid];\n            break;\n          }\n        }\n      }\n    });\n  }\n  return ret;\n}\n\ntemplate NDArray CSRGetData<kDGLCPU, int32_t, float>(\n    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,\n    NDArray weights, float filler);\ntemplate NDArray CSRGetData<kDGLCPU, int64_t, float>(\n    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,\n    NDArray weights, float filler);\ntemplate NDArray CSRGetData<kDGLCPU, int32_t, double>(\n    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,\n    NDArray weights, double filler);\ntemplate NDArray CSRGetData<kDGLCPU, int64_t, double>(\n    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,\n    NDArray weights, double filler);\n\n// For CSRGetData<XPU, IdType>(CSRMatrix, NDArray, NDArray)\ntemplate NDArray CSRGetData<kDGLCPU, int32_t, int32_t>(\n    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,\n    NDArray weights, int32_t filler);\ntemplate NDArray CSRGetData<kDGLCPU, int64_t, int64_t>(\n    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,\n    NDArray weights, int64_t filler);\n\n}  // namespace impl\n}  // namespace aten\n}  // namespace dgl\n"
  },
  {
    "path": "src/array/cpu/csr_mm.cc",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file array/cpu/csr_mm.cc\n * @brief CSR Matrix Multiplication\n */\n\n#include <dgl/array.h>\n#include <dgl/runtime/parallel_for.h>\n#include <tsl/robin_map.h>\n#include <tsl/robin_set.h>\n\n#include <vector>\n\n#include \"array_utils.h\"\n\nnamespace dgl {\n\nusing dgl::runtime::NDArray;\nusing dgl::runtime::parallel_for;\n\nnamespace aten {\n\nnamespace {\n\n// TODO(BarclayII): avoid using map for sorted CSRs\ntemplate <typename IdType>\nvoid CountNNZPerRow(\n    const IdType* A_indptr, const IdType* A_indices, const IdType* B_indptr,\n    const IdType* B_indices, IdType* C_indptr_data, int64_t M) {\n  parallel_for(0, M, [=](size_t b, size_t e) {\n    for (auto i = b; i < e; ++i) {\n      tsl::robin_set<IdType> set;\n      for (IdType u = A_indptr[i]; u < A_indptr[i + 1]; ++u) {\n        IdType w = A_indices[u];\n        for (IdType v = B_indptr[w]; v < B_indptr[w + 1]; ++v)\n          set.insert(B_indices[v]);\n      }\n      C_indptr_data[i] = set.size();\n    }\n  });\n}\n\ntemplate <typename IdType>\nint64_t ComputeIndptrInPlace(IdType* C_indptr_data, int64_t M) {\n  int64_t nnz = 0;\n  IdType len = 0;\n  for (IdType i = 0; i < M; ++i) {\n    len = C_indptr_data[i];\n    C_indptr_data[i] = nnz;\n    nnz += len;\n  }\n  C_indptr_data[M] = nnz;\n  return nnz;\n}\n\ntemplate <typename IdType, typename DType>\nvoid ComputeIndicesAndData(\n    const IdType* A_indptr, const IdType* A_indices, const IdType* A_eids,\n    const DType* A_data, const IdType* B_indptr, const IdType* B_indices,\n    const IdType* B_eids, const DType* B_data, const IdType* C_indptr_data,\n    IdType* C_indices_data, DType* C_weights_data, int64_t M) {\n  parallel_for(0, M, [=](size_t b, size_t e) {\n    for (auto i = b; i < e; ++i) {\n      tsl::robin_map<IdType, DType> map;\n      for (IdType u = A_indptr[i]; u < A_indptr[i + 1]; ++u) {\n        IdType w = A_indices[u];\n        DType vA = A_data[A_eids ? A_eids[u] : u];\n        for (IdType v = B_indptr[w]; v < B_indptr[w + 1]; ++v) {\n          IdType t = B_indices[v];\n          DType vB = B_data[B_eids ? B_eids[v] : v];\n          map[t] += vA * vB;\n        }\n      }\n\n      IdType v = C_indptr_data[i];\n      for (auto it : map) {\n        C_indices_data[v] = it.first;\n        C_weights_data[v] = it.second;\n        ++v;\n      }\n    }\n  });\n}\n\n};  // namespace\n\ntemplate <int XPU, typename IdType, typename DType>\nstd::pair<CSRMatrix, NDArray> CSRMM(\n    const CSRMatrix& A, NDArray A_weights, const CSRMatrix& B,\n    NDArray B_weights) {\n  CHECK_EQ(A.num_cols, B.num_rows)\n      << \"A's number of columns must equal to B's number of rows\";\n  const bool A_has_eid = !IsNullArray(A.data);\n  const bool B_has_eid = !IsNullArray(B.data);\n  const IdType* A_indptr = A.indptr.Ptr<IdType>();\n  const IdType* A_indices = A.indices.Ptr<IdType>();\n  const IdType* A_eids = A_has_eid ? A.data.Ptr<IdType>() : nullptr;\n  const IdType* B_indptr = B.indptr.Ptr<IdType>();\n  const IdType* B_indices = B.indices.Ptr<IdType>();\n  const IdType* B_eids = B_has_eid ? B.data.Ptr<IdType>() : nullptr;\n  const DType* A_data = A_weights.Ptr<DType>();\n  const DType* B_data = B_weights.Ptr<DType>();\n  const int64_t M = A.num_rows;\n  const int64_t P = B.num_cols;\n\n  IdArray C_indptr = IdArray::Empty({M + 1}, A.indptr->dtype, A.indptr->ctx);\n  IdType* C_indptr_data = C_indptr.Ptr<IdType>();\n\n  CountNNZPerRow<IdType>(\n      A_indptr, A_indices, B_indptr, B_indices, C_indptr_data, M);\n  int64_t nnz = ComputeIndptrInPlace<IdType>(C_indptr_data, M);\n  // Allocate indices and weights array\n  IdArray C_indices = IdArray::Empty({nnz}, A.indices->dtype, A.indices->ctx);\n  NDArray C_weights = NDArray::Empty({nnz}, A_weights->dtype, A_weights->ctx);\n  IdType* C_indices_data = C_indices.Ptr<IdType>();\n  DType* C_weights_data = C_weights.Ptr<DType>();\n\n  ComputeIndicesAndData<IdType, DType>(\n      A_indptr, A_indices, A_eids, A_data, B_indptr, B_indices, B_eids, B_data,\n      C_indptr_data, C_indices_data, C_weights_data, M);\n\n  return {\n      CSRMatrix(\n          M, P, C_indptr, C_indices, NullArray(C_indptr->dtype, C_indptr->ctx)),\n      C_weights};\n}\n\ntemplate std::pair<CSRMatrix, NDArray> CSRMM<kDGLCPU, int32_t, float>(\n    const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);\ntemplate std::pair<CSRMatrix, NDArray> CSRMM<kDGLCPU, int64_t, float>(\n    const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);\ntemplate std::pair<CSRMatrix, NDArray> CSRMM<kDGLCPU, int32_t, double>(\n    const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);\ntemplate std::pair<CSRMatrix, NDArray> CSRMM<kDGLCPU, int64_t, double>(\n    const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);\n\n};  // namespace aten\n};  // namespace dgl\n"
  },
  {
    "path": "src/array/cpu/csr_remove.cc",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file array/cpu/coo_remove.cc\n * @brief CSR matrix remove entries CPU implementation\n */\n#include <dgl/array.h>\n\n#include <utility>\n#include <vector>\n\n#include \"array_utils.h\"\n\nnamespace dgl {\nusing runtime::NDArray;\nnamespace aten {\nnamespace impl {\n\nnamespace {\n\ntemplate <DGLDeviceType XPU, typename IdType>\nvoid CSRRemoveConsecutive(\n    CSRMatrix csr, IdArray entries, std::vector<IdType> *new_indptr,\n    std::vector<IdType> *new_indices, std::vector<IdType> *new_eids) {\n  CHECK_SAME_DTYPE(csr.indices, entries);\n  const int64_t n_entries = entries->shape[0];\n  const IdType *indptr_data = static_cast<IdType *>(csr.indptr->data);\n  const IdType *indices_data = static_cast<IdType *>(csr.indices->data);\n  const IdType *entry_data = static_cast<IdType *>(entries->data);\n\n  std::vector<IdType> entry_data_sorted(entry_data, entry_data + n_entries);\n  std::sort(entry_data_sorted.begin(), entry_data_sorted.end());\n\n  int64_t k = 0;\n  new_indptr->push_back(0);\n  for (int64_t i = 0; i < csr.num_rows; ++i) {\n    for (IdType j = indptr_data[i]; j < indptr_data[i + 1]; ++j) {\n      if (k < n_entries && entry_data_sorted[k] == j) {\n        // Move on to the next different entry\n        while (k < n_entries && entry_data_sorted[k] == j) ++k;\n        continue;\n      }\n      new_indices->push_back(indices_data[j]);\n      new_eids->push_back(k);\n    }\n    new_indptr->push_back(new_indices->size());\n  }\n}\n\ntemplate <DGLDeviceType XPU, typename IdType>\nvoid CSRRemoveShuffled(\n    CSRMatrix csr, IdArray entries, std::vector<IdType> *new_indptr,\n    std::vector<IdType> *new_indices, std::vector<IdType> *new_eids) {\n  CHECK_SAME_DTYPE(csr.indices, entries);\n  const IdType *indptr_data = static_cast<IdType *>(csr.indptr->data);\n  const IdType *indices_data = static_cast<IdType *>(csr.indices->data);\n  const IdType *eid_data = static_cast<IdType *>(csr.data->data);\n\n  IdHashMap<IdType> eid_map(entries);\n\n  new_indptr->push_back(0);\n  for (int64_t i = 0; i < csr.num_rows; ++i) {\n    for (IdType j = indptr_data[i]; j < indptr_data[i + 1]; ++j) {\n      const IdType eid = eid_data ? eid_data[j] : j;\n      if (eid_map.Contains(eid)) continue;\n      new_indices->push_back(indices_data[j]);\n      new_eids->push_back(eid);\n    }\n    new_indptr->push_back(new_indices->size());\n  }\n}\n\n};  // namespace\n\ntemplate <DGLDeviceType XPU, typename IdType>\nCSRMatrix CSRRemove(CSRMatrix csr, IdArray entries) {\n  CHECK_SAME_DTYPE(csr.indices, entries);\n  const int64_t nnz = csr.indices->shape[0];\n  const int64_t n_entries = entries->shape[0];\n  if (n_entries == 0) return csr;\n\n  std::vector<IdType> new_indptr, new_indices, new_eids;\n  new_indptr.reserve(nnz - n_entries);\n  new_indices.reserve(nnz - n_entries);\n  new_eids.reserve(nnz - n_entries);\n\n  if (CSRHasData(csr))\n    CSRRemoveShuffled<XPU, IdType>(\n        csr, entries, &new_indptr, &new_indices, &new_eids);\n  else\n    // Removing from CSR ordered by eid has more efficient implementation\n    CSRRemoveConsecutive<XPU, IdType>(\n        csr, entries, &new_indptr, &new_indices, &new_eids);\n\n  return CSRMatrix(\n      csr.num_rows, csr.num_cols, IdArray::FromVector(new_indptr),\n      IdArray::FromVector(new_indices), IdArray::FromVector(new_eids));\n}\n\ntemplate CSRMatrix CSRRemove<kDGLCPU, int32_t>(CSRMatrix csr, IdArray entries);\ntemplate CSRMatrix CSRRemove<kDGLCPU, int64_t>(CSRMatrix csr, IdArray entries);\n\n};  // namespace impl\n};  // namespace aten\n};  // namespace dgl\n"
  },
  {
    "path": "src/array/cpu/csr_sort.cc",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file array/cpu/csr_sort.cc\n * @brief CSR sorting\n */\n#include <dgl/array.h>\n#include <dgl/runtime/parallel_for.h>\n\n#include <algorithm>\n#include <numeric>\n#include <vector>\n\nnamespace dgl {\nnamespace aten {\nnamespace impl {\n\n///////////////////////////// CSRIsSorted /////////////////////////////\ntemplate <DGLDeviceType XPU, typename IdType>\nbool CSRIsSorted(CSRMatrix csr) {\n  const IdType *indptr = csr.indptr.Ptr<IdType>();\n  const IdType *indices = csr.indices.Ptr<IdType>();\n  return runtime::parallel_reduce(\n      0, csr.num_rows, 1, 1,\n      [indptr, indices](size_t b, size_t e, bool ident) {\n        for (size_t row = b; row < e; ++row) {\n          for (IdType i = indptr[row] + 1; i < indptr[row + 1]; ++i) {\n            if (indices[i - 1] > indices[i]) return false;\n          }\n        }\n        return ident;\n      },\n      [](bool a, bool b) { return a && b; });\n}\n\ntemplate bool CSRIsSorted<kDGLCPU, int64_t>(CSRMatrix csr);\ntemplate bool CSRIsSorted<kDGLCPU, int32_t>(CSRMatrix csr);\n\n///////////////////////////// CSRSort /////////////////////////////\n\ntemplate <DGLDeviceType XPU, typename IdType>\nvoid CSRSort_(CSRMatrix *csr) {\n  typedef std::pair<IdType, IdType> ShufflePair;\n  const int64_t num_rows = csr->num_rows;\n  const int64_t nnz = csr->indices->shape[0];\n  const IdType *indptr_data = static_cast<IdType *>(csr->indptr->data);\n  IdType *indices_data = static_cast<IdType *>(csr->indices->data);\n\n  if (CSRIsSorted(*csr)) {\n    csr->sorted = true;\n    return;\n  }\n\n  if (!CSRHasData(*csr)) {\n    csr->data = aten::Range(0, nnz, csr->indptr->dtype.bits, csr->indptr->ctx);\n  }\n  IdType *eid_data = static_cast<IdType *>(csr->data->data);\n\n  runtime::parallel_for(0, num_rows, [=](size_t b, size_t e) {\n    for (auto row = b; row < e; ++row) {\n      const int64_t num_cols = indptr_data[row + 1] - indptr_data[row];\n      std::vector<ShufflePair> reorder_vec(num_cols);\n      IdType *col = indices_data + indptr_data[row];\n      IdType *eid = eid_data + indptr_data[row];\n\n      for (int64_t i = 0; i < num_cols; i++) {\n        reorder_vec[i].first = col[i];\n        reorder_vec[i].second = eid[i];\n      }\n      std::sort(\n          reorder_vec.begin(), reorder_vec.end(),\n          [](const ShufflePair &e1, const ShufflePair &e2) {\n            return e1.first < e2.first;\n          });\n      for (int64_t i = 0; i < num_cols; i++) {\n        col[i] = reorder_vec[i].first;\n        eid[i] = reorder_vec[i].second;\n      }\n    }\n  });\n\n  csr->sorted = true;\n}\n\ntemplate void CSRSort_<kDGLCPU, int64_t>(CSRMatrix *csr);\ntemplate void CSRSort_<kDGLCPU, int32_t>(CSRMatrix *csr);\n\ntemplate <DGLDeviceType XPU, typename IdType, typename TagType>\nstd::pair<CSRMatrix, NDArray> CSRSortByTag(\n    const CSRMatrix &csr, const IdArray tag_array, int64_t num_tags) {\n  const auto indptr_data = static_cast<const IdType *>(csr.indptr->data);\n  const auto indices_data = static_cast<const IdType *>(csr.indices->data);\n  const auto eid_data = aten::CSRHasData(csr)\n                            ? static_cast<const IdType *>(csr.data->data)\n                            : nullptr;\n  const auto tag_data = static_cast<const TagType *>(tag_array->data);\n  const int64_t num_rows = csr.num_rows;\n\n  NDArray tag_pos = NDArray::Empty(\n      {csr.num_rows, num_tags + 1}, csr.indptr->dtype, csr.indptr->ctx);\n  auto tag_pos_data = static_cast<IdType *>(tag_pos->data);\n  std::fill(tag_pos_data, tag_pos_data + csr.num_rows * (num_tags + 1), 0);\n\n  aten::CSRMatrix output(\n      csr.num_rows, csr.num_cols, csr.indptr.Clone(), csr.indices.Clone(),\n      NDArray::Empty(\n          {csr.indices->shape[0]}, csr.indices->dtype, csr.indices->ctx),\n      csr.sorted);\n\n  auto out_indices_data = static_cast<IdType *>(output.indices->data);\n  auto out_eid_data = static_cast<IdType *>(output.data->data);\n\n  runtime::parallel_for(0, num_rows, [&](size_t b, size_t e) {\n    for (auto src = b; src < e; ++src) {\n      const IdType start = indptr_data[src];\n      const IdType end = indptr_data[src + 1];\n\n      auto tag_pos_row = tag_pos_data + src * (num_tags + 1);\n      std::vector<IdType> pointer(num_tags, 0);\n\n      for (IdType ptr = start; ptr < end; ++ptr) {\n        const IdType eid = eid_data ? eid_data[ptr] : ptr;\n        const TagType tag = tag_data[eid];\n        CHECK_LT(tag, num_tags);\n        ++tag_pos_row[tag + 1];\n      }  // count\n\n      for (TagType tag = 1; tag <= num_tags; ++tag) {\n        tag_pos_row[tag] += tag_pos_row[tag - 1];\n      }  // cumulate\n\n      for (IdType ptr = start; ptr < end; ++ptr) {\n        const IdType dst = indices_data[ptr];\n        const IdType eid = eid_data ? eid_data[ptr] : ptr;\n        const TagType tag = tag_data[eid];\n        const IdType offset = tag_pos_row[tag] + pointer[tag];\n        CHECK_LT(offset, tag_pos_row[tag + 1]);\n        ++pointer[tag];\n\n        out_indices_data[start + offset] = dst;\n        out_eid_data[start + offset] = eid;\n      }\n    }\n  });\n  output.sorted = false;\n  return std::make_pair(output, tag_pos);\n}\n\ntemplate std::pair<CSRMatrix, NDArray> CSRSortByTag<kDGLCPU, int64_t, int64_t>(\n    const CSRMatrix &csr, const IdArray tag, int64_t num_tags);\ntemplate std::pair<CSRMatrix, NDArray> CSRSortByTag<kDGLCPU, int64_t, int32_t>(\n    const CSRMatrix &csr, const IdArray tag, int64_t num_tags);\ntemplate std::pair<CSRMatrix, NDArray> CSRSortByTag<kDGLCPU, int32_t, int64_t>(\n    const CSRMatrix &csr, const IdArray tag, int64_t num_tags);\ntemplate std::pair<CSRMatrix, NDArray> CSRSortByTag<kDGLCPU, int32_t, int32_t>(\n    const CSRMatrix &csr, const IdArray tag, int64_t num_tags);\n\n}  // namespace impl\n}  // namespace aten\n}  // namespace dgl\n"
  },
  {
    "path": "src/array/cpu/csr_sum.cc",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file array/cpu/csr_sum.cc\n * @brief CSR Summation\n */\n\n#include <dgl/array.h>\n#include <dgl/runtime/parallel_for.h>\n#include <tsl/robin_map.h>\n#include <tsl/robin_set.h>\n\n#include <vector>\n\n#include \"array_utils.h\"\n\nnamespace dgl {\n\nusing dgl::runtime::NDArray;\n\nnamespace aten {\n\nnamespace {\n\n// TODO(BarclayII): avoid using map for sorted CSRs\ntemplate <typename IdType>\nvoid CountNNZPerRow(\n    const std::vector<const IdType*>& A_indptr,\n    const std::vector<const IdType*>& A_indices, IdType* C_indptr_data,\n    int64_t M) {\n  int64_t n = A_indptr.size();\n\n  runtime::parallel_for(0, M, [=](size_t b, size_t e) {\n    for (size_t i = b; i < e; ++i) {\n      tsl::robin_set<IdType> set;\n      for (int64_t k = 0; k < n; ++k) {\n        for (IdType u = A_indptr[k][i]; u < A_indptr[k][i + 1]; ++u)\n          set.insert(A_indices[k][u]);\n      }\n      C_indptr_data[i] = set.size();\n    }\n  });\n}\n\ntemplate <typename IdType>\nint64_t ComputeIndptrInPlace(IdType* C_indptr_data, int64_t M) {\n  int64_t nnz = 0;\n  IdType len = 0;\n  for (IdType i = 0; i < M; ++i) {\n    len = C_indptr_data[i];\n    C_indptr_data[i] = nnz;\n    nnz += len;\n  }\n  C_indptr_data[M] = nnz;\n  return nnz;\n}\n\ntemplate <typename IdType, typename DType>\nvoid ComputeIndicesAndData(\n    const std::vector<const IdType*>& A_indptr,\n    const std::vector<const IdType*>& A_indices,\n    const std::vector<const IdType*>& A_eids,\n    const std::vector<const DType*>& A_data, const IdType* C_indptr_data,\n    IdType* C_indices_data, DType* C_weights_data, int64_t M) {\n  int64_t n = A_indptr.size();\n  runtime::parallel_for(0, M, [=](size_t b, size_t e) {\n    for (auto i = b; i < e; ++i) {\n      tsl::robin_map<IdType, DType> map;\n      for (int64_t k = 0; k < n; ++k) {\n        for (IdType u = A_indptr[k][i]; u < A_indptr[k][i + 1]; ++u) {\n          IdType kA = A_indices[k][u];\n          DType vA = A_data[k][A_eids[k] ? A_eids[k][u] : u];\n          map[kA] += vA;\n        }\n      }\n      IdType j = C_indptr_data[i];\n      for (auto it : map) {\n        C_indices_data[j] = it.first;\n        C_weights_data[j] = it.second;\n        ++j;\n      }\n    }\n  });\n}\n\n};  // namespace\n\ntemplate <int XPU, typename IdType, typename DType>\nstd::pair<CSRMatrix, NDArray> CSRSum(\n    const std::vector<CSRMatrix>& A, const std::vector<NDArray>& A_weights) {\n  CHECK(A.size() > 0) << \"List of matrices can't be empty.\";\n  CHECK_EQ(A.size(), A_weights.size())\n      << \"List of matrices and weights must have same length\";\n  const int64_t M = A[0].num_rows;\n  const int64_t N = A[0].num_cols;\n  const int64_t n = A.size();\n\n  std::vector<bool> A_has_eid(n);\n  std::vector<const IdType*> A_indptr(n);\n  std::vector<const IdType*> A_indices(n);\n  std::vector<const IdType*> A_eids(n);\n  std::vector<const DType*> A_data(n);\n\n  for (int64_t i = 0; i < n; ++i) {\n    const CSRMatrix& csr = A[i];\n    const NDArray& data = A_weights[i];\n    A_has_eid[i] = !IsNullArray(csr.data);\n    A_indptr[i] = csr.indptr.Ptr<IdType>();\n    A_indices[i] = csr.indices.Ptr<IdType>();\n    A_eids[i] = A_has_eid[i] ? csr.data.Ptr<IdType>() : nullptr;\n    A_data[i] = data.Ptr<DType>();\n  }\n\n  IdArray C_indptr =\n      IdArray::Empty({M + 1}, A[0].indptr->dtype, A[0].indptr->ctx);\n  IdType* C_indptr_data = C_indptr.Ptr<IdType>();\n\n  CountNNZPerRow<IdType>(A_indptr, A_indices, C_indptr_data, M);\n  IdType nnz = ComputeIndptrInPlace<IdType>(C_indptr_data, M);\n  // Allocate indices and weights array\n  IdArray C_indices =\n      IdArray::Empty({nnz}, A[0].indices->dtype, A[0].indices->ctx);\n  NDArray C_weights =\n      NDArray::Empty({nnz}, A_weights[0]->dtype, A_weights[0]->ctx);\n  IdType* C_indices_data = C_indices.Ptr<IdType>();\n  DType* C_weights_data = C_weights.Ptr<DType>();\n  ComputeIndicesAndData<IdType, DType>(\n      A_indptr, A_indices, A_eids, A_data, C_indptr_data, C_indices_data,\n      C_weights_data, M);\n\n  return {\n      CSRMatrix(\n          M, N, C_indptr, C_indices, NullArray(C_indptr->dtype, C_indptr->ctx)),\n      C_weights};\n}\n\ntemplate std::pair<CSRMatrix, NDArray> CSRSum<kDGLCPU, int32_t, float>(\n    const std::vector<CSRMatrix>&, const std::vector<NDArray>&);\ntemplate std::pair<CSRMatrix, NDArray> CSRSum<kDGLCPU, int64_t, float>(\n    const std::vector<CSRMatrix>&, const std::vector<NDArray>&);\ntemplate std::pair<CSRMatrix, NDArray> CSRSum<kDGLCPU, int32_t, double>(\n    const std::vector<CSRMatrix>&, const std::vector<NDArray>&);\ntemplate std::pair<CSRMatrix, NDArray> CSRSum<kDGLCPU, int64_t, double>(\n    const std::vector<CSRMatrix>&, const std::vector<NDArray>&);\n\n};  // namespace aten\n};  // namespace dgl\n"
  },
  {
    "path": "src/array/cpu/csr_to_simple.cc",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file array/cpu/csr_to_simple.cc\n * @brief CSR sorting\n */\n#include <dgl/array.h>\n\n#include <algorithm>\n#include <numeric>\n#include <vector>\n\nnamespace dgl {\nnamespace aten {\nnamespace impl {\n\ntemplate <DGLDeviceType XPU, typename IdType>\nstd::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple(CSRMatrix csr) {\n  if (!csr.sorted) csr = CSRSort(csr);\n\n  const IdType *indptr_data = static_cast<IdType *>(csr.indptr->data);\n  const IdType *indices_data = static_cast<IdType *>(csr.indices->data);\n\n  std::vector<IdType> indptr;\n  std::vector<IdType> indices;\n  std::vector<IdType> count;\n  indptr.resize(csr.indptr->shape[0]);\n  indptr[0] = 0;\n\n  for (int64_t i = 1; i < csr.indptr->shape[0]; ++i) {\n    if (indptr_data[i - 1] == indptr_data[i]) {\n      indptr[i] = indptr[i - 1];\n      continue;\n    }\n\n    int64_t cnt = 1;\n    int64_t dup_cnt = 1;\n    indices.push_back(indices_data[indptr_data[i - 1]]);\n    for (int64_t j = indptr_data[i - 1] + 1; j < indptr_data[i]; ++j) {\n      if (indices_data[j - 1] == indices_data[j]) {\n        ++dup_cnt;\n        continue;\n      }\n      count.push_back(dup_cnt);\n      dup_cnt = 1;\n      indices.push_back(indices_data[j]);\n      ++cnt;\n    }\n    count.push_back(dup_cnt);\n    indptr[i] = indptr[i - 1] + cnt;\n  }\n\n  CSRMatrix res_csr = CSRMatrix(\n      csr.num_rows, csr.num_cols, IdArray::FromVector(indptr),\n      IdArray::FromVector(indices), NullArray(), true);\n\n  const IdArray &edge_count = IdArray::FromVector(count);\n  const IdArray new_eids =\n      Range(0, res_csr.indices->shape[0], sizeof(IdType) * 8, csr.indptr->ctx);\n  const IdArray eids_remapped =\n      CSRHasData(csr) ? Scatter(Repeat(new_eids, edge_count), csr.data)\n                      : Repeat(new_eids, edge_count);\n\n  return std::make_tuple(res_csr, edge_count, eids_remapped);\n}\n\ntemplate std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple<kDGLCPU, int32_t>(\n    CSRMatrix);\ntemplate std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple<kDGLCPU, int64_t>(\n    CSRMatrix);\n\n}  // namespace impl\n}  // namespace aten\n}  // namespace dgl\n"
  },
  {
    "path": "src/array/cpu/csr_union.cc",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file array/cpu/coo_sort.cc\n * @brief COO sorting\n */\n#include <dgl/array.h>\n#include <dgl/runtime/parallel_for.h>\n\n#include <algorithm>\n#include <iterator>\n#include <numeric>\n#include <vector>\n\nnamespace dgl {\nnamespace aten {\nnamespace impl {\n\ntemplate <DGLDeviceType XPU, typename IdType>\nCSRMatrix UnionCsr(const std::vector<CSRMatrix> &csrs) {\n  std::vector<IdType> res_indptr;\n  std::vector<IdType> res_indices;\n  std::vector<IdType> res_data;\n\n  // some preprocess\n  // we assume the number of csrs is not large in common cases\n  std::vector<IdArray> data;\n  std::vector<IdType *> data_data;\n  std::vector<IdType *> indptr_data;\n  std::vector<IdType *> indices_data;\n  int64_t num_edges = 0;\n  bool sorted = true;\n  for (size_t i = 0; i < csrs.size(); ++i) {\n    //  eids of csrs[0] remains unchanged\n    //  eids of csrs[1] will be increased by number of edges of csrs[0], etc.\n    data.push_back(\n        CSRHasData(csrs[i])\n            ? csrs[i].data + num_edges\n            : Range(\n                  num_edges, num_edges + csrs[i].indices->shape[0],\n                  csrs[i].indptr->dtype.bits, csrs[i].indptr->ctx));\n    data_data.push_back(data[i].Ptr<IdType>());\n    indptr_data.push_back(csrs[i].indptr.Ptr<IdType>());\n    indices_data.push_back(csrs[i].indices.Ptr<IdType>());\n    num_edges += csrs[i].indices->shape[0];\n    sorted &= csrs[i].sorted;\n  }\n\n  res_indptr.resize(csrs[0].num_rows + 1);\n  res_indices.resize(num_edges);\n  res_data.resize(num_edges);\n  res_indptr[0] = 0;\n\n  if (sorted) {  // all csrs are sorted\n#pragma omp for\n    for (int64_t i = 1; i <= csrs[0].num_rows; ++i) {\n      std::vector<int64_t> indices_off;\n      res_indptr[i] = indptr_data[0][i];\n\n      indices_off.push_back(indptr_data[0][i - 1]);\n      for (size_t j = 1; j < csrs.size(); ++j) {\n        res_indptr[i] += indptr_data[j][i];\n        indices_off.push_back(indptr_data[j][i - 1]);\n      }\n\n      IdType off = res_indptr[i - 1];\n      while (off < res_indptr[i]) {\n        IdType min = csrs[0].num_cols + 1;\n        int64_t min_idx = -1;\n        for (size_t j = 0; j < csrs.size(); ++j) {\n          if (indices_off[j] < indptr_data[j][i]) {\n            if (min <= indices_data[j][indices_off[j]]) {\n              continue;\n            } else {\n              min = indices_data[j][indices_off[j]];\n              min_idx = j;\n            }\n          }  // for check out of bound\n        }    // for\n        res_indices[off] = min;\n        res_data[off] = data_data[min_idx][indices_off[min_idx]];\n        indices_off[min_idx] += 1;\n        ++off;\n      }     // while\n    }       // omp for\n  } else {  // some csrs are not sorted\n#pragma omp for\n    for (int64_t i = 1; i <= csrs[0].num_rows; ++i) {\n      IdType off = res_indptr[i - 1];\n      res_indptr[i] = 0;\n\n      for (size_t j = 0; j < csrs.size(); ++j) {\n        std::memcpy(\n            &res_indices[off], &indices_data[j][indptr_data[j][i - 1]],\n            sizeof(IdType) * (indptr_data[j][i] - indptr_data[j][i - 1]));\n        std::memcpy(\n            &res_data[off], &data_data[j][indptr_data[j][i - 1]],\n            sizeof(IdType) * (indptr_data[j][i] - indptr_data[j][i - 1]));\n        off += indptr_data[j][i] - indptr_data[j][i - 1];\n      }\n      res_indptr[i] = off;\n    }  // omp for\n  }\n\n  return CSRMatrix(\n      csrs[0].num_rows, csrs[0].num_cols, IdArray::FromVector(res_indptr),\n      IdArray::FromVector(res_indices), IdArray::FromVector(res_data), sorted);\n}\n\ntemplate CSRMatrix UnionCsr<kDGLCPU, int64_t>(const std::vector<CSRMatrix> &);\ntemplate CSRMatrix UnionCsr<kDGLCPU, int32_t>(const std::vector<CSRMatrix> &);\n\n}  // namespace impl\n}  // namespace aten\n}  // namespace dgl\n"
  },
  {
    "path": "src/array/cpu/disjoint_union.cc",
    "content": "/**\n *   Copyright (c) 2022, NVIDIA CORPORATION.\n *\n *   Licensed under the Apache License, Version 2.0 (the \"License\");\n *   you may not use this file except in compliance with the License.\n *   You may obtain a copy of the License at\n *\n *       http://www.apache.org/licenses/LICENSE-2.0\n *\n *   Unless required by applicable law or agreed to in writing, software\n *   distributed under the License is distributed on an \"AS IS\" BASIS,\n *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n *   See the License for the specific language governing permissions and\n *   limitations under the License.\n *\n * @file array/cpu/disjoint_union.cc\n * @brief Disjoint union CPU implementation.\n */\n\n#include <dgl/array.h>\n#include <dgl/runtime/parallel_for.h>\n\n#include <tuple>\n\nnamespace dgl {\nusing runtime::NDArray;\nnamespace aten {\nnamespace impl {\n\ntemplate <DGLDeviceType XPU, typename IdType>\nstd::tuple<IdArray, IdArray, IdArray> _ComputePrefixSums(\n    const std::vector<COOMatrix>& coos) {\n  IdArray prefix_src_arr =\n      NewIdArray(coos.size(), coos[0].row->ctx, coos[0].row->dtype.bits);\n  IdArray prefix_dst_arr =\n      NewIdArray(coos.size(), coos[0].row->ctx, coos[0].row->dtype.bits);\n  IdArray prefix_elm_arr =\n      NewIdArray(coos.size(), coos[0].row->ctx, coos[0].row->dtype.bits);\n\n  auto prefix_src = prefix_src_arr.Ptr<IdType>();\n  auto prefix_dst = prefix_dst_arr.Ptr<IdType>();\n  auto prefix_elm = prefix_elm_arr.Ptr<IdType>();\n\n  dgl::runtime::parallel_for(0, coos.size(), [&](IdType b, IdType e) {\n    for (IdType i = b; i < e; ++i) {\n      prefix_src[i] = coos[i].num_rows;\n      prefix_dst[i] = coos[i].num_cols;\n      prefix_elm[i] = coos[i].row->shape[0];\n    }\n  });\n\n  return std::make_tuple(\n      CumSum(prefix_src_arr, true), CumSum(prefix_dst_arr, true),\n      CumSum(prefix_elm_arr, true));\n}\n\ntemplate <DGLDeviceType XPU, typename IdType>\nCOOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos) {\n  bool has_data = false;\n  bool row_sorted = true;\n  bool col_sorted = true;\n  // check if data index array\n  for (size_t i = 0; i < coos.size(); ++i) {\n    CHECK_SAME_DTYPE(coos[0].row, coos[i].row);\n    CHECK_SAME_CONTEXT(coos[0].row, coos[i].row);\n    has_data |= COOHasData(coos[i]);\n  }\n\n  auto prefixes = _ComputePrefixSums<XPU, IdType>(coos);\n  auto prefix_src = static_cast<IdArray>(std::get<0>(prefixes)).Ptr<IdType>();\n  auto prefix_dst = static_cast<IdArray>(std::get<1>(prefixes)).Ptr<IdType>();\n  auto prefix_elm = static_cast<IdArray>(std::get<2>(prefixes)).Ptr<IdType>();\n\n  IdArray result_src = NewIdArray(\n      prefix_elm[coos.size()], coos[0].row->ctx, coos[0].row->dtype.bits);\n  IdArray result_dst = NewIdArray(\n      prefix_elm[coos.size()], coos[0].col->ctx, coos[0].col->dtype.bits);\n  IdArray result_dat = NullArray();\n  if (has_data) {\n    result_dat = NewIdArray(\n        prefix_elm[coos.size()], coos[0].row->ctx, coos[0].row->dtype.bits);\n  }\n\n  auto res_src_data = result_src.Ptr<IdType>();\n  auto res_dst_data = result_dst.Ptr<IdType>();\n  auto res_dat_data = result_dat.Ptr<IdType>();\n\n  // 32 is a number obtained from experience. If a user set the grain size\n  // explicitly via env, use that value instead.\n  size_t grain_size = dgl::runtime::DefaultGrainSizeT(32)();\n  dgl::runtime::parallel_for(\n      0, coos.size(), grain_size, [&](IdType b, IdType e) {\n        for (IdType i = b; i < e; ++i) {\n          const aten::COOMatrix& coo = coos[i];\n          if (!coo.row_sorted) row_sorted = false;\n          if (!coo.col_sorted) col_sorted = false;\n\n          auto edges_src = coo.row.Ptr<IdType>();\n          auto edges_dst = coo.col.Ptr<IdType>();\n          auto edges_dat = coo.data.Ptr<IdType>();\n\n          for (IdType j = 0; j < coo.row->shape[0]; j++) {\n            res_src_data[prefix_elm[i] + j] = edges_src[j] + prefix_src[i];\n          }\n\n          for (IdType j = 0; j < coo.row->shape[0]; j++) {\n            res_dst_data[prefix_elm[i] + j] = edges_dst[j] + prefix_dst[i];\n          }\n\n          if (has_data) {\n            for (IdType j = 0; j < coo.row->shape[0]; j++) {\n              const auto d = (!COOHasData(coo)) ? j : edges_dat[j];\n              res_dat_data[prefix_elm[i] + j] = d + prefix_elm[i];\n            }\n          }\n        }\n      });\n  return COOMatrix(\n      prefix_src[coos.size()], prefix_dst[coos.size()], result_src, result_dst,\n      result_dat, row_sorted, col_sorted);\n}\n\ntemplate COOMatrix DisjointUnionCoo<kDGLCPU, int32_t>(\n    const std::vector<COOMatrix>& coos);\ntemplate COOMatrix DisjointUnionCoo<kDGLCPU, int64_t>(\n    const std::vector<COOMatrix>& coos);\n\n}  // namespace impl\n}  // namespace aten\n}  // namespace dgl\n"
  },
  {
    "path": "src/array/cpu/gather_mm.cc",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file kernel/cpu/gaher_mm.cc\n * @brief GatherMM C APIs and definitions.\n */\n#include \"./gather_mm.h\"\n\n#include <dgl/array.h>\n\nnamespace dgl {\nnamespace aten {\n\n/** @brief Generalized SegmentMM. */\ntemplate <int XPU, typename IdType, typename DType>\nvoid SegmentMM(\n    const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,\n    bool a_trans, bool b_trans) {\n  LOG(FATAL) << \"Unsupported CPU kernel for SegmentMM.\";\n}\n\ntemplate <int XPU, typename IdType, typename DType>\nvoid SegmentMMBackwardB(\n    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen) {\n  LOG(FATAL) << \"Unsupported CPU kernel for SegmentMMBackwardB.\";\n}\n\n/** @brief Generalized GatherMM. */\ntemplate <int XPU, typename IdType, typename DType>\nvoid GatherMM(\n    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,\n    const NDArray idx_b) {\n  LOG(FATAL) << \"Unsupported CPU kernel for GatherMM.\";\n}\n\n/** @brief Generalized GatherMM_scatter. */\ntemplate <int XPU, typename IdType, typename DType>\nvoid GatherMMScatter(\n    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,\n    const NDArray idx_b, const NDArray idx_c) {\n  LOG(FATAL) << \"Unsupported CPU kernel for GatherMM.\";\n}\n\ntemplate void GatherMM<kDGLCPU, int32_t, BFloat16>(\n    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,\n    const NDArray idx_b);\ntemplate void GatherMM<kDGLCPU, int64_t, BFloat16>(\n    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,\n    const NDArray idx_b);\ntemplate void GatherMM<kDGLCPU, int32_t, float>(\n    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,\n    const NDArray idx_b);\ntemplate void GatherMM<kDGLCPU, int64_t, float>(\n    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,\n    const NDArray idx_b);\ntemplate void GatherMM<kDGLCPU, int32_t, double>(\n    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,\n    const NDArray idx_b);\ntemplate void GatherMM<kDGLCPU, int64_t, double>(\n    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,\n    const NDArray idx_b);\n\ntemplate void GatherMMScatter<kDGLCPU, int32_t, BFloat16>(\n    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,\n    const NDArray idx_b, const NDArray idx_c);\ntemplate void GatherMMScatter<kDGLCPU, int64_t, BFloat16>(\n    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,\n    const NDArray idx_b, const NDArray idx_c);\ntemplate void GatherMMScatter<kDGLCPU, int32_t, float>(\n    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,\n    const NDArray idx_b, const NDArray idx_c);\ntemplate void GatherMMScatter<kDGLCPU, int64_t, float>(\n    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,\n    const NDArray idx_b, const NDArray idx_c);\ntemplate void GatherMMScatter<kDGLCPU, int32_t, double>(\n    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,\n    const NDArray idx_b, const NDArray idx_c);\ntemplate void GatherMMScatter<kDGLCPU, int64_t, double>(\n    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,\n    const NDArray idx_b, const NDArray idx_c);\n\ntemplate void SegmentMM<kDGLCPU, int32_t, BFloat16>(\n    const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,\n    bool a_trans, bool b_trans);\ntemplate void SegmentMM<kDGLCPU, int64_t, BFloat16>(\n    const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,\n    bool a_trans, bool b_trans);\ntemplate void SegmentMM<kDGLCPU, int32_t, float>(\n    const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,\n    bool a_trans, bool b_trans);\ntemplate void SegmentMM<kDGLCPU, int64_t, float>(\n    const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,\n    bool a_trans, bool b_trans);\ntemplate void SegmentMM<kDGLCPU, int32_t, double>(\n    const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,\n    bool a_trans, bool b_trans);\ntemplate void SegmentMM<kDGLCPU, int64_t, double>(\n    const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,\n    bool a_trans, bool b_trans);\n\ntemplate void SegmentMMBackwardB<kDGLCPU, int32_t, BFloat16>(\n    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);\ntemplate void SegmentMMBackwardB<kDGLCPU, int64_t, BFloat16>(\n    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);\ntemplate void SegmentMMBackwardB<kDGLCPU, int32_t, float>(\n    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);\ntemplate void SegmentMMBackwardB<kDGLCPU, int64_t, float>(\n    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);\ntemplate void SegmentMMBackwardB<kDGLCPU, int32_t, double>(\n    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);\ntemplate void SegmentMMBackwardB<kDGLCPU, int64_t, double>(\n    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);\n\n}  // namespace aten\n}  // namespace dgl\n"
  },
  {
    "path": "src/array/cpu/gather_mm.h",
    "content": "/**\n *  Copyright (c) 2022 by Contributors\n * @file array/cpu/gather_mm.h\n * @brief GATHER_MM CPU kernel function header.\n */\n#ifndef DGL_ARRAY_CPU_GATHER_MM_H_\n#define DGL_ARRAY_CPU_GATHER_MM_H_\n\n#include <dgl/array.h>\n#include <dgl/bcast.h>\n\n#include <utility>\n\nnamespace dgl {\nnamespace aten {\nnamespace cpu {\n\ntemplate <typename DType>\nvoid transpose(const DType *in, DType *out, const int N, const int M) {\n#pragma omp parallel for\n  for (int n = 0; n < N * M; n++) {\n    int i = n / N;\n    int j = n % N;\n    out[n] = in[M * j + i];\n  }\n}\n\ntemplate <typename DType>\nvoid matmul(\n    const DType *A, const DType *B, DType *C, const int M, const int N,\n    const int K) {\n#pragma omp parallel\n  {\n    int i, j, k;\n#pragma omp for\n    for (i = 0; i < M; i++) {\n      for (j = 0; j < N; j++) {\n        DType local_accum = 0;\n        for (k = 0; k < K; k++) {\n          local_accum += A[i * K + k] * B[k * N + j];\n        }\n        C[i * N + j] = local_accum;\n      }\n    }\n  }\n}\n\n/**\n * @brief CPU kernel of Gather_mm. The input matrix A is expected to be\n *        sorted according to relation type.\n * @param A The input dense matrix of dimension m x k\n * @param B The input dense matrix of dimension k x n\n * @param C The output dense matrix od dimension m x n\n * @param A_dim1_per_rel The number of rows in each relation in A\n * @param B_dim1_per_rel The number of rows in each relation in B\n * @param a_trans Matrix A to be transposed\n * @param b_trans Matrix B to be transposed\n */\ntemplate <int XPU, typename IdType, typename DType>\nvoid gatherMM_SortedEtype(\n    const NDArray A, const NDArray B, NDArray C, const NDArray A_dim1_per_rel,\n    const NDArray B_dim1_per_rel, bool a_trans, bool b_trans) {\n  assert(A_dim1_per_rel.NumElements() == B_dim1_per_rel.NumElements());\n  int64_t num_rel = A_dim1_per_rel.NumElements();\n  const DType *A_data = A.Ptr<DType>();\n  const DType *B_data = B.Ptr<DType>();\n  const IdType *A_rel_data = A_dim1_per_rel.Ptr<IdType>();\n  const IdType *B_rel_data = B_dim1_per_rel.Ptr<IdType>();\n  DType *C_data = C.Ptr<DType>();\n\n  int64_t A_offset = 0, B_offset = 0, C_offset = 0;\n  int64_t m, n, k, h_col, w_row;\n  for (int etype = 0; etype < num_rel; ++etype) {\n    assert(\n        (a_trans)                  ? A_rel_data[etype]\n        : A->shape[1] == (b_trans) ? B->shape[1]\n                                   : B_rel_data[etype]);\n    m = A_rel_data[etype];  // rows of A\n    n = B->shape[1];        // cols of B\n    k = B_rel_data[etype];  // rows of B == cols of A\n\n    NDArray A_trans, B_trans;\n    if (a_trans) {\n      A_trans = NDArray::Empty({m * k}, A->dtype, A->ctx);\n      transpose<DType>(\n          A_data + A_offset, static_cast<DType *>(A_trans->data), m, k);\n    }\n    if (b_trans) {\n      B_trans = NDArray::Empty({k * n}, B->dtype, B->ctx);\n      transpose<DType>(\n          B_data + B_offset, static_cast<DType *>(B_trans->data), k, n);\n    }\n    if (a_trans || b_trans) {\n      int64_t tmp = k;\n      if (a_trans) std::swap(m, k);\n      if (b_trans) {\n        k = tmp;\n        std::swap(n, k);\n      }\n    }\n    matmul<DType>(\n        (a_trans) ? static_cast<DType *>(A_trans->data) : A_data + A_offset,\n        (b_trans) ? static_cast<DType *>(B_trans->data) : B_data + B_offset,\n        C_data + C_offset, m, n, k);\n    A_offset += m * k;\n    B_offset += k * n;\n    C_offset += m * n;\n  }\n}\n\n}  // namespace cpu\n}  // namespace aten\n}  // namespace dgl\n\n#endif  // DGL_ARRAY_CPU_GATHER_MM_H_\n"
  },
  {
    "path": "src/array/cpu/labor_pick.h",
    "content": "/**\n *   Copyright (c) 2022, NVIDIA Corporation\n *   Copyright (c) 2022, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)\n *   All rights reserved.\n *\n *   Licensed under the Apache License, Version 2.0 (the \"License\");\n *   you may not use this file except in compliance with the License.\n *   You may obtain a copy of the License at\n *\n *       http://www.apache.org/licenses/LICENSE-2.0\n *\n *   Unless required by applicable law or agreed to in writing, software\n *   distributed under the License is distributed on an \"AS IS\" BASIS,\n *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n *   See the License for the specific language governing permissions and\n *   limitations under the License.\n *\n * @file array/cpu/labor_pick.h\n * @brief Template implementation for layerwise pick operators.\n */\n\n#ifndef DGL_ARRAY_CPU_LABOR_PICK_H_\n#define DGL_ARRAY_CPU_LABOR_PICK_H_\n\n#include <dgl/array.h>\n#include <dgl/random.h>\n#include <dgl/runtime/parallel_for.h>\n#include <dmlc/omp.h>\n#include <tsl/robin_map.h>\n\n#include <algorithm>\n#include <cmath>\n#include <functional>\n#include <memory>\n#include <numeric>\n#include <string>\n#include <utility>\n#include <vector>\n\n#include \"../../random/continuous_seed.h\"\n\nnamespace dgl {\nnamespace aten {\nnamespace impl {\n\nusing dgl::random::continuous_seed;\n\ntemplate <typename K, typename V>\nusing map_t = tsl::robin_map<K, V>;\ntemplate <typename iterator>\nauto& mutable_value_ref(iterator it) {\n  return it.value();\n}\n\nconstexpr double eps = 0.0001;\n\ntemplate <typename IdxType, typename FloatType>\nauto compute_importance_sampling_probabilities(\n    DGLContext ctx, DGLDataType dtype, const IdxType max_degree,\n    const IdxType num_rows, const int importance_sampling, const bool weighted,\n    const IdxType* rows_data, const IdxType* indptr, const FloatType* A,\n    const IdxType* indices, const IdxType num_picks, const FloatType* ds,\n    FloatType* cs) {\n  constexpr FloatType ONE = 1;\n  // ps stands for \\pi in arXiv:2210.13339\n  FloatArray ps_array = NDArray::Empty({max_degree + 1}, dtype, ctx);\n  FloatType* ps = ps_array.Ptr<FloatType>();\n\n  double prev_ex_nodes = max_degree * num_rows;\n\n  map_t<IdxType, FloatType> hop_map, hop_map2;\n  for (int iters = 0; iters < importance_sampling || importance_sampling < 0;\n       iters++) {\n    // NOTE(mfbalin) When the graph is unweighted, the first c values in\n    // the first iteration can be computed in O(1) as k / d where k is fanout\n    // and d is the degree.\n\n    // If the graph is weighted, the first c values are computed in the inner\n    // for loop instead. Therefore the importance_sampling argument should be\n    // increased by one in the caller.\n\n    // The later iterations will have correct c values so the if block will be\n    // executed.\n\n    if (!weighted || iters) {\n      hop_map2.clear();\n      for (int64_t i = 0; i < num_rows; ++i) {\n        const FloatType c = cs[i];\n        const IdxType rid = rows_data[i];\n        for (auto j = indptr[rid]; j < indptr[rid + 1]; j++) {\n          const auto ct = c * (weighted && iters == 1 ? A[j] : 1);\n          auto itb = hop_map2.emplace(indices[j], ct);\n          if (!itb.second) {\n            mutable_value_ref(itb.first) = std::max(ct, itb.first->second);\n          }\n        }\n      }\n      if (hop_map.empty())\n        hop_map = std::move(hop_map2);\n      else\n        // Update the pi array according to Eq 18.\n        for (auto it : hop_map2) hop_map[it.first] *= it.second;\n    }\n\n    // Compute c_s according to Equation (15), (17) is slower because sorting is\n    // required.\n    for (int64_t i = 0; i < num_rows; ++i) {\n      const IdxType rid = rows_data[i];\n      const auto d = indptr[rid + 1] - indptr[rid];\n      if (d == 0) continue;\n\n      const auto k = std::min(num_picks, d);\n\n      if (hop_map.empty()) {  // weighted first iter, pi = A\n        for (auto j = indptr[rid]; j < indptr[rid + 1]; j++)\n          ps[j - indptr[rid]] = A[j];\n      } else {\n        for (auto j = indptr[rid]; j < indptr[rid + 1]; j++)\n          ps[j - indptr[rid]] = hop_map[indices[j]];\n      }\n\n      // stands for RHS of Equation (22) in arXiv:2210.13339 after moving the\n      // other terms without c_s to RHS.\n      double var_target = ds[i] * ds[i] / k;\n      if (weighted) {\n        var_target -= ds[i] * ds[i] / d;\n        for (auto j = indptr[rid]; j < indptr[rid + 1]; j++)\n          var_target += A[j] * A[j];\n      }\n      FloatType c = cs[i];\n      // stands for left handside of Equation (22) in arXiv:2210.13339 after\n      // moving the other terms without c_s to RHS.\n      double var_1;\n      // Compute c_s in Equation (22) via fixed-point iteration.\n      do {\n        var_1 = 0;\n        if (weighted) {\n          for (auto j = indptr[rid]; j < indptr[rid + 1]; j++)\n            // The check for zero is necessary for numerical stability\n            var_1 += A[j] > 0\n                         ? A[j] * A[j] / std::min(ONE, c * ps[j - indptr[rid]])\n                         : 0;\n        } else {\n          for (auto j = indptr[rid]; j < indptr[rid + 1]; j++)\n            var_1 += ONE / std::min(ONE, c * ps[j - indptr[rid]]);\n        }\n\n        c *= var_1 / var_target;\n      } while (std::min(var_1, var_target) / std::max(var_1, var_target) <\n               1 - eps);\n\n      cs[i] = c;\n    }\n\n    // Check convergence\n    if (!weighted || iters) {\n      double cur_ex_nodes = 0;\n      for (auto it : hop_map) cur_ex_nodes += std::min((FloatType)1, it.second);\n      if (cur_ex_nodes / prev_ex_nodes >= 1 - eps) break;\n      prev_ex_nodes = cur_ex_nodes;\n    }\n  }\n\n  return hop_map;\n}\n\n// Template for picking non-zero values row-wise.\ntemplate <typename IdxType, typename FloatType>\nstd::pair<COOMatrix, FloatArray> CSRLaborPick(\n    CSRMatrix mat, IdArray rows, int64_t num_picks, FloatArray prob,\n    int importance_sampling, IdArray random_seed_arr, float seed2_contribution,\n    IdArray NIDs) {\n  using namespace aten;\n  const IdxType* indptr = mat.indptr.Ptr<IdxType>();\n  const IdxType* indices = mat.indices.Ptr<IdxType>();\n  const IdxType* data = CSRHasData(mat) ? mat.data.Ptr<IdxType>() : nullptr;\n  const IdxType* rows_data = rows.Ptr<IdxType>();\n  const IdxType* nids = IsNullArray(NIDs) ? nullptr : NIDs.Ptr<IdxType>();\n  const auto num_rows = rows->shape[0];\n  const auto& ctx = mat.indptr->ctx;\n\n  const bool weighted = !IsNullArray(prob);\n  // O(1) c computation not possible, so one more iteration is needed.\n  if (importance_sampling >= 0) importance_sampling += weighted;\n  // A stands for the same notation in arXiv:2210.13339, i.e. the edge weights.\n  auto A_arr = prob;\n  FloatType* A = A_arr.Ptr<FloatType>();\n  constexpr FloatType ONE = 1;\n\n  constexpr auto dtype = DGLDataTypeTraits<FloatType>::dtype;\n\n  // cs stands for c_s in arXiv:2210.13339\n  FloatArray cs_array = NDArray::Empty({num_rows}, dtype, ctx);\n  FloatType* cs = cs_array.Ptr<FloatType>();\n  // ds stands for A_{*s} in arXiv:2210.13339\n  FloatArray ds_array = NDArray::Empty({num_rows}, dtype, ctx);\n  FloatType* ds = ds_array.Ptr<FloatType>();\n\n  IdxType max_degree = 1;\n  IdxType hop_size = 0;\n  for (int64_t i = 0; i < num_rows; ++i) {\n    const IdxType rid = rows_data[i];\n    const auto act_degree = indptr[rid + 1] - indptr[rid];\n    max_degree = std::max(act_degree, max_degree);\n    double d = weighted\n                   ? std::accumulate(A + indptr[rid], A + indptr[rid + 1], 0.0)\n                   : act_degree;\n    // O(1) c computation, samples more than needed for weighted case, mentioned\n    // in the sentence between (10) and (11) in arXiv:2210.13339\n    cs[i] = num_picks / d;\n    ds[i] = d;\n    hop_size += act_degree;\n  }\n\n  map_t<IdxType, FloatType> hop_map;\n\n  if (importance_sampling)\n    hop_map = compute_importance_sampling_probabilities<IdxType, FloatType>(\n        ctx, dtype, max_degree, num_rows, importance_sampling, weighted,\n        rows_data, indptr, A, indices, (IdxType)num_picks, ds, cs);\n\n  constexpr auto vidtype = DGLDataTypeTraits<IdxType>::dtype;\n\n  IdArray picked_row = NDArray::Empty({hop_size}, vidtype, ctx);\n  IdArray picked_col = NDArray::Empty({hop_size}, vidtype, ctx);\n  IdArray picked_idx = NDArray::Empty({hop_size}, vidtype, ctx);\n  FloatArray picked_imp = importance_sampling\n                              ? NDArray::Empty({hop_size}, dtype, ctx)\n                              : NullArray();\n  IdxType* picked_rdata = picked_row.Ptr<IdxType>();\n  IdxType* picked_cdata = picked_col.Ptr<IdxType>();\n  IdxType* picked_idata = picked_idx.Ptr<IdxType>();\n  FloatType* picked_imp_data = picked_imp.Ptr<FloatType>();\n\n  const continuous_seed random_seed =\n      IsNullArray(random_seed_arr)\n          ? continuous_seed(RandomEngine::ThreadLocal()->RandInt(1000000000))\n          : continuous_seed(random_seed_arr, seed2_contribution);\n\n  // compute number of edges first and do sampling\n  IdxType num_edges = 0;\n  for (int64_t i = 0; i < num_rows; i++) {\n    const IdxType rid = rows_data[i];\n    const auto c = cs[i];\n\n    FloatType norm_inv_p = 0;\n    const auto off = num_edges;\n    for (auto j = indptr[rid]; j < indptr[rid + 1]; j++) {\n      const auto v = indices[j];\n      const uint64_t t = nids ? nids[v] : v;  // t in the paper\n      // rolled random number r_t is a function of the random_seed and t\n      const auto rnd = random_seed.uniform(t);\n      const auto w = (weighted ? A[j] : 1);\n      // if hop_map is initialized, get ps from there, otherwise get it from the\n      // alternative.\n      const auto ps = std::min(\n          ONE, importance_sampling - weighted ? c * hop_map[v] : c * w);\n      if (rnd <= ps) {\n        picked_rdata[num_edges] = rid;\n        picked_cdata[num_edges] = v;\n        picked_idata[num_edges] = data ? data[j] : j;\n        if (importance_sampling) {\n          const auto edge_weight = w / ps;\n          norm_inv_p += edge_weight;\n          picked_imp_data[num_edges] = edge_weight;\n        }\n        num_edges++;\n      }\n    }\n\n    if (importance_sampling) {\n      const auto norm_factor = (num_edges - off) / norm_inv_p;\n      for (auto i = off; i < num_edges; i++)\n        // so that fn.mean can be used\n        picked_imp_data[i] *= norm_factor;\n    }\n  }\n\n  picked_row = picked_row.CreateView({num_edges}, picked_row->dtype);\n  picked_col = picked_col.CreateView({num_edges}, picked_col->dtype);\n  picked_idx = picked_idx.CreateView({num_edges}, picked_idx->dtype);\n  if (importance_sampling)\n    picked_imp = picked_imp.CreateView({num_edges}, picked_imp->dtype);\n\n  return std::make_pair(\n      COOMatrix(mat.num_rows, mat.num_cols, picked_row, picked_col, picked_idx),\n      picked_imp);\n}\n\n// Template for picking non-zero values row-wise. The implementation first\n// slices out the corresponding rows and then converts it to CSR format. It then\n// performs row-wise pick on the CSR matrix and rectifies the returned results.\ntemplate <typename IdxType, typename FloatType>\nstd::pair<COOMatrix, FloatArray> COOLaborPick(\n    COOMatrix mat, IdArray rows, int64_t num_picks, FloatArray prob,\n    int importance_sampling, IdArray random_seed, float seed2_contribution,\n    IdArray NIDs) {\n  using namespace aten;\n  const auto& csr = COOToCSR(COOSliceRows(mat, rows));\n  const IdArray new_rows =\n      Range(0, rows->shape[0], rows->dtype.bits, rows->ctx);\n  const auto&& picked_importances = CSRLaborPick<IdxType, FloatType>(\n      csr, new_rows, num_picks, prob, importance_sampling, random_seed,\n      seed2_contribution, NIDs);\n  const auto& picked = picked_importances.first;\n  const auto& importances = picked_importances.second;\n  return std::make_pair(\n      COOMatrix(\n          mat.num_rows, mat.num_cols,\n          IndexSelect(\n              rows, picked.row),  // map the row index to the correct one\n          picked.col, picked.data),\n      importances);\n}\n\n}  // namespace impl\n}  // namespace aten\n}  // namespace dgl\n\n#endif  // DGL_ARRAY_CPU_LABOR_PICK_H_\n"
  },
  {
    "path": "src/array/cpu/labor_sampling.cc",
    "content": "/*!\n *   Copyright (c) 2022, NVIDIA Corporation\n *   Copyright (c) 2022, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)\n *   All rights reserved.\n *\n *   Licensed under the Apache License, Version 2.0 (the \"License\");\n *   you may not use this file except in compliance with the License.\n *   You may obtain a copy of the License at\n *\n *       http://www.apache.org/licenses/LICENSE-2.0\n *\n *   Unless required by applicable law or agreed to in writing, software\n *   distributed under the License is distributed on an \"AS IS\" BASIS,\n *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n *   See the License for the specific language governing permissions and\n *   limitations under the License.\n *\n * \\file array/cuda/labor_sampling.cc\n * \\brief labor sampling\n */\n#include \"./labor_pick.h\"\n\nnamespace dgl {\nnamespace aten {\nnamespace impl {\n\n/////////////////////////////// CSR ///////////////////////////////\n\ntemplate <DGLDeviceType XPU, typename IdxType, typename FloatType>\nstd::pair<COOMatrix, FloatArray> CSRLaborSampling(\n    CSRMatrix mat, IdArray rows, int64_t num_samples, FloatArray prob,\n    int importance_sampling, IdArray random_seed, float seed2_contribution,\n    IdArray NIDs) {\n  return CSRLaborPick<IdxType, FloatType>(\n      mat, rows, num_samples, prob, importance_sampling, random_seed,\n      seed2_contribution, NIDs);\n}\n\ntemplate std::pair<COOMatrix, FloatArray>\nCSRLaborSampling<kDGLCPU, int32_t, float>(\n    CSRMatrix, IdArray, int64_t, FloatArray, int, IdArray, float, IdArray);\ntemplate std::pair<COOMatrix, FloatArray>\nCSRLaborSampling<kDGLCPU, int64_t, float>(\n    CSRMatrix, IdArray, int64_t, FloatArray, int, IdArray, float, IdArray);\ntemplate std::pair<COOMatrix, FloatArray>\nCSRLaborSampling<kDGLCPU, int32_t, double>(\n    CSRMatrix, IdArray, int64_t, FloatArray, int, IdArray, float, IdArray);\ntemplate std::pair<COOMatrix, FloatArray>\nCSRLaborSampling<kDGLCPU, int64_t, double>(\n    CSRMatrix, IdArray, int64_t, FloatArray, int, IdArray, float, IdArray);\n\n/////////////////////////////// COO ///////////////////////////////\n\ntemplate <DGLDeviceType XPU, typename IdxType, typename FloatType>\nstd::pair<COOMatrix, FloatArray> COOLaborSampling(\n    COOMatrix mat, IdArray rows, int64_t num_samples, FloatArray prob,\n    int importance_sampling, IdArray random_seed, float seed2_contribution,\n    IdArray NIDs) {\n  return COOLaborPick<IdxType, FloatType>(\n      mat, rows, num_samples, prob, importance_sampling, random_seed,\n      seed2_contribution, NIDs);\n}\n\ntemplate std::pair<COOMatrix, FloatArray>\nCOOLaborSampling<kDGLCPU, int32_t, float>(\n    COOMatrix, IdArray, int64_t, FloatArray, int, IdArray, float, IdArray);\ntemplate std::pair<COOMatrix, FloatArray>\nCOOLaborSampling<kDGLCPU, int64_t, float>(\n    COOMatrix, IdArray, int64_t, FloatArray, int, IdArray, float, IdArray);\ntemplate std::pair<COOMatrix, FloatArray>\nCOOLaborSampling<kDGLCPU, int32_t, double>(\n    COOMatrix, IdArray, int64_t, FloatArray, int, IdArray, float, IdArray);\ntemplate std::pair<COOMatrix, FloatArray>\nCOOLaborSampling<kDGLCPU, int64_t, double>(\n    COOMatrix, IdArray, int64_t, FloatArray, int, IdArray, float, IdArray);\n\n}  // namespace impl\n}  // namespace aten\n}  // namespace dgl\n"
  },
  {
    "path": "src/array/cpu/negative_sampling.cc",
    "content": "/**\n *  Copyright (c) 2021 by Contributors\n * @file array/cpu/negative_sampling.cc\n * @brief Uniform negative sampling on CSR.\n */\n\n#include <dgl/array.h>\n#include <dgl/array_iterator.h>\n#include <dgl/random.h>\n#include <dgl/runtime/parallel_for.h>\n\n#include <algorithm>\n#include <utility>\n\nusing namespace dgl::runtime;\n\nnamespace dgl {\nnamespace aten {\nnamespace impl {\n\ntemplate <DGLDeviceType XPU, typename IdType>\nstd::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(\n    const CSRMatrix& csr, int64_t num_samples, int num_trials,\n    bool exclude_self_loops, bool replace, double redundancy) {\n  const int64_t num_row = csr.num_rows;\n  const int64_t num_col = csr.num_cols;\n  const int64_t num_actual_samples =\n      static_cast<int64_t>(num_samples * (1 + redundancy));\n  IdArray row = Full<IdType>(-1, num_actual_samples, csr.indptr->ctx);\n  IdArray col = Full<IdType>(-1, num_actual_samples, csr.indptr->ctx);\n  IdType* row_data = row.Ptr<IdType>();\n  IdType* col_data = col.Ptr<IdType>();\n\n  parallel_for(0, num_actual_samples, 1, [&](int64_t b, int64_t e) {\n    for (int64_t i = b; i < e; ++i) {\n      for (int trial = 0; trial < num_trials; ++trial) {\n        IdType u = RandomEngine::ThreadLocal()->RandInt(num_row);\n        IdType v = RandomEngine::ThreadLocal()->RandInt(num_col);\n        if (!(exclude_self_loops && (u == v)) && !CSRIsNonZero(csr, u, v)) {\n          row_data[i] = u;\n          col_data[i] = v;\n          break;\n        }\n      }\n    }\n  });\n\n  PairIterator<IdType> begin(row_data, col_data);\n  PairIterator<IdType> end = std::remove_if(\n      begin, begin + num_actual_samples,\n      [](const std::pair<IdType, IdType>& val) { return val.first == -1; });\n  if (!replace) {\n    std::sort(\n        begin, end,\n        [](const std::pair<IdType, IdType>& a,\n           const std::pair<IdType, IdType>& b) {\n          return a.first < b.first ||\n                 (a.first == b.first && a.second < b.second);\n        });\n    end = std::unique(begin, end);\n  }\n  int64_t num_sampled =\n      std::min(static_cast<int64_t>(end - begin), num_samples);\n  return {\n      row.CreateView({num_sampled}, row->dtype),\n      col.CreateView({num_sampled}, col->dtype)};\n}\n\ntemplate std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling<\n    kDGLCPU, int32_t>(const CSRMatrix&, int64_t, int, bool, bool, double);\ntemplate std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling<\n    kDGLCPU, int64_t>(const CSRMatrix&, int64_t, int, bool, bool, double);\n\n};  // namespace impl\n};  // namespace aten\n};  // namespace dgl\n"
  },
  {
    "path": "src/array/cpu/rowwise_pick.h",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file array/cpu/rowwise_pick.h\n * @brief Template implementation for rowwise pick operators.\n */\n#ifndef DGL_ARRAY_CPU_ROWWISE_PICK_H_\n#define DGL_ARRAY_CPU_ROWWISE_PICK_H_\n\n#include <dgl/array.h>\n#include <dgl/runtime/parallel_for.h>\n#include <dmlc/omp.h>\n\n#include <algorithm>\n#include <functional>\n#include <memory>\n#include <string>\n#include <utility>\n#include <vector>\n\nnamespace dgl {\nnamespace aten {\nnamespace impl {\n\n// User-defined function for picking elements from one row.\n//\n// The column indices of the given row are stored in\n//   [col + off, col + off + len)\n//\n// Similarly, the data indices are stored in\n//   [data + off, data + off + len)\n// Data index pointer could be NULL, which means data[i] == i\n//\n// *ATTENTION*: This function will be invoked concurrently. Please make sure\n// it is thread-safe.\n//\n// @param rowid The row to pick from.\n// @param off Starting offset of this row.\n// @param len NNZ of the row.\n// @param num_picks Number of picks on the row.\n// @param col Pointer of the column indices.\n// @param data Pointer of the data indices.\n// @param out_idx Picked indices in [off, off + len).\ntemplate <typename IdxType>\nusing PickFn = std::function<void(\n    IdxType rowid, IdxType off, IdxType len, IdxType num_picks,\n    const IdxType* col, const IdxType* data, IdxType* out_idx)>;\n\n// User-defined function for determining the number of elements to pick from one\n// row.\n//\n// The column indices of the given row are stored in\n//   [col + off, col + off + len)\n//\n// Similarly, the data indices are stored in\n//   [data + off, data + off + len)\n// Data index pointer could be NULL, which means data[i] == i\n//\n// *ATTENTION*: This function will be invoked concurrently. Please make sure\n// it is thread-safe.\n//\n// @param rowid The row to pick from.\n// @param off Starting offset of this row.\n// @param len NNZ of the row.\n// @param col Pointer of the column indices.\n// @param data Pointer of the data indices.\ntemplate <typename IdxType>\nusing NumPicksFn = std::function<IdxType(\n    IdxType rowid, IdxType off, IdxType len, const IdxType* col,\n    const IdxType* data)>;\n\n// User-defined function for picking elements from a range within a row.\n//\n// The column indices of each element is in\n//   off + et_idx[et_offset+i]), where i is in [et_offset, et_offset+et_len)\n//\n// Similarly, the data indices are stored in\n//   data[off+et_idx[et_offset+i])]\n// Data index pointer could be NULL, which means data[i] ==\n// off+et_idx[et_offset+i])\n//\n// *ATTENTION*: This function will be invoked concurrently. Please make sure\n// it is thread-safe.\n//\n// @param off Starting offset of this row.\n// @param et_offset Starting offset of this range.\n// @param cur_et The edge type.\n// @param et_len Length of the range.\n// @param et_idx A map from local idx to column id.\n// @param et_eid Edge-type-specific id array.\n// @param eid Pointer of the homogenized edge id array.\n// @param out_idx Picked indices in [et_offset, et_offset + et_len).\ntemplate <typename IdxType>\nusing EtypeRangePickFn = std::function<void(\n    IdxType off, IdxType et_offset, IdxType cur_et, IdxType et_len,\n    const std::vector<IdxType>& et_idx, const std::vector<IdxType>& et_eid,\n    const IdxType* eid, IdxType* out_idx)>;\n\ntemplate <typename IdxType, bool map_seed_nodes>\nstd::pair<CSRMatrix, IdArray> CSRRowWisePickFused(\n    CSRMatrix mat, IdArray rows, IdArray seed_mapping,\n    std::vector<IdxType>* new_seed_nodes, int64_t num_picks, bool replace,\n    PickFn<IdxType> pick_fn, NumPicksFn<IdxType> num_picks_fn) {\n  using namespace aten;\n\n  const IdxType* indptr = static_cast<IdxType*>(mat.indptr->data);\n  const IdxType* indices = static_cast<IdxType*>(mat.indices->data);\n  const IdxType* data =\n      CSRHasData(mat) ? static_cast<IdxType*>(mat.data->data) : nullptr;\n  const IdxType* rows_data = static_cast<IdxType*>(rows->data);\n  const int64_t num_rows = rows->shape[0];\n  const auto& ctx = mat.indptr->ctx;\n  const auto& idtype = mat.indptr->dtype;\n  IdxType* seed_mapping_data = nullptr;\n  if (map_seed_nodes) seed_mapping_data = seed_mapping.Ptr<IdxType>();\n\n  const int num_threads = runtime::compute_num_threads(0, num_rows, 1);\n  std::vector<int64_t> global_prefix(num_threads + 1, 0);\n\n  IdArray picked_col, picked_idx, picked_coo_rows;\n\n  IdArray block_csr_indptr = IdArray::Empty({num_rows + 1}, idtype, ctx);\n  IdxType* block_csr_indptr_data = block_csr_indptr.Ptr<IdxType>();\n\n#pragma omp parallel num_threads(num_threads)\n  {\n    const int thread_id = omp_get_thread_num();\n\n    const int64_t start_i =\n        thread_id * (num_rows / num_threads) +\n        std::min(static_cast<int64_t>(thread_id), num_rows % num_threads);\n    const int64_t end_i =\n        (thread_id + 1) * (num_rows / num_threads) +\n        std::min(static_cast<int64_t>(thread_id + 1), num_rows % num_threads);\n    assert(thread_id + 1 < num_threads || end_i == num_rows);\n\n    const int64_t num_local = end_i - start_i;\n\n    std::unique_ptr<int64_t[]> local_prefix(new int64_t[num_local + 1]);\n    local_prefix[0] = 0;\n    for (int64_t i = start_i; i < end_i; ++i) {\n      // build prefix-sum\n      const int64_t local_i = i - start_i;\n      const IdxType rid = rows_data[i];\n      if (map_seed_nodes) seed_mapping_data[rid] = i;\n\n      IdxType len = num_picks_fn(\n          rid, indptr[rid], indptr[rid + 1] - indptr[rid], indices, data);\n      local_prefix[local_i + 1] = local_prefix[local_i] + len;\n    }\n    global_prefix[thread_id + 1] = local_prefix[num_local];\n\n#pragma omp barrier\n#pragma omp master\n    {\n      for (int t = 0; t < num_threads; ++t) {\n        global_prefix[t + 1] += global_prefix[t];\n      }\n      picked_col = IdArray::Empty({global_prefix[num_threads]}, idtype, ctx);\n      picked_idx = IdArray::Empty({global_prefix[num_threads]}, idtype, ctx);\n      picked_coo_rows =\n          IdArray::Empty({global_prefix[num_threads]}, idtype, ctx);\n    }\n\n#pragma omp barrier\n    IdxType* picked_cdata = picked_col.Ptr<IdxType>();\n    IdxType* picked_idata = picked_idx.Ptr<IdxType>();\n    IdxType* picked_rows = picked_coo_rows.Ptr<IdxType>();\n\n    const IdxType thread_offset = global_prefix[thread_id];\n\n    for (int64_t i = start_i; i < end_i; ++i) {\n      const IdxType rid = rows_data[i];\n      const int64_t local_i = i - start_i;\n      block_csr_indptr_data[i] = local_prefix[local_i] + thread_offset;\n\n      const IdxType off = indptr[rid];\n      const IdxType len = indptr[rid + 1] - off;\n      if (len == 0) continue;\n\n      const int64_t row_offset = local_prefix[local_i] + thread_offset;\n      const int64_t num_picks =\n          local_prefix[local_i + 1] + thread_offset - row_offset;\n\n      pick_fn(\n          rid, off, len, num_picks, indices, data, picked_idata + row_offset);\n      for (int64_t j = 0; j < num_picks; ++j) {\n        const IdxType picked = picked_idata[row_offset + j];\n        picked_cdata[row_offset + j] = indices[picked];\n        picked_idata[row_offset + j] = data ? data[picked] : picked;\n        picked_rows[row_offset + j] = i;\n      }\n    }\n  }\n  block_csr_indptr_data[num_rows] = global_prefix.back();\n\n  const IdxType num_cols = picked_col->shape[0];\n  if (map_seed_nodes) {\n    (*new_seed_nodes).resize(num_rows);\n    memcpy((*new_seed_nodes).data(), rows_data, sizeof(IdxType) * num_rows);\n  }\n\n  return std::make_pair(\n      CSRMatrix(num_rows, num_cols, block_csr_indptr, picked_col, picked_idx),\n      picked_coo_rows);\n}\n\n// Template for picking non-zero values row-wise. The implementation utilizes\n// OpenMP parallelization on rows because each row performs computation\n// independently.\ntemplate <typename IdxType>\nCOOMatrix CSRRowWisePick(\n    CSRMatrix mat, IdArray rows, int64_t num_picks, bool replace,\n    PickFn<IdxType> pick_fn, NumPicksFn<IdxType> num_picks_fn) {\n  using namespace aten;\n  const IdxType* indptr = static_cast<IdxType*>(mat.indptr->data);\n  const IdxType* indices = static_cast<IdxType*>(mat.indices->data);\n  const IdxType* data =\n      CSRHasData(mat) ? static_cast<IdxType*>(mat.data->data) : nullptr;\n  const IdxType* rows_data = static_cast<IdxType*>(rows->data);\n  const int64_t num_rows = rows->shape[0];\n  const auto& ctx = mat.indptr->ctx;\n  const auto& idtype = mat.indptr->dtype;\n\n  // To leverage OMP parallelization, we create two arrays to store\n  // picked src and dst indices. Each array is of length num_rows * num_picks.\n  // For rows whose nnz < num_picks, the indices are padded with -1.\n  //\n  // We check whether all the given rows\n  // have at least num_picks number of nnz when replace is false.\n  //\n  // If the check holds, remove -1 elements by remove_if operation, which simply\n  // moves valid elements to the head of arrays and create a view of the\n  // original array. The implementation consumes a little extra memory than the\n  // actual requirement.\n  //\n  // Otherwise, directly use the row and col arrays to construct the result COO\n  // matrix.\n  //\n  // [02/29/2020 update]: OMP is disabled for now since batch-wise parallelism\n  // is more\n  //   significant. (minjie)\n\n  // Do not use omp_get_max_threads() since that doesn't work for compiling\n  // without OpenMP.\n  const int num_threads = runtime::compute_num_threads(0, num_rows, 1);\n  std::vector<int64_t> global_prefix(num_threads + 1, 0);\n\n  // TODO(BarclayII) Using OMP parallel directly instead of using\n  // runtime::parallel_for does not handle exceptions well (directly aborts when\n  // an exception pops up). It runs faster though because there is less\n  // scheduling.  Need to handle exceptions better.\n  IdArray picked_row, picked_col, picked_idx;\n#pragma omp parallel num_threads(num_threads)\n  {\n    const int thread_id = omp_get_thread_num();\n\n    const int64_t start_i =\n        thread_id * (num_rows / num_threads) +\n        std::min(static_cast<int64_t>(thread_id), num_rows % num_threads);\n    const int64_t end_i =\n        (thread_id + 1) * (num_rows / num_threads) +\n        std::min(static_cast<int64_t>(thread_id + 1), num_rows % num_threads);\n    assert(thread_id + 1 < num_threads || end_i == num_rows);\n\n    const int64_t num_local = end_i - start_i;\n\n    // make sure we don't have to pay initialization cost\n    std::unique_ptr<int64_t[]> local_prefix(new int64_t[num_local + 1]);\n    local_prefix[0] = 0;\n    for (int64_t i = start_i; i < end_i; ++i) {\n      // build prefix-sum\n      const int64_t local_i = i - start_i;\n      const IdxType rid = rows_data[i];\n      IdxType len = num_picks_fn(\n          rid, indptr[rid], indptr[rid + 1] - indptr[rid], indices, data);\n      local_prefix[local_i + 1] = local_prefix[local_i] + len;\n    }\n    global_prefix[thread_id + 1] = local_prefix[num_local];\n\n#pragma omp barrier\n#pragma omp master\n    {\n      for (int t = 0; t < num_threads; ++t) {\n        global_prefix[t + 1] += global_prefix[t];\n      }\n      picked_row = IdArray::Empty({global_prefix[num_threads]}, idtype, ctx);\n      picked_col = IdArray::Empty({global_prefix[num_threads]}, idtype, ctx);\n      picked_idx = IdArray::Empty({global_prefix[num_threads]}, idtype, ctx);\n    }\n\n#pragma omp barrier\n    IdxType* picked_rdata = picked_row.Ptr<IdxType>();\n    IdxType* picked_cdata = picked_col.Ptr<IdxType>();\n    IdxType* picked_idata = picked_idx.Ptr<IdxType>();\n\n    const IdxType thread_offset = global_prefix[thread_id];\n\n    for (int64_t i = start_i; i < end_i; ++i) {\n      const IdxType rid = rows_data[i];\n\n      const IdxType off = indptr[rid];\n      const IdxType len = indptr[rid + 1] - off;\n      if (len == 0) continue;\n\n      const int64_t local_i = i - start_i;\n      const int64_t row_offset = thread_offset + local_prefix[local_i];\n      const int64_t num_picks =\n          thread_offset + local_prefix[local_i + 1] - row_offset;\n\n      pick_fn(\n          rid, off, len, num_picks, indices, data, picked_idata + row_offset);\n      for (int64_t j = 0; j < num_picks; ++j) {\n        const IdxType picked = picked_idata[row_offset + j];\n        picked_rdata[row_offset + j] = rid;\n        picked_cdata[row_offset + j] = indices[picked];\n        picked_idata[row_offset + j] = data ? data[picked] : picked;\n      }\n    }\n  }\n\n  const int64_t new_len = global_prefix.back();\n\n  return COOMatrix(\n      mat.num_rows, mat.num_cols,\n      picked_row.CreateView({new_len}, picked_row->dtype),\n      picked_col.CreateView({new_len}, picked_row->dtype),\n      picked_idx.CreateView({new_len}, picked_row->dtype));\n}\n\n// Template for picking non-zero values row-wise. The implementation utilizes\n// OpenMP parallelization on rows because each row performs computation\n// independently.\ntemplate <typename IdxType, typename DType>\nCOOMatrix CSRRowWisePerEtypePick(\n    CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,\n    const std::vector<int64_t>& num_picks, bool replace,\n    bool rowwise_etype_sorted, EtypeRangePickFn<IdxType> pick_fn,\n    const std::vector<NDArray>& prob_or_mask) {\n  using namespace aten;\n  const IdxType* indptr = mat.indptr.Ptr<IdxType>();\n  const IdxType* indices = mat.indices.Ptr<IdxType>();\n  const IdxType* eid = CSRHasData(mat) ? mat.data.Ptr<IdxType>() : nullptr;\n  const IdxType* rows_data = rows.Ptr<IdxType>();\n  const int64_t num_rows = rows->shape[0];\n  const auto& ctx = mat.indptr->ctx;\n  const int64_t num_etypes = num_picks.size();\n  const bool has_probs = (prob_or_mask.size() > 0);\n  std::vector<IdArray> picked_rows(rows->shape[0]);\n  std::vector<IdArray> picked_cols(rows->shape[0]);\n  std::vector<IdArray> picked_idxs(rows->shape[0]);\n\n  // Check if the number of picks have the same value.\n  // If so, we can potentially speed up if we have a node with total number of\n  // neighbors less than the given number of picks with replace=False.\n  bool same_num_pick = true;\n  int64_t num_pick_value = num_picks[0];\n  for (int64_t num_pick : num_picks) {\n    if (num_pick_value != num_pick) {\n      same_num_pick = false;\n      break;\n    }\n  }\n\n  runtime::parallel_for(0, num_rows, [&](size_t b, size_t e) {\n    for (size_t i = b; i < e; ++i) {\n      const IdxType rid = rows_data[i];\n      CHECK_LT(rid, mat.num_rows);\n      const IdxType off = indptr[rid];\n      const IdxType len = indptr[rid + 1] - off;\n\n      // do something here\n      if (len == 0) {\n        picked_rows[i] = NewIdArray(0, ctx, sizeof(IdxType) * 8);\n        picked_cols[i] = NewIdArray(0, ctx, sizeof(IdxType) * 8);\n        picked_idxs[i] = NewIdArray(0, ctx, sizeof(IdxType) * 8);\n        continue;\n      }\n\n      // fast path\n      if (same_num_pick && len <= num_pick_value && !replace) {\n        IdArray rows = Full(rid, len, sizeof(IdxType) * 8, ctx);\n        IdArray cols = Full(-1, len, sizeof(IdxType) * 8, ctx);\n        IdArray idx = Full(-1, len, sizeof(IdxType) * 8, ctx);\n        IdxType* cdata = cols.Ptr<IdxType>();\n        IdxType* idata = idx.Ptr<IdxType>();\n\n        int64_t k = 0;\n        for (int64_t j = 0; j < len; ++j) {\n          const IdxType homogenized_eid = eid ? eid[off + j] : off + j;\n          auto it = std::upper_bound(\n              eid2etype_offset.begin(), eid2etype_offset.end(),\n              homogenized_eid);\n          const IdxType heterogenized_etype = it - eid2etype_offset.begin() - 1;\n          const IdxType heterogenized_eid =\n              homogenized_eid - eid2etype_offset[heterogenized_etype];\n\n          if (!has_probs || IsNullArray(prob_or_mask[heterogenized_etype])) {\n            // No probability array, select all\n            cdata[k] = indices[off + j];\n            idata[k] = homogenized_eid;\n            ++k;\n          } else {\n            // Select the entries with non-zero probability\n            const NDArray& p = prob_or_mask[heterogenized_etype];\n            const DType* pdata = p.Ptr<DType>();\n            if (pdata[heterogenized_eid] > 0) {\n              cdata[k] = indices[off + j];\n              idata[k] = homogenized_eid;\n              ++k;\n            }\n          }\n        }\n\n        picked_rows[i] = rows.CreateView({k}, rows->dtype);\n        picked_cols[i] = cols.CreateView({k}, cols->dtype);\n        picked_idxs[i] = idx.CreateView({k}, idx->dtype);\n      } else {\n        // need to do per edge type sample\n        std::vector<IdxType> rows;\n        std::vector<IdxType> cols;\n        std::vector<IdxType> idx;\n\n        std::vector<IdxType> et(len);\n        std::vector<IdxType> et_idx(len);\n        std::vector<IdxType> et_eid(len);\n        std::iota(et_idx.begin(), et_idx.end(), 0);\n        for (int64_t j = 0; j < len; ++j) {\n          const IdxType homogenized_eid = eid ? eid[off + j] : off + j;\n          auto it = std::upper_bound(\n              eid2etype_offset.begin(), eid2etype_offset.end(),\n              homogenized_eid);\n          const IdxType heterogenized_etype = it - eid2etype_offset.begin() - 1;\n          const IdxType heterogenized_eid =\n              homogenized_eid - eid2etype_offset[heterogenized_etype];\n          et[j] = heterogenized_etype;\n          et_eid[j] = heterogenized_eid;\n        }\n        if (!rowwise_etype_sorted)  // the edge type is sorted, not need to sort\n                                    // it\n          std::sort(\n              et_idx.begin(), et_idx.end(),\n              [&et](IdxType i1, IdxType i2) { return et[i1] < et[i2]; });\n        CHECK_LT(et[et_idx[len - 1]], num_etypes)\n            << \"etype values exceed the number of fanouts\";\n\n        IdxType cur_et = et[et_idx[0]];\n        int64_t et_offset = 0;\n        int64_t et_len = 1;\n        for (int64_t j = 0; j < len; ++j) {\n          CHECK((j + 1 == len) || (et[et_idx[j]] <= et[et_idx[j + 1]]))\n              << \"Edge type is not sorted. Please sort in advance or specify \"\n                 \"'rowwise_etype_sorted' as false.\";\n          if ((j + 1 == len) || cur_et != et[et_idx[j + 1]]) {\n            // 1 end of the current etype\n            // 2 end of the row\n            // random pick for current etype\n            if ((num_picks[cur_et] == -1) ||\n                (et_len <= num_picks[cur_et] && !replace)) {\n              // fast path, select all\n              for (int64_t k = 0; k < et_len; ++k) {\n                const IdxType eid_offset = off + et_idx[et_offset + k];\n                const IdxType homogenized_eid =\n                    eid ? eid[eid_offset] : eid_offset;\n                auto it = std::upper_bound(\n                    eid2etype_offset.begin(), eid2etype_offset.end(),\n                    homogenized_eid);\n                const IdxType heterogenized_etype =\n                    it - eid2etype_offset.begin() - 1;\n                const IdxType heterogenized_eid =\n                    homogenized_eid - eid2etype_offset[heterogenized_etype];\n\n                if (!has_probs ||\n                    IsNullArray(prob_or_mask[heterogenized_etype])) {\n                  // No probability, select all\n                  rows.push_back(rid);\n                  cols.push_back(indices[eid_offset]);\n                  idx.push_back(homogenized_eid);\n                } else {\n                  // Select the entries with non-zero probability\n                  const NDArray& p = prob_or_mask[heterogenized_etype];\n                  const DType* pdata = p.Ptr<DType>();\n                  if (pdata[heterogenized_eid] > 0) {\n                    rows.push_back(rid);\n                    cols.push_back(indices[eid_offset]);\n                    idx.push_back(homogenized_eid);\n                  }\n                }\n              }\n            } else {\n              IdArray picked_idx =\n                  Full(-1, num_picks[cur_et], sizeof(IdxType) * 8, ctx);\n              IdxType* picked_idata = picked_idx.Ptr<IdxType>();\n\n              // need call random pick\n              pick_fn(\n                  off, et_offset, cur_et, et_len, et_idx, et_eid, eid,\n                  picked_idata);\n              for (int64_t k = 0; k < num_picks[cur_et]; ++k) {\n                const IdxType picked = picked_idata[k];\n                if (picked == -1) continue;\n                rows.push_back(rid);\n                cols.push_back(indices[off + et_idx[et_offset + picked]]);\n                if (eid) {\n                  idx.push_back(eid[off + et_idx[et_offset + picked]]);\n                } else {\n                  idx.push_back(off + et_idx[et_offset + picked]);\n                }\n              }\n            }\n\n            if (j + 1 == len) break;\n            // next etype\n            cur_et = et[et_idx[j + 1]];\n            et_offset = j + 1;\n            et_len = 1;\n          } else {\n            et_len++;\n          }\n        }\n\n        picked_rows[i] = VecToIdArray(rows, sizeof(IdxType) * 8, ctx);\n        picked_cols[i] = VecToIdArray(cols, sizeof(IdxType) * 8, ctx);\n        picked_idxs[i] = VecToIdArray(idx, sizeof(IdxType) * 8, ctx);\n      }  // end processing one row\n\n      CHECK_EQ(picked_rows[i]->shape[0], picked_cols[i]->shape[0]);\n      CHECK_EQ(picked_rows[i]->shape[0], picked_idxs[i]->shape[0]);\n    }  // end processing all rows\n  });\n\n  IdArray picked_row = Concat(picked_rows);\n  IdArray picked_col = Concat(picked_cols);\n  IdArray picked_idx = Concat(picked_idxs);\n  return COOMatrix(\n      mat.num_rows, mat.num_cols, picked_row, picked_col, picked_idx);\n}\n\n// Template for picking non-zero values row-wise. The implementation first\n// slices out the corresponding rows and then converts it to CSR format. It then\n// performs row-wise pick on the CSR matrix and rectifies the returned results.\ntemplate <typename IdxType>\nCOOMatrix COORowWisePick(\n    COOMatrix mat, IdArray rows, int64_t num_picks, bool replace,\n    PickFn<IdxType> pick_fn, NumPicksFn<IdxType> num_picks_fn) {\n  using namespace aten;\n  const auto& csr = COOToCSR(COOSliceRows(mat, rows));\n  const IdArray new_rows =\n      Range(0, rows->shape[0], rows->dtype.bits, rows->ctx);\n  const auto& picked = CSRRowWisePick<IdxType>(\n      csr, new_rows, num_picks, replace, pick_fn, num_picks_fn);\n  return COOMatrix(\n      mat.num_rows, mat.num_cols,\n      IndexSelect(rows, picked.row),  // map the row index to the correct one\n      picked.col, picked.data);\n}\n\n// Template for picking non-zero values row-wise. The implementation first\n// slices out the corresponding rows and then converts it to CSR format. It then\n// performs row-wise pick on the CSR matrix and rectifies the returned results.\ntemplate <typename IdxType, typename DType>\nCOOMatrix COORowWisePerEtypePick(\n    COOMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,\n    const std::vector<int64_t>& num_picks, bool replace,\n    EtypeRangePickFn<IdxType> pick_fn,\n    const std::vector<NDArray>& prob_or_mask) {\n  using namespace aten;\n  const auto& csr = COOToCSR(COOSliceRows(mat, rows));\n  const IdArray new_rows =\n      Range(0, rows->shape[0], rows->dtype.bits, rows->ctx);\n  const auto& picked = CSRRowWisePerEtypePick<IdxType, DType>(\n      csr, new_rows, eid2etype_offset, num_picks, replace, false, pick_fn,\n      prob_or_mask);\n  return COOMatrix(\n      mat.num_rows, mat.num_cols,\n      IndexSelect(rows, picked.row),  // map the row index to the correct one\n      picked.col, picked.data);\n}\n\n}  // namespace impl\n}  // namespace aten\n}  // namespace dgl\n\n#endif  // DGL_ARRAY_CPU_ROWWISE_PICK_H_\n"
  },
  {
    "path": "src/array/cpu/rowwise_sampling.cc",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file array/cpu/rowwise_sampling.cc\n * @brief rowwise sampling\n */\n#include <dgl/random.h>\n\n#include <numeric>\n\n#include \"./rowwise_pick.h\"\n\nnamespace dgl {\nnamespace aten {\nnamespace impl {\nnamespace {\n// Equivalent to numpy expression: array[idx[off:off + len]]\ntemplate <typename IdxType, typename FloatType>\ninline FloatArray DoubleSlice(\n    FloatArray array, const IdxType* idx_data, IdxType off, IdxType len) {\n  const FloatType* array_data = static_cast<FloatType*>(array->data);\n  FloatArray ret = FloatArray::Empty({len}, array->dtype, array->ctx);\n  FloatType* ret_data = static_cast<FloatType*>(ret->data);\n  for (int64_t j = 0; j < len; ++j) {\n    if (idx_data)\n      ret_data[j] = array_data[idx_data[off + j]];\n    else\n      ret_data[j] = array_data[off + j];\n  }\n  return ret;\n}\n\ntemplate <typename IdxType, typename DType>\ninline NumPicksFn<IdxType> GetSamplingNumPicksFn(\n    int64_t num_samples, NDArray prob_or_mask, bool replace) {\n  NumPicksFn<IdxType> num_picks_fn = [prob_or_mask, num_samples, replace](\n                                         IdxType rowid, IdxType off,\n                                         IdxType len, const IdxType* col,\n                                         const IdxType* data) {\n    const int64_t max_num_picks = (num_samples == -1) ? len : num_samples;\n    const DType* prob_or_mask_data = prob_or_mask.Ptr<DType>();\n    IdxType nnz = 0;\n    for (IdxType i = off; i < off + len; ++i) {\n      const IdxType eid = data ? data[i] : i;\n      if (prob_or_mask_data[eid] > 0) {\n        ++nnz;\n      }\n    }\n\n    if (replace) {\n      return static_cast<IdxType>(nnz == 0 ? 0 : max_num_picks);\n    } else {\n      return std::min(static_cast<IdxType>(max_num_picks), nnz);\n    }\n  };\n  return num_picks_fn;\n}\n\ntemplate <typename IdxType, typename DType>\ninline PickFn<IdxType> GetSamplingPickFn(\n    int64_t num_samples, NDArray prob_or_mask, bool replace) {\n  PickFn<IdxType> pick_fn = [prob_or_mask, num_samples, replace](\n                                IdxType rowid, IdxType off, IdxType len,\n                                IdxType num_picks, const IdxType* col,\n                                const IdxType* data, IdxType* out_idx) {\n    NDArray prob_or_mask_selected =\n        DoubleSlice<IdxType, DType>(prob_or_mask, data, off, len);\n    RandomEngine::ThreadLocal()->Choice<IdxType, DType>(\n        num_picks, prob_or_mask_selected, out_idx, replace);\n    for (int64_t j = 0; j < num_picks; ++j) {\n      out_idx[j] += off;\n    }\n  };\n  return pick_fn;\n}\n\ntemplate <typename IdxType, typename FloatType>\ninline EtypeRangePickFn<IdxType> GetSamplingRangePickFn(\n    const std::vector<int64_t>& num_samples,\n    const std::vector<FloatArray>& prob, bool replace) {\n  EtypeRangePickFn<IdxType> pick_fn =\n      [prob, num_samples, replace](\n          IdxType off, IdxType et_offset, IdxType cur_et, IdxType et_len,\n          const std::vector<IdxType>& et_idx,\n          const std::vector<IdxType>& et_eid, const IdxType* eid,\n          IdxType* out_idx) {\n        const FloatArray& p = prob[cur_et];\n        const FloatType* p_data = IsNullArray(p) ? nullptr : p.Ptr<FloatType>();\n        FloatArray probs = FloatArray::Empty({et_len}, p->dtype, p->ctx);\n        FloatType* probs_data = probs.Ptr<FloatType>();\n        for (int64_t j = 0; j < et_len; ++j) {\n          const IdxType cur_eid = et_eid[et_idx[et_offset + j]];\n          probs_data[j] = p_data ? p_data[cur_eid] : static_cast<FloatType>(1.);\n        }\n\n        RandomEngine::ThreadLocal()->Choice<IdxType, FloatType>(\n            num_samples[cur_et], probs, out_idx, replace);\n      };\n  return pick_fn;\n}\n\ntemplate <typename IdxType>\ninline NumPicksFn<IdxType> GetSamplingUniformNumPicksFn(\n    int64_t num_samples, bool replace) {\n  NumPicksFn<IdxType> num_picks_fn = [num_samples, replace](\n                                         IdxType rowid, IdxType off,\n                                         IdxType len, const IdxType* col,\n                                         const IdxType* data) {\n    const int64_t max_num_picks = (num_samples == -1) ? len : num_samples;\n    if (replace) {\n      return static_cast<IdxType>(len == 0 ? 0 : max_num_picks);\n    } else {\n      return std::min(static_cast<IdxType>(max_num_picks), len);\n    }\n  };\n  return num_picks_fn;\n}\n\ntemplate <typename IdxType>\ninline PickFn<IdxType> GetSamplingUniformPickFn(\n    int64_t num_samples, bool replace) {\n  PickFn<IdxType> pick_fn = [num_samples, replace](\n                                IdxType rowid, IdxType off, IdxType len,\n                                IdxType num_picks, const IdxType* col,\n                                const IdxType* data, IdxType* out_idx) {\n    RandomEngine::ThreadLocal()->UniformChoice<IdxType>(\n        num_picks, len, out_idx, replace);\n    for (int64_t j = 0; j < num_picks; ++j) {\n      out_idx[j] += off;\n    }\n  };\n  return pick_fn;\n}\n\ntemplate <typename IdxType>\ninline EtypeRangePickFn<IdxType> GetSamplingUniformRangePickFn(\n    const std::vector<int64_t>& num_samples, bool replace) {\n  EtypeRangePickFn<IdxType> pick_fn =\n      [num_samples, replace](\n          IdxType off, IdxType et_offset, IdxType cur_et, IdxType et_len,\n          const std::vector<IdxType>& et_idx,\n          const std::vector<IdxType>& et_eid, const IdxType* data,\n          IdxType* out_idx) {\n        RandomEngine::ThreadLocal()->UniformChoice<IdxType>(\n            num_samples[cur_et], et_len, out_idx, replace);\n      };\n  return pick_fn;\n}\n\ntemplate <typename IdxType, typename FloatType>\ninline NumPicksFn<IdxType> GetSamplingBiasedNumPicksFn(\n    int64_t num_samples, IdArray split, FloatArray bias, bool replace) {\n  NumPicksFn<IdxType> num_picks_fn = [num_samples, split, bias, replace](\n                                         IdxType rowid, IdxType off,\n                                         IdxType len, const IdxType* col,\n                                         const IdxType* data) {\n    const int64_t max_num_picks = (num_samples == -1) ? len : num_samples;\n    const int64_t num_tags = split->shape[1] - 1;\n    const IdxType* tag_offset = split.Ptr<IdxType>() + rowid * split->shape[1];\n    const FloatType* bias_data = bias.Ptr<FloatType>();\n    IdxType nnz = 0;\n    for (int64_t j = 0; j < num_tags; ++j) {\n      if (bias_data[j] > 0) {\n        nnz += tag_offset[j + 1] - tag_offset[j];\n      }\n    }\n\n    if (replace) {\n      return static_cast<IdxType>(nnz == 0 ? 0 : max_num_picks);\n    } else {\n      return std::min(static_cast<IdxType>(max_num_picks), nnz);\n    }\n  };\n  return num_picks_fn;\n}\n\ntemplate <typename IdxType, typename FloatType>\ninline PickFn<IdxType> GetSamplingBiasedPickFn(\n    int64_t num_samples, IdArray split, FloatArray bias, bool replace) {\n  PickFn<IdxType> pick_fn = [num_samples, split, bias, replace](\n                                IdxType rowid, IdxType off, IdxType len,\n                                IdxType num_picks, const IdxType* col,\n                                const IdxType* data, IdxType* out_idx) {\n    const IdxType* tag_offset = split.Ptr<IdxType>() + rowid * split->shape[1];\n    RandomEngine::ThreadLocal()->BiasedChoice<IdxType, FloatType>(\n        num_picks, tag_offset, bias, out_idx, replace);\n    for (int64_t j = 0; j < num_picks; ++j) {\n      out_idx[j] += off;\n    }\n  };\n  return pick_fn;\n}\n\n}  // namespace\n\n/////////////////////////////// CSR ///////////////////////////////\n\ntemplate <DGLDeviceType XPU, typename IdxType, typename DType>\nCOOMatrix CSRRowWiseSampling(\n    CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask,\n    bool replace) {\n  // If num_samples is -1, select all neighbors without replacement.\n  replace = (replace && num_samples != -1);\n  CHECK(prob_or_mask.defined());\n  auto num_picks_fn =\n      GetSamplingNumPicksFn<IdxType, DType>(num_samples, prob_or_mask, replace);\n  auto pick_fn =\n      GetSamplingPickFn<IdxType, DType>(num_samples, prob_or_mask, replace);\n  return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn, num_picks_fn);\n}\n\ntemplate COOMatrix CSRRowWiseSampling<kDGLCPU, int32_t, float>(\n    CSRMatrix, IdArray, int64_t, NDArray, bool);\ntemplate COOMatrix CSRRowWiseSampling<kDGLCPU, int64_t, float>(\n    CSRMatrix, IdArray, int64_t, NDArray, bool);\ntemplate COOMatrix CSRRowWiseSampling<kDGLCPU, int32_t, double>(\n    CSRMatrix, IdArray, int64_t, NDArray, bool);\ntemplate COOMatrix CSRRowWiseSampling<kDGLCPU, int64_t, double>(\n    CSRMatrix, IdArray, int64_t, NDArray, bool);\ntemplate COOMatrix CSRRowWiseSampling<kDGLCPU, int32_t, int8_t>(\n    CSRMatrix, IdArray, int64_t, NDArray, bool);\ntemplate COOMatrix CSRRowWiseSampling<kDGLCPU, int64_t, int8_t>(\n    CSRMatrix, IdArray, int64_t, NDArray, bool);\ntemplate COOMatrix CSRRowWiseSampling<kDGLCPU, int32_t, uint8_t>(\n    CSRMatrix, IdArray, int64_t, NDArray, bool);\ntemplate COOMatrix CSRRowWiseSampling<kDGLCPU, int64_t, uint8_t>(\n    CSRMatrix, IdArray, int64_t, NDArray, bool);\n\ntemplate <\n    DGLDeviceType XPU, typename IdxType, typename DType, bool map_seed_nodes>\nstd::pair<CSRMatrix, IdArray> CSRRowWiseSamplingFused(\n    CSRMatrix mat, IdArray rows, IdArray seed_mapping,\n    std::vector<IdxType>* new_seed_nodes, int64_t num_samples,\n    NDArray prob_or_mask, bool replace) {\n  // If num_samples is -1, select all neighbors without replacement.\n  replace = (replace && num_samples != -1);\n  CHECK(prob_or_mask.defined());\n  auto num_picks_fn =\n      GetSamplingNumPicksFn<IdxType, DType>(num_samples, prob_or_mask, replace);\n  auto pick_fn =\n      GetSamplingPickFn<IdxType, DType>(num_samples, prob_or_mask, replace);\n  return CSRRowWisePickFused<IdxType, map_seed_nodes>(\n      mat, rows, seed_mapping, new_seed_nodes, num_samples, replace, pick_fn,\n      num_picks_fn);\n}\n\ntemplate std::pair<CSRMatrix, IdArray>\nCSRRowWiseSamplingFused<kDGLCPU, int32_t, float, true>(\n    CSRMatrix, IdArray, IdArray, std::vector<int32_t>*, int64_t, NDArray, bool);\ntemplate std::pair<CSRMatrix, IdArray>\nCSRRowWiseSamplingFused<kDGLCPU, int64_t, float, true>(\n    CSRMatrix, IdArray, IdArray, std::vector<int64_t>*, int64_t, NDArray, bool);\ntemplate std::pair<CSRMatrix, IdArray>\nCSRRowWiseSamplingFused<kDGLCPU, int32_t, double, true>(\n    CSRMatrix, IdArray, IdArray, std::vector<int32_t>*, int64_t, NDArray, bool);\ntemplate std::pair<CSRMatrix, IdArray>\nCSRRowWiseSamplingFused<kDGLCPU, int64_t, double, true>(\n    CSRMatrix, IdArray, IdArray, std::vector<int64_t>*, int64_t, NDArray, bool);\ntemplate std::pair<CSRMatrix, IdArray>\nCSRRowWiseSamplingFused<kDGLCPU, int32_t, int8_t, true>(\n    CSRMatrix, IdArray, IdArray, std::vector<int32_t>*, int64_t, NDArray, bool);\ntemplate std::pair<CSRMatrix, IdArray>\nCSRRowWiseSamplingFused<kDGLCPU, int64_t, int8_t, true>(\n    CSRMatrix, IdArray, IdArray, std::vector<int64_t>*, int64_t, NDArray, bool);\ntemplate std::pair<CSRMatrix, IdArray>\nCSRRowWiseSamplingFused<kDGLCPU, int32_t, uint8_t, true>(\n    CSRMatrix, IdArray, IdArray, std::vector<int32_t>*, int64_t, NDArray, bool);\ntemplate std::pair<CSRMatrix, IdArray>\nCSRRowWiseSamplingFused<kDGLCPU, int64_t, uint8_t, true>(\n    CSRMatrix, IdArray, IdArray, std::vector<int64_t>*, int64_t, NDArray, bool);\n\ntemplate std::pair<CSRMatrix, IdArray>\nCSRRowWiseSamplingFused<kDGLCPU, int32_t, float, false>(\n    CSRMatrix, IdArray, IdArray, std::vector<int32_t>*, int64_t, NDArray, bool);\ntemplate std::pair<CSRMatrix, IdArray>\nCSRRowWiseSamplingFused<kDGLCPU, int64_t, float, false>(\n    CSRMatrix, IdArray, IdArray, std::vector<int64_t>*, int64_t, NDArray, bool);\ntemplate std::pair<CSRMatrix, IdArray>\nCSRRowWiseSamplingFused<kDGLCPU, int32_t, double, false>(\n    CSRMatrix, IdArray, IdArray, std::vector<int32_t>*, int64_t, NDArray, bool);\ntemplate std::pair<CSRMatrix, IdArray>\nCSRRowWiseSamplingFused<kDGLCPU, int64_t, double, false>(\n    CSRMatrix, IdArray, IdArray, std::vector<int64_t>*, int64_t, NDArray, bool);\ntemplate std::pair<CSRMatrix, IdArray>\nCSRRowWiseSamplingFused<kDGLCPU, int32_t, int8_t, false>(\n    CSRMatrix, IdArray, IdArray, std::vector<int32_t>*, int64_t, NDArray, bool);\ntemplate std::pair<CSRMatrix, IdArray>\nCSRRowWiseSamplingFused<kDGLCPU, int64_t, int8_t, false>(\n    CSRMatrix, IdArray, IdArray, std::vector<int64_t>*, int64_t, NDArray, bool);\ntemplate std::pair<CSRMatrix, IdArray>\nCSRRowWiseSamplingFused<kDGLCPU, int32_t, uint8_t, false>(\n    CSRMatrix, IdArray, IdArray, std::vector<int32_t>*, int64_t, NDArray, bool);\ntemplate std::pair<CSRMatrix, IdArray>\nCSRRowWiseSamplingFused<kDGLCPU, int64_t, uint8_t, false>(\n    CSRMatrix, IdArray, IdArray, std::vector<int64_t>*, int64_t, NDArray, bool);\n\ntemplate <DGLDeviceType XPU, typename IdxType, typename DType>\nCOOMatrix CSRRowWisePerEtypeSampling(\n    CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,\n    const std::vector<int64_t>& num_samples,\n    const std::vector<NDArray>& prob_or_mask, bool replace,\n    bool rowwise_etype_sorted) {\n  CHECK(prob_or_mask.size() == num_samples.size())\n      << \"the number of probability tensors does not match the number of edge \"\n         \"types.\";\n  for (auto& p : prob_or_mask) CHECK(p.defined());\n  auto pick_fn = GetSamplingRangePickFn<IdxType, DType>(\n      num_samples, prob_or_mask, replace);\n  return CSRRowWisePerEtypePick<IdxType, DType>(\n      mat, rows, eid2etype_offset, num_samples, replace, rowwise_etype_sorted,\n      pick_fn, prob_or_mask);\n}\n\ntemplate COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int32_t, float>(\n    CSRMatrix, IdArray, const std::vector<int64_t>&,\n    const std::vector<int64_t>&, const std::vector<NDArray>&, bool, bool);\ntemplate COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int64_t, float>(\n    CSRMatrix, IdArray, const std::vector<int64_t>&,\n    const std::vector<int64_t>&, const std::vector<NDArray>&, bool, bool);\ntemplate COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int32_t, double>(\n    CSRMatrix, IdArray, const std::vector<int64_t>&,\n    const std::vector<int64_t>&, const std::vector<NDArray>&, bool, bool);\ntemplate COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int64_t, double>(\n    CSRMatrix, IdArray, const std::vector<int64_t>&,\n    const std::vector<int64_t>&, const std::vector<NDArray>&, bool, bool);\ntemplate COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int32_t, int8_t>(\n    CSRMatrix, IdArray, const std::vector<int64_t>&,\n    const std::vector<int64_t>&, const std::vector<NDArray>&, bool, bool);\ntemplate COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int64_t, int8_t>(\n    CSRMatrix, IdArray, const std::vector<int64_t>&,\n    const std::vector<int64_t>&, const std::vector<NDArray>&, bool, bool);\ntemplate COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int32_t, uint8_t>(\n    CSRMatrix, IdArray, const std::vector<int64_t>&,\n    const std::vector<int64_t>&, const std::vector<NDArray>&, bool, bool);\ntemplate COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int64_t, uint8_t>(\n    CSRMatrix, IdArray, const std::vector<int64_t>&,\n    const std::vector<int64_t>&, const std::vector<NDArray>&, bool, bool);\n\ntemplate <DGLDeviceType XPU, typename IdxType>\nCOOMatrix CSRRowWiseSamplingUniform(\n    CSRMatrix mat, IdArray rows, int64_t num_samples, bool replace) {\n  // If num_samples is -1, select all neighbors without replacement.\n  replace = (replace && num_samples != -1);\n  auto num_picks_fn =\n      GetSamplingUniformNumPicksFn<IdxType>(num_samples, replace);\n  auto pick_fn = GetSamplingUniformPickFn<IdxType>(num_samples, replace);\n  return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn, num_picks_fn);\n}\n\ntemplate COOMatrix CSRRowWiseSamplingUniform<kDGLCPU, int32_t>(\n    CSRMatrix, IdArray, int64_t, bool);\ntemplate COOMatrix CSRRowWiseSamplingUniform<kDGLCPU, int64_t>(\n    CSRMatrix, IdArray, int64_t, bool);\n\ntemplate <DGLDeviceType XPU, typename IdxType, bool map_seed_nodes>\nstd::pair<CSRMatrix, IdArray> CSRRowWiseSamplingUniformFused(\n    CSRMatrix mat, IdArray rows, IdArray seed_mapping,\n    std::vector<IdxType>* new_seed_nodes, int64_t num_samples, bool replace) {\n  // If num_samples is -1, select all neighbors without replacement.\n  replace = (replace && num_samples != -1);\n  auto num_picks_fn =\n      GetSamplingUniformNumPicksFn<IdxType>(num_samples, replace);\n  auto pick_fn = GetSamplingUniformPickFn<IdxType>(num_samples, replace);\n  return CSRRowWisePickFused<IdxType, map_seed_nodes>(\n      mat, rows, seed_mapping, new_seed_nodes, num_samples, replace, pick_fn,\n      num_picks_fn);\n}\n\ntemplate std::pair<CSRMatrix, IdArray>\nCSRRowWiseSamplingUniformFused<kDGLCPU, int32_t, true>(\n    CSRMatrix, IdArray, IdArray, std::vector<int32_t>*, int64_t, bool);\ntemplate std::pair<CSRMatrix, IdArray>\nCSRRowWiseSamplingUniformFused<kDGLCPU, int64_t, true>(\n    CSRMatrix, IdArray, IdArray, std::vector<int64_t>*, int64_t, bool);\ntemplate std::pair<CSRMatrix, IdArray>\nCSRRowWiseSamplingUniformFused<kDGLCPU, int32_t, false>(\n    CSRMatrix, IdArray, IdArray, std::vector<int32_t>*, int64_t, bool);\ntemplate std::pair<CSRMatrix, IdArray>\nCSRRowWiseSamplingUniformFused<kDGLCPU, int64_t, false>(\n    CSRMatrix, IdArray, IdArray, std::vector<int64_t>*, int64_t, bool);\n\ntemplate <DGLDeviceType XPU, typename IdxType>\nCOOMatrix CSRRowWisePerEtypeSamplingUniform(\n    CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,\n    const std::vector<int64_t>& num_samples, bool replace,\n    bool rowwise_etype_sorted) {\n  auto pick_fn = GetSamplingUniformRangePickFn<IdxType>(num_samples, replace);\n  return CSRRowWisePerEtypePick<IdxType, float>(\n      mat, rows, eid2etype_offset, num_samples, replace, rowwise_etype_sorted,\n      pick_fn, {});\n}\n\ntemplate COOMatrix CSRRowWisePerEtypeSamplingUniform<kDGLCPU, int32_t>(\n    CSRMatrix, IdArray, const std::vector<int64_t>&,\n    const std::vector<int64_t>&, bool, bool);\ntemplate COOMatrix CSRRowWisePerEtypeSamplingUniform<kDGLCPU, int64_t>(\n    CSRMatrix, IdArray, const std::vector<int64_t>&,\n    const std::vector<int64_t>&, bool, bool);\n\ntemplate <DGLDeviceType XPU, typename IdxType, typename FloatType>\nCOOMatrix CSRRowWiseSamplingBiased(\n    CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray tag_offset,\n    FloatArray bias, bool replace) {\n  // If num_samples is -1, select all neighbors without replacement.\n  replace = (replace && num_samples != -1);\n  auto num_picks_fn = GetSamplingBiasedNumPicksFn<IdxType, FloatType>(\n      num_samples, tag_offset, bias, replace);\n  auto pick_fn = GetSamplingBiasedPickFn<IdxType, FloatType>(\n      num_samples, tag_offset, bias, replace);\n  return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn, num_picks_fn);\n}\n\ntemplate COOMatrix CSRRowWiseSamplingBiased<kDGLCPU, int32_t, float>(\n    CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);\n\ntemplate COOMatrix CSRRowWiseSamplingBiased<kDGLCPU, int64_t, float>(\n    CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);\n\ntemplate COOMatrix CSRRowWiseSamplingBiased<kDGLCPU, int32_t, double>(\n    CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);\n\ntemplate COOMatrix CSRRowWiseSamplingBiased<kDGLCPU, int64_t, double>(\n    CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);\n\n/////////////////////////////// COO ///////////////////////////////\n\ntemplate <DGLDeviceType XPU, typename IdxType, typename DType>\nCOOMatrix COORowWiseSampling(\n    COOMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask,\n    bool replace) {\n  // If num_samples is -1, select all neighbors without replacement.\n  replace = (replace && num_samples != -1);\n  CHECK(prob_or_mask.defined());\n  auto num_picks_fn =\n      GetSamplingNumPicksFn<IdxType, DType>(num_samples, prob_or_mask, replace);\n  auto pick_fn =\n      GetSamplingPickFn<IdxType, DType>(num_samples, prob_or_mask, replace);\n  return COORowWisePick(mat, rows, num_samples, replace, pick_fn, num_picks_fn);\n}\n\ntemplate COOMatrix COORowWiseSampling<kDGLCPU, int32_t, float>(\n    COOMatrix, IdArray, int64_t, NDArray, bool);\ntemplate COOMatrix COORowWiseSampling<kDGLCPU, int64_t, float>(\n    COOMatrix, IdArray, int64_t, NDArray, bool);\ntemplate COOMatrix COORowWiseSampling<kDGLCPU, int32_t, double>(\n    COOMatrix, IdArray, int64_t, NDArray, bool);\ntemplate COOMatrix COORowWiseSampling<kDGLCPU, int64_t, double>(\n    COOMatrix, IdArray, int64_t, NDArray, bool);\ntemplate COOMatrix COORowWiseSampling<kDGLCPU, int32_t, int8_t>(\n    COOMatrix, IdArray, int64_t, NDArray, bool);\ntemplate COOMatrix COORowWiseSampling<kDGLCPU, int64_t, int8_t>(\n    COOMatrix, IdArray, int64_t, NDArray, bool);\ntemplate COOMatrix COORowWiseSampling<kDGLCPU, int32_t, uint8_t>(\n    COOMatrix, IdArray, int64_t, NDArray, bool);\ntemplate COOMatrix COORowWiseSampling<kDGLCPU, int64_t, uint8_t>(\n    COOMatrix, IdArray, int64_t, NDArray, bool);\n\ntemplate <DGLDeviceType XPU, typename IdxType, typename DType>\nCOOMatrix COORowWisePerEtypeSampling(\n    COOMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,\n    const std::vector<int64_t>& num_samples,\n    const std::vector<NDArray>& prob_or_mask, bool replace) {\n  CHECK(prob_or_mask.size() == num_samples.size())\n      << \"the number of probability tensors do not match the number of edge \"\n         \"types.\";\n  for (auto& p : prob_or_mask) CHECK(p.defined());\n  auto pick_fn = GetSamplingRangePickFn<IdxType, DType>(\n      num_samples, prob_or_mask, replace);\n  return COORowWisePerEtypePick<IdxType, DType>(\n      mat, rows, eid2etype_offset, num_samples, replace, pick_fn, prob_or_mask);\n}\n\ntemplate COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int32_t, float>(\n    COOMatrix, IdArray, const std::vector<int64_t>&,\n    const std::vector<int64_t>&, const std::vector<NDArray>&, bool);\ntemplate COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int64_t, float>(\n    COOMatrix, IdArray, const std::vector<int64_t>&,\n    const std::vector<int64_t>&, const std::vector<NDArray>&, bool);\ntemplate COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int32_t, double>(\n    COOMatrix, IdArray, const std::vector<int64_t>&,\n    const std::vector<int64_t>&, const std::vector<NDArray>&, bool);\ntemplate COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int64_t, double>(\n    COOMatrix, IdArray, const std::vector<int64_t>&,\n    const std::vector<int64_t>&, const std::vector<NDArray>&, bool);\ntemplate COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int32_t, int8_t>(\n    COOMatrix, IdArray, const std::vector<int64_t>&,\n    const std::vector<int64_t>&, const std::vector<NDArray>&, bool);\ntemplate COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int64_t, int8_t>(\n    COOMatrix, IdArray, const std::vector<int64_t>&,\n    const std::vector<int64_t>&, const std::vector<NDArray>&, bool);\ntemplate COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int32_t, uint8_t>(\n    COOMatrix, IdArray, const std::vector<int64_t>&,\n    const std::vector<int64_t>&, const std::vector<NDArray>&, bool);\ntemplate COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int64_t, uint8_t>(\n    COOMatrix, IdArray, const std::vector<int64_t>&,\n    const std::vector<int64_t>&, const std::vector<NDArray>&, bool);\n\ntemplate <DGLDeviceType XPU, typename IdxType>\nCOOMatrix COORowWiseSamplingUniform(\n    COOMatrix mat, IdArray rows, int64_t num_samples, bool replace) {\n  // If num_samples is -1, select all neighbors without replacement.\n  replace = (replace && num_samples != -1);\n  auto num_picks_fn =\n      GetSamplingUniformNumPicksFn<IdxType>(num_samples, replace);\n  auto pick_fn = GetSamplingUniformPickFn<IdxType>(num_samples, replace);\n  return COORowWisePick(mat, rows, num_samples, replace, pick_fn, num_picks_fn);\n}\n\ntemplate COOMatrix COORowWiseSamplingUniform<kDGLCPU, int32_t>(\n    COOMatrix, IdArray, int64_t, bool);\ntemplate COOMatrix COORowWiseSamplingUniform<kDGLCPU, int64_t>(\n    COOMatrix, IdArray, int64_t, bool);\n\ntemplate <DGLDeviceType XPU, typename IdxType>\nCOOMatrix COORowWisePerEtypeSamplingUniform(\n    COOMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,\n    const std::vector<int64_t>& num_samples, bool replace) {\n  auto pick_fn = GetSamplingUniformRangePickFn<IdxType>(num_samples, replace);\n  return COORowWisePerEtypePick<IdxType, float>(\n      mat, rows, eid2etype_offset, num_samples, replace, pick_fn, {});\n}\n\ntemplate COOMatrix COORowWisePerEtypeSamplingUniform<kDGLCPU, int32_t>(\n    COOMatrix, IdArray, const std::vector<int64_t>&,\n    const std::vector<int64_t>&, bool);\ntemplate COOMatrix COORowWisePerEtypeSamplingUniform<kDGLCPU, int64_t>(\n    COOMatrix, IdArray, const std::vector<int64_t>&,\n    const std::vector<int64_t>&, bool);\n\n}  // namespace impl\n}  // namespace aten\n}  // namespace dgl\n"
  },
  {
    "path": "src/array/cpu/rowwise_topk.cc",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file array/cpu/rowwise_topk.cc\n * @brief rowwise topk\n */\n#include <algorithm>\n#include <numeric>\n\n#include \"./rowwise_pick.h\"\n\nnamespace dgl {\nnamespace aten {\nnamespace impl {\nnamespace {\n\ntemplate <typename IdxType>\ninline NumPicksFn<IdxType> GetTopkNumPicksFn(int64_t k) {\n  NumPicksFn<IdxType> num_picks_fn = [k](IdxType rowid, IdxType off,\n                                         IdxType len, const IdxType* col,\n                                         const IdxType* data) {\n    const int64_t max_num_picks = (k == -1) ? len : k;\n    return std::min(static_cast<IdxType>(max_num_picks), len);\n  };\n  return num_picks_fn;\n}\n\ntemplate <typename IdxType, typename DType>\ninline PickFn<IdxType> GetTopkPickFn(NDArray weight, bool ascending) {\n  const DType* wdata = static_cast<DType*>(weight->data);\n  PickFn<IdxType> pick_fn = [ascending, wdata](\n                                IdxType rowid, IdxType off, IdxType len,\n                                IdxType num_picks, const IdxType* col,\n                                const IdxType* data, IdxType* out_idx) {\n    std::function<bool(IdxType, IdxType)> compare_fn;\n    if (ascending) {\n      if (data) {\n        compare_fn = [wdata, data](IdxType i, IdxType j) {\n          return wdata[data[i]] < wdata[data[j]];\n        };\n      } else {\n        compare_fn = [wdata](IdxType i, IdxType j) {\n          return wdata[i] < wdata[j];\n        };\n      }\n    } else {\n      if (data) {\n        compare_fn = [wdata, data](IdxType i, IdxType j) {\n          return wdata[data[i]] > wdata[data[j]];\n        };\n      } else {\n        compare_fn = [wdata](IdxType i, IdxType j) {\n          return wdata[i] > wdata[j];\n        };\n      }\n    }\n\n    std::vector<IdxType> idx(len);\n    std::iota(idx.begin(), idx.end(), off);\n    std::sort(idx.begin(), idx.end(), compare_fn);\n    for (int64_t j = 0; j < num_picks; ++j) {\n      out_idx[j] = idx[j];\n    }\n  };\n\n  return pick_fn;\n}\n\n}  // namespace\n\ntemplate <DGLDeviceType XPU, typename IdxType, typename DType>\nCOOMatrix CSRRowWiseTopk(\n    CSRMatrix mat, IdArray rows, int64_t k, NDArray weight, bool ascending) {\n  auto num_picks_fn = GetTopkNumPicksFn<IdxType>(k);\n  auto pick_fn = GetTopkPickFn<IdxType, DType>(weight, ascending);\n  return CSRRowWisePick(mat, rows, k, false, pick_fn, num_picks_fn);\n}\n\ntemplate COOMatrix CSRRowWiseTopk<kDGLCPU, int32_t, int32_t>(\n    CSRMatrix, IdArray, int64_t, NDArray, bool);\ntemplate COOMatrix CSRRowWiseTopk<kDGLCPU, int64_t, int32_t>(\n    CSRMatrix, IdArray, int64_t, NDArray, bool);\ntemplate COOMatrix CSRRowWiseTopk<kDGLCPU, int32_t, int64_t>(\n    CSRMatrix, IdArray, int64_t, NDArray, bool);\ntemplate COOMatrix CSRRowWiseTopk<kDGLCPU, int64_t, int64_t>(\n    CSRMatrix, IdArray, int64_t, NDArray, bool);\ntemplate COOMatrix CSRRowWiseTopk<kDGLCPU, int32_t, float>(\n    CSRMatrix, IdArray, int64_t, NDArray, bool);\ntemplate COOMatrix CSRRowWiseTopk<kDGLCPU, int64_t, float>(\n    CSRMatrix, IdArray, int64_t, NDArray, bool);\ntemplate COOMatrix CSRRowWiseTopk<kDGLCPU, int32_t, double>(\n    CSRMatrix, IdArray, int64_t, NDArray, bool);\ntemplate COOMatrix CSRRowWiseTopk<kDGLCPU, int64_t, double>(\n    CSRMatrix, IdArray, int64_t, NDArray, bool);\n\ntemplate <DGLDeviceType XPU, typename IdxType, typename DType>\nCOOMatrix COORowWiseTopk(\n    COOMatrix mat, IdArray rows, int64_t k, NDArray weight, bool ascending) {\n  auto num_picks_fn = GetTopkNumPicksFn<IdxType>(k);\n  auto pick_fn = GetTopkPickFn<IdxType, DType>(weight, ascending);\n  return COORowWisePick(mat, rows, k, false, pick_fn, num_picks_fn);\n}\n\ntemplate COOMatrix COORowWiseTopk<kDGLCPU, int32_t, int32_t>(\n    COOMatrix, IdArray, int64_t, NDArray, bool);\ntemplate COOMatrix COORowWiseTopk<kDGLCPU, int64_t, int32_t>(\n    COOMatrix, IdArray, int64_t, NDArray, bool);\ntemplate COOMatrix COORowWiseTopk<kDGLCPU, int32_t, int64_t>(\n    COOMatrix, IdArray, int64_t, NDArray, bool);\ntemplate COOMatrix COORowWiseTopk<kDGLCPU, int64_t, int64_t>(\n    COOMatrix, IdArray, int64_t, NDArray, bool);\ntemplate COOMatrix COORowWiseTopk<kDGLCPU, int32_t, float>(\n    COOMatrix, IdArray, int64_t, NDArray, bool);\ntemplate COOMatrix COORowWiseTopk<kDGLCPU, int64_t, float>(\n    COOMatrix, IdArray, int64_t, NDArray, bool);\ntemplate COOMatrix COORowWiseTopk<kDGLCPU, int32_t, double>(\n    COOMatrix, IdArray, int64_t, NDArray, bool);\ntemplate COOMatrix COORowWiseTopk<kDGLCPU, int64_t, double>(\n    COOMatrix, IdArray, int64_t, NDArray, bool);\n\n}  // namespace impl\n}  // namespace aten\n}  // namespace dgl\n"
  },
  {
    "path": "src/array/cpu/sddmm.cc",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file aten/cpu/sddmm.cc\n * @brief SDDMM C APIs and definitions.\n */\n#include \"./sddmm.h\"\n\n#include <dgl/array.h>\n\nnamespace dgl {\nnamespace aten {\n\n#define SWITCH_RHS(rhs_target, RhsTarget, ...)             \\\n  do {                                                     \\\n    if ((rhs_target) == 0) {                               \\\n      constexpr int RhsTarget = 0;                         \\\n      { __VA_ARGS__ }                                      \\\n    } else if ((rhs_target) == 1) {                        \\\n      constexpr int RhsTarget = 1;                         \\\n      { __VA_ARGS__ }                                      \\\n    } else if ((rhs_target) == 2) {                        \\\n      constexpr int RhsTarget = 2;                         \\\n      { __VA_ARGS__ }                                      \\\n    } else {                                               \\\n      LOG(INFO) << \"Invalid rhs target: \" << (rhs_target); \\\n    }                                                      \\\n  } while (0)\n\n#define SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, ...) \\\n  do {                                                                   \\\n    if ((lhs_target) == 0) {                                             \\\n      constexpr int LhsTarget = 0;                                       \\\n      SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__);                    \\\n    } else if ((lhs_target) == 1) {                                      \\\n      constexpr int LhsTarget = 1;                                       \\\n      SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__);                    \\\n    } else if ((lhs_target) == 2) {                                      \\\n      constexpr int LhsTarget = 2;                                       \\\n      SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__);                    \\\n    } else {                                                             \\\n      LOG(INFO) << \"Invalid lhs target: \" << (lhs_target);               \\\n    }                                                                    \\\n  } while (0)\n\n/** @brief Generalized SDDMM on Csr format. */\ntemplate <int XPU, typename IdType, typename DType>\nvoid SDDMMCsr(\n    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,\n    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target) {\n  SWITCH_OP(op, Op, {\n    SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {\n      cpu::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(\n          bcast, csr, lhs, rhs, out);\n    });\n  });\n}\n\n/** @brief Generalized SDDMM on Csr format with Heterograph support. */\ntemplate <int XPU, typename IdType, typename DType>\nvoid SDDMMCsrHetero(\n    const std::string& op, const BcastOff& bcast,\n    const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& vec_lhs,\n    const std::vector<NDArray>& vec_rhs, std::vector<NDArray> vec_out,\n    int lhs_target, int rhs_target, const std::vector<dgl_type_t>& lhs_nid,\n    const std::vector<dgl_type_t>& rhs_nid) {\n  SWITCH_OP(op, Op, {\n    SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {\n      /* Call  SDDMM for each relation type */\n      for (dgl_type_t etype = 0; etype < lhs_nid.size(); ++etype) {\n        CSRMatrix csr = vec_csr[etype];\n        NDArray lhs = vec_lhs[lhs_nid[etype]];\n        NDArray rhs = vec_rhs[rhs_nid[etype]];\n        NDArray out = vec_out[etype];\n        cpu::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(\n            bcast, csr, lhs, rhs, out);\n      }\n    });\n  });\n}\n\ntemplate void SDDMMCsr<kDGLCPU, int32_t, BFloat16>(\n    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,\n    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);\ntemplate void SDDMMCsr<kDGLCPU, int64_t, BFloat16>(\n    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,\n    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);\ntemplate void SDDMMCsr<kDGLCPU, int32_t, float>(\n    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,\n    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);\ntemplate void SDDMMCsr<kDGLCPU, int64_t, float>(\n    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,\n    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);\ntemplate void SDDMMCsr<kDGLCPU, int32_t, double>(\n    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,\n    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);\ntemplate void SDDMMCsr<kDGLCPU, int64_t, double>(\n    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,\n    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);\n\ntemplate void SDDMMCsrHetero<kDGLCPU, int32_t, BFloat16>(\n    const std::string& op, const BcastOff& bcast,\n    const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,\n    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,\n    int rhs_target, const std::vector<dgl_type_t>& in_eid,\n    const std::vector<dgl_type_t>& out_eid);\ntemplate void SDDMMCsrHetero<kDGLCPU, int64_t, BFloat16>(\n    const std::string& op, const BcastOff& bcast,\n    const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,\n    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,\n    int rhs_target, const std::vector<dgl_type_t>& in_eid,\n    const std::vector<dgl_type_t>& out_eid);\ntemplate void SDDMMCsrHetero<kDGLCPU, int32_t, float>(\n    const std::string& op, const BcastOff& bcast,\n    const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,\n    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,\n    int rhs_target, const std::vector<dgl_type_t>& in_eid,\n    const std::vector<dgl_type_t>& out_eid);\ntemplate void SDDMMCsrHetero<kDGLCPU, int64_t, float>(\n    const std::string& op, const BcastOff& bcast,\n    const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,\n    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,\n    int rhs_target, const std::vector<dgl_type_t>& in_eid,\n    const std::vector<dgl_type_t>& out_eid);\ntemplate void SDDMMCsrHetero<kDGLCPU, int32_t, double>(\n    const std::string& op, const BcastOff& bcast,\n    const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,\n    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,\n    int rhs_target, const std::vector<dgl_type_t>& in_eid,\n    const std::vector<dgl_type_t>& out_eid);\ntemplate void SDDMMCsrHetero<kDGLCPU, int64_t, double>(\n    const std::string& op, const BcastOff& bcast,\n    const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,\n    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,\n    int rhs_target, const std::vector<dgl_type_t>& in_eid,\n    const std::vector<dgl_type_t>& out_eid);\n\n/** @brief Generalized SDDMM on Coo format. */\ntemplate <int XPU, typename IdType, typename DType>\nvoid SDDMMCoo(\n    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,\n    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target) {\n  SWITCH_OP(op, Op, {\n    SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {\n      cpu::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(\n          bcast, coo, lhs, rhs, out);\n    });\n  });\n}\n\n/** @brief Generalized SDDMM on Coo format with Heterograph support. */\ntemplate <int XPU, typename IdType, typename DType>\nvoid SDDMMCooHetero(\n    const std::string& op, const BcastOff& bcast,\n    const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& vec_lhs,\n    const std::vector<NDArray>& vec_rhs, std::vector<NDArray> vec_out,\n    int lhs_target, int rhs_target, const std::vector<dgl_type_t>& lhs_nid,\n    const std::vector<dgl_type_t>& rhs_nid) {\n  SWITCH_OP(op, Op, {\n    SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {\n      /* Call  SDDMM for each relation type */\n      for (dgl_type_t etype = 0; etype < lhs_nid.size(); ++etype) {\n        COOMatrix coo = vec_coo[etype];\n        NDArray lhs = vec_lhs[lhs_nid[etype]];\n        NDArray rhs = vec_rhs[rhs_nid[etype]];\n        NDArray out = vec_out[etype];\n        cpu::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(\n            bcast, coo, lhs, rhs, out);\n      }\n    });\n  });\n}\n\ntemplate void SDDMMCoo<kDGLCPU, int32_t, BFloat16>(\n    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,\n    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);\ntemplate void SDDMMCoo<kDGLCPU, int64_t, BFloat16>(\n    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,\n    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);\ntemplate void SDDMMCoo<kDGLCPU, int32_t, float>(\n    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,\n    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);\ntemplate void SDDMMCoo<kDGLCPU, int64_t, float>(\n    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,\n    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);\ntemplate void SDDMMCoo<kDGLCPU, int32_t, double>(\n    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,\n    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);\ntemplate void SDDMMCoo<kDGLCPU, int64_t, double>(\n    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,\n    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);\n\ntemplate void SDDMMCooHetero<kDGLCPU, int32_t, BFloat16>(\n    const std::string& op, const BcastOff& bcast,\n    const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,\n    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,\n    int rhs_target, const std::vector<dgl_type_t>& in_eid,\n    const std::vector<dgl_type_t>& out_eid);\ntemplate void SDDMMCooHetero<kDGLCPU, int64_t, BFloat16>(\n    const std::string& op, const BcastOff& bcast,\n    const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,\n    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,\n    int rhs_target, const std::vector<dgl_type_t>& in_eid,\n    const std::vector<dgl_type_t>& out_eid);\ntemplate void SDDMMCooHetero<kDGLCPU, int32_t, float>(\n    const std::string& op, const BcastOff& bcast,\n    const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,\n    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,\n    int rhs_target, const std::vector<dgl_type_t>& in_eid,\n    const std::vector<dgl_type_t>& out_eid);\ntemplate void SDDMMCooHetero<kDGLCPU, int64_t, float>(\n    const std::string& op, const BcastOff& bcast,\n    const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,\n    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,\n    int rhs_target, const std::vector<dgl_type_t>& in_eid,\n    const std::vector<dgl_type_t>& out_eid);\ntemplate void SDDMMCooHetero<kDGLCPU, int32_t, double>(\n    const std::string& op, const BcastOff& bcast,\n    const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,\n    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,\n    int rhs_target, const std::vector<dgl_type_t>& in_eid,\n    const std::vector<dgl_type_t>& out_eid);\ntemplate void SDDMMCooHetero<kDGLCPU, int64_t, double>(\n    const std::string& op, const BcastOff& bcast,\n    const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,\n    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,\n    int rhs_target, const std::vector<dgl_type_t>& in_eid,\n    const std::vector<dgl_type_t>& out_eid);\n\n}  // namespace aten\n}  // namespace dgl\n"
  },
  {
    "path": "src/array/cpu/sddmm.h",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file array/cpu/sddmm.h\n * @brief SDDMM CPU kernel function header.\n */\n#ifndef DGL_ARRAY_CPU_SDDMM_H_\n#define DGL_ARRAY_CPU_SDDMM_H_\n\n#include <dgl/array.h>\n#include <dgl/bcast.h>\n#include <dgl/runtime/parallel_for.h>\n\n#include \"../selector.h\"\n\nnamespace dgl {\nnamespace aten {\nnamespace cpu {\n\n/**\n * @brief CPU kernel of g-SDDMM on Csr format.\n * @param bcast Broadcast information.\n * @param csr The Csr matrix.\n * @param lhs The left hand side operand feature.\n * @param rhs The right hand size operand feature.\n * @param out The result feature on edges.\n * @note it uses node parallel strategy, different threads are responsible\n *       for the computation of different nodes.\n */\ntemplate <\n    typename IdType, typename DType, typename Op, int LhsTarget = 0,\n    int RhsTarget = 2>\nvoid SDDMMCsr(\n    const BcastOff& bcast, const CSRMatrix& csr, NDArray lhs, NDArray rhs,\n    NDArray out) {\n  const bool has_idx = !IsNullArray(csr.data);\n  const IdType* indptr = csr.indptr.Ptr<IdType>();\n  const IdType* indices = csr.indices.Ptr<IdType>();\n  const IdType* edges = csr.data.Ptr<IdType>();\n  const DType* X = lhs.Ptr<DType>();\n  const DType* Y = rhs.Ptr<DType>();\n  const int64_t dim = bcast.out_len, lhs_dim = bcast.lhs_len,\n                rhs_dim = bcast.rhs_len, reduce_size = bcast.reduce_size;\n  DType* O = out.Ptr<DType>();\n  runtime::parallel_for(0, csr.num_rows, [=](IdType b, IdType e) {\n    for (auto rid = b; rid < e; ++rid) {\n      const IdType row_start = indptr[rid], row_end = indptr[rid + 1];\n      for (IdType j = row_start; j < row_end; ++j) {\n        const IdType cid = indices[j];\n        const IdType eid = has_idx ? edges[j] : j;\n        DType* out_off = O + eid * dim;\n        for (int64_t k = 0; k < dim; ++k) {\n          const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k;\n          const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;\n          const DType* lhs_off =\n              Op::use_lhs\n                  ? X + Selector<LhsTarget>::Call(rid, eid, cid) * lhs_dim +\n                        lhs_add * reduce_size\n                  : nullptr;\n          const DType* rhs_off =\n              Op::use_rhs\n                  ? Y + Selector<RhsTarget>::Call(rid, eid, cid) * rhs_dim +\n                        rhs_add * reduce_size\n                  : nullptr;\n          out_off[k] = Op::Call(lhs_off, rhs_off, reduce_size);\n        }\n      }\n    }\n  });\n}\n\n/**\n * @brief CPU kernel of g-SDDMM on Coo format.\n * @param bcast Broadcast information.\n * @param coo The COO matrix.\n * @param lhs The left hand side operand feature.\n * @param rhs The right hand size operand feature.\n * @param out The result feature on edges.\n * @note it uses edge parallel strategy, different threads are responsible\n *       for the computation of different edges.\n */\ntemplate <\n    typename IdType, typename DType, typename Op, int LhsTarget = 0,\n    int RhsTarget = 2>\nvoid SDDMMCoo(\n    const BcastOff& bcast, const COOMatrix& coo, NDArray lhs, NDArray rhs,\n    NDArray out) {\n  const bool has_idx = !IsNullArray(coo.data);\n  const IdType* row = coo.row.Ptr<IdType>();\n  const IdType* col = coo.col.Ptr<IdType>();\n  const IdType* edges = coo.data.Ptr<IdType>();\n  const DType* X = lhs.Ptr<DType>();\n  const DType* Y = rhs.Ptr<DType>();\n  const int64_t dim = bcast.out_len, lhs_dim = bcast.lhs_len,\n                rhs_dim = bcast.rhs_len, reduce_size = bcast.reduce_size;\n  DType* O = out.Ptr<DType>();\n#pragma omp parallel for\n  for (int64_t i = 0; i < coo.row->shape[0]; ++i) {\n    const IdType rid = row[i];\n    const IdType cid = col[i];\n    const IdType eid = has_idx ? edges[i] : i;\n    DType* out_off = O + eid * dim;\n    for (int64_t k = 0; k < dim; ++k) {\n      const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k;\n      const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;\n      const DType* lhs_off =\n          Op::use_lhs ? X + Selector<LhsTarget>::Call(rid, eid, cid) * lhs_dim +\n                            lhs_add * reduce_size\n                      : nullptr;\n      const DType* rhs_off =\n          Op::use_rhs ? Y + Selector<RhsTarget>::Call(rid, eid, cid) * rhs_dim +\n                            rhs_add * reduce_size\n                      : nullptr;\n      out_off[k] = Op::Call(lhs_off, rhs_off, bcast.reduce_size);\n    }\n  }\n}\n\nnamespace op {\n\n////////////////////////// binary operators on CPU /////////////////////////////\ntemplate <typename DType>\nstruct Add {\n  static constexpr bool use_lhs = true;\n  static constexpr bool use_rhs = true;\n  inline static DType Call(\n      const DType* lhs_off, const DType* rhs_off, int64_t len = 1) {\n    return *lhs_off + *rhs_off;\n  }\n};\n\ntemplate <typename DType>\nstruct Sub {\n  static constexpr bool use_lhs = true;\n  static constexpr bool use_rhs = true;\n  inline static DType Call(\n      const DType* lhs_off, const DType* rhs_off, int64_t len = 1) {\n    return *lhs_off - *rhs_off;\n  }\n};\n\ntemplate <typename DType>\nstruct Mul {\n  static constexpr bool use_lhs = true;\n  static constexpr bool use_rhs = true;\n  inline static DType Call(\n      const DType* lhs_off, const DType* rhs_off, int64_t len = 1) {\n    return *lhs_off * *rhs_off;\n  }\n};\n\ntemplate <typename DType>\nstruct Div {\n  static constexpr bool use_lhs = true;\n  static constexpr bool use_rhs = true;\n  inline static DType Call(\n      const DType* lhs_off, const DType* rhs_off, int64_t len = 1) {\n    return *lhs_off / *rhs_off;\n  }\n};\n\ntemplate <typename DType>\nstruct CopyLhs {\n  static constexpr bool use_lhs = true;\n  static constexpr bool use_rhs = false;\n  inline static DType Call(\n      const DType* lhs_off, const DType*, int64_t len = 1) {\n    return *lhs_off;\n  }\n};\n\ntemplate <typename DType>\nstruct CopyRhs {\n  static constexpr bool use_lhs = false;\n  static constexpr bool use_rhs = true;\n  inline static DType Call(\n      const DType*, const DType* rhs_off, int64_t len = 1) {\n    return *rhs_off;\n  }\n};\n\ntemplate <typename DType>\nstruct Dot {\n  static constexpr bool use_lhs = true;\n  static constexpr bool use_rhs = true;\n  inline static DType Call(\n      const DType* lhs_off, const DType* rhs_off, int64_t len = 1) {\n    DType rst = 0;\n    for (int64_t l = 0; l < len; ++l) {\n      rst += lhs_off[l] * rhs_off[l];\n    }\n    return rst;\n  }\n};\n\n#define SWITCH_OP(op, Op, ...)                                   \\\n  do {                                                           \\\n    if ((op) == \"add\") {                                         \\\n      typedef dgl::aten::cpu::op::Add<DType> Op;                 \\\n      { __VA_ARGS__ }                                            \\\n    } else if ((op) == \"sub\") {                                  \\\n      typedef dgl::aten::cpu::op::Sub<DType> Op;                 \\\n      { __VA_ARGS__ }                                            \\\n    } else if ((op) == \"mul\") {                                  \\\n      typedef dgl::aten::cpu::op::Mul<DType> Op;                 \\\n      { __VA_ARGS__ }                                            \\\n    } else if ((op) == \"div\") {                                  \\\n      typedef dgl::aten::cpu::op::Div<DType> Op;                 \\\n      { __VA_ARGS__ }                                            \\\n    } else if ((op) == \"copy_lhs\") {                             \\\n      typedef dgl::aten::cpu::op::CopyLhs<DType> Op;             \\\n      { __VA_ARGS__ }                                            \\\n    } else if ((op) == \"copy_rhs\") {                             \\\n      typedef dgl::aten::cpu::op::CopyRhs<DType> Op;             \\\n      { __VA_ARGS__ }                                            \\\n    } else if ((op) == \"dot\") {                                  \\\n      typedef dgl::aten::cpu::op::Dot<DType> Op;                 \\\n      { __VA_ARGS__ }                                            \\\n    } else {                                                     \\\n      LOG(FATAL) << \"Unsupported SDDMM binary operator: \" << op; \\\n    }                                                            \\\n  } while (0)\n\n}  // namespace op\n\n}  // namespace cpu\n}  // namespace aten\n}  // namespace dgl\n\n#endif  // DGL_ARRAY_CPU_SDDMM_H_\n"
  },
  {
    "path": "src/array/cpu/segment_reduce.cc",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file kernel/cpu/segment_reduce.cc\n * @brief Segment reduce C APIs and definitions.\n */\n#include \"./segment_reduce.h\"\n\n#include <dgl/array.h>\n\n#include <string>\n\n#include \"./spmm_binary_ops.h\"\n\nnamespace dgl {\nnamespace aten {\n\n/** @brief Segment Reduce operator. */\ntemplate <int XPU, typename IdType, typename DType>\nvoid SegmentReduce(\n    const std::string& op, NDArray feat, NDArray offsets, NDArray out,\n    NDArray arg) {\n  if (op == \"sum\") {\n    cpu::SegmentSum<IdType, DType>(feat, offsets, out);\n  } else if (op == \"max\" || op == \"min\") {\n    if (op == \"max\") {\n      cpu::SegmentCmp<IdType, DType, cpu::op::Max<DType>>(\n          feat, offsets, out, arg);\n    } else {\n      cpu::SegmentCmp<IdType, DType, cpu::op::Min<DType>>(\n          feat, offsets, out, arg);\n    }\n  } else {\n    LOG(FATAL) << \"Unsupported reduce function \" << op;\n  }\n}\n\n/** @brief Scatter Add.*/\ntemplate <int XPU, typename IdType, typename DType>\nvoid ScatterAdd(NDArray feat, NDArray idx, NDArray out) {\n  cpu::ScatterAdd<IdType, DType>(feat, idx, out);\n}\n\n/** @brief Update gradients for reduce operator max/min on heterogeneous\n * graph.*/\ntemplate <int XPU, typename IdType, typename DType>\nvoid UpdateGradMinMax_hetero(\n    const HeteroGraphPtr& g, const std::string& op,\n    const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,\n    const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out) {\n  cpu::UpdateGradMinMax_hetero<IdType, DType>(g, op, feat, idx, idx_etype, out);\n}\n\n/** @brief Backward function of segment cmp.*/\ntemplate <int XPU, typename IdType, typename DType>\nvoid BackwardSegmentCmp(NDArray feat, NDArray arg, NDArray out) {\n  cpu::BackwardSegmentCmp<IdType, DType>(feat, arg, out);\n}\n\ntemplate void SegmentReduce<kDGLCPU, int32_t, BFloat16>(\n    const std::string& op, NDArray feat, NDArray offsets, NDArray out,\n    NDArray arg);\ntemplate void SegmentReduce<kDGLCPU, int64_t, BFloat16>(\n    const std::string& op, NDArray feat, NDArray offsets, NDArray out,\n    NDArray arg);\ntemplate void SegmentReduce<kDGLCPU, int32_t, float>(\n    const std::string& op, NDArray feat, NDArray offsets, NDArray out,\n    NDArray arg);\ntemplate void SegmentReduce<kDGLCPU, int64_t, float>(\n    const std::string& op, NDArray feat, NDArray offsets, NDArray out,\n    NDArray arg);\ntemplate void SegmentReduce<kDGLCPU, int32_t, double>(\n    const std::string& op, NDArray feat, NDArray offsets, NDArray out,\n    NDArray arg);\ntemplate void SegmentReduce<kDGLCPU, int64_t, double>(\n    const std::string& op, NDArray feat, NDArray offsets, NDArray out,\n    NDArray arg);\n\ntemplate <>\nvoid ScatterAdd<kDGLCPU, int32_t, BFloat16>(\n    NDArray feat, NDArray idx, NDArray out) {\n  LOG(FATAL) << \"Unsupported CPU kernel for ScatterAdd for BF16.\";\n}\ntemplate <>\nvoid ScatterAdd<kDGLCPU, int64_t, BFloat16>(\n    NDArray feat, NDArray idx, NDArray out) {\n  LOG(FATAL) << \"Unsupported CPU kernel for ScatterAdd for BF16.\";\n}\ntemplate void ScatterAdd<kDGLCPU, int32_t, float>(\n    NDArray feat, NDArray idx, NDArray out);\ntemplate void ScatterAdd<kDGLCPU, int64_t, float>(\n    NDArray feat, NDArray idx, NDArray out);\ntemplate void ScatterAdd<kDGLCPU, int32_t, double>(\n    NDArray feat, NDArray idx, NDArray out);\ntemplate void ScatterAdd<kDGLCPU, int64_t, double>(\n    NDArray feat, NDArray arg, NDArray out);\n\ntemplate <>\nvoid UpdateGradMinMax_hetero<kDGLCPU, int32_t, BFloat16>(\n    const HeteroGraphPtr& g, const std::string& op,\n    const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,\n    const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out) {\n  LOG(FATAL) << \"Unsupported CPU kernel for UpdateGradMinMax_hetero for BF16.\";\n}\ntemplate <>\nvoid UpdateGradMinMax_hetero<kDGLCPU, int64_t, BFloat16>(\n    const HeteroGraphPtr& g, const std::string& op,\n    const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,\n    const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out) {\n  LOG(FATAL) << \"Unsupported CPU kernel for UpdateGradMinMax_hetero for BF16.\";\n}\ntemplate void UpdateGradMinMax_hetero<kDGLCPU, int32_t, float>(\n    const HeteroGraphPtr& g, const std::string& op,\n    const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,\n    const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);\ntemplate void UpdateGradMinMax_hetero<kDGLCPU, int64_t, float>(\n    const HeteroGraphPtr& g, const std::string& op,\n    const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,\n    const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);\ntemplate void UpdateGradMinMax_hetero<kDGLCPU, int32_t, double>(\n    const HeteroGraphPtr& g, const std::string& op,\n    const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,\n    const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);\ntemplate void UpdateGradMinMax_hetero<kDGLCPU, int64_t, double>(\n    const HeteroGraphPtr& g, const std::string& op,\n    const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,\n    const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);\n\ntemplate void BackwardSegmentCmp<kDGLCPU, int32_t, BFloat16>(\n    NDArray feat, NDArray arg, NDArray out);\ntemplate void BackwardSegmentCmp<kDGLCPU, int64_t, BFloat16>(\n    NDArray feat, NDArray arg, NDArray out);\ntemplate void BackwardSegmentCmp<kDGLCPU, int32_t, float>(\n    NDArray feat, NDArray arg, NDArray out);\ntemplate void BackwardSegmentCmp<kDGLCPU, int64_t, float>(\n    NDArray feat, NDArray arg, NDArray out);\ntemplate void BackwardSegmentCmp<kDGLCPU, int32_t, double>(\n    NDArray feat, NDArray arg, NDArray out);\ntemplate void BackwardSegmentCmp<kDGLCPU, int64_t, double>(\n    NDArray feat, NDArray arg, NDArray out);\n\n}  // namespace aten\n}  // namespace dgl\n"
  },
  {
    "path": "src/array/cpu/segment_reduce.h",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file array/cpu/spmm.h\n * @brief Segment reduce kernel function header.\n */\n#ifndef DGL_ARRAY_CPU_SEGMENT_REDUCE_H_\n#define DGL_ARRAY_CPU_SEGMENT_REDUCE_H_\n\n#include <dgl/array.h>\n#include <dgl/base_heterograph.h>\n#include <dgl/runtime/parallel_for.h>\n\n#include <string>\n#include <vector>\n\nnamespace dgl {\nnamespace aten {\nnamespace cpu {\n\n/**\n * @brief CPU kernel of segment sum.\n * @param feat The input tensor.\n * @param offsets The offset tensor storing the ranges of segments.\n * @param out The output tensor.\n */\ntemplate <typename IdType, typename DType>\nvoid SegmentSum(NDArray feat, NDArray offsets, NDArray out) {\n  if (std::is_same<DType, BFloat16>::value)\n    LOG(FATAL) << \"Unsupported CPU kernel for SegmentSum for BF16.\";\n  int n = out->shape[0];\n  int dim = 1;\n  for (int i = 1; i < out->ndim; ++i) dim *= out->shape[i];\n  const DType* feat_data = feat.Ptr<DType>();\n  const IdType* offsets_data = offsets.Ptr<IdType>();\n  DType* out_data = out.Ptr<DType>();\n  runtime::parallel_for(0, n, [=](int b, int e) {\n    for (auto i = b; i < e; ++i) {\n      for (IdType j = offsets_data[i]; j < offsets_data[i + 1]; ++j) {\n        for (int k = 0; k < dim; ++k) {\n          out_data[i * dim + k] += feat_data[j * dim + k];\n        }\n      }\n    }\n  });\n}\n\n/**\n * @brief CPU kernel of segment min/max.\n * @param feat The input tensor.\n * @param offsets The offset tensor storing the ranges of segments.\n * @param out The output tensor.\n * @param arg An auxiliary tensor storing the argmin/max information\n *        used in backward phase.\n */\ntemplate <typename IdType, typename DType, typename Cmp>\nvoid SegmentCmp(NDArray feat, NDArray offsets, NDArray out, NDArray arg) {\n  int n = out->shape[0];\n  int dim = 1;\n  for (int i = 1; i < out->ndim; ++i) dim *= out->shape[i];\n  const DType* feat_data = feat.Ptr<DType>();\n  const IdType* offsets_data = offsets.Ptr<IdType>();\n  DType* out_data = out.Ptr<DType>();\n  IdType* arg_data = arg.Ptr<IdType>();\n  std::fill(out_data, out_data + out.NumElements(), Cmp::zero);\n  std::fill(arg_data, arg_data + arg.NumElements(), -1);\n  runtime::parallel_for(0, n, [=](int b, int e) {\n    for (auto i = b; i < e; ++i) {\n      for (IdType j = offsets_data[i]; j < offsets_data[i + 1]; ++j) {\n        for (int k = 0; k < dim; ++k) {\n          const DType val = feat_data[j * dim + k];\n          if (Cmp::Call(out_data[i * dim + k], val)) {\n            out_data[i * dim + k] = val;\n            arg_data[i * dim + k] = j;\n          }\n        }\n      }\n    }\n  });\n}\n\n/**\n * @brief CPU kernel of Scatter Add (on first dimension) operator.\n * @note math equation: out[idx[i], *] += feat[i, *]\n * @param feat The input tensor.\n * @param idx The indices tensor.\n * @param out The output tensor.\n */\ntemplate <typename IdType, typename DType>\nvoid ScatterAdd(NDArray feat, NDArray idx, NDArray out) {\n  int n = feat->shape[0];\n  int dim = 1;\n  for (int i = 1; i < out->ndim; ++i) dim *= out->shape[i];\n  const DType* feat_data = feat.Ptr<DType>();\n  const IdType* idx_data = idx.Ptr<IdType>();\n  DType* out_data = out.Ptr<DType>();\n#pragma omp parallel for\n  for (int i = 0; i < n; ++i) {\n    const int write_row = idx_data[i];\n    for (int k = 0; k < dim; ++k) {\n#pragma omp atomic\n      out_data[write_row * dim + k] += feat_data[i * dim + k];\n    }\n  }\n}\n\n/**\n * @brief CPU kernel to update gradients for reduce op max/min\n * @param graph The input heterogeneous graph.\n * @param op The binary operator, could be `copy_u`, `copy_e'.\n * @param list_feat List of the input tensors.\n * @param list_idx  List of the indices tensors.\n * @param list_idx_etype List of the node- or edge-type tensors.\n * @param list_out List of the output tensors.\n */\ntemplate <typename IdType, typename DType>\nvoid UpdateGradMinMax_hetero(\n    HeteroGraphPtr graph, const std::string& op,\n    const std::vector<NDArray>& list_feat, const std::vector<NDArray>& list_idx,\n    const std::vector<NDArray>& list_idx_types,\n    std::vector<NDArray>* list_out) {\n  if (op == \"copy_lhs\" || op == \"copy_rhs\") {\n    std::vector<std::vector<dgl_id_t>> src_dst_ntypes(\n        graph->NumVertexTypes(), std::vector<dgl_id_t>());\n\n    for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {\n      auto pair = graph->meta_graph()->FindEdge(etype);\n      const dgl_id_t dst_ntype = pair.first;  // graph is reversed\n      const dgl_id_t src_ntype = pair.second;\n      auto same_src_dst_ntype = std::find(\n          std::begin(src_dst_ntypes[dst_ntype]),\n          std::end(src_dst_ntypes[dst_ntype]), src_ntype);\n      // if op is \"copy_lhs\", relation type with same src and dst node type will\n      // be updated once\n      if (op == \"copy_lhs\" &&\n          same_src_dst_ntype != std::end(src_dst_ntypes[dst_ntype]))\n        continue;\n      src_dst_ntypes[dst_ntype].push_back(src_ntype);\n      const DType* feat_data = list_feat[dst_ntype].Ptr<DType>();\n      const IdType* idx_data = list_idx[dst_ntype].Ptr<IdType>();\n      const IdType* idx_type_data = list_idx_types[dst_ntype].Ptr<IdType>();\n      int type = (op == \"copy_lhs\") ? src_ntype : etype;\n      DType* out_data = (*list_out)[type].Ptr<DType>();\n      int dim = 1;\n      for (int i = 1; i < (*list_out)[type]->ndim; ++i)\n        dim *= (*list_out)[type]->shape[i];\n      int n = list_feat[dst_ntype]->shape[0];\n#pragma omp parallel for\n      for (int i = 0; i < n; ++i) {\n        for (int k = 0; k < dim; ++k) {\n          if (type == idx_type_data[i * dim + k]) {\n            const int write_row = idx_data[i * dim + k];\n#pragma omp atomic\n            out_data[write_row * dim + k] +=\n                feat_data[i * dim + k];  // feat = dZ\n          }\n        }\n      }\n    }\n  } else {\n    LOG(FATAL) << \"Unsupported binary operator: \" << op;\n  }\n}\n\n/**\n * @brief CPU kernel of backward phase of segment min/max.\n * @note math equation: out[arg[i, k], k] = feat[i, k]\n * @param feat The input tensor.\n * @param arg The argmin/argmax tensor.\n * @param out The output tensor.\n */\ntemplate <typename IdType, typename DType>\nvoid BackwardSegmentCmp(NDArray feat, NDArray arg, NDArray out) {\n  int n = feat->shape[0];\n  int dim = 1;\n  for (int i = 1; i < out->ndim; ++i) dim *= out->shape[i];\n  const DType* feat_data = feat.Ptr<DType>();\n  const IdType* arg_data = arg.Ptr<IdType>();\n  DType* out_data = out.Ptr<DType>();\n  runtime::parallel_for(0, n, [=](int b, int e) {\n    for (auto i = b; i < e; ++i) {\n      for (int k = 0; k < dim; ++k) {\n        int write_row = arg_data[i * dim + k];\n        if (write_row >= 0)\n          out_data[write_row * dim + k] = feat_data[i * dim + k];\n      }\n    }\n  });\n}\n\n}  // namespace cpu\n}  // namespace aten\n}  // namespace dgl\n\n#endif  // DGL_ARRAY_CPU_SEGMENT_REDUCE_H_\n"
  },
  {
    "path": "src/array/cpu/spmat_op_impl_coo.cc",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file array/cpu/spmat_op_impl.cc\n * @brief CPU implementation of COO sparse matrix operators\n */\n#include <dgl/runtime/parallel_for.h>\n#include <dmlc/omp.h>\n\n#include <numeric>\n#include <tuple>\n#include <unordered_map>\n#include <unordered_set>\n#include <vector>\n\n#include \"array_utils.h\"\n\nnamespace dgl {\n\nusing runtime::NDArray;\nusing runtime::parallel_for;\n\nnamespace aten {\nnamespace impl {\n\n/**\n * TODO(BarclayII):\n * For row-major sorted COOs, we have faster implementation with binary search,\n * sorted search, etc.  Later we should benchmark how much we can gain with\n * sorted COOs on hypersparse graphs.\n */\n\n///////////////////////////// COOIsNonZero /////////////////////////////\n\ntemplate <DGLDeviceType XPU, typename IdType>\nbool COOIsNonZero(COOMatrix coo, int64_t row, int64_t col) {\n  CHECK(row >= 0 && row < coo.num_rows) << \"Invalid row index: \" << row;\n  CHECK(col >= 0 && col < coo.num_cols) << \"Invalid col index: \" << col;\n  const IdType *coo_row_data = static_cast<IdType *>(coo.row->data);\n  const IdType *coo_col_data = static_cast<IdType *>(coo.col->data);\n  for (int64_t i = 0; i < coo.row->shape[0]; ++i) {\n    if (coo_row_data[i] == row && coo_col_data[i] == col) return true;\n  }\n  return false;\n}\n\ntemplate bool COOIsNonZero<kDGLCPU, int32_t>(COOMatrix, int64_t, int64_t);\ntemplate bool COOIsNonZero<kDGLCPU, int64_t>(COOMatrix, int64_t, int64_t);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nNDArray COOIsNonZero(COOMatrix coo, NDArray row, NDArray col) {\n  const auto rowlen = row->shape[0];\n  const auto collen = col->shape[0];\n  const auto rstlen = std::max(rowlen, collen);\n  NDArray rst = NDArray::Empty({rstlen}, row->dtype, row->ctx);\n  IdType *rst_data = static_cast<IdType *>(rst->data);\n  const IdType *row_data = static_cast<IdType *>(row->data);\n  const IdType *col_data = static_cast<IdType *>(col->data);\n  const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1;\n  const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1;\n  const int64_t kmax = std::max(rowlen, collen);\n  parallel_for(0, kmax, [=](size_t b, size_t e) {\n    for (auto k = b; k < e; ++k) {\n      int64_t i = row_stride * k;\n      int64_t j = col_stride * k;\n      rst_data[k] =\n          COOIsNonZero<XPU, IdType>(coo, row_data[i], col_data[j]) ? 1 : 0;\n    }\n  });\n  return rst;\n}\n\ntemplate NDArray COOIsNonZero<kDGLCPU, int32_t>(COOMatrix, NDArray, NDArray);\ntemplate NDArray COOIsNonZero<kDGLCPU, int64_t>(COOMatrix, NDArray, NDArray);\n\n///////////////////////////// COOHasDuplicate /////////////////////////////\n\ntemplate <DGLDeviceType XPU, typename IdType>\nbool COOHasDuplicate(COOMatrix coo) {\n  std::unordered_set<std::pair<IdType, IdType>, PairHash> hashmap;\n  const IdType *src_data = static_cast<IdType *>(coo.row->data);\n  const IdType *dst_data = static_cast<IdType *>(coo.col->data);\n  const auto nnz = coo.row->shape[0];\n  for (IdType eid = 0; eid < nnz; ++eid) {\n    const auto &p = std::make_pair(src_data[eid], dst_data[eid]);\n    if (hashmap.count(p)) {\n      return true;\n    } else {\n      hashmap.insert(p);\n    }\n  }\n  return false;\n}\n\ntemplate bool COOHasDuplicate<kDGLCPU, int32_t>(COOMatrix coo);\ntemplate bool COOHasDuplicate<kDGLCPU, int64_t>(COOMatrix coo);\n\n///////////////////////////// COOGetRowNNZ /////////////////////////////\n\ntemplate <DGLDeviceType XPU, typename IdType>\nint64_t COOGetRowNNZ(COOMatrix coo, int64_t row) {\n  CHECK(row >= 0 && row < coo.num_rows) << \"Invalid row index: \" << row;\n  const IdType *coo_row_data = static_cast<IdType *>(coo.row->data);\n  int64_t result = 0;\n  for (int64_t i = 0; i < coo.row->shape[0]; ++i) {\n    if (coo_row_data[i] == row) ++result;\n  }\n  return result;\n}\n\ntemplate int64_t COOGetRowNNZ<kDGLCPU, int32_t>(COOMatrix, int64_t);\ntemplate int64_t COOGetRowNNZ<kDGLCPU, int64_t>(COOMatrix, int64_t);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nNDArray COOGetRowNNZ(COOMatrix coo, NDArray rows) {\n  CHECK_SAME_DTYPE(coo.col, rows);\n  const auto len = rows->shape[0];\n  const IdType *vid_data = static_cast<IdType *>(rows->data);\n  NDArray rst = NDArray::Empty({len}, rows->dtype, rows->ctx);\n  IdType *rst_data = static_cast<IdType *>(rst->data);\n#pragma omp parallel for\n  for (int64_t i = 0; i < len; ++i) {\n    rst_data[i] = COOGetRowNNZ<XPU, IdType>(coo, vid_data[i]);\n  }\n  return rst;\n}\n\ntemplate NDArray COOGetRowNNZ<kDGLCPU, int32_t>(COOMatrix, NDArray);\ntemplate NDArray COOGetRowNNZ<kDGLCPU, int64_t>(COOMatrix, NDArray);\n\n////////////////////////// COOGetRowDataAndIndices /////////////////////////////\n\ntemplate <DGLDeviceType XPU, typename IdType>\nstd::pair<NDArray, NDArray> COOGetRowDataAndIndices(\n    COOMatrix coo, int64_t row) {\n  CHECK(row >= 0 && row < coo.num_rows) << \"Invalid row index: \" << row;\n\n  const IdType *coo_row_data = static_cast<IdType *>(coo.row->data);\n  const IdType *coo_col_data = static_cast<IdType *>(coo.col->data);\n  const IdType *coo_data =\n      COOHasData(coo) ? static_cast<IdType *>(coo.data->data) : nullptr;\n\n  std::vector<IdType> indices;\n  std::vector<IdType> data;\n\n  for (int64_t i = 0; i < coo.row->shape[0]; ++i) {\n    if (coo_row_data[i] == row) {\n      indices.push_back(coo_col_data[i]);\n      data.push_back(coo_data ? coo_data[i] : i);\n    }\n  }\n\n  return std::make_pair(\n      NDArray::FromVector(data), NDArray::FromVector(indices));\n}\n\ntemplate std::pair<NDArray, NDArray> COOGetRowDataAndIndices<kDGLCPU, int32_t>(\n    COOMatrix, int64_t);\ntemplate std::pair<NDArray, NDArray> COOGetRowDataAndIndices<kDGLCPU, int64_t>(\n    COOMatrix, int64_t);\n\n///////////////////////////// COOGetData /////////////////////////////\n\ntemplate <DGLDeviceType XPU, typename IdType>\nIdArray COOGetData(COOMatrix coo, IdArray rows, IdArray cols) {\n  const int64_t rowlen = rows->shape[0];\n  const int64_t collen = cols->shape[0];\n  CHECK((rowlen == collen) || (rowlen == 1) || (collen == 1))\n      << \"Invalid row and col Id array:\" << rows << \" \" << cols;\n  const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1;\n  const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1;\n  const IdType *row_data = rows.Ptr<IdType>();\n  const IdType *col_data = cols.Ptr<IdType>();\n\n  const IdType *coo_row = coo.row.Ptr<IdType>();\n  const IdType *coo_col = coo.col.Ptr<IdType>();\n  const IdType *data = COOHasData(coo) ? coo.data.Ptr<IdType>() : nullptr;\n  const int64_t nnz = coo.row->shape[0];\n\n  const int64_t retlen = std::max(rowlen, collen);\n  IdArray ret = Full(-1, retlen, rows->dtype.bits, rows->ctx);\n  IdType *ret_data = ret.Ptr<IdType>();\n\n  // TODO(minjie): We might need to consider sorting the COO beforehand\n  // especially when the number of (row, col) pairs is large. Need more\n  // benchmarks to justify the choice.\n\n  if (coo.row_sorted) {\n    parallel_for(0, retlen, [&](size_t b, size_t e) {\n      for (auto p = b; p < e; ++p) {\n        const IdType row_id = row_data[p * row_stride],\n                     col_id = col_data[p * col_stride];\n        auto it = std::lower_bound(coo_row, coo_row + nnz, row_id);\n        for (; it < coo_row + nnz && *it == row_id; ++it) {\n          const auto idx = it - coo_row;\n          if (coo_col[idx] == col_id) {\n            ret_data[p] = data ? data[idx] : idx;\n            break;\n          }\n        }\n      }\n    });\n  } else {\n#pragma omp parallel for\n    for (int64_t p = 0; p < retlen; ++p) {\n      const IdType row_id = row_data[p * row_stride],\n                   col_id = col_data[p * col_stride];\n      for (int64_t idx = 0; idx < nnz; ++idx) {\n        if (coo_row[idx] == row_id && coo_col[idx] == col_id) {\n          ret_data[p] = data ? data[idx] : idx;\n          break;\n        }\n      }\n    }\n  }\n\n  return ret;\n}\n\ntemplate IdArray COOGetData<kDGLCPU, int32_t>(COOMatrix, IdArray, IdArray);\ntemplate IdArray COOGetData<kDGLCPU, int64_t>(COOMatrix, IdArray, IdArray);\n\n///////////////////////////// COOGetDataAndIndices /////////////////////////////\n\ntemplate <DGLDeviceType XPU, typename IdType>\nstd::vector<NDArray> COOGetDataAndIndices(\n    COOMatrix coo, NDArray rows, NDArray cols) {\n  CHECK_SAME_DTYPE(coo.col, rows);\n  CHECK_SAME_DTYPE(coo.col, cols);\n  const int64_t rowlen = rows->shape[0];\n  const int64_t collen = cols->shape[0];\n  const int64_t len = std::max(rowlen, collen);\n\n  CHECK((rowlen == collen) || (rowlen == 1) || (collen == 1))\n      << \"Invalid row and col id array.\";\n\n  const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1;\n  const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1;\n  const IdType *row_data = static_cast<IdType *>(rows->data);\n  const IdType *col_data = static_cast<IdType *>(cols->data);\n\n  const IdType *coo_row_data = static_cast<IdType *>(coo.row->data);\n  const IdType *coo_col_data = static_cast<IdType *>(coo.col->data);\n  const IdType *data =\n      COOHasData(coo) ? static_cast<IdType *>(coo.data->data) : nullptr;\n\n  std::vector<IdType> ret_rows, ret_cols;\n  std::vector<IdType> ret_data;\n  ret_rows.reserve(len);\n  ret_cols.reserve(len);\n  ret_data.reserve(len);\n\n  // NOTE(BarclayII): With a small number of lookups, linear scan is faster.\n  // The threshold 200 comes from benchmarking both algorithms on a P3.8x\n  // instance. I also tried sorting plus binary search.  The speed gain is only\n  // significant for medium-sized graphs and lookups, so I didn't include it.\n  if (len >= 200) {\n    // TODO(BarclayII) Ideally we would want to cache this object.  However I'm\n    // not sure what is the best way to do so since this object is valid for CPU\n    // only.\n    std::unordered_multimap<std::pair<IdType, IdType>, IdType, PairHash>\n        pair_map;\n    pair_map.reserve(coo.row->shape[0]);\n    for (int64_t k = 0; k < coo.row->shape[0]; ++k)\n      pair_map.emplace(\n          std::make_pair(coo_row_data[k], coo_col_data[k]), data ? data[k] : k);\n\n    for (int64_t i = 0, j = 0; i < rowlen && j < collen;\n         i += row_stride, j += col_stride) {\n      const IdType row_id = row_data[i], col_id = col_data[j];\n      CHECK(row_id >= 0 && row_id < coo.num_rows)\n          << \"Invalid row index: \" << row_id;\n      CHECK(col_id >= 0 && col_id < coo.num_cols)\n          << \"Invalid col index: \" << col_id;\n      auto range = pair_map.equal_range({row_id, col_id});\n      for (auto it = range.first; it != range.second; ++it) {\n        ret_rows.push_back(row_id);\n        ret_cols.push_back(col_id);\n        ret_data.push_back(it->second);\n      }\n    }\n  } else {\n    for (int64_t i = 0, j = 0; i < rowlen && j < collen;\n         i += row_stride, j += col_stride) {\n      const IdType row_id = row_data[i], col_id = col_data[j];\n      CHECK(row_id >= 0 && row_id < coo.num_rows)\n          << \"Invalid row index: \" << row_id;\n      CHECK(col_id >= 0 && col_id < coo.num_cols)\n          << \"Invalid col index: \" << col_id;\n      for (int64_t k = 0; k < coo.row->shape[0]; ++k) {\n        if (coo_row_data[k] == row_id && coo_col_data[k] == col_id) {\n          ret_rows.push_back(row_id);\n          ret_cols.push_back(col_id);\n          ret_data.push_back(data ? data[k] : k);\n        }\n      }\n    }\n  }\n\n  return {\n      NDArray::FromVector(ret_rows), NDArray::FromVector(ret_cols),\n      NDArray::FromVector(ret_data)};\n}\n\ntemplate std::vector<NDArray> COOGetDataAndIndices<kDGLCPU, int32_t>(\n    COOMatrix coo, NDArray rows, NDArray cols);\ntemplate std::vector<NDArray> COOGetDataAndIndices<kDGLCPU, int64_t>(\n    COOMatrix coo, NDArray rows, NDArray cols);\n\n///////////////////////////// COOTranspose /////////////////////////////\n\ntemplate <DGLDeviceType XPU, typename IdType>\nCOOMatrix COOTranspose(COOMatrix coo) {\n  return COOMatrix{coo.num_cols, coo.num_rows, coo.col, coo.row, coo.data};\n}\n\ntemplate COOMatrix COOTranspose<kDGLCPU, int32_t>(COOMatrix coo);\ntemplate COOMatrix COOTranspose<kDGLCPU, int64_t>(COOMatrix coo);\n\n///////////////////////////// COOToCSR /////////////////////////////\nnamespace {\n\ntemplate <class IdType>\nCSRMatrix SortedCOOToCSR(const COOMatrix &coo) {\n  const int64_t N = coo.num_rows;\n  const int64_t NNZ = coo.row->shape[0];\n  const IdType *const row_data = static_cast<IdType *>(coo.row->data);\n  const IdType *const data =\n      COOHasData(coo) ? static_cast<IdType *>(coo.data->data) : nullptr;\n\n  NDArray ret_indptr = NDArray::Empty({N + 1}, coo.row->dtype, coo.row->ctx);\n  NDArray ret_indices = coo.col;\n  NDArray ret_data = data == nullptr\n                         ? NDArray::Empty({NNZ}, coo.row->dtype, coo.row->ctx)\n                         : coo.data;\n\n  // compute indptr\n  IdType *const Bp = static_cast<IdType *>(ret_indptr->data);\n  Bp[0] = 0;\n\n  IdType *const fill_data =\n      data ? nullptr : static_cast<IdType *>(ret_data->data);\n\n  if (NNZ > 0) {\n    auto num_threads = omp_get_max_threads();\n    parallel_for(0, num_threads, [&](int b, int e) {\n      for (auto thread_id = b; thread_id < e; ++thread_id) {\n        // We partition the set the of non-zeros among the threads\n        const int64_t nz_chunk = (NNZ + num_threads - 1) / num_threads;\n        const int64_t nz_start = thread_id * nz_chunk;\n        const int64_t nz_end = std::min(NNZ, nz_start + nz_chunk);\n\n        // Each thread searchs the row array for a change, and marks it's\n        // location in Bp. Threads, other than the first, start at the last\n        // index covered by the previous, in order to detect changes in the row\n        // array between thread partitions. This means that each thread after\n        // the first, searches the range [nz_start-1, nz_end). That is,\n        // if we had 10 non-zeros, and 4 threads, the indexes searched by each\n        // thread would be:\n        // 0: [0, 1, 2]\n        // 1: [2, 3, 4, 5]\n        // 2: [5, 6, 7, 8]\n        // 3: [8, 9]\n        //\n        // That way, if the row array were [0, 0, 1, 2, 2, 2, 4, 5, 5, 6], each\n        // change in row would be captured by one thread:\n        //\n        // 0: [0, 0, 1] - row 0\n        // 1: [1, 2, 2, 2] - row 1\n        // 2: [2, 4, 5, 5] - rows 2, 3, and 4\n        // 3: [5, 6] - rows 5 and 6\n        //\n        int64_t row = 0;\n        if (nz_start < nz_end) {\n          row = nz_start == 0 ? 0 : row_data[nz_start - 1];\n          for (int64_t i = nz_start; i < nz_end; ++i) {\n            while (row != row_data[i]) {\n              ++row;\n              Bp[row] = i;\n            }\n          }\n\n          // We will not detect the row change for the last row, nor any empty\n          // rows at the end of the matrix, so the last active thread needs\n          // mark all remaining rows in Bp with NNZ.\n          if (nz_end == NNZ) {\n            while (row < N) {\n              ++row;\n              Bp[row] = NNZ;\n            }\n          }\n\n          if (fill_data) {\n            // TODO(minjie): Many of our current implementation assumes that CSR\n            // must have\n            //   a data array. This is a temporary workaround. Remove this\n            //   after:\n            //   - The old immutable graph implementation is deprecated.\n            //   - The old binary reduce kernel is deprecated.\n            std::iota(fill_data + nz_start, fill_data + nz_end, nz_start);\n          }\n        }\n      }\n    });\n  } else {\n    std::fill(Bp, Bp + N + 1, 0);\n  }\n\n  return CSRMatrix(\n      coo.num_rows, coo.num_cols, ret_indptr, ret_indices, ret_data,\n      coo.col_sorted);\n}\n\ntemplate <class IdType>\nCSRMatrix UnSortedSparseCOOToCSR(const COOMatrix &coo) {\n  // Unsigned version of the original integer index data type.\n  // It avoids overflow in (N + num_threads) and (n_start + n_chunk) below.\n  typedef typename std::make_unsigned<IdType>::type UIdType;\n\n  const UIdType N = coo.num_rows;\n  const int64_t NNZ = coo.row->shape[0];\n  const IdType *const row_data = static_cast<IdType *>(coo.row->data);\n  const IdType *const col_data = static_cast<IdType *>(coo.col->data);\n  const IdType *const data =\n      COOHasData(coo) ? static_cast<IdType *>(coo.data->data) : nullptr;\n\n  NDArray ret_indptr = NDArray::Empty(\n      {static_cast<int64_t>(N) + 1}, coo.row->dtype, coo.row->ctx);\n  NDArray ret_indices = NDArray::Empty({NNZ}, coo.row->dtype, coo.row->ctx);\n  NDArray ret_data = NDArray::Empty({NNZ}, coo.row->dtype, coo.row->ctx);\n  IdType *const Bp = static_cast<IdType *>(ret_indptr->data);\n  Bp[N] = 0;\n  IdType *const Bi = static_cast<IdType *>(ret_indices->data);\n  IdType *const Bx = static_cast<IdType *>(ret_data->data);\n\n  // store sorted data and original index.\n  NDArray sorted_data = NDArray::Empty({NNZ}, coo.row->dtype, coo.row->ctx);\n  NDArray sorted_data_pos = NDArray::Empty({NNZ}, coo.row->dtype, coo.row->ctx);\n  IdType *const Sx = static_cast<IdType *>(sorted_data->data);\n  IdType *const Si = static_cast<IdType *>(sorted_data_pos->data);\n\n  // Lower number of threads if cost of parallelization is grater than gain\n  // from making calculation parallel.\n  const int64_t min_chunk_size = 1000;\n  const int64_t num_threads_for_batch = 2 + (NNZ + N) / min_chunk_size;\n  const int num_threads_required = std::min(\n      static_cast<int64_t>(omp_get_max_threads()), num_threads_for_batch);\n\n  // record row_idx in each thread.\n  std::vector<std::vector<int64_t>> p_sum(\n      num_threads_required, std::vector<int64_t>(num_threads_required));\n\n#pragma omp parallel num_threads(num_threads_required)\n  {\n    const int num_threads = omp_get_num_threads();\n    const int thread_id = omp_get_thread_num();\n    CHECK_LT(thread_id, num_threads);\n\n    const int64_t nz_chunk = (NNZ + num_threads - 1) / num_threads;\n    const int64_t nz_start = thread_id * nz_chunk;\n    const int64_t nz_end = std::min(NNZ, nz_start + nz_chunk);\n\n    const UIdType n_chunk = (N + num_threads - 1) / num_threads;\n    const UIdType n_start = thread_id * n_chunk;\n    const UIdType n_end = std::min(N, n_start + n_chunk);\n\n    for (auto i = n_start; i < n_end; ++i) {\n      Bp[i] = 0;\n    }\n\n    // iterate on NNZ data and count row_idx.\n    for (auto i = nz_start; i < nz_end; ++i) {\n      const IdType row_idx = row_data[i];\n      const IdType row_thread_id = row_idx / n_chunk;\n      ++p_sum[thread_id][row_thread_id];\n    }\n\n#pragma omp barrier\n#pragma omp master\n    // accumulate row_idx.\n    {\n      int64_t cum = 0;\n      for (int j = 0; j < num_threads; ++j) {\n        for (int i = 0; i < num_threads; ++i) {\n          auto tmp = p_sum[i][j];\n          p_sum[i][j] = cum;\n          cum += tmp;\n        }\n      }\n      CHECK_EQ(cum, NNZ);\n    }\n#pragma omp barrier\n    const int64_t i_start = p_sum[0][thread_id];\n    const int64_t i_end =\n        thread_id + 1 == num_threads ? NNZ : p_sum[0][thread_id + 1];\n#pragma omp barrier\n\n    // sort data by row_idx and place into Sx/Si.\n    auto &data_pos = p_sum[thread_id];\n    for (auto i = nz_start; i < nz_end; ++i) {\n      const IdType row_idx = row_data[i];\n      const IdType row_thread_id = row_idx / n_chunk;\n      const int64_t pos = data_pos[row_thread_id]++;\n      Sx[pos] = data == nullptr ? i : data[i];\n      Si[pos] = i;\n    }\n\n#pragma omp barrier\n\n    // Now we're able to do coo2csr on sorted data in each thread in parallel.\n    // compute data number on each row_idx.\n    for (auto i = i_start; i < i_end; ++i) {\n      const UIdType row_idx = row_data[Si[i]];\n      ++Bp[row_idx + 1];\n    }\n\n    // accumulate on each row\n    IdType cumsum = i_start;\n    for (auto i = n_start + 1; i <= n_end; ++i) {\n      const auto tmp = Bp[i];\n      Bp[i] = cumsum;\n      cumsum += tmp;\n    }\n\n    // update Bi/Bp/Bx\n    for (auto i = i_start; i < i_end; ++i) {\n      const UIdType row_idx = row_data[Si[i]];\n      const int64_t dest = (Bp[row_idx + 1]++);\n      Bi[dest] = col_data[Si[i]];\n      Bx[dest] = Sx[i];\n    }\n  }\n  return CSRMatrix(\n      coo.num_rows, coo.num_cols, ret_indptr, ret_indices, ret_data,\n      coo.col_sorted);\n}\n\ntemplate <class IdType>\nCSRMatrix UnSortedDenseCOOToCSR(const COOMatrix &coo) {\n  // Unsigned version of the original integer index data type.\n  // It avoids overflow in (N + num_threads) and (n_start + n_chunk) below.\n  typedef typename std::make_unsigned<IdType>::type UIdType;\n\n  const UIdType N = coo.num_rows;\n  const int64_t NNZ = coo.row->shape[0];\n  const IdType *const row_data = static_cast<IdType *>(coo.row->data);\n  const IdType *const col_data = static_cast<IdType *>(coo.col->data);\n  const IdType *const data =\n      COOHasData(coo) ? static_cast<IdType *>(coo.data->data) : nullptr;\n\n  NDArray ret_indptr = NDArray::Empty(\n      {static_cast<int64_t>(N) + 1}, coo.row->dtype, coo.row->ctx);\n  NDArray ret_indices = NDArray::Empty({NNZ}, coo.row->dtype, coo.row->ctx);\n  NDArray ret_data = NDArray::Empty({NNZ}, coo.row->dtype, coo.row->ctx);\n  IdType *const Bp = static_cast<IdType *>(ret_indptr->data);\n  Bp[0] = 0;\n  IdType *const Bi = static_cast<IdType *>(ret_indices->data);\n  IdType *const Bx = static_cast<IdType *>(ret_data->data);\n\n  // the offset within each row, that each thread will write to\n  std::vector<std::vector<IdType>> local_ptrs;\n  std::vector<int64_t> thread_prefixsum;\n\n#pragma omp parallel\n  {\n    const int num_threads = omp_get_num_threads();\n    const int thread_id = omp_get_thread_num();\n    CHECK_LT(thread_id, num_threads);\n\n    const int64_t nz_chunk = (NNZ + num_threads - 1) / num_threads;\n    const int64_t nz_start = thread_id * nz_chunk;\n    const int64_t nz_end = std::min(NNZ, nz_start + nz_chunk);\n\n    const UIdType n_chunk = (N + num_threads - 1) / num_threads;\n    const UIdType n_start = thread_id * n_chunk;\n    const UIdType n_end = std::min(N, n_start + n_chunk);\n\n#pragma omp master\n    {\n      local_ptrs.resize(num_threads);\n      thread_prefixsum.resize(num_threads + 1);\n    }\n\n#pragma omp barrier\n    local_ptrs[thread_id].resize(N, 0);\n\n    for (int64_t i = nz_start; i < nz_end; ++i) {\n      ++local_ptrs[thread_id][row_data[i]];\n    }\n\n#pragma omp barrier\n    // compute prefixsum in parallel\n    int64_t sum = 0;\n    for (UIdType i = n_start; i < n_end; ++i) {\n      IdType tmp = 0;\n      for (int j = 0; j < num_threads; ++j) {\n        auto previous = local_ptrs[j][i];\n        local_ptrs[j][i] = tmp;\n        tmp += previous;\n      }\n      sum += tmp;\n      Bp[i + 1] = sum;\n    }\n    thread_prefixsum[thread_id + 1] = sum;\n\n#pragma omp barrier\n#pragma omp master\n    {\n      for (int i = 0; i < num_threads; ++i) {\n        thread_prefixsum[i + 1] += thread_prefixsum[i];\n      }\n      CHECK_EQ(thread_prefixsum[num_threads], NNZ);\n    }\n#pragma omp barrier\n\n    sum = thread_prefixsum[thread_id];\n    for (UIdType i = n_start; i < n_end; ++i) {\n      Bp[i + 1] += sum;\n    }\n\n#pragma omp barrier\n    for (int64_t i = nz_start; i < nz_end; ++i) {\n      const IdType r = row_data[i];\n      const int64_t index = Bp[r] + local_ptrs[thread_id][r]++;\n      Bi[index] = col_data[i];\n      Bx[index] = data ? data[i] : i;\n    }\n  }\n  CHECK_EQ(Bp[N], NNZ);\n\n  return CSRMatrix(\n      coo.num_rows, coo.num_cols, ret_indptr, ret_indices, ret_data,\n      coo.col_sorted);\n}\n\n// complexity: time O(NNZ), space O(1)\ntemplate <typename IdType>\nCSRMatrix UnSortedSmallCOOToCSR(COOMatrix coo) {\n  const int64_t N = coo.num_rows;\n  const int64_t NNZ = coo.row->shape[0];\n  const IdType *row_data = static_cast<IdType *>(coo.row->data);\n  const IdType *col_data = static_cast<IdType *>(coo.col->data);\n  const IdType *data =\n      COOHasData(coo) ? static_cast<IdType *>(coo.data->data) : nullptr;\n  NDArray ret_indptr = NDArray::Empty({N + 1}, coo.row->dtype, coo.row->ctx);\n  NDArray ret_indices = NDArray::Empty({NNZ}, coo.row->dtype, coo.row->ctx);\n  NDArray ret_data = NDArray::Empty({NNZ}, coo.row->dtype, coo.row->ctx);\n  IdType *Bp = static_cast<IdType *>(ret_indptr->data);\n  IdType *Bi = static_cast<IdType *>(ret_indices->data);\n  IdType *Bx = static_cast<IdType *>(ret_data->data);\n\n  // Count elements in each row\n  std::fill(Bp, Bp + N, 0);\n  for (int64_t i = 0; i < NNZ; ++i) {\n    Bp[row_data[i]]++;\n  }\n\n  // Convert to indexes\n  for (IdType i = 0, cumsum = 0; i < N; ++i) {\n    const IdType temp = Bp[i];\n    Bp[i] = cumsum;\n    cumsum += temp;\n  }\n\n  for (int64_t i = 0; i < NNZ; ++i) {\n    const IdType r = row_data[i];\n    Bi[Bp[r]] = col_data[i];\n    Bx[Bp[r]] = data ? data[i] : i;\n    Bp[r]++;\n  }\n\n  // Restore the indptr\n  for (int64_t i = N; i > 0; --i) {\n    Bp[i] = Bp[i - 1];\n  }\n  Bp[0] = 0;\n\n  return CSRMatrix(\n      coo.num_rows, coo.num_cols, ret_indptr, ret_indices, ret_data,\n      coo.col_sorted);\n}\n\nenum class COOToCSRAlg {\n  sorted = 0,\n  unsortedSmall,\n  unsortedSparse,\n  unsortedDense\n};\n\n/**\n * Chose COO to CSR format conversion algorithm for given COO matrix according\n * to heuristic based on measured performance.\n *\n * Implementation and complexity details. N: num_nodes, NNZ: num_edges, P:\n * num_threads.\n *   1. If row is sorted in COO, SortedCOOToCSR<> is applied. Time: O(NNZ/P),\n * space: O(1).\n *   2 If row is NOT sorted in COO and graph is small (small number of NNZ),\n * UnSortedSmallCOOToCSR<> is applied. Time: O(NNZ), space O(N).\n *   3 If row is NOT sorted in COO and graph is sparse (low average degree),\n * UnSortedSparseCOOToCSR<> is applied. Time: O(NNZ/P + N/P + P^2),\n * space O(NNZ + P^2).\n *   4. If row is NOT sorted in COO and graph is dense (medium/high average\n * degree), UnSortedDenseCOOToCSR<> is applied. Time: O(NNZ/P + N/P),\n * space O(NNZ + N*P).\n *\n * Note:\n *   If you change this function, change also _TestCOOToCSRAlgs in\n * tests/cpp/test_spmat_coo.cc\n */\ntemplate <typename IdType>\ninline COOToCSRAlg WhichCOOToCSR(const COOMatrix &coo) {\n  if (coo.row_sorted) {\n    return COOToCSRAlg::sorted;\n  } else {\n#ifdef _WIN32\n    // On Windows omp_get_max_threads() gives larger value than later OMP can\n    // spawn.\n    int64_t num_threads;\n#pragma omp parallel\n#pragma master\n    { num_threads = omp_get_num_threads(); }\n#else\n    const int64_t num_threads = omp_get_max_threads();\n#endif\n    const int64_t N = coo.num_rows;\n    const int64_t NNZ = coo.row->shape[0];\n    // Parameters below are heuristically chosen according to measured\n    // performance.\n    const int64_t type_scale = sizeof(IdType) >> 1;\n    const int64_t small = 50 * num_threads * type_scale * type_scale;\n    if (NNZ < small || num_threads == 1) {\n      // For relatively small number of non zero elements cost of spread\n      // algorithm between threads is bigger than improvements from using\n      // many cores\n      return COOToCSRAlg::unsortedSmall;\n    } else if (type_scale * NNZ < num_threads * N) {\n      // For relatively small number of non zero elements in matrix, sparse\n      // parallel version of algorithm is more efficient than dense.\n      return COOToCSRAlg::unsortedSparse;\n    }\n    return COOToCSRAlg::unsortedDense;\n  }\n}\n\n}  // namespace\n\ntemplate <DGLDeviceType XPU, typename IdType>\nCSRMatrix COOToCSR(COOMatrix coo) {\n  CHECK_NO_OVERFLOW(coo.row->dtype, coo.row->shape[0]);\n  switch (WhichCOOToCSR<IdType>(coo)) {\n    case COOToCSRAlg::sorted:\n      return SortedCOOToCSR<IdType>(coo);\n    case COOToCSRAlg::unsortedSmall:\n    default:\n      return UnSortedSmallCOOToCSR<IdType>(coo);\n    case COOToCSRAlg::unsortedSparse:\n      return UnSortedSparseCOOToCSR<IdType>(coo);\n    case COOToCSRAlg::unsortedDense:\n      return UnSortedDenseCOOToCSR<IdType>(coo);\n  }\n}\n\ntemplate CSRMatrix COOToCSR<kDGLCPU, int32_t>(COOMatrix coo);\ntemplate CSRMatrix COOToCSR<kDGLCPU, int64_t>(COOMatrix coo);\n\n///////////////////////////// COOSliceRows /////////////////////////////\n\ntemplate <DGLDeviceType XPU, typename IdType>\nCOOMatrix COOSliceRows(COOMatrix coo, int64_t start, int64_t end) {\n  // TODO(minjie): use binary search when coo.row_sorted is true\n  CHECK(start >= 0 && start < coo.num_rows) << \"Invalid start row \" << start;\n  CHECK(end > 0 && end <= coo.num_rows) << \"Invalid end row \" << end;\n\n  const IdType *coo_row_data = static_cast<IdType *>(coo.row->data);\n  const IdType *coo_col_data = static_cast<IdType *>(coo.col->data);\n  const IdType *coo_data =\n      COOHasData(coo) ? static_cast<IdType *>(coo.data->data) : nullptr;\n\n  std::vector<IdType> ret_row, ret_col;\n  std::vector<IdType> ret_data;\n\n  for (int64_t i = 0; i < coo.row->shape[0]; ++i) {\n    const IdType row_id = coo_row_data[i];\n    const IdType col_id = coo_col_data[i];\n    if (row_id < end && row_id >= start) {\n      ret_row.push_back(row_id - start);\n      ret_col.push_back(col_id);\n      ret_data.push_back(coo_data ? coo_data[i] : i);\n    }\n  }\n  return COOMatrix(\n      end - start, coo.num_cols, NDArray::FromVector(ret_row),\n      NDArray::FromVector(ret_col), NDArray::FromVector(ret_data),\n      coo.row_sorted, coo.col_sorted);\n}\n\ntemplate COOMatrix COOSliceRows<kDGLCPU, int32_t>(COOMatrix, int64_t, int64_t);\ntemplate COOMatrix COOSliceRows<kDGLCPU, int64_t>(COOMatrix, int64_t, int64_t);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nCOOMatrix COOSliceRows(COOMatrix coo, NDArray rows) {\n  const IdType *coo_row_data = static_cast<IdType *>(coo.row->data);\n  const IdType *coo_col_data = static_cast<IdType *>(coo.col->data);\n  const IdType *coo_data =\n      COOHasData(coo) ? static_cast<IdType *>(coo.data->data) : nullptr;\n\n  std::vector<IdType> ret_row, ret_col;\n  std::vector<IdType> ret_data;\n\n  IdHashMap<IdType> hashmap(rows);\n\n  for (int64_t i = 0; i < coo.row->shape[0]; ++i) {\n    const IdType row_id = coo_row_data[i];\n    const IdType col_id = coo_col_data[i];\n    const IdType mapped_row_id = hashmap.Map(row_id, -1);\n    if (mapped_row_id != -1) {\n      ret_row.push_back(mapped_row_id);\n      ret_col.push_back(col_id);\n      ret_data.push_back(coo_data ? coo_data[i] : i);\n    }\n  }\n\n  return COOMatrix{\n      rows->shape[0],\n      coo.num_cols,\n      NDArray::FromVector(ret_row),\n      NDArray::FromVector(ret_col),\n      NDArray::FromVector(ret_data),\n      coo.row_sorted,\n      coo.col_sorted};\n}\n\ntemplate COOMatrix COOSliceRows<kDGLCPU, int32_t>(COOMatrix, NDArray);\ntemplate COOMatrix COOSliceRows<kDGLCPU, int64_t>(COOMatrix, NDArray);\n\n///////////////////////////// COOSliceMatrix /////////////////////////////\n\ntemplate <DGLDeviceType XPU, typename IdType>\nCOOMatrix COOSliceMatrix(\n    COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols) {\n  const IdType *coo_row_data = static_cast<IdType *>(coo.row->data);\n  const IdType *coo_col_data = static_cast<IdType *>(coo.col->data);\n  const IdType *coo_data =\n      COOHasData(coo) ? static_cast<IdType *>(coo.data->data) : nullptr;\n\n  IdHashMap<IdType> row_map(rows), col_map(cols);\n\n  std::vector<IdType> ret_row, ret_col;\n  std::vector<IdType> ret_data;\n\n  for (int64_t i = 0; i < coo.row->shape[0]; ++i) {\n    const IdType row_id = coo_row_data[i];\n    const IdType col_id = coo_col_data[i];\n    const IdType mapped_row_id = row_map.Map(row_id, -1);\n    if (mapped_row_id != -1) {\n      const IdType mapped_col_id = col_map.Map(col_id, -1);\n      if (mapped_col_id != -1) {\n        ret_row.push_back(mapped_row_id);\n        ret_col.push_back(mapped_col_id);\n        ret_data.push_back(coo_data ? coo_data[i] : i);\n      }\n    }\n  }\n\n  return COOMatrix(\n      rows->shape[0], cols->shape[0], NDArray::FromVector(ret_row),\n      NDArray::FromVector(ret_col), NDArray::FromVector(ret_data),\n      coo.row_sorted, coo.col_sorted);\n}\n\ntemplate COOMatrix COOSliceMatrix<kDGLCPU, int32_t>(\n    COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols);\ntemplate COOMatrix COOSliceMatrix<kDGLCPU, int64_t>(\n    COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols);\n\n///////////////////////////// COOReorder /////////////////////////////\n\ntemplate <DGLDeviceType XPU, typename IdType>\nCOOMatrix COOReorder(\n    COOMatrix coo, runtime::NDArray new_row_id_arr,\n    runtime::NDArray new_col_id_arr) {\n  CHECK_SAME_DTYPE(coo.row, new_row_id_arr);\n  CHECK_SAME_DTYPE(coo.col, new_col_id_arr);\n\n  // Input COO\n  const IdType *in_rows = static_cast<IdType *>(coo.row->data);\n  const IdType *in_cols = static_cast<IdType *>(coo.col->data);\n  int64_t num_rows = coo.num_rows;\n  int64_t num_cols = coo.num_cols;\n  int64_t nnz = coo.row->shape[0];\n  CHECK_EQ(num_rows, new_row_id_arr->shape[0])\n      << \"The new row Id array needs to be the same as the number of rows of \"\n         \"COO\";\n  CHECK_EQ(num_cols, new_col_id_arr->shape[0])\n      << \"The new col Id array needs to be the same as the number of cols of \"\n         \"COO\";\n\n  // New row/col Ids.\n  const IdType *new_row_ids = static_cast<IdType *>(new_row_id_arr->data);\n  const IdType *new_col_ids = static_cast<IdType *>(new_col_id_arr->data);\n\n  // Output COO\n  NDArray out_row_arr = NDArray::Empty({nnz}, coo.row->dtype, coo.row->ctx);\n  NDArray out_col_arr = NDArray::Empty({nnz}, coo.col->dtype, coo.col->ctx);\n  NDArray out_data_arr = COOHasData(coo) ? coo.data : NullArray();\n  IdType *out_row = static_cast<IdType *>(out_row_arr->data);\n  IdType *out_col = static_cast<IdType *>(out_col_arr->data);\n\n  parallel_for(0, nnz, [=](size_t b, size_t e) {\n    for (auto i = b; i < e; ++i) {\n      out_row[i] = new_row_ids[in_rows[i]];\n      out_col[i] = new_col_ids[in_cols[i]];\n    }\n  });\n  return COOMatrix(num_rows, num_cols, out_row_arr, out_col_arr, out_data_arr);\n}\n\ntemplate COOMatrix COOReorder<kDGLCPU, int64_t>(\n    COOMatrix csr, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids);\ntemplate COOMatrix COOReorder<kDGLCPU, int32_t>(\n    COOMatrix csr, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids);\n\n}  // namespace impl\n}  // namespace aten\n}  // namespace dgl\n"
  },
  {
    "path": "src/array/cpu/spmat_op_impl_csr.cc",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file array/cpu/spmat_op_impl_csr.cc\n * @brief CSR matrix operator CPU implementation\n */\n#include <dgl/array.h>\n#include <dgl/runtime/parallel_for.h>\n\n#include <atomic>\n#include <numeric>\n#include <unordered_set>\n#include <vector>\n\n#include \"array_utils.h\"\n\nnamespace dgl {\n\nusing runtime::NDArray;\nusing runtime::parallel_for;\n\nnamespace aten {\nnamespace impl {\n\n///////////////////////////// CSRIsNonZero /////////////////////////////\n\ntemplate <DGLDeviceType XPU, typename IdType>\nbool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) {\n  const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);\n  const IdType* indices_data = static_cast<IdType*>(csr.indices->data);\n  if (csr.sorted) {\n    const IdType* start = indices_data + indptr_data[row];\n    const IdType* end = indices_data + indptr_data[row + 1];\n    return std::binary_search(start, end, col);\n  } else {\n    for (IdType i = indptr_data[row]; i < indptr_data[row + 1]; ++i) {\n      if (indices_data[i] == col) {\n        return true;\n      }\n    }\n  }\n  return false;\n}\n\ntemplate bool CSRIsNonZero<kDGLCPU, int32_t>(CSRMatrix, int64_t, int64_t);\ntemplate bool CSRIsNonZero<kDGLCPU, int64_t>(CSRMatrix, int64_t, int64_t);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nNDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) {\n  const auto rowlen = row->shape[0];\n  const auto collen = col->shape[0];\n  const auto rstlen = std::max(rowlen, collen);\n  NDArray rst = NDArray::Empty({rstlen}, row->dtype, row->ctx);\n  IdType* rst_data = static_cast<IdType*>(rst->data);\n  const IdType* row_data = static_cast<IdType*>(row->data);\n  const IdType* col_data = static_cast<IdType*>(col->data);\n  const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1;\n  const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1;\n  runtime::parallel_for(\n      0, std::max(rowlen, collen), 1, [=](int64_t b, int64_t e) {\n        int64_t i = (row_stride == 0) ? 0 : b;\n        int64_t j = (col_stride == 0) ? 0 : b;\n        for (int64_t k = b; i < e && j < e;\n             i += row_stride, j += col_stride, ++k)\n          rst_data[k] =\n              CSRIsNonZero<XPU, IdType>(csr, row_data[i], col_data[j]) ? 1 : 0;\n      });\n  return rst;\n}\n\ntemplate NDArray CSRIsNonZero<kDGLCPU, int32_t>(CSRMatrix, NDArray, NDArray);\ntemplate NDArray CSRIsNonZero<kDGLCPU, int64_t>(CSRMatrix, NDArray, NDArray);\n\n///////////////////////////// CSRHasDuplicate /////////////////////////////\n\ntemplate <DGLDeviceType XPU, typename IdType>\nbool CSRHasDuplicate(CSRMatrix csr) {\n  const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);\n  const IdType* indices_data = static_cast<IdType*>(csr.indices->data);\n  for (IdType src = 0; src < csr.num_rows; ++src) {\n    std::unordered_set<IdType> hashmap;\n    for (IdType eid = indptr_data[src]; eid < indptr_data[src + 1]; ++eid) {\n      const IdType dst = indices_data[eid];\n      if (hashmap.count(dst)) {\n        return true;\n      } else {\n        hashmap.insert(dst);\n      }\n    }\n  }\n  return false;\n}\n\ntemplate bool CSRHasDuplicate<kDGLCPU, int32_t>(CSRMatrix csr);\ntemplate bool CSRHasDuplicate<kDGLCPU, int64_t>(CSRMatrix csr);\n\n///////////////////////////// CSRGetRowNNZ /////////////////////////////\n\ntemplate <DGLDeviceType XPU, typename IdType>\nint64_t CSRGetRowNNZ(CSRMatrix csr, int64_t row) {\n  const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);\n  return indptr_data[row + 1] - indptr_data[row];\n}\n\ntemplate int64_t CSRGetRowNNZ<kDGLCPU, int32_t>(CSRMatrix, int64_t);\ntemplate int64_t CSRGetRowNNZ<kDGLCPU, int64_t>(CSRMatrix, int64_t);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nNDArray CSRGetRowNNZ(CSRMatrix csr, NDArray rows) {\n  CHECK_SAME_DTYPE(csr.indices, rows);\n  const auto len = rows->shape[0];\n  const IdType* vid_data = static_cast<IdType*>(rows->data);\n  const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);\n  NDArray rst = NDArray::Empty({len}, rows->dtype, rows->ctx);\n  IdType* rst_data = static_cast<IdType*>(rst->data);\n  for (int64_t i = 0; i < len; ++i) {\n    const auto vid = vid_data[i];\n    rst_data[i] = indptr_data[vid + 1] - indptr_data[vid];\n  }\n  return rst;\n}\n\ntemplate NDArray CSRGetRowNNZ<kDGLCPU, int32_t>(CSRMatrix, NDArray);\ntemplate NDArray CSRGetRowNNZ<kDGLCPU, int64_t>(CSRMatrix, NDArray);\n\n/////////////////////////// CSRGetRowColumnIndices /////////////////////////////\n\ntemplate <DGLDeviceType XPU, typename IdType>\nNDArray CSRGetRowColumnIndices(CSRMatrix csr, int64_t row) {\n  const int64_t len = impl::CSRGetRowNNZ<XPU, IdType>(csr, row);\n  const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);\n  const int64_t offset = indptr_data[row] * sizeof(IdType);\n  return csr.indices.CreateView({len}, csr.indices->dtype, offset);\n}\n\ntemplate NDArray CSRGetRowColumnIndices<kDGLCPU, int32_t>(CSRMatrix, int64_t);\ntemplate NDArray CSRGetRowColumnIndices<kDGLCPU, int64_t>(CSRMatrix, int64_t);\n\n///////////////////////////// CSRGetRowData /////////////////////////////\n\ntemplate <DGLDeviceType XPU, typename IdType>\nNDArray CSRGetRowData(CSRMatrix csr, int64_t row) {\n  const int64_t len = impl::CSRGetRowNNZ<XPU, IdType>(csr, row);\n  const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);\n  const int64_t offset = indptr_data[row] * sizeof(IdType);\n  if (CSRHasData(csr))\n    return csr.data.CreateView({len}, csr.data->dtype, offset);\n  else\n    return aten::Range(\n        offset, offset + len, csr.indptr->dtype.bits, csr.indptr->ctx);\n}\n\ntemplate NDArray CSRGetRowData<kDGLCPU, int32_t>(CSRMatrix, int64_t);\ntemplate NDArray CSRGetRowData<kDGLCPU, int64_t>(CSRMatrix, int64_t);\n\n///////////////////////////// CSRGetData /////////////////////////////\n///////////////////////////// CSRGetDataAndIndices /////////////////////////////\n\ntemplate <DGLDeviceType XPU, typename IdType>\nvoid CollectDataIndicesFromSorted(\n    const IdType* indices_data, const IdType* data, const IdType start,\n    const IdType end, const IdType col, std::vector<IdType>* col_vec,\n    std::vector<IdType>* ret_vec) {\n  const IdType* start_ptr = indices_data + start;\n  const IdType* end_ptr = indices_data + end;\n  auto it = std::lower_bound(start_ptr, end_ptr, col);\n  // This might be a multi-graph. We need to collect all of the matched\n  // columns.\n  for (; it != end_ptr; it++) {\n    // If the col exist\n    if (*it == col) {\n      IdType idx = it - indices_data;\n      col_vec->push_back(indices_data[idx]);\n      ret_vec->push_back(data[idx]);\n    } else {\n      // If we find a column that is different, we can stop searching now.\n      break;\n    }\n  }\n}\n\ntemplate <DGLDeviceType XPU, typename IdType>\nstd::vector<NDArray> CSRGetDataAndIndices(\n    CSRMatrix csr, NDArray rows, NDArray cols) {\n  // TODO(minjie): more efficient implementation for matrix without duplicate\n  // entries\n  const int64_t rowlen = rows->shape[0];\n  const int64_t collen = cols->shape[0];\n\n  CHECK((rowlen == collen) || (rowlen == 1) || (collen == 1))\n      << \"Invalid row and col id array.\";\n\n  const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1;\n  const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1;\n  const IdType* row_data = static_cast<IdType*>(rows->data);\n  const IdType* col_data = static_cast<IdType*>(cols->data);\n\n  const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);\n  const IdType* indices_data = static_cast<IdType*>(csr.indices->data);\n  const IdType* data =\n      CSRHasData(csr) ? static_cast<IdType*>(csr.data->data) : nullptr;\n\n  std::vector<IdType> ret_rows, ret_cols;\n  std::vector<IdType> ret_data;\n\n  for (int64_t i = 0, j = 0; i < rowlen && j < collen;\n       i += row_stride, j += col_stride) {\n    const IdType row_id = row_data[i], col_id = col_data[j];\n    CHECK(row_id >= 0 && row_id < csr.num_rows)\n        << \"Invalid row index: \" << row_id;\n    CHECK(col_id >= 0 && col_id < csr.num_cols)\n        << \"Invalid col index: \" << col_id;\n    if (csr.sorted) {\n      // Here we collect col indices and data.\n      CollectDataIndicesFromSorted<XPU, IdType>(\n          indices_data, data, indptr_data[row_id], indptr_data[row_id + 1],\n          col_id, &ret_cols, &ret_data);\n      // We need to add row Ids.\n      while (ret_rows.size() < ret_data.size()) {\n        ret_rows.push_back(row_id);\n      }\n    } else {\n      for (IdType i = indptr_data[row_id]; i < indptr_data[row_id + 1]; ++i) {\n        if (indices_data[i] == col_id) {\n          ret_rows.push_back(row_id);\n          ret_cols.push_back(col_id);\n          ret_data.push_back(data ? data[i] : i);\n        }\n      }\n    }\n  }\n\n  return {\n      NDArray::FromVector(ret_rows, csr.indptr->ctx),\n      NDArray::FromVector(ret_cols, csr.indptr->ctx),\n      NDArray::FromVector(ret_data, csr.data->ctx)};\n}\n\ntemplate std::vector<NDArray> CSRGetDataAndIndices<kDGLCPU, int32_t>(\n    CSRMatrix csr, NDArray rows, NDArray cols);\ntemplate std::vector<NDArray> CSRGetDataAndIndices<kDGLCPU, int64_t>(\n    CSRMatrix csr, NDArray rows, NDArray cols);\n\n///////////////////////////// CSRTranspose /////////////////////////////\n\n// for a matrix of shape (N, M) and NNZ\n// complexity: time O(NNZ + max(N, M)), space O(1)\ntemplate <DGLDeviceType XPU, typename IdType>\nCSRMatrix CSRTranspose(CSRMatrix csr) {\n  const int64_t N = csr.num_rows;\n  const int64_t M = csr.num_cols;\n  const int64_t nnz = csr.indices->shape[0];\n  const IdType* Ap = static_cast<IdType*>(csr.indptr->data);\n  const IdType* Aj = static_cast<IdType*>(csr.indices->data);\n  const IdType* Ax =\n      CSRHasData(csr) ? static_cast<IdType*>(csr.data->data) : nullptr;\n  NDArray ret_indptr =\n      NDArray::Empty({M + 1}, csr.indptr->dtype, csr.indptr->ctx);\n  NDArray ret_indices =\n      NDArray::Empty({nnz}, csr.indices->dtype, csr.indices->ctx);\n  NDArray ret_data = NDArray::Empty({nnz}, csr.indptr->dtype, csr.indptr->ctx);\n  IdType* Bp = static_cast<IdType*>(ret_indptr->data);\n  IdType* Bi = static_cast<IdType*>(ret_indices->data);\n  IdType* Bx = static_cast<IdType*>(ret_data->data);\n\n  std::fill(Bp, Bp + M, 0);\n\n  for (int64_t j = 0; j < nnz; ++j) {\n    Bp[Aj[j]]++;\n  }\n\n  // cumsum\n  for (int64_t i = 0, cumsum = 0; i < M; ++i) {\n    const IdType temp = Bp[i];\n    Bp[i] = cumsum;\n    cumsum += temp;\n  }\n  Bp[M] = nnz;\n\n  for (int64_t i = 0; i < N; ++i) {\n    for (IdType j = Ap[i]; j < Ap[i + 1]; ++j) {\n      const IdType dst = Aj[j];\n      Bi[Bp[dst]] = i;\n      Bx[Bp[dst]] = Ax ? Ax[j] : j;\n      Bp[dst]++;\n    }\n  }\n\n  // correct the indptr\n  for (int64_t i = 0, last = 0; i <= M; ++i) {\n    IdType temp = Bp[i];\n    Bp[i] = last;\n    last = temp;\n  }\n\n  return CSRMatrix{\n      csr.num_cols, csr.num_rows, ret_indptr, ret_indices, ret_data};\n}\n\ntemplate CSRMatrix CSRTranspose<kDGLCPU, int32_t>(CSRMatrix csr);\ntemplate CSRMatrix CSRTranspose<kDGLCPU, int64_t>(CSRMatrix csr);\n\n///////////////////////////// CSRToCOO /////////////////////////////\ntemplate <DGLDeviceType XPU, typename IdType>\nCOOMatrix CSRToCOO(CSRMatrix csr) {\n  const int64_t nnz = csr.indices->shape[0];\n  const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);\n  NDArray ret_row = NDArray::Empty({nnz}, csr.indices->dtype, csr.indices->ctx);\n  IdType* ret_row_data = static_cast<IdType*>(ret_row->data);\n  parallel_for(0, csr.indptr->shape[0] - 1, 10000, [=](int64_t b, int64_t e) {\n    for (auto i = b; i < e; ++i) {\n      std::fill(\n          ret_row_data + indptr_data[i], ret_row_data + indptr_data[i + 1], i);\n    }\n  });\n  return COOMatrix(\n      csr.num_rows, csr.num_cols, ret_row, csr.indices, csr.data, true,\n      csr.sorted);\n}\n\ntemplate COOMatrix CSRToCOO<kDGLCPU, int32_t>(CSRMatrix csr);\ntemplate COOMatrix CSRToCOO<kDGLCPU, int64_t>(CSRMatrix csr);\n\n// complexity: time O(NNZ), space O(1)\ntemplate <DGLDeviceType XPU, typename IdType>\nCOOMatrix CSRToCOODataAsOrder(CSRMatrix csr) {\n  const int64_t N = csr.num_rows;\n  const int64_t M = csr.num_cols;\n  const int64_t nnz = csr.indices->shape[0];\n  const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);\n  const IdType* indices_data = static_cast<IdType*>(csr.indices->data);\n  // data array should have the same type as the indices arrays\n  const IdType* data =\n      CSRHasData(csr) ? static_cast<IdType*>(csr.data->data) : nullptr;\n  NDArray ret_row = NDArray::Empty({nnz}, csr.indices->dtype, csr.indices->ctx);\n  NDArray ret_col = NDArray::Empty({nnz}, csr.indices->dtype, csr.indices->ctx);\n  IdType* ret_row_data = static_cast<IdType*>(ret_row->data);\n  IdType* ret_col_data = static_cast<IdType*>(ret_col->data);\n  // scatter using the indices in the data array\n  parallel_for(0, N, 10000, [=](int64_t b, int64_t e) {\n    for (auto row = b; row < e; ++row) {\n      for (IdType j = indptr_data[row]; j < indptr_data[row + 1]; ++j) {\n        const IdType col = indices_data[j];\n        ret_row_data[data ? data[j] : j] = row;\n        ret_col_data[data ? data[j] : j] = col;\n      }\n    }\n  });\n  return COOMatrix(N, M, ret_row, ret_col);\n}\n\ntemplate COOMatrix CSRToCOODataAsOrder<kDGLCPU, int32_t>(CSRMatrix csr);\ntemplate COOMatrix CSRToCOODataAsOrder<kDGLCPU, int64_t>(CSRMatrix csr);\n\n///////////////////////////// CSRSliceRows /////////////////////////////\n\ntemplate <DGLDeviceType XPU, typename IdType>\nCSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end) {\n  const IdType* indptr = static_cast<IdType*>(csr.indptr->data);\n  const int64_t num_rows = end - start;\n  const int64_t nnz = indptr[end] - indptr[start];\n  IdArray ret_indptr =\n      IdArray::Empty({num_rows + 1}, csr.indptr->dtype, csr.indices->ctx);\n  IdType* r_indptr = static_cast<IdType*>(ret_indptr->data);\n  for (int64_t i = start; i < end + 1; ++i) {\n    r_indptr[i - start] = indptr[i] - indptr[start];\n  }\n  // indices and data can be view arrays\n  IdArray ret_indices = csr.indices.CreateView(\n      {nnz}, csr.indices->dtype, indptr[start] * sizeof(IdType));\n  IdArray ret_data;\n  if (CSRHasData(csr))\n    ret_data = csr.data.CreateView(\n        {nnz}, csr.data->dtype, indptr[start] * sizeof(IdType));\n  else\n    ret_data = aten::Range(\n        indptr[start], indptr[end], csr.indptr->dtype.bits, csr.indptr->ctx);\n  return CSRMatrix(\n      num_rows, csr.num_cols, ret_indptr, ret_indices, ret_data, csr.sorted);\n}\n\ntemplate CSRMatrix CSRSliceRows<kDGLCPU, int32_t>(CSRMatrix, int64_t, int64_t);\ntemplate CSRMatrix CSRSliceRows<kDGLCPU, int64_t>(CSRMatrix, int64_t, int64_t);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nCSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) {\n  CHECK_SAME_DTYPE(csr.indices, rows);\n  const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);\n  const IdType* indices_data = static_cast<IdType*>(csr.indices->data);\n  const IdType* data =\n      CSRHasData(csr) ? static_cast<IdType*>(csr.data->data) : nullptr;\n  const auto len = rows->shape[0];\n  const IdType* rows_data = static_cast<IdType*>(rows->data);\n  int64_t nnz = 0;\n\n  CSRMatrix ret;\n  ret.num_rows = len;\n  ret.num_cols = csr.num_cols;\n  ret.indptr = NDArray::Empty({len + 1}, csr.indptr->dtype, csr.indices->ctx);\n\n  IdType* ret_indptr_data = static_cast<IdType*>(ret.indptr->data);\n  ret_indptr_data[0] = 0;\n\n  std::vector<IdType> sums;\n\n  std::atomic_flag err_flag = ATOMIC_FLAG_INIT;\n  bool err = false;\n  std::stringstream err_msg_stream;\n\n// Perform two-round parallel prefix sum using OpenMP\n#pragma omp parallel\n  {\n    int64_t tid = omp_get_thread_num();\n    int64_t num_threads = omp_get_num_threads();\n\n#pragma omp single\n    {\n      sums.resize(num_threads + 1);\n      sums[0] = 0;\n    }\n\n    int64_t sum = 0;\n\n// First round of parallel prefix sum. All threads perform local prefix sums.\n#pragma omp for schedule(static) nowait\n    for (int64_t i = 0; i < len; ++i) {\n      int64_t rid = rows_data[i];\n      if (rid >= csr.num_rows) {\n        if (!err_flag.test_and_set()) {\n          err_msg_stream << \"expect row ID \" << rid\n                         << \" to be less than number of rows \" << csr.num_rows;\n          err = true;\n        }\n      } else {\n        sum += indptr_data[rid + 1] - indptr_data[rid];\n        ret_indptr_data[i + 1] = sum;\n      }\n    }\n    sums[tid + 1] = sum;\n#pragma omp barrier\n\n#pragma omp single\n    {\n      for (int64_t i = 1; i < num_threads; ++i) sums[i] += sums[i - 1];\n    }\n\n    int64_t offset = sums[tid];\n\n// Second round of parallel prefix sum. Update the local prefix sums.\n#pragma omp for schedule(static)\n    for (int64_t i = 0; i < len; ++i) ret_indptr_data[i + 1] += offset;\n  }\n  if (err) {\n    LOG(FATAL) << err_msg_stream.str();\n    return ret;\n  }\n\n  // After the prefix sum, the last element of ret_indptr_data holds the\n  // sum of all elements\n  nnz = ret_indptr_data[len];\n\n  ret.indices = NDArray::Empty({nnz}, csr.indices->dtype, csr.indices->ctx);\n  ret.data = NDArray::Empty({nnz}, csr.indptr->dtype, csr.indptr->ctx);\n  ret.sorted = csr.sorted;\n\n  IdType* ret_indices_data = static_cast<IdType*>(ret.indices->data);\n  IdType* ret_data = static_cast<IdType*>(ret.data->data);\n\n  parallel_for(0, len, [=](int64_t b, int64_t e) {\n    for (auto i = b; i < e; ++i) {\n      const IdType rid = rows_data[i];\n      // note: zero is allowed\n      std::copy(\n          indices_data + indptr_data[rid], indices_data + indptr_data[rid + 1],\n          ret_indices_data + ret_indptr_data[i]);\n      if (data)\n        std::copy(\n            data + indptr_data[rid], data + indptr_data[rid + 1],\n            ret_data + ret_indptr_data[i]);\n      else\n        std::iota(\n            ret_data + ret_indptr_data[i], ret_data + ret_indptr_data[i + 1],\n            indptr_data[rid]);\n    }\n  });\n  return ret;\n}\n\ntemplate CSRMatrix CSRSliceRows<kDGLCPU, int32_t>(CSRMatrix, NDArray);\ntemplate CSRMatrix CSRSliceRows<kDGLCPU, int64_t>(CSRMatrix, NDArray);\n\n///////////////////////////// CSRSliceMatrix /////////////////////////////\n\ntemplate <DGLDeviceType XPU, typename IdType>\nCSRMatrix CSRSliceMatrix(\n    CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols) {\n  IdHashMap<IdType> hashmap(cols);\n  const int64_t new_nrows = rows->shape[0];\n  const int64_t new_ncols = cols->shape[0];\n  const IdType* rows_data = static_cast<IdType*>(rows->data);\n  const bool has_data = CSRHasData(csr);\n\n  const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);\n  const IdType* indices_data = static_cast<IdType*>(csr.indices->data);\n  const IdType* data =\n      has_data ? static_cast<IdType*>(csr.data->data) : nullptr;\n\n  std::vector<IdType> sub_indptr, sub_indices;\n  std::vector<IdType> sub_data;\n  sub_indptr.resize(new_nrows + 1, 0);\n  const IdType kInvalidId = new_ncols + 1;\n  for (int64_t i = 0; i < new_nrows; ++i) {\n    // NOTE: newi == i\n    const IdType oldi = rows_data[i];\n    CHECK(oldi >= 0 && oldi < csr.num_rows) << \"Invalid row index: \" << oldi;\n    for (IdType p = indptr_data[oldi]; p < indptr_data[oldi + 1]; ++p) {\n      const IdType oldj = indices_data[p];\n      const IdType newj = hashmap.Map(oldj, kInvalidId);\n      if (newj != kInvalidId) {\n        ++sub_indptr[i];\n        sub_indices.push_back(newj);\n        sub_data.push_back(has_data ? data[p] : p);\n      }\n    }\n  }\n\n  // cumsum sub_indptr\n  for (int64_t i = 0, cumsum = 0; i < new_nrows; ++i) {\n    const IdType temp = sub_indptr[i];\n    sub_indptr[i] = cumsum;\n    cumsum += temp;\n  }\n  sub_indptr[new_nrows] = sub_indices.size();\n\n  const int64_t nnz = sub_data.size();\n  NDArray sub_data_arr =\n      NDArray::Empty({nnz}, csr.indptr->dtype, csr.indptr->ctx);\n  IdType* ptr = static_cast<IdType*>(sub_data_arr->data);\n  std::copy(sub_data.begin(), sub_data.end(), ptr);\n  return CSRMatrix{\n      new_nrows, new_ncols, NDArray::FromVector(sub_indptr, csr.indptr->ctx),\n      NDArray::FromVector(sub_indices, csr.indptr->ctx), sub_data_arr};\n}\n\ntemplate CSRMatrix CSRSliceMatrix<kDGLCPU, int32_t>(\n    CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);\ntemplate CSRMatrix CSRSliceMatrix<kDGLCPU, int64_t>(\n    CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);\n\n///////////////////////////// CSRReorder /////////////////////////////\n\ntemplate <DGLDeviceType XPU, typename IdType>\nCSRMatrix CSRReorder(\n    CSRMatrix csr, runtime::NDArray new_row_id_arr,\n    runtime::NDArray new_col_id_arr) {\n  CHECK_SAME_DTYPE(csr.indices, new_row_id_arr);\n  CHECK_SAME_DTYPE(csr.indices, new_col_id_arr);\n\n  // Input CSR\n  const IdType* in_indptr = static_cast<IdType*>(csr.indptr->data);\n  const IdType* in_indices = static_cast<IdType*>(csr.indices->data);\n  const IdType* in_data = static_cast<IdType*>(csr.data->data);\n  int64_t num_rows = csr.num_rows;\n  int64_t num_cols = csr.num_cols;\n  int64_t nnz = csr.indices->shape[0];\n  CHECK_EQ(nnz, in_indptr[num_rows]);\n  CHECK_EQ(num_rows, new_row_id_arr->shape[0])\n      << \"The new row Id array needs to be the same as the number of rows of \"\n         \"CSR\";\n  CHECK_EQ(num_cols, new_col_id_arr->shape[0])\n      << \"The new col Id array needs to be the same as the number of cols of \"\n         \"CSR\";\n\n  // New row/col Ids.\n  const IdType* new_row_ids = static_cast<IdType*>(new_row_id_arr->data);\n  const IdType* new_col_ids = static_cast<IdType*>(new_col_id_arr->data);\n\n  // Output CSR\n  NDArray out_indptr_arr =\n      NDArray::Empty({num_rows + 1}, csr.indptr->dtype, csr.indptr->ctx);\n  NDArray out_indices_arr =\n      NDArray::Empty({nnz}, csr.indices->dtype, csr.indices->ctx);\n  NDArray out_data_arr = NDArray::Empty({nnz}, csr.data->dtype, csr.data->ctx);\n  IdType* out_indptr = static_cast<IdType*>(out_indptr_arr->data);\n  IdType* out_indices = static_cast<IdType*>(out_indices_arr->data);\n  IdType* out_data = static_cast<IdType*>(out_data_arr->data);\n\n  // Compute the length of rows for the new matrix.\n  std::vector<IdType> new_row_lens(num_rows, -1);\n  parallel_for(0, num_rows, [=, &new_row_lens](size_t b, size_t e) {\n    for (auto i = b; i < e; ++i) {\n      int64_t new_row_id = new_row_ids[i];\n      new_row_lens[new_row_id] = in_indptr[i + 1] - in_indptr[i];\n    }\n  });\n  // Compute the starting location of each row in the new matrix.\n  out_indptr[0] = 0;\n  // This is sequential. It should be pretty fast.\n  for (int64_t i = 0; i < num_rows; i++) {\n    CHECK_GE(new_row_lens[i], 0);\n    out_indptr[i + 1] = out_indptr[i] + new_row_lens[i];\n  }\n  CHECK_EQ(out_indptr[num_rows], nnz);\n  // Copy indieces and data with the new order.\n  // Here I iterate rows in the order of the old matrix.\n  parallel_for(0, num_rows, [=](size_t b, size_t e) {\n    for (auto i = b; i < e; ++i) {\n      const IdType* in_row = in_indices + in_indptr[i];\n      const IdType* in_row_data = in_data + in_indptr[i];\n\n      int64_t new_row_id = new_row_ids[i];\n      IdType* out_row = out_indices + out_indptr[new_row_id];\n      IdType* out_row_data = out_data + out_indptr[new_row_id];\n\n      int64_t row_len = new_row_lens[new_row_id];\n      // Here I iterate col indices in a row in the order of the old matrix.\n      for (int64_t j = 0; j < row_len; j++) {\n        out_row[j] = new_col_ids[in_row[j]];\n        out_row_data[j] = in_row_data[j];\n      }\n      // TODO(zhengda) maybe we should sort the column indices.\n    }\n  });\n  return CSRMatrix(\n      num_rows, num_cols, out_indptr_arr, out_indices_arr, out_data_arr);\n}\n\ntemplate CSRMatrix CSRReorder<kDGLCPU, int64_t>(\n    CSRMatrix csr, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids);\ntemplate CSRMatrix CSRReorder<kDGLCPU, int32_t>(\n    CSRMatrix csr, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids);\n\n}  // namespace impl\n}  // namespace aten\n}  // namespace dgl\n"
  },
  {
    "path": "src/array/cpu/spmm.cc",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file kernel/cpu/spmm.cc\n * @brief SPMM C APIs and definitions.\n */\n#include \"./spmm.h\"\n\n#include <dgl/array.h>\n\nnamespace dgl {\nnamespace aten {\n\n/** @brief Generalized SpMM on Csr format. */\ntemplate <int XPU, typename IdType, typename DType>\nvoid SpMMCsr(\n    const std::string& op, const std::string& reduce, const BcastOff& bcast,\n    const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,\n    std::vector<NDArray> out_aux) {\n  const int64_t dim = bcast.out_len;\n  if (reduce == \"sum\") {\n    SWITCH_OP(op, Op, {\n      cpu::SpMMSumCsr<IdType, DType, Op>(bcast, csr, ufeat, efeat, out);\n    });\n  } else if (reduce == \"max\" || reduce == \"min\") {\n    SWITCH_OP(op, Op, {\n      DType* out_off = out.Ptr<DType>();\n      if (reduce == \"max\") {\n        std::fill(\n            out_off, out_off + csr.num_rows * dim, cpu::op::Max<DType>::zero);\n        cpu::SpMMCmpCsr<IdType, DType, Op, cpu::op::Max<DType>>(\n            bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]);\n      } else {\n        std::fill(\n            out_off, out_off + csr.num_rows * dim, cpu::op::Min<DType>::zero);\n        cpu::SpMMCmpCsr<IdType, DType, Op, cpu::op::Min<DType>>(\n            bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]);\n      }\n    });\n  } else {\n    LOG(FATAL) << \"Unsupported SpMM reducer: \" << reduce;\n  }\n}\n\n/** @brief Generalized SpMM on Csr format. */\ntemplate <int XPU, typename IdType, typename DType>\nvoid SpMMCsrHetero(\n    const std::string& op, const std::string& reduce, const BcastOff& bcast,\n    const std::vector<CSRMatrix>& vec_csr,\n    const std::vector<NDArray>& vec_ufeat,\n    const std::vector<NDArray>& vec_efeat, std::vector<NDArray>* vec_out,\n    std::vector<std::vector<NDArray>>* out_aux,\n    const std::vector<dgl_type_t>& ufeat_node_tids,\n    const std::vector<dgl_type_t>& out_node_tids) {\n  const int64_t dim = bcast.out_len;\n  if (reduce == \"sum\") {\n    SWITCH_OP(op, Op, {\n      /* Call  SpMM for each relation type */\n      for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) {\n        const dgl_type_t src_id = ufeat_node_tids[etype];\n        const dgl_type_t dst_id = out_node_tids[etype];\n        CSRMatrix csr = vec_csr[etype];\n        NDArray ufeat =\n            (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];\n        NDArray efeat =\n            (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];\n        NDArray out = (*vec_out)[dst_id];\n        cpu::SpMMSumCsr<IdType, DType, Op>(bcast, csr, ufeat, efeat, out);\n      }\n    });\n  } else if (reduce == \"max\" || reduce == \"min\") {\n    SWITCH_OP(op, Op, {\n      std::vector<bool> updated((*vec_out).size(), false);\n      // TODO(Israt): use vector updated to fill(out...) too\n      for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) {\n        DType* out_off = (*vec_out)[out_node_tids[etype]].Ptr<DType>();\n        if (reduce == \"max\")\n          std::fill(\n              out_off, out_off + vec_csr[etype].num_rows * dim,\n              cpu::op::Max<DType>::zero);\n        else\n          std::fill(\n              out_off, out_off + vec_csr[etype].num_rows * dim,\n              cpu::op::Min<DType>::zero);\n        const dgl_type_t dst_id = out_node_tids[etype];\n        if (!updated[dst_id]) {\n          updated[dst_id] = true;\n          if (Op::use_lhs) {\n            IdType* argu_ntype = (*out_aux)[2][dst_id].Ptr<IdType>();\n            std::fill(\n                argu_ntype, argu_ntype + vec_csr[etype].num_rows * dim, -1);\n          }\n          if (Op::use_rhs) {\n            IdType* arge_etype = (*out_aux)[3][dst_id].Ptr<IdType>();\n            std::fill(\n                arge_etype, arge_etype + vec_csr[etype].num_rows * dim, -1);\n          }\n        }\n      }\n      /* Call  SpMM for each relation type */\n      for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) {\n        const dgl_type_t src_id = ufeat_node_tids[etype];\n        const dgl_type_t dst_id = out_node_tids[etype];\n        CSRMatrix csr = vec_csr[etype];\n        NDArray ufeat =\n            (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];\n        NDArray efeat =\n            (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];\n        NDArray out = (*vec_out)[dst_id];\n        if (reduce == \"max\") {\n          cpu::SpMMCmpCsrHetero<IdType, DType, Op, cpu::op::Max<DType>>(\n              bcast, csr, ufeat, efeat, out, (*out_aux)[0][dst_id],\n              (*out_aux)[1][dst_id], (*out_aux)[2][dst_id],\n              (*out_aux)[3][dst_id], src_id, etype);\n        } else {\n          cpu::SpMMCmpCsrHetero<IdType, DType, Op, cpu::op::Min<DType>>(\n              bcast, csr, ufeat, efeat, out, (*out_aux)[0][dst_id],\n              (*out_aux)[1][dst_id], (*out_aux)[2][dst_id],\n              (*out_aux)[3][dst_id], src_id, etype);\n        }\n      }\n    });\n  } else {\n    LOG(FATAL) << \"Unsupported SpMM reducer: \" << reduce;\n  }\n}\n\ntemplate void SpMMCsr<kDGLCPU, int32_t, BFloat16>(\n    const std::string& op, const std::string& reduce, const BcastOff& bcast,\n    const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,\n    std::vector<NDArray> out_aux);\ntemplate void SpMMCsr<kDGLCPU, int64_t, BFloat16>(\n    const std::string& op, const std::string& reduce, const BcastOff& bcast,\n    const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,\n    std::vector<NDArray> out_aux);\ntemplate void SpMMCsr<kDGLCPU, int32_t, float>(\n    const std::string& op, const std::string& reduce, const BcastOff& bcast,\n    const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,\n    std::vector<NDArray> out_aux);\ntemplate void SpMMCsr<kDGLCPU, int64_t, float>(\n    const std::string& op, const std::string& reduce, const BcastOff& bcast,\n    const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,\n    std::vector<NDArray> out_aux);\ntemplate void SpMMCsr<kDGLCPU, int32_t, double>(\n    const std::string& op, const std::string& reduce, const BcastOff& bcast,\n    const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,\n    std::vector<NDArray> out_aux);\ntemplate void SpMMCsr<kDGLCPU, int64_t, double>(\n    const std::string& op, const std::string& reduce, const BcastOff& bcast,\n    const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,\n    std::vector<NDArray> out_aux);\n\ntemplate void SpMMCsrHetero<kDGLCPU, int32_t, BFloat16>(\n    const std::string& op, const std::string& reduce, const BcastOff& bcast,\n    const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,\n    const std::vector<NDArray>& efeat, std::vector<NDArray>* out,\n    std::vector<std::vector<NDArray>>* out_aux,\n    const std::vector<dgl_type_t>& ufeat_node_tids,\n    const std::vector<dgl_type_t>& out_node_tids);\ntemplate void SpMMCsrHetero<kDGLCPU, int64_t, BFloat16>(\n    const std::string& op, const std::string& reduce, const BcastOff& bcast,\n    const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,\n    const std::vector<NDArray>& efeat, std::vector<NDArray>* out,\n    std::vector<std::vector<NDArray>>* out_aux,\n    const std::vector<dgl_type_t>& ufeat_node_tids,\n    const std::vector<dgl_type_t>& out_node_tids);\ntemplate void SpMMCsrHetero<kDGLCPU, int32_t, float>(\n    const std::string& op, const std::string& reduce, const BcastOff& bcast,\n    const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,\n    const std::vector<NDArray>& efeat, std::vector<NDArray>* out,\n    std::vector<std::vector<NDArray>>* out_aux,\n    const std::vector<dgl_type_t>& ufeat_node_tids,\n    const std::vector<dgl_type_t>& out_node_tids);\ntemplate void SpMMCsrHetero<kDGLCPU, int64_t, float>(\n    const std::string& op, const std::string& reduce, const BcastOff& bcast,\n    const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,\n    const std::vector<NDArray>& efeat, std::vector<NDArray>* out,\n    std::vector<std::vector<NDArray>>* out_aux,\n    const std::vector<dgl_type_t>& ufeat_node_tids,\n    const std::vector<dgl_type_t>& out_node_tids);\ntemplate void SpMMCsrHetero<kDGLCPU, int32_t, double>(\n    const std::string& op, const std::string& reduce, const BcastOff& bcast,\n    const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,\n    const std::vector<NDArray>& efeat, std::vector<NDArray>* out,\n    std::vector<std::vector<NDArray>>* out_aux,\n    const std::vector<dgl_type_t>& ufeat_node_tids,\n    const std::vector<dgl_type_t>& out_node_tids);\ntemplate void SpMMCsrHetero<kDGLCPU, int64_t, double>(\n    const std::string& op, const std::string& reduce, const BcastOff& bcast,\n    const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,\n    const std::vector<NDArray>& efeat, std::vector<NDArray>* out,\n    std::vector<std::vector<NDArray>>* out_aux,\n    const std::vector<dgl_type_t>& ufeat_node_tids,\n    const std::vector<dgl_type_t>& out_node_tids);\n\n/** @brief Edge_softmax_csr forward op on Csr format. */\ntemplate <int XPU, typename IdType, typename DType>\nvoid Edge_softmax_csr_forward(\n    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,\n    NDArray ufeat, NDArray efeat, NDArray out) {\n  SWITCH_OP(op, Op, {\n    cpu::Edge_softmax_csr_forward<IdType, DType, Op>(\n        bcast, csr, ufeat, efeat, out);\n  });\n}\n\n/** @brief Edge_softmax_csr backward op on Csr format. */\ntemplate <int XPU, typename IdType, typename DType>\nvoid Edge_softmax_csr_backward(\n    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,\n    NDArray out, NDArray sds, NDArray back_out) {\n  SWITCH_OP(op, Op, {\n    cpu::Edge_softmax_csr_backward<IdType, DType, Op>(\n        bcast, csr, out, sds, back_out);\n  });\n}\ntemplate void Edge_softmax_csr_forward<kDGLCPU, int32_t, BFloat16>(\n    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,\n    NDArray ufeat, NDArray efeat, NDArray out);\ntemplate void Edge_softmax_csr_forward<kDGLCPU, int64_t, BFloat16>(\n    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,\n    NDArray ufeat, NDArray efeat, NDArray out);\ntemplate void Edge_softmax_csr_forward<kDGLCPU, int32_t, float>(\n    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,\n    NDArray ufeat, NDArray efeat, NDArray out);\ntemplate void Edge_softmax_csr_forward<kDGLCPU, int64_t, float>(\n    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,\n    NDArray ufeat, NDArray efeat, NDArray out);\ntemplate void Edge_softmax_csr_forward<kDGLCPU, int32_t, double>(\n    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,\n    NDArray ufeat, NDArray efeat, NDArray out);\ntemplate void Edge_softmax_csr_forward<kDGLCPU, int64_t, double>(\n    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,\n    NDArray ufeat, NDArray efeat, NDArray out);\n\ntemplate void Edge_softmax_csr_backward<kDGLCPU, int32_t, BFloat16>(\n    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,\n    NDArray ufeat, NDArray efeat, NDArray out);\ntemplate void Edge_softmax_csr_backward<kDGLCPU, int64_t, BFloat16>(\n    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,\n    NDArray ufeat, NDArray efeat, NDArray out);\ntemplate void Edge_softmax_csr_backward<kDGLCPU, int32_t, float>(\n    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,\n    NDArray ufeat, NDArray efeat, NDArray out);\ntemplate void Edge_softmax_csr_backward<kDGLCPU, int64_t, float>(\n    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,\n    NDArray ufeat, NDArray efeat, NDArray out);\ntemplate void Edge_softmax_csr_backward<kDGLCPU, int32_t, double>(\n    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,\n    NDArray ufeat, NDArray efeat, NDArray out);\ntemplate void Edge_softmax_csr_backward<kDGLCPU, int64_t, double>(\n    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,\n    NDArray ufeat, NDArray efeat, NDArray out);\n\n/** @brief Generalized SpMM on Coo format. */\ntemplate <int XPU, typename IdType, typename DType>\nvoid SpMMCoo(\n    const std::string& op, const std::string& reduce, const BcastOff& bcast,\n    const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,\n    std::vector<NDArray> out_aux) {\n  if (reduce == \"sum\") {\n    SWITCH_OP(op, Op, {\n      cpu::SpMMSumCoo<IdType, DType, Op>(bcast, coo, ufeat, efeat, out);\n    });\n  } else if (reduce == \"max\" || reduce == \"min\") {\n    SWITCH_OP(op, Op, {\n      if (reduce == \"max\")\n        cpu::SpMMCmpCoo<IdType, DType, Op, cpu::op::Max<DType>>(\n            bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]);\n      else\n        cpu::SpMMCmpCoo<IdType, DType, Op, cpu::op::Min<DType>>(\n            bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]);\n    });\n  } else {\n    LOG(FATAL) << \"Unsupported SpMM reducer: \" << reduce;\n  }\n}\n\ntemplate void SpMMCoo<kDGLCPU, int32_t, BFloat16>(\n    const std::string& op, const std::string& reduce, const BcastOff& bcast,\n    const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,\n    std::vector<NDArray> out_aux);\ntemplate void SpMMCoo<kDGLCPU, int64_t, BFloat16>(\n    const std::string& op, const std::string& reduce, const BcastOff& bcast,\n    const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,\n    std::vector<NDArray> out_aux);\ntemplate void SpMMCoo<kDGLCPU, int32_t, float>(\n    const std::string& op, const std::string& reduce, const BcastOff& bcast,\n    const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,\n    std::vector<NDArray> out_aux);\ntemplate void SpMMCoo<kDGLCPU, int64_t, float>(\n    const std::string& op, const std::string& reduce, const BcastOff& bcast,\n    const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,\n    std::vector<NDArray> out_aux);\ntemplate void SpMMCoo<kDGLCPU, int32_t, double>(\n    const std::string& op, const std::string& reduce, const BcastOff& bcast,\n    const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,\n    std::vector<NDArray> out_aux);\ntemplate void SpMMCoo<kDGLCPU, int64_t, double>(\n    const std::string& op, const std::string& reduce, const BcastOff& bcast,\n    const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,\n    std::vector<NDArray> out_aux);\n\n}  // namespace aten\n}  // namespace dgl\n"
  },
  {
    "path": "src/array/cpu/spmm.h",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file array/cpu/spmm.h\n * @brief SPMM CPU kernel function header.\n */\n#ifndef DGL_ARRAY_CPU_SPMM_H_\n#define DGL_ARRAY_CPU_SPMM_H_\n\n#include <dgl/array.h>\n#include <dgl/bcast.h>\n#include <dgl/runtime/config.h>\n#include <dgl/runtime/parallel_for.h>\n#include <math.h>\n\n#include <algorithm>\n#include <limits>\n#include <memory>\n#include <vector>\n\n#include \"spmm_binary_ops.h\"\n#if !defined(_WIN32)\n#ifdef USE_LIBXSMM\n#include \"spmm_blocking_libxsmm.h\"\n#endif  // USE_LIBXSMM\n#endif  // _WIN32\nnamespace dgl {\nnamespace aten {\nnamespace cpu {\n\ntemplate <typename DType>\nusing AccType = typename std::conditional<\n    std::is_same<DType, BFloat16>::value, float, DType>::type;\n\n/**\n * @brief Naive CPU kernel of SpMM on Csr format.\n * @param cpu_spec JIT'ed kernel\n * @param bcast Broadcast information.\n * @param csr The Csr matrix.\n * @param X The feature on source nodes.\n * @param W The feature on edges.\n * @param O The result feature on destination nodes.\n * @note it uses node parallel strategy, different threads are responsible\n *       for the computation of different nodes.\n */\ntemplate <typename IdType, typename DType, typename Op>\ntypename std::enable_if<!std::is_same<DType, BFloat16>::value, void>::type\nSpMMSumCsrNaive(\n    const BcastOff& bcast, const CSRMatrix& csr, const DType* X, const DType* W,\n    DType* O) {\n  const bool has_idx = !IsNullArray(csr.data);\n  const IdType* indptr = csr.indptr.Ptr<IdType>();\n  const IdType* indices = csr.indices.Ptr<IdType>();\n  const IdType* edges = csr.data.Ptr<IdType>();\n  int64_t dim = bcast.out_len, lhs_dim = bcast.lhs_len, rhs_dim = bcast.rhs_len;\n  runtime::parallel_for(0, csr.num_rows, [&](size_t b, size_t e) {\n    for (auto rid = b; rid < e; ++rid) {\n      const IdType row_start = indptr[rid], row_end = indptr[rid + 1];\n      DType* out_off = O + rid * dim;\n      for (IdType j = row_start; j < row_end; ++j) {\n        const IdType cid = indices[j];\n        const IdType eid = has_idx ? edges[j] : j;\n        for (int64_t k = 0; k < dim; ++k) {\n          const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k;\n          const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;\n          const DType* lhs_off =\n              Op::use_lhs ? X + cid * lhs_dim + lhs_add : nullptr;\n          const DType* rhs_off =\n              Op::use_rhs ? W + eid * rhs_dim + rhs_add : nullptr;\n          out_off[k] += Op::Call(lhs_off, rhs_off);\n        }\n      }\n    }\n  });\n}\n\n// Naive implementation with additional accumulator, which prevents accuracy\n// degradation in less precise data types, like bfloat16.\ntemplate <typename IdType, typename DType, typename Op>\ntypename std::enable_if<std::is_same<DType, BFloat16>::value, void>::type\nSpMMSumCsrNaive(\n    const BcastOff& bcast, const CSRMatrix& csr, const DType* X, const DType* W,\n    DType* O) {\n  const bool has_idx = !IsNullArray(csr.data);\n  const IdType* indptr = csr.indptr.Ptr<IdType>();\n  const IdType* indices = csr.indices.Ptr<IdType>();\n  const IdType* edges = csr.data.Ptr<IdType>();\n  int64_t dim = bcast.out_len, lhs_dim = bcast.lhs_len, rhs_dim = bcast.rhs_len;\n  runtime::parallel_for(0, csr.num_rows, [&](size_t b, size_t e) {\n    for (auto rid = b; rid < e; ++rid) {\n      const IdType row_start = indptr[rid], row_end = indptr[rid + 1];\n      DType* out_off = O + rid * dim;\n      for (int64_t k = 0; k < dim; ++k) {\n        AccType<DType> acc = 0.;\n        for (IdType j = row_start; j < row_end; ++j) {\n          const IdType cid = indices[j];\n          const IdType eid = has_idx ? edges[j] : j;\n          const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k;\n          const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;\n          const DType* lhs_off =\n              Op::use_lhs ? X + cid * lhs_dim + lhs_add : nullptr;\n          const DType* rhs_off =\n              Op::use_rhs ? W + eid * rhs_dim + rhs_add : nullptr;\n          acc += Op::Call(lhs_off, rhs_off);\n        }\n        out_off[k] += acc;\n      }\n    }\n  });\n}\n\n/**\n * @brief CPU kernel of SpMM on Csr format.\n * @param bcast Broadcast information.\n * @param csr The Csr matrix.\n * @param ufeat The feature on source nodes.\n * @param efeat The feature on edges.\n * @param out The result feature on destination nodes.\n * @note it uses node parallel strategy, different threads are responsible\n *       for the computation of different nodes.\n */\ntemplate <typename IdType, typename DType, typename Op>\nvoid SpMMSumCsr(\n    const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat,\n    NDArray out) {\n  const bool has_idx = !IsNullArray(csr.data);\n  const IdType* indptr = csr.indptr.Ptr<IdType>();\n  const IdType* indices = csr.indices.Ptr<IdType>();\n  const IdType* edges = csr.data.Ptr<IdType>();\n  const DType* X = ufeat.Ptr<DType>();\n  const DType* W = efeat.Ptr<DType>();\n  DType* O = out.Ptr<DType>();\n  CHECK_NOTNULL(indptr);\n  CHECK_NOTNULL(O);\n  if (Op::use_lhs) {\n    CHECK_NOTNULL(indices);\n    CHECK_NOTNULL(X);\n  }\n  if (Op::use_rhs) {\n    if (has_idx) CHECK_NOTNULL(edges);\n    CHECK_NOTNULL(W);\n  }\n#if !defined(_WIN32)\n#ifdef USE_LIBXSMM\n  int cpu_id = libxsmm_cpuid_x86();\n  const bool no_libxsmm =\n      bcast.use_bcast || std::is_same<DType, double>::value ||\n      (std::is_same<DType, BFloat16>::value && cpu_id < LIBXSMM_X86_AVX512) ||\n      !dgl::runtime::Config::Global()->IsLibxsmmAvailable();\n  if (!no_libxsmm) {\n    SpMMSumCsrLibxsmm<IdType, DType, Op>(bcast, csr, ufeat, efeat, out);\n  } else {\n#endif  // USE_LIBXSMM\n#endif  // _WIN32\n    SpMMSumCsrNaive<IdType, DType, Op>(bcast, csr, X, W, O);\n#if !defined(_WIN32)\n#ifdef USE_LIBXSMM\n  }\n#endif  // USE_LIBXSMM\n#endif  // _WIN32\n}\n\n/**\n * @brief CPU kernel of SpMM on Coo format.\n * @param bcast Broadcast information.\n * @param coo The Coo matrix.\n * @param ufeat The feature on source nodes.\n * @param efeat The feature on edges.\n * @param out The result feature on destination nodes.\n * @note it uses node parallel strategy, different threads are responsible\n *       for the computation of different nodes. To avoid possible data hazard,\n *       we use atomic operators in the reduction phase.\n */\ntemplate <typename IdType, typename DType, typename Op>\ntypename std::enable_if<!std::is_same<DType, BFloat16>::value, void>::type\nSpMMSumCoo(\n    const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, NDArray efeat,\n    NDArray out) {\n  const bool has_idx = !IsNullArray(coo.data);\n  const IdType* row = coo.row.Ptr<IdType>();\n  const IdType* col = coo.col.Ptr<IdType>();\n  const IdType* edges = coo.data.Ptr<IdType>();\n  const DType* X = ufeat.Ptr<DType>();\n  const DType* W = efeat.Ptr<DType>();\n  int64_t dim = bcast.out_len, lhs_dim = bcast.lhs_len, rhs_dim = bcast.rhs_len;\n  DType* O = out.Ptr<DType>();\n  const int64_t nnz = coo.row->shape[0];\n  // fill zero elements\n  memset(O, 0, out.GetSize());\n  // spmm\n#pragma omp parallel for\n  for (IdType i = 0; i < nnz; ++i) {\n    const IdType rid = row[i];\n    const IdType cid = col[i];\n    const IdType eid = has_idx ? edges[i] : i;\n    DType* out_off = O + cid * dim;\n    for (int64_t k = 0; k < dim; ++k) {\n      const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k;\n      const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;\n      const DType* lhs_off =\n          Op::use_lhs ? X + rid * lhs_dim + lhs_add : nullptr;\n      const DType* rhs_off =\n          Op::use_rhs ? W + eid * rhs_dim + rhs_add : nullptr;\n      const DType val = Op::Call(lhs_off, rhs_off);\n      if (val != 0) {\n#pragma omp atomic\n        out_off[k] += val;\n      }\n    }\n  }\n}\n\ntemplate <typename IdType, typename DType, typename Op>\ntypename std::enable_if<std::is_same<DType, BFloat16>::value, void>::type\nSpMMSumCoo(\n    const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, NDArray efeat,\n    NDArray out) {\n  LOG(FATAL) << \"Unsupported CPU kernel for SpMMSumCoo for BF16.\";\n}\n\n/**\n * @brief CPU kernel of SpMM-Min/Max on Csr format.\n * @param bcast Broadcast information.\n * @param csr The Csr matrix.\n * @param ufeat The feature on source nodes.\n * @param efeat The feature on edges.\n * @param out The result feature on destination nodes.\n * @param argu Arg-Min/Max on source nodes, which refers the source node indices\n *        correspond to the minimum/maximum values of reduction result on\n *        destination nodes. It's useful in computing gradients of Min/Max\n *        reducer.\n * @param arge Arg-Min/Max on edges. which refers the source node indices\n          correspond to the minimum/maximum values of reduction result on\n *        destination nodes. It's useful in computing gradients of Min/Max\n *        reducer.\n * @note It uses node parallel strategy, different threads are responsible for\n *       the computation of different nodes.\n * @note The result will contain infinity for zero-degree nodes.\n */\ntemplate <typename IdType, typename DType, typename Op, typename Cmp>\nvoid SpMMCmpCsr(\n    const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat,\n    NDArray out, NDArray argu, NDArray arge) {\n  const bool has_idx = !IsNullArray(csr.data);\n  const IdType* indptr = static_cast<IdType*>(csr.indptr->data);\n  const IdType* indices = static_cast<IdType*>(csr.indices->data);\n  const IdType* edges =\n      has_idx ? static_cast<IdType*>(csr.data->data) : nullptr;\n  const DType* X = Op::use_lhs ? static_cast<DType*>(ufeat->data) : nullptr;\n  const DType* W = Op::use_rhs ? static_cast<DType*>(efeat->data) : nullptr;\n  const int64_t dim = bcast.out_len, lhs_dim = bcast.lhs_len,\n                rhs_dim = bcast.rhs_len;\n  DType* O = static_cast<DType*>(out->data);\n  IdType* argX = Op::use_lhs ? static_cast<IdType*>(argu->data) : nullptr;\n  IdType* argW = Op::use_rhs ? static_cast<IdType*>(arge->data) : nullptr;\n  CHECK_NOTNULL(indptr);\n  CHECK_NOTNULL(O);\n  if (Op::use_lhs) {\n    CHECK_NOTNULL(indices);\n    CHECK_NOTNULL(X);\n    CHECK_NOTNULL(argX);\n  }\n  if (Op::use_rhs) {\n    if (has_idx) CHECK_NOTNULL(edges);\n    CHECK_NOTNULL(W);\n    CHECK_NOTNULL(argW);\n  }\n#if !defined(_WIN32)\n#ifdef USE_LIBXSMM\n  int cpu_id = libxsmm_cpuid_x86();\n  const bool no_libxsmm = bcast.use_bcast ||\n                          std::is_same<DType, double>::value ||\n                          cpu_id < LIBXSMM_X86_AVX512 ||\n                          !dgl::runtime::Config::Global()->IsLibxsmmAvailable();\n  if (!no_libxsmm) {\n    SpMMCmpCsrLibxsmm<IdType, DType, Op, Cmp>(\n        bcast, csr, ufeat, efeat, out, argu, arge);\n  } else {\n#endif  // USE_LIBXSMM\n#endif  // _WIN32\n\n    runtime::parallel_for(0, csr.num_rows, [&](size_t b, size_t e) {\n      for (auto rid = b; rid < e; ++rid) {\n        const IdType row_start = indptr[rid], row_end = indptr[rid + 1];\n        DType* out_off = O + rid * dim;\n        IdType* argx_off = argX + rid * dim;\n        IdType* argw_off = argW + rid * dim;\n        for (IdType j = row_start; j < row_end; ++j) {\n          const IdType cid = indices[j];\n          const IdType eid = has_idx ? edges[j] : j;\n          for (int64_t k = 0; k < dim; ++k) {\n            const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k;\n            const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;\n            const DType* lhs_off =\n                Op::use_lhs ? X + cid * lhs_dim + lhs_add : nullptr;\n            const DType* rhs_off =\n                Op::use_rhs ? W + eid * rhs_dim + rhs_add : nullptr;\n            const DType val = Op::Call(lhs_off, rhs_off);\n            if (Cmp::Call(out_off[k], val)) {\n              out_off[k] = val;\n              if (Op::use_lhs) argx_off[k] = cid;\n              if (Op::use_rhs) argw_off[k] = eid;\n            }\n          }\n        }\n      }\n    });\n#if !defined(_WIN32)\n#ifdef USE_LIBXSMM\n  }\n#endif  // USE_LIBXSMM\n#endif  // _WIN32\n}\n\n/**\n * @brief CPU kernel of SpMM-Min/Max on Csr format.\n * @param bcast Broadcast information.\n * @param csr The Csr matrix.\n * @param ufeat The feature on source nodes.\n * @param efeat The feature on edges.\n * @param out The result feature on destination nodes.\n * @param argu Arg-Min/Max on source nodes, which refers the source node indices\n *        correspond to the minimum/maximum values of reduction result on\n *        destination nodes. It's useful in computing gradients of Min/Max\n *        reducer.\n * @param arge Arg-Min/Max on edges. which refers the source node indices\n *        correspond to the minimum/maximum values of reduction result on\n *        destination nodes. It's useful in computing gradients of Min/Max\n *        reducer.\n * @param argu_ntype Node type of the arg-Min/Max on source nodes, which refers\n *        the source node types correspond to the minimum/maximum values of\n *        reduction result on destination nodes. It's useful in computing\n *        gradients of Min/Max reducer.\n * @param arge_etype Edge-type of the arg-Min/Max on edges. which refers the\n *        source node indices correspond to the minimum/maximum values of\n *        reduction result on destination nodes. It's useful in computing\n *        gradients of Min/Max reducer.\n * @param src_type Node type of the source nodes of an etype\n * @param etype Edge type\n */\ntemplate <typename IdType, typename DType, typename Op, typename Cmp>\nvoid SpMMCmpCsrHetero(\n    const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat,\n    NDArray out, NDArray argu, NDArray arge, NDArray argu_ntype,\n    NDArray arge_etype, const int ntype, const int etype) {\n  const bool has_idx = !IsNullArray(csr.data);\n  const IdType* indptr = static_cast<IdType*>(csr.indptr->data);\n  const IdType* indices = static_cast<IdType*>(csr.indices->data);\n  const IdType* edges =\n      has_idx ? static_cast<IdType*>(csr.data->data) : nullptr;\n  const DType* X = Op::use_lhs ? static_cast<DType*>(ufeat->data) : nullptr;\n  const DType* W = Op::use_rhs ? static_cast<DType*>(efeat->data) : nullptr;\n  const int64_t dim = bcast.out_len, lhs_dim = bcast.lhs_len,\n                rhs_dim = bcast.rhs_len;\n  DType* O = static_cast<DType*>(out->data);\n  IdType* argX = Op::use_lhs ? static_cast<IdType*>(argu->data) : nullptr;\n  IdType* argW = Op::use_rhs ? static_cast<IdType*>(arge->data) : nullptr;\n  IdType* argX_ntype =\n      Op::use_lhs ? static_cast<IdType*>(argu_ntype->data) : nullptr;\n  IdType* argW_etype =\n      Op::use_rhs ? static_cast<IdType*>(arge_etype->data) : nullptr;\n  CHECK_NOTNULL(indptr);\n  CHECK_NOTNULL(O);\n  if (Op::use_lhs) {\n    CHECK_NOTNULL(indices);\n    CHECK_NOTNULL(X);\n    CHECK_NOTNULL(argX);\n  }\n  if (Op::use_rhs) {\n    if (has_idx) CHECK_NOTNULL(edges);\n    CHECK_NOTNULL(W);\n    CHECK_NOTNULL(argW);\n  }\n  // TODO(Israt): Use LIBXSMM. Homogeneous graph uses LIBXMM when enabled.\n  runtime::parallel_for(0, csr.num_rows, [&](size_t b, size_t e) {\n    for (auto rid = b; rid < e; ++rid) {\n      const IdType row_start = indptr[rid], row_end = indptr[rid + 1];\n      DType* out_off = O + rid * dim;\n      IdType* argx_off = argX + rid * dim;\n      IdType* argw_off = argW + rid * dim;\n      IdType* argx_ntype = argX_ntype + rid * dim;\n      IdType* argw_etype = argW_etype + rid * dim;\n      for (IdType j = row_start; j < row_end; ++j) {\n        const IdType cid = indices[j];\n        const IdType eid = has_idx ? edges[j] : j;\n        for (int64_t k = 0; k < dim; ++k) {\n          const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k;\n          const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;\n          const DType* lhs_off =\n              Op::use_lhs ? X + cid * lhs_dim + lhs_add : nullptr;\n          const DType* rhs_off =\n              Op::use_rhs ? W + eid * rhs_dim + rhs_add : nullptr;\n          const DType val = Op::Call(lhs_off, rhs_off);\n          if (Cmp::Call(out_off[k], val)) {\n            out_off[k] = val;\n            if (Op::use_lhs) {\n              argx_off[k] = cid;\n              argx_ntype[k] = ntype;\n            }\n            if (Op::use_rhs) {\n              argw_off[k] = eid;\n              argw_etype[k] = etype;\n            }\n          }\n        }\n      }\n    }\n  });\n}\n\n/**\n * @brief CPU kernel of SpMM-Min/Max on Coo format.\n * @param bcast Broadcast information.\n * @param coo The Coo matrix.\n * @param ufeat The feature on source nodes.\n * @param efeat The feature on edges.\n * @param out The result feature on destination nodes.\n * @param argu Arg-Min/Max on source nodes, which refers the source node indices\n *        correspond to the minimum/maximum values of reduction result on\n *        destination nodes. It's useful in computing gradients of Min/Max\n *        reducer.\n * @param arge Arg-Min/Max on edges. which refers the source node indices\n *        correspond to the minimum/maximum values of reduction result on\n *        destination nodes. It's useful in computing gradients of Min/Max\n *        reducer.\n * @note it uses node parallel strategy, different threads are responsible for\n *       the computation of different nodes. To avoid possible data hazard, we\n *       use atomic operators in the reduction phase.\n * @note The result will contain infinity for zero-degree nodes.\n */\ntemplate <typename IdType, typename DType, typename Op, typename Cmp>\nvoid SpMMCmpCoo(\n    const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, NDArray efeat,\n    NDArray out, NDArray argu, NDArray arge) {\n  const bool has_idx = !IsNullArray(coo.data);\n  const IdType* row = static_cast<IdType*>(coo.row->data);\n  const IdType* col = static_cast<IdType*>(coo.col->data);\n  const IdType* edges =\n      has_idx ? static_cast<IdType*>(coo.data->data) : nullptr;\n  const DType* X = Op::use_lhs ? static_cast<DType*>(ufeat->data) : nullptr;\n  const DType* W = Op::use_rhs ? static_cast<DType*>(efeat->data) : nullptr;\n  const int64_t dim = bcast.out_len, lhs_dim = bcast.lhs_len,\n                rhs_dim = bcast.rhs_len;\n  DType* O = static_cast<DType*>(out->data);\n  IdType* argX = Op::use_lhs ? static_cast<IdType*>(argu->data) : nullptr;\n  IdType* argW = Op::use_rhs ? static_cast<IdType*>(arge->data) : nullptr;\n  const int64_t nnz = coo.row->shape[0];\n  // fill zero elements\n  std::fill(O, O + out.NumElements(), Cmp::zero);\n  // spmm\n#pragma omp parallel for\n  for (IdType i = 0; i < nnz; ++i) {\n    const IdType rid = row[i];\n    const IdType cid = col[i];\n    const IdType eid = has_idx ? edges[i] : i;\n    DType* out_off = O + cid * dim;\n    IdType* argx_off = Op::use_lhs ? argX + cid * dim : nullptr;\n    IdType* argw_off = Op::use_rhs ? argW + cid * dim : nullptr;\n    for (int64_t k = 0; k < dim; ++k) {\n      const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k;\n      const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;\n      const DType* lhs_off =\n          Op::use_lhs ? X + rid * lhs_dim + lhs_add : nullptr;\n      const DType* rhs_off =\n          Op::use_rhs ? W + eid * rhs_dim + rhs_add : nullptr;\n      const DType val = Op::Call(lhs_off, rhs_off);\n#pragma omp critical\n      if (Cmp::Call(out_off[k], val)) {\n        out_off[k] = val;\n        if (Op::use_lhs) argx_off[k] = rid;\n        if (Op::use_rhs) argw_off[k] = eid;\n      }\n    }\n  }\n}\n\n/**\n * @brief CPU kernel of Edge_softmax_csr_forward on Csr format.\n * @param bcast Broadcast information.\n * @param csr The Csr matrix.\n * @param ufeat The feature on source nodes.\n * @param efeat The feature on edges.\n * @param out The result of edge_softmax_forward.\n */\ntemplate <typename IdType, typename DType, typename Op>\nvoid Edge_softmax_csr_forward(\n    const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat,\n    NDArray out) {\n  const bool has_idx = !IsNullArray(csr.data);\n  const IdType* indptr = static_cast<IdType*>(csr.indptr->data);\n  const IdType* edges =\n      has_idx ? static_cast<IdType*>(csr.data->data) : nullptr;\n  const DType* W = Op::use_rhs ? static_cast<DType*>(efeat->data) : nullptr;\n  const int64_t dim = bcast.out_len, rhs_dim = bcast.rhs_len;\n  runtime::parallel_for(0, csr.num_rows, [&](size_t b, size_t e) {\n    for (auto rid = b; rid < e; ++rid) {\n      const IdType row_start = indptr[rid], row_end = indptr[rid + 1];\n      std::vector<AccType<DType>> data_e(row_end - row_start, 0);\n      std::vector<IdType> num(row_end - row_start, 0);\n      for (int64_t k = 0; k < dim; ++k) {\n        DType max_v = -std::numeric_limits<DType>::infinity();\n        for (IdType j = row_start; j < row_end; ++j) {\n          const IdType eid = has_idx ? edges[j] : j;\n          const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;\n          const DType* rhs_off =\n              Op::use_rhs ? W + eid * rhs_dim + rhs_add : nullptr;\n          data_e[j - row_start] = *rhs_off;\n          num[j - row_start] = eid * rhs_dim + rhs_add;\n          max_v = std::max<DType>(max_v, (*rhs_off));\n        }\n        DType exp_sum = 0;\n        for (auto& element : data_e) {\n          element -= max_v;\n          element = std::exp(element);\n          exp_sum += element;\n        }\n        for (int i = 0; i < row_end - row_start; i++) {\n          out.Ptr<DType>()[num[i]] = data_e[i] / exp_sum;\n        }\n      }\n    }\n  });\n}\n\n/**\n * @brief CPU kernel of Edge_softmax_csr_backward on Csr format.\n * @param bcast Broadcast information.\n * @param csr The Csr matrix.\n * @param out The result of forward.\n * @param sds The result of gradiet * out.\n * @param back_out The result of edge_softmax_backward.\n */\ntemplate <typename IdType, typename DType, typename Op>\nvoid Edge_softmax_csr_backward(\n    const BcastOff& bcast, const CSRMatrix& csr, NDArray out, NDArray sds,\n    NDArray back_out) {\n  typedef typename std::conditional<\n      std::is_same<DType, BFloat16>::value, float, DType>::type AccType;\n  const bool has_idx = !IsNullArray(csr.data);\n  const IdType* indptr = static_cast<IdType*>(csr.indptr->data);\n  const IdType* edges =\n      has_idx ? static_cast<IdType*>(csr.data->data) : nullptr;\n  const DType* W_out = Op::use_rhs ? static_cast<DType*>(out->data) : nullptr;\n  const DType* W_sds = Op::use_rhs ? static_cast<DType*>(sds->data) : nullptr;\n  const int64_t dim = bcast.out_len, rhs_dim = bcast.rhs_len;\n  runtime::parallel_for(0, csr.num_rows, [&](size_t b, size_t e) {\n    for (auto rid = b; rid < e; ++rid) {\n      const IdType row_start = indptr[rid], row_end = indptr[rid + 1];\n      for (int64_t k = 0; k < dim; ++k) {\n        AccType sum_sds = 0;\n        for (IdType j = row_start; j < row_end; ++j) {\n          const IdType eid = has_idx ? edges[j] : j;\n          const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;\n          const DType* rhs_off_sds =\n              Op::use_rhs ? W_sds + eid * rhs_dim + rhs_add : nullptr;\n          sum_sds += (*rhs_off_sds);\n        }\n        for (IdType j = row_start; j < row_end; ++j) {\n          const IdType eid = has_idx ? edges[j] : j;\n          const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;\n          const DType* rhs_off_out =\n              Op::use_rhs ? W_out + eid * rhs_dim + rhs_add : nullptr;\n          const DType* rhs_off_sds =\n              Op::use_rhs ? W_sds + eid * rhs_dim + rhs_add : nullptr;\n          back_out.Ptr<DType>()[eid * rhs_dim + rhs_add] =\n              (*rhs_off_sds) - sum_sds * (*rhs_off_out);\n        }\n      }\n    }\n  });\n}\n\n}  // namespace cpu\n}  // namespace aten\n}  // namespace dgl\n\n#endif  // DGL_ARRAY_CPU_SPMM_H_\n"
  },
  {
    "path": "src/array/cpu/spmm_binary_ops.h",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file array/cpu/spmm_binary_ops.h\n * @brief SPMM CPU Binary ops.\n */\n#ifndef DGL_ARRAY_CPU_SPMM_BINARY_OPS_H_\n#define DGL_ARRAY_CPU_SPMM_BINARY_OPS_H_\n#include <dgl/array.h>\n#include <dgl/bcast.h>\n\n#include <limits>\nnamespace dgl {\nnamespace aten {\nnamespace cpu {\nnamespace op {\n\n//////////////////////////////// binary operators on CPU\n///////////////////////////////////\ntemplate <typename DType>\nstruct Add {\n  typedef DType type;\n  static constexpr bool use_lhs = true;\n  static constexpr bool use_rhs = true;\n  inline static DType Call(const DType* lhs_off, const DType* rhs_off) {\n    return *lhs_off + *rhs_off;\n  }\n};\ntemplate <typename DType>\nconstexpr bool Add<DType>::use_lhs;\ntemplate <typename DType>\nconstexpr bool Add<DType>::use_rhs;\n\ntemplate <typename DType>\nstruct Sub {\n  typedef DType type;\n  static constexpr bool use_lhs = true;\n  static constexpr bool use_rhs = true;\n  inline static DType Call(const DType* lhs_off, const DType* rhs_off) {\n    return *lhs_off - *rhs_off;\n  }\n};\ntemplate <typename DType>\nconstexpr bool Sub<DType>::use_lhs;\ntemplate <typename DType>\nconstexpr bool Sub<DType>::use_rhs;\n\ntemplate <typename DType>\nstruct Mul {\n  typedef DType type;\n  static constexpr bool use_lhs = true;\n  static constexpr bool use_rhs = true;\n  inline static DType Call(const DType* lhs_off, const DType* rhs_off) {\n    return *lhs_off * *rhs_off;\n  }\n};\ntemplate <typename DType>\nconstexpr bool Mul<DType>::use_lhs;\ntemplate <typename DType>\nconstexpr bool Mul<DType>::use_rhs;\n\ntemplate <typename DType>\nstruct Div {\n  typedef DType type;\n  static constexpr bool use_lhs = true;\n  static constexpr bool use_rhs = true;\n  inline static DType Call(const DType* lhs_off, const DType* rhs_off) {\n    return *lhs_off / *rhs_off;\n  }\n};\ntemplate <typename DType>\nconstexpr bool Div<DType>::use_lhs;\ntemplate <typename DType>\nconstexpr bool Div<DType>::use_rhs;\n\ntemplate <typename DType>\nstruct CopyLhs {\n  typedef DType type;\n  static constexpr bool use_lhs = true;\n  static constexpr bool use_rhs = false;\n  inline static DType Call(const DType* lhs_off, const DType*) {\n    return *lhs_off;\n  }\n};\ntemplate <typename DType>\nconstexpr bool CopyLhs<DType>::use_lhs;\ntemplate <typename DType>\nconstexpr bool CopyLhs<DType>::use_rhs;\n\ntemplate <typename DType>\nstruct CopyRhs {\n  typedef DType type;\n  static constexpr bool use_lhs = false;\n  static constexpr bool use_rhs = true;\n  inline static DType Call(const DType*, const DType* rhs_off) {\n    return *rhs_off;\n  }\n};\ntemplate <typename DType>\nconstexpr bool CopyRhs<DType>::use_lhs;\ntemplate <typename DType>\nconstexpr bool CopyRhs<DType>::use_rhs;\n\n//////////////////////////////// Reduce operators on CPU\n///////////////////////////////////\ntemplate <typename DType>\nconstexpr DType MinDType() {\n  if (std::is_same<DType, BFloat16>::value)\n    return BFloat16::Min();\n  else\n    return -std::numeric_limits<DType>::infinity();\n}\n\ntemplate <typename DType>\nstruct Max {\n  typedef DType type;\n  static constexpr DType zero = MinDType<DType>();\n  // return true if accum should be replaced\n  inline static DType Call(DType accum, DType val) { return accum < val; }\n};\ntemplate <typename DType>\nconstexpr DType Max<DType>::zero;\n\ntemplate <typename DType>\nconstexpr DType MaxDType() {\n  if (std::is_same<DType, BFloat16>::value)\n    return BFloat16::Max();\n  else\n    return std::numeric_limits<DType>::infinity();\n}\n\ntemplate <typename DType>\nstruct Min {\n  typedef DType type;\n  static constexpr DType zero = MaxDType<DType>();\n  // return true if accum should be replaced\n  inline static DType Call(DType accum, DType val) { return accum > val; }\n};\ntemplate <typename DType>\nconstexpr DType Min<DType>::zero;\n\n#define SWITCH_OP(op, Op, ...)                                  \\\n  do {                                                          \\\n    if ((op) == \"add\") {                                        \\\n      typedef dgl::aten::cpu::op::Add<DType> Op;                \\\n      { __VA_ARGS__ }                                           \\\n    } else if ((op) == \"sub\") {                                 \\\n      typedef dgl::aten::cpu::op::Sub<DType> Op;                \\\n      { __VA_ARGS__ }                                           \\\n    } else if ((op) == \"mul\") {                                 \\\n      typedef dgl::aten::cpu::op::Mul<DType> Op;                \\\n      { __VA_ARGS__ }                                           \\\n    } else if ((op) == \"div\") {                                 \\\n      typedef dgl::aten::cpu::op::Div<DType> Op;                \\\n      { __VA_ARGS__ }                                           \\\n    } else if ((op) == \"copy_lhs\") {                            \\\n      typedef dgl::aten::cpu::op::CopyLhs<DType> Op;            \\\n      { __VA_ARGS__ }                                           \\\n    } else if ((op) == \"copy_rhs\") {                            \\\n      typedef dgl::aten::cpu::op::CopyRhs<DType> Op;            \\\n      { __VA_ARGS__ }                                           \\\n    } else {                                                    \\\n      LOG(FATAL) << \"Unsupported SpMM binary operator: \" << op; \\\n    }                                                           \\\n  } while (0)\n\n}  // namespace op\n\n}  // namespace cpu\n}  // namespace aten\n}  // namespace dgl\n\n#endif  // DGL_ARRAY_CPU_SPMM_BINARY_OPS_H_\n"
  },
  {
    "path": "src/array/cpu/spmm_blocking_libxsmm.h",
    "content": "/**\n *  Copyright (c) 2021 Intel Corporation\n * @file array/cpu/spmm.h\n * @brief SPMM CPU kernel function header.\n * @author Sanchit Misra <sanchit.misra@intel.com>,\n *         Ramanarayan Mohanty <ramanarayan.mohanty@intel.com>,\n *         Vasimuddin Md <vasimuddin.md@intel.com>,\n *         Sasikanth Avancha <sasikanth.avancha@intel.com>\n */\n#ifndef DGL_ARRAY_CPU_SPMM_BLOCKING_LIBXSMM_H_\n#define DGL_ARRAY_CPU_SPMM_BLOCKING_LIBXSMM_H_\n\n#include <dgl/array.h>\n#include <dgl/bcast.h>\n#include <dmlc/logging.h>\n\n#include <algorithm>\n\n#if !defined(_WIN32)\n#ifdef USE_LIBXSMM\n#include <libxsmm_source.h>\n#include <unistd.h>\n#ifdef DEBUG\n#include <x86intrin.h>\n#endif  // DEBUG\n#include <dmlc/omp.h>\n\n#define NUM_BLOCKS_PER_THREAD 20\n#define BLOCKING_HEURISTIC_PARAM 500\n\nnamespace dgl {\nnamespace aten {\nnamespace cpu {\n\ntemplate <typename IdType, typename DType>\nstruct CSRMatrixInternal {\n  IdType num_rows;\n  IdType num_cols;\n  IdType *indptr;\n  IdType *indices;\n  DType *data;\n};\n\nint32_t GetLLCSize() {\n#ifdef _SC_LEVEL3_CACHE_SIZE\n  int32_t cache_size = sysconf(_SC_LEVEL3_CACHE_SIZE);\n  if (cache_size < 0) cache_size = DGL_CPU_LLC_SIZE;\n#else\n  int32_t cache_size = DGL_CPU_LLC_SIZE;\n#endif\n  return cache_size;\n}\n\n/**\n * @brief Tile the CSR matrix to roughly make sure that the column tiles and\n *        corresponding neighbor features fit into LLC and the row tiles\n *        are assigned to OMP threads.\n * @param csr The Csr matrix.\n * @param block_csr_array The array containing csr matrices of all blocks.\n * @param num_M_blocks Number of blocks to create along the rows of adjacency\n *        matrix.\n * @param num_K_blocks Number of blocks to create along the columns of adjacency\n *        matrix.\n * @param M_block_size block size along the rows of adjacency matrix.\n * @param K_block_size block size along the columns of adjacency matrix.\n * @param use_lhs Whether to use lhs.\n * @param use_rhs Whether to use rhs.\n */\ntemplate <typename IdType>\ninline void SpMMCreateBlocks(\n    const CSRMatrix &csr, CSRMatrixInternal<IdType, IdType> *block_csr_array,\n    IdType num_M_blocks, IdType num_K_blocks, IdType M_block_size,\n    IdType K_block_size, bool use_lhs, bool use_rhs) {\n  const IdType M = csr.num_rows;\n  const IdType K = csr.num_cols;\n  IdType *indptr = csr.indptr.Ptr<IdType>();\n  IdType *indices = csr.indices.Ptr<IdType>();\n  IdType *edges = csr.data.Ptr<IdType>();\n  CHECK_NOTNULL(indptr);\n  if (use_lhs) CHECK_NOTNULL(indices);\n  if (use_rhs) CHECK_NOTNULL(edges);\n\n  if (num_K_blocks > 1) {\n    IdType *indptr_block_buf = reinterpret_cast<IdType *>(aligned_alloc(\n        64, (M_block_size + 1) * num_M_blocks * num_K_blocks * sizeof(IdType)));\n    IdType *indices_block_buf = nullptr;\n    if (use_lhs) {\n      indices_block_buf = reinterpret_cast<IdType *>(\n          aligned_alloc(64, indptr[M] * sizeof(IdType)));\n    }\n    IdType *edges_block_buf = nullptr;\n    if (use_rhs) {\n      edges_block_buf = reinterpret_cast<IdType *>(\n          aligned_alloc(64, indptr[M] * sizeof(IdType)));\n    }\n\n#pragma omp parallel\n    {\n      IdType *my_cur_col_id = reinterpret_cast<IdType *>(\n          aligned_alloc(64, 2 * M_block_size * sizeof(IdType)));\n\n#pragma omp for\n      for (IdType m = 0; m < num_M_blocks; m++) {\n        const IdType M_start = m * M_block_size;\n        const IdType M_end = std::min((m + 1) * M_block_size, M);\n        const IdType nnz = indptr[M_end] - indptr[M_start];\n\n        IdType cur_indices_id = 0;\n        IdType *my_indices_block_buf, *my_edges_block_buf;\n        if (use_lhs) my_indices_block_buf = indices_block_buf + indptr[M_start];\n        if (use_rhs) my_edges_block_buf = edges_block_buf + indptr[M_start];\n\n        for (IdType i = M_start; i < M_end; i++) {\n          my_cur_col_id[(i - M_start) * 2] = indptr[i];\n          my_cur_col_id[(i - M_start) * 2 + 1] = indptr[i + 1];\n        }\n        for (IdType k = 0; k < num_K_blocks; k++) {\n          const IdType K_start = k * K_block_size;\n          const IdType K_end = std::min((k + 1) * K_block_size, K);\n          CSRMatrixInternal<IdType, IdType> cur_csr;\n          cur_csr.num_rows = M_end - M_start;\n          cur_csr.num_cols = K_end - K_start;\n          // Create csr_ij\n          IdType *cur_csr_indptr =\n              indptr_block_buf + (m * num_K_blocks + k) * (M_block_size + 1);\n          IdType *cur_csr_indices = nullptr, *cur_csr_edges = nullptr;\n          if (use_lhs) cur_csr_indices = my_indices_block_buf + cur_indices_id;\n          if (use_rhs) cur_csr_edges = my_edges_block_buf + cur_indices_id;\n          IdType cur_nnz = 0;\n          for (IdType i = M_start; i < M_end; i++) {\n            const IdType row_start = my_cur_col_id[(i - M_start) * 2];\n            const IdType row_end = my_cur_col_id[(i - M_start) * 2 + 1];\n            cur_csr_indptr[i - M_start] = cur_nnz;\n            IdType eid;\n            for (eid = row_start; eid < row_end; eid++) {\n              const IdType src = indices[eid];\n              const IdType edge = edges[eid];\n              if (src >= K_end) {\n                break;\n              }\n              CHECK_LT(cur_indices_id + cur_nnz, nnz);\n              if (use_lhs) cur_csr_indices[cur_nnz] = src;\n              if (use_rhs) cur_csr_edges[cur_nnz] = edge;\n              cur_nnz++;\n            }\n            my_cur_col_id[(i - M_start) * 2] = eid;\n          }\n          cur_csr_indptr[cur_csr.num_rows] = cur_nnz;\n          cur_indices_id += cur_nnz;\n          cur_csr.indptr = cur_csr_indptr;\n          if (use_lhs) cur_csr.indices = cur_csr_indices;\n          if (use_rhs) cur_csr.data = cur_csr_edges;\n          block_csr_array[m * num_K_blocks + k] = cur_csr;\n        }\n        CHECK_EQ(nnz, cur_indices_id);\n      }\n      free(my_cur_col_id);\n    }\n  } else {\n    for (IdType m = 0; m < num_M_blocks; m++) {\n      const IdType M_start = m * M_block_size;\n      const IdType M_end = std::min((m + 1) * M_block_size, M);\n\n      CSRMatrixInternal<IdType, IdType> cur_csr;\n      cur_csr.num_rows = M_end - M_start;\n      cur_csr.num_cols = K;\n      cur_csr.indptr = indptr + M_start;\n      cur_csr.indices = indices;\n      cur_csr.data = edges;\n\n      block_csr_array[m] = cur_csr;\n    }\n  }\n}\n\n/**\n * @brief Create libxsmm kernel.\n * @param has_idx For the edge features, are there indices available.\n * @param N Feature size.\n * @param redop_flag Flag specifying the reduction operation.\n * @param is_cmp Is the reduction operation a compare operation.\n * @note libxsmm_dispatch_meltw_opreduce_vecs_idx creates a JIT'ed kernel.\n *       Given a node u, the kernel performs an elementwise \"Op\" on the\n *       features of the neighbors and/or the edges incident on u.\n *       Subsequently, it performs an elementwise \"Redop\" on all such\n *       features created and stores into the feature of node u.\n *       It uses a SIMD and a cache efficient design and also provides\n *       support to enable software prefetching if needed. For IdType,\n *       it supports INT32 and INT64. For DType, it supports BF16 and FP32.\n *       It supports all the \"Ops\" and \"Redops\" supported by DGL. Once a\n *       kernel is generated by libxsmm_dispatch_meltw_opreduce_vecs_idx,\n *       it is cached for the entire duration of the execution of a program\n *       so that subsequently if the kernel is needed again, it just returns\n *       the cached copy.\n */\ntemplate <typename IdType, typename DType, typename Op>\ninline libxsmm_meltwfunction_opreduce_vecs_idx SpMMCreateLibxsmmKernel(\n    bool has_idx, IdType N, libxsmm_meltw_opreduce_vecs_flags redop_flag,\n    bool is_cmp) {\n  int _ld = N;\n  libxsmm_meltw_opreduce_vecs_flags opredop_flags;\n  // First, set the Op in the opredop_flags\n  if (std::is_same<Op, op::Add<DType>>::value) {\n    opredop_flags = LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OP_ADD;\n  } else if (std::is_same<Op, op::Sub<DType>>::value) {\n    opredop_flags = LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OP_SUB;\n  } else if (std::is_same<Op, op::Mul<DType>>::value) {\n    opredop_flags = LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OP_MUL;\n  } else if (std::is_same<Op, op::Div<DType>>::value) {\n    opredop_flags = LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OP_DIV;\n  } else if (std::is_same<Op, op::CopyLhs<DType>>::value) {\n    opredop_flags = LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OP_COPY;\n  } else if (std::is_same<Op, op::CopyRhs<DType>>::value) {\n    opredop_flags = LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OP_COPY;\n  }\n  // Second, set which of lhs or rhs is considered first and second operand.\n  // This is needed since libxsmm assumes that the copy operation always copies\n  // the first operand. So, if we need to copy rhs, we need to set that as the\n  // first operand. For rhs, we also set whether to use implicit indices or\n  // provided indices.\n  // TODO(Steve): fix this long line in a separate PR.\n  if (std::is_same<Op, op::CopyLhs<DType>>::value) {\n    opredop_flags =\n        (libxsmm_meltw_opreduce_vecs_flags)(opredop_flags | LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OPORDER_VECIDX_VECIN);  // NOLINT\n  } else if (std::is_same<Op, op::CopyRhs<DType>>::value) {\n    opredop_flags =\n        (libxsmm_meltw_opreduce_vecs_flags)(opredop_flags | LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OPORDER_VECIN_VECIDX);  // NOLINT\n    if (!has_idx) {\n      opredop_flags =\n          (libxsmm_meltw_opreduce_vecs_flags)(opredop_flags | LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_IMPLICIT_INDEXED_VECIDX);  // NOLINT\n    }\n  } else {\n    opredop_flags =\n        (libxsmm_meltw_opreduce_vecs_flags)(opredop_flags | LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OPORDER_VECIDX_VECIN);  // NOLINT\n    if (has_idx) {\n      opredop_flags =\n          (libxsmm_meltw_opreduce_vecs_flags)(opredop_flags | LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_INDEXED_VEC);  // NOLINT\n    } else {\n      opredop_flags =\n          (libxsmm_meltw_opreduce_vecs_flags)(opredop_flags | LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_IMPLICIT_INDEXED_VEC);  // NOLINT\n    }\n  }\n  // Third, we set the Redop in the opredop_flags\n  opredop_flags =\n      (libxsmm_meltw_opreduce_vecs_flags)(opredop_flags | redop_flag);\n  // Fourth, in case of Cmp Redop, set whether to record argmax/argmin for\n  // lhs/rhs\n  if (is_cmp) {\n    if (Op::use_lhs) {\n      opredop_flags =\n          (libxsmm_meltw_opreduce_vecs_flags)(opredop_flags | LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_RECORD_ARGOP_OFF_VEC_0);  // NOLINT\n    }\n    if (Op::use_rhs) {\n      opredop_flags =\n          (libxsmm_meltw_opreduce_vecs_flags)(opredop_flags | LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_RECORD_ARGOP_OFF_VEC_1);  // NOLINT\n    }\n  }\n  libxsmm_meltwfunction_opreduce_vecs_idx kernel = nullptr;\n  if (std::is_same<DType, float>::value) {\n    kernel = libxsmm_dispatch_meltw_opreduce_vecs_idx(\n        N, &_ld, &_ld, LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32,\n        (sizeof(IdType) == 8) ? LIBXSMM_DATATYPE_I64 : LIBXSMM_DATATYPE_I32,\n        opredop_flags, 0);\n  } else {  // assume bf16\n    kernel = libxsmm_dispatch_meltw_opreduce_vecs_idx(\n        N, &_ld, &_ld, LIBXSMM_DATATYPE_BF16, LIBXSMM_DATATYPE_BF16,\n        (sizeof(IdType) == 8) ? LIBXSMM_DATATYPE_I64 : LIBXSMM_DATATYPE_I32,\n        opredop_flags, 0);\n  }\n\n  if (kernel == nullptr) {\n    LOG(FATAL) << \"Failed to generate libxsmm kernel for the SpMM operation.\"\n                  \"To disable libxsmm, use dgl.use_libxsmm(false).\";\n  }\n  return kernel;\n}\n\n/**\n * @brief Use libxsmm to perform SpMM-Sum on all blocks.\n * @param block_csr_array The array containing csr matrices of all blocks.\n * @param B The feature on source nodes.\n * @param E The feature on edges.\n * @param C The result feature on destination nodes.\n * @param has_idx For the edge features, are there indices available.\n * @param N Feature size.\n * @param num_M_blocks Number of blocks to create along the rows of adjacency\n *        matrix.\n * @param num_K_blocks Number of blocks to create along the columns of adjacency\n *        matrix.\n * @param M_block_size block size along the rows of adjacency matrix.\n * @param kernel The libxsmm kernel.\n */\ntemplate <typename IdType, typename DType>\ninline void SpMMBlockwiseOpSum(\n    CSRMatrixInternal<IdType, IdType> *block_csr_array, const DType *B,\n    const DType *E, DType *C, bool has_idx, IdType N, IdType num_M_blocks,\n    IdType num_K_blocks, IdType M_block_size,\n    libxsmm_meltwfunction_opreduce_vecs_idx kernel) {\n  const DType *in_matrix1 = B;\n  const DType *in_matrix2 = E;\n  DType *output = C;\n#pragma omp parallel\n  {\n    for (IdType k = 0; k < num_K_blocks; k++) {\n#pragma omp for schedule(dynamic)\n      for (IdType m = 0; m < num_M_blocks; m++) {\n        CSRMatrixInternal<IdType, IdType> cur_csr =\n            block_csr_array[m * num_K_blocks + k];\n\n        const IdType M_start = m * M_block_size;\n        for (IdType i = 0; i < cur_csr.num_rows; i++) {\n          const IdType row_start = cur_csr.indptr[i];\n          const IdType row_end = cur_csr.indptr[i + 1];\n          const IdType dst = i + M_start;\n\n          libxsmm_meltw_opreduce_vecs_idx_param params;\n          params.n = row_end - row_start;\n          params.indices = &cur_csr.indices[row_start];\n          params.in_matrix = in_matrix1;\n          params.out_vec = &output[dst * N];\n          params.scale_vals = nullptr;\n          if (has_idx) {\n            params.in_matrix2 = in_matrix2;\n            params.indices2 = &cur_csr.data[row_start];\n          } else {\n            params.in_matrix2 = &in_matrix2[row_start * N];\n          }\n          kernel(&params);\n        }\n      }\n    }\n  }\n}\n\n/**\n * @brief Use libxsmm to perform SpMM-Max/Min on all blocks.\n * @param block_csr_array The array containing csr matrices of all blocks.\n * @param B The feature on source nodes.\n * @param E The feature on edges.\n * @param C The result feature on destination nodes.\n * @param argB Arg-Min/Max on source nodes.\n * @param argE Arg-Min/Max on edges.\n * @param has_idx For the edge features, are there indices available.\n * @param N Feature size.\n * @param num_M_blocks Number of blocks to create along the rows of adjacency\n *        matrix.\n * @param num_K_blocks Number of blocks to create along the columns of adjacency\n *        matrix.\n * @param M_block_size block size along the rows of adjacency matrix.\n * @param kernel The libxsmm kernel.\n */\ntemplate <typename IdType, typename DType, typename Op, typename Cmp>\ninline void SpMMBlockwiseOpCmp(\n    CSRMatrixInternal<IdType, IdType> *block_csr_array, const DType *B,\n    const DType *E, DType *C, IdType *argB, IdType *argE, bool has_idx,\n    IdType N, IdType num_M_blocks, IdType num_K_blocks, IdType M_block_size,\n    libxsmm_meltwfunction_opreduce_vecs_idx kernel) {\n  const DType *in_matrix1 = B;\n  const DType *in_matrix2 = E;\n  DType *output = C;\n  IdType *out_matrix1 = argB;\n  IdType *out_matrix2 = argE;\n\n#pragma omp parallel\n  {\n    for (IdType k = 0; k < num_K_blocks; k++) {\n#pragma omp for schedule(dynamic)\n      for (IdType m = 0; m < num_M_blocks; m++) {\n        CSRMatrixInternal<IdType, IdType> cur_csr =\n            block_csr_array[m * num_K_blocks + k];\n\n        const IdType M_start = m * M_block_size;\n        for (IdType i = 0; i < cur_csr.num_rows; i++) {\n          const IdType row_start = cur_csr.indptr[i];\n          const IdType row_end = cur_csr.indptr[i + 1];\n          const IdType dst = i + M_start;\n\n          libxsmm_meltw_opreduce_vecs_idx_param params;\n          params.n = row_end - row_start;\n          params.indices = &cur_csr.indices[row_start];\n          params.in_matrix = in_matrix1;\n          params.out_vec = &output[dst * N];\n          params.argop_off_vec_0 = &out_matrix1[dst * N];\n          params.argop_off_vec_1 = &out_matrix2[dst * N];\n          params.scale_vals = nullptr;\n          if (has_idx) {\n            params.in_matrix2 = in_matrix2;\n            params.indices2 = &cur_csr.data[row_start];\n          } else {\n            params.in_matrix2 = &in_matrix2[row_start * N];\n          }\n          kernel(&params);\n        }\n      }\n    }\n  }\n}\n\n/**\n * @brief Free the tiled CSR matrix data.\n * @param block_csr_array The array containing csr matrices of all blocks.\n * @param num_M_blocks Number of blocks to create along the rows of adjacency\n *        matrix.\n * @param num_K_blocks Number of blocks to create along the columns of adjacency\n *        matrix.\n * @param use_lhs Whether to use lhs.\n * @param use_rhs Whether to use rhs.\n */\ntemplate <typename IdType>\ninline void SpMMFreeBlocks(\n    CSRMatrixInternal<IdType, IdType> *block_csr_array, IdType num_M_blocks,\n    IdType num_K_blocks, bool use_lhs, bool use_rhs) {\n  if (num_K_blocks > 1) {\n    free(block_csr_array[0].indptr);\n    if (use_lhs) free(block_csr_array[0].indices);\n    if (use_rhs) free(block_csr_array[0].data);\n  }\n  free(block_csr_array);\n}\n\n/**\n * @brief Optimized CPU kernel of SpMM-Sum/Max/Min on Csr format.\n * @param bcast Broadcast information.\n * @param csr The Csr matrix.\n * @param ufeat The feature on source nodes.\n * @param efeat The feature on edges.\n * @param out The result feature on destination nodes.\n * @param argu Arg-Min/Max on source nodes.\n * @param arge Arg-Min/Max on edges.\n * @note it uses libxsmm, blocking and dynamic thread scheduling.\n */\ntemplate <typename IdType, typename DType, typename Op, typename Redop>\nvoid SpMMRedopCsrOpt(\n    const BcastOff &bcast, const CSRMatrix &csr, NDArray ufeat, NDArray efeat,\n    NDArray out, NDArray argu, NDArray arge) {\n  int32_t llc_size = GetLLCSize();\n\n#ifdef DEBUG\n  uint64_t startTick, endTick;\n  startTick = __rdtsc();\n#endif  // DEBUG\n\n  const bool has_idx = !IsNullArray(csr.data);\n\n  DType *C = out.Ptr<DType>();\n  const DType *B = ufeat.Ptr<DType>();\n  const DType *E = efeat.Ptr<DType>();\n  IdType *argB, *argE;\n  if (std::is_same<Redop, op::Max<DType>>::value ||\n      std::is_same<Redop, op::Min<DType>>::value) {\n    argB = argu.Ptr<IdType>();\n    argE = arge.Ptr<IdType>();\n  }\n\n  const int nthreads = omp_get_max_threads();\n  const IdType M = csr.num_rows;\n  const IdType N = bcast.out_len;\n  const IdType K = csr.num_cols;\n  const IdType *indptr = csr.indptr.Ptr<IdType>();\n  CHECK_NOTNULL(indptr);\n  const IdType total_nnz = indptr[M];\n  if (M <= 0 || K <= 0 || N <= 0 || total_nnz <= 0) return;\n\n  const double avg_degree = total_nnz * 1.0 / M;\n  const double nnz_prob = avg_degree / K;\n\n  IdType K_block_size = std::min(\n      (int64_t)K,\n      (int64_t)(llc_size / (N * sizeof(DType) * nnz_prob * BLOCKING_HEURISTIC_PARAM)));  // NOLINT\n  IdType M_block_size = M / (nthreads * NUM_BLOCKS_PER_THREAD);\n  if (M_block_size == 0) M_block_size = 1;\n  if (K_block_size == 0) K_block_size = 1;\n\n  IdType num_M_blocks = (M + M_block_size - 1) / M_block_size;\n  IdType num_K_blocks = (K + K_block_size - 1) / K_block_size;\n\n  CSRMatrixInternal<IdType, IdType> *block_csr_array =\n      (CSRMatrixInternal<IdType, IdType> *)aligned_alloc(\n          64, sizeof(CSRMatrixInternal<IdType, IdType>) * num_M_blocks *\n                  num_K_blocks);\n\n#ifdef DEBUG\n  endTick = __rdtsc();\n  if (std::is_same<Redop, op::Max<DType>>::value) {\n    LOG(INFO) << \"Redop = Max\";\n  } else if (std::is_same<Redop, op::Min<DType>>::value) {\n    LOG(INFO) << \"Redop = Min\";\n  } else if (std::is_same<Redop, op::Add<DType>>::value) {\n    LOG(INFO) << \"Redop = Add\";\n  }\n  LOG(INFO) << \"nthreads = \" << nthreads << \", llc_size = \" << llc_size;\n  LOG(INFO) << \"M = \" << M << \", K = \" << K << \", N = \" << N;\n  LOG(INFO) << \"use_lhs = \" << Op::use_lhs << \", use_rhs = \" << Op::use_rhs;\n  LOG(INFO) << \"total_nnz = \" << total_nnz << \", avg_degree = \" << avg_degree;\n  LOG(INFO) << \"has_idx = \" << has_idx;\n  LOG(INFO) << \"nnz_prob = \" << nnz_prob;\n  LOG(INFO) << \"K_block_size = \" << K_block_size\n            << \", M_block_size = \" << M_block_size;\n  LOG(INFO) << \"num_K_blocks = \" << num_K_blocks\n            << \", num_M_blocks = \" << num_M_blocks;\n  LOG(INFO) << \"stage0 ticks = \" << (endTick - startTick);\n  startTick = __rdtsc();\n#endif  // DEBUG\n\n  SpMMCreateBlocks(\n      csr, block_csr_array, num_M_blocks, num_K_blocks, M_block_size,\n      K_block_size, Op::use_lhs, Op::use_rhs);\n\n#ifdef DEBUG\n  endTick = __rdtsc();\n  LOG(INFO) << \"stage1 ticks = \" << (endTick - startTick);\n  startTick = __rdtsc();\n#endif  // DEBUG\n\n  libxsmm_meltwfunction_opreduce_vecs_idx kernel = nullptr;\n  if (std::is_same<Redop, op::Max<DType>>::value) {\n    kernel = SpMMCreateLibxsmmKernel<IdType, DType, Op>(\n        has_idx, N, LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_REDOP_MAX, true);\n  } else if (std::is_same<Redop, op::Min<DType>>::value) {\n    kernel = SpMMCreateLibxsmmKernel<IdType, DType, Op>(\n        has_idx, N, LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_REDOP_MIN, true);\n  } else if (std::is_same<Redop, op::Add<DType>>::value) {\n    kernel = SpMMCreateLibxsmmKernel<IdType, DType, Op>(\n        has_idx, N, LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_REDOP_SUM, false);\n  }\n\n#ifdef DEBUG\n  endTick = __rdtsc();\n  LOG(INFO) << \"stage2 ticks = \" << (endTick - startTick);\n  startTick = __rdtsc();\n#endif  // DEBUG\n\n  if (std::is_same<Redop, op::Max<DType>>::value ||\n      std::is_same<Redop, op::Min<DType>>::value) {\n    SpMMBlockwiseOpCmp<IdType, DType, Op, Redop>(\n        block_csr_array, B, E, C, argB, argE, has_idx, N, num_M_blocks,\n        num_K_blocks, M_block_size, kernel);\n  } else {\n    SpMMBlockwiseOpSum(\n        block_csr_array, B, E, C, has_idx, N, num_M_blocks, num_K_blocks,\n        M_block_size, kernel);\n  }\n\n#ifdef DEBUG\n  endTick = __rdtsc();\n  LOG(INFO) << \"stage3 ticks = \" << (endTick - startTick);\n  startTick = __rdtsc();\n#endif  // DEBUG\n\n  SpMMFreeBlocks(\n      block_csr_array, num_M_blocks, num_K_blocks, Op::use_lhs, Op::use_rhs);\n\n#ifdef DEBUG\n  endTick = __rdtsc();\n  LOG(INFO) << \"stage4 ticks = \" << (endTick - startTick);\n#endif  // DEBUG\n}\n\n/**\n * @brief Optimized CPU kernel of SpMM-Sum on Csr format.\n * @param bcast Broadcast information.\n * @param csr The Csr matrix.\n * @param ufeat The feature on source nodes.\n * @param efeat The feature on edges.\n * @param out The result feature on destination nodes.\n * @note it uses libxsmm, blocking and dynamic thread scheduling.\n */\ntemplate <typename IdType, typename DType, typename Op>\nvoid SpMMSumCsrLibxsmm(\n    const BcastOff &bcast, const CSRMatrix &csr, NDArray ufeat, NDArray efeat,\n    NDArray out) {\n  NDArray dummy;\n  SpMMRedopCsrOpt<IdType, DType, Op, op::Add<DType>>(\n      bcast, csr, ufeat, efeat, out, dummy, dummy);\n}\n\n/**\n * @brief Optimized CPU kernel of SpMM-Min/Max on Csr format.\n * @param bcast Broadcast information.\n * @param csr The Csr matrix.\n * @param ufeat The feature on source nodes.\n * @param efeat The feature on edges.\n * @param out The result feature on destination nodes.\n * @param argu Arg-Min/Max on source nodes.\n * @param arge Arg-Min/Max on edges.\n * @note it uses libxsmm, blocking and dynamic thread scheduling.\n */\ntemplate <typename IdType, typename DType, typename Op, typename Cmp>\nvoid SpMMCmpCsrLibxsmm(\n    const BcastOff &bcast, const CSRMatrix &csr, NDArray ufeat, NDArray efeat,\n    NDArray out, NDArray argu, NDArray arge) {\n  SpMMRedopCsrOpt<IdType, DType, Op, Cmp>(\n      bcast, csr, ufeat, efeat, out, argu, arge);\n}\n\n}  // namespace cpu\n}  // namespace aten\n}  // namespace dgl\n\n#endif  // USE_LIBXSMM\n#endif  // _WIN32\n\n#endif  // DGL_ARRAY_CPU_SPMM_BLOCKING_LIBXSMM_H_\n"
  },
  {
    "path": "src/array/cpu/traversal.cc",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file array/cpu/traversal.cc\n * @brief Graph traversal implementation\n */\n\n#include \"./traversal.h\"\n\n#include <dgl/graph_traversal.h>\n\n#include <algorithm>\n#include <queue>\n\nnamespace dgl {\nnamespace aten {\nnamespace impl {\nnamespace {\n// A utility view class to wrap a vector into a queue.\ntemplate <typename DType>\nstruct VectorQueueWrapper {\n  std::vector<DType>* vec;\n  size_t head = 0;\n\n  explicit VectorQueueWrapper(std::vector<DType>* vec) : vec(vec) {}\n\n  void push(const DType& elem) { vec->push_back(elem); }\n\n  DType top() const { return vec->operator[](head); }\n\n  void pop() { ++head; }\n\n  bool empty() const { return head == vec->size(); }\n\n  size_t size() const { return vec->size() - head; }\n};\n\n// Internal function to merge multiple traversal traces into one ndarray.\n// It is similar to zip the vectors together.\ntemplate <typename DType>\nIdArray MergeMultipleTraversals(const std::vector<std::vector<DType>>& traces) {\n  int64_t max_len = 0, total_len = 0;\n  for (size_t i = 0; i < traces.size(); ++i) {\n    const int64_t tracelen = traces[i].size();\n    max_len = std::max(max_len, tracelen);\n    total_len += traces[i].size();\n  }\n  IdArray ret = IdArray::Empty(\n      {total_len}, DGLDataType{kDGLInt, sizeof(DType) * 8, 1},\n      DGLContext{kDGLCPU, 0});\n  DType* ret_data = static_cast<DType*>(ret->data);\n  for (int64_t i = 0; i < max_len; ++i) {\n    for (size_t j = 0; j < traces.size(); ++j) {\n      const int64_t tracelen = traces[j].size();\n      if (i >= tracelen) {\n        continue;\n      }\n      *(ret_data++) = traces[j][i];\n    }\n  }\n  return ret;\n}\n\n// Internal function to compute sections if multiple traversal traces\n// are merged into one ndarray.\ntemplate <typename DType>\nIdArray ComputeMergedSections(const std::vector<std::vector<DType>>& traces) {\n  int64_t max_len = 0;\n  for (size_t i = 0; i < traces.size(); ++i) {\n    const int64_t tracelen = traces[i].size();\n    max_len = std::max(max_len, tracelen);\n  }\n  IdArray ret = IdArray::Empty(\n      {max_len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});\n  int64_t* ret_data = static_cast<int64_t*>(ret->data);\n  for (int64_t i = 0; i < max_len; ++i) {\n    int64_t sec_len = 0;\n    for (size_t j = 0; j < traces.size(); ++j) {\n      const int64_t tracelen = traces[j].size();\n      if (i < tracelen) {\n        ++sec_len;\n      }\n    }\n    *(ret_data++) = sec_len;\n  }\n  return ret;\n}\n\n}  // namespace\n\ntemplate <DGLDeviceType XPU, typename IdType>\nFrontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source) {\n  std::vector<IdType> ids;\n  std::vector<int64_t> sections;\n  VectorQueueWrapper<IdType> queue(&ids);\n  auto visit = [&](const int64_t v) {};\n  auto make_frontier = [&]() {\n    if (!queue.empty()) {\n      // do not push zero-length frontier\n      sections.push_back(queue.size());\n    }\n  };\n  BFSTraverseNodes<IdType>(csr, source, &queue, visit, make_frontier);\n\n  Frontiers front;\n  front.ids = VecToIdArray(ids, sizeof(IdType) * 8);\n  front.sections = VecToIdArray(sections, sizeof(int64_t) * 8);\n  return front;\n}\n\ntemplate Frontiers BFSNodesFrontiers<kDGLCPU, int32_t>(\n    const CSRMatrix&, IdArray);\ntemplate Frontiers BFSNodesFrontiers<kDGLCPU, int64_t>(\n    const CSRMatrix&, IdArray);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nFrontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source) {\n  std::vector<IdType> ids;\n  std::vector<int64_t> sections;\n  // NOTE: std::queue has no top() method.\n  std::vector<IdType> nodes;\n  VectorQueueWrapper<IdType> queue(&nodes);\n  auto visit = [&](const IdType e) { ids.push_back(e); };\n  bool first_frontier = true;\n  auto make_frontier = [&] {\n    if (first_frontier) {\n      first_frontier = false;  // do not push the first section when doing edges\n    } else if (!queue.empty()) {\n      // do not push zero-length frontier\n      sections.push_back(queue.size());\n    }\n  };\n  BFSTraverseEdges<IdType>(csr, source, &queue, visit, make_frontier);\n\n  Frontiers front;\n  front.ids = VecToIdArray(ids, sizeof(IdType) * 8);\n  front.sections = VecToIdArray(sections, sizeof(int64_t) * 8);\n  return front;\n}\n\ntemplate Frontiers BFSEdgesFrontiers<kDGLCPU, int32_t>(\n    const CSRMatrix&, IdArray);\ntemplate Frontiers BFSEdgesFrontiers<kDGLCPU, int64_t>(\n    const CSRMatrix&, IdArray);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nFrontiers TopologicalNodesFrontiers(const CSRMatrix& csr) {\n  std::vector<IdType> ids;\n  std::vector<int64_t> sections;\n  VectorQueueWrapper<IdType> queue(&ids);\n  auto visit = [&](const uint64_t v) {};\n  auto make_frontier = [&]() {\n    if (!queue.empty()) {\n      // do not push zero-length frontier\n      sections.push_back(queue.size());\n    }\n  };\n  TopologicalNodes<IdType>(csr, &queue, visit, make_frontier);\n\n  Frontiers front;\n  front.ids = VecToIdArray(ids, sizeof(IdType) * 8);\n  front.sections = VecToIdArray(sections, sizeof(int64_t) * 8);\n  return front;\n}\n\ntemplate Frontiers TopologicalNodesFrontiers<kDGLCPU, int32_t>(\n    const CSRMatrix&);\ntemplate Frontiers TopologicalNodesFrontiers<kDGLCPU, int64_t>(\n    const CSRMatrix&);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nFrontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source) {\n  const int64_t len = source->shape[0];\n  const IdType* src_data = static_cast<IdType*>(source->data);\n  std::vector<std::vector<IdType>> edges(len);\n\n  for (int64_t i = 0; i < len; ++i) {\n    auto visit = [&](IdType e, int tag) { edges[i].push_back(e); };\n    DFSLabeledEdges<IdType>(csr, src_data[i], false, false, visit);\n  }\n\n  Frontiers front;\n  front.ids = MergeMultipleTraversals(edges);\n  front.sections = ComputeMergedSections(edges);\n  return front;\n}\n\ntemplate Frontiers DGLDFSEdges<kDGLCPU, int32_t>(const CSRMatrix&, IdArray);\ntemplate Frontiers DGLDFSEdges<kDGLCPU, int64_t>(const CSRMatrix&, IdArray);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nFrontiers DGLDFSLabeledEdges(\n    const CSRMatrix& csr, IdArray source, const bool has_reverse_edge,\n    const bool has_nontree_edge, const bool return_labels) {\n  const int64_t len = source->shape[0];\n  const IdType* src_data = static_cast<IdType*>(source->data);\n  std::vector<std::vector<IdType>> edges(len);\n  std::vector<std::vector<int64_t>> tags;\n\n  if (return_labels) {\n    tags.resize(len);\n  }\n\n  for (int64_t i = 0; i < len; ++i) {\n    auto visit = [&](IdType e, int64_t tag) {\n      edges[i].push_back(e);\n      if (return_labels) {\n        tags[i].push_back(tag);\n      }\n    };\n    DFSLabeledEdges<IdType>(\n        csr, src_data[i], has_reverse_edge, has_nontree_edge, visit);\n  }\n\n  Frontiers front;\n  front.ids = MergeMultipleTraversals(edges);\n  front.sections = ComputeMergedSections(edges);\n  if (return_labels) {\n    front.tags = MergeMultipleTraversals(tags);\n  }\n\n  return front;\n}\n\ntemplate Frontiers DGLDFSLabeledEdges<kDGLCPU, int32_t>(\n    const CSRMatrix&, IdArray, const bool, const bool, const bool);\ntemplate Frontiers DGLDFSLabeledEdges<kDGLCPU, int64_t>(\n    const CSRMatrix&, IdArray, const bool, const bool, const bool);\n\n}  // namespace impl\n}  // namespace aten\n}  // namespace dgl\n"
  },
  {
    "path": "src/array/cpu/traversal.h",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file array/cpu/traversal.h\n * @brief Graph traversal routines.\n *\n * Traversal routines generate frontiers. Frontiers can be node frontiers or\n * edge frontiers depending on the traversal function. Each frontier is a list\n * of nodes/edges (specified by their ids). An optional tag can be specified for\n * each node/edge (represented by an int value).\n */\n#ifndef DGL_ARRAY_CPU_TRAVERSAL_H_\n#define DGL_ARRAY_CPU_TRAVERSAL_H_\n\n#include <dgl/graph_interface.h>\n\n#include <stack>\n#include <tuple>\n#include <vector>\n\nnamespace dgl {\nnamespace aten {\nnamespace impl {\n\n/**\n * @brief Traverse the graph in a breadth-first-search (BFS) order.\n *\n * The queue object must suffice following interface:\n *   Members:\n *   void push(IdType);  // push one node\n *   IdType top();       // get the first node\n *   void pop();           // pop one node\n *   bool empty();         // return true if the queue is empty\n *   size_t size();        // return the size of the queue\n * For example, std::queue<IdType> is a valid queue type.\n *\n * The visit function must be compatible with following interface:\n *   void (*visit)(IdType );\n *\n * The frontier function must be compatible with following interface:\n *   void (*make_frontier)(void);\n *\n * @param graph The graph.\n * @param sources Source nodes.\n * @param reversed If true, BFS follows the in-edge direction\n * @param queue The queue used to do bfs.\n * @param visit The function to call when a node is visited.\n * @param make_frontier The function to indicate that a new froniter can be\n * made;\n */\ntemplate <\n    typename IdType, typename Queue, typename VisitFn, typename FrontierFn>\nvoid BFSTraverseNodes(\n    const CSRMatrix &csr, IdArray source, Queue *queue, VisitFn visit,\n    FrontierFn make_frontier) {\n  const int64_t len = source->shape[0];\n  const IdType *src_data = static_cast<IdType *>(source->data);\n\n  const IdType *indptr_data = static_cast<IdType *>(csr.indptr->data);\n  const IdType *indices_data = static_cast<IdType *>(csr.indices->data);\n  const int64_t num_nodes = csr.num_rows;\n  std::vector<bool> visited(num_nodes);\n  for (int64_t i = 0; i < len; ++i) {\n    const IdType u = src_data[i];\n    visited[u] = true;\n    visit(u);\n    queue->push(u);\n  }\n  make_frontier();\n\n  while (!queue->empty()) {\n    const size_t size = queue->size();\n    for (size_t i = 0; i < size; ++i) {\n      const IdType u = queue->top();\n      queue->pop();\n      for (auto idx = indptr_data[u]; idx < indptr_data[u + 1]; ++idx) {\n        auto v = indices_data[idx];\n        if (!visited[v]) {\n          visited[v] = true;\n          visit(v);\n          queue->push(v);\n        }\n      }\n    }\n    make_frontier();\n  }\n}\n\n/**\n * @brief Traverse the graph in a breadth-first-search (BFS) order, returning\n *        the edges of the BFS tree.\n *\n * The queue object must suffice following interface:\n *   Members:\n *   void push(IdType);  // push one node\n *   IdType top();       // get the first node\n *   void pop();           // pop one node\n *   bool empty();         // return true if the queue is empty\n *   size_t size();        // return the size of the queue\n * For example, std::queue<IdType> is a valid queue type.\n *\n * The visit function must be compatible with following interface:\n *   void (*visit)(IdType );\n *\n * The frontier function must be compatible with following interface:\n *   void (*make_frontier)(void);\n *\n * @param graph The graph.\n * @param sources Source nodes.\n * @param reversed If true, BFS follows the in-edge direction\n * @param queue The queue used to do bfs.\n * @param visit The function to call when a node is visited.\n *        The argument would be edge ID.\n * @param make_frontier The function to indicate that a new frontier can be\n * made;\n */\ntemplate <\n    typename IdType, typename Queue, typename VisitFn, typename FrontierFn>\nvoid BFSTraverseEdges(\n    const CSRMatrix &csr, IdArray source, Queue *queue, VisitFn visit,\n    FrontierFn make_frontier) {\n  const int64_t len = source->shape[0];\n  const IdType *src_data = static_cast<IdType *>(source->data);\n\n  const IdType *indptr_data = static_cast<IdType *>(csr.indptr->data);\n  const IdType *indices_data = static_cast<IdType *>(csr.indices->data);\n  const IdType *eid_data = static_cast<IdType *>(csr.data->data);\n\n  const int64_t num_nodes = csr.num_rows;\n  std::vector<bool> visited(num_nodes);\n  for (int64_t i = 0; i < len; ++i) {\n    const IdType u = src_data[i];\n    visited[u] = true;\n    queue->push(u);\n  }\n  make_frontier();\n\n  while (!queue->empty()) {\n    const size_t size = queue->size();\n    for (size_t i = 0; i < size; ++i) {\n      const IdType u = queue->top();\n      queue->pop();\n      for (auto idx = indptr_data[u]; idx < indptr_data[u + 1]; ++idx) {\n        auto e = eid_data ? eid_data[idx] : idx;\n        const IdType v = indices_data[idx];\n        if (!visited[v]) {\n          visited[v] = true;\n          visit(e);\n          queue->push(v);\n        }\n      }\n    }\n    make_frontier();\n  }\n}\n\n/**\n * @brief Traverse the graph in topological order.\n *\n * The queue object must suffice following interface:\n *   Members:\n *   void push(IdType);  // push one node\n *   IdType top();       // get the first node\n *   void pop();           // pop one node\n *   bool empty();         // return true if the queue is empty\n *   size_t size();        // return the size of the queue\n * For example, std::queue<IdType> is a valid queue type.\n *\n * The visit function must be compatible with following interface:\n *   void (*visit)(IdType );\n *\n * The frontier function must be compatible with following interface:\n *   void (*make_frontier)(void);\n *\n * @param graph The graph.\n * @param reversed If true, follows the in-edge direction\n * @param queue The queue used to do bfs.\n * @param visit The function to call when a node is visited.\n * @param make_frontier The function to indicate that a new froniter can be\n * made;\n */\ntemplate <\n    typename IdType, typename Queue, typename VisitFn, typename FrontierFn>\nvoid TopologicalNodes(\n    const CSRMatrix &csr, Queue *queue, VisitFn visit,\n    FrontierFn make_frontier) {\n  int64_t num_visited_nodes = 0;\n  const IdType *indptr_data = static_cast<IdType *>(csr.indptr->data);\n  const IdType *indices_data = static_cast<IdType *>(csr.indices->data);\n\n  const int64_t num_nodes = csr.num_rows;\n  const int64_t num_edges = csr.indices->shape[0];\n  std::vector<int64_t> degrees(num_nodes, 0);\n  for (int64_t eid = 0; eid < num_edges; ++eid) {\n    degrees[indices_data[eid]]++;\n  }\n\n  for (int64_t vid = 0; vid < num_nodes; ++vid) {\n    if (degrees[vid] == 0) {\n      visit(vid);\n      queue->push(static_cast<IdType>(vid));\n      ++num_visited_nodes;\n    }\n  }\n  make_frontier();\n\n  while (!queue->empty()) {\n    const size_t size = queue->size();\n    for (size_t i = 0; i < size; ++i) {\n      const IdType u = queue->top();\n      queue->pop();\n      for (auto idx = indptr_data[u]; idx < indptr_data[u + 1]; ++idx) {\n        const IdType v = indices_data[idx];\n        if (--(degrees[v]) == 0) {\n          visit(v);\n          queue->push(v);\n          ++num_visited_nodes;\n        }\n      }\n    }\n    make_frontier();\n  }\n\n  if (num_visited_nodes != num_nodes) {\n    LOG(FATAL)\n        << \"Error in topological traversal: loop detected in the given graph.\";\n  }\n}\n\n/** @brief Tags for ``DFSEdges``. */\nenum DFSEdgeTag {\n  kForward = 0,\n  kReverse,\n  kNonTree,\n};\n/**\n * @brief Traverse the graph in a depth-first-search (DFS) order.\n *\n * The traversal visit edges in its DFS order. Edges have three tags:\n * FORWARD(0), REVERSE(1), NONTREE(2)\n *\n * A FORWARD edge is one in which `u` has been visisted but `v` has not.\n * A REVERSE edge is one in which both `u` and `v` have been visisted and the\n * edge is in the DFS tree. A NONTREE edge is one in which both `u` and `v` have\n * been visisted but the edge is NOT in the DFS tree.\n *\n * @param source Source node.\n * @param reversed If true, DFS follows the in-edge direction\n * @param has_reverse_edge If true, REVERSE edges are included\n * @param has_nontree_edge If true, NONTREE edges are included\n * @param visit The function to call when an edge is visited; the edge id and\n * its tag will be given as the arguments.\n */\ntemplate <typename IdType, typename VisitFn>\nvoid DFSLabeledEdges(\n    const CSRMatrix &csr, IdType source, bool has_reverse_edge,\n    bool has_nontree_edge, VisitFn visit) {\n  const int64_t num_nodes = csr.num_rows;\n  CHECK_GE(num_nodes, source)\n      << \"source \" << source << \" is out of range [0,\" << num_nodes << \"]\";\n  const IdType *indptr_data = static_cast<IdType *>(csr.indptr->data);\n  const IdType *indices_data = static_cast<IdType *>(csr.indices->data);\n  const IdType *eid_data = static_cast<IdType *>(csr.data->data);\n\n  if (indptr_data[source + 1] - indptr_data[source] == 0) {\n    // no out-going edges from the source node\n    return;\n  }\n\n  typedef std::tuple<IdType, size_t, bool> StackEntry;\n  std::stack<StackEntry> stack;\n  std::vector<bool> visited(num_nodes);\n  visited[source] = true;\n  stack.push(std::make_tuple(source, 0, false));\n  IdType u = 0;\n  int64_t i = 0;\n  bool on_tree = false;\n\n  while (!stack.empty()) {\n    std::tie(u, i, on_tree) = stack.top();\n    const IdType v = indices_data[indptr_data[u] + i];\n    const IdType uv =\n        eid_data ? eid_data[indptr_data[u] + i] : indptr_data[u] + i;\n    if (visited[v]) {\n      if (!on_tree && has_nontree_edge) {\n        visit(uv, kNonTree);\n      } else if (on_tree && has_reverse_edge) {\n        visit(uv, kReverse);\n      }\n      stack.pop();\n      // find next one.\n      if (indptr_data[u] + i < indptr_data[u + 1] - 1) {\n        stack.push(std::make_tuple(u, i + 1, false));\n      }\n    } else {\n      visited[v] = true;\n      std::get<2>(stack.top()) = true;\n      visit(uv, kForward);\n      // expand\n      if (indptr_data[v] < indptr_data[v + 1]) {\n        stack.push(std::make_tuple(v, 0, false));\n      }\n    }\n  }\n}\n\n}  // namespace impl\n}  // namespace aten\n}  // namespace dgl\n\n#endif  // DGL_ARRAY_CPU_TRAVERSAL_H_\n"
  },
  {
    "path": "src/array/cuda/array_cumsum.cu",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file array/cpu/array_cumsum.cu\n * @brief Array cumsum GPU implementation\n */\n#include <dgl/array.h>\n\n#include <cub/cub.cuh>\n\n#include \"../../runtime/cuda/cuda_common.h\"\n#include \"./utils.h\"\n\nnamespace dgl {\nusing runtime::NDArray;\nnamespace aten {\nnamespace impl {\n\ntemplate <DGLDeviceType XPU, typename IdType>\nIdArray CumSum(IdArray array, bool prepend_zero) {\n  const int64_t len = array.NumElements();\n  if (len == 0)\n    return !prepend_zero ? array\n                         : aten::Full(0, 1, array->dtype.bits, array->ctx);\n\n  auto device = runtime::DeviceAPI::Get(array->ctx);\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  const IdType* in_d = array.Ptr<IdType>();\n  IdArray ret;\n  IdType* out_d = nullptr;\n  if (prepend_zero) {\n    ret = aten::Full(0, len + 1, array->dtype.bits, array->ctx);\n    out_d = ret.Ptr<IdType>() + 1;\n  } else {\n    ret = aten::NewIdArray(len, array->ctx, array->dtype.bits);\n    out_d = ret.Ptr<IdType>();\n  }\n  // Allocate workspace\n  size_t workspace_size = 0;\n  CUDA_CALL(cub::DeviceScan::InclusiveSum(\n      nullptr, workspace_size, in_d, out_d, len, stream));\n  void* workspace = device->AllocWorkspace(array->ctx, workspace_size);\n\n  // Compute cumsum\n  CUDA_CALL(cub::DeviceScan::InclusiveSum(\n      workspace, workspace_size, in_d, out_d, len, stream));\n\n  device->FreeWorkspace(array->ctx, workspace);\n\n  return ret;\n}\n\ntemplate IdArray CumSum<kDGLCUDA, int32_t>(IdArray, bool);\ntemplate IdArray CumSum<kDGLCUDA, int64_t>(IdArray, bool);\n\n}  // namespace impl\n}  // namespace aten\n}  // namespace dgl\n"
  },
  {
    "path": "src/array/cuda/array_index_select.cu",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file array/cpu/array_index_select.cu\n * @brief Array index select GPU implementation\n */\n#include <dgl/array.h>\n\n#include \"../../runtime/cuda/cuda_common.h\"\n#include \"./array_index_select.cuh\"\n#include \"./utils.h\"\n\nnamespace dgl {\nusing runtime::NDArray;\nnamespace aten {\nnamespace impl {\n\ntemplate <DGLDeviceType XPU, typename DType, typename IdType>\nNDArray IndexSelect(NDArray array, IdArray index) {\n  const int64_t arr_len = array->shape[0];\n  const int64_t len = index->shape[0];\n  int64_t num_feat = 1;\n  std::vector<int64_t> shape{len};\n  for (int d = 1; d < array->ndim; ++d) {\n    num_feat *= array->shape[d];\n    shape.emplace_back(array->shape[d]);\n  }\n\n  // use index->ctx for pinned array\n  NDArray ret = NDArray::Empty(shape, array->dtype, index->ctx);\n  if (len == 0 || arr_len * num_feat == 0) return ret;\n  DType* ret_data = static_cast<DType*>(ret->data);\n\n  const DType* array_data = static_cast<DType*>(cuda::GetDevicePointer(array));\n  const IdType* idx_data = static_cast<IdType*>(index->data);\n\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  if (num_feat == 1) {\n    const int nt = cuda::FindNumThreads(len);\n    const int nb = (len + nt - 1) / nt;\n    CUDA_KERNEL_CALL(\n        IndexSelectSingleKernel, nb, nt, 0, stream, array_data, idx_data, len,\n        arr_len, ret_data);\n  } else {\n    dim3 block(256, 1);\n    while (static_cast<int64_t>(block.x) >= 2 * num_feat) {\n      block.x /= 2;\n      block.y *= 2;\n    }\n    const dim3 grid((len + block.y - 1) / block.y);\n    CUDA_KERNEL_CALL(\n        IndexSelectMultiKernel, grid, block, 0, stream, array_data, num_feat,\n        idx_data, len, arr_len, ret_data);\n  }\n  return ret;\n}\n\ntemplate NDArray IndexSelect<kDGLCUDA, int32_t, int32_t>(NDArray, IdArray);\ntemplate NDArray IndexSelect<kDGLCUDA, int32_t, int64_t>(NDArray, IdArray);\ntemplate NDArray IndexSelect<kDGLCUDA, int64_t, int32_t>(NDArray, IdArray);\ntemplate NDArray IndexSelect<kDGLCUDA, int64_t, int64_t>(NDArray, IdArray);\ntemplate NDArray IndexSelect<kDGLCUDA, __half, int32_t>(NDArray, IdArray);\ntemplate NDArray IndexSelect<kDGLCUDA, __half, int64_t>(NDArray, IdArray);\n#if BF16_ENABLED\ntemplate NDArray IndexSelect<kDGLCUDA, __nv_bfloat16, int32_t>(\n    NDArray, IdArray);\ntemplate NDArray IndexSelect<kDGLCUDA, __nv_bfloat16, int64_t>(\n    NDArray, IdArray);\n#endif  // BF16_ENABLED\ntemplate NDArray IndexSelect<kDGLCUDA, float, int32_t>(NDArray, IdArray);\ntemplate NDArray IndexSelect<kDGLCUDA, float, int64_t>(NDArray, IdArray);\ntemplate NDArray IndexSelect<kDGLCUDA, double, int32_t>(NDArray, IdArray);\ntemplate NDArray IndexSelect<kDGLCUDA, double, int64_t>(NDArray, IdArray);\n\ntemplate <DGLDeviceType XPU, typename DType>\nDType IndexSelect(NDArray array, int64_t index) {\n  auto device = runtime::DeviceAPI::Get(array->ctx);\n  DType ret = static_cast<DType>(0.0f);\n  device->CopyDataFromTo(\n      static_cast<DType*>(array->data) + index, 0, &ret, 0, sizeof(DType),\n      array->ctx, DGLContext{kDGLCPU, 0}, array->dtype);\n  return ret;\n}\n\ntemplate int32_t IndexSelect<kDGLCUDA, int32_t>(NDArray array, int64_t index);\ntemplate int64_t IndexSelect<kDGLCUDA, int64_t>(NDArray array, int64_t index);\ntemplate uint32_t IndexSelect<kDGLCUDA, uint32_t>(NDArray array, int64_t index);\ntemplate uint64_t IndexSelect<kDGLCUDA, uint64_t>(NDArray array, int64_t index);\ntemplate __half IndexSelect<kDGLCUDA, __half>(NDArray array, int64_t index);\n#if BF16_ENABLED\ntemplate __nv_bfloat16 IndexSelect<kDGLCUDA, __nv_bfloat16>(\n    NDArray array, int64_t index);\n#endif  // BF16_ENABLED\ntemplate float IndexSelect<kDGLCUDA, float>(NDArray array, int64_t index);\ntemplate double IndexSelect<kDGLCUDA, double>(NDArray array, int64_t index);\n\n}  // namespace impl\n}  // namespace aten\n}  // namespace dgl\n"
  },
  {
    "path": "src/array/cuda/array_index_select.cuh",
    "content": "/**\n *  Copyright (c) 2021-2022 by Contributors\n * @file array/cuda/array_index_select.cuh\n * @brief Array index select GPU kernel implementation\n */\n\n#ifndef DGL_ARRAY_CUDA_ARRAY_INDEX_SELECT_CUH_\n#define DGL_ARRAY_CUDA_ARRAY_INDEX_SELECT_CUH_\n\nnamespace dgl {\nnamespace aten {\nnamespace impl {\n\ntemplate <typename DType, typename IdType>\n__global__ void IndexSelectSingleKernel(\n    const DType* array, const IdType* index, const int64_t length,\n    const int64_t arr_len, DType* out, const int64_t* perm = nullptr) {\n  int64_t tx = blockIdx.x * blockDim.x + threadIdx.x;\n  int stride_x = gridDim.x * blockDim.x;\n  while (tx < length) {\n    assert(index[tx] >= 0 && index[tx] < arr_len);\n    const auto out_row = perm ? perm[tx] : tx;\n    out[out_row] = array[index[tx]];\n    tx += stride_x;\n  }\n}\n\ntemplate <typename DType, typename IdType>\n__global__ void IndexSelectMultiKernel(\n    const DType* const array, const int64_t num_feat, const IdType* const index,\n    const int64_t length, const int64_t arr_len, DType* const out,\n    const int64_t* perm = nullptr) {\n  int64_t out_row_index = blockIdx.x * blockDim.y + threadIdx.y;\n\n  const int64_t stride = blockDim.y * gridDim.x;\n\n  while (out_row_index < length) {\n    int64_t col = threadIdx.x;\n    const int64_t in_row = index[out_row_index];\n    assert(in_row >= 0 && in_row < arr_len);\n    const auto out_row = perm ? perm[out_row_index] : out_row_index;\n    while (col < num_feat) {\n      out[out_row * num_feat + col] = array[in_row * num_feat + col];\n      col += blockDim.x;\n    }\n    out_row_index += stride;\n  }\n}\n\ntemplate <typename DType, typename IdType>\n__global__ void IndexScatterSingleKernel(\n    const DType* array, const IdType* index, const int64_t length,\n    const int64_t arr_len, DType* out) {\n  int tx = blockIdx.x * blockDim.x + threadIdx.x;\n  int stride_x = gridDim.x * blockDim.x;\n  while (tx < length) {\n    assert(index[tx] >= 0 && index[tx] < arr_len);\n    out[index[tx]] = array[tx];\n    tx += stride_x;\n  }\n}\n\ntemplate <typename DType, typename IdType>\n__global__ void IndexScatterMultiKernel(\n    const DType* const array, const int64_t num_feat, const IdType* const index,\n    const int64_t length, const int64_t arr_len, DType* const out) {\n  int64_t in_row = blockIdx.x * blockDim.y + threadIdx.y;\n\n  const int64_t stride = blockDim.y * gridDim.x;\n\n  while (in_row < length) {\n    int64_t col = threadIdx.x;\n    const int64_t out_row = index[in_row];\n    assert(out_row >= 0 && out_row < arr_len);\n    while (col < num_feat) {\n      out[out_row * num_feat + col] = array[in_row * num_feat + col];\n      col += blockDim.x;\n    }\n    in_row += stride;\n  }\n}\n\n}  // namespace impl\n}  // namespace aten\n}  // namespace dgl\n\n#endif  // DGL_ARRAY_CUDA_ARRAY_INDEX_SELECT_CUH_\n"
  },
  {
    "path": "src/array/cuda/array_nonzero.cu",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file array/cpu/array_nonzero.cc\n * @brief Array nonzero CPU implementation\n */\n\n#include <dgl/array.h>\n\n#include <cub/cub.cuh>\n\n#include \"../../runtime/cuda/cuda_common.h\"\n#include \"./utils.h\"\n\nnamespace dgl {\nusing runtime::NDArray;\nnamespace aten {\nnamespace impl {\n\ntemplate <typename IdType>\nstruct IsNonZeroIndex {\n  explicit IsNonZeroIndex(const IdType* array) : array_(array) {}\n\n  __device__ bool operator()(const int64_t index) { return array_[index] != 0; }\n\n  const IdType* array_;\n};\n\ntemplate <DGLDeviceType XPU, typename IdType>\nIdArray NonZero(IdArray array) {\n  const auto& ctx = array->ctx;\n  auto device = runtime::DeviceAPI::Get(ctx);\n\n  const int64_t len = array->shape[0];\n  IdArray ret = NewIdArray(len, ctx, 64);\n\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n\n  const IdType* const in_data = static_cast<const IdType*>(array->data);\n  int64_t* const out_data = static_cast<int64_t*>(ret->data);\n\n  IsNonZeroIndex<IdType> comp(in_data);\n  cub::CountingInputIterator<int64_t> counter(0);\n\n  // room for cub to output on GPU\n  int64_t* d_num_nonzeros =\n      static_cast<int64_t*>(device->AllocWorkspace(ctx, sizeof(int64_t)));\n\n  size_t temp_size = 0;\n  CUDA_CALL(cub::DeviceSelect::If(\n      nullptr, temp_size, counter, out_data, d_num_nonzeros, len, comp,\n      stream));\n  void* temp = device->AllocWorkspace(ctx, temp_size);\n  CUDA_CALL(cub::DeviceSelect::If(\n      temp, temp_size, counter, out_data, d_num_nonzeros, len, comp, stream));\n  device->FreeWorkspace(ctx, temp);\n\n  // copy number of selected elements from GPU to CPU\n  int64_t num_nonzeros = cuda::GetCUDAScalar(device, ctx, d_num_nonzeros);\n  device->FreeWorkspace(ctx, d_num_nonzeros);\n  device->StreamSync(ctx, stream);\n\n  // truncate array to size\n  return ret.CreateView({num_nonzeros}, ret->dtype, 0);\n}\n\ntemplate IdArray NonZero<kDGLCUDA, int32_t>(IdArray);\ntemplate IdArray NonZero<kDGLCUDA, int64_t>(IdArray);\n\n}  // namespace impl\n}  // namespace aten\n}  // namespace dgl\n"
  },
  {
    "path": "src/array/cuda/array_op_impl.cu",
    "content": "/**\n *  Copyright (c) 2020-2021 by Contributors\n * @file array/cuda/array_op_impl.cu\n * @brief Array operator GPU implementation\n */\n#include <dgl/array.h>\n\n#include \"../../runtime/cuda/cuda_common.h\"\n#include \"../../runtime/cuda/cuda_hashtable.cuh\"\n#include \"../arith.h\"\n#include \"./utils.h\"\n\nnamespace dgl {\nusing runtime::NDArray;\nusing namespace runtime::cuda;\nnamespace aten {\nnamespace impl {\n\n///////////////////////////// BinaryElewise /////////////////////////////\n\ntemplate <typename IdType, typename Op>\n__global__ void _BinaryElewiseKernel(\n    const IdType* lhs, const IdType* rhs, IdType* out, int64_t length) {\n  int tx = blockIdx.x * blockDim.x + threadIdx.x;\n  int stride_x = gridDim.x * blockDim.x;\n  while (tx < length) {\n    out[tx] = Op::Call(lhs[tx], rhs[tx]);\n    tx += stride_x;\n  }\n}\n\ntemplate <DGLDeviceType XPU, typename IdType, typename Op>\nIdArray BinaryElewise(IdArray lhs, IdArray rhs) {\n  const int64_t len = lhs->shape[0];\n  IdArray ret = NewIdArray(lhs->shape[0], lhs->ctx, lhs->dtype.bits);\n  const IdType* lhs_data = static_cast<IdType*>(lhs->data);\n  const IdType* rhs_data = static_cast<IdType*>(rhs->data);\n  IdType* ret_data = static_cast<IdType*>(ret->data);\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  int nt = cuda::FindNumThreads(len);\n  int nb = (len + nt - 1) / nt;\n  CUDA_KERNEL_CALL(\n      (_BinaryElewiseKernel<IdType, Op>), nb, nt, 0, stream, lhs_data, rhs_data,\n      ret_data, len);\n  return ret;\n}\n\ntemplate IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Add>(\n    IdArray lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Sub>(\n    IdArray lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mul>(\n    IdArray lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Div>(\n    IdArray lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mod>(\n    IdArray lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GT>(\n    IdArray lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LT>(\n    IdArray lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GE>(\n    IdArray lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LE>(\n    IdArray lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int32_t, arith::EQ>(\n    IdArray lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int32_t, arith::NE>(\n    IdArray lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Add>(\n    IdArray lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Sub>(\n    IdArray lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mul>(\n    IdArray lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Div>(\n    IdArray lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mod>(\n    IdArray lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GT>(\n    IdArray lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LT>(\n    IdArray lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GE>(\n    IdArray lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LE>(\n    IdArray lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int64_t, arith::EQ>(\n    IdArray lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int64_t, arith::NE>(\n    IdArray lhs, IdArray rhs);\n\ntemplate <typename IdType, typename Op>\n__global__ void _BinaryElewiseKernel(\n    const IdType* lhs, IdType rhs, IdType* out, int64_t length) {\n  int tx = blockIdx.x * blockDim.x + threadIdx.x;\n  int stride_x = gridDim.x * blockDim.x;\n  while (tx < length) {\n    out[tx] = Op::Call(lhs[tx], rhs);\n    tx += stride_x;\n  }\n}\n\ntemplate <DGLDeviceType XPU, typename IdType, typename Op>\nIdArray BinaryElewise(IdArray lhs, IdType rhs) {\n  const int64_t len = lhs->shape[0];\n  IdArray ret = NewIdArray(lhs->shape[0], lhs->ctx, lhs->dtype.bits);\n  const IdType* lhs_data = static_cast<IdType*>(lhs->data);\n  IdType* ret_data = static_cast<IdType*>(ret->data);\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  int nt = cuda::FindNumThreads(len);\n  int nb = (len + nt - 1) / nt;\n  CUDA_KERNEL_CALL(\n      (_BinaryElewiseKernel<IdType, Op>), nb, nt, 0, stream, lhs_data, rhs,\n      ret_data, len);\n  return ret;\n}\n\ntemplate IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Add>(\n    IdArray lhs, int32_t rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Sub>(\n    IdArray lhs, int32_t rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mul>(\n    IdArray lhs, int32_t rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Div>(\n    IdArray lhs, int32_t rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mod>(\n    IdArray lhs, int32_t rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GT>(\n    IdArray lhs, int32_t rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LT>(\n    IdArray lhs, int32_t rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GE>(\n    IdArray lhs, int32_t rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LE>(\n    IdArray lhs, int32_t rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int32_t, arith::EQ>(\n    IdArray lhs, int32_t rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int32_t, arith::NE>(\n    IdArray lhs, int32_t rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Add>(\n    IdArray lhs, int64_t rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Sub>(\n    IdArray lhs, int64_t rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mul>(\n    IdArray lhs, int64_t rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Div>(\n    IdArray lhs, int64_t rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mod>(\n    IdArray lhs, int64_t rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GT>(\n    IdArray lhs, int64_t rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LT>(\n    IdArray lhs, int64_t rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GE>(\n    IdArray lhs, int64_t rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LE>(\n    IdArray lhs, int64_t rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int64_t, arith::EQ>(\n    IdArray lhs, int64_t rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int64_t, arith::NE>(\n    IdArray lhs, int64_t rhs);\n\ntemplate <typename IdType, typename Op>\n__global__ void _BinaryElewiseKernel(\n    IdType lhs, const IdType* rhs, IdType* out, int64_t length) {\n  int tx = blockIdx.x * blockDim.x + threadIdx.x;\n  int stride_x = gridDim.x * blockDim.x;\n  while (tx < length) {\n    out[tx] = Op::Call(lhs, rhs[tx]);\n    tx += stride_x;\n  }\n}\n\ntemplate <DGLDeviceType XPU, typename IdType, typename Op>\nIdArray BinaryElewise(IdType lhs, IdArray rhs) {\n  const int64_t len = rhs->shape[0];\n  IdArray ret = NewIdArray(rhs->shape[0], rhs->ctx, rhs->dtype.bits);\n  const IdType* rhs_data = static_cast<IdType*>(rhs->data);\n  IdType* ret_data = static_cast<IdType*>(ret->data);\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  int nt = cuda::FindNumThreads(len);\n  int nb = (len + nt - 1) / nt;\n  CUDA_KERNEL_CALL(\n      (_BinaryElewiseKernel<IdType, Op>), nb, nt, 0, stream, lhs, rhs_data,\n      ret_data, len);\n  return ret;\n}\n\ntemplate IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Add>(\n    int32_t lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Sub>(\n    int32_t lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mul>(\n    int32_t lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Div>(\n    int32_t lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mod>(\n    int32_t lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GT>(\n    int32_t lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LT>(\n    int32_t lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GE>(\n    int32_t lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LE>(\n    int32_t lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int32_t, arith::EQ>(\n    int32_t lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int32_t, arith::NE>(\n    int32_t lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Add>(\n    int64_t lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Sub>(\n    int64_t lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mul>(\n    int64_t lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Div>(\n    int64_t lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mod>(\n    int64_t lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GT>(\n    int64_t lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LT>(\n    int64_t lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GE>(\n    int64_t lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LE>(\n    int64_t lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int64_t, arith::EQ>(\n    int64_t lhs, IdArray rhs);\ntemplate IdArray BinaryElewise<kDGLCUDA, int64_t, arith::NE>(\n    int64_t lhs, IdArray rhs);\n\ntemplate <typename IdType, typename Op>\n__global__ void _UnaryElewiseKernel(\n    const IdType* lhs, IdType* out, int64_t length) {\n  int tx = blockIdx.x * blockDim.x + threadIdx.x;\n  int stride_x = gridDim.x * blockDim.x;\n  while (tx < length) {\n    out[tx] = Op::Call(lhs[tx]);\n    tx += stride_x;\n  }\n}\n\ntemplate <DGLDeviceType XPU, typename IdType, typename Op>\nIdArray UnaryElewise(IdArray lhs) {\n  const int64_t len = lhs->shape[0];\n  IdArray ret = NewIdArray(lhs->shape[0], lhs->ctx, lhs->dtype.bits);\n  const IdType* lhs_data = static_cast<IdType*>(lhs->data);\n  IdType* ret_data = static_cast<IdType*>(ret->data);\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  int nt = cuda::FindNumThreads(len);\n  int nb = (len + nt - 1) / nt;\n  CUDA_KERNEL_CALL(\n      (_UnaryElewiseKernel<IdType, Op>), nb, nt, 0, stream, lhs_data, ret_data,\n      len);\n  return ret;\n}\n\ntemplate IdArray UnaryElewise<kDGLCUDA, int32_t, arith::Neg>(IdArray lhs);\ntemplate IdArray UnaryElewise<kDGLCUDA, int64_t, arith::Neg>(IdArray lhs);\n\n///////////////////////////// Full /////////////////////////////\n\ntemplate <typename DType>\n__global__ void _FullKernel(DType* out, int64_t length, DType val) {\n  int tx = blockIdx.x * blockDim.x + threadIdx.x;\n  int stride_x = gridDim.x * blockDim.x;\n  while (tx < length) {\n    out[tx] = val;\n    tx += stride_x;\n  }\n}\n\ntemplate <DGLDeviceType XPU, typename DType>\nNDArray Full(DType val, int64_t length, DGLContext ctx) {\n  NDArray ret = NDArray::Empty({length}, DGLDataTypeTraits<DType>::dtype, ctx);\n  DType* ret_data = static_cast<DType*>(ret->data);\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  int nt = cuda::FindNumThreads(length);\n  int nb = (length + nt - 1) / nt;\n  CUDA_KERNEL_CALL(\n      (_FullKernel<DType>), nb, nt, 0, stream, ret_data, length, val);\n  return ret;\n}\n\ntemplate IdArray Full<kDGLCUDA, int32_t>(\n    int32_t val, int64_t length, DGLContext ctx);\ntemplate IdArray Full<kDGLCUDA, int64_t>(\n    int64_t val, int64_t length, DGLContext ctx);\ntemplate IdArray Full<kDGLCUDA, __half>(\n    __half val, int64_t length, DGLContext ctx);\n#if BF16_ENABLED\ntemplate IdArray Full<kDGLCUDA, __nv_bfloat16>(\n    __nv_bfloat16 val, int64_t length, DGLContext ctx);\n#endif  // BF16_ENABLED\ntemplate IdArray Full<kDGLCUDA, float>(\n    float val, int64_t length, DGLContext ctx);\ntemplate IdArray Full<kDGLCUDA, double>(\n    double val, int64_t length, DGLContext ctx);\n\n///////////////////////////// Range /////////////////////////////\n\ntemplate <typename IdType>\n__global__ void _RangeKernel(IdType* out, IdType low, IdType length) {\n  int tx = blockIdx.x * blockDim.x + threadIdx.x;\n  int stride_x = gridDim.x * blockDim.x;\n  while (tx < length) {\n    out[tx] = low + tx;\n    tx += stride_x;\n  }\n}\n\ntemplate <DGLDeviceType XPU, typename IdType>\nIdArray Range(IdType low, IdType high, DGLContext ctx) {\n  CHECK(high >= low) << \"high must be bigger than low\";\n  const IdType length = high - low;\n  IdArray ret = NewIdArray(length, ctx, sizeof(IdType) * 8);\n  if (length == 0) return ret;\n  IdType* ret_data = static_cast<IdType*>(ret->data);\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  int nt = cuda::FindNumThreads(length);\n  int nb = (length + nt - 1) / nt;\n  CUDA_KERNEL_CALL(\n      (_RangeKernel<IdType>), nb, nt, 0, stream, ret_data, low, length);\n  return ret;\n}\n\ntemplate IdArray Range<kDGLCUDA, int32_t>(int32_t, int32_t, DGLContext);\ntemplate IdArray Range<kDGLCUDA, int64_t>(int64_t, int64_t, DGLContext);\n\n///////////////////////////// Relabel_ //////////////////////////////\n\ntemplate <typename IdType>\n__global__ void _RelabelKernel(\n    IdType* out, int64_t length, DeviceOrderedHashTable<IdType> table) {\n  int tx = blockIdx.x * blockDim.x + threadIdx.x;\n  int stride_x = gridDim.x * blockDim.x;\n\n  while (tx < length) {\n    out[tx] = table.Search(out[tx])->local;\n    tx += stride_x;\n  }\n}\n\ntemplate <DGLDeviceType XPU, typename IdType>\nIdArray Relabel_(const std::vector<IdArray>& arrays) {\n  IdArray all_nodes = Concat(arrays);\n  const int64_t total_length = all_nodes->shape[0];\n\n  if (total_length == 0) {\n    return all_nodes;\n  }\n\n  const auto& ctx = arrays[0]->ctx;\n  auto device = runtime::DeviceAPI::Get(ctx);\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n\n  // build node maps and get the induced nodes\n  OrderedHashTable<IdType> node_map(total_length, ctx, stream);\n  int64_t num_induced = 0;\n  int64_t* num_induced_device =\n      static_cast<int64_t*>(device->AllocWorkspace(ctx, sizeof(int64_t)));\n  IdArray induced_nodes = NewIdArray(total_length, ctx, sizeof(IdType) * 8);\n\n  CUDA_CALL(cudaMemsetAsync(\n      num_induced_device, 0, sizeof(*num_induced_device), stream));\n\n  node_map.FillWithDuplicates(\n      all_nodes.Ptr<IdType>(), all_nodes->shape[0], induced_nodes.Ptr<IdType>(),\n      num_induced_device, stream);\n  // copy using the internal current stream\n  device->CopyDataFromTo(\n      num_induced_device, 0, &num_induced, 0, sizeof(num_induced), ctx,\n      DGLContext{kDGLCPU, 0}, DGLDataType{kDGLInt, 64, 1});\n\n  device->StreamSync(ctx, stream);\n  device->FreeWorkspace(ctx, num_induced_device);\n\n  // resize the induced nodes\n  induced_nodes->shape[0] = num_induced;\n\n  // relabel\n  const int nt = 128;\n  for (IdArray arr : arrays) {\n    const int64_t length = arr->shape[0];\n    int nb = (length + nt - 1) / nt;\n    CUDA_KERNEL_CALL(\n        (_RelabelKernel<IdType>), nb, nt, 0, stream, arr.Ptr<IdType>(), length,\n        node_map.DeviceHandle());\n  }\n\n  return induced_nodes;\n}\n\ntemplate IdArray Relabel_<kDGLCUDA, int32_t>(\n    const std::vector<IdArray>& arrays);\ntemplate IdArray Relabel_<kDGLCUDA, int64_t>(\n    const std::vector<IdArray>& arrays);\n\n///////////////////////////// AsNumBits /////////////////////////////\n\ntemplate <typename InType, typename OutType>\n__global__ void _CastKernel(const InType* in, OutType* out, size_t length) {\n  int tx = blockIdx.x * blockDim.x + threadIdx.x;\n  int stride_x = gridDim.x * blockDim.x;\n  while (tx < length) {\n    out[tx] = in[tx];\n    tx += stride_x;\n  }\n}\n\ntemplate <DGLDeviceType XPU, typename IdType>\nIdArray AsNumBits(IdArray arr, uint8_t bits) {\n  const std::vector<int64_t> shape(arr->shape, arr->shape + arr->ndim);\n  IdArray ret = IdArray::Empty(shape, DGLDataType{kDGLInt, bits, 1}, arr->ctx);\n  const int64_t length = ret.NumElements();\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  int nt = cuda::FindNumThreads(length);\n  int nb = (length + nt - 1) / nt;\n  if (bits == 32) {\n    CUDA_KERNEL_CALL(\n        (_CastKernel<IdType, int32_t>), nb, nt, 0, stream,\n        static_cast<IdType*>(arr->data), static_cast<int32_t*>(ret->data),\n        length);\n  } else {\n    CUDA_KERNEL_CALL(\n        (_CastKernel<IdType, int64_t>), nb, nt, 0, stream,\n        static_cast<IdType*>(arr->data), static_cast<int64_t*>(ret->data),\n        length);\n  }\n  return ret;\n}\n\ntemplate IdArray AsNumBits<kDGLCUDA, int32_t>(IdArray arr, uint8_t bits);\ntemplate IdArray AsNumBits<kDGLCUDA, int64_t>(IdArray arr, uint8_t bits);\n\n}  // namespace impl\n}  // namespace aten\n}  // namespace dgl\n"
  },
  {
    "path": "src/array/cuda/array_scatter.cu",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file array/cuda/array_scatter.cu\n * @brief Array scatter GPU implementation\n */\n#include <dgl/array.h>\n\n#include \"../../runtime/cuda/cuda_common.h\"\n#include \"./utils.h\"\n\nnamespace dgl {\nusing runtime::NDArray;\nnamespace aten {\nnamespace impl {\n\ntemplate <typename DType, typename IdType>\n__global__ void _ScatterKernel(\n    const IdType* index, const DType* value, int64_t length, DType* out) {\n  int tx = blockIdx.x * blockDim.x + threadIdx.x;\n  int stride_x = gridDim.x * blockDim.x;\n  while (tx < length) {\n    out[index[tx]] = value[tx];\n    tx += stride_x;\n  }\n}\n\ntemplate <DGLDeviceType XPU, typename DType, typename IdType>\nvoid Scatter_(IdArray index, NDArray value, NDArray out) {\n  const int64_t len = index->shape[0];\n  const IdType* idx = index.Ptr<IdType>();\n  const DType* val = value.Ptr<DType>();\n  DType* outd = out.Ptr<DType>();\n\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  const int nt = cuda::FindNumThreads(len);\n  const int nb = (len + nt - 1) / nt;\n  CUDA_KERNEL_CALL(_ScatterKernel, nb, nt, 0, stream, idx, val, len, outd);\n}\n\ntemplate void Scatter_<kDGLCUDA, int32_t, int32_t>(IdArray, NDArray, NDArray);\ntemplate void Scatter_<kDGLCUDA, int64_t, int32_t>(IdArray, NDArray, NDArray);\ntemplate void Scatter_<kDGLCUDA, __half, int32_t>(IdArray, NDArray, NDArray);\n#if BF16_ENABLED\ntemplate void Scatter_<kDGLCUDA, __nv_bfloat16, int32_t>(\n    IdArray, NDArray, NDArray);\n#endif  // BF16_ENABLED\ntemplate void Scatter_<kDGLCUDA, float, int32_t>(IdArray, NDArray, NDArray);\ntemplate void Scatter_<kDGLCUDA, double, int32_t>(IdArray, NDArray, NDArray);\ntemplate void Scatter_<kDGLCUDA, int32_t, int64_t>(IdArray, NDArray, NDArray);\ntemplate void Scatter_<kDGLCUDA, int64_t, int64_t>(IdArray, NDArray, NDArray);\ntemplate void Scatter_<kDGLCUDA, __half, int64_t>(IdArray, NDArray, NDArray);\n#if BF16_ENABLED\ntemplate void Scatter_<kDGLCUDA, __nv_bfloat16, int64_t>(\n    IdArray, NDArray, NDArray);\n#endif  // BF16_ENABLED\ntemplate void Scatter_<kDGLCUDA, float, int64_t>(IdArray, NDArray, NDArray);\ntemplate void Scatter_<kDGLCUDA, double, int64_t>(IdArray, NDArray, NDArray);\n\n};  // namespace impl\n};  // namespace aten\n};  // namespace dgl\n"
  },
  {
    "path": "src/array/cuda/array_sort.cu",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file array/cpu/array_sort.cu\n * @brief Array sort GPU implementation\n */\n#include <dgl/array.h>\n\n#include <cub/cub.cuh>\n\n#include \"../../runtime/cuda/cuda_common.h\"\n#include \"./utils.h\"\n\nnamespace dgl {\nusing runtime::NDArray;\nnamespace aten {\nnamespace impl {\n\ntemplate <DGLDeviceType XPU, typename IdType>\nstd::pair<IdArray, IdArray> Sort(IdArray array, int num_bits) {\n  const auto& ctx = array->ctx;\n  auto device = runtime::DeviceAPI::Get(ctx);\n  const int64_t nitems = array->shape[0];\n  IdArray orig_idx = Range(0, nitems, 64, ctx);\n  IdArray sorted_array = NewIdArray(nitems, ctx, array->dtype.bits);\n  IdArray sorted_idx = NewIdArray(nitems, ctx, 64);\n\n  const IdType* keys_in = array.Ptr<IdType>();\n  const int64_t* values_in = orig_idx.Ptr<int64_t>();\n  IdType* keys_out = sorted_array.Ptr<IdType>();\n  int64_t* values_out = sorted_idx.Ptr<int64_t>();\n\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  if (num_bits == 0) {\n    num_bits = sizeof(IdType) * 8;\n  }\n\n  // Allocate workspace\n  size_t workspace_size = 0;\n  CUDA_CALL(cub::DeviceRadixSort::SortPairs(\n      nullptr, workspace_size, keys_in, keys_out, values_in, values_out, nitems,\n      0, num_bits, stream));\n  void* workspace = device->AllocWorkspace(ctx, workspace_size);\n\n  // Compute\n  CUDA_CALL(cub::DeviceRadixSort::SortPairs(\n      workspace, workspace_size, keys_in, keys_out, values_in, values_out,\n      nitems, 0, num_bits, stream));\n\n  device->FreeWorkspace(ctx, workspace);\n\n  return std::make_pair(sorted_array, sorted_idx);\n}\n\ntemplate std::pair<IdArray, IdArray> Sort<kDGLCUDA, int32_t>(\n    IdArray, int num_bits);\ntemplate std::pair<IdArray, IdArray> Sort<kDGLCUDA, int64_t>(\n    IdArray, int num_bits);\n\n}  // namespace impl\n}  // namespace aten\n}  // namespace dgl\n"
  },
  {
    "path": "src/array/cuda/atomic.cuh",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file array/cuda/atomic.cuh\n * @brief Atomic functions\n */\n#ifndef DGL_ARRAY_CUDA_ATOMIC_CUH_\n#define DGL_ARRAY_CUDA_ATOMIC_CUH_\n\n#include <cuda_runtime.h>\n\n#include <cassert>\n#include <cstdint>\n#include <cstdio>\n\n#include \"bf16.cuh\"\n#include \"fp16.cuh\"\n\n#if __CUDA_ARCH__ >= 600\n#include <cuda_fp16.h>\n#endif\n\nnamespace dgl {\nnamespace aten {\nnamespace cuda {\n\n// Type trait for selecting code type\ntemplate <int Bytes>\nstruct Code {};\n\ntemplate <>\nstruct Code<2> {\n  typedef unsigned short int Type;  // NOLINT\n};\n\ntemplate <>\nstruct Code<4> {\n  typedef unsigned int Type;  // NOLINT\n};\n\ntemplate <>\nstruct Code<8> {\n  typedef unsigned long long int Type;  // NOLINT\n};\n\n// Helper class for converting to/from atomicCAS compatible types.\ntemplate <typename T>\nstruct Cast {\n  typedef typename Code<sizeof(T)>::Type Type;\n  static __device__ __forceinline__ Type Encode(T val) {\n    return static_cast<Type>(val);\n  }\n  static __device__ __forceinline__ T Decode(Type code) {\n    return static_cast<T>(code);\n  }\n};\n\ntemplate <>\nstruct Cast<half> {\n  typedef Code<sizeof(half)>::Type Type;\n  static __device__ __forceinline__ Type Encode(half val) {\n    return __half_as_ushort(val);\n  }\n  static __device__ __forceinline__ half Decode(Type code) {\n    return __ushort_as_half(code);\n  }\n};\n\n#if BF16_ENABLED\ntemplate <>\nstruct Cast<__nv_bfloat16> {\n  typedef Code<sizeof(__nv_bfloat16)>::Type Type;\n  static __device__ __forceinline__ Type Encode(__nv_bfloat16 val) {\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800\n    return __bfloat16_as_ushort(val);\n#else\n    printf(\n        \"Atomic operations are not supported for bfloat16 (BF16) \"\n        \"on GPUs with compute capability less than 8.0.\\n\");\n    __trap();\n    return static_cast<Type>(0);\n#endif\n  }\n  static __device__ __forceinline__ __nv_bfloat16 Decode(Type code) {\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800\n    return __ushort_as_bfloat16(code);\n#else\n    printf(\n        \"Atomic operations are not supported for bfloat16 (BF16) \"\n        \"on GPUs with compute capability less than 8.0.\\n\");\n    __trap();\n    return static_cast<__nv_bfloat16>(0.0f);\n#endif\n  }\n};\n#endif  // BF16_ENABLED\n\ntemplate <>\nstruct Cast<float> {\n  typedef Code<sizeof(float)>::Type Type;\n  static __device__ __forceinline__ Type Encode(float val) {\n    return __float_as_uint(val);\n  }\n  static __device__ __forceinline__ float Decode(Type code) {\n    return __uint_as_float(code);\n  }\n};\n\ntemplate <>\nstruct Cast<double> {\n  typedef Code<sizeof(double)>::Type Type;\n  static __device__ __forceinline__ Type Encode(double val) {\n    return __double_as_longlong(val);\n  }\n  static __device__ __forceinline__ double Decode(Type code) {\n    return __longlong_as_double(code);\n  }\n};\n\nstatic __device__ __forceinline__ unsigned short int atomicCASshort(  // NOLINT\n    unsigned short int* address,                                      // NOLINT\n    unsigned short int compare,                                       // NOLINT\n    unsigned short int val) {                                         // NOLINT\n  static_assert(CUDART_VERSION >= 10000, \"Requires at least CUDA 10\");\n#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__) >= 700)\n  return atomicCAS(address, compare, val);\n#else\n  (void)address;\n  (void)compare;\n  (void)val;\n  printf(\n      \"Atomic operations are not supported for half precision (FP16) \"\n      \"on this GPU.\\n\");\n  __trap();\n  return val;\n#endif  // (defined(__CUDA_ARCH__) && (__CUDA_ARCH__) >= 700)\n}\n\n#define DEFINE_ATOMIC(NAME)                                   \\\n  template <typename T>                                       \\\n  __device__ __forceinline__ T Atomic##NAME(T* addr, T val) { \\\n    typedef typename Cast<T>::Type CT;                        \\\n    CT* addr_as_ui = reinterpret_cast<CT*>(addr);             \\\n    CT old = *addr_as_ui;                                     \\\n    CT assumed = old;                                         \\\n    do {                                                      \\\n      assumed = old;                                          \\\n      old = atomicCAS(                                        \\\n          addr_as_ui, assumed,                                \\\n          Cast<T>::Encode(OP(val, Cast<T>::Decode(old))));    \\\n    } while (assumed != old);                                 \\\n    return Cast<T>::Decode(old);                              \\\n  }\n\n#define DEFINE_ATOMIC_16BIT(NAME, dtype)                           \\\n  template <>                                                      \\\n  __device__ __forceinline__ dtype Atomic##NAME<dtype>(            \\\n      dtype * addr, dtype val) {                                   \\\n    typedef uint16_t CT;                                           \\\n    CT* addr_as_ui = reinterpret_cast<CT*>(addr);                  \\\n    CT old = *addr_as_ui;                                          \\\n    CT assumed = old;                                              \\\n    do {                                                           \\\n      assumed = old;                                               \\\n      old = atomicCASshort(                                        \\\n          addr_as_ui, assumed,                                     \\\n          Cast<dtype>::Encode(OP(val, Cast<dtype>::Decode(old)))); \\\n    } while (assumed != old);                                      \\\n    return Cast<dtype>::Decode(old);                               \\\n  }\n\n#define OP(a, b) max(a, b)\nDEFINE_ATOMIC(Max)\nDEFINE_ATOMIC_16BIT(Max, half)\n#if BF16_ENABLED\nDEFINE_ATOMIC_16BIT(Max, __nv_bfloat16)\n#endif  // BF16_ENABLED\n#undef OP\n\n#define OP(a, b) min(a, b)\nDEFINE_ATOMIC(Min)\nDEFINE_ATOMIC_16BIT(Min, half)\n#if BF16_ENABLED\nDEFINE_ATOMIC_16BIT(Min, __nv_bfloat16)\n#endif  // BF16_ENABLED\n#undef OP\n\n#define OP(a, b) a + b\nDEFINE_ATOMIC(Add)\n#undef OP\n\n/**\n * @brief Performs an atomic compare-and-swap on 64 bit integers. That is,\n * it the word `old` at the memory location `address`, computes\n * `(old == compare ? val : old)` , and stores the result back to memory at\n * the same address.\n *\n * @param address The address to perform the atomic operation on.\n * @param compare The value to compare to.\n * @param val The new value to conditionally store.\n *\n * @return The old value at the address.\n */\ninline __device__ int64_t\nAtomicCAS(int64_t* const address, const int64_t compare, const int64_t val) {\n  // match the type of \"::atomicCAS\", so ignore lint warning\n  using Type = unsigned long long int;  // NOLINT\n\n  static_assert(sizeof(Type) == sizeof(*address), \"Type width must match\");\n\n  return atomicCAS(\n      reinterpret_cast<Type*>(address), static_cast<Type>(compare),\n      static_cast<Type>(val));\n}\n\n/**\n * @brief Performs an atomic compare-and-swap on 32 bit integers. That is,\n * it the word `old` at the memory location `address`, computes\n * `(old == compare ? val : old)` , and stores the result back to memory at\n * the same address.\n *\n * @param address The address to perform the atomic operation on.\n * @param compare The value to compare to.\n * @param val The new value to conditionally store.\n *\n * @return The old value at the address.\n */\ninline __device__ int32_t\nAtomicCAS(int32_t* const address, const int32_t compare, const int32_t val) {\n  // match the type of \"::atomicCAS\", so ignore lint warning\n  using Type = int;  // NOLINT\n\n  static_assert(sizeof(Type) == sizeof(*address), \"Type width must match\");\n\n  return atomicCAS(\n      reinterpret_cast<Type*>(address), static_cast<Type>(compare),\n      static_cast<Type>(val));\n}\n\ninline __device__ int64_t AtomicMax(int64_t* const address, const int64_t val) {\n  // match the type of \"::atomicCAS\", so ignore lint warning\n  using Type = unsigned long long int;  // NOLINT\n\n  static_assert(sizeof(Type) == sizeof(*address), \"Type width must match\");\n\n  return atomicMax(reinterpret_cast<Type*>(address), static_cast<Type>(val));\n}\n\ninline __device__ int32_t AtomicMax(int32_t* const address, const int32_t val) {\n  // match the type of \"::atomicCAS\", so ignore lint warning\n  using Type = int;  // NOLINT\n\n  static_assert(sizeof(Type) == sizeof(*address), \"Type width must match\");\n\n  return atomicMax(reinterpret_cast<Type*>(address), static_cast<Type>(val));\n}\n\ntemplate <>\n__device__ __forceinline__ float AtomicAdd<float>(float* addr, float val) {\n#if __CUDA_ARCH__ >= 200\n  return atomicAdd(addr, val);\n#else\n  typedef float T;\n  typedef typename Cast<T>::Type CT;\n  CT* addr_as_ui = reinterpret_cast<CT*>(addr);\n  CT old = *addr_as_ui;\n  CT assumed = old;\n  do {\n    assumed = old;\n    old = atomicCAS(\n        addr_as_ui, assumed, Cast<T>::Encode(Cast<T>::Decode(old) + val));\n  } while (assumed != old);\n  return Cast<T>::Decode(old);\n#endif  // __CUDA_ARCH__\n}\n\ntemplate <>\n__device__ __forceinline__ double AtomicAdd<double>(double* addr, double val) {\n#if __CUDA_ARCH__ >= 600\n  return atomicAdd(addr, val);\n#else\n  typedef double T;\n  typedef typename Cast<T>::Type CT;\n  CT* addr_as_ui = reinterpret_cast<CT*>(addr);\n  CT old = *addr_as_ui;\n  CT assumed = old;\n  do {\n    assumed = old;\n    old = atomicCAS(\n        addr_as_ui, assumed, Cast<T>::Encode(Cast<T>::Decode(old) + val));\n  } while (assumed != old);\n  return Cast<T>::Decode(old);\n#endif\n}\n\n#if defined(CUDART_VERSION) && CUDART_VERSION >= 10000\ntemplate <>\n__device__ __forceinline__ half AtomicAdd<half>(half* addr, half val) {\n// make sure we have half support\n#if __CUDA_ARCH__ >= 700\n  return atomicAdd(addr, val);\n#else\n  (void)addr;\n  (void)val;\n  printf(\n      \"Atomic operations are not supported for half precision (FP16) \"\n      \"on this GPU.\\n\");\n  __trap();\n  return val;\n#endif  // __CUDA_ARCH__ >= 700\n}\n#endif  // defined(CUDART_VERSION) && CUDART_VERSION >= 10000\n\n#if BF16_ENABLED\ntemplate <>\n__device__ __forceinline__ __nv_bfloat16\nAtomicAdd<__nv_bfloat16>(__nv_bfloat16* addr, __nv_bfloat16 val) {\n// make sure we have bfloat16 support\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800\n  return atomicAdd(addr, val);\n#else\n  (void)addr;\n  (void)val;\n  printf(\n      \"Atomic operations are not supported for bfloat16 (BF16) \"\n      \"on GPUs with compute capability less than 8.0.\\n\");\n  __trap();\n  return val;\n#endif  // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800\n}\n#endif  // BF16_ENABLED\n\n}  // namespace cuda\n}  // namespace aten\n}  // namespace dgl\n\n#endif  // DGL_ARRAY_CUDA_ATOMIC_CUH_\n"
  },
  {
    "path": "src/array/cuda/bf16.cuh",
    "content": "/**\n *  Copyright (c) 2022 by Contributors\n *\n *  Licensed under the Apache License, Version 2.0 (the \"License\");\n *  you may not use this file except in compliance with the License.\n *  You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n *  Unless required by applicable law or agreed to in writing, software\n *  distributed under the License is distributed on an \"AS IS\" BASIS,\n *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n *  See the License for the specific language governing permissions and\n *  limitations under the License.\n *\n * @file array/cuda/bf16.cuh\n * @brief bfloat16 related functions.\n */\n#ifndef DGL_ARRAY_CUDA_BF16_CUH_\n#define DGL_ARRAY_CUDA_BF16_CUH_\n\n#if BF16_ENABLED\n#include <cuda_bf16.h>\n\n#include <algorithm>\n\nstatic __device__ __forceinline__ __nv_bfloat16\nmax(__nv_bfloat16 a, __nv_bfloat16 b) {\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800\n  return __hmax(a, b);\n#else\n  return __nv_bfloat16(max(float(a), float(b)));  // NOLINT\n#endif\n}\n\nstatic __device__ __forceinline__ __nv_bfloat16\nmin(__nv_bfloat16 a, __nv_bfloat16 b) {\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800\n  return __hmin(a, b);\n#else\n  return __nv_bfloat16(min(float(a), float(b)));  // NOLINT\n#endif\n}\n\n#ifdef __CUDACC__\n// Arithmetic BF16 operations for architecture >= 8.0 are already defined in\n// cuda_bf16.h\n#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)\n// CUDA 12.2 adds \"emulated\" support for older architectures.\n#if defined(CUDART_VERSION) && (CUDART_VERSION < 12020)\n__device__ __forceinline__ __nv_bfloat16\noperator+(const __nv_bfloat16& lh, const __nv_bfloat16& rh) {\n  return __nv_bfloat16(float(lh) + float(rh));  // NOLINT\n}\n__device__ __forceinline__ __nv_bfloat16\noperator-(const __nv_bfloat16& lh, const __nv_bfloat16& rh) {\n  return __nv_bfloat16(float(lh) - float(rh));  // NOLINT\n}\n__device__ __forceinline__ __nv_bfloat16\noperator*(const __nv_bfloat16& lh, const __nv_bfloat16& rh) {\n  return __nv_bfloat16(float(lh) * float(rh));  // NOLINT\n}\n__device__ __forceinline__ __nv_bfloat16\noperator/(const __nv_bfloat16& lh, const __nv_bfloat16& rh) {\n  return __nv_bfloat16(float(lh) / float(rh));  // NOLINT\n}\n\n__device__ __forceinline__ __nv_bfloat16& operator+=(\n    __nv_bfloat16& lh, const __nv_bfloat16& rh) {  // NOLINT\n  lh = __nv_bfloat16(float(lh) + float(rh));       // NOLINT\n  return lh;\n}\n__device__ __forceinline__ __nv_bfloat16& operator-=(\n    __nv_bfloat16& lh, const __nv_bfloat16& rh) {  // NOLINT\n  lh = __nv_bfloat16(float(lh) - float(rh));       // NOLINT\n  return lh;\n}\n__device__ __forceinline__ __nv_bfloat16& operator*=(\n    __nv_bfloat16& lh, const __nv_bfloat16& rh) {  // NOLINT\n  lh = __nv_bfloat16(float(lh) * float(rh));       // NOLINT\n  return lh;\n}\n__device__ __forceinline__ __nv_bfloat16& operator/=(\n    __nv_bfloat16& lh, const __nv_bfloat16& rh) {  // NOLINT\n  lh = __nv_bfloat16(float(lh) / float(rh));       // NOLINT\n  return lh;\n}\n\n__device__ __forceinline__ __nv_bfloat16& operator++(\n    __nv_bfloat16& h) {                // NOLINT\n  h = __nv_bfloat16(float(h) + 1.0f);  // NOLINT\n  return h;\n}\n__device__ __forceinline__ __nv_bfloat16& operator--(\n    __nv_bfloat16& h) {                // NOLINT\n  h = __nv_bfloat16(float(h) - 1.0f);  // NOLINT\n  return h;\n}\n__device__ __forceinline__ __nv_bfloat16\noperator++(__nv_bfloat16& h, int) {  // NOLINT\n  __nv_bfloat16 ret = h;\n  h = __nv_bfloat16(float(h) + 1.0f);  // NOLINT\n  return ret;\n}\n__device__ __forceinline__ __nv_bfloat16\noperator--(__nv_bfloat16& h, int) {  // NOLINT\n  __nv_bfloat16 ret = h;\n  h = __nv_bfloat16(float(h) - 1.0f);  // NOLINT\n  return ret;\n}\n\n__device__ __forceinline__ __nv_bfloat16 operator+(const __nv_bfloat16& h) {\n  return h;\n}\n__device__ __forceinline__ __nv_bfloat16 operator-(const __nv_bfloat16& h) {\n  return __nv_bfloat16(-float(h));  // NOLINT\n}\n\n__device__ __forceinline__ bool operator==(\n    const __nv_bfloat16& lh, const __nv_bfloat16& rh) {\n  return float(lh) == float(rh);  // NOLINT\n}\n__device__ __forceinline__ bool operator!=(\n    const __nv_bfloat16& lh, const __nv_bfloat16& rh) {\n  return float(lh) != float(rh);  // NOLINT\n}\n__device__ __forceinline__ bool operator>(\n    const __nv_bfloat16& lh, const __nv_bfloat16& rh) {\n  return float(lh) > float(rh);  // NOLINT\n}\n__device__ __forceinline__ bool operator<(\n    const __nv_bfloat16& lh, const __nv_bfloat16& rh) {\n  return float(lh) < float(rh);  // NOLINT\n}\n__device__ __forceinline__ bool operator>=(\n    const __nv_bfloat16& lh, const __nv_bfloat16& rh) {\n  return float(lh) >= float(rh);  // NOLINT\n}\n__device__ __forceinline__ bool operator<=(\n    const __nv_bfloat16& lh, const __nv_bfloat16& rh) {\n  return float(lh) <= float(rh);  // NOLINT\n}\n#endif  // defined(CUDART_VERSION) && (CUDART_VERSION < 12020)\n#endif  // defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)\n#endif  // __CUDACC__\n\n#endif  // BF16_ENABLED\n\n#endif  // DGL_ARRAY_CUDA_BF16_CUH_\n"
  },
  {
    "path": "src/array/cuda/coo2csr.cu",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file array/cuda/coo2csr.cc\n * @brief COO2CSR\n */\n#include <dgl/array.h>\n\n#include \"../../runtime/cuda/cuda_common.h\"\n#include \"./utils.h\"\n\nnamespace dgl {\n\nusing runtime::NDArray;\n\nnamespace aten {\nnamespace impl {\n\ntemplate <DGLDeviceType XPU, typename IdType>\nCSRMatrix COOToCSR(COOMatrix coo) {\n  LOG(FATAL) << \"Unreachable code.\";\n  return {};\n}\n\ntemplate <>\nCSRMatrix COOToCSR<kDGLCUDA, int32_t>(COOMatrix coo) {\n  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  // allocate cusparse handle if needed\n  if (!thr_entry->cusparse_handle) {\n    CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle)));\n  }\n  CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, stream));\n\n  bool row_sorted = coo.row_sorted;\n  bool col_sorted = coo.col_sorted;\n  if (!row_sorted) {\n    // we only need to sort the rows to perform conversion\n    coo = COOSort(coo, false);\n    col_sorted = coo.col_sorted;\n  }\n\n  const int64_t nnz = coo.row->shape[0];\n  CHECK_NO_OVERFLOW(coo.row->dtype, nnz);\n  // TODO(minjie): Many of our current implementation assumes that CSR must have\n  //   a data array. This is a temporary workaround. Remove this after:\n  //   - The old immutable graph implementation is deprecated.\n  //   - The old binary reduce kernel is deprecated.\n  if (!COOHasData(coo))\n    coo.data = aten::Range(0, nnz, coo.row->dtype.bits, coo.row->ctx);\n\n  NDArray indptr =\n      aten::NewIdArray(coo.num_rows + 1, coo.row->ctx, coo.row->dtype.bits);\n  int32_t* indptr_ptr = static_cast<int32_t*>(indptr->data);\n  CUSPARSE_CALL(cusparseXcoo2csr(\n      thr_entry->cusparse_handle, coo.row.Ptr<int32_t>(), nnz, coo.num_rows,\n      indptr_ptr, CUSPARSE_INDEX_BASE_ZERO));\n\n  return CSRMatrix(\n      coo.num_rows, coo.num_cols, indptr, coo.col, coo.data, col_sorted);\n}\n\n/**\n * @brief Search for the insertion positions for needle in the hay.\n *\n * The hay is a list of sorted elements and the result is the insertion position\n * of each needle so that the insertion still gives sorted order.\n *\n * It essentially perform binary search to find upper bound for each needle\n * elements.\n *\n * For example:\n * hay = [0, 0, 1, 2, 2]\n * needle = [0, 1, 2, 3]\n * then,\n * out = [2, 3, 5, 5]\n */\ntemplate <typename IdType>\n__global__ void _SortedSearchKernelUpperBound(\n    const IdType* hay, int64_t hay_size, const IdType* needles,\n    int64_t num_needles, IdType* pos) {\n  int tx = blockIdx.x * blockDim.x + threadIdx.x;\n  const int stride_x = gridDim.x * blockDim.x;\n  while (tx < num_needles) {\n    const IdType ele = needles[tx];\n    // binary search\n    IdType lo = 0, hi = hay_size;\n    while (lo < hi) {\n      IdType mid = (lo + hi) >> 1;\n      if (hay[mid] <= ele) {\n        lo = mid + 1;\n      } else {\n        hi = mid;\n      }\n    }\n    pos[tx] = lo;\n    tx += stride_x;\n  }\n}\n\ntemplate <>\nCSRMatrix COOToCSR<kDGLCUDA, int64_t>(COOMatrix coo) {\n  const auto& ctx = coo.row->ctx;\n  const auto nbits = coo.row->dtype.bits;\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  bool row_sorted = coo.row_sorted;\n  bool col_sorted = coo.col_sorted;\n  if (!row_sorted) {\n    coo = COOSort(coo, false);\n    col_sorted = coo.col_sorted;\n  }\n\n  const int64_t nnz = coo.row->shape[0];\n  // TODO(minjie): Many of our current implementation assumes that CSR must have\n  //   a data array. This is a temporary workaround. Remove this after:\n  //   - The old immutable graph implementation is deprecated.\n  //   - The old binary reduce kernel is deprecated.\n  if (!COOHasData(coo))\n    coo.data = aten::Range(0, nnz, coo.row->dtype.bits, coo.row->ctx);\n\n  IdArray rowids = Range(0, coo.num_rows, nbits, ctx);\n  const int nt = cuda::FindNumThreads(coo.num_rows);\n  const int nb = (coo.num_rows + nt - 1) / nt;\n  IdArray indptr = Full(0, coo.num_rows + 1, nbits, ctx);\n  CUDA_KERNEL_CALL(\n      _SortedSearchKernelUpperBound, nb, nt, 0, stream, coo.row.Ptr<int64_t>(),\n      nnz, rowids.Ptr<int64_t>(), coo.num_rows, indptr.Ptr<int64_t>() + 1);\n\n  return CSRMatrix(\n      coo.num_rows, coo.num_cols, indptr, coo.col, coo.data, col_sorted);\n}\n\ntemplate CSRMatrix COOToCSR<kDGLCUDA, int32_t>(COOMatrix coo);\ntemplate CSRMatrix COOToCSR<kDGLCUDA, int64_t>(COOMatrix coo);\n\n}  // namespace impl\n}  // namespace aten\n}  // namespace dgl\n"
  },
  {
    "path": "src/array/cuda/coo_sort.cu",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file array/cuda/coo_sort.cc\n * @brief Sort COO index\n */\n#include <dgl/array.h>\n\n#include \"../../c_api_common.h\"\n#include \"../../runtime/cuda/cuda_common.h\"\n#include \"./utils.h\"\n\nnamespace dgl {\n\nusing runtime::NDArray;\n\nnamespace aten {\nnamespace impl {\n\n///////////////////////////// COOSort_ /////////////////////////////\n\n/**\n * @brief Encode row and column IDs into a single scalar per edge.\n *\n * @tparam IdType The type to encode as.\n * @param row The row (src) IDs per edge.\n * @param col The column (dst) IDs per edge.\n * @param nnz The number of edges.\n * @param col_bits The number of bits used to encode the destination. The row\n * information is packed into the remaining bits.\n * @param key The encoded edges (output).\n */\ntemplate <typename IdType>\n__global__ void _COOEncodeEdgesKernel(\n    const IdType* const row, const IdType* const col, const int64_t nnz,\n    const int col_bits, IdType* const key) {\n  int64_t tx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;\n\n  if (tx < nnz) {\n    key[tx] = row[tx] << col_bits | col[tx];\n  }\n}\n\n/**\n * @brief Decode row and column IDs from the encoded edges.\n *\n * @tparam IdType The type the edges are encoded as.\n * @param key The encoded edges.\n * @param nnz The number of edges.\n * @param col_bits The number of bits used to store the column/dst ID.\n * @param row The row (src) IDs per edge (output).\n * @param col The col (dst) IDs per edge (output).\n */\ntemplate <typename IdType>\n__global__ void _COODecodeEdgesKernel(\n    const IdType* const key, const int64_t nnz, const int col_bits,\n    IdType* const row, IdType* const col) {\n  int64_t tx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;\n\n  if (tx < nnz) {\n    const IdType k = key[tx];\n    row[tx] = k >> col_bits;\n    col[tx] = k & ((1 << col_bits) - 1);\n  }\n}\n\ntemplate <DGLDeviceType XPU, typename IdType>\nvoid COOSort_(COOMatrix* coo, bool sort_column) {\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  const int row_bits = cuda::_NumberOfBits(coo->num_rows);\n\n  const int64_t nnz = coo->row->shape[0];\n  if (sort_column) {\n    const int col_bits = cuda::_NumberOfBits(coo->num_cols);\n    const int num_bits = row_bits + col_bits;\n\n    const int nt = 256;\n    const int nb = (nnz + nt - 1) / nt;\n    CHECK(static_cast<int64_t>(nb) * nt >= nnz);\n\n    IdArray pos = aten::NewIdArray(nnz, coo->row->ctx, coo->row->dtype.bits);\n\n    CUDA_KERNEL_CALL(\n        _COOEncodeEdgesKernel, nb, nt, 0, stream, coo->row.Ptr<IdType>(),\n        coo->col.Ptr<IdType>(), nnz, col_bits, pos.Ptr<IdType>());\n\n    auto sorted = Sort(pos, num_bits);\n\n    CUDA_KERNEL_CALL(\n        _COODecodeEdgesKernel, nb, nt, 0, stream, sorted.first.Ptr<IdType>(),\n        nnz, col_bits, coo->row.Ptr<IdType>(), coo->col.Ptr<IdType>());\n\n    if (aten::COOHasData(*coo))\n      coo->data = IndexSelect(coo->data, sorted.second);\n    else\n      coo->data = AsNumBits(sorted.second, coo->row->dtype.bits);\n    coo->row_sorted = coo->col_sorted = true;\n  } else {\n    const int num_bits = row_bits;\n\n    auto sorted = Sort(coo->row, num_bits);\n\n    coo->row = sorted.first;\n    coo->col = IndexSelect(coo->col, sorted.second);\n\n    if (aten::COOHasData(*coo))\n      coo->data = IndexSelect(coo->data, sorted.second);\n    else\n      coo->data = AsNumBits(sorted.second, coo->row->dtype.bits);\n    coo->row_sorted = true;\n  }\n}\n\ntemplate void COOSort_<kDGLCUDA, int32_t>(COOMatrix* coo, bool sort_column);\ntemplate void COOSort_<kDGLCUDA, int64_t>(COOMatrix* coo, bool sort_column);\n\n///////////////////////////// COOIsSorted /////////////////////////////\n\ntemplate <typename IdType>\n__global__ void _COOIsSortedKernel(\n    const IdType* row, const IdType* col, int64_t nnz, int8_t* row_sorted,\n    int8_t* col_sorted) {\n  int tx = blockIdx.x * blockDim.x + threadIdx.x;\n  const int stride_x = gridDim.x * blockDim.x;\n  while (tx < nnz) {\n    if (tx == 0) {\n      row_sorted[0] = 1;\n      col_sorted[0] = 1;\n    } else {\n      row_sorted[tx] = static_cast<int8_t>(row[tx - 1] <= row[tx]);\n      col_sorted[tx] =\n          static_cast<int8_t>(row[tx - 1] < row[tx] || col[tx - 1] <= col[tx]);\n    }\n    tx += stride_x;\n  }\n}\n\ntemplate <DGLDeviceType XPU, typename IdType>\nstd::pair<bool, bool> COOIsSorted(COOMatrix coo) {\n  const int64_t nnz = coo.row->shape[0];\n  const auto& ctx = coo.row->ctx;\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  auto device = runtime::DeviceAPI::Get(ctx);\n  // We allocate a workspace of 2*nnz bytes. It wastes a little bit memory but\n  // should be fine.\n  int8_t* row_flags = static_cast<int8_t*>(device->AllocWorkspace(ctx, nnz));\n  int8_t* col_flags = static_cast<int8_t*>(device->AllocWorkspace(ctx, nnz));\n  const int nt = cuda::FindNumThreads(nnz);\n  const int nb = (nnz + nt - 1) / nt;\n  CUDA_KERNEL_CALL(\n      _COOIsSortedKernel, nb, nt, 0, stream, coo.row.Ptr<IdType>(),\n      coo.col.Ptr<IdType>(), nnz, row_flags, col_flags);\n\n  const bool row_sorted = cuda::AllTrue(row_flags, nnz, ctx);\n  const bool col_sorted =\n      row_sorted ? cuda::AllTrue(col_flags, nnz, ctx) : false;\n\n  device->FreeWorkspace(ctx, row_flags);\n  device->FreeWorkspace(ctx, col_flags);\n\n  return {row_sorted, col_sorted};\n}\n\ntemplate std::pair<bool, bool> COOIsSorted<kDGLCUDA, int32_t>(COOMatrix coo);\ntemplate std::pair<bool, bool> COOIsSorted<kDGLCUDA, int64_t>(COOMatrix coo);\n\n}  // namespace impl\n}  // namespace aten\n}  // namespace dgl\n"
  },
  {
    "path": "src/array/cuda/csr2coo.cu",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file array/cuda/csr2coo.cc\n * @brief CSR2COO\n */\n#include <dgl/array.h>\n#include <thrust/iterator/constant_iterator.h>\n#include <thrust/iterator/counting_iterator.h>\n#include <thrust/iterator/transform_iterator.h>\n\n#include <cub/cub.cuh>\n\n#include \"../../runtime/cuda/cuda_common.h\"\n#include \"./utils.h\"\n\nnamespace dgl {\n\nusing runtime::NDArray;\n\nnamespace aten {\nnamespace impl {\n\ntemplate <DGLDeviceType XPU, typename IdType>\nCOOMatrix CSRToCOO(CSRMatrix csr) {\n  LOG(FATAL) << \"Unreachable codes\";\n  return {};\n}\n\ntemplate <>\nCOOMatrix CSRToCOO<kDGLCUDA, int32_t>(CSRMatrix csr) {\n  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  // allocate cusparse handle if needed\n  if (!thr_entry->cusparse_handle) {\n    CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle)));\n  }\n  CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, stream));\n\n  NDArray indptr = csr.indptr, indices = csr.indices, data = csr.data;\n  const int32_t* indptr_ptr = static_cast<int32_t*>(indptr->data);\n  NDArray row =\n      aten::NewIdArray(indices->shape[0], indptr->ctx, indptr->dtype.bits);\n  int32_t* row_ptr = static_cast<int32_t*>(row->data);\n\n  CUSPARSE_CALL(cusparseXcsr2coo(\n      thr_entry->cusparse_handle, indptr_ptr, indices->shape[0], csr.num_rows,\n      row_ptr, CUSPARSE_INDEX_BASE_ZERO));\n\n  return COOMatrix(\n      csr.num_rows, csr.num_cols, row, indices, data, true, csr.sorted);\n}\n\nstruct RepeatIndex {\n  template <typename IdType>\n  __host__ __device__ auto operator()(IdType i) {\n    return thrust::make_constant_iterator(i);\n  }\n};\n\ntemplate <typename IdType>\nstruct OutputBufferIndexer {\n  const IdType* indptr;\n  IdType* buffer;\n  __host__ __device__ auto operator()(IdType i) { return buffer + indptr[i]; }\n};\n\ntemplate <typename IdType>\nstruct AdjacentDifference {\n  const IdType* indptr;\n  __host__ __device__ auto operator()(IdType i) {\n    return indptr[i + 1] - indptr[i];\n  }\n};\n\ntemplate <>\nCOOMatrix CSRToCOO<kDGLCUDA, int64_t>(CSRMatrix csr) {\n  const auto& ctx = csr.indptr->ctx;\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n\n  const int64_t nnz = csr.indices->shape[0];\n  const auto nbits = csr.indptr->dtype.bits;\n  IdArray ret_row = NewIdArray(nnz, ctx, nbits);\n\n  runtime::CUDAWorkspaceAllocator allocator(csr.indptr->ctx);\n  thrust::counting_iterator<int64_t> iota(0);\n\n  auto input_buffer = thrust::make_transform_iterator(iota, RepeatIndex{});\n  auto output_buffer = thrust::make_transform_iterator(\n      iota, OutputBufferIndexer<int64_t>{\n                csr.indptr.Ptr<int64_t>(), ret_row.Ptr<int64_t>()});\n  auto buffer_sizes = thrust::make_transform_iterator(\n      iota, AdjacentDifference<int64_t>{csr.indptr.Ptr<int64_t>()});\n\n  constexpr int64_t max_copy_at_once = std::numeric_limits<int32_t>::max();\n  for (int64_t i = 0; i < csr.num_rows; i += max_copy_at_once) {\n    std::size_t temp_storage_bytes = 0;\n    CUDA_CALL(cub::DeviceCopy::Batched(\n        nullptr, temp_storage_bytes, input_buffer + i, output_buffer + i,\n        buffer_sizes + i, std::min(csr.num_rows - i, max_copy_at_once),\n        stream));\n\n    auto temp = allocator.alloc_unique<char>(temp_storage_bytes);\n\n    CUDA_CALL(cub::DeviceCopy::Batched(\n        temp.get(), temp_storage_bytes, input_buffer + i, output_buffer + i,\n        buffer_sizes + i, std::min(csr.num_rows - i, max_copy_at_once),\n        stream));\n  }\n\n  return COOMatrix(\n      csr.num_rows, csr.num_cols, ret_row, csr.indices, csr.data, true,\n      csr.sorted);\n}\n\ntemplate COOMatrix CSRToCOO<kDGLCUDA, int32_t>(CSRMatrix csr);\ntemplate COOMatrix CSRToCOO<kDGLCUDA, int64_t>(CSRMatrix csr);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nCOOMatrix CSRToCOODataAsOrder(CSRMatrix csr) {\n  LOG(FATAL) << \"Unreachable codes\";\n  return {};\n}\n\ntemplate <>\nCOOMatrix CSRToCOODataAsOrder<kDGLCUDA, int32_t>(CSRMatrix csr) {\n  COOMatrix coo = CSRToCOO<kDGLCUDA, int32_t>(csr);\n  if (aten::IsNullArray(coo.data)) return coo;\n\n  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();\n  auto device = runtime::DeviceAPI::Get(coo.row->ctx);\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  // allocate cusparse handle if needed\n  if (!thr_entry->cusparse_handle) {\n    CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle)));\n  }\n  CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, stream));\n\n  NDArray row = coo.row, col = coo.col, data = coo.data;\n  int32_t* row_ptr = static_cast<int32_t*>(row->data);\n  int32_t* col_ptr = static_cast<int32_t*>(col->data);\n  int32_t* data_ptr = static_cast<int32_t*>(data->data);\n\n  size_t workspace_size = 0;\n  CUSPARSE_CALL(cusparseXcoosort_bufferSizeExt(\n      thr_entry->cusparse_handle, coo.num_rows, coo.num_cols, row->shape[0],\n      data_ptr, row_ptr, &workspace_size));\n  void* workspace = device->AllocWorkspace(row->ctx, workspace_size);\n  CUSPARSE_CALL(cusparseXcoosortByRow(\n      thr_entry->cusparse_handle, coo.num_rows, coo.num_cols, row->shape[0],\n      data_ptr, row_ptr, col_ptr, workspace));\n  device->FreeWorkspace(row->ctx, workspace);\n\n  // The row and column field have already been reordered according\n  // to data, thus the data field will be deprecated.\n  coo.data = aten::NullArray();\n  coo.row_sorted = false;\n  coo.col_sorted = false;\n  return coo;\n}\n\ntemplate <>\nCOOMatrix CSRToCOODataAsOrder<kDGLCUDA, int64_t>(CSRMatrix csr) {\n  COOMatrix coo = CSRToCOO<kDGLCUDA, int64_t>(csr);\n  if (aten::IsNullArray(coo.data)) return coo;\n  const auto& sorted = Sort(coo.data);\n\n  coo.row = IndexSelect(coo.row, sorted.second);\n  coo.col = IndexSelect(coo.col, sorted.second);\n\n  // The row and column field have already been reordered according\n  // to data, thus the data field will be deprecated.\n  coo.data = aten::NullArray();\n  coo.row_sorted = false;\n  coo.col_sorted = false;\n  return coo;\n}\n\ntemplate COOMatrix CSRToCOODataAsOrder<kDGLCUDA, int32_t>(CSRMatrix csr);\ntemplate COOMatrix CSRToCOODataAsOrder<kDGLCUDA, int64_t>(CSRMatrix csr);\n\n}  // namespace impl\n}  // namespace aten\n}  // namespace dgl\n"
  },
  {
    "path": "src/array/cuda/csr_get_data.cu",
    "content": "/**\n *  Copyright (c) 2021 by Contributors\n * @file array/cuda/csr_get_data.cu\n * @brief Retrieve entries of a CSR matrix\n */\n#include <dgl/array.h>\n\n#include <numeric>\n#include <unordered_set>\n#include <vector>\n\n#include \"../../runtime/cuda/cuda_common.h\"\n#include \"./utils.h\"\n\nnamespace dgl {\n\nusing runtime::NDArray;\n\nnamespace aten {\nnamespace impl {\n\ntemplate <DGLDeviceType XPU, typename IdType, typename DType>\nNDArray CSRGetData(\n    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,\n    NDArray weights, DType filler) {\n  const int64_t rowlen = rows->shape[0];\n  const int64_t collen = cols->shape[0];\n\n  CHECK((rowlen == collen) || (rowlen == 1) || (collen == 1))\n      << \"Invalid row and col id array.\";\n\n  const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1;\n  const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1;\n\n  const int64_t rstlen = std::max(rowlen, collen);\n  IdArray rst = NDArray::Empty({rstlen}, weights->dtype, rows->ctx);\n  if (rstlen == 0) return rst;\n\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  const int nt = cuda::FindNumThreads(rstlen);\n  const int nb = (rstlen + nt - 1) / nt;\n  if (return_eids)\n    BUG_IF_FAIL(DGLDataTypeTraits<DType>::dtype == rows->dtype)\n        << \"DType does not match row's dtype.\";\n\n  const IdType* indptr_data =\n      static_cast<IdType*>(cuda::GetDevicePointer(csr.indptr));\n  const IdType* indices_data =\n      static_cast<IdType*>(cuda::GetDevicePointer(csr.indices));\n  const IdType* data_data =\n      CSRHasData(csr) ? static_cast<IdType*>(cuda::GetDevicePointer(csr.data))\n                      : nullptr;\n\n  // TODO(minjie): use binary search for sorted csr\n  CUDA_KERNEL_CALL(\n      cuda::_LinearSearchKernel, nb, nt, 0, stream, indptr_data, indices_data,\n      data_data, rows.Ptr<IdType>(), cols.Ptr<IdType>(), row_stride, col_stride,\n      rstlen, return_eids ? nullptr : weights.Ptr<DType>(), filler,\n      rst.Ptr<DType>());\n  return rst;\n}\n\ntemplate NDArray CSRGetData<kDGLCUDA, int32_t, __half>(\n    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,\n    NDArray weights, __half filler);\ntemplate NDArray CSRGetData<kDGLCUDA, int64_t, __half>(\n    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,\n    NDArray weights, __half filler);\n#if BF16_ENABLED\ntemplate NDArray CSRGetData<kDGLCUDA, int32_t, __nv_bfloat16>(\n    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,\n    NDArray weights, __nv_bfloat16 filler);\ntemplate NDArray CSRGetData<kDGLCUDA, int64_t, __nv_bfloat16>(\n    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,\n    NDArray weights, __nv_bfloat16 filler);\n#endif  // BF16_ENABLED\ntemplate NDArray CSRGetData<kDGLCUDA, int32_t, float>(\n    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,\n    NDArray weights, float filler);\ntemplate NDArray CSRGetData<kDGLCUDA, int64_t, float>(\n    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,\n    NDArray weights, float filler);\ntemplate NDArray CSRGetData<kDGLCUDA, int32_t, double>(\n    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,\n    NDArray weights, double filler);\ntemplate NDArray CSRGetData<kDGLCUDA, int64_t, double>(\n    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,\n    NDArray weights, double filler);\n\n// For CSRGetData<XPU, IdType>(CSRMatrix, NDArray, NDArray)\ntemplate NDArray CSRGetData<kDGLCUDA, int32_t, int32_t>(\n    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,\n    NDArray weights, int32_t filler);\ntemplate NDArray CSRGetData<kDGLCUDA, int64_t, int64_t>(\n    CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,\n    NDArray weights, int64_t filler);\n\n}  // namespace impl\n}  // namespace aten\n}  // namespace dgl\n"
  },
  {
    "path": "src/array/cuda/csr_mm.cu",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file array/cuda/csr_mm.cu\n * @brief SpSpMM/SpGEMM C APIs and definitions.\n */\n#include <dgl/array.h>\n#include <dgl/runtime/device_api.h>\n\n#include <limits>\n\n#include \"../../runtime/cuda/cuda_common.h\"\n#include \"./cusparse_dispatcher.cuh\"\n#include \"./functor.cuh\"\nnamespace dgl {\n\nusing namespace dgl::runtime;\n\nnamespace aten {\nnamespace cusparse {\n\n#if CUDART_VERSION >= 12000\n\n/** @brief Cusparse implementation of SpGEMM on Csr format for CUDA 12.0+ */\ntemplate <typename DType, typename IdType>\nstd::pair<CSRMatrix, NDArray> CusparseSpgemm(\n    const CSRMatrix& A, const NDArray A_weights_array, const CSRMatrix& B,\n    const NDArray B_weights_array) {\n  // We use Spgemm (SpSpMM) to perform following operation:\n  // C = A x B, where A, B and C are sparse matrices in csr format.\n  const int nnzA = A.indices->shape[0];\n  const int nnzB = B.indices->shape[0];\n  const DType alpha = 1.0;\n  const DType beta = 0.0;\n  auto transA = CUSPARSE_OPERATION_NON_TRANSPOSE;\n  auto transB = CUSPARSE_OPERATION_NON_TRANSPOSE;\n  // device\n  auto ctx = A.indptr->ctx;\n  auto device = runtime::DeviceAPI::Get(ctx);\n  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  const DType* A_weights = A_weights_array.Ptr<DType>();\n  const DType* B_weights = B_weights_array.Ptr<DType>();\n  // allocate cusparse handle if needed\n  if (!thr_entry->cusparse_handle) {\n    CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle)));\n  }\n  CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, stream));\n  // all one data array\n  cusparseSpMatDescr_t matA, matB, matC;\n  IdArray dC_csrOffsets =\n      IdArray::Empty({A.num_rows + 1}, A.indptr->dtype, A.indptr->ctx);\n  IdType* dC_csrOffsets_data = dC_csrOffsets.Ptr<IdType>();\n  constexpr auto idtype = cusparse_idtype<IdType>::value;\n  constexpr auto dtype = cuda_dtype<DType>::value;\n  // Create sparse matrix A, B and C in CSR format\n  CUSPARSE_CALL(cusparseCreateCsr(\n      &matA, A.num_rows, A.num_cols, nnzA, A.indptr.Ptr<IdType>(),\n      A.indices.Ptr<IdType>(),\n      // cusparseCreateCsr only accepts non-const pointers.\n      const_cast<DType*>(A_weights), idtype, idtype, CUSPARSE_INDEX_BASE_ZERO,\n      dtype));\n  CUSPARSE_CALL(cusparseCreateCsr(\n      &matB, B.num_rows, B.num_cols, nnzB, B.indptr.Ptr<IdType>(),\n      B.indices.Ptr<IdType>(),\n      // cusparseCreateCsr only accepts non-const pointers.\n      const_cast<DType*>(B_weights), idtype, idtype, CUSPARSE_INDEX_BASE_ZERO,\n      dtype));\n  CUSPARSE_CALL(cusparseCreateCsr(\n      &matC, A.num_rows, B.num_cols, 0, dC_csrOffsets_data, nullptr, nullptr,\n      idtype, idtype, CUSPARSE_INDEX_BASE_ZERO, dtype));\n  // SpGEMM Computation\n  cusparseSpGEMMDescr_t spgemmDesc;\n  cusparseSpGEMMAlg_t alg = CUSPARSE_SPGEMM_DEFAULT;\n\n  CUSPARSE_CALL(cusparseSpGEMM_createDescr(&spgemmDesc));\n  size_t workspace_size1 = 0, workspace_size2 = 0, workspace_size3 = 0;\n  // ask bufferSize1 bytes for external memory\n  CUSPARSE_CALL(cusparseSpGEMM_workEstimation(\n      thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,\n      matC, dtype, alg, spgemmDesc, &workspace_size1, NULL));\n  void* workspace1 = (device->AllocWorkspace(ctx, workspace_size1));\n  // inspect the matrices A and B to understand the memory requiremnent\n  cusparseStatus_t e = cusparseSpGEMM_workEstimation(\n      thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,\n      matC, dtype, alg, spgemmDesc, &workspace_size1, workspace1);\n  // CUSPARSE_SPGEMM_DEFAULT not support getting num_prods > 2^31 -1\n  // and throws insufficient memory error within workEstimation call\n  if (e == CUSPARSE_STATUS_INSUFFICIENT_RESOURCES) {\n    // fall back to ALG2 to estimate num_prods\n    alg = CUSPARSE_SPGEMM_ALG2;\n    device->FreeWorkspace(ctx, workspace1);\n    // rerun cusparseSpGEMM_workEstimation\n    CUSPARSE_CALL(cusparseSpGEMM_workEstimation(\n        thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,\n        matC, dtype, alg, spgemmDesc, &workspace_size1, NULL));\n    workspace1 = (device->AllocWorkspace(ctx, workspace_size1));\n    CUSPARSE_CALL(cusparseSpGEMM_workEstimation(\n        thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,\n        matC, dtype, alg, spgemmDesc, &workspace_size1, workspace1));\n  } else {\n    CHECK(e == CUSPARSE_STATUS_SUCCESS) << \"CUSPARSE ERROR in SpGEMM: \" << e;\n  }\n\n  // get the number of intermediate products required for SpGEMM compute\n  // num_prods indicates device memory consumption for SpGEMM if using ALG2/3\n  int64_t num_prods;\n  CUSPARSE_CALL(cusparseSpGEMM_getNumProducts(spgemmDesc, &num_prods));\n\n  // assume free GPU mem at least ~15G for below heuristics to work\n  // user-defined medium problem size (below will use DEFAULT)\n  int64_t MEDIUM_NUM_PRODUCTS = 400000000;  // 400*1000*1000;\n  // user-defined large problem size (above will use ALG3)\n  int64_t LARGE_NUM_PRODUCTS = 800000000;  // 800*1000*1000;\n\n  // switch to ALG2/ALG3 for medium & large problem size\n  if (alg == CUSPARSE_SPGEMM_DEFAULT && num_prods > MEDIUM_NUM_PRODUCTS) {\n    // use ALG3 for very large problem\n    alg = num_prods > LARGE_NUM_PRODUCTS ? CUSPARSE_SPGEMM_ALG3\n                                         : CUSPARSE_SPGEMM_ALG2;\n\n    device->FreeWorkspace(ctx, workspace1);\n    // rerun cusparseSpGEMM_workEstimation\n    CUSPARSE_CALL(cusparseSpGEMM_workEstimation(\n        thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,\n        matC, dtype, alg, spgemmDesc, &workspace_size1, NULL));\n    workspace1 = (device->AllocWorkspace(ctx, workspace_size1));\n    CUSPARSE_CALL(cusparseSpGEMM_workEstimation(\n        thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,\n        matC, dtype, alg, spgemmDesc, &workspace_size1, workspace1));\n  } else if (alg == CUSPARSE_SPGEMM_ALG2 && num_prods > LARGE_NUM_PRODUCTS) {\n    // no need to rerun cusparseSpGEMM_workEstimation between ALG2 and ALG3\n    alg = CUSPARSE_SPGEMM_ALG3;\n  }\n\n  if (alg == CUSPARSE_SPGEMM_ALG2 || alg == CUSPARSE_SPGEMM_ALG3) {\n    // estimate memory for ALG2/ALG3; note chunk_fraction is only used by ALG3\n    // reduce chunk_fraction if crash due to mem., but it trades off speed\n    float chunk_fraction = num_prods < 4 * LARGE_NUM_PRODUCTS ? 0.15 : 0.05;\n    CUSPARSE_CALL(cusparseSpGEMM_estimateMemory(\n        thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,\n        matC, dtype, alg, spgemmDesc, chunk_fraction, &workspace_size3, NULL,\n        NULL));\n    void* workspace3 = (device->AllocWorkspace(ctx, workspace_size3));\n    CUSPARSE_CALL(cusparseSpGEMM_estimateMemory(\n        thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,\n        matC, dtype, alg, spgemmDesc, chunk_fraction, &workspace_size3,\n        workspace3, &workspace_size2));\n    device->FreeWorkspace(ctx, workspace3);\n  } else {\n    CUSPARSE_CALL(cusparseSpGEMM_compute(\n        thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,\n        matC, dtype, alg, spgemmDesc, &workspace_size2, NULL));\n  }\n  // ask bufferSize2 bytes for external memory\n  void* workspace2 = device->AllocWorkspace(ctx, workspace_size2);\n  // compute the intermediate product of A * B\n  CUSPARSE_CALL(cusparseSpGEMM_compute(\n      thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,\n      matC, dtype, alg, spgemmDesc, &workspace_size2, workspace2));\n  // get matrix C non-zero entries C_nnz1\n  int64_t C_num_rows1, C_num_cols1, C_nnz1;\n  CUSPARSE_CALL(\n      cusparseSpMatGetSize(matC, &C_num_rows1, &C_num_cols1, &C_nnz1));\n  IdArray dC_columns = IdArray::Empty({C_nnz1}, A.indptr->dtype, A.indptr->ctx);\n  NDArray dC_weights =\n      NDArray::Empty({C_nnz1}, A_weights_array->dtype, A.indptr->ctx);\n  IdType* dC_columns_data = dC_columns.Ptr<IdType>();\n  DType* dC_weights_data = dC_weights.Ptr<DType>();\n  // update matC with the new pointers\n  CUSPARSE_CALL(cusparseCsrSetPointers(\n      matC, dC_csrOffsets_data, dC_columns_data, dC_weights_data));\n  // copy the final products to the matrix C\n  CUSPARSE_CALL(cusparseSpGEMM_copy(\n      thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,\n      matC, dtype, alg, spgemmDesc));\n\n  device->FreeWorkspace(ctx, workspace1);\n  device->FreeWorkspace(ctx, workspace2);\n  // destroy matrix/vector descriptors\n  CUSPARSE_CALL(cusparseSpGEMM_destroyDescr(spgemmDesc));\n  CUSPARSE_CALL(cusparseDestroySpMat(matA));\n  CUSPARSE_CALL(cusparseDestroySpMat(matB));\n  CUSPARSE_CALL(cusparseDestroySpMat(matC));\n  return {\n      CSRMatrix(\n          A.num_rows, B.num_cols, dC_csrOffsets, dC_columns,\n          NullArray(dC_csrOffsets->dtype, dC_csrOffsets->ctx)),\n      dC_weights};\n}\n\n#else  // CUDART_VERSION < 12000\n\n/** @brief Cusparse implementation of SpGEMM on Csr format for older CUDA\n * versions */\ntemplate <typename DType, typename IdType>\nstd::pair<CSRMatrix, NDArray> CusparseSpgemm(\n    const CSRMatrix& A, const NDArray A_weights_array, const CSRMatrix& B,\n    const NDArray B_weights_array) {\n  int nnzC;\n  csrgemm2Info_t info = nullptr;\n  size_t workspace_size;\n  const DType alpha = 1.;\n  const int nnzA = A.indices->shape[0];\n  const int nnzB = B.indices->shape[0];\n  const int m = A.num_rows;\n  const int n = A.num_cols;\n  const int k = B.num_cols;\n  auto ctx = A.indptr->ctx;\n  auto device = runtime::DeviceAPI::Get(ctx);\n  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  auto idtype = A.indptr->dtype;\n  auto dtype = A_weights_array->dtype;\n  const DType* A_weights = A_weights_array.Ptr<DType>();\n  const DType* B_weights = B_weights_array.Ptr<DType>();\n  if (!thr_entry->cusparse_handle) {\n    CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle)));\n  }\n  CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, stream));\n  CUSPARSE_CALL(cusparseSetPointerMode(\n      thr_entry->cusparse_handle, CUSPARSE_POINTER_MODE_HOST));\n\n  CUSPARSE_CALL(cusparseCreateCsrgemm2Info(&info));\n\n  cusparseMatDescr_t matA, matB, matC, matD;\n  CUSPARSE_CALL(cusparseCreateMatDescr(&matA));\n  CUSPARSE_CALL(cusparseCreateMatDescr(&matB));\n  CUSPARSE_CALL(cusparseCreateMatDescr(&matC));\n  CUSPARSE_CALL(cusparseCreateMatDescr(&matD));  // needed even if D is null\n\n  CUSPARSE_CALL(CSRGEMM<DType>::bufferSizeExt(\n      thr_entry->cusparse_handle, m, n, k, &alpha, matA, nnzA,\n      A.indptr.Ptr<IdType>(), A.indices.Ptr<IdType>(), matB, nnzB,\n      B.indptr.Ptr<IdType>(), B.indices.Ptr<IdType>(), nullptr, matD, 0,\n      nullptr, nullptr, info, &workspace_size));\n\n  void* workspace = device->AllocWorkspace(ctx, workspace_size);\n  IdArray C_indptr = IdArray::Empty({m + 1}, idtype, ctx);\n  CUSPARSE_CALL(CSRGEMM<DType>::nnz(\n      thr_entry->cusparse_handle, m, n, k, matA, nnzA, A.indptr.Ptr<IdType>(),\n      A.indices.Ptr<IdType>(), matB, nnzB, B.indptr.Ptr<IdType>(),\n      B.indices.Ptr<IdType>(), matD, 0, nullptr, nullptr, matC,\n      C_indptr.Ptr<IdType>(), &nnzC, info, workspace));\n\n  IdArray C_indices = IdArray::Empty({nnzC}, idtype, ctx);\n  NDArray C_weights = NDArray::Empty({nnzC}, dtype, ctx);\n  CUSPARSE_CALL(CSRGEMM<DType>::compute(\n      thr_entry->cusparse_handle, m, n, k, &alpha, matA, nnzA, A_weights,\n      A.indptr.Ptr<IdType>(), A.indices.Ptr<IdType>(), matB, nnzB, B_weights,\n      B.indptr.Ptr<IdType>(), B.indices.Ptr<IdType>(), nullptr, matD, 0,\n      nullptr, nullptr, nullptr, matC, C_weights.Ptr<DType>(),\n      C_indptr.Ptr<IdType>(), C_indices.Ptr<IdType>(), info, workspace));\n\n  device->FreeWorkspace(ctx, workspace);\n  CUSPARSE_CALL(cusparseDestroyCsrgemm2Info(info));\n  CUSPARSE_CALL(cusparseDestroyMatDescr(matA));\n  CUSPARSE_CALL(cusparseDestroyMatDescr(matB));\n  CUSPARSE_CALL(cusparseDestroyMatDescr(matC));\n  CUSPARSE_CALL(cusparseDestroyMatDescr(matD));\n\n  return {\n      CSRMatrix(\n          m, k, C_indptr, C_indices, NullArray(C_indptr->dtype, C_indptr->ctx)),\n      C_weights};\n}\n\n#endif  // CUDART_VERSION >= 12000\n}  // namespace cusparse\n\ntemplate <int XPU, typename IdType, typename DType>\nstd::pair<CSRMatrix, NDArray> CSRMM(\n    const CSRMatrix& A, NDArray A_weights, const CSRMatrix& B,\n    NDArray B_weights) {\n  auto ctx = A.indptr->ctx;\n  auto device = runtime::DeviceAPI::Get(ctx);\n  CSRMatrix newA, newB;\n  bool cast = false;\n\n  // Cast 64 bit indices to 32 bit.\n  if (A.indptr->dtype.bits == 64) {\n    newA = CSRMatrix(\n        A.num_rows, A.num_cols, AsNumBits(A.indptr, 32),\n        AsNumBits(A.indices, 32), AsNumBits(A.data, 32));\n    newB = CSRMatrix(\n        B.num_rows, B.num_cols, AsNumBits(B.indptr, 32),\n        AsNumBits(B.indices, 32), AsNumBits(B.data, 32));\n    cast = true;\n  }\n\n  // Reorder weights if A or B has edge IDs\n  NDArray newA_weights, newB_weights;\n  if (CSRHasData(A)) newA_weights = IndexSelect(A_weights, A.data);\n  if (CSRHasData(B)) newB_weights = IndexSelect(B_weights, B.data);\n\n  auto result = cusparse::CusparseSpgemm<DType, int32_t>(\n      cast ? newA : A, CSRHasData(A) ? newA_weights : A_weights,\n      cast ? newB : B, CSRHasData(B) ? newB_weights : B_weights);\n\n  // Cast 32 bit indices back to 64 bit if necessary\n  if (cast) {\n    CSRMatrix C = result.first;\n    return {\n        CSRMatrix(\n            C.num_rows, C.num_cols, AsNumBits(C.indptr, 64),\n            AsNumBits(C.indices, 64), AsNumBits(C.data, 64)),\n        result.second};\n  } else {\n    return result;\n  }\n}\n\ntemplate std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int32_t, __half>(\n    const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);\ntemplate std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int64_t, __half>(\n    const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);\n#if BF16_ENABLED\ntemplate std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int32_t, __nv_bfloat16>(\n    const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);\ntemplate std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int64_t, __nv_bfloat16>(\n    const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);\n#endif  // BF16_ENABLED\ntemplate std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int32_t, float>(\n    const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);\ntemplate std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int64_t, float>(\n    const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);\ntemplate std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int32_t, double>(\n    const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);\ntemplate std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int64_t, double>(\n    const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);\n\n}  // namespace aten\n}  // namespace dgl\n"
  },
  {
    "path": "src/array/cuda/csr_sort.cu",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file array/cuda/csr_sort.cc\n * @brief Sort CSR index\n */\n#include <dgl/array.h>\n\n#include <cub/cub.cuh>\n\n#include \"../../runtime/cuda/cuda_common.h\"\n#include \"./utils.h\"\n\nnamespace dgl {\n\nusing runtime::NDArray;\n\nnamespace aten {\nnamespace impl {\n\n/**\n * @brief Check whether each row is sorted.\n */\ntemplate <typename IdType>\n__global__ void _SegmentIsSorted(\n    const IdType* indptr, const IdType* indices, int64_t num_rows,\n    int8_t* flags) {\n  int tx = blockIdx.x * blockDim.x + threadIdx.x;\n  const int stride_x = gridDim.x * blockDim.x;\n  while (tx < num_rows) {\n    bool f = true;\n    for (IdType i = indptr[tx] + 1; f && i < indptr[tx + 1]; ++i) {\n      f = (indices[i - 1] <= indices[i]);\n    }\n    flags[tx] = static_cast<int8_t>(f);\n    tx += stride_x;\n  }\n}\n\ntemplate <DGLDeviceType XPU, typename IdType>\nbool CSRIsSorted(CSRMatrix csr) {\n  const auto& ctx = csr.indptr->ctx;\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  auto device = runtime::DeviceAPI::Get(ctx);\n  // We allocate a workspace of num_rows bytes. It wastes a little bit memory\n  // but should be fine.\n  int8_t* flags =\n      static_cast<int8_t*>(device->AllocWorkspace(ctx, csr.num_rows));\n  const int nt = cuda::FindNumThreads(csr.num_rows);\n  const int nb = (csr.num_rows + nt - 1) / nt;\n  CUDA_KERNEL_CALL(\n      _SegmentIsSorted, nb, nt, 0, stream, csr.indptr.Ptr<IdType>(),\n      csr.indices.Ptr<IdType>(), csr.num_rows, flags);\n  bool ret = cuda::AllTrue(flags, csr.num_rows, ctx);\n  device->FreeWorkspace(ctx, flags);\n  return ret;\n}\n\ntemplate bool CSRIsSorted<kDGLCUDA, int32_t>(CSRMatrix csr);\ntemplate bool CSRIsSorted<kDGLCUDA, int64_t>(CSRMatrix csr);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nvoid CSRSort_(CSRMatrix* csr) {\n  LOG(FATAL) << \"Unreachable codes\";\n}\n\ntemplate <>\nvoid CSRSort_<kDGLCUDA, int32_t>(CSRMatrix* csr) {\n  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();\n  auto device = runtime::DeviceAPI::Get(csr->indptr->ctx);\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  // allocate cusparse handle if needed\n  if (!thr_entry->cusparse_handle) {\n    CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle)));\n  }\n  CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, stream));\n\n  NDArray indptr = csr->indptr;\n  NDArray indices = csr->indices;\n  const auto& ctx = indptr->ctx;\n  const int64_t nnz = indices->shape[0];\n  if (!aten::CSRHasData(*csr))\n    csr->data = aten::Range(0, nnz, indices->dtype.bits, ctx);\n  NDArray data = csr->data;\n\n  size_t workspace_size = 0;\n  CUSPARSE_CALL(cusparseXcsrsort_bufferSizeExt(\n      thr_entry->cusparse_handle, csr->num_rows, csr->num_cols, nnz,\n      indptr.Ptr<int32_t>(), indices.Ptr<int32_t>(), &workspace_size));\n  void* workspace = device->AllocWorkspace(ctx, workspace_size);\n\n  cusparseMatDescr_t descr;\n  CUSPARSE_CALL(cusparseCreateMatDescr(&descr));\n  CUSPARSE_CALL(cusparseSetMatType(descr, CUSPARSE_MATRIX_TYPE_GENERAL));\n  CUSPARSE_CALL(cusparseSetMatIndexBase(descr, CUSPARSE_INDEX_BASE_ZERO));\n  CUSPARSE_CALL(cusparseXcsrsort(\n      thr_entry->cusparse_handle, csr->num_rows, csr->num_cols, nnz, descr,\n      indptr.Ptr<int32_t>(), indices.Ptr<int32_t>(), data.Ptr<int32_t>(),\n      workspace));\n\n  csr->sorted = true;\n\n  // free resources\n  CUSPARSE_CALL(cusparseDestroyMatDescr(descr));\n  device->FreeWorkspace(ctx, workspace);\n}\n\ntemplate <>\nvoid CSRSort_<kDGLCUDA, int64_t>(CSRMatrix* csr) {\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  auto device = runtime::DeviceAPI::Get(csr->indptr->ctx);\n\n  const auto& ctx = csr->indptr->ctx;\n  const int64_t nnz = csr->indices->shape[0];\n  const auto nbits = csr->indptr->dtype.bits;\n  if (!aten::CSRHasData(*csr)) csr->data = aten::Range(0, nnz, nbits, ctx);\n\n  IdArray new_indices = csr->indices.Clone();\n  IdArray new_data = csr->data.Clone();\n\n  const int64_t* offsets = csr->indptr.Ptr<int64_t>();\n  const int64_t* key_in = csr->indices.Ptr<int64_t>();\n  int64_t* key_out = new_indices.Ptr<int64_t>();\n  const int64_t* value_in = csr->data.Ptr<int64_t>();\n  int64_t* value_out = new_data.Ptr<int64_t>();\n\n  // Allocate workspace\n  size_t workspace_size = 0;\n  CUDA_CALL(cub::DeviceSegmentedRadixSort::SortPairs(\n      nullptr, workspace_size, key_in, key_out, value_in, value_out, nnz,\n      csr->num_rows, offsets, offsets + 1, 0, sizeof(int64_t) * 8, stream));\n  void* workspace = device->AllocWorkspace(ctx, workspace_size);\n\n  // Compute\n  CUDA_CALL(cub::DeviceSegmentedRadixSort::SortPairs(\n      workspace, workspace_size, key_in, key_out, value_in, value_out, nnz,\n      csr->num_rows, offsets, offsets + 1, 0, sizeof(int64_t) * 8, stream));\n\n  csr->sorted = true;\n  csr->indices = new_indices;\n  csr->data = new_data;\n\n  // free resources\n  device->FreeWorkspace(ctx, workspace);\n}\n\ntemplate void CSRSort_<kDGLCUDA, int32_t>(CSRMatrix* csr);\ntemplate void CSRSort_<kDGLCUDA, int64_t>(CSRMatrix* csr);\n\n}  // namespace impl\n}  // namespace aten\n}  // namespace dgl\n"
  },
  {
    "path": "src/array/cuda/csr_sum.cu",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file array/cuda/spmm.cu\n * @brief SpGEAM C APIs and definitions.\n */\n#include <dgl/array.h>\n#include <dgl/runtime/device_api.h>\n\n#include \"../../runtime/cuda/cuda_common.h\"\n#include \"./cusparse_dispatcher.cuh\"\n#include \"./functor.cuh\"\n\nnamespace dgl {\n\nusing namespace dgl::runtime;\n\nnamespace aten {\nnamespace cusparse {\n\n/** Cusparse implementation of SpSum on Csr format. */\ntemplate <typename DType, typename IdType>\nstd::pair<CSRMatrix, NDArray> CusparseCsrgeam2(\n    const CSRMatrix& A, const NDArray A_weights_array, const CSRMatrix& B,\n    const NDArray B_weights_array) {\n  const int m = A.num_rows;\n  const int n = A.num_cols;\n  const int nnzA = A.indices->shape[0];\n  const int nnzB = B.indices->shape[0];\n  int nnzC;\n  const DType alpha = 1.0;\n  const DType beta = 1.0;\n  auto ctx = A.indptr->ctx;\n  auto device = runtime::DeviceAPI::Get(ctx);\n  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  const DType* A_weights = A_weights_array.Ptr<DType>();\n  const DType* B_weights = B_weights_array.Ptr<DType>();\n  // allocate cusparse handle if needed\n  if (!thr_entry->cusparse_handle)\n    CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle)));\n  CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, stream));\n\n  cusparseMatDescr_t matA, matB, matC;\n  CUSPARSE_CALL(cusparseCreateMatDescr(&matA));\n  CUSPARSE_CALL(cusparseCreateMatDescr(&matB));\n  CUSPARSE_CALL(cusparseCreateMatDescr(&matC));\n\n  cusparseSetPointerMode(\n      thr_entry->cusparse_handle, CUSPARSE_POINTER_MODE_HOST);\n  size_t workspace_size = 0;\n  /* prepare output C */\n  IdArray dC_csrOffsets = IdArray::Empty({m + 1}, A.indptr->dtype, ctx);\n  IdType* dC_csrOffsets_data = dC_csrOffsets.Ptr<IdType>();\n  IdArray dC_columns;\n  NDArray dC_weights;\n  IdType* dC_columns_data = dC_columns.Ptr<IdType>();\n  DType* dC_weights_data = dC_weights.Ptr<DType>();\n  /* prepare buffer */\n  CUSPARSE_CALL(CSRGEAM<DType>::bufferSizeExt(\n      thr_entry->cusparse_handle, m, n, &alpha, matA, nnzA, A_weights,\n      A.indptr.Ptr<IdType>(), A.indices.Ptr<IdType>(), &beta, matB, nnzB,\n      B_weights, B.indptr.Ptr<IdType>(), B.indices.Ptr<IdType>(), matC,\n      dC_weights_data, dC_csrOffsets_data, dC_columns_data, &workspace_size));\n\n  void* workspace = device->AllocWorkspace(ctx, workspace_size);\n  CUSPARSE_CALL(CSRGEAM<DType>::nnz(\n      thr_entry->cusparse_handle, m, n, matA, nnzA, A.indptr.Ptr<IdType>(),\n      A.indices.Ptr<IdType>(), matB, nnzB, B.indptr.Ptr<IdType>(),\n      B.indices.Ptr<IdType>(), matC, dC_csrOffsets_data, &nnzC, workspace));\n\n  dC_columns = IdArray::Empty({nnzC}, A.indptr->dtype, ctx);\n  dC_weights = NDArray::Empty({nnzC}, A_weights_array->dtype, ctx);\n  dC_columns_data = dC_columns.Ptr<IdType>();\n  dC_weights_data = dC_weights.Ptr<DType>();\n\n  CUSPARSE_CALL(CSRGEAM<DType>::compute(\n      thr_entry->cusparse_handle, m, n, &alpha, matA, nnzA, A_weights,\n      A.indptr.Ptr<IdType>(), A.indices.Ptr<IdType>(), &beta, matB, nnzB,\n      B_weights, B.indptr.Ptr<IdType>(), B.indices.Ptr<IdType>(), matC,\n      dC_weights_data, dC_csrOffsets_data, dC_columns_data, workspace));\n\n  device->FreeWorkspace(ctx, workspace);\n  // destroy matrix/vector descriptors\n  CUSPARSE_CALL(cusparseDestroyMatDescr(matA));\n  CUSPARSE_CALL(cusparseDestroyMatDescr(matB));\n  CUSPARSE_CALL(cusparseDestroyMatDescr(matC));\n  return {\n      CSRMatrix(\n          A.num_rows, A.num_cols, dC_csrOffsets, dC_columns,\n          NullArray(dC_csrOffsets->dtype, dC_csrOffsets->ctx), true),\n      dC_weights};\n}\n}  // namespace cusparse\n\ntemplate <int XPU, typename IdType, typename DType>\nstd::pair<CSRMatrix, NDArray> CSRSum(\n    const std::vector<CSRMatrix>& As, const std::vector<NDArray>& A_weights) {\n  const int64_t M = As[0].num_rows;\n  const int64_t N = As[0].num_cols;\n  const int64_t n = As.size();\n\n  // Cast 64 bit indices to 32 bit\n  std::vector<CSRMatrix> newAs;\n  newAs.reserve(n);\n  bool cast = false;\n  if (As[0].indptr->dtype.bits == 64) {\n    for (int i = 0; i < n; ++i)\n      newAs.emplace_back(\n          As[i].num_rows, As[i].num_cols, AsNumBits(As[i].indptr, 32),\n          AsNumBits(As[i].indices, 32), AsNumBits(As[i].data, 32));\n    cast = true;\n  } else {\n    for (int i = 0; i < n; ++i) newAs.push_back(As[i]);\n  }\n\n  // cuSPARSE csrgeam2 requires the CSR to be sorted.\n  // TODO(BarclayII): ideally the sorted CSR should be cached but I'm not sure\n  // how to do it.\n  for (int i = 0; i < n; ++i) {\n    if (!newAs[i].sorted) newAs[i] = CSRSort(newAs[i]);\n  }\n\n  // Reorder weights if A[i] has edge IDs\n  std::vector<NDArray> A_weights_reordered(n);\n  for (int i = 0; i < n; ++i) {\n    if (CSRHasData(newAs[i]))\n      A_weights_reordered[i] = IndexSelect(A_weights[i], newAs[i].data);\n    else\n      A_weights_reordered[i] = A_weights[i];\n  }\n\n  // Loop and sum\n  auto result = std::make_pair(\n      CSRMatrix(\n          newAs[0].num_rows, newAs[0].num_cols, newAs[0].indptr,\n          newAs[0].indices,\n          NullArray(newAs[0].indptr->dtype, newAs[0].indptr->ctx)),\n      A_weights_reordered[0]);  // Weights already reordered so we don't need\n                                // As[0].data\n  for (int64_t i = 1; i < n; ++i)\n    result = cusparse::CusparseCsrgeam2<DType, int32_t>(\n        result.first, result.second, newAs[i], A_weights_reordered[i]);\n\n  // Cast 32 bit indices back to 64 bit if necessary\n  if (cast) {\n    CSRMatrix C = result.first;\n    return {\n        CSRMatrix(\n            C.num_rows, C.num_cols, AsNumBits(C.indptr, 64),\n            AsNumBits(C.indices, 64), AsNumBits(C.data, 64), true),\n        result.second};\n  } else {\n    return result;\n  }\n}\n\ntemplate std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int32_t, __half>(\n    const std::vector<CSRMatrix>&, const std::vector<NDArray>&);\ntemplate std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int64_t, __half>(\n    const std::vector<CSRMatrix>&, const std::vector<NDArray>&);\n#if BF16_ENABLED\ntemplate std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int32_t, __nv_bfloat16>(\n    const std::vector<CSRMatrix>&, const std::vector<NDArray>&);\ntemplate std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int64_t, __nv_bfloat16>(\n    const std::vector<CSRMatrix>&, const std::vector<NDArray>&);\n#endif  // BF16_ENABLED\ntemplate std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int32_t, float>(\n    const std::vector<CSRMatrix>&, const std::vector<NDArray>&);\ntemplate std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int64_t, float>(\n    const std::vector<CSRMatrix>&, const std::vector<NDArray>&);\ntemplate std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int32_t, double>(\n    const std::vector<CSRMatrix>&, const std::vector<NDArray>&);\ntemplate std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int64_t, double>(\n    const std::vector<CSRMatrix>&, const std::vector<NDArray>&);\n\n}  // namespace aten\n}  // namespace dgl\n"
  },
  {
    "path": "src/array/cuda/csr_transpose.cc",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file array/cuda/csr_transpose.cc\n * @brief CSR transpose (convert to CSC)\n */\n#include <dgl/array.h>\n\n#include \"../../runtime/cuda/cuda_common.h\"\n\nnamespace dgl {\n\nusing runtime::NDArray;\n\nnamespace aten {\nnamespace impl {\n\ntemplate <DGLDeviceType XPU, typename IdType>\nCSRMatrix CSRTranspose(CSRMatrix csr) {\n  LOG(FATAL) << \"Unreachable codes\";\n  return {};\n}\n\ntemplate <>\nCSRMatrix CSRTranspose<kDGLCUDA, int32_t>(CSRMatrix csr) {\n#if CUDART_VERSION < 12000\n  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  // allocate cusparse handle if needed\n  if (!thr_entry->cusparse_handle) {\n    CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle)));\n  }\n  CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, stream));\n\n  NDArray indptr = csr.indptr, indices = csr.indices, data = csr.data;\n  const int64_t nnz = indices->shape[0];\n  const auto& ctx = indptr->ctx;\n  const auto bits = indptr->dtype.bits;\n  if (aten::IsNullArray(data)) data = aten::Range(0, nnz, bits, ctx);\n  const int32_t* indptr_ptr = static_cast<int32_t*>(indptr->data);\n  const int32_t* indices_ptr = static_cast<int32_t*>(indices->data);\n  const void* data_ptr = data->data;\n\n  // (BarclayII) csr2csc doesn't seem to clear the content of cscColPtr if nnz\n  // == 0. We need to do it ourselves.\n  NDArray t_indptr = aten::Full(0, csr.num_cols + 1, bits, ctx);\n  NDArray t_indices = aten::NewIdArray(nnz, ctx, bits);\n  NDArray t_data = aten::NewIdArray(nnz, ctx, bits);\n  int32_t* t_indptr_ptr = static_cast<int32_t*>(t_indptr->data);\n  int32_t* t_indices_ptr = static_cast<int32_t*>(t_indices->data);\n  void* t_data_ptr = t_data->data;\n\n#if CUDART_VERSION >= 10010\n  auto device = runtime::DeviceAPI::Get(csr.indptr->ctx);\n  // workspace\n  size_t workspace_size;\n  CUSPARSE_CALL(cusparseCsr2cscEx2_bufferSize(\n      thr_entry->cusparse_handle, csr.num_rows, csr.num_cols, nnz, data_ptr,\n      indptr_ptr, indices_ptr, t_data_ptr, t_indptr_ptr, t_indices_ptr,\n      CUDA_R_32F, CUSPARSE_ACTION_NUMERIC, CUSPARSE_INDEX_BASE_ZERO,\n      CUSPARSE_CSR2CSC_ALG1,  // see cusparse doc for reference\n      &workspace_size));\n  void* workspace = device->AllocWorkspace(ctx, workspace_size);\n  CUSPARSE_CALL(cusparseCsr2cscEx2(\n      thr_entry->cusparse_handle, csr.num_rows, csr.num_cols, nnz, data_ptr,\n      indptr_ptr, indices_ptr, t_data_ptr, t_indptr_ptr, t_indices_ptr,\n      CUDA_R_32F, CUSPARSE_ACTION_NUMERIC, CUSPARSE_INDEX_BASE_ZERO,\n      CUSPARSE_CSR2CSC_ALG1,  // see cusparse doc for reference\n      workspace));\n  device->FreeWorkspace(ctx, workspace);\n#else\n  CUSPARSE_CALL(cusparseScsr2csc(\n      thr_entry->cusparse_handle, csr.num_rows, csr.num_cols, nnz,\n      static_cast<const float*>(data_ptr), indptr_ptr, indices_ptr,\n      static_cast<float*>(t_data_ptr), t_indices_ptr, t_indptr_ptr,\n      CUSPARSE_ACTION_NUMERIC, CUSPARSE_INDEX_BASE_ZERO));\n#endif\n\n  return CSRMatrix(\n      csr.num_cols, csr.num_rows, t_indptr, t_indices, t_data, false);\n#else\n  return COOToCSR(COOTranspose(CSRToCOO(csr, false)));\n#endif\n}\n\ntemplate <>\nCSRMatrix CSRTranspose<kDGLCUDA, int64_t>(CSRMatrix csr) {\n  return COOToCSR(COOTranspose(CSRToCOO(csr, false)));\n}\n\ntemplate CSRMatrix CSRTranspose<kDGLCUDA, int32_t>(CSRMatrix csr);\ntemplate CSRMatrix CSRTranspose<kDGLCUDA, int64_t>(CSRMatrix csr);\n\n}  // namespace impl\n}  // namespace aten\n}  // namespace dgl\n"
  },
  {
    "path": "src/array/cuda/cuda_filter.cu",
    "content": "/**\n *  Copyright (c) 2021 by Contributors\n * @file array/cuda/cuda_filter.cc\n * @brief Object for selecting items in a set, or selecting items not in a set.\n */\n\n#include <dgl/runtime/device_api.h>\n\n#include <cub/cub.cuh>\n\n#include \"../../runtime/cuda/cuda_common.h\"\n#include \"../../runtime/cuda/cuda_hashtable.cuh\"\n#include \"../filter.h\"\n\nusing namespace dgl::runtime::cuda;\n\nnamespace dgl {\nnamespace array {\n\nnamespace {\n\ntemplate <typename IdType, bool include>\n__global__ void _IsInKernel(\n    DeviceOrderedHashTable<IdType> table, const IdType* const array,\n    const int64_t size, IdType* const mark) {\n  const int64_t idx = threadIdx.x + blockDim.x * blockIdx.x;\n  if (idx < size) {\n    mark[idx] = table.Contains(array[idx]) ^ (!include);\n  }\n}\n\ntemplate <typename IdType>\n__global__ void _InsertKernel(\n    const IdType* const prefix, const int64_t size, IdType* const result) {\n  const int64_t idx = threadIdx.x + blockDim.x * blockIdx.x;\n  if (idx < size) {\n    if (prefix[idx] != prefix[idx + 1]) {\n      result[prefix[idx]] = idx;\n    }\n  }\n}\n\ntemplate <typename IdType, bool include>\nIdArray _PerformFilter(const OrderedHashTable<IdType>& table, IdArray test) {\n  const auto& ctx = test->ctx;\n  auto device = runtime::DeviceAPI::Get(ctx);\n  const int64_t size = test->shape[0];\n  cudaStream_t cudaStream = runtime::getCurrentCUDAStream();\n\n  if (size == 0) {\n    return test;\n  }\n\n  // we need two arrays: 1) to act as a prefixsum\n  // for the number of entries that will be inserted, and\n  // 2) to collect the included items.\n  IdType* prefix = static_cast<IdType*>(\n      device->AllocWorkspace(ctx, sizeof(IdType) * (size + 1)));\n\n  // will resize down later\n  IdArray result = aten::NewIdArray(size, ctx, sizeof(IdType) * 8);\n\n  // mark each index based on it's existence in the hashtable\n  {\n    const dim3 block(256);\n    const dim3 grid((size + block.x - 1) / block.x);\n\n    CUDA_KERNEL_CALL(\n        (_IsInKernel<IdType, include>), grid, block, 0, cudaStream,\n        table.DeviceHandle(), static_cast<const IdType*>(test->data), size,\n        prefix);\n  }\n\n  // generate prefix-sum\n  {\n    size_t workspace_bytes;\n    CUDA_CALL(cub::DeviceScan::ExclusiveSum(\n        nullptr, workspace_bytes, static_cast<IdType*>(nullptr),\n        static_cast<IdType*>(nullptr), size + 1, cudaStream));\n    void* workspace = device->AllocWorkspace(ctx, workspace_bytes);\n\n    CUDA_CALL(cub::DeviceScan::ExclusiveSum(\n        workspace, workspace_bytes, prefix, prefix, size + 1, cudaStream));\n    device->FreeWorkspace(ctx, workspace);\n  }\n\n  // copy number using the internal current stream;\n  IdType num_unique;\n  device->CopyDataFromTo(\n      prefix + size, 0, &num_unique, 0, sizeof(num_unique), ctx,\n      DGLContext{kDGLCPU, 0}, test->dtype);\n\n  // insert items into set\n  {\n    const dim3 block(256);\n    const dim3 grid((size + block.x - 1) / block.x);\n\n    CUDA_KERNEL_CALL(\n        _InsertKernel, grid, block, 0, cudaStream, prefix, size,\n        static_cast<IdType*>(result->data));\n  }\n  device->FreeWorkspace(ctx, prefix);\n\n  return result.CreateView({num_unique}, result->dtype);\n}\n\ntemplate <typename IdType>\nclass CudaFilterSet : public Filter {\n public:\n  explicit CudaFilterSet(IdArray array)\n      : table_(array->shape[0], array->ctx, runtime::getCurrentCUDAStream()) {\n    cudaStream_t cudaStream = runtime::getCurrentCUDAStream();\n    table_.FillWithUnique(\n        static_cast<const IdType*>(array->data), array->shape[0], cudaStream);\n  }\n\n  IdArray find_included_indices(IdArray test) override {\n    return _PerformFilter<IdType, true>(table_, test);\n  }\n\n  IdArray find_excluded_indices(IdArray test) override {\n    return _PerformFilter<IdType, false>(table_, test);\n  }\n\n private:\n  OrderedHashTable<IdType> table_;\n};\n\n}  // namespace\n\ntemplate <DGLDeviceType XPU, typename IdType>\nFilterRef CreateSetFilter(IdArray set) {\n  return FilterRef(std::make_shared<CudaFilterSet<IdType>>(set));\n}\n\ntemplate FilterRef CreateSetFilter<kDGLCUDA, int32_t>(IdArray set);\ntemplate FilterRef CreateSetFilter<kDGLCUDA, int64_t>(IdArray set);\n\n}  // namespace array\n}  // namespace dgl\n"
  },
  {
    "path": "src/array/cuda/cusparse_dispatcher.cuh",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file array/cuda/dispatcher.cuh\n * @brief Templates to dispatch into different cuSPARSE routines based on the\n * type argument.\n */\n#ifndef DGL_ARRAY_CUDA_CUSPARSE_DISPATCHER_CUH_\n#define DGL_ARRAY_CUDA_CUSPARSE_DISPATCHER_CUH_\n\n#include <cusparse.h>\n#include <dgl/runtime/c_runtime_api.h>\n\n#include \"bf16.cuh\"\n#include \"fp16.cuh\"\n\nnamespace dgl {\nnamespace aten {\n\n/** @brief cusparseXcsrgemm dispatcher */\ntemplate <typename DType>\nstruct CSRGEMM {\n  template <typename... Args>\n  static inline cusparseStatus_t bufferSizeExt(Args... args) {\n    BUG_IF_FAIL(false) << \"This piece of code should not be reached.\";\n    return static_cast<cusparseStatus_t>(0);\n  }\n\n  template <typename... Args>\n  static inline cusparseStatus_t nnz(Args... args) {\n    return cusparseXcsrgemm2Nnz(args...);\n  }\n\n  template <typename... Args>\n  static inline cusparseStatus_t compute(Args... args) {\n    BUG_IF_FAIL(false) << \"This piece of code should not be reached.\";\n    return static_cast<cusparseStatus_t>(0);\n  }\n};\n\ntemplate <>\nstruct CSRGEMM<__half> {\n  template <typename... Args>\n  static inline cusparseStatus_t bufferSizeExt(Args... args) {\n    // TODO(ndickson): There is no cusparseHcsrgemm2_bufferSizeExt, so a\n    // different implementation would be required.\n    LOG(FATAL) << \"CSRGEMM::bufferSizeExt does not support dtype half (FP16).\";\n    return static_cast<cusparseStatus_t>(0);\n  }\n\n  template <typename... Args>\n  static inline cusparseStatus_t nnz(Args... args) {\n    return cusparseXcsrgemm2Nnz(args...);\n  }\n\n  template <typename... Args>\n  static inline cusparseStatus_t compute(Args... args) {\n    // TODO(ndickson): There is no cusparseHcsrgemm2, so a different\n    // implementation would be required.\n    LOG(FATAL) << \"CSRGEMM::compute does not support dtype half (FP16).\";\n    return static_cast<cusparseStatus_t>(0);\n  }\n};\n\n#if BF16_ENABLED\ntemplate <>\nstruct CSRGEMM<__nv_bfloat16> {\n  template <typename... Args>\n  static inline cusparseStatus_t bufferSizeExt(Args... args) {\n    // TODO(ndickson): There is no cusparseHcsrgemm2_bufferSizeExt, so a\n    // different implementation would be required.\n    LOG(FATAL)\n        << \"CSRGEMM::bufferSizeExt does not support dtype bfloat16 (BF16).\";\n    return static_cast<cusparseStatus_t>(0);\n  }\n\n  template <typename... Args>\n  static inline cusparseStatus_t nnz(Args... args) {\n    return cusparseXcsrgemm2Nnz(args...);\n  }\n\n  template <typename... Args>\n  static inline cusparseStatus_t compute(Args... args) {\n    // TODO(ndickson): There is no cusparseHcsrgemm2, so a different\n    // implementation would be required.\n    LOG(FATAL) << \"CSRGEMM::compute does not support dtype bfloat16 (BF16).\";\n    return static_cast<cusparseStatus_t>(0);\n  }\n};\n#endif  // BF16_ENABLED\n\ntemplate <>\nstruct CSRGEMM<float> {\n  template <typename... Args>\n  static inline cusparseStatus_t bufferSizeExt(Args... args) {\n    return cusparseScsrgemm2_bufferSizeExt(args...);\n  }\n\n  template <typename... Args>\n  static inline cusparseStatus_t nnz(Args... args) {\n    return cusparseXcsrgemm2Nnz(args...);\n  }\n\n  template <typename... Args>\n  static inline cusparseStatus_t compute(Args... args) {\n    return cusparseScsrgemm2(args...);\n  }\n};\n\ntemplate <>\nstruct CSRGEMM<double> {\n  template <typename... Args>\n  static inline cusparseStatus_t bufferSizeExt(Args... args) {\n    return cusparseDcsrgemm2_bufferSizeExt(args...);\n  }\n\n  template <typename... Args>\n  static inline cusparseStatus_t nnz(Args... args) {\n    return cusparseXcsrgemm2Nnz(args...);\n  }\n\n  template <typename... Args>\n  static inline cusparseStatus_t compute(Args... args) {\n    return cusparseDcsrgemm2(args...);\n  }\n};\n\n/** @brief cusparseXcsrgeam dispatcher */\ntemplate <typename DType>\nstruct CSRGEAM {\n  template <typename... Args>\n  static inline cusparseStatus_t bufferSizeExt(Args... args) {\n    BUG_IF_FAIL(false) << \"This piece of code should not be reached.\";\n    return static_cast<cusparseStatus_t>(0);\n  }\n\n  template <typename... Args>\n  static inline cusparseStatus_t nnz(Args... args) {\n    return cusparseXcsrgeam2Nnz(args...);\n  }\n\n  template <typename... Args>\n  static inline cusparseStatus_t compute(Args... args) {\n    BUG_IF_FAIL(false) << \"This piece of code should not be reached.\";\n    return static_cast<cusparseStatus_t>(0);\n  }\n};\n\ntemplate <>\nstruct CSRGEAM<__half> {\n  template <typename... Args>\n  static inline cusparseStatus_t bufferSizeExt(Args... args) {\n    // TODO(ndickson): There is no cusparseHcsrgeam2_bufferSizeExt, so a\n    // different implementation would be required.\n    LOG(FATAL) << \"CSRGEAM::bufferSizeExt does not support dtype half (FP16).\";\n    return static_cast<cusparseStatus_t>(0);\n  }\n\n  template <typename... Args>\n  static inline cusparseStatus_t nnz(Args... args) {\n    return cusparseXcsrgeam2Nnz(args...);\n  }\n\n  template <typename... Args>\n  static inline cusparseStatus_t compute(Args... args) {\n    // TODO(ndickson): There is no cusparseHcsrgeam2, so a different\n    // implementation would be required.\n    LOG(FATAL) << \"CSRGEAM::compute does not support dtype half (FP16).\";\n    return static_cast<cusparseStatus_t>(0);\n  }\n};\n\n#if BF16_ENABLED\ntemplate <>\nstruct CSRGEAM<__nv_bfloat16> {\n  template <typename... Args>\n  static inline cusparseStatus_t bufferSizeExt(Args... args) {\n    // TODO(ndickson): There is no cusparseHcsrgeam2_bufferSizeExt, so a\n    // different implementation would be required.\n    LOG(FATAL)\n        << \"CSRGEAM::bufferSizeExt does not support dtype bfloat16 (BF16).\";\n    return static_cast<cusparseStatus_t>(0);\n  }\n\n  template <typename... Args>\n  static inline cusparseStatus_t nnz(Args... args) {\n    return cusparseXcsrgeam2Nnz(args...);\n  }\n\n  template <typename... Args>\n  static inline cusparseStatus_t compute(Args... args) {\n    // TODO(ndickson): There is no cusparseHcsrgeam2, so a different\n    // implementation would be required.\n    LOG(FATAL) << \"CSRGEAM::compute does not support dtype bfloat16 (BF16).\";\n    return static_cast<cusparseStatus_t>(0);\n  }\n};\n#endif  // BF16_ENABLED\n\ntemplate <>\nstruct CSRGEAM<float> {\n  template <typename... Args>\n  static inline cusparseStatus_t bufferSizeExt(Args... args) {\n    return cusparseScsrgeam2_bufferSizeExt(args...);\n  }\n\n  template <typename... Args>\n  static inline cusparseStatus_t nnz(Args... args) {\n    return cusparseXcsrgeam2Nnz(args...);\n  }\n\n  template <typename... Args>\n  static inline cusparseStatus_t compute(Args... args) {\n    return cusparseScsrgeam2(args...);\n  }\n};\n\ntemplate <>\nstruct CSRGEAM<double> {\n  template <typename... Args>\n  static inline cusparseStatus_t bufferSizeExt(Args... args) {\n    return cusparseDcsrgeam2_bufferSizeExt(args...);\n  }\n\n  template <typename... Args>\n  static inline cusparseStatus_t nnz(Args... args) {\n    return cusparseXcsrgeam2Nnz(args...);\n  }\n\n  template <typename... Args>\n  static inline cusparseStatus_t compute(Args... args) {\n    return cusparseDcsrgeam2(args...);\n  }\n};\n\n};  // namespace aten\n};  // namespace dgl\n\n#endif  // DGL_ARRAY_CUDA_CUSPARSE_DISPATCHER_CUH_\n"
  },
  {
    "path": "src/array/cuda/disjoint_union.cu",
    "content": "/**\n *   Copyright (c) 2022, NVIDIA CORPORATION.\n *\n *   Licensed under the Apache License, Version 2.0 (the \"License\");\n *   you may not use this file except in compliance with the License.\n *   You may obtain a copy of the License at\n *\n *       http://www.apache.org/licenses/LICENSE-2.0\n *\n *   Unless required by applicable law or agreed to in writing, software\n *   distributed under the License is distributed on an \"AS IS\" BASIS,\n *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n *   See the License for the specific language governing permissions and\n *   limitations under the License.\n *\n * @file array/gpu/disjoint_union.cu\n * @brief Disjoint union GPU implementation.\n */\n\n#include <dgl/array.h>\n#include <dgl/runtime/parallel_for.h>\n\n#include <tuple>\n#include <vector>\n\n#include \"../../runtime/cuda/cuda_common.h\"\n#include \"./utils.h\"\n\nnamespace dgl {\nusing runtime::NDArray;\nnamespace aten {\nnamespace impl {\n\ntemplate <typename IdType>\n__global__ void _DisjointUnionKernel(\n    IdType** arrs, IdType* prefix, IdType* offset, IdType* out, int64_t n_arrs,\n    int n_elms) {\n  IdType tx = static_cast<IdType>(blockIdx.x) * blockDim.x + threadIdx.x;\n  const int stride_x = gridDim.x * blockDim.x;\n  while (tx < n_elms) {\n    IdType i = dgl::cuda::_UpperBound(offset, n_arrs, tx) - 1;\n    if (arrs[i] == NULL) {\n      out[tx] = tx;\n    } else {\n      IdType j = tx - offset[i];\n      out[tx] = arrs[i][j] + prefix[i];\n    }\n    tx += stride_x;\n  }\n}\n\ntemplate <DGLDeviceType XPU, typename IdType>\nstd::tuple<IdArray, IdArray, IdArray> _ComputePrefixSums(\n    const std::vector<COOMatrix>& coos) {\n  IdType n = coos.size(), nbits = coos[0].row->dtype.bits;\n  IdArray n_rows = NewIdArray(n, CPU, nbits);\n  IdArray n_cols = NewIdArray(n, CPU, nbits);\n  IdArray n_elms = NewIdArray(n, CPU, nbits);\n\n  IdType* n_rows_data = n_rows.Ptr<IdType>();\n  IdType* n_cols_data = n_cols.Ptr<IdType>();\n  IdType* n_elms_data = n_elms.Ptr<IdType>();\n\n  dgl::runtime::parallel_for(0, coos.size(), [&](IdType b, IdType e) {\n    for (IdType i = b; i < e; ++i) {\n      n_rows_data[i] = coos[i].num_rows;\n      n_cols_data[i] = coos[i].num_cols;\n      n_elms_data[i] = coos[i].row->shape[0];\n    }\n  });\n\n  return std::make_tuple(\n      CumSum(n_rows.CopyTo(coos[0].row->ctx), true),\n      CumSum(n_cols.CopyTo(coos[0].row->ctx), true),\n      CumSum(n_elms.CopyTo(coos[0].row->ctx), true));\n}\n\ntemplate <DGLDeviceType XPU, typename IdType>\nvoid _Merge(\n    IdType** arrs, IdType* prefix, IdType* offset, IdType* out, int64_t n_arrs,\n    int n_elms, DGLContext ctx, DGLDataType dtype, cudaStream_t stream) {\n  auto device = runtime::DeviceAPI::Get(ctx);\n  int nt = 256;\n  int nb = (n_elms + nt - 1) / nt;\n\n  IdType** arrs_dev = static_cast<IdType**>(\n      device->AllocWorkspace(ctx, n_arrs * sizeof(IdType*)));\n\n  device->CopyDataFromTo(\n      arrs, 0, arrs_dev, 0, sizeof(IdType*) * n_arrs, DGLContext{kDGLCPU, 0},\n      ctx, dtype);\n\n  CUDA_KERNEL_CALL(\n      _DisjointUnionKernel, nb, nt, 0, stream, arrs_dev, prefix, offset, out,\n      n_arrs, n_elms);\n\n  device->FreeWorkspace(ctx, arrs_dev);\n}\n\ntemplate <DGLDeviceType XPU, typename IdType>\nCOOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos) {\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  auto device = runtime::DeviceAPI::Get(coos[0].row->ctx);\n  uint64_t src_offset = 0, dst_offset = 0;\n  bool has_data = false;\n  bool row_sorted = true;\n  bool col_sorted = true;\n\n  // check if data index array\n  for (size_t i = 0; i < coos.size(); ++i) {\n    CHECK_SAME_DTYPE(coos[0].row, coos[i].row);\n    CHECK_SAME_CONTEXT(coos[0].row, coos[i].row);\n    has_data |= COOHasData(coos[i]);\n  }\n\n  auto prefixes = _ComputePrefixSums<XPU, IdType>(coos);\n  auto prefix_src = static_cast<IdType*>(std::get<0>(prefixes)->data);\n  auto prefix_dst = static_cast<IdType*>(std::get<1>(prefixes)->data);\n  auto prefix_elm = static_cast<IdType*>(std::get<2>(prefixes)->data);\n\n  std::unique_ptr<IdType*[]> rows(new IdType*[coos.size()]);\n  std::unique_ptr<IdType*[]> cols(new IdType*[coos.size()]);\n  std::unique_ptr<IdType*[]> data(new IdType*[coos.size()]);\n\n  for (size_t i = 0; i < coos.size(); i++) {\n    row_sorted &= coos[i].row_sorted;\n    col_sorted &= coos[i].col_sorted;\n    rows[i] = coos[i].row.Ptr<IdType>();\n    cols[i] = coos[i].col.Ptr<IdType>();\n    data[i] = coos[i].data.Ptr<IdType>();\n  }\n\n  auto ctx = coos[0].row->ctx;\n  auto dtype = coos[0].row->dtype;\n\n  IdType n_elements = 0;\n  device->CopyDataFromTo(\n      &prefix_elm[coos.size()], 0, &n_elements, 0, sizeof(IdType),\n      coos[0].row->ctx, DGLContext{kDGLCPU, 0}, coos[0].row->dtype);\n\n  device->CopyDataFromTo(\n      &prefix_src[coos.size()], 0, &src_offset, 0, sizeof(IdType),\n      coos[0].row->ctx, DGLContext{kDGLCPU, 0}, coos[0].row->dtype);\n\n  device->CopyDataFromTo(\n      &prefix_dst[coos.size()], 0, &dst_offset, 0, sizeof(IdType),\n      coos[0].row->ctx, DGLContext{kDGLCPU, 0}, coos[0].row->dtype);\n\n  // Union src array\n  IdArray result_src =\n      NewIdArray(n_elements, coos[0].row->ctx, coos[0].row->dtype.bits);\n  _Merge<XPU, IdType>(\n      rows.get(), prefix_src, prefix_elm, result_src.Ptr<IdType>(), coos.size(),\n      n_elements, ctx, dtype, stream);\n\n  // Union dst array\n  IdArray result_dst =\n      NewIdArray(n_elements, coos[0].col->ctx, coos[0].col->dtype.bits);\n  _Merge<XPU, IdType>(\n      cols.get(), prefix_dst, prefix_elm, result_dst.Ptr<IdType>(), coos.size(),\n      n_elements, ctx, dtype, stream);\n\n  // Union data array if exists and fetch number of elements\n  IdArray result_dat = NullArray();\n  if (has_data) {\n    result_dat =\n        NewIdArray(n_elements, coos[0].row->ctx, coos[0].row->dtype.bits);\n    _Merge<XPU, IdType>(\n        data.get(), prefix_elm, prefix_elm, result_dat.Ptr<IdType>(),\n        coos.size(), n_elements, ctx, dtype, stream);\n  }\n\n  return COOMatrix(\n      src_offset, dst_offset, result_src, result_dst, result_dat, row_sorted,\n      col_sorted);\n}\n\ntemplate COOMatrix DisjointUnionCoo<kDGLCUDA, int32_t>(\n    const std::vector<COOMatrix>& coos);\ntemplate COOMatrix DisjointUnionCoo<kDGLCUDA, int64_t>(\n    const std::vector<COOMatrix>& coos);\n\n}  // namespace impl\n}  // namespace aten\n}  // namespace dgl\n"
  },
  {
    "path": "src/array/cuda/fp16.cuh",
    "content": "/**\n *  Copyright (c) 2020-2022 by Contributors\n *\n *  Licensed under the Apache License, Version 2.0 (the \"License\");\n *  you may not use this file except in compliance with the License.\n *  You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n *  Unless required by applicable law or agreed to in writing, software\n *  distributed under the License is distributed on an \"AS IS\" BASIS,\n *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n *  See the License for the specific language governing permissions and\n *  limitations under the License.\n *\n * @file array/cuda/fp16.cuh\n * @brief float16 related functions.\n * @note this file is modified from TVM project:\n *       https://github.com/apache/tvm/blob/e561007f0c330e3d14c2bc8a3ef40fb741db9004/src/target/source/literal/cuda_half_t.h.\n */\n#ifndef DGL_ARRAY_CUDA_FP16_CUH_\n#define DGL_ARRAY_CUDA_FP16_CUH_\n\n#include <cuda_fp16.h>\n\n#include <algorithm>\n\nstatic __device__ __forceinline__ half max(half a, half b) {\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530\n  return __hgt(__half(a), __half(b)) ? a : b;\n#else\n  return __half(max(float(a), float(b)));  // NOLINT\n#endif\n}\n\nstatic __device__ __forceinline__ half min(half a, half b) {\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530\n  return __hlt(__half(a), __half(b)) ? a : b;\n#else\n  return __half(min(float(a), float(b)));  // NOLINT\n#endif\n}\n\n#ifdef __CUDACC__\n// Arithmetic FP16 operations for architecture >= 5.3 are already defined in\n// cuda_fp16.h\n#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 530)\n// CUDA 12.2 adds \"emulated\" support for older architectures.\n#if defined(CUDART_VERSION) && (CUDART_VERSION < 12020)\n__device__ __forceinline__ __half\noperator+(const __half& lh, const __half& rh) {\n  return __half(float(lh) + float(rh));  // NOLINT\n}\n__device__ __forceinline__ __half\noperator-(const __half& lh, const __half& rh) {\n  return __half(float(lh) - float(rh));  // NOLINT\n}\n__device__ __forceinline__ __half\noperator*(const __half& lh, const __half& rh) {\n  return __half(float(lh) * float(rh));  // NOLINT\n}\n__device__ __forceinline__ __half\noperator/(const __half& lh, const __half& rh) {\n  return __half(float(lh) / float(rh));  // NOLINT\n}\n\n__device__ __forceinline__ __half& operator+=(\n    __half& lh, const __half& rh) {    // NOLINT\n  lh = __half(float(lh) + float(rh));  // NOLINT\n  return lh;\n}\n__device__ __forceinline__ __half& operator-=(\n    __half& lh, const __half& rh) {    // NOLINT\n  lh = __half(float(lh) - float(rh));  // NOLINT\n  return lh;\n}\n__device__ __forceinline__ __half& operator*=(\n    __half& lh, const __half& rh) {    // NOLINT\n  lh = __half(float(lh) * float(rh));  // NOLINT\n  return lh;\n}\n__device__ __forceinline__ __half& operator/=(\n    __half& lh, const __half& rh) {    // NOLINT\n  lh = __half(float(lh) / float(rh));  // NOLINT\n  return lh;\n}\n\n__device__ __forceinline__ __half& operator++(__half& h) {  // NOLINT\n  h = __half(float(h) + 1.0f);                              // NOLINT\n  return h;\n}\n__device__ __forceinline__ __half& operator--(__half& h) {  // NOLINT\n  h = __half(float(h) - 1.0f);                              // NOLINT\n  return h;\n}\n__device__ __forceinline__ __half operator++(__half& h, int) {  // NOLINT\n  __half ret = h;\n  h = __half(float(h) + 1.0f);  // NOLINT\n  return ret;\n}\n__device__ __forceinline__ __half operator--(__half& h, int) {  // NOLINT\n  __half ret = h;\n  h = __half(float(h) - 1.0f);  // NOLINT\n  return ret;\n}\n\n__device__ __forceinline__ __half operator+(const __half& h) { return h; }\n__device__ __forceinline__ __half operator-(const __half& h) {\n  return __half(-float(h));  // NOLINT\n}\n\n__device__ __forceinline__ bool operator==(const __half& lh, const __half& rh) {\n  return float(lh) == float(rh);  // NOLINT\n}\n__device__ __forceinline__ bool operator!=(const __half& lh, const __half& rh) {\n  return float(lh) != float(rh);  // NOLINT\n}\n__device__ __forceinline__ bool operator>(const __half& lh, const __half& rh) {\n  return float(lh) > float(rh);  // NOLINT\n}\n__device__ __forceinline__ bool operator<(const __half& lh, const __half& rh) {\n  return float(lh) < float(rh);  // NOLINT\n}\n__device__ __forceinline__ bool operator>=(const __half& lh, const __half& rh) {\n  return float(lh) >= float(rh);  // NOLINT\n}\n__device__ __forceinline__ bool operator<=(const __half& lh, const __half& rh) {\n  return float(lh) <= float(rh);  // NOLINT\n}\n#endif  // defined(CUDART_VERSION) && (CUDART_VERSION < 12020)\n#endif  // defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 530)\n#endif  // __CUDACC__\n\n#endif  // DGL_ARRAY_CUDA_FP16_CUH_\n"
  },
  {
    "path": "src/array/cuda/functor.cuh",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file array/cuda/functor.cuh\n * @brief Functors for template on CUDA\n */\n#ifndef DGL_ARRAY_CUDA_FUNCTOR_CUH_\n#define DGL_ARRAY_CUDA_FUNCTOR_CUH_\n\n#include <cmath>\n#include <limits>\n\n#include \"./atomic.cuh\"\n#include \"./fp16.cuh\"\n#include \"bf16.cuh\"\n\nnamespace dgl {\nnamespace aten {\nnamespace cuda {\n\n/////////////////////////// CUDA binary operators //////////////////////////////\nnamespace binary {\ntemplate <typename DType>\nstruct Add {\n  static constexpr bool use_lhs = true;\n  static constexpr bool use_rhs = true;\n  static constexpr bool reduce_last_dim = false;\n  static __device__ __forceinline__ DType\n  Call(const DType *lhs, const DType *rhs, int64_t len = 1) {\n    return lhs[0] + rhs[0];\n  }\n};\ntemplate <typename DType>\nconstexpr bool Add<DType>::use_lhs;\ntemplate <typename DType>\nconstexpr bool Add<DType>::use_rhs;\ntemplate <typename DType>\nconstexpr bool Add<DType>::reduce_last_dim;\n\ntemplate <typename DType>\nstruct Sub {\n  static constexpr bool use_lhs = true;\n  static constexpr bool use_rhs = true;\n  static constexpr bool reduce_last_dim = false;\n  static __device__ __forceinline__ DType\n  Call(const DType *lhs, const DType *rhs, int64_t len = 1) {\n    return lhs[0] - rhs[0];\n  }\n};\ntemplate <typename DType>\nconstexpr bool Sub<DType>::use_lhs;\ntemplate <typename DType>\nconstexpr bool Sub<DType>::use_rhs;\ntemplate <typename DType>\nconstexpr bool Sub<DType>::reduce_last_dim;\n\ntemplate <typename DType>\nstruct Mul {\n  static constexpr bool use_lhs = true;\n  static constexpr bool use_rhs = true;\n  static constexpr bool reduce_last_dim = false;\n  static __device__ __forceinline__ DType\n  Call(const DType *lhs, const DType *rhs, int64_t len = 1) {\n    return lhs[0] * rhs[0];\n  }\n};\ntemplate <typename DType>\nconstexpr bool Mul<DType>::use_lhs;\ntemplate <typename DType>\nconstexpr bool Mul<DType>::use_rhs;\ntemplate <typename DType>\nconstexpr bool Mul<DType>::reduce_last_dim;\n\ntemplate <typename DType>\nstruct Div {\n  static constexpr bool use_lhs = true;\n  static constexpr bool use_rhs = true;\n  static constexpr bool reduce_last_dim = false;\n  static __device__ __forceinline__ DType\n  Call(const DType *lhs, const DType *rhs, int64_t len = 1) {\n    return lhs[0] / rhs[0];\n  }\n};\ntemplate <typename DType>\nconstexpr bool Div<DType>::use_lhs;\ntemplate <typename DType>\nconstexpr bool Div<DType>::use_rhs;\ntemplate <typename DType>\nconstexpr bool Div<DType>::reduce_last_dim;\n\ntemplate <typename DType>\nstruct CopyLhs {\n  static constexpr bool use_lhs = true;\n  static constexpr bool use_rhs = false;\n  static constexpr bool reduce_last_dim = false;\n  static __device__ __forceinline__ DType\n  Call(const DType *lhs, const DType *rhs, int64_t len = 1) {\n    return lhs[0];\n  }\n};\ntemplate <typename DType>\nconstexpr bool CopyLhs<DType>::use_lhs;\ntemplate <typename DType>\nconstexpr bool CopyLhs<DType>::use_rhs;\ntemplate <typename DType>\nconstexpr bool CopyLhs<DType>::reduce_last_dim;\n\ntemplate <typename DType>\nstruct CopyRhs {\n  static constexpr bool use_lhs = false;\n  static constexpr bool use_rhs = true;\n  static constexpr bool reduce_last_dim = false;\n  static __device__ __forceinline__ DType\n  Call(const DType *lhs, const DType *rhs, int64_t len = 1) {\n    return rhs[0];\n  }\n};\ntemplate <typename DType>\nconstexpr bool CopyRhs<DType>::use_lhs;\ntemplate <typename DType>\nconstexpr bool CopyRhs<DType>::use_rhs;\ntemplate <typename DType>\nconstexpr bool CopyRhs<DType>::reduce_last_dim;\n\ntemplate <typename DType>\nstruct Dot {\n  static constexpr bool use_lhs = true;\n  static constexpr bool use_rhs = true;\n  static constexpr bool reduce_last_dim = true;\n  static __device__ __forceinline__ DType\n  Call(const DType *lhs, const DType *rhs, int64_t len = 1) {\n    DType rst = static_cast<DType>(0.0f);\n    for (int64_t i = 0; i < len; ++i) {\n      rst += lhs[i] * rhs[i];\n    }\n    return rst;\n  }\n};\ntemplate <typename DType>\nconstexpr bool Dot<DType>::use_lhs;\ntemplate <typename DType>\nconstexpr bool Dot<DType>::use_rhs;\ntemplate <typename DType>\nconstexpr bool Dot<DType>::reduce_last_dim;\n\n}  // end of namespace binary\n\n/////////////////////////// CUDA reduce operators //////////////////////////////\nnamespace reduce {\ntemplate <typename Idx, typename DType, bool atomic>\nstruct _Sum {\n  static constexpr __host__ __device__ __forceinline__ DType zero() {\n    return 0.;\n  }\n  static constexpr bool require_arg = false;\n  static __device__ __forceinline__ void Call(\n      DType *out_buf, Idx *arg_u_buf, Idx *arg_e_buf, DType val, Idx uid,\n      Idx eid) {\n    if (!atomic) {\n      *out_buf += val;\n    } else {\n      cuda::AtomicAdd(out_buf, val);\n    }\n  }\n  static __device__ __forceinline__ void Call(\n      DType *out_buf, Idx *arg_buf, DType val, Idx id) {\n    if (!atomic) {\n      *out_buf += val;\n    } else {\n      cuda::AtomicAdd(out_buf, val);\n    }\n  }\n  static __device__ __forceinline__ void CallArg(\n      Idx fid, Idx *arg_u_buf, Idx *arg_e_buf, DType val, DType val_ref,\n      Idx uid, Idx eid) {}\n};\n\ntemplate <typename Idx, typename DType, bool atomic = false>\nstruct Sum : _Sum<Idx, DType, atomic> {};\n\ntemplate <typename Idx, bool atomic>\nstruct Sum<Idx, __half, atomic> : _Sum<Idx, __half, atomic> {\n  static constexpr __host__ __device__ __forceinline__ __half zero() {\n    return __float2half_rn(0.);\n  }\n  static __device__ __forceinline__ void Call(\n      __half *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,\n      __half val, Idx uid, Idx eid) {\n    _Sum<Idx, __half, atomic>::Call(\n        out_buf, arg_u_buf, arg_e_buf, val, uid, eid);\n  }\n  static __device__ __forceinline__ void Call(\n      __half *out_buf, Idx *arg_buf, __half val, Idx id) {\n    _Sum<Idx, __half, atomic>::Call(out_buf, arg_buf, val, id);\n  }\n  // sometimes we have to use float in reduction for better precision\n  static __device__ __forceinline__ void Call(\n      float *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,\n      __half val, Idx uid, Idx eid) {\n    _Sum<Idx, float, atomic>::Call(out_buf, arg_u_buf, arg_e_buf,\n        static_cast<float>(val), uid, eid);\n  }\n  static __device__ __forceinline__ void Call(\n      float *out_buf, Idx *arg_buf, __half val, Idx id) {\n    _Sum<Idx, float, atomic>::Call(out_buf, arg_buf,\n        static_cast<float>(val), id);\n  }\n};\n\n#if BF16_ENABLED\ntemplate <typename Idx, bool atomic>\nstruct Sum<Idx, __nv_bfloat16, atomic> : _Sum<Idx, __nv_bfloat16, atomic> {\n  static constexpr __host__ __device__ __forceinline__ __nv_bfloat16 zero() {\n    return __float2bfloat16_rn(0.);\n  }\n  static __device__ __forceinline__ void Call(\n      __nv_bfloat16 *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,\n      __nv_bfloat16 val, Idx uid, Idx eid) {\n    _Sum<Idx, __nv_bfloat16, atomic>::Call(\n        out_buf, arg_u_buf, arg_e_buf, val, uid, eid);\n  }\n  static __device__ __forceinline__ void Call(\n      __nv_bfloat16 *out_buf, Idx *arg_buf, __nv_bfloat16 val, Idx id) {\n    _Sum<Idx, __nv_bfloat16, atomic>::Call(out_buf, arg_buf, val, id);\n  }\n  // sometimes we have to use float in reduction for better precision\n  static __device__ __forceinline__ void Call(\n      float *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,\n      __nv_bfloat16 val, Idx uid, Idx eid) {\n    _Sum<Idx, float, atomic>::Call(out_buf, arg_u_buf, arg_e_buf,\n        static_cast<float>(val), uid, eid);\n  }\n  static __device__ __forceinline__ void Call(\n      float *out_buf, Idx *arg_buf, __nv_bfloat16 val, Idx id) {\n    _Sum<Idx, float, atomic>::Call(out_buf, arg_buf,\n        static_cast<float>(val), id);\n  }\n};\n#endif  // BF16_ENABLED\n\ntemplate <typename Idx, typename DType, bool atomic>\nstruct _Max {\n  static constexpr __host__ __device__ __forceinline__ DType zero() {\n    return -std::numeric_limits<DType>::infinity();\n  }\n  static constexpr bool require_arg = true;\n  static __device__ __forceinline__ void Call(\n      DType *out_buf, Idx *arg_u_buf, Idx *arg_e_buf, DType val, Idx uid,\n      Idx eid) {\n    if (!atomic) {\n      if (*out_buf < val) {\n        *out_buf = val;\n        *arg_u_buf = uid;\n        *arg_e_buf = eid;\n      }\n    } else {\n      cuda::AtomicMax(out_buf, val);\n    }\n  }\n  static __device__ __forceinline__ void Call(\n      DType *out_buf, Idx *arg_buf, DType val, Idx id) {\n    if (!atomic) {\n      if (*out_buf < val) {\n        *out_buf = val;\n        *arg_buf = id;\n      }\n    } else {\n      cuda::AtomicMax(out_buf, val);\n    }\n  }\n  static __device__ __forceinline__ void CallArg(\n      Idx fid, Idx *arg_u_buf, Idx *arg_e_buf, DType val, DType val_ref,\n      Idx uid, Idx eid) {\n    if (atomic) {\n      if (val == val_ref) {\n        if (arg_u_buf) arg_u_buf[fid] = uid;\n        if (arg_e_buf) arg_e_buf[fid] = eid;\n      }\n    }\n  }\n};\n\ntemplate <typename Idx, typename DType, bool atomic = false>\nstruct Max : _Max<Idx, DType, atomic> {};\n\ntemplate <typename Idx, bool atomic>\nstruct Max<Idx, __half, atomic> : _Max<Idx, __half, atomic> {\n  static constexpr __host__ __device__ __forceinline__ __half zero() {\n    return __float2half_rn(-6.550400e+04f);\n  }\n  static __device__ __forceinline__ void Call(\n      __half *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,\n      __half val, Idx uid, Idx eid) {\n    _Max<Idx, __half, atomic>::Call(\n        out_buf, arg_u_buf, arg_e_buf, val, uid, eid);\n  }\n  static __device__ __forceinline__ void Call(\n      __half *out_buf, Idx *arg_buf, __half val, Idx id) {\n    _Max<Idx, __half, atomic>::Call(out_buf, arg_buf, val, id);\n  }\n  // sometimes we have to use float in reduction for better precision\n  static __device__ __forceinline__ void Call(\n      float *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,\n      __half val, Idx uid, Idx eid) {\n    _Max<Idx, float, atomic>::Call(out_buf, arg_u_buf, arg_e_buf,\n        static_cast<float>(val), uid, eid);\n  }\n  static __device__ __forceinline__ void Call(\n      float *out_buf, Idx *arg_buf, __half val, Idx id) {\n    _Max<Idx, float, atomic>::Call(out_buf, arg_buf,\n        static_cast<float>(val), id);\n  }\n};\n\n#if BF16_ENABLED\ntemplate <typename Idx, bool atomic>\nstruct Max<Idx, __nv_bfloat16, atomic> : _Max<Idx, __nv_bfloat16, atomic> {\n  static constexpr __host__ __device__ __forceinline__ __nv_bfloat16 zero() {\n    return __float2bfloat16_rn(-std::numeric_limits<float>::infinity());\n  }\n  static __device__ __forceinline__ void Call(\n      __nv_bfloat16 *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,\n      __nv_bfloat16 val, Idx uid, Idx eid) {\n    _Max<Idx, __nv_bfloat16, atomic>::Call(\n        out_buf, arg_u_buf, arg_e_buf, val, uid, eid);\n  }\n  static __device__ __forceinline__ void Call(\n      __nv_bfloat16 *out_buf, Idx *arg_buf, __nv_bfloat16 val, Idx id) {\n    _Max<Idx, __nv_bfloat16, atomic>::Call(out_buf, arg_buf, val, id);\n  }\n  // sometimes we have to use float in reduction for better precision\n  static __device__ __forceinline__ void Call(\n      float *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,\n      __nv_bfloat16 val, Idx uid, Idx eid) {\n    _Max<Idx, float, atomic>::Call(out_buf, arg_u_buf, arg_e_buf,\n        static_cast<float>(val), uid, eid);\n  }\n  static __device__ __forceinline__ void Call(\n      float *out_buf, Idx *arg_buf, __nv_bfloat16 val, Idx id) {\n    _Max<Idx, float, atomic>::Call(out_buf, arg_buf,\n        static_cast<float>(val), id);\n  }\n};\n#endif  // BF16_ENABLED\n\ntemplate <typename Idx, typename DType, bool atomic>\nstruct _Min {\n  static constexpr __host__ __device__ __forceinline__ DType zero() {\n    return std::numeric_limits<DType>::infinity();\n  }\n  static constexpr bool require_arg = true;\n  static __device__ __forceinline__ void Call(\n      DType *out_buf, Idx *arg_u_buf, Idx *arg_e_buf, DType val, Idx uid,\n      Idx eid) {\n    if (!atomic) {\n      if (*out_buf > val) {\n        *out_buf = val;\n        *arg_u_buf = uid;\n        *arg_e_buf = eid;\n      }\n    } else {\n      cuda::AtomicMin(out_buf, val);\n    }\n  }\n  static __device__ __forceinline__ void Call(\n      DType *out_buf, Idx *arg_buf, DType val, Idx id) {\n    if (!atomic) {\n      if (*out_buf > val) {\n        *out_buf = val;\n        *arg_buf = id;\n      }\n    } else {\n      cuda::AtomicMin(out_buf, val);\n    }\n  }\n  static __device__ __forceinline__ void CallArg(\n      Idx fid, Idx *arg_u_buf, Idx *arg_e_buf, DType val, DType val_ref,\n      Idx uid, Idx eid) {\n    if (atomic) {\n      if (val == val_ref) {\n        if (arg_u_buf) arg_u_buf[fid] = uid;\n        if (arg_e_buf) arg_e_buf[fid] = eid;\n      }\n    }\n  }\n};\n\ntemplate <typename Idx, typename DType, bool atomic = false>\nstruct Min : _Min<Idx, DType, atomic> {};\n\ntemplate <typename Idx, bool atomic>\nstruct Min<Idx, __half, atomic> : _Min<Idx, __half, atomic> {\n  static constexpr __host__ __device__ __forceinline__ __half zero() {\n    return __float2half_rn(6.550400e+04f);\n  }\n  static __device__ __forceinline__ void Call(\n      __half *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,\n      __half val, Idx uid, Idx eid) {\n    _Min<Idx, __half, atomic>::Call(\n        out_buf, arg_u_buf, arg_e_buf, val, uid, eid);\n  }\n  static __device__ __forceinline__ void Call(\n      __half *out_buf, Idx *arg_buf, __half val, Idx id) {\n    _Min<Idx, __half, atomic>::Call(out_buf, arg_buf, val, id);\n  }\n  // sometimes we have to use float in reduction for better precision\n  static __device__ __forceinline__ void Call(\n      float *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,\n      __half val, Idx uid, Idx eid) {\n    _Min<Idx, float, atomic>::Call(out_buf, arg_u_buf, arg_e_buf,\n        static_cast<float>(val), uid, eid);\n  }\n  static __device__ __forceinline__ void Call(\n      float *out_buf, Idx *arg_buf, __half val, Idx id) {\n    _Min<Idx, float, atomic>::Call(out_buf, arg_buf,\n        static_cast<float>(val), id);\n  }\n};\n\n#if BF16_ENABLED\ntemplate <typename Idx, bool atomic>\nstruct Min<Idx, __nv_bfloat16, atomic> : _Min<Idx, __nv_bfloat16, atomic> {\n  static constexpr __host__ __device__ __forceinline__ __nv_bfloat16 zero() {\n    return __float2bfloat16_rn(std::numeric_limits<float>::infinity());\n  }\n  static __device__ __forceinline__ void Call(\n      __nv_bfloat16 *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,\n      __nv_bfloat16 val, Idx uid, Idx eid) {\n    _Min<Idx, __nv_bfloat16, atomic>::Call(\n        out_buf, arg_u_buf, arg_e_buf, val, uid, eid);\n  }\n  static __device__ __forceinline__ void Call(\n      __nv_bfloat16 *out_buf, Idx *arg_buf, __nv_bfloat16 val, Idx id) {\n    _Min<Idx, __nv_bfloat16, atomic>::Call(out_buf, arg_buf, val, id);\n  }\n  // sometimes we have to use float in reduction for better precision\n  static __device__ __forceinline__ void Call(\n      float *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,\n      __nv_bfloat16 val, Idx uid, Idx eid) {\n    _Min<Idx, float, atomic>::Call(out_buf, arg_u_buf, arg_e_buf,\n        static_cast<float>(val), uid, eid);\n  }\n  static __device__ __forceinline__ void Call(\n      float *out_buf, Idx *arg_buf, __nv_bfloat16 val, Idx id) {\n    _Min<Idx, float, atomic>::Call(out_buf, arg_buf,\n        static_cast<float>(val), id);\n  }\n};\n#endif  // BF16_ENABLED\n\n}  // namespace reduce\n\n}  // namespace cuda\n}  // namespace aten\n}  // namespace dgl\n\n#endif  // DGL_ARRAY_CUDA_FUNCTOR_CUH_\n"
  },
  {
    "path": "src/array/cuda/gather_mm.cu",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file array/cuda/gather_mm.cu\n * @brief GatherMM C APIs and definitions.\n */\n#include <dgl/array.h>\n\n#include <algorithm>  // std::swap\n\n#include \"./atomic.cuh\"\n#include \"./functor.cuh\"\n#include \"./utils.h\"\n\nnamespace dgl {\nusing namespace cuda;\nnamespace aten {\n\nnamespace {\n\n/** @brief Call cuBLAS GEMM API for dense matmul operation for float and double.\n */\ntemplate <typename DType>\ncublasStatus_t cublasGemm(\n    cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,\n    int m, int n, int k, const DType* alpha, const DType* A, int lda,\n    const DType* B, int ldb, const DType* beta, DType* C, int ldc) {\n  LOG(INFO) << \"Not supported dtype\";\n  return CUBLAS_STATUS_EXECUTION_FAILED;\n}\n\ntemplate <>\ncublasStatus_t cublasGemm<__half>(\n    cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,\n    int m, int n, int k, const __half* alpha, const __half* A, int lda,\n    const __half* B, int ldb, const __half* beta, __half* C, int ldc) {\n  return cublasHgemm(\n      handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);\n}\n\n#if BF16_ENABLED\ntemplate <>\ncublasStatus_t cublasGemm<__nv_bfloat16>(\n    cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,\n    int m, int n, int k, const __nv_bfloat16* alpha, const __nv_bfloat16* A,\n    int lda, const __nv_bfloat16* B, int ldb, const __nv_bfloat16* beta,\n    __nv_bfloat16* C, int ldc) {\n  float alpha_float = __bfloat162float(*alpha);\n  float beta_float = __bfloat162float(*beta);\n  return cublasGemmEx(\n      handle, transa, transb, m, n, k, &alpha_float, A, CUDA_R_16BF, lda, B,\n      CUDA_R_16BF, ldb, &beta_float, C, CUDA_R_16BF, ldc, CUBLAS_COMPUTE_32F,\n      CUBLAS_GEMM_DEFAULT_TENSOR_OP);\n}\n#endif  // BF16_ENABLED\n\ntemplate <>\ncublasStatus_t cublasGemm<float>(\n    cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,\n    int m, int n, int k, const float* alpha, const float* A, int lda,\n    const float* B, int ldb, const float* beta, float* C, int ldc) {\n  return cublasSgemm(\n      handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);\n}\n\ntemplate <>\ncublasStatus_t cublasGemm<double>(\n    cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,\n    int m, int n, int k, const double* alpha, const double* A, int lda,\n    const double* B, int ldb, const double* beta, double* C, int ldc) {\n  return cublasDgemm(\n      handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);\n}\n\n}  // namespace\n\nnamespace cuda {\n\n/**\n * @note Each row of A multiplies a segment of matrix of B of dimension in_len *\n * outlen. One warp is assigned to process one row of A. Each WARP sequentially\n * multiplies one element of A and a row of B to compute partial result of the\n * output. A is loaded in shared memory in a coalesced way. Output matrix is\n * loaded in registers. B should get benefit from L2 cache.\n */\ntemplate <typename Idx, typename DType>\n__global__ void GatherMMScatterKernel(\n    const DType* __restrict__ A, const DType* __restrict__ B,\n    DType* __restrict__ C, const Idx* __restrict__ idx_a,\n    const Idx* __restrict__ idx_b, const Idx* __restrict__ idx_c,\n    const int64_t num_rows, const int64_t in_len, const int64_t out_len) {\n  unsigned int tId = threadIdx.x;\n  unsigned int laneId = tId & 31;\n  unsigned int gId = (blockIdx.x * blockDim.x + threadIdx.x);\n  unsigned int warpId = gId >> 5;\n  unsigned int row = warpId;\n  if (row < num_rows) {\n    const unsigned int local_row =\n        row & 3;  // hardcoded for TB size 128 (4 warps)\n    const Idx cur_rowA = (idx_a) ? idx_a[row] : row;\n    const Idx cur_rowB = (idx_b) ? idx_b[row] : row;\n    const Idx cur_rowC = (idx_c) ? idx_c[row] : row;\n    const Idx B_offset = cur_rowB * in_len * out_len;\n    const int sh_a_tile = 64;\n    __shared__ DType sh_A[4 * sh_a_tile];\n    int a_tile = sh_a_tile;\n    for (unsigned int k_start = 0; k_start < in_len; k_start += 64) {\n      if ((in_len - k_start) < a_tile) a_tile = in_len - k_start;\n      // Load A in shared mem in a coalesced way\n      for (unsigned int l = laneId; l < a_tile; l += 32)\n        sh_A[local_row * sh_a_tile + l] = A[cur_rowA * in_len + (k_start + l)];\n      __syncwarp();\n\n      for (unsigned int outloop = 0; outloop < out_len; outloop += 32) {\n        DType out_reg = static_cast<DType>(0.0f);  // thread private\n        const unsigned int l = laneId;\n        if (l < out_len) {\n          // iterate over elements of a row of A\n          for (unsigned int i = 0; i < a_tile; i++) {\n            const DType a_val = sh_A[local_row * sh_a_tile + i];\n            // iterate over elements of a row of B in parallel\n            out_reg +=\n                a_val * B[B_offset + ((i + k_start) * out_len + (outloop + l))];\n          }\n          if (idx_c) {\n            AtomicAdd(C + cur_rowC * out_len + (outloop + l), out_reg);\n          } else {\n            C[cur_rowC * out_len + (outloop + l)] += out_reg;\n          }\n        }\n      }\n    }\n  }\n}\n\n/**\n * @note Output matrix is accumulated via atomic operations. Rest of the\n * strategies are similar to GatherMMKernel. One warp is assigned to process one\n * row of A. Each WARP sequentially multiplies one element of A and a row of B\n * to compute partial result of the output. A is loaded in shared memory in a\n * coalesced way. B should get benefit from L2 cache.\n */\ntemplate <typename Idx, typename DType>\n__global__ void GatherMMScatterKernel2(\n    const DType* __restrict__ A, const DType* __restrict__ B,\n    DType* __restrict__ C, const Idx* __restrict__ idx_a,\n    const Idx* __restrict__ idx_b, const Idx* __restrict__ idx_c,\n    const int64_t num_rows, const int64_t in_len, const int64_t out_len) {\n  unsigned int tId = threadIdx.x;\n  unsigned int laneId = tId & 31;\n  unsigned int gId = (blockIdx.x * blockDim.x + threadIdx.x);\n  unsigned int warpId = gId >> 5;\n  unsigned int row = warpId;\n  if (row < num_rows) {\n    const unsigned int local_row =\n        row & 3;  // hardcoded for TB size 128 (4 warps)\n    const Idx row_a = (idx_a) ? idx_a[row] : row;\n    const Idx row_b = (idx_b) ? idx_b[row] : row;\n    const Idx row_c = (idx_c) ? idx_c[row] : row;\n    const Idx C_offset = row_c * in_len * out_len;\n    const int sh_a_tile = 64;\n    __shared__ DType sh_A[4 * sh_a_tile];\n    int a_tile = sh_a_tile;\n    for (unsigned int k_start = 0; k_start < in_len; k_start += 64) {\n      if ((in_len - k_start) < a_tile) a_tile = in_len - k_start;\n      /* Load A in shared mem in a coalesced way */\n      for (unsigned int l = laneId; l < a_tile; l += 32)\n        sh_A[local_row * sh_a_tile + l] = A[row_a * in_len + (k_start + l)];\n      __syncwarp();\n\n      for (unsigned int outloop = 0; outloop < out_len; outloop += 32) {\n        DType out_reg = static_cast<DType>(0.0f);  // thread private\n        const unsigned int l = laneId;\n        if (l < out_len) {\n          const DType b_val = B[row_b * out_len + (outloop + l)];\n          /* iterate over elements of a row of A */\n          for (unsigned int i = 0; i < a_tile; i++) {\n            const DType a_val = sh_A[local_row * sh_a_tile + i];\n            const Idx C_idx =\n                C_offset + ((i + k_start) * out_len + (outloop + l));\n            AtomicAdd(C + C_idx, a_val * b_val);\n          }\n        }\n      }\n    }\n  }\n}\n\n}  // namespace cuda\n\n/**\n * @brief Implementation of Gather_mm operator. The input matrix A is\n *        expected to be sorted according to relation type.\n * @param A The input dense matrix of dimension m x k\n * @param B The input dense matrix of dimension k x n\n * @param C The output dense matrix of dimension m x n\n * @param seglen_A The input vector of size R. Each element\n *        is the length of segments of input ``A``\n * @param a_trans Matrix A to be transposed\n * @param b_trans Matrix B to be transposed\n */\ntemplate <int XPU, typename IdType, typename DType>\nvoid SegmentMM(\n    const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,\n    bool a_trans, bool b_trans) {\n  auto device = runtime::DeviceAPI::Get(A->ctx);\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  const DType* A_data = A.Ptr<DType>();\n  const DType* B_data = B.Ptr<DType>();\n  const IdType* seglen_A_data = seglen_A.Ptr<IdType>();\n  DType* C_data = C.Ptr<DType>();\n  int64_t A_offset = 0, B_offset = 0, C_offset = 0;\n  int64_t m, n, k;\n  int64_t num_rel = seglen_A.NumElements();\n  DType alpha = 1., beta = 0.;\n\n  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();\n  if (!thr_entry->cublas_handle)\n    CUBLAS_CALL(cublasCreate(&(thr_entry->cublas_handle)));\n  CUBLAS_CALL(cublasSetStream(thr_entry->cublas_handle, stream));\n\n  IdType m_offset = 0;\n  for (IdType etype = 0; etype < num_rel; ++etype) {\n    m = seglen_A_data[etype];  // rows of A\n    CHECK_LE(m_offset + m, A->shape[0])\n        << \"Segment index out of bound of A->shape[0].\";\n    n = B->shape[2];  // cols of B\n    k = B->shape[1];  // cols of A == rows of B\n    int ldb = n, lda = k, ldc = n;\n    cublasOperation_t transB = CUBLAS_OP_N;\n    cublasOperation_t transA = CUBLAS_OP_N;\n    if (b_trans) {\n      transB = CUBLAS_OP_T;\n      ldb = n, lda = n, ldc = k;\n      std::swap(n, k);\n    }\n    CUBLAS_CALL(cublasGemm<DType>(\n        thr_entry->cublas_handle, transB, transA, n, m, k, &alpha,\n        B_data + B_offset, ldb, A_data + A_offset, lda, &beta,\n        C_data + C_offset, ldc));\n    A_offset += m * k;\n    B_offset += k * n;\n    C_offset += m * n;\n    m_offset += m;\n  }\n}\n\ntemplate <int XPU, typename IdType, typename DType>\nvoid SegmentMMBackwardB(\n    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen) {\n  auto device = runtime::DeviceAPI::Get(A->ctx);\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  const DType* A_data = A.Ptr<DType>();\n  const DType* dC_data = dC.Ptr<DType>();\n  const IdType* seglen_data = seglen.Ptr<IdType>();\n  DType* dB_data = dB.Ptr<DType>();\n  int64_t A_offset = 0, dC_offset = 0, dB_offset = 0;\n  int64_t m, n, k;\n  int64_t num_rel = seglen.NumElements();\n  DType alpha = 1., beta = 0.;\n\n  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();\n  if (!thr_entry->cublas_handle)\n    CUBLAS_CALL(cublasCreate(&(thr_entry->cublas_handle)));\n  CUBLAS_CALL(cublasSetStream(thr_entry->cublas_handle, stream));\n\n  IdType k_offset = 0;\n  for (IdType etype = 0; etype < num_rel; ++etype) {\n    m = dC->shape[1];\n    n = A->shape[1];\n    k = seglen_data[etype];\n    CHECK_LE(k_offset + k, A->shape[0])\n        << \"Segement index out of bound of A->shape[0].\";\n    int lddC = m, ldA = n, lddB = m;\n    cublasOperation_t trans_dC = CUBLAS_OP_N;\n    cublasOperation_t trans_A = CUBLAS_OP_T;\n    CUBLAS_CALL(cublasGemm<DType>(\n        thr_entry->cublas_handle, trans_dC, trans_A, m, n, k, &alpha,\n        dC_data + dC_offset, lddC, A_data + A_offset, ldA, &beta,\n        dB_data + dB_offset, lddB));\n    dC_offset += m * k;\n    A_offset += n * k;\n    dB_offset += m * n;\n    k_offset += k;\n  }\n}\n\n/**\n * @brief Implementation of Gather_mm operator. The input matrix A is\n *        expected to be sorted according to relation type.\n * @param A The input dense matrix of dimension m x k\n * @param B The input dense matrix of dimension k x n\n * @param C The output dense matrix of dimension m x n\n * @param idx_a The input vector to gather left hand operand on\n * @param idx_b The input vector to gather right hand operand on\n */\n\ntemplate <int XPU, typename IdType, typename DType>\nvoid GatherMM(\n    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,\n    const NDArray idx_b) {\n  auto device = runtime::DeviceAPI::Get(A->ctx);\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  int64_t out_len = B->shape[2];  // cols of B\n  int64_t in_len = A->shape[1];   // cols of A\n  const int64_t tot_num_rows = A->shape[0];\n  const int ntx = 128;\n  const int warp_size = 32;\n  const int nbx = ((tot_num_rows * warp_size + ntx - 1) / ntx);\n  const dim3 nblks(nbx);\n  const dim3 nthrs(ntx);\n  CUDA_KERNEL_CALL(\n      (cuda::GatherMMScatterKernel<IdType, DType>), nblks, nthrs, 0, stream,\n      A.Ptr<DType>(), B.Ptr<DType>(), C.Ptr<DType>(), idx_a.Ptr<IdType>(),\n      idx_b.Ptr<IdType>(), nullptr, tot_num_rows, in_len, out_len);\n}\n\n/**\n * @brief Implementation of Gather_mm operator. The input matrix A is\n *        expected to be sorted according to relation type.\n * @param A The input dense matrix of dimension m x k\n * @param B The input dense matrix of dimension k x n\n * @param C The output dense matrix of dimension m x n\n * @param idx_a The input vector to gather left hand operand on\n * @param idx_b The input vector to gather right hand operand on\n * @param idx_c The input vector to gather output operand on\n * @param num_rel The number of idx types in idx_b\n * @param a_trans Matrix A to be transposed\n * @param b_trans Matrix B to be transposed\n */\ntemplate <int XPU, typename IdType, typename DType>\nvoid GatherMMScatter(\n    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,\n    const NDArray idx_b, const NDArray idx_c) {\n  auto device = runtime::DeviceAPI::Get(A->ctx);\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  const IdType* idx_c_data = idx_c.Ptr<IdType>();\n  int64_t out_len = (B->ndim == 2) ? B->shape[1] : B->shape[2];  // cols of B\n  int64_t in_len = A->shape[1];                                  // cols of A\n  int64_t tot_num_rows = A->shape[0];\n  const int ntx = 128;\n  const int warp_size = 32;\n  const int nbx = ((tot_num_rows * warp_size + ntx - 1) / ntx);\n  const dim3 nblks(nbx);\n  const dim3 nthrs(ntx);\n  if (B->ndim == 3) {\n    CUDA_KERNEL_CALL(\n        (cuda::GatherMMScatterKernel<IdType, DType>), nblks, nthrs, 0, stream,\n        A.Ptr<DType>(), B.Ptr<DType>(), C.Ptr<DType>(), idx_a.Ptr<IdType>(),\n        idx_b.Ptr<IdType>(), idx_c.Ptr<IdType>(), tot_num_rows, in_len,\n        out_len);\n  } else {\n    // Custom kernel for W_grad[idx_c[i]] = H^T[i] * C.grad[i]\n    // This kernel accesses rows of A in a transposed way w/o explicitly\n    // converting A\n    CUDA_KERNEL_CALL(\n        (cuda::GatherMMScatterKernel2<IdType, DType>), nblks, nthrs, 0, stream,\n        A.Ptr<DType>(), B.Ptr<DType>(), C.Ptr<DType>(), idx_a.Ptr<IdType>(),\n        idx_b.Ptr<IdType>(), idx_c.Ptr<IdType>(), tot_num_rows, in_len,\n        out_len);\n  }\n}\n\ntemplate void GatherMM<kDGLCUDA, int32_t, __half>(\n    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,\n    const NDArray idx_b);\ntemplate void GatherMM<kDGLCUDA, int64_t, __half>(\n    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,\n    const NDArray idx_b);\n#if BF16_ENABLED\ntemplate void GatherMM<kDGLCUDA, int32_t, __nv_bfloat16>(\n    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,\n    const NDArray idx_b);\ntemplate void GatherMM<kDGLCUDA, int64_t, __nv_bfloat16>(\n    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,\n    const NDArray idx_b);\n#endif  // BF16_ENABLED\ntemplate void GatherMM<kDGLCUDA, int32_t, float>(\n    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,\n    const NDArray idx_b);\ntemplate void GatherMM<kDGLCUDA, int64_t, float>(\n    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,\n    const NDArray idx_b);\ntemplate void GatherMM<kDGLCUDA, int32_t, double>(\n    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,\n    const NDArray idx_b);\ntemplate void GatherMM<kDGLCUDA, int64_t, double>(\n    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,\n    const NDArray idx_b);\n\ntemplate void GatherMMScatter<kDGLCUDA, int32_t, __half>(\n    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,\n    const NDArray idx_b, const NDArray idx_c);\ntemplate void GatherMMScatter<kDGLCUDA, int64_t, __half>(\n    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,\n    const NDArray idx_b, const NDArray idx_c);\n#if BF16_ENABLED\ntemplate void GatherMMScatter<kDGLCUDA, int32_t, __nv_bfloat16>(\n    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,\n    const NDArray idx_b, const NDArray idx_c);\ntemplate void GatherMMScatter<kDGLCUDA, int64_t, __nv_bfloat16>(\n    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,\n    const NDArray idx_b, const NDArray idx_c);\n#endif  // BF16_ENABLED\ntemplate void GatherMMScatter<kDGLCUDA, int32_t, float>(\n    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,\n    const NDArray idx_b, const NDArray idx_c);\ntemplate void GatherMMScatter<kDGLCUDA, int64_t, float>(\n    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,\n    const NDArray idx_b, const NDArray idx_c);\ntemplate void GatherMMScatter<kDGLCUDA, int32_t, double>(\n    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,\n    const NDArray idx_b, const NDArray idx_c);\ntemplate void GatherMMScatter<kDGLCUDA, int64_t, double>(\n    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,\n    const NDArray idx_b, const NDArray idx_c);\n\ntemplate void SegmentMM<kDGLCUDA, int32_t, __half>(\n    const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,\n    bool a_trans, bool b_trans);\ntemplate void SegmentMM<kDGLCUDA, int64_t, __half>(\n    const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,\n    bool a_trans, bool b_trans);\n#if BF16_ENABLED\ntemplate void SegmentMM<kDGLCUDA, int32_t, __nv_bfloat16>(\n    const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,\n    bool a_trans, bool b_trans);\ntemplate void SegmentMM<kDGLCUDA, int64_t, __nv_bfloat16>(\n    const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,\n    bool a_trans, bool b_trans);\n#endif  // BF16_ENABLED\ntemplate void SegmentMM<kDGLCUDA, int32_t, float>(\n    const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,\n    bool a_trans, bool b_trans);\ntemplate void SegmentMM<kDGLCUDA, int64_t, float>(\n    const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,\n    bool a_trans, bool b_trans);\ntemplate void SegmentMM<kDGLCUDA, int32_t, double>(\n    const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,\n    bool a_trans, bool b_trans);\ntemplate void SegmentMM<kDGLCUDA, int64_t, double>(\n    const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,\n    bool a_trans, bool b_trans);\n\ntemplate void SegmentMMBackwardB<kDGLCUDA, int32_t, __half>(\n    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);\ntemplate void SegmentMMBackwardB<kDGLCUDA, int64_t, __half>(\n    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);\n#if BF16_ENABLED\ntemplate void SegmentMMBackwardB<kDGLCUDA, int32_t, __nv_bfloat16>(\n    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);\ntemplate void SegmentMMBackwardB<kDGLCUDA, int64_t, __nv_bfloat16>(\n    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);\n#endif  // BF16_ENABLED\ntemplate void SegmentMMBackwardB<kDGLCUDA, int32_t, float>(\n    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);\ntemplate void SegmentMMBackwardB<kDGLCUDA, int64_t, float>(\n    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);\ntemplate void SegmentMMBackwardB<kDGLCUDA, int32_t, double>(\n    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);\ntemplate void SegmentMMBackwardB<kDGLCUDA, int64_t, double>(\n    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);\n\n}  // namespace aten\n}  // namespace dgl\n"
  },
  {
    "path": "src/array/cuda/ge_spmm.cuh",
    "content": "/**\n * Copyright (c) 2020 by Contributors\n * @file array/cuda/ge_spmm.cuh\n * @brief GE-SpMM CUDA kernel function header.\n */\n#ifndef DGL_ARRAY_CUDA_GE_SPMM_CUH_\n#define DGL_ARRAY_CUDA_GE_SPMM_CUH_\n\n#include \"../../runtime/cuda/cuda_common.h\"\n#include \"./utils.h\"\n#include \"atomic.cuh\"\n#include \"macro.cuh\"\n\nnamespace dgl {\n\nusing namespace cuda;\n\nnamespace aten {\nnamespace cuda {\n\n/**\n * @brief CUDA kernel of GE-SpMM on Csr.\n * @note GE-SpMM: https://arxiv.org/pdf/2007.03179.pdf\n *       The grid dimension x and y are reordered for better performance.\n */\ntemplate <typename Idx, typename DType, typename BinaryOp>\n__global__ void GESpMMKernel(\n    const DType* __restrict__ ufeat, const DType* __restrict__ efeat,\n    DType* __restrict__ out, const Idx* __restrict__ indptr,\n    const Idx* __restrict__ indices, const int64_t num_rows,\n    const int64_t num_cols, const int64_t feat_len) {\n  const Idx rid =\n      blockIdx.x * blockDim.y + threadIdx.y;        // over vertices dimension\n  const Idx fid = (blockIdx.y * 64) + threadIdx.x;  // over feature dimension\n\n  if (rid < num_rows && fid < feat_len) {\n    const Idx low = __ldg(indptr + rid), high = __ldg(indptr + rid + 1);\n    DType accum_0 = 0., accum_1 = 0.;\n\n    if (blockIdx.y != gridDim.y - 1) {  // fid + 32 < feat_len\n      for (Idx left = low; left < high; left += 32) {\n        if (left + 32 <= high) {\n#pragma unroll\n          for (Idx i = 0; i < 32; ++i) {\n            const Idx eid = left + i;\n            const Idx cid = __ldg(indices + eid);\n            const Idx offset = feat_len * cid + fid;\n            if (BinaryOp::use_rhs) {\n              accum_0 += BinaryOp::Call(ufeat + offset, efeat + eid);\n              accum_1 += BinaryOp::Call(ufeat + offset + 32, efeat + eid);\n            } else {\n              accum_0 += ufeat[offset];\n              accum_1 += ufeat[offset + 32];\n            }\n          }\n        } else {\n          for (Idx i = 0; left + i < high; ++i) {\n            const Idx eid = left + i;\n            const Idx cid = __ldg(indices + eid);\n            const Idx offset = feat_len * cid + fid;\n            if (BinaryOp::use_rhs) {\n              accum_0 += BinaryOp::Call(ufeat + offset, efeat + eid);\n              accum_1 += BinaryOp::Call(ufeat + offset + 32, efeat + eid);\n            } else {\n              accum_0 += ufeat[offset];\n              accum_1 += ufeat[offset + 32];\n            }\n          }\n        }\n\n        out[feat_len * rid + fid] = accum_0;\n        out[feat_len * rid + fid + 32] = accum_1;\n      }\n    } else {\n      const Idx fid_0 = fid < feat_len ? fid : 0,\n                fid_1 = fid + 32 < feat_len ? fid + 32 : 0;\n      for (int left = low; left < high; left += 32) {\n        if (left + 32 <= high) {\n#pragma unroll\n          for (int i = 0; i < 32; ++i) {\n            const Idx eid = left + i;\n            const Idx cid = __ldg(indices + eid);\n            const Idx offset = feat_len * cid;\n            if (BinaryOp::use_rhs) {\n              accum_0 += BinaryOp::Call(ufeat + offset + fid_0, efeat + eid);\n              accum_1 += BinaryOp::Call(ufeat + offset + fid_1, efeat + eid);\n            } else {\n              accum_0 += ufeat[offset + fid_0];\n              accum_1 += ufeat[offset + fid_1];\n            }\n          }\n        } else {\n          for (int i = 0; i + left < high; ++i) {\n            const Idx eid = left + i;\n            const Idx cid = __ldg(indices + eid);\n            const Idx offset = feat_len * cid;\n            if (BinaryOp::use_rhs) {\n              accum_0 += BinaryOp::Call(ufeat + offset + fid_0, efeat + eid);\n              accum_1 += BinaryOp::Call(ufeat + offset + fid_1, efeat + eid);\n            } else {\n              accum_0 += ufeat[offset + fid_0];\n              accum_1 += ufeat[offset + fid_1];\n            }\n          }\n        }\n\n        out[feat_len * rid + fid] = accum_0;\n        if (fid + 32 < feat_len) out[feat_len * rid + fid + 32] = accum_1;\n      }\n    }\n  }\n}\n\ntemplate <typename Idx, typename DType, typename BinaryOp>\nvoid GESpMMCsr(\n    const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,\n    int64_t feat_len) {\n  const Idx* indptr = csr.indptr.Ptr<Idx>();\n  const Idx* indices = csr.indices.Ptr<Idx>();\n  const DType* ufeat_data = ufeat.Ptr<DType>();\n  const DType* efeat_data = efeat.Ptr<DType>();\n  DType* out_data = out.Ptr<DType>();\n\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n\n  const int ntx = 32;\n  const int nty = 32;\n  const int nby = (feat_len + (ntx * 2) - 1) / (ntx * 2);\n  const int nbx = (csr.num_rows + nty - 1) / nty;\n  const dim3 nblks(nbx, nby);\n  const dim3 nthrs(ntx, nty);\n  const int sh_mem_size = 0;\n\n  CUDA_KERNEL_CALL(\n      (GESpMMKernel<Idx, DType, BinaryOp>), nblks, nthrs, sh_mem_size, stream,\n      ufeat_data, efeat_data, out_data, indptr, indices, csr.num_rows,\n      csr.num_cols, feat_len);\n}\n\n}  // namespace cuda\n}  // namespace aten\n}  // namespace dgl\n\n#endif  // DGL_ARRAY_CUDA_GE_SPMM_CUH_\n"
  },
  {
    "path": "src/array/cuda/labor_sampling.cu",
    "content": "/*!\n *   Copyright (c) 2022, NVIDIA Corporation\n *   Copyright (c) 2022, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)\n *   All rights reserved.\n *\n *   Licensed under the Apache License, Version 2.0 (the \"License\");\n *   you may not use this file except in compliance with the License.\n *   You may obtain a copy of the License at\n *\n *       http://www.apache.org/licenses/LICENSE-2.0\n *\n *   Unless required by applicable law or agreed to in writing, software\n *   distributed under the License is distributed on an \"AS IS\" BASIS,\n *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n *   See the License for the specific language governing permissions and\n *   limitations under the License.\n *\n * @file array/cuda/labor_sampling.cu\n * @brief labor sampling\n */\n\n#include <dgl/aten/coo.h>\n#include <dgl/random.h>\n#include <dgl/runtime/device_api.h>\n#include <thrust/binary_search.h>\n#include <thrust/copy.h>\n#include <thrust/execution_policy.h>\n#include <thrust/gather.h>\n#include <thrust/iterator/constant_iterator.h>\n#include <thrust/iterator/counting_iterator.h>\n#include <thrust/iterator/zip_iterator.h>\n#include <thrust/reduce.h>\n#include <thrust/shuffle.h>\n#include <thrust/transform.h>\n#include <thrust/zip_function.h>\n\n#include <algorithm>\n#include <cub/cub.cuh>  // NOLINT\n#include <limits>\n#include <numeric>\n#include <type_traits>\n#include <utility>\n\n#include \"../../array/cuda/utils.h\"\n#include \"../../random/continuous_seed.h\"\n#include \"../../runtime/cuda/cuda_common.h\"\n#include \"./functor.cuh\"\n#include \"./spmm.cuh\"\n\nnamespace dgl {\nnamespace aten {\nnamespace impl {\n\nusing dgl::random::continuous_seed;\n\nconstexpr int BLOCK_SIZE = 128;\nconstexpr int CTA_SIZE = 128;\nconstexpr double eps = 0.0001;\n\nnamespace {\n\ntemplate <typename IdType>\nstruct TransformOp {\n  const IdType* idx_coo;\n  const IdType* rows;\n  const IdType* indptr;\n  const IdType* subindptr;\n  const IdType* indices;\n  const IdType* data_arr;\n  bool is_pinned;\n  __host__ __device__ auto operator()(IdType idx) {\n    const auto in_row = idx_coo[idx];\n    const auto row = rows[in_row];\n    const auto in_idx = indptr[in_row] + idx - subindptr[in_row];\n    const auto u = indices[is_pinned ? idx : in_idx];\n    const auto data = data_arr ? data_arr[in_idx] : in_idx;\n    return thrust::make_tuple(row, u, data);\n  }\n};\n\ntemplate <\n    typename IdType, typename FloatType, typename probs_t, typename A_t,\n    typename B_t>\nstruct TransformOpImp {\n  probs_t probs;\n  A_t A;\n  B_t B;\n  const IdType* idx_coo;\n  const IdType* rows;\n  const FloatType* cs;\n  const IdType* indptr;\n  const IdType* subindptr;\n  const IdType* indices;\n  const IdType* data_arr;\n  bool is_pinned;\n  __host__ __device__ auto operator()(IdType idx) {\n    const auto ps = probs[idx];\n    const auto in_row = idx_coo[idx];\n    const auto c = cs[in_row];\n    const auto row = rows[in_row];\n    const auto in_idx = indptr[in_row] + idx - subindptr[in_row];\n    const auto u = indices[is_pinned ? idx : in_idx];\n    const auto w = A[in_idx];\n    const auto w2 = B[in_idx];\n    const auto data = data_arr ? data_arr[in_idx] : in_idx;\n    return thrust::make_tuple(\n        in_row, row, u, data, w / min((FloatType)1, c * w2 * ps));\n  }\n};\n\ntemplate <typename FloatType>\nstruct StencilOp {\n  const FloatType* cs;\n  template <typename IdType>\n  __host__ __device__ auto operator()(\n      IdType in_row, FloatType ps, FloatType rnd) {\n    return rnd <= cs[in_row] * ps;\n  }\n};\n\ntemplate <typename IdType, typename FloatType, typename ps_t, typename A_t>\nstruct StencilOpFused {\n  const continuous_seed seed;\n  const IdType* idx_coo;\n  const FloatType* cs;\n  const ps_t probs;\n  const A_t A;\n  const IdType* subindptr;\n  const IdType* indptr;\n  const IdType* indices;\n  const IdType* nids;\n  bool is_pinned;\n  __device__ auto operator()(IdType idx) {\n    const auto in_row = idx_coo[idx];\n    const auto ps = probs[idx];\n    IdType rofs = idx - subindptr[in_row];\n    const auto in_idx = indptr[in_row] + rofs;\n    const auto u = indices[is_pinned ? idx : in_idx];\n    const auto t = nids ? nids[u] : u;  // t in the paper\n    // rolled random number r_t is a function of the random_seed and t\n    const float rnd = seed.uniform(t);\n    return rnd <= cs[in_row] * A[in_idx] * ps;\n  }\n};\n\ntemplate <typename IdType, typename FloatType>\nstruct TransformOpMean {\n  const IdType* ds;\n  const FloatType* ws;\n  __host__ __device__ auto operator()(IdType idx, FloatType ps) {\n    return ps * ds[idx] / ws[idx];\n  }\n};\n\nstruct TransformOpMinWith1 {\n  template <typename FloatType>\n  __host__ __device__ auto operator()(FloatType x) {\n    return min((FloatType)1, x);\n  }\n};\n\ntemplate <typename IdType>\nstruct IndptrFunc {\n  const IdType* indptr;\n  const IdType* in_deg;\n  __host__ __device__ auto operator()(IdType row) {\n    return indptr[row] + (in_deg ? in_deg[row] : 0);\n  }\n};\n\ntemplate <typename FloatType>\nstruct SquareFunc {\n  __host__ __device__ auto operator()(FloatType x) {\n    return thrust::make_tuple(x, x * x);\n  }\n};\n\nstruct TupleSum {\n  template <typename T>\n  __host__ __device__ T operator()(const T& a, const T& b) const {\n    return thrust::make_tuple(\n        thrust::get<0>(a) + thrust::get<0>(b),\n        thrust::get<1>(a) + thrust::get<1>(b));\n  }\n};\n\ntemplate <typename IdType, typename FloatType>\nstruct DegreeFunc {\n  const IdType num_picks;\n  const IdType* rows;\n  const IdType* indptr;\n  IdType* in_deg;\n  IdType* inrow_indptr;\n  FloatType* cs;\n  __host__ __device__ auto operator()(IdType tIdx) {\n    const auto out_row = rows[tIdx];\n    const auto indptr_val = indptr[out_row];\n    const auto d = indptr[out_row + 1] - indptr_val;\n    in_deg[tIdx] = d;\n    inrow_indptr[tIdx] = indptr_val;\n    cs[tIdx] = num_picks / (FloatType)d;\n  }\n};\n\ntemplate <typename IdType, typename FloatType>\n__global__ void _CSRRowWiseOneHopExtractorKernel(\n    const continuous_seed seed, const IdType hop_size,\n    const IdType* const indptr, const IdType* const subindptr,\n    const IdType* const indices, const IdType* const idx_coo,\n    const IdType* const nids, const FloatType* const A, FloatType* const rands,\n    IdType* const hop, FloatType* const A_l) {\n  IdType tx = static_cast<IdType>(blockIdx.x) * blockDim.x + threadIdx.x;\n  const int stride_x = gridDim.x * blockDim.x;\n\n  while (tx < hop_size) {\n    IdType rpos = idx_coo[tx];\n    IdType rofs = tx - subindptr[rpos];\n    const auto in_idx = indptr[rpos] + rofs;\n    const auto not_pinned = indices != hop;\n    const auto u = indices[not_pinned ? in_idx : tx];\n    if (not_pinned) hop[tx] = u;\n    const auto t = nids ? nids[u] : u;\n    if (A) A_l[tx] = A[in_idx];\n    // rolled random number r_t is a function of the random_seed and t\n    rands[tx] = (FloatType)seed.uniform(t);\n    tx += stride_x;\n  }\n}\n\nconstexpr int CACHE_LINE_SIZE = 128;\n\ntemplate <typename IdType>\nstruct AlignmentFunc {\n  static_assert(CACHE_LINE_SIZE % sizeof(IdType) == 0);\n  const IdType* in_deg;\n  const int64_t* perm;\n  IdType num_rows;\n  __host__ __device__ auto operator()(IdType row) {\n    constexpr int num_elements = CACHE_LINE_SIZE / sizeof(IdType);\n    return in_deg[perm ? perm[row % num_rows] : row] + num_elements - 1;\n  }\n};\n\ntemplate <typename IdType>\n__global__ void _CSRRowWiseOneHopExtractorAlignedKernel(\n    const IdType hop_size, const IdType num_rows, const IdType* const indptr,\n    const IdType* const subindptr, const IdType* const subindptr_aligned,\n    const IdType* const indices, IdType* const hop, const int64_t* const perm) {\n  IdType tx = static_cast<IdType>(blockIdx.x) * blockDim.x + threadIdx.x;\n  const int stride_x = gridDim.x * blockDim.x;\n\n  while (tx < hop_size) {\n    const IdType rpos_ =\n        dgl::cuda::_UpperBound(subindptr_aligned, num_rows, tx) - 1;\n    const IdType rpos = perm ? perm[rpos_] : rpos_;\n    const auto out_row = subindptr[rpos];\n    const auto d = subindptr[rpos + 1] - out_row;\n    const int offset =\n        ((uint64_t)(indices + indptr[rpos] - subindptr_aligned[rpos_]) %\n         CACHE_LINE_SIZE) /\n        sizeof(IdType);\n    const IdType rofs = tx - subindptr_aligned[rpos_] - offset;\n    if (rofs >= 0 && rofs < d) {\n      const auto in_idx = indptr[rpos] + rofs;\n      assert((uint64_t)(indices + in_idx - tx) % CACHE_LINE_SIZE == 0);\n      const auto u = indices[in_idx];\n      hop[out_row + rofs] = u;\n    }\n    tx += stride_x;\n  }\n}\n\ntemplate <typename IdType, typename FloatType, int BLOCK_CTAS, int TILE_SIZE>\n__global__ void _CSRRowWiseLayerSampleDegreeKernel(\n    const IdType num_picks, const IdType num_rows, FloatType* const cs,\n    const FloatType* const ds, const FloatType* const d2s,\n    const IdType* const indptr, const FloatType* const probs,\n    const FloatType* const A, const IdType* const subindptr) {\n  typedef cub::BlockReduce<FloatType, BLOCK_SIZE> BlockReduce;\n  __shared__ typename BlockReduce::TempStorage temp_storage;\n  __shared__ FloatType var_1_bcast[BLOCK_CTAS];\n\n  // we assign one warp per row\n  assert(blockDim.x == CTA_SIZE);\n  assert(blockDim.y == BLOCK_CTAS);\n\n  IdType out_row = blockIdx.x * TILE_SIZE + threadIdx.y;\n  const auto last_row =\n      min(static_cast<IdType>(blockIdx.x + 1) * TILE_SIZE, num_rows);\n\n  constexpr FloatType ONE = 1;\n\n  while (out_row < last_row) {\n    const auto in_row_start = indptr[out_row];\n    const auto out_row_start = subindptr[out_row];\n\n    const IdType degree = subindptr[out_row + 1] - out_row_start;\n\n    if (degree > 0) {\n      // stands for k in in arXiv:2210.13339, i.e. fanout\n      const auto k = min(num_picks, degree);\n      // slightly better than NS\n      const FloatType d_ = ds ? ds[out_row] : degree;\n      // stands for right handside of Equation (22) in arXiv:2210.13339\n      FloatType var_target =\n          d_ * d_ / k + (ds ? d2s[out_row] - d_ * d_ / degree : 0);\n\n      auto c = cs[out_row];\n      const int num_valid = min(degree, (IdType)CTA_SIZE);\n      // stands for left handside of Equation (22) in arXiv:2210.13339\n      FloatType var_1;\n      do {\n        var_1 = 0;\n        if (A) {\n          for (int idx = threadIdx.x; idx < degree; idx += CTA_SIZE) {\n            const auto w = A[in_row_start + idx];\n            const auto ps = probs ? probs[out_row_start + idx] : w;\n            var_1 += w > 0 ? w * w / min(ONE, c * ps) : 0;\n          }\n        } else {\n          for (int idx = threadIdx.x; idx < degree; idx += CTA_SIZE) {\n            const auto ps = probs[out_row_start + idx];\n            var_1 += 1 / min(ONE, c * ps);\n          }\n        }\n        var_1 = BlockReduce(temp_storage).Sum(var_1, num_valid);\n        if (threadIdx.x == 0) var_1_bcast[threadIdx.y] = var_1;\n        __syncthreads();\n        var_1 = var_1_bcast[threadIdx.y];\n\n        c *= var_1 / var_target;\n      } while (min(var_1, var_target) / max(var_1, var_target) < 1 - eps);\n\n      if (threadIdx.x == 0) cs[out_row] = c;\n    }\n\n    out_row += BLOCK_CTAS;\n  }\n}\n\n}  // namespace\n\ntemplate <typename IdType>\nint log_size(const IdType size) {\n  if (size <= 0) return 0;\n  for (int i = 0; i < static_cast<int>(sizeof(IdType)) * 8; i++)\n    if (((size - 1) >> i) == 0) return i;\n  return sizeof(IdType) * 8;\n}\n\ntemplate <typename IdType, typename FloatType, typename exec_policy_t>\nvoid compute_importance_sampling_probabilities(\n    CSRMatrix mat, const IdType hop_size, cudaStream_t stream,\n    const continuous_seed seed, const IdType num_rows, const IdType* indptr,\n    const IdType* subindptr, const IdType* indices, IdArray idx_coo_arr,\n    const IdType* nids,\n    FloatArray cs_arr,  // holds the computed cs values, has size num_rows\n    const bool weighted, const FloatType* A, const FloatType* ds,\n    const FloatType* d2s, const IdType num_picks, DGLContext ctx,\n    const runtime::CUDAWorkspaceAllocator& allocator,\n    const exec_policy_t& exec_policy, const int importance_sampling,\n    IdType* hop_1,  // holds the contiguous one-hop neighborhood, has size |E|\n    FloatType* rands,  // holds the rolled random numbers r_t for each edge, has\n                       // size |E|\n    FloatType* probs_found) {  // holds the computed pi_t values for each edge,\n                               // has size |E|\n  auto device = runtime::DeviceAPI::Get(ctx);\n  auto idx_coo = idx_coo_arr.Ptr<IdType>();\n  auto cs = cs_arr.Ptr<FloatType>();\n  FloatArray A_l_arr = weighted\n                           ? NewFloatArray(hop_size, ctx, sizeof(FloatType) * 8)\n                           : NullArray();\n  auto A_l = A_l_arr.Ptr<FloatType>();\n\n  const int max_log_num_vertices = log_size(mat.num_cols);\n\n  {  // extracts the onehop neighborhood cols to a contiguous range into hop_1\n    const dim3 block(BLOCK_SIZE);\n    const dim3 grid((hop_size + BLOCK_SIZE - 1) / BLOCK_SIZE);\n    CUDA_KERNEL_CALL(\n        (_CSRRowWiseOneHopExtractorKernel<IdType, FloatType>), grid, block, 0,\n        stream, seed, hop_size, indptr, subindptr, indices, idx_coo, nids,\n        weighted ? A : nullptr, rands, hop_1, A_l);\n  }\n  int64_t hop_uniq_size = 0;\n  IdArray hop_new_arr = NewIdArray(hop_size, ctx, sizeof(IdType) * 8);\n  auto hop_new = hop_new_arr.Ptr<IdType>();\n  auto hop_unique = allocator.alloc_unique<IdType>(hop_size);\n  // After this block, hop_unique holds the unique set of one-hop neighborhood\n  // and hop_new holds the relabeled hop_1, idx_coo already holds relabeled\n  // destination. hop_unique[hop_new] == hop_1 holds\n  {\n    auto hop_2 = allocator.alloc_unique<IdType>(hop_size);\n    auto hop_3 = allocator.alloc_unique<IdType>(hop_size);\n\n    device->CopyDataFromTo(\n        hop_1, 0, hop_2.get(), 0, sizeof(IdType) * hop_size, ctx, ctx,\n        mat.indptr->dtype);\n\n    cub::DoubleBuffer<IdType> hop_b(hop_2.get(), hop_3.get());\n\n    {\n      std::size_t temp_storage_bytes = 0;\n      CUDA_CALL(cub::DeviceRadixSort::SortKeys(\n          nullptr, temp_storage_bytes, hop_b, hop_size, 0, max_log_num_vertices,\n          stream));\n\n      auto temp = allocator.alloc_unique<char>(temp_storage_bytes);\n\n      CUDA_CALL(cub::DeviceRadixSort::SortKeys(\n          temp.get(), temp_storage_bytes, hop_b, hop_size, 0,\n          max_log_num_vertices, stream));\n    }\n\n    auto hop_counts = allocator.alloc_unique<IdType>(hop_size + 1);\n    auto hop_unique_size = allocator.alloc_unique<int64_t>(1);\n\n    {\n      std::size_t temp_storage_bytes = 0;\n      CUDA_CALL(cub::DeviceRunLengthEncode::Encode(\n          nullptr, temp_storage_bytes, hop_b.Current(), hop_unique.get(),\n          hop_counts.get(), hop_unique_size.get(), hop_size, stream));\n\n      auto temp = allocator.alloc_unique<char>(temp_storage_bytes);\n\n      CUDA_CALL(cub::DeviceRunLengthEncode::Encode(\n          temp.get(), temp_storage_bytes, hop_b.Current(), hop_unique.get(),\n          hop_counts.get(), hop_unique_size.get(), hop_size, stream));\n\n      device->CopyDataFromTo(\n          hop_unique_size.get(), 0, &hop_uniq_size, 0, sizeof(hop_uniq_size),\n          ctx, DGLContext{kDGLCPU, 0}, mat.indptr->dtype);\n    }\n\n    thrust::lower_bound(\n        exec_policy, hop_unique.get(), hop_unique.get() + hop_uniq_size, hop_1,\n        hop_1 + hop_size, hop_new);\n  }\n\n  // @todo Consider creating a CSC because the SpMV will be done multiple times.\n  COOMatrix rmat(\n      num_rows, hop_uniq_size, idx_coo_arr, hop_new_arr, NullArray(), true,\n      mat.sorted);\n\n  BcastOff bcast_off;\n  bcast_off.use_bcast = false;\n  bcast_off.out_len = 1;\n  bcast_off.lhs_len = 1;\n  bcast_off.rhs_len = 1;\n\n  FloatArray probs_arr =\n      NewFloatArray(hop_uniq_size, ctx, sizeof(FloatType) * 8);\n  auto probs_1 = probs_arr.Ptr<FloatType>();\n  FloatArray probs_arr_2 =\n      NewFloatArray(hop_uniq_size, ctx, sizeof(FloatType) * 8);\n  auto probs = probs_arr_2.Ptr<FloatType>();\n  auto arg_u = NewIdArray(hop_uniq_size, ctx, sizeof(IdType) * 8);\n  auto arg_e = NewIdArray(hop_size, ctx, sizeof(IdType) * 8);\n\n  double prev_ex_nodes = hop_uniq_size;\n\n  for (int iters = 0; iters < importance_sampling || importance_sampling < 0;\n       iters++) {\n    if (weighted && iters == 0) {\n      cuda::SpMMCoo<\n          IdType, FloatType, cuda::binary::Mul<FloatType>,\n          cuda::reduce::Max<IdType, FloatType, true>>(\n          bcast_off, rmat, cs_arr, A_l_arr, probs_arr_2, arg_u, arg_e);\n    } else {\n      cuda::SpMMCoo<\n          IdType, FloatType, cuda::binary::CopyLhs<FloatType>,\n          cuda::reduce::Max<IdType, FloatType, true>>(\n          bcast_off, rmat, cs_arr, NullArray(), iters ? probs_arr : probs_arr_2,\n          arg_u, arg_e);\n    }\n\n    if (iters)\n      thrust::transform(\n          exec_policy, probs_1, probs_1 + hop_uniq_size, probs, probs,\n          thrust::multiplies<FloatType>{});\n\n    thrust::gather(\n        exec_policy, hop_new, hop_new + hop_size, probs, probs_found);\n\n    {\n      constexpr int BLOCK_CTAS = BLOCK_SIZE / CTA_SIZE;\n      // the number of rows each thread block will cover\n      constexpr int TILE_SIZE = BLOCK_CTAS;\n      const dim3 block(CTA_SIZE, BLOCK_CTAS);\n      const dim3 grid((num_rows + TILE_SIZE - 1) / TILE_SIZE);\n      CUDA_KERNEL_CALL(\n          (_CSRRowWiseLayerSampleDegreeKernel<\n              IdType, FloatType, BLOCK_CTAS, TILE_SIZE>),\n          grid, block, 0, stream, (IdType)num_picks, num_rows, cs,\n          weighted ? ds : nullptr, weighted ? d2s : nullptr, indptr,\n          probs_found, A, subindptr);\n    }\n\n    {\n      auto probs_min_1 =\n          thrust::make_transform_iterator(probs, TransformOpMinWith1{});\n      const double cur_ex_nodes = thrust::reduce(\n          exec_policy, probs_min_1, probs_min_1 + hop_uniq_size, 0.0);\n      if (cur_ex_nodes / prev_ex_nodes >= 1 - eps) break;\n      prev_ex_nodes = cur_ex_nodes;\n    }\n  }\n}\n\n/////////////////////////////// CSR ///////////////////////////////\n\ntemplate <DGLDeviceType XPU, typename IdType, typename FloatType>\nstd::pair<COOMatrix, FloatArray> CSRLaborSampling(\n    CSRMatrix mat, IdArray rows_arr, const int64_t num_picks,\n    FloatArray prob_arr, const int importance_sampling, IdArray random_seed_arr,\n    float seed2_contribution, IdArray NIDs) {\n  const bool weighted = !IsNullArray(prob_arr);\n\n  const auto& ctx = rows_arr->ctx;\n\n  runtime::CUDAWorkspaceAllocator allocator(ctx);\n\n  const auto stream = runtime::getCurrentCUDAStream();\n  const auto exec_policy = thrust::cuda::par_nosync(allocator).on(stream);\n\n  auto device = runtime::DeviceAPI::Get(ctx);\n\n  const IdType num_rows = rows_arr->shape[0];\n  IdType* const rows = rows_arr.Ptr<IdType>();\n  IdType* const nids = IsNullArray(NIDs) ? nullptr : NIDs.Ptr<IdType>();\n  FloatType* const A = prob_arr.Ptr<FloatType>();\n\n  IdType* const indptr_ = mat.indptr.Ptr<IdType>();\n  IdType* const indices_ = mat.indices.Ptr<IdType>();\n  IdType* const data = CSRHasData(mat) ? mat.data.Ptr<IdType>() : nullptr;\n\n  // Read indptr only once in case it is pinned and access is slow.\n  auto indptr = allocator.alloc_unique<IdType>(num_rows);\n  // compute in-degrees\n  auto in_deg = allocator.alloc_unique<IdType>(num_rows + 1);\n  // cs stands for c_s in arXiv:2210.13339\n  FloatArray cs_arr = NewFloatArray(num_rows, ctx, sizeof(FloatType) * 8);\n  auto cs = cs_arr.Ptr<FloatType>();\n  // ds stands for A_{*s} in arXiv:2210.13339\n  FloatArray ds_arr = weighted\n                          ? NewFloatArray(num_rows, ctx, sizeof(FloatType) * 8)\n                          : NullArray();\n  auto ds = ds_arr.Ptr<FloatType>();\n  // d2s stands for (A^2)_{*s} in arXiv:2210.13339, ^2 is elementwise.\n  FloatArray d2s_arr = weighted\n                           ? NewFloatArray(num_rows, ctx, sizeof(FloatType) * 8)\n                           : NullArray();\n  auto d2s = d2s_arr.Ptr<FloatType>();\n\n  thrust::counting_iterator<IdType> iota(0);\n  thrust::for_each(\n      exec_policy, iota, iota + num_rows,\n      DegreeFunc<IdType, FloatType>{\n          (IdType)num_picks, rows, indptr_, in_deg.get(), indptr.get(), cs});\n\n  if (weighted) {\n    auto b_offsets = thrust::make_transform_iterator(\n        iota, IndptrFunc<IdType>{indptr.get(), nullptr});\n    auto e_offsets = thrust::make_transform_iterator(\n        iota, IndptrFunc<IdType>{indptr.get(), in_deg.get()});\n\n    auto A_A2 = thrust::make_transform_iterator(A, SquareFunc<FloatType>{});\n    auto ds_d2s = thrust::make_zip_iterator(ds, d2s);\n\n    size_t prefix_temp_size = 0;\n    CUDA_CALL(cub::DeviceSegmentedReduce::Reduce(\n        nullptr, prefix_temp_size, A_A2, ds_d2s, num_rows, b_offsets, e_offsets,\n        TupleSum{}, thrust::make_tuple((FloatType)0, (FloatType)0), stream));\n    auto temp = allocator.alloc_unique<char>(prefix_temp_size);\n    CUDA_CALL(cub::DeviceSegmentedReduce::Reduce(\n        temp.get(), prefix_temp_size, A_A2, ds_d2s, num_rows, b_offsets,\n        e_offsets, TupleSum{}, thrust::make_tuple((FloatType)0, (FloatType)0),\n        stream));\n  }\n\n  // fill subindptr\n  IdArray subindptr_arr = NewIdArray(num_rows + 1, ctx, sizeof(IdType) * 8);\n  auto subindptr = subindptr_arr.Ptr<IdType>();\n\n  IdType hop_size;\n  {\n    size_t prefix_temp_size = 0;\n    CUDA_CALL(cub::DeviceScan::ExclusiveSum(\n        nullptr, prefix_temp_size, in_deg.get(), subindptr, num_rows + 1,\n        stream));\n    auto temp = allocator.alloc_unique<char>(prefix_temp_size);\n    CUDA_CALL(cub::DeviceScan::ExclusiveSum(\n        temp.get(), prefix_temp_size, in_deg.get(), subindptr, num_rows + 1,\n        stream));\n\n    device->CopyDataFromTo(\n        subindptr, num_rows * sizeof(hop_size), &hop_size, 0, sizeof(hop_size),\n        ctx, DGLContext{kDGLCPU, 0}, mat.indptr->dtype);\n  }\n  IdArray hop_arr = NewIdArray(hop_size, ctx, sizeof(IdType) * 8);\n  CSRMatrix smat(\n      num_rows, mat.num_cols, subindptr_arr, hop_arr, NullArray(), mat.sorted);\n  // @todo Consider fusing CSRToCOO into StencilOpFused kernel\n  auto smatcoo = CSRToCOO(smat, false);\n\n  auto idx_coo_arr = smatcoo.row;\n  auto idx_coo = idx_coo_arr.Ptr<IdType>();\n\n  auto hop_1 = hop_arr.Ptr<IdType>();\n  const bool is_pinned = mat.indices.IsPinned();\n  if (is_pinned) {\n    const auto res = Sort(rows_arr, log_size(mat.num_rows));\n    const int64_t* perm = static_cast<int64_t*>(res.second->data);\n\n    IdType hop_size;  // Shadows the original one as this is temporary\n    auto subindptr_aligned = allocator.alloc_unique<IdType>(num_rows + 1);\n    {\n      auto modified_in_deg = thrust::make_transform_iterator(\n          iota, AlignmentFunc<IdType>{in_deg.get(), perm, num_rows});\n      size_t prefix_temp_size = 0;\n      CUDA_CALL(cub::DeviceScan::ExclusiveSum(\n          nullptr, prefix_temp_size, modified_in_deg, subindptr_aligned.get(),\n          num_rows + 1, stream));\n      auto temp = allocator.alloc_unique<char>(prefix_temp_size);\n      CUDA_CALL(cub::DeviceScan::ExclusiveSum(\n          temp.get(), prefix_temp_size, modified_in_deg,\n          subindptr_aligned.get(), num_rows + 1, stream));\n\n      device->CopyDataFromTo(\n          subindptr_aligned.get(), num_rows * sizeof(hop_size), &hop_size, 0,\n          sizeof(hop_size), ctx, DGLContext{kDGLCPU, 0}, mat.indptr->dtype);\n    }\n    const dim3 block(BLOCK_SIZE);\n    const dim3 grid((hop_size + BLOCK_SIZE - 1) / BLOCK_SIZE);\n    CUDA_KERNEL_CALL(\n        (_CSRRowWiseOneHopExtractorAlignedKernel<IdType>), grid, block, 0,\n        stream, hop_size, num_rows, indptr.get(), subindptr,\n        subindptr_aligned.get(), indices_, hop_1, perm);\n  }\n  const auto indices = is_pinned ? hop_1 : indices_;\n\n  auto rands =\n      allocator.alloc_unique<FloatType>(importance_sampling ? hop_size : 1);\n  auto probs_found =\n      allocator.alloc_unique<FloatType>(importance_sampling ? hop_size : 1);\n\n  if (weighted) {\n    // Recompute c for weighted graphs.\n    constexpr int BLOCK_CTAS = BLOCK_SIZE / CTA_SIZE;\n    // the number of rows each thread block will cover\n    constexpr int TILE_SIZE = BLOCK_CTAS;\n    const dim3 block(CTA_SIZE, BLOCK_CTAS);\n    const dim3 grid((num_rows + TILE_SIZE - 1) / TILE_SIZE);\n    CUDA_KERNEL_CALL(\n        (_CSRRowWiseLayerSampleDegreeKernel<\n            IdType, FloatType, BLOCK_CTAS, TILE_SIZE>),\n        grid, block, 0, stream, (IdType)num_picks, num_rows, cs, ds, d2s,\n        indptr.get(), nullptr, A, subindptr);\n  }\n\n  const continuous_seed random_seed =\n      IsNullArray(random_seed_arr)\n          ? continuous_seed(RandomEngine::ThreadLocal()->RandInt(1000000000))\n          : continuous_seed(random_seed_arr, seed2_contribution);\n\n  if (importance_sampling)\n    compute_importance_sampling_probabilities<\n        IdType, FloatType, decltype(exec_policy)>(\n        mat, hop_size, stream, random_seed, num_rows, indptr.get(), subindptr,\n        indices, idx_coo_arr, nids, cs_arr, weighted, A, ds, d2s,\n        (IdType)num_picks, ctx, allocator, exec_policy, importance_sampling,\n        hop_1, rands.get(), probs_found.get());\n\n  IdArray picked_row = NewIdArray(hop_size, ctx, sizeof(IdType) * 8);\n  IdArray picked_col = NewIdArray(hop_size, ctx, sizeof(IdType) * 8);\n  IdArray picked_idx = NewIdArray(hop_size, ctx, sizeof(IdType) * 8);\n  FloatArray picked_imp =\n      importance_sampling || weighted\n          ? NewFloatArray(hop_size, ctx, sizeof(FloatType) * 8)\n          : NullArray();\n\n  IdType* const picked_row_data = picked_row.Ptr<IdType>();\n  IdType* const picked_col_data = picked_col.Ptr<IdType>();\n  IdType* const picked_idx_data = picked_idx.Ptr<IdType>();\n  FloatType* const picked_imp_data = picked_imp.Ptr<FloatType>();\n\n  auto picked_inrow = allocator.alloc_unique<IdType>(\n      importance_sampling || weighted ? hop_size : 1);\n\n  // Sample edges here\n  IdType num_edges;\n  {\n    thrust::constant_iterator<FloatType> one(1);\n    if (importance_sampling) {\n      auto output = thrust::make_zip_iterator(\n          picked_inrow.get(), picked_row_data, picked_col_data, picked_idx_data,\n          picked_imp_data);\n      if (weighted) {\n        auto transformed_output = thrust::make_transform_output_iterator(\n            output,\n            TransformOpImp<\n                IdType, FloatType, FloatType*, FloatType*, decltype(one)>{\n                probs_found.get(), A, one, idx_coo, rows, cs, indptr.get(),\n                subindptr, indices, data, is_pinned});\n        auto stencil =\n            thrust::make_zip_iterator(idx_coo, probs_found.get(), rands.get());\n        num_edges =\n            thrust::copy_if(\n                exec_policy, iota, iota + hop_size, stencil, transformed_output,\n                thrust::make_zip_function(StencilOp<FloatType>{cs})) -\n            transformed_output;\n      } else {\n        auto transformed_output = thrust::make_transform_output_iterator(\n            output,\n            TransformOpImp<\n                IdType, FloatType, FloatType*, decltype(one), decltype(one)>{\n                probs_found.get(), one, one, idx_coo, rows, cs, indptr.get(),\n                subindptr, indices, data, is_pinned});\n        auto stencil =\n            thrust::make_zip_iterator(idx_coo, probs_found.get(), rands.get());\n        num_edges =\n            thrust::copy_if(\n                exec_policy, iota, iota + hop_size, stencil, transformed_output,\n                thrust::make_zip_function(StencilOp<FloatType>{cs})) -\n            transformed_output;\n      }\n    } else {\n      if (weighted) {\n        auto output = thrust::make_zip_iterator(\n            picked_inrow.get(), picked_row_data, picked_col_data,\n            picked_idx_data, picked_imp_data);\n        auto transformed_output = thrust::make_transform_output_iterator(\n            output,\n            TransformOpImp<\n                IdType, FloatType, decltype(one), FloatType*, FloatType*>{\n                one, A, A, idx_coo, rows, cs, indptr.get(), subindptr, indices,\n                data, is_pinned});\n        const auto pred =\n            StencilOpFused<IdType, FloatType, decltype(one), FloatType*>{\n                random_seed, idx_coo,      cs,      one,  A,\n                subindptr,   indptr.get(), indices, nids, is_pinned};\n        num_edges = thrust::copy_if(\n                        exec_policy, iota, iota + hop_size, iota,\n                        transformed_output, pred) -\n                    transformed_output;\n      } else {\n        auto output = thrust::make_zip_iterator(\n            picked_row_data, picked_col_data, picked_idx_data);\n        auto transformed_output = thrust::make_transform_output_iterator(\n            output, TransformOp<IdType>{\n                        idx_coo, rows, indptr.get(), subindptr, indices, data,\n                        is_pinned});\n        const auto pred =\n            StencilOpFused<IdType, FloatType, decltype(one), decltype(one)>{\n                random_seed, idx_coo,      cs,      one,  one,\n                subindptr,   indptr.get(), indices, nids, is_pinned};\n        num_edges = thrust::copy_if(\n                        exec_policy, iota, iota + hop_size, iota,\n                        transformed_output, pred) -\n                    transformed_output;\n      }\n    }\n  }\n\n  // Normalize edge weights here\n  if (importance_sampling || weighted) {\n    thrust::constant_iterator<IdType> one(1);\n    // contains degree information\n    auto ds = allocator.alloc_unique<IdType>(num_rows);\n    // contains sum of edge weights\n    auto ws = allocator.alloc_unique<FloatType>(num_rows);\n    // contains degree information only for vertices with nonzero degree\n    auto ds_2 = allocator.alloc_unique<IdType>(num_rows);\n    // contains sum of edge weights only for vertices with nonzero degree\n    auto ws_2 = allocator.alloc_unique<FloatType>(num_rows);\n    auto output_ = thrust::make_zip_iterator(ds.get(), ws.get());\n    // contains row ids only for vertices with nonzero degree\n    auto keys = allocator.alloc_unique<IdType>(num_rows);\n    auto input = thrust::make_zip_iterator(one, picked_imp_data);\n    auto new_end = thrust::reduce_by_key(\n        exec_policy, picked_inrow.get(), picked_inrow.get() + num_edges, input,\n        keys.get(), output_, thrust::equal_to<IdType>{}, TupleSum{});\n    {\n      thrust::constant_iterator<IdType> zero_int(0);\n      thrust::constant_iterator<FloatType> zero_float(0);\n      auto input = thrust::make_zip_iterator(zero_int, zero_float);\n      auto output = thrust::make_zip_iterator(ds_2.get(), ws_2.get());\n      thrust::copy(exec_policy, input, input + num_rows, output);\n      {\n        const auto num_rows_2 = new_end.first - keys.get();\n        thrust::scatter(\n            exec_policy, output_, output_ + num_rows_2, keys.get(), output);\n      }\n    }\n    {\n      auto input =\n          thrust::make_zip_iterator(picked_inrow.get(), picked_imp_data);\n      auto transformed_input = thrust::make_transform_iterator(\n          input, thrust::make_zip_function(TransformOpMean<IdType, FloatType>{\n                     ds_2.get(), ws_2.get()}));\n      thrust::copy(\n          exec_policy, transformed_input, transformed_input + num_edges,\n          picked_imp_data);\n    }\n  }\n\n  picked_row = picked_row.CreateView({num_edges}, picked_row->dtype);\n  picked_col = picked_col.CreateView({num_edges}, picked_col->dtype);\n  picked_idx = picked_idx.CreateView({num_edges}, picked_idx->dtype);\n  if (importance_sampling || weighted)\n    picked_imp = picked_imp.CreateView({num_edges}, picked_imp->dtype);\n\n  return std::make_pair(\n      COOMatrix(mat.num_rows, mat.num_cols, picked_row, picked_col, picked_idx),\n      picked_imp);\n}\n\ntemplate std::pair<COOMatrix, FloatArray>\nCSRLaborSampling<kDGLCUDA, int32_t, float>(\n    CSRMatrix, IdArray, int64_t, FloatArray, int, IdArray, float, IdArray);\ntemplate std::pair<COOMatrix, FloatArray>\nCSRLaborSampling<kDGLCUDA, int64_t, float>(\n    CSRMatrix, IdArray, int64_t, FloatArray, int, IdArray, float, IdArray);\ntemplate std::pair<COOMatrix, FloatArray>\nCSRLaborSampling<kDGLCUDA, int32_t, double>(\n    CSRMatrix, IdArray, int64_t, FloatArray, int, IdArray, float, IdArray);\ntemplate std::pair<COOMatrix, FloatArray>\nCSRLaborSampling<kDGLCUDA, int64_t, double>(\n    CSRMatrix, IdArray, int64_t, FloatArray, int, IdArray, float, IdArray);\n\n}  // namespace impl\n}  // namespace aten\n}  // namespace dgl\n"
  },
  {
    "path": "src/array/cuda/macro.cuh",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file array/cuda/macro.cuh\n * @brief Macro to call SPMM/SDDMM cuda kernels.\n */\n#ifndef DGL_ARRAY_CUDA_MACRO_CUH_\n#define DGL_ARRAY_CUDA_MACRO_CUH_\n\n///////////////////////// Dispatchers //////////////////////////\n\n/* Macro used for switching between broadcasting and non-broadcasting kernels.\n * It also copies the auxiliary information for calculating broadcasting offsets\n * to GPU.\n */\n#define BCAST_IDX_CTX_SWITCH(BCAST, EDGE_MAP, CTX, LHS_OFF, RHS_OFF, ...)     \\\n  do {                                                                        \\\n    const BcastOff &info = (BCAST);                                           \\\n    if (!info.use_bcast) {                                                    \\\n      constexpr bool UseBcast = false;                                        \\\n      if ((EDGE_MAP)) {                                                       \\\n        constexpr bool UseIdx = true;                                         \\\n        { __VA_ARGS__ }                                                       \\\n      } else {                                                                \\\n        constexpr bool UseIdx = false;                                        \\\n        { __VA_ARGS__ }                                                       \\\n      }                                                                       \\\n    } else {                                                                  \\\n      constexpr bool UseBcast = true;                                         \\\n      const DGLContext ctx = (CTX);                                           \\\n      const auto device = runtime::DeviceAPI::Get(ctx);                       \\\n      (LHS_OFF) = static_cast<int64_t *>(device->AllocWorkspace(              \\\n          ctx, sizeof(int64_t) * info.lhs_offset.size()));                    \\\n      CUDA_CALL(cudaMemcpy(                                                   \\\n          (LHS_OFF), &info.lhs_offset[0],                                     \\\n          sizeof(int64_t) * info.lhs_offset.size(), cudaMemcpyHostToDevice)); \\\n      (RHS_OFF) = static_cast<int64_t *>(device->AllocWorkspace(              \\\n          ctx, sizeof(int64_t) * info.rhs_offset.size()));                    \\\n      CUDA_CALL(cudaMemcpy(                                                   \\\n          (RHS_OFF), &info.rhs_offset[0],                                     \\\n          sizeof(int64_t) * info.rhs_offset.size(), cudaMemcpyHostToDevice)); \\\n      if ((EDGE_MAP)) {                                                       \\\n        constexpr bool UseIdx = true;                                         \\\n        { __VA_ARGS__ }                                                       \\\n      } else {                                                                \\\n        constexpr bool UseIdx = false;                                        \\\n        { __VA_ARGS__ }                                                       \\\n      }                                                                       \\\n      device->FreeWorkspace(ctx, (LHS_OFF));                                  \\\n      device->FreeWorkspace(ctx, (RHS_OFF));                                  \\\n    }                                                                         \\\n  } while (0)\n\n#endif  // DGL_ARRAY_CUDA_MACRO_CUH_\n"
  },
  {
    "path": "src/array/cuda/negative_sampling.cu",
    "content": "/**\n *  Copyright (c) 2021 by Contributors\n * @file array/cuda/negative_sampling.cu\n * @brief rowwise sampling\n */\n\n#include <curand_kernel.h>\n#include <dgl/array.h>\n#include <dgl/array_iterator.h>\n#include <dgl/random.h>\n\n#include <cub/cub.cuh>\n\n#include \"../../runtime/cuda/cuda_common.h\"\n#include \"./utils.h\"\n\nusing namespace dgl::runtime;\n\nnamespace dgl {\nnamespace aten {\nnamespace impl {\n\nnamespace {\n\ntemplate <typename IdType>\n__global__ void _GlobalUniformNegativeSamplingKernel(\n    const IdType* __restrict__ indptr, const IdType* __restrict__ indices,\n    IdType* __restrict__ row, IdType* __restrict__ col, int64_t num_row,\n    int64_t num_col, int64_t num_samples, int num_trials,\n    bool exclude_self_loops, int32_t random_seed) {\n  int64_t tx = blockIdx.x * blockDim.x + threadIdx.x;\n  const int stride_x = gridDim.x * blockDim.x;\n\n  curandStatePhilox4_32_10_t\n      rng;  // this allows generating 4 32-bit ints at a time\n  curand_init(random_seed * gridDim.x + blockIdx.x, threadIdx.x, 0, &rng);\n\n  while (tx < num_samples) {\n    for (int i = 0; i < num_trials; ++i) {\n      uint4 result = curand4(&rng);\n      // Turns out that result.x is always 0 with the above RNG.\n      uint64_t y_hi = result.y >> 16;\n      uint64_t y_lo = result.y & 0xFFFF;\n      uint64_t z = static_cast<uint64_t>(result.z);\n      uint64_t w = static_cast<uint64_t>(result.w);\n      int64_t u = static_cast<int64_t>(((y_lo << 32L) | z) % num_row);\n      int64_t v = static_cast<int64_t>(((y_hi << 32L) | w) % num_col);\n\n      if (exclude_self_loops && (u == v)) continue;\n\n      // binary search of v among indptr[u:u+1]\n      int64_t b = indptr[u], e = indptr[u + 1] - 1;\n      bool found = false;\n      while (b <= e) {\n        int64_t m = (b + e) / 2;\n        if (indices[m] == v) {\n          found = true;\n          break;\n        } else if (indices[m] < v) {\n          b = m + 1;\n        } else {\n          e = m - 1;\n        }\n      }\n\n      if (!found) {\n        row[tx] = u;\n        col[tx] = v;\n        break;\n      }\n    }\n\n    tx += stride_x;\n  }\n}\n\ntemplate <typename DType>\nstruct IsNotMinusOne {\n  __device__ __forceinline__ bool operator()(const std::pair<DType, DType>& a) {\n    return a.first != -1;\n  }\n};\n\n/**\n * @brief Sort ordered pairs in ascending order, using \\a tmp_major and \\a\n * tmp_minor as temporary buffers, each with \\a n elements.\n */\ntemplate <typename IdType>\nvoid SortOrderedPairs(\n    runtime::DeviceAPI* device, DGLContext ctx, IdType* major, IdType* minor,\n    IdType* tmp_major, IdType* tmp_minor, int64_t n, cudaStream_t stream) {\n  // Sort ordered pairs in lexicographical order by two radix sorts since\n  // cub's radix sorts are stable.\n  // We need a 2*n auxiliary storage to store the results form the first radix\n  // sort.\n  size_t s1 = 0, s2 = 0;\n  void* tmp1 = nullptr;\n  void* tmp2 = nullptr;\n\n  // Radix sort by minor key first, reorder the major key in the progress.\n  CUDA_CALL(cub::DeviceRadixSort::SortPairs(\n      tmp1, s1, minor, tmp_minor, major, tmp_major, n, 0, sizeof(IdType) * 8,\n      stream));\n  tmp1 = device->AllocWorkspace(ctx, s1);\n  CUDA_CALL(cub::DeviceRadixSort::SortPairs(\n      tmp1, s1, minor, tmp_minor, major, tmp_major, n, 0, sizeof(IdType) * 8,\n      stream));\n\n  // Radix sort by major key next.\n  CUDA_CALL(cub::DeviceRadixSort::SortPairs(\n      tmp2, s2, tmp_major, major, tmp_minor, minor, n, 0, sizeof(IdType) * 8,\n      stream));\n  tmp2 = (s2 > s1) ? device->AllocWorkspace(ctx, s2)\n                   : tmp1;  // reuse buffer if s2 <= s1\n  CUDA_CALL(cub::DeviceRadixSort::SortPairs(\n      tmp2, s2, tmp_major, major, tmp_minor, minor, n, 0, sizeof(IdType) * 8,\n      stream));\n\n  if (tmp1 != tmp2) device->FreeWorkspace(ctx, tmp2);\n  device->FreeWorkspace(ctx, tmp1);\n}\n\n};  // namespace\n\ntemplate <DGLDeviceType XPU, typename IdType>\nstd::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(\n    const CSRMatrix& csr, int64_t num_samples, int num_trials,\n    bool exclude_self_loops, bool replace, double redundancy) {\n  auto ctx = csr.indptr->ctx;\n  auto dtype = csr.indptr->dtype;\n  const int64_t num_row = csr.num_rows;\n  const int64_t num_col = csr.num_cols;\n  const int64_t num_actual_samples =\n      static_cast<int64_t>(num_samples * (1 + redundancy));\n  IdArray row = Full<IdType>(-1, num_actual_samples, ctx);\n  IdArray col = Full<IdType>(-1, num_actual_samples, ctx);\n  IdArray out_row = IdArray::Empty({num_actual_samples}, dtype, ctx);\n  IdArray out_col = IdArray::Empty({num_actual_samples}, dtype, ctx);\n  IdType* row_data = row.Ptr<IdType>();\n  IdType* col_data = col.Ptr<IdType>();\n  IdType* out_row_data = out_row.Ptr<IdType>();\n  IdType* out_col_data = out_col.Ptr<IdType>();\n  auto device = runtime::DeviceAPI::Get(ctx);\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  const int nt = cuda::FindNumThreads(num_actual_samples);\n  const int nb = (num_actual_samples + nt - 1) / nt;\n  std::pair<IdArray, IdArray> result;\n  int64_t num_out;\n\n  CUDA_KERNEL_CALL(\n      _GlobalUniformNegativeSamplingKernel, nb, nt, 0, stream,\n      csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(), row_data, col_data,\n      num_row, num_col, num_actual_samples, num_trials, exclude_self_loops,\n      RandomEngine::ThreadLocal()->RandInt32());\n\n  size_t tmp_size = 0;\n  int64_t* num_out_cuda =\n      static_cast<int64_t*>(device->AllocWorkspace(ctx, sizeof(int64_t)));\n  IsNotMinusOne<IdType> op;\n  PairIterator<IdType> begin(row_data, col_data);\n  PairIterator<IdType> out_begin(out_row_data, out_col_data);\n  CUDA_CALL(cub::DeviceSelect::If(\n      nullptr, tmp_size, begin, out_begin, num_out_cuda, num_actual_samples, op,\n      stream));\n  void* tmp = device->AllocWorkspace(ctx, tmp_size);\n  CUDA_CALL(cub::DeviceSelect::If(\n      tmp, tmp_size, begin, out_begin, num_out_cuda, num_actual_samples, op,\n      stream));\n  num_out = cuda::GetCUDAScalar(device, ctx, num_out_cuda);\n\n  if (!replace) {\n    IdArray unique_row = IdArray::Empty({num_out}, dtype, ctx);\n    IdArray unique_col = IdArray::Empty({num_out}, dtype, ctx);\n    IdType* unique_row_data = unique_row.Ptr<IdType>();\n    IdType* unique_col_data = unique_col.Ptr<IdType>();\n    PairIterator<IdType> unique_begin(unique_row_data, unique_col_data);\n\n    SortOrderedPairs(\n        device, ctx, out_row_data, out_col_data, unique_row_data,\n        unique_col_data, num_out, stream);\n\n    size_t tmp_size_unique = 0;\n    void* tmp_unique = nullptr;\n    CUDA_CALL(cub::DeviceSelect::Unique(\n        nullptr, tmp_size_unique, out_begin, unique_begin, num_out_cuda,\n        num_out, stream));\n    tmp_unique = (tmp_size_unique > tmp_size)\n                     ? device->AllocWorkspace(ctx, tmp_size_unique)\n                     : tmp;  // reuse buffer\n    CUDA_CALL(cub::DeviceSelect::Unique(\n        tmp_unique, tmp_size_unique, out_begin, unique_begin, num_out_cuda,\n        num_out, stream));\n    num_out = cuda::GetCUDAScalar(device, ctx, num_out_cuda);\n\n    num_out = std::min(num_samples, num_out);\n    result = {\n        unique_row.CreateView({num_out}, dtype),\n        unique_col.CreateView({num_out}, dtype)};\n\n    if (tmp_unique != tmp) device->FreeWorkspace(ctx, tmp_unique);\n  } else {\n    num_out = std::min(num_samples, num_out);\n    result = {\n        out_row.CreateView({num_out}, dtype),\n        out_col.CreateView({num_out}, dtype)};\n  }\n\n  device->FreeWorkspace(ctx, tmp);\n  device->FreeWorkspace(ctx, num_out_cuda);\n  return result;\n}\n\ntemplate std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling<\n    kDGLCUDA, int32_t>(const CSRMatrix&, int64_t, int, bool, bool, double);\ntemplate std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling<\n    kDGLCUDA, int64_t>(const CSRMatrix&, int64_t, int, bool, bool, double);\n\n};  // namespace impl\n};  // namespace aten\n};  // namespace dgl\n"
  },
  {
    "path": "src/array/cuda/rowwise_sampling.cu",
    "content": "/**\n *  Copyright (c) 2021 by Contributors\n * @file array/cuda/rowwise_sampling.cu\n * @brief uniform rowwise sampling\n */\n\n#include <curand_kernel.h>\n#include <dgl/random.h>\n#include <dgl/runtime/device_api.h>\n#include <dgl/runtime/tensordispatch.h>\n\n#include <cub/cub.cuh>\n#include <numeric>\n\n#include \"../../array/cuda/atomic.cuh\"\n#include \"../../runtime/cuda/cuda_common.h\"\n#include \"./utils.h\"\n\nusing namespace dgl::cuda;\nusing namespace dgl::aten::cuda;\nusing TensorDispatcher = dgl::runtime::TensorDispatcher;\n\nnamespace dgl {\nnamespace aten {\nnamespace impl {\n\nnamespace {\n\nconstexpr int BLOCK_SIZE = 128;\n\n/**\n * @brief Compute the size of each row in the sampled CSR, without replacement.\n *\n * @tparam IdType The type of node and edge indexes.\n * @param num_picks The number of non-zero entries to pick per row.\n * @param num_rows The number of rows to pick.\n * @param in_rows The set of rows to pick.\n * @param in_ptr The index where each row's edges start.\n * @param out_deg The size of each row in the sampled matrix, as indexed by\n * `in_rows` (output).\n */\ntemplate <typename IdType>\n__global__ void _CSRRowWiseSampleDegreeKernel(\n    const int64_t num_picks, const int64_t num_rows,\n    const IdType* const in_rows, const IdType* const in_ptr,\n    IdType* const out_deg) {\n  const int tIdx = threadIdx.x + blockIdx.x * blockDim.x;\n\n  if (tIdx < num_rows) {\n    const int in_row = in_rows[tIdx];\n    const int out_row = tIdx;\n    out_deg[out_row] = min(\n        static_cast<IdType>(num_picks), in_ptr[in_row + 1] - in_ptr[in_row]);\n\n    if (out_row == num_rows - 1) {\n      // make the prefixsum work\n      out_deg[num_rows] = 0;\n    }\n  }\n}\n\n/**\n * @brief Compute the size of each row in the sampled CSR, with replacement.\n *\n * @tparam IdType The type of node and edge indexes.\n * @param num_picks The number of non-zero entries to pick per row.\n * @param num_rows The number of rows to pick.\n * @param in_rows The set of rows to pick.\n * @param in_ptr The index where each row's edges start.\n * @param out_deg The size of each row in the sampled matrix, as indexed by\n * `in_rows` (output).\n */\ntemplate <typename IdType>\n__global__ void _CSRRowWiseSampleDegreeReplaceKernel(\n    const int64_t num_picks, const int64_t num_rows,\n    const IdType* const in_rows, const IdType* const in_ptr,\n    IdType* const out_deg) {\n  const int tIdx = threadIdx.x + blockIdx.x * blockDim.x;\n\n  if (tIdx < num_rows) {\n    const int64_t in_row = in_rows[tIdx];\n    const int64_t out_row = tIdx;\n\n    if (in_ptr[in_row + 1] - in_ptr[in_row] == 0) {\n      out_deg[out_row] = 0;\n    } else {\n      out_deg[out_row] = static_cast<IdType>(num_picks);\n    }\n\n    if (out_row == num_rows - 1) {\n      // make the prefixsum work\n      out_deg[num_rows] = 0;\n    }\n  }\n}\n\n/**\n * @brief Perform row-wise uniform sampling on a CSR matrix,\n * and generate a COO matrix, without replacement.\n *\n * @tparam IdType The ID type used for matrices.\n * @tparam TILE_SIZE The number of rows covered by each threadblock.\n * @param rand_seed The random seed to use.\n * @param num_picks The number of non-zeros to pick per row.\n * @param num_rows The number of rows to pick.\n * @param in_rows The set of rows to pick.\n * @param in_ptr The indptr array of the input CSR.\n * @param in_index The indices array of the input CSR.\n * @param data The data array of the input CSR.\n * @param out_ptr The offset to write each row to in the output COO.\n * @param out_rows The rows of the output COO (output).\n * @param out_cols The columns of the output COO (output).\n * @param out_idxs The data array of the output COO (output).\n */\ntemplate <typename IdType, int TILE_SIZE>\n__global__ void _CSRRowWiseSampleUniformKernel(\n    const uint64_t rand_seed, const int64_t num_picks, const int64_t num_rows,\n    const IdType* const in_rows, const IdType* const in_ptr,\n    const IdType* const in_index, const IdType* const data,\n    const IdType* const out_ptr, IdType* const out_rows, IdType* const out_cols,\n    IdType* const out_idxs) {\n  // we assign one warp per row\n  assert(blockDim.x == BLOCK_SIZE);\n\n  int64_t out_row = blockIdx.x * TILE_SIZE;\n  const int64_t last_row =\n      min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_rows);\n\n  curandStatePhilox4_32_10_t rng;\n  curand_init(rand_seed * gridDim.x + blockIdx.x, threadIdx.x, 0, &rng);\n\n  while (out_row < last_row) {\n    const int64_t row = in_rows[out_row];\n    const int64_t in_row_start = in_ptr[row];\n    const int64_t deg = in_ptr[row + 1] - in_row_start;\n    const int64_t out_row_start = out_ptr[out_row];\n\n    if (deg <= num_picks) {\n      // just copy row when there is not enough nodes to sample.\n      for (int idx = threadIdx.x; idx < deg; idx += BLOCK_SIZE) {\n        const IdType in_idx = in_row_start + idx;\n        out_rows[out_row_start + idx] = row;\n        out_cols[out_row_start + idx] = in_index[in_idx];\n        out_idxs[out_row_start + idx] = data ? data[in_idx] : in_idx;\n      }\n    } else {\n      // generate permutation list via reservoir algorithm\n      for (int idx = threadIdx.x; idx < num_picks; idx += BLOCK_SIZE) {\n        out_idxs[out_row_start + idx] = idx;\n      }\n      __syncthreads();\n\n      for (int idx = num_picks + threadIdx.x; idx < deg; idx += BLOCK_SIZE) {\n        const int num = curand(&rng) % (idx + 1);\n        if (num < num_picks) {\n          // use max so as to achieve the replacement order the serial\n          // algorithm would have\n          AtomicMax(out_idxs + out_row_start + num, idx);\n        }\n      }\n      __syncthreads();\n\n      // copy permutation over\n      for (int idx = threadIdx.x; idx < num_picks; idx += BLOCK_SIZE) {\n        const IdType perm_idx = out_idxs[out_row_start + idx] + in_row_start;\n        out_rows[out_row_start + idx] = row;\n        out_cols[out_row_start + idx] = in_index[perm_idx];\n        out_idxs[out_row_start + idx] = data ? data[perm_idx] : perm_idx;\n      }\n    }\n    out_row += 1;\n  }\n}\n\n/**\n * @brief Perform row-wise uniform sampling on a CSR matrix,\n * and generate a COO matrix, with replacement.\n *\n * @tparam IdType The ID type used for matrices.\n * @tparam TILE_SIZE The number of rows covered by each threadblock.\n * @param rand_seed The random seed to use.\n * @param num_picks The number of non-zeros to pick per row.\n * @param num_rows The number of rows to pick.\n * @param in_rows The set of rows to pick.\n * @param in_ptr The indptr array of the input CSR.\n * @param in_index The indices array of the input CSR.\n * @param data The data array of the input CSR.\n * @param out_ptr The offset to write each row to in the output COO.\n * @param out_rows The rows of the output COO (output).\n * @param out_cols The columns of the output COO (output).\n * @param out_idxs The data array of the output COO (output).\n */\ntemplate <typename IdType, int TILE_SIZE>\n__global__ void _CSRRowWiseSampleUniformReplaceKernel(\n    const uint64_t rand_seed, const int64_t num_picks, const int64_t num_rows,\n    const IdType* const in_rows, const IdType* const in_ptr,\n    const IdType* const in_index, const IdType* const data,\n    const IdType* const out_ptr, IdType* const out_rows, IdType* const out_cols,\n    IdType* const out_idxs) {\n  // we assign one warp per row\n  assert(blockDim.x == BLOCK_SIZE);\n\n  int64_t out_row = blockIdx.x * TILE_SIZE;\n  const int64_t last_row =\n      min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_rows);\n\n  curandStatePhilox4_32_10_t rng;\n  curand_init(rand_seed * gridDim.x + blockIdx.x, threadIdx.x, 0, &rng);\n\n  while (out_row < last_row) {\n    const int64_t row = in_rows[out_row];\n    const int64_t in_row_start = in_ptr[row];\n    const int64_t out_row_start = out_ptr[out_row];\n    const int64_t deg = in_ptr[row + 1] - in_row_start;\n\n    if (deg > 0) {\n      // each thread then blindly copies in rows only if deg > 0.\n      for (int idx = threadIdx.x; idx < num_picks; idx += BLOCK_SIZE) {\n        const int64_t edge = curand(&rng) % deg;\n        const int64_t out_idx = out_row_start + idx;\n        out_rows[out_idx] = row;\n        out_cols[out_idx] = in_index[in_row_start + edge];\n        out_idxs[out_idx] =\n            data ? data[in_row_start + edge] : in_row_start + edge;\n      }\n    }\n    out_row += 1;\n  }\n}\n\n}  // namespace\n\n///////////////////////////// CSR sampling //////////////////////////\n\ntemplate <DGLDeviceType XPU, typename IdType>\nCOOMatrix _CSRRowWiseSamplingUniform(\n    CSRMatrix mat, IdArray rows, const int64_t num_picks, const bool replace) {\n  const auto& ctx = rows->ctx;\n  auto device = runtime::DeviceAPI::Get(ctx);\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n\n  const int64_t num_rows = rows->shape[0];\n  const IdType* const slice_rows = static_cast<const IdType*>(rows->data);\n\n  IdArray picked_row =\n      NewIdArray(num_rows * num_picks, ctx, sizeof(IdType) * 8);\n  IdArray picked_col =\n      NewIdArray(num_rows * num_picks, ctx, sizeof(IdType) * 8);\n  IdArray picked_idx =\n      NewIdArray(num_rows * num_picks, ctx, sizeof(IdType) * 8);\n  IdType* const out_rows = static_cast<IdType*>(picked_row->data);\n  IdType* const out_cols = static_cast<IdType*>(picked_col->data);\n  IdType* const out_idxs = static_cast<IdType*>(picked_idx->data);\n\n  const IdType* in_ptr = static_cast<IdType*>(GetDevicePointer(mat.indptr));\n  const IdType* in_cols = static_cast<IdType*>(GetDevicePointer(mat.indices));\n  const IdType* data = CSRHasData(mat)\n                           ? static_cast<IdType*>(GetDevicePointer(mat.data))\n                           : nullptr;\n\n  // compute degree\n  IdType* out_deg = static_cast<IdType*>(\n      device->AllocWorkspace(ctx, (num_rows + 1) * sizeof(IdType)));\n  if (replace) {\n    const dim3 block(512);\n    const dim3 grid((num_rows + block.x - 1) / block.x);\n    CUDA_KERNEL_CALL(\n        _CSRRowWiseSampleDegreeReplaceKernel, grid, block, 0, stream, num_picks,\n        num_rows, slice_rows, in_ptr, out_deg);\n  } else {\n    const dim3 block(512);\n    const dim3 grid((num_rows + block.x - 1) / block.x);\n    CUDA_KERNEL_CALL(\n        _CSRRowWiseSampleDegreeKernel, grid, block, 0, stream, num_picks,\n        num_rows, slice_rows, in_ptr, out_deg);\n  }\n\n  // fill out_ptr\n  IdType* out_ptr = static_cast<IdType*>(\n      device->AllocWorkspace(ctx, (num_rows + 1) * sizeof(IdType)));\n  size_t prefix_temp_size = 0;\n  CUDA_CALL(cub::DeviceScan::ExclusiveSum(\n      nullptr, prefix_temp_size, out_deg, out_ptr, num_rows + 1, stream));\n  void* prefix_temp = device->AllocWorkspace(ctx, prefix_temp_size);\n  CUDA_CALL(cub::DeviceScan::ExclusiveSum(\n      prefix_temp, prefix_temp_size, out_deg, out_ptr, num_rows + 1, stream));\n  device->FreeWorkspace(ctx, prefix_temp);\n  device->FreeWorkspace(ctx, out_deg);\n\n  cudaEvent_t copyEvent;\n  CUDA_CALL(cudaEventCreate(&copyEvent));\n\n  NDArray new_len_tensor;\n  if (TensorDispatcher::Global()->IsAvailable()) {\n    new_len_tensor = NDArray::PinnedEmpty(\n        {1}, DGLDataTypeTraits<IdType>::dtype, DGLContext{kDGLCPU, 0});\n  } else {\n    // use pageable memory, it will unecessarily block but be functional\n    new_len_tensor = NDArray::Empty(\n        {1}, DGLDataTypeTraits<IdType>::dtype, DGLContext{kDGLCPU, 0});\n  }\n\n  // copy using the internal current stream\n  CUDA_CALL(cudaMemcpyAsync(\n      new_len_tensor->data, out_ptr + num_rows, sizeof(IdType),\n      cudaMemcpyDeviceToHost, stream));\n  CUDA_CALL(cudaEventRecord(copyEvent, stream));\n\n  const uint64_t random_seed = RandomEngine::ThreadLocal()->RandInt(1000000000);\n\n  // select edges\n  // the number of rows each thread block will cover\n  constexpr int TILE_SIZE = 128 / BLOCK_SIZE;\n  if (replace) {  // with replacement\n    const dim3 block(BLOCK_SIZE);\n    const dim3 grid((num_rows + TILE_SIZE - 1) / TILE_SIZE);\n    CUDA_KERNEL_CALL(\n        (_CSRRowWiseSampleUniformReplaceKernel<IdType, TILE_SIZE>), grid, block,\n        0, stream, random_seed, num_picks, num_rows, slice_rows, in_ptr,\n        in_cols, data, out_ptr, out_rows, out_cols, out_idxs);\n  } else {  // without replacement\n    const dim3 block(BLOCK_SIZE);\n    const dim3 grid((num_rows + TILE_SIZE - 1) / TILE_SIZE);\n    CUDA_KERNEL_CALL(\n        (_CSRRowWiseSampleUniformKernel<IdType, TILE_SIZE>), grid, block, 0,\n        stream, random_seed, num_picks, num_rows, slice_rows, in_ptr, in_cols,\n        data, out_ptr, out_rows, out_cols, out_idxs);\n  }\n  device->FreeWorkspace(ctx, out_ptr);\n\n  // wait for copying `new_len` to finish\n  CUDA_CALL(cudaEventSynchronize(copyEvent));\n  CUDA_CALL(cudaEventDestroy(copyEvent));\n\n  const IdType new_len = static_cast<const IdType*>(new_len_tensor->data)[0];\n  picked_row = picked_row.CreateView({new_len}, picked_row->dtype);\n  picked_col = picked_col.CreateView({new_len}, picked_col->dtype);\n  picked_idx = picked_idx.CreateView({new_len}, picked_idx->dtype);\n\n  return COOMatrix(\n      mat.num_rows, mat.num_cols, picked_row, picked_col, picked_idx);\n}\n\ntemplate <DGLDeviceType XPU, typename IdType>\nCOOMatrix CSRRowWiseSamplingUniform(\n    CSRMatrix mat, IdArray rows, const int64_t num_picks, const bool replace) {\n  if (num_picks == -1) {\n    // Basically this is UnitGraph::InEdges().\n    COOMatrix coo = CSRToCOO(CSRSliceRows(mat, rows), false);\n    IdArray sliced_rows = IndexSelect(rows, coo.row);\n    return COOMatrix(\n        mat.num_rows, mat.num_cols, sliced_rows, coo.col, coo.data);\n  } else {\n    return _CSRRowWiseSamplingUniform<XPU, IdType>(\n        mat, rows, num_picks, replace);\n  }\n}\n\ntemplate COOMatrix CSRRowWiseSamplingUniform<kDGLCUDA, int32_t>(\n    CSRMatrix, IdArray, int64_t, bool);\ntemplate COOMatrix CSRRowWiseSamplingUniform<kDGLCUDA, int64_t>(\n    CSRMatrix, IdArray, int64_t, bool);\n\n}  // namespace impl\n}  // namespace aten\n}  // namespace dgl\n"
  },
  {
    "path": "src/array/cuda/rowwise_sampling_prob.cu",
    "content": "/**\n *  Copyright (c) 2022 by Contributors\n * @file array/cuda/rowwise_sampling_prob.cu\n * @brief weighted rowwise sampling. The degree computing kernels and\n * host-side functions are partially borrowed from the uniform rowwise\n * sampling code rowwise_sampling.cu.\n * @author pengqirong (OPPO), dlasalle and Xin from Nvidia.\n */\n#include <curand_kernel.h>\n#include <dgl/random.h>\n#include <dgl/runtime/device_api.h>\n\n#include <cub/cub.cuh>\n#include <numeric>\n\n#include \"../../array/cuda/atomic.cuh\"\n#include \"../../runtime/cuda/cuda_common.h\"\n#include \"./utils.h\"\n\n// require CUB 1.17 to use DeviceSegmentedSort\nstatic_assert(\n    CUB_VERSION >= 101700, \"Require CUB >= 1.17 to use DeviceSegmentedSort\");\n\nnamespace dgl {\nusing namespace cuda;\nusing namespace aten::cuda;\nnamespace aten {\nnamespace impl {\n\nnamespace {\n\nconstexpr int BLOCK_SIZE = 128;\n\n/**\n * @brief Compute the size of each row in the sampled CSR, without replacement.\n * temp_deg is calculated for rows with deg > num_picks.\n * For these rows, we will calculate their A-Res values and sort them to get\n * top-num_picks.\n *\n * @tparam IdType The type of node and edge indexes.\n * @param num_picks The number of non-zero entries to pick per row.\n * @param num_rows The number of rows to pick.\n * @param in_rows The set of rows to pick.\n * @param in_ptr The index where each row's edges start.\n * @param out_deg The size of each row in the sampled matrix, as indexed by\n * `in_rows` (output).\n * @param temp_deg The size of each row in the input matrix, as indexed by\n * `in_rows` (output).\n */\ntemplate <typename IdType>\n__global__ void _CSRRowWiseSampleDegreeKernel(\n    const int64_t num_picks, const int64_t num_rows,\n    const IdType* const in_rows, const IdType* const in_ptr,\n    IdType* const out_deg, IdType* const temp_deg) {\n  const int64_t tIdx = threadIdx.x + blockIdx.x * blockDim.x;\n\n  if (tIdx < num_rows) {\n    const int64_t in_row = in_rows[tIdx];\n    const int64_t out_row = tIdx;\n    const IdType deg = in_ptr[in_row + 1] - in_ptr[in_row];\n    // temp_deg is used to generate ares_ptr\n    temp_deg[out_row] = deg > static_cast<IdType>(num_picks) ? deg : 0;\n    out_deg[out_row] = min(static_cast<IdType>(num_picks), deg);\n\n    if (out_row == num_rows - 1) {\n      // make the prefixsum work\n      out_deg[num_rows] = 0;\n      temp_deg[num_rows] = 0;\n    }\n  }\n}\n\n/**\n * @brief Compute the size of each row in the sampled CSR, with replacement.\n * We need the actual in degree of each row to store CDF values.\n *\n * @tparam IdType The type of node and edge indexes.\n * @param num_picks The number of non-zero entries to pick per row.\n * @param num_rows The number of rows to pick.\n * @param in_rows The set of rows to pick.\n * @param in_ptr The index where each row's edges start.\n * @param out_deg The size of each row in the sampled matrix, as indexed by\n * `in_rows` (output).\n * @param temp_deg The size of each row in the input matrix, as indexed by\n * `in_rows` (output).\n */\ntemplate <typename IdType>\n__global__ void _CSRRowWiseSampleDegreeReplaceKernel(\n    const int64_t num_picks, const int64_t num_rows,\n    const IdType* const in_rows, const IdType* const in_ptr,\n    IdType* const out_deg, IdType* const temp_deg) {\n  const int64_t tIdx = threadIdx.x + blockIdx.x * blockDim.x;\n\n  if (tIdx < num_rows) {\n    const int64_t in_row = in_rows[tIdx];\n    const int64_t out_row = tIdx;\n    const IdType deg = in_ptr[in_row + 1] - in_ptr[in_row];\n    temp_deg[out_row] = deg;\n    out_deg[out_row] = deg == 0 ? 0 : static_cast<IdType>(num_picks);\n\n    if (out_row == num_rows - 1) {\n      // make the prefixsum work\n      out_deg[num_rows] = 0;\n      temp_deg[num_rows] = 0;\n    }\n  }\n}\n\n/**\n * @brief Equivalent to numpy expression: array[idx[off:off + len]]\n *\n * @tparam IdType The ID type used for indices.\n * @tparam FloatType The float type used for array values.\n * @param array The array to be selected.\n * @param idx_data The index mapping array.\n * @param index The index of value to be selected.\n * @param offset The offset to start.\n * @param out The selected value (output).\n */\ntemplate <typename IdType, typename FloatType>\n__device__ void _DoubleSlice(\n    const FloatType* const array, const IdType* const idx_data,\n    const IdType idx, const IdType offset, FloatType* const out) {\n  if (idx_data) {\n    *out = array[idx_data[offset + idx]];\n  } else {\n    *out = array[offset + idx];\n  }\n}\n\n/**\n * @brief Compute A-Res value. A-Res value needs to be calculated only if deg\n * is greater than num_picks in weighted rowwise sampling without replacement.\n *\n * @tparam IdType The ID type used for matrices.\n * @tparam FloatType The Float type used for matrices.\n * @tparam TILE_SIZE The number of rows covered by each threadblock.\n * @param rand_seed The random seed to use.\n * @param num_picks The number of non-zeros to pick per row.\n * @param num_rows The number of rows to pick.\n * @param in_rows The set of rows to pick.\n * @param in_ptr The indptr array of the input CSR.\n * @param data The data array of the input CSR.\n * @param prob The probability array of the input CSR.\n * @param ares_ptr The offset to write each row to in the A-res array.\n * @param ares_idxs The A-Res value corresponding index array, the index of\n * input CSR (output).\n * @param ares The A-Res value array (output).\n * @author pengqirong (OPPO)\n */\ntemplate <typename IdType, typename FloatType, int TILE_SIZE>\n__global__ void _CSRAResValueKernel(\n    const uint64_t rand_seed, const int64_t num_picks, const int64_t num_rows,\n    const IdType* const in_rows, const IdType* const in_ptr,\n    const IdType* const data, const FloatType* const prob,\n    const IdType* const ares_ptr, IdType* const ares_idxs,\n    FloatType* const ares) {\n  int64_t out_row = blockIdx.x * TILE_SIZE;\n  const int64_t last_row =\n      min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_rows);\n\n  curandStatePhilox4_32_10_t rng;\n  curand_init(rand_seed * gridDim.x + blockIdx.x, threadIdx.x, 0, &rng);\n\n  while (out_row < last_row) {\n    const int64_t row = in_rows[out_row];\n    const int64_t in_row_start = in_ptr[row];\n    const int64_t deg = in_ptr[row + 1] - in_row_start;\n    // A-Res value needs to be calculated only if deg is greater than num_picks\n    // in weighted rowwise sampling without replacement\n    if (deg > num_picks) {\n      const int64_t ares_row_start = ares_ptr[out_row];\n\n      for (int64_t idx = threadIdx.x; idx < deg; idx += BLOCK_SIZE) {\n        const int64_t in_idx = in_row_start + idx;\n        const int64_t ares_idx = ares_row_start + idx;\n        FloatType item_prob;\n        _DoubleSlice<IdType, FloatType>(\n            prob, data, idx, in_row_start, &item_prob);\n        // compute A-Res value\n        ares[ares_idx] = static_cast<FloatType>(\n            __powf(curand_uniform(&rng), 1.0f / item_prob));\n        ares_idxs[ares_idx] = static_cast<IdType>(in_idx);\n      }\n    }\n    out_row += 1;\n  }\n}\n\n/**\n * @brief Perform weighted row-wise sampling on a CSR matrix, and generate a COO\n * matrix, without replacement. After sorting, we select top-num_picks items.\n *\n * @tparam IdType The ID type used for matrices.\n * @tparam FloatType The Float type used for matrices.\n * @tparam TILE_SIZE The number of rows covered by each threadblock.\n * @param num_picks The number of non-zeros to pick per row.\n * @param num_rows The number of rows to pick.\n * @param in_rows The set of rows to pick.\n * @param in_ptr The indptr array of the input CSR.\n * @param in_cols The columns array of the input CSR.\n * @param data The data array of the input CSR.\n * @param out_ptr The offset to write each row to in the output COO.\n * @param ares_ptr The offset to write each row to in the ares array.\n * @param sort_ares_idxs The sorted A-Res value corresponding index array, the\n * index of input CSR.\n * @param out_rows The rows of the output COO (output).\n * @param out_cols The columns of the output COO (output).\n * @param out_idxs The data array of the output COO (output).\n * @author pengqirong (OPPO)\n */\ntemplate <typename IdType, typename FloatType, int TILE_SIZE>\n__global__ void _CSRRowWiseSampleKernel(\n    const int64_t num_picks, const int64_t num_rows,\n    const IdType* const in_rows, const IdType* const in_ptr,\n    const IdType* const in_cols, const IdType* const data,\n    const IdType* const out_ptr, const IdType* const ares_ptr,\n    const IdType* const sort_ares_idxs, IdType* const out_rows,\n    IdType* const out_cols, IdType* const out_idxs) {\n  // we assign one warp per row\n  assert(blockDim.x == BLOCK_SIZE);\n\n  int64_t out_row = blockIdx.x * TILE_SIZE;\n  const int64_t last_row =\n      min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_rows);\n\n  while (out_row < last_row) {\n    const int64_t row = in_rows[out_row];\n    const int64_t in_row_start = in_ptr[row];\n    const int64_t out_row_start = out_ptr[out_row];\n    const int64_t deg = in_ptr[row + 1] - in_row_start;\n\n    if (deg > num_picks) {\n      const int64_t ares_row_start = ares_ptr[out_row];\n      for (int64_t idx = threadIdx.x; idx < num_picks; idx += BLOCK_SIZE) {\n        // get in and out index, the in_idx is one of top num_picks A-Res value\n        // corresponding index in input CSR.\n        const int64_t out_idx = out_row_start + idx;\n        const int64_t ares_idx = ares_row_start + idx;\n        const int64_t in_idx = sort_ares_idxs[ares_idx];\n        // copy permutation over\n        out_rows[out_idx] = static_cast<IdType>(row);\n        out_cols[out_idx] = in_cols[in_idx];\n        out_idxs[out_idx] = static_cast<IdType>(data ? data[in_idx] : in_idx);\n      }\n    } else {\n      for (int64_t idx = threadIdx.x; idx < deg; idx += BLOCK_SIZE) {\n        // get in and out index\n        const int64_t out_idx = out_row_start + idx;\n        const int64_t in_idx = in_row_start + idx;\n        // copy permutation over\n        out_rows[out_idx] = static_cast<IdType>(row);\n        out_cols[out_idx] = in_cols[in_idx];\n        out_idxs[out_idx] = static_cast<IdType>(data ? data[in_idx] : in_idx);\n      }\n    }\n    out_row += 1;\n  }\n}\n\n// A stateful callback functor that maintains a running prefix to be applied\n// during consecutive scan operations.\ntemplate <typename FloatType>\nstruct BlockPrefixCallbackOp {\n  // Running prefix\n  FloatType running_total;\n  // Constructor\n  __device__ BlockPrefixCallbackOp(FloatType running_total)\n      : running_total(running_total) {}\n  // Callback operator to be entered by the first warp of threads in the block.\n  // Thread-0 is responsible for returning a value for seeding the block-wide\n  // scan.\n  __device__ FloatType operator()(FloatType block_aggregate) {\n    FloatType old_prefix = running_total;\n    running_total += block_aggregate;\n    return old_prefix;\n  }\n};\n\n/**\n * @brief Perform weighted row-wise sampling on a CSR matrix, and generate a COO\n * matrix, with replacement. We store the CDF (unnormalized) of all neighbors of\n * a row in global memory and use binary search to find inverse indices as\n * selected items.\n *\n * @tparam IdType The ID type used for matrices.\n * @tparam FloatType The Float type used for matrices.\n * @tparam TILE_SIZE The number of rows covered by each threadblock.\n * @param rand_seed The random seed to use.\n * @param num_picks The number of non-zeros to pick per row.\n * @param num_rows The number of rows to pick.\n * @param in_rows The set of rows to pick.\n * @param in_ptr The indptr array of the input CSR.\n * @param in_cols The columns array of the input CSR.\n * @param data The data array of the input CSR.\n * @param prob The probability array of the input CSR.\n * @param out_ptr The offset to write each row to in the output COO.\n * @param cdf_ptr The offset of each cdf segment.\n * @param cdf The global buffer to store cdf segments.\n * @param out_rows The rows of the output COO (output).\n * @param out_cols The columns of the output COO (output).\n * @param out_idxs The data array of the output COO (output).\n * @author pengqirong (OPPO)\n */\ntemplate <typename IdType, typename FloatType, int TILE_SIZE>\n__global__ void _CSRRowWiseSampleReplaceKernel(\n    const uint64_t rand_seed, const int64_t num_picks, const int64_t num_rows,\n    const IdType* const in_rows, const IdType* const in_ptr,\n    const IdType* const in_cols, const IdType* const data,\n    const FloatType* const prob, const IdType* const out_ptr,\n    const IdType* const cdf_ptr, FloatType* const cdf, IdType* const out_rows,\n    IdType* const out_cols, IdType* const out_idxs) {\n  // we assign one warp per row\n  assert(blockDim.x == BLOCK_SIZE);\n\n  int64_t out_row = blockIdx.x * TILE_SIZE;\n  const int64_t last_row =\n      min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_rows);\n\n  curandStatePhilox4_32_10_t rng;\n  curand_init(rand_seed * gridDim.x + blockIdx.x, threadIdx.x, 0, &rng);\n\n  while (out_row < last_row) {\n    const int64_t row = in_rows[out_row];\n    const int64_t in_row_start = in_ptr[row];\n    const int64_t out_row_start = out_ptr[out_row];\n    const int64_t cdf_row_start = cdf_ptr[out_row];\n    const int64_t deg = in_ptr[row + 1] - in_row_start;\n    const FloatType MIN_THREAD_DATA = static_cast<FloatType>(0.0f);\n\n    if (deg > 0) {\n      // Specialize BlockScan for a 1D block of BLOCK_SIZE threads\n      typedef cub::BlockScan<FloatType, BLOCK_SIZE> BlockScan;\n      // Allocate shared memory for BlockScan\n      __shared__ typename BlockScan::TempStorage temp_storage;\n      // Initialize running total\n      BlockPrefixCallbackOp<FloatType> prefix_op(MIN_THREAD_DATA);\n\n      int64_t max_iter = (1 + (deg - 1) / BLOCK_SIZE) * BLOCK_SIZE;\n      // Have the block iterate over segments of items\n      for (int64_t idx = threadIdx.x; idx < max_iter; idx += BLOCK_SIZE) {\n        // Load a segment of consecutive items that are blocked across threads\n        FloatType thread_data;\n        if (idx < deg)\n          _DoubleSlice<IdType, FloatType>(\n              prob, data, idx, in_row_start, &thread_data);\n        else\n          thread_data = MIN_THREAD_DATA;\n        thread_data = max(thread_data, MIN_THREAD_DATA);\n        // Collectively compute the block-wide inclusive prefix sum\n        BlockScan(temp_storage)\n            .InclusiveSum(thread_data, thread_data, prefix_op);\n        __syncthreads();\n\n        // Store scanned items to cdf array\n        if (idx < deg) {\n          cdf[cdf_row_start + idx] = thread_data;\n        }\n      }\n      __syncthreads();\n\n      for (int64_t idx = threadIdx.x; idx < num_picks; idx += BLOCK_SIZE) {\n        // get random value\n        FloatType sum = cdf[cdf_row_start + deg - 1];\n        FloatType rand = static_cast<FloatType>(curand_uniform(&rng) * sum);\n        // get the offset of the first value within cdf array which is greater\n        // than random value.\n        int64_t item = cub::UpperBound<FloatType*, int64_t, FloatType>(\n            &cdf[cdf_row_start], deg, rand);\n        item = min(item, deg - 1);\n        // get in and out index\n        const int64_t in_idx = in_row_start + item;\n        const int64_t out_idx = out_row_start + idx;\n        // copy permutation over\n        out_rows[out_idx] = static_cast<IdType>(row);\n        out_cols[out_idx] = in_cols[in_idx];\n        out_idxs[out_idx] = static_cast<IdType>(data ? data[in_idx] : in_idx);\n      }\n    }\n    out_row += 1;\n  }\n}\n\ntemplate <typename IdType, typename DType, typename BoolType>\n__global__ void _GenerateFlagsKernel(\n    int64_t n, const IdType* idx, const DType* values, DType criteria,\n    BoolType* output) {\n  int tx = blockIdx.x * blockDim.x + threadIdx.x;\n  const int stride_x = gridDim.x * blockDim.x;\n  while (tx < n) {\n    output[tx] = (values[idx ? idx[tx] : tx] != criteria);\n    tx += stride_x;\n  }\n}\n\ntemplate <DGLDeviceType XPU, typename IdType, typename DType, typename MaskGen>\nCOOMatrix COOGeneralRemoveIf(const COOMatrix& coo, MaskGen maskgen) {\n  using namespace dgl::cuda;\n\n  const auto idtype = coo.row->dtype;\n  const auto ctx = coo.row->ctx;\n  const int64_t nnz = coo.row->shape[0];\n  const IdType* row = coo.row.Ptr<IdType>();\n  const IdType* col = coo.col.Ptr<IdType>();\n  const IdArray& eid =\n      COOHasData(coo) ? coo.data : Range(0, nnz, sizeof(IdType) * 8, ctx);\n  const IdType* data = coo.data.Ptr<IdType>();\n  IdArray new_row = IdArray::Empty({nnz}, idtype, ctx);\n  IdArray new_col = IdArray::Empty({nnz}, idtype, ctx);\n  IdArray new_eid = IdArray::Empty({nnz}, idtype, ctx);\n  IdType* new_row_data = new_row.Ptr<IdType>();\n  IdType* new_col_data = new_col.Ptr<IdType>();\n  IdType* new_eid_data = new_eid.Ptr<IdType>();\n  auto stream = runtime::getCurrentCUDAStream();\n  auto device = runtime::DeviceAPI::Get(ctx);\n\n  int8_t* flags = static_cast<int8_t*>(device->AllocWorkspace(ctx, nnz));\n  int nt = dgl::cuda::FindNumThreads(nnz);\n  int64_t nb = (nnz + nt - 1) / nt;\n\n  maskgen(nb, nt, stream, nnz, data, flags);\n\n  int64_t* rst =\n      static_cast<int64_t*>(device->AllocWorkspace(ctx, sizeof(int64_t)));\n  MaskSelect(device, ctx, row, flags, new_row_data, nnz, rst, stream);\n  MaskSelect(device, ctx, col, flags, new_col_data, nnz, rst, stream);\n  MaskSelect(device, ctx, data, flags, new_eid_data, nnz, rst, stream);\n\n  int64_t new_len = GetCUDAScalar(device, ctx, rst);\n\n  device->FreeWorkspace(ctx, flags);\n  device->FreeWorkspace(ctx, rst);\n  return COOMatrix(\n      coo.num_rows, coo.num_cols, new_row.CreateView({new_len}, idtype, 0),\n      new_col.CreateView({new_len}, idtype, 0),\n      new_eid.CreateView({new_len}, idtype, 0));\n}\n\ntemplate <DGLDeviceType XPU, typename IdType, typename DType>\nCOOMatrix _COORemoveIf(\n    const COOMatrix& coo, const NDArray& values, DType criteria) {\n  const DType* val = values.Ptr<DType>();\n  auto maskgen = [val, criteria](\n                     int nb, int nt, cudaStream_t stream, int64_t nnz,\n                     const IdType* data, int8_t* flags) {\n    CUDA_KERNEL_CALL(\n        (_GenerateFlagsKernel<IdType, DType, int8_t>), nb, nt, 0, stream, nnz,\n        data, val, criteria, flags);\n  };\n  return COOGeneralRemoveIf<XPU, IdType, DType, decltype(maskgen)>(\n      coo, maskgen);\n}\n\n}  // namespace\n\n/////////////////////////////// CSR ///////////////////////////////\n\n/**\n * @brief Perform weighted row-wise sampling on a CSR matrix, and generate a COO\n * matrix. Use CDF sampling algorithm for with replacement:\n *   1) Calculate the CDF of all neighbor's prob.\n *   2) For each [0, num_picks), generate a rand ~ U(0, 1). Use binary search to\n *      find its index in the CDF array as a chosen item.\n * Use A-Res sampling algorithm for without replacement:\n *   1) For rows with deg > num_picks, calculate A-Res values for all neighbors.\n *   2) Sort the A-Res array and select top-num_picks as chosen items.\n *\n * @tparam XPU The device type used for matrices.\n * @tparam IdType The ID type used for matrices.\n * @tparam FloatType The Float type used for matrices.\n * @param mat The CSR matrix.\n * @param rows The set of rows to pick.\n * @param num_picks The number of non-zeros to pick per row.\n * @param prob The probability array of the input CSR.\n * @param replace Is replacement sampling?\n * @author pengqirong (OPPO), dlasalle and Xin from Nvidia.\n */\ntemplate <DGLDeviceType XPU, typename IdType, typename FloatType>\nCOOMatrix _CSRRowWiseSampling(\n    const CSRMatrix& mat, const IdArray& rows, int64_t num_picks,\n    const FloatArray& prob, bool replace) {\n  const auto& ctx = rows->ctx;\n  auto device = runtime::DeviceAPI::Get(ctx);\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n\n  const int64_t num_rows = rows->shape[0];\n  const IdType* const slice_rows = static_cast<const IdType*>(rows->data);\n\n  IdArray picked_row =\n      NewIdArray(num_rows * num_picks, ctx, sizeof(IdType) * 8);\n  IdArray picked_col =\n      NewIdArray(num_rows * num_picks, ctx, sizeof(IdType) * 8);\n  IdArray picked_idx =\n      NewIdArray(num_rows * num_picks, ctx, sizeof(IdType) * 8);\n  IdType* const out_rows = static_cast<IdType*>(picked_row->data);\n  IdType* const out_cols = static_cast<IdType*>(picked_col->data);\n  IdType* const out_idxs = static_cast<IdType*>(picked_idx->data);\n\n  const IdType* in_ptr = static_cast<IdType*>(GetDevicePointer(mat.indptr));\n  const IdType* in_cols = static_cast<IdType*>(GetDevicePointer(mat.indices));\n  const IdType* data = CSRHasData(mat)\n                           ? static_cast<IdType*>(GetDevicePointer(mat.data))\n                           : nullptr;\n  const FloatType* prob_data = static_cast<FloatType*>(GetDevicePointer(prob));\n\n  // compute degree\n  // out_deg: the size of each row in the sampled matrix\n  // temp_deg: the size of each row we will manipulate in sampling\n  //    1) for w/o replacement: in degree if it's greater than num_picks else 0\n  //    2) for w/ replacement: in degree\n  IdType* out_deg = static_cast<IdType*>(\n      device->AllocWorkspace(ctx, (num_rows + 1) * sizeof(IdType)));\n  IdType* temp_deg = static_cast<IdType*>(\n      device->AllocWorkspace(ctx, (num_rows + 1) * sizeof(IdType)));\n  if (replace) {\n    const dim3 block(512);\n    const dim3 grid((num_rows + block.x - 1) / block.x);\n    CUDA_KERNEL_CALL(\n        _CSRRowWiseSampleDegreeReplaceKernel, grid, block, 0, stream, num_picks,\n        num_rows, slice_rows, in_ptr, out_deg, temp_deg);\n  } else {\n    const dim3 block(512);\n    const dim3 grid((num_rows + block.x - 1) / block.x);\n    CUDA_KERNEL_CALL(\n        _CSRRowWiseSampleDegreeKernel, grid, block, 0, stream, num_picks,\n        num_rows, slice_rows, in_ptr, out_deg, temp_deg);\n  }\n\n  // fill temp_ptr\n  IdType* temp_ptr = static_cast<IdType*>(\n      device->AllocWorkspace(ctx, (num_rows + 1) * sizeof(IdType)));\n  size_t prefix_temp_size = 0;\n  CUDA_CALL(cub::DeviceScan::ExclusiveSum(\n      nullptr, prefix_temp_size, temp_deg, temp_ptr, num_rows + 1, stream));\n  void* prefix_temp = device->AllocWorkspace(ctx, prefix_temp_size);\n  CUDA_CALL(cub::DeviceScan::ExclusiveSum(\n      prefix_temp, prefix_temp_size, temp_deg, temp_ptr, num_rows + 1, stream));\n  device->FreeWorkspace(ctx, prefix_temp);\n  device->FreeWorkspace(ctx, temp_deg);\n\n  // TODO(Xin): The copy here is too small, and the overhead of creating\n  // cuda events cannot be ignored. Just use synchronized copy.\n  IdType temp_len;\n  // copy using the internal current stream.\n  device->CopyDataFromTo(\n      temp_ptr, num_rows * sizeof(temp_len), &temp_len, 0, sizeof(temp_len),\n      ctx, DGLContext{kDGLCPU, 0}, mat.indptr->dtype);\n  device->StreamSync(ctx, stream);\n\n  // fill out_ptr\n  IdType* out_ptr = static_cast<IdType*>(\n      device->AllocWorkspace(ctx, (num_rows + 1) * sizeof(IdType)));\n  prefix_temp_size = 0;\n  CUDA_CALL(cub::DeviceScan::ExclusiveSum(\n      nullptr, prefix_temp_size, out_deg, out_ptr, num_rows + 1, stream));\n  prefix_temp = device->AllocWorkspace(ctx, prefix_temp_size);\n  CUDA_CALL(cub::DeviceScan::ExclusiveSum(\n      prefix_temp, prefix_temp_size, out_deg, out_ptr, num_rows + 1, stream));\n  device->FreeWorkspace(ctx, prefix_temp);\n  device->FreeWorkspace(ctx, out_deg);\n\n  cudaEvent_t copyEvent;\n  CUDA_CALL(cudaEventCreate(&copyEvent));\n  // TODO(dlasalle): use pinned memory to overlap with the actual sampling, and\n  // wait on a cudaevent\n  IdType new_len;\n  // copy using the internal current stream.\n  device->CopyDataFromTo(\n      out_ptr, num_rows * sizeof(new_len), &new_len, 0, sizeof(new_len), ctx,\n      DGLContext{kDGLCPU, 0}, mat.indptr->dtype);\n  CUDA_CALL(cudaEventRecord(copyEvent, stream));\n\n  // allocate workspace\n  // 1) for w/ replacement, it's a global buffer to store cdf segments (one\n  // segment for each row).\n  // 2) for w/o replacement, it's used to store a-res segments (one segment for\n  // each row with degree > num_picks)\n  FloatType* temp = static_cast<FloatType*>(\n      device->AllocWorkspace(ctx, temp_len * sizeof(FloatType)));\n\n  const uint64_t rand_seed = RandomEngine::ThreadLocal()->RandInt(1000000000);\n\n  // select edges\n  // the number of rows each thread block will cover\n  constexpr int TILE_SIZE = 128 / BLOCK_SIZE;\n  if (replace) {  // with replacement.\n    const dim3 block(BLOCK_SIZE);\n    const dim3 grid((num_rows + TILE_SIZE - 1) / TILE_SIZE);\n    CUDA_KERNEL_CALL(\n        (_CSRRowWiseSampleReplaceKernel<IdType, FloatType, TILE_SIZE>), grid,\n        block, 0, stream, rand_seed, num_picks, num_rows, slice_rows, in_ptr,\n        in_cols, data, prob_data, out_ptr, temp_ptr, temp, out_rows, out_cols,\n        out_idxs);\n    device->FreeWorkspace(ctx, temp);\n  } else {  // without replacement\n    IdType* temp_idxs = static_cast<IdType*>(\n        device->AllocWorkspace(ctx, (temp_len) * sizeof(IdType)));\n\n    // Compute A-Res value. A-Res value needs to be calculated only if deg\n    // is greater than num_picks in weighted rowwise sampling without\n    // replacement.\n    const dim3 block(BLOCK_SIZE);\n    const dim3 grid((num_rows + TILE_SIZE - 1) / TILE_SIZE);\n    CUDA_KERNEL_CALL(\n        (_CSRAResValueKernel<IdType, FloatType, TILE_SIZE>), grid, block, 0,\n        stream, rand_seed, num_picks, num_rows, slice_rows, in_ptr, data,\n        prob_data, temp_ptr, temp_idxs, temp);\n\n    // sort A-Res value array.\n    FloatType* sort_temp = static_cast<FloatType*>(\n        device->AllocWorkspace(ctx, temp_len * sizeof(FloatType)));\n    IdType* sort_temp_idxs = static_cast<IdType*>(\n        device->AllocWorkspace(ctx, temp_len * sizeof(IdType)));\n\n    cub::DoubleBuffer<FloatType> sort_keys(temp, sort_temp);\n    cub::DoubleBuffer<IdType> sort_values(temp_idxs, sort_temp_idxs);\n\n    void* d_temp_storage = nullptr;\n    size_t temp_storage_bytes = 0;\n    CUDA_CALL(cub::DeviceSegmentedSort::SortPairsDescending(\n        d_temp_storage, temp_storage_bytes, sort_keys, sort_values, temp_len,\n        num_rows, temp_ptr, temp_ptr + 1, stream));\n    d_temp_storage = device->AllocWorkspace(ctx, temp_storage_bytes);\n    CUDA_CALL(cub::DeviceSegmentedSort::SortPairsDescending(\n        d_temp_storage, temp_storage_bytes, sort_keys, sort_values, temp_len,\n        num_rows, temp_ptr, temp_ptr + 1, stream));\n    device->FreeWorkspace(ctx, d_temp_storage);\n    device->FreeWorkspace(ctx, temp);\n    device->FreeWorkspace(ctx, temp_idxs);\n    device->FreeWorkspace(ctx, sort_temp);\n    device->FreeWorkspace(ctx, sort_temp_idxs);\n\n    // select tok-num_picks as results\n    CUDA_KERNEL_CALL(\n        (_CSRRowWiseSampleKernel<IdType, FloatType, TILE_SIZE>), grid, block, 0,\n        stream, num_picks, num_rows, slice_rows, in_ptr, in_cols, data, out_ptr,\n        temp_ptr, sort_values.Current(), out_rows, out_cols, out_idxs);\n  }\n\n  device->FreeWorkspace(ctx, temp_ptr);\n  device->FreeWorkspace(ctx, out_ptr);\n\n  // wait for copying `new_len` to finish\n  CUDA_CALL(cudaEventSynchronize(copyEvent));\n  CUDA_CALL(cudaEventDestroy(copyEvent));\n\n  picked_row = picked_row.CreateView({new_len}, picked_row->dtype);\n  picked_col = picked_col.CreateView({new_len}, picked_col->dtype);\n  picked_idx = picked_idx.CreateView({new_len}, picked_idx->dtype);\n\n  return COOMatrix(\n      mat.num_rows, mat.num_cols, picked_row, picked_col, picked_idx);\n}\n\ntemplate <DGLDeviceType XPU, typename IdType, typename DType>\nCOOMatrix CSRRowWiseSampling(\n    CSRMatrix mat, IdArray rows, int64_t num_picks, FloatArray prob,\n    bool replace) {\n  COOMatrix result;\n  if (num_picks == -1) {\n    // Basically this is UnitGraph::InEdges().\n    COOMatrix coo = CSRToCOO(CSRSliceRows(mat, rows), false);\n    IdArray sliced_rows = IndexSelect(rows, coo.row);\n    result =\n        COOMatrix(mat.num_rows, mat.num_cols, sliced_rows, coo.col, coo.data);\n  } else {\n    result = _CSRRowWiseSampling<XPU, IdType, DType>(\n        mat, rows, num_picks, prob, replace);\n  }\n  // NOTE(BarclayII): I'm removing the entries with zero probability after\n  // sampling. Is there a better way?\n  return _COORemoveIf<XPU, IdType, DType>(result, prob, static_cast<DType>(0));\n}\n\ntemplate COOMatrix CSRRowWiseSampling<kDGLCUDA, int32_t, float>(\n    CSRMatrix, IdArray, int64_t, FloatArray, bool);\ntemplate COOMatrix CSRRowWiseSampling<kDGLCUDA, int64_t, float>(\n    CSRMatrix, IdArray, int64_t, FloatArray, bool);\ntemplate COOMatrix CSRRowWiseSampling<kDGLCUDA, int32_t, double>(\n    CSRMatrix, IdArray, int64_t, FloatArray, bool);\ntemplate COOMatrix CSRRowWiseSampling<kDGLCUDA, int64_t, double>(\n    CSRMatrix, IdArray, int64_t, FloatArray, bool);\n// These are not being called, but we instantiate them anyway to prevent missing\n// symbols in Debug build\ntemplate COOMatrix CSRRowWiseSampling<kDGLCUDA, int32_t, int8_t>(\n    CSRMatrix, IdArray, int64_t, FloatArray, bool);\ntemplate COOMatrix CSRRowWiseSampling<kDGLCUDA, int64_t, int8_t>(\n    CSRMatrix, IdArray, int64_t, FloatArray, bool);\ntemplate COOMatrix CSRRowWiseSampling<kDGLCUDA, int32_t, uint8_t>(\n    CSRMatrix, IdArray, int64_t, FloatArray, bool);\ntemplate COOMatrix CSRRowWiseSampling<kDGLCUDA, int64_t, uint8_t>(\n    CSRMatrix, IdArray, int64_t, FloatArray, bool);\n\n}  // namespace impl\n}  // namespace aten\n}  // namespace dgl\n"
  },
  {
    "path": "src/array/cuda/sddmm.cu",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file array/cuda/sddmm.cu\n * @brief SDDMM C APIs and definitions.\n */\n#include <dgl/array.h>\n\n#include \"./functor.cuh\"\n#include \"./sddmm.cuh\"\n\nnamespace dgl {\nnamespace aten {\n\n/**\n * @brief CUDA implementation of g-SDDMM on Csr format.\n */\ntemplate <int XPU, typename IdType, typename DType>\nvoid SDDMMCsr(\n    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,\n    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target) {\n  SWITCH_OP(op, Op, {\n    SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {\n      cuda::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(\n          bcast, csr, lhs, rhs, out);\n    });\n  });\n}\n\n/**\n * @brief CUDA implementation of g-SDDMM on Coo format.\n */\ntemplate <int XPU, typename IdType, typename DType>\nvoid SDDMMCoo(\n    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,\n    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target) {\n  SWITCH_OP(op, Op, {\n    SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {\n      cuda::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(\n          bcast, coo, lhs, rhs, out);\n    });\n  });\n}\n\ntemplate void SDDMMCsr<kDGLCUDA, int32_t, __half>(\n    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,\n    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);\ntemplate void SDDMMCsr<kDGLCUDA, int64_t, __half>(\n    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,\n    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);\n#if BF16_ENABLED\ntemplate void SDDMMCsr<kDGLCUDA, int32_t, __nv_bfloat16>(\n    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,\n    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);\ntemplate void SDDMMCsr<kDGLCUDA, int64_t, __nv_bfloat16>(\n    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,\n    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);\n#endif  // BF16_ENABLED\ntemplate void SDDMMCsr<kDGLCUDA, int32_t, float>(\n    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,\n    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);\ntemplate void SDDMMCsr<kDGLCUDA, int64_t, float>(\n    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,\n    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);\ntemplate void SDDMMCsr<kDGLCUDA, int32_t, double>(\n    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,\n    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);\ntemplate void SDDMMCsr<kDGLCUDA, int64_t, double>(\n    const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,\n    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);\n\ntemplate void SDDMMCoo<kDGLCUDA, int32_t, __half>(\n    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,\n    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);\ntemplate void SDDMMCoo<kDGLCUDA, int64_t, __half>(\n    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,\n    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);\n#if BF16_ENABLED\ntemplate void SDDMMCoo<kDGLCUDA, int32_t, __nv_bfloat16>(\n    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,\n    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);\ntemplate void SDDMMCoo<kDGLCUDA, int64_t, __nv_bfloat16>(\n    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,\n    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);\n#endif  // BF16_ENABLED\ntemplate void SDDMMCoo<kDGLCUDA, int32_t, float>(\n    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,\n    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);\ntemplate void SDDMMCoo<kDGLCUDA, int64_t, float>(\n    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,\n    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);\ntemplate void SDDMMCoo<kDGLCUDA, int32_t, double>(\n    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,\n    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);\ntemplate void SDDMMCoo<kDGLCUDA, int64_t, double>(\n    const std::string& op, const BcastOff& bcast, const COOMatrix& coo,\n    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);\n\n}  // namespace aten\n}  // namespace dgl\n"
  },
  {
    "path": "src/array/cuda/sddmm.cuh",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file array/cuda/sddmm.cuh\n * @brief SDDMM CUDA kernel function header.\n */\n#ifndef DGL_ARRAY_CUDA_SDDMM_CUH_\n#define DGL_ARRAY_CUDA_SDDMM_CUH_\n\n#include <dgl/bcast.h>\n\n#include \"../../runtime/cuda/cuda_common.h\"\n#include \"../selector.h\"\n#include \"./functor.cuh\"\n#include \"./utils.h\"\n#include \"atomic.cuh\"\n#include \"bf16.cuh\"\n#include \"fp16.cuh\"\n#include \"functor.cuh\"\n#include \"macro.cuh\"\n\nnamespace dgl {\n\nusing namespace cuda;\n\nnamespace aten {\nnamespace cuda {\n\n#define SWITCH_OP(op, Op, ...)                                        \\\n  do {                                                                \\\n    if ((op) == \"add\") {                                              \\\n      typedef cuda::binary::Add<DType> Op;                            \\\n      { __VA_ARGS__ }                                                 \\\n    } else if ((op) == \"sub\") {                                       \\\n      typedef cuda::binary::Sub<DType> Op;                            \\\n      { __VA_ARGS__ }                                                 \\\n    } else if ((op) == \"mul\") {                                       \\\n      typedef cuda::binary::Mul<DType> Op;                            \\\n      { __VA_ARGS__ }                                                 \\\n    } else if ((op) == \"div\") {                                       \\\n      typedef cuda::binary::Div<DType> Op;                            \\\n      { __VA_ARGS__ }                                                 \\\n    } else if ((op) == \"copy_lhs\") {                                  \\\n      typedef cuda::binary::CopyLhs<DType> Op;                        \\\n      { __VA_ARGS__ }                                                 \\\n    } else if ((op) == \"copy_rhs\") {                                  \\\n      typedef cuda::binary::CopyRhs<DType> Op;                        \\\n      { __VA_ARGS__ }                                                 \\\n    } else if ((op) == \"dot\") {                                       \\\n      typedef cuda::binary::Dot<DType> Op;                            \\\n      { __VA_ARGS__ }                                                 \\\n    } else {                                                          \\\n      LOG(FATAL) << \"Unsupported SpMM/SDDMM binary operator: \" << op; \\\n    }                                                                 \\\n  } while (0)\n\n#define SWITCH_RHS(rhs_target, RhsTarget, ...)             \\\n  do {                                                     \\\n    if ((rhs_target) == 0) {                               \\\n      constexpr int RhsTarget = 0;                         \\\n      { __VA_ARGS__ }                                      \\\n    } else if ((rhs_target) == 1) {                        \\\n      constexpr int RhsTarget = 1;                         \\\n      { __VA_ARGS__ }                                      \\\n    } else if ((rhs_target) == 2) {                        \\\n      constexpr int RhsTarget = 2;                         \\\n      { __VA_ARGS__ }                                      \\\n    } else {                                               \\\n      LOG(INFO) << \"Invalid rhs target: \" << (rhs_target); \\\n    }                                                      \\\n  } while (0)\n\n#define SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, ...) \\\n  do {                                                                   \\\n    if ((lhs_target) == 0) {                                             \\\n      constexpr int LhsTarget = 0;                                       \\\n      SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__);                    \\\n    } else if ((lhs_target) == 1) {                                      \\\n      constexpr int LhsTarget = 1;                                       \\\n      SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__);                    \\\n    } else if ((lhs_target) == 2) {                                      \\\n      constexpr int LhsTarget = 2;                                       \\\n      SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__);                    \\\n    } else {                                                             \\\n      LOG(INFO) << \"Invalid lhs target: \" << (lhs_target);               \\\n    }                                                                    \\\n  } while (0)\n\nconstexpr unsigned int full_mask = 0xffffffff;\n\n/**\n * @brief CUDA kernel of g-SDDMM on Coo format.\n * @note it uses edge parallel strategy, different threadblocks (on y-axis)\n *       is responsible for the computation on different edges. Threadblocks\n *       on the x-axis are responsible for the computation on different\n * positions in feature dimension.\n */\ntemplate <\n    typename Idx, typename DType, typename BinaryOp, bool UseBcast = false,\n    bool UseIdx = false, int LhsTarget = 0, int RhsTarget = 2>\n__global__ void SDDMMCooKernel(\n    const DType* __restrict__ lhs, const DType* __restrict__ rhs,\n    DType* __restrict__ out, const Idx* __restrict__ row,\n    const Idx* __restrict__ col, const Idx* __restrict__ edge_map, int64_t N,\n    int64_t M, int64_t E, int64_t reduce_size,\n    const int64_t* __restrict__ lhs_off, const int64_t* __restrict__ rhs_off,\n    int64_t lhs_len, int64_t rhs_len, int64_t out_len) {\n  // SDDMM with COO.\n  Idx ty = blockIdx.y * blockDim.y + threadIdx.y;\n  const Idx stride_y = blockDim.y * gridDim.y;\n  while (ty < E) {\n    const Idx src = _ldg(row + ty);\n    const Idx dst = _ldg(col + ty);\n    const Idx eid = UseIdx ? _ldg(edge_map + ty) : ty;\n    const DType* lhsoff =\n        BinaryOp::use_lhs\n            ? (lhs + Selector<LhsTarget>::Call(src, eid, dst) * lhs_len)\n            : nullptr;\n    const DType* rhsoff =\n        BinaryOp::use_rhs\n            ? (rhs + Selector<RhsTarget>::Call(src, eid, dst) * rhs_len)\n            : nullptr;\n    DType* outoff = out + eid * out_len;\n    int tx = blockIdx.x * blockDim.x + threadIdx.x;\n    const int stride_x = blockDim.x * gridDim.x;\n    while (tx < out_len) {\n      const Idx lhs_add = UseBcast ? lhs_off[tx] : tx;\n      const Idx rhs_add = UseBcast ? rhs_off[tx] : tx;\n      DType val = BinaryOp::Call(\n          lhsoff + lhs_add * reduce_size, rhsoff + rhs_add * reduce_size,\n          reduce_size);\n      outoff[tx] = val;\n      tx += stride_x;\n    }\n    ty += stride_y;\n  }\n}\n\n/**\n * @brief CUDA kernel of SDDMM-dot on Coo format, accelerated with tree\n * reduction.\n * @note it uses edge parallel strategy, different threadblocks (on y-axis)\n *       is responsible for the computation on different edges. Threadblocks\n *       on the x-axis are responsible for the computation on different\n * positions in feature dimension.\n */\ntemplate <\n    typename Idx, typename DType, bool UseBcast = false, bool UseIdx = false,\n    int LhsTarget = 0, int RhsTarget = 2>\n__global__ void SDDMMCooTreeReduceKernel(\n    const DType* __restrict__ lhs, const DType* __restrict__ rhs,\n    DType* __restrict__ out, const Idx* __restrict__ row,\n    const Idx* __restrict__ col, const Idx* __restrict__ edge_map, int64_t N,\n    int64_t M, int64_t E, int64_t reduce_size,\n    const int64_t* __restrict__ lhs_off, const int64_t* __restrict__ rhs_off,\n    int64_t lhs_len, int64_t rhs_len, int64_t out_len) {\n  Idx ty = blockIdx.x * blockDim.y + threadIdx.y;\n  if (ty < E) {\n    const Idx src = _ldg(row + ty);\n    const Idx dst = _ldg(col + ty);\n    const Idx eid = UseIdx ? _ldg(edge_map + ty) : ty;\n    const DType* lhsoff =\n        lhs + Selector<LhsTarget>::Call(src, eid, dst) * lhs_len;\n    const DType* rhsoff =\n        rhs + Selector<RhsTarget>::Call(src, eid, dst) * rhs_len;\n    DType* outoff = out + eid * out_len;\n    int tx = threadIdx.x;  // tx < 32\n    for (int i = blockIdx.y; i < out_len;\n         i += gridDim.y) {  // over output feature dimension\n      const Idx lhs_add = UseBcast ? __ldg(lhs_off + i) : i;\n      const Idx rhs_add = UseBcast ? __ldg(rhs_off + i) : i;\n      DType val = reduce::Sum<Idx, DType>::zero();\n      for (int j = tx; j < reduce_size; j += 64) {\n        val += lhsoff[lhs_add * reduce_size + j] *\n               rhsoff[rhs_add * reduce_size + j];\n        if (j + 32 < reduce_size)\n          val += lhsoff[lhs_add * reduce_size + j + 32] *\n                 rhsoff[rhs_add * reduce_size + j + 32];\n      }\n#pragma unroll\n      for (int offset = 16; offset > 0; offset /= 2)\n        val += __shfl_down_sync(full_mask, val, offset);\n      if (tx == 0) outoff[i] = val;\n    }\n  }\n}\n\n// Binary search the row_offsets to find the source node of the edge id.\ntemplate <typename Idx>\n__device__ __forceinline__ Idx\nBinarySearchSrc(const Idx* array, Idx length, Idx eid) {\n  Idx lo = 0, hi = length - 1;\n  while (lo < hi) {\n    Idx mid = (lo + hi) >> 1;\n    if (_ldg(array + mid) <= eid) {\n      lo = mid + 1;\n    } else {\n      hi = mid;\n    }\n  }\n  // INVARIANT: lo == hi\n  if (_ldg(array + hi) == eid) {\n    return hi;\n  } else {\n    return hi - 1;\n  }\n}\n\n/**\n * @brief CUDA kernel of g-SDDMM on Csr format.\n * @note it uses edge parallel strategy, different threadblocks (on y-axis)\n *       is responsible for the computation on different edges. Threadblocks\n *       on the x-axis are responsible for the computation on different\n * positions in feature dimension. To efficiently find the source node idx and\n * destination node index of an given edge on Csr format, it uses binary search\n * (time complexity O(log N)).\n */\ntemplate <\n    typename Idx, typename DType, typename BinaryOp, bool UseBcast = false,\n    bool UseIdx = false, int LhsTarget = 0, int RhsTarget = 2>\n__global__ void SDDMMCsrKernel(\n    const DType* __restrict__ lhs, const DType* __restrict__ rhs,\n    DType* __restrict__ out, const Idx* __restrict__ indptr,\n    const Idx* __restrict__ indices, const Idx* __restrict__ edge_map,\n    int64_t N, int64_t M, int64_t E, int64_t reduce_size,\n    const int64_t* __restrict__ lhs_off, const int64_t* __restrict__ rhs_off,\n    int64_t lhs_len, int64_t rhs_len, int64_t out_len) {\n  // SDDMM with Csr.\n  Idx ty = blockIdx.y * blockDim.y + threadIdx.y;\n  const Idx stride_y = blockDim.y * gridDim.y;\n  while (ty < E) {\n    const Idx src = BinarySearchSrc<Idx>(indptr, N + 1, ty);\n    const Idx dst = _ldg(indices + ty);\n    const Idx eid = UseIdx ? _ldg(edge_map + ty) : ty;\n    int64_t tx = blockIdx.x * blockDim.x + threadIdx.x;\n    const int64_t stride_x = blockDim.x * gridDim.x;\n    const DType* lhsoff =\n        BinaryOp::use_lhs\n            ? (lhs + Selector<LhsTarget>::Call(src, eid, dst) * lhs_len)\n            : nullptr;\n    const DType* rhsoff =\n        BinaryOp::use_rhs\n            ? (rhs + Selector<RhsTarget>::Call(src, eid, dst) * rhs_len)\n            : nullptr;\n    DType* outoff = out + eid * out_len;\n    while (tx < out_len) {\n      const Idx lhs_add = UseBcast ? lhs_off[tx] : tx;\n      const Idx rhs_add = UseBcast ? rhs_off[tx] : tx;\n      DType val = BinaryOp::Call(\n          lhsoff + lhs_add * reduce_size, rhsoff + rhs_add * reduce_size,\n          reduce_size);\n      outoff[tx] = val;\n      tx += stride_x;\n    }\n    ty += stride_y;\n  }\n}\n\n/**\n * @brief CUDA implementation of g-SDDMM on Coo format.\n * @param bcast Broadcast information.\n * @param coo The Coo matrix.\n * @param lhs The left hand side operand feature.\n * @param rhs The right hand size operand feature.\n * @param out The result feature on edges.\n */\ntemplate <\n    typename Idx, typename DType, typename Op, int LhsTarget = 0,\n    int RhsTarget = 2>\nvoid SDDMMCoo(\n    const BcastOff& bcast, const COOMatrix& coo, NDArray lhs, NDArray rhs,\n    NDArray out) {\n  const Idx* row = coo.row.Ptr<Idx>();\n  const Idx* col = coo.col.Ptr<Idx>();\n  const Idx* edge_map = coo.data.Ptr<Idx>();\n  const DType* lhs_data = lhs.Ptr<DType>();\n  const DType* rhs_data = rhs.Ptr<DType>();\n  DType* out_data = out.Ptr<DType>();\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n\n  int64_t *lhs_off = nullptr, *rhs_off = nullptr;\n  int64_t len = bcast.out_len, lhs_len = bcast.lhs_len, rhs_len = bcast.rhs_len;\n  int64_t reduce_dim = bcast.reduce_size;\n\n  const int64_t nnz = coo.row->shape[0];\n  const bool use_idx = !IsNullArray(coo.data);\n\n  if (std::is_same<Op, binary::Dot<DType> >::value && reduce_dim >= 32) {\n    const int ntx = 32;  // on feature dimension\n    const int nty = 8;   // on out dimension\n    const int nbx = (nnz + nty - 1) / nty;\n    const int nby = FindNumBlocks<'y'>(len);\n    const dim3 nblks(nbx, nby);\n    const dim3 nthrs(ntx, nty);\n    BCAST_IDX_CTX_SWITCH(bcast, use_idx, out->ctx, lhs_off, rhs_off, {\n      CUDA_KERNEL_CALL(\n          (SDDMMCooTreeReduceKernel<\n              Idx, DType, UseBcast, UseIdx, LhsTarget, RhsTarget>),\n          nblks, nthrs, 0, stream, lhs_data, rhs_data, out_data, row, col,\n          edge_map, coo.num_rows, coo.num_cols, nnz, reduce_dim, lhs_off,\n          rhs_off, lhs_len, rhs_len, len);\n    });\n  } else {\n    const int ntx = FindNumThreads(len);\n    const int nty = CUDA_MAX_NUM_THREADS / ntx;\n    const int nbx = (len + ntx - 1) / ntx;\n    const int nby = FindNumBlocks<'y'>((nnz + nty - 1) / nty);\n    const dim3 nblks(nbx, nby);\n    const dim3 nthrs(ntx, nty);\n    BCAST_IDX_CTX_SWITCH(bcast, use_idx, out->ctx, lhs_off, rhs_off, {\n      CUDA_KERNEL_CALL(\n          (SDDMMCooKernel<\n              Idx, DType, Op, UseBcast, UseIdx, LhsTarget, RhsTarget>),\n          nblks, nthrs, 0, stream, lhs_data, rhs_data, out_data, row, col,\n          edge_map, coo.num_rows, coo.num_cols, nnz, reduce_dim, lhs_off,\n          rhs_off, lhs_len, rhs_len, len);\n    });\n  }\n}\n\n/**\n * @brief CUDA implementation of g-SDDMM on Csr format.\n * @param bcast Broadcast information.\n * @param csr The Csr matrix.\n * @param lhs The left hand side operand feature.\n * @param rhs The right hand size operand feature.\n * @param out The result feature on edges.\n */\ntemplate <\n    typename Idx, typename DType, typename Op, int LhsTarget = 0,\n    int RhsTarget = 2>\nvoid SDDMMCsr(\n    const BcastOff& bcast, const CSRMatrix& csr, NDArray lhs, NDArray rhs,\n    NDArray out) {\n  const Idx* indptr = csr.indptr.Ptr<Idx>();\n  const Idx* indices = csr.indices.Ptr<Idx>();\n  const Idx* edge_map = csr.data.Ptr<Idx>();\n  const DType* lhs_data = lhs.Ptr<DType>();\n  const DType* rhs_data = rhs.Ptr<DType>();\n  DType* out_data = out.Ptr<DType>();\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  int64_t N = csr.num_rows, M = csr.num_cols, E = csr.indices->shape[0];\n\n  int64_t *lhs_off = nullptr, *rhs_off = nullptr;\n  int64_t len = bcast.out_len, lhs_len = bcast.lhs_len, rhs_len = bcast.rhs_len;\n  int64_t reduce_dim = bcast.reduce_size;\n\n  const int ntx = FindNumThreads(len);\n  const int nty = CUDA_MAX_NUM_THREADS / ntx;\n  const int nbx = (len + ntx - 1) / ntx;\n  const int nby = FindNumBlocks<'y'>((E + nty - 1) / nty);\n  const dim3 nblks(nbx, nby);\n  const dim3 nthrs(ntx, nty);\n  const bool use_idx = !IsNullArray(csr.data);\n\n  BCAST_IDX_CTX_SWITCH(bcast, use_idx, out->ctx, lhs_off, rhs_off, {\n    CUDA_KERNEL_CALL(\n        (SDDMMCsrKernel<\n            Idx, DType, Op, UseBcast, UseIdx, LhsTarget, RhsTarget>),\n        nblks, nthrs, 0, stream, lhs_data, rhs_data, out_data, indptr, indices,\n        edge_map, N, M, E, reduce_dim, lhs_off, rhs_off, lhs_len, rhs_len, len);\n  });\n}\n\n}  // namespace cuda\n}  // namespace aten\n}  // namespace dgl\n\n#endif  // DGL_ARRAY_CUDA_SDDMM_CUH_\n"
  },
  {
    "path": "src/array/cuda/sddmm_hetero_coo.cu",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file array/cuda/sddmm.cu\n * @brief SDDMM C APIs and definitions.\n */\n#include <dgl/array.h>\n\n#include \"./sddmm.cuh\"\n\nnamespace dgl {\nnamespace aten {\n\n/**\n * @brief CUDA implementation of g-SDDMM on heterograph using\n    Csr format.\n */\ntemplate <int XPU, typename IdType, typename DType>\nvoid SDDMMCooHetero(\n    const std::string& op, const BcastOff& bcast,\n    const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& vec_lhs,\n    const std::vector<NDArray>& vec_rhs, std::vector<NDArray> vec_out,\n    int lhs_target, int rhs_target, const std::vector<dgl_type_t>& lhs_eid,\n    const std::vector<dgl_type_t>& rhs_eid) {\n  SWITCH_OP(op, Op, {\n    SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {\n      /* Call SDDMM CUDA kernel for each relation type sequentially */\n      for (dgl_type_t etype = 0; etype < lhs_eid.size(); ++etype) {\n        COOMatrix coo = vec_coo[etype];\n        NDArray lhs = vec_lhs[lhs_eid[etype]];\n        NDArray rhs = vec_rhs[rhs_eid[etype]];\n        NDArray out = vec_out[etype];\n        cuda::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(\n            bcast, coo, lhs, rhs, out);\n      }\n    });\n  });\n}\n\ntemplate void SDDMMCooHetero<kDGLCUDA, int32_t, __half>(\n    const std::string& op, const BcastOff& bcast,\n    const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,\n    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,\n    int rhs_target, const std::vector<dgl_type_t>& in_eid,\n    const std::vector<dgl_type_t>& out_eid);\ntemplate void SDDMMCooHetero<kDGLCUDA, int64_t, __half>(\n    const std::string& op, const BcastOff& bcast,\n    const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,\n    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,\n    int rhs_target, const std::vector<dgl_type_t>& in_eid,\n    const std::vector<dgl_type_t>& out_eid);\n#if BF16_ENABLED\ntemplate void SDDMMCooHetero<kDGLCUDA, int32_t, __nv_bfloat16>(\n    const std::string& op, const BcastOff& bcast,\n    const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,\n    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,\n    int rhs_target, const std::vector<dgl_type_t>& in_eid,\n    const std::vector<dgl_type_t>& out_eid);\ntemplate void SDDMMCooHetero<kDGLCUDA, int64_t, __nv_bfloat16>(\n    const std::string& op, const BcastOff& bcast,\n    const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,\n    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,\n    int rhs_target, const std::vector<dgl_type_t>& in_eid,\n    const std::vector<dgl_type_t>& out_eid);\n#endif  // BF16_ENABLED\ntemplate void SDDMMCooHetero<kDGLCUDA, int32_t, float>(\n    const std::string& op, const BcastOff& bcast,\n    const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,\n    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,\n    int rhs_target, const std::vector<dgl_type_t>& in_eid,\n    const std::vector<dgl_type_t>& out_eid);\ntemplate void SDDMMCooHetero<kDGLCUDA, int64_t, float>(\n    const std::string& op, const BcastOff& bcast,\n    const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,\n    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,\n    int rhs_target, const std::vector<dgl_type_t>& in_eid,\n    const std::vector<dgl_type_t>& out_eid);\ntemplate void SDDMMCooHetero<kDGLCUDA, int32_t, double>(\n    const std::string& op, const BcastOff& bcast,\n    const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,\n    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,\n    int rhs_target, const std::vector<dgl_type_t>& in_eid,\n    const std::vector<dgl_type_t>& out_eid);\ntemplate void SDDMMCooHetero<kDGLCUDA, int64_t, double>(\n    const std::string& op, const BcastOff& bcast,\n    const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,\n    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,\n    int rhs_target, const std::vector<dgl_type_t>& in_eid,\n    const std::vector<dgl_type_t>& out_eid);\n\n}  // namespace aten\n}  // namespace dgl\n"
  },
  {
    "path": "src/array/cuda/sddmm_hetero_csr.cu",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file array/cuda/sddmm.cu\n * @brief SDDMM C APIs and definitions.\n */\n#include <dgl/array.h>\n\n#include \"./sddmm.cuh\"\n\nnamespace dgl {\nnamespace aten {\n\n/**\n * @brief CUDA implementation of g-SDDMM on heterograph using Csr format.\n */\ntemplate <int XPU, typename IdType, typename DType>\nvoid SDDMMCsrHetero(\n    const std::string& op, const BcastOff& bcast,\n    const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& vec_lhs,\n    const std::vector<NDArray>& vec_rhs, std::vector<NDArray> vec_out,\n    int lhs_target, int rhs_target, const std::vector<dgl_type_t>& lhs_eid,\n    const std::vector<dgl_type_t>& rhs_eid) {\n  SWITCH_OP(op, Op, {\n    SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {\n      /* Call SDDMM CUDA kernel for each relation type sequentially */\n      for (dgl_type_t etype = 0; etype < lhs_eid.size(); ++etype) {\n        CSRMatrix csr = vec_csr[etype];\n        NDArray lhs = vec_lhs[lhs_eid[etype]];\n        NDArray rhs = vec_rhs[rhs_eid[etype]];\n        NDArray out = vec_out[etype];\n        cuda::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(\n            bcast, csr, lhs, rhs, out);\n      }\n    });\n  });\n}\n\ntemplate void SDDMMCsrHetero<kDGLCUDA, int32_t, __half>(\n    const std::string& op, const BcastOff& bcast,\n    const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,\n    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,\n    int rhs_target, const std::vector<dgl_type_t>& in_eid,\n    const std::vector<dgl_type_t>& out_eid);\ntemplate void SDDMMCsrHetero<kDGLCUDA, int64_t, __half>(\n    const std::string& op, const BcastOff& bcast,\n    const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,\n    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,\n    int rhs_target, const std::vector<dgl_type_t>& in_eid,\n    const std::vector<dgl_type_t>& out_eid);\n#if BF16_ENABLED\ntemplate void SDDMMCsrHetero<kDGLCUDA, int32_t, __nv_bfloat16>(\n    const std::string& op, const BcastOff& bcast,\n    const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,\n    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,\n    int rhs_target, const std::vector<dgl_type_t>& in_eid,\n    const std::vector<dgl_type_t>& out_eid);\ntemplate void SDDMMCsrHetero<kDGLCUDA, int64_t, __nv_bfloat16>(\n    const std::string& op, const BcastOff& bcast,\n    const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,\n    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,\n    int rhs_target, const std::vector<dgl_type_t>& in_eid,\n    const std::vector<dgl_type_t>& out_eid);\n#endif  // BF16_ENABLED\ntemplate void SDDMMCsrHetero<kDGLCUDA, int32_t, float>(\n    const std::string& op, const BcastOff& bcast,\n    const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,\n    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,\n    int rhs_target, const std::vector<dgl_type_t>& in_eid,\n    const std::vector<dgl_type_t>& out_eid);\ntemplate void SDDMMCsrHetero<kDGLCUDA, int64_t, float>(\n    const std::string& op, const BcastOff& bcast,\n    const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,\n    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,\n    int rhs_target, const std::vector<dgl_type_t>& in_eid,\n    const std::vector<dgl_type_t>& out_eid);\ntemplate void SDDMMCsrHetero<kDGLCUDA, int32_t, double>(\n    const std::string& op, const BcastOff& bcast,\n    const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,\n    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,\n    int rhs_target, const std::vector<dgl_type_t>& in_eid,\n    const std::vector<dgl_type_t>& out_eid);\ntemplate void SDDMMCsrHetero<kDGLCUDA, int64_t, double>(\n    const std::string& op, const BcastOff& bcast,\n    const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,\n    const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,\n    int rhs_target, const std::vector<dgl_type_t>& in_eid,\n    const std::vector<dgl_type_t>& out_eid);\n\n}  // namespace aten\n}  // namespace dgl\n"
  },
  {
    "path": "src/array/cuda/segment_reduce.cu",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file array/cuda/segment_reduce.cu\n * @brief Segment reduce C APIs and definitions.\n */\n#include <dgl/array.h>\n#include <dgl/base_heterograph.h>\n\n#include \"./functor.cuh\"\n#include \"./segment_reduce.cuh\"\n#include \"./utils.h\"\n\nnamespace dgl {\n\nusing namespace cuda;\n\nnamespace aten {\n\ntemplate <int XPU, typename IdType, typename DType>\nvoid SegmentReduce(\n    const std::string& op, NDArray feat, NDArray offsets, NDArray out,\n    NDArray arg) {\n  if (op == \"sum\") {\n    cuda::SegmentReduce<IdType, DType, cuda::reduce::Sum<IdType, DType>>(\n        feat, offsets, out, arg);\n  } else if (op == \"max\") {\n    cuda::SegmentReduce<IdType, DType, cuda::reduce::Max<IdType, DType>>(\n        feat, offsets, out, arg);\n  } else if (op == \"min\") {\n    cuda::SegmentReduce<IdType, DType, cuda::reduce::Min<IdType, DType>>(\n        feat, offsets, out, arg);\n  } else {\n    LOG(FATAL) << \"Not implemented\";\n  }\n}\n\ntemplate <int XPU, typename IdType, typename DType>\nvoid ScatterAdd(NDArray feat, NDArray idx, NDArray out) {\n  cuda::ScatterAdd<IdType, DType>(feat, idx, out);\n}\n\ntemplate <int XPU, typename IdType, typename DType>\nvoid UpdateGradMinMax_hetero(\n    const HeteroGraphPtr& g, const std::string& op,\n    const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,\n    const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out) {\n  cuda::UpdateGradMinMax_hetero<IdType, DType>(\n      g, op, feat, idx, idx_etype, out);\n}\n\ntemplate <int XPU, typename IdType, typename DType>\nvoid BackwardSegmentCmp(NDArray feat, NDArray arg, NDArray out) {\n  cuda::BackwardSegmentCmp<IdType, DType>(feat, arg, out);\n}\n\ntemplate void SegmentReduce<kDGLCUDA, int32_t, __half>(\n    const std::string& op, NDArray feat, NDArray offsets, NDArray out,\n    NDArray arg);\ntemplate void SegmentReduce<kDGLCUDA, int64_t, __half>(\n    const std::string& op, NDArray feat, NDArray offsets, NDArray out,\n    NDArray arg);\n#if BF16_ENABLED\ntemplate void SegmentReduce<kDGLCUDA, int32_t, __nv_bfloat16>(\n    const std::string& op, NDArray feat, NDArray offsets, NDArray out,\n    NDArray arg);\ntemplate void SegmentReduce<kDGLCUDA, int64_t, __nv_bfloat16>(\n    const std::string& op, NDArray feat, NDArray offsets, NDArray out,\n    NDArray arg);\n#endif  // BF16_ENABLED\ntemplate void SegmentReduce<kDGLCUDA, int32_t, float>(\n    const std::string& op, NDArray feat, NDArray offsets, NDArray out,\n    NDArray arg);\ntemplate void SegmentReduce<kDGLCUDA, int64_t, float>(\n    const std::string& op, NDArray feat, NDArray offsets, NDArray out,\n    NDArray arg);\ntemplate void SegmentReduce<kDGLCUDA, int32_t, double>(\n    const std::string& op, NDArray feat, NDArray offsets, NDArray out,\n    NDArray arg);\ntemplate void SegmentReduce<kDGLCUDA, int64_t, double>(\n    const std::string& op, NDArray feat, NDArray offsets, NDArray out,\n    NDArray arg);\n\ntemplate void ScatterAdd<kDGLCUDA, int32_t, __half>(\n    NDArray feat, NDArray idx, NDArray out);\ntemplate void ScatterAdd<kDGLCUDA, int64_t, __half>(\n    NDArray feat, NDArray idx, NDArray out);\n#if BF16_ENABLED\ntemplate void ScatterAdd<kDGLCUDA, int32_t, __nv_bfloat16>(\n    NDArray feat, NDArray idx, NDArray out);\ntemplate void ScatterAdd<kDGLCUDA, int64_t, __nv_bfloat16>(\n    NDArray feat, NDArray idx, NDArray out);\n#endif  // BF16_ENABLED\ntemplate void ScatterAdd<kDGLCUDA, int32_t, float>(\n    NDArray feat, NDArray idx, NDArray out);\ntemplate void ScatterAdd<kDGLCUDA, int64_t, float>(\n    NDArray feat, NDArray idx, NDArray out);\ntemplate void ScatterAdd<kDGLCUDA, int32_t, double>(\n    NDArray feat, NDArray idx, NDArray out);\ntemplate void ScatterAdd<kDGLCUDA, int64_t, double>(\n    NDArray feat, NDArray idx, NDArray out);\n\ntemplate void UpdateGradMinMax_hetero<kDGLCUDA, int32_t, __half>(\n    const HeteroGraphPtr& g, const std::string& op,\n    const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,\n    const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);\ntemplate void UpdateGradMinMax_hetero<kDGLCUDA, int64_t, __half>(\n    const HeteroGraphPtr& g, const std::string& op,\n    const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,\n    const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);\n#if BF16_ENABLED\ntemplate void UpdateGradMinMax_hetero<kDGLCUDA, int32_t, __nv_bfloat16>(\n    const HeteroGraphPtr& g, const std::string& op,\n    const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,\n    const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);\ntemplate void UpdateGradMinMax_hetero<kDGLCUDA, int64_t, __nv_bfloat16>(\n    const HeteroGraphPtr& g, const std::string& op,\n    const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,\n    const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);\n#endif  // BF16_ENABLED\ntemplate void UpdateGradMinMax_hetero<kDGLCUDA, int32_t, float>(\n    const HeteroGraphPtr& g, const std::string& op,\n    const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,\n    const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);\ntemplate void UpdateGradMinMax_hetero<kDGLCUDA, int64_t, float>(\n    const HeteroGraphPtr& g, const std::string& op,\n    const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,\n    const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);\ntemplate void UpdateGradMinMax_hetero<kDGLCUDA, int32_t, double>(\n    const HeteroGraphPtr& g, const std::string& op,\n    const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,\n    const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);\ntemplate void UpdateGradMinMax_hetero<kDGLCUDA, int64_t, double>(\n    const HeteroGraphPtr& g, const std::string& op,\n    const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,\n    const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);\n\ntemplate void BackwardSegmentCmp<kDGLCUDA, int32_t, __half>(\n    NDArray feat, NDArray arg, NDArray out);\ntemplate void BackwardSegmentCmp<kDGLCUDA, int64_t, __half>(\n    NDArray feat, NDArray arg, NDArray out);\n#if BF16_ENABLED\ntemplate void BackwardSegmentCmp<kDGLCUDA, int32_t, __nv_bfloat16>(\n    NDArray feat, NDArray arg, NDArray out);\ntemplate void BackwardSegmentCmp<kDGLCUDA, int64_t, __nv_bfloat16>(\n    NDArray feat, NDArray arg, NDArray out);\n#endif  // BF16_ENABLED\ntemplate void BackwardSegmentCmp<kDGLCUDA, int32_t, float>(\n    NDArray feat, NDArray arg, NDArray out);\ntemplate void BackwardSegmentCmp<kDGLCUDA, int64_t, float>(\n    NDArray feat, NDArray arg, NDArray out);\ntemplate void BackwardSegmentCmp<kDGLCUDA, int32_t, double>(\n    NDArray feat, NDArray arg, NDArray out);\ntemplate void BackwardSegmentCmp<kDGLCUDA, int64_t, double>(\n    NDArray feat, NDArray arg, NDArray out);\n\n}  // namespace aten\n}  // namespace dgl\n"
  },
  {
    "path": "src/array/cuda/segment_reduce.cuh",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file array/cuda/segment_reduce.cuh\n * @brief Segment reduce kernel function header.\n */\n#ifndef DGL_ARRAY_CUDA_SEGMENT_REDUCE_CUH_\n#define DGL_ARRAY_CUDA_SEGMENT_REDUCE_CUH_\n\n#include <string>\n#include <vector>\n\n#include \"../../runtime/cuda/cuda_common.h\"\n#include \"./atomic.cuh\"\n#include \"./utils.h\"\n\nnamespace dgl {\n\nusing namespace cuda;\nusing namespace runtime;\n\nnamespace aten {\nnamespace cuda {\n\n/**\n * @brief CUDA kernel of segment reduce.\n * @note each blockthread is responsible for aggregation on a row\n *       in the result tensor.\n */\ntemplate <typename IdType, typename DType, typename ReduceOp>\n__global__ void SegmentReduceKernel(\n    const DType* feat, const IdType* offsets, DType* out, IdType* arg,\n    int64_t n, int64_t dim) {\n  for (int row = blockIdx.x; row < n; row += gridDim.x) {\n    int col = blockIdx.y * blockDim.x + threadIdx.x;\n    while (col < dim) {\n      typename accum_dtype<DType>::type local_accum = ReduceOp::zero();\n      IdType local_arg = -1;\n      for (IdType i = offsets[row]; i < offsets[row + 1]; ++i) {\n        ReduceOp::Call(&local_accum, &local_arg, feat[i * dim + col], i);\n      }\n      out[row * dim + col] = static_cast<DType>(local_accum);\n      if (ReduceOp::require_arg) arg[row * dim + col] = local_arg;\n      col += gridDim.y * blockDim.x;\n    }\n  }\n}\n\n/**\n * @brief CUDA kernel of scatter add.\n * @note each blockthread is responsible for adding a row in feature tensor\n *       to a target row in output tensor.\n */\ntemplate <typename IdType, typename DType>\n__global__ void ScatterAddKernel(\n    const DType* feat, const IdType* idx, DType* out, int64_t n, int64_t dim) {\n  for (int row = blockIdx.x; row < n; row += gridDim.x) {\n    const int write_row = idx[row];\n    int col = blockIdx.y * blockDim.x + threadIdx.x;\n    while (col < dim) {\n      cuda::AtomicAdd(out + write_row * dim + col, feat[row * dim + col]);\n      col += gridDim.y * blockDim.x;\n    }\n  }\n}\n\n/**\n * @brief CUDA kernel to update gradients for reduce op max/min\n * @note each WARP (group of 32 threads) is responsible for adding a row in\n * feature tensor to a target row in output tensor.\n */\n\ntemplate <typename IdType, typename DType>\n__global__ void UpdateGradMinMaxHeteroKernel(\n    const DType* feat, const IdType* idx, const IdType* idx_type, DType* out,\n    int64_t n, int64_t dim, int type) {\n  unsigned int tId = threadIdx.x;\n  unsigned int laneId = tId & 31;\n  unsigned int gId = blockIdx.x * blockDim.x + threadIdx.x;\n  unsigned int warpId = gId >> 5;\n  unsigned int warp_size = 32;\n  unsigned int row = warpId;\n\n  while (row < n) {\n    for (unsigned int col = laneId; col < dim; col += warp_size) {\n      if (type == idx_type[row * dim + col]) {\n        const int write_row = idx[row * dim + col];\n        cuda::AtomicAdd(out + write_row * dim + col, feat[row * dim + col]);\n      }\n    }\n    row += blockDim.x * gridDim.x;\n  }\n}\n\n/**\n * @brief CUDA kernel of backward phase in segment min/max.\n * @note each blockthread is responsible for writing a row in the\n *       result gradient tensor by lookup the ArgMin/Max for index information.\n */\ntemplate <typename IdType, typename DType>\n__global__ void BackwardSegmentCmpKernel(\n    const DType* feat, const IdType* arg, DType* out, int64_t n, int64_t dim) {\n  for (int row = blockIdx.x; row < n; row += gridDim.x) {\n    int col = blockIdx.y * blockDim.x + threadIdx.x;\n    while (col < dim) {\n      int write_row = arg[row * dim + col];\n      if (write_row >= 0) {\n        out[write_row * dim + col] = feat[row * dim + col];\n      }\n      col += gridDim.y * blockDim.x;\n    }\n  }\n}\n\n/**\n * @brief CUDA implementation of forward phase of Segment Reduce.\n * @param feat The input tensor.\n * @param offsets The offsets tensor.\n * @param out The output tensor.\n * @param arg An auxiliary tensor storing ArgMax/Min information,\n */\ntemplate <typename IdType, typename DType, typename ReduceOp>\nvoid SegmentReduce(NDArray feat, NDArray offsets, NDArray out, NDArray arg) {\n  const DType* feat_data = feat.Ptr<DType>();\n  const IdType* offsets_data = offsets.Ptr<IdType>();\n  DType* out_data = out.Ptr<DType>();\n  IdType* arg_data = arg.Ptr<IdType>();\n\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  int64_t n = out->shape[0];\n  int64_t dim = 1;\n  for (int i = 1; i < out->ndim; ++i) dim *= out->shape[i];\n\n  const int nbx = FindNumBlocks<'x'>(n);\n  const int ntx = FindNumThreads(dim);\n  const int nby = FindNumBlocks<'y'>((dim + ntx - 1) / ntx);\n  const int nty = 1;\n  const dim3 nblks(nbx, nby);\n  const dim3 nthrs(ntx, nty);\n  // TODO(zihao): try cub's DeviceSegmentedReduce and compare the performance.\n  CUDA_KERNEL_CALL(\n      (SegmentReduceKernel<IdType, DType, ReduceOp>), nblks, nthrs, 0, stream,\n      feat_data, offsets_data, out_data, arg_data, n, dim);\n}\n\n/**\n * @brief CUDA implementation of Scatter Add (on first dimension).\n * @note math equation: out[idx[i], *] += feat[i, *]\n * @param feat The input tensor.\n * @param idx The indices tensor.\n * @param out The output tensor.\n */\ntemplate <typename IdType, typename DType>\nvoid ScatterAdd(NDArray feat, NDArray idx, NDArray out) {\n  const DType* feat_data = feat.Ptr<DType>();\n  const IdType* idx_data = idx.Ptr<IdType>();\n  DType* out_data = out.Ptr<DType>();\n\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  int64_t n = feat->shape[0];\n  int64_t dim = 1;\n  for (int i = 1; i < out->ndim; ++i) dim *= out->shape[i];\n\n  const int nbx = FindNumBlocks<'x'>(n);\n  const int ntx = FindNumThreads(dim);\n  const int nby = FindNumBlocks<'y'>((dim + ntx - 1) / ntx);\n  const int nty = 1;\n  const dim3 nblks(nbx, nby);\n  const dim3 nthrs(ntx, nty);\n  CUDA_KERNEL_CALL(\n      (ScatterAddKernel<IdType, DType>), nblks, nthrs, 0, stream, feat_data,\n      idx_data, out_data, n, dim);\n}\n\n/**\n * @brief CUDA implementation to update gradients for reduce op max/min\n * @param graph The input heterogeneous graph.\n * @param op The binary operator, could be `copy_u`, `copy_e'.\n * @param list_feat List of the input tensors.\n * @param list_idx  List of the indices tensors.\n * @param list_idx_etype List of the node- or edge-type tensors.\n * @param list_out List of the output tensors.\n */\ntemplate <typename IdType, typename DType>\nvoid UpdateGradMinMax_hetero(\n    const HeteroGraphPtr& graph, const std::string& op,\n    const std::vector<NDArray>& list_feat, const std::vector<NDArray>& list_idx,\n    const std::vector<NDArray>& list_idx_types,\n    std::vector<NDArray>* list_out) {\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  if (op == \"copy_lhs\" || op == \"copy_rhs\") {\n    std::vector<std::vector<dgl_id_t>> src_dst_ntypes(\n        graph->NumVertexTypes(), std::vector<dgl_id_t>());\n    for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {\n      auto pair = graph->meta_graph()->FindEdge(etype);\n      const dgl_id_t dst_ntype = pair.first;  // graph is reversed\n      const dgl_id_t src_ntype = pair.second;\n      auto same_src_dst_ntype = std::find(\n          std::begin(src_dst_ntypes[dst_ntype]),\n          std::end(src_dst_ntypes[dst_ntype]), src_ntype);\n      // if op is \"copy_lhs\", relation type with same src and dst node type will\n      // be updated once\n      if (op == \"copy_lhs\" &&\n          same_src_dst_ntype != std::end(src_dst_ntypes[dst_ntype]))\n        continue;\n      src_dst_ntypes[dst_ntype].push_back(src_ntype);\n      const DType* feat_data = list_feat[dst_ntype].Ptr<DType>();\n      const IdType* idx_data = list_idx[dst_ntype].Ptr<IdType>();\n      const IdType* idx_type_data = list_idx_types[dst_ntype].Ptr<IdType>();\n      int type = (op == \"copy_lhs\") ? src_ntype : etype;\n      DType* out_data = (*list_out)[type].Ptr<DType>();\n      int dim = 1;\n      for (int i = 1; i < (*list_out)[type]->ndim; ++i)\n        dim *= (*list_out)[type]->shape[i];\n      int n = list_feat[dst_ntype]->shape[0];\n      const int th_per_row = 32;\n      const int ntx = 128;\n      const int nbx = FindNumBlocks<'x'>((n * th_per_row + ntx - 1) / ntx);\n      const dim3 nblks(nbx);\n      const dim3 nthrs(ntx);\n      CUDA_KERNEL_CALL(\n          (UpdateGradMinMaxHeteroKernel<IdType, DType>), nblks, nthrs, 0,\n          stream, feat_data, idx_data, idx_type_data, out_data, n, dim, type);\n    }\n  }\n}\n\n/**\n * @brief CUDA implementation of backward phase of Segment Reduce with Min/Max\n *        reducer.\n * @note math equation: out[arg[i, k], k] = feat[i, k]\n * @param feat The input\n *       tensor.\n * @param arg The ArgMin/Max information, used for indexing.\n * @param out The output tensor.\n */\ntemplate <typename IdType, typename DType>\nvoid BackwardSegmentCmp(NDArray feat, NDArray arg, NDArray out) {\n  const DType* feat_data = feat.Ptr<DType>();\n  const IdType* arg_data = arg.Ptr<IdType>();\n  DType* out_data = out.Ptr<DType>();\n\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  int64_t n = feat->shape[0];\n  int64_t dim = 1;\n  for (int i = 1; i < out->ndim; ++i) dim *= out->shape[i];\n\n  const int nbx = FindNumBlocks<'x'>(n);\n  const int ntx = FindNumThreads(dim);\n  const int nby = FindNumBlocks<'y'>((dim + ntx - 1) / ntx);\n  const int nty = 1;\n  const dim3 nblks(nbx, nby);\n  const dim3 nthrs(ntx, nty);\n  CUDA_KERNEL_CALL(\n      (BackwardSegmentCmpKernel<IdType, DType>), nblks, nthrs, 0, stream,\n      feat_data, arg_data, out_data, n, dim);\n}\n\n}  // namespace cuda\n}  // namespace aten\n}  // namespace dgl\n\n#endif  // DGL_ARRAY_CUDA_SEGMENT_REDUCE_CUH_\n"
  },
  {
    "path": "src/array/cuda/spmat_op_impl_coo.cu",
    "content": "/**\n *  Copyright (c) 2021 by contributors.\n * @file array/cuda/spmat_op_impl_coo.cu\n * @brief COO operator GPU implementation\n */\n#include <dgl/array.h>\n\n#include <numeric>\n#include <unordered_set>\n#include <vector>\n\n#include \"../../runtime/cuda/cuda_common.h\"\n#include \"./atomic.cuh\"\n#include \"./utils.h\"\n\nnamespace dgl {\n\nusing runtime::NDArray;\nusing namespace cuda;\n\nnamespace aten {\nnamespace impl {\n\ntemplate <typename IdType>\n__device__ void _warpReduce(volatile IdType* sdata, IdType tid) {\n  sdata[tid] += sdata[tid + 32];\n  sdata[tid] += sdata[tid + 16];\n  sdata[tid] += sdata[tid + 8];\n  sdata[tid] += sdata[tid + 4];\n  sdata[tid] += sdata[tid + 2];\n  sdata[tid] += sdata[tid + 1];\n}\n\ntemplate <typename IdType>\n__global__ void _COOGetRowNNZKernel(\n    const IdType* __restrict__ row_indices, IdType* __restrict__ glb_cnt,\n    const int64_t row_query, IdType nnz) {\n  __shared__ IdType local_cnt[1024];\n  IdType tx = threadIdx.x;\n  IdType bx = blockIdx.x;\n  local_cnt[tx] = 0;\n  IdType start = bx * blockDim.x;\n  while (start < nnz) {\n    if (start + tx < nnz)\n      local_cnt[tx] = (row_indices[start + tx] == row_query);\n    __syncthreads();\n    if (tx < 512) {\n      local_cnt[tx] += local_cnt[tx + 512];\n      __syncthreads();\n    }\n    if (tx < 256) {\n      local_cnt[tx] += local_cnt[tx + 256];\n      __syncthreads();\n    }\n    if (tx < 128) {\n      local_cnt[tx] += local_cnt[tx + 128];\n      __syncthreads();\n    }\n    if (tx < 64) {\n      local_cnt[tx] += local_cnt[tx + 64];\n      __syncthreads();\n    }\n    if (tx < 32) {\n      _warpReduce(local_cnt, tx);\n    }\n    if (tx == 0) {\n      cuda::AtomicAdd(glb_cnt, local_cnt[tx]);\n    }\n    start += blockDim.x * gridDim.x;\n  }\n}\n\ntemplate <DGLDeviceType XPU, typename IdType>\nint64_t COOGetRowNNZ(COOMatrix coo, int64_t row) {\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  const auto& ctx = coo.row->ctx;\n  IdType nnz = coo.row->shape[0];\n  IdType nt = 1024;\n  IdType nb = dgl::cuda::FindNumBlocks<'x'>((nnz + nt - 1) / nt);\n  NDArray rst = NDArray::Empty({1}, coo.row->dtype, coo.row->ctx);\n  _Fill(rst.Ptr<IdType>(), 1, IdType(0));\n  CUDA_KERNEL_CALL(\n      _COOGetRowNNZKernel, nb, nt, 0, stream, coo.row.Ptr<IdType>(),\n      rst.Ptr<IdType>(), row, nnz);\n  rst = rst.CopyTo(DGLContext{kDGLCPU, 0});\n  return *rst.Ptr<IdType>();\n}\n\ntemplate int64_t COOGetRowNNZ<kDGLCUDA, int32_t>(COOMatrix, int64_t);\ntemplate int64_t COOGetRowNNZ<kDGLCUDA, int64_t>(COOMatrix, int64_t);\n\ntemplate <typename IdType>\n__global__ void _COOGetAllRowNNZKernel(\n    const IdType* __restrict__ row_indices, IdType* __restrict__ glb_cnts,\n    IdType nnz) {\n  IdType eid = blockIdx.x * blockDim.x + threadIdx.x;\n  while (eid < nnz) {\n    IdType row = row_indices[eid];\n    cuda::AtomicAdd(glb_cnts + row, IdType(1));\n    eid += blockDim.x * gridDim.x;\n  }\n}\n\ntemplate <DGLDeviceType XPU, typename IdType>\nNDArray COOGetRowNNZ(COOMatrix coo, NDArray rows) {\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  const auto& ctx = coo.row->ctx;\n  IdType nnz = coo.row->shape[0];\n  IdType num_rows = coo.num_rows;\n  IdType num_queries = rows->shape[0];\n  if (num_queries == 1) {\n    auto rows_cpu = rows.CopyTo(DGLContext{kDGLCPU, 0});\n    int64_t row = *rows_cpu.Ptr<IdType>();\n    IdType nt = 1024;\n    IdType nb = dgl::cuda::FindNumBlocks<'x'>((nnz + nt - 1) / nt);\n    NDArray rst = NDArray::Empty({1}, coo.row->dtype, coo.row->ctx);\n    _Fill(rst.Ptr<IdType>(), 1, IdType(0));\n    CUDA_KERNEL_CALL(\n        _COOGetRowNNZKernel, nb, nt, 0, stream, coo.row.Ptr<IdType>(),\n        rst.Ptr<IdType>(), row, nnz);\n    return rst;\n  } else {\n    IdType nt = 1024;\n    IdType nb = dgl::cuda::FindNumBlocks<'x'>((nnz + nt - 1) / nt);\n    NDArray in_degrees = NDArray::Empty({num_rows}, rows->dtype, rows->ctx);\n    _Fill(in_degrees.Ptr<IdType>(), num_rows, IdType(0));\n    CUDA_KERNEL_CALL(\n        _COOGetAllRowNNZKernel, nb, nt, 0, stream, coo.row.Ptr<IdType>(),\n        in_degrees.Ptr<IdType>(), nnz);\n    return IndexSelect(in_degrees, rows);\n  }\n}\n\ntemplate NDArray COOGetRowNNZ<kDGLCUDA, int32_t>(COOMatrix, NDArray);\ntemplate NDArray COOGetRowNNZ<kDGLCUDA, int64_t>(COOMatrix, NDArray);\n\n}  // namespace impl\n}  // namespace aten\n}  // namespace dgl\n"
  },
  {
    "path": "src/array/cuda/spmat_op_impl_csr.cu",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file array/cuda/spmat_op_impl_csr.cu\n * @brief CSR operator CPU implementation\n */\n#include <dgl/array.h>\n#include <thrust/execution_policy.h>\n#include <thrust/for_each.h>\n\n#include <cub/cub.cuh>\n#include <numeric>\n#include <unordered_set>\n#include <vector>\n\n#include \"../../runtime/cuda/cuda_common.h\"\n#include \"./atomic.cuh\"\n#include \"./utils.h\"\n\nnamespace dgl {\n\nusing runtime::NDArray;\nusing namespace cuda;\n\nnamespace aten {\nnamespace impl {\n\n///////////////////////////// CSRIsNonZero /////////////////////////////\n\ntemplate <DGLDeviceType XPU, typename IdType>\nbool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) {\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  const auto& ctx = csr.indptr->ctx;\n  IdArray rows = aten::VecToIdArray<int64_t>({row}, sizeof(IdType) * 8, ctx);\n  IdArray cols = aten::VecToIdArray<int64_t>({col}, sizeof(IdType) * 8, ctx);\n  rows = rows.CopyTo(ctx);\n  cols = cols.CopyTo(ctx);\n  IdArray out = aten::NewIdArray(1, ctx, sizeof(IdType) * 8);\n  const IdType* data = nullptr;\n  // TODO(minjie): use binary search for sorted csr\n  CUDA_KERNEL_CALL(\n      dgl::cuda::_LinearSearchKernel, 1, 1, 0, stream, csr.indptr.Ptr<IdType>(),\n      csr.indices.Ptr<IdType>(), data, rows.Ptr<IdType>(), cols.Ptr<IdType>(),\n      1, 1, 1, static_cast<IdType*>(nullptr), static_cast<IdType>(-1),\n      out.Ptr<IdType>());\n  out = out.CopyTo(DGLContext{kDGLCPU, 0});\n  return *out.Ptr<IdType>() != -1;\n}\n\ntemplate bool CSRIsNonZero<kDGLCUDA, int32_t>(CSRMatrix, int64_t, int64_t);\ntemplate bool CSRIsNonZero<kDGLCUDA, int64_t>(CSRMatrix, int64_t, int64_t);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nNDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) {\n  const auto rowlen = row->shape[0];\n  const auto collen = col->shape[0];\n  const auto rstlen = std::max(rowlen, collen);\n  NDArray rst = NDArray::Empty({rstlen}, row->dtype, row->ctx);\n  if (rstlen == 0) return rst;\n  const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1;\n  const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1;\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  const int nt = dgl::cuda::FindNumThreads(rstlen);\n  const int nb = (rstlen + nt - 1) / nt;\n  const IdType* data = nullptr;\n  const IdType* indptr_data =\n      static_cast<IdType*>(GetDevicePointer(csr.indptr));\n  const IdType* indices_data =\n      static_cast<IdType*>(GetDevicePointer(csr.indices));\n  // TODO(minjie): use binary search for sorted csr\n  CUDA_KERNEL_CALL(\n      dgl::cuda::_LinearSearchKernel, nb, nt, 0, stream, indptr_data,\n      indices_data, data, row.Ptr<IdType>(), col.Ptr<IdType>(), row_stride,\n      col_stride, rstlen, static_cast<IdType*>(nullptr),\n      static_cast<IdType>(-1), rst.Ptr<IdType>());\n  return rst != -1;\n}\n\ntemplate NDArray CSRIsNonZero<kDGLCUDA, int32_t>(CSRMatrix, NDArray, NDArray);\ntemplate NDArray CSRIsNonZero<kDGLCUDA, int64_t>(CSRMatrix, NDArray, NDArray);\n\n///////////////////////////// CSRHasDuplicate /////////////////////////////\n\n/**\n * @brief Check whether each row does not have any duplicate entries.\n * Assume the CSR is sorted.\n */\ntemplate <typename IdType>\n__global__ void _SegmentHasNoDuplicate(\n    const IdType* indptr, const IdType* indices, int64_t num_rows,\n    int8_t* flags) {\n  int tx = blockIdx.x * blockDim.x + threadIdx.x;\n  const int stride_x = gridDim.x * blockDim.x;\n  while (tx < num_rows) {\n    bool f = true;\n    for (IdType i = indptr[tx] + 1; f && i < indptr[tx + 1]; ++i) {\n      f = (indices[i - 1] != indices[i]);\n    }\n    flags[tx] = static_cast<int8_t>(f);\n    tx += stride_x;\n  }\n}\n\ntemplate <DGLDeviceType XPU, typename IdType>\nbool CSRHasDuplicate(CSRMatrix csr) {\n  if (!csr.sorted) csr = CSRSort(csr);\n  const auto& ctx = csr.indptr->ctx;\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  auto device = runtime::DeviceAPI::Get(ctx);\n  // We allocate a workspace of num_rows bytes. It wastes a little bit memory\n  // but should be fine.\n  int8_t* flags =\n      static_cast<int8_t*>(device->AllocWorkspace(ctx, csr.num_rows));\n  const int nt = dgl::cuda::FindNumThreads(csr.num_rows);\n  const int nb = (csr.num_rows + nt - 1) / nt;\n  CUDA_KERNEL_CALL(\n      _SegmentHasNoDuplicate, nb, nt, 0, stream, csr.indptr.Ptr<IdType>(),\n      csr.indices.Ptr<IdType>(), csr.num_rows, flags);\n  bool ret = dgl::cuda::AllTrue(flags, csr.num_rows, ctx);\n  device->FreeWorkspace(ctx, flags);\n  return !ret;\n}\n\ntemplate bool CSRHasDuplicate<kDGLCUDA, int32_t>(CSRMatrix csr);\ntemplate bool CSRHasDuplicate<kDGLCUDA, int64_t>(CSRMatrix csr);\n\n///////////////////////////// CSRGetRowNNZ /////////////////////////////\n\ntemplate <DGLDeviceType XPU, typename IdType>\nint64_t CSRGetRowNNZ(CSRMatrix csr, int64_t row) {\n  const IdType cur = aten::IndexSelect<IdType>(csr.indptr, row);\n  const IdType next = aten::IndexSelect<IdType>(csr.indptr, row + 1);\n  return next - cur;\n}\n\ntemplate int64_t CSRGetRowNNZ<kDGLCUDA, int32_t>(CSRMatrix, int64_t);\ntemplate int64_t CSRGetRowNNZ<kDGLCUDA, int64_t>(CSRMatrix, int64_t);\n\ntemplate <typename IdType>\n__global__ void _CSRGetRowNNZKernel(\n    const IdType* vid, const IdType* indptr, IdType* out, int64_t length) {\n  int tx = blockIdx.x * blockDim.x + threadIdx.x;\n  int stride_x = gridDim.x * blockDim.x;\n  while (tx < length) {\n    const IdType vv = vid[tx];\n    out[tx] = indptr[vv + 1] - indptr[vv];\n    tx += stride_x;\n  }\n}\n\ntemplate <DGLDeviceType XPU, typename IdType>\nNDArray CSRGetRowNNZ(CSRMatrix csr, NDArray rows) {\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  const auto len = rows->shape[0];\n  const IdType* vid_data = rows.Ptr<IdType>();\n  const IdType* indptr_data =\n      static_cast<IdType*>(GetDevicePointer(csr.indptr));\n  NDArray rst = NDArray::Empty({len}, rows->dtype, rows->ctx);\n  IdType* rst_data = static_cast<IdType*>(rst->data);\n  const int nt = dgl::cuda::FindNumThreads(len);\n  const int nb = (len + nt - 1) / nt;\n  CUDA_KERNEL_CALL(\n      _CSRGetRowNNZKernel, nb, nt, 0, stream, vid_data, indptr_data, rst_data,\n      len);\n  return rst;\n}\n\ntemplate NDArray CSRGetRowNNZ<kDGLCUDA, int32_t>(CSRMatrix, NDArray);\ntemplate NDArray CSRGetRowNNZ<kDGLCUDA, int64_t>(CSRMatrix, NDArray);\n\n////////////////////////// CSRGetRowColumnIndices //////////////////////////////\n\ntemplate <DGLDeviceType XPU, typename IdType>\nNDArray CSRGetRowColumnIndices(CSRMatrix csr, int64_t row) {\n  const int64_t len = impl::CSRGetRowNNZ<XPU, IdType>(csr, row);\n  const int64_t offset =\n      aten::IndexSelect<IdType>(csr.indptr, row) * sizeof(IdType);\n  return csr.indices.CreateView({len}, csr.indices->dtype, offset);\n}\n\ntemplate NDArray CSRGetRowColumnIndices<kDGLCUDA, int32_t>(CSRMatrix, int64_t);\ntemplate NDArray CSRGetRowColumnIndices<kDGLCUDA, int64_t>(CSRMatrix, int64_t);\n\n///////////////////////////// CSRGetRowData /////////////////////////////\n\ntemplate <DGLDeviceType XPU, typename IdType>\nNDArray CSRGetRowData(CSRMatrix csr, int64_t row) {\n  const int64_t len = impl::CSRGetRowNNZ<XPU, IdType>(csr, row);\n  const int64_t offset =\n      aten::IndexSelect<IdType>(csr.indptr, row) * sizeof(IdType);\n  if (aten::CSRHasData(csr))\n    return csr.data.CreateView({len}, csr.data->dtype, offset);\n  else\n    return aten::Range(\n        offset, offset + len, csr.indptr->dtype.bits, csr.indptr->ctx);\n}\n\ntemplate NDArray CSRGetRowData<kDGLCUDA, int32_t>(CSRMatrix, int64_t);\ntemplate NDArray CSRGetRowData<kDGLCUDA, int64_t>(CSRMatrix, int64_t);\n\n///////////////////////////// CSRSliceRows /////////////////////////////\n\ntemplate <DGLDeviceType XPU, typename IdType>\nCSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end) {\n  const int64_t num_rows = end - start;\n  const IdType st_pos = aten::IndexSelect<IdType>(csr.indptr, start);\n  const IdType ed_pos = aten::IndexSelect<IdType>(csr.indptr, end);\n  const IdType nnz = ed_pos - st_pos;\n  IdArray ret_indptr = aten::IndexSelect(csr.indptr, start, end + 1) - st_pos;\n  // indices and data can be view arrays\n  IdArray ret_indices = csr.indices.CreateView(\n      {nnz}, csr.indices->dtype, st_pos * sizeof(IdType));\n  IdArray ret_data;\n  if (CSRHasData(csr))\n    ret_data =\n        csr.data.CreateView({nnz}, csr.data->dtype, st_pos * sizeof(IdType));\n  else\n    ret_data =\n        aten::Range(st_pos, ed_pos, csr.indptr->dtype.bits, csr.indptr->ctx);\n  return CSRMatrix(\n      num_rows, csr.num_cols, ret_indptr, ret_indices, ret_data, csr.sorted);\n}\n\ntemplate CSRMatrix CSRSliceRows<kDGLCUDA, int32_t>(CSRMatrix, int64_t, int64_t);\ntemplate CSRMatrix CSRSliceRows<kDGLCUDA, int64_t>(CSRMatrix, int64_t, int64_t);\n\n/**\n * @brief Copy data segment to output buffers\n *\n * For the i^th row r = row[i], copy the data from indptr[r] ~ indptr[r+1]\n * to the out_data from out_indptr[i] ~ out_indptr[i+1]\n *\n * If the provided `data` array is nullptr, write the read index to the\n * out_data.\n *\n */\ntemplate <typename IdType, typename DType>\n__global__ void _SegmentCopyKernel(\n    const IdType* indptr, const DType* data, const IdType* row, int64_t length,\n    int64_t n_row, const IdType* out_indptr, DType* out_data) {\n  IdType tx = static_cast<IdType>(blockIdx.x) * blockDim.x + threadIdx.x;\n  const int stride_x = gridDim.x * blockDim.x;\n  while (tx < length) {\n    IdType rpos = dgl::cuda::_UpperBound(out_indptr, n_row, tx) - 1;\n    IdType rofs = tx - out_indptr[rpos];\n    const IdType u = row[rpos];\n    out_data[tx] = data ? data[indptr[u] + rofs] : indptr[u] + rofs;\n    tx += stride_x;\n  }\n}\n\ntemplate <DGLDeviceType XPU, typename IdType>\nCSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) {\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  const int64_t len = rows->shape[0];\n  IdArray ret_indptr = aten::CumSum(aten::CSRGetRowNNZ(csr, rows), true);\n  const int64_t nnz = aten::IndexSelect<IdType>(ret_indptr, len);\n\n  const int nt = 256;  // for better GPU usage of small invocations\n  const int nb = (nnz + nt - 1) / nt;\n\n  // Copy indices.\n  IdArray ret_indices = NDArray::Empty({nnz}, csr.indptr->dtype, rows->ctx);\n\n  const IdType* indptr_data =\n      static_cast<IdType*>(GetDevicePointer(csr.indptr));\n  const IdType* indices_data =\n      static_cast<IdType*>(GetDevicePointer(csr.indices));\n  const IdType* data_data =\n      CSRHasData(csr) ? static_cast<IdType*>(GetDevicePointer(csr.data))\n                      : nullptr;\n\n  CUDA_KERNEL_CALL(\n      _SegmentCopyKernel, nb, nt, 0, stream, indptr_data, indices_data,\n      rows.Ptr<IdType>(), nnz, len, ret_indptr.Ptr<IdType>(),\n      ret_indices.Ptr<IdType>());\n  // Copy data.\n  IdArray ret_data = NDArray::Empty({nnz}, csr.indptr->dtype, rows->ctx);\n  CUDA_KERNEL_CALL(\n      _SegmentCopyKernel, nb, nt, 0, stream, indptr_data, data_data,\n      rows.Ptr<IdType>(), nnz, len, ret_indptr.Ptr<IdType>(),\n      ret_data.Ptr<IdType>());\n  return CSRMatrix(\n      len, csr.num_cols, ret_indptr, ret_indices, ret_data, csr.sorted);\n}\n\ntemplate CSRMatrix CSRSliceRows<kDGLCUDA, int32_t>(CSRMatrix, NDArray);\ntemplate CSRMatrix CSRSliceRows<kDGLCUDA, int64_t>(CSRMatrix, NDArray);\n\n///////////////////////////// CSRGetDataAndIndices /////////////////////////////\n\n/**\n * @brief Generate a 0-1 mask for each index that hits the provided (row, col)\n *        index.\n *\n * Examples:\n * Given a CSR matrix (with duplicate entries) as follows:\n * [[0, 1, 2, 0, 0],\n *  [1, 0, 0, 0, 0],\n *  [0, 0, 1, 1, 0],\n *  [0, 0, 0, 0, 0]]\n * Given rows: [0, 1], cols: [0, 2, 3]\n * The result mask is: [0, 1, 1, 1, 0, 0]\n */\ntemplate <typename IdType>\n__global__ void _SegmentMaskKernel(\n    const IdType* indptr, const IdType* indices, const IdType* row,\n    const IdType* col, int64_t row_stride, int64_t col_stride, int64_t length,\n    IdType* mask) {\n  int tx = blockIdx.x * blockDim.x + threadIdx.x;\n  const int stride_x = gridDim.x * blockDim.x;\n  while (tx < length) {\n    int rpos = tx * row_stride, cpos = tx * col_stride;\n    const IdType r = row[rpos], c = col[cpos];\n    for (IdType i = indptr[r]; i < indptr[r + 1]; ++i) {\n      if (indices[i] == c) {\n        mask[i] = 1;\n      }\n    }\n    tx += stride_x;\n  }\n}\n\n/**\n * @brief Search for the insertion positions for needle in the hay.\n *\n * The hay is a list of sorted elements and the result is the insertion position\n * of each needle so that the insertion still gives sorted order.\n *\n * It essentially perform binary search to find lower bound for each needle\n * elements. Require the largest elements in the hay is larger than the given\n * needle elements. Commonly used in searching for row IDs of a given set of\n * coordinates.\n */\ntemplate <typename IdType>\n__global__ void _SortedSearchKernel(\n    const IdType* hay, int64_t hay_size, const IdType* needles,\n    int64_t num_needles, IdType* pos) {\n  int tx = blockIdx.x * blockDim.x + threadIdx.x;\n  const int stride_x = gridDim.x * blockDim.x;\n  while (tx < num_needles) {\n    const IdType ele = needles[tx];\n    // binary search\n    IdType lo = 0, hi = hay_size - 1;\n    while (lo < hi) {\n      IdType mid = (lo + hi) >> 1;\n      if (hay[mid] <= ele) {\n        lo = mid + 1;\n      } else {\n        hi = mid;\n      }\n    }\n    pos[tx] = (hay[hi] == ele) ? hi : hi - 1;\n    tx += stride_x;\n  }\n}\n\ntemplate <DGLDeviceType XPU, typename IdType>\nstd::vector<NDArray> CSRGetDataAndIndices(\n    CSRMatrix csr, NDArray row, NDArray col) {\n  const auto rowlen = row->shape[0];\n  const auto collen = col->shape[0];\n  const auto len = std::max(rowlen, collen);\n  if (len == 0) return {NullArray(), NullArray(), NullArray()};\n\n  const auto& ctx = row->ctx;\n  const auto nbits = row->dtype.bits;\n  const int64_t nnz = csr.indices->shape[0];\n  const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1;\n  const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1;\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n\n  const IdType* indptr_data =\n      static_cast<IdType*>(GetDevicePointer(csr.indptr));\n  const IdType* indices_data =\n      static_cast<IdType*>(GetDevicePointer(csr.indices));\n\n  // Generate a 0-1 mask for matched (row, col) positions.\n  IdArray mask = Full(0, nnz, nbits, ctx);\n  const int nt = dgl::cuda::FindNumThreads(len);\n  const int nb = (len + nt - 1) / nt;\n  CUDA_KERNEL_CALL(\n      _SegmentMaskKernel, nb, nt, 0, stream, indptr_data, indices_data,\n      row.Ptr<IdType>(), col.Ptr<IdType>(), row_stride, col_stride, len,\n      mask.Ptr<IdType>());\n\n  IdArray idx = AsNumBits(NonZero(mask), nbits);\n  if (idx->shape[0] == 0)\n    // No data. Return three empty arrays.\n    return {idx, idx, idx};\n\n  // Search for row index\n  IdArray ret_row = NewIdArray(idx->shape[0], ctx, nbits);\n  const int nt2 = dgl::cuda::FindNumThreads(idx->shape[0]);\n  const int nb2 = (idx->shape[0] + nt - 1) / nt;\n  CUDA_KERNEL_CALL(\n      _SortedSearchKernel, nb2, nt2, 0, stream, indptr_data, csr.num_rows,\n      idx.Ptr<IdType>(), idx->shape[0], ret_row.Ptr<IdType>());\n\n  // Column & data can be obtained by index select.\n  IdArray ret_col = IndexSelect(csr.indices, idx);\n  IdArray ret_data = CSRHasData(csr) ? IndexSelect(csr.data, idx) : idx;\n  return {ret_row, ret_col, ret_data};\n}\n\ntemplate std::vector<NDArray> CSRGetDataAndIndices<kDGLCUDA, int32_t>(\n    CSRMatrix csr, NDArray rows, NDArray cols);\ntemplate std::vector<NDArray> CSRGetDataAndIndices<kDGLCUDA, int64_t>(\n    CSRMatrix csr, NDArray rows, NDArray cols);\n\n///////////////////////////// CSRSliceMatrix /////////////////////////////\n\nint64_t _UpPower(int64_t numel) {\n  uint64_t ret = 1 << static_cast<uint64_t>(std::log2(numel) + 1);\n  return ret;\n}\n\n/**\n * @brief Thomas Wang's 32 bit Mix Function.\n * Source link: https://gist.github.com/badboy/6267743\n */\n__device__ inline uint32_t _Hash32Shift(uint32_t key) {\n  key = ~key + (key << 15);\n  key = key ^ (key >> 12);\n  key = key + (key << 2);\n  key = key ^ (key >> 4);\n  key = key * 2057;\n  key = key ^ (key >> 16);\n  return key;\n}\n\n/**\n * @brief Thomas Wang's 64 bit Mix Function.\n * Source link: https://gist.github.com/badboy/6267743\n */\n__device__ inline uint64_t _Hash64Shift(uint64_t key) {\n  key = (~key) + (key << 21);\n  key = key ^ (key >> 24);\n  key = (key + (key << 3)) + (key << 8);\n  key = key ^ (key >> 14);\n  key = (key + (key << 2)) + (key << 4);\n  key = key ^ (key >> 28);\n  key = key + (key << 31);\n  return key;\n}\n\n/**\n * @brief A hashmap designed for CSRSliceMatrix, similar in function to set. For\n * performance, it can only be created and called in the cuda kernel.\n */\ntemplate <typename IdType>\nstruct NodeQueryHashmap {\n  __device__ inline NodeQueryHashmap(IdType* Kptr, size_t numel)\n      : kptr_(Kptr), capacity_(numel) {}\n\n  /**\n   * @brief Insert a key. It must be called by cuda threads.\n   *\n   * @param key The key to be inserted.\n   */\n  __device__ inline void Insert(IdType key) {\n    uint32_t delta = 1;\n    uint32_t pos = Hash(key);\n    IdType prev = dgl::aten::cuda::AtomicCAS(&kptr_[pos], kEmptyKey_, key);\n    while (prev != key && prev != kEmptyKey_) {\n      pos = Hash(pos + delta);\n      delta += 1;\n      prev = dgl::aten::cuda::AtomicCAS(&kptr_[pos], kEmptyKey_, key);\n    }\n  }\n\n  /**\n   * @brief Check whether a key exists within the hashtable. It must be called\n   * by cuda threads.\n   *\n   * @param key The key to check for.\n   * @return True if the key exists in the hashtable.\n   */\n  __device__ inline bool Query(IdType key) {\n    uint32_t delta = 1;\n    uint32_t pos = Hash(key);\n    while (true) {\n      if (kptr_[pos] == key) return true;\n      if (kptr_[pos] == kEmptyKey_) return false;\n      pos = Hash(pos + delta);\n      delta += 1;\n    }\n    return false;\n  }\n\n  __device__ inline uint32_t Hash(int32_t key) {\n    return _Hash32Shift(key) & (capacity_ - 1);\n  }\n\n  __device__ inline uint32_t Hash(uint32_t key) {\n    return _Hash32Shift(key) & (capacity_ - 1);\n  }\n\n  __device__ inline uint32_t Hash(int64_t key) {\n    return static_cast<uint32_t>(_Hash64Shift(key)) & (capacity_ - 1);\n  }\n\n  __device__ inline uint32_t Hash(uint64_t key) {\n    return static_cast<uint32_t>(_Hash64Shift(key)) & (capacity_ - 1);\n  }\n\n  IdType kEmptyKey_{-1};\n  IdType* kptr_;\n  uint32_t capacity_{0};\n};\n\n/**\n * @brief Generate a 0-1 mask for each index whose column is in the provided\n * hashmap. It also counts the number of masked values per row.\n *\n * @tparam IdType The ID type used for matrices.\n * @tparam WARP_SIZE The number of cuda threads in a cuda warp.\n * @tparam BLOCK_WARPS The number of warps in a cuda block.\n * @tparam TILE_SIZE The number of rows covered by each threadblock.\n */\ntemplate <typename IdType, int WARP_SIZE, int BLOCK_WARPS, int TILE_SIZE>\n__global__ void _SegmentMaskColKernel(\n    const IdType* indptr, const IdType* indices, int64_t num_rows,\n    IdType* hashmap_buffer, int64_t buffer_size, IdType* mask, IdType* count) {\n  assert(blockDim.x == WARP_SIZE);\n  assert(blockDim.y == BLOCK_WARPS);\n\n  int warp_id = threadIdx.y;\n  int laneid = threadIdx.x;\n  IdType out_row = blockIdx.x * TILE_SIZE + threadIdx.y;\n  IdType last_row =\n      min(static_cast<IdType>((blockIdx.x + 1) * TILE_SIZE),\n          static_cast<IdType>(num_rows));\n\n  NodeQueryHashmap<IdType> hashmap(hashmap_buffer, buffer_size);\n  typedef cub::WarpReduce<IdType> WarpReduce;\n  __shared__ typename WarpReduce::TempStorage temp_storage[BLOCK_WARPS];\n\n  while (out_row < last_row) {\n    IdType local_count = 0;\n    IdType in_row_start = indptr[out_row];\n    IdType in_row_end = indptr[out_row + 1];\n    for (int idx = in_row_start + laneid; idx < in_row_end; idx += WARP_SIZE) {\n      bool is_in = hashmap.Query(indices[idx]);\n      if (is_in) {\n        local_count += 1;\n        mask[idx] = 1;\n      }\n    }\n    IdType reduce_count = WarpReduce(temp_storage[warp_id]).Sum(local_count);\n    if (laneid == 0) {\n      count[out_row] = reduce_count;\n    }\n    out_row += BLOCK_WARPS;\n  }\n}\n\ntemplate <DGLDeviceType XPU, typename IdType>\nCSRMatrix CSRSliceMatrix(\n    CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols) {\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  const auto& ctx = rows->ctx;\n  const auto& dtype = rows->dtype;\n  const auto nbits = dtype.bits;\n  const int64_t new_nrows = rows->shape[0];\n  const int64_t new_ncols = cols->shape[0];\n\n  if (new_nrows == 0 || new_ncols == 0)\n    return CSRMatrix(\n        new_nrows, new_ncols, Full(0, new_nrows + 1, nbits, ctx),\n        NullArray(dtype, ctx), NullArray(dtype, ctx));\n\n  // First slice rows\n  csr = CSRSliceRows(csr, rows);\n\n  if (csr.indices->shape[0] == 0)\n    return CSRMatrix(\n        new_nrows, new_ncols, Full(0, new_nrows + 1, nbits, ctx),\n        NullArray(dtype, ctx), NullArray(dtype, ctx));\n\n  // Generate a 0-1 mask for matched (row, col) positions.\n  IdArray mask = Full(0, csr.indices->shape[0], nbits, ctx);\n  // A count for how many masked values per row.\n  IdArray count = NewIdArray(csr.num_rows, ctx, nbits);\n  CUDA_CALL(\n      cudaMemset(count.Ptr<IdType>(), 0, sizeof(IdType) * (csr.num_rows)));\n\n  // Generate a NodeQueryHashmap buffer. The key of the hashmap is col.\n  // For performance, the load factor of the hashmap is in (0.25, 0.5);\n  // Because num_cols is usually less than 1 Million (on GPU), the\n  // memory overhead is not significant (less than 31MB) at a low load factor.\n  int64_t buffer_size = _UpPower(new_ncols) * 2;\n  IdArray hashmap_buffer = Full(-1, buffer_size, nbits, ctx);\n\n  using it = thrust::counting_iterator<int64_t>;\n  runtime::CUDAWorkspaceAllocator allocator(ctx);\n  const auto exec_policy = thrust::cuda::par_nosync(allocator).on(stream);\n  thrust::for_each(\n      exec_policy, it(0), it(new_ncols),\n      [key = cols.Ptr<IdType>(), buffer = hashmap_buffer.Ptr<IdType>(),\n       buffer_size] __device__(int64_t i) {\n        NodeQueryHashmap<IdType> hashmap(buffer, buffer_size);\n        hashmap.Insert(key[i]);\n      });\n\n  const IdType* indptr_data =\n      static_cast<IdType*>(GetDevicePointer(csr.indptr));\n  const IdType* indices_data =\n      static_cast<IdType*>(GetDevicePointer(csr.indices));\n\n  // Execute SegmentMaskColKernel\n  const int64_t num_rows = csr.num_rows;\n  constexpr int WARP_SIZE = 32;\n  // With a simple fine-tuning, TILE_SIZE=16 gives a good performance.\n  constexpr int TILE_SIZE = 16;\n  constexpr int BLOCK_WARPS = CUDA_MAX_NUM_THREADS / WARP_SIZE;\n  IdType nb =\n      dgl::cuda::FindNumBlocks<'x'>((num_rows + TILE_SIZE - 1) / TILE_SIZE);\n  const dim3 nthrs(WARP_SIZE, BLOCK_WARPS);\n  const dim3 nblks(nb);\n  CUDA_KERNEL_CALL(\n      (_SegmentMaskColKernel<IdType, WARP_SIZE, BLOCK_WARPS, TILE_SIZE>), nblks,\n      nthrs, 0, stream, indptr_data, indices_data, num_rows,\n      hashmap_buffer.Ptr<IdType>(), buffer_size, mask.Ptr<IdType>(),\n      count.Ptr<IdType>());\n\n  IdArray idx = AsNumBits(NonZero(mask), nbits);\n  if (idx->shape[0] == 0)\n    return CSRMatrix(\n        new_nrows, new_ncols, Full(0, new_nrows + 1, nbits, ctx),\n        NullArray(dtype, ctx), NullArray(dtype, ctx));\n\n  // Indptr needs to be adjusted according to the new nnz per row.\n  IdArray ret_indptr = CumSum(count, true);\n\n  // Column & data can be obtained by index select.\n  IdArray ret_col = IndexSelect(csr.indices, idx);\n  IdArray ret_data = CSRHasData(csr) ? IndexSelect(csr.data, idx) : idx;\n\n  // Relabel column\n  IdArray col_hash = NewIdArray(csr.num_cols, ctx, nbits);\n  Scatter_(cols, Range(0, cols->shape[0], nbits, ctx), col_hash);\n  ret_col = IndexSelect(col_hash, ret_col);\n\n  return CSRMatrix(new_nrows, new_ncols, ret_indptr, ret_col, ret_data);\n}\n\ntemplate CSRMatrix CSRSliceMatrix<kDGLCUDA, int32_t>(\n    CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);\ntemplate CSRMatrix CSRSliceMatrix<kDGLCUDA, int64_t>(\n    CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);\n\n}  // namespace impl\n}  // namespace aten\n}  // namespace dgl\n"
  },
  {
    "path": "src/array/cuda/spmm.cu",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file array/cuda/spmm.cu\n * @brief SPMM C APIs and definitions.\n */\n#include <dgl/array.h>\n\n#include <cstdlib>\n\n#include \"../../runtime/cuda/cuda_common.h\"\n#include \"./functor.cuh\"\n#include \"./ge_spmm.cuh\"\n#include \"./spmm.cuh\"\n\nnamespace dgl {\n\nusing namespace cuda;\n\nnamespace aten {\n\n/**\n * @brief CUDA implementation of g-SpMM on Csr format.\n * @note use cusparse if the reduce operator is `sum` and there is\n *       no broadcast, use dgl's kernel in other cases.\n */\ntemplate <int XPU, typename IdType, typename DType>\nvoid SpMMCsr(\n    const std::string& op, const std::string& reduce, const BcastOff& bcast,\n    const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,\n    std::vector<NDArray> out_aux) {\n  bool is_scalar_efeat = efeat.NumElements() == csr.indices->shape[0];\n  bool use_efeat = op != \"copy_lhs\";\n  bool use_deterministic_alg_only = false;\n  if (NULL != std::getenv(\"USE_DETERMINISTIC_ALG\"))\n    use_deterministic_alg_only = true;\n\n  if (reduce == \"sum\") {\n    bool more_nnz = (csr.indices->shape[0] > csr.num_rows * csr.num_cols);\n    if (op == \"copy_lhs\" && cusparse_available<DType, IdType>(more_nnz)) {\n      // cusparse\n      int64_t x_length = 1;\n      for (int i = 1; i < ufeat->ndim; ++i) x_length *= ufeat->shape[i];\n      CusparseCsrmm2<DType, IdType>(\n          ufeat->ctx, csr, static_cast<DType*>(ufeat->data), nullptr,\n          static_cast<DType*>(out->data), x_length, use_deterministic_alg_only);\n    } else if (\n        op == \"mul\" && is_scalar_efeat &&\n        cusparse_available<DType, IdType>(more_nnz)) {\n      // cusparse\n      int64_t x_length = 1;\n      for (int i = 1; i < ufeat->ndim; ++i) x_length *= ufeat->shape[i];\n      if (!IsNullArray(csr.data)) {\n        efeat = IndexSelect(efeat, csr.data);\n      }\n      CusparseCsrmm2<DType, IdType>(\n          ufeat->ctx, csr, static_cast<DType*>(ufeat->data),\n          static_cast<DType*>(efeat->data), static_cast<DType*>(out->data),\n          x_length, use_deterministic_alg_only);\n    } else {  // general kernel\n      SWITCH_OP(op, Op, {\n        cuda::SpMMCsr<IdType, DType, Op, cuda::reduce::Sum<IdType, DType> >(\n            bcast, csr, ufeat, efeat, out, NullArray(), NullArray());\n      });\n    }\n  } else if (reduce == \"max\") {\n    SWITCH_OP(op, Op, {\n      cuda::SpMMCsr<IdType, DType, Op, cuda::reduce::Max<IdType, DType> >(\n          bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]);\n    });\n  } else if (reduce == \"min\") {\n    SWITCH_OP(op, Op, {\n      cuda::SpMMCsr<IdType, DType, Op, cuda::reduce::Min<IdType, DType> >(\n          bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]);\n    });\n  } else {\n    LOG(FATAL) << \"Not implemented\";\n  }\n}\n\n/**\n * @brief CUDA implementation of g-SpMM on Coo format.\n */\ntemplate <int XPU, typename IdType, typename DType>\nvoid SpMMCoo(\n    const std::string& op, const std::string& reduce, const BcastOff& bcast,\n    const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,\n    std::vector<NDArray> out_aux) {\n  if (reduce == \"sum\") {\n    SWITCH_OP(op, Op, {\n      cuda::SpMMCoo<IdType, DType, Op, cuda::reduce::Sum<IdType, DType, true> >(\n          bcast, coo, ufeat, efeat, out, NullArray(), NullArray());\n    });\n  } else if (reduce == \"max\") {\n    SWITCH_OP(op, Op, {\n      cuda::SpMMCoo<IdType, DType, Op, cuda::reduce::Max<IdType, DType, true> >(\n          bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]);\n    });\n  } else if (reduce == \"min\") {\n    SWITCH_OP(op, Op, {\n      cuda::SpMMCoo<IdType, DType, Op, cuda::reduce::Min<IdType, DType, true> >(\n          bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]);\n    });\n  } else {\n    LOG(FATAL) << \"Not implemented\";\n  }\n}\n\ntemplate void SpMMCsr<kDGLCUDA, int32_t, __half>(\n    const std::string& op, const std::string& reduce, const BcastOff& bcast,\n    const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,\n    std::vector<NDArray> out_aux);\ntemplate void SpMMCsr<kDGLCUDA, int64_t, __half>(\n    const std::string& op, const std::string& reduce, const BcastOff& bcast,\n    const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,\n    std::vector<NDArray> out_aux);\n#if BF16_ENABLED\ntemplate void SpMMCsr<kDGLCUDA, int32_t, __nv_bfloat16>(\n    const std::string& op, const std::string& reduce, const BcastOff& bcast,\n    const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,\n    std::vector<NDArray> out_aux);\ntemplate void SpMMCsr<kDGLCUDA, int64_t, __nv_bfloat16>(\n    const std::string& op, const std::string& reduce, const BcastOff& bcast,\n    const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,\n    std::vector<NDArray> out_aux);\n#endif  // BF16_ENABLED\ntemplate void SpMMCsr<kDGLCUDA, int32_t, float>(\n    const std::string& op, const std::string& reduce, const BcastOff& bcast,\n    const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,\n    std::vector<NDArray> out_aux);\ntemplate void SpMMCsr<kDGLCUDA, int64_t, float>(\n    const std::string& op, const std::string& reduce, const BcastOff& bcast,\n    const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,\n    std::vector<NDArray> out_aux);\ntemplate void SpMMCsr<kDGLCUDA, int32_t, double>(\n    const std::string& op, const std::string& reduce, const BcastOff& bcast,\n    const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,\n    std::vector<NDArray> out_aux);\ntemplate void SpMMCsr<kDGLCUDA, int64_t, double>(\n    const std::string& op, const std::string& reduce, const BcastOff& bcast,\n    const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,\n    std::vector<NDArray> out_aux);\n\ntemplate void SpMMCoo<kDGLCUDA, int32_t, __half>(\n    const std::string& op, const std::string& reduce, const BcastOff& bcast,\n    const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,\n    std::vector<NDArray> out_aux);\ntemplate void SpMMCoo<kDGLCUDA, int64_t, __half>(\n    const std::string& op, const std::string& reduce, const BcastOff& bcast,\n    const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,\n    std::vector<NDArray> out_aux);\n#if BF16_ENABLED\ntemplate void SpMMCoo<kDGLCUDA, int32_t, __nv_bfloat16>(\n    const std::string& op, const std::string& reduce, const BcastOff& bcast,\n    const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,\n    std::vector<NDArray> out_aux);\ntemplate void SpMMCoo<kDGLCUDA, int64_t, __nv_bfloat16>(\n    const std::string& op, const std::string& reduce, const BcastOff& bcast,\n    const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,\n    std::vector<NDArray> out_aux);\n#endif  // BF16_ENABLED\ntemplate void SpMMCoo<kDGLCUDA, int32_t, float>(\n    const std::string& op, const std::string& reduce, const BcastOff& bcast,\n    const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,\n    std::vector<NDArray> out_aux);\ntemplate void SpMMCoo<kDGLCUDA, int64_t, float>(\n    const std::string& op, const std::string& reduce, const BcastOff& bcast,\n    const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,\n    std::vector<NDArray> out_aux);\ntemplate void SpMMCoo<kDGLCUDA, int32_t, double>(\n    const std::string& op, const std::string& reduce, const BcastOff& bcast,\n    const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,\n    std::vector<NDArray> out_aux);\ntemplate void SpMMCoo<kDGLCUDA, int64_t, double>(\n    const std::string& op, const std::string& reduce, const BcastOff& bcast,\n    const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,\n    std::vector<NDArray> out_aux);\n\n}  // namespace aten\n}  // namespace dgl\n"
  },
  {
    "path": "src/array/cuda/spmm.cuh",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file array/cuda/spmm.cuh\n * @brief SPMM CUDA kernel function header.\n */\n#ifndef DGL_ARRAY_CUDA_SPMM_CUH_\n#define DGL_ARRAY_CUDA_SPMM_CUH_\n\n#include <dgl/bcast.h>\n\n#include <limits>\n\n#include \"../../runtime/cuda/cuda_common.h\"\n#include \"./utils.h\"\n#include \"atomic.cuh\"\n#include \"bf16.cuh\"\n#include \"fp16.cuh\"\n#include \"macro.cuh\"\n\nnamespace dgl {\n\nusing namespace cuda;\n\nnamespace aten {\n\n/**\n * @brief Determine whether cusparse SpMM function is applicable.\n */\ntemplate <typename DType, typename IdType>\ninline bool cusparse_available(bool more_nnz_than_matrix_size) {\n#if CUDART_VERSION < 11000\n  if (std::is_same<IdType, int>::value &&\n      (std::is_same<DType, float>::value || std::is_same<DType, double>::value))\n    return true;\n  return false;\n#else\n  if (std::is_same<DType, __half>::value ||\n      std::is_same<DType, __nv_bfloat16>::value)\n    return false;  // cusparse's SpMM on fp16 is slow, temporally disabled.\n  // If the CSR matrix has more NNZ than matrix size, we should not use\n  // cuSPARSE 11.1.\n  return !more_nnz_than_matrix_size;\n#endif\n}\n\nnamespace {\n\n/** @brief Call cuBLAS geam API for transpose operation for float and double. */\ntemplate <typename DType>\ncublasStatus_t Xgeam(\n    cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,\n    int m, int n, const DType* alpha, const DType* A, int lda,\n    const DType* beta, const DType* B, int ldb, DType* C, int ldc) {\n  LOG(FATAL) << \"Not supported dtype\";\n  return CUBLAS_STATUS_EXECUTION_FAILED;\n}\n\ntemplate <>\ncublasStatus_t Xgeam<__half>(\n    cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,\n    int m, int n, const __half* alpha, const __half* A, int lda,\n    const __half* beta, const __half* B, int ldb, __half* C, int ldc) {\n  // TODO(ndickson): There is no cublasHgeam, so a different\n  // implementation would be required.\n  LOG(FATAL) << \"Xgeam does not support dtype half (FP16)\";\n  return CUBLAS_STATUS_EXECUTION_FAILED;\n}\n\n#if BF16_ENABLED\ntemplate <>\ncublasStatus_t Xgeam<__nv_bfloat16>(\n    cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,\n    int m, int n, const __nv_bfloat16* alpha, const __nv_bfloat16* A, int lda,\n    const __nv_bfloat16* beta, const __nv_bfloat16* B, int ldb,\n    __nv_bfloat16* C, int ldc) {\n  // TODO(ndickson): There is no cublasHgeam, so a different\n  // implementation would be required.\n  LOG(FATAL) << \"Xgeam does not support dtype bfloat16 (BF16)\";\n  return CUBLAS_STATUS_EXECUTION_FAILED;\n}\n#endif  // BF16_ENABLED\n\ntemplate <>\ncublasStatus_t Xgeam<float>(\n    cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,\n    int m, int n, const float* alpha, const float* A, int lda,\n    const float* beta, const float* B, int ldb, float* C, int ldc) {\n  return cublasSgeam(\n      handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc);\n}\n\ntemplate <>\ncublasStatus_t Xgeam<double>(\n    cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,\n    int m, int n, const double* alpha, const double* A, int lda,\n    const double* beta, const double* B, int ldb, double* C, int ldc) {\n  return cublasDgeam(\n      handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc);\n}\n\n/**\n * @brief Transpose operator kernel implementation.\n * @note not efficient but it's not a bottleneck, used for float16 dtype.\n */\ntemplate <typename DType>\n__global__ void _TransposeKernel(\n    const DType* __restrict__ in, DType* __restrict__ out, int n, int m) {\n  int i = blockIdx.x;\n  for (int j = threadIdx.x; j < m; j += blockDim.x)\n    out[i * m + j] = in[j * n + i];\n}\n\n/**\n * @brief Tranpose the input matrix.\n * @param row number of rows of input matrix.\n * @param col number of columns of input matrix.\n */\ntemplate <typename DType>\nvoid _Transpose(const DType* in, DType* out, int row, int col) {\n  DType alpha = 1., beta = 0.;\n  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  if (!thr_entry->cublas_handle)\n    CUBLAS_CALL(cublasCreate(&(thr_entry->cublas_handle)));\n  CUBLAS_CALL(cublasSetStream(thr_entry->cublas_handle, stream));\n  CUBLAS_CALL(Xgeam<DType>(\n      thr_entry->cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N, row, col, &alpha, in,\n      col, &beta, nullptr, row, out, row));\n}\n\n/**\n * @brief Tranpose the input matrix for data type half.\n * @note cuBLAS has no geam API for half data type, fallback to our kernel.\n */\ntemplate <>\nvoid _Transpose<__half>(const __half* in, __half* out, int row, int col) {\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  int nt = FindNumThreads(row);\n  int nb = col;\n  CUDA_KERNEL_CALL(_TransposeKernel, nb, nt, 0, stream, in, out, col, row);\n}\n\n#if BF16_ENABLED\n/**\n * @brief Tranpose the input matrix for data type half.\n * @note cuBLAS has no geam API for bf16 data type, fallback to our kernel.\n */\ntemplate <>\nvoid _Transpose<__nv_bfloat16>(\n    const __nv_bfloat16* in, __nv_bfloat16* out, int row, int col) {\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  int nt = FindNumThreads(row);\n  int nb = col;\n  CUDA_KERNEL_CALL(_TransposeKernel, nb, nt, 0, stream, in, out, col, row);\n}\n#endif  // BF16_ENABLED\n\n#if CUDART_VERSION < 11000\ntemplate <typename DType>\ncusparseStatus_t Xcsrmm2(\n    cusparseHandle_t handle, cusparseOperation_t transA,\n    cusparseOperation_t transB, int m, int n, int k, int nnz,\n    const DType* alpha, const cusparseMatDescr_t descrA, const DType* csrValA,\n    const int* csrRowPtrA, const int* csrColIndA, const DType* B, int ldb,\n    const DType* beta, DType* C, int ldc) {\n  LOG(INFO) << \"Not supported dtype\";\n  return CUSPARSE_STATUS_EXECUTION_FAILED;\n}\n\ntemplate <>\ncusparseStatus_t Xcsrmm2<float>(\n    cusparseHandle_t handle, cusparseOperation_t transA,\n    cusparseOperation_t transB, int m, int n, int k, int nnz,\n    const float* alpha, const cusparseMatDescr_t descrA, const float* csrValA,\n    const int* csrRowPtrA, const int* csrColIndA, const float* B, int ldb,\n    const float* beta, float* C, int ldc) {\n  return cusparseScsrmm2(\n      handle, transA, transB, m, n, k, nnz, alpha, descrA, csrValA, csrRowPtrA,\n      csrColIndA, B, ldb, beta, C, ldc);\n}\n\ntemplate <>\ncusparseStatus_t Xcsrmm2<double>(\n    cusparseHandle_t handle, cusparseOperation_t transA,\n    cusparseOperation_t transB, int m, int n, int k, int nnz,\n    const double* alpha, const cusparseMatDescr_t descrA, const double* csrValA,\n    const int* csrRowPtrA, const int* csrColIndA, const double* B, int ldb,\n    const double* beta, double* C, int ldc) {\n  return cusparseDcsrmm2(\n      handle, transA, transB, m, n, k, nnz, alpha, descrA, csrValA, csrRowPtrA,\n      csrColIndA, B, ldb, beta, C, ldc);\n}\n#endif\n\n/** Cusparse implementation of SpMM on Csr format. */\ntemplate <typename DType, typename IdType>\nvoid CusparseCsrmm2(\n    const DGLContext& ctx, const CSRMatrix& csr, const DType* B_data,\n    const DType* A_data, DType* C_data, int x_length,\n    bool use_deterministic_alg_only = false) {\n  // We use csrmm2 to perform following operation:\n  // C = A x B, where A is a sparse matrix in csr format, B is the dense matrix\n  // for node feature tensor. However, since cusparse only supports\n  // column-major, while our tensor is stored in row-major, the actual\n  // computation is: C = trans(A x trans(B)). Currently, we use cublasXgeam to\n  // implement transposition and allocate intermediate workspace memory for\n  // this.\n  const int m = csr.num_rows;\n  const int n = x_length;\n  const int k = csr.num_cols;\n  const int nnz = csr.indices->shape[0];\n  const DType alpha = 1.0;\n  const DType beta = 0.0;\n  // device\n  auto device = runtime::DeviceAPI::Get(ctx);\n  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  // allocate cusparse handle if needed\n  if (!thr_entry->cusparse_handle) {\n    CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle)));\n  }\n  CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, stream));\n  // all one data array\n  DType* valptr = nullptr;\n  if (!A_data) {\n    valptr =\n        static_cast<DType*>(device->AllocWorkspace(ctx, nnz * sizeof(DType)));\n    _Fill(valptr, nnz, static_cast<DType>(1.));\n  }\n#if CUDART_VERSION >= 11000\n  cusparseSpMatDescr_t matA;\n  cusparseDnMatDescr_t matB, matC;\n  constexpr auto dtype = cuda_dtype<DType>::value;\n  constexpr auto idtype = cusparse_idtype<IdType>::value;\n  CUSPARSE_CALL(cusparseCreateCsr(\n      &matA, m, k, nnz, static_cast<IdType*>(csr.indptr->data),\n      static_cast<IdType*>(csr.indices->data),\n      const_cast<DType*>(valptr ? valptr : A_data), idtype, idtype,\n      CUSPARSE_INDEX_BASE_ZERO, dtype));\n  CUSPARSE_CALL(cusparseCreateDnMat(\n      &matB, k, n, n, const_cast<DType*>(B_data), dtype, CUSPARSE_ORDER_ROW));\n  CUSPARSE_CALL(\n      cusparseCreateDnMat(&matC, m, n, n, C_data, dtype, CUSPARSE_ORDER_ROW));\n\n  auto transA = CUSPARSE_OPERATION_NON_TRANSPOSE;\n  auto transB = CUSPARSE_OPERATION_NON_TRANSPOSE;\n  size_t workspace_size;\n  cusparseSpMMAlg_t spmm_alg = use_deterministic_alg_only\n                                   ? CUSPARSE_SPMM_CSR_ALG3\n                                   : CUSPARSE_SPMM_CSR_ALG2;\n  CUSPARSE_CALL(cusparseSpMM_bufferSize(\n      thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,\n      matC, dtype, spmm_alg, &workspace_size));\n  void* workspace = device->AllocWorkspace(ctx, workspace_size);\n  CUSPARSE_CALL(cusparseSpMM(\n      thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,\n      matC, dtype, spmm_alg, workspace));\n  device->FreeWorkspace(ctx, workspace);\n\n  CUSPARSE_CALL(cusparseDestroySpMat(matA));\n  CUSPARSE_CALL(cusparseDestroyDnMat(matB));\n  CUSPARSE_CALL(cusparseDestroyDnMat(matC));\n#else\n  // allocate matrix for temporary transposed output\n  DType* trans_out =\n      static_cast<DType*>(device->AllocWorkspace(ctx, m * n * sizeof(DType)));\n\n  cusparseMatDescr_t descr;\n  CUSPARSE_CALL(cusparseCreateMatDescr(&descr));\n  CUSPARSE_CALL(cusparseSetMatType(descr, CUSPARSE_MATRIX_TYPE_GENERAL));\n  CUSPARSE_CALL(cusparseSetMatIndexBase(descr, CUSPARSE_INDEX_BASE_ZERO));\n  CUSPARSE_CALL(Xcsrmm2<DType>(\n      thr_entry->cusparse_handle, CUSPARSE_OPERATION_NON_TRANSPOSE,\n      CUSPARSE_OPERATION_TRANSPOSE, m, n, k, nnz, &alpha, descr,\n      (valptr) ? valptr : A_data, static_cast<int32_t*>(csr.indptr->data),\n      static_cast<int32_t*>(csr.indices->data), B_data, n, &beta, trans_out,\n      m));\n  CUSPARSE_CALL(cusparseDestroyMatDescr(descr));\n  // transpose the output matrix\n  _Transpose(trans_out, C_data, n, m);\n  device->FreeWorkspace(ctx, trans_out);\n#endif\n  if (valptr) device->FreeWorkspace(ctx, valptr);\n}\n\n/** Cusparse implementation of SpMM on Csr format. */\ntemplate <typename DType, typename IdType>\nvoid CusparseCsrmm2Hetero(\n    const DGLContext& ctx, const CSRMatrix& csr, const DType* B_data,\n    const DType* A_data, DType* C_data, int64_t x_length, cudaStream_t strm_id,\n    bool use_deterministic_alg_only = false) {\n  // We use csrmm2 to perform following operation:\n  // C = A x B, where A is a sparse matrix in csr format, B is the dense matrix\n  // for node feature tensor. However, since cusparse only supports\n  // column-major, while our tensor is stored in row-major, the actual\n  // computation is: C = trans(A x trans(B)). Currently, we use cublasXgeam to\n  // implement transposition and allocate intermediate workspace memory for\n  // this.\n  int int_maxlimit = std::numeric_limits<int>::max();\n  CHECK_GE(int_maxlimit, (csr.num_rows));\n  CHECK_GE(int_maxlimit, csr.num_cols);\n  CHECK_GE(int_maxlimit, csr.indices->shape[0]);\n  const int m = csr.num_rows;\n  const int n = x_length;\n  const int k = csr.num_cols;\n  const int nnz = csr.indices->shape[0];\n  const DType alpha = 1.0;\n  const DType beta = 1.0;\n  // device\n  auto device = runtime::DeviceAPI::Get(ctx);\n  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();\n  // allocate cusparse handle if needed\n  if (!thr_entry->cusparse_handle) {\n    CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle)));\n  }\n  CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, strm_id));\n  // all one data array\n  DType* valptr = nullptr;\n  if (!A_data) {\n    valptr =\n        static_cast<DType*>(device->AllocWorkspace(ctx, nnz * sizeof(DType)));\n    _Fill(valptr, nnz, static_cast<DType>(1.));\n  }\n#if CUDART_VERSION >= 11000\n  cusparseSpMatDescr_t matA;\n  cusparseDnMatDescr_t matB, matC;\n  constexpr auto dtype = cuda_dtype<DType>::value;\n  constexpr auto idtype = cusparse_idtype<IdType>::value;\n  CUSPARSE_CALL(cusparseCreateCsr(\n      &matA, m, k, nnz, static_cast<IdType*>(csr.indptr->data),\n      static_cast<IdType*>(csr.indices->data),\n      const_cast<DType*>(valptr ? valptr : A_data), idtype, idtype,\n      CUSPARSE_INDEX_BASE_ZERO, dtype));\n  CUSPARSE_CALL(cusparseCreateDnMat(\n      &matB, k, n, n, const_cast<DType*>(B_data), dtype, CUSPARSE_ORDER_ROW));\n  CUSPARSE_CALL(\n      cusparseCreateDnMat(&matC, m, n, n, C_data, dtype, CUSPARSE_ORDER_ROW));\n\n  auto transA = CUSPARSE_OPERATION_NON_TRANSPOSE;\n  auto transB = CUSPARSE_OPERATION_NON_TRANSPOSE;\n  size_t workspace_size;\n  cusparseSpMMAlg_t spmm_alg = use_deterministic_alg_only\n                                   ? CUSPARSE_SPMM_CSR_ALG3\n                                   : CUSPARSE_SPMM_CSR_ALG2;\n  CUSPARSE_CALL(cusparseSpMM_bufferSize(\n      thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,\n      matC, dtype, spmm_alg, &workspace_size));\n  void* workspace = device->AllocWorkspace(ctx, workspace_size);\n  CUSPARSE_CALL(cusparseSpMM(\n      thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,\n      matC, dtype, spmm_alg, workspace));\n  device->FreeWorkspace(ctx, workspace);\n\n  CUSPARSE_CALL(cusparseDestroySpMat(matA));\n  CUSPARSE_CALL(cusparseDestroyDnMat(matB));\n  CUSPARSE_CALL(cusparseDestroyDnMat(matC));\n#else\n  cusparseMatDescr_t descr;\n  CUSPARSE_CALL(cusparseCreateMatDescr(&descr));\n  CUSPARSE_CALL(cusparseSetMatType(descr, CUSPARSE_MATRIX_TYPE_GENERAL));\n  CUSPARSE_CALL(cusparseSetMatIndexBase(descr, CUSPARSE_INDEX_BASE_ZERO));\n  CHECK_EQ(sizeof(IdType), sizeof(int32_t));\n  CUSPARSE_CALL(Xcsrmm2<DType>(\n      thr_entry->cusparse_handle, CUSPARSE_OPERATION_NON_TRANSPOSE,\n      CUSPARSE_OPERATION_TRANSPOSE, m, n, k, nnz, &alpha, descr,\n      (valptr) ? valptr : A_data, static_cast<int32_t*>(csr.indptr->data),\n      static_cast<int32_t*>(csr.indices->data), B_data, n, &beta, C_data, m));\n  CUSPARSE_CALL(cusparseDestroyMatDescr(descr));\n#endif\n  if (valptr) device->FreeWorkspace(ctx, valptr);\n}\n\n}  // namespace\n\n#define SWITCH_OP(op, Op, ...)                                  \\\n  do {                                                          \\\n    if ((op) == \"add\") {                                        \\\n      typedef cuda::binary::Add<DType> Op;                      \\\n      { __VA_ARGS__ }                                           \\\n    } else if ((op) == \"sub\") {                                 \\\n      typedef cuda::binary::Sub<DType> Op;                      \\\n      { __VA_ARGS__ }                                           \\\n    } else if ((op) == \"mul\") {                                 \\\n      typedef cuda::binary::Mul<DType> Op;                      \\\n      { __VA_ARGS__ }                                           \\\n    } else if ((op) == \"div\") {                                 \\\n      typedef cuda::binary::Div<DType> Op;                      \\\n      { __VA_ARGS__ }                                           \\\n    } else if ((op) == \"copy_lhs\") {                            \\\n      typedef cuda::binary::CopyLhs<DType> Op;                  \\\n      { __VA_ARGS__ }                                           \\\n    } else if ((op) == \"copy_rhs\") {                            \\\n      typedef cuda::binary::CopyRhs<DType> Op;                  \\\n      { __VA_ARGS__ }                                           \\\n    } else {                                                    \\\n      LOG(FATAL) << \"Unsupported SpMM binary operator: \" << op; \\\n    }                                                           \\\n  } while (0)\n\nnamespace cuda {\n\n/**\n * @brief CUDA kernel of g-SpMM on Coo format.\n * @note it uses edge parallel strategy, different threadblocks (on y-axis)\n *       is responsible for the computation on different edges. Threadblocks\n *       on the x-axis are responsible for the computation on different\n * positions in feature dimension. To avoid possible data hazards, it uses\n * atomic operators for reduction.\n */\ntemplate <\n    typename Idx, typename DType, typename BinaryOp, typename ReduceOp,\n    bool UseBcast = false, bool UseIdx = false>\n__global__ void SpMMCooKernel(\n    const DType* __restrict__ ufeat, const DType* __restrict__ efeat,\n    DType* __restrict__ out, Idx* __restrict__ arg_u, Idx* __restrict__ arg_e,\n    const Idx* __restrict__ row, const Idx* __restrict__ col,\n    const Idx* __restrict__ edge_map, int64_t N, int64_t M, int64_t E,\n    const int64_t* __restrict__ ubcast_off,\n    const int64_t* __restrict__ ebcast_off, int64_t ufeat_len,\n    int64_t efeat_len, int64_t out_len) {\n  // SPMM with COO.\n  Idx ty = blockIdx.y * blockDim.y + threadIdx.y;\n  const Idx stride_y = blockDim.y * gridDim.y;\n  while (ty < E) {\n    const Idx src = _ldg(row + ty);\n    const Idx dst = _ldg(col + ty);\n    const Idx eid = UseIdx ? _ldg(edge_map + ty) : ty;\n    int64_t tx = blockIdx.x * blockDim.x + threadIdx.x;\n    const int64_t stride_x = blockDim.x * gridDim.x;\n    const DType* uoff = BinaryOp::use_lhs ? (ufeat + src * ufeat_len) : nullptr;\n    const DType* eoff = BinaryOp::use_rhs ? (efeat + eid * efeat_len) : nullptr;\n    DType* outoff = out + dst * out_len;\n    while (tx < out_len) {\n      const int64_t lhs_add = UseBcast ? ubcast_off[tx] : tx;\n      const int64_t rhs_add = UseBcast ? ebcast_off[tx] : tx;\n      DType val = BinaryOp::Call(uoff + lhs_add, eoff + rhs_add);\n      Idx* arguoff = nullptr;  // arguoff is not used in SpMMCoo.\n      Idx* argeoff = nullptr;  // argeoff is not used in SpMMCoo.\n      ReduceOp::Call(outoff + tx, arguoff, argeoff, val, src, eid);\n      tx += stride_x;\n    }\n    ty += stride_y;\n  }\n}\n\n/**\n * @brief CUDA kernel to compute argu and arge in g-SpMM on Coo format.\n * @note it uses edge parallel strategy, different threadblocks (on y-axis)\n *       is responsible for the computation on different edges. Threadblocks\n *       on the x-axis are responsible for the computation on different\n * positions in feature dimension.\n */\ntemplate <\n    typename Idx, typename DType, typename BinaryOp, typename ReduceOp,\n    bool UseBcast = false, bool UseIdx = false>\n__global__ void ArgSpMMCooKernel(\n    const DType* __restrict__ ufeat, const DType* __restrict__ efeat,\n    DType* __restrict__ out, Idx* __restrict__ arg_u, Idx* __restrict__ arg_e,\n    const Idx* __restrict__ row, const Idx* __restrict__ col,\n    const Idx* __restrict__ edge_map, int64_t N, int64_t M, int64_t E,\n    const int64_t* __restrict__ ubcast_off,\n    const int64_t* __restrict__ ebcast_off, int64_t ufeat_len,\n    int64_t efeat_len, int64_t out_len) {\n  // SPMM with COO arg max/min.\n  Idx ty = blockIdx.y * blockDim.y + threadIdx.y;\n  const Idx stride_y = blockDim.y * gridDim.y;\n  while (ty < E) {\n    const Idx src = _ldg(row + ty);\n    const Idx dst = _ldg(col + ty);\n    const Idx eid = UseIdx ? _ldg(edge_map + ty) : ty;\n    int64_t tx = blockIdx.x * blockDim.x + threadIdx.x;\n    const int64_t stride_x = blockDim.x * gridDim.x;\n    const DType* uoff = BinaryOp::use_lhs ? (ufeat + src * ufeat_len) : nullptr;\n    const DType* eoff = BinaryOp::use_rhs ? (efeat + eid * efeat_len) : nullptr;\n    const DType* outoff = out + dst * out_len;\n    Idx* arguoff = BinaryOp::use_lhs ? (arg_u + dst * out_len) : nullptr;\n    Idx* argeoff = BinaryOp::use_rhs ? (arg_e + dst * out_len) : nullptr;\n    while (tx < out_len) {\n      int64_t lhs_add = UseBcast ? ubcast_off[tx] : tx;\n      int64_t rhs_add = UseBcast ? ebcast_off[tx] : tx;\n      DType val = BinaryOp::Call(uoff + lhs_add, eoff + rhs_add);\n      ReduceOp::CallArg(tx, arguoff, argeoff, val, outoff[tx], src, eid);\n      tx += stride_x;\n    }\n    ty += stride_y;\n  }\n}\n\n/**\n * @brief CUDA kernel of g-SpMM on Csr format.\n * @note it uses node parallel strategy, different threadblocks (on y-axis)\n *       is responsible for the computation on different destination nodes.\n *       Threadblocks on the x-axis are responsible for the computation on\n *       different positions in feature dimension.\n */\ntemplate <\n    typename Idx, typename DType, typename BinaryOp, typename ReduceOp,\n    bool UseBcast = false, bool UseIdx = false>\n__global__ void SpMMCsrKernel(\n    const DType* __restrict__ ufeat, const DType* __restrict__ efeat,\n    DType* __restrict__ out, Idx* __restrict__ arg_u, Idx* __restrict__ arg_e,\n    const Idx* __restrict__ indptr, const Idx* __restrict__ indices,\n    const Idx* __restrict__ edge_map, int64_t num_rows, int64_t num_cols,\n    const int64_t* __restrict__ ubcast_off,\n    const int64_t* __restrict__ ebcast_off, int64_t ufeat_len,\n    int64_t efeat_len, int64_t out_len) {\n  // SPMM with CSR.\n  int ty = blockIdx.x * blockDim.y + threadIdx.y;\n  const Idx stride_y = blockDim.y * gridDim.x;\n  const int stride_x = blockDim.x * gridDim.y;\n  while (ty < num_rows) {\n    int tx = blockIdx.y * blockDim.x + threadIdx.x;\n    while (tx < out_len) {\n      typename accum_dtype<DType>::type local_accum = ReduceOp::zero();\n      Idx local_argu = 0, local_arge = 0;\n      const int lhs_add = UseBcast ? ubcast_off[tx] : tx;\n      const int rhs_add = UseBcast ? ebcast_off[tx] : tx;\n      for (Idx i = indptr[ty]; i < indptr[ty + 1]; ++i) {\n        const Idx eid = UseIdx ? _ldg(edge_map + i) : i;\n        const Idx cid = _ldg(indices + i);\n        const DType* uoff =\n            BinaryOp::use_lhs ? (ufeat + cid * ufeat_len) : nullptr;\n        const DType* eoff =\n            BinaryOp::use_rhs ? (efeat + eid * efeat_len) : nullptr;\n        DType out = BinaryOp::Call(uoff + lhs_add, eoff + rhs_add);\n        ReduceOp::Call(&local_accum, &local_argu, &local_arge, out, cid, eid);\n      }\n      // The use of += is to compute cross-type reducing on heterogeneous graph\n      // when reduce op is `sum`.\n      //     C = SpMM(SpA, B) + C\n      // Separate kernel `SpMMCmpCsrHeteroKernel` is used for max- and\n      // min-reducer. It does not affect the output on homogeneous graph as\n      // `out` is initialized to zero.\n      out[ty * out_len + tx] += static_cast<DType>(local_accum);\n      if (ReduceOp::require_arg && BinaryOp::use_lhs)\n        arg_u[ty * out_len + tx] = local_argu;\n      if (ReduceOp::require_arg && BinaryOp::use_rhs)\n        arg_e[ty * out_len + tx] = local_arge;\n      tx += stride_x;\n    }\n    ty += stride_y;\n  }\n}\n\n/**\n * @brief CUDA kernel of SpMM-Min/Max on Csr format.\n * @note it uses node parallel strategy, different threadblocks (on y-axis)\n *       is responsible for the computation on different destination nodes.\n *       Threadblocks on the x-axis are responsible for the computation on\n *       different positions in feature dimension.\n */\ntemplate <\n    typename Idx, typename DType, typename BinaryOp, typename ReduceOp,\n    bool UseBcast = false, bool UseIdx = false>\n__global__ void SpMMCmpCsrHeteroKernel(\n    const DType* __restrict__ ufeat, const DType* __restrict__ efeat,\n    DType* __restrict__ out, Idx* __restrict__ arg_u, Idx* __restrict__ arg_e,\n    Idx* __restrict__ arg_u_ntype, Idx* __restrict__ arg_e_etype,\n    const Idx* __restrict__ indptr, const Idx* __restrict__ indices,\n    const Idx* __restrict__ edge_map, int64_t num_rows, int64_t num_cols,\n    const int64_t* __restrict__ ubcast_off,\n    const int64_t* __restrict__ ebcast_off, int64_t ufeat_len,\n    int64_t efeat_len, int64_t out_len, const int src_type, const int etype) {\n  // SPMM with CSR.\n  int ty = blockIdx.y * blockDim.y + threadIdx.y;\n  const Idx stride_y = blockDim.y * gridDim.y;\n  const int stride_x = blockDim.x * gridDim.x;\n  while (ty < num_rows) {\n    int tx = blockIdx.x * blockDim.x + threadIdx.x;\n    while (tx < out_len) {\n      using accum_type = typename accum_dtype<DType>::type;\n      accum_type local_accum =\n          static_cast<accum_type>(out[ty * out_len + tx]);  // ReduceOp::zero();\n      Idx local_argu = 0, local_arge = 0;\n      const int lhs_add = UseBcast ? ubcast_off[tx] : tx;\n      const int rhs_add = UseBcast ? ebcast_off[tx] : tx;\n      for (Idx i = indptr[ty]; i < indptr[ty + 1]; ++i) {\n        const Idx eid = UseIdx ? _ldg(edge_map + i) : i;\n        const Idx cid = _ldg(indices + i);\n        const DType* uoff =\n            BinaryOp::use_lhs ? (ufeat + cid * ufeat_len) : nullptr;\n        const DType* eoff =\n            BinaryOp::use_rhs ? (efeat + eid * efeat_len) : nullptr;\n        DType tmp_out = BinaryOp::Call(uoff + lhs_add, eoff + rhs_add);\n        ReduceOp::Call(\n            &local_accum, &local_argu, &local_arge, tmp_out, cid, eid);\n      }\n      // Update output only when max/min values are different that original\n      // output\n      DType new_out = static_cast<DType>(local_accum);\n      if (out[ty * out_len + tx] != new_out) {\n        out[ty * out_len + tx] = new_out;\n        if (ReduceOp::require_arg && BinaryOp::use_lhs) {\n          arg_u[ty * out_len + tx] = local_argu;\n          arg_u_ntype[ty * out_len + tx] = src_type;\n        }\n        if (ReduceOp::require_arg && BinaryOp::use_rhs) {\n          arg_e[ty * out_len + tx] = local_arge;\n          arg_e_etype[ty * out_len + tx] = etype;\n        }\n      }\n      tx += stride_x;\n    }\n    ty += stride_y;\n  }\n}\n\n/**\n * @brief CUDA implementation of g-SpMM on Coo format.\n * @param bcast Broadcast information.\n * @param coo The Coo matrix.\n * @param ufeat The feature on source nodes.\n * @param efeat The feature on edges.\n * @param out The result feature on destination nodes.\n * @param argu Arg-Min/Max on source nodes, which refers the source node indices\n *        correspond to the minimum/maximum values of reduction result on\n *        destination nodes. It's useful in computing gradients of Min/Max\n * reducer.\n * @param arge Arg-Min/Max on edges. which refers the source node indices\n *        correspond to the minimum/maximum values of reduction result on\n *        destination nodes. It's useful in computing gradients of Min/Max\n * reducer.\n */\ntemplate <typename Idx, typename DType, typename BinaryOp, typename ReduceOp>\nvoid SpMMCoo(\n    const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, NDArray efeat,\n    NDArray out, NDArray argu, NDArray arge) {\n  /**\n   * TODO(Xin): Disable half precision for SpMMCoo due to the round-off error.\n   * We should use fp32 for the accumulation but it's hard to modify the\n   * current implementation.\n   */\n#if BF16_ENABLED\n  if (std::is_same<DType, __half>::value ||\n      std::is_same<DType, __nv_bfloat16>::value)\n#else\n  if (std::is_same<DType, __half>::value)\n#endif  // BF16_ENABLED\n    LOG(FATAL) << \"SpMMCoo doesn't support half precision fow now. \"\n               << \"Please use SpMMCsr instead by allowing the graph \"\n               << \"materialize CSR/CSC formats.\";\n  const Idx *row = coo.row.Ptr<Idx>(), *col = coo.col.Ptr<Idx>(),\n            *edge_map = coo.data.Ptr<Idx>();\n  const DType *ufeat_data = ufeat.Ptr<DType>(),\n              *efeat_data = efeat.Ptr<DType>();\n  DType* out_data = out.Ptr<DType>();\n  Idx *argu_data = argu.Ptr<Idx>(), *arge_data = arge.Ptr<Idx>();\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  const int64_t N = coo.num_rows, M = coo.num_cols, E = coo.row->shape[0];\n\n  int64_t *ubcast_off = nullptr, *ebcast_off = nullptr;\n  int64_t len = bcast.out_len, lhs_len = bcast.lhs_len, rhs_len = bcast.rhs_len;\n\n  int64_t out_size = out.NumElements();\n  const int nt = FindNumThreads(out_size);\n  const int nb = (out_size + nt - 1) / nt;\n  CUDA_KERNEL_CALL(\n      _FillKernel, nb, nt, 0, stream, out_data, out_size, ReduceOp::zero());\n\n  const int ntx = FindNumThreads(len);\n  const int nty = CUDA_MAX_NUM_THREADS / ntx;\n  const int nbx = (len + ntx - 1) / ntx;\n  const int nby = FindNumBlocks<'y'>((E + nty - 1) / nty);\n  const dim3 nblks(nbx, nby);\n  const dim3 nthrs(ntx, nty);\n  const bool use_idx = !IsNullArray(coo.data);\n\n  BCAST_IDX_CTX_SWITCH(bcast, use_idx, ufeat->ctx, ubcast_off, ebcast_off, {\n    CUDA_KERNEL_CALL(\n        (SpMMCooKernel<Idx, DType, BinaryOp, ReduceOp, UseBcast, UseIdx>),\n        nblks, nthrs, 0, stream, ufeat_data, efeat_data, out_data, argu_data,\n        arge_data, row, col, edge_map, N, M, E, ubcast_off, ebcast_off, lhs_len,\n        rhs_len, len);\n    if (ReduceOp::require_arg) {\n      CUDA_KERNEL_CALL(\n          (ArgSpMMCooKernel<Idx, DType, BinaryOp, ReduceOp, UseBcast, UseIdx>),\n          nblks, nthrs, 0, stream, ufeat_data, efeat_data, out_data, argu_data,\n          arge_data, row, col, edge_map, N, M, E, ubcast_off, ebcast_off,\n          lhs_len, rhs_len, len);\n    }\n  });\n}\n\n/**\n * @brief CUDA implementation of g-SpMM on Csr format.\n * @param bcast Broadcast information.\n * @param csr The Csr matrix.\n * @param ufeat The feature on source nodes.\n * @param efeat The feature on edges.\n * @param out The result feature on destination nodes.\n * @param argu Arg-Min/Max on source nodes, which refers the source node indices\n *        correspond to the minimum/maximum values of reduction result on\n *        destination nodes. It's useful in computing gradients of Min/Max\n * reducer.\n * @param arge Arg-Min/Max on edges. which refers the source node indices\n *        correspond to the minimum/maximum values of reduction result on\n *        destination nodes. It's useful in computing gradients of Min/Max\n * reducer.\n */\ntemplate <typename Idx, typename DType, typename BinaryOp, typename ReduceOp>\nvoid SpMMCsr(\n    const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat,\n    NDArray out, NDArray argu, NDArray arge) {\n  const Idx* indptr = csr.indptr.Ptr<Idx>();\n  const Idx* indices = csr.indices.Ptr<Idx>();\n  const Idx* edge_map = csr.data.Ptr<Idx>();\n  const DType* ufeat_data = ufeat.Ptr<DType>();\n  const DType* efeat_data = efeat.Ptr<DType>();\n  DType* out_data = out.Ptr<DType>();\n  Idx* argu_data = argu.Ptr<Idx>();\n  Idx* arge_data = arge.Ptr<Idx>();\n\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n\n  int64_t *ubcast_off = nullptr, *ebcast_off = nullptr;\n  int64_t len = bcast.out_len, lhs_len = bcast.lhs_len, rhs_len = bcast.rhs_len;\n  const int ntx = FindNumThreads(len);\n  const int nty = CUDA_MAX_NUM_THREADS / ntx;\n  const int nby = (len + ntx - 1) / ntx;\n  const int nbx = FindNumBlocks<'x'>((csr.num_rows + nty - 1) / nty);\n  const dim3 nblks(nbx, nby);\n  const dim3 nthrs(ntx, nty);\n  const bool use_idx = !IsNullArray(csr.data);\n\n  BCAST_IDX_CTX_SWITCH(\n      bcast, use_idx, ufeat->ctx, ubcast_off, ebcast_off,\n      {CUDA_KERNEL_CALL(\n          (SpMMCsrKernel<Idx, DType, BinaryOp, ReduceOp, UseBcast, UseIdx>),\n          nblks, nthrs, 0, stream, ufeat_data, efeat_data, out_data, argu_data,\n          arge_data, indptr, indices, edge_map, csr.num_rows, csr.num_cols,\n          ubcast_off, ebcast_off, lhs_len, rhs_len, len)});\n}\n\n/**\n * @brief CUDA kernel of SpMM-Min/Max on Csr format on heterogeneous graph.\n * @param bcast Broadcast information.\n * @param csr The Csr matrix.\n * @param ufeat The feature on source nodes.\n * @param efeat The feature on edges.\n * @param out The result feature on destination nodes.\n * @param argu Arg-Min/Max on source nodes, which refers the source node indices\n *        correspond to the minimum/maximum values of reduction result on\n *        destination nodes. It's useful in computing gradients of Min/Max\n * reducer.\n * @param arge Arg-Min/Max on edges. which refers the source node indices\n *        correspond to the minimum/maximum values of reduction result on\n *        destination nodes. It's useful in computing gradients of Min/Max\n * reducer.\n * @param argu_ntype Node type of the arg-Min/Max on source nodes, which refers\n * the source node types correspond to the minimum/maximum values of reduction\n * result on destination nodes. It's useful in computing gradients of Min/Max\n * reducer.\n * @param arge_etype Edge-type of the arg-Min/Max on edges. which refers the\n * source node indices correspond to the minimum/maximum values of reduction\n * result on destination nodes. It's useful in computing gradients of Min/Max\n * reducer.\n * @param src_type Node type of the source nodes of an etype\n * @param etype Edge type\n */\ntemplate <typename Idx, typename DType, typename BinaryOp, typename ReduceOp>\nvoid SpMMCmpCsrHetero(\n    const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat,\n    NDArray out, NDArray argu, NDArray arge, NDArray argu_ntype,\n    NDArray arge_etype, const int src_type, const int etype) {\n  const Idx* indptr = csr.indptr.Ptr<Idx>();\n  const Idx* indices = csr.indices.Ptr<Idx>();\n  const Idx* edge_map = csr.data.Ptr<Idx>();\n  const DType* ufeat_data = ufeat.Ptr<DType>();\n  const DType* efeat_data = efeat.Ptr<DType>();\n  DType* out_data = out.Ptr<DType>();\n  Idx* argu_data = argu.Ptr<Idx>();\n  Idx* arge_data = arge.Ptr<Idx>();\n\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n\n  int64_t *ubcast_off = nullptr, *ebcast_off = nullptr;\n  int64_t len = bcast.out_len, lhs_len = bcast.lhs_len, rhs_len = bcast.rhs_len;\n  const int ntx = FindNumThreads(len);\n  const int nty = CUDA_MAX_NUM_THREADS / ntx;\n  const int nbx = (len + ntx - 1) / ntx;\n  const int nby = FindNumBlocks<'y'>((csr.num_rows + nty - 1) / nty);\n  const dim3 nblks(nbx, nby);\n  const dim3 nthrs(ntx, nty);\n  const bool use_idx = !IsNullArray(csr.data);\n\n  BCAST_IDX_CTX_SWITCH(\n      bcast, use_idx, ufeat->ctx, ubcast_off, ebcast_off,\n      {CUDA_KERNEL_CALL(\n          (SpMMCmpCsrHeteroKernel<\n              Idx, DType, BinaryOp, ReduceOp, UseBcast, UseIdx>),\n          nblks, nthrs, 0, stream, ufeat_data, efeat_data, out_data, argu_data,\n          arge_data, static_cast<Idx*>(argu_ntype->data),\n          static_cast<Idx*>(arge_etype->data), indptr, indices, edge_map,\n          csr.num_rows, csr.num_cols, ubcast_off, ebcast_off, lhs_len, rhs_len,\n          len, src_type, etype)});\n}\n\n}  // namespace cuda\n}  // namespace aten\n}  // namespace dgl\n\n#endif  // DGL_ARRAY_CUDA_SPMM_CUH_\n"
  },
  {
    "path": "src/array/cuda/spmm_hetero.cu",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file array/cuda/spmm.cu\n * @brief SPMM C APIs and definitions.\n */\n#include <dgl/array.h>\n\n#include <cstdlib>\n\n#include \"../../runtime/cuda/cuda_common.h\"\n#include \"./functor.cuh\"\n#include \"./ge_spmm.cuh\"\n#include \"./spmm.cuh\"\n\nnamespace dgl {\n\nusing namespace cuda;\n\nnamespace aten {\n\n/**\n * @brief CUDA implementation of g-SpMM on Csr format.\n * @note use cusparse if the reduce operator is `sum` and there is\n *       no broadcast, use dgl's kernel in other cases.\n */\ntemplate <int XPU, typename IdType, typename DType>\nvoid SpMMCsrHetero(\n    const std::string& op, const std::string& reduce, const BcastOff& bcast,\n    const std::vector<CSRMatrix>& vec_csr,\n    const std::vector<NDArray>& vec_ufeat,\n    const std::vector<NDArray>& vec_efeat, std::vector<NDArray>* vec_out,\n    std::vector<std::vector<NDArray>>* out_aux,\n    const std::vector<dgl_type_t>& ufeat_ntids,  // ufeat node type id\n    const std::vector<dgl_type_t>& out_ntids) {  // output node type id\n  bool is_scalar_efeat =\n      vec_efeat[0].NumElements() == vec_csr[0].indices->shape[0];\n  bool use_efeat = op != \"copy_lhs\";\n  auto device = runtime::DeviceAPI::Get(vec_csr[0].indptr->ctx);\n  std::vector<DType*> trans_out((*vec_out).size(), NULL);\n  bool use_deterministic_alg_only = false;\n  if (NULL != std::getenv(\"USE_DETERMINISTIC_ALG\"))\n    use_deterministic_alg_only = true;\n\n  bool use_legacy_cusparsemm =\n      (CUDART_VERSION < 11000) && (reduce == \"sum\") &&\n      // legacy cuSPARSE does not care about NNZ, hence the argument \"false\".\n      ((op == \"copy_lhs\" && cusparse_available<DType, IdType>(false)) ||\n       (op == \"mul\" && is_scalar_efeat &&\n        cusparse_available<DType, IdType>(false)));\n  // Create temporary output buffer to store non-transposed output\n  if (use_legacy_cusparsemm) {\n    for (dgl_type_t ntype = 0; ntype < (*vec_out).size(); ++ntype) {\n      const int m = (*vec_out)[ntype]->shape[0];\n      const int n = (*vec_out)[ntype]->shape[1];\n      if (m == 0) continue;\n      DType* out = static_cast<DType*>(device->AllocWorkspace(\n          vec_csr[0].indptr->ctx, m * n * sizeof(DType)));\n      CUDA_CALL(cudaMemset(out, 0, m * n * sizeof(DType)));\n      trans_out[ntype] = out;\n    }\n  }\n  // Check shape of ufeat for all relation type and compute feature size\n  int64_t x_length = 1;\n  for (dgl_type_t etype = 0; etype < (ufeat_ntids.size() - 1); ++etype) {\n    NDArray ufeat = vec_ufeat[ufeat_ntids[etype]];\n    NDArray next_ufeat = vec_ufeat[ufeat_ntids[etype + 1]];\n    CHECK_EQ(ufeat->ndim, next_ufeat->ndim)\n        << \"Input features have different shapes\";\n    for (int i = 1; i < ufeat->ndim; ++i) {\n      if (ufeat->shape[i] != next_ufeat->shape[i]) {\n        if (ufeat->shape[i] == 1 || next_ufeat->shape[i] == 1)\n          LOG(FATAL) << \"Homogenized message passing on heterogeneous graphs \"\n                        \"does not support \"\n                     << \"automatic broadcasting.  Please manually broadcast it \"\n                        \"before calling \"\n                     << \"message passing functions.\";\n        else\n          LOG(FATAL) << \"Input features have different shapes.\";\n        return;\n      }\n\n      if (etype == 0) x_length *= ufeat->shape[i];\n    }\n  }\n  // TODO(Israt): Can python do the following initializations while creating the\n  // tensors?\n  if (reduce == \"max\" || reduce == \"min\") {\n    const int64_t dim = bcast.out_len;\n    std::vector<bool> updated((*vec_out).size(), false);\n    for (dgl_type_t etype = 0; etype < ufeat_ntids.size(); ++etype) {\n      DType* out_off = (*vec_out)[out_ntids[etype]].Ptr<DType>();\n      if (reduce == \"max\")\n        _Fill(\n            out_off, vec_csr[etype].num_rows * dim,\n            cuda::reduce::Max<IdType, DType>::zero());\n      else  // min\n        _Fill(\n            out_off, vec_csr[etype].num_rows * dim,\n            cuda::reduce::Min<IdType, DType>::zero());\n      const dgl_type_t dst_id = out_ntids[etype];\n      if (!updated[dst_id]) {\n        updated[dst_id] = true;\n        if (op == \"copy_lhs\") {\n          IdType* argu_ntype = (*out_aux)[2][dst_id].Ptr<IdType>();\n          _Fill(\n              argu_ntype, vec_csr[etype].num_rows * dim,\n              static_cast<IdType>(-1));\n        }\n        if (op == \"copy_rhs\") {\n          IdType* arge_etype = (*out_aux)[3][dst_id].Ptr<IdType>();\n          _Fill(\n              arge_etype, vec_csr[etype].num_rows * dim,\n              static_cast<IdType>(-1));\n        }\n      }\n    }\n  }\n\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  for (dgl_type_t etype = 0; etype < ufeat_ntids.size(); ++etype) {\n    const dgl_type_t src_id = ufeat_ntids[etype];\n    const dgl_type_t dst_id = out_ntids[etype];\n    CSRMatrix csr = vec_csr[etype];\n    if (reduce == \"sum\") {\n      bool more_nnz = (csr.indices->shape[0] > csr.num_rows * csr.num_cols);\n      /* Call  SpMM for each relation type */\n      if (op == \"copy_lhs\" &&\n          cusparse_available<DType, IdType>(more_nnz)) {  // cusparse\n        /* If CUDA is less than 11.0, put the output in trans_out for later\n         * transposition */\n        DType* out = (CUDART_VERSION < 11000)\n                         ? trans_out[dst_id]\n                         : static_cast<DType*>((*vec_out)[dst_id]->data);\n        CusparseCsrmm2Hetero<DType, IdType>(\n            csr.indptr->ctx, csr, static_cast<DType*>(vec_ufeat[src_id]->data),\n            nullptr, out, x_length, stream, use_deterministic_alg_only);\n      } else if (\n          op == \"mul\" && is_scalar_efeat &&\n          cusparse_available<DType, IdType>(more_nnz)) {  // cusparse\n        NDArray efeat = vec_efeat[etype];\n        if (!IsNullArray(csr.data)) efeat = IndexSelect(efeat, csr.data);\n        CusparseCsrmm2Hetero<DType, IdType>(\n            csr.indptr->ctx, csr, static_cast<DType*>(vec_ufeat[src_id]->data),\n            static_cast<DType*>(efeat->data),\n            // TODO(Israt): Change (*vec_out) to trans_out to support CUDA\n            // version < 11\n            static_cast<DType*>((*vec_out)[dst_id]->data), x_length, stream,\n            use_deterministic_alg_only);\n      } else {  // general kernel\n        NDArray ufeat =\n            (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];\n        NDArray efeat =\n            (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];\n        SWITCH_OP(op, Op, {\n          cuda::SpMMCsr<IdType, DType, Op, cuda::reduce::Sum<IdType, DType>>(\n              bcast, csr, ufeat, efeat, (*vec_out)[dst_id], NullArray(),\n              NullArray());\n        });\n      }\n    } else if (reduce == \"max\") {\n      SWITCH_OP(op, Op, {\n        NDArray ufeat =\n            (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];\n        NDArray efeat =\n            (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];\n        cuda::SpMMCmpCsrHetero<\n            IdType, DType, Op, cuda::reduce::Max<IdType, DType>>(\n            bcast, csr, ufeat, efeat, (*vec_out)[dst_id], (*out_aux)[0][dst_id],\n            (*out_aux)[1][dst_id], (*out_aux)[2][dst_id], (*out_aux)[3][dst_id],\n            src_id, etype);\n      });\n    } else if (reduce == \"min\") {\n      SWITCH_OP(op, Op, {\n        NDArray ufeat =\n            (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];\n        NDArray efeat =\n            (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];\n        cuda::SpMMCmpCsrHetero<\n            IdType, DType, Op, cuda::reduce::Min<IdType, DType>>(\n            bcast, csr, ufeat, efeat, (*vec_out)[dst_id], (*out_aux)[0][dst_id],\n            (*out_aux)[1][dst_id], (*out_aux)[2][dst_id], (*out_aux)[3][dst_id],\n            src_id, etype);\n      });\n    } else {\n      LOG(FATAL) << \"Not implemented\";\n    }\n  }\n\n  if (use_legacy_cusparsemm) {\n    // transpose output\n    for (dgl_type_t ntype = 0; ntype < (*vec_out).size(); ++ntype) {\n      const int m = (*vec_out)[ntype]->shape[0];\n      const int n = (*vec_out)[ntype]->shape[1];\n      if (m == 0) continue;\n      DType* C_data = static_cast<DType*>((*vec_out)[ntype]->data);\n      _Transpose(trans_out[ntype], C_data, n, m);\n      device->FreeWorkspace(vec_csr[0].indptr->ctx, trans_out[ntype]);\n    }\n  }\n}\n\ntemplate void SpMMCsrHetero<kDGLCUDA, int32_t, __half>(\n    const std::string& op, const std::string& reduce, const BcastOff& bcast,\n    const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,\n    const std::vector<NDArray>& efeat, std::vector<NDArray>* out,\n    std::vector<std::vector<NDArray>>* out_aux,\n    const std::vector<dgl_type_t>& ufeat_ntids,\n    const std::vector<dgl_type_t>& out_ntids);\ntemplate void SpMMCsrHetero<kDGLCUDA, int64_t, __half>(\n    const std::string& op, const std::string& reduce, const BcastOff& bcast,\n    const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,\n    const std::vector<NDArray>& efeat, std::vector<NDArray>* out,\n    std::vector<std::vector<NDArray>>* out_aux,\n    const std::vector<dgl_type_t>& ufeat_ntids,\n    const std::vector<dgl_type_t>& out_ntids);\n#if BF16_ENABLED\ntemplate void SpMMCsrHetero<kDGLCUDA, int32_t, __nv_bfloat16>(\n    const std::string& op, const std::string& reduce, const BcastOff& bcast,\n    const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,\n    const std::vector<NDArray>& efeat, std::vector<NDArray>* out,\n    std::vector<std::vector<NDArray>>* out_aux,\n    const std::vector<dgl_type_t>& ufeat_ntids,\n    const std::vector<dgl_type_t>& out_ntids);\ntemplate void SpMMCsrHetero<kDGLCUDA, int64_t, __nv_bfloat16>(\n    const std::string& op, const std::string& reduce, const BcastOff& bcast,\n    const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,\n    const std::vector<NDArray>& efeat, std::vector<NDArray>* out,\n    std::vector<std::vector<NDArray>>* out_aux,\n    const std::vector<dgl_type_t>& ufeat_ntids,\n    const std::vector<dgl_type_t>& out_ntids);\n#endif  // BF16_ENABLED\ntemplate void SpMMCsrHetero<kDGLCUDA, int32_t, float>(\n    const std::string& op, const std::string& reduce, const BcastOff& bcast,\n    const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,\n    const std::vector<NDArray>& efeat, std::vector<NDArray>* out,\n    std::vector<std::vector<NDArray>>* out_aux,\n    const std::vector<dgl_type_t>& ufeat_ntids,\n    const std::vector<dgl_type_t>& out_ntids);\ntemplate void SpMMCsrHetero<kDGLCUDA, int64_t, float>(\n    const std::string& op, const std::string& reduce, const BcastOff& bcast,\n    const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,\n    const std::vector<NDArray>& efeat, std::vector<NDArray>* out,\n    std::vector<std::vector<NDArray>>* out_aux,\n    const std::vector<dgl_type_t>& ufeat_ntids,\n    const std::vector<dgl_type_t>& out_ntids);\ntemplate void SpMMCsrHetero<kDGLCUDA, int32_t, double>(\n    const std::string& op, const std::string& reduce, const BcastOff& bcast,\n    const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,\n    const std::vector<NDArray>& efeat, std::vector<NDArray>* out,\n    std::vector<std::vector<NDArray>>* out_aux,\n    const std::vector<dgl_type_t>& ufeat_ntids,\n    const std::vector<dgl_type_t>& out_ntids);\ntemplate void SpMMCsrHetero<kDGLCUDA, int64_t, double>(\n    const std::string& op, const std::string& reduce, const BcastOff& bcast,\n    const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,\n    const std::vector<NDArray>& efeat, std::vector<NDArray>* out,\n    std::vector<std::vector<NDArray>>* out_aux,\n    const std::vector<dgl_type_t>& ufeat_ntids,\n    const std::vector<dgl_type_t>& out_ntids);\n\n}  // namespace aten\n}  // namespace dgl\n"
  },
  {
    "path": "src/array/cuda/utils.cu",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file array/cuda/utils.cu\n * @brief Utilities for CUDA kernels.\n */\n\n#include <cub/cub.cuh>\n\n#include \"../../runtime/cuda/cuda_common.h\"\n#include \"./utils.h\"\n\nnamespace dgl {\nnamespace cuda {\n\nbool AllTrue(int8_t* flags, int64_t length, const DGLContext& ctx) {\n  auto device = runtime::DeviceAPI::Get(ctx);\n  int8_t* rst = static_cast<int8_t*>(device->AllocWorkspace(ctx, 1));\n  // Call CUB's reduction\n  size_t workspace_size = 0;\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  CUDA_CALL(cub::DeviceReduce::Min(\n      nullptr, workspace_size, flags, rst, length, stream));\n  void* workspace = device->AllocWorkspace(ctx, workspace_size);\n  CUDA_CALL(cub::DeviceReduce::Min(\n      workspace, workspace_size, flags, rst, length, stream));\n  int8_t cpu_rst = GetCUDAScalar(device, ctx, rst);\n  device->FreeWorkspace(ctx, workspace);\n  device->FreeWorkspace(ctx, rst);\n  return cpu_rst == 1;\n}\n\n}  // namespace cuda\n}  // namespace dgl\n"
  },
  {
    "path": "src/array/cuda/utils.h",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file array/cuda/utils.h\n * @brief Utilities for CUDA kernels.\n */\n#ifndef DGL_ARRAY_CUDA_UTILS_H_\n#define DGL_ARRAY_CUDA_UTILS_H_\n\n#include <dgl/runtime/c_runtime_api.h>\n#include <dgl/runtime/device_api.h>\n#include <dgl/runtime/ndarray.h>\n#include <dmlc/logging.h>\n\n#include <cub/cub.cuh>\n#include <type_traits>\n\n#include \"../../runtime/cuda/cuda_common.h\"\n\nnamespace dgl {\nnamespace cuda {\n\n#define CUDA_MAX_NUM_BLOCKS_X 0x7FFFFFFF\n#define CUDA_MAX_NUM_BLOCKS_Y 0xFFFF\n#define CUDA_MAX_NUM_BLOCKS_Z 0xFFFF\n// The max number of threads per block\n#define CUDA_MAX_NUM_THREADS 256\n\n/** @brief Calculate the number of threads needed given the dimension length.\n *\n * It finds the biggest number that is smaller than min(dim, max_nthrs)\n * and is also power of two.\n */\ninline int FindNumThreads(int dim, int max_nthrs = CUDA_MAX_NUM_THREADS) {\n  CHECK_GE(dim, 0);\n  if (dim == 0) return 1;\n  int ret = max_nthrs;\n  while (ret > dim) {\n    ret = ret >> 1;\n  }\n  return ret;\n}\n\ntemplate <typename T>\nint _NumberOfBits(const T& range) {\n  if (range <= 1) {\n    // ranges of 0 or 1 require no bits to store\n    return 0;\n  }\n\n  int bits = 1;\n  const auto urange = static_cast<std::make_unsigned_t<T>>(range);\n  while (bits < static_cast<int>(sizeof(T) * 8) && (1ull << bits) < urange) {\n    ++bits;\n  }\n\n  if (bits < static_cast<int>(sizeof(T) * 8)) {\n    CHECK_EQ((range - 1) >> bits, 0);\n  }\n  CHECK_NE((range - 1) >> (bits - 1), 0);\n\n  return bits;\n}\n\n/**\n * @brief Find number of blocks is smaller than nblks and max_nblks\n * on the given axis ('x', 'y' or 'z').\n */\ntemplate <char axis>\ninline int FindNumBlocks(int nblks, int max_nblks = -1) {\n  int default_max_nblks = -1;\n  switch (axis) {\n    case 'x':\n      default_max_nblks = CUDA_MAX_NUM_BLOCKS_X;\n      break;\n    case 'y':\n      default_max_nblks = CUDA_MAX_NUM_BLOCKS_Y;\n      break;\n    case 'z':\n      default_max_nblks = CUDA_MAX_NUM_BLOCKS_Z;\n      break;\n    default:\n      LOG(FATAL) << \"Axis \" << axis << \" not recognized\";\n      break;\n  }\n  if (max_nblks == -1) max_nblks = default_max_nblks;\n  CHECK_NE(nblks, 0);\n  if (nblks < max_nblks) return nblks;\n  return max_nblks;\n}\n\ntemplate <typename T>\n__device__ __forceinline__ T _ldg(T* addr) {\n#if __CUDA_ARCH__ >= 350\n  return __ldg(addr);\n#else\n  return *addr;\n#endif\n}\n\n/**\n * @brief Return true if the given bool flag array is all true.\n * The input bool array is in int8_t type so it is aligned with byte address.\n *\n * @param flags The bool array.\n * @param length The length.\n * @param ctx Device context.\n * @return True if all the flags are true.\n */\nbool AllTrue(int8_t* flags, int64_t length, const DGLContext& ctx);\n\n/**\n * @brief CUDA Kernel of filling the vector started from ptr of size length\n *        with val.\n * @note internal use only.\n */\ntemplate <typename DType>\n__global__ void _FillKernel(DType* ptr, size_t length, DType val) {\n  int tx = blockIdx.x * blockDim.x + threadIdx.x;\n  int stride_x = gridDim.x * blockDim.x;\n  while (tx < length) {\n    ptr[tx] = val;\n    tx += stride_x;\n  }\n}\n\n/** @brief Fill the vector started from ptr of size length with val */\ntemplate <typename DType>\nvoid _Fill(DType* ptr, size_t length, DType val) {\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  int nt = FindNumThreads(length);\n  int nb =\n      (length + nt - 1) / nt;  // on x-axis, no need to worry about upperbound.\n  CUDA_KERNEL_CALL(cuda::_FillKernel, nb, nt, 0, stream, ptr, length, val);\n}\n\n/**\n * @brief Search adjacency list linearly for each (row, col) pair and\n * write the data under the matched position in the indices array to the output.\n *\n * If there is no match, the value in \\c filler is written.\n * If there are multiple matches, only the first match is written.\n * If the given data array is null, write the matched position to the output.\n */\ntemplate <typename IdType, typename DType>\n__global__ void _LinearSearchKernel(\n    const IdType* indptr, const IdType* indices, const IdType* data,\n    const IdType* row, const IdType* col, int64_t row_stride,\n    int64_t col_stride, int64_t length, const DType* weights, DType filler,\n    DType* out) {\n  int tx = blockIdx.x * blockDim.x + threadIdx.x;\n  const int stride_x = gridDim.x * blockDim.x;\n  while (tx < length) {\n    int rpos = tx * row_stride, cpos = tx * col_stride;\n    IdType v = -1;\n    const IdType r = row[rpos], c = col[cpos];\n    for (IdType i = indptr[r]; i < indptr[r + 1]; ++i) {\n      if (indices[i] == c) {\n        v = data ? data[i] : i;\n        break;\n      }\n    }\n    if (v == -1) {\n      out[tx] = filler;\n    } else {\n      // The casts here are to be able to handle DType being __half.\n      // GCC treats int64_t as a distinct type from long long, so\n      // without the explcit cast to long long, it errors out saying\n      // that the implicit cast results in an ambiguous choice of\n      // constructor for __half.\n      // The using statement is to avoid a linter error about using\n      // long or long long.\n      using LongLong = long long;  // NOLINT\n      out[tx] = weights ? weights[v] : DType(LongLong(v));\n    }\n    tx += stride_x;\n  }\n}\n\n#if BF16_ENABLED\n/**\n * @brief Specialization for bf16 because conversion from long long to bfloat16\n * doesn't exist before SM80.\n */\ntemplate <typename IdType>\n__global__ void _LinearSearchKernel(\n    const IdType* indptr, const IdType* indices, const IdType* data,\n    const IdType* row, const IdType* col, int64_t row_stride,\n    int64_t col_stride, int64_t length, const __nv_bfloat16* weights,\n    __nv_bfloat16 filler, __nv_bfloat16* out) {\n  int tx = blockIdx.x * blockDim.x + threadIdx.x;\n  const int stride_x = gridDim.x * blockDim.x;\n  while (tx < length) {\n    int rpos = tx * row_stride, cpos = tx * col_stride;\n    IdType v = -1;\n    const IdType r = row[rpos], c = col[cpos];\n    for (IdType i = indptr[r]; i < indptr[r + 1]; ++i) {\n      if (indices[i] == c) {\n        v = data ? data[i] : i;\n        break;\n      }\n    }\n    if (v == -1) {\n      out[tx] = filler;\n    } else {\n      // If the result is saved in bf16, it should be fine to convert it to\n      // float first\n      out[tx] = weights ? weights[v] : __nv_bfloat16(static_cast<float>(v));\n    }\n    tx += stride_x;\n  }\n}\n#endif  // BF16_ENABLED\n\ntemplate <typename DType>\ninline DType GetCUDAScalar(\n    runtime::DeviceAPI* device_api, DGLContext ctx, const DType* cuda_ptr) {\n  DType result;\n  device_api->CopyDataFromTo(\n      cuda_ptr, 0, &result, 0, sizeof(result), ctx, DGLContext{kDGLCPU, 0},\n      DGLDataTypeTraits<DType>::dtype);\n  return result;\n}\n\n/**\n * @brief Given a sorted array and a value this function returns the index\n * of the first element which compares greater than value.\n *\n * This function assumes 0-based index\n * @param A: ascending sorted array\n * @param n: size of the A\n * @param x: value to search in A\n * @return index, i, of the first element st. A[i]>x. If x>=A[n-1] returns n.\n * if x<A[0] then it returns 0.\n */\ntemplate <typename IdType>\n__device__ IdType _UpperBound(const IdType* A, int64_t n, IdType x) {\n  IdType l = 0, r = n, m = 0;\n  while (l < r) {\n    m = l + (r - l) / 2;\n    if (x >= A[m]) {\n      l = m + 1;\n    } else {\n      r = m;\n    }\n  }\n  return l;\n}\n\n/**\n * @brief Given a sorted array and a value this function returns the index\n * of the element who is equal to val. If not exist returns n+1\n *\n * This function assumes 0-based index\n * @param A: ascending sorted array\n * @param n: size of the A\n * @param x: value to search in A\n * @return index, i, st. A[i]==x. If such an index not exists returns 'n'.\n */\ntemplate <typename IdType>\n__device__ IdType _BinarySearch(const IdType* A, int64_t n, IdType x) {\n  IdType l = 0, r = n - 1, m = 0;\n  while (l <= r) {\n    m = l + (r - l) / 2;\n    if (A[m] == x) {\n      return m;\n    }\n    if (A[m] < x) {\n      l = m + 1;\n    } else {\n      r = m - 1;\n    }\n  }\n  return n;  // not found\n}\n\ntemplate <typename DType, typename BoolType>\nvoid MaskSelect(\n    runtime::DeviceAPI* device, const DGLContext& ctx, const DType* input,\n    const BoolType* mask, DType* output, int64_t n, int64_t* rst,\n    cudaStream_t stream) {\n  size_t workspace_size = 0;\n  CUDA_CALL(cub::DeviceSelect::Flagged(\n      nullptr, workspace_size, input, mask, output, rst, n, stream));\n  void* workspace = device->AllocWorkspace(ctx, workspace_size);\n  CUDA_CALL(cub::DeviceSelect::Flagged(\n      workspace, workspace_size, input, mask, output, rst, n, stream));\n  device->FreeWorkspace(ctx, workspace);\n}\n\ninline void* GetDevicePointer(runtime::NDArray array) {\n  void* ptr = array->data;\n  if (array.IsPinned()) {\n    CUDA_CALL(cudaHostGetDevicePointer(&ptr, ptr, 0));\n  }\n  return ptr;\n}\n\n}  // namespace cuda\n}  // namespace dgl\n\n#endif  // DGL_ARRAY_CUDA_UTILS_H_\n"
  },
  {
    "path": "src/array/cuda/uvm/array_index_select_uvm.cu",
    "content": "/**\n *  Copyright (c) 2019-2022 by Contributors\n * @file array/cuda/uvm/array_index_select_uvm.cu\n * @brief Array index select GPU implementation\n */\n#include <dgl/array.h>\n\n#include \"../../../runtime/cuda/cuda_common.h\"\n#include \"../array_index_select.cuh\"\n#include \"../utils.h\"\n#include \"./array_index_select_uvm.cuh\"\n\nnamespace dgl {\nusing runtime::NDArray;\nnamespace aten {\nnamespace impl {\n\ntemplate <typename DType, typename IdType>\nNDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) {\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  const int64_t arr_len = array->shape[0];\n  const int64_t len = index->shape[0];\n  int64_t num_feat = 1;\n  std::vector<int64_t> shape{len};\n\n  CHECK(array.IsPinned());\n  const DType* array_data = static_cast<DType*>(cuda::GetDevicePointer(array));\n  CHECK_EQ(index->ctx.device_type, kDGLCUDA);\n\n  for (int d = 1; d < array->ndim; ++d) {\n    num_feat *= array->shape[d];\n    shape.emplace_back(array->shape[d]);\n  }\n\n  NDArray ret = NDArray::Empty(shape, array->dtype, index->ctx);\n  if (len == 0 || arr_len * num_feat == 0) return ret;\n  DType* ret_data = static_cast<DType*>(ret->data);\n\n  auto res = Sort(index, cuda::_NumberOfBits(arr_len));\n  const IdType* idx_data = static_cast<IdType*>(res.first->data);\n  const int64_t* perm_data = static_cast<int64_t*>(res.second->data);\n\n  if (num_feat == 1) {\n    const int nt = cuda::FindNumThreads(len);\n    const int nb = (len + nt - 1) / nt;\n    CUDA_KERNEL_CALL(\n        IndexSelectSingleKernel, nb, nt, 0, stream, array_data, idx_data, len,\n        arr_len, ret_data, perm_data);\n  } else {\n    dim3 block(256, 1);\n    while (static_cast<int64_t>(block.x) >= 2 * num_feat) {\n      block.x /= 2;\n      block.y *= 2;\n    }\n    const dim3 grid((len + block.y - 1) / block.y);\n    if (num_feat * sizeof(DType) < 2 * CACHE_LINE_SIZE) {\n      CUDA_KERNEL_CALL(\n          IndexSelectMultiKernel, grid, block, 0, stream, array_data, num_feat,\n          idx_data, len, arr_len, ret_data, perm_data);\n    } else {\n      CUDA_KERNEL_CALL(\n          IndexSelectMultiKernelAligned, grid, block, 0, stream, array_data,\n          num_feat, idx_data, len, arr_len, ret_data, perm_data);\n    }\n  }\n  return ret;\n}\n\n// floating point types are treated as their equal width integer types\ntemplate NDArray IndexSelectCPUFromGPU<int8_t, int32_t>(NDArray, IdArray);\ntemplate NDArray IndexSelectCPUFromGPU<int8_t, int64_t>(NDArray, IdArray);\ntemplate NDArray IndexSelectCPUFromGPU<int16_t, int32_t>(NDArray, IdArray);\ntemplate NDArray IndexSelectCPUFromGPU<int16_t, int64_t>(NDArray, IdArray);\ntemplate NDArray IndexSelectCPUFromGPU<int32_t, int32_t>(NDArray, IdArray);\ntemplate NDArray IndexSelectCPUFromGPU<int32_t, int64_t>(NDArray, IdArray);\ntemplate NDArray IndexSelectCPUFromGPU<int64_t, int32_t>(NDArray, IdArray);\ntemplate NDArray IndexSelectCPUFromGPU<int64_t, int64_t>(NDArray, IdArray);\n\ntemplate <typename DType, typename IdType>\nvoid IndexScatterGPUToCPU(NDArray dest, IdArray index, NDArray source) {\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  const DType* source_data = static_cast<DType*>(source->data);\n  const IdType* idx_data = static_cast<IdType*>(index->data);\n  const int64_t arr_len = dest->shape[0];\n  const int64_t len = index->shape[0];\n  int64_t num_feat = 1;\n  std::vector<int64_t> shape{len};\n\n  CHECK(dest.IsPinned());\n  DType* dest_data = static_cast<DType*>(cuda::GetDevicePointer(dest));\n  CHECK_EQ(index->ctx.device_type, kDGLCUDA);\n  CHECK_EQ(source->ctx.device_type, kDGLCUDA);\n\n  for (int d = 1; d < source->ndim; ++d) {\n    num_feat *= source->shape[d];\n  }\n\n  if (len == 0) return;\n\n  if (num_feat == 1) {\n    const int nt = cuda::FindNumThreads(len);\n    const int nb = (len + nt - 1) / nt;\n    CUDA_KERNEL_CALL(\n        IndexScatterSingleKernel, nb, nt, 0, stream, source_data, idx_data, len,\n        arr_len, dest_data);\n  } else {\n    dim3 block(256, 1);\n    while (static_cast<int64_t>(block.x) >= 2 * num_feat) {\n      block.x /= 2;\n      block.y *= 2;\n    }\n    const dim3 grid((len + block.y - 1) / block.y);\n    CUDA_KERNEL_CALL(\n        IndexScatterMultiKernel, grid, block, 0, stream, source_data, num_feat,\n        idx_data, len, arr_len, dest_data);\n  }\n}\n\n// floating point types are treated as their equal width integer types\ntemplate void IndexScatterGPUToCPU<int8_t, int32_t>(NDArray, IdArray, NDArray);\ntemplate void IndexScatterGPUToCPU<int8_t, int64_t>(NDArray, IdArray, NDArray);\ntemplate void IndexScatterGPUToCPU<int16_t, int32_t>(NDArray, IdArray, NDArray);\ntemplate void IndexScatterGPUToCPU<int16_t, int64_t>(NDArray, IdArray, NDArray);\ntemplate void IndexScatterGPUToCPU<int32_t, int32_t>(NDArray, IdArray, NDArray);\ntemplate void IndexScatterGPUToCPU<int32_t, int64_t>(NDArray, IdArray, NDArray);\ntemplate void IndexScatterGPUToCPU<int64_t, int32_t>(NDArray, IdArray, NDArray);\ntemplate void IndexScatterGPUToCPU<int64_t, int64_t>(NDArray, IdArray, NDArray);\n\n}  // namespace impl\n}  // namespace aten\n}  // namespace dgl\n"
  },
  {
    "path": "src/array/cuda/uvm/array_index_select_uvm.cuh",
    "content": "/**\n *  Copyright (c) 2021 by Contributors\n * @file array/cpu/array_index_select_uvm.cuh\n * @brief Array index select GPU kernel implementation\n */\n\n#ifndef DGL_ARRAY_CUDA_UVM_ARRAY_INDEX_SELECT_UVM_CUH_\n#define DGL_ARRAY_CUDA_UVM_ARRAY_INDEX_SELECT_UVM_CUH_\n\n#define CACHE_LINE_SIZE 128\n\nnamespace dgl {\nnamespace aten {\nnamespace impl {\n\n/**\n *  This is a cross-device access version of IndexSelectMultiKernel.\n *  Since the memory access over PCIe is more sensitive to the\n *  data access aligment (cacheline), we need a separate version here.\n */\ntemplate <typename DType, typename IdType>\n__global__ void IndexSelectMultiKernelAligned(\n    const DType* const array, const int64_t num_feat, const IdType* const index,\n    const int64_t length, const int64_t arr_len, DType* const out,\n    const int64_t* perm = nullptr) {\n  int64_t out_row_index = blockIdx.x * blockDim.y + threadIdx.y;\n\n  const int64_t stride = blockDim.y * gridDim.x;\n\n  while (out_row_index < length) {\n    int64_t col = threadIdx.x;\n    const int64_t in_row = index[out_row_index];\n    assert(in_row >= 0 && in_row < arr_len);\n    const int64_t idx_offset =\n        ((uint64_t)(&array[in_row * num_feat]) % CACHE_LINE_SIZE) /\n        sizeof(DType);\n    col = col - idx_offset;\n    const auto out_row = perm ? perm[out_row_index] : out_row_index;\n    while (col < num_feat) {\n      if (col >= 0)\n        out[out_row * num_feat + col] = array[in_row * num_feat + col];\n      col += blockDim.x;\n    }\n    out_row_index += stride;\n  }\n}\n\n}  // namespace impl\n}  // namespace aten\n}  // namespace dgl\n\n#endif  // DGL_ARRAY_CUDA_UVM_ARRAY_INDEX_SELECT_UVM_CUH_\n"
  },
  {
    "path": "src/array/filter.cc",
    "content": "/**\n *  Copyright (c) 2021 by Contributors\n * @file array/filter.cc\n * @brief Object for selecting items in a set, or selecting items not in a set.\n */\n\n#include \"./filter.h\"\n\n#include <dgl/packed_func_ext.h>\n#include <dgl/runtime/packed_func.h>\n#include <dgl/runtime/registry.h>\n\nnamespace dgl {\nnamespace array {\n\nusing namespace dgl::runtime;\n\ntemplate <DGLDeviceType XPU, typename IdType>\nFilterRef CreateSetFilter(IdArray set);\n\nDGL_REGISTER_GLOBAL(\"utils.filter._CAPI_DGLFilterCreateFromSet\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      IdArray array = args[0];\n      auto ctx = array->ctx;\n      // TODO(nv-dlasalle): Implement CPU version.\n      if (ctx.device_type == kDGLCUDA) {\n#ifdef DGL_USE_CUDA\n        ATEN_ID_TYPE_SWITCH(array->dtype, IdType, {\n          *rv = CreateSetFilter<kDGLCUDA, IdType>(array);\n        });\n#else\n        LOG(FATAL) << \"GPU support not compiled.\";\n#endif\n      } else {\n        LOG(FATAL) << \"CPU support not yet implemented.\";\n      }\n    });\n\nDGL_REGISTER_GLOBAL(\"utils.filter._CAPI_DGLFilterFindIncludedIndices\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      FilterRef filter = args[0];\n      IdArray array = args[1];\n      *rv = filter->find_included_indices(array);\n    });\n\nDGL_REGISTER_GLOBAL(\"utils.filter._CAPI_DGLFilterFindExcludedIndices\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      FilterRef filter = args[0];\n      IdArray array = args[1];\n      *rv = filter->find_excluded_indices(array);\n    });\n\n}  // namespace array\n}  // namespace dgl\n"
  },
  {
    "path": "src/array/filter.h",
    "content": "/**\n *  Copyright (c) 2021 by Contributors\n * @file array/filter.h\n * @brief Object for selecting items in a set, or selecting items not in a set.\n */\n\n#ifndef DGL_ARRAY_FILTER_H_\n#define DGL_ARRAY_FILTER_H_\n\n#include <dgl/array.h>\n#include <dgl/runtime/object.h>\n\nnamespace dgl {\nnamespace array {\n\nclass Filter : public runtime::Object {\n public:\n  static constexpr const char* _type_key = \"array.Filter\";\n  DGL_DECLARE_OBJECT_TYPE_INFO(Filter, Object);\n\n  /**\n   * @brief From the test set of items, get the index of those which are\n   * included by this filter.\n   *\n   * @param test The set of items to check for.\n   *\n   * @return The indices of the items from `test` that are selected by\n   * this filter.\n   */\n  virtual IdArray find_included_indices(IdArray test) = 0;\n\n  /**\n   * @brief From the test set of items, get the indices of those which are\n   * excluded by this filter.\n   *\n   * @param test The set of items to check for.\n   *\n   * @return The indices of the items from `test` that are not selected by this\n   * filter.\n   */\n  virtual IdArray find_excluded_indices(IdArray test) = 0;\n};\n\nDGL_DEFINE_OBJECT_REF(FilterRef, Filter);\n\n}  // namespace array\n}  // namespace dgl\n\n#endif  // DGL_ARRAY_FILTER_H_\n"
  },
  {
    "path": "src/array/kernel.cc",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file array/kernel.cc\n * @brief New kernels\n */\n#include <dgl/base_heterograph.h>\n#include <dgl/packed_func_ext.h>\n\n#include \"../c_api_common.h\"\n#include \"./check.h\"\n#include \"kernel_decl.h\"\n\nusing namespace dgl::runtime;\n\nnamespace dgl {\nnamespace aten {\nnamespace {}  // namespace\n\n/** @brief Generalized Sparse Matrix-Matrix Multiplication. */\nvoid SpMM(\n    const std::string& op, const std::string& reduce, HeteroGraphPtr graph,\n    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux) {\n  // TODO(zihao): format tuning\n  SparseFormat format = graph->SelectFormat(0, CSC_CODE);\n  const auto& bcast = CalcBcastOff(op, ufeat, efeat);\n\n  ATEN_XPU_SWITCH_CUDA(graph->Context().device_type, XPU, \"SpMM\", {\n    ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, {\n      ATEN_FLOAT_TYPE_SWITCH_16BITS(out->dtype, Dtype, XPU, \"Feature data\", {\n        if (format == SparseFormat::kCSC) {\n          SpMMCsr<XPU, IdType, Dtype>(\n              op, reduce, bcast, graph->GetCSCMatrix(0), ufeat, efeat, out,\n              out_aux);\n        } else if (format == SparseFormat::kCOO) {\n          SpMMCoo<XPU, IdType, Dtype>(\n              op, reduce, bcast, graph->GetCOOMatrix(0), ufeat, efeat, out,\n              out_aux);\n        } else {\n          LOG(FATAL) << \"SpMM only supports CSC and COO formats\";\n        }\n      });\n    });\n  });\n}\n\n/** @brief Generalized segmented dense Matrix-Matrix Multiplication. */\nvoid SegmentMM(\n    const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,\n    bool A_trans, bool B_trans) {\n  CHECK_EQ(A->ndim, 2) << \"segment_mm expects a 2D tensor for the first input.\";\n  CHECK_EQ(B->ndim, 3)\n      << \"segment_mm expects a 3D tensor for the second input.\";\n  CHECK(!A_trans);\n  if (B_trans) {\n    CHECK_EQ(A->shape[1], B->shape[2])\n        << \"segment_mm expects A.shape[1] == B.shape[2] when B_trans=True\";\n  } else {\n    CHECK_EQ(A->shape[1], B->shape[1])\n        << \"segment_mm expects A.shape[1] == B.shape[1]\";\n  }\n  CHECK_EQ(B->shape[0], seglen_A.NumElements())\n      << \"segment_mm expects len(seglen_A) == B.shape[0]\";\n  CHECK_EQ(seglen_A->ctx.device_type, kDGLCPU)\n      << \"segment_mm expects seglen_A to be on CPU.\";\n  CHECK(A->ctx == B->ctx)\n      << \"segment_mm expects A and B to be of the same device\";\n  ATEN_XPU_SWITCH_CUDA(A->ctx.device_type, XPU, \"SegmentMM\", {\n    ATEN_ID_TYPE_SWITCH(seglen_A->dtype, IdType, {\n      ATEN_FLOAT_TYPE_SWITCH_16BITS(A->dtype, Dtype, XPU, \"Feature data\", {\n        SegmentMM<XPU, IdType, Dtype>(A, B, C, seglen_A, A_trans, B_trans);\n      });\n    });\n  });\n}\n\nvoid SegmentMMBackwardB(\n    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen) {\n  CHECK_EQ(A->ndim, 2) << \"segment_mm_backward operator expects a 2D tensor \"\n                          \"for the first input.\";\n  CHECK_EQ(dC->ndim, 2) << \"segment_mm_backward operator expects a 2D tensor \"\n                           \"for the second input.\";\n  CHECK_EQ(seglen->ctx.device_type, kDGLCPU)\n      << \"segment_mm expects seglen to be on CPU.\";\n  ATEN_XPU_SWITCH_CUDA(A->ctx.device_type, XPU, \"SegmentMMBackwardB\", {\n    ATEN_ID_TYPE_SWITCH(seglen->dtype, IdType, {\n      ATEN_FLOAT_TYPE_SWITCH_16BITS(A->dtype, Dtype, XPU, \"Feature data\", {\n        SegmentMMBackwardB<XPU, IdType, Dtype>(A, dC, dB, seglen);\n      });\n    });\n  });\n}\n\n/** @brief Generalized Dense Matrix-Matrix Multiplication according to relation\n * types. */\nvoid GatherMM(\n    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,\n    const NDArray idx_b) {\n  CHECK_EQ(A->ndim, 2)\n      << \"gather_mm operator expects a 2D tensor for the first input.\";\n  CHECK_EQ(B->ndim, 3)\n      << \"gather_mm operator expects a 3D tensor for the second input.\";\n  CHECK(A->ctx == B->ctx)\n      << \"gather_mm expects all arguments to be on the same device.\";\n  if (aten::IsNullArray(idx_a)) {\n    CHECK_EQ(A->shape[0], idx_b->shape[0])\n        << \"gather_mm expects len(idx_b) == A.shape[0] when idx_a is None.\";\n    CHECK(A->ctx == idx_b->ctx)\n        << \"gather_mm expects all arguments to be on the same device.\";\n  } else if (aten::IsNullArray(idx_b)) {\n    CHECK_EQ(B->shape[0], idx_a->shape[0])\n        << \"gather_mm expects len(idx_a) == B.shape[0] when idx_b is None.\";\n    CHECK(A->ctx == idx_a->ctx)\n        << \"gather_mm expects all arguments to be on the same device.\";\n  } else {\n    CHECK_EQ(idx_a->shape[0], idx_b->shape[0])\n        << \"gather_mm expects len(idx_a) == len(idx_b) when both idx_a and \"\n           \"idx_b are given.\";\n    CHECK(A->ctx == idx_a->ctx && A->ctx == idx_b->ctx)\n        << \"gather_mm expects all arguments to be on the same device.\";\n  }\n  const auto idtype = aten::IsNullArray(idx_a) ? idx_b->dtype : idx_a->dtype;\n  ATEN_XPU_SWITCH_CUDA(A->ctx.device_type, XPU, \"GatherMM\", {\n    ATEN_ID_TYPE_SWITCH(idtype, IdType, {\n      ATEN_FLOAT_TYPE_SWITCH_16BITS(A->dtype, Dtype, XPU, \"Feature data\", {\n        GatherMM<XPU, IdType, Dtype>(A, B, C, idx_a, idx_b);\n      });\n    });\n  });\n}\n\n/** @brief Generalized Dense Matrix-Matrix Multiplication according to relation\n * types. */\nvoid GatherMMScatter(\n    const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,\n    const NDArray idx_b, const NDArray idx_c) {\n  CHECK_EQ(A->ndim, 2)\n      << \"gather_mm_scatter expects a 2D tensor for the first input.\";\n  CHECK(A->ctx == B->ctx)\n      << \"gather_mm_scatter expects all arguments to be on the same device.\";\n  if (!aten::IsNullArray(idx_c))\n    CHECK(A->ctx == idx_c->ctx)\n        << \"gather_mm_scatter expects all arguments to be on the same device.\";\n  if (aten::IsNullArray(idx_a) && !aten::IsNullArray(idx_b)) {\n    CHECK_EQ(A->shape[0], idx_b->shape[0])\n        << \"gather_mm_scatter expects len(idx_b) == A.shape[0] when idx_a is \"\n           \"None.\";\n    CHECK(A->ctx == idx_b->ctx)\n        << \"gather_mm_scatter expects all arguments to be on the same device.\";\n  } else if (aten::IsNullArray(idx_b) && !aten::IsNullArray(idx_a)) {\n    CHECK_EQ(B->shape[0], idx_a->shape[0])\n        << \"gather_mm_scatter expects len(idx_a) == B.shape[0] when idx_b is \"\n           \"None.\";\n    CHECK(A->ctx == idx_a->ctx)\n        << \"gather_mm_scatter expects all arguments to be on the same device.\";\n  } else if (!aten::IsNullArray(idx_b) && !aten::IsNullArray(idx_a)) {\n    CHECK_EQ(idx_a->shape[0], idx_b->shape[0])\n        << \"gather_mm_scatter expects len(idx_a) == len(idx_b) \"\n        << \"when both idx_a and idx_b are given.\";\n    CHECK(A->ctx == idx_a->ctx && A->ctx == idx_b->ctx)\n        << \"gather_mm_scatter expects all arguments to be on the same device.\";\n  }\n  ATEN_XPU_SWITCH_CUDA(A->ctx.device_type, XPU, \"GatherMM\", {\n    ATEN_ID_TYPE_SWITCH(idx_c->dtype, IdType, {\n      ATEN_FLOAT_TYPE_SWITCH_16BITS(A->dtype, Dtype, XPU, \"Feature data\", {\n        GatherMMScatter<XPU, IdType, Dtype>(A, B, C, idx_a, idx_b, idx_c);\n      });\n    });\n  });\n}\n\n/** @brief Generalized Sparse Matrix-Matrix Multiplication with hetero-graph\n * support. */\nvoid SpMMHetero(\n    const std::string& op, const std::string& reduce, HeteroGraphPtr graph,\n    const std::vector<NDArray>& ufeat_vec,\n    const std::vector<NDArray>& efeat_vec, std::vector<NDArray>* out,\n    std::vector<std::vector<NDArray>>* out_aux) {\n  SparseFormat format = graph->SelectFormat(0, CSC_CODE);\n\n  std::vector<CSRMatrix> vec_graph;\n  std::vector<dgl_type_t> ufeat_eid;\n  std::vector<dgl_type_t> efeat_eid;\n  std::vector<dgl_type_t> out_eid;\n  auto pair = graph->meta_graph()->FindEdge(0);  // first etype\n  NDArray ufeat_etype0 =\n      (ufeat_vec.size() == 0) ? NullArray() : ufeat_vec[pair.first];\n  NDArray efeat_etype0 = (efeat_vec.size() == 0) ? NullArray() : efeat_vec[0];\n  for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {\n    vec_graph.push_back(graph->GetCSCMatrix(etype));\n    auto pair = graph->meta_graph()->FindEdge(etype);\n    ufeat_eid.push_back(pair.first);\n    efeat_eid.push_back(etype);\n    out_eid.push_back(pair.second);\n    if (ufeat_etype0->shape[1] != ufeat_vec[pair.first]->shape[1])\n      LOG(FATAL) << \"Column width of the input node features of all etypes \"\n                    \"must be same.\";\n    if (efeat_etype0->shape[1] != efeat_vec[etype]->shape[1])\n      LOG(FATAL) << \"Column width of the input edge features of all etypes \"\n                    \"must be same.\";\n  }\n  const auto& bcast = CalcBcastOff(op, ufeat_etype0, efeat_etype0);\n\n  ATEN_XPU_SWITCH_CUDA(graph->Context().device_type, XPU, \"SpMM\", {\n    ATEN_ID_TYPE_SWITCH(\n        graph->DataType(), IdType, {\n          ATEN_FLOAT_TYPE_SWITCH_16BITS(\n              (*out)[out_eid[0]]->dtype, Dtype, XPU, \"Feature data\", {\n                if (format == SparseFormat::kCSC) {\n                  SpMMCsrHetero<XPU, IdType, Dtype>(\n                      op, reduce, bcast, vec_graph, ufeat_vec, efeat_vec, out,\n                      out_aux, ufeat_eid, out_eid);\n                } else {\n                  // TODO(Israt): Add support for COO format\n                  LOG(FATAL)\n                      << \"SpMM only supports CSC format for graphs with number \"\n                      << \"of relation types > 1\";\n                }\n              });\n        });\n  });\n}\n\n/** @brief Generalized Sampled Dense-Dense Matrix Multiplication. */\nvoid SDDMM(\n    const std::string& op, HeteroGraphPtr graph, NDArray lhs, NDArray rhs,\n    NDArray out, int lhs_target, int rhs_target) {\n  // TODO(zihao): format tuning\n  SparseFormat format = graph->SelectFormat(0, COO_CODE);\n  const auto& bcast = CalcBcastOff(op, lhs, rhs);\n\n  ATEN_XPU_SWITCH_CUDA(graph->Context().device_type, XPU, \"SDDMM\", {\n    ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, {\n      ATEN_FLOAT_TYPE_SWITCH_16BITS(out->dtype, Dtype, XPU, \"Feature data\", {\n        if (format == SparseFormat::kCSR) {\n          SDDMMCsr<XPU, IdType, Dtype>(\n              op, bcast, graph->GetCSRMatrix(0), lhs, rhs, out, lhs_target,\n              rhs_target);\n        } else if (format == SparseFormat::kCOO) {\n          SDDMMCoo<XPU, IdType, Dtype>(\n              op, bcast, graph->GetCOOMatrix(0), lhs, rhs, out, lhs_target,\n              rhs_target);\n        } else {\n          LOG(FATAL) << \"SDDMM only supports CSR and COO formats\";\n        }\n      });\n    });\n  });\n}\n\n/**\n * @brief Find the src/dst/etype id based on the target 'u', 'v' or 'e'.\n *\n * @param graph The input graph.\n * @param target 'u', 'v' or 'e'. The target of the lhs or rhs data of an etype.\n * @param etype Relation type of the input graph.\n */\nint get_typeid_by_target(HeteroGraphPtr graph, int target, dgl_type_t etype) {\n  auto pair = graph->meta_graph()->FindEdge(etype);\n  if (target == 0) return pair.first;\n  if (target == 2) return pair.second;\n  return etype;\n}\n\n/** @brief Generalized Sampled Dense-Dense Matrix Multiplication. */\nvoid SDDMMHetero(\n    const std::string& op, HeteroGraphPtr graph, std::vector<NDArray> lhs,\n    std::vector<NDArray> rhs, std::vector<NDArray> out, int lhs_target,\n    int rhs_target) {\n  SparseFormat format = graph->SelectFormat(0, COO_CODE);\n\n  std::vector<dgl_type_t> lhs_eid;\n  std::vector<dgl_type_t> rhs_eid;\n  for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {\n    lhs_eid.push_back(get_typeid_by_target(graph, lhs_target, etype));\n    rhs_eid.push_back(get_typeid_by_target(graph, rhs_target, etype));\n  }\n  const auto& bcast = CalcBcastOff(op, lhs[lhs_eid[0]], rhs[rhs_eid[0]]);\n\n  ATEN_XPU_SWITCH_CUDA(graph->Context().device_type, XPU, \"SDDMM\", {\n    ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, {\n      ATEN_FLOAT_TYPE_SWITCH_16BITS(\n          out[rhs_eid[0]]->dtype, Dtype, XPU, \"Feature data\", {\n            if (format == SparseFormat::kCSR) {\n              std::vector<CSRMatrix> vec_csr;\n              for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes();\n                   ++etype) {\n                vec_csr.push_back(graph->GetCSRMatrix(etype));\n              }\n              SDDMMCsrHetero<XPU, IdType, Dtype>(\n                  op, bcast, vec_csr, lhs, rhs, out, lhs_target, rhs_target,\n                  lhs_eid, rhs_eid);\n            } else if (format == SparseFormat::kCOO) {\n              std::vector<COOMatrix> vec_coo;\n              for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes();\n                   ++etype) {\n                vec_coo.push_back(graph->GetCOOMatrix(etype));\n              }\n              SDDMMCooHetero<XPU, IdType, Dtype>(\n                  op, bcast, vec_coo, lhs, rhs, out, lhs_target, rhs_target,\n                  lhs_eid, rhs_eid);\n            } else {\n              LOG(FATAL) << \"SDDMM only supports CSR and COO formats\";\n            }\n          });\n    });\n  });\n}\n\n/** @brief Generalized Edge_softmax op for forward */\nvoid Edge_softmax_forward(\n    const std::string& op, HeteroGraphPtr graph, NDArray ufeat, NDArray efeat,\n    NDArray out) {\n  // TODO(zhejiang): add gpu op for edge_softmax\n  const auto& bcast = CalcBcastOff(op, ufeat, efeat);\n\n  ATEN_XPU_SWITCH(graph->Context().device_type, XPU, \"edge_softmax\", {\n    ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, {\n      ATEN_FLOAT_TYPE_SWITCH_16BITS(\n          out->dtype, Dtype, XPU, \"edge_softmax out data\", {\n            Edge_softmax_csr_forward<XPU, IdType, Dtype>(\n                op, bcast, graph->GetCSCMatrix(0), ufeat, efeat, out);\n          });\n    });\n  });\n}\n\n/** @brief Generalized Edge_softmax op for backward */\nvoid Edge_softmax_backward(\n    const std::string& op, HeteroGraphPtr graph, NDArray out, NDArray sds,\n    NDArray back_out, NDArray ufeat) {\n  // TODO(zhejiang): add gpu op for edge_softmax\n  const auto& bcast = CalcBcastOff(op, ufeat, sds);\n\n  ATEN_XPU_SWITCH(graph->Context().device_type, XPU, \"edge_softmax_back\", {\n    ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, {\n      ATEN_FLOAT_TYPE_SWITCH_16BITS(\n          out->dtype, Dtype, XPU, \"edge_softmax out data_back\", {\n            Edge_softmax_csr_backward<XPU, IdType, Dtype>(\n                op, bcast, graph->GetCSCMatrix(0), out, sds, back_out);\n          });\n    });\n  });\n}\n\nNDArray GetEdgeMapping(HeteroGraphRef graph) {\n  SparseFormat format = graph->SelectFormat(0, CSC_CODE);\n  if (format == SparseFormat::kCSC) {\n    return graph.sptr()->GetCSCMatrix(0).data;\n  } else {\n    return NullArray();\n  }\n}\n\n/** @brief Segment reduce dispatch function. */\nvoid SegmentReduceDispatch(\n    const std::string& op, NDArray feat, NDArray offsets, NDArray out,\n    NDArray arg) {\n  ATEN_XPU_SWITCH_CUDA(feat->ctx.device_type, XPU, \"SegmentReduce\", {\n    ATEN_ID_TYPE_SWITCH(offsets->dtype, IdType, {\n      ATEN_FLOAT_TYPE_SWITCH_16BITS(feat->dtype, Dtype, XPU, \"Feature data\", {\n        SegmentReduce<XPU, IdType, Dtype>(op, feat, offsets, out, arg);\n      });\n    });\n  });\n}\n\n/** @brief Scatter Add (on first dimension) dispatch function. */\nvoid ScatterAddDispatch(NDArray feat, NDArray idx, NDArray out) {\n  ATEN_XPU_SWITCH_CUDA(feat->ctx.device_type, XPU, \"ScatterAdd\", {\n    ATEN_ID_TYPE_SWITCH(idx->dtype, IdType, {\n      ATEN_FLOAT_TYPE_SWITCH_16BITS(feat->dtype, Dtype, XPU, \"Feature data\", {\n        ScatterAdd<XPU, IdType, Dtype>(feat, idx, out);\n      });\n    });\n  });\n}\n\n/** @brief Update gradients (reduce op max/min) dispatch function on\n * heterogeneous graph. */\nvoid UpdateGradMinMaxDispatchHetero(\n    const HeteroGraphPtr& graph, const std::string& op,\n    const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,\n    const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out) {\n  auto pair = graph->meta_graph()->FindEdge(0);  // checking the first etype\n  auto src_id = pair.first;\n  ATEN_XPU_SWITCH_CUDA(feat[src_id]->ctx.device_type, XPU, \"ScatterAdd\", {\n    ATEN_ID_TYPE_SWITCH(idx[src_id]->dtype, IdType, {\n      ATEN_FLOAT_TYPE_SWITCH_16BITS(\n          feat[src_id]->dtype, Dtype, XPU, \"Feature data\", {\n            UpdateGradMinMax_hetero<XPU, IdType, Dtype>(\n                graph, op, feat, idx, idx_etype, out);\n          });\n    });\n  });\n}\n\n/** @brief Backward segment cmp dispatch function.*/\nvoid BackwardSegmentCmpDispatch(NDArray feat, NDArray arg, NDArray out) {\n  ATEN_XPU_SWITCH_CUDA(feat->ctx.device_type, XPU, \"BackwardSegmentCmp\", {\n    ATEN_ID_TYPE_SWITCH(arg->dtype, IdType, {\n      ATEN_FLOAT_TYPE_SWITCH_16BITS(feat->dtype, Dtype, XPU, \"Feature data\", {\n        BackwardSegmentCmp<XPU, IdType, Dtype>(feat, arg, out);\n      });\n    });\n  });\n}\n\nstd::pair<CSRMatrix, NDArray> CSRMM(\n    CSRMatrix A, NDArray A_weights, CSRMatrix B, NDArray B_weights) {\n  CHECK_EQ(A.num_cols, B.num_rows)\n      << \"The number of nodes of destination node type of the first graph must \"\n         \"be the \"\n         \"same as the number of nodes of source node type of the second graph.\";\n  CheckCtx(\n      A.indptr->ctx, {A_weights, B_weights},\n      {\"A's edge weights\", \"B's edge weights\"});\n  CHECK_EQ(A.indptr->ctx, B.indptr->ctx) << \"Device of two graphs must match.\";\n  CHECK_EQ(A.indptr->dtype, B.indptr->dtype)\n      << \"ID types of two graphs must match.\";\n  CHECK_EQ(A_weights->dtype, B_weights->dtype)\n      << \"Data types of two edge weights must match.\";\n\n  std::pair<CSRMatrix, NDArray> ret;\n  ATEN_XPU_SWITCH_CUDA(A.indptr->ctx.device_type, XPU, \"CSRMM\", {\n    ATEN_ID_TYPE_SWITCH(A.indptr->dtype, IdType, {\n      ATEN_FLOAT_TYPE_SWITCH(A_weights->dtype, DType, \"Edge weights\", {\n        ret = CSRMM<XPU, IdType, DType>(A, A_weights, B, B_weights);\n      });\n    });\n  });\n  return ret;\n}\n\nstd::pair<CSRMatrix, NDArray> CSRSum(\n    const std::vector<CSRMatrix>& A, const std::vector<NDArray>& A_weights) {\n  CHECK(A.size() > 0) << \"The list of graphs must not be empty.\";\n  CHECK_EQ(A.size(), A_weights.size())\n      << \"The list of edge weights must have the same length as the list of \"\n         \"graphs.\";\n  const auto ctx = A[0].indptr->ctx;\n  const auto idtype = A[0].indptr->dtype;\n  const auto dtype = A_weights[0]->dtype;\n  const auto num_rows = A[0].num_rows;\n  const auto num_cols = A[0].num_cols;\n  for (size_t i = 0; i < A.size(); ++i) {\n    CHECK_EQ(A[i].indptr->ctx, ctx)\n        << \"The devices of all graphs must be equal.\";\n    CHECK_EQ(A[i].indptr->dtype, idtype)\n        << \"The ID types of all graphs must be equal.\";\n    CHECK_EQ(A[i].indices->shape[0], A_weights[i]->shape[0])\n        << \"Shape of edge weights does not match the number of edges.\";\n    CHECK_EQ(A_weights[i]->ctx, ctx) << \"The devices of edge weights must be \"\n                                        \"the same as that of the graphs.\";\n    CHECK_EQ(A_weights[i]->dtype, dtype)\n        << \"The data types of all edge weights must be equal.\";\n    CHECK_EQ(A[i].num_rows, num_rows)\n        << \"Graphs must have the same number of nodes.\";\n    CHECK_EQ(A[i].num_cols, num_cols)\n        << \"Graphs must have the same number of nodes.\";\n  }\n\n  std::pair<CSRMatrix, NDArray> ret;\n  ATEN_XPU_SWITCH_CUDA(ctx.device_type, XPU, \"CSRSum\", {\n    ATEN_ID_TYPE_SWITCH(idtype, IdType, {\n      ATEN_FLOAT_TYPE_SWITCH(dtype, DType, \"Edge weights\", {\n        ret = CSRSum<XPU, IdType, DType>(A, A_weights);\n      });\n    });\n  });\n  return ret;\n}\n\nDGL_REGISTER_GLOBAL(\"sparse._CAPI_DGLKernelSpMM\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef graph = args[0];\n      const std::string op = args[1];\n      const std::string reduce_op = args[2];\n      NDArray U = args[3];\n      NDArray E = args[4];\n      NDArray V = args[5];\n      NDArray ArgU = args[6];\n      NDArray ArgE = args[7];\n      CheckCtx(\n          graph->Context(), {U, E, V, ArgU, ArgE},\n          {\"U_data\", \"E_data\", \"out\", \"Arg_U\", \"Arg_E\"});\n      CheckContiguous(\n          {U, E, V, ArgU, ArgE}, {\"U_data\", \"E_data\", \"out\", \"Arg_U\", \"Arg_E\"});\n      CHECK_EQ(graph->NumEdgeTypes(), 1);\n      auto pair =\n          graph->meta_graph()->FindEdge(0);  // only one etype in the graph.\n      const dgl_type_t src_vtype = pair.first;\n      const dgl_type_t dst_vtype = pair.second;\n      CheckShape(\n          {graph->NumVertices(src_vtype), graph->NumEdges(0),\n           graph->NumVertices(dst_vtype)},\n          {0, 1, 2, 2, 2}, {U, E, V, ArgU, ArgE},\n          {\"U_data\", \"E_data\", \"out\", \"Arg_U\", \"Arg_E\"});\n      SpMM(op, reduce_op, graph.sptr(), U, E, V, {ArgU, ArgE});\n    });\n\nDGL_REGISTER_GLOBAL(\"sparse._CAPI_DGLKernelGATHERMM\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      NDArray A = args[0];\n      NDArray B = args[1];\n      NDArray C = args[2];\n      NDArray idx_a = args[3];\n      NDArray idx_b = args[4];\n      GatherMM(A, B, C, idx_a, idx_b);\n    });\n\nDGL_REGISTER_GLOBAL(\"sparse._CAPI_DGLKernelGATHERMMSCATTER\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      NDArray A = args[0];\n      NDArray B = args[1];\n      NDArray C = args[2];\n      NDArray idx_a = args[3];\n      NDArray idx_b = args[4];\n      NDArray idx_c = args[5];\n      GatherMMScatter(A, B, C, idx_a, idx_b, idx_c);\n    });\n\nDGL_REGISTER_GLOBAL(\"sparse._CAPI_DGLKernelSEGMENTMM\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      NDArray A = args[0];\n      NDArray B = args[1];\n      NDArray C = args[2];\n      NDArray seglen_A = args[3];\n      bool A_trans = args[4];\n      bool B_trans = args[5];\n      SegmentMM(A, B, C, seglen_A, A_trans, B_trans);\n    });\n\nDGL_REGISTER_GLOBAL(\"sparse._CAPI_DGLKernelSEGMENTMMBackwardB\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      NDArray A = args[0];\n      NDArray dC = args[1];\n      NDArray dB = args[2];\n      NDArray seglen = args[3];\n      SegmentMMBackwardB(A, dC, dB, seglen);\n    });\n\nDGL_REGISTER_GLOBAL(\"sparse._CAPI_DGLKernelEdge_softmax_forward\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef graph = args[0];\n      const std::string op = args[1];\n      NDArray U = args[2];\n      NDArray E = args[3];\n      NDArray V = args[4];\n      Edge_softmax_forward(op, graph.sptr(), U, E, V);\n    });\n\nDGL_REGISTER_GLOBAL(\"sparse._CAPI_DGLKernelEdge_softmax_backward\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef graph = args[0];\n      const std::string op = args[1];\n      NDArray out = args[2];\n      NDArray sds = args[3];\n      NDArray back_out = args[4];\n      NDArray ufeat = args[5];\n      Edge_softmax_backward(op, graph.sptr(), out, sds, back_out, ufeat);\n    });\n\nDGL_REGISTER_GLOBAL(\"sparse._CAPI_DGLKernelSpMMHetero\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef graph = args[0];\n      const std::string op = args[1];\n      const std::string reduce_op = args[2];\n      List<Value> list_U = args[3];\n      List<Value> list_E = args[4];\n      List<Value> list_V = args[5];\n      List<Value> list_ArgU = args[6];\n      List<Value> list_ArgE = args[7];\n      List<Value> list_ArgU_ntype = args[8];\n      List<Value> list_ArgE_etype = args[9];\n      std::vector<std::vector<NDArray>> Arg_vec;  // ArgU + ArgE\n      for (int i = 0; i < 4; ++i) {  // ArgU + ArgE + ArgU_ntype + ArgE_etype\n        Arg_vec.push_back(std::vector<NDArray>());\n      }\n      std::vector<NDArray> U_vec = ListValueToVector<NDArray>(list_U);\n      std::vector<NDArray> V_vec = ListValueToVector<NDArray>(list_V);\n      std::vector<NDArray> E_vec = ListValueToVector<NDArray>(list_E);\n      Arg_vec[0] = ListValueToVector<NDArray>(list_ArgU);\n      Arg_vec[1] = ListValueToVector<NDArray>(list_ArgE);\n      Arg_vec[2] = ListValueToVector<NDArray>(list_ArgU_ntype);\n      Arg_vec[3] = ListValueToVector<NDArray>(list_ArgE_etype);\n      for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {\n        auto pair = graph->meta_graph()->FindEdge(etype);\n        const dgl_id_t src_id = pair.first;\n        const dgl_id_t dst_id = pair.second;\n        NDArray U = (U_vec.size() == 0) ? NullArray() : U_vec[src_id];\n        NDArray E = (E_vec.size() == 0) ? NullArray() : E_vec[etype];\n        CheckCtx(\n            graph->Context(),\n            {U, E, V_vec[dst_id], Arg_vec[0][dst_id], Arg_vec[1][dst_id]},\n            {\"U_data\", \"E_data\", \"out\", \"Arg_U\", \"Arg_E\"});\n        CheckContiguous(\n            {U, E, V_vec[dst_id], Arg_vec[0][dst_id], Arg_vec[1][dst_id]},\n            {\"U_data\", \"E_data\", \"out\", \"Arg_U\", \"Arg_E\"});\n      }\n      SpMMHetero(op, reduce_op, graph.sptr(), U_vec, E_vec, &V_vec, &Arg_vec);\n    });\n\nDGL_REGISTER_GLOBAL(\"sparse._CAPI_DGLKernelSDDMM\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef graph = args[0];\n      const std::string op = args[1];\n      NDArray lhs = args[2];\n      NDArray rhs = args[3];\n      NDArray out = args[4];\n      int lhs_target = args[5];\n      int rhs_target = args[6];\n      CheckCtx(graph->Context(), {lhs, rhs, out}, {\"lhs\", \"rhs\", \"out\"});\n      CheckContiguous({lhs, rhs, out}, {\"lhs\", \"rhs\", \"out\"});\n      CHECK_EQ(graph->NumEdgeTypes(), 1);\n      auto pair =\n          graph->meta_graph()->FindEdge(0);  // only one etype in the graph.\n      const dgl_type_t src_vtype = pair.first;\n      const dgl_type_t dst_vtype = pair.second;\n\n      CheckShape(\n          {graph->NumVertices(src_vtype), graph->NumEdges(0),\n           graph->NumVertices(dst_vtype)},\n          {lhs_target, rhs_target, 1}, {lhs, rhs, out},\n          {\"U_data\", \"E_data\", \"V_data\"});\n      SDDMM(op, graph.sptr(), lhs, rhs, out, lhs_target, rhs_target);\n    });\n\nDGL_REGISTER_GLOBAL(\"sparse._CAPI_DGLKernelSDDMMHetero\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef graph = args[0];\n      const std::string op = args[1];\n      List<Value> list_lhs = args[2];\n      List<Value> list_rhs = args[3];\n      List<Value> list_out = args[4];\n      int lhs_target = args[5];\n      int rhs_target = args[6];\n      std::vector<NDArray> vec_lhs;\n      std::vector<NDArray> vec_rhs;\n      std::vector<NDArray> vec_out;\n\n      vec_lhs.reserve(list_lhs.size());\n      vec_rhs.reserve(list_rhs.size());\n      vec_out.reserve(list_out.size());\n\n      for (Value val : list_lhs) {\n        vec_lhs.push_back(val->data);\n      }\n      for (Value val : list_rhs) {\n        vec_rhs.push_back(val->data);\n      }\n      for (Value val : list_out) {\n        vec_out.push_back(val->data);\n      }\n      SDDMMHetero(\n          op, graph.sptr(), vec_lhs, vec_rhs, vec_out, lhs_target, rhs_target);\n    });\n\nDGL_REGISTER_GLOBAL(\"sparse._CAPI_DGLKernelSegmentReduce\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      const std::string op = args[0];\n      NDArray feat = args[1];\n      NDArray offsets = args[2];\n      NDArray out = args[3];\n      NDArray arg = args[4];\n      CheckCtx(feat->ctx, {feat, offsets, out}, {\"feat\", \"offsets\", \"out\"});\n      CheckContiguous({feat, offsets, out}, {\"feat\", \"offsets\", \"out\"});\n      SegmentReduceDispatch(op, feat, offsets, out, arg);\n    });\n\nDGL_REGISTER_GLOBAL(\"sparse._CAPI_DGLKernelScatterAdd\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      NDArray feat = args[0];\n      NDArray idx = args[1];\n      NDArray out = args[2];\n      CheckCtx(feat->ctx, {feat, idx, out}, {\"feat\", \"idx\", \"out\"});\n      CheckContiguous({feat, idx, out}, {\"feat\", \"idx\", \"out\"});\n      ScatterAddDispatch(feat, idx, out);\n    });\n\nDGL_REGISTER_GLOBAL(\"sparse._CAPI_DGLKernelUpdateGradMinMaxHetero\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef graph = args[0];\n      const std::string op = args[1];\n      List<Value> list_feat = args[2];\n      List<Value> list_idx = args[3];\n      List<Value> list_idx_etype = args[4];\n      List<Value> list_out = args[5];\n      std::vector<NDArray> vec_feat = ListValueToVector<NDArray>(list_feat);\n      std::vector<NDArray> vec_idx = ListValueToVector<NDArray>(list_idx);\n      std::vector<NDArray> vec_idx_etype =\n          ListValueToVector<NDArray>(list_idx_etype);\n      std::vector<NDArray> vec_out = ListValueToVector<NDArray>(list_out);\n      // CheckCtx(feat->ctx, {feat, idx, out}, {\"feat\", \"idx\", \"out\"});\n      // CheckContiguous({feat, idx, out}, {\"feat\", \"idx\", \"out\"});\n      UpdateGradMinMaxDispatchHetero(\n          graph.sptr(), op, vec_feat, vec_idx, vec_idx_etype, &vec_out);\n    });\n\nDGL_REGISTER_GLOBAL(\"sparse._CAPI_DGLKernelBwdSegmentCmp\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      NDArray feat = args[0];\n      NDArray arg = args[1];\n      NDArray out = args[2];\n      CheckCtx(feat->ctx, {feat, arg, out}, {\"feat\", \"arg\", \"out\"});\n      CheckContiguous({feat, arg, out}, {\"feat\", \"arg\", \"out\"});\n      BackwardSegmentCmpDispatch(feat, arg, out);\n    });\n\nDGL_REGISTER_GLOBAL(\"sparse._CAPI_DGLKernelGetEdgeMapping\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef graph = args[0];\n      *rv = GetEdgeMapping(graph);\n    });\n\n/**\n * @brief Sparse matrix multiplication with graph interface.\n *\n * @param A_ref The left operand.\n * @param A_weights The edge weights of graph A.\n * @param B_ref The right operand.\n * @param B_weights The edge weights of graph B.\n * @param num_vtypes The number of vertex types of the graph to be returned.\n * @return A pair consisting of the new graph as well as its edge weights.\n */\nDGL_REGISTER_GLOBAL(\"sparse._CAPI_DGLCSRMM\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      const HeteroGraphRef A_ref = args[0];\n      NDArray A_weights = args[1];\n      const HeteroGraphRef B_ref = args[2];\n      NDArray B_weights = args[3];\n      int num_vtypes = args[4];\n\n      const HeteroGraphPtr A = A_ref.sptr();\n      const HeteroGraphPtr B = B_ref.sptr();\n      CHECK_EQ(A->NumEdgeTypes(), 1)\n          << \"The first graph must have only one edge type.\";\n      CHECK_EQ(B->NumEdgeTypes(), 1)\n          << \"The second graph must have only one edge type.\";\n      const auto A_csr = A->GetCSRMatrix(0);\n      const auto B_csr = B->GetCSRMatrix(0);\n      auto result = CSRMM(A_csr, A_weights, B_csr, B_weights);\n\n      List<ObjectRef> ret;\n      ret.push_back(\n          HeteroGraphRef(CreateFromCSR(num_vtypes, result.first, ALL_CODE)));\n      ret.push_back(Value(MakeValue(result.second)));\n      *rv = ret;\n    });\n\nDGL_REGISTER_GLOBAL(\"sparse._CAPI_DGLCSRSum\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      List<HeteroGraphRef> A_refs = args[0];\n      List<Value> A_weights = args[1];\n\n      std::vector<NDArray> weights = ListValueToVector<NDArray>(A_weights);\n      std::vector<CSRMatrix> mats;\n      mats.reserve(A_refs.size());\n      int num_vtypes = 0;\n      for (auto A_ref : A_refs) {\n        const HeteroGraphPtr A = A_ref.sptr();\n        CHECK_EQ(A->NumEdgeTypes(), 1)\n            << \"Graphs must have only one edge type.\";\n        mats.push_back(A->GetCSRMatrix(0));\n        if (num_vtypes == 0) num_vtypes = A->NumVertexTypes();\n      }\n      auto result = CSRSum(mats, weights);\n\n      List<ObjectRef> ret;\n      ret.push_back(\n          HeteroGraphRef(CreateFromCSR(num_vtypes, result.first, ALL_CODE)));\n      ret.push_back(Value(MakeValue(result.second)));\n      *rv = ret;\n    });\n\nDGL_REGISTER_GLOBAL(\"sparse._CAPI_DGLCSRMask\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      const HeteroGraphRef A_ref = args[0];\n      NDArray A_weights = args[1];\n      const HeteroGraphRef B_ref = args[2];\n\n      const HeteroGraphPtr A = A_ref.sptr();\n      const HeteroGraphPtr B = B_ref.sptr();\n      CHECK_EQ(A->NumEdgeTypes(), 1)\n          << \"Both graphs must have only one edge type.\";\n      CHECK_EQ(B->NumEdgeTypes(), 1)\n          << \"Both graphs must have only one edge type.\";\n      const CSRMatrix& A_csr = A->GetCSRMatrix(0);\n      const COOMatrix& B_coo = B->GetCOOMatrix(0);\n      CHECK_EQ(A_csr.num_rows, B_coo.num_rows)\n          << \"Both graphs must have the same number of nodes.\";\n      CHECK_EQ(A_csr.num_cols, B_coo.num_cols)\n          << \"Both graphs must have the same number of nodes.\";\n\n      NDArray result;\n      ATEN_FLOAT_TYPE_SWITCH(A_weights->dtype, DType, \"Edge weights\", {\n        result =\n            aten::CSRGetData<DType>(A_csr, B_coo.row, B_coo.col, A_weights, 0.);\n      });\n      *rv = result;\n    });\n\n}  // namespace aten\n}  // namespace dgl\n"
  },
  {
    "path": "src/array/kernel_decl.h",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file array/kernel_decl.h\n * @brief Sparse matrix format-specific operator declarations.\n */\n#ifndef DGL_ARRAY_KERNEL_DECL_H_\n#define DGL_ARRAY_KERNEL_DECL_H_\n\n#include <dgl/base_heterograph.h>\n#include <dgl/bcast.h>\n#include <dgl/runtime/ndarray.h>\n\n#include <string>\n#include <utility>\n#include <vector>\n\nnamespace dgl {\nnamespace aten {\n\n/**\n * @brief Generalized Sparse Matrix Dense Matrix Multiplication on Csr format.\n */\ntemplate <int XPU, typename IdType, typename DType>\nvoid SpMMCsr(\n    const std::string& op, const std::string& reduce, const BcastOff& bcast,\n    const aten::CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,\n    std::vector<NDArray> out_aux);\n\n/**\n * @brief Generalized Sparse Matrix Dense Matrix Multiplication on Csr format\n * with heterograph support.\n */\ntemplate <int XPU, typename IdType, typename DType>\nvoid SpMMCsrHetero(\n    const std::string& op, const std::string& reduce, const BcastOff& bcast,\n    const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,\n    const std::vector<NDArray>& efeat, std::vector<NDArray>* out,\n    std::vector<std::vector<NDArray>>* out_aux,\n    const std::vector<dgl_type_t>& ufeat_eid,\n    const std::vector<dgl_type_t>& out_eid);\n/**\n * @brief Generalized Sparse Matrix Dense Matrix Multiplication on Coo format.\n */\ntemplate <int XPU, typename IdType, typename DType>\nvoid SpMMCoo(\n    const std::string& op, const std::string& reduce, const BcastOff& bcast,\n    const aten::COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,\n    std::vector<NDArray> out_aux);\n\n/**\n * @brief Generalized Sampled Dense-Dense Matrix Multiplication on Csr format.\n */\ntemplate <int XPU, typename IdType, typename DType>\nvoid SDDMMCsr(\n    const std::string& op, const BcastOff& bcast, const aten::CSRMatrix& csr,\n    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);\n/**\n * @brief Generalized Sampled Dense-Dense Matrix Multiplication on Csr format\n * with heterograph support.\n */\ntemplate <int XPU, typename IdType, typename DType>\nvoid SDDMMCsrHetero(\n    const std::string& op, const BcastOff& bcast,\n    const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& vec_lhs,\n    const std::vector<NDArray>& vec_rhs, std::vector<NDArray> vec_out,\n    int lhs_target, int rhs_target, const std::vector<dgl_type_t>& ufeat_eid,\n    const std::vector<dgl_type_t>& out_eid);\n\n/**\n * @brief Generalized Sampled Dense-Dense Matrix Multiplication on Coo format.\n */\ntemplate <int XPU, typename IdType, typename DType>\nvoid SDDMMCoo(\n    const std::string& op, const BcastOff& bcast, const aten::COOMatrix& coo,\n    NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);\n\n/**\n * @brief Generalized Sampled Dense-Dense Matrix Multiplication on Coo format\n * with heterograph support.\n */\ntemplate <int XPU, typename IdType, typename DType>\nvoid SDDMMCooHetero(\n    const std::string& op, const BcastOff& bcast,\n    const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& vec_lhs,\n    const std::vector<NDArray>& vec_rhs, std::vector<NDArray> vec_out,\n    int lhs_target, int rhs_target, const std::vector<dgl_type_t>& lhs_eid,\n    const std::vector<dgl_type_t>& rhs_eid);\n\n/**\n * @brief Generalized Dense Matrix-Matrix Multiplication according to relation\n * types.\n */\ntemplate <int XPU, typename IdType, typename DType>\nvoid GatherMM(\n    const NDArray A, const NDArray B, NDArray out, const NDArray idx_a,\n    const NDArray idx_b);\n\n/**\n * @brief Generalized Dense Matrix-Matrix Multiplication according to relation\n * types.\n */\ntemplate <int XPU, typename IdType, typename DType>\nvoid GatherMMScatter(\n    const NDArray A, const NDArray B, NDArray out, const NDArray idx_a,\n    const NDArray idx_b, const NDArray idx_c);\n\n/**\n * @brief Generalized segmented dense Matrix-Matrix Multiplication.\n */\ntemplate <int XPU, typename IdType, typename DType>\nvoid SegmentMM(\n    const NDArray A, const NDArray B, NDArray out, const NDArray seglen_A,\n    bool a_trans, bool b_trans);\n\ntemplate <int XPU, typename IdType, typename DType>\nvoid SegmentMMBackwardB(\n    const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);\n\n/**\n * @brief Segment reduce.\n */\ntemplate <int XPU, typename IdType, typename DType>\nvoid SegmentReduce(\n    const std::string& op, NDArray feat, NDArray offsets, NDArray out,\n    NDArray arg);\n\n/**\n * @brief Scatter Add on first dimension.\n */\ntemplate <int XPU, typename IdType, typename DType>\nvoid ScatterAdd(NDArray feat, NDArray idx, NDArray out);\n\n/**\n * @brief Update gradients for reduce operator max and min on first dimension.\n */\ntemplate <int XPU, typename IdType, typename DType>\nvoid UpdateGradMinMax_hetero(\n    const HeteroGraphPtr& g, const std::string& op,\n    const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,\n    const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);\n\n/**\n * @brief Backward function of segment cmp.\n */\ntemplate <int XPU, typename IdType, typename DType>\nvoid BackwardSegmentCmp(NDArray feat, NDArray arg, NDArray out);\n\n/**\n * @brief Sparse-sparse matrix multiplication\n *\n * @param A The left operand.\n * @param A_weights The weights of matrix as a 1D tensor.\n * @param B The right operand.\n * @param B_weights The weights of matrix as a 1D tensor.\n *\n * @note GPU implementation will cast the indices to 32 bit.\n * @note The zero entries in the result are not removed.\n * @note The CSR matrix should not have duplicate entries.\n */\ntemplate <int XPU, typename IdType, typename DType>\nstd::pair<CSRMatrix, NDArray> CSRMM(\n    const CSRMatrix& A, NDArray A_weights, const CSRMatrix& B,\n    NDArray B_weights);\n\n/**\n * @brief Sparse-sparse matrix summation.\n *\n * @param A The sparse matrices with the same size.\n * @param A_weights The weights of each sparse matrix as a 1D tensor.\n *\n * @note GPU implementation will cast the indices to 32 bit.\n * @note The zero entries in the result are not removed.\n * @note The CSR matrix should not have duplicate entries.\n */\ntemplate <int XPU, typename IdType, typename DType>\nstd::pair<CSRMatrix, NDArray> CSRSum(\n    const std::vector<CSRMatrix>& A, const std::vector<NDArray>& A_weights);\n\n/**\n * @brief Edge_softmax_csr forward function on Csr format.\n */\ntemplate <int XPU, typename IdType, typename DType>\nvoid Edge_softmax_csr_forward(\n    const std::string& op, const BcastOff& bcast, const aten::CSRMatrix& csr,\n    NDArray ufeat, NDArray efeat, NDArray out);\n/**\n * @brief Edge_softmax_csr backward function on Csr format.\n */\ntemplate <int XPU, typename IdType, typename DType>\nvoid Edge_softmax_csr_backward(\n    const std::string& op, const BcastOff& bcast, const aten::CSRMatrix& csr,\n    NDArray ufeat, NDArray efeat, NDArray out);\n}  // namespace aten\n}  // namespace dgl\n\n#endif  // DGL_ARRAY_KERNEL_DECL_H_\n"
  },
  {
    "path": "src/array/libra_partition.cc",
    "content": "/**\n *  Copyright (c) 2021 Intel Corporation\n *\n *  @file distgnn/partition/main_Libra.py\n *  @brief Libra - Vertex-cut based graph partitioner for distirbuted training\n *  @author Vasimuddin Md <vasimuddin.md@intel.com>,\n *          Guixiang Ma <guixiang.ma@intel.com>\n *          Sanchit Misra <sanchit.misra@intel.com>,\n *          Ramanarayan Mohanty <ramanarayan.mohanty@intel.com>,\n *          Sasikanth Avancha <sasikanth.avancha@intel.com>\n *          Nesreen K. Ahmed <nesreen.k.ahmed@intel.com>\n */\n\n#include <dgl/base_heterograph.h>\n#include <dgl/packed_func_ext.h>\n#include <dgl/random.h>\n#include <dgl/runtime/parallel_for.h>\n#include <dmlc/omp.h>\n#include <stdint.h>\n\n#include <vector>\n\n#include \"../c_api_common.h\"\n#include \"./check.h\"\n#include \"kernel_decl.h\"\n\nusing namespace dgl::runtime;\n\nnamespace dgl {\nnamespace aten {\n\ntemplate <typename IdType>\nint32_t Ver2partition(IdType in_val, int64_t *node_map, int32_t num_parts) {\n  int32_t pos = 0;\n  for (int32_t p = 0; p < num_parts; p++) {\n    if (in_val < node_map[p]) return pos;\n    pos = pos + 1;\n  }\n  LOG(FATAL) << \"Error: Unexpected output in Ver2partition!\";\n  return -1;\n}\n\n/**\n * @brief Identifies the lead loaded partition/community for a given edge\n * assignment.\n */\nint32_t LeastLoad(int64_t *community_edges, int32_t nc) {\n  std::vector<int> loc;\n  int32_t min = 1e9;\n  for (int32_t i = 0; i < nc; i++) {\n    if (community_edges[i] < min) {\n      min = community_edges[i];\n    }\n  }\n  for (int32_t i = 0; i < nc; i++) {\n    if (community_edges[i] == min) {\n      loc.push_back(i);\n    }\n  }\n\n  int32_t r = RandomEngine::ThreadLocal()->RandInt(loc.size());\n  CHECK(loc[r] < nc);\n  return loc[r];\n}\n\n/**\n * @brief Libra - vertexcut based graph partitioning.\n * It takes list of edges from input DGL graph and distributed them among nc\n * partitions During edge distribution, Libra assign a given edge to a partition\n * based on the end vertices, in doing so, it tries to minimized the splitting\n * of the graph vertices. In case of conflict Libra assigns an edge to the least\n * loaded partition/community.\n * @param[in] nc Number of partitions/communities\n * @param[in] node_degree per node degree\n * @param[in] edgenum_unassigned node degree\n * @param[out] community_weights weight of the created partitions\n * @param[in] u src nodes\n * @param[in] v dst nodes\n * @param[out] w weight per edge\n * @param[out] out partition assignment of the edges\n * @param[in] N_n number of nodes in the input graph\n * @param[in] N_e number of edges in the input graph\n * @param[in] prefix output/partition storage location\n */\ntemplate <typename IdType, typename IdType2>\nvoid LibraVertexCut(\n    int32_t nc, NDArray node_degree, NDArray edgenum_unassigned,\n    NDArray community_weights, NDArray u, NDArray v, NDArray w, NDArray out,\n    int64_t N_n, int64_t N_e, const std::string &prefix) {\n  int32_t *out_ptr = out.Ptr<int32_t>();\n  IdType2 *node_degree_ptr = node_degree.Ptr<IdType2>();\n  IdType2 *edgenum_unassigned_ptr = edgenum_unassigned.Ptr<IdType2>();\n  IdType *u_ptr = u.Ptr<IdType>();\n  IdType *v_ptr = v.Ptr<IdType>();\n  int64_t *w_ptr = w.Ptr<int64_t>();\n  int64_t *community_weights_ptr = community_weights.Ptr<int64_t>();\n\n  std::vector<std::vector<int32_t> > node_assignments(N_n);\n  std::vector<IdType2> replication_list;\n  // local allocations\n  int64_t *community_edges = new int64_t[nc]();\n  int64_t *cache = new int64_t[nc]();\n\n  int64_t meter = static_cast<int>(N_e / 100);\n  for (int64_t i = 0; i < N_e; i++) {\n    IdType u = u_ptr[i];   // edge end vertex 1\n    IdType v = v_ptr[i];   // edge end vertex 2\n    int64_t w = w_ptr[i];  // edge weight\n\n    CHECK(u < N_n);\n    CHECK(v < N_n);\n\n    if (i % meter == 0) {\n      fprintf(stderr, \".\");\n      fflush(0);\n    }\n\n    if (node_assignments[u].size() == 0 && node_assignments[v].size() == 0) {\n      int32_t c = LeastLoad(community_edges, nc);\n      out_ptr[i] = c;\n      CHECK_LT(c, nc);\n\n      community_edges[c]++;\n      community_weights_ptr[c] = community_weights_ptr[c] + w;\n      node_assignments[u].push_back(c);\n      if (u != v) node_assignments[v].push_back(c);\n\n      CHECK(node_assignments[u].size() <= static_cast<size_t>(nc))\n          << \"[bug] 1. generated splits (u) are greater than nc!\";\n      CHECK(node_assignments[v].size() <= static_cast<size_t>(nc))\n          << \"[bug] 1. generated splits (v) are greater than nc!\";\n      edgenum_unassigned_ptr[u]--;\n      edgenum_unassigned_ptr[v]--;\n    } else if (\n        node_assignments[u].size() != 0 && node_assignments[v].size() == 0) {\n      for (uint32_t j = 0; j < node_assignments[u].size(); j++) {\n        int32_t cind = node_assignments[u][j];\n        cache[j] = community_edges[cind];\n      }\n      int32_t cindex = LeastLoad(cache, node_assignments[u].size());\n      int32_t c = node_assignments[u][cindex];\n      out_ptr[i] = c;\n      community_edges[c]++;\n      community_weights_ptr[c] = community_weights_ptr[c] + w;\n\n      node_assignments[v].push_back(c);\n      CHECK(node_assignments[v].size() <= static_cast<size_t>(nc))\n          << \"[bug] 2. generated splits (v) are greater than nc!\";\n      edgenum_unassigned_ptr[u]--;\n      edgenum_unassigned_ptr[v]--;\n    } else if (\n        node_assignments[v].size() != 0 && node_assignments[u].size() == 0) {\n      for (uint32_t j = 0; j < node_assignments[v].size(); j++) {\n        int32_t cind = node_assignments[v][j];\n        cache[j] = community_edges[cind];\n      }\n      int32_t cindex = LeastLoad(cache, node_assignments[v].size());\n      int32_t c = node_assignments[v][cindex];\n      CHECK(c < nc) << \"[bug] 2. partition greater than nc !!\";\n      out_ptr[i] = c;\n\n      community_edges[c]++;\n      community_weights_ptr[c] = community_weights_ptr[c] + w;\n\n      node_assignments[u].push_back(c);\n      CHECK(node_assignments[u].size() <= static_cast<size_t>(nc))\n          << \"[bug] 3. generated splits (u) are greater than nc!\";\n      edgenum_unassigned_ptr[u]--;\n      edgenum_unassigned_ptr[v]--;\n    } else {\n      std::vector<int> setv(nc), intersetv;\n      for (int32_t j = 0; j < nc; j++) setv[j] = 0;\n      int32_t interset = 0;\n\n      CHECK(node_assignments[u].size() <= static_cast<size_t>(nc))\n          << \"[bug] 4. generated splits (u) are greater than nc!\";\n      CHECK(node_assignments[v].size() <= static_cast<size_t>(nc))\n          << \"[bug] 4. generated splits (v) are greater than nc!\";\n      for (size_t j = 0; j < node_assignments[v].size(); j++) {\n        CHECK(node_assignments[v][j] < nc)\n            << \"[bug] 4. Part assigned (v) greater than nc!\";\n        setv[node_assignments[v][j]]++;\n      }\n\n      for (size_t j = 0; j < node_assignments[u].size(); j++) {\n        CHECK(node_assignments[u][j] < nc)\n            << \"[bug] 4. Part assigned (u) greater than nc!\";\n        setv[node_assignments[u][j]]++;\n      }\n\n      for (int32_t j = 0; j < nc; j++) {\n        CHECK(setv[j] <= 2) << \"[bug] 4. unexpected computed value !!!\";\n        if (setv[j] == 2) {\n          interset++;\n          intersetv.push_back(j);\n        }\n      }\n      if (interset) {\n        for (size_t j = 0; j < intersetv.size(); j++) {\n          int32_t cind = intersetv[j];\n          cache[j] = community_edges[cind];\n        }\n        int32_t cindex = LeastLoad(cache, intersetv.size());\n        int32_t c = intersetv[cindex];\n        CHECK(c < nc) << \"[bug] 4. partition greater than nc !!\";\n        out_ptr[i] = c;\n        community_edges[c]++;\n        community_weights_ptr[c] = community_weights_ptr[c] + w;\n        edgenum_unassigned_ptr[u]--;\n        edgenum_unassigned_ptr[v]--;\n      } else {\n        if (node_degree_ptr[u] < node_degree_ptr[v]) {\n          for (uint32_t j = 0; j < node_assignments[u].size(); j++) {\n            int32_t cind = node_assignments[u][j];\n            cache[j] = community_edges[cind];\n          }\n          int32_t cindex = LeastLoad(cache, node_assignments[u].size());\n          int32_t c = node_assignments[u][cindex];\n          CHECK(c < nc) << \"[bug] 5. partition greater than nc !!\";\n          out_ptr[i] = c;\n          community_edges[c]++;\n          community_weights_ptr[c] = community_weights_ptr[c] + w;\n\n          for (uint32_t j = 0; j < node_assignments[v].size(); j++) {\n            CHECK(node_assignments[v][j] != c)\n                << \"[bug] 5. duplicate partition (v) assignment !!\";\n          }\n\n          node_assignments[v].push_back(c);\n          CHECK(node_assignments[v].size() <= static_cast<size_t>(nc))\n              << \"[bug] 5. generated splits (v) greater than nc!!\";\n          replication_list.push_back(v);\n          edgenum_unassigned_ptr[u]--;\n          edgenum_unassigned_ptr[v]--;\n        } else {\n          for (uint32_t j = 0; j < node_assignments[v].size(); j++) {\n            int32_t cind = node_assignments[v][j];\n            cache[j] = community_edges[cind];\n          }\n          int32_t cindex = LeastLoad(cache, node_assignments[v].size());\n          int32_t c = node_assignments[v][cindex];\n          CHECK(c < nc) << \"[bug] 6. partition greater than nc !!\";\n          out_ptr[i] = c;\n          community_edges[c]++;\n          community_weights_ptr[c] = community_weights_ptr[c] + w;\n          for (uint32_t j = 0; j < node_assignments[u].size(); j++) {\n            CHECK(node_assignments[u][j] != c)\n                << \"[bug] 6. duplicate partition (u) assignment !!\";\n          }\n          if (u != v) node_assignments[u].push_back(c);\n\n          CHECK(node_assignments[u].size() <= static_cast<size_t>(nc))\n              << \"[bug] 6. generated splits (u) greater than nc!!\";\n          replication_list.push_back(u);\n          edgenum_unassigned_ptr[u]--;\n          edgenum_unassigned_ptr[v]--;\n        }\n      }\n    }\n  }\n  delete cache;\n\n  for (int64_t c = 0; c < nc; c++) {\n    std::string path = prefix + \"/community\" + std::to_string(c) + \".txt\";\n\n    FILE *fp = fopen(path.c_str(), \"w\");\n    CHECK_NE(fp, static_cast<FILE *>(NULL))\n        << \"Error: can not open file: \" << path.c_str();\n\n    for (int64_t i = 0; i < N_e; i++) {\n      if (out_ptr[i] == c)\n        fprintf(\n            fp, \"%ld,%ld,%ld\\n\", static_cast<int64_t>(u_ptr[i]),\n            static_cast<int64_t>(v_ptr[i]), w_ptr[i]);\n    }\n    fclose(fp);\n  }\n\n  std::string path = prefix + \"/replicationlist.csv\";\n  FILE *fp = fopen(path.c_str(), \"w\");\n  CHECK_NE(fp, static_cast<FILE *>(NULL))\n      << \"Error: can not open file: \" << path.c_str();\n\n  fprintf(fp, \"## The Indices of Nodes that are replicated :: Header\");\n  printf(\"\\nTotal replication: %ld\\n\", replication_list.size());\n\n  for (uint64_t i = 0; i < replication_list.size(); i++)\n    fprintf(fp, \"%ld\\n\", static_cast<int64_t>(replication_list[i]));\n\n  printf(\"Community weights:\\n\");\n  for (int64_t c = 0; c < nc; c++) printf(\"%ld \", community_weights_ptr[c]);\n  printf(\"\\n\");\n\n  printf(\"Community edges:\\n\");\n  for (int64_t c = 0; c < nc; c++) printf(\"%ld \", community_edges[c]);\n  printf(\"\\n\");\n\n  delete[] community_edges;\n  fclose(fp);\n}\n\nDGL_REGISTER_GLOBAL(\"sparse._CAPI_DGLLibraVertexCut\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      int32_t nc = args[0];\n      NDArray node_degree = args[1];\n      NDArray edgenum_unassigned = args[2];\n      NDArray community_weights = args[3];\n      NDArray u = args[4];\n      NDArray v = args[5];\n      NDArray w = args[6];\n      NDArray out = args[7];\n      int64_t N = args[8];\n      int64_t N_e = args[9];\n      std::string prefix = args[10];\n\n      ATEN_ID_TYPE_SWITCH(node_degree->dtype, IdType2, {\n        ATEN_ID_TYPE_SWITCH(u->dtype, IdType, {\n          LibraVertexCut<IdType, IdType2>(\n              nc, node_degree, edgenum_unassigned, community_weights, u, v, w,\n              out, N, N_e, prefix);\n        });\n      });\n    });\n\n/**\n * @brief\n * 1. Builds dictionary (ldt) for assigning local node IDs to nodes in the\n *    partitions\n * 2. Builds dictionary (gdt) for storing copies (local ID) of split nodes\n *    These dictionaries will be used in the subsequesnt stages to setup\n *    tracking of split nodes copies across the partition, setting up partition\n *    `ndata` dictionaries.\n * @param[out] a local src node ID of an edge in a partition\n * @param[out] b local dst node ID of an edge in a partition\n * @param[-] indices temporary memory, keeps track of global node ID to local\n *           node ID in a partition\n * @param[out] ldt_key per partition dict for storing global and local node IDs\n *             (consecutive)\n * @param[out] gdt_key global dict for storing number of local nodes (or split\n *             nodes) for a given global node ID\n * @param[out] gdt_value global dict, stores local node IDs (due to split)\n *             across partitions for a given global node ID\n * @param[out] node_map keeps track of range of local node IDs (consecutive)\n *             given to the nodes in the partitions\n * @param[in, out] offset start of the range of local node IDs for this\n *                 partition\n * @param[in] nc number of partitions/communities\n * @param[in] c current partition number\n * @param[in] fsize size of pre-allocated\n *            memory tensor\n * @param[in] prefix input Libra partition file location\n */\nList<Value> Libra2dglBuildDict(\n    NDArray a, NDArray b, NDArray indices, NDArray ldt_key, NDArray gdt_key,\n    NDArray gdt_value, NDArray node_map, NDArray offset, int32_t nc, int32_t c,\n    int64_t fsize, const std::string &prefix) {\n  int64_t *indices_ptr = indices.Ptr<int64_t>();  // 1D temp array\n  int64_t *ldt_key_ptr =\n      ldt_key.Ptr<int64_t>();  // 1D local nodes <-> global nodes\n  int64_t *gdt_key_ptr = gdt_key.Ptr<int64_t>();  // 1D #split copies per node\n  int64_t *gdt_value_ptr = gdt_value.Ptr<int64_t>();  // 2D tensor\n  int64_t *node_map_ptr = node_map.Ptr<int64_t>();    // 1D tensor\n  int64_t *offset_ptr = offset.Ptr<int64_t>();        // 1D tensor\n  int32_t width = nc;\n\n  int64_t *a_ptr = a.Ptr<int64_t>();  // stores local src and dst node ID,\n  int64_t *b_ptr = b.Ptr<int64_t>();  // to create the partition graph\n\n  int64_t N_n = indices->shape[0];\n  int64_t num_nodes = ldt_key->shape[0];\n\n  for (int64_t i = 0; i < N_n; i++) {\n    indices_ptr[i] = -100;\n  }\n\n  int64_t pos = 0;\n  int64_t edge = 0;\n  std::string path = prefix + \"/community\" + std::to_string(c) + \".txt\";\n  FILE *fp = fopen(path.c_str(), \"r\");\n  CHECK_NE(fp, static_cast<FILE *>(NULL))\n      << \"Error: can not open file: \" << path.c_str();\n\n  while (!feof(fp) && edge < fsize) {\n    int64_t u, v;\n    float w;\n    CHECK_EQ(\n        fscanf(fp, \"%ld,%ld,%f\\n\", &u, &v, &w),\n        3);  // reading an edge - the src and dst global node IDs\n\n    if (indices_ptr[u] ==\n        -100) {  // if already not assigned a local node ID, local node ID is\n      ldt_key_ptr[pos] = u;    // already assigned for this global node ID\n      CHECK(pos < num_nodes);  // Sanity check\n      indices_ptr[u] =\n          pos++;  // consecutive local node ID for a given global node ID\n    }\n    if (indices_ptr[v] == -100) {  // if already not assigned a local node ID\n      ldt_key_ptr[pos] = v;\n      CHECK(pos < num_nodes);  // Sanity check\n      indices_ptr[v] = pos++;\n    }\n    a_ptr[edge] = indices_ptr[u];    // new local ID for an edge\n    b_ptr[edge++] = indices_ptr[v];  // new local ID for an edge\n  }\n  CHECK(edge <= fsize)\n      << \"[Bug] memory allocated for #edges per partition is not enough.\";\n  fclose(fp);\n\n  List<Value> ret;\n  ret.push_back(Value(\n      MakeValue(pos)));  // returns total number of nodes in this partition\n  ret.push_back(Value(\n      MakeValue(edge)));  // returns total number of edges in this partition\n\n  for (int64_t i = 0; i < pos; i++) {\n    int64_t u = ldt_key_ptr[i];  // global node ID\n    // int64_t  v   = indices_ptr[u];\n    int64_t v = i;  // local node ID\n    int64_t *ind =\n        &gdt_key_ptr[u];  // global dict, total number of local node IDs (an\n                          // offset) as of now for a given global node ID\n    int64_t *ptr = gdt_value_ptr + u * width;\n    ptr[*ind] =\n        offset_ptr[0] + v;  // stores a local node ID for the global node ID\n    (*ind)++;\n    CHECK_NE(v, -100);\n    CHECK(*ind <= nc);\n  }\n  node_map_ptr[c] =\n      offset_ptr[0] +\n      pos;  // since local node IDs for a partition are consecutive,\n            // we maintain the range of local node IDs like this\n  offset_ptr[0] += pos;\n\n  return ret;\n}\n\nDGL_REGISTER_GLOBAL(\"sparse._CAPI_DGLLibra2dglBuildDict\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      NDArray a = args[0];\n      NDArray b = args[1];\n      NDArray indices = args[2];\n      NDArray ldt_key = args[3];\n      NDArray gdt_key = args[4];\n      NDArray gdt_value = args[5];\n      NDArray node_map = args[6];\n      NDArray offset = args[7];\n      int32_t nc = args[8];\n      int32_t c = args[9];\n      int64_t fsize = args[10];\n      std::string prefix = args[11];\n      List<Value> ret = Libra2dglBuildDict(\n          a, b, indices, ldt_key, gdt_key, gdt_value, node_map, offset, nc, c,\n          fsize, prefix);\n      *rv = ret;\n    });\n\n/**\n * @brief sets up the 1-level tree among the clones of the split-nodes.\n * @param[in] gdt_key global dict for assigning consecutive node IDs to nodes\n *            across all the partitions\n * @param[in] gdt_value global dict for assigning consecutive node IDs to nodes\n *            across all the partition\n * @param[out] lrtensor keeps the root node ID of 1-level tree\n * @param[in] nc number of partitions/communities\n * @param[in] Nn number of nodes in the input graph\n */\nvoid Libra2dglSetLR(\n    NDArray gdt_key, NDArray gdt_value, NDArray lrtensor, int32_t nc,\n    int64_t Nn) {\n  int64_t *gdt_key_ptr = gdt_key.Ptr<int64_t>();      // 1D tensor\n  int64_t *gdt_value_ptr = gdt_value.Ptr<int64_t>();  // 2D tensor\n  int64_t *lrtensor_ptr = lrtensor.Ptr<int64_t>();    // 1D tensor\n\n  int32_t width = nc;\n  int64_t cnt = 0;\n  int64_t avg_split_copy = 0, scnt = 0;\n\n  for (int64_t i = 0; i < Nn; i++) {\n    if (gdt_key_ptr[i] <= 0) {\n      cnt++;\n    } else {\n      int32_t val = RandomEngine::ThreadLocal()->RandInt(gdt_key_ptr[i]);\n      CHECK(val >= 0 && val < gdt_key_ptr[i]);\n      CHECK(gdt_key_ptr[i] <= nc);\n\n      int64_t *ptr = gdt_value_ptr + i * width;\n      lrtensor_ptr[i] = ptr[val];\n    }\n    if (gdt_key_ptr[i] > 1) {\n      avg_split_copy += gdt_key_ptr[i];\n      scnt++;\n    }\n  }\n}\n\nDGL_REGISTER_GLOBAL(\"sparse._CAPI_DGLLibra2dglSetLR\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      NDArray gdt_key = args[0];\n      NDArray gdt_value = args[1];\n      NDArray lrtensor = args[2];\n      int32_t nc = args[3];\n      int64_t Nn = args[4];\n\n      Libra2dglSetLR(gdt_key, gdt_value, lrtensor, nc, Nn);\n    });\n\n/**\n * @brief For each node in a partition, it creates a list of remote clone IDs;\n *        also, for each node in a partition, it gathers the data (feats, label,\n *        trian, test) from input graph.\n * @param[out] feat node features in current partition c.\n * @param[in] gfeat input graph node features.\n * @param[out] adj list of node IDs of remote clones.\n * @param[out] inner_nodes marks whether a node is split or not.\n * @param[in] ldt_key per partition dict for tracking global to local node IDs\n * @param[out] gdt_key global dict for storing number of local nodes (or split\n *             nodes) for a given global node ID\n * @param[out] gdt_value global\n *             dict, stores local node IDs (due to split) across partitions for\n *             a given global node ID.\n * @param[in] node_map keeps track of range of local node IDs (consecutive)\n *            given to the nodes in the partitions.\n * @param[out] lr 1-level tree marking for local split nodes.\n * @param[in] lrtensor global (all the partitions) 1-level tree.\n * @param[in] num_nodes number of nodes in current partition.\n * @param[in] nc number of partitions/communities.\n * @param[in] c current partition/community.\n * @param[in] feat_size node feature vector size.\n * @param[out] labels local (for this partition) labels.\n * @param[out] trainm local (for this partition) training nodes.\n * @param[out] testm local (for this partition) testing nodes.\n * @param[out] valm local (for this partition) validation nodes.\n * @param[in] glabels global (input graph) labels.\n * @param[in] gtrainm glabal (input graph) training nodes.\n * @param[in] gtestm glabal (input graph) testing nodes.\n * @param[in] gvalm glabal (input graph) validation nodes.\n * @param[out] Nn number of nodes in the input graph.\n */\ntemplate <typename IdType, typename IdType2, typename DType>\nvoid Libra2dglBuildAdjlist(\n    NDArray feat, NDArray gfeat, NDArray adj, NDArray inner_node,\n    NDArray ldt_key, NDArray gdt_key, NDArray gdt_value, NDArray node_map,\n    NDArray lr, NDArray lrtensor, int64_t num_nodes, int32_t nc, int32_t c,\n    int32_t feat_size, NDArray labels, NDArray trainm, NDArray testm,\n    NDArray valm, NDArray glabels, NDArray gtrainm, NDArray gtestm,\n    NDArray gvalm, int64_t Nn) {\n  DType *feat_ptr = feat.Ptr<DType>();    // 2D tensor\n  DType *gfeat_ptr = gfeat.Ptr<DType>();  // 2D tensor\n  int64_t *adj_ptr = adj.Ptr<int64_t>();  // 2D tensor\n  int32_t *inner_node_ptr = inner_node.Ptr<int32_t>();\n  int64_t *ldt_key_ptr = ldt_key.Ptr<int64_t>();\n  int64_t *gdt_key_ptr = gdt_key.Ptr<int64_t>();\n  int64_t *gdt_value_ptr = gdt_value.Ptr<int64_t>();  // 2D tensor\n  int64_t *node_map_ptr = node_map.Ptr<int64_t>();\n  int64_t *lr_ptr = lr.Ptr<int64_t>();\n  int64_t *lrtensor_ptr = lrtensor.Ptr<int64_t>();\n  int32_t width = nc - 1;\n\n  runtime::parallel_for(0, num_nodes, [&](int64_t s, int64_t e) {\n    for (int64_t i = s; i < e; i++) {\n      int64_t k = ldt_key_ptr[i];\n      int64_t v = i;\n      int64_t ind = gdt_key_ptr[k];\n\n      int64_t *adj_ptr_ptr = adj_ptr + v * width;\n      if (ind == 1) {\n        for (int32_t j = 0; j < width; j++) adj_ptr_ptr[j] = -1;\n        inner_node_ptr[i] = 1;\n        lr_ptr[i] = -200;\n      } else {\n        lr_ptr[i] = lrtensor_ptr[k];\n        int64_t *ptr = gdt_value_ptr + k * nc;\n        int64_t pos = 0;\n        CHECK(ind <= nc);\n        int32_t flg = 0;\n        for (int64_t j = 0; j < ind; j++) {\n          if (ptr[j] == lr_ptr[i]) flg = 1;\n          if (c != Ver2partition<int64_t>(ptr[j], node_map_ptr, nc))\n            adj_ptr_ptr[pos++] = ptr[j];\n        }\n        CHECK_EQ(flg, 1);\n        CHECK(pos == ind - 1);\n        for (; pos < width; pos++) adj_ptr_ptr[pos] = -1;\n        inner_node_ptr[i] = 0;\n      }\n    }\n  });\n\n  // gather\n  runtime::parallel_for(0, num_nodes, [&](int64_t s, int64_t e) {\n    for (int64_t i = s; i < e; i++) {\n      int64_t k = ldt_key_ptr[i];\n      int64_t ind = i * feat_size;\n      DType *optr = gfeat_ptr + ind;\n      DType *iptr = feat_ptr + k * feat_size;\n\n      for (int32_t j = 0; j < feat_size; j++) optr[j] = iptr[j];\n    }\n\n    IdType *labels_ptr = labels.Ptr<IdType>();\n    IdType *glabels_ptr = glabels.Ptr<IdType>();\n    IdType2 *trainm_ptr = trainm.Ptr<IdType2>();\n    IdType2 *gtrainm_ptr = gtrainm.Ptr<IdType2>();\n    IdType2 *testm_ptr = testm.Ptr<IdType2>();\n    IdType2 *gtestm_ptr = gtestm.Ptr<IdType2>();\n    IdType2 *valm_ptr = valm.Ptr<IdType2>();\n    IdType2 *gvalm_ptr = gvalm.Ptr<IdType2>();\n\n    for (int64_t i = 0; i < num_nodes; i++) {\n      int64_t k = ldt_key_ptr[i];\n      CHECK(k >= 0 && k < Nn);\n      glabels_ptr[i] = labels_ptr[k];\n      gtrainm_ptr[i] = trainm_ptr[k];\n      gtestm_ptr[i] = testm_ptr[k];\n      gvalm_ptr[i] = valm_ptr[k];\n    }\n  });\n}\n\nDGL_REGISTER_GLOBAL(\"sparse._CAPI_DGLLibra2dglBuildAdjlist\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      NDArray feat = args[0];\n      NDArray gfeat = args[1];\n      NDArray adj = args[2];\n      NDArray inner_node = args[3];\n      NDArray ldt_key = args[4];\n      NDArray gdt_key = args[5];\n      NDArray gdt_value = args[6];\n      NDArray node_map = args[7];\n      NDArray lr = args[8];\n      NDArray lrtensor = args[9];\n      int64_t num_nodes = args[10];\n      int32_t nc = args[11];\n      int32_t c = args[12];\n      int32_t feat_size = args[13];\n      NDArray labels = args[14];\n      NDArray trainm = args[15];\n      NDArray testm = args[16];\n      NDArray valm = args[17];\n      NDArray glabels = args[18];\n      NDArray gtrainm = args[19];\n      NDArray gtestm = args[20];\n      NDArray gvalm = args[21];\n      int64_t Nn = args[22];\n\n      ATEN_FLOAT_TYPE_SWITCH(feat->dtype, DType, \"Features\", {\n        ATEN_ID_TYPE_SWITCH(trainm->dtype, IdType2, {\n          ATEN_ID_BITS_SWITCH((glabels->dtype).bits, IdType, {\n            Libra2dglBuildAdjlist<IdType, IdType2, DType>(\n                feat, gfeat, adj, inner_node, ldt_key, gdt_key, gdt_value,\n                node_map, lr, lrtensor, num_nodes, nc, c, feat_size, labels,\n                trainm, testm, valm, glabels, gtrainm, gtestm, gvalm, Nn);\n          });\n        });\n      });\n    });\n\n}  // namespace aten\n}  // namespace dgl\n"
  },
  {
    "path": "src/array/selector.h",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file array/selector.h\n * @brief Selector functions to select among src/edge/dst attributes.\n */\n#ifndef DGL_ARRAY_SELECTOR_H_\n#define DGL_ARRAY_SELECTOR_H_\n\n#include <dmlc/logging.h>\n\nnamespace dgl {\n\nnamespace {\n\n#ifdef __CUDACC__\n#define DGLDEVICE __device__\n#define DGLINLINE __forceinline__\n#else\n#define DGLDEVICE\n#define DGLINLINE inline\n#endif  // __CUDACC__\n\n}  // namespace\n\n/**\n * @brief Select among src/edge/dst feature/idx.\n * @note the integer argument target specifies which target\n *       to choose, 0: src, 1: edge, 2: dst.\n */\ntemplate <int target>\nstruct Selector {\n  template <typename T>\n  static DGLDEVICE DGLINLINE T Call(T src, T edge, T dst) {\n    LOG(INFO) << \"Target \" << target << \" not recognized.\";\n    return src;\n  }\n};\n\ntemplate <>\ntemplate <typename T>\nDGLDEVICE DGLINLINE T Selector<0>::Call(T src, T edge, T dst) {\n  return src;\n}\n\ntemplate <>\ntemplate <typename T>\nDGLDEVICE DGLINLINE T Selector<1>::Call(T src, T edge, T dst) {\n  return edge;\n}\n\ntemplate <>\ntemplate <typename T>\nDGLDEVICE DGLINLINE T Selector<2>::Call(T src, T edge, T dst) {\n  return dst;\n}\n\n}  // namespace dgl\n\n#endif  // DGL_ARRAY_SELECTOR_H_\n"
  },
  {
    "path": "src/array/union_partition.cc",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file array/cpu/coo_union_partition.cc\n * @brief COO union and partition\n */\n#include <dgl/array.h>\n\n#include <vector>\n\nnamespace dgl {\nnamespace aten {\n///////////////////////// COO Based Operations/////////////////////////\nstd::vector<COOMatrix> DisjointPartitionCooBySizes(\n    const COOMatrix &coo, const uint64_t batch_size,\n    const std::vector<uint64_t> &edge_cumsum,\n    const std::vector<uint64_t> &src_vertex_cumsum,\n    const std::vector<uint64_t> &dst_vertex_cumsum) {\n  CHECK_EQ(edge_cumsum.size(), batch_size + 1);\n  CHECK_EQ(src_vertex_cumsum.size(), batch_size + 1);\n  CHECK_EQ(dst_vertex_cumsum.size(), batch_size + 1);\n  std::vector<COOMatrix> ret;\n  ret.resize(batch_size);\n\n  for (size_t g = 0; g < batch_size; ++g) {\n    IdArray result_src =\n        IndexSelect(coo.row, edge_cumsum[g], edge_cumsum[g + 1]) -\n        src_vertex_cumsum[g];\n    IdArray result_dst =\n        IndexSelect(coo.col, edge_cumsum[g], edge_cumsum[g + 1]) -\n        dst_vertex_cumsum[g];\n    IdArray result_data = NullArray();\n    // has data index array\n    if (COOHasData(coo)) {\n      result_data = IndexSelect(coo.data, edge_cumsum[g], edge_cumsum[g + 1]) -\n                    edge_cumsum[g];\n    }\n\n    COOMatrix sub_coo = COOMatrix(\n        src_vertex_cumsum[g + 1] - src_vertex_cumsum[g],\n        dst_vertex_cumsum[g + 1] - dst_vertex_cumsum[g], result_src, result_dst,\n        result_data, coo.row_sorted, coo.col_sorted);\n    ret[g] = sub_coo;\n  }\n\n  return ret;\n}\n\nCOOMatrix COOSliceContiguousChunk(\n    const COOMatrix &coo, const std::vector<uint64_t> &edge_range,\n    const std::vector<uint64_t> &src_vertex_range,\n    const std::vector<uint64_t> &dst_vertex_range) {\n  IdArray result_src = NullArray(coo.row->dtype, coo.row->ctx);\n  IdArray result_dst = NullArray(coo.row->dtype, coo.row->ctx);\n  if (edge_range[1] != edge_range[0]) {\n    // The chunk has edges\n    result_src = IndexSelect(coo.row, edge_range[0], edge_range[1]) -\n                 src_vertex_range[0];\n    result_dst = IndexSelect(coo.col, edge_range[0], edge_range[1]) -\n                 dst_vertex_range[0];\n  }\n\n  IdArray result_data = NullArray();\n  // has data index array\n  if (COOHasData(coo)) {\n    result_data =\n        IndexSelect(coo.data, edge_range[0], edge_range[1]) - edge_range[0];\n  }\n\n  COOMatrix sub_coo = COOMatrix(\n      src_vertex_range[1] - src_vertex_range[0],\n      dst_vertex_range[1] - dst_vertex_range[0], result_src, result_dst,\n      result_data, coo.row_sorted, coo.col_sorted);\n\n  return sub_coo;\n}\n\n///////////////////////// CSR Based Operations/////////////////////////\nCSRMatrix DisjointUnionCsr(const std::vector<CSRMatrix> &csrs) {\n  uint64_t src_offset = 0, dst_offset = 0;\n  int64_t indices_offset = 0;\n  bool has_data = false;\n  bool sorted = true;\n\n  // check if data index array\n  for (size_t i = 0; i < csrs.size(); ++i) {\n    CHECK_SAME_DTYPE(csrs[0].indptr, csrs[i].indptr);\n    CHECK_SAME_CONTEXT(csrs[0].indices, csrs[i].indices);\n    has_data |= CSRHasData(csrs[i]);\n  }\n\n  std::vector<IdArray> res_indptr;\n  std::vector<IdArray> res_indices;\n  std::vector<IdArray> res_data;\n  res_indptr.resize(csrs.size());\n  res_indices.resize(csrs.size());\n\n  for (size_t i = 0; i < csrs.size(); ++i) {\n    const aten::CSRMatrix &csr = csrs[i];\n    sorted &= csr.sorted;\n    IdArray indptr = csr.indptr + indices_offset;\n    IdArray indices = csr.indices + dst_offset;\n    if (i > 0) indptr = IndexSelect(indptr, 1, indptr->shape[0]);\n    res_indptr[i] = indptr;\n    res_indices[i] = indices;\n    src_offset += csr.num_rows;\n    dst_offset += csr.num_cols;\n\n    // any one of input csr has data index array\n    if (has_data) {\n      IdArray edges_data;\n      if (CSRHasData(csr) == false) {\n        edges_data = Range(\n            indices_offset, indices_offset + csr.indices->shape[0],\n            csr.indices->dtype.bits, csr.indices->ctx);\n      } else {\n        edges_data = csr.data + indices_offset;\n      }\n      res_data.push_back(edges_data);\n      indices_offset += csr.indices->shape[0];\n    }\n  }\n\n  IdArray result_indptr = Concat(res_indptr);\n  IdArray result_indices = Concat(res_indices);\n  IdArray result_data = has_data ? Concat(res_data) : NullArray();\n\n  return CSRMatrix(\n      src_offset, dst_offset, result_indptr, result_indices, result_data,\n      sorted);\n}\n\nstd::vector<CSRMatrix> DisjointPartitionCsrBySizes(\n    const CSRMatrix &csr, const uint64_t batch_size,\n    const std::vector<uint64_t> &edge_cumsum,\n    const std::vector<uint64_t> &src_vertex_cumsum,\n    const std::vector<uint64_t> &dst_vertex_cumsum) {\n  CHECK_EQ(edge_cumsum.size(), batch_size + 1);\n  CHECK_EQ(src_vertex_cumsum.size(), batch_size + 1);\n  CHECK_EQ(dst_vertex_cumsum.size(), batch_size + 1);\n  std::vector<CSRMatrix> ret;\n  ret.resize(batch_size);\n\n  for (size_t g = 0; g < batch_size; ++g) {\n    uint64_t num_src = src_vertex_cumsum[g + 1] - src_vertex_cumsum[g];\n    IdArray result_indptr;\n    if (g == 0) {\n      result_indptr =\n          IndexSelect(csr.indptr, 0, src_vertex_cumsum[1] + 1) - edge_cumsum[0];\n    } else {\n      result_indptr =\n          IndexSelect(\n              csr.indptr, src_vertex_cumsum[g], src_vertex_cumsum[g + 1] + 1) -\n          edge_cumsum[g];\n    }\n\n    IdArray result_indices =\n        IndexSelect(csr.indices, edge_cumsum[g], edge_cumsum[g + 1]) -\n        dst_vertex_cumsum[g];\n\n    IdArray result_data = NullArray();\n    // has data index array\n    if (CSRHasData(csr)) {\n      result_data = IndexSelect(csr.data, edge_cumsum[g], edge_cumsum[g + 1]) -\n                    edge_cumsum[g];\n    }\n\n    CSRMatrix sub_csr = CSRMatrix(\n        num_src, dst_vertex_cumsum[g + 1] - dst_vertex_cumsum[g], result_indptr,\n        result_indices, result_data, csr.sorted);\n    ret[g] = sub_csr;\n  }\n\n  return ret;\n}\n\nCSRMatrix CSRSliceContiguousChunk(\n    const CSRMatrix &csr, const std::vector<uint64_t> &edge_range,\n    const std::vector<uint64_t> &src_vertex_range,\n    const std::vector<uint64_t> &dst_vertex_range) {\n  int64_t indptr_len = src_vertex_range[1] - src_vertex_range[0] + 1;\n  IdArray result_indptr =\n      Full(0, indptr_len, csr.indptr->dtype.bits, csr.indptr->ctx);\n  IdArray result_indices = NullArray(csr.indptr->dtype, csr.indptr->ctx);\n  IdArray result_data = NullArray();\n  if (edge_range[1] != edge_range[0]) {\n    // The chunk has edges\n    result_indptr =\n        IndexSelect(csr.indptr, src_vertex_range[0], src_vertex_range[1] + 1) -\n        edge_range[0];\n    result_indices = IndexSelect(csr.indices, edge_range[0], edge_range[1]) -\n                     dst_vertex_range[0];\n    if (CSRHasData(csr)) {\n      result_data =\n          IndexSelect(csr.data, edge_range[0], edge_range[1]) - edge_range[0];\n    }\n  }\n\n  CSRMatrix sub_csr = CSRMatrix(\n      src_vertex_range[1] - src_vertex_range[0],\n      dst_vertex_range[1] - dst_vertex_range[0], result_indptr, result_indices,\n      result_data, csr.sorted);\n\n  return sub_csr;\n}\n\n}  // namespace aten\n}  // namespace dgl\n"
  },
  {
    "path": "src/array/uvm_array.cc",
    "content": "/**\n *  Copyright (c) 2019-2022 by Contributors\n * @file array/uvm_array.cc\n * @brief DGL array utilities implementation\n */\n#include <dgl/array.h>\n\n#include <sstream>\n\n#include \"../c_api_common.h\"\n#include \"./uvm_array_op.h\"\n\nusing namespace dgl::runtime;\n\nnamespace dgl {\nnamespace aten {\n\nNDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) {\n#ifdef DGL_USE_CUDA\n  CHECK(array.IsPinned()) << \"Input array must be in pinned memory.\";\n  CHECK_EQ(index->ctx.device_type, kDGLCUDA) << \"Index must be on the GPU.\";\n  CHECK_GE(array->ndim, 1) << \"Input array must have at least 1 dimension.\";\n  CHECK_EQ(index->ndim, 1) << \"Index must be a 1D array.\";\n\n  ATEN_DTYPE_BITS_ONLY_SWITCH(array->dtype, DType, \"values\", {\n    ATEN_ID_TYPE_SWITCH(index->dtype, IdType, {\n      return impl::IndexSelectCPUFromGPU<DType, IdType>(array, index);\n    });\n  });\n#endif\n  LOG(FATAL) << \"IndexSelectCPUFromGPU requires CUDA.\";\n  // Should be unreachable\n  return NDArray{};\n}\n\nvoid IndexScatterGPUToCPU(NDArray dest, IdArray index, NDArray source) {\n#ifdef DGL_USE_CUDA\n  CHECK(dest.IsPinned()) << \"Destination array must be in pinned memory.\";\n  CHECK_EQ(index->ctx.device_type, kDGLCUDA) << \"Index must be on the GPU.\";\n  CHECK_EQ(source->ctx.device_type, kDGLCUDA)\n      << \"Source array must be on the GPU.\";\n  CHECK_EQ(dest->dtype, source->dtype) << \"Destination array and source \"\n                                          \"array must have the same dtype.\";\n  CHECK_GE(dest->ndim, 1)\n      << \"Destination array must have at least 1 dimension.\";\n  CHECK_EQ(index->ndim, 1) << \"Index must be a 1D array.\";\n\n  ATEN_DTYPE_BITS_ONLY_SWITCH(source->dtype, DType, \"values\", {\n    ATEN_ID_TYPE_SWITCH(index->dtype, IdType, {\n      impl::IndexScatterGPUToCPU<DType, IdType>(dest, index, source);\n    });\n  });\n#else\n  LOG(FATAL) << \"IndexScatterGPUToCPU requires CUDA.\";\n#endif\n}\n\nDGL_REGISTER_GLOBAL(\"ndarray.uvm._CAPI_DGLIndexSelectCPUFromGPU\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      NDArray array = args[0];\n      IdArray index = args[1];\n      *rv = IndexSelectCPUFromGPU(array, index);\n    });\n\nDGL_REGISTER_GLOBAL(\"ndarray.uvm._CAPI_DGLIndexScatterGPUToCPU\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      NDArray dest = args[0];\n      IdArray index = args[1];\n      NDArray source = args[2];\n      IndexScatterGPUToCPU(dest, index, source);\n    });\n\n}  // namespace aten\n}  // namespace dgl\n"
  },
  {
    "path": "src/array/uvm_array_op.h",
    "content": "/**\n *  Copyright (c) 2019-2022 by Contributors\n * @file array/uvm_array_op.h\n * @brief Array operator templates\n */\n#ifndef DGL_ARRAY_UVM_ARRAY_OP_H_\n#define DGL_ARRAY_UVM_ARRAY_OP_H_\n\n#include <dgl/array.h>\n\n#include <utility>\n\nnamespace dgl {\nnamespace aten {\nnamespace impl {\n\n// Take CPU array and GPU index, and then index with GPU.\ntemplate <typename DType, typename IdType>\nNDArray IndexSelectCPUFromGPU(NDArray array, IdArray index);\n\ntemplate <typename DType, typename IdType>\nvoid IndexScatterGPUToCPU(NDArray dest, IdArray index, NDArray source);\n\n}  // namespace impl\n}  // namespace aten\n}  // namespace dgl\n\n#endif  // DGL_ARRAY_UVM_ARRAY_OP_H_\n"
  },
  {
    "path": "src/bcast.cc",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file kernel/bcast.h\n * @brief Broadcast related function implementations.\n */\n#include <dgl/bcast.h>\n#include <dmlc/logging.h>\n\n#include <algorithm>\n\nnamespace dgl {\n\nnamespace {\n/**\n * @brief Determine whether use broadcasting or not, given the operator\n *        type, lhs array and rhs array.\n */\nbool UseBcast(const std::string& op, NDArray lhs, NDArray rhs) {\n  if (op == \"copy_lhs\" || op == \"copy_rhs\")\n    return false;  // broadcasting is not required for copy_u/copy_e\n  if (lhs->ndim != rhs->ndim) return true;\n  for (int i = 1; i < lhs->ndim; ++i) {\n    if (lhs->shape[i] != rhs->shape[i]) return true;\n  }\n  return false;\n}\n\n}  // namespace\n\n/**\n * @brief: Compute broadcast and auxiliary information given operator\n *         and operands for kernel computation.\n * @note: Expect lhs, rhs to have ndim >= 2 and the shape of lhs/rhs\n *        valid for the op computation.\n */\nBcastOff CalcBcastOff(const std::string& op, NDArray lhs, NDArray rhs) {\n  BcastOff rst;\n  rst.lhs_len = 1;\n  rst.rhs_len = 1;\n  for (int i = 1; i < lhs->ndim; ++i) rst.lhs_len *= lhs->shape[i];\n  for (int i = 1; i < rhs->ndim; ++i) rst.rhs_len *= rhs->shape[i];\n  rst.use_bcast = UseBcast(op, lhs, rhs);\n  rst.reduce_size = 1;  // defaults to 1, except for the case op == 'dot'.\n  if (rst.use_bcast) {\n    const int max_ndim = std::max(lhs->ndim, rhs->ndim) - 1;\n    int out_len = 1, j = 0;\n    if (op == \"dot\") {\n      rst.reduce_size = lhs->shape[lhs->ndim - 1];  // set reduce_size for dot.\n      ++j;  // do not consider reduce axis in computing lhs_offset and\n            // rhs_offset.\n    }\n    int stride_l = 1, stride_r = 1;\n    rst.lhs_offset.push_back(0);  // lhs_offset[0] is always 0\n    rst.rhs_offset.push_back(0);  // rhs_offset[0] is always 0\n    for (; j < max_ndim; ++j) {   // iterate the axis from back to front.\n      // dl refers to the size of lhs array in the current axis, likewise for\n      // dr.\n      const int dl =\n          (lhs->ndim - 1 - j < 1) ? 1 : lhs->shape[lhs->ndim - 1 - j];\n      const int dr =\n          (rhs->ndim - 1 - j < 1) ? 1 : rhs->shape[rhs->ndim - 1 - j];\n      for (int i = 1; i < std::max(dl, dr); ++i) {\n        for (int k = 0; k < out_len; ++k) {\n          /* Explaination:\n           * if current dimension is not broadcast dimension for lhs array\n           *   lhs_offset[i * out_len + k] = lhs_offset[k] + i * stride_l\n           * else\n           *   lhs_offset[i * out_len + k] = lhs_offset[k]\n           * likewise for rhs_offset.\n           */\n          rst.lhs_offset.push_back(rst.lhs_offset[k] + i * (i < dl) * stride_l);\n          rst.rhs_offset.push_back(rst.rhs_offset[k] + i * (i < dr) * stride_r);\n        }\n      }\n      out_len *= std::max(dl, dr);\n      stride_l *= dl;\n      stride_r *= dr;\n    }\n    rst.out_len = out_len;\n  } else {\n    rst.out_len = (op == \"copy_rhs\") ? rst.rhs_len : rst.lhs_len;\n    if (op == \"dot\") {\n      // set reduce_size for dot.\n      rst.reduce_size = lhs->shape[lhs->ndim - 1];\n      // out_len is divied by reduce_size in dot.\n      rst.out_len /= rst.reduce_size;\n    }\n  }\n  return rst;\n}\n\n}  // namespace dgl\n"
  },
  {
    "path": "src/c_api_common.cc",
    "content": "/**\n *  Copyright (c) 2018 by Contributors\n * @file c_api_common.cc\n * @brief DGL C API common implementations\n */\n#include \"c_api_common.h\"\n\n#include <dgl/graph_interface.h>\n\nusing dgl::runtime::DGLArgs;\nusing dgl::runtime::DGLArgValue;\nusing dgl::runtime::DGLRetValue;\nusing dgl::runtime::NDArray;\nusing dgl::runtime::PackedFunc;\n\nnamespace dgl {\n\nPackedFunc ConvertNDArrayVectorToPackedFunc(const std::vector<NDArray>& vec) {\n  auto body = [vec](DGLArgs args, DGLRetValue* rv) {\n    const uint64_t which = args[0];\n    if (which >= vec.size()) {\n      LOG(FATAL) << \"invalid choice\";\n    } else {\n      *rv = std::move(vec[which]);\n    }\n  };\n  return PackedFunc(body);\n}\n\nPackedFunc ConvertEdgeArrayToPackedFunc(const EdgeArray& ea) {\n  auto body = [ea](DGLArgs args, DGLRetValue* rv) {\n    const int which = args[0];\n    if (which == 0) {\n      *rv = std::move(ea.src);\n    } else if (which == 1) {\n      *rv = std::move(ea.dst);\n    } else if (which == 2) {\n      *rv = std::move(ea.id);\n    } else {\n      LOG(FATAL) << \"invalid choice\";\n    }\n  };\n  return PackedFunc(body);\n}\n\n}  // namespace dgl\n"
  },
  {
    "path": "src/c_api_common.h",
    "content": "/**\n *  Copyright (c) 2018 by Contributors\n * @file c_api_common.h\n * @brief DGL C API common util functions\n */\n#ifndef DGL_C_API_COMMON_H_\n#define DGL_C_API_COMMON_H_\n\n#include <dgl/array.h>\n#include <dgl/graph_interface.h>\n#include <dgl/runtime/ndarray.h>\n#include <dgl/runtime/packed_func.h>\n#include <dgl/runtime/registry.h>\n\n#include <algorithm>\n#include <string>\n#include <utility>\n#include <vector>\n\nnamespace dgl {\n\n// Communicator handler type\ntypedef void* CommunicatorHandle;\n\n// KVstore message handler type\ntypedef void* KVMsgHandle;\n\n/**\n * @brief Convert a vector of NDArray to PackedFunc.\n */\ndgl::runtime::PackedFunc ConvertNDArrayVectorToPackedFunc(\n    const std::vector<dgl::runtime::NDArray>& vec);\n\n/**\n * @brief Copy a vector to an NDArray.\n *\n * The data type of the NDArray will be IdType, which must be an integer type.\n * The element type (DType) of the vector must be convertible to IdType.\n */\ntemplate <typename IdType, typename DType>\ndgl::runtime::NDArray CopyVectorToNDArray(const std::vector<DType>& vec) {\n  using dgl::runtime::NDArray;\n  const int64_t len = vec.size();\n  NDArray a = NDArray::Empty(\n      {len}, DGLDataType{kDGLInt, sizeof(IdType) * 8, 1},\n      DGLContext{kDGLCPU, 0});\n  std::copy(vec.begin(), vec.end(), static_cast<IdType*>(a->data));\n  return a;\n}\n\nruntime::PackedFunc ConvertEdgeArrayToPackedFunc(const EdgeArray& ea);\n\n}  // namespace dgl\n\n#endif  // DGL_C_API_COMMON_H_\n"
  },
  {
    "path": "src/geometry/cpu/geometry_op_impl.cc",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file array/cpu/geometry_op_impl.cc\n * @brief Geometry operator CPU implementation\n */\n#include <dgl/random.h>\n\n#include <numeric>\n#include <utility>\n#include <vector>\n\n#include \"../geometry_op.h\"\n\nnamespace dgl {\nusing runtime::NDArray;\nnamespace geometry {\nnamespace impl {\n\n/** @brief Knuth shuffle algorithm */\ntemplate <typename IdType>\nvoid IndexShuffle(IdType *idxs, int64_t num_elems) {\n  for (int64_t i = num_elems - 1; i > 0; --i) {\n    int64_t j = dgl::RandomEngine::ThreadLocal()->RandInt(i);\n    std::swap(idxs[i], idxs[j]);\n  }\n}\ntemplate void IndexShuffle<int32_t>(int32_t *idxs, int64_t num_elems);\ntemplate void IndexShuffle<int64_t>(int64_t *idxs, int64_t num_elems);\n\n/** @brief Groupwise index shuffle algorithm. This function will perform shuffle\n * in subarrays indicated by group index. The group index is similar to indptr\n * in CSRMatrix.\n *\n * @param group_idxs group index array.\n * @param idxs index array for shuffle.\n * @param num_groups_idxs length of group_idxs\n * @param num_elems length of idxs\n */\ntemplate <typename IdType>\nvoid GroupIndexShuffle(\n    const IdType *group_idxs, IdType *idxs, int64_t num_groups_idxs,\n    int64_t num_elems) {\n  if (num_groups_idxs < 2) return;  // empty idxs array\n  CHECK_LE(group_idxs[num_groups_idxs - 1], num_elems)\n      << \"group_idxs out of range\";\n  for (int64_t i = 0; i < num_groups_idxs - 1; ++i) {\n    auto subarray_len = group_idxs[i + 1] - group_idxs[i];\n    IndexShuffle(idxs + group_idxs[i], subarray_len);\n  }\n}\ntemplate void GroupIndexShuffle<int32_t>(\n    const int32_t *group_idxs, int32_t *idxs, int64_t num_groups_idxs,\n    int64_t num_elems);\ntemplate void GroupIndexShuffle<int64_t>(\n    const int64_t *group_idxs, int64_t *idxs, int64_t num_groups_idxs,\n    int64_t num_elems);\n\ntemplate <typename IdType>\nIdArray RandomPerm(int64_t num_nodes) {\n  IdArray perm =\n      aten::NewIdArray(num_nodes, DGLContext{kDGLCPU, 0}, sizeof(IdType) * 8);\n  IdType *perm_data = static_cast<IdType *>(perm->data);\n  std::iota(perm_data, perm_data + num_nodes, 0);\n  IndexShuffle(perm_data, num_nodes);\n  return perm;\n}\n\ntemplate <typename IdType>\nIdArray GroupRandomPerm(\n    const IdType *group_idxs, int64_t num_group_idxs, int64_t num_nodes) {\n  IdArray perm =\n      aten::NewIdArray(num_nodes, DGLContext{kDGLCPU, 0}, sizeof(IdType) * 8);\n  IdType *perm_data = static_cast<IdType *>(perm->data);\n  std::iota(perm_data, perm_data + num_nodes, 0);\n  GroupIndexShuffle(group_idxs, perm_data, num_group_idxs, num_nodes);\n  return perm;\n}\n\n/**\n * @brief Farthest Point Sampler without the need to compute all pairs of\n * distance.\n *\n * The input array has shape (N, d), where N is the number of points, and d is\n * the dimension. It consists of a (flatten) batch of point clouds.\n *\n * In each batch, the algorithm starts with the sample index specified by\n * ``start_idx``. Then for each point, we maintain the minimum to-sample\n * distance. Finally, we pick the point with the maximum such distance. This\n * process will be repeated for ``sample_points`` - 1 times.\n */\ntemplate <DGLDeviceType XPU, typename FloatType, typename IdType>\nvoid FarthestPointSampler(\n    NDArray array, int64_t batch_size, int64_t sample_points, NDArray dist,\n    IdArray start_idx, IdArray result) {\n  const FloatType *array_data = static_cast<FloatType *>(array->data);\n  const int64_t point_in_batch = array->shape[0] / batch_size;\n  const int64_t dim = array->shape[1];\n\n  // distance\n  FloatType *dist_data = static_cast<FloatType *>(dist->data);\n\n  // sample for each cloud in the batch\n  IdType *start_idx_data = static_cast<IdType *>(start_idx->data);\n\n  // return value\n  IdType *ret_data = static_cast<IdType *>(result->data);\n\n  int64_t array_start = 0, ret_start = 0;\n  // loop for each point cloud sample in this batch\n  for (auto b = 0; b < batch_size; b++) {\n    // random init start sample\n    int64_t sample_idx = (int64_t)start_idx_data[b];\n    ret_data[ret_start] = (IdType)(sample_idx);\n\n    // sample the rest `sample_points - 1` points\n    for (auto i = 0; i < sample_points - 1; i++) {\n      // re-init distance and the argmax\n      int64_t dist_argmax = 0;\n      FloatType dist_max = -1;\n\n      // update the distance\n      for (auto j = 0; j < point_in_batch; j++) {\n        // compute the distance on dimensions\n        FloatType one_dist = 0;\n        for (auto d = 0; d < dim; d++) {\n          FloatType tmp = array_data[(array_start + j) * dim + d] -\n                          array_data[(array_start + sample_idx) * dim + d];\n          one_dist += tmp * tmp;\n        }\n\n        // for each out-of-set point, keep its nearest to-the-set distance\n        if (i == 0 || dist_data[j] > one_dist) {\n          dist_data[j] = one_dist;\n        }\n        // look for the farthest sample\n        if (dist_data[j] > dist_max) {\n          dist_argmax = j;\n          dist_max = dist_data[j];\n        }\n      }\n      // sample the `dist_argmax`-th point\n      sample_idx = dist_argmax;\n      ret_data[ret_start + i + 1] = (IdType)(sample_idx);\n    }\n\n    array_start += point_in_batch;\n    ret_start += sample_points;\n  }\n}\ntemplate void FarthestPointSampler<kDGLCPU, float, int32_t>(\n    NDArray array, int64_t batch_size, int64_t sample_points, NDArray dist,\n    IdArray start_idx, IdArray result);\ntemplate void FarthestPointSampler<kDGLCPU, float, int64_t>(\n    NDArray array, int64_t batch_size, int64_t sample_points, NDArray dist,\n    IdArray start_idx, IdArray result);\ntemplate void FarthestPointSampler<kDGLCPU, double, int32_t>(\n    NDArray array, int64_t batch_size, int64_t sample_points, NDArray dist,\n    IdArray start_idx, IdArray result);\ntemplate void FarthestPointSampler<kDGLCPU, double, int64_t>(\n    NDArray array, int64_t batch_size, int64_t sample_points, NDArray dist,\n    IdArray start_idx, IdArray result);\n\ntemplate <DGLDeviceType XPU, typename FloatType, typename IdType>\nvoid WeightedNeighborMatching(\n    const aten::CSRMatrix &csr, const NDArray weight, IdArray result) {\n  const int64_t num_nodes = result->shape[0];\n  const IdType *indptr_data = static_cast<IdType *>(csr.indptr->data);\n  const IdType *indices_data = static_cast<IdType *>(csr.indices->data);\n  IdType *result_data = static_cast<IdType *>(result->data);\n  FloatType *weight_data = static_cast<FloatType *>(weight->data);\n\n  // build node visiting order\n  IdArray vis_order = RandomPerm<IdType>(num_nodes);\n  IdType *vis_order_data = static_cast<IdType *>(vis_order->data);\n\n  for (int64_t n = 0; n < num_nodes; ++n) {\n    auto u = vis_order_data[n];\n\n    // if marked\n    if (result_data[u] >= 0) continue;\n\n    auto v_max = u;\n    FloatType weight_max = 0;\n\n    for (auto e = indptr_data[u]; e < indptr_data[u + 1]; ++e) {\n      auto v = indices_data[e];\n      if (result_data[v] >= 0) continue;\n      if (weight_data[e] >= weight_max) {\n        v_max = v;\n        weight_max = weight_data[e];\n      }\n    }\n    result_data[u] = std::min(u, v_max);\n    result_data[v_max] = result_data[u];\n  }\n}\ntemplate void WeightedNeighborMatching<kDGLCPU, float, int32_t>(\n    const aten::CSRMatrix &csr, const NDArray weight, IdArray result);\ntemplate void WeightedNeighborMatching<kDGLCPU, float, int64_t>(\n    const aten::CSRMatrix &csr, const NDArray weight, IdArray result);\ntemplate void WeightedNeighborMatching<kDGLCPU, double, int32_t>(\n    const aten::CSRMatrix &csr, const NDArray weight, IdArray result);\ntemplate void WeightedNeighborMatching<kDGLCPU, double, int64_t>(\n    const aten::CSRMatrix &csr, const NDArray weight, IdArray result);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nvoid NeighborMatching(const aten::CSRMatrix &csr, IdArray result) {\n  const int64_t num_nodes = result->shape[0];\n  const IdType *indptr_data = static_cast<IdType *>(csr.indptr->data);\n  const IdType *indices_data = static_cast<IdType *>(csr.indices->data);\n  IdType *result_data = static_cast<IdType *>(result->data);\n\n  // build vis order\n  IdArray u_vis_order = RandomPerm<IdType>(num_nodes);\n  IdType *u_vis_order_data = static_cast<IdType *>(u_vis_order->data);\n  IdArray v_vis_order = GroupRandomPerm<IdType>(\n      indptr_data, csr.indptr->shape[0], csr.indices->shape[0]);\n  IdType *v_vis_order_data = static_cast<IdType *>(v_vis_order->data);\n\n  for (int64_t n = 0; n < num_nodes; ++n) {\n    auto u = u_vis_order_data[n];\n\n    // if marked\n    if (result_data[u] >= 0) continue;\n\n    result_data[u] = u;\n\n    for (auto e = indptr_data[u]; e < indptr_data[u + 1]; ++e) {\n      auto v = indices_data[v_vis_order_data[e]];\n      if (result_data[v] >= 0) continue;\n      result_data[u] = std::min(u, v);\n      result_data[v] = result_data[u];\n      break;\n    }\n  }\n}\ntemplate void NeighborMatching<kDGLCPU, int32_t>(\n    const aten::CSRMatrix &csr, IdArray result);\ntemplate void NeighborMatching<kDGLCPU, int64_t>(\n    const aten::CSRMatrix &csr, IdArray result);\n\n}  // namespace impl\n}  // namespace geometry\n}  // namespace dgl\n"
  },
  {
    "path": "src/geometry/cuda/edge_coarsening_impl.cu",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file geometry/cuda/edge_coarsening_impl.cu\n * @brief Edge coarsening CUDA implementation\n */\n#include <curand_kernel.h>\n#include <dgl/array.h>\n#include <dgl/random.h>\n#include <dmlc/thread_local.h>\n\n#include <cstdint>\n\n#include \"../../array/cuda/utils.h\"\n#include \"../../runtime/cuda/cuda_common.h\"\n#include \"../geometry_op.h\"\n\n#define BLOCKS(N, T) (N + T - 1) / T\n\nnamespace dgl {\nnamespace geometry {\nnamespace impl {\n\nconstexpr float BLUE_P = 0.53406;\nconstexpr int BLUE = -1;\nconstexpr int RED = -2;\nconstexpr int EMPTY_IDX = -1;\n\n__device__ bool done_d;\n__global__ void init_done_kernel() { done_d = true; }\n\n__global__ void generate_uniform_kernel(\n    float *ret_values, size_t num, uint64_t seed) {\n  size_t id = blockIdx.x * blockDim.x + threadIdx.x;\n  if (id < num) {\n    curandState state;\n    curand_init(seed, id, 0, &state);\n    ret_values[id] = curand_uniform(&state);\n  }\n}\n\ntemplate <typename IdType>\n__global__ void colorize_kernel(\n    const float *prop, int64_t num_elem, IdType *result) {\n  const IdType idx = blockIdx.x * blockDim.x + threadIdx.x;\n  if (idx < num_elem) {\n    if (result[idx] < 0) {  // if unmatched\n      result[idx] = (prop[idx] > BLUE_P) ? RED : BLUE;\n      done_d = false;\n    }\n  }\n}\n\ntemplate <typename FloatType, typename IdType>\n__global__ void weighted_propose_kernel(\n    const IdType *indptr, const IdType *indices, const FloatType *weights,\n    int64_t num_elem, IdType *proposal, IdType *result) {\n  const IdType idx = blockIdx.x * blockDim.x + threadIdx.x;\n  if (idx < num_elem) {\n    if (result[idx] != BLUE) return;\n\n    bool has_unmatched_neighbor = false;\n    FloatType weight_max = 0.;\n    IdType v_max = EMPTY_IDX;\n\n    for (IdType i = indptr[idx]; i < indptr[idx + 1]; ++i) {\n      auto v = indices[i];\n\n      if (result[v] < 0) has_unmatched_neighbor = true;\n      if (result[v] == RED && weights[i] >= weight_max) {\n        v_max = v;\n        weight_max = weights[i];\n      }\n    }\n\n    proposal[idx] = v_max;\n    if (!has_unmatched_neighbor) result[idx] = idx;\n  }\n}\n\ntemplate <typename FloatType, typename IdType>\n__global__ void weighted_respond_kernel(\n    const IdType *indptr, const IdType *indices, const FloatType *weights,\n    int64_t num_elem, IdType *proposal, IdType *result) {\n  const IdType idx = blockIdx.x * blockDim.x + threadIdx.x;\n  if (idx < num_elem) {\n    if (result[idx] != RED) return;\n\n    bool has_unmatched_neighbors = false;\n    IdType v_max = -1;\n    FloatType weight_max = 0.;\n\n    for (IdType i = indptr[idx]; i < indptr[idx + 1]; ++i) {\n      auto v = indices[i];\n\n      if (result[v] < 0) {\n        has_unmatched_neighbors = true;\n      }\n      if (result[v] == BLUE && proposal[v] == idx && weights[i] >= weight_max) {\n        v_max = v;\n        weight_max = weights[i];\n      }\n    }\n    if (v_max >= 0) {\n      result[v_max] = min(idx, v_max);\n      result[idx] = min(idx, v_max);\n    }\n\n    if (!has_unmatched_neighbors) result[idx] = idx;\n  }\n}\n\n/** @brief The colorize procedure. This procedure randomly marks unmarked\n * nodes with BLUE(-1) and RED(-2) and checks whether the node matching\n * process has finished.\n */\ntemplate <typename IdType>\nbool Colorize(IdType *result_data, int64_t num_nodes, float *const prop) {\n  // initial done signal\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  CUDA_KERNEL_CALL(init_done_kernel, 1, 1, 0, stream);\n\n  // generate color prop for each node\n  uint64_t seed = dgl::RandomEngine::ThreadLocal()->RandInt(UINT64_MAX);\n  auto num_threads = cuda::FindNumThreads(num_nodes);\n  auto num_blocks = cuda::FindNumBlocks<'x'>(BLOCKS(num_nodes, num_threads));\n  CUDA_KERNEL_CALL(\n      generate_uniform_kernel, num_blocks, num_threads, 0, stream, prop,\n      num_nodes, seed);\n\n  // call kernel\n  CUDA_KERNEL_CALL(\n      colorize_kernel, num_blocks, num_threads, 0, stream, prop, num_nodes,\n      result_data);\n  bool done_h = false;\n  CUDA_CALL(cudaMemcpyFromSymbol(\n      &done_h, done_d, sizeof(done_h), 0, cudaMemcpyDeviceToHost));\n  return done_h;\n}\n\n/** @brief Weighted neighbor matching procedure (GPU version).\n * This implementation is from `A GPU Algorithm for Greedy Graph Matching\n * <http://www.staff.science.uu.nl/~bisse101/Articles/match12.pdf>`__\n *\n * This algorithm has three parts: colorize, propose and respond.\n * In colorize procedure, each unmarked node will be marked as BLUE or\n * RED randomly. If all nodes are marked, finish and return.\n * In propose procedure, each BLUE node will propose to the RED\n * neighbor with the largest weight (or randomly choose one if without weight).\n * If all its neighbors are marked, mark this node with its id.\n * In respond procedure, each RED node will respond to the BLUE neighbor\n * that has proposed to it and has the largest weight. If all neighbors\n * are marked, mark this node with its id. Else match this (BLUE, RED) node\n * pair and mark them with the smaller id between them.\n */\ntemplate <DGLDeviceType XPU, typename FloatType, typename IdType>\nvoid WeightedNeighborMatching(\n    const aten::CSRMatrix &csr, const NDArray weight, IdArray result) {\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  const auto &ctx = result->ctx;\n  auto device = runtime::DeviceAPI::Get(ctx);\n  device->SetDevice(ctx);\n\n  // create proposal tensor\n  const int64_t num_nodes = result->shape[0];\n  IdArray proposal = aten::Full(-1, num_nodes, sizeof(IdType) * 8, ctx);\n\n  // get data ptrs\n  IdType *indptr_data = static_cast<IdType *>(csr.indptr->data);\n  IdType *indices_data = static_cast<IdType *>(csr.indices->data);\n  IdType *result_data = static_cast<IdType *>(result->data);\n  IdType *proposal_data = static_cast<IdType *>(proposal->data);\n  FloatType *weight_data = static_cast<FloatType *>(weight->data);\n\n  // allocate workspace for prop used in Colorize()\n  float *prop = static_cast<float *>(\n      device->AllocWorkspace(ctx, num_nodes * sizeof(float)));\n\n  auto num_threads = cuda::FindNumThreads(num_nodes);\n  auto num_blocks = cuda::FindNumBlocks<'x'>(BLOCKS(num_nodes, num_threads));\n  while (!Colorize<IdType>(result_data, num_nodes, prop)) {\n    CUDA_KERNEL_CALL(\n        weighted_propose_kernel, num_blocks, num_threads, 0, stream,\n        indptr_data, indices_data, weight_data, num_nodes, proposal_data,\n        result_data);\n    CUDA_KERNEL_CALL(\n        weighted_respond_kernel, num_blocks, num_threads, 0, stream,\n        indptr_data, indices_data, weight_data, num_nodes, proposal_data,\n        result_data);\n  }\n  device->FreeWorkspace(ctx, prop);\n}\ntemplate void WeightedNeighborMatching<kDGLCUDA, float, int32_t>(\n    const aten::CSRMatrix &csr, const NDArray weight, IdArray result);\ntemplate void WeightedNeighborMatching<kDGLCUDA, float, int64_t>(\n    const aten::CSRMatrix &csr, const NDArray weight, IdArray result);\ntemplate void WeightedNeighborMatching<kDGLCUDA, double, int32_t>(\n    const aten::CSRMatrix &csr, const NDArray weight, IdArray result);\ntemplate void WeightedNeighborMatching<kDGLCUDA, double, int64_t>(\n    const aten::CSRMatrix &csr, const NDArray weight, IdArray result);\n\n/** @brief Unweighted neighbor matching procedure (GPU version).\n * Instead of directly sample neighbors, we assign each neighbor\n * with a random weight. We use random weight for 2 reasons:\n *  1. Random sample for each node in GPU is expensive. Although\n *     we can perform a global group-wise (neighborhood of each\n *     node as a group) random permutation as in CPU version,\n *     it still cost too much compared to directly using random weights.\n *  2. Graph is sparse, thus neighborhood of each node is small,\n *     which is suitable for GPU implementation.\n */\ntemplate <DGLDeviceType XPU, typename IdType>\nvoid NeighborMatching(const aten::CSRMatrix &csr, IdArray result) {\n  const int64_t num_edges = csr.indices->shape[0];\n  const auto &ctx = result->ctx;\n  auto device = runtime::DeviceAPI::Get(ctx);\n  device->SetDevice(ctx);\n\n  // generate random weights\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  NDArray weight = NDArray::Empty(\n      {num_edges}, DGLDataType{kDGLFloat, sizeof(float) * 8, 1}, ctx);\n  float *weight_data = static_cast<float *>(weight->data);\n  uint64_t seed = dgl::RandomEngine::ThreadLocal()->RandInt(UINT64_MAX);\n  auto num_threads = cuda::FindNumThreads(num_edges);\n  auto num_blocks = cuda::FindNumBlocks<'x'>(BLOCKS(num_edges, num_threads));\n  CUDA_KERNEL_CALL(\n      generate_uniform_kernel, num_blocks, num_threads, 0, stream, weight_data,\n      num_edges, seed);\n\n  WeightedNeighborMatching<XPU, float, IdType>(csr, weight, result);\n}\ntemplate void NeighborMatching<kDGLCUDA, int32_t>(\n    const aten::CSRMatrix &csr, IdArray result);\ntemplate void NeighborMatching<kDGLCUDA, int64_t>(\n    const aten::CSRMatrix &csr, IdArray result);\n\n}  // namespace impl\n}  // namespace geometry\n}  // namespace dgl\n"
  },
  {
    "path": "src/geometry/cuda/geometry_op_impl.cu",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file geometry/cuda/geometry_op_impl.cc\n * @brief Geometry operator CUDA implementation\n */\n#include <dgl/array.h>\n\n#include \"../../c_api_common.h\"\n#include \"../../runtime/cuda/cuda_common.h\"\n#include \"../geometry_op.h\"\n\n#define THREADS 1024\n\nnamespace dgl {\nnamespace geometry {\nnamespace impl {\n\n/**\n * @brief Farthest Point Sampler without the need to compute all pairs of\n * distance.\n *\n * The input array has shape (N, d), where N is the number of points, and d is\n * the dimension. It consists of a (flatten) batch of point clouds.\n *\n * In each batch, the algorithm starts with the sample index specified by\n * ``start_idx``. Then for each point, we maintain the minimum to-sample\n * distance. Finally, we pick the point with the maximum such distance. This\n * process will be repeated for ``sample_points`` - 1 times.\n */\ntemplate <typename FloatType, typename IdType>\n__global__ void fps_kernel(\n    const FloatType* array_data, const int64_t batch_size,\n    const int64_t sample_points, const int64_t point_in_batch,\n    const int64_t dim, const IdType* start_idx, FloatType* dist_data,\n    IdType* ret_data) {\n  const int64_t thread_idx = threadIdx.x;\n  const int64_t batch_idx = blockIdx.x;\n\n  const int64_t array_start = point_in_batch * batch_idx;\n  const int64_t ret_start = sample_points * batch_idx;\n\n  __shared__ FloatType dist_max_ht[THREADS];\n  __shared__ int64_t dist_argmax_ht[THREADS];\n\n  // start with random initialization\n  if (thread_idx == 0) {\n    ret_data[ret_start] = (IdType)(start_idx[batch_idx]);\n  }\n\n  // sample the rest `sample_points - 1` points\n  for (auto i = 0; i < sample_points - 1; i++) {\n    __syncthreads();\n\n    // the last sampled point\n    int64_t sample_idx = (int64_t)(ret_data[ret_start + i]);\n    dist_argmax_ht[thread_idx] = 0;\n    dist_max_ht[thread_idx] = (FloatType)(-1.);\n\n    // multi-thread distance calculation\n    for (auto j = thread_idx; j < point_in_batch; j += THREADS) {\n      FloatType one_dist = (FloatType)(0.);\n      for (auto d = 0; d < dim; d++) {\n        FloatType tmp = array_data[(array_start + j) * dim + d] -\n                        array_data[(array_start + sample_idx) * dim + d];\n        one_dist += tmp * tmp;\n      }\n\n      if (i == 0 || dist_data[array_start + j] > one_dist) {\n        dist_data[array_start + j] = one_dist;\n      }\n\n      if (dist_data[array_start + j] > dist_max_ht[thread_idx]) {\n        dist_argmax_ht[thread_idx] = j;\n        dist_max_ht[thread_idx] = dist_data[array_start + j];\n      }\n    }\n\n    __syncthreads();\n\n    if (thread_idx == 0) {\n      FloatType best = dist_max_ht[0];\n      int64_t best_idx = dist_argmax_ht[0];\n      for (auto j = 1; j < THREADS; j++) {\n        if (dist_max_ht[j] > best) {\n          best = dist_max_ht[j];\n          best_idx = dist_argmax_ht[j];\n        }\n      }\n      ret_data[ret_start + i + 1] = (IdType)(best_idx);\n    }\n  }\n}\n\ntemplate <DGLDeviceType XPU, typename FloatType, typename IdType>\nvoid FarthestPointSampler(\n    NDArray array, int64_t batch_size, int64_t sample_points, NDArray dist,\n    IdArray start_idx, IdArray result) {\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n\n  const FloatType* array_data = static_cast<FloatType*>(array->data);\n\n  const int64_t point_in_batch = array->shape[0] / batch_size;\n  const int64_t dim = array->shape[1];\n\n  // return value\n  IdType* ret_data = static_cast<IdType*>(result->data);\n\n  // distance\n  FloatType* dist_data = static_cast<FloatType*>(dist->data);\n\n  // sample for each cloud in the batch\n  IdType* start_idx_data = static_cast<IdType*>(start_idx->data);\n  CUDA_CALL(cudaSetDevice(array->ctx.device_id));\n\n  CUDA_KERNEL_CALL(\n      fps_kernel, batch_size, THREADS, 0, stream, array_data, batch_size,\n      sample_points, point_in_batch, dim, start_idx_data, dist_data, ret_data);\n}\n\ntemplate void FarthestPointSampler<kDGLCUDA, float, int32_t>(\n    NDArray array, int64_t batch_size, int64_t sample_points, NDArray dist,\n    IdArray start_idx, IdArray result);\ntemplate void FarthestPointSampler<kDGLCUDA, float, int64_t>(\n    NDArray array, int64_t batch_size, int64_t sample_points, NDArray dist,\n    IdArray start_idx, IdArray result);\ntemplate void FarthestPointSampler<kDGLCUDA, double, int32_t>(\n    NDArray array, int64_t batch_size, int64_t sample_points, NDArray dist,\n    IdArray start_idx, IdArray result);\ntemplate void FarthestPointSampler<kDGLCUDA, double, int64_t>(\n    NDArray array, int64_t batch_size, int64_t sample_points, NDArray dist,\n    IdArray start_idx, IdArray result);\n\n}  // namespace impl\n}  // namespace geometry\n}  // namespace dgl\n"
  },
  {
    "path": "src/geometry/geometry.cc",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file geometry/geometry.cc\n * @brief DGL geometry utilities implementation\n */\n#include <dgl/array.h>\n#include <dgl/base_heterograph.h>\n#include <dgl/packed_func_ext.h>\n#include <dgl/runtime/ndarray.h>\n\n#include \"../array/check.h\"\n#include \"../c_api_common.h\"\n#include \"./geometry_op.h\"\n\nusing namespace dgl::runtime;\n\nnamespace dgl {\nnamespace geometry {\n\nvoid FarthestPointSampler(\n    NDArray array, int64_t batch_size, int64_t sample_points, NDArray dist,\n    IdArray start_idx, IdArray result) {\n  CHECK_EQ(array->ctx, result->ctx)\n      << \"Array and the result should be on the same device.\";\n  CHECK_EQ(array->shape[0], dist->shape[0])\n      << \"Shape of array and dist mismatch\";\n  CHECK_EQ(start_idx->shape[0], batch_size)\n      << \"Shape of start_idx and batch_size mismatch\";\n  CHECK_EQ(result->shape[0], batch_size * sample_points)\n      << \"Invalid shape of result\";\n\n  ATEN_FLOAT_TYPE_SWITCH(array->dtype, FloatType, \"values\", {\n    ATEN_ID_TYPE_SWITCH(result->dtype, IdType, {\n      ATEN_XPU_SWITCH_CUDA(\n          array->ctx.device_type, XPU, \"FarthestPointSampler\", {\n            impl::FarthestPointSampler<XPU, FloatType, IdType>(\n                array, batch_size, sample_points, dist, start_idx, result);\n          });\n    });\n  });\n}\n\nvoid NeighborMatching(\n    HeteroGraphPtr graph, const NDArray weight, IdArray result) {\n  if (!aten::IsNullArray(weight)) {\n    ATEN_XPU_SWITCH_CUDA(\n        graph->Context().device_type, XPU, \"NeighborMatching\", {\n          ATEN_FLOAT_TYPE_SWITCH(weight->dtype, FloatType, \"weight\", {\n            ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, {\n              impl::WeightedNeighborMatching<XPU, FloatType, IdType>(\n                  graph->GetCSRMatrix(0), weight, result);\n            });\n          });\n        });\n  } else {\n    ATEN_XPU_SWITCH_CUDA(\n        graph->Context().device_type, XPU, \"NeighborMatching\", {\n          ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, {\n            impl::NeighborMatching<XPU, IdType>(graph->GetCSRMatrix(0), result);\n          });\n        });\n  }\n}\n\n///////////////////////// C APIs /////////////////////////\n\nDGL_REGISTER_GLOBAL(\"geometry._CAPI_FarthestPointSampler\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      const NDArray data = args[0];\n      const int64_t batch_size = args[1];\n      const int64_t sample_points = args[2];\n      NDArray dist = args[3];\n      IdArray start_idx = args[4];\n      IdArray result = args[5];\n\n      FarthestPointSampler(\n          data, batch_size, sample_points, dist, start_idx, result);\n    });\n\nDGL_REGISTER_GLOBAL(\"geometry._CAPI_NeighborMatching\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef graph = args[0];\n      const NDArray weight = args[1];\n      IdArray result = args[2];\n\n      // sanity check\n      aten::CheckCtx(\n          graph->Context(), {weight, result}, {\"edge_weight, result\"});\n      aten::CheckContiguous({weight, result}, {\"edge_weight\", \"result\"});\n      CHECK_EQ(graph->NumEdgeTypes(), 1)\n          << \"homogeneous graph has only one edge type\";\n      CHECK_EQ(result->ndim, 1) << \"result should be an 1D tensor.\";\n      auto pair = graph->meta_graph()->FindEdge(0);\n      const dgl_type_t node_type = pair.first;\n      CHECK_EQ(graph->NumVertices(node_type), result->shape[0])\n          << \"The number of nodes should be the same as the length of result \"\n             \"tensor.\";\n      if (!aten::IsNullArray(weight)) {\n        CHECK_EQ(weight->ndim, 1) << \"weight should be an 1D tensor.\";\n        CHECK_EQ(graph->NumEdges(0), weight->shape[0])\n            << \"number of edges in graph should be the same \"\n            << \"as the length of edge weight tensor.\";\n      }\n\n      // call implementation\n      NeighborMatching(graph.sptr(), weight, result);\n    });\n\n}  // namespace geometry\n}  // namespace dgl\n"
  },
  {
    "path": "src/geometry/geometry_op.h",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file geometry/geometry_op.h\n * @brief Geometry operator templates\n */\n#ifndef DGL_GEOMETRY_GEOMETRY_OP_H_\n#define DGL_GEOMETRY_GEOMETRY_OP_H_\n\n#include <dgl/array.h>\n\nnamespace dgl {\nnamespace geometry {\nnamespace impl {\n\ntemplate <DGLDeviceType XPU, typename FloatType, typename IdType>\nvoid FarthestPointSampler(\n    NDArray array, int64_t batch_size, int64_t sample_points, NDArray dist,\n    IdArray start_idx, IdArray result);\n\n/** @brief Implementation of weighted neighbor matching process of edge\n * coarsening used in Metis and Graclus for homogeneous graph coarsening. This\n * procedure keeps picking an unmarked vertex and matching it with one its\n * unmarked neighbors (that maximizes its edge weight) until no match can be\n * done.\n */\ntemplate <DGLDeviceType XPU, typename FloatType, typename IdType>\nvoid WeightedNeighborMatching(\n    const aten::CSRMatrix &csr, const NDArray weight, IdArray result);\n\n/** @brief Implementation of neighbor matching process of edge coarsening used\n * in Metis and Graclus for homogeneous graph coarsening. This procedure keeps\n * picking an unmarked vertex and matching it with one its unmarked neighbors\n * (that maximizes its edge weight) until no match can be done.\n */\ntemplate <DGLDeviceType XPU, typename IdType>\nvoid NeighborMatching(const aten::CSRMatrix &csr, IdArray result);\n\n}  // namespace impl\n}  // namespace geometry\n}  // namespace dgl\n\n#endif  // DGL_GEOMETRY_GEOMETRY_OP_H_\n"
  },
  {
    "path": "src/graph/creators.cc",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file graph/creators.cc\n * @brief Functions for constructing graphs.\n */\n#include \"./heterograph.h\"\nusing namespace dgl::runtime;\n\nnamespace dgl {\n\n// creator implementation\nHeteroGraphPtr CreateHeteroGraph(\n    GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs,\n    const std::vector<int64_t>& num_nodes_per_type) {\n  return HeteroGraphPtr(\n      new HeteroGraph(meta_graph, rel_graphs, num_nodes_per_type));\n}\n\nHeteroGraphPtr CreateFromCOO(\n    int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray row,\n    IdArray col, bool row_sorted, bool col_sorted, dgl_format_code_t formats) {\n  auto unit_g = UnitGraph::CreateFromCOO(\n      num_vtypes, num_src, num_dst, row, col, row_sorted, col_sorted, formats);\n  return HeteroGraphPtr(new HeteroGraph(unit_g->meta_graph(), {unit_g}));\n}\n\nHeteroGraphPtr CreateFromCOO(\n    int64_t num_vtypes, const aten::COOMatrix& mat, dgl_format_code_t formats) {\n  auto unit_g = UnitGraph::CreateFromCOO(num_vtypes, mat, formats);\n  return HeteroGraphPtr(new HeteroGraph(unit_g->meta_graph(), {unit_g}));\n}\n\nHeteroGraphPtr CreateFromCSR(\n    int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray indptr,\n    IdArray indices, IdArray edge_ids, dgl_format_code_t formats) {\n  auto unit_g = UnitGraph::CreateFromCSR(\n      num_vtypes, num_src, num_dst, indptr, indices, edge_ids, formats);\n  return HeteroGraphPtr(new HeteroGraph(unit_g->meta_graph(), {unit_g}));\n}\n\nHeteroGraphPtr CreateFromCSR(\n    int64_t num_vtypes, const aten::CSRMatrix& mat, dgl_format_code_t formats) {\n  auto unit_g = UnitGraph::CreateFromCSR(num_vtypes, mat, formats);\n  auto ret = HeteroGraphPtr(new HeteroGraph(unit_g->meta_graph(), {unit_g}));\n  return HeteroGraphPtr(new HeteroGraph(unit_g->meta_graph(), {unit_g}));\n}\n\nHeteroGraphPtr CreateFromCSC(\n    int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray indptr,\n    IdArray indices, IdArray edge_ids, dgl_format_code_t formats) {\n  auto unit_g = UnitGraph::CreateFromCSC(\n      num_vtypes, num_src, num_dst, indptr, indices, edge_ids, formats);\n  return HeteroGraphPtr(new HeteroGraph(unit_g->meta_graph(), {unit_g}));\n}\n\nHeteroGraphPtr CreateFromCSC(\n    int64_t num_vtypes, const aten::CSRMatrix& mat, dgl_format_code_t formats) {\n  auto unit_g = UnitGraph::CreateFromCSC(num_vtypes, mat, formats);\n  return HeteroGraphPtr(new HeteroGraph(unit_g->meta_graph(), {unit_g}));\n}\n\n}  // namespace dgl\n"
  },
  {
    "path": "src/graph/gk_ops.cc",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file graph/gk_ops.cc\n * @brief Graph operation implemented in GKlib\n */\n\n#if !defined(_WIN32)\n#include <GKlib.h>\n#endif  // !defined(_WIN32)\n\n#include <dgl/graph_op.h>\n\nnamespace dgl {\n\n#if !defined(_WIN32)\n\n/**\n * Convert DGL CSR to GKLib CSR.\n * GKLib CSR actually stores a CSR object and a CSC object of a graph.\n * @param mat the DGL CSR matrix.\n * @param is_row the input DGL matrix is CSR or CSC.\n * @return a GKLib CSR.\n */\ngk_csr_t *Convert2GKCsr(const aten::CSRMatrix mat, bool is_row) {\n  // TODO(zhengda) The conversion will be zero-copy in the future.\n  CHECK_EQ(mat.indptr->dtype.bits, sizeof(dgl_id_t) * CHAR_BIT);\n  CHECK_EQ(mat.indices->dtype.bits, sizeof(dgl_id_t) * CHAR_BIT);\n  const dgl_id_t *indptr = static_cast<dgl_id_t *>(mat.indptr->data);\n  const dgl_id_t *indices = static_cast<dgl_id_t *>(mat.indices->data);\n\n  gk_csr_t *gk_csr = gk_csr_Create();\n  gk_csr->nrows = mat.num_rows;\n  gk_csr->ncols = mat.num_cols;\n  uint64_t nnz = mat.indices->shape[0];\n  auto gk_indptr = gk_csr->rowptr;\n  auto gk_indices = gk_csr->rowind;\n  size_t num_ptrs;\n  if (is_row) {\n    num_ptrs = gk_csr->nrows + 1;\n    gk_indptr = gk_csr->rowptr = gk_zmalloc(\n        gk_csr->nrows + 1,\n        const_cast<char *>(\"gk_csr_ExtractPartition: rowptr\"));\n    gk_indices = gk_csr->rowind =\n        gk_imalloc(nnz, const_cast<char *>(\"gk_csr_ExtractPartition: rowind\"));\n  } else {\n    num_ptrs = gk_csr->ncols + 1;\n    gk_indptr = gk_csr->colptr = gk_zmalloc(\n        gk_csr->ncols + 1,\n        const_cast<char *>(\"gk_csr_ExtractPartition: colptr\"));\n    gk_indices = gk_csr->colind =\n        gk_imalloc(nnz, const_cast<char *>(\"gk_csr_ExtractPartition: colind\"));\n  }\n\n  for (size_t i = 0; i < num_ptrs; i++) {\n    gk_indptr[i] = indptr[i];\n  }\n  for (size_t i = 0; i < nnz; i++) {\n    gk_indices[i] = indices[i];\n  }\n  return gk_csr;\n}\n\n/**\n * Convert GKLib CSR to DGL CSR.\n * GKLib CSR actually stores a CSR object and a CSC object of a graph.\n * @param gk_csr the GKLib CSR.\n * @param is_row specify whether to convert the CSR or CSC object of GKLib CSR.\n * @return a DGL CSR matrix.\n */\naten::CSRMatrix Convert2DGLCsr(gk_csr_t *gk_csr, bool is_row) {\n  // TODO(zhengda) The conversion will be zero-copy in the future.\n  size_t num_ptrs;\n  size_t nnz;\n  auto gk_indptr = gk_csr->rowptr;\n  auto gk_indices = gk_csr->rowind;\n  if (is_row) {\n    num_ptrs = gk_csr->nrows + 1;\n    nnz = gk_csr->rowptr[num_ptrs - 1];\n    gk_indptr = gk_csr->rowptr;\n    gk_indices = gk_csr->rowind;\n  } else {\n    num_ptrs = gk_csr->ncols + 1;\n    nnz = gk_csr->colptr[num_ptrs - 1];\n    gk_indptr = gk_csr->colptr;\n    gk_indices = gk_csr->colind;\n  }\n\n  IdArray indptr_arr = aten::NewIdArray(num_ptrs);\n  IdArray indices_arr = aten::NewIdArray(nnz);\n  IdArray eids_arr = aten::NewIdArray(nnz);\n\n  dgl_id_t *indptr = static_cast<dgl_id_t *>(indptr_arr->data);\n  dgl_id_t *indices = static_cast<dgl_id_t *>(indices_arr->data);\n  dgl_id_t *eids = static_cast<dgl_id_t *>(eids_arr->data);\n  for (size_t i = 0; i < num_ptrs; i++) {\n    indptr[i] = gk_indptr[i];\n  }\n  for (size_t i = 0; i < nnz; i++) {\n    indices[i] = gk_indices[i];\n    eids[i] = i;\n  }\n\n  return aten::CSRMatrix(\n      gk_csr->nrows, gk_csr->ncols, indptr_arr, indices_arr, eids_arr);\n}\n\n#endif  // !defined(_WIN32)\n\nGraphPtr GraphOp::ToBidirectedSimpleImmutableGraph(ImmutableGraphPtr ig) {\n#if !defined(_WIN32)\n  // TODO(zhengda) should we get whatever CSR exists in the graph.\n  CSRPtr csr = ig->GetInCSR();\n  gk_csr_t *gk_csr = Convert2GKCsr(csr->ToCSRMatrix(), true);\n  gk_csr_t *sym_gk_csr = gk_csr_MakeSymmetric(gk_csr, GK_CSR_SYM_SUM);\n  auto mat = Convert2DGLCsr(sym_gk_csr, true);\n  gk_csr_Free(&gk_csr);\n  gk_csr_Free(&sym_gk_csr);\n\n  // This is a symmetric graph now. The in-csr and out-csr are the same.\n  csr = CSRPtr(new CSR(mat.indptr, mat.indices, mat.data));\n  return GraphPtr(new ImmutableGraph(csr, csr));\n#else\n  return GraphPtr();\n#endif  // !defined(_WIN32)\n}\n\n}  // namespace dgl\n"
  },
  {
    "path": "src/graph/graph.cc",
    "content": "/**\n *  Copyright (c) 2018 by Contributors\n * @file graph/graph.cc\n * @brief DGL graph index implementation\n */\n#include <dgl/graph.h>\n#include <dgl/sampler.h>\n\n#include <algorithm>\n#include <functional>\n#include <set>\n#include <tuple>\n#include <unordered_map>\n\n#include \"../c_api_common.h\"\n\nnamespace dgl {\n\nGraph::Graph(IdArray src_ids, IdArray dst_ids, size_t num_nodes) {\n  CHECK(aten::IsValidIdArray(src_ids));\n  CHECK(aten::IsValidIdArray(dst_ids));\n  this->AddVertices(num_nodes);\n  num_edges_ = src_ids->shape[0];\n  CHECK(static_cast<int64_t>(num_edges_) == dst_ids->shape[0])\n      << \"vectors in COO must have the same length\";\n  const dgl_id_t* src_data = static_cast<dgl_id_t*>(src_ids->data);\n  const dgl_id_t* dst_data = static_cast<dgl_id_t*>(dst_ids->data);\n  all_edges_src_.reserve(num_edges_);\n  all_edges_dst_.reserve(num_edges_);\n  for (uint64_t i = 0; i < num_edges_; i++) {\n    auto src = src_data[i];\n    auto dst = dst_data[i];\n    CHECK(HasVertex(src) && HasVertex(dst))\n        << \"Invalid vertices: src=\" << src << \" dst=\" << dst;\n\n    adjlist_[src].succ.push_back(dst);\n    adjlist_[src].edge_id.push_back(i);\n    reverse_adjlist_[dst].succ.push_back(src);\n    reverse_adjlist_[dst].edge_id.push_back(i);\n\n    all_edges_src_.push_back(src);\n    all_edges_dst_.push_back(dst);\n  }\n}\n\nbool Graph::IsMultigraph() const {\n  if (num_edges_ <= 1) {\n    return false;\n  }\n\n  typedef std::pair<int64_t, int64_t> Pair;\n  std::vector<Pair> pairs;\n  pairs.reserve(num_edges_);\n  for (uint64_t eid = 0; eid < num_edges_; ++eid) {\n    pairs.emplace_back(all_edges_src_[eid], all_edges_dst_[eid]);\n  }\n  // sort according to src and dst ids\n  std::sort(pairs.begin(), pairs.end(), [](const Pair& t1, const Pair& t2) {\n    return std::get<0>(t1) < std::get<0>(t2) ||\n           (std::get<0>(t1) == std::get<0>(t2) &&\n            std::get<1>(t1) < std::get<1>(t2));\n  });\n  for (uint64_t eid = 0; eid < num_edges_ - 1; ++eid) {\n    // As src and dst are all sorted, we only need to compare i and i+1\n    if (std::get<0>(pairs[eid]) == std::get<0>(pairs[eid + 1]) &&\n        std::get<1>(pairs[eid]) == std::get<1>(pairs[eid + 1]))\n      return true;\n  }\n\n  return false;\n}\n\nvoid Graph::AddVertices(uint64_t num_vertices) {\n  CHECK(!read_only_) << \"Graph is read-only. Mutations are not allowed.\";\n  adjlist_.resize(adjlist_.size() + num_vertices);\n  reverse_adjlist_.resize(reverse_adjlist_.size() + num_vertices);\n}\n\nvoid Graph::AddEdge(dgl_id_t src, dgl_id_t dst) {\n  CHECK(!read_only_) << \"Graph is read-only. Mutations are not allowed.\";\n  CHECK(HasVertex(src) && HasVertex(dst))\n      << \"Invalid vertices: src=\" << src << \" dst=\" << dst;\n\n  dgl_id_t eid = num_edges_++;\n\n  adjlist_[src].succ.push_back(dst);\n  adjlist_[src].edge_id.push_back(eid);\n  reverse_adjlist_[dst].succ.push_back(src);\n  reverse_adjlist_[dst].edge_id.push_back(eid);\n\n  all_edges_src_.push_back(src);\n  all_edges_dst_.push_back(dst);\n}\n\nvoid Graph::AddEdges(IdArray src_ids, IdArray dst_ids) {\n  CHECK(!read_only_) << \"Graph is read-only. Mutations are not allowed.\";\n  CHECK(aten::IsValidIdArray(src_ids)) << \"Invalid src id array.\";\n  CHECK(aten::IsValidIdArray(dst_ids)) << \"Invalid dst id array.\";\n  const auto srclen = src_ids->shape[0];\n  const auto dstlen = dst_ids->shape[0];\n  const int64_t* src_data = static_cast<int64_t*>(src_ids->data);\n  const int64_t* dst_data = static_cast<int64_t*>(dst_ids->data);\n  if (srclen == 1) {\n    // one-many\n    for (int64_t i = 0; i < dstlen; ++i) {\n      AddEdge(src_data[0], dst_data[i]);\n    }\n  } else if (dstlen == 1) {\n    // many-one\n    for (int64_t i = 0; i < srclen; ++i) {\n      AddEdge(src_data[i], dst_data[0]);\n    }\n  } else {\n    // many-many\n    CHECK(srclen == dstlen) << \"Invalid src and dst id array.\";\n    for (int64_t i = 0; i < srclen; ++i) {\n      AddEdge(src_data[i], dst_data[i]);\n    }\n  }\n}\n\nBoolArray Graph::HasVertices(IdArray vids) const {\n  CHECK(aten::IsValidIdArray(vids)) << \"Invalid vertex id array.\";\n  const auto len = vids->shape[0];\n  BoolArray rst = BoolArray::Empty({len}, vids->dtype, vids->ctx);\n  const int64_t* vid_data = static_cast<int64_t*>(vids->data);\n  int64_t* rst_data = static_cast<int64_t*>(rst->data);\n  const int64_t nverts = NumVertices();\n  for (int64_t i = 0; i < len; ++i) {\n    rst_data[i] = (vid_data[i] < nverts && vid_data[i] >= 0) ? 1 : 0;\n  }\n  return rst;\n}\n\n// O(E)\nbool Graph::HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const {\n  if (!HasVertex(src) || !HasVertex(dst)) return false;\n  const auto& succ = adjlist_[src].succ;\n  return std::find(succ.begin(), succ.end(), dst) != succ.end();\n}\n\n// O(E*k) pretty slow\nBoolArray Graph::HasEdgesBetween(IdArray src_ids, IdArray dst_ids) const {\n  CHECK(aten::IsValidIdArray(src_ids)) << \"Invalid src id array.\";\n  CHECK(aten::IsValidIdArray(dst_ids)) << \"Invalid dst id array.\";\n  const auto srclen = src_ids->shape[0];\n  const auto dstlen = dst_ids->shape[0];\n  const auto rstlen = std::max(srclen, dstlen);\n  BoolArray rst = BoolArray::Empty({rstlen}, src_ids->dtype, src_ids->ctx);\n  int64_t* rst_data = static_cast<int64_t*>(rst->data);\n  const int64_t* src_data = static_cast<int64_t*>(src_ids->data);\n  const int64_t* dst_data = static_cast<int64_t*>(dst_ids->data);\n  if (srclen == 1) {\n    // one-many\n    for (int64_t i = 0; i < dstlen; ++i) {\n      rst_data[i] = HasEdgeBetween(src_data[0], dst_data[i]) ? 1 : 0;\n    }\n  } else if (dstlen == 1) {\n    // many-one\n    for (int64_t i = 0; i < srclen; ++i) {\n      rst_data[i] = HasEdgeBetween(src_data[i], dst_data[0]) ? 1 : 0;\n    }\n  } else {\n    // many-many\n    CHECK(srclen == dstlen) << \"Invalid src and dst id array.\";\n    for (int64_t i = 0; i < srclen; ++i) {\n      rst_data[i] = HasEdgeBetween(src_data[i], dst_data[i]) ? 1 : 0;\n    }\n  }\n  return rst;\n}\n\n// The data is copy-out; support zero-copy?\nIdArray Graph::Predecessors(dgl_id_t vid, uint64_t radius) const {\n  CHECK(HasVertex(vid)) << \"invalid vertex: \" << vid;\n  CHECK(radius >= 1) << \"invalid radius: \" << radius;\n  std::set<dgl_id_t> vset;\n\n  for (auto& it : reverse_adjlist_[vid].succ) vset.insert(it);\n\n  const int64_t len = vset.size();\n  IdArray rst = IdArray::Empty(\n      {len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});\n  int64_t* rst_data = static_cast<int64_t*>(rst->data);\n\n  std::copy(vset.begin(), vset.end(), rst_data);\n  return rst;\n}\n\n// The data is copy-out; support zero-copy?\nIdArray Graph::Successors(dgl_id_t vid, uint64_t radius) const {\n  CHECK(HasVertex(vid)) << \"invalid vertex: \" << vid;\n  CHECK(radius >= 1) << \"invalid radius: \" << radius;\n  std::set<dgl_id_t> vset;\n\n  for (auto& it : adjlist_[vid].succ) vset.insert(it);\n\n  const int64_t len = vset.size();\n  IdArray rst = IdArray::Empty(\n      {len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});\n  int64_t* rst_data = static_cast<int64_t*>(rst->data);\n\n  std::copy(vset.begin(), vset.end(), rst_data);\n  return rst;\n}\n\n// O(E)\nIdArray Graph::EdgeId(dgl_id_t src, dgl_id_t dst) const {\n  CHECK(HasVertex(src) && HasVertex(dst))\n      << \"invalid edge: \" << src << \" -> \" << dst;\n\n  const auto& succ = adjlist_[src].succ;\n  std::vector<dgl_id_t> edgelist;\n\n  for (size_t i = 0; i < succ.size(); ++i) {\n    if (succ[i] == dst) edgelist.push_back(adjlist_[src].edge_id[i]);\n  }\n\n  // FIXME: signed?  Also it seems that we are using int64_t everywhere...\n  const int64_t len = edgelist.size();\n  IdArray rst = IdArray::Empty(\n      {len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});\n  // FIXME: signed?\n  int64_t* rst_data = static_cast<int64_t*>(rst->data);\n\n  std::copy(edgelist.begin(), edgelist.end(), rst_data);\n\n  return rst;\n}\n\n// O(E*k) pretty slow\nEdgeArray Graph::EdgeIds(IdArray src_ids, IdArray dst_ids) const {\n  CHECK(aten::IsValidIdArray(src_ids)) << \"Invalid src id array.\";\n  CHECK(aten::IsValidIdArray(dst_ids)) << \"Invalid dst id array.\";\n  const auto srclen = src_ids->shape[0];\n  const auto dstlen = dst_ids->shape[0];\n  int64_t i, j;\n\n  CHECK((srclen == dstlen) || (srclen == 1) || (dstlen == 1))\n      << \"Invalid src and dst id array.\";\n\n  const int64_t src_stride = (srclen == 1 && dstlen != 1) ? 0 : 1;\n  const int64_t dst_stride = (dstlen == 1 && srclen != 1) ? 0 : 1;\n  const int64_t* src_data = static_cast<int64_t*>(src_ids->data);\n  const int64_t* dst_data = static_cast<int64_t*>(dst_ids->data);\n\n  std::vector<dgl_id_t> src, dst, eid;\n\n  for (i = 0, j = 0; i < srclen && j < dstlen;\n       i += src_stride, j += dst_stride) {\n    const dgl_id_t src_id = src_data[i], dst_id = dst_data[j];\n    CHECK(HasVertex(src_id) && HasVertex(dst_id))\n        << \"invalid edge: \" << src_id << \" -> \" << dst_id;\n    const auto& succ = adjlist_[src_id].succ;\n    for (size_t k = 0; k < succ.size(); ++k) {\n      if (succ[k] == dst_id) {\n        src.push_back(src_id);\n        dst.push_back(dst_id);\n        eid.push_back(adjlist_[src_id].edge_id[k]);\n      }\n    }\n  }\n\n  int64_t rstlen = src.size();\n  IdArray rst_src = IdArray::Empty({rstlen}, src_ids->dtype, src_ids->ctx);\n  IdArray rst_dst = IdArray::Empty({rstlen}, src_ids->dtype, src_ids->ctx);\n  IdArray rst_eid = IdArray::Empty({rstlen}, src_ids->dtype, src_ids->ctx);\n  int64_t* rst_src_data = static_cast<int64_t*>(rst_src->data);\n  int64_t* rst_dst_data = static_cast<int64_t*>(rst_dst->data);\n  int64_t* rst_eid_data = static_cast<int64_t*>(rst_eid->data);\n\n  std::copy(src.begin(), src.end(), rst_src_data);\n  std::copy(dst.begin(), dst.end(), rst_dst_data);\n  std::copy(eid.begin(), eid.end(), rst_eid_data);\n\n  return EdgeArray{rst_src, rst_dst, rst_eid};\n}\n\nEdgeArray Graph::FindEdges(IdArray eids) const {\n  CHECK(aten::IsValidIdArray(eids)) << \"Invalid edge id array\";\n  int64_t len = eids->shape[0];\n\n  IdArray rst_src = IdArray::Empty({len}, eids->dtype, eids->ctx);\n  IdArray rst_dst = IdArray::Empty({len}, eids->dtype, eids->ctx);\n  IdArray rst_eid = IdArray::Empty({len}, eids->dtype, eids->ctx);\n  int64_t* eid_data = static_cast<int64_t*>(eids->data);\n  int64_t* rst_src_data = static_cast<int64_t*>(rst_src->data);\n  int64_t* rst_dst_data = static_cast<int64_t*>(rst_dst->data);\n  int64_t* rst_eid_data = static_cast<int64_t*>(rst_eid->data);\n\n  for (uint64_t i = 0; i < (uint64_t)len; ++i) {\n    dgl_id_t eid = eid_data[i];\n    if (eid >= num_edges_) LOG(FATAL) << \"invalid edge id:\" << eid;\n\n    rst_src_data[i] = all_edges_src_[eid];\n    rst_dst_data[i] = all_edges_dst_[eid];\n    rst_eid_data[i] = eid;\n  }\n\n  return EdgeArray{rst_src, rst_dst, rst_eid};\n}\n\n// O(E)\nEdgeArray Graph::InEdges(dgl_id_t vid) const {\n  CHECK(HasVertex(vid)) << \"invalid vertex: \" << vid;\n  const int64_t len = reverse_adjlist_[vid].succ.size();\n  IdArray src = IdArray::Empty(\n      {len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});\n  IdArray dst = IdArray::Empty(\n      {len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});\n  IdArray eid = IdArray::Empty(\n      {len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});\n  int64_t* src_data = static_cast<int64_t*>(src->data);\n  int64_t* dst_data = static_cast<int64_t*>(dst->data);\n  int64_t* eid_data = static_cast<int64_t*>(eid->data);\n  for (int64_t i = 0; i < len; ++i) {\n    src_data[i] = reverse_adjlist_[vid].succ[i];\n    eid_data[i] = reverse_adjlist_[vid].edge_id[i];\n  }\n  std::fill(dst_data, dst_data + len, vid);\n  return EdgeArray{src, dst, eid};\n}\n\n// O(E)\nEdgeArray Graph::InEdges(IdArray vids) const {\n  CHECK(aten::IsValidIdArray(vids)) << \"Invalid vertex id array.\";\n  const auto len = vids->shape[0];\n  const int64_t* vid_data = static_cast<int64_t*>(vids->data);\n  int64_t rstlen = 0;\n  for (int64_t i = 0; i < len; ++i) {\n    CHECK(HasVertex(vid_data[i])) << \"Invalid vertex: \" << vid_data[i];\n    rstlen += reverse_adjlist_[vid_data[i]].succ.size();\n  }\n  IdArray src = IdArray::Empty({rstlen}, vids->dtype, vids->ctx);\n  IdArray dst = IdArray::Empty({rstlen}, vids->dtype, vids->ctx);\n  IdArray eid = IdArray::Empty({rstlen}, vids->dtype, vids->ctx);\n  int64_t* src_ptr = static_cast<int64_t*>(src->data);\n  int64_t* dst_ptr = static_cast<int64_t*>(dst->data);\n  int64_t* eid_ptr = static_cast<int64_t*>(eid->data);\n  for (int64_t i = 0; i < len; ++i) {\n    const auto& pred = reverse_adjlist_[vid_data[i]].succ;\n    const auto& eids = reverse_adjlist_[vid_data[i]].edge_id;\n    for (size_t j = 0; j < pred.size(); ++j) {\n      *(src_ptr++) = pred[j];\n      *(dst_ptr++) = vid_data[i];\n      *(eid_ptr++) = eids[j];\n    }\n  }\n  return EdgeArray{src, dst, eid};\n}\n\n// O(E)\nEdgeArray Graph::OutEdges(dgl_id_t vid) const {\n  CHECK(HasVertex(vid)) << \"invalid vertex: \" << vid;\n  const int64_t len = adjlist_[vid].succ.size();\n  IdArray src = IdArray::Empty(\n      {len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});\n  IdArray dst = IdArray::Empty(\n      {len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});\n  IdArray eid = IdArray::Empty(\n      {len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});\n  int64_t* src_data = static_cast<int64_t*>(src->data);\n  int64_t* dst_data = static_cast<int64_t*>(dst->data);\n  int64_t* eid_data = static_cast<int64_t*>(eid->data);\n  for (int64_t i = 0; i < len; ++i) {\n    dst_data[i] = adjlist_[vid].succ[i];\n    eid_data[i] = adjlist_[vid].edge_id[i];\n  }\n  std::fill(src_data, src_data + len, vid);\n  return EdgeArray{src, dst, eid};\n}\n\n// O(E)\nEdgeArray Graph::OutEdges(IdArray vids) const {\n  CHECK(aten::IsValidIdArray(vids)) << \"Invalid vertex id array.\";\n  const auto len = vids->shape[0];\n  const int64_t* vid_data = static_cast<int64_t*>(vids->data);\n  int64_t rstlen = 0;\n  for (int64_t i = 0; i < len; ++i) {\n    CHECK(HasVertex(vid_data[i])) << \"Invalid vertex: \" << vid_data[i];\n    rstlen += adjlist_[vid_data[i]].succ.size();\n  }\n  IdArray src = IdArray::Empty({rstlen}, vids->dtype, vids->ctx);\n  IdArray dst = IdArray::Empty({rstlen}, vids->dtype, vids->ctx);\n  IdArray eid = IdArray::Empty({rstlen}, vids->dtype, vids->ctx);\n  int64_t* src_ptr = static_cast<int64_t*>(src->data);\n  int64_t* dst_ptr = static_cast<int64_t*>(dst->data);\n  int64_t* eid_ptr = static_cast<int64_t*>(eid->data);\n  for (int64_t i = 0; i < len; ++i) {\n    const auto& succ = adjlist_[vid_data[i]].succ;\n    const auto& eids = adjlist_[vid_data[i]].edge_id;\n    for (size_t j = 0; j < succ.size(); ++j) {\n      *(src_ptr++) = vid_data[i];\n      *(dst_ptr++) = succ[j];\n      *(eid_ptr++) = eids[j];\n    }\n  }\n  return EdgeArray{src, dst, eid};\n}\n\n// O(E*log(E)) if sort is required; otherwise, O(E)\nEdgeArray Graph::Edges(const std::string& order) const {\n  const int64_t len = num_edges_;\n  IdArray src = IdArray::Empty(\n      {len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});\n  IdArray dst = IdArray::Empty(\n      {len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});\n  IdArray eid = IdArray::Empty(\n      {len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});\n\n  if (order == \"srcdst\") {\n    typedef std::tuple<int64_t, int64_t, int64_t> Tuple;\n    std::vector<Tuple> tuples;\n    tuples.reserve(len);\n    for (uint64_t eid = 0; eid < num_edges_; ++eid) {\n      tuples.emplace_back(all_edges_src_[eid], all_edges_dst_[eid], eid);\n    }\n    // sort according to src and dst ids\n    std::sort(\n        tuples.begin(), tuples.end(), [](const Tuple& t1, const Tuple& t2) {\n          return std::get<0>(t1) < std::get<0>(t2) ||\n                 (std::get<0>(t1) == std::get<0>(t2) &&\n                  std::get<1>(t1) < std::get<1>(t2));\n        });\n\n    // make return arrays\n    int64_t* src_ptr = static_cast<int64_t*>(src->data);\n    int64_t* dst_ptr = static_cast<int64_t*>(dst->data);\n    int64_t* eid_ptr = static_cast<int64_t*>(eid->data);\n    for (size_t i = 0; i < tuples.size(); ++i) {\n      src_ptr[i] = std::get<0>(tuples[i]);\n      dst_ptr[i] = std::get<1>(tuples[i]);\n      eid_ptr[i] = std::get<2>(tuples[i]);\n    }\n  } else {\n    int64_t* src_ptr = static_cast<int64_t*>(src->data);\n    int64_t* dst_ptr = static_cast<int64_t*>(dst->data);\n    int64_t* eid_ptr = static_cast<int64_t*>(eid->data);\n    std::copy(all_edges_src_.begin(), all_edges_src_.end(), src_ptr);\n    std::copy(all_edges_dst_.begin(), all_edges_dst_.end(), dst_ptr);\n    for (uint64_t eid = 0; eid < num_edges_; ++eid) {\n      eid_ptr[eid] = eid;\n    }\n  }\n\n  return EdgeArray{src, dst, eid};\n}\n\n// O(V)\nDegreeArray Graph::InDegrees(IdArray vids) const {\n  CHECK(aten::IsValidIdArray(vids)) << \"Invalid vertex id array.\";\n  const auto len = vids->shape[0];\n  const int64_t* vid_data = static_cast<int64_t*>(vids->data);\n  DegreeArray rst = DegreeArray::Empty({len}, vids->dtype, vids->ctx);\n  int64_t* rst_data = static_cast<int64_t*>(rst->data);\n  for (int64_t i = 0; i < len; ++i) {\n    const auto vid = vid_data[i];\n    CHECK(HasVertex(vid)) << \"Invalid vertex: \" << vid;\n    rst_data[i] = reverse_adjlist_[vid].succ.size();\n  }\n  return rst;\n}\n\n// O(V)\nDegreeArray Graph::OutDegrees(IdArray vids) const {\n  CHECK(aten::IsValidIdArray(vids)) << \"Invalid vertex id array.\";\n  const auto len = vids->shape[0];\n  const int64_t* vid_data = static_cast<int64_t*>(vids->data);\n  DegreeArray rst = DegreeArray::Empty({len}, vids->dtype, vids->ctx);\n  int64_t* rst_data = static_cast<int64_t*>(rst->data);\n  for (int64_t i = 0; i < len; ++i) {\n    const auto vid = vid_data[i];\n    CHECK(HasVertex(vid)) << \"Invalid vertex: \" << vid;\n    rst_data[i] = adjlist_[vid].succ.size();\n  }\n  return rst;\n}\n\nSubgraph Graph::VertexSubgraph(IdArray vids) const {\n  CHECK(aten::IsValidIdArray(vids)) << \"Invalid vertex id array.\";\n  const auto len = vids->shape[0];\n  std::unordered_map<dgl_id_t, dgl_id_t> oldv2newv;\n  std::vector<dgl_id_t> edges;\n  const int64_t* vid_data = static_cast<int64_t*>(vids->data);\n  for (int64_t i = 0; i < len; ++i) {\n    oldv2newv[vid_data[i]] = i;\n  }\n  Subgraph rst;\n  rst.graph = std::make_shared<Graph>();\n  rst.induced_vertices = vids;\n  rst.graph->AddVertices(len);\n  for (int64_t i = 0; i < len; ++i) {\n    const dgl_id_t oldvid = vid_data[i];\n    const dgl_id_t newvid = i;\n    for (size_t j = 0; j < adjlist_[oldvid].succ.size(); ++j) {\n      const dgl_id_t oldsucc = adjlist_[oldvid].succ[j];\n      if (oldv2newv.count(oldsucc)) {\n        const dgl_id_t newsucc = oldv2newv[oldsucc];\n        edges.push_back(adjlist_[oldvid].edge_id[j]);\n        rst.graph->AddEdge(newvid, newsucc);\n      }\n    }\n  }\n  rst.induced_edges = IdArray::Empty(\n      {static_cast<int64_t>(edges.size())}, vids->dtype, vids->ctx);\n  std::copy(\n      edges.begin(), edges.end(),\n      static_cast<int64_t*>(rst.induced_edges->data));\n  return rst;\n}\n\nSubgraph Graph::EdgeSubgraph(IdArray eids, bool preserve_nodes) const {\n  CHECK(aten::IsValidIdArray(eids)) << \"Invalid edge id array.\";\n  const auto len = eids->shape[0];\n  std::vector<dgl_id_t> nodes;\n  const int64_t* eid_data = static_cast<int64_t*>(eids->data);\n\n  Subgraph rst;\n  if (!preserve_nodes) {\n    std::unordered_map<dgl_id_t, dgl_id_t> oldv2newv;\n\n    for (int64_t i = 0; i < len; ++i) {\n      const dgl_id_t src_id = all_edges_src_[eid_data[i]];\n      const dgl_id_t dst_id = all_edges_dst_[eid_data[i]];\n      if (oldv2newv.insert(std::make_pair(src_id, oldv2newv.size())).second)\n        nodes.push_back(src_id);\n      if (oldv2newv.insert(std::make_pair(dst_id, oldv2newv.size())).second)\n        nodes.push_back(dst_id);\n    }\n\n    rst.graph = std::make_shared<Graph>();\n    rst.induced_edges = eids;\n    rst.graph->AddVertices(nodes.size());\n\n    for (int64_t i = 0; i < len; ++i) {\n      const dgl_id_t src_id = all_edges_src_[eid_data[i]];\n      const dgl_id_t dst_id = all_edges_dst_[eid_data[i]];\n      rst.graph->AddEdge(oldv2newv[src_id], oldv2newv[dst_id]);\n    }\n\n    rst.induced_vertices = IdArray::Empty(\n        {static_cast<int64_t>(nodes.size())}, eids->dtype, eids->ctx);\n    std::copy(\n        nodes.begin(), nodes.end(),\n        static_cast<int64_t*>(rst.induced_vertices->data));\n  } else {\n    rst.graph = std::make_shared<Graph>();\n    rst.induced_edges = eids;\n    rst.graph->AddVertices(NumVertices());\n\n    for (int64_t i = 0; i < len; ++i) {\n      dgl_id_t src_id = all_edges_src_[eid_data[i]];\n      dgl_id_t dst_id = all_edges_dst_[eid_data[i]];\n      rst.graph->AddEdge(src_id, dst_id);\n    }\n\n    for (uint64_t i = 0; i < NumVertices(); ++i) nodes.push_back(i);\n\n    rst.induced_vertices = IdArray::Empty(\n        {static_cast<int64_t>(nodes.size())}, eids->dtype, eids->ctx);\n    std::copy(\n        nodes.begin(), nodes.end(),\n        static_cast<int64_t*>(rst.induced_vertices->data));\n  }\n\n  return rst;\n}\n\nstd::vector<IdArray> Graph::GetAdj(\n    bool transpose, const std::string& fmt) const {\n  uint64_t num_edges = NumEdges();\n  uint64_t num_nodes = NumVertices();\n  if (fmt == \"coo\") {\n    IdArray idx = IdArray::Empty(\n        {2 * static_cast<int64_t>(num_edges)}, DGLDataType{kDGLInt, 64, 1},\n        DGLContext{kDGLCPU, 0});\n    int64_t* idx_data = static_cast<int64_t*>(idx->data);\n    if (transpose) {\n      std::copy(all_edges_src_.begin(), all_edges_src_.end(), idx_data);\n      std::copy(\n          all_edges_dst_.begin(), all_edges_dst_.end(), idx_data + num_edges);\n    } else {\n      std::copy(all_edges_dst_.begin(), all_edges_dst_.end(), idx_data);\n      std::copy(\n          all_edges_src_.begin(), all_edges_src_.end(), idx_data + num_edges);\n    }\n    IdArray eid = IdArray::Empty(\n        {static_cast<int64_t>(num_edges)}, DGLDataType{kDGLInt, 64, 1},\n        DGLContext{kDGLCPU, 0});\n    int64_t* eid_data = static_cast<int64_t*>(eid->data);\n    for (uint64_t eid = 0; eid < num_edges; ++eid) {\n      eid_data[eid] = eid;\n    }\n    return std::vector<IdArray>{idx, eid};\n  } else if (fmt == \"csr\") {\n    IdArray indptr = IdArray::Empty(\n        {static_cast<int64_t>(num_nodes) + 1}, DGLDataType{kDGLInt, 64, 1},\n        DGLContext{kDGLCPU, 0});\n    IdArray indices = IdArray::Empty(\n        {static_cast<int64_t>(num_edges)}, DGLDataType{kDGLInt, 64, 1},\n        DGLContext{kDGLCPU, 0});\n    IdArray eid = IdArray::Empty(\n        {static_cast<int64_t>(num_edges)}, DGLDataType{kDGLInt, 64, 1},\n        DGLContext{kDGLCPU, 0});\n    int64_t* indptr_data = static_cast<int64_t*>(indptr->data);\n    int64_t* indices_data = static_cast<int64_t*>(indices->data);\n    int64_t* eid_data = static_cast<int64_t*>(eid->data);\n    const AdjacencyList* adjlist;\n    if (transpose) {\n      // Out-edges.\n      adjlist = &adjlist_;\n    } else {\n      // In-edges.\n      adjlist = &reverse_adjlist_;\n    }\n    indptr_data[0] = 0;\n    for (size_t i = 0; i < adjlist->size(); i++) {\n      indptr_data[i + 1] = indptr_data[i] + adjlist->at(i).succ.size();\n      std::copy(\n          adjlist->at(i).succ.begin(), adjlist->at(i).succ.end(),\n          indices_data + indptr_data[i]);\n      std::copy(\n          adjlist->at(i).edge_id.begin(), adjlist->at(i).edge_id.end(),\n          eid_data + indptr_data[i]);\n    }\n    return std::vector<IdArray>{indptr, indices, eid};\n  } else {\n    LOG(FATAL) << \"unsupported format\";\n    return std::vector<IdArray>();\n  }\n}\n\n}  // namespace dgl\n"
  },
  {
    "path": "src/graph/graph_apis.cc",
    "content": "/**\n *  Copyright (c) 2018 by Contributors\n * @file graph/graph.cc\n * @brief DGL graph index APIs\n */\n#include <dgl/graph.h>\n#include <dgl/graph_op.h>\n#include <dgl/immutable_graph.h>\n#include <dgl/nodeflow.h>\n#include <dgl/packed_func_ext.h>\n#include <dgl/sampler.h>\n\n#include \"../c_api_common.h\"\n\nusing dgl::runtime::DGLArgs;\nusing dgl::runtime::DGLArgValue;\nusing dgl::runtime::DGLRetValue;\nusing dgl::runtime::NDArray;\nusing dgl::runtime::PackedFunc;\n\nnamespace dgl {\n\n///////////////////////////// Graph API ///////////////////////////////////\n\nDGL_REGISTER_GLOBAL(\"graph_index._CAPI_DGLGraphCreateMutable\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      *rv = GraphRef(Graph::Create());\n    });\n\nDGL_REGISTER_GLOBAL(\"graph_index._CAPI_DGLGraphCreate\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      const IdArray src_ids = args[0];\n      const IdArray dst_ids = args[1];\n      const int64_t num_nodes = args[2];\n      const bool readonly = args[3];\n      if (readonly) {\n        *rv = GraphRef(\n            ImmutableGraph::CreateFromCOO(num_nodes, src_ids, dst_ids));\n      } else {\n        *rv = GraphRef(Graph::CreateFromCOO(num_nodes, src_ids, dst_ids));\n      }\n    });\n\nDGL_REGISTER_GLOBAL(\"graph_index._CAPI_DGLGraphCSRCreate\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      const IdArray indptr = args[0];\n      const IdArray indices = args[1];\n      const std::string edge_dir = args[2];\n\n      IdArray edge_ids = IdArray::Empty(\n          {indices->shape[0]}, DGLDataType{kDGLInt, 64, 1},\n          DGLContext{kDGLCPU, 0});\n      int64_t* edge_data = static_cast<int64_t*>(edge_ids->data);\n      for (int64_t i = 0; i < edge_ids->shape[0]; i++) edge_data[i] = i;\n      *rv = GraphRef(\n          ImmutableGraph::CreateFromCSR(indptr, indices, edge_ids, edge_dir));\n    });\n\nDGL_REGISTER_GLOBAL(\"graph_index._CAPI_DGLGraphCSRCreateMMap\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      const std::string shared_mem_name = args[0];\n      *rv = GraphRef(ImmutableGraph::CreateFromCSR(shared_mem_name));\n    });\n\nDGL_REGISTER_GLOBAL(\"graph_index._CAPI_DGLGraphAddVertices\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      GraphRef g = args[0];\n      uint64_t num_vertices = args[1];\n      g->AddVertices(num_vertices);\n    });\n\nDGL_REGISTER_GLOBAL(\"graph_index._CAPI_DGLGraphAddEdge\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      GraphRef g = args[0];\n      const dgl_id_t src = args[1];\n      const dgl_id_t dst = args[2];\n      g->AddEdge(src, dst);\n    });\n\nDGL_REGISTER_GLOBAL(\"graph_index._CAPI_DGLGraphAddEdges\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      GraphRef g = args[0];\n      const IdArray src = args[1];\n      const IdArray dst = args[2];\n      g->AddEdges(src, dst);\n    });\n\nDGL_REGISTER_GLOBAL(\"graph_index._CAPI_DGLGraphClear\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      GraphRef g = args[0];\n      g->Clear();\n    });\n\nDGL_REGISTER_GLOBAL(\"graph_index._CAPI_DGLGraphIsMultigraph\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      GraphRef g = args[0];\n      *rv = g->IsMultigraph();\n    });\n\nDGL_REGISTER_GLOBAL(\"graph_index._CAPI_DGLGraphIsReadonly\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      GraphRef g = args[0];\n      *rv = g->IsReadonly();\n    });\n\nDGL_REGISTER_GLOBAL(\"graph_index._CAPI_DGLGraphNumVertices\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      GraphRef g = args[0];\n      *rv = static_cast<int64_t>(g->NumVertices());\n    });\n\nDGL_REGISTER_GLOBAL(\"graph_index._CAPI_DGLGraphNumEdges\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      GraphRef g = args[0];\n      *rv = static_cast<int64_t>(g->NumEdges());\n    });\n\nDGL_REGISTER_GLOBAL(\"graph_index._CAPI_DGLGraphHasVertex\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      GraphRef g = args[0];\n      const dgl_id_t vid = args[1];\n      *rv = g->HasVertex(vid);\n    });\n\nDGL_REGISTER_GLOBAL(\"graph_index._CAPI_DGLGraphHasVertices\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      GraphRef g = args[0];\n      const IdArray vids = args[1];\n      *rv = g->HasVertices(vids);\n    });\n\nDGL_REGISTER_GLOBAL(\"graph_index._CAPI_DGLGraphHasEdgeBetween\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      GraphRef g = args[0];\n      const dgl_id_t src = args[1];\n      const dgl_id_t dst = args[2];\n      *rv = g->HasEdgeBetween(src, dst);\n    });\n\nDGL_REGISTER_GLOBAL(\"graph_index._CAPI_DGLGraphHasEdgesBetween\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      GraphRef g = args[0];\n      const IdArray src = args[1];\n      const IdArray dst = args[2];\n      *rv = g->HasEdgesBetween(src, dst);\n    });\n\nDGL_REGISTER_GLOBAL(\"graph_index._CAPI_DGLGraphPredecessors\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      GraphRef g = args[0];\n      const dgl_id_t vid = args[1];\n      const uint64_t radius = args[2];\n      *rv = g->Predecessors(vid, radius);\n    });\n\nDGL_REGISTER_GLOBAL(\"graph_index._CAPI_DGLGraphSuccessors\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      GraphRef g = args[0];\n      const dgl_id_t vid = args[1];\n      const uint64_t radius = args[2];\n      *rv = g->Successors(vid, radius);\n    });\n\nDGL_REGISTER_GLOBAL(\"graph_index._CAPI_DGLGraphEdgeId\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      GraphRef g = args[0];\n      const dgl_id_t src = args[1];\n      const dgl_id_t dst = args[2];\n      *rv = g->EdgeId(src, dst);\n    });\n\nDGL_REGISTER_GLOBAL(\"graph_index._CAPI_DGLGraphEdgeIds\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      GraphRef g = args[0];\n      const IdArray src = args[1];\n      const IdArray dst = args[2];\n      *rv = ConvertEdgeArrayToPackedFunc(g->EdgeIds(src, dst));\n    });\n\nDGL_REGISTER_GLOBAL(\"graph_index._CAPI_DGLGraphFindEdge\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      GraphRef g = args[0];\n      const dgl_id_t eid = args[1];\n      const auto& pair = g->FindEdge(eid);\n      *rv = PackedFunc([pair](DGLArgs args, DGLRetValue* rv) {\n        const int choice = args[0];\n        const int64_t ret = (choice == 0 ? pair.first : pair.second);\n        *rv = ret;\n      });\n    });\n\nDGL_REGISTER_GLOBAL(\"graph_index._CAPI_DGLGraphFindEdges\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      GraphRef g = args[0];\n      const IdArray eids = args[1];\n      *rv = ConvertEdgeArrayToPackedFunc(g->FindEdges(eids));\n    });\n\nDGL_REGISTER_GLOBAL(\"graph_index._CAPI_DGLGraphInEdges_1\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      GraphRef g = args[0];\n      const dgl_id_t vid = args[1];\n      *rv = ConvertEdgeArrayToPackedFunc(g->InEdges(vid));\n    });\n\nDGL_REGISTER_GLOBAL(\"graph_index._CAPI_DGLGraphInEdges_2\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      GraphRef g = args[0];\n      const IdArray vids = args[1];\n      *rv = ConvertEdgeArrayToPackedFunc(g->InEdges(vids));\n    });\n\nDGL_REGISTER_GLOBAL(\"graph_index._CAPI_DGLGraphOutEdges_1\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      GraphRef g = args[0];\n      const dgl_id_t vid = args[1];\n      *rv = ConvertEdgeArrayToPackedFunc(g->OutEdges(vid));\n    });\n\nDGL_REGISTER_GLOBAL(\"graph_index._CAPI_DGLGraphOutEdges_2\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      GraphRef g = args[0];\n      const IdArray vids = args[1];\n      *rv = ConvertEdgeArrayToPackedFunc(g->OutEdges(vids));\n    });\n\nDGL_REGISTER_GLOBAL(\"graph_index._CAPI_DGLGraphEdges\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      GraphRef g = args[0];\n      std::string order = args[1];\n      *rv = ConvertEdgeArrayToPackedFunc(g->Edges(order));\n    });\n\nDGL_REGISTER_GLOBAL(\"graph_index._CAPI_DGLGraphInDegree\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      GraphRef g = args[0];\n      const dgl_id_t vid = args[1];\n      *rv = static_cast<int64_t>(g->InDegree(vid));\n    });\n\nDGL_REGISTER_GLOBAL(\"graph_index._CAPI_DGLGraphInDegrees\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      GraphRef g = args[0];\n      const IdArray vids = args[1];\n      *rv = g->InDegrees(vids);\n    });\n\nDGL_REGISTER_GLOBAL(\"graph_index._CAPI_DGLGraphOutDegree\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      GraphRef g = args[0];\n      const dgl_id_t vid = args[1];\n      *rv = static_cast<int64_t>(g->OutDegree(vid));\n    });\n\nDGL_REGISTER_GLOBAL(\"graph_index._CAPI_DGLGraphOutDegrees\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      GraphRef g = args[0];\n      const IdArray vids = args[1];\n      *rv = g->OutDegrees(vids);\n    });\n\nDGL_REGISTER_GLOBAL(\"graph_index._CAPI_DGLGraphVertexSubgraph\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      GraphRef g = args[0];\n      const IdArray vids = args[1];\n      std::shared_ptr<Subgraph> subg(new Subgraph(g->VertexSubgraph(vids)));\n      *rv = SubgraphRef(subg);\n    });\n\nDGL_REGISTER_GLOBAL(\"graph_index._CAPI_DGLGraphEdgeSubgraph\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      GraphRef g = args[0];\n      const IdArray eids = args[1];\n      bool preserve_nodes = args[2];\n      std::shared_ptr<Subgraph> subg(\n          new Subgraph(g->EdgeSubgraph(eids, preserve_nodes)));\n      *rv = SubgraphRef(subg);\n    });\n\nDGL_REGISTER_GLOBAL(\"graph_index._CAPI_DGLGraphGetAdj\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      GraphRef g = args[0];\n      bool transpose = args[1];\n      std::string format = args[2];\n      auto res = g->GetAdj(transpose, format);\n      *rv = ConvertNDArrayVectorToPackedFunc(res);\n    });\n\nDGL_REGISTER_GLOBAL(\"graph_index._CAPI_DGLGraphContext\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      GraphRef g = args[0];\n      *rv = g->Context();\n    });\n\nDGL_REGISTER_GLOBAL(\"graph_index._CAPI_DGLGraphNumBits\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      GraphRef g = args[0];\n      *rv = g->NumBits();\n    });\n\n// Subgraph C APIs\n\nDGL_REGISTER_GLOBAL(\"graph_index._CAPI_DGLSubgraphGetGraph\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      SubgraphRef subg = args[0];\n      *rv = GraphRef(subg->graph);\n    });\n\nDGL_REGISTER_GLOBAL(\"graph_index._CAPI_DGLSubgraphGetInducedVertices\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      SubgraphRef subg = args[0];\n      *rv = subg->induced_vertices;\n    });\n\nDGL_REGISTER_GLOBAL(\"graph_index._CAPI_DGLSubgraphGetInducedEdges\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      SubgraphRef subg = args[0];\n      *rv = subg->induced_edges;\n    });\n\nDGL_REGISTER_GLOBAL(\"graph_index._CAPI_DGLSortAdj\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      GraphRef g = args[0];\n      g->SortCSR();\n    });\n\n}  // namespace dgl\n"
  },
  {
    "path": "src/graph/graph_op.cc",
    "content": "/**\n *  Copyright (c) 2018 by Contributors\n * @file graph/graph.cc\n * @brief Graph operation implementation\n */\n#include <dgl/array.h>\n#include <dgl/graph_op.h>\n#include <dgl/immutable_graph.h>\n#include <dgl/packed_func_ext.h>\n#include <dgl/runtime/container.h>\n#include <dgl/runtime/parallel_for.h>\n\n#include <algorithm>\n\n#include \"../c_api_common.h\"\n\nusing namespace dgl::runtime;\n\nnamespace dgl {\nnamespace {\n// generate consecutive dgl ids\nclass RangeIter : public std::iterator<std::input_iterator_tag, dgl_id_t> {\n public:\n  explicit RangeIter(dgl_id_t from) : cur_(from) {}\n\n  RangeIter& operator++() {\n    ++cur_;\n    return *this;\n  }\n\n  RangeIter operator++(int) {\n    RangeIter retval = *this;\n    ++cur_;\n    return retval;\n  }\n  bool operator==(RangeIter other) const { return cur_ == other.cur_; }\n  bool operator!=(RangeIter other) const { return cur_ != other.cur_; }\n  dgl_id_t operator*() const { return cur_; }\n\n private:\n  dgl_id_t cur_;\n};\n\nbool IsMutable(GraphPtr g) {\n  MutableGraphPtr mg = std::dynamic_pointer_cast<Graph>(g);\n  return mg != nullptr;\n}\n\n}  // namespace\n\nGraphPtr GraphOp::Reverse(GraphPtr g) {\n  ImmutableGraphPtr ig = std::dynamic_pointer_cast<ImmutableGraph>(g);\n  CHECK(ig) << \"Reverse is only supported on immutable graph\";\n  return ig->Reverse();\n}\n\nGraphPtr GraphOp::LineGraph(GraphPtr g, bool backtracking) {\n  MutableGraphPtr mg = std::dynamic_pointer_cast<Graph>(g);\n  CHECK(mg) << \"Line graph transformation is only supported on mutable graph\";\n  MutableGraphPtr lg = Graph::Create();\n  lg->AddVertices(g->NumEdges());\n  for (size_t i = 0; i < mg->all_edges_src_.size(); ++i) {\n    const auto u = mg->all_edges_src_[i];\n    const auto v = mg->all_edges_dst_[i];\n    for (size_t j = 0; j < mg->adjlist_[v].succ.size(); ++j) {\n      if (backtracking || (!backtracking && mg->adjlist_[v].succ[j] != u)) {\n        lg->AddEdge(i, mg->adjlist_[v].edge_id[j]);\n      }\n    }\n  }\n  return lg;\n}\n\nGraphPtr GraphOp::DisjointUnion(std::vector<GraphPtr> graphs) {\n  CHECK_GT(graphs.size(), 0) << \"Input graph list is empty\";\n  if (IsMutable(graphs[0])) {\n    // Disjointly union of a list of mutable graph inputs. The result is\n    // also a mutable graph.\n    MutableGraphPtr rst = Graph::Create();\n    uint64_t cumsum = 0;\n    for (GraphPtr gr : graphs) {\n      MutableGraphPtr mg = std::dynamic_pointer_cast<Graph>(gr);\n      CHECK(mg) << \"All the input graphs should be mutable graphs.\";\n      rst->AddVertices(gr->NumVertices());\n      for (uint64_t i = 0; i < gr->NumEdges(); ++i) {\n        // TODO(minjie): quite ugly to expose internal members\n        rst->AddEdge(\n            mg->all_edges_src_[i] + cumsum, mg->all_edges_dst_[i] + cumsum);\n      }\n      cumsum += gr->NumVertices();\n    }\n    return rst;\n  } else {\n    // Disjointly union of a list of immutable graph inputs. The result is\n    // also an immutable graph.\n    int64_t num_nodes = 0;\n    int64_t num_edges = 0;\n    for (auto gr : graphs) {\n      num_nodes += gr->NumVertices();\n      num_edges += gr->NumEdges();\n    }\n    IdArray indptr_arr = aten::NewIdArray(num_nodes + 1);\n    IdArray indices_arr = aten::NewIdArray(num_edges);\n    IdArray edge_ids_arr = aten::NewIdArray(num_edges);\n    dgl_id_t* indptr = static_cast<dgl_id_t*>(indptr_arr->data);\n    dgl_id_t* indices = static_cast<dgl_id_t*>(indices_arr->data);\n    dgl_id_t* edge_ids = static_cast<dgl_id_t*>(edge_ids_arr->data);\n\n    indptr[0] = 0;\n    dgl_id_t cum_num_nodes = 0;\n    dgl_id_t cum_num_edges = 0;\n    for (auto g : graphs) {\n      ImmutableGraphPtr gr = std::dynamic_pointer_cast<ImmutableGraph>(g);\n      CHECK(gr) << \"All the input graphs should be immutable graphs.\";\n      // TODO(minjie): why in csr?\n      const CSRPtr g_csrptr = gr->GetInCSR();\n      const uint64_t g_num_nodes = g_csrptr->NumVertices();\n      const uint64_t g_num_edges = g_csrptr->NumEdges();\n      dgl_id_t* g_indptr = static_cast<dgl_id_t*>(g_csrptr->indptr()->data);\n      dgl_id_t* g_indices = static_cast<dgl_id_t*>(g_csrptr->indices()->data);\n      dgl_id_t* g_edge_ids = static_cast<dgl_id_t*>(g_csrptr->edge_ids()->data);\n      for (dgl_id_t i = 1; i < g_num_nodes + 1; ++i) {\n        indptr[cum_num_nodes + i] = g_indptr[i] + cum_num_edges;\n      }\n      for (dgl_id_t i = 0; i < g_num_edges; ++i) {\n        indices[cum_num_edges + i] = g_indices[i] + cum_num_nodes;\n      }\n\n      for (dgl_id_t i = 0; i < g_num_edges; ++i) {\n        edge_ids[cum_num_edges + i] = g_edge_ids[i] + cum_num_edges;\n      }\n      cum_num_nodes += g_num_nodes;\n      cum_num_edges += g_num_edges;\n    }\n\n    return ImmutableGraph::CreateFromCSR(\n        indptr_arr, indices_arr, edge_ids_arr, \"in\");\n  }\n}\n\nstd::vector<GraphPtr> GraphOp::DisjointPartitionByNum(\n    GraphPtr graph, int64_t num) {\n  CHECK(num != 0 && graph->NumVertices() % num == 0)\n      << \"Number of partitions must evenly divide the number of nodes.\";\n  IdArray sizes = IdArray::Empty(\n      {num}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});\n  int64_t* sizes_data = static_cast<int64_t*>(sizes->data);\n  std::fill(sizes_data, sizes_data + num, graph->NumVertices() / num);\n  return DisjointPartitionBySizes(graph, sizes);\n}\n\nstd::vector<GraphPtr> GraphOp::DisjointPartitionBySizes(\n    GraphPtr batched_graph, IdArray sizes) {\n  const int64_t len = sizes->shape[0];\n  const int64_t* sizes_data = static_cast<int64_t*>(sizes->data);\n  std::vector<int64_t> cumsum;\n  cumsum.push_back(0);\n  for (int64_t i = 0; i < len; ++i) {\n    cumsum.push_back(cumsum[i] + sizes_data[i]);\n  }\n  CHECK_EQ(cumsum[len], batched_graph->NumVertices())\n      << \"Sum of the given sizes must equal to the number of nodes.\";\n\n  std::vector<GraphPtr> rst;\n  if (IsMutable(batched_graph)) {\n    // Input is a mutable graph. Partition it into several mutable graphs.\n    MutableGraphPtr graph = std::dynamic_pointer_cast<Graph>(batched_graph);\n    dgl_id_t node_offset = 0, edge_offset = 0;\n    for (int64_t i = 0; i < len; ++i) {\n      MutableGraphPtr mg = Graph::Create();\n      // TODO(minjie): quite ugly to expose internal members\n      // copy adj\n      mg->adjlist_.insert(\n          mg->adjlist_.end(), graph->adjlist_.begin() + node_offset,\n          graph->adjlist_.begin() + node_offset + sizes_data[i]);\n      mg->reverse_adjlist_.insert(\n          mg->reverse_adjlist_.end(),\n          graph->reverse_adjlist_.begin() + node_offset,\n          graph->reverse_adjlist_.begin() + node_offset + sizes_data[i]);\n      // relabel adjs\n      size_t num_edges = 0;\n      for (auto& elist : mg->adjlist_) {\n        for (size_t j = 0; j < elist.succ.size(); ++j) {\n          elist.succ[j] -= node_offset;\n          elist.edge_id[j] -= edge_offset;\n        }\n        num_edges += elist.succ.size();\n      }\n      for (auto& elist : mg->reverse_adjlist_) {\n        for (size_t j = 0; j < elist.succ.size(); ++j) {\n          elist.succ[j] -= node_offset;\n          elist.edge_id[j] -= edge_offset;\n        }\n      }\n      // copy edges\n      mg->all_edges_src_.reserve(num_edges);\n      mg->all_edges_dst_.reserve(num_edges);\n      mg->num_edges_ = num_edges;\n      for (size_t j = edge_offset; j < edge_offset + num_edges; ++j) {\n        mg->all_edges_src_.push_back(graph->all_edges_src_[j] - node_offset);\n        mg->all_edges_dst_.push_back(graph->all_edges_dst_[j] - node_offset);\n      }\n      // push to rst\n      rst.push_back(mg);\n      // update offset\n      CHECK_EQ(rst[i]->NumVertices(), sizes_data[i]);\n      CHECK_EQ(rst[i]->NumEdges(), num_edges);\n      node_offset += sizes_data[i];\n      edge_offset += num_edges;\n    }\n  } else {\n    // Input is an immutable graph. Partition it into several multiple graphs.\n    ImmutableGraphPtr graph =\n        std::dynamic_pointer_cast<ImmutableGraph>(batched_graph);\n    // TODO(minjie): why in csr?\n    CSRPtr in_csr_ptr = graph->GetInCSR();\n    const dgl_id_t* indptr = static_cast<dgl_id_t*>(in_csr_ptr->indptr()->data);\n    const dgl_id_t* indices =\n        static_cast<dgl_id_t*>(in_csr_ptr->indices()->data);\n    const dgl_id_t* edge_ids =\n        static_cast<dgl_id_t*>(in_csr_ptr->edge_ids()->data);\n    dgl_id_t cum_sum_edges = 0;\n    for (int64_t i = 0; i < len; ++i) {\n      const int64_t start_pos = cumsum[i];\n      const int64_t end_pos = cumsum[i + 1];\n      const int64_t g_num_nodes = sizes_data[i];\n      const int64_t g_num_edges = indptr[end_pos] - indptr[start_pos];\n      IdArray indptr_arr = aten::NewIdArray(g_num_nodes + 1);\n      IdArray indices_arr = aten::NewIdArray(g_num_edges);\n      IdArray edge_ids_arr = aten::NewIdArray(g_num_edges);\n      dgl_id_t* g_indptr = static_cast<dgl_id_t*>(indptr_arr->data);\n      dgl_id_t* g_indices = static_cast<dgl_id_t*>(indices_arr->data);\n      dgl_id_t* g_edge_ids = static_cast<dgl_id_t*>(edge_ids_arr->data);\n\n      const dgl_id_t idoff = indptr[start_pos];\n      g_indptr[0] = 0;\n      for (int l = start_pos + 1; l < end_pos + 1; ++l) {\n        g_indptr[l - start_pos] = indptr[l] - indptr[start_pos];\n      }\n\n      for (dgl_id_t j = indptr[start_pos]; j < indptr[end_pos]; ++j) {\n        g_indices[j - idoff] = indices[j] - cumsum[i];\n      }\n\n      for (dgl_id_t k = indptr[start_pos]; k < indptr[end_pos]; ++k) {\n        g_edge_ids[k - idoff] = edge_ids[k] - cum_sum_edges;\n      }\n\n      cum_sum_edges += g_num_edges;\n      rst.push_back(ImmutableGraph::CreateFromCSR(\n          indptr_arr, indices_arr, edge_ids_arr, \"in\"));\n    }\n  }\n  return rst;\n}\n\nIdArray GraphOp::MapParentIdToSubgraphId(IdArray parent_vids, IdArray query) {\n  CHECK(aten::IsValidIdArray(parent_vids)) << \"Invalid parent id array.\";\n  CHECK(aten::IsValidIdArray(query)) << \"Invalid query id array.\";\n  const auto parent_len = parent_vids->shape[0];\n  const auto query_len = query->shape[0];\n  const dgl_id_t* parent_data = static_cast<dgl_id_t*>(parent_vids->data);\n  const dgl_id_t* query_data = static_cast<dgl_id_t*>(query->data);\n  IdArray rst = IdArray::Empty(\n      {query_len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});\n  dgl_id_t* rst_data = static_cast<dgl_id_t*>(rst->data);\n\n  const bool is_sorted = std::is_sorted(parent_data, parent_data + parent_len);\n  if (is_sorted) {\n    runtime::parallel_for(0, query_len, [&](size_t b, size_t e) {\n      for (auto i = b; i < e; ++i) {\n        const dgl_id_t id = query_data[i];\n        const auto it = std::find(parent_data, parent_data + parent_len, id);\n        // If the vertex Id doesn't exist, the vid in the subgraph is -1.\n        if (it != parent_data + parent_len) {\n          rst_data[i] = it - parent_data;\n        } else {\n          rst_data[i] = -1;\n        }\n      }\n    });\n  } else {\n    std::unordered_map<dgl_id_t, dgl_id_t> parent_map;\n    for (int64_t i = 0; i < parent_len; i++) {\n      const dgl_id_t id = parent_data[i];\n      parent_map[id] = i;\n    }\n    runtime::parallel_for(0, query_len, [&](size_t b, size_t e) {\n      for (auto i = b; i < e; ++i) {\n        const dgl_id_t id = query_data[i];\n        auto it = parent_map.find(id);\n        // If the vertex Id doesn't exist, the vid in the subgraph is -1.\n        if (it != parent_map.end()) {\n          rst_data[i] = it->second;\n        } else {\n          rst_data[i] = -1;\n        }\n      }\n    });\n  }\n  return rst;\n}\n\nIdArray GraphOp::ExpandIds(IdArray ids, IdArray offset) {\n  const auto id_len = ids->shape[0];\n  const auto off_len = offset->shape[0];\n  CHECK_EQ(id_len + 1, off_len);\n  const dgl_id_t* id_data = static_cast<dgl_id_t*>(ids->data);\n  const dgl_id_t* off_data = static_cast<dgl_id_t*>(offset->data);\n  const int64_t len = off_data[off_len - 1];\n  IdArray rst = IdArray::Empty(\n      {len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});\n  dgl_id_t* rst_data = static_cast<dgl_id_t*>(rst->data);\n  for (int64_t i = 0; i < id_len; i++) {\n    const int64_t local_len = off_data[i + 1] - off_data[i];\n    for (int64_t j = 0; j < local_len; j++) {\n      rst_data[off_data[i] + j] = id_data[i];\n    }\n  }\n  return rst;\n}\n\nGraphPtr GraphOp::ToSimpleGraph(GraphPtr graph) {\n  std::vector<dgl_id_t> indptr(graph->NumVertices() + 1), indices;\n  indptr[0] = 0;\n  for (dgl_id_t src = 0; src < graph->NumVertices(); ++src) {\n    std::unordered_set<dgl_id_t> hashmap;\n    for (const dgl_id_t dst : graph->SuccVec(src)) {\n      if (!hashmap.count(dst)) {\n        indices.push_back(dst);\n        hashmap.insert(dst);\n      }\n    }\n    indptr[src + 1] = indices.size();\n  }\n  CSRPtr csr(new CSR(\n      graph->NumVertices(), indices.size(), indptr.begin(), indices.begin(),\n      RangeIter(0)));\n  return std::make_shared<ImmutableGraph>(csr);\n}\n\nGraphPtr GraphOp::ToBidirectedMutableGraph(GraphPtr g) {\n  std::unordered_map<int, std::unordered_map<int, int>> n_e;\n  for (dgl_id_t u = 0; u < g->NumVertices(); ++u) {\n    for (const dgl_id_t v : g->SuccVec(u)) {\n      n_e[u][v]++;\n    }\n  }\n\n  GraphPtr bg = Graph::Create();\n  bg->AddVertices(g->NumVertices());\n  for (dgl_id_t u = 0; u < g->NumVertices(); ++u) {\n    for (dgl_id_t v = u; v < g->NumVertices(); ++v) {\n      const auto new_n_e = std::max(n_e[u][v], n_e[v][u]);\n      if (new_n_e > 0) {\n        IdArray us = aten::NewIdArray(new_n_e);\n        dgl_id_t* us_data = static_cast<dgl_id_t*>(us->data);\n        std::fill(us_data, us_data + new_n_e, u);\n        if (u == v) {\n          bg->AddEdges(us, us);\n        } else {\n          IdArray vs = aten::NewIdArray(new_n_e);\n          dgl_id_t* vs_data = static_cast<dgl_id_t*>(vs->data);\n          std::fill(vs_data, vs_data + new_n_e, v);\n          bg->AddEdges(us, vs);\n          bg->AddEdges(vs, us);\n        }\n      }\n    }\n  }\n  return bg;\n}\n\nGraphPtr GraphOp::ToBidirectedImmutableGraph(GraphPtr g) {\n  std::unordered_map<int, std::unordered_map<int, int>> n_e;\n  for (dgl_id_t u = 0; u < g->NumVertices(); ++u) {\n    for (const dgl_id_t v : g->SuccVec(u)) {\n      n_e[u][v]++;\n    }\n  }\n\n  std::vector<dgl_id_t> srcs, dsts;\n  for (dgl_id_t u = 0; u < g->NumVertices(); ++u) {\n    std::unordered_set<dgl_id_t> hashmap;\n    std::vector<dgl_id_t> nbrs;\n    for (const dgl_id_t v : g->PredVec(u)) {\n      if (!hashmap.count(v)) {\n        nbrs.push_back(v);\n        hashmap.insert(v);\n      }\n    }\n    for (const dgl_id_t v : g->SuccVec(u)) {\n      if (!hashmap.count(v)) {\n        nbrs.push_back(v);\n        hashmap.insert(v);\n      }\n    }\n    for (const dgl_id_t v : nbrs) {\n      const auto new_n_e = std::max(n_e[u][v], n_e[v][u]);\n      for (int i = 0; i < new_n_e; ++i) {\n        srcs.push_back(v);\n        dsts.push_back(u);\n      }\n    }\n  }\n\n  IdArray srcs_array = aten::VecToIdArray(srcs);\n  IdArray dsts_array = aten::VecToIdArray(dsts);\n  return ImmutableGraph::CreateFromCOO(\n      g->NumVertices(), srcs_array, dsts_array);\n}\n\nHaloSubgraph GraphOp::GetSubgraphWithHalo(\n    GraphPtr g, IdArray nodes, int num_hops) {\n  const dgl_id_t* nid = static_cast<dgl_id_t*>(nodes->data);\n  const auto id_len = nodes->shape[0];\n  // A map contains all nodes in the subgraph.\n  // The key is the old node Ids, the value indicates whether a node is a inner\n  // node.\n  std::unordered_map<dgl_id_t, bool> all_nodes;\n  // The old Ids of all nodes. We want to preserve the order of the nodes in the\n  // vector. The first few nodes are the inner nodes in the subgraph.\n  std::vector<dgl_id_t> old_node_ids(nid, nid + id_len);\n  std::vector<std::vector<dgl_id_t>> outer_nodes(num_hops);\n  for (int64_t i = 0; i < id_len; i++) all_nodes[nid[i]] = true;\n  auto orig_nodes = all_nodes;\n\n  std::vector<dgl_id_t> edge_src, edge_dst, edge_eid;\n\n  // When we deal with in-edges, we need to do two things:\n  // * find the edges inside the partition and the edges between partitions.\n  // * find the nodes outside the partition that connect the partition.\n  EdgeArray in_edges = g->InEdges(nodes);\n  auto src = in_edges.src;\n  auto dst = in_edges.dst;\n  auto eid = in_edges.id;\n  auto num_edges = eid->shape[0];\n  const dgl_id_t* src_data = static_cast<dgl_id_t*>(src->data);\n  const dgl_id_t* dst_data = static_cast<dgl_id_t*>(dst->data);\n  const dgl_id_t* eid_data = static_cast<dgl_id_t*>(eid->data);\n  for (int64_t i = 0; i < num_edges; i++) {\n    // We check if the source node is in the original node.\n    auto it1 = orig_nodes.find(src_data[i]);\n    if (it1 != orig_nodes.end() || num_hops > 0) {\n      edge_src.push_back(src_data[i]);\n      edge_dst.push_back(dst_data[i]);\n      edge_eid.push_back(eid_data[i]);\n    }\n    // We need to expand only if the node hasn't been seen before.\n    auto it = all_nodes.find(src_data[i]);\n    if (it == all_nodes.end() && num_hops > 0) {\n      all_nodes[src_data[i]] = false;\n      old_node_ids.push_back(src_data[i]);\n      outer_nodes[0].push_back(src_data[i]);\n    }\n  }\n\n  // Now we need to traverse the graph with the in-edges to access nodes\n  // and edges more hops away.\n  for (int k = 1; k < num_hops; k++) {\n    const std::vector<dgl_id_t>& nodes = outer_nodes[k - 1];\n    EdgeArray in_edges = g->InEdges(aten::VecToIdArray(nodes));\n    auto src = in_edges.src;\n    auto dst = in_edges.dst;\n    auto eid = in_edges.id;\n    auto num_edges = eid->shape[0];\n    const dgl_id_t* src_data = static_cast<dgl_id_t*>(src->data);\n    const dgl_id_t* dst_data = static_cast<dgl_id_t*>(dst->data);\n    const dgl_id_t* eid_data = static_cast<dgl_id_t*>(eid->data);\n    for (int64_t i = 0; i < num_edges; i++) {\n      edge_src.push_back(src_data[i]);\n      edge_dst.push_back(dst_data[i]);\n      edge_eid.push_back(eid_data[i]);\n      // If we haven't seen this node.\n      auto it = all_nodes.find(src_data[i]);\n      if (it == all_nodes.end()) {\n        all_nodes[src_data[i]] = false;\n        old_node_ids.push_back(src_data[i]);\n        outer_nodes[k].push_back(src_data[i]);\n      }\n    }\n  }\n\n  // We assign new Ids to the nodes in the subgraph. We ensure that the HALO\n  // nodes are behind the input nodes.\n  std::unordered_map<dgl_id_t, dgl_id_t> old2new;\n  for (size_t i = 0; i < old_node_ids.size(); i++) {\n    old2new[old_node_ids[i]] = i;\n  }\n\n  num_edges = edge_src.size();\n  IdArray new_src = IdArray::Empty(\n      {num_edges}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});\n  IdArray new_dst = IdArray::Empty(\n      {num_edges}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});\n  dgl_id_t* new_src_data = static_cast<dgl_id_t*>(new_src->data);\n  dgl_id_t* new_dst_data = static_cast<dgl_id_t*>(new_dst->data);\n  for (size_t i = 0; i < edge_src.size(); i++) {\n    new_src_data[i] = old2new[edge_src[i]];\n    new_dst_data[i] = old2new[edge_dst[i]];\n  }\n\n  std::vector<int> inner_nodes(old_node_ids.size());\n  for (size_t i = 0; i < old_node_ids.size(); i++) {\n    dgl_id_t old_nid = old_node_ids[i];\n    inner_nodes[i] = all_nodes[old_nid];\n  }\n\n  GraphPtr subg =\n      ImmutableGraph::CreateFromCOO(old_node_ids.size(), new_src, new_dst);\n  HaloSubgraph halo_subg;\n  halo_subg.graph = subg;\n  halo_subg.induced_vertices = aten::VecToIdArray(old_node_ids);\n  halo_subg.induced_edges = aten::VecToIdArray(edge_eid);\n  // TODO(zhengda) we need to switch to 8 bytes afterwards.\n  halo_subg.inner_nodes = aten::VecToIdArray<int>(inner_nodes, 32);\n  return halo_subg;\n}\n\nGraphPtr GraphOp::ReorderImmutableGraph(\n    ImmutableGraphPtr ig, IdArray new_order) {\n  CSRPtr in_csr, out_csr;\n  COOPtr coo;\n  // We only need to reorder one of the graph structure.\n  if (ig->HasInCSR()) {\n    in_csr = ig->GetInCSR();\n    auto csrmat = in_csr->ToCSRMatrix();\n    auto new_csrmat = aten::CSRReorder(csrmat, new_order, new_order);\n    in_csr =\n        CSRPtr(new CSR(new_csrmat.indptr, new_csrmat.indices, new_csrmat.data));\n  } else if (ig->HasOutCSR()) {\n    out_csr = ig->GetOutCSR();\n    auto csrmat = out_csr->ToCSRMatrix();\n    auto new_csrmat = aten::CSRReorder(csrmat, new_order, new_order);\n    out_csr =\n        CSRPtr(new CSR(new_csrmat.indptr, new_csrmat.indices, new_csrmat.data));\n  } else {\n    coo = ig->GetCOO();\n    auto coomat = coo->ToCOOMatrix();\n    auto new_coomat = aten::COOReorder(coomat, new_order, new_order);\n    coo = COOPtr(new COO(ig->NumVertices(), new_coomat.row, new_coomat.col));\n  }\n  if (in_csr || out_csr)\n    return GraphPtr(new ImmutableGraph(in_csr, out_csr));\n  else\n    return GraphPtr(new ImmutableGraph(coo));\n}\n\nDGL_REGISTER_GLOBAL(\"transform._CAPI_DGLPartitionWithHalo\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      GraphRef graph = args[0];\n      IdArray node_parts = args[1];\n      int num_hops = args[2];\n\n      const dgl_id_t* part_data = static_cast<dgl_id_t*>(node_parts->data);\n      int64_t num_nodes = node_parts->shape[0];\n      std::unordered_map<int, std::vector<dgl_id_t>> part_map;\n      for (int64_t i = 0; i < num_nodes; i++) {\n        dgl_id_t part_id = part_data[i];\n        auto it = part_map.find(part_id);\n        if (it == part_map.end()) {\n          std::vector<dgl_id_t> vec;\n          vec.push_back(i);\n          part_map[part_id] = vec;\n        } else {\n          it->second.push_back(i);\n        }\n      }\n      std::vector<int> part_ids;\n      std::vector<std::vector<dgl_id_t>> part_nodes;\n      int max_part_id = 0;\n      for (auto it = part_map.begin(); it != part_map.end(); it++) {\n        max_part_id = std::max(it->first, max_part_id);\n        part_ids.push_back(it->first);\n        part_nodes.push_back(it->second);\n      }\n      auto graph_ptr = std::dynamic_pointer_cast<ImmutableGraph>(graph.sptr());\n      CHECK(graph_ptr) << \"The input graph has to be an immutable graph\";\n      // When we construct subgraphs, we only access in-edges.\n      // We need to make sure the in-CSR exists. Otherwise, we'll\n      // try to construct in-CSR in openmp for loop, which will lead\n      // to some unexpected results.\n      graph_ptr->GetInCSR();\n      std::vector<std::shared_ptr<HaloSubgraph>> subgs(max_part_id + 1);\n      int num_partitions = part_nodes.size();\n      runtime::parallel_for(0, num_partitions, [&](size_t b, size_t e) {\n        for (auto i = b; i < e; ++i) {\n          auto nodes = aten::VecToIdArray(part_nodes[i]);\n          HaloSubgraph subg =\n              GraphOp::GetSubgraphWithHalo(graph_ptr, nodes, num_hops);\n          std::shared_ptr<HaloSubgraph> subg_ptr(new HaloSubgraph(subg));\n          int part_id = part_ids[i];\n          subgs[part_id] = subg_ptr;\n        }\n      });\n      List<SubgraphRef> ret_list;\n      for (size_t i = 0; i < subgs.size(); i++) {\n        ret_list.push_back(SubgraphRef(subgs[i]));\n      }\n      *rv = ret_list;\n    });\n\nDGL_REGISTER_GLOBAL(\"graph_index._CAPI_DGLGetSubgraphWithHalo\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      GraphRef graph = args[0];\n      IdArray nodes = args[1];\n      int num_hops = args[2];\n      HaloSubgraph subg =\n          GraphOp::GetSubgraphWithHalo(graph.sptr(), nodes, num_hops);\n      std::shared_ptr<HaloSubgraph> subg_ptr(new HaloSubgraph(subg));\n      *rv = SubgraphRef(subg_ptr);\n    });\n\nDGL_REGISTER_GLOBAL(\"graph_index._CAPI_GetHaloSubgraphInnerNodes\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      SubgraphRef g = args[0];\n      auto gptr = std::dynamic_pointer_cast<HaloSubgraph>(g.sptr());\n      CHECK(gptr) << \"The input graph has to be immutable graph\";\n      *rv = gptr->inner_nodes;\n    });\n\nDGL_REGISTER_GLOBAL(\"graph_index._CAPI_DGLDisjointUnion\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      List<GraphRef> graphs = args[0];\n      std::vector<GraphPtr> ptrs(graphs.size());\n      for (size_t i = 0; i < graphs.size(); ++i) {\n        ptrs[i] = graphs[i].sptr();\n      }\n      *rv = GraphOp::DisjointUnion(ptrs);\n    });\n\nDGL_REGISTER_GLOBAL(\"graph_index._CAPI_DGLDisjointPartitionByNum\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      GraphRef g = args[0];\n      int64_t num = args[1];\n      const auto& ret = GraphOp::DisjointPartitionByNum(g.sptr(), num);\n      List<GraphRef> ret_list;\n      for (GraphPtr gp : ret) {\n        ret_list.push_back(GraphRef(gp));\n      }\n      *rv = ret_list;\n    });\n\nDGL_REGISTER_GLOBAL(\"graph_index._CAPI_DGLDisjointPartitionBySizes\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      GraphRef g = args[0];\n      const IdArray sizes = args[1];\n      const auto& ret = GraphOp::DisjointPartitionBySizes(g.sptr(), sizes);\n      List<GraphRef> ret_list;\n      for (GraphPtr gp : ret) {\n        ret_list.push_back(GraphRef(gp));\n      }\n      *rv = ret_list;\n    });\n\nDGL_REGISTER_GLOBAL(\"graph_index._CAPI_DGLGraphLineGraph\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      GraphRef g = args[0];\n      bool backtracking = args[1];\n      *rv = GraphOp::LineGraph(g.sptr(), backtracking);\n    });\n\nDGL_REGISTER_GLOBAL(\"graph_index._CAPI_DGLToImmutable\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      GraphRef g = args[0];\n      *rv = ImmutableGraph::ToImmutable(g.sptr());\n    });\n\nDGL_REGISTER_GLOBAL(\"transform._CAPI_DGLToSimpleGraph\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      GraphRef g = args[0];\n      *rv = GraphOp::ToSimpleGraph(g.sptr());\n    });\n\nDGL_REGISTER_GLOBAL(\"transform._CAPI_DGLToBidirectedMutableGraph\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      GraphRef g = args[0];\n      *rv = GraphOp::ToBidirectedMutableGraph(g.sptr());\n    });\n\nDGL_REGISTER_GLOBAL(\"transform._CAPI_DGLReorderGraph\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      GraphRef g = args[0];\n      const IdArray new_order = args[1];\n      auto gptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());\n      CHECK(gptr) << \"The input graph has to be immutable graph\";\n      *rv = GraphOp::ReorderImmutableGraph(gptr, new_order);\n    });\n\nDGL_REGISTER_GLOBAL(\"transform._CAPI_DGLReassignEdges\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      GraphRef graph = args[0];\n      bool is_incsr = args[1];\n      auto gptr = std::dynamic_pointer_cast<ImmutableGraph>(graph.sptr());\n      CHECK(gptr) << \"We can only reassign edge Ids on immutable graphs\";\n      CSRPtr csr = is_incsr ? gptr->GetInCSR() : gptr->GetOutCSR();\n      auto csrmat = csr->ToCSRMatrix();\n      int64_t num_edges = csrmat.data->shape[0];\n      IdArray new_data =\n          IdArray::Empty({num_edges}, csrmat.data->dtype, csrmat.data->ctx);\n      // Return the original edge Ids.\n      *rv = new_data;\n      // TODO(zhengda) I need to invalidate out-CSR and COO.\n\n      // Generate new edge Ids.\n      // TODO(zhengda) after assignment, we actually don't need to store them\n      // physically.\n      ATEN_ID_TYPE_SWITCH(new_data->dtype, IdType, {\n        IdType* typed_new_data = static_cast<IdType*>(new_data->data);\n        IdType* typed_data = static_cast<IdType*>(csrmat.data->data);\n        for (int64_t i = 0; i < num_edges; i++) {\n          typed_new_data[i] = typed_data[i];\n          typed_data[i] = i;\n        }\n      });\n    });\n\nDGL_REGISTER_GLOBAL(\"transform._CAPI_DGLToBidirectedImmutableGraph\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      GraphRef g = args[0];\n      auto gptr = g.sptr();\n      auto immutable_g = std::dynamic_pointer_cast<ImmutableGraph>(gptr);\n      GraphPtr ret;\n      // For immutable graphs, we can try a faster version.\n      if (immutable_g) {\n        ret = GraphOp::ToBidirectedSimpleImmutableGraph(immutable_g);\n      }\n      // If the above option doesn't work, we call a general implementation.\n      if (!ret) {\n        ret = GraphOp::ToBidirectedImmutableGraph(gptr);\n      }\n      *rv = ret;\n    });\n\nDGL_REGISTER_GLOBAL(\"graph_index._CAPI_DGLMapSubgraphNID\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      const IdArray parent_vids = args[0];\n      const IdArray query = args[1];\n      *rv = GraphOp::MapParentIdToSubgraphId(parent_vids, query);\n    });\n\ntemplate <class IdType>\nIdArray MapIds(\n    IdArray ids, IdArray range_starts, IdArray range_ends, IdArray typed_map,\n    int num_parts, int num_types) {\n  int64_t num_ids = ids->shape[0];\n  int64_t num_ranges = range_starts->shape[0];\n  IdArray ret = IdArray::Empty({num_ids * 2}, ids->dtype, ids->ctx);\n\n  const IdType* range_start_data = static_cast<IdType*>(range_starts->data);\n  const IdType* range_end_data = static_cast<IdType*>(range_ends->data);\n  const IdType* ids_data = static_cast<IdType*>(ids->data);\n  const IdType* typed_map_data = static_cast<IdType*>(typed_map->data);\n  IdType* types_data = static_cast<IdType*>(ret->data);\n  IdType* per_type_ids_data = static_cast<IdType*>(ret->data) + num_ids;\n  runtime::parallel_for(0, ids->shape[0], [&](size_t b, size_t e) {\n    for (auto i = b; i < e; ++i) {\n      IdType id = ids_data[i];\n      auto it =\n          std::lower_bound(range_end_data, range_end_data + num_ranges, id);\n      // The range must exist.\n      BUG_IF_FAIL(it != range_end_data + num_ranges);\n      size_t range_id = it - range_end_data;\n      int type_id = range_id % num_types;\n      types_data[i] = type_id;\n      int part_id = range_id / num_types;\n      BUG_IF_FAIL(part_id < num_parts);\n      if (part_id == 0) {\n        per_type_ids_data[i] = id - range_start_data[range_id];\n      } else {\n        per_type_ids_data[i] =\n            id - range_start_data[range_id] +\n            typed_map_data[num_parts * type_id + part_id - 1];\n      }\n    }\n  });\n  return ret;\n}\n\nDGL_REGISTER_GLOBAL(\"distributed.id_map._CAPI_DGLHeteroMapIds\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      const IdArray ids = args[0];\n      const IdArray range_starts = args[1];\n      const IdArray range_ends = args[2];\n      const IdArray typed_map = args[3];\n      int num_parts = args[4];\n      int num_types = args[5];\n      int num_ranges = range_starts->shape[0];\n\n      CHECK_EQ(range_starts->dtype.bits, ids->dtype.bits);\n      CHECK_EQ(range_ends->dtype.bits, ids->dtype.bits);\n      CHECK_EQ(typed_map->dtype.bits, ids->dtype.bits);\n      CHECK_EQ(num_ranges, num_parts * num_types);\n      CHECK_EQ(num_ranges, range_ends->shape[0]);\n\n      IdArray ret;\n      ATEN_ID_TYPE_SWITCH(ids->dtype, IdType, {\n        ret = MapIds<IdType>(\n            ids, range_starts, range_ends, typed_map, num_parts, num_types);\n      });\n      *rv = ret;\n    });\n\n}  // namespace dgl\n"
  },
  {
    "path": "src/graph/graph_traversal.cc",
    "content": "/**\n *  Copyright (c) 2018 by Contributors\n * @file graph/traversal.cc\n * @brief Graph traversal implementation\n */\n#include <dgl/graph_traversal.h>\n#include <dgl/packed_func_ext.h>\n\n#include \"../c_api_common.h\"\n\nusing namespace dgl::runtime;\n\nnamespace dgl {\nnamespace traverse {\n\nDGL_REGISTER_GLOBAL(\"traversal._CAPI_DGLBFSNodes_v2\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef g = args[0];\n      const IdArray src = args[1];\n      bool reversed = args[2];\n      aten::CSRMatrix csr;\n      if (reversed) {\n        csr = g.sptr()->GetCSCMatrix(0);\n      } else {\n        csr = g.sptr()->GetCSRMatrix(0);\n      }\n      const auto& front = aten::BFSNodesFrontiers(csr, src);\n      *rv = ConvertNDArrayVectorToPackedFunc({front.ids, front.sections});\n    });\n\nDGL_REGISTER_GLOBAL(\"traversal._CAPI_DGLBFSEdges_v2\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef g = args[0];\n      const IdArray src = args[1];\n      bool reversed = args[2];\n      aten::CSRMatrix csr;\n      if (reversed) {\n        csr = g.sptr()->GetCSCMatrix(0);\n      } else {\n        csr = g.sptr()->GetCSRMatrix(0);\n      }\n\n      const auto& front = aten::BFSEdgesFrontiers(csr, src);\n      *rv = ConvertNDArrayVectorToPackedFunc({front.ids, front.sections});\n    });\n\nDGL_REGISTER_GLOBAL(\"traversal._CAPI_DGLTopologicalNodes_v2\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef g = args[0];\n      bool reversed = args[1];\n      aten::CSRMatrix csr;\n      if (reversed) {\n        csr = g.sptr()->GetCSCMatrix(0);\n      } else {\n        csr = g.sptr()->GetCSRMatrix(0);\n      }\n\n      const auto& front = aten::TopologicalNodesFrontiers(csr);\n      *rv = ConvertNDArrayVectorToPackedFunc({front.ids, front.sections});\n    });\n\nDGL_REGISTER_GLOBAL(\"traversal._CAPI_DGLDFSEdges_v2\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef g = args[0];\n      const IdArray source = args[1];\n      const bool reversed = args[2];\n      CHECK(aten::IsValidIdArray(source)) << \"Invalid source node id array.\";\n      aten::CSRMatrix csr;\n      if (reversed) {\n        csr = g.sptr()->GetCSCMatrix(0);\n      } else {\n        csr = g.sptr()->GetCSRMatrix(0);\n      }\n      const auto& front = aten::DGLDFSEdges(csr, source);\n      *rv = ConvertNDArrayVectorToPackedFunc({front.ids, front.sections});\n    });\n\nDGL_REGISTER_GLOBAL(\"traversal._CAPI_DGLDFSLabeledEdges_v2\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef g = args[0];\n      const IdArray source = args[1];\n      const bool reversed = args[2];\n      const bool has_reverse_edge = args[3];\n      const bool has_nontree_edge = args[4];\n      const bool return_labels = args[5];\n      aten::CSRMatrix csr;\n      if (reversed) {\n        csr = g.sptr()->GetCSCMatrix(0);\n      } else {\n        csr = g.sptr()->GetCSRMatrix(0);\n      }\n\n      const auto& front = aten::DGLDFSLabeledEdges(\n          csr, source, has_reverse_edge, has_nontree_edge, return_labels);\n\n      if (return_labels) {\n        *rv = ConvertNDArrayVectorToPackedFunc(\n            {front.ids, front.tags, front.sections});\n      } else {\n        *rv = ConvertNDArrayVectorToPackedFunc({front.ids, front.sections});\n      }\n    });\n\n}  // namespace traverse\n}  // namespace dgl\n"
  },
  {
    "path": "src/graph/heterograph.cc",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file graph/heterograph.cc\n * @brief Heterograph implementation\n */\n#include \"./heterograph.h\"\n\n#include <dgl/array.h>\n#include <dgl/graph_serializer.h>\n#include <dgl/immutable_graph.h>\n#include <dmlc/memory_io.h>\n\n#include <memory>\n#include <tuple>\n#include <utility>\n#include <vector>\n\nusing namespace dgl::runtime;\n\nnamespace dgl {\nnamespace {\n\nusing dgl::ImmutableGraph;\n\nHeteroSubgraph EdgeSubgraphPreserveNodes(\n    const HeteroGraph* hg, const std::vector<IdArray>& eids) {\n  CHECK_EQ(eids.size(), hg->NumEdgeTypes())\n      << \"Invalid input: the input list size must be the same as the number of \"\n         \"edge type.\";\n  HeteroSubgraph ret;\n  ret.induced_vertices.resize(hg->NumVertexTypes());\n  ret.induced_edges = eids;\n  // When preserve_nodes is true, simply compute EdgeSubgraph for each bipartite\n  std::vector<HeteroGraphPtr> subrels(hg->NumEdgeTypes());\n  for (dgl_type_t etype = 0; etype < hg->NumEdgeTypes(); ++etype) {\n    auto pair = hg->meta_graph()->FindEdge(etype);\n    const dgl_type_t src_vtype = pair.first;\n    const dgl_type_t dst_vtype = pair.second;\n    const auto& rel_vsg =\n        hg->GetRelationGraph(etype)->EdgeSubgraph({eids[etype]}, true);\n    subrels[etype] = rel_vsg.graph;\n    ret.induced_vertices[src_vtype] = rel_vsg.induced_vertices[0];\n    ret.induced_vertices[dst_vtype] = rel_vsg.induced_vertices[1];\n  }\n  ret.graph = HeteroGraphPtr(\n      new HeteroGraph(hg->meta_graph(), subrels, hg->NumVerticesPerType()));\n  return ret;\n}\n\nHeteroSubgraph EdgeSubgraphNoPreserveNodes(\n    const HeteroGraph* hg, const std::vector<IdArray>& eids) {\n  // TODO(minjie): In general, all relabeling should be separated with subgraph\n  //   operations.\n  CHECK_EQ(eids.size(), hg->NumEdgeTypes())\n      << \"Invalid input: the input list size must be the same as the number of \"\n         \"edge type.\";\n  HeteroSubgraph ret;\n  ret.induced_vertices.resize(hg->NumVertexTypes());\n  ret.induced_edges = eids;\n  // NOTE(minjie): EdgeSubgraph when preserve_nodes is false is quite\n  // complicated in heterograph. This is because we need to make sure bipartite\n  // graphs that incident on the same vertex type must have the same ID space.\n  // For example, suppose we have following heterograph:\n  //\n  // Meta graph: A -> B -> C\n  // UnitGraph graphs:\n  // * A -> B: (0, 0), (0, 1)\n  // * B -> C: (1, 0), (1, 1)\n  //\n  // Suppose for A->B, we only keep edge (0, 0), while for B->C we only keep (1,\n  // 0). We need to make sure that in the result subgraph, node type B still has\n  // two nodes. This means we cannot simply compute EdgeSubgraph for B->C which\n  // will relabel node#1 of type B to be node #0.\n  //\n  // One implementation is as follows:\n  // (1) For each bipartite graph, slice out the edges using the given eids.\n  // (2) Make a dictionary map<vtype, vector<IdArray>>, where the key is the\n  // vertex type\n  //     and the value is the incident nodes from the bipartite graphs that has\n  //     the vertex type as either srctype or dsttype.\n  // (3) Then for each vertex type, use aten::Relabel_ on its vector<IdArray>.\n  //     aten::Relabel_ computes the union of the vertex sets and relabel\n  //     the unique elements from zero. The returned mapping array is the final\n  //     induced vertex set for that vertex type.\n  // (4) Use the relabeled edges to construct the bipartite graph.\n  // step (1) & (2)\n  std::vector<EdgeArray> subedges(hg->NumEdgeTypes());\n  std::vector<std::vector<IdArray>> vtype2incnodes(hg->NumVertexTypes());\n  for (dgl_type_t etype = 0; etype < hg->NumEdgeTypes(); ++etype) {\n    auto pair = hg->meta_graph()->FindEdge(etype);\n    const dgl_type_t src_vtype = pair.first;\n    const dgl_type_t dst_vtype = pair.second;\n    auto earray = hg->GetRelationGraph(etype)->FindEdges(0, eids[etype]);\n    vtype2incnodes[src_vtype].push_back(earray.src);\n    vtype2incnodes[dst_vtype].push_back(earray.dst);\n    subedges[etype] = earray;\n  }\n  // step (3)\n  std::vector<int64_t> num_vertices_per_type(hg->NumVertexTypes());\n  for (dgl_type_t vtype = 0; vtype < hg->NumVertexTypes(); ++vtype) {\n    ret.induced_vertices[vtype] = aten::Relabel_(vtype2incnodes[vtype]);\n    num_vertices_per_type[vtype] = ret.induced_vertices[vtype]->shape[0];\n  }\n  // step (4)\n  std::vector<HeteroGraphPtr> subrels(hg->NumEdgeTypes());\n  for (dgl_type_t etype = 0; etype < hg->NumEdgeTypes(); ++etype) {\n    auto pair = hg->meta_graph()->FindEdge(etype);\n    const dgl_type_t src_vtype = pair.first;\n    const dgl_type_t dst_vtype = pair.second;\n    subrels[etype] = UnitGraph::CreateFromCOO(\n        (src_vtype == dst_vtype) ? 1 : 2,\n        ret.induced_vertices[src_vtype]->shape[0],\n        ret.induced_vertices[dst_vtype]->shape[0], subedges[etype].src,\n        subedges[etype].dst);\n  }\n  ret.graph = HeteroGraphPtr(new HeteroGraph(\n      hg->meta_graph(), subrels, std::move(num_vertices_per_type)));\n  return ret;\n}\n\nvoid HeteroGraphSanityCheck(\n    GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs) {\n  // Sanity check\n  CHECK_EQ(meta_graph->NumEdges(), rel_graphs.size());\n  CHECK(!rel_graphs.empty()) << \"Empty heterograph is not allowed.\";\n  // all relation graphs must have only one edge type\n  for (const auto& rg : rel_graphs) {\n    CHECK_EQ(rg->NumEdgeTypes(), 1)\n        << \"Each relation graph must have only one edge type.\";\n  }\n  auto ctx = rel_graphs[0]->Context();\n  for (const auto& rg : rel_graphs) {\n    CHECK_EQ(rg->Context(), ctx)\n        << \"Each relation graph must have the same context.\";\n  }\n}\n\nstd::vector<int64_t> InferNumVerticesPerType(\n    GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs) {\n  // create num verts per type\n  std::vector<int64_t> num_verts_per_type(meta_graph->NumVertices(), -1);\n\n  EdgeArray etype_array = meta_graph->Edges();\n  dgl_type_t* srctypes = static_cast<dgl_type_t*>(etype_array.src->data);\n  dgl_type_t* dsttypes = static_cast<dgl_type_t*>(etype_array.dst->data);\n  dgl_type_t* etypes = static_cast<dgl_type_t*>(etype_array.id->data);\n  for (size_t i = 0; i < meta_graph->NumEdges(); ++i) {\n    dgl_type_t srctype = srctypes[i];\n    dgl_type_t dsttype = dsttypes[i];\n    dgl_type_t etype = etypes[i];\n    const auto& rg = rel_graphs[etype];\n    const auto sty = 0;\n    const auto dty = rg->NumVertexTypes() == 1 ? 0 : 1;\n    size_t nv;\n\n    // # nodes of source type\n    nv = rg->NumVertices(sty);\n    if (num_verts_per_type[srctype] < 0)\n      num_verts_per_type[srctype] = nv;\n    else\n      CHECK_EQ(num_verts_per_type[srctype], nv)\n          << \"Mismatch number of vertices for vertex type \" << srctype;\n    // # nodes of destination type\n    nv = rg->NumVertices(dty);\n    if (num_verts_per_type[dsttype] < 0)\n      num_verts_per_type[dsttype] = nv;\n    else\n      CHECK_EQ(num_verts_per_type[dsttype], nv)\n          << \"Mismatch number of vertices for vertex type \" << dsttype;\n  }\n  return num_verts_per_type;\n}\n\nstd::vector<UnitGraphPtr> CastToUnitGraphs(\n    const std::vector<HeteroGraphPtr>& rel_graphs) {\n  std::vector<UnitGraphPtr> relation_graphs(rel_graphs.size());\n  for (size_t i = 0; i < rel_graphs.size(); ++i) {\n    HeteroGraphPtr relg = rel_graphs[i];\n    if (std::dynamic_pointer_cast<UnitGraph>(relg)) {\n      relation_graphs[i] = std::dynamic_pointer_cast<UnitGraph>(relg);\n    } else {\n      relation_graphs[i] = CHECK_NOTNULL(\n          std::dynamic_pointer_cast<UnitGraph>(relg->GetRelationGraph(0)));\n    }\n  }\n  return relation_graphs;\n}\n\n}  // namespace\n\nHeteroGraph::HeteroGraph(\n    GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs,\n    const std::vector<int64_t>& num_nodes_per_type)\n    : BaseHeteroGraph(meta_graph) {\n  if (num_nodes_per_type.size() == 0)\n    num_verts_per_type_ = InferNumVerticesPerType(meta_graph, rel_graphs);\n  else\n    num_verts_per_type_ = num_nodes_per_type;\n  HeteroGraphSanityCheck(meta_graph, rel_graphs);\n  relation_graphs_ = CastToUnitGraphs(rel_graphs);\n}\n\nbool HeteroGraph::IsMultigraph() const {\n  for (const auto& hg : relation_graphs_) {\n    if (hg->IsMultigraph()) {\n      return true;\n    }\n  }\n  return false;\n}\n\nBoolArray HeteroGraph::HasVertices(dgl_type_t vtype, IdArray vids) const {\n  CHECK(aten::IsValidIdArray(vids)) << \"Invalid id array input\";\n  return aten::LT(vids, NumVertices(vtype));\n}\n\nHeteroSubgraph HeteroGraph::VertexSubgraph(\n    const std::vector<IdArray>& vids) const {\n  CHECK_EQ(vids.size(), NumVertexTypes())\n      << \"Invalid input: the input list size must be the same as the number of \"\n         \"vertex types.\";\n  HeteroSubgraph ret;\n  ret.induced_vertices = vids;\n  std::vector<int64_t> num_vertices_per_type(NumVertexTypes());\n  for (dgl_type_t vtype = 0; vtype < NumVertexTypes(); ++vtype)\n    num_vertices_per_type[vtype] = vids[vtype]->shape[0];\n  ret.induced_edges.resize(NumEdgeTypes());\n  std::vector<HeteroGraphPtr> subrels(NumEdgeTypes());\n  for (dgl_type_t etype = 0; etype < NumEdgeTypes(); ++etype) {\n    auto pair = meta_graph_->FindEdge(etype);\n    const dgl_type_t src_vtype = pair.first;\n    const dgl_type_t dst_vtype = pair.second;\n    const std::vector<IdArray> rel_vids =\n        (src_vtype == dst_vtype)\n            ? std::vector<IdArray>({vids[src_vtype]})\n            : std::vector<IdArray>({vids[src_vtype], vids[dst_vtype]});\n    const auto& rel_vsg = GetRelationGraph(etype)->VertexSubgraph(rel_vids);\n    subrels[etype] = rel_vsg.graph;\n    ret.induced_edges[etype] = rel_vsg.induced_edges[0];\n  }\n  ret.graph = HeteroGraphPtr(\n      new HeteroGraph(meta_graph_, subrels, std::move(num_vertices_per_type)));\n  return ret;\n}\n\nHeteroSubgraph HeteroGraph::EdgeSubgraph(\n    const std::vector<IdArray>& eids, bool preserve_nodes) const {\n  if (preserve_nodes) {\n    return EdgeSubgraphPreserveNodes(this, eids);\n  } else {\n    return EdgeSubgraphNoPreserveNodes(this, eids);\n  }\n}\n\nHeteroGraphPtr HeteroGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) {\n  auto hgindex = std::dynamic_pointer_cast<HeteroGraph>(g);\n  CHECK_NOTNULL(hgindex);\n  std::vector<HeteroGraphPtr> rel_graphs;\n  for (auto g : hgindex->relation_graphs_) {\n    rel_graphs.push_back(UnitGraph::AsNumBits(g, bits));\n  }\n  return HeteroGraphPtr(new HeteroGraph(\n      hgindex->meta_graph_, rel_graphs, hgindex->num_verts_per_type_));\n}\n\nHeteroGraphPtr HeteroGraph::CopyTo(HeteroGraphPtr g, const DGLContext& ctx) {\n  if (ctx == g->Context()) {\n    return g;\n  }\n  auto hgindex = std::dynamic_pointer_cast<HeteroGraph>(g);\n  CHECK_NOTNULL(hgindex);\n  std::vector<HeteroGraphPtr> rel_graphs;\n  for (auto g : hgindex->relation_graphs_) {\n    rel_graphs.push_back(UnitGraph::CopyTo(g, ctx));\n  }\n  return HeteroGraphPtr(new HeteroGraph(\n      hgindex->meta_graph_, rel_graphs, hgindex->num_verts_per_type_));\n}\n\nHeteroGraphPtr HeteroGraph::PinMemory(HeteroGraphPtr g) {\n  auto casted_ptr = std::dynamic_pointer_cast<HeteroGraph>(g);\n  CHECK_NOTNULL(casted_ptr);\n  auto relation_graphs = casted_ptr->relation_graphs_;\n\n  auto it = std::find_if_not(\n      relation_graphs.begin(), relation_graphs.end(),\n      [](auto& underlying_g) { return underlying_g->IsPinned(); });\n  // All underlying relation graphs are pinned, return the input hetero-graph\n  // directly.\n  if (it == relation_graphs.end()) return g;\n\n  std::vector<HeteroGraphPtr> pinned_relation_graphs(relation_graphs.size());\n  for (size_t i = 0; i < pinned_relation_graphs.size(); ++i) {\n    if (!relation_graphs[i]->IsPinned()) {\n      pinned_relation_graphs[i] = relation_graphs[i]->PinMemory();\n    } else {\n      pinned_relation_graphs[i] = relation_graphs[i];\n    }\n  }\n  return HeteroGraphPtr(new HeteroGraph(\n      casted_ptr->meta_graph_, pinned_relation_graphs,\n      casted_ptr->num_verts_per_type_));\n}\n\nvoid HeteroGraph::PinMemory_() {\n  for (auto g : relation_graphs_) g->PinMemory_();\n}\n\nvoid HeteroGraph::UnpinMemory_() {\n  for (auto g : relation_graphs_) g->UnpinMemory_();\n}\n\nvoid HeteroGraph::RecordStream(DGLStreamHandle stream) {\n  for (auto g : relation_graphs_) g->RecordStream(stream);\n}\n\nstd::string HeteroGraph::SharedMemName() const {\n  return shared_mem_ ? shared_mem_->GetName() : \"\";\n}\n\nHeteroGraphPtr HeteroGraph::CopyToSharedMem(\n    HeteroGraphPtr g, const std::string& name,\n    const std::vector<std::string>& ntypes,\n    const std::vector<std::string>& etypes, const std::set<std::string>& fmts) {\n  // TODO(JJ): Raise error when calling shared_memory if graph index is on gpu\n  auto hg = std::dynamic_pointer_cast<HeteroGraph>(g);\n  CHECK_NOTNULL(hg);\n  if (hg->SharedMemName() == name) return g;\n\n  // Copy buffer to share memory\n  auto mem = std::make_shared<SharedMemory>(name);\n  auto mem_buf = mem->CreateNew(SHARED_MEM_METAINFO_SIZE_MAX);\n  dmlc::MemoryFixedSizeStream strm(mem_buf, SHARED_MEM_METAINFO_SIZE_MAX);\n  SharedMemManager shm(name, &strm);\n\n  bool has_coo = fmts.find(\"coo\") != fmts.end();\n  bool has_csr = fmts.find(\"csr\") != fmts.end();\n  bool has_csc = fmts.find(\"csc\") != fmts.end();\n  shm.Write(g->NumBits());\n  shm.Write(has_coo);\n  shm.Write(has_csr);\n  shm.Write(has_csc);\n  shm.Write(ImmutableGraph::ToImmutable(hg->meta_graph_));\n  shm.Write(hg->num_verts_per_type_);\n\n  std::vector<HeteroGraphPtr> relgraphs(g->NumEdgeTypes());\n\n  for (dgl_type_t etype = 0; etype < g->NumEdgeTypes(); ++etype) {\n    auto src_dst_type = g->GetEndpointTypes(etype);\n    int num_vtypes = (src_dst_type.first == src_dst_type.second ? 1 : 2);\n    aten::COOMatrix coo;\n    aten::CSRMatrix csr, csc;\n    std::string prefix = name + \"_\" + std::to_string(etype);\n    if (has_coo) {\n      coo = shm.CopyToSharedMem(hg->GetCOOMatrix(etype), prefix + \"_coo\");\n    }\n    if (has_csr) {\n      csr = shm.CopyToSharedMem(hg->GetCSRMatrix(etype), prefix + \"_csr\");\n    }\n    if (has_csc) {\n      csc = shm.CopyToSharedMem(hg->GetCSCMatrix(etype), prefix + \"_csc\");\n    }\n    relgraphs[etype] = UnitGraph::CreateUnitGraphFrom(\n        num_vtypes, csc, csr, coo, has_csc, has_csr, has_coo);\n  }\n\n  auto ret = std::shared_ptr<HeteroGraph>(\n      new HeteroGraph(hg->meta_graph_, relgraphs, hg->num_verts_per_type_));\n  ret->shared_mem_ = mem;\n\n  shm.Write(ntypes);\n  shm.Write(etypes);\n  return ret;\n}\n\nstd::tuple<HeteroGraphPtr, std::vector<std::string>, std::vector<std::string>>\nHeteroGraph::CreateFromSharedMem(const std::string& name) {\n  bool exist = SharedMemory::Exist(name);\n  if (!exist) {\n    return std::make_tuple(\n        nullptr, std::vector<std::string>(), std::vector<std::string>());\n  }\n  auto mem = std::make_shared<SharedMemory>(name);\n  auto mem_buf = mem->Open(SHARED_MEM_METAINFO_SIZE_MAX);\n  dmlc::MemoryFixedSizeStream strm(mem_buf, SHARED_MEM_METAINFO_SIZE_MAX);\n  SharedMemManager shm(name, &strm);\n\n  uint8_t nbits;\n  CHECK(shm.Read(&nbits)) << \"invalid nbits (unit8_t)\";\n\n  bool has_coo, has_csr, has_csc;\n  CHECK(shm.Read(&has_coo)) << \"invalid nbits (unit8_t)\";\n  CHECK(shm.Read(&has_csr)) << \"invalid csr (unit8_t)\";\n  CHECK(shm.Read(&has_csc)) << \"invalid csc (unit8_t)\";\n\n  auto meta_imgraph = Serializer::make_shared<ImmutableGraph>();\n  CHECK(shm.Read(&meta_imgraph)) << \"Invalid meta graph\";\n  GraphPtr metagraph = meta_imgraph;\n\n  std::vector<int64_t> num_verts_per_type;\n  CHECK(shm.Read(&num_verts_per_type)) << \"Invalid number of vertices per type\";\n\n  std::vector<HeteroGraphPtr> relgraphs(metagraph->NumEdges());\n  for (dgl_type_t etype = 0; etype < metagraph->NumEdges(); ++etype) {\n    auto src_dst = metagraph->FindEdge(etype);\n    int num_vtypes = (src_dst.first == src_dst.second) ? 1 : 2;\n    aten::COOMatrix coo;\n    aten::CSRMatrix csr, csc;\n    std::string prefix = name + \"_\" + std::to_string(etype);\n    if (has_coo) {\n      shm.CreateFromSharedMem(&coo, prefix + \"_coo\");\n    }\n    if (has_csr) {\n      shm.CreateFromSharedMem(&csr, prefix + \"_csr\");\n    }\n    if (has_csc) {\n      shm.CreateFromSharedMem(&csc, prefix + \"_csc\");\n    }\n\n    relgraphs[etype] = UnitGraph::CreateUnitGraphFrom(\n        num_vtypes, csc, csr, coo, has_csc, has_csr, has_coo);\n  }\n\n  auto ret =\n      std::make_shared<HeteroGraph>(metagraph, relgraphs, num_verts_per_type);\n  ret->shared_mem_ = mem;\n\n  std::vector<std::string> ntypes;\n  std::vector<std::string> etypes;\n  CHECK(shm.Read(&ntypes)) << \"invalid ntypes\";\n  CHECK(shm.Read(&etypes)) << \"invalid etypes\";\n  return std::make_tuple(ret, ntypes, etypes);\n}\n\nHeteroGraphPtr HeteroGraph::GetGraphInFormat(dgl_format_code_t formats) const {\n  std::vector<HeteroGraphPtr> format_rels(NumEdgeTypes());\n  for (dgl_type_t etype = 0; etype < NumEdgeTypes(); ++etype) {\n    auto relgraph =\n        std::dynamic_pointer_cast<UnitGraph>(GetRelationGraph(etype));\n    format_rels[etype] = relgraph->GetGraphInFormat(formats);\n  }\n  return HeteroGraphPtr(\n      new HeteroGraph(meta_graph_, format_rels, NumVerticesPerType()));\n}\n\nFlattenedHeteroGraphPtr HeteroGraph::Flatten(\n    const std::vector<dgl_type_t>& etypes) const {\n  const int64_t bits = NumBits();\n  if (bits == 32) {\n    return FlattenImpl<int32_t>(etypes);\n  } else {\n    return FlattenImpl<int64_t>(etypes);\n  }\n}\n\ntemplate <class IdType>\nFlattenedHeteroGraphPtr HeteroGraph::FlattenImpl(\n    const std::vector<dgl_type_t>& etypes) const {\n  std::unordered_map<dgl_type_t, size_t> srctype_offsets, dsttype_offsets;\n  size_t src_nodes = 0, dst_nodes = 0;\n  std::vector<dgl_type_t> induced_srctype, induced_dsttype;\n  std::vector<IdType> induced_srcid, induced_dstid;\n  std::vector<dgl_type_t> srctype_set, dsttype_set;\n\n  // XXXtype_offsets contain the mapping from node type and number of nodes\n  // after this loop.\n  for (dgl_type_t etype : etypes) {\n    auto src_dsttype = meta_graph_->FindEdge(etype);\n    dgl_type_t srctype = src_dsttype.first;\n    dgl_type_t dsttype = src_dsttype.second;\n    size_t num_srctype_nodes = NumVertices(srctype);\n    size_t num_dsttype_nodes = NumVertices(dsttype);\n\n    if (srctype_offsets.count(srctype) == 0) {\n      srctype_offsets[srctype] = num_srctype_nodes;\n      srctype_set.push_back(srctype);\n    }\n    if (dsttype_offsets.count(dsttype) == 0) {\n      dsttype_offsets[dsttype] = num_dsttype_nodes;\n      dsttype_set.push_back(dsttype);\n    }\n  }\n  // Sort the node types so that we can compare the sets and decide whether a\n  // homogeneous graph should be returned.\n  std::sort(srctype_set.begin(), srctype_set.end());\n  std::sort(dsttype_set.begin(), dsttype_set.end());\n  bool homograph =\n      (srctype_set.size() == dsttype_set.size()) &&\n      std::equal(srctype_set.begin(), srctype_set.end(), dsttype_set.begin());\n\n  // XXXtype_offsets contain the mapping from node type to node ID offsets after\n  // these two loops.\n  for (size_t i = 0; i < srctype_set.size(); ++i) {\n    dgl_type_t ntype = srctype_set[i];\n    size_t num_nodes = srctype_offsets[ntype];\n    srctype_offsets[ntype] = src_nodes;\n    src_nodes += num_nodes;\n    for (size_t j = 0; j < num_nodes; ++j) {\n      induced_srctype.push_back(ntype);\n      induced_srcid.push_back(j);\n    }\n  }\n  for (size_t i = 0; i < dsttype_set.size(); ++i) {\n    dgl_type_t ntype = dsttype_set[i];\n    size_t num_nodes = dsttype_offsets[ntype];\n    dsttype_offsets[ntype] = dst_nodes;\n    dst_nodes += num_nodes;\n    for (size_t j = 0; j < num_nodes; ++j) {\n      induced_dsttype.push_back(ntype);\n      induced_dstid.push_back(j);\n    }\n  }\n\n  // TODO(minjie): Using concat operations cause many fragmented memory.\n  //   Need to optimize it in the future.\n  std::vector<IdArray> src_arrs, dst_arrs, eid_arrs, induced_etypes;\n  src_arrs.reserve(etypes.size());\n  dst_arrs.reserve(etypes.size());\n  eid_arrs.reserve(etypes.size());\n  induced_etypes.reserve(etypes.size());\n  for (dgl_type_t etype : etypes) {\n    auto src_dsttype = meta_graph_->FindEdge(etype);\n    dgl_type_t srctype = src_dsttype.first;\n    dgl_type_t dsttype = src_dsttype.second;\n    size_t srctype_offset = srctype_offsets[srctype];\n    size_t dsttype_offset = dsttype_offsets[dsttype];\n\n    EdgeArray edges = Edges(etype);\n    size_t num_edges = NumEdges(etype);\n    src_arrs.push_back(edges.src + srctype_offset);\n    dst_arrs.push_back(edges.dst + dsttype_offset);\n    eid_arrs.push_back(edges.id);\n    induced_etypes.push_back(\n        aten::Full(etype, num_edges, NumBits(), Context()));\n  }\n\n  HeteroGraphPtr gptr = UnitGraph::CreateFromCOO(\n      homograph ? 1 : 2, src_nodes, dst_nodes, aten::Concat(src_arrs),\n      aten::Concat(dst_arrs));\n\n  // Sanity check\n  CHECK_EQ(gptr->Context(), Context());\n  CHECK_EQ(gptr->NumBits(), NumBits());\n\n  FlattenedHeteroGraph* result = new FlattenedHeteroGraph;\n  result->graph = HeteroGraphRef(\n      HeteroGraphPtr(new HeteroGraph(gptr->meta_graph(), {gptr})));\n  result->induced_srctype =\n      aten::VecToIdArray(induced_srctype).CopyTo(Context());\n  result->induced_srctype_set =\n      aten::VecToIdArray(srctype_set).CopyTo(Context());\n  result->induced_srcid = aten::VecToIdArray(induced_srcid).CopyTo(Context());\n  result->induced_etype = aten::Concat(induced_etypes);\n  result->induced_etype_set = aten::VecToIdArray(etypes).CopyTo(Context());\n  result->induced_eid = aten::Concat(eid_arrs);\n  result->induced_dsttype =\n      aten::VecToIdArray(induced_dsttype).CopyTo(Context());\n  result->induced_dsttype_set =\n      aten::VecToIdArray(dsttype_set).CopyTo(Context());\n  result->induced_dstid = aten::VecToIdArray(induced_dstid).CopyTo(Context());\n  return FlattenedHeteroGraphPtr(result);\n}\n\nconstexpr uint64_t kDGLSerialize_HeteroGraph = 0xDD589FBE35224ABF;\n\nbool HeteroGraph::Load(dmlc::Stream* fs) {\n  uint64_t magicNum;\n  CHECK(fs->Read(&magicNum)) << \"Invalid Magic Number\";\n  CHECK_EQ(magicNum, kDGLSerialize_HeteroGraph) << \"Invalid HeteroGraph Data\";\n  auto meta_imgraph = Serializer::make_shared<ImmutableGraph>();\n  CHECK(fs->Read(&meta_imgraph)) << \"Invalid meta graph\";\n  meta_graph_ = meta_imgraph;\n  CHECK(fs->Read(&relation_graphs_)) << \"Invalid relation_graphs_\";\n  CHECK(fs->Read(&num_verts_per_type_)) << \"Invalid num_verts_per_type_\";\n  return true;\n}\n\nvoid HeteroGraph::Save(dmlc::Stream* fs) const {\n  fs->Write(kDGLSerialize_HeteroGraph);\n  auto meta_graph_ptr = ImmutableGraph::ToImmutable(meta_graph());\n  fs->Write(meta_graph_ptr);\n  fs->Write(relation_graphs_);\n  fs->Write(num_verts_per_type_);\n}\n\nGraphPtr HeteroGraph::AsImmutableGraph() const {\n  CHECK(NumVertexTypes() == 1) << \"graph has more than one node types\";\n  CHECK(NumEdgeTypes() == 1) << \"graph has more than one edge types\";\n  auto unit_graph =\n      CHECK_NOTNULL(std::dynamic_pointer_cast<UnitGraph>(GetRelationGraph(0)));\n  return unit_graph->AsImmutableGraph();\n}\n\nHeteroGraphPtr HeteroGraph::LineGraph(bool backtracking) const {\n  CHECK_EQ(1, meta_graph_->NumEdges())\n      << \"Only support Homogeneous graph now (one edge type)\";\n  CHECK_EQ(1, meta_graph_->NumVertices())\n      << \"Only support Homogeneous graph now (one node type)\";\n  CHECK_EQ(1, relation_graphs_.size()) << \"Only support Homogeneous graph now\";\n  UnitGraphPtr ug = relation_graphs_[0];\n\n  const auto& ulg = ug->LineGraph(backtracking);\n  std::vector<HeteroGraphPtr> rel_graph = {ulg};\n  std::vector<int64_t> num_nodes_per_type = {\n      static_cast<int64_t>(ulg->NumVertices(0))};\n  return HeteroGraphPtr(\n      new HeteroGraph(meta_graph_, rel_graph, std::move(num_nodes_per_type)));\n}\n\n}  // namespace dgl\n"
  },
  {
    "path": "src/graph/heterograph.h",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file graph/heterograph.h\n * @brief Heterograph\n */\n\n#ifndef DGL_GRAPH_HETEROGRAPH_H_\n#define DGL_GRAPH_HETEROGRAPH_H_\n\n#include <dgl/base_heterograph.h>\n#include <dgl/lazy.h>\n#include <dgl/runtime/shared_mem.h>\n\n#include <memory>\n#include <set>\n#include <string>\n#include <tuple>\n#include <utility>\n#include <vector>\n\n#include \"./unit_graph.h\"\n#include \"shared_mem_manager.h\"\n\nnamespace dgl {\n\n/** @brief Heterograph */\nclass HeteroGraph : public BaseHeteroGraph {\n public:\n  HeteroGraph(\n      GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs,\n      const std::vector<int64_t>& num_nodes_per_type = {});\n\n  HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const override {\n    CHECK_LT(etype, meta_graph_->NumEdges()) << \"Invalid edge type: \" << etype;\n    return relation_graphs_[etype];\n  }\n\n  void AddVertices(dgl_type_t vtype, uint64_t num_vertices) override {\n    LOG(FATAL) << \"Bipartite graph is not mutable.\";\n  }\n\n  void AddEdge(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) override {\n    LOG(FATAL) << \"Bipartite graph is not mutable.\";\n  }\n\n  void AddEdges(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) override {\n    LOG(FATAL) << \"Bipartite graph is not mutable.\";\n  }\n\n  void Clear() override { LOG(FATAL) << \"Bipartite graph is not mutable.\"; }\n\n  DGLDataType DataType() const override {\n    return relation_graphs_[0]->DataType();\n  }\n\n  DGLContext Context() const override { return relation_graphs_[0]->Context(); }\n\n  bool IsPinned() const override { return relation_graphs_[0]->IsPinned(); }\n\n  uint8_t NumBits() const override { return relation_graphs_[0]->NumBits(); }\n\n  bool IsMultigraph() const override;\n\n  bool IsReadonly() const override { return true; }\n\n  uint64_t NumVertices(dgl_type_t vtype) const override {\n    CHECK(meta_graph_->HasVertex(vtype)) << \"Invalid vertex type: \" << vtype;\n    return num_verts_per_type_[vtype];\n  }\n\n  inline std::vector<int64_t> NumVerticesPerType() const override {\n    return num_verts_per_type_;\n  }\n\n  uint64_t NumEdges(dgl_type_t etype) const override {\n    return GetRelationGraph(etype)->NumEdges(0);\n  }\n\n  bool HasVertex(dgl_type_t vtype, dgl_id_t vid) const override {\n    return vid < NumVertices(vtype);\n  }\n\n  BoolArray HasVertices(dgl_type_t vtype, IdArray vids) const override;\n\n  bool HasEdgeBetween(\n      dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {\n    return GetRelationGraph(etype)->HasEdgeBetween(0, src, dst);\n  }\n\n  BoolArray HasEdgesBetween(\n      dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const override {\n    return GetRelationGraph(etype)->HasEdgesBetween(0, src_ids, dst_ids);\n  }\n\n  IdArray Predecessors(dgl_type_t etype, dgl_id_t dst) const override {\n    return GetRelationGraph(etype)->Predecessors(0, dst);\n  }\n\n  IdArray Successors(dgl_type_t etype, dgl_id_t src) const override {\n    return GetRelationGraph(etype)->Successors(0, src);\n  }\n\n  IdArray EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {\n    return GetRelationGraph(etype)->EdgeId(0, src, dst);\n  }\n\n  EdgeArray EdgeIdsAll(\n      dgl_type_t etype, IdArray src, IdArray dst) const override {\n    return GetRelationGraph(etype)->EdgeIdsAll(0, src, dst);\n  }\n\n  IdArray EdgeIdsOne(\n      dgl_type_t etype, IdArray src, IdArray dst) const override {\n    return GetRelationGraph(etype)->EdgeIdsOne(0, src, dst);\n  }\n\n  std::pair<dgl_id_t, dgl_id_t> FindEdge(\n      dgl_type_t etype, dgl_id_t eid) const override {\n    return GetRelationGraph(etype)->FindEdge(0, eid);\n  }\n\n  EdgeArray FindEdges(dgl_type_t etype, IdArray eids) const override {\n    return GetRelationGraph(etype)->FindEdges(0, eids);\n  }\n\n  EdgeArray InEdges(dgl_type_t etype, dgl_id_t vid) const override {\n    return GetRelationGraph(etype)->InEdges(0, vid);\n  }\n\n  EdgeArray InEdges(dgl_type_t etype, IdArray vids) const override {\n    return GetRelationGraph(etype)->InEdges(0, vids);\n  }\n\n  EdgeArray OutEdges(dgl_type_t etype, dgl_id_t vid) const override {\n    return GetRelationGraph(etype)->OutEdges(0, vid);\n  }\n\n  EdgeArray OutEdges(dgl_type_t etype, IdArray vids) const override {\n    return GetRelationGraph(etype)->OutEdges(0, vids);\n  }\n\n  EdgeArray Edges(\n      dgl_type_t etype, const std::string& order = \"\") const override {\n    return GetRelationGraph(etype)->Edges(0, order);\n  }\n\n  uint64_t InDegree(dgl_type_t etype, dgl_id_t vid) const override {\n    return GetRelationGraph(etype)->InDegree(0, vid);\n  }\n\n  DegreeArray InDegrees(dgl_type_t etype, IdArray vids) const override {\n    return GetRelationGraph(etype)->InDegrees(0, vids);\n  }\n\n  uint64_t OutDegree(dgl_type_t etype, dgl_id_t vid) const override {\n    return GetRelationGraph(etype)->OutDegree(0, vid);\n  }\n\n  DegreeArray OutDegrees(dgl_type_t etype, IdArray vids) const override {\n    return GetRelationGraph(etype)->OutDegrees(0, vids);\n  }\n\n  DGLIdIters SuccVec(dgl_type_t etype, dgl_id_t vid) const override {\n    return GetRelationGraph(etype)->SuccVec(0, vid);\n  }\n\n  DGLIdIters OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const override {\n    return GetRelationGraph(etype)->OutEdgeVec(0, vid);\n  }\n\n  DGLIdIters PredVec(dgl_type_t etype, dgl_id_t vid) const override {\n    return GetRelationGraph(etype)->PredVec(0, vid);\n  }\n\n  DGLIdIters InEdgeVec(dgl_type_t etype, dgl_id_t vid) const override {\n    return GetRelationGraph(etype)->InEdgeVec(0, vid);\n  }\n\n  std::vector<IdArray> GetAdj(\n      dgl_type_t etype, bool transpose, const std::string& fmt) const override {\n    return GetRelationGraph(etype)->GetAdj(0, transpose, fmt);\n  }\n\n  aten::COOMatrix GetCOOMatrix(dgl_type_t etype) const override {\n    return GetRelationGraph(etype)->GetCOOMatrix(0);\n  }\n\n  aten::CSRMatrix GetCSCMatrix(dgl_type_t etype) const override {\n    return GetRelationGraph(etype)->GetCSCMatrix(0);\n  }\n\n  aten::CSRMatrix GetCSRMatrix(dgl_type_t etype) const override {\n    return GetRelationGraph(etype)->GetCSRMatrix(0);\n  }\n\n  SparseFormat SelectFormat(\n      dgl_type_t etype, dgl_format_code_t preferred_formats) const override {\n    return GetRelationGraph(etype)->SelectFormat(0, preferred_formats);\n  }\n\n  dgl_format_code_t GetAllowedFormats() const override {\n    return GetRelationGraph(0)->GetAllowedFormats();\n  }\n\n  dgl_format_code_t GetCreatedFormats() const override {\n    return GetRelationGraph(0)->GetCreatedFormats();\n  }\n\n  HeteroSubgraph VertexSubgraph(\n      const std::vector<IdArray>& vids) const override;\n\n  HeteroSubgraph EdgeSubgraph(\n      const std::vector<IdArray>& eids,\n      bool preserve_nodes = false) const override;\n\n  HeteroGraphPtr GetGraphInFormat(dgl_format_code_t formats) const override;\n\n  FlattenedHeteroGraphPtr Flatten(\n      const std::vector<dgl_type_t>& etypes) const override;\n\n  GraphPtr AsImmutableGraph() const override;\n\n  /** @return Load HeteroGraph from stream, using CSRMatrix*/\n  bool Load(dmlc::Stream* fs);\n\n  /** @return Save HeteroGraph to stream, using CSRMatrix */\n  void Save(dmlc::Stream* fs) const;\n\n  /** @brief Convert the graph to use the given number of bits for storage */\n  static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits);\n\n  /** @brief Copy the data to another context */\n  static HeteroGraphPtr CopyTo(HeteroGraphPtr g, const DGLContext& ctx);\n\n  /**\n   * @brief Pin all relation graphs of the current graph.\n   * @note The graph will be pinned inplace. Behavior depends on the current\n   * context, kDGLCPU: will be pinned; IsPinned: directly return; kDGLCUDA:\n   * invalid, will throw an error. The context check is deferred to pinning the\n   * NDArray.\n   */\n  void PinMemory_() override;\n\n  /**\n   * @brief Unpin all relation graphs of the current graph.\n   * @note The graph will be unpinned inplace. Behavior depends on the current\n   * context, IsPinned: will be unpinned; others: directly return. The context\n   * check is deferred to unpinning the NDArray.\n   */\n  void UnpinMemory_();\n\n  /**\n   * @brief Copy the current graph to pinned memory managed by\n   *     PyTorch CachingHostAllocator for each relation graph.\n   * @note If any of the underlying relation graphs are already pinned, the\n   *     function will utilize their existing copies. If all of them are\n   *     pinned, the function will return the original input hetero-graph\n   *     directly.\n   */\n  static HeteroGraphPtr PinMemory(HeteroGraphPtr g);\n\n  /**\n   * @brief Record stream for this graph.\n   * @param stream The stream that is using the graph\n   */\n  void RecordStream(DGLStreamHandle stream) override;\n\n  /**\n   * @brief Copy the data to shared memory.\n   *\n   * Also save names of node types and edge types of the HeteroGraph object to\n   * shared memory\n   */\n  static HeteroGraphPtr CopyToSharedMem(\n      HeteroGraphPtr g, const std::string& name,\n      const std::vector<std::string>& ntypes,\n      const std::vector<std::string>& etypes,\n      const std::set<std::string>& fmts);\n\n  /**\n   * @brief Create a heterograph from\n   *\n   * @return the HeteroGraphPtr, names of node types, names of edge types\n   */\n  static std::tuple<\n      HeteroGraphPtr, std::vector<std::string>, std::vector<std::string>>\n  CreateFromSharedMem(const std::string& name);\n\n  /** @brief Creat a LineGraph of self */\n  HeteroGraphPtr LineGraph(bool backtracking) const;\n\n  const std::vector<UnitGraphPtr>& relation_graphs() const {\n    return relation_graphs_;\n  }\n\n private:\n  // To create empty class\n  friend class Serializer;\n\n  // Empty Constructor, only for serializer\n  HeteroGraph() : BaseHeteroGraph() {}\n\n  /** @brief A map from edge type to unit graph */\n  std::vector<UnitGraphPtr> relation_graphs_;\n\n  /** @brief A map from vert type to the number of verts in the type */\n  std::vector<int64_t> num_verts_per_type_;\n\n  /** @brief The shared memory object for meta info*/\n  std::shared_ptr<runtime::SharedMemory> shared_mem_;\n\n  /**\n   * @brief The name of the shared memory. Return empty string if it is not in\n   * shared memory.\n   */\n  std::string SharedMemName() const;\n\n  /**\n   * @brief template class for Flatten operation\n   *\n   * @tparam IdType Graph's index data type, can be int32_t or int64_t\n   * @param etypes vector of etypes to be falttened\n   * @return pointer of FlattenedHeteroGraphh\n   */\n  template <class IdType>\n  FlattenedHeteroGraphPtr FlattenImpl(\n      const std::vector<dgl_type_t>& etypes) const;\n};\n\n}  // namespace dgl\n\nnamespace dmlc {\nDMLC_DECLARE_TRAITS(has_saveload, dgl::HeteroGraph, true);\n}  // namespace dmlc\n\n#endif  // DGL_GRAPH_HETEROGRAPH_H_\n"
  },
  {
    "path": "src/graph/heterograph_capi.cc",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file graph/heterograph_capi.cc\n * @brief Heterograph CAPI bindings.\n */\n#include <dgl/array.h>\n#include <dgl/aten/coo.h>\n#include <dgl/immutable_graph.h>\n#include <dgl/packed_func_ext.h>\n#include <dgl/runtime/c_runtime_api.h>\n#include <dgl/runtime/container.h>\n#include <dgl/runtime/parallel_for.h>\n\n#include <set>\n\n#include \"../c_api_common.h\"\n#include \"./heterograph.h\"\n#include \"unit_graph.h\"\n\nusing namespace dgl::runtime;\n\nnamespace dgl {\n\n///////////////////////// Unitgraph functions /////////////////////////\n\n// XXX(minjie): Ideally, Unitgraph should be invisible to python side\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroCreateUnitGraphFromCOO\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      int64_t nvtypes = args[0];\n      int64_t num_src = args[1];\n      int64_t num_dst = args[2];\n      IdArray row = args[3];\n      IdArray col = args[4];\n      List<Value> formats = args[5];\n      bool row_sorted = args[6];\n      bool col_sorted = args[7];\n      std::vector<SparseFormat> formats_vec;\n      for (Value val : formats) {\n        std::string fmt = val->data;\n        formats_vec.push_back(ParseSparseFormat(fmt));\n      }\n      const auto code = SparseFormatsToCode(formats_vec);\n      auto hgptr = CreateFromCOO(\n          nvtypes, num_src, num_dst, row, col, row_sorted, col_sorted, code);\n      *rv = HeteroGraphRef(hgptr);\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroCreateUnitGraphFromCSR\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      int64_t nvtypes = args[0];\n      int64_t num_src = args[1];\n      int64_t num_dst = args[2];\n      IdArray indptr = args[3];\n      IdArray indices = args[4];\n      IdArray edge_ids = args[5];\n      List<Value> formats = args[6];\n      bool transpose = args[7];\n      std::vector<SparseFormat> formats_vec;\n      for (Value val : formats) {\n        std::string fmt = val->data;\n        formats_vec.push_back(ParseSparseFormat(fmt));\n      }\n      const auto code = SparseFormatsToCode(formats_vec);\n      if (!transpose) {\n        auto hgptr = CreateFromCSR(\n            nvtypes, num_src, num_dst, indptr, indices, edge_ids, code);\n        *rv = HeteroGraphRef(hgptr);\n      } else {\n        auto hgptr = CreateFromCSC(\n            nvtypes, num_src, num_dst, indptr, indices, edge_ids, code);\n        *rv = HeteroGraphRef(hgptr);\n      }\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroCreateHeteroGraph\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      GraphRef meta_graph = args[0];\n      List<HeteroGraphRef> rel_graphs = args[1];\n      std::vector<HeteroGraphPtr> rel_ptrs;\n      rel_ptrs.reserve(rel_graphs.size());\n      for (const auto& ref : rel_graphs) {\n        rel_ptrs.push_back(ref.sptr());\n      }\n      auto hgptr = CreateHeteroGraph(meta_graph.sptr(), rel_ptrs);\n      *rv = HeteroGraphRef(hgptr);\n    });\n\nDGL_REGISTER_GLOBAL(\n    \"heterograph_index._CAPI_DGLHeteroCreateHeteroGraphWithNumNodes\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      GraphRef meta_graph = args[0];\n      List<HeteroGraphRef> rel_graphs = args[1];\n      IdArray num_nodes_per_type = args[2];\n      std::vector<HeteroGraphPtr> rel_ptrs;\n      rel_ptrs.reserve(rel_graphs.size());\n      for (const auto& ref : rel_graphs) {\n        rel_ptrs.push_back(ref.sptr());\n      }\n      auto hgptr = CreateHeteroGraph(\n          meta_graph.sptr(), rel_ptrs, num_nodes_per_type.ToVector<int64_t>());\n      *rv = HeteroGraphRef(hgptr);\n    });\n\n///////////////////////// HeteroGraph member functions /////////////////////////\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroGetMetaGraph\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      *rv = hg->meta_graph();\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroIsMetaGraphUniBipartite\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      GraphPtr mg = hg->meta_graph();\n      *rv = mg->IsUniBipartite();\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroGetRelationGraph\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      dgl_type_t etype = args[1];\n      CHECK_LE(etype, hg->NumEdgeTypes()) << \"invalid edge type \" << etype;\n      auto unit_graph = hg->GetRelationGraph(etype);\n      auto meta_graph = unit_graph->meta_graph();\n      auto hgptr = CreateHeteroGraph(\n          meta_graph, {unit_graph}, unit_graph->NumVerticesPerType());\n      *rv = HeteroGraphRef(hgptr);\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroGetFlattenedGraph\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      List<Value> etypes = args[1];\n      std::vector<dgl_id_t> etypes_vec;\n      for (Value val : etypes) {\n        // (gq) have to decompose it into two statements because of a weird MSVC\n        // internal error\n        dgl_id_t id = val->data;\n        etypes_vec.push_back(id);\n      }\n\n      *rv = FlattenedHeteroGraphRef(hg->Flatten(etypes_vec));\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroAddVertices\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      dgl_type_t vtype = args[1];\n      int64_t num = args[2];\n      hg->AddVertices(vtype, num);\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroAddEdge\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      dgl_type_t etype = args[1];\n      dgl_id_t src = args[2];\n      dgl_id_t dst = args[3];\n      hg->AddEdge(etype, src, dst);\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroAddEdges\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      dgl_type_t etype = args[1];\n      IdArray src = args[2];\n      IdArray dst = args[3];\n      hg->AddEdges(etype, src, dst);\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroClear\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      hg->Clear();\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroDataType\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      *rv = hg->DataType();\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroContext\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      *rv = hg->Context();\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroIsPinned\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      *rv = hg->IsPinned();\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroNumBits\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      *rv = hg->NumBits();\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroIsMultigraph\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      *rv = hg->IsMultigraph();\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroIsReadonly\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      *rv = hg->IsReadonly();\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroNumVertices\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      dgl_type_t vtype = args[1];\n      *rv = static_cast<int64_t>(hg->NumVertices(vtype));\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroNumEdges\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      dgl_type_t etype = args[1];\n      *rv = static_cast<int64_t>(hg->NumEdges(etype));\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroHasVertex\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      dgl_type_t vtype = args[1];\n      dgl_id_t vid = args[2];\n      *rv = hg->HasVertex(vtype, vid);\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroHasVertices\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      dgl_type_t vtype = args[1];\n      IdArray vids = args[2];\n      *rv = hg->HasVertices(vtype, vids);\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroHasEdgeBetween\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      dgl_type_t etype = args[1];\n      dgl_id_t src = args[2];\n      dgl_id_t dst = args[3];\n      *rv = hg->HasEdgeBetween(etype, src, dst);\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroHasEdgesBetween\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      dgl_type_t etype = args[1];\n      IdArray src = args[2];\n      IdArray dst = args[3];\n      *rv = hg->HasEdgesBetween(etype, src, dst);\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroPredecessors\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      dgl_type_t etype = args[1];\n      dgl_id_t dst = args[2];\n      *rv = hg->Predecessors(etype, dst);\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroSuccessors\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      dgl_type_t etype = args[1];\n      dgl_id_t src = args[2];\n      *rv = hg->Successors(etype, src);\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroEdgeId\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      dgl_type_t etype = args[1];\n      dgl_id_t src = args[2];\n      dgl_id_t dst = args[3];\n      *rv = hg->EdgeId(etype, src, dst);\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroEdgeIdsAll\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      dgl_type_t etype = args[1];\n      IdArray src = args[2];\n      IdArray dst = args[3];\n      const auto& ret = hg->EdgeIdsAll(etype, src, dst);\n      *rv = ConvertEdgeArrayToPackedFunc(ret);\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroEdgeIdsOne\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      dgl_type_t etype = args[1];\n      IdArray src = args[2];\n      IdArray dst = args[3];\n      *rv = hg->EdgeIdsOne(etype, src, dst);\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroFindEdges\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      dgl_type_t etype = args[1];\n      IdArray eids = args[2];\n      const auto& ret = hg->FindEdges(etype, eids);\n      *rv = ConvertEdgeArrayToPackedFunc(ret);\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroInEdges_1\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      dgl_type_t etype = args[1];\n      dgl_id_t vid = args[2];\n      const auto& ret = hg->InEdges(etype, vid);\n      *rv = ConvertEdgeArrayToPackedFunc(ret);\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroInEdges_2\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      dgl_type_t etype = args[1];\n      IdArray vids = args[2];\n      const auto& ret = hg->InEdges(etype, vids);\n      *rv = ConvertEdgeArrayToPackedFunc(ret);\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroOutEdges_1\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      dgl_type_t etype = args[1];\n      dgl_id_t vid = args[2];\n      const auto& ret = hg->OutEdges(etype, vid);\n      *rv = ConvertEdgeArrayToPackedFunc(ret);\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroOutEdges_2\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      dgl_type_t etype = args[1];\n      IdArray vids = args[2];\n      const auto& ret = hg->OutEdges(etype, vids);\n      *rv = ConvertEdgeArrayToPackedFunc(ret);\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroEdges\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      dgl_type_t etype = args[1];\n      std::string order = args[2];\n      const auto& ret = hg->Edges(etype, order);\n      *rv = ConvertEdgeArrayToPackedFunc(ret);\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroInDegree\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      dgl_type_t etype = args[1];\n      dgl_id_t vid = args[2];\n      *rv = static_cast<int64_t>(hg->InDegree(etype, vid));\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroInDegrees\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      dgl_type_t etype = args[1];\n      IdArray vids = args[2];\n      *rv = hg->InDegrees(etype, vids);\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroOutDegree\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      dgl_type_t etype = args[1];\n      dgl_id_t vid = args[2];\n      *rv = static_cast<int64_t>(hg->OutDegree(etype, vid));\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroOutDegrees\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      dgl_type_t etype = args[1];\n      IdArray vids = args[2];\n      *rv = hg->OutDegrees(etype, vids);\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroGetAdj\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      dgl_type_t etype = args[1];\n      bool transpose = args[2];\n      std::string fmt = args[3];\n      *rv = ConvertNDArrayVectorToPackedFunc(hg->GetAdj(etype, transpose, fmt));\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroVertexSubgraph\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      List<Value> vids = args[1];\n      std::vector<IdArray> vid_vec;\n      vid_vec.reserve(vids.size());\n      for (Value val : vids) {\n        vid_vec.push_back(val->data);\n      }\n      std::shared_ptr<HeteroSubgraph> subg(\n          new HeteroSubgraph(hg->VertexSubgraph(vid_vec)));\n      *rv = HeteroSubgraphRef(subg);\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroEdgeSubgraph\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      List<Value> eids = args[1];\n      bool preserve_nodes = args[2];\n      std::vector<IdArray> eid_vec;\n      eid_vec.reserve(eids.size());\n      for (Value val : eids) {\n        eid_vec.push_back(val->data);\n      }\n      std::shared_ptr<HeteroSubgraph> subg(\n          new HeteroSubgraph(hg->EdgeSubgraph(eid_vec, preserve_nodes)));\n      *rv = HeteroSubgraphRef(subg);\n    });\n\n///////////////////////// HeteroSubgraph members /////////////////////////\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroSubgraphGetGraph\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroSubgraphRef subg = args[0];\n      *rv = HeteroGraphRef(subg->graph);\n    });\n\nDGL_REGISTER_GLOBAL(\n    \"heterograph_index._CAPI_DGLHeteroSubgraphGetInducedVertices\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroSubgraphRef subg = args[0];\n      List<Value> induced_verts;\n      for (IdArray arr : subg->induced_vertices) {\n        induced_verts.push_back(Value(MakeValue(arr)));\n      }\n      *rv = induced_verts;\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroSubgraphGetInducedEdges\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroSubgraphRef subg = args[0];\n      List<Value> induced_edges;\n      for (IdArray arr : subg->induced_edges) {\n        induced_edges.push_back(Value(MakeValue(arr)));\n      }\n      *rv = induced_edges;\n    });\n\n///////////////////////// Global functions and algorithms\n////////////////////////////\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroAsNumBits\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      int bits = args[1];\n      HeteroGraphPtr bhg_ptr = hg.sptr();\n      auto hg_ptr = std::dynamic_pointer_cast<HeteroGraph>(bhg_ptr);\n      HeteroGraphPtr hg_new;\n      if (hg_ptr) {\n        hg_new = HeteroGraph::AsNumBits(hg_ptr, bits);\n      } else {\n        hg_new = UnitGraph::AsNumBits(bhg_ptr, bits);\n      }\n      *rv = HeteroGraphRef(hg_new);\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroCopyTo\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      int device_type = args[1];\n      int device_id = args[2];\n      DGLContext ctx;\n      ctx.device_type = static_cast<DGLDeviceType>(device_type);\n      ctx.device_id = device_id;\n      HeteroGraphPtr hg_new = HeteroGraph::CopyTo(hg.sptr(), ctx);\n      *rv = HeteroGraphRef(hg_new);\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroPinMemory\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      HeteroGraphPtr hg_new = HeteroGraph::PinMemory(hg.sptr());\n      *rv = HeteroGraphRef(hg_new);\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroPinMemory_\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      auto hgindex = std::dynamic_pointer_cast<HeteroGraph>(hg.sptr());\n      hgindex->PinMemory_();\n      *rv = hg;\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroUnpinMemory_\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      auto hgindex = std::dynamic_pointer_cast<HeteroGraph>(hg.sptr());\n      hgindex->UnpinMemory_();\n      *rv = hg;\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroRecordStream\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      DGLStreamHandle stream = args[1];\n      auto hgindex = std::dynamic_pointer_cast<HeteroGraph>(hg.sptr());\n      hgindex->RecordStream(stream);\n      *rv = hg;\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroCopyToSharedMem\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      std::string name = args[1];\n      List<Value> ntypes = args[2];\n      List<Value> etypes = args[3];\n      List<Value> fmts = args[4];\n      auto ntypes_vec = ListValueToVector<std::string>(ntypes);\n      auto etypes_vec = ListValueToVector<std::string>(etypes);\n      std::set<std::string> fmts_set;\n      for (const auto& fmt : fmts) {\n        std::string fmt_data = fmt->data;\n        fmts_set.insert(fmt_data);\n      }\n      auto hg_share = HeteroGraph::CopyToSharedMem(\n          hg.sptr(), name, ntypes_vec, etypes_vec, fmts_set);\n      *rv = HeteroGraphRef(hg_share);\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroCreateFromSharedMem\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      std::string name = args[0];\n      HeteroGraphPtr hg;\n      std::vector<std::string> ntypes;\n      std::vector<std::string> etypes;\n      std::tie(hg, ntypes, etypes) = HeteroGraph::CreateFromSharedMem(name);\n      List<Value> ntypes_list;\n      List<Value> etypes_list;\n      for (const auto& ntype : ntypes)\n        ntypes_list.push_back(Value(MakeValue(ntype)));\n      for (const auto& etype : etypes)\n        etypes_list.push_back(Value(MakeValue(etype)));\n      List<ObjectRef> ret;\n      ret.push_back(HeteroGraphRef(hg));\n      ret.push_back(ntypes_list);\n      ret.push_back(etypes_list);\n      *rv = ret;\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroJointUnion\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      GraphRef meta_graph = args[0];\n      List<HeteroGraphRef> component_graphs = args[1];\n      CHECK(component_graphs.size() > 1)\n          << \"Expect graph list to have at least two graphs\";\n      std::vector<HeteroGraphPtr> component_ptrs;\n      component_ptrs.reserve(component_graphs.size());\n      const int64_t bits = component_graphs[0]->NumBits();\n      const DGLContext ctx = component_graphs[0]->Context();\n      for (const auto& component : component_graphs) {\n        component_ptrs.push_back(component.sptr());\n        CHECK_EQ(component->NumBits(), bits)\n            << \"Expect graphs to joint union have the same index dtype(int\"\n            << bits << \"), but got int\" << component->NumBits();\n        CHECK_EQ(component->Context(), ctx)\n            << \"Expect graphs to joint union have the same context\" << ctx\n            << \"), but got \" << component->Context();\n      }\n\n      auto hgptr = JointUnionHeteroGraph(meta_graph.sptr(), component_ptrs);\n      *rv = HeteroGraphRef(hgptr);\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroDisjointUnion_v2\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      GraphRef meta_graph = args[0];\n      List<HeteroGraphRef> component_graphs = args[1];\n      CHECK(component_graphs.size() > 0)\n          << \"Expect graph list has at least one graph\";\n      std::vector<HeteroGraphPtr> component_ptrs;\n      component_ptrs.reserve(component_graphs.size());\n      const int64_t bits = component_graphs[0]->NumBits();\n      const DGLContext ctx = component_graphs[0]->Context();\n      for (const auto& component : component_graphs) {\n        component_ptrs.push_back(component.sptr());\n        CHECK_EQ(component->NumBits(), bits)\n            << \"Expect graphs to batch have the same index dtype(int\" << bits\n            << \"), but got int\" << component->NumBits();\n        CHECK_EQ(component->Context(), ctx)\n            << \"Expect graphs to batch have the same context\" << ctx\n            << \"), but got \" << component->Context();\n      }\n\n      auto hgptr = DisjointUnionHeteroGraph2(meta_graph.sptr(), component_ptrs);\n      *rv = HeteroGraphRef(hgptr);\n    });\n\nDGL_REGISTER_GLOBAL(\n    \"heterograph_index._CAPI_DGLHeteroDisjointPartitionBySizes_v2\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      const IdArray vertex_sizes = args[1];\n      const IdArray edge_sizes = args[2];\n      std::vector<HeteroGraphPtr> ret;\n      ret = DisjointPartitionHeteroBySizes2(\n          hg->meta_graph(), hg.sptr(), vertex_sizes, edge_sizes);\n      List<HeteroGraphRef> ret_list;\n      for (HeteroGraphPtr hgptr : ret) {\n        ret_list.push_back(HeteroGraphRef(hgptr));\n      }\n      *rv = ret_list;\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroDisjointPartitionBySizes\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      const IdArray vertex_sizes = args[1];\n      const IdArray edge_sizes = args[2];\n      const int64_t bits = hg->NumBits();\n      std::vector<HeteroGraphPtr> ret;\n      ATEN_ID_BITS_SWITCH(bits, IdType, {\n        ret = DisjointPartitionHeteroBySizes<IdType>(\n            hg->meta_graph(), hg.sptr(), vertex_sizes, edge_sizes);\n      });\n      List<HeteroGraphRef> ret_list;\n      for (HeteroGraphPtr hgptr : ret) {\n        ret_list.push_back(HeteroGraphRef(hgptr));\n      }\n      *rv = ret_list;\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroSlice\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      const IdArray num_nodes_per_type = args[1];\n      const IdArray start_nid_per_type = args[2];\n      const IdArray num_edges_per_type = args[3];\n      const IdArray start_eid_per_type = args[4];\n      auto hgptr = SliceHeteroGraph(\n          hg->meta_graph(), hg.sptr(), num_nodes_per_type, start_nid_per_type,\n          num_edges_per_type, start_eid_per_type);\n      *rv = HeteroGraphRef(hgptr);\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroGetCreatedFormats\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      List<Value> format_list;\n      dgl_format_code_t code = hg->GetRelationGraph(0)->GetCreatedFormats();\n      for (auto format : CodeToSparseFormats(code)) {\n        format_list.push_back(Value(MakeValue(ToStringSparseFormat(format))));\n      }\n      *rv = format_list;\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroGetAllowedFormats\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      List<Value> format_list;\n      dgl_format_code_t code = hg->GetRelationGraph(0)->GetAllowedFormats();\n      for (auto format : CodeToSparseFormats(code)) {\n        format_list.push_back(Value(MakeValue(ToStringSparseFormat(format))));\n      }\n      *rv = format_list;\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroCreateFormat\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      dgl_format_code_t code = hg->GetRelationGraph(0)->GetAllowedFormats();\n      auto get_format_f = [&](size_t etype_b, size_t etype_e) {\n        for (auto etype = etype_b; etype < etype_e; ++etype) {\n          auto bg =\n              std::dynamic_pointer_cast<UnitGraph>(hg->GetRelationGraph(etype));\n          for (auto format : CodeToSparseFormats(code)) bg->GetFormat(format);\n        }\n      };\n\n#if !(defined(DGL_USE_CUDA))\n      runtime::parallel_for(0, hg->NumEdgeTypes(), get_format_f);\n#else\n      get_format_f(0, hg->NumEdgeTypes());\n#endif\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroGetFormatGraph\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      List<Value> formats = args[1];\n      std::vector<SparseFormat> formats_vec;\n      for (Value val : formats) {\n        std::string fmt = val->data;\n        formats_vec.push_back(ParseSparseFormat(fmt));\n      }\n      auto hgptr = hg->GetGraphInFormat(SparseFormatsToCode(formats_vec));\n      *rv = HeteroGraphRef(hgptr);\n    });\n\nDGL_REGISTER_GLOBAL(\"subgraph._CAPI_DGLInSubgraph\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      const auto& nodes = ListValueToVector<IdArray>(args[1]);\n      bool relabel_nodes = args[2];\n      std::shared_ptr<HeteroSubgraph> ret(new HeteroSubgraph);\n      *ret = InEdgeGraph(hg.sptr(), nodes, relabel_nodes);\n      *rv = HeteroGraphRef(ret);\n    });\n\nDGL_REGISTER_GLOBAL(\"subgraph._CAPI_DGLOutSubgraph\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      const auto& nodes = ListValueToVector<IdArray>(args[1]);\n      bool relabel_nodes = args[2];\n      std::shared_ptr<HeteroSubgraph> ret(new HeteroSubgraph);\n      *ret = OutEdgeGraph(hg.sptr(), nodes, relabel_nodes);\n      *rv = HeteroGraphRef(ret);\n    });\n\nDGL_REGISTER_GLOBAL(\"transform._CAPI_DGLAsImmutableGraph\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      *rv = GraphRef(hg->AsImmutableGraph());\n    });\n\nDGL_REGISTER_GLOBAL(\"transform._CAPI_DGLHeteroSortOutEdges\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      NDArray tag = args[1];\n      int64_t num_tag = args[2];\n\n      CHECK_EQ(hg->Context().device_type, kDGLCPU)\n          << \"Only support sorting by tag on cpu\";\n      CHECK(aten::IsValidIdArray(tag));\n      CHECK_EQ(tag->ctx.device_type, kDGLCPU)\n          << \"Only support sorting by tag on cpu\";\n\n      const auto csr = hg->GetCSRMatrix(0);\n\n      NDArray tag_pos = aten::NullArray();\n      aten::CSRMatrix output;\n      std::tie(output, tag_pos) = aten::CSRSortByTag(csr, tag, num_tag);\n      HeteroGraphPtr output_hg =\n          CreateFromCSR(hg->NumVertexTypes(), output, ALL_CODE);\n      List<ObjectRef> ret;\n      ret.push_back(HeteroGraphRef(output_hg));\n      ret.push_back(Value(MakeValue(tag_pos)));\n      *rv = ret;\n    });\n\nDGL_REGISTER_GLOBAL(\"transform._CAPI_DGLHeteroSortInEdges\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      NDArray tag = args[1];\n      int64_t num_tag = args[2];\n\n      CHECK_EQ(hg->Context().device_type, kDGLCPU)\n          << \"Only support sorting by tag on cpu\";\n      CHECK(aten::IsValidIdArray(tag));\n      CHECK_EQ(tag->ctx.device_type, kDGLCPU)\n          << \"Only support sorting by tag on cpu\";\n\n      const auto csc = hg->GetCSCMatrix(0);\n\n      NDArray tag_pos = aten::NullArray();\n      aten::CSRMatrix output;\n      std::tie(output, tag_pos) = aten::CSRSortByTag(csc, tag, num_tag);\n\n      HeteroGraphPtr output_hg =\n          CreateFromCSC(hg->NumVertexTypes(), output, ALL_CODE);\n      List<ObjectRef> ret;\n      ret.push_back(HeteroGraphRef(output_hg));\n      ret.push_back(Value(MakeValue(tag_pos)));\n      *rv = ret;\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph._CAPI_DGLFindSrcDstNtypes\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      GraphRef metagraph = args[0];\n      std::unordered_set<uint64_t> dst_set;\n      std::unordered_set<uint64_t> src_set;\n\n      for (uint64_t eid = 0; eid < metagraph->NumEdges(); ++eid) {\n        auto edge = metagraph->FindEdge(eid);\n        auto src = edge.first;\n        auto dst = edge.second;\n        dst_set.insert(dst);\n        src_set.insert(src);\n      }\n\n      List<Value> srclist, dstlist;\n      List<List<Value>> ret_list;\n      for (uint64_t nid = 0; nid < metagraph->NumVertices(); ++nid) {\n        auto is_dst = dst_set.count(nid);\n        auto is_src = src_set.count(nid);\n        if (is_dst && is_src)\n          return;\n        else if (is_dst)\n          dstlist.push_back(Value(MakeValue(static_cast<int64_t>(nid))));\n        else\n          // If a node type is isolated, put it in srctype as defined in the\n          // Python docstring.\n          srclist.push_back(Value(MakeValue(static_cast<int64_t>(nid))));\n      }\n      ret_list.push_back(srclist);\n      ret_list.push_back(dstlist);\n      *rv = ret_list;\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroReverse\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      CHECK_GT(hg->NumEdgeTypes(), 0);\n      auto g = std::dynamic_pointer_cast<HeteroGraph>(hg.sptr());\n      std::vector<HeteroGraphPtr> rev_ugs;\n      const auto& ugs = g->relation_graphs();\n      rev_ugs.resize(ugs.size());\n\n      for (size_t i = 0; i < ugs.size(); ++i) {\n        const auto& rev_ug = ugs[i]->Reverse();\n        rev_ugs[i] = rev_ug;\n      }\n      // node types are not changed\n      const auto& num_nodes = g->NumVerticesPerType();\n      const auto& meta_edges = hg->meta_graph()->Edges(\"eid\");\n      // reverse the metagraph\n      const auto& rev_meta = ImmutableGraph::CreateFromCOO(\n          hg->meta_graph()->NumVertices(), meta_edges.dst, meta_edges.src);\n      *rv = CreateHeteroGraph(rev_meta, rev_ugs, num_nodes);\n    });\n}  // namespace dgl\n"
  },
  {
    "path": "src/graph/immutable_graph.cc",
    "content": "/**\n *  Copyright (c) 2018 by Contributors\n * @file graph/immutable_graph.cc\n * @brief DGL immutable graph index implementation\n */\n\n#include <dgl/base_heterograph.h>\n#include <dgl/immutable_graph.h>\n#include <dgl/packed_func_ext.h>\n#include <dgl/runtime/smart_ptr_serializer.h>\n#include <dmlc/io.h>\n#include <dmlc/type_traits.h>\n#include <string.h>\n\n#include <bitset>\n#include <numeric>\n#include <tuple>\n\n#include \"../c_api_common.h\"\n#include \"heterograph.h\"\n#include \"unit_graph.h\"\n\nusing namespace dgl::runtime;\n\nnamespace dgl {\nnamespace {\ninline std::string GetSharedMemName(\n    const std::string &name, const std::string &edge_dir) {\n  return name + \"_\" + edge_dir;\n}\n\n/**\n * The metadata of a graph index that are needed for shared-memory graph.\n */\nstruct GraphIndexMetadata {\n  int64_t num_nodes;\n  int64_t num_edges;\n  bool has_in_csr;\n  bool has_out_csr;\n  bool has_coo;\n};\n\n/**\n * Serialize the metadata of a graph index and place it in a shared-memory\n * tensor. In this way, another process can reconstruct a GraphIndex from a\n * shared-memory tensor.\n */\nNDArray SerializeMetadata(ImmutableGraphPtr gidx, const std::string &name) {\n#ifndef _WIN32\n  GraphIndexMetadata meta;\n  meta.num_nodes = gidx->NumVertices();\n  meta.num_edges = gidx->NumEdges();\n  meta.has_in_csr = gidx->HasInCSR();\n  meta.has_out_csr = gidx->HasOutCSR();\n  meta.has_coo = false;\n\n  NDArray meta_arr = NDArray::EmptyShared(\n      name, {sizeof(meta)}, DGLDataType{kDGLInt, 8, 1}, DGLContext{kDGLCPU, 0},\n      true);\n  memcpy(meta_arr->data, &meta, sizeof(meta));\n  return meta_arr;\n#else\n  LOG(FATAL) << \"CSR graph doesn't support shared memory in Windows yet\";\n  return NDArray();\n#endif  // _WIN32\n}\n\n/**\n * Deserialize the metadata of a graph index.\n */\nGraphIndexMetadata DeserializeMetadata(const std::string &name) {\n  GraphIndexMetadata meta;\n#ifndef _WIN32\n  NDArray meta_arr = NDArray::EmptyShared(\n      name, {sizeof(meta)}, DGLDataType{kDGLInt, 8, 1}, DGLContext{kDGLCPU, 0},\n      false);\n  memcpy(&meta, meta_arr->data, sizeof(meta));\n#else\n  LOG(FATAL) << \"CSR graph doesn't support shared memory in Windows yet\";\n#endif  // _WIN32\n  return meta;\n}\n\nstd::tuple<IdArray, IdArray, IdArray> MapFromSharedMemory(\n    const std::string &shared_mem_name, int64_t num_verts, int64_t num_edges,\n    bool is_create) {\n#ifndef _WIN32\n  const int64_t file_size = (num_verts + 1 + num_edges * 2) * sizeof(dgl_id_t);\n\n  IdArray sm_array = IdArray::EmptyShared(\n      shared_mem_name, {file_size}, DGLDataType{kDGLInt, 8, 1},\n      DGLContext{kDGLCPU, 0}, is_create);\n  // Create views from the shared memory array. Note that we don't need to save\n  //   the sm_array because the refcount is maintained by the view arrays.\n  IdArray indptr =\n      sm_array.CreateView({num_verts + 1}, DGLDataType{kDGLInt, 64, 1});\n  IdArray indices = sm_array.CreateView(\n      {num_edges}, DGLDataType{kDGLInt, 64, 1},\n      (num_verts + 1) * sizeof(dgl_id_t));\n  IdArray edge_ids = sm_array.CreateView(\n      {num_edges}, DGLDataType{kDGLInt, 64, 1},\n      (num_verts + 1 + num_edges) * sizeof(dgl_id_t));\n  return std::make_tuple(indptr, indices, edge_ids);\n#else\n  LOG(FATAL) << \"CSR graph doesn't support shared memory in Windows yet\";\n  return {};\n#endif  // _WIN32\n}\n}  // namespace\n\n//////////////////////////////////////////////////////////\n//\n// CSR graph implementation\n//\n//////////////////////////////////////////////////////////\n\nCSR::CSR(int64_t num_vertices, int64_t num_edges) {\n  CHECK(!(num_vertices == 0 && num_edges != 0));\n  adj_ = aten::CSRMatrix{\n      num_vertices, num_vertices, aten::NewIdArray(num_vertices + 1),\n      aten::NewIdArray(num_edges), aten::NewIdArray(num_edges)};\n  adj_.sorted = false;\n}\n\nCSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids) {\n  CHECK(aten::IsValidIdArray(indptr));\n  CHECK(aten::IsValidIdArray(indices));\n  CHECK(aten::IsValidIdArray(edge_ids));\n  CHECK_EQ(indices->shape[0], edge_ids->shape[0]);\n  const int64_t N = indptr->shape[0] - 1;\n  adj_ = aten::CSRMatrix{N, N, indptr, indices, edge_ids};\n  adj_.sorted = false;\n}\n\nCSR::CSR(\n    IdArray indptr, IdArray indices, IdArray edge_ids,\n    const std::string &shared_mem_name)\n    : shared_mem_name_(shared_mem_name) {\n  CHECK(aten::IsValidIdArray(indptr));\n  CHECK(aten::IsValidIdArray(indices));\n  CHECK(aten::IsValidIdArray(edge_ids));\n  CHECK_EQ(indices->shape[0], edge_ids->shape[0]);\n  const int64_t num_verts = indptr->shape[0] - 1;\n  const int64_t num_edges = indices->shape[0];\n  adj_.num_rows = num_verts;\n  adj_.num_cols = num_verts;\n  std::tie(adj_.indptr, adj_.indices, adj_.data) =\n      MapFromSharedMemory(shared_mem_name, num_verts, num_edges, true);\n  // copy the given data into the shared memory arrays\n  adj_.indptr.CopyFrom(indptr);\n  adj_.indices.CopyFrom(indices);\n  adj_.data.CopyFrom(edge_ids);\n  adj_.sorted = false;\n}\n\nCSR::CSR(\n    const std::string &shared_mem_name, int64_t num_verts, int64_t num_edges)\n    : shared_mem_name_(shared_mem_name) {\n  CHECK(!(num_verts == 0 && num_edges != 0));\n  adj_.num_rows = num_verts;\n  adj_.num_cols = num_verts;\n  std::tie(adj_.indptr, adj_.indices, adj_.data) =\n      MapFromSharedMemory(shared_mem_name, num_verts, num_edges, false);\n  adj_.sorted = false;\n}\n\nbool CSR::IsMultigraph() const { return aten::CSRHasDuplicate(adj_); }\n\nEdgeArray CSR::OutEdges(dgl_id_t vid) const {\n  CHECK(HasVertex(vid)) << \"invalid vertex: \" << vid;\n  IdArray ret_dst = aten::CSRGetRowColumnIndices(adj_, vid);\n  IdArray ret_eid = aten::CSRGetRowData(adj_, vid);\n  IdArray ret_src = aten::Full(vid, ret_dst->shape[0], NumBits(), ret_dst->ctx);\n  return EdgeArray{ret_src, ret_dst, ret_eid};\n}\n\nEdgeArray CSR::OutEdges(IdArray vids) const {\n  CHECK(aten::IsValidIdArray(vids)) << \"Invalid vertex id array.\";\n  auto csrsubmat = aten::CSRSliceRows(adj_, vids);\n  auto coosubmat = aten::CSRToCOO(csrsubmat, false);\n  // Note that the row id in the csr submat is relabled, so\n  // we need to recover it using an index select.\n  auto row = aten::IndexSelect(vids, coosubmat.row);\n  return EdgeArray{row, coosubmat.col, coosubmat.data};\n}\n\nDegreeArray CSR::OutDegrees(IdArray vids) const {\n  CHECK(aten::IsValidIdArray(vids)) << \"Invalid vertex id array.\";\n  return aten::CSRGetRowNNZ(adj_, vids);\n}\n\nbool CSR::HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const {\n  CHECK(HasVertex(src)) << \"Invalid vertex id: \" << src;\n  CHECK(HasVertex(dst)) << \"Invalid vertex id: \" << dst;\n  return aten::CSRIsNonZero(adj_, src, dst);\n}\n\nBoolArray CSR::HasEdgesBetween(IdArray src_ids, IdArray dst_ids) const {\n  CHECK(aten::IsValidIdArray(src_ids)) << \"Invalid vertex id array.\";\n  CHECK(aten::IsValidIdArray(dst_ids)) << \"Invalid vertex id array.\";\n  return aten::CSRIsNonZero(adj_, src_ids, dst_ids);\n}\n\nIdArray CSR::Successors(dgl_id_t vid, uint64_t radius) const {\n  CHECK(HasVertex(vid)) << \"invalid vertex: \" << vid;\n  CHECK(radius == 1) << \"invalid radius: \" << radius;\n  return aten::CSRGetRowColumnIndices(adj_, vid);\n}\n\nIdArray CSR::EdgeId(dgl_id_t src, dgl_id_t dst) const {\n  CHECK(HasVertex(src)) << \"invalid vertex: \" << src;\n  CHECK(HasVertex(dst)) << \"invalid vertex: \" << dst;\n  return aten::CSRGetAllData(adj_, src, dst);\n}\n\nEdgeArray CSR::EdgeIds(IdArray src_ids, IdArray dst_ids) const {\n  const auto &arrs = aten::CSRGetDataAndIndices(adj_, src_ids, dst_ids);\n  return EdgeArray{arrs[0], arrs[1], arrs[2]};\n}\n\nEdgeArray CSR::Edges(const std::string &order) const {\n  CHECK(order.empty() || order == std::string(\"srcdst\"))\n      << \"CSR only support Edges of order \\\"srcdst\\\",\"\n      << \" but got \\\"\" << order << \"\\\".\";\n  const auto &coo = aten::CSRToCOO(adj_, false);\n  return EdgeArray{coo.row, coo.col, coo.data};\n}\n\nSubgraph CSR::VertexSubgraph(IdArray vids) const {\n  CHECK(aten::IsValidIdArray(vids)) << \"Invalid vertex id array.\";\n  const auto &submat = aten::CSRSliceMatrix(adj_, vids, vids);\n  IdArray sub_eids =\n      aten::Range(0, submat.data->shape[0], NumBits(), Context());\n  CSRPtr subcsr(new CSR(submat.indptr, submat.indices, sub_eids));\n  subcsr->adj_.sorted = this->adj_.sorted;\n  Subgraph subg;\n  subg.graph = subcsr;\n  subg.induced_vertices = vids;\n  subg.induced_edges = submat.data;\n  return subg;\n}\n\nCSRPtr CSR::Transpose() const {\n  const auto &trans = aten::CSRTranspose(adj_);\n  return CSRPtr(new CSR(trans.indptr, trans.indices, trans.data));\n}\n\nCOOPtr CSR::ToCOO() const {\n  const auto &coo = aten::CSRToCOO(adj_, true);\n  return COOPtr(new COO(NumVertices(), coo.row, coo.col));\n}\n\nCSR CSR::CopyTo(const DGLContext &ctx) const {\n  if (Context() == ctx) {\n    return *this;\n  } else {\n    CSR ret(\n        adj_.indptr.CopyTo(ctx), adj_.indices.CopyTo(ctx),\n        adj_.data.CopyTo(ctx));\n    return ret;\n  }\n}\n\nCSR CSR::CopyToSharedMem(const std::string &name) const {\n  if (IsSharedMem()) {\n    CHECK(name == shared_mem_name_);\n    return *this;\n  } else {\n    // TODO(zhengda) we need to set sorted_ properly.\n    return CSR(adj_.indptr, adj_.indices, adj_.data, name);\n  }\n}\n\nCSR CSR::AsNumBits(uint8_t bits) const {\n  if (NumBits() == bits) {\n    return *this;\n  } else {\n    CSR ret(\n        aten::AsNumBits(adj_.indptr, bits), aten::AsNumBits(adj_.indices, bits),\n        aten::AsNumBits(adj_.data, bits));\n    return ret;\n  }\n}\n\nDGLIdIters CSR::SuccVec(dgl_id_t vid) const {\n  // TODO(minjie): This still assumes the data type and device context\n  //   of this graph. Should fix later.\n  const dgl_id_t *indptr_data = static_cast<dgl_id_t *>(adj_.indptr->data);\n  const dgl_id_t *indices_data = static_cast<dgl_id_t *>(adj_.indices->data);\n  const dgl_id_t start = indptr_data[vid];\n  const dgl_id_t end = indptr_data[vid + 1];\n  return DGLIdIters(indices_data + start, indices_data + end);\n}\n\nDGLIdIters CSR::OutEdgeVec(dgl_id_t vid) const {\n  // TODO(minjie): This still assumes the data type and device context\n  //   of this graph. Should fix later.\n  const dgl_id_t *indptr_data = static_cast<dgl_id_t *>(adj_.indptr->data);\n  const dgl_id_t *eid_data = static_cast<dgl_id_t *>(adj_.data->data);\n  const dgl_id_t start = indptr_data[vid];\n  const dgl_id_t end = indptr_data[vid + 1];\n  return DGLIdIters(eid_data + start, eid_data + end);\n}\n\nbool CSR::Load(dmlc::Stream *fs) {\n  fs->Read(const_cast<dgl::aten::CSRMatrix *>(&adj_));\n  return true;\n}\n\nvoid CSR::Save(dmlc::Stream *fs) const { fs->Write(adj_); }\n\n//////////////////////////////////////////////////////////\n//\n// COO graph implementation\n//\n//////////////////////////////////////////////////////////\nCOO::COO(\n    int64_t num_vertices, IdArray src, IdArray dst, bool row_sorted,\n    bool col_sorted) {\n  CHECK(aten::IsValidIdArray(src));\n  CHECK(aten::IsValidIdArray(dst));\n  CHECK_EQ(src->shape[0], dst->shape[0]);\n  adj_ = aten::COOMatrix{num_vertices,      num_vertices, src,       dst,\n                         aten::NullArray(), row_sorted,   col_sorted};\n}\n\nbool COO::IsMultigraph() const { return aten::COOHasDuplicate(adj_); }\n\nstd::pair<dgl_id_t, dgl_id_t> COO::FindEdge(dgl_id_t eid) const {\n  CHECK(eid < NumEdges()) << \"Invalid edge id: \" << eid;\n  const dgl_id_t src = aten::IndexSelect<dgl_id_t>(adj_.row, eid);\n  const dgl_id_t dst = aten::IndexSelect<dgl_id_t>(adj_.col, eid);\n  return std::pair<dgl_id_t, dgl_id_t>(src, dst);\n}\n\nEdgeArray COO::FindEdges(IdArray eids) const {\n  CHECK(aten::IsValidIdArray(eids)) << \"Invalid edge id array\";\n  BUG_IF_FAIL(aten::IsNullArray(adj_.data))\n      << \"FindEdges requires the internal COO matrix not having EIDs.\";\n  return EdgeArray{\n      aten::IndexSelect(adj_.row, eids), aten::IndexSelect(adj_.col, eids),\n      eids};\n}\n\nEdgeArray COO::Edges(const std::string &order) const {\n  CHECK(order.empty() || order == std::string(\"eid\"))\n      << \"COO only support Edges of order \\\"eid\\\", but got \\\"\" << order\n      << \"\\\".\";\n  IdArray rst_eid = aten::Range(0, NumEdges(), NumBits(), Context());\n  return EdgeArray{adj_.row, adj_.col, rst_eid};\n}\n\nSubgraph COO::EdgeSubgraph(IdArray eids, bool preserve_nodes) const {\n  CHECK(aten::IsValidIdArray(eids)) << \"Invalid edge id array.\";\n  COOPtr subcoo;\n  IdArray induced_nodes;\n  if (!preserve_nodes) {\n    IdArray new_src = aten::IndexSelect(adj_.row, eids);\n    IdArray new_dst = aten::IndexSelect(adj_.col, eids);\n    induced_nodes = aten::Relabel_({new_src, new_dst});\n    const auto new_nnodes = induced_nodes->shape[0];\n    subcoo = COOPtr(new COO(new_nnodes, new_src, new_dst));\n  } else {\n    IdArray new_src = aten::IndexSelect(adj_.row, eids);\n    IdArray new_dst = aten::IndexSelect(adj_.col, eids);\n    induced_nodes = aten::Range(0, NumVertices(), NumBits(), Context());\n    subcoo = COOPtr(new COO(NumVertices(), new_src, new_dst));\n  }\n  Subgraph subg;\n  subg.graph = subcoo;\n  subg.induced_vertices = induced_nodes;\n  subg.induced_edges = eids;\n  return subg;\n}\n\nCSRPtr COO::ToCSR() const {\n  const auto &csr = aten::COOToCSR(adj_);\n  return CSRPtr(new CSR(csr.indptr, csr.indices, csr.data));\n}\n\nCOO COO::CopyTo(const DGLContext &ctx) const {\n  if (Context() == ctx) {\n    return *this;\n  } else {\n    COO ret(NumVertices(), adj_.row.CopyTo(ctx), adj_.col.CopyTo(ctx));\n    return ret;\n  }\n}\n\nCOO COO::CopyToSharedMem(const std::string &name) const {\n  LOG(FATAL) << \"COO doesn't supprt shared memory yet\";\n  return COO();\n}\n\nCOO COO::AsNumBits(uint8_t bits) const {\n  if (NumBits() == bits) {\n    return *this;\n  } else {\n    COO ret(\n        NumVertices(), aten::AsNumBits(adj_.row, bits),\n        aten::AsNumBits(adj_.col, bits));\n    return ret;\n  }\n}\n\n//////////////////////////////////////////////////////////\n//\n// immutable graph implementation\n//\n//////////////////////////////////////////////////////////\n\nBoolArray ImmutableGraph::HasVertices(IdArray vids) const {\n  CHECK(aten::IsValidIdArray(vids)) << \"Invalid id array input\";\n  return aten::LT(vids, NumVertices());\n}\n\nCSRPtr ImmutableGraph::GetInCSR() const {\n  if (!in_csr_) {\n    if (out_csr_) {\n      const_cast<ImmutableGraph *>(this)->in_csr_ = out_csr_->Transpose();\n      if (out_csr_->IsSharedMem())\n        LOG(WARNING)\n            << \"We just construct an in-CSR from a shared-memory out CSR. \"\n            << \"It may dramatically increase memory consumption.\";\n    } else {\n      CHECK(coo_) << \"None of CSR, COO exist\";\n      const_cast<ImmutableGraph *>(this)->in_csr_ = coo_->Transpose()->ToCSR();\n    }\n  }\n  return in_csr_;\n}\n\n/** @brief Return out csr. If not exist, transpose the other one.*/\nCSRPtr ImmutableGraph::GetOutCSR() const {\n  if (!out_csr_) {\n    if (in_csr_) {\n      const_cast<ImmutableGraph *>(this)->out_csr_ = in_csr_->Transpose();\n      if (in_csr_->IsSharedMem())\n        LOG(WARNING)\n            << \"We just construct an out-CSR from a shared-memory in CSR. \"\n            << \"It may dramatically increase memory consumption.\";\n    } else {\n      CHECK(coo_) << \"None of CSR, COO exist\";\n      const_cast<ImmutableGraph *>(this)->out_csr_ = coo_->ToCSR();\n    }\n  }\n  return out_csr_;\n}\n\n/** @brief Return coo. If not exist, create from csr.*/\nCOOPtr ImmutableGraph::GetCOO() const {\n  if (!coo_) {\n    if (in_csr_) {\n      const_cast<ImmutableGraph *>(this)->coo_ = in_csr_->ToCOO()->Transpose();\n    } else {\n      CHECK(out_csr_) << \"Both CSR are missing.\";\n      const_cast<ImmutableGraph *>(this)->coo_ = out_csr_->ToCOO();\n    }\n  }\n  return coo_;\n}\n\nEdgeArray ImmutableGraph::Edges(const std::string &order) const {\n  if (order.empty()) {\n    // arbitrary order\n    if (in_csr_) {\n      // transpose\n      const auto &edges = in_csr_->Edges(order);\n      return EdgeArray{edges.dst, edges.src, edges.id};\n    } else {\n      return AnyGraph()->Edges(order);\n    }\n  } else if (order == std::string(\"srcdst\")) {\n    // TODO(minjie): CSR only guarantees \"src\" to be sorted.\n    //   Maybe we should relax this requirement?\n    return GetOutCSR()->Edges(order);\n  } else if (order == std::string(\"eid\")) {\n    return GetCOO()->Edges(order);\n  } else {\n    LOG(FATAL) << \"Unsupported order request: \" << order;\n  }\n  return {};\n}\n\nSubgraph ImmutableGraph::VertexSubgraph(IdArray vids) const {\n  // We prefer to generate a subgraph from out-csr.\n  auto sg = GetOutCSR()->VertexSubgraph(vids);\n  CSRPtr subcsr = std::dynamic_pointer_cast<CSR>(sg.graph);\n  sg.graph = GraphPtr(new ImmutableGraph(subcsr));\n  return sg;\n}\n\nSubgraph ImmutableGraph::EdgeSubgraph(IdArray eids, bool preserve_nodes) const {\n  auto sg = GetCOO()->EdgeSubgraph(eids, preserve_nodes);\n  COOPtr subcoo = std::dynamic_pointer_cast<COO>(sg.graph);\n  sg.graph = GraphPtr(new ImmutableGraph(subcoo));\n  return sg;\n}\n\nstd::vector<IdArray> ImmutableGraph::GetAdj(\n    bool transpose, const std::string &fmt) const {\n  // TODO(minjie): Our current semantics of adjacency matrix is row for dst\n  // nodes and col for\n  //   src nodes. Therefore, we need to flip the transpose flag. For example,\n  //   transpose=False is equal to in edge CSR. We have this behavior because\n  //   previously we use framework's SPMM and we don't cache reverse adj. This\n  //   is not intuitive and also not consistent with networkx's\n  //   to_scipy_sparse_matrix. With the upcoming custom kernel change, we should\n  //   change the behavior and make row for src and col for dst.\n  if (fmt == std::string(\"csr\")) {\n    return transpose ? GetOutCSR()->GetAdj(false, \"csr\")\n                     : GetInCSR()->GetAdj(false, \"csr\");\n  } else if (fmt == std::string(\"coo\")) {\n    return GetCOO()->GetAdj(!transpose, fmt);\n  } else {\n    LOG(FATAL) << \"unsupported adjacency matrix format: \" << fmt;\n    return {};\n  }\n}\n\nImmutableGraphPtr ImmutableGraph::CreateFromCSR(\n    IdArray indptr, IdArray indices, IdArray edge_ids,\n    const std::string &edge_dir) {\n  CSRPtr csr(new CSR(indptr, indices, edge_ids));\n  if (edge_dir == \"in\") {\n    return ImmutableGraphPtr(new ImmutableGraph(csr, nullptr));\n  } else if (edge_dir == \"out\") {\n    return ImmutableGraphPtr(new ImmutableGraph(nullptr, csr));\n  } else {\n    LOG(FATAL) << \"Unknown edge direction: \" << edge_dir;\n    return ImmutableGraphPtr();\n  }\n}\n\nImmutableGraphPtr ImmutableGraph::CreateFromCSR(const std::string &name) {\n  // If the shared memory graph index doesn't exist, we return null directly.\n#ifndef _WIN32\n  if (!SharedMemory::Exist(GetSharedMemName(name, \"meta\"))) {\n    return nullptr;\n  }\n#endif  // _WIN32\n  GraphIndexMetadata meta = DeserializeMetadata(GetSharedMemName(name, \"meta\"));\n  CSRPtr in_csr, out_csr;\n  if (meta.has_in_csr) {\n    in_csr = CSRPtr(\n        new CSR(GetSharedMemName(name, \"in\"), meta.num_nodes, meta.num_edges));\n  }\n  if (meta.has_out_csr) {\n    out_csr = CSRPtr(\n        new CSR(GetSharedMemName(name, \"out\"), meta.num_nodes, meta.num_edges));\n  }\n  return ImmutableGraphPtr(new ImmutableGraph(in_csr, out_csr, name));\n}\n\nImmutableGraphPtr ImmutableGraph::CreateFromCOO(\n    int64_t num_vertices, IdArray src, IdArray dst, bool row_sorted,\n    bool col_sorted) {\n  COOPtr coo(new COO(num_vertices, src, dst, row_sorted, col_sorted));\n  return std::make_shared<ImmutableGraph>(coo);\n}\n\nImmutableGraphPtr ImmutableGraph::ToImmutable(GraphPtr graph) {\n  ImmutableGraphPtr ig = std::dynamic_pointer_cast<ImmutableGraph>(graph);\n  if (ig) {\n    return ig;\n  } else {\n    const auto &adj = graph->GetAdj(true, \"csr\");\n    CSRPtr csr(new CSR(adj[0], adj[1], adj[2]));\n    return ImmutableGraph::CreateFromCSR(adj[0], adj[1], adj[2], \"out\");\n  }\n}\n\nImmutableGraphPtr ImmutableGraph::CopyTo(\n    ImmutableGraphPtr g, const DGLContext &ctx) {\n  if (ctx == g->Context()) {\n    return g;\n  }\n  // TODO(minjie): since we don't have GPU implementation of COO<->CSR,\n  //   we make sure that this graph (on CPU) has materialized CSR,\n  //   and then copy them to other context (usually GPU). This should\n  //   be fixed later.\n  CSRPtr new_incsr = CSRPtr(new CSR(g->GetInCSR()->CopyTo(ctx)));\n  CSRPtr new_outcsr = CSRPtr(new CSR(g->GetOutCSR()->CopyTo(ctx)));\n  return ImmutableGraphPtr(new ImmutableGraph(new_incsr, new_outcsr));\n}\n\nImmutableGraphPtr ImmutableGraph::CopyToSharedMem(\n    ImmutableGraphPtr g, const std::string &name) {\n  CSRPtr new_incsr, new_outcsr;\n  std::string shared_mem_name = GetSharedMemName(name, \"in\");\n  new_incsr = CSRPtr(new CSR(g->GetInCSR()->CopyToSharedMem(shared_mem_name)));\n\n  shared_mem_name = GetSharedMemName(name, \"out\");\n  new_outcsr =\n      CSRPtr(new CSR(g->GetOutCSR()->CopyToSharedMem(shared_mem_name)));\n\n  auto new_g =\n      ImmutableGraphPtr(new ImmutableGraph(new_incsr, new_outcsr, name));\n  new_g->serialized_shared_meta_ =\n      SerializeMetadata(new_g, GetSharedMemName(name, \"meta\"));\n  return new_g;\n}\n\nImmutableGraphPtr ImmutableGraph::AsNumBits(ImmutableGraphPtr g, uint8_t bits) {\n  if (g->NumBits() == bits) {\n    return g;\n  } else {\n    // TODO(minjie): since we don't have int32 operations,\n    //   we make sure that this graph (on CPU) has materialized CSR,\n    //   and then copy them to other context (usually GPU). This should\n    //   be fixed later.\n    CSRPtr new_incsr = CSRPtr(new CSR(g->GetInCSR()->AsNumBits(bits)));\n    CSRPtr new_outcsr = CSRPtr(new CSR(g->GetOutCSR()->AsNumBits(bits)));\n    return ImmutableGraphPtr(new ImmutableGraph(new_incsr, new_outcsr));\n  }\n}\n\nImmutableGraphPtr ImmutableGraph::Reverse() const {\n  if (coo_) {\n    return ImmutableGraphPtr(\n        new ImmutableGraph(out_csr_, in_csr_, coo_->Transpose()));\n  } else {\n    return ImmutableGraphPtr(new ImmutableGraph(out_csr_, in_csr_));\n  }\n}\n\nconstexpr uint64_t kDGLSerialize_ImGraph = 0xDD3c5FFE20046ABF;\n\n/** @return Load HeteroGraph from stream, using OutCSR Matrix*/\nbool ImmutableGraph::Load(dmlc::Stream *fs) {\n  uint64_t magicNum;\n  aten::CSRMatrix out_csr_matrix;\n  CHECK(fs->Read(&magicNum)) << \"Invalid Magic Number\";\n  CHECK_EQ(magicNum, kDGLSerialize_ImGraph)\n      << \"Invalid ImmutableGraph Magic Number\";\n  CHECK(fs->Read(&out_csr_)) << \"Invalid csr matrix\";\n  return true;\n}\n\n/** @return Save HeteroGraph to stream, using OutCSR Matrix */\nvoid ImmutableGraph::Save(dmlc::Stream *fs) const {\n  fs->Write(kDGLSerialize_ImGraph);\n  fs->Write(GetOutCSR());\n}\n\nHeteroGraphPtr ImmutableGraph::AsHeteroGraph() const {\n  aten::CSRMatrix in_csr, out_csr;\n  aten::COOMatrix coo;\n\n  if (in_csr_) in_csr = GetInCSR()->ToCSRMatrix();\n  if (out_csr_) out_csr = GetOutCSR()->ToCSRMatrix();\n  if (coo_) coo = GetCOO()->ToCOOMatrix();\n\n  auto g = UnitGraph::CreateUnitGraphFrom(\n      1, in_csr, out_csr, coo, in_csr_ != nullptr, out_csr_ != nullptr,\n      coo_ != nullptr);\n  return HeteroGraphPtr(new HeteroGraph(g->meta_graph(), {g}));\n}\n\nDGL_REGISTER_GLOBAL(\"transform._CAPI_DGLAsHeteroGraph\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      GraphRef g = args[0];\n      ImmutableGraphPtr ig =\n          std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());\n      CHECK(ig) << \"graph is not readonly\";\n      *rv = HeteroGraphRef(ig->AsHeteroGraph());\n    });\n\nDGL_REGISTER_GLOBAL(\"graph_index._CAPI_DGLImmutableGraphCopyTo\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      GraphRef g = args[0];\n      const int device_type = args[1];\n      const int device_id = args[2];\n      DGLContext ctx;\n      ctx.device_type = static_cast<DGLDeviceType>(device_type);\n      ctx.device_id = device_id;\n      ImmutableGraphPtr ig =\n          CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()));\n      *rv = ImmutableGraph::CopyTo(ig, ctx);\n    });\n\nDGL_REGISTER_GLOBAL(\"graph_index._CAPI_DGLImmutableGraphCopyToSharedMem\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      GraphRef g = args[0];\n      std::string name = args[1];\n      ImmutableGraphPtr ig =\n          CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()));\n      *rv = ImmutableGraph::CopyToSharedMem(ig, name);\n    });\n\nDGL_REGISTER_GLOBAL(\"graph_index._CAPI_DGLImmutableGraphAsNumBits\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      GraphRef g = args[0];\n      int bits = args[1];\n      ImmutableGraphPtr ig =\n          CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()));\n      *rv = ImmutableGraph::AsNumBits(ig, bits);\n    });\n\n}  // namespace dgl\n"
  },
  {
    "path": "src/graph/metis_partition.cc",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file graph/metis_partition.cc\n * @brief Call Metis partitioning\n */\n\n#include <dgl/graph_op.h>\n#include <dgl/packed_func_ext.h>\n#include <metis.h>\n\n#include \"../c_api_common.h\"\n\nusing namespace dgl::runtime;\n\nnamespace dgl {\n\n#if !defined(_WIN32)\n\nIdArray MetisPartition(GraphPtr g, int k, NDArray vwgt_arr, bool obj_cut) {\n  // The index type of Metis needs to be compatible with DGL index type.\n  CHECK_EQ(sizeof(idx_t), sizeof(dgl_id_t));\n  ImmutableGraphPtr ig = std::dynamic_pointer_cast<ImmutableGraph>(g);\n  CHECK(ig) << \"The input graph must be an immutable graph.\";\n  // This is a symmetric graph, so in-csr and out-csr are the same.\n  const auto mat = ig->GetInCSR()->ToCSRMatrix();\n\n  idx_t nvtxs = g->NumVertices();\n  idx_t ncon = 1;  // # balacing constraints.\n  idx_t *xadj = static_cast<idx_t *>(mat.indptr->data);\n  idx_t *adjncy = static_cast<idx_t *>(mat.indices->data);\n  idx_t nparts = k;\n  IdArray part_arr = aten::NewIdArray(nvtxs);\n  idx_t objval = 0;\n  idx_t *part = static_cast<idx_t *>(part_arr->data);\n\n  int64_t vwgt_len = vwgt_arr->shape[0];\n  CHECK_EQ(sizeof(idx_t), vwgt_arr->dtype.bits / 8)\n      << \"The vertex weight array doesn't have right type\";\n  CHECK(vwgt_len % g->NumVertices() == 0)\n      << \"The vertex weight array doesn't have right number of elements\";\n  idx_t *vwgt = NULL;\n  if (vwgt_len > 0) {\n    ncon = vwgt_len / g->NumVertices();\n    vwgt = static_cast<idx_t *>(vwgt_arr->data);\n  }\n\n  idx_t options[METIS_NOPTIONS];\n  METIS_SetDefaultOptions(options);\n  options[METIS_OPTION_ONDISK] = 1;\n  options[METIS_OPTION_NITER] = 1;\n  options[METIS_OPTION_NIPARTS] = 1;\n  options[METIS_OPTION_DROPEDGES] = 1;\n\n  if (obj_cut) {\n    options[METIS_OPTION_OBJTYPE] = METIS_OBJTYPE_CUT;\n  } else {\n    options[METIS_OPTION_OBJTYPE] = METIS_OBJTYPE_VOL;\n  }\n\n  int ret = METIS_PartGraphKway(\n      &nvtxs,  // The number of vertices\n      &ncon,   // The number of balancing constraints.\n      xadj,    // indptr\n      adjncy,  // indices\n      vwgt,    // the weights of the vertices\n      NULL,    // The size of the vertices for computing\n      // the total communication volume\n      NULL,     // The weights of the edges\n      &nparts,  // The number of partitions.\n      NULL,     // the desired weight for each partition and constraint\n      NULL,     // the allowed load imbalance tolerance\n      options,  // the array of options\n      &objval,  // the edge-cut or the total communication volume of\n      // the partitioning solution\n      part);\n\n  if (obj_cut) {\n    LOG(INFO) << \"Partition a graph with \" << g->NumVertices() << \" nodes and \"\n              << g->NumEdges() << \" edges into \" << k << \" parts and \"\n              << \"get \" << objval << \" edge cuts\";\n  } else {\n    LOG(INFO) << \"Partition a graph with \" << g->NumVertices() << \" nodes and \"\n              << g->NumEdges() << \" edges into \" << k << \" parts and \"\n              << \"the communication volume is \" << objval;\n  }\n\n  switch (ret) {\n    case METIS_OK:\n      return part_arr;\n    case METIS_ERROR_INPUT:\n      LOG(FATAL) << \"Error in Metis partitioning: input error\";\n    case METIS_ERROR_MEMORY:\n      LOG(FATAL) << \"Error in Metis partitioning: cannot allocate memory\";\n    default:\n      LOG(FATAL) << \"Error in Metis partitioning: other errors\";\n  }\n  // return an array of 0 elements to indicate the error.\n  return aten::NullArray();\n}\n\n#endif  // !defined(_WIN32)\n\nDGL_REGISTER_GLOBAL(\"transform._CAPI_DGLMetisPartition\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      GraphRef g = args[0];\n      int k = args[1];\n      NDArray vwgt = args[2];\n      bool obj_cut = args[3];\n#if !defined(_WIN32)\n      *rv = MetisPartition(g.sptr(), k, vwgt, obj_cut);\n#else\n      LOG(FATAL) << \"Metis partition does not support Windows.\";\n#endif  // !defined(_WIN32)\n    });\n\n}  // namespace dgl\n"
  },
  {
    "path": "src/graph/nodeflow.cc",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file graph/nodeflow.cc\n * @brief DGL NodeFlow related functions.\n */\n\n#include <dgl/immutable_graph.h>\n#include <dgl/nodeflow.h>\n#include <dgl/packed_func_ext.h>\n\n#include <string>\n\n#include \"../c_api_common.h\"\n\nusing dgl::runtime::DGLArgs;\nusing dgl::runtime::DGLArgValue;\nusing dgl::runtime::DGLRetValue;\nusing dgl::runtime::PackedFunc;\n\nnamespace dgl {\n\nstd::vector<IdArray> GetNodeFlowSlice(\n    const ImmutableGraph &graph, const std::string &fmt, size_t layer0_size,\n    size_t layer1_start, size_t layer1_end, bool remap) {\n  CHECK_GE(layer1_start, layer0_size);\n  if (fmt == std::string(\"csr\")) {\n    dgl_id_t first_vid = layer1_start - layer0_size;\n    auto csr = aten::CSRSliceRows(\n        graph.GetInCSR()->ToCSRMatrix(), layer1_start, layer1_end);\n    if (remap) {\n      dgl_id_t *eid_data = static_cast<dgl_id_t *>(csr.data->data);\n      const dgl_id_t first_eid = eid_data[0];\n      IdArray new_indices = aten::Sub(csr.indices, first_vid);\n      IdArray new_data = aten::Sub(csr.data, first_eid);\n      return {csr.indptr, new_indices, new_data};\n    } else {\n      return {csr.indptr, csr.indices, csr.data};\n    }\n  } else if (fmt == std::string(\"coo\")) {\n    auto csr = graph.GetInCSR()->ToCSRMatrix();\n    const dgl_id_t *indptr = static_cast<dgl_id_t *>(csr.indptr->data);\n    const dgl_id_t *indices = static_cast<dgl_id_t *>(csr.indices->data);\n    const dgl_id_t *edge_ids = static_cast<dgl_id_t *>(csr.data->data);\n    int64_t nnz = indptr[layer1_end] - indptr[layer1_start];\n    IdArray idx = aten::NewIdArray(2 * nnz);\n    IdArray eid = aten::NewIdArray(nnz);\n    int64_t *idx_data = static_cast<int64_t *>(idx->data);\n    dgl_id_t *eid_data = static_cast<dgl_id_t *>(eid->data);\n    size_t num_edges = 0;\n    for (size_t i = layer1_start; i < layer1_end; i++) {\n      for (dgl_id_t j = indptr[i]; j < indptr[i + 1]; j++) {\n        // These nodes are all in a layer. We need to remap them to the node id\n        // local to the layer.\n        idx_data[num_edges] = remap ? i - layer1_start : i;\n        num_edges++;\n      }\n    }\n    CHECK_EQ(num_edges, nnz);\n    if (remap) {\n      size_t edge_start = indptr[layer1_start];\n      dgl_id_t first_eid = edge_ids[edge_start];\n      dgl_id_t first_vid = layer1_start - layer0_size;\n      for (int64_t i = 0; i < nnz; i++) {\n        CHECK_GE(indices[edge_start + i], first_vid);\n        idx_data[nnz + i] = indices[edge_start + i] - first_vid;\n        eid_data[i] = edge_ids[edge_start + i] - first_eid;\n      }\n    } else {\n      std::copy(\n          indices + indptr[layer1_start], indices + indptr[layer1_end],\n          idx_data + nnz);\n      std::copy(\n          edge_ids + indptr[layer1_start], edge_ids + indptr[layer1_end],\n          eid_data);\n    }\n    return std::vector<IdArray>{idx, eid};\n  } else {\n    LOG(FATAL) << \"unsupported adjacency matrix format\";\n    return {};\n  }\n}\n\nDGL_REGISTER_GLOBAL(\"_deprecate.nodeflow._CAPI_NodeFlowGetBlockAdj\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      GraphRef g = args[0];\n      std::string format = args[1];\n      int64_t layer0_size = args[2];\n      int64_t start = args[3];\n      int64_t end = args[4];\n      const bool remap = args[5];\n      auto ig =\n          CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()));\n      auto res = GetNodeFlowSlice(*ig, format, layer0_size, start, end, remap);\n      *rv = ConvertNDArrayVectorToPackedFunc(res);\n    });\n\n}  // namespace dgl\n"
  },
  {
    "path": "src/graph/pickle.cc",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file graph/pickle.cc\n * @brief Functions for pickle and unpickle a graph\n */\n#include <dgl/graph_serializer.h>\n#include <dgl/immutable_graph.h>\n#include <dgl/packed_func_ext.h>\n#include <dgl/runtime/container.h>\n#include <dmlc/memory_io.h>\n\n#include \"../c_api_common.h\"\n#include \"./heterograph.h\"\n#include \"unit_graph.h\"\n\nusing namespace dgl::runtime;\n\nnamespace dgl {\n\nHeteroPickleStates HeteroPickle(HeteroGraphPtr graph) {\n  HeteroPickleStates states;\n  states.version = 2;\n  dmlc::MemoryStringStream ofs(&states.meta);\n  dmlc::Stream *strm = &ofs;\n  strm->Write(ImmutableGraph::ToImmutable(graph->meta_graph()));\n  strm->Write(graph->NumVerticesPerType());\n  strm->Write(graph->IsPinned());\n  for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {\n    SparseFormat fmt = graph->SelectFormat(etype, ALL_CODE);\n    switch (fmt) {\n      case SparseFormat::kCOO: {\n        strm->Write(SparseFormat::kCOO);\n        const auto &coo = graph->GetCOOMatrix(etype);\n        strm->Write(coo.row_sorted);\n        strm->Write(coo.col_sorted);\n        states.arrays.push_back(coo.row);\n        states.arrays.push_back(coo.col);\n        break;\n      }\n      case SparseFormat::kCSR:\n      case SparseFormat::kCSC: {\n        strm->Write(SparseFormat::kCSR);\n        const auto &csr = graph->GetCSRMatrix(etype);\n        strm->Write(csr.sorted);\n        states.arrays.push_back(csr.indptr);\n        states.arrays.push_back(csr.indices);\n        states.arrays.push_back(csr.data);\n        break;\n      }\n      default:\n        LOG(FATAL) << \"Unsupported sparse format.\";\n    }\n  }\n  return states;\n}\n\nHeteroPickleStates HeteroForkingPickle(HeteroGraphPtr graph) {\n  HeteroPickleStates states;\n  states.version = 2;\n  dmlc::MemoryStringStream ofs(&states.meta);\n  dmlc::Stream *strm = &ofs;\n  strm->Write(ImmutableGraph::ToImmutable(graph->meta_graph()));\n  strm->Write(graph->NumVerticesPerType());\n  strm->Write(graph->IsPinned());\n  for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {\n    auto created_formats = graph->GetCreatedFormats();\n    auto allowed_formats = graph->GetAllowedFormats();\n    strm->Write(created_formats);\n    strm->Write(allowed_formats);\n    if (created_formats & COO_CODE) {\n      const auto &coo = graph->GetCOOMatrix(etype);\n      strm->Write(coo.row_sorted);\n      strm->Write(coo.col_sorted);\n      states.arrays.push_back(coo.row);\n      states.arrays.push_back(coo.col);\n    }\n    if (created_formats & CSR_CODE) {\n      const auto &csr = graph->GetCSRMatrix(etype);\n      strm->Write(csr.sorted);\n      states.arrays.push_back(csr.indptr);\n      states.arrays.push_back(csr.indices);\n      states.arrays.push_back(csr.data);\n    }\n    if (created_formats & CSC_CODE) {\n      const auto &csc = graph->GetCSCMatrix(etype);\n      strm->Write(csc.sorted);\n      states.arrays.push_back(csc.indptr);\n      states.arrays.push_back(csc.indices);\n      states.arrays.push_back(csc.data);\n    }\n  }\n  return states;\n}\n\nHeteroGraphPtr HeteroUnpickle(const HeteroPickleStates &states) {\n  char *buf = const_cast<char *>(states.meta.c_str());  // a readonly stream?\n  dmlc::MemoryFixedSizeStream ifs(buf, states.meta.size());\n  dmlc::Stream *strm = &ifs;\n  auto meta_imgraph = Serializer::make_shared<ImmutableGraph>();\n  CHECK(strm->Read(&meta_imgraph)) << \"Invalid meta graph\";\n  GraphPtr metagraph = meta_imgraph;\n  std::vector<HeteroGraphPtr> relgraphs(metagraph->NumEdges());\n  std::vector<int64_t> num_nodes_per_type;\n  CHECK(strm->Read(&num_nodes_per_type)) << \"Invalid num_nodes_per_type\";\n  bool is_pinned = false;\n  if (states.version > 1) {\n    CHECK(strm->Read(&is_pinned)) << \"Invalid flag 'is_pinned'\";\n  }\n\n  auto array_itr = states.arrays.begin();\n  for (dgl_type_t etype = 0; etype < metagraph->NumEdges(); ++etype) {\n    const auto &pair = metagraph->FindEdge(etype);\n    const dgl_type_t srctype = pair.first;\n    const dgl_type_t dsttype = pair.second;\n    const int64_t num_vtypes = (srctype == dsttype) ? 1 : 2;\n    int64_t num_src = num_nodes_per_type[srctype];\n    int64_t num_dst = num_nodes_per_type[dsttype];\n    SparseFormat fmt;\n    CHECK(strm->Read(&fmt)) << \"Invalid SparseFormat\";\n    HeteroGraphPtr relgraph;\n    switch (fmt) {\n      case SparseFormat::kCOO: {\n        CHECK_GE(states.arrays.end() - array_itr, 2);\n        const auto &row = *(array_itr++);\n        const auto &col = *(array_itr++);\n        bool rsorted;\n        bool csorted;\n        CHECK(strm->Read(&rsorted)) << \"Invalid flag 'rsorted'\";\n        CHECK(strm->Read(&csorted)) << \"Invalid flag 'csorted'\";\n        auto coo = aten::COOMatrix(\n            num_src, num_dst, row, col, aten::NullArray(), rsorted, csorted);\n        // TODO(zihao) fix\n        relgraph = CreateFromCOO(num_vtypes, coo, ALL_CODE);\n        break;\n      }\n      case SparseFormat::kCSR: {\n        CHECK_GE(states.arrays.end() - array_itr, 3);\n        const auto &indptr = *(array_itr++);\n        const auto &indices = *(array_itr++);\n        const auto &edge_id = *(array_itr++);\n        bool sorted;\n        CHECK(strm->Read(&sorted)) << \"Invalid flag 'sorted'\";\n        auto csr =\n            aten::CSRMatrix(num_src, num_dst, indptr, indices, edge_id, sorted);\n        // TODO(zihao) fix\n        relgraph = CreateFromCSR(num_vtypes, csr, ALL_CODE);\n        break;\n      }\n      case SparseFormat::kCSC:\n      default:\n        LOG(FATAL) << \"Unsupported sparse format.\";\n    }\n    relgraphs[etype] = relgraph;\n  }\n  auto graph = CreateHeteroGraph(metagraph, relgraphs, num_nodes_per_type);\n  if (is_pinned) {\n    graph->PinMemory_();\n  }\n  return graph;\n}\n\n// For backward compatibility\nHeteroGraphPtr HeteroUnpickleOld(const HeteroPickleStates &states) {\n  const auto metagraph = states.metagraph;\n  const auto &num_nodes_per_type = states.num_nodes_per_type;\n  CHECK_EQ(states.adjs.size(), metagraph->NumEdges());\n  std::vector<HeteroGraphPtr> relgraphs(metagraph->NumEdges());\n  for (dgl_type_t etype = 0; etype < metagraph->NumEdges(); ++etype) {\n    const auto &pair = metagraph->FindEdge(etype);\n    const dgl_type_t srctype = pair.first;\n    const dgl_type_t dsttype = pair.second;\n    const int64_t num_vtypes = (srctype == dsttype) ? 1 : 2;\n    const SparseFormat fmt =\n        static_cast<SparseFormat>(states.adjs[etype]->format);\n    switch (fmt) {\n      case SparseFormat::kCOO:\n        relgraphs[etype] = UnitGraph::CreateFromCOO(\n            num_vtypes, aten::COOMatrix(*states.adjs[etype]));\n        break;\n      case SparseFormat::kCSR:\n        relgraphs[etype] = UnitGraph::CreateFromCSR(\n            num_vtypes, aten::CSRMatrix(*states.adjs[etype]));\n        break;\n      case SparseFormat::kCSC:\n      default:\n        LOG(FATAL) << \"Unsupported sparse format.\";\n    }\n  }\n  return CreateHeteroGraph(metagraph, relgraphs, num_nodes_per_type);\n}\n\nHeteroGraphPtr HeteroForkingUnpickle(const HeteroPickleStates &states) {\n  char *buf = const_cast<char *>(states.meta.c_str());  // a readonly stream?\n  dmlc::MemoryFixedSizeStream ifs(buf, states.meta.size());\n  dmlc::Stream *strm = &ifs;\n  auto meta_imgraph = Serializer::make_shared<ImmutableGraph>();\n  CHECK(strm->Read(&meta_imgraph)) << \"Invalid meta graph\";\n  GraphPtr metagraph = meta_imgraph;\n  std::vector<HeteroGraphPtr> relgraphs(metagraph->NumEdges());\n  std::vector<int64_t> num_nodes_per_type;\n  CHECK(strm->Read(&num_nodes_per_type)) << \"Invalid num_nodes_per_type\";\n  bool is_pinned = false;\n  if (states.version > 1) {\n    CHECK(strm->Read(&is_pinned)) << \"Invalid flag 'is_pinned'\";\n  }\n\n  auto array_itr = states.arrays.begin();\n  for (dgl_type_t etype = 0; etype < metagraph->NumEdges(); ++etype) {\n    const auto &pair = metagraph->FindEdge(etype);\n    const dgl_type_t srctype = pair.first;\n    const dgl_type_t dsttype = pair.second;\n    const int64_t num_vtypes = (srctype == dsttype) ? 1 : 2;\n    int64_t num_src = num_nodes_per_type[srctype];\n    int64_t num_dst = num_nodes_per_type[dsttype];\n\n    dgl_format_code_t created_formats, allowed_formats;\n    CHECK(strm->Read(&created_formats)) << \"Invalid code for created formats\";\n    CHECK(strm->Read(&allowed_formats)) << \"Invalid code for allowed formats\";\n    aten::COOMatrix coo;\n    aten::CSRMatrix csr;\n    aten::CSRMatrix csc;\n    bool has_coo = (created_formats & COO_CODE);\n    bool has_csr = (created_formats & CSR_CODE);\n    bool has_csc = (created_formats & CSC_CODE);\n\n    if (created_formats & COO_CODE) {\n      CHECK_GE(states.arrays.end() - array_itr, 2);\n      const auto &row = *(array_itr++);\n      const auto &col = *(array_itr++);\n      bool rsorted;\n      bool csorted;\n      CHECK(strm->Read(&rsorted)) << \"Invalid flag 'rsorted'\";\n      CHECK(strm->Read(&csorted)) << \"Invalid flag 'csorted'\";\n      coo = aten::COOMatrix(\n          num_src, num_dst, row, col, aten::NullArray(), rsorted, csorted);\n    }\n    if (created_formats & CSR_CODE) {\n      CHECK_GE(states.arrays.end() - array_itr, 3);\n      const auto &indptr = *(array_itr++);\n      const auto &indices = *(array_itr++);\n      const auto &edge_id = *(array_itr++);\n      bool sorted;\n      CHECK(strm->Read(&sorted)) << \"Invalid flag 'sorted'\";\n      csr = aten::CSRMatrix(num_src, num_dst, indptr, indices, edge_id, sorted);\n    }\n    if (created_formats & CSC_CODE) {\n      CHECK_GE(states.arrays.end() - array_itr, 3);\n      const auto &indptr = *(array_itr++);\n      const auto &indices = *(array_itr++);\n      const auto &edge_id = *(array_itr++);\n      bool sorted;\n      CHECK(strm->Read(&sorted)) << \"Invalid flag 'sorted'\";\n      csc = aten::CSRMatrix(num_dst, num_src, indptr, indices, edge_id, sorted);\n    }\n    relgraphs[etype] = UnitGraph::CreateUnitGraphFrom(\n        num_vtypes, csc, csr, coo, has_csc, has_csr, has_coo, allowed_formats);\n  }\n  auto graph = CreateHeteroGraph(metagraph, relgraphs, num_nodes_per_type);\n  if (is_pinned) {\n    graph->PinMemory_();\n  }\n  return graph;\n}\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroPickleStatesGetVersion\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      HeteroPickleStatesRef st = args[0];\n      *rv = st->version;\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroPickleStatesGetMeta\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      HeteroPickleStatesRef st = args[0];\n      DGLByteArray buf;\n      buf.data = st->meta.c_str();\n      buf.size = st->meta.size();\n      *rv = buf;\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroPickleStatesGetArrays\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      HeteroPickleStatesRef st = args[0];\n      *rv = ConvertNDArrayVectorToPackedFunc(st->arrays);\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroPickleStatesGetArraysNum\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      HeteroPickleStatesRef st = args[0];\n      *rv = static_cast<int64_t>(st->arrays.size());\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLCreateHeteroPickleStates\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      const int version = args[0];\n      std::string meta = args[1];\n      const List<Value> arrays = args[2];\n      std::shared_ptr<HeteroPickleStates> st(new HeteroPickleStates);\n      st->version = version == 0 ? 1 : version;\n      st->meta = meta;\n      st->arrays.reserve(arrays.size());\n      for (const auto &ref : arrays) {\n        st->arrays.push_back(ref->data);\n      }\n      *rv = HeteroPickleStatesRef(st);\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroPickle\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      HeteroGraphRef ref = args[0];\n      std::shared_ptr<HeteroPickleStates> st(new HeteroPickleStates);\n      *st = HeteroPickle(ref.sptr());\n      *rv = HeteroPickleStatesRef(st);\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroForkingPickle\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      HeteroGraphRef ref = args[0];\n      std::shared_ptr<HeteroPickleStates> st(new HeteroPickleStates);\n      *st = HeteroForkingPickle(ref.sptr());\n      *rv = HeteroPickleStatesRef(st);\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroUnpickle\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      HeteroPickleStatesRef ref = args[0];\n      HeteroGraphPtr graph;\n      switch (ref->version) {\n        case 0:\n          graph = HeteroUnpickleOld(*ref.sptr());\n          break;\n        case 1:\n        case 2:\n          graph = HeteroUnpickle(*ref.sptr());\n          break;\n        default:\n          LOG(FATAL) << \"Version can only be 0 or 1 or 2.\";\n      }\n      *rv = HeteroGraphRef(graph);\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLHeteroForkingUnpickle\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      HeteroPickleStatesRef ref = args[0];\n      HeteroGraphPtr graph = HeteroForkingUnpickle(*ref.sptr());\n      *rv = HeteroGraphRef(graph);\n    });\n\nDGL_REGISTER_GLOBAL(\"heterograph_index._CAPI_DGLCreateHeteroPickleStatesOld\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      GraphRef metagraph = args[0];\n      IdArray num_nodes_per_type = args[1];\n      List<SparseMatrixRef> adjs = args[2];\n      std::shared_ptr<HeteroPickleStates> st(new HeteroPickleStates);\n      st->version = 0;\n      st->metagraph = metagraph.sptr();\n      st->num_nodes_per_type = num_nodes_per_type.ToVector<int64_t>();\n      st->adjs.reserve(adjs.size());\n      for (const auto &ref : adjs) st->adjs.push_back(ref.sptr());\n      *rv = HeteroPickleStatesRef(st);\n    });\n}  // namespace dgl\n"
  },
  {
    "path": "src/graph/sampler.cc",
    "content": "/**\n *  Copyright (c) 2018 by Contributors\n * @file graph/sampler.cc\n * @brief DGL sampler implementation\n */\n#include <dgl/array.h>\n#include <dgl/immutable_graph.h>\n#include <dgl/packed_func_ext.h>\n#include <dgl/random.h>\n#include <dgl/runtime/container.h>\n#include <dgl/runtime/parallel_for.h>\n#include <dgl/sampler.h>\n#include <dmlc/omp.h>\n\n#include <algorithm>\n#include <cmath>\n#include <cstdlib>\n#include <numeric>\n\n#include \"../c_api_common.h\"\n\nusing namespace dgl::runtime;\n\nnamespace dgl {\n\nnamespace {\n/**\n * ArrayHeap is used to sample elements from vector\n */\ntemplate <typename ValueType>\nclass ArrayHeap {\n public:\n  explicit ArrayHeap(const std::vector<ValueType> &prob) {\n    vec_size_ = prob.size();\n    bit_len_ = ceil(log2(vec_size_));\n    limit_ = 1UL << bit_len_;\n    // allocate twice the size\n    heap_.resize(limit_ << 1, 0);\n    // allocate the leaves\n    for (size_t i = limit_; i < vec_size_ + limit_; ++i) {\n      heap_[i] = prob[i - limit_];\n    }\n    // iterate up the tree (this is O(m))\n    for (int i = bit_len_ - 1; i >= 0; --i) {\n      for (size_t j = (1UL << i); j < (1UL << (i + 1)); ++j) {\n        heap_[j] = heap_[j << 1] + heap_[(j << 1) + 1];\n      }\n    }\n  }\n  ~ArrayHeap() {}\n\n  /**\n   * Remove term from index (this costs O(log m) steps)\n   */\n  void Delete(size_t index) {\n    size_t i = index + limit_;\n    heap_[i] = 0;\n    i /= 2;\n    for (int j = bit_len_ - 1; j >= 0; --j) {\n      // Using heap_[i] = heap_[i] - w will loss some precision in float.\n      // Using addition to re-calculate the weight layer by layer.\n      heap_[i] = heap_[i << 1] + heap_[(i << 1) + 1];\n      i /= 2;\n    }\n  }\n\n  /**\n   * Add value w to index (this costs O(log m) steps)\n   */\n  void Add(size_t index, ValueType w) {\n    size_t i = index + limit_;\n    for (int j = bit_len_; j >= 0; --j) {\n      heap_[i] += w;\n      i = i >> 1;\n    }\n  }\n\n  /**\n   * Sample from arrayHeap\n   */\n  size_t Sample() {\n    // heap_ is empty\n    ValueType xi = heap_[1] * RandomEngine::ThreadLocal()->Uniform<float>();\n    size_t i = 1;\n    while (i < limit_) {\n      i = i << 1;\n      if (xi >= heap_[i]) {\n        xi -= heap_[i];\n        i += 1;\n      }\n    }\n    return i - limit_;\n  }\n\n  /**\n   * Sample a vector by given the size n\n   */\n  size_t SampleWithoutReplacement(size_t n, std::vector<size_t> *samples) {\n    // sample n elements\n    size_t i = 0;\n    for (; i < n; ++i) {\n      // heap is empty\n      if (heap_[1] == 0) {\n        break;\n      }\n      samples->at(i) = this->Sample();\n      this->Delete(samples->at(i));\n    }\n\n    return i;\n  }\n\n private:\n  size_t vec_size_;  // sample size\n  int bit_len_;      // bit size\n  size_t limit_;\n  std::vector<ValueType> heap_;\n};\n\n///////////////////////// Samplers //////////////////////////\nclass EdgeSamplerObject : public Object {\n public:\n  EdgeSamplerObject(\n      const GraphPtr gptr, IdArray seed_edges, const int64_t batch_size,\n      const int64_t num_workers, const bool replacement, const bool reset,\n      const std::string neg_mode, const int64_t neg_sample_size,\n      const int64_t chunk_size, const bool exclude_positive,\n      const bool check_false_neg, IdArray relations) {\n    gptr_ = gptr;\n    seed_edges_ = seed_edges;\n    relations_ = relations;\n\n    batch_size_ = batch_size;\n    num_workers_ = num_workers;\n    replacement_ = replacement;\n    reset_ = reset;\n    neg_mode_ = neg_mode;\n    neg_sample_size_ = neg_sample_size;\n    exclude_positive_ = exclude_positive;\n    check_false_neg_ = check_false_neg;\n    chunk_size_ = chunk_size;\n  }\n\n  ~EdgeSamplerObject() {}\n\n  virtual void Fetch(DGLRetValue *rv) = 0;\n  virtual void Reset() = 0;\n\n protected:\n  virtual void randomSample(\n      size_t set_size, size_t num, std::vector<size_t> *out) = 0;\n  virtual void randomSample(\n      size_t set_size, size_t num, const std::vector<size_t> &exclude,\n      std::vector<size_t> *out) = 0;\n\n  NegSubgraph genNegEdgeSubgraph(\n      const Subgraph &pos_subg, const std::string &neg_mode,\n      int64_t neg_sample_size, bool exclude_positive, bool check_false_neg);\n  NegSubgraph genChunkedNegEdgeSubgraph(\n      const Subgraph &pos_subg, const std::string &neg_mode,\n      int64_t neg_sample_size, bool exclude_positive, bool check_false_neg);\n\n  GraphPtr gptr_;\n  IdArray seed_edges_;\n  IdArray relations_;\n\n  int64_t batch_size_;\n  int64_t num_workers_;\n  bool replacement_;\n  int64_t reset_;\n  std::string neg_mode_;\n  int64_t neg_sample_size_;\n  bool exclude_positive_;\n  bool check_false_neg_;\n  int64_t chunk_size_;\n};\n\n/**\n * Uniformly sample integers from [0, set_size) without replacement.\n */\nvoid RandomSample(size_t set_size, size_t num, std::vector<size_t> *out) {\n  if (num < set_size) {\n    std::unordered_set<size_t> sampled_idxs;\n    while (sampled_idxs.size() < num) {\n      sampled_idxs.insert(RandomEngine::ThreadLocal()->RandInt(set_size));\n    }\n    out->insert(out->end(), sampled_idxs.begin(), sampled_idxs.end());\n  } else {\n    // If we need to sample all elements in the set, we don't need to\n    // generate random numbers.\n    for (size_t i = 0; i < set_size; i++) out->push_back(i);\n  }\n}\n\nvoid RandomSample(\n    size_t set_size, size_t num, const std::vector<size_t> &exclude,\n    std::vector<size_t> *out) {\n  std::unordered_map<size_t, int> sampled_idxs;\n  for (auto v : exclude) {\n    sampled_idxs.insert(std::pair<size_t, int>(v, 0));\n  }\n  if (num + exclude.size() < set_size) {\n    while (sampled_idxs.size() < num + exclude.size()) {\n      size_t rand = RandomEngine::ThreadLocal()->RandInt(set_size);\n      sampled_idxs.insert(std::pair<size_t, int>(rand, 1));\n    }\n    for (auto it = sampled_idxs.begin(); it != sampled_idxs.end(); it++) {\n      if (it->second) {\n        out->push_back(it->first);\n      }\n    }\n  } else {\n    // If we need to sample all elements in the set, we don't need to\n    // generate random numbers.\n    for (size_t i = 0; i < set_size; i++) {\n      // If the element doesn't exist in exclude.\n      if (sampled_idxs.find(i) == sampled_idxs.end()) {\n        out->push_back(i);\n      }\n    }\n  }\n}\n\n/**\n * For a sparse array whose non-zeros are represented by nz_idxs,\n * negate the sparse array and outputs the non-zeros in the negated array.\n */\nvoid NegateArray(\n    const std::vector<size_t> &nz_idxs, size_t arr_size,\n    std::vector<size_t> *out) {\n  // nz_idxs must have been sorted.\n  auto it = nz_idxs.begin();\n  size_t i = 0;\n  CHECK_GT(arr_size, nz_idxs.back());\n  for (; i < arr_size && it != nz_idxs.end(); i++) {\n    if (*it == i) {\n      it++;\n      continue;\n    }\n    out->push_back(i);\n  }\n  for (; i < arr_size; i++) {\n    out->push_back(i);\n  }\n}\n\n/**\n * Uniform sample vertices from a list of vertices.\n */\nvoid GetUniformSample(\n    const dgl_id_t *edge_id_list, const dgl_id_t *vid_list,\n    const size_t ver_len, const size_t max_num_neighbor,\n    std::vector<dgl_id_t> *out_ver, std::vector<dgl_id_t> *out_edge) {\n  // Copy vid_list to output\n  if (ver_len <= max_num_neighbor) {\n    out_ver->insert(out_ver->end(), vid_list, vid_list + ver_len);\n    out_edge->insert(out_edge->end(), edge_id_list, edge_id_list + ver_len);\n    return;\n  }\n  // If we just sample a small number of elements from a large neighbor list.\n  std::vector<size_t> sorted_idxs;\n  if (ver_len > max_num_neighbor * 2) {\n    sorted_idxs.reserve(max_num_neighbor);\n    RandomSample(ver_len, max_num_neighbor, &sorted_idxs);\n    std::sort(sorted_idxs.begin(), sorted_idxs.end());\n  } else {\n    std::vector<size_t> negate;\n    negate.reserve(ver_len - max_num_neighbor);\n    RandomSample(ver_len, ver_len - max_num_neighbor, &negate);\n    std::sort(negate.begin(), negate.end());\n    NegateArray(negate, ver_len, &sorted_idxs);\n  }\n  // verify the result.\n  CHECK_EQ(sorted_idxs.size(), max_num_neighbor);\n  for (size_t i = 1; i < sorted_idxs.size(); i++) {\n    CHECK_GT(sorted_idxs[i], sorted_idxs[i - 1]);\n  }\n  for (auto idx : sorted_idxs) {\n    out_ver->push_back(vid_list[idx]);\n    out_edge->push_back(edge_id_list[idx]);\n  }\n}\n\n/**\n * Non-uniform sample via ArrayHeap\n *\n * @param probability Transition probability on the entire graph, indexed by\n * edge ID\n */\ntemplate <typename ValueType>\nvoid GetNonUniformSample(\n    const ValueType *probability, const dgl_id_t *edge_id_list,\n    const dgl_id_t *vid_list, const size_t ver_len,\n    const size_t max_num_neighbor, std::vector<dgl_id_t> *out_ver,\n    std::vector<dgl_id_t> *out_edge) {\n  // Copy vid_list to output\n  if (ver_len <= max_num_neighbor) {\n    out_ver->insert(out_ver->end(), vid_list, vid_list + ver_len);\n    out_edge->insert(out_edge->end(), edge_id_list, edge_id_list + ver_len);\n    return;\n  }\n  // Make sample\n  std::vector<size_t> sp_index(max_num_neighbor);\n  std::vector<ValueType> sp_prob(ver_len);\n  for (size_t i = 0; i < ver_len; ++i) {\n    sp_prob[i] = probability[edge_id_list[i]];\n  }\n  ArrayHeap<ValueType> arrayHeap(sp_prob);\n  arrayHeap.SampleWithoutReplacement(max_num_neighbor, &sp_index);\n  out_ver->resize(max_num_neighbor);\n  out_edge->resize(max_num_neighbor);\n  for (size_t i = 0; i < max_num_neighbor; ++i) {\n    size_t idx = sp_index[i];\n    out_ver->at(i) = vid_list[idx];\n    out_edge->at(i) = edge_id_list[idx];\n  }\n  sort(out_ver->begin(), out_ver->end());\n  sort(out_edge->begin(), out_edge->end());\n}\n\n/**\n * Used for subgraph sampling\n */\nstruct neigh_list {\n  std::vector<dgl_id_t> neighs;\n  std::vector<dgl_id_t> edges;\n  neigh_list(\n      const std::vector<dgl_id_t> &_neighs, const std::vector<dgl_id_t> &_edges)\n      : neighs(_neighs), edges(_edges) {}\n};\n\nstruct neighbor_info {\n  dgl_id_t id;\n  size_t pos;\n  size_t num_edges;\n\n  neighbor_info(dgl_id_t id, size_t pos, size_t num_edges) {\n    this->id = id;\n    this->pos = pos;\n    this->num_edges = num_edges;\n  }\n};\n\nNodeFlow ConstructNodeFlow(\n    std::vector<dgl_id_t> neighbor_list, std::vector<dgl_id_t> edge_list,\n    std::vector<size_t> layer_offsets,\n    std::vector<std::pair<dgl_id_t, int>> *sub_vers,\n    std::vector<neighbor_info> *neigh_pos, const std::string &edge_type,\n    int64_t num_edges, int num_hops) {\n  NodeFlow nf = NodeFlow::Create();\n  uint64_t num_vertices = sub_vers->size();\n  nf->node_mapping = aten::NewIdArray(num_vertices);\n  nf->edge_mapping = aten::NewIdArray(num_edges);\n  nf->layer_offsets = aten::NewIdArray(num_hops + 1);\n  nf->flow_offsets = aten::NewIdArray(num_hops);\n\n  dgl_id_t *node_map_data = static_cast<dgl_id_t *>(nf->node_mapping->data);\n  dgl_id_t *layer_off_data = static_cast<dgl_id_t *>(nf->layer_offsets->data);\n  dgl_id_t *flow_off_data = static_cast<dgl_id_t *>(nf->flow_offsets->data);\n  dgl_id_t *edge_map_data = static_cast<dgl_id_t *>(nf->edge_mapping->data);\n\n  // Construct sub_csr_graph, we treat nodeflow as multigraph by default\n  auto subg_csr = CSRPtr(new CSR(num_vertices, num_edges));\n  dgl_id_t *indptr_out = static_cast<dgl_id_t *>(subg_csr->indptr()->data);\n  dgl_id_t *col_list_out = static_cast<dgl_id_t *>(subg_csr->indices()->data);\n  dgl_id_t *eid_out = static_cast<dgl_id_t *>(subg_csr->edge_ids()->data);\n  size_t collected_nedges = 0;\n\n  // The data from the previous steps:\n  // * node data: sub_vers (vid, layer), neigh_pos,\n  // * edge data: neighbor_list, edge_list, probability.\n  // * layer_offsets: the offset in sub_vers.\n  dgl_id_t ver_id = 0;\n  std::vector<std::unordered_map<dgl_id_t, dgl_id_t>> layer_ver_maps;\n  layer_ver_maps.resize(num_hops);\n  size_t out_node_idx = 0;\n  for (int layer_id = num_hops - 1; layer_id >= 0; layer_id--) {\n    // We sort the vertices in a layer so that we don't need to sort the\n    // neighbor Ids after remap to a subgraph. However, we don't need to sort\n    // the first layer because we want the order of the nodes in the first layer\n    // is the same as the input seed nodes.\n    if (layer_id > 0) {\n      std::sort(\n          sub_vers->begin() + layer_offsets[layer_id],\n          sub_vers->begin() + layer_offsets[layer_id + 1],\n          [](const std::pair<dgl_id_t, dgl_id_t> &a1,\n             const std::pair<dgl_id_t, dgl_id_t> &a2) {\n            return a1.first < a2.first;\n          });\n    }\n\n    // Save the sampled vertices and its layer Id.\n    for (size_t i = layer_offsets[layer_id]; i < layer_offsets[layer_id + 1];\n         i++) {\n      node_map_data[out_node_idx++] = sub_vers->at(i).first;\n      layer_ver_maps[layer_id].insert(\n          std::pair<dgl_id_t, dgl_id_t>(sub_vers->at(i).first, ver_id++));\n      CHECK_EQ(sub_vers->at(i).second, layer_id);\n    }\n  }\n  CHECK(out_node_idx == num_vertices);\n\n  // sampling algorithms have to start from the seed nodes, so the seed nodes\n  // are in the first layer and the input nodes are in the last layer. When we\n  // expose the sampled graph to a Python user, we say the input nodes are in\n  // the first layer and the seed nodes are in the last layer. Thus, when we\n  // copy sampled results to a CSR, we need to reverse the order of layers.\n  std::fill(indptr_out, indptr_out + num_vertices + 1, 0);\n  size_t row_idx = layer_offsets[num_hops] - layer_offsets[num_hops - 1];\n  layer_off_data[0] = 0;\n  layer_off_data[1] = layer_offsets[num_hops] - layer_offsets[num_hops - 1];\n  int out_layer_idx = 1;\n  for (int layer_id = num_hops - 2; layer_id >= 0; layer_id--) {\n    // Because we don't sort the vertices in the first layer above, we can't\n    // sort the neighbor positions of the vertices in the first layer either.\n    if (layer_id > 0) {\n      std::sort(\n          neigh_pos->begin() + layer_offsets[layer_id],\n          neigh_pos->begin() + layer_offsets[layer_id + 1],\n          [](const neighbor_info &a1, const neighbor_info &a2) {\n            return a1.id < a2.id;\n          });\n    }\n\n    for (size_t i = layer_offsets[layer_id]; i < layer_offsets[layer_id + 1];\n         i++) {\n      dgl_id_t dst_id = sub_vers->at(i).first;\n      CHECK_EQ(dst_id, neigh_pos->at(i).id);\n      size_t pos = neigh_pos->at(i).pos;\n      CHECK_LE(pos, neighbor_list.size());\n      const size_t nedges = neigh_pos->at(i).num_edges;\n      if (neighbor_list.empty()) CHECK_EQ(nedges, 0);\n\n      // We need to map the Ids of the neighbors to the subgraph.\n      auto neigh_it = neighbor_list.begin() + pos;\n      for (size_t i = 0; i < nedges; i++) {\n        dgl_id_t neigh = *(neigh_it + i);\n        CHECK(\n            layer_ver_maps[layer_id + 1].find(neigh) !=\n            layer_ver_maps[layer_id + 1].end());\n        col_list_out[collected_nedges + i] =\n            layer_ver_maps[layer_id + 1][neigh];\n      }\n      // We can simply copy the edge Ids.\n      std::copy_n(\n          edge_list.begin() + pos, nedges, edge_map_data + collected_nedges);\n      collected_nedges += nedges;\n      indptr_out[row_idx + 1] = indptr_out[row_idx] + nedges;\n      row_idx++;\n    }\n    layer_off_data[out_layer_idx + 1] = layer_off_data[out_layer_idx] +\n                                        layer_offsets[layer_id + 1] -\n                                        layer_offsets[layer_id];\n    out_layer_idx++;\n  }\n  CHECK_EQ(row_idx, num_vertices);\n  CHECK_EQ(indptr_out[row_idx], num_edges);\n  CHECK_EQ(out_layer_idx, num_hops);\n  CHECK_EQ(layer_off_data[out_layer_idx], num_vertices);\n\n  // Copy flow offsets.\n  flow_off_data[0] = 0;\n  int out_flow_idx = 0;\n  for (size_t i = 0; i < layer_offsets.size() - 2; i++) {\n    size_t num_edges =\n        indptr_out[layer_off_data[i + 2]] - indptr_out[layer_off_data[i + 1]];\n    flow_off_data[out_flow_idx + 1] = flow_off_data[out_flow_idx] + num_edges;\n    out_flow_idx++;\n  }\n  CHECK(out_flow_idx == num_hops - 1);\n  CHECK(flow_off_data[num_hops - 1] == static_cast<uint64_t>(num_edges));\n\n  std::iota(eid_out, eid_out + num_edges, 0);\n\n  if (edge_type == std::string(\"in\")) {\n    nf->graph = GraphPtr(new ImmutableGraph(subg_csr, nullptr));\n  } else {\n    nf->graph = GraphPtr(new ImmutableGraph(nullptr, subg_csr));\n  }\n\n  return nf;\n}\n\ntemplate <typename ValueType>\nNodeFlow SampleSubgraph(\n    const ImmutableGraph *graph, const std::vector<dgl_id_t> &seeds,\n    const ValueType *probability, const std::string &edge_type, int num_hops,\n    size_t num_neighbor, const bool add_self_loop) {\n  CHECK_EQ(graph->NumBits(), 64) << \"32 bit graph is not supported yet\";\n  const size_t num_seeds = seeds.size();\n  auto orig_csr = edge_type == \"in\" ? graph->GetInCSR() : graph->GetOutCSR();\n  const dgl_id_t *val_list =\n      static_cast<dgl_id_t *>(orig_csr->edge_ids()->data);\n  const dgl_id_t *col_list = static_cast<dgl_id_t *>(orig_csr->indices()->data);\n  const dgl_id_t *indptr = static_cast<dgl_id_t *>(orig_csr->indptr()->data);\n\n  std::unordered_set<dgl_id_t> sub_ver_map;  // The vertex Ids in a layer.\n  std::vector<std::pair<dgl_id_t, int>> sub_vers;\n  sub_vers.reserve(num_seeds * 10);\n  // add seed vertices\n  for (size_t i = 0; i < num_seeds; ++i) {\n    auto ret = sub_ver_map.insert(seeds[i]);\n    // If the vertex is inserted successfully.\n    if (ret.second) {\n      sub_vers.emplace_back(seeds[i], 0);\n    }\n  }\n  std::vector<dgl_id_t> tmp_sampled_src_list;\n  std::vector<dgl_id_t> tmp_sampled_edge_list;\n  // ver_id, position\n  std::vector<neighbor_info> neigh_pos;\n  neigh_pos.reserve(num_seeds);\n  std::vector<dgl_id_t> neighbor_list;\n  std::vector<dgl_id_t> edge_list;\n  std::vector<size_t> layer_offsets(num_hops + 1);\n  int64_t num_edges = 0;\n\n  layer_offsets[0] = 0;\n  layer_offsets[1] = sub_vers.size();\n  for (int layer_id = 1; layer_id < num_hops; layer_id++) {\n    // We need to avoid resampling the same node in a layer, but we allow a node\n    // to be resampled in multiple layers. We use `sub_ver_map` to keep track of\n    // sampled nodes in a layer, and clear it when entering a new layer.\n    sub_ver_map.clear();\n    // Previous iteration collects all nodes in sub_vers, which are collected\n    // in the previous layer. sub_vers is used both as a node collection and a\n    // queue.\n    for (size_t idx = layer_offsets[layer_id - 1];\n         idx < layer_offsets[layer_id]; idx++) {\n      dgl_id_t dst_id = sub_vers[idx].first;\n      const int cur_node_level = sub_vers[idx].second;\n\n      tmp_sampled_src_list.clear();\n      tmp_sampled_edge_list.clear();\n      dgl_id_t ver_len = *(indptr + dst_id + 1) - *(indptr + dst_id);\n      if (probability == nullptr) {  // uniform-sample\n        GetUniformSample(\n            val_list + *(indptr + dst_id), col_list + *(indptr + dst_id),\n            ver_len, num_neighbor, &tmp_sampled_src_list,\n            &tmp_sampled_edge_list);\n      } else {  // non-uniform-sample\n        GetNonUniformSample(\n            probability, val_list + *(indptr + dst_id),\n            col_list + *(indptr + dst_id), ver_len, num_neighbor,\n            &tmp_sampled_src_list, &tmp_sampled_edge_list);\n      }\n      // If we need to add self loop and it doesn't exist in the sampled\n      // neighbor list.\n      if (add_self_loop &&\n          std::find(\n              tmp_sampled_src_list.begin(), tmp_sampled_src_list.end(),\n              dst_id) == tmp_sampled_src_list.end()) {\n        tmp_sampled_src_list.push_back(dst_id);\n        const dgl_id_t *src_list = col_list + *(indptr + dst_id);\n        const dgl_id_t *eid_list = val_list + *(indptr + dst_id);\n        // TODO(zhengda) this operation has O(N) complexity. It can be pretty\n        // slow.\n        const dgl_id_t *src = std::find(src_list, src_list + ver_len, dst_id);\n        // If there doesn't exist a self loop in the graph.\n        // we have to add -1 as the edge id for the self-loop edge.\n        if (src == src_list + ver_len)\n          tmp_sampled_edge_list.push_back(-1);\n        else\n          tmp_sampled_edge_list.push_back(eid_list[src - src_list]);\n      }\n      CHECK_EQ(tmp_sampled_src_list.size(), tmp_sampled_edge_list.size());\n      neigh_pos.emplace_back(\n          dst_id, neighbor_list.size(), tmp_sampled_src_list.size());\n      // Then push the vertices\n      for (size_t i = 0; i < tmp_sampled_src_list.size(); ++i) {\n        neighbor_list.push_back(tmp_sampled_src_list[i]);\n      }\n      // Finally we push the edge list\n      for (size_t i = 0; i < tmp_sampled_edge_list.size(); ++i) {\n        edge_list.push_back(tmp_sampled_edge_list[i]);\n      }\n      num_edges += tmp_sampled_src_list.size();\n      for (size_t i = 0; i < tmp_sampled_src_list.size(); ++i) {\n        // We need to add the neighbor in the hashtable here. This ensures that\n        // the vertex in the queue is unique. If we see a vertex before, we\n        // don't need to add it to the queue again.\n        auto ret = sub_ver_map.insert(tmp_sampled_src_list[i]);\n        // If the sampled neighbor is inserted to the map successfully.\n        if (ret.second) {\n          sub_vers.emplace_back(tmp_sampled_src_list[i], cur_node_level + 1);\n        }\n      }\n    }\n    layer_offsets[layer_id + 1] = layer_offsets[layer_id] + sub_ver_map.size();\n    CHECK_EQ(layer_offsets[layer_id + 1], sub_vers.size());\n  }\n\n  return ConstructNodeFlow(\n      neighbor_list, edge_list, layer_offsets, &sub_vers, &neigh_pos, edge_type,\n      num_edges, num_hops);\n}\n\n}  // namespace\n\nDGL_REGISTER_GLOBAL(\"_deprecate.nodeflow._CAPI_NodeFlowGetGraph\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      NodeFlow nflow = args[0];\n      *rv = nflow->graph;\n    });\n\nDGL_REGISTER_GLOBAL(\"_deprecate.nodeflow._CAPI_NodeFlowGetNodeMapping\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      NodeFlow nflow = args[0];\n      *rv = nflow->node_mapping;\n    });\n\nDGL_REGISTER_GLOBAL(\"_deprecate.nodeflow._CAPI_NodeFlowGetEdgeMapping\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      NodeFlow nflow = args[0];\n      *rv = nflow->edge_mapping;\n    });\n\nDGL_REGISTER_GLOBAL(\"_deprecate.nodeflow._CAPI_NodeFlowGetLayerOffsets\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      NodeFlow nflow = args[0];\n      *rv = nflow->layer_offsets;\n    });\n\nDGL_REGISTER_GLOBAL(\"_deprecate.nodeflow._CAPI_NodeFlowGetBlockOffsets\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      NodeFlow nflow = args[0];\n      *rv = nflow->flow_offsets;\n    });\n\ntemplate <typename ValueType>\nNodeFlow SamplerOp::NeighborSample(\n    const ImmutableGraph *graph, const std::vector<dgl_id_t> &seeds,\n    const std::string &edge_type, int num_hops, int expand_factor,\n    const bool add_self_loop, const ValueType *probability) {\n  return SampleSubgraph(\n      graph, seeds, probability, edge_type, num_hops + 1, expand_factor,\n      add_self_loop);\n}\n\nnamespace {\nvoid ConstructLayers(\n    const dgl_id_t *indptr, const dgl_id_t *indices,\n    const std::vector<dgl_id_t> &seed_array, IdArray layer_sizes,\n    std::vector<dgl_id_t> *layer_offsets, std::vector<dgl_id_t> *node_mapping,\n    std::vector<int64_t> *actl_layer_sizes, std::vector<float> *probabilities) {\n  /**\n   * Given a graph and a collection of seed nodes, this function constructs\n   * NodeFlow layers via uniform layer-wise sampling, and return the resultant\n   * layers and their corresponding probabilities.\n   */\n  std::copy(\n      seed_array.begin(), seed_array.end(), std::back_inserter(*node_mapping));\n  actl_layer_sizes->push_back(node_mapping->size());\n  probabilities->insert(probabilities->end(), node_mapping->size(), 1);\n  const int64_t *layer_sizes_data = static_cast<int64_t *>(layer_sizes->data);\n  const int64_t num_layers = layer_sizes->shape[0];\n\n  size_t curr = 0;\n  size_t next = node_mapping->size();\n  for (int64_t i = num_layers - 1; i >= 0; --i) {\n    const int64_t layer_size = layer_sizes_data[i];\n    std::unordered_set<dgl_id_t> candidate_set;\n    for (auto j = curr; j != next; ++j) {\n      auto src = (*node_mapping)[j];\n      candidate_set.insert(indices + indptr[src], indices + indptr[src + 1]);\n    }\n\n    std::vector<dgl_id_t> candidate_vector;\n    std::copy(\n        candidate_set.begin(), candidate_set.end(),\n        std::back_inserter(candidate_vector));\n\n    std::unordered_map<dgl_id_t, size_t> n_occurrences;\n    auto n_candidates = candidate_vector.size();\n    for (int64_t j = 0; j != layer_size; ++j) {\n      auto dst =\n          candidate_vector[RandomEngine::ThreadLocal()->RandInt(n_candidates)];\n      if (!n_occurrences.insert(std::make_pair(dst, 1)).second) {\n        ++n_occurrences[dst];\n      }\n    }\n\n    for (auto const &pair : n_occurrences) {\n      node_mapping->push_back(pair.first);\n      float p = pair.second * n_candidates / static_cast<float>(layer_size);\n      probabilities->push_back(p);\n    }\n\n    actl_layer_sizes->push_back(node_mapping->size() - next);\n    curr = next;\n    next = node_mapping->size();\n  }\n  std::reverse(node_mapping->begin(), node_mapping->end());\n  std::reverse(actl_layer_sizes->begin(), actl_layer_sizes->end());\n  layer_offsets->push_back(0);\n  for (const auto &size : *actl_layer_sizes) {\n    layer_offsets->push_back(size + layer_offsets->back());\n  }\n}\n\nvoid ConstructFlows(\n    const dgl_id_t *indptr, const dgl_id_t *indices, const dgl_id_t *eids,\n    const std::vector<dgl_id_t> &node_mapping,\n    const std::vector<int64_t> &actl_layer_sizes,\n    std::vector<dgl_id_t> *sub_indptr, std::vector<dgl_id_t> *sub_indices,\n    std::vector<dgl_id_t> *sub_eids, std::vector<dgl_id_t> *flow_offsets,\n    std::vector<dgl_id_t> *edge_mapping) {\n  /**\n   * Given a graph and a sequence of NodeFlow layers, this function constructs\n   * dense subgraphs (flows) between consecutive layers.\n   */\n  auto n_flows = actl_layer_sizes.size() - 1;\n  for (int64_t i = 0; i < actl_layer_sizes.front() + 1; i++)\n    sub_indptr->push_back(0);\n  flow_offsets->push_back(0);\n  int64_t first = 0;\n  for (size_t i = 0; i < n_flows; ++i) {\n    auto src_size = actl_layer_sizes[i];\n    std::unordered_map<dgl_id_t, dgl_id_t> source_map;\n    for (int64_t j = 0; j < src_size; ++j) {\n      source_map.insert(std::make_pair(node_mapping[first + j], first + j));\n    }\n    auto dst_size = actl_layer_sizes[i + 1];\n    for (int64_t j = 0; j < dst_size; ++j) {\n      auto dst = node_mapping[first + src_size + j];\n      typedef std::pair<dgl_id_t, dgl_id_t> id_pair;\n      std::vector<id_pair> neighbor_indices;\n      for (dgl_id_t k = indptr[dst]; k < indptr[dst + 1]; ++k) {\n        // TODO(gaiyu): accelerate hash table lookup\n        auto ret = source_map.find(indices[k]);\n        if (ret != source_map.end()) {\n          neighbor_indices.push_back(std::make_pair(ret->second, eids[k]));\n        }\n      }\n      auto cmp = [](const id_pair p, const id_pair q) -> bool {\n        return p.first < q.first;\n      };\n      std::sort(neighbor_indices.begin(), neighbor_indices.end(), cmp);\n      for (const auto &pair : neighbor_indices) {\n        sub_indices->push_back(pair.first);\n        edge_mapping->push_back(pair.second);\n      }\n      sub_indptr->push_back(sub_indices->size());\n    }\n    flow_offsets->push_back(sub_indices->size());\n    first += src_size;\n  }\n  sub_eids->resize(sub_indices->size());\n  std::iota(sub_eids->begin(), sub_eids->end(), 0);\n}\n}  // namespace\n\nNodeFlow SamplerOp::LayerUniformSample(\n    const ImmutableGraph *graph, const std::vector<dgl_id_t> &seeds,\n    const std::string &neighbor_type, IdArray layer_sizes) {\n  const auto g_csr =\n      neighbor_type == \"in\" ? graph->GetInCSR() : graph->GetOutCSR();\n  const dgl_id_t *indptr = static_cast<dgl_id_t *>(g_csr->indptr()->data);\n  const dgl_id_t *indices = static_cast<dgl_id_t *>(g_csr->indices()->data);\n  const dgl_id_t *eids = static_cast<dgl_id_t *>(g_csr->edge_ids()->data);\n\n  std::vector<dgl_id_t> layer_offsets;\n  std::vector<dgl_id_t> node_mapping;\n  std::vector<int64_t> actl_layer_sizes;\n  std::vector<float> probabilities;\n  ConstructLayers(\n      indptr, indices, seeds, layer_sizes, &layer_offsets, &node_mapping,\n      &actl_layer_sizes, &probabilities);\n\n  std::vector<dgl_id_t> sub_indptr, sub_indices, sub_edge_ids;\n  std::vector<dgl_id_t> flow_offsets;\n  std::vector<dgl_id_t> edge_mapping;\n  ConstructFlows(\n      indptr, indices, eids, node_mapping, actl_layer_sizes, &sub_indptr,\n      &sub_indices, &sub_edge_ids, &flow_offsets, &edge_mapping);\n  // sanity check\n  CHECK_GT(sub_indptr.size(), 0);\n  CHECK_EQ(sub_indptr[0], 0);\n  CHECK_EQ(sub_indptr.back(), sub_indices.size());\n  CHECK_EQ(sub_indices.size(), sub_edge_ids.size());\n\n  NodeFlow nf = NodeFlow::Create();\n  auto sub_csr = CSRPtr(new CSR(\n      aten::VecToIdArray(sub_indptr), aten::VecToIdArray(sub_indices),\n      aten::VecToIdArray(sub_edge_ids)));\n\n  if (neighbor_type == std::string(\"in\")) {\n    nf->graph = GraphPtr(new ImmutableGraph(sub_csr, nullptr));\n  } else {\n    nf->graph = GraphPtr(new ImmutableGraph(nullptr, sub_csr));\n  }\n\n  nf->node_mapping = aten::VecToIdArray(node_mapping);\n  nf->edge_mapping = aten::VecToIdArray(edge_mapping);\n  nf->layer_offsets = aten::VecToIdArray(layer_offsets);\n  nf->flow_offsets = aten::VecToIdArray(flow_offsets);\n\n  return nf;\n}\n\nvoid BuildCsr(const ImmutableGraph &g, const std::string neigh_type) {\n  if (neigh_type == \"in\") {\n    auto csr = g.GetInCSR();\n    assert(csr);\n  } else if (neigh_type == \"out\") {\n    auto csr = g.GetOutCSR();\n    assert(csr);\n  } else {\n    LOG(FATAL) << \"We don't support sample from neighbor type \" << neigh_type;\n  }\n}\n\ntemplate <typename ValueType>\nstd::vector<NodeFlow> NeighborSamplingImpl(\n    const ImmutableGraphPtr gptr, const IdArray seed_nodes,\n    const int64_t batch_start_id, const int64_t batch_size,\n    const int64_t max_num_workers, const int64_t expand_factor,\n    const int64_t num_hops, const std::string neigh_type,\n    const bool add_self_loop, const ValueType *probability) {\n  // process args\n  CHECK(aten::IsValidIdArray(seed_nodes));\n  const dgl_id_t *seed_nodes_data = static_cast<dgl_id_t *>(seed_nodes->data);\n  const int64_t num_seeds = seed_nodes->shape[0];\n  const int64_t num_workers = std::min(\n      max_num_workers,\n      (num_seeds + batch_size - 1) / batch_size - batch_start_id);\n  // We need to make sure we have the right CSR before we enter parallel\n  // sampling.\n  BuildCsr(*gptr, neigh_type);\n  // generate node flows\n  std::vector<NodeFlow> nflows(num_workers);\n  runtime::parallel_for(0, num_workers, [&](size_t b, size_t e) {\n    for (auto i = b; i < e; ++i) {\n      // create per-worker seed nodes.\n      const int64_t start = (batch_start_id + i) * batch_size;\n      const int64_t end = std::min(start + batch_size, num_seeds);\n      // TODO(minjie): the vector allocation/copy is unnecessary\n      std::vector<dgl_id_t> worker_seeds(end - start);\n      std::copy(\n          seed_nodes_data + start, seed_nodes_data + end, worker_seeds.begin());\n      nflows[i] = SamplerOp::NeighborSample(\n          gptr.get(), worker_seeds, neigh_type, num_hops, expand_factor,\n          add_self_loop, probability);\n    }\n  });\n  return nflows;\n}\n\nDGL_REGISTER_GLOBAL(\"sampling._CAPI_UniformSampling\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      // arguments\n      const GraphRef g = args[0];\n      const IdArray seed_nodes = args[1];\n      const int64_t batch_start_id = args[2];\n      const int64_t batch_size = args[3];\n      const int64_t max_num_workers = args[4];\n      const int64_t expand_factor = args[5];\n      const int64_t num_hops = args[6];\n      const std::string neigh_type = args[7];\n      const bool add_self_loop = args[8];\n\n      auto gptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());\n      CHECK(gptr) << \"sampling isn't implemented in mutable graph\";\n\n      CHECK(aten::IsValidIdArray(seed_nodes));\n      CHECK_EQ(seed_nodes->ctx.device_type, kDGLCPU)\n          << \"UniformSampler only support CPU sampling\";\n\n      std::vector<NodeFlow> nflows = NeighborSamplingImpl<float>(\n          gptr, seed_nodes, batch_start_id, batch_size, max_num_workers,\n          expand_factor, num_hops, neigh_type, add_self_loop, nullptr);\n\n      *rv = List<NodeFlow>(nflows);\n    });\n\nDGL_REGISTER_GLOBAL(\"sampling._CAPI_NeighborSampling\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      // arguments\n      const GraphRef g = args[0];\n      const IdArray seed_nodes = args[1];\n      const int64_t batch_start_id = args[2];\n      const int64_t batch_size = args[3];\n      const int64_t max_num_workers = args[4];\n      const int64_t expand_factor = args[5];\n      const int64_t num_hops = args[6];\n      const std::string neigh_type = args[7];\n      const bool add_self_loop = args[8];\n      const NDArray probability = args[9];\n\n      auto gptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());\n      CHECK(gptr) << \"sampling isn't implemented in mutable graph\";\n\n      CHECK(aten::IsValidIdArray(seed_nodes));\n      CHECK_EQ(seed_nodes->ctx.device_type, kDGLCPU)\n          << \"NeighborSampler only support CPU sampling\";\n\n      std::vector<NodeFlow> nflows;\n\n      CHECK(probability->dtype.code == kDGLFloat)\n          << \"transition probability must be float\";\n      CHECK(probability->ndim == 1)\n          << \"transition probability must be a 1-dimensional vector\";\n      CHECK_EQ(probability->ctx.device_type, kDGLCPU)\n          << \"NeighborSampling only support CPU sampling\";\n\n      ATEN_FLOAT_TYPE_SWITCH(\n          probability->dtype, FloatType, \"transition probability\", {\n            const FloatType *prob;\n\n            if (aten::IsNullArray(probability)) {\n              prob = nullptr;\n            } else {\n              CHECK(\n                  probability->shape[0] ==\n                  static_cast<int64_t>(gptr->NumEdges()))\n                  << \"transition probability must have same number of elements \"\n                     \"as edges\";\n              CHECK(probability.IsContiguous())\n                  << \"transition probability must be contiguous tensor\";\n              prob = static_cast<const FloatType *>(probability->data);\n            }\n\n            nflows = NeighborSamplingImpl(\n                gptr, seed_nodes, batch_start_id, batch_size, max_num_workers,\n                expand_factor, num_hops, neigh_type, add_self_loop, prob);\n          });\n\n      *rv = List<NodeFlow>(nflows);\n    });\n\nDGL_REGISTER_GLOBAL(\"sampling._CAPI_LayerSampling\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      // arguments\n      GraphRef g = args[0];\n      const IdArray seed_nodes = args[1];\n      const int64_t batch_start_id = args[2];\n      const int64_t batch_size = args[3];\n      const int64_t max_num_workers = args[4];\n      const IdArray layer_sizes = args[5];\n      const std::string neigh_type = args[6];\n      // process args\n      auto gptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());\n      CHECK(gptr) << \"sampling isn't implemented in mutable graph\";\n      CHECK(aten::IsValidIdArray(seed_nodes));\n      CHECK_EQ(seed_nodes->ctx.device_type, kDGLCPU)\n          << \"LayerSampler only support CPU sampling\";\n\n      CHECK(aten::IsValidIdArray(layer_sizes));\n      CHECK_EQ(layer_sizes->ctx.device_type, kDGLCPU)\n          << \"LayerSampler only support CPU sampling\";\n\n      const dgl_id_t *seed_nodes_data =\n          static_cast<dgl_id_t *>(seed_nodes->data);\n      const int64_t num_seeds = seed_nodes->shape[0];\n      const int64_t num_workers = std::min(\n          max_num_workers,\n          (num_seeds + batch_size - 1) / batch_size - batch_start_id);\n      // We need to make sure we have the right CSR before we enter parallel\n      // sampling.\n      BuildCsr(*gptr, neigh_type);\n      // generate node flows\n      std::vector<NodeFlow> nflows(num_workers);\n      runtime::parallel_for(0, num_workers, [&](size_t b, size_t e) {\n        for (auto i = b; i < e; ++i) {\n          // create per-worker seed nodes.\n          const int64_t start = (batch_start_id + i) * batch_size;\n          const int64_t end = std::min(start + batch_size, num_seeds);\n          // TODO(minjie): the vector allocation/copy is unnecessary\n          std::vector<dgl_id_t> worker_seeds(end - start);\n          std::copy(\n              seed_nodes_data + start, seed_nodes_data + end,\n              worker_seeds.begin());\n          nflows[i] = SamplerOp::LayerUniformSample(\n              gptr.get(), worker_seeds, neigh_type, layer_sizes);\n        }\n      });\n      *rv = List<NodeFlow>(nflows);\n    });\n\nnamespace {\n\nvoid BuildCoo(const ImmutableGraph &g) {\n  auto coo = g.GetCOO();\n  assert(coo);\n}\n\ndgl_id_t global2local_map(\n    dgl_id_t global_id, std::unordered_map<dgl_id_t, dgl_id_t> *map) {\n  auto it = map->find(global_id);\n  if (it == map->end()) {\n    dgl_id_t local_id = map->size();\n    map->insert(std::pair<dgl_id_t, dgl_id_t>(global_id, local_id));\n    return local_id;\n  } else {\n    return it->second;\n  }\n}\n\ninline bool IsNegativeHeadMode(const std::string &mode) {\n  return mode == \"head\";\n}\n\nIdArray GetGlobalVid(IdArray induced_nid, IdArray subg_nid) {\n  IdArray gnid =\n      IdArray::Empty({subg_nid->shape[0]}, subg_nid->dtype, subg_nid->ctx);\n  const dgl_id_t *induced_nid_data = static_cast<dgl_id_t *>(induced_nid->data);\n  const dgl_id_t *subg_nid_data = static_cast<dgl_id_t *>(subg_nid->data);\n  dgl_id_t *gnid_data = static_cast<dgl_id_t *>(gnid->data);\n  for (int64_t i = 0; i < subg_nid->shape[0]; i++) {\n    gnid_data[i] = induced_nid_data[subg_nid_data[i]];\n  }\n  return gnid;\n}\n\nIdArray CheckExistence(\n    GraphPtr gptr, IdArray neg_src, IdArray neg_dst, IdArray induced_nid) {\n  return gptr->HasEdgesBetween(\n      GetGlobalVid(induced_nid, neg_src), GetGlobalVid(induced_nid, neg_dst));\n}\n\nIdArray CheckExistence(\n    GraphPtr gptr, IdArray relations, IdArray neg_src, IdArray neg_dst,\n    IdArray induced_nid, IdArray neg_eid) {\n  neg_src = GetGlobalVid(induced_nid, neg_src);\n  neg_dst = GetGlobalVid(induced_nid, neg_dst);\n  BoolArray exist = gptr->HasEdgesBetween(neg_src, neg_dst);\n  dgl_id_t *neg_dst_data = static_cast<dgl_id_t *>(neg_dst->data);\n  dgl_id_t *neg_src_data = static_cast<dgl_id_t *>(neg_src->data);\n  dgl_id_t *neg_eid_data = static_cast<dgl_id_t *>(neg_eid->data);\n  dgl_id_t *relation_data = static_cast<dgl_id_t *>(relations->data);\n  // TODO(zhengda) is this right?\n  dgl_id_t *exist_data = static_cast<dgl_id_t *>(exist->data);\n  int64_t num_neg_edges = neg_src->shape[0];\n  for (int64_t i = 0; i < num_neg_edges; i++) {\n    // If the edge doesn't exist, we don't need to do anything.\n    if (!exist_data[i]) continue;\n    // If the edge exists, we need to double check if the relations match.\n    // If they match, this negative edge isn't really a negative edge.\n    dgl_id_t eid1 = neg_eid_data[i];\n    dgl_id_t orig_neg_rel1 = relation_data[eid1];\n    IdArray eids = gptr->EdgeId(neg_src_data[i], neg_dst_data[i]);\n    dgl_id_t *eid_data = static_cast<dgl_id_t *>(eids->data);\n    int64_t num_edges_between = eids->shape[0];\n    bool same_rel = false;\n    for (int64_t j = 0; j < num_edges_between; j++) {\n      dgl_id_t neg_rel1 = relation_data[eid_data[j]];\n      if (neg_rel1 == orig_neg_rel1) {\n        same_rel = true;\n        break;\n      }\n    }\n    exist_data[i] = same_rel;\n  }\n  return exist;\n}\n\nstd::vector<dgl_id_t> Global2Local(\n    const std::vector<size_t> &ids,\n    const std::unordered_map<dgl_id_t, dgl_id_t> &map) {\n  std::vector<dgl_id_t> local_ids(ids.size());\n  for (size_t i = 0; i < ids.size(); i++) {\n    auto it = map.find(ids[i]);\n    assert(it != map.end());\n    local_ids[i] = it->second;\n  }\n  return local_ids;\n}\n\nNegSubgraph EdgeSamplerObject::genNegEdgeSubgraph(\n    const Subgraph &pos_subg, const std::string &neg_mode,\n    int64_t neg_sample_size, bool exclude_positive, bool check_false_neg) {\n  int64_t num_tot_nodes = gptr_->NumVertices();\n  if (neg_sample_size > num_tot_nodes) neg_sample_size = num_tot_nodes;\n  std::vector<IdArray> adj = pos_subg.graph->GetAdj(false, \"coo\");\n  IdArray coo = adj[0];\n  int64_t num_pos_edges = coo->shape[0] / 2;\n  int64_t num_neg_edges = num_pos_edges * neg_sample_size;\n  IdArray neg_dst = IdArray::Empty({num_neg_edges}, coo->dtype, coo->ctx);\n  IdArray neg_src = IdArray::Empty({num_neg_edges}, coo->dtype, coo->ctx);\n  IdArray induced_neg_eid =\n      IdArray::Empty({num_neg_edges}, coo->dtype, coo->ctx);\n\n  // These are vids in the positive subgraph.\n  const dgl_id_t *dst_data = static_cast<const dgl_id_t *>(coo->data);\n  const dgl_id_t *src_data =\n      static_cast<const dgl_id_t *>(coo->data) + num_pos_edges;\n  const dgl_id_t *induced_vid_data =\n      static_cast<const dgl_id_t *>(pos_subg.induced_vertices->data);\n  const dgl_id_t *induced_eid_data =\n      static_cast<const dgl_id_t *>(pos_subg.induced_edges->data);\n  size_t num_pos_nodes = pos_subg.graph->NumVertices();\n  std::vector<size_t> pos_nodes(\n      induced_vid_data, induced_vid_data + num_pos_nodes);\n\n  dgl_id_t *neg_dst_data = static_cast<dgl_id_t *>(neg_dst->data);\n  dgl_id_t *neg_src_data = static_cast<dgl_id_t *>(neg_src->data);\n  dgl_id_t *induced_neg_eid_data =\n      static_cast<dgl_id_t *>(induced_neg_eid->data);\n\n  const dgl_id_t *unchanged;\n  dgl_id_t *neg_unchanged;\n  dgl_id_t *neg_changed;\n  if (IsNegativeHeadMode(neg_mode)) {\n    unchanged = dst_data;\n    neg_unchanged = neg_dst_data;\n    neg_changed = neg_src_data;\n  } else {\n    unchanged = src_data;\n    neg_unchanged = neg_src_data;\n    neg_changed = neg_dst_data;\n  }\n\n  std::unordered_map<dgl_id_t, dgl_id_t> neg_map;\n  std::vector<dgl_id_t> local_pos_vids;\n  local_pos_vids.reserve(num_pos_edges);\n\n  std::vector<size_t> neg_vids;\n  neg_vids.reserve(neg_sample_size);\n  // If we don't exclude positive edges, we are actually sampling more than\n  // the total number of nodes in the graph.\n  if (!exclude_positive && neg_sample_size >= num_tot_nodes) {\n    // We add all nodes as negative nodes.\n    for (int64_t i = 0; i < num_tot_nodes; i++) {\n      neg_vids.push_back(i);\n      neg_map[i] = i;\n    }\n\n    // Get all nodes in the positive side.\n    for (int64_t i = 0; i < num_pos_edges; i++) {\n      dgl_id_t vid = induced_vid_data[unchanged[i]];\n      local_pos_vids.push_back(neg_map[vid]);\n    }\n    // There is no guarantee that the nodes in the vector are unique.\n    std::sort(local_pos_vids.begin(), local_pos_vids.end());\n    auto it = std::unique(local_pos_vids.begin(), local_pos_vids.end());\n    local_pos_vids.resize(it - local_pos_vids.begin());\n  } else {\n    // Collect nodes in the positive side.\n    dgl_id_t local_vid = 0;\n    for (int64_t i = 0; i < num_pos_edges; i++) {\n      dgl_id_t vid = induced_vid_data[unchanged[i]];\n      auto it = neg_map.find(vid);\n      if (it == neg_map.end()) {\n        local_pos_vids.push_back(local_vid);\n        neg_map.insert(std::pair<dgl_id_t, dgl_id_t>(vid, local_vid++));\n      }\n    }\n  }\n\n  int64_t prev_neg_offset = 0;\n  for (int64_t i = 0; i < num_pos_edges; i++) {\n    size_t neg_idx = i * neg_sample_size;\n\n    std::vector<size_t> neighbors;\n    DGLIdIters neigh_it;\n    if (IsNegativeHeadMode(neg_mode)) {\n      neigh_it = gptr_->PredVec(induced_vid_data[unchanged[i]]);\n    } else {\n      neigh_it = gptr_->SuccVec(induced_vid_data[unchanged[i]]);\n    }\n\n    // If the number of negative nodes is smaller than the number of total nodes\n    // in the graph.\n    if (exclude_positive && neg_sample_size < num_tot_nodes) {\n      std::vector<size_t> exclude;\n      for (auto it = neigh_it.begin(); it != neigh_it.end(); it++) {\n        dgl_id_t global_vid = *it;\n        exclude.push_back(global_vid);\n      }\n      prev_neg_offset = neg_vids.size();\n      randomSample(num_tot_nodes, neg_sample_size, exclude, &neg_vids);\n      assert(\n          static_cast<size_t>(prev_neg_offset + neg_sample_size) ==\n          neg_vids.size());\n    } else if (neg_sample_size < num_tot_nodes) {\n      prev_neg_offset = neg_vids.size();\n      randomSample(num_tot_nodes, neg_sample_size, &neg_vids);\n      assert(\n          static_cast<size_t>(prev_neg_offset + neg_sample_size) ==\n          neg_vids.size());\n    } else if (exclude_positive) {\n      LOG(FATAL) << \"We can't exclude positive edges\"\n                    \"when sampling negative edges with all nodes.\";\n    } else {\n      // We don't need to do anything here.\n      // In this case, every edge has the same negative edges. That is,\n      // neg_vids contains all nodes of the graph. They have been generated\n      // before the for loop.\n    }\n\n    dgl_id_t global_unchanged = induced_vid_data[unchanged[i]];\n    dgl_id_t local_unchanged = global2local_map(global_unchanged, &neg_map);\n\n    for (int64_t j = 0; j < neg_sample_size; j++) {\n      neg_unchanged[neg_idx + j] = local_unchanged;\n      dgl_id_t local_changed =\n          global2local_map(neg_vids[j + prev_neg_offset], &neg_map);\n      neg_changed[neg_idx + j] = local_changed;\n      // induced negative eid references to the positive one.\n      induced_neg_eid_data[neg_idx + j] = induced_eid_data[i];\n    }\n  }\n\n  // Now we know the number of vertices in the negative graph.\n  int64_t num_neg_nodes = neg_map.size();\n  IdArray induced_neg_vid =\n      IdArray::Empty({num_neg_nodes}, coo->dtype, coo->ctx);\n  dgl_id_t *induced_neg_vid_data =\n      static_cast<dgl_id_t *>(induced_neg_vid->data);\n  for (auto it = neg_map.begin(); it != neg_map.end(); it++) {\n    induced_neg_vid_data[it->second] = it->first;\n  }\n\n  NegSubgraph neg_subg;\n  // We sample negative vertices without replacement.\n  // There shouldn't be duplicated edges.\n  COOPtr neg_coo(new COO(num_neg_nodes, neg_src, neg_dst));\n  neg_subg.graph = GraphPtr(new ImmutableGraph(neg_coo));\n  neg_subg.induced_vertices = induced_neg_vid;\n  neg_subg.induced_edges = induced_neg_eid;\n\n  if (IsNegativeHeadMode(neg_mode)) {\n    neg_subg.head_nid = aten::VecToIdArray(Global2Local(neg_vids, neg_map));\n    neg_subg.tail_nid = aten::VecToIdArray(local_pos_vids);\n  } else {\n    neg_subg.head_nid = aten::VecToIdArray(local_pos_vids);\n    neg_subg.tail_nid = aten::VecToIdArray(Global2Local(neg_vids, neg_map));\n  }\n  // TODO(zhengda) we should provide an array of 1s if exclude_positive\n  if (check_false_neg) {\n    if (aten::IsNullArray(relations_)) {\n      neg_subg.exist = CheckExistence(gptr_, neg_src, neg_dst, induced_neg_vid);\n    } else {\n      neg_subg.exist = CheckExistence(\n          gptr_, relations_, neg_src, neg_dst, induced_neg_vid,\n          induced_neg_eid);\n    }\n  }\n  return neg_subg;\n}\n\nNegSubgraph EdgeSamplerObject::genChunkedNegEdgeSubgraph(\n    const Subgraph &pos_subg, const std::string &neg_mode,\n    int64_t neg_sample_size, bool exclude_positive, bool check_false_neg) {\n  int64_t num_tot_nodes = gptr_->NumVertices();\n  std::vector<IdArray> adj = pos_subg.graph->GetAdj(false, \"coo\");\n  IdArray coo = adj[0];\n  int64_t num_pos_edges = coo->shape[0] / 2;\n  if (neg_sample_size > num_tot_nodes) neg_sample_size = num_tot_nodes;\n\n  int64_t chunk_size = chunk_size_;\n  CHECK_GT(chunk_size, 0) << \"chunk size has to be positive\";\n  // If num_pos_edges isn't divisible by chunk_size, the actual number of chunks\n  // is num_chunks + 1 and the last chunk size is last_chunk_size.\n  // Otherwise, the actual number of chunks is num_chunks, the last chunk size\n  // is 0.\n  int64_t num_chunks = num_pos_edges / chunk_size;\n  int64_t last_chunk_size = num_pos_edges - num_chunks * chunk_size;\n\n  // The number of negative edges.\n  int64_t num_neg_edges = neg_sample_size * chunk_size * num_chunks;\n  int64_t num_neg_edges_last_chunk = neg_sample_size * last_chunk_size;\n  int64_t num_all_neg_edges = num_neg_edges + num_neg_edges_last_chunk;\n\n  // We should include the last chunk.\n  if (last_chunk_size > 0) num_chunks++;\n\n  IdArray neg_dst = IdArray::Empty({num_all_neg_edges}, coo->dtype, coo->ctx);\n  IdArray neg_src = IdArray::Empty({num_all_neg_edges}, coo->dtype, coo->ctx);\n  IdArray induced_neg_eid =\n      IdArray::Empty({num_all_neg_edges}, coo->dtype, coo->ctx);\n\n  // These are vids in the positive subgraph.\n  const dgl_id_t *dst_data = static_cast<const dgl_id_t *>(coo->data);\n  const dgl_id_t *src_data =\n      static_cast<const dgl_id_t *>(coo->data) + num_pos_edges;\n  const dgl_id_t *induced_vid_data =\n      static_cast<const dgl_id_t *>(pos_subg.induced_vertices->data);\n  const dgl_id_t *induced_eid_data =\n      static_cast<const dgl_id_t *>(pos_subg.induced_edges->data);\n  int64_t num_pos_nodes = pos_subg.graph->NumVertices();\n  std::vector<dgl_id_t> pos_nodes(\n      induced_vid_data, induced_vid_data + num_pos_nodes);\n\n  dgl_id_t *neg_dst_data = static_cast<dgl_id_t *>(neg_dst->data);\n  dgl_id_t *neg_src_data = static_cast<dgl_id_t *>(neg_src->data);\n  dgl_id_t *induced_neg_eid_data =\n      static_cast<dgl_id_t *>(induced_neg_eid->data);\n\n  const dgl_id_t *unchanged;\n  dgl_id_t *neg_unchanged;\n  dgl_id_t *neg_changed;\n  if (IsNegativeHeadMode(neg_mode)) {\n    unchanged = dst_data;\n    neg_unchanged = neg_dst_data;\n    neg_changed = neg_src_data;\n  } else {\n    unchanged = src_data;\n    neg_unchanged = neg_src_data;\n    neg_changed = neg_dst_data;\n  }\n\n  // We first sample all negative edges.\n  std::vector<size_t> global_neg_vids;\n  std::vector<size_t> local_neg_vids;\n  randomSample(num_tot_nodes, num_chunks * neg_sample_size, &global_neg_vids);\n  CHECK_EQ(num_chunks * neg_sample_size, global_neg_vids.size());\n\n  std::unordered_map<dgl_id_t, dgl_id_t> neg_map;\n  dgl_id_t local_vid = 0;\n\n  // Collect nodes in the positive side.\n  std::vector<dgl_id_t> local_pos_vids;\n  local_pos_vids.reserve(num_pos_edges);\n  for (int64_t i = 0; i < num_pos_edges; i++) {\n    dgl_id_t vid = induced_vid_data[unchanged[i]];\n    auto it = neg_map.find(vid);\n    if (it == neg_map.end()) {\n      local_pos_vids.push_back(local_vid);\n      neg_map.insert(std::pair<dgl_id_t, dgl_id_t>(vid, local_vid++));\n    }\n  }\n\n  // We should map the global negative nodes to local Ids in advance\n  // to reduce computation overhead.\n  local_neg_vids.resize(global_neg_vids.size());\n  for (size_t i = 0; i < global_neg_vids.size(); i++) {\n    local_neg_vids[i] = global2local_map(global_neg_vids[i], &neg_map);\n  }\n\n  for (int64_t i_chunk = 0; i_chunk < num_chunks; i_chunk++) {\n    // for each chunk.\n    int64_t neg_idx = neg_sample_size * chunk_size * i_chunk;\n    int64_t pos_edge_idx = chunk_size * i_chunk;\n    int64_t neg_node_idx = neg_sample_size * i_chunk;\n    // The actual chunk size. It'll be different for the last chunk.\n    int64_t chunk_size1;\n    if (i_chunk == num_chunks - 1 && last_chunk_size > 0)\n      chunk_size1 = last_chunk_size;\n    else\n      chunk_size1 = chunk_size;\n\n    for (int64_t in_chunk = 0; in_chunk != chunk_size1; ++in_chunk) {\n      // For each positive node in a chunk.\n      dgl_id_t global_unchanged =\n          induced_vid_data[unchanged[pos_edge_idx + in_chunk]];\n      dgl_id_t local_unchanged = global2local_map(global_unchanged, &neg_map);\n      for (int64_t j = 0; j < neg_sample_size; ++j) {\n        neg_unchanged[neg_idx] = local_unchanged;\n        neg_changed[neg_idx] = local_neg_vids[neg_node_idx + j];\n        induced_neg_eid_data[neg_idx] =\n            induced_eid_data[pos_edge_idx + in_chunk];\n        neg_idx++;\n      }\n    }\n  }\n\n  // Now we know the number of vertices in the negative graph.\n  int64_t num_neg_nodes = neg_map.size();\n  IdArray induced_neg_vid =\n      IdArray::Empty({num_neg_nodes}, coo->dtype, coo->ctx);\n  dgl_id_t *induced_neg_vid_data =\n      static_cast<dgl_id_t *>(induced_neg_vid->data);\n  for (auto it = neg_map.begin(); it != neg_map.end(); it++) {\n    induced_neg_vid_data[it->second] = it->first;\n  }\n\n  NegSubgraph neg_subg;\n  // We sample negative vertices without replacement.\n  // There shouldn't be duplicated edges.\n  COOPtr neg_coo(new COO(num_neg_nodes, neg_src, neg_dst));\n  neg_subg.graph = GraphPtr(new ImmutableGraph(neg_coo));\n  neg_subg.induced_vertices = induced_neg_vid;\n  neg_subg.induced_edges = induced_neg_eid;\n  if (IsNegativeHeadMode(neg_mode)) {\n    neg_subg.head_nid =\n        aten::VecToIdArray(Global2Local(global_neg_vids, neg_map));\n    neg_subg.tail_nid = aten::VecToIdArray(local_pos_vids);\n  } else {\n    neg_subg.head_nid = aten::VecToIdArray(local_pos_vids);\n    neg_subg.tail_nid =\n        aten::VecToIdArray(Global2Local(global_neg_vids, neg_map));\n  }\n  if (check_false_neg) {\n    if (aten::IsNullArray(relations_)) {\n      neg_subg.exist = CheckExistence(gptr_, neg_src, neg_dst, induced_neg_vid);\n    } else {\n      neg_subg.exist = CheckExistence(\n          gptr_, relations_, neg_src, neg_dst, induced_neg_vid,\n          induced_neg_eid);\n    }\n  }\n  return neg_subg;\n}\n\ninline SubgraphRef ConvertRef(const Subgraph &subg) {\n  return SubgraphRef(std::shared_ptr<Subgraph>(new Subgraph(subg)));\n}\n\ninline SubgraphRef ConvertRef(const NegSubgraph &subg) {\n  return SubgraphRef(std::shared_ptr<Subgraph>(new NegSubgraph(subg)));\n}\n\n}  // namespace\n\nDGL_REGISTER_GLOBAL(\"sampling._CAPI_GetNegEdgeExistence\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      SubgraphRef g = args[0];\n      auto gptr = std::dynamic_pointer_cast<NegSubgraph>(g.sptr());\n      *rv = gptr->exist;\n    });\n\nDGL_REGISTER_GLOBAL(\"sampling._CAPI_GetEdgeSubgraphHead\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      SubgraphRef g = args[0];\n      auto gptr = std::dynamic_pointer_cast<NegSubgraph>(g.sptr());\n      *rv = gptr->head_nid;\n    });\n\nDGL_REGISTER_GLOBAL(\"sampling._CAPI_GetEdgeSubgraphTail\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      SubgraphRef g = args[0];\n      auto gptr = std::dynamic_pointer_cast<NegSubgraph>(g.sptr());\n      *rv = gptr->tail_nid;\n    });\n\nclass UniformEdgeSamplerObject : public EdgeSamplerObject {\n public:\n  explicit UniformEdgeSamplerObject(\n      const GraphPtr gptr, IdArray seed_edges, const int64_t batch_size,\n      const int64_t num_workers, const bool replacement, const bool reset,\n      const std::string neg_mode, const int64_t neg_sample_size,\n      const int64_t chunk_size, const bool exclude_positive,\n      const bool check_false_neg, IdArray relations)\n      : EdgeSamplerObject(\n            gptr, seed_edges, batch_size, num_workers, replacement, reset,\n            neg_mode, neg_sample_size, chunk_size, exclude_positive,\n            check_false_neg, relations) {\n    batch_curr_id_ = 0;\n    num_seeds_ = seed_edges->shape[0];\n    max_batch_id_ = (num_seeds_ + batch_size - 1) / batch_size;\n\n    // TODO(song): Tricky thing here to make sure gptr_ has coo cache\n    gptr_->FindEdge(0);\n  }\n  ~UniformEdgeSamplerObject() {}\n\n  void Fetch(DGLRetValue *rv) {\n    const int64_t num_workers =\n        std::min(num_workers_, max_batch_id_ - batch_curr_id_);\n    // generate subgraphs.\n    std::vector<SubgraphRef> positive_subgs(num_workers);\n    std::vector<SubgraphRef> negative_subgs(num_workers);\n\n    runtime::parallel_for(0, num_workers, [&](size_t b, size_t e) {\n      for (auto i = b; i < e; ++i) {\n        const int64_t start = (batch_curr_id_ + i) * batch_size_;\n        const int64_t end = std::min(start + batch_size_, num_seeds_);\n        const int64_t num_edges = end - start;\n        IdArray worker_seeds;\n\n        if (replacement_ == false) {\n          worker_seeds = seed_edges_.CreateView(\n              {num_edges}, DGLDataType{kDGLInt, 64, 1},\n              sizeof(dgl_id_t) * start);\n        } else {\n          std::vector<dgl_id_t> seeds;\n          const dgl_id_t *seed_edge_ids =\n              static_cast<const dgl_id_t *>(seed_edges_->data);\n          // sampling of each edge is a standalone event\n          for (int64_t i = 0; i < num_edges; ++i) {\n            int64_t seed = static_cast<const int64_t>(\n                RandomEngine::ThreadLocal()->RandInt(num_seeds_));\n            seeds.push_back(seed_edge_ids[seed]);\n          }\n\n          worker_seeds = aten::VecToIdArray(seeds, seed_edges_->dtype.bits);\n        }\n\n        EdgeArray arr = gptr_->FindEdges(worker_seeds);\n        const dgl_id_t *src_ids = static_cast<const dgl_id_t *>(arr.src->data);\n        const dgl_id_t *dst_ids = static_cast<const dgl_id_t *>(arr.dst->data);\n        std::vector<dgl_id_t> src_vec(src_ids, src_ids + num_edges);\n        std::vector<dgl_id_t> dst_vec(dst_ids, dst_ids + num_edges);\n        // TODO(zhengda) what if there are duplicates in the src and dst\n        // vectors.\n\n        Subgraph subg = gptr_->EdgeSubgraph(worker_seeds, false);\n        positive_subgs[i] = ConvertRef(subg);\n        // For chunked negative sampling, we accept \"chunk-head\" for corrupting\n        // head nodes and \"chunk-tail\" for corrupting tail nodes.\n        if (neg_mode_.substr(0, 5) == \"chunk\") {\n          NegSubgraph neg_subg = genChunkedNegEdgeSubgraph(\n              subg, neg_mode_.substr(6), neg_sample_size_, exclude_positive_,\n              check_false_neg_);\n          negative_subgs[i] = ConvertRef(neg_subg);\n        } else if (neg_mode_ == \"head\" || neg_mode_ == \"tail\") {\n          NegSubgraph neg_subg = genNegEdgeSubgraph(\n              subg, neg_mode_, neg_sample_size_, exclude_positive_,\n              check_false_neg_);\n          negative_subgs[i] = ConvertRef(neg_subg);\n        }\n      }\n    });\n    if (neg_mode_.size() > 0) {\n      positive_subgs.insert(\n          positive_subgs.end(), negative_subgs.begin(), negative_subgs.end());\n    }\n    batch_curr_id_ += num_workers;\n\n    if (batch_curr_id_ >= max_batch_id_ && reset_ == true) {\n      Reset();\n    }\n\n    *rv = List<SubgraphRef>(positive_subgs);\n  }\n\n  void Reset() {\n    batch_curr_id_ = 0;\n    if (replacement_ == false) {\n      // Now we should shuffle the data and reset the sampler.\n      dgl_id_t *seed_ids = static_cast<dgl_id_t *>(seed_edges_->data);\n      std::shuffle(\n          seed_ids, seed_ids + seed_edges_->shape[0],\n          std::default_random_engine());\n    }\n  }\n\n  DGL_DECLARE_OBJECT_TYPE_INFO(UniformEdgeSamplerObject, Object);\n\n private:\n  void randomSample(size_t set_size, size_t num, std::vector<size_t> *out) {\n    RandomSample(set_size, num, out);\n  }\n\n  void randomSample(\n      size_t set_size, size_t num, const std::vector<size_t> &exclude,\n      std::vector<size_t> *out) {\n    RandomSample(set_size, num, exclude, out);\n  }\n\n  int64_t batch_curr_id_;\n  int64_t max_batch_id_;\n  int64_t num_seeds_;\n};\n\nclass UniformEdgeSampler : public ObjectRef {\n public:\n  UniformEdgeSampler() {}\n  explicit UniformEdgeSampler(std::shared_ptr<runtime::Object> obj)\n      : ObjectRef(obj) {}\n\n  UniformEdgeSamplerObject *operator->() const {\n    return static_cast<UniformEdgeSamplerObject *>(obj_.get());\n  }\n\n  std::shared_ptr<UniformEdgeSamplerObject> sptr() const {\n    return CHECK_NOTNULL(\n        std::dynamic_pointer_cast<UniformEdgeSamplerObject>(obj_));\n  }\n\n  operator bool() const { return this->defined(); }\n  using ContainerType = UniformEdgeSamplerObject;\n};\n\nDGL_REGISTER_GLOBAL(\"sampling._CAPI_CreateUniformEdgeSampler\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      // arguments\n      GraphRef g = args[0];\n      IdArray seed_edges = args[1];\n      const int64_t batch_size = args[2];\n      const int64_t max_num_workers = args[3];\n      const bool replacement = args[4];\n      const bool reset = args[5];\n      const std::string neg_mode = args[6];\n      const int neg_sample_size = args[7];\n      const bool exclude_positive = args[8];\n      const bool check_false_neg = args[9];\n      IdArray relations = args[10];\n      const int64_t chunk_size = args[11];\n      // process args\n      auto gptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());\n      CHECK(gptr) << \"sampling isn't implemented in mutable graph\";\n      CHECK(aten::IsValidIdArray(seed_edges));\n      CHECK_EQ(seed_edges->ctx.device_type, kDGLCPU)\n          << \"UniformEdgeSampler only support CPU sampling\";\n\n      if (relations->shape[0] > 0) {\n        CHECK(aten::IsValidIdArray(relations));\n        CHECK_EQ(relations->ctx.device_type, kDGLCPU)\n            << \"WeightedEdgeSampler only support CPU sampling\";\n      }\n      BuildCoo(*gptr);\n\n      auto o = std::make_shared<UniformEdgeSamplerObject>(\n          gptr, seed_edges, batch_size, max_num_workers, replacement, reset,\n          neg_mode, neg_sample_size, chunk_size, exclude_positive,\n          check_false_neg, relations);\n      *rv = o;\n    });\n\nDGL_REGISTER_GLOBAL(\"sampling._CAPI_FetchUniformEdgeSample\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      UniformEdgeSampler sampler = args[0];\n      sampler->Fetch(rv);\n    });\n\nDGL_REGISTER_GLOBAL(\"sampling._CAPI_ResetUniformEdgeSample\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      UniformEdgeSampler sampler = args[0];\n      sampler->Reset();\n    });\n\ntemplate <typename ValueType>\nclass WeightedEdgeSamplerObject : public EdgeSamplerObject {\n public:\n  explicit WeightedEdgeSamplerObject(\n      const GraphPtr gptr, IdArray seed_edges, NDArray edge_weight,\n      NDArray node_weight, const int64_t batch_size, const int64_t num_workers,\n      const bool replacement, const bool reset, const std::string neg_mode,\n      const int64_t neg_sample_size, const int64_t chunk_size,\n      const bool exclude_positive, const bool check_false_neg,\n      IdArray relations)\n      : EdgeSamplerObject(\n            gptr, seed_edges, batch_size, num_workers, replacement, reset,\n            neg_mode, neg_sample_size, chunk_size, exclude_positive,\n            check_false_neg, relations) {\n    const int64_t num_edges = edge_weight->shape[0];\n    const ValueType *edge_prob =\n        static_cast<const ValueType *>(edge_weight->data);\n    std::vector<ValueType> eprob(num_edges);\n    for (int64_t i = 0; i < num_edges; ++i) {\n      eprob[i] = edge_prob[i];\n    }\n    edge_selector_ = std::make_shared<ArrayHeap<ValueType>>(eprob);\n    edge_weight_ = edge_weight;\n\n    const size_t num_nodes = node_weight->shape[0];\n    if (num_nodes == 0) {\n      node_selector_ = nullptr;\n    } else {\n      const ValueType *node_prob =\n          static_cast<const ValueType *>(node_weight->data);\n      std::vector<ValueType> nprob(num_nodes);\n      for (size_t i = 0; i < num_nodes; ++i) {\n        nprob[i] = node_prob[i];\n      }\n      node_selector_ = std::make_shared<ArrayHeap<ValueType>>(nprob);\n    }\n\n    curr_batch_id_ = 0;\n    // handle int64 overflow here\n    max_batch_id_ = (num_edges + batch_size - 1) / batch_size;\n    // TODO(song): Tricky thing here to make sure gptr_ has coo cache\n    gptr_->FindEdge(0);\n  }\n\n  ~WeightedEdgeSamplerObject() {}\n\n  void Fetch(DGLRetValue *rv) {\n    const int64_t num_workers =\n        std::min(num_workers_, max_batch_id_ - curr_batch_id_);\n    // generate subgraphs.\n    std::vector<SubgraphRef> positive_subgs(num_workers);\n    std::vector<SubgraphRef> negative_subgs(num_workers);\n\n#pragma omp parallel for\n    for (int i = 0; i < num_workers; i++) {\n      const dgl_id_t *seed_edge_ids =\n          static_cast<const dgl_id_t *>(seed_edges_->data);\n      std::vector<size_t> edge_ids(batch_size_);\n\n      if (replacement_ == false) {\n        size_t n = batch_size_;\n        size_t num_ids = 0;\n#pragma omp critical\n        { num_ids = edge_selector_->SampleWithoutReplacement(n, &edge_ids); }\n        edge_ids.resize(num_ids);\n        for (size_t i = 0; i < num_ids; ++i) {\n          edge_ids[i] = seed_edge_ids[edge_ids[i]];\n        }\n      } else {\n        // sampling of each edge is a standalone event\n        for (int i = 0; i < batch_size_; ++i) {\n          size_t edge_id = edge_selector_->Sample();\n          edge_ids[i] = seed_edge_ids[edge_id];\n        }\n      }\n\n      auto worker_seeds = aten::VecToIdArray(edge_ids, seed_edges_->dtype.bits);\n\n      EdgeArray arr = gptr_->FindEdges(worker_seeds);\n      const dgl_id_t *src_ids = static_cast<const dgl_id_t *>(arr.src->data);\n      const dgl_id_t *dst_ids = static_cast<const dgl_id_t *>(arr.dst->data);\n      std::vector<dgl_id_t> src_vec(src_ids, src_ids + batch_size_);\n      std::vector<dgl_id_t> dst_vec(dst_ids, dst_ids + batch_size_);\n      // TODO(zhengda) what if there are duplicates in the src and dst vectors.\n      Subgraph subg = gptr_->EdgeSubgraph(worker_seeds, false);\n      positive_subgs[i] = ConvertRef(subg);\n      // For chunked negative sampling, we accept \"chunk-head\" for corrupting\n      // head nodes and \"chunk-tail\" for corrupting tail nodes.\n      if (neg_mode_.substr(0, 5) == \"chunk\") {\n        NegSubgraph neg_subg = genChunkedNegEdgeSubgraph(\n            subg, neg_mode_.substr(6), neg_sample_size_, exclude_positive_,\n            check_false_neg_);\n        negative_subgs[i] = ConvertRef(neg_subg);\n      } else if (neg_mode_ == \"head\" || neg_mode_ == \"tail\") {\n        NegSubgraph neg_subg = genNegEdgeSubgraph(\n            subg, neg_mode_, neg_sample_size_, exclude_positive_,\n            check_false_neg_);\n        negative_subgs[i] = ConvertRef(neg_subg);\n      }\n    }\n    curr_batch_id_ += num_workers;\n\n    if (curr_batch_id_ >= max_batch_id_ && reset_ == true) {\n      Reset();\n    }\n\n    if (neg_mode_.size() > 0) {\n      positive_subgs.insert(\n          positive_subgs.end(), negative_subgs.begin(), negative_subgs.end());\n    }\n    *rv = List<SubgraphRef>(positive_subgs);\n  }\n\n  void Reset() {\n    curr_batch_id_ = 0;\n    if (replacement_ == false) {\n      const int64_t num_edges = edge_weight_->shape[0];\n      const ValueType *edge_prob =\n          static_cast<const ValueType *>(edge_weight_->data);\n      std::vector<ValueType> eprob(num_edges);\n      for (int64_t i = 0; i < num_edges; ++i) {\n        eprob[i] = edge_prob[i];\n      }\n\n      // rebuild the edge_selector_\n      edge_selector_ = std::make_shared<ArrayHeap<ValueType>>(eprob);\n    }\n  }\n\n  DGL_DECLARE_OBJECT_TYPE_INFO(WeightedEdgeSamplerObject<ValueType>, Object);\n\n private:\n  void randomSample(size_t set_size, size_t num, std::vector<size_t> *out) {\n    if (num < set_size) {\n      std::unordered_set<size_t> sampled_idxs;\n      while (sampled_idxs.size() < num) {\n        if (node_selector_ == nullptr) {\n          sampled_idxs.insert(RandomEngine::ThreadLocal()->RandInt(set_size));\n        } else {\n          size_t id = node_selector_->Sample();\n          sampled_idxs.insert(id);\n        }\n      }\n\n      out->insert(out->end(), sampled_idxs.begin(), sampled_idxs.end());\n    } else {\n      // If we need to sample all elements in the set, we don't need to\n      // generate random numbers.\n      for (size_t i = 0; i < set_size; i++) out->push_back(i);\n    }\n  }\n\n  void randomSample(\n      size_t set_size, size_t num, const std::vector<size_t> &exclude,\n      std::vector<size_t> *out) {\n    std::unordered_map<size_t, int> sampled_idxs;\n    for (auto v : exclude) {\n      sampled_idxs.insert(std::pair<size_t, int>(v, 0));\n    }\n    if (num + exclude.size() < set_size) {\n      while (sampled_idxs.size() < num + exclude.size()) {\n        size_t rand;\n        if (node_selector_ == nullptr) {\n          rand = RandomEngine::ThreadLocal()->RandInt(set_size);\n        } else {\n          rand = node_selector_->Sample();\n        }\n        sampled_idxs.insert(std::pair<size_t, int>(rand, 1));\n      }\n      for (auto it = sampled_idxs.begin(); it != sampled_idxs.end(); it++) {\n        if (it->second) {\n          out->push_back(it->first);\n        }\n      }\n    } else {\n      // If we need to sample all elements in the set, we don't need to\n      // generate random numbers.\n      for (size_t i = 0; i < set_size; i++) {\n        // If the element doesn't exist in exclude.\n        if (sampled_idxs.find(i) == sampled_idxs.end()) {\n          out->push_back(i);\n        }\n      }\n    }\n  }\n\n private:\n  std::shared_ptr<ArrayHeap<ValueType>> edge_selector_;\n  std::shared_ptr<ArrayHeap<ValueType>> node_selector_;\n\n  NDArray edge_weight_;\n  int64_t curr_batch_id_;\n  int64_t max_batch_id_;\n};\n\ntemplate class WeightedEdgeSamplerObject<float>;\n\nclass FloatWeightedEdgeSampler : public ObjectRef {\n public:\n  FloatWeightedEdgeSampler() {}\n  explicit FloatWeightedEdgeSampler(std::shared_ptr<runtime::Object> obj)\n      : ObjectRef(obj) {}\n\n  WeightedEdgeSamplerObject<float> *operator->() const {\n    return static_cast<WeightedEdgeSamplerObject<float> *>(obj_.get());\n  }\n\n  std::shared_ptr<WeightedEdgeSamplerObject<float>> sptr() const {\n    return CHECK_NOTNULL(\n        std::dynamic_pointer_cast<WeightedEdgeSamplerObject<float>>(obj_));\n  }\n\n  operator bool() const { return this->defined(); }\n  using ContainerType = WeightedEdgeSamplerObject<float>;\n};\n\nDGL_REGISTER_GLOBAL(\"sampling._CAPI_CreateWeightedEdgeSampler\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      // arguments\n      GraphRef g = args[0];\n      IdArray seed_edges = args[1];\n      NDArray edge_weight = args[2];\n      NDArray node_weight = args[3];\n      const int64_t batch_size = args[4];\n      const int64_t max_num_workers = args[5];\n      const bool replacement = args[6];\n      const bool reset = args[7];\n      const std::string neg_mode = args[8];\n      const int64_t neg_sample_size = args[9];\n      const bool exclude_positive = args[10];\n      const bool check_false_neg = args[11];\n      IdArray relations = args[12];\n      const int64_t chunk_size = args[13];\n\n      auto gptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());\n      CHECK(gptr) << \"sampling isn't implemented in mutable graph\";\n      CHECK(aten::IsValidIdArray(seed_edges));\n      CHECK_EQ(seed_edges->ctx.device_type, kDGLCPU)\n          << \"WeightedEdgeSampler only support CPU sampling\";\n      CHECK(edge_weight->dtype.code == kDGLFloat)\n          << \"edge_weight should be FloatType\";\n      CHECK(edge_weight->dtype.bits == 32)\n          << \"WeightedEdgeSampler only support float weight\";\n      CHECK_EQ(edge_weight->ctx.device_type, kDGLCPU)\n          << \"WeightedEdgeSampler only support CPU sampling\";\n      if (node_weight->shape[0] > 0) {\n        CHECK(node_weight->dtype.code == kDGLFloat)\n            << \"node_weight should be FloatType\";\n        CHECK(node_weight->dtype.bits == 32)\n            << \"WeightedEdgeSampler only support float weight\";\n        CHECK_EQ(node_weight->ctx.device_type, kDGLCPU)\n            << \"WeightedEdgeSampler only support CPU sampling\";\n      }\n      if (relations->shape[0] > 0) {\n        CHECK(aten::IsValidIdArray(relations));\n        CHECK_EQ(relations->ctx.device_type, kDGLCPU)\n            << \"WeightedEdgeSampler only support CPU sampling\";\n      }\n      BuildCoo(*gptr);\n\n      const int64_t num_seeds = seed_edges->shape[0];\n      const int64_t num_workers =\n          std::min(max_num_workers, (num_seeds + batch_size - 1) / batch_size);\n\n      auto o = std::make_shared<WeightedEdgeSamplerObject<float>>(\n          gptr, seed_edges, edge_weight, node_weight, batch_size, num_workers,\n          replacement, reset, neg_mode, neg_sample_size, chunk_size,\n          exclude_positive, check_false_neg, relations);\n      *rv = o;\n    });\n\nDGL_REGISTER_GLOBAL(\"sampling._CAPI_FetchWeightedEdgeSample\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      FloatWeightedEdgeSampler sampler = args[0];\n      sampler->Fetch(rv);\n    });\n\nDGL_REGISTER_GLOBAL(\"sampling._CAPI_ResetWeightedEdgeSample\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      FloatWeightedEdgeSampler sampler = args[0];\n      sampler->Reset();\n    });\n\n}  // namespace dgl\n"
  },
  {
    "path": "src/graph/sampling/negative/global_uniform.cc",
    "content": "/**\n *  Copyright (c) 2021 by Contributors\n * @file graph/sampling/negative/global_uniform.cc\n * @brief Global uniform negative sampling.\n */\n\n#include <dgl/array.h>\n#include <dgl/base_heterograph.h>\n#include <dgl/packed_func_ext.h>\n#include <dgl/runtime/container.h>\n#include <dgl/sampling/negative.h>\n\n#include <utility>\n\n#include \"../../../c_api_common.h\"\n\nusing namespace dgl::runtime;\nusing namespace dgl::aten;\n\nnamespace dgl {\nnamespace sampling {\n\nstd::pair<IdArray, IdArray> GlobalUniformNegativeSampling(\n    HeteroGraphPtr hg, dgl_type_t etype, int64_t num_samples, int num_trials,\n    bool exclude_self_loops, bool replace, double redundancy) {\n  auto format = hg->SelectFormat(etype, CSC_CODE | CSR_CODE);\n  if (format == SparseFormat::kCSC) {\n    CSRMatrix csc = hg->GetCSCMatrix(etype);\n    CSRSort_(&csc);\n    std::pair<IdArray, IdArray> result = CSRGlobalUniformNegativeSampling(\n        csc, num_samples, num_trials, exclude_self_loops, replace, redundancy);\n    // reverse the pair since it is CSC\n    return {result.second, result.first};\n  } else if (format == SparseFormat::kCSR) {\n    CSRMatrix csr = hg->GetCSRMatrix(etype);\n    CSRSort_(&csr);\n    return CSRGlobalUniformNegativeSampling(\n        csr, num_samples, num_trials, exclude_self_loops, replace, redundancy);\n  } else {\n    LOG(FATAL)\n        << \"COO format is not supported in global uniform negative sampling\";\n    return {IdArray(), IdArray()};\n  }\n}\n\nDGL_REGISTER_GLOBAL(\"sampling.negative._CAPI_DGLGlobalUniformNegativeSampling\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      dgl_type_t etype = args[1];\n      CHECK_LE(etype, hg->NumEdgeTypes()) << \"invalid edge type \" << etype;\n      int64_t num_samples = args[2];\n      int num_trials = args[3];\n      bool exclude_self_loops = args[4];\n      bool replace = args[5];\n      double redundancy = args[6];\n      List<Value> result;\n      std::pair<IdArray, IdArray> ret = GlobalUniformNegativeSampling(\n          hg.sptr(), etype, num_samples, num_trials, exclude_self_loops,\n          replace, redundancy);\n      result.push_back(Value(MakeValue(ret.first)));\n      result.push_back(Value(MakeValue(ret.second)));\n      *rv = result;\n    });\n\n};  // namespace sampling\n};  // namespace dgl\n"
  },
  {
    "path": "src/graph/sampling/neighbor/neighbor.cc",
    "content": "/**\n *  Copyright (c) 2020-2022 by Contributors\n * @file graph/sampling/neighbor.cc\n * @brief Definition of neighborhood-based sampler APIs.\n */\n\n#include <dgl/array.h>\n#include <dgl/aten/macro.h>\n#include <dgl/immutable_graph.h>\n#include <dgl/packed_func_ext.h>\n#include <dgl/runtime/container.h>\n#include <dgl/runtime/parallel_for.h>\n#include <dgl/sampling/neighbor.h>\n\n#include <tuple>\n#include <utility>\n\n#include \"../../../array/cpu/concurrent_id_hash_map.h\"\n#include \"../../../c_api_common.h\"\n#include \"../../unit_graph.h\"\n\nusing namespace dgl::runtime;\nusing namespace dgl::aten;\n\nnamespace dgl {\nnamespace sampling {\n\ntemplate <typename IdType>\nvoid ExcludeCertainEdgesFused(\n    std::vector<CSRMatrix>* sampled_graphs, std::vector<IdArray>* induced_edges,\n    std::vector<IdArray>* sampled_coo_rows,\n    const std::vector<IdArray>& exclude_edges,\n    std::vector<FloatArray>* weights = nullptr) {\n  int etypes = (*sampled_graphs).size();\n  std::vector<IdArray> remain_induced_edges(etypes);\n  std::vector<IdArray> remain_indptrs(etypes);\n  std::vector<IdArray> remain_indices(etypes);\n  std::vector<IdArray> remain_coo_rows(etypes);\n  std::vector<FloatArray> remain_weights(etypes);\n  for (int etype = 0; etype < etypes; ++etype) {\n    if (exclude_edges[etype].GetSize() == 0 ||\n        (*sampled_graphs)[etype].num_rows == 0) {\n      remain_induced_edges[etype] = (*induced_edges)[etype];\n      if (weights) remain_weights[etype] = (*weights)[etype];\n      continue;\n    }\n    const auto dtype = weights && (*weights)[etype]->shape[0]\n                           ? (*weights)[etype]->dtype\n                           : DGLDataType{kDGLFloat, 8 * sizeof(float), 1};\n    ATEN_FLOAT_TYPE_SWITCH(dtype, FloatType, \"weights\", {\n      IdType* indptr = (*sampled_graphs)[etype].indptr.Ptr<IdType>();\n      IdType* indices = (*sampled_graphs)[etype].indices.Ptr<IdType>();\n      IdType* coo_rows = (*sampled_coo_rows)[etype].Ptr<IdType>();\n      IdType* induced_edges_data = (*induced_edges)[etype].Ptr<IdType>();\n      FloatType* weights_data = weights && (*weights)[etype]->shape[0]\n                                    ? (*weights)[etype].Ptr<FloatType>()\n                                    : nullptr;\n      const IdType exclude_edges_len = exclude_edges[etype]->shape[0];\n      std::sort(\n          exclude_edges[etype].Ptr<IdType>(),\n          exclude_edges[etype].Ptr<IdType>() + exclude_edges_len);\n      const IdType* exclude_edges_data = exclude_edges[etype].Ptr<IdType>();\n      IdType outIndices = 0;\n      for (IdType row = 0; row < (*sampled_graphs)[etype].indptr->shape[0] - 1;\n           ++row) {\n        auto tmp_row = indptr[row];\n        if (outIndices != indptr[row]) indptr[row] = outIndices;\n        for (IdType col = tmp_row; col < indptr[row + 1]; ++col) {\n          if (!std::binary_search(\n                  exclude_edges_data, exclude_edges_data + exclude_edges_len,\n                  induced_edges_data[col])) {\n            indices[outIndices] = indices[col];\n            induced_edges_data[outIndices] = induced_edges_data[col];\n            coo_rows[outIndices] = coo_rows[col];\n            if (weights_data) weights_data[outIndices] = weights_data[col];\n            ++outIndices;\n          }\n        }\n      }\n      indptr[(*sampled_graphs)[etype].indptr->shape[0] - 1] = outIndices;\n      remain_induced_edges[etype] =\n          aten::IndexSelect((*induced_edges)[etype], 0, outIndices);\n      remain_weights[etype] =\n          weights_data ? aten::IndexSelect((*weights)[etype], 0, outIndices)\n                       : NullArray();\n      remain_indices[etype] =\n          aten::IndexSelect((*sampled_graphs)[etype].indices, 0, outIndices);\n      (*sampled_coo_rows)[etype] =\n          aten::IndexSelect((*sampled_coo_rows)[etype], 0, outIndices);\n      (*sampled_graphs)[etype] = CSRMatrix(\n          (*sampled_graphs)[etype].num_rows, outIndices,\n          (*sampled_graphs)[etype].indptr, remain_indices[etype],\n          remain_induced_edges[etype]);\n    });\n  }\n}\n\nstd::pair<HeteroSubgraph, std::vector<FloatArray>> ExcludeCertainEdges(\n    const HeteroSubgraph& sg, const std::vector<IdArray>& exclude_edges,\n    const std::vector<FloatArray>* weights = nullptr) {\n  HeteroGraphPtr hg_view = HeteroGraphRef(sg.graph).sptr();\n  std::vector<IdArray> remain_induced_edges(hg_view->NumEdgeTypes());\n  std::vector<IdArray> remain_edges(hg_view->NumEdgeTypes());\n  std::vector<FloatArray> remain_weights(hg_view->NumEdgeTypes());\n\n  for (dgl_type_t etype = 0; etype < hg_view->NumEdgeTypes(); ++etype) {\n    IdArray edge_ids = Range(\n        0, sg.induced_edges[etype]->shape[0],\n        sg.induced_edges[etype]->dtype.bits, sg.induced_edges[etype]->ctx);\n    if (exclude_edges[etype].GetSize() == 0 || edge_ids.GetSize() == 0) {\n      remain_edges[etype] = edge_ids;\n      remain_induced_edges[etype] = sg.induced_edges[etype];\n      if (weights) remain_weights[etype] = (*weights)[etype];\n      continue;\n    }\n    ATEN_ID_TYPE_SWITCH(hg_view->DataType(), IdType, {\n      const auto dtype = weights && (*weights)[etype]->shape[0]\n                             ? (*weights)[etype]->dtype\n                             : DGLDataType{kDGLFloat, 8 * sizeof(float), 1};\n      ATEN_FLOAT_TYPE_SWITCH(dtype, FloatType, \"weights\", {\n        IdType* idx_data = edge_ids.Ptr<IdType>();\n        IdType* induced_edges_data = sg.induced_edges[etype].Ptr<IdType>();\n        FloatType* weights_data = weights && (*weights)[etype]->shape[0]\n                                      ? (*weights)[etype].Ptr<FloatType>()\n                                      : nullptr;\n        const IdType exclude_edges_len = exclude_edges[etype]->shape[0];\n        std::sort(\n            exclude_edges[etype].Ptr<IdType>(),\n            exclude_edges[etype].Ptr<IdType>() + exclude_edges_len);\n        const IdType* exclude_edges_data = exclude_edges[etype].Ptr<IdType>();\n        IdType outId = 0;\n        for (IdType i = 0; i != sg.induced_edges[etype]->shape[0]; ++i) {\n          // the following binary search is the bottleneck, excluding weights\n          // together with edges should almost be free.\n          if (!std::binary_search(\n                  exclude_edges_data, exclude_edges_data + exclude_edges_len,\n                  induced_edges_data[i])) {\n            induced_edges_data[outId] = induced_edges_data[i];\n            idx_data[outId] = idx_data[i];\n            if (weights_data) weights_data[outId] = weights_data[i];\n            ++outId;\n          }\n        }\n        remain_edges[etype] = aten::IndexSelect(edge_ids, 0, outId);\n        remain_induced_edges[etype] =\n            aten::IndexSelect(sg.induced_edges[etype], 0, outId);\n        remain_weights[etype] =\n            weights_data ? aten::IndexSelect((*weights)[etype], 0, outId)\n                         : NullArray();\n      });\n    });\n  }\n  HeteroSubgraph subg = hg_view->EdgeSubgraph(remain_edges, true);\n  subg.induced_edges = std::move(remain_induced_edges);\n  return std::make_pair(subg, remain_weights);\n}\n\nstd::pair<HeteroSubgraph, std::vector<FloatArray>> SampleLabors(\n    const HeteroGraphPtr hg, const std::vector<IdArray>& nodes,\n    const std::vector<int64_t>& fanouts, EdgeDir dir,\n    const std::vector<FloatArray>& prob,\n    const std::vector<IdArray>& exclude_edges, const int importance_sampling,\n    const IdArray random_seed, const float seed2_contribution,\n    const std::vector<IdArray>& NIDs) {\n  // sanity check\n  CHECK_EQ(nodes.size(), hg->NumVertexTypes())\n      << \"Number of node ID tensors must match the number of node types.\";\n  CHECK_EQ(fanouts.size(), hg->NumEdgeTypes())\n      << \"Number of fanout values must match the number of edge types.\";\n\n  DGLContext ctx = aten::GetContextOf(nodes);\n\n  std::vector<HeteroGraphPtr> subrels(hg->NumEdgeTypes());\n  std::vector<FloatArray> subimportances(hg->NumEdgeTypes());\n  std::vector<IdArray> induced_edges(hg->NumEdgeTypes());\n  for (dgl_type_t etype = 0; etype < hg->NumEdgeTypes(); ++etype) {\n    auto pair = hg->meta_graph()->FindEdge(etype);\n    const dgl_type_t src_vtype = pair.first;\n    const dgl_type_t dst_vtype = pair.second;\n    const IdArray nodes_ntype =\n        nodes[(dir == EdgeDir::kOut) ? src_vtype : dst_vtype];\n    const IdArray NIDs_ntype =\n        NIDs[(dir == EdgeDir::kIn) ? src_vtype : dst_vtype];\n    const int64_t num_nodes = nodes_ntype->shape[0];\n    if (num_nodes == 0 || fanouts[etype] == 0) {\n      // Nothing to sample for this etype, create a placeholder relation graph\n      subrels[etype] = UnitGraph::Empty(\n          hg->GetRelationGraph(etype)->NumVertexTypes(),\n          hg->NumVertices(src_vtype), hg->NumVertices(dst_vtype),\n          hg->DataType(), ctx);\n      induced_edges[etype] = aten::NullArray(hg->DataType(), ctx);\n      subimportances[etype] = NullArray();\n    } else {\n      // sample from one relation graph\n      auto req_fmt = (dir == EdgeDir::kOut) ? CSR_CODE : CSC_CODE;\n      auto avail_fmt = hg->SelectFormat(etype, req_fmt);\n      COOMatrix sampled_coo;\n      FloatArray importances;\n      const int64_t fanout =\n          fanouts[etype] >= 0\n              ? fanouts[etype]\n              : std::max(\n                    hg->NumVertices(dst_vtype), hg->NumVertices(src_vtype));\n      switch (avail_fmt) {\n        case SparseFormat::kCOO:\n          if (dir == EdgeDir::kIn) {\n            auto fs = aten::COOLaborSampling(\n                aten::COOTranspose(hg->GetCOOMatrix(etype)), nodes_ntype,\n                fanout, prob[etype], importance_sampling, random_seed,\n                seed2_contribution, NIDs_ntype);\n            sampled_coo = aten::COOTranspose(fs.first);\n            importances = fs.second;\n          } else {\n            std::tie(sampled_coo, importances) = aten::COOLaborSampling(\n                hg->GetCOOMatrix(etype), nodes_ntype, fanout, prob[etype],\n                importance_sampling, random_seed, seed2_contribution,\n                NIDs_ntype);\n          }\n          break;\n        case SparseFormat::kCSR:\n          CHECK(dir == EdgeDir::kOut)\n              << \"Cannot sample out edges on CSC matrix.\";\n          std::tie(sampled_coo, importances) = aten::CSRLaborSampling(\n              hg->GetCSRMatrix(etype), nodes_ntype, fanout, prob[etype],\n              importance_sampling, random_seed, seed2_contribution, NIDs_ntype);\n          break;\n        case SparseFormat::kCSC:\n          CHECK(dir == EdgeDir::kIn) << \"Cannot sample in edges on CSR matrix.\";\n          std::tie(sampled_coo, importances) = aten::CSRLaborSampling(\n              hg->GetCSCMatrix(etype), nodes_ntype, fanout, prob[etype],\n              importance_sampling, random_seed, seed2_contribution, NIDs_ntype);\n          sampled_coo = aten::COOTranspose(sampled_coo);\n          break;\n        default:\n          LOG(FATAL) << \"Unsupported sparse format.\";\n      }\n      subrels[etype] = UnitGraph::CreateFromCOO(\n          hg->GetRelationGraph(etype)->NumVertexTypes(), sampled_coo.num_rows,\n          sampled_coo.num_cols, sampled_coo.row, sampled_coo.col);\n      subimportances[etype] = importances;\n      induced_edges[etype] = sampled_coo.data;\n    }\n  }\n\n  HeteroSubgraph ret;\n  ret.graph =\n      CreateHeteroGraph(hg->meta_graph(), subrels, hg->NumVerticesPerType());\n  ret.induced_vertices.resize(hg->NumVertexTypes());\n  ret.induced_edges = std::move(induced_edges);\n\n  if (!exclude_edges.empty())\n    return ExcludeCertainEdges(ret, exclude_edges, &subimportances);\n\n  return std::make_pair(ret, std::move(subimportances));\n}\n\nHeteroSubgraph SampleNeighbors(\n    const HeteroGraphPtr hg, const std::vector<IdArray>& nodes,\n    const std::vector<int64_t>& fanouts, EdgeDir dir,\n    const std::vector<NDArray>& prob_or_mask,\n    const std::vector<IdArray>& exclude_edges, bool replace) {\n  // sanity check\n  CHECK_EQ(nodes.size(), hg->NumVertexTypes())\n      << \"Number of node ID tensors must match the number of node types.\";\n  CHECK_EQ(fanouts.size(), hg->NumEdgeTypes())\n      << \"Number of fanout values must match the number of edge types.\";\n  CHECK_EQ(prob_or_mask.size(), hg->NumEdgeTypes())\n      << \"Number of probability tensors must match the number of edge types.\";\n\n  DGLContext ctx = aten::GetContextOf(nodes);\n\n  std::vector<HeteroGraphPtr> subrels(hg->NumEdgeTypes());\n  std::vector<IdArray> induced_edges(hg->NumEdgeTypes());\n  for (dgl_type_t etype = 0; etype < hg->NumEdgeTypes(); ++etype) {\n    auto pair = hg->meta_graph()->FindEdge(etype);\n    const dgl_type_t src_vtype = pair.first;\n    const dgl_type_t dst_vtype = pair.second;\n    const IdArray nodes_ntype =\n        nodes[(dir == EdgeDir::kOut) ? src_vtype : dst_vtype];\n    const int64_t num_nodes = nodes_ntype->shape[0];\n\n    if (num_nodes == 0 || fanouts[etype] == 0) {\n      // Nothing to sample for this etype, create a placeholder relation graph\n      subrels[etype] = UnitGraph::Empty(\n          hg->GetRelationGraph(etype)->NumVertexTypes(),\n          hg->NumVertices(src_vtype), hg->NumVertices(dst_vtype),\n          hg->DataType(), ctx);\n      induced_edges[etype] = aten::NullArray(hg->DataType(), ctx);\n    } else {\n      COOMatrix sampled_coo;\n      // sample from one relation graph\n      auto req_fmt = (dir == EdgeDir::kOut) ? CSR_CODE : CSC_CODE;\n      auto avail_fmt = hg->SelectFormat(etype, req_fmt);\n      switch (avail_fmt) {\n        case SparseFormat::kCOO:\n          if (dir == EdgeDir::kIn) {\n            sampled_coo = aten::COOTranspose(aten::COORowWiseSampling(\n                aten::COOTranspose(hg->GetCOOMatrix(etype)), nodes_ntype,\n                fanouts[etype], prob_or_mask[etype], replace));\n          } else {\n            sampled_coo = aten::COORowWiseSampling(\n                hg->GetCOOMatrix(etype), nodes_ntype, fanouts[etype],\n                prob_or_mask[etype], replace);\n          }\n          break;\n        case SparseFormat::kCSR:\n          CHECK(dir == EdgeDir::kOut)\n              << \"Cannot sample out edges on CSC matrix.\";\n          sampled_coo = aten::CSRRowWiseSampling(\n              hg->GetCSRMatrix(etype), nodes_ntype, fanouts[etype],\n              prob_or_mask[etype], replace);\n          break;\n        case SparseFormat::kCSC:\n          CHECK(dir == EdgeDir::kIn) << \"Cannot sample in edges on CSR matrix.\";\n          sampled_coo = aten::CSRRowWiseSampling(\n              hg->GetCSCMatrix(etype), nodes_ntype, fanouts[etype],\n              prob_or_mask[etype], replace);\n          sampled_coo = aten::COOTranspose(sampled_coo);\n          break;\n        default:\n          LOG(FATAL) << \"Unsupported sparse format.\";\n      }\n\n      subrels[etype] = UnitGraph::CreateFromCOO(\n          hg->GetRelationGraph(etype)->NumVertexTypes(), sampled_coo.num_rows,\n          sampled_coo.num_cols, sampled_coo.row, sampled_coo.col);\n      induced_edges[etype] = sampled_coo.data;\n    }\n  }\n\n  HeteroSubgraph ret;\n  ret.graph =\n      CreateHeteroGraph(hg->meta_graph(), subrels, hg->NumVerticesPerType());\n  ret.induced_vertices.resize(hg->NumVertexTypes());\n  ret.induced_edges = std::move(induced_edges);\n  if (!exclude_edges.empty()) {\n    return ExcludeCertainEdges(ret, exclude_edges).first;\n  }\n  return ret;\n}\n\ntemplate <typename IdType>\nstd::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>>\nSampleNeighborsFused(\n    const HeteroGraphPtr hg, const std::vector<IdArray>& nodes,\n    const std::vector<IdArray>& mapping, const std::vector<int64_t>& fanouts,\n    EdgeDir dir, const std::vector<NDArray>& prob_or_mask,\n    const std::vector<IdArray>& exclude_edges, bool replace) {\n  CHECK_EQ(nodes.size(), hg->NumVertexTypes())\n      << \"Number of node ID tensors must match the number of node types.\";\n  CHECK_EQ(fanouts.size(), hg->NumEdgeTypes())\n      << \"Number of fanout values must match the number of edge types.\";\n  CHECK_EQ(prob_or_mask.size(), hg->NumEdgeTypes())\n      << \"Number of probability tensors must match the number of edge types.\";\n\n  DGLContext ctx = aten::GetContextOf(nodes);\n\n  std::vector<CSRMatrix> sampled_graphs;\n  std::vector<IdArray> sampled_coo_rows;\n  std::vector<IdArray> induced_edges;\n  std::vector<IdArray> induced_vertices;\n  std::vector<int64_t> num_nodes_per_type;\n  std::vector<std::vector<IdType>> new_nodes_vec(hg->NumVertexTypes());\n  std::vector<int> seed_nodes_mapped(hg->NumVertexTypes(), 0);\n\n  for (dgl_type_t etype = 0; etype < hg->NumEdgeTypes(); ++etype) {\n    auto pair = hg->meta_graph()->FindEdge(etype);\n    const dgl_type_t src_vtype = pair.first;\n    const dgl_type_t dst_vtype = pair.second;\n    const dgl_type_t rhs_node_type =\n        (dir == EdgeDir::kOut) ? src_vtype : dst_vtype;\n    const IdArray nodes_ntype = nodes[rhs_node_type];\n    const int64_t num_nodes = nodes_ntype->shape[0];\n\n    if (num_nodes == 0 || fanouts[etype] == 0) {\n      // Nothing to sample for this etype, create a placeholder\n      sampled_graphs.push_back(CSRMatrix());\n      sampled_coo_rows.push_back(IdArray());\n      induced_edges.push_back(aten::NullArray(hg->DataType(), ctx));\n    } else {\n      bool map_seed_nodes = !seed_nodes_mapped[rhs_node_type];\n      // sample from one relation graph\n      std::pair<CSRMatrix, IdArray> sampled_graph;\n      auto sampling_fn = map_seed_nodes\n                             ? aten::CSRRowWiseSamplingFused<IdType, true>\n                             : aten::CSRRowWiseSamplingFused<IdType, false>;\n      auto req_fmt = (dir == EdgeDir::kOut) ? CSR_CODE : CSC_CODE;\n      auto avail_fmt = hg->SelectFormat(etype, req_fmt);\n      switch (avail_fmt) {\n        case SparseFormat::kCSR:\n          CHECK(dir == EdgeDir::kOut)\n              << \"Cannot sample out edges on CSC matrix.\";\n          // In heterographs nodes of two diffrent types can be connected\n          // therefore two diffrent mappings and node vectors are needed\n          sampled_graph = sampling_fn(\n              hg->GetCSRMatrix(etype), nodes_ntype, mapping[src_vtype],\n              &new_nodes_vec[src_vtype], fanouts[etype], prob_or_mask[etype],\n              replace);\n          break;\n        case SparseFormat::kCSC:\n          CHECK(dir == EdgeDir::kIn) << \"Cannot sample in edges on CSR matrix.\";\n          sampled_graph = sampling_fn(\n              hg->GetCSCMatrix(etype), nodes_ntype, mapping[dst_vtype],\n              &new_nodes_vec[dst_vtype], fanouts[etype], prob_or_mask[etype],\n              replace);\n          break;\n        default:\n          LOG(FATAL) << \"Unsupported sparse format.\";\n      }\n      seed_nodes_mapped[rhs_node_type]++;\n      sampled_graphs.push_back(sampled_graph.first);\n      if (sampled_graph.first.data.defined())\n        induced_edges.push_back(sampled_graph.first.data);\n      else\n        induced_edges.push_back(\n            aten::NullArray(DGLDataType{kDGLInt, sizeof(IdType) * 8, 1}, ctx));\n      sampled_coo_rows.push_back(sampled_graph.second);\n    }\n  }\n\n  if (!exclude_edges.empty()) {\n    ExcludeCertainEdgesFused<IdType>(\n        &sampled_graphs, &induced_edges, &sampled_coo_rows, exclude_edges);\n    for (size_t i = 0; i < hg->NumEdgeTypes(); i++) {\n      if (sampled_graphs[i].data.defined())\n        induced_edges[i] = std::move(sampled_graphs[i].data);\n      else\n        induced_edges[i] =\n            aten::NullArray(DGLDataType{kDGLInt, sizeof(IdType) * 8, 1}, ctx);\n    }\n  }\n\n  // map indices\n  for (dgl_type_t etype = 0; etype < hg->NumEdgeTypes(); ++etype) {\n    auto pair = hg->meta_graph()->FindEdge(etype);\n    const dgl_type_t src_vtype = pair.first;\n    const dgl_type_t dst_vtype = pair.second;\n    const dgl_type_t lhs_node_type =\n        (dir == EdgeDir::kIn) ? src_vtype : dst_vtype;\n    if (sampled_graphs[etype].num_cols != 0) {\n      auto num_cols = sampled_graphs[etype].num_cols;\n      int num_threads_col = runtime::compute_num_threads(0, num_cols, 1);\n      std::vector<IdType> global_prefix_col(num_threads_col + 1, 0);\n      std::vector<std::vector<IdType>> src_nodes_local(num_threads_col);\n      IdType* mapping_data_dst = mapping[lhs_node_type].Ptr<IdType>();\n      IdType* cdata = sampled_graphs[etype].indices.Ptr<IdType>();\n#pragma omp parallel num_threads(num_threads_col)\n      {\n        const int thread_id = omp_get_thread_num();\n        num_threads_col = omp_get_num_threads();\n\n        const int64_t start_i =\n            thread_id * (num_cols / num_threads_col) +\n            std::min(\n                static_cast<int64_t>(thread_id), num_cols % num_threads_col);\n        const int64_t end_i = (thread_id + 1) * (num_cols / num_threads_col) +\n                              std::min(\n                                  static_cast<int64_t>(thread_id + 1),\n                                  num_cols % num_threads_col);\n        assert(thread_id + 1 < num_threads_col || end_i == num_cols);\n        for (int64_t i = start_i; i < end_i; ++i) {\n          int64_t picked_idx = cdata[i];\n          bool spot_claimed =\n              BoolCompareAndSwap<IdType>(&mapping_data_dst[picked_idx]);\n          if (spot_claimed) src_nodes_local[thread_id].push_back(picked_idx);\n        }\n        global_prefix_col[thread_id + 1] = src_nodes_local[thread_id].size();\n\n#pragma omp barrier\n#pragma omp master\n        {\n          global_prefix_col[0] = new_nodes_vec[lhs_node_type].size();\n          for (int t = 0; t < num_threads_col; ++t) {\n            global_prefix_col[t + 1] += global_prefix_col[t];\n          }\n        }\n\n#pragma omp barrier\n        int64_t mapping_shift = global_prefix_col[thread_id];\n        for (size_t i = 0; i < src_nodes_local[thread_id].size(); ++i)\n          mapping_data_dst[src_nodes_local[thread_id][i]] = mapping_shift + i;\n\n#pragma omp barrier\n        for (int64_t i = start_i; i < end_i; ++i) {\n          IdType picked_idx = cdata[i];\n          IdType mapped_idx = mapping_data_dst[picked_idx];\n          cdata[i] = mapped_idx;\n        }\n      }\n      IdType offset = new_nodes_vec[lhs_node_type].size();\n      new_nodes_vec[lhs_node_type].resize(global_prefix_col.back());\n      for (int thread_id = 0; thread_id < num_threads_col; ++thread_id) {\n        memcpy(\n            new_nodes_vec[lhs_node_type].data() + offset,\n            &src_nodes_local[thread_id][0],\n            src_nodes_local[thread_id].size() * sizeof(IdType));\n        offset += src_nodes_local[thread_id].size();\n      }\n    }\n  }\n\n  // counting how many nodes of each ntype were sampled\n  num_nodes_per_type.resize(2 * hg->NumVertexTypes());\n  for (size_t i = 0; i < hg->NumVertexTypes(); i++) {\n    num_nodes_per_type[i] = new_nodes_vec[i].size();\n    num_nodes_per_type[hg->NumVertexTypes() + i] = nodes[i]->shape[0];\n    induced_vertices.push_back(\n        VecToIdArray(new_nodes_vec[i], sizeof(IdType) * 8));\n  }\n\n  std::vector<HeteroGraphPtr> subrels(hg->NumEdgeTypes());\n  for (dgl_type_t etype = 0; etype < hg->NumEdgeTypes(); ++etype) {\n    auto pair = hg->meta_graph()->FindEdge(etype);\n    const dgl_type_t src_vtype = pair.first;\n    const dgl_type_t dst_vtype = pair.second;\n    if (sampled_graphs[etype].num_rows == 0) {\n      subrels[etype] = UnitGraph::Empty(\n          2, new_nodes_vec[src_vtype].size(), nodes[dst_vtype]->shape[0],\n          hg->DataType(), ctx);\n    } else {\n      CSRMatrix graph = sampled_graphs[etype];\n      if (dir == EdgeDir::kOut) {\n        subrels[etype] = UnitGraph::CreateFromCSRAndCOO(\n            2,\n            CSRMatrix(\n                nodes[src_vtype]->shape[0], new_nodes_vec[dst_vtype].size(),\n                graph.indptr, graph.indices,\n                Range(\n                    0, graph.indices->shape[0], graph.indices->dtype.bits,\n                    ctx)),\n            COOMatrix(\n                nodes[src_vtype]->shape[0], new_nodes_vec[dst_vtype].size(),\n                sampled_coo_rows[etype], graph.indices),\n            ALL_CODE);\n      } else {\n        subrels[etype] = UnitGraph::CreateFromCSCAndCOO(\n            2,\n            CSRMatrix(\n                nodes[dst_vtype]->shape[0], new_nodes_vec[src_vtype].size(),\n                graph.indptr, graph.indices,\n                Range(\n                    0, graph.indices->shape[0], graph.indices->dtype.bits,\n                    ctx)),\n            COOMatrix(\n                new_nodes_vec[src_vtype].size(), nodes[dst_vtype]->shape[0],\n                graph.indices, sampled_coo_rows[etype]),\n            ALL_CODE);\n      }\n    }\n  }\n\n  HeteroSubgraph ret;\n\n  const auto meta_graph = hg->meta_graph();\n  const EdgeArray etypes = meta_graph->Edges(\"eid\");\n  const IdArray new_dst = Add(etypes.dst, hg->NumVertexTypes());\n\n  const auto new_meta_graph = ImmutableGraph::CreateFromCOO(\n      hg->NumVertexTypes() * 2, etypes.src, new_dst);\n\n  HeteroGraphPtr new_graph =\n      CreateHeteroGraph(new_meta_graph, subrels, num_nodes_per_type);\n  return std::make_tuple(new_graph, induced_edges, induced_vertices);\n}\n\ntemplate std::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>>\nSampleNeighborsFused<int64_t>(\n    const HeteroGraphPtr, const std::vector<IdArray>&,\n    const std::vector<IdArray>&, const std::vector<int64_t>&, EdgeDir,\n    const std::vector<NDArray>&, const std::vector<IdArray>&, bool);\n\ntemplate std::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>>\nSampleNeighborsFused<int32_t>(\n    const HeteroGraphPtr, const std::vector<IdArray>&,\n    const std::vector<IdArray>&, const std::vector<int64_t>&, EdgeDir,\n    const std::vector<NDArray>&, const std::vector<IdArray>&, bool);\n\nHeteroSubgraph SampleNeighborsEType(\n    const HeteroGraphPtr hg, const IdArray nodes,\n    const std::vector<int64_t>& eid2etype_offset,\n    const std::vector<int64_t>& fanouts, EdgeDir dir,\n    const std::vector<FloatArray>& prob, bool replace,\n    bool rowwise_etype_sorted) {\n  CHECK_EQ(1, hg->NumVertexTypes())\n      << \"SampleNeighborsEType only work with homogeneous graph\";\n  CHECK_EQ(1, hg->NumEdgeTypes())\n      << \"SampleNeighborsEType only work with homogeneous graph\";\n\n  std::vector<HeteroGraphPtr> subrels(1);\n  std::vector<IdArray> induced_edges(1);\n  const int64_t num_nodes = nodes->shape[0];\n  dgl_type_t etype = 0;\n  const dgl_type_t src_vtype = 0;\n  const dgl_type_t dst_vtype = 0;\n\n  bool same_fanout = true;\n  int64_t fanout_value = fanouts[0];\n  for (auto fanout : fanouts) {\n    if (fanout != fanout_value) {\n      same_fanout = false;\n      break;\n    }\n  }\n\n  if (num_nodes == 0 || (same_fanout && fanout_value == 0)) {\n    subrels[etype] = UnitGraph::Empty(\n        1, hg->NumVertices(src_vtype), hg->NumVertices(dst_vtype),\n        hg->DataType(), hg->Context());\n    induced_edges[etype] = aten::NullArray();\n  } else {\n    COOMatrix sampled_coo;\n    // sample from graph\n    // the edge type is stored in etypes\n    auto req_fmt = (dir == EdgeDir::kOut) ? CSR_CODE : CSC_CODE;\n    auto avail_fmt = hg->SelectFormat(etype, req_fmt);\n    switch (avail_fmt) {\n      case SparseFormat::kCOO:\n        if (dir == EdgeDir::kIn) {\n          sampled_coo = aten::COOTranspose(aten::COORowWisePerEtypeSampling(\n              aten::COOTranspose(hg->GetCOOMatrix(etype)), nodes,\n              eid2etype_offset, fanouts, prob, replace));\n        } else {\n          sampled_coo = aten::COORowWisePerEtypeSampling(\n              hg->GetCOOMatrix(etype), nodes, eid2etype_offset, fanouts, prob,\n              replace);\n        }\n        break;\n      case SparseFormat::kCSR:\n        CHECK(dir == EdgeDir::kOut) << \"Cannot sample out edges on CSC matrix.\";\n        sampled_coo = aten::CSRRowWisePerEtypeSampling(\n            hg->GetCSRMatrix(etype), nodes, eid2etype_offset, fanouts, prob,\n            replace, rowwise_etype_sorted);\n        break;\n      case SparseFormat::kCSC:\n        CHECK(dir == EdgeDir::kIn) << \"Cannot sample in edges on CSR matrix.\";\n        sampled_coo = aten::CSRRowWisePerEtypeSampling(\n            hg->GetCSCMatrix(etype), nodes, eid2etype_offset, fanouts, prob,\n            replace, rowwise_etype_sorted);\n        sampled_coo = aten::COOTranspose(sampled_coo);\n        break;\n      default:\n        LOG(FATAL) << \"Unsupported sparse format.\";\n    }\n\n    subrels[etype] = UnitGraph::CreateFromCOO(\n        1, sampled_coo.num_rows, sampled_coo.num_cols, sampled_coo.row,\n        sampled_coo.col);\n    induced_edges[etype] = sampled_coo.data;\n  }\n\n  HeteroSubgraph ret;\n  ret.graph =\n      CreateHeteroGraph(hg->meta_graph(), subrels, hg->NumVerticesPerType());\n  ret.induced_vertices.resize(hg->NumVertexTypes());\n  ret.induced_edges = std::move(induced_edges);\n  return ret;\n}\n\nHeteroSubgraph SampleNeighborsTopk(\n    const HeteroGraphPtr hg, const std::vector<IdArray>& nodes,\n    const std::vector<int64_t>& k, EdgeDir dir,\n    const std::vector<FloatArray>& weight, bool ascending) {\n  // sanity check\n  CHECK_EQ(nodes.size(), hg->NumVertexTypes())\n      << \"Number of node ID tensors must match the number of node types.\";\n  CHECK_EQ(k.size(), hg->NumEdgeTypes())\n      << \"Number of k values must match the number of edge types.\";\n  CHECK_EQ(weight.size(), hg->NumEdgeTypes())\n      << \"Number of weight tensors must match the number of edge types.\";\n\n  std::vector<HeteroGraphPtr> subrels(hg->NumEdgeTypes());\n  std::vector<IdArray> induced_edges(hg->NumEdgeTypes());\n  for (dgl_type_t etype = 0; etype < hg->NumEdgeTypes(); ++etype) {\n    auto pair = hg->meta_graph()->FindEdge(etype);\n    const dgl_type_t src_vtype = pair.first;\n    const dgl_type_t dst_vtype = pair.second;\n    const IdArray nodes_ntype =\n        nodes[(dir == EdgeDir::kOut) ? src_vtype : dst_vtype];\n    const int64_t num_nodes = nodes_ntype->shape[0];\n    if (num_nodes == 0 || k[etype] == 0) {\n      // Nothing to sample for this etype, create a placeholder relation graph\n      subrels[etype] = UnitGraph::Empty(\n          hg->GetRelationGraph(etype)->NumVertexTypes(),\n          hg->NumVertices(src_vtype), hg->NumVertices(dst_vtype),\n          hg->DataType(), hg->Context());\n      induced_edges[etype] = aten::NullArray();\n    } else {\n      // sample from one relation graph\n      auto req_fmt = (dir == EdgeDir::kOut) ? CSR_CODE : CSC_CODE;\n      auto avail_fmt = hg->SelectFormat(etype, req_fmt);\n      COOMatrix sampled_coo;\n      switch (avail_fmt) {\n        case SparseFormat::kCOO:\n          if (dir == EdgeDir::kIn) {\n            sampled_coo = aten::COOTranspose(aten::COORowWiseTopk(\n                aten::COOTranspose(hg->GetCOOMatrix(etype)), nodes_ntype,\n                k[etype], weight[etype], ascending));\n          } else {\n            sampled_coo = aten::COORowWiseTopk(\n                hg->GetCOOMatrix(etype), nodes_ntype, k[etype], weight[etype],\n                ascending);\n          }\n          break;\n        case SparseFormat::kCSR:\n          CHECK(dir == EdgeDir::kOut)\n              << \"Cannot sample out edges on CSC matrix.\";\n          sampled_coo = aten::CSRRowWiseTopk(\n              hg->GetCSRMatrix(etype), nodes_ntype, k[etype], weight[etype],\n              ascending);\n          break;\n        case SparseFormat::kCSC:\n          CHECK(dir == EdgeDir::kIn) << \"Cannot sample in edges on CSR matrix.\";\n          sampled_coo = aten::CSRRowWiseTopk(\n              hg->GetCSCMatrix(etype), nodes_ntype, k[etype], weight[etype],\n              ascending);\n          sampled_coo = aten::COOTranspose(sampled_coo);\n          break;\n        default:\n          LOG(FATAL) << \"Unsupported sparse format.\";\n      }\n      subrels[etype] = UnitGraph::CreateFromCOO(\n          hg->GetRelationGraph(etype)->NumVertexTypes(), sampled_coo.num_rows,\n          sampled_coo.num_cols, sampled_coo.row, sampled_coo.col);\n      induced_edges[etype] = sampled_coo.data;\n    }\n  }\n\n  HeteroSubgraph ret;\n  ret.graph =\n      CreateHeteroGraph(hg->meta_graph(), subrels, hg->NumVerticesPerType());\n  ret.induced_vertices.resize(hg->NumVertexTypes());\n  ret.induced_edges = std::move(induced_edges);\n  return ret;\n}\n\nHeteroSubgraph SampleNeighborsBiased(\n    const HeteroGraphPtr hg, const IdArray& nodes, const int64_t fanout,\n    const NDArray& bias, const NDArray& tag_offset, const EdgeDir dir,\n    const bool replace) {\n  CHECK_EQ(hg->NumEdgeTypes(), 1)\n      << \"Only homogeneous or bipartite graphs are supported\";\n  auto pair = hg->meta_graph()->FindEdge(0);\n  const dgl_type_t src_vtype = pair.first;\n  const dgl_type_t dst_vtype = pair.second;\n  const dgl_type_t nodes_ntype = (dir == EdgeDir::kOut) ? src_vtype : dst_vtype;\n\n  // sanity check\n  CHECK_EQ(tag_offset->ndim, 2)\n      << \"The shape of tag_offset should be [num_nodes, num_tags + 1]\";\n  CHECK_EQ(tag_offset->shape[0], hg->NumVertices(nodes_ntype))\n      << \"The shape of tag_offset should be [num_nodes, num_tags + 1]\";\n  CHECK_EQ(tag_offset->shape[1], bias->shape[0] + 1)\n      << \"The sizes of tag_offset and bias are inconsistent\";\n\n  const int64_t num_nodes = nodes->shape[0];\n  HeteroGraphPtr subrel;\n  IdArray induced_edges;\n  const dgl_type_t etype = 0;\n  if (num_nodes == 0 || fanout == 0) {\n    // Nothing to sample for this etype, create a placeholder relation graph\n    subrel = UnitGraph::Empty(\n        hg->GetRelationGraph(etype)->NumVertexTypes(),\n        hg->NumVertices(src_vtype), hg->NumVertices(dst_vtype), hg->DataType(),\n        hg->Context());\n    induced_edges = aten::NullArray();\n  } else {\n    // sample from one relation graph\n    const auto req_fmt = (dir == EdgeDir::kOut) ? CSR_CODE : CSC_CODE;\n    const auto created_fmt = hg->GetCreatedFormats();\n    COOMatrix sampled_coo;\n\n    switch (req_fmt) {\n      case CSR_CODE:\n        CHECK(created_fmt & CSR_CODE) << \"A sorted CSR Matrix is required.\";\n        sampled_coo = aten::CSRRowWiseSamplingBiased(\n            hg->GetCSRMatrix(etype), nodes, fanout, tag_offset, bias, replace);\n        break;\n      case CSC_CODE:\n        CHECK(created_fmt & CSC_CODE) << \"A sorted CSC Matrix is required.\";\n        sampled_coo = aten::CSRRowWiseSamplingBiased(\n            hg->GetCSCMatrix(etype), nodes, fanout, tag_offset, bias, replace);\n        sampled_coo = aten::COOTranspose(sampled_coo);\n        break;\n      default:\n        LOG(FATAL) << \"Unsupported sparse format.\";\n    }\n    subrel = UnitGraph::CreateFromCOO(\n        hg->GetRelationGraph(etype)->NumVertexTypes(), sampled_coo.num_rows,\n        sampled_coo.num_cols, sampled_coo.row, sampled_coo.col);\n    induced_edges = sampled_coo.data;\n  }\n\n  HeteroSubgraph ret;\n  ret.graph =\n      CreateHeteroGraph(hg->meta_graph(), {subrel}, hg->NumVerticesPerType());\n  ret.induced_vertices.resize(hg->NumVertexTypes());\n  ret.induced_edges = {induced_edges};\n  return ret;\n}\n\nDGL_REGISTER_GLOBAL(\"sampling.neighbor._CAPI_DGLSampleNeighborsEType\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      IdArray nodes = args[1];\n      const std::vector<int64_t>& eid2etype_offset =\n          ListValueToVector<int64_t>(args[2]);\n      IdArray fanout = args[3];\n      const std::string dir_str = args[4];\n      const auto& prob = ListValueToVector<FloatArray>(args[5]);\n      const bool replace = args[6];\n      const bool rowwise_etype_sorted = args[7];\n\n      CHECK(dir_str == \"in\" || dir_str == \"out\")\n          << \"Invalid edge direction. Must be \\\"in\\\" or \\\"out\\\".\";\n      EdgeDir dir = (dir_str == \"in\") ? EdgeDir::kIn : EdgeDir::kOut;\n      CHECK_INT64(fanout, \"fanout\");\n      std::vector<int64_t> fanout_vec = fanout.ToVector<int64_t>();\n\n      std::shared_ptr<HeteroSubgraph> subg(new HeteroSubgraph);\n      *subg = sampling::SampleNeighborsEType(\n          hg.sptr(), nodes, eid2etype_offset, fanout_vec, dir, prob, replace,\n          rowwise_etype_sorted);\n      *rv = HeteroSubgraphRef(subg);\n    });\n\nDGL_REGISTER_GLOBAL(\"sampling.labor._CAPI_DGLSampleLabors\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      const auto& nodes = ListValueToVector<IdArray>(args[1]);\n      IdArray fanouts_array = args[2];\n      const auto& fanouts = fanouts_array.ToVector<int64_t>();\n      const std::string dir_str = args[3];\n      const auto& prob = ListValueToVector<FloatArray>(args[4]);\n      const auto& exclude_edges = ListValueToVector<IdArray>(args[5]);\n      const int importance_sampling = args[6];\n      const IdArray random_seed = args[7];\n      const double seed2_contribution = args[8];\n      const auto& NIDs = ListValueToVector<IdArray>(args[9]);\n\n      CHECK(dir_str == \"in\" || dir_str == \"out\")\n          << \"Invalid edge direction. Must be \\\"in\\\" or \\\"out\\\".\";\n      EdgeDir dir = (dir_str == \"in\") ? EdgeDir::kIn : EdgeDir::kOut;\n\n      std::shared_ptr<HeteroSubgraph> subg_ptr(new HeteroSubgraph);\n\n      auto&& subg_importances = sampling::SampleLabors(\n          hg.sptr(), nodes, fanouts, dir, prob, exclude_edges,\n          importance_sampling, random_seed, seed2_contribution, NIDs);\n      *subg_ptr = subg_importances.first;\n      List<Value> ret_val;\n      ret_val.push_back(Value(subg_ptr));\n      for (auto& imp : subg_importances.second)\n        ret_val.push_back(Value(MakeValue(imp)));\n\n      *rv = ret_val;\n    });\n\nDGL_REGISTER_GLOBAL(\"sampling.neighbor._CAPI_DGLSampleNeighbors\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      const auto& nodes = ListValueToVector<IdArray>(args[1]);\n      IdArray fanouts_array = args[2];\n      const auto& fanouts = fanouts_array.ToVector<int64_t>();\n      const std::string dir_str = args[3];\n      const auto& prob_or_mask = ListValueToVector<NDArray>(args[4]);\n      const auto& exclude_edges = ListValueToVector<IdArray>(args[5]);\n      const bool replace = args[6];\n\n      CHECK(dir_str == \"in\" || dir_str == \"out\")\n          << \"Invalid edge direction. Must be \\\"in\\\" or \\\"out\\\".\";\n      EdgeDir dir = (dir_str == \"in\") ? EdgeDir::kIn : EdgeDir::kOut;\n\n      std::shared_ptr<HeteroSubgraph> subg(new HeteroSubgraph);\n      *subg = sampling::SampleNeighbors(\n          hg.sptr(), nodes, fanouts, dir, prob_or_mask, exclude_edges, replace);\n\n      *rv = HeteroSubgraphRef(subg);\n    });\n\nDGL_REGISTER_GLOBAL(\"sampling.neighbor._CAPI_DGLSampleNeighborsFused\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      const auto& nodes = ListValueToVector<IdArray>(args[1]);\n      auto mapping = ListValueToVector<IdArray>(args[2]);\n      IdArray fanouts_array = args[3];\n      const auto& fanouts = fanouts_array.ToVector<int64_t>();\n      const std::string dir_str = args[4];\n      const auto& prob_or_mask = ListValueToVector<NDArray>(args[5]);\n      const auto& exclude_edges = ListValueToVector<IdArray>(args[6]);\n      const bool replace = args[7];\n\n      CHECK(dir_str == \"in\" || dir_str == \"out\")\n          << \"Invalid edge direction. Must be \\\"in\\\" or \\\"out\\\".\";\n      EdgeDir dir = (dir_str == \"in\") ? EdgeDir::kIn : EdgeDir::kOut;\n\n      HeteroGraphPtr new_graph;\n      std::vector<IdArray> induced_edges;\n      std::vector<IdArray> induced_vertices;\n\n      ATEN_ID_TYPE_SWITCH(hg->DataType(), IdType, {\n        std::tie(new_graph, induced_edges, induced_vertices) =\n            SampleNeighborsFused<IdType>(\n                hg.sptr(), nodes, mapping, fanouts, dir, prob_or_mask,\n                exclude_edges, replace);\n      });\n\n      List<Value> lhs_nodes_ref;\n      for (IdArray& array : induced_vertices)\n        lhs_nodes_ref.push_back(Value(MakeValue(array)));\n      List<Value> induced_edges_ref;\n      for (IdArray& array : induced_edges)\n        induced_edges_ref.push_back(Value(MakeValue(array)));\n      List<ObjectRef> ret;\n      ret.push_back(HeteroGraphRef(new_graph));\n      ret.push_back(lhs_nodes_ref);\n      ret.push_back(induced_edges_ref);\n\n      *rv = ret;\n    });\n\nDGL_REGISTER_GLOBAL(\"sampling.neighbor._CAPI_DGLSampleNeighborsTopk\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      const auto& nodes = ListValueToVector<IdArray>(args[1]);\n      IdArray k_array = args[2];\n      const auto& k = k_array.ToVector<int64_t>();\n      const std::string dir_str = args[3];\n      const auto& weight = ListValueToVector<FloatArray>(args[4]);\n      const bool ascending = args[5];\n\n      CHECK(dir_str == \"in\" || dir_str == \"out\")\n          << \"Invalid edge direction. Must be \\\"in\\\" or \\\"out\\\".\";\n      EdgeDir dir = (dir_str == \"in\") ? EdgeDir::kIn : EdgeDir::kOut;\n\n      std::shared_ptr<HeteroSubgraph> subg(new HeteroSubgraph);\n      *subg = sampling::SampleNeighborsTopk(\n          hg.sptr(), nodes, k, dir, weight, ascending);\n\n      *rv = HeteroGraphRef(subg);\n    });\n\nDGL_REGISTER_GLOBAL(\"sampling.neighbor._CAPI_DGLSampleNeighborsBiased\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      const IdArray nodes = args[1];\n      const int64_t fanout = args[2];\n      const NDArray bias = args[3];\n      const NDArray tag_offset = args[4];\n      const std::string dir_str = args[5];\n      const bool replace = args[6];\n\n      CHECK(dir_str == \"in\" || dir_str == \"out\")\n          << \"Invalid edge direction. Must be \\\"in\\\" or \\\"out\\\".\";\n      EdgeDir dir = (dir_str == \"in\") ? EdgeDir::kIn : EdgeDir::kOut;\n\n      std::shared_ptr<HeteroSubgraph> subg(new HeteroSubgraph);\n      *subg = sampling::SampleNeighborsBiased(\n          hg.sptr(), nodes, fanout, bias, tag_offset, dir, replace);\n\n      *rv = HeteroGraphRef(subg);\n    });\n\n}  // namespace sampling\n}  // namespace dgl\n"
  },
  {
    "path": "src/graph/sampling/randomwalks/frequency_hashmap.cu",
    "content": "/**\n *  Copyright (c) 2021 by Contributors\n * @file graph/sampling/frequency_hashmap.cu\n * @brief frequency hashmap - used to select top-k frequency edges of each node\n */\n\n#include <algorithm>\n#include <cub/cub.cuh>  // NOLINT\n#include <tuple>\n#include <utility>\n\n#include \"../../../array/cuda/atomic.cuh\"\n#include \"../../../runtime/cuda/cuda_common.h\"\n#include \"frequency_hashmap.cuh\"\n\nnamespace dgl {\n\nnamespace sampling {\n\nnamespace impl {\n\nnamespace {\n\nint64_t _table_size(const int64_t num, const int64_t scale) {\n  /**\n   * Calculate the number of buckets in the hashtable. To guarantee we can\n   * fill the hashtable in the worst case, we must use a number of buckets which\n   * is a power of two.\n   * https://en.wikipedia.org/wiki/Quadratic_probing#Limitations\n   */\n  const int64_t next_pow2 = 1 << static_cast<int64_t>(1 + std::log2(num >> 1));\n  return next_pow2 << scale;\n}\n\ntemplate <typename IdxType, int BLOCK_SIZE, int TILE_SIZE>\n__global__ void _init_edge_table(void *edge_hashmap, int64_t edges_len) {\n  using EdgeItem = typename DeviceEdgeHashmap<IdxType>::EdgeItem;\n  auto edge_hashmap_t = static_cast<EdgeItem *>(edge_hashmap);\n  int64_t start_idx = (blockIdx.x * TILE_SIZE) + threadIdx.x;\n  int64_t last_idx = start_idx + TILE_SIZE;\n#pragma unroll(4)\n  for (int64_t idx = start_idx; idx < last_idx; idx += BLOCK_SIZE) {\n    if (idx < edges_len) {\n      EdgeItem *edge = (edge_hashmap_t + idx);\n      edge->src = static_cast<IdxType>(-1);\n      edge->cnt = static_cast<IdxType>(0);\n    }\n  }\n}\n\ntemplate <typename IdxType, int BLOCK_SIZE, int TILE_SIZE>\n__global__ void _count_frequency(\n    const IdxType *src_data, const int64_t num_edges,\n    const int64_t num_edges_per_node, IdxType *edge_blocks_prefix,\n    bool *is_first_position, DeviceEdgeHashmap<IdxType> device_edge_hashmap) {\n  int64_t start_idx = (blockIdx.x * TILE_SIZE) + threadIdx.x;\n  int64_t last_idx = start_idx + TILE_SIZE;\n\n  IdxType count = 0;\n  for (int64_t idx = start_idx; idx < last_idx; idx += BLOCK_SIZE) {\n    if (idx < num_edges) {\n      IdxType src = src_data[idx];\n      if (src == static_cast<IdxType>(-1)) {\n        continue;\n      }\n      IdxType dst_idx = (idx / num_edges_per_node);\n      if (device_edge_hashmap.InsertEdge(src, dst_idx) == 0) {\n        is_first_position[idx] = true;\n        ++count;\n      }\n    }\n  }\n\n  using BlockReduce = typename cub::BlockReduce<IdxType, BLOCK_SIZE>;\n  __shared__ typename BlockReduce::TempStorage temp_space;\n\n  count = BlockReduce(temp_space).Sum(count);\n  if (threadIdx.x == 0) {\n    edge_blocks_prefix[blockIdx.x] = count;\n    if (blockIdx.x == 0) {\n      edge_blocks_prefix[gridDim.x] = 0;\n    }\n  }\n}\n\n/**\n * This structure is used with cub's block-level prefixscan in order to\n * keep a running sum as items are iteratively processed.\n */\ntemplate <typename T>\nstruct BlockPrefixCallbackOp {\n  T _running_total;\n\n  __device__ BlockPrefixCallbackOp(const T running_total)\n      : _running_total(running_total) {}\n\n  __device__ T operator()(const T block_aggregate) {\n    const T old_prefix = _running_total;\n    _running_total += block_aggregate;\n    return old_prefix;\n  }\n};\n\ntemplate <typename IdxType, typename Idx64Type, int BLOCK_SIZE, int TILE_SIZE>\n__global__ void _compact_frequency(\n    const IdxType *src_data, const IdxType *dst_data, const int64_t num_edges,\n    const int64_t num_edges_per_node, const IdxType *edge_blocks_prefix,\n    const bool *is_first_position, IdxType *num_unique_each_node,\n    IdxType *unique_src_edges, Idx64Type *unique_frequency,\n    DeviceEdgeHashmap<IdxType> device_edge_hashmap) {\n  int64_t start_idx = (blockIdx.x * TILE_SIZE) + threadIdx.x;\n  int64_t last_idx = start_idx + TILE_SIZE;\n  const IdxType block_offset = edge_blocks_prefix[blockIdx.x];\n\n  using BlockScan = typename cub::BlockScan<IdxType, BLOCK_SIZE>;\n  __shared__ typename BlockScan::TempStorage temp_space;\n  BlockPrefixCallbackOp<IdxType> prefix_op(0);\n\n  for (int64_t idx = start_idx; idx < last_idx; idx += BLOCK_SIZE) {\n    IdxType flag = 0;\n    if (idx < num_edges) {\n      IdxType src = src_data[idx];\n      IdxType dst_idx = (idx / num_edges_per_node);\n      if (idx % num_edges_per_node == 0) {\n        num_unique_each_node[dst_idx] =\n            device_edge_hashmap.GetDstCount(dst_idx);\n      }\n      if (is_first_position[idx] == true) {\n        flag = 1;\n      }\n      BlockScan(temp_space).ExclusiveSum(flag, flag, prefix_op);\n      __syncthreads();\n      if (is_first_position[idx] == true) {\n        const IdxType pos = (block_offset + flag);\n        unique_src_edges[pos] = src;\n        if (sizeof(IdxType) != sizeof(Idx64Type) &&\n            sizeof(IdxType) == 4) {  // if IdxType is a 32-bit data\n          unique_frequency[pos] =\n              ((static_cast<Idx64Type>(num_edges / num_edges_per_node - dst_idx)\n                << 32) |\n               device_edge_hashmap.GetEdgeCount(src, dst_idx));\n        } else {\n          unique_frequency[pos] =\n              device_edge_hashmap.GetEdgeCount(src, dst_idx);\n        }\n      }\n    }\n  }\n}\n\ntemplate <typename IdxType, int BLOCK_SIZE, int TILE_SIZE>\n__global__ void _get_pick_num(\n    IdxType *num_unique_each_node, const int64_t num_pick,\n    const int64_t num_dst_nodes) {\n  int64_t start_idx = (blockIdx.x * TILE_SIZE) + threadIdx.x;\n  int64_t last_idx = start_idx + TILE_SIZE;\n#pragma unroll(4)\n  for (int64_t idx = start_idx; idx < last_idx; idx += BLOCK_SIZE) {\n    if (idx < num_dst_nodes) {\n      IdxType &num_unique = num_unique_each_node[idx];\n      num_unique = min(num_unique, static_cast<IdxType>(num_pick));\n    }\n  }\n}\n\ntemplate <typename IdxType, typename Idx64Type, int BLOCK_SIZE, int TILE_SIZE>\n__global__ void _pick_data(\n    const Idx64Type *unique_frequency, const IdxType *unique_src_edges,\n    const IdxType *unique_input_offsets, const IdxType *dst_data,\n    const int64_t num_edges_per_node, const int64_t num_dst_nodes,\n    const int64_t num_edges, const IdxType *unique_output_offsets,\n    IdxType *output_src, IdxType *output_dst, IdxType *output_frequency) {\n  int64_t start_idx = (blockIdx.x * TILE_SIZE) + threadIdx.x;\n  int64_t last_idx = start_idx + TILE_SIZE;\n\n  for (int64_t idx = start_idx; idx < last_idx; idx += BLOCK_SIZE) {\n    if (idx < num_dst_nodes) {\n      const int64_t dst_pos = (idx * num_edges_per_node);\n      assert(dst_pos < num_edges);\n      const IdxType dst = dst_data[dst_pos];\n      const IdxType last_output_offset = unique_output_offsets[idx + 1];\n      assert(\n          (last_output_offset - unique_output_offsets[idx]) <=\n          (unique_input_offsets[idx + 1] - unique_input_offsets[idx]));\n      for (IdxType output_idx = unique_output_offsets[idx],\n                   input_idx = unique_input_offsets[idx];\n           output_idx < last_output_offset; ++output_idx, ++input_idx) {\n        output_src[output_idx] = unique_src_edges[input_idx];\n        output_dst[output_idx] = dst;\n        output_frequency[output_idx] =\n            static_cast<IdxType>(unique_frequency[input_idx]);\n      }\n    }\n  }\n}\n\n}  // namespace\n\n// return the old cnt of this edge\ntemplate <typename IdxType>\ninline __device__ IdxType DeviceEdgeHashmap<IdxType>::InsertEdge(\n    const IdxType &src, const IdxType &dst_idx) {\n  IdxType start_off = dst_idx * _num_items_each_dst;\n  IdxType pos = EdgeHash(src);\n  IdxType delta = 1;\n  IdxType old_cnt = static_cast<IdxType>(-1);\n  while (true) {\n    IdxType old_src = dgl::aten::cuda::AtomicCAS(\n        &_edge_hashmap[start_off + pos].src, static_cast<IdxType>(-1), src);\n    if (old_src == static_cast<IdxType>(-1) || old_src == src) {\n      // first insert\n      old_cnt = dgl::aten::cuda::AtomicAdd(\n          &_edge_hashmap[start_off + pos].cnt, static_cast<IdxType>(1));\n      if (old_src == static_cast<IdxType>(-1)) {\n        assert(dst_idx < _num_dst);\n        dgl::aten::cuda::AtomicAdd(\n            &_dst_unique_edges[dst_idx], static_cast<IdxType>(1));\n      }\n      break;\n    }\n    pos = EdgeHash(pos + delta);\n    delta += 1;\n  }\n  return old_cnt;\n}\n\ntemplate <typename IdxType>\ninline __device__ IdxType\nDeviceEdgeHashmap<IdxType>::GetDstCount(const IdxType &dst_idx) {\n  return _dst_unique_edges[dst_idx];\n}\n\ntemplate <typename IdxType>\ninline __device__ IdxType DeviceEdgeHashmap<IdxType>::GetEdgeCount(\n    const IdxType &src, const IdxType &dst_idx) {\n  IdxType start_off = dst_idx * _num_items_each_dst;\n  IdxType pos = EdgeHash(src);\n  IdxType delta = 1;\n  while (_edge_hashmap[start_off + pos].src != src) {\n    pos = EdgeHash(pos + delta);\n    delta += 1;\n  }\n  return _edge_hashmap[start_off + pos].cnt;\n}\n\ntemplate <typename IdxType>\nFrequencyHashmap<IdxType>::FrequencyHashmap(\n    int64_t num_dst, int64_t num_items_each_dst, DGLContext ctx,\n    cudaStream_t stream, int64_t edge_table_scale) {\n  _ctx = ctx;\n  _stream = stream;\n  num_items_each_dst = _table_size(num_items_each_dst, edge_table_scale);\n  auto device = dgl::runtime::DeviceAPI::Get(_ctx);\n  auto dst_unique_edges = static_cast<IdxType *>(\n      device->AllocWorkspace(_ctx, (num_dst) * sizeof(IdxType)));\n  auto edge_hashmap = static_cast<EdgeItem *>(device->AllocWorkspace(\n      _ctx, (num_dst * num_items_each_dst) * sizeof(EdgeItem)));\n  constexpr int BLOCK_SIZE = 256;\n  constexpr int TILE_SIZE = BLOCK_SIZE * 8;\n  dim3 block(BLOCK_SIZE);\n  dim3 grid((num_dst * num_items_each_dst + TILE_SIZE - 1) / TILE_SIZE);\n  CUDA_CALL(cudaMemset(dst_unique_edges, 0, (num_dst) * sizeof(IdxType)));\n  CUDA_KERNEL_CALL(\n      (_init_edge_table<IdxType, BLOCK_SIZE, TILE_SIZE>), grid, block, 0,\n      _stream, edge_hashmap, (num_dst * num_items_each_dst));\n  _device_edge_hashmap = new DeviceEdgeHashmap<IdxType>(\n      num_dst, num_items_each_dst, dst_unique_edges, edge_hashmap);\n  _dst_unique_edges = dst_unique_edges;\n  _edge_hashmap = edge_hashmap;\n}\n\ntemplate <typename IdxType>\nFrequencyHashmap<IdxType>::~FrequencyHashmap() {\n  auto device = dgl::runtime::DeviceAPI::Get(_ctx);\n  delete _device_edge_hashmap;\n  _device_edge_hashmap = nullptr;\n  device->FreeWorkspace(_ctx, _dst_unique_edges);\n  _dst_unique_edges = nullptr;\n  device->FreeWorkspace(_ctx, _edge_hashmap);\n  _edge_hashmap = nullptr;\n}\n\ntemplate <typename IdxType>\nstd::tuple<IdArray, IdArray, IdArray> FrequencyHashmap<IdxType>::Topk(\n    const IdxType *src_data, const IdxType *dst_data, DGLDataType dtype,\n    const int64_t num_edges, const int64_t num_edges_per_node,\n    const int64_t num_pick) {\n  using Idx64Type = int64_t;\n  const int64_t num_dst_nodes = (num_edges / num_edges_per_node);\n  constexpr int BLOCK_SIZE = 256;\n  // XXX: a experienced value, best performance in GV100\n  constexpr int TILE_SIZE = BLOCK_SIZE * 32;\n  const dim3 block(BLOCK_SIZE);\n  const dim3 edges_grid((num_edges + TILE_SIZE - 1) / TILE_SIZE);\n  auto device = dgl::runtime::DeviceAPI::Get(_ctx);\n  const IdxType num_edge_blocks = static_cast<IdxType>(edges_grid.x);\n  IdxType num_unique_edges = 0;\n\n  // to mark if this position of edges is the first inserting position for\n  // _edge_hashmap\n  bool *is_first_position = static_cast<bool *>(\n      device->AllocWorkspace(_ctx, sizeof(bool) * (num_edges)));\n  CUDA_CALL(cudaMemset(is_first_position, 0, sizeof(bool) * (num_edges)));\n  // double space to use ExclusiveSum\n  auto edge_blocks_prefix_data = static_cast<IdxType *>(device->AllocWorkspace(\n      _ctx, 2 * sizeof(IdxType) * (num_edge_blocks + 1)));\n  IdxType *edge_blocks_prefix = edge_blocks_prefix_data;\n  IdxType *edge_blocks_prefix_alternate =\n      (edge_blocks_prefix_data + (num_edge_blocks + 1));\n  // triple space to use ExclusiveSum and unique_output_offsets\n  auto num_unique_each_node_data = static_cast<IdxType *>(\n      device->AllocWorkspace(_ctx, 3 * sizeof(IdxType) * (num_dst_nodes + 1)));\n  IdxType *num_unique_each_node = num_unique_each_node_data;\n  IdxType *num_unique_each_node_alternate =\n      (num_unique_each_node_data + (num_dst_nodes + 1));\n  IdxType *unique_output_offsets =\n      (num_unique_each_node_data + 2 * (num_dst_nodes + 1));\n\n  // 1. Scan the all edges and count the unique edges and unique edges for each\n  // dst node\n  CUDA_KERNEL_CALL(\n      (_count_frequency<IdxType, BLOCK_SIZE, TILE_SIZE>), edges_grid, block, 0,\n      _stream, src_data, num_edges, num_edges_per_node, edge_blocks_prefix,\n      is_first_position, *_device_edge_hashmap);\n\n  // 2. Compact the unique edges frequency\n  // 2.1 ExclusiveSum the edge_blocks_prefix\n  void *d_temp_storage = nullptr;\n  size_t temp_storage_bytes = 0;\n  CUDA_CALL(cub::DeviceScan::ExclusiveSum(\n      d_temp_storage, temp_storage_bytes, edge_blocks_prefix,\n      edge_blocks_prefix_alternate, num_edge_blocks + 1, _stream));\n  d_temp_storage = device->AllocWorkspace(_ctx, temp_storage_bytes);\n  CUDA_CALL(cub::DeviceScan::ExclusiveSum(\n      d_temp_storage, temp_storage_bytes, edge_blocks_prefix,\n      edge_blocks_prefix_alternate, num_edge_blocks + 1, _stream));\n  device->FreeWorkspace(_ctx, d_temp_storage);\n  std::swap(edge_blocks_prefix, edge_blocks_prefix_alternate);\n  device->CopyDataFromTo(\n      &edge_blocks_prefix[num_edge_blocks], 0, &num_unique_edges, 0,\n      sizeof(num_unique_edges), _ctx, DGLContext{kDGLCPU, 0}, dtype);\n  device->StreamSync(_ctx, _stream);\n  // 2.2 Allocate the data of unique edges and frequency\n  // double space to use SegmentedRadixSort\n  auto unique_src_edges_data = static_cast<IdxType *>(\n      device->AllocWorkspace(_ctx, 2 * sizeof(IdxType) * (num_unique_edges)));\n  IdxType *unique_src_edges = unique_src_edges_data;\n  IdxType *unique_src_edges_alternate =\n      unique_src_edges_data + num_unique_edges;\n  // double space to use SegmentedRadixSort\n  auto unique_frequency_data = static_cast<Idx64Type *>(\n      device->AllocWorkspace(_ctx, 2 * sizeof(Idx64Type) * (num_unique_edges)));\n  Idx64Type *unique_frequency = unique_frequency_data;\n  Idx64Type *unique_frequency_alternate =\n      unique_frequency_data + num_unique_edges;\n  // 2.3 Compact the unique edges and their frequency\n  CUDA_KERNEL_CALL(\n      (_compact_frequency<IdxType, Idx64Type, BLOCK_SIZE, TILE_SIZE>),\n      edges_grid, block, 0, _stream, src_data, dst_data, num_edges,\n      num_edges_per_node, edge_blocks_prefix, is_first_position,\n      num_unique_each_node, unique_src_edges, unique_frequency,\n      *_device_edge_hashmap);\n\n  // 3. SegmentedRadixSort the unique edges and unique_frequency\n  // 3.1 ExclusiveSum the num_unique_each_node\n  d_temp_storage = nullptr;\n  temp_storage_bytes = 0;\n  CUDA_CALL(cub::DeviceScan::ExclusiveSum(\n      d_temp_storage, temp_storage_bytes, num_unique_each_node,\n      num_unique_each_node_alternate, num_dst_nodes + 1, _stream));\n  d_temp_storage = device->AllocWorkspace(_ctx, temp_storage_bytes);\n  CUDA_CALL(cub::DeviceScan::ExclusiveSum(\n      d_temp_storage, temp_storage_bytes, num_unique_each_node,\n      num_unique_each_node_alternate, num_dst_nodes + 1, _stream));\n  device->FreeWorkspace(_ctx, d_temp_storage);\n  // 3.2 SegmentedRadixSort the unique_src_edges and unique_frequency\n  // Create a set of DoubleBuffers to wrap pairs of device pointers\n  cub::DoubleBuffer<Idx64Type> d_unique_frequency(\n      unique_frequency, unique_frequency_alternate);\n  cub::DoubleBuffer<IdxType> d_unique_src_edges(\n      unique_src_edges, unique_src_edges_alternate);\n  // Determine temporary device storage requirements\n  d_temp_storage = nullptr;\n  temp_storage_bytes = 0;\n  // the DeviceRadixSort is faster than DeviceSegmentedRadixSort,\n  // especially when num_dst_nodes is large (about ~10000)\n  if (dtype.bits == 32) {\n    CUDA_CALL(cub::DeviceRadixSort::SortPairsDescending(\n        d_temp_storage, temp_storage_bytes, d_unique_frequency,\n        d_unique_src_edges, num_unique_edges, 0, sizeof(Idx64Type) * 8,\n        _stream));\n  } else {\n    CUDA_CALL(cub::DeviceSegmentedRadixSort::SortPairsDescending(\n        d_temp_storage, temp_storage_bytes, d_unique_frequency,\n        d_unique_src_edges, num_unique_edges, num_dst_nodes,\n        num_unique_each_node_alternate, num_unique_each_node_alternate + 1, 0,\n        sizeof(Idx64Type) * 8, _stream));\n  }\n  d_temp_storage = device->AllocWorkspace(_ctx, temp_storage_bytes);\n  if (dtype.bits == 32) {\n    CUDA_CALL(cub::DeviceRadixSort::SortPairsDescending(\n        d_temp_storage, temp_storage_bytes, d_unique_frequency,\n        d_unique_src_edges, num_unique_edges, 0, sizeof(Idx64Type) * 8,\n        _stream));\n  } else {\n    CUDA_CALL(cub::DeviceSegmentedRadixSort::SortPairsDescending(\n        d_temp_storage, temp_storage_bytes, d_unique_frequency,\n        d_unique_src_edges, num_unique_edges, num_dst_nodes,\n        num_unique_each_node_alternate, num_unique_each_node_alternate + 1, 0,\n        sizeof(Idx64Type) * 8, _stream));\n  }\n  device->FreeWorkspace(_ctx, d_temp_storage);\n\n  // 4. Get the final pick number for each dst node\n  // 4.1 Reset the min(num_pick, num_unique_each_node) to num_unique_each_node\n  constexpr int NODE_TILE_SIZE = BLOCK_SIZE * 2;\n  const dim3 nodes_grid((num_dst_nodes + NODE_TILE_SIZE - 1) / NODE_TILE_SIZE);\n  CUDA_KERNEL_CALL(\n      (_get_pick_num<IdxType, BLOCK_SIZE, NODE_TILE_SIZE>), nodes_grid, block,\n      0, _stream, num_unique_each_node, num_pick, num_dst_nodes);\n  // 4.2 ExclusiveSum the new num_unique_each_node as unique_output_offsets\n  // use unique_output_offsets;\n  d_temp_storage = nullptr;\n  temp_storage_bytes = 0;\n  CUDA_CALL(cub::DeviceScan::ExclusiveSum(\n      d_temp_storage, temp_storage_bytes, num_unique_each_node,\n      unique_output_offsets, num_dst_nodes + 1, _stream));\n  d_temp_storage = device->AllocWorkspace(_ctx, temp_storage_bytes);\n  CUDA_CALL(cub::DeviceScan::ExclusiveSum(\n      d_temp_storage, temp_storage_bytes, num_unique_each_node,\n      unique_output_offsets, num_dst_nodes + 1, _stream));\n  device->FreeWorkspace(_ctx, d_temp_storage);\n\n  // 5. Pick the data to result\n  IdxType num_output = 0;\n  device->CopyDataFromTo(\n      &unique_output_offsets[num_dst_nodes], 0, &num_output, 0,\n      sizeof(num_output), _ctx, DGLContext{kDGLCPU, 0}, dtype);\n  device->StreamSync(_ctx, _stream);\n\n  IdArray res_src =\n      IdArray::Empty({static_cast<int64_t>(num_output)}, dtype, _ctx);\n  IdArray res_dst =\n      IdArray::Empty({static_cast<int64_t>(num_output)}, dtype, _ctx);\n  IdArray res_cnt =\n      IdArray::Empty({static_cast<int64_t>(num_output)}, dtype, _ctx);\n  CUDA_KERNEL_CALL(\n      (_pick_data<IdxType, Idx64Type, BLOCK_SIZE, NODE_TILE_SIZE>), nodes_grid,\n      block, 0, _stream, d_unique_frequency.Current(),\n      d_unique_src_edges.Current(), num_unique_each_node_alternate, dst_data,\n      num_edges_per_node, num_dst_nodes, num_edges, unique_output_offsets,\n      res_src.Ptr<IdxType>(), res_dst.Ptr<IdxType>(), res_cnt.Ptr<IdxType>());\n\n  device->FreeWorkspace(_ctx, is_first_position);\n  device->FreeWorkspace(_ctx, edge_blocks_prefix_data);\n  device->FreeWorkspace(_ctx, num_unique_each_node_data);\n  device->FreeWorkspace(_ctx, unique_src_edges_data);\n  device->FreeWorkspace(_ctx, unique_frequency_data);\n\n  return std::make_tuple(res_src, res_dst, res_cnt);\n}\n\ntemplate class FrequencyHashmap<int64_t>;\n\ntemplate class FrequencyHashmap<int32_t>;\n\n};  // namespace impl\n\n};  // namespace sampling\n\n};  // namespace dgl\n"
  },
  {
    "path": "src/graph/sampling/randomwalks/frequency_hashmap.cuh",
    "content": "/**\n *  Copyright (c) 2021 by Contributors\n * @file graph/sampling/frequency_hashmap.cuh\n * @brief frequency hashmap - used to select top-k frequency edges of each node\n */\n\n#ifndef DGL_GRAPH_SAMPLING_RANDOMWALKS_FREQUENCY_HASHMAP_CUH_\n#define DGL_GRAPH_SAMPLING_RANDOMWALKS_FREQUENCY_HASHMAP_CUH_\n\n#include <dgl/array.h>\n#include <dgl/runtime/device_api.h>\n\n#include <tuple>\n\nnamespace dgl {\nnamespace sampling {\nnamespace impl {\n\ntemplate <typename IdxType>\nclass DeviceEdgeHashmap {\n public:\n  struct EdgeItem {\n    IdxType src;\n    IdxType cnt;\n  };\n  DeviceEdgeHashmap() = delete;\n  DeviceEdgeHashmap(\n      int64_t num_dst, int64_t num_items_each_dst, IdxType *dst_unique_edges,\n      EdgeItem *edge_hashmap)\n      : _num_dst(num_dst),\n        _num_items_each_dst(num_items_each_dst),\n        _dst_unique_edges(dst_unique_edges),\n        _edge_hashmap(edge_hashmap) {}\n  // return the old cnt of this edge\n  inline __device__ IdxType\n  InsertEdge(const IdxType &src, const IdxType &dst_idx);\n  inline __device__ IdxType GetDstCount(const IdxType &dst_idx);\n  inline __device__ IdxType\n  GetEdgeCount(const IdxType &src, const IdxType &dst_idx);\n\n private:\n  int64_t _num_dst;\n  int64_t _num_items_each_dst;\n  IdxType *_dst_unique_edges;\n  EdgeItem *_edge_hashmap;\n\n  inline __device__ IdxType EdgeHash(const IdxType &id) const {\n    return id % _num_items_each_dst;\n  }\n};\n\ntemplate <typename IdxType>\nclass FrequencyHashmap {\n public:\n  static constexpr int64_t kDefaultEdgeTableScale = 3;\n  FrequencyHashmap() = delete;\n  FrequencyHashmap(\n      int64_t num_dst, int64_t num_items_each_dst, DGLContext ctx,\n      cudaStream_t stream, int64_t edge_table_scale = kDefaultEdgeTableScale);\n  ~FrequencyHashmap();\n  using EdgeItem = typename DeviceEdgeHashmap<IdxType>::EdgeItem;\n  std::tuple<IdArray, IdArray, IdArray> Topk(\n      const IdxType *src_data, const IdxType *dst_data, DGLDataType dtype,\n      const int64_t num_edges, const int64_t num_edges_per_node,\n      const int64_t num_pick);\n\n private:\n  DGLContext _ctx;\n  cudaStream_t _stream;\n  DeviceEdgeHashmap<IdxType> *_device_edge_hashmap;\n  IdxType *_dst_unique_edges;\n  EdgeItem *_edge_hashmap;\n};\n\n};  // namespace impl\n};  // namespace sampling\n};  // namespace dgl\n\n#endif  // DGL_GRAPH_SAMPLING_RANDOMWALKS_FREQUENCY_HASHMAP_CUH_\n"
  },
  {
    "path": "src/graph/sampling/randomwalks/get_node_types_cpu.cc",
    "content": "/**\n *  Copyright (c) 2018 by Contributors\n * @file graph/sampling/get_node_types_cpu.cc\n * @brief DGL sampler - CPU implementation of random walks with OpenMP\n */\n\n#include <dgl/array.h>\n#include <dgl/base_heterograph.h>\n\n#include <utility>\n\n#include \"randomwalks_impl.h\"\n\nnamespace dgl {\n\nusing namespace dgl::runtime;\nusing namespace dgl::aten;\n\nnamespace sampling {\n\nnamespace impl {\n\ntemplate <DGLDeviceType XPU, typename IdxType>\nTypeArray GetNodeTypesFromMetapath(\n    const HeteroGraphPtr hg, const TypeArray metapath) {\n  uint64_t num_etypes = metapath->shape[0];\n  TypeArray result = TypeArray::Empty(\n      {metapath->shape[0] + 1}, metapath->dtype, metapath->ctx);\n\n  const IdxType *metapath_data = static_cast<IdxType *>(metapath->data);\n  IdxType *result_data = static_cast<IdxType *>(result->data);\n\n  dgl_type_t curr_type = hg->GetEndpointTypes(metapath_data[0]).first;\n  result_data[0] = curr_type;\n\n  for (uint64_t i = 0; i < num_etypes; ++i) {\n    auto src_dst_type = hg->GetEndpointTypes(metapath_data[i]);\n    dgl_type_t srctype = src_dst_type.first;\n    dgl_type_t dsttype = src_dst_type.second;\n\n    if (srctype != curr_type) {\n      LOG(FATAL) << \"source of edge type #\" << i\n                 << \" does not match destination of edge type #\" << i - 1;\n      return result;\n    }\n    curr_type = dsttype;\n    result_data[i + 1] = dsttype;\n  }\n  return result;\n}\n\ntemplate TypeArray GetNodeTypesFromMetapath<kDGLCPU, int32_t>(\n    const HeteroGraphPtr hg, const TypeArray metapath);\ntemplate TypeArray GetNodeTypesFromMetapath<kDGLCPU, int64_t>(\n    const HeteroGraphPtr hg, const TypeArray metapath);\n\n};  // namespace impl\n\n};  // namespace sampling\n\n};  // namespace dgl\n"
  },
  {
    "path": "src/graph/sampling/randomwalks/get_node_types_gpu.cu",
    "content": "/**\n *  Copyright (c) 2021 by Contributors\n * @file graph/sampling/get_node_types_gpu.cu\n * @brief DGL sampler\n */\n\n#include <cuda_runtime.h>\n#include <dgl/array.h>\n#include <dgl/base_heterograph.h>\n#include <dgl/runtime/device_api.h>\n\n#include <utility>\n\n#include \"randomwalks_impl.h\"\n\nnamespace dgl {\n\nusing namespace dgl::runtime;\nusing namespace dgl::aten;\n\nnamespace sampling {\n\nnamespace impl {\n\ntemplate <DGLDeviceType XPU, typename IdxType>\nTypeArray GetNodeTypesFromMetapath(\n    const HeteroGraphPtr hg, const TypeArray metapath) {\n  uint64_t num_etypes = metapath->shape[0];\n\n  auto cpu_ctx = DGLContext{kDGLCPU, 0};\n  auto metapath_ctx = metapath->ctx;\n  auto stream = DeviceAPI::Get(metapath_ctx)->GetStream();\n\n  TypeArray h_result =\n      TypeArray::Empty({metapath->shape[0] + 1}, metapath->dtype, cpu_ctx);\n  auto h_result_data = h_result.Ptr<IdxType>();\n\n  auto h_metapath = metapath.CopyTo(cpu_ctx);\n  DeviceAPI::Get(metapath_ctx)->StreamSync(metapath_ctx, stream);\n  const IdxType *h_metapath_data = h_metapath.Ptr<IdxType>();\n\n  dgl_type_t curr_type = hg->GetEndpointTypes(h_metapath_data[0]).first;\n  h_result_data[0] = curr_type;\n\n  for (uint64_t i = 0; i < num_etypes; ++i) {\n    auto src_dst_type = hg->GetEndpointTypes(h_metapath_data[i]);\n    dgl_type_t srctype = src_dst_type.first;\n    dgl_type_t dsttype = src_dst_type.second;\n\n    if (srctype != curr_type) {\n      LOG(FATAL) << \"source of edge type #\" << i\n                 << \" does not match destination of edge type #\" << i - 1;\n    }\n    curr_type = dsttype;\n    h_result_data[i + 1] = dsttype;\n  }\n\n  auto result = h_result.CopyTo(metapath->ctx);\n  DeviceAPI::Get(metapath_ctx)->StreamSync(metapath_ctx, stream);\n  return result;\n}\n\ntemplate TypeArray GetNodeTypesFromMetapath<kDGLCUDA, int32_t>(\n    const HeteroGraphPtr hg, const TypeArray metapath);\ntemplate TypeArray GetNodeTypesFromMetapath<kDGLCUDA, int64_t>(\n    const HeteroGraphPtr hg, const TypeArray metapath);\n\n};  // namespace impl\n\n};  // namespace sampling\n\n};  // namespace dgl\n"
  },
  {
    "path": "src/graph/sampling/randomwalks/metapath_randomwalk.h",
    "content": "/**\n *  Copyright (c) 2018 by Contributors\n * @file graph/sampler/generic_randomwalk_cpu.h\n * @brief DGL sampler - templated implementation definition of random walks on\n *     CPU.\n */\n\n#ifndef DGL_GRAPH_SAMPLING_RANDOMWALKS_METAPATH_RANDOMWALK_H_\n#define DGL_GRAPH_SAMPLING_RANDOMWALKS_METAPATH_RANDOMWALK_H_\n\n#include <dgl/array.h>\n#include <dgl/base_heterograph.h>\n#include <dgl/random.h>\n\n#include <tuple>\n#include <utility>\n#include <vector>\n\n#include \"randomwalks_cpu.h\"\n#include \"randomwalks_impl.h\"\n\nnamespace dgl {\n\nusing namespace dgl::runtime;\nusing namespace dgl::aten;\n\nnamespace sampling {\n\nnamespace impl {\n\nnamespace {\n\ntemplate <typename IdxType>\nusing TerminatePredicate = std::function<bool(IdxType *, dgl_id_t, int64_t)>;\n\n/**\n * @brief Select one successor of metapath-based random walk, given the path\n *     generated so far.\n *\n * @param data The path generated so far, of type \\c IdxType.\n * @param curr The last node ID generated.\n * @param len The number of nodes generated so far.  Note that the seed node is\n *     always included as \\c data[0], and the successors start from \\c data[1].\n *\n * @param edges_by_type Vector of results from \\c GetAdj() by edge type.\n * @param metapath_data Edge types of given metapath.\n * @param prob Transition probability per edge type.\n * @param terminate Predicate for terminating the current random walk path.\n *\n * @return A tuple of ID of next successor (-1 if not exist), the last traversed\n *     edge ID, as well as whether to terminate.\n */\ntemplate <DGLDeviceType XPU, typename IdxType>\nstd::tuple<dgl_id_t, dgl_id_t, bool> MetapathRandomWalkStep(\n    IdxType *data, dgl_id_t curr, int64_t len,\n    const std::vector<CSRMatrix> &edges_by_type,\n    const std::vector<bool> &csr_has_data, const IdxType *metapath_data,\n    const std::vector<FloatArray> &prob,\n    TerminatePredicate<IdxType> terminate) {\n  dgl_type_t etype = metapath_data[len];\n\n  // Note that since the selection of successors is very lightweight (especially\n  // in the uniform case), we want to reduce the overheads (even from object\n  // copies or object construction) as much as possible. Using Successors()\n  // slows down by 2x. Using OutEdges() slows down by 10x.\n  const CSRMatrix &csr = edges_by_type[etype];\n  const IdxType *offsets = csr.indptr.Ptr<IdxType>();\n  const IdxType *all_succ = csr.indices.Ptr<IdxType>();\n  const IdxType *all_eids =\n      csr_has_data[etype] ? csr.data.Ptr<IdxType>() : nullptr;\n  const IdxType *succ = all_succ + offsets[curr];\n  const IdxType *eids = all_eids ? (all_eids + offsets[curr]) : nullptr;\n\n  const int64_t size = offsets[curr + 1] - offsets[curr];\n  if (size == 0) return std::make_tuple(-1, -1, true);\n\n  // Use a reference to the original array instead of copying. This avoids\n  // updating the ref counts atomically from different threads and avoids cache\n  // ping-ponging in the tight loop.\n  const FloatArray &prob_etype = prob[etype];\n  IdxType idx = 0;\n  if (IsNullArray(prob_etype)) {\n    // empty probability array; assume uniform\n    idx = RandomEngine::ThreadLocal()->RandInt(size);\n  } else {\n    ATEN_FLOAT_TYPE_SWITCH(prob_etype->dtype, DType, \"probability\", {\n      FloatArray prob_selected =\n          FloatArray::Empty({size}, prob_etype->dtype, prob_etype->ctx);\n      DType *prob_selected_data = prob_selected.Ptr<DType>();\n      const DType *prob_etype_data = prob_etype.Ptr<DType>();\n      for (int64_t j = 0; j < size; ++j)\n        prob_selected_data[j] =\n            prob_etype_data[eids ? eids[j] : j + offsets[curr]];\n      idx = RandomEngine::ThreadLocal()->Choice<IdxType>(prob_selected);\n    });\n  }\n  dgl_id_t eid = eids ? eids[idx] : (idx + offsets[curr]);\n\n  return std::make_tuple(succ[idx], eid, terminate(data, curr, len));\n}\n\n/**\n * @brief Select one successor of metapath-based random walk, given the path\n *     generated so far specifically for the uniform probability distribution.\n *\n * @param data The path generated so far, of type \\c IdxType.\n * @param curr The last node ID generated.\n * @param len The number of nodes generated so far.  Note that the seed node is\n *     always included as \\c data[0], and the successors start from \\c data[1].\n *\n * @param edges_by_type Vector of results from \\c GetAdj() by edge type.\n * @param metapath_data Edge types of given metapath.\n * @param prob Transition probability per edge type, for this special case this\n *     will be a NullArray.\n * @param terminate Predicate for terminating the current random walk path.\n *\n * @return A pair of ID of next successor (-1 if not exist), as well as whether\n *     to terminate. \\note This function is called only if all the probability\n *     arrays are null.\n */\ntemplate <DGLDeviceType XPU, typename IdxType>\nstd::tuple<dgl_id_t, dgl_id_t, bool> MetapathRandomWalkStepUniform(\n    IdxType *data, dgl_id_t curr, int64_t len,\n    const std::vector<CSRMatrix> &edges_by_type,\n    const std::vector<bool> &csr_has_data, const IdxType *metapath_data,\n    const std::vector<FloatArray> &prob,\n    TerminatePredicate<IdxType> terminate) {\n  dgl_type_t etype = metapath_data[len];\n\n  // Note that since the selection of successors is very lightweight (especially\n  // in the uniform case), we want to reduce the overheads (even from object\n  // copies or object construction) as much as possible. Using Successors()\n  // slows down by 2x. Using OutEdges() slows down by 10x.\n  const CSRMatrix &csr = edges_by_type[etype];\n  const IdxType *offsets = csr.indptr.Ptr<IdxType>();\n  const IdxType *all_succ = csr.indices.Ptr<IdxType>();\n  const IdxType *all_eids =\n      csr_has_data[etype] ? csr.data.Ptr<IdxType>() : nullptr;\n  const IdxType *succ = all_succ + offsets[curr];\n  const IdxType *eids = all_eids ? (all_eids + offsets[curr]) : nullptr;\n\n  const int64_t size = offsets[curr + 1] - offsets[curr];\n  if (size == 0) return std::make_tuple(-1, -1, true);\n\n  IdxType idx = 0;\n  // Guaranteed uniform distribution\n  idx = RandomEngine::ThreadLocal()->RandInt(size);\n  dgl_id_t eid = eids ? eids[idx] : (idx + offsets[curr]);\n\n  return std::make_tuple(succ[idx], eid, terminate(data, curr, len));\n}\n\n/**\n * @brief Metapath-based random walk.\n * @param hg The heterograph.\n * @param seeds A 1D array of seed nodes, with the type the source type of the\n *     first edge type in the metapath.\n * @param metapath A 1D array of edge types representing the metapath.\n * @param prob A vector of 1D float arrays, indicating the transition\n *     probability of each edge by edge type.  An empty float array assumes\n *     uniform transition.\n * @param terminate Predicate for terminating a random walk path.\n * @return A 2D array of shape (len(seeds), len(metapath) + 1) with node IDs,\n *     and A 2D array of shape (len(seeds), len(metapath)) with edge IDs.\n */\ntemplate <DGLDeviceType XPU, typename IdxType>\nstd::pair<IdArray, IdArray> MetapathBasedRandomWalk(\n    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,\n    const std::vector<FloatArray> &prob,\n    TerminatePredicate<IdxType> terminate) {\n  int64_t max_num_steps = metapath->shape[0];\n  const IdxType *metapath_data = static_cast<IdxType *>(metapath->data);\n  const int64_t begin_ntype =\n      hg->meta_graph()->FindEdge(metapath_data[0]).first;\n  const int64_t max_nodes = hg->NumVertices(begin_ntype);\n\n  // Prefetch all edges.\n  // This forces the heterograph to materialize all OutCSR's before the OpenMP\n  // loop; otherwise data races will happen.\n  // TODO(BarclayII): should we later on materialize COO/CSR/CSC anyway unless\n  // told otherwise?\n  int64_t num_etypes = hg->NumEdgeTypes();\n  std::vector<CSRMatrix> edges_by_type(num_etypes);\n  std::vector<bool> csr_has_data(num_etypes);\n  for (int64_t etype = 0; etype < num_etypes; ++etype) {\n    const CSRMatrix &csr = hg->GetCSRMatrix(etype);\n    edges_by_type[etype] = csr;\n    csr_has_data[etype] = CSRHasData(csr);\n  }\n\n  // Hoist the check for Uniform vs Non uniform edge distribution\n  // to avoid putting it on the hot path\n  bool isUniform = true;\n  for (const auto &etype_prob : prob) {\n    if (!IsNullArray(etype_prob)) {\n      isUniform = false;\n      break;\n    }\n  }\n  if (!isUniform) {\n    StepFunc<IdxType> step = [&edges_by_type, &csr_has_data, metapath_data,\n                              &prob, terminate](\n                                 IdxType *data, dgl_id_t curr, int64_t len) {\n      return MetapathRandomWalkStep<XPU, IdxType>(\n          data, curr, len, edges_by_type, csr_has_data, metapath_data, prob,\n          terminate);\n    };\n    return GenericRandomWalk<XPU, IdxType>(\n        seeds, max_num_steps, step, max_nodes);\n  } else {\n    StepFunc<IdxType> step = [&edges_by_type, &csr_has_data, metapath_data,\n                              &prob, terminate](\n                                 IdxType *data, dgl_id_t curr, int64_t len) {\n      return MetapathRandomWalkStepUniform<XPU, IdxType>(\n          data, curr, len, edges_by_type, csr_has_data, metapath_data, prob,\n          terminate);\n    };\n    return GenericRandomWalk<XPU, IdxType>(\n        seeds, max_num_steps, step, max_nodes);\n  }\n}\n\n};  // namespace\n\n};  // namespace impl\n\n};  // namespace sampling\n\n};  // namespace dgl\n\n#endif  // DGL_GRAPH_SAMPLING_RANDOMWALKS_METAPATH_RANDOMWALK_H_\n"
  },
  {
    "path": "src/graph/sampling/randomwalks/node2vec.cc",
    "content": "/**\n *  Copyright (c) 2021 by Contributors\n * @file graph/sampling/node2vec.cc\n * @brief Dispatcher of DGL node2vec random walks\n */\n\n#include <dgl/array.h>\n#include <dgl/packed_func_ext.h>\n#include <dgl/runtime/container.h>\n\n#include \"../../../c_api_common.h\"\n#include \"node2vec_impl.h\"\n\nusing namespace dgl::runtime;\nusing namespace dgl::aten;\n\nnamespace dgl {\n\nnamespace sampling {\n\nnamespace {\n\nvoid CheckNode2vecInputs(\n    const HeteroGraphPtr hg, const IdArray seeds, const double p,\n    const double q, const int64_t walk_length, const FloatArray &prob) {\n  CHECK_INT(seeds, \"seeds\");\n  CHECK_NDIM(seeds, 1, \"seeds\");\n  CHECK_FLOAT(prob, \"probability\");\n  CHECK_NDIM(prob, 1, \"probability\");\n}\n\nstd::pair<IdArray, IdArray> Node2vec(\n    const HeteroGraphPtr hg, const IdArray seeds, const double p,\n    const double q, const int64_t walk_length, const FloatArray &prob) {\n  CheckNode2vecInputs(hg, seeds, p, q, walk_length, prob);\n\n  std::pair<IdArray, IdArray> result;\n  ATEN_XPU_SWITCH(hg->Context().device_type, XPU, \"Node2vec\", {\n    ATEN_ID_TYPE_SWITCH(seeds->dtype, IdxType, {\n      result = impl::Node2vec<XPU, IdxType>(hg, seeds, p, q, walk_length, prob);\n    });\n  });\n\n  return result;\n}\n\nDGL_REGISTER_GLOBAL(\"sampling.randomwalks._CAPI_DGLSamplingNode2vec\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      HeteroGraphRef hg = args[0];\n      IdArray seeds = args[1];\n      double p = args[2];\n      double q = args[3];\n      int64_t walk_length = args[4];\n      FloatArray prob = args[5];\n\n      auto result =\n          sampling::Node2vec(hg.sptr(), seeds, p, q, walk_length, prob);\n\n      List<Value> ret;\n      ret.push_back(Value(MakeValue(result.first)));\n      ret.push_back(Value(MakeValue(result.second)));\n      *rv = ret;\n    });\n\n}  // namespace\n\n}  // namespace sampling\n\n}  // namespace dgl\n"
  },
  {
    "path": "src/graph/sampling/randomwalks/node2vec_cpu.cc",
    "content": "/**\n *  Copyright (c) 2021 by Contributors\n * @file graph/sampling/node2vec_cpu.cc\n * @brief DGL sampler - CPU implementation of node2vec random walk with OpenMP\n */\n\n#include <dgl/array.h>\n#include <dgl/base_heterograph.h>\n\n#include <utility>\n\n#include \"node2vec_randomwalk.h\"\n\nnamespace dgl {\n\nusing namespace dgl::runtime;\nusing namespace dgl::aten;\n\nnamespace sampling {\n\nnamespace impl {\n\ntemplate <DGLDeviceType XPU, typename IdxType>\nstd::pair<IdArray, IdArray> Node2vec(\n    const HeteroGraphPtr hg, const IdArray seeds, const double p,\n    const double q, const int64_t walk_length, const FloatArray &prob) {\n  TerminatePredicate<IdxType> terminate = [](IdxType *data, dgl_id_t curr,\n                                             int64_t len) { return false; };\n\n  return Node2vecRandomWalk<XPU, IdxType>(\n      hg, seeds, p, q, walk_length, prob, terminate);\n}\n\ntemplate std::pair<IdArray, IdArray> Node2vec<kDGLCPU, int32_t>(\n    const HeteroGraphPtr hg, const IdArray seeds, const double p,\n    const double q, const int64_t walk_length, const FloatArray &prob);\ntemplate std::pair<IdArray, IdArray> Node2vec<kDGLCPU, int64_t>(\n    const HeteroGraphPtr hg, const IdArray seeds, const double p,\n    const double q, const int64_t walk_length, const FloatArray &prob);\n\n};  // namespace impl\n\n};  // namespace sampling\n\n};  // namespace dgl\n"
  },
  {
    "path": "src/graph/sampling/randomwalks/node2vec_impl.h",
    "content": "/**\n *  Copyright (c) 2021 by Contributors\n * @file graph/sampling/node2vec_impl.h\n * @brief DGL sampler - templated implementation definition of node2vec random\n * walks\n */\n\n#ifndef DGL_GRAPH_SAMPLING_RANDOMWALKS_NODE2VEC_IMPL_H_\n#define DGL_GRAPH_SAMPLING_RANDOMWALKS_NODE2VEC_IMPL_H_\n\n#include <dgl/array.h>\n#include <dgl/base_heterograph.h>\n\n#include <functional>\n#include <tuple>\n#include <utility>\n#include <vector>\n\nnamespace dgl {\n\nusing namespace dgl::runtime;\nusing namespace dgl::aten;\n\nnamespace sampling {\n\nnamespace impl {\n\n/**\n * @brief Node2vec random walk.\n * @param hg The heterograph.\n * @param seeds A 1D array of seed nodes, with the type the source type of the\n * first edge type in the metapath.\n * @param p Float, indicating likelihood of immediately revisiting a node in the\n * walk.\n * @param q Float, control parameter to interpolate between breadth-first\n * strategy and depth-first strategy.\n * @param walk_length Int, length of walk.\n * @param prob A vector of 1D float arrays, indicating the transition\n *        probability of each edge by edge type.  An empty float array assumes\n * uniform transition.\n * @return A 2D array of shape (len(seeds), len(walk_length)\n * + 1) with node IDs.  The paths that terminated early are padded with -1.\n */\ntemplate <DGLDeviceType XPU, typename IdxType>\nstd::pair<IdArray, IdArray> Node2vec(\n    const HeteroGraphPtr hg, const IdArray seeds, const double p,\n    const double q, const int64_t walk_length, const FloatArray &prob);\n\n};  // namespace impl\n\n};  // namespace sampling\n\n};  // namespace dgl\n\n#endif  // DGL_GRAPH_SAMPLING_RANDOMWALKS_NODE2VEC_IMPL_H_\n"
  },
  {
    "path": "src/graph/sampling/randomwalks/node2vec_randomwalk.h",
    "content": "/**\n *  Copyright (c) 2021 by Contributors\n * @file graph/sampling/node2vec_randomwalk.cc\n * @brief DGL sampler - CPU implementation of node2vec random walk.\n */\n\n#ifndef DGL_GRAPH_SAMPLING_RANDOMWALKS_NODE2VEC_RANDOMWALK_H_\n#define DGL_GRAPH_SAMPLING_RANDOMWALKS_NODE2VEC_RANDOMWALK_H_\n\n#include <dgl/array.h>\n#include <dgl/base_heterograph.h>\n#include <dgl/random.h>\n\n#include <algorithm>\n#include <cmath>\n#include <functional>\n#include <tuple>\n#include <utility>\n#include <vector>\n\n#include \"metapath_randomwalk.h\"  // for TerminatePredicate\n#include \"node2vec_impl.h\"\n#include \"randomwalks_cpu.h\"\n\nnamespace dgl {\n\nusing namespace dgl::runtime;\nusing namespace dgl::aten;\n\nnamespace sampling {\n\nnamespace impl {\n\nnamespace {\n\ntemplate <typename IdxType>\nbool has_edge_between(const CSRMatrix &csr, dgl_id_t u, dgl_id_t v) {\n  const IdxType *offsets = csr.indptr.Ptr<IdxType>();\n  const IdxType *all_succ = csr.indices.Ptr<IdxType>();\n  const IdxType *u_succ = all_succ + offsets[u];\n  const int64_t size = offsets[u + 1] - offsets[u];\n\n  if (csr.sorted)\n    return std::binary_search(u_succ, u_succ + size, v);\n  else\n    return std::find(u_succ, u_succ + size, v) != u_succ + size;\n}\n\n/**\n * @brief Node2vec random walk step function\n * @param data The path generated so far, of type \\c IdxType.\n * @param curr The last node ID generated.\n * @param pre The last last node ID generated\n * @param p Float, indicating likelihood of immediately revisiting a node in the\n *        walk.\n * @param q Float, control parameter to interpolate between breadth-first\n *        strategy and depth-first strategy.\n * @param len The number of nodes generated so far.  Note that the seed node is\n * always included as \\c data[0], and the successors start from \\c data[1].\n * @param csr The CSR matrix\n * @param prob Transition probability\n * @param terminate Predicate for terminating the current random walk path.\n * @return A tuple of ID of next successor (-1 if not exist), the edge ID\n * traversed, as well as whether to terminate.\n */\n\ntemplate <DGLDeviceType XPU, typename IdxType>\nstd::tuple<dgl_id_t, dgl_id_t, bool> Node2vecRandomWalkStep(\n    IdxType *data, dgl_id_t curr, dgl_id_t pre, const double p, const double q,\n    int64_t len, const CSRMatrix &csr, bool csr_has_data,\n    const FloatArray &probs, TerminatePredicate<IdxType> terminate) {\n  const IdxType *offsets = csr.indptr.Ptr<IdxType>();\n  const IdxType *all_succ = csr.indices.Ptr<IdxType>();\n  const IdxType *all_eids = csr_has_data ? csr.data.Ptr<IdxType>() : nullptr;\n  const IdxType *succ = all_succ + offsets[curr];\n  const IdxType *eids = all_eids ? (all_eids + offsets[curr]) : nullptr;\n\n  const int64_t size = offsets[curr + 1] - offsets[curr];\n\n  // Isolated node\n  if (size == 0) return std::make_tuple(-1, -1, true);\n\n  IdxType idx = 0;\n\n  // Normalize the weights to compute rejection probabilities\n  double max_prob = std::max({1 / p, 1.0, 1 / q});\n  // rejection prob for back to the previous node\n  double prob0 = 1 / p / max_prob;\n  // rejection prob for visiting the node with the distance of 1 between the\n  // previous node\n  double prob1 = 1 / max_prob;\n  // rejection prob for visiting the node with the distance of 2 between the\n  // previous node\n  double prob2 = 1 / q / max_prob;\n  dgl_id_t next_node;\n  double r;  // rejection probability.\n  if (IsNullArray(probs)) {\n    if (len == 0) {\n      idx = RandomEngine::ThreadLocal()->RandInt(size);\n      next_node = succ[idx];\n    } else {\n      while (true) {\n        idx = RandomEngine::ThreadLocal()->RandInt(size);\n        r = RandomEngine::ThreadLocal()->Uniform(0., 1.);\n        next_node = succ[idx];\n        if (next_node == pre) {\n          if (r < prob0) break;\n        } else if (has_edge_between<IdxType>(csr, next_node, pre)) {\n          if (r < prob1) break;\n        } else if (r < prob2) {\n          break;\n        }\n      }\n    }\n  } else {\n    FloatArray prob_selected;\n    ATEN_FLOAT_TYPE_SWITCH(probs->dtype, DType, \"probability\", {\n      prob_selected = FloatArray::Empty({size}, probs->dtype, probs->ctx);\n      DType *prob_selected_data = prob_selected.Ptr<DType>();\n      const DType *prob_etype_data = probs.Ptr<DType>();\n      for (int64_t j = 0; j < size; ++j)\n        prob_selected_data[j] =\n            prob_etype_data[eids ? eids[j] : j + offsets[curr]];\n    });\n\n    if (len == 0) {\n      idx = RandomEngine::ThreadLocal()->Choice<IdxType>(prob_selected);\n      next_node = succ[idx];\n    } else {\n      while (true) {\n        idx = RandomEngine::ThreadLocal()->Choice<IdxType>(prob_selected);\n        r = RandomEngine::ThreadLocal()->Uniform(0., 1.);\n        next_node = succ[idx];\n        if (next_node == pre) {\n          if (r < prob0) break;\n        } else if (has_edge_between<IdxType>(csr, next_node, pre)) {\n          if (r < prob1) break;\n        } else if (r < prob2) {\n          break;\n        }\n      }\n    }\n  }\n  dgl_id_t eid = eids ? eids[idx] : (idx + offsets[curr]);\n\n  return std::make_tuple(next_node, eid, terminate(data, next_node, len));\n}\n\ntemplate <DGLDeviceType XPU, typename IdxType>\nstd::pair<IdArray, IdArray> Node2vecRandomWalk(\n    const HeteroGraphPtr g, const IdArray seeds, const double p, const double q,\n    const int64_t max_num_steps, const FloatArray &prob,\n    TerminatePredicate<IdxType> terminate) {\n  const CSRMatrix &edges = g->GetCSRMatrix(0);  // homogeneous graph.\n  bool csr_has_data = CSRHasData(edges);\n\n  StepFunc<IdxType> step = [&edges, csr_has_data, &prob, p, q, terminate](\n                               IdxType *data, dgl_id_t curr, int64_t len) {\n    dgl_id_t pre = (len != 0) ? data[len - 1] : curr;\n    return Node2vecRandomWalkStep<XPU, IdxType>(\n        data, curr, pre, p, q, len, edges, csr_has_data, prob, terminate);\n  };\n\n  return GenericRandomWalk<XPU, IdxType>(\n      seeds, max_num_steps, step, g->NumVertices(0));\n}\n\n};  // namespace\n\n};  // namespace impl\n\n};  // namespace sampling\n\n};      // namespace dgl\n#endif  // DGL_GRAPH_SAMPLING_RANDOMWALKS_NODE2VEC_RANDOMWALK_H_\n"
  },
  {
    "path": "src/graph/sampling/randomwalks/randomwalk_cpu.cc",
    "content": "/**\n *  Copyright (c) 2018 by Contributors\n * @file graph/sampling/randomwalk_cpu.cc\n * @brief DGL sampler - CPU implementation of metapath-based random walk with\n * OpenMP\n */\n\n#include <dgl/array.h>\n#include <dgl/base_heterograph.h>\n#include <dgl/runtime/device_api.h>\n\n#include <algorithm>\n#include <utility>\n#include <vector>\n\n#include \"metapath_randomwalk.h\"\n#include \"randomwalks_cpu.h\"\n#include \"randomwalks_impl.h\"\n\nnamespace dgl {\n\nusing namespace dgl::runtime;\nusing namespace dgl::aten;\n\nnamespace sampling {\n\nnamespace impl {\n\ntemplate <DGLDeviceType XPU, typename IdxType>\nstd::pair<IdArray, IdArray> RandomWalk(\n    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,\n    const std::vector<FloatArray> &prob) {\n  TerminatePredicate<IdxType> terminate = [](IdxType *data, dgl_id_t curr,\n                                             int64_t len) { return false; };\n\n  return MetapathBasedRandomWalk<XPU, IdxType>(\n      hg, seeds, metapath, prob, terminate);\n}\n\ntemplate <DGLDeviceType XPU, typename IdxType>\nstd::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors(\n    const IdArray src, const IdArray dst, const int64_t num_samples_per_node,\n    const int64_t k) {\n  CHECK(src->ctx.device_type == kDGLCPU) << \"IdArray needs be on CPU!\";\n  int64_t len = src->shape[0] / num_samples_per_node;\n  IdxType *src_data = src.Ptr<IdxType>();\n  const IdxType *dst_data = dst.Ptr<IdxType>();\n  std::vector<IdxType> res_src_vec, res_dst_vec, res_cnt_vec;\n  for (int64_t i = 0; i < len; ++i) {\n    int64_t start_idx = (i * num_samples_per_node);\n    int64_t end_idx = (start_idx + num_samples_per_node);\n    IdxType dst_node = dst_data[start_idx];\n    std::sort(src_data + start_idx, src_data + end_idx);\n    int64_t cnt = 0;\n    std::vector<std::pair<IdxType, IdxType>> vec;\n    for (int64_t j = start_idx; j < end_idx; ++j) {\n      if ((j != start_idx) && (src_data[j] != src_data[j - 1])) {\n        if (src_data[j - 1] != -1) {\n          vec.emplace_back(std::make_pair(cnt, src_data[j - 1]));\n        }\n        cnt = 0;\n      }\n      ++cnt;\n    }\n    // add last count\n    if (src_data[end_idx - 1] != -1) {\n      vec.emplace_back(std::make_pair(cnt, src_data[end_idx - 1]));\n    }\n    std::sort(\n        vec.begin(), vec.end(), std::greater<std::pair<IdxType, IdxType>>());\n    int64_t len = std::min(vec.size(), static_cast<size_t>(k));\n    for (int64_t j = 0; j < len; ++j) {\n      auto pair_item = vec[j];\n      res_src_vec.emplace_back(pair_item.second);\n      res_dst_vec.emplace_back(dst_node);\n      res_cnt_vec.emplace_back(pair_item.first);\n    }\n  }\n  IdArray res_src = IdArray::Empty(\n      {static_cast<int64_t>(res_src_vec.size())}, src->dtype, src->ctx);\n  IdArray res_dst = IdArray::Empty(\n      {static_cast<int64_t>(res_dst_vec.size())}, dst->dtype, dst->ctx);\n  IdArray res_cnt = IdArray::Empty(\n      {static_cast<int64_t>(res_cnt_vec.size())}, src->dtype, src->ctx);\n\n  // copy data from vector to NDArray\n  auto device = runtime::DeviceAPI::Get(src->ctx);\n  device->CopyDataFromTo(\n      static_cast<IdxType *>(res_src_vec.data()), 0, res_src.Ptr<IdxType>(), 0,\n      sizeof(IdxType) * res_src_vec.size(), DGLContext{kDGLCPU, 0},\n      res_src->ctx, res_src->dtype);\n  device->CopyDataFromTo(\n      static_cast<IdxType *>(res_dst_vec.data()), 0, res_dst.Ptr<IdxType>(), 0,\n      sizeof(IdxType) * res_dst_vec.size(), DGLContext{kDGLCPU, 0},\n      res_dst->ctx, res_dst->dtype);\n  device->CopyDataFromTo(\n      static_cast<IdxType *>(res_cnt_vec.data()), 0, res_cnt.Ptr<IdxType>(), 0,\n      sizeof(IdxType) * res_cnt_vec.size(), DGLContext{kDGLCPU, 0},\n      res_cnt->ctx, res_cnt->dtype);\n\n  return std::make_tuple(res_src, res_dst, res_cnt);\n}\n\ntemplate std::pair<IdArray, IdArray> RandomWalk<kDGLCPU, int32_t>(\n    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,\n    const std::vector<FloatArray> &prob);\ntemplate std::pair<IdArray, IdArray> RandomWalk<kDGLCPU, int64_t>(\n    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,\n    const std::vector<FloatArray> &prob);\n\ntemplate std::tuple<IdArray, IdArray, IdArray>\nSelectPinSageNeighbors<kDGLCPU, int32_t>(\n    const IdArray src, const IdArray dst, const int64_t num_samples_per_node,\n    const int64_t k);\ntemplate std::tuple<IdArray, IdArray, IdArray>\nSelectPinSageNeighbors<kDGLCPU, int64_t>(\n    const IdArray src, const IdArray dst, const int64_t num_samples_per_node,\n    const int64_t k);\n\n};  // namespace impl\n\n};  // namespace sampling\n\n};  // namespace dgl\n"
  },
  {
    "path": "src/graph/sampling/randomwalks/randomwalk_gpu.cu",
    "content": "/**\n *  Copyright (c) 2021-2022 by Contributors\n * @file graph/sampling/randomwalk_gpu.cu\n * @brief CUDA random walk sampleing\n */\n\n#include <curand_kernel.h>\n#include <dgl/array.h>\n#include <dgl/base_heterograph.h>\n#include <dgl/random.h>\n#include <dgl/runtime/device_api.h>\n\n#include <cub/cub.cuh>\n#include <tuple>\n#include <utility>\n#include <vector>\n\n#include \"../../../runtime/cuda/cuda_common.h\"\n#include \"frequency_hashmap.cuh\"\n\nnamespace dgl {\n\nusing namespace dgl::runtime;\nusing namespace dgl::aten;\n\nnamespace sampling {\n\nnamespace impl {\n\nnamespace {\n\ntemplate <typename IdType>\nstruct GraphKernelData {\n  const IdType *in_ptr;\n  const IdType *in_cols;\n  const IdType *data;\n};\n\ntemplate <typename IdType, typename FloatType, int BLOCK_SIZE, int TILE_SIZE>\n__global__ void _RandomWalkKernel(\n    const uint64_t rand_seed, const IdType *seed_data, const int64_t num_seeds,\n    const IdType *metapath_data, const uint64_t max_num_steps,\n    const GraphKernelData<IdType> *graphs, const FloatType *restart_prob_data,\n    const int64_t restart_prob_size, const int64_t max_nodes,\n    IdType *out_traces_data, IdType *out_eids_data) {\n  assert(BLOCK_SIZE == blockDim.x);\n  int64_t idx = blockIdx.x * TILE_SIZE + threadIdx.x;\n  int64_t last_idx =\n      min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_seeds);\n  int64_t trace_length = (max_num_steps + 1);\n  curandState rng;\n  // reference:\n  //     https://docs.nvidia.com/cuda/curand/device-api-overview.html#performance-notes\n  curand_init(rand_seed + idx, 0, 0, &rng);\n\n  while (idx < last_idx) {\n    IdType curr = seed_data[idx];\n    assert(curr < max_nodes);\n    IdType *traces_data_ptr = &out_traces_data[idx * trace_length];\n    IdType *eids_data_ptr = &out_eids_data[idx * max_num_steps];\n    *(traces_data_ptr++) = curr;\n    int64_t step_idx;\n    for (step_idx = 0; step_idx < max_num_steps; ++step_idx) {\n      IdType metapath_id = metapath_data[step_idx];\n      const GraphKernelData<IdType> &graph = graphs[metapath_id];\n      const int64_t in_row_start = graph.in_ptr[curr];\n      const int64_t deg = graph.in_ptr[curr + 1] - graph.in_ptr[curr];\n      if (deg == 0) {  // the degree is zero\n        break;\n      }\n      const int64_t num = curand(&rng) % deg;\n      IdType pick = graph.in_cols[in_row_start + num];\n      IdType eid =\n          (graph.data ? graph.data[in_row_start + num] : in_row_start + num);\n      *traces_data_ptr = pick;\n      *eids_data_ptr = eid;\n      if ((restart_prob_size > 1) &&\n          (curand_uniform(&rng) < restart_prob_data[step_idx])) {\n        break;\n      } else if (\n          (restart_prob_size == 1) &&\n          (curand_uniform(&rng) < restart_prob_data[0])) {\n        break;\n      }\n      ++traces_data_ptr;\n      ++eids_data_ptr;\n      curr = pick;\n    }\n    for (; step_idx < max_num_steps; ++step_idx) {\n      *(traces_data_ptr++) = -1;\n      *(eids_data_ptr++) = -1;\n    }\n    idx += BLOCK_SIZE;\n  }\n}\n\ntemplate <typename IdType, typename FloatType, int BLOCK_SIZE, int TILE_SIZE>\n__global__ void _RandomWalkBiasedKernel(\n    const uint64_t rand_seed, const IdType *seed_data, const int64_t num_seeds,\n    const IdType *metapath_data, const uint64_t max_num_steps,\n    const GraphKernelData<IdType> *graphs, const FloatType **probs,\n    const FloatType **prob_sums, const FloatType *restart_prob_data,\n    const int64_t restart_prob_size, const int64_t max_nodes,\n    IdType *out_traces_data, IdType *out_eids_data) {\n  assert(BLOCK_SIZE == blockDim.x);\n  int64_t idx = blockIdx.x * TILE_SIZE + threadIdx.x;\n  int64_t last_idx =\n      min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_seeds);\n  int64_t trace_length = (max_num_steps + 1);\n  curandState rng;\n  // reference:\n  //     https://docs.nvidia.com/cuda/curand/device-api-overview.html#performance-notes\n  curand_init(rand_seed + idx, 0, 0, &rng);\n\n  while (idx < last_idx) {\n    IdType curr = seed_data[idx];\n    assert(curr < max_nodes);\n    IdType *traces_data_ptr = &out_traces_data[idx * trace_length];\n    IdType *eids_data_ptr = &out_eids_data[idx * max_num_steps];\n    *(traces_data_ptr++) = curr;\n    int64_t step_idx;\n    for (step_idx = 0; step_idx < max_num_steps; ++step_idx) {\n      IdType metapath_id = metapath_data[step_idx];\n      const GraphKernelData<IdType> &graph = graphs[metapath_id];\n      const int64_t in_row_start = graph.in_ptr[curr];\n      const int64_t deg = graph.in_ptr[curr + 1] - graph.in_ptr[curr];\n      if (deg == 0) {  // the degree is zero\n        break;\n      }\n\n      // randomly select by weight\n      const FloatType *prob_sum = prob_sums[metapath_id];\n      const FloatType *prob = probs[metapath_id];\n      int64_t num;\n      if (prob == nullptr) {\n        num = curand(&rng) % deg;\n      } else {\n        auto rnd_sum_w = prob_sum[curr] * curand_uniform(&rng);\n        FloatType sum_w{0.};\n        for (num = 0; num < deg; ++num) {\n          sum_w += prob[in_row_start + num];\n          if (sum_w >= rnd_sum_w) break;\n        }\n      }\n\n      IdType pick = graph.in_cols[in_row_start + num];\n      IdType eid =\n          (graph.data ? graph.data[in_row_start + num] : in_row_start + num);\n      *traces_data_ptr = pick;\n      *eids_data_ptr = eid;\n      if ((restart_prob_size > 1) &&\n          (curand_uniform(&rng) < restart_prob_data[step_idx])) {\n        break;\n      } else if (\n          (restart_prob_size == 1) &&\n          (curand_uniform(&rng) < restart_prob_data[0])) {\n        break;\n      }\n      ++traces_data_ptr;\n      ++eids_data_ptr;\n      curr = pick;\n    }\n    for (; step_idx < max_num_steps; ++step_idx) {\n      *(traces_data_ptr++) = -1;\n      *(eids_data_ptr++) = -1;\n    }\n    idx += BLOCK_SIZE;\n  }\n}\n\n}  // namespace\n\n// random walk for uniform choice\ntemplate <DGLDeviceType XPU, typename IdType>\nstd::pair<IdArray, IdArray> RandomWalkUniform(\n    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,\n    FloatArray restart_prob) {\n  const int64_t max_num_steps = metapath->shape[0];\n  const IdType *metapath_data = static_cast<IdType *>(metapath->data);\n  const int64_t begin_ntype =\n      hg->meta_graph()->FindEdge(metapath_data[0]).first;\n  const int64_t max_nodes = hg->NumVertices(begin_ntype);\n  int64_t num_etypes = hg->NumEdgeTypes();\n  auto ctx = seeds->ctx;\n\n  const IdType *seed_data = static_cast<const IdType *>(seeds->data);\n  CHECK(seeds->ndim == 1) << \"seeds shape is not one dimension.\";\n  const int64_t num_seeds = seeds->shape[0];\n  int64_t trace_length = max_num_steps + 1;\n  IdArray traces = IdArray::Empty({num_seeds, trace_length}, seeds->dtype, ctx);\n  IdArray eids = IdArray::Empty({num_seeds, max_num_steps}, seeds->dtype, ctx);\n  IdType *traces_data = traces.Ptr<IdType>();\n  IdType *eids_data = eids.Ptr<IdType>();\n\n  std::vector<GraphKernelData<IdType>> h_graphs(num_etypes);\n  for (int64_t etype = 0; etype < num_etypes; ++etype) {\n    const CSRMatrix &csr = hg->GetCSRMatrix(etype);\n    h_graphs[etype].in_ptr = static_cast<const IdType *>(csr.indptr->data);\n    h_graphs[etype].in_cols = static_cast<const IdType *>(csr.indices->data);\n    h_graphs[etype].data =\n        (CSRHasData(csr) ? static_cast<const IdType *>(csr.data->data)\n                         : nullptr);\n  }\n  // use cuda stream from local thread\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  auto device = DeviceAPI::Get(ctx);\n  auto d_graphs = static_cast<GraphKernelData<IdType> *>(device->AllocWorkspace(\n      ctx, (num_etypes) * sizeof(GraphKernelData<IdType>)));\n  // copy graph metadata pointers to GPU\n  device->CopyDataFromTo(\n      h_graphs.data(), 0, d_graphs, 0,\n      (num_etypes) * sizeof(GraphKernelData<IdType>), DGLContext{kDGLCPU, 0},\n      ctx, hg->GetCSRMatrix(0).indptr->dtype);\n  // copy metapath to GPU\n  auto d_metapath = metapath.CopyTo(ctx);\n  const IdType *d_metapath_data = static_cast<IdType *>(d_metapath->data);\n\n  constexpr int BLOCK_SIZE = 256;\n  constexpr int TILE_SIZE = BLOCK_SIZE * 4;\n  dim3 block(256);\n  dim3 grid((num_seeds + TILE_SIZE - 1) / TILE_SIZE);\n  const uint64_t random_seed = RandomEngine::ThreadLocal()->RandInt(1000000000);\n  ATEN_FLOAT_TYPE_SWITCH(\n      restart_prob->dtype, FloatType, \"random walk GPU kernel\", {\n        CHECK(restart_prob->ctx.device_type == kDGLCUDA)\n            << \"restart prob should be in GPU.\";\n        CHECK(restart_prob->ndim == 1) << \"restart prob dimension should be 1.\";\n        const FloatType *restart_prob_data = restart_prob.Ptr<FloatType>();\n        const int64_t restart_prob_size = restart_prob->shape[0];\n        CUDA_KERNEL_CALL(\n            (_RandomWalkKernel<IdType, FloatType, BLOCK_SIZE, TILE_SIZE>), grid,\n            block, 0, stream, random_seed, seed_data, num_seeds,\n            d_metapath_data, max_num_steps, d_graphs, restart_prob_data,\n            restart_prob_size, max_nodes, traces_data, eids_data);\n      });\n\n  device->FreeWorkspace(ctx, d_graphs);\n  return std::make_pair(traces, eids);\n}\n\n/**\n * @brief Random walk for biased choice. We use inverse transform sampling to\n * choose the next step.\n */\ntemplate <DGLDeviceType XPU, typename FloatType, typename IdType>\nstd::pair<IdArray, IdArray> RandomWalkBiased(\n    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,\n    const std::vector<FloatArray> &prob, FloatArray restart_prob) {\n  const int64_t max_num_steps = metapath->shape[0];\n  const IdType *metapath_data = static_cast<IdType *>(metapath->data);\n  const int64_t begin_ntype =\n      hg->meta_graph()->FindEdge(metapath_data[0]).first;\n  const int64_t max_nodes = hg->NumVertices(begin_ntype);\n  int64_t num_etypes = hg->NumEdgeTypes();\n  auto ctx = seeds->ctx;\n\n  const IdType *seed_data = static_cast<const IdType *>(seeds->data);\n  CHECK(seeds->ndim == 1) << \"seeds shape is not one dimension.\";\n  const int64_t num_seeds = seeds->shape[0];\n  int64_t trace_length = max_num_steps + 1;\n  IdArray traces = IdArray::Empty({num_seeds, trace_length}, seeds->dtype, ctx);\n  IdArray eids = IdArray::Empty({num_seeds, max_num_steps}, seeds->dtype, ctx);\n  IdType *traces_data = traces.Ptr<IdType>();\n  IdType *eids_data = eids.Ptr<IdType>();\n\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  auto device = DeviceAPI::Get(ctx);\n  // new probs and prob sums pointers\n  assert(num_etypes == static_cast<int64_t>(prob.size()));\n  std::unique_ptr<FloatType *[]> probs(new FloatType *[prob.size()]);\n  std::unique_ptr<FloatType *[]> prob_sums(new FloatType *[prob.size()]);\n  std::vector<FloatArray> prob_sums_arr;\n  prob_sums_arr.reserve(prob.size());\n\n  // graphs\n  std::vector<GraphKernelData<IdType>> h_graphs(num_etypes);\n  for (int64_t etype = 0; etype < num_etypes; ++etype) {\n    const CSRMatrix &csr = hg->GetCSRMatrix(etype);\n    h_graphs[etype].in_ptr = static_cast<const IdType *>(csr.indptr->data);\n    h_graphs[etype].in_cols = static_cast<const IdType *>(csr.indices->data);\n    h_graphs[etype].data =\n        (CSRHasData(csr) ? static_cast<const IdType *>(csr.data->data)\n                         : nullptr);\n\n    int64_t num_segments = csr.indptr->shape[0] - 1;\n    // will handle empty probs in the kernel\n    if (IsNullArray(prob[etype])) {\n      probs[etype] = nullptr;\n      prob_sums[etype] = nullptr;\n      continue;\n    }\n    probs[etype] = prob[etype].Ptr<FloatType>();\n    prob_sums_arr.push_back(\n        FloatArray::Empty({num_segments}, prob[etype]->dtype, ctx));\n    prob_sums[etype] = prob_sums_arr[etype].Ptr<FloatType>();\n\n    // calculate the sum of the neighbor weights\n    const IdType *d_offsets = static_cast<const IdType *>(csr.indptr->data);\n    size_t temp_storage_size = 0;\n    CUDA_CALL(cub::DeviceSegmentedReduce::Sum(\n        nullptr, temp_storage_size, probs[etype], prob_sums[etype],\n        num_segments, d_offsets, d_offsets + 1, stream));\n    void *temp_storage = device->AllocWorkspace(ctx, temp_storage_size);\n    CUDA_CALL(cub::DeviceSegmentedReduce::Sum(\n        temp_storage, temp_storage_size, probs[etype], prob_sums[etype],\n        num_segments, d_offsets, d_offsets + 1, stream));\n    device->FreeWorkspace(ctx, temp_storage);\n  }\n\n  // copy graph metadata pointers to GPU\n  auto d_graphs = static_cast<GraphKernelData<IdType> *>(device->AllocWorkspace(\n      ctx, (num_etypes) * sizeof(GraphKernelData<IdType>)));\n  device->CopyDataFromTo(\n      h_graphs.data(), 0, d_graphs, 0,\n      (num_etypes) * sizeof(GraphKernelData<IdType>), DGLContext{kDGLCPU, 0},\n      ctx, hg->GetCSRMatrix(0).indptr->dtype);\n  // copy probs pointers to GPU\n  const FloatType **probs_dev = static_cast<const FloatType **>(\n      device->AllocWorkspace(ctx, num_etypes * sizeof(FloatType *)));\n  device->CopyDataFromTo(\n      probs.get(), 0, probs_dev, 0, (num_etypes) * sizeof(FloatType *),\n      DGLContext{kDGLCPU, 0}, ctx, prob[0]->dtype);\n  // copy probs_sum pointers to GPU\n  const FloatType **prob_sums_dev = static_cast<const FloatType **>(\n      device->AllocWorkspace(ctx, num_etypes * sizeof(FloatType *)));\n  device->CopyDataFromTo(\n      prob_sums.get(), 0, prob_sums_dev, 0, (num_etypes) * sizeof(FloatType *),\n      DGLContext{kDGLCPU, 0}, ctx, prob[0]->dtype);\n  // copy metapath to GPU\n  auto d_metapath = metapath.CopyTo(ctx);\n  const IdType *d_metapath_data = static_cast<IdType *>(d_metapath->data);\n\n  constexpr int BLOCK_SIZE = 256;\n  constexpr int TILE_SIZE = BLOCK_SIZE * 4;\n  dim3 block(256);\n  dim3 grid((num_seeds + TILE_SIZE - 1) / TILE_SIZE);\n  const uint64_t random_seed = RandomEngine::ThreadLocal()->RandInt(1000000000);\n  CHECK(restart_prob->ctx.device_type == kDGLCUDA)\n      << \"restart prob should be in GPU.\";\n  CHECK(restart_prob->ndim == 1) << \"restart prob dimension should be 1.\";\n  const FloatType *restart_prob_data = restart_prob.Ptr<FloatType>();\n  const int64_t restart_prob_size = restart_prob->shape[0];\n  CUDA_KERNEL_CALL(\n      (_RandomWalkBiasedKernel<IdType, FloatType, BLOCK_SIZE, TILE_SIZE>), grid,\n      block, 0, stream, random_seed, seed_data, num_seeds, d_metapath_data,\n      max_num_steps, d_graphs, probs_dev, prob_sums_dev, restart_prob_data,\n      restart_prob_size, max_nodes, traces_data, eids_data);\n\n  device->FreeWorkspace(ctx, d_graphs);\n  device->FreeWorkspace(ctx, probs_dev);\n  device->FreeWorkspace(ctx, prob_sums_dev);\n  return std::make_pair(traces, eids);\n}\n\ntemplate <DGLDeviceType XPU, typename IdType>\nstd::pair<IdArray, IdArray> RandomWalk(\n    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,\n    const std::vector<FloatArray> &prob) {\n  bool isUniform = true;\n  for (const auto &etype_prob : prob) {\n    if (!IsNullArray(etype_prob)) {\n      isUniform = false;\n      break;\n    }\n  }\n\n  auto restart_prob =\n      NDArray::Empty({0}, DGLDataType{kDGLFloat, 32, 1}, DGLContext{XPU, 0});\n  if (!isUniform) {\n    std::pair<IdArray, IdArray> ret;\n    ATEN_FLOAT_TYPE_SWITCH(prob[0]->dtype, FloatType, \"probability\", {\n      ret = RandomWalkBiased<XPU, FloatType, IdType>(\n          hg, seeds, metapath, prob, restart_prob);\n    });\n    return ret;\n  } else {\n    return RandomWalkUniform<XPU, IdType>(hg, seeds, metapath, restart_prob);\n  }\n}\n\ntemplate <DGLDeviceType XPU, typename IdType>\nstd::pair<IdArray, IdArray> RandomWalkWithRestart(\n    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,\n    const std::vector<FloatArray> &prob, double restart_prob) {\n  bool isUniform = true;\n  for (const auto &etype_prob : prob) {\n    if (!IsNullArray(etype_prob)) {\n      isUniform = false;\n      break;\n    }\n  }\n\n  auto device_ctx = seeds->ctx;\n  auto restart_prob_array =\n      NDArray::Empty({1}, DGLDataType{kDGLFloat, 64, 1}, device_ctx);\n  auto device = dgl::runtime::DeviceAPI::Get(device_ctx);\n\n  // use cuda stream from local thread\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  device->CopyDataFromTo(\n      &restart_prob, 0, restart_prob_array.Ptr<double>(), 0, sizeof(double),\n      DGLContext{kDGLCPU, 0}, device_ctx, restart_prob_array->dtype);\n  device->StreamSync(device_ctx, stream);\n\n  if (!isUniform) {\n    std::pair<IdArray, IdArray> ret;\n    ATEN_FLOAT_TYPE_SWITCH(prob[0]->dtype, FloatType, \"probability\", {\n      ret = RandomWalkBiased<XPU, FloatType, IdType>(\n          hg, seeds, metapath, prob, restart_prob_array);\n    });\n    return ret;\n  } else {\n    return RandomWalkUniform<XPU, IdType>(\n        hg, seeds, metapath, restart_prob_array);\n  }\n}\n\ntemplate <DGLDeviceType XPU, typename IdType>\nstd::pair<IdArray, IdArray> RandomWalkWithStepwiseRestart(\n    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,\n    const std::vector<FloatArray> &prob, FloatArray restart_prob) {\n  bool isUniform = true;\n  for (const auto &etype_prob : prob) {\n    if (!IsNullArray(etype_prob)) {\n      isUniform = false;\n      break;\n    }\n  }\n\n  if (!isUniform) {\n    std::pair<IdArray, IdArray> ret;\n    ATEN_FLOAT_TYPE_SWITCH(prob[0]->dtype, FloatType, \"probability\", {\n      ret = RandomWalkBiased<XPU, FloatType, IdType>(\n          hg, seeds, metapath, prob, restart_prob);\n    });\n    return ret;\n  } else {\n    return RandomWalkUniform<XPU, IdType>(hg, seeds, metapath, restart_prob);\n  }\n}\n\ntemplate <DGLDeviceType XPU, typename IdxType>\nstd::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors(\n    const IdArray src, const IdArray dst, const int64_t num_samples_per_node,\n    const int64_t k) {\n  CHECK(src->ctx.device_type == kDGLCUDA) << \"IdArray needs be on GPU!\";\n  const IdxType *src_data = src.Ptr<IdxType>();\n  const IdxType *dst_data = dst.Ptr<IdxType>();\n  const int64_t num_dst_nodes = (dst->shape[0] / num_samples_per_node);\n  auto ctx = src->ctx;\n  // use cuda stream from local thread\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  auto frequency_hashmap = FrequencyHashmap<IdxType>(\n      num_dst_nodes, num_samples_per_node, ctx, stream);\n  auto ret = frequency_hashmap.Topk(\n      src_data, dst_data, src->dtype, src->shape[0], num_samples_per_node, k);\n  return ret;\n}\n\ntemplate std::pair<IdArray, IdArray> RandomWalk<kDGLCUDA, int32_t>(\n    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,\n    const std::vector<FloatArray> &prob);\ntemplate std::pair<IdArray, IdArray> RandomWalk<kDGLCUDA, int64_t>(\n    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,\n    const std::vector<FloatArray> &prob);\n\ntemplate std::pair<IdArray, IdArray> RandomWalkWithRestart<kDGLCUDA, int32_t>(\n    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,\n    const std::vector<FloatArray> &prob, double restart_prob);\ntemplate std::pair<IdArray, IdArray> RandomWalkWithRestart<kDGLCUDA, int64_t>(\n    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,\n    const std::vector<FloatArray> &prob, double restart_prob);\n\ntemplate std::pair<IdArray, IdArray>\nRandomWalkWithStepwiseRestart<kDGLCUDA, int32_t>(\n    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,\n    const std::vector<FloatArray> &prob, FloatArray restart_prob);\ntemplate std::pair<IdArray, IdArray>\nRandomWalkWithStepwiseRestart<kDGLCUDA, int64_t>(\n    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,\n    const std::vector<FloatArray> &prob, FloatArray restart_prob);\n\ntemplate std::tuple<IdArray, IdArray, IdArray>\nSelectPinSageNeighbors<kDGLCUDA, int32_t>(\n    const IdArray src, const IdArray dst, const int64_t num_samples_per_node,\n    const int64_t k);\ntemplate std::tuple<IdArray, IdArray, IdArray>\nSelectPinSageNeighbors<kDGLCUDA, int64_t>(\n    const IdArray src, const IdArray dst, const int64_t num_samples_per_node,\n    const int64_t k);\n\n};  // namespace impl\n\n};  // namespace sampling\n\n};  // namespace dgl\n"
  },
  {
    "path": "src/graph/sampling/randomwalks/randomwalk_with_restart_cpu.cc",
    "content": "/**\n *  Copyright (c) 2018 by Contributors\n * @file graph/sampling/randomwalk_with_restart_cpu.cc\n * @brief DGL sampler - CPU implementation of metapath-based random walk with\n * restart with OpenMP\n */\n\n#include <dgl/array.h>\n#include <dgl/base_heterograph.h>\n#include <dgl/random.h>\n\n#include <utility>\n#include <vector>\n\n#include \"metapath_randomwalk.h\"\n#include \"randomwalks_cpu.h\"\n#include \"randomwalks_impl.h\"\n\nnamespace dgl {\n\nusing namespace dgl::runtime;\nusing namespace dgl::aten;\n\nnamespace sampling {\n\nnamespace impl {\n\ntemplate <DGLDeviceType XPU, typename IdxType>\nstd::pair<IdArray, IdArray> RandomWalkWithRestart(\n    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,\n    const std::vector<FloatArray> &prob, double restart_prob) {\n  TerminatePredicate<IdxType> terminate =\n      [restart_prob](IdxType *data, dgl_id_t curr, int64_t len) {\n        return RandomEngine::ThreadLocal()->Uniform<double>() < restart_prob;\n      };\n  return MetapathBasedRandomWalk<XPU, IdxType>(\n      hg, seeds, metapath, prob, terminate);\n}\n\ntemplate std::pair<IdArray, IdArray> RandomWalkWithRestart<kDGLCPU, int32_t>(\n    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,\n    const std::vector<FloatArray> &prob, double restart_prob);\ntemplate std::pair<IdArray, IdArray> RandomWalkWithRestart<kDGLCPU, int64_t>(\n    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,\n    const std::vector<FloatArray> &prob, double restart_prob);\n\ntemplate <DGLDeviceType XPU, typename IdxType>\nstd::pair<IdArray, IdArray> RandomWalkWithStepwiseRestart(\n    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,\n    const std::vector<FloatArray> &prob, FloatArray restart_prob) {\n  std::pair<IdArray, IdArray> result;\n\n  ATEN_FLOAT_TYPE_SWITCH(restart_prob->dtype, DType, \"restart probability\", {\n    DType *restart_prob_data = static_cast<DType *>(restart_prob->data);\n    TerminatePredicate<IdxType> terminate =\n        [restart_prob_data](IdxType *data, dgl_id_t curr, int64_t len) {\n          return RandomEngine::ThreadLocal()->Uniform<DType>() <\n                 restart_prob_data[len];\n        };\n    result = MetapathBasedRandomWalk<XPU, IdxType>(\n        hg, seeds, metapath, prob, terminate);\n  });\n\n  return result;\n}\n\ntemplate std::pair<IdArray, IdArray>\nRandomWalkWithStepwiseRestart<kDGLCPU, int32_t>(\n    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,\n    const std::vector<FloatArray> &prob, FloatArray restart_prob);\ntemplate std::pair<IdArray, IdArray>\nRandomWalkWithStepwiseRestart<kDGLCPU, int64_t>(\n    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,\n    const std::vector<FloatArray> &prob, FloatArray restart_prob);\n\n};  // namespace impl\n\n};  // namespace sampling\n\n};  // namespace dgl\n"
  },
  {
    "path": "src/graph/sampling/randomwalks/randomwalks.cc",
    "content": "/**\n *  Copyright (c) 2018 by Contributors\n * @file graph/sampling/randomwalks.cc\n * @brief Dispatcher of different DGL random walks by device type\n */\n\n#include <dgl/array.h>\n#include <dgl/packed_func_ext.h>\n#include <dgl/runtime/container.h>\n#include <dgl/sampling/randomwalks.h>\n\n#include <tuple>\n#include <utility>\n#include <vector>\n\n#include \"../../../c_api_common.h\"\n#include \"randomwalks_impl.h\"\n\nusing namespace dgl::runtime;\nusing namespace dgl::aten;\n\nnamespace dgl {\n\nnamespace sampling {\n\nnamespace {\n\nvoid CheckRandomWalkInputs(\n    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,\n    const std::vector<FloatArray> &prob) {\n  CHECK_INT(seeds, \"seeds\");\n  CHECK_INT(metapath, \"metapath\");\n  CHECK_NDIM(seeds, 1, \"seeds\");\n  CHECK_NDIM(metapath, 1, \"metapath\");\n  // (Xin): metapath is copied to GPU in CUDA random walk code\n  // CHECK_SAME_CONTEXT(seeds, metapath);\n\n  if (hg->IsPinned()) {\n    CHECK_EQ(seeds->ctx.device_type, kDGLCUDA)\n        << \"Expected seeds (\" << seeds->ctx << \")\"\n        << \" to be on the GPU when the graph is pinned.\";\n  } else if (hg->Context() != seeds->ctx) {\n    LOG(FATAL) << \"Expected seeds (\" << seeds->ctx << \")\"\n               << \" to have the same \"\n               << \"context as graph (\" << hg->Context() << \").\";\n  }\n  for (uint64_t i = 0; i < prob.size(); ++i) {\n    FloatArray p = prob[i];\n    CHECK_EQ(hg->Context(), p->ctx)\n        << \"Expected prob (\" << p->ctx << \")\"\n        << \" to have the same \"\n        << \"context as graph (\" << hg->Context() << \").\";\n    CHECK_FLOAT(p, \"probability\");\n    if (p.GetSize() != 0) {\n      CHECK_EQ(hg->IsPinned(), p.IsPinned())\n          << \"The prob array should have the same pinning status as the graph\";\n      CHECK_NDIM(p, 1, \"probability\");\n    }\n  }\n}\n\n};  // namespace\n\nstd::tuple<IdArray, IdArray, TypeArray> RandomWalk(\n    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,\n    const std::vector<FloatArray> &prob) {\n  CheckRandomWalkInputs(hg, seeds, metapath, prob);\n\n  TypeArray vtypes;\n  std::pair<IdArray, IdArray> result;\n  ATEN_XPU_SWITCH_CUDA(seeds->ctx.device_type, XPU, \"RandomWalk\", {\n    ATEN_ID_TYPE_SWITCH(seeds->dtype, IdxType, {\n      vtypes = impl::GetNodeTypesFromMetapath<XPU, IdxType>(hg, metapath);\n      result = impl::RandomWalk<XPU, IdxType>(hg, seeds, metapath, prob);\n    });\n  });\n\n  return std::make_tuple(result.first, result.second, vtypes);\n}\n\nstd::tuple<IdArray, IdArray, TypeArray> RandomWalkWithRestart(\n    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,\n    const std::vector<FloatArray> &prob, double restart_prob) {\n  CheckRandomWalkInputs(hg, seeds, metapath, prob);\n  CHECK(restart_prob >= 0 && restart_prob < 1)\n      << \"restart probability must belong to [0, 1)\";\n\n  TypeArray vtypes;\n  std::pair<IdArray, IdArray> result;\n  ATEN_XPU_SWITCH_CUDA(seeds->ctx.device_type, XPU, \"RandomWalkWithRestart\", {\n    ATEN_ID_TYPE_SWITCH(seeds->dtype, IdxType, {\n      vtypes = impl::GetNodeTypesFromMetapath<XPU, IdxType>(hg, metapath);\n      result = impl::RandomWalkWithRestart<XPU, IdxType>(\n          hg, seeds, metapath, prob, restart_prob);\n    });\n  });\n\n  return std::make_tuple(result.first, result.second, vtypes);\n}\n\nstd::tuple<IdArray, IdArray, TypeArray> RandomWalkWithStepwiseRestart(\n    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,\n    const std::vector<FloatArray> &prob, FloatArray restart_prob) {\n  CheckRandomWalkInputs(hg, seeds, metapath, prob);\n  // TODO(BarclayII): check the elements of restart probability\n\n  TypeArray vtypes;\n  std::pair<IdArray, IdArray> result;\n  ATEN_XPU_SWITCH_CUDA(\n      seeds->ctx.device_type, XPU, \"RandomWalkWithStepwiseRestart\", {\n        ATEN_ID_TYPE_SWITCH(seeds->dtype, IdxType, {\n          vtypes = impl::GetNodeTypesFromMetapath<XPU, IdxType>(hg, metapath);\n          result = impl::RandomWalkWithStepwiseRestart<XPU, IdxType>(\n              hg, seeds, metapath, prob, restart_prob);\n        });\n      });\n\n  return std::make_tuple(result.first, result.second, vtypes);\n}\n\nstd::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors(\n    const IdArray src, const IdArray dst, const int64_t num_samples_per_node,\n    const int64_t k) {\n  assert(\n      (src->ndim == 1) && (dst->ndim == 1) &&\n      (src->shape[0] % num_samples_per_node == 0) &&\n      (src->shape[0] == dst->shape[0]));\n  std::tuple<IdArray, IdArray, IdArray> result;\n\n  ATEN_XPU_SWITCH_CUDA((src->ctx).device_type, XPU, \"SelectPinSageNeighbors\", {\n    ATEN_ID_TYPE_SWITCH(src->dtype, IdxType, {\n      result = impl::SelectPinSageNeighbors<XPU, IdxType>(\n          src, dst, num_samples_per_node, k);\n    });\n  });\n\n  return result;\n}\n\n};  // namespace sampling\n\nDGL_REGISTER_GLOBAL(\"sampling.randomwalks._CAPI_DGLSamplingRandomWalk\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      HeteroGraphRef hg = args[0];\n      IdArray seeds = args[1];\n      TypeArray metapath = args[2];\n      List<Value> prob = args[3];\n\n      const auto &prob_vec = ListValueToVector<FloatArray>(prob);\n\n      auto result = sampling::RandomWalk(hg.sptr(), seeds, metapath, prob_vec);\n      List<Value> ret;\n      ret.push_back(Value(MakeValue(std::get<0>(result))));\n      ret.push_back(Value(MakeValue(std::get<1>(result))));\n      ret.push_back(Value(MakeValue(std::get<2>(result))));\n      *rv = ret;\n    });\n\nDGL_REGISTER_GLOBAL(\"sampling.pinsage._CAPI_DGLSamplingSelectPinSageNeighbors\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      IdArray src = args[0];\n      IdArray dst = args[1];\n      int64_t num_travelsals = static_cast<int64_t>(args[2]);\n      int64_t k = static_cast<int64_t>(args[3]);\n\n      auto result =\n          sampling::SelectPinSageNeighbors(src, dst, num_travelsals, k);\n\n      List<Value> ret;\n      ret.push_back(Value(MakeValue(std::get<0>(result))));\n      ret.push_back(Value(MakeValue(std::get<1>(result))));\n      ret.push_back(Value(MakeValue(std::get<2>(result))));\n      *rv = ret;\n    });\n\nDGL_REGISTER_GLOBAL(\n    \"sampling.randomwalks._CAPI_DGLSamplingRandomWalkWithRestart\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      HeteroGraphRef hg = args[0];\n      IdArray seeds = args[1];\n      TypeArray metapath = args[2];\n      List<Value> prob = args[3];\n      double restart_prob = args[4];\n\n      const auto &prob_vec = ListValueToVector<FloatArray>(prob);\n\n      auto result = sampling::RandomWalkWithRestart(\n          hg.sptr(), seeds, metapath, prob_vec, restart_prob);\n      List<Value> ret;\n      ret.push_back(Value(MakeValue(std::get<0>(result))));\n      ret.push_back(Value(MakeValue(std::get<1>(result))));\n      ret.push_back(Value(MakeValue(std::get<2>(result))));\n      *rv = ret;\n    });\n\nDGL_REGISTER_GLOBAL(\n    \"sampling.randomwalks._CAPI_DGLSamplingRandomWalkWithStepwiseRestart\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      HeteroGraphRef hg = args[0];\n      IdArray seeds = args[1];\n      TypeArray metapath = args[2];\n      List<Value> prob = args[3];\n      FloatArray restart_prob = args[4];\n\n      const auto &prob_vec = ListValueToVector<FloatArray>(prob);\n\n      auto result = sampling::RandomWalkWithStepwiseRestart(\n          hg.sptr(), seeds, metapath, prob_vec, restart_prob);\n      List<Value> ret;\n      ret.push_back(Value(MakeValue(std::get<0>(result))));\n      ret.push_back(Value(MakeValue(std::get<1>(result))));\n      ret.push_back(Value(MakeValue(std::get<2>(result))));\n      *rv = ret;\n    });\n\nDGL_REGISTER_GLOBAL(\"sampling.randomwalks._CAPI_DGLSamplingPackTraces\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      IdArray vids = args[0];\n      TypeArray vtypes = args[1];\n\n      IdArray concat_vids, concat_vtypes, lengths, offsets;\n      std::tie(concat_vids, lengths, offsets) = Pack(vids, -1);\n      std::tie(concat_vtypes, std::ignore) = ConcatSlices(vtypes, lengths);\n\n      List<Value> ret;\n      ret.push_back(Value(MakeValue(concat_vids)));\n      ret.push_back(Value(MakeValue(concat_vtypes)));\n      ret.push_back(Value(MakeValue(lengths)));\n      ret.push_back(Value(MakeValue(offsets)));\n      *rv = ret;\n    });\n\n};  // namespace dgl\n"
  },
  {
    "path": "src/graph/sampling/randomwalks/randomwalks_cpu.h",
    "content": "/**\n *  Copyright (c) 2018 by Contributors\n * @file graph/sampler/generic_randomwalk_cpu.h\n * @brief DGL sampler - templated implementation definition of random walks on\n * CPU\n */\n\n#ifndef DGL_GRAPH_SAMPLING_RANDOMWALKS_RANDOMWALKS_CPU_H_\n#define DGL_GRAPH_SAMPLING_RANDOMWALKS_RANDOMWALKS_CPU_H_\n\n#include <dgl/array.h>\n#include <dgl/base_heterograph.h>\n#include <dgl/runtime/parallel_for.h>\n\n#include <tuple>\n#include <utility>\n\n#include \"randomwalks_impl.h\"\n\nnamespace dgl {\n\nusing namespace dgl::runtime;\nusing namespace dgl::aten;\n\nnamespace sampling {\n\nnamespace impl {\n\nnamespace {\n\n/**\n * @brief Generic Random Walk.\n * @param seeds A 1D array of seed nodes, with the type the source type of the\n * first edge type in the metapath.\n * @param max_num_steps The maximum number of steps of a random walk path.\n * @param step The random walk step function with type \\c StepFunc.\n * @param max_nodes Throws an error if one of the values in \\c seeds exceeds\n * this argument.\n * @return A 2D array of shape (len(seeds), max_num_steps + 1) with node IDs.\n * @note The graph itself should be bounded in the closure of \\c step.\n */\ntemplate <DGLDeviceType XPU, typename IdxType>\nstd::pair<IdArray, IdArray> GenericRandomWalk(\n    const IdArray seeds, int64_t max_num_steps, StepFunc<IdxType> step,\n    int64_t max_nodes) {\n  int64_t num_seeds = seeds->shape[0];\n  int64_t trace_length = max_num_steps + 1;\n  IdArray traces =\n      IdArray::Empty({num_seeds, trace_length}, seeds->dtype, seeds->ctx);\n  IdArray eids =\n      IdArray::Empty({num_seeds, max_num_steps}, seeds->dtype, seeds->ctx);\n\n  const IdxType *seed_data = seeds.Ptr<IdxType>();\n  IdxType *traces_data = traces.Ptr<IdxType>();\n  IdxType *eids_data = eids.Ptr<IdxType>();\n\n  runtime::parallel_for(0, num_seeds, [&](size_t seed_begin, size_t seed_end) {\n    for (auto seed_id = seed_begin; seed_id < seed_end; seed_id++) {\n      int64_t i;\n      dgl_id_t curr = seed_data[seed_id];\n      traces_data[seed_id * trace_length] = curr;\n\n      CHECK_LT(curr, max_nodes)\n          << \"Seed node ID exceeds the maximum number of nodes.\";\n\n      for (i = 0; i < max_num_steps; ++i) {\n        const auto &succ = step(traces_data + seed_id * trace_length, curr, i);\n        traces_data[seed_id * trace_length + i + 1] = curr = std::get<0>(succ);\n        eids_data[seed_id * max_num_steps + i] = std::get<1>(succ);\n        if (std::get<2>(succ)) break;\n      }\n\n      for (; i < max_num_steps; ++i) {\n        traces_data[seed_id * trace_length + i + 1] = -1;\n        eids_data[seed_id * max_num_steps + i] = -1;\n      }\n    }\n  });\n\n  return std::make_pair(traces, eids);\n}\n\n};  // namespace\n\n};  // namespace impl\n\n};  // namespace sampling\n\n};  // namespace dgl\n\n#endif  // DGL_GRAPH_SAMPLING_RANDOMWALKS_RANDOMWALKS_CPU_H_\n"
  },
  {
    "path": "src/graph/sampling/randomwalks/randomwalks_impl.h",
    "content": "/**\n *  Copyright (c) 2018 by Contributors\n * @file graph/sampling/randomwalks_impl.h\n * @brief DGL sampler - templated implementation definition of random walks\n */\n\n#ifndef DGL_GRAPH_SAMPLING_RANDOMWALKS_RANDOMWALKS_IMPL_H_\n#define DGL_GRAPH_SAMPLING_RANDOMWALKS_RANDOMWALKS_IMPL_H_\n\n#include <dgl/array.h>\n#include <dgl/base_heterograph.h>\n\n#include <functional>\n#include <tuple>\n#include <utility>\n#include <vector>\n\nnamespace dgl {\n\nusing namespace dgl::runtime;\nusing namespace dgl::aten;\n\nnamespace sampling {\n\nnamespace impl {\n\n/**\n * @brief Random walk step function\n */\ntemplate <typename IdxType>\nusing StepFunc = std::function<\n    //        ID        Edge ID   terminate?\n    std::tuple<dgl_id_t, dgl_id_t, bool>(\n        IdxType *,  // node IDs generated so far\n        dgl_id_t,   // last node ID\n        int64_t)>;  // # of steps\n\n/**\n * @brief Get the node types traversed by the metapath.\n * @return A 1D array of shape (len(metapath) + 1,) with node type IDs.\n */\ntemplate <DGLDeviceType XPU, typename IdxType>\nTypeArray GetNodeTypesFromMetapath(\n    const HeteroGraphPtr hg, const TypeArray metapath);\n\n/**\n * @brief Metapath-based random walk.\n * @param hg The heterograph.\n * @param seeds A 1D array of seed nodes, with the type the source type of the\n * first edge type in the metapath.\n * @param metapath A 1D array of edge types\n * representing the metapath.\n * @param prob A vector of 1D float arrays,\n * indicating the transition probability of each edge by edge type.  An empty\n * float array assumes uniform transition.\n * @return A 2D array of shape\n * (len(seeds), len(metapath) + 1) with node IDs.  The paths that terminated\n * early are padded with -1. A 2D array of shape (len(seeds), len(metapath))\n * with edge IDs.  The paths that terminated early are padded with -1. \\note\n * This function should be called together with GetNodeTypesFromMetapath to\n *       determine the node type of each node in the random walk traces.\n */\ntemplate <DGLDeviceType XPU, typename IdxType>\nstd::pair<IdArray, IdArray> RandomWalk(\n    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,\n    const std::vector<FloatArray> &prob);\n\n/**\n * @brief Metapath-based random walk with restart probability.\n * @param hg The heterograph.\n * @param seeds A 1D array of seed nodes, with the type the source type of the\n * first edge type in the metapath.\n * @param metapath A 1D array of edge types\n * representing the metapath.\n * @param prob A vector of 1D float arrays,\n * indicating the transition probability of each edge by edge type.  An empty\n * float array assumes uniform transition.\n * @param restart_prob Restart\n * probability\n * @return A 2D array of shape (len(seeds), len(metapath) + 1) with\n * node IDs.  The paths that terminated early are padded with -1. A 2D array of\n * shape (len(seeds), len(metapath)) with edge IDs.  The paths that terminated\n * early are padded with -1. \\note This function should be called together with\n * GetNodeTypesFromMetapath to determine the node type of each node in the\n * random walk traces.\n */\ntemplate <DGLDeviceType XPU, typename IdxType>\nstd::pair<IdArray, IdArray> RandomWalkWithRestart(\n    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,\n    const std::vector<FloatArray> &prob, double restart_prob);\n\n/**\n * @brief Metapath-based random walk with stepwise restart probability.  Useful\n *        for PinSAGE-like models.\n * @param hg The heterograph.\n * @param seeds A 1D array of seed nodes, with the type the source type of the\n * first edge type in the metapath.\n * @param metapath A 1D array of edge types\n * representing the metapath.\n * @param prob A vector of 1D float arrays,\n * indicating the transition probability of each edge by edge type.  An empty\n * float array assumes uniform transition.\n * @param restart_prob Restart\n * probability array which has the same number of elements as \\c metapath,\n * indicating the probability to terminate after transition.\n * @return A 2D array\n * of shape (len(seeds), len(metapath) + 1) with node IDs.  The paths that\n * terminated early are padded with -1. A 2D array of shape (len(seeds),\n * len(metapath)) with edge IDs.  The paths that terminated early are padded\n * with -1. \\note This function should be called together with\n * GetNodeTypesFromMetapath to determine the node type of each node in the\n * random walk traces.\n */\ntemplate <DGLDeviceType XPU, typename IdxType>\nstd::pair<IdArray, IdArray> RandomWalkWithStepwiseRestart(\n    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,\n    const std::vector<FloatArray> &prob, FloatArray restart_prob);\n\ntemplate <DGLDeviceType XPU, typename IdxType>\nstd::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors(\n    const IdArray src, const IdArray dst, const int64_t num_samples_per_node,\n    const int64_t k);\n\n};  // namespace impl\n\n};  // namespace sampling\n\n};  // namespace dgl\n\n#endif  // DGL_GRAPH_SAMPLING_RANDOMWALKS_RANDOMWALKS_IMPL_H_\n"
  },
  {
    "path": "src/graph/serialize/dglgraph_data.h",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file graph/serialize/dglgraph_data.h\n * @brief Graph serialization header\n */\n#ifndef DGL_GRAPH_SERIALIZE_DGLGRAPH_DATA_H_\n#define DGL_GRAPH_SERIALIZE_DGLGRAPH_DATA_H_\n\n#include <dgl/array.h>\n#include <dgl/graph.h>\n#include <dgl/immutable_graph.h>\n#include <dgl/packed_func_ext.h>\n#include <dgl/runtime/container.h>\n#include <dgl/runtime/ndarray.h>\n#include <dgl/runtime/object.h>\n#include <dmlc/io.h>\n#include <dmlc/type_traits.h>\n\n#include <algorithm>\n#include <iostream>\n#include <memory>\n#include <string>\n#include <utility>\n#include <vector>\n\n#include \"../../c_api_common.h\"\n\nusing dgl::ImmutableGraph;\nusing dgl::runtime::NDArray;\nusing namespace dgl::runtime;\n\nnamespace dgl {\nnamespace serialize {\n\ntypedef std::pair<std::string, NDArray> NamedTensor;\n\nclass GraphDataObject : public runtime::Object {\n public:\n  ImmutableGraphPtr gptr;\n  std::vector<NamedTensor> node_tensors;\n  std::vector<NamedTensor> edge_tensors;\n  static constexpr const char *_type_key = \"graph_serialize.GraphData\";\n\n  void SetData(\n      ImmutableGraphPtr gptr, Map<std::string, Value> node_tensors,\n      Map<std::string, Value> edge_tensors);\n\n  void Save(dmlc::Stream *fs) const;\n\n  bool Load(dmlc::Stream *fs);\n\n  DGL_DECLARE_OBJECT_TYPE_INFO(GraphDataObject, runtime::Object);\n};\n\nclass GraphData : public runtime::ObjectRef {\n public:\n  DGL_DEFINE_OBJECT_REF_METHODS(GraphData, runtime::ObjectRef, GraphDataObject);\n\n  /** @brief create a new GraphData reference */\n  static GraphData Create() {\n    return GraphData(std::make_shared<GraphDataObject>());\n  }\n};\n\nImmutableGraphPtr ToImmutableGraph(GraphPtr g);\n\n}  // namespace serialize\n}  // namespace dgl\n\n#endif  // DGL_GRAPH_SERIALIZE_DGLGRAPH_DATA_H_\n"
  },
  {
    "path": "src/graph/serialize/dglgraph_serialize.cc",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file graph/serialize/graph_serialize.cc\n * @brief Graph serialization implementation\n *\n * The storage structure is\n * {\n *   // MetaData Section\n *   uint64_t kDGLSerializeMagic\n *   uint64_t kVersion\n *   uint64_t GraphType\n *   ** Reserved Area till 4kB **\n *\n *   dgl_id_t num_graphs\n *   vector<dgl_id_t> graph_indices (start address of each graph)\n *   vector<dgl_id_t> nodes_num_list (list of number of nodes for each graph)\n *   vector<dgl_id_t> edges_num_list (list of number of edges for each graph)\n *\n *   vector<GraphData> graph_datas;\n *\n * }\n *\n * Storage of GraphData is\n * {\n *   // Everything uses in csr\n *   NDArray indptr\n *   NDArray indices\n *   NDArray edge_ids\n *   vector<pair<string, NDArray>> node_tensors;\n *   vector<pair<string, NDArray>> edge_tensors;\n * }\n *\n */\n#include <dgl/aten/coo.h>\n#include <dgl/graph_op.h>\n#include <dgl/immutable_graph.h>\n#include <dgl/runtime/container.h>\n#include <dgl/runtime/object.h>\n#include <dmlc/io.h>\n#include <dmlc/type_traits.h>\n\n#include <algorithm>\n#include <iostream>\n#include <string>\n#include <utility>\n#include <vector>\n\n#include \"graph_serialize.h\"\n\nusing namespace dgl::runtime;\n\nusing dgl::COO;\nusing dgl::COOPtr;\nusing dgl::ImmutableGraph;\nusing dgl::runtime::NDArray;\nusing dgl::serialize::GraphData;\nusing dgl::serialize::GraphDataObject;\nusing dmlc::SeekStream;\nusing std::vector;\n\nnamespace dmlc {\nDMLC_DECLARE_TRAITS(has_saveload, GraphDataObject, true);\n}\n\nnamespace dgl {\nnamespace serialize {\n\nbool SaveDGLGraphs(\n    std::string filename, List<GraphData> graph_data,\n    std::vector<NamedTensor> labels_list) {\n  auto fs = std::unique_ptr<SeekStream>(dynamic_cast<SeekStream *>(\n      SeekStream::Create(filename.c_str(), \"w\", true)));\n  CHECK(fs) << \"File name \" << filename << \" is not a valid local file name\";\n\n  // Write DGL MetaData\n  const uint64_t kVersion = 1;\n  fs->Write(kDGLSerializeMagic);\n  fs->Write(kVersion);\n  fs->Write(GraphType::kImmutableGraph);\n  fs->Seek(4096);\n\n  // Write Graph Meta Data\n  dgl_id_t num_graph = graph_data.size();\n\n  std::vector<dgl_id_t> graph_indices(num_graph);\n  std::vector<int64_t> nodes_num_list(num_graph);\n  std::vector<int64_t> edges_num_list(num_graph);\n\n  for (uint64_t i = 0; i < num_graph; ++i) {\n    nodes_num_list[i] = graph_data[i]->gptr->NumVertices();\n    edges_num_list[i] = graph_data[i]->gptr->NumEdges();\n  }\n  // Reserve spaces for graph indices\n  fs->Write(num_graph);\n  dgl_id_t indices_start_ptr = fs->Tell();\n  fs->Write(graph_indices);\n  fs->Write(nodes_num_list);\n  fs->Write(edges_num_list);\n  fs->Write(labels_list);\n\n  // Write GraphData\n  for (uint64_t i = 0; i < num_graph; ++i) {\n    graph_indices[i] = fs->Tell();\n    GraphDataObject gdata = *graph_data[i].as<GraphDataObject>();\n    fs->Write(gdata);\n  }\n\n  fs->Seek(indices_start_ptr);\n  fs->Write(graph_indices);\n\n  return true;\n}\n\nStorageMetaData LoadDGLGraphs(\n    const std::string &filename, std::vector<dgl_id_t> idx_list,\n    bool onlyMeta) {\n  auto fs = std::unique_ptr<SeekStream>(\n      SeekStream::CreateForRead(filename.c_str(), true));\n  CHECK(fs) << \"Filename is invalid\";\n  // Read DGL MetaData\n  uint64_t magicNum, graphType, version;\n  fs->Read(&magicNum);\n  fs->Read(&version);\n  fs->Read(&graphType);\n  fs->Seek(4096);\n\n  CHECK_EQ(magicNum, kDGLSerializeMagic) << \"Invalid DGL files\";\n  CHECK_EQ(version, 1) << \"Invalid DGL files\";\n  StorageMetaData metadata = StorageMetaData::Create();\n  // Read Graph MetaData\n  dgl_id_t num_graph;\n  CHECK(fs->Read(&num_graph)) << \"Invalid num of graph\";\n  std::vector<dgl_id_t> graph_indices;\n  std::vector<int64_t> nodes_num_list;\n  std::vector<int64_t> edges_num_list;\n  std::vector<NamedTensor> labels_list;\n\n  CHECK(fs->Read(&graph_indices)) << \"Invalid graph indices\";\n  CHECK(fs->Read(&nodes_num_list)) << \"Invalid node num list\";\n  CHECK(fs->Read(&edges_num_list)) << \"Invalid edge num list\";\n  CHECK(fs->Read(&labels_list)) << \"Invalid label list\";\n\n  metadata->SetMetaData(num_graph, nodes_num_list, edges_num_list, labels_list);\n\n  std::vector<GraphData> gdata_refs;\n\n  // Early Return\n  if (onlyMeta) {\n    return metadata;\n  }\n\n  if (idx_list.empty()) {\n    // Read All Graphs\n    gdata_refs.reserve(num_graph);\n    for (uint64_t i = 0; i < num_graph; ++i) {\n      GraphData gdata = GraphData::Create();\n      GraphDataObject *gdata_ptr =\n          const_cast<GraphDataObject *>(gdata.as<GraphDataObject>());\n      fs->Read(gdata_ptr);\n      gdata_refs.push_back(gdata);\n    }\n  } else {\n    // Read Selected Graphss\n    gdata_refs.reserve(idx_list.size());\n    // Would be better if idx_list is sorted. However the returned the graphs\n    // should be the same order as the idx_list\n    for (uint64_t i = 0; i < idx_list.size(); ++i) {\n      auto gid = idx_list[i];\n      CHECK((gid < graph_indices.size()) && (gid >= 0))\n          << \"ID \" << gid\n          << \" in idx_list is out of bound. Please check your idx_list.\";\n      fs->Seek(graph_indices[gid]);\n      GraphData gdata = GraphData::Create();\n      GraphDataObject *gdata_ptr =\n          const_cast<GraphDataObject *>(gdata.as<GraphDataObject>());\n      fs->Read(gdata_ptr);\n      gdata_refs.push_back(gdata);\n    }\n  }\n  metadata->SetGraphData(gdata_refs);\n  return metadata;\n}\n\nvoid GraphDataObject::SetData(\n    ImmutableGraphPtr gptr, Map<std::string, Value> node_tensors,\n    Map<std::string, Value> edge_tensors) {\n  this->gptr = gptr;\n\n  for (auto kv : node_tensors) {\n    std::string name = kv.first;\n    Value v = kv.second;\n    NDArray ndarray = static_cast<NDArray>(v->data);\n    this->node_tensors.emplace_back(name, ndarray);\n  }\n  for (auto kv : edge_tensors) {\n    std::string &name = kv.first;\n    Value v = kv.second;\n    const NDArray &ndarray = static_cast<NDArray>(v->data);\n    this->edge_tensors.emplace_back(name, ndarray);\n  }\n}\n\nvoid GraphDataObject::Save(dmlc::Stream *fs) const {\n  // Using in csr for storage\n  const CSRPtr g_csr = this->gptr->GetInCSR();\n  fs->Write(g_csr->indptr());\n  fs->Write(g_csr->indices());\n  fs->Write(g_csr->edge_ids());\n  fs->Write(node_tensors);\n  fs->Write(edge_tensors);\n}\n\nbool GraphDataObject::Load(dmlc::Stream *fs) {\n  NDArray indptr, indices, edge_ids;\n  fs->Read(&indptr);\n  fs->Read(&indices);\n  fs->Read(&edge_ids);\n  this->gptr = ImmutableGraph::CreateFromCSR(indptr, indices, edge_ids, \"in\");\n\n  fs->Read(&this->node_tensors);\n  fs->Read(&this->edge_tensors);\n  return true;\n}\n\nImmutableGraphPtr BatchLoadedGraphs(std::vector<GraphData> gdata_list) {\n  std::vector<GraphPtr> gptrs;\n  gptrs.reserve(gdata_list.size());\n  for (auto gdata : gdata_list) {\n    gptrs.push_back(static_cast<GraphPtr>(gdata->gptr));\n  }\n  ImmutableGraphPtr imGPtr =\n      std::dynamic_pointer_cast<ImmutableGraph>(GraphOp::DisjointUnion(gptrs));\n  return imGPtr;\n}\n\nImmutableGraphPtr ToImmutableGraph(GraphPtr g) {\n  ImmutableGraphPtr imgr = std::dynamic_pointer_cast<ImmutableGraph>(g);\n  if (imgr) {\n    return imgr;\n  } else {\n    MutableGraphPtr mgr = std::dynamic_pointer_cast<Graph>(g);\n    CHECK(mgr) << \"Invalid Graph Pointer\";\n    EdgeArray earray = mgr->Edges(\"eid\");\n    IdArray srcs_array = earray.src;\n    IdArray dsts_array = earray.dst;\n\n    bool row_sorted, col_sorted;\n    std::tie(row_sorted, col_sorted) = COOIsSorted(aten::COOMatrix(\n        mgr->NumVertices(), mgr->NumVertices(), srcs_array, dsts_array));\n\n    ImmutableGraphPtr imgptr = ImmutableGraph::CreateFromCOO(\n        mgr->NumVertices(), srcs_array, dsts_array, row_sorted, col_sorted);\n    return imgptr;\n  }\n}\n\nvoid StorageMetaDataObject::SetMetaData(\n    dgl_id_t num_graph, std::vector<int64_t> nodes_num_list,\n    std::vector<int64_t> edges_num_list, std::vector<NamedTensor> labels_list) {\n  this->num_graph = num_graph;\n  this->nodes_num_list = Value(MakeValue(aten::VecToIdArray(nodes_num_list)));\n  this->edges_num_list = Value(MakeValue(aten::VecToIdArray(edges_num_list)));\n  for (auto kv : labels_list) {\n    this->labels_list.Set(kv.first, Value(MakeValue(kv.second)));\n  }\n}\n\nvoid StorageMetaDataObject::SetGraphData(std::vector<GraphData> gdata) {\n  this->graph_data = List<GraphData>(gdata);\n}\n\n}  // namespace serialize\n}  // namespace dgl\n"
  },
  {
    "path": "src/graph/serialize/dglstream.h",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file graph/serialize/dglstream.h\n * @brief Graph serialization header\n */\n#ifndef DGL_GRAPH_SERIALIZE_DGLSTREAM_H_\n#define DGL_GRAPH_SERIALIZE_DGLSTREAM_H_\n\n#include <dgl/aten/spmat.h>\n#include <dmlc/io.h>\n#include <dmlc/type_traits.h>\n\n#include <memory>\n\nnamespace dgl {\nnamespace serialize {\n\n/**\n * @brief DGLStream counts the bytes that already written into the\n * underlying stream.\n */\nclass DGLStream : public dmlc::Stream {\n public:\n  /** @brief create a new DGLStream instance */\n  static DGLStream *Create(\n      const char *uri, const char *const flag, bool allow_null,\n      dgl_format_code_t formats) {\n    return new DGLStream(uri, flag, allow_null, formats);\n  }\n\n  size_t Read(void *ptr, size_t size) override {\n    return strm_->Read(ptr, size);\n  }\n\n  void Write(const void *ptr, size_t size) override {\n    count_ += size;\n    strm_->Write(ptr, size);\n  }\n\n  using dmlc::Stream::Read;\n  using dmlc::Stream::Write;\n\n  bool IsValid() { return strm_.get(); }\n\n  uint64_t Count() const { return count_; }\n\n  uint64_t FormatsToSave() const { return formats_to_save_; }\n\n private:\n  DGLStream(\n      const char *uri, const char *const flag, bool allow_null,\n      dgl_format_code_t formats)\n      : strm_(dmlc::Stream::Create(uri, flag, allow_null)),\n        formats_to_save_(formats) {}\n  // stream for serialization\n  std::unique_ptr<dmlc::Stream> strm_;\n  // size of already written to stream\n  uint64_t count_ = 0;\n  // formats to use when saving graph\n  const dgl_format_code_t formats_to_save_ = ANY_CODE;\n};\n}  // namespace serialize\n}  // namespace dgl\n\n#endif  // DGL_GRAPH_SERIALIZE_DGLSTREAM_H_\n"
  },
  {
    "path": "src/graph/serialize/graph_serialize.cc",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file graph/serialize/graph_serialize.cc\n * @brief Graph serialization implementation\n *\n * The storage structure is\n * {\n *   // MetaData Section\n *   uint64_t kDGLSerializeMagic\n *   uint64_t kVersion\n *   uint64_t GraphType\n *   ** Reserved Area till 4kB **\n *\n *   dgl_id_t num_graphs\n *   vector<dgl_id_t> graph_indices (start address of each graph)\n *   vector<dgl_id_t> nodes_num_list (list of number of nodes for each graph)\n *   vector<dgl_id_t> edges_num_list (list of number of edges for each graph)\n *\n *   vector<GraphData> graph_datas;\n *\n * }\n *\n * Storage of GraphData is\n * {\n *   // Everything uses in csr\n *   NDArray indptr\n *   NDArray indices\n *   NDArray edge_ids\n *   vector<pair<string, NDArray>> node_tensors;\n *   vector<pair<string, NDArray>> edge_tensors;\n * }\n *\n */\n#include \"graph_serialize.h\"\n\n#include <dgl/graph_op.h>\n#include <dgl/immutable_graph.h>\n#include <dgl/runtime/container.h>\n#include <dgl/runtime/object.h>\n#include <dmlc/io.h>\n#include <dmlc/logging.h>\n#include <dmlc/type_traits.h>\n\n#include <algorithm>\n#include <iostream>\n#include <string>\n#include <utility>\n#include <vector>\n\nusing namespace dgl::runtime;\n\nusing dgl::COO;\nusing dgl::COOPtr;\nusing dgl::ImmutableGraph;\nusing dgl::runtime::NDArray;\nusing dgl::serialize::GraphData;\nusing dgl::serialize::GraphDataObject;\nusing dmlc::SeekStream;\nusing dmlc::Stream;\nusing std::vector;\n\nnamespace dmlc {\nDMLC_DECLARE_TRAITS(has_saveload, GraphDataObject, true);\n}\n\nnamespace dgl {\nnamespace serialize {\n\nDGL_REGISTER_GLOBAL(\"data.graph_serialize._CAPI_MakeGraphData\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      GraphRef gptr = args[0];\n      ImmutableGraphPtr imGPtr = ToImmutableGraph(gptr.sptr());\n      Map<std::string, Value> node_tensors = args[1];\n      Map<std::string, Value> edge_tensors = args[2];\n      GraphData gd = GraphData::Create();\n      gd->SetData(imGPtr, node_tensors, edge_tensors);\n      *rv = gd;\n    });\n\nDGL_REGISTER_GLOBAL(\"data.graph_serialize._CAPI_SaveDGLGraphs_V0\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      std::string filename = args[0];\n      List<GraphData> graph_data = args[1];\n      Map<std::string, Value> labels = args[2];\n      std::vector<NamedTensor> labels_list;\n      for (auto kv : labels) {\n        std::string name = kv.first;\n        Value v = kv.second;\n        NDArray ndarray = static_cast<NDArray>(v->data);\n        labels_list.emplace_back(name, ndarray);\n      }\n      SaveDGLGraphs(filename, graph_data, labels_list);\n    });\n\nDGL_REGISTER_GLOBAL(\"data.graph_serialize._CAPI_GDataGraphHandle\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      GraphData gdata = args[0];\n      *rv = gdata->gptr;\n    });\n\nDGL_REGISTER_GLOBAL(\"data.graph_serialize._CAPI_GDataNodeTensors\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      GraphData gdata = args[0];\n      Map<std::string, Value> rvmap;\n      for (auto kv : gdata->node_tensors) {\n        rvmap.Set(kv.first, Value(MakeValue(kv.second)));\n      }\n      *rv = rvmap;\n    });\n\nDGL_REGISTER_GLOBAL(\"data.graph_serialize._CAPI_GDataEdgeTensors\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      GraphData gdata = args[0];\n      Map<std::string, Value> rvmap;\n      for (auto kv : gdata->edge_tensors) {\n        rvmap.Set(kv.first, Value(MakeValue(kv.second)));\n      }\n      *rv = rvmap;\n    });\n\nuint64_t GetFileVersion(const std::string &filename) {\n  auto fs = std::unique_ptr<SeekStream>(\n      SeekStream::CreateForRead(filename.c_str(), false));\n  CHECK(fs) << \"File \" << filename << \" not found\";\n  uint64_t magicNum, version;\n  fs->Read(&magicNum);\n  fs->Read(&version);\n  CHECK_EQ(magicNum, kDGLSerializeMagic) << \"Invalid DGL files\";\n  return version;\n}\n\nDGL_REGISTER_GLOBAL(\"data.graph_serialize._CAPI_GetFileVersion\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      std::string filename = args[0];\n      *rv = static_cast<int64_t>(GetFileVersion(filename));\n    });\n\nDGL_REGISTER_GLOBAL(\"data.graph_serialize._CAPI_LoadGraphFiles_V1\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      std::string filename = args[0];\n      List<Value> idxs = args[1];\n      bool onlyMeta = args[2];\n      auto idx_list = ListValueToVector<dgl_id_t>(idxs);\n      *rv = LoadDGLGraphs(filename, idx_list, onlyMeta);\n    });\n\nDGL_REGISTER_GLOBAL(\"data.graph_serialize._CAPI_DGLAsHeteroGraph\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      GraphRef g = args[0];\n      ImmutableGraphPtr ig =\n          std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());\n      CHECK(ig) << \"graph is not readonly\";\n      *rv = HeteroGraphRef(ig->AsHeteroGraph());\n    });\n\nDGL_REGISTER_GLOBAL(\"data.graph_serialize._CAPI_LoadGraphFiles_V2\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      std::string filename = args[0];\n      List<Value> idxs = args[1];\n      auto idx_list = ListValueToVector<dgl_id_t>(idxs);\n      *rv = List<HeteroGraphData>(LoadHeteroGraphs(filename, idx_list));\n    });\n\n}  // namespace serialize\n}  // namespace dgl\n"
  },
  {
    "path": "src/graph/serialize/graph_serialize.h",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file graph/serialize/graph_serialize.h\n * @brief Graph serialization header\n */\n#ifndef DGL_GRAPH_SERIALIZE_GRAPH_SERIALIZE_H_\n#define DGL_GRAPH_SERIALIZE_GRAPH_SERIALIZE_H_\n\n#include <dgl/array.h>\n#include <dgl/graph.h>\n#include <dgl/immutable_graph.h>\n#include <dgl/packed_func_ext.h>\n#include <dgl/runtime/container.h>\n#include <dgl/runtime/ndarray.h>\n#include <dgl/runtime/object.h>\n#include <dmlc/io.h>\n#include <dmlc/type_traits.h>\n\n#include <algorithm>\n#include <iostream>\n#include <memory>\n#include <string>\n#include <utility>\n#include <vector>\n\n#include \"../../c_api_common.h\"\n#include \"dglgraph_data.h\"\n#include \"heterograph_data.h\"\n\nusing dgl::ImmutableGraph;\nusing dgl::runtime::NDArray;\nusing namespace dgl::runtime;\n\nnamespace dgl {\nnamespace serialize {\n\nenum GraphType : uint64_t {\n  kMutableGraph = 0ull,\n  kImmutableGraph = 1ull,\n  kHeteroGraph = 2ull\n};\n\nconstexpr uint64_t kDGLSerializeMagic = 0xDD2E4FF046B4A13F;\n\nclass StorageMetaDataObject : public runtime::Object {\n public:\n  // For saving DGLGraph\n  dgl_id_t num_graph;\n  Value nodes_num_list;\n  Value edges_num_list;\n  Map<std::string, Value> labels_list;\n  List<GraphData> graph_data;\n\n  static constexpr const char *_type_key = \"graph_serialize.StorageMetaData\";\n\n  void SetMetaData(\n      dgl_id_t num_graph, std::vector<int64_t> nodes_num_list,\n      std::vector<int64_t> edges_num_list,\n      std::vector<NamedTensor> labels_list);\n\n  void SetGraphData(std::vector<GraphData> gdata);\n\n  void VisitAttrs(AttrVisitor *v) final {\n    v->Visit(\"num_graph\", &num_graph);\n    v->Visit(\"nodes_num_list\", &nodes_num_list);\n    v->Visit(\"edges_num_list\", &edges_num_list);\n    v->Visit(\"labels\", &labels_list);\n    v->Visit(\"graph_data\", &graph_data);\n  }\n\n  DGL_DECLARE_OBJECT_TYPE_INFO(StorageMetaDataObject, runtime::Object);\n};\n\nclass StorageMetaData : public runtime::ObjectRef {\n public:\n  DGL_DEFINE_OBJECT_REF_METHODS(\n      StorageMetaData, runtime::ObjectRef, StorageMetaDataObject);\n\n  /** @brief create a new StorageMetaData reference */\n  static StorageMetaData Create() {\n    return StorageMetaData(std::make_shared<StorageMetaDataObject>());\n  }\n};\n\nStorageMetaData LoadDGLGraphFiles(\n    const std::string &filename, std::vector<dgl_id_t> idx_list, bool onlyMeta);\n\nStorageMetaData LoadDGLGraphs(\n    const std::string &filename, std::vector<dgl_id_t> idx_list, bool onlyMeta);\n\nbool SaveDGLGraphs(\n    std::string filename, List<GraphData> graph_data,\n    std::vector<NamedTensor> labels_list);\n\nstd::vector<HeteroGraphData> LoadHeteroGraphs(\n    const std::string &filename, std::vector<dgl_id_t> idx_list);\n\nImmutableGraphPtr ToImmutableGraph(GraphPtr g);\n\n}  // namespace serialize\n}  // namespace dgl\n\n#endif  // DGL_GRAPH_SERIALIZE_GRAPH_SERIALIZE_H_\n"
  },
  {
    "path": "src/graph/serialize/heterograph_data.h",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file graph/serialize/heterograph_data.h\n * @brief Graph serialization header\n */\n#ifndef DGL_GRAPH_SERIALIZE_HETEROGRAPH_DATA_H_\n#define DGL_GRAPH_SERIALIZE_HETEROGRAPH_DATA_H_\n\n#include <dgl/array.h>\n#include <dgl/graph.h>\n#include <dgl/packed_func_ext.h>\n#include <dgl/runtime/container.h>\n#include <dgl/runtime/ndarray.h>\n#include <dgl/runtime/object.h>\n#include <dmlc/io.h>\n#include <dmlc/type_traits.h>\n\n#include <algorithm>\n#include <iostream>\n#include <memory>\n#include <string>\n#include <utility>\n#include <vector>\n\n#include \"../../c_api_common.h\"\n#include \"../heterograph.h\"\n\nusing dgl::runtime::NDArray;\nusing namespace dgl::runtime;\n\nnamespace dgl {\nnamespace serialize {\n\ntypedef std::pair<std::string, NDArray> NamedTensor;\nclass HeteroGraphDataObject : public runtime::Object {\n public:\n  std::shared_ptr<HeteroGraph> gptr;\n  std::vector<std::vector<NamedTensor>> node_tensors;\n  std::vector<std::vector<NamedTensor>> edge_tensors;\n  std::vector<std::string> etype_names;\n  std::vector<std::string> ntype_names;\n\n  static constexpr const char *_type_key =\n      \"heterograph_serialize.HeteroGraphData\";\n\n  HeteroGraphDataObject() {}\n\n  HeteroGraphDataObject(\n      HeteroGraphPtr gptr, List<Map<std::string, Value>> ndata,\n      List<Map<std::string, Value>> edata, List<Value> ntype_names,\n      List<Value> etype_names) {\n    this->gptr = std::dynamic_pointer_cast<HeteroGraph>(gptr);\n    CHECK_NOTNULL(this->gptr);\n    for (auto nd_dict : ndata) {\n      node_tensors.emplace_back();\n      for (auto kv : nd_dict) {\n        auto last = &node_tensors.back();\n        NDArray ndarray = kv.second->data;\n        last->emplace_back(kv.first, ndarray);\n      }\n    }\n    for (auto nd_dict : edata) {\n      edge_tensors.emplace_back();\n      for (auto kv : nd_dict) {\n        auto last = &edge_tensors.back();\n        NDArray ndarray = kv.second->data;\n        last->emplace_back(kv.first, ndarray);\n      }\n    }\n\n    this->ntype_names = ListValueToVector<std::string>(ntype_names);\n    this->etype_names = ListValueToVector<std::string>(etype_names);\n  }\n\n  void Save(dmlc::Stream *fs) const {\n    fs->Write(gptr);\n    fs->Write(node_tensors);\n    fs->Write(edge_tensors);\n    fs->Write(ntype_names);\n    fs->Write(etype_names);\n  }\n\n  bool Load(dmlc::Stream *fs) {\n    fs->Read(&gptr);\n    fs->Read(&node_tensors);\n    fs->Read(&edge_tensors);\n    fs->Read(&ntype_names);\n    fs->Read(&etype_names);\n    return true;\n  }\n\n  DGL_DECLARE_OBJECT_TYPE_INFO(HeteroGraphDataObject, runtime::Object);\n};\n\nclass HeteroGraphData : public runtime::ObjectRef {\n public:\n  DGL_DEFINE_OBJECT_REF_METHODS(\n      HeteroGraphData, runtime::ObjectRef, HeteroGraphDataObject);\n\n  /** @brief create a new GraphData reference */\n  static HeteroGraphData Create(\n      HeteroGraphPtr gptr, List<Map<std::string, Value>> node_tensors,\n      List<Map<std::string, Value>> edge_tensors, List<Value> ntype_names,\n      List<Value> etype_names) {\n    return HeteroGraphData(std::make_shared<HeteroGraphDataObject>(\n        gptr, node_tensors, edge_tensors, ntype_names, etype_names));\n  }\n\n  /** @brief create an empty GraphData reference */\n  static HeteroGraphData Create() {\n    return HeteroGraphData(std::make_shared<HeteroGraphDataObject>());\n  }\n};\n}  // namespace serialize\n}  // namespace dgl\n\nnamespace dmlc {\nDMLC_DECLARE_TRAITS(has_saveload, dgl::serialize::HeteroGraphDataObject, true);\n}\n\n#endif  // DGL_GRAPH_SERIALIZE_HETEROGRAPH_DATA_H_\n"
  },
  {
    "path": "src/graph/serialize/heterograph_serialize.cc",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file graph/serialize/heterograph_serialize.cc\n * @brief DGLHeteroGraph serialization implementation\n *\n * The storage structure is\n * {\n *   // MetaData Section\n *   uint64_t kDGLSerializeMagic\n *   uint64_t kVersion = 2\n *   uint64_t GraphType = kDGLHeteroGraph\n *   dgl_id_t num_graphs\n *   ** Reserved Area till 4kB **\n *\n *   uint64_t gdata_start_pos (This stores the start position of graph_data,\n * which is used to skip label dict part if unnecessary)\n *   vector<pair<string, NDArray>> label_dict (To store the dict[str, NDArray])\n *\n *   vector<HeteroGraphData> graph_datas;\n *   vector<dgl_id_t> graph_indices (start address of each graph)\n *   uint64_t size_of_graph_indices_vector (Used to seek to graph_indices\n * vector)\n *\n * }\n *\n * Storage of HeteroGraphData is\n * {\n *   HeteroGraphPtr ptr;\n *   vector<vector<pair<string, NDArray>>> node_tensors;\n *   vector<vector<pair<string, NDArray>>> edge_tensors;\n *   vector<string> ntype_name;\n *   vector<string> etype_name;\n * }\n *\n */\n#include <dgl/graph_op.h>\n#include <dgl/immutable_graph.h>\n#include <dgl/runtime/container.h>\n#include <dgl/runtime/object.h>\n#include <dmlc/io.h>\n#include <dmlc/type_traits.h>\n\n#include <algorithm>\n#include <array>\n#include <iostream>\n#include <string>\n#include <utility>\n#include <vector>\n\n#include \"../heterograph.h\"\n#include \"./dglstream.h\"\n#include \"./graph_serialize.h\"\n#include \"dmlc/memory_io.h\"\n\nnamespace dgl {\nnamespace serialize {\n\nusing namespace dgl::runtime;\nusing dmlc::SeekStream;\nusing dmlc::Stream;\nusing dmlc::io::FileSystem;\nusing dmlc::io::URI;\n\nbool SaveHeteroGraphs(\n    std::string filename, List<HeteroGraphData> hdata,\n    const std::vector<NamedTensor> &nd_list, dgl_format_code_t formats) {\n  auto fs = std::unique_ptr<DGLStream>(\n      DGLStream::Create(filename.c_str(), \"w\", false, formats));\n  CHECK(fs->IsValid()) << \"File name \" << filename << \" is not a valid name\";\n\n  // Write DGL MetaData\n  const uint64_t kVersion = 2;\n  std::array<char, 4096> meta_buffer;\n\n  // Write metadata into char buffer with size 4096\n  dmlc::MemoryFixedSizeStream meta_fs_(meta_buffer.data(), 4096);\n  auto meta_fs = static_cast<Stream *>(&meta_fs_);\n  meta_fs->Write(kDGLSerializeMagic);\n  meta_fs->Write(kVersion);\n  meta_fs->Write(GraphType::kHeteroGraph);\n  uint64_t num_graph = hdata.size();\n  meta_fs->Write(num_graph);\n\n  // Write metadata into files\n  fs->Write(meta_buffer.data(), 4096);\n\n  // Calculate label dict binary size\n  std::string labels_blob;\n  dmlc::MemoryStringStream label_fs_(&labels_blob);\n  auto label_fs = static_cast<Stream *>(&label_fs_);\n  label_fs->Write(nd_list);\n\n  uint64_t gdata_start_pos =\n      fs->Count() + sizeof(uint64_t) + labels_blob.size();\n\n  // Write start position of gdata, which can be skipped when only reading gdata\n  // And label dict\n  fs->Write(gdata_start_pos);\n  fs->Write(labels_blob.c_str(), labels_blob.size());\n\n  std::vector<uint64_t> graph_indices(num_graph);\n\n  // Write HeteroGraphData\n  for (uint64_t i = 0; i < num_graph; ++i) {\n    graph_indices[i] = fs->Count();\n    auto gdata = hdata[i].sptr();\n    fs->Write(gdata);\n  }\n\n  // Write indptr into string to count size\n  std::string indptr_blob;\n  dmlc::MemoryStringStream indptr_fs_(&indptr_blob);\n  auto indptr_fs = static_cast<Stream *>(&indptr_fs_);\n  indptr_fs->Write(graph_indices);\n\n  uint64_t indptr_buffer_size = indptr_blob.size();\n  fs->Write(indptr_blob);\n  fs->Write(indptr_buffer_size);\n\n  return true;\n}\n\nstd::vector<HeteroGraphData> LoadHeteroGraphs(\n    const std::string &filename, std::vector<dgl_id_t> idx_list) {\n  auto fs = std::unique_ptr<SeekStream>(\n      SeekStream::CreateForRead(filename.c_str(), false));\n  CHECK(fs) << \"File name \" << filename << \" is not a valid name\";\n  // Read DGL MetaData\n  uint64_t magicNum, graphType, version, num_graph;\n  fs->Read(&magicNum);\n  fs->Read(&version);\n  fs->Read(&graphType);\n  CHECK(fs->Read(&num_graph)) << \"Invalid num of graph\";\n  fs->Seek(4096);\n\n  CHECK_EQ(magicNum, kDGLSerializeMagic) << \"Invalid DGL files\";\n  CHECK_EQ(version, 2) << \"Invalid GraphType\";\n  CHECK_EQ(graphType, GraphType::kHeteroGraph) << \"Invalid GraphType\";\n\n  uint64_t gdata_start_pos;\n  fs->Read(&gdata_start_pos);\n  // Skip labels part\n  fs->Seek(gdata_start_pos);\n\n  std::vector<HeteroGraphData> gdata_refs;\n  if (idx_list.empty()) {\n    // Read All Graphs\n    gdata_refs.reserve(num_graph);\n    for (uint64_t i = 0; i < num_graph; ++i) {\n      HeteroGraphData gdata = HeteroGraphData::Create();\n      auto hetero_data = gdata.sptr();\n      fs->Read(&hetero_data);\n      gdata_refs.push_back(gdata);\n    }\n  } else {\n    uint64_t gdata_start_pos = fs->Tell();\n    // Read Selected Graphss\n    gdata_refs.reserve(idx_list.size());\n    URI uri(filename.c_str());\n    uint64_t filesize = FileSystem::GetInstance(uri)->GetPathInfo(uri).size;\n    fs->Seek(filesize - sizeof(uint64_t));\n    uint64_t indptr_buffer_size;\n    fs->Read(&indptr_buffer_size);\n\n    std::vector<uint64_t> graph_indices(num_graph);\n    fs->Seek(filesize - sizeof(uint64_t) - indptr_buffer_size);\n    fs->Read(&graph_indices);\n\n    fs->Seek(gdata_start_pos);\n    // Would be better if idx_list is sorted. However the returned the graphs\n    // should be the same order as the idx_list\n    for (uint64_t i = 0; i < idx_list.size(); ++i) {\n      auto gid = idx_list[i];\n      CHECK((gid < graph_indices.size()) && (gid >= 0))\n          << \"ID \" << gid\n          << \" in idx_list is out of bound. Please check your idx_list.\";\n      fs->Seek(graph_indices[gid]);\n      HeteroGraphData gdata = HeteroGraphData::Create();\n      auto hetero_data = gdata.sptr();\n      fs->Read(&hetero_data);\n      gdata_refs.push_back(gdata);\n    }\n  }\n\n  return gdata_refs;\n}\n\nstd::vector<NamedTensor> LoadLabels_V2(const std::string &filename) {\n  auto fs = std::unique_ptr<SeekStream>(\n      SeekStream::CreateForRead(filename.c_str(), false));\n  CHECK(fs) << \"File name \" << filename << \" is not a valid name\";\n  // Read DGL MetaData\n  uint64_t magicNum, graphType, version, num_graph;\n  fs->Read(&magicNum);\n  fs->Read(&version);\n  fs->Read(&graphType);\n  CHECK(fs->Read(&num_graph)) << \"Invalid num of graph\";\n  fs->Seek(4096);\n\n  uint64_t gdata_start_pos;\n  fs->Read(&gdata_start_pos);\n\n  std::vector<NamedTensor> labels_list;\n  fs->Read(&labels_list);\n\n  return labels_list;\n}\n\nDGL_REGISTER_GLOBAL(\"data.heterograph_serialize._CAPI_MakeHeteroGraphData\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      HeteroGraphRef hg = args[0];\n      List<Map<std::string, Value>> ndata = args[1];\n      List<Map<std::string, Value>> edata = args[2];\n      List<Value> ntype_names = args[3];\n      List<Value> etype_names = args[4];\n      *rv = HeteroGraphData::Create(\n          hg.sptr(), ndata, edata, ntype_names, etype_names);\n    });\n\nDGL_REGISTER_GLOBAL(\"data.heterograph_serialize._CAPI_SaveHeteroGraphData\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      std::string filename = args[0];\n      List<HeteroGraphData> hgdata = args[1];\n      Map<std::string, Value> nd_map = args[2];\n      List<Value> formats = args[3];\n      std::vector<SparseFormat> formats_vec;\n      for (const auto &val : formats) {\n        formats_vec.push_back(ParseSparseFormat(val->data));\n      }\n      const auto formats_code = SparseFormatsToCode(formats_vec);\n      std::vector<NamedTensor> nd_list;\n      for (auto kv : nd_map) {\n        NDArray ndarray = static_cast<NDArray>(kv.second->data);\n        nd_list.emplace_back(kv.first, ndarray);\n      }\n      *rv = dgl::serialize::SaveHeteroGraphs(\n          filename, hgdata, nd_list, formats_code);\n    });\n\nDGL_REGISTER_GLOBAL(\n    \"data.heterograph_serialize._CAPI_GetGindexFromHeteroGraphData\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      HeteroGraphData hdata = args[0];\n      *rv = HeteroGraphRef(hdata->gptr);\n    });\n\nDGL_REGISTER_GLOBAL(\n    \"data.heterograph_serialize._CAPI_GetEtypesFromHeteroGraphData\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      HeteroGraphData hdata = args[0];\n      List<Value> etype_names;\n      for (const auto &name : hdata->etype_names) {\n        etype_names.push_back(Value(MakeValue(name)));\n      }\n      *rv = etype_names;\n    });\n\nDGL_REGISTER_GLOBAL(\n    \"data.heterograph_serialize._CAPI_GetNtypesFromHeteroGraphData\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      HeteroGraphData hdata = args[0];\n      List<Value> ntype_names;\n      for (auto name : hdata->ntype_names) {\n        ntype_names.push_back(Value(MakeValue(name)));\n      }\n      *rv = ntype_names;\n    });\n\nDGL_REGISTER_GLOBAL(\n    \"data.heterograph_serialize._CAPI_GetNDataFromHeteroGraphData\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      HeteroGraphData hdata = args[0];\n      List<List<Value>> ntensors;\n      for (auto tensor_list : hdata->node_tensors) {\n        List<Value> nlist;\n        for (const auto &kv : tensor_list) {\n          nlist.push_back(Value(MakeValue(kv.first)));\n          nlist.push_back(Value(MakeValue(kv.second)));\n        }\n        ntensors.push_back(nlist);\n      }\n      *rv = ntensors;\n    });\n\nDGL_REGISTER_GLOBAL(\n    \"data.heterograph_serialize._CAPI_GetEDataFromHeteroGraphData\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      HeteroGraphData hdata = args[0];\n      List<List<Value>> etensors;\n      for (auto tensor_list : hdata->edge_tensors) {\n        List<Value> elist;\n        for (const auto &kv : tensor_list) {\n          elist.push_back(Value(MakeValue(kv.first)));\n          elist.push_back(Value(MakeValue(kv.second)));\n        }\n        etensors.push_back(elist);\n      }\n      *rv = etensors;\n    });\n\nDGL_REGISTER_GLOBAL(\"data.graph_serialize._CAPI_LoadLabels_V2\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      std::string filename = args[0];\n      auto labels_list = LoadLabels_V2(filename);\n      Map<std::string, Value> rvmap;\n      for (auto kv : labels_list) {\n        rvmap.Set(kv.first, Value(MakeValue(kv.second)));\n      }\n      *rv = rvmap;\n    });\n\n}  // namespace serialize\n}  // namespace dgl\n"
  },
  {
    "path": "src/graph/serialize/tensor_serialize.cc",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file graph/serialize/tensor_serialize.cc\n * @brief Graph serialization implementation\n */\n#include <dgl/packed_func_ext.h>\n#include <dgl/runtime/container.h>\n#include <dgl/runtime/ndarray.h>\n#include <dgl/runtime/object.h>\n#include <dmlc/io.h>\n\n#include \"../../c_api_common.h\"\n\nusing namespace dgl::runtime;\nusing dmlc::SeekStream;\n\nnamespace dgl {\nnamespace serialize {\n\ntypedef std::pair<std::string, NDArray> NamedTensor;\n\nconstexpr uint64_t kDGLSerialize_Tensors = 0xDD5A9FBE3FA2443F;\n\nDGL_REGISTER_GLOBAL(\"data.tensor_serialize._CAPI_SaveNDArrayDict\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      std::string filename = args[0];\n      auto fs = std::unique_ptr<dmlc::Stream>(\n          dmlc::Stream::Create(filename.c_str(), \"w\"));\n      CHECK(fs) << \"Filename is invalid\";\n      fs->Write(kDGLSerialize_Tensors);\n      bool empty_dict = args[2];\n      Map<std::string, Value> nd_dict;\n      if (!empty_dict) {\n        nd_dict = args[1];\n      }\n      std::vector<NamedTensor> namedTensors;\n      fs->Write(static_cast<uint64_t>(nd_dict.size()));\n      for (auto kv : nd_dict) {\n        NDArray ndarray = static_cast<NDArray>(kv.second->data);\n        namedTensors.emplace_back(kv.first, ndarray);\n      }\n      fs->Write(namedTensors);\n      *rv = true;\n    });\n\nDGL_REGISTER_GLOBAL(\"data.tensor_serialize._CAPI_LoadNDArrayDict\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      std::string filename = args[0];\n      auto fs = std::unique_ptr<dmlc::Stream>(\n          dmlc::Stream::Create(filename.c_str(), \"r\"));\n      CHECK(fs) << \"Filename is invalid or file doesn't exists\";\n      uint64_t magincNum, num_elements;\n      CHECK(fs->Read(&magincNum)) << \"Invalid file\";\n      CHECK_EQ(magincNum, kDGLSerialize_Tensors) << \"Invalid DGL tensor file\";\n      CHECK(fs->Read(&num_elements)) << \"Invalid num of elements\";\n      Map<std::string, Value> nd_dict;\n      std::vector<NamedTensor> namedTensors;\n      fs->Read(&namedTensors);\n      for (auto kv : namedTensors) {\n        Value ndarray = Value(MakeValue(kv.second));\n        nd_dict.Set(kv.first, ndarray);\n      }\n      *rv = nd_dict;\n    });\n\n}  // namespace serialize\n}  // namespace dgl\n"
  },
  {
    "path": "src/graph/serialize/zerocopy_serializer.cc",
    "content": "/**\n *  Copyright (c) 2020-2022 by Contributors\n * @file graph/serailize/zerocopy_serializer.cc\n * @brief serializer implementation.\n */\n\n#include <dgl/zerocopy_serializer.h>\n\n#include \"dgl/runtime/ndarray.h\"\n#include \"dmlc/memory_io.h\"\n\nnamespace dgl {\n\nusing dgl::runtime::NDArray;\n\nNDArray CreateNDArrayFromRawData(\n    std::vector<int64_t> shape, DGLDataType dtype, DGLContext ctx, void* raw) {\n  return NDArray::CreateFromRaw(shape, dtype, ctx, raw, true);\n}\n\nvoid StreamWithBuffer::PushNDArray(const NDArray& tensor) {\n#ifndef _WIN32\n  this->Write(tensor->ndim);\n  this->Write(tensor->dtype);\n  int ndim = tensor->ndim;\n  this->WriteArray(tensor->shape, ndim);\n  CHECK(tensor.IsContiguous())\n      << \"StreamWithBuffer only supports contiguous tensor\";\n  CHECK_EQ(tensor->byte_offset, 0)\n      << \"StreamWithBuffer only supports zero byte offset tensor\";\n  int type_bytes = tensor->dtype.bits / 8;\n  int64_t num_elems = 1;\n  for (int i = 0; i < ndim; ++i) {\n    num_elems *= tensor->shape[i];\n  }\n  int64_t data_byte_size = type_bytes * num_elems;\n\n  auto mem = tensor.GetSharedMem();\n  if (send_to_remote_ || !mem) {\n    // If the stream is for remote communication or the data is not stored in\n    // shared memory, serialize the data content as a buffer.\n    this->Write<bool>(false);\n    // If this is a null ndarray, we will not push it into the underlying\n    // buffer_list\n    if (data_byte_size != 0) {\n      buffer_list_.emplace_back(tensor, tensor->data, data_byte_size);\n    }\n  } else {\n    CHECK(mem) << \"Tried to send non-shared-memroy tensor to local \"\n                  \"StreamWithBuffer\";\n    // Serialize only the shared memory name.\n    this->Write<bool>(true);\n    this->Write(mem->GetName());\n  }\n#else\n  LOG(FATAL) << \"StreamWithBuffer is not supported on windows\";\n#endif  // _WIN32\n  return;\n}\n\nNDArray StreamWithBuffer::PopNDArray() {\n#ifndef _WIN32\n  int ndim;\n  DGLDataType dtype;\n\n  CHECK(this->Read(&ndim)) << \"Invalid DGLArray file format\";\n  CHECK(this->Read(&dtype)) << \"Invalid DGLArray file format\";\n\n  std::vector<int64_t> shape(ndim);\n  if (ndim != 0) {\n    CHECK(this->ReadArray(&shape[0], ndim)) << \"Invalid DGLArray file format\";\n  }\n\n  DGLContext cpu_ctx;\n  cpu_ctx.device_type = kDGLCPU;\n  cpu_ctx.device_id = 0;\n\n  bool is_shared_mem;\n  CHECK(this->Read(&is_shared_mem)) << \"Invalid stream read\";\n  std::string sharedmem_name;\n  if (is_shared_mem) {\n    CHECK(!send_to_remote_) << \"Invalid attempt to deserialize from shared \"\n                               \"memory with send_to_remote=true\";\n    CHECK(this->Read(&sharedmem_name)) << \"Invalid stream read\";\n    return NDArray::EmptyShared(sharedmem_name, shape, dtype, cpu_ctx, false);\n  } else {\n    CHECK(send_to_remote_) << \"Invalid attempt to deserialize from raw data \"\n                              \"pointer with send_to_remote=false\";\n    NDArray ret;\n    if (ndim == 0 || shape[0] == 0) {\n      // Mean this is a null ndarray\n      ret = CreateNDArrayFromRawData(shape, dtype, cpu_ctx, nullptr);\n    } else {\n      ret = CreateNDArrayFromRawData(\n          shape, dtype, cpu_ctx, buffer_list_.front().data);\n      buffer_list_.pop_front();\n    }\n    return ret;\n  }\n#else\n  LOG(FATAL) << \"StreamWithBuffer is not supported on windows\";\n  return NDArray();\n#endif  // _WIN32\n}\n\n}  // namespace dgl\n"
  },
  {
    "path": "src/graph/shared_mem_manager.cc",
    "content": "/**\n *  Copyright (c) 2018 by Contributors\n * @file graph/shared_mem_manager.cc\n * @brief DGL sampler implementation\n */\n#include \"shared_mem_manager.h\"\n\n#include <dgl/array.h>\n#include <dgl/base_heterograph.h>\n#include <dgl/immutable_graph.h>\n#include <dgl/packed_func_ext.h>\n#include <dgl/random.h>\n#include <dgl/runtime/container.h>\n#include <dgl/sampler.h>\n#include <dmlc/io.h>\n#include <dmlc/memory_io.h>\n\n#include <algorithm>\n#include <array>\n#include <cmath>\n#include <cstdlib>\n#include <numeric>\n#include <vector>\n\n#include \"../c_api_common.h\"\n#include \"heterograph.h\"\n\nusing namespace dgl::runtime;\nusing namespace dgl::aten;\n\nnamespace dgl {\n\ntemplate <>\nNDArray SharedMemManager::CopyToSharedMem<NDArray>(\n    const NDArray &data, std::string name) {\n  DGLContext ctx = {kDGLCPU, 0};\n  std::vector<int64_t> shape(data->shape, data->shape + data->ndim);\n  strm_->Write(data->ndim);\n  strm_->Write(data->dtype);\n  int ndim = data->ndim;\n  strm_->WriteArray(data->shape, ndim);\n\n  bool is_null = IsNullArray(data);\n  strm_->Write(is_null);\n  if (is_null) {\n    return data;\n  } else {\n    auto nd =\n        NDArray::EmptyShared(graph_name_ + name, shape, data->dtype, ctx, true);\n    nd.CopyFrom(data);\n    return nd;\n  }\n}\n\ntemplate <>\nCSRMatrix SharedMemManager::CopyToSharedMem<CSRMatrix>(\n    const CSRMatrix &csr, std::string name) {\n  auto indptr_shared_mem = CopyToSharedMem(csr.indptr, name + \"_indptr\");\n  auto indices_shared_mem = CopyToSharedMem(csr.indices, name + \"_indices\");\n  auto data_shared_mem = CopyToSharedMem(csr.data, name + \"_data\");\n  strm_->Write(csr.num_rows);\n  strm_->Write(csr.num_cols);\n  strm_->Write(csr.sorted);\n  return CSRMatrix(\n      csr.num_rows, csr.num_cols, indptr_shared_mem, indices_shared_mem,\n      data_shared_mem, csr.sorted);\n}\n\ntemplate <>\nCOOMatrix SharedMemManager::CopyToSharedMem<COOMatrix>(\n    const COOMatrix &coo, std::string name) {\n  auto row_shared_mem = CopyToSharedMem(coo.row, name + \"_row\");\n  auto col_shared_mem = CopyToSharedMem(coo.col, name + \"_col\");\n  auto data_shared_mem = CopyToSharedMem(coo.data, name + \"_data\");\n  strm_->Write(coo.num_rows);\n  strm_->Write(coo.num_cols);\n  strm_->Write(coo.row_sorted);\n  strm_->Write(coo.col_sorted);\n  return COOMatrix(\n      coo.num_rows, coo.num_cols, row_shared_mem, col_shared_mem,\n      data_shared_mem, coo.row_sorted, coo.col_sorted);\n}\n\ntemplate <>\nbool SharedMemManager::CreateFromSharedMem<NDArray>(\n    NDArray *nd, std::string name) {\n  int ndim;\n  DGLContext ctx = {kDGLCPU, 0};\n  DGLDataType dtype;\n\n  CHECK(this->Read(&ndim)) << \"Invalid DGLArray file format\";\n  CHECK(this->Read(&dtype)) << \"Invalid DGLArray file format\";\n\n  std::vector<int64_t> shape(ndim);\n  if (ndim != 0) {\n    CHECK(this->ReadArray(&shape[0], ndim)) << \"Invalid DGLArray file format\";\n  }\n  bool is_null;\n  this->Read(&is_null);\n  if (is_null) {\n    *nd = NDArray::Empty(shape, dtype, ctx);\n  } else {\n    *nd = NDArray::EmptyShared(graph_name_ + name, shape, dtype, ctx, false);\n  }\n  return true;\n}\n\ntemplate <>\nbool SharedMemManager::CreateFromSharedMem<COOMatrix>(\n    COOMatrix *coo, std::string name) {\n  CreateFromSharedMem(&coo->row, name + \"_row\");\n  CreateFromSharedMem(&coo->col, name + \"_col\");\n  CreateFromSharedMem(&coo->data, name + \"_data\");\n  strm_->Read(&coo->num_rows);\n  strm_->Read(&coo->num_cols);\n  strm_->Read(&coo->row_sorted);\n  strm_->Read(&coo->col_sorted);\n  return true;\n}\n\ntemplate <>\nbool SharedMemManager::CreateFromSharedMem<CSRMatrix>(\n    CSRMatrix *csr, std::string name) {\n  CreateFromSharedMem(&csr->indptr, name + \"_indptr\");\n  CreateFromSharedMem(&csr->indices, name + \"_indices\");\n  CreateFromSharedMem(&csr->data, name + \"_data\");\n  strm_->Read(&csr->num_rows);\n  strm_->Read(&csr->num_cols);\n  strm_->Read(&csr->sorted);\n  return true;\n}\n\n}  // namespace dgl\n"
  },
  {
    "path": "src/graph/shared_mem_manager.h",
    "content": "/**\n *  Copyright (c) 2018 by Contributors\n * @file graph/shared_mem_manager.cc\n * @brief DGL shared mem manager APIs\n */\n\n#ifndef DGL_GRAPH_SHARED_MEM_MANAGER_H_\n#define DGL_GRAPH_SHARED_MEM_MANAGER_H_\n\n#include <dgl/array.h>\n#include <dmlc/io.h>\n#include <dmlc/memory_io.h>\n\n#include <algorithm>\n#include <array>\n#include <cmath>\n#include <cstdlib>\n#include <memory>\n#include <numeric>\n#include <string>\n\nnamespace dgl {\n\nusing dgl::runtime::SharedMemory;\n\nconst size_t SHARED_MEM_METAINFO_SIZE_MAX = 1024 * 32;\n\n// Utility class to copy objects to shared memory and record metadatas\nclass SharedMemManager : public dmlc::Stream {\n public:\n  explicit SharedMemManager(std::string graph_name, dmlc::Stream* strm)\n      : graph_name_(graph_name), strm_(strm) {}\n\n  template <typename T>\n  T CopyToSharedMem(const T& data, std::string name);\n\n  template <typename T>\n  bool CreateFromSharedMem(T* out_data, std::string name);\n\n  // delegate methods to strm_\n  virtual size_t Read(void* ptr, size_t size) { return strm_->Read(ptr, size); }\n  virtual void Write(const void* ptr, size_t size) { strm_->Write(ptr, size); }\n\n  using dmlc::Stream::Read;\n  using dmlc::Stream::Write;\n\n private:\n  std::string graph_name_;\n  dmlc::Stream* strm_;\n};\n\n}  // namespace dgl\n\n#endif  // DGL_GRAPH_SHARED_MEM_MANAGER_H_\n"
  },
  {
    "path": "src/graph/subgraph.cc",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file graph/subgraph.cc\n * @brief Functions for extracting subgraphs.\n */\n#include \"./heterograph.h\"\nusing namespace dgl::runtime;\n\nnamespace dgl {\n\nHeteroSubgraph InEdgeGraphRelabelNodes(\n    const HeteroGraphPtr graph, const std::vector<IdArray>& vids) {\n  CHECK_EQ(vids.size(), graph->NumVertexTypes())\n      << \"Invalid input: the input list size must be the same as the number of \"\n         \"vertex types.\";\n  std::vector<IdArray> eids(graph->NumEdgeTypes());\n  DGLContext ctx = aten::GetContextOf(vids);\n  for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {\n    auto pair = graph->meta_graph()->FindEdge(etype);\n    const dgl_type_t dst_vtype = pair.second;\n    if (aten::IsNullArray(vids[dst_vtype])) {\n      eids[etype] = IdArray::Empty({0}, graph->DataType(), ctx);\n    } else {\n      const auto& earr = graph->InEdges(etype, {vids[dst_vtype]});\n      eids[etype] = earr.id;\n    }\n  }\n  return graph->EdgeSubgraph(eids, false);\n}\n\nHeteroSubgraph InEdgeGraphNoRelabelNodes(\n    const HeteroGraphPtr graph, const std::vector<IdArray>& vids) {\n  // TODO(mufei): This should also use EdgeSubgraph once it is supported for CSR\n  // graphs\n  CHECK_EQ(vids.size(), graph->NumVertexTypes())\n      << \"Invalid input: the input list size must be the same as the number of \"\n         \"vertex types.\";\n  std::vector<HeteroGraphPtr> subrels(graph->NumEdgeTypes());\n  std::vector<IdArray> induced_edges(graph->NumEdgeTypes());\n  DGLContext ctx = aten::GetContextOf(vids);\n  for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {\n    auto pair = graph->meta_graph()->FindEdge(etype);\n    const dgl_type_t src_vtype = pair.first;\n    const dgl_type_t dst_vtype = pair.second;\n    auto relgraph = graph->GetRelationGraph(etype);\n    if (aten::IsNullArray(vids[dst_vtype])) {\n      // create a placeholder graph\n      subrels[etype] = UnitGraph::Empty(\n          relgraph->NumVertexTypes(), graph->NumVertices(src_vtype),\n          graph->NumVertices(dst_vtype), graph->DataType(), ctx);\n      induced_edges[etype] =\n          IdArray::Empty({0}, graph->DataType(), graph->Context());\n    } else {\n      const auto& earr = graph->InEdges(etype, {vids[dst_vtype]});\n      subrels[etype] = UnitGraph::CreateFromCOO(\n          relgraph->NumVertexTypes(), graph->NumVertices(src_vtype),\n          graph->NumVertices(dst_vtype), earr.src, earr.dst);\n      induced_edges[etype] = earr.id;\n    }\n  }\n  HeteroSubgraph ret;\n  ret.graph = CreateHeteroGraph(\n      graph->meta_graph(), subrels, graph->NumVerticesPerType());\n  ret.induced_edges = std::move(induced_edges);\n  return ret;\n}\n\nHeteroSubgraph InEdgeGraph(\n    const HeteroGraphPtr graph, const std::vector<IdArray>& vids,\n    bool relabel_nodes) {\n  if (relabel_nodes) {\n    return InEdgeGraphRelabelNodes(graph, vids);\n  } else {\n    return InEdgeGraphNoRelabelNodes(graph, vids);\n  }\n}\n\nHeteroSubgraph OutEdgeGraphRelabelNodes(\n    const HeteroGraphPtr graph, const std::vector<IdArray>& vids) {\n  CHECK_EQ(vids.size(), graph->NumVertexTypes())\n      << \"Invalid input: the input list size must be the same as the number of \"\n         \"vertex types.\";\n  std::vector<IdArray> eids(graph->NumEdgeTypes());\n  DGLContext ctx = aten::GetContextOf(vids);\n  for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {\n    auto pair = graph->meta_graph()->FindEdge(etype);\n    const dgl_type_t src_vtype = pair.first;\n    if (aten::IsNullArray(vids[src_vtype])) {\n      eids[etype] = IdArray::Empty({0}, graph->DataType(), ctx);\n    } else {\n      const auto& earr = graph->OutEdges(etype, {vids[src_vtype]});\n      eids[etype] = earr.id;\n    }\n  }\n  return graph->EdgeSubgraph(eids, false);\n}\n\nHeteroSubgraph OutEdgeGraphNoRelabelNodes(\n    const HeteroGraphPtr graph, const std::vector<IdArray>& vids) {\n  // TODO(mufei): This should also use EdgeSubgraph once it is supported for CSR\n  // graphs\n  CHECK_EQ(vids.size(), graph->NumVertexTypes())\n      << \"Invalid input: the input list size must be the same as the number of \"\n         \"vertex types.\";\n  std::vector<HeteroGraphPtr> subrels(graph->NumEdgeTypes());\n  std::vector<IdArray> induced_edges(graph->NumEdgeTypes());\n  DGLContext ctx = aten::GetContextOf(vids);\n  for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {\n    auto pair = graph->meta_graph()->FindEdge(etype);\n    const dgl_type_t src_vtype = pair.first;\n    const dgl_type_t dst_vtype = pair.second;\n    auto relgraph = graph->GetRelationGraph(etype);\n    if (aten::IsNullArray(vids[src_vtype])) {\n      // create a placeholder graph\n      subrels[etype] = UnitGraph::Empty(\n          relgraph->NumVertexTypes(), graph->NumVertices(src_vtype),\n          graph->NumVertices(dst_vtype), graph->DataType(), ctx);\n      induced_edges[etype] =\n          IdArray::Empty({0}, graph->DataType(), graph->Context());\n    } else {\n      const auto& earr = graph->OutEdges(etype, {vids[src_vtype]});\n      subrels[etype] = UnitGraph::CreateFromCOO(\n          relgraph->NumVertexTypes(), graph->NumVertices(src_vtype),\n          graph->NumVertices(dst_vtype), earr.src, earr.dst);\n      induced_edges[etype] = earr.id;\n    }\n  }\n  HeteroSubgraph ret;\n  ret.graph = CreateHeteroGraph(\n      graph->meta_graph(), subrels, graph->NumVerticesPerType());\n  ret.induced_edges = std::move(induced_edges);\n  return ret;\n}\n\nHeteroSubgraph OutEdgeGraph(\n    const HeteroGraphPtr graph, const std::vector<IdArray>& vids,\n    bool relabel_nodes) {\n  if (relabel_nodes) {\n    return OutEdgeGraphRelabelNodes(graph, vids);\n  } else {\n    return OutEdgeGraphNoRelabelNodes(graph, vids);\n  }\n}\n\n}  // namespace dgl\n"
  },
  {
    "path": "src/graph/transform/compact.cc",
    "content": "/**\n *  Copyright 2019-2021 Contributors\n *\n *  Licensed under the Apache License, Version 2.0 (the \"License\");\n *  you may not use this file except in compliance with the License.\n *  You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n *  Unless required by applicable law or agreed to in writing, software\n *  distributed under the License is distributed on an \"AS IS\" BASIS,\n *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n *  See the License for the specific language governing permissions and\n *  limitations under the License.\n *\n * @file graph/transform/compact.cc\n * @brief Compact graph implementation\n */\n\n#include \"compact.h\"\n\n#include <dgl/array.h>\n#include <dgl/base_heterograph.h>\n#include <dgl/packed_func_ext.h>\n#include <dgl/runtime/container.h>\n#include <dgl/runtime/registry.h>\n#include <dgl/transform.h>\n\n#include <utility>\n#include <vector>\n\n#include \"../../c_api_common.h\"\n#include \"../unit_graph.h\"\n// TODO(BarclayII): currently CompactGraphs depend on IdHashMap implementation\n// which only works on CPU.  Should fix later to make it device agnostic.\n#include \"../../array/cpu/array_utils.h\"\n\nnamespace dgl {\n\nusing namespace dgl::runtime;\nusing namespace dgl::aten;\n\nnamespace transform {\n\nnamespace {\n\ntemplate <typename IdType>\nstd::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>> CompactGraphsCPU(\n    const std::vector<HeteroGraphPtr> &graphs,\n    const std::vector<IdArray> &always_preserve) {\n  // TODO(BarclayII): check whether the node space and metagraph of each graph\n  // is the same. Step 1: Collect the nodes that has connections for each type.\n  const int64_t num_ntypes = graphs[0]->NumVertexTypes();\n  std::vector<aten::IdHashMap<IdType>> hashmaps(num_ntypes);\n  std::vector<std::vector<EdgeArray>> all_edges(\n      graphs.size());  // all_edges[i][etype]\n\n  std::vector<int64_t> max_vertex_cnt(num_ntypes, 0);\n  for (size_t i = 0; i < graphs.size(); ++i) {\n    const HeteroGraphPtr curr_graph = graphs[i];\n    const int64_t num_etypes = curr_graph->NumEdgeTypes();\n\n    for (IdType etype = 0; etype < num_etypes; ++etype) {\n      IdType srctype, dsttype;\n      std::tie(srctype, dsttype) = curr_graph->GetEndpointTypes(etype);\n\n      const int64_t n_edges = curr_graph->NumEdges(etype);\n      max_vertex_cnt[srctype] += n_edges;\n      max_vertex_cnt[dsttype] += n_edges;\n    }\n  }\n\n  // Reserve the space for hash maps before ahead to aoivd rehashing\n  for (size_t i = 0; i < static_cast<size_t>(num_ntypes); ++i) {\n    if (i < always_preserve.size())\n      hashmaps[i].Reserve(always_preserve[i]->shape[0] + max_vertex_cnt[i]);\n    else\n      hashmaps[i].Reserve(max_vertex_cnt[i]);\n  }\n\n  for (size_t i = 0; i < always_preserve.size(); ++i) {\n    hashmaps[i].Update(always_preserve[i]);\n  }\n\n  for (size_t i = 0; i < graphs.size(); ++i) {\n    const HeteroGraphPtr curr_graph = graphs[i];\n    const int64_t num_etypes = curr_graph->NumEdgeTypes();\n\n    all_edges[i].reserve(num_etypes);\n    for (IdType etype = 0; etype < num_etypes; ++etype) {\n      IdType srctype, dsttype;\n      std::tie(srctype, dsttype) = curr_graph->GetEndpointTypes(etype);\n\n      const EdgeArray edges = curr_graph->Edges(etype, \"eid\");\n\n      hashmaps[srctype].Update(edges.src);\n      hashmaps[dsttype].Update(edges.dst);\n\n      all_edges[i].push_back(edges);\n    }\n  }\n\n  // Step 2: Relabel the nodes for each type to a smaller ID space and save the\n  // mapping.\n  std::vector<IdArray> induced_nodes(num_ntypes);\n  std::vector<int64_t> num_induced_nodes(num_ntypes);\n  for (int64_t i = 0; i < num_ntypes; ++i) {\n    induced_nodes[i] = hashmaps[i].Values();\n    num_induced_nodes[i] = hashmaps[i].Size();\n  }\n\n  // Step 3: Remap the edges of each graph.\n  std::vector<HeteroGraphPtr> new_graphs;\n  for (size_t i = 0; i < graphs.size(); ++i) {\n    std::vector<HeteroGraphPtr> rel_graphs;\n    const HeteroGraphPtr curr_graph = graphs[i];\n    const auto meta_graph = curr_graph->meta_graph();\n    const int64_t num_etypes = curr_graph->NumEdgeTypes();\n\n    for (IdType etype = 0; etype < num_etypes; ++etype) {\n      IdType srctype, dsttype;\n      std::tie(srctype, dsttype) = curr_graph->GetEndpointTypes(etype);\n      const EdgeArray &edges = all_edges[i][etype];\n\n      const IdArray mapped_rows = hashmaps[srctype].Map(edges.src, -1);\n      const IdArray mapped_cols = hashmaps[dsttype].Map(edges.dst, -1);\n\n      rel_graphs.push_back(UnitGraph::CreateFromCOO(\n          srctype == dsttype ? 1 : 2, induced_nodes[srctype]->shape[0],\n          induced_nodes[dsttype]->shape[0], mapped_rows, mapped_cols));\n    }\n\n    new_graphs.push_back(\n        CreateHeteroGraph(meta_graph, rel_graphs, num_induced_nodes));\n  }\n\n  return std::make_pair(new_graphs, induced_nodes);\n}\n\n};  // namespace\n\ntemplate <>\nstd::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>>\nCompactGraphs<kDGLCPU, int32_t>(\n    const std::vector<HeteroGraphPtr> &graphs,\n    const std::vector<IdArray> &always_preserve) {\n  return CompactGraphsCPU<int32_t>(graphs, always_preserve);\n}\n\ntemplate <>\nstd::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>>\nCompactGraphs<kDGLCPU, int64_t>(\n    const std::vector<HeteroGraphPtr> &graphs,\n    const std::vector<IdArray> &always_preserve) {\n  return CompactGraphsCPU<int64_t>(graphs, always_preserve);\n}\n\nDGL_REGISTER_GLOBAL(\"transform._CAPI_DGLCompactGraphs\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      List<HeteroGraphRef> graph_refs = args[0];\n      List<Value> always_preserve_refs = args[1];\n\n      std::vector<HeteroGraphPtr> graphs;\n      std::vector<IdArray> always_preserve;\n      for (HeteroGraphRef gref : graph_refs) graphs.push_back(gref.sptr());\n      for (Value array : always_preserve_refs)\n        always_preserve.push_back(array->data);\n\n      // TODO(BarclayII): check for all IdArrays\n      CHECK(graphs[0]->DataType() == always_preserve[0]->dtype)\n          << \"data type mismatch.\";\n\n      std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>> result_pair;\n\n      ATEN_XPU_SWITCH_CUDA(\n          graphs[0]->Context().device_type, XPU, \"CompactGraphs\", {\n            ATEN_ID_TYPE_SWITCH(graphs[0]->DataType(), IdType, {\n              result_pair = CompactGraphs<XPU, IdType>(graphs, always_preserve);\n            });\n          });\n\n      List<HeteroGraphRef> compacted_graph_refs;\n      List<Value> induced_nodes;\n\n      for (const HeteroGraphPtr &g : result_pair.first)\n        compacted_graph_refs.push_back(HeteroGraphRef(g));\n      for (const IdArray &ids : result_pair.second)\n        induced_nodes.push_back(Value(MakeValue(ids)));\n\n      List<ObjectRef> result;\n      result.push_back(compacted_graph_refs);\n      result.push_back(induced_nodes);\n\n      *rv = result;\n    });\n\n};  // namespace transform\n\n};  // namespace dgl\n"
  },
  {
    "path": "src/graph/transform/compact.h",
    "content": "/**\n *  Copyright 2021 Contributors\n *\n *  Licensed under the Apache License, Version 2.0 (the \"License\");\n *  you may not use this file except in compliance with the License.\n *  You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n *  Unless required by applicable law or agreed to in writing, software\n *  distributed under the License is distributed on an \"AS IS\" BASIS,\n *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n *  See the License for the specific language governing permissions and\n *  limitations under the License.\n *\n * @file graph/transform/compact.h\n * @brief Functions to find and eliminate the common isolated nodes across\n * all given graphs with the same set of nodes.\n */\n\n#ifndef DGL_GRAPH_TRANSFORM_COMPACT_H_\n#define DGL_GRAPH_TRANSFORM_COMPACT_H_\n\n#include <dgl/array.h>\n#include <dgl/base_heterograph.h>\n\n#include <utility>\n#include <vector>\n\nnamespace dgl {\nnamespace transform {\n\n/**\n * @brief Given a list of graphs with the same set of nodes, find and eliminate\n * the common isolated nodes across all graphs.\n *\n * @tparam XPU The type of device to operate on.\n * @tparam IdType The type to use as an index.\n * @param graphs The list of graphs to be compacted.\n * @param always_preserve The vector of nodes to be preserved.\n *\n * @return The vector of compacted graphs and the vector of induced nodes.\n */\ntemplate <DGLDeviceType XPU, typename IdType>\nstd::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>> CompactGraphs(\n    const std::vector<HeteroGraphPtr> &graphs,\n    const std::vector<IdArray> &always_preserve);\n\n}  // namespace transform\n}  // namespace dgl\n\n#endif  // DGL_GRAPH_TRANSFORM_COMPACT_H_\n"
  },
  {
    "path": "src/graph/transform/cpu/kdtree_ndarray_adapter.h",
    "content": "/**\n *  Copyright (c) 2021 by Contributors\n * @file graph/transform/cpu/kdtree_ndarray_adapter.h\n * @brief NDArray adapter for nanoflann, without\n *        duplicating the storage\n */\n#ifndef DGL_GRAPH_TRANSFORM_CPU_KDTREE_NDARRAY_ADAPTER_H_\n#define DGL_GRAPH_TRANSFORM_CPU_KDTREE_NDARRAY_ADAPTER_H_\n\n#include <dgl/array.h>\n#include <dmlc/logging.h>\n\n#include <nanoflann.hpp>\n\n#include \"../../../c_api_common.h\"\n\nnamespace dgl {\nnamespace transform {\nnamespace knn_utils {\n\n/**\n * @brief A simple 2D NDArray adapter for nanoflann, without duplicating the\n *        storage.\n *\n * @tparam FloatType: The type of the point coordinates (typically, double or\n *         float).\n * @tparam IdType: The type for indices in the KD-tree index (typically,\n *         size_t of int)\n * @tparam FeatureDim: If set to > 0, it specifies a compile-time fixed\n *         dimensionality for the points in the data set, allowing more compiler\n *         optimizations.\n * @tparam Dist: The distance metric to use: nanoflann::metric_L1,\n           nanoflann::metric_L2, nanoflann::metric_L2_Simple, etc.\n * @note The spelling of dgl's adapter (\"adapter\") is different from naneflann\n *       (\"adaptor\")\n */\ntemplate <\n    typename FloatType, typename IdType, int FeatureDim = -1,\n    typename Dist = nanoflann::metric_L2>\nclass KDTreeNDArrayAdapter {\n public:\n  using self_type = KDTreeNDArrayAdapter<FloatType, IdType, FeatureDim, Dist>;\n  using metric_type =\n      typename Dist::template traits<FloatType, self_type>::distance_t;\n  using index_type = nanoflann::KDTreeSingleIndexAdaptor<\n      metric_type, self_type, FeatureDim, IdType>;\n\n  KDTreeNDArrayAdapter(\n      const size_t /* dims */, const NDArray data_points,\n      const int leaf_max_size = 10)\n      : data_(data_points) {\n    CHECK(data_points->shape[0] != 0 && data_points->shape[1] != 0)\n        << \"Tensor containing input data point set must be 2D.\";\n    const size_t dims = data_points->shape[1];\n    CHECK(!(FeatureDim > 0 && static_cast<int>(dims) != FeatureDim))\n        << \"Data set feature dimension does not match the 'FeatureDim' \"\n        << \"template argument.\";\n    index_ = new index_type(\n        static_cast<int>(dims), *this,\n        nanoflann::KDTreeSingleIndexAdaptorParams(leaf_max_size));\n    index_->buildIndex();\n  }\n\n  ~KDTreeNDArrayAdapter() { delete index_; }\n\n  index_type* GetIndex() { return index_; }\n\n  /**\n   * @brief Query for the \\a num_closest points to a given point\n   *  Note that this is a short-cut method for GetIndex()->findNeighbors().\n   */\n  void query(\n      const FloatType* query_pt, const size_t num_closest, IdType* out_idxs,\n      FloatType* out_dists) const {\n    nanoflann::KNNResultSet<FloatType, IdType> resultSet(num_closest);\n    resultSet.init(out_idxs, out_dists);\n    index_->findNeighbors(resultSet, query_pt, nanoflann::SearchParams());\n  }\n\n  /** @brief Interface expected by KDTreeSingleIndexAdaptor */\n  const self_type& derived() const { return *this; }\n\n  /** @brief Interface expected by KDTreeSingleIndexAdaptor */\n  self_type& derived() { return *this; }\n\n  /**\n   * @brief Interface expected by KDTreeSingleIndexAdaptor,\n   *  return the number of data points\n   */\n  size_t kdtree_get_point_count() const { return data_->shape[0]; }\n\n  /**\n   * @brief Interface expected by KDTreeSingleIndexAdaptor,\n   *  return the dim'th component of the idx'th point\n   */\n  FloatType kdtree_get_pt(const size_t idx, const size_t dim) const {\n    return data_.Ptr<FloatType>()[idx * data_->shape[1] + dim];\n  }\n\n  /**\n   * @brief Interface expected by KDTreeSingleIndexAdaptor.\n   *  Optional bounding-box computation: return false to\n   *  default to a standard bbox computation loop.\n   *\n   */\n  template <typename BBOX>\n  bool kdtree_get_bbox(BBOX& /* bb */) const {\n    return false;\n  }\n\n private:\n  index_type* index_;   // The kd tree index\n  const NDArray data_;  // data points\n};\n\n}  // namespace knn_utils\n}  // namespace transform\n}  // namespace dgl\n\n#endif  // DGL_GRAPH_TRANSFORM_CPU_KDTREE_NDARRAY_ADAPTER_H_\n"
  },
  {
    "path": "src/graph/transform/cpu/knn.cc",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file graph/transform/cpu/knn.cc\n * @brief k-nearest-neighbor (KNN) implementation\n */\n\n#include \"../knn.h\"\n\n#include <dgl/random.h>\n#include <dgl/runtime/device_api.h>\n#include <dgl/runtime/parallel_for.h>\n#include <dmlc/omp.h>\n\n#include <algorithm>\n#include <limits>\n#include <tuple>\n#include <vector>\n\n#include \"kdtree_ndarray_adapter.h\"\n\nusing namespace dgl::runtime;\nusing namespace dgl::transform::knn_utils;\nnamespace dgl {\nnamespace transform {\nnamespace impl {\n\n// This value is directly from pynndescent\nstatic constexpr int NN_DESCENT_BLOCK_SIZE = 16384;\n\n/**\n * @brief Compute Euclidean distance between two vectors, return positive\n *  infinite value if the intermediate distance is greater than the worst\n *  distance.\n */\ntemplate <typename FloatType, typename IdType>\nFloatType EuclideanDistWithCheck(\n    const FloatType* vec1, const FloatType* vec2, int64_t dim,\n    FloatType worst_dist = std::numeric_limits<FloatType>::max()) {\n  FloatType dist = 0;\n  bool early_stop = false;\n\n  for (IdType idx = 0; idx < dim; ++idx) {\n    dist += (vec1[idx] - vec2[idx]) * (vec1[idx] - vec2[idx]);\n    if (dist > worst_dist) {\n      early_stop = true;\n      break;\n    }\n  }\n\n  if (early_stop) {\n    return std::numeric_limits<FloatType>::max();\n  } else {\n    return dist;\n  }\n}\n\n/** @brief Compute Euclidean distance between two vectors */\ntemplate <typename FloatType, typename IdType>\nFloatType EuclideanDist(\n    const FloatType* vec1, const FloatType* vec2, int64_t dim) {\n  FloatType dist = 0;\n\n  for (IdType idx = 0; idx < dim; ++idx) {\n    dist += (vec1[idx] - vec2[idx]) * (vec1[idx] - vec2[idx]);\n  }\n\n  return dist;\n}\n\n/** @brief Insert a new element into a heap */\ntemplate <typename FloatType, typename IdType>\nvoid HeapInsert(\n    IdType* out, FloatType* dist, IdType new_id, FloatType new_dist, int k,\n    bool check_repeat = false) {\n  if (new_dist > dist[0]) return;\n\n  // check if we have it\n  if (check_repeat) {\n    for (IdType i = 0; i < k; ++i) {\n      if (out[i] == new_id) return;\n    }\n  }\n\n  IdType left_idx = 0, right_idx = 0, curr_idx = 0, swap_idx = 0;\n  dist[0] = new_dist;\n  out[0] = new_id;\n  while (true) {\n    left_idx = 2 * curr_idx + 1;\n    right_idx = left_idx + 1;\n    swap_idx = curr_idx;\n    if (left_idx < k && dist[left_idx] > dist[swap_idx]) {\n      swap_idx = left_idx;\n    }\n    if (right_idx < k && dist[right_idx] > dist[swap_idx]) {\n      swap_idx = right_idx;\n    }\n    if (swap_idx != curr_idx) {\n      std::swap(dist[curr_idx], dist[swap_idx]);\n      std::swap(out[curr_idx], out[swap_idx]);\n      curr_idx = swap_idx;\n    } else {\n      break;\n    }\n  }\n}\n\n/** @brief Insert a new element and its flag into heap, return 1 if insert\n * successfully */\ntemplate <typename FloatType, typename IdType>\nint FlaggedHeapInsert(\n    IdType* out, FloatType* dist, bool* flag, IdType new_id, FloatType new_dist,\n    bool new_flag, int k, bool check_repeat = false) {\n  if (new_dist > dist[0]) return 0;\n\n  if (check_repeat) {\n    for (IdType i = 0; i < k; ++i) {\n      if (out[i] == new_id) return 0;\n    }\n  }\n\n  IdType left_idx = 0, right_idx = 0, curr_idx = 0, swap_idx = 0;\n  dist[0] = new_dist;\n  out[0] = new_id;\n  flag[0] = new_flag;\n  while (true) {\n    left_idx = 2 * curr_idx + 1;\n    right_idx = left_idx + 1;\n    swap_idx = curr_idx;\n    if (left_idx < k && dist[left_idx] > dist[swap_idx]) {\n      swap_idx = left_idx;\n    }\n    if (right_idx < k && dist[right_idx] > dist[swap_idx]) {\n      swap_idx = right_idx;\n    }\n    if (swap_idx != curr_idx) {\n      std::swap(dist[curr_idx], dist[swap_idx]);\n      std::swap(out[curr_idx], out[swap_idx]);\n      std::swap(flag[curr_idx], flag[swap_idx]);\n      curr_idx = swap_idx;\n    } else {\n      break;\n    }\n  }\n  return 1;\n}\n\n/** @brief Build heap for each point. Used by NN-descent */\ntemplate <typename FloatType, typename IdType>\nvoid BuildHeap(IdType* index, FloatType* dist, int k) {\n  for (int i = k / 2 - 1; i >= 0; --i) {\n    IdType idx = i;\n    while (true) {\n      IdType largest = idx;\n      IdType left = idx * 2 + 1;\n      IdType right = left + 1;\n      if (left < k && dist[left] > dist[largest]) {\n        largest = left;\n      }\n      if (right < k && dist[right] > dist[largest]) {\n        largest = right;\n      }\n      if (largest != idx) {\n        std::swap(index[largest], index[idx]);\n        std::swap(dist[largest], dist[idx]);\n        idx = largest;\n      } else {\n        break;\n      }\n    }\n  }\n}\n\n/**\n * @brief Neighbor update process in NN-descent. The distance between\n *  two points are computed. If this new distance is less than any worst\n *  distance of these two points, we update the neighborhood of that point.\n */\ntemplate <typename FloatType, typename IdType>\nint UpdateNeighbors(\n    IdType* neighbors, FloatType* dists, const FloatType* points, bool* flags,\n    IdType c1, IdType c2, IdType point_start, int64_t feature_size, int k) {\n  IdType c1_local = c1 - point_start, c2_local = c2 - point_start;\n  FloatType worst_c1_dist = dists[c1_local * k];\n  FloatType worst_c2_dist = dists[c2_local * k];\n  FloatType new_dist = EuclideanDistWithCheck<FloatType, IdType>(\n      points + c1 * feature_size, points + c2 * feature_size, feature_size,\n      std::max(worst_c1_dist, worst_c2_dist));\n\n  int num_updates = 0;\n  if (new_dist < worst_c1_dist) {\n    ++num_updates;\n#pragma omp critical\n    {\n      FlaggedHeapInsert<FloatType, IdType>(\n          neighbors + c1 * k, dists + c1_local * k, flags + c1_local * k, c2,\n          new_dist, true, k, true);\n    }\n  }\n  if (new_dist < worst_c2_dist) {\n    ++num_updates;\n#pragma omp critical\n    {\n      FlaggedHeapInsert<FloatType, IdType>(\n          neighbors + c2 * k, dists + c2_local * k, flags + c2_local * k, c1,\n          new_dist, true, k, true);\n    }\n  }\n  return num_updates;\n}\n\n/** @brief The kd-tree implementation of K-Nearest Neighbors */\ntemplate <typename FloatType, typename IdType>\nvoid KdTreeKNN(\n    const NDArray& data_points, const IdArray& data_offsets,\n    const NDArray& query_points, const IdArray& query_offsets, const int k,\n    IdArray result) {\n  const int64_t batch_size = data_offsets->shape[0] - 1;\n  const int64_t feature_size = data_points->shape[1];\n  const IdType* data_offsets_data = data_offsets.Ptr<IdType>();\n  const IdType* query_offsets_data = query_offsets.Ptr<IdType>();\n  const FloatType* query_points_data = query_points.Ptr<FloatType>();\n  IdType* query_out = result.Ptr<IdType>();\n  IdType* data_out = query_out + k * query_points->shape[0];\n\n  for (int64_t b = 0; b < batch_size; ++b) {\n    auto d_offset = data_offsets_data[b];\n    auto d_length = data_offsets_data[b + 1] - d_offset;\n    auto q_offset = query_offsets_data[b];\n    auto q_length = query_offsets_data[b + 1] - q_offset;\n    auto out_offset = k * q_offset;\n\n    // create view for each segment\n    const NDArray current_data_points =\n        const_cast<NDArray*>(&data_points)\n            ->CreateView(\n                {d_length, feature_size}, data_points->dtype,\n                d_offset * feature_size * sizeof(FloatType));\n    const FloatType* current_query_pts_data =\n        query_points_data + q_offset * feature_size;\n\n    KDTreeNDArrayAdapter<FloatType, IdType> kdtree(\n        feature_size, current_data_points);\n\n    // query\n    parallel_for(0, q_length, [&](IdType b, IdType e) {\n      for (auto q = b; q < e; ++q) {\n        std::vector<IdType> out_buffer(k);\n        std::vector<FloatType> out_dist_buffer(k);\n\n        auto curr_out_offset = k * q + out_offset;\n        const FloatType* q_point = current_query_pts_data + q * feature_size;\n        size_t num_matches = kdtree.GetIndex()->knnSearch(\n            q_point, k, out_buffer.data(), out_dist_buffer.data());\n\n        for (size_t i = 0; i < num_matches; ++i) {\n          query_out[curr_out_offset] = q + q_offset;\n          data_out[curr_out_offset] = out_buffer[i] + d_offset;\n          curr_out_offset++;\n        }\n      }\n    });\n  }\n}\n\ntemplate <typename FloatType, typename IdType>\nvoid BruteForceKNN(\n    const NDArray& data_points, const IdArray& data_offsets,\n    const NDArray& query_points, const IdArray& query_offsets, const int k,\n    IdArray result) {\n  const int64_t batch_size = data_offsets->shape[0] - 1;\n  const int64_t feature_size = data_points->shape[1];\n  const IdType* data_offsets_data = data_offsets.Ptr<IdType>();\n  const IdType* query_offsets_data = query_offsets.Ptr<IdType>();\n  const FloatType* data_points_data = data_points.Ptr<FloatType>();\n  const FloatType* query_points_data = query_points.Ptr<FloatType>();\n  IdType* query_out = result.Ptr<IdType>();\n  IdType* data_out = query_out + k * query_points->shape[0];\n\n  for (int64_t b = 0; b < batch_size; ++b) {\n    IdType d_start = data_offsets_data[b], d_end = data_offsets_data[b + 1];\n    IdType q_start = query_offsets_data[b], q_end = query_offsets_data[b + 1];\n\n    std::vector<FloatType> dist_buffer(k);\n\n    parallel_for(q_start, q_end, [&](IdType b, IdType e) {\n      for (auto q_idx = b; q_idx < e; ++q_idx) {\n        std::vector<FloatType> dist_buffer(k);\n        for (IdType k_idx = 0; k_idx < k; ++k_idx) {\n          query_out[q_idx * k + k_idx] = q_idx;\n          dist_buffer[k_idx] = std::numeric_limits<FloatType>::max();\n        }\n        FloatType worst_dist = std::numeric_limits<FloatType>::max();\n\n        for (IdType d_idx = d_start; d_idx < d_end; ++d_idx) {\n          FloatType tmp_dist = EuclideanDistWithCheck<FloatType, IdType>(\n              query_points_data + q_idx * feature_size,\n              data_points_data + d_idx * feature_size, feature_size,\n              worst_dist);\n\n          if (tmp_dist == std::numeric_limits<FloatType>::max()) {\n            continue;\n          }\n\n          IdType out_offset = q_idx * k;\n          HeapInsert<FloatType, IdType>(\n              data_out + out_offset, dist_buffer.data(), d_idx, tmp_dist, k);\n          worst_dist = dist_buffer[0];\n        }\n      }\n    });\n  }\n}\n}  // namespace impl\n\ntemplate <DGLDeviceType XPU, typename FloatType, typename IdType>\nvoid KNN(\n    const NDArray& data_points, const IdArray& data_offsets,\n    const NDArray& query_points, const IdArray& query_offsets, const int k,\n    IdArray result, const std::string& algorithm) {\n  if (algorithm == std::string(\"kd-tree\")) {\n    impl::KdTreeKNN<FloatType, IdType>(\n        data_points, data_offsets, query_points, query_offsets, k, result);\n  } else if (algorithm == std::string(\"bruteforce\")) {\n    impl::BruteForceKNN<FloatType, IdType>(\n        data_points, data_offsets, query_points, query_offsets, k, result);\n  } else {\n    LOG(FATAL) << \"Algorithm \" << algorithm << \" is not supported on CPU\";\n  }\n}\n\ntemplate <DGLDeviceType XPU, typename FloatType, typename IdType>\nvoid NNDescent(\n    const NDArray& points, const IdArray& offsets, IdArray result, const int k,\n    const int num_iters, const int num_candidates, const double delta) {\n  using nnd_updates_t =\n      std::vector<std::vector<std::tuple<IdType, IdType, FloatType>>>;\n  const auto& ctx = points->ctx;\n  auto device = runtime::DeviceAPI::Get(ctx);\n  const int64_t num_nodes = points->shape[0];\n  const int64_t batch_size = offsets->shape[0] - 1;\n  const int64_t feature_size = points->shape[1];\n  const IdType* offsets_data = offsets.Ptr<IdType>();\n  const FloatType* points_data = points.Ptr<FloatType>();\n\n  IdType* central_nodes = result.Ptr<IdType>();\n  IdType* neighbors = central_nodes + k * num_nodes;\n  int64_t max_segment_size = 0;\n\n  // find max segment\n  for (IdType b = 0; b < batch_size; ++b) {\n    if (max_segment_size < offsets_data[b + 1] - offsets_data[b])\n      max_segment_size = offsets_data[b + 1] - offsets_data[b];\n  }\n\n  // allocate memory for candidate, sampling pool, distance and flag\n  IdType* new_candidates = static_cast<IdType*>(device->AllocWorkspace(\n      ctx, max_segment_size * num_candidates * sizeof(IdType)));\n  IdType* old_candidates = static_cast<IdType*>(device->AllocWorkspace(\n      ctx, max_segment_size * num_candidates * sizeof(IdType)));\n  FloatType* new_candidates_dists =\n      static_cast<FloatType*>(device->AllocWorkspace(\n          ctx, max_segment_size * num_candidates * sizeof(FloatType)));\n  FloatType* old_candidates_dists =\n      static_cast<FloatType*>(device->AllocWorkspace(\n          ctx, max_segment_size * num_candidates * sizeof(FloatType)));\n  FloatType* neighbors_dists = static_cast<FloatType*>(\n      device->AllocWorkspace(ctx, max_segment_size * k * sizeof(FloatType)));\n  bool* flags = static_cast<bool*>(\n      device->AllocWorkspace(ctx, max_segment_size * k * sizeof(bool)));\n\n  for (IdType b = 0; b < batch_size; ++b) {\n    IdType point_idx_start = offsets_data[b],\n           point_idx_end = offsets_data[b + 1];\n    IdType segment_size = point_idx_end - point_idx_start;\n\n    // random initialization\n    runtime::parallel_for(\n        point_idx_start, point_idx_end, [&](size_t b, size_t e) {\n          for (auto i = b; i < e; ++i) {\n            IdType local_idx = i - point_idx_start;\n\n            dgl::RandomEngine::ThreadLocal()->UniformChoice<IdType>(\n                k, segment_size, neighbors + i * k, false);\n\n            for (IdType n = 0; n < k; ++n) {\n              central_nodes[i * k + n] = i;\n              neighbors[i * k + n] += point_idx_start;\n              flags[local_idx * k + n] = true;\n              neighbors_dists[local_idx * k + n] =\n                  impl::EuclideanDist<FloatType, IdType>(\n                      points_data + i * feature_size,\n                      points_data + neighbors[i * k + n] * feature_size,\n                      feature_size);\n            }\n            impl::BuildHeap<FloatType, IdType>(\n                neighbors + i * k, neighbors_dists + local_idx * k, k);\n          }\n        });\n\n    size_t num_updates = 0;\n    for (int iter = 0; iter < num_iters; ++iter) {\n      num_updates = 0;\n\n      // initialize candidates array as empty value\n      runtime::parallel_for(\n          point_idx_start, point_idx_end, [&](size_t b, size_t e) {\n            for (auto i = b; i < e; ++i) {\n              IdType local_idx = i - point_idx_start;\n              for (IdType c = 0; c < num_candidates; ++c) {\n                new_candidates[local_idx * num_candidates + c] = num_nodes;\n                old_candidates[local_idx * num_candidates + c] = num_nodes;\n                new_candidates_dists[local_idx * num_candidates + c] =\n                    std::numeric_limits<FloatType>::max();\n                old_candidates_dists[local_idx * num_candidates + c] =\n                    std::numeric_limits<FloatType>::max();\n              }\n            }\n          });\n\n      // randomly select neighbors as candidates\n      int num_threads = omp_get_max_threads();\n      runtime::parallel_for(0, num_threads, [&](IdType b, IdType e) {\n        for (auto tid = b; tid < e; ++tid) {\n          for (IdType i = point_idx_start; i < point_idx_end; ++i) {\n            IdType local_idx = i - point_idx_start;\n            for (IdType n = 0; n < k; ++n) {\n              IdType neighbor_idx = neighbors[i * k + n];\n              bool is_new = flags[local_idx * k + n];\n              IdType local_neighbor_idx = neighbor_idx - point_idx_start;\n              FloatType random_dist =\n                  dgl::RandomEngine::ThreadLocal()->Uniform<FloatType>();\n\n              if (is_new) {\n                if (local_idx % num_threads == tid) {\n                  impl::HeapInsert<FloatType, IdType>(\n                      new_candidates + local_idx * num_candidates,\n                      new_candidates_dists + local_idx * num_candidates,\n                      neighbor_idx, random_dist, num_candidates, true);\n                }\n                if (local_neighbor_idx % num_threads == tid) {\n                  impl::HeapInsert<FloatType, IdType>(\n                      new_candidates + local_neighbor_idx * num_candidates,\n                      new_candidates_dists +\n                          local_neighbor_idx * num_candidates,\n                      i, random_dist, num_candidates, true);\n                }\n              } else {\n                if (local_idx % num_threads == tid) {\n                  impl::HeapInsert<FloatType, IdType>(\n                      old_candidates + local_idx * num_candidates,\n                      old_candidates_dists + local_idx * num_candidates,\n                      neighbor_idx, random_dist, num_candidates, true);\n                }\n                if (local_neighbor_idx % num_threads == tid) {\n                  impl::HeapInsert<FloatType, IdType>(\n                      old_candidates + local_neighbor_idx * num_candidates,\n                      old_candidates_dists +\n                          local_neighbor_idx * num_candidates,\n                      i, random_dist, num_candidates, true);\n                }\n              }\n            }\n          }\n        }\n      });\n\n      // mark all elements in new_candidates as false\n      runtime::parallel_for(\n          point_idx_start, point_idx_end, [&](size_t b, size_t e) {\n            for (auto i = b; i < e; ++i) {\n              IdType local_idx = i - point_idx_start;\n              for (IdType n = 0; n < k; ++n) {\n                IdType n_idx = neighbors[i * k + n];\n\n                for (IdType c = 0; c < num_candidates; ++c) {\n                  if (new_candidates[local_idx * num_candidates + c] == n_idx) {\n                    flags[local_idx * k + n] = false;\n                    break;\n                  }\n                }\n              }\n            }\n          });\n\n      // update neighbors block by block\n      for (IdType block_start = point_idx_start; block_start < point_idx_end;\n           block_start += impl::NN_DESCENT_BLOCK_SIZE) {\n        IdType block_end =\n            std::min(point_idx_end, block_start + impl::NN_DESCENT_BLOCK_SIZE);\n        IdType block_size = block_end - block_start;\n        nnd_updates_t updates(block_size);\n\n        // generate updates\n        runtime::parallel_for(block_start, block_end, [&](size_t b, size_t e) {\n          for (auto i = b; i < e; ++i) {\n            IdType local_idx = i - point_idx_start;\n\n            for (IdType c1 = 0; c1 < num_candidates; ++c1) {\n              IdType new_c1 = new_candidates[local_idx * num_candidates + c1];\n              if (new_c1 == num_nodes) continue;\n              IdType c1_local = new_c1 - point_idx_start;\n\n              // new-new\n              for (IdType c2 = c1; c2 < num_candidates; ++c2) {\n                IdType new_c2 = new_candidates[local_idx * num_candidates + c2];\n                if (new_c2 == num_nodes) continue;\n                IdType c2_local = new_c2 - point_idx_start;\n\n                FloatType worst_c1_dist = neighbors_dists[c1_local * k];\n                FloatType worst_c2_dist = neighbors_dists[c2_local * k];\n                FloatType new_dist =\n                    impl::EuclideanDistWithCheck<FloatType, IdType>(\n                        points_data + new_c1 * feature_size,\n                        points_data + new_c2 * feature_size, feature_size,\n                        std::max(worst_c1_dist, worst_c2_dist));\n\n                if (new_dist < worst_c1_dist || new_dist < worst_c2_dist) {\n                  updates[i - block_start].push_back(\n                      std::make_tuple(new_c1, new_c2, new_dist));\n                }\n              }\n\n              // new-old\n              for (IdType c2 = 0; c2 < num_candidates; ++c2) {\n                IdType old_c2 = old_candidates[local_idx * num_candidates + c2];\n                if (old_c2 == num_nodes) continue;\n                IdType c2_local = old_c2 - point_idx_start;\n\n                FloatType worst_c1_dist = neighbors_dists[c1_local * k];\n                FloatType worst_c2_dist = neighbors_dists[c2_local * k];\n                FloatType new_dist =\n                    impl::EuclideanDistWithCheck<FloatType, IdType>(\n                        points_data + new_c1 * feature_size,\n                        points_data + old_c2 * feature_size, feature_size,\n                        std::max(worst_c1_dist, worst_c2_dist));\n\n                if (new_dist < worst_c1_dist || new_dist < worst_c2_dist) {\n                  updates[i - block_start].push_back(\n                      std::make_tuple(new_c1, old_c2, new_dist));\n                }\n              }\n            }\n          }\n        });\n\n        int tid;\n#pragma omp parallel private(tid, num_threads) reduction(+ : num_updates)\n        {\n          tid = omp_get_thread_num();\n          num_threads = omp_get_num_threads();\n          for (IdType i = 0; i < block_size; ++i) {\n            for (const auto& u : updates[i]) {\n              IdType p1, p2;\n              FloatType d;\n              std::tie(p1, p2, d) = u;\n              IdType p1_local = p1 - point_idx_start;\n              IdType p2_local = p2 - point_idx_start;\n\n              if (p1 % num_threads == tid) {\n                num_updates += impl::FlaggedHeapInsert<FloatType, IdType>(\n                    neighbors + p1 * k, neighbors_dists + p1_local * k,\n                    flags + p1_local * k, p2, d, true, k, true);\n              }\n              if (p2 % num_threads == tid) {\n                num_updates += impl::FlaggedHeapInsert<FloatType, IdType>(\n                    neighbors + p2 * k, neighbors_dists + p2_local * k,\n                    flags + p2_local * k, p1, d, true, k, true);\n              }\n            }\n          }\n        }\n      }\n\n      // early abort\n      if (num_updates <= static_cast<size_t>(delta * k * segment_size)) {\n        break;\n      }\n    }\n  }\n\n  device->FreeWorkspace(ctx, new_candidates);\n  device->FreeWorkspace(ctx, old_candidates);\n  device->FreeWorkspace(ctx, new_candidates_dists);\n  device->FreeWorkspace(ctx, old_candidates_dists);\n  device->FreeWorkspace(ctx, neighbors_dists);\n  device->FreeWorkspace(ctx, flags);\n}\n\ntemplate void KNN<kDGLCPU, float, int32_t>(\n    const NDArray& data_points, const IdArray& data_offsets,\n    const NDArray& query_points, const IdArray& query_offsets, const int k,\n    IdArray result, const std::string& algorithm);\ntemplate void KNN<kDGLCPU, float, int64_t>(\n    const NDArray& data_points, const IdArray& data_offsets,\n    const NDArray& query_points, const IdArray& query_offsets, const int k,\n    IdArray result, const std::string& algorithm);\ntemplate void KNN<kDGLCPU, double, int32_t>(\n    const NDArray& data_points, const IdArray& data_offsets,\n    const NDArray& query_points, const IdArray& query_offsets, const int k,\n    IdArray result, const std::string& algorithm);\ntemplate void KNN<kDGLCPU, double, int64_t>(\n    const NDArray& data_points, const IdArray& data_offsets,\n    const NDArray& query_points, const IdArray& query_offsets, const int k,\n    IdArray result, const std::string& algorithm);\n\ntemplate void NNDescent<kDGLCPU, float, int32_t>(\n    const NDArray& points, const IdArray& offsets, IdArray result, const int k,\n    const int num_iters, const int num_candidates, const double delta);\ntemplate void NNDescent<kDGLCPU, float, int64_t>(\n    const NDArray& points, const IdArray& offsets, IdArray result, const int k,\n    const int num_iters, const int num_candidates, const double delta);\ntemplate void NNDescent<kDGLCPU, double, int32_t>(\n    const NDArray& points, const IdArray& offsets, IdArray result, const int k,\n    const int num_iters, const int num_candidates, const double delta);\ntemplate void NNDescent<kDGLCPU, double, int64_t>(\n    const NDArray& points, const IdArray& offsets, IdArray result, const int k,\n    const int num_iters, const int num_candidates, const double delta);\n}  // namespace transform\n}  // namespace dgl\n"
  },
  {
    "path": "src/graph/transform/cuda/cuda_compact_graph.cu",
    "content": "/**\n *  Copyright 2021 Contributors\n *\n *  Licensed under the Apache License, Version 2.0 (the \"License\");\n *  you may not use this file except in compliance with the License.\n *  You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n *  Unless required by applicable law or agreed to in writing, software\n *  distributed under the License is distributed on an \"AS IS\" BASIS,\n *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n *  See the License for the specific language governing permissions and\n *  limitations under the License.\n *\n * @file graph/transform/cuda/cuda_compact_graph.cu\n * @brief Functions to find and eliminate the common isolated nodes across\n * all given graphs with the same set of nodes.\n */\n\n#include <cuda_runtime.h>\n#include <dgl/immutable_graph.h>\n#include <dgl/runtime/device_api.h>\n\n#include <algorithm>\n#include <memory>\n#include <utility>\n\n#include \"../../../runtime/cuda/cuda_common.h\"\n#include \"../../heterograph.h\"\n#include \"../compact.h\"\n#include \"cuda_map_edges.cuh\"\n\nusing namespace dgl::aten;\nusing namespace dgl::runtime::cuda;\nusing namespace dgl::transform::cuda;\n\nnamespace dgl {\nnamespace transform {\n\nnamespace {\n\n/**\n * @brief This function builds node maps for each node type, preserving the\n * order of the input nodes. Here it is assumed the nodes are not unique,\n * and thus a unique list is generated.\n *\n * @param input_nodes The set of input nodes.\n * @param node_maps The node maps to be constructed.\n * @param count_unique_device The number of unique nodes (on the GPU).\n * @param unique_nodes_device The unique nodes (on the GPU).\n * @param stream The stream to operate on.\n */\ntemplate <typename IdType>\nvoid BuildNodeMaps(\n    const std::vector<IdArray> &input_nodes,\n    DeviceNodeMap<IdType> *const node_maps, int64_t *const count_unique_device,\n    std::vector<IdArray> *const unique_nodes_device, cudaStream_t stream) {\n  const int64_t num_ntypes = static_cast<int64_t>(input_nodes.size());\n\n  CUDA_CALL(cudaMemsetAsync(\n      count_unique_device, 0, num_ntypes * sizeof(*count_unique_device),\n      stream));\n\n  // possibly duplicated nodes\n  for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) {\n    const IdArray &nodes = input_nodes[ntype];\n    if (nodes->shape[0] > 0) {\n      CHECK_EQ(nodes->ctx.device_type, kDGLCUDA);\n      node_maps->LhsHashTable(ntype).FillWithDuplicates(\n          nodes.Ptr<IdType>(), nodes->shape[0],\n          (*unique_nodes_device)[ntype].Ptr<IdType>(),\n          count_unique_device + ntype, stream);\n    }\n  }\n}\n\ntemplate <typename IdType>\nstd::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>> CompactGraphsGPU(\n    const std::vector<HeteroGraphPtr> &graphs,\n    const std::vector<IdArray> &always_preserve) {\n  const auto &ctx = graphs[0]->Context();\n  auto device = runtime::DeviceAPI::Get(ctx);\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n\n  CHECK_EQ(ctx.device_type, kDGLCUDA);\n\n  // Step 1: Collect the nodes that has connections for each type.\n  const uint64_t num_ntypes = graphs[0]->NumVertexTypes();\n  std::vector<std::vector<EdgeArray>> all_edges(\n      graphs.size());  // all_edges[i][etype]\n\n  // count the number of nodes per type\n  std::vector<int64_t> max_vertex_cnt(num_ntypes, 0);\n  for (size_t i = 0; i < graphs.size(); ++i) {\n    const HeteroGraphPtr curr_graph = graphs[i];\n    const int64_t num_etypes = curr_graph->NumEdgeTypes();\n\n    for (IdType etype = 0; etype < num_etypes; ++etype) {\n      IdType srctype, dsttype;\n      std::tie(srctype, dsttype) = curr_graph->GetEndpointTypes(etype);\n\n      const int64_t n_edges = curr_graph->NumEdges(etype);\n      max_vertex_cnt[srctype] += n_edges;\n      max_vertex_cnt[dsttype] += n_edges;\n    }\n  }\n\n  for (size_t i = 0; i < always_preserve.size(); ++i) {\n    max_vertex_cnt[i] += always_preserve[i]->shape[0];\n  }\n\n  // gather all nodes\n  std::vector<IdArray> all_nodes(num_ntypes);\n  std::vector<int64_t> node_offsets(num_ntypes, 0);\n\n  for (uint64_t ntype = 0; ntype < num_ntypes; ++ntype) {\n    all_nodes[ntype] =\n        NewIdArray(max_vertex_cnt[ntype], ctx, sizeof(IdType) * 8);\n    // copy the nodes in always_preserve\n    if (ntype < always_preserve.size() &&\n        always_preserve[ntype]->shape[0] > 0) {\n      device->CopyDataFromTo(\n          always_preserve[ntype].Ptr<IdType>(), 0,\n          all_nodes[ntype].Ptr<IdType>(), node_offsets[ntype],\n          sizeof(IdType) * always_preserve[ntype]->shape[0],\n          always_preserve[ntype]->ctx, all_nodes[ntype]->ctx,\n          always_preserve[ntype]->dtype);\n      node_offsets[ntype] += sizeof(IdType) * always_preserve[ntype]->shape[0];\n    }\n  }\n\n  for (size_t i = 0; i < graphs.size(); ++i) {\n    const HeteroGraphPtr curr_graph = graphs[i];\n    const int64_t num_etypes = curr_graph->NumEdgeTypes();\n\n    all_edges[i].reserve(num_etypes);\n    for (int64_t etype = 0; etype < num_etypes; ++etype) {\n      dgl_type_t srctype, dsttype;\n      std::tie(srctype, dsttype) = curr_graph->GetEndpointTypes(etype);\n\n      const EdgeArray edges = curr_graph->Edges(etype, \"eid\");\n\n      if (edges.src.defined()) {\n        device->CopyDataFromTo(\n            edges.src.Ptr<IdType>(), 0, all_nodes[srctype].Ptr<IdType>(),\n            node_offsets[srctype], sizeof(IdType) * edges.src->shape[0],\n            edges.src->ctx, all_nodes[srctype]->ctx, edges.src->dtype);\n        node_offsets[srctype] += sizeof(IdType) * edges.src->shape[0];\n      }\n      if (edges.dst.defined()) {\n        device->CopyDataFromTo(\n            edges.dst.Ptr<IdType>(), 0, all_nodes[dsttype].Ptr<IdType>(),\n            node_offsets[dsttype], sizeof(IdType) * edges.dst->shape[0],\n            edges.dst->ctx, all_nodes[dsttype]->ctx, edges.dst->dtype);\n        node_offsets[dsttype] += sizeof(IdType) * edges.dst->shape[0];\n      }\n      all_edges[i].push_back(edges);\n    }\n  }\n\n  // Step 2: Relabel the nodes for each type to a smaller ID space\n  //         using BuildNodeMaps\n\n  // allocate space for map creation\n  // the hashmap on GPU\n  DeviceNodeMap<IdType> node_maps(max_vertex_cnt, 0, ctx, stream);\n  // number of unique nodes per type on CPU\n  std::vector<int64_t> num_induced_nodes(num_ntypes);\n  // number of unique nodes per type on GPU\n  int64_t *count_unique_device = static_cast<int64_t *>(\n      device->AllocWorkspace(ctx, sizeof(int64_t) * num_ntypes));\n  // the set of unique nodes per type\n  std::vector<IdArray> induced_nodes(num_ntypes);\n  for (uint64_t ntype = 0; ntype < num_ntypes; ++ntype) {\n    induced_nodes[ntype] =\n        NewIdArray(max_vertex_cnt[ntype], ctx, sizeof(IdType) * 8);\n  }\n\n  BuildNodeMaps(\n      all_nodes, &node_maps, count_unique_device, &induced_nodes, stream);\n\n  device->CopyDataFromTo(\n      count_unique_device, 0, num_induced_nodes.data(), 0,\n      sizeof(*num_induced_nodes.data()) * num_ntypes, ctx,\n      DGLContext{kDGLCPU, 0}, DGLDataType{kDGLInt, 64, 1});\n  device->StreamSync(ctx, stream);\n\n  // wait for the node counts to finish transferring\n  device->FreeWorkspace(ctx, count_unique_device);\n\n  // resize induced nodes\n  for (uint64_t ntype = 0; ntype < num_ntypes; ++ntype) {\n    induced_nodes[ntype]->shape[0] = num_induced_nodes[ntype];\n  }\n\n  // Step 3: Remap the edges of each graph using MapEdges\n  std::vector<HeteroGraphPtr> new_graphs;\n  for (size_t i = 0; i < graphs.size(); ++i) {\n    const HeteroGraphPtr curr_graph = graphs[i];\n    const auto meta_graph = curr_graph->meta_graph();\n    const int64_t num_etypes = curr_graph->NumEdgeTypes();\n\n    std::vector<HeteroGraphPtr> rel_graphs;\n    rel_graphs.reserve(num_etypes);\n\n    std::vector<IdArray> new_src;\n    std::vector<IdArray> new_dst;\n    std::tie(new_src, new_dst) =\n        MapEdges(curr_graph, all_edges[i], node_maps, stream);\n\n    for (IdType etype = 0; etype < num_etypes; ++etype) {\n      IdType srctype, dsttype;\n      std::tie(srctype, dsttype) = curr_graph->GetEndpointTypes(etype);\n\n      rel_graphs.push_back(UnitGraph::CreateFromCOO(\n          srctype == dsttype ? 1 : 2, induced_nodes[srctype]->shape[0],\n          induced_nodes[dsttype]->shape[0], new_src[etype], new_dst[etype]));\n    }\n\n    new_graphs.push_back(\n        CreateHeteroGraph(meta_graph, rel_graphs, num_induced_nodes));\n  }\n\n  return std::make_pair(new_graphs, induced_nodes);\n}\n\n}  // namespace\n\ntemplate <>\nstd::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>>\nCompactGraphs<kDGLCUDA, int32_t>(\n    const std::vector<HeteroGraphPtr> &graphs,\n    const std::vector<IdArray> &always_preserve) {\n  return CompactGraphsGPU<int32_t>(graphs, always_preserve);\n}\n\ntemplate <>\nstd::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>>\nCompactGraphs<kDGLCUDA, int64_t>(\n    const std::vector<HeteroGraphPtr> &graphs,\n    const std::vector<IdArray> &always_preserve) {\n  return CompactGraphsGPU<int64_t>(graphs, always_preserve);\n}\n\n}  // namespace transform\n}  // namespace dgl\n"
  },
  {
    "path": "src/graph/transform/cuda/cuda_map_edges.cuh",
    "content": "/**\n *  Copyright 2020-2022 Contributors\n *\n *  Licensed under the Apache License, Version 2.0 (the \"License\");\n *  you may not use this file except in compliance with the License.\n *  You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n *  Unless required by applicable law or agreed to in writing, software\n *  distributed under the License is distributed on an \"AS IS\" BASIS,\n *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n *  See the License for the specific language governing permissions and\n *  limitations under the License.\n *\n * @file graph/transform/cuda/cuda_map_edges.cuh\n * @brief Device level functions for mapping edges.\n */\n\n#ifndef DGL_GRAPH_TRANSFORM_CUDA_CUDA_MAP_EDGES_CUH_\n#define DGL_GRAPH_TRANSFORM_CUDA_CUDA_MAP_EDGES_CUH_\n\n#include <dgl/runtime/c_runtime_api.h>\n#include <dgl/base_heterograph.h>\n#include <cuda_runtime.h>\n#include <dgl/runtime/c_runtime_api.h>\n\n#include <algorithm>\n#include <memory>\n#include <tuple>\n#include <utility>\n#include <vector>\n\n#include \"../../../runtime/cuda/cuda_common.h\"\n#include \"../../../runtime/cuda/cuda_hashtable.cuh\"\n\nusing namespace dgl::aten;\nusing namespace dgl::runtime::cuda;\n\nnamespace dgl {\nnamespace transform {\n\nnamespace cuda {\n\ntemplate <typename IdType, int BLOCK_SIZE, IdType TILE_SIZE>\n__device__ void map_vertex_ids(\n    const IdType* const global, IdType* const new_global,\n    const IdType num_vertices, const DeviceOrderedHashTable<IdType>& table) {\n  assert(BLOCK_SIZE == blockDim.x);\n\n  using Mapping = typename OrderedHashTable<IdType>::Mapping;\n\n  const IdType tile_start = TILE_SIZE * blockIdx.x;\n  const IdType tile_end = min(TILE_SIZE * (blockIdx.x + 1), num_vertices);\n\n  for (IdType idx = threadIdx.x + tile_start; idx < tile_end;\n       idx += BLOCK_SIZE) {\n    const Mapping& mapping = *table.Search(global[idx]);\n    new_global[idx] = mapping.local;\n  }\n}\n\n/**\n * @brief Generate mapped edge endpoint ids.\n *\n * @tparam IdType The type of id.\n * @tparam BLOCK_SIZE The size of each thread block.\n * @tparam TILE_SIZE The number of edges to process per thread block.\n * @param global_srcs_device The source ids to map.\n * @param new_global_srcs_device The mapped source ids (output).\n * @param global_dsts_device The destination ids to map.\n * @param new_global_dsts_device The mapped destination ids (output).\n * @param num_edges The number of edges to map.\n * @param src_mapping The mapping of sources ids.\n * @param src_hash_size The the size of source id hash table/mapping.\n * @param dst_mapping The mapping of destination ids.\n * @param dst_hash_size The the size of destination id hash table/mapping.\n */\ntemplate <typename IdType, int BLOCK_SIZE, IdType TILE_SIZE>\n__global__ void map_edge_ids(\n    const IdType* const global_srcs_device,\n    IdType* const new_global_srcs_device,\n    const IdType* const global_dsts_device,\n    IdType* const new_global_dsts_device, const IdType num_edges,\n    DeviceOrderedHashTable<IdType> src_mapping,\n    DeviceOrderedHashTable<IdType> dst_mapping) {\n  assert(BLOCK_SIZE == blockDim.x);\n  assert(2 == gridDim.y);\n\n  if (blockIdx.y == 0) {\n    map_vertex_ids<IdType, BLOCK_SIZE, TILE_SIZE>(\n        global_srcs_device, new_global_srcs_device, num_edges, src_mapping);\n  } else {\n    map_vertex_ids<IdType, BLOCK_SIZE, TILE_SIZE>(\n        global_dsts_device, new_global_dsts_device, num_edges, dst_mapping);\n  }\n}\n\n/**\n * @brief Device level node maps for each node type.\n *\n * @param num_nodes Number of nodes per type.\n * @param offset When offset is set to 0, LhsHashTable is identical to\n *        RhsHashTable. Or set to num_nodes.size()/2 to use seperated\n *        LhsHashTable and RhsHashTable.\n * @param ctx The DGL context.\n * @param stream The stream to operate on.\n */\ntemplate <typename IdType>\nclass DeviceNodeMap {\n public:\n  using Mapping = typename OrderedHashTable<IdType>::Mapping;\n\n  DeviceNodeMap(\n      const std::vector<int64_t>& num_nodes, const int64_t offset,\n      DGLContext ctx, cudaStream_t stream)\n      : num_types_(num_nodes.size()),\n        rhs_offset_(offset),\n        hash_tables_(),\n        ctx_(ctx) {\n    auto device = runtime::DeviceAPI::Get(ctx);\n\n    hash_tables_.reserve(num_types_);\n    for (int64_t i = 0; i < num_types_; ++i) {\n      hash_tables_.emplace_back(\n          new OrderedHashTable<IdType>(num_nodes[i], ctx_, stream));\n    }\n  }\n\n  OrderedHashTable<IdType>& LhsHashTable(const size_t index) {\n    return HashData(index);\n  }\n\n  OrderedHashTable<IdType>& RhsHashTable(const size_t index) {\n    return HashData(index + rhs_offset_);\n  }\n\n  const OrderedHashTable<IdType>& LhsHashTable(const size_t index) const {\n    return HashData(index);\n  }\n\n  const OrderedHashTable<IdType>& RhsHashTable(const size_t index) const {\n    return HashData(index + rhs_offset_);\n  }\n\n  IdType LhsHashSize(const size_t index) const { return HashSize(index); }\n\n  IdType RhsHashSize(const size_t index) const {\n    return HashSize(rhs_offset_ + index);\n  }\n\n  size_t Size() const { return hash_tables_.size(); }\n\n private:\n  int64_t num_types_;\n  size_t rhs_offset_;\n  std::vector<std::unique_ptr<OrderedHashTable<IdType>>> hash_tables_;\n  DGLContext ctx_;\n\n  inline OrderedHashTable<IdType>& HashData(const size_t index) {\n    CHECK_LT(index, hash_tables_.size());\n    return *hash_tables_[index];\n  }\n\n  inline const OrderedHashTable<IdType>& HashData(const size_t index) const {\n    CHECK_LT(index, hash_tables_.size());\n    return *hash_tables_[index];\n  }\n\n  inline IdType HashSize(const size_t index) const {\n    return HashData(index).size();\n  }\n};\n\ntemplate <typename IdType>\ninline size_t RoundUpDiv(const IdType num, const size_t divisor) {\n  return static_cast<IdType>(num / divisor) + (num % divisor == 0 ? 0 : 1);\n}\n\ntemplate <typename IdType>\ninline IdType RoundUp(const IdType num, const size_t unit) {\n  return RoundUpDiv(num, unit) * unit;\n}\n\ntemplate <typename IdType>\nstd::tuple<std::vector<IdArray>, std::vector<IdArray>> MapEdges(\n    HeteroGraphPtr graph, const std::vector<EdgeArray>& edge_sets,\n    const DeviceNodeMap<IdType>& node_map, cudaStream_t stream) {\n  constexpr const int BLOCK_SIZE = 128;\n  constexpr const size_t TILE_SIZE = 1024;\n\n  const auto& ctx = graph->Context();\n\n  std::vector<IdArray> new_lhs;\n  new_lhs.reserve(edge_sets.size());\n  std::vector<IdArray> new_rhs;\n  new_rhs.reserve(edge_sets.size());\n\n  // The next peformance optimization here, is to perform mapping of all edge\n  // types in a single kernel launch.\n  const int64_t num_edge_sets = static_cast<int64_t>(edge_sets.size());\n  for (int64_t etype = 0; etype < num_edge_sets; ++etype) {\n    const EdgeArray& edges = edge_sets[etype];\n    if (edges.id.defined() && edges.src->shape[0] > 0) {\n      const int64_t num_edges = edges.src->shape[0];\n\n      new_lhs.emplace_back(NewIdArray(num_edges, ctx, sizeof(IdType) * 8));\n      new_rhs.emplace_back(NewIdArray(num_edges, ctx, sizeof(IdType) * 8));\n\n      const auto src_dst_types = graph->GetEndpointTypes(etype);\n      const int src_type = src_dst_types.first;\n      const int dst_type = src_dst_types.second;\n\n      const dim3 grid(RoundUpDiv(num_edges, TILE_SIZE), 2);\n      const dim3 block(BLOCK_SIZE);\n\n      // map the srcs\n      CUDA_KERNEL_CALL(\n          (map_edge_ids<IdType, BLOCK_SIZE, TILE_SIZE>), grid, block, 0, stream,\n          edges.src.Ptr<IdType>(), new_lhs.back().Ptr<IdType>(),\n          edges.dst.Ptr<IdType>(), new_rhs.back().Ptr<IdType>(), num_edges,\n          node_map.LhsHashTable(src_type).DeviceHandle(),\n          node_map.RhsHashTable(dst_type).DeviceHandle());\n    } else {\n      new_lhs.emplace_back(\n          aten::NullArray(DGLDataType{kDGLInt, sizeof(IdType) * 8, 1}, ctx));\n      new_rhs.emplace_back(\n          aten::NullArray(DGLDataType{kDGLInt, sizeof(IdType) * 8, 1}, ctx));\n    }\n  }\n\n  return std::tuple<std::vector<IdArray>, std::vector<IdArray>>(\n      std::move(new_lhs), std::move(new_rhs));\n}\n\n}  // namespace cuda\n}  // namespace transform\n}  // namespace dgl\n\n#endif  // DGL_GRAPH_TRANSFORM_CUDA_CUDA_MAP_EDGES_CUH_\n"
  },
  {
    "path": "src/graph/transform/cuda/cuda_to_block.cu",
    "content": "/**\n *  Copyright 2020-2021 Contributors\n *\n *  Licensed under the Apache License, Version 2.0 (the \"License\");\n *  you may not use this file except in compliance with the License.\n *  You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n *  Unless required by applicable law or agreed to in writing, software\n *  distributed under the License is distributed on an \"AS IS\" BASIS,\n *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n *  See the License for the specific language governing permissions and\n *  limitations under the License.\n *\n * @file graph/transform/cuda/cuda_to_block.cu\n * @brief Functions to convert a set of edges into a graph block with local\n * ids.\n *\n * Tested via python wrapper: python/dgl/path/to/to_block.py\n */\n\n#include <cuda_runtime.h>\n#include <dgl/immutable_graph.h>\n#include <dgl/runtime/device_api.h>\n#include <dgl/runtime/tensordispatch.h>\n\n#include <algorithm>\n#include <memory>\n#include <utility>\n\n#include \"../../../runtime/cuda/cuda_common.h\"\n#include \"../../heterograph.h\"\n#include \"../to_block.h\"\n#include \"cuda_map_edges.cuh\"\n\nusing namespace dgl::aten;\nusing namespace dgl::runtime::cuda;\nusing namespace dgl::transform::cuda;\nusing TensorDispatcher = dgl::runtime::TensorDispatcher;\n\nnamespace dgl {\nnamespace transform {\n\nnamespace {\n\ntemplate <typename IdType>\nclass DeviceNodeMapMaker {\n public:\n  explicit DeviceNodeMapMaker(const std::vector<int64_t>& maxNodesPerType)\n      : max_num_nodes_(0) {\n    max_num_nodes_ =\n        *std::max_element(maxNodesPerType.begin(), maxNodesPerType.end());\n  }\n\n  /**\n   * @brief This function builds node maps for each node type, preserving the\n   * order of the input nodes. Here it is assumed the lhs_nodes are not unique,\n   * and thus a unique list is generated.\n   *\n   * @param lhs_nodes The set of source input nodes.\n   * @param rhs_nodes The set of destination input nodes.\n   * @param node_maps The node maps to be constructed.\n   * @param count_lhs_device The number of unique source nodes (on the GPU).\n   * @param lhs_device The unique source nodes (on the GPU).\n   * @param stream The stream to operate on.\n   */\n  void Make(\n      const std::vector<IdArray>& lhs_nodes,\n      const std::vector<IdArray>& rhs_nodes,\n      DeviceNodeMap<IdType>* const node_maps, int64_t* const count_lhs_device,\n      std::vector<IdArray>* const lhs_device, cudaStream_t stream) {\n    const int64_t num_ntypes = lhs_nodes.size() + rhs_nodes.size();\n\n    CUDA_CALL(cudaMemsetAsync(\n        count_lhs_device, 0, num_ntypes * sizeof(*count_lhs_device), stream));\n\n    // possibly dublicate lhs nodes\n    const int64_t lhs_num_ntypes = static_cast<int64_t>(lhs_nodes.size());\n    for (int64_t ntype = 0; ntype < lhs_num_ntypes; ++ntype) {\n      const IdArray& nodes = lhs_nodes[ntype];\n      if (nodes->shape[0] > 0) {\n        CHECK_EQ(nodes->ctx.device_type, kDGLCUDA);\n        node_maps->LhsHashTable(ntype).FillWithDuplicates(\n            nodes.Ptr<IdType>(), nodes->shape[0],\n            (*lhs_device)[ntype].Ptr<IdType>(), count_lhs_device + ntype,\n            stream);\n      }\n    }\n\n    // unique rhs nodes\n    const int64_t rhs_num_ntypes = static_cast<int64_t>(rhs_nodes.size());\n    for (int64_t ntype = 0; ntype < rhs_num_ntypes; ++ntype) {\n      const IdArray& nodes = rhs_nodes[ntype];\n      if (nodes->shape[0] > 0) {\n        node_maps->RhsHashTable(ntype).FillWithUnique(\n            nodes.Ptr<IdType>(), nodes->shape[0], stream);\n      }\n    }\n  }\n\n  /**\n   * @brief This function builds node maps for each node type, preserving the\n   * order of the input nodes. Here it is assumed both lhs_nodes and rhs_nodes\n   * are unique.\n   *\n   * @param lhs_nodes The set of source input nodes.\n   * @param rhs_nodes The set of destination input nodes.\n   * @param node_maps The node maps to be constructed.\n   * @param stream The stream to operate on.\n   */\n  void Make(\n      const std::vector<IdArray>& lhs_nodes,\n      const std::vector<IdArray>& rhs_nodes,\n      DeviceNodeMap<IdType>* const node_maps, cudaStream_t stream) {\n    const int64_t num_ntypes = lhs_nodes.size() + rhs_nodes.size();\n\n    // unique lhs nodes\n    const int64_t lhs_num_ntypes = static_cast<int64_t>(lhs_nodes.size());\n    for (int64_t ntype = 0; ntype < lhs_num_ntypes; ++ntype) {\n      const IdArray& nodes = lhs_nodes[ntype];\n      if (nodes->shape[0] > 0) {\n        CHECK_EQ(nodes->ctx.device_type, kDGLCUDA);\n        node_maps->LhsHashTable(ntype).FillWithUnique(\n            nodes.Ptr<IdType>(), nodes->shape[0], stream);\n      }\n    }\n\n    // unique rhs nodes\n    const int64_t rhs_num_ntypes = static_cast<int64_t>(rhs_nodes.size());\n    for (int64_t ntype = 0; ntype < rhs_num_ntypes; ++ntype) {\n      const IdArray& nodes = rhs_nodes[ntype];\n      if (nodes->shape[0] > 0) {\n        node_maps->RhsHashTable(ntype).FillWithUnique(\n            nodes.Ptr<IdType>(), nodes->shape[0], stream);\n      }\n    }\n  }\n\n private:\n  IdType max_num_nodes_;\n};\n\ntemplate <typename IdType>\nstruct CUDAIdsMapper {\n  std::tuple<std::vector<IdArray>, std::vector<IdArray>> operator()(\n      const HeteroGraphPtr& graph, bool include_rhs_in_lhs, int64_t num_ntypes,\n      const DGLContext& ctx, const std::vector<int64_t>& maxNodesPerType,\n      const std::vector<EdgeArray>& edge_arrays,\n      const std::vector<IdArray>& src_nodes,\n      const std::vector<IdArray>& rhs_nodes,\n      std::vector<IdArray>* const lhs_nodes_ptr,\n      std::vector<int64_t>* const num_nodes_per_type_ptr) {\n    std::vector<IdArray>& lhs_nodes = *lhs_nodes_ptr;\n    std::vector<int64_t>& num_nodes_per_type = *num_nodes_per_type_ptr;\n    const bool generate_lhs_nodes = lhs_nodes.empty();\n    auto device = runtime::DeviceAPI::Get(ctx);\n    cudaStream_t stream = runtime::getCurrentCUDAStream();\n\n    // Allocate space for map creation process.\n    DeviceNodeMapMaker<IdType> maker(maxNodesPerType);\n    DeviceNodeMap<IdType> node_maps(maxNodesPerType, num_ntypes, ctx, stream);\n    if (generate_lhs_nodes) {\n      lhs_nodes.reserve(num_ntypes);\n      for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) {\n        lhs_nodes.emplace_back(\n            NewIdArray(maxNodesPerType[ntype], ctx, sizeof(IdType) * 8));\n      }\n    }\n\n    cudaEvent_t copyEvent;\n    NDArray new_len_tensor;\n    // Populate the mappings.\n    if (generate_lhs_nodes) {\n      int64_t* count_lhs_device = static_cast<int64_t*>(\n          device->AllocWorkspace(ctx, sizeof(int64_t) * num_ntypes * 2));\n\n      maker.Make(\n          src_nodes, rhs_nodes, &node_maps, count_lhs_device, &lhs_nodes,\n          stream);\n\n      CUDA_CALL(cudaEventCreate(&copyEvent));\n      if (TensorDispatcher::Global()->IsAvailable()) {\n        new_len_tensor = NDArray::PinnedEmpty(\n            {num_ntypes}, DGLDataTypeTraits<int64_t>::dtype,\n            DGLContext{kDGLCPU, 0});\n      } else {\n        // use pageable memory, it will unecessarily block but be functional\n        new_len_tensor = NDArray::Empty(\n            {num_ntypes}, DGLDataTypeTraits<int64_t>::dtype,\n            DGLContext{kDGLCPU, 0});\n      }\n      CUDA_CALL(cudaMemcpyAsync(\n          new_len_tensor->data, count_lhs_device,\n          sizeof(*num_nodes_per_type.data()) * num_ntypes,\n          cudaMemcpyDeviceToHost, stream));\n      CUDA_CALL(cudaEventRecord(copyEvent, stream));\n\n      device->FreeWorkspace(ctx, count_lhs_device);\n    } else {\n      maker.Make(lhs_nodes, rhs_nodes, &node_maps, stream);\n\n      for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) {\n        num_nodes_per_type[ntype] = lhs_nodes[ntype]->shape[0];\n      }\n    }\n    // Map node numberings from global to local, and build pointer for CSR.\n    auto ret = MapEdges(graph, edge_arrays, node_maps, stream);\n\n    if (generate_lhs_nodes) {\n      // wait for the previous copy\n      CUDA_CALL(cudaEventSynchronize(copyEvent));\n      CUDA_CALL(cudaEventDestroy(copyEvent));\n\n      // Resize lhs nodes.\n      for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) {\n        num_nodes_per_type[ntype] =\n            static_cast<int64_t*>(new_len_tensor->data)[ntype];\n        lhs_nodes[ntype]->shape[0] = num_nodes_per_type[ntype];\n      }\n    }\n\n    return ret;\n  }\n};\n\ntemplate <typename IdType>\nstd::tuple<HeteroGraphPtr, std::vector<IdArray>> ToBlockGPU(\n    HeteroGraphPtr graph, const std::vector<IdArray>& rhs_nodes,\n    bool include_rhs_in_lhs, std::vector<IdArray>* const lhs_nodes_ptr) {\n  return dgl::transform::ProcessToBlock<IdType>(\n      graph, rhs_nodes, include_rhs_in_lhs, lhs_nodes_ptr,\n      CUDAIdsMapper<IdType>());\n}\n\n}  // namespace\n\n// Use explicit names to get around MSVC's broken mangling that thinks the\n// following two functions are the same. Using template<> fails to export the\n// symbols.\nstd::tuple<HeteroGraphPtr, std::vector<IdArray>>\n// ToBlock<kDGLCUDA, int32_t>\nToBlockGPU32(\n    HeteroGraphPtr graph, const std::vector<IdArray>& rhs_nodes,\n    bool include_rhs_in_lhs, std::vector<IdArray>* const lhs_nodes) {\n  return ToBlockGPU<int32_t>(graph, rhs_nodes, include_rhs_in_lhs, lhs_nodes);\n}\n\nstd::tuple<HeteroGraphPtr, std::vector<IdArray>>\n// ToBlock<kDGLCUDA, int64_t>\nToBlockGPU64(\n    HeteroGraphPtr graph, const std::vector<IdArray>& rhs_nodes,\n    bool include_rhs_in_lhs, std::vector<IdArray>* const lhs_nodes) {\n  return ToBlockGPU<int64_t>(graph, rhs_nodes, include_rhs_in_lhs, lhs_nodes);\n}\n\n}  // namespace transform\n}  // namespace dgl\n"
  },
  {
    "path": "src/graph/transform/cuda/knn.cu",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file graph/transform/cuda/knn.cu\n * @brief k-nearest-neighbor (KNN) implementation (cuda)\n */\n\n#include <curand_kernel.h>\n#include <dgl/array.h>\n#include <dgl/random.h>\n#include <dgl/runtime/device_api.h>\n\n#include <algorithm>\n#include <cub/cub.cuh>  // NOLINT\n#include <limits>\n#include <string>\n#include <type_traits>\n#include <vector>\n\n#include \"../../../array/cuda/utils.h\"\n#include \"../../../runtime/cuda/cuda_common.h\"\n#include \"../knn.h\"\n\nnamespace dgl {\nnamespace transform {\nnamespace impl {\n\n/**\n * @brief Given input `size`, find the smallest value\n * greater or equal to `size` that is a multiple of `align`.\n *\n * e.g. Pow2Align(17, 4) = 20, Pow2Align(17, 8) = 24\n */\ntemplate <typename Type>\nstatic __host__ __device__ std::enable_if_t<std::is_unsigned<Type>::value, Type>\nPow2Align(Type size, Type align) {\n  if (align <= 1 || size <= 0) return size;\n  return ((size - 1) | (align - 1)) + 1;\n}\n\n/**\n * @brief Utility class used to avoid linker errors with extern\n *  unsized shared memory arrays with templated type\n */\ntemplate <typename Type>\nstruct SharedMemory {\n  __device__ inline operator Type*() {\n    extern __shared__ int __smem[];\n    return reinterpret_cast<Type*>(__smem);\n  }\n\n  __device__ inline operator const Type*() const {\n    extern __shared__ int __smem[];\n    return reinterpret_cast<Type*>(__smem);\n  }\n};\n\n// specialize for double to avoid unaligned memory\n// access compile errors\ntemplate <>\nstruct SharedMemory<double> {\n  __device__ inline operator double*() {\n    extern __shared__ double __smem_d[];\n    return reinterpret_cast<double*>(__smem_d);\n  }\n\n  __device__ inline operator const double*() const {\n    extern __shared__ double __smem_d[];\n    return reinterpret_cast<double*>(__smem_d);\n  }\n};\n\n/** @brief Compute Euclidean distance between two vectors in a cuda kernel */\ntemplate <typename FloatType, typename IdType>\n__device__ FloatType\nEuclideanDist(const FloatType* vec1, const FloatType* vec2, const int64_t dim) {\n  FloatType dist = 0;\n  IdType idx = 0;\n  for (; idx < dim - 3; idx += 4) {\n    FloatType diff0 = vec1[idx] - vec2[idx];\n    FloatType diff1 = vec1[idx + 1] - vec2[idx + 1];\n    FloatType diff2 = vec1[idx + 2] - vec2[idx + 2];\n    FloatType diff3 = vec1[idx + 3] - vec2[idx + 3];\n\n    dist += diff0 * diff0 + diff1 * diff1 + diff2 * diff2 + diff3 * diff3;\n  }\n\n  for (; idx < dim; ++idx) {\n    FloatType diff = vec1[idx] - vec2[idx];\n    dist += diff * diff;\n  }\n\n  return dist;\n}\n\n/**\n * @brief Compute Euclidean distance between two vectors in a cuda kernel,\n *  return positive infinite value if the intermediate distance is greater\n *  than the worst distance.\n */\ntemplate <typename FloatType, typename IdType>\n__device__ FloatType EuclideanDistWithCheck(\n    const FloatType* vec1, const FloatType* vec2, const int64_t dim,\n    const FloatType worst_dist) {\n  FloatType dist = 0;\n  IdType idx = 0;\n  bool early_stop = false;\n\n  for (; idx < dim - 3; idx += 4) {\n    FloatType diff0 = vec1[idx] - vec2[idx];\n    FloatType diff1 = vec1[idx + 1] - vec2[idx + 1];\n    FloatType diff2 = vec1[idx + 2] - vec2[idx + 2];\n    FloatType diff3 = vec1[idx + 3] - vec2[idx + 3];\n\n    dist += diff0 * diff0 + diff1 * diff1 + diff2 * diff2 + diff3 * diff3;\n    if (dist > worst_dist) {\n      early_stop = true;\n      idx = dim;\n      break;\n    }\n  }\n\n  for (; idx < dim; ++idx) {\n    FloatType diff = vec1[idx] - vec2[idx];\n    dist += diff * diff;\n    if (dist > worst_dist) {\n      early_stop = true;\n      break;\n    }\n  }\n\n  if (early_stop) {\n    return std::numeric_limits<FloatType>::max();\n  } else {\n    return dist;\n  }\n}\n\ntemplate <typename FloatType, typename IdType>\n__device__ void BuildHeap(IdType* indices, FloatType* dists, int size) {\n  for (int i = size / 2 - 1; i >= 0; --i) {\n    IdType idx = i;\n    while (true) {\n      IdType largest = idx;\n      IdType left = idx * 2 + 1;\n      IdType right = left + 1;\n      if (left < size && dists[left] > dists[largest]) {\n        largest = left;\n      }\n      if (right < size && dists[right] > dists[largest]) {\n        largest = right;\n      }\n      if (largest != idx) {\n        IdType tmp_idx = indices[largest];\n        indices[largest] = indices[idx];\n        indices[idx] = tmp_idx;\n\n        FloatType tmp_dist = dists[largest];\n        dists[largest] = dists[idx];\n        dists[idx] = tmp_dist;\n        idx = largest;\n      } else {\n        break;\n      }\n    }\n  }\n}\n\ntemplate <typename FloatType, typename IdType>\n__device__ void HeapInsert(\n    IdType* indices, FloatType* dist, IdType new_idx, FloatType new_dist,\n    int size, bool check_repeat = false) {\n  if (new_dist > dist[0]) return;\n\n  // check if we have it\n  if (check_repeat) {\n    for (IdType i = 0; i < size; ++i) {\n      if (indices[i] == new_idx) return;\n    }\n  }\n\n  IdType left = 0, right = 0, idx = 0, largest = 0;\n  dist[0] = new_dist;\n  indices[0] = new_idx;\n  while (true) {\n    left = idx * 2 + 1;\n    right = left + 1;\n    if (left < size && dist[left] > dist[largest]) {\n      largest = left;\n    }\n    if (right < size && dist[right] > dist[largest]) {\n      largest = right;\n    }\n    if (largest != idx) {\n      IdType tmp_idx = indices[idx];\n      indices[idx] = indices[largest];\n      indices[largest] = tmp_idx;\n\n      FloatType tmp_dist = dist[idx];\n      dist[idx] = dist[largest];\n      dist[largest] = tmp_dist;\n\n      idx = largest;\n    } else {\n      break;\n    }\n  }\n}\n\ntemplate <typename FloatType, typename IdType>\n__device__ bool FlaggedHeapInsert(\n    IdType* indices, FloatType* dist, bool* flags, IdType new_idx,\n    FloatType new_dist, bool new_flag, int size, bool check_repeat = false) {\n  if (new_dist > dist[0]) return false;\n\n  // check if we have it\n  if (check_repeat) {\n    for (IdType i = 0; i < size; ++i) {\n      if (indices[i] == new_idx) return false;\n    }\n  }\n\n  IdType left = 0, right = 0, idx = 0, largest = 0;\n  dist[0] = new_dist;\n  indices[0] = new_idx;\n  flags[0] = new_flag;\n  while (true) {\n    left = idx * 2 + 1;\n    right = left + 1;\n    if (left < size && dist[left] > dist[largest]) {\n      largest = left;\n    }\n    if (right < size && dist[right] > dist[largest]) {\n      largest = right;\n    }\n    if (largest != idx) {\n      IdType tmp_idx = indices[idx];\n      indices[idx] = indices[largest];\n      indices[largest] = tmp_idx;\n\n      FloatType tmp_dist = dist[idx];\n      dist[idx] = dist[largest];\n      dist[largest] = tmp_dist;\n\n      bool tmp_flag = flags[idx];\n      flags[idx] = flags[largest];\n      flags[largest] = tmp_flag;\n\n      idx = largest;\n    } else {\n      break;\n    }\n  }\n  return true;\n}\n\n/**\n * @brief Brute force kNN kernel. Compute distance for each pair of input points\n * and get the result directly (without a distance matrix).\n */\ntemplate <typename FloatType, typename IdType>\n__global__ void BruteforceKnnKernel(\n    const FloatType* data_points, const IdType* data_offsets,\n    const FloatType* query_points, const IdType* query_offsets, const int k,\n    FloatType* dists, IdType* query_out, IdType* data_out,\n    const int64_t num_batches, const int64_t feature_size) {\n  const IdType q_idx = blockIdx.x * blockDim.x + threadIdx.x;\n  if (q_idx >= query_offsets[num_batches]) return;\n  IdType batch_idx = 0;\n  for (IdType b = 0; b < num_batches + 1; ++b) {\n    if (query_offsets[b] > q_idx) {\n      batch_idx = b - 1;\n      break;\n    }\n  }\n  const IdType data_start = data_offsets[batch_idx],\n               data_end = data_offsets[batch_idx + 1];\n\n  for (IdType k_idx = 0; k_idx < k; ++k_idx) {\n    query_out[q_idx * k + k_idx] = q_idx;\n    dists[q_idx * k + k_idx] = std::numeric_limits<FloatType>::max();\n  }\n  FloatType worst_dist = std::numeric_limits<FloatType>::max();\n\n  for (IdType d_idx = data_start; d_idx < data_end; ++d_idx) {\n    FloatType tmp_dist = EuclideanDistWithCheck<FloatType, IdType>(\n        query_points + q_idx * feature_size, data_points + d_idx * feature_size,\n        feature_size, worst_dist);\n\n    IdType out_offset = q_idx * k;\n    HeapInsert<FloatType, IdType>(\n        data_out + out_offset, dists + out_offset, d_idx, tmp_dist, k);\n    worst_dist = dists[q_idx * k];\n  }\n}\n\n/**\n * @brief Same as BruteforceKnnKernel, but use shared memory as buffer.\n *  This kernel divides query points and data points into blocks. For each\n *  query block, it will make a loop over all data blocks and compute distances.\n *  This kernel is faster when the dimension of input points is not large.\n */\ntemplate <typename FloatType, typename IdType>\n__global__ void BruteforceKnnShareKernel(\n    const FloatType* data_points, const IdType* data_offsets,\n    const FloatType* query_points, const IdType* query_offsets,\n    const IdType* block_batch_id, const IdType* local_block_id, const int k,\n    FloatType* dists, IdType* query_out, IdType* data_out,\n    const int64_t num_batches, const int64_t feature_size) {\n  const IdType block_idx = static_cast<IdType>(blockIdx.x);\n  const IdType block_size = static_cast<IdType>(blockDim.x);\n  const IdType batch_idx = block_batch_id[block_idx];\n  const IdType local_bid = local_block_id[block_idx];\n  const IdType query_start = query_offsets[batch_idx] + block_size * local_bid;\n  const IdType query_end =\n      min(query_start + block_size, query_offsets[batch_idx + 1]);\n  if (query_start >= query_end) return;\n  const IdType query_idx = query_start + threadIdx.x;\n  const IdType data_start = data_offsets[batch_idx];\n  const IdType data_end = data_offsets[batch_idx + 1];\n\n  // shared memory: points in block + distance buffer + result buffer\n  FloatType* data_buff = SharedMemory<FloatType>();\n  FloatType* query_buff = data_buff + block_size * feature_size;\n  FloatType* dist_buff = query_buff + block_size * feature_size;\n  IdType* res_buff = reinterpret_cast<IdType*>(Pow2Align<uint64_t>(\n      reinterpret_cast<uint64_t>(dist_buff + block_size * k), sizeof(IdType)));\n  FloatType worst_dist = std::numeric_limits<FloatType>::max();\n\n  // initialize dist buff with inf value\n  for (auto i = 0; i < k; ++i) {\n    dist_buff[threadIdx.x + i * block_size] =\n        std::numeric_limits<FloatType>::max();\n  }\n\n  // load query data to shared memory\n  // TODO(tianqi): could be better here to exploit coalesce global memory\n  // access.\n  if (query_idx < query_end) {\n    for (auto i = 0; i < feature_size; ++i) {\n      // to avoid bank conflict, we use transpose here\n      query_buff[threadIdx.x + i * block_size] =\n          query_points[query_idx * feature_size + i];\n    }\n  }\n\n  // perform computation on each tile\n  for (auto tile_start = data_start; tile_start < data_end;\n       tile_start += block_size) {\n    // each thread load one data point into the shared memory\n    IdType load_idx = tile_start + threadIdx.x;\n    if (load_idx < data_end) {\n      for (auto i = 0; i < feature_size; ++i) {\n        data_buff[threadIdx.x * feature_size + i] =\n            data_points[load_idx * feature_size + i];\n      }\n    }\n    __syncthreads();\n\n    // compute distance for one tile\n    IdType true_block_size = min(data_end - tile_start, block_size);\n    if (query_idx < query_end) {\n      for (IdType d_idx = 0; d_idx < true_block_size; ++d_idx) {\n        FloatType tmp_dist = 0;\n        bool early_stop = false;\n        IdType dim_idx = 0;\n\n        for (; dim_idx < feature_size - 3; dim_idx += 4) {\n          FloatType diff0 = query_buff[threadIdx.x + block_size * (dim_idx)] -\n                            data_buff[d_idx * feature_size + dim_idx];\n          FloatType diff1 =\n              query_buff[threadIdx.x + block_size * (dim_idx + 1)] -\n              data_buff[d_idx * feature_size + dim_idx + 1];\n          FloatType diff2 =\n              query_buff[threadIdx.x + block_size * (dim_idx + 2)] -\n              data_buff[d_idx * feature_size + dim_idx + 2];\n          FloatType diff3 =\n              query_buff[threadIdx.x + block_size * (dim_idx + 3)] -\n              data_buff[d_idx * feature_size + dim_idx + 3];\n\n          tmp_dist +=\n              diff0 * diff0 + diff1 * diff1 + diff2 * diff2 + diff3 * diff3;\n\n          if (tmp_dist > worst_dist) {\n            early_stop = true;\n            dim_idx = feature_size;\n            break;\n          }\n        }\n\n        for (; dim_idx < feature_size; ++dim_idx) {\n          const FloatType diff =\n              query_buff[threadIdx.x + dim_idx * block_size] -\n              data_buff[d_idx * feature_size + dim_idx];\n          tmp_dist += diff * diff;\n\n          if (tmp_dist > worst_dist) {\n            early_stop = true;\n            break;\n          }\n        }\n\n        if (early_stop) continue;\n\n        HeapInsert<FloatType, IdType>(\n            res_buff + threadIdx.x * k, dist_buff + threadIdx.x * k,\n            d_idx + tile_start, tmp_dist, k);\n        worst_dist = dist_buff[threadIdx.x * k];\n      }\n    }\n    __syncthreads();\n  }\n\n  // copy result to global memory\n  if (query_idx < query_end) {\n    for (auto i = 0; i < k; ++i) {\n      dists[query_idx * k + i] = dist_buff[threadIdx.x * k + i];\n      data_out[query_idx * k + i] = res_buff[threadIdx.x * k + i];\n      query_out[query_idx * k + i] = query_idx;\n    }\n  }\n}\n\n/** @brief determine the number of blocks for each segment */\ntemplate <typename IdType>\n__global__ void GetNumBlockPerSegment(\n    const IdType* offsets, IdType* out, const int64_t batch_size,\n    const int64_t block_size) {\n  const IdType idx = blockIdx.x * blockDim.x + threadIdx.x;\n  if (idx < batch_size) {\n    out[idx] = (offsets[idx + 1] - offsets[idx] - 1) / block_size + 1;\n  }\n}\n\n/** @brief Get the batch index and local index in segment for each block */\ntemplate <typename IdType>\n__global__ void GetBlockInfo(\n    const IdType* num_block_prefixsum, IdType* block_batch_id,\n    IdType* local_block_id, size_t batch_size, size_t num_blocks) {\n  const IdType idx = blockIdx.x * blockDim.x + threadIdx.x;\n  IdType i = 0;\n\n  if (idx < num_blocks) {\n    for (; i < batch_size; ++i) {\n      if (num_block_prefixsum[i] > idx) break;\n    }\n    i--;\n    block_batch_id[idx] = i;\n    local_block_id[idx] = idx - num_block_prefixsum[i];\n  }\n}\n\n/**\n * @brief Brute force kNN. Compute distance for each pair of input points and\n * get the result directly (without a distance matrix).\n *\n * @tparam FloatType The type of input points.\n * @tparam IdType The type of id.\n * @param data_points NDArray of dataset points.\n * @param data_offsets offsets of point index in data points.\n * @param query_points NDArray of query points\n * @param query_offsets offsets of point index in query points.\n * @param k the number of nearest points\n * @param result output array\n */\ntemplate <typename FloatType, typename IdType>\nvoid BruteForceKNNCuda(\n    const NDArray& data_points, const IdArray& data_offsets,\n    const NDArray& query_points, const IdArray& query_offsets, const int k,\n    IdArray result) {\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  const auto& ctx = data_points->ctx;\n  auto device = runtime::DeviceAPI::Get(ctx);\n  const int64_t batch_size = data_offsets->shape[0] - 1;\n  const int64_t feature_size = data_points->shape[1];\n  const IdType* data_offsets_data = data_offsets.Ptr<IdType>();\n  const IdType* query_offsets_data = query_offsets.Ptr<IdType>();\n  const FloatType* data_points_data = data_points.Ptr<FloatType>();\n  const FloatType* query_points_data = query_points.Ptr<FloatType>();\n  IdType* query_out = result.Ptr<IdType>();\n  IdType* data_out = query_out + k * query_points->shape[0];\n\n  FloatType* dists = static_cast<FloatType*>(device->AllocWorkspace(\n      ctx, k * query_points->shape[0] * sizeof(FloatType)));\n\n  const int64_t block_size = cuda::FindNumThreads(query_points->shape[0]);\n  const int64_t num_blocks = (query_points->shape[0] - 1) / block_size + 1;\n  CUDA_KERNEL_CALL(\n      BruteforceKnnKernel, num_blocks, block_size, 0, stream, data_points_data,\n      data_offsets_data, query_points_data, query_offsets_data, k, dists,\n      query_out, data_out, batch_size, feature_size);\n\n  device->FreeWorkspace(ctx, dists);\n}\n\n/**\n * @brief Brute force kNN with shared memory.\n *  This function divides query points and data points into blocks. For each\n *  query block, it will make a loop over all data blocks and compute distances.\n *  It will be faster when the dimension of input points is not large.\n *\n * @tparam FloatType The type of input points.\n * @tparam IdType The type of id.\n * @param data_points NDArray of dataset points.\n * @param data_offsets offsets of point index in data points.\n * @param query_points NDArray of query points\n * @param query_offsets offsets of point index in query points.\n * @param k the number of nearest points\n * @param result output array\n */\ntemplate <typename FloatType, typename IdType>\nvoid BruteForceKNNSharedCuda(\n    const NDArray& data_points, const IdArray& data_offsets,\n    const NDArray& query_points, const IdArray& query_offsets, const int k,\n    IdArray result) {\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  const auto& ctx = data_points->ctx;\n  auto device = runtime::DeviceAPI::Get(ctx);\n  const int64_t batch_size = data_offsets->shape[0] - 1;\n  const int64_t feature_size = data_points->shape[1];\n  const IdType* data_offsets_data = data_offsets.Ptr<IdType>();\n  const IdType* query_offsets_data = query_offsets.Ptr<IdType>();\n  const FloatType* data_points_data = data_points.Ptr<FloatType>();\n  const FloatType* query_points_data = query_points.Ptr<FloatType>();\n  IdType* query_out = result.Ptr<IdType>();\n  IdType* data_out = query_out + k * query_points->shape[0];\n  constexpr size_t smem_align = std::max(sizeof(IdType), sizeof(FloatType));\n\n  // get max shared memory per block in bytes\n  // determine block size according to this value\n  int max_sharedmem_per_block = 0;\n  CUDA_CALL(cudaDeviceGetAttribute(\n      &max_sharedmem_per_block, cudaDevAttrMaxSharedMemoryPerBlock,\n      ctx.device_id));\n  const int64_t single_shared_mem = static_cast<int64_t>(Pow2Align<size_t>(\n      (k + 2 * feature_size) * sizeof(FloatType) + k * sizeof(IdType),\n      smem_align));\n\n  const int64_t block_size =\n      cuda::FindNumThreads(max_sharedmem_per_block / single_shared_mem);\n\n  // Determine the number of blocks. We first get the number of blocks for each\n  // segment. Then we get the block id offset via prefix sum.\n  IdType* num_block_per_segment = static_cast<IdType*>(\n      device->AllocWorkspace(ctx, batch_size * sizeof(IdType)));\n  IdType* num_block_prefixsum = static_cast<IdType*>(\n      device->AllocWorkspace(ctx, batch_size * sizeof(IdType)));\n\n  // block size for GetNumBlockPerSegment computation\n  int64_t temp_block_size = cuda::FindNumThreads(batch_size);\n  int64_t temp_num_blocks = (batch_size - 1) / temp_block_size + 1;\n  CUDA_KERNEL_CALL(\n      GetNumBlockPerSegment, temp_num_blocks, temp_block_size, 0, stream,\n      query_offsets_data, num_block_per_segment, batch_size, block_size);\n  size_t prefix_temp_size = 0;\n  CUDA_CALL(cub::DeviceScan::ExclusiveSum(\n      nullptr, prefix_temp_size, num_block_per_segment, num_block_prefixsum,\n      batch_size, stream));\n  void* prefix_temp = device->AllocWorkspace(ctx, prefix_temp_size);\n  CUDA_CALL(cub::DeviceScan::ExclusiveSum(\n      prefix_temp, prefix_temp_size, num_block_per_segment, num_block_prefixsum,\n      batch_size, stream));\n  device->FreeWorkspace(ctx, prefix_temp);\n\n  // wait for results\n  CUDA_CALL(cudaStreamSynchronize(stream));\n\n  int64_t num_blocks = 0, final_elem = 0,\n          copyoffset = (batch_size - 1) * sizeof(IdType);\n  device->CopyDataFromTo(\n      num_block_prefixsum, copyoffset, &num_blocks, 0, sizeof(IdType), ctx,\n      DGLContext{kDGLCPU, 0}, query_offsets->dtype);\n  device->CopyDataFromTo(\n      num_block_per_segment, copyoffset, &final_elem, 0, sizeof(IdType), ctx,\n      DGLContext{kDGLCPU, 0}, query_offsets->dtype);\n  num_blocks += final_elem;\n  device->FreeWorkspace(ctx, num_block_per_segment);\n\n  // get batch id and local id in segment\n  temp_block_size = cuda::FindNumThreads(num_blocks);\n  temp_num_blocks = (num_blocks - 1) / temp_block_size + 1;\n  IdType* block_batch_id = static_cast<IdType*>(\n      device->AllocWorkspace(ctx, num_blocks * sizeof(IdType)));\n  IdType* local_block_id = static_cast<IdType*>(\n      device->AllocWorkspace(ctx, num_blocks * sizeof(IdType)));\n  CUDA_KERNEL_CALL(\n      GetBlockInfo, temp_num_blocks, temp_block_size, 0, stream,\n      num_block_prefixsum, block_batch_id, local_block_id, batch_size,\n      num_blocks);\n\n  FloatType* dists = static_cast<FloatType*>(device->AllocWorkspace(\n      ctx, k * query_points->shape[0] * sizeof(FloatType)));\n  CUDA_KERNEL_CALL(\n      BruteforceKnnShareKernel, num_blocks, block_size,\n      single_shared_mem * block_size, stream, data_points_data,\n      data_offsets_data, query_points_data, query_offsets_data, block_batch_id,\n      local_block_id, k, dists, query_out, data_out, batch_size, feature_size);\n\n  device->FreeWorkspace(ctx, num_block_prefixsum);\n  device->FreeWorkspace(ctx, dists);\n  device->FreeWorkspace(ctx, local_block_id);\n  device->FreeWorkspace(ctx, block_batch_id);\n}\n\n/** @brief Setup rng state for nn-descent */\n__global__ void SetupRngKernel(\n    curandState* states, const uint64_t seed, const size_t n) {\n  size_t id = blockIdx.x * blockDim.x + threadIdx.x;\n  if (id < n) {\n    curand_init(seed, id, 0, states + id);\n  }\n}\n\n/**\n * @brief Randomly initialize neighbors (sampling without replacement)\n * for each nodes\n */\ntemplate <typename FloatType, typename IdType>\n__global__ void RandomInitNeighborsKernel(\n    const FloatType* points, const IdType* offsets, IdType* central_nodes,\n    IdType* neighbors, FloatType* dists, bool* flags, const int k,\n    const int64_t feature_size, const int64_t batch_size, const uint64_t seed) {\n  const IdType point_idx = blockIdx.x * blockDim.x + threadIdx.x;\n  IdType batch_idx = 0;\n  if (point_idx >= offsets[batch_size]) return;\n  curandState state;\n  curand_init(seed, point_idx, 0, &state);\n\n  // find the segment location in the input batch\n  for (IdType b = 0; b < batch_size + 1; ++b) {\n    if (offsets[b] > point_idx) {\n      batch_idx = b - 1;\n      break;\n    }\n  }\n\n  const IdType segment_size = offsets[batch_idx + 1] - offsets[batch_idx];\n  IdType* current_neighbors = neighbors + point_idx * k;\n  IdType* current_central_nodes = central_nodes + point_idx * k;\n  bool* current_flags = flags + point_idx * k;\n  FloatType* current_dists = dists + point_idx * k;\n  IdType segment_start = offsets[batch_idx];\n\n  // reservoir sampling\n  for (IdType i = 0; i < k; ++i) {\n    current_neighbors[i] = i + segment_start;\n    current_central_nodes[i] = point_idx;\n  }\n  for (IdType i = k; i < segment_size; ++i) {\n    const IdType j = static_cast<IdType>(curand(&state) % (i + 1));\n    if (j < k) current_neighbors[j] = i + segment_start;\n  }\n\n  // compute distances and set flags\n  for (IdType i = 0; i < k; ++i) {\n    current_flags[i] = true;\n    current_dists[i] = EuclideanDist<FloatType, IdType>(\n        points + point_idx * feature_size,\n        points + current_neighbors[i] * feature_size, feature_size);\n  }\n\n  // build heap\n  BuildHeap<FloatType, IdType>(neighbors + point_idx * k, current_dists, k);\n}\n\n/**\n * @brief Randomly select candidates from current knn and reverse-knn graph for\n *        nn-descent.\n */\ntemplate <typename IdType>\n__global__ void FindCandidatesKernel(\n    const IdType* offsets, IdType* new_candidates, IdType* old_candidates,\n    IdType* neighbors, bool* flags, const uint64_t seed,\n    const int64_t batch_size, const int num_candidates, const int k) {\n  const IdType point_idx = blockIdx.x * blockDim.x + threadIdx.x;\n  IdType batch_idx = 0;\n  if (point_idx >= offsets[batch_size]) return;\n  curandState state;\n  curand_init(seed, point_idx, 0, &state);\n\n  // find the segment location in the input batch\n  for (IdType b = 0; b < batch_size + 1; ++b) {\n    if (offsets[b] > point_idx) {\n      batch_idx = b - 1;\n      break;\n    }\n  }\n\n  IdType segment_start = offsets[batch_idx],\n         segment_end = offsets[batch_idx + 1];\n  IdType* current_neighbors = neighbors + point_idx * k;\n  bool* current_flags = flags + point_idx * k;\n\n  // reset candidates\n  IdType* new_candidates_ptr =\n      new_candidates + point_idx * (num_candidates + 1);\n  IdType* old_candidates_ptr =\n      old_candidates + point_idx * (num_candidates + 1);\n  new_candidates_ptr[0] = 0;\n  old_candidates_ptr[0] = 0;\n\n  // select candidates from current knn graph\n  // here we use candidate[0] for reservoir sampling temporarily\n  for (IdType i = 0; i < k; ++i) {\n    IdType candidate = current_neighbors[i];\n    IdType* candidate_array =\n        current_flags[i] ? new_candidates_ptr : old_candidates_ptr;\n    IdType curr_num = candidate_array[0];\n    IdType* candidate_data = candidate_array + 1;\n\n    // reservoir sampling\n    if (curr_num < num_candidates) {\n      candidate_data[curr_num] = candidate;\n    } else {\n      IdType pos = static_cast<IdType>(curand(&state) % (curr_num + 1));\n      if (pos < num_candidates) candidate_data[pos] = candidate;\n    }\n    ++candidate_array[0];\n  }\n\n  // select candidates from current reverse knn graph\n  // here we use candidate[0] for reservoir sampling temporarily\n  IdType index_start = segment_start * k, index_end = segment_end * k;\n  for (IdType i = index_start; i < index_end; ++i) {\n    if (neighbors[i] == point_idx) {\n      IdType reverse_candidate = (i - index_start) / k + segment_start;\n      IdType* candidate_array =\n          flags[i] ? new_candidates_ptr : old_candidates_ptr;\n      IdType curr_num = candidate_array[0];\n      IdType* candidate_data = candidate_array + 1;\n\n      // reservoir sampling\n      if (curr_num < num_candidates) {\n        candidate_data[curr_num] = reverse_candidate;\n      } else {\n        IdType pos = static_cast<IdType>(curand(&state) % (curr_num + 1));\n        if (pos < num_candidates) candidate_data[pos] = reverse_candidate;\n      }\n      ++candidate_array[0];\n    }\n  }\n\n  // set candidate[0] back to length\n  if (new_candidates_ptr[0] > num_candidates)\n    new_candidates_ptr[0] = num_candidates;\n  if (old_candidates_ptr[0] > num_candidates)\n    old_candidates_ptr[0] = num_candidates;\n\n  // mark new_candidates as old\n  IdType num_new_candidates = new_candidates_ptr[0];\n  for (IdType i = 0; i < k; ++i) {\n    IdType neighbor_idx = current_neighbors[i];\n\n    if (current_flags[i]) {\n      for (IdType j = 1; j < num_new_candidates + 1; ++j) {\n        if (new_candidates_ptr[j] == neighbor_idx) {\n          current_flags[i] = false;\n          break;\n        }\n      }\n    }\n  }\n}\n\n/** @brief Update knn graph according to selected candidates for nn-descent */\ntemplate <typename FloatType, typename IdType>\n__global__ void UpdateNeighborsKernel(\n    const FloatType* points, const IdType* offsets, IdType* neighbors,\n    IdType* new_candidates, IdType* old_candidates, FloatType* distances,\n    bool* flags, IdType* num_updates, const int64_t batch_size,\n    const int num_candidates, const int k, const int64_t feature_size) {\n  const IdType point_idx = blockIdx.x * blockDim.x + threadIdx.x;\n  if (point_idx >= offsets[batch_size]) return;\n  IdType* current_neighbors = neighbors + point_idx * k;\n  bool* current_flags = flags + point_idx * k;\n  FloatType* current_dists = distances + point_idx * k;\n  IdType* new_candidates_ptr =\n      new_candidates + point_idx * (num_candidates + 1);\n  IdType* old_candidates_ptr =\n      old_candidates + point_idx * (num_candidates + 1);\n  IdType num_new_candidates = new_candidates_ptr[0];\n  IdType num_old_candidates = old_candidates_ptr[0];\n  IdType current_num_updates = 0;\n\n  // process new candidates\n  for (IdType i = 1; i <= num_new_candidates; ++i) {\n    IdType new_c = new_candidates_ptr[i];\n\n    // new/old candidates of the current new candidate\n    IdType* twohop_new_ptr = new_candidates + new_c * (num_candidates + 1);\n    IdType* twohop_old_ptr = old_candidates + new_c * (num_candidates + 1);\n    IdType num_twohop_new = twohop_new_ptr[0];\n    IdType num_twohop_old = twohop_old_ptr[0];\n    FloatType worst_dist = current_dists[0];\n\n    // new - new\n    for (IdType j = 1; j <= num_twohop_new; ++j) {\n      IdType twohop_new_c = twohop_new_ptr[j];\n      FloatType new_dist = EuclideanDistWithCheck<FloatType, IdType>(\n          points + point_idx * feature_size,\n          points + twohop_new_c * feature_size, feature_size, worst_dist);\n\n      if (FlaggedHeapInsert<FloatType, IdType>(\n              current_neighbors, current_dists, current_flags, twohop_new_c,\n              new_dist, true, k, true)) {\n        ++current_num_updates;\n        worst_dist = current_dists[0];\n      }\n    }\n\n    // new - old\n    for (IdType j = 1; j <= num_twohop_old; ++j) {\n      IdType twohop_old_c = twohop_old_ptr[j];\n      FloatType new_dist = EuclideanDistWithCheck<FloatType, IdType>(\n          points + point_idx * feature_size,\n          points + twohop_old_c * feature_size, feature_size, worst_dist);\n\n      if (FlaggedHeapInsert<FloatType, IdType>(\n              current_neighbors, current_dists, current_flags, twohop_old_c,\n              new_dist, true, k, true)) {\n        ++current_num_updates;\n        worst_dist = current_dists[0];\n      }\n    }\n  }\n\n  // process old candidates\n  for (IdType i = 1; i <= num_old_candidates; ++i) {\n    IdType old_c = old_candidates_ptr[i];\n\n    // new candidates of the current old candidate\n    IdType* twohop_new_ptr = new_candidates + old_c * (num_candidates + 1);\n    IdType num_twohop_new = twohop_new_ptr[0];\n    FloatType worst_dist = current_dists[0];\n\n    // old - new\n    for (IdType j = 1; j <= num_twohop_new; ++j) {\n      IdType twohop_new_c = twohop_new_ptr[j];\n      FloatType new_dist = EuclideanDistWithCheck<FloatType, IdType>(\n          points + point_idx * feature_size,\n          points + twohop_new_c * feature_size, feature_size, worst_dist);\n\n      if (FlaggedHeapInsert<FloatType, IdType>(\n              current_neighbors, current_dists, current_flags, twohop_new_c,\n              new_dist, true, k, true)) {\n        ++current_num_updates;\n        worst_dist = current_dists[0];\n      }\n    }\n  }\n\n  num_updates[point_idx] = current_num_updates;\n}\n\n}  // namespace impl\n\ntemplate <DGLDeviceType XPU, typename FloatType, typename IdType>\nvoid KNN(\n    const NDArray& data_points, const IdArray& data_offsets,\n    const NDArray& query_points, const IdArray& query_offsets, const int k,\n    IdArray result, const std::string& algorithm) {\n  if (algorithm == std::string(\"bruteforce\")) {\n    impl::BruteForceKNNCuda<FloatType, IdType>(\n        data_points, data_offsets, query_points, query_offsets, k, result);\n  } else if (algorithm == std::string(\"bruteforce-sharemem\")) {\n    impl::BruteForceKNNSharedCuda<FloatType, IdType>(\n        data_points, data_offsets, query_points, query_offsets, k, result);\n  } else {\n    LOG(FATAL) << \"Algorithm \" << algorithm << \" is not supported on CUDA.\";\n  }\n}\n\ntemplate <DGLDeviceType XPU, typename FloatType, typename IdType>\nvoid NNDescent(\n    const NDArray& points, const IdArray& offsets, IdArray result, const int k,\n    const int num_iters, const int num_candidates, const double delta) {\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n  const auto& ctx = points->ctx;\n  auto device = runtime::DeviceAPI::Get(ctx);\n  const int64_t num_nodes = points->shape[0];\n  const int64_t feature_size = points->shape[1];\n  const int64_t batch_size = offsets->shape[0] - 1;\n  const IdType* offsets_data = offsets.Ptr<IdType>();\n  const FloatType* points_data = points.Ptr<FloatType>();\n\n  IdType* central_nodes = result.Ptr<IdType>();\n  IdType* neighbors = central_nodes + k * num_nodes;\n  uint64_t seed;\n  int warp_size = 0;\n  CUDA_CALL(\n      cudaDeviceGetAttribute(&warp_size, cudaDevAttrWarpSize, ctx.device_id));\n  // We don't need large block sizes, since there's not much inter-thread\n  // communication\n  int64_t block_size = warp_size;\n  int64_t num_blocks = (num_nodes - 1) / block_size + 1;\n\n  // allocate space for candidates, distances and flags\n  // we use the first element in candidate array to represent length\n  IdType* new_candidates = static_cast<IdType*>(device->AllocWorkspace(\n      ctx, num_nodes * (num_candidates + 1) * sizeof(IdType)));\n  IdType* old_candidates = static_cast<IdType*>(device->AllocWorkspace(\n      ctx, num_nodes * (num_candidates + 1) * sizeof(IdType)));\n  IdType* num_updates = static_cast<IdType*>(\n      device->AllocWorkspace(ctx, num_nodes * sizeof(IdType)));\n  FloatType* distances = static_cast<FloatType*>(\n      device->AllocWorkspace(ctx, num_nodes * k * sizeof(IdType)));\n  bool* flags = static_cast<bool*>(\n      device->AllocWorkspace(ctx, num_nodes * k * sizeof(IdType)));\n\n  size_t sum_temp_size = 0;\n  IdType total_num_updates = 0;\n  IdType* total_num_updates_d =\n      static_cast<IdType*>(device->AllocWorkspace(ctx, sizeof(IdType)));\n\n  CUDA_CALL(cub::DeviceReduce::Sum(\n      nullptr, sum_temp_size, num_updates, total_num_updates_d, num_nodes,\n      stream));\n  IdType* sum_temp_storage =\n      static_cast<IdType*>(device->AllocWorkspace(ctx, sum_temp_size));\n\n  // random initialize neighbors\n  seed = RandomEngine::ThreadLocal()->RandInt<uint64_t>(\n      std::numeric_limits<uint64_t>::max());\n  CUDA_KERNEL_CALL(\n      impl::RandomInitNeighborsKernel, num_blocks, block_size, 0, stream,\n      points_data, offsets_data, central_nodes, neighbors, distances, flags, k,\n      feature_size, batch_size, seed);\n\n  for (int i = 0; i < num_iters; ++i) {\n    // select candidates\n    seed = RandomEngine::ThreadLocal()->RandInt<uint64_t>(\n        std::numeric_limits<uint64_t>::max());\n    CUDA_KERNEL_CALL(\n        impl::FindCandidatesKernel, num_blocks, block_size, 0, stream,\n        offsets_data, new_candidates, old_candidates, neighbors, flags, seed,\n        batch_size, num_candidates, k);\n\n    // update\n    CUDA_KERNEL_CALL(\n        impl::UpdateNeighborsKernel, num_blocks, block_size, 0, stream,\n        points_data, offsets_data, neighbors, new_candidates, old_candidates,\n        distances, flags, num_updates, batch_size, num_candidates, k,\n        feature_size);\n\n    total_num_updates = 0;\n    CUDA_CALL(cub::DeviceReduce::Sum(\n        sum_temp_storage, sum_temp_size, num_updates, total_num_updates_d,\n        num_nodes, stream));\n    device->CopyDataFromTo(\n        total_num_updates_d, 0, &total_num_updates, 0, sizeof(IdType), ctx,\n        DGLContext{kDGLCPU, 0}, offsets->dtype);\n\n    if (total_num_updates <= static_cast<IdType>(delta * k * num_nodes)) {\n      break;\n    }\n  }\n\n  device->FreeWorkspace(ctx, new_candidates);\n  device->FreeWorkspace(ctx, old_candidates);\n  device->FreeWorkspace(ctx, num_updates);\n  device->FreeWorkspace(ctx, distances);\n  device->FreeWorkspace(ctx, flags);\n  device->FreeWorkspace(ctx, total_num_updates_d);\n  device->FreeWorkspace(ctx, sum_temp_storage);\n}\n\ntemplate void KNN<kDGLCUDA, float, int32_t>(\n    const NDArray& data_points, const IdArray& data_offsets,\n    const NDArray& query_points, const IdArray& query_offsets, const int k,\n    IdArray result, const std::string& algorithm);\ntemplate void KNN<kDGLCUDA, float, int64_t>(\n    const NDArray& data_points, const IdArray& data_offsets,\n    const NDArray& query_points, const IdArray& query_offsets, const int k,\n    IdArray result, const std::string& algorithm);\ntemplate void KNN<kDGLCUDA, double, int32_t>(\n    const NDArray& data_points, const IdArray& data_offsets,\n    const NDArray& query_points, const IdArray& query_offsets, const int k,\n    IdArray result, const std::string& algorithm);\ntemplate void KNN<kDGLCUDA, double, int64_t>(\n    const NDArray& data_points, const IdArray& data_offsets,\n    const NDArray& query_points, const IdArray& query_offsets, const int k,\n    IdArray result, const std::string& algorithm);\n\ntemplate void NNDescent<kDGLCUDA, float, int32_t>(\n    const NDArray& points, const IdArray& offsets, IdArray result, const int k,\n    const int num_iters, const int num_candidates, const double delta);\ntemplate void NNDescent<kDGLCUDA, float, int64_t>(\n    const NDArray& points, const IdArray& offsets, IdArray result, const int k,\n    const int num_iters, const int num_candidates, const double delta);\ntemplate void NNDescent<kDGLCUDA, double, int32_t>(\n    const NDArray& points, const IdArray& offsets, IdArray result, const int k,\n    const int num_iters, const int num_candidates, const double delta);\ntemplate void NNDescent<kDGLCUDA, double, int64_t>(\n    const NDArray& points, const IdArray& offsets, IdArray result, const int k,\n    const int num_iters, const int num_candidates, const double delta);\n\n}  // namespace transform\n}  // namespace dgl\n"
  },
  {
    "path": "src/graph/transform/knn.cc",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file graph/transform/knn.cc\n * @brief k-nearest-neighbor (KNN) interface\n */\n\n#include \"knn.h\"\n\n#include <dgl/runtime/packed_func.h>\n#include <dgl/runtime/registry.h>\n\n#include \"../../array/check.h\"\n\nusing namespace dgl::runtime;\nnamespace dgl {\nnamespace transform {\n\nDGL_REGISTER_GLOBAL(\"transform._CAPI_DGLKNN\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      const NDArray data_points = args[0];\n      const IdArray data_offsets = args[1];\n      const NDArray query_points = args[2];\n      const IdArray query_offsets = args[3];\n      const int k = args[4];\n      IdArray result = args[5];\n      const std::string algorithm = args[6];\n\n      aten::CheckContiguous(\n          {data_points, data_offsets, query_points, query_offsets, result},\n          {\"data_points\", \"data_offsets\", \"query_points\", \"query_offsets\",\n           \"result\"});\n      aten::CheckCtx(\n          data_points->ctx, {data_offsets, query_points, query_offsets, result},\n          {\"data_offsets\", \"query_points\", \"query_offsets\", \"result\"});\n\n      ATEN_XPU_SWITCH_CUDA(data_points->ctx.device_type, XPU, \"KNN\", {\n        ATEN_FLOAT_TYPE_SWITCH(data_points->dtype, FloatType, \"data_points\", {\n          ATEN_ID_TYPE_SWITCH(result->dtype, IdType, {\n            KNN<XPU, FloatType, IdType>(\n                data_points, data_offsets, query_points, query_offsets, k,\n                result, algorithm);\n          });\n        });\n      });\n    });\n\nDGL_REGISTER_GLOBAL(\"transform._CAPI_DGLNNDescent\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      const NDArray points = args[0];\n      const IdArray offsets = args[1];\n      const IdArray result = args[2];\n      const int k = args[3];\n      const int num_iters = args[4];\n      const int num_candidates = args[5];\n      const double delta = args[6];\n\n      aten::CheckContiguous(\n          {points, offsets, result}, {\"points\", \"offsets\", \"result\"});\n      aten::CheckCtx(\n          points->ctx, {points, offsets, result},\n          {\"points\", \"offsets\", \"result\"});\n\n      ATEN_XPU_SWITCH_CUDA(points->ctx.device_type, XPU, \"NNDescent\", {\n        ATEN_FLOAT_TYPE_SWITCH(points->dtype, FloatType, \"points\", {\n          ATEN_ID_TYPE_SWITCH(result->dtype, IdType, {\n            NNDescent<XPU, FloatType, IdType>(\n                points, offsets, result, k, num_iters, num_candidates, delta);\n          });\n        });\n      });\n    });\n\n}  // namespace transform\n}  // namespace dgl\n"
  },
  {
    "path": "src/graph/transform/knn.h",
    "content": "/**\n *  Copyright (c) 2021 by Contributors\n * @file graph/transform/knn.h\n * @brief k-nearest-neighbor (KNN) implementation\n */\n\n#ifndef DGL_GRAPH_TRANSFORM_KNN_H_\n#define DGL_GRAPH_TRANSFORM_KNN_H_\n\n#include <dgl/array.h>\n\n#include <string>\n\nnamespace dgl {\nnamespace transform {\n\n/**\n * @brief For each point in each segment in \\a query_points, find \\a k nearest\n *        points in the same segment in \\a data_points. \\a data_offsets and \\a\n *        query_offsets determine the start index of each segment in \\a\n *        data_points and \\a query_points.\n *\n * @param data_points dataset points.\n * @param data_offsets offsets of point index in \\a data_points.\n * @param query_points query points.\n * @param query_offsets offsets of point index in \\a query_points.\n * @param k the number of nearest points.\n * @param result output array. A 2D tensor indicating the index  relation\n *        between \\a query_points and \\a data_points.\n * @param algorithm algorithm used to compute the k-nearest neighbors.\n */\ntemplate <DGLDeviceType XPU, typename FloatType, typename IdType>\nvoid KNN(\n    const NDArray& data_points, const IdArray& data_offsets,\n    const NDArray& query_points, const IdArray& query_offsets, const int k,\n    IdArray result, const std::string& algorithm);\n\n/**\n * @brief For each input point, find \\a k approximate nearest points in the same\n *        segment using NN-descent algorithm.\n *\n * @param points input points.\n * @param offsets offsets of point index.\n * @param result output array. A 2D tensor indicating the index relation between\n *        points.\n * @param k the number of nearest points.\n * @param num_iters The maximum number of NN-descent iterations to perform.\n * @param num_candidates The maximum number of candidates to be considered\n *        during one iteration.\n * @param delta A value controls the early abort.\n */\ntemplate <DGLDeviceType XPU, typename FloatType, typename IdType>\nvoid NNDescent(\n    const NDArray& points, const IdArray& offsets, IdArray result, const int k,\n    const int num_iters, const int num_candidates, const double delta);\n\n}  // namespace transform\n}  // namespace dgl\n\n#endif  // DGL_GRAPH_TRANSFORM_KNN_H_\n"
  },
  {
    "path": "src/graph/transform/line_graph.cc",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file graph/transform/line_graph.cc\n * @brief Line graph implementation\n */\n\n#include <dgl/array.h>\n#include <dgl/base_heterograph.h>\n#include <dgl/packed_func_ext.h>\n#include <dgl/transform.h>\n\n#include <utility>\n#include <vector>\n\n#include \"../../c_api_common.h\"\n#include \"../heterograph.h\"\n\nnamespace dgl {\n\nusing namespace dgl::runtime;\nusing namespace dgl::aten;\n\nnamespace transform {\n\n/**\n * @brief Create Line Graph.\n * @param hg Graph.\n * @param backtracking whether the pair of (v, u) (u, v) edges are treated as\n *        linked.\n * @return The Line Graph.\n */\nHeteroGraphPtr CreateLineGraph(HeteroGraphPtr hg, bool backtracking) {\n  const auto hgp = std::dynamic_pointer_cast<HeteroGraph>(hg);\n  return hgp->LineGraph(backtracking);\n}\n\nDGL_REGISTER_GLOBAL(\"transform._CAPI_DGLHeteroLineGraph\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      HeteroGraphRef hg = args[0];\n      bool backtracking = args[1];\n\n      auto hgptr = CreateLineGraph(hg.sptr(), backtracking);\n      *rv = HeteroGraphRef(hgptr);\n    });\n\n};  // namespace transform\n};  // namespace dgl\n"
  },
  {
    "path": "src/graph/transform/metis_partition_hetero.cc",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file graph/metis_partition.cc\n * @brief Call Metis partitioning\n */\n\n#include <dgl/base_heterograph.h>\n#include <dgl/packed_func_ext.h>\n#include <metis.h>\n\n#include \"../heterograph.h\"\n#include \"../unit_graph.h\"\n\nusing namespace dgl::runtime;\n\nnamespace dgl {\n\nnamespace transform {\n\n#if !defined(_WIN32)\n\nIdArray MetisPartition(\n    UnitGraphPtr g, int k, NDArray vwgt_arr, const std::string &mode,\n    bool obj_cut) {\n  // Mode can only be \"k-way\" or \"recursive\"\n  CHECK(mode == \"k-way\" || mode == \"recursive\")\n      << \"mode can only be \\\"k-way\\\" or \\\"recursive\\\"\";\n  // The index type of Metis needs to be compatible with DGL index type.\n  CHECK_EQ(sizeof(idx_t), sizeof(int64_t))\n      << \"Metis only supports int64 graph for now\";\n  // This is a symmetric graph, so in-csr and out-csr are the same.\n  const auto mat = g->GetCSCMatrix(0);\n  //   const auto mat = g->GetInCSR()->ToCSRMatrix();\n\n  idx_t nvtxs = g->NumVertices(0);\n  idx_t ncon = 1;  // # balacing constraints.\n  idx_t *xadj = static_cast<idx_t *>(mat.indptr->data);\n  idx_t *adjncy = static_cast<idx_t *>(mat.indices->data);\n  idx_t nparts = k;\n  IdArray part_arr = aten::NewIdArray(nvtxs);\n  idx_t objval = 0;\n  idx_t *part = static_cast<idx_t *>(part_arr->data);\n\n  int64_t vwgt_len = vwgt_arr->shape[0];\n  CHECK_EQ(sizeof(idx_t), vwgt_arr->dtype.bits / 8)\n      << \"The vertex weight array doesn't have right type\";\n  CHECK(vwgt_len % g->NumVertices(0) == 0)\n      << \"The vertex weight array doesn't have right number of elements\";\n  idx_t *vwgt = NULL;\n  if (vwgt_len > 0) {\n    ncon = vwgt_len / g->NumVertices(0);\n    vwgt = static_cast<idx_t *>(vwgt_arr->data);\n  }\n\n  auto partition_func =\n      (mode == \"k-way\") ? METIS_PartGraphKway : METIS_PartGraphRecursive;\n\n  idx_t options[METIS_NOPTIONS];\n  METIS_SetDefaultOptions(options);\n  options[METIS_OPTION_ONDISK] = 1;\n  options[METIS_OPTION_NITER] = 1;\n  options[METIS_OPTION_NIPARTS] = 1;\n  options[METIS_OPTION_DROPEDGES] = 1;\n\n  if (obj_cut) {\n    options[METIS_OPTION_OBJTYPE] = METIS_OBJTYPE_CUT;\n  } else {\n    options[METIS_OPTION_OBJTYPE] = METIS_OBJTYPE_VOL;\n  }\n\n  int ret = partition_func(\n      &nvtxs,  // The number of vertices\n      &ncon,   // The number of balancing constraints.\n      xadj,    // indptr\n      adjncy,  // indices\n      vwgt,    // the weights of the vertices\n      NULL,    // The size of the vertices for computing\n      // the total communication volume\n      NULL,     // The weights of the edges\n      &nparts,  // The number of partitions.\n      NULL,     // the desired weight for each partition and constraint\n      NULL,     // the allowed load imbalance tolerance\n      options,  // the array of options\n      &objval,  // the edge-cut or the total communication volume of\n      // the partitioning solution\n      part);\n\n  if (obj_cut) {\n    LOG(INFO) << \"Partition a graph with \" << g->NumVertices(0) << \" nodes and \"\n              << g->NumEdges(0) << \" edges into \" << k << \" parts and \"\n              << \"get \" << objval << \" edge cuts\";\n  } else {\n    LOG(INFO) << \"Partition a graph with \" << g->NumVertices(0) << \" nodes and \"\n              << g->NumEdges(0) << \" edges into \" << k << \" parts and \"\n              << \"the communication volume is \" << objval;\n  }\n\n  switch (ret) {\n    case METIS_OK:\n      return part_arr;\n    case METIS_ERROR_INPUT:\n      LOG(FATAL) << \"Error in Metis partitioning: input error\";\n    case METIS_ERROR_MEMORY:\n      LOG(FATAL) << \"Error in Metis partitioning: cannot allocate memory\";\n    default:\n      LOG(FATAL) << \"Error in Metis partitioning: other errors\";\n  }\n  // return an array of 0 elements to indicate the error.\n  return aten::NullArray();\n}\n\n#endif  // !defined(_WIN32)\n\nDGL_REGISTER_GLOBAL(\"partition._CAPI_DGLMetisPartition_Hetero\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      HeteroGraphRef g = args[0];\n      auto hgptr = std::dynamic_pointer_cast<HeteroGraph>(g.sptr());\n      CHECK(hgptr) << \"Invalid HeteroGraph object\";\n      CHECK_EQ(hgptr->relation_graphs().size(), 1)\n          << \"Metis partition only supports HomoGraph\";\n      auto ugptr = hgptr->relation_graphs()[0];\n      int k = args[1];\n      NDArray vwgt = args[2];\n      std::string mode = args[3];\n      bool obj_cut = args[4];\n#if !defined(_WIN32)\n      *rv = MetisPartition(ugptr, k, vwgt, mode, obj_cut);\n#else\n      LOG(FATAL) << \"Metis partition does not support Windows.\";\n#endif  // !defined(_WIN32)\n    });\n}  // namespace transform\n}  // namespace dgl\n"
  },
  {
    "path": "src/graph/transform/partition_hetero.cc",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file graph/metis_partition.cc\n * @brief Call Metis partitioning\n */\n\n#include <dgl/base_heterograph.h>\n#include <dgl/packed_func_ext.h>\n#include <dgl/runtime/parallel_for.h>\n\n#include \"../heterograph.h\"\n#include \"../unit_graph.h\"\n\n#if !defined(_WIN32)\n#include <GKlib.h>\n#endif  // !defined(_WIN32)\n\nusing namespace dgl::runtime;\n\nnamespace dgl {\n\n#if !defined(_WIN32)\ngk_csr_t *Convert2GKCsr(const aten::CSRMatrix mat, bool is_row);\naten::CSRMatrix Convert2DGLCsr(gk_csr_t *gk_csr, bool is_row);\n#endif  // !defined(_WIN32)\n\nnamespace transform {\n\nclass HaloHeteroSubgraph : public HeteroSubgraph {\n public:\n  std::vector<IdArray> inner_nodes;\n};\n\nHeteroGraphPtr ReorderUnitGraph(UnitGraphPtr ug, IdArray new_order) {\n  auto format = ug->GetCreatedFormats();\n  // We only need to reorder one of the graph structure.\n  if (format & CSC_CODE) {\n    auto cscmat = ug->GetCSCMatrix(0);\n    auto new_cscmat = aten::CSRReorder(cscmat, new_order, new_order);\n    return UnitGraph::CreateFromCSC(\n        ug->NumVertexTypes(), new_cscmat, ug->GetAllowedFormats());\n  } else if (format & CSR_CODE) {\n    auto csrmat = ug->GetCSRMatrix(0);\n    auto new_csrmat = aten::CSRReorder(csrmat, new_order, new_order);\n    return UnitGraph::CreateFromCSR(\n        ug->NumVertexTypes(), new_csrmat, ug->GetAllowedFormats());\n  } else {\n    auto coomat = ug->GetCOOMatrix(0);\n    auto new_coomat = aten::COOReorder(coomat, new_order, new_order);\n    return UnitGraph::CreateFromCOO(\n        ug->NumVertexTypes(), new_coomat, ug->GetAllowedFormats());\n  }\n}\n\nHaloHeteroSubgraph GetSubgraphWithHalo(\n    std::shared_ptr<HeteroGraph> hg, IdArray nodes, int num_hops) {\n  CHECK_EQ(hg->NumBits(), 64) << \"halo subgraph only supports 64bits graph\";\n  CHECK_EQ(hg->relation_graphs().size(), 1)\n      << \"halo subgraph only supports homogeneous graph\";\n  CHECK_EQ(nodes->dtype.bits, 64)\n      << \"halo subgraph only supports 64bits nodes tensor\";\n  const dgl_id_t *nid = static_cast<dgl_id_t *>(nodes->data);\n  const auto id_len = nodes->shape[0];\n  // A map contains all nodes in the subgraph.\n  // The key is the old node Ids, the value indicates whether a node is a inner\n  // node.\n  std::unordered_map<dgl_id_t, bool> all_nodes;\n  // The old Ids of all nodes. We want to preserve the order of the nodes in the\n  // vector. The first few nodes are the inner nodes in the subgraph.\n  std::vector<dgl_id_t> old_node_ids(nid, nid + id_len);\n  std::vector<std::vector<dgl_id_t>> outer_nodes(num_hops);\n  for (int64_t i = 0; i < id_len; i++) all_nodes[nid[i]] = true;\n  auto orig_nodes = all_nodes;\n\n  std::vector<dgl_id_t> edge_src, edge_dst, edge_eid;\n\n  // When we deal with in-edges, we need to do two things:\n  // * find the edges inside the partition and the edges between partitions.\n  // * find the nodes outside the partition that connect the partition.\n  EdgeArray in_edges = hg->InEdges(0, nodes);\n  auto src = in_edges.src;\n  auto dst = in_edges.dst;\n  auto eid = in_edges.id;\n  auto num_edges = eid->shape[0];\n  const dgl_id_t *src_data = static_cast<dgl_id_t *>(src->data);\n  const dgl_id_t *dst_data = static_cast<dgl_id_t *>(dst->data);\n  const dgl_id_t *eid_data = static_cast<dgl_id_t *>(eid->data);\n  for (int64_t i = 0; i < num_edges; i++) {\n    // We check if the source node is in the original node.\n    auto it1 = orig_nodes.find(src_data[i]);\n    if (it1 != orig_nodes.end() || num_hops > 0) {\n      edge_src.push_back(src_data[i]);\n      edge_dst.push_back(dst_data[i]);\n      edge_eid.push_back(eid_data[i]);\n    }\n    // We need to expand only if the node hasn't been seen before.\n    auto it = all_nodes.find(src_data[i]);\n    if (it == all_nodes.end() && num_hops > 0) {\n      all_nodes[src_data[i]] = false;\n      old_node_ids.push_back(src_data[i]);\n      outer_nodes[0].push_back(src_data[i]);\n    }\n  }\n\n  // Now we need to traverse the graph with the in-edges to access nodes\n  // and edges more hops away.\n  for (int k = 1; k < num_hops; k++) {\n    const std::vector<dgl_id_t> &nodes = outer_nodes[k - 1];\n    EdgeArray in_edges = hg->InEdges(0, aten::VecToIdArray(nodes));\n    auto src = in_edges.src;\n    auto dst = in_edges.dst;\n    auto eid = in_edges.id;\n    auto num_edges = eid->shape[0];\n    const dgl_id_t *src_data = static_cast<dgl_id_t *>(src->data);\n    const dgl_id_t *dst_data = static_cast<dgl_id_t *>(dst->data);\n    const dgl_id_t *eid_data = static_cast<dgl_id_t *>(eid->data);\n    for (int64_t i = 0; i < num_edges; i++) {\n      auto it1 = orig_nodes.find(src_data[i]);\n      // If the source node is in the partition, we have got this edge when we\n      // iterate over the out-edges above.\n      if (it1 == orig_nodes.end()) {\n        edge_src.push_back(src_data[i]);\n        edge_dst.push_back(dst_data[i]);\n        edge_eid.push_back(eid_data[i]);\n      }\n      // If we haven't seen this node.\n      auto it = all_nodes.find(src_data[i]);\n      if (it == all_nodes.end()) {\n        all_nodes[src_data[i]] = false;\n        old_node_ids.push_back(src_data[i]);\n        outer_nodes[k].push_back(src_data[i]);\n      }\n    }\n  }\n\n  if (num_hops > 0) {\n    EdgeArray out_edges = hg->OutEdges(0, nodes);\n    auto src = out_edges.src;\n    auto dst = out_edges.dst;\n    auto eid = out_edges.id;\n    auto num_edges = eid->shape[0];\n    const dgl_id_t *src_data = static_cast<dgl_id_t *>(src->data);\n    const dgl_id_t *dst_data = static_cast<dgl_id_t *>(dst->data);\n    const dgl_id_t *eid_data = static_cast<dgl_id_t *>(eid->data);\n    for (int64_t i = 0; i < num_edges; i++) {\n      // If the outer edge isn't in the partition.\n      auto it1 = orig_nodes.find(dst_data[i]);\n      if (it1 == orig_nodes.end()) {\n        edge_src.push_back(src_data[i]);\n        edge_dst.push_back(dst_data[i]);\n        edge_eid.push_back(eid_data[i]);\n      }\n      // We don't expand along the out-edges.\n      auto it = all_nodes.find(dst_data[i]);\n      if (it == all_nodes.end()) {\n        all_nodes[dst_data[i]] = false;\n        old_node_ids.push_back(dst_data[i]);\n      }\n    }\n  }\n\n  // We assign new Ids to the nodes in the subgraph. We ensure that the HALO\n  // nodes are behind the input nodes.\n  std::unordered_map<dgl_id_t, dgl_id_t> old2new;\n  for (size_t i = 0; i < old_node_ids.size(); i++) {\n    old2new[old_node_ids[i]] = i;\n  }\n\n  num_edges = edge_src.size();\n  IdArray new_src = IdArray::Empty(\n      {num_edges}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});\n  IdArray new_dst = IdArray::Empty(\n      {num_edges}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});\n  dgl_id_t *new_src_data = static_cast<dgl_id_t *>(new_src->data);\n  dgl_id_t *new_dst_data = static_cast<dgl_id_t *>(new_dst->data);\n  for (size_t i = 0; i < edge_src.size(); i++) {\n    new_src_data[i] = old2new[edge_src[i]];\n    new_dst_data[i] = old2new[edge_dst[i]];\n  }\n\n  std::vector<int> inner_nodes(old_node_ids.size());\n  for (size_t i = 0; i < old_node_ids.size(); i++) {\n    dgl_id_t old_nid = old_node_ids[i];\n    inner_nodes[i] = all_nodes[old_nid];\n  }\n  aten::COOMatrix coo(\n      old_node_ids.size(), old_node_ids.size(), new_src, new_dst);\n  HeteroGraphPtr ugptr = UnitGraph::CreateFromCOO(1, coo);\n  HeteroGraphPtr subg = CreateHeteroGraph(hg->meta_graph(), {ugptr});\n  HaloHeteroSubgraph halo_subg;\n  halo_subg.graph = subg;\n  halo_subg.induced_vertices = {aten::VecToIdArray(old_node_ids)};\n  halo_subg.induced_edges = {aten::VecToIdArray(edge_eid)};\n  // TODO(zhengda) we need to switch to 8 bytes afterwards.\n  halo_subg.inner_nodes = {aten::VecToIdArray<int>(inner_nodes, 32)};\n  return halo_subg;\n}\n\nDGL_REGISTER_GLOBAL(\"partition._CAPI_DGLReorderGraph_Hetero\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      HeteroGraphRef g = args[0];\n      auto hgptr = std::dynamic_pointer_cast<HeteroGraph>(g.sptr());\n      CHECK(hgptr) << \"Invalid HeteroGraph object\";\n      CHECK_EQ(hgptr->relation_graphs().size(), 1)\n          << \"Reorder only supports HomoGraph\";\n      auto ugptr = hgptr->relation_graphs()[0];\n      const IdArray new_order = args[1];\n      auto reorder_ugptr = ReorderUnitGraph(ugptr, new_order);\n      std::vector<HeteroGraphPtr> rel_graphs = {reorder_ugptr};\n      *rv = HeteroGraphRef(std::make_shared<HeteroGraph>(\n          hgptr->meta_graph(), rel_graphs, hgptr->NumVerticesPerType()));\n    });\n\nDGL_REGISTER_GLOBAL(\"partition._CAPI_DGLPartitionWithHalo_Hetero\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      HeteroGraphRef g = args[0];\n      auto hgptr = std::dynamic_pointer_cast<HeteroGraph>(g.sptr());\n      CHECK(hgptr) << \"Invalid HeteroGraph object\";\n      CHECK_EQ(hgptr->relation_graphs().size(), 1)\n          << \"Metis partition only supports HomoGraph\";\n      auto ugptr = hgptr->relation_graphs()[0];\n\n      IdArray node_parts = args[1];\n      int num_hops = args[2];\n\n      CHECK_EQ(node_parts->dtype.bits, 64)\n          << \"Only supports 64bits tensor for now\";\n\n      const int64_t *part_data = static_cast<int64_t *>(node_parts->data);\n      int64_t num_nodes = node_parts->shape[0];\n      std::unordered_map<int, std::vector<int64_t>> part_map;\n      for (int64_t i = 0; i < num_nodes; i++) {\n        dgl_id_t part_id = part_data[i];\n        auto it = part_map.find(part_id);\n        if (it == part_map.end()) {\n          std::vector<int64_t> vec;\n          vec.push_back(i);\n          part_map[part_id] = vec;\n        } else {\n          it->second.push_back(i);\n        }\n      }\n      std::vector<int> part_ids;\n      std::vector<std::vector<int64_t>> part_nodes;\n      int max_part_id = 0;\n      for (auto it = part_map.begin(); it != part_map.end(); it++) {\n        max_part_id = std::max(it->first, max_part_id);\n        part_ids.push_back(it->first);\n        part_nodes.push_back(it->second);\n      }\n      // When we construct subgraphs, we need to access both in-edges and\n      // out-edges. We need to make sure the in-CSR and out-CSR exist.\n      // Otherwise, we'll try to construct in-CSR and out-CSR in openmp for\n      // loop, which will lead to some unexpected results.\n      ugptr->GetInCSR();\n      ugptr->GetOutCSR();\n      std::vector<std::shared_ptr<HaloHeteroSubgraph>> subgs(max_part_id + 1);\n      int num_partitions = part_nodes.size();\n      runtime::parallel_for(0, num_partitions, [&](int b, int e) {\n        for (auto i = b; i < e; i++) {\n          auto nodes = aten::VecToIdArray(part_nodes[i]);\n          HaloHeteroSubgraph subg = GetSubgraphWithHalo(hgptr, nodes, num_hops);\n          std::shared_ptr<HaloHeteroSubgraph> subg_ptr(\n              new HaloHeteroSubgraph(subg));\n          int part_id = part_ids[i];\n          subgs[part_id] = subg_ptr;\n        }\n      });\n      List<HeteroSubgraphRef> ret_list;\n      for (size_t i = 0; i < subgs.size(); i++) {\n        ret_list.push_back(HeteroSubgraphRef(subgs[i]));\n      }\n      *rv = ret_list;\n    });\n\ntemplate <class IdType>\nstruct EdgeProperty {\n  IdType eid;\n  int64_t idx;\n  int part_id;\n};\n\n// Reassign edge IDs so that all edges in a partition have contiguous edge IDs.\n// The original edge IDs are returned.\nDGL_REGISTER_GLOBAL(\"partition._CAPI_DGLReassignEdges_Hetero\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      HeteroGraphRef g = args[0];\n      auto hgptr = std::dynamic_pointer_cast<HeteroGraph>(g.sptr());\n      CHECK(hgptr) << \"Invalid HeteroGraph object\";\n      CHECK_EQ(hgptr->relation_graphs().size(), 1)\n          << \"Reorder only supports HomoGraph\";\n      auto ugptr = hgptr->relation_graphs()[0];\n      IdArray etype = args[1];\n      IdArray part_id = args[2];\n      bool is_incsr = args[3];\n      auto csrmat = is_incsr ? ugptr->GetCSCMatrix(0) : ugptr->GetCSRMatrix(0);\n      int64_t num_edges = csrmat.data->shape[0];\n      int64_t num_rows = csrmat.indptr->shape[0] - 1;\n      IdArray new_data =\n          IdArray::Empty({num_edges}, csrmat.data->dtype, csrmat.data->ctx);\n      // Return the original edge Ids.\n      *rv = new_data;\n\n      // Generate new edge Ids.\n      ATEN_ID_TYPE_SWITCH(new_data->dtype, IdType, {\n        CHECK(etype->dtype.bits == sizeof(IdType) * 8);\n        CHECK(part_id->dtype.bits == sizeof(IdType) * 8);\n        const IdType *part_id_data = static_cast<IdType *>(part_id->data);\n        const IdType *etype_data = static_cast<IdType *>(etype->data);\n        const IdType *indptr_data = static_cast<IdType *>(csrmat.indptr->data);\n        IdType *typed_data = static_cast<IdType *>(csrmat.data->data);\n        IdType *typed_new_data = static_cast<IdType *>(new_data->data);\n        std::vector<EdgeProperty<IdType>> indexed_eids(num_edges);\n        for (int64_t i = 0; i < num_rows; i++) {\n          for (int64_t j = indptr_data[i]; j < indptr_data[i + 1]; j++) {\n            indexed_eids[j].eid = typed_data[j];\n            indexed_eids[j].idx = j;\n            indexed_eids[j].part_id = part_id_data[i];\n          }\n        }\n        auto comp = [etype_data](\n                        const EdgeProperty<IdType> &a,\n                        const EdgeProperty<IdType> &b) {\n          if (a.part_id == b.part_id) {\n            return etype_data[a.eid] < etype_data[b.eid];\n          } else {\n            return a.part_id < b.part_id;\n          }\n        };\n        // We only need to sort the edges if the input graph has multiple\n        // relations. If it's a homogeneous grap, we'll just assign edge Ids\n        // based on its previous order.\n        if (etype->shape[0] > 0) {\n          std::sort(indexed_eids.begin(), indexed_eids.end(), comp);\n        }\n        for (int64_t new_eid = 0; new_eid < num_edges; new_eid++) {\n          int64_t orig_idx = indexed_eids[new_eid].idx;\n          typed_new_data[new_eid] = typed_data[orig_idx];\n          typed_data[orig_idx] = new_eid;\n        }\n      });\n      ugptr->InvalidateCSR();\n      ugptr->InvalidateCOO();\n    });\n\nDGL_REGISTER_GLOBAL(\"partition._CAPI_GetHaloSubgraphInnerNodes_Hetero\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      HeteroSubgraphRef g = args[0];\n      auto gptr = std::dynamic_pointer_cast<HaloHeteroSubgraph>(g.sptr());\n      CHECK(gptr) << \"The input graph has to be HaloHeteroSubgraph\";\n      *rv = gptr->inner_nodes[0];\n    });\n\nDGL_REGISTER_GLOBAL(\"partition._CAPI_DGLMakeSymmetric_Hetero\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      HeteroGraphRef g = args[0];\n      auto hgptr = std::dynamic_pointer_cast<HeteroGraph>(g.sptr());\n      CHECK(hgptr) << \"Invalid HeteroGraph object\";\n      CHECK_EQ(hgptr->relation_graphs().size(), 1)\n          << \"Metis partition only supports homogeneous graph\";\n      auto ugptr = hgptr->relation_graphs()[0];\n\n#if !defined(_WIN32)\n      // TODO(zhengda) should we get whatever CSR exists in the graph.\n      gk_csr_t *gk_csr = Convert2GKCsr(ugptr->GetCSCMatrix(0), true);\n      gk_csr_t *sym_gk_csr = gk_csr_MakeSymmetric(gk_csr, GK_CSR_SYM_SUM);\n      auto mat = Convert2DGLCsr(sym_gk_csr, true);\n      gk_csr_Free(&gk_csr);\n      gk_csr_Free(&sym_gk_csr);\n\n      auto new_ugptr = UnitGraph::CreateFromCSC(\n          ugptr->NumVertexTypes(), mat, ugptr->GetAllowedFormats());\n      std::vector<HeteroGraphPtr> rel_graphs = {new_ugptr};\n      *rv = HeteroGraphRef(std::make_shared<HeteroGraph>(\n          hgptr->meta_graph(), rel_graphs, hgptr->NumVerticesPerType()));\n#else\n      LOG(FATAL) << \"The fast version of making symmetric graph is not \"\n                    \"supported in Windows.\";\n#endif  // !defined(_WIN32)\n    });\n\n}  // namespace transform\n}  // namespace dgl\n"
  },
  {
    "path": "src/graph/transform/remove_edges.cc",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file graph/transform/remove_edges.cc\n * @brief Remove edges.\n */\n\n#include <dgl/array.h>\n#include <dgl/base_heterograph.h>\n#include <dgl/packed_func_ext.h>\n#include <dgl/runtime/container.h>\n#include <dgl/runtime/registry.h>\n#include <dgl/transform.h>\n\n#include <tuple>\n#include <utility>\n#include <vector>\n\nnamespace dgl {\n\nusing namespace dgl::runtime;\nusing namespace dgl::aten;\n\nnamespace transform {\n\nstd::pair<HeteroGraphPtr, std::vector<IdArray>> RemoveEdges(\n    const HeteroGraphPtr graph, const std::vector<IdArray> &eids) {\n  std::vector<IdArray> induced_eids;\n  std::vector<HeteroGraphPtr> rel_graphs;\n  const int64_t num_etypes = graph->NumEdgeTypes();\n\n  for (int64_t etype = 0; etype < num_etypes; ++etype) {\n    const SparseFormat fmt = graph->SelectFormat(etype, COO_CODE);\n    const auto src_dst_types = graph->GetEndpointTypes(etype);\n    const dgl_type_t srctype = src_dst_types.first;\n    const dgl_type_t dsttype = src_dst_types.second;\n    const int num_ntypes_rel = (srctype == dsttype) ? 1 : 2;\n    HeteroGraphPtr new_rel_graph;\n    IdArray induced_eids_rel;\n\n    if (fmt == SparseFormat::kCOO) {\n      const COOMatrix &coo = graph->GetCOOMatrix(etype);\n      const COOMatrix &result = COORemove(coo, eids[etype]);\n      new_rel_graph = CreateFromCOO(\n          num_ntypes_rel, result.num_rows, result.num_cols, result.row,\n          result.col);\n      induced_eids_rel = result.data;\n    } else if (fmt == SparseFormat::kCSR) {\n      const CSRMatrix &csr = graph->GetCSRMatrix(etype);\n      const CSRMatrix &result = CSRRemove(csr, eids[etype]);\n      new_rel_graph = CreateFromCSR(\n          num_ntypes_rel, result.num_rows, result.num_cols, result.indptr,\n          result.indices,\n          // TODO(BarclayII): make CSR support null eid array\n          Range(\n              0, result.indices->shape[0], result.indices->dtype.bits,\n              result.indices->ctx));\n      induced_eids_rel = result.data;\n    } else if (fmt == SparseFormat::kCSC) {\n      const CSRMatrix &csc = graph->GetCSCMatrix(etype);\n      const CSRMatrix &result = CSRRemove(csc, eids[etype]);\n      new_rel_graph = CreateFromCSC(\n          num_ntypes_rel, result.num_rows, result.num_cols, result.indptr,\n          result.indices,\n          // TODO(BarclayII): make CSR support null eid array\n          Range(\n              0, result.indices->shape[0], result.indices->dtype.bits,\n              result.indices->ctx));\n      induced_eids_rel = result.data;\n    }\n\n    rel_graphs.push_back(new_rel_graph);\n    induced_eids.push_back(induced_eids_rel);\n  }\n\n  const HeteroGraphPtr new_graph = CreateHeteroGraph(\n      graph->meta_graph(), rel_graphs, graph->NumVerticesPerType());\n  return std::make_pair(new_graph, induced_eids);\n}\n\nDGL_REGISTER_GLOBAL(\"transform._CAPI_DGLRemoveEdges\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      const HeteroGraphRef graph_ref = args[0];\n      const std::vector<IdArray> &eids = ListValueToVector<IdArray>(args[1]);\n\n      HeteroGraphPtr new_graph;\n      std::vector<IdArray> induced_eids;\n      std::tie(new_graph, induced_eids) = RemoveEdges(graph_ref.sptr(), eids);\n\n      List<Value> induced_eids_ref;\n      for (IdArray &array : induced_eids)\n        induced_eids_ref.push_back(Value(MakeValue(array)));\n\n      List<ObjectRef> ret;\n      ret.push_back(HeteroGraphRef(new_graph));\n      ret.push_back(induced_eids_ref);\n\n      *rv = ret;\n    });\n\n};  // namespace transform\n\n};  // namespace dgl\n"
  },
  {
    "path": "src/graph/transform/to_block.cc",
    "content": "/**\n *  Copyright 2019-2021 Contributors\n *\n *  Licensed under the Apache License, Version 2.0 (the \"License\");\n *  you may not use this file except in compliance with the License.\n *  You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n *  Unless required by applicable law or agreed to in writing, software\n *  distributed under the License is distributed on an \"AS IS\" BASIS,\n *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n *  See the License for the specific language governing permissions and\n *  limitations under the License.\n *\n * @file graph/transform/to_block.cc\n * @brief Convert a graph to a bipartite-structured graph.\n *\n * Tested via python wrapper: python/dgl/path/to/to_block.py\n */\n\n#include \"to_block.h\"\n\n#include <dgl/array.h>\n#include <dgl/base_heterograph.h>\n#include <dgl/immutable_graph.h>\n#include <dgl/packed_func_ext.h>\n#include <dgl/runtime/container.h>\n#include <dgl/runtime/device_api.h>\n#include <dgl/runtime/registry.h>\n#include <dgl/transform.h>\n\n#include <tuple>\n#include <utility>\n#include <vector>\n\n#include \"../../array/cpu/concurrent_id_hash_map.h\"\n\nnamespace dgl {\n\nusing namespace dgl::runtime;\nusing namespace dgl::aten;\n\nnamespace transform {\n\nnamespace {\n\ntemplate <typename IdType>\nstruct CPUIdsMapper {\n  std::tuple<std::vector<IdArray>, std::vector<IdArray>> operator()(\n      const HeteroGraphPtr &graph, bool include_rhs_in_lhs, int64_t num_ntypes,\n      const DGLContext &ctx, const std::vector<int64_t> &max_nodes_per_type,\n      const std::vector<EdgeArray> &edge_arrays,\n      const std::vector<IdArray> &src_nodes,\n      const std::vector<IdArray> &rhs_nodes,\n      std::vector<IdArray> *const lhs_nodes_ptr,\n      std::vector<int64_t> *const num_nodes_per_type_ptr) {\n    std::vector<IdArray> &lhs_nodes = *lhs_nodes_ptr;\n    std::vector<int64_t> &num_nodes_per_type = *num_nodes_per_type_ptr;\n\n    const bool generate_lhs_nodes = lhs_nodes.empty();\n    if (generate_lhs_nodes) {\n      lhs_nodes.reserve(num_ntypes);\n    }\n\n    std::vector<ConcurrentIdHashMap<IdType>> lhs_nodes_map(num_ntypes);\n    std::vector<ConcurrentIdHashMap<IdType>> rhs_nodes_map(num_ntypes);\n    for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) {\n      IdArray unique_ids =\n          aten::NullArray(DGLDataTypeTraits<IdType>::dtype, ctx);\n      if (!aten::IsNullArray(src_nodes[ntype])) {\n        auto num_seeds = include_rhs_in_lhs ? rhs_nodes[ntype]->shape[0] : 0;\n        unique_ids = lhs_nodes_map[ntype].Init(src_nodes[ntype], num_seeds);\n      }\n      if (generate_lhs_nodes) {\n        num_nodes_per_type[ntype] = unique_ids->shape[0];\n        lhs_nodes.emplace_back(unique_ids);\n      }\n    }\n\n    // Skip rhs mapping construction to save efforts when rhs is already\n    // contained in lhs.\n    if (!include_rhs_in_lhs) {\n      for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) {\n        if (!aten::IsNullArray(rhs_nodes[ntype])) {\n          rhs_nodes_map[ntype].Init(\n              rhs_nodes[ntype], rhs_nodes[ntype]->shape[0]);\n        }\n      }\n    }\n\n    // Map node numberings from global to local, and build pointer for CSR.\n    std::vector<IdArray> new_lhs;\n    std::vector<IdArray> new_rhs;\n    new_lhs.reserve(edge_arrays.size());\n    new_rhs.reserve(edge_arrays.size());\n    const int64_t num_etypes = static_cast<int64_t>(edge_arrays.size());\n    for (int64_t etype = 0; etype < num_etypes; ++etype) {\n      const EdgeArray &edges = edge_arrays[etype];\n      if (edges.id.defined() && !aten::IsNullArray(edges.src)) {\n        const auto src_dst_types = graph->GetEndpointTypes(etype);\n        const int src_type = src_dst_types.first;\n        const int dst_type = src_dst_types.second;\n        new_lhs.emplace_back(lhs_nodes_map[src_type].MapIds(edges.src));\n        if (include_rhs_in_lhs) {\n          new_rhs.emplace_back(lhs_nodes_map[dst_type].MapIds(edges.dst));\n        } else {\n          new_rhs.emplace_back(rhs_nodes_map[dst_type].MapIds(edges.dst));\n        }\n      } else {\n        new_lhs.emplace_back(\n            aten::NullArray(DGLDataTypeTraits<IdType>::dtype, ctx));\n        new_rhs.emplace_back(\n            aten::NullArray(DGLDataTypeTraits<IdType>::dtype, ctx));\n      }\n    }\n    return std::tuple<std::vector<IdArray>, std::vector<IdArray>>(\n        std::move(new_lhs), std::move(new_rhs));\n  }\n};\n\n// Since partial specialization is not allowed for functions, use this as an\n// intermediate for ToBlock where XPU = kDGLCPU.\ntemplate <typename IdType>\nstd::tuple<HeteroGraphPtr, std::vector<IdArray>> ToBlockCPU(\n    HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes,\n    bool include_rhs_in_lhs, std::vector<IdArray> *const lhs_nodes_ptr) {\n  return dgl::transform::ProcessToBlock<IdType>(\n      graph, rhs_nodes, include_rhs_in_lhs, lhs_nodes_ptr,\n      CPUIdsMapper<IdType>());\n}\n\n}  // namespace\n\ntemplate <typename IdType>\nstd::tuple<HeteroGraphPtr, std::vector<IdArray>> ProcessToBlock(\n    HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes,\n    bool include_rhs_in_lhs, std::vector<IdArray> *const lhs_nodes_ptr,\n    IdsMapper &&ids_mapper) {\n  std::vector<IdArray> &lhs_nodes = *lhs_nodes_ptr;\n  const bool generate_lhs_nodes = lhs_nodes.empty();\n\n  const auto &ctx = graph->Context();\n  auto device = runtime::DeviceAPI::Get(ctx);\n\n  // Since DST nodes are included in SRC nodes, a common requirement is to fetch\n  // the DST node features from the SRC nodes features. To avoid expensive\n  // sparse lookup, the function assures that the DST nodes in both SRC and DST\n  // sets have the same ids. As a result, given the node feature tensor ``X`` of\n  // type ``utype``, the following code finds the corresponding DST node\n  // features of type ``vtype``:\n\n  const int64_t num_etypes = graph->NumEdgeTypes();\n  const int64_t num_ntypes = graph->NumVertexTypes();\n\n  CHECK(rhs_nodes.size() == static_cast<size_t>(num_ntypes))\n      << \"rhs_nodes not given for every node type\";\n\n  std::vector<EdgeArray> edge_arrays(num_etypes);\n  for (int64_t etype = 0; etype < num_etypes; ++etype) {\n    const auto src_dst_types = graph->GetEndpointTypes(etype);\n    const dgl_type_t dsttype = src_dst_types.second;\n    if (!aten::IsNullArray(rhs_nodes[dsttype])) {\n      edge_arrays[etype] = graph->Edges(etype);\n    }\n  }\n\n  // Count lhs and rhs nodes.\n  std::vector<int64_t> maxNodesPerType(num_ntypes * 2, 0);\n  for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) {\n    maxNodesPerType[ntype + num_ntypes] += rhs_nodes[ntype]->shape[0];\n\n    if (generate_lhs_nodes) {\n      if (include_rhs_in_lhs) {\n        maxNodesPerType[ntype] += rhs_nodes[ntype]->shape[0];\n      }\n    } else {\n      maxNodesPerType[ntype] += lhs_nodes[ntype]->shape[0];\n    }\n  }\n  if (generate_lhs_nodes) {\n    // We don't have lhs_nodes, see we need to count inbound edges to get an\n    // upper bound.\n    for (int64_t etype = 0; etype < num_etypes; ++etype) {\n      const auto src_dst_types = graph->GetEndpointTypes(etype);\n      const dgl_type_t srctype = src_dst_types.first;\n      if (edge_arrays[etype].src.defined()) {\n        maxNodesPerType[srctype] += edge_arrays[etype].src->shape[0];\n      }\n    }\n  }\n\n  // Gather lhs_nodes.\n  std::vector<IdArray> src_nodes(num_ntypes);\n  if (generate_lhs_nodes) {\n    std::vector<int64_t> src_node_offsets(num_ntypes, 0);\n    for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) {\n      src_nodes[ntype] =\n          NewIdArray(maxNodesPerType[ntype], ctx, sizeof(IdType) * 8);\n      if (include_rhs_in_lhs) {\n        // Place rhs nodes first.\n        device->CopyDataFromTo(\n            rhs_nodes[ntype].Ptr<IdType>(), 0, src_nodes[ntype].Ptr<IdType>(),\n            src_node_offsets[ntype],\n            sizeof(IdType) * rhs_nodes[ntype]->shape[0], rhs_nodes[ntype]->ctx,\n            src_nodes[ntype]->ctx, rhs_nodes[ntype]->dtype);\n        src_node_offsets[ntype] += sizeof(IdType) * rhs_nodes[ntype]->shape[0];\n      }\n    }\n    for (int64_t etype = 0; etype < num_etypes; ++etype) {\n      const auto src_dst_types = graph->GetEndpointTypes(etype);\n      const dgl_type_t srctype = src_dst_types.first;\n      if (edge_arrays[etype].src.defined()) {\n        device->CopyDataFromTo(\n            edge_arrays[etype].src.Ptr<IdType>(), 0,\n            src_nodes[srctype].Ptr<IdType>(), src_node_offsets[srctype],\n            sizeof(IdType) * edge_arrays[etype].src->shape[0],\n            rhs_nodes[srctype]->ctx, src_nodes[srctype]->ctx,\n            rhs_nodes[srctype]->dtype);\n\n        src_node_offsets[srctype] +=\n            sizeof(IdType) * edge_arrays[etype].src->shape[0];\n      }\n    }\n  } else {\n    for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) {\n      src_nodes[ntype] = lhs_nodes[ntype];\n    }\n  }\n\n  std::vector<int64_t> num_nodes_per_type(num_ntypes * 2);\n  // Populate RHS nodes from what we already know.\n  for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) {\n    num_nodes_per_type[num_ntypes + ntype] = rhs_nodes[ntype]->shape[0];\n  }\n\n  std::vector<IdArray> new_lhs;\n  std::vector<IdArray> new_rhs;\n  std::tie(new_lhs, new_rhs) = ids_mapper(\n      graph, include_rhs_in_lhs, num_ntypes, ctx, maxNodesPerType, edge_arrays,\n      src_nodes, rhs_nodes, lhs_nodes_ptr, &num_nodes_per_type);\n\n  std::vector<IdArray> induced_edges;\n  induced_edges.reserve(num_etypes);\n  for (int64_t etype = 0; etype < num_etypes; ++etype) {\n    if (edge_arrays[etype].id.defined()) {\n      induced_edges.push_back(edge_arrays[etype].id);\n    } else {\n      induced_edges.push_back(\n          aten::NullArray(DGLDataType{kDGLInt, sizeof(IdType) * 8, 1}, ctx));\n    }\n  }\n\n  // Build metagraph.\n  const auto meta_graph = graph->meta_graph();\n  const EdgeArray etypes = meta_graph->Edges(\"eid\");\n  const IdArray new_dst = Add(etypes.dst, num_ntypes);\n  const auto new_meta_graph =\n      ImmutableGraph::CreateFromCOO(num_ntypes * 2, etypes.src, new_dst);\n\n  // Allocate vector for graph relations while GPU is busy.\n  std::vector<HeteroGraphPtr> rel_graphs;\n  rel_graphs.reserve(num_etypes);\n\n  // Build the heterograph.\n  for (int64_t etype = 0; etype < num_etypes; ++etype) {\n    const auto src_dst_types = graph->GetEndpointTypes(etype);\n    const dgl_type_t srctype = src_dst_types.first;\n    const dgl_type_t dsttype = src_dst_types.second;\n\n    if (rhs_nodes[dsttype]->shape[0] == 0) {\n      // No rhs nodes are given for this edge type. Create an empty graph.\n      rel_graphs.push_back(CreateFromCOO(\n          2, lhs_nodes[srctype]->shape[0], rhs_nodes[dsttype]->shape[0],\n          aten::NullArray(DGLDataType{kDGLInt, sizeof(IdType) * 8, 1}, ctx),\n          aten::NullArray(DGLDataType{kDGLInt, sizeof(IdType) * 8, 1}, ctx)));\n    } else {\n      rel_graphs.push_back(CreateFromCOO(\n          2, lhs_nodes[srctype]->shape[0], rhs_nodes[dsttype]->shape[0],\n          new_lhs[etype], new_rhs[etype]));\n    }\n  }\n\n  HeteroGraphPtr new_graph =\n      CreateHeteroGraph(new_meta_graph, rel_graphs, num_nodes_per_type);\n\n  // Return the new graph, the new src nodes, and new edges.\n  return std::make_tuple(new_graph, induced_edges);\n}\n\ntemplate std::tuple<HeteroGraphPtr, std::vector<IdArray>>\nProcessToBlock<int32_t>(\n    HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes,\n    bool include_rhs_in_lhs, std::vector<IdArray> *const lhs_nodes_ptr,\n    IdsMapper &&get_maping_ids);\n\ntemplate std::tuple<HeteroGraphPtr, std::vector<IdArray>>\nProcessToBlock<int64_t>(\n    HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes,\n    bool include_rhs_in_lhs, std::vector<IdArray> *const lhs_nodes_ptr,\n    IdsMapper &&get_maping_ids);\n\ntemplate <>\nstd::tuple<HeteroGraphPtr, std::vector<IdArray>> ToBlock<kDGLCPU, int32_t>(\n    HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes,\n    bool include_rhs_in_lhs, std::vector<IdArray> *const lhs_nodes) {\n  return ToBlockCPU<int32_t>(graph, rhs_nodes, include_rhs_in_lhs, lhs_nodes);\n}\n\ntemplate <>\nstd::tuple<HeteroGraphPtr, std::vector<IdArray>> ToBlock<kDGLCPU, int64_t>(\n    HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes,\n    bool include_rhs_in_lhs, std::vector<IdArray> *const lhs_nodes) {\n  return ToBlockCPU<int64_t>(graph, rhs_nodes, include_rhs_in_lhs, lhs_nodes);\n}\n\n#ifdef DGL_USE_CUDA\n\n// Forward declaration of GPU ToBlock implementations - actual implementation is\n// in\n// ./cuda/cuda_to_block.cu\n// This is to get around the broken name mangling in VS2019 CL 16.5.5 +\n// CUDA 11.3 which complains that the two template specializations have the same\n// signature.\nstd::tuple<HeteroGraphPtr, std::vector<IdArray>> ToBlockGPU32(\n    HeteroGraphPtr, const std::vector<IdArray> &, bool,\n    std::vector<IdArray> *const);\nstd::tuple<HeteroGraphPtr, std::vector<IdArray>> ToBlockGPU64(\n    HeteroGraphPtr, const std::vector<IdArray> &, bool,\n    std::vector<IdArray> *const);\n\ntemplate <>\nstd::tuple<HeteroGraphPtr, std::vector<IdArray>> ToBlock<kDGLCUDA, int32_t>(\n    HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes,\n    bool include_rhs_in_lhs, std::vector<IdArray> *const lhs_nodes) {\n  return ToBlockGPU32(graph, rhs_nodes, include_rhs_in_lhs, lhs_nodes);\n}\n\ntemplate <>\nstd::tuple<HeteroGraphPtr, std::vector<IdArray>> ToBlock<kDGLCUDA, int64_t>(\n    HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes,\n    bool include_rhs_in_lhs, std::vector<IdArray> *const lhs_nodes) {\n  return ToBlockGPU64(graph, rhs_nodes, include_rhs_in_lhs, lhs_nodes);\n}\n\n#endif  // DGL_USE_CUDA\n\nDGL_REGISTER_GLOBAL(\"capi._CAPI_DGLToBlock\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      const HeteroGraphRef graph_ref = args[0];\n      const std::vector<IdArray> &rhs_nodes =\n          ListValueToVector<IdArray>(args[1]);\n      const bool include_rhs_in_lhs = args[2];\n      std::vector<IdArray> lhs_nodes = ListValueToVector<IdArray>(args[3]);\n\n      HeteroGraphPtr new_graph;\n      std::vector<IdArray> induced_edges;\n\n      ATEN_XPU_SWITCH_CUDA(graph_ref->Context().device_type, XPU, \"ToBlock\", {\n        ATEN_ID_TYPE_SWITCH(graph_ref->DataType(), IdType, {\n          std::tie(new_graph, induced_edges) = ToBlock<XPU, IdType>(\n              graph_ref.sptr(), rhs_nodes, include_rhs_in_lhs, &lhs_nodes);\n        });\n      });\n\n      List<Value> lhs_nodes_ref;\n      for (IdArray &array : lhs_nodes)\n        lhs_nodes_ref.push_back(Value(MakeValue(array)));\n      List<Value> induced_edges_ref;\n      for (IdArray &array : induced_edges)\n        induced_edges_ref.push_back(Value(MakeValue(array)));\n\n      List<ObjectRef> ret;\n      ret.push_back(HeteroGraphRef(new_graph));\n      ret.push_back(lhs_nodes_ref);\n      ret.push_back(induced_edges_ref);\n\n      *rv = ret;\n    });\n\n};  // namespace transform\n\n};  // namespace dgl\n"
  },
  {
    "path": "src/graph/transform/to_block.h",
    "content": "/**\n *  Copyright 2021 Contributors\n *\n *  Licensed under the Apache License, Version 2.0 (the \"License\");\n *  you may not use this file except in compliance with the License.\n *  You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n *  Unless required by applicable law or agreed to in writing, software\n *  distributed under the License is distributed on an \"AS IS\" BASIS,\n *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n *  See the License for the specific language governing permissions and\n *  limitations under the License.\n *\n * @file graph/transform/to_block.h\n * @brief Functions to convert a set of edges into a graph block with local\n * ids.\n */\n\n#ifndef DGL_GRAPH_TRANSFORM_TO_BLOCK_H_\n#define DGL_GRAPH_TRANSFORM_TO_BLOCK_H_\n\n#include <dgl/array.h>\n#include <dgl/base_heterograph.h>\n\n#include <functional>\n#include <tuple>\n#include <vector>\n\nnamespace dgl {\nnamespace transform {\n\n/** @brief Mapper used in block generation which maps left and right Id arrays\n * in the original MFG to new arrays with continuous numbers.\n */\nusing IdsMapper =\n    std::function<std::tuple<std::vector<IdArray>, std::vector<IdArray>>(\n        const HeteroGraphPtr&, bool, int64_t, const DGLContext&,\n        const std::vector<int64_t>&, const std::vector<EdgeArray>&,\n        const std::vector<IdArray>&, const std::vector<IdArray>&,\n        std::vector<IdArray>* const, std::vector<int64_t>* const)>;\n\n/**\n * @brief Create a graph block from the set of\n * src and dst nodes (lhs and rhs respectively).\n *\n * @tparam XPU The type of device to operate on.\n * @tparam IdType The type to use as an index.\n * @param graph The graph from which to extract the block.\n * @param rhs_nodes The destination nodes of the block.\n * @param include_rhs_in_lhs Whether or not to include the\n * destination nodes of the block in the sources nodes.\n * @param [in/out] lhs_nodes The source nodes of the block.\n *\n * @return The block and the induced edges.\n */\ntemplate <DGLDeviceType XPU, typename IdType>\nstd::tuple<HeteroGraphPtr, std::vector<IdArray>> ToBlock(\n    HeteroGraphPtr graph, const std::vector<IdArray>& rhs_nodes,\n    bool include_rhs_in_lhs, std::vector<IdArray>* lhs_nodes);\n\n/**\n * @brief A warpper function shared by CPU and GPU ```ToBlock```\n * which deal with the common preprocess and postprocess work of them.\n *\n * @tparam IdType The type to use as an index.\n * @param graph The graph from which to extract the block.\n * @param rhs_nodes The destination nodes of the block.\n * @param include_rhs_in_lhs Whether or not to include the\n * destination nodes of the block in the sources nodes.\n * @param [in/out] lhs_nodes The source nodes of the block.\n * @param MappingIdsFunc  The function to get mapped ids from original ids.\n *\n * @return The block and the induced edges.\n */\ntemplate <typename IdType>\nstd::tuple<HeteroGraphPtr, std::vector<IdArray>> ProcessToBlock(\n    HeteroGraphPtr graph, const std::vector<IdArray>& rhs_nodes,\n    bool include_rhs_in_lhs, std::vector<IdArray>* const lhs_nodes_ptr,\n    IdsMapper&& get_maping_ids);\n\n}  // namespace transform\n}  // namespace dgl\n\n#endif  // DGL_GRAPH_TRANSFORM_TO_BLOCK_H_\n"
  },
  {
    "path": "src/graph/transform/to_simple.cc",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file graph/transform/to_simple.cc\n * @brief Convert multigraphs to simple graphs\n */\n\n#include <dgl/array.h>\n#include <dgl/base_heterograph.h>\n#include <dgl/packed_func_ext.h>\n#include <dgl/transform.h>\n\n#include <utility>\n#include <vector>\n\n#include \"../../c_api_common.h\"\n#include \"../heterograph.h\"\n#include \"../unit_graph.h\"\n\nnamespace dgl {\n\nusing namespace dgl::runtime;\nusing namespace dgl::aten;\n\nnamespace transform {\n\nstd::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>>\nToSimpleGraph(const HeteroGraphPtr graph) {\n  const int64_t num_etypes = graph->NumEdgeTypes();\n  const auto metagraph = graph->meta_graph();\n  const auto &ugs =\n      std::dynamic_pointer_cast<HeteroGraph>(graph)->relation_graphs();\n\n  std::vector<IdArray> counts(num_etypes), edge_maps(num_etypes);\n  std::vector<HeteroGraphPtr> rel_graphs(num_etypes);\n\n  for (int64_t etype = 0; etype < num_etypes; ++etype) {\n    const auto result = ugs[etype]->ToSimple();\n    std::tie(rel_graphs[etype], counts[etype], edge_maps[etype]) = result;\n  }\n\n  const HeteroGraphPtr result =\n      CreateHeteroGraph(metagraph, rel_graphs, graph->NumVerticesPerType());\n\n  return std::make_tuple(result, counts, edge_maps);\n}\n\nDGL_REGISTER_GLOBAL(\"transform._CAPI_DGLToSimpleHetero\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      const HeteroGraphRef graph_ref = args[0];\n\n      const auto result = ToSimpleGraph(graph_ref.sptr());\n\n      List<Value> counts, edge_maps;\n      for (const IdArray &count : std::get<1>(result))\n        counts.push_back(Value(MakeValue(count)));\n      for (const IdArray &edge_map : std::get<2>(result))\n        edge_maps.push_back(Value(MakeValue(edge_map)));\n\n      List<ObjectRef> ret;\n      ret.push_back(HeteroGraphRef(std::get<0>(result)));\n      ret.push_back(counts);\n      ret.push_back(edge_maps);\n\n      *rv = ret;\n    });\n\n};  // namespace transform\n\n};  // namespace dgl\n"
  },
  {
    "path": "src/graph/transform/union_partition.cc",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file graph/transform/union_partition.cc\n * @brief Functions for partition, union multiple graphs.\n */\n#include \"../heterograph.h\"\nusing namespace dgl::runtime;\n\nnamespace dgl {\n\nHeteroGraphPtr JointUnionHeteroGraph(\n    GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs) {\n  CHECK_GT(component_graphs.size(), 0)\n      << \"Input graph list has at least two graphs\";\n  std::vector<HeteroGraphPtr> rel_graphs(meta_graph->NumEdges());\n  std::vector<int64_t> num_nodes_per_type(meta_graph->NumVertices(), 0);\n\n  // Loop over all canonical etypes\n  for (dgl_type_t etype = 0; etype < meta_graph->NumEdges(); ++etype) {\n    auto pair = meta_graph->FindEdge(etype);\n    const dgl_type_t src_vtype = pair.first;\n    const dgl_type_t dst_vtype = pair.second;\n    uint64_t num_src_v = component_graphs[0]->NumVertices(src_vtype);\n    uint64_t num_dst_v = component_graphs[0]->NumVertices(dst_vtype);\n    HeteroGraphPtr rgptr = nullptr;\n\n    // ALL = CSC | CSR | COO\n    const dgl_format_code_t code =\n        component_graphs[0]->GetRelationGraph(etype)->GetAllowedFormats();\n\n    // get common format\n    for (size_t i = 0; i < component_graphs.size(); ++i) {\n      const auto& cg = component_graphs[i];\n      CHECK_EQ(num_src_v, component_graphs[i]->NumVertices(src_vtype))\n          << \"Input graph[\" << i\n          << \"] should have same number of src vertices as input graph[0]\";\n      CHECK_EQ(num_dst_v, component_graphs[i]->NumVertices(dst_vtype))\n          << \"Input graph[\" << i\n          << \"] should have same number of dst vertices as input graph[0]\";\n\n      const dgl_format_code_t curr_code =\n          cg->GetRelationGraph(etype)->GetAllowedFormats();\n      if (curr_code != code)\n        LOG(FATAL) << \"All components should have the same formats\";\n    }\n\n    // prefer COO\n    if (FORMAT_HAS_COO(code)) {\n      std::vector<aten::COOMatrix> coos;\n      for (size_t i = 0; i < component_graphs.size(); ++i) {\n        const auto& cg = component_graphs[i];\n        aten::COOMatrix coo = cg->GetCOOMatrix(etype);\n        coos.push_back(coo);\n      }\n\n      aten::COOMatrix res = aten::UnionCoo(coos);\n      rgptr =\n          UnitGraph::CreateFromCOO((src_vtype == dst_vtype) ? 1 : 2, res, code);\n    } else if (FORMAT_HAS_CSR(code)) {\n      std::vector<aten::CSRMatrix> csrs;\n      for (size_t i = 0; i < component_graphs.size(); ++i) {\n        const auto& cg = component_graphs[i];\n        aten::CSRMatrix csr = cg->GetCSRMatrix(etype);\n        csrs.push_back(csr);\n      }\n\n      aten::CSRMatrix res = aten::UnionCsr(csrs);\n      rgptr =\n          UnitGraph::CreateFromCSR((src_vtype == dst_vtype) ? 1 : 2, res, code);\n    } else if (FORMAT_HAS_CSC(code)) {\n      // CSR and CSC have the same storage format, i.e. CSRMatrix\n      std::vector<aten::CSRMatrix> cscs;\n      for (size_t i = 0; i < component_graphs.size(); ++i) {\n        const auto& cg = component_graphs[i];\n        aten::CSRMatrix csc = cg->GetCSCMatrix(etype);\n        cscs.push_back(csc);\n      }\n\n      aten::CSRMatrix res = aten::UnionCsr(cscs);\n      rgptr =\n          UnitGraph::CreateFromCSC((src_vtype == dst_vtype) ? 1 : 2, res, code);\n    }\n\n    rel_graphs[etype] = rgptr;\n    num_nodes_per_type[src_vtype] = num_src_v;\n    num_nodes_per_type[dst_vtype] = num_dst_v;\n  }\n\n  return CreateHeteroGraph(\n      meta_graph, rel_graphs, std::move(num_nodes_per_type));\n}\n\nHeteroGraphPtr DisjointUnionHeteroGraph2(\n    GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs) {\n  CHECK_GT(component_graphs.size(), 0) << \"Input graph list is empty\";\n  std::vector<HeteroGraphPtr> rel_graphs(meta_graph->NumEdges());\n  std::vector<int64_t> num_nodes_per_type(meta_graph->NumVertices(), 0);\n\n  // Loop over all ntypes\n  for (dgl_type_t vtype = 0; vtype < meta_graph->NumVertices(); ++vtype) {\n    uint64_t offset = 0;\n    for (const auto& cg : component_graphs) offset += cg->NumVertices(vtype);\n    num_nodes_per_type[vtype] = offset;\n  }\n\n  // Loop over all canonical etypes\n  for (dgl_type_t etype = 0; etype < meta_graph->NumEdges(); ++etype) {\n    auto pair = meta_graph->FindEdge(etype);\n    const dgl_type_t src_vtype = pair.first;\n    const dgl_type_t dst_vtype = pair.second;\n    HeteroGraphPtr rgptr = nullptr;\n\n    const dgl_format_code_t code =\n        component_graphs[0]->GetRelationGraph(etype)->GetAllowedFormats();\n    // do some preprocess\n    for (const auto& cg : component_graphs) {\n      const dgl_format_code_t cur_code =\n          cg->GetRelationGraph(etype)->GetAllowedFormats();\n      if (cur_code != code)\n        LOG(FATAL) << \"All components should have the same formats\";\n    }\n\n    // prefer COO\n    if (FORMAT_HAS_COO(code)) {\n      std::vector<aten::COOMatrix> coos;\n      for (const auto& cg : component_graphs) {\n        aten::COOMatrix coo = cg->GetCOOMatrix(etype);\n        coos.push_back(coo);\n      }\n\n      aten::COOMatrix res = aten::DisjointUnionCoo(coos);\n\n      rgptr =\n          UnitGraph::CreateFromCOO((src_vtype == dst_vtype) ? 1 : 2, res, code);\n    } else if (FORMAT_HAS_CSR(code)) {\n      std::vector<aten::CSRMatrix> csrs;\n      for (const auto& cg : component_graphs) {\n        aten::CSRMatrix csr = cg->GetCSRMatrix(etype);\n        csrs.push_back(csr);\n      }\n\n      aten::CSRMatrix res = aten::DisjointUnionCsr(csrs);\n\n      rgptr =\n          UnitGraph::CreateFromCSR((src_vtype == dst_vtype) ? 1 : 2, res, code);\n    } else if (FORMAT_HAS_CSC(code)) {\n      // CSR and CSC have the same storage format, i.e. CSRMatrix\n      std::vector<aten::CSRMatrix> cscs;\n      for (const auto& cg : component_graphs) {\n        aten::CSRMatrix csc = cg->GetCSCMatrix(etype);\n        cscs.push_back(csc);\n      }\n\n      aten::CSRMatrix res = aten::DisjointUnionCsr(cscs);\n      rgptr =\n          UnitGraph::CreateFromCSC((src_vtype == dst_vtype) ? 1 : 2, res, code);\n    }\n    rel_graphs[etype] = rgptr;\n  }\n\n  return CreateHeteroGraph(\n      meta_graph, rel_graphs, std::move(num_nodes_per_type));\n}\n\nstd::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes2(\n    GraphPtr meta_graph, HeteroGraphPtr batched_graph, IdArray vertex_sizes,\n    IdArray edge_sizes) {\n  // Sanity check for vertex sizes\n  CHECK_EQ(vertex_sizes->dtype.bits, 64)\n      << \"dtype of vertex_sizes should be int64\";\n  CHECK_EQ(edge_sizes->dtype.bits, 64) << \"dtype of edge_sizes should be int64\";\n  const uint64_t len_vertex_sizes = vertex_sizes->shape[0];\n  const uint64_t* vertex_sizes_data =\n      static_cast<uint64_t*>(vertex_sizes->data);\n  const uint64_t num_vertex_types = meta_graph->NumVertices();\n  const uint64_t batch_size = len_vertex_sizes / num_vertex_types;\n\n  // Map vertex type to the corresponding node cum sum\n  std::vector<std::vector<uint64_t>> vertex_cumsum;\n  vertex_cumsum.resize(num_vertex_types);\n  // Loop over all vertex types\n  for (uint64_t vtype = 0; vtype < num_vertex_types; ++vtype) {\n    vertex_cumsum[vtype].push_back(0);\n    for (uint64_t g = 0; g < batch_size; ++g) {\n      // We've flattened the number of vertices in the batch for all types\n      vertex_cumsum[vtype].push_back(\n          vertex_cumsum[vtype][g] + vertex_sizes_data[vtype * batch_size + g]);\n    }\n    CHECK_EQ(\n        vertex_cumsum[vtype][batch_size], batched_graph->NumVertices(vtype))\n        << \"Sum of the given sizes must equal to the number of nodes for type \"\n        << vtype;\n  }\n\n  // Sanity check for edge sizes\n  const uint64_t* edge_sizes_data = static_cast<uint64_t*>(edge_sizes->data);\n  const uint64_t num_edge_types = meta_graph->NumEdges();\n  // Map edge type to the corresponding edge cum sum\n  std::vector<std::vector<uint64_t>> edge_cumsum;\n  edge_cumsum.resize(num_edge_types);\n  // Loop over all edge types\n  for (uint64_t etype = 0; etype < num_edge_types; ++etype) {\n    edge_cumsum[etype].push_back(0);\n    for (uint64_t g = 0; g < batch_size; ++g) {\n      // We've flattened the number of edges in the batch for all types\n      edge_cumsum[etype].push_back(\n          edge_cumsum[etype][g] + edge_sizes_data[etype * batch_size + g]);\n    }\n    CHECK_EQ(edge_cumsum[etype][batch_size], batched_graph->NumEdges(etype))\n        << \"Sum of the given sizes must equal to the number of edges for type \"\n        << etype;\n  }\n\n  // Construct relation graphs for unbatched graphs\n  std::vector<std::vector<HeteroGraphPtr>> rel_graphs;\n  rel_graphs.resize(batch_size);\n  // Loop over all edge types\n  auto code = batched_graph->GetRelationGraph(0)->GetAllowedFormats();\n\n  if (FORMAT_HAS_COO(code)) {\n    for (uint64_t etype = 0; etype < num_edge_types; ++etype) {\n      auto pair = meta_graph->FindEdge(etype);\n      const dgl_type_t src_vtype = pair.first;\n      const dgl_type_t dst_vtype = pair.second;\n      aten::COOMatrix coo = batched_graph->GetCOOMatrix(etype);\n      auto res = aten::DisjointPartitionCooBySizes(\n          coo, batch_size, edge_cumsum[etype], vertex_cumsum[src_vtype],\n          vertex_cumsum[dst_vtype]);\n      for (uint64_t g = 0; g < batch_size; ++g) {\n        HeteroGraphPtr rgptr = UnitGraph::CreateFromCOO(\n            (src_vtype == dst_vtype) ? 1 : 2, res[g], code);\n        rel_graphs[g].push_back(rgptr);\n      }\n    }\n  } else if (FORMAT_HAS_CSR(code)) {\n    for (uint64_t etype = 0; etype < num_edge_types; ++etype) {\n      auto pair = meta_graph->FindEdge(etype);\n      const dgl_type_t src_vtype = pair.first;\n      const dgl_type_t dst_vtype = pair.second;\n      aten::CSRMatrix csr = batched_graph->GetCSRMatrix(etype);\n      auto res = aten::DisjointPartitionCsrBySizes(\n          csr, batch_size, edge_cumsum[etype], vertex_cumsum[src_vtype],\n          vertex_cumsum[dst_vtype]);\n      for (uint64_t g = 0; g < batch_size; ++g) {\n        HeteroGraphPtr rgptr = UnitGraph::CreateFromCSR(\n            (src_vtype == dst_vtype) ? 1 : 2, res[g], code);\n        rel_graphs[g].push_back(rgptr);\n      }\n    }\n  } else if (FORMAT_HAS_CSC(code)) {\n    for (uint64_t etype = 0; etype < num_edge_types; ++etype) {\n      auto pair = meta_graph->FindEdge(etype);\n      const dgl_type_t src_vtype = pair.first;\n      const dgl_type_t dst_vtype = pair.second;\n      // CSR and CSC have the same storage format, i.e. CSRMatrix\n      aten::CSRMatrix csc = batched_graph->GetCSCMatrix(etype);\n      auto res = aten::DisjointPartitionCsrBySizes(\n          csc, batch_size, edge_cumsum[etype], vertex_cumsum[dst_vtype],\n          vertex_cumsum[src_vtype]);\n      for (uint64_t g = 0; g < batch_size; ++g) {\n        HeteroGraphPtr rgptr = UnitGraph::CreateFromCSC(\n            (src_vtype == dst_vtype) ? 1 : 2, res[g], code);\n        rel_graphs[g].push_back(rgptr);\n      }\n    }\n  }\n\n  std::vector<HeteroGraphPtr> rst;\n  std::vector<int64_t> num_nodes_per_type(num_vertex_types);\n  for (uint64_t g = 0; g < batch_size; ++g) {\n    for (uint64_t i = 0; i < num_vertex_types; ++i)\n      num_nodes_per_type[i] = vertex_sizes_data[i * batch_size + g];\n    rst.push_back(\n        CreateHeteroGraph(meta_graph, rel_graphs[g], num_nodes_per_type));\n  }\n  return rst;\n}\n\nHeteroGraphPtr SliceHeteroGraph(\n    GraphPtr meta_graph, HeteroGraphPtr batched_graph,\n    IdArray num_nodes_per_type, IdArray start_nid_per_type,\n    IdArray num_edges_per_type, IdArray start_eid_per_type) {\n  std::vector<HeteroGraphPtr> rel_graphs(meta_graph->NumEdges());\n\n  const uint64_t* start_nid_per_type_data =\n      static_cast<uint64_t*>(start_nid_per_type->data);\n  const uint64_t* num_nodes_per_type_data =\n      static_cast<uint64_t*>(num_nodes_per_type->data);\n  const uint64_t* start_eid_per_type_data =\n      static_cast<uint64_t*>(start_eid_per_type->data);\n  const uint64_t* num_edges_per_type_data =\n      static_cast<uint64_t*>(num_edges_per_type->data);\n\n  // Map vertex type to the corresponding node range\n  const uint64_t num_vertex_types = meta_graph->NumVertices();\n  std::vector<std::vector<uint64_t>> vertex_range;\n  vertex_range.resize(num_vertex_types);\n  // Loop over all vertex types\n  for (uint64_t vtype = 0; vtype < num_vertex_types; ++vtype) {\n    vertex_range[vtype].push_back(start_nid_per_type_data[vtype]);\n    vertex_range[vtype].push_back(\n        start_nid_per_type_data[vtype] + num_nodes_per_type_data[vtype]);\n  }\n\n  // Loop over all canonical etypes\n  for (dgl_type_t etype = 0; etype < meta_graph->NumEdges(); ++etype) {\n    auto pair = meta_graph->FindEdge(etype);\n    const dgl_type_t src_vtype = pair.first;\n    const dgl_type_t dst_vtype = pair.second;\n    HeteroGraphPtr rgptr = nullptr;\n    const dgl_format_code_t code =\n        batched_graph->GetRelationGraph(etype)->GetAllowedFormats();\n\n    // handle graph without edges\n    std::vector<uint64_t> edge_range;\n    edge_range.push_back(start_eid_per_type_data[etype]);\n    edge_range.push_back(\n        start_eid_per_type_data[etype] + num_edges_per_type_data[etype]);\n\n    // prefer COO\n    if (FORMAT_HAS_COO(code)) {\n      aten::COOMatrix coo = batched_graph->GetCOOMatrix(etype);\n      aten::COOMatrix res = aten::COOSliceContiguousChunk(\n          coo, edge_range, vertex_range[src_vtype], vertex_range[dst_vtype]);\n      rgptr =\n          UnitGraph::CreateFromCOO((src_vtype == dst_vtype) ? 1 : 2, res, code);\n    } else if (FORMAT_HAS_CSR(code)) {\n      aten::CSRMatrix csr = batched_graph->GetCSRMatrix(etype);\n      aten::CSRMatrix res = aten::CSRSliceContiguousChunk(\n          csr, edge_range, vertex_range[src_vtype], vertex_range[dst_vtype]);\n      rgptr =\n          UnitGraph::CreateFromCSR((src_vtype == dst_vtype) ? 1 : 2, res, code);\n    } else if (FORMAT_HAS_CSC(code)) {\n      // CSR and CSC have the same storage format, i.e. CSRMatrix\n      aten::CSRMatrix csc = batched_graph->GetCSCMatrix(etype);\n      aten::CSRMatrix res = aten::CSRSliceContiguousChunk(\n          csc, edge_range, vertex_range[dst_vtype], vertex_range[src_vtype]);\n      rgptr =\n          UnitGraph::CreateFromCSC((src_vtype == dst_vtype) ? 1 : 2, res, code);\n    }\n\n    rel_graphs[etype] = rgptr;\n  }\n\n  return CreateHeteroGraph(\n      meta_graph, rel_graphs, num_nodes_per_type.ToVector<int64_t>());\n}\n\ntemplate <class IdType>\nstd::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes(\n    GraphPtr meta_graph, HeteroGraphPtr batched_graph, IdArray vertex_sizes,\n    IdArray edge_sizes) {\n  // Sanity check for vertex sizes\n  const uint64_t len_vertex_sizes = vertex_sizes->shape[0];\n  const uint64_t* vertex_sizes_data =\n      static_cast<uint64_t*>(vertex_sizes->data);\n  const uint64_t num_vertex_types = meta_graph->NumVertices();\n  const uint64_t batch_size = len_vertex_sizes / num_vertex_types;\n  // Map vertex type to the corresponding node cum sum\n  std::vector<std::vector<uint64_t>> vertex_cumsum;\n  vertex_cumsum.resize(num_vertex_types);\n  // Loop over all vertex types\n  for (uint64_t vtype = 0; vtype < num_vertex_types; ++vtype) {\n    vertex_cumsum[vtype].push_back(0);\n    for (uint64_t g = 0; g < batch_size; ++g) {\n      // We've flattened the number of vertices in the batch for all types\n      vertex_cumsum[vtype].push_back(\n          vertex_cumsum[vtype][g] + vertex_sizes_data[vtype * batch_size + g]);\n    }\n    CHECK_EQ(\n        vertex_cumsum[vtype][batch_size], batched_graph->NumVertices(vtype))\n        << \"Sum of the given sizes must equal to the number of nodes for type \"\n        << vtype;\n  }\n\n  // Sanity check for edge sizes\n  const uint64_t* edge_sizes_data = static_cast<uint64_t*>(edge_sizes->data);\n  const uint64_t num_edge_types = meta_graph->NumEdges();\n  // Map edge type to the corresponding edge cum sum\n  std::vector<std::vector<uint64_t>> edge_cumsum;\n  edge_cumsum.resize(num_edge_types);\n  // Loop over all edge types\n  for (uint64_t etype = 0; etype < num_edge_types; ++etype) {\n    edge_cumsum[etype].push_back(0);\n    for (uint64_t g = 0; g < batch_size; ++g) {\n      // We've flattened the number of edges in the batch for all types\n      edge_cumsum[etype].push_back(\n          edge_cumsum[etype][g] + edge_sizes_data[etype * batch_size + g]);\n    }\n    CHECK_EQ(edge_cumsum[etype][batch_size], batched_graph->NumEdges(etype))\n        << \"Sum of the given sizes must equal to the number of edges for type \"\n        << etype;\n  }\n\n  // Construct relation graphs for unbatched graphs\n  std::vector<std::vector<HeteroGraphPtr>> rel_graphs;\n  rel_graphs.resize(batch_size);\n  // Loop over all edge types\n  for (uint64_t etype = 0; etype < num_edge_types; ++etype) {\n    auto pair = meta_graph->FindEdge(etype);\n    const dgl_type_t src_vtype = pair.first;\n    const dgl_type_t dst_vtype = pair.second;\n    EdgeArray edges = batched_graph->Edges(etype);\n    const IdType* edges_src_data = static_cast<const IdType*>(edges.src->data);\n    const IdType* edges_dst_data = static_cast<const IdType*>(edges.dst->data);\n    // Loop over all graphs to be unbatched\n    for (uint64_t g = 0; g < batch_size; ++g) {\n      std::vector<IdType> result_src, result_dst;\n      // Loop over the chunk of edges for the specified graph and edge type\n      for (uint64_t e = edge_cumsum[etype][g]; e < edge_cumsum[etype][g + 1];\n           ++e) {\n        // TODO(mufei): Should use array operations to implement this.\n        result_src.push_back(edges_src_data[e] - vertex_cumsum[src_vtype][g]);\n        result_dst.push_back(edges_dst_data[e] - vertex_cumsum[dst_vtype][g]);\n      }\n      HeteroGraphPtr rgptr = UnitGraph::CreateFromCOO(\n          (src_vtype == dst_vtype) ? 1 : 2,\n          vertex_sizes_data[src_vtype * batch_size + g],\n          vertex_sizes_data[dst_vtype * batch_size + g],\n          aten::VecToIdArray(result_src, sizeof(IdType) * 8),\n          aten::VecToIdArray(result_dst, sizeof(IdType) * 8));\n      rel_graphs[g].push_back(rgptr);\n    }\n  }\n\n  std::vector<HeteroGraphPtr> rst;\n  std::vector<int64_t> num_nodes_per_type(num_vertex_types);\n  for (uint64_t g = 0; g < batch_size; ++g) {\n    for (uint64_t i = 0; i < num_vertex_types; ++i)\n      num_nodes_per_type[i] = vertex_sizes_data[i * batch_size + g];\n    rst.push_back(\n        CreateHeteroGraph(meta_graph, rel_graphs[g], num_nodes_per_type));\n  }\n  return rst;\n}\n\ntemplate std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes<int32_t>(\n    GraphPtr meta_graph, HeteroGraphPtr batched_graph, IdArray vertex_sizes,\n    IdArray edge_sizes);\n\ntemplate std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes<int64_t>(\n    GraphPtr meta_graph, HeteroGraphPtr batched_graph, IdArray vertex_sizes,\n    IdArray edge_sizes);\n\n}  // namespace dgl\n"
  },
  {
    "path": "src/graph/traversal.cc",
    "content": "/**\n *  Copyright (c) 2018 by Contributors\n * @file graph/traversal.cc\n * @brief Graph traversal implementation\n */\n#include \"./traversal.h\"\n\n#include <dgl/packed_func_ext.h>\n\n#include <algorithm>\n#include <queue>\n\n#include \"../c_api_common.h\"\n\nusing namespace dgl::runtime;\n\nnamespace dgl {\nnamespace traverse {\nnamespace {\n// A utility view class to wrap a vector into a queue.\ntemplate <typename DType>\nstruct VectorQueueWrapper {\n  std::vector<DType>* vec;\n  size_t head = 0;\n\n  explicit VectorQueueWrapper(std::vector<DType>* vec) : vec(vec) {}\n\n  void push(const DType& elem) { vec->push_back(elem); }\n\n  DType top() const { return vec->operator[](head); }\n\n  void pop() { ++head; }\n\n  bool empty() const { return head == vec->size(); }\n\n  size_t size() const { return vec->size() - head; }\n};\n\n// Internal function to merge multiple traversal traces into one ndarray.\n// It is similar to zip the vectors together.\ntemplate <typename DType>\nIdArray MergeMultipleTraversals(const std::vector<std::vector<DType>>& traces) {\n  int64_t max_len = 0, total_len = 0;\n  for (size_t i = 0; i < traces.size(); ++i) {\n    const int64_t tracelen = traces[i].size();\n    max_len = std::max(max_len, tracelen);\n    total_len += traces[i].size();\n  }\n  IdArray ret = IdArray::Empty(\n      {total_len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});\n  int64_t* ret_data = static_cast<int64_t*>(ret->data);\n  for (int64_t i = 0; i < max_len; ++i) {\n    for (size_t j = 0; j < traces.size(); ++j) {\n      const int64_t tracelen = traces[j].size();\n      if (i >= tracelen) {\n        continue;\n      }\n      *(ret_data++) = traces[j][i];\n    }\n  }\n  return ret;\n}\n\n// Internal function to compute sections if multiple traversal traces\n// are merged into one ndarray.\ntemplate <typename DType>\nIdArray ComputeMergedSections(const std::vector<std::vector<DType>>& traces) {\n  int64_t max_len = 0;\n  for (size_t i = 0; i < traces.size(); ++i) {\n    const int64_t tracelen = traces[i].size();\n    max_len = std::max(max_len, tracelen);\n  }\n  IdArray ret = IdArray::Empty(\n      {max_len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});\n  int64_t* ret_data = static_cast<int64_t*>(ret->data);\n  for (int64_t i = 0; i < max_len; ++i) {\n    int64_t sec_len = 0;\n    for (size_t j = 0; j < traces.size(); ++j) {\n      const int64_t tracelen = traces[j].size();\n      if (i < tracelen) {\n        ++sec_len;\n      }\n    }\n    *(ret_data++) = sec_len;\n  }\n  return ret;\n}\n\n}  // namespace\n\n/**\n * @brief Class for representing frontiers.\n *\n * Each frontier is a list of nodes/edges (specified by their ids).\n * An optional tag can be specified on each node/edge (represented by an int\n * value).\n */\nstruct Frontiers {\n  /** @brief a vector store for the nodes/edges in all the frontiers */\n  std::vector<dgl_id_t> ids;\n\n  /**\n   * @brief a vector store for node/edge tags. Empty if no tags are requested\n   */\n  std::vector<int64_t> tags;\n\n  /** @brief a section vector to indicate each frontier */\n  std::vector<int64_t> sections;\n};\n\nFrontiers BFSNodesFrontiers(\n    const GraphInterface& graph, IdArray source, bool reversed) {\n  Frontiers front;\n  VectorQueueWrapper<dgl_id_t> queue(&front.ids);\n  auto visit = [&](const dgl_id_t v) {};\n  auto make_frontier = [&]() {\n    if (!queue.empty()) {\n      // do not push zero-length frontier\n      front.sections.push_back(queue.size());\n    }\n  };\n  BFSNodes(graph, source, reversed, &queue, visit, make_frontier);\n  return front;\n}\n\nDGL_REGISTER_GLOBAL(\"traversal._CAPI_DGLBFSNodes\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      GraphRef g = args[0];\n      const IdArray src = args[1];\n      bool reversed = args[2];\n      const auto& front = BFSNodesFrontiers(*(g.sptr()), src, reversed);\n      IdArray node_ids = CopyVectorToNDArray<int64_t>(front.ids);\n      IdArray sections = CopyVectorToNDArray<int64_t>(front.sections);\n      *rv = ConvertNDArrayVectorToPackedFunc({node_ids, sections});\n    });\n\nFrontiers BFSEdgesFrontiers(\n    const GraphInterface& graph, IdArray source, bool reversed) {\n  Frontiers front;\n  // NOTE: std::queue has no top() method.\n  std::vector<dgl_id_t> nodes;\n  VectorQueueWrapper<dgl_id_t> queue(&nodes);\n  auto visit = [&](const dgl_id_t e) { front.ids.push_back(e); };\n  bool first_frontier = true;\n  auto make_frontier = [&] {\n    if (first_frontier) {\n      first_frontier = false;  // do not push the first section when doing edges\n    } else if (!queue.empty()) {\n      // do not push zero-length frontier\n      front.sections.push_back(queue.size());\n    }\n  };\n  BFSEdges(graph, source, reversed, &queue, visit, make_frontier);\n  return front;\n}\n\nDGL_REGISTER_GLOBAL(\"traversal._CAPI_DGLBFSEdges\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      GraphRef g = args[0];\n      const IdArray src = args[1];\n      bool reversed = args[2];\n      const auto& front = BFSEdgesFrontiers(*(g.sptr()), src, reversed);\n      IdArray edge_ids = CopyVectorToNDArray<int64_t>(front.ids);\n      IdArray sections = CopyVectorToNDArray<int64_t>(front.sections);\n      *rv = ConvertNDArrayVectorToPackedFunc({edge_ids, sections});\n    });\n\nFrontiers TopologicalNodesFrontiers(\n    const GraphInterface& graph, bool reversed) {\n  Frontiers front;\n  VectorQueueWrapper<dgl_id_t> queue(&front.ids);\n  auto visit = [&](const dgl_id_t v) {};\n  auto make_frontier = [&]() {\n    if (!queue.empty()) {\n      // do not push zero-length frontier\n      front.sections.push_back(queue.size());\n    }\n  };\n  TopologicalNodes(graph, reversed, &queue, visit, make_frontier);\n  return front;\n}\n\nDGL_REGISTER_GLOBAL(\"traversal._CAPI_DGLTopologicalNodes\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      GraphRef g = args[0];\n      bool reversed = args[1];\n      const auto& front = TopologicalNodesFrontiers(*g.sptr(), reversed);\n      IdArray node_ids = CopyVectorToNDArray<int64_t>(front.ids);\n      IdArray sections = CopyVectorToNDArray<int64_t>(front.sections);\n      *rv = ConvertNDArrayVectorToPackedFunc({node_ids, sections});\n    });\n\nDGL_REGISTER_GLOBAL(\"traversal._CAPI_DGLDFSEdges\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      GraphRef g = args[0];\n      const IdArray source = args[1];\n      const bool reversed = args[2];\n      CHECK(aten::IsValidIdArray(source)) << \"Invalid source node id array.\";\n      const int64_t len = source->shape[0];\n      const int64_t* src_data = static_cast<int64_t*>(source->data);\n      std::vector<std::vector<dgl_id_t>> edges(len);\n      for (int64_t i = 0; i < len; ++i) {\n        auto visit = [&](dgl_id_t e, int tag) { edges[i].push_back(e); };\n        DFSLabeledEdges(*g.sptr(), src_data[i], reversed, false, false, visit);\n      }\n      IdArray ids = MergeMultipleTraversals(edges);\n      IdArray sections = ComputeMergedSections(edges);\n      *rv = ConvertNDArrayVectorToPackedFunc({ids, sections});\n    });\n\nDGL_REGISTER_GLOBAL(\"traversal._CAPI_DGLDFSLabeledEdges\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      GraphRef g = args[0];\n      const IdArray source = args[1];\n      const bool reversed = args[2];\n      const bool has_reverse_edge = args[3];\n      const bool has_nontree_edge = args[4];\n      const bool return_labels = args[5];\n\n      CHECK(aten::IsValidIdArray(source)) << \"Invalid source node id array.\";\n      const int64_t len = source->shape[0];\n      const int64_t* src_data = static_cast<int64_t*>(source->data);\n\n      std::vector<std::vector<dgl_id_t>> edges(len);\n      std::vector<std::vector<int64_t>> tags;\n      if (return_labels) {\n        tags.resize(len);\n      }\n      for (int64_t i = 0; i < len; ++i) {\n        auto visit = [&](dgl_id_t e, int tag) {\n          edges[i].push_back(e);\n          if (return_labels) {\n            tags[i].push_back(tag);\n          }\n        };\n        DFSLabeledEdges(\n            *g.sptr(), src_data[i], reversed, has_reverse_edge,\n            has_nontree_edge, visit);\n      }\n\n      IdArray ids = MergeMultipleTraversals(edges);\n      IdArray sections = ComputeMergedSections(edges);\n      if (return_labels) {\n        IdArray labels = MergeMultipleTraversals(tags);\n        *rv = ConvertNDArrayVectorToPackedFunc({ids, labels, sections});\n      } else {\n        *rv = ConvertNDArrayVectorToPackedFunc({ids, sections});\n      }\n    });\n\n}  // namespace traverse\n}  // namespace dgl\n"
  },
  {
    "path": "src/graph/traversal.h",
    "content": "/**\n *  Copyright (c) 2018 by Contributors\n * @file graph/traversal.h\n * @brief Graph traversal routines.\n *\n * Traversal routines generate frontiers. Frontiers can be node frontiers or\n * edge frontiers depending on the traversal function. Each frontier is a list\n * of nodes/edges (specified by their ids). An optional tag can be specified for\n * each node/edge (represented by an int value).\n */\n#ifndef DGL_GRAPH_TRAVERSAL_H_\n#define DGL_GRAPH_TRAVERSAL_H_\n\n#include <dgl/graph_interface.h>\n\n#include <stack>\n#include <tuple>\n#include <vector>\n\nnamespace dgl {\nnamespace traverse {\n\n/**\n * @brief Traverse the graph in a breadth-first-search (BFS) order.\n *\n * The queue object must suffice following interface:\n *   Members:\n *   void push(dgl_id_t);  // push one node\n *   dgl_id_t top();       // get the first node\n *   void pop();           // pop one node\n *   bool empty();         // return true if the queue is empty\n *   size_t size();        // return the size of the queue\n * For example, std::queue<dgl_id_t> is a valid queue type.\n *\n * The visit function must be compatible with following interface:\n *   void (*visit)(dgl_id_t );\n *\n * The frontier function must be compatible with following interface:\n *   void (*make_frontier)(void);\n *\n * @param graph The graph.\n * @param sources Source nodes.\n * @param reversed If true, BFS follows the in-edge direction.\n * @param queue The queue used to do bfs.\n * @param visit The function to call when a node is visited.\n * @param make_frontier The function to indicate that a new froniter can be\n *        made.\n */\ntemplate <typename Queue, typename VisitFn, typename FrontierFn>\nvoid BFSNodes(\n    const GraphInterface& graph, IdArray source, bool reversed, Queue* queue,\n    VisitFn visit, FrontierFn make_frontier) {\n  const int64_t len = source->shape[0];\n  const int64_t* src_data = static_cast<int64_t*>(source->data);\n\n  std::vector<bool> visited(graph.NumVertices());\n  for (int64_t i = 0; i < len; ++i) {\n    const dgl_id_t u = src_data[i];\n    visited[u] = true;\n    visit(u);\n    queue->push(u);\n  }\n  make_frontier();\n\n  const auto neighbor_iter =\n      reversed ? &GraphInterface::PredVec : &GraphInterface::SuccVec;\n  while (!queue->empty()) {\n    const size_t size = queue->size();\n    for (size_t i = 0; i < size; ++i) {\n      const dgl_id_t u = queue->top();\n      queue->pop();\n      for (auto v : (graph.*neighbor_iter)(u)) {\n        if (!visited[v]) {\n          visited[v] = true;\n          visit(v);\n          queue->push(v);\n        }\n      }\n    }\n    make_frontier();\n  }\n}\n\n/**\n * @brief Traverse the graph in a breadth-first-search (BFS) order, returning\n *        the edges of the BFS tree.\n *\n * The queue object must suffice following interface:\n *   Members:\n *   void push(dgl_id_t);  // push one node\n *   dgl_id_t top();       // get the first node\n *   void pop();           // pop one node\n *   bool empty();         // return true if the queue is empty\n *   size_t size();        // return the size of the queue\n * For example, std::queue<dgl_id_t> is a valid queue type.\n *\n * The visit function must be compatible with following interface:\n *   void (*visit)(dgl_id_t );\n *\n * The frontier function must be compatible with following interface:\n *   void (*make_frontier)(void);\n *\n * @param graph The graph.\n * @param sources Source nodes.\n * @param reversed If true, BFS follows the in-edge direction.\n * @param queue The queue used to do bfs.\n * @param visit The function to call when a node is visited.\n *        The argument would be edge ID.\n * @param make_frontier The function to indicate that a new frontier can be\n *        made.\n */\ntemplate <typename Queue, typename VisitFn, typename FrontierFn>\nvoid BFSEdges(\n    const GraphInterface& graph, IdArray source, bool reversed, Queue* queue,\n    VisitFn visit, FrontierFn make_frontier) {\n  const int64_t len = source->shape[0];\n  const int64_t* src_data = static_cast<int64_t*>(source->data);\n\n  std::vector<bool> visited(graph.NumVertices());\n  for (int64_t i = 0; i < len; ++i) {\n    const dgl_id_t u = src_data[i];\n    visited[u] = true;\n    queue->push(u);\n  }\n  make_frontier();\n\n  const auto neighbor_iter =\n      reversed ? &GraphInterface::InEdgeVec : &GraphInterface::OutEdgeVec;\n  while (!queue->empty()) {\n    const size_t size = queue->size();\n    for (size_t i = 0; i < size; ++i) {\n      const dgl_id_t u = queue->top();\n      queue->pop();\n      for (auto e : (graph.*neighbor_iter)(u)) {\n        const auto uv = graph.FindEdge(e);\n        const dgl_id_t v = (reversed ? uv.first : uv.second);\n        if (!visited[v]) {\n          visited[v] = true;\n          visit(e);\n          queue->push(v);\n        }\n      }\n    }\n    make_frontier();\n  }\n}\n\n/**\n * @brief Traverse the graph in topological order.\n *\n * The queue object must suffice following interface:\n *   Members:\n *   void push(dgl_id_t);  // push one node\n *   dgl_id_t top();       // get the first node\n *   void pop();           // pop one node\n *   bool empty();         // return true if the queue is empty\n *   size_t size();        // return the size of the queue\n * For example, std::queue<dgl_id_t> is a valid queue type.\n *\n * The visit function must be compatible with following interface:\n *   void (*visit)(dgl_id_t );\n *\n * The frontier function must be compatible with following interface:\n *   void (*make_frontier)(void);\n *\n * @param graph The graph.\n * @param reversed If true, follows the in-edge direction.\n * @param queue The queue used to do bfs.\n * @param visit The function to call when a node is visited.\n * @param make_frontier The function to indicate that a new froniter can be\n *        made.\n */\ntemplate <typename Queue, typename VisitFn, typename FrontierFn>\nvoid TopologicalNodes(\n    const GraphInterface& graph, bool reversed, Queue* queue, VisitFn visit,\n    FrontierFn make_frontier) {\n  const auto get_degree =\n      reversed ? &GraphInterface::OutDegree : &GraphInterface::InDegree;\n  const auto neighbor_iter =\n      reversed ? &GraphInterface::PredVec : &GraphInterface::SuccVec;\n  uint64_t num_visited_nodes = 0;\n  std::vector<uint64_t> degrees(graph.NumVertices(), 0);\n  for (dgl_id_t vid = 0; vid < graph.NumVertices(); ++vid) {\n    degrees[vid] = (graph.*get_degree)(vid);\n    if (degrees[vid] == 0) {\n      visit(vid);\n      queue->push(vid);\n      ++num_visited_nodes;\n    }\n  }\n  make_frontier();\n\n  while (!queue->empty()) {\n    const size_t size = queue->size();\n    for (size_t i = 0; i < size; ++i) {\n      const dgl_id_t u = queue->top();\n      queue->pop();\n      for (auto v : (graph.*neighbor_iter)(u)) {\n        if (--(degrees[v]) == 0) {\n          visit(v);\n          queue->push(v);\n          ++num_visited_nodes;\n        }\n      }\n    }\n    make_frontier();\n  }\n\n  if (num_visited_nodes != graph.NumVertices()) {\n    LOG(FATAL)\n        << \"Error in topological traversal: loop detected in the given graph.\";\n  }\n}\n\n/** @brief Tags for ``DFSEdges``. */\nenum DFSEdgeTag {\n  kForward = 0,\n  kReverse,\n  kNonTree,\n};\n/**\n * @brief Traverse the graph in a depth-first-search (DFS) order.\n *\n * The traversal visit edges in its DFS order. Edges have three tags:\n * FORWARD(0), REVERSE(1), NONTREE(2).\n *\n * A FORWARD edge is one in which `u` has been visisted but `v` has not.\n * A REVERSE edge is one in which both `u` and `v` have been visisted and the\n * edge is in the DFS tree. A NONTREE edge is one in which both `u` and `v` have\n * been visisted but the edge is NOT in the DFS tree.\n *\n * @param source Source node.\n * @param reversed If true, DFS follows the in-edge direction.\n * @param has_reverse_edge If true, REVERSE edges are included.\n * @param has_nontree_edge If true, NONTREE edges are included.\n * @param visit The function to call when an edge is visited; the edge id and\n *        its tag will be given as the arguments.\n */\ntemplate <typename VisitFn>\nvoid DFSLabeledEdges(\n    const GraphInterface& graph, dgl_id_t source, bool reversed,\n    bool has_reverse_edge, bool has_nontree_edge, VisitFn visit) {\n  const auto succ =\n      reversed ? &GraphInterface::PredVec : &GraphInterface::SuccVec;\n  const auto out_edge =\n      reversed ? &GraphInterface::InEdgeVec : &GraphInterface::OutEdgeVec;\n\n  if ((graph.*succ)(source).size() == 0) {\n    // no out-going edges from the source node\n    return;\n  }\n\n  typedef std::tuple<dgl_id_t, size_t, bool> StackEntry;\n  std::stack<StackEntry> stack;\n  std::vector<bool> visited(graph.NumVertices());\n  visited[source] = true;\n  stack.push(std::make_tuple(source, 0, false));\n  dgl_id_t u = 0;\n  size_t i = 0;\n  bool on_tree = false;\n\n  while (!stack.empty()) {\n    std::tie(u, i, on_tree) = stack.top();\n    const dgl_id_t v = (graph.*succ)(u)[i];\n    const dgl_id_t uv = (graph.*out_edge)(u)[i];\n    if (visited[v]) {\n      if (!on_tree && has_nontree_edge) {\n        visit(uv, kNonTree);\n      } else if (on_tree && has_reverse_edge) {\n        visit(uv, kReverse);\n      }\n      stack.pop();\n      // find next one.\n      if (i < (graph.*succ)(u).size() - 1) {\n        stack.push(std::make_tuple(u, i + 1, false));\n      }\n    } else {\n      visited[v] = true;\n      std::get<2>(stack.top()) = true;\n      visit(uv, kForward);\n      // expand\n      if ((graph.*succ)(v).size() > 0) {\n        stack.push(std::make_tuple(v, 0, false));\n      }\n    }\n  }\n}\n\n}  // namespace traverse\n}  // namespace dgl\n\n#endif  // DGL_GRAPH_TRAVERSAL_H_\n"
  },
  {
    "path": "src/graph/unit_graph.cc",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file graph/unit_graph.cc\n * @brief UnitGraph graph implementation\n */\n#include \"./unit_graph.h\"\n\n#include <dgl/array.h>\n#include <dgl/base_heterograph.h>\n#include <dgl/immutable_graph.h>\n#include <dgl/lazy.h>\n\n#include \"../c_api_common.h\"\n#include \"./serialize/dglstream.h\"\n\nnamespace dgl {\n\nnamespace {\n\nusing namespace dgl::aten;\n\n// create metagraph of one node type\ninline GraphPtr CreateUnitGraphMetaGraph1() {\n  // a self-loop edge 0->0\n  std::vector<int64_t> row_vec(1, 0);\n  std::vector<int64_t> col_vec(1, 0);\n  IdArray row = aten::VecToIdArray(row_vec);\n  IdArray col = aten::VecToIdArray(col_vec);\n  GraphPtr g = ImmutableGraph::CreateFromCOO(1, row, col);\n  return g;\n}\n\n// create metagraph of two node types\ninline GraphPtr CreateUnitGraphMetaGraph2() {\n  // an edge 0->1\n  std::vector<int64_t> row_vec(1, 0);\n  std::vector<int64_t> col_vec(1, 1);\n  IdArray row = aten::VecToIdArray(row_vec);\n  IdArray col = aten::VecToIdArray(col_vec);\n  GraphPtr g = ImmutableGraph::CreateFromCOO(2, row, col);\n  return g;\n}\n\ninline GraphPtr CreateUnitGraphMetaGraph(int num_vtypes) {\n  static GraphPtr mg1 = CreateUnitGraphMetaGraph1();\n  static GraphPtr mg2 = CreateUnitGraphMetaGraph2();\n  if (num_vtypes == 1)\n    return mg1;\n  else if (num_vtypes == 2)\n    return mg2;\n  else\n    LOG(FATAL) << \"Invalid number of vertex types. Must be 1 or 2.\";\n  return {};\n}\n\n};  // namespace\n\n//////////////////////////////////////////////////////////\n//\n// COO graph implementation\n//\n//////////////////////////////////////////////////////////\n\nclass UnitGraph::COO : public BaseHeteroGraph {\n public:\n  COO(GraphPtr metagraph, int64_t num_src, int64_t num_dst, IdArray src,\n      IdArray dst, bool row_sorted = false, bool col_sorted = false)\n      : BaseHeteroGraph(metagraph) {\n    CHECK(aten::IsValidIdArray(src));\n    CHECK(aten::IsValidIdArray(dst));\n    CHECK_EQ(src->shape[0], dst->shape[0])\n        << \"Input arrays should have the same length.\";\n    adj_ = aten::COOMatrix{num_src,     num_dst,    src,       dst,\n                           NullArray(), row_sorted, col_sorted};\n  }\n\n  COO(GraphPtr metagraph, const aten::COOMatrix& coo)\n      : BaseHeteroGraph(metagraph), adj_(coo) {\n    // Data index should not be inherited. Edges in COO format are always\n    // assigned ids from 0 to num_edges - 1.\n    CHECK(!COOHasData(coo)) << \"[BUG] COO should not contain data.\";\n    adj_.data = aten::NullArray();\n  }\n\n  COO() {\n    // set magic num_rows/num_cols to mark it as undefined\n    // adj_.num_rows == 0 and adj_.num_cols == 0 means empty UnitGraph which is\n    // supported\n    adj_.num_rows = -1;\n    adj_.num_cols = -1;\n  };\n\n  bool defined() const { return (adj_.num_rows >= 0) && (adj_.num_cols >= 0); }\n\n  inline dgl_type_t SrcType() const { return 0; }\n\n  inline dgl_type_t DstType() const { return NumVertexTypes() == 1 ? 0 : 1; }\n\n  inline dgl_type_t EdgeType() const { return 0; }\n\n  HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const override {\n    LOG(FATAL) << \"The method shouldn't be called for UnitGraph graph. \"\n               << \"The relation graph is simply this graph itself.\";\n    return {};\n  }\n\n  void AddVertices(dgl_type_t vtype, uint64_t num_vertices) override {\n    LOG(FATAL) << \"UnitGraph graph is not mutable.\";\n  }\n\n  void AddEdge(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) override {\n    LOG(FATAL) << \"UnitGraph graph is not mutable.\";\n  }\n\n  void AddEdges(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) override {\n    LOG(FATAL) << \"UnitGraph graph is not mutable.\";\n  }\n\n  void Clear() override { LOG(FATAL) << \"UnitGraph graph is not mutable.\"; }\n\n  DGLDataType DataType() const override { return adj_.row->dtype; }\n\n  DGLContext Context() const override { return adj_.row->ctx; }\n\n  bool IsPinned() const override { return adj_.is_pinned; }\n\n  uint8_t NumBits() const override { return adj_.row->dtype.bits; }\n\n  COO AsNumBits(uint8_t bits) const {\n    if (NumBits() == bits) return *this;\n\n    COO ret(\n        meta_graph_, adj_.num_rows, adj_.num_cols,\n        aten::AsNumBits(adj_.row, bits), aten::AsNumBits(adj_.col, bits));\n    return ret;\n  }\n\n  COO CopyTo(const DGLContext& ctx) const {\n    if (Context() == ctx) return *this;\n    return COO(meta_graph_, adj_.CopyTo(ctx));\n  }\n\n  /**\n   * @brief Copy the adj_ to pinned memory.\n   * @return COOMatrix of the COO graph.\n   */\n  COO PinMemory() {\n    if (adj_.is_pinned) return *this;\n    return COO(meta_graph_, adj_.PinMemory());\n  }\n\n  /** @brief Pin the adj_: COOMatrix of the COO graph. */\n  void PinMemory_() { adj_.PinMemory_(); }\n\n  /** @brief Unpin the adj_: COOMatrix of the COO graph. */\n  void UnpinMemory_() { adj_.UnpinMemory_(); }\n\n  /** @brief Record stream for the adj_: COOMatrix of the COO graph. */\n  void RecordStream(DGLStreamHandle stream) override {\n    adj_.RecordStream(stream);\n  }\n\n  bool IsMultigraph() const override { return aten::COOHasDuplicate(adj_); }\n\n  bool IsReadonly() const override { return true; }\n\n  uint64_t NumVertices(dgl_type_t vtype) const override {\n    if (vtype == SrcType()) {\n      return adj_.num_rows;\n    } else if (vtype == DstType()) {\n      return adj_.num_cols;\n    } else {\n      LOG(FATAL) << \"Invalid vertex type: \" << vtype;\n      return 0;\n    }\n  }\n\n  uint64_t NumEdges(dgl_type_t etype) const override {\n    return adj_.row->shape[0];\n  }\n\n  bool HasVertex(dgl_type_t vtype, dgl_id_t vid) const override {\n    return vid < NumVertices(vtype);\n  }\n\n  BoolArray HasVertices(dgl_type_t vtype, IdArray vids) const override {\n    LOG(FATAL) << \"Not enabled for COO graph\";\n    return {};\n  }\n\n  bool HasEdgeBetween(\n      dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {\n    CHECK(HasVertex(SrcType(), src)) << \"Invalid src vertex id: \" << src;\n    CHECK(HasVertex(DstType(), dst)) << \"Invalid dst vertex id: \" << dst;\n    return aten::COOIsNonZero(adj_, src, dst);\n  }\n\n  BoolArray HasEdgesBetween(\n      dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const override {\n    CHECK(aten::IsValidIdArray(src_ids)) << \"Invalid vertex id array.\";\n    CHECK(aten::IsValidIdArray(dst_ids)) << \"Invalid vertex id array.\";\n    return aten::COOIsNonZero(adj_, src_ids, dst_ids);\n  }\n\n  IdArray Predecessors(dgl_type_t etype, dgl_id_t dst) const override {\n    CHECK(HasVertex(DstType(), dst)) << \"Invalid dst vertex id: \" << dst;\n    return aten::COOGetRowDataAndIndices(aten::COOTranspose(adj_), dst).second;\n  }\n\n  IdArray Successors(dgl_type_t etype, dgl_id_t src) const override {\n    CHECK(HasVertex(SrcType(), src)) << \"Invalid src vertex id: \" << src;\n    return aten::COOGetRowDataAndIndices(adj_, src).second;\n  }\n\n  IdArray EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {\n    CHECK(HasVertex(SrcType(), src)) << \"Invalid src vertex id: \" << src;\n    CHECK(HasVertex(DstType(), dst)) << \"Invalid dst vertex id: \" << dst;\n    return aten::COOGetAllData(adj_, src, dst);\n  }\n\n  EdgeArray EdgeIdsAll(\n      dgl_type_t etype, IdArray src, IdArray dst) const override {\n    CHECK(aten::IsValidIdArray(src)) << \"Invalid vertex id array.\";\n    CHECK(aten::IsValidIdArray(dst)) << \"Invalid vertex id array.\";\n    const auto& arrs = aten::COOGetDataAndIndices(adj_, src, dst);\n    return EdgeArray{arrs[0], arrs[1], arrs[2]};\n  }\n\n  IdArray EdgeIdsOne(\n      dgl_type_t etype, IdArray src, IdArray dst) const override {\n    return aten::COOGetData(adj_, src, dst);\n  }\n\n  std::pair<dgl_id_t, dgl_id_t> FindEdge(\n      dgl_type_t etype, dgl_id_t eid) const override {\n    CHECK(eid < NumEdges(etype)) << \"Invalid edge id: \" << eid;\n    const dgl_id_t src = aten::IndexSelect<int64_t>(adj_.row, eid);\n    const dgl_id_t dst = aten::IndexSelect<int64_t>(adj_.col, eid);\n    return std::pair<dgl_id_t, dgl_id_t>(src, dst);\n  }\n\n  EdgeArray FindEdges(dgl_type_t etype, IdArray eids) const override {\n    CHECK(aten::IsValidIdArray(eids)) << \"Invalid edge id array\";\n    BUG_IF_FAIL(aten::IsNullArray(adj_.data))\n        << \"FindEdges requires the internal COO matrix not having EIDs.\";\n    return EdgeArray{\n        aten::IndexSelect(adj_.row, eids), aten::IndexSelect(adj_.col, eids),\n        eids};\n  }\n\n  EdgeArray InEdges(dgl_type_t etype, dgl_id_t vid) const override {\n    IdArray ret_src, ret_eid;\n    std::tie(ret_eid, ret_src) =\n        aten::COOGetRowDataAndIndices(aten::COOTranspose(adj_), vid);\n    IdArray ret_dst =\n        aten::Full(vid, ret_src->shape[0], NumBits(), ret_src->ctx);\n    return EdgeArray{ret_src, ret_dst, ret_eid};\n  }\n\n  EdgeArray InEdges(dgl_type_t etype, IdArray vids) const override {\n    CHECK(aten::IsValidIdArray(vids)) << \"Invalid vertex id array.\";\n    auto coosubmat = aten::COOSliceRows(aten::COOTranspose(adj_), vids);\n    auto row = aten::IndexSelect(vids, coosubmat.row);\n    return EdgeArray{coosubmat.col, row, coosubmat.data};\n  }\n\n  EdgeArray OutEdges(dgl_type_t etype, dgl_id_t vid) const override {\n    IdArray ret_dst, ret_eid;\n    std::tie(ret_eid, ret_dst) = aten::COOGetRowDataAndIndices(adj_, vid);\n    IdArray ret_src =\n        aten::Full(vid, ret_dst->shape[0], NumBits(), ret_dst->ctx);\n    return EdgeArray{ret_src, ret_dst, ret_eid};\n  }\n\n  EdgeArray OutEdges(dgl_type_t etype, IdArray vids) const override {\n    CHECK(aten::IsValidIdArray(vids)) << \"Invalid vertex id array.\";\n    auto coosubmat = aten::COOSliceRows(adj_, vids);\n    auto row = aten::IndexSelect(vids, coosubmat.row);\n    return EdgeArray{row, coosubmat.col, coosubmat.data};\n  }\n\n  EdgeArray Edges(\n      dgl_type_t etype, const std::string& order = \"\") const override {\n    CHECK(order.empty() || order == std::string(\"eid\"))\n        << \"COO only support Edges of order \\\"eid\\\", but got \\\"\" << order\n        << \"\\\".\";\n    IdArray rst_eid = aten::Range(0, NumEdges(etype), NumBits(), Context());\n    return EdgeArray{adj_.row, adj_.col, rst_eid};\n  }\n\n  uint64_t InDegree(dgl_type_t etype, dgl_id_t vid) const override {\n    CHECK(HasVertex(DstType(), vid)) << \"Invalid dst vertex id: \" << vid;\n    return aten::COOGetRowNNZ(aten::COOTranspose(adj_), vid);\n  }\n\n  DegreeArray InDegrees(dgl_type_t etype, IdArray vids) const override {\n    CHECK(aten::IsValidIdArray(vids)) << \"Invalid vertex id array.\";\n    return aten::COOGetRowNNZ(aten::COOTranspose(adj_), vids);\n  }\n\n  uint64_t OutDegree(dgl_type_t etype, dgl_id_t vid) const override {\n    CHECK(HasVertex(SrcType(), vid)) << \"Invalid src vertex id: \" << vid;\n    return aten::COOGetRowNNZ(adj_, vid);\n  }\n\n  DegreeArray OutDegrees(dgl_type_t etype, IdArray vids) const override {\n    CHECK(aten::IsValidIdArray(vids)) << \"Invalid vertex id array.\";\n    return aten::COOGetRowNNZ(adj_, vids);\n  }\n\n  DGLIdIters SuccVec(dgl_type_t etype, dgl_id_t vid) const override {\n    LOG(INFO) << \"Not enabled for COO graph.\";\n    return {};\n  }\n\n  DGLIdIters OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const override {\n    LOG(INFO) << \"Not enabled for COO graph.\";\n    return {};\n  }\n\n  DGLIdIters PredVec(dgl_type_t etype, dgl_id_t vid) const override {\n    LOG(INFO) << \"Not enabled for COO graph.\";\n    return {};\n  }\n\n  DGLIdIters InEdgeVec(dgl_type_t etype, dgl_id_t vid) const override {\n    LOG(INFO) << \"Not enabled for COO graph.\";\n    return {};\n  }\n\n  std::vector<IdArray> GetAdj(\n      dgl_type_t etype, bool transpose, const std::string& fmt) const override {\n    CHECK(fmt == \"coo\") << \"Not valid adj format request.\";\n    if (transpose) {\n      return {aten::HStack(adj_.col, adj_.row)};\n    } else {\n      return {aten::HStack(adj_.row, adj_.col)};\n    }\n  }\n\n  aten::COOMatrix GetCOOMatrix(dgl_type_t etype) const override { return adj_; }\n\n  aten::CSRMatrix GetCSCMatrix(dgl_type_t etype) const override {\n    LOG(FATAL) << \"Not enabled for COO graph\";\n    return aten::CSRMatrix();\n  }\n\n  aten::CSRMatrix GetCSRMatrix(dgl_type_t etype) const override {\n    LOG(FATAL) << \"Not enabled for COO graph\";\n    return aten::CSRMatrix();\n  }\n\n  SparseFormat SelectFormat(\n      dgl_type_t etype, dgl_format_code_t preferred_formats) const override {\n    LOG(FATAL) << \"Not enabled for COO graph\";\n    return SparseFormat::kCOO;\n  }\n\n  dgl_format_code_t GetAllowedFormats() const override {\n    LOG(FATAL) << \"Not enabled for COO graph\";\n    return 0;\n  }\n\n  dgl_format_code_t GetCreatedFormats() const override {\n    LOG(FATAL) << \"Not enabled for COO graph\";\n    return 0;\n  }\n\n  HeteroSubgraph VertexSubgraph(\n      const std::vector<IdArray>& vids) const override {\n    CHECK_EQ(vids.size(), NumVertexTypes())\n        << \"Number of vertex types mismatch\";\n    auto srcvids = vids[SrcType()], dstvids = vids[DstType()];\n    CHECK(aten::IsValidIdArray(srcvids)) << \"Invalid vertex id array.\";\n    CHECK(aten::IsValidIdArray(dstvids)) << \"Invalid vertex id array.\";\n    HeteroSubgraph subg;\n    const auto& submat = aten::COOSliceMatrix(adj_, srcvids, dstvids);\n    DGLContext ctx = aten::GetContextOf(vids);\n    IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), ctx);\n    subg.graph = std::make_shared<COO>(\n        meta_graph(), submat.num_rows, submat.num_cols, submat.row, submat.col);\n    subg.induced_vertices = vids;\n    subg.induced_edges.emplace_back(submat.data);\n    return subg;\n  }\n\n  HeteroSubgraph EdgeSubgraph(\n      const std::vector<IdArray>& eids,\n      bool preserve_nodes = false) const override {\n    CHECK_EQ(eids.size(), 1) << \"Edge type number mismatch.\";\n    HeteroSubgraph subg;\n    if (!preserve_nodes) {\n      IdArray new_src = aten::IndexSelect(adj_.row, eids[0]);\n      IdArray new_dst = aten::IndexSelect(adj_.col, eids[0]);\n      subg.induced_vertices.emplace_back(aten::Relabel_({new_src}));\n      subg.induced_vertices.emplace_back(aten::Relabel_({new_dst}));\n      const auto new_nsrc = subg.induced_vertices[0]->shape[0];\n      const auto new_ndst = subg.induced_vertices[1]->shape[0];\n      subg.graph = std::make_shared<COO>(\n          meta_graph(), new_nsrc, new_ndst, new_src, new_dst);\n      subg.induced_edges = eids;\n    } else {\n      IdArray new_src = aten::IndexSelect(adj_.row, eids[0]);\n      IdArray new_dst = aten::IndexSelect(adj_.col, eids[0]);\n      subg.induced_vertices.emplace_back(\n          aten::NullArray(DGLDataType{kDGLInt, NumBits(), 1}, Context()));\n      subg.induced_vertices.emplace_back(\n          aten::NullArray(DGLDataType{kDGLInt, NumBits(), 1}, Context()));\n      subg.graph = std::make_shared<COO>(\n          meta_graph(), NumVertices(SrcType()), NumVertices(DstType()), new_src,\n          new_dst);\n      subg.induced_edges = eids;\n    }\n    return subg;\n  }\n\n  HeteroGraphPtr GetGraphInFormat(dgl_format_code_t formats) const override {\n    LOG(FATAL) << \"Not enabled for COO graph.\";\n    return nullptr;\n  }\n\n  aten::COOMatrix adj() const { return adj_; }\n\n  /**\n   * @brief Determines whether the graph is \"hypersparse\", i.e. having\n   * significantly more nodes than edges.\n   */\n  bool IsHypersparse() const {\n    return (NumVertices(SrcType()) / 8 > NumEdges(EdgeType())) &&\n           (NumVertices(SrcType()) > 1000000);\n  }\n\n  bool Load(dmlc::Stream* fs) {\n    auto meta_imgraph = Serializer::make_shared<ImmutableGraph>();\n    CHECK(fs->Read(&meta_imgraph)) << \"Invalid meta graph\";\n    meta_graph_ = meta_imgraph;\n    CHECK(fs->Read(&adj_)) << \"Invalid adj matrix\";\n    return true;\n  }\n  void Save(dmlc::Stream* fs) const {\n    auto meta_graph_ptr = ImmutableGraph::ToImmutable(meta_graph());\n    fs->Write(meta_graph_ptr);\n    fs->Write(adj_);\n  }\n\n private:\n  friend class Serializer;\n\n  /** @brief internal adjacency matrix. Data array is empty */\n  aten::COOMatrix adj_;\n};\n\n//////////////////////////////////////////////////////////\n//\n// CSR graph implementation\n//\n//////////////////////////////////////////////////////////\n\n/** @brief CSR graph */\nclass UnitGraph::CSR : public BaseHeteroGraph {\n public:\n  CSR(GraphPtr metagraph, int64_t num_src, int64_t num_dst, IdArray indptr,\n      IdArray indices, IdArray edge_ids)\n      : BaseHeteroGraph(metagraph) {\n    CHECK(aten::IsValidIdArray(indptr));\n    CHECK(aten::IsValidIdArray(indices));\n    if (aten::IsValidIdArray(edge_ids))\n      CHECK(\n          (indices->shape[0] == edge_ids->shape[0]) ||\n          aten::IsNullArray(edge_ids))\n          << \"edge id arrays should have the same length as indices if not \"\n             \"empty\";\n    CHECK_EQ(num_src, indptr->shape[0] - 1)\n        << \"number of nodes do not match the length of indptr minus 1.\";\n\n    adj_ = aten::CSRMatrix{num_src, num_dst, indptr, indices, edge_ids};\n  }\n\n  CSR(GraphPtr metagraph, const aten::CSRMatrix& csr)\n      : BaseHeteroGraph(metagraph), adj_(csr) {}\n\n  CSR() {\n    // set magic num_rows/num_cols to mark it as undefined\n    // adj_.num_rows == 0 and adj_.num_cols == 0 means empty UnitGraph which is\n    // supported\n    adj_.num_rows = -1;\n    adj_.num_cols = -1;\n  };\n\n  bool defined() const { return (adj_.num_rows >= 0) || (adj_.num_cols >= 0); }\n\n  inline dgl_type_t SrcType() const { return 0; }\n\n  inline dgl_type_t DstType() const { return NumVertexTypes() == 1 ? 0 : 1; }\n\n  inline dgl_type_t EdgeType() const { return 0; }\n\n  HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const override {\n    LOG(FATAL) << \"The method shouldn't be called for UnitGraph graph. \"\n               << \"The relation graph is simply this graph itself.\";\n    return {};\n  }\n\n  void AddVertices(dgl_type_t vtype, uint64_t num_vertices) override {\n    LOG(FATAL) << \"UnitGraph graph is not mutable.\";\n  }\n\n  void AddEdge(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) override {\n    LOG(FATAL) << \"UnitGraph graph is not mutable.\";\n  }\n\n  void AddEdges(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) override {\n    LOG(FATAL) << \"UnitGraph graph is not mutable.\";\n  }\n\n  void Clear() override { LOG(FATAL) << \"UnitGraph graph is not mutable.\"; }\n\n  DGLDataType DataType() const override { return adj_.indices->dtype; }\n\n  DGLContext Context() const override { return adj_.indices->ctx; }\n\n  bool IsPinned() const override { return adj_.is_pinned; }\n\n  uint8_t NumBits() const override { return adj_.indices->dtype.bits; }\n\n  CSR AsNumBits(uint8_t bits) const {\n    if (NumBits() == bits) {\n      return *this;\n    } else {\n      CSR ret(\n          meta_graph_, adj_.num_rows, adj_.num_cols,\n          aten::AsNumBits(adj_.indptr, bits),\n          aten::AsNumBits(adj_.indices, bits),\n          aten::AsNumBits(adj_.data, bits));\n      return ret;\n    }\n  }\n\n  CSR CopyTo(const DGLContext& ctx) const {\n    if (Context() == ctx) {\n      return *this;\n    } else {\n      return CSR(meta_graph_, adj_.CopyTo(ctx));\n    }\n  }\n\n  /**\n   * @brief Copy the adj_ to pinned memory.\n   * @return CSRMatrix of the CSR graph.\n   */\n  CSR PinMemory() {\n    if (adj_.is_pinned) return *this;\n    return CSR(meta_graph_, adj_.PinMemory());\n  }\n\n  /** @brief Pin the adj_: CSRMatrix of the CSR graph. */\n  void PinMemory_() { adj_.PinMemory_(); }\n\n  /** @brief Unpin the adj_: CSRMatrix of the CSR graph. */\n  void UnpinMemory_() { adj_.UnpinMemory_(); }\n\n  /** @brief Record stream for the adj_: CSRMatrix of the CSR graph. */\n  void RecordStream(DGLStreamHandle stream) override {\n    adj_.RecordStream(stream);\n  }\n\n  bool IsMultigraph() const override { return aten::CSRHasDuplicate(adj_); }\n\n  bool IsReadonly() const override { return true; }\n\n  uint64_t NumVertices(dgl_type_t vtype) const override {\n    if (vtype == SrcType()) {\n      return adj_.num_rows;\n    } else if (vtype == DstType()) {\n      return adj_.num_cols;\n    } else {\n      LOG(FATAL) << \"Invalid vertex type: \" << vtype;\n      return 0;\n    }\n  }\n\n  uint64_t NumEdges(dgl_type_t etype) const override {\n    return adj_.indices->shape[0];\n  }\n\n  bool HasVertex(dgl_type_t vtype, dgl_id_t vid) const override {\n    return vid < NumVertices(vtype);\n  }\n\n  BoolArray HasVertices(dgl_type_t vtype, IdArray vids) const override {\n    LOG(FATAL) << \"Not enabled for COO graph\";\n    return {};\n  }\n\n  bool HasEdgeBetween(\n      dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {\n    CHECK(HasVertex(SrcType(), src)) << \"Invalid src vertex id: \" << src;\n    CHECK(HasVertex(DstType(), dst)) << \"Invalid dst vertex id: \" << dst;\n    return aten::CSRIsNonZero(adj_, src, dst);\n  }\n\n  BoolArray HasEdgesBetween(\n      dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const override {\n    CHECK(aten::IsValidIdArray(src_ids)) << \"Invalid vertex id array.\";\n    CHECK(aten::IsValidIdArray(dst_ids)) << \"Invalid vertex id array.\";\n    return aten::CSRIsNonZero(adj_, src_ids, dst_ids);\n  }\n\n  IdArray Predecessors(dgl_type_t etype, dgl_id_t dst) const override {\n    LOG(INFO) << \"Not enabled for CSR graph.\";\n    return {};\n  }\n\n  IdArray Successors(dgl_type_t etype, dgl_id_t src) const override {\n    CHECK(HasVertex(SrcType(), src)) << \"Invalid src vertex id: \" << src;\n    return aten::CSRGetRowColumnIndices(adj_, src);\n  }\n\n  IdArray EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {\n    CHECK(HasVertex(SrcType(), src)) << \"Invalid src vertex id: \" << src;\n    CHECK(HasVertex(DstType(), dst)) << \"Invalid dst vertex id: \" << dst;\n    return aten::CSRGetAllData(adj_, src, dst);\n  }\n\n  EdgeArray EdgeIdsAll(\n      dgl_type_t etype, IdArray src, IdArray dst) const override {\n    CHECK(aten::IsValidIdArray(src)) << \"Invalid vertex id array.\";\n    CHECK(aten::IsValidIdArray(dst)) << \"Invalid vertex id array.\";\n    const auto& arrs = aten::CSRGetDataAndIndices(adj_, src, dst);\n    return EdgeArray{arrs[0], arrs[1], arrs[2]};\n  }\n\n  IdArray EdgeIdsOne(\n      dgl_type_t etype, IdArray src, IdArray dst) const override {\n    return aten::CSRGetData(adj_, src, dst);\n  }\n\n  std::pair<dgl_id_t, dgl_id_t> FindEdge(\n      dgl_type_t etype, dgl_id_t eid) const override {\n    LOG(FATAL) << \"Not enabled for CSR graph.\";\n    return {};\n  }\n\n  EdgeArray FindEdges(dgl_type_t etype, IdArray eids) const override {\n    LOG(FATAL) << \"Not enabled for CSR graph.\";\n    return {};\n  }\n\n  EdgeArray InEdges(dgl_type_t etype, dgl_id_t vid) const override {\n    LOG(FATAL) << \"Not enabled for CSR graph.\";\n    return {};\n  }\n\n  EdgeArray InEdges(dgl_type_t etype, IdArray vids) const override {\n    LOG(FATAL) << \"Not enabled for CSR graph.\";\n    return {};\n  }\n\n  EdgeArray OutEdges(dgl_type_t etype, dgl_id_t vid) const override {\n    CHECK(HasVertex(SrcType(), vid)) << \"Invalid src vertex id: \" << vid;\n    IdArray ret_dst = aten::CSRGetRowColumnIndices(adj_, vid);\n    IdArray ret_eid = aten::CSRGetRowData(adj_, vid);\n    IdArray ret_src =\n        aten::Full(vid, ret_dst->shape[0], NumBits(), ret_dst->ctx);\n    return EdgeArray{ret_src, ret_dst, ret_eid};\n  }\n\n  EdgeArray OutEdges(dgl_type_t etype, IdArray vids) const override {\n    CHECK(aten::IsValidIdArray(vids)) << \"Invalid vertex id array.\";\n    auto csrsubmat = aten::CSRSliceRows(adj_, vids);\n    auto coosubmat = aten::CSRToCOO(csrsubmat, false);\n    // Note that the row id in the csr submat is relabled, so\n    // we need to recover it using an index select.\n    auto row = aten::IndexSelect(vids, coosubmat.row);\n    return EdgeArray{row, coosubmat.col, coosubmat.data};\n  }\n\n  EdgeArray Edges(\n      dgl_type_t etype, const std::string& order = \"\") const override {\n    CHECK(order.empty() || order == std::string(\"srcdst\"))\n        << \"CSR only support Edges of order \\\"srcdst\\\",\"\n        << \" but got \\\"\" << order << \"\\\".\";\n    auto coo = aten::CSRToCOO(adj_, false);\n    if (order == std::string(\"srcdst\")) {\n      // make sure the coo is sorted if an order is requested\n      coo = aten::COOSort(coo, true);\n    }\n    return EdgeArray{coo.row, coo.col, coo.data};\n  }\n\n  uint64_t InDegree(dgl_type_t etype, dgl_id_t vid) const override {\n    LOG(FATAL) << \"Not enabled for CSR graph.\";\n    return {};\n  }\n\n  DegreeArray InDegrees(dgl_type_t etype, IdArray vids) const override {\n    LOG(FATAL) << \"Not enabled for CSR graph.\";\n    return {};\n  }\n\n  uint64_t OutDegree(dgl_type_t etype, dgl_id_t vid) const override {\n    CHECK(HasVertex(SrcType(), vid)) << \"Invalid src vertex id: \" << vid;\n    return aten::CSRGetRowNNZ(adj_, vid);\n  }\n\n  DegreeArray OutDegrees(dgl_type_t etype, IdArray vids) const override {\n    CHECK(aten::IsValidIdArray(vids)) << \"Invalid vertex id array.\";\n    return aten::CSRGetRowNNZ(adj_, vids);\n  }\n\n  DGLIdIters SuccVec(dgl_type_t etype, dgl_id_t vid) const override {\n    // TODO(minjie): This still assumes the data type and device context\n    //   of this graph. Should fix later.\n    CHECK_EQ(NumBits(), 64);\n    const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(adj_.indptr->data);\n    const dgl_id_t* indices_data = static_cast<dgl_id_t*>(adj_.indices->data);\n    const dgl_id_t start = indptr_data[vid];\n    const dgl_id_t end = indptr_data[vid + 1];\n    return DGLIdIters(indices_data + start, indices_data + end);\n  }\n\n  DGLIdIters32 SuccVec32(dgl_type_t etype, dgl_id_t vid) {\n    // TODO(minjie): This still assumes the data type and device context\n    //   of this graph. Should fix later.\n    const int32_t* indptr_data = static_cast<int32_t*>(adj_.indptr->data);\n    const int32_t* indices_data = static_cast<int32_t*>(adj_.indices->data);\n    const int32_t start = indptr_data[vid];\n    const int32_t end = indptr_data[vid + 1];\n    return DGLIdIters32(indices_data + start, indices_data + end);\n  }\n\n  DGLIdIters OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const override {\n    // TODO(minjie): This still assumes the data type and device context\n    //   of this graph. Should fix later.\n    CHECK_EQ(NumBits(), 64);\n    const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(adj_.indptr->data);\n    const dgl_id_t* eid_data = static_cast<dgl_id_t*>(adj_.data->data);\n    const dgl_id_t start = indptr_data[vid];\n    const dgl_id_t end = indptr_data[vid + 1];\n    return DGLIdIters(eid_data + start, eid_data + end);\n  }\n\n  DGLIdIters PredVec(dgl_type_t etype, dgl_id_t vid) const override {\n    LOG(FATAL) << \"Not enabled for CSR graph.\";\n    return {};\n  }\n\n  DGLIdIters InEdgeVec(dgl_type_t etype, dgl_id_t vid) const override {\n    LOG(FATAL) << \"Not enabled for CSR graph.\";\n    return {};\n  }\n\n  std::vector<IdArray> GetAdj(\n      dgl_type_t etype, bool transpose, const std::string& fmt) const override {\n    CHECK(!transpose && fmt == \"csr\") << \"Not valid adj format request.\";\n    return {adj_.indptr, adj_.indices, adj_.data};\n  }\n\n  aten::COOMatrix GetCOOMatrix(dgl_type_t etype) const override {\n    LOG(FATAL) << \"Not enabled for CSR graph\";\n    return aten::COOMatrix();\n  }\n\n  aten::CSRMatrix GetCSCMatrix(dgl_type_t etype) const override {\n    LOG(FATAL) << \"Not enabled for CSR graph\";\n    return aten::CSRMatrix();\n  }\n\n  aten::CSRMatrix GetCSRMatrix(dgl_type_t etype) const override { return adj_; }\n\n  SparseFormat SelectFormat(\n      dgl_type_t etype, dgl_format_code_t preferred_formats) const override {\n    LOG(FATAL) << \"Not enabled for CSR graph\";\n    return SparseFormat::kCSR;\n  }\n\n  dgl_format_code_t GetAllowedFormats() const override {\n    LOG(FATAL) << \"Not enabled for COO graph\";\n    return 0;\n  }\n\n  dgl_format_code_t GetCreatedFormats() const override {\n    LOG(FATAL) << \"Not enabled for CSR graph\";\n    return 0;\n  }\n\n  HeteroSubgraph VertexSubgraph(\n      const std::vector<IdArray>& vids) const override {\n    CHECK_EQ(vids.size(), NumVertexTypes())\n        << \"Number of vertex types mismatch\";\n    auto srcvids = vids[SrcType()], dstvids = vids[DstType()];\n    CHECK(aten::IsValidIdArray(srcvids)) << \"Invalid vertex id array.\";\n    CHECK(aten::IsValidIdArray(dstvids)) << \"Invalid vertex id array.\";\n    HeteroSubgraph subg;\n    const auto& submat = aten::CSRSliceMatrix(adj_, srcvids, dstvids);\n    DGLContext ctx = aten::GetContextOf(vids);\n    IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), ctx);\n    subg.graph = std::make_shared<CSR>(\n        meta_graph(), submat.num_rows, submat.num_cols, submat.indptr,\n        submat.indices, sub_eids);\n    subg.induced_vertices = vids;\n    subg.induced_edges.emplace_back(submat.data);\n    return subg;\n  }\n\n  HeteroSubgraph EdgeSubgraph(\n      const std::vector<IdArray>& eids,\n      bool preserve_nodes = false) const override {\n    LOG(FATAL) << \"Not enabled for CSR graph.\";\n    return {};\n  }\n\n  HeteroGraphPtr GetGraphInFormat(dgl_format_code_t formats) const override {\n    LOG(FATAL) << \"Not enabled for CSR graph.\";\n    return nullptr;\n  }\n\n  aten::CSRMatrix adj() const { return adj_; }\n\n  bool Load(dmlc::Stream* fs) {\n    auto meta_imgraph = Serializer::make_shared<ImmutableGraph>();\n    CHECK(fs->Read(&meta_imgraph)) << \"Invalid meta graph\";\n    meta_graph_ = meta_imgraph;\n    CHECK(fs->Read(&adj_)) << \"Invalid adj matrix\";\n    return true;\n  }\n  void Save(dmlc::Stream* fs) const {\n    auto meta_graph_ptr = ImmutableGraph::ToImmutable(meta_graph());\n    fs->Write(meta_graph_ptr);\n    fs->Write(adj_);\n  }\n\n private:\n  friend class Serializer;\n\n  /** @brief internal adjacency matrix. Data array stores edge ids */\n  aten::CSRMatrix adj_;\n};\n\n//////////////////////////////////////////////////////////\n//\n// unit graph implementation\n//\n//////////////////////////////////////////////////////////\n\nDGLDataType UnitGraph::DataType() const { return GetAny()->DataType(); }\n\nDGLContext UnitGraph::Context() const { return GetAny()->Context(); }\n\nbool UnitGraph::IsPinned() const { return GetAny()->IsPinned(); }\n\nuint8_t UnitGraph::NumBits() const { return GetAny()->NumBits(); }\n\nbool UnitGraph::IsMultigraph() const {\n  const SparseFormat fmt = SelectFormat(CSC_CODE);\n  const auto ptr = GetFormat(fmt);\n  return ptr->IsMultigraph();\n}\n\nuint64_t UnitGraph::NumVertices(dgl_type_t vtype) const {\n  const SparseFormat fmt = SelectFormat(ALL_CODE);\n  const auto ptr = GetFormat(fmt);\n  // TODO(BarclayII): we have a lot of special handling for CSC.\n  // Need to have a UnitGraph::CSC backend instead.\n  if (fmt == SparseFormat::kCSC)\n    vtype = (vtype == SrcType()) ? DstType() : SrcType();\n  return ptr->NumVertices(vtype);\n}\n\nuint64_t UnitGraph::NumEdges(dgl_type_t etype) const {\n  return GetAny()->NumEdges(etype);\n}\n\nbool UnitGraph::HasVertex(dgl_type_t vtype, dgl_id_t vid) const {\n  const SparseFormat fmt = SelectFormat(ALL_CODE);\n  const auto ptr = GetFormat(fmt);\n  if (fmt == SparseFormat::kCSC)\n    vtype = (vtype == SrcType()) ? DstType() : SrcType();\n  return ptr->HasVertex(vtype, vid);\n}\n\nBoolArray UnitGraph::HasVertices(dgl_type_t vtype, IdArray vids) const {\n  CHECK(aten::IsValidIdArray(vids)) << \"Invalid id array input\";\n  return aten::LT(vids, NumVertices(vtype));\n}\n\nbool UnitGraph::HasEdgeBetween(\n    dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const {\n  const SparseFormat fmt = SelectFormat(CSC_CODE);\n  const auto ptr = GetFormat(fmt);\n  if (fmt == SparseFormat::kCSC)\n    return ptr->HasEdgeBetween(etype, dst, src);\n  else\n    return ptr->HasEdgeBetween(etype, src, dst);\n}\n\nBoolArray UnitGraph::HasEdgesBetween(\n    dgl_type_t etype, IdArray src, IdArray dst) const {\n  const SparseFormat fmt = SelectFormat(CSC_CODE);\n  const auto ptr = GetFormat(fmt);\n  if (fmt == SparseFormat::kCSC)\n    return ptr->HasEdgesBetween(etype, dst, src);\n  else\n    return ptr->HasEdgesBetween(etype, src, dst);\n}\n\nIdArray UnitGraph::Predecessors(dgl_type_t etype, dgl_id_t dst) const {\n  const SparseFormat fmt = SelectFormat(CSC_CODE);\n  const auto ptr = GetFormat(fmt);\n  if (fmt == SparseFormat::kCSC)\n    return ptr->Successors(etype, dst);\n  else\n    return ptr->Predecessors(etype, dst);\n}\n\nIdArray UnitGraph::Successors(dgl_type_t etype, dgl_id_t src) const {\n  const SparseFormat fmt = SelectFormat(CSR_CODE);\n  const auto ptr = GetFormat(fmt);\n  return ptr->Successors(etype, src);\n}\n\nIdArray UnitGraph::EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const {\n  const SparseFormat fmt = SelectFormat(CSR_CODE);\n  const auto ptr = GetFormat(fmt);\n  if (fmt == SparseFormat::kCSC)\n    return ptr->EdgeId(etype, dst, src);\n  else\n    return ptr->EdgeId(etype, src, dst);\n}\n\nEdgeArray UnitGraph::EdgeIdsAll(\n    dgl_type_t etype, IdArray src, IdArray dst) const {\n  const SparseFormat fmt = SelectFormat(CSR_CODE);\n  const auto ptr = GetFormat(fmt);\n  if (fmt == SparseFormat::kCSC) {\n    EdgeArray edges = ptr->EdgeIdsAll(etype, dst, src);\n    return EdgeArray{edges.dst, edges.src, edges.id};\n  } else {\n    return ptr->EdgeIdsAll(etype, src, dst);\n  }\n}\n\nIdArray UnitGraph::EdgeIdsOne(\n    dgl_type_t etype, IdArray src, IdArray dst) const {\n  const SparseFormat fmt = SelectFormat(CSR_CODE);\n  const auto ptr = GetFormat(fmt);\n  if (fmt == SparseFormat::kCSC) {\n    return ptr->EdgeIdsOne(etype, dst, src);\n  } else {\n    return ptr->EdgeIdsOne(etype, src, dst);\n  }\n}\n\nstd::pair<dgl_id_t, dgl_id_t> UnitGraph::FindEdge(\n    dgl_type_t etype, dgl_id_t eid) const {\n  const SparseFormat fmt = SelectFormat(COO_CODE);\n  const auto ptr = GetFormat(fmt);\n  return ptr->FindEdge(etype, eid);\n}\n\nEdgeArray UnitGraph::FindEdges(dgl_type_t etype, IdArray eids) const {\n  const SparseFormat fmt = SelectFormat(COO_CODE);\n  const auto ptr = GetFormat(fmt);\n  return ptr->FindEdges(etype, eids);\n}\n\nEdgeArray UnitGraph::InEdges(dgl_type_t etype, dgl_id_t vid) const {\n  const SparseFormat fmt = SelectFormat(CSC_CODE);\n  const auto ptr = GetFormat(fmt);\n  if (fmt == SparseFormat::kCSC) {\n    const EdgeArray& ret = ptr->OutEdges(etype, vid);\n    return {ret.dst, ret.src, ret.id};\n  } else {\n    return ptr->InEdges(etype, vid);\n  }\n}\n\nEdgeArray UnitGraph::InEdges(dgl_type_t etype, IdArray vids) const {\n  const SparseFormat fmt = SelectFormat(CSC_CODE);\n  const auto ptr = GetFormat(fmt);\n  if (fmt == SparseFormat::kCSC) {\n    const EdgeArray& ret = ptr->OutEdges(etype, vids);\n    return {ret.dst, ret.src, ret.id};\n  } else {\n    return ptr->InEdges(etype, vids);\n  }\n}\n\nEdgeArray UnitGraph::OutEdges(dgl_type_t etype, dgl_id_t vid) const {\n  const SparseFormat fmt = SelectFormat(CSR_CODE);\n  const auto ptr = GetFormat(fmt);\n  return ptr->OutEdges(etype, vid);\n}\n\nEdgeArray UnitGraph::OutEdges(dgl_type_t etype, IdArray vids) const {\n  const SparseFormat fmt = SelectFormat(CSR_CODE);\n  const auto ptr = GetFormat(fmt);\n  return ptr->OutEdges(etype, vids);\n}\n\nEdgeArray UnitGraph::Edges(dgl_type_t etype, const std::string& order) const {\n  SparseFormat fmt;\n  if (order == std::string(\"eid\")) {\n    fmt = SelectFormat(COO_CODE);\n  } else if (order.empty()) {\n    // arbitrary order\n    fmt = SelectFormat(ALL_CODE);\n  } else if (order == std::string(\"srcdst\")) {\n    fmt = SelectFormat(CSR_CODE);\n  } else {\n    LOG(FATAL) << \"Unsupported order request: \" << order;\n    return {};\n  }\n\n  const auto& edges = GetFormat(fmt)->Edges(etype, order);\n  if (fmt == SparseFormat::kCSC)\n    return EdgeArray{edges.dst, edges.src, edges.id};\n  else\n    return edges;\n}\n\nuint64_t UnitGraph::InDegree(dgl_type_t etype, dgl_id_t vid) const {\n  SparseFormat fmt = SelectFormat(CSC_CODE);\n  const auto ptr = GetFormat(fmt);\n  CHECK(fmt == SparseFormat::kCSC || fmt == SparseFormat::kCOO)\n      << \"In degree cannot be computed as neither CSC nor COO format is \"\n         \"allowed for this graph. Please enable one of them at least.\";\n  return fmt == SparseFormat::kCSC ? ptr->OutDegree(etype, vid)\n                                   : ptr->InDegree(etype, vid);\n}\n\nDegreeArray UnitGraph::InDegrees(dgl_type_t etype, IdArray vids) const {\n  SparseFormat fmt = SelectFormat(CSC_CODE);\n  const auto ptr = GetFormat(fmt);\n  CHECK(fmt == SparseFormat::kCSC || fmt == SparseFormat::kCOO)\n      << \"In degree cannot be computed as neither CSC nor COO format is \"\n         \"allowed for this graph. Please enable one of them at least.\";\n  return fmt == SparseFormat::kCSC ? ptr->OutDegrees(etype, vids)\n                                   : ptr->InDegrees(etype, vids);\n}\n\nuint64_t UnitGraph::OutDegree(dgl_type_t etype, dgl_id_t vid) const {\n  SparseFormat fmt = SelectFormat(CSR_CODE);\n  const auto ptr = GetFormat(fmt);\n  CHECK(fmt == SparseFormat::kCSR || fmt == SparseFormat::kCOO)\n      << \"Out degree cannot be computed as neither CSR nor COO format is \"\n         \"allowed for this graph. Please enable one of them at least.\";\n  return ptr->OutDegree(etype, vid);\n}\n\nDegreeArray UnitGraph::OutDegrees(dgl_type_t etype, IdArray vids) const {\n  SparseFormat fmt = SelectFormat(CSR_CODE);\n  const auto ptr = GetFormat(fmt);\n  CHECK(fmt == SparseFormat::kCSR || fmt == SparseFormat::kCOO)\n      << \"Out degree cannot be computed as neither CSR nor COO format is \"\n         \"allowed for this graph. Please enable one of them at least.\";\n  return ptr->OutDegrees(etype, vids);\n}\n\nDGLIdIters UnitGraph::SuccVec(dgl_type_t etype, dgl_id_t vid) const {\n  SparseFormat fmt = SelectFormat(CSR_CODE);\n  const auto ptr = GetFormat(fmt);\n  return ptr->SuccVec(etype, vid);\n}\n\nDGLIdIters32 UnitGraph::SuccVec32(dgl_type_t etype, dgl_id_t vid) const {\n  SparseFormat fmt = SelectFormat(CSR_CODE);\n  const auto ptr = std::dynamic_pointer_cast<CSR>(GetFormat(fmt));\n  CHECK_NOTNULL(ptr);\n  return ptr->SuccVec32(etype, vid);\n}\n\nDGLIdIters UnitGraph::OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const {\n  SparseFormat fmt = SelectFormat(CSR_CODE);\n  const auto ptr = GetFormat(fmt);\n  return ptr->OutEdgeVec(etype, vid);\n}\n\nDGLIdIters UnitGraph::PredVec(dgl_type_t etype, dgl_id_t vid) const {\n  SparseFormat fmt = SelectFormat(CSC_CODE);\n  const auto ptr = GetFormat(fmt);\n  if (fmt == SparseFormat::kCSC)\n    return ptr->SuccVec(etype, vid);\n  else\n    return ptr->PredVec(etype, vid);\n}\n\nDGLIdIters UnitGraph::InEdgeVec(dgl_type_t etype, dgl_id_t vid) const {\n  SparseFormat fmt = SelectFormat(CSC_CODE);\n  const auto ptr = GetFormat(fmt);\n  if (fmt == SparseFormat::kCSC)\n    return ptr->OutEdgeVec(etype, vid);\n  else\n    return ptr->InEdgeVec(etype, vid);\n}\n\nstd::vector<IdArray> UnitGraph::GetAdj(\n    dgl_type_t etype, bool transpose, const std::string& fmt) const {\n  // TODO(minjie): Our current semantics of adjacency matrix is row for dst\n  // nodes and col for src nodes. Therefore, we need to flip the transpose flag.\n  // For example,\n  //   transpose=False is equal to in edge CSR. We have this behavior because\n  //   previously we use framework's SPMM and we don't cache reverse adj. This\n  //   is not intuitive and also not consistent with networkx's\n  //   to_scipy_sparse_matrix. With the upcoming custom kernel change, we should\n  //   change the behavior and make row for src and col for dst.\n  if (fmt == std::string(\"csr\")) {\n    return !transpose ? GetOutCSR()->GetAdj(etype, false, \"csr\")\n                      : GetInCSR()->GetAdj(etype, false, \"csr\");\n  } else if (fmt == std::string(\"coo\")) {\n    return GetCOO()->GetAdj(etype, transpose, fmt);\n  } else {\n    LOG(FATAL) << \"unsupported adjacency matrix format: \" << fmt;\n    return {};\n  }\n}\n\nHeteroSubgraph UnitGraph::VertexSubgraph(\n    const std::vector<IdArray>& vids) const {\n  // We prefer to generate a subgraph from out-csr.\n  SparseFormat fmt = SelectFormat(CSR_CODE);\n  HeteroSubgraph sg = GetFormat(fmt)->VertexSubgraph(vids);\n  HeteroSubgraph ret;\n\n  CSRPtr subcsr = nullptr;\n  CSRPtr subcsc = nullptr;\n  COOPtr subcoo = nullptr;\n  switch (fmt) {\n    case SparseFormat::kCSR:\n      subcsr = std::dynamic_pointer_cast<CSR>(sg.graph);\n      break;\n    case SparseFormat::kCSC:\n      subcsc = std::dynamic_pointer_cast<CSR>(sg.graph);\n      break;\n    case SparseFormat::kCOO:\n      subcoo = std::dynamic_pointer_cast<COO>(sg.graph);\n      break;\n    default:\n      LOG(FATAL) << \"[BUG] unsupported format \" << static_cast<int>(fmt);\n      return ret;\n  }\n\n  ret.graph =\n      HeteroGraphPtr(new UnitGraph(meta_graph(), subcsc, subcsr, subcoo));\n  ret.induced_vertices = std::move(sg.induced_vertices);\n  ret.induced_edges = std::move(sg.induced_edges);\n  return ret;\n}\n\nHeteroSubgraph UnitGraph::EdgeSubgraph(\n    const std::vector<IdArray>& eids, bool preserve_nodes) const {\n  SparseFormat fmt = SelectFormat(COO_CODE);\n  auto sg = GetFormat(fmt)->EdgeSubgraph(eids, preserve_nodes);\n  HeteroSubgraph ret;\n\n  CSRPtr subcsr = nullptr;\n  CSRPtr subcsc = nullptr;\n  COOPtr subcoo = nullptr;\n  switch (fmt) {\n    case SparseFormat::kCSR:\n      subcsr = std::dynamic_pointer_cast<CSR>(sg.graph);\n      break;\n    case SparseFormat::kCSC:\n      subcsc = std::dynamic_pointer_cast<CSR>(sg.graph);\n      break;\n    case SparseFormat::kCOO:\n      subcoo = std::dynamic_pointer_cast<COO>(sg.graph);\n      break;\n    default:\n      LOG(FATAL) << \"[BUG] unsupported format \" << static_cast<int>(fmt);\n      return ret;\n  }\n\n  ret.graph =\n      HeteroGraphPtr(new UnitGraph(meta_graph(), subcsc, subcsr, subcoo));\n  ret.induced_vertices = std::move(sg.induced_vertices);\n  ret.induced_edges = std::move(sg.induced_edges);\n  return ret;\n}\n\nHeteroGraphPtr UnitGraph::CreateFromCOO(\n    int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray row,\n    IdArray col, bool row_sorted, bool col_sorted, dgl_format_code_t formats) {\n  CHECK(num_vtypes == 1 || num_vtypes == 2);\n  if (num_vtypes == 1) CHECK_EQ(num_src, num_dst);\n  auto mg = CreateUnitGraphMetaGraph(num_vtypes);\n  COOPtr coo(new COO(mg, num_src, num_dst, row, col, row_sorted, col_sorted));\n\n  return HeteroGraphPtr(new UnitGraph(mg, nullptr, nullptr, coo, formats));\n}\n\nHeteroGraphPtr UnitGraph::CreateFromCOO(\n    int64_t num_vtypes, const aten::COOMatrix& mat, dgl_format_code_t formats) {\n  CHECK(num_vtypes == 1 || num_vtypes == 2);\n  if (num_vtypes == 1) CHECK_EQ(mat.num_rows, mat.num_cols);\n  auto mg = CreateUnitGraphMetaGraph(num_vtypes);\n  COOPtr coo(new COO(mg, mat));\n\n  return HeteroGraphPtr(new UnitGraph(mg, nullptr, nullptr, coo, formats));\n}\n\nHeteroGraphPtr UnitGraph::CreateFromCSR(\n    int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray indptr,\n    IdArray indices, IdArray edge_ids, dgl_format_code_t formats) {\n  CHECK(num_vtypes == 1 || num_vtypes == 2);\n  if (num_vtypes == 1) CHECK_EQ(num_src, num_dst);\n  auto mg = CreateUnitGraphMetaGraph(num_vtypes);\n  CSRPtr csr(new CSR(mg, num_src, num_dst, indptr, indices, edge_ids));\n  return HeteroGraphPtr(new UnitGraph(mg, nullptr, csr, nullptr, formats));\n}\n\nHeteroGraphPtr UnitGraph::CreateFromCSR(\n    int64_t num_vtypes, const aten::CSRMatrix& mat, dgl_format_code_t formats) {\n  CHECK(num_vtypes == 1 || num_vtypes == 2);\n  if (num_vtypes == 1) CHECK_EQ(mat.num_rows, mat.num_cols);\n  auto mg = CreateUnitGraphMetaGraph(num_vtypes);\n  CSRPtr csr(new CSR(mg, mat));\n  return HeteroGraphPtr(new UnitGraph(mg, nullptr, csr, nullptr, formats));\n}\n\nHeteroGraphPtr UnitGraph::CreateFromCSRAndCOO(\n    int64_t num_vtypes, const aten::CSRMatrix& csr, const aten::COOMatrix& coo,\n    dgl_format_code_t formats) {\n  CHECK(num_vtypes == 1 || num_vtypes == 2);\n  CHECK_EQ(coo.num_rows, csr.num_rows);\n  CHECK_EQ(coo.num_cols, csr.num_cols);\n  if (num_vtypes == 1) {\n    CHECK_EQ(csr.num_rows, csr.num_cols);\n  }\n  auto mg = CreateUnitGraphMetaGraph(num_vtypes);\n  CSRPtr csrPtr(new CSR(mg, csr));\n  COOPtr cooPtr(new COO(mg, coo));\n  return HeteroGraphPtr(new UnitGraph(mg, nullptr, csrPtr, cooPtr, formats));\n}\n\nHeteroGraphPtr UnitGraph::CreateFromCSC(\n    int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray indptr,\n    IdArray indices, IdArray edge_ids, dgl_format_code_t formats) {\n  CHECK(num_vtypes == 1 || num_vtypes == 2);\n  if (num_vtypes == 1) CHECK_EQ(num_src, num_dst);\n  auto mg = CreateUnitGraphMetaGraph(num_vtypes);\n  CSRPtr csc(new CSR(mg, num_dst, num_src, indptr, indices, edge_ids));\n  return HeteroGraphPtr(new UnitGraph(mg, csc, nullptr, nullptr, formats));\n}\n\nHeteroGraphPtr UnitGraph::CreateFromCSC(\n    int64_t num_vtypes, const aten::CSRMatrix& mat, dgl_format_code_t formats) {\n  CHECK(num_vtypes == 1 || num_vtypes == 2);\n  if (num_vtypes == 1) CHECK_EQ(mat.num_rows, mat.num_cols);\n  auto mg = CreateUnitGraphMetaGraph(num_vtypes);\n  CSRPtr csc(new CSR(mg, mat));\n  return HeteroGraphPtr(new UnitGraph(mg, csc, nullptr, nullptr, formats));\n}\n\nHeteroGraphPtr UnitGraph::CreateFromCSCAndCOO(\n    int64_t num_vtypes, const aten::CSRMatrix& csc, const aten::COOMatrix& coo,\n    dgl_format_code_t formats) {\n  CHECK(num_vtypes == 1 || num_vtypes == 2);\n  CHECK_EQ(coo.num_rows, csc.num_cols);\n  CHECK_EQ(coo.num_cols, csc.num_rows);\n  if (num_vtypes == 1) {\n    CHECK_EQ(csc.num_rows, csc.num_cols);\n  }\n  auto mg = CreateUnitGraphMetaGraph(num_vtypes);\n  CSRPtr cscPtr(new CSR(mg, csc));\n  COOPtr cooPtr(new COO(mg, coo));\n  return HeteroGraphPtr(new UnitGraph(mg, cscPtr, nullptr, cooPtr, formats));\n}\n\nHeteroGraphPtr UnitGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) {\n  if (g->NumBits() == bits) {\n    return g;\n  } else {\n    auto bg = std::dynamic_pointer_cast<UnitGraph>(g);\n    CHECK_NOTNULL(bg);\n    CSRPtr new_incsr = (bg->in_csr_->defined())\n                           ? CSRPtr(new CSR(bg->in_csr_->AsNumBits(bits)))\n                           : nullptr;\n    CSRPtr new_outcsr = (bg->out_csr_->defined())\n                            ? CSRPtr(new CSR(bg->out_csr_->AsNumBits(bits)))\n                            : nullptr;\n    COOPtr new_coo = (bg->coo_->defined())\n                         ? COOPtr(new COO(bg->coo_->AsNumBits(bits)))\n                         : nullptr;\n    return HeteroGraphPtr(new UnitGraph(\n        g->meta_graph(), new_incsr, new_outcsr, new_coo, bg->formats_));\n  }\n}\n\nHeteroGraphPtr UnitGraph::CopyTo(HeteroGraphPtr g, const DGLContext& ctx) {\n  if (ctx == g->Context()) {\n    return g;\n  } else {\n    auto bg = std::dynamic_pointer_cast<UnitGraph>(g);\n    CHECK_NOTNULL(bg);\n    CSRPtr new_incsr = (bg->in_csr_->defined())\n                           ? CSRPtr(new CSR(bg->in_csr_->CopyTo(ctx)))\n                           : nullptr;\n    CSRPtr new_outcsr = (bg->out_csr_->defined())\n                            ? CSRPtr(new CSR(bg->out_csr_->CopyTo(ctx)))\n                            : nullptr;\n    COOPtr new_coo = (bg->coo_->defined())\n                         ? COOPtr(new COO(bg->coo_->CopyTo(ctx)))\n                         : nullptr;\n    return HeteroGraphPtr(new UnitGraph(\n        g->meta_graph(), new_incsr, new_outcsr, new_coo, bg->formats_));\n  }\n}\n\nHeteroGraphPtr UnitGraph::PinMemory() {\n  CSRPtr pinned_in_csr, pinned_out_csr;\n  COOPtr pinned_coo;\n  if (this->in_csr_->defined() && this->in_csr_->IsPinned()) {\n    pinned_in_csr = this->in_csr_;\n  } else if (this->in_csr_->defined()) {\n    pinned_in_csr = CSRPtr(new CSR(this->in_csr_->PinMemory()));\n  } else {\n    pinned_in_csr = nullptr;\n  }\n\n  if (this->out_csr_->defined() && this->out_csr_->IsPinned()) {\n    pinned_out_csr = this->out_csr_;\n  } else if (this->out_csr_->defined()) {\n    pinned_out_csr = CSRPtr(new CSR(this->out_csr_->PinMemory()));\n  } else {\n    pinned_out_csr = nullptr;\n  }\n\n  if (this->coo_->defined() && this->coo_->IsPinned()) {\n    pinned_coo = this->coo_;\n  } else if (this->coo_->defined()) {\n    pinned_coo = COOPtr(new COO(this->coo_->PinMemory()));\n  } else {\n    pinned_coo = nullptr;\n  }\n\n  return HeteroGraphPtr(new UnitGraph(\n      meta_graph(), pinned_in_csr, pinned_out_csr, pinned_coo, this->formats_));\n}\n\nvoid UnitGraph::PinMemory_() {\n  if (this->in_csr_->defined()) this->in_csr_->PinMemory_();\n  if (this->out_csr_->defined()) this->out_csr_->PinMemory_();\n  if (this->coo_->defined()) this->coo_->PinMemory_();\n}\n\nvoid UnitGraph::UnpinMemory_() {\n  if (this->in_csr_->defined()) this->in_csr_->UnpinMemory_();\n  if (this->out_csr_->defined()) this->out_csr_->UnpinMemory_();\n  if (this->coo_->defined()) this->coo_->UnpinMemory_();\n}\n\nvoid UnitGraph::RecordStream(DGLStreamHandle stream) {\n  if (this->in_csr_->defined()) this->in_csr_->RecordStream(stream);\n  if (this->out_csr_->defined()) this->out_csr_->RecordStream(stream);\n  if (this->coo_->defined()) this->coo_->RecordStream(stream);\n  this->recorded_streams.push_back(stream);\n}\n\nvoid UnitGraph::InvalidateCSR() { this->out_csr_ = CSRPtr(new CSR()); }\n\nvoid UnitGraph::InvalidateCSC() { this->in_csr_ = CSRPtr(new CSR()); }\n\nvoid UnitGraph::InvalidateCOO() { this->coo_ = COOPtr(new COO()); }\n\nUnitGraph::UnitGraph(\n    GraphPtr metagraph, CSRPtr in_csr, CSRPtr out_csr, COOPtr coo,\n    dgl_format_code_t formats)\n    : BaseHeteroGraph(metagraph),\n      in_csr_(in_csr),\n      out_csr_(out_csr),\n      coo_(coo) {\n  if (!in_csr_) {\n    in_csr_ = CSRPtr(new CSR());\n  }\n  if (!out_csr_) {\n    out_csr_ = CSRPtr(new CSR());\n  }\n  if (!coo_) {\n    coo_ = COOPtr(new COO());\n  }\n  formats_ = formats;\n  dgl_format_code_t created = GetCreatedFormats();\n  if ((formats | created) != formats)\n    LOG(FATAL) << \"Graph created from formats: \" << CodeToStr(created)\n               << \", which is not compatible with available formats: \"\n               << CodeToStr(formats);\n  CHECK(GetAny()) << \"At least one graph structure should exist.\";\n}\n\nHeteroGraphPtr UnitGraph::CreateUnitGraphFrom(\n    int num_vtypes, const aten::CSRMatrix& in_csr,\n    const aten::CSRMatrix& out_csr, const aten::COOMatrix& coo, bool has_in_csr,\n    bool has_out_csr, bool has_coo, dgl_format_code_t formats) {\n  auto mg = CreateUnitGraphMetaGraph(num_vtypes);\n\n  CSRPtr in_csr_ptr = nullptr;\n  CSRPtr out_csr_ptr = nullptr;\n  COOPtr coo_ptr = nullptr;\n\n  if (has_in_csr)\n    in_csr_ptr = CSRPtr(new CSR(mg, in_csr));\n  else\n    in_csr_ptr = CSRPtr(new CSR());\n  if (has_out_csr)\n    out_csr_ptr = CSRPtr(new CSR(mg, out_csr));\n  else\n    out_csr_ptr = CSRPtr(new CSR());\n  if (has_coo)\n    coo_ptr = COOPtr(new COO(mg, coo));\n  else\n    coo_ptr = COOPtr(new COO());\n\n  return HeteroGraphPtr(\n      new UnitGraph(mg, in_csr_ptr, out_csr_ptr, coo_ptr, formats));\n}\n\nUnitGraph::CSRPtr UnitGraph::GetInCSR(bool inplace) const {\n  if (inplace)\n    if (!(formats_ & CSC_CODE))\n      LOG(FATAL) << \"The graph have restricted sparse format \"\n                 << CodeToStr(formats_) << \", cannot create CSC matrix.\";\n  CSRPtr ret = in_csr_;\n  // Prefers converting from COO since it is parallelized.\n  // TODO(BarclayII): need benchmarking.\n  if (!in_csr_->defined()) {\n    if (coo_->defined()) {\n      const auto& newadj = aten::COOToCSR(aten::COOTranspose(coo_->adj()));\n\n      if (inplace)\n        *(const_cast<UnitGraph*>(this)->in_csr_) = CSR(meta_graph(), newadj);\n      else\n        ret = std::make_shared<CSR>(meta_graph(), newadj);\n    } else {\n      CHECK(out_csr_->defined()) << \"None of CSR, COO exist\";\n      const auto& newadj = aten::CSRTranspose(out_csr_->adj());\n\n      if (inplace)\n        *(const_cast<UnitGraph*>(this)->in_csr_) = CSR(meta_graph(), newadj);\n      else\n        ret = std::make_shared<CSR>(meta_graph(), newadj);\n    }\n    if (inplace) {\n      if (IsPinned()) in_csr_->PinMemory_();\n      for (auto stream : recorded_streams) in_csr_->RecordStream(stream);\n    }\n  }\n  return ret;\n}\n\n/** @brief Return out csr. If not exist, transpose the other one.*/\nUnitGraph::CSRPtr UnitGraph::GetOutCSR(bool inplace) const {\n  if (inplace)\n    if (!(formats_ & CSR_CODE))\n      LOG(FATAL) << \"The graph have restricted sparse format \"\n                 << CodeToStr(formats_) << \", cannot create CSR matrix.\";\n  CSRPtr ret = out_csr_;\n  // Prefers converting from COO since it is parallelized.\n  // TODO(BarclayII): need benchmarking.\n  if (!out_csr_->defined()) {\n    if (coo_->defined()) {\n      const auto& newadj = aten::COOToCSR(coo_->adj());\n\n      if (inplace)\n        *(const_cast<UnitGraph*>(this)->out_csr_) = CSR(meta_graph(), newadj);\n      else\n        ret = std::make_shared<CSR>(meta_graph(), newadj);\n    } else {\n      CHECK(in_csr_->defined()) << \"None of CSR, COO exist\";\n      const auto& newadj = aten::CSRTranspose(in_csr_->adj());\n\n      if (inplace)\n        *(const_cast<UnitGraph*>(this)->out_csr_) = CSR(meta_graph(), newadj);\n      else\n        ret = std::make_shared<CSR>(meta_graph(), newadj);\n    }\n    if (inplace) {\n      if (IsPinned()) out_csr_->PinMemory_();\n      for (auto stream : recorded_streams) out_csr_->RecordStream(stream);\n    }\n  }\n  return ret;\n}\n\n/** @brief Return coo. If not exist, create from csr.*/\nUnitGraph::COOPtr UnitGraph::GetCOO(bool inplace) const {\n  if (inplace)\n    if (!(formats_ & COO_CODE))\n      LOG(FATAL) << \"The graph have restricted sparse format \"\n                 << CodeToStr(formats_) << \", cannot create COO matrix.\";\n  COOPtr ret = coo_;\n  if (!coo_->defined()) {\n    if (in_csr_->defined()) {\n      const auto& newadj =\n          aten::COOTranspose(aten::CSRToCOO(in_csr_->adj(), true));\n\n      if (inplace)\n        *(const_cast<UnitGraph*>(this)->coo_) = COO(meta_graph(), newadj);\n      else\n        ret = std::make_shared<COO>(meta_graph(), newadj);\n    } else {\n      CHECK(out_csr_->defined()) << \"Both CSR are missing.\";\n      const auto& newadj = aten::CSRToCOO(out_csr_->adj(), true);\n\n      if (inplace)\n        *(const_cast<UnitGraph*>(this)->coo_) = COO(meta_graph(), newadj);\n      else\n        ret = std::make_shared<COO>(meta_graph(), newadj);\n    }\n    if (inplace) {\n      if (IsPinned()) coo_->PinMemory_();\n      for (auto stream : recorded_streams) coo_->RecordStream(stream);\n    }\n  }\n  return ret;\n}\n\naten::CSRMatrix UnitGraph::GetCSCMatrix(dgl_type_t etype) const {\n  return GetInCSR()->adj();\n}\n\naten::CSRMatrix UnitGraph::GetCSRMatrix(dgl_type_t etype) const {\n  return GetOutCSR()->adj();\n}\n\naten::COOMatrix UnitGraph::GetCOOMatrix(dgl_type_t etype) const {\n  return GetCOO()->adj();\n}\n\nHeteroGraphPtr UnitGraph::GetAny() const {\n  if (in_csr_->defined()) {\n    return in_csr_;\n  } else if (out_csr_->defined()) {\n    return out_csr_;\n  } else {\n    return coo_;\n  }\n}\n\ndgl_format_code_t UnitGraph::GetCreatedFormats() const {\n  dgl_format_code_t ret = 0;\n  if (in_csr_->defined()) ret |= CSC_CODE;\n  if (out_csr_->defined()) ret |= CSR_CODE;\n  if (coo_->defined()) ret |= COO_CODE;\n  return ret;\n}\n\ndgl_format_code_t UnitGraph::GetAllowedFormats() const { return formats_; }\n\nHeteroGraphPtr UnitGraph::GetFormat(SparseFormat format) const {\n  switch (format) {\n    case SparseFormat::kCSR:\n      return GetOutCSR();\n    case SparseFormat::kCSC:\n      return GetInCSR();\n    default:\n      return GetCOO();\n  }\n}\n\nHeteroGraphPtr UnitGraph::GetGraphInFormat(dgl_format_code_t formats) const {\n  // Get the created formats.\n  auto created_formats = GetCreatedFormats();\n  // Get the intersection of formats and created_formats.\n  auto intersection = formats & created_formats;\n\n  // If the intersection of formats and created_formats is not empty.\n  // The format(s) in the intersection will be retained.\n  if (intersection != 0) {\n    COOPtr coo_ptr = COO_CODE & intersection ? GetCOO(false) : nullptr;\n    CSRPtr in_csr_ptr = CSC_CODE & intersection ? GetInCSR(false) : nullptr;\n    CSRPtr out_csr_ptr = CSR_CODE & intersection ? GetOutCSR(false) : nullptr;\n\n    return HeteroGraphPtr(\n        new UnitGraph(meta_graph_, in_csr_ptr, out_csr_ptr, coo_ptr, formats));\n  }\n\n  // If the intersection of formats and created_formats is empty.\n  // Create a format in the order of COO -> CSR -> CSC.\n  int64_t num_vtypes = NumVertexTypes();\n  if (COO_CODE & formats)\n    return CreateFromCOO(num_vtypes, GetCOO(false)->adj(), formats);\n  if (CSR_CODE & formats)\n    return CreateFromCSR(num_vtypes, GetOutCSR(false)->adj(), formats);\n  return CreateFromCSC(num_vtypes, GetInCSR(false)->adj(), formats);\n}\n\nSparseFormat UnitGraph::SelectFormat(\n    dgl_format_code_t preferred_formats) const {\n  dgl_format_code_t common = preferred_formats & formats_;\n  dgl_format_code_t created = GetCreatedFormats();\n  if (common & created) return DecodeFormat(common & created);\n\n  // NOTE(zihao): hypersparse is currently disabled since many CUDA operators on\n  // COO have not been implmented yet. if (coo_->defined() &&\n  // coo_->IsHypersparse())  // only allow coo for hypersparse graph.\n  //   return SparseFormat::kCOO;\n  if (common) return DecodeFormat(common);\n  return DecodeFormat(created);\n}\n\nGraphPtr UnitGraph::AsImmutableGraph() const {\n  CHECK(NumVertexTypes() == 1) << \"not a homogeneous graph\";\n  dgl::CSRPtr in_csr_ptr = nullptr, out_csr_ptr = nullptr;\n  dgl::COOPtr coo_ptr = nullptr;\n  if (in_csr_->defined()) {\n    aten::CSRMatrix csc = GetCSCMatrix(0);\n    in_csr_ptr = dgl::CSRPtr(new dgl::CSR(csc.indptr, csc.indices, csc.data));\n  }\n  if (out_csr_->defined()) {\n    aten::CSRMatrix csr = GetCSRMatrix(0);\n    out_csr_ptr = dgl::CSRPtr(new dgl::CSR(csr.indptr, csr.indices, csr.data));\n  }\n  if (coo_->defined()) {\n    aten::COOMatrix coo = GetCOOMatrix(0);\n    if (!COOHasData(coo)) {\n      coo_ptr = dgl::COOPtr(new dgl::COO(NumVertices(0), coo.row, coo.col));\n    } else {\n      IdArray new_src = Scatter(coo.row, coo.data);\n      IdArray new_dst = Scatter(coo.col, coo.data);\n      coo_ptr = dgl::COOPtr(new dgl::COO(NumVertices(0), new_src, new_dst));\n    }\n  }\n  return GraphPtr(new dgl::ImmutableGraph(in_csr_ptr, out_csr_ptr, coo_ptr));\n}\n\nHeteroGraphPtr UnitGraph::LineGraph(bool backtracking) const {\n  // TODO(xiangsx) currently we only support homogeneous graph\n  auto fmt = SelectFormat(ALL_CODE);\n  switch (fmt) {\n    case SparseFormat::kCOO: {\n      return CreateFromCOO(1, aten::COOLineGraph(coo_->adj(), backtracking));\n    }\n    case SparseFormat::kCSR: {\n      const aten::CSRMatrix csr = GetCSRMatrix(0);\n      const aten::COOMatrix coo =\n          aten::COOLineGraph(aten::CSRToCOO(csr, true), backtracking);\n      return CreateFromCOO(1, coo);\n    }\n    case SparseFormat::kCSC: {\n      const aten::CSRMatrix csc = GetCSCMatrix(0);\n      const aten::CSRMatrix csr = aten::CSRTranspose(csc);\n      const aten::COOMatrix coo =\n          aten::COOLineGraph(aten::CSRToCOO(csr, true), backtracking);\n      return CreateFromCOO(1, coo);\n    }\n    default:\n      LOG(FATAL) << \"None of CSC, CSR, COO exist\";\n      break;\n  }\n  return nullptr;\n}\n\nconstexpr uint64_t kDGLSerialize_UnitGraphMagic = 0xDD2E60F0F6B4A127;\n\nbool UnitGraph::Load(dmlc::Stream* fs) {\n  uint64_t magicNum;\n  CHECK(fs->Read(&magicNum)) << \"Invalid Magic Number\";\n  CHECK_EQ(magicNum, kDGLSerialize_UnitGraphMagic) << \"Invalid UnitGraph Data\";\n\n  int64_t save_format_code, formats_code;\n  CHECK(fs->Read(&save_format_code)) << \"Invalid format\";\n  CHECK(fs->Read(&formats_code)) << \"Invalid format\";\n  dgl_format_code_t save_formats = ANY_CODE;\n  if (save_format_code >> 32) {\n    save_formats =\n        static_cast<dgl_format_code_t>(0xffffffff & save_format_code);\n  } else {\n    save_formats =\n        SparseFormatsToCode({static_cast<SparseFormat>(save_format_code)});\n  }\n  if (formats_code >> 32) {\n    formats_ = static_cast<dgl_format_code_t>(0xffffffff & formats_code);\n  } else {\n    // NOTE(zihao): to be compatible with old formats.\n    switch (formats_code & 0xffffffff) {\n      case 0:\n        formats_ = ALL_CODE;\n        break;\n      case 1:\n        formats_ = COO_CODE;\n        break;\n      case 2:\n        formats_ = CSR_CODE;\n        break;\n      case 3:\n        formats_ = CSC_CODE;\n        break;\n      default:\n        LOG(FATAL) << \"Load graph failed, formats code \" << formats_code\n                   << \"not recognized.\";\n    }\n  }\n\n  if (save_formats & COO_CODE) {\n    fs->Read(&coo_);\n  }\n  if (save_formats & CSR_CODE) {\n    fs->Read(&out_csr_);\n  }\n  if (save_formats & CSC_CODE) {\n    fs->Read(&in_csr_);\n  }\n  if (!coo_ && !out_csr_ && !in_csr_) {\n    LOG(FATAL) << \"unsupported format code\";\n  }\n\n  if (!in_csr_) {\n    in_csr_ = CSRPtr(new CSR());\n  }\n  if (!out_csr_) {\n    out_csr_ = CSRPtr(new CSR());\n  }\n  if (!coo_) {\n    coo_ = COOPtr(new COO());\n  }\n\n  meta_graph_ = GetAny()->meta_graph();\n\n  return true;\n}\n\nvoid UnitGraph::Save(dmlc::Stream* fs) const {\n  fs->Write(kDGLSerialize_UnitGraphMagic);\n  // Didn't write UnitGraph::meta_graph_, since it's included in the underlying\n  // sparse matrix\n  auto save_formats = SparseFormatsToCode({SelectFormat(ALL_CODE)});\n  auto fstream = dynamic_cast<dgl::serialize::DGLStream*>(fs);\n  if (fstream) {\n    auto formats = fstream->FormatsToSave();\n    save_formats = formats == ANY_CODE\n                       ? SparseFormatsToCode({SelectFormat(ALL_CODE)})\n                       : formats;\n  }\n  fs->Write(static_cast<int64_t>(save_formats | 0x100000000));\n  fs->Write(static_cast<int64_t>(formats_ | 0x100000000));\n  if (save_formats & COO_CODE) {\n    fs->Write(GetCOO());\n  }\n  if (save_formats & CSR_CODE) {\n    fs->Write(GetOutCSR());\n  }\n  if (save_formats & CSC_CODE) {\n    fs->Write(GetInCSR());\n  }\n}\n\nUnitGraphPtr UnitGraph::Reverse() const {\n  CSRPtr new_incsr = out_csr_, new_outcsr = in_csr_;\n  COOPtr new_coo = nullptr;\n  if (coo_->defined()) {\n    new_coo =\n        COOPtr(new COO(coo_->meta_graph(), aten::COOTranspose(coo_->adj())));\n  }\n\n  return UnitGraphPtr(\n      new UnitGraph(meta_graph(), new_incsr, new_outcsr, new_coo));\n}\n\nstd::tuple<UnitGraphPtr, IdArray, IdArray> UnitGraph::ToSimple() const {\n  CSRPtr new_incsr = nullptr, new_outcsr = nullptr;\n  COOPtr new_coo = nullptr;\n  IdArray count;\n  IdArray edge_map;\n\n  auto avail_fmt = SelectFormat(ALL_CODE);\n  switch (avail_fmt) {\n    case SparseFormat::kCOO: {\n      auto ret = aten::COOToSimple(GetCOO()->adj());\n      count = std::get<1>(ret);\n      edge_map = std::get<2>(ret);\n      new_coo = COOPtr(new COO(meta_graph(), std::get<0>(ret)));\n      break;\n    }\n    case SparseFormat::kCSR: {\n      auto ret = aten::CSRToSimple(GetOutCSR()->adj());\n      count = std::get<1>(ret);\n      edge_map = std::get<2>(ret);\n      new_outcsr = CSRPtr(new CSR(meta_graph(), std::get<0>(ret)));\n      break;\n    }\n    case SparseFormat::kCSC: {\n      auto ret = aten::CSRToSimple(GetInCSR()->adj());\n      count = std::get<1>(ret);\n      edge_map = std::get<2>(ret);\n      new_incsr = CSRPtr(new CSR(meta_graph(), std::get<0>(ret)));\n      break;\n    }\n    default:\n      LOG(FATAL) << \"At lease one of COO, CSR or CSC adj should exist.\";\n      break;\n  }\n\n  return std::make_tuple(\n      UnitGraphPtr(new UnitGraph(meta_graph(), new_incsr, new_outcsr, new_coo)),\n      count, edge_map);\n}\n\n}  // namespace dgl\n"
  },
  {
    "path": "src/graph/unit_graph.h",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file graph/unit_graph.h\n * @brief UnitGraph graph\n */\n\n#ifndef DGL_GRAPH_UNIT_GRAPH_H_\n#define DGL_GRAPH_UNIT_GRAPH_H_\n\n#include <dgl/array.h>\n#include <dgl/base_heterograph.h>\n#include <dgl/lazy.h>\n#include <dmlc/io.h>\n#include <dmlc/type_traits.h>\n\n#include <memory>\n#include <string>\n#include <tuple>\n#include <utility>\n#include <vector>\n\n#include \"../c_api_common.h\"\n\nnamespace dgl {\n\nclass HeteroGraph;\nclass UnitGraph;\ntypedef std::shared_ptr<UnitGraph> UnitGraphPtr;\n\n/**\n * @brief UnitGraph graph\n *\n * UnitGraph graph is a special type of heterograph which\n * (1) Have two types of nodes: \"Src\" and \"Dst\". All the edges are\n *     from \"Src\" type nodes to \"Dst\" type nodes, so there is no edge among\n *     nodes of the same type. Thus, its metagraph has two nodes and one edge\n *     between them.\n * (2) Have only one type of nodes and edges. Thus, its metagraph has one node\n *     and one self-loop edge.\n */\nclass UnitGraph : public BaseHeteroGraph {\n public:\n  // internal data structure\n  class COO;\n  class CSR;\n  typedef std::shared_ptr<COO> COOPtr;\n  typedef std::shared_ptr<CSR> CSRPtr;\n\n  inline dgl_type_t SrcType() const { return 0; }\n\n  inline dgl_type_t DstType() const { return NumVertexTypes() == 1 ? 0 : 1; }\n\n  inline dgl_type_t EdgeType() const { return 0; }\n\n  HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const override {\n    LOG(FATAL) << \"The method shouldn't be called for UnitGraph graph. \"\n               << \"The relation graph is simply this graph itself.\";\n    return {};\n  }\n\n  void AddVertices(dgl_type_t vtype, uint64_t num_vertices) override {\n    LOG(FATAL) << \"UnitGraph graph is not mutable.\";\n  }\n\n  void AddEdge(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) override {\n    LOG(FATAL) << \"UnitGraph graph is not mutable.\";\n  }\n\n  void AddEdges(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) override {\n    LOG(FATAL) << \"UnitGraph graph is not mutable.\";\n  }\n\n  void Clear() override { LOG(FATAL) << \"UnitGraph graph is not mutable.\"; }\n\n  DGLDataType DataType() const override;\n\n  DGLContext Context() const override;\n\n  bool IsPinned() const override;\n\n  uint8_t NumBits() const override;\n\n  bool IsMultigraph() const override;\n\n  bool IsReadonly() const override { return true; }\n\n  uint64_t NumVertices(dgl_type_t vtype) const override;\n\n  inline std::vector<int64_t> NumVerticesPerType() const override {\n    std::vector<int64_t> num_nodes_per_type;\n    for (dgl_type_t vtype = 0; vtype < NumVertexTypes(); ++vtype)\n      num_nodes_per_type.push_back(NumVertices(vtype));\n    return num_nodes_per_type;\n  }\n\n  uint64_t NumEdges(dgl_type_t etype) const override;\n\n  bool HasVertex(dgl_type_t vtype, dgl_id_t vid) const override;\n\n  BoolArray HasVertices(dgl_type_t vtype, IdArray vids) const override;\n\n  bool HasEdgeBetween(\n      dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override;\n\n  BoolArray HasEdgesBetween(\n      dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const override;\n\n  IdArray Predecessors(dgl_type_t etype, dgl_id_t dst) const override;\n\n  IdArray Successors(dgl_type_t etype, dgl_id_t src) const override;\n\n  IdArray EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override;\n\n  EdgeArray EdgeIdsAll(\n      dgl_type_t etype, IdArray src, IdArray dst) const override;\n\n  IdArray EdgeIdsOne(dgl_type_t etype, IdArray src, IdArray dst) const override;\n\n  std::pair<dgl_id_t, dgl_id_t> FindEdge(\n      dgl_type_t etype, dgl_id_t eid) const override;\n\n  EdgeArray FindEdges(dgl_type_t etype, IdArray eids) const override;\n\n  EdgeArray InEdges(dgl_type_t etype, dgl_id_t vid) const override;\n\n  EdgeArray InEdges(dgl_type_t etype, IdArray vids) const override;\n\n  EdgeArray OutEdges(dgl_type_t etype, dgl_id_t vid) const override;\n\n  EdgeArray OutEdges(dgl_type_t etype, IdArray vids) const override;\n\n  EdgeArray Edges(\n      dgl_type_t etype, const std::string& order = \"\") const override;\n\n  uint64_t InDegree(dgl_type_t etype, dgl_id_t vid) const override;\n\n  DegreeArray InDegrees(dgl_type_t etype, IdArray vids) const override;\n\n  uint64_t OutDegree(dgl_type_t etype, dgl_id_t vid) const override;\n\n  DegreeArray OutDegrees(dgl_type_t etype, IdArray vids) const override;\n\n  DGLIdIters SuccVec(dgl_type_t etype, dgl_id_t vid) const override;\n\n  // 32bit version functions, patch for SuccVec\n  DGLIdIters32 SuccVec32(dgl_type_t etype, dgl_id_t vid) const;\n\n  DGLIdIters OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const override;\n\n  DGLIdIters PredVec(dgl_type_t etype, dgl_id_t vid) const override;\n\n  DGLIdIters InEdgeVec(dgl_type_t etype, dgl_id_t vid) const override;\n\n  std::vector<IdArray> GetAdj(\n      dgl_type_t etype, bool transpose, const std::string& fmt) const override;\n\n  HeteroSubgraph VertexSubgraph(\n      const std::vector<IdArray>& vids) const override;\n\n  HeteroSubgraph EdgeSubgraph(\n      const std::vector<IdArray>& eids,\n      bool preserve_nodes = false) const override;\n\n  // creators\n  /** @brief Create a graph with no edges */\n  static HeteroGraphPtr Empty(\n      int64_t num_vtypes, int64_t num_src, int64_t num_dst, DGLDataType dtype,\n      DGLContext ctx) {\n    IdArray row = IdArray::Empty({0}, dtype, ctx);\n    IdArray col = IdArray::Empty({0}, dtype, ctx);\n    return CreateFromCOO(num_vtypes, num_src, num_dst, row, col);\n  }\n\n  /** @brief Create a graph from COO arrays */\n  static HeteroGraphPtr CreateFromCOO(\n      int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray row,\n      IdArray col, bool row_sorted = false, bool col_sorted = false,\n      dgl_format_code_t formats = ALL_CODE);\n\n  static HeteroGraphPtr CreateFromCOO(\n      int64_t num_vtypes, const aten::COOMatrix& mat,\n      dgl_format_code_t formats = ALL_CODE);\n\n  /** @brief Create a graph from (out) CSR arrays */\n  static HeteroGraphPtr CreateFromCSR(\n      int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray indptr,\n      IdArray indices, IdArray edge_ids, dgl_format_code_t formats = ALL_CODE);\n\n  static HeteroGraphPtr CreateFromCSR(\n      int64_t num_vtypes, const aten::CSRMatrix& mat,\n      dgl_format_code_t formats = ALL_CODE);\n\n  /** @brief Create a graph from (out) CSR and COO arrays, both representing the\n   * same graph */\n  static HeteroGraphPtr CreateFromCSRAndCOO(\n      int64_t num_vtypes, const aten::CSRMatrix& csr,\n      const aten::COOMatrix& coo, dgl_format_code_t formats = ALL_CODE);\n\n  /** @brief Create a graph from (in) CSC arrays */\n  static HeteroGraphPtr CreateFromCSC(\n      int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray indptr,\n      IdArray indices, IdArray edge_ids, dgl_format_code_t formats = ALL_CODE);\n\n  static HeteroGraphPtr CreateFromCSC(\n      int64_t num_vtypes, const aten::CSRMatrix& mat,\n      dgl_format_code_t formats = ALL_CODE);\n\n  /** @brief Create a graph from (in) CSC and COO arrays, both representing the\n   * same graph */\n  static HeteroGraphPtr CreateFromCSCAndCOO(\n      int64_t num_vtypes, const aten::CSRMatrix& csc,\n      const aten::COOMatrix& coo, dgl_format_code_t formats = ALL_CODE);\n\n  /** @brief Convert the graph to use the given number of bits for storage */\n  static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits);\n\n  /** @brief Copy the data to another context */\n  static HeteroGraphPtr CopyTo(HeteroGraphPtr g, const DGLContext& ctx);\n\n  /**\n   * @brief Pin the in_csr_, out_scr_ and coo_ of the current graph.\n   * @note The graph will be pinned inplace. Behavior depends on the current\n   * context, kDGLCPU: will be pinned; IsPinned: directly return; kDGLCUDA:\n   * invalid, will throw an error. The context check is deferred to pinning the\n   * NDArray.\n   */\n  void PinMemory_() override;\n\n  /**\n   * @brief Unpin the in_csr_, out_scr_ and coo_ of the current graph.\n   * @note The graph will be unpinned inplace. Behavior depends on the current\n   * context, IsPinned: will be unpinned; others: directly return. The context\n   * check is deferred to unpinning the NDArray.\n   */\n  void UnpinMemory_();\n\n  /**\n   * @brief Create a copy of the current graph in pinned memory.\n   * @note The graph will be pinned outplace through PyTorch\n   *     CachingHostAllocator, if available. Otherwise, an error will be thrown.\n   *     If any of the underlying structures (incsr, outcsr, coo) are already\n   *     pinned, the function will simply use its original copy.\n   */\n  HeteroGraphPtr PinMemory();\n\n  /**\n   * @brief Record stream for this graph.\n   * @param stream The stream that is using the graph\n   */\n  void RecordStream(DGLStreamHandle stream) override;\n\n  /**\n   * @brief Create in-edge CSR format of the unit graph.\n   * @param inplace if true and the in-edge CSR format does not exist, the\n   * created format will be cached in this object unless the format is\n   * restricted.\n   * @return Return the in-edge CSR format. Create from other format if not\n   * exist.\n   */\n  CSRPtr GetInCSR(bool inplace = true) const;\n\n  /**\n   * @brief Create out-edge CSR format of the unit graph.\n   * @param inplace if true and the out-edge CSR format does not exist, the\n   * created format will be cached in this object unless the format is\n   * restricted.\n   * @return Return the out-edge CSR format. Create from other format if not\n   * exist.\n   */\n  CSRPtr GetOutCSR(bool inplace = true) const;\n\n  /**\n   * @brief Create COO format of the unit graph.\n   * @param inplace if true and the COO format does not exist, the created\n   *                format will be cached in this object unless the format is\n   * restricted.\n   * @return Return the COO format. Create from other format if not exist.\n   */\n  COOPtr GetCOO(bool inplace = true) const;\n\n  /** @return Return the COO matrix form */\n  aten::COOMatrix GetCOOMatrix(dgl_type_t etype) const override;\n\n  /** @return Return the in-edge CSC in the matrix form */\n  aten::CSRMatrix GetCSCMatrix(dgl_type_t etype) const override;\n\n  /** @return Return the out-edge CSR in the matrix form */\n  aten::CSRMatrix GetCSRMatrix(dgl_type_t etype) const override;\n\n  SparseFormat SelectFormat(\n      dgl_type_t etype, dgl_format_code_t preferred_formats) const override {\n    return SelectFormat(preferred_formats);\n  }\n\n  /**\n   * @brief Return the graph in the given format. Perform format conversion if\n   * the requested format does not exist.\n   *\n   * @return A graph in the requested format.\n   */\n  HeteroGraphPtr GetFormat(SparseFormat format) const;\n\n  dgl_format_code_t GetCreatedFormats() const override;\n\n  dgl_format_code_t GetAllowedFormats() const override;\n\n  HeteroGraphPtr GetGraphInFormat(dgl_format_code_t formats) const override;\n\n  /** @return Load UnitGraph from stream, using CSRMatrix*/\n  bool Load(dmlc::Stream* fs);\n\n  /** @return Save UnitGraph to stream, using CSRMatrix */\n  void Save(dmlc::Stream* fs) const;\n\n  /** @brief Creat a LineGraph of self */\n  HeteroGraphPtr LineGraph(bool backtracking) const;\n\n  /** @return the reversed graph */\n  UnitGraphPtr Reverse() const;\n\n  /** @return the simpled (no-multi-edge) graph\n   *          the count recording the number of duplicated edges from the\n   * original graph. the edge mapping from the edge IDs of original graph to\n   * those of the returned graph.\n   */\n  std::tuple<UnitGraphPtr, IdArray, IdArray> ToSimple() const;\n\n  void InvalidateCSR();\n\n  void InvalidateCSC();\n\n  void InvalidateCOO();\n\n private:\n  friend class Serializer;\n  friend class HeteroGraph;\n  friend class ImmutableGraph;\n  friend HeteroGraphPtr HeteroForkingUnpickle(const HeteroPickleStates& states);\n\n  // private empty constructor\n  UnitGraph() {}\n\n  /**\n   * @brief constructor\n   * @param metagraph metagraph\n   * @param in_csr in edge csr\n   * @param out_csr out edge csr\n   * @param coo coo\n   */\n  UnitGraph(\n      GraphPtr metagraph, CSRPtr in_csr, CSRPtr out_csr, COOPtr coo,\n      dgl_format_code_t formats = ALL_CODE);\n\n  /**\n   * @brief constructor\n   * @param num_vtypes number of vertex types (1 or 2)\n   * @param metagraph metagraph\n   * @param in_csr in edge csr\n   * @param out_csr out edge csr\n   * @param coo coo\n   * @param has_in_csr whether in_csr is valid\n   * @param has_out_csr whether out_csr is valid\n   * @param has_coo whether coo is valid\n   */\n  static HeteroGraphPtr CreateUnitGraphFrom(\n      int num_vtypes, const aten::CSRMatrix& in_csr,\n      const aten::CSRMatrix& out_csr, const aten::COOMatrix& coo,\n      bool has_in_csr, bool has_out_csr, bool has_coo,\n      dgl_format_code_t formats = ALL_CODE);\n\n  /** @return Return any existing format. */\n  HeteroGraphPtr GetAny() const;\n\n  /**\n   * @brief Determine which format to use with a preference.\n   *\n   * If the storage of unit graph is \"locked\", i.e. no conversion is allowed,\n   * then it will return the locked format.\n   *\n   * Otherwise, it will return whatever DGL thinks is the most appropriate given\n   * the arguments.\n   */\n  SparseFormat SelectFormat(dgl_format_code_t preferred_formats) const;\n\n  /** @return Whether the graph is hypersparse */\n  bool IsHypersparse() const;\n\n  GraphPtr AsImmutableGraph() const override;\n\n  // Graph stored in different format. We use an on-demand strategy: the format\n  // is only materialized if the operation that suitable for it is invoked.\n  /** @brief CSR graph that stores reverse edges */\n  CSRPtr in_csr_;\n  /** @brief CSR representation */\n  CSRPtr out_csr_;\n  /** @brief COO representation */\n  COOPtr coo_;\n  /**\n   * @brief Storage format restriction.\n   */\n  dgl_format_code_t formats_;\n  /** @brief which streams have recorded the graph */\n  std::vector<DGLStreamHandle> recorded_streams;\n};\n\n};  // namespace dgl\n\nnamespace dmlc {\nDMLC_DECLARE_TRAITS(has_saveload, dgl::UnitGraph, true);\nDMLC_DECLARE_TRAITS(has_saveload, dgl::UnitGraph::CSR, true);\nDMLC_DECLARE_TRAITS(has_saveload, dgl::UnitGraph::COO, true);\n}  // namespace dmlc\n\n#endif  // DGL_GRAPH_UNIT_GRAPH_H_\n"
  },
  {
    "path": "src/partition/cuda/partition_op.cu",
    "content": "/**\n *  Copyright (c) 2021 by Contributors\n * @file ndarray_partition.h\n * @brief Operations on partition implemented in CUDA.\n */\n\n#include <dgl/runtime/device_api.h>\n\n#include <cub/cub.cuh>\n\n#include \"../../runtime/cuda/cuda_common.h\"\n#include \"../../runtime/workspace.h\"\n#include \"../partition_op.h\"\n\nusing namespace dgl::runtime;\n\nnamespace dgl {\nnamespace partition {\nnamespace impl {\n\nnamespace {\n\n/**\n * @brief Kernel to map global element IDs to partition IDs by remainder.\n *\n * @tparam IdType The type of ID.\n * @param global The global element IDs.\n * @param num_elements The number of element IDs.\n * @param num_parts The number of partitions.\n * @param part_id The mapped partition ID (outupt).\n */\ntemplate <typename IdType>\n__global__ void _MapProcByRemainderKernel(\n    const IdType* const global, const int64_t num_elements,\n    const int64_t num_parts, IdType* const part_id) {\n  assert(num_elements <= gridDim.x * blockDim.x);\n  const int64_t idx =\n      blockDim.x * static_cast<int64_t>(blockIdx.x) + threadIdx.x;\n\n  if (idx < num_elements) {\n    part_id[idx] = global[idx] % num_parts;\n  }\n}\n\n/**\n * @brief Kernel to map global element IDs to partition IDs, using a bit-mask.\n * The number of partitions must be a power a two.\n *\n * @tparam IdType The type of ID.\n * @param global The global element IDs.\n * @param num_elements The number of element IDs.\n * @param mask The bit-mask with 1's for each bit to keep from the element ID to\n * extract the partition ID (e.g., an 8 partition mask would be 0x07).\n * @param part_id The mapped partition ID (outupt).\n */\ntemplate <typename IdType>\n__global__ void _MapProcByMaskRemainderKernel(\n    const IdType* const global, const int64_t num_elements, const IdType mask,\n    IdType* const part_id) {\n  assert(num_elements <= gridDim.x * blockDim.x);\n  const int64_t idx =\n      blockDim.x * static_cast<int64_t>(blockIdx.x) + threadIdx.x;\n\n  if (idx < num_elements) {\n    part_id[idx] = global[idx] & mask;\n  }\n}\n\n/**\n * @brief Kernel to map global element IDs to local element IDs.\n *\n * @tparam IdType The type of ID.\n * @param global The global element IDs.\n * @param num_elements The number of IDs.\n * @param num_parts The number of partitions.\n * @param local The local element IDs (output).\n */\ntemplate <typename IdType>\n__global__ void _MapLocalIndexByRemainderKernel(\n    const IdType* const global, const int64_t num_elements, const int num_parts,\n    IdType* const local) {\n  assert(num_elements <= gridDim.x * blockDim.x);\n  const int64_t idx = threadIdx.x + blockDim.x * blockIdx.x;\n\n  if (idx < num_elements) {\n    local[idx] = global[idx] / num_parts;\n  }\n}\n\n/**\n * @brief Kernel to map local element IDs within a partition to their global\n * IDs, using the remainder over the number of partitions.\n *\n * @tparam IdType The type of ID.\n * @param local The local element IDs.\n * @param part_id The partition to map local elements from.\n * @param num_elements The number of elements to map.\n * @param num_parts The number of partitions.\n * @param global The global element IDs (output).\n */\ntemplate <typename IdType>\n__global__ void _MapGlobalIndexByRemainderKernel(\n    const IdType* const local, const int part_id, const int64_t num_elements,\n    const int num_parts, IdType* const global) {\n  assert(num_elements <= gridDim.x * blockDim.x);\n  const int64_t idx = threadIdx.x + blockDim.x * blockIdx.x;\n\n  assert(part_id < num_parts);\n\n  if (idx < num_elements) {\n    global[idx] = (local[idx] * num_parts) + part_id;\n  }\n}\n\n/**\n * @brief Device function to perform a binary search to find to which partition\n * a given ID belongs.\n *\n * @tparam RangeType The type of range.\n * @param range The prefix-sum of IDs assigned to partitions.\n * @param num_parts The number of partitions.\n * @param target The element ID to find the partition of.\n *\n * @return The partition.\n */\ntemplate <typename RangeType>\n__device__ RangeType _SearchRange(\n    const RangeType* const range, const int num_parts, const RangeType target) {\n  int start = 0;\n  int end = num_parts;\n  int cur = (end + start) / 2;\n\n  assert(range[0] == 0);\n  assert(target < range[num_parts]);\n\n  while (start + 1 < end) {\n    if (target < range[cur]) {\n      end = cur;\n    } else {\n      start = cur;\n    }\n    cur = (start + end) / 2;\n  }\n\n  return cur;\n}\n\n/**\n * @brief Kernel to map element IDs to partition IDs.\n *\n * @tparam IdType The type of element ID.\n * @tparam RangeType The type of of the range.\n * @param range The prefix-sum of IDs assigned to partitions.\n * @param global The global element IDs.\n * @param num_elements The number of element IDs.\n * @param num_parts The number of partitions.\n * @param part_id The partition ID assigned to each element (output).\n */\ntemplate <typename IdType, typename RangeType>\n__global__ void _MapProcByRangeKernel(\n    const RangeType* const range, const IdType* const global,\n    const int64_t num_elements, const int64_t num_parts,\n    IdType* const part_id) {\n  assert(num_elements <= gridDim.x * blockDim.x);\n  const int64_t idx =\n      blockDim.x * static_cast<int64_t>(blockIdx.x) + threadIdx.x;\n\n  // rely on caching to load the range into L1 cache\n  if (idx < num_elements) {\n    part_id[idx] = static_cast<IdType>(_SearchRange(\n        range, static_cast<int>(num_parts),\n        static_cast<RangeType>(global[idx])));\n  }\n}\n\n/**\n * @brief Kernel to map global element IDs to their ID within their respective\n * partition.\n *\n * @tparam IdType The type of element ID.\n * @tparam RangeType The type of the range.\n * @param range The prefix-sum of IDs assigned to partitions.\n * @param global The global element IDs.\n * @param num_elements The number of elements.\n * @param num_parts The number of partitions.\n * @param local The local element IDs (output).\n */\ntemplate <typename IdType, typename RangeType>\n__global__ void _MapLocalIndexByRangeKernel(\n    const RangeType* const range, const IdType* const global,\n    const int64_t num_elements, const int num_parts, IdType* const local) {\n  assert(num_elements <= gridDim.x * blockDim.x);\n  const int64_t idx = threadIdx.x + blockDim.x * blockIdx.x;\n\n  // rely on caching to load the range into L1 cache\n  if (idx < num_elements) {\n    const int proc = _SearchRange(\n        range, static_cast<int>(num_parts),\n        static_cast<RangeType>(global[idx]));\n    local[idx] = global[idx] - range[proc];\n  }\n}\n\n/**\n * @brief Kernel to map local element IDs within a partition to their global\n * IDs.\n *\n * @tparam IdType The type of ID.\n * @tparam RangeType The type of the range.\n * @param range The prefix-sum of IDs assigend to partitions.\n * @param local The local element IDs.\n * @param part_id The partition to map local elements from.\n * @param num_elements The number of elements to map.\n * @param num_parts The number of partitions.\n * @param global The global element IDs (output).\n */\ntemplate <typename IdType, typename RangeType>\n__global__ void _MapGlobalIndexByRangeKernel(\n    const RangeType* const range, const IdType* const local, const int part_id,\n    const int64_t num_elements, const int num_parts, IdType* const global) {\n  assert(num_elements <= gridDim.x * blockDim.x);\n  const int64_t idx = threadIdx.x + blockDim.x * blockIdx.x;\n\n  assert(part_id < num_parts);\n\n  // rely on caching to load the range into L1 cache\n  if (idx < num_elements) {\n    global[idx] = local[idx] + range[part_id];\n  }\n}\n}  // namespace\n\n// Remainder Based Partition Operations\n\ntemplate <DGLDeviceType XPU, typename IdType>\nstd::pair<IdArray, NDArray> GeneratePermutationFromRemainder(\n    int64_t array_size, int num_parts, IdArray in_idx) {\n  std::pair<IdArray, NDArray> result;\n\n  const auto& ctx = in_idx->ctx;\n  auto device = DeviceAPI::Get(ctx);\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n\n  const int64_t num_in = in_idx->shape[0];\n\n  CHECK_GE(num_parts, 1) << \"The number of partitions (\" << num_parts\n                         << \") must be at least 1.\";\n  if (num_parts == 1) {\n    // no permutation\n    result.first = aten::Range(0, num_in, sizeof(IdType) * 8, ctx);\n    result.second = aten::Full(num_in, num_parts, sizeof(int64_t) * 8, ctx);\n\n    return result;\n  }\n\n  result.first = aten::NewIdArray(num_in, ctx, sizeof(IdType) * 8);\n  result.second = aten::Full(0, num_parts, sizeof(int64_t) * 8, ctx);\n  int64_t* out_counts = static_cast<int64_t*>(result.second->data);\n  if (num_in == 0) {\n    // now that we've zero'd out_counts, nothing left to do for an empty\n    // mapping\n    return result;\n  }\n\n  const int64_t part_bits =\n      static_cast<int64_t>(std::ceil(std::log2(num_parts)));\n\n  // First, generate a mapping of indexes to processors\n  Workspace<IdType> proc_id_in(device, ctx, num_in);\n  {\n    const dim3 block(256);\n    const dim3 grid((num_in + block.x - 1) / block.x);\n\n    if (num_parts < (1 << part_bits)) {\n      // num_parts is not a power of 2\n      CUDA_KERNEL_CALL(\n          _MapProcByRemainderKernel, grid, block, 0, stream,\n          static_cast<const IdType*>(in_idx->data), num_in, num_parts,\n          proc_id_in.get());\n    } else {\n      // num_parts is a power of 2\n      CUDA_KERNEL_CALL(\n          _MapProcByMaskRemainderKernel, grid, block, 0, stream,\n          static_cast<const IdType*>(in_idx->data), num_in,\n          static_cast<IdType>(num_parts - 1),  // bit mask\n          proc_id_in.get());\n    }\n  }\n\n  // then create a permutation array that groups processors together by\n  // performing a radix sort\n  Workspace<IdType> proc_id_out(device, ctx, num_in);\n  IdType* perm_out = static_cast<IdType*>(result.first->data);\n  {\n    IdArray perm_in = aten::Range(0, num_in, sizeof(IdType) * 8, ctx);\n\n    size_t sort_workspace_size;\n    CUDA_CALL(cub::DeviceRadixSort::SortPairs(\n        nullptr, sort_workspace_size, proc_id_in.get(), proc_id_out.get(),\n        static_cast<IdType*>(perm_in->data), perm_out, num_in, 0, part_bits,\n        stream));\n\n    Workspace<void> sort_workspace(device, ctx, sort_workspace_size);\n    CUDA_CALL(cub::DeviceRadixSort::SortPairs(\n        sort_workspace.get(), sort_workspace_size, proc_id_in.get(),\n        proc_id_out.get(), static_cast<IdType*>(perm_in->data), perm_out,\n        num_in, 0, part_bits, stream));\n  }\n  // explicitly free so workspace can be re-used\n  proc_id_in.free();\n\n  // perform a histogram and then prefixsum on the sorted proc_id vector\n\n  // Count the number of values to be sent to each processor\n  {\n    using AtomicCount = unsigned long long;  // NOLINT\n    static_assert(\n        sizeof(AtomicCount) == sizeof(*out_counts),\n        \"AtomicCount must be the same width as int64_t for atomicAdd \"\n        \"in cub::DeviceHistogram::HistogramEven() to work\");\n\n    // TODO(dlasalle): Once https://github.com/NVIDIA/cub/pull/287 is merged,\n    // add a compile time check against the cub version to allow\n    // num_in > (2 << 31).\n    CHECK(num_in < static_cast<int64_t>(std::numeric_limits<int>::max()))\n        << \"number of values to insert into histogram must be less than max \"\n           \"value of int.\";\n\n    size_t hist_workspace_size;\n    CUDA_CALL(cub::DeviceHistogram::HistogramEven(\n        nullptr, hist_workspace_size, proc_id_out.get(),\n        reinterpret_cast<AtomicCount*>(out_counts), num_parts + 1,\n        static_cast<IdType>(0), static_cast<IdType>(num_parts),\n        static_cast<int>(num_in), stream));\n\n    Workspace<void> hist_workspace(device, ctx, hist_workspace_size);\n    CUDA_CALL(cub::DeviceHistogram::HistogramEven(\n        hist_workspace.get(), hist_workspace_size, proc_id_out.get(),\n        reinterpret_cast<AtomicCount*>(out_counts), num_parts + 1,\n        static_cast<IdType>(0), static_cast<IdType>(num_parts),\n        static_cast<int>(num_in), stream));\n  }\n\n  return result;\n}\n\ntemplate std::pair<IdArray, IdArray> GeneratePermutationFromRemainder<\n    kDGLCUDA, int32_t>(int64_t array_size, int num_parts, IdArray in_idx);\ntemplate std::pair<IdArray, IdArray> GeneratePermutationFromRemainder<\n    kDGLCUDA, int64_t>(int64_t array_size, int num_parts, IdArray in_idx);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nIdArray MapToLocalFromRemainder(const int num_parts, IdArray global_idx) {\n  const auto& ctx = global_idx->ctx;\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n\n  if (num_parts > 1) {\n    IdArray local_idx =\n        aten::NewIdArray(global_idx->shape[0], ctx, sizeof(IdType) * 8);\n\n    const dim3 block(128);\n    const dim3 grid((global_idx->shape[0] + block.x - 1) / block.x);\n\n    CUDA_KERNEL_CALL(\n        _MapLocalIndexByRemainderKernel, grid, block, 0, stream,\n        static_cast<const IdType*>(global_idx->data), global_idx->shape[0],\n        num_parts, static_cast<IdType*>(local_idx->data));\n\n    return local_idx;\n  } else {\n    // no mapping to be done\n    return global_idx;\n  }\n}\n\ntemplate IdArray MapToLocalFromRemainder<kDGLCUDA, int32_t>(\n    int num_parts, IdArray in_idx);\ntemplate IdArray MapToLocalFromRemainder<kDGLCUDA, int64_t>(\n    int num_parts, IdArray in_idx);\n\ntemplate <DGLDeviceType XPU, typename IdType>\nIdArray MapToGlobalFromRemainder(\n    const int num_parts, IdArray local_idx, const int part_id) {\n  CHECK_LT(part_id, num_parts)\n      << \"Invalid partition id \" << part_id << \"/\" << num_parts;\n  CHECK_GE(part_id, 0) << \"Invalid partition id \" << part_id << \"/\"\n                       << num_parts;\n\n  const auto& ctx = local_idx->ctx;\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n\n  if (num_parts > 1) {\n    IdArray global_idx =\n        aten::NewIdArray(local_idx->shape[0], ctx, sizeof(IdType) * 8);\n\n    const dim3 block(128);\n    const dim3 grid((local_idx->shape[0] + block.x - 1) / block.x);\n\n    CUDA_KERNEL_CALL(\n        _MapGlobalIndexByRemainderKernel, grid, block, 0, stream,\n        static_cast<const IdType*>(local_idx->data), part_id,\n        global_idx->shape[0], num_parts,\n        static_cast<IdType*>(global_idx->data));\n\n    return global_idx;\n  } else {\n    // no mapping to be done\n    return local_idx;\n  }\n}\n\ntemplate IdArray MapToGlobalFromRemainder<kDGLCUDA, int32_t>(\n    int num_parts, IdArray in_idx, int part_id);\ntemplate IdArray MapToGlobalFromRemainder<kDGLCUDA, int64_t>(\n    int num_parts, IdArray in_idx, int part_id);\n\n// Range Based Partition Operations\n\ntemplate <DGLDeviceType XPU, typename IdType, typename RangeType>\nstd::pair<IdArray, NDArray> GeneratePermutationFromRange(\n    int64_t array_size, int num_parts, IdArray range, IdArray in_idx) {\n  std::pair<IdArray, NDArray> result;\n\n  const auto& ctx = in_idx->ctx;\n  auto device = DeviceAPI::Get(ctx);\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n\n  const int64_t num_in = in_idx->shape[0];\n\n  CHECK_GE(num_parts, 1) << \"The number of partitions (\" << num_parts\n                         << \") must be at least 1.\";\n  if (num_parts == 1) {\n    // no permutation\n    result.first = aten::Range(0, num_in, sizeof(IdType) * 8, ctx);\n    result.second = aten::Full(num_in, num_parts, sizeof(int64_t) * 8, ctx);\n\n    return result;\n  }\n\n  result.first = aten::NewIdArray(num_in, ctx, sizeof(IdType) * 8);\n  result.second = aten::Full(0, num_parts, sizeof(int64_t) * 8, ctx);\n  int64_t* out_counts = static_cast<int64_t*>(result.second->data);\n  if (num_in == 0) {\n    // now that we've zero'd out_counts, nothing left to do for an empty\n    // mapping\n    return result;\n  }\n\n  const int64_t part_bits =\n      static_cast<int64_t>(std::ceil(std::log2(num_parts)));\n\n  // First, generate a mapping of indexes to processors\n  Workspace<IdType> proc_id_in(device, ctx, num_in);\n  {\n    const dim3 block(256);\n    const dim3 grid((num_in + block.x - 1) / block.x);\n\n    CUDA_KERNEL_CALL(\n        _MapProcByRangeKernel, grid, block, 0, stream,\n        static_cast<const RangeType*>(range->data),\n        static_cast<const IdType*>(in_idx->data), num_in, num_parts,\n        proc_id_in.get());\n  }\n\n  // then create a permutation array that groups processors together by\n  // performing a radix sort\n  Workspace<IdType> proc_id_out(device, ctx, num_in);\n  IdType* perm_out = static_cast<IdType*>(result.first->data);\n  {\n    IdArray perm_in = aten::Range(0, num_in, sizeof(IdType) * 8, ctx);\n\n    size_t sort_workspace_size;\n    CUDA_CALL(cub::DeviceRadixSort::SortPairs(\n        nullptr, sort_workspace_size, proc_id_in.get(), proc_id_out.get(),\n        static_cast<IdType*>(perm_in->data), perm_out, num_in, 0, part_bits,\n        stream));\n\n    Workspace<void> sort_workspace(device, ctx, sort_workspace_size);\n    CUDA_CALL(cub::DeviceRadixSort::SortPairs(\n        sort_workspace.get(), sort_workspace_size, proc_id_in.get(),\n        proc_id_out.get(), static_cast<IdType*>(perm_in->data), perm_out,\n        num_in, 0, part_bits, stream));\n  }\n  // explicitly free so workspace can be re-used\n  proc_id_in.free();\n\n  // perform a histogram and then prefixsum on the sorted proc_id vector\n\n  // Count the number of values to be sent to each processor\n  {\n    using AtomicCount = unsigned long long;  // NOLINT\n    static_assert(\n        sizeof(AtomicCount) == sizeof(*out_counts),\n        \"AtomicCount must be the same width as int64_t for atomicAdd \"\n        \"in cub::DeviceHistogram::HistogramEven() to work\");\n\n    // TODO(dlasalle): Once https://github.com/NVIDIA/cub/pull/287 is merged,\n    // add a compile time check against the cub version to allow\n    // num_in > (2 << 31).\n    CHECK(num_in < static_cast<int64_t>(std::numeric_limits<int>::max()))\n        << \"number of values to insert into histogram must be less than max \"\n           \"value of int.\";\n\n    size_t hist_workspace_size;\n    CUDA_CALL(cub::DeviceHistogram::HistogramEven(\n        nullptr, hist_workspace_size, proc_id_out.get(),\n        reinterpret_cast<AtomicCount*>(out_counts), num_parts + 1,\n        static_cast<IdType>(0), static_cast<IdType>(num_parts),\n        static_cast<int>(num_in), stream));\n\n    Workspace<void> hist_workspace(device, ctx, hist_workspace_size);\n    CUDA_CALL(cub::DeviceHistogram::HistogramEven(\n        hist_workspace.get(), hist_workspace_size, proc_id_out.get(),\n        reinterpret_cast<AtomicCount*>(out_counts), num_parts + 1,\n        static_cast<IdType>(0), static_cast<IdType>(num_parts),\n        static_cast<int>(num_in), stream));\n  }\n\n  return result;\n}\n\ntemplate std::pair<IdArray, IdArray>\nGeneratePermutationFromRange<kDGLCUDA, int32_t, int32_t>(\n    int64_t array_size, int num_parts, IdArray range, IdArray in_idx);\ntemplate std::pair<IdArray, IdArray>\nGeneratePermutationFromRange<kDGLCUDA, int64_t, int32_t>(\n    int64_t array_size, int num_parts, IdArray range, IdArray in_idx);\ntemplate std::pair<IdArray, IdArray>\nGeneratePermutationFromRange<kDGLCUDA, int32_t, int64_t>(\n    int64_t array_size, int num_parts, IdArray range, IdArray in_idx);\ntemplate std::pair<IdArray, IdArray>\nGeneratePermutationFromRange<kDGLCUDA, int64_t, int64_t>(\n    int64_t array_size, int num_parts, IdArray range, IdArray in_idx);\n\ntemplate <DGLDeviceType XPU, typename IdType, typename RangeType>\nIdArray MapToLocalFromRange(\n    const int num_parts, IdArray range, IdArray global_idx) {\n  const auto& ctx = global_idx->ctx;\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n\n  if (num_parts > 1 && global_idx->shape[0] > 0) {\n    IdArray local_idx =\n        aten::NewIdArray(global_idx->shape[0], ctx, sizeof(IdType) * 8);\n\n    const dim3 block(128);\n    const dim3 grid((global_idx->shape[0] + block.x - 1) / block.x);\n\n    CUDA_KERNEL_CALL(\n        _MapLocalIndexByRangeKernel, grid, block, 0, stream,\n        static_cast<const RangeType*>(range->data),\n        static_cast<const IdType*>(global_idx->data), global_idx->shape[0],\n        num_parts, static_cast<IdType*>(local_idx->data));\n\n    return local_idx;\n  } else {\n    // no mapping to be done\n    return global_idx;\n  }\n}\n\ntemplate IdArray MapToLocalFromRange<kDGLCUDA, int32_t, int32_t>(\n    int num_parts, IdArray range, IdArray in_idx);\ntemplate IdArray MapToLocalFromRange<kDGLCUDA, int64_t, int32_t>(\n    int num_parts, IdArray range, IdArray in_idx);\ntemplate IdArray MapToLocalFromRange<kDGLCUDA, int32_t, int64_t>(\n    int num_parts, IdArray range, IdArray in_idx);\ntemplate IdArray MapToLocalFromRange<kDGLCUDA, int64_t, int64_t>(\n    int num_parts, IdArray range, IdArray in_idx);\n\ntemplate <DGLDeviceType XPU, typename IdType, typename RangeType>\nIdArray MapToGlobalFromRange(\n    const int num_parts, IdArray range, IdArray local_idx, const int part_id) {\n  CHECK_LT(part_id, num_parts)\n      << \"Invalid partition id \" << part_id << \"/\" << num_parts;\n  CHECK_GE(part_id, 0) << \"Invalid partition id \" << part_id << \"/\"\n                       << num_parts;\n\n  const auto& ctx = local_idx->ctx;\n  cudaStream_t stream = runtime::getCurrentCUDAStream();\n\n  if (num_parts > 1 && local_idx->shape[0] > 0) {\n    IdArray global_idx =\n        aten::NewIdArray(local_idx->shape[0], ctx, sizeof(IdType) * 8);\n\n    const dim3 block(128);\n    const dim3 grid((local_idx->shape[0] + block.x - 1) / block.x);\n\n    CUDA_KERNEL_CALL(\n        _MapGlobalIndexByRangeKernel, grid, block, 0, stream,\n        static_cast<const RangeType*>(range->data),\n        static_cast<const IdType*>(local_idx->data), part_id,\n        global_idx->shape[0], num_parts,\n        static_cast<IdType*>(global_idx->data));\n\n    return global_idx;\n  } else {\n    // no mapping to be done\n    return local_idx;\n  }\n}\n\ntemplate IdArray MapToGlobalFromRange<kDGLCUDA, int32_t, int32_t>(\n    int num_parts, IdArray range, IdArray in_idx, int part_id);\ntemplate IdArray MapToGlobalFromRange<kDGLCUDA, int64_t, int32_t>(\n    int num_parts, IdArray range, IdArray in_idx, int part_id);\ntemplate IdArray MapToGlobalFromRange<kDGLCUDA, int32_t, int64_t>(\n    int num_parts, IdArray range, IdArray in_idx, int part_id);\ntemplate IdArray MapToGlobalFromRange<kDGLCUDA, int64_t, int64_t>(\n    int num_parts, IdArray range, IdArray in_idx, int part_id);\n\n}  // namespace impl\n}  // namespace partition\n}  // namespace dgl\n"
  },
  {
    "path": "src/partition/ndarray_partition.cc",
    "content": "/**\n *  Copyright (c) 2021 by Contributors\n * @file ndarray_partition.cc\n * @brief DGL utilities for working with the partitioned NDArrays\n */\n\n#include \"ndarray_partition.h\"\n\n#include <dgl/runtime/packed_func.h>\n#include <dgl/runtime/registry.h>\n\n#include <memory>\n#include <utility>\n\n#include \"../c_api_common.h\"\n#include \"partition_op.h\"\n\nusing namespace dgl::runtime;\n\nnamespace dgl {\nnamespace partition {\n\nNDArrayPartition::NDArrayPartition(\n    const int64_t array_size, const int num_parts)\n    : array_size_(array_size), num_parts_(num_parts) {}\n\nint64_t NDArrayPartition::ArraySize() const { return array_size_; }\n\nint NDArrayPartition::NumParts() const { return num_parts_; }\n\nclass RemainderPartition : public NDArrayPartition {\n public:\n  RemainderPartition(const int64_t array_size, const int num_parts)\n      : NDArrayPartition(array_size, num_parts) {\n    // do nothing\n  }\n\n  std::pair<IdArray, NDArray> GeneratePermutation(\n      IdArray in_idx) const override {\n#ifdef DGL_USE_CUDA\n    auto ctx = in_idx->ctx;\n    if (ctx.device_type == kDGLCUDA) {\n      ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, {\n        return impl::GeneratePermutationFromRemainder<kDGLCUDA, IdType>(\n            ArraySize(), NumParts(), in_idx);\n      });\n    }\n#endif\n\n    LOG(FATAL) << \"Remainder based partitioning for the CPU is not yet \"\n                  \"implemented.\";\n    // should be unreachable\n    return std::pair<IdArray, NDArray>{};\n  }\n\n  IdArray MapToLocal(IdArray in_idx) const override {\n#ifdef DGL_USE_CUDA\n    auto ctx = in_idx->ctx;\n    if (ctx.device_type == kDGLCUDA) {\n      ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, {\n        return impl::MapToLocalFromRemainder<kDGLCUDA, IdType>(\n            NumParts(), in_idx);\n      });\n    }\n#endif\n\n    LOG(FATAL) << \"Remainder based partitioning for the CPU is not yet \"\n                  \"implemented.\";\n    // should be unreachable\n    return IdArray{};\n  }\n\n  IdArray MapToGlobal(IdArray in_idx, const int part_id) const override {\n#ifdef DGL_USE_CUDA\n    auto ctx = in_idx->ctx;\n    if (ctx.device_type == kDGLCUDA) {\n      ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, {\n        return impl::MapToGlobalFromRemainder<kDGLCUDA, IdType>(\n            NumParts(), in_idx, part_id);\n      });\n    }\n#endif\n\n    LOG(FATAL) << \"Remainder based partitioning for the CPU is not yet \"\n                  \"implemented.\";\n    // should be unreachable\n    return IdArray{};\n  }\n\n  int64_t PartSize(const int part_id) const override {\n    CHECK_LT(part_id, NumParts()) << \"Invalid part ID (\" << part_id\n                                  << \") for \"\n                                     \"partition of size \"\n                                  << NumParts() << \".\";\n    return ArraySize() / NumParts() + (part_id < ArraySize() % NumParts());\n  }\n};\n\nclass RangePartition : public NDArrayPartition {\n public:\n  RangePartition(const int64_t array_size, const int num_parts, IdArray range)\n      : NDArrayPartition(array_size, num_parts),\n        range_(range),\n        // We also need a copy of the range on the CPU, to compute partition\n        // sizes. We require the input range on the GPU, as if we have multiple\n        // GPUs, we can't know which is the proper one to copy the array to, but\n        // we have only one CPU context, and can safely copy the array to that.\n        range_cpu_(range.CopyTo(DGLContext{kDGLCPU, 0})) {\n    auto ctx = range->ctx;\n    if (ctx.device_type != kDGLCUDA) {\n      LOG(FATAL) << \"The range for an NDArrayPartition is only supported \"\n                    \" on GPUs. Transfer the range to the target device before \"\n                    \"creating the partition.\";\n    }\n  }\n\n  std::pair<IdArray, NDArray> GeneratePermutation(\n      IdArray in_idx) const override {\n#ifdef DGL_USE_CUDA\n    auto ctx = in_idx->ctx;\n    if (ctx.device_type == kDGLCUDA) {\n      if (ctx.device_type != range_->ctx.device_type ||\n          ctx.device_id != range_->ctx.device_id) {\n        LOG(FATAL) << \"The range for the NDArrayPartition and the input \"\n                      \"array must be on the same device: \"\n                   << ctx << \" vs. \" << range_->ctx;\n      }\n      ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, {\n        ATEN_ID_TYPE_SWITCH(range_->dtype, RangeType, {\n          return impl::GeneratePermutationFromRange<\n              kDGLCUDA, IdType, RangeType>(\n              ArraySize(), NumParts(), range_, in_idx);\n        });\n      });\n    }\n#endif\n\n    LOG(FATAL) << \"Remainder based partitioning for the CPU is not yet \"\n                  \"implemented.\";\n    // should be unreachable\n    return std::pair<IdArray, NDArray>{};\n  }\n\n  IdArray MapToLocal(IdArray in_idx) const override {\n#ifdef DGL_USE_CUDA\n    auto ctx = in_idx->ctx;\n    if (ctx.device_type == kDGLCUDA) {\n      ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, {\n        ATEN_ID_TYPE_SWITCH(range_->dtype, RangeType, {\n          return impl::MapToLocalFromRange<kDGLCUDA, IdType, RangeType>(\n              NumParts(), range_, in_idx);\n        });\n      });\n    }\n#endif\n\n    LOG(FATAL) << \"Remainder based partitioning for the CPU is not yet \"\n                  \"implemented.\";\n    // should be unreachable\n    return IdArray{};\n  }\n\n  IdArray MapToGlobal(IdArray in_idx, const int part_id) const override {\n#ifdef DGL_USE_CUDA\n    auto ctx = in_idx->ctx;\n    if (ctx.device_type == kDGLCUDA) {\n      ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, {\n        ATEN_ID_TYPE_SWITCH(range_->dtype, RangeType, {\n          return impl::MapToGlobalFromRange<kDGLCUDA, IdType, RangeType>(\n              NumParts(), range_, in_idx, part_id);\n        });\n      });\n    }\n#endif\n\n    LOG(FATAL) << \"Remainder based partitioning for the CPU is not yet \"\n                  \"implemented.\";\n    // should be unreachable\n    return IdArray{};\n  }\n\n  int64_t PartSize(const int part_id) const override {\n    CHECK_LT(part_id, NumParts()) << \"Invalid part ID (\" << part_id\n                                  << \") for \"\n                                     \"partition of size \"\n                                  << NumParts() << \".\";\n    int64_t part_size = -1;\n    ATEN_ID_TYPE_SWITCH(range_cpu_->dtype, RangeType, {\n      const RangeType* const ptr =\n          static_cast<const RangeType*>(range_cpu_->data);\n      part_size = ptr[part_id + 1] - ptr[part_id];\n    });\n    return part_size;\n  }\n\n private:\n  IdArray range_;\n  IdArray range_cpu_;\n};\n\nNDArrayPartitionRef CreatePartitionRemainderBased(\n    const int64_t array_size, const int num_parts) {\n  return NDArrayPartitionRef(\n      std::make_shared<RemainderPartition>(array_size, num_parts));\n}\n\nNDArrayPartitionRef CreatePartitionRangeBased(\n    const int64_t array_size, const int num_parts, IdArray range) {\n  return NDArrayPartitionRef(\n      std::make_shared<RangePartition>(array_size, num_parts, range));\n}\n\nDGL_REGISTER_GLOBAL(\"partition._CAPI_DGLNDArrayPartitionCreateRemainderBased\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      int64_t array_size = args[0];\n      int num_parts = args[1];\n\n      *rv = CreatePartitionRemainderBased(array_size, num_parts);\n    });\n\nDGL_REGISTER_GLOBAL(\"partition._CAPI_DGLNDArrayPartitionCreateRangeBased\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      const int64_t array_size = args[0];\n      const int num_parts = args[1];\n      IdArray range = args[2];\n\n      *rv = CreatePartitionRangeBased(array_size, num_parts, range);\n    });\n\nDGL_REGISTER_GLOBAL(\"partition._CAPI_DGLNDArrayPartitionGetPartSize\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      NDArrayPartitionRef part = args[0];\n      int part_id = args[1];\n\n      *rv = part->PartSize(part_id);\n    });\n\nDGL_REGISTER_GLOBAL(\"partition._CAPI_DGLNDArrayPartitionMapToLocal\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      NDArrayPartitionRef part = args[0];\n      IdArray idxs = args[1];\n\n      *rv = part->MapToLocal(idxs);\n    });\n\nDGL_REGISTER_GLOBAL(\"partition._CAPI_DGLNDArrayPartitionMapToGlobal\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      NDArrayPartitionRef part = args[0];\n      IdArray idxs = args[1];\n      const int part_id = args[2];\n\n      *rv = part->MapToGlobal(idxs, part_id);\n    });\n\nDGL_REGISTER_GLOBAL(\"partition._CAPI_DGLNDArrayPartitionGeneratePermutation\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      NDArrayPartitionRef part = args[0];\n      IdArray idxs = args[1];\n\n      std::pair<IdArray, NDArray> part_perm = part->GeneratePermutation(idxs);\n      *rv =\n          ConvertNDArrayVectorToPackedFunc({part_perm.first, part_perm.second});\n    });\n\n}  // namespace partition\n}  // namespace dgl\n"
  },
  {
    "path": "src/partition/ndarray_partition.h",
    "content": "/**\n *  Copyright (c) 2021 by Contributors\n * @file ndarray_partition.h\n * @brief DGL utilities for working with the partitioned NDArrays\n */\n\n#ifndef DGL_PARTITION_NDARRAY_PARTITION_H_\n#define DGL_PARTITION_NDARRAY_PARTITION_H_\n\n#include <dgl/array.h>\n#include <dgl/packed_func_ext.h>\n#include <dgl/runtime/object.h>\n\n#include <utility>\n\nnamespace dgl {\nnamespace partition {\n\n/**\n * @brief The top-level partition class. Specific types of partitions should be\n * sub-classes of this.\n */\nclass NDArrayPartition : public runtime::Object {\n public:\n  /**\n   * @brief Create a new partition.\n   *\n   * @param array_size The first dimension of the partitioned array.\n   * @param num_parts The number parts to the array is split into.\n   */\n  NDArrayPartition(int64_t array_size, int num_parts);\n\n  virtual ~NDArrayPartition() = default;\n\n  static constexpr const char* _type_key = \"partition.NDArrayPartition\";\n\n  DGL_DECLARE_OBJECT_TYPE_INFO(NDArrayPartition, Object);\n\n  /**\n   * @brief Create a mapping for the given indices to different partitions,\n   * and a count of the number of indices per part.\n   *\n   * A prefix-sum of the counts, can be used to select the continuous sets of\n   * indices destined for each part.\n   *\n   * @param in_idx The input indices to map.\n   *\n   * @return A pair containing 0) the permutation to re-order the indices by\n   * partition, 1) the number of indices per partition (int64_t).\n   */\n  virtual std::pair<IdArray, NDArray> GeneratePermutation(\n      IdArray in_idx) const = 0;\n\n  /**\n   * @brief Generate the local indices (the numbering within each processor)\n   * from a set of global indices.\n   *\n   * @param in_idx The global indices.\n   *\n   * @return The local indices.\n   */\n  virtual IdArray MapToLocal(IdArray in_idx) const = 0;\n\n  /**\n   * @brief Generate the global indices (the numbering unique across all\n   * processors) from a set of local indices.\n   *\n   * @param in_idx The local indices.\n   * @param part_id The part id.\n   *\n   * @return The global indices.\n   */\n  virtual IdArray MapToGlobal(IdArray in_idx, int part_id) const = 0;\n\n  /**\n   * @brief Get the number of rows/items assigned to the given part.\n   *\n   * @param part_id The part id.\n   *\n   * @return The size.\n   */\n  virtual int64_t PartSize(int part_id) const = 0;\n\n  /**\n   * @brief Get the first dimension of the partitioned array.\n   *\n   * @return The size.\n   */\n  int64_t ArraySize() const;\n\n  /**\n   * @brief Get the number of parts in this partition.\n   *\n   * @return The number of parts.\n   */\n  int NumParts() const;\n\n private:\n  int64_t array_size_;\n  int num_parts_;\n};\n\nDGL_DEFINE_OBJECT_REF(NDArrayPartitionRef, NDArrayPartition);\n\n/**\n * @brief Create a new partition object, using the remainder of the row id\n * divided by the number of parts, to assign rows to parts.\n *\n * @param array_size The first dimension of the array.\n * @param num_parts The number of parts.\n *\n * @return The partition object.\n */\nNDArrayPartitionRef CreatePartitionRemainderBased(\n    int64_t array_size, int num_parts);\n\n/**\n * @brief Create a new partition object, using the range (exclusive prefix-sum)\n * provided to identify which rows belong to which partitions.\n *\n * @param array_size The size of the partitioned array.\n * @param num_parts The number of parts the array is partitioned into.\n * @param range The exclusive prefix-sum of the number of rows owned by each\n * partition. The first value must be zero, and the last value must be the\n * total number of rows. It should be of length `num_parts+1`.\n *\n * @return The partition object.\n */\nNDArrayPartitionRef CreatePartitionRangeBased(\n    int64_t array_size, int num_parts, IdArray range);\n\n}  // namespace partition\n}  // namespace dgl\n\n#endif  // DGL_PARTITION_NDARRAY_PARTITION_H_\n"
  },
  {
    "path": "src/partition/partition_op.h",
    "content": "/**\n *  Copyright (c) 2021 by Contributors\n * @file ndarray_partition.h\n * @brief DGL utilities for working with the partitioned NDArrays\n */\n\n#ifndef DGL_PARTITION_PARTITION_OP_H_\n#define DGL_PARTITION_PARTITION_OP_H_\n\n#include <dgl/array.h>\n\n#include <utility>\n\nnamespace dgl {\nnamespace partition {\nnamespace impl {\n\n/**\n * @brief Create a permutation that groups indices by the part id when used for\n * slicing, via the remainder. That is, for the input indices A, find I\n * such that A[I] is grouped by part ID.\n *\n * For example, if we have the set of indices [3, 9, 2, 4, 1, 7] and two\n * partitions, the permutation vector would be [2, 3, 0, 1, 4, 5].\n *\n * @tparam XPU The type of device to run on.\n * @tparam IdType The type of the index.\n * @param array_size The total size of the partitioned array.\n * @param num_parts The number parts the array id divided into.\n * @param in_idx The array of indices to group by part id.\n *\n * @return The permutation to group the indices by part id, and the number of\n * indices in each part.\n */\ntemplate <DGLDeviceType XPU, typename IdType>\nstd::pair<IdArray, IdArray> GeneratePermutationFromRemainder(\n    int64_t array_size, int num_parts, IdArray in_idx);\n\n/**\n * @brief Generate the set of local indices from the global indices, using\n * remainder. That is, for each index `i` in `global_idx`, the local index\n * is computed as `global_idx[i] / num_parts`.\n *\n * @tparam XPU The type of device to run on.\n * @tparam IdType The type of the index.\n * @param num_parts The number parts the array id divided into.\n * @param global_idx The array of global indices to map.\n *\n * @return The array of local indices.\n */\ntemplate <DGLDeviceType XPU, typename IdType>\nIdArray MapToLocalFromRemainder(int num_parts, IdArray global_idx);\n\n/**\n * @brief Generate the set of global indices from the local indices, using\n * remainder. That is, for each index `i` in `local_idx`, the global index\n * is computed as `local_idx[i] * num_parts + part_id`.\n *\n * @tparam XPU The type of device to run on.\n * @tparam IdType The type of the index.\n * @param num_parts The number parts the array id divided into.\n * @param local_idx The array of local indices to map.\n * @param part_id The id of the current part.\n *\n * @return The array of global indices.\n */\ntemplate <DGLDeviceType XPU, typename IdType>\nIdArray MapToGlobalFromRemainder(int num_parts, IdArray local_idx, int part_id);\n\n/**\n * @brief Create a permutation that groups indices by the part id when used for\n * slicing. That is, for the input indices A, find I such that A[I] is grouped\n * by part ID.\n *\n * For example, if we have a range of [0, 5, 10] and the set of indices\n * [3, 9, 2, 4, 1, 7], the permutation vector would be [0, 2, 3, 4, 1, 5].\n *\n * @tparam XPU The type of device to run on.\n * @tparam IdType The type of the index.\n * @tparam RangeType THe type of the range.\n * @param array_size The total size of the partitioned array.\n * @param num_parts The number parts the array id divided into.\n * @param range The exclusive prefix-sum, representing the range of rows\n * assigned to each partition. Must be on the same context as `in_idx`.\n * @param in_idx The array of indices to group by part id.\n *\n * @return The permutation to group the indices by part id, and the number of\n * indices in each part.\n */\ntemplate <DGLDeviceType XPU, typename IdType, typename RangeType>\nstd::pair<IdArray, IdArray> GeneratePermutationFromRange(\n    int64_t array_size, int num_parts, IdArray range, IdArray in_idx);\n\n/**\n * @brief Generate the set of local indices from the global indices, using\n * remainder. That is, for each index `i` in `global_idx`, the local index\n * is computed as `global_idx[i] / num_parts`.\n *\n * @tparam XPU The type of device to run on.\n * @tparam IdType The type of the index.\n * @tparam RangeType THe type of the range.\n * @param num_parts The number parts the array id divided into.\n * @param range The exclusive prefix-sum, representing the range of rows\n * assigned to each partition. Must be on the same context as `global_idx`.\n * @param global_idx The array of global indices to map.\n *\n * @return The array of local indices.\n */\ntemplate <DGLDeviceType XPU, typename IdType, typename RangeType>\nIdArray MapToLocalFromRange(int num_parts, IdArray range, IdArray global_idx);\n\n/**\n * @brief Generate the set of global indices from the local indices, using\n * remainder. That is, for each index `i` in `local_idx`, the global index\n * is computed as `local_idx[i] * num_parts + part_id`.\n *\n * @tparam XPU The type of device to run on.\n * @tparam IdType The type of the index.\n * @tparam RangeType THe type of the range.\n * @param num_parts The number parts the array id divided into.\n * @param range The exclusive prefix-sum, representing the range of rows\n * assigned to each partition. Must be on the same context as `local_idx`.\n * @param local_idx The array of local indices to map.\n * @param part_id The id of the current part.\n *\n * @return The array of global indices.\n */\ntemplate <DGLDeviceType XPU, typename IdType, typename RangeType>\nIdArray MapToGlobalFromRange(\n    int num_parts, IdArray range, IdArray local_idx, int part_id);\n\n}  // namespace impl\n}  // namespace partition\n}  // namespace dgl\n\n#endif  // DGL_PARTITION_PARTITION_OP_H_\n"
  },
  {
    "path": "src/random/continuous_seed.h",
    "content": "/*!\n *   Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)\n *   All rights reserved.\n *\n *   Licensed under the Apache License, Version 2.0 (the \"License\");\n *   you may not use this file except in compliance with the License.\n *   You may obtain a copy of the License at\n *\n *       http://www.apache.org/licenses/LICENSE-2.0\n *\n *   Unless required by applicable law or agreed to in writing, software\n *   distributed under the License is distributed on an \"AS IS\" BASIS,\n *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n *   See the License for the specific language governing permissions and\n *   limitations under the License.\n *\n * @file dgl/continuous_seed.h\n * @brief CPU and CUDA implementation for continuous random seeds\n */\n#ifndef DGL_RANDOM_CONTINUOUS_SEED_H_\n#define DGL_RANDOM_CONTINUOUS_SEED_H_\n\n#include <dgl/array.h>\n\n#include <cmath>\n\n#ifdef __NVCC__\n#include <curand_kernel.h>\n#else\n#include <random>\n\n#include \"pcg_random.hpp\"\n#endif  // __CUDA_ARCH__\n\n#ifndef M_SQRT1_2\n#define M_SQRT1_2 0.707106781186547524401\n#endif  // M_SQRT1_2\n\nnamespace dgl {\nnamespace random {\n\nclass continuous_seed {\n  uint64_t s[2];\n  float c[2];\n\n public:\n  /* implicit */ continuous_seed(const int64_t seed) {  // NOLINT\n    s[0] = s[1] = seed;\n    c[0] = c[1] = 0;\n  }\n\n  continuous_seed(IdArray seed_arr, float r) {\n    auto seed = seed_arr.Ptr<int64_t>();\n    s[0] = seed[0];\n    s[1] = seed[seed_arr->shape[0] - 1];\n    const auto pi = std::acos(-1.0);\n    c[0] = std::cos(pi * r / 2);\n    c[1] = std::sin(pi * r / 2);\n  }\n\n#ifdef __CUDA_ARCH__\n  __device__ inline float uniform(const uint64_t t) const {\n    const uint64_t kCurandSeed = 999961;  // Could be any random number.\n    curandStatePhilox4_32_10_t rng;\n    curand_init(kCurandSeed, s[0], t, &rng);\n    float rnd;\n    if (s[0] != s[1]) {\n      rnd = c[0] * curand_normal(&rng);\n      curand_init(kCurandSeed, s[1], t, &rng);\n      rnd += c[1] * curand_normal(&rng);\n      rnd = normcdff(rnd);\n    } else {\n      rnd = curand_uniform(&rng);\n    }\n    return rnd;\n  }\n#else\n  inline float uniform(const uint64_t t) const {\n    pcg32 ng0(s[0], t);\n    float rnd;\n    if (s[0] != s[1]) {\n      std::normal_distribution<float> norm;\n      rnd = c[0] * norm(ng0);\n      pcg32 ng1(s[1], t);\n      norm.reset();\n      rnd += c[1] * norm(ng1);\n      rnd = std::erfc(-rnd * static_cast<float>(M_SQRT1_2)) / 2.0f;\n    } else {\n      std::uniform_real_distribution<float> uni;\n      rnd = uni(ng0);\n    }\n    return rnd;\n  }\n#endif  // __CUDA_ARCH__\n};\n\n}  // namespace random\n}  // namespace dgl\n\n#endif  // DGL_RANDOM_CONTINUOUS_SEED_H_\n"
  },
  {
    "path": "src/random/cpu/choice.cc",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file random/choice.cc\n * @brief Non-uniform discrete sampling implementation\n */\n\n#include <dgl/array.h>\n#include <dgl/random.h>\n\n#include <numeric>\n#include <vector>\n\n#include \"sample_utils.h\"\n\nnamespace dgl {\n\ntemplate <typename IdxType>\nIdxType RandomEngine::Choice(FloatArray prob) {\n  IdxType ret = 0;\n  ATEN_FLOAT_TYPE_SWITCH(prob->dtype, ValueType, \"probability\", {\n    // TODO(minjie): allow choosing different sampling algorithms\n    utils::TreeSampler<IdxType, ValueType, true> sampler(this, prob);\n    ret = sampler.Draw();\n  });\n  return ret;\n}\n\ntemplate int32_t RandomEngine::Choice<int32_t>(FloatArray);\ntemplate int64_t RandomEngine::Choice<int64_t>(FloatArray);\n\ntemplate <typename IdxType, typename FloatType>\nvoid RandomEngine::Choice(\n    IdxType num, FloatArray prob, IdxType* out, bool replace) {\n  const IdxType N = prob->shape[0];\n  if (!replace)\n    CHECK_LE(num, N)\n        << \"Cannot take more sample than population when 'replace=false'\";\n  if (num == N && !replace) std::iota(out, out + num, 0);\n\n  utils::BaseSampler<IdxType>* sampler = nullptr;\n  if (replace) {\n    sampler = new utils::TreeSampler<IdxType, FloatType, true>(this, prob);\n  } else {\n    sampler = new utils::TreeSampler<IdxType, FloatType, false>(this, prob);\n  }\n  for (IdxType i = 0; i < num; ++i) out[i] = sampler->Draw();\n  delete sampler;\n}\n\ntemplate void RandomEngine::Choice<int32_t, float>(\n    int32_t num, FloatArray prob, int32_t* out, bool replace);\ntemplate void RandomEngine::Choice<int64_t, float>(\n    int64_t num, FloatArray prob, int64_t* out, bool replace);\ntemplate void RandomEngine::Choice<int32_t, double>(\n    int32_t num, FloatArray prob, int32_t* out, bool replace);\ntemplate void RandomEngine::Choice<int64_t, double>(\n    int64_t num, FloatArray prob, int64_t* out, bool replace);\ntemplate void RandomEngine::Choice<int32_t, int8_t>(\n    int32_t num, FloatArray prob, int32_t* out, bool replace);\ntemplate void RandomEngine::Choice<int64_t, int8_t>(\n    int64_t num, FloatArray prob, int64_t* out, bool replace);\ntemplate void RandomEngine::Choice<int32_t, uint8_t>(\n    int32_t num, FloatArray prob, int32_t* out, bool replace);\ntemplate void RandomEngine::Choice<int64_t, uint8_t>(\n    int64_t num, FloatArray prob, int64_t* out, bool replace);\n\ntemplate <typename IdxType>\nvoid RandomEngine::UniformChoice(\n    IdxType num, IdxType population, IdxType* out, bool replace) {\n  CHECK_GE(num, 0) << \"The numbers to sample should be non-negative.\";\n  CHECK_GE(population, 0) << \"The population size should be non-negative.\";\n  if (!replace)\n    CHECK_LE(num, population)\n        << \"Cannot take more sample than population when 'replace=false'\";\n  if (replace) {\n    for (IdxType i = 0; i < num; ++i) out[i] = RandInt(population);\n  } else {\n    if (num <\n        population / 10) {  // TODO(minjie): may need a better threshold here\n      // if set of numbers is small (up to 128) use linear search to verify\n      // uniqueness this operation is cheaper for CPU.\n      if (num && num < 64) {\n        *out = RandInt(population);\n        auto b = out + 1;\n        auto e = b + num - 1;\n        while (b != e) {\n          // put the new value at the end\n          *b = RandInt(population);\n          // Check if a new value doesn't exist in current range(out,b)\n          // otherwise get a new value until we haven't unique range of\n          // elements.\n          auto it = std::find(out, b, *b);\n          if (it != b) continue;\n          ++b;\n        }\n\n      } else {\n        // use hash set\n        // In the best scenario, time complexity is O(num), i.e., no conflict.\n        //\n        // Let k be num / population, the expected number of extra sampling\n        // steps is roughly k^2 / (1-k) * population, which means in the worst\n        // case scenario, the time complexity is O(population^2). In practice,\n        // we use 1/10 since std::unordered_set is pretty slow.\n        std::unordered_set<IdxType> selected;\n        while (static_cast<IdxType>(selected.size()) < num) {\n          selected.insert(RandInt(population));\n        }\n        std::copy(selected.begin(), selected.end(), out);\n      }\n\n    } else {\n      // In this case, `num >= population / 10`. To reduce the computation\n      // overhead, we should reduce the number of random number generations.\n      // Even though reservior algorithm is more memory effficient (it has\n      // O(num) memory complexity), it generates O(population) random numbers,\n      // which is computationally expensive. This algorithm has memory\n      // complexity of O(population) but generates much fewer random numbers\n      // O(num). In the case of `num >= population/10`, we don't need to worry\n      // about memory complexity because `num` is usually small. So is\n      // `population`. Allocating a small piece of memory is very efficient.\n      std::vector<IdxType> seq(population);\n      for (size_t i = 0; i < seq.size(); i++) seq[i] = i;\n      for (IdxType i = 0; i < num; i++) {\n        IdxType j = RandInt(i, population);\n        std::swap(seq[i], seq[j]);\n      }\n      // Save the randomly sampled numbers.\n      for (IdxType i = 0; i < num; i++) {\n        out[i] = seq[i];\n      }\n    }\n  }\n}\n\ntemplate void RandomEngine::UniformChoice<int32_t>(\n    int32_t num, int32_t population, int32_t* out, bool replace);\ntemplate void RandomEngine::UniformChoice<int64_t>(\n    int64_t num, int64_t population, int64_t* out, bool replace);\n\ntemplate <typename IdxType, typename FloatType>\nvoid RandomEngine::BiasedChoice(\n    IdxType num, const IdxType* split, FloatArray bias, IdxType* out,\n    bool replace) {\n  const int64_t num_tags = bias->shape[0];\n  const FloatType* bias_data = static_cast<FloatType*>(bias->data);\n  IdxType total_node_num = 0;\n  FloatArray prob = NDArray::Empty({num_tags}, bias->dtype, bias->ctx);\n  FloatType* prob_data = static_cast<FloatType*>(prob->data);\n  for (int64_t tag = 0; tag < num_tags; ++tag) {\n    int64_t tag_num_nodes = split[tag + 1] - split[tag];\n    total_node_num += tag_num_nodes;\n    FloatType tag_bias = bias_data[tag];\n    prob_data[tag] = tag_num_nodes * tag_bias;\n  }\n  if (replace) {\n    auto sampler = utils::TreeSampler<IdxType, FloatType, true>(this, prob);\n    for (IdxType i = 0; i < num; ++i) {\n      const int64_t tag = sampler.Draw();\n      const IdxType tag_num_nodes = split[tag + 1] - split[tag];\n      out[i] = RandInt(tag_num_nodes) + split[tag];\n    }\n  } else {\n    utils::TreeSampler<int64_t, FloatType, false> sampler(\n        this, prob, bias_data);\n    CHECK_GE(total_node_num, num)\n        << \"Cannot take more sample than population when 'replace=false'\";\n    // we use hash set here. Maybe in the future we should support reservoir\n    // algorithm\n    std::vector<std::unordered_set<IdxType>> selected(num_tags);\n    for (IdxType i = 0; i < num; ++i) {\n      const int64_t tag = sampler.Draw();\n      bool inserted = false;\n      const IdxType tag_num_nodes = split[tag + 1] - split[tag];\n      IdxType selected_node;\n      while (!inserted) {\n        CHECK_LT(selected[tag].size(), tag_num_nodes)\n            << \"Cannot take more sample than population when 'replace=false'\";\n        selected_node = RandInt(tag_num_nodes);\n        inserted = selected[tag].insert(selected_node).second;\n      }\n      out[i] = selected_node + split[tag];\n    }\n  }\n}\n\ntemplate void RandomEngine::BiasedChoice<int32_t, float>(\n    int32_t, const int32_t*, FloatArray, int32_t*, bool);\ntemplate void RandomEngine::BiasedChoice<int32_t, double>(\n    int32_t, const int32_t*, FloatArray, int32_t*, bool);\ntemplate void RandomEngine::BiasedChoice<int64_t, float>(\n    int64_t, const int64_t*, FloatArray, int64_t*, bool);\ntemplate void RandomEngine::BiasedChoice<int64_t, double>(\n    int64_t, const int64_t*, FloatArray, int64_t*, bool);\n\n};  // namespace dgl\n"
  },
  {
    "path": "src/random/cpu/sample_utils.h",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file dgl/sample_utils.h\n * @brief Sampling utilities\n */\n#ifndef DGL_RANDOM_CPU_SAMPLE_UTILS_H_\n#define DGL_RANDOM_CPU_SAMPLE_UTILS_H_\n\n#include <dgl/array.h>\n#include <dgl/random.h>\n\n#include <algorithm>\n#include <cmath>\n#include <cstdlib>\n#include <limits>\n#include <numeric>\n#include <queue>\n#include <utility>\n#include <vector>\n\nnamespace dgl {\nnamespace utils {\n\n/** @brief Base sampler class */\ntemplate <typename Idx>\nclass BaseSampler {\n public:\n  virtual ~BaseSampler() = default;\n  /** @brief Draw one integer sample */\n  virtual Idx Draw() {\n    LOG(INFO) << \"Not implemented yet.\";\n    return 0;\n  }\n};\n\n// (BarclayII 2022.9.20) Changing the internal data type of probabilities to\n// double since we are using non-uniform sampling to sample on boolean masks,\n// where False represents probability 0.  DType could be uint8 in this case,\n// which will give incorrect arithmetic results due to overflowing and/or\n// integer division.\n\n/**\n * AliasSampler is used to sample elements from a given discrete categorical\n * distribution. Algorithm: Alias\n * Method(https://en.wikipedia.org/wiki/Alias_method) Sampler building\n * complexity: O(n) Sample w/ replacement complexity: O(1) Sample w/o\n * replacement complexity: O(log n)\n */\ntemplate <typename Idx, typename DType, bool replace>\nclass AliasSampler : public BaseSampler<Idx> {\n private:\n  RandomEngine *re;\n  Idx N;\n  double accum, taken;    // accumulated likelihood\n  std::vector<Idx> K;     // alias table\n  std::vector<double> U;  // probability table\n  FloatArray _prob;       // category distribution\n  std::vector<bool>\n      used;  // indicate availability, activated when replace=false;\n  std::vector<Idx> id_mapping;  // index mapping, activated when replace=false;\n\n  inline Idx Map(Idx x) const {  // Map consecutive indices to unused elements\n    if (replace)\n      return x;\n    else\n      return id_mapping[x];\n  }\n\n  void Reconstruct(FloatArray prob) {  // Reconstruct alias table\n    const int64_t prob_size = prob->shape[0];\n    const DType *prob_data = prob.Ptr<DType>();\n    N = 0;\n    accum = 0.;\n    taken = 0.;\n    if (!replace) id_mapping.clear();\n    for (Idx i = 0; i < prob_size; ++i)\n      if (!used[i]) {\n        N++;\n        accum += prob_data[i];\n        if (!replace) id_mapping.push_back(i);\n      }\n    if (N == 0)\n      LOG(FATAL)\n          << \"Cannot take more sample than population when 'replace=false'\";\n    K.resize(N);\n    U.resize(N);\n    double avg = accum / static_cast<double>(N);\n    std::fill(U.begin(), U.end(), avg);  // initialize U\n    std::queue<std::pair<Idx, double> > under, over;\n    for (Idx i = 0; i < N; ++i) {\n      double p = prob_data[Map(i)];\n      if (p > avg)\n        over.push(std::make_pair(i, p));\n      else\n        under.push(std::make_pair(i, p));\n      K[i] = i;  // initialize K\n    }\n    while (!under.empty() && !over.empty()) {\n      auto u_pair = under.front(), o_pair = over.front();\n      Idx i_u = u_pair.first, i_o = o_pair.first;\n      double p_u = u_pair.second, p_o = o_pair.second;\n      K[i_u] = i_o;\n      U[i_u] = p_u;\n      if (p_o + p_u > 2 * avg)\n        over.push(std::make_pair(i_o, p_o + p_u - avg));\n      else if (p_o + p_u < 2 * avg)\n        under.push(std::make_pair(i_o, p_o + p_u - avg));\n      under.pop();\n      over.pop();\n    }\n  }\n\n public:\n  void ResetState(FloatArray prob) {\n    used.resize(prob->shape[0]);\n    if (!replace) _prob = prob;\n    std::fill(used.begin(), used.end(), false);\n    Reconstruct(prob);\n  }\n\n  explicit AliasSampler(RandomEngine *re, FloatArray prob) : re(re) {\n    ResetState(prob);\n  }\n\n  ~AliasSampler() {}\n\n  Idx Draw() {\n    if (!replace) {\n      const DType *_prob_data = _prob.Ptr<DType>();\n      if (2 * taken >= accum) Reconstruct(_prob);\n      if (accum <= 0) return -1;\n      // accum changes after Reconstruct(), so avg should be computed after\n      // that.\n      double avg = accum / N;\n      while (true) {\n        double dice = re->Uniform<double>(0, N);\n        Idx i = static_cast<Idx>(dice), rst;\n        double p = (dice - i) * avg;\n        if (p <= U[i]) {\n          rst = Map(i);\n        } else {\n          rst = Map(K[i]);\n        }\n        double cap = _prob_data[rst];\n        if (!used[rst]) {\n          used[rst] = true;\n          taken += cap;\n          return rst;\n        }\n      }\n    }\n    if (accum <= 0) return -1;\n    double avg = accum / N;\n    double dice = re->Uniform<double>(0, N);\n    Idx i = static_cast<Idx>(dice);\n    double p = (dice - i) * avg;\n    if (p <= U[i])\n      return Map(i);\n    else\n      return Map(K[i]);\n  }\n};\n\n/**\n * CDFSampler is used to sample elements from a given discrete categorical\n * distribution. Algorithm: create a cumulative distribution function and\n * conduct binary search for sampling. Reference:\n * https://github.com/numpy/numpy/blob/d37908/numpy/random/mtrand.pyx#L804\n * Sampler building complexity: O(n)\n * Sample w/ and w/o replacement complexity: O(log n)\n */\ntemplate <typename Idx, typename DType, bool replace>\nclass CDFSampler : public BaseSampler<Idx> {\n private:\n  RandomEngine *re;\n  Idx N;\n  double accum, taken;\n  FloatArray _prob;         // categorical distribution\n  std::vector<double> cdf;  // cumulative distribution function\n  std::vector<bool>\n      used;  // indicate availability, activated when replace=false;\n  std::vector<Idx>\n      id_mapping;  // indicate index mapping, activated when replace=false;\n\n  inline Idx Map(Idx x) const {  // Map consecutive indices to unused elements\n    if (replace)\n      return x;\n    else\n      return id_mapping[x];\n  }\n\n  void Reconstruct(FloatArray prob) {  // Reconstruct CDF\n    int64_t prob_size = prob->shape[0];\n    const DType *prob_data = prob.Ptr<DType>();\n    N = 0;\n    accum = 0.;\n    taken = 0.;\n    if (!replace) id_mapping.clear();\n    cdf.clear();\n    cdf.push_back(0);\n    for (Idx i = 0; i < prob_size; ++i)\n      if (!used[i]) {\n        N++;\n        accum += prob_data[i];\n        if (!replace) id_mapping.push_back(i);\n        cdf.push_back(accum);\n      }\n    if (N == 0)\n      LOG(FATAL)\n          << \"Cannot take more sample than population when 'replace=false'\";\n  }\n\n public:\n  void ResetState(FloatArray prob) {\n    used.resize(prob->shape[0]);\n    if (!replace) _prob = prob;\n    std::fill(used.begin(), used.end(), false);\n    Reconstruct(prob);\n  }\n\n  explicit CDFSampler(RandomEngine *re, FloatArray prob) : re(re) {\n    ResetState(prob);\n  }\n\n  ~CDFSampler() {}\n\n  Idx Draw() {\n    double eps = std::numeric_limits<double>::min();\n    if (!replace) {\n      const DType *_prob_data = _prob.Ptr<DType>();\n      if (2 * taken >= accum) Reconstruct(_prob);\n      if (accum <= 0) return -1;\n      while (true) {\n        double p = std::max(re->Uniform<double>(0., accum), eps);\n        Idx rst =\n            Map(std::lower_bound(cdf.begin(), cdf.end(), p) - cdf.begin() - 1);\n        double cap = static_cast<double>(_prob_data[rst]);\n        if (!used[rst]) {\n          used[rst] = true;\n          taken += cap;\n          return rst;\n        }\n      }\n    }\n    if (accum <= 0) return -1;\n    double p = std::max(re->Uniform<double>(0., accum), eps);\n    return Map(std::lower_bound(cdf.begin(), cdf.end(), p) - cdf.begin() - 1);\n  }\n};\n\n/**\n * TreeSampler is used to sample elements from a given discrete categorical\n * distribution. Algorithm: create a heap that stores accumulated likelihood of\n * its leaf descendents. Reference: https://blog.smola.org/post/1016514759\n * Sampler building complexity: O(n)\n * Sample w/ and w/o replacement complexity: O(log n)\n */\ntemplate <typename Idx, typename DType, bool replace>\nclass TreeSampler : public BaseSampler<Idx> {\n private:\n  RandomEngine *re;\n  std::vector<double> weight;  // accumulated likelihood of subtrees.\n  int64_t N;\n  int64_t num_leafs;\n  const DType *decrease;\n\n public:\n  void ResetState(FloatArray prob) {\n    int64_t prob_size = prob->shape[0];\n    const DType *prob_data = prob.Ptr<DType>();\n    std::fill(weight.begin(), weight.end(), 0);\n    for (int64_t i = 0; i < prob_size; ++i)\n      weight[num_leafs + i] = prob_data[i];\n    for (int64_t i = num_leafs - 1; i >= 1; --i)\n      weight[i] = weight[i * 2] + weight[i * 2 + 1];\n  }\n\n  explicit TreeSampler(\n      RandomEngine *re, FloatArray prob, const DType *decrease = nullptr)\n      : re(re), decrease(decrease) {\n    num_leafs = 1;\n    while (num_leafs < prob->shape[0]) num_leafs *= 2;\n    N = num_leafs * 2;\n    weight.resize(N);\n    ResetState(prob);\n  }\n\n  /* Pick an element from the given distribution and update the tree.\n   *\n   * The parameter decrease is an array of which the length is the number of\n   * categories. Every time an element in the category x is picked, the weight\n   * of this category is subtracted by decrease[x]. It is used to support the\n   * case where a category might contains multiple candidates and decrease[x] is\n   * the weight of one candidate of the category x.\n   *\n   * When decrease == nullptr, it means there is only one candidate in each\n   * category and will directly set the weight of the chosen category as 0.\n   *\n   */\n  Idx Draw() {\n    if (weight[1] <= 0) return -1;\n    int64_t cur = 1;\n    double p = re->Uniform<double>(0, weight[cur]);\n    double accum = 0.;\n    while (cur < num_leafs) {\n      double w_l = weight[cur * 2], w_r = weight[cur * 2 + 1];\n      double pivot = accum + w_l;\n      // w_r > 0 can suppress some numerical problems.\n      Idx shift = static_cast<Idx>(p > pivot && w_r > 0);\n      cur = cur * 2 + shift;\n      if (shift == 1) accum = pivot;\n    }\n    Idx rst = cur - num_leafs;\n    if (!replace) {\n      while (cur >= 1) {\n        if (cur >= num_leafs)\n          weight[cur] =\n              this->decrease\n                  ? weight[cur] - static_cast<double>(this->decrease[rst])\n                  : 0.;\n        else\n          weight[cur] = weight[cur * 2] + weight[cur * 2 + 1];\n        cur /= 2;\n      }\n    }\n    return rst;\n  }\n};\n\n};  // namespace utils\n};  // namespace dgl\n\n#endif  // DGL_RANDOM_CPU_SAMPLE_UTILS_H_\n"
  },
  {
    "path": "src/random/random.cc",
    "content": "/**\n *  Copyright (c) 2017 by Contributors\n * @file random.cc\n * @brief Random number generator interfaces\n */\n\n#include <dgl/array.h>\n#include <dgl/random.h>\n#include <dgl/runtime/packed_func.h>\n#include <dgl/runtime/parallel_for.h>\n#include <dgl/runtime/registry.h>\n#include <dmlc/omp.h>\n\nusing namespace dgl::runtime;\n\nnamespace dgl {\n\nDGL_REGISTER_GLOBAL(\"rng._CAPI_SetSeed\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      const int seed = args[0];\n\n      runtime::parallel_for(0, omp_get_max_threads(), [&](size_t b, size_t e) {\n        for (auto i = b; i < e; ++i) {\n          RandomEngine::ThreadLocal()->SetSeed(seed);\n        }\n      });\n    });\n\nDGL_REGISTER_GLOBAL(\"rng._CAPI_Choice\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      const int64_t num = args[0];\n      const int64_t population = args[1];\n      const NDArray prob = args[2];\n      const bool replace = args[3];\n      const int bits = args[4];\n      CHECK(bits == 32 || bits == 64)\n          << \"Supported bit widths are 32 and 64, but got \" << bits << \".\";\n      if (aten::IsNullArray(prob)) {\n        if (bits == 32) {\n          *rv = RandomEngine::ThreadLocal()->UniformChoice<int32_t>(\n              num, population, replace);\n        } else {\n          *rv = RandomEngine::ThreadLocal()->UniformChoice<int64_t>(\n              num, population, replace);\n        }\n      } else {\n        if (bits == 32) {\n          ATEN_FLOAT_TYPE_SWITCH(prob->dtype, FloatType, \"probability\", {\n            *rv = RandomEngine::ThreadLocal()->Choice<int32_t, FloatType>(\n                num, prob, replace);\n          });\n        } else {\n          ATEN_FLOAT_TYPE_SWITCH(prob->dtype, FloatType, \"probability\", {\n            *rv = RandomEngine::ThreadLocal()->Choice<int64_t, FloatType>(\n                num, prob, replace);\n          });\n        }\n      }\n    });\n\n};  // namespace dgl\n"
  },
  {
    "path": "src/rpc/network/common.cc",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file common.cc\n * @brief This file provide basic facilities for string\n * to make programming convenient.\n */\n#include \"common.h\"\n\n#include <stdarg.h>\n#include <stdio.h>\n\nusing std::string;\n\nnamespace dgl {\nnamespace network {\n\n// In most cases, delim contains only one character. In this case, we\n// use CalculateReserveForVector to count the number of elements should\n// be reserved in result vector, and thus optimize SplitStringUsing.\nstatic int CalculateReserveForVector(\n    const std::string& full, const char* delim) {\n  int count = 0;\n  if (delim[0] != '\\0' && delim[1] == '\\0') {\n    // Optimize the common case where delim is a single character.\n    char c = delim[0];\n    const char* p = full.data();\n    const char* end = p + full.size();\n    while (p != end) {\n      if (*p == c) {  // This could be optimized with hasless(v,1) trick.\n        ++p;\n      } else {\n        while (++p != end && *p != c) {\n          // Skip to the next occurence of the delimiter.\n        }\n        ++count;\n      }\n    }\n  }\n  return count;\n}\n\nvoid SplitStringUsing(\n    const std::string& full, const char* delim,\n    std::vector<std::string>* result) {\n  CHECK(delim != NULL);\n  CHECK(result != NULL);\n  result->reserve(CalculateReserveForVector(full, delim));\n  back_insert_iterator<std::vector<std::string> > it(*result);\n  SplitStringToIteratorUsing(full, delim, &it);\n}\n\nvoid SplitStringToSetUsing(\n    const std::string& full, const char* delim, std::set<std::string>* result) {\n  CHECK(delim != NULL);\n  CHECK(result != NULL);\n  simple_insert_iterator<std::set<std::string> > it(result);\n  SplitStringToIteratorUsing(full, delim, &it);\n}\n\nstatic void StringAppendV(string* dst, const char* format, va_list ap) {\n  // First try with a small fixed size buffer\n  char space[1024];\n  // It's possible for methods that use a va_list to invalidate\n  // the data in it upon use.  The fix is to make a copy\n  // of the structure before using it and use that copy instead.\n  va_list backup_ap;\n  va_copy(backup_ap, ap);\n  int result = vsnprintf(space, sizeof(space), format, backup_ap);\n  va_end(backup_ap);\n\n  if ((result >= 0) && (result < static_cast<int>(sizeof(space)))) {\n    // It fit\n    dst->append(space, result);\n    return;\n  }\n\n  // Repeatedly increase buffer size until it fits\n  int length = sizeof(space);\n  while (true) {\n    if (result < 0) {\n      // Older behavior: just try doubling the buffer size\n      length *= 2;\n    } else {\n      // We need exactly \"result+1\" characters\n      length = result + 1;\n    }\n    char* buf = new char[length];\n\n    // Restore the va_list before we use it again\n    va_copy(backup_ap, ap);\n    result = vsnprintf(buf, length, format, backup_ap);\n    va_end(backup_ap);\n\n    if ((result >= 0) && (result < length)) {\n      // It fit\n      dst->append(buf, result);\n      delete[] buf;\n      return;\n    }\n    delete[] buf;\n  }\n}\n\nstring StringPrintf(const char* format, ...) {\n  va_list ap;\n  va_start(ap, format);\n  string result;\n  StringAppendV(&result, format, ap);\n  va_end(ap);\n  return result;\n}\n\nvoid SStringPrintf(string* dst, const char* format, ...) {\n  va_list ap;\n  va_start(ap, format);\n  dst->clear();\n  StringAppendV(dst, format, ap);\n  va_end(ap);\n}\n\nvoid StringAppendF(string* dst, const char* format, ...) {\n  va_list ap;\n  va_start(ap, format);\n  StringAppendV(dst, format, ap);\n  va_end(ap);\n}\n\n}  // namespace network\n}  // namespace dgl\n"
  },
  {
    "path": "src/rpc/network/common.h",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file common.h\n * @brief This file provide basic facilities for string\n * to make programming convenient.\n */\n#ifndef DGL_RPC_NETWORK_COMMON_H_\n#define DGL_RPC_NETWORK_COMMON_H_\n\n#include <dmlc/logging.h>\n\n#include <set>\n#include <string>\n#include <vector>\n\nnamespace dgl {\nnamespace network {\n\n//------------------------------------------------------------------------------\n// Subdivide string |full| into substrings according to delimitors\n// given in |delim|.  |delim| should pointing to a string including\n// one or more characters.  Each character is considerred a possible\n// delimitor. For example:\n//\n//   vector<string> substrings;\n//   SplitStringUsing(\"apple orange\\tbanana\", \"\\t \", &substrings);\n//\n// results in three substrings:\n//\n//   substrings.size() == 3\n//   substrings[0] == \"apple\"\n//   substrings[1] == \"orange\"\n//   substrings[2] == \"banana\"\n//------------------------------------------------------------------------------\n\nvoid SplitStringUsing(\n    const std::string& full, const char* delim,\n    std::vector<std::string>* result);\n\n// This function has the same semnatic as SplitStringUsing.  Results\n// are saved in an STL set container.\nvoid SplitStringToSetUsing(\n    const std::string& full, const char* delim, std::set<std::string>* result);\n\ntemplate <typename T>\nstruct simple_insert_iterator {\n  explicit simple_insert_iterator(T* t) : t_(t) {}\n\n  simple_insert_iterator<T>& operator=(const typename T::value_type& value) {\n    t_->insert(value);\n    return *this;\n  }\n\n  simple_insert_iterator<T>& operator*() { return *this; }\n  simple_insert_iterator<T>& operator++() { return *this; }\n  simple_insert_iterator<T>& operator++(int placeholder) { return *this; }\n\n  T* t_;\n};\n\ntemplate <typename T>\nstruct back_insert_iterator {\n  explicit back_insert_iterator(T& t) : t_(t) {}\n\n  back_insert_iterator<T>& operator=(const typename T::value_type& value) {\n    t_.push_back(value);\n    return *this;\n  }\n\n  back_insert_iterator<T>& operator*() { return *this; }\n  back_insert_iterator<T>& operator++() { return *this; }\n  back_insert_iterator<T> operator++(int placeholder) { return *this; }\n\n  T& t_;\n};\n\ntemplate <typename StringType, typename ITR>\nstatic inline void SplitStringToIteratorUsing(\n    const StringType& full, const char* delim, ITR* result) {\n  CHECK_NOTNULL(delim);\n  // Optimize the common case where delim is a single character.\n  if (delim[0] != '\\0' && delim[1] == '\\0') {\n    char c = delim[0];\n    const char* p = full.data();\n    const char* end = p + full.size();\n    while (p != end) {\n      if (*p == c) {\n        ++p;\n      } else {\n        const char* start = p;\n        while (++p != end && *p != c) {\n          // Skip to the next occurence of the delimiter.\n        }\n        *(*result)++ = StringType(start, p - start);\n      }\n    }\n    return;\n  }\n\n  std::string::size_type begin_index, end_index;\n  begin_index = full.find_first_not_of(delim);\n  while (begin_index != std::string::npos) {\n    end_index = full.find_first_of(delim, begin_index);\n    if (end_index == std::string::npos) {\n      *(*result)++ = full.substr(begin_index);\n      return;\n    }\n    *(*result)++ = full.substr(begin_index, (end_index - begin_index));\n    begin_index = full.find_first_not_of(delim, end_index);\n  }\n}\n\n//------------------------------------------------------------------------------\n// StringPrintf:\n//\n// For example:\n//\n//  std::string str = StringPrintf(\"%d\", 1);    /* str = \"1\"  */\n//  SStringPrintf(&str, \"%d\", 2);               /* str = \"2\"  */\n//  StringAppendF(&str, \"%d\", 3);               /* str = \"23\" */\n//------------------------------------------------------------------------------\n\nstd::string StringPrintf(const char* format, ...);\nvoid SStringPrintf(std::string* dst, const char* format, ...);\nvoid StringAppendF(std::string* dst, const char* format, ...);\n\n}  // namespace network\n}  // namespace dgl\n\n#endif  // DGL_RPC_NETWORK_COMMON_H_\n"
  },
  {
    "path": "src/rpc/network/communicator.h",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file communicator.h\n * @brief Communicator for DGL distributed training.\n */\n#ifndef DGL_RPC_NETWORK_COMMUNICATOR_H_\n#define DGL_RPC_NETWORK_COMMUNICATOR_H_\n\n#include <dmlc/logging.h>\n\n#include <string>\n\n#include \"msg_queue.h\"\n\nnamespace dgl {\nnamespace network {\n\n/**\n * @brief Network Sender for DGL distributed training.\n *\n * Sender is an abstract class that defines a set of APIs for sending binary\n * data message over network. It can be implemented by different underlying\n * networking libraries such TCP socket and MPI. One Sender can connect to\n * multiple receivers and it can send data to specified receiver via receiver's\n * ID.\n */\nclass Sender {\n public:\n  /**\n   * @brief Sender constructor\n   * @param queue_size size (bytes) of message queue.\n   * @param max_thread_count size of thread pool. 0 for no limit\n   * Note that, the queue_size parameter is optional.\n   */\n  explicit Sender(int64_t queue_size = 0, int max_thread_count = 0) {\n    CHECK_GE(queue_size, 0);\n    CHECK_GE(max_thread_count, 0);\n    queue_size_ = queue_size;\n    max_thread_count_ = max_thread_count;\n  }\n\n  virtual ~Sender() {}\n\n  /**\n   * @brief Send data to specified Receiver.\n   * @param msg data message\n   * @param recv_id receiver's ID\n   * @return Status code\n   *\n   * (1) The send is non-blocking. There is no guarantee that the message has\n   * been physically sent out when the function returns. (2) The communicator\n   * will assume the responsibility of the given message. (3) The API is\n   * multi-thread safe. (4) Messages sent to the same receiver are guaranteed to\n   * be received in the same order. There is no guarantee for messages sent to\n   * different receivers.\n   */\n  virtual STATUS Send(Message msg, int recv_id) = 0;\n\n protected:\n  /**\n   * @brief Size of message queue\n   */\n  int64_t queue_size_;\n  /**\n   * @brief Size of thread pool. 0 for no limit\n   */\n  int max_thread_count_;\n};\n\n/**\n * @brief Network Receiver for DGL distributed training.\n *\n * Receiver is an abstract class that defines a set of APIs for receiving binary\n * data message over network. It can be implemented by different underlying\n * networking libraries such as TCP socket and MPI. One Receiver can connect\n * with multiple Senders and it can receive data from multiple Senders\n * concurrently.\n */\nclass Receiver {\n public:\n  /**\n   * @brief Receiver constructor\n   * @param queue_size size of message queue.\n   * @param max_thread_count size of thread pool. 0 for no limit\n   * Note that, the queue_size parameter is optional.\n   */\n  explicit Receiver(int64_t queue_size = 0, int max_thread_count = 0) {\n    if (queue_size < 0) {\n      LOG(FATAL) << \"queue_size cannot be a negative number.\";\n    }\n    CHECK_GE(max_thread_count, 0);\n    queue_size_ = queue_size;\n    max_thread_count_ = max_thread_count;\n  }\n\n  virtual ~Receiver() {}\n\n  /**\n   * @brief Recv data from Sender\n   * @param msg pointer of data message\n   * @param send_id which sender current msg comes from\n   * @param timeout The timeout value in milliseconds. If zero, wait\n   * indefinitely.\n   * @return Status code\n   *\n   * (1) The Recv() API is thread-safe.\n   * (2) Memory allocated by communicator but will not own it after the function\n   * returns.\n   */\n  virtual STATUS Recv(Message* msg, int* send_id, int timeout = 0) = 0;\n\n  /**\n   * @brief Recv data from a specified Sender\n   * @param msg pointer of data message\n   * @param send_id sender's ID\n   * @param timeout The timeout value in milliseconds. If zero, wait\n   * indefinitely.\n   * @return Status code\n   *\n   * (1) The RecvFrom() API is thread-safe.\n   * (2) Memory allocated by communicator but will not own it after the function\n   * returns.\n   */\n  virtual STATUS RecvFrom(Message* msg, int send_id, int timeout = 0) = 0;\n\n protected:\n  /**\n   * @brief Size of message queue\n   */\n  int64_t queue_size_;\n  /**\n   * @brief Size of thread pool. 0 for no limit\n   */\n  int max_thread_count_;\n};\n\n}  // namespace network\n}  // namespace dgl\n\n#endif  // DGL_RPC_NETWORK_COMMUNICATOR_H_\n"
  },
  {
    "path": "src/rpc/network/msg_queue.cc",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file msg_queue.cc\n * @brief Message queue for DGL distributed training.\n */\n#include \"msg_queue.h\"\n\n#include <dmlc/logging.h>\n\n#include <cstring>\n\nnamespace dgl {\nnamespace network {\n\nusing std::string;\n\nMessageQueue::MessageQueue(int64_t queue_size, int num_producers) {\n  CHECK_GE(queue_size, 0);\n  CHECK_GE(num_producers, 0);\n  queue_size_ = queue_size;\n  free_size_ = queue_size;\n  num_producers_ = num_producers;\n}\n\nSTATUS MessageQueue::Add(Message msg, bool is_blocking) {\n  // check if message is too long to fit into the queue\n  if (msg.size > queue_size_) {\n    LOG(WARNING) << \"Message is larger than the queue.\";\n    return MSG_GT_SIZE;\n  }\n  if (msg.size <= 0) {\n    LOG(WARNING) << \"Message size (\" << msg.size << \") is negative or zero.\";\n    return MSG_LE_ZERO;\n  }\n  std::unique_lock<std::mutex> lock(mutex_);\n  if (finished_producers_.size() >= num_producers_) {\n    return QUEUE_CLOSE;\n  }\n  if (msg.size > free_size_ && !is_blocking) {\n    return QUEUE_FULL;\n  }\n  cond_not_full_.wait(lock, [&]() { return msg.size <= free_size_; });\n  // Add data pointer to queue\n  queue_.push(msg);\n  free_size_ -= msg.size;\n  // not empty signal\n  cond_not_empty_.notify_one();\n\n  return ADD_SUCCESS;\n}\n\nSTATUS MessageQueue::Remove(Message* msg, bool is_blocking) {\n  std::unique_lock<std::mutex> lock(mutex_);\n  if (queue_.empty()) {\n    if (!is_blocking) {\n      return QUEUE_EMPTY;\n    }\n    if (finished_producers_.size() >= num_producers_) {\n      return QUEUE_CLOSE;\n    }\n  }\n\n  cond_not_empty_.wait(\n      lock, [this] { return !queue_.empty() || exit_flag_.load(); });\n  if (finished_producers_.size() >= num_producers_ && queue_.empty()) {\n    return QUEUE_CLOSE;\n  }\n\n  Message old_msg = queue_.front();\n  queue_.pop();\n  msg->data = old_msg.data;\n  msg->size = old_msg.size;\n  msg->receiver_id = old_msg.receiver_id;\n  msg->deallocator = old_msg.deallocator;\n  free_size_ += old_msg.size;\n  cond_not_full_.notify_one();\n\n  return REMOVE_SUCCESS;\n}\n\nvoid MessageQueue::SignalFinished(int producer_id) {\n  std::lock_guard<std::mutex> lock(mutex_);\n  finished_producers_.insert(producer_id);\n  // if all producers have finished, consumers should be\n  // waken up to get this signal\n  if (finished_producers_.size() >= num_producers_) {\n    exit_flag_.store(true);\n    cond_not_empty_.notify_all();\n  }\n}\n\nbool MessageQueue::Empty() const {\n  std::lock_guard<std::mutex> lock(mutex_);\n  return queue_.size() == 0;\n}\n\nbool MessageQueue::EmptyAndNoMoreAdd() const {\n  std::lock_guard<std::mutex> lock(mutex_);\n  return queue_.size() == 0 && finished_producers_.size() >= num_producers_;\n}\n\n}  // namespace network\n}  // namespace dgl\n"
  },
  {
    "path": "src/rpc/network/msg_queue.h",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file msg_queue.h\n * @brief Message queue for DGL distributed training.\n */\n#ifndef DGL_RPC_NETWORK_MSG_QUEUE_H_\n#define DGL_RPC_NETWORK_MSG_QUEUE_H_\n\n#include <dgl/runtime/ndarray.h>\n\n#include <atomic>\n#include <condition_variable>\n#include <functional>\n#include <mutex>\n#include <queue>\n#include <set>\n#include <string>\n#include <utility>  // for pair\n\nnamespace dgl {\nnamespace network {\n\ntypedef int STATUS;\n\n/**\n * @brief Status code of message queue\n */\n#define ADD_SUCCESS 3400     // Add message successfully\n#define MSG_GT_SIZE 3401     // Message size beyond queue size\n#define MSG_LE_ZERO 3402     // Message size is not a positive number\n#define QUEUE_CLOSE 3403     // Cannot add message when queue is closed\n#define QUEUE_FULL 3404      // Cannot add message when queue is full\n#define REMOVE_SUCCESS 3405  // Remove message successfully\n#define QUEUE_EMPTY 3406     // Cannot remove when queue is empty\n\n/**\n * @brief Message used by network communicator and message queue.\n */\nstruct Message {\n  /**\n   * @brief Constructor\n   */\n  Message() {}\n\n  /**\n   * @brief Constructor\n   */\n  Message(char* data_ptr, int64_t data_size)\n      : data(data_ptr), size(data_size) {}\n\n  /**\n   * @brief message data\n   */\n  char* data;\n  /**\n   * @brief message size in bytes\n   */\n  int64_t size;\n  /**\n   * @brief message receiver id\n   */\n  int receiver_id = -1;\n  /**\n   * @brief user-defined deallocator, which can be nullptr\n   */\n  std::function<void(Message*)> deallocator = nullptr;\n};\n\n/**\n * @brief Free memory buffer of message\n */\ninline void DefaultMessageDeleter(Message* msg) { delete[] msg->data; }\n\n/**\n * @brief Message Queue for network communication.\n *\n * MessageQueue is FIFO queue that adopts producer/consumer model for data\n * message. It supports one or more producer threads and one or more consumer\n * threads. Producers invokes Add() to push data message into the queue, and\n * consumers invokes Remove() to pop data message from queue. Add() and Remove()\n * use two condition variables to synchronize producer threads and consumer\n * threads. Each producer invokes SignalFinished(producer_id) to claim that it\n * is about to finish, where producer_id is an integer uniquely identify a\n * producer thread. This signaling mechanism prevents consumers from waiting\n * after all producers have finished their jobs.\n *\n * MessageQueue is thread-safe.\n *\n */\nclass MessageQueue {\n public:\n  /**\n   * @brief MessageQueue constructor\n   * @param queue_size size (bytes) of message queue\n   * @param num_producers number of producers, use 1 by default\n   */\n  explicit MessageQueue(\n      int64_t queue_size /* in bytes */, int num_producers = 1);\n\n  /**\n   * @brief MessageQueue deconstructor\n   */\n  ~MessageQueue() {}\n\n  /**\n   * @brief Add message to the queue\n   * @param msg data message\n   * @param is_blocking Blocking if cannot add, else return\n   * @return Status code\n   */\n  STATUS Add(Message msg, bool is_blocking = true);\n\n  /**\n   * @brief Remove message from the queue\n   * @param msg pointer of data msg\n   * @param is_blocking Blocking if cannot remove, else return\n   * @return Status code\n   */\n  STATUS Remove(Message* msg, bool is_blocking = true);\n\n  /**\n   * @brief Signal that producer producer_id will no longer produce anything\n   * @param producer_id An integer uniquely to identify a producer thread\n   */\n  void SignalFinished(int producer_id);\n\n  /**\n   * @return true if queue is empty.\n   */\n  bool Empty() const;\n\n  /**\n   * @return true if queue is empty and all num_producers have signaled.\n   */\n  bool EmptyAndNoMoreAdd() const;\n\n protected:\n  /**\n   * @brief message queue\n   */\n  std::queue<Message> queue_;\n\n  /**\n   * @brief Size of the queue in bytes\n   */\n  int64_t queue_size_;\n\n  /**\n   * @brief Free size of the queue\n   */\n  int64_t free_size_;\n\n  /**\n   * @brief Used to check all producers will no longer produce anything\n   */\n  size_t num_producers_;\n\n  /**\n   * @brief Store finished producer id\n   */\n  std::set<int /* producer_id */> finished_producers_;\n\n  /**\n   * @brief Condition when consumer should wait\n   */\n  std::condition_variable cond_not_full_;\n\n  /**\n   * @brief Condition when producer should wait\n   */\n  std::condition_variable cond_not_empty_;\n\n  /**\n   * @brief Signal for exit wait\n   */\n  std::atomic<bool> exit_flag_{false};\n\n  /**\n   * @brief Protect all above data and conditions\n   */\n  mutable std::mutex mutex_;\n};\n\n}  // namespace network\n}  // namespace dgl\n\n#endif  // DGL_RPC_NETWORK_MSG_QUEUE_H_\n"
  },
  {
    "path": "src/rpc/network/socket_communicator.cc",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file communicator.cc\n * @brief SocketCommunicator for DGL distributed training.\n */\n#include \"socket_communicator.h\"\n\n#include <dmlc/logging.h>\n#include <stdlib.h>\n#include <string.h>\n#include <time.h>\n\n#include <memory>\n\n#include \"../../c_api_common.h\"\n#include \"socket_pool.h\"\n\n#ifdef _WIN32\n#include <windows.h>\n#else  // !_WIN32\n#include <unistd.h>\n#endif  // _WIN32\n\nnamespace dgl {\nnamespace network {\n\n/////////////////////////////////////// SocketSender\n//////////////////////////////////////////////\n\nbool SocketSender::ConnectReceiver(const std::string& addr, int recv_id) {\n  if (recv_id < 0) {\n    LOG(FATAL) << \"recv_id cannot be a negative number.\";\n  }\n  std::vector<std::string> substring;\n  std::vector<std::string> ip_and_port;\n  SplitStringUsing(addr, \"//\", &substring);\n  // Check address format\n  if (substring[0] != \"tcp:\" || substring.size() != 2) {\n    LOG(FATAL) << \"Incorrect address format:\" << addr\n               << \" Please provide right address format, \"\n               << \"e.g, 'tcp://127.0.0.1:50051'. \";\n  }\n  // Get IP and port\n  SplitStringUsing(substring[1], \":\", &ip_and_port);\n  if (ip_and_port.size() != 2) {\n    LOG(FATAL) << \"Incorrect address format:\" << addr\n               << \" Please provide right address format, \"\n               << \"e.g, 'tcp://127.0.0.1:50051'. \";\n  }\n  IPAddr address;\n  address.ip = ip_and_port[0];\n  address.port = std::stoi(ip_and_port[1]);\n  receiver_addrs_[recv_id] = address;\n\n  return true;\n}\n\nbool SocketSender::ConnectReceiverFinalize(const int max_try_times) {\n  // Create N sockets for Receiver\n  int receiver_count = static_cast<int>(receiver_addrs_.size());\n  if (max_thread_count_ == 0 || max_thread_count_ > receiver_count) {\n    max_thread_count_ = receiver_count;\n  }\n  sockets_.resize(max_thread_count_);\n  for (const auto& r : receiver_addrs_) {\n    int receiver_id = r.first;\n    int thread_id = receiver_id % max_thread_count_;\n    sockets_[thread_id][receiver_id] = std::make_shared<TCPSocket>();\n    TCPSocket* client_socket = sockets_[thread_id][receiver_id].get();\n    bool bo = false;\n    int try_count = 0;\n    const char* ip = r.second.ip.c_str();\n    int port = r.second.port;\n    while (bo == false && try_count < max_try_times) {\n      if (client_socket->Connect(ip, port)) {\n        bo = true;\n      } else {\n        if (try_count % 200 == 0 && try_count != 0) {\n          // every 600 seconds show this message\n          LOG(INFO) << \"Trying to connect receiver: \" << ip << \":\" << port;\n        }\n        try_count++;\n        std::this_thread::sleep_for(std::chrono::seconds(3));\n      }\n    }\n    if (bo == false) {\n      return bo;\n    }\n  }\n\n  for (int thread_id = 0; thread_id < max_thread_count_; ++thread_id) {\n    msg_queue_.push_back(std::make_shared<MessageQueue>(queue_size_));\n    // Create a new thread for this socket connection\n    threads_.push_back(std::make_shared<std::thread>(\n        SendLoop, sockets_[thread_id], msg_queue_[thread_id]));\n  }\n\n  return true;\n}\n\nvoid SocketSender::Send(const rpc::RPCMessage& msg, int recv_id) {\n  std::shared_ptr<std::string> zerocopy_blob(new std::string());\n  StreamWithBuffer zc_write_strm(zerocopy_blob.get(), true);\n  zc_write_strm.Write(msg);\n  int32_t nonempty_ndarray_count = zc_write_strm.buffer_list().size();\n  zerocopy_blob->append(\n      reinterpret_cast<char*>(&nonempty_ndarray_count), sizeof(int32_t));\n  Message rpc_meta_msg;\n  rpc_meta_msg.data = const_cast<char*>(zerocopy_blob->data());\n  rpc_meta_msg.size = zerocopy_blob->size();\n  rpc_meta_msg.deallocator = [zerocopy_blob](Message*) {};\n  CHECK_EQ(Send(rpc_meta_msg, recv_id), ADD_SUCCESS);\n  // send real ndarray data\n  for (auto ptr : zc_write_strm.buffer_list()) {\n    Message ndarray_data_msg;\n    ndarray_data_msg.data = reinterpret_cast<char*>(ptr.data);\n    if (ptr.size == 0) {\n      LOG(FATAL) << \"Cannot send a empty NDArray.\";\n    }\n    ndarray_data_msg.size = ptr.size;\n    NDArray tensor = ptr.tensor;\n    ndarray_data_msg.deallocator = [tensor](Message*) {};\n    CHECK_EQ(Send(ndarray_data_msg, recv_id), ADD_SUCCESS);\n  }\n}\n\nSTATUS SocketSender::Send(Message msg, int recv_id) {\n  CHECK_NOTNULL(msg.data);\n  CHECK_GT(msg.size, 0);\n  CHECK_GE(recv_id, 0);\n  msg.receiver_id = recv_id;\n  // Add data message to message queue\n  STATUS code = msg_queue_[recv_id % max_thread_count_]->Add(msg);\n  return code;\n}\n\nvoid SocketSender::Finalize() {\n  // Send a signal to tell the msg_queue to finish its job\n  for (int i = 0; i < max_thread_count_; ++i) {\n    // wait until queue is empty\n    auto& mq = msg_queue_[i];\n    while (mq->Empty() == false) {\n      std::this_thread::sleep_for(std::chrono::seconds(1));\n    }\n    // All queues have only one producer, which is main thread, so\n    // the producerID argument here should be zero.\n    mq->SignalFinished(0);\n  }\n  // Block main thread until all socket-threads finish their jobs\n  for (auto& thread : threads_) {\n    thread->join();\n  }\n  // Clear all sockets\n  for (auto& group_sockets_ : sockets_) {\n    for (auto& socket : group_sockets_) {\n      socket.second->Close();\n    }\n  }\n}\n\nvoid SendCore(Message msg, TCPSocket* socket) {\n  // First send the size\n  // If exit == true, we will send zero size to reciever\n  int64_t sent_bytes = 0;\n  while (static_cast<size_t>(sent_bytes) < sizeof(int64_t)) {\n    int64_t max_len = sizeof(int64_t) - sent_bytes;\n    int64_t tmp =\n        socket->Send(reinterpret_cast<char*>(&msg.size) + sent_bytes, max_len);\n    CHECK_NE(tmp, -1);\n    sent_bytes += tmp;\n  }\n  // Then send the data\n  sent_bytes = 0;\n  while (sent_bytes < msg.size) {\n    int64_t max_len = msg.size - sent_bytes;\n    int64_t tmp = socket->Send(msg.data + sent_bytes, max_len);\n    CHECK_NE(tmp, -1);\n    sent_bytes += tmp;\n  }\n  // delete msg\n  if (msg.deallocator != nullptr) {\n    msg.deallocator(&msg);\n  }\n}\n\nvoid SocketSender::SendLoop(\n    std::unordered_map<int, std::shared_ptr<TCPSocket>> sockets,\n    std::shared_ptr<MessageQueue> queue) {\n  for (;;) {\n    Message msg;\n    STATUS code = queue->Remove(&msg);\n    if (code == QUEUE_CLOSE) {\n      msg.size = 0;  // send an end-signal to receiver\n      for (auto& socket : sockets) {\n        SendCore(msg, socket.second.get());\n      }\n      break;\n    }\n    SendCore(msg, sockets[msg.receiver_id].get());\n  }\n}\n\n/////////////////////////////////////// SocketReceiver\n//////////////////////////////////////////////\nbool SocketReceiver::Wait(const std::string& addr, int num_sender) {\n  CHECK_GT(num_sender, 0);\n  std::vector<std::string> substring;\n  std::vector<std::string> ip_and_port;\n  SplitStringUsing(addr, \"//\", &substring);\n  // Check address format\n  if (substring[0] != \"tcp:\" || substring.size() != 2) {\n    LOG(FATAL) << \"Incorrect address format:\" << addr\n               << \" Please provide right address format, \"\n               << \"e.g, 'tcp://127.0.0.1:50051'. \";\n  }\n  // Get IP and port\n  SplitStringUsing(substring[1], \":\", &ip_and_port);\n  if (ip_and_port.size() != 2) {\n    LOG(FATAL) << \"Incorrect address format:\" << addr\n               << \" Please provide right address format, \"\n               << \"e.g, 'tcp://127.0.0.1:50051'. \";\n  }\n  std::string ip = ip_and_port[0];\n  int port = stoi(ip_and_port[1]);\n  // Initialize message queue for each connection\n  num_sender_ = num_sender;\n#ifdef USE_EPOLL\n  if (max_thread_count_ == 0 || max_thread_count_ > num_sender_) {\n    max_thread_count_ = num_sender_;\n  }\n#else\n  max_thread_count_ = num_sender_;\n#endif\n  // Initialize socket and socket-thread\n  server_socket_ = new TCPSocket();\n  // Bind socket\n  if (server_socket_->Bind(ip.c_str(), port) == false) {\n    LOG(FATAL) << \"Cannot bind to \" << ip << \":\" << port;\n  }\n\n  // Listen\n  if (server_socket_->Listen(kMaxConnection) == false) {\n    LOG(FATAL) << \"Cannot listen on \" << ip << \":\" << port;\n  }\n  // Accept all sender sockets\n  std::string accept_ip;\n  int accept_port;\n  sockets_.resize(max_thread_count_);\n  for (int i = 0; i < num_sender_; ++i) {\n    int thread_id = i % max_thread_count_;\n    auto socket = std::make_shared<TCPSocket>();\n    sockets_[thread_id][i] = socket;\n    msg_queue_[i] = std::make_shared<MessageQueue>(queue_size_);\n    if (server_socket_->Accept(socket.get(), &accept_ip, &accept_port) ==\n        false) {\n      LOG(WARNING) << \"Error on accept socket.\";\n      return false;\n    }\n  }\n  mq_iter_ = msg_queue_.begin();\n\n  for (int thread_id = 0; thread_id < max_thread_count_; ++thread_id) {\n    // create new thread for each socket\n    threads_.push_back(std::make_shared<std::thread>(\n        RecvLoop, sockets_[thread_id], msg_queue_, &queue_sem_));\n  }\n\n  return true;\n}\n\nrpc::RPCStatus SocketReceiver::Recv(rpc::RPCMessage* msg, int timeout) {\n  Message rpc_meta_msg;\n  int send_id;\n  auto status = Recv(&rpc_meta_msg, &send_id, timeout);\n  if (status == QUEUE_EMPTY) {\n    DLOG(WARNING) << \"Timed out when trying to receive rpc meta data after \"\n                  << timeout << \" milliseconds.\";\n    return rpc::kRPCTimeOut;\n  }\n  CHECK_EQ(status, REMOVE_SUCCESS);\n  char* count_ptr = rpc_meta_msg.data + rpc_meta_msg.size - sizeof(int32_t);\n  int32_t nonempty_ndarray_count = *(reinterpret_cast<int32_t*>(count_ptr));\n  // Recv real ndarray data\n  std::vector<void*> buffer_list(nonempty_ndarray_count);\n  for (int i = 0; i < nonempty_ndarray_count; ++i) {\n    Message ndarray_data_msg;\n    // As meta message has been received, data message is always expected unless\n    // connection is closed.\n    STATUS status;\n    do {\n      status = RecvFrom(&ndarray_data_msg, send_id, timeout);\n      if (status == QUEUE_EMPTY) {\n        DLOG(WARNING)\n            << \"Timed out when trying to receive rpc ndarray data after \"\n            << timeout << \" milliseconds.\";\n      }\n    } while (status == QUEUE_EMPTY);\n    CHECK_EQ(status, REMOVE_SUCCESS);\n    buffer_list[i] = ndarray_data_msg.data;\n  }\n  StreamWithBuffer zc_read_strm(\n      rpc_meta_msg.data, rpc_meta_msg.size - sizeof(int32_t), buffer_list);\n  zc_read_strm.Read(msg);\n  rpc_meta_msg.deallocator(&rpc_meta_msg);\n  return rpc::kRPCSuccess;\n}\n\nSTATUS SocketReceiver::Recv(Message* msg, int* send_id, int timeout) {\n  // queue_sem_ is a semaphore indicating how many elements in multiple\n  // message queues.\n  // When calling queue_sem_.Wait(), this Recv will be suspended until\n  // queue_sem_ > 0 or specified timeout expires, decrease queue_sem_ by 1,\n  // then start to fetch a message.\n  if (!queue_sem_.TimedWait(timeout)) {\n    return QUEUE_EMPTY;\n  }\n  for (;;) {\n    for (; mq_iter_ != msg_queue_.end(); ++mq_iter_) {\n      STATUS code = mq_iter_->second->Remove(msg, false);\n      if (code == QUEUE_EMPTY) {\n        continue;  // jump to the next queue\n      } else {\n        *send_id = mq_iter_->first;\n        ++mq_iter_;\n        return code;\n      }\n    }\n    mq_iter_ = msg_queue_.begin();\n  }\n  LOG(ERROR)\n      << \"Failed to remove message from queue due to unexpected queue status.\";\n  return QUEUE_CLOSE;\n}\n\nSTATUS SocketReceiver::RecvFrom(Message* msg, int send_id, int timeout) {\n  // Get message from specified message queue\n  if (!queue_sem_.TimedWait(timeout)) {\n    return QUEUE_EMPTY;\n  }\n  STATUS code = msg_queue_[send_id]->Remove(msg);\n  return code;\n}\n\nvoid SocketReceiver::Finalize() {\n  // Send a signal to tell the message queue to finish its job\n  for (auto& mq : msg_queue_) {\n    // wait until queue is empty\n    while (mq.second->Empty() == false) {\n      std::this_thread::sleep_for(std::chrono::seconds(1));\n    }\n    mq.second->SignalFinished(mq.first);\n  }\n  // Block main thread until all socket-threads finish their jobs\n  for (auto& thread : threads_) {\n    thread->join();\n  }\n  // Clear all sockets\n  for (auto& group_sockets : sockets_) {\n    for (auto& socket : group_sockets) {\n      socket.second->Close();\n    }\n  }\n  server_socket_->Close();\n  delete server_socket_;\n}\n\nint64_t RecvDataSize(TCPSocket* socket) {\n  int64_t received_bytes = 0;\n  int64_t data_size = 0;\n  while (static_cast<size_t>(received_bytes) < sizeof(int64_t)) {\n    int64_t max_len = sizeof(int64_t) - received_bytes;\n    int64_t tmp = socket->Receive(\n        reinterpret_cast<char*>(&data_size) + received_bytes, max_len);\n    if (tmp == -1) {\n      if (received_bytes > 0) {\n        // We want to finish reading full data_size\n        continue;\n      }\n      return -1;\n    }\n    received_bytes += tmp;\n  }\n  return data_size;\n}\n\nvoid RecvData(\n    TCPSocket* socket, char* buffer, const int64_t& data_size,\n    int64_t* received_bytes) {\n  while (*received_bytes < data_size) {\n    int64_t max_len = data_size - *received_bytes;\n    int64_t tmp = socket->Receive(buffer + *received_bytes, max_len);\n    if (tmp == -1) {\n      // Socket not ready, no more data to read\n      return;\n    }\n    *received_bytes += tmp;\n  }\n}\n\nvoid SocketReceiver::RecvLoop(\n    std::unordered_map<\n        int /* Sender (virtual) ID */, std::shared_ptr<TCPSocket>>\n        sockets,\n    std::unordered_map<\n        int /* Sender (virtual) ID */, std::shared_ptr<MessageQueue>>\n        queues,\n    runtime::Semaphore* queue_sem) {\n  std::unordered_map<int, std::unique_ptr<RecvContext>> recv_contexts;\n  SocketPool socket_pool;\n  for (auto& socket : sockets) {\n    auto& sender_id = socket.first;\n    socket_pool.AddSocket(socket.second, sender_id);\n    recv_contexts[sender_id] = std::unique_ptr<RecvContext>(new RecvContext());\n  }\n\n  // Main loop to receive messages\n  for (;;) {\n    int sender_id;\n    // Get active socket using epoll\n    std::shared_ptr<TCPSocket> socket = socket_pool.GetActiveSocket(&sender_id);\n    if (queues[sender_id]->EmptyAndNoMoreAdd()) {\n      // This sender has already stopped\n      if (socket_pool.RemoveSocket(socket) == 0) {\n        return;\n      }\n      continue;\n    }\n\n    // Nonblocking socket might be interrupted at any point. So we need to\n    // store the partially received data\n    std::unique_ptr<RecvContext>& ctx = recv_contexts[sender_id];\n    int64_t& data_size = ctx->data_size;\n    int64_t& received_bytes = ctx->received_bytes;\n    char*& buffer = ctx->buffer;\n\n    if (data_size == -1) {\n      // This is a new message, so receive the data size first\n      data_size = RecvDataSize(socket.get());\n      if (data_size > 0) {\n        try {\n          buffer = new char[data_size];\n        } catch (const std::bad_alloc&) {\n          LOG(FATAL) << \"Cannot allocate enough memory for message, \"\n                     << \"(message size: \" << data_size << \")\";\n        }\n        received_bytes = 0;\n      } else if (data_size == 0) {\n        // Received stop signal\n        if (socket_pool.RemoveSocket(socket) == 0) {\n          return;\n        }\n      }\n    }\n\n    RecvData(socket.get(), buffer, data_size, &received_bytes);\n    if (received_bytes >= data_size) {\n      // Full data received, create Message and push to queue\n      Message msg;\n      msg.data = buffer;\n      msg.size = data_size;\n      msg.deallocator = DefaultMessageDeleter;\n      queues[sender_id]->Add(msg);\n\n      // Reset recv context\n      data_size = -1;\n\n      // Signal queue semaphore\n      queue_sem->Post();\n    }\n  }\n}\n\n}  // namespace network\n}  // namespace dgl\n"
  },
  {
    "path": "src/rpc/network/socket_communicator.h",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file communicator.h\n * @brief SocketCommunicator for DGL distributed training.\n */\n#ifndef DGL_RPC_NETWORK_SOCKET_COMMUNICATOR_H_\n#define DGL_RPC_NETWORK_SOCKET_COMMUNICATOR_H_\n\n#include <memory>\n#include <string>\n#include <thread>\n#include <unordered_map>\n#include <vector>\n\n#include \"../../runtime/semaphore_wrapper.h\"\n#include \"../rpc_msg.h\"\n#include \"common.h\"\n#include \"communicator.h\"\n#include \"msg_queue.h\"\n#include \"tcp_socket.h\"\n\nnamespace dgl {\nnamespace network {\n\nstatic constexpr int kTimeOut =\n    10 * 60;  // 10 minutes (in seconds) for socket timeout\nstatic constexpr int kMaxConnection = 1024;  // maximal connection: 1024\n\n/**\n * @breif Networking address\n */\nstruct IPAddr {\n  std::string ip;\n  int port;\n};\n\n/**\n * @brief SocketSender for DGL distributed training.\n *\n * SocketSender is the communicator implemented by tcp socket.\n */\nclass SocketSender : public Sender {\n public:\n  /**\n   * @brief Sender constructor\n   * @param queue_size size of message queue\n   * @param max_thread_count size of thread pool. 0 for no limit\n   */\n  SocketSender(int64_t queue_size, int max_thread_count)\n      : Sender(queue_size, max_thread_count) {}\n\n  /**\n   * @brief Connect to a receiver.\n   *\n   * When there are multiple receivers to be connected, application will call\n   * `ConnectReceiver` for each and then call `ConnectReceiverFinalize` to make\n   * sure that either all the connections are successfully established or some\n   * of them fail.\n   *\n   * @param addr Networking address, e.g., 'tcp://127.0.0.1:50091'\n   * @param recv_id receiver's ID\n   * @return True for success and False for fail\n   *\n   * The function is *not* thread-safe; only one thread can invoke this API.\n   */\n  bool ConnectReceiver(const std::string& addr, int recv_id);\n\n  /**\n   * @brief Finalize the action to connect to receivers. Make sure that either\n   *        all connections are successfully established or connection fails.\n   * @return True for success and False for fail\n   *\n   * The function is *not* thread-safe; only one thread can invoke this API.\n   */\n  bool ConnectReceiverFinalize(const int max_try_times);\n\n  /**\n   * @brief Send RPCMessage to specified Receiver.\n   * @param msg data message\n   * @param recv_id receiver's ID\n   */\n  void Send(const rpc::RPCMessage& msg, int recv_id);\n\n  /**\n   * @brief Finalize TPSender\n   */\n  void Finalize();\n\n  /**\n   * @brief Send data to specified Receiver. Actually pushing message to message\n   * queue.\n   * @param msg data message.\n   * @param recv_id receiver's ID.\n   * @return Status code.\n   *\n   * (1) The send is non-blocking. There is no guarantee that the message has\n   * been physically sent out when the function returns. (2) The communicator\n   * will assume the responsibility of the given message. (3) The API is\n   * multi-thread safe. (4) Messages sent to the same receiver are guaranteed to\n   * be received in the same order. There is no guarantee for messages sent to\n   * different receivers.\n   */\n  STATUS Send(Message msg, int recv_id) override;\n\n private:\n  /**\n   * @brief socket for each connection of receiver\n   */\n  std::vector<\n      std::unordered_map<int /* receiver ID */, std::shared_ptr<TCPSocket>>>\n      sockets_;\n\n  /**\n   * @brief receivers' address\n   */\n  std::unordered_map<int /* receiver ID */, IPAddr> receiver_addrs_;\n\n  /**\n   * @brief message queue for each thread\n   */\n  std::vector<std::shared_ptr<MessageQueue>> msg_queue_;\n\n  /**\n   * @brief Independent thread\n   */\n  std::vector<std::shared_ptr<std::thread>> threads_;\n\n  /**\n   * @brief Send-loop for each thread\n   * @param sockets TCPSockets for current thread\n   * @param queue message_queue for current thread\n   *\n   * Note that, the SendLoop will finish its loop-job and exit thread\n   * when the main thread invokes Signal() API on the message queue.\n   */\n  static void SendLoop(\n      std::unordered_map<\n          int /* Receiver (virtual) ID */, std::shared_ptr<TCPSocket>>\n          sockets,\n      std::shared_ptr<MessageQueue> queue);\n};\n\n/**\n * @brief SocketReceiver for DGL distributed training.\n *\n * SocketReceiver is the communicator implemented by tcp socket.\n */\nclass SocketReceiver : public Receiver {\n public:\n  /**\n   * @brief Receiver constructor\n   * @param queue_size size of message queue.\n   * @param max_thread_count size of thread pool. 0 for no limit\n   */\n  SocketReceiver(int64_t queue_size, int max_thread_count)\n      : Receiver(queue_size, max_thread_count) {}\n\n  /**\n   * @brief Wait for all the Senders to connect\n   * @param addr Networking address, e.g., 'tcp://127.0.0.1:50051', 'mpi://0'\n   * @param num_sender total number of Senders\n   * @return True for success and False for fail\n   *\n   * Wait() is not thread-safe and only one thread can invoke this API.\n   */\n  bool Wait(const std::string& addr, int num_sender);\n\n  /**\n   * @brief Recv RPCMessage from Sender. Actually removing data from queue.\n   * @param msg pointer of RPCmessage\n   * @param timeout The timeout value in milliseconds. If zero, wait\n   * indefinitely.\n   * @return RPCStatus: kRPCSuccess or kRPCTimeOut.\n   */\n  rpc::RPCStatus Recv(rpc::RPCMessage* msg, int timeout);\n\n  /**\n   * @brief Recv data from Sender. Actually removing data from msg_queue.\n   * @param msg pointer of data message\n   * @param send_id which sender current msg comes from\n   * @param timeout The timeout value in milliseconds. If zero, wait\n   * indefinitely.\n   * @return Status code\n   *\n   * (1) The Recv() API is thread-safe.\n   * (2) Memory allocated by communicator but will not own it after the function\n   * returns.\n   */\n  STATUS Recv(Message* msg, int* send_id, int timeout = 0) override;\n\n  /**\n   * @brief Recv data from a specified Sender. Actually removing data from\n   * msg_queue.\n   * @param msg pointer of data message.\n   * @param send_id sender's ID\n   * @param timeout The timeout value in milliseconds. If zero, wait\n   * indefinitely.\n   * @return Status code\n   *\n   * (1) The RecvFrom() API is thread-safe.\n   * (2) Memory allocated by communicator but will not own it after the function\n   * returns.\n   */\n  STATUS RecvFrom(Message* msg, int send_id, int timeout = 0) override;\n\n  /**\n   * @brief Finalize SocketReceiver\n   *\n   * Finalize() is not thread-safe and only one thread can invoke this API.\n   */\n  void Finalize();\n\n private:\n  struct RecvContext {\n    int64_t data_size = -1;\n    int64_t received_bytes = 0;\n    char* buffer = nullptr;\n  };\n  /**\n   * @brief number of sender\n   */\n  int num_sender_;\n\n  /**\n   * @brief server socket for listening connections\n   */\n  TCPSocket* server_socket_;\n\n  /**\n   * @brief socket for each client connections\n   */\n  std::vector<std::unordered_map<\n      int /* Sender (virutal) ID */, std::shared_ptr<TCPSocket>>>\n      sockets_;\n\n  /**\n   * @brief Message queue for each socket connection\n   */\n  std::unordered_map<\n      int /* Sender (virtual) ID */, std::shared_ptr<MessageQueue>>\n      msg_queue_;\n  std::unordered_map<int, std::shared_ptr<MessageQueue>>::iterator mq_iter_;\n\n  /**\n   * @brief Independent thead\n   */\n  std::vector<std::shared_ptr<std::thread>> threads_;\n\n  /**\n   * @brief queue_sem_ semphore to indicate number of messages in multiple\n   * message queues to prevent busy wait of Recv\n   */\n  runtime::Semaphore queue_sem_;\n\n  /**\n   * @brief Recv-loop for each thread\n   * @param sockets client sockets of current thread\n   * @param queue message queues of current thread\n   *\n   * Note that, the RecvLoop will finish its loop-job and exit thread\n   * when the main thread invokes Signal() API on the message queue.\n   */\n  static void RecvLoop(\n      std::unordered_map<\n          int /* Sender (virtual) ID */, std::shared_ptr<TCPSocket>>\n          sockets,\n      std::unordered_map<\n          int /* Sender (virtual) ID */, std::shared_ptr<MessageQueue>>\n          queues,\n      runtime::Semaphore* queue_sem);\n};\n\n}  // namespace network\n}  // namespace dgl\n\n#endif  // DGL_RPC_NETWORK_SOCKET_COMMUNICATOR_H_\n"
  },
  {
    "path": "src/rpc/network/socket_pool.cc",
    "content": "/**\n *  Copyright (c) 2021 by Contributors\n * @file socket_pool.cc\n * @brief Socket pool of nonblocking sockets for DGL distributed training.\n */\n#include \"socket_pool.h\"\n\n#include <dmlc/logging.h>\n\n#include \"tcp_socket.h\"\n\n#ifdef USE_EPOLL\n#include <sys/epoll.h>\n#endif\n\nnamespace dgl {\nnamespace network {\n\nSocketPool::SocketPool() {\n#ifdef USE_EPOLL\n  epfd_ = epoll_create1(0);\n  if (epfd_ < 0) {\n    LOG(FATAL) << \"SocketPool cannot create epfd\";\n  }\n#endif\n}\n\nvoid SocketPool::AddSocket(\n    std::shared_ptr<TCPSocket> socket, int socket_id, int events) {\n  int fd = socket->Socket();\n  tcp_sockets_[fd] = socket;\n  socket_ids_[fd] = socket_id;\n\n#ifdef USE_EPOLL\n  epoll_event e;\n  e.data.fd = fd;\n  if (events == READ) {\n    e.events = EPOLLIN;\n  } else if (events == WRITE) {\n    e.events = EPOLLOUT;\n  } else if (events == READ + WRITE) {\n    e.events = EPOLLIN | EPOLLOUT;\n  }\n  if (epoll_ctl(epfd_, EPOLL_CTL_ADD, fd, &e) < 0) {\n    LOG(FATAL) << \"SocketPool cannot add socket\";\n  }\n  socket->SetNonBlocking(true);\n#else\n  if (tcp_sockets_.size() > 1) {\n    LOG(FATAL) << \"SocketPool supports only one socket if not use epoll.\"\n                  \"Please turn on USE_EPOLL on building\";\n  }\n#endif\n}\n\nsize_t SocketPool::RemoveSocket(std::shared_ptr<TCPSocket> socket) {\n  int fd = socket->Socket();\n  socket_ids_.erase(fd);\n  tcp_sockets_.erase(fd);\n#ifdef USE_EPOLL\n  epoll_ctl(epfd_, EPOLL_CTL_DEL, fd, NULL);\n#endif\n  return socket_ids_.size();\n}\n\nSocketPool::~SocketPool() {\n#ifdef USE_EPOLL\n  for (auto& id : socket_ids_) {\n    int fd = id.first;\n    epoll_ctl(epfd_, EPOLL_CTL_DEL, fd, NULL);\n  }\n#endif\n}\n\nstd::shared_ptr<TCPSocket> SocketPool::GetActiveSocket(int* socket_id) {\n  if (socket_ids_.empty()) {\n    return nullptr;\n  }\n\n  for (;;) {\n    while (pending_fds_.empty()) {\n      Wait();\n    }\n    int fd = pending_fds_.front();\n    pending_fds_.pop();\n\n    // Check if this socket is not removed\n    if (socket_ids_.find(fd) != socket_ids_.end()) {\n      *socket_id = socket_ids_[fd];\n      return tcp_sockets_[fd];\n    }\n  }\n\n  return nullptr;\n}\n\nvoid SocketPool::Wait() {\n#ifdef USE_EPOLL\n  static const int MAX_EVENTS = 10;\n  epoll_event events[MAX_EVENTS];\n  int nfd = epoll_wait(epfd_, events, MAX_EVENTS, -1 /*Timeout*/);\n  for (int i = 0; i < nfd; ++i) {\n    pending_fds_.push(events[i].data.fd);\n  }\n#else\n  pending_fds_.push(tcp_sockets_.begin()->second->Socket());\n#endif\n}\n\n}  // namespace network\n}  // namespace dgl\n"
  },
  {
    "path": "src/rpc/network/socket_pool.h",
    "content": "/**\n *  Copyright (c) 2021 by Contributors\n * @file socket_pool.h\n * @brief Socket pool of nonblocking sockets for DGL distributed training.\n */\n#ifndef DGL_RPC_NETWORK_SOCKET_POOL_H_\n#define DGL_RPC_NETWORK_SOCKET_POOL_H_\n\n#include <memory>\n#include <queue>\n#include <unordered_map>\n\nnamespace dgl {\nnamespace network {\n\nclass TCPSocket;\n\n/**\n * @brief SocketPool maintains a group of nonblocking sockets, and can provide\n * active sockets.\n * Currently SocketPool is based on epoll, a scalable I/O event notification\n * mechanism in Linux operating system.\n */\nclass SocketPool {\n public:\n  /**\n   * @brief socket mode read/receive\n   */\n  static const int READ = 1;\n  /**\n   * @brief socket mode write/send\n   */\n  static const int WRITE = 2;\n  /**\n   * @brief SocketPool constructor\n   */\n  SocketPool();\n\n  /**\n   * @brief Add a socket to SocketPool\n   * @param socket tcp socket to add\n   * @param socket_id receiver/sender id of the socket\n   * @param events READ, WRITE or READ + WRITE\n   */\n  void AddSocket(\n      std::shared_ptr<TCPSocket> socket, int socket_id, int events = READ);\n\n  /**\n   * @brief Remove socket from SocketPool\n   * @param socket tcp socket to remove\n   * @return number of remaing sockets in the pool\n   */\n  size_t RemoveSocket(std::shared_ptr<TCPSocket> socket);\n\n  /**\n   * @brief SocketPool destructor\n   */\n  ~SocketPool();\n\n  /**\n   * @brief Get current active socket. This is a blocking method\n   * @param socket_id output parameter of the socket_id of active socket\n   * @return active TCPSocket\n   */\n  std::shared_ptr<TCPSocket> GetActiveSocket(int* socket_id);\n\n private:\n  /**\n   * @brief Wait for event notification\n   */\n  void Wait();\n\n  /**\n   * @brief map from fd to TCPSocket\n   */\n  std::unordered_map<int, std::shared_ptr<TCPSocket>> tcp_sockets_;\n\n  /**\n   * @brief map from fd to socket_id\n   */\n  std::unordered_map<int, int> socket_ids_;\n\n  /**\n   * @brief fd for epoll base\n   */\n  int epfd_;\n\n  /**\n   * @brief queue for current active fds\n   */\n  std::queue<int> pending_fds_;\n};\n\n}  // namespace network\n}  // namespace dgl\n\n#endif  // DGL_RPC_NETWORK_SOCKET_POOL_H_\n"
  },
  {
    "path": "src/rpc/network/tcp_socket.cc",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file tcp_socket.cc\n * @brief TCP socket for DGL distributed training.\n */\n#include \"tcp_socket.h\"\n\n#include <dmlc/logging.h>\n\n#ifndef _WIN32\n#include <arpa/inet.h>\n#include <fcntl.h>\n#include <netdb.h>\n#include <netinet/in.h>\n#include <sys/socket.h>\n#include <unistd.h>\n#endif  // !_WIN32\n#include <errno.h>\n#include <string.h>\n\nnamespace dgl {\nnamespace network {\n\ntypedef struct sockaddr_in SAI;\ntypedef struct sockaddr SA;\n\nTCPSocket::TCPSocket() {\n  // init socket\n  socket_ = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);\n  if (socket_ < 0) {\n    LOG(FATAL) << \"Can't create new socket. Error: \" << strerror(errno);\n  }\n#ifndef _WIN32\n  // This is to make sure the same port can be reused right after the socket is\n  // closed.\n  int enable = 1;\n  if (setsockopt(socket_, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int)) < 0) {\n    LOG(WARNING) << \"cannot make the socket reusable. Error: \"\n                 << strerror(errno);\n  }\n#endif  // _WIN32\n}\n\nTCPSocket::~TCPSocket() { Close(); }\n\nbool TCPSocket::Connect(const char *ip, int port) {\n  SAI sa_server;\n  sa_server.sin_family = AF_INET;\n  sa_server.sin_port = htons(port);\n\n  int retval = 0;\n  do {  // retry if EINTR failure appears\n    if (0 < inet_pton(AF_INET, ip, &sa_server.sin_addr) &&\n        0 <= (retval = connect(\n                  socket_, reinterpret_cast<SA *>(&sa_server),\n                  sizeof(sa_server)))) {\n      return true;\n    }\n  } while (retval == -1 && errno == EINTR);\n\n  return false;\n}\n\nbool TCPSocket::Bind(const char *ip, int port) {\n  SAI sa_server;\n  sa_server.sin_family = AF_INET;\n  sa_server.sin_port = htons(port);\n  int ret = 0;\n  ret = inet_pton(AF_INET, ip, &sa_server.sin_addr);\n  if (ret == 0) {\n    LOG(ERROR) << \"Invalid IP: \" << ip;\n    return false;\n  } else if (ret < 0) {\n    LOG(ERROR) << \"Failed to convert [\" << ip\n               << \"] to binary form, error: \" << strerror(errno);\n    return false;\n  }\n  do {  // retry if EINTR failure appears\n    if (0 <=\n        (ret = bind(\n             socket_, reinterpret_cast<SA *>(&sa_server), sizeof(sa_server)))) {\n      return true;\n    }\n  } while (ret == -1 && errno == EINTR);\n\n  LOG(ERROR) << \"Failed bind on \" << ip << \":\" << port\n             << \" , error: \" << strerror(errno);\n  return false;\n}\n\nbool TCPSocket::Listen(int max_connection) {\n  int retval;\n  do {  // retry if EINTR failure appears\n    if (0 <= (retval = listen(socket_, max_connection))) {\n      return true;\n    }\n  } while (retval == -1 && errno == EINTR);\n\n  LOG(ERROR) << \"Failed listen on socket fd: \" << socket_\n             << \" , error: \" << strerror(errno);\n  return false;\n}\n\nbool TCPSocket::Accept(TCPSocket *socket, std::string *ip, int *port) {\n  int sock_client;\n  SAI sa_client;\n  socklen_t len = sizeof(sa_client);\n\n  do {  // retry if EINTR failure appears\n    sock_client = accept(socket_, reinterpret_cast<SA *>(&sa_client), &len);\n  } while (sock_client == -1 && errno == EINTR);\n\n  if (sock_client < 0) {\n    LOG(ERROR) << \"Failed accept connection on \" << *ip << \":\" << *port\n               << \", error: \" << strerror(errno)\n               << (errno == EAGAIN ? \" SO_RCVTIMEO timeout reached\" : \"\");\n    return false;\n  }\n\n  char tmp[INET_ADDRSTRLEN];\n  const char *ip_client =\n      inet_ntop(AF_INET, &sa_client.sin_addr, tmp, sizeof(tmp));\n  CHECK(ip_client != nullptr);\n  ip->assign(ip_client);\n  *port = ntohs(sa_client.sin_port);\n  socket->socket_ = sock_client;\n\n  return true;\n}\n\n#ifdef _WIN32\nbool TCPSocket::SetNonBlocking(bool flag) {\n  int result;\n  u_long argp = flag ? 1 : 0;\n\n  // XXX Non-blocking Windows Sockets apparently has tons of issues:\n  // http://www.sockets.com/winsock.htm#Overview_BlockingNonBlocking\n  // Since SetBlocking() is not used at all, I'm leaving a default\n  // implementation here.  But be warned that this is not fully tested.\n  if ((result = ioctlsocket(socket_, FIONBIO, &argp)) != NO_ERROR) {\n    LOG(ERROR) << \"Failed to set socket status.\";\n    return false;\n  }\n  return true;\n}\n#else   // !_WIN32\nbool TCPSocket::SetNonBlocking(bool flag) {\n  int opts;\n\n  if ((opts = fcntl(socket_, F_GETFL)) < 0) {\n    LOG(ERROR) << \"Failed to get socket status.\";\n    return false;\n  }\n\n  if (flag) {\n    opts |= O_NONBLOCK;\n  } else {\n    opts &= ~O_NONBLOCK;\n  }\n\n  if (fcntl(socket_, F_SETFL, opts) < 0) {\n    LOG(ERROR) << \"Failed to set socket status.\";\n    return false;\n  }\n\n  return true;\n}\n#endif  // _WIN32\n\nvoid TCPSocket::SetTimeout(int timeout) {\n#ifdef _WIN32\n  timeout = timeout * 1000;  // WIN API accepts millsec\n  setsockopt(\n      socket_, SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast<char *>(&timeout),\n      sizeof(timeout));\n#else   // !_WIN32\n  struct timeval tv;\n  tv.tv_sec = timeout;\n  tv.tv_usec = 0;\n  setsockopt(socket_, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv));\n#endif  // _WIN32\n}\n\nbool TCPSocket::ShutDown(int ways) { return 0 == shutdown(socket_, ways); }\n\nvoid TCPSocket::Close() {\n  if (socket_ >= 0) {\n#ifdef _WIN32\n    CHECK_EQ(0, closesocket(socket_));\n#else   // !_WIN32\n    CHECK_EQ(0, close(socket_));\n#endif  // _WIN32\n    socket_ = -1;\n  }\n}\n\nint64_t TCPSocket::Send(const char *data, int64_t len_data) {\n  int64_t number_send;\n\n  do {  // retry if EINTR failure appears\n    number_send = send(socket_, data, len_data, 0);\n  } while (number_send == -1 && errno == EINTR);\n  if (number_send == -1) {\n    LOG(ERROR) << \"send error: \" << strerror(errno);\n  }\n\n  return number_send;\n}\n\nint64_t TCPSocket::Receive(char *buffer, int64_t size_buffer) {\n  int64_t number_recv;\n\n  do {  // retry if EINTR failure appears\n    number_recv = recv(socket_, buffer, size_buffer, 0);\n  } while (number_recv == -1 && errno == EINTR);\n  if (number_recv == -1 && errno != EAGAIN && errno != EWOULDBLOCK) {\n    LOG(ERROR) << \"recv error: \" << strerror(errno);\n  }\n\n  return number_recv;\n}\n\nint TCPSocket::Socket() const { return socket_; }\n\n}  // namespace network\n}  // namespace dgl\n"
  },
  {
    "path": "src/rpc/network/tcp_socket.h",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file tcp_socket.h\n * @brief TCP socket for DGL distributed training.\n */\n#ifndef DGL_RPC_NETWORK_TCP_SOCKET_H_\n#define DGL_RPC_NETWORK_TCP_SOCKET_H_\n\n#ifdef _WIN32\n#include <winsock2.h>\n#include <ws2tcpip.h>\n\n#pragma comment(lib, \"Ws2_32.lib\")\n#else  // !_WIN32\n#include <sys/socket.h>\n#endif  // _WIN32\n#include <string>\n\nnamespace dgl {\nnamespace network {\n\n/**\n * @brief TCPSocket is a simple wrapper around a socket.\n * It supports only TCP connections.\n */\nclass TCPSocket {\n public:\n  /**\n   * @brief TCPSocket constructor\n   */\n  TCPSocket();\n\n  /**\n   * @brief TCPSocket deconstructor\n   */\n  ~TCPSocket();\n\n  /**\n   * @brief Connect to a given server address\n   * @param ip ip address\n   * @param port end port\n   * @return true for success and false for failure\n   */\n  bool Connect(const char* ip, int port);\n\n  /**\n   * @brief Bind on the given IP and PORT\n   * @param ip ip address\n   * @param port end port\n   * @return true for success and false for failure\n   */\n  bool Bind(const char* ip, int port);\n\n  /**\n   * @brief listen for remote connection\n   * @param max_connection maximal connection\n   * @return true for success and false for failure\n   */\n  bool Listen(int max_connection);\n\n  /**\n   * @brief wait doe a new connection\n   * @param socket new SOCKET will be stored to socket\n   * @param ip_client new IP will be stored to ip_client\n   * @param port_client new PORT will be stored to port_client\n   * @return true for success and false for failure\n   */\n  bool Accept(TCPSocket* socket, std::string* ip_client, int* port_client);\n\n  /**\n   * @brief SetNonBlocking() is needed refering to this example of epoll:\n   * http://www.kernel.org/doc/man-pages/online/pages/man4/epoll.4.html\n   * @param flag true for nonblocking, false for blocking\n   * @return true for success and false for failure\n   */\n  bool SetNonBlocking(bool flag);\n\n  /**\n   * @brief Set timeout for socket\n   * @param timeout seconds timeout\n   */\n  void SetTimeout(int timeout);\n\n  /**\n   * @brief Shut down one or both halves of the connection.\n   * @param ways ways for shutdown\n   * If ways is SHUT_RD, further receives are disallowed.\n   * If ways is SHUT_WR, further sends are disallowed.\n   * If ways is SHUT_RDWR, further sends and receives are disallowed.\n   * @return true for success and false for failure\n   */\n  bool ShutDown(int ways);\n\n  /**\n   * @brief close socket.\n   */\n  void Close();\n\n  /**\n   * @brief Send data.\n   * @param data data for sending\n   * @param len_data length of data\n   * @return return number of bytes sent if OK, -1 on error\n   */\n  int64_t Send(const char* data, int64_t len_data);\n\n  /**\n   * @brief Receive data.\n   * @param buffer buffer for receving\n   * @param size_buffer size of buffer\n   * @return return number of bytes received if OK, -1 on error\n   */\n  int64_t Receive(char* buffer, int64_t size_buffer);\n\n  /**\n   * @brief Get socket's file descriptor\n   * @return socket's file descriptor\n   */\n  int Socket() const;\n\n private:\n  /**\n   * @brief socket's file descriptor\n   */\n  int socket_;\n};\n\n}  // namespace network\n}  // namespace dgl\n\n#endif  // DGL_RPC_NETWORK_TCP_SOCKET_H_\n"
  },
  {
    "path": "src/rpc/rpc.cc",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file rpc/rpc.cc\n * @brief Implementation of RPC utilities used by both server and client sides.\n */\n#if defined(__linux__)\n#include \"./rpc.h\"\n\n#include <dgl/array.h>\n#include <dgl/packed_func_ext.h>\n#include <dgl/random.h>\n#include <dgl/runtime/container.h>\n#include <dgl/runtime/parallel_for.h>\n#include <dgl/zerocopy_serializer.h>\n#include <unistd.h>\n\n#include <csignal>\n#include <future>\n\n#include \"../c_api_common.h\"\n#include \"../runtime/resource_manager.h\"\n\nusing dgl::network::StringPrintf;\nusing namespace dgl::runtime;\n\nnamespace dgl {\nnamespace rpc {\n\n// Borrow from PyTorch\n\nconst char kSocketIfnameEnvVar[] = \"TP_SOCKET_IFNAME\";\nconst char kDefaultUvAddress[] = \"127.0.0.1\";\n\nRPCStatus SendRPCMessage(const RPCMessage& msg, const int32_t target_id) {\n  RPCContext::getInstance()->sender->Send(msg, target_id);\n  return kRPCSuccess;\n}\n\nRPCStatus RecvRPCMessage(RPCMessage* msg, int32_t timeout) {\n  static constexpr int32_t retry_timeout = 5 * 1000;  // milliseconds\n  RPCStatus status;\n  const int32_t real_timeout = timeout == 0 ? retry_timeout : timeout;\n  do {\n    status = RPCContext::getInstance()->receiver->Recv(msg, real_timeout);\n    if (status == kRPCTimeOut) {\n      static const std::string log_str = [real_timeout, timeout]() {\n        std::ostringstream oss;\n        oss << \"Recv RPCMessage timeout in \" << real_timeout << \" ms.\"\n            << (timeout == 0 ? \" Retrying ...\" : \"\");\n        return oss.str();\n      }();\n      DLOG(WARNING) << log_str;\n    }\n  } while (timeout == 0 && status == kRPCTimeOut);\n  return status;\n}\n\n//////////////////////////// C APIs ////////////////////////////\nDGL_REGISTER_GLOBAL(\"distributed.rpc._CAPI_DGLRPCReset\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) { RPCContext::Reset(); });\n\nDGL_REGISTER_GLOBAL(\"distributed.rpc._CAPI_DGLRPCCreateSender\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      int64_t msg_queue_size = args[0];\n      int max_thread_count = args[1];\n      RPCContext::getInstance()->sender.reset(\n          new network::SocketSender(msg_queue_size, max_thread_count));\n    });\n\nDGL_REGISTER_GLOBAL(\"distributed.rpc._CAPI_DGLRPCCreateReceiver\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      int64_t msg_queue_size = args[0];\n      int max_thread_count = args[1];\n      RPCContext::getInstance()->receiver.reset(\n          new network::SocketReceiver(msg_queue_size, max_thread_count));\n    });\n\nDGL_REGISTER_GLOBAL(\"distributed.rpc._CAPI_DGLRPCFinalizeSender\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      RPCContext::getInstance()->sender->Finalize();\n    });\n\nDGL_REGISTER_GLOBAL(\"distributed.rpc._CAPI_DGLRPCFinalizeReceiver\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      RPCContext::getInstance()->receiver->Finalize();\n    });\n\nDGL_REGISTER_GLOBAL(\"distributed.rpc._CAPI_DGLRPCWaitForSenders\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      std::string ip = args[0];\n      int port = args[1];\n      int num_sender = args[2];\n      std::string addr;\n      addr = StringPrintf(\"tcp://%s:%d\", ip.c_str(), port);\n      if (RPCContext::getInstance()->receiver->Wait(addr, num_sender) ==\n          false) {\n        LOG(FATAL) << \"Wait sender socket failed.\";\n      }\n    });\n\nDGL_REGISTER_GLOBAL(\"distributed.rpc._CAPI_DGLRPCConnectReceiver\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      std::string ip = args[0];\n      int port = args[1];\n      int recv_id = args[2];\n      std::string addr;\n      addr = StringPrintf(\"tcp://%s:%d\", ip.c_str(), port);\n      *rv = RPCContext::getInstance()->sender->ConnectReceiver(addr, recv_id);\n    });\n\nDGL_REGISTER_GLOBAL(\"distributed.rpc._CAPI_DGLRPCConnectReceiverFinalize\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      const int max_try_times = args[0];\n      *rv = RPCContext::getInstance()->sender->ConnectReceiverFinalize(\n          max_try_times);\n    });\n\nDGL_REGISTER_GLOBAL(\"distributed.rpc._CAPI_DGLRPCSetRank\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      const int32_t rank = args[0];\n      RPCContext::getInstance()->rank = rank;\n    });\n\nDGL_REGISTER_GLOBAL(\"distributed.rpc._CAPI_DGLRPCGetRank\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      *rv = RPCContext::getInstance()->rank;\n    });\n\nDGL_REGISTER_GLOBAL(\"distributed.rpc._CAPI_DGLRPCSetNumServer\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      const int32_t num_servers = args[0];\n      *rv = RPCContext::getInstance()->num_servers = num_servers;\n    });\n\nDGL_REGISTER_GLOBAL(\"distributed.rpc._CAPI_DGLRPCGetNumServer\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      *rv = RPCContext::getInstance()->num_servers;\n    });\n\nDGL_REGISTER_GLOBAL(\"distributed.rpc._CAPI_DGLRPCSetNumClient\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      const int32_t num_clients = args[0];\n      *rv = RPCContext::getInstance()->num_clients = num_clients;\n    });\n\nDGL_REGISTER_GLOBAL(\"distributed.rpc._CAPI_DGLRPCGetNumClient\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      *rv = RPCContext::getInstance()->num_clients;\n    });\n\nDGL_REGISTER_GLOBAL(\"distributed.rpc._CAPI_DGLRPCSetNumServerPerMachine\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      const int32_t num_servers = args[0];\n      *rv = RPCContext::getInstance()->num_servers_per_machine = num_servers;\n    });\n\nDGL_REGISTER_GLOBAL(\"distributed.rpc._CAPI_DGLRPCGetNumServerPerMachine\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      *rv = RPCContext::getInstance()->num_servers_per_machine;\n    });\n\nDGL_REGISTER_GLOBAL(\"distributed.rpc._CAPI_DGLRPCIncrMsgSeq\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      *rv = (RPCContext::getInstance()->msg_seq)++;\n    });\n\nDGL_REGISTER_GLOBAL(\"distributed.rpc._CAPI_DGLRPCGetMsgSeq\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      *rv = RPCContext::getInstance()->msg_seq;\n    });\n\nDGL_REGISTER_GLOBAL(\"distributed.rpc._CAPI_DGLRPCSetMsgSeq\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      const int64_t msg_seq = args[0];\n      RPCContext::getInstance()->msg_seq = msg_seq;\n    });\n\nDGL_REGISTER_GLOBAL(\"distributed.rpc._CAPI_DGLRPCGetBarrierCount\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      const int32_t group_id = args[0];\n      auto&& cnt = RPCContext::getInstance()->barrier_count;\n      if (cnt.find(group_id) == cnt.end()) {\n        cnt.emplace(group_id, 0x0);\n      }\n      *rv = cnt[group_id];\n    });\n\nDGL_REGISTER_GLOBAL(\"distributed.rpc._CAPI_DGLRPCSetBarrierCount\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      const int32_t count = args[0];\n      const int32_t group_id = args[1];\n      RPCContext::getInstance()->barrier_count[group_id] = count;\n    });\n\nDGL_REGISTER_GLOBAL(\"distributed.rpc._CAPI_DGLRPCGetMachineID\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      *rv = RPCContext::getInstance()->machine_id;\n    });\n\nDGL_REGISTER_GLOBAL(\"distributed.rpc._CAPI_DGLRPCSetMachineID\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      const int32_t machine_id = args[0];\n      RPCContext::getInstance()->machine_id = machine_id;\n    });\n\nDGL_REGISTER_GLOBAL(\"distributed.rpc._CAPI_DGLRPCGetNumMachines\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      *rv = RPCContext::getInstance()->num_machines;\n    });\n\nDGL_REGISTER_GLOBAL(\"distributed.rpc._CAPI_DGLRPCSetNumMachines\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      const int32_t num_machines = args[0];\n      RPCContext::getInstance()->num_machines = num_machines;\n    });\n\nDGL_REGISTER_GLOBAL(\"distributed.rpc._CAPI_DGLRPCSendRPCMessage\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      RPCMessageRef msg = args[0];\n      const int32_t target_id = args[1];\n      *rv = SendRPCMessage(*(msg.sptr()), target_id);\n    });\n\nDGL_REGISTER_GLOBAL(\"distributed.rpc._CAPI_DGLRPCRecvRPCMessage\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      int32_t timeout = args[0];\n      RPCMessageRef msg = args[1];\n      *rv = RecvRPCMessage(msg.sptr().get(), timeout);\n    });\n\n//////////////////////////// RPCMessage ////////////////////////////\n\nDGL_REGISTER_GLOBAL(\"distributed.rpc._CAPI_DGLRPCCreateEmptyRPCMessage\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      std::shared_ptr<RPCMessage> rst(new RPCMessage);\n      *rv = rst;\n    });\n\nDGL_REGISTER_GLOBAL(\"distributed.rpc._CAPI_DGLRPCCreateRPCMessage\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      std::shared_ptr<RPCMessage> rst(new RPCMessage);\n      rst->service_id = args[0];\n      rst->msg_seq = args[1];\n      rst->client_id = args[2];\n      rst->server_id = args[3];\n      const std::string data =\n          args[4];  // directly assigning string value raises errors :(\n      rst->data = data;\n      rst->tensors = ListValueToVector<NDArray>(args[5]);\n      rst->group_id = args[6];\n      *rv = rst;\n    });\n\nDGL_REGISTER_GLOBAL(\"distributed.rpc._CAPI_DGLRPCMessageGetServiceId\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      const RPCMessageRef msg = args[0];\n      *rv = msg->service_id;\n    });\n\nDGL_REGISTER_GLOBAL(\"distributed.rpc._CAPI_DGLRPCMessageGetMsgSeq\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      const RPCMessageRef msg = args[0];\n      *rv = msg->msg_seq;\n    });\n\nDGL_REGISTER_GLOBAL(\"distributed.rpc._CAPI_DGLRPCMessageGetClientId\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      const RPCMessageRef msg = args[0];\n      *rv = msg->client_id;\n    });\n\nDGL_REGISTER_GLOBAL(\"distributed.rpc._CAPI_DGLRPCMessageGetServerId\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      const RPCMessageRef msg = args[0];\n      *rv = msg->server_id;\n    });\n\nDGL_REGISTER_GLOBAL(\"distributed.rpc._CAPI_DGLRPCMessageGetData\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      const RPCMessageRef msg = args[0];\n      DGLByteArray barr{msg->data.c_str(), msg->data.size()};\n      *rv = barr;\n    });\n\nDGL_REGISTER_GLOBAL(\"distributed.rpc._CAPI_DGLRPCMessageGetTensors\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      const RPCMessageRef msg = args[0];\n      List<Value> ret;\n      for (size_t i = 0; i < msg->tensors.size(); ++i) {\n        ret.push_back(Value(MakeValue(msg->tensors[i])));\n      }\n      *rv = ret;\n    });\n\n#if defined(__linux__)\n/**\n * @brief The signal handler.\n * @param s signal\n */\nvoid SigHandler(int s) {\n  LOG(INFO) << \"\\nUser pressed Ctrl+C, Exiting\";\n  CleanupResources();\n  exit(1);\n}\n\nDGL_REGISTER_GLOBAL(\"distributed.rpc._CAPI_DGLRPCHandleSignal\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      // Ctrl+C handler\n      struct sigaction sigHandler;\n      sigHandler.sa_handler = SigHandler;\n      sigemptyset(&sigHandler.sa_mask);\n      sigHandler.sa_flags = 0;\n      sigaction(SIGINT, &sigHandler, nullptr);\n      sigaction(SIGTERM, &sigHandler, nullptr);\n    });\n#endif\n\n//////////////////////////// ServerState ////////////////////////////\n\nDGL_REGISTER_GLOBAL(\"distributed.server_state._CAPI_DGLRPCGetServerState\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      auto st = RPCContext::getInstance()->server_state;\n      if (st.get() == nullptr) {\n        RPCContext::getInstance()->server_state =\n            std::make_shared<ServerState>();\n      }\n      *rv = st;\n    });\n\n//////////////////////////// KVStore ////////////////////////////\n\nDGL_REGISTER_GLOBAL(\"distributed.rpc._CAPI_DGLRPCGetGlobalIDFromLocalPartition\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      NDArray ID = args[0];\n      NDArray part_id = args[1];\n      int local_machine_id = args[2];\n      int64_t* ID_data = static_cast<int64_t*>(ID->data);\n      int64_t* part_id_data = static_cast<int64_t*>(part_id->data);\n      int64_t ID_size = ID.GetSize() / sizeof(int64_t);\n      std::vector<int64_t> global_id;\n      for (int64_t i = 0; i < ID_size; ++i) {\n        if (part_id_data[i] == local_machine_id) {\n          global_id.push_back(ID_data[i]);\n        }\n      }\n      NDArray res_tensor = dgl::aten::VecToIdArray<int64_t>(global_id);\n      *rv = res_tensor;\n    });\n\nDGL_REGISTER_GLOBAL(\"distributed.rpc._CAPI_DGLRPCFastPull\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      // Input\n      std::string name = args[0];\n      int local_machine_id = args[1];\n      int machine_count = args[2];\n      int group_count = args[3];\n      int client_id = args[4];\n      int service_id = args[5];\n      int64_t msg_seq = args[6];\n      std::string pickle_data = args[7];\n      NDArray ID = args[8];\n      NDArray part_id = args[9];\n      NDArray local_id = args[10];\n      NDArray local_data = args[11];\n      // Data\n      dgl_id_t ID_size = ID.GetSize() / sizeof(dgl_id_t);\n      dgl_id_t* ID_data = static_cast<dgl_id_t*>(ID->data);\n      dgl_id_t* part_id_data = static_cast<dgl_id_t*>(part_id->data);\n      dgl_id_t* local_id_data = static_cast<dgl_id_t*>(local_id->data);\n      char* local_data_char = static_cast<char*>(local_data->data);\n      std::vector<dgl_id_t> local_ids;\n      std::vector<dgl_id_t> local_ids_orginal;\n      std::vector<int64_t> local_data_shape;\n      std::vector<std::vector<dgl_id_t>> remote_ids(machine_count);\n      std::vector<std::vector<dgl_id_t>> remote_ids_original(machine_count);\n      // Get row size (in bytes)\n      int row_size = 1;\n      for (int i = 0; i < local_data->ndim; ++i) {\n        local_data_shape.push_back(local_data->shape[i]);\n        if (i != 0) {\n          row_size *= local_data->shape[i];\n        }\n      }\n      row_size *= (local_data->dtype.bits / 8);\n      size_t data_size = local_data.GetSize();\n      CHECK_GT(local_data_shape.size(), 0);\n      CHECK_EQ(row_size * local_data_shape[0], data_size);\n      // Get local id (used in local machine) and\n      // remote id (send to remote machine)\n      dgl_id_t idx = 0;\n      for (dgl_id_t i = 0; i < ID_size; ++i) {\n        dgl_id_t p_id = part_id_data[i];\n        if (static_cast<int>(p_id) == local_machine_id) {\n          dgl_id_t l_id = local_id_data[idx++];\n          CHECK_LT(l_id, local_data_shape[0]);\n          CHECK_GE(l_id, 0);\n          local_ids.push_back(l_id);\n          local_ids_orginal.push_back(i);\n        } else {\n          CHECK_LT(p_id, machine_count) << \"Invalid partition ID.\";\n          dgl_id_t id = ID_data[i];\n          remote_ids[p_id].push_back(id);\n          remote_ids_original[p_id].push_back(i);\n        }\n      }\n      // Send remote id\n      int msg_count = 0;\n      for (size_t i = 0; i < remote_ids.size(); ++i) {\n        if (remote_ids[i].size() != 0) {\n          RPCMessage msg;\n          msg.service_id = service_id;\n          msg.msg_seq = msg_seq;\n          msg.client_id = client_id;\n          int lower = i * group_count;\n          int upper = (i + 1) * group_count;\n          msg.server_id =\n              dgl::RandomEngine::ThreadLocal()->RandInt(lower, upper);\n          msg.data = pickle_data;\n          NDArray tensor = dgl::aten::VecToIdArray<dgl_id_t>(remote_ids[i]);\n          msg.tensors.push_back(tensor);\n          msg.group_id = RPCContext::getInstance()->group_id;\n          SendRPCMessage(msg, msg.server_id);\n          msg_count++;\n        }\n      }\n      local_data_shape[0] = ID_size;\n      NDArray res_tensor = NDArray::Empty(\n          local_data_shape, local_data->dtype, DGLContext{kDGLCPU, 0});\n      char* return_data = static_cast<char*>(res_tensor->data);\n      // Copy local data\n      parallel_for(0, local_ids.size(), [&](size_t b, size_t e) {\n        for (auto i = b; i < e; ++i) {\n          CHECK_GE(\n              ID_size * row_size, local_ids_orginal[i] * row_size + row_size);\n          CHECK_GE(data_size, local_ids[i] * row_size + row_size);\n          CHECK_GE(local_ids[i], 0);\n          memcpy(\n              return_data + local_ids_orginal[i] * row_size,\n              local_data_char + local_ids[i] * row_size, row_size);\n        }\n      });\n      // Recv remote message\n      int recv_cnt = 0;\n      while (recv_cnt < msg_count) {\n        RPCMessage msg;\n        auto status = RecvRPCMessage(&msg, 0);\n        CHECK_EQ(status, kRPCSuccess);\n        ++recv_cnt;\n        int part_id = msg.server_id / group_count;\n        char* data_char = static_cast<char*>(msg.tensors[0]->data);\n        dgl_id_t id_size = remote_ids[part_id].size();\n        for (size_t n = 0; n < id_size; ++n) {\n          memcpy(\n              return_data + remote_ids_original[part_id][n] * row_size,\n              data_char + n * row_size, row_size);\n        }\n      }\n      *rv = res_tensor;\n    });\n\nDGL_REGISTER_GLOBAL(\"distributed.rpc._CAPI_DGLRPCGetGroupID\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      *rv = RPCContext::getInstance()->group_id;\n    });\n\nDGL_REGISTER_GLOBAL(\"distributed.rpc._CAPI_DGLRPCSetGroupID\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      const int32_t group_id = args[0];\n      RPCContext::getInstance()->group_id = group_id;\n    });\n\nDGL_REGISTER_GLOBAL(\"distributed.rpc._CAPI_DGLRPCMessageGetGroupId\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      const RPCMessageRef msg = args[0];\n      *rv = msg->group_id;\n    });\n\nDGL_REGISTER_GLOBAL(\"distributed.rpc._CAPI_DGLRPCRegisterClient\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      const int32_t client_id = args[0];\n      const int32_t group_id = args[1];\n      *rv = RPCContext::getInstance()->RegisterClient(client_id, group_id);\n    });\n\nDGL_REGISTER_GLOBAL(\"distributed.rpc._CAPI_DGLRPCGetClient\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      const int32_t client_id = args[0];\n      const int32_t group_id = args[1];\n      *rv = RPCContext::getInstance()->GetClient(client_id, group_id);\n    });\n\n}  // namespace rpc\n}  // namespace dgl\n\n#endif\n"
  },
  {
    "path": "src/rpc/rpc.h",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file rpc/rpc.h\n * @brief Common headers for remote process call (RPC).\n */\n#ifndef DGL_RPC_RPC_H_\n#define DGL_RPC_RPC_H_\n\n#include <dgl/runtime/ndarray.h>\n#include <dgl/runtime/object.h>\n#include <dgl/zerocopy_serializer.h>\n#include <dmlc/thread_local.h>\n\n#include <cstdint>\n#include <deque>\n#include <memory>\n#include <mutex>\n#include <string>\n#include <unordered_map>\n#include <vector>\n\n#include \"./network/common.h\"\n#include \"./rpc_msg.h\"\n#include \"./server_state.h\"\n#include \"network/socket_communicator.h\"\n\nnamespace dgl {\nnamespace rpc {\n\nstruct RPCContext;\n\n// Communicator handler type\ntypedef void* CommunicatorHandle;\n\n/** @brief Context information for RPC communication */\nstruct RPCContext {\n  /**\n   * @brief Rank of this process.\n   *\n   * If the process is a client, this is equal to client ID. Otherwise, the\n   * process is a server and this is equal to server ID.\n   */\n  int32_t rank = -1;\n\n  /**\n   * @brief Cuurent machine ID\n   */\n  int32_t machine_id = -1;\n\n  /**\n   * @brief Total number of machines.\n   */\n  int32_t num_machines = 0;\n\n  /**\n   * @brief Message sequence number.\n   */\n  std::atomic<int64_t> msg_seq{0};\n\n  /**\n   * @brief Total number of server.\n   */\n  int32_t num_servers = 0;\n\n  /**\n   * @brief Total number of client.\n   */\n  int32_t num_clients = 0;\n\n  /**\n   * @brief Current barrier count\n   */\n  std::unordered_map<int32_t, int32_t> barrier_count;\n\n  /**\n   * @brief Total number of server per machine.\n   */\n  int32_t num_servers_per_machine = 0;\n\n  /**\n   * @brief Sender communicator.\n   */\n  std::shared_ptr<network::SocketSender> sender;\n\n  /**\n   * @brief Receiver communicator.\n   */\n  std::shared_ptr<network::SocketReceiver> receiver;\n\n  /**\n   * @brief Server state data.\n   *\n   * If the process is a server, this stores necessary\n   * server-side data. Otherwise, the process is a client and it stores a cache\n   * of the server co-located with the client (if available). When the client\n   * invokes a RPC to the co-located server, it can thus perform computation\n   * locally without an actual remote call.\n   */\n  std::shared_ptr<ServerState> server_state;\n\n  /**\n   * @brief Cuurent group ID\n   */\n  int32_t group_id = -1;\n  int32_t curr_client_id = -1;\n  std::unordered_map<int32_t, std::unordered_map<int32_t, int32_t>> clients_;\n\n  /** @brief Get the RPC context singleton */\n  static RPCContext* getInstance() {\n    static RPCContext ctx;\n    return &ctx;\n  }\n\n  /** @brief Reset the RPC context */\n  static void Reset() {\n    auto* t = getInstance();\n    t->rank = -1;\n    t->machine_id = -1;\n    t->num_machines = 0;\n    t->msg_seq = 0;\n    t->num_servers = 0;\n    t->num_clients = 0;\n    t->barrier_count.clear();\n    t->num_servers_per_machine = 0;\n    t->sender.reset();\n    t->receiver.reset();\n    t->server_state.reset();\n    t->group_id = -1;\n    t->curr_client_id = -1;\n    t->clients_.clear();\n  }\n\n  int32_t RegisterClient(int32_t client_id, int32_t group_id) {\n    auto&& m = clients_[group_id];\n    if (m.find(client_id) != m.end()) {\n      return -1;\n    }\n    m[client_id] = ++curr_client_id;\n    return curr_client_id;\n  }\n\n  int32_t GetClient(int32_t client_id, int32_t group_id) const {\n    if (clients_.find(group_id) == clients_.end()) {\n      return -1;\n    }\n    const auto& m = clients_.at(group_id);\n    if (m.find(client_id) == m.end()) {\n      return -1;\n    }\n    return m.at(client_id);\n  }\n};\n\n/**\n * @brief Send out one RPC message.\n *\n * The operation is non-blocking -- it does not guarantee the payloads have\n * reached the target or even have left the sender process. However,\n * all the payloads (i.e., data and arrays) can be safely freed after this\n * function returns.\n *\n * The data buffer in the requst will be copied to internal buffer for actual\n * transmission, while no memory copy for tensor payloads (a.k.a. zero-copy).\n * The underlying sending threads will hold references to the tensors until\n * the contents have been transmitted.\n *\n * @param msg RPC message to send\n * @return status flag\n */\nRPCStatus SendRPCMessage(const RPCMessage& msg);\n\n/**\n * @brief Receive one RPC message.\n *\n * The operation is blocking -- it returns when it receives any message\n *\n * @param msg The received message\n * @param timeout The timeout value in milliseconds. If zero, wait indefinitely.\n * @return status flag\n */\nRPCStatus RecvRPCMessage(RPCMessage* msg, int32_t timeout = 0);\n\n}  // namespace rpc\n}  // namespace dgl\n\n#endif  // DGL_RPC_RPC_H_\n"
  },
  {
    "path": "src/rpc/rpc_msg.h",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file rpc/rpc_msg.h\n * @brief Common headers for remote process call (RPC).\n */\n#ifndef DGL_RPC_RPC_MSG_H_\n#define DGL_RPC_RPC_MSG_H_\n\n#include <dgl/runtime/ndarray.h>\n#include <dgl/runtime/object.h>\n#include <dgl/zerocopy_serializer.h>\n\n#include <string>\n#include <vector>\n\nnamespace dgl {\nnamespace rpc {\n\n/** @brief RPC message data structure\n *\n * This structure is exposed to Python and can be used as argument or return\n * value in C API.\n */\nstruct RPCMessage : public runtime::Object {\n  /** @brief Service ID */\n  int32_t service_id;\n\n  /** @brief Sequence number of this message. */\n  int64_t msg_seq;\n\n  /** @brief Client ID. */\n  int32_t client_id;\n\n  /** @brief Server ID. */\n  int32_t server_id;\n\n  /** @brief Payload buffer carried by this request.*/\n  std::string data;\n\n  /** @brief Extra payloads in the form of tensors.*/\n  std::vector<runtime::NDArray> tensors;\n\n  /** @brief Group ID. */\n  int32_t group_id{0};\n\n  bool Load(dmlc::Stream* stream) {\n    stream->Read(&service_id);\n    stream->Read(&msg_seq);\n    stream->Read(&client_id);\n    stream->Read(&server_id);\n    stream->Read(&data);\n    stream->Read(&tensors);\n    stream->Read(&group_id);\n    return true;\n  }\n\n  void Save(dmlc::Stream* stream) const {\n    stream->Write(service_id);\n    stream->Write(msg_seq);\n    stream->Write(client_id);\n    stream->Write(server_id);\n    stream->Write(data);\n    stream->Write(tensors);\n    stream->Write(group_id);\n  }\n\n  static constexpr const char* _type_key = \"rpc.RPCMessage\";\n  DGL_DECLARE_OBJECT_TYPE_INFO(RPCMessage, runtime::Object);\n};\n\nDGL_DEFINE_OBJECT_REF(RPCMessageRef, RPCMessage);\n\n/** @brief RPC status flag */\nenum RPCStatus {\n  kRPCSuccess = 0,\n  kRPCTimeOut,\n};\n\n}  // namespace rpc\n}  // namespace dgl\n\nnamespace dmlc {\nDMLC_DECLARE_TRAITS(has_saveload, dgl::rpc::RPCMessage, true);\n}  // namespace dmlc\n\n#endif  // DGL_RPC_RPC_MSG_H_\n"
  },
  {
    "path": "src/rpc/server_state.h",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file rpc/server_state.h\n * @brief Implementation of RPC utilities used by both server and client sides.\n */\n\n#ifndef DGL_RPC_SERVER_STATE_H_\n#define DGL_RPC_SERVER_STATE_H_\n\n#include <dgl/base_heterograph.h>\n#include <dgl/runtime/ndarray.h>\n#include <dgl/runtime/object.h>\n\n#include <string>\n#include <unordered_map>\n\nnamespace dgl {\nnamespace rpc {\n\n/**\n * @brief Data stored in one DGL server.\n *\n * In a distributed setting, DGL partitions all data associated with the graph\n * (e.g., node and edge features, graph structure, etc.) to multiple partitions,\n * each handled by one DGL server. Hence, the ServerState class includes all\n * the data associated with a graph partition.\n *\n * Under some setup, users may want to deploy servers in a heterogeneous way\n * -- servers are further divided into special groups for fetching/updating\n * node/edge data and for sampling/querying on graph structure respectively.\n * In this case, the ServerState can be configured to include only node/edge\n * data or graph structure.\n *\n * Each machine can have multiple server and client processes, but only one\n * server is the *master* server while all the others are backup servers. All\n * clients and backup servers share the state of the master server via shared\n * memory, which means the ServerState class must be serializable and large\n * bulk data (e.g., node/edge features) must be stored in NDArray to leverage\n * shared memory.\n */\nstruct ServerState : public runtime::Object {\n  /** @brief Key value store for NDArray data */\n  std::unordered_map<std::string, runtime::NDArray> kv_store;\n\n  /** @brief Graph structure of one partition */\n  HeteroGraphPtr graph;\n\n  /** @brief Total number of nodes */\n  int64_t total_num_nodes = 0;\n\n  /** @brief Total number of edges */\n  int64_t total_num_edges = 0;\n\n  static constexpr const char* _type_key = \"server_state.ServerState\";\n  DGL_DECLARE_OBJECT_TYPE_INFO(ServerState, runtime::Object);\n};\nDGL_DEFINE_OBJECT_REF(ServerStateRef, ServerState);\n\n}  // namespace rpc\n}  // namespace dgl\n\n#endif  // DGL_RPC_SERVER_STATE_H_\n"
  },
  {
    "path": "src/runtime/c_object_api.cc",
    "content": "/**\n *  Copyright (c) 2016 by Contributors\n * Implementation of C API (reference: tvm/src/api/c_api.cc)\n * @file c_api.cc\n */\n#include <dgl/packed_func_ext.h>\n#include <dgl/runtime/c_object_api.h>\n#include <dgl/runtime/ndarray.h>\n#include <dgl/runtime/object.h>\n#include <dmlc/base.h>\n#include <dmlc/logging.h>\n#include <dmlc/thread_local.h>\n\n#include <exception>\n#include <string>\n#include <vector>\n\n#include \"runtime_base.h\"\n\n/** @brief entry to to easily hold returning information */\nstruct DGLAPIThreadLocalEntry {\n  /** @brief result holder for returning strings */\n  std::vector<std::string> ret_vec_str;\n  /** @brief result holder for returning string pointers */\n  std::vector<const char*> ret_vec_charp;\n  /** @brief result holder for retruning string */\n  std::string ret_str;\n};\n\nusing namespace dgl::runtime;\n\n/** @brief Thread local store that can be used to hold return values. */\ntypedef dmlc::ThreadLocalStore<DGLAPIThreadLocalEntry> DGLAPIThreadLocalStore;\n\nusing DGLAPIObject = std::shared_ptr<Object>;\n\nstruct APIAttrGetter : public AttrVisitor {\n  std::string skey;\n  DGLRetValue* ret;\n  bool found_object_ref{false};\n\n  void Visit(const char* key, double* value) final {\n    if (skey == key) *ret = value[0];\n  }\n  void Visit(const char* key, int64_t* value) final {\n    if (skey == key) *ret = value[0];\n  }\n  void Visit(const char* key, uint64_t* value) final {\n    CHECK_LE(\n        value[0], static_cast<uint64_t>(std::numeric_limits<int64_t>::max()))\n        << \"cannot return too big constant\";\n    if (skey == key) *ret = static_cast<int64_t>(value[0]);\n  }\n  void Visit(const char* key, int* value) final {\n    if (skey == key) *ret = static_cast<int64_t>(value[0]);\n  }\n  void Visit(const char* key, bool* value) final {\n    if (skey == key) *ret = static_cast<int64_t>(value[0]);\n  }\n  void Visit(const char* key, std::string* value) final {\n    if (skey == key) *ret = value[0];\n  }\n  void Visit(const char* key, ObjectRef* value) final {\n    if (skey == key) {\n      *ret = value[0];\n      found_object_ref = true;\n    }\n  }\n  void Visit(const char* key, NDArray* value) final {\n    if (skey == key) *ret = value[0];\n  }\n};\n\nstruct APIAttrDir : public AttrVisitor {\n  std::vector<std::string>* names;\n\n  void Visit(const char* key, double* value) final { names->push_back(key); }\n  void Visit(const char* key, int64_t* value) final { names->push_back(key); }\n  void Visit(const char* key, uint64_t* value) final { names->push_back(key); }\n  void Visit(const char* key, bool* value) final { names->push_back(key); }\n  void Visit(const char* key, int* value) final { names->push_back(key); }\n  void Visit(const char* key, std::string* value) final {\n    names->push_back(key);\n  }\n  void Visit(const char* key, ObjectRef* value) final { names->push_back(key); }\n  void Visit(const char* key, NDArray* value) final { names->push_back(key); }\n};\n\nint DGLObjectFree(ObjectHandle handle) {\n  API_BEGIN();\n  delete static_cast<DGLAPIObject*>(handle);\n  API_END();\n}\n\nint DGLObjectTypeKey2Index(const char* type_key, int* out_index) {\n  API_BEGIN();\n  *out_index = static_cast<int>(Object::TypeKey2Index(type_key));\n  API_END();\n}\n\nint DGLObjectGetTypeIndex(ObjectHandle handle, int* out_index) {\n  API_BEGIN();\n  *out_index =\n      static_cast<int>((*static_cast<DGLAPIObject*>(handle))->type_index());\n  API_END();\n}\n\nint DGLObjectGetAttr(\n    ObjectHandle handle, const char* key, DGLValue* ret_val, int* ret_type_code,\n    int* ret_success) {\n  API_BEGIN();\n  DGLRetValue rv;\n  APIAttrGetter getter;\n  getter.skey = key;\n  getter.ret = &rv;\n  DGLAPIObject* tobject = static_cast<DGLAPIObject*>(handle);\n  if (getter.skey == \"type_key\") {\n    ret_val->v_str = (*tobject)->type_key();\n    *ret_type_code = kStr;\n    *ret_success = 1;\n  } else {\n    (*tobject)->VisitAttrs(&getter);\n    *ret_success = getter.found_object_ref || rv.type_code() != kNull;\n    if (rv.type_code() == kStr || rv.type_code() == kDGLDataType) {\n      DGLAPIThreadLocalEntry* e = DGLAPIThreadLocalStore::Get();\n      e->ret_str = rv.operator std::string();\n      *ret_type_code = kStr;\n      ret_val->v_str = e->ret_str.c_str();\n    } else {\n      rv.MoveToCHost(ret_val, ret_type_code);\n    }\n  }\n  API_END();\n}\n\nint DGLObjectListAttrNames(\n    ObjectHandle handle, int* out_size, const char*** out_array) {\n  DGLAPIThreadLocalEntry* ret = DGLAPIThreadLocalStore::Get();\n  API_BEGIN();\n  ret->ret_vec_str.clear();\n  DGLAPIObject* tobject = static_cast<DGLAPIObject*>(handle);\n  APIAttrDir dir;\n  dir.names = &(ret->ret_vec_str);\n  (*tobject)->VisitAttrs(&dir);\n  ret->ret_vec_charp.clear();\n  for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {\n    ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());\n  }\n  *out_array = dmlc::BeginPtr(ret->ret_vec_charp);\n  *out_size = static_cast<int>(ret->ret_vec_str.size());\n  API_END();\n}\n"
  },
  {
    "path": "src/runtime/c_runtime_api.cc",
    "content": "/**\n *  Copyright (c) 2016-2022 by Contributors\n * @file c_runtime_api.cc\n * @brief Runtime API implementation\n */\n#include <dgl/runtime/c_backend_api.h>\n#include <dgl/runtime/c_runtime_api.h>\n#include <dgl/runtime/device_api.h>\n#include <dgl/runtime/module.h>\n#include <dgl/runtime/packed_func.h>\n#include <dgl/runtime/registry.h>\n#include <dgl/runtime/tensordispatch.h>\n#include <dmlc/thread_local.h>\n\n#include <algorithm>\n#include <array>\n#include <cstdlib>\n#include <string>\n\n#include \"runtime_base.h\"\n\nnamespace dgl {\nnamespace runtime {\n\n/**\n * @brief The name of Device API factory.\n * @param type The device type.\n */\ninline std::string DeviceName(int type) {\n  switch (type) {\n    case kDGLCPU:\n      return \"cpu\";\n    case kDGLCUDA:\n      return \"cuda\";\n    // add more device here once supported\n    default:\n      LOG(FATAL) << \"unknown type =\" << type;\n      return \"Unknown\";\n  }\n}\n\nclass DeviceAPIManager {\n public:\n  static const int kMaxDeviceAPI = 32;\n  // Get API\n  static DeviceAPI* Get(const DGLContext& ctx) { return Get(ctx.device_type); }\n  static DeviceAPI* Get(int dev_type, bool allow_missing = false) {\n    return Global()->GetAPI(dev_type, allow_missing);\n  }\n\n private:\n  std::array<DeviceAPI*, kMaxDeviceAPI> api_;\n  DeviceAPI* rpc_api_{nullptr};\n  std::mutex mutex_;\n  // constructor\n  DeviceAPIManager() { std::fill(api_.begin(), api_.end(), nullptr); }\n  // Global static variable.\n  static DeviceAPIManager* Global() {\n    static DeviceAPIManager inst;\n    return &inst;\n  }\n  // Get or initialize API.\n  DeviceAPI* GetAPI(int type, bool allow_missing) {\n    if (type < kRPCSessMask) {\n      if (api_[type] != nullptr) return api_[type];\n      std::lock_guard<std::mutex> lock(mutex_);\n      if (api_[type] != nullptr) return api_[type];\n      api_[type] = GetAPI(DeviceName(type), allow_missing);\n      return api_[type];\n    } else {\n      if (rpc_api_ != nullptr) return rpc_api_;\n      std::lock_guard<std::mutex> lock(mutex_);\n      if (rpc_api_ != nullptr) return rpc_api_;\n      rpc_api_ = GetAPI(\"rpc\", allow_missing);\n      return rpc_api_;\n    }\n  }\n  DeviceAPI* GetAPI(const std::string name, bool allow_missing) {\n    std::string factory = \"device_api.\" + name;\n    auto* f = Registry::Get(factory);\n    if (f == nullptr) {\n      CHECK(allow_missing)\n          << \"Device API \" << name\n          << \" is not enabled. Please install the cuda version of dgl.\";\n      return nullptr;\n    }\n    void* ptr = (*f)();\n    return static_cast<DeviceAPI*>(ptr);\n  }\n};\n\nDeviceAPI* DeviceAPI::Get(DGLContext ctx, bool allow_missing) {\n  return DeviceAPIManager::Get(\n      static_cast<int>(ctx.device_type), allow_missing);\n}\n\nDeviceAPI* DeviceAPI::Get(DGLDeviceType dev_type, bool allow_missing) {\n  return DeviceAPIManager::Get(static_cast<int>(dev_type), allow_missing);\n}\n\nvoid* DeviceAPI::AllocWorkspace(\n    DGLContext ctx, size_t size, DGLDataType type_hint) {\n  return AllocDataSpace(ctx, size, kTempAllocaAlignment, type_hint);\n}\n\nvoid DeviceAPI::FreeWorkspace(DGLContext ctx, void* ptr) {\n  FreeDataSpace(ctx, ptr);\n}\n\nDGLStreamHandle DeviceAPI::CreateStream(DGLContext ctx) {\n  LOG(FATAL) << \"Device does not support stream api.\";\n  return 0;\n}\n\nvoid DeviceAPI::FreeStream(DGLContext ctx, DGLStreamHandle stream) {\n  LOG(FATAL) << \"Device does not support stream api.\";\n}\n\nvoid DeviceAPI::SyncStreamFromTo(\n    DGLContext ctx, DGLStreamHandle event_src, DGLStreamHandle event_dst) {\n  LOG(FATAL) << \"Device does not support stream api.\";\n}\n\nbool DeviceAPI::PinData(void* ptr, size_t nbytes) {\n  LOG(FATAL) << \"Device does not support cudaHostRegister api.\";\n  return false;\n}\n\nvoid* DeviceAPI::AllocPinnedDataSpace(\n    size_t nbytes, void** ctx, void** deleter) {\n  LOG(FATAL) << \"Device does not support cudaHostAlloc api.\";\n  return nullptr;\n}\n\nvoid DeviceAPI::FreePinnedDataSpace(void** deleter) {\n  LOG(FATAL) << \"Device does not support cudaHostFree api.\";\n}\n\nvoid DeviceAPI::UnpinData(void* ptr) {\n  LOG(FATAL) << \"Device does not support cudaHostUnregister api.\";\n}\n}  // namespace runtime\n}  // namespace dgl\n\nusing namespace dgl::runtime;\n\nstruct DGLRuntimeEntry {\n  std::string ret_str;\n  std::string last_error;\n  DGLByteArray ret_bytes;\n};\n\ntypedef dmlc::ThreadLocalStore<DGLRuntimeEntry> DGLAPIRuntimeStore;\n\nconst char* DGLGetLastError() {\n  return DGLAPIRuntimeStore::Get()->last_error.c_str();\n}\n\nvoid DGLAPISetLastError(const char* msg) {\n#ifndef _LIBCPP_SGX_CONFIG\n  DGLAPIRuntimeStore::Get()->last_error = msg;\n#else\n  sgx::OCallPackedFunc(\"__sgx_set_last_error__\", msg);\n#endif\n}\n\nint DGLModLoadFromFile(\n    const char* file_name, const char* format, DGLModuleHandle* out) {\n  API_BEGIN();\n  Module m = Module::LoadFromFile(file_name, format);\n  *out = new Module(m);\n  API_END();\n}\n\nint DGLModImport(DGLModuleHandle mod, DGLModuleHandle dep) {\n  API_BEGIN();\n  static_cast<Module*>(mod)->Import(*static_cast<Module*>(dep));\n  API_END();\n}\n\nint DGLModGetFunction(\n    DGLModuleHandle mod, const char* func_name, int query_imports,\n    DGLFunctionHandle* func) {\n  API_BEGIN();\n  PackedFunc pf =\n      static_cast<Module*>(mod)->GetFunction(func_name, query_imports != 0);\n  if (pf != nullptr) {\n    *func = new PackedFunc(pf);\n  } else {\n    *func = nullptr;\n  }\n  API_END();\n}\n\nint DGLModFree(DGLModuleHandle mod) {\n  API_BEGIN();\n  delete static_cast<Module*>(mod);\n  API_END();\n}\n\nint DGLBackendGetFuncFromEnv(\n    void* mod_node, const char* func_name, DGLFunctionHandle* func) {\n  API_BEGIN();\n  *func =\n      (DGLFunctionHandle)(static_cast<ModuleNode*>(mod_node)->GetFuncFromEnv(\n          func_name));\n  API_END();\n}\n\nvoid* DGLBackendAllocWorkspace(\n    int device_type, int device_id, uint64_t size, int dtype_code_hint,\n    int dtype_bits_hint) {\n  DGLContext ctx;\n  ctx.device_type = static_cast<DGLDeviceType>(device_type);\n  ctx.device_id = device_id;\n\n  DGLDataType type_hint;\n  type_hint.code = static_cast<decltype(type_hint.code)>(dtype_code_hint);\n  type_hint.bits = static_cast<decltype(type_hint.bits)>(dtype_bits_hint);\n  type_hint.lanes = 1;\n\n  return DeviceAPIManager::Get(ctx)->AllocWorkspace(\n      ctx, static_cast<size_t>(size), type_hint);\n}\n\nint DGLBackendFreeWorkspace(int device_type, int device_id, void* ptr) {\n  DGLContext ctx;\n  ctx.device_type = static_cast<DGLDeviceType>(device_type);\n  ctx.device_id = device_id;\n  DeviceAPIManager::Get(ctx)->FreeWorkspace(ctx, ptr);\n  return 0;\n}\n\nint DGLBackendRunOnce(void** handle, int (*f)(void*), void* cdata, int nbytes) {\n  if (*handle == nullptr) {\n    *handle = reinterpret_cast<void*>(1);\n    return (*f)(cdata);\n  }\n  return 0;\n}\n\nint DGLFuncFree(DGLFunctionHandle func) {\n  API_BEGIN();\n  delete static_cast<PackedFunc*>(func);\n  API_END();\n}\n\nint DGLFuncCall(\n    DGLFunctionHandle func, DGLValue* args, int* arg_type_codes, int num_args,\n    DGLValue* ret_val, int* ret_type_code) {\n  API_BEGIN();\n  DGLRetValue rv;\n  (*static_cast<const PackedFunc*>(func))\n      .CallPacked(DGLArgs(args, arg_type_codes, num_args), &rv);\n  // handle return string.\n  if (rv.type_code() == kStr || rv.type_code() == kDGLDataType ||\n      rv.type_code() == kBytes) {\n    DGLRuntimeEntry* e = DGLAPIRuntimeStore::Get();\n    if (rv.type_code() != kDGLDataType) {\n      e->ret_str = *rv.ptr<std::string>();\n    } else {\n      e->ret_str = rv.operator std::string();\n    }\n    if (rv.type_code() == kBytes) {\n      e->ret_bytes.data = e->ret_str.c_str();\n      e->ret_bytes.size = e->ret_str.length();\n      *ret_type_code = kBytes;\n      ret_val->v_handle = &(e->ret_bytes);\n    } else {\n      *ret_type_code = kStr;\n      ret_val->v_str = e->ret_str.c_str();\n    }\n  } else {\n    rv.MoveToCHost(ret_val, ret_type_code);\n  }\n  API_END();\n}\n\nint DGLCFuncSetReturn(\n    DGLRetValueHandle ret, DGLValue* value, int* type_code, int num_ret) {\n  API_BEGIN();\n  CHECK_EQ(num_ret, 1);\n  DGLRetValue* rv = static_cast<DGLRetValue*>(ret);\n  *rv = DGLArgValue(value[0], type_code[0]);\n  API_END();\n}\n\nint DGLFuncCreateFromCFunc(\n    DGLPackedCFunc func, void* resource_handle, DGLPackedCFuncFinalizer fin,\n    DGLFunctionHandle* out) {\n  API_BEGIN();\n  if (fin == nullptr) {\n    *out =\n        new PackedFunc([func, resource_handle](DGLArgs args, DGLRetValue* rv) {\n          int ret = func(\n              (DGLValue*)args.values, (int*)args.type_codes,  // NOLINT(*)\n              args.num_args, rv, resource_handle);\n          if (ret != 0) {\n            std::string err = \"DGLCall CFunc Error:\\n\";\n            err += DGLGetLastError();\n            throw dmlc::Error(err);\n          }\n        });\n  } else {\n    // wrap it in a shared_ptr, with fin as deleter.\n    // so fin will be called when the lambda went out of scope.\n    std::shared_ptr<void> rpack(resource_handle, fin);\n    *out = new PackedFunc([func, rpack](DGLArgs args, DGLRetValue* rv) {\n      int ret = func(\n          (DGLValue*)args.values, (int*)args.type_codes,  // NOLINT(*)\n          args.num_args, rv, rpack.get());\n      if (ret != 0) {\n        std::string err = \"DGLCall CFunc Error:\\n\";\n        err += DGLGetLastError();\n        throw dmlc::Error(err);\n      }\n    });\n  }\n  API_END();\n}\n\nint DGLStreamCreate(int device_type, int device_id, DGLStreamHandle* out) {\n  API_BEGIN();\n  DGLContext ctx;\n  ctx.device_type = static_cast<DGLDeviceType>(device_type);\n  ctx.device_id = device_id;\n  *out = DeviceAPIManager::Get(ctx)->CreateStream(ctx);\n  API_END();\n}\n\nint DGLStreamFree(int device_type, int device_id, DGLStreamHandle stream) {\n  API_BEGIN();\n  DGLContext ctx;\n  ctx.device_type = static_cast<DGLDeviceType>(device_type);\n  ctx.device_id = device_id;\n  DeviceAPIManager::Get(ctx)->FreeStream(ctx, stream);\n  API_END();\n}\n\nint DGLSetStream(int device_type, int device_id, DGLStreamHandle stream) {\n  API_BEGIN();\n  DGLContext ctx;\n  ctx.device_type = static_cast<DGLDeviceType>(device_type);\n  ctx.device_id = device_id;\n  DeviceAPIManager::Get(ctx)->SetStream(ctx, stream);\n  API_END();\n}\n\nint DGLGetStream(int device_type, int device_id, DGLStreamHandle* stream) {\n  API_BEGIN();\n  DGLContext ctx;\n  ctx.device_type = static_cast<DGLDeviceType>(device_type);\n  ctx.device_id = device_id;\n  *stream = DeviceAPIManager::Get(ctx)->GetStream();\n  API_END();\n}\n\nint DGLSynchronize(int device_type, int device_id, DGLStreamHandle stream) {\n  API_BEGIN();\n  DGLContext ctx;\n  ctx.device_type = static_cast<DGLDeviceType>(device_type);\n  ctx.device_id = device_id;\n  DeviceAPIManager::Get(ctx)->StreamSync(ctx, stream);\n  API_END();\n}\n\nint DGLStreamStreamSynchronize(\n    int device_type, int device_id, DGLStreamHandle src, DGLStreamHandle dst) {\n  API_BEGIN();\n  DGLContext ctx;\n  ctx.device_type = static_cast<DGLDeviceType>(device_type);\n  ctx.device_id = device_id;\n  DeviceAPIManager::Get(ctx)->SyncStreamFromTo(ctx, src, dst);\n  API_END();\n}\n\nint DGLCbArgToReturn(DGLValue* value, int code) {\n  API_BEGIN();\n  dgl::runtime::DGLRetValue rv;\n  rv = dgl::runtime::DGLArgValue(*value, code);\n  int tcode;\n  rv.MoveToCHost(value, &tcode);\n  CHECK_EQ(tcode, code);\n  API_END();\n}\n\nint DGLLoadTensorAdapter(const char* path) {\n  return TensorDispatcher::Global()->Load(path) ? 0 : -1;\n}\n\n// set device api\nDGL_REGISTER_GLOBAL(dgl::runtime::symbol::dgl_set_device)\n    .set_body([](DGLArgs args, DGLRetValue* ret) {\n      DGLContext ctx;\n      ctx.device_type = static_cast<DGLDeviceType>(args[0].operator int());\n      ctx.device_id = args[1];\n      DeviceAPIManager::Get(ctx)->SetDevice(ctx);\n    });\n\n// set device api\nDGL_REGISTER_GLOBAL(\"_GetDeviceAttr\")\n    .set_body([](DGLArgs args, DGLRetValue* ret) {\n      DGLContext ctx;\n      ctx.device_type = static_cast<DGLDeviceType>(args[0].operator int());\n      ctx.device_id = args[1];\n\n      DeviceAttrKind kind = static_cast<DeviceAttrKind>(args[2].operator int());\n      if (kind == kExist) {\n        DeviceAPI* api = DeviceAPIManager::Get(ctx.device_type, true);\n        if (api != nullptr) {\n          api->GetAttr(ctx, kind, ret);\n        } else {\n          *ret = 0;\n        }\n      } else {\n        DeviceAPIManager::Get(ctx)->GetAttr(ctx, kind, ret);\n      }\n    });\n"
  },
  {
    "path": "src/runtime/config.cc",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file runtime/config.cc\n * @brief DGL runtime config\n */\n\n#include <dgl/runtime/config.h>\n#include <dgl/runtime/registry.h>\n#if !defined(_WIN32) && defined(USE_LIBXSMM)\n#include <libxsmm_source.h>\n#endif\n\nusing namespace dgl::runtime;\n\nnamespace dgl {\nnamespace runtime {\n\nConfig::Config() {\n#if !defined(_WIN32) && defined(USE_LIBXSMM)\n  int cpu_id = libxsmm_cpuid_x86();\n  // Enable libxsmm on AVX machines by default\n  libxsmm_ = LIBXSMM_X86_AVX2 <= cpu_id && cpu_id <= LIBXSMM_X86_ALLFEAT;\n#else\n  libxsmm_ = false;\n#endif\n}\n\nvoid Config::EnableLibxsmm(bool b) { libxsmm_ = b; }\n\nbool Config::IsLibxsmmAvailable() const { return libxsmm_; }\n\nDGL_REGISTER_GLOBAL(\"global_config._CAPI_DGLConfigSetLibxsmm\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      bool use_libxsmm = args[0];\n      dgl::runtime::Config::Global()->EnableLibxsmm(use_libxsmm);\n    });\n\nDGL_REGISTER_GLOBAL(\"global_config._CAPI_DGLConfigGetLibxsmm\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      *rv = dgl::runtime::Config::Global()->IsLibxsmmAvailable();\n    });\n\n}  // namespace runtime\n}  // namespace dgl\n"
  },
  {
    "path": "src/runtime/cpu_device_api.cc",
    "content": "/**\n *  Copyright (c) 2016-2022 by Contributors\n * @file cpu_device_api.cc\n */\n#include <dgl/runtime/device_api.h>\n#include <dgl/runtime/registry.h>\n#include <dgl/runtime/tensordispatch.h>\n#include <dmlc/logging.h>\n#include <dmlc/thread_local.h>\n\n#include <cstdlib>\n#include <cstring>\n\n#include \"workspace_pool.h\"\n\nnamespace dgl {\nnamespace runtime {\nclass CPUDeviceAPI final : public DeviceAPI {\n public:\n  void SetDevice(DGLContext ctx) final {}\n  void GetAttr(DGLContext ctx, DeviceAttrKind kind, DGLRetValue* rv) final {\n    if (kind == kExist) {\n      *rv = 1;\n    }\n  }\n  void* AllocDataSpace(\n      DGLContext ctx, size_t nbytes, size_t alignment,\n      DGLDataType type_hint) final {\n    TensorDispatcher* tensor_dispatcher = TensorDispatcher::Global();\n    if (tensor_dispatcher->IsAvailable())\n      return tensor_dispatcher->CPUAllocWorkspace(nbytes);\n\n    void* ptr;\n#if _MSC_VER || defined(__MINGW32__)\n    ptr = _aligned_malloc(nbytes, alignment);\n    if (ptr == nullptr) throw std::bad_alloc();\n#elif defined(_LIBCPP_SGX_CONFIG)\n    ptr = memalign(alignment, nbytes);\n    if (ptr == nullptr) throw std::bad_alloc();\n#else\n    int ret = posix_memalign(&ptr, alignment, nbytes);\n    if (ret != 0) throw std::bad_alloc();\n#endif\n    return ptr;\n  }\n\n  void FreeDataSpace(DGLContext ctx, void* ptr) final {\n    TensorDispatcher* tensor_dispatcher = TensorDispatcher::Global();\n    if (tensor_dispatcher->IsAvailable())\n      return tensor_dispatcher->CPUFreeWorkspace(ptr);\n\n#if _MSC_VER || defined(__MINGW32__)\n    _aligned_free(ptr);\n#else\n    free(ptr);\n#endif\n  }\n\n  void CopyDataFromTo(\n      const void* from, size_t from_offset, void* to, size_t to_offset,\n      size_t size, DGLContext ctx_from, DGLContext ctx_to,\n      DGLDataType type_hint) final {\n    memcpy(\n        static_cast<char*>(to) + to_offset,\n        static_cast<const char*>(from) + from_offset, size);\n  }\n\n  void RecordedCopyDataFromTo(\n      void* from, size_t from_offset, void* to, size_t to_offset, size_t size,\n      DGLContext ctx_from, DGLContext ctx_to, DGLDataType type_hint,\n      void* pytorch_ctx) final {\n    BUG_IF_FAIL(false) << \"This piece of code should not be reached.\";\n  }\n\n  DGLStreamHandle CreateStream(DGLContext) final { return nullptr; }\n\n  void StreamSync(DGLContext ctx, DGLStreamHandle stream) final {}\n\n  void* AllocWorkspace(\n      DGLContext ctx, size_t size, DGLDataType type_hint) final;\n  void FreeWorkspace(DGLContext ctx, void* data) final;\n\n  static const std::shared_ptr<CPUDeviceAPI>& Global() {\n    static std::shared_ptr<CPUDeviceAPI> inst =\n        std::make_shared<CPUDeviceAPI>();\n    return inst;\n  }\n};\n\nstruct CPUWorkspacePool : public WorkspacePool {\n  CPUWorkspacePool() : WorkspacePool(kDGLCPU, CPUDeviceAPI::Global()) {}\n};\n\nvoid* CPUDeviceAPI::AllocWorkspace(\n    DGLContext ctx, size_t size, DGLDataType type_hint) {\n  TensorDispatcher* tensor_dispatcher = TensorDispatcher::Global();\n  if (tensor_dispatcher->IsAvailable()) {\n    return tensor_dispatcher->CPUAllocWorkspace(size);\n  }\n\n  return dmlc::ThreadLocalStore<CPUWorkspacePool>::Get()->AllocWorkspace(\n      ctx, size);\n}\n\nvoid CPUDeviceAPI::FreeWorkspace(DGLContext ctx, void* data) {\n  TensorDispatcher* tensor_dispatcher = TensorDispatcher::Global();\n  if (tensor_dispatcher->IsAvailable()) {\n    return tensor_dispatcher->CPUFreeWorkspace(data);\n  }\n\n  dmlc::ThreadLocalStore<CPUWorkspacePool>::Get()->FreeWorkspace(ctx, data);\n}\n\nDGL_REGISTER_GLOBAL(\"device_api.cpu\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      DeviceAPI* ptr = CPUDeviceAPI::Global().get();\n      *rv = static_cast<void*>(ptr);\n    });\n}  // namespace runtime\n}  // namespace dgl\n"
  },
  {
    "path": "src/runtime/cuda/cuda_common.h",
    "content": "/**\n *  Copyright (c) 2017 by Contributors\n * @file cuda_common.h\n * @brief Common utilities for CUDA\n */\n#ifndef DGL_RUNTIME_CUDA_CUDA_COMMON_H_\n#define DGL_RUNTIME_CUDA_CUDA_COMMON_H_\n\n#include <cublas_v2.h>\n#include <cuda_runtime.h>\n#include <curand.h>\n#include <cusparse.h>\n#include <dgl/runtime/packed_func.h>\n\n#include <memory>\n#include <string>\n\n#include \"../workspace_pool.h\"\n\nnamespace dgl {\nnamespace runtime {\n\n/*\n  How to use this class to get a nonblocking thrust execution policy that uses\n  DGL's memory pool and the current cuda stream\n\n  runtime::CUDAWorkspaceAllocator allocator(ctx);\n  const auto stream = runtime::getCurrentCUDAStream();\n  const auto exec_policy = thrust::cuda::par_nosync(allocator).on(stream);\n\n  now, one can pass exec_policy to thrust functions\n\n  to get an integer array of size 1000 whose lifetime is managed by unique_ptr,\n  use: auto int_array = allocator.alloc_unique<int>(1000); int_array.get() gives\n  the raw pointer.\n*/\nclass CUDAWorkspaceAllocator {\n  DGLContext ctx;\n\n public:\n  typedef char value_type;\n\n  void operator()(void* ptr) const {\n    runtime::DeviceAPI::Get(ctx)->FreeWorkspace(ctx, ptr);\n  }\n\n  explicit CUDAWorkspaceAllocator(DGLContext ctx) : ctx(ctx) {}\n\n  CUDAWorkspaceAllocator& operator=(const CUDAWorkspaceAllocator&) = default;\n\n  template <typename T>\n  std::unique_ptr<T, CUDAWorkspaceAllocator> alloc_unique(\n      std::size_t size) const {\n    return std::unique_ptr<T, CUDAWorkspaceAllocator>(\n        reinterpret_cast<T*>(runtime::DeviceAPI::Get(ctx)->AllocWorkspace(\n            ctx, sizeof(T) * size)),\n        *this);\n  }\n\n  char* allocate(std::ptrdiff_t size) const {\n    return reinterpret_cast<char*>(\n        runtime::DeviceAPI::Get(ctx)->AllocWorkspace(ctx, size));\n  }\n\n  void deallocate(char* ptr, std::size_t) const {\n    runtime::DeviceAPI::Get(ctx)->FreeWorkspace(ctx, ptr);\n  }\n};\n\ntemplate <typename T>\ninline bool is_zero(T size) {\n  return size == 0;\n}\n\ntemplate <>\ninline bool is_zero<dim3>(dim3 size) {\n  return size.x == 0 || size.y == 0 || size.z == 0;\n}\n\n#define CUDA_DRIVER_CALL(x)                                             \\\n  {                                                                     \\\n    CUresult result = x;                                                \\\n    if (result != CUDA_SUCCESS && result != CUDA_ERROR_DEINITIALIZED) { \\\n      const char* msg;                                                  \\\n      cuGetErrorName(result, &msg);                                     \\\n      LOG(FATAL) << \"CUDAError: \" #x \" failed with error: \" << msg;     \\\n    }                                                                   \\\n  }\n\n#define CUDA_CALL(func)                                      \\\n  {                                                          \\\n    cudaError_t e = (func);                                  \\\n    CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) \\\n        << \"CUDA: \" << cudaGetErrorString(e);                \\\n  }\n\n#define CUDA_KERNEL_CALL(kernel, nblks, nthrs, shmem, stream, ...)            \\\n  {                                                                           \\\n    if (!dgl::runtime::is_zero((nblks)) && !dgl::runtime::is_zero((nthrs))) { \\\n      (kernel)<<<(nblks), (nthrs), (shmem), (stream)>>>(__VA_ARGS__);         \\\n      cudaError_t e = cudaGetLastError();                                     \\\n      CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading)                \\\n          << \"CUDA kernel launch error: \" << cudaGetErrorString(e);           \\\n    }                                                                         \\\n  }\n\n#define CUSPARSE_CALL(func)                                         \\\n  {                                                                 \\\n    cusparseStatus_t e = (func);                                    \\\n    CHECK(e == CUSPARSE_STATUS_SUCCESS) << \"CUSPARSE ERROR: \" << e; \\\n  }\n\n#define CUBLAS_CALL(func)                                       \\\n  {                                                             \\\n    cublasStatus_t e = (func);                                  \\\n    CHECK(e == CUBLAS_STATUS_SUCCESS) << \"CUBLAS ERROR: \" << e; \\\n  }\n\n#define CURAND_CALL(func)                                                      \\\n  {                                                                            \\\n    curandStatus_t e = (func);                                                 \\\n    CHECK(e == CURAND_STATUS_SUCCESS)                                          \\\n        << \"CURAND Error: \" << dgl::runtime::curandGetErrorString(e) << \" at \" \\\n        << __FILE__ << \":\" << __LINE__;                                        \\\n  }\n\ninline const char* curandGetErrorString(curandStatus_t error) {\n  switch (error) {\n    case CURAND_STATUS_SUCCESS:\n      return \"CURAND_STATUS_SUCCESS\";\n    case CURAND_STATUS_VERSION_MISMATCH:\n      return \"CURAND_STATUS_VERSION_MISMATCH\";\n    case CURAND_STATUS_NOT_INITIALIZED:\n      return \"CURAND_STATUS_NOT_INITIALIZED\";\n    case CURAND_STATUS_ALLOCATION_FAILED:\n      return \"CURAND_STATUS_ALLOCATION_FAILED\";\n    case CURAND_STATUS_TYPE_ERROR:\n      return \"CURAND_STATUS_TYPE_ERROR\";\n    case CURAND_STATUS_OUT_OF_RANGE:\n      return \"CURAND_STATUS_OUT_OF_RANGE\";\n    case CURAND_STATUS_LENGTH_NOT_MULTIPLE:\n      return \"CURAND_STATUS_LENGTH_NOT_MULTIPLE\";\n    case CURAND_STATUS_DOUBLE_PRECISION_REQUIRED:\n      return \"CURAND_STATUS_DOUBLE_PRECISION_REQUIRED\";\n    case CURAND_STATUS_LAUNCH_FAILURE:\n      return \"CURAND_STATUS_LAUNCH_FAILURE\";\n    case CURAND_STATUS_PREEXISTING_FAILURE:\n      return \"CURAND_STATUS_PREEXISTING_FAILURE\";\n    case CURAND_STATUS_INITIALIZATION_FAILED:\n      return \"CURAND_STATUS_INITIALIZATION_FAILED\";\n    case CURAND_STATUS_ARCH_MISMATCH:\n      return \"CURAND_STATUS_ARCH_MISMATCH\";\n    case CURAND_STATUS_INTERNAL_ERROR:\n      return \"CURAND_STATUS_INTERNAL_ERROR\";\n  }\n  // To suppress compiler warning.\n  return \"Unrecognized curand error string\";\n}\n\n/**\n * @brief Cast data type to cudaDataType_t.\n */\ntemplate <typename T>\nstruct cuda_dtype {\n  static constexpr cudaDataType_t value = CUDA_R_32F;\n};\n\ntemplate <>\nstruct cuda_dtype<__half> {\n  static constexpr cudaDataType_t value = CUDA_R_16F;\n};\n\n#if BF16_ENABLED\ntemplate <>\nstruct cuda_dtype<__nv_bfloat16> {\n  static constexpr cudaDataType_t value = CUDA_R_16BF;\n};\n#endif  // BF16_ENABLED\n\ntemplate <>\nstruct cuda_dtype<float> {\n  static constexpr cudaDataType_t value = CUDA_R_32F;\n};\n\ntemplate <>\nstruct cuda_dtype<double> {\n  static constexpr cudaDataType_t value = CUDA_R_64F;\n};\n\n/*\n * \\brief Accumulator type for SpMM.\n */\ntemplate <typename T>\nstruct accum_dtype {\n  typedef float type;\n};\n\ntemplate <>\nstruct accum_dtype<__half> {\n  typedef float type;\n};\n\n#if BF16_ENABLED\ntemplate <>\nstruct accum_dtype<__nv_bfloat16> {\n  typedef float type;\n};\n#endif  // BF16_ENABLED\n\ntemplate <>\nstruct accum_dtype<float> {\n  typedef float type;\n};\n\ntemplate <>\nstruct accum_dtype<double> {\n  typedef double type;\n};\n\n#if CUDART_VERSION >= 11000\n/**\n * @brief Cast index data type to cusparseIndexType_t.\n */\ntemplate <typename T>\nstruct cusparse_idtype {\n  static constexpr cusparseIndexType_t value = CUSPARSE_INDEX_32I;\n};\n\ntemplate <>\nstruct cusparse_idtype<int32_t> {\n  static constexpr cusparseIndexType_t value = CUSPARSE_INDEX_32I;\n};\n\ntemplate <>\nstruct cusparse_idtype<int64_t> {\n  static constexpr cusparseIndexType_t value = CUSPARSE_INDEX_64I;\n};\n#endif\n\n/** @brief Thread local workspace */\nclass CUDAThreadEntry {\n public:\n  /** @brief The cusparse handler */\n  cusparseHandle_t cusparse_handle{nullptr};\n  /** @brief The cublas handler */\n  cublasHandle_t cublas_handle{nullptr};\n  /** @brief thread local pool*/\n  WorkspacePool pool;\n  /** @brief constructor */\n  CUDAThreadEntry();\n  // get the threadlocal workspace\n  static CUDAThreadEntry* ThreadLocal();\n};\n\n/** @brief Get the current CUDA stream */\ncudaStream_t getCurrentCUDAStream();\n}  // namespace runtime\n}  // namespace dgl\n#endif  // DGL_RUNTIME_CUDA_CUDA_COMMON_H_\n"
  },
  {
    "path": "src/runtime/cuda/cuda_device_api.cc",
    "content": "/**\n *  Copyright (c) 2017-2022 by Contributors\n * @file cuda_device_api.cc\n * @brief GPU specific API\n */\n#include <cuda_runtime.h>\n#include <dgl/runtime/device_api.h>\n#include <dgl/runtime/registry.h>\n#include <dgl/runtime/tensordispatch.h>\n#include <dmlc/thread_local.h>\n\n#include \"cuda_common.h\"\n\nnamespace dgl {\nnamespace runtime {\n\nclass CUDADeviceAPI final : public DeviceAPI {\n public:\n  CUDADeviceAPI() {\n    int count;\n    auto err = cudaGetDeviceCount(&count);\n    switch (err) {\n      case cudaSuccess:\n        break;\n      default:\n        count = 0;\n        cudaGetLastError();\n    }\n    is_available_ = count > 0;\n  }\n\n  bool IsAvailable() final { return is_available_; }\n\n  void SetDevice(DGLContext ctx) final {\n    CUDA_CALL(cudaSetDevice(ctx.device_id));\n  }\n  void GetAttr(DGLContext ctx, DeviceAttrKind kind, DGLRetValue* rv) final {\n    int value = 0;\n    switch (kind) {\n      case kExist:\n        value =\n            (cudaDeviceGetAttribute(\n                 &value, cudaDevAttrMaxThreadsPerBlock, ctx.device_id) ==\n             cudaSuccess);\n        break;\n      case kMaxThreadsPerBlock: {\n        CUDA_CALL(cudaDeviceGetAttribute(\n            &value, cudaDevAttrMaxThreadsPerBlock, ctx.device_id));\n        break;\n      }\n      case kWarpSize: {\n        CUDA_CALL(\n            cudaDeviceGetAttribute(&value, cudaDevAttrWarpSize, ctx.device_id));\n        break;\n      }\n      case kMaxSharedMemoryPerBlock: {\n        CUDA_CALL(cudaDeviceGetAttribute(\n            &value, cudaDevAttrMaxSharedMemoryPerBlock, ctx.device_id));\n        break;\n      }\n      case kComputeVersion: {\n        std::ostringstream os;\n        CUDA_CALL(cudaDeviceGetAttribute(\n            &value, cudaDevAttrComputeCapabilityMajor, ctx.device_id));\n        os << value << \".\";\n        CUDA_CALL(cudaDeviceGetAttribute(\n            &value, cudaDevAttrComputeCapabilityMinor, ctx.device_id));\n        os << value;\n        *rv = os.str();\n        return;\n      }\n      case kDeviceName: {\n        cudaDeviceProp props;\n        CUDA_CALL(cudaGetDeviceProperties(&props, ctx.device_id));\n        *rv = std::string(props.name);\n        return;\n      }\n      case kMaxClockRate: {\n        CUDA_CALL(cudaDeviceGetAttribute(\n            &value, cudaDevAttrClockRate, ctx.device_id));\n        break;\n      }\n      case kMultiProcessorCount: {\n        CUDA_CALL(cudaDeviceGetAttribute(\n            &value, cudaDevAttrMultiProcessorCount, ctx.device_id));\n        break;\n      }\n      case kMaxThreadDimensions: {\n        int dims[3];\n        CUDA_CALL(cudaDeviceGetAttribute(\n            &dims[0], cudaDevAttrMaxBlockDimX, ctx.device_id));\n        CUDA_CALL(cudaDeviceGetAttribute(\n            &dims[1], cudaDevAttrMaxBlockDimY, ctx.device_id));\n        CUDA_CALL(cudaDeviceGetAttribute(\n            &dims[2], cudaDevAttrMaxBlockDimZ, ctx.device_id));\n\n        std::stringstream ss;  // use json string to return multiple int values;\n        ss << \"[\" << dims[0] << \", \" << dims[1] << \", \" << dims[2] << \"]\";\n        *rv = ss.str();\n        return;\n      }\n    }\n    *rv = value;\n  }\n  void* AllocDataSpace(\n      DGLContext ctx, size_t nbytes, size_t alignment,\n      DGLDataType type_hint) final {\n    SetDevice(ctx);\n    // Redirect to PyTorch's allocator when available.\n    TensorDispatcher* tensor_dispatcher = TensorDispatcher::Global();\n    if (tensor_dispatcher->IsAvailable()) {\n      return tensor_dispatcher->CUDAAllocWorkspace(\n          nbytes, getCurrentCUDAStream());\n    }\n    CHECK_EQ(256 % alignment, 0U) << \"CUDA space is aligned at 256 bytes\";\n    void* ret;\n    CUDA_CALL(cudaMalloc(&ret, nbytes));\n    return ret;\n  }\n\n  void FreeDataSpace(DGLContext ctx, void* ptr) final {\n    SetDevice(ctx);\n    TensorDispatcher* tensor_dispatcher = TensorDispatcher::Global();\n    if (tensor_dispatcher->IsAvailable()) {\n      return tensor_dispatcher->CUDAFreeWorkspace(ptr);\n    }\n    CUDA_CALL(cudaFree(ptr));\n  }\n\n  void CopyDataFromTo(\n      const void* from, size_t from_offset, void* to, size_t to_offset,\n      size_t size, DGLContext ctx_from, DGLContext ctx_to,\n      DGLDataType type_hint, DGLStreamHandle stream) {\n    cudaStream_t cu_stream = static_cast<cudaStream_t>(stream);\n    from = static_cast<const char*>(from) + from_offset;\n    to = static_cast<char*>(to) + to_offset;\n    if (ctx_from.device_type == kDGLCUDA && ctx_to.device_type == kDGLCUDA) {\n      CUDA_CALL(cudaSetDevice(ctx_from.device_id));\n      if (ctx_from.device_id == ctx_to.device_id) {\n        GPUCopy(from, to, size, cudaMemcpyDeviceToDevice, cu_stream);\n      } else {\n        CUDA_CALL(cudaMemcpyPeerAsync(\n            to, ctx_to.device_id, from, ctx_from.device_id, size, cu_stream));\n      }\n    } else if (\n        ctx_from.device_type == kDGLCUDA && ctx_to.device_type == kDGLCPU) {\n      CUDA_CALL(cudaSetDevice(ctx_from.device_id));\n      GPUCopy(from, to, size, cudaMemcpyDeviceToHost, cu_stream);\n    } else if (\n        ctx_from.device_type == kDGLCPU && ctx_to.device_type == kDGLCUDA) {\n      CUDA_CALL(cudaSetDevice(ctx_to.device_id));\n      GPUCopy(from, to, size, cudaMemcpyHostToDevice, cu_stream);\n    } else {\n      LOG(FATAL) << \"expect copy from/to GPU or between GPU\";\n    }\n  }\n\n  void CopyDataFromTo(\n      const void* from, size_t from_offset, void* to, size_t to_offset,\n      size_t size, DGLContext ctx_from, DGLContext ctx_to,\n      DGLDataType type_hint) final {\n    auto stream = GetStream();\n    CopyDataFromTo(\n        from, from_offset, to, to_offset, size, ctx_from, ctx_to, type_hint,\n        stream);\n  }\n\n  // To ensure correct behavior, `record_event` must be invoked anytime a\n  // pointer from PyTorch CachingHostAllocator is used in a cudaMemcpyAsync\n  // call. It provides a way to re-use freed pinned (page-locked) memory\n  // allocations and avoid device sync due to cudaFreeHost calls.\n  void RecordedCopyDataFromTo(\n      void* from, size_t from_offset, void* to, size_t to_offset, size_t size,\n      DGLContext ctx_from, DGLContext ctx_to, DGLDataType type_hint,\n      void* pytorch_ctx) final {\n    auto stream = GetStream();\n    CopyDataFromTo(\n        from, from_offset, to, to_offset, size, ctx_from, ctx_to, type_hint,\n        stream);\n    auto tensor_dispatcher = TensorDispatcher::Global();\n    if (tensor_dispatcher->IsAvailable()) {\n      auto custream = static_cast<cudaStream_t>(stream);\n      void* ptr = ctx_to.device_type == kDGLCPU ? to : from;\n      int id =\n          ctx_to.device_type == kDGLCPU ? ctx_from.device_id : ctx_to.device_id;\n      tensor_dispatcher->CUDARecordHostAlloc(ptr, pytorch_ctx, custream, id);\n    }\n  }\n\n  DGLStreamHandle CreateStream(DGLContext ctx) {\n    CUDA_CALL(cudaSetDevice(ctx.device_id));\n    cudaStream_t retval;\n    // make sure the legacy default stream won't block on this stream\n    CUDA_CALL(cudaStreamCreateWithFlags(&retval, cudaStreamNonBlocking));\n    return static_cast<DGLStreamHandle>(retval);\n  }\n\n  void FreeStream(DGLContext ctx, DGLStreamHandle stream) {\n    CUDA_CALL(cudaSetDevice(ctx.device_id));\n    cudaStream_t cu_stream = static_cast<cudaStream_t>(stream);\n    CUDA_CALL(cudaStreamDestroy(cu_stream));\n  }\n\n  void SyncStreamFromTo(\n      DGLContext ctx, DGLStreamHandle event_src, DGLStreamHandle event_dst) {\n    CUDA_CALL(cudaSetDevice(ctx.device_id));\n    cudaStream_t src_stream = static_cast<cudaStream_t>(event_src);\n    cudaStream_t dst_stream = static_cast<cudaStream_t>(event_dst);\n    cudaEvent_t evt;\n    CUDA_CALL(cudaEventCreate(&evt));\n    CUDA_CALL(cudaEventRecord(evt, src_stream));\n    CUDA_CALL(cudaStreamWaitEvent(dst_stream, evt, 0));\n    CUDA_CALL(cudaEventDestroy(evt));\n  }\n\n  void StreamSync(DGLContext ctx, DGLStreamHandle stream) final {\n    CUDA_CALL(cudaSetDevice(ctx.device_id));\n    CUDA_CALL(cudaStreamSynchronize(static_cast<cudaStream_t>(stream)));\n  }\n\n  /** NOTE: If the backend is PyTorch, we will use PyTorch's stream management,\n   *        so just avoid calling our SetStream/CreateStream unless\n   *        you really need advanced stream control.\n   * TODO(Xin): Redirect this to PyTorch or remove it.\n   * PyTorch allows external CUDA streams to be set as current since v1.11.\n   */\n  void SetStream(DGLContext ctx, DGLStreamHandle stream) final {}\n\n  DGLStreamHandle GetStream() const final {\n    return static_cast<DGLStreamHandle>(getCurrentCUDAStream());\n  }\n\n  /** NOTE: cudaHostRegister can be called from an arbitrary GPU device,\n   *        so we don't need to specify a ctx.\n   *        The pinned memory can be seen by all CUDA contexts,\n   *        not just the one that performed the allocation\n   */\n  bool PinData(void* ptr, size_t nbytes) override {\n    // prevent users from pinning empty tensors or graphs\n    if (ptr == nullptr || nbytes == 0) return false;\n    TensorDispatcher* tensor_dispatcher = TensorDispatcher::Global();\n    // Minimize the pinned memory pool allocated by backend (via tensoradapter)\n    // to preserve enough memory for DGL inherited in-place pin-memory operation\n    if (tensor_dispatcher->IsAvailable()) {\n      tensor_dispatcher->CUDAHostAllocatorEmptyCache();\n    }\n    CUDA_CALL(cudaHostRegister(ptr, nbytes, cudaHostRegisterDefault));\n    return true;\n  }\n\n  void UnpinData(void* ptr) {\n    if (ptr == nullptr) return;\n    CUDA_CALL(cudaHostUnregister(ptr));\n  }\n\n  void* AllocPinnedDataSpace(\n      size_t nbytes, void** ctx, void** deleter) override {\n    // prevent pinning empty tensors or graphs\n    if (nbytes == 0) return nullptr;\n    TensorDispatcher* tensor_dispatcher = TensorDispatcher::Global();\n    CHECK(tensor_dispatcher->IsAvailable())\n        << \"CachingHostAllocator is not available in the current backend \"\n           \"PyTorch. Please update the PyTorch version to 1.11+\";\n    return tensor_dispatcher->CUDAAllocHostWorkspace(nbytes, ctx, deleter);\n  }\n\n  void FreePinnedDataSpace(void** deleter) override {\n    TensorDispatcher* tensor_dispatcher = TensorDispatcher::Global();\n    CHECK(tensor_dispatcher->IsAvailable())\n        << \"CachingHostAllocator is not available in the current backend \"\n           \"PyTorch. Please update the PyTorch version to 1.11+\";\n    tensor_dispatcher->CUDAFreeHostWorkspace(deleter);\n  }\n\n  bool IsPinned(const void* ptr) override {\n    // can't be a pinned tensor if CUDA context is unavailable.\n    if (!is_available_) return false;\n\n    cudaPointerAttributes attr;\n    cudaError_t status = cudaPointerGetAttributes(&attr, ptr);\n    bool result = false;\n\n    switch (status) {\n      case cudaErrorInvalidValue:\n        // might be a normal CPU tensor in CUDA 10.2-\n        cudaGetLastError();  // clear error\n        break;\n      case cudaSuccess:\n        result = (attr.type == cudaMemoryTypeHost);\n        break;\n      case cudaErrorInitializationError:\n      case cudaErrorNoDevice:\n      case cudaErrorInsufficientDriver:\n      case cudaErrorInvalidDevice:\n        // We don't want to fail in these particular cases since this function\n        // can be called when users only want to run on CPU even if CUDA API is\n        // enabled, or in a forked subprocess where CUDA context cannot be\n        // initialized.  So we just mark the CUDA context to unavailable and\n        // return.\n        is_available_ = false;\n        cudaGetLastError();  // clear error\n        break;\n      default:\n        LOG(FATAL) << \"error while determining memory status: \"\n                   << cudaGetErrorString(status);\n        break;\n    }\n\n    return result;\n  }\n\n  void* AllocWorkspace(\n      DGLContext ctx, size_t size, DGLDataType type_hint) final {\n    SetDevice(ctx);\n    // Redirect to PyTorch's allocator when available.\n    TensorDispatcher* tensor_dispatcher = TensorDispatcher::Global();\n    if (tensor_dispatcher->IsAvailable())\n      return tensor_dispatcher->CUDAAllocWorkspace(\n          size, getCurrentCUDAStream());\n\n    return CUDAThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size);\n  }\n\n  void FreeWorkspace(DGLContext ctx, void* data) final {\n    SetDevice(ctx);\n    TensorDispatcher* tensor_dispatcher = TensorDispatcher::Global();\n    if (tensor_dispatcher->IsAvailable())\n      return tensor_dispatcher->CUDAFreeWorkspace(data);\n\n    CUDAThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data);\n  }\n\n  static const std::shared_ptr<CUDADeviceAPI>& Global() {\n    static std::shared_ptr<CUDADeviceAPI> inst =\n        std::make_shared<CUDADeviceAPI>();\n    return inst;\n  }\n\n private:\n  static void GPUCopy(\n      const void* from, void* to, size_t size, cudaMemcpyKind kind,\n      cudaStream_t stream) {\n    CUDA_CALL(cudaMemcpyAsync(to, from, size, kind, stream));\n    if (stream == 0 && kind == cudaMemcpyDeviceToHost) {\n      // only wait for the copy, when it's on the default stream, and it's to\n      // host memory\n      CUDA_CALL(cudaStreamSynchronize(stream));\n    }\n  }\n\n  bool is_available_ = true;\n};\n\ntypedef dmlc::ThreadLocalStore<CUDAThreadEntry> CUDAThreadStore;\n\nCUDAThreadEntry::CUDAThreadEntry() : pool(kDGLCUDA, CUDADeviceAPI::Global()) {}\n\nCUDAThreadEntry* CUDAThreadEntry::ThreadLocal() {\n  return CUDAThreadStore::Get();\n}\n\ncudaStream_t getCurrentCUDAStream() {\n  TensorDispatcher* tensor_dispatcher = TensorDispatcher::Global();\n  if (tensor_dispatcher->IsAvailable())\n    return tensor_dispatcher->CUDAGetCurrentStream();\n  else  // return the default stream when TA is not available\n    return nullptr;\n}\n\nDGL_REGISTER_GLOBAL(\"device_api.cuda\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      DeviceAPI* ptr = CUDADeviceAPI::Global().get();\n      *rv = static_cast<void*>(ptr);\n    });\n\n}  // namespace runtime\n}  // namespace dgl\n"
  },
  {
    "path": "src/runtime/cuda/cuda_hashtable.cu",
    "content": "/**\n *  Copyright (c) 2021 by Contributors\n * @file runtime/cuda/cuda_device_common.cuh\n * @brief Device level functions for within cuda kernels.\n */\n\n#include <cassert>\n#include <cub/cub.cuh>  // NOLINT\n\n#include \"../../array/cuda/atomic.cuh\"\n#include \"cuda_common.h\"\n#include \"cuda_hashtable.cuh\"\n\nusing namespace dgl::aten::cuda;\n\nnamespace dgl {\nnamespace runtime {\nnamespace cuda {\n\nnamespace {\n\nconstexpr static const int BLOCK_SIZE = 256;\nconstexpr static const size_t TILE_SIZE = 1024;\n\n/**\n * @brief This is the mutable version of the DeviceOrderedHashTable, for use in\n * inserting elements into the hashtable.\n *\n * @tparam IdType The type of ID to store in the hashtable.\n */\ntemplate <typename IdType>\nclass MutableDeviceOrderedHashTable : public DeviceOrderedHashTable<IdType> {\n public:\n  typedef typename DeviceOrderedHashTable<IdType>::Mapping* Iterator;\n  static constexpr IdType kEmptyKey = DeviceOrderedHashTable<IdType>::kEmptyKey;\n\n  /**\n   * @brief Create a new mutable hashtable for use on the device.\n   *\n   * @param hostTable The original hash table on the host.\n   */\n  explicit MutableDeviceOrderedHashTable(\n      OrderedHashTable<IdType>* const hostTable)\n      : DeviceOrderedHashTable<IdType>(hostTable->DeviceHandle()) {}\n\n  /**\n   * @brief Find the mutable mapping of a given key within the hash table.\n   *\n   * WARNING: The key must exist within the hashtable. Searching for a key not\n   * in the hashtable is undefined behavior.\n   *\n   * @param id The key to search for.\n   *\n   * @return The mapping.\n   */\n  inline __device__ Iterator Search(const IdType id) {\n    const IdType pos = SearchForPosition(id);\n\n    return GetMutable(pos);\n  }\n\n  /**\n   * @brief Attempt to insert into the hash table at a specific location.\n   *\n   * @param pos The position to insert at.\n   * @param id The ID to insert into the hash table.\n   * @param index The original index of the item being inserted.\n   *\n   * @return True, if the insertion was successful.\n   */\n  inline __device__ bool AttemptInsertAt(\n      const size_t pos, const IdType id, const size_t index) {\n    const IdType key = AtomicCAS(&GetMutable(pos)->key, kEmptyKey, id);\n    if (key == kEmptyKey || key == id) {\n      // we either set a match key, or found a matching key, so then place the\n      // minimum index in position. Match the type of atomicMin, so ignore\n      // linting\n      atomicMin(\n          reinterpret_cast<unsigned long long*>(  // NOLINT\n              &GetMutable(pos)->index),\n          static_cast<unsigned long long>(index));  // NOLINT\n      return true;\n    } else {\n      // we need to search elsewhere\n      return false;\n    }\n  }\n\n  /**\n   * @brief Insert key-index pair into the hashtable.\n   *\n   * @param id The ID to insert.\n   * @param index The index at which the ID occured.\n   *\n   * @return An iterator to inserted mapping.\n   */\n  inline __device__ Iterator Insert(const IdType id, const size_t index) {\n    size_t pos = Hash(id);\n\n    // linearly scan for an empty slot or matching entry\n    IdType delta = 1;\n    while (!AttemptInsertAt(pos, id, index)) {\n      pos = Hash(pos + delta);\n      delta += 1;\n    }\n\n    return GetMutable(pos);\n  }\n\n private:\n  /**\n   * @brief Get a mutable iterator to the given bucket in the hashtable.\n   *\n   * @param pos The given bucket.\n   *\n   * @return The iterator.\n   */\n  inline __device__ Iterator GetMutable(const size_t pos) {\n    assert(pos < this->size_);\n    // The parent class Device is read-only, but we ensure this can only be\n    // constructed from a mutable version of OrderedHashTable, making this\n    // a safe cast to perform.\n    return const_cast<Iterator>(this->table_ + pos);\n  }\n};\n\n/**\n * @brief Calculate the number of buckets in the hashtable. To guarantee we can\n * fill the hashtable in the worst case, we must use a number of buckets which\n * is a power of two.\n * https://en.wikipedia.org/wiki/Quadratic_probing#Limitations\n *\n * @param num The number of items to insert (should be an upper bound on the\n * number of unique keys).\n * @param scale The power of two larger the number of buckets should be than the\n * unique keys.\n *\n * @return The number of buckets the table should contain.\n */\nsize_t TableSize(const size_t num, const int scale) {\n  const size_t next_pow2 = 1 << static_cast<size_t>(1 + std::log2(num >> 1));\n  return next_pow2 << scale;\n}\n\n/**\n * @brief This structure is used with cub's block-level prefixscan in order to\n * keep a running sum as items are iteratively processed.\n *\n * @tparam IdType The type to perform the prefixsum on.\n */\ntemplate <typename IdType>\nstruct BlockPrefixCallbackOp {\n  IdType running_total_;\n\n  __device__ BlockPrefixCallbackOp(const IdType running_total)\n      : running_total_(running_total) {}\n\n  __device__ IdType operator()(const IdType block_aggregate) {\n    const IdType old_prefix = running_total_;\n    running_total_ += block_aggregate;\n    return old_prefix;\n  }\n};\n\n}  // namespace\n\n/**\n * @brief This generates a hash map where the keys are the global item numbers,\n * and the values are indexes, and inputs may have duplciates.\n *\n * @tparam IdType The type of of id.\n * @tparam BLOCK_SIZE The size of the thread block.\n * @tparam TILE_SIZE The number of entries each thread block will process.\n * @param items The items to insert.\n * @param num_items The number of items to insert.\n * @param table The hash table.\n */\ntemplate <typename IdType, int BLOCK_SIZE, size_t TILE_SIZE>\n__global__ void generate_hashmap_duplicates(\n    const IdType* const items, const int64_t num_items,\n    MutableDeviceOrderedHashTable<IdType> table) {\n  assert(BLOCK_SIZE == blockDim.x);\n\n  const size_t block_start = TILE_SIZE * blockIdx.x;\n  const size_t block_end = TILE_SIZE * (blockIdx.x + 1);\n\n#pragma unroll\n  for (size_t index = threadIdx.x + block_start; index < block_end;\n       index += BLOCK_SIZE) {\n    if (index < num_items) {\n      table.Insert(items[index], index);\n    }\n  }\n}\n\n/**\n * @brief This generates a hash map where the keys are the global item numbers,\n * and the values are indexes, and all inputs are unique.\n *\n * @tparam IdType The type of of id.\n * @tparam BLOCK_SIZE The size of the thread block.\n * @tparam TILE_SIZE The number of entries each thread block will process.\n * @param items The unique items to insert.\n * @param num_items The number of items to insert.\n * @param table The hash table.\n */\ntemplate <typename IdType, int BLOCK_SIZE, size_t TILE_SIZE>\n__global__ void generate_hashmap_unique(\n    const IdType* const items, const int64_t num_items,\n    MutableDeviceOrderedHashTable<IdType> table) {\n  assert(BLOCK_SIZE == blockDim.x);\n\n  using Iterator = typename MutableDeviceOrderedHashTable<IdType>::Iterator;\n\n  const size_t block_start = TILE_SIZE * blockIdx.x;\n  const size_t block_end = TILE_SIZE * (blockIdx.x + 1);\n\n#pragma unroll\n  for (size_t index = threadIdx.x + block_start; index < block_end;\n       index += BLOCK_SIZE) {\n    if (index < num_items) {\n      const Iterator pos = table.Insert(items[index], index);\n\n      // since we are only inserting unique items, we know their local id\n      // will be equal to their index\n      pos->local = static_cast<IdType>(index);\n    }\n  }\n}\n\n/**\n * @brief This counts the number of nodes inserted per thread block.\n *\n * @tparam IdType The type of of id.\n * @tparam BLOCK_SIZE The size of the thread block.\n * @tparam TILE_SIZE The number of entries each thread block will process.\n * @param input The nodes to insert.\n * @param num_input The number of nodes to insert.\n * @param table The hash table.\n * @param num_unique The number of nodes inserted into the hash table per thread\n * block.\n */\ntemplate <typename IdType, int BLOCK_SIZE, size_t TILE_SIZE>\n__global__ void count_hashmap(\n    const IdType* items, const size_t num_items,\n    DeviceOrderedHashTable<IdType> table, IdType* const num_unique) {\n  assert(BLOCK_SIZE == blockDim.x);\n\n  using BlockReduce = typename cub::BlockReduce<IdType, BLOCK_SIZE>;\n  using Mapping = typename DeviceOrderedHashTable<IdType>::Mapping;\n\n  const size_t block_start = TILE_SIZE * blockIdx.x;\n  const size_t block_end = TILE_SIZE * (blockIdx.x + 1);\n\n  IdType count = 0;\n\n#pragma unroll\n  for (size_t index = threadIdx.x + block_start; index < block_end;\n       index += BLOCK_SIZE) {\n    if (index < num_items) {\n      const Mapping& mapping = *table.Search(items[index]);\n      if (mapping.index == index) {\n        ++count;\n      }\n    }\n  }\n\n  __shared__ typename BlockReduce::TempStorage temp_space;\n\n  count = BlockReduce(temp_space).Sum(count);\n\n  if (threadIdx.x == 0) {\n    num_unique[blockIdx.x] = count;\n    if (blockIdx.x == 0) {\n      num_unique[gridDim.x] = 0;\n    }\n  }\n}\n\n/**\n * @brief Update the local numbering of elements in the hashmap.\n *\n * @tparam IdType The type of id.\n * @tparam BLOCK_SIZE The size of the thread blocks.\n * @tparam TILE_SIZE The number of elements each thread block works on.\n * @param items The set of non-unique items to update from.\n * @param num_items The number of non-unique items.\n * @param table The hash table.\n * @param num_items_prefix The number of unique items preceding each thread\n * block.\n * @param unique_items The set of unique items (output).\n * @param num_unique_items The number of unique items (output).\n */\ntemplate <typename IdType, int BLOCK_SIZE, size_t TILE_SIZE>\n__global__ void compact_hashmap(\n    const IdType* const items, const size_t num_items,\n    MutableDeviceOrderedHashTable<IdType> table,\n    const IdType* const num_items_prefix, IdType* const unique_items,\n    int64_t* const num_unique_items) {\n  assert(BLOCK_SIZE == blockDim.x);\n\n  using FlagType = uint16_t;\n  using BlockScan = typename cub::BlockScan<FlagType, BLOCK_SIZE>;\n  using Mapping = typename DeviceOrderedHashTable<IdType>::Mapping;\n\n  constexpr const int32_t VALS_PER_THREAD = TILE_SIZE / BLOCK_SIZE;\n\n  __shared__ typename BlockScan::TempStorage temp_space;\n\n  const IdType offset = num_items_prefix[blockIdx.x];\n\n  BlockPrefixCallbackOp<FlagType> prefix_op(0);\n\n  // count successful placements\n  for (int32_t i = 0; i < VALS_PER_THREAD; ++i) {\n    const IdType index = threadIdx.x + i * BLOCK_SIZE + blockIdx.x * TILE_SIZE;\n\n    FlagType flag;\n    Mapping* kv;\n    if (index < num_items) {\n      kv = table.Search(items[index]);\n      flag = kv->index == index;\n    } else {\n      flag = 0;\n    }\n\n    if (!flag) {\n      kv = nullptr;\n    }\n\n    BlockScan(temp_space).ExclusiveSum(flag, flag, prefix_op);\n    __syncthreads();\n\n    if (kv) {\n      const IdType pos = offset + flag;\n      kv->local = pos;\n      unique_items[pos] = items[index];\n    }\n  }\n\n  if (threadIdx.x == 0 && blockIdx.x == 0) {\n    *num_unique_items = num_items_prefix[gridDim.x];\n  }\n}\n\n// DeviceOrderedHashTable implementation\n\ntemplate <typename IdType>\nDeviceOrderedHashTable<IdType>::DeviceOrderedHashTable(\n    const Mapping* const table, const size_t size)\n    : table_(table), size_(size) {}\n\ntemplate <typename IdType>\nDeviceOrderedHashTable<IdType> OrderedHashTable<IdType>::DeviceHandle() const {\n  return DeviceOrderedHashTable<IdType>(table_, size_);\n}\n\n// OrderedHashTable implementation\n\ntemplate <typename IdType>\nOrderedHashTable<IdType>::OrderedHashTable(\n    const size_t size, DGLContext ctx, cudaStream_t stream, const int scale)\n    : table_(nullptr), size_(TableSize(size, scale)), ctx_(ctx) {\n  // make sure we will at least as many buckets as items.\n  CHECK_GT(scale, 0);\n\n  auto device = runtime::DeviceAPI::Get(ctx_);\n  table_ = static_cast<Mapping*>(\n      device->AllocWorkspace(ctx_, sizeof(Mapping) * size_));\n\n  CUDA_CALL(cudaMemsetAsync(\n      table_, DeviceOrderedHashTable<IdType>::kEmptyKey,\n      sizeof(Mapping) * size_, stream));\n}\n\ntemplate <typename IdType>\nOrderedHashTable<IdType>::~OrderedHashTable() {\n  auto device = runtime::DeviceAPI::Get(ctx_);\n  device->FreeWorkspace(ctx_, table_);\n}\n\ntemplate <typename IdType>\nvoid OrderedHashTable<IdType>::FillWithDuplicates(\n    const IdType* const input, const size_t num_input, IdType* const unique,\n    int64_t* const num_unique, cudaStream_t stream) {\n  auto device = runtime::DeviceAPI::Get(ctx_);\n\n  const int64_t num_tiles = (num_input + TILE_SIZE - 1) / TILE_SIZE;\n\n  const dim3 grid(num_tiles);\n  const dim3 block(BLOCK_SIZE);\n\n  auto device_table = MutableDeviceOrderedHashTable<IdType>(this);\n\n  CUDA_KERNEL_CALL(\n      (generate_hashmap_duplicates<IdType, BLOCK_SIZE, TILE_SIZE>), grid, block,\n      0, stream, input, num_input, device_table);\n\n  IdType* item_prefix = static_cast<IdType*>(\n      device->AllocWorkspace(ctx_, sizeof(IdType) * (num_input + 1)));\n\n  CUDA_KERNEL_CALL(\n      (count_hashmap<IdType, BLOCK_SIZE, TILE_SIZE>), grid, block, 0, stream,\n      input, num_input, device_table, item_prefix);\n\n  size_t workspace_bytes;\n  CUDA_CALL(cub::DeviceScan::ExclusiveSum(\n      nullptr, workspace_bytes, static_cast<IdType*>(nullptr),\n      static_cast<IdType*>(nullptr), grid.x + 1, stream));\n  void* workspace = device->AllocWorkspace(ctx_, workspace_bytes);\n\n  CUDA_CALL(cub::DeviceScan::ExclusiveSum(\n      workspace, workspace_bytes, item_prefix, item_prefix, grid.x + 1,\n      stream));\n  device->FreeWorkspace(ctx_, workspace);\n\n  CUDA_KERNEL_CALL(\n      (compact_hashmap<IdType, BLOCK_SIZE, TILE_SIZE>), grid, block, 0, stream,\n      input, num_input, device_table, item_prefix, unique, num_unique);\n  device->FreeWorkspace(ctx_, item_prefix);\n}\n\ntemplate <typename IdType>\nvoid OrderedHashTable<IdType>::FillWithUnique(\n    const IdType* const input, const size_t num_input, cudaStream_t stream) {\n  const int64_t num_tiles = (num_input + TILE_SIZE - 1) / TILE_SIZE;\n\n  const dim3 grid(num_tiles);\n  const dim3 block(BLOCK_SIZE);\n\n  auto device_table = MutableDeviceOrderedHashTable<IdType>(this);\n\n  CUDA_KERNEL_CALL(\n      (generate_hashmap_unique<IdType, BLOCK_SIZE, TILE_SIZE>), grid, block, 0,\n      stream, input, num_input, device_table);\n}\n\ntemplate class OrderedHashTable<int32_t>;\ntemplate class OrderedHashTable<int64_t>;\n\n}  // namespace cuda\n}  // namespace runtime\n}  // namespace dgl\n"
  },
  {
    "path": "src/runtime/cuda/cuda_hashtable.cuh",
    "content": "/**\n *  Copyright (c) 2021 by Contributors\n * @file runtime/cuda/cuda_device_common.cuh\n * @brief Device level functions for within cuda kernels.\n */\n\n#ifndef DGL_RUNTIME_CUDA_CUDA_HASHTABLE_CUH_\n#define DGL_RUNTIME_CUDA_CUDA_HASHTABLE_CUH_\n\n#include <dgl/runtime/c_runtime_api.h>\n\n#include \"cuda_common.h\"\n#include \"cuda_runtime.h\"\n\nnamespace dgl {\nnamespace runtime {\nnamespace cuda {\n\ntemplate <typename>\nclass OrderedHashTable;\n\n/**\n * @brief A device-side handle for a GPU hashtable for mapping items to the\n * first index at which they appear in the provided data array.\n *\n * For any ID array A, one can view it as a mapping from the index `i`\n * (continuous integer range from zero) to its element `A[i]`. This hashtable\n * serves as a reverse mapping, i.e., from element `A[i]` to its index `i`.\n * Quadratic probing is used for collision resolution. See\n * DeviceOrderedHashTable's documentation for how the Mapping structure is\n * used.\n *\n * The hash table should be used in two phases, with the first being populating\n * the hash table with the OrderedHashTable object, and then generating this\n * handle from it. This object can then be used to search the hash table,\n * to find mappings, from with CUDA code.\n *\n * If a device-side handle is created from a hash table with the following\n * entries:\n * [\n *   {key: 0, local: 0, index: 0},\n *   {key: 3, local: 1, index: 1},\n *   {key: 2, local: 2, index: 2},\n *   {key: 8, local: 3, index: 4},\n *   {key: 4, local: 4, index: 5},\n *   {key: 1, local: 5, index: 8}\n * ]\n * The array [0, 3, 2, 0, 8, 4, 3, 2, 1, 8] could have `Search()` called on\n * each id, to be mapped via:\n * ```\n * __global__ void map(int32_t * array,\n *                     size_t size,\n *                     DeviceOrderedHashTable<int32_t> table) {\n *   int idx = threadIdx.x + blockIdx.x*blockDim.x;\n *   if (idx < size) {\n *     array[idx] = table.Search(array[idx])->local;\n *   }\n * }\n * ```\n * to get the remaped array:\n * [0, 1, 2, 0, 3, 4, 1, 2, 5, 3]\n *\n * @tparam IdType The type of the IDs.\n */\ntemplate <typename IdType>\nclass DeviceOrderedHashTable {\n public:\n  /**\n   * @brief An entry in the hashtable.\n   */\n  struct Mapping {\n    /**\n     * @brief The ID of the item inserted.\n     */\n    IdType key;\n    /**\n     * @brief The index of the item in the unique list.\n     */\n    IdType local;\n    /**\n     * @brief The index of the item when inserted into the hashtable (e.g.,\n     * the index within the array passed into FillWithDuplicates()).\n     */\n    int64_t index;\n  };\n\n  typedef const Mapping* ConstIterator;\n\n  DeviceOrderedHashTable(const DeviceOrderedHashTable& other) = default;\n  DeviceOrderedHashTable& operator=(const DeviceOrderedHashTable& other) =\n      default;\n\n  /**\n   * @brief Find the non-mutable mapping of a given key within the hash table.\n   *\n   * WARNING: The key must exist within the hashtable. Searching for a key not\n   * in the hashtable is undefined behavior.\n   *\n   * @param id The key to search for.\n   *\n   * @return An iterator to the mapping.\n   */\n  inline __device__ ConstIterator Search(const IdType id) const {\n    const IdType pos = SearchForPosition(id);\n\n    return &table_[pos];\n  }\n\n  /**\n   * @brief Check whether a key exists within the hashtable.\n   *\n   * @param id The key to check for.\n   *\n   * @return True if the key exists in the hashtable.\n   */\n  inline __device__ bool Contains(const IdType id) const {\n    IdType pos = Hash(id);\n\n    IdType delta = 1;\n    while (table_[pos].key != kEmptyKey) {\n      if (table_[pos].key == id) {\n        return true;\n      }\n      pos = Hash(pos + delta);\n      delta += 1;\n    }\n    return false;\n  }\n\n protected:\n  // Must be uniform bytes for memset to work\n  static constexpr IdType kEmptyKey = static_cast<IdType>(-1);\n\n  const Mapping* table_;\n  size_t size_;\n\n  /**\n   * @brief Create a new device-side handle to the hash table.\n   *\n   * @param table The table stored in GPU memory.\n   * @param size The size of the table.\n   */\n  explicit DeviceOrderedHashTable(const Mapping* table, size_t size);\n\n  /**\n   * @brief Search for an item in the hash table which is known to exist.\n   *\n   * WARNING: If the ID searched for does not exist within the hashtable, this\n   * function will never return.\n   *\n   * @param id The ID of the item to search for.\n   *\n   * @return The the position of the item in the hashtable.\n   */\n  inline __device__ IdType SearchForPosition(const IdType id) const {\n    IdType pos = Hash(id);\n\n    // linearly scan for matching entry\n    IdType delta = 1;\n    while (table_[pos].key != id) {\n      assert(table_[pos].key != kEmptyKey);\n      pos = Hash(pos + delta);\n      delta += 1;\n    }\n    assert(pos < size_);\n\n    return pos;\n  }\n\n  /**\n   * @brief Hash an ID to a to a position in the hash table.\n   *\n   * @param id The ID to hash.\n   *\n   * @return The hash.\n   */\n  inline __device__ size_t Hash(const IdType id) const { return id % size_; }\n\n  friend class OrderedHashTable<IdType>;\n};\n\n/**\n * @brief A host-side handle for a GPU hashtable for mapping items to the\n * first index at which they appear in the provided data array. This host-side\n * handle is responsible for allocating and free the GPU memory of the\n * hashtable.\n *\n * For any ID array A, one can view it as a mapping from the index `i`\n * (continuous integer range from zero) to its element `A[i]`. This hashtable\n * serves as a reverse mapping, i.e., from element `A[i]` to its index `i`.\n * Quadratic probing is used for collision resolution.\n *\n * The hash table should be used in two phases, the first is filling the hash\n * table via 'FillWithDuplicates()' or 'FillWithUnique()'. Then, the\n * 'DeviceHandle()' method can be called, to get a version suitable for\n * searching from device and kernel functions.\n *\n * If 'FillWithDuplicates()' was called with an array of:\n * [0, 3, 2, 0, 8, 4, 3, 2, 1, 8]\n *\n * The resulting entries in the hash-table would be:\n * [\n *   {key: 0, local: 0, index: 0},\n *   {key: 3, local: 1, index: 1},\n *   {key: 2, local: 2, index: 2},\n *   {key: 8, local: 3, index: 4},\n *   {key: 4, local: 4, index: 5},\n *   {key: 1, local: 5, index: 8}\n * ]\n *\n * @tparam IdType The type of the IDs.\n */\ntemplate <typename IdType>\nclass OrderedHashTable {\n public:\n  static constexpr int kDefaultScale = 3;\n\n  using Mapping = typename DeviceOrderedHashTable<IdType>::Mapping;\n\n  /**\n   * @brief Create a new ordered hash table. The amoutn of GPU memory\n   * consumed by the resulting hashtable is O(`size` * 2^`scale`).\n   *\n   * @param size The number of items to insert into the hashtable.\n   * @param ctx The device context to store the hashtable on.\n   * @param scale The power of two times larger the number of buckets should\n   * be than the number of items.\n   * @param stream The stream to use for initializing the hashtable.\n   */\n  OrderedHashTable(\n      const size_t size, DGLContext ctx, cudaStream_t stream,\n      const int scale = kDefaultScale);\n\n  /**\n   * @brief Cleanup after the hashtable.\n   */\n  ~OrderedHashTable();\n\n  // Disable copying\n  OrderedHashTable(const OrderedHashTable& other) = delete;\n  OrderedHashTable& operator=(const OrderedHashTable& other) = delete;\n\n  /**\n   * @brief Fill the hashtable with the array containing possibly duplicate\n   * IDs.\n   *\n   * @param input The array of IDs to insert.\n   * @param num_input The number of IDs to insert.\n   * @param unique The list of unique IDs inserted.\n   * @param num_unique The number of unique IDs inserted.\n   * @param stream The stream to perform operations on.\n   */\n  void FillWithDuplicates(\n      const IdType* const input, const size_t num_input, IdType* const unique,\n      int64_t* const num_unique, cudaStream_t stream);\n\n  /**\n   * @brief Fill the hashtable with an array of unique keys.\n   *\n   * @param input The array of unique IDs.\n   * @param num_input The number of keys.\n   * @param stream The stream to perform operations on.\n   */\n  void FillWithUnique(\n      const IdType* const input, const size_t num_input, cudaStream_t stream);\n\n  /**\n   * @brief Get a verison of the hashtable usable from device functions.\n   *\n   * @return This hashtable.\n   */\n  DeviceOrderedHashTable<IdType> DeviceHandle() const;\n\n private:\n  Mapping* table_;\n  size_t size_;\n  DGLContext ctx_;\n};\n\n}  // namespace cuda\n}  // namespace runtime\n}  // namespace dgl\n\n#endif  // DGL_RUNTIME_CUDA_CUDA_HASHTABLE_CUH_\n"
  },
  {
    "path": "src/runtime/cuda/gpu_cache.cu",
    "content": "/*!\n *  Copyright (c) 2022 by Contributors\n *\n *  Licensed under the Apache License, Version 2.0 (the \"License\");\n *  you may not use this file except in compliance with the License.\n *  You may obtain a copy of the License at\n *\n *      http://www.apache.org/licenses/LICENSE-2.0\n *\n *  Unless required by applicable law or agreed to in writing, software\n *  distributed under the License is distributed on an \"AS IS\" BASIS,\n *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n *  See the License for the specific language governing permissions and\n *  limitations under the License.\n *\n * \\file gpu_cache.cu\n * \\brief Implementation of wrapper HugeCTR gpu_cache routines.\n */\n\n#ifndef DGL_RUNTIME_CUDA_GPU_CACHE_H_\n#define DGL_RUNTIME_CUDA_GPU_CACHE_H_\n\n#include <cuda_runtime.h>\n#include <dgl/array.h>\n#include <dgl/aten/array_ops.h>\n#include <dgl/packed_func_ext.h>\n#include <dgl/runtime/container.h>\n#include <dgl/runtime/device_api.h>\n#include <dgl/runtime/object.h>\n#include <dgl/runtime/registry.h>\n\n#include <nv_gpu_cache.hpp>\n\n#include \"../../runtime/cuda/cuda_common.h\"\n\nnamespace dgl {\nnamespace runtime {\nnamespace cuda {\n\ntemplate <typename key_t>\nclass GpuCache : public runtime::Object {\n  constexpr static int set_associativity = 2;\n  constexpr static int WARP_SIZE = 32;\n  constexpr static int bucket_size = WARP_SIZE * set_associativity;\n  using gpu_cache_t = gpu_cache::gpu_cache<\n      key_t, uint64_t, std::numeric_limits<key_t>::max(), set_associativity,\n      WARP_SIZE>;\n\n public:\n  static constexpr const char *_type_key =\n      sizeof(key_t) == 4 ? \"cuda.GpuCache32\" : \"cuda.GpuCache64\";\n  DGL_DECLARE_OBJECT_TYPE_INFO(GpuCache, Object);\n\n  GpuCache(size_t num_items, size_t num_feats)\n      : num_feats(num_feats),\n        cache(std::make_unique<gpu_cache_t>(\n            (num_items + bucket_size - 1) / bucket_size, num_feats)) {\n    CUDA_CALL(cudaGetDevice(&cuda_device));\n  }\n\n  std::tuple<NDArray, IdArray, IdArray> Query(IdArray keys) {\n    const auto &ctx = keys->ctx;\n    cudaStream_t stream = dgl::runtime::getCurrentCUDAStream();\n    auto device = dgl::runtime::DeviceAPI::Get(ctx);\n    CHECK_EQ(ctx.device_type, kDGLCUDA)\n        << \"The keys should be on a CUDA device\";\n    CHECK_EQ(ctx.device_id, cuda_device)\n        << \"The keys should be on the correct CUDA device\";\n    CHECK_EQ(keys->ndim, 1)\n        << \"The tensor of requested indices must be of dimension one.\";\n    NDArray values = NDArray::Empty(\n        {keys->shape[0], (int64_t)num_feats}, DGLDataType{kDGLFloat, 32, 1},\n        ctx);\n    IdArray missing_index = aten::NewIdArray(keys->shape[0], ctx, 64);\n    IdArray missing_keys =\n        aten::NewIdArray(keys->shape[0], ctx, sizeof(key_t) * 8);\n    size_t *missing_len =\n        static_cast<size_t *>(device->AllocWorkspace(ctx, sizeof(size_t)));\n    cache->Query(\n        static_cast<const key_t *>(keys->data), keys->shape[0],\n        static_cast<float *>(values->data),\n        static_cast<uint64_t *>(missing_index->data),\n        static_cast<key_t *>(missing_keys->data), missing_len, stream);\n    size_t missing_len_host;\n    device->CopyDataFromTo(\n        missing_len, 0, &missing_len_host, 0, sizeof(missing_len_host), ctx,\n        DGLContext{kDGLCPU, 0}, keys->dtype);\n    device->FreeWorkspace(ctx, missing_len);\n    missing_index = missing_index.CreateView(\n        {(int64_t)missing_len_host}, missing_index->dtype);\n    missing_keys =\n        missing_keys.CreateView({(int64_t)missing_len_host}, keys->dtype);\n    return std::make_tuple(values, missing_index, missing_keys);\n  }\n\n  void Replace(IdArray keys, NDArray values) {\n    cudaStream_t stream = dgl::runtime::getCurrentCUDAStream();\n    CHECK_EQ(keys->ctx.device_type, kDGLCUDA)\n        << \"The keys should be on a CUDA device\";\n    CHECK_EQ(keys->ctx.device_id, cuda_device)\n        << \"The keys should be on the correct CUDA device\";\n    CHECK_EQ(values->ctx.device_type, kDGLCUDA)\n        << \"The values should be on a CUDA device\";\n    CHECK_EQ(values->ctx.device_id, cuda_device)\n        << \"The values should be on the correct CUDA device\";\n    CHECK_EQ(keys->shape[0], values->shape[0])\n        << \"First dimensions of keys and values must match\";\n    CHECK_EQ(values->shape[1], num_feats) << \"Embedding dimension must match\";\n    cache->Replace(\n        static_cast<const key_t *>(keys->data), keys->shape[0],\n        static_cast<const float *>(values->data), stream);\n  }\n\n private:\n  size_t num_feats;\n  std::unique_ptr<gpu_cache_t> cache;\n  int cuda_device;\n};\n\nstatic_assert(sizeof(unsigned int) == 4);\nDGL_DEFINE_OBJECT_REF(GpuCacheRef32, GpuCache<unsigned int>);\n// The cu file in HugeCTR gpu cache uses unsigned int and long long.\n// Changing to int64_t results in a mismatch of template arguments.\nstatic_assert(sizeof(long long) == 8);                      // NOLINT\nDGL_DEFINE_OBJECT_REF(GpuCacheRef64, GpuCache<long long>);  // NOLINT\n\n/* CAPI **********************************************************************/\n\nusing namespace dgl::runtime;\n\nDGL_REGISTER_GLOBAL(\"cuda._CAPI_DGLGpuCacheCreate\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      const size_t num_items = args[0];\n      const size_t num_feats = args[1];\n      const int num_bits = args[2];\n\n      if (num_bits == 32)\n        *rv = GpuCacheRef32(\n            std::make_shared<GpuCache<unsigned int>>(num_items, num_feats));\n      else\n        *rv = GpuCacheRef64(std::make_shared<GpuCache<long long>>(  // NOLINT\n            num_items, num_feats));\n    });\n\nDGL_REGISTER_GLOBAL(\"cuda._CAPI_DGLGpuCacheQuery\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      IdArray keys = args[1];\n\n      List<ObjectRef> ret;\n      if (keys->dtype.bits == 32) {\n        GpuCacheRef32 cache = args[0];\n        auto result = cache->Query(keys);\n\n        ret.push_back(Value(MakeValue(std::get<0>(result))));\n        ret.push_back(Value(MakeValue(std::get<1>(result))));\n        ret.push_back(Value(MakeValue(std::get<2>(result))));\n      } else {\n        GpuCacheRef64 cache = args[0];\n        auto result = cache->Query(keys);\n\n        ret.push_back(Value(MakeValue(std::get<0>(result))));\n        ret.push_back(Value(MakeValue(std::get<1>(result))));\n        ret.push_back(Value(MakeValue(std::get<2>(result))));\n      }\n\n      *rv = ret;\n    });\n\nDGL_REGISTER_GLOBAL(\"cuda._CAPI_DGLGpuCacheReplace\")\n    .set_body([](DGLArgs args, DGLRetValue *rv) {\n      IdArray keys = args[1];\n      NDArray values = args[2];\n\n      if (keys->dtype.bits == 32) {\n        GpuCacheRef32 cache = args[0];\n        cache->Replace(keys, values);\n      } else {\n        GpuCacheRef64 cache = args[0];\n        cache->Replace(keys, values);\n      }\n\n      *rv = List<ObjectRef>{};\n    });\n\n}  // namespace cuda\n}  // namespace runtime\n}  // namespace dgl\n\n#endif\n"
  },
  {
    "path": "src/runtime/dlpack_convert.cc",
    "content": "/**\n *  Copyright (c) 2022 by Contributors\n * @file src/runtime/dlpack_convert.cc\n * @brief Conversion between NDArray and DLPack.\n */\n#include <dgl/runtime/c_runtime_api.h>\n#include <dgl/runtime/device_api.h>\n#include <dgl/runtime/dlpack_convert.h>\n#include <dgl/runtime/ndarray.h>\n#include <dlpack/dlpack.h>\n\n#include <cstdint>\n\n#include \"runtime_base.h\"\n\n// deleter for arrays used by DLPack exporter\nextern \"C\" void NDArrayDLPackDeleter(DLManagedTensor* tensor);\n\nnamespace dgl {\nnamespace runtime {\n\nvoid NDArrayDLPackDeleter(DLManagedTensor* tensor) {\n  static_cast<NDArray::Container*>(tensor->manager_ctx)->DecRef();\n  delete tensor;\n}\n\ninline DGLContext ToDGLContext(const DLDevice& device) {\n  DGLContext ctx;\n  ctx.device_type = static_cast<DGLDeviceType>(device.device_type);\n  ctx.device_id = device.device_id;\n  return ctx;\n}\n\ninline DGLDataType ToDGLDataType(const DLDataType& src) {\n  DGLDataType ret;\n  ret.code = src.code;\n  ret.bits = src.bits;\n  ret.lanes = src.lanes;\n  return ret;\n}\n\ninline DLDevice ToDLDevice(const DGLContext& ctx) {\n  DLDevice device;\n  device.device_type = static_cast<DLDeviceType>(ctx.device_type);\n  device.device_id = ctx.device_id;\n  return device;\n}\n\ninline DLDataType ToDLDataType(const DGLDataType& src) {\n  DLDataType ret;\n  ret.code = src.code;\n  ret.bits = src.bits;\n  ret.lanes = src.lanes;\n  return ret;\n}\n\nNDArray DLPackConvert::FromDLPack(DLManagedTensor* tensor) {\n  NDArray::Container* data = new NDArray::Container();\n  data->deleter = DLPackConvert::DLPackDeleter;\n  data->manager_ctx = tensor;\n  data->dl_tensor.data = tensor->dl_tensor.data;\n  data->dl_tensor.ctx = ToDGLContext(tensor->dl_tensor.device);\n  data->dl_tensor.ndim = tensor->dl_tensor.ndim;\n  data->dl_tensor.dtype = ToDGLDataType(tensor->dl_tensor.dtype);\n  data->dl_tensor.shape = tensor->dl_tensor.shape;\n  data->dl_tensor.strides = tensor->dl_tensor.strides;\n  data->dl_tensor.byte_offset = tensor->dl_tensor.byte_offset;\n\n  return NDArray(data);\n}\n\nvoid DLPackConvert::DLPackDeleter(NDArray::Container* ptr) {\n  // if the array is pinned by dgl, unpin it before freeing\n  if (ptr->pinned_by_dgl_) NDArray::UnpinContainer(ptr);\n  DLManagedTensor* tensor = static_cast<DLManagedTensor*>(ptr->manager_ctx);\n  if (tensor->deleter != nullptr) {\n    (*tensor->deleter)(tensor);\n  }\n  delete ptr;\n}\n\nDLManagedTensor* ContainerToDLPack(NDArray::Container* from) {\n  CHECK(from != nullptr);\n  DLManagedTensor* ret = new DLManagedTensor();\n  ret->dl_tensor.data = from->dl_tensor.data;\n  ret->dl_tensor.device = ToDLDevice(from->dl_tensor.ctx);\n  ret->dl_tensor.ndim = from->dl_tensor.ndim;\n  ret->dl_tensor.dtype = ToDLDataType(from->dl_tensor.dtype);\n  ret->dl_tensor.shape = from->dl_tensor.shape;\n  ret->dl_tensor.strides = from->dl_tensor.strides;\n  ret->dl_tensor.byte_offset = from->dl_tensor.byte_offset;\n\n  ret->manager_ctx = from;\n  from->IncRef();\n  ret->deleter = NDArrayDLPackDeleter;\n  return ret;\n}\n\nDLManagedTensor* DLPackConvert::ToDLPack(const NDArray& from) {\n  return ContainerToDLPack(from.data_);\n}\n\n}  // namespace runtime\n}  // namespace dgl\n\nusing namespace dgl::runtime;\n\nvoid DGLDLManagedTensorCallDeleter(DLManagedTensor* dltensor) {\n  (*(dltensor->deleter))(dltensor);\n}\n\ninline bool IsAligned(const void* ptr, std::uintptr_t alignment) noexcept {\n  auto iptr = reinterpret_cast<std::uintptr_t>(ptr);\n  return !(iptr % alignment);\n}\n\nint DGLArrayFromDLPack(DLManagedTensor* from, DGLArrayHandle* out) {\n  API_BEGIN();\n  *out = NDArray::Internal::MoveAsDGLArray(DLPackConvert::FromDLPack(from));\n  API_END();\n}\n\nint DGLArrayToDLPack(\n    DGLArrayHandle from, DLManagedTensor** out, int alignment) {\n  API_BEGIN();\n  auto* nd_container = reinterpret_cast<NDArray::Container*>(from);\n  DGLArray* nd = &(nd_container->dl_tensor);\n  // If the source DGLArray is not aligned, we should create a new aligned one\n  if (alignment != 0 && !IsAligned(nd->data, alignment)) {\n    std::vector<int64_t> shape_vec(nd->shape, nd->shape + nd->ndim);\n    NDArray copy_ndarray = NDArray::Empty(shape_vec, nd->dtype, nd->ctx);\n    copy_ndarray.CopyFrom(nd);\n    *out = DLPackConvert::ToDLPack(copy_ndarray);\n  } else {\n    *out = ContainerToDLPack(nd_container);\n  }\n  API_END();\n}\n"
  },
  {
    "path": "src/runtime/dso_module.cc",
    "content": "/**\n *  Copyright (c) 2017 by Contributors\n * @file dso_dll_module.cc\n * @brief Module to load from dynamic shared library.\n */\n#include <dgl/runtime/module.h>\n#include <dgl/runtime/packed_func.h>\n#include <dgl/runtime/registry.h>\n\n#include \"module_util.h\"\n\n#if defined(_WIN32)\n#include <windows.h>\n#else\n#include <dlfcn.h>\n#endif\n\nnamespace dgl {\nnamespace runtime {\n\n// Module to load from dynamic shared libary.\n// This is the default module DGL used for host-side AOT\nclass DSOModuleNode final : public ModuleNode {\n public:\n  ~DSOModuleNode() {\n    if (lib_handle_) Unload();\n  }\n\n  const char* type_key() const final { return \"dso\"; }\n\n  PackedFunc GetFunction(\n      const std::string& name,\n      const std::shared_ptr<ModuleNode>& sptr_to_self) final {\n    BackendPackedCFunc faddr;\n    if (name == runtime::symbol::dgl_module_main) {\n      const char* entry_name = reinterpret_cast<const char*>(\n          GetSymbol(runtime::symbol::dgl_module_main));\n      CHECK(entry_name != nullptr)\n          << \"Symbol \" << runtime::symbol::dgl_module_main\n          << \" is not presented\";\n      faddr = reinterpret_cast<BackendPackedCFunc>(GetSymbol(entry_name));\n    } else {\n      faddr = reinterpret_cast<BackendPackedCFunc>(GetSymbol(name.c_str()));\n    }\n    if (faddr == nullptr) return PackedFunc();\n    return WrapPackedFunc(faddr, sptr_to_self);\n  }\n\n  void Init(const std::string& name) {\n    Load(name);\n    if (auto* ctx_addr = reinterpret_cast<void**>(\n            GetSymbol(runtime::symbol::dgl_module_ctx))) {\n      *ctx_addr = this;\n    }\n    InitContextFunctions(\n        [this](const char* fname) { return GetSymbol(fname); });\n    // Load the imported modules\n    const char* dev_mblob = reinterpret_cast<const char*>(\n        GetSymbol(runtime::symbol::dgl_dev_mblob));\n    if (dev_mblob != nullptr) {\n      ImportModuleBlob(dev_mblob, &imports_);\n    }\n  }\n\n private:\n  // Platform dependent handling.\n#if defined(_WIN32)\n  // library handle\n  HMODULE lib_handle_{nullptr};\n  // Load the library\n  void Load(const std::string& name) {\n    // use wstring version that is needed by LLVM.\n    std::wstring wname(name.begin(), name.end());\n    lib_handle_ = LoadLibraryW(wname.c_str());\n    CHECK(lib_handle_ != nullptr)\n        << \"Failed to load dynamic shared library \" << name;\n  }\n  void* GetSymbol(const char* name) {\n    return reinterpret_cast<void*>(\n        GetProcAddress(lib_handle_, (LPCSTR)name));  // NOLINT(*)\n  }\n  void Unload() { FreeLibrary(lib_handle_); }\n#else\n  // Library handle\n  void* lib_handle_{nullptr};\n  // load the library\n  void Load(const std::string& name) {\n    lib_handle_ = dlopen(name.c_str(), RTLD_LAZY | RTLD_LOCAL);\n    CHECK(lib_handle_ != nullptr)\n        << \"Failed to load dynamic shared library \" << name << \" \" << dlerror();\n  }\n  void* GetSymbol(const char* name) { return dlsym(lib_handle_, name); }\n  void Unload() { dlclose(lib_handle_); }\n#endif\n};\n\nDGL_REGISTER_GLOBAL(\"module.loadfile_so\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      std::shared_ptr<DSOModuleNode> n = std::make_shared<DSOModuleNode>();\n      n->Init(args[0]);\n      *rv = runtime::Module(n);\n    });\n}  // namespace runtime\n}  // namespace dgl\n"
  },
  {
    "path": "src/runtime/file_util.cc",
    "content": "/**\n *  Copyright (c) 2017 by Contributors\n * @file file_util.cc\n */\n#include \"file_util.h\"\n\n#include <dgl/runtime/serializer.h>\n#include <dmlc/json.h>\n#include <dmlc/logging.h>\n\n#include <fstream>\n#include <unordered_map>\n#include <vector>\n\nnamespace dgl {\nnamespace runtime {\n\nvoid FunctionInfo::Save(dmlc::JSONWriter* writer) const {\n  std::vector<std::string> sarg_types(arg_types.size());\n  for (size_t i = 0; i < arg_types.size(); ++i) {\n    sarg_types[i] = DGLDataType2String(arg_types[i]);\n  }\n  writer->BeginObject();\n  writer->WriteObjectKeyValue(\"name\", name);\n  writer->WriteObjectKeyValue(\"arg_types\", sarg_types);\n  writer->WriteObjectKeyValue(\"thread_axis_tags\", thread_axis_tags);\n  writer->EndObject();\n}\n\nvoid FunctionInfo::Load(dmlc::JSONReader* reader) {\n  dmlc::JSONObjectReadHelper helper;\n  std::vector<std::string> sarg_types;\n  helper.DeclareField(\"name\", &name);\n  helper.DeclareField(\"arg_types\", &sarg_types);\n  helper.DeclareField(\"thread_axis_tags\", &thread_axis_tags);\n  helper.ReadAllFields(reader);\n  arg_types.resize(sarg_types.size());\n  for (size_t i = 0; i < arg_types.size(); ++i) {\n    arg_types[i] = String2DGLDataType(sarg_types[i]);\n  }\n}\n\nvoid FunctionInfo::Save(dmlc::Stream* writer) const {\n  writer->Write(name);\n  writer->Write(arg_types);\n  writer->Write(thread_axis_tags);\n}\n\nbool FunctionInfo::Load(dmlc::Stream* reader) {\n  if (!reader->Read(&name)) return false;\n  if (!reader->Read(&arg_types)) return false;\n  if (!reader->Read(&thread_axis_tags)) return false;\n  return true;\n}\n\nstd::string GetFileFormat(\n    const std::string& file_name, const std::string& format) {\n  std::string fmt = format;\n  if (fmt.length() == 0) {\n    if (file_name.find(\".signed.so\") != std::string::npos) return \"sgx\";\n    size_t pos = file_name.find_last_of(\".\");\n    if (pos != std::string::npos) {\n      return file_name.substr(pos + 1, file_name.length() - pos - 1);\n    } else {\n      return \"\";\n    }\n  } else {\n    return format;\n  }\n}\n\nstd::string GetCacheDir() {\n  char* env_cache_dir;\n  if ((env_cache_dir = getenv(\"DGL_CACHE_DIR\"))) return env_cache_dir;\n  if ((env_cache_dir = getenv(\"XDG_CACHE_HOME\"))) {\n    return std::string(env_cache_dir) + \"/dgl\";\n  }\n  if ((env_cache_dir = getenv(\"HOME\"))) {\n    return std::string(env_cache_dir) + \"/.cache/dgl\";\n  }\n  return \".\";\n}\n\nstd::string GetFileBasename(const std::string& file_name) {\n  size_t last_slash = file_name.find_last_of(\"/\");\n  if (last_slash == std::string::npos) return file_name;\n  return file_name.substr(last_slash + 1);\n}\n\nstd::string GetMetaFilePath(const std::string& file_name) {\n  size_t pos = file_name.find_last_of(\".\");\n  if (pos != std::string::npos) {\n    return file_name.substr(0, pos) + \".dgl_meta.json\";\n  } else {\n    return file_name + \".dgl_meta.json\";\n  }\n}\n\nvoid LoadBinaryFromFile(const std::string& file_name, std::string* data) {\n  std::ifstream fs(file_name, std::ios::in | std::ios::binary);\n  CHECK(!fs.fail()) << \"Cannot open \" << file_name;\n  // get its size:\n  fs.seekg(0, std::ios::end);\n  size_t size = static_cast<size_t>(fs.tellg());\n  fs.seekg(0, std::ios::beg);\n  data->resize(size);\n  fs.read(&(*data)[0], size);\n}\n\nvoid SaveBinaryToFile(const std::string& file_name, const std::string& data) {\n  std::ofstream fs(file_name, std::ios::out | std::ios::binary);\n  CHECK(!fs.fail()) << \"Cannot open \" << file_name;\n  fs.write(&data[0], data.length());\n}\n\nvoid SaveMetaDataToFile(\n    const std::string& file_name,\n    const std::unordered_map<std::string, FunctionInfo>& fmap) {\n  std::string version = \"0.1.0\";\n  std::ofstream fs(file_name.c_str());\n  CHECK(!fs.fail()) << \"Cannot open file \" << file_name;\n  dmlc::JSONWriter writer(&fs);\n  writer.BeginObject();\n  writer.WriteObjectKeyValue(\"dgl_version\", version);\n  writer.WriteObjectKeyValue(\"func_info\", fmap);\n  writer.EndObject();\n  fs.close();\n}\n\nvoid LoadMetaDataFromFile(\n    const std::string& file_name,\n    std::unordered_map<std::string, FunctionInfo>* fmap) {\n  std::ifstream fs(file_name.c_str());\n  CHECK(!fs.fail()) << \"Cannot open file \" << file_name;\n  std::string version;\n  dmlc::JSONReader reader(&fs);\n  dmlc::JSONObjectReadHelper helper;\n  helper.DeclareField(\"dgl_version\", &version);\n  helper.DeclareField(\"func_info\", fmap);\n  helper.ReadAllFields(&reader);\n  fs.close();\n}\n\n}  // namespace runtime\n}  // namespace dgl\n"
  },
  {
    "path": "src/runtime/file_util.h",
    "content": "/**\n *  Copyright (c) 2017 by Contributors\n * @file file_util.h\n * @brief Minimum file manipulation util for runtime.\n */\n#ifndef DGL_RUNTIME_FILE_UTIL_H_\n#define DGL_RUNTIME_FILE_UTIL_H_\n\n#include <string>\n#include <unordered_map>\n\n#include \"meta_data.h\"\n\nnamespace dgl {\nnamespace runtime {\n/**\n * @brief Get file format from given file name or format argument.\n * @param file_name The name of the file.\n * @param format The format of the file.\n */\nstd::string GetFileFormat(\n    const std::string& file_name, const std::string& format);\n\n/**\n * @return the directory in which DGL stores cached files.\n *         May be set using DGL_CACHE_DIR; defaults to system locations.\n */\nstd::string GetCacheDir();\n\n/**\n * @brief Get meta file path given file name and format.\n * @param file_name The name of the file.\n */\nstd::string GetMetaFilePath(const std::string& file_name);\n\n/**\n * @brief Get file basename (i.e. without leading directories)\n * @param file_name The name of the file.\n * @return the base name\n */\nstd::string GetFileBasename(const std::string& file_name);\n\n/**\n * @brief Load binary file into a in-memory buffer.\n * @param file_name The name of the file.\n * @param data The data to be loaded.\n */\nvoid LoadBinaryFromFile(const std::string& file_name, std::string* data);\n\n/**\n * @brief Load binary file into a in-memory buffer.\n * @param file_name The name of the file.\n * @param data The binary data to be saved.\n */\nvoid SaveBinaryToFile(const std::string& file_name, const std::string& data);\n\n/**\n * @brief Save meta data to file.\n * @param file_name The name of the file.\n * @param fmap The function info map.\n */\nvoid SaveMetaDataToFile(\n    const std::string& file_name,\n    const std::unordered_map<std::string, FunctionInfo>& fmap);\n\n/**\n * @brief Load meta data to file.\n * @param file_name The name of the file.\n * @param fmap The function info map.\n */\nvoid LoadMetaDataFromFile(\n    const std::string& file_name,\n    std::unordered_map<std::string, FunctionInfo>* fmap);\n}  // namespace runtime\n}  // namespace dgl\n#endif  // DGL_RUNTIME_FILE_UTIL_H_\n"
  },
  {
    "path": "src/runtime/meta_data.h",
    "content": "/**\n *  Copyright (c) 2017 by Contributors\n * @file meta_data.h\n * @brief Meta data related utilities\n */\n#ifndef DGL_RUNTIME_META_DATA_H_\n#define DGL_RUNTIME_META_DATA_H_\n\n#include <dgl/runtime/packed_func.h>\n#include <dmlc/io.h>\n#include <dmlc/json.h>\n\n#include <string>\n#include <vector>\n\n#include \"runtime_base.h\"\n\nnamespace dgl {\nnamespace runtime {\n\n/** @brief function information needed by device */\nstruct FunctionInfo {\n  std::string name;\n  std::vector<DGLDataType> arg_types;\n  std::vector<std::string> thread_axis_tags;\n\n  void Save(dmlc::JSONWriter *writer) const;\n  void Load(dmlc::JSONReader *reader);\n  void Save(dmlc::Stream *writer) const;\n  bool Load(dmlc::Stream *reader);\n};\n}  // namespace runtime\n}  // namespace dgl\n\nnamespace dmlc {\nDMLC_DECLARE_TRAITS(has_saveload, ::dgl::runtime::FunctionInfo, true);\n}  // namespace dmlc\n#endif  // DGL_RUNTIME_META_DATA_H_\n"
  },
  {
    "path": "src/runtime/module.cc",
    "content": "/**\n *  Copyright (c) 2017 by Contributors\n * @file module.cc\n * @brief DGL module system\n */\n#include <dgl/runtime/module.h>\n#include <dgl/runtime/packed_func.h>\n#include <dgl/runtime/registry.h>\n\n#include <cstring>\n#include <unordered_set>\n#ifndef _LIBCPP_SGX_CONFIG\n#include \"file_util.h\"\n#endif\n\nnamespace dgl {\nnamespace runtime {\n\nvoid Module::Import(Module other) {\n  // specially handle rpc\n  if (!std::strcmp((*this)->type_key(), \"rpc\")) {\n    static const PackedFunc* fimport_ = nullptr;\n    if (fimport_ == nullptr) {\n      fimport_ = runtime::Registry::Get(\"rpc._ImportRemoteModule\");\n      CHECK(fimport_ != nullptr);\n    }\n    (*fimport_)(*this, other);\n    return;\n  }\n  // cyclic detection.\n  std::unordered_set<const ModuleNode*> visited{other.node_.get()};\n  std::vector<const ModuleNode*> stack{other.node_.get()};\n  while (!stack.empty()) {\n    const ModuleNode* n = stack.back();\n    stack.pop_back();\n    for (const Module& m : n->imports_) {\n      const ModuleNode* next = m.node_.get();\n      if (visited.count(next)) continue;\n      visited.insert(next);\n      stack.push_back(next);\n    }\n  }\n  CHECK(!visited.count(node_.get()))\n      << \"Cyclic dependency detected during import\";\n  node_->imports_.emplace_back(std::move(other));\n}\n\nModule Module::LoadFromFile(\n    const std::string& file_name, const std::string& format) {\n#ifndef _LIBCPP_SGX_CONFIG\n  std::string fmt = GetFileFormat(file_name, format);\n  CHECK(fmt.length() != 0) << \"Cannot deduce format of file \" << file_name;\n  if (fmt == \"dll\" || fmt == \"dylib\" || fmt == \"dso\") {\n    fmt = \"so\";\n  }\n  std::string load_f_name = \"module.loadfile_\" + fmt;\n  const PackedFunc* f = Registry::Get(load_f_name);\n  CHECK(f != nullptr) << \"Loader of \" << format << \"(\" << load_f_name\n                      << \") is not presented.\";\n  Module m = (*f)(file_name, format);\n  return m;\n#else\n  LOG(FATAL) << \"SGX does not support LoadFromFile\";\n#endif\n}\n\nvoid ModuleNode::SaveToFile(\n    const std::string& file_name, const std::string& format) {\n  LOG(FATAL) << \"Module[\" << type_key() << \"] does not support SaveToFile\";\n}\n\nvoid ModuleNode::SaveToBinary(dmlc::Stream* stream) {\n  LOG(FATAL) << \"Module[\" << type_key() << \"] does not support SaveToBinary\";\n}\n\nstd::string ModuleNode::GetSource(const std::string& format) {\n  LOG(FATAL) << \"Module[\" << type_key() << \"] does not support GetSource\";\n  return \"\";\n}\n\nconst PackedFunc* ModuleNode::GetFuncFromEnv(const std::string& name) {\n  auto it = import_cache_.find(name);\n  if (it != import_cache_.end()) return it->second.get();\n  PackedFunc pf;\n  for (Module& m : this->imports_) {\n    pf = m.GetFunction(name, false);\n    if (pf != nullptr) break;\n  }\n  if (pf == nullptr) {\n    const PackedFunc* f = Registry::Get(name);\n    CHECK(f != nullptr) << \"Cannot find function \" << name\n                        << \" in the imported modules or global registry\";\n    return f;\n  } else {\n    std::unique_ptr<PackedFunc> f(new PackedFunc(pf));\n    import_cache_[name] = std::move(f);\n    return import_cache_.at(name).get();\n  }\n}\n\nbool RuntimeEnabled(const std::string& target) {\n  std::string f_name;\n  if (target == \"cpu\") {\n    return true;\n  } else if (target == \"cuda\" || target == \"gpu\") {\n    f_name = \"device_api.cuda\";\n  } else if (target == \"cl\" || target == \"opencl\" || target == \"sdaccel\") {\n    f_name = \"device_api.opencl\";\n  } else if (target == \"gl\" || target == \"opengl\") {\n    f_name = \"device_api.opengl\";\n  } else if (target == \"mtl\" || target == \"metal\") {\n    f_name = \"device_api.metal\";\n  } else if (target == \"vulkan\") {\n    f_name = \"device_api.vulkan\";\n  } else if (target == \"stackvm\") {\n    f_name = \"codegen.build_stackvm\";\n  } else if (target == \"rpc\") {\n    f_name = \"device_api.rpc\";\n  } else if (target == \"vpi\" || target == \"verilog\") {\n    f_name = \"device_api.vpi\";\n  } else if (target.length() >= 5 && target.substr(0, 5) == \"nvptx\") {\n    f_name = \"device_api.cuda\";\n  } else if (target.length() >= 4 && target.substr(0, 4) == \"rocm\") {\n    f_name = \"device_api.rocm\";\n  } else if (target.length() >= 4 && target.substr(0, 4) == \"llvm\") {\n    const PackedFunc* pf =\n        runtime::Registry::Get(\"codegen.llvm_target_enabled\");\n    if (pf == nullptr) return false;\n    return (*pf)(target);\n  } else {\n    LOG(FATAL) << \"Unknown optional runtime \" << target;\n  }\n  return runtime::Registry::Get(f_name) != nullptr;\n}\n\nDGL_REGISTER_GLOBAL(\"module._Enabled\")\n    .set_body([](DGLArgs args, DGLRetValue* ret) {\n      *ret = RuntimeEnabled(args[0]);\n    });\n\nDGL_REGISTER_GLOBAL(\"module._GetSource\")\n    .set_body([](DGLArgs args, DGLRetValue* ret) {\n      *ret = args[0].operator Module()->GetSource(args[1]);\n    });\n\nDGL_REGISTER_GLOBAL(\"module._ImportsSize\")\n    .set_body([](DGLArgs args, DGLRetValue* ret) {\n      *ret = static_cast<int64_t>(args[0].operator Module()->imports().size());\n    });\n\nDGL_REGISTER_GLOBAL(\"module._GetImport\")\n    .set_body([](DGLArgs args, DGLRetValue* ret) {\n      *ret = args[0].operator Module()->imports().at(args[1].operator int());\n    });\n\nDGL_REGISTER_GLOBAL(\"module._GetTypeKey\")\n    .set_body([](DGLArgs args, DGLRetValue* ret) {\n      *ret = std::string(args[0].operator Module()->type_key());\n    });\n\nDGL_REGISTER_GLOBAL(\"module._LoadFromFile\")\n    .set_body([](DGLArgs args, DGLRetValue* ret) {\n      *ret = Module::LoadFromFile(args[0], args[1]);\n    });\n\nDGL_REGISTER_GLOBAL(\"module._SaveToFile\")\n    .set_body([](DGLArgs args, DGLRetValue* ret) {\n      args[0].operator Module()->SaveToFile(args[1], args[2]);\n    });\n}  // namespace runtime\n}  // namespace dgl\n"
  },
  {
    "path": "src/runtime/module_util.cc",
    "content": "/**\n *  Copyright (c) 2017 by Contributors\n * @file module_util.cc\n * @brief Utilities for module.\n */\n#ifndef _LIBCPP_SGX_CONFIG\n#include <dmlc/memory_io.h>\n#endif\n#include <dgl/runtime/module.h>\n#include <dgl/runtime/registry.h>\n\n#include <memory>\n#include <string>\n\n#include \"module_util.h\"\n\nnamespace dgl {\nnamespace runtime {\n\nvoid ImportModuleBlob(const char* mblob, std::vector<Module>* mlist) {\n#ifndef _LIBCPP_SGX_CONFIG\n  CHECK(mblob != nullptr);\n  uint64_t nbytes = 0;\n  for (size_t i = 0; i < sizeof(nbytes); ++i) {\n    uint64_t c = mblob[i];\n    nbytes |= (c & 0xffUL) << (i * 8);\n  }\n  dmlc::MemoryFixedSizeStream fs(\n      const_cast<char*>(mblob + sizeof(nbytes)), static_cast<size_t>(nbytes));\n  dmlc::Stream* stream = &fs;\n  uint64_t size;\n  CHECK(stream->Read(&size));\n  for (uint64_t i = 0; i < size; ++i) {\n    std::string tkey;\n    CHECK(stream->Read(&tkey));\n    std::string fkey = \"module.loadbinary_\" + tkey;\n    const PackedFunc* f = Registry::Get(fkey);\n    CHECK(f != nullptr) << \"Loader of \" << tkey << \"(\" << fkey\n                        << \") is not presented.\";\n    Module m = (*f)(static_cast<void*>(stream));\n    mlist->push_back(m);\n  }\n#else\n  LOG(FATAL) << \"SGX does not support ImportModuleBlob\";\n#endif\n}\n\nPackedFunc WrapPackedFunc(\n    BackendPackedCFunc faddr, const std::shared_ptr<ModuleNode>& sptr_to_self) {\n  return PackedFunc([faddr, sptr_to_self](DGLArgs args, DGLRetValue* rv) {\n    int ret = (*faddr)(\n        const_cast<DGLValue*>(args.values), const_cast<int*>(args.type_codes),\n        args.num_args);\n    CHECK_EQ(ret, 0) << DGLGetLastError();\n  });\n}\n\n}  // namespace runtime\n}  // namespace dgl\n"
  },
  {
    "path": "src/runtime/module_util.h",
    "content": "/**\n *  Copyright (c) 2017 by Contributors\n * @file module_util.h\n * @brief Helper utilities for module building\n */\n#ifndef DGL_RUNTIME_MODULE_UTIL_H_\n#define DGL_RUNTIME_MODULE_UTIL_H_\n\n#include <dgl/runtime/c_backend_api.h>\n#include <dgl/runtime/c_runtime_api.h>\n#include <dgl/runtime/module.h>\n\n#include <memory>\n#include <vector>\n\nextern \"C\" {\n// Function signature for generated packed function in shared library\ntypedef int (*BackendPackedCFunc)(void* args, int* type_codes, int num_args);\n}  // extern \"C\"\n\nnamespace dgl {\nnamespace runtime {\n/**\n * @brief Wrap a BackendPackedCFunc to packed function.\n * @param faddr The function address\n * @param mptr The module pointer node.\n */\nPackedFunc WrapPackedFunc(\n    BackendPackedCFunc faddr, const std::shared_ptr<ModuleNode>& mptr);\n/**\n * @brief Load and append module blob to module list\n * @param mblob The module blob.\n * @param module_list The module list to append to\n */\nvoid ImportModuleBlob(const char* mblob, std::vector<Module>* module_list);\n\n/**\n * @brief Utility to initialize conext function symbols during startup\n * @param flookup A symbol lookup function.\n * @tparam FLookup a function of signature string->void*\n */\ntemplate <typename FLookup>\nvoid InitContextFunctions(FLookup flookup) {\n#define DGL_INIT_CONTEXT_FUNC(FuncName)                                      \\\n  if (auto* fp =                                                             \\\n          reinterpret_cast<decltype(&FuncName)*>(flookup(\"__\" #FuncName))) { \\\n    *fp = FuncName;                                                          \\\n  }\n  // Initialize the functions\n  DGL_INIT_CONTEXT_FUNC(DGLFuncCall);\n  DGL_INIT_CONTEXT_FUNC(DGLAPISetLastError);\n  DGL_INIT_CONTEXT_FUNC(DGLBackendGetFuncFromEnv);\n  DGL_INIT_CONTEXT_FUNC(DGLBackendAllocWorkspace);\n  DGL_INIT_CONTEXT_FUNC(DGLBackendFreeWorkspace);\n  DGL_INIT_CONTEXT_FUNC(DGLBackendParallelLaunch);\n  DGL_INIT_CONTEXT_FUNC(DGLBackendParallelBarrier);\n\n#undef DGL_INIT_CONTEXT_FUNC\n}\n}  // namespace runtime\n}  // namespace dgl\n#endif  // DGL_RUNTIME_MODULE_UTIL_H_\n"
  },
  {
    "path": "src/runtime/ndarray.cc",
    "content": "/**\n *  Copyright (c) 2017-2022 by Contributors\n * @file ndarray.cc\n * @brief NDArray container infratructure.\n */\n#include <dgl/runtime/c_runtime_api.h>\n#include <dgl/runtime/device_api.h>\n#include <dgl/runtime/ndarray.h>\n#include <dgl/runtime/shared_mem.h>\n#include <dgl/runtime/tensordispatch.h>\n#include <dgl/zerocopy_serializer.h>\n#include <dmlc/logging.h>\n#include <string.h>\n\n#include \"runtime_base.h\"\n\nnamespace dgl {\n\nconstexpr DGLDataType DGLDataTypeTraits<int8_t>::dtype;\nconstexpr DGLDataType DGLDataTypeTraits<uint8_t>::dtype;\nconstexpr DGLDataType DGLDataTypeTraits<int16_t>::dtype;\nconstexpr DGLDataType DGLDataTypeTraits<int32_t>::dtype;\nconstexpr DGLDataType DGLDataTypeTraits<int64_t>::dtype;\nconstexpr DGLDataType DGLDataTypeTraits<uint32_t>::dtype;\nconstexpr DGLDataType DGLDataTypeTraits<uint64_t>::dtype;\n#ifdef DGL_USE_CUDA\nconstexpr DGLDataType DGLDataTypeTraits<__half>::dtype;\n#if BF16_ENABLED\nconstexpr DGLDataType DGLDataTypeTraits<__nv_bfloat16>::dtype;\n#endif  // BF16_ENABLED\n#endif  // DGL_USE_CUDA\nconstexpr DGLDataType DGLDataTypeTraits<float>::dtype;\nconstexpr DGLDataType DGLDataTypeTraits<double>::dtype;\n\nnamespace runtime {\n\ninline void VerifyDataType(DGLDataType dtype) {\n  CHECK_GE(dtype.lanes, 1);\n  if (dtype.code == kDGLFloat) {\n    CHECK_EQ(dtype.bits % 8, 0);\n  } else {\n    CHECK_EQ(dtype.bits % 8, 0);\n  }\n  CHECK_EQ(dtype.bits & (dtype.bits - 1), 0);\n}\n\ninline size_t GetDataSize(const DGLArray& arr) {\n  size_t size = 1;\n  for (dgl_index_t i = 0; i < arr.ndim; ++i) {\n    size *= arr.shape[i];\n  }\n  size *= (arr.dtype.bits * arr.dtype.lanes + 7) / 8;\n  return size;\n}\n\ninline size_t GetDataAlignment(const DGLArray& arr) {\n  size_t align = (arr.dtype.bits / 8) * arr.dtype.lanes;\n  if (align < kAllocAlignment) return kAllocAlignment;\n  return align;\n}\n\nvoid NDArray::Internal::DefaultDeleter(NDArray::Container* ptr) {\n  using dgl::runtime::NDArray;\n  if (ptr->manager_ctx != nullptr) {\n    static_cast<NDArray::Container*>(ptr->manager_ctx)->DecRef();\n  } else if (ptr->mem) {\n    ptr->mem = nullptr;\n  } else if (ptr->dl_tensor.data != nullptr) {\n    // if the array is still pinned before freeing, unpin it.\n    if (ptr->pinned_by_dgl_) UnpinContainer(ptr);\n    if (ptr->pinned_by_pytorch_) {\n      DeviceAPI::Get(kDGLCUDA)->FreePinnedDataSpace(\n          &(ptr->pytorch_raw_deleter_));\n      CHECK(ptr->pytorch_raw_deleter_ == nullptr);\n      ptr->pinned_by_pytorch_ = false;\n      ptr->pytorch_ctx_ = nullptr;\n    } else {\n      dgl::runtime::DeviceAPI::Get(ptr->dl_tensor.ctx)\n          ->FreeDataSpace(ptr->dl_tensor.ctx, ptr->dl_tensor.data);\n    }\n  }\n  delete ptr;\n}\n\nNDArray NDArray::Internal::Create(\n    std::vector<int64_t> shape, DGLDataType dtype, DGLContext ctx) {\n  VerifyDataType(dtype);\n  // critical zone\n  NDArray::Container* data = new NDArray::Container();\n  data->deleter = DefaultDeleter;\n  NDArray ret(data);\n  ret.data_ = data;\n  // RAII now in effect\n  // setup shape\n  data->shape_ = std::move(shape);\n  data->dl_tensor.shape = dmlc::BeginPtr(data->shape_);\n  data->dl_tensor.ndim = static_cast<int>(data->shape_.size());\n  // setup stride (this should be optional, but some framework\n  //   does not support NULL stride and thus will crash the program).\n  data->stride_.resize(data->dl_tensor.ndim, 1);\n  for (int i = data->dl_tensor.ndim - 2; i >= 0; --i) {\n    data->stride_[i] = data->shape_[i + 1] * data->stride_[i + 1];\n  }\n  data->dl_tensor.strides = dmlc::BeginPtr(data->stride_);\n  // setup dtype\n  data->dl_tensor.dtype = dtype;\n  // setup ctx\n  data->dl_tensor.ctx = ctx;\n  return ret;\n}\n\nDGLArray* NDArray::Internal::MoveAsDGLArray(NDArray arr) {\n  DGLArray* tensor = reinterpret_cast<DGLArray*>(arr.data_);\n  CHECK(tensor == const_cast<DGLArray*>(arr.operator->()));\n  arr.data_ = nullptr;\n  return tensor;\n}\n\nsize_t NDArray::GetSize() const { return GetDataSize(data_->dl_tensor); }\n\nint64_t NDArray::NumElements() const {\n  if (data_->dl_tensor.ndim == 0) return 0;\n  int64_t size = 1;\n  for (int i = 0; i < data_->dl_tensor.ndim; ++i) {\n    size *= data_->dl_tensor.shape[i];\n  }\n  return size;\n}\n\nbool NDArray::IsContiguous() const {\n  CHECK(data_ != nullptr);\n  if (data_->dl_tensor.strides == nullptr) return true;\n\n  // See https://github.com/dmlc/dgl/issues/2118 and PyTorch's\n  // compute_contiguous() implementation\n  int64_t z = 1;\n  for (int64_t i = data_->dl_tensor.ndim - 1; i >= 0; --i) {\n    if (data_->dl_tensor.shape[i] != 1) {\n      if (data_->dl_tensor.strides[i] == z)\n        z *= data_->dl_tensor.shape[i];\n      else\n        return false;\n    }\n  }\n  return true;\n}\n\nNDArray NDArray::CreateView(\n    std::vector<int64_t> shape, DGLDataType dtype, int64_t offset) {\n  CHECK(data_ != nullptr);\n  CHECK(IsContiguous()) << \"Can only create view for compact tensor\";\n  NDArray ret = Internal::Create(shape, dtype, data_->dl_tensor.ctx);\n  ret.data_->dl_tensor.byte_offset = this->data_->dl_tensor.byte_offset;\n  size_t curr_size = GetDataSize(this->data_->dl_tensor);\n  size_t view_size = GetDataSize(ret.data_->dl_tensor);\n  CHECK_LE(view_size, curr_size)\n      << \"Tries to create a view that has bigger memory than current one\";\n  // increase ref count\n  this->data_->IncRef();\n  ret.data_->manager_ctx = this->data_;\n  ret.data_->dl_tensor.data =\n      static_cast<char*>(this->data_->dl_tensor.data) + offset;\n  return ret;\n}\n\nNDArray NDArray::EmptyShared(\n    const std::string& name, std::vector<int64_t> shape, DGLDataType dtype,\n    DGLContext ctx, bool is_create) {\n  NDArray ret = Internal::Create(shape, dtype, ctx);\n  size_t size = GetDataSize(ret.data_->dl_tensor);\n  auto mem = std::make_shared<SharedMemory>(name);\n  if (is_create) {\n    ret.data_->dl_tensor.data = mem->CreateNew(size);\n  } else {\n    ret.data_->dl_tensor.data = mem->Open(size);\n  }\n\n  ret.data_->mem = mem;\n  return ret;\n}\n\nNDArray NDArray::Empty(\n    std::vector<int64_t> shape, DGLDataType dtype, DGLContext ctx) {\n  NDArray ret = Internal::Create(shape, dtype, ctx);\n  size_t size = GetDataSize(ret.data_->dl_tensor);\n  size_t alignment = GetDataAlignment(ret.data_->dl_tensor);\n  if (size > 0)\n    ret.data_->dl_tensor.data = DeviceAPI::Get(ret->ctx)->AllocDataSpace(\n        ret->ctx, size, alignment, ret->dtype);\n  return ret;\n}\n\nvoid NDArray::CopyFromTo(DGLArray* from, DGLArray* to) {\n  size_t from_size = GetDataSize(*from);\n  size_t to_size = GetDataSize(*to);\n  CHECK_EQ(from_size, to_size)\n      << \"DGLArrayCopyFromTo: The size must exactly match\";\n\n  CHECK(\n      from->ctx.device_type == to->ctx.device_type ||\n      from->ctx.device_type == kDGLCPU || to->ctx.device_type == kDGLCPU)\n      << \"Can not copy across different ctx types directly\";\n\n  // Use the context that is *not* a cpu context to get the correct device\n  // api manager.\n  DGLContext ctx = from->ctx.device_type != kDGLCPU ? from->ctx : to->ctx;\n\n  // default: local current cuda stream\n  DeviceAPI::Get(ctx)->CopyDataFromTo(\n      from->data, static_cast<size_t>(from->byte_offset), to->data,\n      static_cast<size_t>(to->byte_offset), from_size, from->ctx, to->ctx,\n      from->dtype);\n}\n\nvoid NDArray::RecordedCopyFromTo(\n    DGLArray* from, DGLArray* to, void* pytorch_ctx) {\n  size_t from_size = GetDataSize(*from);\n  size_t to_size = GetDataSize(*to);\n  CHECK_EQ(from_size, to_size)\n      << \"DGLArrayCopyFromTo: The size must exactly match.\";\n\n  CHECK(from->ctx.device_type != to->ctx.device_type)\n      << \"Recoding event is only called for the copy between CPU and GPU.\";\n\n  CHECK(from->ctx.device_type == kDGLCUDA || to->ctx.device_type == kDGLCUDA)\n      << \"At least one CUDA ctx needs to be involved.\";\n\n  DeviceAPI::Get(kDGLCUDA)->RecordedCopyDataFromTo(\n      from->data, static_cast<size_t>(from->byte_offset), to->data,\n      static_cast<size_t>(to->byte_offset), from_size, from->ctx, to->ctx,\n      from->dtype, pytorch_ctx);\n}\n\nNDArray NDArray::PinnedEmpty(\n    std::vector<int64_t> shape, DGLDataType dtype, DGLContext ctx) {\n  CHECK_EQ(ctx.device_type, kDGLCPU) << \"Only NDArray on CPU can be pinned\";\n  NDArray ret = Internal::Create(shape, dtype, ctx);\n  size_t size = GetDataSize(ret.data_->dl_tensor);\n  if (size > 0) {\n    ret.data_->dl_tensor.data = DeviceAPI::Get(kDGLCUDA)->AllocPinnedDataSpace(\n        size, &(ret.data_->pytorch_ctx_), &(ret.data_->pytorch_raw_deleter_));\n    CHECK(\n        ret.data_->pytorch_ctx_ != nullptr &&\n        ret.data_->pytorch_raw_deleter_ != nullptr)\n        << \"The allocation failed in PyTorch's CachingHostAllocator. \"\n        << \"The returned context pointer is \" << ret.data_->pytorch_ctx_\n        << \" and the function deleter is \" << ret.data_->pytorch_raw_deleter_;\n    ret.data_->pinned_by_pytorch_ = true;\n  }\n  return ret;\n}\n\nvoid NDArray::PinContainer(NDArray::Container* ptr) {\n  if (IsContainerPinned(ptr)) return;\n  auto* tensor = &(ptr->dl_tensor);\n  CHECK_EQ(tensor->ctx.device_type, kDGLCPU)\n      << \"Only NDArray on CPU can be pinned\";\n  ptr->pinned_by_dgl_ =\n      DeviceAPI::Get(kDGLCUDA)->PinData(tensor->data, GetDataSize(*tensor));\n}\n\nvoid NDArray::UnpinContainer(NDArray::Container* ptr) {\n  auto container_is_pinned = IsContainerPinned(ptr);\n  // The tensor may be pinned outside of DGL via a different CUDA API,\n  // so we cannot unpin it with cudaHostUnregister.\n  CHECK(ptr->pinned_by_dgl_ || !container_is_pinned)\n      << \"Cannot unpin a tensor that is pinned outside of DGL.\";\n  // 1. not pinned, do nothing\n  if (!container_is_pinned) return;\n  // 2. pinned by DGL, unpin it\n  DeviceAPI::Get(kDGLCUDA)->UnpinData(ptr->dl_tensor.data);\n  ptr->pinned_by_dgl_ = false;\n}\n\nvoid NDArray::RecordStream(DGLArray* tensor, DGLStreamHandle stream) {\n  TensorDispatcher* tensor_dispatcher = TensorDispatcher::Global();\n  CHECK(tensor_dispatcher->IsAvailable())\n      << \"RecordStream only works when TensorAdapter is available.\";\n  CHECK_EQ(tensor->ctx.device_type, kDGLCUDA)\n      << \"RecordStream only works with GPU tensors.\";\n\n  tensor_dispatcher->RecordStream(tensor->data, stream, tensor->ctx.device_id);\n}\n\ntemplate <typename T>\nNDArray NDArray::FromVector(const std::vector<T>& vec, DGLContext ctx) {\n  const DGLDataType dtype = DGLDataTypeTraits<T>::dtype;\n  int64_t size = static_cast<int64_t>(vec.size());\n  NDArray ret = NDArray::Empty({size}, dtype, ctx);\n  DeviceAPI::Get(ctx)->CopyDataFromTo(\n      vec.data(), 0, static_cast<T*>(ret->data), 0, size * sizeof(T),\n      DGLContext{kDGLCPU, 0}, ctx, dtype);\n  return ret;\n}\n\nNDArray NDArray::CreateFromRaw(\n    const std::vector<int64_t>& shape, DGLDataType dtype, DGLContext ctx,\n    void* raw, bool auto_free) {\n  NDArray ret = Internal::Create(shape, dtype, ctx);\n  ret.data_->dl_tensor.data = raw;\n  if (!auto_free) ret.data_->deleter = nullptr;\n  return ret;\n}\n\n// export specializations\ntemplate NDArray NDArray::FromVector<int32_t>(\n    const std::vector<int32_t>&, DGLContext);\ntemplate NDArray NDArray::FromVector<int64_t>(\n    const std::vector<int64_t>&, DGLContext);\ntemplate NDArray NDArray::FromVector<uint32_t>(\n    const std::vector<uint32_t>&, DGLContext);\ntemplate NDArray NDArray::FromVector<uint64_t>(\n    const std::vector<uint64_t>&, DGLContext);\ntemplate NDArray NDArray::FromVector<float>(\n    const std::vector<float>&, DGLContext);\ntemplate NDArray NDArray::FromVector<double>(\n    const std::vector<double>&, DGLContext);\n\ntemplate <typename T>\nstd::vector<T> NDArray::ToVector() const {\n  const DGLDataType dtype = DGLDataTypeTraits<T>::dtype;\n  CHECK(data_->dl_tensor.ndim == 1)\n      << \"ToVector() only supported for 1D arrays\";\n  CHECK(data_->dl_tensor.dtype == dtype) << \"dtype mismatch\";\n\n  int64_t size = data_->dl_tensor.shape[0];\n  std::vector<T> vec(size);\n  const DGLContext& ctx = data_->dl_tensor.ctx;\n  DeviceAPI::Get(ctx)->CopyDataFromTo(\n      static_cast<T*>(data_->dl_tensor.data), 0, vec.data(), 0,\n      size * sizeof(T), ctx, DGLContext{kDGLCPU, 0}, dtype);\n  return vec;\n}\n\ntemplate std::vector<int32_t> NDArray::ToVector<int32_t>() const;\ntemplate std::vector<int64_t> NDArray::ToVector<int64_t>() const;\ntemplate std::vector<uint32_t> NDArray::ToVector<uint32_t>() const;\ntemplate std::vector<uint64_t> NDArray::ToVector<uint64_t>() const;\ntemplate std::vector<float> NDArray::ToVector<float>() const;\ntemplate std::vector<double> NDArray::ToVector<double>() const;\n\nstd::shared_ptr<SharedMemory> NDArray::GetSharedMem() const {\n  return this->data_->mem;\n}\n\nbool NDArray::IsContainerPinned(NDArray::Container* ptr) {\n  if (ptr->pinned_by_dgl_ || ptr->pinned_by_pytorch_) return true;\n  auto* tensor = &(ptr->dl_tensor);\n  // Can only be pinned if on CPU...\n  if (tensor->ctx.device_type != kDGLCPU) return false;\n  // ... and CUDA device API is enabled, and the tensor is indeed in pinned\n  // memory.\n  auto device = DeviceAPI::Get(kDGLCUDA, true);\n  return device && device->IsPinned(tensor->data);\n}\n\nvoid NDArray::Save(dmlc::Stream* strm) const {\n  auto zc_strm = dynamic_cast<StreamWithBuffer*>(strm);\n  if (zc_strm) {\n    zc_strm->PushNDArray(*this);\n    return;\n  }\n  SaveDGLArray(strm, const_cast<DGLArray*>(operator->()));\n}\n\nbool NDArray::Load(dmlc::Stream* strm) {\n  auto zc_strm = dynamic_cast<StreamWithBuffer*>(strm);\n  if (zc_strm) {\n    *this = zc_strm->PopNDArray();\n    return true;\n  }\n  uint64_t header, reserved;\n  CHECK(strm->Read(&header)) << \"Invalid DGLArray file format\";\n  CHECK(strm->Read(&reserved)) << \"Invalid DGLArray file format\";\n  CHECK(header == kDGLNDArrayMagic) << \"Invalid DGLArray file format\";\n  DGLContext ctx;\n  int ndim;\n  DGLDataType dtype;\n  CHECK(strm->Read(&ctx)) << \"Invalid DGLArray file format\";\n  CHECK(strm->Read(&ndim)) << \"Invalid DGLArray file format\";\n  CHECK(strm->Read(&dtype)) << \"Invalid DGLArray file format\";\n  CHECK_EQ(ctx.device_type, kDGLCPU)\n      << \"Invalid DGLArray context: can only save as CPU tensor\";\n  std::vector<int64_t> shape(ndim);\n  if (ndim != 0) {\n    CHECK(strm->ReadArray(&shape[0], ndim)) << \"Invalid DGLArray file format\";\n  }\n  NDArray ret = NDArray::Empty(shape, dtype, ctx);\n  int64_t num_elems = 1;\n  int elem_bytes = (ret->dtype.bits + 7) / 8;\n  for (int i = 0; i < ret->ndim; ++i) {\n    num_elems *= ret->shape[i];\n  }\n  int64_t data_byte_size;\n  CHECK(strm->Read(&data_byte_size)) << \"Invalid DGLArray file format\";\n  CHECK(data_byte_size == num_elems * elem_bytes)\n      << \"Invalid DGLArray file format\";\n  if (data_byte_size != 0) {\n    // strm->Read will return the total number of elements successfully read.\n    // Therefore if data_byte_size is zero, the CHECK below would fail.\n    CHECK(strm->Read(ret->data, data_byte_size))\n        << \"Invalid DGLArray file format\";\n  }\n  if (!DMLC_IO_NO_ENDIAN_SWAP) {\n    dmlc::ByteSwap(ret->data, elem_bytes, num_elems);\n  }\n  *this = ret;\n  return true;\n}\n\n}  // namespace runtime\n}  // namespace dgl\n\nusing namespace dgl::runtime;\n\nint DGLArrayAlloc(\n    const dgl_index_t* shape, int ndim, int dtype_code, int dtype_bits,\n    int dtype_lanes, int device_type, int device_id, DGLArrayHandle* out) {\n  API_BEGIN();\n  DGLDataType dtype;\n  dtype.code = static_cast<uint8_t>(dtype_code);\n  dtype.bits = static_cast<uint8_t>(dtype_bits);\n  dtype.lanes = static_cast<uint16_t>(dtype_lanes);\n  DGLContext ctx;\n  ctx.device_type = static_cast<DGLDeviceType>(device_type);\n  ctx.device_id = device_id;\n  *out = NDArray::Internal::MoveAsDGLArray(\n      NDArray::Empty(std::vector<int64_t>(shape, shape + ndim), dtype, ctx));\n  API_END();\n}\n\nint DGLArrayAllocSharedMem(\n    const char* mem_name, const dgl_index_t* shape, int ndim, int dtype_code,\n    int dtype_bits, int dtype_lanes, bool is_create, DGLArrayHandle* out) {\n  API_BEGIN();\n  DGLDataType dtype;\n  dtype.code = static_cast<uint8_t>(dtype_code);\n  dtype.bits = static_cast<uint8_t>(dtype_bits);\n  dtype.lanes = static_cast<uint16_t>(dtype_lanes);\n  std::vector<int64_t> shape_vec(shape, shape + ndim);\n  NDArray arr = NDArray::EmptyShared(\n      mem_name, shape_vec, dtype, DGLContext{kDGLCPU, 0}, is_create);\n  *out = NDArray::Internal::MoveAsDGLArray(arr);\n  API_END();\n}\n\nint DGLArrayFree(DGLArrayHandle handle) {\n  API_BEGIN();\n  reinterpret_cast<NDArray::Container*>(handle)->DecRef();\n  API_END();\n}\n\nint DGLArrayCopyFromTo(DGLArrayHandle from, DGLArrayHandle to) {\n  API_BEGIN();\n  NDArray::CopyFromTo(from, to);\n  API_END();\n}\n\nint DGLArrayCopyFromBytes(DGLArrayHandle handle, void* data, size_t nbytes) {\n  API_BEGIN();\n  DGLContext cpu_ctx;\n  cpu_ctx.device_type = kDGLCPU;\n  cpu_ctx.device_id = 0;\n  size_t arr_size = GetDataSize(*handle);\n  CHECK_EQ(arr_size, nbytes) << \"DGLArrayCopyFromBytes: size mismatch\";\n  DeviceAPI::Get(handle->ctx)\n      ->CopyDataFromTo(\n          data, 0, handle->data, static_cast<size_t>(handle->byte_offset),\n          nbytes, cpu_ctx, handle->ctx, handle->dtype);\n  API_END();\n}\n\nint DGLArrayCopyToBytes(DGLArrayHandle handle, void* data, size_t nbytes) {\n  API_BEGIN();\n  DGLContext cpu_ctx;\n  cpu_ctx.device_type = kDGLCPU;\n  cpu_ctx.device_id = 0;\n  size_t arr_size = GetDataSize(*handle);\n  CHECK_EQ(arr_size, nbytes) << \"DGLArrayCopyToBytes: size mismatch\";\n  DeviceAPI::Get(handle->ctx)\n      ->CopyDataFromTo(\n          handle->data, static_cast<size_t>(handle->byte_offset), data, 0,\n          nbytes, handle->ctx, cpu_ctx, handle->dtype);\n  API_END();\n}\n\nint DGLArrayPinData(DGLArrayHandle handle, DGLContext ctx) {\n  API_BEGIN();\n  auto* nd_container = reinterpret_cast<NDArray::Container*>(handle);\n  NDArray::PinContainer(nd_container);\n  API_END();\n}\n\nint DGLArrayUnpinData(DGLArrayHandle handle, DGLContext ctx) {\n  API_BEGIN();\n  auto* nd_container = reinterpret_cast<NDArray::Container*>(handle);\n  NDArray::UnpinContainer(nd_container);\n  API_END();\n}\n\nint DGLArrayRecordStream(DGLArrayHandle handle, DGLStreamHandle stream) {\n  API_BEGIN();\n  NDArray::RecordStream(handle, stream);\n  API_END();\n}\n"
  },
  {
    "path": "src/runtime/object.cc",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file runtime/object.cc\n * @brief Implementation of runtime object APIs.\n */\n#include <dgl/runtime/object.h>\n\n#include <atomic>\n#include <memory>\n#include <mutex>\n#include <unordered_map>\n\nnamespace dgl {\nnamespace runtime {\n\nnamespace {\n// single manager of operator information.\nstruct TypeManager {\n  // mutex to avoid registration from multiple threads.\n  // recursive is needed for trigger(which calls UpdateAttrMap)\n  std::mutex mutex;\n  std::atomic<uint32_t> type_counter{0};\n  std::unordered_map<std::string, uint32_t> key2index;\n  std::vector<std::string> index2key;\n  // get singleton of the\n  static TypeManager* Global() {\n    static TypeManager inst;\n    return &inst;\n  }\n};\n}  // namespace\n\nbool Object::_DerivedFrom(uint32_t tid) const {\n  static uint32_t tindex = TypeKey2Index(Object::_type_key);\n  return tid == tindex;\n}\n\n// this is slow, usually caller always hold the result in a static variable.\nuint32_t Object::TypeKey2Index(const char* key) {\n  TypeManager* t = TypeManager::Global();\n  std::lock_guard<std::mutex> lock(t->mutex);\n  std::string skey = key;\n  auto it = t->key2index.find(skey);\n  if (it != t->key2index.end()) {\n    return it->second;\n  }\n  uint32_t tid = ++(t->type_counter);\n  t->key2index[skey] = tid;\n  t->index2key.push_back(skey);\n  return tid;\n}\n\nconst char* Object::TypeIndex2Key(uint32_t index) {\n  TypeManager* t = TypeManager::Global();\n  std::lock_guard<std::mutex> lock(t->mutex);\n  CHECK_NE(index, 0);\n  return t->index2key.at(index - 1).c_str();\n}\n\n}  // namespace runtime\n}  // namespace dgl\n"
  },
  {
    "path": "src/runtime/pack_args.h",
    "content": "/**\n *  Copyright (c) 2017 by Contributors\n * @file pack_args.h\n * @brief Utility to pack DGLArgs to other type-erased fution calling\n * convention.\n *\n *  Two type erased function signatures are supported.\n *   - cuda_style(void** args, int num_args);\n *      - Pack everything by address\n *   - metal_style(void** buffers, int num_buffers,\n *                 union_32bit args[N], int num_args);\n *      - Pack buffer by address, pack rest parameter into 32bit union buffer.\n */\n#ifndef DGL_RUNTIME_PACK_ARGS_H_\n#define DGL_RUNTIME_PACK_ARGS_H_\n\n#include <dgl/runtime/c_runtime_api.h>\n#include <dgl/runtime/packed_func.h>\n\n#include <cstring>\n#include <vector>\n\nnamespace dgl {\nnamespace runtime {\n/**\n * @brief argument union type of 32bit.\n * Choose 32 bit because most GPU API do not work well with 64 bit.\n */\nunion ArgUnion {\n  int32_t v_int32;\n  uint32_t v_uint32;\n  float v_float32;\n};\n/**\n * @brief Create a packed function from void addr types.\n *\n * @param f with signiture (DGLArgs args, DGLRetValue* rv, void* void_args)\n * @param arg_types The arguments type information.\n * @tparam F the function type\n *\n * @return The wrapped packed function.\n */\ntemplate <typename F>\ninline PackedFunc PackFuncVoidAddr(\n    F f, const std::vector<DGLDataType>& arg_types);\n/**\n * @brief Create a packed function that from function only packs buffer\n * arguments.\n *\n * @param f with signiture (DGLArgs args, DGLRetValue* rv, ArgUnion* pack_args)\n * @param arg_types The arguments type information.\n * @tparam F the function type\n *\n * @return The wrapped packed function.\n */\ntemplate <typename F>\ninline PackedFunc PackFuncNonBufferArg(\n    F f, const std::vector<DGLDataType>& arg_types);\n/**\n * @brief Create a packed function that from function that takes a packed\n * arguments.\n *\n * @param f with signature (DGLArgs args, DGLRetValue* rv, void* pack_args,\n * size_t nbytes)\n * @param arg_types The arguments that wish to get from\n * @tparam F the function type\n *\n * @return The wrapped packed function.\n */\ntemplate <typename F>\ninline PackedFunc PackFuncPackedArg(\n    F f, const std::vector<DGLDataType>& arg_types);\n/**\n * @brief Extract number of buffer argument from the argument types.\n * @param arg_types The argument types.\n * @return number of buffer arguments\n */\ninline size_t NumBufferArgs(const std::vector<DGLDataType>& arg_types);\n\n// implementations details\nnamespace detail {\ntemplate <typename T, int kSize>\nclass TempArray {\n public:\n  explicit TempArray(int size) {}\n  T* data() { return data_; }\n\n private:\n  T data_[kSize];\n};\ntemplate <typename T>\nclass TempArray<T, 0> {\n public:\n  explicit TempArray(int size) : data_(size) {}\n  T* data() { return data_.data(); }\n\n private:\n  std::vector<T> data_;\n};\n\n/** @brief conversion code used in void arg. */\nenum ArgConvertCode {\n  INT64_TO_INT64,\n  INT64_TO_INT32,\n  INT64_TO_UINT32,\n  FLOAT64_TO_FLOAT32,\n  FLOAT64_TO_FLOAT64,\n  HANDLE_TO_HANDLE\n};\n\ninline ArgConvertCode GetArgConvertCode(DGLDataType t) {\n  CHECK_EQ(t.lanes, 1U)\n      << \"Cannot pass vector type argument to devic function for now\";\n  if (t.code == kDGLInt) {\n    if (t.bits == 64U) return INT64_TO_INT64;\n    if (t.bits == 32U) return INT64_TO_INT32;\n  } else if (t.code == kDGLUInt) {\n    if (t.bits == 32U) return INT64_TO_UINT32;\n  } else if (t.code == kDGLFloat) {\n    if (t.bits == 64U) return FLOAT64_TO_FLOAT64;\n    if (t.bits == 32U) return FLOAT64_TO_FLOAT32;\n  } else if (t.code == kHandle) {\n    return HANDLE_TO_HANDLE;\n  }\n  LOG(FATAL) << \"Cannot handle \" << t << \" as device function argument\";\n  return HANDLE_TO_HANDLE;\n}\n\ntemplate <int N, typename F>\ninline PackedFunc PackFuncVoidAddr_(\n    F f, const std::vector<ArgConvertCode>& codes) {\n  int num_args = static_cast<int>(codes.size());\n  auto ret = [f, codes, num_args](DGLArgs args, DGLRetValue* ret) {\n    TempArray<void*, N> addr_(num_args);\n    TempArray<ArgUnion, N> holder_(num_args);\n    void** addr = addr_.data();\n    ArgUnion* holder = holder_.data();\n    for (int i = 0; i < num_args; ++i) {\n      switch (codes[i]) {\n        case INT64_TO_INT64:\n        case FLOAT64_TO_FLOAT64:\n        case HANDLE_TO_HANDLE: {\n          addr[i] = (void*)&(args.values[i]);  // NOLINT(*)\n          break;\n        }\n        case INT64_TO_INT32: {\n          holder[i].v_int32 = static_cast<int32_t>(args.values[i].v_int64);\n          addr[i] = &(holder[i]);\n          break;\n        }\n        case INT64_TO_UINT32: {\n          holder[i].v_uint32 = static_cast<uint32_t>(args.values[i].v_int64);\n          addr[i] = &(holder[i]);\n          break;\n        }\n        case FLOAT64_TO_FLOAT32: {\n          holder[i].v_float32 = static_cast<float>(args.values[i].v_float64);\n          addr[i] = &(holder[i]);\n          break;\n        }\n      }\n    }\n    f(args, ret, addr);\n  };\n  return PackedFunc(ret);\n}\n\ntemplate <int N, typename F>\ninline PackedFunc PackFuncNonBufferArg_(\n    F f, int base, const std::vector<ArgConvertCode>& codes) {\n  int num_args = static_cast<int>(codes.size());\n  auto ret = [f, codes, base, num_args](DGLArgs args, DGLRetValue* ret) {\n    TempArray<ArgUnion, N> holder_(num_args);\n    ArgUnion* holder = holder_.data();\n    for (int i = 0; i < num_args; ++i) {\n      switch (codes[i]) {\n        case INT64_TO_INT64:\n        case FLOAT64_TO_FLOAT64: {\n          LOG(FATAL) << \"Donot support 64bit argument to device function\";\n          break;\n        }\n        case INT64_TO_INT32: {\n          holder[i].v_int32 =\n              static_cast<int32_t>(args.values[base + i].v_int64);\n          break;\n        }\n        case INT64_TO_UINT32: {\n          holder[i].v_uint32 =\n              static_cast<uint32_t>(args.values[base + i].v_int64);\n          break;\n        }\n        case FLOAT64_TO_FLOAT32: {\n          holder[i].v_float32 =\n              static_cast<float>(args.values[base + i].v_float64);\n          break;\n        }\n        case HANDLE_TO_HANDLE: {\n          LOG(FATAL) << \"not reached\";\n          break;\n        }\n      }\n    }\n    f(args, ret, holder);\n  };\n  return PackedFunc(ret);\n}\n\ntemplate <int N, typename F>\ninline PackedFunc PackFuncPackedArg_(\n    F f, const std::vector<ArgConvertCode>& codes) {\n  int num_args = static_cast<int>(codes.size());\n  auto ret = [f, codes, num_args](DGLArgs args, DGLRetValue* ret) {\n    TempArray<uint64_t, N> pack_(num_args);\n    int32_t* pack = reinterpret_cast<int32_t*>(pack_.data());\n    int32_t* ptr = pack;\n    static_assert(sizeof(DGLValue) == 8, \"invariant\");\n    static_assert(sizeof(void*) % sizeof(int32_t) == 0, \"invariant\");\n    for (int i = 0; i < num_args; ++i) {\n      switch (codes[i]) {\n        case HANDLE_TO_HANDLE: {\n          std::memcpy(ptr, &(args.values[i].v_handle), sizeof(void*));\n          ptr += sizeof(void*) / sizeof(int32_t);\n          break;\n        }\n        case INT64_TO_INT64:\n        case FLOAT64_TO_FLOAT64: {\n          std::memcpy(ptr, &args.values[i], sizeof(DGLValue));\n          ptr += 2;\n          break;\n        }\n        case INT64_TO_INT32: {\n          *ptr = static_cast<int32_t>(args.values[i].v_int64);\n          ++ptr;\n          break;\n        }\n        case INT64_TO_UINT32: {\n          *reinterpret_cast<uint32_t*>(ptr) =\n              static_cast<uint32_t>(args.values[i].v_int64);\n          ++ptr;\n          break;\n        }\n        case FLOAT64_TO_FLOAT32: {\n          *reinterpret_cast<float*>(ptr) =\n              static_cast<float>(args.values[i].v_float64);\n          ++ptr;\n          break;\n        }\n        default: {\n          LOG(FATAL) << \"not reached\";\n          break;\n        }\n      }\n    }\n    f(args, ret, pack, (ptr - pack) * sizeof(int32_t));\n  };\n  return PackedFunc(ret);\n}\n}  // namespace detail\n\ntemplate <typename F>\ninline PackedFunc PackFuncVoidAddr(\n    F f, const std::vector<DGLDataType>& arg_types) {\n  std::vector<detail::ArgConvertCode> codes(arg_types.size());\n  for (size_t i = 0; i < arg_types.size(); ++i) {\n    codes[i] = detail::GetArgConvertCode(arg_types[i]);\n  }\n  size_t num_void_args = arg_types.size();\n  // specialization\n  if (num_void_args <= 4) {\n    return detail::PackFuncVoidAddr_<4>(f, codes);\n  } else if (num_void_args <= 8) {\n    return detail::PackFuncVoidAddr_<8>(f, codes);\n  } else {\n    return detail::PackFuncVoidAddr_<0>(f, codes);\n  }\n}\n\ninline size_t NumBufferArgs(const std::vector<DGLDataType>& arg_types) {\n  size_t base = arg_types.size();\n  for (size_t i = 0; i < arg_types.size(); ++i) {\n    if (arg_types[i].code != kHandle) {\n      base = i;\n      break;\n    }\n  }\n  for (size_t i = base; i < arg_types.size(); ++i) {\n    CHECK(arg_types[i].code != kHandle)\n        << \"Device function need to be organized\";\n  }\n  return base;\n}\n\ntemplate <typename F>\ninline PackedFunc PackFuncNonBufferArg(\n    F f, const std::vector<DGLDataType>& arg_types) {\n  size_t num_buffer = NumBufferArgs(arg_types);\n  std::vector<detail::ArgConvertCode> codes;\n  for (size_t i = num_buffer; i < arg_types.size(); ++i) {\n    codes.push_back(detail::GetArgConvertCode(arg_types[i]));\n  }\n  int base = static_cast<int>(num_buffer);\n  size_t nargs = codes.size();\n  // specialization\n  if (nargs <= 4) {\n    return detail::PackFuncNonBufferArg_<4>(f, base, codes);\n  } else {\n    return detail::PackFuncNonBufferArg_<0>(f, base, codes);\n  }\n}\n\ntemplate <typename F>\ninline PackedFunc PackFuncPackedArg(\n    F f, const std::vector<DGLDataType>& arg_types) {\n  std::vector<detail::ArgConvertCode> codes;\n  for (size_t i = 0; i < arg_types.size(); ++i) {\n    codes.push_back(detail::GetArgConvertCode(arg_types[i]));\n  }\n  size_t nargs = codes.size();\n  // specialization\n  if (nargs <= 4) {\n    return detail::PackFuncPackedArg_<4>(f, codes);\n  } else {\n    return detail::PackFuncPackedArg_<0>(f, codes);\n  }\n}\n}  // namespace runtime\n}  // namespace dgl\n#endif  // DGL_RUNTIME_PACK_ARGS_H_\n"
  },
  {
    "path": "src/runtime/parallel_for.cpp",
    "content": "/**\n *  Copyright (c) 2016 by Contributors\n * Implementation of C API (reference: tvm/src/api/c_api.cc)\n * @file c_api.cc\n */\n\nnamespace dgl {\nnamespace runtime {\nDefaultGrainSizeT default_grain_size;\n}  // namespace runtime\n}  // namespace dgl\n"
  },
  {
    "path": "src/runtime/registry.cc",
    "content": "/**\n *  Copyright (c) 2017 by Contributors\n * @file registry.cc\n * @brief The global registry of packed function.\n */\n#include <dgl/runtime/registry.h>\n#include <dmlc/logging.h>\n#include <dmlc/thread_local.h>\n\n#include <array>\n#include <memory>\n#include <mutex>\n#include <unordered_map>\n\n#include \"runtime_base.h\"\n\nnamespace dgl {\nnamespace runtime {\n\nstruct Registry::Manager {\n  // map storing the functions.\n  // We delibrately used raw pointer\n  // This is because PackedFunc can contain callbacks into the host\n  // languge(python) and the resource can become invalid because of\n  // indeterminstic order of destruction. The resources will only be recycled\n  // during program exit.\n  std::unordered_map<std::string, Registry*> fmap;\n  // vtable for extension type\n  std::array<ExtTypeVTable, kExtEnd> ext_vtable;\n  // mutex\n  std::mutex mutex;\n\n  Manager() {\n    for (auto& x : ext_vtable) {\n      x.destroy = nullptr;\n    }\n  }\n\n  static Manager* Global() {\n    static Manager inst;\n    return &inst;\n  }\n};\n\nRegistry& Registry::set_body(PackedFunc f) {  // NOLINT(*)\n  func_ = f;\n  return *this;\n}\n\nRegistry& Registry::Register(\n    const std::string& name, bool override) {  // NOLINT(*)\n  Manager* m = Manager::Global();\n  std::lock_guard<std::mutex> lock(m->mutex);\n  auto it = m->fmap.find(name);\n  if (it == m->fmap.end()) {\n    Registry* r = new Registry();\n    r->name_ = name;\n    m->fmap[name] = r;\n    return *r;\n  } else {\n    CHECK(override) << \"Global PackedFunc \" << name << \" is already registered\";\n    return *it->second;\n  }\n}\n\nbool Registry::Remove(const std::string& name) {\n  Manager* m = Manager::Global();\n  std::lock_guard<std::mutex> lock(m->mutex);\n  auto it = m->fmap.find(name);\n  if (it == m->fmap.end()) return false;\n  m->fmap.erase(it);\n  return true;\n}\n\nconst PackedFunc* Registry::Get(const std::string& name) {\n  Manager* m = Manager::Global();\n  std::lock_guard<std::mutex> lock(m->mutex);\n  auto it = m->fmap.find(name);\n  if (it == m->fmap.end()) return nullptr;\n  return &(it->second->func_);\n}\n\nstd::vector<std::string> Registry::ListNames() {\n  Manager* m = Manager::Global();\n  std::lock_guard<std::mutex> lock(m->mutex);\n  std::vector<std::string> keys;\n  keys.reserve(m->fmap.size());\n  for (const auto& kv : m->fmap) {\n    keys.push_back(kv.first);\n  }\n  return keys;\n}\n\nExtTypeVTable* ExtTypeVTable::Get(int type_code) {\n  CHECK(type_code > kExtBegin && type_code < kExtEnd);\n  Registry::Manager* m = Registry::Manager::Global();\n  ExtTypeVTable* vt = &(m->ext_vtable[type_code]);\n  CHECK(vt->destroy != nullptr) << \"Extension type not registered\";\n  return vt;\n}\n\nExtTypeVTable* ExtTypeVTable::RegisterInternal(\n    int type_code, const ExtTypeVTable& vt) {\n  CHECK(type_code > kExtBegin && type_code < kExtEnd);\n  Registry::Manager* m = Registry::Manager::Global();\n  std::lock_guard<std::mutex> lock(m->mutex);\n  ExtTypeVTable* pvt = &(m->ext_vtable[type_code]);\n  pvt[0] = vt;\n  return pvt;\n}\n}  // namespace runtime\n}  // namespace dgl\n\n/** @brief entry to to easily hold returning information */\nstruct DGLFuncThreadLocalEntry {\n  /** @brief result holder for returning strings */\n  std::vector<std::string> ret_vec_str;\n  /** @brief result holder for returning string pointers */\n  std::vector<const char*> ret_vec_charp;\n};\n\n/** @brief Thread local store that can be used to hold return values. */\ntypedef dmlc::ThreadLocalStore<DGLFuncThreadLocalEntry> DGLFuncThreadLocalStore;\n\nint DGLExtTypeFree(void* handle, int type_code) {\n  API_BEGIN();\n  dgl::runtime::ExtTypeVTable::Get(type_code)->destroy(handle);\n  API_END();\n}\n\nint DGLFuncRegisterGlobal(const char* name, DGLFunctionHandle f, int override) {\n  API_BEGIN();\n  dgl::runtime::Registry::Register(name, override != 0)\n      .set_body(*static_cast<dgl::runtime::PackedFunc*>(f));\n  API_END();\n}\n\nint DGLFuncGetGlobal(const char* name, DGLFunctionHandle* out) {\n  API_BEGIN();\n  const dgl::runtime::PackedFunc* fp = dgl::runtime::Registry::Get(name);\n  if (fp != nullptr) {\n    *out = new dgl::runtime::PackedFunc(*fp);  // NOLINT(*)\n  } else {\n    *out = nullptr;\n  }\n  API_END();\n}\n\nint DGLFuncListGlobalNames(int* out_size, const char*** out_array) {\n  API_BEGIN();\n  DGLFuncThreadLocalEntry* ret = DGLFuncThreadLocalStore::Get();\n  ret->ret_vec_str = dgl::runtime::Registry::ListNames();\n  ret->ret_vec_charp.clear();\n  for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {\n    ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());\n  }\n  *out_array = dmlc::BeginPtr(ret->ret_vec_charp);\n  *out_size = static_cast<int>(ret->ret_vec_str.size());\n  API_END();\n}\n"
  },
  {
    "path": "src/runtime/resource_manager.cc",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file resource_manager.cc\n * @brief Manage the resources.\n */\n\n#include \"resource_manager.h\"\n\n#include <dmlc/logging.h>\n\n#include <utility>\n\nnamespace dgl {\nnamespace runtime {\n\n/**\n * The runtime allocates resources during the computation. Some of the resources\n * cannot be destroyed after the process exits especially when the process\n * doesn't exits normally. We need to keep track of the resources in the system\n * and clean them up properly.\n */\nclass ResourceManager {\n  std::unordered_map<std::string, std::shared_ptr<Resource>> resources;\n\n public:\n  void Add(const std::string &key, std::shared_ptr<Resource> resource) {\n    auto it = resources.find(key);\n    CHECK(it == resources.end()) << key << \" already exists\";\n    resources.insert(\n        std::pair<std::string, std::shared_ptr<Resource>>(key, resource));\n  }\n\n  void Erase(const std::string &key) { resources.erase(key); }\n\n  void Cleanup() {\n    for (auto it = resources.begin(); it != resources.end(); it++) {\n      it->second->Destroy();\n    }\n    resources.clear();\n  }\n};\n\nstatic ResourceManager manager;\n\nvoid AddResource(const std::string &key, std::shared_ptr<Resource> resource) {\n  manager.Add(key, resource);\n}\n\nvoid DeleteResource(const std::string &key) { manager.Erase(key); }\n\nvoid CleanupResources() { manager.Cleanup(); }\n\n}  // namespace runtime\n}  // namespace dgl\n"
  },
  {
    "path": "src/runtime/resource_manager.h",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file resource_manager.h\n * @brief Manage the resources in the runtime system.\n */\n#ifndef DGL_RUNTIME_RESOURCE_MANAGER_H_\n#define DGL_RUNTIME_RESOURCE_MANAGER_H_\n\n#include <memory>\n#include <string>\n#include <unordered_map>\n\nnamespace dgl {\nnamespace runtime {\n\n/**\n * A class that provides the interface to describe a resource that can be\n * managed by a resource manager. Some of the resources cannot be free'd\n * automatically when the process exits, especially when the process doesn't\n * exit normally. One example is shared memory. We can keep track of this kind\n * of resources and manage them properly.\n */\nclass Resource {\n public:\n  virtual ~Resource() {}\n\n  virtual void Destroy() = 0;\n};\n\n// Add resource.\nvoid AddResource(const std::string &key, std::shared_ptr<Resource> resource);\n\n// Delete resource.\nvoid DeleteResource(const std::string &key);\n\n// Clean up all resources.\nvoid CleanupResources();\n\n}  // namespace runtime\n}  // namespace dgl\n\n#endif  // DGL_RUNTIME_RESOURCE_MANAGER_H_\n"
  },
  {
    "path": "src/runtime/runtime_base.h",
    "content": "/**\n *  Copyright (c) 2016 by Contributors\n * @file runtime_base.h\n * @brief Base of all C APIs\n */\n#ifndef DGL_RUNTIME_RUNTIME_BASE_H_\n#define DGL_RUNTIME_RUNTIME_BASE_H_\n\n#include <dgl/runtime/c_runtime_api.h>\n\n#include <stdexcept>\n\n/** @brief  macro to guard beginning and end section of all functions */\n#define API_BEGIN() try {\n/** @brief every function starts with API_BEGIN();\n     and finishes with API_END() or API_END_HANDLE_ERROR */\n#define API_END()                           \\\n  }                                         \\\n  catch (std::runtime_error & _except_) {   \\\n    return DGLAPIHandleException(_except_); \\\n  }                                         \\\n  return 0;  // NOLINT(*)\n/**\n * @brief every function starts with API_BEGIN();\n *   and finishes with API_END() or API_END_HANDLE_ERROR\n *   The finally clause contains procedure to cleanup states when an error\n * happens.\n */\n#define API_END_HANDLE_ERROR(Finalize)      \\\n  }                                         \\\n  catch (std::runtime_error & _except_) {   \\\n    Finalize;                               \\\n    return DGLAPIHandleException(_except_); \\\n  }                                         \\\n  return 0;  // NOLINT(*)\n\n/**\n * @brief handle exception throwed out\n * @param e the exception\n * @return the return value of API after exception is handled\n */\ninline int DGLAPIHandleException(const std::runtime_error &e) {\n  DGLAPISetLastError(e.what());\n  return -1;\n}\n\n#endif  // DGL_RUNTIME_RUNTIME_BASE_H_\n"
  },
  {
    "path": "src/runtime/semaphore_wrapper.cc",
    "content": "/**\n *  Copyright (c) 2021 by Contributors\n * @file semaphore_wrapper.cc\n * @brief A simple corss platform semaphore wrapper\n */\n#include \"semaphore_wrapper.h\"\n\n#include <dmlc/logging.h>\n\n#ifndef _WIN32\n#include <errno.h>\n#include <time.h>\n#include <unistd.h>\n#endif\n\nnamespace dgl {\nnamespace runtime {\n\n#ifdef _WIN32\n\nSemaphore::Semaphore() {\n  sem_ = CreateSemaphore(nullptr, 0, INT_MAX, nullptr);\n  if (!sem_) {\n    LOG(FATAL) << \"Cannot create semaphore\";\n  }\n}\n\nvoid Semaphore::Wait() { WaitForSingleObject(sem_, INFINITE); }\n\nbool Semaphore::TimedWait(int) {\n  // Timed wait is not supported on WIN32.\n  Wait();\n  return true;\n}\n\nvoid Semaphore::Post() { ReleaseSemaphore(sem_, 1, nullptr); }\n\n#else\n\nSemaphore::Semaphore() { sem_init(&sem_, 0, 0); }\n\nvoid Semaphore::Wait() { sem_wait(&sem_); }\n\nbool Semaphore::TimedWait(int timeout) {\n  // sem_timedwait does not exist in Mac OS.\n#ifdef __APPLE__\n  DLOG(WARNING) << \"Timeout is not supported in semaphore's wait on Mac OS.\";\n  Wait();\n#else\n  // zero timeout means wait infinitely\n  if (timeout == 0) {\n    DLOG(WARNING) << \"Will wait infinitely on semaphore until posted.\";\n    Wait();\n    return true;\n  }\n  timespec ts;\n  if (clock_gettime(CLOCK_REALTIME, &ts) != 0) {\n    LOG(ERROR) << \"Failed to get current time via clock_gettime. Errno: \"\n               << errno;\n    return false;\n  }\n  ts.tv_sec += timeout / MILLISECONDS_PER_SECOND;\n  ts.tv_nsec +=\n      (timeout % MILLISECONDS_PER_SECOND) * NANOSECONDS_PER_MILLISECOND;\n  if (ts.tv_nsec >= NANOSECONDS_PER_SECOND) {\n    ts.tv_nsec -= NANOSECONDS_PER_SECOND;\n    ++ts.tv_sec;\n  }\n  int ret = 0;\n  while ((ret = sem_timedwait(&sem_, &ts) != 0) && errno == EINTR) {\n    continue;\n  }\n  if (ret != 0) {\n    if (errno == ETIMEDOUT) {\n      DLOG(WARNING) << \"sem_timedwait timed out after \" << timeout\n                    << \" milliseconds.\";\n    } else {\n      LOG(ERROR) << \"sem_timedwait returns unexpectedly. Errno: \" << errno;\n    }\n    return false;\n  }\n#endif\n\n  return true;\n}\n\nvoid Semaphore::Post() { sem_post(&sem_); }\n\n#endif\n\n}  // namespace runtime\n}  // namespace dgl\n"
  },
  {
    "path": "src/runtime/semaphore_wrapper.h",
    "content": "/**\n *  Copyright (c) 2021 by Contributors\n * @file semaphore_wrapper.h\n * @brief A simple corss platform semaphore wrapper\n */\n#ifndef DGL_RUNTIME_SEMAPHORE_WRAPPER_H_\n#define DGL_RUNTIME_SEMAPHORE_WRAPPER_H_\n\n#ifdef _WIN32\n#include <windows.h>\n#else\n#include <semaphore.h>\n#endif\n\nnamespace dgl {\nnamespace runtime {\n\n/**\n * @brief A simple crossplatform Semaphore wrapper\n */\nclass Semaphore {\n public:\n  /**\n   * @brief Semaphore constructor\n   */\n  Semaphore();\n  /**\n   * @brief blocking wait, decrease semaphore by 1\n   */\n  void Wait();\n  /**\n   * @brief timed wait, decrease semaphore by 1 or returns if times out\n   * @param timeout The timeout value in milliseconds. If zero, wait\n   * indefinitely.\n   */\n  bool TimedWait(int timeout);\n  /**\n   * @brief increase semaphore by 1\n   */\n  void Post();\n\n private:\n#ifdef _WIN32\n  HANDLE sem_;\n#else\n  sem_t sem_;\n#endif\n  enum {\n    MILLISECONDS_PER_SECOND = 1000,\n    NANOSECONDS_PER_MILLISECOND = 1000 * 1000,\n    NANOSECONDS_PER_SECOND = 1000 * 1000 * 1000\n  };\n};\n\n}  // namespace runtime\n}  // namespace dgl\n\n#endif  // DGL_RUNTIME_SEMAPHORE_WRAPPER_H_\n"
  },
  {
    "path": "src/runtime/shared_mem.cc",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file shared_mem.cc\n * @brief Shared memory management.\n */\n#ifndef _WIN32\n#include <fcntl.h>\n#include <sys/mman.h>\n#include <unistd.h>\n#endif\n#include <dgl/runtime/shared_mem.h>\n#include <dmlc/logging.h>\n#include <stdio.h>\n#include <string.h>\n\n#include \"resource_manager.h\"\n\nnamespace dgl {\nnamespace runtime {\n\n/**\n * Shared memory is a resource that cannot be cleaned up if the process doesn't\n * exit normally. We'll manage the resource with ResourceManager.\n */\nclass SharedMemoryResource : public Resource {\n  std::string name;\n\n public:\n  explicit SharedMemoryResource(const std::string &name) { this->name = name; }\n\n  void Destroy() {\n    // LOG(INFO) << \"remove \" << name << \" for shared memory\";\n#ifndef _WIN32\n    shm_unlink(name.c_str());\n#else  // _WIN32\n    // NOTHING; Windows automatically removes the shared memory object once all\n    // handles are unmapped.\n#endif\n  }\n};\n\nSharedMemory::SharedMemory(const std::string &name) {\n  this->name = name;\n  this->own_ = false;\n#ifndef _WIN32\n  this->fd_ = -1;\n#else\n  this->handle_ = nullptr;\n#endif\n  this->ptr_ = nullptr;\n  this->size_ = 0;\n}\n\nSharedMemory::~SharedMemory() {\n#ifndef _WIN32\n  if (ptr_ && size_ != 0) CHECK(munmap(ptr_, size_) != -1) << strerror(errno);\n  if (fd_ != -1) close(fd_);\n  if (own_) {\n    // LOG(INFO) << \"remove \" << name << \" for shared memory\";\n    if (name != \"\") {\n      shm_unlink(name.c_str());\n      // The resource has been deleted. We don't need to keep track of it any\n      // more.\n      DeleteResource(name);\n    }\n  }\n#else\n  if (ptr_) CHECK(UnmapViewOfFile(ptr_)) << \"Win32 Error: \" << GetLastError();\n  if (handle_) CloseHandle(handle_);\n    // Windows do not need a separate shm_unlink step.\n#endif  // _WIN32\n}\n\nvoid *SharedMemory::CreateNew(size_t sz) {\n#ifndef _WIN32\n  this->own_ = true;\n\n  // We need to create a shared-memory file.\n  // TODO(zhengda) we need to report error if the shared-memory file exists.\n  int flag = O_RDWR | O_CREAT;\n  fd_ = shm_open(name.c_str(), flag, S_IRUSR | S_IWUSR);\n  CHECK_NE(fd_, -1) << \"fail to open \" << name << \": \" << strerror(errno);\n  // Shared memory cannot be deleted if the process exits abnormally in Linux.\n  AddResource(name, std::shared_ptr<Resource>(new SharedMemoryResource(name)));\n  auto res = ftruncate(fd_, sz);\n  CHECK_NE(res, -1) << \"Failed to truncate the file. \" << strerror(errno);\n  ptr_ = mmap(NULL, sz, PROT_READ | PROT_WRITE, MAP_SHARED, fd_, 0);\n  CHECK_NE(ptr_, MAP_FAILED)\n      << \"Failed to map shared memory. mmap failed with error \"\n      << strerror(errno);\n  this->size_ = sz;\n  return ptr_;\n#else\n  handle_ = CreateFileMapping(\n      INVALID_HANDLE_VALUE, nullptr, PAGE_READWRITE,\n      static_cast<DWORD>(sz >> 32), static_cast<DWORD>(sz & 0xFFFFFFFF),\n      name.c_str());\n  CHECK(handle_ != nullptr)\n      << \"fail to open \" << name << \", Win32 error: \" << GetLastError();\n  ptr_ = MapViewOfFile(handle_, FILE_MAP_ALL_ACCESS, 0, 0, sz);\n  if (ptr_ == nullptr) {\n    LOG(FATAL) << \"Memory mapping failed, Win32 error: \" << GetLastError();\n    CloseHandle(handle_);\n    return nullptr;\n  }\n  this->size_ = sz;\n  return ptr_;\n#endif  // _WIN32\n}\n\nvoid *SharedMemory::Open(size_t sz) {\n#ifndef _WIN32\n  int flag = O_RDWR;\n  fd_ = shm_open(name.c_str(), flag, S_IRUSR | S_IWUSR);\n  CHECK_NE(fd_, -1) << \"fail to open \" << name << \": \" << strerror(errno);\n  ptr_ = mmap(NULL, sz, PROT_READ | PROT_WRITE, MAP_SHARED, fd_, 0);\n  CHECK_NE(ptr_, MAP_FAILED)\n      << \"Failed to map shared memory. mmap failed with error \"\n      << strerror(errno);\n  this->size_ = sz;\n  return ptr_;\n#else\n  handle_ = OpenFileMapping(FILE_MAP_ALL_ACCESS, FALSE, name.c_str());\n  CHECK(handle_ != nullptr)\n      << \"fail to open \" << name << \", Win32 Error: \" << GetLastError();\n  ptr_ = MapViewOfFile(handle_, FILE_MAP_ALL_ACCESS, 0, 0, sz);\n  if (ptr_ == nullptr) {\n    LOG(FATAL) << \"Memory mapping failed, Win32 error: \" << GetLastError();\n    CloseHandle(handle_);\n    return nullptr;\n  }\n  this->size_ = sz;\n  return ptr_;\n#endif  // _WIN32\n}\n\nbool SharedMemory::Exist(const std::string &name) {\n#ifndef _WIN32\n  int fd = shm_open(name.c_str(), O_RDONLY, S_IRUSR | S_IWUSR);\n  if (fd >= 0) {\n    close(fd);\n    return true;\n  } else {\n    return false;\n  }\n#else\n  HANDLE handle = OpenFileMapping(FILE_MAP_ALL_ACCESS, FALSE, name.c_str());\n  if (handle != nullptr) {\n    CloseHandle(handle);\n    return true;\n  } else {\n    return false;\n  }\n#endif  // _WIN32\n}\n\n}  // namespace runtime\n}  // namespace dgl\n"
  },
  {
    "path": "src/runtime/system_lib_module.cc",
    "content": "/**\n *  Copyright (c) 2017 by Contributors\n * @file system_lib_module.cc\n * @brief SystemLib module.\n */\n#include <dgl/runtime/c_backend_api.h>\n#include <dgl/runtime/registry.h>\n\n#include <mutex>\n\n#include \"module_util.h\"\n\nnamespace dgl {\nnamespace runtime {\n\nclass SystemLibModuleNode : public ModuleNode {\n public:\n  SystemLibModuleNode() = default;\n\n  const char* type_key() const final { return \"system_lib\"; }\n\n  PackedFunc GetFunction(\n      const std::string& name,\n      const std::shared_ptr<ModuleNode>& sptr_to_self) final {\n    std::lock_guard<std::mutex> lock(mutex_);\n\n    if (module_blob_ != nullptr) {\n      // If we previously recorded submodules, load them now.\n      ImportModuleBlob(reinterpret_cast<const char*>(module_blob_), &imports_);\n      module_blob_ = nullptr;\n    }\n\n    auto it = tbl_.find(name);\n    if (it != tbl_.end()) {\n      return WrapPackedFunc(\n          reinterpret_cast<BackendPackedCFunc>(it->second), sptr_to_self);\n    } else {\n      return PackedFunc();\n    }\n  }\n\n  void RegisterSymbol(const std::string& name, void* ptr) {\n    std::lock_guard<std::mutex> lock(mutex_);\n    if (name == symbol::dgl_module_ctx) {\n      void** ctx_addr = reinterpret_cast<void**>(ptr);\n      *ctx_addr = this;\n    } else if (name == symbol::dgl_dev_mblob) {\n      // Record pointer to content of submodules to be loaded.\n      // We defer loading submodules to the first call to GetFunction().\n      // The reason is that RegisterSymbol() gets called when initializing the\n      // syslib (i.e. library loading time), and the registeries aren't ready\n      // yet. Therefore, we might not have the functionality to load submodules\n      // now.\n      CHECK(module_blob_ == nullptr) << \"Resetting mobule blob?\";\n      module_blob_ = ptr;\n    } else {\n      auto it = tbl_.find(name);\n      if (it != tbl_.end() && ptr != it->second) {\n        LOG(WARNING) << \"SystemLib symbol \" << name\n                     << \" get overriden to a different address \" << ptr << \"->\"\n                     << it->second;\n      }\n      tbl_[name] = ptr;\n    }\n  }\n\n  static const std::shared_ptr<SystemLibModuleNode>& Global() {\n    static std::shared_ptr<SystemLibModuleNode> inst =\n        std::make_shared<SystemLibModuleNode>();\n    return inst;\n  }\n\n private:\n  // Internal mutex\n  std::mutex mutex_;\n  // Internal symbol table\n  std::unordered_map<std::string, void*> tbl_;\n  // Module blob to be imported\n  void* module_blob_{nullptr};\n};\n\nDGL_REGISTER_GLOBAL(\"module._GetSystemLib\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      *rv = runtime::Module(SystemLibModuleNode::Global());\n    });\n}  // namespace runtime\n}  // namespace dgl\n\nint DGLBackendRegisterSystemLibSymbol(const char* name, void* ptr) {\n  dgl::runtime::SystemLibModuleNode::Global()->RegisterSymbol(name, ptr);\n  return 0;\n}\n"
  },
  {
    "path": "src/runtime/tensordispatch.cc",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file runtime/tensordispatch.cc\n * @brief Adapter library caller\n */\n\n#include <dgl/packed_func_ext.h>\n#include <dgl/runtime/registry.h>\n#include <dgl/runtime/tensordispatch.h>\n#if defined(WIN32) || defined(_WIN32)\n#include <windows.h>\n#else  // !WIN32\n#include <dlfcn.h>\n#endif  // WIN32\n#include <cstring>\n\nnamespace dgl {\nnamespace runtime {\n\nconstexpr const char *TensorDispatcher::names_[];\n\nbool TensorDispatcher::Load(const char *path) {\n  CHECK(!available_) << \"The tensor adapter can only load once.\";\n\n  if (path == nullptr || strlen(path) == 0)\n    // does not have dispatcher library; all operators fall back to DGL's\n    // implementation\n    return false;\n\n#if defined(WIN32) || defined(_WIN32)\n  handle_ = LoadLibrary(path);\n\n  if (!handle_) return false;\n\n  for (int i = 0; i < num_entries_; ++i) {\n    entrypoints_[i] =\n        reinterpret_cast<void *>(GetProcAddress(handle_, names_[i]));\n    CHECK(entrypoints_[i]) << \"cannot locate symbol \" << names_[i];\n  }\n#else   // !WIN32\n  handle_ = dlopen(path, RTLD_LAZY);\n\n  if (!handle_) {\n    DLOG(WARNING)\n        << \"Could not open file: \" << dlerror()\n        << \". This does not affect DGL's but might impact its performance.\";\n    return false;\n  }\n\n  for (int i = 0; i < num_entries_; ++i) {\n    entrypoints_[i] = dlsym(handle_, names_[i]);\n    CHECK(entrypoints_[i]) << \"cannot locate symbol \" << names_[i];\n  }\n#endif  // WIN32\n\n  available_ = true;\n  return true;\n}\n\nTensorDispatcher::~TensorDispatcher() {\n  if (handle_) {\n#if defined(WIN32) || defined(_WIN32)\n    FreeLibrary(handle_);\n#else   // !WIN32\n    dlclose(handle_);\n#endif  // WIN32\n  }\n}\n\n};  // namespace runtime\n};  // namespace dgl\n"
  },
  {
    "path": "src/runtime/thread_pool.cc",
    "content": "/**\n *  Copyright (c) 2017 by Contributors\n * @file thread_pool.cc\n * @brief Threadpool for multi-threading runtime.\n */\n#include <dgl/runtime/c_backend_api.h>\n#include <dgl/runtime/c_runtime_api.h>\n#include <dgl/runtime/packed_func.h>\n#include <dgl/runtime/registry.h>\n#include <dgl/runtime/threading_backend.h>\n#include <dmlc/logging.h>\n#include <dmlc/thread_local.h>\n\n#include <algorithm>\n#include <atomic>\n#include <condition_variable>\n#include <cstring>\n#include <memory>\n#include <mutex>\n#include <sstream>\n#include <string>\n#include <thread>\n#include <vector>\n\nconst constexpr int kL1CacheBytes = 64;\n\nnamespace dgl {\nnamespace runtime {\n\n// stride in the page, fit to cache line.\nconstexpr int kSyncStride = 64 / sizeof(std::atomic<int>);\n\n/**\n * @brief Thread local master environment.\n */\nclass ParallelLauncher {\n public:\n  // Reset the the task request.\n  void Init(\n      FDGLParallelLambda flambda, void* cdata, int num_task, bool need_sync) {\n    num_pending_.store(num_task);\n    this->cdata = cdata;\n    this->flambda = flambda;\n    this->env.num_task = num_task;\n    has_error_.store(false);\n    // reshape\n    if (static_cast<size_t>(num_task) > par_errors_.size()) {\n      par_errors_.resize(num_task + 1);\n      if (need_sync) {\n        delete[] sync_counter_;\n        sync_counter_ = new std::atomic<int>[num_task * kSyncStride];\n      }\n    }\n    if (need_sync) {\n      for (int i = 0; i < num_task; ++i) {\n        sync_counter_[i * kSyncStride].store(0, std::memory_order_relaxed);\n      }\n      this->env.sync_handle = sync_counter_;\n    } else {\n      this->env.sync_handle = nullptr;\n    }\n  }\n  ~ParallelLauncher() { delete[] sync_counter_; }\n  // Wait n jobs to finish\n  int WaitForJobs() {\n    while (num_pending_.load() != 0) {\n      dgl::runtime::threading::YieldThread();\n    }\n    if (!has_error_.load()) return 0;\n    // the following is intended to use string due to\n    // security issue raised in SGX backend\n    std::string err(\"\");\n    for (size_t i = 0; i < par_errors_.size(); ++i) {\n      if (par_errors_[i].length() != 0) {\n        err += \"Task \" + std::to_string(i) + \" error: \" + par_errors_[i] + '\\n';\n        par_errors_[i].clear();\n      }\n    }\n    DGLAPISetLastError(err.c_str());\n    return -1;\n  }\n  // Signal that one job has finished.\n  void SignalJobError(int task_id) {\n    num_pending_.fetch_sub(1);\n    par_errors_[task_id] = DGLGetLastError();\n    has_error_.store(true);\n  }\n  // Signal that one job has finished.\n  void SignalJobFinish() { num_pending_.fetch_sub(1); }\n  // Get thread local version of the store.\n  static ParallelLauncher* ThreadLocal() {\n    return dmlc::ThreadLocalStore<ParallelLauncher>::Get();\n  }\n  // The parallel lambda\n  FDGLParallelLambda flambda;\n  // The closure data\n  void* cdata;\n  // Local env\n  DGLParallelGroupEnv env;\n  // Whether this thread is worker of the pool.\n  // used to prevent recursive launch.\n  bool is_worker{false};\n\n private:\n  // The pending jobs.\n  std::atomic<int32_t> num_pending_;\n  // Whether error has been countered.\n  std::atomic<bool> has_error_;\n  // The counter page.\n  std::atomic<int32_t>* sync_counter_{nullptr};\n  // The error message\n  std::vector<std::string> par_errors_;\n};\n\n/** @brief Lock-free single-producer-single-consumer queue for each thread */\nclass SpscTaskQueue {\n public:\n  /** @brief The task entry */\n  struct Task {\n    ParallelLauncher* launcher;\n    int32_t task_id;\n  };\n\n  SpscTaskQueue() : buffer_(new Task[kRingSize]), head_(0), tail_(0) {}\n\n  ~SpscTaskQueue() { delete[] buffer_; }\n\n  /**\n   * @brief Push a task into the queue and notify the comsumer if it is on wait.\n   * @param input The task to be dequeued.\n   */\n  void Push(const Task& input) {\n    while (!Enqueue(input)) {\n      dgl::runtime::threading::YieldThread();\n    }\n    if (pending_.fetch_add(1) == -1) {\n      std::unique_lock<std::mutex> lock(mutex_);\n      cv_.notify_one();\n    }\n  }\n\n  /**\n   * @brief Pop a task out of the queue and condition wait if no tasks.\n   * @param output The pointer to the task to be dequeued.\n   * @param spin_count The number of iterations to spin before sleep.\n   * @return Whether pop is successful (true) or we need to exit now (false).\n   */\n  bool Pop(Task* output, uint32_t spin_count = 300000) {\n    // Busy wait a bit when the queue is empty.\n    // If a new task comes to the queue quickly, this wait avoid the worker from\n    // sleeping. The default spin count is set by following the typical omp\n    // convention\n    for (uint32_t i = 0; i < spin_count && pending_.load() == 0; ++i) {\n      dgl::runtime::threading::YieldThread();\n    }\n    if (pending_.fetch_sub(1) == 0) {\n      std::unique_lock<std::mutex> lock(mutex_);\n      cv_.wait(\n          lock, [this] { return pending_.load() >= 0 || exit_now_.load(); });\n    }\n    if (exit_now_.load(std::memory_order_relaxed)) {\n      return false;\n    }\n    const uint32_t head = head_.load(std::memory_order_relaxed);\n    // sanity check if the queue is empty\n    CHECK(tail_.load(std::memory_order_acquire) != head);\n    *output = buffer_[head];\n    head_.store((head + 1) % kRingSize, std::memory_order_release);\n    return true;\n  }\n\n  /**\n   * @brief Signal to terminate the worker.\n   */\n  void SignalForKill() {\n    std::lock_guard<std::mutex> lock(mutex_);\n    exit_now_.store(true);\n    cv_.notify_all();\n  }\n\n protected:\n  /**\n   * @brief Lock-free enqueue.\n   * @param input The task to be enqueued.\n   * @return Whether the task is enqueued.\n   */\n  bool Enqueue(const Task& input) {\n    if (exit_now_.load(std::memory_order_relaxed)) return false;\n\n    const uint32_t tail = tail_.load(std::memory_order_relaxed);\n\n    if ((tail + 1) % kRingSize != (head_.load(std::memory_order_acquire))) {\n      buffer_[tail] = input;\n      tail_.store((tail + 1) % kRingSize, std::memory_order_release);\n      return true;\n    }\n    return false;\n  }\n\n  // the cache line paddings are used for avoid false sharing between atomic\n  // variables\n  typedef char cache_line_pad_t[kL1CacheBytes];\n  cache_line_pad_t pad0_;\n  // size of the queue, the queue can host size_ - 1 items at most\n  // define it as a constant for better compiler optimization\n  static constexpr const int kRingSize = 2;\n  // pointer to access the item\n  Task* const buffer_;\n\n  cache_line_pad_t pad1_;\n  // queue head, where one gets a task from the queue\n  std::atomic<uint32_t> head_;\n\n  cache_line_pad_t pad2_;\n  // queue tail, when one puts a task to the queue\n  std::atomic<uint32_t> tail_;\n\n  cache_line_pad_t pad3_;\n  // pending tasks in the queue\n  std::atomic<int8_t> pending_{0};\n\n  cache_line_pad_t pad4_;\n  // signal for exit now\n  std::atomic<bool> exit_now_{false};\n\n  // internal mutex\n  std::mutex mutex_;\n  // cv for consumer\n  std::condition_variable cv_;\n};\n\n// The thread pool\nclass ThreadPool {\n public:\n  ThreadPool() : num_workers_(dgl::runtime::threading::MaxConcurrency()) {\n    for (int i = 0; i < num_workers_; ++i) {\n      // The SpscTaskQueue only hosts ONE item at a time\n      queues_.emplace_back(std::unique_ptr<SpscTaskQueue>(new SpscTaskQueue()));\n    }\n    threads_ = std::unique_ptr<dgl::runtime::threading::ThreadGroup>(\n        new dgl::runtime::threading::ThreadGroup(\n            num_workers_, [this](int worker_id) { this->RunWorker(worker_id); },\n            exclude_worker0_ /* include_main_thread */));\n    num_workers_used_ =\n        threads_->Configure(threading::ThreadGroup::kBig, 0, exclude_worker0_);\n  }\n  ~ThreadPool() {\n    for (std::unique_ptr<SpscTaskQueue>& q : queues_) {\n      q->SignalForKill();\n    }\n    threads_.reset();\n  }\n  int Launch(\n      FDGLParallelLambda flambda, void* cdata, int num_task, int need_sync) {\n    ParallelLauncher* launcher = ParallelLauncher::ThreadLocal();\n    CHECK(!launcher->is_worker) << \"Cannot launch parallel job inside worker, \"\n                                   \"consider fuse then parallel\";\n    if (num_task == 0) {\n      num_task = num_workers_used_;\n    }\n    if (need_sync != 0) {\n      CHECK_LE(num_task, num_workers_used_)\n          << \"Request parallel sync task larger than number of threads used \"\n          << \" workers=\" << num_workers_used_ << \" request=\" << num_task;\n    }\n    launcher->Init(flambda, cdata, num_task, need_sync != 0);\n    SpscTaskQueue::Task tsk;\n    tsk.launcher = launcher;\n    // if worker0 is taken by the master, queues_[0] is abandoned\n    for (int i = exclude_worker0_; i < num_task; ++i) {\n      tsk.task_id = i;\n      queues_[i]->Push(tsk);\n    }\n    // use the master thread to run task 0\n    if (exclude_worker0_) {\n      DGLParallelGroupEnv* penv = &(tsk.launcher->env);\n      if ((*tsk.launcher->flambda)(0, penv, cdata) == 0) {\n        tsk.launcher->SignalJobFinish();\n      } else {\n        tsk.launcher->SignalJobError(tsk.task_id);\n      }\n    }\n    int res = launcher->WaitForJobs();\n    return res;\n  }\n\n  static ThreadPool* ThreadLocal() {\n    return dmlc::ThreadLocalStore<ThreadPool>::Get();\n  }\n\n  void UpdateWorkerConfiguration(\n      threading::ThreadGroup::AffinityMode mode, int nthreads) {\n    // this will also reset the affinity of the ThreadGroup\n    // may use less than the MaxConcurrency number of workers\n    num_workers_used_ = threads_->Configure(mode, nthreads, exclude_worker0_);\n    // if MaxConcurrency restricted the number of workers (e.g., due to\n    // hyperthreading), respect the restriction\n    num_workers_used_ = std::min(num_workers_, num_workers_used_);\n  }\n\n private:\n  // Internal worker function.\n  void RunWorker(int worker_id) {\n    SpscTaskQueue* queue = queues_[worker_id].get();\n    SpscTaskQueue::Task task;\n    ParallelLauncher::ThreadLocal()->is_worker = true;\n    while (queue->Pop(&task)) {\n      CHECK(task.launcher != nullptr);\n      DGLParallelGroupEnv* penv = &(task.launcher->env);\n      void* cdata = task.launcher->cdata;\n      if ((*task.launcher->flambda)(task.task_id, penv, cdata) == 0) {\n        task.launcher->SignalJobFinish();\n      } else {\n        task.launcher->SignalJobError(task.task_id);\n      }\n    }\n  }\n  int num_workers_;\n  // number of workers used (can be restricted with affinity pref)\n  int num_workers_used_;\n  // if excluding worker 0 and using master to run task 0\n#ifndef _LIBCPP_SGX_CONFIG\n  bool exclude_worker0_{true};\n#else\n  bool exclude_worker0_{false};\n#endif\n  std::vector<std::unique_ptr<SpscTaskQueue> > queues_;\n  std::unique_ptr<dgl::runtime::threading::ThreadGroup> threads_;\n};\n\nDGL_REGISTER_GLOBAL(\"runtime.config_threadpool\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      threading::ThreadGroup::AffinityMode mode =\n          static_cast<threading::ThreadGroup::AffinityMode>(\n              static_cast<int>(args[0]));\n      int nthreads = args[1];\n      ThreadPool::ThreadLocal()->UpdateWorkerConfiguration(mode, nthreads);\n    });\n\n}  // namespace runtime\n}  // namespace dgl\n\nint DGLBackendParallelLaunch(\n    FDGLParallelLambda flambda, void* cdata, int num_task) {\n  int res = dgl::runtime::ThreadPool::ThreadLocal()->Launch(\n      flambda, cdata, num_task, 1);\n  return res;\n}\n\nint DGLBackendParallelBarrier(int task_id, DGLParallelGroupEnv* penv) {\n  using dgl::runtime::kSyncStride;\n  int num_task = penv->num_task;\n  std::atomic<int>* sync_counter =\n      reinterpret_cast<std::atomic<int>*>(penv->sync_handle);\n  int old_counter = sync_counter[task_id * kSyncStride].fetch_add(\n      1, std::memory_order_release);\n  for (int i = 0; i < num_task; ++i) {\n    if (i != task_id) {\n      while (sync_counter[i * kSyncStride].load(std::memory_order_relaxed) <=\n             old_counter) {\n        dgl::runtime::threading::YieldThread();\n      }\n    }\n  }\n  std::atomic_thread_fence(std::memory_order_acquire);\n  return 0;\n}\n"
  },
  {
    "path": "src/runtime/thread_storage_scope.h",
    "content": "/**\n *  Copyright (c) 2017 by Contributors\n * @file thread_storage_scope.h\n * @brief Extract thread axis configuration from DGLArgs.\n */\n#ifndef DGL_RUNTIME_THREAD_STORAGE_SCOPE_H_\n#define DGL_RUNTIME_THREAD_STORAGE_SCOPE_H_\n\n#include <dgl/runtime/packed_func.h>\n\n#include <string>\n#include <vector>\n\nnamespace dgl {\nnamespace runtime {\n\n/**\n * @brief Memory hierachy rank in the storage system\n * @note The global rank and shared rank have one to one\n *       correspondence to the thread rank.\n */\nenum class StorageRank {\n  /** @brief global memory */\n  kGlobal = 0,\n  /** @brief shared memory among thread group */\n  kShared = 1,\n  /**\n   * @brief reserved for warp memory.\n   *  This is only used by programming model.\n   *  There is no such memory usually in GPU.\n   *  Instead, we can simulate it by registers and shuffle.\n   */\n  kWarp = 2,\n  /** @brief thread local memory */\n  kLocal = 3\n};\n\n/**\n * @param thread_scope_rank The thread scope rank\n * @return default storage rank given the thread scope\n */\ninline StorageRank DefaultStorageRank(int thread_scope_rank) {\n  switch (thread_scope_rank) {\n    case -1:\n      return StorageRank::kGlobal;\n    case 0:\n      return StorageRank::kShared;\n    case 1:\n      return StorageRank::kLocal;\n    default: {\n      LOG(FATAL) << \"unknown rank\";\n      return StorageRank::kGlobal;\n    }\n  }\n}\n\n/** @brief class to represent storage scope */\nstruct StorageScope {\n  /** @brief The rank of the storage */\n  StorageRank rank{StorageRank::kGlobal};\n  /** @brief tag for special purpose memory. */\n  std::string tag;\n  // comparator\n  inline bool operator==(const StorageScope& other) const {\n    return rank == other.rank && tag == other.tag;\n  }\n  inline bool operator!=(const StorageScope& other) const {\n    return !(*this == other);\n  }\n  inline std::string to_string() const {\n    std::string ret;\n    switch (rank) {\n      case StorageRank::kGlobal:\n        return \"global\" + tag;\n      case StorageRank::kShared:\n        return \"shared\" + tag;\n      case StorageRank::kWarp:\n        return \"warp\" + tag;\n      case StorageRank::kLocal:\n        return \"local\" + tag;\n      default:\n        LOG(FATAL) << \"unknown storage scope\";\n        return \"\";\n    }\n  }\n  /**\n   * @brief make storage scope from string\n   * @param s The string to be parsed.\n   * @return The storage scope.\n   */\n  static StorageScope make(const std::string& s) {\n    StorageScope r;\n    if (s.compare(0, 6, \"global\") == 0) {\n      r.rank = StorageRank::kGlobal;\n      r.tag = s.substr(6, std::string::npos);\n    } else if (s.compare(0, 6, \"shared\") == 0) {\n      r.rank = StorageRank::kShared;\n      r.tag = s.substr(6, std::string::npos);\n    } else if (s.compare(0, 4, \"warp\") == 0) {\n      r.rank = StorageRank::kWarp;\n      r.tag = s.substr(4, std::string::npos);\n    } else if (s.compare(0, 5, \"local\") == 0) {\n      r.rank = StorageRank::kLocal;\n      r.tag = s.substr(5, std::string::npos);\n    } else {\n      LOG(FATAL) << \"unknown storage scope \" << s;\n    }\n    return r;\n  }\n};\n\n/** @brief class to represent thread scope */\nstruct ThreadScope {\n  /** @brief The rank of thread scope */\n  int rank{0};\n  /** @brief the dimension index under the rank */\n  int dim_index{0};\n  /**\n   * @brief make storage scope from string\n   * @param s The string to be parsed.\n   * @return The storage scope.\n   */\n  static ThreadScope make(const std::string& s) {\n    ThreadScope r;\n    if (s == \"vthread\" || s == \"cthread\") {\n      // virtual thread at the same level as local\n      r.rank = 1;\n      r.dim_index = -1;\n    } else if (s.compare(0, 9, \"blockIdx.\") == 0) {\n      r.rank = 0;\n      r.dim_index = static_cast<int>(s[9] - 'x');\n    } else if (s.compare(0, 10, \"threadIdx.\") == 0) {\n      r.rank = 1;\n      r.dim_index = static_cast<int>(s[10] - 'x');\n    } else {\n      LOG(FATAL) << \"Unknown threadscope \" << s;\n    }\n    return r;\n  }\n};\n\n/** @brief workload speccification */\nstruct ThreadWorkLoad {\n  // array, first three are thread configuration.\n  size_t work_size[6];\n  /**\n   * @param i The block dimension.\n   * @return i-th block dim\n   */\n  inline size_t block_dim(size_t i) const { return work_size[i + 3]; }\n  /**\n   * @param i The grid dimension.\n   * @return i-th grid dim\n   */\n  inline size_t grid_dim(size_t i) const { return work_size[i]; }\n};\n/** @brief Thread axis configuration */\nclass ThreadAxisConfig {\n public:\n  void Init(size_t base, const std::vector<std::string>& thread_axis_tags) {\n    base_ = base;\n    std::vector<bool> filled(6, false);\n    for (size_t i = 0; i < thread_axis_tags.size(); ++i) {\n      const std::string& tag = thread_axis_tags[i];\n      ThreadScope ts = ThreadScope::make(tag);\n      arg_index_map_.push_back(ts.rank * 3 + ts.dim_index);\n      filled[ts.rank * 3 + ts.dim_index] = true;\n    }\n    work_dim_ = 1;\n    for (int i = 0; i < 3; ++i) {\n      if (filled[i] || filled[i + 3]) {\n        work_dim_ = i + 1;\n      }\n    }\n  }\n  // extract workload from arguments.\n  ThreadWorkLoad Extract(DGLArgs x) const {\n    ThreadWorkLoad w;\n    std::fill(w.work_size, w.work_size + 6, 1);\n    for (size_t i = 0; i < arg_index_map_.size(); ++i) {\n      w.work_size[arg_index_map_[i]] =\n          static_cast<size_t>(x.values[base_ + i].v_int64);\n    }\n    return w;\n  }\n  // return the work dim\n  size_t work_dim() const { return work_dim_; }\n\n private:\n  /** @brief base axis */\n  size_t base_;\n  /** @brief The worker dimension */\n  size_t work_dim_;\n  /** @brief The index mapping. */\n  std::vector<uint32_t> arg_index_map_;\n};\n\n}  // namespace runtime\n}  // namespace dgl\n\nnamespace std {\ntemplate <>\nstruct hash<::dgl::runtime::StorageScope> {\n  std::size_t operator()(const ::dgl::runtime::StorageScope& k) const {\n    return static_cast<size_t>(k.rank);\n  }\n};\n}  // namespace std\n#endif  // DGL_RUNTIME_THREAD_STORAGE_SCOPE_H_\n"
  },
  {
    "path": "src/runtime/threading_backend.cc",
    "content": "/**\n *  Copyright (c) 2018 by Contributors\n * @file threading_backend.cc\n * @brief Native threading backend\n */\n#include <dgl/runtime/threading_backend.h>\n#include <dmlc/logging.h>\n\n#include <algorithm>\n#include <thread>\n#if defined(__linux__) || defined(__ANDROID__)\n#include <fstream>\n#else\n#endif\n#if defined(__linux__)\n#include <sched.h>\n#endif\n\nnamespace dgl {\nnamespace runtime {\nnamespace threading {\n\nclass ThreadGroup::Impl {\n public:\n  Impl(\n      int num_workers, std::function<void(int)> worker_callback,\n      bool exclude_worker0)\n      : num_workers_(num_workers) {\n    CHECK_GE(num_workers, 1)\n        << \"Requested a non-positive number of worker threads.\";\n    for (int i = exclude_worker0; i < num_workers_; ++i) {\n      threads_.emplace_back([worker_callback, i] { worker_callback(i); });\n    }\n    InitSortedOrder();\n  }\n  ~Impl() { Join(); }\n\n  void Join() {\n    for (auto &t : threads_) {\n      if (t.joinable()) t.join();\n    }\n  }\n\n  int Configure(AffinityMode mode, int nthreads, bool exclude_worker0) {\n    int num_workers_used = 0;\n    if (mode == kLittle) {\n      num_workers_used = little_count_;\n    } else if (mode == kBig) {\n      num_workers_used = big_count_;\n    } else {\n      // use default\n      num_workers_used = threading::MaxConcurrency();\n    }\n    // if a specific number was given, use that\n    if (nthreads) {\n      num_workers_used = nthreads;\n    }\n    // if MaxConcurrency restricted the number of workers (e.g., due to\n    // hyperthreading), respect the restriction. On CPUs with N logical cores\n    // and N/2 physical cores this will set affinity to the first N/2 logical\n    // ones.\n    num_workers_used = std::min(num_workers_, num_workers_used);\n\n    const char *val = getenv(\"DGL_BIND_THREADS\");\n    if (val == nullptr || atoi(val) == 1) {\n      // Do not set affinity if there are more workers than found cores\n      if (sorted_order_.size() >= static_cast<unsigned int>(num_workers_)) {\n        SetAffinity(exclude_worker0, mode == kLittle);\n      } else {\n        LOG(WARNING)\n            << \"The thread affinity cannot be set when the number of workers\"\n            << \"is larger than the number of available cores in the system.\";\n      }\n    }\n    return num_workers_used;\n  }\n\n private:\n  // bind worker threads to disjoint cores\n  // if worker 0 is offloaded to master, i.e. exclude_worker0 is true,\n  // the master thread is bound to core 0.\n  void SetAffinity(bool exclude_worker0, bool reverse = false) {\n#if defined(__ANDROID__)\n#ifndef CPU_SET\n#define CPU_SETSIZE 1024\n#define __NCPUBITS (8 * sizeof(uint64_t))\n    typedef struct {\n      uint64_t __bits[CPU_SETSIZE / __NCPUBITS];\n    } cpu_set_t;\n\n#define CPU_SET(cpu, cpusetp) \\\n  ((cpusetp)->__bits[(cpu) / __NCPUBITS] |= (1UL << ((cpu) % __NCPUBITS)))\n#define CPU_ZERO(cpusetp) memset((cpusetp), 0, sizeof(cpu_set_t))\n#endif\n#endif\n#if defined(__linux__) || defined(__ANDROID__)\n    CHECK_GE(sorted_order_.size(), num_workers_);\n\n    for (unsigned i = 0; i < threads_.size(); ++i) {\n      unsigned core_id;\n      if (reverse) {\n        core_id =\n            sorted_order_[sorted_order_.size() - (i + exclude_worker0) - 1];\n      } else {\n        core_id = sorted_order_[i + exclude_worker0];\n      }\n      cpu_set_t cpuset;\n      CPU_ZERO(&cpuset);\n      CPU_SET(core_id, &cpuset);\n#if defined(__ANDROID__)\n      sched_setaffinity(\n          threads_[i].native_handle(), sizeof(cpu_set_t), &cpuset);\n#else\n      pthread_setaffinity_np(\n          threads_[i].native_handle(), sizeof(cpu_set_t), &cpuset);\n#endif\n    }\n    if (exclude_worker0) {  // bind the master thread to core 0\n      cpu_set_t cpuset;\n      CPU_ZERO(&cpuset);\n      if (reverse) {\n        CPU_SET(sorted_order_[sorted_order_.size() - 1], &cpuset);\n      } else {\n        CPU_SET(sorted_order_[0], &cpuset);\n      }\n#if defined(__ANDROID__)\n      sched_setaffinity(pthread_self(), sizeof(cpu_set_t), &cpuset);\n#else\n      pthread_setaffinity_np(pthread_self(), sizeof(cpu_set_t), &cpuset);\n#endif\n    }\n#endif\n  }\n\n  void InitSortedOrder() {\n    unsigned int threads = std::thread::hardware_concurrency();\n    std::vector<std::pair<unsigned int, int64_t> > max_freqs;\n\n    for (unsigned int i = 0; i < threads; ++i) {\n      int64_t cur_freq = 0;\n#if defined(__linux__) || defined(__ANDROID__)\n      std::ostringstream filepath;\n      filepath << \"/sys/devices/system/cpu/cpu\" << i\n               << \"/cpufreq/cpuinfo_max_freq\";\n      std::ifstream ifs(filepath.str());\n      if (!ifs.fail()) {\n        if (!(ifs >> cur_freq)) {\n          cur_freq = -1;\n        }\n        ifs.close();\n      }\n#endif\n      max_freqs.push_back(std::make_pair(i, cur_freq));\n    }\n\n    auto fcmpbyfreq = [](const std::pair<unsigned int, int64_t> &a,\n                         const std::pair<unsigned int, int64_t> &b) {\n      return a.second == b.second ? a.first < b.first : a.second > b.second;\n    };\n    std::sort(max_freqs.begin(), max_freqs.end(), fcmpbyfreq);\n    int64_t big_freq = max_freqs.begin()->second;\n    int64_t little_freq = max_freqs.rbegin()->second;\n    for (auto it = max_freqs.begin(); it != max_freqs.end(); it++) {\n      sorted_order_.push_back(it->first);\n      if (big_freq == it->second) {\n        big_count_++;\n      }\n      if (big_freq != little_freq && little_freq == it->second) {\n        little_count_++;\n      }\n    }\n    if (big_count_ + little_count_ != static_cast<int>(sorted_order_.size())) {\n      LOG(WARNING) << \"more than two frequencies detected!\";\n    }\n  }\n\n  int num_workers_;\n  std::vector<std::thread> threads_;\n  std::vector<unsigned int> sorted_order_;\n  int big_count_ = 0;\n  int little_count_ = 0;\n};\n\nThreadGroup::ThreadGroup(\n    int num_workers, std::function<void(int)> worker_callback,\n    bool exclude_worker0)\n    : impl_(new ThreadGroup::Impl(\n          num_workers, worker_callback, exclude_worker0)) {}\nThreadGroup::~ThreadGroup() { delete impl_; }\nvoid ThreadGroup::Join() { impl_->Join(); }\n\nint ThreadGroup::Configure(\n    AffinityMode mode, int nthreads, bool exclude_worker0) {\n  return impl_->Configure(mode, nthreads, exclude_worker0);\n}\n\nvoid YieldThread() { std::this_thread::yield(); }\n\nint MaxConcurrency() {\n  int max_concurrency = 1;\n  const char *val = getenv(\"DGL_NUM_THREADS\");\n  if (val == nullptr) {\n    val = getenv(\"OMP_NUM_THREADS\");\n  }\n  if (val != nullptr) {\n    max_concurrency = atoi(val);\n  } else {\n    max_concurrency = std::thread::hardware_concurrency();\n#if defined(_M_X64) || defined(__x86_64__)\n    max_concurrency /= 2;  // ignore hyper-threading\n#endif\n  }\n  return std::max(max_concurrency, 1);\n}\n\n}  // namespace threading\n}  // namespace runtime\n}  // namespace dgl\n"
  },
  {
    "path": "src/runtime/utils.cc",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file utils.cc\n * @brief DGL util functions\n */\n\n#include <dgl/aten/coo.h>\n#include <dgl/packed_func_ext.h>\n#include <dmlc/omp.h>\n\n#include <utility>\n\n#include \"../array/array_op.h\"\n#include \"../c_api_common.h\"\n\nusing namespace dgl::runtime;\nusing namespace dgl::aten::impl;\n\nnamespace dgl {\n\nDGL_REGISTER_GLOBAL(\"utils.internal._CAPI_DGLSetOMPThreads\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      int num_threads = args[0];\n      omp_set_num_threads(num_threads);\n    });\n\nDGL_REGISTER_GLOBAL(\"utils.internal._CAPI_DGLGetOMPThreads\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      *rv = omp_get_max_threads();\n    });\n\nDGL_REGISTER_GLOBAL(\"utils.checks._CAPI_DGLCOOIsSorted\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      IdArray src = args[0];\n      IdArray dst = args[1];\n      int64_t num_src = args[2];\n      int64_t num_dst = args[3];\n\n      bool row_sorted, col_sorted;\n      std::tie(row_sorted, col_sorted) =\n          COOIsSorted(aten::COOMatrix(num_src, num_dst, src, dst));\n\n      // make sure col_sorted is only true when row_sorted is true\n      assert(!(!row_sorted && col_sorted));\n\n      // 0 for unosrted, 1 for row sorted, 2 for row and col sorted\n      int64_t sorted_status = row_sorted + col_sorted;\n      *rv = sorted_status;\n    });\n\n}  // namespace dgl\n"
  },
  {
    "path": "src/runtime/workspace.h",
    "content": "/**\n *  Copyright (c) 2021 by Contributors\n * @file ndarray_partition.h\n * @brief Operations on partition implemented in CUDA.\n */\n\n#ifndef DGL_RUNTIME_WORKSPACE_H_\n#define DGL_RUNTIME_WORKSPACE_H_\n\n#include <dgl/runtime/device_api.h>\n\n#include <cassert>\n\nnamespace dgl {\nnamespace runtime {\n\ntemplate <typename T>\nclass Workspace {\n public:\n  Workspace(DeviceAPI* device, DGLContext ctx, const size_t size)\n      : device_(device),\n        ctx_(ctx),\n        size_(size * sizeof(T)),\n        ptr_(static_cast<T*>(device_->AllocWorkspace(ctx_, size_))) {}\n\n  ~Workspace() {\n    if (*this) {\n      free();\n    }\n  }\n\n  operator bool() const { return ptr_ != nullptr; }\n\n  T* get() {\n    assert(size_ == 0 || *this);\n    return ptr_;\n  }\n\n  T const* get() const {\n    assert(size_ == 0 || *this);\n    return ptr_;\n  }\n\n  void free() {\n    assert(size_ == 0 || *this);\n    device_->FreeWorkspace(ctx_, ptr_);\n    ptr_ = nullptr;\n  }\n\n private:\n  DeviceAPI* device_;\n  DGLContext ctx_;\n  size_t size_;\n  T* ptr_;\n};\n\ntemplate <>\nclass Workspace<void> {\n public:\n  Workspace(DeviceAPI* device, DGLContext ctx, const size_t size)\n      : device_(device),\n        ctx_(ctx),\n        size_(size),\n        ptr_(static_cast<void*>(device_->AllocWorkspace(ctx_, size_))) {}\n\n  ~Workspace() {\n    if (*this) {\n      free();\n    }\n  }\n\n  operator bool() const { return ptr_ != nullptr; }\n\n  void* get() {\n    assert(size_ == 0 || *this);\n    return ptr_;\n  }\n\n  void const* get() const {\n    assert(size_ == 0 || *this);\n    return ptr_;\n  }\n\n  void free() {\n    assert(size_ == 0 || *this);\n    device_->FreeWorkspace(ctx_, ptr_);\n    ptr_ = nullptr;\n  }\n\n private:\n  DeviceAPI* device_;\n  DGLContext ctx_;\n  size_t size_;\n  void* ptr_;\n};\n\n}  // namespace runtime\n}  // namespace dgl\n\n#endif  // DGL_RUNTIME_WORKSPACE_H_\n"
  },
  {
    "path": "src/runtime/workspace_pool.cc",
    "content": "/**\n *  Copyright (c) 2017 by Contributors\n * @file workspace_pool.h\n * @brief Workspace pool utility.\n */\n#include \"workspace_pool.h\"\n\n#include <memory>\n\nnamespace dgl {\nnamespace runtime {\n\n// page size.\nconstexpr size_t kWorkspacePageSize = 4 << 10;\n\nclass WorkspacePool::Pool {\n public:\n  // constructor\n  Pool() {\n    // safe guard header on each list.\n    Entry e;\n    e.data = nullptr;\n    e.size = 0;\n    free_list_.push_back(e);\n    allocated_.push_back(e);\n  }\n  // allocate from pool\n  void* Alloc(DGLContext ctx, DeviceAPI* device, size_t nbytes) {\n    // Allocate align to page.\n    nbytes = (nbytes + (kWorkspacePageSize - 1)) / kWorkspacePageSize *\n             kWorkspacePageSize;\n    if (nbytes == 0) nbytes = kWorkspacePageSize;\n    Entry e;\n    DGLDataType type;\n    type.code = kDGLUInt;\n    type.bits = 8;\n    type.lanes = 1;\n    if (free_list_.size() == 2) {\n      e = free_list_.back();\n      free_list_.pop_back();\n      if (e.size < nbytes) {\n        // resize the page\n        device->FreeDataSpace(ctx, e.data);\n        e.data =\n            device->AllocDataSpace(ctx, nbytes, kTempAllocaAlignment, type);\n        e.size = nbytes;\n      }\n    } else if (free_list_.size() == 1) {\n      e.data = device->AllocDataSpace(ctx, nbytes, kTempAllocaAlignment, type);\n      e.size = nbytes;\n    } else {\n      if (free_list_.back().size >= nbytes) {\n        // find smallest fit\n        auto it = free_list_.end() - 2;\n        for (; it->size >= nbytes; --it) {\n        }\n        e = *(it + 1);\n        free_list_.erase(it + 1);\n      } else {\n        // resize the page\n        e = free_list_.back();\n        free_list_.pop_back();\n        device->FreeDataSpace(ctx, e.data);\n        e.data =\n            device->AllocDataSpace(ctx, nbytes, kTempAllocaAlignment, type);\n        e.size = nbytes;\n      }\n    }\n    allocated_.push_back(e);\n    return e.data;\n  }\n  // free resource back to pool\n  void Free(void* data) {\n    Entry e;\n    if (allocated_.back().data == data) {\n      // quick path, last allocated.\n      e = allocated_.back();\n      allocated_.pop_back();\n    } else {\n      int index = static_cast<int>(allocated_.size()) - 2;\n      for (; index > 0 && allocated_[index].data != data; --index) {\n      }\n      CHECK_GT(index, 0) << \"trying to free things that has not been allocated\";\n      e = allocated_[index];\n      allocated_.erase(allocated_.begin() + index);\n    }\n    if (free_list_.back().size < e.size) {\n      free_list_.push_back(e);\n    } else if (free_list_.size() == 2) {\n      free_list_.push_back(free_list_.back());\n      free_list_[1] = e;\n    } else {\n      size_t i = free_list_.size() - 1;\n      free_list_.resize(free_list_.size() + 1);\n      for (; e.size < free_list_[i].size; --i) {\n        free_list_[i + 1] = free_list_[i];\n      }\n      free_list_[i + 1] = e;\n    }\n  }\n  // Release all resources\n  void Release(DGLContext ctx, DeviceAPI* device) {\n    CHECK_EQ(allocated_.size(), 1);\n    for (size_t i = 1; i < free_list_.size(); ++i) {\n      device->FreeDataSpace(ctx, free_list_[i].data);\n    }\n    free_list_.clear();\n  }\n\n private:\n  /** @brief a single entry in the pool */\n  struct Entry {\n    void* data;\n    size_t size;\n  };\n  /** @brief List of free items, sorted from small to big size */\n  std::vector<Entry> free_list_;\n  /** @brief List of allocated items */\n  std::vector<Entry> allocated_;\n};\n\nWorkspacePool::WorkspacePool(\n    DGLDeviceType device_type, std::shared_ptr<DeviceAPI> device)\n    : device_type_(device_type), device_(device) {}\n\nWorkspacePool::~WorkspacePool() {\n  /**\n   * Note that the following code will cause Segmentation fault with MXNet.\n   * Since we're phasing out MXNet, it's acceptable to keep it as it is.\n   * Commenting out the following code will cause memory leak.\n   */\n  for (size_t i = 0; i < array_.size(); ++i) {\n    if (array_[i] != nullptr) {\n      DGLContext ctx;\n      ctx.device_type = device_type_;\n      ctx.device_id = static_cast<int>(i);\n      array_[i]->Release(ctx, device_.get());\n      delete array_[i];\n    }\n  }\n}\n\nvoid* WorkspacePool::AllocWorkspace(DGLContext ctx, size_t size) {\n  if (static_cast<size_t>(ctx.device_id) >= array_.size()) {\n    array_.resize(ctx.device_id + 1, nullptr);\n  }\n  if (array_[ctx.device_id] == nullptr) {\n    array_[ctx.device_id] = new Pool();\n  }\n  return array_[ctx.device_id]->Alloc(ctx, device_.get(), size);\n}\n\nvoid WorkspacePool::FreeWorkspace(DGLContext ctx, void* ptr) {\n  CHECK(\n      static_cast<size_t>(ctx.device_id) < array_.size() &&\n      array_[ctx.device_id] != nullptr);\n  array_[ctx.device_id]->Free(ptr);\n}\n\n}  // namespace runtime\n}  // namespace dgl\n"
  },
  {
    "path": "src/runtime/workspace_pool.h",
    "content": "/**\n *  Copyright (c) 2017 by Contributors\n * @file workspace_pool.h\n * @brief Workspace pool utility.\n */\n#ifndef DGL_RUNTIME_WORKSPACE_POOL_H_\n#define DGL_RUNTIME_WORKSPACE_POOL_H_\n\n#include <dgl/runtime/device_api.h>\n\n#include <memory>\n#include <vector>\n\nnamespace dgl {\nnamespace runtime {\n/**\n * @brief A workspace pool to manage\n *\n *  \\note We have the following assumption about backend temporal\n *   workspace allocation, and will optimize for such assumption,\n *   some of these assumptions can be enforced by the compiler.\n *\n *  - Only a few allocation will happen, and space will be released after use.\n *  - The release order is usually in reverse order of allocate\n *  - Repeative pattern of same allocations over different runs.\n */\nclass WorkspacePool {\n public:\n  /**\n   * @brief Create pool with specific device type and device.\n   * @param device_type The device type.\n   * @param device The device API.\n   */\n  WorkspacePool(DGLDeviceType device_type, std::shared_ptr<DeviceAPI> device);\n  /** @brief destructor */\n  ~WorkspacePool();\n  /**\n   * @brief Allocate temporal workspace.\n   * @param ctx The context of allocation.\n   * @param size The size to be allocated.\n   */\n  void* AllocWorkspace(DGLContext ctx, size_t size);\n  /**\n   * @brief Free temporal workspace in backend execution.\n   *\n   * @param ctx The context of allocation.\n   * @param ptr The pointer to be freed.\n   */\n  void FreeWorkspace(DGLContext ctx, void* ptr);\n\n private:\n  class Pool;\n  /** @brief pool of device local array */\n  std::vector<Pool*> array_;\n  /** @brief device type this pool support */\n  DGLDeviceType device_type_;\n  /** @brief The device API */\n  std::shared_ptr<DeviceAPI> device_;\n};\n\n}  // namespace runtime\n}  // namespace dgl\n#endif  // DGL_RUNTIME_WORKSPACE_POOL_H_\n"
  },
  {
    "path": "src/scheduler/scheduler.cc",
    "content": "/**\n *  Copyright (c) 2018 by Contributors\n * @file scheduler/scheduler.cc\n * @brief DGL Scheduler implementation\n */\n#include <dgl/scheduler.h>\n\n#include <unordered_map>\n#include <vector>\n\nnamespace dgl {\nnamespace sched {\n\ntemplate <class IdType>\nstd::vector<IdArray> DegreeBucketing(\n    const IdArray& msg_ids, const IdArray& vids, const IdArray& recv_ids) {\n  auto n_msgs = msg_ids->shape[0];\n\n  const IdType* vid_data = static_cast<IdType*>(vids->data);\n  const IdType* msg_id_data = static_cast<IdType*>(msg_ids->data);\n  const IdType* recv_id_data = static_cast<IdType*>(recv_ids->data);\n\n  // in edge: dst->msgs\n  std::unordered_map<IdType, std::vector<IdType>> in_edges;\n  for (IdType i = 0; i < n_msgs; ++i) {\n    in_edges[vid_data[i]].push_back(msg_id_data[i]);\n  }\n\n  // bkt: deg->dsts\n  std::unordered_map<IdType, std::vector<IdType>> bkt;\n  for (const auto& it : in_edges) {\n    bkt[it.second.size()].push_back(it.first);\n  }\n\n  std::unordered_set<IdType> zero_deg_nodes;\n  for (IdType i = 0; i < recv_ids->shape[0]; ++i) {\n    if (in_edges.find(recv_id_data[i]) == in_edges.end()) {\n      zero_deg_nodes.insert(recv_id_data[i]);\n    }\n  }\n  auto n_zero_deg = zero_deg_nodes.size();\n\n  // calc output size\n  IdType n_deg = bkt.size();\n  IdType n_dst = in_edges.size();\n  IdType n_mid_sec = bkt.size();  // zero deg won't affect message size\n  if (n_zero_deg > 0) {\n    n_deg += 1;\n    n_dst += n_zero_deg;\n  }\n\n  // initialize output\n  IdArray degs = IdArray::Empty({n_deg}, vids->dtype, vids->ctx);\n  IdArray nids = IdArray::Empty({n_dst}, vids->dtype, vids->ctx);\n  IdArray nid_section = IdArray::Empty({n_deg}, vids->dtype, vids->ctx);\n  IdArray mids = IdArray::Empty({n_msgs}, vids->dtype, vids->ctx);\n  IdArray mid_section = IdArray::Empty({n_mid_sec}, vids->dtype, vids->ctx);\n  IdType* deg_ptr = static_cast<IdType*>(degs->data);\n  IdType* nid_ptr = static_cast<IdType*>(nids->data);\n  IdType* nsec_ptr = static_cast<IdType*>(nid_section->data);\n  IdType* mid_ptr = static_cast<IdType*>(mids->data);\n  IdType* msec_ptr = static_cast<IdType*>(mid_section->data);\n\n  // fill in bucketing ordering\n  for (const auto& it : bkt) {  // for each bucket\n    const IdType deg = it.first;\n    const IdType bucket_size = it.second.size();\n    *deg_ptr++ = deg;\n    *nsec_ptr++ = bucket_size;\n    *msec_ptr++ = deg * bucket_size;\n    for (const auto dst : it.second) {  // for each dst in this bucket\n      *nid_ptr++ = dst;\n      for (const auto mid : in_edges[dst]) {  // for each in edge of dst\n        *mid_ptr++ = mid;\n      }\n    }\n  }\n\n  if (n_zero_deg > 0) {\n    *deg_ptr = 0;\n    *nsec_ptr = n_zero_deg;\n    for (const auto dst : zero_deg_nodes) {\n      *nid_ptr++ = dst;\n    }\n  }\n\n  std::vector<IdArray> ret;\n  ret.push_back(std::move(degs));\n  ret.push_back(std::move(nids));\n  ret.push_back(std::move(nid_section));\n  ret.push_back(std::move(mids));\n  ret.push_back(std::move(mid_section));\n\n  return ret;\n}\n\ntemplate std::vector<IdArray> DegreeBucketing<int32_t>(\n    const IdArray& msg_ids, const IdArray& vids, const IdArray& recv_ids);\n\ntemplate std::vector<IdArray> DegreeBucketing<int64_t>(\n    const IdArray& msg_ids, const IdArray& vids, const IdArray& recv_ids);\n\ntemplate <class IdType>\nstd::vector<IdArray> GroupEdgeByNodeDegree(\n    const IdArray& uids, const IdArray& vids, const IdArray& eids) {\n  auto n_edge = eids->shape[0];\n  const IdType* eid_data = static_cast<IdType*>(eids->data);\n  const IdType* uid_data = static_cast<IdType*>(uids->data);\n  const IdType* vid_data = static_cast<IdType*>(vids->data);\n\n  // node2edge: group_by nodes uid -> (eid, the other end vid)\n  std::unordered_map<IdType, std::vector<std::pair<IdType, IdType>>> node2edge;\n  for (IdType i = 0; i < n_edge; ++i) {\n    node2edge[uid_data[i]].emplace_back(eid_data[i], vid_data[i]);\n  }\n\n  // bkt: deg -> group_by node uid\n  std::unordered_map<IdType, std::vector<IdType>> bkt;\n  for (const auto& it : node2edge) {\n    bkt[it.second.size()].push_back(it.first);\n  }\n\n  // number of unique degree\n  IdType n_deg = bkt.size();\n\n  // initialize output\n  IdArray degs = IdArray::Empty({n_deg}, eids->dtype, eids->ctx);\n  IdArray new_uids = IdArray::Empty({n_edge}, uids->dtype, uids->ctx);\n  IdArray new_vids = IdArray::Empty({n_edge}, vids->dtype, vids->ctx);\n  IdArray new_eids = IdArray::Empty({n_edge}, eids->dtype, eids->ctx);\n  IdArray sections = IdArray::Empty({n_deg}, eids->dtype, eids->ctx);\n  IdType* deg_ptr = static_cast<IdType*>(degs->data);\n  IdType* uid_ptr = static_cast<IdType*>(new_uids->data);\n  IdType* vid_ptr = static_cast<IdType*>(new_vids->data);\n  IdType* eid_ptr = static_cast<IdType*>(new_eids->data);\n  IdType* sec_ptr = static_cast<IdType*>(sections->data);\n\n  // fill in bucketing ordering\n  for (const auto& it : bkt) {  // for each bucket\n    // degree of this bucket\n    const IdType deg = it.first;\n    // number of edges in this bucket\n    const IdType bucket_size = it.second.size();\n    *deg_ptr++ = deg;\n    *sec_ptr++ = deg * bucket_size;\n    for (const auto u : it.second) {           // for uid in this bucket\n      for (const auto& pair : node2edge[u]) {  // for each edge of uid\n        *uid_ptr++ = u;\n        *vid_ptr++ = pair.second;\n        *eid_ptr++ = pair.first;\n      }\n    }\n  }\n\n  std::vector<IdArray> ret;\n  ret.push_back(std::move(degs));\n  ret.push_back(std::move(new_uids));\n  ret.push_back(std::move(new_vids));\n  ret.push_back(std::move(new_eids));\n  ret.push_back(std::move(sections));\n\n  return ret;\n}\n\ntemplate std::vector<IdArray> GroupEdgeByNodeDegree<int32_t>(\n    const IdArray& uids, const IdArray& vids, const IdArray& eids);\n\ntemplate std::vector<IdArray> GroupEdgeByNodeDegree<int64_t>(\n    const IdArray& uids, const IdArray& vids, const IdArray& eids);\n\n}  // namespace sched\n\n}  // namespace dgl\n"
  },
  {
    "path": "src/scheduler/scheduler_apis.cc",
    "content": "/**\n *  Copyright (c) 2018 by Contributors\n * @file scheduler/scheduler_apis.cc\n * @brief DGL scheduler APIs\n */\n#include <dgl/array.h>\n#include <dgl/graph.h>\n#include <dgl/scheduler.h>\n\n#include \"../array/cpu/array_utils.h\"\n#include \"../c_api_common.h\"\n\nusing dgl::runtime::DGLArgs;\nusing dgl::runtime::DGLRetValue;\nusing dgl::runtime::NDArray;\n\nnamespace dgl {\n\nDGL_REGISTER_GLOBAL(\n    \"_deprecate.runtime.degree_bucketing._CAPI_DGLDegreeBucketing\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      const IdArray msg_ids = args[0];\n      const IdArray vids = args[1];\n      const IdArray nids = args[2];\n      CHECK_SAME_DTYPE(msg_ids, vids);\n      CHECK_SAME_DTYPE(msg_ids, nids);\n      ATEN_ID_TYPE_SWITCH(msg_ids->dtype, IdType, {\n        *rv = ConvertNDArrayVectorToPackedFunc(\n            sched::DegreeBucketing<IdType>(msg_ids, vids, nids));\n      });\n    });\n\nDGL_REGISTER_GLOBAL(\n    \"_deprecate.runtime.degree_bucketing._CAPI_DGLGroupEdgeByNodeDegree\")\n    .set_body([](DGLArgs args, DGLRetValue* rv) {\n      const IdArray uids = args[0];\n      const IdArray vids = args[1];\n      const IdArray eids = args[2];\n      CHECK_SAME_DTYPE(uids, vids);\n      CHECK_SAME_DTYPE(uids, eids);\n      ATEN_ID_TYPE_SWITCH(uids->dtype, IdType, {\n        *rv = ConvertNDArrayVectorToPackedFunc(\n            sched::GroupEdgeByNodeDegree<IdType>(uids, vids, eids));\n      });\n    });\n\n}  // namespace dgl\n"
  },
  {
    "path": "tensoradapter/include/tensoradapter.h",
    "content": "/**\n *  Copyright (c) 2020-2022 by Contributors\n * @file tensoradapter.h\n * @brief Header file for functions exposed by the adapter library.\n *\n * Functions in this library must be exported with extern \"C\" so that DGL can\n * locate them with dlsym(3) (or GetProcAddress on Windows).\n */\n\n#ifndef TENSORADAPTER_H_\n#define TENSORADAPTER_H_\n\n#ifdef DGL_USE_CUDA\n#include <cuda_runtime.h>\n#endif  // DGL_USE_CUDA\n\nnamespace tensoradapter {\n\nextern \"C\" {\n\n/**\n * @brief Allocate a piece of CPU memory via\n * PyTorch's CPUAllocator\n *\n * @param nbytes The size to be allocated.\n * @return Pointer to the allocated memory.\n */\nvoid* CPURawAlloc(size_t nbytes);\n\n/**\n * @brief Free the CPU memory.\n *\n * @param ptr Pointer to the memory to be freed.\n */\nvoid CPURawDelete(void* ptr);\n\n#ifdef DGL_USE_CUDA\n/**\n * @brief Allocate a piece of GPU memory via\n * PyTorch's THCCachingAllocator.\n *\n * @param nbytes The size to be allocated.\n * @param stream The stream to be allocated on.\n * @return Pointer to the allocated memory.\n */\nvoid* CUDARawAlloc(size_t nbytes, cudaStream_t stream);\n\n/**\n * @brief Free the GPU memory.\n *\n * @param ptr Pointer to the memory to be freed.\n */\nvoid CUDARawDelete(void* ptr);\n\n/**\n * @brief Get the current CUDA stream.\n */\ncudaStream_t CUDACurrentStream();\n\n/**\n * @brief Let the caching allocator know which streams are using this tensor.\n *\n * @param ptr Pointer of the tensor to be recorded.\n * @param stream The stream that is using this tensor.\n * @param device_id Device of the tensor.\n */\nvoid RecordStream(void* ptr, cudaStream_t stream, int device_id);\n\n/**\n * @brief Allocate a piece of pinned CPU memory via\n *     PyTorch's CachingHostAllocator.\n *\n * @param nbytes The size to be allocated.\n * @param ctx Pointer to the PyTorch storage ctx ptr returned from the\n *     allocator.\n * @param deleter Pointer to the delete function ptr returned from the\n *     allocator.\n * @return Raw pointer to the allocated memory.\n */\nvoid* CUDARawHostAlloc(size_t nbytes, void** ctx, void** raw_deleter);\n\n/**\n * @brief 'Free' the pinned CPU memory via\n *     inserting the memory block back to the free list.\n *\n * @param deleter Pointer to the delete function ptr returned from the\n *     allocator.\n */\nvoid CUDARawHostDelete(void** raw_deleter);\n\n/**\n * @brief 'Record' a CUDA stream (usually from a copy kernel) for the pinned\n *     memory via PyTorch's CachingHostAllocator.\n *\n * @param data Pointer of the tensor to be recorded.\n * @param ctx PyTorch storage ctx ptr returned from the allocator.\n * @param stream The stream that currently consumes this tensor.\n * @param device_id Device of the tensor.\n */\nvoid CUDARecordHostAlloc(\n    void* data, void* ctx, cudaStream_t stream, int device_id);\n\n/**\n * @brief Release cached pinned memory allocations via cudaHostFree.\n */\nvoid CUDAHostAllocatorEmptyCache();\n\n#endif  // DGL_USE_CUDA\n}\n\n};  // namespace tensoradapter\n\n#endif  // TENSORADAPTER_H_\n"
  },
  {
    "path": "tensoradapter/include/tensoradapter_exports.h",
    "content": "/**\n *  Copyright (c) 2020 by Contributors\n * @file tensoradapter_exports.h\n * @brief Header file for functions exposed by the adapter library.\n */\n\n#ifndef TENSORADAPTER_EXPORTS_H_\n#define TENSORADAPTER_EXPORTS_H_\n\n#if defined(WIN32) || defined(_WIN32)\n#define TA_EXPORTS __declspec(dllexport)\n#else\n#define TA_EXPORTS\n#endif\n\n#endif  // TENSORADAPTER_EXPORTS_H_\n"
  },
  {
    "path": "tensoradapter/pytorch/CMakeLists.txt",
    "content": "cmake_minimum_required(VERSION 3.5)\nproject(tensoradapter_pytorch C CXX)\n\n# Find PyTorch cmake files and PyTorch versions with the python interpreter $PYTHON_INTERP\n# (\"python3\" or \"python\" if empty)\nif(NOT PYTHON_INTERP)\n  find_program(PYTHON_INTERP NAMES python3 python)\nendif()\nmessage(STATUS \"Using Python interpreter: ${PYTHON_INTERP}\")\nfile(TO_NATIVE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/find_cmake.py FIND_CMAKE_PY)\nexecute_process(\n  COMMAND ${PYTHON_INTERP} ${FIND_CMAKE_PY}\n  OUTPUT_VARIABLE TORCH_PREFIX_VER\n  OUTPUT_STRIP_TRAILING_WHITESPACE)\nmessage(STATUS \"find_cmake.py output: ${TORCH_PREFIX_VER}\")\nlist(GET TORCH_PREFIX_VER 0 TORCH_PREFIX)\nlist(GET TORCH_PREFIX_VER 1 TORCH_VER)\nmessage(STATUS \"Configuring for PyTorch ${TORCH_VER}\")\n\nif(USE_CUDA)\n  add_definitions(-DDGL_USE_CUDA)\nendif()\n\nset(Torch_DIR \"${TORCH_PREFIX}/Torch\")\nmessage(STATUS \"Setting directory to ${Torch_DIR}\")\nfind_package(Torch REQUIRED)\nset(CMAKE_C_FLAGS \"${CMAKE_C_FLAGS} ${TORCH_C_FLAGS}\")\nset(CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}\")\nset(CMAKE_CXX_FLAGS_DEBUG \"${CMAKE_CXX_FLAGS_DEBUG} -O0 -g3 -ggdb\")\nset(TORCH_TARGET_NAME \"tensoradapter_pytorch_${TORCH_VER}\")\nfile(GLOB TA_TORCH_SRC *.cpp)\nadd_library(${TORCH_TARGET_NAME} SHARED \"${TA_TORCH_SRC}\")\n\n# use the library name rather than the path\nset(TENSORADAPTER_TORCH_LIBS torch)\n\nmessage(STATUS \"tensoradapter found PyTorch includes: ${TORCH_INCLUDE_DIRS}\")\nmessage(STATUS \"tensoradapter found PyTorch lib: ${TENSORADAPTER_TORCH_LIBS}\")\n\ntarget_include_directories(\n  ${TORCH_TARGET_NAME} PRIVATE \"${CMAKE_CURRENT_SOURCE_DIR}/../include\")\ntarget_include_directories(\n  ${TORCH_TARGET_NAME} PRIVATE \"${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/dlpack/include\")\ntarget_include_directories(\n  ${TORCH_TARGET_NAME} PRIVATE \"${TORCH_INCLUDE_DIRS}\")\ntarget_link_libraries(${TORCH_TARGET_NAME} PRIVATE \"${TENSORADAPTER_TORCH_LIBS}\")\nset_property(TARGET ${TORCH_TARGET_NAME} PROPERTY CXX_STANDARD 17)\nmessage(STATUS \"Configured target ${TORCH_TARGET_NAME}\")\n"
  },
  {
    "path": "tensoradapter/pytorch/build.bat",
    "content": "REM Helper script to build tensor adapter libraries for PyTorch\n@ECHO OFF\nSETLOCAL EnableDelayedExpansion\n\nMD \"%BINDIR%\\tensoradapter\\pytorch\"\nDEL /S /Q build\nMD build\nPUSHD build\n\nIF x%1x == xx GOTO single\n\nFOR %%X IN (%*) DO (\n\tDEL /S /Q *\n\t\"%CMAKE_COMMAND%\" -DCMAKE_CONFIGURATION_TYPES=Release -DCUDA_TOOLKIT_ROOT_DIR=\"%CUDA_TOOLKIT_ROOT_DIR%\" -DTORCH_CUDA_ARCH_LIST=%TORCH_CUDA_ARCH_LIST% -DUSE_CUDA=%USE_CUDA% -DPYTHON_INTERP=%%X .. -G \"Visual Studio 16 2019\" || EXIT /B 1\n\tmsbuild tensoradapter_pytorch.sln /m /nr:false || EXIT /B 1\n\tCOPY /Y Release\\*.dll \"%BINDIR%\\tensoradapter\\pytorch\" || EXIT /B 1\n)\n\nGOTO end\n\n:single\n\nDEL /S /Q *\n\"%CMAKE_COMMAND%\" -DCMAKE_CONFIGURATION_TYPES=Release -DCUDA_TOOLKIT_ROOT_DIR=\"%CUDA_TOOLKIT_ROOT_DIR%\" -DTORCH_CUDA_ARCH_LIST=%TORCH_CUDA_ARCH_LIST% -DUSE_CUDA=%USE_CUDA% .. -G \"Visual Studio 16 2019\" || EXIT /B 1\nmsbuild tensoradapter_pytorch.sln /m /nr:false || EXIT /B 1\nCOPY /Y Release\\*.dll \"%BINDIR%\\tensoradapter\\pytorch\" || EXIT /B 1\n\n:end\nPOPD\n\nENDLOCAL\n"
  },
  {
    "path": "tensoradapter/pytorch/build.sh",
    "content": "#!/bin/bash\n# Helper script to build tensor adapter libraries for PyTorch\nset -e\n\nmkdir -p build\nmkdir -p $BINDIR/tensoradapter/pytorch\ncd build\n\nif [ $(uname) = 'Darwin' ]; then\n\tCPSOURCE=*.dylib\nelse\n\tCPSOURCE=*.so\nfi\n\nCMAKE_FLAGS=\"-DCUDA_TOOLKIT_ROOT_DIR=$CUDA_TOOLKIT_ROOT_DIR -DTORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST -DUSE_CUDA=$USE_CUDA\"\n\nif [ $# -eq 0 ]; then\n\t$CMAKE_COMMAND $CMAKE_FLAGS ..\n\tmake -j\n\tcp -v $CPSOURCE $BINDIR/tensoradapter/pytorch\nelse\n\tfor PYTHON_INTERP in $@; do\n\t\tTORCH_VER=$($PYTHON_INTERP -c 'import torch; print(torch.__version__.split(\"+\")[0])')\n\t\tmkdir -p $TORCH_VER\n\t\tcd $TORCH_VER\n\t\t$CMAKE_COMMAND $CMAKE_FLAGS -DPYTHON_INTERP=$PYTHON_INTERP ../..\n\t\tmake -j\n\t\tcp -v $CPSOURCE $BINDIR/tensoradapter/pytorch\n\t\tcd ..\n\tdone\nfi\n"
  },
  {
    "path": "tensoradapter/pytorch/find_cmake.py",
    "content": "import os\n\nimport torch\n\ncmake_prefix_path = getattr(\n    torch.utils,\n    \"cmake_prefix_path\",\n    os.path.join(os.path.dirname(torch.__file__), \"share\", \"cmake\"),\n)\nversion = torch.__version__.split(\"+\")[0]\nprint(\";\".join([cmake_prefix_path, version]))\n"
  },
  {
    "path": "tensoradapter/pytorch/torch.cpp",
    "content": "/**\n *  Copyright (c) 2020-2022 by Contributors\n * @file torch/torch.cpp\n * @brief Implementation of PyTorch adapter library.\n */\n\n#include <c10/core/CPUAllocator.h>\n#include <tensoradapter_exports.h>\n#ifdef DGL_USE_CUDA\n#include <ATen/cuda/CUDAContext.h>\n#include <ATen/cuda/CachingHostAllocator.h>\n#include <c10/cuda/CUDACachingAllocator.h>\n#include <c10/cuda/CUDAStream.h>\n#include <cuda_runtime.h>\n#endif  // DGL_USE_CUDA\n\nnamespace tensoradapter {\n\nextern \"C\" {\n\nTA_EXPORTS void* CPURawAlloc(size_t nbytes) {\n  return c10::GetCPUAllocator()->raw_allocate(nbytes);\n}\n\nTA_EXPORTS void CPURawDelete(void* ptr) {\n  c10::GetCPUAllocator()->raw_deallocate(ptr);\n}\n\n#ifdef DGL_USE_CUDA\nTA_EXPORTS void* CUDARawAlloc(size_t nbytes, cudaStream_t stream) {\n  at::globalContext().lazyInitDevice(at::kCUDA);\n  return c10::cuda::CUDACachingAllocator::raw_alloc_with_stream(nbytes, stream);\n}\n\nTA_EXPORTS void CUDARawDelete(void* ptr) {\n  c10::cuda::CUDACachingAllocator::raw_delete(ptr);\n}\n\nTA_EXPORTS cudaStream_t CUDACurrentStream() {\n  return at::cuda::getCurrentCUDAStream();\n}\n\nTA_EXPORTS void RecordStream(void* ptr, cudaStream_t stream, int device_id) {\n  c10::DataPtr data_ptr{\n      ptr, ptr, c10::cuda::CUDACachingAllocator::get()->raw_deleter(),\n      c10::Device(c10::DeviceType::CUDA, device_id)};\n  c10::cuda::CUDACachingAllocator::recordStream(\n      data_ptr,\n      // getStreamFromExternal doesn't exist before PyTorch 1.10, just copy it\n      // here\n      c10::cuda::CUDAStream(\n          c10::cuda::CUDAStream::UNCHECKED,\n          c10::Stream(\n              c10::Stream::UNSAFE,\n              c10::Device(c10::DeviceType::CUDA, device_id),\n              reinterpret_cast<int64_t>(stream))));\n  data_ptr.release_context();\n}\n\nclass CUDAHostDeleter {\n public:\n  explicit CUDAHostDeleter(std::unique_ptr<void, c10::DeleterFnPtr> ptr)\n      : ptr_(std::move(ptr)) {}\n\n private:\n  std::unique_ptr<void, c10::DeleterFnPtr> ptr_;\n};\n\nTA_EXPORTS void* CUDARawHostAlloc(\n    size_t nbytes, void** ctx, void** raw_deleter) {\n  auto data_ptr = at::cuda::getCachingHostAllocator()->allocate(nbytes);\n  auto raw = data_ptr.get();\n  // Return the raw ctx ptr for recording event.\n  *ctx = data_ptr.get_context();\n\n  // Transfer ownership to raw_deleter.\n  auto* data_deleter = new CUDAHostDeleter(data_ptr.move_context());\n  *raw_deleter = static_cast<void*>(data_deleter);\n  return raw;\n}\n\n// Designated CUDAHostDeleter for CUDARawHostAlloc.\nTA_EXPORTS void CUDARawHostDelete(void** raw_deleter) {\n  delete static_cast<CUDAHostDeleter*>(*raw_deleter);\n  *raw_deleter = nullptr;\n}\n\nTA_EXPORTS void CUDARecordHostAlloc(\n    void* ptr, void* ctx, cudaStream_t stream, int device_id) {\n  at::cuda::CachingHostAllocator_recordEvent(\n      ptr, ctx,\n      c10::cuda::CUDAStream(\n          c10::cuda::CUDAStream::UNCHECKED,\n          c10::Stream(\n              c10::Stream::UNSAFE,\n              c10::Device(c10::DeviceType::CUDA, device_id),\n              reinterpret_cast<int64_t>(stream))));\n}\n\nTA_EXPORTS void CUDAHostAllocatorEmptyCache() {\n  at::cuda::CachingHostAllocator_emptyCache();\n}\n#endif  // DGL_USE_CUDA\n};\n\n};  // namespace tensoradapter\n"
  },
  {
    "path": "tests/README.md",
    "content": "Unit test\n===\n\n## Python Unittest\nThe code organization goes as follows:\n\n* `backend`: Additional unified tensor interface for supported frameworks.\n  The functions there are only used in unit tests, not DGL itself.  Note that\n  the code there are not unit tests by themselves.\n* `compute`: All framework-agnostic computation-related unit tests go there.\n* `${DGLBACKEND}` (e.g. `pytorch` and `mxnet`): All framework-specific\n  computation-related unit tests go there.\n* `graph_index`: All unit tests for C++ graph structure implementation go\n  there.  The Python API being tested in this directory, if any, should be\n  as minimal as possible (usually simple wrappers of corresponding C++\n  functions).\n* `lint`: Pylint-related files.\n* `scripts`: Automated test scripts for CI.\n\n## C++ Unittest\nCompile with unittest by executing the command below\n```\n# Assume current directory is the root directory of dgl, and googletest submodule is initialized\nbash script/build_dgl.sh -c -r\n./runUnitTests\n```\n"
  },
  {
    "path": "tests/backend/__init__.py",
    "content": "import importlib\nimport os\nimport sys\n\nimport numpy as np\n\nfrom dgl.backend import *\nfrom dgl.nn import *\n\nfrom . import backend_unittest\n\nmod = importlib.import_module(\".%s\" % backend_name, __name__)\nthismod = sys.modules[__name__]\n\nfor api in backend_unittest.__dict__.keys():\n    if api.startswith(\"__\"):\n        continue\n    elif callable(mod.__dict__[api]):\n        # Tensor APIs used in unit tests MUST be supported across all backends\n        globals()[api] = mod.__dict__[api]\n\n# Tensor creation with default dtype and context\n\n_zeros = zeros\n_ones = ones\n_randn = randn\n_tensor = tensor\n_arange = arange\n_full = full\n_full_1d = full_1d\n_softmax = softmax\n_default_context_str = os.getenv(\"DGLTESTDEV\", \"cpu\")\n_context_dict = {\n    \"cpu\": cpu(),\n    \"gpu\": cuda(),\n}\n_default_context = _context_dict[_default_context_str]\n\n\ndef ctx():\n    return _default_context\n\n\ndef gpu_ctx():\n    return _default_context_str == \"gpu\"\n\n\ndef zeros(shape, dtype=float32, ctx=_default_context):\n    return _zeros(shape, dtype, ctx)\n\n\ndef ones(shape, dtype=float32, ctx=_default_context):\n    return _ones(shape, dtype, ctx)\n\n\ndef randn(shape):\n    return copy_to(_randn(shape), _default_context)\n\n\ndef tensor(data, dtype=None):\n    return copy_to(_tensor(data, dtype), _default_context)\n\n\ndef arange(start, stop, dtype=int64, ctx=None):\n    return _arange(\n        start, stop, dtype, ctx if ctx is not None else _default_context\n    )\n\n\ndef full(shape, fill_value, dtype, ctx=_default_context):\n    return _full(shape, fill_value, dtype, ctx)\n\n\ndef full_1d(length, fill_value, dtype, ctx=_default_context):\n    return _full_1d(length, fill_value, dtype, ctx)\n\n\ndef softmax(x, dim):\n    return _softmax(x, dim)\n"
  },
  {
    "path": "tests/backend/backend_unittest.py",
    "content": "\"\"\"This file defines the unified tensor framework interface required by DGL\nunit testing, other than the ones used in the framework itself.\n\"\"\"\n\n###############################################################################\n# Tensor, data type and context interfaces\n\n\ndef cuda():\n    \"\"\"Context object for CUDA.\"\"\"\n    pass\n\n\ndef is_cuda_available():\n    \"\"\"Check whether CUDA is available.\"\"\"\n    pass\n\n\n###############################################################################\n# Tensor functions on feature data\n# --------------------------------\n# These functions are performance critical, so it's better to have efficient\n# implementation in each framework.\n\n\ndef array_equal(a, b):\n    \"\"\"Check whether the two tensors are *exactly* equal.\"\"\"\n    pass\n\n\ndef allclose(a, b, rtol=1e-4, atol=1e-4):\n    \"\"\"Check whether the two tensors are numerically close to each other.\"\"\"\n    pass\n\n\ndef randn(shape):\n    \"\"\"Generate a tensor with elements from standard normal distribution.\"\"\"\n    pass\n\n\ndef full(shape, fill_value, dtype, ctx):\n    pass\n\n\ndef narrow_row_set(x, start, stop, new):\n    \"\"\"Set a slice of the given tensor to a new value.\"\"\"\n    pass\n\n\ndef sparse_to_numpy(x):\n    \"\"\"Convert a sparse tensor to a numpy array.\"\"\"\n    pass\n\n\ndef clone(x):\n    pass\n\n\ndef reduce_sum(x):\n    \"\"\"Sums all the elements into a single scalar.\"\"\"\n    pass\n\n\ndef softmax(x, dim):\n    \"\"\"Softmax Operation on Tensors\"\"\"\n    pass\n\n\ndef spmm(x, y):\n    \"\"\"Sparse dense matrix multiply\"\"\"\n    pass\n\n\ndef add(a, b):\n    \"\"\"Compute a + b\"\"\"\n    pass\n\n\ndef sub(a, b):\n    \"\"\"Compute a - b\"\"\"\n    pass\n\n\ndef mul(a, b):\n    \"\"\"Compute a * b\"\"\"\n    pass\n\n\ndef div(a, b):\n    \"\"\"Compute a / b\"\"\"\n    pass\n\n\ndef sum(x, dim, keepdims=False):\n    \"\"\"Computes the sum of array elements over given axes\"\"\"\n    pass\n\n\ndef max(x, dim):\n    \"\"\"Computes the max of array elements over given axes\"\"\"\n    pass\n\n\ndef min(x, dim):\n    \"\"\"Computes the min of array elements over given axes\"\"\"\n    pass\n\n\ndef prod(x, dim):\n    \"\"\"Computes the prod of array elements over given axes\"\"\"\n    pass\n\n\ndef matmul(a, b):\n    \"\"\"Compute Matrix Multiplication between a and b\"\"\"\n    pass\n\n\ndef dot(a, b):\n    \"\"\"Compute Dot between a and b\"\"\"\n    pass\n\n\ndef abs(a):\n    \"\"\"Compute the absolute value of a\"\"\"\n    pass\n\n\ndef seed(a):\n    \"\"\"Set seed to for random generator\"\"\"\n    pass\n\n\n###############################################################################\n# Tensor functions used *only* on index tensor\n# ----------------\n# These operators are light-weighted, so it is acceptable to fallback to\n# numpy operators if currently missing in the framework. Ideally in the future,\n# DGL should contain all the operations on index, so this set of operators\n# should be gradually removed.\n\n###############################################################################\n# Other interfaces\n# ----------------\n# These are not related to tensors. Some of them are temporary workarounds that\n# should be included in DGL in the future.\n"
  },
  {
    "path": "tests/backend/mxnet/__init__.py",
    "content": "from __future__ import absolute_import\n\nimport mxnet as mx\nimport mxnet.ndarray as nd\nimport numpy as np\n\n\ndef cuda():\n    return mx.gpu()\n\n\ndef is_cuda_available():\n    # TODO: Does MXNet have a convenient function to test GPU availability/compilation?\n    try:\n        a = nd.array([1, 2, 3], ctx=mx.gpu())\n        return True\n    except mx.MXNetError:\n        return False\n\n\ndef array_equal(a, b):\n    return nd.equal(a, b).asnumpy().all()\n\n\ndef allclose(a, b, rtol=1e-4, atol=1e-4):\n    return np.allclose(a.asnumpy(), b.asnumpy(), rtol=rtol, atol=atol)\n\n\ndef randn(shape):\n    return nd.random.randn(*shape)\n\n\ndef full(shape, fill_value, dtype, ctx):\n    return nd.full(shape, fill_value, dtype=dtype, ctx=ctx)\n\n\ndef narrow_row_set(x, start, stop, new):\n    x[start:stop] = new\n\n\ndef sparse_to_numpy(x):\n    return x.asscipy().todense().A\n\n\ndef clone(x):\n    return x.copy()\n\n\ndef reduce_sum(x):\n    return x.sum()\n\n\ndef softmax(x, dim):\n    return nd.softmax(x, axis=dim)\n\n\ndef spmm(x, y):\n    return nd.dot(x, y)\n\n\ndef add(a, b):\n    return a + b\n\n\ndef sub(a, b):\n    return a - b\n\n\ndef mul(a, b):\n    return a * b\n\n\ndef div(a, b):\n    return a / b\n\n\ndef sum(x, dim, keepdims=False):\n    return x.sum(dim, keepdims=keepdims)\n\n\ndef max(x, dim):\n    return x.max(dim)\n\n\ndef min(x, dim):\n    return x.min(dim)\n\n\ndef prod(x, dim):\n    return x.prod(dim)\n\n\ndef matmul(a, b):\n    return nd.dot(a, b)\n\n\ndef dot(a, b):\n    return nd.sum(mul(a, b), axis=-1)\n\n\ndef abs(a):\n    return nd.abs(a)\n\n\ndef seed(a):\n    return mx.random.seed(a)\n"
  },
  {
    "path": "tests/backend/pytorch/__init__.py",
    "content": "from __future__ import absolute_import\n\nimport torch as th\n\n\ndef cuda():\n    return th.device(\"cuda:0\")\n\n\ndef is_cuda_available():\n    return th.cuda.is_available()\n\n\ndef array_equal(a, b):\n    return th.equal(a.cpu(), b.cpu())\n\n\ndef allclose(a, b, rtol=1e-4, atol=1e-4):\n    return th.allclose(a.float().cpu(), b.float().cpu(), rtol=rtol, atol=atol)\n\n\ndef randn(shape):\n    return th.randn(*shape)\n\n\ndef full(shape, fill_value, dtype, ctx):\n    return th.full(shape, fill_value, dtype=dtype, device=ctx)\n\n\ndef narrow_row_set(x, start, stop, new):\n    x[start:stop] = new\n\n\ndef sparse_to_numpy(x):\n    return x.to_dense().numpy()\n\n\ndef clone(x):\n    return x.clone()\n\n\ndef reduce_sum(x):\n    return x.sum()\n\n\ndef softmax(x, dim):\n    return th.softmax(x, dim)\n\n\ndef spmm(x, y):\n    return th.spmm(x, y)\n\n\ndef add(a, b):\n    return a + b\n\n\ndef sub(a, b):\n    return a - b\n\n\ndef mul(a, b):\n    return a * b\n\n\ndef div(a, b):\n    return a / b\n\n\ndef sum(x, dim, keepdims=False):\n    return x.sum(dim, keepdims=keepdims)\n\n\ndef max(x, dim):\n    return x.max(dim)[0]\n\n\ndef min(x, dim):\n    return x.min(dim)[0]\n\n\ndef prod(x, dim):\n    return x.prod(dim)\n\n\ndef matmul(a, b):\n    return a @ b\n\n\ndef dot(a, b):\n    return sum(mul(a, b), dim=-1)\n\n\ndef abs(a):\n    return a.abs()\n\n\ndef seed(a):\n    return th.manual_seed(a)\n"
  },
  {
    "path": "tests/backend/tensorflow/__init__.py",
    "content": "from __future__ import absolute_import\n\nimport numpy as np\nimport tensorflow as tf\nfrom scipy.sparse import coo_matrix\n\n\ndef cuda():\n    return \"/gpu:0\"\n\n\ndef is_cuda_available():\n    return tf.test.is_gpu_available(cuda_only=True)\n\n\ndef array_equal(a, b):\n    return np.array_equal(a.numpy(), b.numpy())\n\n\ndef allclose(a, b, rtol=1e-4, atol=1e-4):\n    return np.allclose(\n        tf.convert_to_tensor(a).numpy(),\n        tf.convert_to_tensor(b).numpy(),\n        rtol=rtol,\n        atol=atol,\n    )\n\n\ndef randn(shape):\n    return tf.random.normal(shape)\n\n\ndef full(shape, fill_value, dtype, ctx):\n    with tf.device(ctx):\n        t = tf.constant(fill_value, shape=shape, dtype=dtype)\n    return t\n\n\ndef narrow_row_set(x, start, stop, new):\n    # x[start:stop] = new\n    raise NotImplementedError(\"TF doesn't support inplace update\")\n\n\ndef sparse_to_numpy(x):\n    # tf.sparse.to_dense assume sorted indices, need to turn off validate_indices in our cases\n    return tf.sparse.to_dense(x, validate_indices=False).numpy()\n\n\ndef clone(x):\n    return tf.identity(x)\n\n\ndef reduce_sum(x):\n    return tf.reduce_sum(x)\n\n\ndef softmax(x, dim):\n    return tf.math.softmax(x, axis=dim)\n\n\ndef spmm(x, y):\n    return tf.sparse.sparse_dense_matmul(x, y)\n\n\ndef add(a, b):\n    return a + b\n\n\ndef sub(a, b):\n    return a - b\n\n\ndef mul(a, b):\n    return a * b\n\n\ndef div(a, b):\n    return a / b\n\n\ndef sum(x, dim, keepdims=False):\n    return tf.reduce_sum(x, axis=dim, keepdims=keepdims)\n\n\ndef max(x, dim):\n    return tf.reduce_max(x, axis=dim)\n\n\ndef min(x, dim):\n    return tf.reduce_min(x, axis=dim)\n\n\ndef prod(x, dim):\n    return tf.reduce_prod(x, axis=dim)\n\n\ndef matmul(a, b):\n    return tf.linalg.matmul(a, b)\n\n\ndef dot(a, b):\n    return sum(mul(a, b), dim=-1)\n\n\ndef abs(a):\n    return tf.abs(a)\n\n\ndef seed(a):\n    return tf.random.set_seed(a)\n"
  },
  {
    "path": "tests/cpp/common.h",
    "content": "#ifndef TEST_COMMON_H_\n#define TEST_COMMON_H_\n\n#include <dgl/runtime/ndarray.h>\n\nstatic constexpr DGLContext CTX = DGLContext{kDGLCPU, 0};\nstatic constexpr DGLContext CPU = DGLContext{kDGLCPU, 0};\n#ifdef DGL_USE_CUDA\nstatic constexpr DGLContext GPU = DGLContext{kDGLCUDA, 0};\n#endif\n\ntemplate <typename T>\ninline T* Ptr(dgl::runtime::NDArray nd) {\n  return static_cast<T*>(nd->data);\n}\n\ninline int64_t* PI64(dgl::runtime::NDArray nd) {\n  return static_cast<int64_t*>(nd->data);\n}\n\ninline int32_t* PI32(dgl::runtime::NDArray nd) {\n  return static_cast<int32_t*>(nd->data);\n}\n\ninline int64_t Len(dgl::runtime::NDArray nd) { return nd->shape[0]; }\n\ntemplate <typename T>\ninline bool ArrayEQ(dgl::runtime::NDArray a1, dgl::runtime::NDArray a2) {\n  if (a1->ndim != a2->ndim) return false;\n  if (a1->dtype != a2->dtype) return false;\n  if (a1->ctx != a2->ctx) return false;\n  if (a1.NumElements() != a2.NumElements()) return false;\n  if (a1.NumElements() == 0) return true;\n  int64_t num = 1;\n  for (int i = 0; i < a1->ndim; ++i) {\n    if (a1->shape[i] != a2->shape[i]) return false;\n    num *= a1->shape[i];\n  }\n  a1 = a1.CopyTo(CPU);\n  a2 = a2.CopyTo(CPU);\n  for (int64_t i = 0; i < num; ++i)\n    if (static_cast<T*>(a1->data)[i] != static_cast<T*>(a2->data)[i])\n      return false;\n  return true;\n}\n\ntemplate <typename T>\ninline bool IsInArray(dgl::runtime::NDArray a, T x) {\n  if (!a.defined() || a->shape[0] == 0) return false;\n  for (int64_t i = 0; i < a->shape[0]; ++i) {\n    if (x == static_cast<T*>(a->data)[i]) return true;\n  }\n  return false;\n}\n\n#endif  // TEST_COMMON_H_\n"
  },
  {
    "path": "tests/cpp/graph_index_test.cc",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file graph_index_test.cc\n * @brief Test GraphIndex\n */\n#include <dgl/graph.h>\n#include <gtest/gtest.h>\n\nTEST(GraphTest, TestNumVertices) {\n  dgl::Graph g;\n  g.AddVertices(10);\n  ASSERT_EQ(g.NumVertices(), 10);\n};\n"
  },
  {
    "path": "tests/cpp/message_queue_test.cc",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file msg_queue.cc\n * @brief Message queue for DGL distributed training.\n */\n#include <gtest/gtest.h>\n\n#include <string>\n#include <thread>\n#include <vector>\n\n#include \"../src/rpc/network/msg_queue.h\"\n\nusing dgl::network::Message;\nusing dgl::network::MessageQueue;\nusing std::string;\n\nTEST(MessageQueueTest, AddRemove) {\n  MessageQueue queue(5, 1);  // size:5, num_of_producer:1\n  // msg 1\n  std::string str_1(\"111\");\n  Message msg_1 = {const_cast<char*>(str_1.data()), 3};\n  EXPECT_EQ(queue.Add(msg_1), ADD_SUCCESS);\n  // msg 2\n  std::string str_2(\"22\");\n  Message msg_2 = {const_cast<char*>(str_2.data()), 2};\n  EXPECT_EQ(queue.Add(msg_2), ADD_SUCCESS);\n  // msg 3\n  std::string str_3(\"xxxx\");\n  Message msg_3 = {const_cast<char*>(str_3.data()), 4};\n  EXPECT_EQ(queue.Add(msg_3, false), QUEUE_FULL);\n  // msg 4\n  Message msg_4;\n  EXPECT_EQ(queue.Remove(&msg_4), REMOVE_SUCCESS);\n  EXPECT_EQ(string(msg_4.data, msg_4.size), string(\"111\"));\n  // msg 5\n  Message msg_5;\n  EXPECT_EQ(queue.Remove(&msg_5), REMOVE_SUCCESS);\n  EXPECT_EQ(string(msg_5.data, msg_5.size), string(\"22\"));\n  // msg 6\n  std::string str_6(\"33333\");\n  Message msg_6 = {const_cast<char*>(str_6.data()), 5};\n  EXPECT_EQ(queue.Add(msg_6), ADD_SUCCESS);\n  // msg 7\n  Message msg_7;\n  EXPECT_EQ(queue.Remove(&msg_7), REMOVE_SUCCESS);\n  EXPECT_EQ(string(msg_7.data, msg_7.size), string(\"33333\"));\n  // msg 8\n  Message msg_8;\n  EXPECT_EQ(queue.Remove(&msg_8, false), QUEUE_EMPTY);  // non-blocking remove\n  // msg 9\n  std::string str_9(\"666666\");\n  Message msg_9 = {const_cast<char*>(str_9.data()), 6};\n  EXPECT_EQ(queue.Add(msg_9), MSG_GT_SIZE);  // exceed queue size\n  // msg 10\n  std::string str_10(\"55555\");\n  Message msg_10 = {const_cast<char*>(str_10.data()), 5};\n  EXPECT_EQ(queue.Add(msg_10), ADD_SUCCESS);\n  // msg 11\n  Message msg_11;\n  EXPECT_EQ(queue.Remove(&msg_11), REMOVE_SUCCESS);\n}\n\nTEST(MessageQueueTest, EmptyAndNoMoreAdd) {\n  MessageQueue queue(5, 2);  // size:5, num_of_producer:2\n  EXPECT_EQ(queue.EmptyAndNoMoreAdd(), false);\n  EXPECT_EQ(queue.Empty(), true);\n  queue.SignalFinished(1);\n  queue.SignalFinished(1);\n  EXPECT_EQ(queue.EmptyAndNoMoreAdd(), false);\n  queue.SignalFinished(2);\n  EXPECT_EQ(queue.EmptyAndNoMoreAdd(), true);\n}\n\nconst int kNumOfProducer = 100;\nconst int kNumOfMessage = 100;\n\nstd::string str_apple(\"apple\");\n\nvoid start_add(MessageQueue* queue, int id) {\n  for (int i = 0; i < kNumOfMessage; ++i) {\n    Message msg = {const_cast<char*>(str_apple.data()), 5};\n    EXPECT_EQ(queue->Add(msg), ADD_SUCCESS);\n  }\n  queue->SignalFinished(id);\n}\n\nTEST(MessageQueueTest, MultiThread) {\n  MessageQueue queue(100000, kNumOfProducer);\n  EXPECT_EQ(queue.EmptyAndNoMoreAdd(), false);\n  EXPECT_EQ(queue.Empty(), true);\n  std::vector<std::thread> thread_pool(kNumOfProducer);\n  for (int i = 0; i < kNumOfProducer; ++i) {\n    thread_pool[i] = std::thread(start_add, &queue, i);\n  }\n  for (int i = 0; i < kNumOfProducer * kNumOfMessage; ++i) {\n    Message msg;\n    EXPECT_EQ(queue.Remove(&msg), REMOVE_SUCCESS);\n    EXPECT_EQ(string(msg.data, msg.size), string(\"apple\"));\n  }\n  for (int i = 0; i < kNumOfProducer; ++i) {\n    thread_pool[i].join();\n  }\n  EXPECT_EQ(queue.EmptyAndNoMoreAdd(), true);\n}\n"
  },
  {
    "path": "tests/cpp/socket_communicator_test.cc",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file socket_communicator_test.cc\n * @brief Test SocketCommunicator\n */\n#include \"../src/rpc/network/socket_communicator.h\"\n\n#include <gtest/gtest.h>\n#include <stdlib.h>\n#include <string.h>\n#include <time.h>\n\n#include <chrono>\n#include <fstream>\n#include <streambuf>\n#include <string>\n#include <thread>\n#include <vector>\n\n#include \"../src/rpc/network/msg_queue.h\"\n\nusing std::string;\n\nusing dgl::network::DefaultMessageDeleter;\nusing dgl::network::Message;\nusing dgl::network::SocketReceiver;\nusing dgl::network::SocketSender;\n\nconst int64_t kQueueSize = 500 * 1024;\nconst int kThreadNum = 2;\nconst int kMaxTryTimes = 1024;\n\n#ifndef WIN32\n\nconst int kNumSender = 3;\nconst int kNumReceiver = 3;\nconst int kNumMessage = 10;\n\nconst char* ip_addr[] = {\n    \"tcp://127.0.0.1:50091\", \"tcp://127.0.0.1:50092\", \"tcp://127.0.0.1:50093\"};\n\nstatic void start_client();\nstatic void start_server(int id);\n\nTEST(SocketCommunicatorTest, SendAndRecv) {\n  // start 10 client\n  std::vector<std::thread> client_thread(kNumSender);\n  for (int i = 0; i < kNumSender; ++i) {\n    client_thread[i] = std::thread(start_client);\n  }\n  // start 10 server\n  std::vector<std::thread> server_thread(kNumReceiver);\n  for (int i = 0; i < kNumReceiver; ++i) {\n    server_thread[i] = std::thread(start_server, i);\n  }\n  for (int i = 0; i < kNumSender; ++i) {\n    client_thread[i].join();\n  }\n  for (int i = 0; i < kNumReceiver; ++i) {\n    server_thread[i].join();\n  }\n}\n\nTEST(SocketCommunicatorTest, SendAndRecvTimeout) {\n  std::atomic_bool stop{false};\n  // start 1 client, connect to 1 server, send 2 messsage\n  auto client = std::thread([&stop]() {\n    SocketSender sender(kQueueSize, kThreadNum);\n    sender.ConnectReceiver(ip_addr[0], 0);\n    sender.ConnectReceiverFinalize(kMaxTryTimes);\n    for (int i = 0; i < 2; ++i) {\n      char* str_data = new char[9];\n      memcpy(str_data, \"123456789\", 9);\n      Message msg = {str_data, 9};\n      msg.deallocator = DefaultMessageDeleter;\n      EXPECT_EQ(sender.Send(msg, 0), ADD_SUCCESS);\n    }\n    while (!stop) {\n    }\n    sender.Finalize();\n  });\n  // start 1 server, accept 1 client, receive 2 message\n  auto server = std::thread([&stop]() {\n    SocketReceiver receiver(kQueueSize, kThreadNum);\n    receiver.Wait(ip_addr[0], 1);\n    Message msg;\n    int recv_id;\n    // receive 1st message\n    EXPECT_EQ(receiver.RecvFrom(&msg, 0, 0), REMOVE_SUCCESS);\n    EXPECT_EQ(string(msg.data, msg.size), string(\"123456789\"));\n    msg.deallocator(&msg);\n    // receive 2nd message\n    EXPECT_EQ(receiver.Recv(&msg, &recv_id, 0), REMOVE_SUCCESS);\n    EXPECT_EQ(string(msg.data, msg.size), string(\"123456789\"));\n    msg.deallocator(&msg);\n    // timed out\n    EXPECT_EQ(receiver.RecvFrom(&msg, 0, 1000), QUEUE_EMPTY);\n    EXPECT_EQ(receiver.Recv(&msg, &recv_id, 1000), QUEUE_EMPTY);\n    stop = true;\n    receiver.Finalize();\n  });\n  // join\n  client.join();\n  server.join();\n}\n\nvoid start_client() {\n  SocketSender sender(kQueueSize, kThreadNum);\n  for (int i = 0; i < kNumReceiver; ++i) {\n    sender.ConnectReceiver(ip_addr[i], i);\n  }\n  sender.ConnectReceiverFinalize(kMaxTryTimes);\n  for (int i = 0; i < kNumMessage; ++i) {\n    for (int n = 0; n < kNumReceiver; ++n) {\n      char* str_data = new char[9];\n      memcpy(str_data, \"123456789\", 9);\n      Message msg = {str_data, 9};\n      msg.deallocator = DefaultMessageDeleter;\n      EXPECT_EQ(sender.Send(msg, n), ADD_SUCCESS);\n    }\n  }\n  for (int i = 0; i < kNumMessage; ++i) {\n    for (int n = 0; n < kNumReceiver; ++n) {\n      char* str_data = new char[9];\n      memcpy(str_data, \"123456789\", 9);\n      Message msg = {str_data, 9};\n      msg.deallocator = DefaultMessageDeleter;\n      EXPECT_EQ(sender.Send(msg, n), ADD_SUCCESS);\n    }\n  }\n  sender.Finalize();\n}\n\nvoid start_server(int id) {\n  sleep(5);\n  SocketReceiver receiver(kQueueSize, kThreadNum);\n  receiver.Wait(ip_addr[id], kNumSender);\n  for (int i = 0; i < kNumMessage; ++i) {\n    for (int n = 0; n < kNumSender; ++n) {\n      Message msg;\n      EXPECT_EQ(receiver.RecvFrom(&msg, n), REMOVE_SUCCESS);\n      EXPECT_EQ(string(msg.data, msg.size), string(\"123456789\"));\n      msg.deallocator(&msg);\n    }\n  }\n  for (int n = 0; n < kNumSender * kNumMessage; ++n) {\n    Message msg;\n    int recv_id;\n    EXPECT_EQ(receiver.Recv(&msg, &recv_id), REMOVE_SUCCESS);\n    EXPECT_EQ(string(msg.data, msg.size), string(\"123456789\"));\n    msg.deallocator(&msg);\n  }\n  receiver.Finalize();\n}\n\nTEST(SocketCommunicatorTest, TCPSocketBind) {\n  dgl::network::TCPSocket socket;\n  testing::internal::CaptureStderr();\n  EXPECT_EQ(socket.Bind(\"127.0.0\", 50001), false);\n  const std::string stderr = testing::internal::GetCapturedStderr();\n  EXPECT_NE(stderr.find(\"Invalid IP: 127.0.0\"), std::string::npos);\n}\n\n#else\n\n#include <windows.h>\n#include <winsock2.h>\n\n#pragma comment(lib, \"ws2_32.lib\")\n\nvoid sleep(int seconds) { Sleep(seconds * 1000); }\n\nstatic void start_client();\nstatic bool start_server();\n\nDWORD WINAPI _ClientThreadFunc(LPVOID param) {\n  start_client();\n  return 0;\n}\n\nDWORD WINAPI _ServerThreadFunc(LPVOID param) { return start_server() ? 1 : 0; }\n\nTEST(SocketCommunicatorTest, SendAndRecv) {\n  HANDLE hThreads[2];\n  WSADATA wsaData;\n  DWORD retcode, exitcode;\n\n  srand((unsigned)time(NULL));\n  int port = (rand() % (5000 - 3000 + 1)) + 3000;\n  std::string ip_addr = \"tcp://127.0.0.1:\" + std::to_string(port);\n  std::ofstream out(\"addr.txt\");\n  out << ip_addr;\n  out.close();\n\n  ASSERT_EQ(::WSAStartup(MAKEWORD(2, 2), &wsaData), 0);\n\n  hThreads[0] =\n      ::CreateThread(NULL, 0, _ClientThreadFunc, NULL, 0, NULL);  // client\n  ASSERT_TRUE(hThreads[0] != NULL);\n  hThreads[1] =\n      ::CreateThread(NULL, 0, _ServerThreadFunc, NULL, 0, NULL);  // server\n  ASSERT_TRUE(hThreads[1] != NULL);\n\n  retcode = ::WaitForMultipleObjects(2, hThreads, TRUE, INFINITE);\n  EXPECT_TRUE((retcode <= WAIT_OBJECT_0 + 1) && (retcode >= WAIT_OBJECT_0));\n\n  EXPECT_EQ(::GetExitCodeThread(hThreads[1], &exitcode), TRUE);\n  EXPECT_EQ(exitcode, 1);\n\n  EXPECT_EQ(::CloseHandle(hThreads[0]), TRUE);\n  EXPECT_EQ(::CloseHandle(hThreads[1]), TRUE);\n\n  ::WSACleanup();\n}\n\nstatic void start_client() {\n  std::ifstream t(\"addr.txt\");\n  std::string ip_addr(\n      (std::istreambuf_iterator<char>(t)), std::istreambuf_iterator<char>());\n  t.close();\n  SocketSender sender(kQueueSize, kThreadNum);\n  sender.ConnectReceiver(ip_addr.c_str(), 0);\n  sender.ConnectReceiverFinalize(kMaxTryTimes);\n  char* str_data = new char[9];\n  memcpy(str_data, \"123456789\", 9);\n  Message msg = {str_data, 9};\n  msg.deallocator = DefaultMessageDeleter;\n  sender.Send(msg, 0);\n  sender.Finalize();\n}\n\nstatic bool start_server() {\n  sleep(5);\n  std::ifstream t(\"addr.txt\");\n  std::string ip_addr(\n      (std::istreambuf_iterator<char>(t)), std::istreambuf_iterator<char>());\n  t.close();\n  SocketReceiver receiver(kQueueSize, kThreadNum);\n  receiver.Wait(ip_addr.c_str(), 1);\n  Message msg;\n  EXPECT_EQ(receiver.RecvFrom(&msg, 0), REMOVE_SUCCESS);\n  receiver.Finalize();\n  return string(\"123456789\") == string(msg.data, msg.size);\n}\n\n#endif\n"
  },
  {
    "path": "tests/cpp/string_test.cc",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file string_test.cc\n * @brief Test String Common\n */\n#include <gtest/gtest.h>\n\n#include <string>\n#include <vector>\n\n#include \"../src/rpc/network/common.h\"\n\nusing dgl::network::SplitStringUsing;\nusing dgl::network::SStringPrintf;\nusing dgl::network::StringAppendF;\nusing dgl::network::StringPrintf;\n\nTEST(SplitStringTest, SplitStringUsingCompoundDelim) {\n  std::string full(\" apple \\torange \");\n  std::vector<std::string> subs;\n  SplitStringUsing(full, \" \\t\", &subs);\n  EXPECT_EQ(subs.size(), 2);\n  EXPECT_EQ(subs[0], std::string(\"apple\"));\n  EXPECT_EQ(subs[1], std::string(\"orange\"));\n}\n\nTEST(SplitStringTest, testSplitStringUsingSingleDelim) {\n  std::string full(\" apple orange \");\n  std::vector<std::string> subs;\n  SplitStringUsing(full, \" \", &subs);\n  EXPECT_EQ(subs.size(), 2);\n  EXPECT_EQ(subs[0], std::string(\"apple\"));\n  EXPECT_EQ(subs[1], std::string(\"orange\"));\n}\n\nTEST(SplitStringTest, testSplitingNoDelimString) {\n  std::string full(\"apple\");\n  std::vector<std::string> subs;\n  SplitStringUsing(full, \" \", &subs);\n  EXPECT_EQ(subs.size(), 1);\n  EXPECT_EQ(subs[0], std::string(\"apple\"));\n}\n\nTEST(StringPrintf, normal) {\n  using std::string;\n  EXPECT_EQ(StringPrintf(\"%d\", 1), string(\"1\"));\n  string target;\n  SStringPrintf(&target, \"%d\", 1);\n  EXPECT_EQ(target, string(\"1\"));\n  StringAppendF(&target, \"%d\", 2);\n  EXPECT_EQ(target, string(\"12\"));\n}\n"
  },
  {
    "path": "tests/cpp/test_aten.cc",
    "content": "#include <dgl/array.h>\n#include <gtest/gtest.h>\n\n#include \"./common.h\"\n\nusing namespace dgl;\nusing namespace dgl::runtime;\n\nTEST(ArrayTest, TestCreate) {\n  IdArray a = aten::NewIdArray(100, CTX, 32);\n  ASSERT_EQ(a->dtype.bits, 32);\n  ASSERT_EQ(a->shape[0], 100);\n\n  a = aten::NewIdArray(0);\n  ASSERT_EQ(a->shape[0], 0);\n\n  std::vector<int64_t> vec = {2, 94, 232, 30};\n  a = aten::VecToIdArray(vec, 32);\n  ASSERT_EQ(Len(a), vec.size());\n  ASSERT_EQ(a->dtype.bits, 32);\n  for (int i = 0; i < Len(a); ++i) {\n    ASSERT_EQ(Ptr<int32_t>(a)[i], vec[i]);\n  }\n\n  a = aten::VecToIdArray(std::vector<int32_t>());\n  ASSERT_EQ(Len(a), 0);\n};\n\nvoid _TestRange(DGLContext ctx) {\n  IdArray a = aten::Range(10, 10, 64, ctx);\n  ASSERT_EQ(Len(a), 0);\n  a = aten::Range(10, 20, 32, ctx);\n  ASSERT_EQ(Len(a), 10);\n  ASSERT_EQ(a->dtype.bits, 32);\n  a = a.CopyTo(CPU);\n  for (int i = 0; i < 10; ++i) ASSERT_EQ(Ptr<int32_t>(a)[i], i + 10);\n}\n\nTEST(ArrayTest, TestRange) {\n  _TestRange(CPU);\n#ifdef DGL_USE_CUDA\n  _TestRange(GPU);\n#endif\n};\n\nTEST(ArrayTest, TestFull) {\n  IdArray a = aten::Full(-100, 0, 32, CTX);\n  ASSERT_EQ(Len(a), 0);\n  a = aten::Full(-100, 13, 64, CTX);\n  ASSERT_EQ(Len(a), 13);\n  ASSERT_EQ(a->dtype.bits, 64);\n  for (int i = 0; i < 13; ++i) ASSERT_EQ(Ptr<int64_t>(a)[i], -100);\n};\n\nTEST(ArrayTest, TestClone) {\n  IdArray a = aten::NewIdArray(0);\n  IdArray b = aten::Clone(a);\n  ASSERT_EQ(Len(b), 0);\n\n  a = aten::Range(0, 10, 32, CTX);\n  b = aten::Clone(a);\n  for (int i = 0; i < 10; ++i) {\n    ASSERT_EQ(PI32(b)[i], i);\n  }\n  PI32(b)[0] = -1;\n  for (int i = 0; i < 10; ++i) {\n    ASSERT_EQ(PI32(a)[i], i);\n  }\n};\n\nvoid _TestNumBits(DGLContext ctx) {\n  IdArray a = aten::Range(0, 10, 32, ctx);\n  a = aten::AsNumBits(a, 64);\n  ASSERT_EQ(a->dtype.bits, 64);\n  a = a.CopyTo(CPU);\n  for (int i = 0; i < 10; ++i) ASSERT_EQ(PI64(a)[i], i);\n}\n\nTEST(ArrayTest, TestAsNumBits) {\n  _TestNumBits(CPU);\n#ifdef DGL_USE_CUDA\n  _TestNumBits(GPU);\n#endif\n};\n\ntemplate <typename IDX>\nvoid _TestArith(DGLContext ctx) {\n  const int N = 100;\n  IdArray a = aten::Full(-10, N, sizeof(IDX) * 8, ctx);\n  IdArray b = aten::Full(7, N, sizeof(IDX) * 8, ctx);\n\n  IdArray c = a + b;\n  c = c.CopyTo(CPU);\n  for (int i = 0; i < N; ++i) ASSERT_EQ(Ptr<IDX>(c)[i], -3);\n  c = a - b;\n  c = c.CopyTo(CPU);\n  for (int i = 0; i < N; ++i) ASSERT_EQ(Ptr<IDX>(c)[i], -17);\n  c = a * b;\n  c = c.CopyTo(CPU);\n  for (int i = 0; i < N; ++i) ASSERT_EQ(Ptr<IDX>(c)[i], -70);\n  c = a / b;\n  c = c.CopyTo(CPU);\n  for (int i = 0; i < N; ++i) ASSERT_EQ(Ptr<IDX>(c)[i], -1);\n  c = -a;\n  c = c.CopyTo(CPU);\n  for (int i = 0; i < N; ++i) ASSERT_EQ(Ptr<IDX>(c)[i], 10);\n  c = (-a) % b;\n  c = c.CopyTo(CPU);\n  for (int i = 0; i < N; ++i) ASSERT_EQ(Ptr<IDX>(c)[i], 3);\n\n  const int val = -3;\n  c = aten::Add(a, val);\n  c = c.CopyTo(CPU);\n  for (int i = 0; i < N; ++i) ASSERT_EQ(Ptr<IDX>(c)[i], -13);\n  c = aten::Sub(a, val);\n  c = c.CopyTo(CPU);\n  for (int i = 0; i < N; ++i) ASSERT_EQ(Ptr<IDX>(c)[i], -7);\n  c = aten::Mul(a, val);\n  c = c.CopyTo(CPU);\n  for (int i = 0; i < N; ++i) ASSERT_EQ(Ptr<IDX>(c)[i], 30);\n  c = aten::Div(a, val);\n  c = c.CopyTo(CPU);\n  for (int i = 0; i < N; ++i) ASSERT_EQ(Ptr<IDX>(c)[i], 3);\n  c = b % 3;\n  c = c.CopyTo(CPU);\n  for (int i = 0; i < N; ++i) ASSERT_EQ(Ptr<IDX>(c)[i], 1);\n\n  c = aten::Add(val, b);\n  c = c.CopyTo(CPU);\n  for (int i = 0; i < N; ++i) ASSERT_EQ(Ptr<IDX>(c)[i], 4);\n  c = aten::Sub(val, b);\n  c = c.CopyTo(CPU);\n  for (int i = 0; i < N; ++i) ASSERT_EQ(Ptr<IDX>(c)[i], -10);\n  c = aten::Mul(val, b);\n  c = c.CopyTo(CPU);\n  for (int i = 0; i < N; ++i) ASSERT_EQ(Ptr<IDX>(c)[i], -21);\n  c = aten::Div(val, b);\n  c = c.CopyTo(CPU);\n  for (int i = 0; i < N; ++i) ASSERT_EQ(Ptr<IDX>(c)[i], 0);\n  c = 3 % b;\n  c = c.CopyTo(CPU);\n  for (int i = 0; i < N; ++i) ASSERT_EQ(Ptr<IDX>(c)[i], 3);\n\n  a = aten::Range(0, N, sizeof(IDX) * 8, ctx);\n  c = a < 50;\n  c = c.CopyTo(CPU);\n  for (int i = 0; i < N; ++i) ASSERT_EQ(Ptr<IDX>(c)[i], (int)(i < 50));\n\n  c = a > 50;\n  c = c.CopyTo(CPU);\n  for (int i = 0; i < N; ++i) ASSERT_EQ(Ptr<IDX>(c)[i], (int)(i > 50));\n\n  c = a >= 50;\n  c = c.CopyTo(CPU);\n  for (int i = 0; i < N; ++i) ASSERT_EQ(Ptr<IDX>(c)[i], (int)(i >= 50));\n\n  c = a <= 50;\n  c = c.CopyTo(CPU);\n  for (int i = 0; i < N; ++i) ASSERT_EQ(Ptr<IDX>(c)[i], (int)(i <= 50));\n\n  c = a == 50;\n  c = c.CopyTo(CPU);\n  for (int i = 0; i < N; ++i) ASSERT_EQ(Ptr<IDX>(c)[i], (int)(i == 50));\n\n  c = a != 50;\n  c = c.CopyTo(CPU);\n  for (int i = 0; i < N; ++i) ASSERT_EQ(Ptr<IDX>(c)[i], (int)(i != 50));\n}\n\nTEST(ArrayTest, Arith) {\n  _TestArith<int32_t>(CPU);\n  _TestArith<int64_t>(CPU);\n#ifdef DGL_USE_CUDA\n  _TestArith<int32_t>(GPU);\n  _TestArith<int64_t>(GPU);\n#endif\n};\n\ntemplate <typename IDX>\nvoid _TestHStack(DGLContext ctx) {\n  IdArray a = aten::Range(0, 100, sizeof(IDX) * 8, ctx);\n  IdArray b = aten::Range(100, 200, sizeof(IDX) * 8, ctx);\n  IdArray c = aten::HStack(a, b).CopyTo(aten::CPU);\n  ASSERT_EQ(c->ndim, 1);\n  ASSERT_EQ(c->shape[0], 200);\n  for (int i = 0; i < 200; ++i) ASSERT_EQ(Ptr<IDX>(c)[i], i);\n}\n\nTEST(ArrayTest, HStack) {\n  _TestHStack<int32_t>(CPU);\n  _TestHStack<int64_t>(CPU);\n#ifdef DGL_USE_CUDA\n  _TestHStack<int32_t>(GPU);\n  _TestHStack<int64_t>(GPU);\n#endif\n}\n\ntemplate <typename IDX>\nvoid _TestIndexSelect(DGLContext ctx) {\n  IdArray a = aten::Range(0, 100, sizeof(IDX) * 8, ctx);\n  ASSERT_EQ(aten::IndexSelect<int>(a, 50), 50);\n  ASSERT_TRUE(ArrayEQ<IDX>(\n      aten::IndexSelect(a, 10, 20), aten::Range(10, 20, sizeof(IDX) * 8, ctx)));\n  IdArray b =\n      aten::VecToIdArray(std::vector<IDX>({0, 20, 10}), sizeof(IDX) * 8, ctx);\n  IdArray c = aten::IndexSelect(a, b);\n  ASSERT_TRUE(ArrayEQ<IDX>(b, c));\n}\n\nTEST(ArrayTest, TestIndexSelect) {\n  _TestIndexSelect<int32_t>(CPU);\n  _TestIndexSelect<int64_t>(CPU);\n#ifdef DGL_USE_CUDA\n  _TestIndexSelect<int32_t>(GPU);\n  _TestIndexSelect<int64_t>(GPU);\n#endif\n}\n\ntemplate <typename IDX>\nvoid _TestRelabel_(DGLContext ctx) {\n  IdArray a =\n      aten::VecToIdArray(std::vector<IDX>({0, 20, 10}), sizeof(IDX) * 8, ctx);\n  IdArray b =\n      aten::VecToIdArray(std::vector<IDX>({20, 5, 6}), sizeof(IDX) * 8, ctx);\n  IdArray c = aten::Relabel_({a, b});\n\n  IdArray ta =\n      aten::VecToIdArray(std::vector<IDX>({0, 1, 2}), sizeof(IDX) * 8, ctx);\n  IdArray tb =\n      aten::VecToIdArray(std::vector<IDX>({1, 3, 4}), sizeof(IDX) * 8, ctx);\n  IdArray tc = aten::VecToIdArray(\n      std::vector<IDX>({0, 20, 10, 5, 6}), sizeof(IDX) * 8, ctx);\n\n  ASSERT_TRUE(ArrayEQ<IDX>(a, ta));\n  ASSERT_TRUE(ArrayEQ<IDX>(b, tb));\n  ASSERT_TRUE(ArrayEQ<IDX>(c, tc));\n}\n\nTEST(ArrayTest, TestRelabel_) {\n  _TestRelabel_<int32_t>(CPU);\n  _TestRelabel_<int64_t>(CPU);\n#ifdef DGL_USE_CUDA\n  _TestRelabel_<int32_t>(GPU);\n  _TestRelabel_<int64_t>(GPU);\n#endif\n}\n\ntemplate <typename IDX>\nvoid _TestConcat(DGLContext ctx) {\n  IdArray a =\n      aten::VecToIdArray(std::vector<IDX>({1, 2, 3}), sizeof(IDX) * 8, CTX);\n  IdArray b =\n      aten::VecToIdArray(std::vector<IDX>({4, 5, 6}), sizeof(IDX) * 8, CTX);\n  IdArray tc = aten::VecToIdArray(\n      std::vector<IDX>({1, 2, 3, 4, 5, 6}), sizeof(IDX) * 8, CTX);\n  IdArray c = aten::Concat(std::vector<IdArray>{a, b});\n  ASSERT_TRUE(ArrayEQ<IDX>(c, tc));\n  IdArray d = aten::Concat(std::vector<IdArray>{a, b, c});\n  IdArray td = aten::VecToIdArray(\n      std::vector<IDX>({1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6}), sizeof(IDX) * 8,\n      CTX);\n  ASSERT_TRUE(ArrayEQ<IDX>(d, td));\n}\n\ntemplate <typename IdType>\nvoid _TestToSimpleCsr(DGLContext ctx) {\n  /**\n   * A = [[0, 0, 0, 0],\n   *      [1, 0, 0, 1],\n   *      [1, 1, 1, 1],\n   *      [3, 2, 2, 3],\n   *      [2, 0, 0, 2]]\n   *\n   * B = CSRToSimple(A)\n   * B = [[0, 0, 0, 0],\n   *      [1, 0, 0, 1],\n   *      [1, 1, 1, 1],\n   *      [1, 1, 1, 1],\n   *      [1, 0, 0, 1]]\n   */\n  IdArray a_indptr = aten::VecToIdArray(\n      std::vector<IdType>({0, 0, 2, 6, 16, 20}), sizeof(IdType) * 8, CTX);\n  IdArray a_indices = aten::VecToIdArray(\n      std::vector<IdType>(\n          {0, 3, 0, 1, 2, 3, 0, 0, 0, 1, 1, 2, 2, 3, 3, 3, 0, 0, 3, 3}),\n      sizeof(IdType) * 8, CTX);\n  IdArray b_indptr = aten::VecToIdArray(\n      std::vector<IdType>({0, 0, 2, 6, 10, 12}), sizeof(IdType) * 8, CTX);\n  IdArray b_indices = aten::VecToIdArray(\n      std::vector<IdType>({0, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 3}),\n      sizeof(IdType) * 8, CTX);\n  IdArray cnt = aten::VecToIdArray(\n      std::vector<IdType>({1, 1, 1, 1, 1, 1, 3, 2, 2, 3, 2, 2}),\n      sizeof(IdType) * 8, CTX);\n  IdArray map = aten::VecToIdArray(\n      std::vector<IdType>(\n          {0, 1, 2, 3, 4, 5, 6, 6, 6, 7, 7, 8, 8, 9, 9, 9, 10, 10, 11, 11}),\n      sizeof(IdType) * 8, CTX);\n  const aten::CSRMatrix &csr_a =\n      aten::CSRMatrix(5, 4, a_indptr, a_indices, aten::NullArray(), true);\n  auto ret = CSRToSimple(csr_a);\n  aten::CSRMatrix csr_b = std::get<0>(ret);\n  IdArray ecnt = std::get<1>(ret);\n  IdArray emap = std::get<2>(ret);\n  ASSERT_EQ(csr_b.num_rows, 5);\n  ASSERT_EQ(csr_b.num_cols, 4);\n  ASSERT_TRUE(ArrayEQ<IdType>(csr_b.indptr, b_indptr));\n  ASSERT_TRUE(ArrayEQ<IdType>(csr_b.indices, b_indices));\n  ASSERT_TRUE(ArrayEQ<IdType>(ecnt, cnt));\n  ASSERT_TRUE(ArrayEQ<IdType>(emap, map));\n  ASSERT_TRUE(csr_b.sorted);\n\n  // a not sorted\n  a_indices = aten::VecToIdArray(\n      std::vector<IdType>(\n          {0, 3, 0, 1, 2, 3, 3, 0, 0, 1, 1, 2, 2, 3, 3, 0, 0, 3, 0, 3}),\n      sizeof(IdType) * 8, CTX);\n  map = aten::VecToIdArray(\n      std::vector<IdType>(\n          {0, 1, 2, 3, 4, 5, 9, 6, 6, 7, 7, 8, 8, 9, 9, 6, 10, 11, 10, 11}),\n      sizeof(IdType) * 8, CTX);\n  const aten::CSRMatrix &csr_a2 =\n      aten::CSRMatrix(5, 4, a_indptr, a_indices, aten::NullArray(), false);\n  ret = CSRToSimple(csr_a2);\n  csr_b = std::get<0>(ret);\n  ecnt = std::get<1>(ret);\n  emap = std::get<2>(ret);\n  ASSERT_EQ(csr_b.num_rows, 5);\n  ASSERT_EQ(csr_b.num_cols, 4);\n  ASSERT_TRUE(ArrayEQ<IdType>(csr_b.indptr, b_indptr));\n  ASSERT_TRUE(ArrayEQ<IdType>(csr_b.indices, b_indices));\n  ASSERT_TRUE(ArrayEQ<IdType>(ecnt, cnt));\n  ASSERT_TRUE(ArrayEQ<IdType>(emap, map));\n  ASSERT_TRUE(csr_b.sorted);\n}\n\nTEST(MatrixTest, TestToSimpleCsr) {\n  _TestToSimpleCsr<int32_t>(CPU);\n  _TestToSimpleCsr<int64_t>(CPU);\n}\n\ntemplate <typename IdType>\nvoid _TestToSimpleCoo(DGLContext ctx) {\n  /**\n   * A = [[0, 0, 0, 0],\n   *      [1, 0, 0, 1],\n   *      [1, 1, 1, 1],\n   *      [3, 2, 2, 3],\n   *      [2, 0, 0, 2]]\n   *\n   * B = CSRToSimple(A)\n   * B = [[0, 0, 0, 0],\n   *      [1, 0, 0, 1],\n   *      [1, 1, 1, 1],\n   *      [1, 1, 1, 1],\n   *      [1, 0, 0, 1]]\n   */\n  IdArray a_row = aten::VecToIdArray(\n      std::vector<IdType>(\n          {1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4}),\n      sizeof(IdType) * 8, CTX);\n  IdArray a_col = aten::VecToIdArray(\n      std::vector<IdType>(\n          {0, 3, 0, 1, 2, 3, 0, 0, 0, 1, 1, 2, 2, 3, 3, 3, 0, 0, 3, 3}),\n      sizeof(IdType) * 8, CTX);\n  IdArray b_row = aten::VecToIdArray(\n      std::vector<IdType>({1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4}),\n      sizeof(IdType) * 8, CTX);\n  IdArray b_col = aten::VecToIdArray(\n      std::vector<IdType>({0, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 3}),\n      sizeof(IdType) * 8, CTX);\n  IdArray cnt = aten::VecToIdArray(\n      std::vector<IdType>({1, 1, 1, 1, 1, 1, 3, 2, 2, 3, 2, 2}),\n      sizeof(IdType) * 8, CTX);\n  IdArray map = aten::VecToIdArray(\n      std::vector<IdType>(\n          {0, 1, 2, 3, 4, 5, 6, 6, 6, 7, 7, 8, 8, 9, 9, 9, 10, 10, 11, 11}),\n      sizeof(IdType) * 8, CTX);\n  const aten::COOMatrix &coo_a =\n      aten::COOMatrix(5, 4, a_row, a_col, aten::NullArray(), true, true);\n  auto ret = COOToSimple(coo_a);\n  aten::COOMatrix coo_b = std::get<0>(ret);\n  IdArray ecnt = std::get<1>(ret);\n  IdArray emap = std::get<2>(ret);\n  ASSERT_EQ(coo_b.num_rows, 5);\n  ASSERT_EQ(coo_b.num_cols, 4);\n  ASSERT_TRUE(ArrayEQ<IdType>(coo_b.row, b_row));\n  ASSERT_TRUE(ArrayEQ<IdType>(coo_b.col, b_col));\n  ASSERT_TRUE(ArrayEQ<IdType>(ecnt, cnt));\n  ASSERT_TRUE(ArrayEQ<IdType>(emap, map));\n  ASSERT_FALSE(COOHasData(coo_b));\n  ASSERT_TRUE(coo_b.row_sorted);\n  ASSERT_TRUE(coo_b.col_sorted);\n\n  // a not sorted\n  a_row = aten::VecToIdArray(\n      std::vector<IdType>(\n          {1, 2, 1, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4}),\n      sizeof(IdType) * 8, CTX);\n  a_col = aten::VecToIdArray(\n      std::vector<IdType>(\n          {0, 0, 3, 1, 2, 3, 0, 0, 0, 1, 1, 2, 2, 3, 3, 3, 0, 3, 0, 3}),\n      sizeof(IdType) * 8, CTX);\n  map = aten::VecToIdArray(\n      std::vector<IdType>(\n          {0, 2, 1, 3, 4, 5, 6, 6, 6, 7, 7, 8, 8, 9, 9, 9, 10, 11, 10, 11}),\n      sizeof(IdType) * 8, CTX);\n  const aten::COOMatrix &coo_a2 =\n      aten::COOMatrix(5, 4, a_row, a_col, aten::NullArray(), false, false);\n  ret = COOToSimple(coo_a2);\n  coo_b = std::get<0>(ret);\n  ecnt = std::get<1>(ret);\n  emap = std::get<2>(ret);\n  ASSERT_EQ(coo_b.num_rows, 5);\n  ASSERT_EQ(coo_b.num_cols, 4);\n  ASSERT_TRUE(ArrayEQ<IdType>(coo_b.row, b_row));\n  ASSERT_TRUE(ArrayEQ<IdType>(coo_b.col, b_col));\n  ASSERT_TRUE(ArrayEQ<IdType>(ecnt, cnt));\n  ASSERT_TRUE(ArrayEQ<IdType>(emap, map));\n  ASSERT_FALSE(COOHasData(coo_b));\n  ASSERT_TRUE(coo_b.row_sorted);\n  ASSERT_TRUE(coo_b.col_sorted);\n}\n\nTEST(MatrixTest, TestToSimpleCoo) {\n  _TestToSimpleCoo<int32_t>(CPU);\n  _TestToSimpleCoo<int64_t>(CPU);\n}\n\ntemplate <typename IdType>\nvoid _TestDisjointUnionPartitionCoo(DGLContext ctx) {\n  /**\n   * A = [[0, 0, 1],\n   *      [1, 0, 1],\n   *      [0, 1, 0]]\n   *\n   * B = [[1, 1, 0],\n   *      [0, 1, 0]]\n   *\n   * C = [[1]]\n   *\n   * AB = [[0, 0, 1, 0, 0, 0],\n   *       [1, 0, 1, 0, 0, 0],\n   *       [0, 1, 0, 0, 0, 0],\n   *       [0, 0, 0, 1, 1, 0],\n   *       [0, 0, 0, 0, 1, 0]]\n   *\n   * ABC = [[0, 0, 1, 0, 0, 0, 0],\n   *        [1, 0, 1, 0, 0, 0, 0],\n   *        [0, 1, 0, 0, 0, 0, 0],\n   *        [0, 0, 0, 1, 1, 0, 0],\n   *        [0, 0, 0, 0, 1, 0, 0],\n   *        [0, 0, 0, 0, 0, 0, 1]]\n   */\n  IdArray a_row = aten::VecToIdArray(\n      std::vector<IdType>({0, 1, 1, 2}), sizeof(IdType) * 8, CTX);\n  IdArray a_col = aten::VecToIdArray(\n      std::vector<IdType>({2, 0, 2, 1}), sizeof(IdType) * 8, CTX);\n  IdArray b_row = aten::VecToIdArray(\n      std::vector<IdType>({0, 0, 1}), sizeof(IdType) * 8, CTX);\n  IdArray b_col = aten::VecToIdArray(\n      std::vector<IdType>({0, 1, 1}), sizeof(IdType) * 8, CTX);\n  IdArray b_data = aten::VecToIdArray(\n      std::vector<IdType>({2, 0, 1}), sizeof(IdType) * 8, CTX);\n  IdArray c_row =\n      aten::VecToIdArray(std::vector<IdType>({0}), sizeof(IdType) * 8, CTX);\n  IdArray c_col =\n      aten::VecToIdArray(std::vector<IdType>({0}), sizeof(IdType) * 8, CTX);\n  IdArray ab_row = aten::VecToIdArray(\n      std::vector<IdType>({0, 1, 1, 2, 3, 3, 4}), sizeof(IdType) * 8, CTX);\n  IdArray ab_col = aten::VecToIdArray(\n      std::vector<IdType>({2, 0, 2, 1, 3, 4, 4}), sizeof(IdType) * 8, CTX);\n  IdArray ab_data = aten::VecToIdArray(\n      std::vector<IdType>({0, 1, 2, 3, 6, 4, 5}), sizeof(IdType) * 8, CTX);\n  IdArray abc_row = aten::VecToIdArray(\n      std::vector<IdType>({0, 1, 1, 2, 3, 3, 4, 5}), sizeof(IdType) * 8, CTX);\n  IdArray abc_col = aten::VecToIdArray(\n      std::vector<IdType>({2, 0, 2, 1, 3, 4, 4, 6}), sizeof(IdType) * 8, CTX);\n  IdArray abc_data = aten::VecToIdArray(\n      std::vector<IdType>({0, 1, 2, 3, 6, 4, 5, 7}), sizeof(IdType) * 8, CTX);\n  const aten::COOMatrix &coo_a =\n      aten::COOMatrix(3, 3, a_row, a_col, aten::NullArray(), true, false);\n  const aten::COOMatrix &coo_b =\n      aten::COOMatrix(2, 3, b_row, b_col, b_data, true, true);\n  const aten::COOMatrix &coo_c =\n      aten::COOMatrix(1, 1, c_row, c_col, aten::NullArray(), true, true);\n\n  const std::vector<aten::COOMatrix> coos_ab({coo_a, coo_b});\n  const aten::COOMatrix &coo_ab = aten::DisjointUnionCoo(coos_ab);\n  ASSERT_EQ(coo_ab.num_rows, 5);\n  ASSERT_EQ(coo_ab.num_cols, 6);\n  ASSERT_TRUE(ArrayEQ<IdType>(coo_ab.row, ab_row));\n  ASSERT_TRUE(ArrayEQ<IdType>(coo_ab.col, ab_col));\n  ASSERT_TRUE(ArrayEQ<IdType>(coo_ab.data, ab_data));\n  ASSERT_TRUE(coo_ab.row_sorted);\n  ASSERT_FALSE(coo_ab.col_sorted);\n\n  const std::vector<uint64_t> edge_cumsum({0, 4, 7});\n  const std::vector<uint64_t> src_vertex_cumsum({0, 3, 5});\n  const std::vector<uint64_t> dst_vertex_cumsum({0, 3, 6});\n  const std::vector<aten::COOMatrix> &p_coos =\n      aten::DisjointPartitionCooBySizes(\n          coo_ab, 2, edge_cumsum, src_vertex_cumsum, dst_vertex_cumsum);\n  ASSERT_EQ(p_coos[0].num_rows, coo_a.num_rows);\n  ASSERT_EQ(p_coos[0].num_cols, coo_a.num_cols);\n  ASSERT_EQ(p_coos[1].num_rows, coo_b.num_rows);\n  ASSERT_EQ(p_coos[1].num_cols, coo_b.num_cols);\n  ASSERT_TRUE(ArrayEQ<IdType>(p_coos[0].row, coo_a.row));\n  ASSERT_TRUE(ArrayEQ<IdType>(p_coos[0].col, coo_a.col));\n  ASSERT_TRUE(ArrayEQ<IdType>(p_coos[1].row, coo_b.row));\n  ASSERT_TRUE(ArrayEQ<IdType>(p_coos[1].col, coo_b.col));\n  ASSERT_TRUE(ArrayEQ<IdType>(p_coos[1].data, coo_b.data));\n  ASSERT_TRUE(p_coos[0].row_sorted);\n  ASSERT_FALSE(p_coos[0].col_sorted);\n  ASSERT_TRUE(p_coos[1].row_sorted);\n  ASSERT_FALSE(p_coos[1].col_sorted);\n\n  const std::vector<aten::COOMatrix> coos_abc({coo_a, coo_b, coo_c});\n  const aten::COOMatrix &coo_abc = aten::DisjointUnionCoo(coos_abc);\n  ASSERT_EQ(coo_abc.num_rows, 6);\n  ASSERT_EQ(coo_abc.num_cols, 7);\n  ASSERT_TRUE(ArrayEQ<IdType>(coo_abc.row, abc_row));\n  ASSERT_TRUE(ArrayEQ<IdType>(coo_abc.col, abc_col));\n  ASSERT_TRUE(ArrayEQ<IdType>(coo_abc.data, abc_data));\n  ASSERT_TRUE(coo_abc.row_sorted);\n  ASSERT_FALSE(coo_abc.col_sorted);\n\n  const std::vector<uint64_t> edge_cumsum_abc({0, 4, 7, 8});\n  const std::vector<uint64_t> src_vertex_cumsum_abc({0, 3, 5, 6});\n  const std::vector<uint64_t> dst_vertex_cumsum_abc({0, 3, 6, 7});\n  const std::vector<aten::COOMatrix> &p_coos_abc =\n      aten::DisjointPartitionCooBySizes(\n          coo_abc, 3, edge_cumsum_abc, src_vertex_cumsum_abc,\n          dst_vertex_cumsum_abc);\n  ASSERT_EQ(p_coos_abc[0].num_rows, coo_a.num_rows);\n  ASSERT_EQ(p_coos_abc[0].num_cols, coo_a.num_cols);\n  ASSERT_EQ(p_coos_abc[1].num_rows, coo_b.num_rows);\n  ASSERT_EQ(p_coos_abc[1].num_cols, coo_b.num_cols);\n  ASSERT_EQ(p_coos_abc[2].num_rows, coo_c.num_rows);\n  ASSERT_EQ(p_coos_abc[2].num_cols, coo_c.num_cols);\n  ASSERT_TRUE(ArrayEQ<IdType>(p_coos_abc[0].row, coo_a.row));\n  ASSERT_TRUE(ArrayEQ<IdType>(p_coos_abc[0].col, coo_a.col));\n  ASSERT_TRUE(ArrayEQ<IdType>(p_coos_abc[1].row, coo_b.row));\n  ASSERT_TRUE(ArrayEQ<IdType>(p_coos_abc[1].col, coo_b.col));\n  ASSERT_TRUE(ArrayEQ<IdType>(p_coos_abc[1].data, coo_b.data));\n  ASSERT_TRUE(ArrayEQ<IdType>(p_coos_abc[2].row, coo_c.row));\n  ASSERT_TRUE(ArrayEQ<IdType>(p_coos_abc[2].col, coo_c.col));\n  ASSERT_TRUE(p_coos_abc[0].row_sorted);\n  ASSERT_FALSE(p_coos_abc[0].col_sorted);\n  ASSERT_TRUE(p_coos_abc[1].row_sorted);\n  ASSERT_FALSE(p_coos_abc[1].col_sorted);\n  ASSERT_TRUE(p_coos_abc[2].row_sorted);\n  ASSERT_FALSE(p_coos_abc[2].col_sorted);\n}\n\nTEST(DisjointUnionTest, TestDisjointUnionPartitionCoo) {\n  _TestDisjointUnionPartitionCoo<int32_t>(CPU);\n  _TestDisjointUnionPartitionCoo<int64_t>(CPU);\n#ifdef DGL_USE_CUDA\n  _TestDisjointUnionPartitionCoo<int32_t>(GPU);\n  _TestDisjointUnionPartitionCoo<int64_t>(GPU);\n#endif\n}\n\ntemplate <typename IdType>\nvoid _TestDisjointUnionPartitionCsr(DGLContext ctx) {\n  /**\n   * A = [[0, 0, 1],\n   *      [1, 0, 1],\n   *      [0, 1, 0]]\n   *\n   * B = [[1, 1, 0],\n   *      [0, 1, 0]]\n   *\n   * C = [[1]]\n   *\n   * BC = [[1, 1, 0, 0],\n   *       [0, 1, 0, 0],\n   *       [0, 0, 0, 1]],\n   *\n   * ABC = [[0, 0, 1, 0, 0, 0, 0],\n   *        [1, 0, 1, 0, 0, 0, 0],\n   *        [0, 1, 0, 0, 0, 0, 0],\n   *        [0, 0, 0, 1, 1, 0, 0],\n   *        [0, 0, 0, 0, 1, 0, 0],\n   *        [0, 0, 0, 0, 0, 0, 1]]\n   */\n  IdArray a_indptr = aten::VecToIdArray(\n      std::vector<IdType>({0, 1, 3, 4}), sizeof(IdType) * 8, CTX);\n  IdArray a_indices = aten::VecToIdArray(\n      std::vector<IdType>({2, 0, 2, 1}), sizeof(IdType) * 8, CTX);\n  IdArray b_indptr = aten::VecToIdArray(\n      std::vector<IdType>({0, 2, 3}), sizeof(IdType) * 8, CTX);\n  IdArray b_indices = aten::VecToIdArray(\n      std::vector<IdType>({0, 1, 1}), sizeof(IdType) * 8, CTX);\n  IdArray b_data = aten::VecToIdArray(\n      std::vector<IdType>({2, 0, 1}), sizeof(IdType) * 8, CTX);\n  IdArray c_indptr =\n      aten::VecToIdArray(std::vector<IdType>({0, 1}), sizeof(IdType) * 8, CTX);\n  IdArray c_indices =\n      aten::VecToIdArray(std::vector<IdType>({0}), sizeof(IdType) * 8, CTX);\n  IdArray bc_indptr = aten::VecToIdArray(\n      std::vector<IdType>({0, 2, 3, 4}), sizeof(IdType) * 8, CTX);\n  IdArray bc_indices = aten::VecToIdArray(\n      std::vector<IdType>({0, 1, 1, 3}), sizeof(IdType) * 8, CTX);\n  IdArray bc_data = aten::VecToIdArray(\n      std::vector<IdType>({2, 0, 1, 3}), sizeof(IdType) * 8, CTX);\n  IdArray abc_indptr = aten::VecToIdArray(\n      std::vector<IdType>({0, 1, 3, 4, 6, 7, 8}), sizeof(IdType) * 8, CTX);\n  IdArray abc_indices = aten::VecToIdArray(\n      std::vector<IdType>({2, 0, 2, 1, 3, 4, 4, 6}), sizeof(IdType) * 8, CTX);\n  IdArray abc_data = aten::VecToIdArray(\n      std::vector<IdType>({0, 1, 2, 3, 6, 4, 5, 7}), sizeof(IdType) * 8, CTX);\n  const aten::CSRMatrix &csr_a =\n      aten::CSRMatrix(3, 3, a_indptr, a_indices, aten::NullArray(), false);\n  const aten::CSRMatrix &csr_b =\n      aten::CSRMatrix(2, 3, b_indptr, b_indices, b_data, true);\n  const aten::CSRMatrix &csr_c =\n      aten::CSRMatrix(1, 1, c_indptr, c_indices, aten::NullArray(), true);\n\n  const std::vector<aten::CSRMatrix> csrs_bc({csr_b, csr_c});\n  const aten::CSRMatrix &csr_bc = aten::DisjointUnionCsr(csrs_bc);\n  ASSERT_EQ(csr_bc.num_rows, 3);\n  ASSERT_EQ(csr_bc.num_cols, 4);\n  ASSERT_TRUE(ArrayEQ<IdType>(csr_bc.indptr, bc_indptr));\n  ASSERT_TRUE(ArrayEQ<IdType>(csr_bc.indices, bc_indices));\n  ASSERT_TRUE(ArrayEQ<IdType>(csr_bc.data, bc_data));\n  ASSERT_TRUE(csr_bc.sorted);\n\n  const std::vector<uint64_t> edge_cumsum({0, 3, 4});\n  const std::vector<uint64_t> src_vertex_cumsum({0, 2, 3});\n  const std::vector<uint64_t> dst_vertex_cumsum({0, 3, 4});\n  const std::vector<aten::CSRMatrix> &p_csrs =\n      aten::DisjointPartitionCsrBySizes(\n          csr_bc, 2, edge_cumsum, src_vertex_cumsum, dst_vertex_cumsum);\n  ASSERT_EQ(p_csrs[0].num_rows, csr_b.num_rows);\n  ASSERT_EQ(p_csrs[0].num_cols, csr_b.num_cols);\n  ASSERT_EQ(p_csrs[1].num_rows, csr_c.num_rows);\n  ASSERT_EQ(p_csrs[1].num_cols, csr_c.num_cols);\n  ASSERT_TRUE(ArrayEQ<IdType>(p_csrs[0].indptr, csr_b.indptr));\n  ASSERT_TRUE(ArrayEQ<IdType>(p_csrs[0].indices, csr_b.indices));\n  ASSERT_TRUE(ArrayEQ<IdType>(p_csrs[0].data, csr_b.data));\n  ASSERT_TRUE(ArrayEQ<IdType>(p_csrs[1].indptr, csr_c.indptr));\n  ASSERT_TRUE(ArrayEQ<IdType>(p_csrs[1].indices, csr_c.indices));\n  ASSERT_TRUE(p_csrs[0].sorted);\n  ASSERT_TRUE(p_csrs[1].sorted);\n\n  const std::vector<aten::CSRMatrix> csrs_abc({csr_a, csr_b, csr_c});\n  const aten::CSRMatrix &csr_abc = aten::DisjointUnionCsr(csrs_abc);\n  ASSERT_EQ(csr_abc.num_rows, 6);\n  ASSERT_EQ(csr_abc.num_cols, 7);\n  ASSERT_TRUE(ArrayEQ<IdType>(csr_abc.indptr, abc_indptr));\n  ASSERT_TRUE(ArrayEQ<IdType>(csr_abc.indices, abc_indices));\n  ASSERT_TRUE(ArrayEQ<IdType>(csr_abc.data, abc_data));\n  ASSERT_FALSE(csr_abc.sorted);\n\n  const std::vector<uint64_t> edge_cumsum_abc({0, 4, 7, 8});\n  const std::vector<uint64_t> src_vertex_cumsum_abc({0, 3, 5, 6});\n  const std::vector<uint64_t> dst_vertex_cumsum_abc({0, 3, 6, 7});\n  const std::vector<aten::CSRMatrix> &p_csrs_abc =\n      aten::DisjointPartitionCsrBySizes(\n          csr_abc, 3, edge_cumsum_abc, src_vertex_cumsum_abc,\n          dst_vertex_cumsum_abc);\n  ASSERT_EQ(p_csrs_abc[0].num_rows, csr_a.num_rows);\n  ASSERT_EQ(p_csrs_abc[0].num_cols, csr_a.num_cols);\n  ASSERT_EQ(p_csrs_abc[1].num_rows, csr_b.num_rows);\n  ASSERT_EQ(p_csrs_abc[1].num_cols, csr_b.num_cols);\n  ASSERT_EQ(p_csrs_abc[2].num_rows, csr_c.num_rows);\n  ASSERT_EQ(p_csrs_abc[2].num_cols, csr_c.num_cols);\n  ASSERT_TRUE(ArrayEQ<IdType>(p_csrs_abc[0].indptr, csr_a.indptr));\n  ASSERT_TRUE(ArrayEQ<IdType>(p_csrs_abc[0].indices, csr_a.indices));\n  ASSERT_TRUE(ArrayEQ<IdType>(p_csrs_abc[1].indptr, csr_b.indptr));\n  ASSERT_TRUE(ArrayEQ<IdType>(p_csrs_abc[1].indices, csr_b.indices));\n  ASSERT_TRUE(ArrayEQ<IdType>(p_csrs_abc[1].data, csr_b.data));\n  ASSERT_TRUE(ArrayEQ<IdType>(p_csrs_abc[2].indptr, csr_c.indptr));\n  ASSERT_TRUE(ArrayEQ<IdType>(p_csrs_abc[2].indices, csr_c.indices));\n  ASSERT_FALSE(p_csrs_abc[0].sorted);\n  ASSERT_FALSE(p_csrs_abc[1].sorted);\n  ASSERT_FALSE(p_csrs_abc[2].sorted);\n}\n\nTEST(DisjointUnionTest, TestDisjointUnionPartitionCsr) {\n  _TestDisjointUnionPartitionCsr<int32_t>(CPU);\n  _TestDisjointUnionPartitionCsr<int64_t>(CPU);\n#ifdef DGL_USE_CUDA\n  _TestDisjointUnionPartitionCsr<int32_t>(GPU);\n  _TestDisjointUnionPartitionCsr<int64_t>(GPU);\n#endif\n}\n\ntemplate <typename IdType>\nvoid _TestSliceContiguousChunkCoo(DGLContext ctx) {\n  /**\n   * A = [[1, 0, 0, 0],\n   *      [0, 0, 1, 0],\n   *      [0, 0, 0, 0]]\n   *\n   * B = [[1, 0, 0],\n   *      [0, 0, 1]]\n   *\n   * C = [[0]]\n   *\n   */\n  IdArray a_row =\n      aten::VecToIdArray(std::vector<IdType>({0, 1}), sizeof(IdType) * 8, CTX);\n  IdArray a_col =\n      aten::VecToIdArray(std::vector<IdType>({0, 2}), sizeof(IdType) * 8, CTX);\n  const aten::COOMatrix &coo_a =\n      aten::COOMatrix(3, 4, a_row, a_col, aten::NullArray(), true, false);\n\n  IdArray b_row =\n      aten::VecToIdArray(std::vector<IdType>({0, 1}), sizeof(IdType) * 8, CTX);\n  IdArray b_col =\n      aten::VecToIdArray(std::vector<IdType>({0, 2}), sizeof(IdType) * 8, CTX);\n  const aten::COOMatrix &coo_b_raw =\n      aten::COOMatrix(2, 3, b_row, b_col, aten::NullArray(), true, false);\n\n  const std::vector<uint64_t> edge_range_b({0, 2});\n  const std::vector<uint64_t> src_vertex_range_b({0, 2});\n  const std::vector<uint64_t> dst_vertex_range_b({0, 3});\n  const aten::COOMatrix &coo_b = aten::COOSliceContiguousChunk(\n      coo_a, edge_range_b, src_vertex_range_b, dst_vertex_range_b);\n  ASSERT_EQ(coo_b_raw.num_rows, coo_b.num_rows);\n  ASSERT_EQ(coo_b_raw.num_cols, coo_b.num_cols);\n  ASSERT_TRUE(ArrayEQ<IdType>(coo_b_raw.row, coo_b.row));\n  ASSERT_TRUE(ArrayEQ<IdType>(coo_b_raw.col, coo_b.col));\n  ASSERT_TRUE(coo_b.row_sorted);\n  ASSERT_FALSE(coo_b.col_sorted);\n\n  IdArray c_row =\n      aten::VecToIdArray(std::vector<IdType>({}), sizeof(IdType) * 8, CTX);\n  IdArray c_col =\n      aten::VecToIdArray(std::vector<IdType>({}), sizeof(IdType) * 8, CTX);\n  const aten::COOMatrix &coo_c_raw =\n      aten::COOMatrix(1, 1, c_row, c_col, aten::NullArray(), true, false);\n\n  const std::vector<uint64_t> edge_range_c({2, 2});\n  const std::vector<uint64_t> src_vertex_range_c({2, 3});\n  const std::vector<uint64_t> dst_vertex_range_c({3, 4});\n  const aten::COOMatrix &coo_c = aten::COOSliceContiguousChunk(\n      coo_a, edge_range_c, src_vertex_range_c, dst_vertex_range_c);\n  ASSERT_EQ(coo_c_raw.num_rows, coo_c.num_rows);\n  ASSERT_EQ(coo_c_raw.num_cols, coo_c.num_cols);\n  ASSERT_TRUE(ArrayEQ<IdType>(coo_c.row, c_row));\n  ASSERT_TRUE(ArrayEQ<IdType>(coo_c.col, c_col));\n  ASSERT_TRUE(coo_c.row_sorted);\n  ASSERT_FALSE(coo_c.col_sorted);\n}\n\nTEST(SliceContiguousChunk, TestSliceContiguousChunkCoo) {\n  _TestSliceContiguousChunkCoo<int32_t>(CPU);\n  _TestSliceContiguousChunkCoo<int64_t>(CPU);\n#ifdef DGL_USE_CUDA\n  _TestSliceContiguousChunkCoo<int32_t>(GPU);\n  _TestSliceContiguousChunkCoo<int64_t>(GPU);\n#endif\n}\n\ntemplate <typename IdType>\nvoid _TestSliceContiguousChunkCsr(DGLContext ctx) {\n  /**\n   * A = [[1, 0, 0, 0],\n   *      [0, 0, 1, 0],\n   *      [0, 0, 0, 0]]\n   *\n   * B = [[1, 0, 0],\n   *      [0, 0, 1]]\n   *\n   * C = [[0]]\n   *\n   */\n  IdArray a_indptr = aten::VecToIdArray(\n      std::vector<IdType>({0, 1, 2, 2}), sizeof(IdType) * 8, CTX);\n  IdArray a_indices =\n      aten::VecToIdArray(std::vector<IdType>({0, 2}), sizeof(IdType) * 8, CTX);\n  const aten::CSRMatrix &csr_a =\n      aten::CSRMatrix(3, 4, a_indptr, a_indices, aten::NullArray(), false);\n\n  IdArray b_indptr = aten::VecToIdArray(\n      std::vector<IdType>({0, 1, 2}), sizeof(IdType) * 8, CTX);\n  IdArray b_indices =\n      aten::VecToIdArray(std::vector<IdType>({0, 2}), sizeof(IdType) * 8, CTX);\n  const aten::CSRMatrix &csr_b_raw =\n      aten::CSRMatrix(2, 3, b_indptr, b_indices, aten::NullArray(), false);\n\n  const std::vector<uint64_t> edge_range_b({0, 2});\n  const std::vector<uint64_t> src_vertex_range_b({0, 2});\n  const std::vector<uint64_t> dst_vertex_range_b({0, 3});\n  const aten::CSRMatrix &csr_b = aten::CSRSliceContiguousChunk(\n      csr_a, edge_range_b, src_vertex_range_b, dst_vertex_range_b);\n  ASSERT_EQ(csr_b.num_rows, csr_b_raw.num_rows);\n  ASSERT_EQ(csr_b.num_cols, csr_b_raw.num_cols);\n  ASSERT_TRUE(ArrayEQ<IdType>(csr_b.indptr, csr_b_raw.indptr));\n  ASSERT_TRUE(ArrayEQ<IdType>(csr_b.indices, csr_b_raw.indices));\n  ASSERT_FALSE(csr_b.sorted);\n\n  const std::vector<uint64_t> edge_range_c({2, 2});\n  const std::vector<uint64_t> src_vertex_range_c({2, 3});\n  const std::vector<uint64_t> dst_vertex_range_c({3, 4});\n  const aten::CSRMatrix &csr_c = aten::CSRSliceContiguousChunk(\n      csr_a, edge_range_c, src_vertex_range_c, dst_vertex_range_c);\n\n  int64_t indptr_len = src_vertex_range_c[1] - src_vertex_range_c[0] + 1;\n  IdArray c_indptr = aten::Full(0, indptr_len, sizeof(IdType) * 8, CTX);\n  IdArray c_indices =\n      aten::VecToIdArray(std::vector<IdType>({}), sizeof(IdType) * 8, CTX);\n  const aten::CSRMatrix &csr_c_raw =\n      aten::CSRMatrix(1, 1, c_indptr, c_indices, aten::NullArray(), false);\n\n  ASSERT_EQ(csr_c.num_rows, csr_c_raw.num_rows);\n  ASSERT_EQ(csr_c.num_cols, csr_c_raw.num_cols);\n  ASSERT_TRUE(ArrayEQ<IdType>(csr_c.indptr, c_indptr));\n  ASSERT_TRUE(ArrayEQ<IdType>(csr_c.indices, c_indices));\n  ASSERT_FALSE(csr_c.sorted);\n}\n\nTEST(SliceContiguousChunk, TestSliceContiguousChunkCsr) {\n  _TestSliceContiguousChunkCsr<int32_t>(CPU);\n  _TestSliceContiguousChunkCsr<int64_t>(CPU);\n#ifdef DGL_USE_CUDA\n  _TestSliceContiguousChunkCsr<int32_t>(GPU);\n  _TestSliceContiguousChunkCsr<int64_t>(GPU);\n#endif\n}\n\ntemplate <typename IdType>\nvoid _TestMatrixUnionCsr(DGLContext ctx) {\n  /**\n   * A = [[0, 0, 0, 0],\n   *      [0, 0, 0, 0],\n   *      [0, 1, 0, 0],\n   *      [1, 1, 1, 1],\n   *      [0, 1, 1, 0],\n   *      [1, 0, 0, 1]]\n   *\n   * B = [[0, 0, 0, 0],\n   *      [1, 0, 0, 1],\n   *      [0, 0, 1, 0],\n   *      [1, 0, 0, 1],\n   *      [1, 0, 0, 1]]\n   *      [1, 0, 0, 1]]\n   *\n   * C = UnionCsr({A, B})\n   *\n   * C = [[0, 0, 0, 0],\n   *      [1, 0, 0, 1],\n   *      [0, 1, 1, 0],\n   *      [2, 1, 1, 2],\n   *      [1, 1, 1, 1]]\n   *      [2, 0, 0, 2]]\n   *\n   * D = [[1, 0, 0, 0],\n   *      [0, 0, 0, 0],\n   *      [0, 0, 0, 0],\n   *      [0, 0, 0, 0],\n   *      [0, 0, 0, 0],\n   *      [1, 0, 0, 1]]\n   *\n   * C = UnionCsr({A, B, D})\n   *\n   * C = [[1, 0, 0, 0],\n   *      [1, 0, 0, 1],\n   *      [0, 1, 1, 0],\n   *      [2, 1, 1, 2],\n   *      [1, 1, 1, 1]]\n   *      [3, 0, 0, 3]]\n   */\n  IdArray a_indptr = aten::VecToIdArray(\n      std::vector<IdType>({0, 0, 0, 1, 5, 7, 9}), sizeof(IdType) * 8, CTX);\n  IdArray a_indices = aten::VecToIdArray(\n      std::vector<IdType>({1, 0, 1, 2, 3, 1, 2, 0, 3}), sizeof(IdType) * 8,\n      CTX);\n  IdArray b_indptr = aten::VecToIdArray(\n      std::vector<IdType>({0, 0, 2, 3, 5, 7, 9}), sizeof(IdType) * 8, CTX);\n  IdArray b_indices = aten::VecToIdArray(\n      std::vector<IdType>({0, 3, 2, 0, 3, 0, 3, 0, 3}), sizeof(IdType) * 8,\n      CTX);\n  IdArray c_indptr = aten::VecToIdArray(\n      std::vector<IdType>({0, 0, 2, 4, 10, 14, 18}), sizeof(IdType) * 8, CTX);\n  IdArray c_indices = aten::VecToIdArray(\n      std::vector<IdType>(\n          {0, 3, 1, 2, 0, 0, 1, 2, 3, 3, 0, 1, 2, 3, 0, 0, 3, 3}),\n      sizeof(IdType) * 8, CTX);\n  IdArray c_data = aten::VecToIdArray(\n      std::vector<IdType>(\n          {9, 10, 0, 11, 1, 12, 2, 3, 4, 13, 14, 5, 6, 15, 7, 16, 8, 17}),\n      sizeof(IdType) * 8, CTX);\n\n  const aten::CSRMatrix &csr_a =\n      aten::CSRMatrix(6, 4, a_indptr, a_indices, aten::NullArray(), true);\n  const aten::CSRMatrix &csr_b =\n      aten::CSRMatrix(6, 4, b_indptr, b_indices, aten::NullArray(), true);\n\n  const aten::CSRMatrix &csr_aUb = aten::UnionCsr({csr_a, csr_b});\n  ASSERT_EQ(csr_aUb.num_rows, 6);\n  ASSERT_EQ(csr_aUb.num_cols, 4);\n  ASSERT_TRUE(ArrayEQ<IdType>(csr_aUb.indptr, c_indptr));\n  ASSERT_TRUE(ArrayEQ<IdType>(csr_aUb.indices, c_indices));\n  ASSERT_TRUE(ArrayEQ<IdType>(csr_aUb.data, c_data));\n  ASSERT_TRUE(csr_aUb.sorted);\n\n  IdArray a_data = aten::VecToIdArray(\n      std::vector<IdType>({8, 7, 6, 5, 4, 3, 2, 1, 0}), sizeof(IdType) * 8,\n      CTX);\n\n  c_data = aten::VecToIdArray(\n      std::vector<IdType>(\n          {9, 10, 8, 11, 7, 12, 6, 5, 4, 13, 14, 3, 2, 15, 1, 16, 0, 17}),\n      sizeof(IdType) * 8, CTX);\n  const aten::CSRMatrix &csr_ad =\n      aten::CSRMatrix(6, 4, a_indptr, a_indices, a_data, true);\n  const aten::CSRMatrix &csr_adUb = aten::UnionCsr({csr_ad, csr_b});\n  ASSERT_EQ(csr_adUb.num_rows, 6);\n  ASSERT_EQ(csr_adUb.num_cols, 4);\n  ASSERT_TRUE(ArrayEQ<IdType>(csr_adUb.indptr, c_indptr));\n  ASSERT_TRUE(ArrayEQ<IdType>(csr_adUb.indices, c_indices));\n  ASSERT_TRUE(ArrayEQ<IdType>(csr_adUb.data, c_data));\n  ASSERT_TRUE(csr_adUb.sorted);\n\n  IdArray b_indices2 = aten::VecToIdArray(\n      std::vector<IdType>({0, 3, 2, 0, 3, 3, 0, 0, 3}), sizeof(IdType) * 8,\n      CTX);\n  c_indices = aten::VecToIdArray(\n      std::vector<IdType>(\n          {0, 3, 1, 2, 0, 1, 2, 3, 0, 3, 1, 2, 3, 0, 0, 3, 0, 3}),\n      sizeof(IdType) * 8, CTX);\n  c_data = aten::VecToIdArray(\n      std::vector<IdType>(\n          {9, 10, 0, 11, 1, 2, 3, 4, 12, 13, 5, 6, 14, 15, 7, 8, 16, 17}),\n      sizeof(IdType) * 8, CTX);\n  const aten::CSRMatrix &csr_b2 =\n      aten::CSRMatrix(6, 4, b_indptr, b_indices2, aten::NullArray(), false);\n  const aten::CSRMatrix &csr_aUb2 = aten::UnionCsr({csr_a, csr_b2});\n  ASSERT_EQ(csr_aUb2.num_rows, 6);\n  ASSERT_EQ(csr_aUb2.num_cols, 4);\n  ASSERT_TRUE(ArrayEQ<IdType>(csr_aUb2.indptr, c_indptr));\n  ASSERT_TRUE(ArrayEQ<IdType>(csr_aUb2.indices, c_indices));\n  ASSERT_TRUE(ArrayEQ<IdType>(csr_aUb2.data, c_data));\n  ASSERT_FALSE(csr_aUb2.sorted);\n\n  IdArray a_indices2 = aten::VecToIdArray(\n      std::vector<IdType>({1, 3, 2, 1, 0, 1, 2, 0, 3}), sizeof(IdType) * 8,\n      CTX);\n  c_indices = aten::VecToIdArray(\n      std::vector<IdType>(\n          {0, 3, 1, 2, 3, 2, 1, 0, 0, 3, 1, 2, 0, 3, 0, 3, 0, 3}),\n      sizeof(IdType) * 8, CTX);\n  const aten::CSRMatrix &csr_a2 =\n      aten::CSRMatrix(6, 4, a_indptr, a_indices2, aten::NullArray(), false);\n  const aten::CSRMatrix &csr_aUb3 = aten::UnionCsr({csr_a2, csr_b});\n  ASSERT_EQ(csr_aUb3.num_rows, 6);\n  ASSERT_EQ(csr_aUb3.num_cols, 4);\n  ASSERT_TRUE(ArrayEQ<IdType>(csr_aUb3.indptr, c_indptr));\n  ASSERT_TRUE(ArrayEQ<IdType>(csr_aUb3.indices, c_indices));\n  ASSERT_TRUE(ArrayEQ<IdType>(csr_aUb3.data, c_data));\n  ASSERT_FALSE(csr_aUb3.sorted);\n\n  c_indices = aten::VecToIdArray(\n      std::vector<IdType>(\n          {0, 3, 1, 2, 3, 2, 1, 0, 0, 3, 1, 2, 3, 0, 0, 3, 0, 3}),\n      sizeof(IdType) * 8, CTX);\n  const aten::CSRMatrix &csr_aUb4 = aten::UnionCsr({csr_a2, csr_b2});\n  ASSERT_EQ(csr_aUb4.num_rows, 6);\n  ASSERT_EQ(csr_aUb4.num_cols, 4);\n  ASSERT_TRUE(ArrayEQ<IdType>(csr_aUb4.indptr, c_indptr));\n  ASSERT_TRUE(ArrayEQ<IdType>(csr_aUb4.indices, c_indices));\n  ASSERT_TRUE(ArrayEQ<IdType>(csr_aUb4.data, c_data));\n  ASSERT_FALSE(csr_aUb4.sorted);\n\n  IdArray d_indptr = aten::VecToIdArray(\n      std::vector<IdType>({0, 1, 1, 1, 1, 1, 3}), sizeof(IdType) * 8, CTX);\n  IdArray d_indices = aten::VecToIdArray(\n      std::vector<IdType>({0, 0, 3}), sizeof(IdType) * 8, CTX);\n  c_indptr = aten::VecToIdArray(\n      std::vector<IdType>({0, 1, 3, 5, 11, 15, 21}), sizeof(IdType) * 8, CTX);\n  c_indices = aten::VecToIdArray(\n      std::vector<IdType>(\n          {0, 0, 3, 1, 2, 0, 0, 1, 2, 3, 3, 0, 1, 2, 3, 0, 0, 0, 3, 3, 3}),\n      sizeof(IdType) * 8, CTX);\n  c_data = aten::VecToIdArray(\n      std::vector<IdType>({18, 9, 10, 8,  11, 7,  12, 6, 5,  4, 13,\n                           14, 3, 2,  15, 1,  16, 19, 0, 17, 20}),\n      sizeof(IdType) * 8, CTX);\n  const aten::CSRMatrix &csr_d =\n      aten::CSRMatrix(6, 4, d_indptr, d_indices, aten::NullArray(), true);\n  const aten::CSRMatrix &csr_aUbUd = aten::UnionCsr({csr_ad, csr_b, csr_d});\n  ASSERT_EQ(csr_aUbUd.num_rows, 6);\n  ASSERT_EQ(csr_aUbUd.num_cols, 4);\n  ASSERT_TRUE(ArrayEQ<IdType>(csr_aUbUd.indptr, c_indptr));\n  ASSERT_TRUE(ArrayEQ<IdType>(csr_aUbUd.indices, c_indices));\n  ASSERT_TRUE(ArrayEQ<IdType>(csr_aUbUd.data, c_data));\n  ASSERT_TRUE(csr_aUbUd.sorted);\n\n  c_indices = aten::VecToIdArray(\n      std::vector<IdType>(\n          {0, 0, 3, 1, 2, 3, 2, 1, 0, 0, 3, 1, 2, 3, 0, 0, 3, 0, 3, 0, 3}),\n      sizeof(IdType) * 8, CTX);\n  c_data = aten::VecToIdArray(\n      std::vector<IdType>({18, 9, 10, 0,  11, 1, 2,  3,  4,  12, 13,\n                           5,  6, 14, 15, 7,  8, 16, 17, 19, 20}),\n      sizeof(IdType) * 8, CTX);\n\n  const aten::CSRMatrix &csr_aUbUd2 = aten::UnionCsr({csr_a2, csr_b2, csr_d});\n  ASSERT_EQ(csr_aUbUd2.num_rows, 6);\n  ASSERT_EQ(csr_aUbUd2.num_cols, 4);\n  ASSERT_TRUE(ArrayEQ<IdType>(csr_aUbUd2.indptr, c_indptr));\n  ASSERT_TRUE(ArrayEQ<IdType>(csr_aUbUd2.indices, c_indices));\n  ASSERT_TRUE(ArrayEQ<IdType>(csr_aUbUd2.data, c_data));\n  ASSERT_FALSE(csr_aUbUd2.sorted);\n}\n\nTEST(MatrixUnionTest, TestMatrixUnionCsr) {\n  _TestMatrixUnionCsr<int32_t>(CPU);\n  _TestMatrixUnionCsr<int64_t>(CPU);\n}\n\ntemplate <typename IdType>\nvoid _TestMatrixUnionCoo(DGLContext ctx) {\n  /**\n   * A = [[0, 0, 0, 0],\n   *      [0, 0, 0, 0],\n   *      [0, 1, 0, 0],\n   *      [1, 1, 1, 1],\n   *      [0, 1, 1, 0],\n   *      [1, 0, 0, 1]]\n   *\n   * B = [[0, 0, 0, 0],\n   *      [1, 0, 0, 1],\n   *      [0, 0, 1, 0],\n   *      [1, 0, 0, 1],\n   *      [1, 0, 0, 1]]\n   *      [1, 0, 0, 1]]\n   *\n   * C = UnionCsr({A, B})\n   *\n   * C = [[0, 0, 0, 0],\n   *      [1, 0, 0, 1],\n   *      [0, 1, 1, 0],\n   *      [2, 1, 1, 2],\n   *      [1, 1, 1, 1]]\n   *      [2, 0, 0, 2]]\n   *\n   * D = [[1, 0, 0, 0],\n   *      [0, 0, 0, 0],\n   *      [0, 0, 0, 0],\n   *      [0, 0, 0, 0],\n   *      [0, 0, 0, 0],\n   *      [1, 0, 0, 1]]\n   *\n   * C = UnionCsr({A, B, D})\n   *\n   * C = [[1, 0, 0, 0],\n   *      [1, 0, 0, 1],\n   *      [0, 1, 1, 0],\n   *      [2, 1, 1, 2],\n   *      [1, 1, 1, 1]]\n   *      [3, 0, 0, 3]]\n   */\n  IdArray a_row = aten::VecToIdArray(\n      std::vector<IdType>({2, 3, 3, 3, 3, 4, 4, 5, 5}), sizeof(IdType) * 8,\n      CTX);\n  IdArray a_col = aten::VecToIdArray(\n      std::vector<IdType>({1, 0, 1, 2, 3, 1, 2, 0, 3}), sizeof(IdType) * 8,\n      CTX);\n  IdArray b_row = aten::VecToIdArray(\n      std::vector<IdType>({1, 1, 2, 3, 3, 4, 4, 5, 5}), sizeof(IdType) * 8,\n      CTX);\n  IdArray b_col = aten::VecToIdArray(\n      std::vector<IdType>({0, 3, 2, 0, 3, 0, 3, 0, 3}), sizeof(IdType) * 8,\n      CTX);\n  IdArray c_row = aten::VecToIdArray(\n      std::vector<IdType>(\n          {2, 3, 3, 3, 3, 4, 4, 5, 5, 1, 1, 2, 3, 3, 4, 4, 5, 5}),\n      sizeof(IdType) * 8, CTX);\n  IdArray c_col = aten::VecToIdArray(\n      std::vector<IdType>(\n          {1, 0, 1, 2, 3, 1, 2, 0, 3, 0, 3, 2, 0, 3, 0, 3, 0, 3}),\n      sizeof(IdType) * 8, CTX);\n  const aten::COOMatrix &coo_a =\n      aten::COOMatrix(6, 4, a_row, a_col, aten::NullArray(), true, true);\n  const aten::COOMatrix &coo_b =\n      aten::COOMatrix(6, 4, b_row, b_col, aten::NullArray(), true, true);\n  const std::vector<aten::COOMatrix> coos_ab({coo_a, coo_b});\n  const aten::COOMatrix &coo_ab = aten::UnionCoo(coos_ab);\n  ASSERT_EQ(coo_ab.num_rows, 6);\n  ASSERT_EQ(coo_ab.num_cols, 4);\n  ASSERT_TRUE(ArrayEQ<IdType>(coo_ab.row, c_row));\n  ASSERT_TRUE(ArrayEQ<IdType>(coo_ab.col, c_col));\n  ASSERT_FALSE(COOHasData(coo_ab));\n  ASSERT_FALSE(coo_ab.row_sorted);\n  ASSERT_FALSE(coo_ab.col_sorted);\n\n  IdArray a_data = aten::VecToIdArray(\n      std::vector<IdType>({2, 1, 0, 3, 4, 5, 6, 7, 8}), sizeof(IdType) * 8,\n      CTX);\n\n  IdArray c_data = aten::VecToIdArray(\n      std::vector<IdType>(\n          {2, 1, 0, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}),\n      sizeof(IdType) * 8, CTX);\n  const aten::COOMatrix &coo_a2 =\n      aten::COOMatrix(6, 4, a_row, a_col, a_data, true, true);\n  const std::vector<aten::COOMatrix> coos_ab2({coo_a2, coo_b});\n  const aten::COOMatrix &coo_ab2 = aten::UnionCoo(coos_ab2);\n  ASSERT_EQ(coo_ab2.num_rows, 6);\n  ASSERT_EQ(coo_ab2.num_cols, 4);\n  ASSERT_TRUE(ArrayEQ<IdType>(coo_ab2.row, c_row));\n  ASSERT_TRUE(ArrayEQ<IdType>(coo_ab2.col, c_col));\n  ASSERT_TRUE(COOHasData(coo_ab2));\n  ASSERT_TRUE(ArrayEQ<IdType>(coo_ab2.data, c_data));\n  ASSERT_FALSE(coo_ab2.row_sorted);\n  ASSERT_FALSE(coo_ab2.col_sorted);\n\n  IdArray b_data = aten::VecToIdArray(\n      std::vector<IdType>({0, 1, 2, 3, 4, 5, 6, 8, 7}), sizeof(IdType) * 8,\n      CTX);\n  c_data = aten::VecToIdArray(\n      std::vector<IdType>(\n          {2, 1, 0, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 17, 16}),\n      sizeof(IdType) * 8, CTX);\n  const aten::COOMatrix &coo_b2 =\n      aten::COOMatrix(6, 4, b_row, b_col, b_data, true, true);\n  const std::vector<aten::COOMatrix> coos_ab3({coo_a2, coo_b2});\n  const aten::COOMatrix &coo_ab3 = aten::UnionCoo(coos_ab3);\n  ASSERT_EQ(coo_ab3.num_rows, 6);\n  ASSERT_EQ(coo_ab3.num_cols, 4);\n  ASSERT_TRUE(ArrayEQ<IdType>(coo_ab3.row, c_row));\n  ASSERT_TRUE(ArrayEQ<IdType>(coo_ab3.col, c_col));\n  ASSERT_TRUE(COOHasData(coo_ab3));\n  ASSERT_TRUE(ArrayEQ<IdType>(coo_ab3.data, c_data));\n  ASSERT_FALSE(coo_ab3.row_sorted);\n  ASSERT_FALSE(coo_ab3.col_sorted);\n\n  c_data = aten::VecToIdArray(\n      std::vector<IdType>(\n          {2, 1, 0, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 17, 16}),\n      sizeof(IdType) * 8, CTX);\n\n  const std::vector<aten::COOMatrix> coos_ab4({coo_a2, coo_b2});\n  const aten::COOMatrix &coo_ab4 = aten::UnionCoo(coos_ab4);\n  ASSERT_EQ(coo_ab4.num_rows, 6);\n  ASSERT_EQ(coo_ab4.num_cols, 4);\n  ASSERT_TRUE(ArrayEQ<IdType>(coo_ab4.row, c_row));\n  ASSERT_TRUE(ArrayEQ<IdType>(coo_ab4.col, c_col));\n  ASSERT_TRUE(COOHasData(coo_ab4));\n  ASSERT_TRUE(ArrayEQ<IdType>(coo_ab4.data, c_data));\n  ASSERT_FALSE(coo_ab4.row_sorted);\n  ASSERT_FALSE(coo_ab4.col_sorted);\n\n  IdArray d_row = aten::VecToIdArray(\n      std::vector<IdType>({0, 5, 5}), sizeof(IdType) * 8, CTX);\n  IdArray d_col = aten::VecToIdArray(\n      std::vector<IdType>({0, 0, 3}), sizeof(IdType) * 8, CTX);\n  c_row = aten::VecToIdArray(\n      std::vector<IdType>(\n          {2, 3, 3, 3, 3, 4, 4, 5, 5, 1, 1, 2, 3, 3, 4, 4, 5, 5, 0, 5, 5}),\n      sizeof(IdType) * 8, CTX);\n  c_col = aten::VecToIdArray(\n      std::vector<IdType>(\n          {1, 0, 1, 2, 3, 1, 2, 0, 3, 0, 3, 2, 0, 3, 0, 3, 0, 3, 0, 0, 3}),\n      sizeof(IdType) * 8, CTX);\n\n  const aten::COOMatrix &coo_d =\n      aten::COOMatrix(6, 4, d_row, d_col, aten::NullArray(), true, true);\n  const aten::COOMatrix &csr_aUbUd = aten::UnionCoo({coo_a, coo_b, coo_d});\n  ASSERT_EQ(csr_aUbUd.num_rows, 6);\n  ASSERT_EQ(csr_aUbUd.num_cols, 4);\n  ASSERT_TRUE(ArrayEQ<IdType>(csr_aUbUd.row, c_row));\n  ASSERT_TRUE(ArrayEQ<IdType>(csr_aUbUd.col, c_col));\n  ASSERT_FALSE(COOHasData(csr_aUbUd));\n  ASSERT_FALSE(csr_aUbUd.row_sorted);\n  ASSERT_FALSE(csr_aUbUd.col_sorted);\n\n  c_data = aten::VecToIdArray(\n      std::vector<IdType>({2,  1,  0,  3,  4,  5,  6,  7,  8,  9, 10,\n                           11, 12, 13, 14, 15, 17, 16, 18, 19, 20}),\n      sizeof(IdType) * 8, CTX);\n\n  const aten::COOMatrix &csr_aUbUd2 = aten::UnionCoo({coo_a2, coo_b2, coo_d});\n  ASSERT_EQ(csr_aUbUd2.num_rows, 6);\n  ASSERT_EQ(csr_aUbUd2.num_cols, 4);\n  ASSERT_TRUE(ArrayEQ<IdType>(csr_aUbUd2.row, c_row));\n  ASSERT_TRUE(ArrayEQ<IdType>(csr_aUbUd2.col, c_col));\n  ASSERT_TRUE(COOHasData(csr_aUbUd2));\n  ASSERT_TRUE(ArrayEQ<IdType>(csr_aUbUd2.data, c_data));\n  ASSERT_FALSE(csr_aUbUd2.row_sorted);\n  ASSERT_FALSE(csr_aUbUd2.col_sorted);\n}\n\nTEST(MatrixUnionTest, TestMatrixUnionCoo) {\n  _TestMatrixUnionCoo<int32_t>(CPU);\n  _TestMatrixUnionCoo<int64_t>(CPU);\n}\n\ntemplate <typename IDX>\nvoid _TestCumSum(DGLContext ctx) {\n  IdArray a = aten::VecToIdArray(\n      std::vector<IDX>({8, 6, 7, 5, 3, 0, 9}), sizeof(IDX) * 8, ctx);\n  {\n    IdArray tb = aten::VecToIdArray(\n        std::vector<IDX>({8, 14, 21, 26, 29, 29, 38}), sizeof(IDX) * 8, ctx);\n    IdArray b = aten::CumSum(a);\n    ASSERT_TRUE(ArrayEQ<IDX>(b, tb));\n  }\n  {\n    IdArray tb = aten::VecToIdArray(\n        std::vector<IDX>({0, 8, 14, 21, 26, 29, 29, 38}), sizeof(IDX) * 8, ctx);\n    IdArray b = aten::CumSum(a, true);\n    ASSERT_TRUE(ArrayEQ<IDX>(b, tb));\n  }\n  a = aten::VecToIdArray(std::vector<IDX>({}), sizeof(IDX) * 8, ctx);\n  {\n    IdArray tb = aten::VecToIdArray(std::vector<IDX>({}), sizeof(IDX) * 8, ctx);\n    IdArray b = aten::CumSum(a);\n    ASSERT_TRUE(ArrayEQ<IDX>(b, tb));\n  }\n  {\n    IdArray tb = aten::VecToIdArray(std::vector<IDX>({}), sizeof(IDX) * 8, ctx);\n    IdArray b = aten::CumSum(a);\n    ASSERT_TRUE(ArrayEQ<IDX>(b, tb));\n  }\n}\n\nTEST(ArrayTest, CumSum) {\n  _TestCumSum<int32_t>(CPU);\n  _TestCumSum<int64_t>(CPU);\n#ifdef DGL_USE_CUDA\n  _TestCumSum<int32_t>(GPU);\n  _TestCumSum<int64_t>(GPU);\n#endif\n}\n\ntemplate <typename IDX, typename D>\nvoid _TestScatter_(DGLContext ctx) {\n  IdArray out = aten::Full(1, 10, 8 * sizeof(IDX), ctx);\n  IdArray idx =\n      aten::VecToIdArray(std::vector<IDX>({2, 3, 9}), sizeof(IDX) * 8, ctx);\n  IdArray val =\n      aten::VecToIdArray(std::vector<IDX>({-20, 30, 90}), sizeof(IDX) * 8, ctx);\n  aten::Scatter_(idx, val, out);\n  IdArray tout = aten::VecToIdArray(\n      std::vector<IDX>({1, 1, -20, 30, 1, 1, 1, 1, 1, 90}), sizeof(IDX) * 8,\n      ctx);\n  ASSERT_TRUE(ArrayEQ<IDX>(out, tout));\n}\n\nTEST(ArrayTest, Scatter_) {\n  _TestScatter_<int32_t, int32_t>(CPU);\n  _TestScatter_<int64_t, int32_t>(CPU);\n  _TestScatter_<int32_t, int64_t>(CPU);\n  _TestScatter_<int64_t, int64_t>(CPU);\n#ifdef DGL_USE_CUDA\n  _TestScatter_<int32_t, int32_t>(GPU);\n  _TestScatter_<int64_t, int32_t>(GPU);\n  _TestScatter_<int32_t, int64_t>(GPU);\n  _TestScatter_<int64_t, int64_t>(GPU);\n#endif\n}\n\ntemplate <typename IDX>\nvoid _TestNonZero(DGLContext ctx) {\n  auto val = aten::VecToIdArray(\n      std::vector<IDX>({0, 1, 2, 0, -10, 0, 0, 23}), sizeof(IDX) * 8, ctx);\n  auto idx = aten::NonZero(val);\n  auto tidx = aten::VecToIdArray(std::vector<int64_t>({1, 2, 4, 7}), 64, ctx);\n  ASSERT_TRUE(ArrayEQ<IDX>(idx, tidx));\n\n  val = aten::VecToIdArray(std::vector<IDX>({}), sizeof(IDX) * 8, ctx);\n  idx = aten::NonZero(val);\n  tidx = aten::VecToIdArray(std::vector<int64_t>({}), 64, ctx);\n  ASSERT_TRUE(ArrayEQ<IDX>(idx, tidx));\n\n  val =\n      aten::VecToIdArray(std::vector<IDX>({0, 0, 0, 0}), sizeof(IDX) * 8, ctx);\n  idx = aten::NonZero(val);\n  tidx = aten::VecToIdArray(std::vector<int64_t>({}), 64, ctx);\n  ASSERT_TRUE(ArrayEQ<IDX>(idx, tidx));\n\n  val = aten::Full(1, 3, sizeof(IDX) * 8, ctx);\n  idx = aten::NonZero(val);\n  tidx = aten::VecToIdArray(std::vector<int64_t>({0, 1, 2}), 64, ctx);\n  ASSERT_TRUE(ArrayEQ<IDX>(idx, tidx));\n}\n\nTEST(ArrayTest, NonZero) {\n  _TestNonZero<int32_t>(CPU);\n  _TestNonZero<int64_t>(CPU);\n#ifdef DGL_USE_CUDA\n  _TestNonZero<int32_t>(GPU);\n  _TestNonZero<int64_t>(GPU);\n#endif\n}\n\ntemplate <typename IdType>\nvoid _TestLineGraphCOO(DGLContext ctx) {\n  /**\n   * A = [[0, 0, 1, 0],\n   *      [1, 0, 1, 0],\n   *      [1, 1, 0, 0],\n   *      [0, 0, 0, 1]]\n   * row: 0 1 1 2 2 3\n   * col: 2 0 2 0 1 3\n   * ID:  0 1 2 3 4 5\n   *\n   * B = COOLineGraph(A, backtracking=False)\n   *\n   * B = [[0, 0, 0, 0, 1, 0],\n   *      [1, 0, 0, 0, 0, 0],\n   *      [0, 0, 0, 1, 0, 0],\n   *      [0, 0, 0, 0, 0, 0],\n   *      [0, 1, 0, 0, 0, 0],\n   *      [0, 0, 0, 0, 0, 0]]\n   *\n   * C = COOLineGraph(A, backtracking=True)\n   *\n   * C = [[0, 0, 0, 1, 1, 0],\n   *      [1, 0, 0, 0, 0, 0],\n   *      [0, 0, 0, 1, 1, 0],\n   *      [1, 0, 0, 0, 0, 0],\n   *      [0, 1, 1, 0, 0, 0],\n   *      [0, 0, 0, 0, 0, 0]]\n   */\n  IdArray a_row = aten::VecToIdArray(\n      std::vector<IdType>({0, 1, 1, 2, 2, 3}), sizeof(IdType) * 8, ctx);\n  IdArray a_col = aten::VecToIdArray(\n      std::vector<IdType>({2, 0, 2, 0, 1, 3}), sizeof(IdType) * 8, ctx);\n  IdArray b_row = aten::VecToIdArray(\n      std::vector<IdType>({0, 1, 2, 4}), sizeof(IdType) * 8, ctx);\n  IdArray b_col = aten::VecToIdArray(\n      std::vector<IdType>({4, 0, 3, 1}), sizeof(IdType) * 8, ctx);\n  IdArray c_row = aten::VecToIdArray(\n      std::vector<IdType>({0, 0, 1, 2, 2, 3, 4, 4}), sizeof(IdType) * 8, ctx);\n  IdArray c_col = aten::VecToIdArray(\n      std::vector<IdType>({3, 4, 0, 3, 4, 0, 1, 2}), sizeof(IdType) * 8, ctx);\n\n  const aten::COOMatrix &coo_a =\n      aten::COOMatrix(4, 4, a_row, a_col, aten::NullArray(), true, false);\n\n  const aten::COOMatrix &l_coo = COOLineGraph(coo_a, false);\n  ASSERT_EQ(l_coo.num_rows, 6);\n  ASSERT_EQ(l_coo.num_cols, 6);\n  ASSERT_TRUE(ArrayEQ<IdType>(l_coo.row, b_row));\n  ASSERT_TRUE(ArrayEQ<IdType>(l_coo.col, b_col));\n  ASSERT_FALSE(l_coo.row_sorted);\n  ASSERT_FALSE(l_coo.col_sorted);\n\n  const aten::COOMatrix &l_coo2 = COOLineGraph(coo_a, true);\n  ASSERT_EQ(l_coo2.num_rows, 6);\n  ASSERT_EQ(l_coo2.num_cols, 6);\n  ASSERT_TRUE(ArrayEQ<IdType>(l_coo2.row, c_row));\n  ASSERT_TRUE(ArrayEQ<IdType>(l_coo2.col, c_col));\n  ASSERT_FALSE(l_coo2.row_sorted);\n  ASSERT_FALSE(l_coo2.col_sorted);\n\n  IdArray a_data = aten::VecToIdArray(\n      std::vector<IdType>({4, 5, 0, 1, 2, 3}), sizeof(IdType) * 8, ctx);\n  b_row = aten::VecToIdArray(\n      std::vector<IdType>({4, 5, 0, 2}), sizeof(IdType) * 8, ctx);\n  b_col = aten::VecToIdArray(\n      std::vector<IdType>({2, 4, 1, 5}), sizeof(IdType) * 8, ctx);\n  c_row = aten::VecToIdArray(\n      std::vector<IdType>({4, 4, 5, 0, 0, 1, 2, 2}), sizeof(IdType) * 8, ctx);\n  c_col = aten::VecToIdArray(\n      std::vector<IdType>({1, 2, 4, 1, 2, 4, 5, 0}), sizeof(IdType) * 8, ctx);\n  const aten::COOMatrix &coo_ad =\n      aten::COOMatrix(4, 4, a_row, a_col, a_data, true, false);\n  const aten::COOMatrix &ld_coo = COOLineGraph(coo_ad, false);\n  ASSERT_EQ(ld_coo.num_rows, 6);\n  ASSERT_EQ(ld_coo.num_cols, 6);\n  ASSERT_TRUE(ArrayEQ<IdType>(ld_coo.row, b_row));\n  ASSERT_TRUE(ArrayEQ<IdType>(ld_coo.col, b_col));\n  ASSERT_FALSE(ld_coo.row_sorted);\n  ASSERT_FALSE(ld_coo.col_sorted);\n\n  const aten::COOMatrix &ld_coo2 = COOLineGraph(coo_ad, true);\n  ASSERT_EQ(ld_coo2.num_rows, 6);\n  ASSERT_EQ(ld_coo2.num_cols, 6);\n  ASSERT_TRUE(ArrayEQ<IdType>(ld_coo2.row, c_row));\n  ASSERT_TRUE(ArrayEQ<IdType>(ld_coo2.col, c_col));\n  ASSERT_FALSE(ld_coo2.row_sorted);\n  ASSERT_FALSE(ld_coo2.col_sorted);\n}\n\nTEST(LineGraphTest, LineGraphCOO) {\n  _TestLineGraphCOO<int32_t>(CPU);\n  _TestLineGraphCOO<int64_t>(CPU);\n}\n\ntemplate <typename IDX>\nvoid _TestSort(DGLContext ctx) {\n  // case 1\n  IdArray a = aten::VecToIdArray(\n      std::vector<IDX>({8, 6, 7, 5, 3, 0, 9}), sizeof(IDX) * 8, ctx);\n  IdArray sorted_a = aten::VecToIdArray(\n      std::vector<IDX>({0, 3, 5, 6, 7, 8, 9}), sizeof(IDX) * 8, ctx);\n  IdArray sorted_idx =\n      aten::VecToIdArray(std::vector<IDX>({5, 4, 3, 1, 2, 0, 6}), 64, ctx);\n\n  IdArray sorted, idx;\n  std::tie(sorted, idx) = aten::Sort(a);\n  ASSERT_TRUE(ArrayEQ<IDX>(sorted, sorted_a));\n  ASSERT_TRUE(ArrayEQ<IDX>(idx, sorted_idx));\n\n  // case 2: empty array\n  a = aten::VecToIdArray(std::vector<IDX>({}), sizeof(IDX) * 8, ctx);\n  sorted_a = aten::VecToIdArray(std::vector<IDX>({}), sizeof(IDX) * 8, ctx);\n  sorted_idx = aten::VecToIdArray(std::vector<IDX>({}), 64, ctx);\n  std::tie(sorted, idx) = aten::Sort(a);\n  ASSERT_TRUE(ArrayEQ<IDX>(sorted, sorted_a));\n  ASSERT_TRUE(ArrayEQ<IDX>(idx, sorted_idx));\n\n  // case 3: array with one element\n  a = aten::VecToIdArray(std::vector<IDX>({2}), sizeof(IDX) * 8, ctx);\n  sorted_a = aten::VecToIdArray(std::vector<IDX>({2}), sizeof(IDX) * 8, ctx);\n  sorted_idx = aten::VecToIdArray(std::vector<IDX>({0}), 64, ctx);\n  std::tie(sorted, idx) = aten::Sort(a);\n  ASSERT_TRUE(ArrayEQ<IDX>(sorted, sorted_a));\n  ASSERT_TRUE(ArrayEQ<IDX>(idx, sorted_idx));\n}\n\nTEST(ArrayTest, Sort) {\n  _TestSort<int32_t>(CPU);\n  _TestSort<int64_t>(CPU);\n#ifdef DGL_USE_CUDA\n  _TestSort<int32_t>(GPU);\n  _TestSort<int64_t>(GPU);\n#endif\n}\n\nTEST(ArrayTest, BFloatCast) {\n  for (int i = -100; i < 100; ++i) {\n    float a = i;\n    BFloat16 b = a;\n    float a_casted = b;\n    ASSERT_FLOAT_EQ(a, a_casted);\n  }\n}\n"
  },
  {
    "path": "tests/cpp/test_concurrent_id_hash_map.cc",
    "content": "#include <dgl/array.h>\n#include <dgl/runtime/parallel_for.h>\n#include <gtest/gtest.h>\n\n#include <algorithm>\n#include <set>\n\n#include \"../../src/array/cpu/concurrent_id_hash_map.h\"\n#include \"./common.h\"\n\nusing namespace dgl;\nusing namespace dgl::runtime;\nusing namespace dgl::aten;\n\nnamespace {\n\ntemplate <typename IdType>\nsize_t ConstructRandomSet(\n    size_t size, IdType range, std::vector<IdType>& id_vec) {\n  id_vec.resize(size);\n  std::srand(std::time(nullptr));\n  for (size_t i = 0; i < size; i++) {\n    id_vec[i] = static_cast<IdType>(std::rand() % range);\n  }\n\n  size_t num_seeds = size / 5 + 1;\n  std::sort(id_vec.begin(), id_vec.begin() + num_seeds);\n  return std::unique(id_vec.begin(), id_vec.begin() + num_seeds) -\n         id_vec.begin();\n}\n\ntemplate <typename IdType, size_t size, IdType range>\nvoid _TestIdMap() {\n  std::vector<IdType> id_vec;\n  auto num_seeds = ConstructRandomSet(size, range, id_vec);\n  std::set<IdType> id_set(id_vec.begin(), id_vec.end());\n  IdArray ids = VecToIdArray(id_vec, sizeof(IdType) * 8, CTX);\n  ConcurrentIdHashMap<IdType> id_map;\n  IdArray unique_ids = id_map.Init(ids, num_seeds);\n  auto unique_num = static_cast<size_t>(unique_ids->shape[0]);\n  IdType* unique_id_data = unique_ids.Ptr<IdType>();\n  EXPECT_EQ(id_set.size(), unique_num);\n\n  parallel_for(0, num_seeds, 64, [&](int64_t s, int64_t e) {\n    for (int64_t i = s; i < e; i++) {\n      EXPECT_EQ(id_vec[i], unique_id_data[i]);\n    }\n  });\n\n  parallel_for(num_seeds, unique_num, 128, [&](int64_t s, int64_t e) {\n    for (int64_t i = s; i < e; i++) {\n      EXPECT_TRUE(id_set.find(unique_id_data[i]) != id_set.end());\n    }\n  });\n\n  IdArray new_ids = id_map.MapIds(unique_ids);\n  EXPECT_TRUE(new_ids.IsContiguous());\n  ids->shape[0] = num_seeds;\n  IdArray new_seed_ids = id_map.MapIds(ids);\n  EXPECT_TRUE(new_seed_ids.IsContiguous());\n  EXPECT_EQ(new_seed_ids.Ptr<IdType>()[0], static_cast<IdType>(0));\n}\n\nTEST(ConcurrentIdHashMapTest, TestConcurrentIdHashMap) {\n  _TestIdMap<int32_t, 1, 10>();\n  _TestIdMap<int64_t, 1, 10>();\n  _TestIdMap<int32_t, 1000, 500000>();\n  _TestIdMap<int64_t, 1000, 500000>();\n  _TestIdMap<int32_t, 50000, 1000000>();\n  _TestIdMap<int64_t, 50000, 1000000>();\n  _TestIdMap<int32_t, 100000, 40000000>();\n  _TestIdMap<int64_t, 100000, 40000000>();\n}\n\n};  // namespace\n"
  },
  {
    "path": "tests/cpp/test_csrmm.cc",
    "content": "#include <dgl/array.h>\n#include <dgl/kernel.h>\n#include <gtest/gtest.h>\n\n#include \"../../src/array/cpu/array_utils.h\"  // PairHash\n#include \"./common.h\"\n\nusing namespace dgl;\nusing namespace dgl::runtime;\n\nnamespace {\n\n// Unit tests:\n// CSRMM(A, B) == A_mm_B\n// CSRSum({A, C}) == A_plus_C\n// CSRMask(A, C) = A_mask_C\n\ntemplate <typename IdType, typename DType>\nstd::unordered_map<std::pair<IdType, IdType>, DType, aten::PairHash> COOToMap(\n    aten::COOMatrix coo, NDArray weights) {\n  std::unordered_map<std::pair<IdType, IdType>, DType, aten::PairHash> map;\n\n  for (int64_t i = 0; i < coo.row->shape[0]; ++i) {\n    IdType irow = aten::IndexSelect<IdType>(coo.row, i);\n    IdType icol = aten::IndexSelect<IdType>(coo.col, i);\n    IdType ieid =\n        aten::COOHasData(coo) ? aten::IndexSelect<IdType>(coo.data, i) : i;\n    DType idata = aten::IndexSelect<DType>(weights, ieid);\n    map.insert({{irow, icol}, idata});\n  }\n  return map;\n}\n\ntemplate <typename IdType, typename DType>\nbool CSRIsClose(\n    aten::CSRMatrix A, aten::CSRMatrix B, NDArray A_weights, NDArray B_weights,\n    DType rtol, DType atol) {\n  auto Amap = COOToMap<IdType, DType>(CSRToCOO(A, false), A_weights);\n  auto Bmap = COOToMap<IdType, DType>(CSRToCOO(B, false), B_weights);\n\n  if (Amap.size() != Bmap.size()) return false;\n\n  for (auto itA : Amap) {\n    auto itB = Bmap.find(itA.first);\n    if (itB == Bmap.end()) return false;\n    if (fabs(itA.second - itB->second) >= rtol * fabs(itA.second) + atol)\n      return false;\n  }\n\n  return true;\n}\n\ntemplate <typename IdType, typename DType>\nstd::pair<aten::CSRMatrix, NDArray> CSR_A(DGLContext ctx = CTX) {\n  // matrix([[0. , 0. , 1. , 0.7, 0. ],\n  //         [0. , 0. , 0.5, 0.+, 0. ],\n  //         [0.4, 0.7, 0. , 0.2, 0. ],\n  //         [0. , 0. , 0. , 0. , 0.2]])\n  // (0.+ indicates that the entry exists but the value is 0.)\n  auto csr = aten::CSRMatrix(\n      4, 5, NDArray::FromVector(std::vector<IdType>({0, 2, 4, 7, 8}), ctx),\n      NDArray::FromVector(std::vector<IdType>({2, 3, 2, 3, 0, 1, 3, 4}), ctx),\n      NDArray::FromVector(std::vector<IdType>({1, 0, 2, 3, 4, 5, 6, 7}), ctx));\n  auto weights = NDArray::FromVector(\n      std::vector<DType>({0.7, 1.0, 0.5, 0.0, 0.4, 0.7, 0.2, 0.2}), ctx);\n  return {csr, weights};\n}\n\ntemplate <typename IdType, typename DType>\nstd::pair<aten::CSRMatrix, NDArray> CSR_B(DGLContext ctx = CTX) {\n  // matrix([[0. , 0.9, 0. , 0.6, 0. , 0.3],\n  //         [0. , 0. , 0. , 0. , 0. , 0.4],\n  //         [0.+, 0. , 0. , 0. , 0. , 0.9],\n  //         [0.8, 0.2, 0.3, 0.2, 0. , 0. ],\n  //         [0.2, 0.4, 0. , 0. , 0. , 0. ]])\n  // (0.+ indicates that the entry exists but the value is 0.)\n  auto csr = aten::CSRMatrix(\n      5, 6, NDArray::FromVector(std::vector<IdType>({0, 3, 4, 6, 10, 12}), ctx),\n      NDArray::FromVector(\n          std::vector<IdType>({1, 3, 5, 5, 0, 5, 0, 1, 2, 3, 0, 1}), ctx));\n  auto weights = NDArray::FromVector(\n      std::vector<DType>(\n          {0.9, 0.6, 0.3, 0.4, 0.0, 0.9, 0.8, 0.2, 0.3, 0.2, 0.2, 0.4}),\n      ctx);\n  return {csr, weights};\n}\n\ntemplate <typename IdType, typename DType>\nstd::pair<aten::CSRMatrix, NDArray> CSR_C(DGLContext ctx = CTX) {\n  // matrix([[0. , 0. , 0. , 0.2, 0. ],\n  //         [0. , 0. , 0. , 0.5, 0.4],\n  //         [0. , 0.2, 0. , 0.9, 0.2],\n  //         [0. , 1. , 0. , 0.7, 0. ]])\n  auto csr = aten::CSRMatrix(\n      4, 5, NDArray::FromVector(std::vector<IdType>({0, 1, 3, 6, 8}), ctx),\n      NDArray::FromVector(std::vector<IdType>({3, 3, 4, 1, 3, 4, 1, 3}), ctx));\n  auto weights = NDArray::FromVector(\n      std::vector<DType>({0.2, 0.5, 0.4, 0.2, 0.9, 0.2, 1., 0.7}), ctx);\n  return {csr, weights};\n}\n\ntemplate <typename IdType, typename DType>\nstd::pair<aten::CSRMatrix, NDArray> CSR_A_mm_B(DGLContext ctx = CTX) {\n  // matrix([[0.56, 0.14, 0.21, 0.14, 0.  , 0.9 ],\n  //         [0.+ , 0.+ , 0.+ , 0.+ , 0.  , 0.45],\n  //         [0.16, 0.4 , 0.06, 0.28, 0.  , 0.4 ],\n  //         [0.04, 0.08, 0.  , 0.  , 0.  , 0.  ]])\n  // (0.+ indicates that the entry exists but the value is 0.)\n  auto csr = aten::CSRMatrix(\n      4, 6, NDArray::FromVector(std::vector<IdType>({0, 5, 10, 15, 17}), ctx),\n      NDArray::FromVector(\n          std::vector<IdType>(\n              {0, 1, 2, 3, 5, 0, 1, 2, 3, 5, 0, 1, 2, 3, 5, 0, 1}),\n          ctx));\n  auto weights = NDArray::FromVector(\n      std::vector<DType>(\n          {0.56, 0.14, 0.21, 0.14, 0.9, 0., 0., 0., 0., 0.45, 0.16, 0.4, 0.06,\n           0.28, 0.4, 0.04, 0.08}),\n      ctx);\n  return {csr, weights};\n}\n\ntemplate <typename IdType, typename DType>\nstd::pair<aten::CSRMatrix, NDArray> CSR_A_plus_C(DGLContext ctx = CTX) {\n  auto csr = aten::CSRMatrix(\n      4, 5, NDArray::FromVector(std::vector<IdType>({0, 2, 5, 9, 12}), ctx),\n      NDArray::FromVector(\n          std::vector<IdType>({2, 3, 2, 3, 4, 0, 1, 3, 4, 1, 3, 4}), ctx));\n  auto weights = NDArray::FromVector(\n      std::vector<DType>(\n          {1., 0.9, 0.5, 0.5, 0.4, 0.4, 0.9, 1.1, 0.2, 1., 0.7, 0.2}),\n      ctx);\n  return {csr, weights};\n}\n\ntemplate <typename DType>\nNDArray CSR_A_mask_C(DGLContext ctx = CTX) {\n  return NDArray::FromVector(\n      std::vector<DType>({0.7, 0.0, 0.0, 0.7, 0.2, 0.0, 0.0, 0.0}), ctx);\n}\n\ntemplate <typename IdType, typename DType>\nvoid _TestCsrmm(DGLContext ctx = CTX) {\n  auto A = CSR_A<IdType, DType>(ctx);\n  auto B = CSR_B<IdType, DType>(ctx);\n  auto A_mm_B = aten::CSRMM(A.first, A.second, B.first, B.second);\n  auto A_mm_B2 = CSR_A_mm_B<IdType, DType>(ctx);\n  bool result = CSRIsClose<IdType, DType>(\n      A_mm_B.first, A_mm_B2.first, A_mm_B.second, A_mm_B2.second, 1e-4, 1e-4);\n  ASSERT_TRUE(result);\n}\n\ntemplate <typename IdType, typename DType>\nvoid _TestCsrsum(DGLContext ctx = CTX) {\n  auto A = CSR_A<IdType, DType>(ctx);\n  auto C = CSR_C<IdType, DType>(ctx);\n  auto A_plus_C = aten::CSRSum({A.first, C.first}, {A.second, C.second});\n  auto A_plus_C2 = CSR_A_plus_C<IdType, DType>(ctx);\n  bool result = CSRIsClose<IdType, DType>(\n      A_plus_C.first, A_plus_C2.first, A_plus_C.second, A_plus_C2.second, 1e-4,\n      1e-4);\n  ASSERT_TRUE(result);\n}\n\ntemplate <typename IdType, typename DType>\nvoid _TestCsrmask(DGLContext ctx = CTX) {\n  auto A = CSR_A<IdType, DType>(ctx);\n  auto C = CSR_C<IdType, DType>(ctx);\n  auto C_coo = CSRToCOO(C.first, false);\n  auto A_mask_C =\n      aten::CSRGetData<DType>(A.first, C_coo.row, C_coo.col, A.second, 0);\n  auto A_mask_C2 = CSR_A_mask_C<DType>(ctx);\n  ASSERT_TRUE(ArrayEQ<DType>(A_mask_C, A_mask_C2));\n}\n\nTEST(CsrmmTest, TestCsrmm) {\n  _TestCsrmm<int32_t, float>(CPU);\n  _TestCsrmm<int32_t, double>(CPU);\n  _TestCsrmm<int64_t, float>(CPU);\n  _TestCsrmm<int64_t, double>(CPU);\n#ifdef DGL_USE_CUDA\n  _TestCsrmm<int32_t, float>(GPU);\n  _TestCsrmm<int32_t, double>(GPU);\n  _TestCsrmm<int64_t, float>(GPU);\n  _TestCsrmm<int64_t, double>(GPU);\n#endif\n}\n\nTEST(CsrmmTest, TestCsrsum) {\n  _TestCsrsum<int32_t, float>(CPU);\n  _TestCsrsum<int32_t, double>(CPU);\n  _TestCsrsum<int64_t, float>(CPU);\n  _TestCsrsum<int64_t, double>(CPU);\n#ifdef DGL_USE_CUDA\n  _TestCsrsum<int32_t, float>(GPU);\n  _TestCsrsum<int32_t, double>(GPU);\n  _TestCsrsum<int64_t, float>(GPU);\n  _TestCsrsum<int64_t, double>(GPU);\n#endif\n}\n\nTEST(CsrmmTest, TestCsrmask) {\n  _TestCsrmask<int32_t, float>(CPU);\n  _TestCsrmask<int32_t, double>(CPU);\n  _TestCsrmask<int64_t, float>(CPU);\n  _TestCsrmask<int64_t, double>(CPU);\n#ifdef DGL_USE_CUDA\n  _TestCsrmask<int32_t, float>(GPU);\n  _TestCsrmask<int32_t, double>(GPU);\n  _TestCsrmask<int64_t, float>(GPU);\n  _TestCsrmask<int64_t, double>(GPU);\n#endif\n}\n\n};  // namespace\n"
  },
  {
    "path": "tests/cpp/test_partition.cc",
    "content": "#include <gtest/gtest.h>\n\n#include \"../../src/partition/ndarray_partition.h\"\n#include \"./common.h\"\n\nusing namespace dgl;\nusing namespace dgl::partition;\n\ntemplate <DGLDeviceType XPU, typename IdType>\nvoid _TestRemainder_GeneratePermutation() {\n  const int64_t size = 160000;\n  const int num_parts = 7;\n  NDArrayPartitionRef part = CreatePartitionRemainderBased(size, num_parts);\n\n  IdArray idxs =\n      aten::Range(0, size / 10, sizeof(IdType) * 8, DGLContext{XPU, 0});\n\n  std::pair<IdArray, IdArray> result = part->GeneratePermutation(idxs);\n\n  // first part of result should be the permutation\n  IdArray perm = result.first.CopyTo(DGLContext{kDGLCPU, 0});\n  ASSERT_TRUE(perm.Ptr<IdType>() != nullptr);\n  ASSERT_EQ(perm->shape[0], idxs->shape[0]);\n  const IdType* const perm_cpu = static_cast<const IdType*>(perm->data);\n\n  // second part of result should be the counts\n  IdArray counts = result.second.CopyTo(DGLContext{kDGLCPU, 0});\n  ASSERT_TRUE(counts.Ptr<int64_t>() != nullptr);\n  ASSERT_EQ(counts->shape[0], num_parts);\n  const int64_t* const counts_cpu = static_cast<const int64_t*>(counts->data);\n\n  std::vector<int64_t> prefix(num_parts + 1, 0);\n  for (int p = 0; p < num_parts; ++p) {\n    prefix[p + 1] = prefix[p] + counts_cpu[p];\n  }\n  ASSERT_EQ(prefix.back(), idxs->shape[0]);\n\n  // copy original indexes to cpu\n  idxs = idxs.CopyTo(DGLContext{kDGLCPU, 0});\n  const IdType* const idxs_cpu = static_cast<const IdType*>(idxs->data);\n\n  for (int p = 0; p < num_parts; ++p) {\n    for (int64_t i = prefix[p]; i < prefix[p + 1]; ++i) {\n      EXPECT_EQ(idxs_cpu[perm_cpu[i]] % num_parts, p);\n    }\n  }\n}\n\ntemplate <DGLDeviceType XPU, typename IdType>\nvoid _TestRemainder_MapToX() {\n  const int64_t size = 160000;\n  const int num_parts = 7;\n  NDArrayPartitionRef part = CreatePartitionRemainderBased(size, num_parts);\n\n  for (int part_id = 0; part_id < num_parts; ++part_id) {\n    IdArray local = aten::Range(\n        0, part->PartSize(part_id), sizeof(IdType) * 8, DGLContext{XPU, 0});\n    IdArray global = part->MapToGlobal(local, part_id);\n    IdArray act_local = part->MapToLocal(global).CopyTo(CPU);\n\n    // every global index should have the same remainder as the part id\n    ASSERT_EQ(global->shape[0], local->shape[0]);\n    global = global.CopyTo(CPU);\n    for (int64_t i = 0; i < global->shape[0]; ++i) {\n      EXPECT_EQ(Ptr<IdType>(global)[i] % num_parts, part_id)\n          << \"i=\" << i << \", num_parts=\" << num_parts\n          << \", part_id=\" << part_id;\n    }\n\n    // the remapped local indices to should match the original\n    local = local.CopyTo(CPU);\n    ASSERT_EQ(local->shape[0], act_local->shape[0]);\n    for (int64_t i = 0; i < act_local->shape[0]; ++i) {\n      EXPECT_EQ(Ptr<IdType>(local)[i], Ptr<IdType>(act_local)[i]);\n    }\n  }\n}\n\nTEST(PartitionTest, TestRemainderPartition) {\n#ifdef DGL_USE_CUDA\n  _TestRemainder_GeneratePermutation<kDGLCUDA, int32_t>();\n  _TestRemainder_GeneratePermutation<kDGLCUDA, int64_t>();\n\n  _TestRemainder_MapToX<kDGLCUDA, int32_t>();\n  _TestRemainder_MapToX<kDGLCUDA, int64_t>();\n#endif\n  // CPU is not implemented\n}\n\ntemplate <typename INDEX, typename RANGE>\nint _FindPart(const INDEX idx, const RANGE* const range, const int num_parts) {\n  for (int i = 0; i < num_parts; ++i) {\n    if (range[i + 1] > idx) {\n      return i;\n    }\n  }\n\n  return -1;\n}\n\ntemplate <DGLDeviceType XPU, typename IdType>\nvoid _TestRange_GeneratePermutation() {\n  const int64_t size = 160000;\n  const int num_parts = 7;\n  IdArray range = aten::NewIdArray(\n      num_parts + 1, DGLContext{kDGLCPU, 0}, sizeof(IdType) * 8);\n  for (int i = 0; i < num_parts; ++i) {\n    range.Ptr<IdType>()[i] = (size / num_parts) * i;\n  }\n  range.Ptr<IdType>()[num_parts] = size;\n  NDArrayPartitionRef part = CreatePartitionRangeBased(\n      size, num_parts, range.CopyTo(DGLContext{XPU, 0}));\n\n  IdArray idxs =\n      aten::Range(0, size / 10, sizeof(IdType) * 8, DGLContext{XPU, 0});\n\n  std::pair<IdArray, IdArray> result = part->GeneratePermutation(idxs);\n\n  // first part of result should be the permutation\n  IdArray perm = result.first.CopyTo(DGLContext{kDGLCPU, 0});\n  ASSERT_TRUE(perm.Ptr<IdType>() != nullptr);\n  ASSERT_EQ(perm->shape[0], idxs->shape[0]);\n  const IdType* const perm_cpu = static_cast<const IdType*>(perm->data);\n\n  // second part of result should be the counts\n  IdArray counts = result.second.CopyTo(DGLContext{kDGLCPU, 0});\n  ASSERT_TRUE(counts.Ptr<int64_t>() != nullptr);\n  ASSERT_EQ(counts->shape[0], num_parts);\n  const int64_t* const counts_cpu = static_cast<const int64_t*>(counts->data);\n\n  std::vector<int64_t> prefix(num_parts + 1, 0);\n  for (int p = 0; p < num_parts; ++p) {\n    prefix[p + 1] = prefix[p] + counts_cpu[p];\n  }\n  ASSERT_EQ(prefix.back(), idxs->shape[0]);\n\n  // copy original indexes to cpu\n  idxs = idxs.CopyTo(DGLContext{kDGLCPU, 0});\n  const IdType* const idxs_cpu = static_cast<const IdType*>(idxs->data);\n\n  for (int p = 0; p < num_parts; ++p) {\n    for (int64_t i = prefix[p]; i < prefix[p + 1]; ++i) {\n      EXPECT_EQ(\n          _FindPart(idxs_cpu[perm_cpu[i]], range.Ptr<IdType>(), num_parts), p);\n    }\n  }\n}\n\ntemplate <DGLDeviceType XPU, typename IdType>\nvoid _TestRange_MapToX() {\n  const int64_t size = 160000;\n  const int num_parts = 7;\n  IdArray range = aten::NewIdArray(\n      num_parts + 1, DGLContext{kDGLCPU, 0}, sizeof(IdType) * 8);\n  for (int i = 0; i < num_parts; ++i) {\n    Ptr<IdType>(range)[i] = (size / num_parts) * i;\n  }\n  range.Ptr<IdType>()[num_parts] = size;\n  NDArrayPartitionRef part = CreatePartitionRangeBased(\n      size, num_parts, range.CopyTo(DGLContext{XPU, 0}));\n\n  for (int part_id = 0; part_id < num_parts; ++part_id) {\n    IdArray local = aten::Range(\n        0, part->PartSize(part_id), sizeof(IdType) * 8, DGLContext{XPU, 0});\n    IdArray global = part->MapToGlobal(local, part_id);\n    IdArray act_local = part->MapToLocal(global).CopyTo(CPU);\n\n    ASSERT_EQ(global->shape[0], local->shape[0]);\n    global = global.CopyTo(CPU);\n    for (int64_t i = 0; i < global->shape[0]; ++i) {\n      EXPECT_EQ(\n          _FindPart(Ptr<IdType>(global)[i], Ptr<IdType>(range), num_parts),\n          part_id)\n          << \"i=\" << i << \", num_parts=\" << num_parts << \", part_id=\" << part_id\n          << \", shape=\" << global->shape[0];\n    }\n\n    // the remapped local indices to should match the original\n    local = local.CopyTo(CPU);\n    ASSERT_EQ(local->shape[0], act_local->shape[0]);\n    for (int64_t i = 0; i < act_local->shape[0]; ++i) {\n      EXPECT_EQ(Ptr<IdType>(local)[i], Ptr<IdType>(act_local)[i]);\n    }\n  }\n}\n\nTEST(PartitionTest, TestRangePartition) {\n#ifdef DGL_USE_CUDA\n  _TestRange_GeneratePermutation<kDGLCUDA, int32_t>();\n  _TestRange_GeneratePermutation<kDGLCUDA, int64_t>();\n\n  _TestRange_MapToX<kDGLCUDA, int32_t>();\n  _TestRange_MapToX<kDGLCUDA, int64_t>();\n#endif\n  // CPU is not implemented\n}\n"
  },
  {
    "path": "tests/cpp/test_rowwise.cc",
    "content": "#include <dgl/array.h>\n#include <gtest/gtest.h>\n\n#include <set>\n#include <tuple>\n\n#include \"./common.h\"\n\nusing namespace dgl;\nusing namespace dgl::runtime;\nusing namespace dgl::aten;\n\ntemplate <typename Idx>\nusing ETuple = std::tuple<Idx, Idx, Idx>;\n\ntemplate <typename Idx>\nstd::set<ETuple<Idx>> AllEdgeSet(bool has_data) {\n  if (has_data) {\n    std::set<ETuple<Idx>> eset;\n    eset.insert(ETuple<Idx>{0, 0, 2});\n    eset.insert(ETuple<Idx>{0, 1, 3});\n    eset.insert(ETuple<Idx>{1, 1, 0});\n    eset.insert(ETuple<Idx>{3, 2, 1});\n    eset.insert(ETuple<Idx>{3, 3, 4});\n    return eset;\n  } else {\n    std::set<ETuple<Idx>> eset;\n    eset.insert(ETuple<Idx>{0, 0, 0});\n    eset.insert(ETuple<Idx>{0, 1, 1});\n    eset.insert(ETuple<Idx>{1, 1, 2});\n    eset.insert(ETuple<Idx>{3, 2, 3});\n    eset.insert(ETuple<Idx>{3, 3, 4});\n    return eset;\n  }\n}\n\ntemplate <typename Idx>\nstd::set<ETuple<Idx>> AllEdgePerEtypeSet(bool has_data) {\n  if (has_data) {\n    std::set<ETuple<Idx>> eset;\n    eset.insert(ETuple<Idx>{0, 0, 0});\n    eset.insert(ETuple<Idx>{0, 1, 1});\n    eset.insert(ETuple<Idx>{0, 2, 4});\n    eset.insert(ETuple<Idx>{0, 3, 6});\n    eset.insert(ETuple<Idx>{3, 2, 5});\n    eset.insert(ETuple<Idx>{3, 3, 3});\n    return eset;\n  } else {\n    std::set<ETuple<Idx>> eset;\n    eset.insert(ETuple<Idx>{0, 0, 0});\n    eset.insert(ETuple<Idx>{0, 1, 1});\n    eset.insert(ETuple<Idx>{0, 2, 2});\n    eset.insert(ETuple<Idx>{0, 3, 3});\n    eset.insert(ETuple<Idx>{3, 3, 5});\n    eset.insert(ETuple<Idx>{3, 2, 6});\n    return eset;\n  }\n}\n\ntemplate <typename Idx>\nstd::set<ETuple<Idx>> ToEdgeSet(COOMatrix mat) {\n  std::set<ETuple<Idx>> eset;\n  Idx* row = static_cast<Idx*>(mat.row->data);\n  Idx* col = static_cast<Idx*>(mat.col->data);\n  Idx* data = static_cast<Idx*>(mat.data->data);\n  for (int64_t i = 0; i < mat.row->shape[0]; ++i) {\n    // std::cout << row[i] << \" \" << col[i] <<  \" \" << data[i] << std::endl;\n    eset.emplace(row[i], col[i], data[i]);\n  }\n  return eset;\n}\n\ntemplate <typename Idx>\nvoid CheckSampledResult(COOMatrix mat, IdArray rows, bool has_data) {\n  ASSERT_EQ(mat.num_rows, 4);\n  ASSERT_EQ(mat.num_cols, 4);\n  Idx* row = static_cast<Idx*>(mat.row->data);\n  Idx* col = static_cast<Idx*>(mat.col->data);\n  Idx* data = static_cast<Idx*>(mat.data->data);\n  const auto& gt = AllEdgeSet<Idx>(has_data);\n  for (int64_t i = 0; i < mat.row->shape[0]; ++i) {\n    ASSERT_TRUE(gt.count(std::make_tuple(row[i], col[i], data[i])));\n    ASSERT_TRUE(IsInArray(rows, row[i]));\n  }\n}\n\ntemplate <typename Idx>\nvoid CheckSampledPerEtypeResult(COOMatrix mat, IdArray rows, bool has_data) {\n  ASSERT_EQ(mat.num_rows, 4);\n  ASSERT_EQ(mat.num_cols, 4);\n  Idx* row = static_cast<Idx*>(mat.row->data);\n  Idx* col = static_cast<Idx*>(mat.col->data);\n  Idx* data = static_cast<Idx*>(mat.data->data);\n  const auto& gt = AllEdgePerEtypeSet<Idx>(has_data);\n  for (int64_t i = 0; i < mat.row->shape[0]; ++i) {\n    int64_t count = gt.count(std::make_tuple(row[i], col[i], data[i]));\n    ASSERT_TRUE(count);\n    ASSERT_TRUE(IsInArray(rows, row[i]));\n  }\n}\n\ntemplate <typename Idx>\nCSRMatrix CSR(bool has_data) {\n  IdArray indptr = NDArray::FromVector(std::vector<Idx>({0, 2, 3, 3, 5}));\n  IdArray indices = NDArray::FromVector(std::vector<Idx>({0, 1, 1, 2, 3}));\n  IdArray data = NDArray::FromVector(std::vector<Idx>({2, 3, 0, 1, 4}));\n  if (has_data)\n    return CSRMatrix(4, 4, indptr, indices, data);\n  else\n    return CSRMatrix(4, 4, indptr, indices);\n}\n\ntemplate <typename Idx>\nCOOMatrix COO(bool has_data) {\n  IdArray row = NDArray::FromVector(std::vector<Idx>({0, 0, 1, 3, 3}));\n  IdArray col = NDArray::FromVector(std::vector<Idx>({0, 1, 1, 2, 3}));\n  IdArray data = NDArray::FromVector(std::vector<Idx>({2, 3, 0, 1, 4}));\n  if (has_data)\n    return COOMatrix(4, 4, row, col, data);\n  else\n    return COOMatrix(4, 4, row, col);\n}\n\ntemplate <typename Idx>\nstd::pair<CSRMatrix, std::vector<int64_t>> CSREtypes(bool has_data) {\n  IdArray indptr = NDArray::FromVector(std::vector<Idx>({0, 4, 5, 5, 7}));\n  IdArray indices =\n      NDArray::FromVector(std::vector<Idx>({0, 1, 2, 3, 1, 3, 2}));\n  IdArray data = NDArray::FromVector(std::vector<Idx>({0, 1, 4, 6, 2, 3, 5}));\n  auto eid2etype_offsets = std::vector<int64_t>({0, 4, 5, 6, 7});\n  if (has_data)\n    return {CSRMatrix(4, 4, indptr, indices, data), eid2etype_offsets};\n  else\n    return {CSRMatrix(4, 4, indptr, indices), eid2etype_offsets};\n}\n\ntemplate <typename Idx>\nstd::pair<COOMatrix, std::vector<int64_t>> COOEtypes(bool has_data) {\n  IdArray row = NDArray::FromVector(std::vector<Idx>({0, 0, 0, 0, 1, 3, 3}));\n  IdArray col = NDArray::FromVector(std::vector<Idx>({0, 1, 2, 3, 1, 3, 2}));\n  IdArray data = NDArray::FromVector(std::vector<Idx>({0, 1, 4, 6, 2, 3, 5}));\n  auto eid2etype_offsets = std::vector<int64_t>({0, 4, 5, 6, 7});\n  if (has_data)\n    return {COOMatrix(4, 4, row, col, data), eid2etype_offsets};\n  else\n    return {COOMatrix(4, 4, row, col), eid2etype_offsets};\n}\n\ntemplate <typename Idx, typename FloatType>\nvoid _TestCSRSampling(bool has_data) {\n  auto mat = CSR<Idx>(has_data);\n  FloatArray prob =\n      NDArray::FromVector(std::vector<FloatType>({.5, .5, .5, .5, .5}));\n  IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 3}));\n  for (int k = 0; k < 10; ++k) {\n    auto rst = CSRRowWiseSampling(mat, rows, 2, prob, true);\n    CheckSampledResult<Idx>(rst, rows, has_data);\n  }\n  for (int k = 0; k < 10; ++k) {\n    auto rst = CSRRowWiseSampling(mat, rows, 2, prob, false);\n    CheckSampledResult<Idx>(rst, rows, has_data);\n    auto eset = ToEdgeSet<Idx>(rst);\n    ASSERT_EQ(eset.size(), 4);\n    if (has_data) {\n      ASSERT_TRUE(eset.count(std::make_tuple(0, 0, 2)));\n      ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 3)));\n      ASSERT_TRUE(eset.count(std::make_tuple(3, 2, 1)));\n      ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4)));\n    } else {\n      ASSERT_TRUE(eset.count(std::make_tuple(0, 0, 0)));\n      ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 1)));\n      ASSERT_TRUE(eset.count(std::make_tuple(3, 2, 3)));\n      ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4)));\n    }\n  }\n  prob = NDArray::FromVector(std::vector<FloatType>({.0, .5, .5, .0, .5}));\n  for (int k = 0; k < 100; ++k) {\n    auto rst = CSRRowWiseSampling(mat, rows, 2, prob, true);\n    CheckSampledResult<Idx>(rst, rows, has_data);\n    auto eset = ToEdgeSet<Idx>(rst);\n    if (has_data) {\n      ASSERT_FALSE(eset.count(std::make_tuple(0, 1, 3)));\n    } else {\n      ASSERT_FALSE(eset.count(std::make_tuple(0, 0, 0)));\n      ASSERT_FALSE(eset.count(std::make_tuple(3, 2, 3)));\n    }\n  }\n}\n\nTEST(RowwiseTest, TestCSRSampling) {\n  _TestCSRSampling<int32_t, float>(true);\n  _TestCSRSampling<int64_t, float>(true);\n  _TestCSRSampling<int32_t, double>(true);\n  _TestCSRSampling<int64_t, double>(true);\n  _TestCSRSampling<int32_t, float>(false);\n  _TestCSRSampling<int64_t, float>(false);\n  _TestCSRSampling<int32_t, double>(false);\n  _TestCSRSampling<int64_t, double>(false);\n}\n\ntemplate <typename Idx, typename FloatType>\nvoid _TestCSRSamplingUniform(bool has_data) {\n  auto mat = CSR<Idx>(has_data);\n  FloatArray prob = aten::NullArray();\n  IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 3}));\n  for (int k = 0; k < 10; ++k) {\n    auto rst = CSRRowWiseSampling(mat, rows, 2, prob, true);\n    CheckSampledResult<Idx>(rst, rows, has_data);\n  }\n  for (int k = 0; k < 10; ++k) {\n    auto rst = CSRRowWiseSampling(mat, rows, 2, prob, false);\n    CheckSampledResult<Idx>(rst, rows, has_data);\n    auto eset = ToEdgeSet<Idx>(rst);\n    if (has_data) {\n      ASSERT_TRUE(eset.count(std::make_tuple(0, 0, 2)));\n      ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 3)));\n      ASSERT_TRUE(eset.count(std::make_tuple(3, 2, 1)));\n      ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4)));\n    } else {\n      ASSERT_TRUE(eset.count(std::make_tuple(0, 0, 0)));\n      ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 1)));\n      ASSERT_TRUE(eset.count(std::make_tuple(3, 2, 3)));\n      ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4)));\n    }\n  }\n}\n\nTEST(RowwiseTest, TestCSRSamplingUniform) {\n  _TestCSRSamplingUniform<int32_t, float>(true);\n  _TestCSRSamplingUniform<int64_t, float>(true);\n  _TestCSRSamplingUniform<int32_t, double>(true);\n  _TestCSRSamplingUniform<int64_t, double>(true);\n  _TestCSRSamplingUniform<int32_t, float>(false);\n  _TestCSRSamplingUniform<int64_t, float>(false);\n  _TestCSRSamplingUniform<int32_t, double>(false);\n  _TestCSRSamplingUniform<int64_t, double>(false);\n}\n\ntemplate <typename Idx, typename FloatType>\nvoid _TestCSRPerEtypeSampling(bool has_data) {\n  auto pair = CSREtypes<Idx>(has_data);\n  auto mat = pair.first;\n  auto eid2etype_offset = pair.second;\n  std::vector<FloatArray> prob = {\n      NDArray::FromVector(std::vector<FloatType>({.5, .5, .5, .5})),\n      NDArray::FromVector(std::vector<FloatType>({.5})),\n      NDArray::FromVector(std::vector<FloatType>({.5})),\n      NDArray::FromVector(std::vector<FloatType>({.5}))};\n  IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 3}));\n  for (int k = 0; k < 10; ++k) {\n    auto rst = CSRRowWisePerEtypeSampling(\n        mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, true);\n    CheckSampledPerEtypeResult<Idx>(rst, rows, has_data);\n  }\n  for (int k = 0; k < 10; ++k) {\n    auto rst = CSRRowWisePerEtypeSampling(\n        mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, false);\n    CheckSampledPerEtypeResult<Idx>(rst, rows, has_data);\n    auto eset = ToEdgeSet<Idx>(rst);\n    if (has_data) {\n      int counts = 0;\n      counts += eset.count(std::make_tuple(0, 0, 0));\n      counts += eset.count(std::make_tuple(0, 1, 1));\n      ASSERT_EQ(counts, 2);\n      counts = 0;\n      counts += eset.count(std::make_tuple(0, 2, 4));\n      ASSERT_EQ(counts, 1);\n      counts = 0;\n      counts += eset.count(std::make_tuple(0, 3, 6));\n      ASSERT_EQ(counts, 1);\n      counts = 0;\n      counts += eset.count(std::make_tuple(1, 1, 2));\n      ASSERT_EQ(counts, 0);\n      counts = 0;\n      counts += eset.count(std::make_tuple(3, 2, 5));\n      ASSERT_EQ(counts, 1);\n      counts = 0;\n      counts += eset.count(std::make_tuple(3, 3, 3));\n      ASSERT_EQ(counts, 1);\n    } else {\n      int counts = 0;\n      counts += eset.count(std::make_tuple(0, 0, 0));\n      counts += eset.count(std::make_tuple(0, 1, 1));\n      counts += eset.count(std::make_tuple(0, 2, 2));\n      counts += eset.count(std::make_tuple(0, 3, 3));\n      ASSERT_EQ(counts, 2);\n      counts = 0;\n      counts += eset.count(std::make_tuple(1, 1, 4));\n      ASSERT_EQ(counts, 0);\n      counts = 0;\n      counts += eset.count(std::make_tuple(3, 3, 5));\n      ASSERT_EQ(counts, 1);\n      counts = 0;\n      counts += eset.count(std::make_tuple(3, 2, 6));\n      ASSERT_EQ(counts, 1);\n    }\n  }\n\n  prob = {\n      NDArray::FromVector(std::vector<FloatType>({.0, .5, .0, .0})),\n      NDArray::FromVector(std::vector<FloatType>({.5})),\n      NDArray::FromVector(std::vector<FloatType>({.5})),\n      NDArray::FromVector(std::vector<FloatType>({.5}))};\n  for (int k = 0; k < 10; ++k) {\n    auto rst = CSRRowWisePerEtypeSampling(\n        mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, true);\n    CheckSampledPerEtypeResult<Idx>(rst, rows, has_data);\n    auto eset = ToEdgeSet<Idx>(rst);\n    if (has_data) {\n      ASSERT_FALSE(eset.count(std::make_tuple(0, 0, 0)));\n    } else {\n      ASSERT_FALSE(eset.count(std::make_tuple(0, 0, 0)));\n      ASSERT_FALSE(eset.count(std::make_tuple(0, 2, 2)));\n      ASSERT_FALSE(eset.count(std::make_tuple(0, 3, 3)));\n    }\n  }\n}\n\ntemplate <typename Idx, typename FloatType>\nvoid _TestCSRPerEtypeSamplingSorted() {\n  auto pair = CSREtypes<Idx>(true);\n  auto mat = pair.first;\n  auto eid2etype_offset = pair.second;\n  std::vector<FloatArray> prob = {\n      NDArray::FromVector(std::vector<FloatType>({.5, .5, .5, .5})),\n      NDArray::FromVector(std::vector<FloatType>({.5})),\n      NDArray::FromVector(std::vector<FloatType>({.5})),\n      NDArray::FromVector(std::vector<FloatType>({.5}))};\n  IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 3}));\n  for (int k = 0; k < 10; ++k) {\n    auto rst = CSRRowWisePerEtypeSampling(\n        mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, true, true);\n    CheckSampledPerEtypeResult<Idx>(rst, rows, true);\n  }\n  for (int k = 0; k < 10; ++k) {\n    auto rst = CSRRowWisePerEtypeSampling(\n        mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, false, true);\n    CheckSampledPerEtypeResult<Idx>(rst, rows, true);\n    auto eset = ToEdgeSet<Idx>(rst);\n    int counts = 0;\n    counts += eset.count(std::make_tuple(0, 0, 0));\n    counts += eset.count(std::make_tuple(0, 1, 1));\n    ASSERT_EQ(counts, 2);\n    counts = 0;\n    counts += eset.count(std::make_tuple(0, 2, 4));\n    ASSERT_EQ(counts, 1);\n    counts = 0;\n    counts += eset.count(std::make_tuple(0, 3, 6));\n    ASSERT_EQ(counts, 1);\n    counts = 0;\n    counts += eset.count(std::make_tuple(1, 1, 2));\n    ASSERT_EQ(counts, 0);\n    counts = 0;\n    counts += eset.count(std::make_tuple(3, 2, 5));\n    ASSERT_EQ(counts, 1);\n    counts = 0;\n    counts += eset.count(std::make_tuple(3, 3, 3));\n    ASSERT_EQ(counts, 1);\n  }\n\n  prob = {\n      NDArray::FromVector(std::vector<FloatType>({.0, .5, .0, .0})),\n      NDArray::FromVector(std::vector<FloatType>({.5})),\n      NDArray::FromVector(std::vector<FloatType>({.5})),\n      NDArray::FromVector(std::vector<FloatType>({.5}))};\n  for (int k = 0; k < 10; ++k) {\n    auto rst = CSRRowWisePerEtypeSampling(\n        mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, true, true);\n    CheckSampledPerEtypeResult<Idx>(rst, rows, true);\n    auto eset = ToEdgeSet<Idx>(rst);\n    ASSERT_FALSE(eset.count(std::make_tuple(0, 0, 0)));\n  }\n}\n\nTEST(RowwiseTest, TestCSRPerEtypeSampling) {\n  _TestCSRPerEtypeSampling<int32_t, float>(true);\n  _TestCSRPerEtypeSampling<int64_t, float>(true);\n  _TestCSRPerEtypeSampling<int32_t, double>(true);\n  _TestCSRPerEtypeSampling<int64_t, double>(true);\n  _TestCSRPerEtypeSampling<int32_t, float>(false);\n  _TestCSRPerEtypeSampling<int64_t, float>(false);\n  _TestCSRPerEtypeSampling<int32_t, double>(false);\n  _TestCSRPerEtypeSampling<int64_t, double>(false);\n  _TestCSRPerEtypeSamplingSorted<int32_t, float>();\n  _TestCSRPerEtypeSamplingSorted<int64_t, float>();\n  _TestCSRPerEtypeSamplingSorted<int32_t, double>();\n  _TestCSRPerEtypeSamplingSorted<int64_t, double>();\n}\n\ntemplate <typename Idx, typename FloatType>\nvoid _TestCSRPerEtypeSamplingUniform(bool has_data) {\n  auto pair = CSREtypes<Idx>(has_data);\n  auto mat = pair.first;\n  auto eid2etype_offset = pair.second;\n  std::vector<FloatArray> prob = {\n      aten::NullArray(), aten::NullArray(), aten::NullArray(),\n      aten::NullArray()};\n  IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 3}));\n  for (int k = 0; k < 10; ++k) {\n    auto rst = CSRRowWisePerEtypeSampling(\n        mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, true);\n    CheckSampledPerEtypeResult<Idx>(rst, rows, has_data);\n  }\n  for (int k = 0; k < 10; ++k) {\n    auto rst = CSRRowWisePerEtypeSampling(\n        mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, false);\n    CheckSampledPerEtypeResult<Idx>(rst, rows, has_data);\n    auto eset = ToEdgeSet<Idx>(rst);\n    if (has_data) {\n      int counts = 0;\n      counts += eset.count(std::make_tuple(0, 0, 0));\n      counts += eset.count(std::make_tuple(0, 1, 1));\n      ASSERT_EQ(counts, 2);\n      counts = 0;\n      counts += eset.count(std::make_tuple(0, 2, 4));\n      ASSERT_EQ(counts, 1);\n      counts = 0;\n      counts += eset.count(std::make_tuple(0, 3, 6));\n      ASSERT_EQ(counts, 1);\n      counts = 0;\n      counts += eset.count(std::make_tuple(1, 1, 2));\n      ASSERT_EQ(counts, 0);\n      counts = 0;\n      counts += eset.count(std::make_tuple(3, 2, 5));\n      ASSERT_EQ(counts, 1);\n      counts = 0;\n      counts += eset.count(std::make_tuple(3, 3, 3));\n      ASSERT_EQ(counts, 1);\n    } else {\n      int counts = 0;\n      counts += eset.count(std::make_tuple(0, 0, 0));\n      counts += eset.count(std::make_tuple(0, 1, 1));\n      counts += eset.count(std::make_tuple(0, 2, 2));\n      counts += eset.count(std::make_tuple(0, 3, 3));\n      ASSERT_EQ(counts, 2);\n      counts = 0;\n      counts += eset.count(std::make_tuple(1, 1, 4));\n      ASSERT_EQ(counts, 0);\n      counts = 0;\n      counts += eset.count(std::make_tuple(3, 3, 5));\n      ASSERT_EQ(counts, 1);\n      counts = 0;\n      counts += eset.count(std::make_tuple(3, 2, 6));\n      ASSERT_EQ(counts, 1);\n    }\n  }\n}\n\ntemplate <typename Idx, typename FloatType>\nvoid _TestCSRPerEtypeSamplingUniformSorted() {\n  auto pair = CSREtypes<Idx>(true);\n  auto mat = pair.first;\n  auto eid2etype_offset = pair.second;\n  std::vector<FloatArray> prob = {\n      aten::NullArray(), aten::NullArray(), aten::NullArray(),\n      aten::NullArray()};\n  IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 3}));\n  for (int k = 0; k < 10; ++k) {\n    auto rst = CSRRowWisePerEtypeSampling(\n        mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, true, true);\n    CheckSampledPerEtypeResult<Idx>(rst, rows, true);\n  }\n  for (int k = 0; k < 10; ++k) {\n    auto rst = CSRRowWisePerEtypeSampling(\n        mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, false, true);\n    CheckSampledPerEtypeResult<Idx>(rst, rows, true);\n    auto eset = ToEdgeSet<Idx>(rst);\n    int counts = 0;\n    counts += eset.count(std::make_tuple(0, 0, 0));\n    counts += eset.count(std::make_tuple(0, 1, 1));\n    ASSERT_EQ(counts, 2);\n    counts = 0;\n    counts += eset.count(std::make_tuple(0, 2, 4));\n    ASSERT_EQ(counts, 1);\n    counts = 0;\n    counts += eset.count(std::make_tuple(0, 3, 6));\n    ASSERT_EQ(counts, 1);\n    counts = 0;\n    counts += eset.count(std::make_tuple(1, 1, 2));\n    ASSERT_EQ(counts, 0);\n    counts = 0;\n    counts += eset.count(std::make_tuple(3, 2, 5));\n    ASSERT_EQ(counts, 1);\n    counts = 0;\n    counts += eset.count(std::make_tuple(3, 3, 3));\n    ASSERT_EQ(counts, 1);\n  }\n}\n\nTEST(RowwiseTest, TestCSRPerEtypeSamplingUniform) {\n  _TestCSRPerEtypeSamplingUniform<int32_t, float>(true);\n  _TestCSRPerEtypeSamplingUniform<int64_t, float>(true);\n  _TestCSRPerEtypeSamplingUniform<int32_t, double>(true);\n  _TestCSRPerEtypeSamplingUniform<int64_t, double>(true);\n  _TestCSRPerEtypeSamplingUniform<int32_t, float>(false);\n  _TestCSRPerEtypeSamplingUniform<int64_t, float>(false);\n  _TestCSRPerEtypeSamplingUniform<int32_t, double>(false);\n  _TestCSRPerEtypeSamplingUniform<int64_t, double>(false);\n  _TestCSRPerEtypeSamplingUniformSorted<int32_t, float>();\n  _TestCSRPerEtypeSamplingUniformSorted<int64_t, float>();\n  _TestCSRPerEtypeSamplingUniformSorted<int32_t, double>();\n  _TestCSRPerEtypeSamplingUniformSorted<int64_t, double>();\n}\n\ntemplate <typename Idx, typename FloatType>\nvoid _TestCOOSampling(bool has_data) {\n  auto mat = COO<Idx>(has_data);\n  FloatArray prob =\n      NDArray::FromVector(std::vector<FloatType>({.5, .5, .5, .5, .5}));\n  IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 3}));\n  for (int k = 0; k < 10; ++k) {\n    auto rst = COORowWiseSampling(mat, rows, 2, prob, true);\n    CheckSampledResult<Idx>(rst, rows, has_data);\n  }\n  for (int k = 0; k < 10; ++k) {\n    auto rst = COORowWiseSampling(mat, rows, 2, prob, false);\n    CheckSampledResult<Idx>(rst, rows, has_data);\n    auto eset = ToEdgeSet<Idx>(rst);\n    ASSERT_EQ(eset.size(), 4);\n    if (has_data) {\n      ASSERT_TRUE(eset.count(std::make_tuple(0, 0, 2)));\n      ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 3)));\n      ASSERT_TRUE(eset.count(std::make_tuple(3, 2, 1)));\n      ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4)));\n    } else {\n      ASSERT_TRUE(eset.count(std::make_tuple(0, 0, 0)));\n      ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 1)));\n      ASSERT_TRUE(eset.count(std::make_tuple(3, 2, 3)));\n      ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4)));\n    }\n  }\n  prob = NDArray::FromVector(std::vector<FloatType>({.0, .5, .5, .0, .5}));\n  for (int k = 0; k < 100; ++k) {\n    auto rst = COORowWiseSampling(mat, rows, 2, prob, true);\n    CheckSampledResult<Idx>(rst, rows, has_data);\n    auto eset = ToEdgeSet<Idx>(rst);\n    if (has_data) {\n      ASSERT_FALSE(eset.count(std::make_tuple(0, 1, 3)));\n    } else {\n      ASSERT_FALSE(eset.count(std::make_tuple(0, 0, 0)));\n      ASSERT_FALSE(eset.count(std::make_tuple(3, 2, 3)));\n    }\n  }\n}\n\nTEST(RowwiseTest, TestCOOSampling) {\n  _TestCOOSampling<int32_t, float>(true);\n  _TestCOOSampling<int64_t, float>(true);\n  _TestCOOSampling<int32_t, double>(true);\n  _TestCOOSampling<int64_t, double>(true);\n  _TestCOOSampling<int32_t, float>(false);\n  _TestCOOSampling<int64_t, float>(false);\n  _TestCOOSampling<int32_t, double>(false);\n  _TestCOOSampling<int64_t, double>(false);\n}\n\ntemplate <typename Idx, typename FloatType>\nvoid _TestCOOSamplingUniform(bool has_data) {\n  auto mat = COO<Idx>(has_data);\n  FloatArray prob = aten::NullArray();\n  IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 3}));\n  for (int k = 0; k < 10; ++k) {\n    auto rst = COORowWiseSampling(mat, rows, 2, prob, true);\n    CheckSampledResult<Idx>(rst, rows, has_data);\n  }\n  for (int k = 0; k < 10; ++k) {\n    auto rst = COORowWiseSampling(mat, rows, 2, prob, false);\n    CheckSampledResult<Idx>(rst, rows, has_data);\n    auto eset = ToEdgeSet<Idx>(rst);\n    if (has_data) {\n      ASSERT_TRUE(eset.count(std::make_tuple(0, 0, 2)));\n      ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 3)));\n      ASSERT_TRUE(eset.count(std::make_tuple(3, 2, 1)));\n      ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4)));\n    } else {\n      ASSERT_TRUE(eset.count(std::make_tuple(0, 0, 0)));\n      ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 1)));\n      ASSERT_TRUE(eset.count(std::make_tuple(3, 2, 3)));\n      ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4)));\n    }\n  }\n}\n\nTEST(RowwiseTest, TestCOOSamplingUniform) {\n  _TestCOOSamplingUniform<int32_t, float>(true);\n  _TestCOOSamplingUniform<int64_t, float>(true);\n  _TestCOOSamplingUniform<int32_t, double>(true);\n  _TestCOOSamplingUniform<int64_t, double>(true);\n  _TestCOOSamplingUniform<int32_t, float>(false);\n  _TestCOOSamplingUniform<int64_t, float>(false);\n  _TestCOOSamplingUniform<int32_t, double>(false);\n  _TestCOOSamplingUniform<int64_t, double>(false);\n}\n\n// COOPerEtypeSampling with rowwise_etype_sorted == true is not meaningful as\n// it's never used in practice.\n\ntemplate <typename Idx, typename FloatType>\nvoid _TestCOOPerEtypeSampling(bool has_data) {\n  auto pair = COOEtypes<Idx>(has_data);\n  auto mat = pair.first;\n  auto eid2etype_offset = pair.second;\n  std::vector<FloatArray> prob = {\n      NDArray::FromVector(std::vector<FloatType>({.5, .5, .5, .5})),\n      NDArray::FromVector(std::vector<FloatType>({.5})),\n      NDArray::FromVector(std::vector<FloatType>({.5})),\n      NDArray::FromVector(std::vector<FloatType>({.5}))};\n  IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 3}));\n  for (int k = 0; k < 10; ++k) {\n    auto rst = COORowWisePerEtypeSampling(\n        mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, true);\n    CheckSampledPerEtypeResult<Idx>(rst, rows, has_data);\n  }\n  for (int k = 0; k < 10; ++k) {\n    auto rst = COORowWisePerEtypeSampling(\n        mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, false);\n    CheckSampledPerEtypeResult<Idx>(rst, rows, has_data);\n    auto eset = ToEdgeSet<Idx>(rst);\n    if (has_data) {\n      int counts = 0;\n      counts += eset.count(std::make_tuple(0, 0, 0));\n      counts += eset.count(std::make_tuple(0, 1, 1));\n      ASSERT_EQ(counts, 2);\n      counts = 0;\n      counts += eset.count(std::make_tuple(0, 2, 4));\n      ASSERT_EQ(counts, 1);\n      counts = 0;\n      counts += eset.count(std::make_tuple(0, 3, 6));\n      ASSERT_EQ(counts, 1);\n      counts = 0;\n      counts += eset.count(std::make_tuple(1, 1, 2));\n      ASSERT_EQ(counts, 0);\n      counts = 0;\n      counts += eset.count(std::make_tuple(3, 2, 5));\n      ASSERT_EQ(counts, 1);\n      counts = 0;\n      counts += eset.count(std::make_tuple(3, 3, 3));\n      ASSERT_EQ(counts, 1);\n    } else {\n      int counts = 0;\n      counts += eset.count(std::make_tuple(0, 0, 0));\n      counts += eset.count(std::make_tuple(0, 1, 1));\n      counts += eset.count(std::make_tuple(0, 2, 2));\n      counts += eset.count(std::make_tuple(0, 3, 3));\n      ASSERT_EQ(counts, 2);\n      counts = 0;\n      counts += eset.count(std::make_tuple(1, 1, 4));\n      ASSERT_EQ(counts, 0);\n      counts = 0;\n      counts += eset.count(std::make_tuple(3, 3, 5));\n      ASSERT_EQ(counts, 1);\n      counts = 0;\n      counts += eset.count(std::make_tuple(3, 2, 6));\n      ASSERT_EQ(counts, 1);\n    }\n  }\n\n  prob = {\n      NDArray::FromVector(std::vector<FloatType>({.0, .5, .0, .0})),\n      NDArray::FromVector(std::vector<FloatType>({.5})),\n      NDArray::FromVector(std::vector<FloatType>({.5})),\n      NDArray::FromVector(std::vector<FloatType>({.5}))};\n  for (int k = 0; k < 10; ++k) {\n    auto rst = COORowWisePerEtypeSampling(\n        mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, true);\n    CheckSampledPerEtypeResult<Idx>(rst, rows, has_data);\n    auto eset = ToEdgeSet<Idx>(rst);\n    if (has_data) {\n      ASSERT_FALSE(eset.count(std::make_tuple(0, 0, 0)));\n    } else {\n      ASSERT_FALSE(eset.count(std::make_tuple(0, 0, 0)));\n      ASSERT_FALSE(eset.count(std::make_tuple(0, 2, 2)));\n      ASSERT_FALSE(eset.count(std::make_tuple(0, 3, 3)));\n    }\n  }\n}\n\nTEST(RowwiseTest, TestCOOPerEtypeSampling) {\n  _TestCOOPerEtypeSampling<int32_t, float>(true);\n  _TestCOOPerEtypeSampling<int64_t, float>(true);\n  _TestCOOPerEtypeSampling<int32_t, double>(true);\n  _TestCOOPerEtypeSampling<int64_t, double>(true);\n  _TestCOOPerEtypeSampling<int32_t, float>(false);\n  _TestCOOPerEtypeSampling<int64_t, float>(false);\n  _TestCOOPerEtypeSampling<int32_t, double>(false);\n  _TestCOOPerEtypeSampling<int64_t, double>(false);\n}\n\ntemplate <typename Idx, typename FloatType>\nvoid _TestCOOPerEtypeSamplingUniform(bool has_data) {\n  auto pair = COOEtypes<Idx>(has_data);\n  auto mat = pair.first;\n  auto eid2etype_offset = pair.second;\n  std::vector<FloatArray> prob = {\n      aten::NullArray(), aten::NullArray(), aten::NullArray(),\n      aten::NullArray()};\n  IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 3}));\n  for (int k = 0; k < 10; ++k) {\n    auto rst = COORowWisePerEtypeSampling(\n        mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, true);\n    CheckSampledPerEtypeResult<Idx>(rst, rows, has_data);\n  }\n  for (int k = 0; k < 10; ++k) {\n    auto rst = COORowWisePerEtypeSampling(\n        mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, false);\n    CheckSampledPerEtypeResult<Idx>(rst, rows, has_data);\n    auto eset = ToEdgeSet<Idx>(rst);\n    if (has_data) {\n      int counts = 0;\n      counts += eset.count(std::make_tuple(0, 0, 0));\n      counts += eset.count(std::make_tuple(0, 1, 1));\n      ASSERT_EQ(counts, 2);\n      counts = 0;\n      counts += eset.count(std::make_tuple(0, 2, 4));\n      ASSERT_EQ(counts, 1);\n      counts = 0;\n      counts += eset.count(std::make_tuple(0, 3, 6));\n      ASSERT_EQ(counts, 1);\n      counts = 0;\n      counts += eset.count(std::make_tuple(1, 1, 2));\n      ASSERT_EQ(counts, 0);\n      counts = 0;\n      counts += eset.count(std::make_tuple(3, 2, 5));\n      ASSERT_EQ(counts, 1);\n      counts = 0;\n      counts += eset.count(std::make_tuple(3, 3, 3));\n      ASSERT_EQ(counts, 1);\n    } else {\n      int counts = 0;\n      counts += eset.count(std::make_tuple(0, 0, 0));\n      counts += eset.count(std::make_tuple(0, 1, 1));\n      counts += eset.count(std::make_tuple(0, 2, 2));\n      counts += eset.count(std::make_tuple(0, 3, 3));\n      ASSERT_EQ(counts, 2);\n      counts = 0;\n      counts += eset.count(std::make_tuple(1, 1, 4));\n      ASSERT_EQ(counts, 0);\n      counts = 0;\n      counts += eset.count(std::make_tuple(3, 3, 5));\n      ASSERT_EQ(counts, 1);\n      counts = 0;\n      counts += eset.count(std::make_tuple(3, 2, 6));\n      ASSERT_EQ(counts, 1);\n    }\n  }\n}\n\nTEST(RowwiseTest, TestCOOPerEtypeSamplingUniform) {\n  _TestCOOPerEtypeSamplingUniform<int32_t, float>(true);\n  _TestCOOPerEtypeSamplingUniform<int64_t, float>(true);\n  _TestCOOPerEtypeSamplingUniform<int32_t, double>(true);\n  _TestCOOPerEtypeSamplingUniform<int64_t, double>(true);\n  _TestCOOPerEtypeSamplingUniform<int32_t, float>(false);\n  _TestCOOPerEtypeSamplingUniform<int64_t, float>(false);\n  _TestCOOPerEtypeSamplingUniform<int32_t, double>(false);\n  _TestCOOPerEtypeSamplingUniform<int64_t, double>(false);\n}\n\ntemplate <typename Idx, typename FloatType>\nvoid _TestCSRTopk(bool has_data) {\n  auto mat = CSR<Idx>(has_data);\n  FloatArray weight =\n      NDArray::FromVector(std::vector<FloatType>({.1f, .0f, -.1f, .2f, .5f}));\n  // -.1, .2, .1, .0, .5\n  IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 3}));\n\n  {\n    auto rst = CSRRowWiseTopk(mat, rows, 1, weight, true);\n    auto eset = ToEdgeSet<Idx>(rst);\n    ASSERT_EQ(eset.size(), 2);\n    if (has_data) {\n      ASSERT_TRUE(eset.count(std::make_tuple(0, 0, 2)));\n      ASSERT_TRUE(eset.count(std::make_tuple(3, 2, 1)));\n    } else {\n      ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 1)));\n      ASSERT_TRUE(eset.count(std::make_tuple(3, 2, 3)));\n    }\n  }\n\n  {\n    auto rst = CSRRowWiseTopk(mat, rows, 1, weight, false);\n    auto eset = ToEdgeSet<Idx>(rst);\n    ASSERT_EQ(eset.size(), 2);\n    if (has_data) {\n      ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 3)));\n      ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4)));\n    } else {\n      ASSERT_TRUE(eset.count(std::make_tuple(0, 0, 0)));\n      ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4)));\n    }\n  }\n}\n\nTEST(RowwiseTest, TestCSRTopk) {\n  _TestCSRTopk<int32_t, float>(true);\n  _TestCSRTopk<int64_t, float>(true);\n  _TestCSRTopk<int32_t, double>(true);\n  _TestCSRTopk<int64_t, double>(true);\n  _TestCSRTopk<int32_t, float>(false);\n  _TestCSRTopk<int64_t, float>(false);\n  _TestCSRTopk<int32_t, double>(false);\n  _TestCSRTopk<int64_t, double>(false);\n}\n\ntemplate <typename Idx, typename FloatType>\nvoid _TestCOOTopk(bool has_data) {\n  auto mat = COO<Idx>(has_data);\n  FloatArray weight =\n      NDArray::FromVector(std::vector<FloatType>({.1f, .0f, -.1f, .2f, .5f}));\n  // -.1, .2, .1, .0, .5\n  IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 3}));\n\n  {\n    auto rst = COORowWiseTopk(mat, rows, 1, weight, true);\n    auto eset = ToEdgeSet<Idx>(rst);\n    ASSERT_EQ(eset.size(), 2);\n    if (has_data) {\n      ASSERT_TRUE(eset.count(std::make_tuple(0, 0, 2)));\n      ASSERT_TRUE(eset.count(std::make_tuple(3, 2, 1)));\n    } else {\n      ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 1)));\n      ASSERT_TRUE(eset.count(std::make_tuple(3, 2, 3)));\n    }\n  }\n\n  {\n    auto rst = COORowWiseTopk(mat, rows, 1, weight, false);\n    auto eset = ToEdgeSet<Idx>(rst);\n    ASSERT_EQ(eset.size(), 2);\n    if (has_data) {\n      ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 3)));\n      ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4)));\n    } else {\n      ASSERT_TRUE(eset.count(std::make_tuple(0, 0, 0)));\n      ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4)));\n    }\n  }\n}\n\nTEST(RowwiseTest, TestCOOTopk) {\n  _TestCOOTopk<int32_t, float>(true);\n  _TestCOOTopk<int64_t, float>(true);\n  _TestCOOTopk<int32_t, double>(true);\n  _TestCOOTopk<int64_t, double>(true);\n  _TestCOOTopk<int32_t, float>(false);\n  _TestCOOTopk<int64_t, float>(false);\n  _TestCOOTopk<int32_t, double>(false);\n  _TestCOOTopk<int64_t, double>(false);\n}\n\ntemplate <typename Idx, typename FloatType>\nvoid _TestCSRSamplingBiased(bool has_data) {\n  auto mat = CSR<Idx>(has_data);\n  // 0 - 0,1\n  // 1 - 1\n  // 3 - 2,3\n  NDArray tag_offset = NDArray::FromVector(\n      std::vector<Idx>({0, 1, 2, 0, 0, 1, 0, 0, 0, 0, 1, 2}));\n  tag_offset = tag_offset.CreateView({4, 3}, tag_offset->dtype);\n  IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 1, 3}));\n  FloatArray bias = NDArray::FromVector(std::vector<FloatType>({0, 0.5}));\n  for (int k = 0; k < 10; ++k) {\n    auto rst = CSRRowWiseSamplingBiased(mat, rows, 1, tag_offset, bias, false);\n    CheckSampledResult<Idx>(rst, rows, has_data);\n    auto eset = ToEdgeSet<Idx>(rst);\n    if (has_data) {\n      ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 3)));\n      ASSERT_TRUE(eset.count(std::make_tuple(1, 1, 0)));\n      ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4)));\n    } else {\n      ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 1)));\n      ASSERT_TRUE(eset.count(std::make_tuple(1, 1, 2)));\n      ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4)));\n    }\n  }\n  for (int k = 0; k < 10; ++k) {\n    auto rst = CSRRowWiseSamplingBiased(mat, rows, 3, tag_offset, bias, true);\n    CheckSampledResult<Idx>(rst, rows, has_data);\n    auto eset = ToEdgeSet<Idx>(rst);\n    if (has_data) {\n      ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 3)));\n      ASSERT_TRUE(eset.count(std::make_tuple(1, 1, 0)));\n      ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4)));\n      ASSERT_FALSE(eset.count(std::make_tuple(0, 0, 2)));\n      ASSERT_FALSE(eset.count(std::make_tuple(3, 2, 1)));\n    } else {\n      ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 1)));\n      ASSERT_TRUE(eset.count(std::make_tuple(1, 1, 2)));\n      ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4)));\n      ASSERT_FALSE(eset.count(std::make_tuple(0, 0, 0)));\n      ASSERT_FALSE(eset.count(std::make_tuple(3, 2, 3)));\n    }\n  }\n}\n\nTEST(RowwiseTest, TestCSRSamplingBiased) {\n  _TestCSRSamplingBiased<int32_t, float>(true);\n  _TestCSRSamplingBiased<int32_t, float>(false);\n  _TestCSRSamplingBiased<int64_t, float>(true);\n  _TestCSRSamplingBiased<int64_t, float>(false);\n  _TestCSRSamplingBiased<int32_t, double>(true);\n  _TestCSRSamplingBiased<int32_t, double>(false);\n  _TestCSRSamplingBiased<int64_t, double>(true);\n  _TestCSRSamplingBiased<int64_t, double>(false);\n}\n"
  },
  {
    "path": "tests/cpp/test_sampler.cc",
    "content": "#include <gtest/gtest.h>\n\n#include <algorithm>\n#include <iostream>\n#include <vector>\n\n#include \"../../src/random/cpu/sample_utils.h\"\n#include \"./common.h\"\n\nusing namespace dgl;\nusing namespace dgl::aten;\n\n// TODO: adapt this to Random::Choice\n\ntemplate <typename Idx, typename DType>\nvoid _TestWithReplacement(RandomEngine* re) {\n  Idx n_categories = 100;\n  Idx n_rolls = 1000000;\n  std::vector<DType> _prob;\n  DType accum = 0.;\n  for (Idx i = 0; i < n_categories; ++i) {\n    _prob.push_back(re->Uniform<DType>());\n    accum += _prob.back();\n  }\n  for (Idx i = 0; i < n_categories; ++i) _prob[i] /= accum;\n  FloatArray prob = NDArray::FromVector(_prob);\n\n  auto _check_given_sampler = [n_categories, n_rolls,\n                               &_prob](utils::BaseSampler<Idx>* s) {\n    std::vector<Idx> counter(n_categories, 0);\n    for (Idx i = 0; i < n_rolls; ++i) {\n      Idx dice = s->Draw();\n      counter[dice]++;\n    }\n    for (Idx i = 0; i < n_categories; ++i)\n      ASSERT_NEAR(static_cast<DType>(counter[i]) / n_rolls, _prob[i], 1e-2);\n  };\n\n  auto _check_random_choice = [n_categories, n_rolls, &_prob, prob]() {\n    std::vector<int64_t> counter(n_categories, 0);\n    for (Idx i = 0; i < n_rolls; ++i) {\n      Idx dice = RandomEngine::ThreadLocal()->Choice<int64_t>(prob);\n      counter[dice]++;\n    }\n    for (Idx i = 0; i < n_categories; ++i)\n      ASSERT_NEAR(static_cast<DType>(counter[i]) / n_rolls, _prob[i], 1e-2);\n  };\n\n  utils::AliasSampler<Idx, DType, true> as(re, prob);\n  utils::CDFSampler<Idx, DType, true> cs(re, prob);\n  utils::TreeSampler<Idx, DType, true> ts(re, prob);\n  _check_given_sampler(&as);\n  _check_given_sampler(&cs);\n  _check_given_sampler(&ts);\n  _check_random_choice();\n}\n\nTEST(SampleUtilsTest, TestWithReplacement) {\n  RandomEngine* re = RandomEngine::ThreadLocal();\n  re->SetSeed(42);\n  _TestWithReplacement<int32_t, float>(re);\n  re->SetSeed(42);\n  _TestWithReplacement<int32_t, double>(re);\n  re->SetSeed(42);\n  _TestWithReplacement<int64_t, float>(re);\n  re->SetSeed(42);\n  _TestWithReplacement<int64_t, double>(re);\n};\n\ntemplate <typename Idx, typename DType>\nvoid _TestWithoutReplacementOrder(RandomEngine* re) {\n  // TODO(BarclayII): is there a reliable way to do this test?\n  std::vector<DType> _prob = {1e6f, 1e-6f, 1e-2f, 1e2f};\n  FloatArray prob = NDArray::FromVector(_prob);\n  std::vector<Idx> ground_truth = {0, 3, 2, 1};\n\n  auto _check_given_sampler = [&ground_truth](utils::BaseSampler<Idx>* s) {\n    for (size_t i = 0; i < ground_truth.size(); ++i) {\n      Idx dice = s->Draw();\n      ASSERT_EQ(dice, ground_truth[i]);\n    }\n  };\n\n  utils::AliasSampler<Idx, DType, false> as(re, prob);\n  utils::CDFSampler<Idx, DType, false> cs(re, prob);\n  utils::TreeSampler<Idx, DType, false> ts(re, prob);\n  _check_given_sampler(&as);\n  _check_given_sampler(&cs);\n  _check_given_sampler(&ts);\n}\n\nTEST(SampleUtilsTest, TestWithoutReplacementOrder) {\n  RandomEngine* re = RandomEngine::ThreadLocal();\n  re->SetSeed(42);\n  _TestWithoutReplacementOrder<int32_t, float>(re);\n  re->SetSeed(42);\n  _TestWithoutReplacementOrder<int32_t, double>(re);\n  re->SetSeed(42);\n  _TestWithoutReplacementOrder<int64_t, float>(re);\n  re->SetSeed(42);\n  _TestWithoutReplacementOrder<int64_t, double>(re);\n};\n\ntemplate <typename Idx, typename DType>\nvoid _TestWithoutReplacementUnique(RandomEngine* re) {\n  Idx N = 1000000;\n  std::vector<DType> _likelihood;\n  for (Idx i = 0; i < N; ++i) _likelihood.push_back(re->Uniform<DType>());\n  FloatArray likelihood = NDArray::FromVector(_likelihood);\n\n  auto _check_given_sampler = [N](utils::BaseSampler<Idx>* s) {\n    std::vector<int> cnt(N, 0);\n    for (Idx i = 0; i < N; ++i) {\n      Idx dice = s->Draw();\n      cnt[dice]++;\n    }\n    for (Idx i = 0; i < N; ++i) ASSERT_EQ(cnt[i], 1);\n  };\n\n  utils::AliasSampler<Idx, DType, false> as(re, likelihood);\n  utils::CDFSampler<Idx, DType, false> cs(re, likelihood);\n  utils::TreeSampler<Idx, DType, false> ts(re, likelihood);\n  _check_given_sampler(&as);\n  _check_given_sampler(&cs);\n  _check_given_sampler(&ts);\n}\n\nTEST(SampleUtilsTest, TestWithoutReplacementUnique) {\n  RandomEngine* re = RandomEngine::ThreadLocal();\n  re->SetSeed(42);\n  _TestWithoutReplacementUnique<int32_t, float>(re);\n  re->SetSeed(42);\n  _TestWithoutReplacementUnique<int32_t, double>(re);\n  re->SetSeed(42);\n  _TestWithoutReplacementUnique<int64_t, float>(re);\n  re->SetSeed(42);\n  _TestWithoutReplacementUnique<int64_t, double>(re);\n};\n\ntemplate <typename Idx, typename DType>\nvoid _TestChoice(RandomEngine* re) {\n  re->SetSeed(42);\n  std::vector<DType> prob_vec = {1., 0., 0., 0., 2., 2., 0., 0.};\n  FloatArray prob = FloatArray::FromVector(prob_vec);\n  {\n    for (int k = 0; k < 1000; ++k) {\n      Idx x = re->Choice<Idx>(prob);\n      ASSERT_TRUE(x == 0 || x == 4 || x == 5);\n    }\n  }\n  // num = 0\n  {\n    IdArray rst = re->Choice<Idx, DType>(0, prob, true);\n    ASSERT_EQ(rst->shape[0], 0);\n  }\n  // w/ replacement\n  {\n    IdArray rst = re->Choice<Idx, DType>(1000, prob, true);\n    ASSERT_EQ(rst->shape[0], 1000);\n    for (int64_t i = 0; i < 1000; ++i) {\n      Idx x = static_cast<Idx*>(rst->data)[i];\n      ASSERT_TRUE(x == 0 || x == 4 || x == 5);\n    }\n  }\n  // w/o replacement\n  {\n    IdArray rst = re->Choice<Idx, DType>(3, prob, false);\n    ASSERT_EQ(rst->shape[0], 3);\n    std::set<Idx> idxset;\n    for (int64_t i = 0; i < 3; ++i) {\n      Idx x = static_cast<Idx*>(rst->data)[i];\n      idxset.insert(x);\n    }\n    ASSERT_EQ(idxset.size(), 3);\n    ASSERT_EQ(idxset.count(0), 1);\n    ASSERT_EQ(idxset.count(4), 1);\n    ASSERT_EQ(idxset.count(5), 1);\n  }\n}\n\nTEST(RandomTest, TestChoice) {\n  RandomEngine* re = RandomEngine::ThreadLocal();\n  _TestChoice<int32_t, float>(re);\n  _TestChoice<int64_t, float>(re);\n  _TestChoice<int32_t, double>(re);\n  _TestChoice<int64_t, double>(re);\n}\n\ntemplate <typename Idx>\nvoid _TestUniformChoice(RandomEngine* re) {\n  re->SetSeed(42);\n  // num == 0\n  {\n    IdArray rst = re->UniformChoice<Idx>(0, 100, true);\n    ASSERT_EQ(rst->shape[0], 0);\n  }\n  // w/ replacement\n  {\n    IdArray rst = re->UniformChoice<Idx>(1000, 100, true);\n    ASSERT_EQ(rst->shape[0], 1000);\n    for (int64_t i = 0; i < 1000; ++i) {\n      Idx x = static_cast<Idx*>(rst->data)[i];\n      ASSERT_TRUE(x >= 0 && x < 100);\n    }\n  }\n  // w/o replacement\n  {\n    IdArray rst = re->UniformChoice<Idx>(99, 100, false);\n    ASSERT_EQ(rst->shape[0], 99);\n    std::set<Idx> idxset;\n    for (int64_t i = 0; i < 99; ++i) {\n      Idx x = static_cast<Idx*>(rst->data)[i];\n      ASSERT_TRUE(x >= 0 && x < 100);\n      idxset.insert(x);\n    }\n    ASSERT_EQ(idxset.size(), 99);\n  }\n}\n\nTEST(RandomTest, TestUniformChoice) {\n  RandomEngine* re = RandomEngine::ThreadLocal();\n  _TestUniformChoice<int32_t>(re);\n  _TestUniformChoice<int64_t>(re);\n  _TestUniformChoice<int32_t>(re);\n  _TestUniformChoice<int64_t>(re);\n}\n\ntemplate <typename Idx, typename FloatType>\nvoid _TestBiasedChoice(RandomEngine* re) {\n  re->SetSeed(42);\n  // num == 0\n  {\n    Idx split[] = {0, 1, 2};\n    FloatArray bias = NDArray::FromVector(std::vector<FloatType>({1, 3}));\n    IdArray rst = re->BiasedChoice<Idx, FloatType>(0, split, bias, true);\n    ASSERT_EQ(rst->shape[0], 0);\n  }\n  // basic test\n  {\n    Idx sample_num = 100000;\n    Idx population = 1000000;\n    Idx split[] = {0, population / 2, population};\n    FloatArray bias = NDArray::FromVector(std::vector<FloatType>({1, 3}));\n\n    IdArray rst =\n        re->BiasedChoice<Idx, FloatType>(sample_num, split, bias, true);\n    auto rst_data = static_cast<Idx*>(rst->data);\n    Idx larger = 0;\n    for (Idx i = 0; i < sample_num; ++i)\n      if (rst_data[i] >= population / 2) larger++;\n    ASSERT_LE(fabs((double)larger / sample_num - 0.75), 1e-2);\n  }\n  // without replacement\n  {\n    Idx sample_num = 500;\n    Idx population = 1000;\n    Idx split[] = {0, sample_num, population};\n    FloatArray bias = NDArray::FromVector(std::vector<FloatType>({1, 0}));\n\n    IdArray rst =\n        re->BiasedChoice<Idx, FloatType>(sample_num, split, bias, false);\n    auto rst_data = static_cast<Idx*>(rst->data);\n\n    std::set<Idx> idxset;\n    for (int64_t i = 0; i < sample_num; ++i) {\n      Idx x = rst_data[i];\n      ASSERT_LT(x, sample_num);\n      idxset.insert(x);\n    }\n    ASSERT_EQ(idxset.size(), sample_num);\n  }\n}\n\nTEST(RandomTest, TestBiasedChoice) {\n  RandomEngine* re = RandomEngine::ThreadLocal();\n  _TestBiasedChoice<int32_t, float>(re);\n  _TestBiasedChoice<int64_t, float>(re);\n  _TestBiasedChoice<int32_t, double>(re);\n  _TestBiasedChoice<int64_t, double>(re);\n}\n"
  },
  {
    "path": "tests/cpp/test_serialize.cc",
    "content": "#include <dgl/graph_serializer.h>\n#include <dgl/immutable_graph.h>\n#include <dmlc/memory_io.h>\n#include <gtest/gtest.h>\n\n#include <algorithm>\n#include <iostream>\n#include <memory>\n#include <vector>\n\n#include \"../../src/graph/heterograph.h\"\n#include \"../../src/graph/unit_graph.h\"\n#include \"./common.h\"\n\nusing namespace dgl;\nusing namespace dgl::aten;\nusing namespace dmlc;\n\nTEST(Serialize, UnitGraph_COO) {\n  aten::CSRMatrix csr_matrix;\n  auto src = VecToIdArray<int64_t>({1, 2, 5, 3});\n  auto dst = VecToIdArray<int64_t>({1, 6, 2, 6});\n  auto mg = std::dynamic_pointer_cast<UnitGraph>(\n      dgl::UnitGraph::CreateFromCOO(2, 9, 8, src, dst, COO_CODE));\n\n  std::string blob;\n  dmlc::MemoryStringStream ifs(&blob);\n\n  static_cast<dmlc::Stream *>(&ifs)->Write(mg);\n\n  dmlc::MemoryStringStream ofs(&blob);\n  auto ug2 = Serializer::make_shared<UnitGraph>();\n  static_cast<dmlc::Stream *>(&ofs)->Read(&ug2);\n  EXPECT_EQ(ug2->NumVertices(0), 9);\n  EXPECT_EQ(ug2->NumVertices(1), 8);\n  EXPECT_EQ(ug2->NumEdges(0), 4);\n  EXPECT_EQ(ug2->FindEdge(0, 1).first, 2);\n  EXPECT_EQ(ug2->FindEdge(0, 1).second, 6);\n}\n\nTEST(Serialize, UnitGraph_CSR) {\n  aten::CSRMatrix csr_matrix;\n  auto src = VecToIdArray<int64_t>({1, 2, 5, 3});\n  auto dst = VecToIdArray<int64_t>({1, 6, 2, 6});\n  auto coo_g = std::dynamic_pointer_cast<UnitGraph>(\n      dgl::UnitGraph::CreateFromCOO(2, 9, 8, src, dst));\n  auto csr_g =\n      std::dynamic_pointer_cast<UnitGraph>(coo_g->GetGraphInFormat(CSR_CODE));\n\n  std::string blob;\n  dmlc::MemoryStringStream ifs(&blob);\n\n  static_cast<dmlc::Stream *>(&ifs)->Write(csr_g);\n\n  dmlc::MemoryStringStream ofs(&blob);\n  auto ug2 = Serializer::make_shared<UnitGraph>();\n  static_cast<dmlc::Stream *>(&ofs)->Read(&ug2);\n  // Query operation is not supported on CSR, how to check it?\n}\n\nTEST(Serialize, ImmutableGraph) {\n  auto src = VecToIdArray<int64_t>({1, 2, 5, 3});\n  auto dst = VecToIdArray<int64_t>({1, 6, 2, 6});\n  auto gptr = ImmutableGraph::CreateFromCOO(10, src, dst);\n  std::string blob;\n  dmlc::MemoryStringStream ifs(&blob);\n\n  static_cast<dmlc::Stream *>(&ifs)->Write(gptr);\n\n  dmlc::MemoryStringStream ofs(&blob);\n  auto rptr_read = dgl::Serializer::make_shared<ImmutableGraph>();\n  static_cast<dmlc::Stream *>(&ofs)->Read(&rptr_read);\n  EXPECT_EQ(rptr_read->NumEdges(), 4);\n  EXPECT_EQ(rptr_read->NumVertices(), 10);\n  EXPECT_EQ(rptr_read->FindEdge(2).first, 5);\n  EXPECT_EQ(rptr_read->FindEdge(2).second, 2);\n}\n\nTEST(Serialize, HeteroGraph) {\n  auto src = VecToIdArray<int64_t>({1, 2, 5, 3});\n  auto dst = VecToIdArray<int64_t>({1, 6, 2, 6});\n  auto mg1 = dgl::UnitGraph::CreateFromCOO(2, 9, 8, src, dst);\n  src = VecToIdArray<int64_t>({6, 2, 5, 1, 8});\n  dst = VecToIdArray<int64_t>({5, 2, 4, 8, 0});\n  auto mg2 = dgl::UnitGraph::CreateFromCOO(1, 9, 9, src, dst);\n  std::vector<HeteroGraphPtr> relgraphs;\n  relgraphs.push_back(mg1);\n  relgraphs.push_back(mg2);\n  src = VecToIdArray<int64_t>({0, 0});\n  dst = VecToIdArray<int64_t>({1, 0});\n  auto meta_gptr = ImmutableGraph::CreateFromCOO(3, src, dst);\n  auto hrptr = std::make_shared<HeteroGraph>(meta_gptr, relgraphs);\n\n  std::string blob;\n  dmlc::MemoryStringStream ifs(&blob);\n  static_cast<dmlc::Stream *>(&ifs)->Write(hrptr);\n\n  dmlc::MemoryStringStream ofs(&blob);\n  auto gptr = dgl::Serializer::make_shared<HeteroGraph>();\n  static_cast<dmlc::Stream *>(&ofs)->Read(&gptr);\n  EXPECT_EQ(gptr->NumVertices(0), 9);\n  EXPECT_EQ(gptr->NumVertices(1), 8);\n}\n"
  },
  {
    "path": "tests/cpp/test_smart_ptr_serialize.cc",
    "content": "#include <dgl/runtime/serializer.h>\n#include <dgl/runtime/smart_ptr_serializer.h>\n#include <dmlc/io.h>\n#include <dmlc/logging.h>\n#include <dmlc/memory_io.h>\n#include <dmlc/parameter.h>\n#include <gtest/gtest.h>\n\n#include <cstring>\n#include <iostream>\n#include <sstream>\n#include <unordered_map>\n\nusing namespace std;\n\nclass MyClass {\n public:\n  MyClass() {}\n  MyClass(std::string data) : data_(data) {}\n  inline void Save(dmlc::Stream *strm) const { strm->Write(this->data_); }\n  inline bool Load(dmlc::Stream *strm) { return strm->Read(&data_); }\n  inline bool operator==(const MyClass &other) const {\n    return data_ == other.data_;\n  }\n\n public:\n  std::string data_;\n};\n// need to declare the traits property of my class to dmlc\nnamespace dmlc {\nDMLC_DECLARE_TRAITS(has_saveload, MyClass, true);\n}\n\ntemplate <typename T>\nclass SmartPtrTest : public ::testing::Test {\n public:\n  typedef T SmartPtr;\n};\n\nusing SmartPtrTypes =\n    ::testing::Types<std::shared_ptr<MyClass>, std::unique_ptr<MyClass>>;\nTYPED_TEST_SUITE(SmartPtrTest, SmartPtrTypes);\n\nTYPED_TEST(SmartPtrTest, Obj_Test) {\n  std::string blob;\n  dmlc::MemoryStringStream fs(&blob);\n  using SmartPtr = typename TestFixture::SmartPtr;\n  auto myc = SmartPtr(new MyClass(\"1111\"));\n  { static_cast<dmlc::Stream *>(&fs)->Write(myc); }\n  fs.Seek(0);\n  auto copy_data = SmartPtr(new MyClass());\n  CHECK(static_cast<dmlc::Stream *>(&fs)->Read(&copy_data));\n\n  EXPECT_EQ(myc->data_, copy_data->data_);\n}\n\nTYPED_TEST(SmartPtrTest, Vector_Test1) {\n  std::string blob;\n  dmlc::MemoryStringStream fs(&blob);\n  using SmartPtr = typename TestFixture::SmartPtr;\n  typedef std::pair<std::string, SmartPtr> Pair;\n\n  std::vector<Pair> myclasses;\n  myclasses.emplace_back(\"a\", SmartPtr(new MyClass(\"@A@B\")));\n  myclasses.emplace_back(\"b\", SmartPtr(new MyClass(\"2222\")));\n  static_cast<dmlc::Stream *>(&fs)->Write<std::vector<Pair>>(myclasses);\n\n  dmlc::MemoryStringStream ofs(&blob);\n  std::vector<Pair> copy_myclasses;\n  static_cast<dmlc::Stream *>(&ofs)->Read<std::vector<Pair>>(&copy_myclasses);\n\n  EXPECT_TRUE(std::equal(\n      myclasses.begin(), myclasses.end(), copy_myclasses.begin(),\n      [](const Pair &left, const Pair &right) {\n        return (left.second->data_ == right.second->data_) &&\n               (left.first == right.first);\n      }));\n}\n\nTYPED_TEST(SmartPtrTest, Vector_Test2) {\n  std::string blob;\n  dmlc::MemoryStringStream fs(&blob);\n  using SmartPtr = typename TestFixture::SmartPtr;\n\n  std::vector<SmartPtr> myclasses;\n  myclasses.emplace_back(new MyClass(\"@A@\"));\n  myclasses.emplace_back(new MyClass(\"2222\"));\n  static_cast<dmlc::Stream *>(&fs)->Write<std::vector<SmartPtr>>(myclasses);\n\n  dmlc::MemoryStringStream ofs(&blob);\n  std::vector<SmartPtr> copy_myclasses;\n  static_cast<dmlc::Stream *>(&ofs)->Read<std::vector<SmartPtr>>(\n      &copy_myclasses);\n\n  EXPECT_TRUE(std::equal(\n      myclasses.begin(), myclasses.end(), copy_myclasses.begin(),\n      [](const SmartPtr &left, const SmartPtr &right) {\n        return left->data_ == right->data_;\n      }));\n}\n"
  },
  {
    "path": "tests/cpp/test_spmat_coo.cc",
    "content": "#include <dgl/array.h>\n#include <dmlc/omp.h>\n#include <gtest/gtest.h>\n#include <omp.h>\n\n#include <random>\n\n#include \"./common.h\"\n\nusing namespace dgl;\nusing namespace dgl::runtime;\n\nnamespace {\n\ntemplate <typename IDX>\naten::CSRMatrix CSR1(DGLContext ctx = CTX) {\n  // [[0, 1, 1, 0, 0],\n  //  [1, 0, 0, 0, 0],\n  //  [0, 0, 1, 1, 0],\n  //  [0, 0, 0, 0, 0]]\n  // data: [0, 2, 3, 1, 4]\n  return aten::CSRMatrix(\n      4, 5,\n      aten::VecToIdArray(\n          std::vector<IDX>({0, 2, 3, 5, 5}), sizeof(IDX) * 8, ctx),\n      aten::VecToIdArray(\n          std::vector<IDX>({1, 2, 0, 2, 3}), sizeof(IDX) * 8, ctx),\n      aten::VecToIdArray(\n          std::vector<IDX>({0, 2, 3, 4, 1}), sizeof(IDX) * 8, ctx),\n      false);\n}\n\ntemplate <typename IDX>\naten::CSRMatrix CSR2(DGLContext ctx = CTX) {\n  // has duplicate entries\n  // [[0, 1, 2, 0, 0],\n  //  [1, 0, 0, 0, 0],\n  //  [0, 0, 1, 1, 0],\n  //  [0, 0, 0, 0, 0]]\n  // data: [0, 2, 5, 3, 1, 4]\n  return aten::CSRMatrix(\n      4, 5,\n      aten::VecToIdArray(\n          std::vector<IDX>({0, 3, 4, 6, 6}), sizeof(IDX) * 8, ctx),\n      aten::VecToIdArray(\n          std::vector<IDX>({1, 2, 2, 0, 2, 3}), sizeof(IDX) * 8, ctx),\n      aten::VecToIdArray(\n          std::vector<IDX>({0, 2, 5, 3, 1, 4}), sizeof(IDX) * 8, ctx),\n      false);\n}\n\ntemplate <typename IDX>\naten::COOMatrix COO1(DGLContext ctx = CTX) {\n  // [[0, 1, 1, 0, 0],\n  //  [1, 0, 0, 0, 0],\n  //  [0, 0, 1, 1, 0],\n  //  [0, 0, 0, 0, 0]]\n  // data: [0, 2, 3, 1, 4]\n  // row : [0, 2, 0, 1, 2]\n  // col : [1, 2, 2, 0, 3]\n  return aten::COOMatrix(\n      4, 5,\n      aten::VecToIdArray(\n          std::vector<IDX>({0, 2, 0, 1, 2}), sizeof(IDX) * 8, ctx),\n      aten::VecToIdArray(\n          std::vector<IDX>({1, 2, 2, 0, 3}), sizeof(IDX) * 8, ctx),\n      aten::VecToIdArray(\n          std::vector<IDX>({0, 3, 1, 2, 4}), sizeof(IDX) * 8, ctx));\n}\n\ntemplate <typename IDX>\naten::COOMatrix COO2(DGLContext ctx = CTX) {\n  // has duplicate entries\n  // [[0, 1, 2, 0, 0],\n  //  [1, 0, 0, 0, 0],\n  //  [0, 0, 1, 1, 0],\n  //  [0, 0, 0, 0, 0]]\n  // data: [0, 2, 5, 3, 1, 4]\n  // row : [0, 2, 0, 1, 2, 0]\n  // col : [1, 2, 2, 0, 3, 2]\n  return aten::COOMatrix(\n      4, 5,\n      aten::VecToIdArray(\n          std::vector<IDX>({0, 2, 0, 1, 2, 0}), sizeof(IDX) * 8, ctx),\n      aten::VecToIdArray(\n          std::vector<IDX>({1, 2, 2, 0, 3, 2}), sizeof(IDX) * 8, ctx),\n      aten::VecToIdArray(\n          std::vector<IDX>({0, 1, 2, 3, 4, 5}), sizeof(IDX) * 8, ctx));\n}\n\ntemplate <typename IDX>\naten::CSRMatrix SR_CSR3(DGLContext ctx) {\n  // [[0, 1, 2, 0, 0],\n  //  [1, 0, 0, 0, 0],\n  //  [0, 0, 1, 1, 0],\n  //  [0, 0, 0, 0, 0]]\n  return aten::CSRMatrix(\n      4, 5,\n      aten::VecToIdArray(\n          std::vector<IDX>({0, 3, 4, 6, 6}), sizeof(IDX) * 8, ctx),\n      aten::VecToIdArray(\n          std::vector<IDX>({2, 1, 2, 0, 2, 3}), sizeof(IDX) * 8, ctx),\n      aten::VecToIdArray(\n          std::vector<IDX>({0, 2, 5, 3, 1, 4}), sizeof(IDX) * 8, ctx),\n      false);\n}\n\ntemplate <typename IDX>\naten::CSRMatrix SRC_CSR3(DGLContext ctx) {\n  // [[0, 1, 2, 0, 0],\n  //  [1, 0, 0, 0, 0],\n  //  [0, 0, 1, 1, 0],\n  //  [0, 0, 0, 0, 0]]\n  return aten::CSRMatrix(\n      4, 5,\n      aten::VecToIdArray(\n          std::vector<IDX>({0, 3, 4, 6, 6}), sizeof(IDX) * 8, ctx),\n      aten::VecToIdArray(\n          std::vector<IDX>({1, 2, 2, 0, 2, 3}), sizeof(IDX) * 8, ctx),\n      aten::VecToIdArray(\n          std::vector<IDX>({2, 0, 5, 3, 1, 4}), sizeof(IDX) * 8, ctx),\n      false);\n}\n\ntemplate <typename IDX>\naten::COOMatrix COO3(DGLContext ctx) {\n  // has duplicate entries\n  // [[0, 1, 2, 0, 0],\n  //  [1, 0, 0, 0, 0],\n  //  [0, 0, 1, 1, 0],\n  //  [0, 0, 0, 0, 0]]\n  // row : [0, 2, 0, 1, 2, 0]\n  // col : [2, 2, 1, 0, 3, 2]\n  return aten::COOMatrix(\n      4, 5,\n      aten::VecToIdArray(\n          std::vector<IDX>({0, 2, 0, 1, 2, 0}), sizeof(IDX) * 8, ctx),\n      aten::VecToIdArray(\n          std::vector<IDX>({2, 2, 1, 0, 3, 2}), sizeof(IDX) * 8, ctx));\n}\n\ntemplate <typename IDX>\naten::COOMatrix COORandomized(IDX rows_and_cols, int64_t nnz, int seed) {\n  std::vector<IDX> vec_rows(nnz);\n  std::vector<IDX> vec_cols(nnz);\n  std::vector<IDX> vec_data(nnz);\n\n#pragma omp parallel\n  {\n    const int64_t num_threads = omp_get_num_threads();\n    const int64_t thread_id = omp_get_thread_num();\n    const int64_t chunk = nnz / num_threads;\n    const int64_t size = (thread_id == num_threads - 1)\n                             ? nnz - chunk * (num_threads - 1)\n                             : chunk;\n    auto rows = vec_rows.data() + thread_id * chunk;\n    auto cols = vec_cols.data() + thread_id * chunk;\n    auto data = vec_data.data() + thread_id * chunk;\n\n    std::mt19937_64 gen64(seed + thread_id);\n    std::mt19937 gen32(seed + thread_id);\n\n    for (int64_t i = 0; i < size; ++i) {\n      rows[i] = gen64() % rows_and_cols;\n      cols[i] = gen64() % rows_and_cols;\n      data[i] = gen32() % 90 + 1;\n    }\n  }\n\n  return aten::COOMatrix(\n      rows_and_cols, rows_and_cols,\n      aten::VecToIdArray(vec_rows, sizeof(IDX) * 8, CTX),\n      aten::VecToIdArray(vec_cols, sizeof(IDX) * 8, CTX),\n      aten::VecToIdArray(vec_data, sizeof(IDX) * 8, CTX), false, false);\n}\n\nstruct SparseCOOCSR {\n  static constexpr uint64_t NUM_ROWS = 100;\n  static constexpr uint64_t NUM_COLS = 150;\n  static constexpr uint64_t NUM_NZ = 5;\n  template <typename IDX>\n  static aten::COOMatrix COOSparse(const DGLContext &ctx = CTX) {\n    return aten::COOMatrix(\n        NUM_ROWS, NUM_COLS,\n        aten::VecToIdArray(\n            std::vector<IDX>({0, 1, 2, 3, 4}), sizeof(IDX) * 8, ctx),\n        aten::VecToIdArray(\n            std::vector<IDX>({1, 2, 3, 4, 5}), sizeof(IDX) * 8, ctx));\n  }\n\n  template <typename IDX>\n  static aten::CSRMatrix CSRSparse(const DGLContext &ctx = CTX) {\n    auto &&indptr = std::vector<IDX>(NUM_ROWS + 1, NUM_NZ);\n    for (size_t i = 0; i < NUM_NZ; ++i) {\n      indptr[i + 1] = static_cast<IDX>(i + 1);\n    }\n    indptr[0] = 0;\n    return aten::CSRMatrix(\n        NUM_ROWS, NUM_COLS, aten::VecToIdArray(indptr, sizeof(IDX) * 8, ctx),\n        aten::VecToIdArray(\n            std::vector<IDX>({1, 2, 3, 4, 5}), sizeof(IDX) * 8, ctx),\n        aten::VecToIdArray(\n            std::vector<IDX>({1, 1, 1, 1, 1}), sizeof(IDX) * 8, ctx),\n        false);\n  }\n};\n\ntemplate <typename IDX>\naten::COOMatrix RowSorted_NullData_COO(DGLContext ctx = CTX) {\n  // [[0, 1, 1, 0, 0],\n  //  [1, 0, 0, 0, 0],\n  //  [0, 0, 1, 1, 0],\n  //  [0, 0, 0, 0, 0]]\n  // row : [0, 0, 1, 2, 2]\n  // col : [1, 2, 0, 2, 3]\n  return aten::COOMatrix(\n      4, 5,\n      aten::VecToIdArray(\n          std::vector<IDX>({0, 0, 1, 2, 2}), sizeof(IDX) * 8, ctx),\n      aten::VecToIdArray(\n          std::vector<IDX>({1, 2, 0, 2, 3}), sizeof(IDX) * 8, ctx),\n      aten::NullArray(), true, false);\n}\n\ntemplate <typename IDX>\naten::CSRMatrix RowSorted_NullData_CSR(DGLContext ctx = CTX) {\n  // [[0, 1, 1, 0, 0],\n  //  [1, 0, 0, 0, 0],\n  //  [0, 0, 1, 1, 0],\n  //  [0, 0, 0, 0, 0]]\n  // data: [0, 1, 2, 3, 4]\n  return aten::CSRMatrix(\n      4, 5,\n      aten::VecToIdArray(\n          std::vector<IDX>({0, 2, 3, 5, 5}), sizeof(IDX) * 8, ctx),\n      aten::VecToIdArray(\n          std::vector<IDX>({1, 2, 0, 2, 3}), sizeof(IDX) * 8, ctx),\n      aten::VecToIdArray(\n          std::vector<IDX>({0, 1, 2, 3, 4}), sizeof(IDX) * 8, ctx),\n      false);\n}\n}  // namespace\n\ntemplate <typename IDX>\nvoid _TestCOOToCSR(DGLContext ctx) {\n  auto coo = COO1<IDX>(ctx);\n  auto csr = CSR1<IDX>(ctx);\n  auto tcsr = aten::COOToCSR(coo);\n  ASSERT_FALSE(coo.row_sorted);\n  ASSERT_EQ(csr.num_rows, tcsr.num_rows);\n  ASSERT_EQ(csr.num_cols, tcsr.num_cols);\n  ASSERT_TRUE(ArrayEQ<IDX>(csr.indptr, tcsr.indptr));\n  ASSERT_TRUE(ArrayEQ<IDX>(csr.indices, tcsr.indices));\n\n  coo = COO2<IDX>(ctx);\n  csr = CSR2<IDX>(ctx);\n  tcsr = aten::COOToCSR(coo);\n  ASSERT_EQ(coo.num_rows, csr.num_rows);\n  ASSERT_EQ(coo.num_cols, csr.num_cols);\n  ASSERT_TRUE(ArrayEQ<IDX>(csr.indptr, tcsr.indptr));\n\n  // Convert from row sorted coo\n  coo = COO1<IDX>(ctx);\n  auto rs_coo = aten::COOSort(coo, false);\n  auto rs_csr = CSR1<IDX>(ctx);\n  auto rs_tcsr = aten::COOToCSR(rs_coo);\n  ASSERT_TRUE(rs_coo.row_sorted);\n  ASSERT_EQ(coo.num_rows, rs_tcsr.num_rows);\n  ASSERT_EQ(coo.num_cols, rs_tcsr.num_cols);\n  ASSERT_TRUE(ArrayEQ<IDX>(rs_csr.indptr, rs_tcsr.indptr));\n  ASSERT_TRUE(ArrayEQ<IDX>(rs_tcsr.indices, rs_coo.col));\n  ASSERT_TRUE(ArrayEQ<IDX>(rs_tcsr.data, rs_coo.data));\n\n  coo = COO3<IDX>(ctx);\n  rs_coo = aten::COOSort(coo, false);\n  rs_csr = SR_CSR3<IDX>(ctx);\n  rs_tcsr = aten::COOToCSR(rs_coo);\n  ASSERT_EQ(coo.num_rows, rs_tcsr.num_rows);\n  ASSERT_EQ(coo.num_cols, rs_tcsr.num_cols);\n  ASSERT_TRUE(ArrayEQ<IDX>(rs_csr.indptr, rs_tcsr.indptr));\n  ASSERT_TRUE(ArrayEQ<IDX>(rs_tcsr.indices, rs_coo.col));\n  ASSERT_TRUE(ArrayEQ<IDX>(rs_tcsr.data, rs_coo.data));\n\n  rs_coo = RowSorted_NullData_COO<IDX>(ctx);\n  ASSERT_TRUE(rs_coo.row_sorted);\n  rs_csr = RowSorted_NullData_CSR<IDX>(ctx);\n  rs_tcsr = aten::COOToCSR(rs_coo);\n  ASSERT_EQ(coo.num_rows, rs_tcsr.num_rows);\n  ASSERT_EQ(rs_csr.num_rows, rs_tcsr.num_rows);\n  ASSERT_EQ(coo.num_cols, rs_tcsr.num_cols);\n  ASSERT_EQ(rs_csr.num_cols, rs_tcsr.num_cols);\n  ASSERT_TRUE(ArrayEQ<IDX>(rs_csr.indptr, rs_tcsr.indptr));\n  ASSERT_TRUE(ArrayEQ<IDX>(rs_csr.indices, rs_tcsr.indices));\n  ASSERT_TRUE(ArrayEQ<IDX>(rs_csr.data, rs_tcsr.data));\n  ASSERT_TRUE(ArrayEQ<IDX>(rs_coo.col, rs_tcsr.indices));\n  ASSERT_FALSE(ArrayEQ<IDX>(rs_coo.data, rs_tcsr.data));\n\n  // Convert from col sorted coo\n  coo = COO1<IDX>(ctx);\n  auto src_coo = aten::COOSort(coo, true);\n  auto src_csr = CSR1<IDX>(ctx);\n  auto src_tcsr = aten::COOToCSR(src_coo);\n  ASSERT_EQ(coo.num_rows, src_tcsr.num_rows);\n  ASSERT_EQ(coo.num_cols, src_tcsr.num_cols);\n  ASSERT_TRUE(src_tcsr.sorted);\n  ASSERT_TRUE(ArrayEQ<IDX>(src_tcsr.indptr, src_csr.indptr));\n  ASSERT_TRUE(ArrayEQ<IDX>(src_tcsr.indices, src_coo.col));\n  ASSERT_TRUE(ArrayEQ<IDX>(src_tcsr.data, src_coo.data));\n\n  coo = COO3<IDX>(ctx);\n  src_coo = aten::COOSort(coo, true);\n  src_csr = SRC_CSR3<IDX>(ctx);\n  src_tcsr = aten::COOToCSR(src_coo);\n  ASSERT_EQ(coo.num_rows, src_tcsr.num_rows);\n  ASSERT_EQ(coo.num_cols, src_tcsr.num_cols);\n  ASSERT_TRUE(src_tcsr.sorted);\n  ASSERT_TRUE(ArrayEQ<IDX>(src_tcsr.indptr, src_csr.indptr));\n  ASSERT_TRUE(ArrayEQ<IDX>(src_tcsr.indices, src_coo.col));\n  ASSERT_TRUE(ArrayEQ<IDX>(src_tcsr.data, src_coo.data));\n\n  coo = SparseCOOCSR::COOSparse<IDX>(ctx);\n  csr = SparseCOOCSR::CSRSparse<IDX>(ctx);\n  tcsr = aten::COOToCSR(coo);\n  ASSERT_FALSE(coo.row_sorted);\n  ASSERT_EQ(csr.num_rows, tcsr.num_rows);\n  ASSERT_EQ(csr.num_cols, tcsr.num_cols);\n  ASSERT_TRUE(ArrayEQ<IDX>(csr.indptr, tcsr.indptr));\n  ASSERT_TRUE(ArrayEQ<IDX>(csr.indices, tcsr.indices));\n}\n\nTEST(SpmatTest, COOToCSR) {\n  _TestCOOToCSR<int32_t>(CPU);\n  _TestCOOToCSR<int64_t>(CPU);\n#ifdef DGL_USE_CUDA\n  _TestCOOToCSR<int32_t>(GPU);\n  _TestCOOToCSR<int64_t>(GPU);\n#endif\n}\n\ntemplate <typename IDX>\nvoid _TestCOOHasDuplicate() {\n  auto coo = COO1<IDX>();\n  ASSERT_FALSE(aten::COOHasDuplicate(coo));\n  coo = COO2<IDX>();\n  ASSERT_TRUE(aten::COOHasDuplicate(coo));\n}\n\nTEST(SpmatTest, TestCOOHasDuplicate) {\n  _TestCOOHasDuplicate<int32_t>();\n  _TestCOOHasDuplicate<int64_t>();\n}\n\ntemplate <typename IDX>\nvoid _TestCOOSort(DGLContext ctx) {\n  auto coo = COO3<IDX>(ctx);\n\n  auto sr_coo = COOSort(coo, false);\n  ASSERT_EQ(coo.num_rows, sr_coo.num_rows);\n  ASSERT_EQ(coo.num_cols, sr_coo.num_cols);\n  ASSERT_TRUE(sr_coo.row_sorted);\n  auto flags = COOIsSorted(sr_coo);\n  ASSERT_TRUE(flags.first);\n  flags = COOIsSorted(coo);  // original coo should stay the same\n  ASSERT_FALSE(flags.first);\n  ASSERT_FALSE(flags.second);\n\n  auto src_coo = COOSort(coo, true);\n  ASSERT_EQ(coo.num_rows, src_coo.num_rows);\n  ASSERT_EQ(coo.num_cols, src_coo.num_cols);\n  ASSERT_TRUE(src_coo.row_sorted);\n  ASSERT_TRUE(src_coo.col_sorted);\n  flags = COOIsSorted(src_coo);\n  ASSERT_TRUE(flags.first);\n  ASSERT_TRUE(flags.second);\n\n  // sort inplace\n  COOSort_(&coo);\n  ASSERT_TRUE(coo.row_sorted);\n  flags = COOIsSorted(coo);\n  ASSERT_TRUE(flags.first);\n  COOSort_(&coo, true);\n  ASSERT_TRUE(coo.row_sorted);\n  ASSERT_TRUE(coo.col_sorted);\n  flags = COOIsSorted(coo);\n  ASSERT_TRUE(flags.first);\n  ASSERT_TRUE(flags.second);\n\n  // COO3\n  // [[0, 1, 2, 0, 0],\n  //  [1, 0, 0, 0, 0],\n  //  [0, 0, 1, 1, 0],\n  //  [0, 0, 0, 0, 0]]\n  // data: [0, 1, 2, 3, 4, 5]\n  // row : [0, 2, 0, 1, 2, 0]\n  // col : [2, 2, 1, 0, 3, 2]\n  // Row Sorted\n  // data: [0, 2, 5, 3, 1, 4]\n  // row : [0, 0, 0, 1, 2, 2]\n  // col : [2, 1, 2, 0, 2, 3]\n  // Row Col Sorted\n  // data: [2, 0, 5, 3, 1, 4]\n  // row : [0, 0, 0, 1, 2, 2]\n  // col : [1, 2, 2, 0, 2, 3]\n  auto sort_row = aten::VecToIdArray(\n      std::vector<IDX>({0, 0, 0, 1, 2, 2}), sizeof(IDX) * 8, ctx);\n  auto sort_col = aten::VecToIdArray(\n      std::vector<IDX>({1, 2, 2, 0, 2, 3}), sizeof(IDX) * 8, ctx);\n  auto sort_col_data = aten::VecToIdArray(\n      std::vector<IDX>({2, 0, 5, 3, 1, 4}), sizeof(IDX) * 8, ctx);\n\n  ASSERT_TRUE(ArrayEQ<IDX>(sr_coo.row, sort_row));\n  ASSERT_TRUE(ArrayEQ<IDX>(src_coo.row, sort_row));\n  ASSERT_TRUE(ArrayEQ<IDX>(src_coo.col, sort_col));\n  ASSERT_TRUE(ArrayEQ<IDX>(src_coo.data, sort_col_data));\n}\n\nTEST(SpmatTest, COOSort) {\n  _TestCOOSort<int32_t>(CPU);\n  _TestCOOSort<int64_t>(CPU);\n#ifdef DGL_USE_CUDA\n  _TestCOOSort<int32_t>(GPU);\n  _TestCOOSort<int64_t>(GPU);\n#endif\n}\n\ntemplate <typename IDX>\nvoid _TestCOOReorder() {\n  auto coo = COO2<IDX>();\n  auto new_row =\n      aten::VecToIdArray(std::vector<IDX>({2, 0, 3, 1}), sizeof(IDX) * 8, CTX);\n  auto new_col = aten::VecToIdArray(\n      std::vector<IDX>({2, 0, 4, 3, 1}), sizeof(IDX) * 8, CTX);\n  auto new_coo = COOReorder(coo, new_row, new_col);\n  ASSERT_EQ(new_coo.num_rows, coo.num_rows);\n  ASSERT_EQ(new_coo.num_cols, coo.num_cols);\n}\n\nTEST(SpmatTest, TestCOOReorder) {\n  _TestCOOReorder<int32_t>();\n  _TestCOOReorder<int64_t>();\n}\n\ntemplate <typename IDX>\nvoid _TestCOOGetData(DGLContext ctx) {\n  auto coo = COO2<IDX>(ctx);\n  // test get all data\n  auto x = aten::COOGetAllData(coo, 0, 0);\n  auto tx = aten::VecToIdArray(std::vector<IDX>({}), sizeof(IDX) * 8, ctx);\n  ASSERT_TRUE(ArrayEQ<IDX>(x, tx));\n  x = aten::COOGetAllData(coo, 0, 2);\n  tx = aten::VecToIdArray(std::vector<IDX>({2, 5}), sizeof(IDX) * 8, ctx);\n  ASSERT_TRUE(ArrayEQ<IDX>(x, tx));\n\n  // test get data\n  auto r =\n      aten::VecToIdArray(std::vector<IDX>({0, 0, 0}), sizeof(IDX) * 8, ctx);\n  auto c =\n      aten::VecToIdArray(std::vector<IDX>({0, 1, 2}), sizeof(IDX) * 8, ctx);\n  x = aten::COOGetData(coo, r, c);\n  tx = aten::VecToIdArray(std::vector<IDX>({-1, 0, 2}), sizeof(IDX) * 8, ctx);\n  ASSERT_TRUE(ArrayEQ<IDX>(x, tx));\n\n  // test get data on sorted\n  coo = aten::COOSort(coo);\n  r = aten::VecToIdArray(std::vector<IDX>({0, 0, 0}), sizeof(IDX) * 8, ctx);\n  c = aten::VecToIdArray(std::vector<IDX>({0, 1, 2}), sizeof(IDX) * 8, ctx);\n  x = aten::COOGetData(coo, r, c);\n  tx = aten::VecToIdArray(std::vector<IDX>({-1, 0, 2}), sizeof(IDX) * 8, ctx);\n  ASSERT_TRUE(ArrayEQ<IDX>(x, tx));\n\n  // test get data w/ broadcasting\n  r = aten::VecToIdArray(std::vector<IDX>({0}), sizeof(IDX) * 8, ctx);\n  c = aten::VecToIdArray(std::vector<IDX>({0, 1, 2}), sizeof(IDX) * 8, ctx);\n  x = aten::COOGetData(coo, r, c);\n  tx = aten::VecToIdArray(std::vector<IDX>({-1, 0, 2}), sizeof(IDX) * 8, ctx);\n  ASSERT_TRUE(ArrayEQ<IDX>(x, tx));\n}\n\nTEST(SpmatTest, COOGetData) {\n  _TestCOOGetData<int32_t>(CPU);\n  _TestCOOGetData<int64_t>(CPU);\n  // #ifdef DGL_USE_CUDA\n  //_TestCOOGetData<int32_t>(GPU);\n  //_TestCOOGetData<int64_t>(GPU);\n  // #endif\n}\n\ntemplate <typename IDX>\nvoid _TestCOOGetDataAndIndices() {\n  auto coo = COO2<IDX>();\n  auto r =\n      aten::VecToIdArray(std::vector<IDX>({0, 0, 0}), sizeof(IDX) * 8, CTX);\n  auto c =\n      aten::VecToIdArray(std::vector<IDX>({0, 1, 2}), sizeof(IDX) * 8, CTX);\n  auto x = aten::COOGetDataAndIndices(coo, r, c);\n  auto tr =\n      aten::VecToIdArray(std::vector<IDX>({0, 0, 0}), sizeof(IDX) * 8, CTX);\n  auto tc =\n      aten::VecToIdArray(std::vector<IDX>({1, 2, 2}), sizeof(IDX) * 8, CTX);\n  auto td =\n      aten::VecToIdArray(std::vector<IDX>({0, 2, 5}), sizeof(IDX) * 8, CTX);\n  ASSERT_TRUE(ArrayEQ<IDX>(x[0], tr));\n  ASSERT_TRUE(ArrayEQ<IDX>(x[1], tc));\n  ASSERT_TRUE(ArrayEQ<IDX>(x[2], td));\n}\n\nTEST(SpmatTest, COOGetDataAndIndices) {\n  _TestCOOGetDataAndIndices<int32_t>();\n  _TestCOOGetDataAndIndices<int64_t>();\n}\n\ntemplate <typename IDX>\nvoid _TestCOOToCSRAlgs() {\n  // Compare results between different CPU COOToCSR implementations.\n  // NNZ is chosen to be bigger than the limit for the \"small\" matrix algorithm.\n  // N is set to lay on border between \"sparse\" and \"dense\" algorithm choice.\n\n  const int64_t num_threads = std::min(256, omp_get_max_threads());\n  const int64_t min_num_threads = 3;\n\n  if (num_threads < min_num_threads) {\n    std::cerr << \"[          ] [ INFO ]\"\n              << \"This test requires at least 3 OMP threads to work properly\"\n              << std::endl;\n    GTEST_SKIP();\n    return;\n  }\n\n  // Select N and NNZ for COO matrix in a way than depending on number of\n  // threads different algorithm will be used.\n  // See WhichCOOToCSR in src/array/cpu/spmat_op_impl_coo.cc for details\n  const int64_t type_scale = sizeof(IDX) >> 1;\n  const int64_t small = 50 * num_threads * type_scale * type_scale;\n  // NNZ should be bigger than limit for small matrix algorithm\n  const int64_t nnz = small + 1234;\n  // N is chosen to lay on sparse/dense border\n  const int64_t n = type_scale * nnz / num_threads;\n  const IDX rows_nad_cols = n + 1;  // should be bigger than sparse/dense border\n\n  // Note that it will be better to set the seed to a random value when gtest\n  // allows to use --gtest_random_seed without --gtest_shuffle and report this\n  // value for reproduction. This way we can find unforeseen situations and\n  // potential bugs.\n  const auto seed = 123321;\n  auto coo = COORandomized<IDX>(rows_nad_cols, nnz, seed);\n\n  omp_set_num_threads(1);\n  // UnSortedSmallCOOToCSR will be used\n  auto tcsr_small = aten::COOToCSR(coo);\n  ASSERT_EQ(coo.num_rows, tcsr_small.num_rows);\n  ASSERT_EQ(coo.num_cols, tcsr_small.num_cols);\n\n  omp_set_num_threads(num_threads - 1);\n  // UnSortedDenseCOOToCSR will be used\n  auto tcsr_dense = aten::COOToCSR(coo);\n  ASSERT_EQ(tcsr_small.num_rows, tcsr_dense.num_rows);\n  ASSERT_EQ(tcsr_small.num_cols, tcsr_dense.num_cols);\n  ASSERT_TRUE(ArrayEQ<IDX>(tcsr_small.indptr, tcsr_dense.indptr));\n  ASSERT_TRUE(ArrayEQ<IDX>(tcsr_small.indices, tcsr_dense.indices));\n  ASSERT_TRUE(ArrayEQ<IDX>(tcsr_small.data, tcsr_dense.data));\n\n  omp_set_num_threads(num_threads);\n  // UnSortedSparseCOOToCSR will be used\n  auto tcsr_sparse = aten::COOToCSR(coo);\n  ASSERT_EQ(tcsr_small.num_rows, tcsr_sparse.num_rows);\n  ASSERT_EQ(tcsr_small.num_cols, tcsr_sparse.num_cols);\n  ASSERT_TRUE(ArrayEQ<IDX>(tcsr_small.indptr, tcsr_sparse.indptr));\n  ASSERT_TRUE(ArrayEQ<IDX>(tcsr_small.indices, tcsr_sparse.indices));\n  ASSERT_TRUE(ArrayEQ<IDX>(tcsr_small.data, tcsr_sparse.data));\n  return;\n}\n\nTEST(SpmatTest, COOToCSRAlgs) {\n  _TestCOOToCSRAlgs<int32_t>();\n  _TestCOOToCSRAlgs<int64_t>();\n}\n"
  },
  {
    "path": "tests/cpp/test_spmat_csr.cc",
    "content": "#include <dgl/array.h>\n#include <gtest/gtest.h>\n\n#include \"./common.h\"\n\nusing namespace dgl;\nusing namespace dgl::runtime;\n\nnamespace {\n\ntemplate <typename IDX>\naten::CSRMatrix CSR1(DGLContext ctx = CTX) {\n  // [[0, 1, 1, 0, 0],\n  //  [1, 0, 0, 0, 0],\n  //  [0, 0, 1, 1, 0],\n  //  [0, 0, 0, 0, 0]]\n  // data: [0, 2, 3, 1, 4]\n  return aten::CSRMatrix(\n      4, 5,\n      aten::VecToIdArray(\n          std::vector<IDX>({0, 2, 3, 5, 5}), sizeof(IDX) * 8, ctx),\n      aten::VecToIdArray(\n          std::vector<IDX>({1, 2, 0, 3, 2}), sizeof(IDX) * 8, ctx),\n      aten::VecToIdArray(\n          std::vector<IDX>({0, 2, 3, 4, 1}), sizeof(IDX) * 8, ctx),\n      false);\n}\n\ntemplate <typename IDX>\naten::CSRMatrix CSR2(DGLContext ctx = CTX) {\n  // has duplicate entries\n  // [[0, 1, 2, 0, 0],\n  //  [1, 0, 0, 0, 0],\n  //  [0, 0, 1, 1, 0],\n  //  [0, 0, 0, 0, 0]]\n  // data: [0, 2, 5, 3, 1, 4]\n  return aten::CSRMatrix(\n      4, 5,\n      aten::VecToIdArray(\n          std::vector<IDX>({0, 3, 4, 6, 6}), sizeof(IDX) * 8, ctx),\n      aten::VecToIdArray(\n          std::vector<IDX>({1, 2, 2, 0, 2, 3}), sizeof(IDX) * 8, ctx),\n      aten::VecToIdArray(\n          std::vector<IDX>({0, 2, 5, 3, 1, 4}), sizeof(IDX) * 8, ctx),\n      false);\n}\n\ntemplate <typename IDX>\naten::CSRMatrix CSR3(DGLContext ctx = CTX) {\n  // has duplicate entries and the columns are not sorted\n  // [[0, 1, 1, 1, 0, 0],\n  //  [1, 0, 0, 0, 0, 0],\n  //  [0, 0, 1, 1, 0, 0],\n  //  [0, 0, 0, 0, 0, 0],\n  //  [1, 1, 1, 0, 0, 0],\n  //  [0, 0, 0, 1, 0, 0],\n  //  [0, 0, 0, 0, 0, 0],\n  //  [1, 2, 1, 1, 0, 0],\n  //  [0, 1, 0, 0, 0, 1]],\n  // data: [5, 2, 0, 3, 1, 4, 8, 7, 6, 9, 12, 13, 11, 10, 14, 15, 16]\n  return aten::CSRMatrix(\n      9, 6,\n      aten::VecToIdArray(\n          std::vector<IDX>({0, 3, 4, 6, 6, 9, 10, 10, 15, 17}), sizeof(IDX) * 8,\n          ctx),\n      aten::VecToIdArray(\n          std::vector<IDX>({3, 2, 1, 0, 2, 3, 1, 2, 0, 3, 1, 2, 1, 3, 0, 5, 1}),\n          sizeof(IDX) * 8, ctx),\n      aten::VecToIdArray(\n          std::vector<IDX>(\n              {0, 2, 5, 3, 1, 4, 6, 8, 7, 9, 13, 10, 11, 14, 12, 16, 15}),\n          sizeof(IDX) * 8, ctx),\n      false);\n}\n\ntemplate <typename IDX>\naten::COOMatrix COO1(DGLContext ctx = CTX) {\n  // [[0, 1, 1, 0, 0],\n  //  [1, 0, 0, 0, 0],\n  //  [0, 0, 1, 1, 0],\n  //  [0, 0, 0, 0, 0]]\n  // data: [0, 2, 3, 1, 4]\n  // row : [0, 2, 0, 1, 2]\n  // col : [1, 2, 2, 0, 3]\n  return aten::COOMatrix(\n      4, 5,\n      aten::VecToIdArray(\n          std::vector<IDX>({0, 2, 0, 1, 2}), sizeof(IDX) * 8, ctx),\n      aten::VecToIdArray(\n          std::vector<IDX>({1, 2, 2, 0, 3}), sizeof(IDX) * 8, ctx),\n      aten::VecToIdArray(\n          std::vector<IDX>({0, 3, 1, 2, 4}), sizeof(IDX) * 8, ctx));\n}\n\ntemplate <typename IDX>\naten::COOMatrix COO2(DGLContext ctx = CTX) {\n  // has duplicate entries\n  // [[0, 1, 2, 0, 0],\n  //  [1, 0, 0, 0, 0],\n  //  [0, 0, 1, 1, 0],\n  //  [0, 0, 0, 0, 0]]\n  // data: [0, 2, 5, 3, 1, 4]\n  // row : [0, 2, 0, 1, 2, 0]\n  // col : [1, 2, 2, 0, 3, 2]\n  return aten::COOMatrix(\n      4, 5,\n      aten::VecToIdArray(\n          std::vector<IDX>({0, 2, 0, 1, 2, 0}), sizeof(IDX) * 8, ctx),\n      aten::VecToIdArray(\n          std::vector<IDX>({1, 2, 2, 0, 3, 2}), sizeof(IDX) * 8, ctx),\n      aten::VecToIdArray(\n          std::vector<IDX>({0, 1, 2, 3, 4, 5}), sizeof(IDX) * 8, ctx));\n}\n\ntemplate <typename IDX>\naten::CSRMatrix SR_CSR3(DGLContext ctx) {\n  // [[0, 1, 2, 0, 0],\n  //  [1, 0, 0, 0, 0],\n  //  [0, 0, 1, 1, 0],\n  //  [0, 0, 0, 0, 0]]\n  return aten::CSRMatrix(\n      4, 5,\n      aten::VecToIdArray(\n          std::vector<IDX>({0, 3, 4, 6, 6}), sizeof(IDX) * 8, ctx),\n      aten::VecToIdArray(\n          std::vector<IDX>({2, 1, 2, 0, 2, 3}), sizeof(IDX) * 8, ctx),\n      aten::VecToIdArray(\n          std::vector<IDX>({0, 2, 5, 3, 1, 4}), sizeof(IDX) * 8, ctx),\n      false);\n}\n\ntemplate <typename IDX>\naten::CSRMatrix SRC_CSR3(DGLContext ctx) {\n  // [[0, 1, 2, 0, 0],\n  //  [1, 0, 0, 0, 0],\n  //  [0, 0, 1, 1, 0],\n  //  [0, 0, 0, 0, 0]]\n  return aten::CSRMatrix(\n      4, 5,\n      aten::VecToIdArray(\n          std::vector<IDX>({0, 3, 4, 6, 6}), sizeof(IDX) * 8, ctx),\n      aten::VecToIdArray(\n          std::vector<IDX>({1, 2, 2, 0, 2, 3}), sizeof(IDX) * 8, ctx),\n      aten::VecToIdArray(\n          std::vector<IDX>({2, 0, 5, 3, 1, 4}), sizeof(IDX) * 8, ctx),\n      false);\n}\n\ntemplate <typename IDX>\naten::COOMatrix COO3(DGLContext ctx) {\n  // has duplicate entries\n  // [[0, 1, 2, 0, 0],\n  //  [1, 0, 0, 0, 0],\n  //  [0, 0, 1, 1, 0],\n  //  [0, 0, 0, 0, 0]]\n  // row : [0, 2, 0, 1, 2, 0]\n  // col : [2, 2, 1, 0, 3, 2]\n  return aten::COOMatrix(\n      4, 5,\n      aten::VecToIdArray(\n          std::vector<IDX>({0, 2, 0, 1, 2, 0}), sizeof(IDX) * 8, ctx),\n      aten::VecToIdArray(\n          std::vector<IDX>({2, 2, 1, 0, 3, 2}), sizeof(IDX) * 8, ctx));\n}\n\n}  // namespace\n\ntemplate <typename IDX>\nvoid _TestCSRIsNonZero1(DGLContext ctx) {\n  auto csr = CSR1<IDX>(ctx);\n  ASSERT_TRUE(aten::CSRIsNonZero(csr, 0, 1));\n  ASSERT_FALSE(aten::CSRIsNonZero(csr, 0, 0));\n  IdArray r =\n      aten::VecToIdArray(std::vector<IDX>({2, 2, 0, 0}), sizeof(IDX) * 8, ctx);\n  IdArray c =\n      aten::VecToIdArray(std::vector<IDX>({1, 1, 1, 3}), sizeof(IDX) * 8, ctx);\n  IdArray x = aten::CSRIsNonZero(csr, r, c);\n  IdArray tx =\n      aten::VecToIdArray(std::vector<IDX>({0, 0, 1, 0}), sizeof(IDX) * 8, ctx);\n  ASSERT_TRUE(ArrayEQ<IDX>(x, tx));\n}\n\ntemplate <typename IDX>\nvoid _TestCSRIsNonZero2(DGLContext ctx) {\n  auto csr = CSR3<IDX>(ctx);\n  ASSERT_TRUE(aten::CSRIsNonZero(csr, 0, 1));\n  ASSERT_FALSE(aten::CSRIsNonZero(csr, 0, 0));\n  IdArray r = aten::VecToIdArray(\n      std::vector<IDX>({\n          0,\n          0,\n          0,\n          0,\n          0,\n      }),\n      sizeof(IDX) * 8, ctx);\n  IdArray c = aten::VecToIdArray(\n      std::vector<IDX>({\n          0,\n          1,\n          2,\n          3,\n          4,\n      }),\n      sizeof(IDX) * 8, ctx);\n  IdArray x = aten::CSRIsNonZero(csr, r, c);\n  IdArray tx = aten::VecToIdArray(\n      std::vector<IDX>({0, 1, 1, 1, 0}), sizeof(IDX) * 8, ctx);\n  ASSERT_TRUE(ArrayEQ<IDX>(x, tx)) << \" x = \" << x << \", tx = \" << tx;\n}\n\nTEST(SpmatTest, TestCSRIsNonZero) {\n  _TestCSRIsNonZero1<int32_t>(CPU);\n  _TestCSRIsNonZero1<int64_t>(CPU);\n  _TestCSRIsNonZero2<int32_t>(CPU);\n  _TestCSRIsNonZero2<int64_t>(CPU);\n#ifdef DGL_USE_CUDA\n  _TestCSRIsNonZero1<int32_t>(GPU);\n  _TestCSRIsNonZero1<int64_t>(GPU);\n  _TestCSRIsNonZero2<int32_t>(GPU);\n  _TestCSRIsNonZero2<int64_t>(GPU);\n#endif\n}\n\ntemplate <typename IDX>\nvoid _TestCSRGetRowNNZ(DGLContext ctx) {\n  auto csr = CSR2<IDX>(ctx);\n  ASSERT_EQ(aten::CSRGetRowNNZ(csr, 0), 3);\n  ASSERT_EQ(aten::CSRGetRowNNZ(csr, 3), 0);\n  IdArray r =\n      aten::VecToIdArray(std::vector<IDX>({0, 3}), sizeof(IDX) * 8, ctx);\n  IdArray x = aten::CSRGetRowNNZ(csr, r);\n  IdArray tx =\n      aten::VecToIdArray(std::vector<IDX>({3, 0}), sizeof(IDX) * 8, ctx);\n  ASSERT_TRUE(ArrayEQ<IDX>(x, tx));\n}\n\nTEST(SpmatTest, TestCSRGetRowNNZ) {\n  _TestCSRGetRowNNZ<int32_t>(CPU);\n  _TestCSRGetRowNNZ<int64_t>(CPU);\n#ifdef DGL_USE_CUDA\n  _TestCSRGetRowNNZ<int32_t>(GPU);\n  _TestCSRGetRowNNZ<int64_t>(GPU);\n#endif\n}\n\ntemplate <typename IDX>\nvoid _TestCSRGetRowColumnIndices(DGLContext ctx) {\n  auto csr = CSR2<IDX>(ctx);\n  auto x = aten::CSRGetRowColumnIndices(csr, 0);\n  auto tx =\n      aten::VecToIdArray(std::vector<IDX>({1, 2, 2}), sizeof(IDX) * 8, ctx);\n  ASSERT_TRUE(ArrayEQ<IDX>(x, tx));\n  x = aten::CSRGetRowColumnIndices(csr, 1);\n  tx = aten::VecToIdArray(std::vector<IDX>({0}), sizeof(IDX) * 8, ctx);\n  ASSERT_TRUE(ArrayEQ<IDX>(x, tx));\n  x = aten::CSRGetRowColumnIndices(csr, 3);\n  tx = aten::VecToIdArray(std::vector<IDX>({}), sizeof(IDX) * 8, ctx);\n  ASSERT_TRUE(ArrayEQ<IDX>(x, tx));\n}\n\nTEST(SpmatTest, TestCSRGetRowColumnIndices) {\n  _TestCSRGetRowColumnIndices<int32_t>(CPU);\n  _TestCSRGetRowColumnIndices<int64_t>(CPU);\n#ifdef DGL_USE_CUDA\n  _TestCSRGetRowColumnIndices<int32_t>(GPU);\n  _TestCSRGetRowColumnIndices<int64_t>(GPU);\n#endif\n}\n\ntemplate <typename IDX>\nvoid _TestCSRGetRowData(DGLContext ctx) {\n  auto csr = CSR2<IDX>(ctx);\n  auto x = aten::CSRGetRowData(csr, 0);\n  auto tx =\n      aten::VecToIdArray(std::vector<IDX>({0, 2, 5}), sizeof(IDX) * 8, ctx);\n  ASSERT_TRUE(ArrayEQ<IDX>(x, tx));\n  x = aten::CSRGetRowData(csr, 1);\n  tx = aten::VecToIdArray(std::vector<IDX>({3}), sizeof(IDX) * 8, ctx);\n  ASSERT_TRUE(ArrayEQ<IDX>(x, tx));\n  x = aten::CSRGetRowData(csr, 3);\n  tx = aten::VecToIdArray(std::vector<IDX>({}), sizeof(IDX) * 8, ctx);\n  ASSERT_TRUE(ArrayEQ<IDX>(x, tx));\n}\n\nTEST(SpmatTest, TestCSRGetRowData) {\n  _TestCSRGetRowData<int32_t>(CPU);\n  _TestCSRGetRowData<int64_t>(CPU);\n#ifdef DGL_USE_CUDA\n  _TestCSRGetRowData<int32_t>(GPU);\n  _TestCSRGetRowData<int64_t>(GPU);\n#endif\n}\n\ntemplate <typename IDX>\nvoid _TestCSRGetData(DGLContext ctx) {\n  auto csr = CSR2<IDX>(ctx);\n  // test get all data\n  auto x = aten::CSRGetAllData(csr, 0, 0);\n  auto tx = aten::VecToIdArray(std::vector<IDX>({}), sizeof(IDX) * 8, ctx);\n  ASSERT_TRUE(ArrayEQ<IDX>(x, tx));\n  x = aten::CSRGetAllData(csr, 0, 2);\n  tx = aten::VecToIdArray(std::vector<IDX>({2, 5}), sizeof(IDX) * 8, ctx);\n  ASSERT_TRUE(ArrayEQ<IDX>(x, tx));\n\n  // test get data\n  auto r =\n      aten::VecToIdArray(std::vector<IDX>({0, 0, 0}), sizeof(IDX) * 8, ctx);\n  auto c =\n      aten::VecToIdArray(std::vector<IDX>({0, 1, 2}), sizeof(IDX) * 8, ctx);\n  x = aten::CSRGetData(csr, r, c);\n  tx = aten::VecToIdArray(std::vector<IDX>({-1, 0, 2}), sizeof(IDX) * 8, ctx);\n  ASSERT_TRUE(ArrayEQ<IDX>(x, tx));\n\n  // test get data on sorted\n  csr = aten::CSRSort(csr);\n  r = aten::VecToIdArray(std::vector<IDX>({0, 0, 0}), sizeof(IDX) * 8, ctx);\n  c = aten::VecToIdArray(std::vector<IDX>({0, 1, 2}), sizeof(IDX) * 8, ctx);\n  x = aten::CSRGetData(csr, r, c);\n  tx = aten::VecToIdArray(std::vector<IDX>({-1, 0, 2}), sizeof(IDX) * 8, ctx);\n  ASSERT_TRUE(ArrayEQ<IDX>(x, tx));\n\n  // test get data w/ broadcasting\n  r = aten::VecToIdArray(std::vector<IDX>({0}), sizeof(IDX) * 8, ctx);\n  c = aten::VecToIdArray(std::vector<IDX>({0, 1, 2}), sizeof(IDX) * 8, ctx);\n  x = aten::CSRGetData(csr, r, c);\n  tx = aten::VecToIdArray(std::vector<IDX>({-1, 0, 2}), sizeof(IDX) * 8, ctx);\n  ASSERT_TRUE(ArrayEQ<IDX>(x, tx));\n}\n\nTEST(SpmatTest, CSRGetData) {\n  _TestCSRGetData<int32_t>(CPU);\n  _TestCSRGetData<int64_t>(CPU);\n#ifdef DGL_USE_CUDA\n  _TestCSRGetData<int32_t>(GPU);\n  _TestCSRGetData<int64_t>(GPU);\n#endif\n}\n\ntemplate <typename IDX>\nvoid _TestCSRGetDataAndIndices(DGLContext ctx) {\n  auto csr = CSR2<IDX>(ctx);\n  auto r =\n      aten::VecToIdArray(std::vector<IDX>({0, 0, 0}), sizeof(IDX) * 8, ctx);\n  auto c =\n      aten::VecToIdArray(std::vector<IDX>({0, 1, 2}), sizeof(IDX) * 8, ctx);\n  auto x = aten::CSRGetDataAndIndices(csr, r, c);\n  auto tr =\n      aten::VecToIdArray(std::vector<IDX>({0, 0, 0}), sizeof(IDX) * 8, ctx);\n  auto tc =\n      aten::VecToIdArray(std::vector<IDX>({1, 2, 2}), sizeof(IDX) * 8, ctx);\n  auto td =\n      aten::VecToIdArray(std::vector<IDX>({0, 2, 5}), sizeof(IDX) * 8, ctx);\n  ASSERT_TRUE(ArrayEQ<IDX>(x[0], tr));\n  ASSERT_TRUE(ArrayEQ<IDX>(x[1], tc));\n  ASSERT_TRUE(ArrayEQ<IDX>(x[2], td));\n}\n\nTEST(SpmatTest, CSRGetDataAndIndices) {\n  _TestCSRGetDataAndIndices<int32_t>(CPU);\n  _TestCSRGetDataAndIndices<int64_t>(CPU);\n#ifdef DGL_USE_CUDA\n  _TestCSRGetDataAndIndices<int32_t>(GPU);\n  _TestCSRGetDataAndIndices<int64_t>(GPU);\n#endif\n}\n\ntemplate <typename IDX>\nvoid _TestCSRTranspose(DGLContext ctx) {\n  auto csr = CSR2<IDX>(ctx);\n  auto csr_t = aten::CSRTranspose(csr);\n  // [[0, 1, 0, 0],\n  //  [1, 0, 0, 0],\n  //  [2, 0, 1, 0],\n  //  [0, 0, 1, 0],\n  //  [0, 0, 0, 0]]\n  // data: [3, 0, 2, 5, 1, 4]\n  ASSERT_EQ(csr_t.num_rows, 5);\n  ASSERT_EQ(csr_t.num_cols, 4);\n  auto tp = aten::VecToIdArray(\n      std::vector<IDX>({0, 1, 2, 5, 6, 6}), sizeof(IDX) * 8, ctx);\n  auto ti = aten::VecToIdArray(\n      std::vector<IDX>({1, 0, 0, 0, 2, 2}), sizeof(IDX) * 8, ctx);\n  auto td = aten::VecToIdArray(\n      std::vector<IDX>({3, 0, 2, 5, 1, 4}), sizeof(IDX) * 8, ctx);\n  ASSERT_TRUE(ArrayEQ<IDX>(csr_t.indptr, tp));\n  ASSERT_TRUE(ArrayEQ<IDX>(csr_t.indices, ti));\n  ASSERT_TRUE(ArrayEQ<IDX>(csr_t.data, td));\n}\n\nTEST(SpmatTest, CSRTranspose) {\n  _TestCSRTranspose<int32_t>(CPU);\n  _TestCSRTranspose<int64_t>(CPU);\n#ifdef DGL_USE_CUDA\n  _TestCSRTranspose<int32_t>(GPU);\n  _TestCSRTranspose<int64_t>(GPU);\n#endif\n}\n\ntemplate <typename IDX>\nvoid _TestCSRToCOO(DGLContext ctx) {\n  auto csr = CSR2<IDX>(ctx);\n  {\n    auto coo = CSRToCOO(csr, false);\n    ASSERT_EQ(coo.num_rows, 4);\n    ASSERT_EQ(coo.num_cols, 5);\n    ASSERT_TRUE(coo.row_sorted);\n    auto tr = aten::VecToIdArray(\n        std::vector<IDX>({0, 0, 0, 1, 2, 2}), sizeof(IDX) * 8, ctx);\n    ASSERT_TRUE(ArrayEQ<IDX>(coo.row, tr));\n    ASSERT_TRUE(ArrayEQ<IDX>(coo.col, csr.indices));\n    ASSERT_TRUE(ArrayEQ<IDX>(coo.data, csr.data));\n\n    // convert from sorted csr\n    auto s_csr = CSRSort(csr);\n    coo = CSRToCOO(s_csr, false);\n    ASSERT_EQ(coo.num_rows, 4);\n    ASSERT_EQ(coo.num_cols, 5);\n    ASSERT_TRUE(coo.row_sorted);\n    ASSERT_TRUE(coo.col_sorted);\n    tr = aten::VecToIdArray(\n        std::vector<IDX>({0, 0, 0, 1, 2, 2}), sizeof(IDX) * 8, ctx);\n    ASSERT_TRUE(ArrayEQ<IDX>(coo.row, tr));\n    ASSERT_TRUE(ArrayEQ<IDX>(coo.col, s_csr.indices));\n    ASSERT_TRUE(ArrayEQ<IDX>(coo.data, s_csr.data));\n  }\n  {\n    auto coo = CSRToCOO(csr, true);\n    ASSERT_EQ(coo.num_rows, 4);\n    ASSERT_EQ(coo.num_cols, 5);\n    auto tcoo = COO2<IDX>(ctx);\n    ASSERT_TRUE(ArrayEQ<IDX>(coo.row, tcoo.row));\n    ASSERT_TRUE(ArrayEQ<IDX>(coo.col, tcoo.col));\n  }\n}\n\nTEST(SpmatTest, CSRToCOO) {\n  _TestCSRToCOO<int32_t>(CPU);\n  _TestCSRToCOO<int64_t>(CPU);\n#if DGL_USE_CUDA\n  _TestCSRToCOO<int32_t>(GPU);\n  _TestCSRToCOO<int64_t>(GPU);\n#endif\n}\n\ntemplate <typename IDX>\nvoid _TestCSRSliceRows(DGLContext ctx) {\n  auto csr = CSR2<IDX>(ctx);\n  auto x = aten::CSRSliceRows(csr, 1, 4);\n  //  [1, 0, 0, 0, 0],\n  //  [0, 0, 1, 1, 0],\n  //  [0, 0, 0, 0, 0]]\n  // data: [3, 1, 4]\n  ASSERT_EQ(x.num_rows, 3);\n  ASSERT_EQ(x.num_cols, 5);\n  auto tp =\n      aten::VecToIdArray(std::vector<IDX>({0, 1, 3, 3}), sizeof(IDX) * 8, ctx);\n  auto ti =\n      aten::VecToIdArray(std::vector<IDX>({0, 2, 3}), sizeof(IDX) * 8, ctx);\n  auto td =\n      aten::VecToIdArray(std::vector<IDX>({3, 1, 4}), sizeof(IDX) * 8, ctx);\n  ASSERT_TRUE(ArrayEQ<IDX>(x.indptr, tp));\n  ASSERT_TRUE(ArrayEQ<IDX>(x.indices, ti));\n  ASSERT_TRUE(ArrayEQ<IDX>(x.data, td));\n\n  auto r =\n      aten::VecToIdArray(std::vector<IDX>({0, 1, 3}), sizeof(IDX) * 8, ctx);\n  x = aten::CSRSliceRows(csr, r);\n  // [[0, 1, 2, 0, 0],\n  //  [1, 0, 0, 0, 0],\n  //  [0, 0, 0, 0, 0]]\n  // data: [0, 2, 5, 3]\n  tp = aten::VecToIdArray(std::vector<IDX>({0, 3, 4, 4}), sizeof(IDX) * 8, ctx);\n  ti = aten::VecToIdArray(std::vector<IDX>({1, 2, 2, 0}), sizeof(IDX) * 8, ctx);\n  td = aten::VecToIdArray(std::vector<IDX>({0, 2, 5, 3}), sizeof(IDX) * 8, ctx);\n  ASSERT_TRUE(ArrayEQ<IDX>(x.indptr, tp));\n  ASSERT_TRUE(ArrayEQ<IDX>(x.indices, ti));\n  ASSERT_TRUE(ArrayEQ<IDX>(x.data, td));\n\n  // Testing non-increasing row id based slicing\n  r = aten::VecToIdArray(std::vector<IDX>({3, 2, 1}), sizeof(IDX) * 8, ctx);\n  x = aten::CSRSliceRows(csr, r);\n  // [[0, 0, 0, 0, 0],\n  //  [0, 0, 1, 1, 0],\n  //  [1, 0, 0, 0, 0]]\n  // data: [1, 4, 3]\n  tp = aten::VecToIdArray(std::vector<IDX>({0, 0, 2, 3}), sizeof(IDX) * 8, ctx);\n  ti = aten::VecToIdArray(std::vector<IDX>({2, 3, 0}), sizeof(IDX) * 8, ctx);\n  td = aten::VecToIdArray(std::vector<IDX>({1, 4, 3}), sizeof(IDX) * 8, ctx);\n  ASSERT_TRUE(ArrayEQ<IDX>(x.indptr, tp));\n  ASSERT_TRUE(ArrayEQ<IDX>(x.indices, ti));\n  ASSERT_TRUE(ArrayEQ<IDX>(x.data, td));\n\n  // Testing zero-degree row slicing with different rows\n  r = aten::VecToIdArray(\n      std::vector<IDX>({1, 3, 0, 3, 2}), sizeof(IDX) * 8, ctx);\n  x = aten::CSRSliceRows(csr, r);\n  // [[1, 0, 0, 0, 0],\n  //  [0, 0, 0, 0, 0],\n  //  [0, 1, 2, 0, 0],\n  //  [0, 0, 0, 0, 0],\n  //  [0, 0, 1, 1, 0]]\n  // data: [3, 0, 2, 5, 1, 4]\n  tp = aten::VecToIdArray(\n      std::vector<IDX>({0, 1, 1, 4, 4, 6}), sizeof(IDX) * 8, ctx);\n  ti = aten::VecToIdArray(\n      std::vector<IDX>({0, 1, 2, 2, 2, 3}), sizeof(IDX) * 8, ctx);\n  td = aten::VecToIdArray(\n      std::vector<IDX>({3, 0, 2, 5, 1, 4}), sizeof(IDX) * 8, ctx);\n  ASSERT_TRUE(ArrayEQ<IDX>(x.indptr, tp));\n  ASSERT_TRUE(ArrayEQ<IDX>(x.indices, ti));\n  ASSERT_TRUE(ArrayEQ<IDX>(x.data, td));\n\n  // Testing empty output (i.e. sliced rows will be zero-degree)\n  r = aten::VecToIdArray(std::vector<IDX>({3, 3, 3}), sizeof(IDX) * 8, ctx);\n  x = aten::CSRSliceRows(csr, r);\n  // [[0, 0, 0, 0, 0],\n  //  [0, 0, 0, 0, 0],\n  //  [0, 0, 0, 0, 0]]\n  // data: []\n  tp = aten::VecToIdArray(std::vector<IDX>({0, 0, 0, 0}), sizeof(IDX) * 8, ctx);\n  ti = aten::VecToIdArray(std::vector<IDX>({}), sizeof(IDX) * 8, ctx);\n  td = aten::VecToIdArray(std::vector<IDX>({}), sizeof(IDX) * 8, ctx);\n  ASSERT_TRUE(ArrayEQ<IDX>(x.indptr, tp));\n  ASSERT_TRUE(ArrayEQ<IDX>(x.indices, ti));\n  ASSERT_TRUE(ArrayEQ<IDX>(x.data, td));\n\n  // Testing constant output: we pick last row with at least one nnz\n  r = aten::VecToIdArray(std::vector<IDX>({2, 2, 2}), sizeof(IDX) * 8, ctx);\n  x = aten::CSRSliceRows(csr, r);\n  // [[0, 0, 1, 1, 0],\n  //  [0, 0, 1, 1, 0],\n  //  [0, 0, 1, 1, 0]]\n  // data: [1, 4, 1, 4, 1, 4]\n  tp = aten::VecToIdArray(std::vector<IDX>({0, 2, 4, 6}), sizeof(IDX) * 8, ctx);\n  ti = aten::VecToIdArray(\n      std::vector<IDX>({2, 3, 2, 3, 2, 3}), sizeof(IDX) * 8, ctx);\n  td = aten::VecToIdArray(\n      std::vector<IDX>({1, 4, 1, 4, 1, 4}), sizeof(IDX) * 8, ctx);\n  ASSERT_TRUE(ArrayEQ<IDX>(x.indptr, tp));\n  ASSERT_TRUE(ArrayEQ<IDX>(x.indices, ti));\n  ASSERT_TRUE(ArrayEQ<IDX>(x.data, td));\n}\n\nTEST(SpmatTest, TestCSRSliceRows) {\n  _TestCSRSliceRows<int32_t>(CPU);\n  _TestCSRSliceRows<int64_t>(CPU);\n#ifdef DGL_USE_CUDA\n  _TestCSRSliceRows<int32_t>(GPU);\n  _TestCSRSliceRows<int64_t>(GPU);\n#endif\n}\n\ntemplate <typename IDX>\nvoid _TestCSRSliceMatrix1(DGLContext ctx) {\n  auto csr = CSR2<IDX>(ctx);\n  {\n    // square\n    auto r =\n        aten::VecToIdArray(std::vector<IDX>({0, 1, 3}), sizeof(IDX) * 8, ctx);\n    auto c =\n        aten::VecToIdArray(std::vector<IDX>({1, 2, 3}), sizeof(IDX) * 8, ctx);\n    auto x = aten::CSRSliceMatrix(csr, r, c);\n    // [[1, 2, 0],\n    //  [0, 0, 0],\n    //  [0, 0, 0]]\n    // data: [0, 2, 5]\n    ASSERT_EQ(x.num_rows, 3);\n    ASSERT_EQ(x.num_cols, 3);\n    auto tp = aten::VecToIdArray(\n        std::vector<IDX>({0, 3, 3, 3}), sizeof(IDX) * 8, ctx);\n    auto ti =\n        aten::VecToIdArray(std::vector<IDX>({0, 1, 1}), sizeof(IDX) * 8, ctx);\n    auto td =\n        aten::VecToIdArray(std::vector<IDX>({0, 2, 5}), sizeof(IDX) * 8, ctx);\n    ASSERT_TRUE(ArrayEQ<IDX>(x.indptr, tp));\n    ASSERT_TRUE(ArrayEQ<IDX>(x.indices, ti));\n    ASSERT_TRUE(ArrayEQ<IDX>(x.data, td));\n  }\n  {\n    // non-square\n    auto r =\n        aten::VecToIdArray(std::vector<IDX>({0, 1, 2}), sizeof(IDX) * 8, ctx);\n    auto c = aten::VecToIdArray(std::vector<IDX>({0, 1}), sizeof(IDX) * 8, ctx);\n    auto x = aten::CSRSliceMatrix(csr, r, c);\n    // [[0, 1],\n    //  [1, 0],\n    //  [0, 0]]\n    // data: [0, 3]\n    ASSERT_EQ(x.num_rows, 3);\n    ASSERT_EQ(x.num_cols, 2);\n    auto tp = aten::VecToIdArray(\n        std::vector<IDX>({0, 1, 2, 2}), sizeof(IDX) * 8, ctx);\n    auto ti =\n        aten::VecToIdArray(std::vector<IDX>({1, 0}), sizeof(IDX) * 8, ctx);\n    auto td =\n        aten::VecToIdArray(std::vector<IDX>({0, 3}), sizeof(IDX) * 8, ctx);\n    ASSERT_TRUE(ArrayEQ<IDX>(x.indptr, tp));\n    ASSERT_TRUE(ArrayEQ<IDX>(x.indices, ti));\n    ASSERT_TRUE(ArrayEQ<IDX>(x.data, td));\n  }\n  {\n    // empty slice\n    auto r = aten::VecToIdArray(std::vector<IDX>({2, 3}), sizeof(IDX) * 8, ctx);\n    auto c = aten::VecToIdArray(std::vector<IDX>({0, 1}), sizeof(IDX) * 8, ctx);\n    auto x = aten::CSRSliceMatrix(csr, r, c);\n    // [[0, 0],\n    //  [0, 0]]\n    // data: []\n    ASSERT_EQ(x.num_rows, 2);\n    ASSERT_EQ(x.num_cols, 2);\n    auto tp =\n        aten::VecToIdArray(std::vector<IDX>({0, 0, 0}), sizeof(IDX) * 8, ctx);\n    auto ti = aten::VecToIdArray(std::vector<IDX>({}), sizeof(IDX) * 8, ctx);\n    auto td = aten::VecToIdArray(std::vector<IDX>({}), sizeof(IDX) * 8, ctx);\n    ASSERT_TRUE(ArrayEQ<IDX>(x.indptr, tp));\n    ASSERT_TRUE(ArrayEQ<IDX>(x.indices, ti));\n    ASSERT_TRUE(ArrayEQ<IDX>(x.data, td));\n  }\n}\n\ntemplate <typename IDX>\nvoid _TestCSRSliceMatrix2(DGLContext ctx) {\n  auto csr = CSR3<IDX>(ctx);\n  {\n    // square\n    auto r =\n        aten::VecToIdArray(std::vector<IDX>({0, 1, 3}), sizeof(IDX) * 8, ctx);\n    auto c =\n        aten::VecToIdArray(std::vector<IDX>({1, 2, 3}), sizeof(IDX) * 8, ctx);\n    auto x = aten::CSRSliceMatrix(csr, r, c);\n    // [[1, 1, 1],\n    //  [0, 0, 0],\n    //  [0, 0, 0]]\n    // data: [5, 2, 0]\n    ASSERT_EQ(x.num_rows, 3);\n    ASSERT_EQ(x.num_cols, 3);\n    auto tp = aten::VecToIdArray(\n        std::vector<IDX>({0, 3, 3, 3}), sizeof(IDX) * 8, ctx);\n    // indexes are in reverse order in CSR3\n    auto ti =\n        aten::VecToIdArray(std::vector<IDX>({2, 1, 0}), sizeof(IDX) * 8, ctx);\n    auto td =\n        aten::VecToIdArray(std::vector<IDX>({0, 2, 5}), sizeof(IDX) * 8, ctx);\n    ASSERT_TRUE(ArrayEQ<IDX>(x.indptr, tp));\n    ASSERT_TRUE(ArrayEQ<IDX>(x.indices, ti));\n    ASSERT_TRUE(ArrayEQ<IDX>(x.data, td));\n  }\n  {\n    // non-square\n    auto r =\n        aten::VecToIdArray(std::vector<IDX>({0, 1, 2}), sizeof(IDX) * 8, ctx);\n    auto c = aten::VecToIdArray(std::vector<IDX>({0, 1}), sizeof(IDX) * 8, ctx);\n    auto x = aten::CSRSliceMatrix(csr, r, c);\n    // [[0, 1],\n    //  [1, 0],\n    //  [0, 0]]\n    // data: [0, 3]\n    ASSERT_EQ(x.num_rows, 3);\n    ASSERT_EQ(x.num_cols, 2);\n    auto tp = aten::VecToIdArray(\n        std::vector<IDX>({0, 1, 2, 2}), sizeof(IDX) * 8, ctx);\n    auto ti =\n        aten::VecToIdArray(std::vector<IDX>({1, 0}), sizeof(IDX) * 8, ctx);\n    auto td =\n        aten::VecToIdArray(std::vector<IDX>({5, 3}), sizeof(IDX) * 8, ctx);\n    ASSERT_TRUE(ArrayEQ<IDX>(x.indptr, tp));\n    ASSERT_TRUE(ArrayEQ<IDX>(x.indices, ti));\n    ASSERT_TRUE(ArrayEQ<IDX>(x.data, td));\n  }\n  {\n    // empty slice\n    auto r = aten::VecToIdArray(std::vector<IDX>({2, 3}), sizeof(IDX) * 8, ctx);\n    auto c = aten::VecToIdArray(std::vector<IDX>({0, 1}), sizeof(IDX) * 8, ctx);\n    auto x = aten::CSRSliceMatrix(csr, r, c);\n    // [[0, 0],\n    //  [0, 0]]\n    // data: []\n    ASSERT_EQ(x.num_rows, 2);\n    ASSERT_EQ(x.num_cols, 2);\n    auto tp =\n        aten::VecToIdArray(std::vector<IDX>({0, 0, 0}), sizeof(IDX) * 8, ctx);\n    auto ti = aten::VecToIdArray(std::vector<IDX>({}), sizeof(IDX) * 8, ctx);\n    auto td = aten::VecToIdArray(std::vector<IDX>({}), sizeof(IDX) * 8, ctx);\n    ASSERT_TRUE(ArrayEQ<IDX>(x.indptr, tp));\n    ASSERT_TRUE(ArrayEQ<IDX>(x.indices, ti));\n    ASSERT_TRUE(ArrayEQ<IDX>(x.data, td));\n  }\n}\n\nTEST(SpmatTest, CSRSliceMatrix) {\n  _TestCSRSliceMatrix1<int32_t>(CPU);\n  _TestCSRSliceMatrix1<int64_t>(CPU);\n  _TestCSRSliceMatrix2<int32_t>(CPU);\n  _TestCSRSliceMatrix2<int64_t>(CPU);\n#ifdef DGL_USE_CUDA\n  _TestCSRSliceMatrix1<int32_t>(GPU);\n  _TestCSRSliceMatrix1<int64_t>(GPU);\n  _TestCSRSliceMatrix2<int32_t>(GPU);\n  _TestCSRSliceMatrix2<int64_t>(GPU);\n#endif\n}\n\ntemplate <typename IDX>\nvoid _TestCSRHasDuplicate(DGLContext ctx) {\n  auto csr = CSR1<IDX>(ctx);\n  ASSERT_FALSE(aten::CSRHasDuplicate(csr));\n  csr = CSR2<IDX>(ctx);\n  ASSERT_TRUE(aten::CSRHasDuplicate(csr));\n}\n\nTEST(SpmatTest, CSRHasDuplicate) {\n  _TestCSRHasDuplicate<int32_t>(CPU);\n  _TestCSRHasDuplicate<int64_t>(CPU);\n#ifdef DGL_USE_CUDA\n  _TestCSRHasDuplicate<int32_t>(GPU);\n  _TestCSRHasDuplicate<int64_t>(GPU);\n#endif\n}\n\ntemplate <typename IDX>\nvoid _TestCSRSort(DGLContext ctx) {\n  auto csr = CSR1<IDX>(ctx);\n  ASSERT_FALSE(aten::CSRIsSorted(csr));\n  auto csr1 = aten::CSRSort(csr);\n  ASSERT_FALSE(aten::CSRIsSorted(csr));\n  ASSERT_TRUE(aten::CSRIsSorted(csr1));\n  ASSERT_TRUE(csr1.sorted);\n  aten::CSRSort_(&csr);\n  ASSERT_TRUE(aten::CSRIsSorted(csr));\n  ASSERT_TRUE(csr.sorted);\n  csr = CSR2<IDX>(ctx);\n  ASSERT_TRUE(aten::CSRIsSorted(csr));\n}\n\nTEST(SpmatTest, CSRSort) {\n  _TestCSRSort<int32_t>(CPU);\n  _TestCSRSort<int64_t>(CPU);\n#ifdef DGL_USE_CUDA\n  _TestCSRSort<int32_t>(GPU);\n  _TestCSRSort<int64_t>(GPU);\n#endif\n}\n\ntemplate <typename IDX>\nvoid _TestCSRReorder() {\n  auto csr = CSR2<IDX>();\n  auto new_row =\n      aten::VecToIdArray(std::vector<IDX>({2, 0, 3, 1}), sizeof(IDX) * 8, CTX);\n  auto new_col = aten::VecToIdArray(\n      std::vector<IDX>({2, 0, 4, 3, 1}), sizeof(IDX) * 8, CTX);\n  auto new_csr = CSRReorder(csr, new_row, new_col);\n  ASSERT_EQ(new_csr.num_rows, csr.num_rows);\n  ASSERT_EQ(new_csr.num_cols, csr.num_cols);\n}\n\nTEST(SpmatTest, TestCSRReorder) {\n  _TestCSRReorder<int32_t>();\n  _TestCSRReorder<int64_t>();\n}\n"
  },
  {
    "path": "tests/cpp/test_spmm.cc",
    "content": "#if !defined(_WIN32)\n#include <../src/array/cpu/spmm.h>\n#include <dgl/array.h>\n#include <gtest/gtest.h>\n#include <time.h>\n\n#include <random>\n\n#include \"./common.h\"\n\nusing namespace dgl;\nusing namespace dgl::runtime;\n\nint sizes[] = {1, 7, 8, 9, 31, 32, 33, 54, 63, 64, 65, 256, 257};\nnamespace ns_op = dgl::aten::cpu::op;\nnamespace {\n\ntemplate <class T>\nvoid GenerateData(T* data, int dim, T mul) {\n  for (int i = 0; i < dim; i++) {\n    data[i] = (i + 1) * mul;\n  }\n}\n\ntemplate <class T>\nvoid GenerateRandomData(T* data, int dim) {\n  std::mt19937 rng(std::random_device{}());\n  std::uniform_int_distribution<> dist(0, 10000);\n  for (int i = 0; i < dim; i++) {\n    data[i] = (dist(rng) / 100);\n  }\n}\n\ntemplate <class T>\nvoid GenerateZeroData(T* data, int dim) {\n  for (int i = 0; i < dim; i++) {\n    data[i] = 0;\n  }\n}\n\ntemplate <class T>\nvoid Copy(T* exp, T* out, T* hs, int dim) {\n  for (int i = 0; i < dim; i++) {\n    exp[i] = out[i] + hs[i];\n  }\n}\n\ntemplate <class T>\nvoid Add(T* exp, T* out, T* lhs, T* rhs, int dim) {\n  for (int i = 0; i < dim; i++) {\n    exp[i] = out[i] + lhs[i] + rhs[i];\n  }\n}\n\ntemplate <class T>\nvoid Sub(T* exp, T* out, T* lhs, T* rhs, int dim) {\n  for (int i = 0; i < dim; i++) {\n    exp[i] = out[i] + lhs[i] - rhs[i];\n  }\n}\n\ntemplate <class T>\nvoid Mul(T* exp, T* out, T* lhs, T* rhs, int dim) {\n  for (int i = 0; i < dim; i++) {\n    exp[i] = (out[i] + (lhs[i] * rhs[i]));\n  }\n}\n\ntemplate <class T>\nvoid Div(T* exp, T* out, T* lhs, T* rhs, int dim) {\n  for (int i = 0; i < dim; i++) {\n    exp[i] = (out[i] + (lhs[i] / rhs[i]));\n  }\n}\n\ntemplate <class T>\nvoid CheckResult(T* exp, T* out, int dim) {\n  for (int i = 0; i < dim; i++) {\n    ASSERT_TRUE(exp[i] == out[i]);\n  }\n}\n\n}  // namespace\n\ntemplate <typename IDX>\nvoid _TestSpmmCopyLhs() {\n  for (size_t i = 0; i < sizeof(sizes) / sizeof(int); i++) {\n    int dim = sizes[i];\n    IDX out[dim], exp[dim], lhs[dim];\n    GenerateZeroData(out, dim);\n    GenerateRandomData(lhs, dim);\n\n    // Calculation of expected output - 'exp'\n    Copy(exp, out, lhs, dim);\n\n    // Calculation of output using legacy path - 'out'\n    for (int k = 0; k < dim; k++) {\n      out[k] += ns_op::CopyLhs<IDX>::Call(lhs + k, nullptr);\n    }\n\n    CheckResult(exp, out, dim);\n  }\n}\n\nTEST(SpmmTest, TestSpmmCopyLhs) {\n  _TestSpmmCopyLhs<float>();\n  _TestSpmmCopyLhs<double>();\n  _TestSpmmCopyLhs<BFloat16>();\n}\n\ntemplate <typename IDX>\nvoid _TestSpmmCopyRhs() {\n  for (size_t i = 0; i < sizeof(sizes) / sizeof(int); i++) {\n    int dim = sizes[i];\n    IDX out[dim], exp[dim], rhs[dim];\n    GenerateZeroData(out, dim);\n    GenerateRandomData(rhs, dim);\n\n    // Calculation of expected output - 'exp'\n    Copy(exp, out, rhs, dim);\n\n    // Calculation of output using legacy path - 'out'\n    for (int k = 0; k < dim; k++) {\n      out[k] += ns_op::CopyRhs<IDX>::Call(nullptr, rhs + k);\n    }\n\n    CheckResult(exp, out, dim);\n  }\n}\n\nTEST(SpmmTest, TestSpmmCopyRhs) {\n  _TestSpmmCopyRhs<float>();\n  _TestSpmmCopyRhs<double>();\n  _TestSpmmCopyRhs<BFloat16>();\n}\n\ntemplate <typename IDX>\nvoid _TestSpmmAdd() {\n  for (size_t i = 0; i < sizeof(sizes) / sizeof(int); i++) {\n    int dim = sizes[i];\n    IDX out[dim], exp[dim], lhs[dim], rhs[dim];\n    GenerateZeroData(out, dim);\n    GenerateRandomData(lhs, dim);\n    GenerateRandomData(rhs, dim);\n\n    // Calculation of expected output - 'exp'\n    Add(exp, out, lhs, rhs, dim);\n\n    // Calculation of output using legacy path - 'out'\n    for (int k = 0; k < dim; k++) {\n      out[k] += ns_op::Add<IDX>::Call(lhs + k, rhs + k);\n    }\n\n    CheckResult(exp, out, dim);\n  }\n}\n\nTEST(SpmmTest, TestSpmmAdd) {\n  _TestSpmmAdd<float>();\n  _TestSpmmAdd<double>();\n  _TestSpmmAdd<BFloat16>();\n}\n\ntemplate <typename IDX>\nvoid _TestSpmmSub() {\n  for (size_t i = 0; i < sizeof(sizes) / sizeof(int); i++) {\n    int dim = sizes[i];\n    IDX out[dim], exp[dim], lhs[dim], rhs[dim];\n    GenerateZeroData(out, dim);\n    GenerateRandomData(lhs, dim);\n    GenerateRandomData(rhs, dim);\n\n    // Calculation of expected output - 'exp'\n    Sub(exp, out, lhs, rhs, dim);\n\n    // Calculation of output using legacy path - 'out'\n    for (int k = 0; k < dim; k++) {\n      out[k] += ns_op::Sub<IDX>::Call(lhs + k, rhs + k);\n    }\n\n    CheckResult(exp, out, dim);\n  }\n}\n\nTEST(SpmmTest, TestSpmmSub) {\n  _TestSpmmSub<float>();\n  _TestSpmmSub<double>();\n  _TestSpmmSub<BFloat16>();\n}\n\ntemplate <typename IDX>\nvoid _TestSpmmMul() {\n  for (size_t i = 0; i < sizeof(sizes) / sizeof(int); i++) {\n    int dim = sizes[i];\n    IDX out[dim], exp[dim], lhs[dim], rhs[dim];\n    GenerateZeroData(out, dim);\n    GenerateRandomData(lhs, dim);\n    GenerateRandomData(rhs, dim);\n\n    // Calculation of expected output - 'exp'\n    Mul(exp, out, lhs, rhs, dim);\n\n    // Calculation of output using legacy path - 'out'\n    for (int k = 0; k < dim; k++) {\n      out[k] += ns_op::Mul<IDX>::Call(lhs + k, rhs + k);\n    }\n\n    CheckResult(exp, out, dim);\n  }\n}\n\nTEST(SpmmTest, TestSpmmMul) {\n  _TestSpmmMul<float>();\n  _TestSpmmMul<double>();\n  _TestSpmmMul<BFloat16>();\n}\n\ntemplate <typename IDX>\nvoid _TestSpmmDiv() {\n  for (size_t i = 0; i < sizeof(sizes) / sizeof(int); i++) {\n    int dim = sizes[i];\n    IDX out[dim], exp[dim], lhs[dim], rhs[dim];\n    GenerateZeroData(out, dim);\n    GenerateData(lhs, dim, (IDX)15);\n    GenerateData(rhs, dim, (IDX)1);\n\n    // Calculation of expected output - 'exp'\n    Div(exp, out, lhs, rhs, dim);\n\n    // Calculation of output using legacy path - 'out'\n    for (int k = 0; k < dim; k++) {\n      out[k] += ns_op::Div<IDX>::Call(lhs + k, rhs + k);\n    }\n\n    CheckResult(exp, out, dim);\n  }\n}\n\nTEST(SpmmTest, TestSpmmDiv) {\n  _TestSpmmDiv<float>();\n  _TestSpmmDiv<double>();\n  _TestSpmmDiv<BFloat16>();\n}\n#endif  // _WIN32\n"
  },
  {
    "path": "tests/cpp/test_unit_graph.cc",
    "content": "/**\n *  Copyright (c) 2019 by Contributors\n * @file test_unit_graph.cc\n * @brief Test UnitGraph\n */\n#include <dgl/array.h>\n#include <dgl/immutable_graph.h>\n#include <dgl/runtime/device_api.h>\n#include <gtest/gtest.h>\n\n#include <memory>\n#include <vector>\n\n#include \"../../src/graph/unit_graph.h\"\n#include \"./../src/graph/heterograph.h\"\n#include \"./common.h\"\n\nusing namespace dgl;\nusing namespace dgl::runtime;\n\ntemplate <typename IdType>\naten::CSRMatrix CSR1(DGLContext ctx) {\n  /**\n   * G = [[0, 0, 1],\n   *      [1, 0, 1],\n   *      [0, 1, 0],\n   *      [1, 0, 1]]\n   */\n  IdArray g_indptr = aten::VecToIdArray(\n      std::vector<IdType>({0, 1, 3, 4, 6}), sizeof(IdType) * 8, CTX);\n  IdArray g_indices = aten::VecToIdArray(\n      std::vector<IdType>({2, 0, 2, 1, 0, 2}), sizeof(IdType) * 8, CTX);\n\n  const aten::CSRMatrix &csr_a =\n      aten::CSRMatrix(4, 3, g_indptr, g_indices, aten::NullArray(), false);\n  return csr_a;\n}\n\ntemplate aten::CSRMatrix CSR1<int32_t>(DGLContext ctx);\ntemplate aten::CSRMatrix CSR1<int64_t>(DGLContext ctx);\n\ntemplate <typename IdType>\naten::COOMatrix COO1(DGLContext ctx) {\n  /**\n   * G = [[1, 1, 0],\n   *      [0, 1, 0]]\n   */\n  IdArray g_row = aten::VecToIdArray(\n      std::vector<IdType>({0, 0, 1}), sizeof(IdType) * 8, CTX);\n  IdArray g_col = aten::VecToIdArray(\n      std::vector<IdType>({0, 1, 1}), sizeof(IdType) * 8, CTX);\n  const aten::COOMatrix &coo =\n      aten::COOMatrix(2, 3, g_row, g_col, aten::NullArray(), true, true);\n\n  return coo;\n}\n\ntemplate aten::COOMatrix COO1<int32_t>(DGLContext ctx);\ntemplate aten::COOMatrix COO1<int64_t>(DGLContext ctx);\n\ntemplate <typename IdType>\nvoid _TestUnitGraph_InOutDegrees(DGLContext ctx) {\n  /**\n  InDegree(s) is available only if COO or CSC formats permitted.\n  OutDegree(s) is available only if COO or CSR formats permitted.\n  */\n\n  // COO\n  {\n    const aten::COOMatrix &coo = COO1<IdType>(ctx);\n    auto &&g = CreateFromCOO(2, coo, COO_CODE);\n    ASSERT_EQ(g->InDegree(0, 0), 1);\n    auto &&nids = aten::Range(0, g->NumVertices(0), g->NumBits(), g->Context());\n    ASSERT_TRUE(ArrayEQ<IdType>(\n        g->InDegrees(0, nids),\n        aten::VecToIdArray<IdType>({1, 2}, g->NumBits(), g->Context())));\n    ASSERT_EQ(g->OutDegree(0, 0), 2);\n    ASSERT_TRUE(ArrayEQ<IdType>(\n        g->OutDegrees(0, nids),\n        aten::VecToIdArray<IdType>({2, 1}, g->NumBits(), g->Context())));\n  }\n  // CSC\n  {\n    const aten::CSRMatrix &csr = CSR1<IdType>(ctx);\n    auto &&g = CreateFromCSC(2, csr, CSC_CODE);\n    ASSERT_EQ(g->InDegree(0, 0), 1);\n    auto &&nids = aten::Range(0, g->NumVertices(0), g->NumBits(), g->Context());\n    ASSERT_TRUE(ArrayEQ<IdType>(\n        g->InDegrees(0, nids),\n        aten::VecToIdArray<IdType>({1, 2, 1}, g->NumBits(), g->Context())));\n    EXPECT_ANY_THROW(g->OutDegree(0, 0));\n    EXPECT_ANY_THROW(g->OutDegrees(0, nids));\n  }\n  // CSR\n  {\n    const aten::CSRMatrix &csr = CSR1<IdType>(ctx);\n    auto &&g = CreateFromCSR(2, csr, CSR_CODE);\n    ASSERT_EQ(g->OutDegree(0, 0), 1);\n    auto &&nids = aten::Range(0, g->NumVertices(0), g->NumBits(), g->Context());\n    ASSERT_TRUE(ArrayEQ<IdType>(\n        g->OutDegrees(0, nids),\n        aten::VecToIdArray<IdType>({1, 2, 1, 2}, g->NumBits(), g->Context())));\n    EXPECT_ANY_THROW(g->InDegree(0, 0));\n    EXPECT_ANY_THROW(g->InDegrees(0, nids));\n  }\n}\n\ntemplate <typename IdType>\nvoid _TestUnitGraph(DGLContext ctx) {\n  const aten::CSRMatrix &csr = CSR1<IdType>(ctx);\n  const aten::COOMatrix &coo = COO1<IdType>(ctx);\n\n  auto g = CreateFromCSC(2, csr);\n  ASSERT_EQ(g->GetCreatedFormats(), 4);\n\n  g = CreateFromCSR(2, csr);\n  ASSERT_EQ(g->GetCreatedFormats(), 2);\n\n  g = CreateFromCOO(2, coo);\n  ASSERT_EQ(g->GetCreatedFormats(), 1);\n\n  auto src = aten::VecToIdArray<int64_t>({1, 2, 5, 3});\n  auto dst = aten::VecToIdArray<int64_t>({1, 6, 2, 6});\n  auto mg = dgl::UnitGraph::CreateFromCOO(2, 9, 8, src, dst, COO_CODE);\n  ASSERT_EQ(mg->GetCreatedFormats(), 1);\n  auto hmg = dgl::UnitGraph::CreateFromCOO(1, 8, 8, src, dst, COO_CODE);\n  auto img = std::dynamic_pointer_cast<ImmutableGraph>(hmg->AsImmutableGraph());\n  ASSERT_TRUE(img != nullptr);\n  mg = dgl::UnitGraph::CreateFromCOO(2, 9, 8, src, dst, CSR_CODE | COO_CODE);\n  ASSERT_EQ(mg->GetCreatedFormats(), 1);\n  hmg = dgl::UnitGraph::CreateFromCOO(1, 8, 8, src, dst, CSR_CODE | COO_CODE);\n  img = std::dynamic_pointer_cast<ImmutableGraph>(hmg->AsImmutableGraph());\n  ASSERT_TRUE(img != nullptr);\n  mg = dgl::UnitGraph::CreateFromCOO(2, 9, 8, src, dst, CSC_CODE | COO_CODE);\n  ASSERT_EQ(mg->GetCreatedFormats(), 1);\n  hmg = dgl::UnitGraph::CreateFromCOO(1, 8, 8, src, dst, CSC_CODE | COO_CODE);\n  img = std::dynamic_pointer_cast<ImmutableGraph>(hmg->AsImmutableGraph());\n  ASSERT_TRUE(img != nullptr);\n\n  g = CreateFromCSC(2, csr);\n  ASSERT_EQ(g->GetCreatedFormats(), 4);\n\n  g = CreateFromCSR(2, csr);\n  ASSERT_EQ(g->GetCreatedFormats(), 2);\n\n  g = CreateFromCOO(2, coo);\n  ASSERT_EQ(g->GetCreatedFormats(), 1);\n}\n\ntemplate <typename IdType>\nvoid _TestUnitGraph_GetInCSR(DGLContext ctx) {\n  const aten::CSRMatrix &csr = CSR1<IdType>(ctx);\n  const aten::COOMatrix &coo = COO1<IdType>(ctx);\n\n  auto g = CreateFromCSC(2, csr);\n  auto in_csr_matrix = g->GetCSCMatrix(0);\n  ASSERT_EQ(in_csr_matrix.num_rows, csr.num_rows);\n  ASSERT_EQ(in_csr_matrix.num_cols, csr.num_cols);\n  ASSERT_EQ(g->GetCreatedFormats(), 4);\n\n  // test out csr\n  g = CreateFromCSR(2, csr);\n  auto g_ptr = g->GetGraphInFormat(CSC_CODE);\n  in_csr_matrix = g_ptr->GetCSCMatrix(0);\n  ASSERT_EQ(in_csr_matrix.num_cols, csr.num_rows);\n  ASSERT_EQ(in_csr_matrix.num_rows, csr.num_cols);\n  ASSERT_EQ(g->GetCreatedFormats(), 2);\n  in_csr_matrix = g->GetCSCMatrix(0);\n  ASSERT_EQ(in_csr_matrix.num_cols, csr.num_rows);\n  ASSERT_EQ(in_csr_matrix.num_rows, csr.num_cols);\n  ASSERT_EQ(g->GetCreatedFormats(), 6);\n\n  // test out coo\n  g = CreateFromCOO(2, coo);\n  g_ptr = g->GetGraphInFormat(CSC_CODE);\n  in_csr_matrix = g_ptr->GetCSCMatrix(0);\n  ASSERT_EQ(in_csr_matrix.num_cols, coo.num_rows);\n  ASSERT_EQ(in_csr_matrix.num_rows, coo.num_cols);\n  ASSERT_EQ(g->GetCreatedFormats(), 1);\n\n  in_csr_matrix = g->GetCSCMatrix(0);\n  ASSERT_EQ(in_csr_matrix.num_cols, coo.num_rows);\n  ASSERT_EQ(in_csr_matrix.num_rows, coo.num_cols);\n  ASSERT_EQ(g->GetCreatedFormats(), 5);\n}\n\ntemplate <typename IdType>\nvoid _TestUnitGraph_GetOutCSR(DGLContext ctx) {\n  const aten::CSRMatrix &csr = CSR1<IdType>(ctx);\n  const aten::COOMatrix &coo = COO1<IdType>(ctx);\n\n  auto g = CreateFromCSC(2, csr);\n  auto g_ptr = g->GetGraphInFormat(CSR_CODE);\n  auto out_csr_matrix = g_ptr->GetCSRMatrix(0);\n  ASSERT_EQ(out_csr_matrix.num_cols, csr.num_rows);\n  ASSERT_EQ(out_csr_matrix.num_rows, csr.num_cols);\n  ASSERT_EQ(g->GetCreatedFormats(), 4);\n  out_csr_matrix = g->GetCSRMatrix(0);\n  ASSERT_EQ(out_csr_matrix.num_cols, csr.num_rows);\n  ASSERT_EQ(out_csr_matrix.num_rows, csr.num_cols);\n  ASSERT_EQ(g->GetCreatedFormats(), 6);\n\n  // test out csr\n  g = CreateFromCSR(2, csr);\n  out_csr_matrix = g->GetCSRMatrix(0);\n  ASSERT_EQ(out_csr_matrix.num_rows, csr.num_rows);\n  ASSERT_EQ(out_csr_matrix.num_cols, csr.num_cols);\n  ASSERT_EQ(g->GetCreatedFormats(), 2);\n\n  // test out coo\n  g = CreateFromCOO(2, coo);\n  g_ptr = g->GetGraphInFormat(CSR_CODE);\n  out_csr_matrix = g_ptr->GetCSRMatrix(0);\n  ASSERT_EQ(out_csr_matrix.num_rows, coo.num_rows);\n  ASSERT_EQ(out_csr_matrix.num_cols, coo.num_cols);\n  ASSERT_EQ(g->GetCreatedFormats(), 1);\n\n  out_csr_matrix = g->GetCSRMatrix(0);\n  ASSERT_EQ(out_csr_matrix.num_rows, coo.num_rows);\n  ASSERT_EQ(out_csr_matrix.num_cols, coo.num_cols);\n  ASSERT_EQ(g->GetCreatedFormats(), 3);\n}\n\ntemplate <typename IdType>\nvoid _TestUnitGraph_GetCOO(DGLContext ctx) {\n  const aten::CSRMatrix &csr = CSR1<IdType>(ctx);\n  const aten::COOMatrix &coo = COO1<IdType>(ctx);\n\n  auto g = CreateFromCSC(2, csr);\n  auto g_ptr = g->GetGraphInFormat(COO_CODE);\n  auto out_coo_matrix = g_ptr->GetCOOMatrix(0);\n  ASSERT_EQ(out_coo_matrix.num_cols, csr.num_rows);\n  ASSERT_EQ(out_coo_matrix.num_rows, csr.num_cols);\n  ASSERT_EQ(g->GetCreatedFormats(), 4);\n  out_coo_matrix = g->GetCOOMatrix(0);\n  ASSERT_EQ(out_coo_matrix.num_cols, csr.num_rows);\n  ASSERT_EQ(out_coo_matrix.num_rows, csr.num_cols);\n  ASSERT_EQ(g->GetCreatedFormats(), 5);\n\n  // test out csr\n  g = CreateFromCSR(2, csr);\n  g_ptr = g->GetGraphInFormat(COO_CODE);\n  out_coo_matrix = g_ptr->GetCOOMatrix(0);\n  ASSERT_EQ(out_coo_matrix.num_rows, csr.num_rows);\n  ASSERT_EQ(out_coo_matrix.num_cols, csr.num_cols);\n  ASSERT_EQ(g->GetCreatedFormats(), 2);\n  out_coo_matrix = g->GetCOOMatrix(0);\n  ASSERT_EQ(out_coo_matrix.num_rows, csr.num_rows);\n  ASSERT_EQ(out_coo_matrix.num_cols, csr.num_cols);\n  ASSERT_EQ(g->GetCreatedFormats(), 3);\n\n  // test out coo\n  g = CreateFromCOO(2, coo);\n  out_coo_matrix = g->GetCOOMatrix(0);\n  ASSERT_EQ(out_coo_matrix.num_rows, coo.num_rows);\n  ASSERT_EQ(out_coo_matrix.num_cols, coo.num_cols);\n  ASSERT_EQ(g->GetCreatedFormats(), 1);\n}\n\ntemplate <typename IdType>\nvoid _TestUnitGraph_Reserve(DGLContext ctx) {\n  const aten::CSRMatrix &csr = CSR1<IdType>(ctx);\n  const aten::COOMatrix &coo = COO1<IdType>(ctx);\n\n  auto g = CreateFromCSC(2, csr);\n  ASSERT_EQ(g->GetCreatedFormats(), 4);\n  auto r_g =\n      std::dynamic_pointer_cast<UnitGraph>(g->GetRelationGraph(0))->Reverse();\n  ASSERT_EQ(r_g->GetCreatedFormats(), 2);\n  aten::CSRMatrix g_in_csr = g->GetCSCMatrix(0);\n  aten::CSRMatrix r_g_out_csr = r_g->GetCSRMatrix(0);\n  ASSERT_TRUE(g_in_csr.indptr->data == r_g_out_csr.indptr->data);\n  ASSERT_TRUE(g_in_csr.indices->data == r_g_out_csr.indices->data);\n  aten::CSRMatrix g_out_csr = g->GetCSRMatrix(0);\n  ASSERT_EQ(g->GetCreatedFormats(), 6);\n  ASSERT_EQ(r_g->GetCreatedFormats(), 6);\n  aten::CSRMatrix r_g_in_csr = r_g->GetCSCMatrix(0);\n  ASSERT_TRUE(g_out_csr.indptr->data == r_g_in_csr.indptr->data);\n  ASSERT_TRUE(g_out_csr.indices->data == r_g_in_csr.indices->data);\n  aten::COOMatrix g_coo = g->GetCOOMatrix(0);\n  ASSERT_EQ(g->GetCreatedFormats(), 7);\n  ASSERT_EQ(r_g->GetCreatedFormats(), 6);\n  aten::COOMatrix r_g_coo = r_g->GetCOOMatrix(0);\n  ASSERT_EQ(r_g->GetCreatedFormats(), 7);\n  ASSERT_EQ(g_coo.num_rows, r_g_coo.num_cols);\n  ASSERT_EQ(g_coo.num_cols, r_g_coo.num_rows);\n  ASSERT_TRUE(ArrayEQ<IdType>(g_coo.row, r_g_coo.col));\n  ASSERT_TRUE(ArrayEQ<IdType>(g_coo.col, r_g_coo.row));\n\n  // test out csr\n  g = CreateFromCSR(2, csr);\n  ASSERT_EQ(g->GetCreatedFormats(), 2);\n  r_g = std::dynamic_pointer_cast<UnitGraph>(g->GetRelationGraph(0))->Reverse();\n  ASSERT_EQ(r_g->GetCreatedFormats(), 4);\n  g_out_csr = g->GetCSRMatrix(0);\n  r_g_in_csr = r_g->GetCSCMatrix(0);\n  ASSERT_TRUE(g_out_csr.indptr->data == r_g_in_csr.indptr->data);\n  ASSERT_TRUE(g_out_csr.indices->data == r_g_in_csr.indices->data);\n  g_in_csr = g->GetCSCMatrix(0);\n  ASSERT_EQ(g->GetCreatedFormats(), 6);\n  ASSERT_EQ(r_g->GetCreatedFormats(), 6);\n  r_g_out_csr = r_g->GetCSRMatrix(0);\n  ASSERT_TRUE(g_in_csr.indptr->data == r_g_out_csr.indptr->data);\n  ASSERT_TRUE(g_in_csr.indices->data == r_g_out_csr.indices->data);\n  g_coo = g->GetCOOMatrix(0);\n  ASSERT_EQ(g->GetCreatedFormats(), 7);\n  ASSERT_EQ(r_g->GetCreatedFormats(), 6);\n  r_g_coo = r_g->GetCOOMatrix(0);\n  ASSERT_EQ(r_g->GetCreatedFormats(), 7);\n  ASSERT_EQ(g_coo.num_rows, r_g_coo.num_cols);\n  ASSERT_EQ(g_coo.num_cols, r_g_coo.num_rows);\n  ASSERT_TRUE(ArrayEQ<IdType>(g_coo.row, r_g_coo.col));\n  ASSERT_TRUE(ArrayEQ<IdType>(g_coo.col, r_g_coo.row));\n\n  // test out coo\n  g = CreateFromCOO(2, coo);\n  ASSERT_EQ(g->GetCreatedFormats(), 1);\n  r_g = std::dynamic_pointer_cast<UnitGraph>(g->GetRelationGraph(0))->Reverse();\n  ASSERT_EQ(r_g->GetCreatedFormats(), 1);\n  g_coo = g->GetCOOMatrix(0);\n  r_g_coo = r_g->GetCOOMatrix(0);\n  ASSERT_EQ(g_coo.num_rows, r_g_coo.num_cols);\n  ASSERT_EQ(g_coo.num_cols, r_g_coo.num_rows);\n  ASSERT_TRUE(g_coo.row->data == r_g_coo.col->data);\n  ASSERT_TRUE(g_coo.col->data == r_g_coo.row->data);\n  g_in_csr = g->GetCSCMatrix(0);\n  ASSERT_EQ(g->GetCreatedFormats(), 5);\n  ASSERT_EQ(r_g->GetCreatedFormats(), 3);\n  r_g_out_csr = r_g->GetCSRMatrix(0);\n  ASSERT_TRUE(g_in_csr.indptr->data == r_g_out_csr.indptr->data);\n  ASSERT_TRUE(g_in_csr.indices->data == r_g_out_csr.indices->data);\n  g_out_csr = g->GetCSRMatrix(0);\n  ASSERT_EQ(g->GetCreatedFormats(), 7);\n  ASSERT_EQ(r_g->GetCreatedFormats(), 7);\n  r_g_in_csr = r_g->GetCSCMatrix(0);\n  ASSERT_TRUE(g_out_csr.indptr->data == r_g_in_csr.indptr->data);\n  ASSERT_TRUE(g_out_csr.indices->data == r_g_in_csr.indices->data);\n}\n\ntemplate <typename IdType>\nvoid _TestUnitGraph_CopyTo(\n    const DGLContext &src_ctx, const DGLContext &dst_ctx) {\n  const aten::CSRMatrix &csr = CSR1<IdType>(src_ctx);\n  const aten::COOMatrix &coo = COO1<IdType>(src_ctx);\n\n  auto device = dgl::runtime::DeviceAPI::Get(dst_ctx);\n  // We don't allow SetStream in DGL for now.\n  auto stream = nullptr;\n\n  auto g = dgl::UnitGraph::CreateFromCSC(2, csr);\n  ASSERT_EQ(g->GetCreatedFormats(), 4);\n  auto cg = dgl::UnitGraph::CopyTo(g, dst_ctx);\n  device->StreamSync(dst_ctx, stream);\n  ASSERT_EQ(cg->GetCreatedFormats(), 4);\n\n  g = dgl::UnitGraph::CreateFromCSR(2, csr);\n  ASSERT_EQ(g->GetCreatedFormats(), 2);\n  cg = dgl::UnitGraph::CopyTo(g, dst_ctx);\n  device->StreamSync(dst_ctx, stream);\n  ASSERT_EQ(cg->GetCreatedFormats(), 2);\n\n  g = dgl::UnitGraph::CreateFromCOO(2, coo);\n  ASSERT_EQ(g->GetCreatedFormats(), 1);\n  cg = dgl::UnitGraph::CopyTo(g, dst_ctx);\n  device->StreamSync(dst_ctx, stream);\n  ASSERT_EQ(cg->GetCreatedFormats(), 1);\n}\n\nTEST(UniGraphTest, TestUnitGraph_CopyTo) {\n  _TestUnitGraph_CopyTo<int32_t>(CPU, CPU);\n  _TestUnitGraph_CopyTo<int64_t>(CPU, CPU);\n#ifdef DGL_USE_CUDA\n  _TestUnitGraph_CopyTo<int32_t>(CPU, GPU);\n  _TestUnitGraph_CopyTo<int32_t>(GPU, GPU);\n  _TestUnitGraph_CopyTo<int32_t>(GPU, CPU);\n  _TestUnitGraph_CopyTo<int64_t>(CPU, GPU);\n  _TestUnitGraph_CopyTo<int64_t>(GPU, GPU);\n  _TestUnitGraph_CopyTo<int64_t>(GPU, CPU);\n#endif\n}\n\nTEST(UniGraphTest, TestUnitGraph_InOutDegrees) {\n  _TestUnitGraph_InOutDegrees<int32_t>(CPU);\n  _TestUnitGraph_InOutDegrees<int64_t>(CPU);\n#ifdef DGL_USE_CUDA\n  _TestUnitGraph_InOutDegrees<int32_t>(GPU);\n  _TestUnitGraph_InOutDegrees<int64_t>(GPU);\n#endif\n}\n\nTEST(UniGraphTest, TestUnitGraph_Create) {\n  _TestUnitGraph<int32_t>(CPU);\n  _TestUnitGraph<int64_t>(CPU);\n#ifdef DGL_USE_CUDA\n  _TestUnitGraph<int32_t>(GPU);\n  _TestUnitGraph<int64_t>(GPU);\n#endif\n}\n\nTEST(UniGraphTest, TestUnitGraph_GetInCSR) {\n  _TestUnitGraph_GetInCSR<int32_t>(CPU);\n  _TestUnitGraph_GetInCSR<int64_t>(CPU);\n#ifdef DGL_USE_CUDA\n  _TestUnitGraph_GetInCSR<int32_t>(GPU);\n  _TestUnitGraph_GetInCSR<int64_t>(GPU);\n#endif\n}\n\nTEST(UniGraphTest, TestUnitGraph_GetOutCSR) {\n  _TestUnitGraph_GetOutCSR<int32_t>(CPU);\n  _TestUnitGraph_GetOutCSR<int64_t>(CPU);\n#ifdef DGL_USE_CUDA\n  _TestUnitGraph_GetOutCSR<int32_t>(GPU);\n  _TestUnitGraph_GetOutCSR<int64_t>(GPU);\n#endif\n}\n\nTEST(UniGraphTest, TestUnitGraph_GetCOO) {\n  _TestUnitGraph_GetCOO<int32_t>(CPU);\n  _TestUnitGraph_GetCOO<int64_t>(CPU);\n#ifdef DGL_USE_CUDA\n  _TestUnitGraph_GetCOO<int32_t>(GPU);\n  _TestUnitGraph_GetCOO<int64_t>(GPU);\n#endif\n}\n\nTEST(UniGraphTest, TestUnitGraph_Reserve) {\n  _TestUnitGraph_Reserve<int32_t>(CPU);\n  _TestUnitGraph_Reserve<int64_t>(CPU);\n#ifdef DGL_USE_CUDA\n  _TestUnitGraph_Reserve<int32_t>(GPU);\n  _TestUnitGraph_Reserve<int64_t>(GPU);\n#endif\n}\n"
  },
  {
    "path": "tests/cpp/test_zerocopy_serialize.cc",
    "content": "#include <dgl/array.h>\n#include <dgl/immutable_graph.h>\n#include <dgl/zerocopy_serializer.h>\n#include <dmlc/memory_io.h>\n#include <gtest/gtest.h>\n\n#include <algorithm>\n#include <iostream>\n#include <vector>\n\n#include \"../../src/graph/heterograph.h\"\n#include \"../../src/graph/unit_graph.h\"\n#include \"./common.h\"\n\n#ifndef _WIN32\n\nusing namespace dgl;\nusing namespace dgl::aten;\nusing namespace dmlc;\n// Function to convert an idarray to string\nstd::string IdArrayToStr(IdArray arr) {\n  arr = arr.CopyTo(DGLContext{kDGLCPU, 0});\n  int64_t len = arr->shape[0];\n  std::ostringstream oss;\n  oss << \"(\" << len << \")[\";\n  if (arr->dtype.bits == 32) {\n    int32_t *data = static_cast<int32_t *>(arr->data);\n    for (int64_t i = 0; i < len; ++i) {\n      oss << data[i] << \" \";\n    }\n  } else {\n    int64_t *data = static_cast<int64_t *>(arr->data);\n    for (int64_t i = 0; i < len; ++i) {\n      oss << data[i] << \" \";\n    }\n  }\n  oss << \"]\";\n  return oss.str();\n}\n\nTEST(ZeroCopySerialize, NDArray) {\n  auto tensor1 = VecToIdArray<int64_t>({1, 2, 5, 3});\n  auto tensor2 = VecToIdArray<int64_t>({6, 6, 5, 7});\n\n  std::string nonzerocopy_blob;\n  dmlc::MemoryStringStream ifs(&nonzerocopy_blob);\n  static_cast<dmlc::Stream *>(&ifs)->Write(tensor1);\n  static_cast<dmlc::Stream *>(&ifs)->Write(tensor2);\n\n  std::string zerocopy_blob;\n  StreamWithBuffer zc_write_strm(&zerocopy_blob, true);\n  zc_write_strm.Write(tensor1);\n  zc_write_strm.Write(tensor2);\n\n  EXPECT_EQ(nonzerocopy_blob.size() - zerocopy_blob.size(), 126)\n      << \"Invalid save\";\n\n  std::vector<void *> new_ptr_list;\n  // Use memcpy to mimic remote machine reconstruction\n  for (auto ptr : zc_write_strm.buffer_list()) {\n    auto new_ptr = malloc(ptr.size);\n    memcpy(new_ptr, ptr.data, ptr.size);\n    new_ptr_list.emplace_back(new_ptr);\n  }\n\n  NDArray loadtensor1, loadtensor2;\n  StreamWithBuffer zc_read_strm(&zerocopy_blob, new_ptr_list);\n  zc_read_strm.Read(&loadtensor1);\n  zc_read_strm.Read(&loadtensor2);\n}\n\nTEST(ZeroCopySerialize, ZeroShapeNDArray) {\n  auto tensor1 = VecToIdArray<int64_t>({6, 6, 5, 7});\n  auto tensor2 = VecToIdArray<int64_t>({});\n  auto tensor3 = VecToIdArray<int64_t>({6, 6, 2, 7});\n  std::vector<NDArray> ndvec;\n  ndvec.push_back(tensor1);\n  ndvec.push_back(tensor2);\n  ndvec.push_back(tensor3);\n\n  std::string zerocopy_blob;\n  StreamWithBuffer zc_write_strm(&zerocopy_blob, true);\n  zc_write_strm.Write(ndvec);\n\n  std::vector<void *> new_ptr_list;\n  // Use memcpy to mimic remote machine reconstruction\n  for (auto ptr : zc_write_strm.buffer_list()) {\n    auto new_ptr = malloc(ptr.size);\n    memcpy(new_ptr, ptr.data, ptr.size);\n    new_ptr_list.emplace_back(new_ptr);\n  }\n\n  std::vector<NDArray> ndvec_read;\n  StreamWithBuffer zc_read_strm(&zerocopy_blob, new_ptr_list);\n  zc_read_strm.Read(&ndvec_read);\n  EXPECT_EQ(ndvec_read[1]->ndim, 1);\n  EXPECT_EQ(ndvec_read[1]->shape[0], 0);\n}\n\nTEST(ZeroCopySerialize, SharedMem) {\n  auto tensor1 = VecToIdArray<int64_t>({1, 2, 5, 3});\n  DGLDataType dtype = {kDGLInt, 64, 1};\n  std::vector<int64_t> shape{4};\n  DGLContext cpu_ctx = {kDGLCPU, 0};\n  auto shared_tensor =\n      NDArray::EmptyShared(\"test\", shape, dtype, cpu_ctx, true);\n  shared_tensor.CopyFrom(tensor1);\n\n  std::string nonzerocopy_blob;\n  dmlc::MemoryStringStream ifs(&nonzerocopy_blob);\n  static_cast<dmlc::Stream *>(&ifs)->Write(shared_tensor);\n\n  std::string zerocopy_blob;\n  StreamWithBuffer zc_write_strm(&zerocopy_blob, false);\n  zc_write_strm.Write(shared_tensor);\n\n  EXPECT_EQ(nonzerocopy_blob.size() - zerocopy_blob.size(), 51)\n      << \"Invalid save\";\n  NDArray loadtensor1;\n\n  StreamWithBuffer zc_read_strm = StreamWithBuffer(&zerocopy_blob, false);\n  zc_read_strm.Read(&loadtensor1);\n}\n\nTEST(ZeroCopySerialize, HeteroGraph) {\n  auto src = VecToIdArray<int64_t>({1, 2, 5, 3});\n  auto dst = VecToIdArray<int64_t>({1, 6, 2, 6});\n  auto mg1 = dgl::UnitGraph::CreateFromCOO(2, 9, 8, src, dst);\n  src = VecToIdArray<int64_t>({6, 2, 5, 1, 8});\n  dst = VecToIdArray<int64_t>({5, 2, 4, 8, 0});\n  auto mg2 = dgl::UnitGraph::CreateFromCOO(1, 9, 9, src, dst);\n  std::vector<HeteroGraphPtr> relgraphs;\n  relgraphs.push_back(mg1);\n  relgraphs.push_back(mg2);\n  src = VecToIdArray<int64_t>({0, 0});\n  dst = VecToIdArray<int64_t>({1, 0});\n  auto meta_gptr = ImmutableGraph::CreateFromCOO(3, src, dst);\n  auto hrptr = std::make_shared<HeteroGraph>(meta_gptr, relgraphs);\n\n  std::string nonzerocopy_blob;\n  dmlc::MemoryStringStream ifs(&nonzerocopy_blob);\n  static_cast<dmlc::Stream *>(&ifs)->Write(hrptr);\n\n  std::string zerocopy_blob;\n  StreamWithBuffer zc_write_strm(&zerocopy_blob, true);\n  zc_write_strm.Write(hrptr);\n\n  EXPECT_EQ(nonzerocopy_blob.size() - zerocopy_blob.size(), 745)\n      << \"Invalid save\";\n\n  std::vector<void *> new_ptr_list;\n  // Use memcpy to mimic remote machine reconstruction\n  for (auto ptr : zc_write_strm.buffer_list()) {\n    auto new_ptr = malloc(ptr.size);\n    memcpy(new_ptr, ptr.data, ptr.size);\n    new_ptr_list.emplace_back(new_ptr);\n  }\n\n  auto gptr = dgl::Serializer::make_shared<HeteroGraph>();\n  StreamWithBuffer zc_read_strm(&zerocopy_blob, new_ptr_list);\n  zc_read_strm.Read(&gptr);\n\n  EXPECT_EQ(gptr->NumVertices(0), 9);\n  EXPECT_EQ(gptr->NumVertices(1), 8);\n}\n\n#endif  // _WIN32"
  },
  {
    "path": "tests/cugraph/cugraph-ops/test_cugraph_gatconv.py",
    "content": "# pylint: disable=too-many-arguments, too-many-locals\nfrom collections import OrderedDict\nfrom itertools import product\n\nimport dgl\nimport pytest\nimport torch\nfrom dgl.nn import CuGraphGATConv, GATConv\n\noptions = OrderedDict(\n    {\n        \"idtype_int\": [False, True],\n        \"max_in_degree\": [None, 8],\n        \"num_heads\": [1, 3],\n        \"to_block\": [False, True],\n    }\n)\n\n\ndef generate_graph():\n    u = torch.tensor([0, 1, 0, 2, 3, 0, 4, 0, 5, 0, 6, 7, 0, 8, 9])\n    v = torch.tensor([1, 9, 2, 9, 9, 4, 9, 5, 9, 6, 9, 9, 8, 9, 0])\n    g = dgl.graph((u, v))\n    return g\n\n\n@pytest.mark.parametrize(\",\".join(options.keys()), product(*options.values()))\ndef test_gatconv_equality(idtype_int, max_in_degree, num_heads, to_block):\n    device = \"cuda:0\"\n    in_feat, out_feat = 10, 2\n    args = (in_feat, out_feat, num_heads)\n    kwargs = {\"bias\": False}\n    g = generate_graph().to(device)\n    if idtype_int:\n        g = g.int()\n    if to_block:\n        g = dgl.to_block(g)\n    feat = torch.rand(g.num_src_nodes(), in_feat).to(device)\n\n    torch.manual_seed(0)\n    conv1 = GATConv(*args, **kwargs, allow_zero_in_degree=True).to(device)\n    out1 = conv1(g, feat)\n\n    torch.manual_seed(0)\n    conv2 = CuGraphGATConv(*args, **kwargs).to(device)\n    dim = num_heads * out_feat\n    with torch.no_grad():\n        conv2.attn_weights.data[:dim] = conv1.attn_l.data.flatten()\n        conv2.attn_weights.data[dim:] = conv1.attn_r.data.flatten()\n        conv2.fc.weight.data[:] = conv1.fc.weight.data\n    out2 = conv2(g, feat, max_in_degree=max_in_degree)\n    assert torch.allclose(out1, out2, atol=1e-6)\n\n    grad_out1 = torch.rand_like(out1)\n    grad_out2 = grad_out1.clone().detach()\n    out1.backward(grad_out1)\n    out2.backward(grad_out2)\n\n    assert torch.allclose(conv1.fc.weight.grad, conv2.fc.weight.grad, atol=1e-6)\n    assert torch.allclose(\n        torch.cat((conv1.attn_l.grad, conv1.attn_r.grad), dim=0),\n        conv2.attn_weights.grad.view(2, num_heads, out_feat),\n        atol=1e-6,\n    )\n"
  },
  {
    "path": "tests/cugraph/cugraph-ops/test_cugraph_relgraphconv.py",
    "content": "# pylint: disable=too-many-arguments, too-many-locals\nfrom collections import OrderedDict\nfrom itertools import product\n\nimport dgl\nimport pytest\nimport torch\nfrom dgl.nn import CuGraphRelGraphConv, RelGraphConv\n\n# TODO(tingyu66): Re-enable the following tests after updating cuGraph CI image.\noptions = OrderedDict(\n    {\n        \"idtype_int\": [False, True],\n        \"max_in_degree\": [None, 8],\n        \"num_bases\": [1, 2, 5],\n        \"regularizer\": [None, \"basis\"],\n        \"self_loop\": [False, True],\n        \"to_block\": [False, True],\n    }\n)\n\n\ndef generate_graph():\n    u = torch.tensor([0, 1, 0, 2, 3, 0, 4, 0, 5, 0, 6, 7, 0, 8, 9])\n    v = torch.tensor([1, 9, 2, 9, 9, 4, 9, 5, 9, 6, 9, 9, 8, 9, 0])\n    g = dgl.graph((u, v))\n    return g\n\n\n@pytest.mark.parametrize(\",\".join(options.keys()), product(*options.values()))\ndef test_relgraphconv_equality(\n    idtype_int, max_in_degree, num_bases, regularizer, self_loop, to_block\n):\n    device = \"cuda:0\"\n    in_feat, out_feat, num_rels = 10, 2, 3\n    args = (in_feat, out_feat, num_rels)\n    kwargs = {\n        \"num_bases\": num_bases,\n        \"regularizer\": regularizer,\n        \"bias\": False,\n        \"self_loop\": self_loop,\n    }\n    g = generate_graph().to(device)\n    g.edata[dgl.ETYPE] = torch.randint(num_rels, (g.num_edges(),)).to(device)\n    if idtype_int:\n        g = g.int()\n    if to_block:\n        g = dgl.to_block(g)\n    feat = torch.rand(g.num_src_nodes(), in_feat).to(device)\n\n    torch.manual_seed(0)\n    conv1 = RelGraphConv(*args, **kwargs).to(device)\n\n    torch.manual_seed(0)\n    kwargs[\"apply_norm\"] = False\n    conv2 = CuGraphRelGraphConv(*args, **kwargs).to(device)\n\n    out1 = conv1(g, feat, g.edata[dgl.ETYPE])\n    out2 = conv2(g, feat, g.edata[dgl.ETYPE], max_in_degree=max_in_degree)\n    assert torch.allclose(out1, out2, atol=1e-06)\n\n    grad_out = torch.rand_like(out1)\n    out1.backward(grad_out)\n    out2.backward(grad_out)\n\n    end = -1 if self_loop else None\n    assert torch.allclose(conv1.linear_r.W.grad, conv2.W.grad[:end], atol=1e-6)\n\n    if self_loop:\n        assert torch.allclose(\n            conv1.loop_weight.grad, conv2.W.grad[-1], atol=1e-6\n        )\n\n    if regularizer is not None:\n        assert torch.allclose(\n            conv1.linear_r.coeff.grad, conv2.coeff.grad, atol=1e-6\n        )\n"
  },
  {
    "path": "tests/cugraph/cugraph-ops/test_cugraph_sageconv.py",
    "content": "# pylint: disable=too-many-arguments, too-many-locals\nfrom collections import OrderedDict\nfrom itertools import product\n\nimport dgl\nimport pytest\nimport torch\nfrom dgl.nn import CuGraphSAGEConv, SAGEConv\n\noptions = OrderedDict(\n    {\n        \"idtype_int\": [False, True],\n        \"max_in_degree\": [None, 8],\n        \"to_block\": [False, True],\n    }\n)\n\n\ndef generate_graph():\n    u = torch.tensor([0, 1, 0, 2, 3, 0, 4, 0, 5, 0, 6, 7, 0, 8, 9])\n    v = torch.tensor([1, 9, 2, 9, 9, 4, 9, 5, 9, 6, 9, 9, 8, 9, 0])\n    g = dgl.graph((u, v))\n    return g\n\n\n@pytest.mark.parametrize(\",\".join(options.keys()), product(*options.values()))\ndef test_SAGEConv_equality(idtype_int, max_in_degree, to_block):\n    device = \"cuda:0\"\n    in_feat, out_feat = 5, 2\n    kwargs = {\"aggregator_type\": \"mean\"}\n    g = generate_graph().to(device)\n    if idtype_int:\n        g = g.int()\n    if to_block:\n        g = dgl.to_block(g)\n    feat = torch.rand(g.num_src_nodes(), in_feat).to(device)\n\n    torch.manual_seed(0)\n    conv1 = SAGEConv(in_feat, out_feat, **kwargs).to(device)\n\n    torch.manual_seed(0)\n    conv2 = CuGraphSAGEConv(in_feat, out_feat, **kwargs).to(device)\n\n    with torch.no_grad():\n        conv2.linear.weight.data[:, :in_feat] = conv1.fc_neigh.weight.data\n        conv2.linear.weight.data[:, in_feat:] = conv1.fc_self.weight.data\n        conv2.linear.bias.data[:] = conv1.fc_self.bias.data\n\n    out1 = conv1(g, feat)\n    out2 = conv2(g, feat, max_in_degree=max_in_degree)\n    assert torch.allclose(out1, out2, atol=1e-06)\n\n    grad_out = torch.rand_like(out1)\n    out1.backward(grad_out)\n    out2.backward(grad_out)\n    assert torch.allclose(\n        conv1.fc_neigh.weight.grad,\n        conv2.linear.weight.grad[:, :in_feat],\n        atol=1e-6,\n    )\n    assert torch.allclose(\n        conv1.fc_self.weight.grad,\n        conv2.linear.weight.grad[:, in_feat:],\n        atol=1e-6,\n    )\n    assert torch.allclose(\n        conv1.fc_self.bias.grad, conv2.linear.bias.grad, atol=1e-6\n    )\n"
  },
  {
    "path": "tests/cugraph/test_basics.py",
    "content": "# NOTE(vibwu): Currently cugraph must be imported before torch to avoid a resource cleanup issue.\n#    See https://github.com/rapidsai/cugraph/issues/2718\nimport cugraph  # usort: skip\nimport backend as F\n\nimport dgl\n\n\ndef test_dummy():\n    cg = cugraph.Graph()\n    assert cg is not None\n\n\ndef test_to_cugraph_conversion():\n    g = dgl.graph((F.tensor([0, 1, 2, 3]), F.tensor([1, 0, 3, 2]))).to(\"cuda\")\n    cugraph_g = g.to_cugraph()\n\n    assert cugraph_g.number_of_nodes() == g.num_nodes()\n    assert cugraph_g.number_of_edges() == g.num_edges()\n\n    assert cugraph_g.has_edge(0, 1)\n    assert cugraph_g.has_edge(1, 0)\n    assert cugraph_g.has_edge(3, 2)\n\n\ndef test_from_cugraph_conversion():\n    # cudf is a dependency of cugraph\n    import cudf\n\n    # directed graph conversion test\n    cugraph_g = cugraph.Graph(directed=True)\n    df = cudf.DataFrame({\"source\": [0, 1, 2, 3], \"destination\": [1, 2, 3, 2]})\n\n    cugraph_g.from_cudf_edgelist(df)\n\n    g = dgl.from_cugraph(cugraph_g)\n\n    assert g.device.type == \"cuda\"\n    assert g.num_nodes() == cugraph_g.number_of_nodes()\n    assert g.num_edges() == cugraph_g.number_of_edges()\n\n    # assert reverse edges are not present\n    assert g.has_edges_between(0, 1)\n    assert not g.has_edges_between(1, 0)\n    assert g.has_edges_between(1, 2)\n    assert not g.has_edges_between(2, 1)\n    assert g.has_edges_between(2, 3)\n\n    # undirected graph conversion test\n    cugraph_g = cugraph.Graph(directed=False)\n    df = cudf.DataFrame({\"source\": [0, 1, 2, 3], \"destination\": [1, 2, 3, 2]})\n\n    cugraph_g.from_cudf_edgelist(df)\n\n    g = dgl.from_cugraph(cugraph_g)\n\n    assert g.device.type == \"cuda\"\n    assert g.num_nodes() == cugraph_g.number_of_nodes()\n    # assert reverse edges are present\n    assert g.has_edges_between(0, 1)\n    assert g.has_edges_between(1, 0)\n    assert g.has_edges_between(1, 2)\n    assert g.has_edges_between(2, 1)\n    assert g.has_edges_between(2, 3)\n"
  },
  {
    "path": "tests/dist/python/rpc_basic.py",
    "content": "import os\n\nimport backend as F\n\nimport dgl\nfrom numpy.testing import assert_array_equal\n\nINTEGER = 2\nSTR = \"hello world!\"\nHELLO_SERVICE_ID = 901231\nTENSOR = F.zeros((1000, 1000), F.int64, F.cpu())\n\n\ndef tensor_func(tensor):\n    return tensor * 2\n\n\nclass HelloResponse(dgl.distributed.Response):\n    def __init__(self, hello_str, integer, tensor):\n        self.hello_str = hello_str\n        self.integer = integer\n        self.tensor = tensor\n\n    def __getstate__(self):\n        return self.hello_str, self.integer, self.tensor\n\n    def __setstate__(self, state):\n        self.hello_str, self.integer, self.tensor = state\n\n\nclass HelloRequest(dgl.distributed.Request):\n    def __init__(self, hello_str, integer, tensor, func):\n        self.hello_str = hello_str\n        self.integer = integer\n        self.tensor = tensor\n        self.func = func\n\n    def __getstate__(self):\n        return self.hello_str, self.integer, self.tensor, self.func\n\n    def __setstate__(self, state):\n        self.hello_str, self.integer, self.tensor, self.func = state\n\n    def process_request(self, server_state):\n        assert self.hello_str == STR\n        assert self.integer == INTEGER\n        new_tensor = self.func(self.tensor)\n        res = HelloResponse(self.hello_str, self.integer, new_tensor)\n        return res\n\n\ndef start_server(server_id, ip_config, num_servers, num_clients, keep_alive):\n    server_state = dgl.distributed.ServerState(\n        None, local_g=None, partition_book=None, keep_alive=keep_alive\n    )\n    dgl.distributed.register_service(\n        HELLO_SERVICE_ID, HelloRequest, HelloResponse\n    )\n    print(\"Start server {}\".format(server_id))\n    dgl.distributed.start_server(\n        server_id=server_id,\n        ip_config=ip_config,\n        num_servers=num_servers,\n        num_clients=num_clients,\n        server_state=server_state,\n    )\n\n\ndef start_client(ip_config, num_servers, group_id):\n    dgl.distributed.register_service(\n        HELLO_SERVICE_ID, HelloRequest, HelloResponse\n    )\n    dgl.distributed.connect_to_server(\n        ip_config=ip_config,\n        num_servers=num_servers,\n        group_id=group_id,\n    )\n    req = HelloRequest(STR, INTEGER, TENSOR, tensor_func)\n    server_namebook = dgl.distributed.read_ip_config(ip_config, num_servers)\n    for server_id in server_namebook.keys():\n        # test send and recv\n        dgl.distributed.send_request(server_id, req)\n        res = dgl.distributed.recv_response()\n        assert res.hello_str == STR\n        assert res.integer == INTEGER\n        assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))\n        # test remote_call\n        target_and_requests = []\n        for i in range(10):\n            target_and_requests.append((server_id, req))\n        res_list = dgl.distributed.remote_call(target_and_requests)\n        for res in res_list:\n            assert res.hello_str == STR\n            assert res.integer == INTEGER\n            assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))\n        # test send_request_to_machine\n        dgl.distributed.send_request_to_machine(server_id, req)\n        res = dgl.distributed.recv_response()\n        assert res.hello_str == STR\n        assert res.integer == INTEGER\n        assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))\n        # test remote_call_to_machine\n        target_and_requests = []\n        for i in range(10):\n            target_and_requests.append((server_id, req))\n        res_list = dgl.distributed.remote_call_to_machine(target_and_requests)\n        for res in res_list:\n            assert res.hello_str == STR\n            assert res.integer == INTEGER\n            assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))\n\n\ndef main():\n    ip_config = os.environ.get(\"DIST_DGL_TEST_IP_CONFIG\")\n    num_servers = int(os.environ.get(\"DIST_DGL_TEST_NUM_SERVERS\"))\n    if os.environ.get(\"DIST_DGL_TEST_ROLE\", \"server\") == \"server\":\n        server_id = int(os.environ.get(\"DIST_DGL_TEST_SERVER_ID\"))\n        num_clients = int(os.environ.get(\"DIST_DGL_TEST_NUM_CLIENTS\"))\n        keep_alive = \"DIST_DGL_TEST_KEEP_ALIVE\" in os.environ\n        start_server(server_id, ip_config, num_servers, num_clients, keep_alive)\n    else:\n        group_id = int(os.environ.get(\"DIST_DGL_TEST_GROUP_ID\", \"0\"))\n        start_client(ip_config, num_servers, group_id)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "tests/dist/python/run_dist_objects.py",
    "content": "import json\nimport os\nfrom itertools import product\n\nimport dgl\nimport dgl.backend as F\n\nimport numpy as np\nfrom dgl.distributed import edge_split, load_partition_book, node_split\n\nmode = os.environ.get(\"DIST_DGL_TEST_MODE\", \"\")\ngraph_name = os.environ.get(\"DIST_DGL_TEST_GRAPH_NAME\", \"random_test_graph\")\nnum_part = int(os.environ.get(\"DIST_DGL_TEST_NUM_PART\"))\nnum_servers_per_machine = int(os.environ.get(\"DIST_DGL_TEST_NUM_SERVER\"))\nnum_client_per_machine = int(os.environ.get(\"DIST_DGL_TEST_NUM_CLIENT\"))\nshared_workspace = os.environ.get(\"DIST_DGL_TEST_WORKSPACE\")\ngraph_path = os.environ.get(\"DIST_DGL_TEST_GRAPH_PATH\")\npart_id = int(os.environ.get(\"DIST_DGL_TEST_PART_ID\"))\nip_config = os.environ.get(\"DIST_DGL_TEST_IP_CONFIG\", \"ip_config.txt\")\n\nos.environ[\"DGL_DIST_MODE\"] = \"distributed\"\n\n\ndef batched_assert_zero(tensor, size):\n    BATCH_SIZE = 2**16\n    curr_pos = 0\n    while curr_pos < size:\n        end = min(curr_pos + BATCH_SIZE, size)\n        assert F.sum(tensor[F.arange(curr_pos, end)], 0) == 0\n        curr_pos = end\n\n\ndef zeros_init(shape, dtype):\n    return F.zeros(shape, dtype=dtype, ctx=F.cpu())\n\n\ndef rand_init(shape, dtype):\n    return F.tensor((np.random.randint(0, 100, size=shape) > 30), dtype=dtype)\n\n\ndef run_server(\n    graph_name,\n    server_id,\n    server_count,\n    num_clients,\n    shared_mem,\n):\n    # server_count = num_servers_per_machine\n    g = dgl.distributed.DistGraphServer(\n        server_id,\n        ip_config,\n        server_count,\n        num_clients,\n        graph_path + \"/{}.json\".format(graph_name),\n        disable_shared_mem=not shared_mem,\n        graph_format=[\"csc\", \"coo\"],\n    )\n    print(\"start server\", server_id)\n    g.start()\n\n\n##########################################\n############### DistGraph ###############\n##########################################\n\n\ndef node_split_test(g, force_even, ntype=\"_N\"):\n    gpb = g.get_partition_book()\n\n    selected_nodes_dist_tensor = dgl.distributed.DistTensor(\n        [g.num_nodes(ntype)], F.uint8, init_func=rand_init\n    )\n\n    nodes = node_split(\n        selected_nodes_dist_tensor, gpb, ntype=ntype, force_even=force_even\n    )\n    g.barrier()\n\n    selected_nodes_dist_tensor[nodes] = F.astype(\n        F.zeros_like(nodes), selected_nodes_dist_tensor.dtype\n    )\n    g.barrier()\n\n    if g.rank() == 0:\n        batched_assert_zero(selected_nodes_dist_tensor, g.num_nodes(ntype))\n\n    g.barrier()\n\n\ndef edge_split_test(g, force_even, etype=\"_E\"):\n    gpb = g.get_partition_book()\n\n    selected_edges_dist_tensor = dgl.distributed.DistTensor(\n        [g.num_edges(etype)], F.uint8, init_func=rand_init\n    )\n\n    edges = edge_split(\n        selected_edges_dist_tensor, gpb, etype=etype, force_even=force_even\n    )\n    g.barrier()\n\n    selected_edges_dist_tensor[edges] = F.astype(\n        F.zeros_like(edges), selected_edges_dist_tensor.dtype\n    )\n    g.barrier()\n\n    if g.rank() == 0:\n        batched_assert_zero(selected_edges_dist_tensor, g.num_edges(etype))\n\n    g.barrier()\n\n\ndef test_dist_graph(g):\n    gpb_path = graph_path + \"/{}.json\".format(graph_name)\n    with open(gpb_path) as conf_f:\n        part_metadata = json.load(conf_f)\n    assert \"num_nodes\" in part_metadata\n    assert \"num_edges\" in part_metadata\n    num_nodes = part_metadata[\"num_nodes\"]\n    num_edges = part_metadata[\"num_edges\"]\n\n    assert g.num_nodes() == num_nodes\n    assert g.num_edges() == num_edges\n\n    num_nodes = {ntype: g.num_nodes(ntype) for ntype in g.ntypes}\n    num_edges = {etype: g.num_edges(etype) for etype in g.etypes}\n\n    for key, n_nodes in num_nodes.items():\n        assert g.num_nodes(key) == n_nodes\n        node_split_test(g, force_even=False, ntype=key)\n        node_split_test(g, force_even=True, ntype=key)\n\n    for key, n_edges in num_edges.items():\n        assert g.num_edges(key) == n_edges\n        edge_split_test(g, force_even=False, etype=key)\n        edge_split_test(g, force_even=True, etype=key)\n\n\n##########################################\n########### DistGraphServices ###########\n##########################################\n\n\ndef find_edges_test(g, orig_nid_map):\n    etypes = g.canonical_etypes\n\n    etype_eids_uv_map = dict()\n    for u_type, etype, v_type in etypes:\n        orig_u = g.edges[etype].data[\"edge_u\"]\n        orig_v = g.edges[etype].data[\"edge_v\"]\n        eids = F.tensor(np.random.randint(g.num_edges(etype), size=100))\n        u, v = g.find_edges(eids, etype=etype)\n        assert F.allclose(orig_nid_map[u_type][u], orig_u[eids])\n        assert F.allclose(orig_nid_map[v_type][v], orig_v[eids])\n        etype_eids_uv_map[etype] = (eids, F.cat([u, v], dim=0))\n    return etype_eids_uv_map\n\n\ndef edge_subgraph_test(g, etype_eids_uv_map):\n    etypes = g.canonical_etypes\n    all_eids = dict()\n    for t in etypes:\n        all_eids[t] = etype_eids_uv_map[t[1]][0]\n\n    sg = g.edge_subgraph(all_eids)\n    for t in etypes:\n        assert sg.num_edges(t[1]) == len(all_eids[t])\n        assert F.allclose(sg.edges[t].data[dgl.EID], all_eids[t])\n\n    for u_type, etype, v_type in etypes:\n        uv = etype_eids_uv_map[etype][1]\n        sg_u_nids = sg.nodes[u_type].data[dgl.NID]\n        sg_v_nids = sg.nodes[v_type].data[dgl.NID]\n        sg_uv = F.cat([sg_u_nids, sg_v_nids], dim=0)\n        for node_id in uv:\n            assert node_id in sg_uv\n\n\ndef sample_neighbors_with_args(g, size, fanout):\n    num_nodes = {ntype: g.num_nodes(ntype) for ntype in g.ntypes}\n    etypes = g.canonical_etypes\n\n    sampled_graph = g.sample_neighbors(\n        {\n            ntype: np.random.randint(0, n, size=size)\n            for ntype, n in num_nodes.items()\n        },\n        fanout,\n    )\n\n    for ntype, n in num_nodes.items():\n        assert sampled_graph.num_nodes(ntype) == n\n    for t in etypes:\n        src, dst = sampled_graph.edges(etype=t)\n        eids = sampled_graph.edges[t].data[dgl.EID]\n        dist_u, dist_v = g.find_edges(eids, etype=t[1])\n        assert F.allclose(dist_u, src)\n        assert F.allclose(dist_v, dst)\n\n\ndef sample_neighbors_test(g):\n    sample_neighbors_with_args(g, size=1024, fanout=3)\n    sample_neighbors_with_args(g, size=1, fanout=10)\n    sample_neighbors_with_args(g, size=1024, fanout=2)\n    sample_neighbors_with_args(g, size=10, fanout=-1)\n    sample_neighbors_with_args(g, size=2**10, fanout=1)\n    sample_neighbors_with_args(g, size=2**12, fanout=1)\n\n\ndef test_dist_graph_services(g):\n    # in_degrees and out_degrees does not support heterograph\n    if len(g.etypes) == 1:\n        nids = F.arange(0, 128)\n\n        # Test in_degrees\n        orig_in_degrees = g.ndata[\"in_degrees\"]\n        local_in_degrees = g.in_degrees(nids)\n        F.allclose(local_in_degrees, orig_in_degrees[nids])\n\n        # Test out_degrees\n        orig_out_degrees = g.ndata[\"out_degrees\"]\n        local_out_degrees = g.out_degrees(nids)\n        F.allclose(local_out_degrees, orig_out_degrees[nids])\n\n    num_nodes = {ntype: g.num_nodes(ntype) for ntype in g.ntypes}\n\n    orig_nid_map = dict()\n    dtype = g.edges[g.etypes[0]].data[\"edge_u\"].dtype\n    for ntype, _ in num_nodes.items():\n        orig_nid = F.tensor(\n            np.load(graph_path + f\"/orig_nid_array_{ntype}.npy\"), dtype\n        )\n        orig_nid_map[ntype] = orig_nid\n\n    etype_eids_uv_map = find_edges_test(g, orig_nid_map)\n    edge_subgraph_test(g, etype_eids_uv_map)\n    sample_neighbors_test(g)\n\n\n##########################################\n############### DistTensor ###############\n##########################################\n\n\ndef dist_tensor_test_sanity(data_shape, name=None):\n    local_rank = dgl.distributed.get_rank() % num_client_per_machine\n    dist_ten = dgl.distributed.DistTensor(\n        data_shape, F.int32, init_func=zeros_init, name=name\n    )\n    # arbitrary value\n    stride = 3\n    pos = (part_id // 2) * num_client_per_machine + local_rank\n    if part_id % 2 == 0:\n        dist_ten[pos * stride : (pos + 1) * stride] = F.ones(\n            (stride, 2), dtype=F.int32, ctx=F.cpu()\n        ) * (pos + 1)\n\n    dgl.distributed.client_barrier()\n    assert F.allclose(\n        dist_ten[pos * stride : (pos + 1) * stride],\n        F.ones((stride, 2), dtype=F.int32, ctx=F.cpu()) * (pos + 1),\n    )\n\n\ndef dist_tensor_test_destroy_recreate(data_shape, name):\n    dist_ten = dgl.distributed.DistTensor(\n        data_shape, F.float32, name, init_func=zeros_init\n    )\n    del dist_ten\n\n    dgl.distributed.client_barrier()\n\n    new_shape = (data_shape[0], 4)\n    dist_ten = dgl.distributed.DistTensor(\n        new_shape, F.float32, name, init_func=zeros_init\n    )\n\n\ndef dist_tensor_test_persistent(data_shape):\n    dist_ten_name = \"persistent_dist_tensor\"\n    dist_ten = dgl.distributed.DistTensor(\n        data_shape,\n        F.float32,\n        dist_ten_name,\n        init_func=zeros_init,\n        persistent=True,\n    )\n    del dist_ten\n    try:\n        dist_ten = dgl.distributed.DistTensor(\n            data_shape, F.float32, dist_ten_name\n        )\n        raise Exception(\"\")\n    except BaseException:\n        pass\n\n\ndef test_dist_tensor(g):\n    first_type = g.ntypes[0]\n    data_shape = (g.num_nodes(first_type), 2)\n    dist_tensor_test_sanity(data_shape)\n    dist_tensor_test_sanity(data_shape, name=\"DistTensorSanity\")\n    dist_tensor_test_destroy_recreate(data_shape, name=\"DistTensorRecreate\")\n    dist_tensor_test_persistent(data_shape)\n\n\n##########################################\n############# DistEmbedding ##############\n##########################################\n\n\ndef dist_embedding_check_sanity(num_nodes, optimizer, name=None):\n    local_rank = dgl.distributed.get_rank() % num_client_per_machine\n\n    emb = dgl.distributed.DistEmbedding(\n        num_nodes, 1, name=name, init_func=zeros_init\n    )\n    lr = 0.001\n    optim = optimizer(params=[emb], lr=lr)\n\n    stride = 3\n\n    pos = (part_id // 2) * num_client_per_machine + local_rank\n    idx = F.arange(pos * stride, (pos + 1) * stride)\n\n    if part_id % 2 == 0:\n        with F.record_grad():\n            value = emb(idx)\n            optim.zero_grad()\n            loss = F.sum(value + 1, 0)\n        loss.backward()\n        optim.step()\n\n    dgl.distributed.client_barrier()\n    value = emb(idx)\n    F.allclose(value, F.ones((len(idx), 1), dtype=F.int32, ctx=F.cpu()) * -lr)\n\n    not_update_idx = F.arange(\n        ((num_part + 1) / 2) * num_client_per_machine * stride, num_nodes\n    )\n    value = emb(not_update_idx)\n    assert np.all(F.asnumpy(value) == np.zeros((len(not_update_idx), 1)))\n\n\ndef dist_embedding_check_existing(num_nodes):\n    dist_emb_name = \"UniqueEmb\"\n    emb = dgl.distributed.DistEmbedding(\n        num_nodes, 1, name=dist_emb_name, init_func=zeros_init\n    )\n    try:\n        emb1 = dgl.distributed.DistEmbedding(\n            num_nodes, 2, name=dist_emb_name, init_func=zeros_init\n        )\n        raise Exception(\"\")\n    except BaseException:\n        pass\n\n\ndef test_dist_embedding(g):\n    num_nodes = g.num_nodes(g.ntypes[0])\n    dist_embedding_check_sanity(num_nodes, dgl.distributed.optim.SparseAdagrad)\n    dist_embedding_check_sanity(\n        num_nodes, dgl.distributed.optim.SparseAdagrad, name=\"SomeEmbedding\"\n    )\n    dist_embedding_check_sanity(\n        num_nodes, dgl.distributed.optim.SparseAdam, name=\"SomeEmbedding\"\n    )\n\n    dist_embedding_check_existing(num_nodes)\n\n\n##########################################\n############# DistOptimizer ##############\n##########################################\n\n\ndef dist_optimizer_check_store(g):\n    num_nodes = g.num_nodes(g.ntypes[0])\n    rank = g.rank()\n    try:\n        emb = dgl.distributed.DistEmbedding(\n            num_nodes, 1, name=\"optimizer_test\", init_func=zeros_init\n        )\n        emb2 = dgl.distributed.DistEmbedding(\n            num_nodes, 5, name=\"optimizer_test2\", init_func=zeros_init\n        )\n        emb_optimizer = dgl.distributed.optim.SparseAdam([emb, emb2], lr=0.1)\n        if rank == 0:\n            name_to_state = {}\n            for _, emb_states in emb_optimizer._state.items():\n                for state in emb_states:\n                    name_to_state[state.name] = F.uniform(\n                        state.shape, F.float32, F.cpu(), 0, 1\n                    )\n                    state[\n                        F.arange(0, num_nodes, F.int64, F.cpu())\n                    ] = name_to_state[state.name]\n        emb_optimizer.save(\"emb.pt\")\n        new_emb_optimizer = dgl.distributed.optim.SparseAdam(\n            [emb, emb2], lr=000.1, eps=2e-08, betas=(0.1, 0.222)\n        )\n        new_emb_optimizer.load(\"emb.pt\")\n        if rank == 0:\n            for _, emb_states in new_emb_optimizer._state.items():\n                for new_state in emb_states:\n                    state = name_to_state[new_state.name]\n                    new_state = new_state[\n                        F.arange(0, num_nodes, F.int64, F.cpu())\n                    ]\n                    assert F.allclose(state, new_state, 0.0, 0.0)\n            assert new_emb_optimizer._lr == emb_optimizer._lr\n            assert new_emb_optimizer._eps == emb_optimizer._eps\n            assert new_emb_optimizer._beta1 == emb_optimizer._beta1\n            assert new_emb_optimizer._beta2 == emb_optimizer._beta2\n        g.barrier()\n    finally:\n        file = f\"emb.pt_{rank}\"\n        if os.path.exists(file):\n            os.remove(file)\n\n\ndef test_dist_optimizer(g):\n    dist_optimizer_check_store(g)\n\n\n##########################################\n############# DistDataLoader #############\n##########################################\n\n\nclass NeighborSampler(object):\n    def __init__(self, g, fanouts, sample_neighbors):\n        self.g = g\n        self.fanouts = fanouts\n        self.sample_neighbors = sample_neighbors\n\n    def sample_blocks(self, seeds):\n        import torch as th\n\n        seeds = th.LongTensor(np.asarray(seeds))\n        blocks = []\n        for fanout in self.fanouts:\n            # For each seed node, sample ``fanout`` neighbors.\n            frontier = self.sample_neighbors(\n                self.g, seeds, fanout, replace=True\n            )\n            # Then we compact the frontier into a bipartite graph for\n            # message passing.\n            block = dgl.to_block(frontier, seeds)\n            # Obtain the seed nodes for next layer.\n            seeds = block.srcdata[dgl.NID]\n            block.edata[\"original_eids\"] = frontier.edata[dgl.EID]\n\n            blocks.insert(0, block)\n        return blocks\n\n\ndef distdataloader_test(g, batch_size, drop_last, shuffle):\n    # We sample only a subset to minimize the test runtime\n    num_nodes_to_sample = int(g.num_nodes() * 0.05)\n    # To make sure that drop_last is tested\n    if num_nodes_to_sample % batch_size == 0:\n        num_nodes_to_sample -= 1\n\n    orig_nid_map = dict()\n    dtype = g.edges[g.etypes[0]].data[\"edge_u\"].dtype\n    for ntype in g.ntypes:\n        orig_nid = F.tensor(\n            np.load(graph_path + f\"/orig_nid_array_{ntype}.npy\"), dtype\n        )\n        orig_nid_map[ntype] = orig_nid\n\n    orig_uv_map = dict()\n    for etype in g.etypes:\n        orig_uv_map[etype] = (\n            g.edges[etype].data[\"edge_u\"],\n            g.edges[etype].data[\"edge_v\"],\n        )\n\n    if len(g.ntypes) == 1:\n        train_nid = F.arange(0, num_nodes_to_sample)\n    else:\n        train_nid = {g.ntypes[0]: F.arange(0, num_nodes_to_sample)}\n\n    sampler = NeighborSampler(g, [5, 10], dgl.distributed.sample_neighbors)\n\n    dataloader = dgl.dataloading.DistDataLoader(\n        dataset=train_nid.numpy(),\n        batch_size=batch_size,\n        collate_fn=sampler.sample_blocks,\n        shuffle=shuffle,\n        drop_last=drop_last,\n    )\n\n    for _ in range(2):\n        max_nid = []\n        for idx, blocks in zip(\n            range(0, num_nodes_to_sample, batch_size), dataloader\n        ):\n            block = blocks[-1]\n            for src_type, etype, dst_type in block.canonical_etypes:\n                orig_u, orig_v = orig_uv_map[etype]\n                o_src, o_dst = block.edges(etype=etype)\n                src_nodes_id = block.srcnodes[src_type].data[dgl.NID][o_src]\n                dst_nodes_id = block.dstnodes[dst_type].data[dgl.NID][o_dst]\n                max_nid.append(np.max(F.asnumpy(dst_nodes_id)))\n\n                src_nodes_id = orig_nid_map[src_type][src_nodes_id]\n                dst_nodes_id = orig_nid_map[dst_type][dst_nodes_id]\n                eids = block.edata[\"original_eids\"]\n                F.allclose(src_nodes_id, orig_u[eids])\n                F.allclose(dst_nodes_id, orig_v[eids])\n        if not shuffle and len(max_nid) > 0:\n            if drop_last:\n                assert (\n                    np.max(max_nid)\n                    == num_nodes_to_sample\n                    - 1\n                    - num_nodes_to_sample % batch_size\n                )\n            else:\n                assert np.max(max_nid) == num_nodes_to_sample - 1\n    del dataloader\n\n\ndef distnodedataloader_test(\n    g, batch_size, drop_last, shuffle, num_workers, orig_nid_map, orig_uv_map\n):\n    # We sample only a subset to minimize the test runtime\n    num_nodes_to_sample = int(g.num_nodes(g.ntypes[-1]) * 0.05)\n    # To make sure that drop_last is tested\n    if num_nodes_to_sample % batch_size == 0:\n        num_nodes_to_sample -= 1\n\n    if len(g.ntypes) == 1:\n        train_nid = F.arange(0, num_nodes_to_sample)\n    else:\n        train_nid = {g.ntypes[-1]: F.arange(0, num_nodes_to_sample)}\n\n    if len(g.etypes) > 1:\n        sampler = dgl.dataloading.MultiLayerNeighborSampler(\n            [\n                {etype: 5 for etype in g.etypes},\n                10,\n            ]\n        )\n    else:\n        sampler = dgl.dataloading.MultiLayerNeighborSampler(\n            [\n                5,\n                10,\n            ]\n        )\n\n    dataloader = dgl.dataloading.DistNodeDataLoader(\n        g,\n        train_nid,\n        sampler,\n        batch_size=batch_size,\n        shuffle=shuffle,\n        drop_last=drop_last,\n        num_workers=num_workers,\n    )\n\n    for _ in range(2):\n        for _, (_, _, blocks) in zip(\n            range(0, num_nodes_to_sample, batch_size), dataloader\n        ):\n            block = blocks[-1]\n            for src_type, etype, dst_type in block.canonical_etypes:\n                orig_u, orig_v = orig_uv_map[etype]\n                o_src, o_dst = block.edges(etype=etype)\n                src_nodes_id = block.srcnodes[src_type].data[dgl.NID][o_src]\n                dst_nodes_id = block.dstnodes[dst_type].data[dgl.NID][o_dst]\n                src_nodes_id = orig_nid_map[src_type][src_nodes_id]\n                dst_nodes_id = orig_nid_map[dst_type][dst_nodes_id]\n                eids = block.edges[etype].data[dgl.EID]\n                F.allclose(src_nodes_id, orig_u[eids])\n                F.allclose(dst_nodes_id, orig_v[eids])\n    del dataloader\n\n\ndef distedgedataloader_test(\n    g,\n    batch_size,\n    drop_last,\n    shuffle,\n    num_workers,\n    orig_nid_map,\n    orig_uv_map,\n    num_negs,\n):\n    # We sample only a subset to minimize the test runtime\n    num_edges_to_sample = int(g.num_edges(g.etypes[-1]) * 0.05)\n    # To make sure that drop_last is tested\n    if num_edges_to_sample % batch_size == 0:\n        num_edges_to_sample -= 1\n\n    if len(g.etypes) == 1:\n        train_eid = F.arange(0, num_edges_to_sample)\n    else:\n        train_eid = {g.etypes[-1]: F.arange(0, num_edges_to_sample)}\n\n    sampler = dgl.dataloading.MultiLayerNeighborSampler([5, 10])\n\n    dataloader = dgl.dataloading.DistEdgeDataLoader(\n        g,\n        train_eid,\n        sampler,\n        batch_size=batch_size,\n        negative_sampler=dgl.dataloading.negative_sampler.Uniform(num_negs)\n        if num_negs > 0\n        else None,\n        shuffle=shuffle,\n        drop_last=drop_last,\n        num_workers=num_workers,\n    )\n    for _ in range(2):\n        for _, sampled_data in zip(\n            range(0, num_edges_to_sample, batch_size), dataloader\n        ):\n            blocks = sampled_data[3 if num_negs > 0 else 2]\n            block = blocks[-1]\n            for src_type, etype, dst_type in block.canonical_etypes:\n                orig_u, orig_v = orig_uv_map[etype]\n                o_src, o_dst = block.edges(etype=etype)\n                src_nodes_id = block.srcnodes[src_type].data[dgl.NID][o_src]\n                dst_nodes_id = block.dstnodes[dst_type].data[dgl.NID][o_dst]\n                src_nodes_id = orig_nid_map[src_type][src_nodes_id]\n                dst_nodes_id = orig_nid_map[dst_type][dst_nodes_id]\n                eids = block.edges[etype].data[dgl.EID]\n                F.allclose(src_nodes_id, orig_u[eids])\n                F.allclose(dst_nodes_id, orig_v[eids])\n                if num_negs == 0:\n                    pos_pair_graph = sampled_data[1]\n                    assert np.all(\n                        F.asnumpy(block.dstnodes[dst_type].data[dgl.NID])\n                        == F.asnumpy(\n                            pos_pair_graph.nodes[dst_type].data[dgl.NID]\n                        )\n                    )\n                else:\n                    pos_graph, neg_graph = sampled_data[1:3]\n                    assert np.all(\n                        F.asnumpy(block.dstnodes[dst_type].data[dgl.NID])\n                        == F.asnumpy(pos_graph.nodes[dst_type].data[dgl.NID])\n                    )\n                    assert np.all(\n                        F.asnumpy(block.dstnodes[dst_type].data[dgl.NID])\n                        == F.asnumpy(neg_graph.nodes[dst_type].data[dgl.NID])\n                    )\n                    assert (\n                        pos_graph.num_edges() * num_negs\n                        == neg_graph.num_edges()\n                    )\n    del dataloader\n\n\ndef multi_distdataloader_test(g, dataloader_class):\n    total_num_items = (\n        g.num_nodes(g.ntypes[-1])\n        if \"Node\" in dataloader_class.__name__\n        else g.num_edges(g.etypes[-1])\n    )\n\n    num_dataloaders = 4\n    batch_size = 32\n    sampler = dgl.dataloading.NeighborSampler([-1])\n    dataloaders = []\n    dl_iters = []\n\n    # We sample only a subset to minimize the test runtime\n    num_items_to_sample = int(total_num_items * 0.05)\n    # To make sure that drop_last is tested\n    if num_items_to_sample % batch_size == 0:\n        num_items_to_sample -= 1\n\n    if len(g.ntypes) == 1:\n        train_ids = F.arange(0, num_items_to_sample)\n    else:\n        train_ids = {\n            g.ntypes[-1]\n            if \"Node\" in dataloader_class.__name__\n            else g.etypes[-1]: F.arange(0, num_items_to_sample)\n        }\n\n    for _ in range(num_dataloaders):\n        dataloader = dataloader_class(\n            g, train_ids, sampler, batch_size=batch_size\n        )\n        dataloaders.append(dataloader)\n        dl_iters.append(iter(dataloader))\n\n    # iterate on multiple dataloaders randomly\n    while len(dl_iters) > 0:\n        current_dl = np.random.choice(len(dl_iters), 1)[0]\n        try:\n            _ = next(dl_iters[current_dl])\n        except StopIteration:\n            dl_iters.pop(current_dl)\n            del dataloaders[current_dl]\n\n\ndef test_dist_dataloader(g):\n    orig_nid_map = dict()\n    dtype = g.edges[g.etypes[0]].data[\"edge_u\"].dtype\n    for ntype in g.ntypes:\n        orig_nid = F.tensor(\n            np.load(graph_path + f\"/orig_nid_array_{ntype}.npy\"), dtype\n        )\n        orig_nid_map[ntype] = orig_nid\n\n    orig_uv_map = dict()\n    for etype in g.etypes:\n        orig_uv_map[etype] = (\n            g.edges[etype].data[\"edge_u\"],\n            g.edges[etype].data[\"edge_v\"],\n        )\n\n    batch_size_l = [64]\n    drop_last_l = [False, True]\n    num_workers_l = [0, 4]\n    shuffle_l = [False, True]\n\n    for batch_size, drop_last, shuffle, num_workers in product(\n        batch_size_l, drop_last_l, shuffle_l, num_workers_l\n    ):\n        if len(g.ntypes) == 1 and num_workers == 0:\n            distdataloader_test(g, batch_size, drop_last, shuffle)\n        distnodedataloader_test(\n            g,\n            batch_size,\n            drop_last,\n            shuffle,\n            num_workers,\n            orig_nid_map,\n            orig_uv_map,\n        )\n        # No negssampling\n        distedgedataloader_test(\n            g,\n            batch_size,\n            drop_last,\n            shuffle,\n            num_workers,\n            orig_nid_map,\n            orig_uv_map,\n            num_negs=0,\n        )\n        # negsampling 15\n        distedgedataloader_test(\n            g,\n            batch_size,\n            drop_last,\n            shuffle,\n            num_workers,\n            orig_nid_map,\n            orig_uv_map,\n            num_negs=15,\n        )\n\n    multi_distdataloader_test(g, dgl.dataloading.DistNodeDataLoader)\n    multi_distdataloader_test(g, dgl.dataloading.DistEdgeDataLoader)\n\n\nif mode == \"server\":\n    shared_mem = bool(int(os.environ.get(\"DIST_DGL_TEST_SHARED_MEM\")))\n    server_id = int(os.environ.get(\"DIST_DGL_TEST_SERVER_ID\"))\n    run_server(\n        graph_name,\n        server_id,\n        server_count=num_servers_per_machine,\n        num_clients=num_part * num_client_per_machine,\n        shared_mem=shared_mem,\n    )\nelif mode == \"client\":\n    os.environ[\"DGL_NUM_SERVER\"] = str(num_servers_per_machine)\n    dgl.distributed.initialize(ip_config)\n\n    gpb, graph_name, _, _ = load_partition_book(\n        graph_path + \"/{}.json\".format(graph_name), part_id\n    )\n    g = dgl.distributed.DistGraph(graph_name, gpb=gpb)\n\n    target_func_map = {\n        \"DistGraph\": test_dist_graph,\n        \"DistGraphServices\": test_dist_graph_services,\n        \"DistTensor\": test_dist_tensor,\n        \"DistEmbedding\": test_dist_embedding,\n        \"DistOptimizer\": test_dist_optimizer,\n        \"DistDataLoader\": test_dist_dataloader,\n    }\n\n    targets = os.environ.get(\"DIST_DGL_TEST_OBJECT_TYPE\", \"\")\n    targets = targets.replace(\" \", \"\").split(\",\") if targets else []\n    blacklist = os.environ.get(\"DIST_DGL_TEST_OBJECT_TYPE_BLACKLIST\", \"\")\n    blacklist = blacklist.replace(\" \", \"\").split(\",\") if blacklist else []\n\n    for to_bl in blacklist:\n        target_func_map.pop(to_bl, None)\n\n    if not targets:\n        for test_func in target_func_map.values():\n            test_func(g)\n    else:\n        for target in targets:\n            if target in target_func_map:\n                target_func_map[target](g)\n            else:\n                print(f\"Tests not implemented for target '{target}'\")\n\nelse:\n    exit(1)\n"
  },
  {
    "path": "tests/dist/test_dist_objects.py",
    "content": "import multiprocessing as mp\nimport os\nimport shutil\nimport subprocess\nimport unittest\n\nimport dgl\nimport dgl.backend as F\n\nimport numpy as np\nimport pytest\nimport utils\nfrom dgl.distributed import partition_graph\n\ngraph_name = os.environ.get(\"DIST_DGL_TEST_GRAPH_NAME\", \"random_test_graph\")\ntarget = os.environ.get(\"DIST_DGL_TEST_OBJECT_TYPE\", \"\")\nblacklist = os.environ.get(\"DIST_DGL_TEST_OBJECT_TYPE_BLACKLIST\", \"\")\nshared_workspace = os.environ.get(\n    \"DIST_DGL_TEST_WORKSPACE\", \"/shared_workspace/dgl_dist_tensor_test/\"\n)\n\n\ndef create_graph(num_part, dist_graph_path, hetero):\n    if not hetero:\n        g = dgl.rand_graph(10000, 42000)\n        g.ndata[\"feat\"] = F.unsqueeze(F.arange(0, g.num_nodes()), 1)\n        g.edata[\"feat\"] = F.unsqueeze(F.arange(0, g.num_edges()), 1)\n        g.ndata[\"in_degrees\"] = g.in_degrees()\n        g.ndata[\"out_degrees\"] = g.out_degrees()\n\n        etype = g.etypes[0]\n        ntype = g.ntypes[0]\n        edge_u, edge_v = g.find_edges(F.arange(0, g.num_edges(etype)))\n        g.edges[etype].data[\"edge_u\"] = edge_u\n        g.edges[etype].data[\"edge_v\"] = edge_v\n\n        orig_nid, orig_eid = partition_graph(\n            g, graph_name, num_part, dist_graph_path, return_mapping=True\n        )\n\n        orig_nid_f = os.path.join(\n            dist_graph_path, f\"orig_nid_array_{ntype}.npy\"\n        )\n        np.save(orig_nid_f, orig_nid.numpy())\n        orig_eid_f = os.path.join(\n            dist_graph_path, f\"orig_eid_array_{etype}.npy\"\n        )\n        np.save(orig_eid_f, orig_eid.numpy())\n\n    else:\n        from scipy import sparse as spsp\n\n        num_nodes = {\"n1\": 10000, \"n2\": 10010, \"n3\": 10020}\n        etypes = [(\"n1\", \"r1\", \"n2\"), (\"n1\", \"r2\", \"n3\"), (\"n2\", \"r3\", \"n3\")]\n        edges = {}\n        for etype in etypes:\n            src_ntype, _, dst_ntype = etype\n            arr = spsp.random(\n                num_nodes[src_ntype],\n                num_nodes[dst_ntype],\n                density=0.001,\n                format=\"coo\",\n                random_state=100,\n            )\n            edges[etype] = (arr.row, arr.col)\n        g = dgl.heterograph(edges, num_nodes)\n\n        g.nodes[\"n1\"].data[\"feat\"] = F.unsqueeze(\n            F.arange(0, g.num_nodes(\"n1\")), 1\n        )\n        g.edges[\"r1\"].data[\"feat\"] = F.unsqueeze(\n            F.arange(0, g.num_edges(\"r1\")), 1\n        )\n\n        for _, etype, _ in etypes:\n            edge_u, edge_v = g.find_edges(\n                F.arange(0, g.num_edges(etype)), etype=etype\n            )\n            g.edges[etype].data[\"edge_u\"] = edge_u\n            g.edges[etype].data[\"edge_v\"] = edge_v\n\n        orig_nid, orig_eid = partition_graph(\n            g, graph_name, num_part, dist_graph_path, return_mapping=True\n        )\n\n        for n_type, tensor in orig_nid.items():\n            orig_nid_f = os.path.join(\n                dist_graph_path, f\"orig_nid_array_{n_type}.npy\"\n            )\n            np.save(orig_nid_f, tensor.numpy())\n        for e_type, tensor in orig_eid.items():\n            orig_eid_f = os.path.join(\n                dist_graph_path, f\"orig_eid_array_{e_type}.npy\"\n            )\n            np.save(orig_eid_f, tensor.numpy())\n\n\n@unittest.skipIf(os.name == \"nt\", reason=\"Do not support windows yet\")\n@pytest.mark.parametrize(\"num_servers\", [1, 4])\n@pytest.mark.parametrize(\"num_clients\", [1, 4])\n@pytest.mark.parametrize(\"hetero\", [False, True])\n@pytest.mark.parametrize(\"shared_mem\", [False, True])\ndef test_dist_objects(num_servers, num_clients, hetero, shared_mem):\n    if not shared_mem and num_servers > 1:\n        pytest.skip(\n            f\"Backup servers are not supported when shared memory is disabled\"\n        )\n    ip_config = os.environ.get(\"DIST_DGL_TEST_IP_CONFIG\", \"ip_config.txt\")\n\n    ips = utils.get_ips(ip_config)\n    num_part = len(ips)\n\n    test_bin = os.path.join(\n        os.environ.get(\"DIST_DGL_TEST_PY_BIN_DIR\", \".\"), \"run_dist_objects.py\"\n    )\n\n    dist_graph_path = os.path.join(\n        shared_workspace, \"hetero_dist_graph\" if hetero else \"dist_graph\"\n    )\n    if not os.path.isdir(dist_graph_path):\n        create_graph(num_part, dist_graph_path, hetero)\n\n    base_envs = (\n        f\"DIST_DGL_TEST_WORKSPACE={shared_workspace} \"\n        f\"DIST_DGL_TEST_NUM_PART={num_part} \"\n        f\"DIST_DGL_TEST_NUM_SERVER={num_servers} \"\n        f\"DIST_DGL_TEST_NUM_CLIENT={num_clients} \"\n        f\"DIST_DGL_TEST_GRAPH_PATH={dist_graph_path} \"\n        f\"DIST_DGL_TEST_IP_CONFIG={ip_config} \"\n    )\n\n    procs = []\n    # Start server\n    server_id = 0\n    for part_id, ip in enumerate(ips):\n        for _ in range(num_servers):\n            cmd_envs = (\n                base_envs + f\"DIST_DGL_TEST_SERVER_ID={server_id} \"\n                f\"DIST_DGL_TEST_PART_ID={part_id} \"\n                f\"DIST_DGL_TEST_SHARED_MEM={str(int(shared_mem))} \"\n                f\"DIST_DGL_TEST_MODE=server \"\n            )\n            procs.append(\n                utils.execute_remote(f\"{cmd_envs} python3 {test_bin}\", ip)\n            )\n            server_id += 1\n    # Start client processes\n    for part_id, ip in enumerate(ips):\n        for _ in range(num_clients):\n            cmd_envs = (\n                base_envs + f\"DIST_DGL_TEST_PART_ID={part_id} \"\n                f\"DIST_DGL_TEST_OBJECT_TYPE={target} \"\n                f\"DIST_DGL_TEST_OBJECT_TYPE_BLACKLIST={blacklist} \"\n                f\"DIST_DGL_TEST_MODE=client \"\n            )\n            procs.append(\n                utils.execute_remote(f\"{cmd_envs} python3 {test_bin}\", ip)\n            )\n\n    for p in procs:\n        p.join()\n        assert p.exitcode == 0\n\n\ndef teardown():\n    for name in [\"dist_graph\", \"hetero_dist_graph\"]:\n        path = os.path.join(shared_workspace, name)\n        if os.path.exists(path):\n            print(f\"Removing {path}...\")\n            shutil.rmtree(path)\n"
  },
  {
    "path": "tests/dist/test_rpc.py",
    "content": "import multiprocessing as mp\nimport os\nimport unittest\n\nimport pytest\nimport utils\n\ndgl_envs = f\"PYTHONUNBUFFERED=1 DMLC_LOG_DEBUG=1 DGLBACKEND={os.environ.get('DGLBACKEND')} DGL_LIBRARY_PATH={os.environ.get('DGL_LIBRARY_PATH')} PYTHONPATH={os.environ.get('PYTHONPATH')} \"\n\n\n@unittest.skipIf(os.name == \"nt\", reason=\"Do not support windows yet\")\ndef test_rpc():\n    ip_config = os.environ.get(\"DIST_DGL_TEST_IP_CONFIG\", \"ip_config.txt\")\n    num_clients = 1\n    num_servers = 1\n    ips = utils.get_ips(ip_config)\n    num_machines = len(ips)\n    test_bin = os.path.join(\n        os.environ.get(\"DIST_DGL_TEST_PY_BIN_DIR\", \".\"), \"rpc_basic.py\"\n    )\n    base_envs = (\n        dgl_envs\n        + f\" DGL_DIST_MODE=distributed DIST_DGL_TEST_IP_CONFIG={ip_config} DIST_DGL_TEST_NUM_SERVERS={num_servers} \"\n    )\n    procs = []\n    # start server processes\n    server_id = 0\n    for ip in ips:\n        for _ in range(num_servers):\n            server_envs = (\n                base_envs\n                + f\" DIST_DGL_TEST_ROLE=server DIST_DGL_TEST_SERVER_ID={server_id} DIST_DGL_TEST_NUM_CLIENTS={num_clients * num_machines} \"\n            )\n            procs.append(\n                utils.execute_remote(server_envs + \" python3 \" + test_bin, ip)\n            )\n            server_id += 1\n    # start client processes\n    client_envs = (\n        base_envs + \" DIST_DGL_TEST_ROLE=client DIST_DGL_TEST_GROUP_ID=0 \"\n    )\n    for ip in ips:\n        for _ in range(num_clients):\n            procs.append(\n                utils.execute_remote(client_envs + \" python3 \" + test_bin, ip)\n            )\n    for p in procs:\n        p.join()\n        assert p.exitcode == 0\n"
  },
  {
    "path": "tests/dist/utils.py",
    "content": "import multiprocessing as mp\nimport os\nimport subprocess\nfrom typing import Optional\n\n\ndef run(ssh_cmd):\n    subprocess.check_call(ssh_cmd, shell=True)\n\n\ndef execute_remote(\n    cmd: str, ip: str, port: Optional[int] = 22, username: Optional[str] = \"\"\n) -> mp.Process:\n    \"\"\"Execute command line on remote machine via ssh.\n\n    Args:\n        cmd: User-defined command (udf) to execute on the remote host.\n        ip: The ip-address of the host to run the command on.\n        port: Port number that the host is listening on.\n        username: Optional. If given, this will specify a username to use when issuing commands over SSH.\n            Useful when your infra requires you to explicitly specify a username to avoid permission issues.\n\n    Returns:\n        Process: The Process whose run() is to run the `cmd` on the remote host. Returns when the cmd completes\n            on the remote host.\n    \"\"\"\n    ip_prefix = \"\"\n    if username:\n        ip_prefix += \"{username}@\".format(username=username)\n\n    custom_port = os.getenv(\"DIST_DGL_TEST_SSH_PORT\", \"\")\n    if custom_port:\n        port = custom_port\n\n    custom_ssh_key = os.getenv(\"DIST_DGL_TEST_SSH_KEY\", \"\")\n    if custom_ssh_key:\n        custom_ssh_key = os.path.expanduser(custom_ssh_key)\n        custom_ssh_key = \"-i \" + custom_ssh_key\n\n    ssh_setup = os.getenv(\"DIST_DGL_TEST_SSH_SETUP\", \"\")\n    if ssh_setup:\n        cmd = ssh_setup + \";\" + cmd\n    # Construct ssh command that executes `cmd` on the remote host\n    ssh_cmd = \"ssh -o StrictHostKeyChecking=no {ssh_key} -p {port} {ip_prefix}{ip} '{cmd}'\".format(\n        ssh_key=custom_ssh_key,\n        port=str(port),\n        ip_prefix=ip_prefix,\n        ip=ip,\n        cmd=cmd,\n    )\n    ctx = mp.get_context(\"spawn\")\n    proc = ctx.Process(target=run, args=(ssh_cmd,))\n    proc.start()\n    return proc\n\n\ndef get_ips(ip_config):\n    ips = []\n    with open(ip_config) as f:\n        for line in f:\n            result = line.strip().split()\n            if len(result) != 1:\n                raise RuntimeError(\n                    \"Invalid format of ip_config:{}\".format(ip_config)\n                )\n            ips.append(result[0])\n    return ips\n"
  },
  {
    "path": "tests/distributed/test_dist_graph_store.py",
    "content": "import os\n\nos.environ[\"OMP_NUM_THREADS\"] = \"1\"\nimport math\nimport multiprocessing as mp\nimport pickle\nimport socket\nimport sys\nimport time\nimport unittest\nfrom multiprocessing import Condition, Manager, Process, Value\n\nimport backend as F\n\nimport dgl\nimport dgl.graphbolt as gb\nimport numpy as np\nimport pytest\nimport torch as th\nfrom dgl.data.utils import load_graphs, save_graphs\nfrom dgl.distributed import (\n    DistEmbedding,\n    DistGraph,\n    DistGraphServer,\n    edge_split,\n    load_partition,\n    load_partition_book,\n    node_split,\n    partition_graph,\n)\nfrom dgl.distributed.optim import SparseAdagrad\nfrom dgl.heterograph_index import create_unitgraph_from_coo\nfrom numpy.testing import assert_almost_equal, assert_array_equal\nfrom scipy import sparse as spsp\nfrom utils import create_random_graph, generate_ip_config, reset_envs\n\nif os.name != \"nt\":\n    import fcntl\n    import struct\n\n\ndef _verify_dist_graph_server_dgl(g):\n    # verify dtype of underlying graph\n    cg = g.client_g\n    for k, dtype in dgl.distributed.dist_graph.RESERVED_FIELD_DTYPE.items():\n        if k in cg.ndata:\n            assert (\n                F.dtype(cg.ndata[k]) == dtype\n            ), \"Data type of {} in ndata should be {}.\".format(k, dtype)\n        if k in cg.edata:\n            assert (\n                F.dtype(cg.edata[k]) == dtype\n            ), \"Data type of {} in edata should be {}.\".format(k, dtype)\n\n\ndef _verify_dist_graph_server_graphbolt(g):\n    graph = g.client_g\n    assert isinstance(graph, gb.FusedCSCSamplingGraph)\n    # [Rui][TODO] verify dtype of underlying graph.\n\n\ndef run_server(\n    graph_name,\n    server_id,\n    server_count,\n    num_clients,\n    shared_mem,\n    use_graphbolt=False,\n):\n    g = DistGraphServer(\n        server_id,\n        \"kv_ip_config.txt\",\n        server_count,\n        num_clients,\n        \"/tmp/dist_graph/{}.json\".format(graph_name),\n        disable_shared_mem=not shared_mem,\n        graph_format=[\"csc\", \"coo\"],\n        use_graphbolt=use_graphbolt,\n    )\n    print(f\"Starting server[{server_id}] with use_graphbolt={use_graphbolt}\")\n    _verify = (\n        _verify_dist_graph_server_graphbolt\n        if use_graphbolt\n        else _verify_dist_graph_server_dgl\n    )\n    _verify(g)\n    g.start()\n\n\ndef emb_init(shape, dtype):\n    return F.zeros(shape, dtype, F.cpu())\n\n\ndef rand_init(shape, dtype):\n    return F.tensor(np.random.normal(size=shape), F.float32)\n\n\ndef check_dist_graph_empty(g, num_clients, num_nodes, num_edges):\n    # Test API\n    assert g.num_nodes() == num_nodes\n    assert g.num_edges() == num_edges\n\n    # Test init node data\n    new_shape = (g.num_nodes(), 2)\n    g.ndata[\"test1\"] = dgl.distributed.DistTensor(new_shape, F.int32)\n    nids = F.arange(0, int(g.num_nodes() / 2))\n    feats = g.ndata[\"test1\"][nids]\n    assert np.all(F.asnumpy(feats) == 0)\n\n    # create a tensor and destroy a tensor and create it again.\n    test3 = dgl.distributed.DistTensor(\n        new_shape, F.float32, \"test3\", init_func=rand_init\n    )\n    del test3\n    test3 = dgl.distributed.DistTensor((g.num_nodes(), 3), F.float32, \"test3\")\n    del test3\n\n    # Test write data\n    new_feats = F.ones((len(nids), 2), F.int32, F.cpu())\n    g.ndata[\"test1\"][nids] = new_feats\n    feats = g.ndata[\"test1\"][nids]\n    assert np.all(F.asnumpy(feats) == 1)\n\n    # Test metadata operations.\n    assert g.node_attr_schemes()[\"test1\"].dtype == F.int32\n\n    print(\"end\")\n\n\ndef run_client_empty(\n    graph_name,\n    part_id,\n    server_count,\n    num_clients,\n    num_nodes,\n    num_edges,\n    use_graphbolt=False,\n):\n    os.environ[\"DGL_NUM_SERVER\"] = str(server_count)\n    dgl.distributed.initialize(\"kv_ip_config.txt\")\n    gpb, graph_name, _, _ = load_partition_book(\n        \"/tmp/dist_graph/{}.json\".format(graph_name), part_id\n    )\n    g = DistGraph(graph_name, gpb=gpb)\n    check_dist_graph_empty(g, num_clients, num_nodes, num_edges)\n\n\ndef check_server_client_empty(\n    shared_mem, num_servers, num_clients, use_graphbolt=False\n):\n    prepare_dist(num_servers)\n    g = create_random_graph(10000)\n\n    # Partition the graph\n    num_parts = 1\n    graph_name = \"dist_graph_test_1\"\n    partition_graph(\n        g, graph_name, num_parts, \"/tmp/dist_graph\", use_graphbolt=use_graphbolt\n    )\n\n    # let's just test on one partition for now.\n    # We cannot run multiple servers and clients on the same machine.\n    serv_ps = []\n    ctx = mp.get_context(\"spawn\")\n    for serv_id in range(num_servers):\n        p = ctx.Process(\n            target=run_server,\n            args=(\n                graph_name,\n                serv_id,\n                num_servers,\n                num_clients,\n                shared_mem,\n                use_graphbolt,\n            ),\n        )\n        serv_ps.append(p)\n        p.start()\n\n    cli_ps = []\n    for cli_id in range(num_clients):\n        print(\"start client\", cli_id)\n        p = ctx.Process(\n            target=run_client_empty,\n            args=(\n                graph_name,\n                0,\n                num_servers,\n                num_clients,\n                g.num_nodes(),\n                g.num_edges(),\n                use_graphbolt,\n            ),\n        )\n        p.start()\n        cli_ps.append(p)\n\n    for p in cli_ps:\n        p.join()\n        assert p.exitcode == 0\n\n    for p in serv_ps:\n        p.join()\n        assert p.exitcode == 0\n\n    print(\"clients have terminated\")\n\n\ndef run_client(\n    graph_name,\n    part_id,\n    server_count,\n    num_clients,\n    num_nodes,\n    num_edges,\n    group_id,\n    use_graphbolt=False,\n):\n    os.environ[\"DGL_NUM_SERVER\"] = str(server_count)\n    os.environ[\"DGL_GROUP_ID\"] = str(group_id)\n    dgl.distributed.initialize(\"kv_ip_config.txt\")\n    gpb, graph_name, _, _ = load_partition_book(\n        \"/tmp/dist_graph/{}.json\".format(graph_name), part_id\n    )\n    g = DistGraph(graph_name, gpb=gpb)\n    check_dist_graph(\n        g, num_clients, num_nodes, num_edges, use_graphbolt=use_graphbolt\n    )\n\n\ndef run_emb_client(\n    graph_name,\n    part_id,\n    server_count,\n    num_clients,\n    num_nodes,\n    num_edges,\n    group_id,\n):\n    os.environ[\"DGL_NUM_SERVER\"] = str(server_count)\n    os.environ[\"DGL_GROUP_ID\"] = str(group_id)\n    dgl.distributed.initialize(\"kv_ip_config.txt\")\n    gpb, graph_name, _, _ = load_partition_book(\n        \"/tmp/dist_graph/{}.json\".format(graph_name), part_id\n    )\n    g = DistGraph(graph_name, gpb=gpb)\n    check_dist_emb(g, num_clients, num_nodes, num_edges)\n\n\ndef run_optim_client(\n    graph_name,\n    part_id,\n    server_count,\n    rank,\n    world_size,\n    num_nodes,\n    optimizer_states,\n    save,\n):\n    os.environ[\"DGL_NUM_SERVER\"] = str(server_count)\n    os.environ[\"MASTER_ADDR\"] = \"127.0.0.1\"\n    os.environ[\"MASTER_PORT\"] = \"12355\"\n    dgl.distributed.initialize(\"kv_ip_config.txt\")\n    th.distributed.init_process_group(\n        backend=\"gloo\", rank=rank, world_size=world_size\n    )\n    gpb, graph_name, _, _ = load_partition_book(\n        \"/tmp/dist_graph/{}.json\".format(graph_name), part_id\n    )\n    g = DistGraph(graph_name, gpb=gpb)\n    check_dist_optim_store(rank, num_nodes, optimizer_states, save)\n\n\ndef check_dist_optim_store(rank, num_nodes, optimizer_states, save):\n    try:\n        total_idx = F.arange(0, num_nodes, F.int64, F.cpu())\n        emb = DistEmbedding(num_nodes, 1, name=\"optim_emb1\", init_func=emb_init)\n        emb2 = DistEmbedding(\n            num_nodes, 1, name=\"optim_emb2\", init_func=emb_init\n        )\n        if save:\n            optimizer = SparseAdagrad([emb, emb2], lr=0.1, eps=1e-08)\n            if rank == 0:\n                optimizer._state[\"optim_emb1\"][total_idx] = optimizer_states[0]\n                optimizer._state[\"optim_emb2\"][total_idx] = optimizer_states[1]\n            optimizer.save(\"/tmp/dist_graph/emb.pt\")\n        else:\n            optimizer = SparseAdagrad([emb, emb2], lr=0.001, eps=2e-08)\n            optimizer.load(\"/tmp/dist_graph/emb.pt\")\n            if rank == 0:\n                assert F.allclose(\n                    optimizer._state[\"optim_emb1\"][total_idx],\n                    optimizer_states[0],\n                    0.0,\n                    0.0,\n                )\n                assert F.allclose(\n                    optimizer._state[\"optim_emb2\"][total_idx],\n                    optimizer_states[1],\n                    0.0,\n                    0.0,\n                )\n                assert 0.1 == optimizer._lr\n                assert 1e-08 == optimizer._eps\n            th.distributed.barrier()\n    except Exception as e:\n        print(e)\n        sys.exit(-1)\n\n\ndef run_client_hierarchy(\n    graph_name,\n    part_id,\n    server_count,\n    node_mask,\n    edge_mask,\n    return_dict,\n    use_graphbolt=False,\n):\n    os.environ[\"DGL_NUM_SERVER\"] = str(server_count)\n    dgl.distributed.initialize(\"kv_ip_config.txt\")\n    gpb, graph_name, _, _ = load_partition_book(\n        \"/tmp/dist_graph/{}.json\".format(graph_name), part_id\n    )\n    g = DistGraph(graph_name, gpb=gpb)\n    node_mask = F.tensor(node_mask)\n    edge_mask = F.tensor(edge_mask)\n    nodes = node_split(\n        node_mask,\n        g.get_partition_book(),\n        node_trainer_ids=g.ndata[\"trainer_id\"],\n    )\n    edges = edge_split(\n        edge_mask,\n        g.get_partition_book(),\n        edge_trainer_ids=g.edata[\"trainer_id\"],\n    )\n    rank = g.rank()\n    return_dict[rank] = (nodes, edges)\n\n\ndef check_dist_emb(g, num_clients, num_nodes, num_edges):\n    # Test sparse emb\n    try:\n        emb = DistEmbedding(g.num_nodes(), 1, \"emb1\", emb_init)\n        nids = F.arange(0, int(g.num_nodes()))\n        lr = 0.001\n        optimizer = SparseAdagrad([emb], lr=lr)\n        with F.record_grad():\n            feats = emb(nids)\n            assert np.all(F.asnumpy(feats) == np.zeros((len(nids), 1)))\n            loss = F.sum(feats + 1, 0)\n        loss.backward()\n        optimizer.step()\n        feats = emb(nids)\n        if num_clients == 1:\n            assert_almost_equal(F.asnumpy(feats), np.ones((len(nids), 1)) * -lr)\n        rest = np.setdiff1d(np.arange(g.num_nodes()), F.asnumpy(nids))\n        feats1 = emb(rest)\n        assert np.all(F.asnumpy(feats1) == np.zeros((len(rest), 1)))\n\n        policy = dgl.distributed.PartitionPolicy(\"node\", g.get_partition_book())\n        grad_sum = dgl.distributed.DistTensor(\n            (g.num_nodes(), 1), F.float32, \"emb1_sum\", policy\n        )\n        if num_clients == 1:\n            assert np.all(\n                F.asnumpy(grad_sum[nids])\n                == np.ones((len(nids), 1)) * num_clients\n            )\n        assert np.all(F.asnumpy(grad_sum[rest]) == np.zeros((len(rest), 1)))\n\n        emb = DistEmbedding(g.num_nodes(), 1, \"emb2\", emb_init)\n        with F.no_grad():\n            feats1 = emb(nids)\n        assert np.all(F.asnumpy(feats1) == 0)\n\n        optimizer = SparseAdagrad([emb], lr=lr)\n        with F.record_grad():\n            feats1 = emb(nids)\n            feats2 = emb(nids)\n            feats = F.cat([feats1, feats2], 0)\n            assert np.all(F.asnumpy(feats) == np.zeros((len(nids) * 2, 1)))\n            loss = F.sum(feats + 1, 0)\n        loss.backward()\n        optimizer.step()\n        with F.no_grad():\n            feats = emb(nids)\n        if num_clients == 1:\n            assert_almost_equal(\n                F.asnumpy(feats), np.ones((len(nids), 1)) * 1 * -lr\n            )\n        rest = np.setdiff1d(np.arange(g.num_nodes()), F.asnumpy(nids))\n        feats1 = emb(rest)\n        assert np.all(F.asnumpy(feats1) == np.zeros((len(rest), 1)))\n    except NotImplementedError as e:\n        pass\n    except Exception as e:\n        print(e)\n        sys.exit(-1)\n\n\ndef check_dist_graph(g, num_clients, num_nodes, num_edges, use_graphbolt=False):\n    # Test API\n    assert g.num_nodes() == num_nodes\n    assert g.num_edges() == num_edges\n\n    # Test reading node data\n    nids = F.arange(0, int(g.num_nodes() / 2))\n    feats1 = g.ndata[\"features\"][nids]\n    feats = F.squeeze(feats1, 1)\n    assert np.all(F.asnumpy(feats == nids))\n\n    # Test reading edge data\n    eids = F.arange(0, int(g.num_edges() / 2))\n    feats1 = g.edata[\"features\"][eids]\n    feats = F.squeeze(feats1, 1)\n    assert np.all(F.asnumpy(feats == eids))\n\n    # Test edge_subgraph\n    sg = g.edge_subgraph(eids)\n    assert sg.num_edges() == len(eids)\n    assert F.array_equal(sg.edata[dgl.EID], eids)\n\n    # Test init node data\n    new_shape = (g.num_nodes(), 2)\n    test1 = dgl.distributed.DistTensor(new_shape, F.int32)\n    g.ndata[\"test1\"] = test1\n    feats = g.ndata[\"test1\"][nids]\n    assert np.all(F.asnumpy(feats) == 0)\n    assert test1.count_nonzero() == 0\n\n    # reference to a one that exists\n    test2 = dgl.distributed.DistTensor(\n        new_shape, F.float32, \"test2\", init_func=rand_init\n    )\n    test3 = dgl.distributed.DistTensor(new_shape, F.float32, \"test2\")\n    assert np.all(F.asnumpy(test2[nids]) == F.asnumpy(test3[nids]))\n\n    # create a tensor and destroy a tensor and create it again.\n    test3 = dgl.distributed.DistTensor(\n        new_shape, F.float32, \"test3\", init_func=rand_init\n    )\n    test3_name = test3.kvstore_key\n    assert test3_name in g._client.data_name_list()\n    assert test3_name in g._client.gdata_name_list()\n    del test3\n    assert test3_name not in g._client.data_name_list()\n    assert test3_name not in g._client.gdata_name_list()\n    test3 = dgl.distributed.DistTensor((g.num_nodes(), 3), F.float32, \"test3\")\n    del test3\n\n    # add tests for anonymous distributed tensor.\n    test3 = dgl.distributed.DistTensor(\n        new_shape, F.float32, init_func=rand_init\n    )\n    data = test3[0:10]\n    test4 = dgl.distributed.DistTensor(\n        new_shape, F.float32, init_func=rand_init\n    )\n    del test3\n    test5 = dgl.distributed.DistTensor(\n        new_shape, F.float32, init_func=rand_init\n    )\n    assert np.sum(F.asnumpy(test5[0:10] != data)) > 0\n\n    # test a persistent tesnor\n    test4 = dgl.distributed.DistTensor(\n        new_shape, F.float32, \"test4\", init_func=rand_init, persistent=True\n    )\n    del test4\n    try:\n        test4 = dgl.distributed.DistTensor(\n            (g.num_nodes(), 3), F.float32, \"test4\"\n        )\n        raise Exception(\"\")\n    except:\n        pass\n\n    # Test write data\n    new_feats = F.ones((len(nids), 2), F.int32, F.cpu())\n    g.ndata[\"test1\"][nids] = new_feats\n    feats = g.ndata[\"test1\"][nids]\n    assert np.all(F.asnumpy(feats) == 1)\n\n    # Test metadata operations.\n    assert len(g.ndata[\"features\"]) == g.num_nodes()\n    assert g.ndata[\"features\"].shape == (g.num_nodes(), 1)\n    assert g.ndata[\"features\"].dtype == F.int64\n    assert g.node_attr_schemes()[\"features\"].dtype == F.int64\n    assert g.node_attr_schemes()[\"test1\"].dtype == F.int32\n    assert g.node_attr_schemes()[\"features\"].shape == (1,)\n\n    selected_nodes = np.random.randint(0, 100, size=g.num_nodes()) > 30\n    # Test node split\n    nodes = node_split(selected_nodes, g.get_partition_book())\n    nodes = F.asnumpy(nodes)\n    # We only have one partition, so the local nodes are basically all nodes in the graph.\n    local_nids = np.arange(g.num_nodes())\n    for n in nodes:\n        assert n in local_nids\n\n    print(\"end\")\n\n\ndef check_dist_emb_server_client(\n    shared_mem, num_servers, num_clients, num_groups=1\n):\n    prepare_dist(num_servers)\n    g = create_random_graph(10000)\n\n    # Partition the graph\n    num_parts = 1\n    graph_name = (\n        f\"check_dist_emb_{shared_mem}_{num_servers}_{num_clients}_{num_groups}\"\n    )\n    g.ndata[\"features\"] = F.unsqueeze(F.arange(0, g.num_nodes()), 1)\n    g.edata[\"features\"] = F.unsqueeze(F.arange(0, g.num_edges()), 1)\n    partition_graph(g, graph_name, num_parts, \"/tmp/dist_graph\")\n\n    # let's just test on one partition for now.\n    # We cannot run multiple servers and clients on the same machine.\n    serv_ps = []\n    ctx = mp.get_context(\"spawn\")\n    for serv_id in range(num_servers):\n        p = ctx.Process(\n            target=run_server,\n            args=(\n                graph_name,\n                serv_id,\n                num_servers,\n                num_clients,\n                shared_mem,\n            ),\n        )\n        serv_ps.append(p)\n        p.start()\n\n    cli_ps = []\n    for cli_id in range(num_clients):\n        for group_id in range(num_groups):\n            print(\"start client[{}] for group[{}]\".format(cli_id, group_id))\n            p = ctx.Process(\n                target=run_emb_client,\n                args=(\n                    graph_name,\n                    0,\n                    num_servers,\n                    num_clients,\n                    g.num_nodes(),\n                    g.num_edges(),\n                    group_id,\n                ),\n            )\n            p.start()\n            time.sleep(1)  # avoid race condition when instantiating DistGraph\n            cli_ps.append(p)\n\n    for p in cli_ps:\n        p.join()\n        assert p.exitcode == 0\n\n    for p in serv_ps:\n        p.join()\n        assert p.exitcode == 0\n\n    print(\"clients have terminated\")\n\n\ndef check_server_client(\n    shared_mem, num_servers, num_clients, num_groups=1, use_graphbolt=False\n):\n    prepare_dist(num_servers)\n    g = create_random_graph(10000)\n\n    # Partition the graph\n    num_parts = 1\n    graph_name = f\"check_server_client_{shared_mem}_{num_servers}_{num_clients}_{num_groups}\"\n    g.ndata[\"features\"] = F.unsqueeze(F.arange(0, g.num_nodes()), 1)\n    g.edata[\"features\"] = F.unsqueeze(F.arange(0, g.num_edges()), 1)\n    partition_graph(\n        g, graph_name, num_parts, \"/tmp/dist_graph\", use_graphbolt=use_graphbolt\n    )\n\n    # let's just test on one partition for now.\n    # We cannot run multiple servers and clients on the same machine.\n    serv_ps = []\n    ctx = mp.get_context(\"spawn\")\n    for serv_id in range(num_servers):\n        p = ctx.Process(\n            target=run_server,\n            args=(\n                graph_name,\n                serv_id,\n                num_servers,\n                num_clients,\n                shared_mem,\n                use_graphbolt,\n            ),\n        )\n        serv_ps.append(p)\n        p.start()\n\n    # launch different client groups simultaneously\n    cli_ps = []\n    for cli_id in range(num_clients):\n        for group_id in range(num_groups):\n            print(\"start client[{}] for group[{}]\".format(cli_id, group_id))\n            p = ctx.Process(\n                target=run_client,\n                args=(\n                    graph_name,\n                    0,\n                    num_servers,\n                    num_clients,\n                    g.num_nodes(),\n                    g.num_edges(),\n                    group_id,\n                    use_graphbolt,\n                ),\n            )\n            p.start()\n            time.sleep(1)  # avoid race condition when instantiating DistGraph\n            cli_ps.append(p)\n    for p in cli_ps:\n        p.join()\n        assert p.exitcode == 0\n\n    for p in serv_ps:\n        p.join()\n        assert p.exitcode == 0\n\n    print(\"clients have terminated\")\n\n\ndef check_server_client_hierarchy(\n    shared_mem, num_servers, num_clients, use_graphbolt=False\n):\n    if num_clients == 1:\n        # skip this test if there is only one client.\n        return\n    prepare_dist(num_servers)\n    g = create_random_graph(10000)\n\n    # Partition the graph\n    num_parts = 1\n    graph_name = \"dist_graph_test_2\"\n    g.ndata[\"features\"] = F.unsqueeze(F.arange(0, g.num_nodes()), 1)\n    g.edata[\"features\"] = F.unsqueeze(F.arange(0, g.num_edges()), 1)\n    partition_graph(\n        g,\n        graph_name,\n        num_parts,\n        \"/tmp/dist_graph\",\n        num_trainers_per_machine=num_clients,\n        use_graphbolt=use_graphbolt,\n    )\n\n    # let's just test on one partition for now.\n    # We cannot run multiple servers and clients on the same machine.\n    serv_ps = []\n    ctx = mp.get_context(\"spawn\")\n    for serv_id in range(num_servers):\n        p = ctx.Process(\n            target=run_server,\n            args=(\n                graph_name,\n                serv_id,\n                num_servers,\n                num_clients,\n                shared_mem,\n                use_graphbolt,\n            ),\n        )\n        serv_ps.append(p)\n        p.start()\n\n    cli_ps = []\n    manager = mp.Manager()\n    return_dict = manager.dict()\n    node_mask = np.zeros((g.num_nodes(),), np.int32)\n    edge_mask = np.zeros((g.num_edges(),), np.int32)\n    nodes = np.random.choice(g.num_nodes(), g.num_nodes() // 10, replace=False)\n    edges = np.random.choice(g.num_edges(), g.num_edges() // 10, replace=False)\n    node_mask[nodes] = 1\n    edge_mask[edges] = 1\n    nodes = np.sort(nodes)\n    edges = np.sort(edges)\n    for cli_id in range(num_clients):\n        print(\"start client\", cli_id)\n        p = ctx.Process(\n            target=run_client_hierarchy,\n            args=(\n                graph_name,\n                0,\n                num_servers,\n                node_mask,\n                edge_mask,\n                return_dict,\n                use_graphbolt,\n            ),\n        )\n        p.start()\n        cli_ps.append(p)\n\n    for p in cli_ps:\n        p.join()\n        assert p.exitcode == 0\n    for p in serv_ps:\n        p.join()\n        assert p.exitcode == 0\n    nodes1 = []\n    edges1 = []\n    for n, e in return_dict.values():\n        nodes1.append(n)\n        edges1.append(e)\n    nodes1, _ = F.sort_1d(F.cat(nodes1, 0))\n    edges1, _ = F.sort_1d(F.cat(edges1, 0))\n    assert np.all(F.asnumpy(nodes1) == nodes)\n    assert np.all(F.asnumpy(edges1) == edges)\n\n    print(\"clients have terminated\")\n\n\ndef run_client_hetero(\n    graph_name,\n    part_id,\n    server_count,\n    num_clients,\n    num_nodes,\n    num_edges,\n    use_graphbolt=False,\n):\n    os.environ[\"DGL_NUM_SERVER\"] = str(server_count)\n    dgl.distributed.initialize(\"kv_ip_config.txt\")\n    gpb, graph_name, _, _ = load_partition_book(\n        \"/tmp/dist_graph/{}.json\".format(graph_name), part_id\n    )\n    g = DistGraph(graph_name, gpb=gpb)\n    check_dist_graph_hetero(\n        g, num_clients, num_nodes, num_edges, use_graphbolt=use_graphbolt\n    )\n\n\ndef create_random_hetero():\n    num_nodes = {\"n1\": 10000, \"n2\": 10010, \"n3\": 10020}\n    etypes = [(\"n1\", \"r1\", \"n2\"), (\"n1\", \"r2\", \"n3\"), (\"n2\", \"r3\", \"n3\")]\n    edges = {}\n    for etype in etypes:\n        src_ntype, _, dst_ntype = etype\n        arr = spsp.random(\n            num_nodes[src_ntype],\n            num_nodes[dst_ntype],\n            density=0.001,\n            format=\"coo\",\n            random_state=100,\n        )\n        edges[etype] = (arr.row, arr.col)\n    g = dgl.heterograph(edges, num_nodes)\n    # assign ndata & edata.\n    # data with same name as ntype/etype is assigned on purpose to verify\n    # such same names can be correctly handled in DistGraph. See more details\n    # in issue #4887 and #4463 on github.\n    ntype = \"n1\"\n    for name in [\"feat\", ntype]:\n        g.nodes[ntype].data[name] = F.unsqueeze(\n            F.arange(0, g.num_nodes(ntype)), 1\n        )\n    etype = \"r1\"\n    for name in [\"feat\", etype]:\n        g.edges[etype].data[name] = F.unsqueeze(\n            F.arange(0, g.num_edges(etype)), 1\n        )\n    return g\n\n\ndef check_dist_graph_hetero(\n    g, num_clients, num_nodes, num_edges, use_graphbolt=False\n):\n    # Test API\n    for ntype in num_nodes:\n        assert ntype in g.ntypes\n        assert num_nodes[ntype] == g.num_nodes(ntype)\n    for etype in num_edges:\n        assert etype in g.etypes\n        assert num_edges[etype] == g.num_edges(etype)\n    etypes = [(\"n1\", \"r1\", \"n2\"), (\"n1\", \"r2\", \"n3\"), (\"n2\", \"r3\", \"n3\")]\n    for i, etype in enumerate(g.canonical_etypes):\n        assert etype[0] == etypes[i][0]\n        assert etype[1] == etypes[i][1]\n        assert etype[2] == etypes[i][2]\n    assert g.num_nodes() == sum([num_nodes[ntype] for ntype in num_nodes])\n    assert g.num_edges() == sum([num_edges[etype] for etype in num_edges])\n\n    # Test reading node data\n    ntype = \"n1\"\n    nids = F.arange(0, g.num_nodes(ntype) // 2)\n    for name in [\"feat\", ntype]:\n        data = g.nodes[ntype].data[name][nids]\n        data = F.squeeze(data, 1)\n        assert np.all(F.asnumpy(data == nids))\n    assert len(g.nodes[\"n2\"].data) == 0\n    expect_except = False\n    try:\n        g.nodes[\"xxx\"].data[\"x\"]\n    except dgl.DGLError:\n        expect_except = True\n    assert expect_except\n\n    # Test reading edge data\n    etype = \"r1\"\n    eids = F.arange(0, g.num_edges(etype) // 2)\n    for name in [\"feat\", etype]:\n        # access via etype\n        data = g.edges[etype].data[name][eids]\n        data = F.squeeze(data, 1)\n        assert np.all(F.asnumpy(data == eids))\n        # access via canonical etype\n        c_etype = g.to_canonical_etype(etype)\n        data = g.edges[c_etype].data[name][eids]\n        data = F.squeeze(data, 1)\n        assert np.all(F.asnumpy(data == eids))\n    assert len(g.edges[\"r2\"].data) == 0\n    expect_except = False\n    try:\n        g.edges[\"xxx\"].data[\"x\"]\n    except dgl.DGLError:\n        expect_except = True\n    assert expect_except\n\n    # Test edge_subgraph\n    sg = g.edge_subgraph({\"r1\": eids})\n    assert sg.num_edges() == len(eids)\n    assert F.array_equal(sg.edata[dgl.EID], eids)\n    sg = g.edge_subgraph({(\"n1\", \"r1\", \"n2\"): eids})\n    assert sg.num_edges() == len(eids)\n    assert F.array_equal(sg.edata[dgl.EID], eids)\n\n    # Test init node data\n    new_shape = (g.num_nodes(\"n1\"), 2)\n    g.nodes[\"n1\"].data[\"test1\"] = dgl.distributed.DistTensor(new_shape, F.int32)\n    feats = g.nodes[\"n1\"].data[\"test1\"][nids]\n    assert np.all(F.asnumpy(feats) == 0)\n\n    # create a tensor and destroy a tensor and create it again.\n    test3 = dgl.distributed.DistTensor(\n        new_shape, F.float32, \"test3\", init_func=rand_init\n    )\n    del test3\n    test3 = dgl.distributed.DistTensor(\n        (g.num_nodes(\"n1\"), 3), F.float32, \"test3\"\n    )\n    del test3\n\n    # add tests for anonymous distributed tensor.\n    test3 = dgl.distributed.DistTensor(\n        new_shape, F.float32, init_func=rand_init\n    )\n    data = test3[0:10]\n    test4 = dgl.distributed.DistTensor(\n        new_shape, F.float32, init_func=rand_init\n    )\n    del test3\n    test5 = dgl.distributed.DistTensor(\n        new_shape, F.float32, init_func=rand_init\n    )\n    assert np.sum(F.asnumpy(test5[0:10] != data)) > 0\n\n    # test a persistent tesnor\n    test4 = dgl.distributed.DistTensor(\n        new_shape, F.float32, \"test4\", init_func=rand_init, persistent=True\n    )\n    del test4\n    try:\n        test4 = dgl.distributed.DistTensor(\n            (g.num_nodes(\"n1\"), 3), F.float32, \"test4\"\n        )\n        raise Exception(\"\")\n    except:\n        pass\n\n    # Test write data\n    new_feats = F.ones((len(nids), 2), F.int32, F.cpu())\n    g.nodes[\"n1\"].data[\"test1\"][nids] = new_feats\n    feats = g.nodes[\"n1\"].data[\"test1\"][nids]\n    assert np.all(F.asnumpy(feats) == 1)\n\n    # Test metadata operations.\n    assert len(g.nodes[\"n1\"].data[\"feat\"]) == g.num_nodes(\"n1\")\n    assert g.nodes[\"n1\"].data[\"feat\"].shape == (g.num_nodes(\"n1\"), 1)\n    assert g.nodes[\"n1\"].data[\"feat\"].dtype == F.int64\n\n    selected_nodes = np.random.randint(0, 100, size=g.num_nodes(\"n1\")) > 30\n    # Test node split\n    nodes = node_split(selected_nodes, g.get_partition_book(), ntype=\"n1\")\n    nodes = F.asnumpy(nodes)\n    # We only have one partition, so the local nodes are basically all nodes in the graph.\n    local_nids = np.arange(g.num_nodes(\"n1\"))\n    for n in nodes:\n        assert n in local_nids\n\n    print(\"end\")\n\n\ndef check_server_client_hetero(\n    shared_mem, num_servers, num_clients, use_graphbolt=False\n):\n    prepare_dist(num_servers)\n    g = create_random_hetero()\n\n    # Partition the graph\n    num_parts = 1\n    graph_name = \"dist_graph_test_3\"\n    partition_graph(\n        g, graph_name, num_parts, \"/tmp/dist_graph\", use_graphbolt=use_graphbolt\n    )\n\n    # let's just test on one partition for now.\n    # We cannot run multiple servers and clients on the same machine.\n    serv_ps = []\n    ctx = mp.get_context(\"spawn\")\n    for serv_id in range(num_servers):\n        p = ctx.Process(\n            target=run_server,\n            args=(\n                graph_name,\n                serv_id,\n                num_servers,\n                num_clients,\n                shared_mem,\n                use_graphbolt,\n            ),\n        )\n        serv_ps.append(p)\n        p.start()\n\n    cli_ps = []\n    num_nodes = {ntype: g.num_nodes(ntype) for ntype in g.ntypes}\n    num_edges = {etype: g.num_edges(etype) for etype in g.etypes}\n    for cli_id in range(num_clients):\n        print(\"start client\", cli_id)\n        p = ctx.Process(\n            target=run_client_hetero,\n            args=(\n                graph_name,\n                0,\n                num_servers,\n                num_clients,\n                num_nodes,\n                num_edges,\n                use_graphbolt,\n            ),\n        )\n        p.start()\n        cli_ps.append(p)\n\n    for p in cli_ps:\n        p.join()\n        assert p.exitcode == 0\n\n    for p in serv_ps:\n        p.join()\n        assert p.exitcode == 0\n\n    print(\"clients have terminated\")\n\n\n@unittest.skipIf(os.name == \"nt\", reason=\"Do not support windows yet\")\n@unittest.skipIf(\n    dgl.backend.backend_name == \"tensorflow\",\n    reason=\"TF doesn't support some of operations in DistGraph\",\n)\n@unittest.skipIf(\n    dgl.backend.backend_name == \"mxnet\", reason=\"Turn off Mxnet support\"\n)\n@pytest.mark.parametrize(\"shared_mem\", [True])\n@pytest.mark.parametrize(\"num_servers\", [1])\n@pytest.mark.parametrize(\"num_clients\", [1, 4])\n@pytest.mark.parametrize(\"use_graphbolt\", [True, False])\ndef test_server_client(shared_mem, num_servers, num_clients, use_graphbolt):\n    reset_envs()\n    os.environ[\"DGL_DIST_MODE\"] = \"distributed\"\n    # [Rui]\n    # 1. `disable_shared_mem=False` is not supported yet. Skip it.\n    # 2. `num_servers` > 1 does not work on single machine. Skip it.\n    for func in [\n        check_server_client,\n        check_server_client_hetero,\n        check_server_client_empty,\n        check_server_client_hierarchy,\n    ]:\n        func(shared_mem, num_servers, num_clients, use_graphbolt=use_graphbolt)\n\n\n@unittest.skip(reason=\"Skip due to glitch in CI\")\n@unittest.skipIf(os.name == \"nt\", reason=\"Do not support windows yet\")\n@unittest.skipIf(\n    dgl.backend.backend_name == \"tensorflow\",\n    reason=\"TF doesn't support distributed DistEmbedding\",\n)\n@unittest.skipIf(\n    dgl.backend.backend_name == \"mxnet\",\n    reason=\"Mxnet doesn't support distributed DistEmbedding\",\n)\ndef test_dist_emb_server_client():\n    reset_envs()\n    os.environ[\"DGL_DIST_MODE\"] = \"distributed\"\n    check_dist_emb_server_client(True, 1, 1)\n    check_dist_emb_server_client(False, 1, 1)\n    # [TODO][Rhett] Tests for multiple groups may fail sometimes and\n    # root cause is unknown. Let's disable them for now.\n    # check_dist_emb_server_client(True, 2, 2)\n    # check_dist_emb_server_client(True, 1, 1, 2)\n    # check_dist_emb_server_client(False, 1, 1, 2)\n    # check_dist_emb_server_client(True, 2, 2, 2)\n\n\n@unittest.skipIf(\n    dgl.backend.backend_name == \"tensorflow\",\n    reason=\"TF doesn't support distributed Optimizer\",\n)\n@unittest.skipIf(\n    dgl.backend.backend_name == \"mxnet\",\n    reason=\"Mxnet doesn't support distributed Optimizer\",\n)\ndef test_dist_optim_server_client():\n    reset_envs()\n    os.environ[\"DGL_DIST_MODE\"] = \"distributed\"\n    optimizer_states = []\n    num_nodes = 10000\n    optimizer_states.append(F.uniform((num_nodes, 1), F.float32, F.cpu(), 0, 1))\n    optimizer_states.append(F.uniform((num_nodes, 1), F.float32, F.cpu(), 0, 1))\n    check_dist_optim_server_client(num_nodes, 1, 4, optimizer_states, True)\n    check_dist_optim_server_client(num_nodes, 1, 8, optimizer_states, False)\n    check_dist_optim_server_client(num_nodes, 1, 2, optimizer_states, False)\n\n\ndef check_dist_optim_server_client(\n    num_nodes, num_servers, num_clients, optimizer_states, save\n):\n    graph_name = f\"check_dist_optim_{num_servers}_store\"\n    if save:\n        prepare_dist(num_servers)\n        g = create_random_graph(num_nodes)\n\n        # Partition the graph\n        num_parts = 1\n        g.ndata[\"features\"] = F.unsqueeze(F.arange(0, g.num_nodes()), 1)\n        g.edata[\"features\"] = F.unsqueeze(F.arange(0, g.num_edges()), 1)\n        partition_graph(g, graph_name, num_parts, \"/tmp/dist_graph\")\n\n    # let's just test on one partition for now.\n    # We cannot run multiple servers and clients on the same machine.\n    serv_ps = []\n    ctx = mp.get_context(\"spawn\")\n    for serv_id in range(num_servers):\n        p = ctx.Process(\n            target=run_server,\n            args=(\n                graph_name,\n                serv_id,\n                num_servers,\n                num_clients,\n                True,\n            ),\n        )\n        serv_ps.append(p)\n        p.start()\n\n    cli_ps = []\n    for cli_id in range(num_clients):\n        print(\"start client[{}] for group[0]\".format(cli_id))\n        p = ctx.Process(\n            target=run_optim_client,\n            args=(\n                graph_name,\n                0,\n                num_servers,\n                cli_id,\n                num_clients,\n                num_nodes,\n                optimizer_states,\n                save,\n            ),\n        )\n        p.start()\n        time.sleep(1)  # avoid race condition when instantiating DistGraph\n        cli_ps.append(p)\n\n    for p in cli_ps:\n        p.join()\n        assert p.exitcode == 0\n\n    for p in serv_ps:\n        p.join()\n        assert p.exitcode == 0\n\n\n@unittest.skipIf(\n    dgl.backend.backend_name == \"tensorflow\",\n    reason=\"TF doesn't support some of operations in DistGraph\",\n)\n@unittest.skipIf(\n    dgl.backend.backend_name == \"mxnet\", reason=\"Turn off Mxnet support\"\n)\ndef test_standalone():\n    reset_envs()\n    os.environ[\"DGL_DIST_MODE\"] = \"standalone\"\n\n    g = create_random_graph(10000)\n    # Partition the graph\n    num_parts = 1\n    graph_name = \"dist_graph_test_3\"\n    g.ndata[\"features\"] = F.unsqueeze(F.arange(0, g.num_nodes()), 1)\n    g.edata[\"features\"] = F.unsqueeze(F.arange(0, g.num_edges()), 1)\n    partition_graph(g, graph_name, num_parts, \"/tmp/dist_graph\")\n\n    dgl.distributed.initialize(\"kv_ip_config.txt\")\n    dist_g = DistGraph(\n        graph_name, part_config=\"/tmp/dist_graph/{}.json\".format(graph_name)\n    )\n    check_dist_graph(dist_g, 1, g.num_nodes(), g.num_edges())\n    dgl.distributed.exit_client()  # this is needed since there's two test here in one process\n\n\n@unittest.skip(reason=\"Skip due to glitch in CI\")\n@unittest.skipIf(\n    dgl.backend.backend_name == \"tensorflow\",\n    reason=\"TF doesn't support distributed DistEmbedding\",\n)\n@unittest.skipIf(\n    dgl.backend.backend_name == \"mxnet\",\n    reason=\"Mxnet doesn't support distributed DistEmbedding\",\n)\ndef test_standalone_node_emb():\n    reset_envs()\n    os.environ[\"DGL_DIST_MODE\"] = \"standalone\"\n\n    g = create_random_graph(10000)\n    # Partition the graph\n    num_parts = 1\n    graph_name = \"dist_graph_test_3\"\n    g.ndata[\"features\"] = F.unsqueeze(F.arange(0, g.num_nodes()), 1)\n    g.edata[\"features\"] = F.unsqueeze(F.arange(0, g.num_edges()), 1)\n    partition_graph(g, graph_name, num_parts, \"/tmp/dist_graph\")\n\n    dgl.distributed.initialize(\"kv_ip_config.txt\")\n    dist_g = DistGraph(\n        graph_name, part_config=\"/tmp/dist_graph/{}.json\".format(graph_name)\n    )\n    check_dist_emb(dist_g, 1, g.num_nodes(), g.num_edges())\n    dgl.distributed.exit_client()  # this is needed since there's two test here in one process\n\n\n@unittest.skipIf(os.name == \"nt\", reason=\"Do not support windows yet\")\n@pytest.mark.parametrize(\"hetero\", [True, False])\n@pytest.mark.parametrize(\"empty_mask\", [True, False])\ndef test_split(hetero, empty_mask):\n    if hetero:\n        g = create_random_hetero()\n        ntype = \"n1\"\n        etype = \"r1\"\n    else:\n        g = create_random_graph(10000)\n        ntype = \"_N\"\n        etype = \"_E\"\n    num_parts = 4\n    num_hops = 2\n    partition_graph(\n        g,\n        \"dist_graph_test\",\n        num_parts,\n        \"/tmp/dist_graph\",\n        num_hops=num_hops,\n        part_method=\"metis\",\n    )\n\n    mask_thd = 100 if empty_mask else 30\n    node_mask = np.random.randint(0, 100, size=g.num_nodes(ntype)) > mask_thd\n    edge_mask = np.random.randint(0, 100, size=g.num_edges(etype)) > mask_thd\n    selected_nodes = np.nonzero(node_mask)[0]\n    selected_edges = np.nonzero(edge_mask)[0]\n\n    # The code now collects the roles of all client processes and use the information\n    # to determine how to split the workloads. Here is to simulate the multi-client\n    # use case.\n    def set_roles(num_clients):\n        dgl.distributed.role.CUR_ROLE = \"default\"\n        dgl.distributed.role.GLOBAL_RANK = {i: i for i in range(num_clients)}\n        dgl.distributed.role.PER_ROLE_RANK[\"default\"] = {\n            i: i for i in range(num_clients)\n        }\n\n    for i in range(num_parts):\n        set_roles(num_parts)\n        part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition(\n            \"/tmp/dist_graph/dist_graph_test.json\", i\n        )\n        local_nids = F.nonzero_1d(part_g.ndata[\"inner_node\"])\n        local_nids = F.gather_row(part_g.ndata[dgl.NID], local_nids)\n        if hetero:\n            ntype_ids, nids = gpb.map_to_per_ntype(local_nids)\n            local_nids = F.asnumpy(nids)[F.asnumpy(ntype_ids) == 0]\n        else:\n            local_nids = F.asnumpy(local_nids)\n        nodes1 = np.intersect1d(selected_nodes, local_nids)\n        nodes2 = node_split(\n            node_mask, gpb, ntype=ntype, rank=i, force_even=False\n        )\n        assert np.all(np.sort(nodes1) == np.sort(F.asnumpy(nodes2)))\n        for n in F.asnumpy(nodes2):\n            assert n in local_nids\n\n        set_roles(num_parts * 2)\n        nodes3 = node_split(\n            node_mask, gpb, ntype=ntype, rank=i * 2, force_even=False\n        )\n        nodes4 = node_split(\n            node_mask, gpb, ntype=ntype, rank=i * 2 + 1, force_even=False\n        )\n        nodes5 = F.cat([nodes3, nodes4], 0)\n        assert np.all(np.sort(nodes1) == np.sort(F.asnumpy(nodes5)))\n\n        set_roles(num_parts)\n        local_eids = F.nonzero_1d(part_g.edata[\"inner_edge\"])\n        local_eids = F.gather_row(part_g.edata[dgl.EID], local_eids)\n        if hetero:\n            etype_ids, eids = gpb.map_to_per_etype(local_eids)\n            local_eids = F.asnumpy(eids)[F.asnumpy(etype_ids) == 0]\n        else:\n            local_eids = F.asnumpy(local_eids)\n        edges1 = np.intersect1d(selected_edges, local_eids)\n        edges2 = edge_split(\n            edge_mask, gpb, etype=etype, rank=i, force_even=False\n        )\n        assert np.all(np.sort(edges1) == np.sort(F.asnumpy(edges2)))\n        for e in F.asnumpy(edges2):\n            assert e in local_eids\n\n        set_roles(num_parts * 2)\n        edges3 = edge_split(\n            edge_mask, gpb, etype=etype, rank=i * 2, force_even=False\n        )\n        edges4 = edge_split(\n            edge_mask, gpb, etype=etype, rank=i * 2 + 1, force_even=False\n        )\n        edges5 = F.cat([edges3, edges4], 0)\n        assert np.all(np.sort(edges1) == np.sort(F.asnumpy(edges5)))\n\n\n@unittest.skipIf(os.name == \"nt\", reason=\"Do not support windows yet\")\n@pytest.mark.parametrize(\"empty_mask\", [True, False])\ndef test_split_even(empty_mask):\n    g = create_random_graph(10000)\n    num_parts = 4\n    num_hops = 2\n    partition_graph(\n        g,\n        \"dist_graph_test\",\n        num_parts,\n        \"/tmp/dist_graph\",\n        num_hops=num_hops,\n        part_method=\"metis\",\n    )\n\n    mask_thd = 100 if empty_mask else 30\n    node_mask = np.random.randint(0, 100, size=g.num_nodes()) > mask_thd\n    edge_mask = np.random.randint(0, 100, size=g.num_edges()) > mask_thd\n    all_nodes1 = []\n    all_nodes2 = []\n    all_edges1 = []\n    all_edges2 = []\n\n    # The code now collects the roles of all client processes and use the information\n    # to determine how to split the workloads. Here is to simulate the multi-client\n    # use case.\n    def set_roles(num_clients):\n        dgl.distributed.role.CUR_ROLE = \"default\"\n        dgl.distributed.role.GLOBAL_RANK = {i: i for i in range(num_clients)}\n        dgl.distributed.role.PER_ROLE_RANK[\"default\"] = {\n            i: i for i in range(num_clients)\n        }\n\n    for i in range(num_parts):\n        set_roles(num_parts)\n        part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition(\n            \"/tmp/dist_graph/dist_graph_test.json\", i\n        )\n        local_nids = F.nonzero_1d(part_g.ndata[\"inner_node\"])\n        local_nids = F.gather_row(part_g.ndata[dgl.NID], local_nids)\n        nodes = node_split(node_mask, gpb, rank=i, force_even=True)\n        all_nodes1.append(nodes)\n        subset = np.intersect1d(F.asnumpy(nodes), F.asnumpy(local_nids))\n        print(\n            \"part {} get {} nodes and {} are in the partition\".format(\n                i, len(nodes), len(subset)\n            )\n        )\n\n        set_roles(num_parts * 2)\n        nodes1 = node_split(node_mask, gpb, rank=i * 2, force_even=True)\n        nodes2 = node_split(node_mask, gpb, rank=i * 2 + 1, force_even=True)\n        nodes3, _ = F.sort_1d(F.cat([nodes1, nodes2], 0))\n        all_nodes2.append(nodes3)\n        subset = np.intersect1d(F.asnumpy(nodes), F.asnumpy(nodes3))\n        print(\"intersection has\", len(subset))\n\n        set_roles(num_parts)\n        local_eids = F.nonzero_1d(part_g.edata[\"inner_edge\"])\n        local_eids = F.gather_row(part_g.edata[dgl.EID], local_eids)\n        edges = edge_split(edge_mask, gpb, rank=i, force_even=True)\n        all_edges1.append(edges)\n        subset = np.intersect1d(F.asnumpy(edges), F.asnumpy(local_eids))\n        print(\n            \"part {} get {} edges and {} are in the partition\".format(\n                i, len(edges), len(subset)\n            )\n        )\n\n        set_roles(num_parts * 2)\n        edges1 = edge_split(edge_mask, gpb, rank=i * 2, force_even=True)\n        edges2 = edge_split(edge_mask, gpb, rank=i * 2 + 1, force_even=True)\n        edges3, _ = F.sort_1d(F.cat([edges1, edges2], 0))\n        all_edges2.append(edges3)\n        subset = np.intersect1d(F.asnumpy(edges), F.asnumpy(edges3))\n        print(\"intersection has\", len(subset))\n    all_nodes1 = F.cat(all_nodes1, 0)\n    all_edges1 = F.cat(all_edges1, 0)\n    all_nodes2 = F.cat(all_nodes2, 0)\n    all_edges2 = F.cat(all_edges2, 0)\n    all_nodes = np.nonzero(node_mask)[0]\n    all_edges = np.nonzero(edge_mask)[0]\n    assert np.all(all_nodes == F.asnumpy(all_nodes1))\n    assert np.all(all_edges == F.asnumpy(all_edges1))\n    assert np.all(all_nodes == F.asnumpy(all_nodes2))\n    assert np.all(all_edges == F.asnumpy(all_edges2))\n\n\ndef prepare_dist(num_servers=1):\n    generate_ip_config(\"kv_ip_config.txt\", 1, num_servers=num_servers)\n\n\nif __name__ == \"__main__\":\n    os.makedirs(\"/tmp/dist_graph\", exist_ok=True)\n    test_dist_emb_server_client()\n    test_server_client()\n    test_split(True)\n    test_split(False)\n    test_split_even()\n    test_standalone()\n    test_standalone_node_emb()\n"
  },
  {
    "path": "tests/distributed/test_dist_tensor.py",
    "content": "import operator\nimport os\nimport unittest\n\nimport backend as F\n\nimport dgl\nimport pytest\nfrom utils import create_random_graph, generate_ip_config, reset_envs\n\ndist_g = None\n\n\ndef rand_mask(shape, dtype):\n    return F.randn(shape) > 0\n\n\n@unittest.skipIf(\n    dgl.backend.backend_name == \"tensorflow\",\n    reason=\"TF doesn't support some of operations in DistGraph\",\n)\n@unittest.skipIf(\n    dgl.backend.backend_name == \"mxnet\", reason=\"Turn off Mxnet support\"\n)\ndef setup_module():\n    global dist_g\n\n    reset_envs()\n    os.environ[\"DGL_DIST_MODE\"] = \"standalone\"\n\n    dist_g = create_random_graph(10000)\n    # Partition the graph.\n    num_parts = 1\n    graph_name = \"dist_graph_test_3\"\n    dist_g.ndata[\"features\"] = F.unsqueeze(F.arange(0, dist_g.num_nodes()), 1)\n    dist_g.edata[\"features\"] = F.unsqueeze(F.arange(0, dist_g.num_edges()), 1)\n    dgl.distributed.partition_graph(\n        dist_g, graph_name, num_parts, \"/tmp/dist_graph\"\n    )\n\n    dgl.distributed.initialize(\"kv_ip_config.txt\")\n    dist_g = dgl.distributed.DistGraph(\n        graph_name, part_config=\"/tmp/dist_graph/{}.json\".format(graph_name)\n    )\n    dist_g.edata[\"mask1\"] = dgl.distributed.DistTensor(\n        (dist_g.num_edges(),), F.bool, init_func=rand_mask\n    )\n    dist_g.edata[\"mask2\"] = dgl.distributed.DistTensor(\n        (dist_g.num_edges(),), F.bool, init_func=rand_mask\n    )\n\n\ndef check_binary_op(key1, key2, key3, op):\n    for i in range(0, dist_g.num_edges(), 1000):\n        i_end = min(i + 1000, dist_g.num_edges())\n        assert F.array_equal(\n            dist_g.edata[key3][i:i_end],\n            op(dist_g.edata[key1][i:i_end], dist_g.edata[key2][i:i_end]),\n        )\n        # Test with different index dtypes. int32 is not supported.\n        with pytest.raises(\n            dgl.utils.internal.InconsistentDtypeException,\n            match=\"DGL now requires the input tensor to have\",\n        ):\n            _ = dist_g.edata[key3][F.tensor([100, 20, 10], F.int32)]\n        _ = dist_g.edata[key3][F.tensor([100, 20, 10], F.int64)]\n\n\n@unittest.skipIf(\n    dgl.backend.backend_name == \"tensorflow\",\n    reason=\"TF doesn't support some of operations in DistGraph\",\n)\n@unittest.skipIf(\n    dgl.backend.backend_name == \"mxnet\", reason=\"Turn off Mxnet support\"\n)\ndef test_op():\n    dist_g.edata[\"mask3\"] = dist_g.edata[\"mask1\"] | dist_g.edata[\"mask2\"]\n    check_binary_op(\"mask1\", \"mask2\", \"mask3\", operator.or_)\n\n\n@unittest.skipIf(\n    dgl.backend.backend_name == \"tensorflow\",\n    reason=\"TF doesn't support some of operations in DistGraph\",\n)\n@unittest.skipIf(\n    dgl.backend.backend_name == \"mxnet\", reason=\"Turn off Mxnet support\"\n)\ndef teardown_module():\n    # Since there are two tests in one process, this is needed to make sure\n    # the client exits properly.\n    dgl.distributed.exit_client()\n\n\nif __name__ == \"__main__\":\n    setup_module()\n    test_op()\n    teardown_module()\n"
  },
  {
    "path": "tests/distributed/test_distributed_sampling.py",
    "content": "import multiprocessing as mp\nimport os\nimport random\nimport tempfile\nimport time\nimport traceback\nimport unittest\nfrom pathlib import Path\n\nimport dgl\n\nimport dgl.backend as F\nimport numpy as np\nimport pytest\nimport torch\nfrom dgl.data import CitationGraphDataset, WN18Dataset\nfrom dgl.distributed import (\n    DistGraph,\n    DistGraphServer,\n    load_partition,\n    load_partition_book,\n    partition_graph,\n    sample_etype_neighbors,\n    sample_neighbors,\n)\n\nfrom dgl.distributed.graph_partition_book import _etype_tuple_to_str\n\nfrom scipy import sparse as spsp\nfrom utils import generate_ip_config, reset_envs\n\n\ndef start_server(\n    rank,\n    tmpdir,\n    disable_shared_mem,\n    graph_name,\n    graph_format=[\"csc\", \"coo\"],\n    use_graphbolt=False,\n):\n    g = DistGraphServer(\n        rank,\n        \"rpc_ip_config.txt\",\n        1,\n        1,\n        tmpdir / (graph_name + \".json\"),\n        disable_shared_mem=disable_shared_mem,\n        graph_format=graph_format,\n        use_graphbolt=use_graphbolt,\n    )\n    g.start()\n\n\ndef start_sample_client(rank, tmpdir, disable_shared_mem):\n    gpb = None\n    if disable_shared_mem:\n        _, _, _, gpb, _, _, _ = load_partition(\n            tmpdir / \"test_sampling.json\", rank\n        )\n    dgl.distributed.initialize(\"rpc_ip_config.txt\")\n    dist_graph = DistGraph(\"test_sampling\", gpb=gpb)\n    try:\n        sampled_graph = sample_neighbors(\n            dist_graph,\n            torch.tensor([0, 10, 99, 66, 1024, 2008], dtype=dist_graph.idtype),\n            3,\n        )\n    except Exception as e:\n        print(traceback.format_exc())\n        sampled_graph = None\n    dgl.distributed.exit_client()\n    return sampled_graph\n\n\ndef start_sample_client_shuffle(\n    rank,\n    tmpdir,\n    disable_shared_mem,\n    g,\n    num_servers,\n    group_id,\n    orig_nid,\n    orig_eid,\n    use_graphbolt=False,\n    return_eids=False,\n    node_id_dtype=None,\n    replace=False,\n):\n    os.environ[\"DGL_GROUP_ID\"] = str(group_id)\n    gpb = None\n    if disable_shared_mem:\n        _, _, _, gpb, _, _, _ = load_partition(\n            tmpdir / \"test_sampling.json\", rank\n        )\n    dgl.distributed.initialize(\"rpc_ip_config.txt\")\n    dist_graph = DistGraph(\"test_sampling\", gpb=gpb)\n    sampled_graph = sample_neighbors(\n        dist_graph,\n        torch.tensor([0, 10, 99, 66, 1024, 2008], dtype=node_id_dtype),\n        3,\n        replace=replace,\n        use_graphbolt=use_graphbolt,\n    )\n    assert sampled_graph.idtype == dist_graph.idtype\n    assert sampled_graph.idtype == torch.int64\n\n    assert (\n        dgl.ETYPE not in sampled_graph.edata\n    ), \"Etype should not be in homogeneous sampled graph.\"\n    src, dst = sampled_graph.edges()\n    sampled_in_degrees = sampled_graph.in_degrees(dst)\n    src = orig_nid[src]\n    dst = orig_nid[dst]\n    assert sampled_graph.num_nodes() == g.num_nodes()\n    assert np.all(F.asnumpy(g.has_edges_between(src, dst)))\n    if use_graphbolt and not return_eids:\n        assert (\n            dgl.EID not in sampled_graph.edata\n        ), \"EID should not be in sampled graph if use_graphbolt=True.\"\n    else:\n        eids = g.edge_ids(src, dst)\n        eids1 = orig_eid[sampled_graph.edata[dgl.EID]]\n        assert np.array_equal(F.asnumpy(eids1), F.asnumpy(eids))\n    # Verify replace argument.\n    orig_in_degrees = g.in_degrees(dst)\n    if replace:\n        assert torch.all(\n            (sampled_in_degrees == 3) | (sampled_in_degrees == orig_in_degrees)\n        )\n    else:\n        assert torch.all(sampled_in_degrees <= 3)\n\n\ndef start_find_edges_client(rank, tmpdir, disable_shared_mem, eids, etype=None):\n    gpb = None\n    if disable_shared_mem:\n        _, _, _, gpb, _, _, _ = load_partition(\n            tmpdir / \"test_find_edges.json\", rank\n        )\n    dgl.distributed.initialize(\"rpc_ip_config.txt\")\n    dist_graph = DistGraph(\"test_find_edges\", gpb=gpb)\n    try:\n        u, v = dist_graph.find_edges(eids, etype=etype)\n    except Exception as e:\n        print(traceback.format_exc())\n        u, v = None, None\n    dgl.distributed.exit_client()\n    return u, v\n\n\ndef start_get_degrees_client(rank, tmpdir, disable_shared_mem, nids=None):\n    gpb = None\n    if disable_shared_mem:\n        _, _, _, gpb, _, _, _ = load_partition(\n            tmpdir / \"test_get_degrees.json\", rank\n        )\n    dgl.distributed.initialize(\"rpc_ip_config.txt\")\n    dist_graph = DistGraph(\"test_get_degrees\", gpb=gpb)\n    try:\n        in_deg = dist_graph.in_degrees(nids)\n        all_in_deg = dist_graph.in_degrees()\n        out_deg = dist_graph.out_degrees(nids)\n        all_out_deg = dist_graph.out_degrees()\n    except Exception as e:\n        print(traceback.format_exc())\n        in_deg, out_deg, all_in_deg, all_out_deg = None, None, None, None\n    dgl.distributed.exit_client()\n    return in_deg, out_deg, all_in_deg, all_out_deg\n\n\ndef check_rpc_sampling(tmpdir, num_server):\n    generate_ip_config(\"rpc_ip_config.txt\", num_server, num_server)\n\n    g = CitationGraphDataset(\"cora\")[0]\n    print(g.idtype)\n    num_parts = num_server\n    num_hops = 1\n\n    partition_graph(\n        g,\n        \"test_sampling\",\n        num_parts,\n        tmpdir,\n        num_hops=num_hops,\n        part_method=\"metis\",\n    )\n\n    pserver_list = []\n    ctx = mp.get_context(\"spawn\")\n    for i in range(num_server):\n        p = ctx.Process(\n            target=start_server,\n            args=(i, tmpdir, num_server > 1, \"test_sampling\"),\n        )\n        p.start()\n        time.sleep(1)\n        pserver_list.append(p)\n\n    sampled_graph = start_sample_client(0, tmpdir, num_server > 1)\n    print(\"Done sampling\")\n    for p in pserver_list:\n        p.join()\n        assert p.exitcode == 0\n\n    src, dst = sampled_graph.edges()\n    assert sampled_graph.num_nodes() == g.num_nodes()\n    assert np.all(F.asnumpy(g.has_edges_between(src, dst)))\n    eids = g.edge_ids(src, dst)\n    assert np.array_equal(\n        F.asnumpy(sampled_graph.edata[dgl.EID]), F.asnumpy(eids)\n    )\n\n\ndef check_rpc_find_edges_shuffle(tmpdir, num_server):\n    generate_ip_config(\"rpc_ip_config.txt\", num_server, num_server)\n\n    g = CitationGraphDataset(\"cora\")[0]\n    num_parts = num_server\n\n    orig_nid, orig_eid = partition_graph(\n        g,\n        \"test_find_edges\",\n        num_parts,\n        tmpdir,\n        num_hops=1,\n        part_method=\"metis\",\n        return_mapping=True,\n    )\n\n    pserver_list = []\n    ctx = mp.get_context(\"spawn\")\n    for i in range(num_server):\n        p = ctx.Process(\n            target=start_server,\n            args=(i, tmpdir, num_server > 1, \"test_find_edges\", [\"csr\", \"coo\"]),\n        )\n        p.start()\n        time.sleep(1)\n        pserver_list.append(p)\n\n    eids = F.tensor(np.random.randint(g.num_edges(), size=100))\n    u, v = g.find_edges(orig_eid[eids])\n    du, dv = start_find_edges_client(0, tmpdir, num_server > 1, eids)\n    du = orig_nid[du]\n    dv = orig_nid[dv]\n    assert F.array_equal(u, du)\n    assert F.array_equal(v, dv)\n\n\ndef create_random_hetero(dense=False, empty=False):\n    num_nodes = (\n        {\"n1\": 210, \"n2\": 200, \"n3\": 220}\n        if dense\n        else {\"n1\": 1010, \"n2\": 1000, \"n3\": 1020}\n    )\n    etypes = [(\"n1\", \"r12\", \"n2\"), (\"n1\", \"r13\", \"n3\"), (\"n2\", \"r23\", \"n3\")]\n    edges = {}\n    random.seed(42)\n    for etype in etypes:\n        src_ntype, _, dst_ntype = etype\n        arr = spsp.random(\n            num_nodes[src_ntype] - 10 if empty else num_nodes[src_ntype],\n            num_nodes[dst_ntype] - 10 if empty else num_nodes[dst_ntype],\n            density=0.1 if dense else 0.001,\n            format=\"coo\",\n            random_state=100,\n        )\n        edges[etype] = (arr.row, arr.col)\n    g = dgl.heterograph(edges, num_nodes)\n    g.nodes[\"n1\"].data[\"feat\"] = F.ones(\n        (g.num_nodes(\"n1\"), 10), F.float32, F.cpu()\n    )\n    return g\n\n\ndef check_rpc_hetero_find_edges_shuffle(tmpdir, num_server):\n    generate_ip_config(\"rpc_ip_config.txt\", num_server, num_server)\n\n    g = create_random_hetero()\n    num_parts = num_server\n\n    orig_nid, orig_eid = partition_graph(\n        g,\n        \"test_find_edges\",\n        num_parts,\n        tmpdir,\n        num_hops=1,\n        part_method=\"metis\",\n        return_mapping=True,\n    )\n\n    pserver_list = []\n    ctx = mp.get_context(\"spawn\")\n    for i in range(num_server):\n        p = ctx.Process(\n            target=start_server,\n            args=(i, tmpdir, num_server > 1, \"test_find_edges\", [\"csr\", \"coo\"]),\n        )\n        p.start()\n        time.sleep(1)\n        pserver_list.append(p)\n\n    test_etype = g.to_canonical_etype(\"r12\")\n    eids = F.tensor(np.random.randint(g.num_edges(test_etype), size=100))\n    expect_except = False\n    try:\n        _, _ = g.find_edges(orig_eid[test_etype][eids], etype=(\"n1\", \"r12\"))\n    except:\n        expect_except = True\n    assert expect_except\n    u, v = g.find_edges(orig_eid[test_etype][eids], etype=\"r12\")\n    u1, v1 = g.find_edges(orig_eid[test_etype][eids], etype=(\"n1\", \"r12\", \"n2\"))\n    assert F.array_equal(u, u1)\n    assert F.array_equal(v, v1)\n    du, dv = start_find_edges_client(\n        0, tmpdir, num_server > 1, eids, etype=\"r12\"\n    )\n    du = orig_nid[\"n1\"][du]\n    dv = orig_nid[\"n2\"][dv]\n    assert F.array_equal(u, du)\n    assert F.array_equal(v, dv)\n\n\n# Wait non shared memory graph store\n@unittest.skipIf(os.name == \"nt\", reason=\"Do not support windows yet\")\n@unittest.skipIf(\n    dgl.backend.backend_name == \"tensorflow\",\n    reason=\"Not support tensorflow for now\",\n)\n@unittest.skipIf(\n    dgl.backend.backend_name == \"mxnet\", reason=\"Turn off Mxnet support\"\n)\n@pytest.mark.parametrize(\"num_server\", [1])\ndef test_rpc_find_edges_shuffle(num_server):\n    reset_envs()\n    import tempfile\n\n    os.environ[\"DGL_DIST_MODE\"] = \"distributed\"\n    with tempfile.TemporaryDirectory() as tmpdirname:\n        check_rpc_hetero_find_edges_shuffle(Path(tmpdirname), num_server)\n        check_rpc_find_edges_shuffle(Path(tmpdirname), num_server)\n\n\ndef check_rpc_get_degree_shuffle(tmpdir, num_server):\n    generate_ip_config(\"rpc_ip_config.txt\", num_server, num_server)\n\n    g = CitationGraphDataset(\"cora\")[0]\n    num_parts = num_server\n\n    orig_nid, _ = partition_graph(\n        g,\n        \"test_get_degrees\",\n        num_parts,\n        tmpdir,\n        num_hops=1,\n        part_method=\"metis\",\n        return_mapping=True,\n    )\n\n    pserver_list = []\n    ctx = mp.get_context(\"spawn\")\n    for i in range(num_server):\n        p = ctx.Process(\n            target=start_server,\n            args=(i, tmpdir, num_server > 1, \"test_get_degrees\"),\n        )\n        p.start()\n        time.sleep(1)\n        pserver_list.append(p)\n\n    nids = F.tensor(np.random.randint(g.num_nodes(), size=100))\n    in_degs, out_degs, all_in_degs, all_out_degs = start_get_degrees_client(\n        0, tmpdir, num_server > 1, nids\n    )\n\n    print(\"Done get_degree\")\n    for p in pserver_list:\n        p.join()\n        assert p.exitcode == 0\n\n    print(\"check results\")\n    assert F.array_equal(g.in_degrees(orig_nid[nids]), in_degs)\n    assert F.array_equal(g.in_degrees(orig_nid), all_in_degs)\n    assert F.array_equal(g.out_degrees(orig_nid[nids]), out_degs)\n    assert F.array_equal(g.out_degrees(orig_nid), all_out_degs)\n\n\n# Wait non shared memory graph store\n@unittest.skipIf(os.name == \"nt\", reason=\"Do not support windows yet\")\n@unittest.skipIf(\n    dgl.backend.backend_name == \"tensorflow\",\n    reason=\"Not support tensorflow for now\",\n)\n@unittest.skipIf(\n    dgl.backend.backend_name == \"mxnet\", reason=\"Turn off Mxnet support\"\n)\n@pytest.mark.parametrize(\"num_server\", [1])\ndef test_rpc_get_degree_shuffle(num_server):\n    reset_envs()\n    import tempfile\n\n    os.environ[\"DGL_DIST_MODE\"] = \"distributed\"\n    with tempfile.TemporaryDirectory() as tmpdirname:\n        check_rpc_get_degree_shuffle(Path(tmpdirname), num_server)\n\n\n# @unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')\n# @unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')\n@unittest.skip(\"Only support partition with shuffle\")\ndef test_rpc_sampling():\n    reset_envs()\n    import tempfile\n\n    os.environ[\"DGL_DIST_MODE\"] = \"distributed\"\n    with tempfile.TemporaryDirectory() as tmpdirname:\n        check_rpc_sampling(Path(tmpdirname), 1)\n\n\ndef check_rpc_sampling_shuffle(\n    tmpdir,\n    num_server,\n    num_groups=1,\n    use_graphbolt=False,\n    return_eids=False,\n    node_id_dtype=None,\n    replace=False,\n):\n    generate_ip_config(\"rpc_ip_config.txt\", num_server, num_server)\n\n    g = CitationGraphDataset(\"cora\")[0]\n    num_parts = num_server\n    num_hops = 1\n\n    orig_nids, orig_eids = partition_graph(\n        g,\n        \"test_sampling\",\n        num_parts,\n        tmpdir,\n        num_hops=num_hops,\n        part_method=\"metis\",\n        return_mapping=True,\n        use_graphbolt=use_graphbolt,\n        store_eids=return_eids,\n    )\n\n    pserver_list = []\n    ctx = mp.get_context(\"spawn\")\n    for i in range(num_server):\n        p = ctx.Process(\n            target=start_server,\n            args=(\n                i,\n                tmpdir,\n                num_server > 1,\n                \"test_sampling\",\n                [\"csc\", \"coo\"],\n                use_graphbolt,\n            ),\n        )\n        p.start()\n        time.sleep(1)\n        pserver_list.append(p)\n\n    pclient_list = []\n    num_clients = 1\n    for client_id in range(num_clients):\n        for group_id in range(num_groups):\n            p = ctx.Process(\n                target=start_sample_client_shuffle,\n                args=(\n                    client_id,\n                    tmpdir,\n                    num_server > 1,\n                    g,\n                    num_server,\n                    group_id,\n                    orig_nids,\n                    orig_eids,\n                    use_graphbolt,\n                    return_eids,\n                    node_id_dtype,\n                    replace,\n                ),\n            )\n            p.start()\n            time.sleep(1)  # avoid race condition when instantiating DistGraph\n            pclient_list.append(p)\n    for p in pclient_list:\n        p.join()\n        assert p.exitcode == 0\n    for p in pserver_list:\n        p.join()\n        assert p.exitcode == 0\n\n\ndef start_hetero_sample_client(\n    rank,\n    tmpdir,\n    disable_shared_mem,\n    nodes,\n    use_graphbolt=False,\n    return_eids=False,\n    replace=False,\n):\n    gpb = None\n    if disable_shared_mem:\n        _, _, _, gpb, _, _, _ = load_partition(\n            tmpdir / \"test_sampling.json\", rank\n        )\n    dgl.distributed.initialize(\"rpc_ip_config.txt\")\n    dist_graph = DistGraph(\"test_sampling\", gpb=gpb)\n    assert \"feat\" in dist_graph.nodes[\"n1\"].data\n    assert \"feat\" not in dist_graph.nodes[\"n2\"].data\n    assert \"feat\" not in dist_graph.nodes[\"n3\"].data\n    nodes = {\n        k: v.type(dist_graph.idtype).clone().detach() for k, v in nodes.items()\n    }\n    if gpb is None:\n        gpb = dist_graph.get_partition_book()\n    try:\n        # Enable santity check in distributed sampling.\n        os.environ[\"DGL_DIST_DEBUG\"] = \"1\"\n        sampled_graph = sample_neighbors(\n            dist_graph, nodes, 3, replace=replace, use_graphbolt=use_graphbolt\n        )\n        block = dgl.to_block(sampled_graph, nodes)\n        if not use_graphbolt or return_eids:\n            block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]\n    except Exception as e:\n        print(traceback.format_exc())\n        block = None\n    dgl.distributed.exit_client()\n    return block, gpb\n\n\ndef start_hetero_etype_sample_client(\n    rank,\n    tmpdir,\n    disable_shared_mem,\n    fanout=3,\n    nodes=None,\n    etype_sorted=False,\n    use_graphbolt=False,\n    return_eids=False,\n):\n    gpb = None\n    if disable_shared_mem:\n        _, _, _, gpb, _, _, _ = load_partition(\n            tmpdir / \"test_sampling.json\", rank\n        )\n    dgl.distributed.initialize(\"rpc_ip_config.txt\")\n    dist_graph = DistGraph(\"test_sampling\", gpb=gpb)\n    assert \"feat\" in dist_graph.nodes[\"n1\"].data\n    assert \"feat\" not in dist_graph.nodes[\"n2\"].data\n    assert \"feat\" not in dist_graph.nodes[\"n3\"].data\n    nodes = {\n        k: v.type(dist_graph.idtype).clone().detach() for k, v in nodes.items()\n    }\n\n    if (not use_graphbolt) and dist_graph.local_partition is not None:\n        # Check whether etypes are sorted in dist_graph\n        local_g = dist_graph.local_partition\n        local_nids = np.arange(local_g.num_nodes())\n        for lnid in local_nids:\n            leids = local_g.in_edges(lnid, form=\"eid\")\n            letids = F.asnumpy(local_g.edata[dgl.ETYPE][leids])\n            _, idices = np.unique(letids, return_index=True)\n            assert np.all(idices[:-1] <= idices[1:])\n\n    if gpb is None:\n        gpb = dist_graph.get_partition_book()\n    try:\n        # Enable santity check in distributed sampling.\n        os.environ[\"DGL_DIST_DEBUG\"] = \"1\"\n        sampled_graph = sample_etype_neighbors(\n            dist_graph,\n            nodes,\n            fanout,\n            etype_sorted=etype_sorted,\n            use_graphbolt=use_graphbolt,\n        )\n        block = dgl.to_block(sampled_graph, nodes)\n        if sampled_graph.num_edges() > 0:\n            if not use_graphbolt or return_eids:\n                block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]\n    except Exception as e:\n        print(traceback.format_exc())\n        block = None\n    dgl.distributed.exit_client()\n    return block, gpb\n\n\ndef check_rpc_hetero_sampling_shuffle(\n    tmpdir, num_server, use_graphbolt=False, return_eids=False, replace=False\n):\n    generate_ip_config(\"rpc_ip_config.txt\", num_server, num_server)\n\n    g = create_random_hetero()\n    num_parts = num_server\n    num_hops = 1\n\n    orig_nid_map, orig_eid_map = partition_graph(\n        g,\n        \"test_sampling\",\n        num_parts,\n        tmpdir,\n        num_hops=num_hops,\n        part_method=\"metis\",\n        return_mapping=True,\n        use_graphbolt=use_graphbolt,\n        store_eids=return_eids,\n    )\n\n    pserver_list = []\n    ctx = mp.get_context(\"spawn\")\n    for i in range(num_server):\n        p = ctx.Process(\n            target=start_server,\n            args=(\n                i,\n                tmpdir,\n                num_server > 1,\n                \"test_sampling\",\n                [\"csc\", \"coo\"],\n                use_graphbolt,\n            ),\n        )\n        p.start()\n        time.sleep(1)\n        pserver_list.append(p)\n\n    nodes = {\"n3\": torch.tensor([0, 10, 99, 66, 124, 208], dtype=g.idtype)}\n    block, gpb = start_hetero_sample_client(\n        0,\n        tmpdir,\n        num_server > 1,\n        nodes=nodes,\n        use_graphbolt=use_graphbolt,\n        return_eids=return_eids,\n        replace=replace,\n    )\n    for p in pserver_list:\n        p.join()\n        assert p.exitcode == 0\n\n    for c_etype in block.canonical_etypes:\n        src_type, etype, dst_type = c_etype\n        src, dst = block.edges(etype=etype)\n        # These are global Ids after shuffling.\n        shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src)\n        shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst)\n        orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src))\n        orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst))\n\n        assert np.all(\n            F.asnumpy(g.has_edges_between(orig_src, orig_dst, etype=etype))\n        )\n\n        if use_graphbolt and not return_eids:\n            continue\n\n        shuffled_eid = block.edges[etype].data[dgl.EID]\n        orig_eid = F.asnumpy(F.gather_row(orig_eid_map[c_etype], shuffled_eid))\n\n        # Check the node Ids and edge Ids.\n        orig_src1, orig_dst1 = g.find_edges(orig_eid, etype=etype)\n        assert np.all(F.asnumpy(orig_src1) == orig_src)\n        assert np.all(F.asnumpy(orig_dst1) == orig_dst)\n\n\ndef get_degrees(g, nids, ntype):\n    deg = F.zeros((len(nids),), dtype=F.int64)\n    for srctype, etype, dsttype in g.canonical_etypes:\n        if srctype == ntype:\n            deg += g.out_degrees(u=nids, etype=etype)\n        elif dsttype == ntype:\n            deg += g.in_degrees(v=nids, etype=etype)\n    return deg\n\n\ndef check_rpc_hetero_sampling_empty_shuffle(\n    tmpdir, num_server, use_graphbolt=False, return_eids=False\n):\n    generate_ip_config(\"rpc_ip_config.txt\", num_server, num_server)\n\n    g = create_random_hetero(empty=True)\n    num_parts = num_server\n    num_hops = 1\n\n    orig_nids, _ = partition_graph(\n        g,\n        \"test_sampling\",\n        num_parts,\n        tmpdir,\n        num_hops=num_hops,\n        part_method=\"metis\",\n        return_mapping=True,\n        use_graphbolt=use_graphbolt,\n        store_eids=return_eids,\n    )\n\n    pserver_list = []\n    ctx = mp.get_context(\"spawn\")\n    for i in range(num_server):\n        p = ctx.Process(\n            target=start_server,\n            args=(\n                i,\n                tmpdir,\n                num_server > 1,\n                \"test_sampling\",\n                [\"csc\", \"coo\"],\n                use_graphbolt,\n            ),\n        )\n        p.start()\n        time.sleep(1)\n        pserver_list.append(p)\n\n    deg = get_degrees(g, orig_nids[\"n3\"], \"n3\")\n    empty_nids = F.nonzero_1d(deg == 0).to(g.idtype)\n    block, gpb = start_hetero_sample_client(\n        0,\n        tmpdir,\n        num_server > 1,\n        nodes={\"n3\": empty_nids},\n        use_graphbolt=use_graphbolt,\n        return_eids=return_eids,\n    )\n    for p in pserver_list:\n        p.join()\n        assert p.exitcode == 0\n\n    assert block.num_edges() == 0\n    assert len(block.etypes) == len(g.etypes)\n\n\ndef check_rpc_hetero_etype_sampling_shuffle(\n    tmpdir,\n    num_server,\n    graph_formats=None,\n    use_graphbolt=False,\n    return_eids=False,\n):\n    generate_ip_config(\"rpc_ip_config.txt\", num_server, num_server)\n\n    g = create_random_hetero(dense=True)\n    num_parts = num_server\n    num_hops = 1\n\n    orig_nid_map, orig_eid_map = partition_graph(\n        g,\n        \"test_sampling\",\n        num_parts,\n        tmpdir,\n        num_hops=num_hops,\n        part_method=\"metis\",\n        return_mapping=True,\n        graph_formats=graph_formats,\n        use_graphbolt=use_graphbolt,\n        store_eids=return_eids,\n    )\n\n    pserver_list = []\n    ctx = mp.get_context(\"spawn\")\n    for i in range(num_server):\n        p = ctx.Process(\n            target=start_server,\n            args=(\n                i,\n                tmpdir,\n                num_server > 1,\n                \"test_sampling\",\n                [\"csc\", \"coo\"],\n                use_graphbolt,\n            ),\n        )\n        p.start()\n        time.sleep(1)\n        pserver_list.append(p)\n\n    fanout = {etype: 3 for etype in g.canonical_etypes}\n    etype_sorted = False\n    if graph_formats is not None:\n        etype_sorted = \"csc\" in graph_formats or \"csr\" in graph_formats\n    nodes = {\"n3\": torch.tensor([0, 10, 99, 66, 124, 208], dtype=g.idtype)}\n    block, gpb = start_hetero_etype_sample_client(\n        0,\n        tmpdir,\n        num_server > 1,\n        fanout,\n        nodes=nodes,\n        etype_sorted=etype_sorted,\n        use_graphbolt=use_graphbolt,\n        return_eids=return_eids,\n    )\n    print(\"Done sampling\")\n    for p in pserver_list:\n        p.join()\n        assert p.exitcode == 0\n\n    src, dst = block.edges(etype=(\"n1\", \"r13\", \"n3\"))\n    assert len(src) == 18\n    src, dst = block.edges(etype=(\"n2\", \"r23\", \"n3\"))\n    assert len(src) == 18\n\n    for c_etype in block.canonical_etypes:\n        src_type, etype, dst_type = c_etype\n        src, dst = block.edges(etype=etype)\n        # These are global Ids after shuffling.\n        shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src)\n        shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst)\n        orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src))\n        orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst))\n        assert np.all(\n            F.asnumpy(g.has_edges_between(orig_src, orig_dst, etype=etype))\n        )\n\n        if use_graphbolt and not return_eids:\n            continue\n\n        # Check the node Ids and edge Ids.\n        shuffled_eid = block.edges[etype].data[dgl.EID]\n        orig_eid = F.asnumpy(F.gather_row(orig_eid_map[c_etype], shuffled_eid))\n        orig_src1, orig_dst1 = g.find_edges(orig_eid, etype=etype)\n        assert np.all(F.asnumpy(orig_src1) == orig_src)\n        assert np.all(F.asnumpy(orig_dst1) == orig_dst)\n\n\ndef check_rpc_hetero_etype_sampling_empty_shuffle(\n    tmpdir, num_server, use_graphbolt=False, return_eids=False\n):\n    generate_ip_config(\"rpc_ip_config.txt\", num_server, num_server)\n\n    g = create_random_hetero(dense=True, empty=True)\n    num_parts = num_server\n    num_hops = 1\n\n    orig_nids, _ = partition_graph(\n        g,\n        \"test_sampling\",\n        num_parts,\n        tmpdir,\n        num_hops=num_hops,\n        part_method=\"metis\",\n        return_mapping=True,\n        use_graphbolt=use_graphbolt,\n        store_eids=return_eids,\n    )\n\n    pserver_list = []\n    ctx = mp.get_context(\"spawn\")\n    for i in range(num_server):\n        p = ctx.Process(\n            target=start_server,\n            args=(\n                i,\n                tmpdir,\n                num_server > 1,\n                \"test_sampling\",\n                [\"csc\", \"coo\"],\n                use_graphbolt,\n            ),\n        )\n        p.start()\n        time.sleep(1)\n        pserver_list.append(p)\n\n    fanout = 3\n    deg = get_degrees(g, orig_nids[\"n3\"], \"n3\")\n    empty_nids = F.nonzero_1d(deg == 0).to(g.idtype)\n    block, gpb = start_hetero_etype_sample_client(\n        0,\n        tmpdir,\n        num_server > 1,\n        fanout,\n        nodes={\"n3\": empty_nids},\n        use_graphbolt=use_graphbolt,\n        return_eids=return_eids,\n    )\n    print(\"Done sampling\")\n    for p in pserver_list:\n        p.join()\n        assert p.exitcode == 0\n\n    assert block.num_edges() == 0\n    assert len(block.etypes) == len(g.etypes)\n\n\ndef create_random_bipartite():\n    g = dgl.rand_bipartite(\"user\", \"buys\", \"game\", 500, 1000, 1000)\n    g.nodes[\"user\"].data[\"feat\"] = F.ones(\n        (g.num_nodes(\"user\"), 10), F.float32, F.cpu()\n    )\n    g.nodes[\"game\"].data[\"feat\"] = F.ones(\n        (g.num_nodes(\"game\"), 10), F.float32, F.cpu()\n    )\n    return g\n\n\ndef start_bipartite_sample_client(\n    rank,\n    tmpdir,\n    disable_shared_mem,\n    nodes,\n    use_graphbolt=False,\n    return_eids=False,\n):\n    gpb = None\n    if disable_shared_mem:\n        _, _, _, gpb, _, _, _ = load_partition(\n            tmpdir / \"test_sampling.json\", rank\n        )\n    dgl.distributed.initialize(\"rpc_ip_config.txt\")\n    dist_graph = DistGraph(\"test_sampling\", gpb=gpb)\n    assert \"feat\" in dist_graph.nodes[\"user\"].data\n    assert \"feat\" in dist_graph.nodes[\"game\"].data\n    nodes = {\n        k: v.type(dist_graph.idtype).clone().detach() for k, v in nodes.items()\n    }\n    if gpb is None:\n        gpb = dist_graph.get_partition_book()\n    # Enable santity check in distributed sampling.\n    os.environ[\"DGL_DIST_DEBUG\"] = \"1\"\n    sampled_graph = sample_neighbors(\n        dist_graph, nodes, 3, use_graphbolt=use_graphbolt\n    )\n    block = dgl.to_block(sampled_graph, nodes)\n    if sampled_graph.num_edges() > 0:\n        if not use_graphbolt or return_eids:\n            block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]\n    dgl.distributed.exit_client()\n    return block, gpb\n\n\ndef start_bipartite_etype_sample_client(\n    rank,\n    tmpdir,\n    disable_shared_mem,\n    fanout=3,\n    nodes={},\n    use_graphbolt=False,\n    return_eids=False,\n):\n    gpb = None\n    if disable_shared_mem:\n        _, _, _, gpb, _, _, _ = load_partition(\n            tmpdir / \"test_sampling.json\", rank\n        )\n    dgl.distributed.initialize(\"rpc_ip_config.txt\")\n    dist_graph = DistGraph(\"test_sampling\", gpb=gpb)\n    assert \"feat\" in dist_graph.nodes[\"user\"].data\n    assert \"feat\" in dist_graph.nodes[\"game\"].data\n    nodes = {\n        k: v.type(dist_graph.idtype).clone().detach() for k, v in nodes.items()\n    }\n\n    if not use_graphbolt and dist_graph.local_partition is not None:\n        # Check whether etypes are sorted in dist_graph\n        local_g = dist_graph.local_partition\n        local_nids = np.arange(local_g.num_nodes())\n        for lnid in local_nids:\n            leids = local_g.in_edges(lnid, form=\"eid\")\n            letids = F.asnumpy(local_g.edata[dgl.ETYPE][leids])\n            _, idices = np.unique(letids, return_index=True)\n            assert np.all(idices[:-1] <= idices[1:])\n\n    if gpb is None:\n        gpb = dist_graph.get_partition_book()\n    sampled_graph = sample_etype_neighbors(\n        dist_graph, nodes, fanout, use_graphbolt=use_graphbolt\n    )\n    block = dgl.to_block(sampled_graph, nodes)\n    if sampled_graph.num_edges() > 0:\n        if not use_graphbolt or return_eids:\n            block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]\n    dgl.distributed.exit_client()\n    return block, gpb\n\n\ndef check_rpc_bipartite_sampling_empty(\n    tmpdir, num_server, use_graphbolt=False, return_eids=False\n):\n    \"\"\"sample on bipartite via sample_neighbors() which yields empty sample results\"\"\"\n    generate_ip_config(\"rpc_ip_config.txt\", num_server, num_server)\n\n    g = create_random_bipartite()\n    num_parts = num_server\n    num_hops = 1\n\n    orig_nids, _ = partition_graph(\n        g,\n        \"test_sampling\",\n        num_parts,\n        tmpdir,\n        num_hops=num_hops,\n        part_method=\"metis\",\n        return_mapping=True,\n        use_graphbolt=use_graphbolt,\n        store_eids=return_eids,\n    )\n\n    pserver_list = []\n    ctx = mp.get_context(\"spawn\")\n    for i in range(num_server):\n        p = ctx.Process(\n            target=start_server,\n            args=(\n                i,\n                tmpdir,\n                num_server > 1,\n                \"test_sampling\",\n                [\"csc\", \"coo\"],\n                use_graphbolt,\n            ),\n        )\n        p.start()\n        time.sleep(1)\n        pserver_list.append(p)\n\n    deg = get_degrees(g, orig_nids[\"game\"], \"game\")\n    empty_nids = F.nonzero_1d(deg == 0).to(g.idtype)\n    nodes = {\"game\": empty_nids, \"user\": torch.tensor([1], dtype=g.idtype)}\n    block, _ = start_bipartite_sample_client(\n        0,\n        tmpdir,\n        num_server > 1,\n        nodes=nodes,\n        use_graphbolt=use_graphbolt,\n        return_eids=return_eids,\n    )\n\n    print(\"Done sampling\")\n    for p in pserver_list:\n        p.join()\n        assert p.exitcode == 0\n\n    assert block.num_edges() == 0\n    assert len(block.etypes) == len(g.etypes)\n\n\ndef check_rpc_bipartite_sampling_shuffle(\n    tmpdir, num_server, use_graphbolt=False, return_eids=False\n):\n    \"\"\"sample on bipartite via sample_neighbors() which yields non-empty sample results\"\"\"\n    generate_ip_config(\"rpc_ip_config.txt\", num_server, num_server)\n\n    g = create_random_bipartite()\n    num_parts = num_server\n    num_hops = 1\n\n    orig_nid_map, orig_eid_map = partition_graph(\n        g,\n        \"test_sampling\",\n        num_parts,\n        tmpdir,\n        num_hops=num_hops,\n        part_method=\"metis\",\n        return_mapping=True,\n        use_graphbolt=use_graphbolt,\n        store_eids=return_eids,\n    )\n\n    pserver_list = []\n    ctx = mp.get_context(\"spawn\")\n    for i in range(num_server):\n        p = ctx.Process(\n            target=start_server,\n            args=(\n                i,\n                tmpdir,\n                num_server > 1,\n                \"test_sampling\",\n                [\"csc\", \"coo\"],\n                use_graphbolt,\n            ),\n        )\n        p.start()\n        time.sleep(1)\n        pserver_list.append(p)\n\n    deg = get_degrees(g, orig_nid_map[\"game\"], \"game\")\n    nids = F.nonzero_1d(deg > 0)\n    nodes = {\"game\": nids, \"user\": torch.tensor([0], dtype=g.idtype)}\n    block, gpb = start_bipartite_sample_client(\n        0,\n        tmpdir,\n        num_server > 1,\n        nodes=nodes,\n        use_graphbolt=use_graphbolt,\n        return_eids=return_eids,\n    )\n    print(\"Done sampling\")\n    for p in pserver_list:\n        p.join()\n        assert p.exitcode == 0\n\n    for c_etype in block.canonical_etypes:\n        src_type, etype, dst_type = c_etype\n        src, dst = block.edges(etype=etype)\n        # These are global Ids after shuffling.\n        shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src)\n        shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst)\n        orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src))\n        orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst))\n        assert np.all(\n            F.asnumpy(g.has_edges_between(orig_src, orig_dst, etype=etype))\n        )\n\n        if use_graphbolt and not return_eids:\n            continue\n\n        shuffled_eid = block.edges[etype].data[dgl.EID]\n        orig_eid = F.asnumpy(F.gather_row(orig_eid_map[c_etype], shuffled_eid))\n\n        # Check the node Ids and edge Ids.\n        orig_src1, orig_dst1 = g.find_edges(orig_eid, etype=etype)\n        assert np.all(F.asnumpy(orig_src1) == orig_src)\n        assert np.all(F.asnumpy(orig_dst1) == orig_dst)\n\n\ndef check_rpc_bipartite_etype_sampling_empty(\n    tmpdir, num_server, use_graphbolt=False, return_eids=False\n):\n    \"\"\"sample on bipartite via sample_etype_neighbors() which yields empty sample results\"\"\"\n    generate_ip_config(\"rpc_ip_config.txt\", num_server, num_server)\n\n    g = create_random_bipartite()\n    num_parts = num_server\n    num_hops = 1\n\n    orig_nids, _ = partition_graph(\n        g,\n        \"test_sampling\",\n        num_parts,\n        tmpdir,\n        num_hops=num_hops,\n        part_method=\"metis\",\n        return_mapping=True,\n        use_graphbolt=use_graphbolt,\n        store_eids=return_eids,\n    )\n\n    pserver_list = []\n    ctx = mp.get_context(\"spawn\")\n    for i in range(num_server):\n        p = ctx.Process(\n            target=start_server,\n            args=(\n                i,\n                tmpdir,\n                num_server > 1,\n                \"test_sampling\",\n                [\"csc\", \"coo\"],\n                use_graphbolt,\n            ),\n        )\n        p.start()\n        time.sleep(1)\n        pserver_list.append(p)\n\n    deg = get_degrees(g, orig_nids[\"game\"], \"game\")\n    empty_nids = F.nonzero_1d(deg == 0).to(g.idtype)\n    nodes = {\"game\": empty_nids, \"user\": torch.tensor([1], dtype=g.idtype)}\n    block, _ = start_bipartite_etype_sample_client(\n        0,\n        tmpdir,\n        num_server > 1,\n        nodes=nodes,\n        use_graphbolt=use_graphbolt,\n        return_eids=return_eids,\n    )\n\n    print(\"Done sampling\")\n    for p in pserver_list:\n        p.join()\n        assert p.exitcode == 0\n\n    assert block is not None\n    assert block.num_edges() == 0\n    assert len(block.etypes) == len(g.etypes)\n\n\ndef check_rpc_bipartite_etype_sampling_shuffle(\n    tmpdir, num_server, use_graphbolt=False, return_eids=False\n):\n    \"\"\"sample on bipartite via sample_etype_neighbors() which yields non-empty sample results\"\"\"\n    generate_ip_config(\"rpc_ip_config.txt\", num_server, num_server)\n\n    g = create_random_bipartite()\n    num_parts = num_server\n    num_hops = 1\n\n    orig_nid_map, orig_eid_map = partition_graph(\n        g,\n        \"test_sampling\",\n        num_parts,\n        tmpdir,\n        num_hops=num_hops,\n        part_method=\"metis\",\n        return_mapping=True,\n        use_graphbolt=use_graphbolt,\n        store_eids=return_eids,\n    )\n\n    pserver_list = []\n    ctx = mp.get_context(\"spawn\")\n    for i in range(num_server):\n        p = ctx.Process(\n            target=start_server,\n            args=(\n                i,\n                tmpdir,\n                num_server > 1,\n                \"test_sampling\",\n                [\"csc\", \"coo\"],\n                use_graphbolt,\n            ),\n        )\n        p.start()\n        time.sleep(1)\n        pserver_list.append(p)\n\n    fanout = 3\n    deg = get_degrees(g, orig_nid_map[\"game\"], \"game\")\n    nids = F.nonzero_1d(deg > 0)\n    nodes = {\"game\": nids, \"user\": torch.tensor([0], dtype=g.idtype)}\n    block, gpb = start_bipartite_etype_sample_client(\n        0,\n        tmpdir,\n        num_server > 1,\n        fanout,\n        nodes=nodes,\n        use_graphbolt=use_graphbolt,\n        return_eids=return_eids,\n    )\n    print(\"Done sampling\")\n    for p in pserver_list:\n        p.join()\n        assert p.exitcode == 0\n\n    for c_etype in block.canonical_etypes:\n        src_type, etype, dst_type = c_etype\n        src, dst = block.edges(etype=etype)\n        # These are global Ids after shuffling.\n        shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src)\n        shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst)\n        orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src))\n        orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst))\n        assert np.all(\n            F.asnumpy(g.has_edges_between(orig_src, orig_dst, etype=etype))\n        )\n\n        if use_graphbolt and not return_eids:\n            continue\n\n        # Check the node Ids and edge Ids.\n        shuffled_eid = block.edges[etype].data[dgl.EID]\n        orig_eid = F.asnumpy(F.gather_row(orig_eid_map[c_etype], shuffled_eid))\n        orig_src1, orig_dst1 = g.find_edges(orig_eid, etype=etype)\n        assert np.all(F.asnumpy(orig_src1) == orig_src)\n        assert np.all(F.asnumpy(orig_dst1) == orig_dst)\n\n\n@pytest.mark.parametrize(\"num_server\", [1])\n@pytest.mark.parametrize(\"use_graphbolt\", [False, True])\n@pytest.mark.parametrize(\"return_eids\", [False, True])\n@pytest.mark.parametrize(\"node_id_dtype\", [torch.int64])\n@pytest.mark.parametrize(\"replace\", [False, True])\ndef test_rpc_sampling_shuffle(\n    num_server, use_graphbolt, return_eids, node_id_dtype, replace\n):\n    reset_envs()\n    os.environ[\"DGL_DIST_MODE\"] = \"distributed\"\n    with tempfile.TemporaryDirectory() as tmpdirname:\n        check_rpc_sampling_shuffle(\n            Path(tmpdirname),\n            num_server,\n            use_graphbolt=use_graphbolt,\n            return_eids=return_eids,\n            node_id_dtype=node_id_dtype,\n            replace=replace,\n        )\n\n\n@pytest.mark.parametrize(\"num_server\", [1])\n@pytest.mark.parametrize(\"use_graphbolt,\", [False, True])\n@pytest.mark.parametrize(\"return_eids\", [False, True])\n@pytest.mark.parametrize(\"replace\", [False, True])\ndef test_rpc_hetero_sampling_shuffle(\n    num_server, use_graphbolt, return_eids, replace\n):\n    reset_envs()\n    os.environ[\"DGL_DIST_MODE\"] = \"distributed\"\n    with tempfile.TemporaryDirectory() as tmpdirname:\n        check_rpc_hetero_sampling_shuffle(\n            Path(tmpdirname),\n            num_server,\n            use_graphbolt=use_graphbolt,\n            return_eids=return_eids,\n            replace=replace,\n        )\n\n\n@pytest.mark.parametrize(\"num_server\", [1])\n@pytest.mark.parametrize(\"use_graphbolt\", [False, True])\n@pytest.mark.parametrize(\"return_eids\", [False, True])\ndef test_rpc_hetero_sampling_empty_shuffle(\n    num_server, use_graphbolt, return_eids\n):\n    reset_envs()\n    os.environ[\"DGL_DIST_MODE\"] = \"distributed\"\n    with tempfile.TemporaryDirectory() as tmpdirname:\n        check_rpc_hetero_sampling_empty_shuffle(\n            Path(tmpdirname),\n            num_server,\n            use_graphbolt=use_graphbolt,\n            return_eids=return_eids,\n        )\n\n\n@pytest.mark.parametrize(\"num_server\", [1])\n@pytest.mark.parametrize(\n    \"graph_formats\", [None, [\"csc\"], [\"csr\"], [\"csc\", \"coo\"]]\n)\ndef test_rpc_hetero_etype_sampling_shuffle_dgl(num_server, graph_formats):\n    reset_envs()\n    os.environ[\"DGL_DIST_MODE\"] = \"distributed\"\n    with tempfile.TemporaryDirectory() as tmpdirname:\n        check_rpc_hetero_etype_sampling_shuffle(\n            Path(tmpdirname), num_server, graph_formats=graph_formats\n        )\n\n\n@pytest.mark.parametrize(\"num_server\", [1])\n@pytest.mark.parametrize(\"return_eids\", [False, True])\ndef test_rpc_hetero_etype_sampling_shuffle_graphbolt(num_server, return_eids):\n    reset_envs()\n    os.environ[\"DGL_DIST_MODE\"] = \"distributed\"\n    with tempfile.TemporaryDirectory() as tmpdirname:\n        check_rpc_hetero_etype_sampling_shuffle(\n            Path(tmpdirname),\n            num_server,\n            use_graphbolt=True,\n            return_eids=return_eids,\n        )\n\n\n@pytest.mark.parametrize(\"num_server\", [1])\n@pytest.mark.parametrize(\"use_graphbolt\", [False, True])\n@pytest.mark.parametrize(\"return_eids\", [False, True])\ndef test_rpc_hetero_etype_sampling_empty_shuffle(\n    num_server, use_graphbolt, return_eids\n):\n    reset_envs()\n    os.environ[\"DGL_DIST_MODE\"] = \"distributed\"\n    with tempfile.TemporaryDirectory() as tmpdirname:\n        check_rpc_hetero_etype_sampling_empty_shuffle(\n            Path(tmpdirname),\n            num_server,\n            use_graphbolt=use_graphbolt,\n            return_eids=return_eids,\n        )\n\n\n@pytest.mark.parametrize(\"num_server\", [1])\n@pytest.mark.parametrize(\"use_graphbolt\", [False, True])\n@pytest.mark.parametrize(\"return_eids\", [False, True])\ndef test_rpc_bipartite_sampling_empty_shuffle(\n    num_server, use_graphbolt, return_eids\n):\n    reset_envs()\n    os.environ[\"DGL_DIST_MODE\"] = \"distributed\"\n    with tempfile.TemporaryDirectory() as tmpdirname:\n        check_rpc_bipartite_sampling_empty(\n            Path(tmpdirname), num_server, use_graphbolt, return_eids\n        )\n\n\n@pytest.mark.parametrize(\"num_server\", [1])\n@pytest.mark.parametrize(\"use_graphbolt\", [False, True])\n@pytest.mark.parametrize(\"return_eids\", [False, True])\ndef test_rpc_bipartite_sampling_shuffle(num_server, use_graphbolt, return_eids):\n    reset_envs()\n    os.environ[\"DGL_DIST_MODE\"] = \"distributed\"\n    with tempfile.TemporaryDirectory() as tmpdirname:\n        check_rpc_bipartite_sampling_shuffle(\n            Path(tmpdirname), num_server, use_graphbolt, return_eids\n        )\n\n\n@pytest.mark.parametrize(\"num_server\", [1])\n@pytest.mark.parametrize(\"use_graphbolt\", [False, True])\n@pytest.mark.parametrize(\"return_eids\", [False, True])\ndef test_rpc_bipartite_etype_sampling_empty_shuffle(\n    num_server, use_graphbolt, return_eids\n):\n    reset_envs()\n    os.environ[\"DGL_DIST_MODE\"] = \"distributed\"\n    with tempfile.TemporaryDirectory() as tmpdirname:\n        check_rpc_bipartite_etype_sampling_empty(\n            Path(tmpdirname),\n            num_server,\n            use_graphbolt=use_graphbolt,\n            return_eids=return_eids,\n        )\n\n\n@pytest.mark.parametrize(\"num_server\", [1])\n@pytest.mark.parametrize(\"use_graphbolt\", [False, True])\n@pytest.mark.parametrize(\"return_eids\", [False, True])\ndef test_rpc_bipartite_etype_sampling_shuffle(\n    num_server, use_graphbolt, return_eids\n):\n    reset_envs()\n    os.environ[\"DGL_DIST_MODE\"] = \"distributed\"\n    with tempfile.TemporaryDirectory() as tmpdirname:\n        check_rpc_bipartite_etype_sampling_shuffle(\n            Path(tmpdirname),\n            num_server,\n            use_graphbolt=use_graphbolt,\n            return_eids=return_eids,\n        )\n\n\ndef check_standalone_sampling(tmpdir):\n    g = CitationGraphDataset(\"cora\")[0]\n    prob = np.maximum(np.random.randn(g.num_edges()), 0)\n    mask = prob > 0\n    g.edata[\"prob\"] = F.tensor(prob)\n    g.edata[\"mask\"] = F.tensor(mask)\n    num_parts = 1\n    num_hops = 1\n    partition_graph(\n        g,\n        \"test_sampling\",\n        num_parts,\n        tmpdir,\n        num_hops=num_hops,\n        part_method=\"metis\",\n    )\n\n    os.environ[\"DGL_DIST_MODE\"] = \"standalone\"\n    dgl.distributed.initialize(\"rpc_ip_config.txt\")\n    dist_graph = DistGraph(\n        \"test_sampling\", part_config=tmpdir / \"test_sampling.json\"\n    )\n    sampled_graph = sample_neighbors(\n        dist_graph,\n        torch.tensor([0, 10, 99, 66, 1024, 2008], dtype=dist_graph.idtype),\n        3,\n    )\n\n    src, dst = sampled_graph.edges()\n    assert sampled_graph.num_nodes() == g.num_nodes()\n    assert np.all(F.asnumpy(g.has_edges_between(src, dst)))\n    eids = g.edge_ids(src, dst)\n    assert np.array_equal(\n        F.asnumpy(sampled_graph.edata[dgl.EID]), F.asnumpy(eids)\n    )\n\n    sampled_graph = sample_neighbors(\n        dist_graph,\n        torch.tensor([0, 10, 99, 66, 1024, 2008], dtype=dist_graph.idtype),\n        3,\n        prob=\"mask\",\n    )\n    eid = F.asnumpy(sampled_graph.edata[dgl.EID])\n    assert mask[eid].all()\n\n    sampled_graph = sample_neighbors(\n        dist_graph,\n        torch.tensor([0, 10, 99, 66, 1024, 2008], dtype=dist_graph.idtype),\n        3,\n        prob=\"prob\",\n    )\n    eid = F.asnumpy(sampled_graph.edata[dgl.EID])\n    assert (prob[eid] > 0).all()\n    dgl.distributed.exit_client()\n\n\ndef check_standalone_etype_sampling(tmpdir):\n    hg = CitationGraphDataset(\"cora\")[0]\n    prob = np.maximum(np.random.randn(hg.num_edges()), 0)\n    mask = prob > 0\n    hg.edata[\"prob\"] = F.tensor(prob)\n    hg.edata[\"mask\"] = F.tensor(mask)\n    num_parts = 1\n    num_hops = 1\n\n    partition_graph(\n        hg,\n        \"test_sampling\",\n        num_parts,\n        tmpdir,\n        num_hops=num_hops,\n        part_method=\"metis\",\n    )\n    os.environ[\"DGL_DIST_MODE\"] = \"standalone\"\n    dgl.distributed.initialize(\"rpc_ip_config.txt\")\n    dist_graph = DistGraph(\n        \"test_sampling\", part_config=tmpdir / \"test_sampling.json\"\n    )\n    sampled_graph = sample_etype_neighbors(\n        dist_graph,\n        torch.tensor([0, 10, 99, 66, 1023], dtype=dist_graph.idtype),\n        3,\n    )\n\n    src, dst = sampled_graph.edges()\n    assert sampled_graph.num_nodes() == hg.num_nodes()\n    assert np.all(F.asnumpy(hg.has_edges_between(src, dst)))\n    eids = hg.edge_ids(src, dst)\n    assert np.array_equal(\n        F.asnumpy(sampled_graph.edata[dgl.EID]), F.asnumpy(eids)\n    )\n\n    sampled_graph = sample_etype_neighbors(\n        dist_graph,\n        torch.tensor([0, 10, 99, 66, 1023], dtype=dist_graph.idtype),\n        3,\n        prob=\"mask\",\n    )\n    eid = F.asnumpy(sampled_graph.edata[dgl.EID])\n    assert mask[eid].all()\n\n    sampled_graph = sample_etype_neighbors(\n        dist_graph,\n        torch.tensor([0, 10, 99, 66, 1023], dtype=dist_graph.idtype),\n        3,\n        prob=\"prob\",\n    )\n    eid = F.asnumpy(sampled_graph.edata[dgl.EID])\n    assert (prob[eid] > 0).all()\n    dgl.distributed.exit_client()\n\n\ndef check_standalone_etype_sampling_heterograph(tmpdir):\n    hg = CitationGraphDataset(\"cora\")[0]\n    num_parts = 1\n    num_hops = 1\n    src, dst = hg.edges()\n    new_hg = dgl.heterograph(\n        {\n            (\"paper\", \"cite\", \"paper\"): (src, dst),\n            (\"paper\", \"cite-by\", \"paper\"): (dst, src),\n        },\n        {\"paper\": hg.num_nodes()},\n    )\n    partition_graph(\n        new_hg,\n        \"test_hetero_sampling\",\n        num_parts,\n        tmpdir,\n        num_hops=num_hops,\n        part_method=\"metis\",\n    )\n    os.environ[\"DGL_DIST_MODE\"] = \"standalone\"\n    dgl.distributed.initialize(\"rpc_ip_config.txt\")\n    dist_graph = DistGraph(\n        \"test_hetero_sampling\", part_config=tmpdir / \"test_hetero_sampling.json\"\n    )\n    sampled_graph = sample_etype_neighbors(\n        dist_graph,\n        torch.tensor(\n            [0, 1, 2, 10, 99, 66, 1023, 1024, 2700, 2701],\n            dtype=dist_graph.idtype,\n        ),\n        1,\n    )\n    src, dst = sampled_graph.edges(etype=(\"paper\", \"cite\", \"paper\"))\n    assert len(src) == 10\n    src, dst = sampled_graph.edges(etype=(\"paper\", \"cite-by\", \"paper\"))\n    assert len(src) == 10\n    assert sampled_graph.num_nodes() == new_hg.num_nodes()\n    dgl.distributed.exit_client()\n\n\n@unittest.skipIf(os.name == \"nt\", reason=\"Do not support windows yet\")\n@unittest.skipIf(\n    dgl.backend.backend_name == \"tensorflow\",\n    reason=\"Not support tensorflow for now\",\n)\ndef test_standalone_sampling():\n    reset_envs()\n    import tempfile\n\n    os.environ[\"DGL_DIST_MODE\"] = \"standalone\"\n    with tempfile.TemporaryDirectory() as tmpdirname:\n        check_standalone_sampling(Path(tmpdirname))\n\n\ndef start_in_subgraph_client(rank, tmpdir, disable_shared_mem, nodes):\n    gpb = None\n    dgl.distributed.initialize(\"rpc_ip_config.txt\")\n    if disable_shared_mem:\n        _, _, _, gpb, _, _, _ = load_partition(\n            tmpdir / \"test_in_subgraph.json\", rank\n        )\n    dist_graph = DistGraph(\"test_in_subgraph\", gpb=gpb)\n    try:\n        sampled_graph = dgl.distributed.in_subgraph(dist_graph, nodes)\n    except Exception as e:\n        print(traceback.format_exc())\n        sampled_graph = None\n    dgl.distributed.exit_client()\n    return sampled_graph\n\n\ndef check_rpc_in_subgraph_shuffle(tmpdir, num_server):\n    generate_ip_config(\"rpc_ip_config.txt\", num_server, num_server)\n\n    g = CitationGraphDataset(\"cora\")[0]\n    num_parts = num_server\n\n    orig_nid, orig_eid = partition_graph(\n        g,\n        \"test_in_subgraph\",\n        num_parts,\n        tmpdir,\n        num_hops=1,\n        part_method=\"metis\",\n        return_mapping=True,\n    )\n\n    pserver_list = []\n    ctx = mp.get_context(\"spawn\")\n    for i in range(num_server):\n        p = ctx.Process(\n            target=start_server,\n            args=(i, tmpdir, num_server > 1, \"test_in_subgraph\"),\n        )\n        p.start()\n        time.sleep(1)\n        pserver_list.append(p)\n\n    nodes = torch.tensor([0, 10, 99, 66, 1024, 2008], dtype=g.idtype)\n    sampled_graph = start_in_subgraph_client(0, tmpdir, num_server > 1, nodes)\n    for p in pserver_list:\n        p.join()\n        assert p.exitcode == 0\n\n    src, dst = sampled_graph.edges()\n    src = orig_nid[src]\n    dst = orig_nid[dst]\n    assert sampled_graph.num_nodes() == g.num_nodes()\n    assert np.all(F.asnumpy(g.has_edges_between(src, dst)))\n\n    subg1 = dgl.in_subgraph(g, orig_nid[nodes])\n    src1, dst1 = subg1.edges()\n    assert np.all(np.sort(F.asnumpy(src)) == np.sort(F.asnumpy(src1)))\n    assert np.all(np.sort(F.asnumpy(dst)) == np.sort(F.asnumpy(dst1)))\n    eids = g.edge_ids(src, dst)\n    eids1 = orig_eid[sampled_graph.edata[dgl.EID]]\n    assert np.array_equal(F.asnumpy(eids1), F.asnumpy(eids))\n\n\n@unittest.skipIf(os.name == \"nt\", reason=\"Do not support windows yet\")\n@unittest.skipIf(\n    dgl.backend.backend_name == \"tensorflow\",\n    reason=\"Not support tensorflow for now\",\n)\ndef test_rpc_in_subgraph():\n    reset_envs()\n    import tempfile\n\n    os.environ[\"DGL_DIST_MODE\"] = \"distributed\"\n    with tempfile.TemporaryDirectory() as tmpdirname:\n        check_rpc_in_subgraph_shuffle(Path(tmpdirname), 1)\n\n\n@unittest.skipIf(os.name == \"nt\", reason=\"Do not support windows yet\")\n@unittest.skipIf(\n    dgl.backend.backend_name == \"tensorflow\",\n    reason=\"Not support tensorflow for now\",\n)\n@unittest.skipIf(\n    dgl.backend.backend_name == \"mxnet\", reason=\"Turn off Mxnet support\"\n)\ndef test_standalone_etype_sampling():\n    reset_envs()\n    import tempfile\n\n    with tempfile.TemporaryDirectory() as tmpdirname:\n        os.environ[\"DGL_DIST_MODE\"] = \"standalone\"\n        check_standalone_etype_sampling_heterograph(Path(tmpdirname))\n    with tempfile.TemporaryDirectory() as tmpdirname:\n        os.environ[\"DGL_DIST_MODE\"] = \"standalone\"\n        check_standalone_etype_sampling(Path(tmpdirname))\n\n\n@pytest.mark.parametrize(\"num_parts\", [1, 4])\n@pytest.mark.parametrize(\"use_graphbolt\", [False])\n@pytest.mark.parametrize(\"prob_or_mask\", [\"prob\", \"mask\"])\ndef test_local_sampling_homograph(num_parts, use_graphbolt, prob_or_mask):\n    reset_envs()\n    os.environ[\"DGL_DIST_MODE\"] = \"distributed\"\n    with tempfile.TemporaryDirectory() as test_dir:\n        g = CitationGraphDataset(\"cora\")[0]\n        prob = torch.rand(g.num_edges())\n        mask = prob > 0.2\n        prob[torch.randperm(len(prob))[: int(len(prob) * 0.5)]] = 0.0\n        g.edata[\"prob\"] = prob\n        g.edata[\"mask\"] = mask\n        graph_name = \"test_local_sampling\"\n\n        _, orig_eids = partition_graph(\n            g,\n            graph_name,\n            num_parts,\n            test_dir,\n            num_hops=1,\n            part_method=\"metis\",\n            return_mapping=True,\n            use_graphbolt=use_graphbolt,\n            store_eids=True,\n            store_inner_node=True,\n            store_inner_edge=True,\n        )\n\n        part_config = os.path.join(test_dir, f\"{graph_name}.json\")\n        for part_id in range(num_parts):\n            local_g, _, edge_feats, gpb, _, _, _ = load_partition(\n                part_config,\n                part_id,\n                load_feats=True,\n                use_graphbolt=use_graphbolt,\n            )\n            inner_global_nids = gpb.partid2nids(part_id)\n            inner_global_eids = gpb.partid2eids(part_id)\n            inner_node_data = (\n                local_g.node_attributes[\"inner_node\"]\n                if use_graphbolt\n                else local_g.ndata[\"inner_node\"]\n            )\n            inner_edge_data = (\n                local_g.edge_attributes[\"inner_edge\"]\n                if use_graphbolt\n                else local_g.edata[\"inner_edge\"]\n            )\n            assert len(inner_global_nids) == inner_node_data.sum()\n            assert len(inner_global_eids) == inner_edge_data.sum()\n\n            c_etype = gpb.canonical_etypes[0]\n            _prob = []\n            prob = edge_feats[_etype_tuple_to_str(c_etype) + \"/\" + prob_or_mask]\n            assert len(prob) == len(inner_global_eids)\n            assert len(prob) <= inner_edge_data.shape[0]\n            _prob.append(prob)\n\n            sampled_g = dgl.distributed.graph_services._sample_neighbors(\n                use_graphbolt,\n                local_g,\n                gpb,\n                inner_global_nids,\n                5,\n                prob=_prob,\n            )\n            sampled_homo_eids = sampled_g.global_eids\n            sampled_orig_eids = orig_eids[sampled_homo_eids]\n            assert torch.all(g.edata[prob_or_mask][sampled_orig_eids] > 0)\n\n\n@pytest.mark.parametrize(\"num_parts\", [1, 4])\n@pytest.mark.parametrize(\"use_graphbolt\", [False])\n@pytest.mark.parametrize(\"prob_or_mask\", [\"prob\", \"mask\"])\ndef test_local_sampling_heterograph(num_parts, use_graphbolt, prob_or_mask):\n    reset_envs()\n    os.environ[\"DGL_DIST_MODE\"] = \"distributed\"\n    with tempfile.TemporaryDirectory() as test_dir:\n        g = create_random_hetero()\n        for c_etype in g.canonical_etypes:\n            prob = torch.rand(g.num_edges(c_etype))\n            mask = prob > 0.2\n            prob[torch.randperm(len(prob))[: int(len(prob) * 0.5)]] = 0.0\n            g.edges[c_etype].data[\"prob\"] = prob\n            g.edges[c_etype].data[\"mask\"] = mask\n        graph_name = \"test_local_sampling\"\n\n        _, orig_eids = partition_graph(\n            g,\n            graph_name,\n            num_parts,\n            test_dir,\n            num_hops=1,\n            part_method=\"metis\",\n            return_mapping=True,\n            use_graphbolt=use_graphbolt,\n            store_eids=True,\n            store_inner_node=True,\n            store_inner_edge=True,\n        )\n\n        part_config = os.path.join(test_dir, f\"{graph_name}.json\")\n        for part_id in range(num_parts):\n            local_g, _, edge_feats, gpb, _, _, _ = load_partition(\n                part_config,\n                part_id,\n                load_feats=True,\n                use_graphbolt=use_graphbolt,\n            )\n            inner_global_nids = [\n                gpb.map_to_homo_nid(gpb.partid2nids(part_id, ntype), ntype)\n                for ntype in gpb.ntypes\n            ]\n            inner_global_nids = torch.cat(inner_global_nids)\n            inner_global_eids = {\n                c_etype: gpb.partid2eids(part_id, c_etype)\n                for c_etype in gpb.canonical_etypes\n            }\n            inner_node_data = (\n                local_g.node_attributes[\"inner_node\"]\n                if use_graphbolt\n                else local_g.ndata[\"inner_node\"]\n            )\n            inner_edge_data = (\n                local_g.edge_attributes[\"inner_edge\"]\n                if use_graphbolt\n                else local_g.edata[\"inner_edge\"]\n            )\n            assert len(inner_global_nids) == inner_node_data.sum()\n            num_inner_global_eids = sum(\n                [len(eids) for eids in inner_global_eids.values()]\n            )\n            assert num_inner_global_eids == inner_edge_data.sum()\n\n            _prob = []\n            for i, c_etype in enumerate(gpb.canonical_etypes):\n                prob = edge_feats[\n                    _etype_tuple_to_str(c_etype) + \"/\" + prob_or_mask\n                ]\n                assert len(prob) == len(inner_global_eids[c_etype])\n                assert (\n                    len(prob)\n                    == gpb.local_etype_offset[i + 1] - gpb.local_etype_offset[i]\n                )\n                assert len(prob) <= inner_edge_data.shape[0]\n                _prob.append(prob)\n\n            sampled_g = dgl.distributed.graph_services._sample_etype_neighbors(\n                use_graphbolt,\n                local_g,\n                gpb,\n                inner_global_nids,\n                torch.full((len(g.canonical_etypes),), 5, dtype=torch.int64),\n                prob=_prob,\n                etype_offset=gpb.local_etype_offset,\n            )\n            sampled_homo_eids = sampled_g.global_eids\n            sampled_etype_ids, sampled_per_etype_eids = gpb.map_to_per_etype(\n                sampled_homo_eids\n            )\n            for etype_id, c_etype in enumerate(gpb.canonical_etypes):\n                indices = torch.nonzero(sampled_etype_ids == etype_id).squeeze()\n                sampled_eids = sampled_per_etype_eids[indices]\n                sampled_orig_eids = orig_eids[c_etype][sampled_eids]\n                assert torch.all(\n                    g.edges[c_etype].data[prob_or_mask][sampled_orig_eids] > 0\n                )\n\n\ndef check_hetero_dist_edge_dataloader_gb(\n    tmpdir, num_server, use_graphbolt=True\n):\n    generate_ip_config(\"rpc_ip_config.txt\", num_server, num_server)\n\n    g = create_random_hetero()\n    eids = torch.randperm(g.num_edges(\"r23\"))[:10]\n    mask = torch.zeros(g.num_edges(\"r23\"), dtype=torch.bool)\n    mask[eids] = True\n\n    num_parts = num_server\n\n    orig_nid_map, orig_eid_map = partition_graph(\n        g,\n        \"test_sampling\",\n        num_parts,\n        tmpdir,\n        num_hops=1,\n        part_method=\"metis\",\n        return_mapping=True,\n        use_graphbolt=use_graphbolt,\n        store_eids=True,\n    )\n\n    part_config = tmpdir / \"test_sampling.json\"\n\n    pserver_list = []\n    ctx = mp.get_context(\"spawn\")\n    for i in range(num_server):\n        p = ctx.Process(\n            target=start_server,\n            args=(\n                i,\n                tmpdir,\n                num_server > 1,\n                \"test_sampling\",\n                [\"csc\", \"coo\"],\n                True,\n            ),\n        )\n        p.start()\n        time.sleep(1)\n        pserver_list.append(p)\n\n    dgl.distributed.initialize(\"rpc_ip_config.txt\", use_graphbolt=True)\n    dist_graph = DistGraph(\"test_sampling\", part_config=part_config)\n\n    os.environ[\"DGL_DIST_DEBUG\"] = \"1\"\n\n    edges = {(\"n2\", \"r23\", \"n3\"): eids}\n    sampler = dgl.dataloading.MultiLayerNeighborSampler([10, 10], mask=\"mask\")\n    loader = dgl.dataloading.DistEdgeDataLoader(\n        dist_graph, edges, sampler, batch_size=64\n    )\n    dgl.distributed.exit_client()\n    for p in pserver_list:\n        p.join()\n        assert p.exitcode == 0\n\n    block = next(iter(loader))[2][0]\n    assert block.num_src_nodes(\"n1\") > 0\n    assert block.num_edges(\"r12\") > 0\n    assert block.num_edges(\"r13\") > 0\n    assert block.num_edges(\"r23\") > 0\n\n\ndef test_hetero_dist_edge_dataloader_gb(\n    num_server=1,\n):\n    reset_envs()\n    os.environ[\"DGL_DIST_MODE\"] = \"distributed\"\n    with tempfile.TemporaryDirectory() as tmpdirname:\n        check_hetero_dist_edge_dataloader_gb(Path(tmpdirname), num_server)\n\n\nif __name__ == \"__main__\":\n    import tempfile\n\n    with tempfile.TemporaryDirectory() as tmpdirname:\n        os.environ[\"DGL_DIST_MODE\"] = \"standalone\"\n        check_standalone_etype_sampling_heterograph(Path(tmpdirname))\n\n    with tempfile.TemporaryDirectory() as tmpdirname:\n        os.environ[\"DGL_DIST_MODE\"] = \"standalone\"\n        check_standalone_etype_sampling(Path(tmpdirname))\n        check_standalone_sampling(Path(tmpdirname))\n        os.environ[\"DGL_DIST_MODE\"] = \"distributed\"\n        check_rpc_sampling(Path(tmpdirname), 2)\n        check_rpc_sampling(Path(tmpdirname), 1)\n        check_rpc_get_degree_shuffle(Path(tmpdirname), 1)\n        check_rpc_get_degree_shuffle(Path(tmpdirname), 2)\n        check_rpc_find_edges_shuffle(Path(tmpdirname), 2)\n        check_rpc_find_edges_shuffle(Path(tmpdirname), 1)\n        check_rpc_hetero_find_edges_shuffle(Path(tmpdirname), 1)\n        check_rpc_hetero_find_edges_shuffle(Path(tmpdirname), 2)\n        check_rpc_in_subgraph_shuffle(Path(tmpdirname), 2)\n        check_rpc_sampling_shuffle(Path(tmpdirname), 1)\n        check_rpc_hetero_sampling_shuffle(Path(tmpdirname), 1)\n        check_rpc_hetero_sampling_shuffle(Path(tmpdirname), 2)\n        check_rpc_hetero_sampling_empty_shuffle(Path(tmpdirname), 1)\n        check_rpc_hetero_etype_sampling_shuffle(Path(tmpdirname), 1)\n        check_rpc_hetero_etype_sampling_shuffle(Path(tmpdirname), 2)\n        check_rpc_hetero_etype_sampling_empty_shuffle(Path(tmpdirname), 1)\n"
  },
  {
    "path": "tests/distributed/test_mp_dataloader.py",
    "content": "import multiprocessing as mp\nimport os\nimport tempfile\nimport time\nimport unittest\nimport uuid\n\nimport backend as F\nimport dgl\nimport numpy as np\nimport pytest\nimport torch as th\nfrom dgl.data import CitationGraphDataset\nfrom dgl.distributed import (\n    DistDataLoader,\n    DistGraph,\n    DistGraphServer,\n    load_partition,\n    partition_graph,\n)\nfrom scipy import sparse as spsp\nfrom utils import generate_ip_config, reset_envs\n\n\ndef _unique_rand_graph(num_nodes=1000, num_edges=10 * 1000):\n    edges_set = set()\n    while len(edges_set) < num_edges:\n        src = np.random.randint(0, num_nodes - 1)\n        dst = np.random.randint(0, num_nodes - 1)\n        if (\n            src != dst\n            and (src, dst) not in edges_set\n            and (dst, src) not in edges_set\n        ):\n            edges_set.add((src, dst))\n    src_list, dst_list = zip(*edges_set)\n\n    src = th.tensor(src_list, dtype=th.long)\n    dst = th.tensor(dst_list, dtype=th.long)\n    g = dgl.graph((th.cat([src, dst]), th.cat([dst, src])))\n    E = len(src)\n    reverse_eids = th.cat([th.arange(E, 2 * E), th.arange(0, E)])\n    return g, reverse_eids\n\n\nclass NeighborSampler(object):\n    def __init__(\n        self,\n        g,\n        fanouts,\n        sample_neighbors,\n        use_graphbolt=False,\n        return_eids=False,\n    ):\n        self.g = g\n        self.fanouts = fanouts\n        self.sample_neighbors = sample_neighbors\n        self.use_graphbolt = use_graphbolt\n        self.return_eids = return_eids\n\n    def sample_blocks(self, seeds):\n        import torch as th\n\n        seeds = th.tensor(np.asarray(seeds), dtype=self.g.idtype)\n        blocks = []\n        for fanout in self.fanouts:\n            # For each seed node, sample ``fanout`` neighbors.\n            frontier = self.sample_neighbors(\n                self.g, seeds, fanout, use_graphbolt=self.use_graphbolt\n            )\n            # Then we compact the frontier into a bipartite graph for\n            # message passing.\n            block = dgl.to_block(frontier, seeds)\n            # Obtain the seed nodes for next layer.\n            seeds = block.srcdata[dgl.NID]\n            if frontier.num_edges() > 0:\n                if not self.use_graphbolt or self.return_eids:\n                    block.edata[dgl.EID] = frontier.edata[dgl.EID]\n\n            blocks.insert(0, block)\n        return blocks\n\n\ndef start_server(\n    rank,\n    ip_config,\n    part_config,\n    disable_shared_mem,\n    num_clients,\n    use_graphbolt=False,\n):\n    print(\"server: #clients=\" + str(num_clients))\n    g = DistGraphServer(\n        rank,\n        ip_config,\n        1,\n        num_clients,\n        part_config,\n        disable_shared_mem=disable_shared_mem,\n        graph_format=[\"csc\", \"coo\"],\n        use_graphbolt=use_graphbolt,\n    )\n    g.start()\n\n\ndef start_dist_dataloader(\n    rank,\n    ip_config,\n    part_config,\n    num_server,\n    drop_last,\n    orig_nid,\n    orig_eid,\n    use_graphbolt=False,\n    return_eids=False,\n):\n    dgl.distributed.initialize(ip_config)\n    gpb = None\n    disable_shared_mem = num_server > 1\n    if disable_shared_mem:\n        _, _, _, gpb, _, _, _ = load_partition(part_config, rank)\n    num_nodes_to_sample = 202\n    batch_size = 32\n    train_nid = th.arange(num_nodes_to_sample)\n    graph_name = os.path.splitext(os.path.basename(part_config))[0]\n    dist_graph = DistGraph(\n        graph_name,\n        gpb=gpb,\n        part_config=part_config,\n    )\n\n    # Create sampler\n    sampler = NeighborSampler(\n        dist_graph,\n        [5, 10],\n        dgl.distributed.sample_neighbors,\n        use_graphbolt=use_graphbolt,\n        return_eids=return_eids,\n    )\n\n    # Enable santity check in distributed sampling.\n    os.environ[\"DGL_DIST_DEBUG\"] = \"1\"\n\n    # We need to test creating DistDataLoader multiple times.\n    for i in range(2):\n        # Create DataLoader for constructing blocks\n        dataloader = DistDataLoader(\n            dataset=train_nid,\n            batch_size=batch_size,\n            collate_fn=sampler.sample_blocks,\n            shuffle=False,\n            drop_last=drop_last,\n        )\n\n        groundtruth_g = CitationGraphDataset(\"cora\")[0]\n        max_nid = []\n\n        for _ in range(2):\n            for idx, blocks in zip(\n                range(0, num_nodes_to_sample, batch_size), dataloader\n            ):\n                block = blocks[-1]\n                o_src, o_dst = block.edges()\n                src_nodes_id = block.srcdata[dgl.NID][o_src]\n                dst_nodes_id = block.dstdata[dgl.NID][o_dst]\n                max_nid.append(np.max(F.asnumpy(dst_nodes_id)))\n\n                src_nodes_id = orig_nid[src_nodes_id]\n                dst_nodes_id = orig_nid[dst_nodes_id]\n                has_edges = groundtruth_g.has_edges_between(\n                    src_nodes_id, dst_nodes_id\n                )\n                assert np.all(F.asnumpy(has_edges))\n\n                if use_graphbolt and not return_eids:\n                    continue\n                eids = orig_eid[block.edata[dgl.EID]]\n                expected_eids = groundtruth_g.edge_ids(\n                    src_nodes_id, dst_nodes_id\n                )\n                assert th.equal(\n                    eids, expected_eids\n                ), f\"{eids} != {expected_eids}\"\n            if drop_last:\n                assert (\n                    np.max(max_nid)\n                    == num_nodes_to_sample\n                    - 1\n                    - num_nodes_to_sample % batch_size\n                )\n            else:\n                assert np.max(max_nid) == num_nodes_to_sample - 1\n    del dataloader\n    # this is needed since there's two test here in one process\n    dgl.distributed.exit_client()\n\n\n@unittest.skip(reason=\"Skip due to glitch in CI\")\ndef test_standalone():\n    reset_envs()\n    with tempfile.TemporaryDirectory() as test_dir:\n        ip_config = os.path.join(test_dir, \"ip_config.txt\")\n        generate_ip_config(ip_config, 1, 1)\n\n        g = CitationGraphDataset(\"cora\")[0]\n        print(g.idtype)\n        num_parts = 1\n        num_hops = 1\n        graph_name = f\"graph_{uuid.uuid4()}\"\n        orig_nid, orig_eid = partition_graph(\n            g,\n            graph_name,\n            num_parts,\n            test_dir,\n            num_hops=num_hops,\n            part_method=\"metis\",\n            return_mapping=True,\n        )\n        part_config = os.path.join(test_dir, f\"{graph_name}.json\")\n        os.environ[\"DGL_DIST_MODE\"] = \"standalone\"\n        try:\n            start_dist_dataloader(\n                0, ip_config, part_config, 1, True, orig_nid, orig_eid\n            )\n        except Exception as e:\n            print(e)\n\n\ndef start_dist_neg_dataloader(\n    rank,\n    ip_config,\n    part_config,\n    num_server,\n    num_workers,\n    orig_nid,\n    groundtruth_g,\n):\n    import dgl\n    import torch as th\n\n    dgl.distributed.initialize(ip_config)\n    gpb = None\n    disable_shared_mem = num_server > 1\n    if disable_shared_mem:\n        _, _, _, gpb, _, _, _ = load_partition(part_config, rank)\n    num_edges_to_sample = 202\n    batch_size = 32\n    graph_name = os.path.splitext(os.path.basename(part_config))[0]\n    dist_graph = DistGraph(graph_name, gpb=gpb, part_config=part_config)\n    assert len(dist_graph.ntypes) == len(groundtruth_g.ntypes)\n    assert len(dist_graph.etypes) == len(groundtruth_g.etypes)\n    if len(dist_graph.etypes) == 1:\n        train_eid = th.arange(num_edges_to_sample)\n    else:\n        train_eid = {dist_graph.etypes[0]: th.arange(num_edges_to_sample)}\n\n    for i in range(num_server):\n        part, _, _, _, _, _, _ = load_partition(part_config, i)\n\n    num_negs = 5\n    sampler = dgl.dataloading.MultiLayerNeighborSampler([5, 10])\n    negative_sampler = dgl.dataloading.negative_sampler.Uniform(num_negs)\n    dataloader = dgl.distributed.DistEdgeDataLoader(\n        dist_graph,\n        train_eid,\n        sampler,\n        batch_size=batch_size,\n        negative_sampler=negative_sampler,\n        shuffle=True,\n        drop_last=False,\n        num_workers=num_workers,\n    )\n    for _ in range(2):\n        for _, (_, pos_graph, neg_graph, blocks) in zip(\n            range(0, num_edges_to_sample, batch_size), dataloader\n        ):\n            block = blocks[-1]\n            for src_type, etype, dst_type in block.canonical_etypes:\n                o_src, o_dst = block.edges(etype=etype)\n                src_nodes_id = block.srcnodes[src_type].data[dgl.NID][o_src]\n                dst_nodes_id = block.dstnodes[dst_type].data[dgl.NID][o_dst]\n                src_nodes_id = orig_nid[src_type][src_nodes_id]\n                dst_nodes_id = orig_nid[dst_type][dst_nodes_id]\n                has_edges = groundtruth_g.has_edges_between(\n                    src_nodes_id, dst_nodes_id, etype=etype\n                )\n                assert np.all(F.asnumpy(has_edges))\n                assert np.all(\n                    F.asnumpy(block.dstnodes[dst_type].data[dgl.NID])\n                    == F.asnumpy(pos_graph.nodes[dst_type].data[dgl.NID])\n                )\n                assert np.all(\n                    F.asnumpy(block.dstnodes[dst_type].data[dgl.NID])\n                    == F.asnumpy(neg_graph.nodes[dst_type].data[dgl.NID])\n                )\n                assert pos_graph.num_edges() * num_negs == neg_graph.num_edges()\n\n    del dataloader\n    # this is needed since there's two test here in one process\n    dgl.distributed.exit_client()\n\n\ndef check_neg_dataloader(g, num_server, num_workers):\n    with tempfile.TemporaryDirectory() as test_dir:\n        ip_config = \"ip_config.txt\"\n        generate_ip_config(ip_config, num_server, num_server)\n\n        num_parts = num_server\n        num_hops = 1\n        graph_name = f\"graph_{uuid.uuid4()}\"\n        orig_nid, orig_eid = partition_graph(\n            g,\n            graph_name,\n            num_parts,\n            test_dir,\n            num_hops=num_hops,\n            part_method=\"metis\",\n            return_mapping=True,\n        )\n        part_config = os.path.join(test_dir, f\"{graph_name}.json\")\n        if not isinstance(orig_nid, dict):\n            orig_nid = {g.ntypes[0]: orig_nid}\n        if not isinstance(orig_eid, dict):\n            orig_eid = {g.etypes[0]: orig_eid}\n\n        pserver_list = []\n        ctx = mp.get_context(\"spawn\")\n        for i in range(num_server):\n            p = ctx.Process(\n                target=start_server,\n                args=(\n                    i,\n                    ip_config,\n                    part_config,\n                    num_server > 1,\n                    num_workers + 1,\n                ),\n            )\n            p.start()\n            time.sleep(1)\n            pserver_list.append(p)\n        os.environ[\"DGL_DIST_MODE\"] = \"distributed\"\n        os.environ[\"DGL_NUM_SAMPLER\"] = str(num_workers)\n        ptrainer_list = []\n\n        p = ctx.Process(\n            target=start_dist_neg_dataloader,\n            args=(\n                0,\n                ip_config,\n                part_config,\n                num_server,\n                num_workers,\n                orig_nid,\n                g,\n            ),\n        )\n        p.start()\n        ptrainer_list.append(p)\n\n        for p in pserver_list:\n            p.join()\n            assert p.exitcode == 0\n        for p in ptrainer_list:\n            p.join()\n            assert p.exitcode == 0\n\n\n@pytest.mark.parametrize(\"num_server\", [1])\n@pytest.mark.parametrize(\"num_workers\", [0, 1])\n@pytest.mark.parametrize(\"use_graphbolt\", [False, True])\n@pytest.mark.parametrize(\"return_eids\", [False, True])\ndef test_dist_dataloader(num_server, num_workers, use_graphbolt, return_eids):\n    if not use_graphbolt and return_eids:\n        # return_eids is not supported in non-GraphBolt mode.\n        return\n    reset_envs()\n    os.environ[\"DGL_DIST_MODE\"] = \"distributed\"\n    os.environ[\"DGL_NUM_SAMPLER\"] = str(num_workers)\n    with tempfile.TemporaryDirectory() as test_dir:\n        ip_config = \"ip_config.txt\"\n        generate_ip_config(ip_config, num_server, num_server)\n\n        g = CitationGraphDataset(\"cora\")[0]\n        num_parts = num_server\n        num_hops = 1\n        graph_name = f\"graph_{uuid.uuid4()}\"\n        orig_nid, orig_eid = partition_graph(\n            g,\n            graph_name,\n            num_parts,\n            test_dir,\n            num_hops=num_hops,\n            part_method=\"metis\",\n            return_mapping=True,\n            use_graphbolt=use_graphbolt,\n            store_eids=return_eids,\n        )\n\n        part_config = os.path.join(test_dir, f\"{graph_name}.json\")\n        pserver_list = []\n        ctx = mp.get_context(\"spawn\")\n        for i in range(num_server):\n            p = ctx.Process(\n                target=start_server,\n                args=(\n                    i,\n                    ip_config,\n                    part_config,\n                    num_server > 1,\n                    num_workers + 1,\n                    use_graphbolt,\n                ),\n            )\n            p.start()\n            time.sleep(1)\n            pserver_list.append(p)\n\n        ptrainer_list = []\n        num_trainers = 1\n        for trainer_id in range(num_trainers):\n            p = ctx.Process(\n                target=start_dist_dataloader,\n                args=(\n                    trainer_id,\n                    ip_config,\n                    part_config,\n                    num_server,\n                    False,\n                    orig_nid,\n                    orig_eid,\n                    use_graphbolt,\n                    return_eids,\n                ),\n            )\n            p.start()\n            time.sleep(1)  # avoid race condition when instantiating DistGraph\n            ptrainer_list.append(p)\n\n        for p in ptrainer_list:\n            p.join()\n            assert p.exitcode == 0\n        for p in pserver_list:\n            p.join()\n            assert p.exitcode == 0\n\n\ndef start_node_dataloader(\n    rank,\n    ip_config,\n    part_config,\n    num_server,\n    num_workers,\n    orig_nid,\n    orig_eid,\n    groundtruth_g,\n    use_graphbolt=False,\n    return_eids=False,\n    prob_or_mask=None,\n    use_deprecated_dataloader=False,\n):\n    dgl.distributed.initialize(ip_config, use_graphbolt=use_graphbolt)\n    gpb = None\n    disable_shared_mem = num_server > 1\n    if disable_shared_mem:\n        _, _, _, gpb, _, _, _ = load_partition(part_config, rank)\n    num_nodes_to_sample = 202\n    batch_size = 32\n    graph_name = os.path.splitext(os.path.basename(part_config))[0]\n    dist_graph = DistGraph(\n        graph_name,\n        gpb=gpb,\n        part_config=part_config,\n    )\n    assert len(dist_graph.ntypes) == len(groundtruth_g.ntypes)\n    assert len(dist_graph.etypes) == len(groundtruth_g.etypes)\n    if len(dist_graph.etypes) == 1:\n        train_nid = th.arange(num_nodes_to_sample, dtype=dist_graph.idtype)\n    else:\n        train_nid = {\n            \"n3\": th.arange(num_nodes_to_sample, dtype=dist_graph.idtype)\n        }\n\n    for i in range(num_server):\n        part, _, _, _, _, _, _ = load_partition(part_config, i)\n\n    # Create sampler\n    _prob = None\n    _mask = None\n    if prob_or_mask is None:\n        pass\n    elif prob_or_mask == \"prob\":\n        _prob = \"prob\"\n    elif prob_or_mask == \"mask\":\n        _mask = \"mask\"\n    else:\n        raise ValueError(f\"Unsupported prob type: {prob_or_mask}\")\n    sampler = dgl.dataloading.MultiLayerNeighborSampler(\n        [\n            (\n                # test dict for hetero\n                {etype: 5 for etype in dist_graph.etypes}\n                if len(dist_graph.etypes) > 1\n                else 5\n            ),\n            10,\n        ],\n        prob=_prob,\n        mask=_mask,\n    )  # test int for hetero\n\n    # Enable santity check in distributed sampling.\n    os.environ[\"DGL_DIST_DEBUG\"] = \"1\"\n\n    # We need to test creating DistDataLoader multiple times.\n    for i in range(2):\n        # Create DataLoader for constructing blocks\n        dataloader_cls = (\n            dgl.dataloading.DistNodeDataLoader\n            if use_deprecated_dataloader\n            else dgl.distributed.DistNodeDataLoader\n        )\n        dataloader = dataloader_cls(\n            dist_graph,\n            train_nid,\n            sampler,\n            batch_size=batch_size,\n            shuffle=True,\n            drop_last=False,\n            num_workers=num_workers,\n        )\n\n        for _ in range(2):\n            for idx, (_, _, blocks) in zip(\n                range(0, num_nodes_to_sample, batch_size), dataloader\n            ):\n                block = blocks[-1]\n                for c_etype in block.canonical_etypes:\n                    src_type, _, dst_type = c_etype\n                    o_src, o_dst = block.edges(etype=c_etype)\n                    src_nodes_id = block.srcnodes[src_type].data[dgl.NID][o_src]\n                    dst_nodes_id = block.dstnodes[dst_type].data[dgl.NID][o_dst]\n                    src_nodes_id = orig_nid[src_type][src_nodes_id]\n                    dst_nodes_id = orig_nid[dst_type][dst_nodes_id]\n                    has_edges = groundtruth_g.has_edges_between(\n                        src_nodes_id, dst_nodes_id, etype=c_etype\n                    )\n                    assert np.all(F.asnumpy(has_edges))\n\n                    if use_graphbolt and not return_eids:\n                        assert dgl.EID not in block.edges[c_etype].data\n                        continue\n                    eids = orig_eid[c_etype][block.edges[c_etype].data[dgl.EID]]\n                    expected_eids = groundtruth_g.edge_ids(\n                        src_nodes_id, dst_nodes_id, etype=c_etype\n                    )\n                    assert th.equal(\n                        eids, expected_eids\n                    ), f\"{eids} != {expected_eids}\"\n                    # Verify the prob/mask functionality.\n                    if prob_or_mask is not None:\n                        prob_data = groundtruth_g.edges[c_etype].data[\n                            prob_or_mask\n                        ][eids]\n                        assert th.all(prob_data > 0)\n    del dataloader\n    # this is needed since there's two test here in one process\n    dgl.distributed.exit_client()\n\n\ndef start_edge_dataloader(\n    rank,\n    ip_config,\n    part_config,\n    num_server,\n    num_workers,\n    orig_nid,\n    orig_eid,\n    groundtruth_g,\n    use_graphbolt,\n    exclude,\n    reverse_eids,\n    reverse_etypes,\n    negative,\n    prob_or_mask,\n    use_deprecated_dataloader=False,\n):\n    dgl.distributed.initialize(ip_config, use_graphbolt=use_graphbolt)\n    gpb = None\n    disable_shared_mem = num_server > 1\n    if disable_shared_mem:\n        _, _, _, gpb, _, _, _ = load_partition(part_config, rank)\n    num_edges_to_sample = 202\n    batch_size = 32\n    graph_name = os.path.splitext(os.path.basename(part_config))[0]\n    dist_graph = DistGraph(graph_name, gpb=gpb, part_config=part_config)\n    assert len(dist_graph.ntypes) == len(groundtruth_g.ntypes)\n    assert len(dist_graph.etypes) == len(groundtruth_g.etypes)\n    if len(dist_graph.etypes) == 1:\n        train_eid = th.arange(num_edges_to_sample)\n    else:\n        train_eid = {\n            dist_graph.canonical_etypes[0]: th.arange(num_edges_to_sample)\n        }\n\n    for i in range(num_server):\n        part, _, _, _, _, _, _ = load_partition(part_config, i)\n\n    # Create sampler\n    _prob = None\n    _mask = None\n    if prob_or_mask is None:\n        pass\n    elif prob_or_mask == \"prob\":\n        _prob = \"prob\"\n    elif prob_or_mask == \"mask\":\n        _mask = \"mask\"\n    else:\n        raise ValueError(f\"Unsupported prob type: {prob_or_mask}\")\n    sampler = dgl.dataloading.MultiLayerNeighborSampler(\n        [5, -1], prob=_prob, mask=_mask\n    )\n\n    # Negative sampler.\n    negative_sampler = None\n    if negative:\n        negative_sampler = dgl.dataloading.negative_sampler.Uniform(5)\n\n    # We need to test creating DistDataLoader multiple times.\n    for i in range(2):\n        # Create DataLoader for constructing blocks\n        dataloader_cls = (\n            dgl.dataloading.DistEdgeDataLoader\n            if use_deprecated_dataloader\n            else dgl.distributed.DistEdgeDataLoader\n        )\n        dataloader = dataloader_cls(\n            dist_graph,\n            train_eid,\n            sampler,\n            batch_size=batch_size,\n            shuffle=True,\n            drop_last=False,\n            num_workers=num_workers,\n            exclude=exclude,\n            reverse_eids=reverse_eids,\n            reverse_etypes=reverse_etypes,\n            negative_sampler=negative_sampler,\n        )\n\n        for _ in range(2):\n            for _, minibatch in zip(\n                range(0, num_edges_to_sample, batch_size), dataloader\n            ):\n                if negative:\n                    _, pos_pair_graph, neg_pair_graph, blocks = minibatch\n                else:\n                    _, pos_pair_graph, blocks = minibatch\n                block = blocks[-1]\n                for src_type, etype, dst_type in block.canonical_etypes:\n                    o_src, o_dst = block.edges(etype=etype)\n                    src_nodes_id = block.srcnodes[src_type].data[dgl.NID][o_src]\n                    dst_nodes_id = block.dstnodes[dst_type].data[dgl.NID][o_dst]\n                    src_nodes_id = orig_nid[src_type][src_nodes_id]\n                    dst_nodes_id = orig_nid[dst_type][dst_nodes_id]\n                    has_edges = groundtruth_g.has_edges_between(\n                        src_nodes_id, dst_nodes_id, etype=etype\n                    )\n                    assert np.all(F.asnumpy(has_edges))\n                    assert np.all(\n                        F.asnumpy(block.dstnodes[dst_type].data[dgl.NID])\n                        == F.asnumpy(\n                            pos_pair_graph.nodes[dst_type].data[dgl.NID]\n                        )\n                    )\n                    if negative:\n                        assert np.all(\n                            F.asnumpy(block.dstnodes[dst_type].data[dgl.NID])\n                            == F.asnumpy(\n                                neg_pair_graph.nodes[dst_type].data[dgl.NID]\n                            )\n                        )\n                    if (\n                        dgl.EID\n                        not in block.edges[(src_type, etype, dst_type)].data\n                    ):\n                        continue\n                    sampled_eids = block.edges[\n                        (src_type, etype, dst_type)\n                    ].data[dgl.EID]\n                    sampled_orig_eids = orig_eid[(src_type, etype, dst_type)][\n                        sampled_eids\n                    ]\n                    raw_src, raw_dst = groundtruth_g.find_edges(\n                        sampled_orig_eids, etype=(src_type, etype, dst_type)\n                    )\n                    sampled_src, sampled_dst = block.edges(\n                        etype=(src_type, etype, dst_type)\n                    )\n                    sampled_orig_src = block.nodes[src_type].data[dgl.NID][\n                        sampled_src\n                    ]\n                    sampled_orig_dst = block.nodes[dst_type].data[dgl.NID][\n                        sampled_dst\n                    ]\n                    assert th.equal(\n                        raw_src, orig_nid[src_type][sampled_orig_src]\n                    )\n                    assert th.equal(\n                        raw_dst, orig_nid[dst_type][sampled_orig_dst]\n                    )\n                    # Verify the prob/mask functionality.\n                    if prob_or_mask is not None:\n                        prob_data = groundtruth_g.edges[etype].data[\n                            prob_or_mask\n                        ][sampled_orig_eids]\n                        assert th.all(prob_data > 0)\n                # Verify the exclude functionality.\n                if dgl.EID not in blocks[-1].edata.keys():\n                    continue\n                for (\n                    src_type,\n                    etype,\n                    dst_type,\n                ) in pos_pair_graph.canonical_etypes:\n                    for block in blocks:\n                        if (\n                            src_type,\n                            etype,\n                            dst_type,\n                        ) not in block.canonical_etypes:\n                            continue\n                        current_eids = block.edges[etype].data[dgl.EID]\n                        seed_eids = pos_pair_graph.edges[etype].data[dgl.EID]\n                        if exclude is None:\n                            # seed_eids are not guaranteed to be sampled.\n                            pass\n                        elif exclude == \"self\":\n                            assert not th.any(th.isin(current_eids, seed_eids))\n                        elif exclude == \"reverse_id\":\n                            src, dst = groundtruth_g.find_edges(seed_eids)\n                            reverse_seed_eids = groundtruth_g.edge_ids(dst, src)\n                            assert not th.any(\n                                th.isin(current_eids, reverse_seed_eids)\n                            )\n                            assert not th.any(th.isin(current_eids, seed_eids))\n                        elif exclude == \"reverse_types\":\n                            assert not th.any(th.isin(current_eids, seed_eids))\n                            reverse_etype = reverse_etypes[\n                                (src_type, etype, dst_type)\n                            ]\n                            if reverse_etype in block.canonical_etypes:\n                                assert not th.any(\n                                    th.isin(\n                                        block.edges[reverse_etype].data[\n                                            dgl.EID\n                                        ],\n                                        seed_eids,\n                                    )\n                                )\n                        else:\n                            raise ValueError(\n                                f\"Unsupported exclude type: {exclude}\"\n                            )\n    del dataloader\n    dgl.distributed.exit_client()\n\n\ndef check_dataloader(\n    g,\n    num_server,\n    num_workers,\n    dataloader_type,\n    use_graphbolt=False,\n    return_eids=False,\n    exclude=None,\n    reverse_eids=None,\n    reverse_etypes=None,\n    negative=False,\n    prob_or_mask=None,\n    use_deprecated_dataloader=False,\n):\n    with tempfile.TemporaryDirectory() as test_dir:\n        ip_config = \"ip_config.txt\"\n        generate_ip_config(ip_config, num_server, num_server)\n\n        num_parts = num_server\n        num_hops = 1\n        graph_name = f\"graph_{uuid.uuid4()}\"\n        orig_nid, orig_eid = partition_graph(\n            g,\n            graph_name,\n            num_parts,\n            test_dir,\n            num_hops=num_hops,\n            part_method=\"metis\",\n            return_mapping=True,\n            use_graphbolt=use_graphbolt,\n            store_eids=return_eids,\n        )\n        part_config = os.path.join(test_dir, f\"{graph_name}.json\")\n        if not isinstance(orig_nid, dict):\n            orig_nid = {g.ntypes[0]: orig_nid}\n        if not isinstance(orig_eid, dict):\n            orig_eid = {g.canonical_etypes[0]: orig_eid}\n\n        pserver_list = []\n        ctx = mp.get_context(\"spawn\")\n        for i in range(num_server):\n            p = ctx.Process(\n                target=start_server,\n                args=(\n                    i,\n                    ip_config,\n                    part_config,\n                    num_server > 1,\n                    num_workers + 1,\n                    use_graphbolt,\n                ),\n            )\n            p.start()\n            time.sleep(1)\n            pserver_list.append(p)\n\n        os.environ[\"DGL_DIST_MODE\"] = \"distributed\"\n        os.environ[\"DGL_NUM_SAMPLER\"] = str(num_workers)\n        ptrainer_list = []\n        if dataloader_type == \"node\":\n            p = ctx.Process(\n                target=start_node_dataloader,\n                args=(\n                    0,\n                    ip_config,\n                    part_config,\n                    num_server,\n                    num_workers,\n                    orig_nid,\n                    orig_eid,\n                    g,\n                    use_graphbolt,\n                    return_eids,\n                    prob_or_mask,\n                    use_deprecated_dataloader,\n                ),\n            )\n            p.start()\n            ptrainer_list.append(p)\n        elif dataloader_type == \"edge\":\n            p = ctx.Process(\n                target=start_edge_dataloader,\n                args=(\n                    0,\n                    ip_config,\n                    part_config,\n                    num_server,\n                    num_workers,\n                    orig_nid,\n                    orig_eid,\n                    g,\n                    use_graphbolt,\n                    exclude,\n                    reverse_eids,\n                    reverse_etypes,\n                    negative,\n                    prob_or_mask,\n                    use_deprecated_dataloader,\n                ),\n            )\n            p.start()\n            ptrainer_list.append(p)\n        for p in pserver_list:\n            p.join()\n            assert p.exitcode == 0\n        for p in ptrainer_list:\n            p.join()\n            assert p.exitcode == 0\n\n\ndef create_random_hetero():\n    num_nodes = {\"n1\": 10000, \"n2\": 10010, \"n3\": 10020}\n    etypes = [(\"n1\", \"r1\", \"n2\"), (\"n1\", \"r2\", \"n3\"), (\"n2\", \"r3\", \"n3\")]\n    edges = {}\n    for etype in etypes:\n        src_ntype, _, dst_ntype = etype\n        arr = spsp.random(\n            num_nodes[src_ntype],\n            num_nodes[dst_ntype],\n            density=0.001,\n            format=\"coo\",\n            random_state=100,\n        )\n        edges[etype] = (arr.row, arr.col)\n    # Add reverse edges.\n    src, dst = edges[(\"n1\", \"r1\", \"n2\")]\n    edges[(\"n2\", \"r21\", \"n1\")] = (dst, src)\n    g = dgl.heterograph(edges, num_nodes)\n    g.nodes[\"n1\"].data[\"feat\"] = F.unsqueeze(F.arange(0, g.num_nodes(\"n1\")), 1)\n    g.edges[\"r1\"].data[\"feat\"] = F.unsqueeze(F.arange(0, g.num_edges(\"r1\")), 1)\n    return g\n\n\n@pytest.mark.parametrize(\"num_server\", [1])\n@pytest.mark.parametrize(\"num_workers\", [0, 1])\n@pytest.mark.parametrize(\"dataloader_type\", [\"node\", \"edge\"])\n@pytest.mark.parametrize(\"use_graphbolt\", [False, True])\n@pytest.mark.parametrize(\"return_eids\", [False, True])\ndef test_dataloader_homograph(\n    num_server, num_workers, dataloader_type, use_graphbolt, return_eids\n):\n    if not use_graphbolt and return_eids:\n        # return_eids is not supported in non-GraphBolt mode.\n        return\n    reset_envs()\n    g = CitationGraphDataset(\"cora\")[0]\n    check_dataloader(\n        g,\n        num_server,\n        num_workers,\n        dataloader_type,\n        use_graphbolt=use_graphbolt,\n        return_eids=return_eids,\n    )\n\n\n@pytest.mark.parametrize(\"num_workers\", [0])\n@pytest.mark.parametrize(\"use_graphbolt\", [False, True])\n@pytest.mark.parametrize(\"exclude\", [None, \"self\", \"reverse_id\"])\n@pytest.mark.parametrize(\"negative\", [False, True])\ndef test_edge_dataloader_homograph(\n    num_workers, use_graphbolt, exclude, negative\n):\n    num_server = 1\n    dataloader_type = \"edge\"\n    reset_envs()\n    g, reverse_eids = _unique_rand_graph()\n    check_dataloader(\n        g,\n        num_server,\n        num_workers,\n        dataloader_type,\n        use_graphbolt=use_graphbolt,\n        return_eids=True,\n        exclude=exclude,\n        reverse_eids=reverse_eids,\n        negative=negative,\n    )\n\n\n@pytest.mark.parametrize(\"num_server\", [1])\n@pytest.mark.parametrize(\"num_workers\", [1])\n@pytest.mark.parametrize(\"dataloader_type\", [\"node\", \"edge\"])\n@pytest.mark.parametrize(\"use_graphbolt\", [False, True])\n@pytest.mark.parametrize(\"prob_or_mask\", [\"prob\", \"mask\"])\ndef test_dataloader_homograph_prob_or_mask(\n    num_server, num_workers, dataloader_type, use_graphbolt, prob_or_mask\n):\n    reset_envs()\n    g = CitationGraphDataset(\"cora\")[0]\n    prob = th.rand(g.num_edges())\n    mask = prob > 0.2\n    g.edata[\"prob\"] = F.tensor(prob)\n    g.edata[\"mask\"] = F.tensor(mask)\n    check_dataloader(\n        g,\n        num_server,\n        num_workers,\n        dataloader_type,\n        use_graphbolt=use_graphbolt,\n        return_eids=True,\n        prob_or_mask=prob_or_mask,\n    )\n\n\n@pytest.mark.parametrize(\"num_server\", [1])\n@pytest.mark.parametrize(\"num_workers\", [0, 1])\n@pytest.mark.parametrize(\"dataloader_type\", [\"node\", \"edge\"])\n@pytest.mark.parametrize(\"use_graphbolt\", [False, True])\n@pytest.mark.parametrize(\"return_eids\", [False, True])\ndef test_dataloader_heterograph(\n    num_server, num_workers, dataloader_type, use_graphbolt, return_eids\n):\n    if not use_graphbolt and return_eids:\n        # return_eids is not supported in non-GraphBolt mode.\n        return\n    reset_envs()\n    g = create_random_hetero()\n    check_dataloader(\n        g,\n        num_server,\n        num_workers,\n        dataloader_type,\n        use_graphbolt=use_graphbolt,\n        return_eids=return_eids,\n    )\n\n\n@pytest.mark.parametrize(\"num_workers\", [0])\n@pytest.mark.parametrize(\"use_graphbolt\", [False, True])\n@pytest.mark.parametrize(\"exclude\", [None, \"self\", \"reverse_types\"])\n@pytest.mark.parametrize(\"negative\", [False, True])\ndef test_edge_dataloader_heterograph(\n    num_workers, use_graphbolt, exclude, negative\n):\n    num_server = 1\n    dataloader_type = \"edge\"\n    reset_envs()\n    g = create_random_hetero()\n    reverse_etypes = {(\"n1\", \"r1\", \"n2\"): (\"n2\", \"r21\", \"n1\")}\n    check_dataloader(\n        g,\n        num_server,\n        num_workers,\n        dataloader_type,\n        use_graphbolt=use_graphbolt,\n        return_eids=True,\n        exclude=exclude,\n        reverse_etypes=reverse_etypes,\n        negative=negative,\n    )\n\n\n@pytest.mark.parametrize(\"num_server\", [1])\n@pytest.mark.parametrize(\"num_workers\", [1])\n@pytest.mark.parametrize(\"dataloader_type\", [\"node\", \"edge\"])\n@pytest.mark.parametrize(\"use_graphbolt\", [False, True])\n@pytest.mark.parametrize(\"prob_or_mask\", [\"prob\", \"mask\"])\ndef test_dataloader_heterograph_prob_or_mask(\n    num_server, num_workers, dataloader_type, use_graphbolt, prob_or_mask\n):\n    reset_envs()\n    g = create_random_hetero()\n    for etype in g.canonical_etypes:\n        prob = th.rand(g.num_edges(etype))\n        mask = prob > prob.median()\n        g.edges[etype].data[\"prob\"] = prob\n        g.edges[etype].data[\"mask\"] = mask\n    check_dataloader(\n        g,\n        num_server,\n        num_workers,\n        dataloader_type,\n        use_graphbolt=use_graphbolt,\n        return_eids=True,\n        prob_or_mask=prob_or_mask,\n    )\n\n\n@unittest.skip(reason=\"Skip due to glitch in CI\")\n@pytest.mark.parametrize(\"num_server\", [3])\n@pytest.mark.parametrize(\"num_workers\", [0, 4])\ndef test_neg_dataloader(num_server, num_workers):\n    reset_envs()\n    g = CitationGraphDataset(\"cora\")[0]\n    check_neg_dataloader(g, num_server, num_workers)\n    g = create_random_hetero()\n    check_neg_dataloader(g, num_server, num_workers)\n\n\ndef start_multiple_dataloaders(\n    ip_config,\n    part_config,\n    graph_name,\n    orig_g,\n    num_dataloaders,\n    dataloader_type,\n    use_graphbolt,\n):\n    dgl.distributed.initialize(ip_config)\n    dist_g = dgl.distributed.DistGraph(graph_name, part_config=part_config)\n    if dataloader_type == \"node\":\n        train_ids = th.arange(orig_g.num_nodes(), dtype=dist_g.idtype)\n        batch_size = orig_g.num_nodes() // 100\n    else:\n        train_ids = th.arange(orig_g.num_edges())\n        batch_size = orig_g.num_edges() // 100\n    sampler = dgl.dataloading.NeighborSampler([-1])\n    dataloaders = []\n    dl_iters = []\n    for _ in range(num_dataloaders):\n        if dataloader_type == \"node\":\n            dataloader = dgl.distributed.DistNodeDataLoader(\n                dist_g, train_ids, sampler, batch_size=batch_size\n            )\n        else:\n            dataloader = dgl.distributed.DistEdgeDataLoader(\n                dist_g, train_ids, sampler, batch_size=batch_size\n            )\n        dataloaders.append(dataloader)\n        dl_iters.append(iter(dataloader))\n\n    # iterate on multiple dataloaders randomly\n    while len(dl_iters) > 0:\n        next_dl = np.random.choice(len(dl_iters), 1)[0]\n        try:\n            _ = next(dl_iters[next_dl])\n        except StopIteration:\n            dl_iters.pop(next_dl)\n            del dataloaders[next_dl]\n\n    dgl.distributed.exit_client()\n\n\n@pytest.mark.parametrize(\"num_dataloaders\", [4])\n@pytest.mark.parametrize(\"num_workers\", [0])\n@pytest.mark.parametrize(\"dataloader_type\", [\"node\", \"edge\"])\n@pytest.mark.parametrize(\"use_graphbolt\", [False, True])\ndef test_multiple_dist_dataloaders(\n    num_dataloaders, num_workers, dataloader_type, use_graphbolt\n):\n    reset_envs()\n    os.environ[\"DGL_DIST_MODE\"] = \"distributed\"\n    os.environ[\"DGL_NUM_SAMPLER\"] = str(num_workers)\n    num_parts = 1\n    num_servers = 1\n    with tempfile.TemporaryDirectory() as test_dir:\n        ip_config = os.path.join(test_dir, \"ip_config.txt\")\n        generate_ip_config(ip_config, num_parts, num_servers)\n\n        orig_g = dgl.rand_graph(1000, 10000)\n        graph_name = f\"graph_{uuid.uuid4()}\"\n        partition_graph(\n            orig_g,\n            graph_name,\n            num_parts,\n            test_dir,\n            use_graphbolt=use_graphbolt,\n        )\n        part_config = os.path.join(test_dir, f\"{graph_name}.json\")\n\n        p_servers = []\n        ctx = mp.get_context(\"spawn\")\n        for i in range(num_servers):\n            p = ctx.Process(\n                target=start_server,\n                args=(\n                    i,\n                    ip_config,\n                    part_config,\n                    num_servers > 1,\n                    num_workers + 1,\n                    use_graphbolt,\n                ),\n            )\n            p.start()\n            time.sleep(1)\n            p_servers.append(p)\n\n        p_client = ctx.Process(\n            target=start_multiple_dataloaders,\n            args=(\n                ip_config,\n                part_config,\n                graph_name,\n                orig_g,\n                num_dataloaders,\n                dataloader_type,\n                use_graphbolt,\n            ),\n        )\n        p_client.start()\n\n        p_client.join()\n        assert p_client.exitcode == 0\n        for p in p_servers:\n            p.join()\n            assert p.exitcode == 0\n    reset_envs()\n\n\n@pytest.mark.parametrize(\"dataloader_type\", [\"node\", \"edge\"])\ndef test_deprecated_dataloader(dataloader_type):\n    reset_envs()\n    g = CitationGraphDataset(\"cora\")[0]\n    check_dataloader(\n        g,\n        1,\n        0,\n        dataloader_type,\n        use_deprecated_dataloader=True,\n    )\n"
  },
  {
    "path": "tests/distributed/test_new_kvstore.py",
    "content": "import multiprocessing as mp\nimport os\nimport time\nimport unittest\n\nimport backend as F\n\nimport dgl\nfrom numpy.testing import assert_array_equal\nfrom utils import generate_ip_config, reset_envs\n\n\n# Create an one-part Graph\nnode_map = {\"_N\": F.tensor([[0, 6]], F.int64)}\nedge_map = {(\"_N\", \"_E\", \"_N\"): F.tensor([[0, 7]], F.int64)}\nglobal_nid = F.tensor([0, 1, 2, 3, 4, 5], F.int64)\nglobal_eid = F.tensor([0, 1, 2, 3, 4, 5, 6], F.int64)\n\ng = dgl.graph([])\ng.add_nodes(6)\ng.add_edges(0, 1)  # 0\ng.add_edges(0, 2)  # 1\ng.add_edges(0, 3)  # 2\ng.add_edges(2, 3)  # 3\ng.add_edges(1, 1)  # 4\ng.add_edges(0, 4)  # 5\ng.add_edges(2, 5)  # 6\n\ng.ndata[dgl.NID] = global_nid\ng.edata[dgl.EID] = global_eid\n\ngpb = dgl.distributed.graph_partition_book.RangePartitionBook(\n    part_id=0,\n    num_parts=1,\n    node_map=node_map,\n    edge_map=edge_map,\n    ntypes={ntype: i for i, ntype in enumerate(g.ntypes)},\n    etypes={etype: i for i, etype in enumerate(g.canonical_etypes)},\n)\n\nnode_policy = dgl.distributed.PartitionPolicy(\n    policy_str=\"node~_N\", partition_book=gpb\n)\n\nedge_policy = dgl.distributed.PartitionPolicy(\n    policy_str=\"edge~_N:_E:_N\", partition_book=gpb\n)\n\ndata_0 = F.tensor(\n    [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0]],\n    F.float32,\n)\ndata_0_1 = F.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], F.float32)\ndata_0_2 = F.tensor([1, 2, 3, 4, 5, 6], F.int32)\ndata_0_3 = F.tensor([1, 2, 3, 4, 5, 6], F.int64)\ndata_1 = F.tensor(\n    [\n        [2.0, 2.0],\n        [2.0, 2.0],\n        [2.0, 2.0],\n        [2.0, 2.0],\n        [2.0, 2.0],\n        [2.0, 2.0],\n        [2.0, 2.0],\n    ],\n    F.float32,\n)\ndata_2 = F.tensor(\n    [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]],\n    F.float32,\n)\n\n\ndef init_zero_func(shape, dtype):\n    return F.zeros(shape, dtype, F.cpu())\n\n\ndef udf_push(target, name, id_tensor, data_tensor):\n    target[name][id_tensor] = data_tensor * data_tensor\n\n\ndef add_push(target, name, id_tensor, data_tensor):\n    target[name][id_tensor] += data_tensor\n\n\n@unittest.skipIf(\n    os.name == \"nt\" or os.getenv(\"DGLBACKEND\") == \"tensorflow\",\n    reason=\"Do not support windows and TF yet\",\n)\ndef test_partition_policy():\n    assert node_policy.part_id == 0\n    assert edge_policy.part_id == 0\n    local_nid = node_policy.to_local(F.tensor([0, 1, 2, 3, 4, 5]))\n    local_eid = edge_policy.to_local(F.tensor([0, 1, 2, 3, 4, 5, 6]))\n    assert_array_equal(\n        F.asnumpy(local_nid), F.asnumpy(F.tensor([0, 1, 2, 3, 4, 5], F.int64))\n    )\n    assert_array_equal(\n        F.asnumpy(local_eid),\n        F.asnumpy(F.tensor([0, 1, 2, 3, 4, 5, 6], F.int64)),\n    )\n    nid_partid = node_policy.to_partid(F.tensor([0, 1, 2, 3, 4, 5], F.int64))\n    eid_partid = edge_policy.to_partid(F.tensor([0, 1, 2, 3, 4, 5, 6], F.int64))\n    assert_array_equal(\n        F.asnumpy(nid_partid), F.asnumpy(F.tensor([0, 0, 0, 0, 0, 0], F.int64))\n    )\n    assert_array_equal(\n        F.asnumpy(eid_partid),\n        F.asnumpy(F.tensor([0, 0, 0, 0, 0, 0, 0], F.int64)),\n    )\n    assert node_policy.get_part_size() == len(local_nid)\n    assert edge_policy.get_part_size() == len(local_eid)\n\n\ndef start_server(server_id, num_clients, num_servers):\n    # Init kvserver\n    print(\"Sleep 5 seconds to test client re-connect.\")\n    time.sleep(5)\n    kvserver = dgl.distributed.KVServer(\n        server_id=server_id,\n        ip_config=\"kv_ip_config.txt\",\n        num_servers=num_servers,\n        num_clients=num_clients,\n    )\n    kvserver.add_part_policy(node_policy)\n    kvserver.add_part_policy(edge_policy)\n    if kvserver.is_backup_server():\n        kvserver.init_data(\"data_0\", \"node~_N\")\n        kvserver.init_data(\"data_0_1\", \"node~_N\")\n        kvserver.init_data(\"data_0_2\", \"node~_N\")\n        kvserver.init_data(\"data_0_3\", \"node~_N\")\n    else:\n        kvserver.init_data(\"data_0\", \"node~_N\", data_0)\n        kvserver.init_data(\"data_0_1\", \"node~_N\", data_0_1)\n        kvserver.init_data(\"data_0_2\", \"node~_N\", data_0_2)\n        kvserver.init_data(\"data_0_3\", \"node~_N\", data_0_3)\n    # start server\n    server_state = dgl.distributed.ServerState(\n        kv_store=kvserver, local_g=None, partition_book=None\n    )\n    dgl.distributed.start_server(\n        server_id=server_id,\n        ip_config=\"kv_ip_config.txt\",\n        num_servers=num_servers,\n        num_clients=num_clients,\n        server_state=server_state,\n    )\n\n\ndef start_server_mul_role(server_id, num_clients, num_servers):\n    # Init kvserver\n    kvserver = dgl.distributed.KVServer(\n        server_id=server_id,\n        ip_config=\"kv_ip_mul_config.txt\",\n        num_servers=num_servers,\n        num_clients=num_clients,\n    )\n    kvserver.add_part_policy(node_policy)\n    if kvserver.is_backup_server():\n        kvserver.init_data(\"data_0\", \"node~_N\")\n    else:\n        kvserver.init_data(\"data_0\", \"node~_N\", data_0)\n    # start server\n    server_state = dgl.distributed.ServerState(\n        kv_store=kvserver, local_g=None, partition_book=None\n    )\n    dgl.distributed.start_server(\n        server_id=server_id,\n        ip_config=\"kv_ip_mul_config.txt\",\n        num_servers=num_servers,\n        num_clients=num_clients,\n        server_state=server_state,\n    )\n\n\ndef start_client(num_clients, num_servers):\n    os.environ[\"DGL_DIST_MODE\"] = \"distributed\"\n    # Note: connect to server first !\n    dgl.distributed.initialize(ip_config=\"kv_ip_config.txt\")\n    # Init kvclient\n    kvclient = dgl.distributed.KVClient(\n        ip_config=\"kv_ip_config.txt\", num_servers=num_servers\n    )\n    kvclient.map_shared_data(partition_book=gpb)\n    assert dgl.distributed.get_num_client() == num_clients\n    kvclient.init_data(\n        name=\"data_1\",\n        shape=F.shape(data_1),\n        dtype=F.dtype(data_1),\n        part_policy=edge_policy,\n        init_func=init_zero_func,\n    )\n    kvclient.init_data(\n        name=\"data_2\",\n        shape=F.shape(data_2),\n        dtype=F.dtype(data_2),\n        part_policy=node_policy,\n        init_func=init_zero_func,\n    )\n\n    # Test data_name_list\n    name_list = kvclient.data_name_list()\n    print(name_list)\n    assert \"data_0\" in name_list\n    assert \"data_0_1\" in name_list\n    assert \"data_0_2\" in name_list\n    assert \"data_0_3\" in name_list\n    assert \"data_1\" in name_list\n    assert \"data_2\" in name_list\n    # Test get_meta_data\n    meta = kvclient.get_data_meta(\"data_0\")\n    dtype, shape, policy = meta\n    assert dtype == F.dtype(data_0)\n    assert shape == F.shape(data_0)\n    assert policy.policy_str == \"node~_N\"\n\n    meta = kvclient.get_data_meta(\"data_0_1\")\n    dtype, shape, policy = meta\n    assert dtype == F.dtype(data_0_1)\n    assert shape == F.shape(data_0_1)\n    assert policy.policy_str == \"node~_N\"\n\n    meta = kvclient.get_data_meta(\"data_0_2\")\n    dtype, shape, policy = meta\n    assert dtype == F.dtype(data_0_2)\n    assert shape == F.shape(data_0_2)\n    assert policy.policy_str == \"node~_N\"\n\n    meta = kvclient.get_data_meta(\"data_0_3\")\n    dtype, shape, policy = meta\n    assert dtype == F.dtype(data_0_3)\n    assert shape == F.shape(data_0_3)\n    assert policy.policy_str == \"node~_N\"\n\n    meta = kvclient.get_data_meta(\"data_1\")\n    dtype, shape, policy = meta\n    assert dtype == F.dtype(data_1)\n    assert shape == F.shape(data_1)\n    assert policy.policy_str == \"edge~_N:_E:_N\"\n\n    meta = kvclient.get_data_meta(\"data_2\")\n    dtype, shape, policy = meta\n    assert dtype == F.dtype(data_2)\n    assert shape == F.shape(data_2)\n    assert policy.policy_str == \"node~_N\"\n\n    # Test push and pull\n    id_tensor = F.tensor([0, 2, 4], F.int64)\n    data_tensor = F.tensor([[6.0, 6.0], [6.0, 6.0], [6.0, 6.0]], F.float32)\n    kvclient.push(name=\"data_0\", id_tensor=id_tensor, data_tensor=data_tensor)\n    kvclient.push(name=\"data_1\", id_tensor=id_tensor, data_tensor=data_tensor)\n    kvclient.push(name=\"data_2\", id_tensor=id_tensor, data_tensor=data_tensor)\n    res = kvclient.pull(name=\"data_0\", id_tensor=id_tensor)\n    assert_array_equal(F.asnumpy(res), F.asnumpy(data_tensor))\n    res = kvclient.pull(name=\"data_1\", id_tensor=id_tensor)\n    assert_array_equal(F.asnumpy(res), F.asnumpy(data_tensor))\n    res = kvclient.pull(name=\"data_2\", id_tensor=id_tensor)\n    assert_array_equal(F.asnumpy(res), F.asnumpy(data_tensor))\n    # Register new push handler\n    kvclient.register_push_handler(\"data_0\", udf_push)\n    kvclient.register_push_handler(\"data_1\", udf_push)\n    kvclient.register_push_handler(\"data_2\", udf_push)\n    # Test push and pull\n    kvclient.push(name=\"data_0\", id_tensor=id_tensor, data_tensor=data_tensor)\n    kvclient.push(name=\"data_1\", id_tensor=id_tensor, data_tensor=data_tensor)\n    kvclient.push(name=\"data_2\", id_tensor=id_tensor, data_tensor=data_tensor)\n    kvclient.barrier()\n    data_tensor = data_tensor * data_tensor\n    res = kvclient.pull(name=\"data_0\", id_tensor=id_tensor)\n    assert_array_equal(F.asnumpy(res), F.asnumpy(data_tensor))\n    res = kvclient.pull(name=\"data_1\", id_tensor=id_tensor)\n    assert_array_equal(F.asnumpy(res), F.asnumpy(data_tensor))\n    res = kvclient.pull(name=\"data_2\", id_tensor=id_tensor)\n    assert_array_equal(F.asnumpy(res), F.asnumpy(data_tensor))\n\n    # Test delete data\n    kvclient.delete_data(\"data_0\")\n    kvclient.delete_data(\"data_1\")\n    kvclient.delete_data(\"data_2\")\n\n    # Register new push handler\n    kvclient.init_data(\n        name=\"data_3\",\n        shape=F.shape(data_2),\n        dtype=F.dtype(data_2),\n        part_policy=node_policy,\n        init_func=init_zero_func,\n    )\n    kvclient.register_push_handler(\"data_3\", add_push)\n    data_tensor = F.tensor([[6.0, 6.0], [6.0, 6.0], [6.0, 6.0]], F.float32)\n    kvclient.barrier()\n    time.sleep(kvclient.client_id + 1)\n    print(\"add...\")\n    kvclient.push(name=\"data_3\", id_tensor=id_tensor, data_tensor=data_tensor)\n    kvclient.barrier()\n    res = kvclient.pull(name=\"data_3\", id_tensor=id_tensor)\n    data_tensor = data_tensor * num_clients\n    assert_array_equal(F.asnumpy(res), F.asnumpy(data_tensor))\n\n\ndef start_client_mul_role(i):\n    os.environ[\"DGL_DIST_MODE\"] = \"distributed\"\n    # Initialize creates kvstore !\n    dgl.distributed.initialize(ip_config=\"kv_ip_mul_config.txt\")\n    if i == 0:  # block one trainer\n        time.sleep(5)\n    kvclient = dgl.distributed.kvstore.get_kvstore()\n    kvclient.barrier()\n    print(\"i: %d role: %s\" % (i, kvclient.role))\n\n    assert dgl.distributed.role.get_num_trainers() == 2\n    assert dgl.distributed.role.get_trainer_rank() < 2\n    print(\n        \"trainer rank: %d, global rank: %d\"\n        % (\n            dgl.distributed.role.get_trainer_rank(),\n            dgl.distributed.role.get_global_rank(),\n        )\n    )\n    dgl.distributed.exit_client()\n\n\n@unittest.skipIf(\n    os.name == \"nt\" or os.getenv(\"DGLBACKEND\") == \"tensorflow\",\n    reason=\"Do not support windows and TF yet\",\n)\ndef test_kv_store():\n    reset_envs()\n    num_servers = 2\n    num_clients = 2\n    generate_ip_config(\"kv_ip_config.txt\", 1, num_servers)\n    ctx = mp.get_context(\"spawn\")\n    pserver_list = []\n    pclient_list = []\n    os.environ[\"DGL_NUM_SERVER\"] = str(num_servers)\n    for i in range(num_servers):\n        pserver = ctx.Process(\n            target=start_server, args=(i, num_clients, num_servers)\n        )\n        pserver.start()\n        pserver_list.append(pserver)\n    for i in range(num_clients):\n        pclient = ctx.Process(\n            target=start_client, args=(num_clients, num_servers)\n        )\n        pclient.start()\n        pclient_list.append(pclient)\n    for i in range(num_clients):\n        pclient_list[i].join()\n    for i in range(num_servers):\n        pserver_list[i].join()\n\n\n@unittest.skipIf(\n    os.name == \"nt\" or os.getenv(\"DGLBACKEND\") == \"tensorflow\",\n    reason=\"Do not support windows and TF yet\",\n)\ndef test_kv_multi_role():\n    reset_envs()\n    num_servers = 2\n    num_trainers = 2\n    num_samplers = 2\n    generate_ip_config(\"kv_ip_mul_config.txt\", 1, num_servers)\n    # There are two trainer processes and each trainer process has two sampler processes.\n    num_clients = num_trainers * (1 + num_samplers)\n    ctx = mp.get_context(\"spawn\")\n    pserver_list = []\n    pclient_list = []\n    os.environ[\"DGL_NUM_SAMPLER\"] = str(num_samplers)\n    os.environ[\"DGL_NUM_SERVER\"] = str(num_servers)\n    for i in range(num_servers):\n        pserver = ctx.Process(\n            target=start_server_mul_role, args=(i, num_clients, num_servers)\n        )\n        pserver.start()\n        pserver_list.append(pserver)\n    for i in range(num_trainers):\n        pclient = ctx.Process(target=start_client_mul_role, args=(i,))\n        pclient.start()\n        pclient_list.append(pclient)\n    for i in range(num_trainers):\n        pclient_list[i].join()\n    for i in range(num_servers):\n        pserver_list[i].join()\n\n\nif __name__ == \"__main__\":\n    test_partition_policy()\n    test_kv_store()\n    test_kv_multi_role()\n"
  },
  {
    "path": "tests/distributed/test_partition.py",
    "content": "import json\nimport os\nimport tempfile\n\nimport dgl\n\nimport dgl.backend as F\nimport dgl.graphbolt as gb\nimport numpy as np\nimport pytest\nimport torch as th\nfrom dgl import function as fn\nfrom dgl.distributed import (\n    dgl_partition_to_graphbolt,\n    load_partition,\n    load_partition_book,\n    load_partition_feats,\n    partition_graph,\n)\nfrom dgl.distributed.graph_partition_book import (\n    _etype_str_to_tuple,\n    _etype_tuple_to_str,\n    DEFAULT_ETYPE,\n    DEFAULT_NTYPE,\n    EdgePartitionPolicy,\n    HeteroDataName,\n    NodePartitionPolicy,\n    RangePartitionBook,\n)\nfrom dgl.distributed.partition import (\n    _get_inner_edge_mask,\n    _get_inner_node_mask,\n    RESERVED_FIELD_DTYPE,\n)\nfrom scipy import sparse as spsp\nfrom utils import reset_envs\n\n\ndef _verify_partition_data_types(part_g):\n    \"\"\"\n    check list:\n        make sure nodes and edges have correct type.\n    \"\"\"\n    ndata = (\n        part_g.node_attributes\n        if isinstance(part_g, gb.FusedCSCSamplingGraph)\n        else part_g.ndata\n    )\n    edata = (\n        part_g.edge_attributes\n        if isinstance(part_g, gb.FusedCSCSamplingGraph)\n        else part_g.edata\n    )\n\n    for k, dtype in RESERVED_FIELD_DTYPE.items():\n        if k in ndata:\n            assert ndata[k].dtype == dtype\n        if k in edata:\n            assert edata[k].dtype == dtype\n\n\ndef _verify_partition_formats(part_g, formats):\n    # verify saved graph formats\n    if formats is None:\n        assert \"coo\" in part_g.formats()[\"created\"]\n    else:\n        for format in formats:\n            assert format in part_g.formats()[\"created\"]\n\n\ndef create_random_graph(n):\n    arr = (\n        spsp.random(n, n, density=0.001, format=\"coo\", random_state=100) != 0\n    ).astype(np.int64)\n    return dgl.from_scipy(arr)\n\n\ndef create_random_hetero():\n    num_nodes = {\"n1\": 1000, \"n2\": 1010, \"n3\": 1020}\n    etypes = [\n        (\"n1\", \"r1\", \"n2\"),\n        (\"n2\", \"r1\", \"n1\"),\n        (\"n1\", \"r2\", \"n3\"),\n        (\"n2\", \"r3\", \"n3\"),\n    ]\n    edges = {}\n    for etype in etypes:\n        src_ntype, _, dst_ntype = etype\n        arr = spsp.random(\n            num_nodes[src_ntype],\n            num_nodes[dst_ntype],\n            density=0.001,\n            format=\"coo\",\n            random_state=100,\n        )\n        edges[etype] = (arr.row, arr.col)\n    return dgl.heterograph(edges, num_nodes)\n\n\ndef _verify_graphbolt_attributes(\n    parts, store_inner_node, store_inner_edge, store_eids\n):\n    \"\"\"\n    check list:\n        make sure arguments work.\n    \"\"\"\n    for part in parts:\n        assert store_inner_edge == (\"inner_edge\" in part.edge_attributes)\n        assert store_inner_node == (\"inner_node\" in part.node_attributes)\n        assert store_eids == (dgl.EID in part.edge_attributes)\n\n\ndef _verify_hetero_graph_node_edge_num(\n    g,\n    parts,\n    store_inner_edge,\n    debug_mode,\n):\n    \"\"\"\n    check list:\n        make sure edge type are correct.\n        make sure the number of nodes in each node type are correct.\n        make sure the number of nodes in each node type are correct.\n    \"\"\"\n    num_nodes = {ntype: 0 for ntype in g.ntypes}\n    num_edges = {etype: 0 for etype in g.canonical_etypes}\n    for part in parts:\n        edata = (\n            part.edge_attributes\n            if isinstance(part, gb.FusedCSCSamplingGraph)\n            else part.edata\n        )\n        if dgl.ETYPE in edata:\n            # edata may not contain all edge types.\n            assert len(g.canonical_etypes) >= len(F.unique(edata[dgl.ETYPE]))\n        if debug_mode or isinstance(part, dgl.DGLGraph):\n            for ntype in g.ntypes:\n                ntype_id = g.get_ntype_id(ntype)\n                inner_node_mask = _get_inner_node_mask(part, ntype_id)\n                num_inner_nodes = F.sum(F.astype(inner_node_mask, F.int64), 0)\n                num_nodes[ntype] += num_inner_nodes\n        if store_inner_edge or isinstance(part, dgl.DGLGraph):\n            for etype in g.canonical_etypes:\n                etype_id = g.get_etype_id(etype)\n                inner_edge_mask = _get_inner_edge_mask(part, etype_id)\n                num_inner_edges = F.sum(F.astype(inner_edge_mask, F.int64), 0)\n                num_edges[etype] += num_inner_edges\n\n    # Verify the number of nodes are correct.\n    if debug_mode or isinstance(part, dgl.DGLGraph):\n        for ntype in g.ntypes:\n            print(\n                \"node {}: {}, {}\".format(\n                    ntype, g.num_nodes(ntype), num_nodes[ntype]\n                )\n            )\n            assert g.num_nodes(ntype) == num_nodes[ntype]\n    # Verify the number of edges are correct.\n    if store_inner_edge or isinstance(part, dgl.DGLGraph):\n        for etype in g.canonical_etypes:\n            print(\n                \"edge {}: {}, {}\".format(\n                    etype, g.num_edges(etype), num_edges[etype]\n                )\n            )\n            assert g.num_edges(etype) == num_edges[etype]\n\n\ndef _verify_edge_id_range_hetero(\n    g,\n    part,\n    eids,\n):\n    \"\"\"\n    check list:\n        make sure inner_eids fall into a range.\n        make sure all edges are included.\n    \"\"\"\n    edata = (\n        part.edge_attributes\n        if isinstance(part, gb.FusedCSCSamplingGraph)\n        else part.edata\n    )\n    etype = (\n        part.type_per_edge\n        if isinstance(part, gb.FusedCSCSamplingGraph)\n        else edata[dgl.ETYPE]\n    )\n    eid = th.arange(len(edata[dgl.EID]))\n    etype_arr = F.gather_row(etype, eid)\n    eid_arr = F.gather_row(edata[dgl.EID], eid)\n    for etype in g.canonical_etypes:\n        etype_id = g.get_etype_id(etype)\n        eids[etype].append(F.boolean_mask(eid_arr, etype_arr == etype_id))\n        # Make sure edge Ids fall into a range.\n        inner_edge_mask = _get_inner_edge_mask(part, etype_id)\n        inner_eids = np.sort(\n            F.asnumpy(F.boolean_mask(edata[dgl.EID], inner_edge_mask))\n        )\n        assert np.all(\n            inner_eids == np.arange(inner_eids[0], inner_eids[-1] + 1)\n        )\n    return eids\n\n\ndef _verify_node_id_range_hetero(g, part, nids):\n    \"\"\"\n    check list:\n        make sure inner nodes have Ids fall into a range.\n    \"\"\"\n    for ntype in g.ntypes:\n        ntype_id = g.get_ntype_id(ntype)\n        # Make sure inner nodes have Ids fall into a range.\n        inner_node_mask = _get_inner_node_mask(part, ntype_id)\n        inner_nids = F.boolean_mask(\n            part.node_attributes[dgl.NID], inner_node_mask\n        )\n        assert np.all(\n            F.asnumpy(\n                inner_nids\n                == F.arange(\n                    F.as_scalar(inner_nids[0]),\n                    F.as_scalar(inner_nids[-1]) + 1,\n                )\n            )\n        )\n        nids[ntype].append(inner_nids)\n    return nids\n\n\ndef _verify_graph_attributes_hetero(\n    g,\n    parts,\n    store_inner_edge,\n    store_inner_node,\n):\n    \"\"\"\n    check list:\n        make sure edge ids fall into a range.\n        make sure inner nodes have Ids fall into a range.\n        make sure all nodes is included.\n        make sure all edges is included.\n    \"\"\"\n    nids = {ntype: [] for ntype in g.ntypes}\n    eids = {etype: [] for etype in g.canonical_etypes}\n    # check edge id.\n    if store_inner_edge or isinstance(parts[0], dgl.DGLGraph):\n        for part in parts:\n            # collect eids\n            eids = _verify_edge_id_range_hetero(g, part, eids)\n        for etype in eids:\n            eids_type = F.cat(eids[etype], 0)\n            uniq_ids = F.unique(eids_type)\n            # We should get all nodes.\n            assert len(uniq_ids) == g.num_edges(etype)\n\n    # check node id.\n    if store_inner_node or isinstance(parts[0], dgl.DGLGraph):\n        for part in parts:\n            nids = _verify_node_id_range_hetero(g, part, nids)\n        for ntype in nids:\n            nids_type = F.cat(nids[ntype], 0)\n            uniq_ids = F.unique(nids_type)\n            # We should get all nodes.\n            assert len(uniq_ids) == g.num_nodes(ntype)\n\n\ndef _verify_hetero_graph(\n    g,\n    parts,\n    store_eids=False,\n    store_inner_edge=False,\n    store_inner_node=False,\n    debug_mode=False,\n):\n    _verify_hetero_graph_node_edge_num(\n        g,\n        parts,\n        store_inner_edge=store_inner_edge,\n        debug_mode=debug_mode,\n    )\n    if store_eids:\n        _verify_graph_attributes_hetero(\n            g,\n            parts,\n            store_inner_edge=store_inner_edge,\n            store_inner_node=store_inner_node,\n        )\n\n\ndef _verify_node_feats(g, part, gpb, orig_nids, node_feats, is_homo=False):\n    for ntype in g.ntypes:\n        ndata = (\n            part.node_attributes\n            if isinstance(part, gb.FusedCSCSamplingGraph)\n            else part.ndata\n        )\n        ntype_id = g.get_ntype_id(ntype)\n        inner_node_mask = _get_inner_node_mask(\n            part,\n            ntype_id,\n            (gpb if isinstance(part, gb.FusedCSCSamplingGraph) else None),\n        )\n        inner_nids = F.boolean_mask(ndata[dgl.NID], inner_node_mask)\n        ntype_ids, inner_type_nids = gpb.map_to_per_ntype(inner_nids)\n        partid = gpb.nid2partid(inner_type_nids, ntype)\n        if is_homo:\n            assert np.all(F.asnumpy(ntype_ids) == ntype_id)\n            assert np.all(F.asnumpy(partid) == gpb.partid)\n\n        if is_homo:\n            orig_id = orig_nids[inner_type_nids]\n        else:\n            orig_id = orig_nids[ntype][inner_type_nids]\n        local_nids = gpb.nid2localnid(inner_type_nids, gpb.partid, ntype)\n\n        for name in g.nodes[ntype].data:\n            if name in [dgl.NID, \"inner_node\"]:\n                continue\n            true_feats = F.gather_row(g.nodes[ntype].data[name], orig_id)\n            ndata = F.gather_row(node_feats[ntype + \"/\" + name], local_nids)\n            assert np.all(F.asnumpy(ndata == true_feats))\n\n\ndef _verify_edge_feats(g, part, gpb, orig_eids, edge_feats, is_homo=False):\n    for etype in g.canonical_etypes:\n        edata = (\n            part.edge_attributes\n            if isinstance(part, gb.FusedCSCSamplingGraph)\n            else part.edata\n        )\n        etype_id = g.get_etype_id(etype)\n        inner_edge_mask = _get_inner_edge_mask(part, etype_id)\n        inner_eids = F.boolean_mask(edata[dgl.EID], inner_edge_mask)\n        etype_ids, inner_type_eids = gpb.map_to_per_etype(inner_eids)\n        partid = gpb.eid2partid(inner_type_eids, etype)\n        assert np.all(F.asnumpy(etype_ids) == etype_id)\n        assert np.all(F.asnumpy(partid) == gpb.partid)\n\n        if is_homo:\n            orig_id = orig_eids[inner_type_eids]\n        else:\n            orig_id = orig_eids[etype][inner_type_eids]\n        local_eids = gpb.eid2localeid(inner_type_eids, gpb.partid, etype)\n\n        for name in g.edges[etype].data:\n            if name in [dgl.EID, \"inner_edge\"]:\n                continue\n            true_feats = F.gather_row(g.edges[etype].data[name], orig_id)\n            edata = F.gather_row(\n                edge_feats[_etype_tuple_to_str(etype) + \"/\" + name],\n                local_eids,\n            )\n            assert np.all(F.asnumpy(edata == true_feats))\n\n\ndef verify_graph_feats_hetero_dgl(\n    g,\n    gpb,\n    part,\n    node_feats,\n    edge_feats,\n    orig_nids,\n    orig_eids,\n):\n    \"\"\"\n    check list:\n        make sure the feats of nodes and edges are correct\n    \"\"\"\n    _verify_node_feats(g, part, gpb, orig_nids, node_feats)\n\n    _verify_edge_feats(g, part, gpb, orig_eids, edge_feats)\n\n\ndef verify_graph_feats_gb(\n    g,\n    gpbs,\n    parts,\n    tot_node_feats,\n    tot_edge_feats,\n    orig_nids,\n    orig_eids,\n    shuffled_labels,\n    shuffled_edata,\n    test_ntype,\n    test_etype,\n    store_inner_node=False,\n    store_inner_edge=False,\n    store_eids=False,\n    is_homo=False,\n):\n    \"\"\"\n    check list:\n        make sure the feats of nodes and edges are correct\n    \"\"\"\n    for part_id in range(len(parts)):\n        part = parts[part_id]\n        gpb = gpbs[part_id]\n        node_feats = tot_node_feats[part_id]\n        edge_feats = tot_edge_feats[part_id]\n        if store_inner_node:\n            _verify_node_feats(\n                g,\n                part,\n                gpb,\n                orig_nids,\n                node_feats,\n                is_homo=is_homo,\n            )\n        if store_inner_edge and store_eids:\n            _verify_edge_feats(\n                g,\n                part,\n                gpb,\n                orig_eids,\n                edge_feats,\n                is_homo=is_homo,\n            )\n\n    _verify_shuffled_labels_gb(\n        g,\n        shuffled_labels,\n        shuffled_edata,\n        orig_nids,\n        orig_eids,\n        test_ntype,\n        test_etype,\n    )\n\n\ndef check_hetero_partition(\n    hg,\n    part_method,\n    num_parts=4,\n    num_trainers_per_machine=1,\n    load_feats=True,\n    graph_formats=None,\n):\n    test_ntype = \"n1\"\n    test_etype = (\"n1\", \"r1\", \"n2\")\n    hg.nodes[test_ntype].data[\"labels\"] = F.arange(0, hg.num_nodes(test_ntype))\n    hg.nodes[test_ntype].data[\"feats\"] = F.tensor(\n        np.random.randn(hg.num_nodes(test_ntype), 10), F.float32\n    )\n    hg.edges[test_etype].data[\"feats\"] = F.tensor(\n        np.random.randn(hg.num_edges(test_etype), 10), F.float32\n    )\n    hg.edges[test_etype].data[\"labels\"] = F.arange(0, hg.num_edges(test_etype))\n    num_hops = 1\n\n    orig_nids, orig_eids = partition_graph(\n        hg,\n        \"test\",\n        num_parts,\n        \"/tmp/partition\",\n        num_hops=num_hops,\n        part_method=part_method,\n        return_mapping=True,\n        num_trainers_per_machine=num_trainers_per_machine,\n        graph_formats=graph_formats,\n    )\n    assert len(orig_nids) == len(hg.ntypes)\n    assert len(orig_eids) == len(hg.canonical_etypes)\n    for ntype in hg.ntypes:\n        assert len(orig_nids[ntype]) == hg.num_nodes(ntype)\n    for etype in hg.canonical_etypes:\n        assert len(orig_eids[etype]) == hg.num_edges(etype)\n    parts = []\n    shuffled_labels = []\n    shuffled_elabels = []\n    for i in range(num_parts):\n        part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition(\n            \"/tmp/partition/test.json\", i, load_feats=load_feats\n        )\n        _verify_partition_data_types(part_g)\n        _verify_partition_formats(part_g, graph_formats)\n        if not load_feats:\n            assert not node_feats\n            assert not edge_feats\n            node_feats, edge_feats = load_partition_feats(\n                \"/tmp/partition/test.json\", i\n            )\n        if num_trainers_per_machine > 1:\n            for ntype in hg.ntypes:\n                name = ntype + \"/trainer_id\"\n                assert name in node_feats\n                part_ids = F.floor_div(\n                    node_feats[name], num_trainers_per_machine\n                )\n                assert np.all(F.asnumpy(part_ids) == i)\n\n            for etype in hg.canonical_etypes:\n                name = _etype_tuple_to_str(etype) + \"/trainer_id\"\n                assert name in edge_feats\n                part_ids = F.floor_div(\n                    edge_feats[name], num_trainers_per_machine\n                )\n                assert np.all(F.asnumpy(part_ids) == i)\n        # Verify the mapping between the reshuffled IDs and the original IDs.\n        # These are partition-local IDs.\n        part_src_ids, part_dst_ids = part_g.edges()\n        # These are reshuffled global homogeneous IDs.\n        part_src_ids = F.gather_row(part_g.ndata[dgl.NID], part_src_ids)\n        part_dst_ids = F.gather_row(part_g.ndata[dgl.NID], part_dst_ids)\n        part_eids = part_g.edata[dgl.EID]\n        # These are reshuffled per-type IDs.\n        src_ntype_ids, part_src_ids = gpb.map_to_per_ntype(part_src_ids)\n        dst_ntype_ids, part_dst_ids = gpb.map_to_per_ntype(part_dst_ids)\n        etype_ids, part_eids = gpb.map_to_per_etype(part_eids)\n        # `IdMap` is in int64 by default.\n        assert src_ntype_ids.dtype == F.int64\n        assert dst_ntype_ids.dtype == F.int64\n        assert etype_ids.dtype == F.int64\n        with pytest.raises(dgl.utils.internal.InconsistentDtypeException):\n            gpb.map_to_per_ntype(F.tensor([0], F.int32))\n        with pytest.raises(dgl.utils.internal.InconsistentDtypeException):\n            gpb.map_to_per_etype(F.tensor([0], F.int32))\n        # These are original per-type IDs.\n        for etype_id, etype in enumerate(hg.canonical_etypes):\n            if F.sum((etype_ids == etype_id), 0) == 0:\n                continue\n            part_src_ids1 = F.boolean_mask(part_src_ids, etype_ids == etype_id)\n            src_ntype_ids1 = F.boolean_mask(\n                src_ntype_ids, etype_ids == etype_id\n            )\n            part_dst_ids1 = F.boolean_mask(part_dst_ids, etype_ids == etype_id)\n            dst_ntype_ids1 = F.boolean_mask(\n                dst_ntype_ids, etype_ids == etype_id\n            )\n            part_eids1 = F.boolean_mask(part_eids, etype_ids == etype_id)\n            assert np.all(F.asnumpy(src_ntype_ids1 == src_ntype_ids1[0]))\n            assert np.all(F.asnumpy(dst_ntype_ids1 == dst_ntype_ids1[0]))\n            src_ntype = hg.ntypes[F.as_scalar(src_ntype_ids1[0])]\n            dst_ntype = hg.ntypes[F.as_scalar(dst_ntype_ids1[0])]\n            orig_src_ids1 = F.gather_row(orig_nids[src_ntype], part_src_ids1)\n            orig_dst_ids1 = F.gather_row(orig_nids[dst_ntype], part_dst_ids1)\n            orig_eids1 = F.gather_row(orig_eids[etype], part_eids1)\n            orig_eids2 = hg.edge_ids(orig_src_ids1, orig_dst_ids1, etype=etype)\n            assert len(orig_eids1) == len(orig_eids2)\n            assert np.all(F.asnumpy(orig_eids1) == F.asnumpy(orig_eids2))\n        parts.append(part_g)\n        verify_graph_feats_hetero_dgl(\n            hg, gpb, part_g, node_feats, edge_feats, orig_nids, orig_eids\n        )\n\n        shuffled_labels.append(node_feats[test_ntype + \"/labels\"])\n        shuffled_elabels.append(\n            edge_feats[_etype_tuple_to_str(test_etype) + \"/labels\"]\n        )\n    _verify_hetero_graph(hg, parts)\n    shuffled_labels = F.asnumpy(F.cat(shuffled_labels, 0))\n    shuffled_elabels = F.asnumpy(F.cat(shuffled_elabels, 0))\n    orig_labels = np.zeros(shuffled_labels.shape, dtype=shuffled_labels.dtype)\n    orig_elabels = np.zeros(\n        shuffled_elabels.shape, dtype=shuffled_elabels.dtype\n    )\n    orig_labels[F.asnumpy(orig_nids[test_ntype])] = shuffled_labels\n    orig_elabels[F.asnumpy(orig_eids[test_etype])] = shuffled_elabels\n    assert np.all(orig_labels == F.asnumpy(hg.nodes[test_ntype].data[\"labels\"]))\n    assert np.all(\n        orig_elabels == F.asnumpy(hg.edges[test_etype].data[\"labels\"])\n    )\n\n\ndef check_partition(\n    g,\n    part_method,\n    num_parts=4,\n    num_trainers_per_machine=1,\n    load_feats=True,\n    graph_formats=None,\n):\n    g.ndata[\"labels\"] = F.arange(0, g.num_nodes())\n    g.ndata[\"feats\"] = F.tensor(np.random.randn(g.num_nodes(), 10), F.float32)\n    g.edata[\"feats\"] = F.tensor(np.random.randn(g.num_edges(), 10), F.float32)\n    g.update_all(fn.copy_u(\"feats\", \"msg\"), fn.sum(\"msg\", \"h\"))\n    g.update_all(fn.copy_e(\"feats\", \"msg\"), fn.sum(\"msg\", \"eh\"))\n    num_hops = 2\n\n    orig_nids, orig_eids = partition_graph(\n        g,\n        \"test\",\n        num_parts,\n        \"/tmp/partition\",\n        num_hops=num_hops,\n        part_method=part_method,\n        return_mapping=True,\n        num_trainers_per_machine=num_trainers_per_machine,\n        graph_formats=graph_formats,\n    )\n    part_sizes = []\n    shuffled_labels = []\n    shuffled_edata = []\n    for i in range(num_parts):\n        part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition(\n            \"/tmp/partition/test.json\", i, load_feats=load_feats\n        )\n        _verify_partition_data_types(part_g)\n        _verify_partition_formats(part_g, graph_formats)\n        if not load_feats:\n            assert not node_feats\n            assert not edge_feats\n            node_feats, edge_feats = load_partition_feats(\n                \"/tmp/partition/test.json\", i\n            )\n        if num_trainers_per_machine > 1:\n            for ntype in g.ntypes:\n                name = ntype + \"/trainer_id\"\n                assert name in node_feats\n                part_ids = F.floor_div(\n                    node_feats[name], num_trainers_per_machine\n                )\n                assert np.all(F.asnumpy(part_ids) == i)\n\n            for etype in g.canonical_etypes:\n                name = _etype_tuple_to_str(etype) + \"/trainer_id\"\n                assert name in edge_feats\n                part_ids = F.floor_div(\n                    edge_feats[name], num_trainers_per_machine\n                )\n                assert np.all(F.asnumpy(part_ids) == i)\n\n        # Check the metadata\n        assert gpb._num_nodes() == g.num_nodes()\n        assert gpb._num_edges() == g.num_edges()\n\n        assert gpb.num_partitions() == num_parts\n        gpb_meta = gpb.metadata()\n        assert len(gpb_meta) == num_parts\n        assert len(gpb.partid2nids(i)) == gpb_meta[i][\"num_nodes\"]\n        assert len(gpb.partid2eids(i)) == gpb_meta[i][\"num_edges\"]\n        part_sizes.append((gpb_meta[i][\"num_nodes\"], gpb_meta[i][\"num_edges\"]))\n\n        nid = F.boolean_mask(part_g.ndata[dgl.NID], part_g.ndata[\"inner_node\"])\n        local_nid = gpb.nid2localnid(nid, i)\n        assert F.dtype(local_nid) in (F.int64, F.int32)\n        assert np.all(F.asnumpy(local_nid) == np.arange(0, len(local_nid)))\n        eid = F.boolean_mask(part_g.edata[dgl.EID], part_g.edata[\"inner_edge\"])\n        local_eid = gpb.eid2localeid(eid, i)\n        assert F.dtype(local_eid) in (F.int64, F.int32)\n        assert np.all(F.asnumpy(local_eid) == np.arange(0, len(local_eid)))\n\n        # Check the node map.\n        local_nodes = F.boolean_mask(\n            part_g.ndata[dgl.NID], part_g.ndata[\"inner_node\"]\n        )\n        llocal_nodes = F.nonzero_1d(part_g.ndata[\"inner_node\"])\n        local_nodes1 = gpb.partid2nids(i)\n        assert F.dtype(local_nodes1) in (F.int32, F.int64)\n        assert np.all(\n            np.sort(F.asnumpy(local_nodes)) == np.sort(F.asnumpy(local_nodes1))\n        )\n        assert np.all(F.asnumpy(llocal_nodes) == np.arange(len(llocal_nodes)))\n\n        # Check the edge map.\n        local_edges = F.boolean_mask(\n            part_g.edata[dgl.EID], part_g.edata[\"inner_edge\"]\n        )\n        llocal_edges = F.nonzero_1d(part_g.edata[\"inner_edge\"])\n        local_edges1 = gpb.partid2eids(i)\n        assert F.dtype(local_edges1) in (F.int32, F.int64)\n        assert np.all(\n            np.sort(F.asnumpy(local_edges)) == np.sort(F.asnumpy(local_edges1))\n        )\n        assert np.all(F.asnumpy(llocal_edges) == np.arange(len(llocal_edges)))\n\n        # Verify the mapping between the reshuffled IDs and the original IDs.\n        part_src_ids, part_dst_ids = part_g.edges()\n        part_src_ids = F.gather_row(part_g.ndata[dgl.NID], part_src_ids)\n        part_dst_ids = F.gather_row(part_g.ndata[dgl.NID], part_dst_ids)\n        part_eids = part_g.edata[dgl.EID]\n        orig_src_ids = F.gather_row(orig_nids, part_src_ids)\n        orig_dst_ids = F.gather_row(orig_nids, part_dst_ids)\n        orig_eids1 = F.gather_row(orig_eids, part_eids)\n        orig_eids2 = g.edge_ids(orig_src_ids, orig_dst_ids)\n        assert F.shape(orig_eids1)[0] == F.shape(orig_eids2)[0]\n        assert np.all(F.asnumpy(orig_eids1) == F.asnumpy(orig_eids2))\n\n        local_orig_nids = orig_nids[part_g.ndata[dgl.NID]]\n        local_orig_eids = orig_eids[part_g.edata[dgl.EID]]\n        part_g.ndata[\"feats\"] = F.gather_row(g.ndata[\"feats\"], local_orig_nids)\n        part_g.edata[\"feats\"] = F.gather_row(g.edata[\"feats\"], local_orig_eids)\n        local_nodes = orig_nids[local_nodes]\n        local_edges = orig_eids[local_edges]\n\n        part_g.update_all(fn.copy_u(\"feats\", \"msg\"), fn.sum(\"msg\", \"h\"))\n        part_g.update_all(fn.copy_e(\"feats\", \"msg\"), fn.sum(\"msg\", \"eh\"))\n        assert F.allclose(\n            F.gather_row(g.ndata[\"h\"], local_nodes),\n            F.gather_row(part_g.ndata[\"h\"], llocal_nodes),\n        )\n        assert F.allclose(\n            F.gather_row(g.ndata[\"eh\"], local_nodes),\n            F.gather_row(part_g.ndata[\"eh\"], llocal_nodes),\n        )\n\n        for name in [\"labels\", \"feats\"]:\n            assert \"_N/\" + name in node_feats\n            assert node_feats[\"_N/\" + name].shape[0] == len(local_nodes)\n            true_feats = F.gather_row(g.ndata[name], local_nodes)\n            ndata = F.gather_row(node_feats[\"_N/\" + name], local_nid)\n            assert np.all(F.asnumpy(true_feats) == F.asnumpy(ndata))\n        for name in [\"feats\"]:\n            efeat_name = _etype_tuple_to_str(DEFAULT_ETYPE) + \"/\" + name\n            assert efeat_name in edge_feats\n            assert edge_feats[efeat_name].shape[0] == len(local_edges)\n            true_feats = F.gather_row(g.edata[name], local_edges)\n            edata = F.gather_row(edge_feats[efeat_name], local_eid)\n            assert np.all(F.asnumpy(true_feats) == F.asnumpy(edata))\n\n        # This only works if node/edge IDs are shuffled.\n        shuffled_labels.append(node_feats[\"_N/labels\"])\n        shuffled_edata.append(edge_feats[\"_N:_E:_N/feats\"])\n\n    # Verify that we can reconstruct node/edge data for original IDs.\n    shuffled_labels = F.asnumpy(F.cat(shuffled_labels, 0))\n    shuffled_edata = F.asnumpy(F.cat(shuffled_edata, 0))\n    orig_labels = np.zeros(shuffled_labels.shape, dtype=shuffled_labels.dtype)\n    orig_edata = np.zeros(shuffled_edata.shape, dtype=shuffled_edata.dtype)\n    orig_labels[F.asnumpy(orig_nids)] = shuffled_labels\n    orig_edata[F.asnumpy(orig_eids)] = shuffled_edata\n    assert np.all(orig_labels == F.asnumpy(g.ndata[\"labels\"]))\n    assert np.all(orig_edata == F.asnumpy(g.edata[\"feats\"]))\n\n    node_map = []\n    edge_map = []\n    for i, (num_nodes, num_edges) in enumerate(part_sizes):\n        node_map.append(np.ones(num_nodes) * i)\n        edge_map.append(np.ones(num_edges) * i)\n    node_map = np.concatenate(node_map)\n    edge_map = np.concatenate(edge_map)\n    nid2pid = gpb.nid2partid(F.arange(0, len(node_map)))\n    assert F.dtype(nid2pid) in (F.int32, F.int64)\n    assert np.all(F.asnumpy(nid2pid) == node_map)\n    eid2pid = gpb.eid2partid(F.arange(0, len(edge_map)))\n    assert F.dtype(eid2pid) in (F.int32, F.int64)\n    assert np.all(F.asnumpy(eid2pid) == edge_map)\n\n\n@pytest.mark.parametrize(\"part_method\", [\"metis\", \"random\"])\n@pytest.mark.parametrize(\"num_parts\", [1, 4])\n@pytest.mark.parametrize(\"num_trainers_per_machine\", [1])\n@pytest.mark.parametrize(\"load_feats\", [True, False])\n@pytest.mark.parametrize(\n    \"graph_formats\", [None, [\"csc\"], [\"coo\", \"csc\"], [\"coo\", \"csc\", \"csr\"]]\n)\ndef test_partition(\n    part_method,\n    num_parts,\n    num_trainers_per_machine,\n    load_feats,\n    graph_formats,\n):\n    os.environ[\"DGL_DIST_DEBUG\"] = \"1\"\n    if part_method == \"random\" and num_parts > 1:\n        num_trainers_per_machine = 1\n    g = create_random_graph(1000)\n    check_partition(\n        g,\n        part_method,\n        num_parts,\n        num_trainers_per_machine,\n        load_feats,\n        graph_formats,\n    )\n    hg = create_random_hetero()\n    check_hetero_partition(\n        hg,\n        part_method,\n        num_parts,\n        num_trainers_per_machine,\n        load_feats,\n        graph_formats,\n    )\n    reset_envs()\n\n\n@pytest.mark.parametrize(\"node_map_dtype\", [F.int32, F.int64])\n@pytest.mark.parametrize(\"edge_map_dtype\", [F.int32, F.int64])\ndef test_RangePartitionBook(node_map_dtype, edge_map_dtype):\n    part_id = 1\n    num_parts = 2\n\n    # homogeneous\n    node_map = {\n        DEFAULT_NTYPE: F.tensor([[0, 1000], [1000, 2000]], dtype=node_map_dtype)\n    }\n    edge_map = {\n        DEFAULT_ETYPE: F.tensor(\n            [[0, 5000], [5000, 10000]], dtype=edge_map_dtype\n        )\n    }\n    ntypes = {DEFAULT_NTYPE: 0}\n    etypes = {DEFAULT_ETYPE: 0}\n    gpb = RangePartitionBook(\n        part_id, num_parts, node_map, edge_map, ntypes, etypes\n    )\n    assert gpb.etypes == [DEFAULT_ETYPE[1]]\n    assert gpb.canonical_etypes == [DEFAULT_ETYPE]\n    assert gpb.to_canonical_etype(DEFAULT_ETYPE[1]) == DEFAULT_ETYPE\n    ntype_ids, per_ntype_ids = gpb.map_to_per_ntype(\n        F.tensor([0, 1000], dtype=node_map_dtype)\n    )\n    assert ntype_ids.dtype == node_map_dtype\n    assert per_ntype_ids.dtype == node_map_dtype\n    assert np.all(F.asnumpy(ntype_ids) == 0)\n    assert np.all(F.asnumpy(per_ntype_ids) == [0, 1000])\n\n    etype_ids, per_etype_ids = gpb.map_to_per_etype(\n        F.tensor([0, 5000], dtype=edge_map_dtype)\n    )\n    assert etype_ids.dtype == edge_map_dtype\n    assert per_etype_ids.dtype == edge_map_dtype\n    assert np.all(F.asnumpy(etype_ids) == 0)\n    assert np.all(F.asnumpy(per_etype_ids) == [0, 5000])\n\n    node_policy = NodePartitionPolicy(gpb, DEFAULT_NTYPE)\n    assert node_policy.type_name == DEFAULT_NTYPE\n    edge_policy = EdgePartitionPolicy(gpb, DEFAULT_ETYPE)\n    assert edge_policy.type_name == DEFAULT_ETYPE\n\n    # Init via etype is not supported\n    node_map = {\n        \"node1\": F.tensor([[0, 1000], [1000, 2000]], dtype=node_map_dtype),\n        \"node2\": F.tensor([[0, 1000], [1000, 2000]], dtype=node_map_dtype),\n    }\n    edge_map = {\n        \"edge1\": F.tensor([[0, 5000], [5000, 10000]], dtype=edge_map_dtype)\n    }\n    ntypes = {\"node1\": 0, \"node2\": 1}\n    etypes = {\"edge1\": 0}\n    expect_except = False\n    try:\n        RangePartitionBook(\n            part_id, num_parts, node_map, edge_map, ntypes, etypes\n        )\n    except AssertionError:\n        expect_except = True\n    assert expect_except\n    expect_except = False\n    try:\n        EdgePartitionPolicy(gpb, \"edge1\")\n    except AssertionError:\n        expect_except = True\n    assert expect_except\n\n    # heterogeneous, init via canonical etype\n    node_map = {\n        \"node1\": F.tensor([[0, 1000], [1000, 2000]], dtype=node_map_dtype),\n        \"node2\": F.tensor([[0, 1000], [1000, 2000]], dtype=node_map_dtype),\n    }\n    edge_map = {\n        (\"node1\", \"edge1\", \"node2\"): F.tensor(\n            [[0, 5000], [5000, 10000]], dtype=edge_map_dtype\n        )\n    }\n    ntypes = {\"node1\": 0, \"node2\": 1}\n    etypes = {(\"node1\", \"edge1\", \"node2\"): 0}\n    c_etype = list(etypes.keys())[0]\n    gpb = RangePartitionBook(\n        part_id, num_parts, node_map, edge_map, ntypes, etypes\n    )\n    assert gpb.etypes == [\"edge1\"]\n    assert gpb.canonical_etypes == [c_etype]\n    assert gpb.to_canonical_etype(\"edge1\") == c_etype\n    assert gpb.to_canonical_etype(c_etype) == c_etype\n\n    ntype_ids, per_ntype_ids = gpb.map_to_per_ntype(\n        F.tensor([0, 1000], dtype=node_map_dtype)\n    )\n    assert ntype_ids.dtype == node_map_dtype\n    assert per_ntype_ids.dtype == node_map_dtype\n    assert np.all(F.asnumpy(ntype_ids) == 0)\n    assert np.all(F.asnumpy(per_ntype_ids) == [0, 1000])\n\n    etype_ids, per_etype_ids = gpb.map_to_per_etype(\n        F.tensor([0, 5000], dtype=edge_map_dtype)\n    )\n    assert etype_ids.dtype == edge_map_dtype\n    assert per_etype_ids.dtype == edge_map_dtype\n    assert np.all(F.asnumpy(etype_ids) == 0)\n    assert np.all(F.asnumpy(per_etype_ids) == [0, 5000])\n\n    expect_except = False\n    try:\n        gpb.to_canonical_etype((\"node1\", \"edge2\", \"node2\"))\n    except BaseException:\n        expect_except = True\n    assert expect_except\n    expect_except = False\n    try:\n        gpb.to_canonical_etype(\"edge2\")\n    except BaseException:\n        expect_except = True\n    assert expect_except\n\n    # NodePartitionPolicy\n    node_policy = NodePartitionPolicy(gpb, \"node1\")\n    assert node_policy.type_name == \"node1\"\n    assert node_policy.policy_str == \"node~node1\"\n    assert node_policy.part_id == part_id\n    assert node_policy.is_node\n    assert node_policy.get_data_name(\"x\").is_node()\n    local_ids = th.arange(0, 1000)\n    global_ids = local_ids + 1000\n    assert th.equal(node_policy.to_local(global_ids), local_ids)\n    assert th.all(node_policy.to_partid(global_ids) == part_id)\n    assert node_policy.get_part_size() == 1000\n    assert node_policy.get_size() == 2000\n\n    # EdgePartitionPolicy\n    edge_policy = EdgePartitionPolicy(gpb, c_etype)\n    assert edge_policy.type_name == c_etype\n    assert edge_policy.policy_str == \"edge~node1:edge1:node2\"\n    assert edge_policy.part_id == part_id\n    assert not edge_policy.is_node\n    assert not edge_policy.get_data_name(\"x\").is_node()\n    local_ids = th.arange(0, 5000)\n    global_ids = local_ids + 5000\n    assert th.equal(edge_policy.to_local(global_ids), local_ids)\n    assert th.all(edge_policy.to_partid(global_ids) == part_id)\n    assert edge_policy.get_part_size() == 5000\n    assert edge_policy.get_size() == 10000\n\n    expect_except = False\n    try:\n        HeteroDataName(False, \"edge1\", \"feat\")\n    except BaseException:\n        expect_except = True\n    assert expect_except\n    data_name = HeteroDataName(False, c_etype, \"feat\")\n    assert data_name.get_type() == c_etype\n\n\ndef test_UnknownPartitionBook():\n    node_map = {\"_N\": {0: 0, 1: 1, 2: 2}}\n    edge_map = {\"_N:_E:_N\": {0: 0, 1: 1, 2: 2}}\n\n    part_metadata = {\n        \"num_parts\": 1,\n        \"num_nodes\": len(node_map),\n        \"num_edges\": len(edge_map),\n        \"node_map\": node_map,\n        \"edge_map\": edge_map,\n        \"graph_name\": \"test_graph\",\n    }\n\n    with tempfile.TemporaryDirectory() as test_dir:\n        part_config = os.path.join(test_dir, \"test_graph.json\")\n        with open(part_config, \"w\") as file:\n            json.dump(part_metadata, file, indent=4)\n        try:\n            load_partition_book(part_config, 0)\n        except Exception as e:\n            if not isinstance(e, TypeError):\n                raise e\n\n\n@pytest.mark.parametrize(\"part_method\", [\"metis\", \"random\"])\n@pytest.mark.parametrize(\"num_parts\", [1, 4])\n@pytest.mark.parametrize(\"store_eids\", [True, False])\n@pytest.mark.parametrize(\"store_inner_node\", [True, False])\n@pytest.mark.parametrize(\"store_inner_edge\", [True, False])\n@pytest.mark.parametrize(\"debug_mode\", [True, False])\ndef test_dgl_partition_to_graphbolt_homo(\n    part_method,\n    num_parts,\n    store_eids,\n    store_inner_node,\n    store_inner_edge,\n    debug_mode,\n):\n    reset_envs()\n    if debug_mode:\n        os.environ[\"DGL_DIST_DEBUG\"] = \"1\"\n    with tempfile.TemporaryDirectory() as test_dir:\n        g = create_random_graph(1000)\n        graph_name = \"test\"\n        partition_graph(\n            g, graph_name, num_parts, test_dir, part_method=part_method\n        )\n        part_config = os.path.join(test_dir, f\"{graph_name}.json\")\n        dgl_partition_to_graphbolt(\n            part_config,\n            store_eids=store_eids,\n            store_inner_node=store_inner_node,\n            store_inner_edge=store_inner_edge,\n        )\n        for part_id in range(num_parts):\n            orig_g = dgl.load_graphs(\n                os.path.join(test_dir, f\"part{part_id}/graph.dgl\")\n            )[0][0]\n            os.remove(os.path.join(test_dir, f\"part{part_id}/graph.dgl\"))\n            new_g = load_partition(\n                part_config, part_id, load_feats=False, use_graphbolt=True\n            )[0]\n            orig_indptr, orig_indices, orig_eids = orig_g.adj().csc()\n            # The original graph is in int64 while the partitioned graph is in\n            # int32 as dtype formatting is applied when converting to graphbolt\n            # format.\n            assert orig_indptr.dtype == th.int64\n            assert orig_indices.dtype == th.int64\n            assert new_g.csc_indptr.dtype == th.int32\n            assert new_g.indices.dtype == th.int32\n            assert th.equal(orig_indptr, new_g.csc_indptr)\n            assert th.equal(orig_indices, new_g.indices)\n            assert new_g.node_type_offset is None\n            assert orig_g.ndata[dgl.NID].dtype == th.int64\n            assert new_g.node_attributes[dgl.NID].dtype == th.int64\n            assert th.equal(\n                orig_g.ndata[dgl.NID], new_g.node_attributes[dgl.NID]\n            )\n            if store_inner_node or debug_mode:\n                assert th.equal(\n                    orig_g.ndata[\"inner_node\"],\n                    new_g.node_attributes[\"inner_node\"],\n                )\n            if store_eids or debug_mode:\n                assert orig_g.edata[dgl.EID].dtype == th.int64\n                assert new_g.edge_attributes[dgl.EID].dtype == th.int64\n                assert th.equal(\n                    orig_g.edata[dgl.EID][orig_eids],\n                    new_g.edge_attributes[dgl.EID],\n                )\n            if store_inner_edge or debug_mode:\n                assert orig_g.edata[\"inner_edge\"].dtype == th.uint8\n                assert new_g.edge_attributes[\"inner_edge\"].dtype == th.uint8\n                assert th.equal(\n                    orig_g.edata[\"inner_edge\"][orig_eids],\n                    new_g.edge_attributes[\"inner_edge\"],\n                )\n            assert new_g.type_per_edge is None\n            assert new_g.node_type_to_id is None\n            assert new_g.edge_type_to_id is None\n\n\n@pytest.mark.parametrize(\"part_method\", [\"metis\", \"random\"])\n@pytest.mark.parametrize(\"num_parts\", [1, 4])\n@pytest.mark.parametrize(\"store_eids\", [True, False])\n@pytest.mark.parametrize(\"store_inner_node\", [True, False])\n@pytest.mark.parametrize(\"store_inner_edge\", [True, False])\n@pytest.mark.parametrize(\"debug_mode\", [True, False])\ndef test_dgl_partition_to_graphbolt_hetero(\n    part_method,\n    num_parts,\n    store_eids,\n    store_inner_node,\n    store_inner_edge,\n    debug_mode,\n):\n    reset_envs()\n    if debug_mode:\n        os.environ[\"DGL_DIST_DEBUG\"] = \"1\"\n    with tempfile.TemporaryDirectory() as test_dir:\n        g = create_random_hetero()\n        graph_name = \"test\"\n        partition_graph(\n            g, graph_name, num_parts, test_dir, part_method=part_method\n        )\n        part_config = os.path.join(test_dir, f\"{graph_name}.json\")\n        dgl_partition_to_graphbolt(\n            part_config,\n            store_eids=store_eids,\n            store_inner_node=store_inner_node,\n            store_inner_edge=store_inner_edge,\n        )\n        for part_id in range(num_parts):\n            orig_g = dgl.load_graphs(\n                os.path.join(test_dir, f\"part{part_id}/graph.dgl\")\n            )[0][0]\n            os.remove(os.path.join(test_dir, f\"part{part_id}/graph.dgl\"))\n            new_g = load_partition(\n                part_config, part_id, load_feats=False, use_graphbolt=True\n            )[0]\n            orig_indptr, orig_indices, orig_eids = orig_g.adj().csc()\n\n            # Edges should be sorted in etype for the same dst node.\n            if debug_mode:\n                num_inner_edges = orig_g.edata[\"inner_edge\"].sum().item()\n                assert (\n                    num_inner_edges\n                    == orig_g.edata[\"inner_edge\"][th.arange(num_inner_edges)]\n                    .sum()\n                    .item()\n                )\n                assert (\n                    num_inner_edges\n                    == new_g.edge_attributes[\"inner_edge\"][:num_inner_edges]\n                    .sum()\n                    .item()\n                )\n                num_inner_nodes = orig_g.ndata[\"inner_node\"].sum().item()\n                assert (\n                    num_inner_nodes\n                    == orig_g.ndata[\"inner_node\"][th.arange(num_inner_nodes)]\n                    .sum()\n                    .item()\n                )\n                assert (\n                    num_inner_nodes\n                    == new_g.node_attributes[\"inner_node\"][:num_inner_nodes]\n                    .sum()\n                    .item()\n                )\n                for i in range(orig_g.num_nodes()):\n                    if orig_g.in_degrees(i) == 0:\n                        continue\n                    # Verify DGLGraph partitions.\n                    eids = orig_g.in_edges(i, form=\"eid\")\n                    etypes = orig_g.edata[dgl.ETYPE][eids]\n                    assert th.equal(etypes, etypes.sort()[0])\n                    # Verify GraphBolt partitions.\n                    eids_start = new_g.csc_indptr[i]\n                    eids_end = new_g.csc_indptr[i + 1]\n                    etypes = new_g.edge_attributes[dgl.ETYPE][\n                        eids_start:eids_end\n                    ]\n                    assert th.equal(etypes, etypes.sort()[0])\n\n            # The original graph is in int64 while the partitioned graph is in\n            # int32 as dtype formatting is applied when converting to graphbolt\n            # format.\n            assert orig_indptr.dtype == th.int64\n            assert orig_indices.dtype == th.int64\n            assert new_g.csc_indptr.dtype == th.int32\n            assert new_g.indices.dtype == th.int32\n            assert th.equal(orig_indptr, new_g.csc_indptr)\n            assert th.equal(orig_indices, new_g.indices)\n            assert orig_g.ndata[dgl.NID].dtype == th.int64\n            assert new_g.node_attributes[dgl.NID].dtype == th.int64\n            assert th.equal(\n                orig_g.ndata[dgl.NID], new_g.node_attributes[dgl.NID]\n            )\n            if store_inner_node or debug_mode:\n                assert th.equal(\n                    orig_g.ndata[\"inner_node\"],\n                    new_g.node_attributes[\"inner_node\"],\n                )\n            if debug_mode:\n                assert orig_g.ndata[dgl.NTYPE].dtype == th.int32\n                assert new_g.node_attributes[dgl.NTYPE].dtype == th.int8\n                assert th.equal(\n                    orig_g.ndata[dgl.NTYPE], new_g.node_attributes[dgl.NTYPE]\n                )\n            if store_eids or debug_mode:\n                assert orig_g.edata[dgl.EID].dtype == th.int64\n                assert new_g.edge_attributes[dgl.EID].dtype == th.int64\n                assert th.equal(\n                    orig_g.edata[dgl.EID][orig_eids],\n                    new_g.edge_attributes[dgl.EID],\n                )\n            if store_inner_edge or debug_mode:\n                assert orig_g.edata[\"inner_edge\"].dtype == th.uint8\n                assert new_g.edge_attributes[\"inner_edge\"].dtype == th.uint8\n                assert th.equal(\n                    orig_g.edata[\"inner_edge\"],\n                    new_g.edge_attributes[\"inner_edge\"],\n                )\n            if debug_mode:\n                assert orig_g.edata[dgl.ETYPE].dtype == th.int32\n                assert new_g.edge_attributes[dgl.ETYPE].dtype == th.int8\n                assert th.equal(\n                    orig_g.edata[dgl.ETYPE][orig_eids],\n                    new_g.edge_attributes[dgl.ETYPE],\n                )\n            assert th.equal(\n                orig_g.edata[dgl.ETYPE][orig_eids], new_g.type_per_edge\n            )\n\n            for node_type, type_id in new_g.node_type_to_id.items():\n                assert g.get_ntype_id(node_type) == type_id\n            for edge_type, type_id in new_g.edge_type_to_id.items():\n                assert g.get_etype_id(_etype_str_to_tuple(edge_type)) == type_id\n            assert new_g.node_type_offset is None\n\n\ndef test_not_sorted_node_edge_map():\n    # Partition configure file which includes not sorted node/edge map.\n    part_config_str = \"\"\"\n{\n    \"edge_map\": {\n        \"item:likes-rev:user\": [\n            [\n                0,\n                100\n            ],\n            [\n                1000,\n                1500\n            ]\n        ],\n        \"user:follows-rev:user\": [\n            [\n                300,\n                600\n            ],\n            [\n                2100,\n                2800\n            ]\n        ],\n        \"user:follows:user\": [\n            [\n                100,\n                300\n            ],\n            [\n                1500,\n                2100\n            ]\n        ],\n        \"user:likes:item\": [\n            [\n                600,\n                1000\n            ],\n            [\n                2800,\n                3600\n            ]\n        ]\n    },\n    \"etypes\": {\n        \"item:likes-rev:user\": 0,\n        \"user:follows-rev:user\": 2,\n        \"user:follows:user\": 1,\n        \"user:likes:item\": 3\n    },\n    \"graph_name\": \"test_graph\",\n    \"halo_hops\": 1,\n    \"node_map\": {\n        \"user\": [\n            [\n                100,\n                300\n            ],\n            [\n                600,\n                1000\n            ]\n        ],\n        \"item\": [\n            [\n                0,\n                100\n            ],\n            [\n                300,\n                600\n            ]\n        ]\n    },\n    \"ntypes\": {\n        \"user\": 1,\n        \"item\": 0\n    },\n    \"num_edges\": 3600,\n    \"num_nodes\": 1000,\n    \"num_parts\": 2,\n    \"part-0\": {\n        \"edge_feats\": \"part0/edge_feat.dgl\",\n        \"node_feats\": \"part0/node_feat.dgl\",\n        \"part_graph\": \"part0/graph.dgl\"\n    },\n    \"part-1\": {\n        \"edge_feats\": \"part1/edge_feat.dgl\",\n        \"node_feats\": \"part1/node_feat.dgl\",\n        \"part_graph\": \"part1/graph.dgl\"\n    },\n    \"part_method\": \"metis\"\n}\n    \"\"\"\n    with tempfile.TemporaryDirectory() as test_dir:\n        part_config = os.path.join(test_dir, \"test_graph.json\")\n        with open(part_config, \"w\") as file:\n            file.write(part_config_str)\n        # Part 0.\n        gpb, _, _, _ = load_partition_book(part_config, 0)\n        assert gpb.local_ntype_offset == [0, 100, 300]\n        assert gpb.local_etype_offset == [0, 100, 300, 600, 1000]\n        # Patr 1.\n        gpb, _, _, _ = load_partition_book(part_config, 1)\n        assert gpb.local_ntype_offset == [0, 300, 700]\n        assert gpb.local_etype_offset == [0, 500, 1100, 1800, 2600]\n\n\ndef _get_part_IDs(part_g):\n    # These are partition-local IDs.\n    num_columns = part_g.csc_indptr.diff()\n    part_src_ids = part_g.indices\n    part_dst_ids = th.arange(part_g.total_num_nodes).repeat_interleave(\n        num_columns\n    )\n    # These are reshuffled global homogeneous IDs.\n    part_src_ids = F.gather_row(part_g.node_attributes[dgl.NID], part_src_ids)\n    part_dst_ids = F.gather_row(part_g.node_attributes[dgl.NID], part_dst_ids)\n    return part_src_ids, part_dst_ids\n\n\ndef _verify_orig_edge_IDs_gb(\n    g,\n    orig_nids,\n    orig_eids,\n    part_eids,\n    part_src_ids,\n    part_dst_ids,\n    src_ntype=None,\n    dst_ntype=None,\n    etype=None,\n):\n    \"\"\"\n    check list:\n        make sure orig edge id are correct after\n    \"\"\"\n    if src_ntype is not None and dst_ntype is not None:\n        orig_src_nid = orig_nids[src_ntype]\n        orig_dst_nid = orig_nids[dst_ntype]\n    else:\n        orig_src_nid = orig_nids\n        orig_dst_nid = orig_nids\n    orig_src_ids = F.gather_row(orig_src_nid, part_src_ids)\n    orig_dst_ids = F.gather_row(orig_dst_nid, part_dst_ids)\n    if etype is not None:\n        orig_eids = orig_eids[etype]\n    orig_eids1 = F.gather_row(orig_eids, part_eids)\n    orig_eids2 = g.edge_ids(orig_src_ids, orig_dst_ids, etype=etype)\n    assert len(orig_eids1) == len(orig_eids2)\n    assert np.all(F.asnumpy(orig_eids1) == F.asnumpy(orig_eids2))\n\n\ndef _verify_metadata_gb(gpb, g, num_parts, part_id, part_sizes):\n    \"\"\"\n    check list:\n        make sure the number of nodes and edges is correct.\n        make sure the number of parts is correct.\n        make sure the number of nodes and edges in each parts os corrcet.\n    \"\"\"\n    assert gpb._num_nodes() == g.num_nodes()\n    assert gpb._num_edges() == g.num_edges()\n\n    assert gpb.num_partitions() == num_parts\n    gpb_meta = gpb.metadata()\n    assert len(gpb_meta) == num_parts\n    assert len(gpb.partid2nids(part_id)) == gpb_meta[part_id][\"num_nodes\"]\n    assert len(gpb.partid2eids(part_id)) == gpb_meta[part_id][\"num_edges\"]\n    part_sizes.append(\n        (gpb_meta[part_id][\"num_nodes\"], gpb_meta[part_id][\"num_edges\"])\n    )\n\n\ndef _verify_local_id_gb(part_g, part_id, gpb):\n    \"\"\"\n    check list:\n        make sure the type of local id is correct.\n        make sure local id have a right order.\n    \"\"\"\n    nid = F.boolean_mask(\n        part_g.node_attributes[dgl.NID],\n        part_g.node_attributes[\"inner_node\"],\n    )\n    local_nid = gpb.nid2localnid(nid, part_id)\n    assert F.dtype(local_nid) in (F.int64, F.int32)\n    assert np.all(F.asnumpy(local_nid) == np.arange(0, len(local_nid)))\n    eid = F.boolean_mask(\n        part_g.edge_attributes[dgl.EID],\n        part_g.edge_attributes[\"inner_edge\"],\n    )\n    local_eid = gpb.eid2localeid(eid, part_id)\n    assert F.dtype(local_eid) in (F.int64, F.int32)\n    assert np.all(np.sort(F.asnumpy(local_eid)) == np.arange(0, len(local_eid)))\n    return local_nid, local_eid\n\n\ndef _verify_map_gb(\n    part_g,\n    part_id,\n    gpb,\n):\n    \"\"\"\n    check list:\n        make sure the map node and its data type is correct.\n    \"\"\"\n    # Check the node map.\n    local_nodes = F.boolean_mask(\n        part_g.node_attributes[dgl.NID],\n        part_g.node_attributes[\"inner_node\"],\n    )\n    inner_node_index = F.nonzero_1d(part_g.node_attributes[\"inner_node\"])\n    mapping_nodes = gpb.partid2nids(part_id)\n    assert F.dtype(mapping_nodes) in (F.int32, F.int64)\n    assert np.all(\n        np.sort(F.asnumpy(local_nodes)) == np.sort(F.asnumpy(mapping_nodes))\n    )\n    assert np.all(\n        F.asnumpy(inner_node_index) == np.arange(len(inner_node_index))\n    )\n\n    # Check the edge map.\n\n    local_edges = F.boolean_mask(\n        part_g.edge_attributes[dgl.EID],\n        part_g.edge_attributes[\"inner_edge\"],\n    )\n    inner_edge_index = F.nonzero_1d(part_g.edge_attributes[\"inner_edge\"])\n    mapping_edges = gpb.partid2eids(part_id)\n    assert F.dtype(mapping_edges) in (F.int32, F.int64)\n    assert np.all(\n        np.sort(F.asnumpy(local_edges)) == np.sort(F.asnumpy(mapping_edges))\n    )\n    assert np.all(\n        F.asnumpy(inner_edge_index) == np.arange(len(inner_edge_index))\n    )\n    return local_nodes, local_edges\n\n\ndef _verify_local_and_map_id_gb(\n    part_g,\n    part_id,\n    gpb,\n    store_inner_node,\n    store_inner_edge,\n    store_eids,\n):\n    \"\"\"\n    check list:\n        make sure local id are correct.\n        make sure mapping id are correct.\n    \"\"\"\n    if store_inner_node and store_inner_edge and store_eids:\n        _verify_local_id_gb(part_g, part_id, gpb)\n        _verify_map_gb(part_g, part_id, gpb)\n\n\ndef _verify_orig_IDs_gb(\n    part_g,\n    gpb,\n    g,\n    is_homo=False,\n    part_src_ids=None,\n    part_dst_ids=None,\n    src_ntype_ids=None,\n    dst_ntype_ids=None,\n    orig_nids=None,\n    orig_eids=None,\n):\n    \"\"\"\n    check list:\n        make sure orig edge id are correct.\n        make sure hetero ntype id are correct.\n    \"\"\"\n    part_eids = part_g.edge_attributes[dgl.EID]\n    if is_homo:\n        _verify_orig_edge_IDs_gb(\n            g, orig_nids, orig_eids, part_eids, part_src_ids, part_dst_ids\n        )\n        local_orig_nids = orig_nids[part_g.node_attributes[dgl.NID]]\n        local_orig_eids = orig_eids[part_g.edge_attributes[dgl.EID]]\n        part_g.node_attributes[\"feats\"] = F.gather_row(\n            g.ndata[\"feats\"], local_orig_nids\n        )\n        part_g.edge_attributes[\"feats\"] = F.gather_row(\n            g.edata[\"feats\"], local_orig_eids\n        )\n    else:\n        etype_ids, part_eids = gpb.map_to_per_etype(part_eids)\n        # `IdMap` is in int64 by default.\n        assert etype_ids.dtype == F.int64\n\n        # These are original per-type IDs.\n        for etype_id, etype in enumerate(g.canonical_etypes):\n            part_src_ids1 = F.boolean_mask(part_src_ids, etype_ids == etype_id)\n            src_ntype_ids1 = F.boolean_mask(\n                src_ntype_ids, etype_ids == etype_id\n            )\n            part_dst_ids1 = F.boolean_mask(part_dst_ids, etype_ids == etype_id)\n            dst_ntype_ids1 = F.boolean_mask(\n                dst_ntype_ids, etype_ids == etype_id\n            )\n            part_eids1 = F.boolean_mask(part_eids, etype_ids == etype_id)\n            assert np.all(F.asnumpy(src_ntype_ids1 == src_ntype_ids1[0]))\n            assert np.all(F.asnumpy(dst_ntype_ids1 == dst_ntype_ids1[0]))\n            src_ntype = g.ntypes[F.as_scalar(src_ntype_ids1[0])]\n            dst_ntype = g.ntypes[F.as_scalar(dst_ntype_ids1[0])]\n\n            _verify_orig_edge_IDs_gb(\n                g,\n                orig_nids,\n                orig_eids,\n                part_eids1,\n                part_src_ids1,\n                part_dst_ids1,\n                src_ntype,\n                dst_ntype,\n                etype,\n            )\n\n\n@pytest.mark.parametrize(\"part_method\", [\"metis\", \"random\"])\n@pytest.mark.parametrize(\"num_parts\", [1, 4])\n@pytest.mark.parametrize(\"store_eids\", [True, False])\n@pytest.mark.parametrize(\"store_inner_node\", [True, False])\n@pytest.mark.parametrize(\"store_inner_edge\", [True, False])\n@pytest.mark.parametrize(\"debug_mode\", [True, False])\ndef test_partition_graph_graphbolt_homo(\n    part_method,\n    num_parts,\n    store_eids,\n    store_inner_node,\n    store_inner_edge,\n    debug_mode,\n):\n    reset_envs()\n    if debug_mode:\n        os.environ[\"DGL_DIST_DEBUG\"] = \"1\"\n    with tempfile.TemporaryDirectory() as test_dir:\n        g = create_random_graph(1000)\n        graph_name = \"test\"\n        g.ndata[\"labels\"] = F.arange(0, g.num_nodes())\n        g.ndata[\"feats\"] = F.tensor(\n            np.random.randn(g.num_nodes(), 10), F.float32\n        )\n        g.edata[\"feats\"] = F.tensor(\n            np.random.randn(g.num_edges(), 10), F.float32\n        )\n\n        orig_nids, orig_eids = partition_graph(\n            g,\n            graph_name,\n            num_parts,\n            test_dir,\n            part_method=part_method,\n            use_graphbolt=True,\n            store_eids=store_eids,\n            store_inner_node=store_inner_node,\n            store_inner_edge=store_inner_edge,\n            return_mapping=True,\n        )\n\n        if debug_mode:\n            store_eids = store_inner_node = store_inner_edge = True\n\n        _verify_graphbolt_part(\n            g,\n            test_dir,\n            orig_nids,\n            orig_eids,\n            graph_name,\n            num_parts,\n            store_inner_node,\n            store_inner_edge,\n            store_eids,\n            is_homo=True,\n        )\n\n\ndef _verify_constructed_id_gb(part_sizes, gpb):\n    \"\"\"\n    verify the part id of each node by constructed nids.\n    check list:\n        make sure each node' part id and its type are corect\n    \"\"\"\n    node_map = []\n    edge_map = []\n    for part_i, (num_nodes, num_edges) in enumerate(part_sizes):\n        node_map.append(np.ones(num_nodes) * part_i)\n        edge_map.append(np.ones(num_edges) * part_i)\n    node_map = np.concatenate(node_map)\n    edge_map = np.concatenate(edge_map)\n    nid2pid = gpb.nid2partid(F.arange(0, len(node_map)))\n    assert F.dtype(nid2pid) in (F.int32, F.int64)\n    assert np.all(F.asnumpy(nid2pid) == node_map)\n    eid2pid = gpb.eid2partid(F.arange(0, len(edge_map)))\n    assert F.dtype(eid2pid) in (F.int32, F.int64)\n    assert np.all(F.asnumpy(eid2pid) == edge_map)\n\n\ndef _verify_shuffled_labels_gb(\n    g,\n    shuffled_labels,\n    shuffled_edata,\n    orig_nids,\n    orig_eids,\n    test_ntype=None,\n    test_etype=None,\n):\n    \"\"\"\n    check list:\n        make sure node data are correct.\n        make sure edge data are correct.\n    \"\"\"\n    shuffled_labels = F.asnumpy(F.cat(shuffled_labels, 0))\n    shuffled_edata = F.asnumpy(F.cat(shuffled_edata, 0))\n    orig_labels = np.zeros(shuffled_labels.shape, dtype=shuffled_labels.dtype)\n    orig_edata = np.zeros(shuffled_edata.shape, dtype=shuffled_edata.dtype)\n\n    orig_nid = orig_nids if test_ntype is None else orig_nids[test_ntype]\n    orig_eid = orig_eids if test_etype is None else orig_eids[test_etype]\n    nlabel = (\n        g.ndata[\"labels\"]\n        if test_ntype is None\n        else g.nodes[test_ntype].data[\"labels\"]\n    )\n    edata = (\n        g.edata[\"feats\"]\n        if test_etype is None\n        else g.edges[test_etype].data[\"labels\"]\n    )\n\n    orig_labels[F.asnumpy(orig_nid)] = shuffled_labels\n    orig_edata[F.asnumpy(orig_eid)] = shuffled_edata\n    assert np.all(orig_labels == F.asnumpy(nlabel))\n    assert np.all(orig_edata == F.asnumpy(edata))\n\n\ndef _verify_node_type_ID_gb(part_g, gpb):\n    \"\"\"\n    check list:\n        make sure ntype id have correct data type\n    \"\"\"\n    part_src_ids, part_dst_ids = _get_part_IDs(part_g)\n    # These are reshuffled per-type IDs.\n    src_ntype_ids, part_src_ids = gpb.map_to_per_ntype(part_src_ids)\n    dst_ntype_ids, part_dst_ids = gpb.map_to_per_ntype(part_dst_ids)\n    # `IdMap` is in int64 by default.\n    assert src_ntype_ids.dtype == F.int64\n    assert dst_ntype_ids.dtype == F.int64\n\n    with pytest.raises(dgl.utils.internal.InconsistentDtypeException):\n        gpb.map_to_per_ntype(F.tensor([0], F.int32))\n    with pytest.raises(dgl.utils.internal.InconsistentDtypeException):\n        gpb.map_to_per_etype(F.tensor([0], F.int32))\n    return (\n        part_src_ids,\n        part_dst_ids,\n        src_ntype_ids,\n        part_src_ids,\n        dst_ntype_ids,\n    )\n\n\ndef _verify_IDs_gb(\n    g,\n    part_g,\n    part_id,\n    gpb,\n    part_sizes,\n    orig_nids,\n    orig_eids,\n    store_inner_node,\n    store_inner_edge,\n    store_eids,\n    is_homo,\n):\n    # verify local id and mapping id\n    _verify_local_and_map_id_gb(\n        part_g,\n        part_id,\n        gpb,\n        store_inner_node,\n        store_inner_edge,\n        store_eids,\n    )\n\n    # Verify the mapping between the reshuffled IDs and the original IDs.\n    (\n        part_src_ids,\n        part_dst_ids,\n        src_ntype_ids,\n        part_src_ids,\n        dst_ntype_ids,\n    ) = _verify_node_type_ID_gb(part_g, gpb)\n\n    if store_eids:\n        _verify_orig_IDs_gb(\n            part_g,\n            gpb,\n            g,\n            part_src_ids=part_src_ids,\n            part_dst_ids=part_dst_ids,\n            src_ntype_ids=src_ntype_ids,\n            dst_ntype_ids=dst_ntype_ids,\n            orig_nids=orig_nids,\n            orig_eids=orig_eids,\n            is_homo=is_homo,\n        )\n    _verify_constructed_id_gb(part_sizes, gpb)\n\n\ndef _collect_data_gb(\n    parts,\n    part_g,\n    gpbs,\n    gpb,\n    tot_node_feats,\n    node_feats,\n    tot_edge_feats,\n    edge_feats,\n    shuffled_labels,\n    shuffled_edata,\n    test_ntype,\n    test_etype,\n):\n    if test_ntype != None:\n        shuffled_labels.append(node_feats[test_ntype + \"/labels\"])\n        shuffled_edata.append(\n            edge_feats[_etype_tuple_to_str(test_etype) + \"/labels\"]\n        )\n    else:\n        shuffled_labels.append(node_feats[\"_N/labels\"])\n        shuffled_edata.append(edge_feats[\"_N:_E:_N/feats\"])\n    parts.append(part_g)\n    gpbs.append(gpb)\n    tot_node_feats.append(node_feats)\n    tot_edge_feats.append(edge_feats)\n\n\ndef _verify_graphbolt_part(\n    g,\n    test_dir,\n    orig_nids,\n    orig_eids,\n    graph_name,\n    num_parts,\n    store_inner_node,\n    store_inner_edge,\n    store_eids,\n    test_ntype=None,\n    test_etype=None,\n    is_homo=False,\n):\n    \"\"\"\n    check list:\n        _verify_metadata_gb:\n            data type, ID's order and ID's number of edges and nodes\n        _verify_IDs_gb:\n            local id, mapping id,node type id, orig edge, hetero ntype id\n        verify_graph_feats_gb:\n            nodes and edges' feats\n        _verify_graphbolt_attributes:\n            arguments\n    \"\"\"\n    parts = []\n    tot_node_feats = []\n    tot_edge_feats = []\n    shuffled_labels = []\n    shuffled_edata = []\n    part_sizes = []\n    gpbs = []\n    part_config = os.path.join(test_dir, f\"{graph_name}.json\")\n    # test each part\n    for part_id in range(num_parts):\n        part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition(\n            part_config, part_id, load_feats=True, use_graphbolt=True\n        )\n        # verify metadata\n        _verify_metadata_gb(\n            gpb,\n            g,\n            num_parts,\n            part_id,\n            part_sizes,\n        )\n\n        # verify eid and nid\n        _verify_IDs_gb(\n            g,\n            part_g,\n            part_id,\n            gpb,\n            part_sizes,\n            orig_nids,\n            orig_eids,\n            store_inner_node,\n            store_inner_edge,\n            store_eids,\n            is_homo,\n        )\n\n        # collect shuffled data and parts\n        _collect_data_gb(\n            parts,\n            part_g,\n            gpbs,\n            gpb,\n            tot_node_feats,\n            node_feats,\n            tot_edge_feats,\n            edge_feats,\n            shuffled_labels,\n            shuffled_edata,\n            test_ntype,\n            test_etype,\n        )\n\n    # verify graph feats\n    verify_graph_feats_gb(\n        g,\n        gpbs,\n        parts,\n        tot_node_feats,\n        tot_edge_feats,\n        orig_nids,\n        orig_eids,\n        shuffled_labels=shuffled_labels,\n        shuffled_edata=shuffled_edata,\n        test_ntype=test_ntype,\n        test_etype=test_etype,\n        store_inner_node=store_inner_node,\n        store_inner_edge=store_inner_edge,\n        store_eids=store_eids,\n        is_homo=is_homo,\n    )\n\n    _verify_graphbolt_attributes(\n        parts, store_inner_node, store_inner_edge, store_eids\n    )\n\n    return parts\n\n\ndef _verify_original_IDs_type_hetero(hg, orig_nids, orig_eids):\n    \"\"\"\n    check list:\n        make sure type of nodes and edges' ids are correct.\n        make sure nodes and edges' number in each type is correct.\n    \"\"\"\n    assert len(orig_nids) == len(hg.ntypes)\n    assert len(orig_eids) == len(hg.canonical_etypes)\n    for ntype in hg.ntypes:\n        assert len(orig_nids[ntype]) == hg.num_nodes(ntype)\n        assert F.dtype(orig_nids[ntype]) in (F.int64, F.int32)\n    for etype in hg.canonical_etypes:\n        assert len(orig_eids[etype]) == hg.num_edges(etype)\n        assert F.dtype(orig_eids[etype]) in (F.int64, F.int32)\n\n\n@pytest.mark.parametrize(\"part_method\", [\"metis\", \"random\"])\n@pytest.mark.parametrize(\"num_parts\", [1, 4])\n@pytest.mark.parametrize(\"store_eids\", [True, False])\n@pytest.mark.parametrize(\"store_inner_node\", [True, False])\n@pytest.mark.parametrize(\"store_inner_edge\", [True, False])\n@pytest.mark.parametrize(\"debug_mode\", [True, False])\ndef test_partition_graph_graphbolt_hetero(\n    part_method,\n    num_parts,\n    store_eids,\n    store_inner_node,\n    store_inner_edge,\n    debug_mode,\n    n_jobs=1,\n):\n    test_ntype = \"n1\"\n    test_etype = (\"n1\", \"r1\", \"n2\")\n    reset_envs()\n    if debug_mode:\n        os.environ[\"DGL_DIST_DEBUG\"] = \"1\"\n    with tempfile.TemporaryDirectory() as test_dir:\n        hg = create_random_hetero()\n        graph_name = \"test\"\n        hg.nodes[test_ntype].data[\"labels\"] = F.arange(\n            0, hg.num_nodes(test_ntype)\n        )\n        hg.nodes[test_ntype].data[\"feats\"] = F.tensor(\n            np.random.randn(hg.num_nodes(test_ntype), 10), F.float32\n        )\n        hg.edges[test_etype].data[\"feats\"] = F.tensor(\n            np.random.randn(hg.num_edges(test_etype), 10), F.float32\n        )\n        hg.edges[test_etype].data[\"labels\"] = F.arange(\n            0, hg.num_edges(test_etype)\n        )\n        orig_nids, orig_eids = partition_graph(\n            hg,\n            graph_name,\n            num_parts,\n            test_dir,\n            part_method=part_method,\n            return_mapping=True,\n            num_trainers_per_machine=1,\n            use_graphbolt=True,\n            store_eids=store_eids,\n            store_inner_node=store_inner_node,\n            store_inner_edge=store_inner_edge,\n            n_jobs=n_jobs,\n        )\n\n        _verify_original_IDs_type_hetero(hg, orig_nids, orig_eids)\n        if debug_mode:\n            store_eids = store_inner_node = store_inner_edge = True\n\n        parts = _verify_graphbolt_part(\n            hg,\n            test_dir,\n            orig_nids,\n            orig_eids,\n            graph_name,\n            num_parts,\n            store_inner_node,\n            store_inner_edge,\n            store_eids,\n            test_ntype,\n            test_etype,\n            is_homo=False,\n        )\n\n        _verify_hetero_graph(\n            hg,\n            parts,\n            store_eids=store_eids,\n            store_inner_edge=store_inner_edge,\n            debug_mode=debug_mode,\n        )\n\n\n@pytest.mark.parametrize(\"part_method\", [\"metis\", \"random\"])\n@pytest.mark.parametrize(\"num_parts\", [1, 4])\n@pytest.mark.parametrize(\"graph_formats\", [[\"csc\"], [\"coo\"], [\"coo\", \"csc\"]])\ndef test_partition_graph_graphbolt_homo_find_edges(\n    part_method,\n    num_parts,\n    graph_formats,\n    n_jobs=1,\n):\n    reset_envs()\n    os.environ[\"DGL_DIST_DEBUG\"] = \"1\"\n    with tempfile.TemporaryDirectory() as test_dir:\n        g = create_random_graph(1000)\n        g.ndata[\"feat\"] = th.rand(g.num_nodes(), 5)\n        graph_name = \"test\"\n        orig_nids, orig_eids = partition_graph(\n            g,\n            graph_name,\n            num_parts,\n            test_dir,\n            part_method=part_method,\n            graph_formats=graph_formats,\n            return_mapping=True,\n            use_graphbolt=True,\n            store_eids=True,\n            store_inner_node=True,\n            store_inner_edge=True,\n            n_jobs=n_jobs,\n        )\n        part_config = os.path.join(test_dir, f\"{graph_name}.json\")\n        for part_id in range(num_parts):\n            local_g, _, _, gpb, _, _, _ = load_partition(\n                part_config, part_id, load_feats=False, use_graphbolt=True\n            )\n            inner_local_eids = th.nonzero(\n                local_g.edge_attributes[\"inner_edge\"], as_tuple=False\n            ).squeeze()\n            inner_global_eids = local_g.edge_attributes[dgl.EID][\n                inner_local_eids\n            ]\n            if \"coo\" not in graph_formats:\n                with pytest.raises(\n                    ValueError,\n                    match=\"The edge attributes DGL2GB_EID and GB_DST_ID are \"\n                    \"not found. Please make sure `coo` format is available\"\n                    \" when generating partitions in GraphBolt format.\",\n                ):\n                    dgl.distributed.graph_services._find_edges(\n                        local_g, gpb, inner_global_eids\n                    )\n                continue\n            global_src, global_dst = dgl.distributed.graph_services._find_edges(\n                local_g, gpb, inner_global_eids\n            )\n            orig_global_src = orig_nids[global_src]\n            orig_global_dst = orig_nids[global_dst]\n            assert th.all(g.has_edges_between(orig_global_src, orig_global_dst))\n\n            # dtype check.\n            assert (\n                local_g.edge_attributes[dgl.distributed.DGL2GB_EID].dtype\n                == th.int32\n            )\n            assert (\n                local_g.edge_attributes[dgl.distributed.GB_DST_ID].dtype\n                == th.int32\n            )\n\n            # No need to map local node IDs.\n            inner_local_nids = th.nonzero(\n                local_g.node_attributes[\"inner_node\"], as_tuple=False\n            ).squeeze()\n            inner_global_nids = local_g.node_attributes[dgl.NID][\n                inner_local_nids\n            ]\n            assert th.equal(\n                inner_local_nids, gpb.nid2localnid(inner_global_nids, part_id)\n            )\n\n            # Need to map local edge IDs.\n            DGL_inner_local_eids = gpb.eid2localeid(inner_global_eids, part_id)\n            GB_inner_local_eids = local_g.edge_attributes[\n                dgl.distributed.DGL2GB_EID\n            ][DGL_inner_local_eids]\n            assert th.equal(inner_local_eids, GB_inner_local_eids)\n\n\n@pytest.mark.parametrize(\"part_method\", [\"metis\", \"random\"])\n@pytest.mark.parametrize(\"num_parts\", [1, 4])\n@pytest.mark.parametrize(\"graph_formats\", [[\"csc\"], [\"coo\"], [\"coo\", \"csc\"]])\ndef test_partition_graph_graphbolt_hetero_find_edges(\n    part_method,\n    num_parts,\n    graph_formats,\n    n_jobs=1,\n):\n    reset_envs()\n    os.environ[\"DGL_DIST_DEBUG\"] = \"1\"\n    with tempfile.TemporaryDirectory() as test_dir:\n        hg = create_random_hetero()\n        graph_name = \"test\"\n        orig_nids, orig_eids = partition_graph(\n            hg,\n            graph_name,\n            num_parts,\n            test_dir,\n            part_method=part_method,\n            graph_formats=graph_formats,\n            return_mapping=True,\n            use_graphbolt=True,\n            store_eids=True,\n            store_inner_node=True,\n            store_inner_edge=True,\n            n_jobs=n_jobs,\n        )\n        part_config = os.path.join(test_dir, f\"{graph_name}.json\")\n        for part_id in range(num_parts):\n            local_g, _, _, gpb, _, _, _ = load_partition(\n                part_config, part_id, load_feats=False, use_graphbolt=True\n            )\n            inner_local_eids = th.nonzero(\n                local_g.edge_attributes[\"inner_edge\"], as_tuple=False\n            ).squeeze()\n            inner_global_eids = local_g.edge_attributes[dgl.EID][\n                inner_local_eids\n            ]\n            if \"coo\" not in graph_formats:\n                with pytest.raises(\n                    ValueError,\n                    match=\"The edge attributes DGL2GB_EID and GB_DST_ID are \"\n                    \"not found. Please make sure `coo` format is available\"\n                    \" when generating partitions in GraphBolt format.\",\n                ):\n                    dgl.distributed.graph_services._find_edges(\n                        local_g, gpb, inner_global_eids\n                    )\n                continue\n            global_src, global_dst = dgl.distributed.graph_services._find_edges(\n                local_g, gpb, inner_global_eids\n            )\n            ntype_ids_src, per_ntype_nids_src = gpb.map_to_per_ntype(global_src)\n            ntype_ids_dst, per_ntype_nids_dst = gpb.map_to_per_ntype(global_dst)\n            etype_ids, per_etype_eids = gpb.map_to_per_etype(inner_global_eids)\n            for src_ntype, etype, dst_ntype in hg.canonical_etypes:\n                etype_id = hg.get_etype_id((src_ntype, etype, dst_ntype))\n                current_etype_indices = th.nonzero(\n                    etype_ids == etype_id, as_tuple=False\n                ).squeeze()\n                assert th.all(\n                    ntype_ids_src[current_etype_indices]\n                    == gpb.ntypes.index(src_ntype)\n                )\n                assert th.all(\n                    ntype_ids_dst[current_etype_indices]\n                    == gpb.ntypes.index(dst_ntype)\n                )\n                current_per_ntype_nids_src = per_ntype_nids_src[\n                    current_etype_indices\n                ]\n                current_per_ntype_nids_dst = per_ntype_nids_dst[\n                    current_etype_indices\n                ]\n                current_orig_global_src = orig_nids[src_ntype][\n                    current_per_ntype_nids_src\n                ]\n                current_orig_global_dst = orig_nids[dst_ntype][\n                    current_per_ntype_nids_dst\n                ]\n                assert th.all(\n                    hg.has_edges_between(\n                        current_orig_global_src,\n                        current_orig_global_dst,\n                        etype=(src_ntype, etype, dst_ntype),\n                    )\n                )\n                current_orig_global_eids = orig_eids[\n                    (src_ntype, etype, dst_ntype)\n                ][per_etype_eids[current_etype_indices]]\n                orig_src_ids, orig_dst_ids = hg.find_edges(\n                    current_orig_global_eids,\n                    etype=(src_ntype, etype, dst_ntype),\n                )\n                assert th.equal(current_orig_global_src, orig_src_ids)\n                assert th.equal(current_orig_global_dst, orig_dst_ids)\n\n            # dtype check.\n            assert (\n                local_g.edge_attributes[dgl.distributed.DGL2GB_EID].dtype\n                == th.int32\n            )\n            assert (\n                local_g.edge_attributes[dgl.distributed.GB_DST_ID].dtype\n                == th.int32\n            )\n\n            # No need to map local node IDs.\n            inner_local_nids = th.nonzero(\n                local_g.node_attributes[\"inner_node\"], as_tuple=False\n            ).squeeze()\n            inner_global_nids = local_g.node_attributes[dgl.NID][\n                inner_local_nids\n            ]\n            assert th.equal(\n                inner_local_nids, gpb.nid2localnid(inner_global_nids, part_id)\n            )\n\n            # Need to map local edge IDs.\n            DGL_inner_local_eids = gpb.eid2localeid(inner_global_eids, part_id)\n            GB_inner_local_eids = local_g.edge_attributes[\n                dgl.distributed.DGL2GB_EID\n            ][DGL_inner_local_eids]\n            assert th.equal(inner_local_eids, GB_inner_local_eids)\n\n\n@pytest.mark.parametrize(\"num_parts\", [1, 4])\ndef test_partition_graph_graphbolt_hetero_multi(\n    num_parts,\n):\n    reset_envs()\n\n    test_partition_graph_graphbolt_hetero(\n        part_method=\"random\",\n        num_parts=num_parts,\n        n_jobs=4,\n        store_eids=True,\n        store_inner_node=True,\n        store_inner_edge=True,\n        debug_mode=False,\n    )\n\n\n@pytest.mark.parametrize(\"num_parts\", [1, 4])\ndef test_partition_graph_graphbolt_homo_find_edges_multi(\n    num_parts,\n):\n    test_partition_graph_graphbolt_homo_find_edges(\n        part_method=\"random\",\n        num_parts=num_parts,\n        graph_formats=\"coo\",\n        n_jobs=4,\n    )\n\n\n@pytest.mark.parametrize(\"num_parts\", [1, 4])\ndef test_partition_graph_graphbolt_hetero_find_edges_multi(\n    num_parts,\n):\n    test_partition_graph_graphbolt_hetero_find_edges(\n        part_method=\"random\",\n        num_parts=num_parts,\n        graph_formats=\"coo\",\n        n_jobs=4,\n    )\n\n\n@pytest.mark.parametrize(\"part_method\", [\"metis\", \"random\"])\n@pytest.mark.parametrize(\"num_parts\", [4])\n@pytest.mark.parametrize(\"num_trainers_per_machine\", [1])\n@pytest.mark.parametrize(\"graph_formats\", [None])\ndef test_partition_hetero_few_edges(\n    part_method,\n    num_parts,\n    num_trainers_per_machine,\n    graph_formats,\n):\n    os.environ[\"DGL_DIST_DEBUG\"] = \"1\"\n    if part_method == \"random\" and num_parts > 1:\n        num_trainers_per_machine = 1\n\n    # Create a heterograph with 2 edges for one edge type.\n    hg = create_random_hetero()\n    edges_coo = {\n        c_etype: hg.edges(etype=c_etype) for c_etype in hg.canonical_etypes\n    }\n    edges_coo[(\"n1\", \"a0\", \"n2\")] = (th.tensor([0, 1]), th.tensor([1, 0]))\n    edges_coo[(\"n1\", \"a1\", \"n3\")] = (th.tensor([0, 1]), th.tensor([1, 0]))\n    hg = dgl.heterograph(edges_coo)\n\n    check_hetero_partition(\n        hg,\n        part_method,\n        num_parts,\n        num_trainers_per_machine,\n        load_feats=False,\n        graph_formats=graph_formats,\n    )\n    reset_envs()\n\n\n@pytest.mark.parametrize(\"part_method\", [\"metis\", \"random\"])\n@pytest.mark.parametrize(\"num_parts\", [4])\n@pytest.mark.parametrize(\"num_trainers_per_machine\", [1])\n@pytest.mark.parametrize(\"graph_formats\", [None])\ndef test_partition_hetero_few_nodes(\n    part_method,\n    num_parts,\n    num_trainers_per_machine,\n    graph_formats,\n):\n    os.environ[\"DGL_DIST_DEBUG\"] = \"1\"\n    if part_method == \"random\" and num_parts > 1:\n        num_trainers_per_machine = 1\n\n    # Create a heterograph with 2 nodes for one node type.\n    hg = create_random_hetero()\n    edges_coo = {\n        c_etype: hg.edges(etype=c_etype) for c_etype in hg.canonical_etypes\n    }\n    edges_coo[(\"n1\", \"r_few\", \"n_few\")] = (th.tensor([0, 1]), th.tensor([1, 0]))\n    edges_coo[(\"a0\", \"a01\", \"n_1\")] = (th.tensor([0, 1]), th.tensor([1, 0]))\n    hg = dgl.heterograph(edges_coo)\n\n    expected_exception = False\n    try:\n        check_hetero_partition(\n            hg,\n            part_method,\n            num_parts,\n            num_trainers_per_machine,\n            load_feats=False,\n            graph_formats=graph_formats,\n        )\n    except Exception as e:\n        expected_exception = True\n    assert expected_exception == (part_method == \"metis\")\n    reset_envs()\n"
  },
  {
    "path": "tests/distributed/test_rpc.py",
    "content": "import multiprocessing as mp\nimport os\nimport socket\nimport time\nimport unittest\n\nimport backend as F\n\nimport dgl\nimport pytest\nfrom numpy.testing import assert_array_equal\nfrom utils import generate_ip_config, reset_envs\n\nif os.name != \"nt\":\n    import fcntl\n    import struct\n\nINTEGER = 2\nSTR = \"hello world!\"\nHELLO_SERVICE_ID = 901231\nTENSOR = F.zeros((1000, 1000), F.int64, F.cpu())\n\n\ndef foo(x, y):\n    assert x == 123\n    assert y == \"abc\"\n\n\nclass MyRequest(dgl.distributed.Request):\n    def __init__(self):\n        self.x = 123\n        self.y = \"abc\"\n        self.z = F.randn((3, 4))\n        self.foo = foo\n\n    def __getstate__(self):\n        return self.x, self.y, self.z, self.foo\n\n    def __setstate__(self, state):\n        self.x, self.y, self.z, self.foo = state\n\n    def process_request(self, server_state):\n        pass\n\n\nclass MyResponse(dgl.distributed.Response):\n    def __init__(self):\n        self.x = 432\n\n    def __getstate__(self):\n        return self.x\n\n    def __setstate__(self, state):\n        self.x = state\n\n\ndef simple_func(tensor):\n    return tensor\n\n\nclass HelloResponse(dgl.distributed.Response):\n    def __init__(self, hello_str, integer, tensor):\n        self.hello_str = hello_str\n        self.integer = integer\n        self.tensor = tensor\n\n    def __getstate__(self):\n        return self.hello_str, self.integer, self.tensor\n\n    def __setstate__(self, state):\n        self.hello_str, self.integer, self.tensor = state\n\n\nclass HelloRequest(dgl.distributed.Request):\n    def __init__(self, hello_str, integer, tensor, func):\n        self.hello_str = hello_str\n        self.integer = integer\n        self.tensor = tensor\n        self.func = func\n\n    def __getstate__(self):\n        return self.hello_str, self.integer, self.tensor, self.func\n\n    def __setstate__(self, state):\n        self.hello_str, self.integer, self.tensor, self.func = state\n\n    def process_request(self, server_state):\n        assert self.hello_str == STR\n        assert self.integer == INTEGER\n        new_tensor = self.func(self.tensor)\n        res = HelloResponse(self.hello_str, self.integer, new_tensor)\n        return res\n\n\nTIMEOUT_SERVICE_ID = 123456789\nTIMEOUT_META = \"timeout_test\"\n\n\nclass TimeoutResponse(dgl.distributed.Response):\n    def __init__(self, meta):\n        self.meta = meta\n\n    def __getstate__(self):\n        return self.meta\n\n    def __setstate__(self, state):\n        self.meta = state\n\n\nclass TimeoutRequest(dgl.distributed.Request):\n    def __init__(self, meta, timeout, response=True):\n        self.meta = meta\n        self.timeout = timeout\n        self.response = response\n\n    def __getstate__(self):\n        return self.meta, self.timeout, self.response\n\n    def __setstate__(self, state):\n        self.meta, self.timeout, self.response = state\n\n    def process_request(self, server_state):\n        assert self.meta == TIMEOUT_META\n        # convert from milliseconds to seconds\n        time.sleep(self.timeout / 1000)\n        if not self.response:\n            return None\n        res = TimeoutResponse(self.meta)\n        return res\n\n\ndef start_server(\n    num_clients,\n    ip_config,\n    server_id=0,\n    num_servers=1,\n):\n    print(\"Sleep 1 seconds to test client re-connect.\")\n    time.sleep(1)\n    server_state = dgl.distributed.ServerState(\n        None, local_g=None, partition_book=None\n    )\n    dgl.distributed.register_service(\n        HELLO_SERVICE_ID, HelloRequest, HelloResponse\n    )\n    dgl.distributed.register_service(\n        TIMEOUT_SERVICE_ID, TimeoutRequest, TimeoutResponse\n    )\n    print(\"Start server {}\".format(server_id))\n    dgl.distributed.start_server(\n        server_id=server_id,\n        ip_config=ip_config,\n        num_servers=num_servers,\n        num_clients=num_clients,\n        server_state=server_state,\n    )\n\n\ndef start_client(ip_config, group_id=0, num_servers=1):\n    dgl.distributed.register_service(\n        HELLO_SERVICE_ID, HelloRequest, HelloResponse\n    )\n    dgl.distributed.connect_to_server(\n        ip_config=ip_config,\n        num_servers=num_servers,\n        group_id=group_id,\n    )\n    req = HelloRequest(STR, INTEGER, TENSOR, simple_func)\n    # test send and recv\n    dgl.distributed.send_request(0, req)\n    res = dgl.distributed.recv_response()\n    assert res.hello_str == STR\n    assert res.integer == INTEGER\n    assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))\n    # test remote_call\n    target_and_requests = []\n    for i in range(10):\n        target_and_requests.append((0, req))\n    res_list = dgl.distributed.remote_call(target_and_requests)\n    for res in res_list:\n        assert res.hello_str == STR\n        assert res.integer == INTEGER\n        assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))\n    # test send_request_to_machine\n    dgl.distributed.send_request_to_machine(0, req)\n    res = dgl.distributed.recv_response()\n    assert res.hello_str == STR\n    assert res.integer == INTEGER\n    assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))\n    # test remote_call_to_machine\n    target_and_requests = []\n    for i in range(10):\n        target_and_requests.append((0, req))\n    res_list = dgl.distributed.remote_call_to_machine(target_and_requests)\n    for res in res_list:\n        assert res.hello_str == STR\n        assert res.integer == INTEGER\n        assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))\n\n\ndef start_client_timeout(ip_config, group_id=0, num_servers=1):\n    dgl.distributed.register_service(\n        TIMEOUT_SERVICE_ID, TimeoutRequest, TimeoutResponse\n    )\n    dgl.distributed.connect_to_server(\n        ip_config=ip_config,\n        num_servers=num_servers,\n        group_id=group_id,\n    )\n    timeout = 1 * 1000  # milliseconds\n    req = TimeoutRequest(TIMEOUT_META, timeout)\n    # test send and recv\n    dgl.distributed.send_request(0, req)\n    res = dgl.distributed.recv_response(timeout=int(timeout / 2))\n    assert res is None\n    res = dgl.distributed.recv_response()\n    assert res.meta == TIMEOUT_META\n    # test remote_call\n    req = TimeoutRequest(TIMEOUT_META, timeout, response=False)\n    target_and_requests = []\n    for i in range(3):\n        target_and_requests.append((0, req))\n    expect_except = False\n    try:\n        res_list = dgl.distributed.remote_call(\n            target_and_requests, timeout=int(timeout / 2)\n        )\n    except dgl.DGLError:\n        expect_except = True\n    assert expect_except\n    # test send_request_to_machine\n    req = TimeoutRequest(TIMEOUT_META, timeout)\n    dgl.distributed.send_request_to_machine(0, req)\n    res = dgl.distributed.recv_response(timeout=int(timeout / 2))\n    assert res is None\n    res = dgl.distributed.recv_response()\n    assert res.meta == TIMEOUT_META\n    # test remote_call_to_machine\n    req = TimeoutRequest(TIMEOUT_META, timeout, response=False)\n    target_and_requests = []\n    for i in range(3):\n        target_and_requests.append((0, req))\n    expect_except = False\n    try:\n        res_list = dgl.distributed.remote_call_to_machine(\n            target_and_requests, timeout=int(timeout / 2)\n        )\n    except dgl.DGLError:\n        expect_except = True\n    assert expect_except\n\n\n@unittest.skipIf(os.name == \"nt\", reason=\"Do not support windows yet\")\ndef test_rpc_timeout():\n    reset_envs()\n    os.environ[\"DGL_DIST_MODE\"] = \"distributed\"\n    ip_config = \"rpc_ip_config.txt\"\n    generate_ip_config(ip_config, 1, 1)\n    ctx = mp.get_context(\"spawn\")\n    pserver = ctx.Process(target=start_server, args=(1, ip_config, 0, 1))\n    pclient = ctx.Process(target=start_client_timeout, args=(ip_config, 0, 1))\n    pserver.start()\n    pclient.start()\n    pserver.join()\n    pclient.join()\n\n\ndef test_serialize():\n    reset_envs()\n    os.environ[\"DGL_DIST_MODE\"] = \"distributed\"\n    from dgl.distributed.rpc import (\n        deserialize_from_payload,\n        serialize_to_payload,\n    )\n\n    SERVICE_ID = 12345\n    dgl.distributed.register_service(SERVICE_ID, MyRequest, MyResponse)\n    req = MyRequest()\n    data, tensors = serialize_to_payload(req)\n    req1 = deserialize_from_payload(MyRequest, data, tensors)\n    req1.foo(req1.x, req1.y)\n    assert req.x == req1.x\n    assert req.y == req1.y\n    assert F.array_equal(req.z, req1.z)\n\n    res = MyResponse()\n    data, tensors = serialize_to_payload(res)\n    res1 = deserialize_from_payload(MyResponse, data, tensors)\n    assert res.x == res1.x\n\n\ndef test_rpc_msg():\n    reset_envs()\n    os.environ[\"DGL_DIST_MODE\"] = \"distributed\"\n    from dgl.distributed.rpc import (\n        deserialize_from_payload,\n        RPCMessage,\n        serialize_to_payload,\n    )\n\n    SERVICE_ID = 32452\n    dgl.distributed.register_service(SERVICE_ID, MyRequest, MyResponse)\n    req = MyRequest()\n    data, tensors = serialize_to_payload(req)\n    rpcmsg = RPCMessage(SERVICE_ID, 23, 0, 1, data, tensors)\n    assert rpcmsg.service_id == SERVICE_ID\n    assert rpcmsg.msg_seq == 23\n    assert rpcmsg.client_id == 0\n    assert rpcmsg.server_id == 1\n    assert len(rpcmsg.data) == len(data)\n    assert len(rpcmsg.tensors) == 1\n    assert F.array_equal(rpcmsg.tensors[0], req.z)\n\n\n@unittest.skipIf(os.name == \"nt\", reason=\"Do not support windows yet\")\ndef test_multi_client():\n    reset_envs()\n    os.environ[\"DGL_DIST_MODE\"] = \"distributed\"\n    ip_config = \"rpc_ip_config_mul_client.txt\"\n    generate_ip_config(ip_config, 1, 1)\n    ctx = mp.get_context(\"spawn\")\n    num_clients = 20\n    pserver = ctx.Process(\n        target=start_server,\n        args=(num_clients, ip_config, 0, 1),\n    )\n    pclient_list = []\n    for i in range(num_clients):\n        pclient = ctx.Process(target=start_client, args=(ip_config, 0, 1))\n        pclient_list.append(pclient)\n    pserver.start()\n    for i in range(num_clients):\n        pclient_list[i].start()\n    for i in range(num_clients):\n        pclient_list[i].join()\n    pserver.join()\n\n\n@unittest.skipIf(os.name == \"nt\", reason=\"Do not support windows yet\")\ndef test_multi_thread_rpc():\n    reset_envs()\n    os.environ[\"DGL_DIST_MODE\"] = \"distributed\"\n    num_servers = 2\n    ip_config = \"rpc_ip_config_multithread.txt\"\n    generate_ip_config(ip_config, num_servers, num_servers)\n    ctx = mp.get_context(\"spawn\")\n    pserver_list = []\n    for i in range(num_servers):\n        pserver = ctx.Process(target=start_server, args=(1, ip_config, i, 1))\n        pserver.start()\n        pserver_list.append(pserver)\n\n    def start_client_multithread(ip_config):\n        import threading\n\n        dgl.distributed.connect_to_server(\n            ip_config=ip_config,\n            num_servers=1,\n        )\n        dgl.distributed.register_service(\n            HELLO_SERVICE_ID, HelloRequest, HelloResponse\n        )\n\n        req = HelloRequest(STR, INTEGER, TENSOR, simple_func)\n        dgl.distributed.send_request(0, req)\n\n        def subthread_call(server_id):\n            req = HelloRequest(STR, INTEGER, TENSOR, simple_func)\n            dgl.distributed.send_request(server_id, req)\n\n        subthread = threading.Thread(target=subthread_call, args=(1,))\n        subthread.start()\n        subthread.join()\n\n        res0 = dgl.distributed.recv_response()\n        res1 = dgl.distributed.recv_response()\n        # Order is not guaranteed\n        assert_array_equal(F.asnumpy(res0.tensor), F.asnumpy(TENSOR))\n        assert_array_equal(F.asnumpy(res1.tensor), F.asnumpy(TENSOR))\n        dgl.distributed.exit_client()\n\n    start_client_multithread(ip_config)\n    pserver.join()\n\n\n@unittest.skipIf(os.name == \"nt\", reason=\"Do not support windows yet\")\ndef test_multi_client_connect():\n    reset_envs()\n    os.environ[\"DGL_DIST_MODE\"] = \"distributed\"\n    ip_config = \"rpc_ip_config_mul_client.txt\"\n    generate_ip_config(ip_config, 1, 1)\n    ctx = mp.get_context(\"spawn\")\n    num_clients = 1\n    pserver = ctx.Process(\n        target=start_server,\n        args=(num_clients, ip_config, 0, 1),\n    )\n\n    # small max try times\n    os.environ[\"DGL_DIST_MAX_TRY_TIMES\"] = \"1\"\n    expect_except = False\n    try:\n        start_client(ip_config, 0, 1)\n    except dgl.distributed.DistConnectError as err:\n        print(\"Expected error: {}\".format(err))\n        expect_except = True\n    assert expect_except\n\n    # large max try times\n    os.environ[\"DGL_DIST_MAX_TRY_TIMES\"] = \"1024\"\n    pclient = ctx.Process(target=start_client, args=(ip_config, 0, 1))\n    pclient.start()\n    pserver.start()\n    pclient.join()\n    pserver.join()\n    reset_envs()\n\n\nif __name__ == \"__main__\":\n    test_serialize()\n    test_rpc_msg()\n    test_multi_client(\"socket\")\n    test_multi_client(\"tesnsorpipe\")\n    test_multi_thread_rpc()\n    test_multi_client_connect(\"socket\")\n"
  },
  {
    "path": "tests/distributed/utils.py",
    "content": "import os\nimport random\nimport socket\n\nimport dgl\n\nimport numpy as np\nimport scipy.sparse as spsp\n\n\ndef generate_ip_config(file_name, num_machines, num_servers):\n    \"\"\"Get local IP and available ports, writes to file.\"\"\"\n    # get available IP in localhost\n    sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)\n    try:\n        # doesn't even have to be reachable\n        sock.connect((\"10.255.255.255\", 1))\n        ip = sock.getsockname()[0]\n    except ValueError:\n        ip = \"127.0.0.1\"\n    finally:\n        sock.close()\n\n    # scan available PORT\n    ports = []\n    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)\n    start = random.randint(10000, 30000)\n    for port in range(start, 65535):\n        try:\n            sock.connect((ip, port))\n            ports = []\n        except:\n            ports.append(port)\n            if len(ports) == num_machines * num_servers:\n                break\n    sock.close()\n    if len(ports) < num_machines * num_servers:\n        raise RuntimeError(\n            \"Failed to get available IP/PORT with required numbers.\"\n        )\n    with open(file_name, \"w\") as f:\n        for i in range(num_machines):\n            f.write(\"{} {}\\n\".format(ip, ports[i * num_servers]))\n\n\ndef reset_envs():\n    \"\"\"Reset common environment variable which are set in tests.\"\"\"\n    for key in [\n        \"DGL_ROLE\",\n        \"DGL_NUM_SAMPLER\",\n        \"DGL_NUM_SERVER\",\n        \"DGL_DIST_MODE\",\n        \"DGL_NUM_CLIENT\",\n        \"DGL_DIST_MAX_TRY_TIMES\",\n        \"DGL_DIST_DEBUG\",\n    ]:\n        if key in os.environ:\n            os.environ.pop(key)\n\n\ndef create_random_graph(n):\n    return dgl.rand_graph(n, int(n * n * 0.001))\n"
  },
  {
    "path": "tests/examples/test_sampling_examples.py",
    "content": "import os\nimport subprocess\nimport sys\nimport unittest\n\nEXAMPLE_ROOT = os.path.join(\n    os.path.dirname(os.path.relpath(__file__)),\n    \"..\",\n    \"..\",\n    \"examples\",\n    \"graphbolt\",\n    \"quickstart\",\n)\n\n\ndef test_node_classification():\n    script = os.path.join(EXAMPLE_ROOT, \"node_classification.py\")\n    out = subprocess.run([\"python\", str(script)], capture_output=True)\n    assert (\n        out.returncode == 0\n    ), f\"stdout: {out.stdout.decode('utf-8')}\\nstderr: {out.stderr.decode('utf-8')}\"\n    stdout = out.stdout.decode(\"utf-8\")\n    assert float(stdout[-5:]) > 0.59\n\n\ndef test_link_prediction():\n    script = os.path.join(EXAMPLE_ROOT, \"link_prediction.py\")\n    out = subprocess.run([\"python\", str(script)], capture_output=True)\n    assert (\n        out.returncode == 0\n    ), f\"stdout: {out.stdout.decode('utf-8')}\\nstderr: {out.stderr.decode('utf-8')}\"\n    stdout = out.stdout.decode(\"utf-8\")\n    assert float(stdout[-5:]) > 0.80\n"
  },
  {
    "path": "tests/examples/test_sparse_examples.py",
    "content": "import os\nimport subprocess\nimport sys\n\nEXAMPLE_ROOT = os.path.join(\n    os.path.dirname(os.path.relpath(__file__)),\n    \"..\",\n    \"..\",\n    \"examples\",\n    \"sparse\",\n)\n\n\ndef test_gcn():\n    script = os.path.join(EXAMPLE_ROOT, \"gcn.py\")\n    out = subprocess.run([\"python\", str(script)], capture_output=True)\n    assert (\n        out.returncode == 0\n    ), f\"stdout: {out.stdout.decode('utf-8')}\\nstderr: {out.stderr.decode('utf-8')}\"\n    stdout = out.stdout.decode(\"utf-8\")\n    assert float(stdout[-5:]) > 0.75\n\n\ndef test_gcnii():\n    script = os.path.join(EXAMPLE_ROOT, \"gcnii.py\")\n    out = subprocess.run([\"python\", str(script)], capture_output=True)\n    assert (\n        out.returncode == 0\n    ), f\"stdout: {out.stdout.decode('utf-8')}\\nstderr: {out.stderr.decode('utf-8')}\"\n    stdout = out.stdout.decode(\"utf-8\")\n    assert float(stdout[-5:]) > 0.75\n\n\ndef test_appnp():\n    script = os.path.join(EXAMPLE_ROOT, \"appnp.py\")\n    out = subprocess.run([\"python\", str(script)], capture_output=True)\n    assert (\n        out.returncode == 0\n    ), f\"stdout: {out.stdout.decode('utf-8')}\\nstderr: {out.stderr.decode('utf-8')}\"\n    stdout = out.stdout.decode(\"utf-8\")\n    assert float(stdout[-5:]) > 0.75\n\n\ndef test_c_and_s():\n    script = os.path.join(EXAMPLE_ROOT, \"c_and_s.py\")\n    out = subprocess.run([\"python\", str(script)], capture_output=True)\n    assert (\n        out.returncode == 0\n    ), f\"stdout: {out.stdout.decode('utf-8')}\\nstderr: {out.stderr.decode('utf-8')}\"\n    stdout = out.stdout.decode(\"utf-8\")\n    assert float(stdout[-5:]) > 0.7\n\n\ndef test_gat():\n    script = os.path.join(EXAMPLE_ROOT, \"gat.py\")\n    out = subprocess.run([\"python\", str(script)], capture_output=True)\n    assert (\n        out.returncode == 0\n    ), f\"stdout: {out.stdout.decode('utf-8')}\\nstderr: {out.stderr.decode('utf-8')}\"\n    stdout = out.stdout.decode(\"utf-8\")\n    assert float(stdout[-5:]) > 0.7\n\n\ndef test_hgnn():\n    script = os.path.join(EXAMPLE_ROOT, \"hgnn.py\")\n    out = subprocess.run([\"python\", str(script)], capture_output=True)\n    assert (\n        out.returncode == 0\n    ), f\"stdout: {out.stdout.decode('utf-8')}\\nstderr: {out.stderr.decode('utf-8')}\"\n    stdout = out.stdout.decode(\"utf-8\")\n    assert float(stdout[-5:]) >= 0.65\n\n\ndef test_hypergraphatt():\n    script = os.path.join(EXAMPLE_ROOT, \"hypergraphatt.py\")\n    out = subprocess.run(\n        [\"python\", str(script), \"--epochs=10\"], capture_output=True\n    )\n    assert (\n        out.returncode == 0\n    ), f\"stdout: {out.stdout.decode('utf-8')}\\nstderr: {out.stderr.decode('utf-8')}\"\n\n\ndef test_sgc():\n    script = os.path.join(EXAMPLE_ROOT, \"sgc.py\")\n    out = subprocess.run([\"python\", str(script)], capture_output=True)\n    assert (\n        out.returncode == 0\n    ), f\"stdout: {out.stdout.decode('utf-8')}\\nstderr: {out.stderr.decode('utf-8')}\"\n    stdout = out.stdout.decode(\"utf-8\")\n    assert float(stdout[-5:]) > 0.7\n\n\ndef _test_flaky(test_fn, max_num_success=8, num_tries=10):\n    num_success = 0\n    for i in range(num_tries):\n        try:\n            test_fn()\n            num_success += 1\n        except AssertionError:\n            pass\n        # If it succeeds max_num_success / num_tries of the time.\n        if num_tries * num_success >= max_num_success * (i + 1):\n            return\n        # Early failure if required success rate is impossible now.\n        num_failure = i + 1 - num_success\n        assert num_failure <= num_tries - max_num_success\n\n\ndef _test_sign():\n    script = os.path.join(EXAMPLE_ROOT, \"sign.py\")\n    out = subprocess.run([\"python\", str(script)], capture_output=True)\n    assert (\n        out.returncode == 0\n    ), f\"stdout: {out.stdout.decode('utf-8')}\\nstderr: {out.stderr.decode('utf-8')}\"\n    stdout = out.stdout.decode(\"utf-8\")\n    assert float(stdout[-5:]) > 0.7\n\n\ndef test_sign():\n    _test_flaky(_test_sign)\n\n\ndef test_twirls():\n    script = os.path.join(EXAMPLE_ROOT, \"twirls.py\")\n\n    out = subprocess.run([\"python\", str(script)], capture_output=True)\n    assert (\n        out.returncode == 0\n    ), f\"stdout: {out.stdout.decode('utf-8')}\\nstderr: {out.stderr.decode('utf-8')}\"\n    stdout = out.stdout.decode(\"utf-8\")\n    assert float(stdout[-5:]) > 0.7\n\n    out = subprocess.run(\n        [\"python\", str(script), \"--attention\"], capture_output=True\n    )\n    assert (\n        out.returncode == 0\n    ), f\"stdout: {out.stdout.decode('utf-8')}\\nstderr: {out.stderr.decode('utf-8')}\"\n    stdout = out.stdout.decode(\"utf-8\")\n    assert float(stdout[-5:]) > 0.65\n"
  },
  {
    "path": "tests/go/test_model.py",
    "content": "import dgl\nimport pytest\nimport torch\nfrom utils.graph_cases import get_cases\nfrom dglgo.model import *\n\n\n@pytest.mark.parametrize(\"g\", get_cases([\"has_scalar_e_feature\"]))\ndef test_gcn(g):\n    data_info = {\"num_nodes\": g.num_nodes(), \"out_size\": 7}\n    node_feat = None\n    edge_feat = g.edata[\"scalar_w\"]\n\n    # node embedding + not use_edge_weight\n    model = GCN(data_info, embed_size=10, use_edge_weight=False)\n    model(g, node_feat)\n\n    # node embedding + use_edge_weight\n    model = GCN(data_info, embed_size=10, use_edge_weight=True)\n    model(g, node_feat, edge_feat)\n\n    data_info[\"in_size\"] = g.ndata[\"h\"].shape[-1]\n    node_feat = g.ndata[\"h\"]\n\n    # node feat + not use_edge_weight\n    model = GCN(data_info, embed_size=-1, use_edge_weight=False)\n    model(g, node_feat)\n\n    # node feat + use_edge_weight\n    model = GCN(data_info, embed_size=-1, use_edge_weight=True)\n    model(g, node_feat, edge_feat)\n\n\n@pytest.mark.parametrize(\"g\", get_cases([\"block-bipartite\"]))\ndef test_gcn_block(g):\n    data_info = {\"in_size\": 10, \"out_size\": 7}\n\n    blocks = [g]\n    node_feat = torch.randn(g.num_src_nodes(), data_info[\"in_size\"])\n    edge_feat = torch.abs(torch.randn(g.num_edges()))\n    # not use_edge_weight\n    model = GCN(data_info, use_edge_weight=False)\n    model.forward_block(blocks, node_feat)\n\n    # use_edge_weight\n    model = GCN(data_info, use_edge_weight=True)\n    model.forward_block(blocks, node_feat, edge_feat)\n\n\n@pytest.mark.parametrize(\"g\", get_cases([\"has_scalar_e_feature\"]))\ndef test_gat(g):\n    data_info = {\"num_nodes\": g.num_nodes(), \"out_size\": 7}\n    node_feat = None\n\n    # node embedding\n    model = GAT(data_info, embed_size=10)\n    model(g, node_feat)\n\n    # node feat\n    data_info[\"in_size\"] = g.ndata[\"h\"].shape[-1]\n    node_feat = g.ndata[\"h\"]\n    model = GAT(data_info, embed_size=-1)\n    model(g, node_feat)\n\n\n@pytest.mark.parametrize(\"g\", get_cases([\"block-bipartite\"]))\ndef test_gat_block(g):\n    data_info = {\"in_size\": 10, \"out_size\": 7}\n\n    blocks = [g]\n    node_feat = torch.randn(g.num_src_nodes(), data_info[\"in_size\"])\n    model = GAT(data_info, num_layers=1, heads=[8])\n    model.forward_block(blocks, node_feat)\n\n\n@pytest.mark.parametrize(\"g\", get_cases([\"has_scalar_e_feature\"]))\ndef test_gin(g):\n    data_info = {\"num_nodes\": g.num_nodes(), \"out_size\": 7}\n    node_feat = None\n\n    # node embedding\n    model = GIN(data_info, embed_size=10)\n    model(g, node_feat)\n\n    # node feat\n    data_info[\"in_size\"] = g.ndata[\"h\"].shape[-1]\n    node_feat = g.ndata[\"h\"]\n    model = GIN(data_info, embed_size=-1)\n    model(g, node_feat)\n\n\n@pytest.mark.parametrize(\"g\", get_cases([\"has_scalar_e_feature\"]))\ndef test_sage(g):\n    data_info = {\"num_nodes\": g.num_nodes(), \"out_size\": 7}\n    node_feat = None\n    edge_feat = g.edata[\"scalar_w\"]\n\n    # node embedding\n    model = GraphSAGE(data_info, embed_size=10)\n    model(g, node_feat)\n    model(g, node_feat, edge_feat)\n\n    # node feat\n    data_info[\"in_size\"] = g.ndata[\"h\"].shape[-1]\n    node_feat = g.ndata[\"h\"]\n    model = GraphSAGE(data_info, embed_size=-1)\n    model(g, node_feat)\n    model(g, node_feat, edge_feat)\n\n\n@pytest.mark.parametrize(\"g\", get_cases([\"block-bipartite\"]))\ndef test_sage_block(g):\n    data_info = {\"in_size\": 10, \"out_size\": 7}\n\n    blocks = [g]\n    node_feat = torch.randn(g.num_src_nodes(), data_info[\"in_size\"])\n    edge_feat = torch.abs(torch.randn(g.num_edges()))\n    model = GraphSAGE(data_info, embed_size=-1)\n    model.forward_block(blocks, node_feat)\n    model.forward_block(blocks, node_feat, edge_feat)\n\n\n@pytest.mark.parametrize(\"g\", get_cases([\"has_scalar_e_feature\"]))\ndef test_sgc(g):\n    data_info = {\"num_nodes\": g.num_nodes(), \"out_size\": 7}\n    node_feat = None\n\n    # node embedding\n    model = SGC(data_info, embed_size=10)\n    model(g, node_feat)\n\n    # node feat\n    data_info[\"in_size\"] = g.ndata[\"h\"].shape[-1]\n    node_feat = g.ndata[\"h\"]\n    model = SGC(data_info, embed_size=-1)\n    model(g, node_feat)\n\n\ndef test_bilinear():\n    data_info = {\"in_size\": 10, \"out_size\": 1}\n    model = BilinearPredictor(data_info)\n    num_pairs = 10\n    h_src = torch.randn(num_pairs, data_info[\"in_size\"])\n    h_dst = torch.randn(num_pairs, data_info[\"in_size\"])\n    model(h_src, h_dst)\n\n\ndef test_ele():\n    data_info = {\"in_size\": 10, \"out_size\": 1}\n    model = ElementWiseProductPredictor(data_info)\n    num_pairs = 10\n    h_src = torch.randn(num_pairs, data_info[\"in_size\"])\n    h_dst = torch.randn(num_pairs, data_info[\"in_size\"])\n    model(h_src, h_dst)\n\n\n@pytest.mark.parametrize(\"virtual_node\", [True, False])\ndef test_ogbg_gin(virtual_node):\n    # Test for ogbg-mol datasets\n    data_info = {\"name\": \"ogbg-molhiv\", \"out_size\": 1}\n    model = OGBGGIN(\n        data_info, embed_size=10, num_layers=2, virtual_node=virtual_node\n    )\n    num_nodes = 5\n    num_edges = 15\n    g1 = dgl.rand_graph(num_nodes, num_edges)\n    g2 = dgl.rand_graph(num_nodes, num_edges)\n    g = dgl.batch([g1, g2])\n    num_nodes = g.num_nodes()\n    num_edges = g.num_edges()\n    nfeat = torch.zeros(num_nodes, 9).long()\n    efeat = torch.zeros(num_edges, 3).long()\n    model(g, nfeat, efeat)\n\n    # Test for non-ogbg-mol datasets\n    data_info = {\n        \"name\": \"a_dataset\",\n        \"out_size\": 1,\n        \"node_feat_size\": 15,\n        \"edge_feat_size\": 5,\n    }\n    model = OGBGGIN(\n        data_info, embed_size=10, num_layers=2, virtual_node=virtual_node\n    )\n    nfeat = torch.randn(num_nodes, data_info[\"node_feat_size\"])\n    efeat = torch.randn(num_edges, data_info[\"edge_feat_size\"])\n    model(g, nfeat, efeat)\n\n\ndef test_pna():\n    # Test for ogbg-mol datasets\n    data_info = {\"name\": \"ogbg-molhiv\", \"delta\": 1, \"out_size\": 1}\n    model = PNA(data_info, embed_size=10, num_layers=2)\n    num_nodes = 5\n    num_edges = 15\n    g = dgl.rand_graph(num_nodes, num_edges)\n    nfeat = torch.zeros(num_nodes, 9).long()\n    model(g, nfeat)\n\n    # Test for non-ogbg-mol datasets\n    data_info = {\n        \"name\": \"a_dataset\",\n        \"node_feat_size\": 15,\n        \"delta\": 1,\n        \"out_size\": 1,\n    }\n    model = PNA(data_info, embed_size=10, num_layers=2)\n    nfeat = torch.randn(num_nodes, data_info[\"node_feat_size\"])\n    model(g, nfeat)\n"
  },
  {
    "path": "tests/go/test_pipeline.py",
    "content": "import os\n\nimport pytest\n\n\n@pytest.mark.parametrize(\n    \"data\",\n    [\n        \"cora\",\n        \"citeseer\",\n        \"pubmed\",\n        \"csv\",\n        \"reddit\",\n        \"co-buy-computer\",\n        \"ogbn-arxiv\",\n        \"ogbn-products\",\n    ],\n)\ndef test_nodepred_data(data):\n    os.system(f\"dgl configure nodepred --data {data} --model gcn\")\n    assert os.path.exists(f\"nodepred_{data}_gcn.yaml\")\n\n    custom_cfg = f\"custom_{data}_gcn.yaml\"\n    os.system(\n        f\"dgl configure nodepred --data {data} --model gcn --cfg {custom_cfg}\"\n    )\n    assert os.path.exists(custom_cfg)\n\n    custom_script = f\"{data}_gcn.py\"\n    os.system(f\"dgl export --cfg {custom_cfg} --output {custom_script}\")\n    assert os.path.exists(custom_script)\n\n\n@pytest.mark.parametrize(\"model\", [\"gcn\", \"gat\", \"sage\", \"sgc\", \"gin\"])\ndef test_nodepred_model(model):\n    os.system(f\"dgl configure nodepred --data cora --model {model}\")\n    assert os.path.exists(f\"nodepred_cora_{model}.yaml\")\n\n    custom_cfg = f\"custom_cora_{model}.yaml\"\n    os.system(\n        f\"dgl configure nodepred --data cora --model {model} --cfg {custom_cfg}\"\n    )\n    assert os.path.exists(custom_cfg)\n\n    custom_script = f\"cora_{model}.py\"\n    os.system(f\"dgl export --cfg {custom_cfg} --output {custom_script}\")\n    assert os.path.exists(custom_script)\n\n\n@pytest.mark.parametrize(\n    \"data\",\n    [\n        \"cora\",\n        \"citeseer\",\n        \"pubmed\",\n        \"csv\",\n        \"reddit\",\n        \"co-buy-computer\",\n        \"ogbn-arxiv\",\n        \"ogbn-products\",\n    ],\n)\ndef test_nodepred_ns_data(data):\n    os.system(f\"dgl configure nodepred-ns --data {data} --model gcn\")\n    assert os.path.exists(f\"nodepred-ns_{data}_gcn.yaml\")\n\n    custom_cfg = f\"ns-custom_{data}_gcn.yaml\"\n    os.system(\n        f\"dgl configure nodepred-ns --data {data} --model gcn --cfg {custom_cfg}\"\n    )\n    assert os.path.exists(custom_cfg)\n\n    custom_script = f\"ns-{data}_gcn.py\"\n    os.system(f\"dgl export --cfg {custom_cfg} --output {custom_script}\")\n    assert os.path.exists(custom_script)\n\n\n@pytest.mark.parametrize(\"model\", [\"gcn\", \"gat\", \"sage\"])\ndef test_nodepred_ns_model(model):\n    os.system(f\"dgl configure nodepred-ns --data cora --model {model}\")\n    assert os.path.exists(f\"nodepred-ns_cora_{model}.yaml\")\n\n    custom_cfg = f\"ns-custom_cora_{model}.yaml\"\n    os.system(\n        f\"dgl configure nodepred-ns --data cora --model {model} --cfg {custom_cfg}\"\n    )\n    assert os.path.exists(custom_cfg)\n\n    custom_script = f\"ns-cora_{model}.py\"\n    os.system(f\"dgl export --cfg {custom_cfg} --output {custom_script}\")\n    assert os.path.exists(custom_script)\n\n\n@pytest.mark.parametrize(\n    \"data\",\n    [\n        \"cora\",\n        \"citeseer\",\n        \"pubmed\",\n        \"csv\",\n        \"reddit\",\n        \"co-buy-computer\",\n        \"ogbn-arxiv\",\n        \"ogbn-products\",\n        \"ogbl-collab\",\n        \"ogbl-citation2\",\n    ],\n)\ndef test_linkpred_data(data):\n    node_model = \"gcn\"\n    edge_model = \"ele\"\n    neg_sampler = \"global\"\n    custom_cfg = \"_\".join([data, node_model, edge_model, neg_sampler]) + \".yaml\"\n    os.system(\n        \"dgl configure linkpred --data {} --node-model {} --edge-model {} --neg-sampler {} --cfg {}\".format(\n            data, node_model, edge_model, neg_sampler, custom_cfg\n        )\n    )\n    assert os.path.exists(custom_cfg)\n\n    custom_script = (\n        \"_\".join([data, node_model, edge_model, neg_sampler]) + \".py\"\n    )\n    os.system(\n        \"dgl export --cfg {} --output {}\".format(custom_cfg, custom_script)\n    )\n    assert os.path.exists(custom_script)\n\n\n@pytest.mark.parametrize(\"node_model\", [\"gcn\", \"gat\", \"sage\", \"sgc\", \"gin\"])\ndef test_linkpred_node_model(node_model):\n    data = \"cora\"\n    edge_model = \"ele\"\n    neg_sampler = \"global\"\n    custom_cfg = \"_\".join([data, node_model, edge_model, neg_sampler]) + \".yaml\"\n    os.system(\n        \"dgl configure linkpred --data {} --node-model {} --edge-model {} --neg-sampler {} --cfg {}\".format(\n            data, node_model, edge_model, neg_sampler, custom_cfg\n        )\n    )\n    assert os.path.exists(custom_cfg)\n\n    custom_script = (\n        \"_\".join([data, node_model, edge_model, neg_sampler]) + \".py\"\n    )\n    os.system(\n        \"dgl export --cfg {} --output {}\".format(custom_cfg, custom_script)\n    )\n    assert os.path.exists(custom_script)\n\n\n@pytest.mark.parametrize(\"edge_model\", [\"ele\", \"bilinear\"])\ndef test_linkpred_edge_model(edge_model):\n    data = \"cora\"\n    node_model = \"gcn\"\n    neg_sampler = \"global\"\n    custom_cfg = \"_\".join([data, node_model, edge_model, neg_sampler]) + \".yaml\"\n    os.system(\n        \"dgl configure linkpred --data {} --node-model {} --edge-model {} --neg-sampler {} --cfg {}\".format(\n            data, node_model, edge_model, neg_sampler, custom_cfg\n        )\n    )\n    assert os.path.exists(custom_cfg)\n\n    custom_script = (\n        \"_\".join([data, node_model, edge_model, neg_sampler]) + \".py\"\n    )\n    os.system(\n        \"dgl export --cfg {} --output {}\".format(custom_cfg, custom_script)\n    )\n    assert os.path.exists(custom_script)\n\n\n@pytest.mark.parametrize(\"neg_sampler\", [\"global\", \"persource\", \"\"])\ndef test_linkpred_neg_sampler(neg_sampler):\n    data = \"cora\"\n    node_model = \"gcn\"\n    edge_model = \"ele\"\n    custom_cfg = f\"{data}_{node_model}_{edge_model}_{neg_sampler}.yaml\"\n    if neg_sampler == \"\":\n        os.system(\n            \"dgl configure linkpred --data {} --node-model {} --edge-model {} --cfg {}\".format(\n                data, node_model, edge_model, custom_cfg\n            )\n        )\n    else:\n        os.system(\n            \"dgl configure linkpred --data {} --node-model {} --edge-model {} --neg-sampler {} --cfg {}\".format(\n                data, node_model, edge_model, neg_sampler, custom_cfg\n            )\n        )\n    assert os.path.exists(custom_cfg)\n\n    custom_script = f\"{data}_{node_model}_{edge_model}_{neg_sampler}.py\"\n    os.system(\n        \"dgl export --cfg {} --output {}\".format(custom_cfg, custom_script)\n    )\n    assert os.path.exists(custom_script)\n\n\n@pytest.mark.parametrize(\"data\", [\"csv\", \"ogbg-molhiv\", \"ogbg-molpcba\"])\n@pytest.mark.parametrize(\"model\", [\"gin\", \"pna\"])\ndef test_graphpred(data, model):\n    os.system(\n        \"dgl configure graphpred --data {} --model {}\".format(data, model)\n    )\n    assert os.path.exists(\"graphpred_{}_{}.yaml\".format(data, model))\n\n    custom_cfg = \"custom_{}_{}.yaml\".format(data, model)\n    os.system(\n        \"dgl configure graphpred --data {} --model {} --cfg {}\".format(\n            data, model, custom_cfg\n        )\n    )\n    assert os.path.exists(custom_cfg)\n\n    custom_script = \"_\".join([data, model]) + \".py\"\n    os.system(\n        \"dgl export --cfg {} --output {}\".format(custom_cfg, custom_script)\n    )\n    assert os.path.exists(custom_script)\n\n\n@pytest.mark.parametrize(\n    \"recipe\",\n    [\n        \"graphpred_hiv_gin.yaml\",\n        \"graphpred_hiv_pna.yaml\",\n        \"graphpred_pcba_gin.yaml\",\n        \"linkpred_cora_sage.yaml\",\n        \"linkpred_citation2_sage.yaml\",\n        \"linkpred_collab_sage.yaml\",\n        \"nodepred_citeseer_gat.yaml\",\n        \"nodepred_citeseer_gcn.yaml\",\n        \"nodepred_citeseer_sage.yaml\",\n        \"nodepred_cora_gat.yaml\",\n        \"nodepred_cora_gcn.yaml\",\n        \"nodepred_cora_sage.yaml\",\n        \"nodepred_pubmed_gat.yaml\",\n        \"nodepred_pubmed_gcn.yaml\",\n        \"nodepred_pubmed_sage.yaml\",\n        \"nodepred-ns_arxiv_gcn.yaml\",\n        \"nodepred-ns_product_sage.yaml\",\n    ],\n)\ndef test_recipe(recipe):\n    # Remove all generated yaml files\n    current_dir = os.listdir(\"./\")\n    for item in current_dir:\n        if item.endswith(\".yaml\"):\n            os.remove(item)\n\n    os.system(\"dgl recipe get {}\".format(recipe))\n    assert os.path.exists(recipe)\n\n\ndef test_node_cora():\n    os.system(\"dgl configure nodepred --data cora --model gcn\")\n    os.system(\"dgl train --cfg nodepred_cora_gcn.yaml\")\n    assert os.path.exists(\"results\")\n    assert os.path.exists(\"results/run_0.pth\")\n    os.system(\"dgl configure-apply nodepred --cpt results/run_0.pth\")\n    assert os.path.exists(\"apply_nodepred_cora_gcn.yaml\")\n    os.system(\n        \"dgl configure-apply nodepred --data cora --cpt results/run_0.pth --cfg apply.yaml\"\n    )\n    assert os.path.exists(\"apply.yaml\")\n    os.system(\"dgl apply --cfg apply.yaml\")\n    assert os.path.exists(\"apply_results/output.csv\")\n    os.system(\"dgl export --cfg apply.yaml --output apply.py\")\n    assert os.path.exists(\"apply.py\")\n"
  },
  {
    "path": "tests/integration/test_data.py",
    "content": "import gzip\nimport io\nimport os\nimport tarfile\nimport tempfile\nimport unittest\n\nimport backend as F\n\nimport dgl\nimport dgl.data as data\nimport numpy as np\nimport pandas as pd\nimport pytest\nimport yaml\nfrom dgl import DGLError\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Datasets don't need to be tested on GPU.\",\n)\n@unittest.skipIf(dgl.backend.backend_name == \"mxnet\", reason=\"Skip MXNet\")\ndef test_reddit():\n    # RedditDataset\n    g = data.RedditDataset()[0]\n    assert g.num_nodes() == 232965\n    assert g.num_edges() == 114615892\n    dst = F.asnumpy(g.edges()[1])\n    assert np.array_equal(dst, np.sort(dst))\n\n    transform = dgl.AddSelfLoop(allow_duplicate=True)\n    g2 = data.RedditDataset(transform=transform)[0]\n    assert g2.num_edges() - g.num_edges() == g.num_nodes()\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Datasets don't need to be tested on GPU.\",\n)\n@unittest.skipIf(dgl.backend.backend_name == \"mxnet\", reason=\"Skip MXNet\")\ndef test_fakenews():\n    transform = dgl.AddSelfLoop(allow_duplicate=True)\n\n    ds = data.FakeNewsDataset(\"politifact\", \"bert\")\n    assert len(ds) == 314\n    g = ds[0][0]\n    g2 = data.FakeNewsDataset(\"politifact\", \"bert\", transform=transform)[0][0]\n    assert g2.num_edges() - g.num_edges() == g.num_nodes()\n\n    ds = data.FakeNewsDataset(\"gossipcop\", \"profile\")\n    assert len(ds) == 5464\n    g = ds[0][0]\n    g2 = data.FakeNewsDataset(\"gossipcop\", \"profile\", transform=transform)[0][0]\n    assert g2.num_edges() - g.num_edges() == g.num_nodes()\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Datasets don't need to be tested on GPU.\",\n)\n@unittest.skipIf(\n    dgl.backend.backend_name != \"pytorch\", reason=\"only supports pytorch\"\n)\ndef test_peptides_structural():\n    transform = dgl.AddSelfLoop(allow_duplicate=True)\n    dataset1 = data.PeptidesStructuralDataset()\n    g1 = dataset1[0][0]\n    dataset2 = data.PeptidesStructuralDataset(transform=transform)\n    g2 = dataset2[0][0]\n\n    assert g2.num_edges() - g1.num_edges() == g1.num_nodes()\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Datasets don't need to be tested on GPU.\",\n)\n@unittest.skipIf(\n    dgl.backend.backend_name != \"pytorch\", reason=\"only supports pytorch\"\n)\ndef test_peptides_functional():\n    transform = dgl.AddSelfLoop(allow_duplicate=True)\n    dataset1 = data.PeptidesFunctionalDataset()\n    g1, label = dataset1[0]\n    dataset2 = data.PeptidesFunctionalDataset(transform=transform)\n    g2, _ = dataset2[0]\n\n    assert g2.num_edges() - g1.num_edges() == g1.num_nodes()\n    assert dataset1.num_classes == label.shape[0]\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Datasets don't need to be tested on GPU.\",\n)\n@unittest.skipIf(\n    dgl.backend.backend_name != \"pytorch\", reason=\"only supports pytorch\"\n)\ndef test_VOC_superpixels():\n    transform = dgl.AddSelfLoop(allow_duplicate=True)\n    dataset1 = data.VOCSuperpixelsDataset()\n    g1 = dataset1[0]\n    dataset2 = data.VOCSuperpixelsDataset(transform=transform)\n    g2 = dataset2[0]\n\n    assert g2.num_edges() - g1.num_edges() == g1.num_nodes()\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Datasets don't need to be tested on GPU.\",\n)\n@unittest.skipIf(\n    dgl.backend.backend_name != \"pytorch\", reason=\"only supports pytorch\"\n)\ndef test_COCO_superpixels():\n    transform = dgl.AddSelfLoop(allow_duplicate=True)\n    dataset1 = data.COCOSuperpixelsDataset()\n    g1 = dataset1[0]\n    dataset2 = data.COCOSuperpixelsDataset(transform=transform)\n    g2 = dataset2[0]\n\n    assert g2.num_edges() - g1.num_edges() == g1.num_nodes()\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Datasets don't need to be tested on GPU.\",\n)\n@unittest.skipIf(\n    dgl.backend.backend_name != \"pytorch\", reason=\"only supports pytorch\"\n)\ndef test_MNIST_SuperPixel():\n    transform = dgl.AddSelfLoop(allow_duplicate=True)\n    dataset1 = data.MNISTSuperPixelDataset()\n    g1, _ = dataset1[0]\n    dataset2 = data.MNISTSuperPixelDataset(transform=transform)\n    g2, _ = dataset2[0]\n\n    assert g2.num_edges() - g1.num_edges() == g1.num_nodes()\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Datasets don't need to be tested on GPU.\",\n)\n@unittest.skipIf(\n    dgl.backend.backend_name != \"pytorch\", reason=\"only supports pytorch\"\n)\ndef test_CIFAR10_SuperPixel():\n    transform = dgl.AddSelfLoop(allow_duplicate=True)\n    dataset1 = data.CIFAR10SuperPixelDataset()\n    g1, _ = dataset1[0]\n    dataset2 = data.CIFAR10SuperPixelDataset(transform=transform)\n    g2, _ = dataset2[0]\n\n    assert g2.num_edges() - g1.num_edges() == g1.num_nodes()\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Datasets don't need to be tested on GPU.\",\n)\n@unittest.skipIf(dgl.backend.backend_name == \"mxnet\", reason=\"Skip MXNet\")\ndef test_as_graphpred():\n    ds = data.GINDataset(name=\"MUTAG\", self_loop=True)\n    new_ds = data.AsGraphPredDataset(ds, [0.8, 0.1, 0.1], verbose=True)\n    assert len(new_ds) == 188\n    assert new_ds.num_tasks == 1\n    assert new_ds.num_classes == 2\n\n    ds = data.FakeNewsDataset(\"politifact\", \"profile\")\n    new_ds = data.AsGraphPredDataset(ds, verbose=True)\n    assert len(new_ds) == 314\n    assert new_ds.num_tasks == 1\n    assert new_ds.num_classes == 2\n\n    ds = data.QM7bDataset()\n    new_ds = data.AsGraphPredDataset(ds, [0.8, 0.1, 0.1], verbose=True)\n    assert len(new_ds) == 7211\n    assert new_ds.num_tasks == 14\n    assert new_ds.num_classes is None\n\n    ds = data.QM9Dataset(label_keys=[\"mu\", \"gap\"])\n    new_ds = data.AsGraphPredDataset(ds, [0.8, 0.1, 0.1], verbose=True)\n    assert len(new_ds) == 130831\n    assert new_ds.num_tasks == 2\n    assert new_ds.num_classes is None\n\n    ds = data.QM9EdgeDataset(label_keys=[\"mu\", \"alpha\"])\n    new_ds = data.AsGraphPredDataset(ds, [0.8, 0.1, 0.1], verbose=True)\n    assert len(new_ds) == 130831\n    assert new_ds.num_tasks == 2\n    assert new_ds.num_classes is None\n\n    ds = data.TUDataset(\"DD\")\n    new_ds = data.AsGraphPredDataset(ds, [0.8, 0.1, 0.1], verbose=True)\n    assert len(new_ds) == 1178\n    assert new_ds.num_tasks == 1\n    assert new_ds.num_classes == 2\n\n    ds = data.LegacyTUDataset(\"DD\")\n    new_ds = data.AsGraphPredDataset(ds, [0.8, 0.1, 0.1], verbose=True)\n    assert len(new_ds) == 1178\n    assert new_ds.num_tasks == 1\n    assert new_ds.num_classes == 2\n\n    ds = data.BA2MotifDataset()\n    new_ds = data.AsGraphPredDataset(ds, [0.8, 0.1, 0.1], verbose=True)\n    assert len(new_ds) == 1000\n    assert new_ds.num_tasks == 1\n    assert new_ds.num_classes == 2\n\n\n@unittest.skipIf(\n    dgl.backend.backend_name != \"pytorch\", reason=\"ogb only supports pytorch\"\n)\ndef test_as_linkpred_ogb():\n    from ogb.linkproppred import DglLinkPropPredDataset\n\n    ds = data.AsLinkPredDataset(\n        DglLinkPropPredDataset(\"ogbl-collab\"), split_ratio=None, verbose=True\n    )\n    # original dataset has 46329 test edges\n    assert ds.test_edges[0][0].shape[0] == 46329\n    # force generate new split\n    ds = data.AsLinkPredDataset(\n        DglLinkPropPredDataset(\"ogbl-collab\"),\n        split_ratio=[0.7, 0.2, 0.1],\n        verbose=True,\n    )\n    assert ds.test_edges[0][0].shape[0] == 235812\n\n\n@unittest.skipIf(\n    dgl.backend.backend_name != \"pytorch\", reason=\"ogb only supports pytorch\"\n)\n@unittest.skipIf(dgl.backend.backend_name == \"mxnet\", reason=\"Skip MXNet\")\ndef test_as_nodepred_ogb():\n    from ogb.nodeproppred import DglNodePropPredDataset\n\n    ds = data.AsNodePredDataset(\n        DglNodePropPredDataset(\"ogbn-arxiv\"), split_ratio=None, verbose=True\n    )\n    split = DglNodePropPredDataset(\"ogbn-arxiv\").get_idx_split()\n    train_idx, val_idx, test_idx = split[\"train\"], split[\"valid\"], split[\"test\"]\n    assert F.array_equal(ds.train_idx, F.tensor(train_idx))\n    assert F.array_equal(ds.val_idx, F.tensor(val_idx))\n    assert F.array_equal(ds.test_idx, F.tensor(test_idx))\n    # force generate new split\n    ds = data.AsNodePredDataset(\n        DglNodePropPredDataset(\"ogbn-arxiv\"),\n        split_ratio=[0.7, 0.2, 0.1],\n        verbose=True,\n    )\n\n\n@unittest.skipIf(\n    dgl.backend.backend_name != \"pytorch\", reason=\"ogb only supports pytorch\"\n)\ndef test_as_graphpred_ogb():\n    from ogb.graphproppred import DglGraphPropPredDataset\n\n    ds = data.AsGraphPredDataset(\n        DglGraphPropPredDataset(\"ogbg-molhiv\"), split_ratio=None, verbose=True\n    )\n    assert len(ds.train_idx) == 32901\n    # force generate new split\n    ds = data.AsGraphPredDataset(\n        DglGraphPropPredDataset(\"ogbg-molhiv\"),\n        split_ratio=[0.6, 0.2, 0.2],\n        verbose=True,\n    )\n    assert len(ds.train_idx) == 24676\n"
  },
  {
    "path": "tests/lint/clangformat_linter.py",
    "content": "\"\"\"Borrowed from github.com/pytorch/pytorch/tools/linter/adapters/clangformat_linter.py\"\"\"\nimport argparse\nimport concurrent.futures\nimport json\nimport logging\nimport os\nimport subprocess\nimport sys\nimport time\nfrom enum import Enum\nfrom pathlib import Path\nfrom typing import Any, List, NamedTuple, Optional\n\n\nIS_WINDOWS: bool = os.name == \"nt\"\n\n\ndef eprint(*args: Any, **kwargs: Any) -> None:\n    print(*args, file=sys.stderr, flush=True, **kwargs)\n\n\nclass LintSeverity(str, Enum):\n    ERROR = \"error\"\n    WARNING = \"warning\"\n    ADVICE = \"advice\"\n    DISABLED = \"disabled\"\n\n\nclass LintMessage(NamedTuple):\n    path: Optional[str]\n    line: Optional[int]\n    char: Optional[int]\n    code: str\n    severity: LintSeverity\n    name: str\n    original: Optional[str]\n    replacement: Optional[str]\n    description: Optional[str]\n\n\ndef as_posix(name: str) -> str:\n    return name.replace(\"\\\\\", \"/\") if IS_WINDOWS else name\n\n\ndef _run_command(\n    args: List[str],\n    *,\n    timeout: int,\n) -> \"subprocess.CompletedProcess[bytes]\":\n    logging.debug(\"$ %s\", \" \".join(args))\n    start_time = time.monotonic()\n    try:\n        return subprocess.run(\n            args,\n            stdout=subprocess.PIPE,\n            stderr=subprocess.PIPE,\n            shell=IS_WINDOWS,  # So batch scripts are found.\n            timeout=timeout,\n            check=True,\n        )\n    finally:\n        end_time = time.monotonic()\n        logging.debug(\"took %dms\", (end_time - start_time) * 1000)\n\n\ndef run_command(\n    args: List[str],\n    *,\n    retries: int,\n    timeout: int,\n) -> \"subprocess.CompletedProcess[bytes]\":\n    remaining_retries = retries\n    while True:\n        try:\n            return _run_command(args, timeout=timeout)\n        except subprocess.TimeoutExpired as err:\n            if remaining_retries == 0:\n                raise err\n            remaining_retries -= 1\n            logging.warning(\n                \"(%s/%s) Retrying because command failed with: %r\",\n                retries - remaining_retries,\n                retries,\n                err,\n            )\n            time.sleep(1)\n\n\ndef check_file(\n    filename: str,\n    binary: str,\n    retries: int,\n    timeout: int,\n) -> List[LintMessage]:\n    try:\n        with open(filename, \"rb\") as f:\n            original = f.read()\n        proc = run_command(\n            [binary, filename],\n            retries=retries,\n            timeout=timeout,\n        )\n    except subprocess.TimeoutExpired:\n        return [\n            LintMessage(\n                path=filename,\n                line=None,\n                char=None,\n                code=\"CLANGFORMAT\",\n                severity=LintSeverity.ERROR,\n                name=\"timeout\",\n                original=None,\n                replacement=None,\n                description=(\n                    \"clang-format timed out while trying to process a file. \"\n                    \"Please report an issue in pytorch/pytorch with the \"\n                    \"label 'module: lint'\"\n                ),\n            )\n        ]\n    except (OSError, subprocess.CalledProcessError) as err:\n        return [\n            LintMessage(\n                path=filename,\n                line=None,\n                char=None,\n                code=\"CLANGFORMAT\",\n                severity=LintSeverity.ADVICE,\n                name=\"command-failed\",\n                original=None,\n                replacement=None,\n                description=(\n                    f\"Failed due to {err.__class__.__name__}:\\n{err}\"\n                    if not isinstance(err, subprocess.CalledProcessError)\n                    else (\n                        \"COMMAND (exit code {returncode})\\n\"\n                        \"{command}\\n\\n\"\n                        \"STDERR\\n{stderr}\\n\\n\"\n                        \"STDOUT\\n{stdout}\"\n                    ).format(\n                        returncode=err.returncode,\n                        command=\" \".join(as_posix(x) for x in err.cmd),\n                        stderr=err.stderr.decode(\"utf-8\").strip() or \"(empty)\",\n                        stdout=err.stdout.decode(\"utf-8\").strip() or \"(empty)\",\n                    )\n                ),\n            )\n        ]\n\n    replacement = proc.stdout\n    if original == replacement:\n        return []\n\n    line = 0\n    original = original.decode(\"utf-8\")\n    replacement = replacement.decode(\"utf-8\")\n    for line, (i, j) in enumerate(\n        zip(original.split(\"\\n\"), replacement.split(\"\\n\"))\n    ):\n        if i != j:\n            break\n\n    return [\n        LintMessage(\n            path=filename,\n            line=line,\n            char=None,\n            code=\"CLANGFORMAT\",\n            severity=LintSeverity.WARNING,\n            name=\"format\",\n            original=original,\n            replacement=replacement,\n            description=\"See https://clang.llvm.org/docs/ClangFormat.html.\\nRun `lintrunner -a` to apply this patch.\",\n        )\n    ]\n\n\ndef main() -> None:\n    parser = argparse.ArgumentParser(\n        description=\"Format files with clang-format.\",\n        fromfile_prefix_chars=\"@\",\n    )\n    parser.add_argument(\n        \"--binary\",\n        required=True,\n        help=\"clang-format binary path\",\n    )\n    parser.add_argument(\n        \"--retries\",\n        default=3,\n        type=int,\n        help=\"times to retry timed out clang-format\",\n    )\n    parser.add_argument(\n        \"--timeout\",\n        default=90,\n        type=int,\n        help=\"seconds to wait for clang-format\",\n    )\n    parser.add_argument(\n        \"--verbose\",\n        action=\"store_true\",\n        help=\"verbose logging\",\n    )\n    parser.add_argument(\n        \"filenames\",\n        nargs=\"+\",\n        help=\"paths to lint\",\n    )\n    args = parser.parse_args()\n\n    logging.basicConfig(\n        format=\"<%(threadName)s:%(levelname)s> %(message)s\",\n        level=logging.NOTSET\n        if args.verbose\n        else logging.DEBUG\n        if len(args.filenames) < 1000\n        else logging.INFO,\n        stream=sys.stderr,\n    )\n\n    with concurrent.futures.ThreadPoolExecutor(\n        max_workers=os.cpu_count(),\n        thread_name_prefix=\"Thread\",\n    ) as executor:\n        futures = {\n            executor.submit(\n                check_file, x, args.binary, args.retries, args.timeout\n            ): x\n            for x in args.filenames\n        }\n        for future in concurrent.futures.as_completed(futures):\n            try:\n                for lint_message in future.result():\n                    print(json.dumps(lint_message._asdict()), flush=True)\n            except Exception:\n                logging.critical('Failed at \"%s\".', futures[future])\n                raise\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "tests/lint/lint.py",
    "content": "#!/usr/bin/env python3\n# pylint: disable=protected-access, unused-variable, locally-disabled, len-as-condition\n\"\"\"Lint helper to generate lint summary of source.\n\nCopyright by Contributors.\n\nBorrowed from dmlc-core/scripts/lint.py@939c052\n\"\"\"\nfrom __future__ import print_function\n\nimport argparse\nimport codecs\nimport os\nimport re\nimport sys\n\nimport cpplint\nfrom cpplint import _cpplint_state\nfrom pylint import epylint\n\nCXX_SUFFIX = set([\"cc\", \"c\", \"cpp\", \"h\", \"cu\", \"hpp\", \"cuh\"])\nPYTHON_SUFFIX = set([\"py\"])\n\n\ndef filepath_enumerate(paths):\n    \"\"\"Enumerate the file paths of all subfiles of the list of paths\"\"\"\n    out = []\n    for path in paths:\n        if os.path.isfile(path):\n            out.append(path)\n        else:\n            for root, dirs, files in os.walk(path):\n                for name in files:\n                    out.append(os.path.normpath(os.path.join(root, name)))\n    return out\n\n\n# pylint: disable=useless-object-inheritance\nclass LintHelper(object):\n    \"\"\"Class to help runing the lint and records summary\"\"\"\n\n    @staticmethod\n    def _print_summary_map(strm, result_map, ftype):\n        \"\"\"Print summary of certain result map.\"\"\"\n        if len(result_map) == 0:\n            return 0\n        npass = sum(1 for x in result_map.values() if len(x) == 0)\n        strm.write(\n            f\"====={npass}/{len(result_map)} {ftype} files passed check=====\\n\"\n        )\n        for fname, emap in result_map.items():\n            if len(emap) == 0:\n                continue\n            strm.write(\n                f\"{fname}: {sum(emap.values())} Errors of {len(emap)} Categories map={str(emap)}\\n\"\n            )\n        return len(result_map) - npass\n\n    def __init__(self):\n        self.project_name = None\n        self.cpp_header_map = {}\n        self.cpp_src_map = {}\n        self.python_map = {}\n        pylint_disable = [\n            \"superfluous-parens\",\n            \"too-many-instance-attributes\",\n            \"too-few-public-methods\",\n        ]\n        # setup pylint\n        self.pylint_opts = [\n            \"--extension-pkg-whitelist=numpy\",\n            \"--disable=\" + \",\".join(pylint_disable),\n        ]\n\n        self.pylint_cats = set([\"error\", \"warning\", \"convention\", \"refactor\"])\n        # setup cpp lint\n        cpplint_args = [\n            \"--quiet\",\n            \"--extensions=\" + (\",\".join(CXX_SUFFIX)),\n            \".\",\n        ]\n        _ = cpplint.ParseArguments(cpplint_args)\n        cpplint._SetFilters(\n            \",\".join(\n                [\n                    \"-build/c++11\",\n                    \"-build/namespaces\",\n                    \"-build/include,\",\n                    \"+build/include_what_you_use\",\n                    \"+build/include_order\",\n                ]\n            )\n        )\n        cpplint._SetCountingStyle(\"toplevel\")\n        cpplint._line_length = 80\n\n    def process_cpp(self, path, suffix):\n        \"\"\"Process a cpp file.\"\"\"\n        _cpplint_state.ResetErrorCounts()\n        cpplint.ProcessFile(str(path), _cpplint_state.verbose_level)\n        _cpplint_state.PrintErrorCounts()\n        errors = _cpplint_state.errors_by_category.copy()\n\n        if suffix == \"h\":\n            self.cpp_header_map[str(path)] = errors\n        else:\n            self.cpp_src_map[str(path)] = errors\n\n    def process_python(self, path):\n        \"\"\"Process a python file.\"\"\"\n        (pylint_stdout, pylint_stderr) = epylint.py_run(\n            \" \".join([str(path)] + self.pylint_opts), return_std=True\n        )\n        emap = {}\n        err = pylint_stderr.read()\n        if len(err):\n            print(err)\n        for line in pylint_stdout:\n            sys.stderr.write(line)\n            key = line.split(\":\")[-1].split(\"(\")[0].strip()\n            if key not in self.pylint_cats:\n                continue\n            if key not in emap:\n                emap[key] = 1\n            else:\n                emap[key] += 1\n        self.python_map[str(path)] = emap\n\n    def print_summary(self, strm):\n        \"\"\"Print summary of lint.\"\"\"\n        nerr = 0\n        nerr += LintHelper._print_summary_map(\n            strm, self.cpp_header_map, \"cpp-header\"\n        )\n        nerr += LintHelper._print_summary_map(\n            strm, self.cpp_src_map, \"cpp-source\"\n        )\n        nerr += LintHelper._print_summary_map(strm, self.python_map, \"python\")\n        if nerr == 0:\n            strm.write(\"All passed!\\n\")\n        else:\n            strm.write(f\"{nerr} files failed lint\\n\")\n        return nerr\n\n\n# singleton helper for lint check\n_HELPER = LintHelper()\n\n\ndef get_header_guard_dmlc(filename):\n    \"\"\"Get Header Guard Convention for DMLC Projects.\n\n    For headers in include, directly use the path\n    For headers in src, use project name plus path\n\n    Examples: with project-name = dmlc\n        include/dmlc/timer.h -> DMLC_TIMTER_H_\n        src/io/libsvm_parser.h -> DMLC_IO_LIBSVM_PARSER_H_\n    \"\"\"\n    fileinfo = cpplint.FileInfo(filename)\n    file_path_from_root = fileinfo.RepositoryName()\n    inc_list = [\"include\", \"api\", \"wrapper\", \"contrib\"]\n    if os.name == \"nt\":\n        inc_list.append(\"mshadow\")\n\n    if (\n        file_path_from_root.find(\"src/\") != -1\n        and _HELPER.project_name is not None\n    ):\n        idx = file_path_from_root.find(\"src/\")\n        file_path_from_root = (\n            _HELPER.project_name + file_path_from_root[idx + 3 :]\n        )\n    else:\n        idx = file_path_from_root.find(\"include/\")\n        if idx != -1:\n            file_path_from_root = file_path_from_root[idx + 8 :]\n        for spath in inc_list:\n            prefix = spath + \"/\"\n            if file_path_from_root.startswith(prefix):\n                file_path_from_root = re.sub(\n                    \"^\" + prefix, \"\", file_path_from_root\n                )\n                break\n    return re.sub(r\"[-./\\s]\", \"_\", file_path_from_root).upper() + \"_\"\n\n\ncpplint.GetHeaderGuardCPPVariable = get_header_guard_dmlc\n\n\ndef process(fname, allow_type):\n    \"\"\"Process a file.\"\"\"\n    fname = str(fname)\n    arr = fname.rsplit(\".\", 1)\n    if fname.find(\"#\") != -1 or arr[-1] not in allow_type:\n        return\n    if arr[-1] in CXX_SUFFIX:\n        _HELPER.process_cpp(fname, arr[-1])\n    if arr[-1] in PYTHON_SUFFIX:\n        _HELPER.process_python(fname)\n\n\ndef main():\n    \"\"\"Main entry function.\"\"\"\n    parser = argparse.ArgumentParser(description=\"lint source codes\")\n    parser.add_argument(\"project\", help=\"project name\")\n    parser.add_argument(\n        \"filetype\", choices=[\"python\", \"cpp\", \"all\"], help=\"source code type\"\n    )\n    parser.add_argument(\"path\", nargs=\"+\", help=\"path to traverse\")\n    parser.add_argument(\n        \"--exclude_path\",\n        nargs=\"+\",\n        default=[],\n        help=\"exclude this path, and all subfolders if path is a folder\",\n    )\n    parser.add_argument(\n        \"--quiet\", action=\"store_true\", help=\"run cpplint in quiet mode\"\n    )\n    parser.add_argument(\"--pylint-rc\", default=None, help=\"pylint rc file\")\n    args = parser.parse_args()\n\n    _HELPER.project_name = args.project\n    if args.pylint_rc is not None:\n        _HELPER.pylint_opts = [\n            \"--rcfile=\" + args.pylint_rc,\n        ]\n    file_type = args.filetype\n    allow_type = []\n    if file_type in (\"python\", \"all\"):\n        allow_type += PYTHON_SUFFIX\n    if file_type in (\"cpp\", \"all\"):\n        allow_type += CXX_SUFFIX\n    allow_type = set(allow_type)\n    if sys.version_info.major == 2 and os.name != \"nt\":\n        sys.stderr = codecs.StreamReaderWriter(\n            sys.stderr,\n            codecs.getreader(\"utf8\"),\n            codecs.getwriter(\"utf8\"),\n            \"replace\",\n        )\n    # get excluded files\n    excluded_paths = filepath_enumerate(args.exclude_path)\n    for path in args.path:\n        if os.path.isfile(path):\n            normpath = os.path.normpath(path)\n            if normpath not in excluded_paths:\n                process(path, allow_type)\n        else:\n            for root, dirs, files in os.walk(path):\n                for name in files:\n                    file_path = os.path.normpath(os.path.join(root, name))\n                    if file_path not in excluded_paths:\n                        process(file_path, allow_type)\n    nerr = _HELPER.print_summary(sys.stderr)\n    sys.exit(nerr > 0)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "tests/lint/pip_init.py",
    "content": "\"\"\"\nInitializer script that installs stuff to pip.\n\nBorrowed from github.com/pytorch/pytorch/tools/linter/adapters/pip_init.py\n\"\"\"\nimport argparse\nimport logging\nimport os\nimport subprocess\nimport sys\nimport time\n\nfrom typing import List\n\n\ndef run_command(args: List[str]) -> \"subprocess.CompletedProcess[bytes]\":\n    logging.debug(\"$ %s\", \" \".join(args))\n    start_time = time.monotonic()\n    try:\n        return subprocess.run(args, check=True)\n    finally:\n        end_time = time.monotonic()\n        logging.debug(\"took %dms\", (end_time - start_time) * 1000)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"pip initializer\")\n    parser.add_argument(\n        \"packages\",\n        nargs=\"+\",\n        help=\"pip packages to install\",\n    )\n    parser.add_argument(\n        \"--verbose\",\n        action=\"store_true\",\n        help=\"verbose logging\",\n    )\n    parser.add_argument(\n        \"--dry-run\",\n        help=\"do not install anything, just print what would be done.\",\n    )\n    parser.add_argument(\n        \"--no-black-binary\",\n        help=\"do not use pre-compiled binaries from pip for black.\",\n        action=\"store_true\",\n    )\n\n    args = parser.parse_args()\n\n    logging.basicConfig(\n        format=\"<%(threadName)s:%(levelname)s> %(message)s\",\n        level=logging.NOTSET if args.verbose else logging.DEBUG,\n        stream=sys.stderr,\n    )\n\n    pip_args = [\"pip3\", \"install\"]\n\n    # If we are in a global install, use `--user` to install so that you do not\n    # need root access in order to initialize linters.\n    #\n    # However, `pip install --user` interacts poorly with virtualenvs (see:\n    # https://bit.ly/3vD4kvl) and conda (see: https://bit.ly/3KG7ZfU). So in\n    # these cases perform a regular installation.\n    in_conda = os.environ.get(\"CONDA_PREFIX\") is not None\n    in_virtualenv = os.environ.get(\"VIRTUAL_ENV\") is not None\n    if not in_conda and not in_virtualenv:\n        pip_args.append(\"--user\")\n\n    pip_args.extend(args.packages)\n\n    for package in args.packages:\n        package_name, _, version = package.partition(\"=\")\n        if version == \"\":\n            raise RuntimeError(\n                \"Package {package_name} did not have a version specified. \"\n                \"Please specify a version to produce a consistent linting experience.\"\n            )\n        if args.no_black_binary and \"black\" in package_name:\n            pip_args.append(f\"--no-binary={package_name}\")\n\n    dry_run = args.dry_run == \"1\"\n    if dry_run:\n        print(f\"Would have run: {pip_args}\")\n        sys.exit(0)\n\n    run_command(pip_args)\n"
  },
  {
    "path": "tests/lint/pylintrc",
    "content": "[MASTER]\n\n# A comma-separated list of package or module names from where C extensions may\n# be loaded. Extensions are loading into the active Python interpreter and may\n# run arbitrary code.\nextension-pkg-whitelist=\n\n# Add files or directories to the blacklist. They should be base names, not\n# paths.\nignore=CVS,_cy2,_cy3,backend,data,contrib,_deprecate\n\n# Add files or directories matching the regex patterns to the blacklist. The\n# regex matches against base names, not paths.\nignore-patterns=\n\n# Python code to execute, usually for sys.path manipulation such as\n# pygtk.require().\n#init-hook=\n\n# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the\n# number of processors available to use.\njobs=4\n\n# Control the amount of potential inferred values when inferring a single\n# object. This can help the performance when dealing with large functions or\n# complex, nested conditions.\nlimit-inference-results=100\n\n# List of plugins (as comma separated values of python modules names) to load,\n# usually to register additional checkers.\nload-plugins=\n\n# Pickle collected data for later comparisons.\npersistent=yes\n\n# Specify a configuration file.\n#rcfile=\n\n# When enabled, pylint would attempt to guess common misconfiguration and emit\n# user-friendly hints instead of false-positive error messages.\nsuggestion-mode=yes\n\n# Allow loading of arbitrary C extensions. Extensions are imported into the\n# active Python interpreter and may run arbitrary code.\nunsafe-load-any-extension=no\n\n\n[MESSAGES CONTROL]\n\n# Only show warnings with the listed confidence levels. Leave empty to show\n# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED.\nconfidence=\n\n# Disable the message, report, category or checker with the given id(s). You\n# can either give multiple identifiers separated by comma (,) or put this\n# option multiple times (only on the command line, not in the configuration\n# file where it should appear only once). You can also use \"--disable=all\" to\n# disable everything first and then reenable specific checks. For example, if\n# you want to run only the similarities checker, you can use \"--disable=all\n# --enable=similarities\". If you want to run only the classes checker, but have\n# no Warning level messages displayed, use \"--disable=all --enable=classes\n# --disable=W\".\ndisable=design,\n        similarities,\n        no-self-use,\n        attribute-defined-outside-init,\n        locally-disabled,\n        star-args,\n        pointless-except,\n        bad-option-value,\n        global-statement,\n        fixme,\n        suppressed-message,\n        useless-suppression,\n        locally-enabled,\n        import-error,\n        unsubscriptable-object,\n        unbalanced-tuple-unpacking,\n        protected-access,\n        useless-object-inheritance,\n        no-else-return,\n        len-as-condition,\n        cyclic-import,          # disabled due to the inevitable dgl.graph -> dgl.subgraph loop\n        undefined-variable,     # disabled due to C extension (should enable)\n        raise-missing-from,     # meh\n        import-outside-toplevel,    # due to inevitable imports within blocks\n        using-constant-test,    # due to in-place object modification in C\n        super-with-arguments,   # 2.3.0->2.6.0, should enable but there's too many...\n        not-callable,           # due to optional callables that can be None\n\n# Enable the message, report, category or checker with the given id(s). You can\n# either give multiple identifier separated by comma (,) or put this option\n# multiple time (only on the command line, not in the configuration file where\n# it should appear only once). See also the \"--disable\" option for examples.\nenable=c-extension-no-member\n\n\n[REPORTS]\n\n# Python expression which should return a note less than 10 (10 is the highest\n# note). You have access to the variables errors warning, statement which\n# respectively contain the number of errors / warnings messages and the total\n# number of statements analyzed. This is used by the global evaluation report\n# (RP0004).\nevaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)\n\n# Template used to display messages. This is a python new-style format string\n# used to format the message information. See doc for all details.\n#msg-template=\n\n# Set the output format. Available formats are text, parseable, colorized, json\n# and msvs (visual studio). You can also give a reporter class, e.g.\n# mypackage.mymodule.MyReporterClass.\noutput-format=text\n\n# Tells whether to display a full report or only the messages.\nreports=no\n\n# Activate the evaluation score.\nscore=yes\n\n\n[REFACTORING]\n\n# Maximum number of nested blocks for function / method body\nmax-nested-blocks=5\n\n# Complete name of functions that never returns. When checking for\n# inconsistent-return-statements if a never returning function is called then\n# it will be considered as an explicit return statement and no message will be\n# printed.\nnever-returning-functions=sys.exit\n\n\n[MISCELLANEOUS]\n\n# List of note tags to take in consideration, separated by a comma.\nnotes=FIXME,\n      XXX,\n      TODO\n\n\n[BASIC]\n\n# Naming style matching correct argument names.\nargument-naming-style=snake_case\n\n# Regular expression matching correct argument names. Overrides argument-\n# naming-style.\n#argument-rgx=\n\n# Naming style matching correct attribute names.\nattr-naming-style=snake_case\n\n# Regular expression matching correct attribute names. Overrides attr-naming-\n# style.\n#attr-rgx=\n\n# Bad variable names which should always be refused, separated by a comma.\nbad-names=foo,\n          bar,\n          baz,\n          toto,\n          tutu,\n          tata\n\n# Naming style matching correct class attribute names.\nclass-attribute-naming-style=any\n\n# Regular expression matching correct class attribute names. Overrides class-\n# attribute-naming-style.\n#class-attribute-rgx=\n\n# Naming style matching correct class names.\nclass-naming-style=PascalCase\n\n# Regular expression matching correct class names. Overrides class-naming-\n# style.\n#class-rgx=\n\n# Naming style matching correct constant names.\nconst-naming-style=UPPER_CASE\n\n# Regular expression matching correct constant names. Overrides const-naming-\n# style.\n#const-rgx=\n\n# Minimum line length for functions/classes that require docstrings, shorter\n# ones are exempt.\ndocstring-min-length=-1\n\n# Naming style matching correct function names.\nfunction-naming-style=snake_case\n\n# Regular expression matching correct function names. Overrides function-\n# naming-style.\n#function-rgx=\n\n# Good variable names which should always be accepted, separated by a comma.\n# f - files\n# i, j, k - loop variables\n# u, v, e - nodes and edges\n# s, d - source and destination\n# t - time\n# r - relation type\n# n, m - general integers representing quantity\n# w, x, y, z - general math variables\n# g, G - graphs\n# hg - heterogeneous graphs\n# sg - subgraphs\n# fn - functions\n# us, vs, es, gs - plural form of u, v, g, e\n# op - operators\n# ty - type\n# A, B, C, W - for tensor operators like matmul\n# dp - DataPipes (see https://pytorch.org/data/0.7/torchdata.datapipes.iter.html)\n# it - iterators\ngood-names=f,i,j,k,u,v,e,n,m,w,x,y,z,s,d,t,r,g,G,hg,sg,fn,ex,Run,_,us,vs,gs,es,op,ty,A,B,C,W,a,b,N,D1,D2,R,dp,it\n\n# Include a hint for the correct naming format with invalid-name.\ninclude-naming-hint=no\n\n# Naming style matching correct inline iteration names.\ninlinevar-naming-style=any\n\n# Regular expression matching correct inline iteration names. Overrides\n# inlinevar-naming-style.\n#inlinevar-rgx=\n\n# Naming style matching correct method names.\nmethod-naming-style=snake_case\n\n# Regular expression matching correct method names. Overrides method-naming-\n# style.\n#method-rgx=\n\n# Naming style matching correct module names.\nmodule-naming-style=snake_case\n\n# Regular expression matching correct module names. Overrides module-naming-\n# style.\n#module-rgx=\n\n# Colon-delimited sets of names that determine each other's naming style when\n# the name regexes allow several styles.\nname-group=\n\n# Regular expression which should only match function or class names that do\n# not require a docstring.\nno-docstring-rgx=^_\n\n# List of decorators that produce properties, such as abc.abstractproperty. Add\n# to this list to register other decorators that produce valid properties.\n# These decorators are taken in consideration only for invalid-name.\nproperty-classes=abc.abstractproperty\n\n# Naming style matching correct variable names.\nvariable-naming-style=snake_case\n\n# Regular expression matching correct variable names. Overrides variable-\n# naming-style.\n#variable-rgx=\n\n\n[VARIABLES]\n\n# List of additional names supposed to be defined in builtins. Remember that\n# you should avoid defining new builtins when possible.\nadditional-builtins=\n\n# Tells whether unused global variables should be treated as a violation.\nallow-global-unused-variables=yes\n\n# List of strings which can identify a callback function by name. A callback\n# name must start or end with one of those strings.\ncallbacks=cb_,\n          _cb\n\n# A regular expression matching the name of dummy variables (i.e. expected to\n# not be used).\ndummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_\n\n# Argument names that match this expression will be ignored. Default to name\n# with leading underscore.\nignored-argument-names=_.*|^ignored_|^unused_\n\n# Tells whether we should check for unused import in __init__ files.\ninit-import=no\n\n# List of qualified module names which can have objects that can redefine\n# builtins.\nredefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io\n\n\n[SPELLING]\n\n# Limits count of emitted suggestions for spelling mistakes.\nmax-spelling-suggestions=4\n\n# Spelling dictionary name. Available dictionaries: none. To make it working\n# install python-enchant package..\nspelling-dict=\n\n# List of comma separated words that should not be checked.\nspelling-ignore-words=\n\n# A path to a file that contains private dictionary; one word per line.\nspelling-private-dict-file=\n\n# Tells whether to store unknown words to indicated private dictionary in\n# --spelling-private-dict-file option instead of raising a message.\nspelling-store-unknown-words=no\n\n\n[LOGGING]\n\n# Format style used to check logging format string. `old` means using %\n# formatting, while `new` is for `{}` formatting.\nlogging-format-style=old\n\n# Logging modules to check that the string format arguments are in logging\n# function parameter format.\nlogging-modules=logging\n\n\n[FORMAT]\n\n# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.\nexpected-line-ending-format=\n\n# Regexp for a line that is allowed to be longer than the limit.\nignore-long-lines=^\\s*(# )?<?https?://\\S+>?$\n\n# Number of spaces of indent required inside a hanging or continued line.\nindent-after-paren=4\n\n# String used as indentation unit. This is usually \"    \" (4 spaces) or \"\\t\" (1\n# tab).\nindent-string='    '\n\n# Maximum number of characters on a single line.\nmax-line-length=100\n\n# Maximum number of lines in a module.\nmax-module-lines=4000\n\n# List of optional constructs for which whitespace checking is disabled. `dict-\n# separator` is used to allow tabulation in dicts, etc.: {1  : 1,\\n222: 2}.\n# `trailing-comma` allows a space between comma and closing bracket: (a, ).\n# `empty-line` allows space-only lines.\nno-space-check=trailing-comma,\n               dict-separator\n\n# Allow the body of a class to be on the same line as the declaration if body\n# contains single statement.\nsingle-line-class-stmt=no\n\n# Allow the body of an if to be on the same line as the test if there is no\n# else.\nsingle-line-if-stmt=no\n\n\n[SIMILARITIES]\n\n# Ignore comments when computing similarities.\nignore-comments=yes\n\n# Ignore docstrings when computing similarities.\nignore-docstrings=yes\n\n# Ignore imports when computing similarities.\nignore-imports=no\n\n# Minimum lines number of a similarity.\nmin-similarity-lines=4\n\n\n[TYPECHECK]\n\n# List of decorators that produce context managers, such as\n# contextlib.contextmanager. Add to this list to register other decorators that\n# produce valid context managers.\ncontextmanager-decorators=contextlib.contextmanager\n\n# List of members which are set dynamically and missed by pylint inference\n# system, and so shouldn't trigger E1101 when accessed. Python regular\n# expressions are accepted.\ngenerated-members=\n\n# Tells whether missing members accessed in mixin class should be ignored. A\n# mixin class is detected if its name ends with \"mixin\" (case insensitive).\nignore-mixin-members=yes\n\n# Tells whether to warn about missing members when the owner of the attribute\n# is inferred to be None.\nignore-none=yes\n\n# This flag controls whether pylint should warn about no-member and similar\n# checks whenever an opaque object is returned when inferring. The inference\n# can return multiple potential results while evaluating a Python object, but\n# some branches might not be evaluated, which results in partial inference. In\n# that case, it might be useful to still emit no-member and other checks for\n# the rest of the inferred objects.\nignore-on-opaque-inference=yes\n\n# List of class names for which member attributes should not be checked (useful\n# for classes with dynamically set attributes). This supports the use of\n# qualified names.\nignored-classes=optparse.Values,thread._local,_thread._local\n\n# List of module names for which member attributes should not be checked\n# (useful for modules/projects where namespaces are manipulated during runtime\n# and thus existing member attributes cannot be deduced by static analysis. It\n# supports qualified module names, as well as Unix pattern matching.\nignored-modules=dgl.backend,dgl._api_internal,dgl._deprecate\n\n# Show a hint with possible names when a member name was not found. The aspect\n# of finding the hint is based on edit distance.\nmissing-member-hint=yes\n\n# The minimum edit distance a name should have in order to be considered a\n# similar match for a missing member name.\nmissing-member-hint-distance=1\n\n# The total number of similar names that should be taken in consideration when\n# showing a hint for a missing member.\nmissing-member-max-choices=1\n\n\n[IMPORTS]\n\n# Allow wildcard imports from modules that define __all__.\nallow-wildcard-with-all=yes\n\n# Analyse import fallback blocks. This can be used to support both Python 2 and\n# 3 compatible code, which means that the block might have code that exists\n# only in one or another interpreter, leading to false positives when analysed.\nanalyse-fallback-blocks=no\n\n# Deprecated modules which should not be used, separated by a comma.\ndeprecated-modules=optparse,tkinter.tix\n\n# Create a graph of external dependencies in the given file (report RP0402 must\n# not be disabled).\next-import-graph=\n\n# Create a graph of every (i.e. internal and external) dependencies in the\n# given file (report RP0402 must not be disabled).\nimport-graph=\n\n# Create a graph of internal dependencies in the given file (report RP0402 must\n# not be disabled).\nint-import-graph=\n\n# Force import order to recognize a module as part of the standard\n# compatibility libraries.\nknown-standard-library=\n\n# Force import order to recognize a module as part of a third party library.\nknown-third-party=enchant\n\n\n[DESIGN]\n\n# Maximum number of arguments for function / method.\nmax-args=5\n\n# Maximum number of attributes for a class (see R0902).\nmax-attributes=7\n\n# Maximum number of boolean expressions in an if statement.\nmax-bool-expr=5\n\n# Maximum number of branch for function / method body.\nmax-branches=12\n\n# Maximum number of locals for function / method body.\nmax-locals=15\n\n# Maximum number of parents for a class (see R0901).\nmax-parents=7\n\n# Maximum number of public methods for a class (see R0904).\nmax-public-methods=20\n\n# Maximum number of return / yield for function / method body.\nmax-returns=6\n\n# Maximum number of statements in function / method body.\nmax-statements=50\n\n# Minimum number of public methods for a class (see R0903).\nmin-public-methods=2\n\n\n[CLASSES]\n\n# List of method names used to declare (i.e. assign) instance attributes.\ndefining-attr-methods=__init__,\n                      __new__,\n                      setUp\n\n# List of member names, which should be excluded from the protected access\n# warning.\nexclude-protected=_asdict,\n                  _fields,\n                  _replace,\n                  _source,\n                  _make\n\n# List of valid names for the first argument in a class method.\nvalid-classmethod-first-arg=cls\n\n# List of valid names for the first argument in a metaclass class method.\nvalid-metaclass-classmethod-first-arg=cls\n\n\n[EXCEPTIONS]\n\n# Exceptions that will emit a warning when being caught. Defaults to\n# \"Exception\".\novergeneral-exceptions=Exception\n"
  },
  {
    "path": "tests/lint/ufmt_linter.py",
    "content": "\"\"\"Borrowed from github.com/pytorch/pytorch/tools/linter/adapters/ufmt_linter.py\"\"\"\nimport argparse\nimport concurrent.futures\nimport json\nimport logging\nimport os\nimport sys\nfrom enum import Enum\nfrom pathlib import Path\nfrom typing import Any, List, NamedTuple, Optional\n\nfrom ufmt.core import make_black_config, ufmt_string\nfrom usort import Config as UsortConfig\n\n\nIS_WINDOWS: bool = os.name == \"nt\"\n\n\ndef eprint(*args: Any, **kwargs: Any) -> None:\n    print(*args, file=sys.stderr, flush=True, **kwargs)\n\n\nclass LintSeverity(str, Enum):\n    ERROR = \"error\"\n    WARNING = \"warning\"\n    ADVICE = \"advice\"\n    DISABLED = \"disabled\"\n\n\nclass LintMessage(NamedTuple):\n    path: Optional[str]\n    line: Optional[int]\n    char: Optional[int]\n    code: str\n    severity: LintSeverity\n    name: str\n    original: Optional[str]\n    replacement: Optional[str]\n    description: Optional[str]\n\n\ndef as_posix(name: str) -> str:\n    return name.replace(\"\\\\\", \"/\") if IS_WINDOWS else name\n\n\ndef format_error_message(filename: str, err: Exception) -> LintMessage:\n    return LintMessage(\n        path=filename,\n        line=None,\n        char=None,\n        code=\"UFMT\",\n        severity=LintSeverity.ADVICE,\n        name=\"command-failed\",\n        original=None,\n        replacement=None,\n        description=(f\"Failed due to {err.__class__.__name__}:\\n{err}\"),\n    )\n\n\ndef check_file(\n    filename: str,\n) -> List[LintMessage]:\n    with open(filename, \"rb\") as f:\n        original = f.read().decode(\"utf-8\")\n\n    try:\n        path = Path(filename)\n\n        usort_config = UsortConfig.find(path)\n        black_config = make_black_config(path)\n\n        # Use UFMT API to call both usort and black\n        replacement = ufmt_string(\n            path=path,\n            content=original,\n            usort_config=usort_config,\n            black_config=black_config,\n        )\n\n        if original == replacement:\n            return []\n\n        line = 0\n        for line, (i, j) in enumerate(\n            zip(original.split(\"\\n\"), replacement.split(\"\\n\"))\n        ):\n            if i != j:\n                break\n\n        return [\n            LintMessage(\n                path=filename,\n                line=line,\n                char=None,\n                code=\"UFMT\",\n                severity=LintSeverity.WARNING,\n                name=\"format\",\n                original=original,\n                replacement=replacement,\n                description=\"Run `lintrunner -a` to apply this patch.\",\n            )\n        ]\n    except Exception as err:\n        return [format_error_message(filename, err)]\n\n\ndef main() -> None:\n    parser = argparse.ArgumentParser(\n        description=\"Format files with ufmt (black + usort).\",\n        fromfile_prefix_chars=\"@\",\n    )\n    parser.add_argument(\n        \"--verbose\",\n        action=\"store_true\",\n        help=\"verbose logging\",\n    )\n    parser.add_argument(\n        \"filenames\",\n        nargs=\"+\",\n        help=\"paths to lint\",\n    )\n    args = parser.parse_args()\n\n    logging.basicConfig(\n        format=\"<%(threadName)s:%(levelname)s> %(message)s\",\n        level=logging.NOTSET\n        if args.verbose\n        else logging.DEBUG\n        if len(args.filenames) < 1000\n        else logging.INFO,\n        stream=sys.stderr,\n    )\n\n    with concurrent.futures.ThreadPoolExecutor(\n        max_workers=os.cpu_count(),\n        thread_name_prefix=\"Thread\",\n    ) as executor:\n        futures = {executor.submit(check_file, x): x for x in args.filenames}\n        for future in concurrent.futures.as_completed(futures):\n            try:\n                for lint_message in future.result():\n                    print(json.dumps(lint_message._asdict()), flush=True)\n            except Exception:\n                logging.critical('Failed at \"%s\".', futures[future])\n                raise\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "tests/python/common/backend/test_set_default_backend.py",
    "content": "import os\nimport unittest\n\nimport backend as F\n\n\ndef test_set_default_backend():\n    default_dir = os.path.join(os.path.expanduser(\"~\"), \".dgl_unit_test\")\n    F.set_default_backend(default_dir, \"pytorch\")\n\n    # make sure the config file was created\n    assert os.path.exists(os.path.join(default_dir, \"config.json\"))\n"
  },
  {
    "path": "tests/python/common/backend/test_tensor.py",
    "content": "import unittest\n\nimport backend as F\n\nimport dgl\nimport dgl.ndarray as nd\nimport numpy as np\n\n\n@unittest.skipIf(\n    dgl.backend.backend_name == \"tensorflow\",\n    reason=\"TF doesn't support inplace update\",\n)\ndef test_dlpack():\n    # test dlpack conversion.\n    def nd2th():\n        ans = np.array(\n            [[1.0, 1.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]\n        )\n        x = nd.array(np.zeros((3, 4), dtype=np.float32))\n        dl = x.to_dlpack()\n        y = F.zerocopy_from_dlpack(dl)\n        y[0] = 1\n        print(x)\n        print(y)\n        assert np.allclose(x.asnumpy(), ans)\n\n    def th2nd():\n        ans = np.array(\n            [[1.0, 1.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]\n        )\n        x = F.zeros((3, 4))\n        dl = F.zerocopy_to_dlpack(x)\n        y = nd.from_dlpack(dl)\n        x[0] = 1\n        print(x)\n        print(y)\n        assert np.allclose(y.asnumpy(), ans)\n\n    def th2nd_incontiguous():\n        x = F.astype(F.tensor([[0, 1], [2, 3]]), F.int64)\n        ans = np.array([0, 2])\n        y = x[:2, 0]\n        # Uncomment this line and comment the one below to observe error\n        # dl = dlpack.to_dlpack(y)\n        dl = F.zerocopy_to_dlpack(y)\n        z = nd.from_dlpack(dl)\n        print(x)\n        print(z)\n        assert np.allclose(z.asnumpy(), ans)\n\n    nd2th()\n    th2nd()\n    th2nd_incontiguous()\n"
  },
  {
    "path": "tests/python/common/cuda/test_gpu_cache.py",
    "content": "#\n#   Copyright (c) 2022 by Contributors\n#\n#   Licensed under the Apache License, Version 2.0 (the \"License\");\n#   you may not use this file except in compliance with the License.\n#   You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n#   Unless required by applicable law or agreed to in writing, software\n#   distributed under the License is distributed on an \"AS IS\" BASIS,\n#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#   See the License for the specific language governing permissions and\n#   limitations under the License.\n#\n\nimport unittest\n\nimport backend as F\n\nimport dgl\nfrom utils import parametrize_idtype\n\nD = 5\n\n\ndef generate_graph(idtype, grad=False, add_data=True):\n    g = dgl.graph([]).to(F.ctx(), dtype=idtype)\n    g.add_nodes(10)\n    u, v = [], []\n    # create a graph where 0 is the source and 9 is the sink\n    for i in range(1, 9):\n        u.append(0)\n        v.append(i)\n        u.append(i)\n        v.append(9)\n    # add a back flow from 9 to 0\n    u.append(9)\n    v.append(0)\n    g.add_edges(u, v)\n    if add_data:\n        ncol = F.randn((10, D))\n        ecol = F.randn((17, D))\n        if grad:\n            ncol = F.attach_grad(ncol)\n            ecol = F.attach_grad(ecol)\n        g.ndata[\"h\"] = ncol\n        g.edata[\"l\"] = ecol\n    return g\n\n\n@unittest.skipIf(not F.gpu_ctx(), reason=\"only necessary with GPU\")\n@parametrize_idtype\ndef test_gpu_cache(idtype):\n    g = generate_graph(idtype)\n    cache = dgl.cuda.GPUCache(5, D, idtype)\n    h = g.ndata[\"h\"]\n\n    t = 5\n    keys = F.arange(0, t, dtype=idtype)\n    values, m_idx, m_keys = cache.query(keys)\n    m_values = h[F.tensor(m_keys, F.int64)]\n    values[F.tensor(m_idx, F.int64)] = m_values\n    cache.replace(m_keys, m_values)\n\n    keys = F.arange(3, 8, dtype=idtype)\n    values, m_idx, m_keys = cache.query(keys)\n    assert m_keys.shape[0] == 3 and m_idx.shape[0] == 3\n    m_values = h[F.tensor(m_keys, F.int64)]\n    values[F.tensor(m_idx, F.int64)] = m_values\n    assert (values != h[F.tensor(keys, F.int64)]).sum().item() == 0\n    cache.replace(m_keys, m_values)\n\n\nif __name__ == \"__main__\":\n    test_gpu_cache(F.int64)\n    test_gpu_cache(F.int32)\n"
  },
  {
    "path": "tests/python/common/data/data/test_heterophilous_graphs.py",
    "content": "import unittest\n\nimport backend as F\n\nimport dgl\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Datasets don't need to be tested on GPU.\",\n)\n@unittest.skipIf(\n    dgl.backend.backend_name != \"pytorch\",\n    reason=\"Only supports PyTorch backend.\",\n)\ndef test_roman_empire():\n    transform = dgl.AddSelfLoop(allow_duplicate=True)\n\n    g = dgl.data.RomanEmpireDataset(force_reload=True)[0]\n    assert g.num_nodes() == 22662\n    assert g.num_edges() == 65854\n    g2 = dgl.data.RomanEmpireDataset(force_reload=True, transform=transform)[0]\n    assert g2.num_edges() - g.num_edges() == g.num_nodes()\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Datasets don't need to be tested on GPU.\",\n)\n@unittest.skipIf(\n    dgl.backend.backend_name != \"pytorch\",\n    reason=\"Only supports PyTorch backend.\",\n)\ndef test_amazon_ratings():\n    transform = dgl.AddSelfLoop(allow_duplicate=True)\n\n    g = dgl.data.AmazonRatingsDataset(force_reload=True)[0]\n    assert g.num_nodes() == 24492\n    assert g.num_edges() == 186100\n    g2 = dgl.data.AmazonRatingsDataset(force_reload=True, transform=transform)[\n        0\n    ]\n    assert g2.num_edges() - g.num_edges() == g.num_nodes()\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Datasets don't need to be tested on GPU.\",\n)\n@unittest.skipIf(\n    dgl.backend.backend_name != \"pytorch\",\n    reason=\"Only supports PyTorch backend.\",\n)\ndef test_minesweeper():\n    transform = dgl.AddSelfLoop(allow_duplicate=True)\n\n    g = dgl.data.MinesweeperDataset(force_reload=True)[0]\n    assert g.num_nodes() == 10000\n    assert g.num_edges() == 78804\n    g2 = dgl.data.MinesweeperDataset(force_reload=True, transform=transform)[0]\n    assert g2.num_edges() - g.num_edges() == g.num_nodes()\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Datasets don't need to be tested on GPU.\",\n)\n@unittest.skipIf(\n    dgl.backend.backend_name != \"pytorch\",\n    reason=\"Only supports PyTorch backend.\",\n)\ndef test_tolokers():\n    transform = dgl.AddSelfLoop(allow_duplicate=True)\n\n    g = dgl.data.TolokersDataset(force_reload=True)[0]\n    assert g.num_nodes() == 11758\n    assert g.num_edges() == 1038000\n    g2 = dgl.data.TolokersDataset(force_reload=True, transform=transform)[0]\n    assert g2.num_edges() - g.num_edges() == g.num_nodes()\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Datasets don't need to be tested on GPU.\",\n)\n@unittest.skipIf(\n    dgl.backend.backend_name != \"pytorch\",\n    reason=\"Only supports PyTorch backend.\",\n)\ndef test_questions():\n    transform = dgl.AddSelfLoop(allow_duplicate=True)\n\n    g = dgl.data.QuestionsDataset(force_reload=True)[0]\n    assert g.num_nodes() == 48921\n    assert g.num_edges() == 307080\n    g2 = dgl.data.QuestionsDataset(force_reload=True, transform=transform)[0]\n    assert g2.num_edges() - g.num_edges() == g.num_nodes()\n"
  },
  {
    "path": "tests/python/common/data/test_actor.py",
    "content": "import unittest\n\nimport backend as F\n\nimport dgl\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Datasets don't need to be tested on GPU.\",\n)\n@unittest.skipIf(\n    dgl.backend.backend_name != \"pytorch\", reason=\"only supports pytorch\"\n)\ndef test_actor():\n    transform = dgl.AddSelfLoop(allow_duplicate=True)\n\n    g = dgl.data.ActorDataset(force_reload=True)[0]\n    assert g.num_nodes() == 7600\n    assert g.num_edges() == 33391\n    g2 = dgl.data.ActorDataset(force_reload=True, transform=transform)[0]\n    assert g2.num_edges() - g.num_edges() == g.num_nodes()\n"
  },
  {
    "path": "tests/python/common/data/test_data.py",
    "content": "import gzip\nimport io\nimport os\nimport tarfile\nimport tempfile\nimport unittest\nimport warnings\n\nimport backend as F\n\nimport dgl\nimport dgl.data as data\nimport numpy as np\nimport pandas as pd\nimport pytest\nimport yaml\nfrom dgl import DGLError\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Datasets don't need to be tested on GPU.\",\n)\n@unittest.skipIf(dgl.backend.backend_name == \"mxnet\", reason=\"Skip MXNet\")\ndef test_minigc():\n    ds = data.MiniGCDataset(16, 10, 20)\n    g, l = list(zip(*ds))\n    print(g, l)\n    g1 = ds[0][0]\n    transform = dgl.AddSelfLoop(allow_duplicate=True)\n    ds = data.MiniGCDataset(16, 10, 20, transform=transform)\n    g2 = ds[0][0]\n    assert g2.num_edges() - g1.num_edges() == g1.num_nodes()\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Datasets don't need to be tested on GPU.\",\n)\n@unittest.skipIf(dgl.backend.backend_name == \"mxnet\", reason=\"Skip MXNet\")\ndef test_gin():\n    ds_n_graphs = {\n        \"MUTAG\": 188,\n        \"IMDBBINARY\": 1000,\n        \"IMDBMULTI\": 1500,\n        \"PROTEINS\": 1113,\n        \"PTC\": 344,\n    }\n    transform = dgl.AddSelfLoop(allow_duplicate=True)\n    for name, n_graphs in ds_n_graphs.items():\n        ds = data.GINDataset(name, self_loop=False, degree_as_nlabel=False)\n        assert len(ds) == n_graphs, (len(ds), name)\n        g1 = ds[0][0]\n        ds = data.GINDataset(\n            name, self_loop=False, degree_as_nlabel=False, transform=transform\n        )\n        g2 = ds[0][0]\n        assert g2.num_edges() - g1.num_edges() == g1.num_nodes()\n        assert ds.num_classes == ds.gclasses\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Datasets don't need to be tested on GPU.\",\n)\n@unittest.skipIf(dgl.backend.backend_name == \"mxnet\", reason=\"Skip MXNet\")\ndef test_fraud():\n    transform = dgl.AddSelfLoop(allow_duplicate=True)\n\n    g = data.FraudDataset(\"amazon\")[0]\n    assert g.num_nodes() == 11944\n    num_edges1 = g.num_edges()\n    g2 = data.FraudDataset(\"amazon\", transform=transform)[0]\n    # 3 edge types\n    assert g2.num_edges() - num_edges1 == g.num_nodes() * 3\n\n    g = data.FraudAmazonDataset()[0]\n    assert g.num_nodes() == 11944\n    g2 = data.FraudAmazonDataset(transform=transform)[0]\n    # 3 edge types\n    assert g2.num_edges() - g.num_edges() == g.num_nodes() * 3\n\n    g = data.FraudYelpDataset()[0]\n    assert g.num_nodes() == 45954\n    g2 = data.FraudYelpDataset(transform=transform)[0]\n    # 3 edge types\n    assert g2.num_edges() - g.num_edges() == g.num_nodes() * 3\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Datasets don't need to be tested on GPU.\",\n)\n@unittest.skipIf(dgl.backend.backend_name == \"mxnet\", reason=\"Skip MXNet\")\ndef test_tudataset_regression():\n    ds = data.TUDataset(\"ZINC_test\", force_reload=True)\n    assert ds.num_classes == ds.num_labels\n    assert len(ds) == 5000\n    g = ds[0][0]\n\n    transform = dgl.AddSelfLoop(allow_duplicate=True)\n    ds = data.TUDataset(\"ZINC_test\", force_reload=True, transform=transform)\n    g2 = ds[0][0]\n    assert g2.num_edges() - g.num_edges() == g.num_nodes()\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Datasets don't need to be tested on GPU.\",\n)\n@unittest.skipIf(dgl.backend.backend_name == \"mxnet\", reason=\"Skip MXNet\")\ndef test_data_hash():\n    class HashTestDataset(data.DGLDataset):\n        def __init__(self, hash_key=()):\n            super(HashTestDataset, self).__init__(\"hashtest\", hash_key=hash_key)\n\n        def _load(self):\n            pass\n\n    a = HashTestDataset((True, 0, \"1\", (1, 2, 3)))\n    b = HashTestDataset((True, 0, \"1\", (1, 2, 3)))\n    c = HashTestDataset((True, 0, \"1\", (1, 2, 4)))\n    assert a.hash == b.hash\n    assert a.hash != c.hash\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Datasets don't need to be tested on GPU.\",\n)\n@unittest.skipIf(dgl.backend.backend_name == \"mxnet\", reason=\"Skip MXNet\")\ndef test_citation_graph():\n    transform = dgl.AddSelfLoop(allow_duplicate=True)\n\n    # cora\n    g = data.CoraGraphDataset(force_reload=True, reorder=True)[0]\n    assert g.num_nodes() == 2708\n    assert g.num_edges() == 10556\n    dst = F.asnumpy(g.edges()[1])\n    assert np.array_equal(dst, np.sort(dst))\n    g2 = data.CoraGraphDataset(transform=transform)[0]\n    assert g2.num_edges() - g.num_edges() == g.num_nodes()\n\n    # Citeseer\n    g = data.CiteseerGraphDataset(force_reload=True, reorder=True)[0]\n    assert g.num_nodes() == 3327\n    assert g.num_edges() == 9228\n    dst = F.asnumpy(g.edges()[1])\n    assert np.array_equal(dst, np.sort(dst))\n    g2 = data.CiteseerGraphDataset(transform=transform)[0]\n    assert g2.num_edges() - g.num_edges() == g.num_nodes()\n\n    # Pubmed\n    g = data.PubmedGraphDataset(force_reload=True, reorder=True)[0]\n    assert g.num_nodes() == 19717\n    assert g.num_edges() == 88651\n    dst = F.asnumpy(g.edges()[1])\n    assert np.array_equal(dst, np.sort(dst))\n    g2 = data.PubmedGraphDataset(transform=transform)[0]\n    assert g2.num_edges() - g.num_edges() == g.num_nodes()\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Datasets don't need to be tested on GPU.\",\n)\n@unittest.skipIf(dgl.backend.backend_name == \"mxnet\", reason=\"Skip MXNet\")\ndef test_gnn_benchmark():\n    transform = dgl.AddSelfLoop(allow_duplicate=True)\n\n    # AmazonCoBuyComputerDataset\n    g = data.AmazonCoBuyComputerDataset()[0]\n    assert g.num_nodes() == 13752\n    assert g.num_edges() == 491722\n    dst = F.asnumpy(g.edges()[1])\n    assert np.array_equal(dst, np.sort(dst))\n    g2 = data.AmazonCoBuyComputerDataset(transform=transform)[0]\n    assert g2.num_edges() - g.num_edges() == g.num_nodes()\n\n    # AmazonCoBuyPhotoDataset\n    g = data.AmazonCoBuyPhotoDataset()[0]\n    assert g.num_nodes() == 7650\n    assert g.num_edges() == 238163\n    dst = F.asnumpy(g.edges()[1])\n    assert np.array_equal(dst, np.sort(dst))\n    g2 = data.AmazonCoBuyPhotoDataset(transform=transform)[0]\n    assert g2.num_edges() - g.num_edges() == g.num_nodes()\n\n    # CoauthorPhysicsDataset\n    g = data.CoauthorPhysicsDataset()[0]\n    assert g.num_nodes() == 34493\n    assert g.num_edges() == 495924\n    dst = F.asnumpy(g.edges()[1])\n    assert np.array_equal(dst, np.sort(dst))\n    g2 = data.CoauthorPhysicsDataset(transform=transform)[0]\n    assert g2.num_edges() - g.num_edges() == g.num_nodes()\n\n    # CoauthorCSDataset\n    g = data.CoauthorCSDataset()[0]\n    assert g.num_nodes() == 18333\n    assert g.num_edges() == 163788\n    dst = F.asnumpy(g.edges()[1])\n    assert np.array_equal(dst, np.sort(dst))\n    g2 = data.CoauthorCSDataset(transform=transform)[0]\n    assert g2.num_edges() - g.num_edges() == g.num_nodes()\n\n    # CoraFullDataset\n    g = data.CoraFullDataset()[0]\n    assert g.num_nodes() == 19793\n    assert g.num_edges() == 126842\n    dst = F.asnumpy(g.edges()[1])\n    assert np.array_equal(dst, np.sort(dst))\n    g2 = data.CoraFullDataset(transform=transform)[0]\n    assert g2.num_edges() - g.num_edges() == g.num_nodes()\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Datasets don't need to be tested on GPU.\",\n)\n@unittest.skipIf(dgl.backend.backend_name == \"mxnet\", reason=\"Skip MXNet\")\ndef test_explain_syn():\n    dataset = data.BAShapeDataset()\n    assert dataset.num_classes == 4\n    g = dataset[0]\n    assert \"label\" in g.ndata\n    assert \"feat\" in g.ndata\n\n    g1 = data.BAShapeDataset(force_reload=True, seed=0)[0]\n    src1, dst1 = g1.edges()\n    g2 = data.BAShapeDataset(force_reload=True, seed=0)[0]\n    src2, dst2 = g2.edges()\n    assert F.allclose(src1, src2)\n    assert F.allclose(dst1, dst2)\n\n    dataset = data.BACommunityDataset()\n    assert dataset.num_classes == 8\n    g = dataset[0]\n    assert \"label\" in g.ndata\n    assert \"feat\" in g.ndata\n\n    g1 = data.BACommunityDataset(force_reload=True, seed=0)[0]\n    src1, dst1 = g1.edges()\n    g2 = data.BACommunityDataset(force_reload=True, seed=0)[0]\n    src2, dst2 = g2.edges()\n    assert F.allclose(src1, src2)\n    assert F.allclose(dst1, dst2)\n\n    dataset = data.TreeCycleDataset()\n    assert dataset.num_classes == 2\n    g = dataset[0]\n    assert \"label\" in g.ndata\n    assert \"feat\" in g.ndata\n\n    g1 = data.TreeCycleDataset(force_reload=True, seed=0)[0]\n    src1, dst1 = g1.edges()\n    g2 = data.TreeCycleDataset(force_reload=True, seed=0)[0]\n    src2, dst2 = g2.edges()\n    assert F.allclose(src1, src2)\n    assert F.allclose(dst1, dst2)\n\n    dataset = data.TreeGridDataset()\n    assert dataset.num_classes == 2\n    g = dataset[0]\n    assert \"label\" in g.ndata\n    assert \"feat\" in g.ndata\n\n    g1 = data.TreeGridDataset(force_reload=True, seed=0)[0]\n    src1, dst1 = g1.edges()\n    g2 = data.TreeGridDataset(force_reload=True, seed=0)[0]\n    src2, dst2 = g2.edges()\n    assert F.allclose(src1, src2)\n    assert F.allclose(dst1, dst2)\n\n    dataset = data.BA2MotifDataset()\n    assert dataset.num_classes == 2\n    g, label = dataset[0]\n    assert \"feat\" in g.ndata\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Datasets don't need to be tested on GPU.\",\n)\n@unittest.skipIf(dgl.backend.backend_name == \"mxnet\", reason=\"Skip MXNet\")\ndef test_wiki_cs():\n    g = data.WikiCSDataset()[0]\n    assert g.num_nodes() == 11701\n    assert g.num_edges() == 431726\n    dst = F.asnumpy(g.edges()[1])\n    assert np.array_equal(dst, np.sort(dst))\n\n    transform = dgl.AddSelfLoop(allow_duplicate=True)\n    g2 = data.WikiCSDataset(transform=transform)[0]\n    assert g2.num_edges() - g.num_edges() == g.num_nodes()\n\n\n@unittest.skip(reason=\"Dataset too large to download for the latest CI.\")\n@unittest.skipIf(dgl.backend.backend_name == \"mxnet\", reason=\"Skip MXNet\")\ndef test_yelp():\n    g = data.YelpDataset(reorder=True)[0]\n    assert g.num_nodes() == 716847\n    assert g.num_edges() == 13954819\n    dst = F.asnumpy(g.edges()[1])\n    assert np.array_equal(dst, np.sort(dst))\n\n    transform = dgl.AddSelfLoop(allow_duplicate=True)\n    g2 = data.YelpDataset(reorder=True, transform=transform)[0]\n    assert g2.num_edges() - g.num_edges() == g.num_nodes()\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Datasets don't need to be tested on GPU.\",\n)\n@unittest.skipIf(dgl.backend.backend_name == \"mxnet\", reason=\"Skip MXNet\")\ndef test_flickr():\n    g = data.FlickrDataset(reorder=True)[0]\n    assert g.num_nodes() == 89250\n    assert g.num_edges() == 899756\n    dst = F.asnumpy(g.edges()[1])\n    assert np.array_equal(dst, np.sort(dst))\n\n    transform = dgl.AddSelfLoop(allow_duplicate=True)\n    g2 = data.FlickrDataset(reorder=True, transform=transform)[0]\n    assert g2.num_edges() - g.num_edges() == g.num_nodes()\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Datasets don't need to be tested on GPU.\",\n)\n@unittest.skipIf(dgl.backend.backend_name == \"mxnet\", reason=\"Skip MXNet\")\ndef test_pattern():\n    mode_n_graphs = {\n        \"train\": 10000,\n        \"valid\": 2000,\n        \"test\": 2000,\n    }\n    transform = dgl.AddSelfLoop(allow_duplicate=True)\n    for mode, n_graphs in mode_n_graphs.items():\n        ds = data.PATTERNDataset(mode=mode)\n        assert len(ds) == n_graphs, (len(ds), mode)\n        g1 = ds[0]\n        ds = data.PATTERNDataset(mode=mode, transform=transform)\n        g2 = ds[0]\n        assert g2.num_edges() - g1.num_edges() == g1.num_nodes()\n        assert ds.num_classes == 2\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Datasets don't need to be tested on GPU.\",\n)\n@unittest.skipIf(dgl.backend.backend_name == \"mxnet\", reason=\"Skip MXNet\")\ndef test_cluster():\n    mode_n_graphs = {\n        \"train\": 10000,\n        \"valid\": 1000,\n        \"test\": 1000,\n    }\n    transform = dgl.AddSelfLoop(allow_duplicate=True)\n    for mode, n_graphs in mode_n_graphs.items():\n        ds = data.CLUSTERDataset(mode=mode)\n        assert len(ds) == n_graphs, (len(ds), mode)\n        g1 = ds[0]\n        ds = data.CLUSTERDataset(mode=mode, transform=transform)\n        g2 = ds[0]\n        assert g2.num_edges() - g1.num_edges() == g1.num_nodes()\n        assert ds.num_classes == 6\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Datasets don't need to be tested on GPU.\",\n)\n@unittest.skipIf(\n    dgl.backend.backend_name != \"pytorch\", reason=\"only supports pytorch\"\n)\ndef test_zinc():\n    mode_n_graphs = {\n        \"train\": 10000,\n        \"valid\": 1000,\n        \"test\": 1000,\n    }\n    transform = dgl.AddSelfLoop(allow_duplicate=True)\n    for mode, n_graphs in mode_n_graphs.items():\n        dataset1 = data.ZINCDataset(mode=mode)\n        g1, label = dataset1[0]\n        dataset2 = data.ZINCDataset(mode=mode, transform=transform)\n        g2, _ = dataset2[0]\n\n        assert g2.num_edges() - g1.num_edges() == g1.num_nodes()\n        # return a scalar tensor\n        assert not label.shape\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Datasets don't need to be tested on GPU.\",\n)\n@unittest.skipIf(dgl.backend.backend_name == \"mxnet\", reason=\"Skip MXNet\")\ndef test_extract_archive():\n    # gzip\n    with tempfile.TemporaryDirectory() as src_dir:\n        gz_file = \"gz_archive\"\n        gz_path = os.path.join(src_dir, gz_file + \".gz\")\n        content = b\"test extract archive gzip\"\n        with gzip.open(gz_path, \"wb\") as f:\n            f.write(content)\n        with tempfile.TemporaryDirectory() as dst_dir:\n            data.utils.extract_archive(gz_path, dst_dir, overwrite=True)\n            assert os.path.exists(os.path.join(dst_dir, gz_file))\n\n    # tar\n    with tempfile.TemporaryDirectory() as src_dir:\n        tar_file = \"tar_archive\"\n        tar_path = os.path.join(src_dir, tar_file + \".tar\")\n        # default encode to utf8\n        content = \"test extract archive tar\\n\".encode()\n        info = tarfile.TarInfo(name=\"tar_archive\")\n        info.size = len(content)\n        with tarfile.open(tar_path, \"w\") as f:\n            f.addfile(info, io.BytesIO(content))\n        with tempfile.TemporaryDirectory() as dst_dir:\n            data.utils.extract_archive(tar_path, dst_dir, overwrite=True)\n            assert os.path.exists(os.path.join(dst_dir, tar_file))\n\n\ndef _test_construct_graphs_node_ids():\n    from dgl.data.csv_dataset_base import (\n        DGLGraphConstructor,\n        EdgeData,\n        NodeData,\n    )\n\n    num_nodes = 100\n    num_edges = 1000\n\n    # node IDs are required to be unique\n    node_ids = np.random.choice(np.arange(num_nodes / 2), num_nodes)\n    src_ids = np.random.choice(node_ids, size=num_edges)\n    dst_ids = np.random.choice(node_ids, size=num_edges)\n    node_data = NodeData(node_ids, {})\n    edge_data = EdgeData(src_ids, dst_ids, {})\n    expect_except = False\n    try:\n        _, _ = DGLGraphConstructor.construct_graphs(node_data, edge_data)\n    except:\n        expect_except = True\n    assert expect_except\n\n    # node IDs are already labelled from 0~num_nodes-1\n    node_ids = np.arange(num_nodes)\n    np.random.shuffle(node_ids)\n    _, idx = np.unique(node_ids, return_index=True)\n    src_ids = np.random.choice(node_ids, size=num_edges)\n    dst_ids = np.random.choice(node_ids, size=num_edges)\n    node_feat = np.random.rand(num_nodes, 3)\n    node_data = NodeData(node_ids, {\"feat\": node_feat})\n    edge_data = EdgeData(src_ids, dst_ids, {})\n    graphs, data_dict = DGLGraphConstructor.construct_graphs(\n        node_data, edge_data\n    )\n    assert len(graphs) == 1\n    assert len(data_dict) == 0\n    g = graphs[0]\n    assert g.is_homogeneous\n    assert g.num_nodes() == len(node_ids)\n    assert g.num_edges() == len(src_ids)\n    assert F.array_equal(\n        F.tensor(node_feat[idx], dtype=F.float32), g.ndata[\"feat\"]\n    )\n\n    # node IDs are mixed with numeric and non-numeric values\n    # homogeneous graph\n    node_ids = [1, 2, 3, \"a\"]\n    src_ids = [1, 2, 3]\n    dst_ids = [\"a\", 1, 2]\n    node_data = NodeData(node_ids, {})\n    edge_data = EdgeData(src_ids, dst_ids, {})\n    graphs, data_dict = DGLGraphConstructor.construct_graphs(\n        node_data, edge_data\n    )\n    assert len(graphs) == 1\n    assert len(data_dict) == 0\n    g = graphs[0]\n    assert g.is_homogeneous\n    assert g.num_nodes() == len(node_ids)\n    assert g.num_edges() == len(src_ids)\n\n    # heterogeneous graph\n    node_ids_user = [1, 2, 3]\n    node_ids_item = [\"a\", \"b\", \"c\"]\n    src_ids = node_ids_user\n    dst_ids = node_ids_item\n    node_data_user = NodeData(node_ids_user, {}, type=\"user\")\n    node_data_item = NodeData(node_ids_item, {}, type=\"item\")\n    edge_data = EdgeData(src_ids, dst_ids, {}, type=(\"user\", \"like\", \"item\"))\n    graphs, data_dict = DGLGraphConstructor.construct_graphs(\n        [node_data_user, node_data_item], edge_data\n    )\n    assert len(graphs) == 1\n    assert len(data_dict) == 0\n    g = graphs[0]\n    assert not g.is_homogeneous\n    assert g.num_nodes(\"user\") == len(node_ids_user)\n    assert g.num_nodes(\"item\") == len(node_ids_item)\n    assert g.num_edges() == len(src_ids)\n\n\ndef _test_construct_graphs_homo():\n    from dgl.data.csv_dataset_base import (\n        DGLGraphConstructor,\n        EdgeData,\n        NodeData,\n    )\n\n    # node_id could be non-sorted, non-numeric.\n    num_nodes = 100\n    num_edges = 1000\n    num_dims = 3\n    node_ids = np.random.choice(\n        np.arange(num_nodes * 2), size=num_nodes, replace=False\n    )\n    assert len(node_ids) == num_nodes\n    # to be non-sorted\n    np.random.shuffle(node_ids)\n    # to be non-numeric\n    node_ids = [\"id_{}\".format(id) for id in node_ids]\n    t_ndata = {\n        \"feat\": np.random.rand(num_nodes, num_dims),\n        \"label\": np.random.randint(2, size=num_nodes),\n    }\n    _, u_indices = np.unique(node_ids, return_index=True)\n    ndata = {\n        \"feat\": t_ndata[\"feat\"][u_indices],\n        \"label\": t_ndata[\"label\"][u_indices],\n    }\n    node_data = NodeData(node_ids, t_ndata)\n    src_ids = np.random.choice(node_ids, size=num_edges)\n    dst_ids = np.random.choice(node_ids, size=num_edges)\n    edata = {\n        \"feat\": np.random.rand(num_edges, num_dims),\n        \"label\": np.random.randint(2, size=num_edges),\n    }\n    edge_data = EdgeData(src_ids, dst_ids, edata)\n    graphs, data_dict = DGLGraphConstructor.construct_graphs(\n        node_data, edge_data\n    )\n    assert len(graphs) == 1\n    assert len(data_dict) == 0\n    g = graphs[0]\n    assert g.is_homogeneous\n    assert g.num_nodes() == num_nodes\n    assert g.num_edges() == num_edges\n\n    def assert_data(lhs, rhs):\n        for key, value in lhs.items():\n            assert key in rhs\n            assert F.dtype(rhs[key]) != F.float64\n            assert F.array_equal(\n                F.tensor(value, dtype=F.dtype(rhs[key])), rhs[key]\n            )\n\n    assert_data(ndata, g.ndata)\n    assert_data(edata, g.edata)\n\n\ndef _test_construct_graphs_hetero():\n    from dgl.data.csv_dataset_base import (\n        DGLGraphConstructor,\n        EdgeData,\n        NodeData,\n    )\n\n    # node_id/src_id/dst_id could be non-sorted, duplicated, non-numeric.\n    num_nodes = 100\n    num_edges = 1000\n    num_dims = 3\n    ntypes = [\"user\", \"item\"]\n    node_data = []\n    node_ids_dict = {}\n    ndata_dict = {}\n    for ntype in ntypes:\n        node_ids = np.random.choice(\n            np.arange(num_nodes * 2), size=num_nodes, replace=False\n        )\n        assert len(node_ids) == num_nodes\n        # to be non-sorted\n        np.random.shuffle(node_ids)\n        # to be non-numeric\n        node_ids = [\"id_{}\".format(id) for id in node_ids]\n        t_ndata = {\n            \"feat\": np.random.rand(num_nodes, num_dims),\n            \"label\": np.random.randint(2, size=num_nodes),\n        }\n        _, u_indices = np.unique(node_ids, return_index=True)\n        ndata = {\n            \"feat\": t_ndata[\"feat\"][u_indices],\n            \"label\": t_ndata[\"label\"][u_indices],\n        }\n        node_data.append(NodeData(node_ids, t_ndata, type=ntype))\n        node_ids_dict[ntype] = node_ids\n        ndata_dict[ntype] = ndata\n    etypes = [(\"user\", \"follow\", \"user\"), (\"user\", \"like\", \"item\")]\n    edge_data = []\n    edata_dict = {}\n    for src_type, e_type, dst_type in etypes:\n        src_ids = np.random.choice(node_ids_dict[src_type], size=num_edges)\n        dst_ids = np.random.choice(node_ids_dict[dst_type], size=num_edges)\n        edata = {\n            \"feat\": np.random.rand(num_edges, num_dims),\n            \"label\": np.random.randint(2, size=num_edges),\n        }\n        edge_data.append(\n            EdgeData(src_ids, dst_ids, edata, type=(src_type, e_type, dst_type))\n        )\n        edata_dict[(src_type, e_type, dst_type)] = edata\n    graphs, data_dict = DGLGraphConstructor.construct_graphs(\n        node_data, edge_data\n    )\n    assert len(graphs) == 1\n    assert len(data_dict) == 0\n    g = graphs[0]\n    assert not g.is_homogeneous\n    assert g.num_nodes() == num_nodes * len(ntypes)\n    assert g.num_edges() == num_edges * len(etypes)\n\n    def assert_data(lhs, rhs):\n        for key, value in lhs.items():\n            assert key in rhs\n            assert F.dtype(rhs[key]) != F.float64\n            assert F.array_equal(\n                F.tensor(value, dtype=F.dtype(rhs[key])), rhs[key]\n            )\n\n    for ntype in g.ntypes:\n        assert g.num_nodes(ntype) == num_nodes\n        assert_data(ndata_dict[ntype], g.nodes[ntype].data)\n    for etype in g.canonical_etypes:\n        assert g.num_edges(etype) == num_edges\n        assert_data(edata_dict[etype], g.edges[etype].data)\n\n\ndef _test_construct_graphs_multiple():\n    from dgl.data.csv_dataset_base import (\n        DGLGraphConstructor,\n        EdgeData,\n        GraphData,\n        NodeData,\n    )\n\n    num_nodes = 100\n    num_edges = 1000\n    num_graphs = 10\n    num_dims = 3\n    node_ids = np.array([], dtype=int)\n    src_ids = np.array([], dtype=int)\n    dst_ids = np.array([], dtype=int)\n    ngraph_ids = np.array([], dtype=int)\n    egraph_ids = np.array([], dtype=int)\n    u_indices = np.array([], dtype=int)\n    for i in range(num_graphs):\n        l_node_ids = np.random.choice(\n            np.arange(num_nodes * 2), size=num_nodes, replace=False\n        )\n        node_ids = np.append(node_ids, l_node_ids)\n        _, l_u_indices = np.unique(l_node_ids, return_index=True)\n        u_indices = np.append(u_indices, l_u_indices)\n        ngraph_ids = np.append(ngraph_ids, np.full(num_nodes, i))\n        src_ids = np.append(\n            src_ids, np.random.choice(l_node_ids, size=num_edges)\n        )\n        dst_ids = np.append(\n            dst_ids, np.random.choice(l_node_ids, size=num_edges)\n        )\n        egraph_ids = np.append(egraph_ids, np.full(num_edges, i))\n    ndata = {\n        \"feat\": np.random.rand(num_nodes * num_graphs, num_dims),\n        \"label\": np.random.randint(2, size=num_nodes * num_graphs),\n    }\n    ngraph_ids = [\"graph_{}\".format(id) for id in ngraph_ids]\n    node_data = NodeData(node_ids, ndata, graph_id=ngraph_ids)\n    egraph_ids = [\"graph_{}\".format(id) for id in egraph_ids]\n    edata = {\n        \"feat\": np.random.rand(num_edges * num_graphs, num_dims),\n        \"label\": np.random.randint(2, size=num_edges * num_graphs),\n    }\n    edge_data = EdgeData(src_ids, dst_ids, edata, graph_id=egraph_ids)\n    gdata = {\n        \"feat\": np.random.rand(num_graphs, num_dims),\n        \"label\": np.random.randint(2, size=num_graphs),\n    }\n    graph_ids = [\"graph_{}\".format(id) for id in np.arange(num_graphs)]\n    graph_data = GraphData(graph_ids, gdata)\n    graphs, data_dict = DGLGraphConstructor.construct_graphs(\n        node_data, edge_data, graph_data\n    )\n    assert len(graphs) == num_graphs\n    assert len(data_dict) == len(gdata)\n    for k, v in data_dict.items():\n        assert F.dtype(v) != F.float64\n        assert F.array_equal(\n            F.reshape(F.tensor(gdata[k], dtype=F.dtype(v)), (len(graphs), -1)),\n            v,\n        )\n    for i, g in enumerate(graphs):\n        assert g.is_homogeneous\n        assert g.num_nodes() == num_nodes\n        assert g.num_edges() == num_edges\n\n        def assert_data(lhs, rhs, size, node=False):\n            for key, value in lhs.items():\n                assert key in rhs\n                value = value[i * size : (i + 1) * size]\n                if node:\n                    indices = u_indices[i * size : (i + 1) * size]\n                    value = value[indices]\n                assert F.dtype(rhs[key]) != F.float64\n                assert F.array_equal(\n                    F.tensor(value, dtype=F.dtype(rhs[key])), rhs[key]\n                )\n\n        assert_data(ndata, g.ndata, num_nodes, node=True)\n        assert_data(edata, g.edata, num_edges)\n\n    # Graph IDs found in node/edge CSV but not in graph CSV\n    graph_data = GraphData(np.arange(num_graphs - 2), {})\n    expect_except = False\n    try:\n        _, _ = DGLGraphConstructor.construct_graphs(\n            node_data, edge_data, graph_data\n        )\n    except:\n        expect_except = True\n    assert expect_except\n\n\ndef _get_data_table(data_frame, save_index=False):\n    from dgl.data.csv_dataset_base import DefaultDataParser\n\n    with tempfile.TemporaryDirectory() as test_dir:\n        csv_path = os.path.join(test_dir, \"nodes.csv\")\n\n        data_frame.to_csv(csv_path, index=save_index)\n        dp = DefaultDataParser()\n        df = pd.read_csv(csv_path)\n\n    # Warning suppression : \"Untitled column found. Ignored...\",\n    # which appears when a CSV file is saved with an index:\n    #    data_frame.to_csv(csv_path, index=True).\n    with warnings.catch_warnings():\n        warnings.simplefilter(\"ignore\", category=UserWarning)\n        return dp(df)\n\n\ndef _test_DefaultDataParser():\n    # common csv\n    num_nodes = 5\n    num_labels = 3\n    num_dims = 2\n    node_id = np.arange(num_nodes)\n    label = np.random.randint(num_labels, size=num_nodes)\n    feat = np.random.rand(num_nodes, num_dims)\n    df = pd.DataFrame(\n        {\n            \"node_id\": node_id,\n            \"label\": label,\n            \"feat\": [line.tolist() for line in feat],\n        }\n    )\n\n    dt = _get_data_table(df)\n    assert np.array_equal(node_id, dt[\"node_id\"])\n    assert np.array_equal(label, dt[\"label\"])\n    assert np.array_equal(feat, dt[\"feat\"])\n\n    # string consists of non-numeric values\n    df = pd.DataFrame({\"label\": [\"a\", \"b\", \"c\"]})\n    expect_except = False\n    try:\n        _get_data_table(df)\n    except:\n        expect_except = True\n    assert expect_except\n\n    # csv has index column which is ignored as it's unnamed\n    df = pd.DataFrame({\"label\": [1, 2, 3]})\n    dt = _get_data_table(df, True)\n    assert len(dt) == 1\n\n\ndef _test_load_yaml_with_sanity_check():\n    from dgl.data.csv_dataset_base import load_yaml_with_sanity_check\n\n    with tempfile.TemporaryDirectory() as test_dir:\n        yaml_path = os.path.join(test_dir, \"meta.yaml\")\n        # workable but meaningless usually\n        yaml_data = {\n            \"dataset_name\": \"default\",\n            \"node_data\": [],\n            \"edge_data\": [],\n        }\n        with open(yaml_path, \"w\") as f:\n            yaml.dump(yaml_data, f, sort_keys=False)\n        meta = load_yaml_with_sanity_check(yaml_path)\n        assert meta.version == \"1.0.0\"\n        assert meta.dataset_name == \"default\"\n        assert meta.separator == \",\"\n        assert len(meta.node_data) == 0\n        assert len(meta.edge_data) == 0\n        assert meta.graph_data is None\n        # minimum with required fields only\n        yaml_data = {\n            \"version\": \"1.0.0\",\n            \"dataset_name\": \"default\",\n            \"node_data\": [{\"file_name\": \"nodes.csv\"}],\n            \"edge_data\": [{\"file_name\": \"edges.csv\"}],\n        }\n        with open(yaml_path, \"w\") as f:\n            yaml.dump(yaml_data, f, sort_keys=False)\n        meta = load_yaml_with_sanity_check(yaml_path)\n        for ndata in meta.node_data:\n            assert ndata.file_name == \"nodes.csv\"\n            assert ndata.ntype == \"_V\"\n            assert ndata.graph_id_field == \"graph_id\"\n            assert ndata.node_id_field == \"node_id\"\n        for edata in meta.edge_data:\n            assert edata.file_name == \"edges.csv\"\n            assert edata.etype == [\"_V\", \"_E\", \"_V\"]\n            assert edata.graph_id_field == \"graph_id\"\n            assert edata.src_id_field == \"src_id\"\n            assert edata.dst_id_field == \"dst_id\"\n        # optional fields are specified\n        yaml_data = {\n            \"version\": \"1.0.0\",\n            \"dataset_name\": \"default\",\n            \"separator\": \"|\",\n            \"node_data\": [\n                {\n                    \"file_name\": \"nodes.csv\",\n                    \"ntype\": \"user\",\n                    \"graph_id_field\": \"xxx\",\n                    \"node_id_field\": \"xxx\",\n                }\n            ],\n            \"edge_data\": [\n                {\n                    \"file_name\": \"edges.csv\",\n                    \"etype\": [\"user\", \"follow\", \"user\"],\n                    \"graph_id_field\": \"xxx\",\n                    \"src_id_field\": \"xxx\",\n                    \"dst_id_field\": \"xxx\",\n                }\n            ],\n            \"graph_data\": {\"file_name\": \"graph.csv\", \"graph_id_field\": \"xxx\"},\n        }\n        with open(yaml_path, \"w\") as f:\n            yaml.dump(yaml_data, f, sort_keys=False)\n        meta = load_yaml_with_sanity_check(yaml_path)\n        assert len(meta.node_data) == 1\n        ndata = meta.node_data[0]\n        assert ndata.ntype == \"user\"\n        assert ndata.graph_id_field == \"xxx\"\n        assert ndata.node_id_field == \"xxx\"\n        assert len(meta.edge_data) == 1\n        edata = meta.edge_data[0]\n        assert edata.etype == [\"user\", \"follow\", \"user\"]\n        assert edata.graph_id_field == \"xxx\"\n        assert edata.src_id_field == \"xxx\"\n        assert edata.dst_id_field == \"xxx\"\n        assert meta.graph_data is not None\n        assert meta.graph_data.file_name == \"graph.csv\"\n        assert meta.graph_data.graph_id_field == \"xxx\"\n        # some required fields are missing\n        yaml_data = {\n            \"dataset_name\": \"default\",\n            \"node_data\": [],\n            \"edge_data\": [],\n        }\n        for field in yaml_data.keys():\n            ydata = {k: v for k, v in yaml_data.items()}\n            ydata.pop(field)\n            with open(yaml_path, \"w\") as f:\n                yaml.dump(ydata, f, sort_keys=False)\n            expect_except = False\n            try:\n                meta = load_yaml_with_sanity_check(yaml_path)\n            except:\n                expect_except = True\n            assert expect_except\n        # inapplicable version\n        yaml_data = {\n            \"version\": \"0.0.0\",\n            \"dataset_name\": \"default\",\n            \"node_data\": [{\"file_name\": \"nodes_0.csv\"}],\n            \"edge_data\": [{\"file_name\": \"edges_0.csv\"}],\n        }\n        with open(yaml_path, \"w\") as f:\n            yaml.dump(yaml_data, f, sort_keys=False)\n        expect_except = False\n        try:\n            meta = load_yaml_with_sanity_check(yaml_path)\n        except DGLError:\n            expect_except = True\n        assert expect_except\n        # duplicate node types\n        yaml_data = {\n            \"version\": \"1.0.0\",\n            \"dataset_name\": \"default\",\n            \"node_data\": [\n                {\"file_name\": \"nodes.csv\"},\n                {\"file_name\": \"nodes.csv\"},\n            ],\n            \"edge_data\": [{\"file_name\": \"edges.csv\"}],\n        }\n        with open(yaml_path, \"w\") as f:\n            yaml.dump(yaml_data, f, sort_keys=False)\n        expect_except = False\n        try:\n            meta = load_yaml_with_sanity_check(yaml_path)\n        except DGLError:\n            expect_except = True\n        assert expect_except\n        # duplicate edge types\n        yaml_data = {\n            \"version\": \"1.0.0\",\n            \"dataset_name\": \"default\",\n            \"node_data\": [{\"file_name\": \"nodes.csv\"}],\n            \"edge_data\": [\n                {\"file_name\": \"edges.csv\"},\n                {\"file_name\": \"edges.csv\"},\n            ],\n        }\n        with open(yaml_path, \"w\") as f:\n            yaml.dump(yaml_data, f, sort_keys=False)\n        expect_except = False\n        try:\n            meta = load_yaml_with_sanity_check(yaml_path)\n        except DGLError:\n            expect_except = True\n        assert expect_except\n\n\ndef _test_load_node_data_from_csv():\n    from dgl.data.csv_dataset_base import DefaultDataParser, MetaNode, NodeData\n\n    with tempfile.TemporaryDirectory() as test_dir:\n        num_nodes = 100\n        # minimum\n        df = pd.DataFrame({\"node_id\": np.arange(num_nodes)})\n        csv_path = os.path.join(test_dir, \"nodes.csv\")\n        df.to_csv(csv_path, index=False)\n        meta_node = MetaNode(file_name=csv_path)\n        node_data = NodeData.load_from_csv(meta_node, DefaultDataParser())\n        assert np.array_equal(df[\"node_id\"], node_data.id)\n        assert len(node_data.data) == 0\n\n        # common case\n        df = pd.DataFrame(\n            {\n                \"node_id\": np.arange(num_nodes),\n                \"label\": np.random.randint(3, size=num_nodes),\n            }\n        )\n        csv_path = os.path.join(test_dir, \"nodes.csv\")\n        df.to_csv(csv_path, index=False)\n        meta_node = MetaNode(file_name=csv_path)\n        node_data = NodeData.load_from_csv(meta_node, DefaultDataParser())\n        assert np.array_equal(df[\"node_id\"], node_data.id)\n        assert len(node_data.data) == 1\n        assert np.array_equal(df[\"label\"], node_data.data[\"label\"])\n        assert np.array_equal(np.full(num_nodes, 0), node_data.graph_id)\n        assert node_data.type == \"_V\"\n\n        # add more fields into nodes.csv\n        df = pd.DataFrame(\n            {\n                \"node_id\": np.arange(num_nodes),\n                \"label\": np.random.randint(3, size=num_nodes),\n                \"graph_id\": np.full(num_nodes, 1),\n            }\n        )\n        csv_path = os.path.join(test_dir, \"nodes.csv\")\n        df.to_csv(csv_path, index=False)\n        meta_node = MetaNode(file_name=csv_path)\n        node_data = NodeData.load_from_csv(meta_node, DefaultDataParser())\n        assert np.array_equal(df[\"node_id\"], node_data.id)\n        assert len(node_data.data) == 1\n        assert np.array_equal(df[\"label\"], node_data.data[\"label\"])\n        assert np.array_equal(df[\"graph_id\"], node_data.graph_id)\n        assert node_data.type == \"_V\"\n\n        # required header is missing\n        df = pd.DataFrame({\"label\": np.random.randint(3, size=num_nodes)})\n        csv_path = os.path.join(test_dir, \"nodes.csv\")\n        df.to_csv(csv_path, index=False)\n        meta_node = MetaNode(file_name=csv_path)\n        expect_except = False\n        try:\n            NodeData.load_from_csv(meta_node, DefaultDataParser())\n        except:\n            expect_except = True\n        assert expect_except\n\n\ndef _test_load_edge_data_from_csv():\n    from dgl.data.csv_dataset_base import DefaultDataParser, EdgeData, MetaEdge\n\n    with tempfile.TemporaryDirectory() as test_dir:\n        num_nodes = 100\n        num_edges = 1000\n        # minimum\n        df = pd.DataFrame(\n            {\n                \"src_id\": np.random.randint(num_nodes, size=num_edges),\n                \"dst_id\": np.random.randint(num_nodes, size=num_edges),\n            }\n        )\n        csv_path = os.path.join(test_dir, \"edges.csv\")\n        df.to_csv(csv_path, index=False)\n        meta_edge = MetaEdge(file_name=csv_path)\n        edge_data = EdgeData.load_from_csv(meta_edge, DefaultDataParser())\n        assert np.array_equal(df[\"src_id\"], edge_data.src)\n        assert np.array_equal(df[\"dst_id\"], edge_data.dst)\n        assert len(edge_data.data) == 0\n\n        # common case\n        df = pd.DataFrame(\n            {\n                \"src_id\": np.random.randint(num_nodes, size=num_edges),\n                \"dst_id\": np.random.randint(num_nodes, size=num_edges),\n                \"label\": np.random.randint(3, size=num_edges),\n            }\n        )\n        csv_path = os.path.join(test_dir, \"edges.csv\")\n        df.to_csv(csv_path, index=False)\n        meta_edge = MetaEdge(file_name=csv_path)\n        edge_data = EdgeData.load_from_csv(meta_edge, DefaultDataParser())\n        assert np.array_equal(df[\"src_id\"], edge_data.src)\n        assert np.array_equal(df[\"dst_id\"], edge_data.dst)\n        assert len(edge_data.data) == 1\n        assert np.array_equal(df[\"label\"], edge_data.data[\"label\"])\n        assert np.array_equal(np.full(num_edges, 0), edge_data.graph_id)\n        assert edge_data.type == (\"_V\", \"_E\", \"_V\")\n\n        # add more fields into edges.csv\n        df = pd.DataFrame(\n            {\n                \"src_id\": np.random.randint(num_nodes, size=num_edges),\n                \"dst_id\": np.random.randint(num_nodes, size=num_edges),\n                \"graph_id\": np.arange(num_edges),\n                \"feat\": np.random.randint(3, size=num_edges),\n                \"label\": np.random.randint(3, size=num_edges),\n            }\n        )\n        csv_path = os.path.join(test_dir, \"edges.csv\")\n        df.to_csv(csv_path, index=False)\n        meta_edge = MetaEdge(file_name=csv_path)\n        edge_data = EdgeData.load_from_csv(meta_edge, DefaultDataParser())\n        assert np.array_equal(df[\"src_id\"], edge_data.src)\n        assert np.array_equal(df[\"dst_id\"], edge_data.dst)\n        assert len(edge_data.data) == 2\n        assert np.array_equal(df[\"feat\"], edge_data.data[\"feat\"])\n        assert np.array_equal(df[\"label\"], edge_data.data[\"label\"])\n        assert np.array_equal(df[\"graph_id\"], edge_data.graph_id)\n        assert edge_data.type == (\"_V\", \"_E\", \"_V\")\n\n        # required headers are missing\n        df = pd.DataFrame(\n            {\"src_id\": np.random.randint(num_nodes, size=num_edges)}\n        )\n        csv_path = os.path.join(test_dir, \"edges.csv\")\n        df.to_csv(csv_path, index=False)\n        meta_edge = MetaEdge(file_name=csv_path)\n        expect_except = False\n        try:\n            EdgeData.load_from_csv(meta_edge, DefaultDataParser())\n        except DGLError:\n            expect_except = True\n        assert expect_except\n        df = pd.DataFrame(\n            {\"dst_id\": np.random.randint(num_nodes, size=num_edges)}\n        )\n        csv_path = os.path.join(test_dir, \"edges.csv\")\n        df.to_csv(csv_path, index=False)\n        meta_edge = MetaEdge(file_name=csv_path)\n        expect_except = False\n        try:\n            EdgeData.load_from_csv(meta_edge, DefaultDataParser())\n        except DGLError:\n            expect_except = True\n        assert expect_except\n\n\ndef _test_load_graph_data_from_csv():\n    from dgl.data.csv_dataset_base import (\n        DefaultDataParser,\n        GraphData,\n        MetaGraph,\n    )\n\n    with tempfile.TemporaryDirectory() as test_dir:\n        num_graphs = 100\n        # minimum\n        df = pd.DataFrame({\"graph_id\": np.arange(num_graphs)})\n        csv_path = os.path.join(test_dir, \"graph.csv\")\n        df.to_csv(csv_path, index=False)\n        meta_graph = MetaGraph(file_name=csv_path)\n        graph_data = GraphData.load_from_csv(meta_graph, DefaultDataParser())\n        assert np.array_equal(df[\"graph_id\"], graph_data.graph_id)\n        assert len(graph_data.data) == 0\n\n        # common case\n        df = pd.DataFrame(\n            {\n                \"graph_id\": np.arange(num_graphs),\n                \"label\": np.random.randint(3, size=num_graphs),\n            }\n        )\n        csv_path = os.path.join(test_dir, \"graph.csv\")\n        df.to_csv(csv_path, index=False)\n        meta_graph = MetaGraph(file_name=csv_path)\n        graph_data = GraphData.load_from_csv(meta_graph, DefaultDataParser())\n        assert np.array_equal(df[\"graph_id\"], graph_data.graph_id)\n        assert len(graph_data.data) == 1\n        assert np.array_equal(df[\"label\"], graph_data.data[\"label\"])\n\n        # add more fields into graph.csv\n        df = pd.DataFrame(\n            {\n                \"graph_id\": np.arange(num_graphs),\n                \"feat\": np.random.randint(3, size=num_graphs),\n                \"label\": np.random.randint(3, size=num_graphs),\n            }\n        )\n        csv_path = os.path.join(test_dir, \"graph.csv\")\n        df.to_csv(csv_path, index=False)\n        meta_graph = MetaGraph(file_name=csv_path)\n        graph_data = GraphData.load_from_csv(meta_graph, DefaultDataParser())\n        assert np.array_equal(df[\"graph_id\"], graph_data.graph_id)\n        assert len(graph_data.data) == 2\n        assert np.array_equal(df[\"feat\"], graph_data.data[\"feat\"])\n        assert np.array_equal(df[\"label\"], graph_data.data[\"label\"])\n\n        # required header is missing\n        df = pd.DataFrame({\"label\": np.random.randint(3, size=num_graphs)})\n        csv_path = os.path.join(test_dir, \"graph.csv\")\n        df.to_csv(csv_path, index=False)\n        meta_graph = MetaGraph(file_name=csv_path)\n        expect_except = False\n        try:\n            GraphData.load_from_csv(meta_graph, DefaultDataParser())\n        except DGLError:\n            expect_except = True\n        assert expect_except\n\n\ndef _test_CSVDataset_single():\n    with tempfile.TemporaryDirectory() as test_dir:\n        # generate YAML/CSVs\n        meta_yaml_path = os.path.join(test_dir, \"meta.yaml\")\n        edges_csv_path_0 = os.path.join(test_dir, \"test_edges_0.csv\")\n        edges_csv_path_1 = os.path.join(test_dir, \"test_edges_1.csv\")\n        nodes_csv_path_0 = os.path.join(test_dir, \"test_nodes_0.csv\")\n        nodes_csv_path_1 = os.path.join(test_dir, \"test_nodes_1.csv\")\n        meta_yaml_data = {\n            \"version\": \"1.0.0\",\n            \"dataset_name\": \"default_name\",\n            \"node_data\": [\n                {\n                    \"file_name\": os.path.basename(nodes_csv_path_0),\n                    \"ntype\": \"user\",\n                },\n                {\n                    \"file_name\": os.path.basename(nodes_csv_path_1),\n                    \"ntype\": \"item\",\n                },\n            ],\n            \"edge_data\": [\n                {\n                    \"file_name\": os.path.basename(edges_csv_path_0),\n                    \"etype\": [\"user\", \"follow\", \"user\"],\n                },\n                {\n                    \"file_name\": os.path.basename(edges_csv_path_1),\n                    \"etype\": [\"user\", \"like\", \"item\"],\n                },\n            ],\n        }\n        with open(meta_yaml_path, \"w\") as f:\n            yaml.dump(meta_yaml_data, f, sort_keys=False)\n        num_nodes = 100\n        num_edges = 500\n        num_dims = 3\n        feat_ndata = np.random.rand(num_nodes, num_dims)\n        label_ndata = np.random.randint(2, size=num_nodes)\n        df = pd.DataFrame(\n            {\n                \"node_id\": np.arange(num_nodes),\n                \"label\": label_ndata,\n                \"feat\": [line.tolist() for line in feat_ndata],\n            }\n        )\n        df.to_csv(nodes_csv_path_0, index=False)\n        df.to_csv(nodes_csv_path_1, index=False)\n        feat_edata = np.random.rand(num_edges, num_dims)\n        label_edata = np.random.randint(2, size=num_edges)\n        df = pd.DataFrame(\n            {\n                \"src_id\": np.random.randint(num_nodes, size=num_edges),\n                \"dst_id\": np.random.randint(num_nodes, size=num_edges),\n                \"label\": label_edata,\n                \"feat\": [line.tolist() for line in feat_edata],\n            }\n        )\n        df.to_csv(edges_csv_path_0, index=False)\n        df.to_csv(edges_csv_path_1, index=False)\n\n        # load CSVDataset\n        for force_reload in [True, False]:\n            if not force_reload:\n                # remove original node data file to verify reload from cached files\n                os.remove(nodes_csv_path_0)\n                assert not os.path.exists(nodes_csv_path_0)\n            csv_dataset = data.CSVDataset(test_dir, force_reload=force_reload)\n            assert len(csv_dataset) == 1\n            g = csv_dataset[0]\n            assert not g.is_homogeneous\n            assert csv_dataset.has_cache()\n            for ntype in g.ntypes:\n                assert g.num_nodes(ntype) == num_nodes\n                assert F.array_equal(\n                    F.tensor(feat_ndata, dtype=F.float32),\n                    g.nodes[ntype].data[\"feat\"],\n                )\n                assert np.array_equal(\n                    label_ndata, F.asnumpy(g.nodes[ntype].data[\"label\"])\n                )\n            for etype in g.etypes:\n                assert g.num_edges(etype) == num_edges\n                assert F.array_equal(\n                    F.tensor(feat_edata, dtype=F.float32),\n                    g.edges[etype].data[\"feat\"],\n                )\n                assert np.array_equal(\n                    label_edata, F.asnumpy(g.edges[etype].data[\"label\"])\n                )\n\n\ndef _test_CSVDataset_multiple():\n    with tempfile.TemporaryDirectory() as test_dir:\n        # generate YAML/CSVs\n        meta_yaml_path = os.path.join(test_dir, \"meta.yaml\")\n        edges_csv_path_0 = os.path.join(test_dir, \"test_edges_0.csv\")\n        edges_csv_path_1 = os.path.join(test_dir, \"test_edges_1.csv\")\n        nodes_csv_path_0 = os.path.join(test_dir, \"test_nodes_0.csv\")\n        nodes_csv_path_1 = os.path.join(test_dir, \"test_nodes_1.csv\")\n        graph_csv_path = os.path.join(test_dir, \"test_graph.csv\")\n        meta_yaml_data = {\n            \"version\": \"1.0.0\",\n            \"dataset_name\": \"default_name\",\n            \"node_data\": [\n                {\n                    \"file_name\": os.path.basename(nodes_csv_path_0),\n                    \"ntype\": \"user\",\n                },\n                {\n                    \"file_name\": os.path.basename(nodes_csv_path_1),\n                    \"ntype\": \"item\",\n                },\n            ],\n            \"edge_data\": [\n                {\n                    \"file_name\": os.path.basename(edges_csv_path_0),\n                    \"etype\": [\"user\", \"follow\", \"user\"],\n                },\n                {\n                    \"file_name\": os.path.basename(edges_csv_path_1),\n                    \"etype\": [\"user\", \"like\", \"item\"],\n                },\n            ],\n            \"graph_data\": {\"file_name\": os.path.basename(graph_csv_path)},\n        }\n        with open(meta_yaml_path, \"w\") as f:\n            yaml.dump(meta_yaml_data, f, sort_keys=False)\n        num_nodes = 100\n        num_edges = 500\n        num_graphs = 10\n        num_dims = 3\n        feat_ndata = np.random.rand(num_nodes * num_graphs, num_dims)\n        label_ndata = np.random.randint(2, size=num_nodes * num_graphs)\n        df = pd.DataFrame(\n            {\n                \"node_id\": np.hstack(\n                    [np.arange(num_nodes) for _ in range(num_graphs)]\n                ),\n                \"label\": label_ndata,\n                \"feat\": [line.tolist() for line in feat_ndata],\n                \"graph_id\": np.hstack(\n                    [np.full(num_nodes, i) for i in range(num_graphs)]\n                ),\n            }\n        )\n        df.to_csv(nodes_csv_path_0, index=False)\n        df.to_csv(nodes_csv_path_1, index=False)\n        feat_edata = np.random.rand(num_edges * num_graphs, num_dims)\n        label_edata = np.random.randint(2, size=num_edges * num_graphs)\n        df = pd.DataFrame(\n            {\n                \"src_id\": np.hstack(\n                    [\n                        np.random.randint(num_nodes, size=num_edges)\n                        for _ in range(num_graphs)\n                    ]\n                ),\n                \"dst_id\": np.hstack(\n                    [\n                        np.random.randint(num_nodes, size=num_edges)\n                        for _ in range(num_graphs)\n                    ]\n                ),\n                \"label\": label_edata,\n                \"feat\": [line.tolist() for line in feat_edata],\n                \"graph_id\": np.hstack(\n                    [np.full(num_edges, i) for i in range(num_graphs)]\n                ),\n            }\n        )\n        df.to_csv(edges_csv_path_0, index=False)\n        df.to_csv(edges_csv_path_1, index=False)\n        feat_gdata = np.random.rand(num_graphs, num_dims)\n        label_gdata = np.random.randint(2, size=num_graphs)\n        df = pd.DataFrame(\n            {\n                \"label\": label_gdata,\n                \"feat\": [line.tolist() for line in feat_gdata],\n                \"graph_id\": np.arange(num_graphs),\n            }\n        )\n        df.to_csv(graph_csv_path, index=False)\n\n        # load CSVDataset with default node/edge/gdata_parser\n        for force_reload in [True, False]:\n            if not force_reload:\n                # remove original node data file to verify reload from cached files\n                os.remove(nodes_csv_path_0)\n                assert not os.path.exists(nodes_csv_path_0)\n            csv_dataset = data.CSVDataset(test_dir, force_reload=force_reload)\n            assert len(csv_dataset) == num_graphs\n            assert csv_dataset.has_cache()\n            assert len(csv_dataset.data) == 2\n            assert \"feat\" in csv_dataset.data\n            assert \"label\" in csv_dataset.data\n            assert F.array_equal(\n                F.tensor(feat_gdata, dtype=F.float32), csv_dataset.data[\"feat\"]\n            )\n            for i, (g, g_data) in enumerate(csv_dataset):\n                assert not g.is_homogeneous\n                assert F.asnumpy(g_data[\"label\"]) == label_gdata[i]\n                assert F.array_equal(\n                    g_data[\"feat\"], F.tensor(feat_gdata[i], dtype=F.float32)\n                )\n                for ntype in g.ntypes:\n                    assert g.num_nodes(ntype) == num_nodes\n                    assert F.array_equal(\n                        F.tensor(\n                            feat_ndata[i * num_nodes : (i + 1) * num_nodes],\n                            dtype=F.float32,\n                        ),\n                        g.nodes[ntype].data[\"feat\"],\n                    )\n                    assert np.array_equal(\n                        label_ndata[i * num_nodes : (i + 1) * num_nodes],\n                        F.asnumpy(g.nodes[ntype].data[\"label\"]),\n                    )\n                for etype in g.etypes:\n                    assert g.num_edges(etype) == num_edges\n                    assert F.array_equal(\n                        F.tensor(\n                            feat_edata[i * num_edges : (i + 1) * num_edges],\n                            dtype=F.float32,\n                        ),\n                        g.edges[etype].data[\"feat\"],\n                    )\n                    assert np.array_equal(\n                        label_edata[i * num_edges : (i + 1) * num_edges],\n                        F.asnumpy(g.edges[etype].data[\"label\"]),\n                    )\n\n\ndef _test_CSVDataset_customized_data_parser():\n    with tempfile.TemporaryDirectory() as test_dir:\n        # generate YAML/CSVs\n        meta_yaml_path = os.path.join(test_dir, \"meta.yaml\")\n        edges_csv_path_0 = os.path.join(test_dir, \"test_edges_0.csv\")\n        edges_csv_path_1 = os.path.join(test_dir, \"test_edges_1.csv\")\n        nodes_csv_path_0 = os.path.join(test_dir, \"test_nodes_0.csv\")\n        nodes_csv_path_1 = os.path.join(test_dir, \"test_nodes_1.csv\")\n        graph_csv_path = os.path.join(test_dir, \"test_graph.csv\")\n        meta_yaml_data = {\n            \"dataset_name\": \"default_name\",\n            \"node_data\": [\n                {\n                    \"file_name\": os.path.basename(nodes_csv_path_0),\n                    \"ntype\": \"user\",\n                },\n                {\n                    \"file_name\": os.path.basename(nodes_csv_path_1),\n                    \"ntype\": \"item\",\n                },\n            ],\n            \"edge_data\": [\n                {\n                    \"file_name\": os.path.basename(edges_csv_path_0),\n                    \"etype\": [\"user\", \"follow\", \"user\"],\n                },\n                {\n                    \"file_name\": os.path.basename(edges_csv_path_1),\n                    \"etype\": [\"user\", \"like\", \"item\"],\n                },\n            ],\n            \"graph_data\": {\"file_name\": os.path.basename(graph_csv_path)},\n        }\n        with open(meta_yaml_path, \"w\") as f:\n            yaml.dump(meta_yaml_data, f, sort_keys=False)\n        num_nodes = 100\n        num_edges = 500\n        num_graphs = 10\n        label_ndata = np.random.randint(2, size=num_nodes * num_graphs)\n        df = pd.DataFrame(\n            {\n                \"node_id\": np.hstack(\n                    [np.arange(num_nodes) for _ in range(num_graphs)]\n                ),\n                \"label\": label_ndata,\n                \"graph_id\": np.hstack(\n                    [np.full(num_nodes, i) for i in range(num_graphs)]\n                ),\n            }\n        )\n        df.to_csv(nodes_csv_path_0, index=False)\n        df.to_csv(nodes_csv_path_1, index=False)\n        label_edata = np.random.randint(2, size=num_edges * num_graphs)\n        df = pd.DataFrame(\n            {\n                \"src_id\": np.hstack(\n                    [\n                        np.random.randint(num_nodes, size=num_edges)\n                        for _ in range(num_graphs)\n                    ]\n                ),\n                \"dst_id\": np.hstack(\n                    [\n                        np.random.randint(num_nodes, size=num_edges)\n                        for _ in range(num_graphs)\n                    ]\n                ),\n                \"label\": label_edata,\n                \"graph_id\": np.hstack(\n                    [np.full(num_edges, i) for i in range(num_graphs)]\n                ),\n            }\n        )\n        df.to_csv(edges_csv_path_0, index=False)\n        df.to_csv(edges_csv_path_1, index=False)\n        label_gdata = np.random.randint(2, size=num_graphs)\n        df = pd.DataFrame(\n            {\"label\": label_gdata, \"graph_id\": np.arange(num_graphs)}\n        )\n        df.to_csv(graph_csv_path, index=False)\n\n        class CustDataParser:\n            def __call__(self, df):\n                data = {}\n                for header in df:\n                    dt = df[header].to_numpy().squeeze()\n                    if header == \"label\":\n                        dt += 2\n                    data[header] = dt\n                return data\n\n        # load CSVDataset with customized node/edge/gdata_parser\n        # specify via dict[ntype/etype, callable]\n        csv_dataset = data.CSVDataset(\n            test_dir,\n            force_reload=True,\n            ndata_parser={\"user\": CustDataParser()},\n            edata_parser={(\"user\", \"like\", \"item\"): CustDataParser()},\n            gdata_parser=CustDataParser(),\n        )\n        assert len(csv_dataset) == num_graphs\n        assert len(csv_dataset.data) == 1\n        assert \"label\" in csv_dataset.data\n        for i, (g, g_data) in enumerate(csv_dataset):\n            assert not g.is_homogeneous\n            assert F.asnumpy(g_data) == label_gdata[i] + 2\n            for ntype in g.ntypes:\n                assert g.num_nodes(ntype) == num_nodes\n                offset = 2 if ntype == \"user\" else 0\n                assert np.array_equal(\n                    label_ndata[i * num_nodes : (i + 1) * num_nodes] + offset,\n                    F.asnumpy(g.nodes[ntype].data[\"label\"]),\n                )\n            for etype in g.etypes:\n                assert g.num_edges(etype) == num_edges\n                offset = 2 if etype == \"like\" else 0\n                assert np.array_equal(\n                    label_edata[i * num_edges : (i + 1) * num_edges] + offset,\n                    F.asnumpy(g.edges[etype].data[\"label\"]),\n                )\n        # specify via callable\n        csv_dataset = data.CSVDataset(\n            test_dir,\n            force_reload=True,\n            ndata_parser=CustDataParser(),\n            edata_parser=CustDataParser(),\n            gdata_parser=CustDataParser(),\n        )\n        assert len(csv_dataset) == num_graphs\n        assert len(csv_dataset.data) == 1\n        assert \"label\" in csv_dataset.data\n        for i, (g, g_data) in enumerate(csv_dataset):\n            assert not g.is_homogeneous\n            assert F.asnumpy(g_data) == label_gdata[i] + 2\n            for ntype in g.ntypes:\n                assert g.num_nodes(ntype) == num_nodes\n                offset = 2\n                assert np.array_equal(\n                    label_ndata[i * num_nodes : (i + 1) * num_nodes] + offset,\n                    F.asnumpy(g.nodes[ntype].data[\"label\"]),\n                )\n            for etype in g.etypes:\n                assert g.num_edges(etype) == num_edges\n                offset = 2\n                assert np.array_equal(\n                    label_edata[i * num_edges : (i + 1) * num_edges] + offset,\n                    F.asnumpy(g.edges[etype].data[\"label\"]),\n                )\n\n\ndef _test_NodeEdgeGraphData():\n    from dgl.data.csv_dataset_base import EdgeData, GraphData, NodeData\n\n    # NodeData basics\n    num_nodes = 100\n    node_ids = np.arange(num_nodes, dtype=float)\n    ndata = NodeData(node_ids, {})\n    assert np.array_equal(ndata.id, node_ids)\n    assert len(ndata.data) == 0\n    assert ndata.type == \"_V\"\n    assert np.array_equal(ndata.graph_id, np.full(num_nodes, 0))\n    # NodeData more\n    data = {\"feat\": np.random.rand(num_nodes, 3)}\n    graph_id = np.arange(num_nodes)\n    ndata = NodeData(node_ids, data, type=\"user\", graph_id=graph_id)\n    assert ndata.type == \"user\"\n    assert np.array_equal(ndata.graph_id, graph_id)\n    assert len(ndata.data) == len(data)\n    for k, v in data.items():\n        assert k in ndata.data\n        assert np.array_equal(ndata.data[k], v)\n    # NodeData except\n    expect_except = False\n    try:\n        NodeData(\n            np.arange(num_nodes),\n            {\"feat\": np.random.rand(num_nodes + 1, 3)},\n            graph_id=np.arange(num_nodes - 1),\n        )\n    except:\n        expect_except = True\n    assert expect_except\n\n    # EdgeData basics\n    num_nodes = 100\n    num_edges = 1000\n    src_ids = np.random.randint(num_nodes, size=num_edges)\n    dst_ids = np.random.randint(num_nodes, size=num_edges)\n    edata = EdgeData(src_ids, dst_ids, {})\n    assert np.array_equal(edata.src, src_ids)\n    assert np.array_equal(edata.dst, dst_ids)\n    assert edata.type == (\"_V\", \"_E\", \"_V\")\n    assert len(edata.data) == 0\n    assert np.array_equal(edata.graph_id, np.full(num_edges, 0))\n    # EdageData more\n    src_ids = np.random.randint(num_nodes, size=num_edges).astype(float)\n    dst_ids = np.random.randint(num_nodes, size=num_edges).astype(float)\n    data = {\"feat\": np.random.rand(num_edges, 3)}\n    etype = (\"user\", \"like\", \"item\")\n    graph_ids = np.arange(num_edges)\n    edata = EdgeData(src_ids, dst_ids, data, type=etype, graph_id=graph_ids)\n    assert np.array_equal(edata.src, src_ids)\n    assert np.array_equal(edata.dst, dst_ids)\n    assert edata.type == etype\n    assert len(edata.data) == len(data)\n    for k, v in data.items():\n        assert k in edata.data\n        assert np.array_equal(edata.data[k], v)\n    assert np.array_equal(edata.graph_id, graph_ids)\n    # EdgeData except\n    expect_except = False\n    try:\n        EdgeData(\n            np.arange(num_edges),\n            np.arange(num_edges + 1),\n            {\"feat\": np.random.rand(num_edges - 1, 3)},\n            graph_id=np.arange(num_edges + 2),\n        )\n    except:\n        expect_except = True\n    assert expect_except\n\n    # GraphData basics\n    num_graphs = 10\n    graph_ids = np.arange(num_graphs)\n    gdata = GraphData(graph_ids, {})\n    assert np.array_equal(gdata.graph_id, graph_ids)\n    assert len(gdata.data) == 0\n    # GraphData more\n    graph_ids = np.arange(num_graphs).astype(float)\n    data = {\"feat\": np.random.rand(num_graphs, 3)}\n    gdata = GraphData(graph_ids, data)\n    assert np.array_equal(gdata.graph_id, graph_ids)\n    assert len(gdata.data) == len(data)\n    for k, v in data.items():\n        assert k in gdata.data\n        assert np.array_equal(gdata.data[k], v)\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Datasets don't need to be tested on GPU.\",\n)\n@unittest.skipIf(dgl.backend.backend_name == \"mxnet\", reason=\"Skip MXNet\")\n@unittest.skipIf(\n    dgl.backend.backend_name == \"tensorflow\", reason=\"Skip Tensorflow\"\n)\ndef test_csvdataset():\n    _test_NodeEdgeGraphData()\n    _test_construct_graphs_node_ids()\n    _test_construct_graphs_homo()\n    _test_construct_graphs_hetero()\n    _test_construct_graphs_multiple()\n    _test_DefaultDataParser()\n    _test_load_yaml_with_sanity_check()\n    _test_load_node_data_from_csv()\n    _test_load_edge_data_from_csv()\n    _test_load_graph_data_from_csv()\n    _test_CSVDataset_single()\n    _test_CSVDataset_multiple()\n    _test_CSVDataset_customized_data_parser()\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Datasets don't need to be tested on GPU.\",\n)\n@unittest.skipIf(dgl.backend.backend_name == \"mxnet\", reason=\"Skip MXNet\")\ndef test_as_nodepred1():\n    ds = data.AmazonCoBuyComputerDataset()\n    print(\"train_mask\" in ds[0].ndata)\n    new_ds = data.AsNodePredDataset(ds, [0.8, 0.1, 0.1], verbose=True)\n    assert len(new_ds) == 1\n    assert new_ds[0].num_nodes() == ds[0].num_nodes()\n    assert new_ds[0].num_edges() == ds[0].num_edges()\n    assert \"train_mask\" in new_ds[0].ndata\n    assert F.array_equal(\n        new_ds.train_idx, F.nonzero_1d(new_ds[0].ndata[\"train_mask\"])\n    )\n    assert F.array_equal(\n        new_ds.val_idx, F.nonzero_1d(new_ds[0].ndata[\"val_mask\"])\n    )\n    assert F.array_equal(\n        new_ds.test_idx, F.nonzero_1d(new_ds[0].ndata[\"test_mask\"])\n    )\n\n    ds = data.AIFBDataset()\n    print(\"train_mask\" in ds[0].nodes[\"Personen\"].data)\n    new_ds = data.AsNodePredDataset(\n        ds, [0.8, 0.1, 0.1], \"Personen\", verbose=True\n    )\n    assert len(new_ds) == 1\n    assert new_ds[0].ntypes == ds[0].ntypes\n    assert new_ds[0].canonical_etypes == ds[0].canonical_etypes\n    assert \"train_mask\" in new_ds[0].nodes[\"Personen\"].data\n    assert F.array_equal(\n        new_ds.train_idx,\n        F.nonzero_1d(new_ds[0].nodes[\"Personen\"].data[\"train_mask\"]),\n    )\n    assert F.array_equal(\n        new_ds.val_idx,\n        F.nonzero_1d(new_ds[0].nodes[\"Personen\"].data[\"val_mask\"]),\n    )\n    assert F.array_equal(\n        new_ds.test_idx,\n        F.nonzero_1d(new_ds[0].nodes[\"Personen\"].data[\"test_mask\"]),\n    )\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Datasets don't need to be tested on GPU.\",\n)\n@unittest.skipIf(dgl.backend.backend_name == \"mxnet\", reason=\"Skip MXNet\")\ndef test_as_nodepred2():\n    # test proper reprocessing\n\n    # create\n    ds = data.AsNodePredDataset(\n        data.AmazonCoBuyComputerDataset(), [0.8, 0.1, 0.1]\n    )\n    assert F.sum(F.astype(ds[0].ndata[\"train_mask\"], F.int32), 0) == int(\n        ds[0].num_nodes() * 0.8\n    )\n    assert len(ds.train_idx) == int(ds[0].num_nodes() * 0.8)\n    # read from cache\n    ds = data.AsNodePredDataset(\n        data.AmazonCoBuyComputerDataset(), [0.8, 0.1, 0.1]\n    )\n    assert F.sum(F.astype(ds[0].ndata[\"train_mask\"], F.int32), 0) == int(\n        ds[0].num_nodes() * 0.8\n    )\n    assert len(ds.train_idx) == int(ds[0].num_nodes() * 0.8)\n    # invalid cache, re-read\n    ds = data.AsNodePredDataset(\n        data.AmazonCoBuyComputerDataset(), [0.1, 0.1, 0.8]\n    )\n    assert F.sum(F.astype(ds[0].ndata[\"train_mask\"], F.int32), 0) == int(\n        ds[0].num_nodes() * 0.1\n    )\n    assert len(ds.train_idx) == int(ds[0].num_nodes() * 0.1)\n\n    # create\n    ds = data.AsNodePredDataset(\n        data.AIFBDataset(), [0.8, 0.1, 0.1], \"Personen\", verbose=True\n    )\n    assert F.sum(\n        F.astype(ds[0].nodes[\"Personen\"].data[\"train_mask\"], F.int32), 0\n    ) == int(ds[0].num_nodes(\"Personen\") * 0.8)\n    assert len(ds.train_idx) == int(ds[0].num_nodes(\"Personen\") * 0.8)\n    # read from cache\n    ds = data.AsNodePredDataset(\n        data.AIFBDataset(), [0.8, 0.1, 0.1], \"Personen\", verbose=True\n    )\n    assert F.sum(\n        F.astype(ds[0].nodes[\"Personen\"].data[\"train_mask\"], F.int32), 0\n    ) == int(ds[0].num_nodes(\"Personen\") * 0.8)\n    assert len(ds.train_idx) == int(ds[0].num_nodes(\"Personen\") * 0.8)\n    # invalid cache, re-read\n    ds = data.AsNodePredDataset(\n        data.AIFBDataset(), [0.1, 0.1, 0.8], \"Personen\", verbose=True\n    )\n    assert F.sum(\n        F.astype(ds[0].nodes[\"Personen\"].data[\"train_mask\"], F.int32), 0\n    ) == int(ds[0].num_nodes(\"Personen\") * 0.1)\n    assert len(ds.train_idx) == int(ds[0].num_nodes(\"Personen\") * 0.1)\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Datasets don't need to be tested on GPU.\",\n)\n@unittest.skipIf(dgl.backend.backend_name == \"mxnet\", reason=\"Skip MXNet\")\ndef test_as_linkpred():\n    # create\n    ds = data.AsLinkPredDataset(\n        data.CoraGraphDataset(),\n        split_ratio=[0.8, 0.1, 0.1],\n        neg_ratio=1,\n        verbose=True,\n    )\n    # Cora has 10556 edges, 10% test edges can be 1057\n    assert ds.test_edges[0][0].shape[0] == 1057\n    # negative samples, not guaranteed, so the assert is in a relaxed range\n    assert 1000 <= ds.test_edges[1][0].shape[0] <= 1057\n    # read from cache\n    ds = data.AsLinkPredDataset(\n        data.CoraGraphDataset(),\n        split_ratio=[0.7, 0.1, 0.2],\n        neg_ratio=2,\n        verbose=True,\n    )\n    assert ds.test_edges[0][0].shape[0] == 2112\n    # negative samples, not guaranteed to be ratio 2, so the assert is in a relaxed range\n    assert 4000 < ds.test_edges[1][0].shape[0] <= 4224\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Datasets don't need to be tested on GPU.\",\n)\n@unittest.skipIf(dgl.backend.backend_name == \"mxnet\", reason=\"Skip MXNet\")\n@unittest.skipIf(\n    dgl.backend.backend_name == \"tensorflow\", reason=\"Skip Tensorflow\"\n)\ndef test_as_nodepred_csvdataset():\n    with tempfile.TemporaryDirectory() as test_dir:\n        # generate YAML/CSVs\n        meta_yaml_path = os.path.join(test_dir, \"meta.yaml\")\n        edges_csv_path = os.path.join(test_dir, \"test_edges.csv\")\n        nodes_csv_path = os.path.join(test_dir, \"test_nodes.csv\")\n        meta_yaml_data = {\n            \"version\": \"1.0.0\",\n            \"dataset_name\": \"default_name\",\n            \"node_data\": [{\"file_name\": os.path.basename(nodes_csv_path)}],\n            \"edge_data\": [{\"file_name\": os.path.basename(edges_csv_path)}],\n        }\n        with open(meta_yaml_path, \"w\") as f:\n            yaml.dump(meta_yaml_data, f, sort_keys=False)\n        num_nodes = 100\n        num_edges = 500\n        num_dims = 3\n        num_classes = num_nodes\n        feat_ndata = np.random.rand(num_nodes, num_dims)\n        label_ndata = np.arange(num_classes)\n        df = pd.DataFrame(\n            {\n                \"node_id\": np.arange(num_nodes),\n                \"label\": label_ndata,\n                \"feat\": [line.tolist() for line in feat_ndata],\n            }\n        )\n        df.to_csv(nodes_csv_path, index=False)\n        df = pd.DataFrame(\n            {\n                \"src_id\": np.random.randint(num_nodes, size=num_edges),\n                \"dst_id\": np.random.randint(num_nodes, size=num_edges),\n            }\n        )\n        df.to_csv(edges_csv_path, index=False)\n\n        ds = data.CSVDataset(test_dir, force_reload=True)\n        assert \"feat\" in ds[0].ndata\n        assert \"label\" in ds[0].ndata\n        assert \"train_mask\" not in ds[0].ndata\n        assert not hasattr(ds[0], \"num_classes\")\n        new_ds = data.AsNodePredDataset(\n            ds, split_ratio=[0.8, 0.1, 0.1], force_reload=True\n        )\n        assert new_ds.num_classes == num_classes\n        assert \"feat\" in new_ds[0].ndata\n        assert \"label\" in new_ds[0].ndata\n        assert \"train_mask\" in new_ds[0].ndata\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Datasets don't need to be tested on GPU.\",\n)\n@unittest.skipIf(dgl.backend.backend_name == \"mxnet\", reason=\"Skip MXNet\")\ndef test_as_graphpred_reprocess():\n    ds = data.AsGraphPredDataset(\n        data.GINDataset(name=\"MUTAG\", self_loop=True), [0.8, 0.1, 0.1]\n    )\n    assert len(ds.train_idx) == int(len(ds) * 0.8)\n    # read from cache\n    ds = data.AsGraphPredDataset(\n        data.GINDataset(name=\"MUTAG\", self_loop=True), [0.8, 0.1, 0.1]\n    )\n    assert len(ds.train_idx) == int(len(ds) * 0.8)\n    # invalid cache, re-read\n    ds = data.AsGraphPredDataset(\n        data.GINDataset(name=\"MUTAG\", self_loop=True), [0.1, 0.1, 0.8]\n    )\n    assert len(ds.train_idx) == int(len(ds) * 0.1)\n\n    ds = data.AsGraphPredDataset(\n        data.FakeNewsDataset(\"politifact\", \"profile\"), [0.8, 0.1, 0.1]\n    )\n    assert len(ds.train_idx) == int(len(ds) * 0.8)\n    # read from cache\n    ds = data.AsGraphPredDataset(\n        data.FakeNewsDataset(\"politifact\", \"profile\"), [0.8, 0.1, 0.1]\n    )\n    assert len(ds.train_idx) == int(len(ds) * 0.8)\n    # invalid cache, re-read\n    ds = data.AsGraphPredDataset(\n        data.FakeNewsDataset(\"politifact\", \"profile\"), [0.1, 0.1, 0.8]\n    )\n    assert len(ds.train_idx) == int(len(ds) * 0.1)\n\n    ds = data.AsGraphPredDataset(data.QM7bDataset(), [0.8, 0.1, 0.1])\n    assert len(ds.train_idx) == int(len(ds) * 0.8)\n    # read from cache\n    ds = data.AsGraphPredDataset(data.QM7bDataset(), [0.8, 0.1, 0.1])\n    assert len(ds.train_idx) == int(len(ds) * 0.8)\n    # invalid cache, re-read\n    ds = data.AsGraphPredDataset(data.QM7bDataset(), [0.1, 0.1, 0.8])\n    assert len(ds.train_idx) == int(len(ds) * 0.1)\n\n    ds = data.AsGraphPredDataset(\n        data.QM9Dataset(label_keys=[\"mu\", \"gap\"]), [0.8, 0.1, 0.1]\n    )\n    assert len(ds.train_idx) == int(len(ds) * 0.8)\n    # read from cache\n    ds = data.AsGraphPredDataset(\n        data.QM9Dataset(label_keys=[\"mu\", \"gap\"]), [0.8, 0.1, 0.1]\n    )\n    assert len(ds.train_idx) == int(len(ds) * 0.8)\n    # invalid cache, re-read\n    ds = data.AsGraphPredDataset(\n        data.QM9Dataset(label_keys=[\"mu\", \"gap\"]), [0.1, 0.1, 0.8]\n    )\n    assert len(ds.train_idx) == int(len(ds) * 0.1)\n\n    ds = data.AsGraphPredDataset(\n        data.QM9EdgeDataset(label_keys=[\"mu\", \"alpha\"]), [0.8, 0.1, 0.1]\n    )\n    assert len(ds.train_idx) == int(len(ds) * 0.8)\n    # read from cache\n    ds = data.AsGraphPredDataset(\n        data.QM9EdgeDataset(label_keys=[\"mu\", \"alpha\"]), [0.8, 0.1, 0.1]\n    )\n    assert len(ds.train_idx) == int(len(ds) * 0.8)\n    # invalid cache, re-read\n    ds = data.AsGraphPredDataset(\n        data.QM9EdgeDataset(label_keys=[\"mu\", \"alpha\"]), [0.1, 0.1, 0.8]\n    )\n    assert len(ds.train_idx) == int(len(ds) * 0.1)\n\n    ds = data.AsGraphPredDataset(data.TUDataset(\"DD\"), [0.8, 0.1, 0.1])\n    assert len(ds.train_idx) == int(len(ds) * 0.8)\n    # read from cache\n    ds = data.AsGraphPredDataset(data.TUDataset(\"DD\"), [0.8, 0.1, 0.1])\n    assert len(ds.train_idx) == int(len(ds) * 0.8)\n    # invalid cache, re-read\n    ds = data.AsGraphPredDataset(data.TUDataset(\"DD\"), [0.1, 0.1, 0.8])\n    assert len(ds.train_idx) == int(len(ds) * 0.1)\n\n    ds = data.AsGraphPredDataset(data.LegacyTUDataset(\"DD\"), [0.8, 0.1, 0.1])\n    assert len(ds.train_idx) == int(len(ds) * 0.8)\n    # read from cache\n    ds = data.AsGraphPredDataset(data.LegacyTUDataset(\"DD\"), [0.8, 0.1, 0.1])\n    assert len(ds.train_idx) == int(len(ds) * 0.8)\n    # invalid cache, re-read\n    ds = data.AsGraphPredDataset(data.LegacyTUDataset(\"DD\"), [0.1, 0.1, 0.8])\n    assert len(ds.train_idx) == int(len(ds) * 0.1)\n\n    ds = data.AsGraphPredDataset(data.BA2MotifDataset(), [0.8, 0.1, 0.1])\n    assert len(ds.train_idx) == int(len(ds) * 0.8)\n    # read from cache\n    ds = data.AsGraphPredDataset(data.BA2MotifDataset(), [0.8, 0.1, 0.1])\n    assert len(ds.train_idx) == int(len(ds) * 0.8)\n    # invalid cache, re-read\n    ds = data.AsGraphPredDataset(data.BA2MotifDataset(), [0.1, 0.1, 0.8])\n    assert len(ds.train_idx) == int(len(ds) * 0.1)\n\n\nif __name__ == \"__main__\":\n    test_minigc()\n    test_gin()\n    test_data_hash()\n    test_tudataset_regression()\n    test_fraud()\n    test_fakenews()\n    test_csvdataset()\n    test_as_nodepred1()\n    test_as_nodepred2()\n    test_as_nodepred_csvdataset()\n"
  },
  {
    "path": "tests/python/common/data/test_geom_gcn.py",
    "content": "import unittest\n\nimport backend as F\n\nimport dgl\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Datasets don't need to be tested on GPU.\",\n)\n@unittest.skipIf(\n    dgl.backend.backend_name != \"pytorch\", reason=\"only supports pytorch\"\n)\ndef test_chameleon():\n    transform = dgl.AddSelfLoop(allow_duplicate=True)\n\n    g = dgl.data.ChameleonDataset(force_reload=True)[0]\n    assert g.num_nodes() == 2277\n    assert g.num_edges() == 36101\n    g2 = dgl.data.ChameleonDataset(force_reload=True, transform=transform)[0]\n    assert g2.num_edges() - g.num_edges() == g.num_nodes()\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Datasets don't need to be tested on GPU.\",\n)\n@unittest.skipIf(\n    dgl.backend.backend_name != \"pytorch\", reason=\"only supports pytorch\"\n)\ndef test_squirrel():\n    transform = dgl.AddSelfLoop(allow_duplicate=True)\n\n    g = dgl.data.SquirrelDataset(force_reload=True)[0]\n    assert g.num_nodes() == 5201\n    assert g.num_edges() == 217073\n    g2 = dgl.data.SquirrelDataset(force_reload=True, transform=transform)[0]\n    assert g2.num_edges() - g.num_edges() == g.num_nodes()\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Datasets don't need to be tested on GPU.\",\n)\n@unittest.skipIf(\n    dgl.backend.backend_name != \"pytorch\", reason=\"only supports pytorch\"\n)\ndef test_cornell():\n    transform = dgl.AddSelfLoop(allow_duplicate=True)\n\n    g = dgl.data.CornellDataset(force_reload=True)[0]\n    assert g.num_nodes() == 183\n    assert g.num_edges() == 298\n    g2 = dgl.data.CornellDataset(force_reload=True, transform=transform)[0]\n    assert g2.num_edges() - g.num_edges() == g.num_nodes()\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Datasets don't need to be tested on GPU.\",\n)\n@unittest.skipIf(\n    dgl.backend.backend_name != \"pytorch\", reason=\"only supports pytorch\"\n)\ndef test_texas():\n    transform = dgl.AddSelfLoop(allow_duplicate=True)\n\n    g = dgl.data.TexasDataset(force_reload=True)[0]\n    assert g.num_nodes() == 183\n    assert g.num_edges() == 325\n    g2 = dgl.data.TexasDataset(force_reload=True, transform=transform)[0]\n    assert g2.num_edges() - g.num_edges() == g.num_nodes()\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Datasets don't need to be tested on GPU.\",\n)\n@unittest.skipIf(\n    dgl.backend.backend_name != \"pytorch\", reason=\"only supports pytorch\"\n)\ndef test_wisconsin():\n    transform = dgl.AddSelfLoop(allow_duplicate=True)\n\n    g = dgl.data.WisconsinDataset(force_reload=True)[0]\n    assert g.num_nodes() == 251\n    assert g.num_edges() == 515\n    g2 = dgl.data.WisconsinDataset(force_reload=True, transform=transform)[0]\n    assert g2.num_edges() - g.num_edges() == g.num_nodes()\n"
  },
  {
    "path": "tests/python/common/data/test_movielens.py",
    "content": "import unittest\n\nimport backend as F\n\nimport dgl\nfrom dgl.data.movielens import MovieLensDataset\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Datasets don't need to be tested on GPU.\",\n)\n@unittest.skipIf(\n    dgl.backend.backend_name != \"pytorch\", reason=\"only supports pytorch\"\n)\ndef test_movielens():\n    transform = dgl.AddSelfLoop(new_etypes=True)\n\n    movielens = MovieLensDataset(name=\"ml-100k\", valid_ratio=0.2, verbose=True)\n    g = movielens[0]\n    assert g.num_edges(\"user-movie\") == g.num_edges(\"movie-user\") == 100000\n    assert (\n        g.nodes[\"user\"].data[\"feat\"].shape[1]\n        == g.nodes[\"user\"].data[\"feat\"].shape[1]\n        == g.nodes[\"user\"].data[\"feat\"].shape[1]\n        == 23\n    )\n    assert (\n        g.nodes[\"movie\"].data[\"feat\"].shape[1]\n        == g.nodes[\"movie\"].data[\"feat\"].shape[1]\n        == g.nodes[\"movie\"].data[\"feat\"].shape[1]\n        == 320\n    )\n\n    movielens = MovieLensDataset(\n        name=\"ml-100k\", valid_ratio=0.2, transform=transform, verbose=True\n    )\n    g1 = movielens[0]\n    assert g1.num_edges() - g.num_edges() == g.num_nodes()\n    assert g1.num_edges() - g.num_edges() == g.num_nodes()\n    assert g1.num_edges() - g.num_edges() == g.num_nodes()\n\n    movielens = MovieLensDataset(\n        name=\"ml-1m\", valid_ratio=0.2, test_ratio=0.1, verbose=True\n    )\n    g = movielens[0]\n    assert g.num_edges(\"user-movie\") == g.num_edges(\"movie-user\") == 1000209\n\n    movielens = MovieLensDataset(\n        name=\"ml-10m\", valid_ratio=0.2, test_ratio=0.1, verbose=True\n    )\n    g = movielens[0]\n    assert g.num_edges(\"user-movie\") == g.num_edges(\"movie-user\") == 10000054\n"
  },
  {
    "path": "tests/python/common/data/test_serialize.py",
    "content": "import os\nimport tempfile\nimport time\nimport unittest\nimport warnings\n\nimport backend as F\n\nimport dgl\nimport dgl.ndarray as nd\nimport numpy as np\nimport pytest\nimport scipy as sp\nfrom dgl.data.utils import load_labels, load_tensors, save_tensors\n\nnp.random.seed(44)\n\n\ndef generate_rand_graph(n):\n    arr = (sp.sparse.random(n, n, density=0.1, format=\"coo\") != 0).astype(\n        np.int64\n    )\n    return dgl.from_scipy(arr)\n\n\ndef construct_graph(n):\n    g_list = []\n    for _ in range(n):\n        g = generate_rand_graph(30)\n        g.edata[\"e1\"] = F.randn((g.num_edges(), 32))\n        g.edata[\"e2\"] = F.ones((g.num_edges(), 32))\n        g.ndata[\"n1\"] = F.randn((g.num_nodes(), 64))\n        g_list.append(g)\n    return g_list\n\n\n@unittest.skipIf(F._default_context_str == \"gpu\", reason=\"GPU not implemented\")\ndef test_graph_serialize_with_feature():\n    num_graphs = 100\n\n    t0 = time.time()\n\n    g_list = construct_graph(num_graphs)\n\n    t1 = time.time()\n\n    # create a temporary file and immediately release it so DGL can open it.\n    f = tempfile.NamedTemporaryFile(delete=False)\n    path = f.name\n    f.close()\n\n    dgl.save_graphs(path, g_list)\n\n    t2 = time.time()\n    idx_list = np.random.permutation(np.arange(num_graphs)).tolist()\n    loadg_list, _ = dgl.load_graphs(path, idx_list)\n\n    t3 = time.time()\n    idx = idx_list[0]\n    load_g = loadg_list[0]\n    print(\"Save time: {} s\".format(t2 - t1))\n    print(\"Load time: {} s\".format(t3 - t2))\n    print(\"Graph Construction time: {} s\".format(t1 - t0))\n\n    assert F.allclose(load_g.nodes(), g_list[idx].nodes())\n\n    load_edges = load_g.all_edges(\"uv\", \"eid\")\n    g_edges = g_list[idx].all_edges(\"uv\", \"eid\")\n    assert F.allclose(load_edges[0], g_edges[0])\n    assert F.allclose(load_edges[1], g_edges[1])\n    assert F.allclose(load_g.edata[\"e1\"], g_list[idx].edata[\"e1\"])\n    assert F.allclose(load_g.edata[\"e2\"], g_list[idx].edata[\"e2\"])\n    assert F.allclose(load_g.ndata[\"n1\"], g_list[idx].ndata[\"n1\"])\n\n    os.unlink(path)\n\n\n@unittest.skipIf(F._default_context_str == \"gpu\", reason=\"GPU not implemented\")\ndef test_graph_serialize_without_feature():\n    num_graphs = 100\n    g_list = [generate_rand_graph(30) for _ in range(num_graphs)]\n\n    # create a temporary file and immediately release it so DGL can open it.\n    f = tempfile.NamedTemporaryFile(delete=False)\n    path = f.name\n    f.close()\n\n    dgl.save_graphs(path, g_list)\n\n    idx_list = np.random.permutation(np.arange(num_graphs)).tolist()\n    loadg_list, _ = dgl.load_graphs(path, idx_list)\n\n    idx = idx_list[0]\n    load_g = loadg_list[0]\n\n    assert F.allclose(load_g.nodes(), g_list[idx].nodes())\n\n    load_edges = load_g.all_edges(\"uv\", \"eid\")\n    g_edges = g_list[idx].all_edges(\"uv\", \"eid\")\n    assert F.allclose(load_edges[0], g_edges[0])\n    assert F.allclose(load_edges[1], g_edges[1])\n\n    os.unlink(path)\n\n\n@unittest.skipIf(F._default_context_str == \"gpu\", reason=\"GPU not implemented\")\ndef test_graph_serialize_with_labels():\n    num_graphs = 100\n    g_list = [generate_rand_graph(30) for _ in range(num_graphs)]\n    labels = {\"label\": F.zeros((num_graphs, 1))}\n\n    # create a temporary file and immediately release it so DGL can open it.\n    f = tempfile.NamedTemporaryFile(delete=False)\n    path = f.name\n    f.close()\n\n    dgl.save_graphs(path, g_list, labels)\n\n    idx_list = np.random.permutation(np.arange(num_graphs)).tolist()\n    loadg_list, l_labels0 = dgl.load_graphs(path, idx_list)\n    l_labels = load_labels(path)\n    assert F.allclose(l_labels[\"label\"], labels[\"label\"])\n    assert F.allclose(l_labels0[\"label\"], labels[\"label\"])\n\n    idx = idx_list[0]\n    load_g = loadg_list[0]\n\n    assert F.allclose(load_g.nodes(), g_list[idx].nodes())\n\n    load_edges = load_g.all_edges(\"uv\", \"eid\")\n    g_edges = g_list[idx].all_edges(\"uv\", \"eid\")\n    assert F.allclose(load_edges[0], g_edges[0])\n    assert F.allclose(load_edges[1], g_edges[1])\n\n    os.unlink(path)\n\n\ndef test_serialize_tensors():\n    # create a temporary file and immediately release it so DGL can open it.\n    f = tempfile.NamedTemporaryFile(delete=False)\n    path = f.name\n    f.close()\n\n    tensor_dict = {\n        \"a\": F.tensor([1, 3, -1, 0], dtype=F.int64),\n        \"1@1\": F.tensor([1.5, 2], dtype=F.float32),\n    }\n\n    save_tensors(path, tensor_dict)\n\n    load_tensor_dict = load_tensors(path)\n\n    for key in tensor_dict:\n        assert key in load_tensor_dict\n        assert np.array_equal(\n            F.asnumpy(load_tensor_dict[key]), F.asnumpy(tensor_dict[key])\n        )\n\n    load_nd_dict = load_tensors(path, return_dgl_ndarray=True)\n\n    for key in tensor_dict:\n        assert key in load_nd_dict\n        assert isinstance(load_nd_dict[key], nd.NDArray)\n        assert np.array_equal(\n            load_nd_dict[key].asnumpy(), F.asnumpy(tensor_dict[key])\n        )\n\n    os.unlink(path)\n\n\ndef test_serialize_empty_dict():\n    # create a temporary file and immediately release it so DGL can open it.\n    f = tempfile.NamedTemporaryFile(delete=False)\n    path = f.name\n    f.close()\n\n    tensor_dict = {}\n\n    save_tensors(path, tensor_dict)\n\n    load_tensor_dict = load_tensors(path)\n    assert isinstance(load_tensor_dict, dict)\n    assert len(load_tensor_dict) == 0\n\n    os.unlink(path)\n\n\ndef load_old_files(files):\n    with warnings.catch_warnings():\n        warnings.simplefilter(\"ignore\", category=UserWarning)\n        return dgl.load_graphs(os.path.join(os.path.dirname(__file__), files))\n\n\ndef test_load_old_files1():\n    loadg_list, _ = load_old_files(\"data/1.bin\")\n    idx, num_nodes, edge0, edge1, edata_e1, edata_e2, ndata_n1 = np.load(\n        os.path.join(os.path.dirname(__file__), \"data/1.npy\"), allow_pickle=True\n    )\n\n    load_g = loadg_list[idx]\n    load_edges = load_g.all_edges(\"uv\", \"eid\")\n\n    assert np.allclose(F.asnumpy(load_edges[0]), edge0)\n    assert np.allclose(F.asnumpy(load_edges[1]), edge1)\n    assert np.allclose(F.asnumpy(load_g.edata[\"e1\"]), edata_e1)\n    assert np.allclose(F.asnumpy(load_g.edata[\"e2\"]), edata_e2)\n    assert np.allclose(F.asnumpy(load_g.ndata[\"n1\"]), ndata_n1)\n\n\ndef test_load_old_files2():\n    loadg_list, labels0 = load_old_files(\"data/2.bin\")\n    labels1 = load_labels(os.path.join(os.path.dirname(__file__), \"data/2.bin\"))\n    idx, edges0, edges1, np_labels = np.load(\n        os.path.join(os.path.dirname(__file__), \"data/2.npy\"), allow_pickle=True\n    )\n    assert np.allclose(F.asnumpy(labels0[\"label\"]), np_labels)\n    assert np.allclose(F.asnumpy(labels1[\"label\"]), np_labels)\n\n    load_g = loadg_list[idx]\n    print(load_g)\n    load_edges = load_g.all_edges(\"uv\", \"eid\")\n    assert np.allclose(F.asnumpy(load_edges[0]), edges0)\n    assert np.allclose(F.asnumpy(load_edges[1]), edges1)\n\n\ndef create_heterographs(idtype):\n    g_x = dgl.heterograph(\n        {(\"user\", \"follows\", \"user\"): ([0, 1, 2], [1, 2, 3])}, idtype=idtype\n    )\n    g_y = dgl.heterograph(\n        {(\"user\", \"knows\", \"user\"): ([0, 2], [2, 3])}, idtype=idtype\n    ).formats(\"csr\")\n    g_x.ndata[\"h\"] = F.randn((4, 3))\n    g_x.edata[\"w\"] = F.randn((3, 2))\n    g_y.ndata[\"hh\"] = F.ones((4, 5))\n    g_y.edata[\"ww\"] = F.randn((2, 10))\n    g = dgl.heterograph(\n        {\n            (\"user\", \"follows\", \"user\"): ([0, 1, 2], [1, 2, 3]),\n            (\"user\", \"knows\", \"user\"): ([0, 2], [2, 3]),\n        },\n        idtype=idtype,\n    )\n    g.nodes[\"user\"].data[\"h\"] = g_x.ndata[\"h\"]\n    g.nodes[\"user\"].data[\"hh\"] = g_y.ndata[\"hh\"]\n    g.edges[\"follows\"].data[\"w\"] = g_x.edata[\"w\"]\n    g.edges[\"knows\"].data[\"ww\"] = g_y.edata[\"ww\"]\n    return [g, g_x, g_y]\n\n\ndef create_heterographs2(idtype):\n    g_x = dgl.heterograph(\n        {(\"user\", \"follows\", \"user\"): ([0, 1, 2], [1, 2, 3])}, idtype=idtype\n    )\n    g_y = dgl.heterograph(\n        {(\"user\", \"knows\", \"user\"): ([0, 2], [2, 3])}, idtype=idtype\n    ).formats(\"csr\")\n    g_z = dgl.heterograph(\n        {(\"user\", \"knows\", \"knowledge\"): ([0, 1, 3], [2, 3, 4])}, idtype=idtype\n    )\n    g_x.ndata[\"h\"] = F.randn((4, 3))\n    g_x.edata[\"w\"] = F.randn((3, 2))\n    g_y.ndata[\"hh\"] = F.ones((4, 5))\n    g_y.edata[\"ww\"] = F.randn((2, 10))\n    g = dgl.heterograph(\n        {\n            (\"user\", \"follows\", \"user\"): ([0, 1, 2], [1, 2, 3]),\n            (\"user\", \"knows\", \"user\"): ([0, 2], [2, 3]),\n            (\"user\", \"knows\", \"knowledge\"): ([0, 1, 3], [2, 3, 4]),\n        },\n        idtype=idtype,\n    )\n    g.nodes[\"user\"].data[\"h\"] = g_x.ndata[\"h\"]\n    g.edges[\"follows\"].data[\"w\"] = g_x.edata[\"w\"]\n    g.nodes[\"user\"].data[\"hh\"] = g_y.ndata[\"hh\"]\n    g.edges[(\"user\", \"knows\", \"user\")].data[\"ww\"] = g_y.edata[\"ww\"]\n    return [g, g_x, g_y, g_z]\n\n\ndef test_deserialize_old_heterograph_file():\n    path = os.path.join(os.path.dirname(__file__), \"data/hetero1.bin\")\n    g_list, label_dict = dgl.load_graphs(path)\n    assert g_list[0].idtype == F.int64\n    assert g_list[3].idtype == F.int32\n    assert np.allclose(\n        F.asnumpy(g_list[2].nodes[\"user\"].data[\"hh\"]), np.ones((4, 5))\n    )\n    assert np.allclose(\n        F.asnumpy(g_list[5].nodes[\"user\"].data[\"hh\"]), np.ones((4, 5))\n    )\n    edges = g_list[0][\"follows\"].edges()\n    assert np.allclose(F.asnumpy(edges[0]), np.array([0, 1, 2]))\n    assert np.allclose(F.asnumpy(edges[1]), np.array([1, 2, 3]))\n    assert F.allclose(label_dict[\"graph_label\"], F.ones(54))\n\n\ndef create_old_heterograph_files():\n    path = os.path.join(os.path.dirname(__file__), \"data/hetero1.bin\")\n    g_list0 = create_heterographs(F.int64) + create_heterographs(F.int32)\n    labels_dict = {\"graph_label\": F.ones(54)}\n    dgl.save_graphs(path, g_list0, labels_dict)\n\n\n@unittest.skipIf(F._default_context_str == \"gpu\", reason=\"GPU not implemented\")\ndef test_serialize_heterograph():\n    f = tempfile.NamedTemporaryFile(delete=False)\n    path = f.name\n    f.close()\n    g_list0 = create_heterographs2(F.int64) + create_heterographs2(F.int32)\n    dgl.save_graphs(path, g_list0)\n\n    g_list, _ = dgl.load_graphs(path)\n    assert g_list[0].idtype == F.int64\n    assert len(g_list[0].canonical_etypes) == 3\n    for i in range(len(g_list0)):\n        for j, etypes in enumerate(g_list0[i].canonical_etypes):\n            assert g_list[i].canonical_etypes[j] == etypes\n    # assert g_list[1].restrict_format() == 'any'\n    # assert g_list[2].restrict_format() == 'csr'\n\n    assert g_list[4].idtype == F.int32\n    assert np.allclose(\n        F.asnumpy(g_list[2].nodes[\"user\"].data[\"hh\"]), np.ones((4, 5))\n    )\n    assert np.allclose(\n        F.asnumpy(g_list[6].nodes[\"user\"].data[\"hh\"]), np.ones((4, 5))\n    )\n    edges = g_list[0][\"follows\"].edges()\n    assert np.allclose(F.asnumpy(edges[0]), np.array([0, 1, 2]))\n    assert np.allclose(F.asnumpy(edges[1]), np.array([1, 2, 3]))\n    for i in range(len(g_list)):\n        assert g_list[i].ntypes == g_list0[i].ntypes\n        assert g_list[i].etypes == g_list0[i].etypes\n\n    # test set feature after load_graph\n    g_list[3].nodes[\"user\"].data[\"test\"] = F.tensor([0, 1, 2, 4])\n    g_list[3].edata[\"test\"] = F.tensor([0, 1, 2])\n\n    os.unlink(path)\n\n\n@unittest.skipIf(F._default_context_str == \"gpu\", reason=\"GPU not implemented\")\n@pytest.mark.skip(reason=\"lack of permission on CI\")\ndef test_serialize_heterograph_s3():\n    path = \"s3://dglci-data-test/graph2.bin\"\n    g_list0 = create_heterographs(F.int64) + create_heterographs(F.int32)\n    dgl.save_graphs(path, g_list0)\n\n    g_list = dgl.load_graphs(path, [0, 2, 5])\n    assert g_list[0].idtype == F.int64\n    # assert g_list[1].restrict_format() == 'csr'\n    assert np.allclose(\n        F.asnumpy(g_list[1].nodes[\"user\"].data[\"hh\"]), np.ones((4, 5))\n    )\n    assert np.allclose(\n        F.asnumpy(g_list[2].nodes[\"user\"].data[\"hh\"]), np.ones((4, 5))\n    )\n    edges = g_list[0][\"follows\"].edges()\n    assert np.allclose(F.asnumpy(edges[0]), np.array([0, 1, 2]))\n    assert np.allclose(F.asnumpy(edges[1]), np.array([1, 2, 3]))\n\n\n@unittest.skipIf(F._default_context_str == \"gpu\", reason=\"GPU not implemented\")\n@pytest.mark.parametrize(\n    \"formats\",\n    [\n        \"coo\",\n        \"csr\",\n        \"csc\",\n        [\"coo\", \"csc\"],\n        [\"coo\", \"csr\"],\n        [\"csc\", \"csr\"],\n        [\"coo\", \"csr\", \"csc\"],\n    ],\n)\ndef test_graph_serialize_with_formats(formats):\n    num_graphs = 100\n    g_list = [generate_rand_graph(30) for _ in range(num_graphs)]\n\n    # create a temporary file and immediately release it so DGL can open it.\n    f = tempfile.NamedTemporaryFile(delete=False)\n    path = f.name\n    f.close()\n\n    dgl.save_graphs(path, g_list, formats=formats)\n\n    idx_list = np.random.permutation(np.arange(num_graphs)).tolist()\n    loadg_list, _ = dgl.load_graphs(path, idx_list)\n\n    idx = idx_list[0]\n    load_g = loadg_list[0]\n    g_formats = load_g.formats()\n\n    # verify formats\n    if not isinstance(formats, list):\n        formats = [formats]\n    for fmt in formats:\n        assert fmt in g_formats[\"created\"]\n\n    assert F.allclose(load_g.nodes(), g_list[idx].nodes())\n\n    load_edges = load_g.all_edges(\"uv\", \"eid\")\n    g_edges = g_list[idx].all_edges(\"uv\", \"eid\")\n    assert F.allclose(load_edges[0], g_edges[0])\n    assert F.allclose(load_edges[1], g_edges[1])\n\n    os.unlink(path)\n\n\n@unittest.skipIf(F._default_context_str == \"gpu\", reason=\"GPU not implemented\")\ndef test_graph_serialize_with_restricted_formats():\n    g = dgl.rand_graph(100, 200)\n    g = g.formats([\"coo\"])\n    g_list = [g]\n\n    # create a temporary file and immediately release it so DGL can open it.\n    f = tempfile.NamedTemporaryFile(delete=False)\n    path = f.name\n    f.close()\n\n    expect_except = False\n    try:\n        dgl.save_graphs(path, g_list, formats=[\"csr\"])\n    except:\n        expect_except = True\n    assert expect_except\n\n    os.unlink(path)\n\n\n@unittest.skipIf(F._default_context_str == \"gpu\", reason=\"GPU not implemented\")\ndef test_deserialize_old_graph():\n    num_nodes = 100\n    num_edges = 200\n    path = os.path.join(os.path.dirname(__file__), \"data/graph_0.9a220622.dgl\")\n    g_list, _ = dgl.load_graphs(path)\n    g = g_list[0]\n    assert \"coo\" in g.formats()[\"created\"]\n    assert \"csr\" in g.formats()[\"not created\"]\n    assert \"csc\" in g.formats()[\"not created\"]\n    assert num_nodes == g.num_nodes()\n    assert num_edges == g.num_edges()\n"
  },
  {
    "path": "tests/python/common/data/test_utils.py",
    "content": "import gzip\nimport io\nimport os\nimport tarfile\nimport tempfile\nimport unittest\n\nimport backend as F\n\nimport dgl\nimport dgl.data as data\nimport numpy as np\nimport pandas as pd\nimport pytest\nimport yaml\nfrom dgl import DGLError\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Datasets don't need to be tested on GPU.\",\n)\n@unittest.skipIf(dgl.backend.backend_name == \"mxnet\", reason=\"Skip MXNet\")\ndef test_add_nodepred_split():\n    dataset = data.AmazonCoBuyComputerDataset()\n    print(\"train_mask\" in dataset[0].ndata)\n    data.utils.add_nodepred_split(dataset, [0.8, 0.1, 0.1])\n    assert \"train_mask\" in dataset[0].ndata\n\n    dataset = data.AIFBDataset()\n    print(\"train_mask\" in dataset[0].nodes[\"Publikationen\"].data)\n    data.utils.add_nodepred_split(\n        dataset, [0.8, 0.1, 0.1], ntype=\"Publikationen\"\n    )\n    assert \"train_mask\" in dataset[0].nodes[\"Publikationen\"].data\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Datasets don't need to be tested on GPU.\",\n)\n@unittest.skipIf(dgl.backend.backend_name == \"mxnet\", reason=\"Skip MXNet\")\ndef test_extract_archive():\n    # gzip\n    with tempfile.TemporaryDirectory() as src_dir:\n        gz_file = \"gz_archive\"\n        gz_path = os.path.join(src_dir, gz_file + \".gz\")\n        content = b\"test extract archive gzip\"\n        with gzip.open(gz_path, \"wb\") as f:\n            f.write(content)\n        with tempfile.TemporaryDirectory() as dst_dir:\n            data.utils.extract_archive(gz_path, dst_dir, overwrite=True)\n            assert os.path.exists(os.path.join(dst_dir, gz_file))\n\n    # tar\n    with tempfile.TemporaryDirectory() as src_dir:\n        tar_file = \"tar_archive\"\n        tar_path = os.path.join(src_dir, tar_file + \".tar\")\n        # default encode to utf8\n        content = \"test extract archive tar\\n\".encode()\n        info = tarfile.TarInfo(name=\"tar_archive\")\n        info.size = len(content)\n        with tarfile.open(tar_path, \"w\") as f:\n            f.addfile(info, io.BytesIO(content))\n        with tempfile.TemporaryDirectory() as dst_dir:\n            data.utils.extract_archive(tar_path, dst_dir, overwrite=True)\n            assert os.path.exists(os.path.join(dst_dir, tar_file))\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Datasets don't need to be tested on GPU.\",\n)\n@unittest.skipIf(dgl.backend.backend_name == \"mxnet\", reason=\"Skip MXNet\")\ndef test_mask_nodes_by_property():\n    num_nodes = 1000\n    property_values = np.random.uniform(size=num_nodes)\n    part_ratios = [0.3, 0.1, 0.1, 0.3, 0.2]\n    split_masks = data.utils.mask_nodes_by_property(\n        property_values, part_ratios\n    )\n    assert \"in_valid_mask\" in split_masks\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Datasets don't need to be tested on GPU.\",\n)\n@unittest.skipIf(dgl.backend.backend_name == \"mxnet\", reason=\"Skip MXNet\")\ndef test_add_node_property_split():\n    dataset = data.AmazonCoBuyComputerDataset()\n    part_ratios = [0.3, 0.1, 0.1, 0.3, 0.2]\n    for property_name in [\"popularity\", \"locality\", \"density\"]:\n        data.utils.add_node_property_split(dataset, part_ratios, property_name)\n        assert \"in_valid_mask\" in dataset[0].ndata\n\n\nif __name__ == \"__main__\":\n    test_extract_archive()\n    test_add_nodepred_split()\n    test_mask_nodes_by_property()\n    test_add_node_property_split()\n"
  },
  {
    "path": "tests/python/common/dataloading/test_dataloader.py",
    "content": "import unittest\n\nimport backend as F\n\nimport dgl\nfrom dgl.dataloading import (\n    as_edge_prediction_sampler,\n    negative_sampler,\n    NeighborSampler,\n)\nfrom utils import parametrize_idtype\n\n\ndef create_test_graph(idtype):\n    # test heterograph from the docstring, plus a user -- wishes -- game relation\n    # 3 users, 2 games, 2 developers\n    # metagraph:\n    #    ('user', 'follows', 'user'),\n    #    ('user', 'plays', 'game'),\n    #    ('user', 'wishes', 'game'),\n    #    ('developer', 'develops', 'game')])\n\n    g = dgl.heterograph(\n        {\n            (\"user\", \"follows\", \"user\"): ([0, 1], [1, 2]),\n            (\"user\", \"plays\", \"game\"): ([0, 1, 2, 1], [0, 0, 1, 1]),\n            (\"user\", \"wishes\", \"game\"): ([0, 2], [1, 0]),\n            (\"developer\", \"develops\", \"game\"): ([0, 1], [0, 1]),\n        },\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    assert g.idtype == idtype\n    assert g.device == F.ctx()\n    return g\n\n\n@parametrize_idtype\ndef test_edge_prediction_sampler(idtype):\n    g = create_test_graph(idtype)\n    sampler = NeighborSampler([10, 10])\n    sampler = as_edge_prediction_sampler(\n        sampler, negative_sampler=negative_sampler.Uniform(1)\n    )\n\n    seeds = F.copy_to(F.arange(0, 2, dtype=idtype), ctx=F.ctx())\n    # just a smoke test to make sure we don't fail internal assertions\n    result = sampler.sample(g, {\"follows\": seeds})\n\n\nif __name__ == \"__main__\":\n    test_edge_prediction_sampler()\n"
  },
  {
    "path": "tests/python/common/function/test_basics.py",
    "content": "import warnings\nfrom collections import defaultdict as ddict\n\nimport backend as F\n\nimport dgl\nimport networkx as nx\nimport numpy as np\nfrom utils import parametrize_idtype\n\nD = 5\nreduce_msg_shapes = set()\n\n\ndef message_func(edges):\n    assert F.ndim(edges.src[\"h\"]) == 2\n    assert F.shape(edges.src[\"h\"])[1] == D\n    return {\"m\": edges.src[\"h\"]}\n\n\ndef reduce_func(nodes):\n    msgs = nodes.mailbox[\"m\"]\n    reduce_msg_shapes.add(tuple(msgs.shape))\n    assert F.ndim(msgs) == 3\n    assert F.shape(msgs)[2] == D\n    return {\"accum\": F.sum(msgs, 1)}\n\n\ndef apply_node_func(nodes):\n    return {\"h\": nodes.data[\"h\"] + nodes.data[\"accum\"]}\n\n\ndef generate_graph_old(grad=False):\n    g = dgl.graph([])\n    g.add_nodes(10)  # 10 nodes\n    # create a graph where 0 is the source and 9 is the sink\n    # 17 edges\n    for i in range(1, 9):\n        g.add_edges(0, i)\n        g.add_edges(i, 9)\n    # add a back flow from 9 to 0\n    g.add_edges(9, 0)\n    g = g.to(F.ctx())\n    ncol = F.randn((10, D))\n    ecol = F.randn((17, D))\n    if grad:\n        ncol = F.attach_grad(ncol)\n        ecol = F.attach_grad(ecol)\n\n    g.ndata[\"h\"] = ncol\n    g.edata[\"w\"] = ecol\n    g.set_n_initializer(dgl.init.zero_initializer)\n    g.set_e_initializer(dgl.init.zero_initializer)\n    return g\n\n\ndef generate_graph(idtype, grad=False):\n    \"\"\"\n    s, d, eid\n    0, 1, 0\n    1, 9, 1\n    0, 2, 2\n    2, 9, 3\n    0, 3, 4\n    3, 9, 5\n    0, 4, 6\n    4, 9, 7\n    0, 5, 8\n    5, 9, 9\n    0, 6, 10\n    6, 9, 11\n    0, 7, 12\n    7, 9, 13\n    0, 8, 14\n    8, 9, 15\n    9, 0, 16\n    \"\"\"\n    u = F.tensor([0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0, 8, 9])\n    v = F.tensor([1, 9, 2, 9, 3, 9, 4, 9, 5, 9, 6, 9, 7, 9, 8, 9, 0])\n    g = dgl.graph((u, v), idtype=idtype)\n    assert g.device == F.ctx()\n    ncol = F.randn((10, D))\n    ecol = F.randn((17, D))\n    if grad:\n        ncol = F.attach_grad(ncol)\n        ecol = F.attach_grad(ecol)\n\n    g.ndata[\"h\"] = ncol\n    g.edata[\"w\"] = ecol\n    g.set_n_initializer(dgl.init.zero_initializer)\n    g.set_e_initializer(dgl.init.zero_initializer)\n    return g\n\n\ndef test_compatible():\n    g = generate_graph_old()\n\n\n@parametrize_idtype\ndef test_batch_setter_getter(idtype):\n    def _pfc(x):\n        return list(F.zerocopy_to_numpy(x)[:, 0])\n\n    g = generate_graph(idtype)\n    # set all nodes\n    g.ndata[\"h\"] = F.zeros((10, D))\n    assert F.allclose(g.ndata[\"h\"], F.zeros((10, D)))\n    # pop nodes\n    old_len = len(g.ndata)\n    g.ndata.pop(\"h\")\n    assert len(g.ndata) == old_len - 1\n    g.ndata[\"h\"] = F.zeros((10, D))\n    # set partial nodes\n    u = F.tensor([1, 3, 5], g.idtype)\n    g.nodes[u].data[\"h\"] = F.ones((3, D))\n    assert _pfc(g.ndata[\"h\"]) == [\n        0.0,\n        1.0,\n        0.0,\n        1.0,\n        0.0,\n        1.0,\n        0.0,\n        0.0,\n        0.0,\n        0.0,\n    ]\n    # get partial nodes\n    u = F.tensor([1, 2, 3], g.idtype)\n    assert _pfc(g.nodes[u].data[\"h\"]) == [1.0, 0.0, 1.0]\n\n    \"\"\"\n    s, d, eid\n    0, 1, 0\n    1, 9, 1\n    0, 2, 2\n    2, 9, 3\n    0, 3, 4\n    3, 9, 5\n    0, 4, 6\n    4, 9, 7\n    0, 5, 8\n    5, 9, 9\n    0, 6, 10\n    6, 9, 11\n    0, 7, 12\n    7, 9, 13\n    0, 8, 14\n    8, 9, 15\n    9, 0, 16\n    \"\"\"\n    # set all edges\n    g.edata[\"l\"] = F.zeros((17, D))\n    assert _pfc(g.edata[\"l\"]) == [0.0] * 17\n    # pop edges\n    old_len = len(g.edata)\n    g.edata.pop(\"l\")\n    assert len(g.edata) == old_len - 1\n    g.edata[\"l\"] = F.zeros((17, D))\n    # set partial edges (many-many)\n    u = F.tensor([0, 0, 2, 5, 9], g.idtype)\n    v = F.tensor([1, 3, 9, 9, 0], g.idtype)\n    g.edges[u, v].data[\"l\"] = F.ones((5, D))\n    truth = [0.0] * 17\n    truth[0] = truth[4] = truth[3] = truth[9] = truth[16] = 1.0\n    assert _pfc(g.edata[\"l\"]) == truth\n    u = F.tensor([3, 4, 6], g.idtype)\n    v = F.tensor([9, 9, 9], g.idtype)\n    g.edges[u, v].data[\"l\"] = F.ones((3, D))\n    truth[5] = truth[7] = truth[11] = 1.0\n    assert _pfc(g.edata[\"l\"]) == truth\n    u = F.tensor([0, 0, 0], g.idtype)\n    v = F.tensor([4, 5, 6], g.idtype)\n    g.edges[u, v].data[\"l\"] = F.ones((3, D))\n    truth[6] = truth[8] = truth[10] = 1.0\n    assert _pfc(g.edata[\"l\"]) == truth\n    u = F.tensor([0, 6, 0], g.idtype)\n    v = F.tensor([6, 9, 7], g.idtype)\n    assert _pfc(g.edges[u, v].data[\"l\"]) == [1.0, 1.0, 0.0]\n\n\n@parametrize_idtype\ndef test_batch_setter_autograd(idtype):\n    g = generate_graph(idtype, grad=True)\n    h1 = g.ndata[\"h\"]\n    # partial set\n    v = F.tensor([1, 2, 8], g.idtype)\n    hh = F.attach_grad(F.zeros((len(v), D)))\n    with F.record_grad():\n        g.nodes[v].data[\"h\"] = hh\n        h2 = g.ndata[\"h\"]\n        F.backward(h2, F.ones((10, D)) * 2)\n    assert F.array_equal(\n        F.grad(h1)[:, 0],\n        F.tensor([2.0, 0.0, 0.0, 2.0, 2.0, 2.0, 2.0, 2.0, 0.0, 2.0]),\n    )\n    assert F.array_equal(F.grad(hh)[:, 0], F.tensor([2.0, 2.0, 2.0]))\n\n\ndef _test_nx_conversion():\n    # check conversion between networkx and DGLGraph\n\n    def _check_nx_feature(nxg, nf, ef):\n        # check node and edge feature of nxg\n        # this is used to check to_networkx\n        num_nodes = len(nxg)\n        num_edges = nxg.size()\n        if num_nodes > 0:\n            node_feat = ddict(list)\n            for nid, attr in nxg.nodes(data=True):\n                assert len(attr) == len(nf)\n                for k in nxg.nodes[nid]:\n                    node_feat[k].append(F.unsqueeze(attr[k], 0))\n            for k in node_feat:\n                feat = F.cat(node_feat[k], 0)\n                assert F.allclose(feat, nf[k])\n        else:\n            assert len(nf) == 0\n        if num_edges > 0:\n            edge_feat = ddict(lambda: [0] * num_edges)\n            for u, v, attr in nxg.edges(data=True):\n                assert len(attr) == len(ef) + 1  # extra id\n                eid = attr[\"id\"]\n                for k in ef:\n                    edge_feat[k][eid] = F.unsqueeze(attr[k], 0)\n            for k in edge_feat:\n                feat = F.cat(edge_feat[k], 0)\n                assert F.allclose(feat, ef[k])\n        else:\n            assert len(ef) == 0\n\n    n1 = F.randn((5, 3))\n    n2 = F.randn((5, 10))\n    n3 = F.randn((5, 4))\n    e1 = F.randn((4, 5))\n    e2 = F.randn((4, 7))\n    g = dgl.graph(([0, 1, 3, 4], [2, 4, 0, 3]))\n    g.ndata.update({\"n1\": n1, \"n2\": n2, \"n3\": n3})\n    g.edata.update({\"e1\": e1, \"e2\": e2})\n\n    # convert to networkx\n    nxg = g.to_networkx(node_attrs=[\"n1\", \"n3\"], edge_attrs=[\"e1\", \"e2\"])\n    assert len(nxg) == 5\n    assert nxg.size() == 4\n    _check_nx_feature(nxg, {\"n1\": n1, \"n3\": n3}, {\"e1\": e1, \"e2\": e2})\n\n    # convert to DGLGraph, nx graph has id in edge feature\n    # use id feature to test non-tensor copy\n    g = dgl.from_networkx(nxg, node_attrs=[\"n1\"], edge_attrs=[\"e1\", \"id\"])\n    # check graph size\n    assert g.num_nodes() == 5\n    assert g.num_edges() == 4\n    # check number of features\n    # test with existing dglgraph (so existing features should be cleared)\n    assert len(g.ndata) == 1\n    assert len(g.edata) == 2\n    # check feature values\n    assert F.allclose(g.ndata[\"n1\"], n1)\n    # with id in nx edge feature, e1 should follow original order\n    assert F.allclose(g.edata[\"e1\"], e1)\n    assert F.array_equal(\n        F.astype(g.edata[\"id\"], F.int64), F.copy_to(F.arange(0, 4), F.cpu())\n    )\n\n    # test conversion after modifying DGLGraph\n    g.edata.pop(\"id\")  # pop id so we don't need to provide id when adding edges\n    new_n = F.randn((2, 3))\n    new_e = F.randn((3, 5))\n    g.add_nodes(2, data={\"n1\": new_n})\n    # add three edges, one is a multi-edge\n    g.add_edges([3, 6, 0], [4, 5, 2], data={\"e1\": new_e})\n    n1 = F.cat((n1, new_n), 0)\n    e1 = F.cat((e1, new_e), 0)\n    # convert to networkx again\n    nxg = g.to_networkx(node_attrs=[\"n1\"], edge_attrs=[\"e1\"])\n    assert len(nxg) == 7\n    assert nxg.size() == 7\n    _check_nx_feature(nxg, {\"n1\": n1}, {\"e1\": e1})\n\n    # now test convert from networkx without id in edge feature\n    # first pop id in edge feature\n    for _, _, attr in nxg.edges(data=True):\n        attr.pop(\"id\")\n    # test with a new graph\n    g = dgl.from_networkx(nxg, node_attrs=[\"n1\"], edge_attrs=[\"e1\"])\n    # check graph size\n    assert g.num_nodes() == 7\n    assert g.num_edges() == 7\n    # check number of features\n    assert len(g.ndata) == 1\n    assert len(g.edata) == 1\n    # check feature values\n    assert F.allclose(g.ndata[\"n1\"], n1)\n    # edge feature order follows nxg.edges()\n    edge_feat = []\n    for _, _, attr in nxg.edges(data=True):\n        edge_feat.append(F.unsqueeze(attr[\"e1\"], 0))\n    edge_feat = F.cat(edge_feat, 0)\n    assert F.allclose(g.edata[\"e1\"], edge_feat)\n\n    # Test converting from a networkx graph whose nodes are\n    # not labeled with consecutive-integers.\n    nxg = nx.cycle_graph(5)\n    nxg.remove_nodes_from([0, 4])\n    for u in nxg.nodes():\n        nxg.nodes[u][\"h\"] = F.tensor([u])\n    for u, v, d in nxg.edges(data=True):\n        d[\"h\"] = F.tensor([u, v])\n\n    g = dgl.from_networkx(nxg, node_attrs=[\"h\"], edge_attrs=[\"h\"])\n    assert g.num_nodes() == 3\n    assert g.num_edges() == 4\n    assert g.has_edge_between(0, 1)\n    assert g.has_edge_between(1, 2)\n    assert F.allclose(g.ndata[\"h\"], F.tensor([[1.0], [2.0], [3.0]]))\n    assert F.allclose(\n        g.edata[\"h\"], F.tensor([[1.0, 2.0], [1.0, 2.0], [2.0, 3.0], [2.0, 3.0]])\n    )\n\n\n@parametrize_idtype\ndef test_apply_nodes(idtype):\n    def _upd(nodes):\n        return {\"h\": nodes.data[\"h\"] * 2}\n\n    g = generate_graph(idtype)\n    old = g.ndata[\"h\"]\n    g.apply_nodes(_upd)\n    assert F.allclose(old * 2, g.ndata[\"h\"])\n    u = F.tensor([0, 3, 4, 6], g.idtype)\n    g.apply_nodes(lambda nodes: {\"h\": nodes.data[\"h\"] * 0.0}, u)\n    assert F.allclose(F.gather_row(g.ndata[\"h\"], u), F.zeros((4, D)))\n\n\n@parametrize_idtype\ndef test_apply_edges(idtype):\n    def _upd(edges):\n        return {\"w\": edges.data[\"w\"] * 2}\n\n    g = generate_graph(idtype)\n    old = g.edata[\"w\"]\n    g.apply_edges(_upd)\n    assert F.allclose(old * 2, g.edata[\"w\"])\n    u = F.tensor([0, 0, 0, 4, 5, 6], g.idtype)\n    v = F.tensor([1, 2, 3, 9, 9, 9], g.idtype)\n    g.apply_edges(lambda edges: {\"w\": edges.data[\"w\"] * 0.0}, (u, v))\n    eid = F.tensor(g.edge_ids(u, v))\n    assert F.allclose(F.gather_row(g.edata[\"w\"], eid), F.zeros((6, D)))\n\n\n@parametrize_idtype\ndef test_update_routines(idtype):\n    g = generate_graph(idtype)\n\n    # send_and_recv\n    reduce_msg_shapes.clear()\n    u = [0, 0, 0, 4, 5, 6]\n    v = [1, 2, 3, 9, 9, 9]\n    g.send_and_recv((u, v), message_func, reduce_func, apply_node_func)\n    assert reduce_msg_shapes == {(1, 3, D), (3, 1, D)}\n    reduce_msg_shapes.clear()\n    try:\n        g.send_and_recv([u, v])\n        assert False\n    except:\n        pass\n\n    # pull\n    v = F.tensor([1, 2, 3, 9], g.idtype)\n    reduce_msg_shapes.clear()\n    g.pull(v, message_func, reduce_func, apply_node_func)\n    assert reduce_msg_shapes == {(1, 8, D), (3, 1, D)}\n    reduce_msg_shapes.clear()\n\n    # push\n    v = F.tensor([0, 1, 2, 3], g.idtype)\n    reduce_msg_shapes.clear()\n    g.push(v, message_func, reduce_func, apply_node_func)\n    assert reduce_msg_shapes == {(1, 3, D), (8, 1, D)}\n    reduce_msg_shapes.clear()\n\n    # update_all\n    reduce_msg_shapes.clear()\n    g.update_all(message_func, reduce_func, apply_node_func)\n    assert reduce_msg_shapes == {(1, 8, D), (9, 1, D)}\n    reduce_msg_shapes.clear()\n\n\n@parametrize_idtype\ndef test_update_all_0deg(idtype):\n    # test#1\n    g = dgl.graph(([1, 2, 3, 4], [0, 0, 0, 0]), idtype=idtype, device=F.ctx())\n\n    def _message(edges):\n        return {\"m\": edges.src[\"h\"]}\n\n    def _reduce(nodes):\n        return {\"x\": nodes.data[\"h\"] + F.sum(nodes.mailbox[\"m\"], 1)}\n\n    def _apply(nodes):\n        return {\"x\": nodes.data[\"x\"] * 2}\n\n    def _init2(shape, dtype, ctx, ids):\n        return 2 + F.zeros(shape, dtype, ctx)\n\n    g.set_n_initializer(_init2, \"x\")\n    old_repr = F.randn((5, 5))\n    g.ndata[\"h\"] = old_repr\n    g.update_all(_message, _reduce, _apply)\n    new_repr = g.ndata[\"x\"]\n    # the first row of the new_repr should be the sum of all the node\n    # features; while the 0-deg nodes should be initialized by the\n    # initializer and applied with UDF.\n    assert F.allclose(new_repr[1:], 2 * (2 + F.zeros((4, 5))))\n    assert F.allclose(new_repr[0], 2 * F.sum(old_repr, 0))\n\n    # test#2: graph with no edge\n    g = dgl.graph(([], []), num_nodes=5, idtype=idtype, device=F.ctx())\n    g.ndata[\"h\"] = old_repr\n    # Intercepting the warning: The input graph for the user-defined edge\n    # function does not contain valid edges.\n    with warnings.catch_warnings():\n        warnings.simplefilter(\"ignore\", category=UserWarning)\n        g.update_all(\n            _message, _reduce, lambda nodes: {\"h\": nodes.data[\"h\"] * 2}\n        )\n\n    new_repr = g.ndata[\"h\"]\n    # should fallback to apply\n    assert F.allclose(new_repr, 2 * old_repr)\n\n\n@parametrize_idtype\ndef test_pull_0deg(idtype):\n    g = dgl.graph(([0], [1]), idtype=idtype, device=F.ctx())\n\n    def _message(edges):\n        return {\"m\": edges.src[\"h\"]}\n\n    def _reduce(nodes):\n        return {\"x\": nodes.data[\"h\"] + F.sum(nodes.mailbox[\"m\"], 1)}\n\n    def _apply(nodes):\n        return {\"x\": nodes.data[\"x\"] * 2}\n\n    def _init2(shape, dtype, ctx, ids):\n        return 2 + F.zeros(shape, dtype, ctx)\n\n    g.set_n_initializer(_init2, \"x\")\n    # test#1: pull both 0deg and non-0deg nodes\n    old = F.randn((2, 5))\n    g.ndata[\"h\"] = old\n    g.pull([0, 1], _message, _reduce, _apply)\n    new = g.ndata[\"x\"]\n    # 0deg check: initialized with the func and got applied\n    assert F.allclose(new[0], F.full_1d(5, 4, dtype=F.float32))\n    # non-0deg check\n    assert F.allclose(new[1], F.sum(old, 0) * 2)\n\n    # test#2: pull only 0deg node\n    old = F.randn((2, 5))\n    g.ndata[\"h\"] = old\n    # Intercepting the warning: The input graph for the user-defined edge\n    # function does not contain valid edges\n    with warnings.catch_warnings():\n        warnings.simplefilter(\"ignore\", category=UserWarning)\n        g.pull(0, _message, _reduce, lambda nodes: {\"h\": nodes.data[\"h\"] * 2})\n\n    new = g.ndata[\"h\"]\n    # 0deg check: fallback to apply\n    assert F.allclose(new[0], 2 * old[0])\n    # non-0deg check: not touched\n    assert F.allclose(new[1], old[1])\n\n\ndef test_dynamic_addition():\n    N = 3\n    D = 1\n\n    g = dgl.graph([]).to(F.ctx())\n\n    # Test node addition\n    g.add_nodes(N)\n    g.ndata.update({\"h1\": F.randn((N, D)), \"h2\": F.randn((N, D))})\n    g.add_nodes(3)\n    assert g.ndata[\"h1\"].shape[0] == g.ndata[\"h2\"].shape[0] == N + 3\n\n    # Test edge addition\n    g.add_edges(0, 1)\n    g.add_edges(1, 0)\n    g.edata.update({\"h1\": F.randn((2, D)), \"h2\": F.randn((2, D))})\n    assert g.edata[\"h1\"].shape[0] == g.edata[\"h2\"].shape[0] == 2\n\n    g.add_edges([0, 2], [2, 0])\n    g.edata[\"h1\"] = F.randn((4, D))\n    assert g.edata[\"h1\"].shape[0] == g.edata[\"h2\"].shape[0] == 4\n\n    g.add_edges(1, 2)\n    g.edges[4].data[\"h1\"] = F.randn((1, D))\n    assert g.edata[\"h1\"].shape[0] == g.edata[\"h2\"].shape[0] == 5\n\n    # test add edge with part of the features\n    g.add_edges(2, 1, {\"h1\": F.randn((1, D))})\n    assert len(g.edata[\"h1\"]) == len(g.edata[\"h2\"])\n\n\n@parametrize_idtype\ndef test_repr(idtype):\n    g = dgl.graph(\n        ([0, 0, 1], [1, 2, 2]), num_nodes=10, idtype=idtype, device=F.ctx()\n    )\n    repr_string = g.__repr__()\n    print(repr_string)\n    g.ndata[\"x\"] = F.zeros((10, 5))\n    g.edata[\"y\"] = F.zeros((3, 4))\n    repr_string = g.__repr__()\n    print(repr_string)\n\n\n@parametrize_idtype\ndef test_local_var(idtype):\n    g = dgl.graph(([0, 1, 2, 3], [1, 2, 3, 4]), idtype=idtype, device=F.ctx())\n    g.ndata[\"h\"] = F.zeros((g.num_nodes(), 3))\n    g.edata[\"w\"] = F.zeros((g.num_edges(), 4))\n\n    # test override\n    def foo(g):\n        g = g.local_var()\n        g.ndata[\"h\"] = F.ones((g.num_nodes(), 3))\n        g.edata[\"w\"] = F.ones((g.num_edges(), 4))\n\n    foo(g)\n    assert F.allclose(g.ndata[\"h\"], F.zeros((g.num_nodes(), 3)))\n    assert F.allclose(g.edata[\"w\"], F.zeros((g.num_edges(), 4)))\n\n    # test out-place update\n    def foo(g):\n        g = g.local_var()\n        g.nodes[[2, 3]].data[\"h\"] = F.ones((2, 3))\n        g.edges[[2, 3]].data[\"w\"] = F.ones((2, 4))\n\n    foo(g)\n    assert F.allclose(g.ndata[\"h\"], F.zeros((g.num_nodes(), 3)))\n    assert F.allclose(g.edata[\"w\"], F.zeros((g.num_edges(), 4)))\n\n    # test out-place update 2\n    def foo(g):\n        g = g.local_var()\n        g.apply_nodes(lambda nodes: {\"h\": nodes.data[\"h\"] + 10}, [2, 3])\n        g.apply_edges(lambda edges: {\"w\": edges.data[\"w\"] + 10}, [2, 3])\n\n    foo(g)\n    assert F.allclose(g.ndata[\"h\"], F.zeros((g.num_nodes(), 3)))\n    assert F.allclose(g.edata[\"w\"], F.zeros((g.num_edges(), 4)))\n\n    # test auto-pop\n    def foo(g):\n        g = g.local_var()\n        g.ndata[\"hh\"] = F.ones((g.num_nodes(), 3))\n        g.edata[\"ww\"] = F.ones((g.num_edges(), 4))\n\n    foo(g)\n    assert \"hh\" not in g.ndata\n    assert \"ww\" not in g.edata\n\n    # test initializer1\n    g = dgl.graph(([0, 1], [1, 1]), idtype=idtype, device=F.ctx())\n    g.set_n_initializer(dgl.init.zero_initializer)\n\n    def foo(g):\n        g = g.local_var()\n        g.nodes[0].data[\"h\"] = F.ones((1, 1))\n        assert F.allclose(g.ndata[\"h\"], F.tensor([[1.0], [0.0]]))\n\n    foo(g)\n\n    # test initializer2\n    def foo_e_initializer(shape, dtype, ctx, id_range):\n        return F.ones(shape)\n\n    g.set_e_initializer(foo_e_initializer, field=\"h\")\n\n    def foo(g):\n        g = g.local_var()\n        g.edges[0, 1].data[\"h\"] = F.ones((1, 1))\n        assert F.allclose(g.edata[\"h\"], F.ones((2, 1)))\n        g.edges[0, 1].data[\"w\"] = F.ones((1, 1))\n        assert F.allclose(g.edata[\"w\"], F.tensor([[1.0], [0.0]]))\n\n    foo(g)\n\n\n@parametrize_idtype\ndef test_local_scope(idtype):\n    g = dgl.graph(([0, 1, 2, 3], [1, 2, 3, 4]), idtype=idtype, device=F.ctx())\n    g.ndata[\"h\"] = F.zeros((g.num_nodes(), 3))\n    g.edata[\"w\"] = F.zeros((g.num_edges(), 4))\n\n    # test override\n    def foo(g):\n        with g.local_scope():\n            g.ndata[\"h\"] = F.ones((g.num_nodes(), 3))\n            g.edata[\"w\"] = F.ones((g.num_edges(), 4))\n\n    foo(g)\n    assert F.allclose(g.ndata[\"h\"], F.zeros((g.num_nodes(), 3)))\n    assert F.allclose(g.edata[\"w\"], F.zeros((g.num_edges(), 4)))\n\n    # test out-place update\n    def foo(g):\n        with g.local_scope():\n            g.nodes[[2, 3]].data[\"h\"] = F.ones((2, 3))\n            g.edges[[2, 3]].data[\"w\"] = F.ones((2, 4))\n\n    foo(g)\n    assert F.allclose(g.ndata[\"h\"], F.zeros((g.num_nodes(), 3)))\n    assert F.allclose(g.edata[\"w\"], F.zeros((g.num_edges(), 4)))\n\n    # test out-place update 2\n    def foo(g):\n        with g.local_scope():\n            g.apply_nodes(lambda nodes: {\"h\": nodes.data[\"h\"] + 10}, [2, 3])\n            g.apply_edges(lambda edges: {\"w\": edges.data[\"w\"] + 10}, [2, 3])\n\n    foo(g)\n    assert F.allclose(g.ndata[\"h\"], F.zeros((g.num_nodes(), 3)))\n    assert F.allclose(g.edata[\"w\"], F.zeros((g.num_edges(), 4)))\n\n    # test auto-pop\n    def foo(g):\n        with g.local_scope():\n            g.ndata[\"hh\"] = F.ones((g.num_nodes(), 3))\n            g.edata[\"ww\"] = F.ones((g.num_edges(), 4))\n\n    foo(g)\n    assert \"hh\" not in g.ndata\n    assert \"ww\" not in g.edata\n\n    # test nested scope\n    def foo(g):\n        with g.local_scope():\n            g.ndata[\"hh\"] = F.ones((g.num_nodes(), 3))\n            g.edata[\"ww\"] = F.ones((g.num_edges(), 4))\n            with g.local_scope():\n                g.ndata[\"hhh\"] = F.ones((g.num_nodes(), 3))\n                g.edata[\"www\"] = F.ones((g.num_edges(), 4))\n            assert \"hhh\" not in g.ndata\n            assert \"www\" not in g.edata\n\n    foo(g)\n    assert \"hh\" not in g.ndata\n    assert \"ww\" not in g.edata\n\n    # test initializer1\n    g = dgl.graph(([0, 1], [1, 1]), idtype=idtype, device=F.ctx())\n    g.set_n_initializer(dgl.init.zero_initializer)\n\n    def foo(g):\n        with g.local_scope():\n            g.nodes[0].data[\"h\"] = F.ones((1, 1))\n            assert F.allclose(g.ndata[\"h\"], F.tensor([[1.0], [0.0]]))\n\n    foo(g)\n\n    # test initializer2\n    def foo_e_initializer(shape, dtype, ctx, id_range):\n        return F.ones(shape)\n\n    g.set_e_initializer(foo_e_initializer, field=\"h\")\n\n    def foo(g):\n        with g.local_scope():\n            g.edges[0, 1].data[\"h\"] = F.ones((1, 1))\n            assert F.allclose(g.edata[\"h\"], F.ones((2, 1)))\n            g.edges[0, 1].data[\"w\"] = F.ones((1, 1))\n            assert F.allclose(g.edata[\"w\"], F.tensor([[1.0], [0.0]]))\n\n    foo(g)\n\n    # test exception handling\n    def foo(g):\n        try:\n            with g.local_scope():\n                g.ndata[\"hh\"] = F.ones((g.num_nodes(), 1))\n                # throw TypeError\n                1 + \"1\"\n        except TypeError:\n            pass\n        assert \"hh\" not in g.ndata\n\n    foo(g)\n\n\n@parametrize_idtype\ndef test_isolated_nodes(idtype):\n    g = dgl.graph(([0, 1], [1, 2]), num_nodes=5, idtype=idtype, device=F.ctx())\n    assert g.num_nodes() == 5\n\n    g = dgl.heterograph(\n        {(\"user\", \"plays\", \"game\"): ([0, 0, 1], [2, 3, 2])},\n        {\"user\": 5, \"game\": 7},\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    assert g.idtype == idtype\n    assert g.num_nodes(\"user\") == 5\n    assert g.num_nodes(\"game\") == 7\n\n    # Test backward compatibility\n    g = dgl.heterograph(\n        {(\"user\", \"plays\", \"game\"): ([0, 0, 1], [2, 3, 2])},\n        {\"user\": 5, \"game\": 7},\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    assert g.idtype == idtype\n    assert g.num_nodes(\"user\") == 5\n    assert g.num_nodes(\"game\") == 7\n\n\n@parametrize_idtype\ndef test_send_multigraph(idtype):\n    g = dgl.graph(([0, 0, 0, 2], [1, 1, 1, 1]), idtype=idtype, device=F.ctx())\n\n    def _message_a(edges):\n        return {\"a\": edges.data[\"a\"]}\n\n    def _message_b(edges):\n        return {\"a\": edges.data[\"a\"] * 3}\n\n    def _reduce(nodes):\n        return {\"a\": F.max(nodes.mailbox[\"a\"], 1)}\n\n    def answer(*args):\n        return F.max(F.stack(args, 0), 0)\n\n    assert g.is_multigraph\n\n    # send by eid\n    old_repr = F.randn((4, 5))\n    # send_and_recv_on\n    g.ndata[\"a\"] = F.zeros((3, 5))\n    g.edata[\"a\"] = old_repr\n    g.send_and_recv([0, 2, 3], message_func=_message_a, reduce_func=_reduce)\n    new_repr = g.ndata[\"a\"]\n    assert F.allclose(\n        new_repr[1], answer(old_repr[0], old_repr[2], old_repr[3])\n    )\n    assert F.allclose(new_repr[[0, 2]], F.zeros((2, 5)))\n\n\n@parametrize_idtype\ndef test_issue_1088(idtype):\n    # This test ensures that message passing on a heterograph with one edge type\n    # would not crash (GitHub issue #1088).\n    import dgl.function as fn\n\n    g = dgl.heterograph(\n        {(\"U\", \"E\", \"V\"): ([0, 1, 2], [1, 2, 3])}, idtype=idtype, device=F.ctx()\n    )\n    g.nodes[\"U\"].data[\"x\"] = F.randn((3, 3))\n    g.update_all(fn.copy_u(\"x\", \"m\"), fn.sum(\"m\", \"y\"))\n\n\n@parametrize_idtype\ndef test_degree_bucket_edge_ordering(idtype):\n    import dgl.function as fn\n\n    g = dgl.graph(\n        ([1, 3, 5, 0, 4, 2, 3, 3, 4, 5], [1, 1, 0, 0, 1, 2, 2, 0, 3, 3]),\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    g.edata[\"eid\"] = F.copy_to(F.arange(0, 10), F.ctx())\n\n    def reducer(nodes):\n        eid = F.asnumpy(F.copy_to(nodes.mailbox[\"eid\"], F.cpu()))\n        assert np.array_equal(eid, np.sort(eid, 1))\n        return {\"n\": F.sum(nodes.mailbox[\"eid\"], 1)}\n\n    g.update_all(fn.copy_e(\"eid\", \"eid\"), reducer)\n\n\n@parametrize_idtype\ndef test_issue_2484(idtype):\n    import dgl.function as fn\n\n    g = dgl.graph(([0, 1, 2], [1, 2, 3]), idtype=idtype, device=F.ctx())\n    x = F.copy_to(F.randn((4,)), F.ctx())\n    g.ndata[\"x\"] = x\n    g.pull([2, 1], fn.u_add_v(\"x\", \"x\", \"m\"), fn.sum(\"m\", \"x\"))\n    y1 = g.ndata[\"x\"]\n\n    g.ndata[\"x\"] = x\n    g.pull([1, 2], fn.u_add_v(\"x\", \"x\", \"m\"), fn.sum(\"m\", \"x\"))\n    y2 = g.ndata[\"x\"]\n\n    assert F.allclose(y1, y2)\n"
  },
  {
    "path": "tests/python/common/ops/test_edge_softmax.py",
    "content": "import itertools\nimport math\nimport unittest\nfrom collections import Counter\n\nimport backend as F\n\nimport dgl\nimport dgl.function as fn\nimport networkx as nx\nimport numpy as np\nimport pytest\nimport scipy.sparse as ssp\nfrom dgl import DGLError\nfrom dgl.ops import edge_softmax\nfrom scipy.sparse import rand\nfrom utils import get_cases, parametrize_idtype\n\nedge_softmax_shapes = [(1,), (1, 3), (3, 4, 5)]\nrfuncs = {\"sum\": fn.sum, \"max\": fn.max, \"min\": fn.min, \"mean\": fn.mean}\nfill_value = {\"sum\": 0, \"max\": float(\"-inf\")}\nfeat_size = 2\n\n\n@pytest.mark.parametrize(\"g\", get_cases([\"clique\"]))\n@pytest.mark.parametrize(\"norm_by\", [\"src\", \"dst\"])\n@pytest.mark.parametrize(\"shp\", edge_softmax_shapes)\n@parametrize_idtype\ndef test_edge_softmax(g, norm_by, shp, idtype):\n    g = g.astype(idtype).to(F.ctx())\n    edata = F.tensor(np.random.rand(g.num_edges(), *shp))\n    e1 = F.attach_grad(F.clone(edata))\n\n    with F.record_grad():\n        score1 = edge_softmax(g, e1, norm_by=norm_by)\n        F.backward(F.reduce_sum(score1))\n        grad_edata = F.grad(e1)\n\n    with F.record_grad():\n        e2 = F.attach_grad(F.clone(edata))\n        e2_2d = F.reshape(\n            e2,\n            (g.number_of_src_nodes(), g.number_of_dst_nodes(), *e2.shape[1:]),\n        )\n        if norm_by == \"src\":\n            score2 = F.softmax(e2_2d, 1)\n            score2 = F.reshape(score2, (-1, *e2.shape[1:]))\n        if norm_by == \"dst\":\n            score2 = F.softmax(e2_2d, 0)\n            score2 = F.reshape(score2, (-1, *e2.shape[1:]))\n        assert F.allclose(score1, score2)\n        print(\"forward passed\")\n\n        F.backward(F.reduce_sum(score2))\n        assert F.allclose(F.grad(e2), grad_edata)\n        print(\"backward passed\")\n\n\ndef create_test_heterograph(idtype):\n    # test heterograph from the docstring, plus a user -- wishes -- game relation\n    # 3 users, 2 games, 2 developers\n    # metagraph:\n    #    ('user', 'follows', 'user'),\n    #    ('user', 'plays', 'game'),\n    #    ('user', 'wishes', 'game'),\n    #    ('developer', 'develops', 'game')])\n\n    g = dgl.heterograph(\n        {\n            (\"user\", \"follows\", \"user\"): ([0, 1, 2, 1, 1], [0, 0, 1, 1, 2]),\n            (\"user\", \"plays\", \"game\"): ([0, 1, 2, 1], [0, 0, 1, 1]),\n            (\"user\", \"wishes\", \"game\"): ([0, 1, 1], [0, 0, 1]),\n            (\"developer\", \"develops\", \"game\"): ([0, 1, 0], [0, 1, 1]),\n        },\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    assert g.idtype == idtype\n    assert g.device == F.ctx()\n    return g\n\n\n@unittest.skipIf(\n    dgl.backend.backend_name != \"pytorch\", reason=\"Only support PyTorch for now\"\n)\ndef test_edge_softmax_unidirectional():\n    g = dgl.heterograph(\n        {\n            (\"A\", \"AB\", \"B\"): (\n                [1, 2, 3, 1, 2, 3, 1, 2, 3],\n                [0, 0, 0, 1, 1, 1, 2, 2, 2],\n            ),\n            (\"B\", \"BB\", \"B\"): (\n                [0, 1, 2, 0, 1, 2, 0, 1, 2],\n                [0, 0, 0, 1, 1, 1, 2, 2, 2],\n            ),\n        }\n    )\n    g = g.to(F.ctx())\n    g.edges[\"AB\"].data[\"x\"] = F.ones(9) * 2\n    g.edges[\"BB\"].data[\"x\"] = F.ones(9)\n    result = dgl.ops.edge_softmax(\n        g, {\"AB\": g.edges[\"AB\"].data[\"x\"], \"BB\": g.edges[\"BB\"].data[\"x\"]}\n    )\n\n    ab = result[\"A\", \"AB\", \"B\"]\n    bb = result[\"B\", \"BB\", \"B\"]\n    e2 = F.zeros_like(ab) + math.exp(2) / ((math.exp(2) + math.exp(1)) * 3)\n    e1 = F.zeros_like(bb) + math.exp(1) / ((math.exp(2) + math.exp(1)) * 3)\n    assert F.allclose(ab, e2)\n    assert F.allclose(bb, e1)\n\n\n@unittest.skipIf(\n    dgl.backend.backend_name != \"pytorch\", reason=\"Only support PyTorch for now\"\n)\n@pytest.mark.parametrize(\"g\", get_cases([\"clique\"]))\n@pytest.mark.parametrize(\"norm_by\", [\"src\", \"dst\"])\n# @pytest.mark.parametrize('shp', edge_softmax_shapes)\n@parametrize_idtype\ndef test_edge_softmax(g, norm_by, idtype):\n    print(\"params\", norm_by, idtype)\n\n    g = create_test_heterograph(idtype)\n\n    x1 = F.randn((g.num_edges(\"plays\"), feat_size))\n    x2 = F.randn((g.num_edges(\"follows\"), feat_size))\n    x3 = F.randn((g.num_edges(\"develops\"), feat_size))\n    x4 = F.randn((g.num_edges(\"wishes\"), feat_size))\n\n    F.attach_grad(F.clone(x1))\n    F.attach_grad(F.clone(x2))\n    F.attach_grad(F.clone(x3))\n    F.attach_grad(F.clone(x4))\n\n    g[\"plays\"].edata[\"eid\"] = x1\n    g[\"follows\"].edata[\"eid\"] = x2\n    g[\"develops\"].edata[\"eid\"] = x3\n    g[\"wishes\"].edata[\"eid\"] = x4\n\n    #################################################################\n    #  edge_softmax() on homogeneous graph\n    #################################################################\n\n    with F.record_grad():\n        hm_g = dgl.to_homogeneous(g)\n        hm_x = F.cat((x3, x2, x1, x4), 0)\n        hm_e = F.attach_grad(F.clone(hm_x))\n        score_hm = edge_softmax(hm_g, hm_e, norm_by=norm_by)\n        hm_g.edata[\"score\"] = score_hm\n        ht_g = dgl.to_heterogeneous(hm_g, g.ntypes, g.etypes)\n        r1 = ht_g.edata[\"score\"][(\"user\", \"plays\", \"game\")]\n        r2 = ht_g.edata[\"score\"][(\"user\", \"follows\", \"user\")]\n        r3 = ht_g.edata[\"score\"][(\"developer\", \"develops\", \"game\")]\n        r4 = ht_g.edata[\"score\"][(\"user\", \"wishes\", \"game\")]\n        F.backward(F.reduce_sum(r1) + F.reduce_sum(r2))\n        grad_edata_hm = F.grad(hm_e)\n\n    #################################################################\n    #  edge_softmax() on heterogeneous graph\n    #################################################################\n\n    e1 = F.attach_grad(F.clone(x1))\n    e2 = F.attach_grad(F.clone(x2))\n    e3 = F.attach_grad(F.clone(x3))\n    e4 = F.attach_grad(F.clone(x4))\n    e = {\n        (\"user\", \"follows\", \"user\"): e2,\n        (\"user\", \"plays\", \"game\"): e1,\n        (\"user\", \"wishes\", \"game\"): e4,\n        (\"developer\", \"develops\", \"game\"): e3,\n    }\n    with F.record_grad():\n        score = edge_softmax(g, e, norm_by=norm_by)\n        r5 = score[(\"user\", \"plays\", \"game\")]\n        r6 = score[(\"user\", \"follows\", \"user\")]\n        r7 = score[(\"developer\", \"develops\", \"game\")]\n        r8 = score[(\"user\", \"wishes\", \"game\")]\n        F.backward(F.reduce_sum(r5) + F.reduce_sum(r6))\n        grad_edata_ht = F.cat(\n            (F.grad(e3), F.grad(e2), F.grad(e1), F.grad(e4)), 0\n        )\n        # correctness check\n        assert F.allclose(r1, r5)\n        assert F.allclose(r2, r6)\n        assert F.allclose(r3, r7)\n        assert F.allclose(r4, r8)\n        assert F.allclose(grad_edata_hm, grad_edata_ht)\n\n\nif __name__ == \"__main__\":\n    test_edge_softmax_unidirectional()\n"
  },
  {
    "path": "tests/python/common/ops/test_ops.py",
    "content": "import random\nimport unittest\n\nimport backend as F\n\nimport dgl\nimport numpy as np\nimport pytest\nimport torch\nfrom dgl.ops import gather_mm, gsddmm, gspmm, segment_reduce\nfrom utils import parametrize_idtype\nfrom utils.graph_cases import get_cases\n\n# Set seeds to make tests fully reproducible.\nSEED = 12345  # random.randint(1, 99999)\nrandom.seed(SEED)\nnp.random.seed(SEED)\ndgl.seed(SEED)\nF.seed(SEED)\n\nudf_msg = {\n    \"add\": lambda edges: {\"m\": edges.src[\"x\"] + edges.data[\"w\"]},\n    \"sub\": lambda edges: {\"m\": edges.src[\"x\"] - edges.data[\"w\"]},\n    \"mul\": lambda edges: {\"m\": edges.src[\"x\"] * edges.data[\"w\"]},\n    \"div\": lambda edges: {\"m\": edges.src[\"x\"] / edges.data[\"w\"]},\n    \"copy_lhs\": lambda edges: {\"m\": edges.src[\"x\"]},\n    \"copy_rhs\": lambda edges: {\"m\": edges.data[\"w\"]},\n}\n\n\ndef select(target, src, edge, dst):\n    if target == \"u\":\n        return src\n    elif target == \"v\":\n        return dst\n    elif target == \"e\":\n        return edge\n\n\ndef binary_op(msg, x, y):\n    if msg == \"add\":\n        return x + y\n    elif msg == \"sub\":\n        return x - y\n    elif msg == \"mul\":\n        return x * y\n    elif msg == \"div\":\n        return x / y\n    elif msg == \"dot\":\n        return F.sum(x * y, -1, keepdims=True)\n    elif msg == \"copy_lhs\":\n        return x\n    elif msg == \"copy_rhs\":\n        return y\n\n\ndef edge_func(lhs_target, rhs_target, msg):\n    def foo(edges):\n        return {\n            \"m\": binary_op(\n                msg,\n                select(lhs_target, edges.src, edges.data, edges.dst)[\"x\"],\n                select(rhs_target, edges.src, edges.data, edges.dst)[\"y\"],\n            )\n        }\n\n    return foo\n\n\nudf_apply_edges = {\n    lhs_target\n    + \"_\"\n    + msg\n    + \"_\"\n    + rhs_target: edge_func(lhs_target, rhs_target, msg)\n    for lhs_target in [\"u\", \"v\", \"e\"]\n    for rhs_target in [\"u\", \"v\", \"e\"]\n    for msg in [\"add\", \"sub\", \"mul\", \"div\", \"dot\", \"copy_lhs\", \"copy_rhs\"]\n}\n\nudf_reduce = {\n    \"sum\": lambda nodes: {\"v\": F.sum(nodes.mailbox[\"m\"], 1)},\n    \"min\": lambda nodes: {\"v\": F.min(nodes.mailbox[\"m\"], 1)},\n    \"max\": lambda nodes: {\"v\": F.max(nodes.mailbox[\"m\"], 1)},\n}\n\ngraphs = [\n    #    dgl.rand_graph(30, 0),\n    dgl.rand_graph(30, 100),\n    dgl.rand_bipartite(\"_U\", \"_E\", \"_V\", 30, 40, 300),\n]\n\nspmm_shapes = [\n    ((1, 2, 1, 3, 1), (4, 1, 3, 1, 1)),\n    ((3, 3), (1, 3)),\n    ((1,), (3,)),\n    ((3,), (1,)),\n    ((1,), (1,)),\n    ((), ()),\n]\n\nsddmm_shapes = [\n    ((1, 2, 1, 3, 1), (4, 1, 3, 1, 1)),\n    ((5, 3, 1, 7), (1, 3, 7, 7)),\n    ((1, 3, 3), (4, 1, 3)),\n    ((3,), (3,)),\n    ((1,), (1,)),\n]\n\n\n@pytest.mark.parametrize(\"g\", graphs)\n@pytest.mark.parametrize(\"shp\", spmm_shapes)\n@pytest.mark.parametrize(\n    \"msg\", [\"add\", \"sub\", \"mul\", \"div\", \"copy_lhs\", \"copy_rhs\"]\n)\n@pytest.mark.parametrize(\"reducer\", [\"sum\", \"min\", \"max\"])\n@parametrize_idtype\n@pytest.mark.parametrize(\"dtype\", [np.float32, np.float64])\ndef test_spmm(idtype, dtype, g, shp, msg, reducer):\n    g = g.astype(idtype).to(F.ctx())\n    print(g)\n    print(g.idtype)\n\n    hu = F.tensor(\n        np.random.rand(*((g.number_of_src_nodes(),) + shp[0])).astype(dtype) + 1\n    )\n    he = F.tensor(\n        np.random.rand(*((g.num_edges(),) + shp[1])).astype(dtype) + 1\n    )\n    print(\"u shape: {}, e shape: {}\".format(F.shape(hu), F.shape(he)))\n\n    g.srcdata[\"x\"] = F.attach_grad(F.clone(hu))\n    g.edata[\"w\"] = F.attach_grad(F.clone(he))\n    print(\"SpMM(message func: {}, reduce func: {})\".format(msg, reducer))\n\n    u = F.attach_grad(F.clone(hu))\n    e = F.attach_grad(F.clone(he))\n    with F.record_grad():\n        v = gspmm(g, msg, reducer, u, e)\n        if reducer in [\"max\", \"min\"]:\n            v = F.replace_inf_with_zero(v)\n        if g.num_edges() > 0:\n            F.backward(F.reduce_sum(v))\n            if msg != \"copy_rhs\":\n                grad_u = F.grad(u)\n            if msg != \"copy_lhs\":\n                grad_e = F.grad(e)\n\n    with F.record_grad():\n        g.update_all(udf_msg[msg], udf_reduce[reducer])\n        if g.num_edges() > 0:\n            v1 = g.dstdata[\"v\"]\n            assert F.allclose(v, v1)\n            print(\"forward passed\")\n\n            F.backward(F.reduce_sum(v1))\n            if msg != \"copy_rhs\":\n                if reducer in [\n                    \"min\",\n                    \"max\",\n                ]:  # there might be some numerical errors\n                    rate = F.reduce_sum(\n                        F.abs(F.grad(g.srcdata[\"x\"]) - grad_u)\n                    ) / F.reduce_sum(F.abs(grad_u))\n                    assert F.as_scalar(rate) < 1e-2, rate\n                else:\n                    assert F.allclose(F.grad(g.srcdata[\"x\"]), grad_u)\n            if msg != \"copy_lhs\":\n                if reducer in [\"min\", \"max\"]:\n                    rate = F.reduce_sum(\n                        F.abs(F.grad(g.edata[\"w\"]) - grad_e)\n                    ) / F.reduce_sum(F.abs(grad_e))\n                    assert F.as_scalar(rate) < 1e-2, rate\n                else:\n                    assert F.allclose(F.grad(g.edata[\"w\"]), grad_e)\n            print(\"backward passed\")\n\n    g.srcdata.pop(\"x\")\n    g.edata.pop(\"w\")\n    if \"v\" in g.dstdata:\n        g.dstdata.pop(\"v\")\n\n\n@unittest.skipIf(\n    dgl.backend.backend_name != \"pytorch\",\n    reason=\"Only support PyTorch for now.\",\n)\n@parametrize_idtype\n@pytest.mark.parametrize(\n    \"dtype, rtol, atol\",\n    [(torch.float16, 1e-3, 0.5), (torch.bfloat16, 4e-3, 2.0)],\n)\ndef test_half_spmm(idtype, dtype, rtol, atol):\n    if F._default_context_str == \"cpu\" and dtype == torch.float16:\n        pytest.skip(\"float16 is not supported on CPU.\")\n    if (\n        F._default_context_str == \"gpu\"\n        and dtype == torch.bfloat16\n        and not torch.cuda.is_bf16_supported()\n    ):\n        pytest.skip(\"BF16 is not supported.\")\n\n    # make sure the spmm result is < 512 to match the rtol/atol we set.\n    g = dgl.graph(\n        (torch.arange(900), torch.tensor([0] * 900)),\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    feat_fp32 = torch.rand((g.num_src_nodes(), 32)).to(F.ctx())\n    feat_half = feat_fp32.to(dtype)\n\n    # test SpMMCSR\n    g = g.formats([\"csc\"])\n    res_fp32 = dgl.ops.copy_u_sum(g, feat_fp32)[0]\n    res_half = dgl.ops.copy_u_sum(g, feat_half)[0].float()\n    assert torch.allclose(res_fp32, res_half, rtol=rtol, atol=atol)\n\n    # test SpMMCOO\n    # TODO(Xin): half-precision SpMMCoo is temporally disabled.\n    # g = g.formats(['coo'])\n    # res_fp32 = dgl.ops.copy_u_sum(g, feat_fp32)[0]\n    # res_half = dgl.ops.copy_u_sum(g, feat_half)[0].float()\n    # assert torch.allclose(res_fp32, res_half, rtol=rtol, atol=atol)\n\n\n@pytest.mark.parametrize(\"g\", graphs)\n@pytest.mark.parametrize(\"shp\", sddmm_shapes)\n@pytest.mark.parametrize(\"lhs_target\", [\"u\", \"v\", \"e\"])\n@pytest.mark.parametrize(\"rhs_target\", [\"u\", \"v\", \"e\"])\n@pytest.mark.parametrize(\n    \"msg\", [\"add\", \"sub\", \"mul\", \"div\", \"dot\", \"copy_lhs\", \"copy_rhs\"]\n)\n@parametrize_idtype\ndef test_sddmm(g, shp, lhs_target, rhs_target, msg, idtype):\n    if lhs_target == rhs_target:\n        return\n    g = g.astype(idtype).to(F.ctx())\n    if dgl.backend.backend_name == \"mxnet\" and g.num_edges() == 0:\n        pytest.skip()  # mxnet do not support zero shape tensor\n    print(g)\n    print(g.idtype)\n\n    len_lhs = select(\n        lhs_target,\n        g.number_of_src_nodes(),\n        g.num_edges(),\n        g.number_of_dst_nodes(),\n    )\n    lhs_shp = (len_lhs,) + shp[0]\n    len_rhs = select(\n        rhs_target,\n        g.number_of_src_nodes(),\n        g.num_edges(),\n        g.number_of_dst_nodes(),\n    )\n    rhs_shp = (len_rhs,) + shp[1]\n    feat_lhs = F.tensor(np.random.rand(*lhs_shp) + 1)\n    feat_rhs = F.tensor(np.random.rand(*rhs_shp) + 1)\n    print(\n        \"lhs shape: {}, rhs shape: {}\".format(\n            F.shape(feat_lhs), F.shape(feat_rhs)\n        )\n    )\n\n    lhs_frame = select(lhs_target, g.srcdata, g.edata, g.dstdata)\n    rhs_frame = select(rhs_target, g.srcdata, g.edata, g.dstdata)\n    lhs_frame[\"x\"] = F.attach_grad(F.clone(feat_lhs))\n    rhs_frame[\"y\"] = F.attach_grad(F.clone(feat_rhs))\n    msg_func = lhs_target + \"_\" + msg + \"_\" + rhs_target\n    print(\"SDDMM(message func: {})\".format(msg_func))\n\n    lhs = F.attach_grad(F.clone(feat_lhs))\n    rhs = F.attach_grad(F.clone(feat_rhs))\n    with F.record_grad():\n        e = gsddmm(\n            g, msg, lhs, rhs, lhs_target=lhs_target, rhs_target=rhs_target\n        )\n        F.backward(F.reduce_sum(e))\n        grad_lhs = F.grad(lhs)\n        grad_rhs = F.grad(rhs)\n\n    with F.record_grad():\n        g.apply_edges(udf_apply_edges[msg_func])\n        if g.num_edges() > 0:\n            e1 = g.edata[\"m\"]\n            assert F.allclose(e, e1)\n            print(\"forward passed\")\n\n            F.backward(F.reduce_sum(e1))\n            if msg != \"copy_rhs\":\n                assert F.allclose(F.grad(lhs_frame[\"x\"]), grad_lhs)\n            if msg != \"copy_lhs\":\n                assert F.allclose(F.grad(rhs_frame[\"y\"]), grad_rhs)\n            print(\"backward passed\")\n\n    lhs_frame.pop(\"x\")\n    rhs_frame.pop(\"y\")\n    if \"m\" in g.edata:\n        g.edata.pop(\"m\")\n\n\n@pytest.mark.parametrize(\"reducer\", [\"sum\", \"max\", \"min\", \"mean\"])\ndef test_segment_reduce(reducer):\n    ctx = F.ctx()\n    value = F.tensor(np.random.rand(10, 5))\n    v1 = F.attach_grad(F.clone(value))\n    v2 = F.attach_grad(F.clone(value))\n    seglen = F.tensor([2, 3, 0, 4, 1, 0, 0])\n    u = F.copy_to(F.arange(0, F.shape(value)[0], F.int32), ctx)\n    v = F.repeat(\n        F.copy_to(F.arange(0, len(seglen), F.int32), ctx), seglen, dim=0\n    )\n\n    num_nodes = {\"_U\": len(u), \"_V\": len(seglen)}\n    g = dgl.convert.heterograph(\n        {(\"_U\", \"_E\", \"_V\"): (u, v)}, num_nodes_dict=num_nodes\n    )\n    with F.record_grad():\n        rst1 = gspmm(g, \"copy_lhs\", reducer, v1, None)\n        if reducer in [\"max\", \"min\"]:\n            rst1 = F.replace_inf_with_zero(rst1)\n        F.backward(F.reduce_sum(rst1))\n        grad1 = F.grad(v1)\n\n    with F.record_grad():\n        rst2 = segment_reduce(seglen, v2, reducer=reducer)\n        F.backward(F.reduce_sum(rst2))\n        assert F.allclose(rst1, rst2)\n        print(\"forward passed\")\n\n        grad2 = F.grad(v2)\n        assert F.allclose(grad1, grad2)\n        print(\"backward passed\")\n\n\n@unittest.skipIf(\n    dgl.backend.backend_name != \"pytorch\", reason=\"Only support PyTorch for now\"\n)\n@parametrize_idtype\n@pytest.mark.parametrize(\"feat_size\", [1, 8, 16, 64, 256])\n@pytest.mark.parametrize(\n    \"dtype, tol\",\n    [\n        (torch.float16, 1e-2),\n        (torch.bfloat16, 1e-2),\n        (torch.float32, 3e-3),\n        (torch.float64, 1e-4),\n    ],\n)\ndef test_segment_mm(idtype, feat_size, dtype, tol):\n    if F._default_context_str == \"cpu\" and dtype == torch.float16:\n        pytest.skip(\"float16 is not supported on CPU.\")\n    if (\n        F._default_context_str == \"gpu\"\n        and dtype == torch.bfloat16\n        and not torch.cuda.is_bf16_supported()\n    ):\n        pytest.skip(\"BF16 is not supported.\")\n    dev = F.ctx()\n    # input\n    a = torch.tensor(np.random.rand(100, feat_size)).to(dev).to(dtype)\n    a.requires_grad_()\n    b = (\n        torch.tensor(np.random.rand(10, feat_size, feat_size + 1))\n        .to(dev)\n        .to(dtype)\n    )\n    b.requires_grad_()\n    seglen_a = torch.tensor([10, 15, 8, 0, 1, 9, 18, 24, 15, 0]).to(idtype)\n    dc = torch.tensor(np.random.rand(100, feat_size + 1)).to(dev).to(dtype)\n    # compute\n    c = dgl.ops.segment_mm(a, b, seglen_a)\n    c.backward(dc)\n    da = a.grad.clone()\n    db = b.grad.clone()\n    # ground truth\n    c_t = []\n    off = 0\n    for i, l in enumerate(seglen_a):\n        c_t.append(a[off : off + l] @ b[i])\n        off += l\n    c_t = torch.cat(c_t).to(dtype)\n    a.grad.zero_()\n    b.grad.zero_()\n    c_t.backward(dc)\n    da_t = a.grad\n    db_t = b.grad\n\n    assert torch.allclose(c, c_t, atol=tol, rtol=tol)\n    assert torch.allclose(da, da_t, atol=tol, rtol=tol)\n    assert torch.allclose(db, db_t, atol=tol, rtol=tol)\n\n\n@unittest.skipIf(\n    dgl.backend.backend_name != \"pytorch\", reason=\"Only support PyTorch for now\"\n)\n@pytest.mark.parametrize(\"feat_size\", [1, 8, 16, 64, 256])\n@pytest.mark.parametrize(\n    \"dtype, tol\",\n    [\n        (torch.float16, 1e-2),\n        (torch.bfloat16, 2e-2),\n        (torch.float32, 3e-3),\n        (torch.float64, 1e-4),\n    ],\n)\ndef test_gather_mm_idx_b(feat_size, dtype, tol):\n    if F._default_context_str == \"cpu\" and dtype == torch.float16:\n        pytest.skip(\"float16 is not supported on CPU.\")\n\n    if F._default_context_str == \"gpu\":\n        if dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported():\n            pytest.skip(\"BF16 is not supported.\")\n\n        if (\n            dtype == torch.float16\n            and torch.cuda.get_device_capability() < (7, 0)\n        ) or (\n            dtype == torch.bfloat16\n            and torch.cuda.get_device_capability() < (8, 0)\n        ):\n            pytest.skip(\n                f\"{dtype} is not supported for atomic operations on GPU with \"\n                f\"cuda capability ({torch.cuda.get_device_capability()}).\"\n            )\n\n    dev = F.ctx()\n    # input\n    a = torch.tensor(np.random.rand(100, feat_size)).to(dev).to(dtype)\n    a.requires_grad_()\n    b = (\n        torch.tensor(np.random.rand(10, feat_size, feat_size + 1))\n        .to(dev)\n        .to(dtype)\n    )\n    b.requires_grad_()\n    idx = torch.tensor(np.random.randint(0, 10, 100)).to(dev).long()\n    dc = torch.tensor(np.random.rand(100, feat_size + 1)).to(dev).to(dtype)\n    # compute\n    c = gather_mm(a, b, idx_b=idx)\n    c.backward(dc)\n    da = a.grad.clone()\n    db = b.grad.clone()\n    # ground truth\n    c_t = torch.bmm(a.unsqueeze(1), b[idx]).squeeze(1)\n    a.grad.zero_()\n    b.grad.zero_()\n    c_t.backward(dc)\n    da_t = a.grad\n    db_t = b.grad\n\n    assert torch.allclose(c, c_t, atol=tol, rtol=tol)\n    assert torch.allclose(da, da_t, atol=tol, rtol=tol)\n    assert torch.allclose(db, db_t, atol=tol, rtol=tol)\n\n\n@unittest.skipIf(\n    dgl.backend.backend_name != \"pytorch\", reason=\"Only support PyTorch for now\"\n)\n@parametrize_idtype\n@pytest.mark.parametrize(\"feat_size\", [1, 8, 16, 64, 256])\ndef _test_gather_mm_idx_a(idtype, feat_size):\n    # TODO(minjie): currently disabled due to bugs in the CUDA kernel. Need to fix it later.\n    import torch\n\n    dev = F.ctx()\n    # input\n    a = torch.tensor(np.random.rand(10, feat_size)).to(dev)\n    a.requires_grad_()\n    b = torch.tensor(np.random.rand(100, feat_size, feat_size + 1)).to(dev)\n    b.requires_grad_()\n    idx = torch.tensor(np.random.randint(0, 10, 100)).to(dev)\n    dc = torch.tensor(np.random.rand(100, feat_size + 1)).to(dev)\n    # compute\n    c = gather_mm(a, b, idx_a=idx)\n    c.backward(dc)\n    da = a.grad.clone()\n    db = b.grad.clone()\n    # ground truth\n    c_t = torch.bmm(a[idx].unsqueeze(1), b).squeeze(1)\n    a.grad.zero_()\n    b.grad.zero_()\n    c_t.backward(dc)\n    da_t = a.grad\n    db_t = b.grad\n\n    assert torch.allclose(c, c_t, atol=1e-4, rtol=1e-4)\n    assert torch.allclose(da, da_t, atol=1e-4, rtol=1e-4)\n    assert torch.allclose(db, db_t, atol=1e-4, rtol=1e-4)\n\n\n@unittest.skipIf(\n    dgl.backend.backend_name != \"pytorch\", reason=\"Only support PyTorch for now\"\n)\n@unittest.skipIf(\n    F._default_context_str == \"gpu\", reason=\"Libxsmm only fit in CPU.\"\n)\ndef test_use_libxsmm_switch():\n    import torch\n\n    g = dgl.graph(([0, 0, 0, 1, 1, 2], [0, 1, 2, 1, 2, 2]))\n    x = torch.ones(3, 2, requires_grad=True)\n    y = torch.arange(1, 13).float().view(6, 2).requires_grad_()\n\n    dgl.use_libxsmm(False)\n    assert ~dgl.is_libxsmm_enabled()\n    dgl.ops.u_mul_e_sum(g, x, y)\n    dgl.use_libxsmm(True)\n    assert dgl.is_libxsmm_enabled()\n    dgl.ops.u_mul_e_sum(g, x, y)\n"
  },
  {
    "path": "tests/python/common/sampling/test_sampling.py",
    "content": "import unittest\nimport warnings\n\nimport backend as F\n\nimport dgl\nimport numpy as np\nimport pytest\n\nsample_neighbors_fusing_mode = {\n    True: dgl.sampling.sample_neighbors_fused,\n    False: dgl.sampling.sample_neighbors,\n}\n\n\ndef check_random_walk(g, metapath, traces, ntypes, prob=None, trace_eids=None):\n    traces = F.asnumpy(traces)\n    ntypes = F.asnumpy(ntypes)\n    for j in range(traces.shape[1] - 1):\n        assert ntypes[j] == g.get_ntype_id(g.to_canonical_etype(metapath[j])[0])\n        assert ntypes[j + 1] == g.get_ntype_id(\n            g.to_canonical_etype(metapath[j])[2]\n        )\n\n    for i in range(traces.shape[0]):\n        for j in range(traces.shape[1] - 1):\n            assert g.has_edges_between(\n                traces[i, j], traces[i, j + 1], etype=metapath[j]\n            )\n            if prob is not None and prob in g.edges[metapath[j]].data:\n                p = F.asnumpy(g.edges[metapath[j]].data[\"p\"])\n                eids = g.edge_ids(\n                    traces[i, j], traces[i, j + 1], etype=metapath[j]\n                )\n                assert p[eids] != 0\n            if trace_eids is not None:\n                u, v = g.find_edges(trace_eids[i, j], etype=metapath[j])\n                assert (u == traces[i, j]) and (v == traces[i, j + 1])\n\n\n@pytest.mark.parametrize(\"use_uva\", [True, False])\ndef test_non_uniform_random_walk(use_uva):\n    if use_uva:\n        if F.ctx() == F.cpu():\n            pytest.skip(\"UVA biased random walk requires a GPU.\")\n        if dgl.backend.backend_name != \"pytorch\":\n            pytest.skip(\n                \"UVA biased random walk is only supported with PyTorch.\"\n            )\n    g2 = dgl.heterograph(\n        {(\"user\", \"follow\", \"user\"): ([0, 1, 1, 2, 3], [1, 2, 3, 0, 0])}\n    )\n    g4 = dgl.heterograph(\n        {\n            (\"user\", \"follow\", \"user\"): ([0, 1, 1, 2, 3], [1, 2, 3, 0, 0]),\n            (\"user\", \"view\", \"item\"): ([0, 0, 1, 2, 3, 3], [0, 1, 1, 2, 2, 1]),\n            (\"item\", \"viewed-by\", \"user\"): (\n                [0, 1, 1, 2, 2, 1],\n                [0, 0, 1, 2, 3, 3],\n            ),\n        }\n    )\n\n    g2.edata[\"p\"] = F.copy_to(\n        F.tensor([3, 0, 3, 3, 3], dtype=F.float32), F.cpu()\n    )\n    g2.edata[\"p2\"] = F.copy_to(\n        F.tensor([[3], [0], [3], [3], [3]], dtype=F.float32), F.cpu()\n    )\n    g4.edges[\"follow\"].data[\"p\"] = F.copy_to(\n        F.tensor([3, 0, 3, 3, 3], dtype=F.float32), F.cpu()\n    )\n    g4.edges[\"viewed-by\"].data[\"p\"] = F.copy_to(\n        F.tensor([1, 1, 1, 1, 1, 1], dtype=F.float32), F.cpu()\n    )\n\n    if use_uva:\n        for g in (g2, g4):\n            g.create_formats_()\n            g.pin_memory_()\n    elif F._default_context_str == \"gpu\":\n        g2 = g2.to(F.ctx())\n        g4 = g4.to(F.ctx())\n\n    try:\n        traces, eids, ntypes = dgl.sampling.random_walk(\n            g2,\n            F.tensor([0, 1, 2, 3, 0, 1, 2, 3], dtype=g2.idtype),\n            length=4,\n            prob=\"p\",\n            return_eids=True,\n        )\n        check_random_walk(\n            g2, [\"follow\"] * 4, traces, ntypes, \"p\", trace_eids=eids\n        )\n\n        with pytest.raises(dgl.DGLError):\n            traces, ntypes = dgl.sampling.random_walk(\n                g2,\n                F.tensor([0, 1, 2, 3, 0, 1, 2, 3], dtype=g2.idtype),\n                length=4,\n                prob=\"p2\",\n            )\n\n        metapath = [\"follow\", \"view\", \"viewed-by\"] * 2\n        traces, eids, ntypes = dgl.sampling.random_walk(\n            g4,\n            F.tensor([0, 1, 2, 3, 0, 1, 2, 3], dtype=g4.idtype),\n            metapath=metapath,\n            prob=\"p\",\n            return_eids=True,\n        )\n        check_random_walk(g4, metapath, traces, ntypes, \"p\", trace_eids=eids)\n        traces, eids, ntypes = dgl.sampling.random_walk(\n            g4,\n            F.tensor([0, 1, 2, 3, 0, 1, 2, 3], dtype=g4.idtype),\n            metapath=metapath,\n            prob=\"p\",\n            restart_prob=0.0,\n            return_eids=True,\n        )\n        check_random_walk(g4, metapath, traces, ntypes, \"p\", trace_eids=eids)\n        traces, eids, ntypes = dgl.sampling.random_walk(\n            g4,\n            F.tensor([0, 1, 2, 3, 0, 1, 2, 3], dtype=g4.idtype),\n            metapath=metapath,\n            prob=\"p\",\n            restart_prob=F.zeros((6,), F.float32, F.ctx()),\n            return_eids=True,\n        )\n        check_random_walk(g4, metapath, traces, ntypes, \"p\", trace_eids=eids)\n        traces, eids, ntypes = dgl.sampling.random_walk(\n            g4,\n            F.tensor([0, 1, 2, 3, 0, 1, 2, 3], dtype=g4.idtype),\n            metapath=metapath + [\"follow\"],\n            prob=\"p\",\n            restart_prob=F.tensor([0, 0, 0, 0, 0, 0, 1], F.float32),\n            return_eids=True,\n        )\n        check_random_walk(\n            g4, metapath, traces[:, :7], ntypes[:7], \"p\", trace_eids=eids\n        )\n        assert (F.asnumpy(traces[:, 7]) == -1).all()\n    finally:\n        for g in (g2, g4):\n            g.unpin_memory_()\n\n\n@pytest.mark.parametrize(\"use_uva\", [True, False])\ndef test_uniform_random_walk(use_uva):\n    if use_uva and F.ctx() == F.cpu():\n        pytest.skip(\"UVA random walk requires a GPU.\")\n    g1 = dgl.heterograph({(\"user\", \"follow\", \"user\"): ([0, 1, 2], [1, 2, 0])})\n    g2 = dgl.heterograph(\n        {(\"user\", \"follow\", \"user\"): ([0, 1, 1, 2, 3], [1, 2, 3, 0, 0])}\n    )\n    g3 = dgl.heterograph(\n        {\n            (\"user\", \"follow\", \"user\"): ([0, 1, 2], [1, 2, 0]),\n            (\"user\", \"view\", \"item\"): ([0, 1, 2], [0, 1, 2]),\n            (\"item\", \"viewed-by\", \"user\"): ([0, 1, 2], [0, 1, 2]),\n        }\n    )\n    g4 = dgl.heterograph(\n        {\n            (\"user\", \"follow\", \"user\"): ([0, 1, 1, 2, 3], [1, 2, 3, 0, 0]),\n            (\"user\", \"view\", \"item\"): ([0, 0, 1, 2, 3, 3], [0, 1, 1, 2, 2, 1]),\n            (\"item\", \"viewed-by\", \"user\"): (\n                [0, 1, 1, 2, 2, 1],\n                [0, 0, 1, 2, 3, 3],\n            ),\n        }\n    )\n\n    if use_uva:\n        for g in (g1, g2, g3, g4):\n            g.create_formats_()\n            g.pin_memory_()\n    elif F._default_context_str == \"gpu\":\n        g1 = g1.to(F.ctx())\n        g2 = g2.to(F.ctx())\n        g3 = g3.to(F.ctx())\n        g4 = g4.to(F.ctx())\n\n    try:\n        traces, eids, ntypes = dgl.sampling.random_walk(\n            g1,\n            F.tensor([0, 1, 2, 0, 1, 2], dtype=g1.idtype),\n            length=4,\n            return_eids=True,\n        )\n        check_random_walk(g1, [\"follow\"] * 4, traces, ntypes, trace_eids=eids)\n        if F._default_context_str == \"cpu\":\n            with pytest.raises(dgl.DGLError):\n                dgl.sampling.random_walk(\n                    g1,\n                    F.tensor([0, 1, 2, 10], dtype=g1.idtype),\n                    length=4,\n                    return_eids=True,\n                )\n        traces, eids, ntypes = dgl.sampling.random_walk(\n            g1,\n            F.tensor([0, 1, 2, 0, 1, 2], dtype=g1.idtype),\n            length=4,\n            restart_prob=0.0,\n            return_eids=True,\n        )\n        check_random_walk(g1, [\"follow\"] * 4, traces, ntypes, trace_eids=eids)\n        traces, ntypes = dgl.sampling.random_walk(\n            g1,\n            F.tensor([0, 1, 2, 0, 1, 2], dtype=g1.idtype),\n            length=4,\n            restart_prob=F.zeros((4,), F.float32),\n        )\n        check_random_walk(g1, [\"follow\"] * 4, traces, ntypes)\n        traces, ntypes = dgl.sampling.random_walk(\n            g1,\n            F.tensor([0, 1, 2, 0, 1, 2], dtype=g1.idtype),\n            length=5,\n            restart_prob=F.tensor([0, 0, 0, 0, 1], dtype=F.float32),\n        )\n        check_random_walk(\n            g1,\n            [\"follow\"] * 4,\n            F.slice_axis(traces, 1, 0, 5),\n            F.slice_axis(ntypes, 0, 0, 5),\n        )\n        assert (F.asnumpy(traces)[:, 5] == -1).all()\n\n        traces, eids, ntypes = dgl.sampling.random_walk(\n            g2,\n            F.tensor([0, 1, 2, 3, 0, 1, 2, 3], dtype=g2.idtype),\n            length=4,\n            return_eids=True,\n        )\n        check_random_walk(g2, [\"follow\"] * 4, traces, ntypes, trace_eids=eids)\n\n        metapath = [\"follow\", \"view\", \"viewed-by\"] * 2\n        traces, eids, ntypes = dgl.sampling.random_walk(\n            g3,\n            F.tensor([0, 1, 2, 0, 1, 2], dtype=g3.idtype),\n            metapath=metapath,\n            return_eids=True,\n        )\n        check_random_walk(g3, metapath, traces, ntypes, trace_eids=eids)\n\n        metapath = [\"follow\", \"view\", \"viewed-by\"] * 2\n        traces, eids, ntypes = dgl.sampling.random_walk(\n            g4,\n            F.tensor([0, 1, 2, 3, 0, 1, 2, 3], dtype=g4.idtype),\n            metapath=metapath,\n            return_eids=True,\n        )\n        check_random_walk(g4, metapath, traces, ntypes, trace_eids=eids)\n\n        traces, eids, ntypes = dgl.sampling.random_walk(\n            g4,\n            F.tensor([0, 1, 2, 0, 1, 2], dtype=g4.idtype),\n            metapath=metapath,\n            return_eids=True,\n        )\n        check_random_walk(g4, metapath, traces, ntypes, trace_eids=eids)\n    finally:  # make sure to unpin the graphs even if some test fails\n        for g in (g1, g2, g3, g4):\n            if g.is_pinned():\n                g.unpin_memory_()\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\", reason=\"GPU random walk not implemented\"\n)\ndef test_node2vec():\n    g1 = dgl.heterograph({(\"user\", \"follow\", \"user\"): ([0, 1, 2], [1, 2, 0])})\n    g2 = dgl.heterograph(\n        {(\"user\", \"follow\", \"user\"): ([0, 1, 1, 2, 3], [1, 2, 3, 0, 0])}\n    )\n    g2.edata[\"p\"] = F.tensor([3, 0, 3, 3, 3], dtype=F.float32)\n\n    ntypes = F.zeros((5,), dtype=F.int64)\n\n    traces, eids = dgl.sampling.node2vec_random_walk(\n        g1, [0, 1, 2, 0, 1, 2], 1, 1, 4, return_eids=True\n    )\n    check_random_walk(g1, [\"follow\"] * 4, traces, ntypes, trace_eids=eids)\n\n    traces, eids = dgl.sampling.node2vec_random_walk(\n        g2, [0, 1, 2, 3, 0, 1, 2, 3], 1, 1, 4, prob=\"p\", return_eids=True\n    )\n    check_random_walk(g2, [\"follow\"] * 4, traces, ntypes, \"p\", trace_eids=eids)\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\", reason=\"GPU pack traces not implemented\"\n)\ndef test_pack_traces():\n    traces, types = (\n        np.array(\n            [[0, 1, -1, -1, -1, -1, -1], [0, 1, 1, 3, 0, 0, 0]], dtype=\"int64\"\n        ),\n        np.array([0, 0, 1, 0, 0, 1, 0], dtype=\"int64\"),\n    )\n    traces = F.zerocopy_from_numpy(traces)\n    types = F.zerocopy_from_numpy(types)\n    result = dgl.sampling.pack_traces(traces, types)\n    assert F.array_equal(\n        result[0], F.tensor([0, 1, 0, 1, 1, 3, 0, 0, 0], dtype=F.int64)\n    )\n    assert F.array_equal(\n        result[1], F.tensor([0, 0, 0, 0, 1, 0, 0, 1, 0], dtype=F.int64)\n    )\n    assert F.array_equal(result[2], F.tensor([2, 7], dtype=F.int64))\n    assert F.array_equal(result[3], F.tensor([0, 2], dtype=F.int64))\n\n\n@pytest.mark.parametrize(\"use_uva\", [True, False])\ndef test_pinsage_sampling(use_uva):\n    if use_uva and F.ctx() == F.cpu():\n        pytest.skip(\"UVA sampling requires a GPU.\")\n\n    def _test_sampler(g, sampler, ntype):\n        seeds = F.copy_to(F.tensor([0, 2], dtype=g.idtype), F.ctx())\n        neighbor_g = sampler(seeds)\n        assert neighbor_g.ntypes == [ntype]\n        u, v = neighbor_g.all_edges(form=\"uv\", order=\"eid\")\n        uv = list(zip(F.asnumpy(u).tolist(), F.asnumpy(v).tolist()))\n        assert (1, 0) in uv or (0, 0) in uv\n        assert (2, 2) in uv or (3, 2) in uv\n\n    g = dgl.heterograph(\n        {\n            (\"item\", \"bought-by\", \"user\"): (\n                [0, 0, 1, 1, 2, 2, 3, 3],\n                [0, 1, 0, 1, 2, 3, 2, 3],\n            ),\n            (\"user\", \"bought\", \"item\"): (\n                [0, 1, 0, 1, 2, 3, 2, 3],\n                [0, 0, 1, 1, 2, 2, 3, 3],\n            ),\n        }\n    )\n    if use_uva:\n        g.create_formats_()\n        g.pin_memory_()\n    elif F._default_context_str == \"gpu\":\n        g = g.to(F.ctx())\n    try:\n        sampler = dgl.sampling.PinSAGESampler(g, \"item\", \"user\", 4, 0.5, 3, 2)\n        _test_sampler(g, sampler, \"item\")\n        sampler = dgl.sampling.RandomWalkNeighborSampler(\n            g, 4, 0.5, 3, 2, [\"bought-by\", \"bought\"]\n        )\n        _test_sampler(g, sampler, \"item\")\n        sampler = dgl.sampling.RandomWalkNeighborSampler(\n            g,\n            4,\n            0.5,\n            3,\n            2,\n            [(\"item\", \"bought-by\", \"user\"), (\"user\", \"bought\", \"item\")],\n        )\n        _test_sampler(g, sampler, \"item\")\n    finally:\n        if g.is_pinned():\n            g.unpin_memory_()\n\n    g = dgl.graph(([0, 0, 1, 1, 2, 2, 3, 3], [0, 1, 0, 1, 2, 3, 2, 3]))\n    if use_uva:\n        g.create_formats_()\n        g.pin_memory_()\n    elif F._default_context_str == \"gpu\":\n        g = g.to(F.ctx())\n    try:\n        sampler = dgl.sampling.RandomWalkNeighborSampler(g, 4, 0.5, 3, 2)\n        _test_sampler(g, sampler, g.ntypes[0])\n    finally:\n        if g.is_pinned():\n            g.unpin_memory_()\n\n    g = dgl.heterograph(\n        {\n            (\"A\", \"AB\", \"B\"): ([0, 2], [1, 3]),\n            (\"B\", \"BC\", \"C\"): ([1, 3], [2, 1]),\n            (\"C\", \"CA\", \"A\"): ([2, 1], [0, 2]),\n        }\n    )\n    if use_uva:\n        g.create_formats_()\n        g.pin_memory_()\n    elif F._default_context_str == \"gpu\":\n        g = g.to(F.ctx())\n    try:\n        sampler = dgl.sampling.RandomWalkNeighborSampler(\n            g, 4, 0.5, 3, 2, [\"AB\", \"BC\", \"CA\"]\n        )\n        _test_sampler(g, sampler, \"A\")\n    finally:\n        if g.is_pinned():\n            g.unpin_memory_()\n\n\ndef _gen_neighbor_sampling_test_graph(hypersparse, reverse):\n    if hypersparse:\n        # should crash if allocated a CSR\n        card = 1 << 50\n        num_nodes_dict = {\"user\": card, \"game\": card, \"coin\": card}\n    else:\n        card = None\n        num_nodes_dict = None\n\n    if reverse:\n        g = dgl.heterograph(\n            {\n                (\"user\", \"follow\", \"user\"): (\n                    [0, 0, 0, 1, 1, 1, 2],\n                    [1, 2, 3, 0, 2, 3, 0],\n                )\n            },\n            {\"user\": card if card is not None else 4},\n        )\n        g = g.to(F.ctx())\n        g.edata[\"prob\"] = F.tensor(\n            [0.5, 0.5, 0.0, 0.5, 0.5, 0.0, 1.0], dtype=F.float32\n        )\n        g.edata[\"mask\"] = F.tensor([True, True, False, True, True, False, True])\n        hg = dgl.heterograph(\n            {\n                (\"user\", \"follow\", \"user\"): (\n                    [0, 0, 0, 1, 1, 1, 2],\n                    [1, 2, 3, 0, 2, 3, 0],\n                ),\n                (\"game\", \"play\", \"user\"): ([0, 1, 2, 2], [0, 0, 1, 3]),\n                (\"user\", \"liked-by\", \"game\"): (\n                    [0, 1, 2, 0, 3, 0],\n                    [2, 2, 2, 1, 1, 0],\n                ),\n                (\"coin\", \"flips\", \"user\"): ([0, 0, 0, 0], [0, 1, 2, 3]),\n            },\n            num_nodes_dict,\n        )\n        hg = hg.to(F.ctx())\n    else:\n        g = dgl.heterograph(\n            {\n                (\"user\", \"follow\", \"user\"): (\n                    [1, 2, 3, 0, 2, 3, 0],\n                    [0, 0, 0, 1, 1, 1, 2],\n                )\n            },\n            {\"user\": card if card is not None else 4},\n        )\n        g = g.to(F.ctx())\n        g.edata[\"prob\"] = F.tensor(\n            [0.5, 0.5, 0.0, 0.5, 0.5, 0.0, 1.0], dtype=F.float32\n        )\n        g.edata[\"mask\"] = F.tensor([True, True, False, True, True, False, True])\n        hg = dgl.heterograph(\n            {\n                (\"user\", \"follow\", \"user\"): (\n                    [1, 2, 3, 0, 2, 3, 0],\n                    [0, 0, 0, 1, 1, 1, 2],\n                ),\n                (\"user\", \"play\", \"game\"): ([0, 0, 1, 3], [0, 1, 2, 2]),\n                (\"game\", \"liked-by\", \"user\"): (\n                    [2, 2, 2, 1, 1, 0],\n                    [0, 1, 2, 0, 3, 0],\n                ),\n                (\"user\", \"flips\", \"coin\"): ([0, 1, 2, 3], [0, 0, 0, 0]),\n            },\n            num_nodes_dict,\n        )\n        hg = hg.to(F.ctx())\n    hg.edges[\"follow\"].data[\"prob\"] = F.tensor(\n        [0.5, 0.5, 0.0, 0.5, 0.5, 0.0, 1.0], dtype=F.float32\n    )\n    hg.edges[\"follow\"].data[\"mask\"] = F.tensor(\n        [True, True, False, True, True, False, True]\n    )\n    hg.edges[\"play\"].data[\"prob\"] = F.tensor(\n        [0.8, 0.5, 0.5, 0.5], dtype=F.float32\n    )\n    # Leave out the mask of play and liked-by since all of them are True anyway.\n    hg.edges[\"liked-by\"].data[\"prob\"] = F.tensor(\n        [0.3, 0.5, 0.2, 0.5, 0.1, 0.1], dtype=F.float32\n    )\n\n    return g, hg\n\n\ndef _gen_neighbor_topk_test_graph(hypersparse, reverse):\n    if hypersparse:\n        # should crash if allocated a CSR\n        card = 1 << 50\n    else:\n        card = None\n\n    if reverse:\n        g = dgl.heterograph(\n            {\n                (\"user\", \"follow\", \"user\"): (\n                    [0, 0, 0, 1, 1, 1, 2],\n                    [1, 2, 3, 0, 2, 3, 0],\n                )\n            }\n        )\n        g.edata[\"weight\"] = F.tensor(\n            [0.5, 0.3, 0.0, -5.0, 22.0, 0.0, 1.0], dtype=F.float32\n        )\n        hg = dgl.heterograph(\n            {\n                (\"user\", \"follow\", \"user\"): (\n                    [0, 0, 0, 1, 1, 1, 2],\n                    [1, 2, 3, 0, 2, 3, 0],\n                ),\n                (\"game\", \"play\", \"user\"): ([0, 1, 2, 2], [0, 0, 1, 3]),\n                (\"user\", \"liked-by\", \"game\"): (\n                    [0, 1, 2, 0, 3, 0],\n                    [2, 2, 2, 1, 1, 0],\n                ),\n                (\"coin\", \"flips\", \"user\"): ([0, 0, 0, 0], [0, 1, 2, 3]),\n            }\n        )\n    else:\n        g = dgl.heterograph(\n            {\n                (\"user\", \"follow\", \"user\"): (\n                    [1, 2, 3, 0, 2, 3, 0],\n                    [0, 0, 0, 1, 1, 1, 2],\n                )\n            }\n        )\n        g.edata[\"weight\"] = F.tensor(\n            [0.5, 0.3, 0.0, -5.0, 22.0, 0.0, 1.0], dtype=F.float32\n        )\n        hg = dgl.heterograph(\n            {\n                (\"user\", \"follow\", \"user\"): (\n                    [1, 2, 3, 0, 2, 3, 0],\n                    [0, 0, 0, 1, 1, 1, 2],\n                ),\n                (\"user\", \"play\", \"game\"): ([0, 0, 1, 3], [0, 1, 2, 2]),\n                (\"game\", \"liked-by\", \"user\"): (\n                    [2, 2, 2, 1, 1, 0],\n                    [0, 1, 2, 0, 3, 0],\n                ),\n                (\"user\", \"flips\", \"coin\"): ([0, 1, 2, 3], [0, 0, 0, 0]),\n            }\n        )\n    hg.edges[\"follow\"].data[\"weight\"] = F.tensor(\n        [0.5, 0.3, 0.0, -5.0, 22.0, 0.0, 1.0], dtype=F.float32\n    )\n    hg.edges[\"play\"].data[\"weight\"] = F.tensor(\n        [0.8, 0.5, 0.4, 0.5], dtype=F.float32\n    )\n    hg.edges[\"liked-by\"].data[\"weight\"] = F.tensor(\n        [0.3, 0.5, 0.2, 0.5, 0.1, 0.1], dtype=F.float32\n    )\n    hg.edges[\"flips\"].data[\"weight\"] = F.tensor(\n        [10, 2, 13, -1], dtype=F.float32\n    )\n    return g, hg\n\n\ndef _test_sample_neighbors(hypersparse, prob, fused):\n    g, hg = _gen_neighbor_sampling_test_graph(hypersparse, False)\n\n    def _test1(p, replace):\n        subg = sample_neighbors_fusing_mode[fused](\n            g, [0, 1], -1, prob=p, replace=replace\n        )\n        if not fused:\n            assert subg.num_nodes() == g.num_nodes()\n        u, v = subg.edges()\n        if fused:\n            u, v = subg.srcdata[dgl.NID][u], subg.dstdata[dgl.NID][v]\n        u_ans, v_ans, e_ans = g.in_edges([0, 1], form=\"all\")\n        if p is not None:\n            emask = F.gather_row(g.edata[p], e_ans)\n            if p == \"prob\":\n                emask = emask != 0\n            u_ans = F.boolean_mask(u_ans, emask)\n            v_ans = F.boolean_mask(v_ans, emask)\n        uv = set(zip(F.asnumpy(u), F.asnumpy(v)))\n        uv_ans = set(zip(F.asnumpy(u_ans), F.asnumpy(v_ans)))\n        assert uv == uv_ans\n\n        for i in range(10):\n            subg = sample_neighbors_fusing_mode[fused](\n                g, [0, 1], 2, prob=p, replace=replace\n            )\n            if not fused:\n                assert subg.num_nodes() == g.num_nodes()\n\n            assert subg.num_edges() == 4\n            u, v = subg.edges()\n            if fused:\n                u, v = subg.srcdata[dgl.NID][u], subg.dstdata[dgl.NID][v]\n\n            assert set(F.asnumpy(F.unique(v))) == {0, 1}\n            assert F.array_equal(\n                F.astype(g.has_edges_between(u, v), F.int64),\n                F.ones((4,), dtype=F.int64),\n            )\n            assert F.array_equal(g.edge_ids(u, v), subg.edata[dgl.EID])\n            edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))\n            if not replace:\n                # check no duplication\n                assert len(edge_set) == 4\n            if p is not None:\n                assert not (3, 0) in edge_set\n                assert not (3, 1) in edge_set\n\n    _test1(prob, True)  # w/ replacement, uniform\n    _test1(prob, False)  # w/o replacement, uniform\n\n    def _test2(p, replace):  # fanout > #neighbors\n        subg = sample_neighbors_fusing_mode[fused](\n            g, [0, 2], -1, prob=p, replace=replace\n        )\n        if not fused:\n            assert subg.num_nodes() == g.num_nodes()\n        u, v = subg.edges()\n        if fused:\n            u, v = subg.srcdata[dgl.NID][u], subg.dstdata[dgl.NID][v]\n        u_ans, v_ans, e_ans = g.in_edges([0, 2], form=\"all\")\n        if p is not None:\n            emask = F.gather_row(g.edata[p], e_ans)\n            if p == \"prob\":\n                emask = emask != 0\n            u_ans = F.boolean_mask(u_ans, emask)\n            v_ans = F.boolean_mask(v_ans, emask)\n        uv = set(zip(F.asnumpy(u), F.asnumpy(v)))\n        uv_ans = set(zip(F.asnumpy(u_ans), F.asnumpy(v_ans)))\n        assert uv == uv_ans\n\n        for i in range(10):\n            subg = sample_neighbors_fusing_mode[fused](\n                g, [0, 2], 2, prob=p, replace=replace\n            )\n            if not fused:\n                assert subg.num_nodes() == g.num_nodes()\n            num_edges = 4 if replace else 3\n            assert subg.num_edges() == num_edges\n            u, v = subg.edges()\n            if fused:\n                u, v = subg.srcdata[dgl.NID][u], subg.dstdata[dgl.NID][v]\n            assert set(F.asnumpy(F.unique(v))) == {0, 2}\n            assert F.array_equal(\n                F.astype(g.has_edges_between(u, v), F.int64),\n                F.ones((num_edges,), dtype=F.int64),\n            )\n            assert F.array_equal(g.edge_ids(u, v), subg.edata[dgl.EID])\n            edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))\n            if not replace:\n                # check no duplication\n                assert len(edge_set) == num_edges\n            if p is not None:\n                assert not (3, 0) in edge_set\n\n    _test2(prob, True)  # w/ replacement, uniform\n    _test2(prob, False)  # w/o replacement, uniform\n\n    def _test3(p, replace):\n        subg = sample_neighbors_fusing_mode[fused](\n            hg, {\"user\": [0, 1], \"game\": 0}, -1, prob=p, replace=replace\n        )\n        if not fused:\n            assert len(subg.ntypes) == 3\n        assert len(subg.srctypes) == 3\n        assert len(subg.dsttypes) == 3\n        assert len(subg.etypes) == 4\n        assert subg[\"follow\"].num_edges() == 6 if p is None else 4\n        assert subg[\"play\"].num_edges() == 1\n        assert subg[\"liked-by\"].num_edges() == 4\n        assert subg[\"flips\"].num_edges() == 0\n\n        for i in range(10):\n            subg = sample_neighbors_fusing_mode[fused](\n                hg, {\"user\": [0, 1], \"game\": 0}, 2, prob=p, replace=replace\n            )\n            if not fused:\n                assert len(subg.ntypes) == 3\n            assert len(subg.srctypes) == 3\n            assert len(subg.dsttypes) == 3\n            assert len(subg.etypes) == 4\n            assert subg[\"follow\"].num_edges() == 4\n            assert subg[\"play\"].num_edges() == 2 if replace else 1\n            assert subg[\"liked-by\"].num_edges() == 4 if replace else 3\n            assert subg[\"flips\"].num_edges() == 0\n\n    _test3(prob, True)  # w/ replacement, uniform\n    _test3(prob, False)  # w/o replacement, uniform\n\n    # test different fanouts for different relations\n    for i in range(10):\n        subg = sample_neighbors_fusing_mode[fused](\n            hg,\n            {\"user\": [0, 1], \"game\": 0, \"coin\": 0},\n            {\"follow\": 1, \"play\": 2, \"liked-by\": 0, \"flips\": -1},\n            replace=True,\n        )\n        if not fused:\n            assert len(subg.ntypes) == 3\n        assert len(subg.srctypes) == 3\n        assert len(subg.dsttypes) == 3\n        assert len(subg.etypes) == 4\n        assert subg[\"follow\"].num_edges() == 2\n        assert subg[\"play\"].num_edges() == 2\n        assert subg[\"liked-by\"].num_edges() == 0\n        assert subg[\"flips\"].num_edges() == 4\n\n\ndef _test_sample_labors(hypersparse, prob):\n    g, hg = _gen_neighbor_sampling_test_graph(hypersparse, False)\n\n    # test with seed nodes [0, 1]\n    def _test1(p):\n        subg = dgl.sampling.sample_labors(g, [0, 1], -1, prob=p)[0]\n        assert subg.num_nodes() == g.num_nodes()\n        u, v = subg.edges()\n        u_ans, v_ans, e_ans = g.in_edges([0, 1], form=\"all\")\n        if p is not None:\n            emask = F.gather_row(g.edata[p], e_ans)\n            if p == \"prob\":\n                emask = emask != 0\n            u_ans = F.boolean_mask(u_ans, emask)\n            v_ans = F.boolean_mask(v_ans, emask)\n        uv = set(zip(F.asnumpy(u), F.asnumpy(v)))\n        uv_ans = set(zip(F.asnumpy(u_ans), F.asnumpy(v_ans)))\n        assert uv == uv_ans\n\n        for i in range(10):\n            subg = dgl.sampling.sample_labors(g, [0, 1], 2, prob=p)[0]\n            assert subg.num_nodes() == g.num_nodes()\n            assert subg.num_edges() >= 0\n            u, v = subg.edges()\n            assert set(F.asnumpy(F.unique(v))).issubset({0, 1})\n            assert F.array_equal(\n                F.astype(g.has_edges_between(u, v), F.int64),\n                F.ones((subg.num_edges(),), dtype=F.int64),\n            )\n            assert F.array_equal(g.edge_ids(u, v), subg.edata[dgl.EID])\n            edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))\n            # check no duplication\n            assert len(edge_set) == subg.num_edges()\n            if p is not None:\n                assert not (3, 0) in edge_set\n                assert not (3, 1) in edge_set\n\n    _test1(prob)\n\n    # test with seed nodes [0, 2]\n    def _test2(p):\n        subg = dgl.sampling.sample_labors(g, [0, 2], -1, prob=p)[0]\n        assert subg.num_nodes() == g.num_nodes()\n        u, v = subg.edges()\n        u_ans, v_ans, e_ans = g.in_edges([0, 2], form=\"all\")\n        if p is not None:\n            emask = F.gather_row(g.edata[p], e_ans)\n            if p == \"prob\":\n                emask = emask != 0\n            u_ans = F.boolean_mask(u_ans, emask)\n            v_ans = F.boolean_mask(v_ans, emask)\n        uv = set(zip(F.asnumpy(u), F.asnumpy(v)))\n        uv_ans = set(zip(F.asnumpy(u_ans), F.asnumpy(v_ans)))\n        assert uv == uv_ans\n\n        for i in range(10):\n            subg = dgl.sampling.sample_labors(g, [0, 2], 2, prob=p)[0]\n            assert subg.num_nodes() == g.num_nodes()\n            assert subg.num_edges() >= 0\n            u, v = subg.edges()\n            assert set(F.asnumpy(F.unique(v))).issubset({0, 2})\n            assert F.array_equal(\n                F.astype(g.has_edges_between(u, v), F.int64),\n                F.ones((subg.num_edges(),), dtype=F.int64),\n            )\n            assert F.array_equal(g.edge_ids(u, v), subg.edata[dgl.EID])\n            edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))\n            # check no duplication\n            assert len(edge_set) == subg.num_edges()\n            if p is not None:\n                assert not (3, 0) in edge_set\n\n    _test2(prob)\n\n    # test with heterogenous seed nodes\n    def _test3(p):\n        subg = dgl.sampling.sample_labors(\n            hg, {\"user\": [0, 1], \"game\": 0}, -1, prob=p\n        )[0]\n        assert len(subg.ntypes) == 3\n        assert len(subg.etypes) == 4\n        assert subg[\"follow\"].num_edges() == 6 if p is None else 4\n        assert subg[\"play\"].num_edges() == 1\n        assert subg[\"liked-by\"].num_edges() == 4\n        assert subg[\"flips\"].num_edges() == 0\n\n        for i in range(10):\n            subg = dgl.sampling.sample_labors(\n                hg, {\"user\": [0, 1], \"game\": 0}, 2, prob=p\n            )[0]\n            assert len(subg.ntypes) == 3\n            assert len(subg.etypes) == 4\n            assert subg[\"follow\"].num_edges() >= 0\n            assert subg[\"play\"].num_edges() >= 0\n            assert subg[\"liked-by\"].num_edges() >= 0\n            assert subg[\"flips\"].num_edges() >= 0\n\n    _test3(prob)\n\n    # test different fanouts for different relations\n    for i in range(10):\n        subg = dgl.sampling.sample_labors(\n            hg,\n            {\"user\": [0, 1], \"game\": 0, \"coin\": 0},\n            {\"follow\": 1, \"play\": 2, \"liked-by\": 0, \"flips\": g.num_nodes()},\n        )[0]\n        assert len(subg.ntypes) == 3\n        assert len(subg.etypes) == 4\n        assert subg[\"follow\"].num_edges() >= 0\n        assert subg[\"play\"].num_edges() >= 0\n        assert subg[\"liked-by\"].num_edges() == 0\n        assert subg[\"flips\"].num_edges() == 4\n\n\ndef _test_sample_neighbors_outedge(hypersparse, fused):\n    g, hg = _gen_neighbor_sampling_test_graph(hypersparse, True)\n\n    def _test1(p, replace):\n        subg = sample_neighbors_fusing_mode[fused](\n            g, [0, 1], -1, prob=p, replace=replace, edge_dir=\"out\"\n        )\n        if not fused:\n            assert subg.num_nodes() == g.num_nodes()\n\n        u, v = subg.edges()\n        if fused:\n            u, v = subg.dstdata[dgl.NID][u], subg.srcdata[dgl.NID][v]\n        u_ans, v_ans, e_ans = g.out_edges([0, 1], form=\"all\")\n        if p is not None:\n            emask = F.gather_row(g.edata[p], e_ans)\n            if p == \"prob\":\n                emask = emask != 0\n            u_ans = F.boolean_mask(u_ans, emask)\n            v_ans = F.boolean_mask(v_ans, emask)\n        uv = set(zip(F.asnumpy(u), F.asnumpy(v)))\n        uv_ans = set(zip(F.asnumpy(u_ans), F.asnumpy(v_ans)))\n        assert uv == uv_ans\n\n        for i in range(10):\n            subg = sample_neighbors_fusing_mode[fused](\n                g, [0, 1], 2, prob=p, replace=replace, edge_dir=\"out\"\n            )\n            if not fused:\n                assert subg.num_nodes() == g.num_nodes()\n            assert subg.num_edges() == 4\n            u, v = subg.edges()\n            if fused:\n                u, v = subg.dstdata[dgl.NID][u], subg.srcdata[dgl.NID][v]\n            assert set(F.asnumpy(F.unique(u))) == {0, 1}\n            assert F.array_equal(\n                F.astype(g.has_edges_between(u, v), F.int64),\n                F.ones((4,), dtype=F.int64),\n            )\n            assert F.array_equal(g.edge_ids(u, v), subg.edata[dgl.EID])\n            edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))\n            if not replace:\n                # check no duplication\n                assert len(edge_set) == 4\n            if p is not None:\n                assert not (0, 3) in edge_set\n                assert not (1, 3) in edge_set\n\n    _test1(None, True)  # w/ replacement, uniform\n    _test1(None, False)  # w/o replacement, uniform\n    _test1(\"prob\", True)  # w/ replacement\n    _test1(\"prob\", False)  # w/o replacement\n\n    def _test2(p, replace):  # fanout > #neighbors\n        subg = sample_neighbors_fusing_mode[fused](\n            g, [0, 2], -1, prob=p, replace=replace, edge_dir=\"out\"\n        )\n        if not fused:\n            assert subg.num_nodes() == g.num_nodes()\n        u, v = subg.edges()\n        if fused:\n            u, v = subg.dstdata[dgl.NID][u], subg.srcdata[dgl.NID][v]\n        u_ans, v_ans, e_ans = g.out_edges([0, 2], form=\"all\")\n        if p is not None:\n            emask = F.gather_row(g.edata[p], e_ans)\n            if p == \"prob\":\n                emask = emask != 0\n            u_ans = F.boolean_mask(u_ans, emask)\n            v_ans = F.boolean_mask(v_ans, emask)\n        uv = set(zip(F.asnumpy(u), F.asnumpy(v)))\n        uv_ans = set(zip(F.asnumpy(u_ans), F.asnumpy(v_ans)))\n        assert uv == uv_ans\n\n        for i in range(10):\n            subg = sample_neighbors_fusing_mode[fused](\n                g, [0, 2], 2, prob=p, replace=replace, edge_dir=\"out\"\n            )\n            if not fused:\n                assert subg.num_nodes() == g.num_nodes()\n            num_edges = 4 if replace else 3\n            assert subg.num_edges() == num_edges\n            u, v = subg.edges()\n            if fused:\n                u, v = subg.dstdata[dgl.NID][u], subg.srcdata[dgl.NID][v]\n\n            assert set(F.asnumpy(F.unique(u))) == {0, 2}\n            assert F.array_equal(\n                F.astype(g.has_edges_between(u, v), F.int64),\n                F.ones((num_edges,), dtype=F.int64),\n            )\n            assert F.array_equal(g.edge_ids(u, v), subg.edata[dgl.EID])\n            edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))\n            if not replace:\n                # check no duplication\n                assert len(edge_set) == num_edges\n            if p is not None:\n                assert not (0, 3) in edge_set\n\n    _test2(None, True)  # w/ replacement, uniform\n    _test2(None, False)  # w/o replacement, uniform\n    _test2(\"prob\", True)  # w/ replacement\n    _test2(\"prob\", False)  # w/o replacement\n\n    def _test3(p, replace):\n        subg = sample_neighbors_fusing_mode[fused](\n            hg,\n            {\"user\": [0, 1], \"game\": 0},\n            -1,\n            prob=p,\n            replace=replace,\n            edge_dir=\"out\",\n        )\n\n        if not fused:\n            assert len(subg.ntypes) == 3\n        assert len(subg.srctypes) == 3\n        assert len(subg.dsttypes) == 3\n        assert len(subg.etypes) == 4\n        assert subg[\"follow\"].num_edges() == 6 if p is None else 4\n        assert subg[\"play\"].num_edges() == 1\n        assert subg[\"liked-by\"].num_edges() == 4\n        assert subg[\"flips\"].num_edges() == 0\n\n        for i in range(10):\n            subg = sample_neighbors_fusing_mode[fused](\n                hg,\n                {\"user\": [0, 1], \"game\": 0},\n                2,\n                prob=p,\n                replace=replace,\n                edge_dir=\"out\",\n            )\n            if not fused:\n                assert len(subg.ntypes) == 3\n            assert len(subg.srctypes) == 3\n            assert len(subg.dsttypes) == 3\n            assert len(subg.etypes) == 4\n            assert subg[\"follow\"].num_edges() == 4\n            assert subg[\"play\"].num_edges() == 2 if replace else 1\n            assert subg[\"liked-by\"].num_edges() == 4 if replace else 3\n            assert subg[\"flips\"].num_edges() == 0\n\n    _test3(None, True)  # w/ replacement, uniform\n    _test3(None, False)  # w/o replacement, uniform\n    _test3(\"prob\", True)  # w/ replacement\n    _test3(\"prob\", False)  # w/o replacement\n\n\ndef _test_sample_neighbors_topk(hypersparse):\n    g, hg = _gen_neighbor_topk_test_graph(hypersparse, False)\n\n    def _test1():\n        subg = dgl.sampling.select_topk(g, -1, \"weight\", [0, 1])\n        assert subg.num_nodes() == g.num_nodes()\n        u, v = subg.edges()\n        u_ans, v_ans = subg.in_edges([0, 1])\n        uv = set(zip(F.asnumpy(u), F.asnumpy(v)))\n        uv_ans = set(zip(F.asnumpy(u_ans), F.asnumpy(v_ans)))\n        assert uv == uv_ans\n\n        subg = dgl.sampling.select_topk(g, 2, \"weight\", [0, 1])\n        assert subg.num_nodes() == g.num_nodes()\n        assert subg.num_edges() == 4\n        u, v = subg.edges()\n        edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))\n        assert F.array_equal(g.edge_ids(u, v), subg.edata[dgl.EID])\n        assert edge_set == {(2, 0), (1, 0), (2, 1), (3, 1)}\n\n    _test1()\n\n    def _test2():  # k > #neighbors\n        subg = dgl.sampling.select_topk(g, -1, \"weight\", [0, 2])\n        assert subg.num_nodes() == g.num_nodes()\n        u, v = subg.edges()\n        u_ans, v_ans = subg.in_edges([0, 2])\n        uv = set(zip(F.asnumpy(u), F.asnumpy(v)))\n        uv_ans = set(zip(F.asnumpy(u_ans), F.asnumpy(v_ans)))\n        assert uv == uv_ans\n\n        subg = dgl.sampling.select_topk(g, 2, \"weight\", [0, 2])\n        assert subg.num_nodes() == g.num_nodes()\n        assert subg.num_edges() == 3\n        u, v = subg.edges()\n        assert F.array_equal(g.edge_ids(u, v), subg.edata[dgl.EID])\n        edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))\n        assert edge_set == {(2, 0), (1, 0), (0, 2)}\n\n    _test2()\n\n    def _test3():\n        subg = dgl.sampling.select_topk(\n            hg, 2, \"weight\", {\"user\": [0, 1], \"game\": 0}\n        )\n        assert len(subg.ntypes) == 3\n        assert len(subg.etypes) == 4\n        u, v = subg[\"follow\"].edges()\n        edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))\n        assert F.array_equal(\n            hg[\"follow\"].edge_ids(u, v), subg[\"follow\"].edata[dgl.EID]\n        )\n        assert edge_set == {(2, 0), (1, 0), (2, 1), (3, 1)}\n        u, v = subg[\"play\"].edges()\n        edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))\n        assert F.array_equal(\n            hg[\"play\"].edge_ids(u, v), subg[\"play\"].edata[dgl.EID]\n        )\n        assert edge_set == {(0, 0)}\n        u, v = subg[\"liked-by\"].edges()\n        edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))\n        assert F.array_equal(\n            hg[\"liked-by\"].edge_ids(u, v), subg[\"liked-by\"].edata[dgl.EID]\n        )\n        assert edge_set == {(2, 0), (2, 1), (1, 0)}\n        assert subg[\"flips\"].num_edges() == 0\n\n    _test3()\n\n    # test different k for different relations\n    subg = dgl.sampling.select_topk(\n        hg,\n        {\"follow\": 1, \"play\": 2, \"liked-by\": 0, \"flips\": -1},\n        \"weight\",\n        {\"user\": [0, 1], \"game\": 0, \"coin\": 0},\n    )\n    assert len(subg.ntypes) == 3\n    assert len(subg.etypes) == 4\n    assert subg[\"follow\"].num_edges() == 2\n    assert subg[\"play\"].num_edges() == 1\n    assert subg[\"liked-by\"].num_edges() == 0\n    assert subg[\"flips\"].num_edges() == 4\n\n\ndef _test_sample_neighbors_topk_outedge(hypersparse):\n    g, hg = _gen_neighbor_topk_test_graph(hypersparse, True)\n\n    def _test1():\n        subg = dgl.sampling.select_topk(g, -1, \"weight\", [0, 1], edge_dir=\"out\")\n        assert subg.num_nodes() == g.num_nodes()\n        u, v = subg.edges()\n        u_ans, v_ans = subg.out_edges([0, 1])\n        uv = set(zip(F.asnumpy(u), F.asnumpy(v)))\n        uv_ans = set(zip(F.asnumpy(u_ans), F.asnumpy(v_ans)))\n        assert uv == uv_ans\n\n        subg = dgl.sampling.select_topk(g, 2, \"weight\", [0, 1], edge_dir=\"out\")\n        assert subg.num_nodes() == g.num_nodes()\n        assert subg.num_edges() == 4\n        u, v = subg.edges()\n        edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))\n        assert F.array_equal(g.edge_ids(u, v), subg.edata[dgl.EID])\n        assert edge_set == {(0, 2), (0, 1), (1, 2), (1, 3)}\n\n    _test1()\n\n    def _test2():  # k > #neighbors\n        subg = dgl.sampling.select_topk(g, -1, \"weight\", [0, 2], edge_dir=\"out\")\n        assert subg.num_nodes() == g.num_nodes()\n        u, v = subg.edges()\n        u_ans, v_ans = subg.out_edges([0, 2])\n        uv = set(zip(F.asnumpy(u), F.asnumpy(v)))\n        uv_ans = set(zip(F.asnumpy(u_ans), F.asnumpy(v_ans)))\n        assert uv == uv_ans\n\n        subg = dgl.sampling.select_topk(g, 2, \"weight\", [0, 2], edge_dir=\"out\")\n        assert subg.num_nodes() == g.num_nodes()\n        assert subg.num_edges() == 3\n        u, v = subg.edges()\n        edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))\n        assert F.array_equal(g.edge_ids(u, v), subg.edata[dgl.EID])\n        assert edge_set == {(0, 2), (0, 1), (2, 0)}\n\n    _test2()\n\n    def _test3():\n        subg = dgl.sampling.select_topk(\n            hg, 2, \"weight\", {\"user\": [0, 1], \"game\": 0}, edge_dir=\"out\"\n        )\n        assert len(subg.ntypes) == 3\n        assert len(subg.etypes) == 4\n        u, v = subg[\"follow\"].edges()\n        edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))\n        assert F.array_equal(\n            hg[\"follow\"].edge_ids(u, v), subg[\"follow\"].edata[dgl.EID]\n        )\n        assert edge_set == {(0, 2), (0, 1), (1, 2), (1, 3)}\n        u, v = subg[\"play\"].edges()\n        edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))\n        assert F.array_equal(\n            hg[\"play\"].edge_ids(u, v), subg[\"play\"].edata[dgl.EID]\n        )\n        assert edge_set == {(0, 0)}\n        u, v = subg[\"liked-by\"].edges()\n        edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))\n        assert F.array_equal(\n            hg[\"liked-by\"].edge_ids(u, v), subg[\"liked-by\"].edata[dgl.EID]\n        )\n        assert edge_set == {(0, 2), (1, 2), (0, 1)}\n        assert subg[\"flips\"].num_edges() == 0\n\n    _test3()\n\n\ndef test_sample_neighbors_noprob():\n    _test_sample_neighbors(False, None, False)\n    if F._default_context_str != \"gpu\" and F.backend_name == \"pytorch\":\n        _test_sample_neighbors(False, None, True)\n    # _test_sample_neighbors(True)\n\n\ndef test_sample_labors_noprob():\n    _test_sample_labors(False, None)\n\n\ndef test_sample_neighbors_prob():\n    _test_sample_neighbors(False, \"prob\", False)\n    if F._default_context_str != \"gpu\" and F.backend_name == \"pytorch\":\n        _test_sample_neighbors(False, \"prob\", True)\n    # _test_sample_neighbors(True)\n\n\ndef test_sample_labors_prob():\n    _test_sample_labors(False, \"prob\")\n\n\ndef test_sample_neighbors_outedge():\n    _test_sample_neighbors_outedge(False, False)\n    if F._default_context_str != \"gpu\" and F.backend_name == \"pytorch\":\n        _test_sample_neighbors_outedge(False, True)\n    # _test_sample_neighbors_outedge(True)\n\n\n@unittest.skipIf(\n    F.backend_name == \"mxnet\", reason=\"MXNet has problem converting bool arrays\"\n)\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"GPU sample neighbors with mask not implemented\",\n)\ndef test_sample_neighbors_mask():\n    _test_sample_neighbors(False, \"mask\", False)\n    if F._default_context_str != \"gpu\" and F.backend_name == \"pytorch\":\n        _test_sample_neighbors(False, \"mask\", True)\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"GPU sample neighbors not implemented\",\n)\ndef test_sample_neighbors_topk():\n    _test_sample_neighbors_topk(False)\n    # _test_sample_neighbors_topk(True)\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"GPU sample neighbors not implemented\",\n)\ndef test_sample_neighbors_topk_outedge():\n    _test_sample_neighbors_topk_outedge(False)\n    # _test_sample_neighbors_topk_outedge(True)\n\n\n@pytest.mark.parametrize(\"fused\", [False, True])\ndef test_sample_neighbors_with_0deg(fused):\n    if fused and (\n        F._default_context_str == \"gpu\" or F.backend_name != \"pytorch\"\n    ):\n        pytest.skip(\"Fused sampling support CPU with backend PyTorch.\")\n    g = dgl.graph(([], []), num_nodes=5).to(F.ctx())\n    sg = sample_neighbors_fusing_mode[fused](\n        g, F.tensor([1, 2], dtype=F.int64), 2, edge_dir=\"in\", replace=False\n    )\n    assert sg.num_edges() == 0\n    sg = sample_neighbors_fusing_mode[fused](\n        g, F.tensor([1, 2], dtype=F.int64), 2, edge_dir=\"in\", replace=True\n    )\n    assert sg.num_edges() == 0\n    sg = sample_neighbors_fusing_mode[fused](\n        g, F.tensor([1, 2], dtype=F.int64), 2, edge_dir=\"out\", replace=False\n    )\n    assert sg.num_edges() == 0\n    sg = sample_neighbors_fusing_mode[fused](\n        g, F.tensor([1, 2], dtype=F.int64), 2, edge_dir=\"out\", replace=True\n    )\n    assert sg.num_edges() == 0\n\n\ndef create_test_graph(num_nodes, num_edges_per_node, bipartite=False):\n    src = np.concatenate(\n        [np.array([i] * num_edges_per_node) for i in range(num_nodes)]\n    )\n    dst = np.concatenate(\n        [\n            np.random.choice(num_nodes, num_edges_per_node, replace=False)\n            for i in range(num_nodes)\n        ]\n    )\n    if bipartite:\n        g = dgl.heterograph({(\"u\", \"e\", \"v\"): (src, dst)})\n    else:\n        g = dgl.graph((src, dst))\n    return g\n\n\ndef create_etype_test_graph(num_nodes, num_edges_per_node, rare_cnt):\n    src = np.concatenate(\n        [\n            np.random.choice(num_nodes, num_edges_per_node, replace=False)\n            for i in range(num_nodes)\n        ]\n    )\n    dst = np.concatenate(\n        [np.array([i] * num_edges_per_node) for i in range(num_nodes)]\n    )\n\n    minor_src = np.concatenate(\n        [\n            np.random.choice(num_nodes, 2, replace=False)\n            for i in range(num_nodes)\n        ]\n    )\n    minor_dst = np.concatenate([np.array([i] * 2) for i in range(num_nodes)])\n\n    most_zero_src = np.concatenate(\n        [\n            np.random.choice(num_nodes, num_edges_per_node, replace=False)\n            for i in range(rare_cnt)\n        ]\n    )\n    most_zero_dst = np.concatenate(\n        [np.array([i] * num_edges_per_node) for i in range(rare_cnt)]\n    )\n\n    g = dgl.heterograph(\n        {\n            (\"v\", \"e_major\", \"u\"): (src, dst),\n            (\"u\", \"e_major_rev\", \"v\"): (dst, src),\n            (\"v2\", \"e_minor\", \"u\"): (minor_src, minor_dst),\n            (\"v2\", \"most_zero\", \"u\"): (most_zero_src, most_zero_dst),\n            (\"u\", \"e_minor_rev\", \"v2\"): (minor_dst, minor_src),\n        }\n    )\n    for etype in g.etypes:\n        prob = np.random.rand(g.num_edges(etype))\n        prob[prob > 0.2] = 0\n        g.edges[etype].data[\"p\"] = F.zerocopy_from_numpy(prob)\n        g.edges[etype].data[\"mask\"] = F.zerocopy_from_numpy(prob != 0)\n\n    return g\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"GPU sample neighbors not implemented\",\n)\ndef test_sample_neighbors_biased_homogeneous():\n    g = create_test_graph(100, 30)\n\n    def check_num(nodes, tag):\n        nodes, tag = F.asnumpy(nodes), F.asnumpy(tag)\n        cnt = [sum(tag[nodes] == i) for i in range(4)]\n        # No tag 0\n        assert cnt[0] == 0\n\n        # very rare tag 1\n        assert cnt[2] > 2 * cnt[1]\n        assert cnt[3] > 2 * cnt[1]\n\n    tag = F.tensor(np.random.choice(4, 100))\n    bias = F.tensor([0, 0.1, 10, 10], dtype=F.float32)\n    # inedge / without replacement\n    g_sorted = dgl.sort_csc_by_tag(g, tag)\n    for _ in range(5):\n        subg = dgl.sampling.sample_neighbors_biased(\n            g_sorted, g.nodes(), 5, bias, replace=False\n        )\n        check_num(subg.edges()[0], tag)\n        u, v = subg.edges()\n        edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))\n        assert len(edge_set) == subg.num_edges()\n\n    # inedge / with replacement\n    for _ in range(5):\n        subg = dgl.sampling.sample_neighbors_biased(\n            g_sorted, g.nodes(), 5, bias, replace=True\n        )\n        check_num(subg.edges()[0], tag)\n\n    # outedge / without replacement\n    g_sorted = dgl.sort_csr_by_tag(g, tag)\n    for _ in range(5):\n        subg = dgl.sampling.sample_neighbors_biased(\n            g_sorted, g.nodes(), 5, bias, edge_dir=\"out\", replace=False\n        )\n        check_num(subg.edges()[1], tag)\n        u, v = subg.edges()\n        edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))\n        assert len(edge_set) == subg.num_edges()\n\n    # outedge / with replacement\n    for _ in range(5):\n        subg = dgl.sampling.sample_neighbors_biased(\n            g_sorted, g.nodes(), 5, bias, edge_dir=\"out\", replace=True\n        )\n        check_num(subg.edges()[1], tag)\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"GPU sample neighbors not implemented\",\n)\ndef test_sample_neighbors_biased_bipartite():\n    g = create_test_graph(100, 30, True)\n    num_dst = g.num_dst_nodes()\n    bias = F.tensor([0, 0.01, 10, 10], dtype=F.float32)\n\n    def check_num(nodes, tag):\n        nodes, tag = F.asnumpy(nodes), F.asnumpy(tag)\n        cnt = [sum(tag[nodes] == i) for i in range(4)]\n        # No tag 0\n        assert cnt[0] == 0\n\n        # very rare tag 1\n        assert cnt[2] > 2 * cnt[1]\n        assert cnt[3] > 2 * cnt[1]\n\n    # inedge / without replacement\n    tag = F.tensor(np.random.choice(4, 100))\n    g_sorted = dgl.sort_csc_by_tag(g, tag)\n    for _ in range(5):\n        subg = dgl.sampling.sample_neighbors_biased(\n            g_sorted, g.dstnodes(), 5, bias, replace=False\n        )\n        check_num(subg.edges()[0], tag)\n        u, v = subg.edges()\n        edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))\n        assert len(edge_set) == subg.num_edges()\n\n    # inedge / with replacement\n    for _ in range(5):\n        subg = dgl.sampling.sample_neighbors_biased(\n            g_sorted, g.dstnodes(), 5, bias, replace=True\n        )\n        check_num(subg.edges()[0], tag)\n\n    # outedge / without replacement\n    tag = F.tensor(np.random.choice(4, num_dst))\n    g_sorted = dgl.sort_csr_by_tag(g, tag)\n    for _ in range(5):\n        subg = dgl.sampling.sample_neighbors_biased(\n            g_sorted, g.srcnodes(), 5, bias, edge_dir=\"out\", replace=False\n        )\n        check_num(subg.edges()[1], tag)\n        u, v = subg.edges()\n        edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))\n        assert len(edge_set) == subg.num_edges()\n\n    # outedge / with replacement\n    for _ in range(5):\n        subg = dgl.sampling.sample_neighbors_biased(\n            g_sorted, g.srcnodes(), 5, bias, edge_dir=\"out\", replace=True\n        )\n        check_num(subg.edges()[1], tag)\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"GPU sample neighbors not implemented\",\n)\n@unittest.skipIf(\n    F.backend_name == \"mxnet\", reason=\"MXNet has problem converting bool arrays\"\n)\n@pytest.mark.parametrize(\"format_\", [\"coo\", \"csr\", \"csc\"])\n@pytest.mark.parametrize(\"direction\", [\"in\", \"out\"])\n@pytest.mark.parametrize(\"replace\", [False, True])\ndef test_sample_neighbors_etype_homogeneous(format_, direction, replace):\n    num_nodes = 100\n    rare_cnt = 4\n    g = create_etype_test_graph(100, 30, rare_cnt)\n    h_g = dgl.to_homogeneous(g, edata=[\"p\", \"mask\"])\n    h_g_etype = F.asnumpy(h_g.edata[dgl.ETYPE])\n    h_g_offset = np.cumsum(np.insert(np.bincount(h_g_etype), 0, 0)).tolist()\n    sg = g.edge_subgraph(g.edata[\"mask\"], relabel_nodes=False)\n    h_sg = h_g.edge_subgraph(h_g.edata[\"mask\"], relabel_nodes=False)\n    h_sg_etype = F.asnumpy(h_sg.edata[dgl.ETYPE])\n    h_sg_offset = np.cumsum(np.insert(np.bincount(h_sg_etype), 0, 0)).tolist()\n\n    seed_ntype = g.get_ntype_id(\"u\")\n    seeds = F.nonzero_1d(h_g.ndata[dgl.NTYPE] == seed_ntype)\n    fanouts = F.tensor([6, 5, 4, 3, 2], dtype=F.int64)\n\n    def check_num(h_g, all_src, all_dst, subg, replace, fanouts, direction):\n        src, dst = subg.edges()\n        all_etype_array = F.asnumpy(h_g.edata[dgl.ETYPE])\n        num_etypes = all_etype_array.max() + 1\n        etype_array = F.asnumpy(subg.edata[dgl.ETYPE])\n        src = F.asnumpy(src)\n        dst = F.asnumpy(dst)\n        fanouts = F.asnumpy(fanouts)\n\n        all_src = F.asnumpy(all_src)\n        all_dst = F.asnumpy(all_dst)\n\n        src_per_etype = []\n        dst_per_etype = []\n        all_src_per_etype = []\n        all_dst_per_etype = []\n        for etype in range(num_etypes):\n            src_per_etype.append(src[etype_array == etype])\n            dst_per_etype.append(dst[etype_array == etype])\n            all_src_per_etype.append(all_src[all_etype_array == etype])\n            all_dst_per_etype.append(all_dst[all_etype_array == etype])\n\n        if replace:\n            if direction == \"in\":\n                in_degree_per_etype = [np.bincount(d) for d in dst_per_etype]\n                for etype in range(len(fanouts)):\n                    in_degree = in_degree_per_etype[etype]\n                    fanout = fanouts[etype]\n                    ans = np.zeros_like(in_degree)\n                    if len(in_degree) > 0:\n                        ans[all_dst_per_etype[etype]] = fanout\n                    assert np.all(in_degree == ans)\n            else:\n                out_degree_per_etype = [np.bincount(s) for s in src_per_etype]\n                for etype in range(len(fanouts)):\n                    out_degree = out_degree_per_etype[etype]\n                    fanout = fanouts[etype]\n                    ans = np.zeros_like(out_degree)\n                    if len(out_degree) > 0:\n                        ans[all_src_per_etype[etype]] = fanout\n                    assert np.all(out_degree == ans)\n        else:\n            if direction == \"in\":\n                for v in set(dst):\n                    u = src[dst == v]\n                    et = etype_array[dst == v]\n                    all_u = all_src[all_dst == v]\n                    all_et = all_etype_array[all_dst == v]\n                    for etype in set(et):\n                        u_etype = set(u[et == etype])\n                        all_u_etype = set(all_u[all_et == etype])\n                        assert (len(u_etype) == fanouts[etype]) or (\n                            u_etype == all_u_etype\n                        )\n            else:\n                for u in set(src):\n                    v = dst[src == u]\n                    et = etype_array[src == u]\n                    all_v = all_dst[all_src == u]\n                    all_et = all_etype_array[all_src == u]\n                    for etype in set(et):\n                        v_etype = set(v[et == etype])\n                        all_v_etype = set(all_v[all_et == etype])\n                        assert (len(v_etype) == fanouts[etype]) or (\n                            v_etype == all_v_etype\n                        )\n\n    all_src, all_dst = h_g.edges()\n    all_sub_src, all_sub_dst = h_sg.edges()\n    h_g = h_g.formats(format_)\n    if (direction, format_) in [(\"in\", \"csr\"), (\"out\", \"csc\")]:\n        h_g = h_g.formats([\"csc\", \"csr\", \"coo\"])\n    for _ in range(5):\n        subg = dgl.sampling.sample_etype_neighbors(\n            h_g, seeds, h_g_offset, fanouts, replace=replace, edge_dir=direction\n        )\n        check_num(h_g, all_src, all_dst, subg, replace, fanouts, direction)\n\n        p = [g.edges[etype].data[\"p\"] for etype in g.etypes]\n        subg = dgl.sampling.sample_etype_neighbors(\n            h_g,\n            seeds,\n            h_g_offset,\n            fanouts,\n            replace=replace,\n            edge_dir=direction,\n            prob=p,\n        )\n        check_num(\n            h_sg, all_sub_src, all_sub_dst, subg, replace, fanouts, direction\n        )\n\n        p = [g.edges[etype].data[\"mask\"] for etype in g.etypes]\n        subg = dgl.sampling.sample_etype_neighbors(\n            h_g,\n            seeds,\n            h_g_offset,\n            fanouts,\n            replace=replace,\n            edge_dir=direction,\n            prob=p,\n        )\n        check_num(\n            h_sg, all_sub_src, all_sub_dst, subg, replace, fanouts, direction\n        )\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"GPU sample neighbors not implemented\",\n)\n@unittest.skipIf(\n    F.backend_name == \"mxnet\", reason=\"MXNet has problem converting bool arrays\"\n)\n@pytest.mark.parametrize(\"format_\", [\"csr\", \"csc\"])\n@pytest.mark.parametrize(\"direction\", [\"in\", \"out\"])\ndef test_sample_neighbors_etype_sorted_homogeneous(format_, direction):\n    rare_cnt = 4\n    g = create_etype_test_graph(100, 30, rare_cnt)\n    h_g = dgl.to_homogeneous(g)\n    seed_ntype = g.get_ntype_id(\"u\")\n    seeds = F.nonzero_1d(h_g.ndata[dgl.NTYPE] == seed_ntype)\n    fanouts = F.tensor([6, 5, -1, 3, 2], dtype=F.int64)\n    h_g = h_g.formats(format_)\n    if (direction, format_) in [(\"in\", \"csr\"), (\"out\", \"csc\")]:\n        h_g = h_g.formats([\"csc\", \"csr\", \"coo\"])\n\n    if direction == \"in\":\n        h_g = dgl.sort_csc_by_tag(h_g, h_g.edata[dgl.ETYPE], tag_type=\"edge\")\n    else:\n        h_g = dgl.sort_csr_by_tag(h_g, h_g.edata[dgl.ETYPE], tag_type=\"edge\")\n    # shuffle\n    h_g_etype = F.asnumpy(h_g.edata[dgl.ETYPE])\n    h_g_offset = np.cumsum(np.insert(np.bincount(h_g_etype), 0, 0)).tolist()\n    sg = dgl.sampling.sample_etype_neighbors(\n        h_g, seeds, h_g_offset, fanouts, edge_dir=direction, etype_sorted=True\n    )\n\n\n@pytest.mark.parametrize(\"dtype\", [\"int32\", \"int64\"])\n@pytest.mark.parametrize(\"fused\", [False, True])\ndef test_sample_neighbors_exclude_edges_heteroG(dtype, fused):\n    if fused and (\n        F._default_context_str == \"gpu\" or F.backend_name != \"pytorch\"\n    ):\n        pytest.skip(\"Fused sampling support CPU with backend PyTorch.\")\n    d_i_d_u_nodes = F.zerocopy_from_numpy(\n        np.unique(np.random.randint(300, size=100, dtype=dtype))\n    )\n    d_i_d_v_nodes = F.zerocopy_from_numpy(\n        np.random.randint(25, size=d_i_d_u_nodes.shape, dtype=dtype)\n    )\n    d_i_g_u_nodes = F.zerocopy_from_numpy(\n        np.unique(np.random.randint(300, size=100, dtype=dtype))\n    )\n    d_i_g_v_nodes = F.zerocopy_from_numpy(\n        np.random.randint(25, size=d_i_g_u_nodes.shape, dtype=dtype)\n    )\n    d_t_d_u_nodes = F.zerocopy_from_numpy(\n        np.unique(np.random.randint(300, size=100, dtype=dtype))\n    )\n    d_t_d_v_nodes = F.zerocopy_from_numpy(\n        np.random.randint(25, size=d_t_d_u_nodes.shape, dtype=dtype)\n    )\n\n    g = dgl.heterograph(\n        {\n            (\"drug\", \"interacts\", \"drug\"): (d_i_d_u_nodes, d_i_d_v_nodes),\n            (\"drug\", \"interacts\", \"gene\"): (d_i_g_u_nodes, d_i_g_v_nodes),\n            (\"drug\", \"treats\", \"disease\"): (d_t_d_u_nodes, d_t_d_v_nodes),\n        }\n    ).to(F.ctx())\n\n    (U, V, EID) = (0, 1, 2)\n\n    nd_b_idx = np.random.randint(low=1, high=24, dtype=dtype)\n    nd_e_idx = np.random.randint(low=25, high=49, dtype=dtype)\n    did_b_idx = np.random.randint(low=1, high=24, dtype=dtype)\n    did_e_idx = np.random.randint(low=25, high=49, dtype=dtype)\n    sampled_amount = np.random.randint(low=1, high=10, dtype=dtype)\n\n    drug_i_drug_edges = g.all_edges(\n        form=\"all\", etype=(\"drug\", \"interacts\", \"drug\")\n    )\n    excluded_d_i_d_edges = drug_i_drug_edges[EID][did_b_idx:did_e_idx]\n    sampled_drug_node = drug_i_drug_edges[V][nd_b_idx:nd_e_idx]\n    did_excluded_nodes_U = drug_i_drug_edges[U][did_b_idx:did_e_idx]\n    did_excluded_nodes_V = drug_i_drug_edges[V][did_b_idx:did_e_idx]\n\n    nd_b_idx = np.random.randint(low=1, high=24, dtype=dtype)\n    nd_e_idx = np.random.randint(low=25, high=49, dtype=dtype)\n    dig_b_idx = np.random.randint(low=1, high=24, dtype=dtype)\n    dig_e_idx = np.random.randint(low=25, high=49, dtype=dtype)\n    drug_i_gene_edges = g.all_edges(\n        form=\"all\", etype=(\"drug\", \"interacts\", \"gene\")\n    )\n    excluded_d_i_g_edges = drug_i_gene_edges[EID][dig_b_idx:dig_e_idx]\n    dig_excluded_nodes_U = drug_i_gene_edges[U][dig_b_idx:dig_e_idx]\n    dig_excluded_nodes_V = drug_i_gene_edges[V][dig_b_idx:dig_e_idx]\n    sampled_gene_node = drug_i_gene_edges[V][nd_b_idx:nd_e_idx]\n\n    nd_b_idx = np.random.randint(low=1, high=24, dtype=dtype)\n    nd_e_idx = np.random.randint(low=25, high=49, dtype=dtype)\n    dtd_b_idx = np.random.randint(low=1, high=24, dtype=dtype)\n    dtd_e_idx = np.random.randint(low=25, high=49, dtype=dtype)\n    drug_t_dis_edges = g.all_edges(\n        form=\"all\", etype=(\"drug\", \"treats\", \"disease\")\n    )\n    excluded_d_t_d_edges = drug_t_dis_edges[EID][dtd_b_idx:dtd_e_idx]\n    dtd_excluded_nodes_U = drug_t_dis_edges[U][dtd_b_idx:dtd_e_idx]\n    dtd_excluded_nodes_V = drug_t_dis_edges[V][dtd_b_idx:dtd_e_idx]\n    sampled_disease_node = drug_t_dis_edges[V][nd_b_idx:nd_e_idx]\n    excluded_edges = {\n        (\"drug\", \"interacts\", \"drug\"): excluded_d_i_d_edges,\n        (\"drug\", \"interacts\", \"gene\"): excluded_d_i_g_edges,\n        (\"drug\", \"treats\", \"disease\"): excluded_d_t_d_edges,\n    }\n\n    sg = sample_neighbors_fusing_mode[fused](\n        g,\n        {\n            \"drug\": sampled_drug_node,\n            \"gene\": sampled_gene_node,\n            \"disease\": sampled_disease_node,\n        },\n        sampled_amount,\n        exclude_edges=excluded_edges,\n    )\n\n    if fused:\n\n        def contain_edge(g, sg, etype, u, v):\n            # set of subgraph graph edges deduced from original graph\n            org_edges = set(\n                map(\n                    tuple,\n                    np.stack(\n                        g.find_edges(sg.edges[etype].data[dgl.EID], etype),\n                        axis=1,\n                    ),\n                )\n            )\n            # set of excluded edges\n            excluded_edges = set(map(tuple, np.stack((u, v), axis=1)))\n\n            diff_set = org_edges - excluded_edges\n\n            return len(diff_set) != len(org_edges)\n\n        assert not contain_edge(\n            g,\n            sg,\n            (\"drug\", \"interacts\", \"drug\"),\n            did_excluded_nodes_U,\n            did_excluded_nodes_V,\n        )\n        assert not contain_edge(\n            g,\n            sg,\n            (\"drug\", \"interacts\", \"gene\"),\n            dig_excluded_nodes_U,\n            dig_excluded_nodes_V,\n        )\n        assert not contain_edge(\n            g,\n            sg,\n            (\"drug\", \"treats\", \"disease\"),\n            dtd_excluded_nodes_U,\n            dtd_excluded_nodes_V,\n        )\n    else:\n        assert not np.any(\n            F.asnumpy(\n                sg.has_edges_between(\n                    did_excluded_nodes_U,\n                    did_excluded_nodes_V,\n                    etype=(\"drug\", \"interacts\", \"drug\"),\n                )\n            )\n        )\n        assert not np.any(\n            F.asnumpy(\n                sg.has_edges_between(\n                    dig_excluded_nodes_U,\n                    dig_excluded_nodes_V,\n                    etype=(\"drug\", \"interacts\", \"gene\"),\n                )\n            )\n        )\n        assert not np.any(\n            F.asnumpy(\n                sg.has_edges_between(\n                    dtd_excluded_nodes_U,\n                    dtd_excluded_nodes_V,\n                    etype=(\"drug\", \"treats\", \"disease\"),\n                )\n            )\n        )\n\n\n@pytest.mark.parametrize(\"dtype\", [\"int32\", \"int64\"])\n@pytest.mark.parametrize(\"fused\", [False, True])\ndef test_sample_neighbors_exclude_edges_homoG(dtype, fused):\n    if fused and (\n        F._default_context_str == \"gpu\" or F.backend_name != \"pytorch\"\n    ):\n        pytest.skip(\"Fused sampling support CPU with backend PyTorch.\")\n    u_nodes = F.zerocopy_from_numpy(\n        np.unique(np.random.randint(300, size=100, dtype=dtype))\n    )\n    v_nodes = F.zerocopy_from_numpy(\n        np.random.randint(25, size=u_nodes.shape, dtype=dtype)\n    )\n    g = dgl.graph((u_nodes, v_nodes)).to(F.ctx())\n\n    (U, V, EID) = (0, 1, 2)\n\n    nd_b_idx = np.random.randint(low=1, high=24, dtype=dtype)\n    nd_e_idx = np.random.randint(low=25, high=49, dtype=dtype)\n    b_idx = np.random.randint(low=1, high=24, dtype=dtype)\n    e_idx = np.random.randint(low=25, high=49, dtype=dtype)\n    sampled_amount = np.random.randint(low=1, high=10, dtype=dtype)\n\n    g_edges = g.all_edges(form=\"all\")\n    excluded_edges = g_edges[EID][b_idx:e_idx]\n    sampled_node = g_edges[V][nd_b_idx:nd_e_idx]\n    excluded_nodes_U = g_edges[U][b_idx:e_idx]\n    excluded_nodes_V = g_edges[V][b_idx:e_idx]\n\n    sg = sample_neighbors_fusing_mode[fused](\n        g, sampled_node, sampled_amount, exclude_edges=excluded_edges\n    )\n    if fused:\n\n        def contain_edge(g, sg, u, v):\n            # set of subgraph graph edges deduced from original graph\n            org_edges = set(\n                map(\n                    tuple,\n                    np.stack(\n                        g.find_edges(sg.edges[\"_E\"].data[dgl.EID]), axis=1\n                    ),\n                )\n            )\n            # set of excluded edges\n            excluded_edges = set(map(tuple, np.stack((u, v), axis=1)))\n\n            diff_set = org_edges - excluded_edges\n\n            return len(diff_set) != len(org_edges)\n\n        assert not contain_edge(g, sg, excluded_nodes_U, excluded_nodes_V)\n    else:\n        assert not np.any(\n            F.asnumpy(sg.has_edges_between(excluded_nodes_U, excluded_nodes_V))\n        )\n\n\n@pytest.mark.parametrize(\"dtype\", [\"int32\", \"int64\"])\ndef test_global_uniform_negative_sampling(dtype):\n    warnings.simplefilter(\"ignore\", np.exceptions.ComplexWarning)\n    g = dgl.graph(([], []), num_nodes=1000).to(F.ctx())\n    src, dst = dgl.sampling.global_uniform_negative_sampling(\n        g, 2000, False, True\n    )\n    assert len(src) == 2000\n    assert len(dst) == 2000\n\n    g = dgl.graph(\n        (np.random.randint(0, 20, (300,)), np.random.randint(0, 20, (300,)))\n    ).to(F.ctx())\n    src, dst = dgl.sampling.global_uniform_negative_sampling(g, 20, False, True)\n    assert not F.asnumpy(g.has_edges_between(src, dst)).any()\n\n    src, dst = dgl.sampling.global_uniform_negative_sampling(\n        g, 20, False, False\n    )\n    assert not F.asnumpy(g.has_edges_between(src, dst)).any()\n    src = F.asnumpy(src)\n    dst = F.asnumpy(dst)\n    s = set(zip(src.tolist(), dst.tolist()))\n    assert len(s) == len(src)\n\n    g = dgl.graph(([0], [1])).to(F.ctx())\n    src, dst = dgl.sampling.global_uniform_negative_sampling(\n        g, 20, True, False, redundancy=10\n    )\n    src = F.asnumpy(src)\n    dst = F.asnumpy(dst)\n    # should have either no element or (1, 0)\n    assert len(src) < 2\n    assert len(dst) < 2\n    if len(src) == 1:\n        assert src[0] == 1\n        assert dst[0] == 0\n\n    g = dgl.heterograph(\n        {\n            (\"A\", \"AB\", \"B\"): (\n                np.random.randint(0, 20, (300,)),\n                np.random.randint(0, 40, (300,)),\n            ),\n            (\"B\", \"BA\", \"A\"): (\n                np.random.randint(0, 40, (200,)),\n                np.random.randint(0, 20, (200,)),\n            ),\n        }\n    ).to(F.ctx())\n    src, dst = dgl.sampling.global_uniform_negative_sampling(\n        g, 20, False, etype=\"AB\"\n    )\n    assert not F.asnumpy(g.has_edges_between(src, dst, etype=\"AB\")).any()\n\n\nif __name__ == \"__main__\":\n    from itertools import product\n\n    test_sample_neighbors_noprob()\n    test_sample_labors_noprob()\n    test_sample_neighbors_prob()\n    test_sample_labors_prob()\n    test_sample_neighbors_mask()\n    for args in product([\"coo\", \"csr\", \"csc\"], [\"in\", \"out\"], [False, True]):\n        test_sample_neighbors_etype_homogeneous(*args)\n    for args in product([\"csr\", \"csc\"], [\"in\", \"out\"]):\n        test_sample_neighbors_etype_sorted_homogeneous(*args)\n    test_non_uniform_random_walk(False)\n    test_uniform_random_walk(False)\n    test_pack_traces()\n    test_pinsage_sampling(False)\n    test_sample_neighbors_outedge()\n    test_sample_neighbors_topk()\n    test_sample_neighbors_topk_outedge()\n    test_sample_neighbors_with_0deg()\n    test_sample_neighbors_biased_homogeneous()\n    test_sample_neighbors_biased_bipartite()\n    test_sample_neighbors_exclude_edges_heteroG(\"int32\")\n    test_sample_neighbors_exclude_edges_homoG(\"int32\")\n    test_global_uniform_negative_sampling(\"int32\")\n    test_global_uniform_negative_sampling(\"int64\")\n"
  },
  {
    "path": "tests/python/common/test_batch-graph.py",
    "content": "import unittest\n\nimport backend as F\n\nimport dgl\nimport numpy as np\nfrom utils import parametrize_idtype\n\n\ndef tree1(idtype):\n    \"\"\"Generate a tree\n         0\n        / \\\n       1   2\n      / \\\n     3   4\n    Edges are from leaves to root.\n    \"\"\"\n    g = dgl.graph(([], [])).astype(idtype).to(F.ctx())\n    g.add_nodes(5)\n    g.add_edges(3, 1)\n    g.add_edges(4, 1)\n    g.add_edges(1, 0)\n    g.add_edges(2, 0)\n    g.ndata[\"h\"] = F.tensor([0, 1, 2, 3, 4])\n    g.edata[\"h\"] = F.randn((4, 10))\n    return g\n\n\ndef tree2(idtype):\n    \"\"\"Generate a tree\n         1\n        / \\\n       4   3\n      / \\\n     2   0\n    Edges are from leaves to root.\n    \"\"\"\n    g = dgl.graph(([], [])).astype(idtype).to(F.ctx())\n    g.add_nodes(5)\n    g.add_edges(2, 4)\n    g.add_edges(0, 4)\n    g.add_edges(4, 1)\n    g.add_edges(3, 1)\n    g.ndata[\"h\"] = F.tensor([0, 1, 2, 3, 4])\n    g.edata[\"h\"] = F.randn((4, 10))\n    return g\n\n\n@parametrize_idtype\ndef test_batch_unbatch(idtype):\n    t1 = tree1(idtype)\n    t2 = tree2(idtype)\n\n    bg = dgl.batch([t1, t2])\n    assert bg.num_nodes() == 10\n    assert bg.num_edges() == 8\n    assert bg.batch_size == 2\n    assert F.allclose(bg.batch_num_nodes(), F.tensor([5, 5]))\n    assert F.allclose(bg.batch_num_edges(), F.tensor([4, 4]))\n\n    tt1, tt2 = dgl.unbatch(bg)\n    assert F.allclose(t1.ndata[\"h\"], tt1.ndata[\"h\"])\n    assert F.allclose(t1.edata[\"h\"], tt1.edata[\"h\"])\n    assert F.allclose(t2.ndata[\"h\"], tt2.ndata[\"h\"])\n    assert F.allclose(t2.edata[\"h\"], tt2.edata[\"h\"])\n\n\n@parametrize_idtype\ndef test_batch_unbatch1(idtype):\n    t1 = tree1(idtype)\n    t2 = tree2(idtype)\n    b1 = dgl.batch([t1, t2])\n    b2 = dgl.batch([t2, b1])\n    assert b2.num_nodes() == 15\n    assert b2.num_edges() == 12\n    assert b2.batch_size == 3\n    assert F.allclose(b2.batch_num_nodes(), F.tensor([5, 5, 5]))\n    assert F.allclose(b2.batch_num_edges(), F.tensor([4, 4, 4]))\n\n    s1, s2, s3 = dgl.unbatch(b2)\n    assert F.allclose(t2.ndata[\"h\"], s1.ndata[\"h\"])\n    assert F.allclose(t2.edata[\"h\"], s1.edata[\"h\"])\n    assert F.allclose(t1.ndata[\"h\"], s2.ndata[\"h\"])\n    assert F.allclose(t1.edata[\"h\"], s2.edata[\"h\"])\n    assert F.allclose(t2.ndata[\"h\"], s3.ndata[\"h\"])\n    assert F.allclose(t2.edata[\"h\"], s3.edata[\"h\"])\n\n\n@unittest.skipIf(\n    dgl.backend.backend_name == \"tensorflow\",\n    reason=\"TF doesn't support inplace update\",\n)\n@parametrize_idtype\ndef test_batch_unbatch_frame(idtype):\n    \"\"\"Test module of node/edge frames of batched/unbatched DGLGraphs.\n    Also address the bug mentioned in https://github.com/dmlc/dgl/issues/1475.\n    \"\"\"\n    t1 = tree1(idtype)\n    t2 = tree2(idtype)\n    N1 = t1.num_nodes()\n    E1 = t1.num_edges()\n    N2 = t2.num_nodes()\n    E2 = t2.num_edges()\n    D = 10\n    t1.ndata[\"h\"] = F.randn((N1, D))\n    t1.edata[\"h\"] = F.randn((E1, D))\n    t2.ndata[\"h\"] = F.randn((N2, D))\n    t2.edata[\"h\"] = F.randn((E2, D))\n\n    b1 = dgl.batch([t1, t2])\n    b2 = dgl.batch([t2])\n    b1.ndata[\"h\"][:N1] = F.zeros((N1, D))\n    b1.edata[\"h\"][:E1] = F.zeros((E1, D))\n    b2.ndata[\"h\"][:N2] = F.zeros((N2, D))\n    b2.edata[\"h\"][:E2] = F.zeros((E2, D))\n    assert not F.allclose(t1.ndata[\"h\"], F.zeros((N1, D)))\n    assert not F.allclose(t1.edata[\"h\"], F.zeros((E1, D)))\n    assert not F.allclose(t2.ndata[\"h\"], F.zeros((N2, D)))\n    assert not F.allclose(t2.edata[\"h\"], F.zeros((E2, D)))\n\n    g1, g2 = dgl.unbatch(b1)\n    (_g2,) = dgl.unbatch(b2)\n    assert F.allclose(g1.ndata[\"h\"], F.zeros((N1, D)))\n    assert F.allclose(g1.edata[\"h\"], F.zeros((E1, D)))\n    assert F.allclose(g2.ndata[\"h\"], t2.ndata[\"h\"])\n    assert F.allclose(g2.edata[\"h\"], t2.edata[\"h\"])\n    assert F.allclose(_g2.ndata[\"h\"], F.zeros((N2, D)))\n    assert F.allclose(_g2.edata[\"h\"], F.zeros((E2, D)))\n\n\n@parametrize_idtype\ndef test_batch_unbatch2(idtype):\n    # test setting/getting features after batch\n    a = dgl.graph(([], [])).astype(idtype).to(F.ctx())\n    a.add_nodes(4)\n    a.add_edges(0, [1, 2, 3])\n    b = dgl.graph(([], [])).astype(idtype).to(F.ctx())\n    b.add_nodes(3)\n    b.add_edges(0, [1, 2])\n    c = dgl.batch([a, b])\n    c.ndata[\"h\"] = F.ones((7, 1))\n    c.edata[\"w\"] = F.ones((5, 1))\n    assert F.allclose(c.ndata[\"h\"], F.ones((7, 1)))\n    assert F.allclose(c.edata[\"w\"], F.ones((5, 1)))\n\n\n@parametrize_idtype\ndef test_batch_send_and_recv(idtype):\n    t1 = tree1(idtype)\n    t2 = tree2(idtype)\n\n    bg = dgl.batch([t1, t2])\n    _mfunc = lambda edges: {\"m\": edges.src[\"h\"]}\n    _rfunc = lambda nodes: {\"h\": F.sum(nodes.mailbox[\"m\"], 1)}\n    u = [3, 4, 2 + 5, 0 + 5]\n    v = [1, 1, 4 + 5, 4 + 5]\n\n    bg.send_and_recv((u, v), _mfunc, _rfunc)\n\n    t1, t2 = dgl.unbatch(bg)\n    assert F.asnumpy(t1.ndata[\"h\"][1]) == 7\n    assert F.asnumpy(t2.ndata[\"h\"][4]) == 2\n\n\n@parametrize_idtype\ndef test_batch_propagate(idtype):\n    t1 = tree1(idtype)\n    t2 = tree2(idtype)\n\n    bg = dgl.batch([t1, t2])\n    _mfunc = lambda edges: {\"m\": edges.src[\"h\"]}\n    _rfunc = lambda nodes: {\"h\": F.sum(nodes.mailbox[\"m\"], 1)}\n    # get leaves.\n\n    order = []\n\n    # step 1\n    u = [3, 4, 2 + 5, 0 + 5]\n    v = [1, 1, 4 + 5, 4 + 5]\n    order.append((u, v))\n\n    # step 2\n    u = [1, 2, 4 + 5, 3 + 5]\n    v = [0, 0, 1 + 5, 1 + 5]\n    order.append((u, v))\n\n    bg.prop_edges(order, _mfunc, _rfunc)\n    t1, t2 = dgl.unbatch(bg)\n\n    assert F.asnumpy(t1.ndata[\"h\"][0]) == 9\n    assert F.asnumpy(t2.ndata[\"h\"][1]) == 5\n\n\n@parametrize_idtype\ndef test_batched_edge_ordering(idtype):\n    g1 = dgl.graph(([], [])).astype(idtype).to(F.ctx())\n    g1.add_nodes(6)\n    g1.add_edges([4, 4, 2, 2, 0], [5, 3, 3, 1, 1])\n    e1 = F.randn((5, 10))\n    g1.edata[\"h\"] = e1\n    g2 = dgl.graph(([], [])).astype(idtype).to(F.ctx())\n    g2.add_nodes(6)\n    g2.add_edges([0, 1, 2, 5, 4, 5], [1, 2, 3, 4, 3, 0])\n    e2 = F.randn((6, 10))\n    g2.edata[\"h\"] = e2\n    g = dgl.batch([g1, g2])\n    r1 = g.edata[\"h\"][g.edge_ids(4, 5)]\n    r2 = g1.edata[\"h\"][g1.edge_ids(4, 5)]\n    assert F.array_equal(r1, r2)\n\n\n@parametrize_idtype\ndef test_batch_no_edge(idtype):\n    g1 = dgl.graph(([], [])).astype(idtype).to(F.ctx())\n    g1.add_nodes(6)\n    g1.add_edges([4, 4, 2, 2, 0], [5, 3, 3, 1, 1])\n    g2 = dgl.graph(([], [])).astype(idtype).to(F.ctx())\n    g2.add_nodes(6)\n    g2.add_edges([0, 1, 2, 5, 4, 5], [1, 2, 3, 4, 3, 0])\n    g3 = dgl.graph(([], [])).astype(idtype).to(F.ctx())\n    g3.add_nodes(1)  # no edges\n    g = dgl.batch([g1, g3, g2])  # should not throw an error\n\n\n@parametrize_idtype\ndef test_batch_keeps_empty_data(idtype):\n    g1 = dgl.graph(([], [])).astype(idtype).to(F.ctx())\n    g1.ndata[\"nh\"] = F.tensor([])\n    g1.edata[\"eh\"] = F.tensor([])\n    g2 = dgl.graph(([], [])).astype(idtype).to(F.ctx())\n    g2.ndata[\"nh\"] = F.tensor([])\n    g2.edata[\"eh\"] = F.tensor([])\n    g = dgl.batch([g1, g2])\n    assert \"nh\" in g.ndata\n    assert \"eh\" in g.edata\n\n\ndef _get_subgraph_batch_info(keys, induced_indices_arr, batch_num_objs):\n    \"\"\"Internal function to compute batch information for subgraphs.\n    Parameters\n    ----------\n    keys : List[str]\n        The node/edge type keys.\n    induced_indices_arr : List[Tensor]\n        The induced node/edge index tensor for all node/edge types.\n    batch_num_objs : Tensor\n        Number of nodes/edges for each graph in the original batch.\n    Returns\n    -------\n    Mapping[str, Tensor]\n        A dictionary mapping all node/edge type keys to the ``batch_num_objs``\n        array of corresponding graph.\n    \"\"\"\n    bucket_offset = np.expand_dims(\n        np.cumsum(F.asnumpy(batch_num_objs), 0), -1\n    )  # (num_bkts, 1)\n    ret = {}\n    for key, induced_indices in zip(keys, induced_indices_arr):\n        # NOTE(Zihao): this implementation is not efficient and we can replace it with\n        # binary search in the future.\n        induced_indices = np.expand_dims(\n            F.asnumpy(induced_indices), 0\n        )  # (1, num_nodes)\n        new_offset = np.sum((induced_indices < bucket_offset), 1)  # (num_bkts,)\n        # start_offset = [0] + [new_offset[i-1] for i in range(1, n_bkts)]\n        start_offset = np.concatenate([np.zeros((1,)), new_offset[:-1]], 0)\n        new_batch_num_objs = new_offset - start_offset\n        ret[key] = F.tensor(new_batch_num_objs, dtype=F.dtype(batch_num_objs))\n    return ret\n\n\n@parametrize_idtype\ndef test_set_batch_info(idtype):\n    ctx = F.ctx()\n\n    g1 = dgl.rand_graph(30, 100).astype(idtype).to(F.ctx())\n    g2 = dgl.rand_graph(40, 200).astype(idtype).to(F.ctx())\n    bg = dgl.batch([g1, g2])\n    batch_num_nodes = F.astype(bg.batch_num_nodes(), idtype)\n    batch_num_edges = F.astype(bg.batch_num_edges(), idtype)\n\n    # test homogeneous node subgraph\n    sg_n = dgl.node_subgraph(bg, list(range(10, 20)) + list(range(50, 60)))\n    induced_nodes = sg_n.ndata[\"_ID\"]\n    induced_edges = sg_n.edata[\"_ID\"]\n    new_batch_num_nodes = _get_subgraph_batch_info(\n        bg.ntypes, [induced_nodes], batch_num_nodes\n    )\n    new_batch_num_edges = _get_subgraph_batch_info(\n        bg.canonical_etypes, [induced_edges], batch_num_edges\n    )\n    sg_n.set_batch_num_nodes(new_batch_num_nodes)\n    sg_n.set_batch_num_edges(new_batch_num_edges)\n    subg_n1, subg_n2 = dgl.unbatch(sg_n)\n    subg1 = dgl.node_subgraph(g1, list(range(10, 20)))\n    subg2 = dgl.node_subgraph(g2, list(range(20, 30)))\n    assert subg_n1.num_edges() == subg1.num_edges()\n    assert subg_n2.num_edges() == subg2.num_edges()\n\n    # test homogeneous edge subgraph\n    sg_e = dgl.edge_subgraph(\n        bg, list(range(40, 70)) + list(range(150, 200)), relabel_nodes=False\n    )\n    induced_nodes = F.arange(0, bg.num_nodes(), idtype)\n    induced_edges = sg_e.edata[\"_ID\"]\n    new_batch_num_nodes = _get_subgraph_batch_info(\n        bg.ntypes, [induced_nodes], batch_num_nodes\n    )\n    new_batch_num_edges = _get_subgraph_batch_info(\n        bg.canonical_etypes, [induced_edges], batch_num_edges\n    )\n    sg_e.set_batch_num_nodes(new_batch_num_nodes)\n    sg_e.set_batch_num_edges(new_batch_num_edges)\n    subg_e1, subg_e2 = dgl.unbatch(sg_e)\n    subg1 = dgl.edge_subgraph(g1, list(range(40, 70)), relabel_nodes=False)\n    subg2 = dgl.edge_subgraph(g2, list(range(50, 100)), relabel_nodes=False)\n    assert subg_e1.num_nodes() == subg1.num_nodes()\n    assert subg_e2.num_nodes() == subg2.num_nodes()\n\n\nif __name__ == \"__main__\":\n    # test_batch_unbatch()\n    # test_batch_unbatch1()\n    # test_batch_unbatch_frame()\n    # test_batch_unbatch2()\n    # test_batched_edge_ordering()\n    # test_batch_send_then_recv()\n    # test_batch_send_and_recv()\n    # test_batch_propagate()\n    # test_batch_no_edge()\n    test_set_batch_info(F.int32)\n"
  },
  {
    "path": "tests/python/common/test_batch-heterograph.py",
    "content": "import unittest\n\nimport backend as F\n\nimport dgl\nimport pytest\nfrom dgl.base import ALL\nfrom utils import check_graph_equal, get_cases, parametrize_idtype\n\n\ndef check_equivalence_between_heterographs(\n    g1, g2, node_attrs=None, edge_attrs=None\n):\n    assert g1.ntypes == g2.ntypes\n    assert g1.etypes == g2.etypes\n    assert g1.canonical_etypes == g2.canonical_etypes\n\n    for nty in g1.ntypes:\n        assert g1.num_nodes(nty) == g2.num_nodes(nty)\n\n    for ety in g1.etypes:\n        if len(g1._etype2canonical[ety]) > 0:\n            assert g1.num_edges(ety) == g2.num_edges(ety)\n\n    for ety in g1.canonical_etypes:\n        assert g1.num_edges(ety) == g2.num_edges(ety)\n        src1, dst1, eid1 = g1.edges(etype=ety, form=\"all\")\n        src2, dst2, eid2 = g2.edges(etype=ety, form=\"all\")\n        assert F.allclose(src1, src2)\n        assert F.allclose(dst1, dst2)\n        assert F.allclose(eid1, eid2)\n\n    if node_attrs is not None:\n        for nty in node_attrs.keys():\n            if g1.num_nodes(nty) == 0:\n                continue\n            for feat_name in node_attrs[nty]:\n                assert F.allclose(\n                    g1.nodes[nty].data[feat_name], g2.nodes[nty].data[feat_name]\n                )\n\n    if edge_attrs is not None:\n        for ety in edge_attrs.keys():\n            if g1.num_edges(ety) == 0:\n                continue\n            for feat_name in edge_attrs[ety]:\n                assert F.allclose(\n                    g1.edges[ety].data[feat_name], g2.edges[ety].data[feat_name]\n                )\n\n\n@pytest.mark.parametrize(\"gs\", get_cases([\"two_hetero_batch\"]))\n@parametrize_idtype\ndef test_topology(gs, idtype):\n    \"\"\"Test batching two DGLGraphs where some nodes are isolated in some relations\"\"\"\n    g1, g2 = gs\n    g1 = g1.astype(idtype).to(F.ctx())\n    g2 = g2.astype(idtype).to(F.ctx())\n    bg = dgl.batch([g1, g2])\n\n    assert bg.idtype == idtype\n    assert bg.device == F.ctx()\n    assert bg.ntypes == g2.ntypes\n    assert bg.etypes == g2.etypes\n    assert bg.canonical_etypes == g2.canonical_etypes\n    assert bg.batch_size == 2\n\n    # Test number of nodes\n    for ntype in bg.ntypes:\n        print(ntype)\n        assert F.asnumpy(bg.batch_num_nodes(ntype)).tolist() == [\n            g1.num_nodes(ntype),\n            g2.num_nodes(ntype),\n        ]\n        assert bg.num_nodes(ntype) == (\n            g1.num_nodes(ntype) + g2.num_nodes(ntype)\n        )\n\n    # Test number of edges\n    for etype in bg.canonical_etypes:\n        assert F.asnumpy(bg.batch_num_edges(etype)).tolist() == [\n            g1.num_edges(etype),\n            g2.num_edges(etype),\n        ]\n        assert bg.num_edges(etype) == (\n            g1.num_edges(etype) + g2.num_edges(etype)\n        )\n\n    # Test relabeled nodes\n    for ntype in bg.ntypes:\n        assert list(F.asnumpy(bg.nodes(ntype))) == list(\n            range(bg.num_nodes(ntype))\n        )\n\n    # Test relabeled edges\n    src, dst = bg.edges(etype=(\"user\", \"follows\", \"user\"))\n    assert list(F.asnumpy(src)) == [0, 1, 4, 5]\n    assert list(F.asnumpy(dst)) == [1, 2, 5, 6]\n    src, dst = bg.edges(etype=(\"user\", \"follows\", \"developer\"))\n    assert list(F.asnumpy(src)) == [0, 1, 4, 5]\n    assert list(F.asnumpy(dst)) == [1, 2, 4, 5]\n    src, dst, eid = bg.edges(etype=\"plays\", form=\"all\")\n    assert list(F.asnumpy(src)) == [0, 1, 2, 3, 4, 5, 6]\n    assert list(F.asnumpy(dst)) == [0, 0, 1, 1, 2, 2, 3]\n    assert list(F.asnumpy(eid)) == [0, 1, 2, 3, 4, 5, 6]\n\n    # Test unbatching graphs\n    g3, g4 = dgl.unbatch(bg)\n    check_equivalence_between_heterographs(g1, g3)\n    check_equivalence_between_heterographs(g2, g4)\n\n    # Test dtype cast\n    if idtype == \"int32\":\n        bg_cast = bg.long()\n    else:\n        bg_cast = bg.int()\n    assert bg.batch_size == bg_cast.batch_size\n\n    # Test local var\n    bg_local = bg.local_var()\n    assert bg.batch_size == bg_local.batch_size\n\n\n@parametrize_idtype\ndef test_batching_batched(idtype):\n    \"\"\"Test batching a DGLGraph and a batched DGLGraph.\"\"\"\n    g1 = dgl.heterograph(\n        {\n            (\"user\", \"follows\", \"user\"): ([0, 1], [1, 2]),\n            (\"user\", \"plays\", \"game\"): ([0, 1], [0, 0]),\n        },\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    g2 = dgl.heterograph(\n        {\n            (\"user\", \"follows\", \"user\"): ([0, 1], [1, 2]),\n            (\"user\", \"plays\", \"game\"): ([0, 1], [0, 0]),\n        },\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    bg1 = dgl.batch([g1, g2])\n    g3 = dgl.heterograph(\n        {\n            (\"user\", \"follows\", \"user\"): ([0], [1]),\n            (\"user\", \"plays\", \"game\"): ([1], [0]),\n        },\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    bg2 = dgl.batch([bg1, g3])\n    assert bg2.idtype == idtype\n    assert bg2.device == F.ctx()\n    assert bg2.ntypes == g3.ntypes\n    assert bg2.etypes == g3.etypes\n    assert bg2.canonical_etypes == g3.canonical_etypes\n    assert bg2.batch_size == 3\n\n    # Test number of nodes\n    for ntype in bg2.ntypes:\n        assert F.asnumpy(bg2.batch_num_nodes(ntype)).tolist() == [\n            g1.num_nodes(ntype),\n            g2.num_nodes(ntype),\n            g3.num_nodes(ntype),\n        ]\n        assert bg2.num_nodes(ntype) == (\n            g1.num_nodes(ntype) + g2.num_nodes(ntype) + g3.num_nodes(ntype)\n        )\n\n    # Test number of edges\n    for etype in bg2.canonical_etypes:\n        assert F.asnumpy(bg2.batch_num_edges(etype)).tolist() == [\n            g1.num_edges(etype),\n            g2.num_edges(etype),\n            g3.num_edges(etype),\n        ]\n        assert bg2.num_edges(etype) == (\n            g1.num_edges(etype) + g2.num_edges(etype) + g3.num_edges(etype)\n        )\n\n    # Test relabeled nodes\n    for ntype in bg2.ntypes:\n        assert list(F.asnumpy(bg2.nodes(ntype))) == list(\n            range(bg2.num_nodes(ntype))\n        )\n\n    # Test relabeled edges\n    src, dst = bg2.edges(etype=\"follows\")\n    assert list(F.asnumpy(src)) == [0, 1, 3, 4, 6]\n    assert list(F.asnumpy(dst)) == [1, 2, 4, 5, 7]\n    src, dst = bg2.edges(etype=\"plays\")\n    assert list(F.asnumpy(src)) == [0, 1, 3, 4, 7]\n    assert list(F.asnumpy(dst)) == [0, 0, 1, 1, 2]\n\n    # Test unbatching graphs\n    g4, g5, g6 = dgl.unbatch(bg2)\n    check_equivalence_between_heterographs(g1, g4)\n    check_equivalence_between_heterographs(g2, g5)\n    check_equivalence_between_heterographs(g3, g6)\n\n\n@parametrize_idtype\ndef test_features(idtype):\n    \"\"\"Test the features of batched DGLGraphs\"\"\"\n    g1 = dgl.heterograph(\n        {\n            (\"user\", \"follows\", \"user\"): ([0, 1], [1, 2]),\n            (\"user\", \"plays\", \"game\"): ([0, 1], [0, 0]),\n        },\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    g1.nodes[\"user\"].data[\"h1\"] = F.tensor([[0.0], [1.0], [2.0]])\n    g1.nodes[\"user\"].data[\"h2\"] = F.tensor([[3.0], [4.0], [5.0]])\n    g1.nodes[\"game\"].data[\"h1\"] = F.tensor([[0.0]])\n    g1.nodes[\"game\"].data[\"h2\"] = F.tensor([[1.0]])\n    g1.edges[\"follows\"].data[\"h1\"] = F.tensor([[0.0], [1.0]])\n    g1.edges[\"follows\"].data[\"h2\"] = F.tensor([[2.0], [3.0]])\n    g1.edges[\"plays\"].data[\"h1\"] = F.tensor([[0.0], [1.0]])\n\n    g2 = dgl.heterograph(\n        {\n            (\"user\", \"follows\", \"user\"): ([0, 1], [1, 2]),\n            (\"user\", \"plays\", \"game\"): ([0, 1], [0, 0]),\n        },\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    g2.nodes[\"user\"].data[\"h1\"] = F.tensor([[0.0], [1.0], [2.0]])\n    g2.nodes[\"user\"].data[\"h2\"] = F.tensor([[3.0], [4.0], [5.0]])\n    g2.nodes[\"game\"].data[\"h1\"] = F.tensor([[0.0]])\n    g2.nodes[\"game\"].data[\"h2\"] = F.tensor([[1.0]])\n    g2.edges[\"follows\"].data[\"h1\"] = F.tensor([[0.0], [1.0]])\n    g2.edges[\"follows\"].data[\"h2\"] = F.tensor([[2.0], [3.0]])\n    g2.edges[\"plays\"].data[\"h1\"] = F.tensor([[0.0], [1.0]])\n\n    # test default setting\n    bg = dgl.batch([g1, g2])\n    assert F.allclose(\n        bg.nodes[\"user\"].data[\"h1\"],\n        F.cat(\n            [g1.nodes[\"user\"].data[\"h1\"], g2.nodes[\"user\"].data[\"h1\"]], dim=0\n        ),\n    )\n    assert F.allclose(\n        bg.nodes[\"user\"].data[\"h2\"],\n        F.cat(\n            [g1.nodes[\"user\"].data[\"h2\"], g2.nodes[\"user\"].data[\"h2\"]], dim=0\n        ),\n    )\n    assert F.allclose(\n        bg.nodes[\"game\"].data[\"h1\"],\n        F.cat(\n            [g1.nodes[\"game\"].data[\"h1\"], g2.nodes[\"game\"].data[\"h1\"]], dim=0\n        ),\n    )\n    assert F.allclose(\n        bg.nodes[\"game\"].data[\"h2\"],\n        F.cat(\n            [g1.nodes[\"game\"].data[\"h2\"], g2.nodes[\"game\"].data[\"h2\"]], dim=0\n        ),\n    )\n    assert F.allclose(\n        bg.edges[\"follows\"].data[\"h1\"],\n        F.cat(\n            [g1.edges[\"follows\"].data[\"h1\"], g2.edges[\"follows\"].data[\"h1\"]],\n            dim=0,\n        ),\n    )\n    assert F.allclose(\n        bg.edges[\"follows\"].data[\"h2\"],\n        F.cat(\n            [g1.edges[\"follows\"].data[\"h2\"], g2.edges[\"follows\"].data[\"h2\"]],\n            dim=0,\n        ),\n    )\n    assert F.allclose(\n        bg.edges[\"plays\"].data[\"h1\"],\n        F.cat(\n            [g1.edges[\"plays\"].data[\"h1\"], g2.edges[\"plays\"].data[\"h1\"]], dim=0\n        ),\n    )\n\n    # test specifying ndata/edata\n    bg = dgl.batch([g1, g2], ndata=[\"h2\"], edata=[\"h1\"])\n    assert F.allclose(\n        bg.nodes[\"user\"].data[\"h2\"],\n        F.cat(\n            [g1.nodes[\"user\"].data[\"h2\"], g2.nodes[\"user\"].data[\"h2\"]], dim=0\n        ),\n    )\n    assert F.allclose(\n        bg.nodes[\"game\"].data[\"h2\"],\n        F.cat(\n            [g1.nodes[\"game\"].data[\"h2\"], g2.nodes[\"game\"].data[\"h2\"]], dim=0\n        ),\n    )\n    assert F.allclose(\n        bg.edges[\"follows\"].data[\"h1\"],\n        F.cat(\n            [g1.edges[\"follows\"].data[\"h1\"], g2.edges[\"follows\"].data[\"h1\"]],\n            dim=0,\n        ),\n    )\n    assert F.allclose(\n        bg.edges[\"plays\"].data[\"h1\"],\n        F.cat(\n            [g1.edges[\"plays\"].data[\"h1\"], g2.edges[\"plays\"].data[\"h1\"]], dim=0\n        ),\n    )\n    assert \"h1\" not in bg.nodes[\"user\"].data\n    assert \"h1\" not in bg.nodes[\"game\"].data\n    assert \"h2\" not in bg.edges[\"follows\"].data\n\n    # Test unbatching graphs\n    g3, g4 = dgl.unbatch(bg)\n    check_equivalence_between_heterographs(\n        g1,\n        g3,\n        node_attrs={\"user\": [\"h2\"], \"game\": [\"h2\"]},\n        edge_attrs={(\"user\", \"follows\", \"user\"): [\"h1\"]},\n    )\n    check_equivalence_between_heterographs(\n        g2,\n        g4,\n        node_attrs={\"user\": [\"h2\"], \"game\": [\"h2\"]},\n        edge_attrs={(\"user\", \"follows\", \"user\"): [\"h1\"]},\n    )\n\n\n@unittest.skipIf(\n    F.backend_name == \"mxnet\",\n    reason=\"MXNet does not support split array with zero-length segment.\",\n)\n@parametrize_idtype\ndef test_empty_relation(idtype):\n    \"\"\"Test the features of batched DGLGraphs\"\"\"\n    g1 = dgl.heterograph(\n        {\n            (\"user\", \"follows\", \"user\"): ([0, 1], [1, 2]),\n            (\"user\", \"plays\", \"game\"): ([], []),\n        },\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    g1.nodes[\"user\"].data[\"h1\"] = F.tensor([[0.0], [1.0], [2.0]])\n    g1.nodes[\"user\"].data[\"h2\"] = F.tensor([[3.0], [4.0], [5.0]])\n    g1.edges[\"follows\"].data[\"h1\"] = F.tensor([[0.0], [1.0]])\n    g1.edges[\"follows\"].data[\"h2\"] = F.tensor([[2.0], [3.0]])\n\n    g2 = dgl.heterograph(\n        {\n            (\"user\", \"follows\", \"user\"): ([0, 1], [1, 2]),\n            (\"user\", \"plays\", \"game\"): ([0, 1], [0, 0]),\n        },\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    g2.nodes[\"user\"].data[\"h1\"] = F.tensor([[0.0], [1.0], [2.0]])\n    g2.nodes[\"user\"].data[\"h2\"] = F.tensor([[3.0], [4.0], [5.0]])\n    g2.nodes[\"game\"].data[\"h1\"] = F.tensor([[0.0]])\n    g2.nodes[\"game\"].data[\"h2\"] = F.tensor([[1.0]])\n    g2.edges[\"follows\"].data[\"h1\"] = F.tensor([[0.0], [1.0]])\n    g2.edges[\"follows\"].data[\"h2\"] = F.tensor([[2.0], [3.0]])\n    g2.edges[\"plays\"].data[\"h1\"] = F.tensor([[0.0], [1.0]])\n\n    bg = dgl.batch([g1, g2])\n\n    # Test number of nodes\n    for ntype in bg.ntypes:\n        assert F.asnumpy(bg.batch_num_nodes(ntype)).tolist() == [\n            g1.num_nodes(ntype),\n            g2.num_nodes(ntype),\n        ]\n\n    # Test number of edges\n    for etype in bg.canonical_etypes:\n        assert F.asnumpy(bg.batch_num_edges(etype)).tolist() == [\n            g1.num_edges(etype),\n            g2.num_edges(etype),\n        ]\n\n    # Test features\n    assert F.allclose(\n        bg.nodes[\"user\"].data[\"h1\"],\n        F.cat(\n            [g1.nodes[\"user\"].data[\"h1\"], g2.nodes[\"user\"].data[\"h1\"]], dim=0\n        ),\n    )\n    assert F.allclose(\n        bg.nodes[\"user\"].data[\"h2\"],\n        F.cat(\n            [g1.nodes[\"user\"].data[\"h2\"], g2.nodes[\"user\"].data[\"h2\"]], dim=0\n        ),\n    )\n    assert F.allclose(bg.nodes[\"game\"].data[\"h1\"], g2.nodes[\"game\"].data[\"h1\"])\n    assert F.allclose(bg.nodes[\"game\"].data[\"h2\"], g2.nodes[\"game\"].data[\"h2\"])\n    assert F.allclose(\n        bg.edges[\"follows\"].data[\"h1\"],\n        F.cat(\n            [g1.edges[\"follows\"].data[\"h1\"], g2.edges[\"follows\"].data[\"h1\"]],\n            dim=0,\n        ),\n    )\n    assert F.allclose(\n        bg.edges[\"plays\"].data[\"h1\"], g2.edges[\"plays\"].data[\"h1\"]\n    )\n\n    # Test unbatching graphs\n    g3, g4 = dgl.unbatch(bg)\n    check_equivalence_between_heterographs(\n        g1,\n        g3,\n        node_attrs={\"user\": [\"h1\", \"h2\"], \"game\": [\"h1\", \"h2\"]},\n        edge_attrs={(\"user\", \"follows\", \"user\"): [\"h1\"]},\n    )\n    check_equivalence_between_heterographs(\n        g2,\n        g4,\n        node_attrs={\"user\": [\"h1\", \"h2\"], \"game\": [\"h1\", \"h2\"]},\n        edge_attrs={(\"user\", \"follows\", \"user\"): [\"h1\"]},\n    )\n\n    # Test graphs without edges\n    g1 = dgl.heterograph({(\"u\", \"r\", \"v\"): ([], [])}, {\"u\": 0, \"v\": 4})\n    g2 = dgl.heterograph({(\"u\", \"r\", \"v\"): ([], [])}, {\"u\": 1, \"v\": 5})\n    dgl.batch([g1, g2])\n\n\n@parametrize_idtype\ndef test_unbatch2(idtype):\n    # batch 3 graphs but unbatch to 2\n    g1 = dgl.graph(([0, 1, 2], [1, 2, 3]), idtype=idtype, device=F.ctx())\n    g2 = dgl.graph(([0, 1, 2], [1, 2, 3]), idtype=idtype, device=F.ctx())\n    g3 = dgl.graph(([0, 1, 2], [1, 2, 3]), idtype=idtype, device=F.ctx())\n    bg = dgl.batch([g1, g2, g3])\n    bnn = F.tensor([8, 4])\n    bne = F.tensor([6, 3])\n    f1, f2 = dgl.unbatch(bg, node_split=bnn, edge_split=bne)\n    u, v = f1.edges(order=\"eid\")\n    assert F.allclose(u, F.tensor([0, 1, 2, 4, 5, 6]))\n    assert F.allclose(v, F.tensor([1, 2, 3, 5, 6, 7]))\n    u, v = f2.edges(order=\"eid\")\n    assert F.allclose(u, F.tensor([0, 1, 2]))\n    assert F.allclose(v, F.tensor([1, 2, 3]))\n\n    # batch 2 but unbatch to 3\n    bg = dgl.batch([f1, f2])\n    gg1, gg2, gg3 = dgl.unbatch(bg, F.tensor([4, 4, 4]), F.tensor([3, 3, 3]))\n    check_graph_equal(g1, gg1)\n    check_graph_equal(g2, gg2)\n    check_graph_equal(g3, gg3)\n\n\n@parametrize_idtype\ndef test_slice_batch(idtype):\n    g1 = dgl.heterograph(\n        {\n            (\"user\", \"follows\", \"user\"): ([0, 1], [1, 2]),\n            (\"user\", \"plays\", \"game\"): ([], []),\n            (\"user\", \"follows\", \"game\"): ([0, 0], [1, 4]),\n        },\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    g2 = dgl.heterograph(\n        {\n            (\"user\", \"follows\", \"user\"): ([0, 1], [1, 2]),\n            (\"user\", \"plays\", \"game\"): ([0, 1], [0, 0]),\n            (\"user\", \"follows\", \"game\"): ([0, 1], [1, 4]),\n        },\n        num_nodes_dict={\"user\": 4, \"game\": 6},\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    g3 = dgl.heterograph(\n        {\n            (\"user\", \"follows\", \"user\"): ([0], [2]),\n            (\"user\", \"plays\", \"game\"): ([1, 2], [3, 4]),\n            (\"user\", \"follows\", \"game\"): ([], []),\n        },\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    g_list = [g1, g2, g3]\n    bg = dgl.batch(g_list)\n    bg.nodes[\"user\"].data[\"h1\"] = F.randn((bg.num_nodes(\"user\"), 2))\n    bg.nodes[\"user\"].data[\"h2\"] = F.randn((bg.num_nodes(\"user\"), 5))\n    bg.edges[(\"user\", \"follows\", \"user\")].data[\"h1\"] = F.randn(\n        (bg.num_edges((\"user\", \"follows\", \"user\")), 2)\n    )\n    for fmat in [\"coo\", \"csr\", \"csc\"]:\n        bg = bg.formats(fmat)\n        for i in range(len(g_list)):\n            g_i = g_list[i]\n            g_slice = dgl.slice_batch(bg, i)\n            assert g_i.ntypes == g_slice.ntypes\n            assert g_i.canonical_etypes == g_slice.canonical_etypes\n            assert g_i.idtype == g_slice.idtype\n            assert g_i.device == g_slice.device\n            for nty in g_i.ntypes:\n                assert g_i.num_nodes(nty) == g_slice.num_nodes(nty)\n                for feat in g_i.nodes[nty].data:\n                    assert F.allclose(\n                        g_i.nodes[nty].data[feat], g_slice.nodes[nty].data[feat]\n                    )\n\n            for ety in g_i.canonical_etypes:\n                assert g_i.num_edges(ety) == g_slice.num_edges(ety)\n                for feat in g_i.edges[ety].data:\n                    assert F.allclose(\n                        g_i.edges[ety].data[feat], g_slice.edges[ety].data[feat]\n                    )\n\n\n@parametrize_idtype\ndef test_batch_keeps_empty_data(idtype):\n    g1 = (\n        dgl.heterograph({(\"a\", \"to\", \"a\"): ([], [])}).astype(idtype).to(F.ctx())\n    )\n    g1.nodes[\"a\"].data[\"nh\"] = F.tensor([])\n    g1.edges[(\"a\", \"to\", \"a\")].data[\"eh\"] = F.tensor([])\n    g2 = (\n        dgl.heterograph({(\"a\", \"to\", \"a\"): ([], [])}).astype(idtype).to(F.ctx())\n    )\n    g2.nodes[\"a\"].data[\"nh\"] = F.tensor([])\n    g2.edges[(\"a\", \"to\", \"a\")].data[\"eh\"] = F.tensor([])\n    g = dgl.batch([g1, g2])\n    assert \"nh\" in g.nodes[\"a\"].data\n    assert \"eh\" in g.edges[(\"a\", \"to\", \"a\")].data\n\n\ndef test_batch_netypes():\n    # Test for https://github.com/dmlc/dgl/issues/2808\n    import networkx as nx\n\n    B = nx.DiGraph()\n    B.add_nodes_from(\n        [1, 2, 3, 4],\n        bipartite=0,\n        some_attr=F.tensor([1, 2, 3, 4], dtype=F.float32),\n    )\n    B.add_nodes_from([\"a\", \"b\", \"c\"], bipartite=1)\n    B.add_edges_from(\n        [(1, \"a\"), (1, \"b\"), (2, \"b\"), (2, \"c\"), (3, \"c\"), (4, \"a\")]\n    )\n\n    g_dict = {\n        0: dgl.bipartite_from_networkx(B, \"A\", \"e\", \"B\"),\n        1: dgl.bipartite_from_networkx(B, \"B\", \"e\", \"A\"),\n        2: dgl.bipartite_from_networkx(B, \"A\", \"e\", \"B\", u_attrs=[\"some_attr\"]),\n        3: dgl.bipartite_from_networkx(B, \"B\", \"e\", \"A\", u_attrs=[\"some_attr\"]),\n    }\n    for _, g in g_dict.items():\n        dgl.batch((g, g, g))\n\n\nif __name__ == \"__main__\":\n    # test_topology('int32')\n    # test_batching_batched('int32')\n    # test_batched_features('int32')\n    # test_empty_relation('int64')\n    # test_to_device('int32')\n    pass\n"
  },
  {
    "path": "tests/python/common/test_convert.py",
    "content": "import unittest\n\nimport backend as F\nimport dgl\n\nfrom utils import parametrize_idtype\n\n\ndef get_nodes_by_ntype(nodes, ntype):\n    return dict((k, v) for k, v in nodes.items() if v[\"ntype\"] == ntype)\n\n\ndef edge_attrs(edge):\n    # Edges in Networkx are in the format (src, dst, attrs)\n    return edge[2]\n\n\ndef get_edges_by_etype(edges, etype):\n    return [e for e in edges if edge_attrs(e)[\"etype\"] == etype]\n\n\ndef check_attrs_for_nodes(nodes, attrs):\n    return all(v.keys() == attrs for v in nodes.values())\n\n\ndef check_attr_values_for_nodes(nodes, attr_name, values):\n    return F.allclose(\n        F.stack([v[attr_name] for v in nodes.values()], 0), values\n    )\n\n\ndef check_attrs_for_edges(edges, attrs):\n    return all(edge_attrs(e).keys() == attrs for e in edges)\n\n\ndef check_attr_values_for_edges(edges, attr_name, values):\n    return F.allclose(\n        F.stack([edge_attrs(e)[attr_name] for e in edges], 0), values\n    )\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"`to_networkx` does not support graphs on GPU\",\n)\n@parametrize_idtype\ndef test_to_networkx(idtype):\n    # TODO: adapt and move code from the _test_nx_conversion function in\n    # tests/python/common/function/test_basics.py to here\n    # (pending resolution of https://github.com/dmlc/dgl/issues/5735).\n    g = dgl.heterograph(\n        {\n            (\"user\", \"follows\", \"user\"): ([0, 1], [1, 2]),\n            (\"user\", \"follows\", \"topic\"): ([1, 1], [1, 2]),\n            (\"user\", \"plays\", \"game\"): ([0, 3], [3, 4]),\n        },\n        idtype=idtype,\n        device=F.ctx(),\n    )\n\n    n1 = F.randn((5, 3))\n    n2 = F.randn((4, 2))\n    e1 = F.randn((2, 3))\n    e2 = F.randn((2, 2))\n\n    g.nodes[\"game\"].data[\"n\"] = F.copy_to(n1, ctx=F.ctx())\n    g.nodes[\"user\"].data[\"n\"] = F.copy_to(n2, ctx=F.ctx())\n    g.edges[(\"user\", \"follows\", \"user\")].data[\"e\"] = F.copy_to(e1, ctx=F.ctx())\n    g.edges[\"plays\"].data[\"e\"] = F.copy_to(e2, ctx=F.ctx())\n\n    nxg = dgl.to_networkx(\n        g,\n        node_attrs=[\"n\"],\n        edge_attrs=[\"e\"],\n    )\n\n    # Test nodes\n    nxg_nodes = dict(nxg.nodes(data=True))\n    assert len(nxg_nodes) == g.num_nodes()\n    assert {v[\"ntype\"] for v in nxg_nodes.values()} == set(g.ntypes)\n\n    nxg_nodes_by_ntype = {}\n    for ntype in g.ntypes:\n        nxg_nodes_by_ntype[ntype] = get_nodes_by_ntype(nxg_nodes, ntype)\n        assert g.num_nodes(ntype) == len(nxg_nodes_by_ntype[ntype])\n\n    assert check_attrs_for_nodes(nxg_nodes_by_ntype[\"game\"], {\"ntype\", \"n\"})\n    assert check_attr_values_for_nodes(nxg_nodes_by_ntype[\"game\"], \"n\", n1)\n    assert check_attrs_for_nodes(nxg_nodes_by_ntype[\"user\"], {\"ntype\", \"n\"})\n    assert check_attr_values_for_nodes(nxg_nodes_by_ntype[\"user\"], \"n\", n2)\n    # Nodes without node attributes\n    assert check_attrs_for_nodes(nxg_nodes_by_ntype[\"topic\"], {\"ntype\"})\n\n    # Test edges\n    nxg_edges = list(nxg.edges(data=True))\n    assert len(nxg_edges) == g.num_edges()\n    assert {edge_attrs(e)[\"etype\"] for e in nxg_edges} == set(\n        g.canonical_etypes\n    )\n\n    nxg_edges_by_etype = {}\n    for etype in g.canonical_etypes:\n        nxg_edges_by_etype[etype] = get_edges_by_etype(nxg_edges, etype)\n        assert g.num_edges(etype) == len(nxg_edges_by_etype[etype])\n\n    assert check_attrs_for_edges(\n        nxg_edges_by_etype[(\"user\", \"follows\", \"user\")],\n        {\"id\", \"etype\", \"e\"},\n    )\n    assert check_attr_values_for_edges(\n        nxg_edges_by_etype[(\"user\", \"follows\", \"user\")], \"e\", e1\n    )\n    assert check_attrs_for_edges(\n        nxg_edges_by_etype[(\"user\", \"plays\", \"game\")], {\"id\", \"etype\", \"e\"}\n    )\n    assert check_attr_values_for_edges(\n        nxg_edges_by_etype[(\"user\", \"plays\", \"game\")], \"e\", e2\n    )\n    # Edges without edge attributes\n    assert check_attrs_for_edges(\n        nxg_edges_by_etype[(\"user\", \"follows\", \"topic\")], {\"id\", \"etype\"}\n    )\n"
  },
  {
    "path": "tests/python/common/test_ffi.py",
    "content": "import os\nimport unittest\n\nimport backend as F\n\nimport dgl\nimport numpy as np\nimport pytest\n\n\n@unittest.skipIf(os.name == \"nt\", reason=\"Cython only works on linux\")\ndef test_cython():\n    import dgl._ffi._cy3.core\n\n\n@pytest.mark.parametrize(\"arg\", [1, 2.3])\ndef test_callback(arg):\n    def cb(x):\n        return x + 1\n\n    ret = dgl._api_internal._TestPythonCallback(cb, arg)\n    assert ret == arg + 1\n\n\n@pytest.mark.parametrize(\"dtype\", [F.float32, F.float64, F.int32, F.int64])\ndef _test_callback_array(dtype):\n    def cb(x):\n        return F.to_dgl_nd(F.from_dgl_nd(x) + 1)\n\n    arg = F.copy_to(F.tensor([1, 2, 3], dtype=dtype), F.ctx())\n    ret = F.from_dgl_nd(\n        dgl._api_internal._TestPythonCallback(cb, F.to_dgl_nd(arg))\n    )\n    assert np.allclose(F.asnumpy(ret), F.asnumpy(arg) + 1)\n\n\n@pytest.mark.parametrize(\"arg\", [1, 2.3])\ndef test_callback_thread(arg):\n    def cb(x):\n        return x + 1\n\n    ret = dgl._api_internal._TestPythonCallbackThread(cb, arg)\n    assert ret == arg + 1\n\n\n@pytest.mark.parametrize(\"dtype\", [F.float32, F.float64, F.int32, F.int64])\ndef _test_callback_array_thread(dtype):\n    def cb(x):\n        return F.to_dgl_nd(F.from_dgl_nd(x) + 1)\n\n    arg = F.copy_to(F.tensor([1, 2, 3], dtype=dtype), F.ctx())\n    ret = F.from_dgl_nd(\n        dgl._api_internal._TestPythonCallbackThread(cb, F.to_dgl_nd(arg))\n    )\n    assert np.allclose(F.asnumpy(ret), F.asnumpy(arg) + 1)\n"
  },
  {
    "path": "tests/python/common/test_frame.py",
    "content": "import pickle\nimport unittest\n\nimport backend as F\n\nimport dgl\nimport dgl.ndarray as nd\nimport numpy as np\nfrom dgl.frame import Column\nfrom utils import parametrize_idtype\n\n\ndef test_column_subcolumn():\n    data = F.copy_to(\n        F.tensor(\n            [\n                [1.0, 1.0, 1.0, 1.0],\n                [0.0, 2.0, 9.0, 0.0],\n                [3.0, 2.0, 1.0, 0.0],\n                [1.0, 1.0, 1.0, 1.0],\n                [0.0, 2.0, 4.0, 0.0],\n            ]\n        ),\n        F.ctx(),\n    )\n    original = Column(data)\n\n    # subcolumn from cpu context\n    i1 = F.tensor([0, 2, 1, 3], dtype=F.int64)\n    l1 = original.subcolumn(i1)\n\n    assert len(l1) == i1.shape[0]\n    assert F.array_equal(l1.data, F.gather_row(data, i1))\n\n    # next subcolumn from target context\n    i2 = F.copy_to(F.tensor([0, 2], dtype=F.int64), F.ctx())\n    l2 = l1.subcolumn(i2)\n\n    assert len(l2) == i2.shape[0]\n    i1i2 = F.copy_to(F.gather_row(i1, F.copy_to(i2, F.context(i1))), F.ctx())\n    assert F.array_equal(l2.data, F.gather_row(data, i1i2))\n\n    # next subcolumn also from target context\n    i3 = F.copy_to(F.tensor([1], dtype=F.int64), F.ctx())\n    l3 = l2.subcolumn(i3)\n\n    assert len(l3) == i3.shape[0]\n    i1i2i3 = F.copy_to(\n        F.gather_row(i1i2, F.copy_to(i3, F.context(i1i2))), F.ctx()\n    )\n    assert F.array_equal(l3.data, F.gather_row(data, i1i2i3))\n\n\ndef test_serialize_deserialize_plain():\n    data = F.copy_to(\n        F.tensor(\n            [\n                [1.0, 1.0, 1.0, 1.0],\n                [0.0, 2.0, 9.0, 0.0],\n                [3.0, 2.0, 1.0, 0.0],\n                [1.0, 1.0, 1.0, 1.0],\n                [0.0, 2.0, 4.0, 0.0],\n            ]\n        ),\n        F.ctx(),\n    )\n    original = Column(data)\n\n    serial = pickle.dumps(original)\n    new = pickle.loads(serial)\n    print(\"new = {}\".format(new))\n\n    assert F.array_equal(new.data, original.data)\n\n\ndef test_serialize_deserialize_subcolumn():\n    data = F.copy_to(\n        F.tensor(\n            [\n                [1.0, 1.0, 1.0, 1.0],\n                [0.0, 2.0, 9.0, 0.0],\n                [3.0, 2.0, 1.0, 0.0],\n                [1.0, 1.0, 1.0, 1.0],\n                [0.0, 2.0, 4.0, 0.0],\n            ]\n        ),\n        F.ctx(),\n    )\n    original = Column(data)\n\n    # subcolumn from cpu context\n    i1 = F.tensor([0, 2, 1, 3], dtype=F.int64)\n    l1 = original.subcolumn(i1)\n\n    serial = pickle.dumps(l1)\n    new = pickle.loads(serial)\n\n    assert F.array_equal(new.data, l1.data)\n\n\ndef test_serialize_deserialize_dtype():\n    data = F.copy_to(\n        F.tensor(\n            [\n                [1.0, 1.0, 1.0, 1.0],\n                [0.0, 2.0, 9.0, 0.0],\n                [3.0, 2.0, 1.0, 0.0],\n                [1.0, 1.0, 1.0, 1.0],\n                [0.0, 2.0, 4.0, 0.0],\n            ]\n        ),\n        F.ctx(),\n    )\n    original = Column(data)\n    original = original.astype(F.int64)\n\n    serial = pickle.dumps(original)\n    new = pickle.loads(serial)\n\n    assert new.dtype == F.int64\n"
  },
  {
    "path": "tests/python/common/test_generators.py",
    "content": "import unittest\n\nimport backend as F\n\nimport dgl\nimport numpy as np\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\", reason=\"GPU random choice not implemented\"\n)\ndef test_rand_graph():\n    g = dgl.rand_graph(10000, 100000)\n    assert g.num_nodes() == 10000\n    assert g.num_edges() == 100000\n    # test random seed\n    dgl.random.seed(42)\n    g1 = dgl.rand_graph(100, 30)\n    dgl.random.seed(42)\n    g2 = dgl.rand_graph(100, 30)\n    u1, v1 = g1.edges()\n    u2, v2 = g2.edges()\n    assert F.array_equal(u1, u2)\n    assert F.array_equal(v1, v2)\n\n\nif __name__ == \"__main__\":\n    test_rand_graph()\n"
  },
  {
    "path": "tests/python/common/test_heterograph-apply-edges.py",
    "content": "import itertools\nimport unittest\nfrom collections import Counter\nfrom itertools import product\n\nimport backend as F\n\nimport dgl\nimport dgl.function as fn\nimport networkx as nx\nimport numpy as np\nimport pytest\nimport scipy.sparse as spsp\nimport torch\n\nfrom dgl import DGLError\nfrom scipy.sparse import rand\nfrom utils import get_cases, parametrize_idtype\n\nrfuncs = {\"sum\": fn.sum, \"max\": fn.max, \"min\": fn.min, \"mean\": fn.mean}\nfill_value = {\"sum\": 0, \"max\": float(\"-inf\")}\nfeat_size = 2\n\n\n@unittest.skipIf(\n    dgl.backend.backend_name != \"pytorch\", reason=\"Only support PyTorch for now\"\n)\ndef create_test_heterograph(idtype):\n    # test heterograph from the docstring, plus a user -- wishes -- game relation\n    # 3 users, 2 games, 2 developers\n    # metagraph:\n    #    ('user', 'follows', 'user'),\n    #    ('user', 'plays', 'game'),\n    #    ('user', 'wishes', 'game'),\n    #    ('developer', 'develops', 'game')])\n\n    g = dgl.heterograph(\n        {\n            (\"user\", \"follows\", \"user\"): ([0, 1, 2, 1], [0, 0, 1, 1]),\n            (\"user\", \"plays\", \"game\"): ([0, 1, 2, 1], [0, 0, 1, 1]),\n            (\"user\", \"wishes\", \"game\"): ([0, 1, 1], [0, 0, 1]),\n            (\"developer\", \"develops\", \"game\"): ([0, 1, 0], [0, 1, 1]),\n        },\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    assert g.idtype == idtype\n    assert g.device == F.ctx()\n    return g\n\n\ndef create_random_hetero_with_single_source_node_type(idtype):\n    num_nodes = {\"n1\": 5, \"n2\": 10, \"n3\": 15}\n    etypes = [(\"n1\", \"r1\", \"n2\"), (\"n1\", \"r2\", \"n3\"), (\"n1\", \"r3\", \"n2\")]\n    edges = {}\n    for etype in etypes:\n        src_ntype, _, dst_ntype = etype\n        arr = spsp.random(\n            num_nodes[src_ntype],\n            num_nodes[dst_ntype],\n            density=1,\n            format=\"coo\",\n            random_state=100,\n        )\n        edges[etype] = (arr.row, arr.col)\n    return dgl.heterograph(edges, idtype=idtype, device=F.ctx())\n\n\n@parametrize_idtype\ndef test_unary_copy_u(idtype):\n    def _test(mfunc):\n        g = create_test_heterograph(idtype)\n\n        x1 = F.randn((g.num_nodes(\"user\"), feat_size))\n        x2 = F.randn((g.num_nodes(\"developer\"), feat_size))\n\n        F.attach_grad(x1)\n        F.attach_grad(x2)\n        g.nodes[\"user\"].data[\"h\"] = x1\n        g.nodes[\"developer\"].data[\"h\"] = x2\n\n        #################################################################\n        #  apply_edges() is called on each relation type separately\n        #################################################################\n\n        with F.record_grad():\n            [\n                g.apply_edges(fn.copy_u(\"h\", \"m\"), etype=rel)\n                for rel in g.canonical_etypes\n            ]\n            r1 = g[\"plays\"].edata[\"m\"]\n            F.backward(r1, F.ones(r1.shape))\n            n_grad1 = F.grad(g.ndata[\"h\"][\"user\"])\n        # TODO (Israt): clear not working\n        g.edata[\"m\"].clear()\n\n        #################################################################\n        #  apply_edges() is called on all relation types\n        #################################################################\n\n        g.apply_edges(fn.copy_u(\"h\", \"m\"))\n        r2 = g[\"plays\"].edata[\"m\"]\n        F.backward(r2, F.ones(r2.shape))\n        n_grad2 = F.grad(g.nodes[\"user\"].data[\"h\"])\n\n        # correctness check\n        def _print_error(a, b):\n            for i, (x, y) in enumerate(\n                zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())\n            ):\n                if not np.allclose(x, y):\n                    print(\"@{} {} v.s. {}\".format(i, x, y))\n\n        if not F.allclose(r1, r2):\n            _print_error(r1, r2)\n        assert F.allclose(r1, r2)\n        if not F.allclose(n_grad1, n_grad2):\n            print(\"node grad\")\n            _print_error(n_grad1, n_grad2)\n        assert F.allclose(n_grad1, n_grad2)\n\n    _test(fn.copy_u)\n\n\n@parametrize_idtype\ndef test_unary_copy_e(idtype):\n    def _test(mfunc):\n        g = create_test_heterograph(idtype)\n        feat_size = 2\n\n        x1 = F.randn((4, feat_size))\n        x2 = F.randn((4, feat_size))\n        x3 = F.randn((3, feat_size))\n        x4 = F.randn((3, feat_size))\n        F.attach_grad(x1)\n        F.attach_grad(x2)\n        F.attach_grad(x3)\n        F.attach_grad(x4)\n        g[\"plays\"].edata[\"eid\"] = x1\n        g[\"follows\"].edata[\"eid\"] = x2\n        g[\"develops\"].edata[\"eid\"] = x3\n        g[\"wishes\"].edata[\"eid\"] = x4\n\n        #################################################################\n        #  apply_edges() is called on each relation type separately\n        #################################################################\n        with F.record_grad():\n            [\n                g.apply_edges(fn.copy_e(\"eid\", \"m\"), etype=rel)\n                for rel in g.canonical_etypes\n            ]\n            r1 = g[\"develops\"].edata[\"m\"]\n            F.backward(r1, F.ones(r1.shape))\n            e_grad1 = F.grad(g[\"develops\"].edata[\"eid\"])\n\n        #################################################################\n        #  apply_edges() is called on all relation types\n        #################################################################\n\n        g.apply_edges(fn.copy_e(\"eid\", \"m\"))\n        r2 = g[\"develops\"].edata[\"m\"]\n        F.backward(r2, F.ones(r2.shape))\n        e_grad2 = F.grad(g[\"develops\"].edata[\"eid\"])\n\n        # # correctness check\n        def _print_error(a, b):\n            for i, (x, y) in enumerate(\n                zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())\n            ):\n                if not np.allclose(x, y):\n                    print(\"@{} {} v.s. {}\".format(i, x, y))\n\n        if not F.allclose(r1, r2):\n            _print_error(r1, r2)\n        assert F.allclose(r1, r2)\n        if not F.allclose(e_grad1, e_grad2):\n            print(\"edge grad\")\n            _print_error(e_grad1, e_grad2)\n        assert F.allclose(e_grad1, e_grad2)\n\n    _test(fn.copy_e)\n\n\n@parametrize_idtype\ndef test_binary_op(idtype):\n    def _test(lhs, rhs, binary_op):\n        g = create_test_heterograph(idtype)\n\n        n1 = F.randn((g.num_nodes(\"user\"), feat_size))\n        n2 = F.randn((g.num_nodes(\"developer\"), feat_size))\n        n3 = F.randn((g.num_nodes(\"game\"), feat_size))\n\n        x1 = F.randn((g.num_edges(\"plays\"), feat_size))\n        x2 = F.randn((g.num_edges(\"follows\"), feat_size))\n        x3 = F.randn((g.num_edges(\"develops\"), feat_size))\n        x4 = F.randn((g.num_edges(\"wishes\"), feat_size))\n\n        builtin_msg_name = \"{}_{}_{}\".format(lhs, binary_op, rhs)\n        builtin_msg = getattr(fn, builtin_msg_name)\n\n        #################################################################\n        #  apply_edges() is called on each relation type separately\n        #################################################################\n\n        F.attach_grad(n1)\n        F.attach_grad(n2)\n        F.attach_grad(n3)\n        g.nodes[\"user\"].data[\"h\"] = n1\n        g.nodes[\"developer\"].data[\"h\"] = n2\n        g.nodes[\"game\"].data[\"h\"] = n3\n        F.attach_grad(x1)\n        F.attach_grad(x2)\n        F.attach_grad(x3)\n        F.attach_grad(x4)\n        g[\"plays\"].edata[\"h\"] = x1\n        g[\"follows\"].edata[\"h\"] = x2\n        g[\"develops\"].edata[\"h\"] = x3\n        g[\"wishes\"].edata[\"h\"] = x4\n\n        with F.record_grad():\n            [\n                g.apply_edges(builtin_msg(\"h\", \"h\", \"m\"), etype=rel)\n                for rel in g.canonical_etypes\n            ]\n            r1 = g[\"plays\"].edata[\"m\"]\n            loss = F.sum(r1.view(-1), 0)\n            F.backward(loss)\n            n_grad1 = F.grad(g.nodes[\"game\"].data[\"h\"])\n\n        #################################################################\n        #  apply_edges() is called on all relation types\n        #################################################################\n\n        F.attach_grad(n1)\n        F.attach_grad(n2)\n        F.attach_grad(n3)\n        g.nodes[\"user\"].data[\"h\"] = n1\n        g.nodes[\"developer\"].data[\"h\"] = n2\n        g.nodes[\"game\"].data[\"h\"] = n3\n        F.attach_grad(x1)\n        F.attach_grad(x2)\n        F.attach_grad(x3)\n        F.attach_grad(x4)\n        g[\"plays\"].edata[\"h\"] = x1\n        g[\"follows\"].edata[\"h\"] = x2\n        g[\"develops\"].edata[\"h\"] = x3\n        g[\"wishes\"].edata[\"h\"] = x4\n\n        with F.record_grad():\n            g.apply_edges(builtin_msg(\"h\", \"h\", \"m\"))\n            r2 = g[\"plays\"].edata[\"m\"]\n            loss = F.sum(r2.view(-1), 0)\n            F.backward(loss)\n            n_grad2 = F.grad(g.nodes[\"game\"].data[\"h\"])\n\n        # correctness check\n        def _print_error(a, b):\n            for i, (x, y) in enumerate(\n                zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())\n            ):\n                if not np.allclose(x, y):\n                    print(\"@{} {} v.s. {}\".format(i, x, y))\n\n        if not F.allclose(r1, r2):\n            _print_error(r1, r2)\n        assert F.allclose(r1, r2)\n        if n_grad1 is not None or n_grad2 is not None:\n            if not F.allclose(n_grad1, n_grad2):\n                print(\"node grad\")\n                _print_error(n_grad1, n_grad2)\n            assert F.allclose(n_grad1, n_grad2)\n\n    target = [\"u\", \"v\", \"e\"]\n    for lhs, rhs in product(target, target):\n        if lhs == rhs:\n            continue\n        for binary_op in [\"add\", \"sub\", \"mul\", \"div\", \"dot\"]:\n            print(lhs, rhs, binary_op)\n            _test(lhs, rhs, binary_op)\n\n\n# Here we test heterograph with only single source node type because the format\n# of node feature is a tensor.\n@unittest.skipIf(\n    dgl.backend.backend_name != \"pytorch\", reason=\"Only support PyTorch for now\"\n)\n@parametrize_idtype\ndef test_heterograph_with_single_source_node_type_apply_edges(idtype):\n    hg = create_random_hetero_with_single_source_node_type(idtype)\n\n    hg.nodes[\"n1\"].data[\"h\"] = F.randn((hg.num_nodes(\"n1\"), 1))\n    hg.nodes[\"n2\"].data[\"h\"] = F.randn((hg.num_nodes(\"n2\"), 1))\n    hg.nodes[\"n3\"].data[\"h\"] = F.randn((hg.num_nodes(\"n3\"), 1))\n\n    assert type(hg.srcdata[\"h\"]) == torch.Tensor\n    hg.apply_edges(fn.u_add_v(\"h\", \"h\", \"x\"))\n\n\nif __name__ == \"__main__\":\n    test_unary_copy_u()\n    test_unary_copy_e()\n"
  },
  {
    "path": "tests/python/common/test_heterograph-index.py",
    "content": "import unittest\n\nimport backend as F\n\nimport dgl\nimport pytest\nfrom dgl import DGLError\nfrom utils import parametrize_idtype\n\n\ndef create_test_heterograph(idtype):\n    # 3 users, 2 games, 2 developers\n    # metagraph:\n    #    ('user', 'follows', 'user'),\n    #    ('user', 'plays', 'game'),\n    #    ('user', 'wishes', 'game'),\n    #    ('developer', 'develops', 'game')])\n\n    g = dgl.heterograph(\n        {\n            (\"user\", \"follows\", \"user\"): ([0, 1], [1, 2]),\n            (\"user\", \"plays\", \"game\"): ([0, 1, 2, 1], [0, 0, 1, 1]),\n            (\"user\", \"wishes\", \"game\"): ([0, 2], [1, 0]),\n            (\"developer\", \"develops\", \"game\"): ([0, 1], [0, 1]),\n        },\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    assert g.idtype == idtype\n    assert g.device == F.ctx()\n    return g\n\n\n@unittest.skipIf(\n    F._default_context_str == \"cpu\", reason=\"Need gpu for this test\"\n)\n@unittest.skipIf(\n    dgl.backend.backend_name != \"pytorch\",\n    reason=\"Pinning graph outplace only supported for PyTorch\",\n)\n@parametrize_idtype\ndef test_pin_memory(idtype):\n    g = create_test_heterograph(idtype)\n    g.nodes[\"user\"].data[\"h\"] = F.ones((3, 5))\n    g.nodes[\"game\"].data[\"i\"] = F.ones((2, 5))\n    g.edges[\"plays\"].data[\"e\"] = F.ones((4, 4))\n    g = g.to(F.cpu())\n    assert not g.is_pinned()\n\n    # Test pinning a CPU graph.\n    g._graph.pin_memory()\n    assert not g.is_pinned()\n    g._graph = g._graph.pin_memory()\n    assert g.is_pinned()\n    assert g.device == F.cpu()\n\n    # when clone with a new (different) formats, e.g., g.formats(\"csc\")\n    # ensure the new graphs are not pinned\n    assert not g.formats(\"csc\").is_pinned()\n    assert not g.formats(\"csr\").is_pinned()\n    # 'coo' formats is the default and thus not cloned\n    assert g.formats(\"coo\").is_pinned()\n\n    # Test pinning a GPU graph will cause error raised.\n    g1 = g.to(F.cuda())\n    with pytest.raises(DGLError):\n        g1._graph.pin_memory()\n\n    # Test pinning an empty homograph\n    g2 = dgl.graph(([], []))\n    assert not g2.is_pinned()\n    g2._graph = g2._graph.pin_memory()\n    assert g2.is_pinned()\n\n    # Test pinning heterograph with 0 edge of one relation type\n    g3 = dgl.heterograph(\n        {(\"a\", \"b\", \"c\"): ([0, 1], [1, 2]), (\"c\", \"d\", \"c\"): ([], [])}\n    ).astype(idtype)\n    g3._graph = g3._graph.pin_memory()\n    assert g3.is_pinned()\n\n\nif __name__ == \"__main__\":\n    pass\n"
  },
  {
    "path": "tests/python/common/test_heterograph-kernel.py",
    "content": "from itertools import product\n\nimport backend as F\n\nimport dgl\nimport dgl.function as fn\nimport networkx as nx\nimport numpy as np\nimport pytest\nfrom utils import get_cases, parametrize_idtype\n\n\ndef udf_copy_src(edges):\n    return {\"m\": edges.src[\"u\"]}\n\n\ndef udf_copy_edge(edges):\n    return {\"m\": edges.data[\"e\"]}\n\n\ndef udf_mean(nodes):\n    return {\"r2\": F.mean(nodes.mailbox[\"m\"], 1)}\n\n\ndef udf_sum(nodes):\n    return {\"r2\": F.sum(nodes.mailbox[\"m\"], 1)}\n\n\ndef udf_max(nodes):\n    return {\"r2\": F.max(nodes.mailbox[\"m\"], 1)}\n\n\nD1 = 5\nD2 = 3\nD3 = 4\nD4 = 10  # NOTE(xiang): used to dot feature vector\nbuiltin = {\"sum\": fn.sum, \"max\": fn.max, \"mean\": fn.mean}\nudf_reduce = {\"sum\": udf_sum, \"max\": udf_max, \"mean\": udf_mean}\nfill_value = {\"sum\": 0, \"max\": float(\"-inf\")}\n\n\ndef generate_feature(g, broadcast=\"none\", binary_op=\"none\"):\n    \"\"\"Create graph with src, edge, dst feature. broadcast can be 'u',\n    'e', 'v', 'none'\n    \"\"\"\n    np.random.seed(31)\n    nv = g.num_nodes()\n    ne = g.num_edges()\n    if binary_op == \"dot\":\n        if broadcast == \"e\":\n            u = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3, D4)))\n            e = F.tensor(np.random.uniform(-1, 1, (ne, D2, 1, D4)))\n            v = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3, D4)))\n        elif broadcast == \"u\":\n            u = F.tensor(np.random.uniform(-1, 1, (nv, D2, 1, D4)))\n            e = F.tensor(np.random.uniform(-1, 1, (ne, D1, D2, D3, D4)))\n            v = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3, D4)))\n        elif broadcast == \"v\":\n            u = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3, D4)))\n            e = F.tensor(np.random.uniform(-1, 1, (ne, D1, D2, D3, D4)))\n            v = F.tensor(np.random.uniform(-1, 1, (nv, D2, 1, D4)))\n        else:\n            u = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3, D4)))\n            e = F.tensor(np.random.uniform(-1, 1, (ne, D1, D2, D3, D4)))\n            v = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3, D4)))\n    else:\n        if broadcast == \"e\":\n            u = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3)))\n            e = F.tensor(np.random.uniform(-1, 1, (ne, D2, 1)))\n            v = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3)))\n        elif broadcast == \"u\":\n            u = F.tensor(np.random.uniform(-1, 1, (nv, D2, 1)))\n            e = F.tensor(np.random.uniform(-1, 1, (ne, D1, D2, D3)))\n            v = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3)))\n        elif broadcast == \"v\":\n            u = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3)))\n            e = F.tensor(np.random.uniform(-1, 1, (ne, D1, D2, D3)))\n            v = F.tensor(np.random.uniform(-1, 1, (nv, D2, 1)))\n        else:\n            u = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3)))\n            e = F.tensor(np.random.uniform(-1, 1, (ne, D1, D2, D3)))\n            v = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3)))\n    return (\n        F.astype(u, F.float32),\n        F.astype(v, F.float32),\n        F.astype(e, F.float32),\n    )\n\n\ndef test_copy_src_reduce():\n    def _test(red, partial):\n        g = dgl.from_networkx(nx.erdos_renyi_graph(100, 0.1))\n        # NOTE(zihao): add self-loop to avoid zero-degree nodes.\n        # https://github.com/dmlc/dgl/issues/761\n        g.add_edges(g.nodes(), g.nodes())\n        g = g.to(F.ctx())\n        hu, hv, he = generate_feature(g, \"none\", \"none\")\n        if partial:\n            nid = F.tensor(list(range(0, 100, 2)), g.idtype)\n\n        g.ndata[\"u\"] = F.attach_grad(F.clone(hu))\n        g.ndata[\"v\"] = F.attach_grad(F.clone(hv))\n        g.edata[\"e\"] = F.attach_grad(F.clone(he))\n\n        with F.record_grad():\n            if partial:\n                g.pull(\n                    nid,\n                    fn.copy_u(u=\"u\", out=\"m\"),\n                    builtin[red](msg=\"m\", out=\"r1\"),\n                )\n            else:\n                g.update_all(\n                    fn.copy_u(u=\"u\", out=\"m\"), builtin[red](msg=\"m\", out=\"r1\")\n                )\n            r1 = g.ndata[\"r1\"]\n            F.backward(F.reduce_sum(r1))\n            n_grad1 = F.grad(g.ndata[\"u\"])\n\n        # reset grad\n        g.ndata[\"u\"] = F.attach_grad(F.clone(hu))\n        g.ndata[\"v\"] = F.attach_grad(F.clone(hv))\n        g.edata[\"e\"] = F.attach_grad(F.clone(he))\n\n        with F.record_grad():\n            if partial:\n                g.pull(nid, udf_copy_src, udf_reduce[red])\n            else:\n                g.update_all(udf_copy_src, udf_reduce[red])\n            r2 = g.ndata[\"r2\"]\n            F.backward(F.reduce_sum(r2))\n            n_grad2 = F.grad(g.ndata[\"u\"])\n\n        def _print_error(a, b):\n            print(\"ERROR: Test copy_src_{} partial: {}\".format(red, partial))\n            for i, (x, y) in enumerate(\n                zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())\n            ):\n                if not np.allclose(x, y):\n                    print(\"@{} {} v.s. {}\".format(i, x, y))\n\n        if not F.allclose(r1, r2):\n            _print_error(r1, r2)\n        assert F.allclose(r1, r2)\n        if not F.allclose(n_grad1, n_grad2):\n            print(\"node grad\")\n            _print_error(n_grad1, n_grad2)\n        assert F.allclose(n_grad1, n_grad2)\n\n    _test(\"sum\", False)\n    _test(\"max\", False)\n    _test(\"mean\", False)\n    _test(\"sum\", True)\n    _test(\"max\", True)\n    _test(\"mean\", True)\n\n\ndef test_copy_edge_reduce():\n    def _test(red, partial):\n        g = dgl.from_networkx(nx.erdos_renyi_graph(100, 0.1))\n        # NOTE(zihao): add self-loop to avoid zero-degree nodes.\n        g.add_edges(g.nodes(), g.nodes())\n        g = g.to(F.ctx())\n        hu, hv, he = generate_feature(g, \"none\", \"none\")\n        if partial:\n            nid = F.tensor(list(range(0, 100, 2)), g.idtype)\n\n        g.ndata[\"u\"] = F.attach_grad(F.clone(hu))\n        g.ndata[\"v\"] = F.attach_grad(F.clone(hv))\n        g.edata[\"e\"] = F.attach_grad(F.clone(he))\n\n        with F.record_grad():\n            if partial:\n                g.pull(\n                    nid,\n                    fn.copy_e(e=\"e\", out=\"m\"),\n                    builtin[red](msg=\"m\", out=\"r1\"),\n                )\n            else:\n                g.update_all(\n                    fn.copy_e(e=\"e\", out=\"m\"), builtin[red](msg=\"m\", out=\"r1\")\n                )\n            r1 = g.ndata[\"r1\"]\n            F.backward(F.reduce_sum(r1))\n            e_grad1 = F.grad(g.edata[\"e\"])\n\n        # reset grad\n        g.ndata[\"u\"] = F.attach_grad(F.clone(hu))\n        g.ndata[\"v\"] = F.attach_grad(F.clone(hv))\n        g.edata[\"e\"] = F.attach_grad(F.clone(he))\n\n        with F.record_grad():\n            if partial:\n                g.pull(nid, udf_copy_edge, udf_reduce[red])\n            else:\n                g.update_all(udf_copy_edge, udf_reduce[red])\n            r2 = g.ndata[\"r2\"]\n            F.backward(F.reduce_sum(r2))\n            e_grad2 = F.grad(g.edata[\"e\"])\n\n        def _print_error(a, b):\n            print(\"ERROR: Test copy_edge_{} partial: {}\".format(red, partial))\n            return\n            for i, (x, y) in enumerate(\n                zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())\n            ):\n                if not np.allclose(x, y):\n                    print(\"@{} {} v.s. {}\".format(i, x, y))\n\n        if not F.allclose(r1, r2):\n            _print_error(r1, r2)\n        assert F.allclose(r1, r2)\n        if not F.allclose(e_grad1, e_grad2):\n            print(\"edge gradient\")\n            _print_error(e_grad1, e_grad2)\n        assert F.allclose(e_grad1, e_grad2)\n\n    _test(\"sum\", False)\n    _test(\"max\", False)\n    _test(\"mean\", False)\n    _test(\"sum\", True)\n    _test(\"max\", True)\n    _test(\"mean\", True)\n\n\ndef test_all_binary_builtins():\n    def _test(g, lhs, rhs, binary_op, reducer, partial, nid, broadcast=\"none\"):\n        # initialize node/edge features with uniform(-1, 1)\n        hu, hv, he = generate_feature(g, broadcast, binary_op)\n        if binary_op == \"div\":\n            # op = div\n            # lhs range: [-1, 1]\n            # rhs range: [1, 2]\n            # result range: [-1, 1]\n            if rhs == \"u\":\n                hu = (hu + 3) / 2\n            elif rhs == \"v\":\n                hv = (hv + 3) / 2\n            elif rhs == \"e\":\n                he = (he + 3) / 2\n\n        if binary_op == \"add\" or binary_op == \"sub\":\n            # op = add, sub\n            # lhs range: [-1/2, 1/2]\n            # rhs range: [-1/2, 1/2]\n            # result range: [-1, 1]\n            hu = hu / 2\n            hv = hv / 2\n            he = he / 2\n\n        g.ndata[\"u\"] = F.attach_grad(F.clone(hu))\n        g.ndata[\"v\"] = F.attach_grad(F.clone(hv))\n        g.edata[\"e\"] = F.attach_grad(F.clone(he))\n\n        builtin_msg_name = \"{}_{}_{}\".format(lhs, binary_op, rhs)\n        builtin_msg = getattr(fn, builtin_msg_name)\n        builtin_red = getattr(fn, reducer)\n\n        def target_feature_switch(g, target):\n            if target == \"u\":\n                return g.ndata[\"u\"]\n            elif target == \"v\":\n                return g.ndata[\"v\"]\n            else:\n                return g.edata[\"e\"]\n\n        with F.record_grad():\n            if partial:\n                g.pull(nid, builtin_msg(lhs, rhs, \"m\"), builtin_red(\"m\", \"r1\"))\n            else:\n                g.update_all(builtin_msg(lhs, rhs, \"m\"), builtin_red(\"m\", \"r1\"))\n            r1 = g.ndata.pop(\"r1\")\n            F.backward(F.reduce_sum(r1))\n            lhs_grad_1 = F.grad(target_feature_switch(g, lhs))\n            rhs_grad_1 = F.grad(target_feature_switch(g, rhs))\n\n        # reset grad\n        g.ndata[\"u\"] = F.attach_grad(F.clone(hu))\n        g.ndata[\"v\"] = F.attach_grad(F.clone(hv))\n        g.edata[\"e\"] = F.attach_grad(F.clone(he))\n\n        def target_switch(edges, target):\n            if target == \"u\":\n                return edges.src\n            elif target == \"v\":\n                return edges.dst\n            elif target == \"e\":\n                return edges.data\n            else:\n                assert 0, \"Unknown target {}\".format(target)\n\n        def mfunc(edges):\n            op = getattr(F, binary_op)\n            lhs_data = target_switch(edges, lhs)[lhs]\n            rhs_data = target_switch(edges, rhs)[rhs]\n            # NOTE(zihao): we need to do batched broadcast\n            # e.g. (68, 3, 1) op (68, 5, 3, 4)\n            while F.ndim(lhs_data) < F.ndim(rhs_data):\n                lhs_data = F.unsqueeze(lhs_data, 1)\n            while F.ndim(rhs_data) < F.ndim(lhs_data):\n                rhs_data = F.unsqueeze(rhs_data, 1)\n            return {\"m\": op(lhs_data, rhs_data)}\n\n        def rfunc(nodes):\n            op = getattr(F, reducer)\n            return {\"r2\": op(nodes.mailbox[\"m\"], 1)}\n\n        with F.record_grad():\n            if partial:\n                g.pull(nid, mfunc, rfunc)\n            else:\n                g.update_all(mfunc, rfunc)\n            r2 = g.ndata.pop(\"r2\")\n            F.backward(F.reduce_sum(r2), F.tensor([1.0]))\n            lhs_grad_2 = F.grad(target_feature_switch(g, lhs))\n            rhs_grad_2 = F.grad(target_feature_switch(g, rhs))\n\n        rtol = 1e-4\n        atol = 1e-4\n\n        def _print_error(a, b):\n            print(\n                \"ERROR: Test {}_{}_{}_{} broadcast: {} partial: {}\".format(\n                    lhs, binary_op, rhs, reducer, broadcast, partial\n                )\n            )\n            return\n            if lhs == \"u\":\n                lhs_data = hu\n            elif lhs == \"v\":\n                lhs_data = hv\n            elif lhs == \"e\":\n                lhs_data = he\n\n            if rhs == \"u\":\n                rhs_data = hu\n            elif rhs == \"v\":\n                rhs_data = hv\n            elif rhs == \"e\":\n                rhs_data = he\n            print(\"lhs\", F.asnumpy(lhs_data).tolist())\n            print(\"rhs\", F.asnumpy(rhs_data).tolist())\n            for i, (x, y) in enumerate(\n                zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())\n            ):\n                if not np.allclose(x, y, rtol, atol):\n                    print(\"@{} {} v.s. {}\".format(i, x, y))\n\n        if not F.allclose(r1, r2, rtol, atol):\n            _print_error(r1, r2)\n        assert F.allclose(r1, r2, rtol, atol)\n\n        if not F.allclose(lhs_grad_1, lhs_grad_2, rtol, atol):\n            print(\"left grad\")\n            _print_error(lhs_grad_1, lhs_grad_2)\n        assert F.allclose(lhs_grad_1, lhs_grad_2, rtol, atol)\n\n        if not F.allclose(rhs_grad_1, rhs_grad_2, rtol, atol):\n            print(\"right grad\")\n            _print_error(rhs_grad_1, rhs_grad_2)\n        assert F.allclose(rhs_grad_1, rhs_grad_2, rtol, atol)\n\n    g = dgl.graph([])\n    g.add_nodes(20)\n    # NOTE(zihao): add self-loop to avoid zero-degree nodes.\n    g.add_edges(g.nodes(), g.nodes())\n    for i in range(2, 18):\n        g.add_edges(0, i)\n        g.add_edges(1, i)\n        g.add_edges(i, 18)\n        g.add_edges(i, 19)\n    g.add_edges(18, 0)\n    g.add_edges(18, 1)\n    g.add_edges(19, 0)\n    g.add_edges(19, 1)\n    g = g.to(F.ctx())\n    nid = F.tensor([0, 1, 4, 5, 7, 12, 14, 15, 18, 19], g.idtype)\n    target = [\"u\", \"v\", \"e\"]\n\n    for lhs, rhs in product(target, target):\n        if lhs == rhs:\n            continue\n        for binary_op in [\"add\", \"sub\", \"mul\", \"div\"]:\n            for reducer in [\"sum\", \"max\", \"min\", \"mean\"]:\n                for broadcast in [\"none\", lhs, rhs]:\n                    for partial in [False, True]:\n                        print(lhs, rhs, binary_op, reducer, broadcast, partial)\n                        _test(\n                            g,\n                            lhs,\n                            rhs,\n                            binary_op,\n                            reducer,\n                            partial,\n                            nid,\n                            broadcast=broadcast,\n                        )\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"g\", get_cases([\"homo-zero-degree\"]))\ndef test_mean_zero_degree(g, idtype):\n    g = g.astype(idtype).to(F.ctx())\n    g.ndata[\"h\"] = F.ones((g.num_nodes(), 3))\n    g.update_all(fn.copy_u(\"h\", \"m\"), fn.mean(\"m\", \"x\"))\n    deg = F.asnumpy(g.in_degrees())\n    v = F.tensor(np.where(deg == 0)[0])\n    assert F.allclose(F.gather_row(g.ndata[\"x\"], v), F.zeros((len(v), 3)))\n\n\nif __name__ == \"__main__\":\n    test_copy_src_reduce()\n    test_copy_edge_reduce()\n    test_all_binary_builtins()\n"
  },
  {
    "path": "tests/python/common/test_heterograph-misc.py",
    "content": "import math\nimport numbers\n\nimport backend as F\n\nimport dgl\nimport networkx as nx\nimport numpy as np\nimport pytest\nimport scipy.sparse as sp\nfrom dgl import DGLError\n\n\n# graph generation: a random graph with 10 nodes\n#  and 20 edges.\n#  - has self loop\n#  - no multi edge\ndef edge_pair_input(sort=False):\n    if sort:\n        src = [0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 4, 4, 5, 5, 6, 7, 7, 7, 9]\n        dst = [4, 6, 9, 3, 5, 3, 7, 5, 8, 1, 3, 4, 9, 1, 9, 6, 2, 8, 9, 2]\n        return src, dst\n    else:\n        src = [0, 0, 4, 5, 0, 4, 7, 4, 4, 3, 2, 7, 7, 5, 3, 2, 1, 9, 6, 1]\n        dst = [9, 6, 3, 9, 4, 4, 9, 9, 1, 8, 3, 2, 8, 1, 5, 7, 3, 2, 6, 5]\n        return src, dst\n\n\ndef nx_input():\n    g = nx.DiGraph()\n    src, dst = edge_pair_input()\n    for i, e in enumerate(zip(src, dst)):\n        g.add_edge(*e, id=i)\n    return g\n\n\ndef elist_input():\n    src, dst = edge_pair_input()\n    return list(zip(src, dst))\n\n\ndef scipy_coo_input():\n    src, dst = edge_pair_input()\n    return sp.coo_matrix((np.ones((20,)), (src, dst)), shape=(10, 10))\n\n\ndef scipy_csr_input():\n    src, dst = edge_pair_input()\n    csr = sp.coo_matrix((np.ones((20,)), (src, dst)), shape=(10, 10)).tocsr()\n    csr.sort_indices()\n    # src = [0 0 0 1 1 2 2 3 3 4 4 4 4 5 5 6 7 7 7 9]\n    # dst = [4 6 9 3 5 3 7 5 8 1 3 4 9 1 9 6 2 8 9 2]\n    return csr\n\n\ndef gen_by_mutation():\n    g = dgl.graph([])\n    src, dst = edge_pair_input()\n    g.add_nodes(10)\n    g.add_edges(src, dst)\n    return g\n\n\ndef test_query():\n    def _test_one(g):\n        assert g.num_nodes() == 10\n        assert g.num_edges() == 20\n\n        for i in range(10):\n            assert g.has_nodes(i)\n        assert not g.has_nodes(11)\n        assert F.allclose(g.has_nodes([0, 2, 10, 11]), F.tensor([1, 1, 0, 0]))\n\n        src, dst = edge_pair_input()\n        for u, v in zip(src, dst):\n            assert g.has_edges_between(u, v)\n        assert not g.has_edges_between(0, 0)\n        assert F.allclose(\n            g.has_edges_between([0, 0, 3], [0, 9, 8]), F.tensor([0, 1, 1])\n        )\n        assert set(F.asnumpy(g.predecessors(9))) == set([0, 5, 7, 4])\n        assert set(F.asnumpy(g.successors(2))) == set([7, 3])\n\n        assert g.edge_ids(4, 4) == 5\n        assert F.allclose(g.edge_ids([4, 0], [4, 9]), F.tensor([5, 0]))\n\n        src, dst = g.find_edges([3, 6, 5])\n        assert F.allclose(src, F.tensor([5, 7, 4]))\n        assert F.allclose(dst, F.tensor([9, 9, 4]))\n\n        src, dst, eid = g.in_edges(9, form=\"all\")\n        tup = list(zip(F.asnumpy(src), F.asnumpy(dst), F.asnumpy(eid)))\n        assert set(tup) == set([(0, 9, 0), (5, 9, 3), (7, 9, 6), (4, 9, 7)])\n        src, dst, eid = g.in_edges(\n            [9, 0, 8], form=\"all\"\n        )  # test node#0 has no in edges\n        tup = list(zip(F.asnumpy(src), F.asnumpy(dst), F.asnumpy(eid)))\n        assert set(tup) == set(\n            [(0, 9, 0), (5, 9, 3), (7, 9, 6), (4, 9, 7), (3, 8, 9), (7, 8, 12)]\n        )\n\n        src, dst, eid = g.out_edges(0, form=\"all\")\n        tup = list(zip(F.asnumpy(src), F.asnumpy(dst), F.asnumpy(eid)))\n        assert set(tup) == set([(0, 9, 0), (0, 6, 1), (0, 4, 4)])\n        src, dst, eid = g.out_edges(\n            [0, 4, 8], form=\"all\"\n        )  # test node#8 has no out edges\n        tup = list(zip(F.asnumpy(src), F.asnumpy(dst), F.asnumpy(eid)))\n        assert set(tup) == set(\n            [\n                (0, 9, 0),\n                (0, 6, 1),\n                (0, 4, 4),\n                (4, 3, 2),\n                (4, 4, 5),\n                (4, 9, 7),\n                (4, 1, 8),\n            ]\n        )\n\n        src, dst, eid = g.edges(\"all\", \"eid\")\n        t_src, t_dst = edge_pair_input()\n        t_tup = list(zip(t_src, t_dst, list(range(20))))\n        tup = list(zip(F.asnumpy(src), F.asnumpy(dst), F.asnumpy(eid)))\n        assert set(tup) == set(t_tup)\n        assert list(F.asnumpy(eid)) == list(range(20))\n\n        src, dst, eid = g.edges(\"all\", \"srcdst\")\n        t_src, t_dst = edge_pair_input()\n        t_tup = list(zip(t_src, t_dst, list(range(20))))\n        tup = list(zip(F.asnumpy(src), F.asnumpy(dst), F.asnumpy(eid)))\n        assert set(tup) == set(t_tup)\n        assert list(F.asnumpy(src)) == sorted(list(F.asnumpy(src)))\n\n        assert g.in_degrees(0) == 0\n        assert g.in_degrees(9) == 4\n        assert F.allclose(g.in_degrees([0, 9]), F.tensor([0, 4]))\n        assert g.out_degrees(8) == 0\n        assert g.out_degrees(9) == 1\n        assert F.allclose(g.out_degrees([8, 9]), F.tensor([0, 1]))\n\n        assert np.array_equal(\n            F.sparse_to_numpy(g.adj_external(transpose=True)),\n            scipy_coo_input().toarray().T,\n        )\n        assert np.array_equal(\n            F.sparse_to_numpy(g.adj_external(transpose=False)),\n            scipy_coo_input().toarray(),\n        )\n\n    def _test(g):\n        # test twice to see whether the cached format works or not\n        _test_one(g)\n        _test_one(g)\n\n    def _test_csr_one(g):\n        assert g.num_nodes() == 10\n        assert g.num_edges() == 20\n\n        for i in range(10):\n            assert g.has_nodes(i)\n        assert not g.has_nodes(11)\n        assert F.allclose(g.has_nodes([0, 2, 10, 11]), F.tensor([1, 1, 0, 0]))\n\n        src, dst = edge_pair_input(sort=True)\n        for u, v in zip(src, dst):\n            assert g.has_edges_between(u, v)\n        assert not g.has_edges_between(0, 0)\n        assert F.allclose(\n            g.has_edges_between([0, 0, 3], [0, 9, 8]), F.tensor([0, 1, 1])\n        )\n        assert set(F.asnumpy(g.predecessors(9))) == set([0, 5, 7, 4])\n        assert set(F.asnumpy(g.successors(2))) == set([7, 3])\n\n        # src = [0 0 0 1 1 2 2 3 3 4 4 4 4 5 5 6 7 7 7 9]\n        # dst = [4 6 9 3 5 3 7 5 8 1 3 4 9 1 9 6 2 8 9 2]\n        # eid = [0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9]\n        assert g.edge_ids(4, 4) == 11\n        assert F.allclose(g.edge_ids([4, 0], [4, 9]), F.tensor([11, 2]))\n\n        src, dst = g.find_edges([3, 6, 5])\n        assert F.allclose(src, F.tensor([1, 2, 2]))\n        assert F.allclose(dst, F.tensor([3, 7, 3]))\n\n        src, dst, eid = g.in_edges(9, form=\"all\")\n        tup = list(zip(F.asnumpy(src), F.asnumpy(dst), F.asnumpy(eid)))\n        assert set(tup) == set([(0, 9, 2), (5, 9, 14), (7, 9, 18), (4, 9, 12)])\n        src, dst, eid = g.in_edges(\n            [9, 0, 8], form=\"all\"\n        )  # test node#0 has no in edges\n        tup = list(zip(F.asnumpy(src), F.asnumpy(dst), F.asnumpy(eid)))\n        assert set(tup) == set(\n            [\n                (0, 9, 2),\n                (5, 9, 14),\n                (7, 9, 18),\n                (4, 9, 12),\n                (3, 8, 8),\n                (7, 8, 17),\n            ]\n        )\n\n        src, dst, eid = g.out_edges(0, form=\"all\")\n        tup = list(zip(F.asnumpy(src), F.asnumpy(dst), F.asnumpy(eid)))\n        assert set(tup) == set([(0, 9, 2), (0, 6, 1), (0, 4, 0)])\n        src, dst, eid = g.out_edges(\n            [0, 4, 8], form=\"all\"\n        )  # test node#8 has no out edges\n        tup = list(zip(F.asnumpy(src), F.asnumpy(dst), F.asnumpy(eid)))\n        assert set(tup) == set(\n            [\n                (0, 9, 2),\n                (0, 6, 1),\n                (0, 4, 0),\n                (4, 3, 10),\n                (4, 4, 11),\n                (4, 9, 12),\n                (4, 1, 9),\n            ]\n        )\n\n        src, dst, eid = g.edges(\"all\", \"eid\")\n        t_src, t_dst = edge_pair_input(sort=True)\n        t_tup = list(zip(t_src, t_dst, list(range(20))))\n        tup = list(zip(F.asnumpy(src), F.asnumpy(dst), F.asnumpy(eid)))\n        assert set(tup) == set(t_tup)\n        assert list(F.asnumpy(eid)) == list(range(20))\n\n        src, dst, eid = g.edges(\"all\", \"srcdst\")\n        t_src, t_dst = edge_pair_input(sort=True)\n        t_tup = list(zip(t_src, t_dst, list(range(20))))\n        tup = list(zip(F.asnumpy(src), F.asnumpy(dst), F.asnumpy(eid)))\n        assert set(tup) == set(t_tup)\n        assert list(F.asnumpy(src)) == sorted(list(F.asnumpy(src)))\n\n        assert g.in_degrees(0) == 0\n        assert g.in_degrees(9) == 4\n        assert F.allclose(g.in_degrees([0, 9]), F.tensor([0, 4]))\n        assert g.out_degrees(8) == 0\n        assert g.out_degrees(9) == 1\n        assert F.allclose(g.out_degrees([8, 9]), F.tensor([0, 1]))\n\n        assert np.array_equal(\n            F.sparse_to_numpy(g.adj_external(transpose=True)),\n            scipy_coo_input().toarray().T,\n        )\n        assert np.array_equal(\n            F.sparse_to_numpy(g.adj_external(transpose=False)),\n            scipy_coo_input().toarray(),\n        )\n\n    def _test_csr(g):\n        # test twice to see whether the cached format works or not\n        _test_csr_one(g)\n        _test_csr_one(g)\n\n    def _test_edge_ids():\n        g = gen_by_mutation()\n        eids = g.edge_ids([4, 0], [4, 9])\n        assert eids.shape[0] == 2\n        eid = g.edge_ids(4, 4)\n        assert isinstance(eid, numbers.Number)\n        with pytest.raises(DGLError):\n            eids = g.edge_ids([9, 0], [4, 9])\n\n        with pytest.raises(DGLError):\n            eid = g.edge_ids(4, 5)\n\n        g.add_edges(0, 4)\n        eids = g.edge_ids([0, 0], [4, 9])\n        eid = g.edge_ids(0, 4)\n\n    _test(gen_by_mutation())\n    _test(dgl.graph(elist_input()))\n    _test(dgl.from_scipy(scipy_coo_input()))\n    _test_csr(dgl.from_scipy(scipy_csr_input()))\n    _test_edge_ids()\n\n\ndef test_mutation():\n    g = dgl.graph([])\n    g = g.to(F.ctx())\n    # test add nodes with data\n    g.add_nodes(5)\n    g.add_nodes(5, {\"h\": F.ones((5, 2))})\n    ans = F.cat([F.zeros((5, 2)), F.ones((5, 2))], 0)\n    assert F.allclose(ans, g.ndata[\"h\"])\n    g.ndata[\"w\"] = 2 * F.ones((10, 2))\n    assert F.allclose(2 * F.ones((10, 2)), g.ndata[\"w\"])\n    # test add edges with data\n    g.add_edges([2, 3], [3, 4])\n    g.add_edges([0, 1], [1, 2], {\"m\": F.ones((2, 2))})\n    ans = F.cat([F.zeros((2, 2)), F.ones((2, 2))], 0)\n    assert F.allclose(ans, g.edata[\"m\"])\n\n\ndef test_scipy_adjmat():\n    g = dgl.graph([])\n    g.add_nodes(10)\n    g.add_edges(range(9), range(1, 10))\n\n    adj_0 = g.adj_external(scipy_fmt=\"csr\")\n    adj_1 = g.adj_external(scipy_fmt=\"coo\")\n    assert np.array_equal(adj_0.toarray(), adj_1.toarray())\n\n    adj_t0 = g.adj_external(transpose=False, scipy_fmt=\"csr\")\n    adj_t_1 = g.adj_external(transpose=False, scipy_fmt=\"coo\")\n    assert np.array_equal(adj_0.toarray(), adj_1.toarray())\n\n\ndef test_incmat():\n    g = dgl.graph([])\n    g.add_nodes(4)\n    g.add_edges(0, 1)  # 0\n    g.add_edges(0, 2)  # 1\n    g.add_edges(0, 3)  # 2\n    g.add_edges(2, 3)  # 3\n    g.add_edges(1, 1)  # 4\n    inc_in = F.sparse_to_numpy(g.incidence_matrix(\"in\"))\n    inc_out = F.sparse_to_numpy(g.incidence_matrix(\"out\"))\n    inc_both = F.sparse_to_numpy(g.incidence_matrix(\"both\"))\n    print(inc_in)\n    print(inc_out)\n    print(inc_both)\n    assert np.allclose(\n        inc_in,\n        np.array(\n            [\n                [0.0, 0.0, 0.0, 0.0, 0.0],\n                [1.0, 0.0, 0.0, 0.0, 1.0],\n                [0.0, 1.0, 0.0, 0.0, 0.0],\n                [0.0, 0.0, 1.0, 1.0, 0.0],\n            ]\n        ),\n    )\n    assert np.allclose(\n        inc_out,\n        np.array(\n            [\n                [1.0, 1.0, 1.0, 0.0, 0.0],\n                [0.0, 0.0, 0.0, 0.0, 1.0],\n                [0.0, 0.0, 0.0, 1.0, 0.0],\n                [0.0, 0.0, 0.0, 0.0, 0.0],\n            ]\n        ),\n    )\n    assert np.allclose(\n        inc_both,\n        np.array(\n            [\n                [-1.0, -1.0, -1.0, 0.0, 0.0],\n                [1.0, 0.0, 0.0, 0.0, 0.0],\n                [0.0, 1.0, 0.0, -1.0, 0.0],\n                [0.0, 0.0, 1.0, 1.0, 0.0],\n            ]\n        ),\n    )\n\n\ndef test_find_edges():\n    g = dgl.graph([])\n    g.add_nodes(10)\n    g.add_edges(range(9), range(1, 10))\n    e = g.find_edges([1, 3, 2, 4])\n    assert (\n        F.asnumpy(e[0][0]) == 1\n        and F.asnumpy(e[0][1]) == 3\n        and F.asnumpy(e[0][2]) == 2\n        and F.asnumpy(e[0][3]) == 4\n    )\n    assert (\n        F.asnumpy(e[1][0]) == 2\n        and F.asnumpy(e[1][1]) == 4\n        and F.asnumpy(e[1][2]) == 3\n        and F.asnumpy(e[1][3]) == 5\n    )\n\n    try:\n        g.find_edges([10])\n        fail = False\n    except DGLError:\n        fail = True\n    finally:\n        assert fail\n\n\ndef test_ismultigraph():\n    g = dgl.graph([])\n    g.add_nodes(10)\n    assert g.is_multigraph == False\n    g.add_edges([0], [0])\n    assert g.is_multigraph == False\n    g.add_edges([1], [2])\n    assert g.is_multigraph == False\n    g.add_edges([0, 2], [0, 3])\n    assert g.is_multigraph == True\n\n\ndef test_hypersparse_query():\n    g = dgl.graph([])\n    g = g.to(F.ctx())\n    g.add_nodes(1000001)\n    g.add_edges([0], [1])\n    for i in range(10):\n        assert g.has_nodes(i)\n    assert not g.has_nodes(1000002)\n    assert g.edge_ids(0, 1) == 0\n    src, dst = g.find_edges([0])\n    src, dst, eid = g.in_edges(1, form=\"all\")\n    src, dst, eid = g.out_edges(0, form=\"all\")\n    src, dst = g.edges()\n    assert g.in_degrees(0) == 0\n    assert g.in_degrees(1) == 1\n    assert g.out_degrees(0) == 1\n    assert g.out_degrees(1) == 0\n\n\ndef test_empty_data_initialized():\n    g = dgl.graph([])\n    g = g.to(F.ctx())\n    g.ndata[\"ha\"] = F.tensor([])\n    g.add_nodes(1, {\"hb\": F.tensor([1])})\n    assert \"ha\" in g.ndata\n    assert len(g.ndata[\"ha\"]) == 1\n\n\ndef test_is_sorted():\n    u_src, u_dst = edge_pair_input(False)\n    s_src, s_dst = edge_pair_input(True)\n\n    u_src = F.tensor(u_src, dtype=F.int32)\n    u_dst = F.tensor(u_dst, dtype=F.int32)\n    s_src = F.tensor(s_src, dtype=F.int32)\n    s_dst = F.tensor(s_dst, dtype=F.int32)\n\n    src_sorted, dst_sorted = dgl.utils.is_sorted_srcdst(u_src, u_dst)\n    assert src_sorted == False\n    assert dst_sorted == False\n\n    src_sorted, dst_sorted = dgl.utils.is_sorted_srcdst(s_src, s_dst)\n    assert src_sorted == True\n    assert dst_sorted == True\n\n    src_sorted, dst_sorted = dgl.utils.is_sorted_srcdst(u_src, u_dst)\n    assert src_sorted == False\n    assert dst_sorted == False\n\n    src_sorted, dst_sorted = dgl.utils.is_sorted_srcdst(s_src, u_dst)\n    assert src_sorted == True\n    assert dst_sorted == False\n\n\ndef test_default_types():\n    dg = dgl.graph([])\n    g = dgl.graph(([], []))\n    assert dg.ntypes == g.ntypes\n    assert dg.etypes == g.etypes\n\n\ndef test_formats():\n    g = dgl.rand_graph(10, 20)\n    # in_degrees works if coo or csc available\n    # out_degrees works if coo or csr available\n    try:\n        g.in_degrees()\n        g.out_degrees()\n        g.formats(\"coo\").in_degrees()\n        g.formats(\"coo\").out_degrees()\n        g.formats(\"csc\").in_degrees()\n        g.formats(\"csr\").out_degrees()\n        fail = False\n    except DGLError:\n        fail = True\n    finally:\n        assert not fail\n    # in_degrees NOT works if csc available only\n    try:\n        g.formats(\"csc\").out_degrees()\n        fail = True\n    except DGLError:\n        fail = False\n    finally:\n        assert not fail\n    # out_degrees NOT works if csr available only\n    try:\n        g.formats(\"csr\").in_degrees()\n        fail = True\n    except DGLError:\n        fail = False\n    finally:\n        assert not fail\n\n    # If the intersection of created formats and allowed formats is\n    # not empty, then retain the intersection.\n    # Case1: intersection is not empty and intersected is equal to\n    # created formats.\n    g = g.formats([\"coo\", \"csr\"])\n    g.create_formats_()\n    g = g.formats([\"coo\", \"csr\", \"csc\"])\n    assert sorted(g.formats()[\"created\"]) == sorted([\"coo\", \"csr\"])\n    assert sorted(g.formats()[\"not created\"]) == sorted([\"csc\"])\n\n    # Case2: intersection is not empty and intersected is not equal\n    # to created formats.\n    g = g.formats([\"coo\", \"csr\"])\n    g.create_formats_()\n    g = g.formats([\"coo\", \"csc\"])\n    assert sorted(g.formats()[\"created\"]) == sorted([\"coo\"])\n    assert sorted(g.formats()[\"not created\"]) == sorted([\"csc\"])\n\n    # If the intersection of created formats and allowed formats is\n    # empty, then create a format in the order of `coo` -> `csr` ->\n    # `csc`.\n    # Case1: intersection is empty and just one format is allowed.\n    g = g.formats([\"coo\", \"csr\"])\n    g.create_formats_()\n    g = g.formats([\"csc\"])\n    assert sorted(g.formats()[\"created\"]) == sorted([\"csc\"])\n    assert sorted(g.formats()[\"not created\"]) == sorted([])\n\n    # Case2: intersection is empty and more than one format is allowed.\n    g = g.formats(\"csc\")\n    g.create_formats_()\n    g = g.formats([\"csr\", \"coo\"])\n    assert sorted(g.formats()[\"created\"]) == sorted([\"coo\"])\n    assert sorted(g.formats()[\"not created\"]) == sorted([\"csr\"])\n\n\nif __name__ == \"__main__\":\n    test_query()\n    test_mutation()\n    test_scipy_adjmat()\n    test_incmat()\n    test_find_edges()\n    test_hypersparse_query()\n    test_is_sorted()\n    test_default_types()\n    test_formats()\n"
  },
  {
    "path": "tests/python/common/test_heterograph-pickle.py",
    "content": "import io\nimport pickle\nimport unittest\n\nimport backend as F\n\nimport dgl\nimport dgl.function as fn\nimport networkx as nx\nimport pytest\nimport scipy.sparse as ssp\nfrom dgl.graph_index import create_graph_index\nfrom dgl.utils import toindex\nfrom utils import (\n    assert_is_identical,\n    assert_is_identical_hetero,\n    check_graph_equal,\n    get_cases,\n    parametrize_idtype,\n)\n\n\ndef _assert_is_identical_nodeflow(nf1, nf2):\n    assert nf1.num_nodes() == nf2.num_nodes()\n    src, dst = nf1.all_edges()\n    src2, dst2 = nf2.all_edges()\n    assert F.array_equal(src, src2)\n    assert F.array_equal(dst, dst2)\n\n    assert nf1.num_layers == nf2.num_layers\n    for i in range(nf1.num_layers):\n        assert nf1.layer_size(i) == nf2.layer_size(i)\n        assert nf1.layers[i].data.keys() == nf2.layers[i].data.keys()\n        for k in nf1.layers[i].data:\n            assert F.allclose(nf1.layers[i].data[k], nf2.layers[i].data[k])\n    assert nf1.num_blocks == nf2.num_blocks\n    for i in range(nf1.num_blocks):\n        assert nf1.block_size(i) == nf2.block_size(i)\n        assert nf1.blocks[i].data.keys() == nf2.blocks[i].data.keys()\n        for k in nf1.blocks[i].data:\n            assert F.allclose(nf1.blocks[i].data[k], nf2.blocks[i].data[k])\n\n\ndef _assert_is_identical_batchedgraph(bg1, bg2):\n    assert_is_identical(bg1, bg2)\n    assert bg1.batch_size == bg2.batch_size\n    assert bg1.batch_num_nodes == bg2.batch_num_nodes\n    assert bg1.batch_num_edges == bg2.batch_num_edges\n\n\ndef _assert_is_identical_batchedhetero(bg1, bg2):\n    assert_is_identical_hetero(bg1, bg2)\n    for ntype in bg1.ntypes:\n        assert bg1.batch_num_nodes(ntype) == bg2.batch_num_nodes(ntype)\n    for canonical_etype in bg1.canonical_etypes:\n        assert bg1.batch_num_edges(canonical_etype) == bg2.batch_num_edges(\n            canonical_etype\n        )\n\n\ndef _assert_is_identical_index(i1, i2):\n    assert i1.slice_data() == i2.slice_data()\n    assert F.array_equal(i1.tousertensor(), i2.tousertensor())\n\n\ndef _reconstruct_pickle(obj):\n    f = io.BytesIO()\n    pickle.dump(obj, f)\n    f.seek(0)\n    obj = pickle.load(f)\n    f.close()\n\n    return obj\n\n\ndef test_pickling_index():\n    # normal index\n    i = toindex([1, 2, 3])\n    i.tousertensor()\n    i.todgltensor()  # construct a dgl tensor which is unpicklable\n    i2 = _reconstruct_pickle(i)\n    _assert_is_identical_index(i, i2)\n\n    # slice index\n    i = toindex(slice(5, 10))\n    i2 = _reconstruct_pickle(i)\n    _assert_is_identical_index(i, i2)\n\n\ndef test_pickling_graph_index():\n    gi = create_graph_index(None, False)\n    gi.add_nodes(3)\n    src_idx = toindex([0, 0])\n    dst_idx = toindex([1, 2])\n    gi.add_edges(src_idx, dst_idx)\n\n    gi2 = _reconstruct_pickle(gi)\n\n    assert gi2.num_nodes() == gi.num_nodes()\n    src_idx2, dst_idx2, _ = gi2.edges()\n    assert F.array_equal(src_idx.tousertensor(), src_idx2.tousertensor())\n    assert F.array_equal(dst_idx.tousertensor(), dst_idx2.tousertensor())\n\n\ndef _global_message_func(nodes):\n    return {\"x\": nodes.data[\"x\"]}\n\n\n@unittest.skipIf(F._default_context_str == \"gpu\", reason=\"GPU not implemented\")\n@parametrize_idtype\n@pytest.mark.parametrize(\n    \"g\", get_cases(exclude=[\"dglgraph\", \"two_hetero_batch\"])\n)\ndef test_pickling_graph(g, idtype):\n    g = g.astype(idtype)\n    new_g = _reconstruct_pickle(g)\n    check_graph_equal(g, new_g, check_feature=True)\n\n\n@unittest.skipIf(F._default_context_str == \"gpu\", reason=\"GPU not implemented\")\ndef test_pickling_batched_heterograph():\n    # copied from test_heterograph.create_test_heterograph()\n    g = dgl.heterograph(\n        {\n            (\"user\", \"follows\", \"user\"): ([0, 1], [1, 2]),\n            (\"user\", \"plays\", \"game\"): ([0, 1, 2, 1], [0, 0, 1, 1]),\n            (\"user\", \"wishes\", \"game\"): ([0, 2], [1, 0]),\n            (\"developer\", \"develops\", \"game\"): ([0, 1], [0, 1]),\n        }\n    )\n    g2 = dgl.heterograph(\n        {\n            (\"user\", \"follows\", \"user\"): ([0, 1], [1, 2]),\n            (\"user\", \"plays\", \"game\"): ([0, 1, 2, 1], [0, 0, 1, 1]),\n            (\"user\", \"wishes\", \"game\"): ([0, 2], [1, 0]),\n            (\"developer\", \"develops\", \"game\"): ([0, 1], [0, 1]),\n        }\n    )\n\n    g.nodes[\"user\"].data[\"u_h\"] = F.randn((3, 4))\n    g.nodes[\"game\"].data[\"g_h\"] = F.randn((2, 5))\n    g.edges[\"plays\"].data[\"p_h\"] = F.randn((4, 6))\n    g2.nodes[\"user\"].data[\"u_h\"] = F.randn((3, 4))\n    g2.nodes[\"game\"].data[\"g_h\"] = F.randn((2, 5))\n    g2.edges[\"plays\"].data[\"p_h\"] = F.randn((4, 6))\n\n    bg = dgl.batch([g, g2])\n    new_bg = _reconstruct_pickle(bg)\n    check_graph_equal(bg, new_bg)\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"GPU edge_subgraph w/ relabeling not implemented\",\n)\ndef test_pickling_subgraph():\n    f1 = io.BytesIO()\n    f2 = io.BytesIO()\n    g = dgl.rand_graph(10000, 100000)\n    g.ndata[\"x\"] = F.randn((10000, 4))\n    g.edata[\"x\"] = F.randn((100000, 5))\n    pickle.dump(g, f1)\n    sg = g.subgraph([0, 1])\n    sgx = sg.ndata[\"x\"]  # materialize\n    pickle.dump(sg, f2)\n    # TODO(BarclayII): How should I test that the size of the subgraph pickle file should not\n    # be as large as the size of the original pickle file?\n    assert f1.tell() > f2.tell() * 50\n\n    f2.seek(0)\n    f2.truncate()\n    sgx = sg.edata[\"x\"]  # materialize\n    pickle.dump(sg, f2)\n    assert f1.tell() > f2.tell() * 50\n\n    f2.seek(0)\n    f2.truncate()\n    sg = g.edge_subgraph([0])\n    sgx = sg.edata[\"x\"]  # materialize\n    pickle.dump(sg, f2)\n    assert f1.tell() > f2.tell() * 50\n\n    f2.seek(0)\n    f2.truncate()\n    sgx = sg.ndata[\"x\"]  # materialize\n    pickle.dump(sg, f2)\n    assert f1.tell() > f2.tell() * 50\n\n    f1.close()\n    f2.close()\n\n\n@unittest.skipIf(F._default_context_str != \"gpu\", reason=\"Need GPU for pin\")\n@unittest.skipIf(\n    dgl.backend.backend_name == \"tensorflow\",\n    reason=\"TensorFlow create graph on gpu when unpickle\",\n)\n@parametrize_idtype\ndef test_pickling_is_pinned(idtype):\n    from copy import deepcopy\n\n    g = dgl.rand_graph(10, 20, idtype=idtype, device=F.cpu())\n    hg = dgl.heterograph(\n        {\n            (\"user\", \"follows\", \"user\"): ([0, 1], [1, 2]),\n            (\"user\", \"plays\", \"game\"): ([0, 1, 2, 1], [0, 0, 1, 1]),\n            (\"user\", \"wishes\", \"game\"): ([0, 2], [1, 0]),\n            (\"developer\", \"develops\", \"game\"): ([0, 1], [0, 1]),\n        },\n        idtype=idtype,\n        device=F.cpu(),\n    )\n    for graph in [g, hg]:\n        assert not graph.is_pinned()\n        graph.pin_memory_()\n        assert graph.is_pinned()\n        pg = _reconstruct_pickle(graph)\n        assert pg.is_pinned()\n        pg.unpin_memory_()\n        dg = deepcopy(graph)\n        assert dg.is_pinned()\n        dg.unpin_memory_()\n        graph.unpin_memory_()\n\n\nif __name__ == \"__main__\":\n    test_pickling_index()\n    test_pickling_graph_index()\n    test_pickling_frame()\n    test_pickling_graph()\n    test_pickling_nodeflow()\n    test_pickling_batched_graph()\n    test_pickling_heterograph()\n    test_pickling_batched_heterograph()\n    test_pickling_is_pinned()\n"
  },
  {
    "path": "tests/python/common/test_heterograph-remove.py",
    "content": "import backend as F\n\nimport dgl\nimport numpy as np\nfrom utils import parametrize_idtype\n\n\ndef create_graph(idtype, num_node):\n    g = dgl.graph([])\n    g = g.astype(idtype).to(F.ctx())\n    g.add_nodes(num_node)\n    return g\n\n\n@parametrize_idtype\ndef test_node_removal(idtype):\n    g = create_graph(idtype, 10)\n    g.add_edges(0, 0)\n    assert g.num_nodes() == 10\n    g.ndata[\"id\"] = F.arange(0, 10)\n\n    # remove nodes\n    g.remove_nodes(range(4, 7))\n    assert g.num_nodes() == 7\n    assert F.array_equal(g.ndata[\"id\"], F.tensor([0, 1, 2, 3, 7, 8, 9]))\n    assert dgl.NID not in g.ndata\n    assert dgl.EID not in g.edata\n\n    # add nodes\n    g.add_nodes(3)\n    assert g.num_nodes() == 10\n    assert F.array_equal(\n        g.ndata[\"id\"], F.tensor([0, 1, 2, 3, 7, 8, 9, 0, 0, 0])\n    )\n\n    # remove nodes\n    g.remove_nodes(range(1, 4), store_ids=True)\n    assert g.num_nodes() == 7\n    assert F.array_equal(g.ndata[\"id\"], F.tensor([0, 7, 8, 9, 0, 0, 0]))\n    assert dgl.NID in g.ndata\n    assert dgl.EID in g.edata\n\n\n@parametrize_idtype\ndef test_multigraph_node_removal(idtype):\n    g = create_graph(idtype, 5)\n    for i in range(5):\n        g.add_edges(i, i)\n        g.add_edges(i, i)\n    assert g.num_nodes() == 5\n    assert g.num_edges() == 10\n\n    # remove nodes\n    g.remove_nodes([2, 3])\n    assert g.num_nodes() == 3\n    assert g.num_edges() == 6\n\n    # add nodes\n    g.add_nodes(1)\n    g.add_edges(1, 1)\n    g.add_edges(1, 1)\n    assert g.num_nodes() == 4\n    assert g.num_edges() == 8\n\n    # remove nodes\n    g.remove_nodes([0])\n    assert g.num_nodes() == 3\n    assert g.num_edges() == 6\n\n\n@parametrize_idtype\ndef test_multigraph_edge_removal(idtype):\n    g = create_graph(idtype, 5)\n    for i in range(5):\n        g.add_edges(i, i)\n        g.add_edges(i, i)\n    assert g.num_nodes() == 5\n    assert g.num_edges() == 10\n\n    # remove edges\n    g.remove_edges([2, 3])\n    assert g.num_nodes() == 5\n    assert g.num_edges() == 8\n\n    # add edges\n    g.add_edges(1, 1)\n    g.add_edges(1, 1)\n    assert g.num_nodes() == 5\n    assert g.num_edges() == 10\n\n    # remove edges\n    g.remove_edges([0, 1])\n    assert g.num_nodes() == 5\n    assert g.num_edges() == 8\n\n\n@parametrize_idtype\ndef test_edge_removal(idtype):\n    g = create_graph(idtype, 5)\n    for i in range(5):\n        for j in range(5):\n            g.add_edges(i, j)\n    g.edata[\"id\"] = F.arange(0, 25)\n\n    # remove edges\n    g.remove_edges(range(13, 20))\n    assert g.num_nodes() == 5\n    assert g.num_edges() == 18\n    assert F.array_equal(\n        g.edata[\"id\"], F.tensor(list(range(13)) + list(range(20, 25)))\n    )\n    assert dgl.NID not in g.ndata\n    assert dgl.EID not in g.edata\n\n    # add edges\n    g.add_edges(3, 3)\n    assert g.num_nodes() == 5\n    assert g.num_edges() == 19\n    assert F.array_equal(\n        g.edata[\"id\"], F.tensor(list(range(13)) + list(range(20, 25)) + [0])\n    )\n\n    # remove edges\n    g.remove_edges(range(2, 10), store_ids=True)\n    assert g.num_nodes() == 5\n    assert g.num_edges() == 11\n    assert F.array_equal(\n        g.edata[\"id\"], F.tensor([0, 1, 10, 11, 12, 20, 21, 22, 23, 24, 0])\n    )\n    assert dgl.EID in g.edata\n\n\n@parametrize_idtype\ndef test_node_and_edge_removal(idtype):\n    g = create_graph(idtype, 10)\n    for i in range(10):\n        for j in range(10):\n            g.add_edges(i, j)\n    g.edata[\"id\"] = F.arange(0, 100)\n    assert g.num_nodes() == 10\n    assert g.num_edges() == 100\n\n    # remove nodes\n    g.remove_nodes([2, 4])\n    assert g.num_nodes() == 8\n    assert g.num_edges() == 64\n\n    # remove edges\n    g.remove_edges(range(10, 20))\n    assert g.num_nodes() == 8\n    assert g.num_edges() == 54\n\n    # add nodes\n    g.add_nodes(2)\n    assert g.num_nodes() == 10\n    assert g.num_edges() == 54\n\n    # add edges\n    for i in range(8, 10):\n        for j in range(8, 10):\n            g.add_edges(i, j)\n    assert g.num_nodes() == 10\n    assert g.num_edges() == 58\n\n    # remove edges\n    g.remove_edges(range(10, 20))\n    assert g.num_nodes() == 10\n    assert g.num_edges() == 48\n\n\n@parametrize_idtype\ndef test_node_frame(idtype):\n    g = create_graph(idtype, 10)\n    data = np.random.rand(10, 3)\n    new_data = data.take([0, 1, 2, 7, 8, 9], axis=0)\n    g.ndata[\"h\"] = F.tensor(data)\n\n    # remove nodes\n    g.remove_nodes(range(3, 7))\n    assert F.allclose(g.ndata[\"h\"], F.tensor(new_data))\n\n\n@parametrize_idtype\ndef test_edge_frame(idtype):\n    g = create_graph(idtype, 10)\n    g.add_edges(list(range(10)), list(range(1, 10)) + [0])\n    data = np.random.rand(10, 3)\n    new_data = data.take([0, 1, 2, 7, 8, 9], axis=0)\n    g.edata[\"h\"] = F.tensor(data)\n\n    # remove edges\n    g.remove_edges(range(3, 7))\n    assert F.allclose(g.edata[\"h\"], F.tensor(new_data))\n\n\n@parametrize_idtype\ndef test_issue1287(idtype):\n    # reproduce https://github.com/dmlc/dgl/issues/1287.\n    # setting features after remove nodes\n    g = create_graph(idtype, 5)\n    g.add_edges([0, 2, 3, 1, 1], [1, 0, 3, 1, 0])\n    g.remove_nodes([0, 1])\n    g.ndata[\"h\"] = F.randn((g.num_nodes(), 3))\n    g.edata[\"h\"] = F.randn((g.num_edges(), 2))\n\n    # remove edges\n    g = create_graph(idtype, 5)\n    g.add_edges([0, 2, 3, 1, 1], [1, 0, 3, 1, 0])\n    g.remove_edges([0, 1])\n    g = g.to(F.ctx())\n    g.ndata[\"h\"] = F.randn((g.num_nodes(), 3))\n    g.edata[\"h\"] = F.randn((g.num_edges(), 2))\n\n\nif __name__ == \"__main__\":\n    test_node_removal()\n    test_edge_removal()\n    test_multigraph_node_removal()\n    test_multigraph_edge_removal()\n    test_node_and_edge_removal()\n    test_node_frame()\n    test_edge_frame()\n    test_frame_size()\n"
  },
  {
    "path": "tests/python/common/test_heterograph-shared-memory.py",
    "content": "import io\nimport multiprocessing as mp\nimport os\nimport pickle\nimport unittest\n\nimport backend as F\n\nimport dgl\nimport dgl.function as fn\nimport networkx as nx\nimport scipy.sparse as ssp\nfrom dgl.graph_index import create_graph_index\nfrom dgl.utils import toindex\nfrom utils import parametrize_idtype\n\n\ndef create_test_graph(idtype):\n    g = dgl.heterograph(\n        (\n            {\n                (\"user\", \"follows\", \"user\"): ([0, 1], [1, 2]),\n                (\"user\", \"plays\", \"game\"): ([0, 1, 2, 1], [0, 0, 1, 1]),\n                (\"user\", \"wishes\", \"game\"): ([0, 2], [1, 0]),\n                (\"developer\", \"develops\", \"game\"): ([0, 1], [0, 1]),\n            }\n        ),\n        idtype=idtype,\n    )\n    return g\n\n\ndef _assert_is_identical_hetero(g, g2):\n    assert g.ntypes == g2.ntypes\n    assert g.canonical_etypes == g2.canonical_etypes\n\n    # check if two metagraphs are identical\n    for edges, features in g.metagraph().edges(keys=True).items():\n        assert g2.metagraph().edges(keys=True)[edges] == features\n\n    # check if node ID spaces and feature spaces are equal\n    for ntype in g.ntypes:\n        assert g.num_nodes(ntype) == g2.num_nodes(ntype)\n\n    # check if edge ID spaces and feature spaces are equal\n    for etype in g.canonical_etypes:\n        src, dst = g.all_edges(etype=etype, order=\"eid\")\n        src2, dst2 = g2.all_edges(etype=etype, order=\"eid\")\n        assert F.array_equal(src, src2)\n        assert F.array_equal(dst, dst2)\n\n\n@unittest.skipIf(\n    dgl.backend.backend_name == \"tensorflow\",\n    reason=\"Not support tensorflow for now\",\n)\n@parametrize_idtype\ndef test_single_process(idtype):\n    hg = create_test_graph(idtype=idtype)\n    hg_share = hg.shared_memory(\"hg\")\n    hg_rebuild = dgl.hetero_from_shared_memory(\"hg\")\n    hg_save_again = hg_rebuild.shared_memory(\"hg\")\n    _assert_is_identical_hetero(hg, hg_share)\n    _assert_is_identical_hetero(hg, hg_rebuild)\n    _assert_is_identical_hetero(hg, hg_save_again)\n\n\ndef sub_proc(hg_origin, name):\n    hg_rebuild = dgl.hetero_from_shared_memory(name)\n    hg_save_again = hg_rebuild.shared_memory(name)\n    _assert_is_identical_hetero(hg_origin, hg_rebuild)\n    _assert_is_identical_hetero(hg_origin, hg_save_again)\n\n\n@unittest.skipIf(\n    dgl.backend.backend_name == \"tensorflow\",\n    reason=\"Not support tensorflow for now\",\n)\n@parametrize_idtype\ndef test_multi_process(idtype):\n    hg = create_test_graph(idtype=idtype)\n    hg_share = hg.shared_memory(\"hg1\")\n    p = mp.Process(target=sub_proc, args=(hg, \"hg1\"))\n    p.start()\n    p.join()\n\n\n@unittest.skipIf(\n    F._default_context_str == \"cpu\", reason=\"Need gpu for this test\"\n)\n@unittest.skipIf(\n    dgl.backend.backend_name == \"tensorflow\",\n    reason=\"Not support tensorflow for now\",\n)\ndef test_copy_from_gpu():\n    hg = create_test_graph(idtype=F.int32)\n    hg_gpu = hg.to(F.cuda())\n    hg_share = hg_gpu.shared_memory(\"hg_gpu\")\n    p = mp.Process(target=sub_proc, args=(hg, \"hg_gpu\"))\n    p.start()\n    p.join()\n\n\n# TODO: Test calling shared_memory with Blocks (a subclass of HeteroGraph)\nif __name__ == \"__main__\":\n    test_single_process(F.int64)\n    test_multi_process(F.int32)\n    test_copy_from_gpu()\n"
  },
  {
    "path": "tests/python/common/test_heterograph-specialization.py",
    "content": "import backend as F\n\nimport dgl\nimport dgl.function as fn\nimport numpy as np\nimport scipy.sparse as sp\nfrom utils import parametrize_idtype\n\nD = 5\n\n\ndef generate_graph(idtype):\n    g = dgl.graph([])\n    g = g.astype(idtype).to(F.ctx())\n    g.add_nodes(10)\n    # create a graph where 0 is the source and 9 is the sink\n    for i in range(1, 9):\n        g.add_edges(0, i)\n        g.add_edges(i, 9)\n    # add a back flow from 9 to 0\n    g.add_edges(9, 0)\n    g.ndata.update({\"f1\": F.randn((10,)), \"f2\": F.randn((10, D))})\n    weights = F.randn((17,))\n    g.edata.update({\"e1\": weights, \"e2\": F.unsqueeze(weights, 1)})\n    return g\n\n\n@parametrize_idtype\ndef test_v2v_update_all(idtype):\n    def _test(fld):\n        def message_func(edges):\n            return {\"m\": edges.src[fld]}\n\n        def message_func_edge(edges):\n            if len(edges.src[fld].shape) == 1:\n                return {\"m\": edges.src[fld] * edges.data[\"e1\"]}\n            else:\n                return {\"m\": edges.src[fld] * edges.data[\"e2\"]}\n\n        def reduce_func(nodes):\n            return {fld: F.sum(nodes.mailbox[\"m\"], 1)}\n\n        def apply_func(nodes):\n            return {fld: 2 * nodes.data[fld]}\n\n        g = generate_graph(idtype)\n        # update all\n        v1 = g.ndata[fld]\n        g.update_all(\n            fn.copy_u(u=fld, out=\"m\"), fn.sum(msg=\"m\", out=fld), apply_func\n        )\n        v2 = g.ndata[fld]\n        g.ndata.update({fld: v1})\n        g.update_all(message_func, reduce_func, apply_func)\n        v3 = g.ndata[fld]\n        assert F.allclose(v2, v3)\n        # update all with edge weights\n        v1 = g.ndata[fld]\n        g.update_all(\n            fn.u_mul_e(fld, \"e1\", \"m\"), fn.sum(msg=\"m\", out=fld), apply_func\n        )\n        v2 = g.ndata[fld]\n        g.ndata.update({fld: v1})\n        g.update_all(message_func_edge, reduce_func, apply_func)\n        v4 = g.ndata[fld]\n        assert F.allclose(v2, v4)\n\n    # test 1d node features\n    _test(\"f1\")\n    # test 2d node features\n    _test(\"f2\")\n\n\n@parametrize_idtype\ndef test_v2v_snr(idtype):\n    u = F.tensor([0, 0, 0, 3, 4, 9], idtype)\n    v = F.tensor([1, 2, 3, 9, 9, 0], idtype)\n\n    def _test(fld):\n        def message_func(edges):\n            return {\"m\": edges.src[fld]}\n\n        def message_func_edge(edges):\n            if len(edges.src[fld].shape) == 1:\n                return {\"m\": edges.src[fld] * edges.data[\"e1\"]}\n            else:\n                return {\"m\": edges.src[fld] * edges.data[\"e2\"]}\n\n        def reduce_func(nodes):\n            return {fld: F.sum(nodes.mailbox[\"m\"], 1)}\n\n        def apply_func(nodes):\n            return {fld: 2 * nodes.data[fld]}\n\n        g = generate_graph(idtype)\n        # send and recv\n        v1 = g.ndata[fld]\n        g.send_and_recv(\n            (u, v),\n            fn.copy_u(u=fld, out=\"m\"),\n            fn.sum(msg=\"m\", out=fld),\n            apply_func,\n        )\n        v2 = g.ndata[fld]\n        g.ndata.update({fld: v1})\n        g.send_and_recv((u, v), message_func, reduce_func, apply_func)\n        v3 = g.ndata[fld]\n        assert F.allclose(v2, v3)\n        # send and recv with edge weights\n        v1 = g.ndata[fld]\n        g.send_and_recv(\n            (u, v),\n            fn.u_mul_e(fld, \"e1\", \"m\"),\n            fn.sum(msg=\"m\", out=fld),\n            apply_func,\n        )\n        v2 = g.ndata[fld]\n        g.ndata.update({fld: v1})\n        g.send_and_recv((u, v), message_func_edge, reduce_func, apply_func)\n        v4 = g.ndata[fld]\n        assert F.allclose(v2, v4)\n\n    # test 1d node features\n    _test(\"f1\")\n    # test 2d node features\n    _test(\"f2\")\n\n\n@parametrize_idtype\ndef test_v2v_pull(idtype):\n    nodes = F.tensor([1, 2, 3, 9], idtype)\n\n    def _test(fld):\n        def message_func(edges):\n            return {\"m\": edges.src[fld]}\n\n        def message_func_edge(edges):\n            if len(edges.src[fld].shape) == 1:\n                return {\"m\": edges.src[fld] * edges.data[\"e1\"]}\n            else:\n                return {\"m\": edges.src[fld] * edges.data[\"e2\"]}\n\n        def reduce_func(nodes):\n            return {fld: F.sum(nodes.mailbox[\"m\"], 1)}\n\n        def apply_func(nodes):\n            return {fld: 2 * nodes.data[fld]}\n\n        g = generate_graph(idtype)\n        # send and recv\n        v1 = g.ndata[fld]\n        g.pull(\n            nodes,\n            fn.copy_u(u=fld, out=\"m\"),\n            fn.sum(msg=\"m\", out=fld),\n            apply_func,\n        )\n        v2 = g.ndata[fld]\n        g.ndata[fld] = v1\n        g.pull(nodes, message_func, reduce_func, apply_func)\n        v3 = g.ndata[fld]\n        assert F.allclose(v2, v3)\n        # send and recv with edge weights\n        v1 = g.ndata[fld]\n        g.pull(\n            nodes,\n            fn.u_mul_e(fld, \"e1\", \"m\"),\n            fn.sum(msg=\"m\", out=fld),\n            apply_func,\n        )\n        v2 = g.ndata[fld]\n        g.ndata[fld] = v1\n        g.pull(nodes, message_func_edge, reduce_func, apply_func)\n        v4 = g.ndata[fld]\n        assert F.allclose(v2, v4)\n\n    # test 1d node features\n    _test(\"f1\")\n    # test 2d node features\n    _test(\"f2\")\n\n\n@parametrize_idtype\ndef test_update_all_multi_fallback(idtype):\n    # create a graph with zero in degree nodes\n    g = dgl.graph([])\n    g = g.astype(idtype).to(F.ctx())\n    g.add_nodes(10)\n    for i in range(1, 9):\n        g.add_edges(0, i)\n        g.add_edges(i, 9)\n    g.ndata[\"h\"] = F.randn((10, D))\n    g.edata[\"w1\"] = F.randn((16,))\n    g.edata[\"w2\"] = F.randn((16, D))\n\n    def _mfunc_hxw1(edges):\n        return {\"m1\": edges.src[\"h\"] * F.unsqueeze(edges.data[\"w1\"], 1)}\n\n    def _mfunc_hxw2(edges):\n        return {\"m2\": edges.src[\"h\"] * edges.data[\"w2\"]}\n\n    def _rfunc_m1(nodes):\n        return {\"o1\": F.sum(nodes.mailbox[\"m1\"], 1)}\n\n    def _rfunc_m2(nodes):\n        return {\"o2\": F.sum(nodes.mailbox[\"m2\"], 1)}\n\n    def _rfunc_m1max(nodes):\n        return {\"o3\": F.max(nodes.mailbox[\"m1\"], 1)}\n\n    def _afunc(nodes):\n        ret = {}\n        for k, v in nodes.data.items():\n            if k.startswith(\"o\"):\n                ret[k] = 2 * v\n        return ret\n\n    # compute ground truth\n    g.update_all(_mfunc_hxw1, _rfunc_m1, _afunc)\n    o1 = g.ndata.pop(\"o1\")\n    g.update_all(_mfunc_hxw2, _rfunc_m2, _afunc)\n    o2 = g.ndata.pop(\"o2\")\n    g.update_all(_mfunc_hxw1, _rfunc_m1max, _afunc)\n    o3 = g.ndata.pop(\"o3\")\n    # v2v spmv\n    g.update_all(\n        fn.u_mul_e(\"h\", \"w1\", \"m1\"), fn.sum(msg=\"m1\", out=\"o1\"), _afunc\n    )\n    assert F.allclose(o1, g.ndata.pop(\"o1\"))\n    # v2v fallback to e2v\n    g.update_all(\n        fn.u_mul_e(\"h\", \"w2\", \"m2\"), fn.sum(msg=\"m2\", out=\"o2\"), _afunc\n    )\n    assert F.allclose(o2, g.ndata.pop(\"o2\"))\n\n\n@parametrize_idtype\ndef test_pull_multi_fallback(idtype):\n    # create a graph with zero in degree nodes\n    g = dgl.graph([])\n    g = g.astype(idtype).to(F.ctx())\n    g.add_nodes(10)\n    for i in range(1, 9):\n        g.add_edges(0, i)\n        g.add_edges(i, 9)\n    g.ndata[\"h\"] = F.randn((10, D))\n    g.edata[\"w1\"] = F.randn((16,))\n    g.edata[\"w2\"] = F.randn((16, D))\n\n    def _mfunc_hxw1(edges):\n        return {\"m1\": edges.src[\"h\"] * F.unsqueeze(edges.data[\"w1\"], 1)}\n\n    def _mfunc_hxw2(edges):\n        return {\"m2\": edges.src[\"h\"] * edges.data[\"w2\"]}\n\n    def _rfunc_m1(nodes):\n        return {\"o1\": F.sum(nodes.mailbox[\"m1\"], 1)}\n\n    def _rfunc_m2(nodes):\n        return {\"o2\": F.sum(nodes.mailbox[\"m2\"], 1)}\n\n    def _rfunc_m1max(nodes):\n        return {\"o3\": F.max(nodes.mailbox[\"m1\"], 1)}\n\n    def _afunc(nodes):\n        ret = {}\n        for k, v in nodes.data.items():\n            if k.startswith(\"o\"):\n                ret[k] = 2 * v\n        return ret\n\n    # nodes to pull\n    def _pull_nodes(nodes):\n        # compute ground truth\n        g.pull(nodes, _mfunc_hxw1, _rfunc_m1, _afunc)\n        o1 = g.ndata.pop(\"o1\")\n        g.pull(nodes, _mfunc_hxw2, _rfunc_m2, _afunc)\n        o2 = g.ndata.pop(\"o2\")\n        g.pull(nodes, _mfunc_hxw1, _rfunc_m1max, _afunc)\n        o3 = g.ndata.pop(\"o3\")\n        # v2v spmv\n        g.pull(\n            nodes,\n            fn.u_mul_e(\"h\", \"w1\", \"m1\"),\n            fn.sum(msg=\"m1\", out=\"o1\"),\n            _afunc,\n        )\n        assert F.allclose(o1, g.ndata.pop(\"o1\"))\n        # v2v fallback to e2v\n        g.pull(\n            nodes,\n            fn.u_mul_e(\"h\", \"w2\", \"m2\"),\n            fn.sum(msg=\"m2\", out=\"o2\"),\n            _afunc,\n        )\n        assert F.allclose(o2, g.ndata.pop(\"o2\"))\n\n    # test#1: non-0deg nodes\n    nodes = [1, 2, 9]\n    _pull_nodes(nodes)\n    # test#2: 0deg nodes + non-0deg nodes\n    nodes = [0, 1, 2, 9]\n    _pull_nodes(nodes)\n\n\n@parametrize_idtype\ndef test_spmv_3d_feat(idtype):\n    def src_mul_edge_udf(edges):\n        return {\n            \"sum\": edges.src[\"h\"]\n            * F.unsqueeze(F.unsqueeze(edges.data[\"h\"], 1), 1)\n        }\n\n    def sum_udf(nodes):\n        return {\"h\": F.sum(nodes.mailbox[\"sum\"], 1)}\n\n    n = 100\n    p = 0.1\n    a = sp.random(n, n, p, data_rvs=lambda n: np.ones(n))\n    g = dgl.from_scipy(a)\n    g = g.astype(idtype).to(F.ctx())\n    m = g.num_edges()\n\n    # test#1: v2v with adj data\n    h = F.randn((n, 5, 5))\n    e = F.randn((m,))\n\n    g.ndata[\"h\"] = h\n    g.edata[\"h\"] = e\n    g.update_all(\n        message_func=fn.u_mul_e(\"h\", \"h\", \"sum\"), reduce_func=fn.sum(\"sum\", \"h\")\n    )  # 1\n    ans = g.ndata[\"h\"]\n\n    g.ndata[\"h\"] = h\n    g.edata[\"h\"] = e\n    g.update_all(\n        message_func=src_mul_edge_udf, reduce_func=fn.sum(\"sum\", \"h\")\n    )  # 2\n    assert F.allclose(g.ndata[\"h\"], ans)\n\n    g.ndata[\"h\"] = h\n    g.edata[\"h\"] = e\n    g.update_all(message_func=src_mul_edge_udf, reduce_func=sum_udf)  # 3\n    assert F.allclose(g.ndata[\"h\"], ans)\n\n    # test#2: e2v\n    def src_mul_edge_udf(edges):\n        return {\"sum\": edges.src[\"h\"] * edges.data[\"h\"]}\n\n    h = F.randn((n, 5, 5))\n    e = F.randn((m, 5, 5))\n\n    g.ndata[\"h\"] = h\n    g.edata[\"h\"] = e\n    g.update_all(\n        message_func=fn.u_mul_e(\"h\", \"h\", \"sum\"), reduce_func=fn.sum(\"sum\", \"h\")\n    )  # 1\n    ans = g.ndata[\"h\"]\n\n    g.ndata[\"h\"] = h\n    g.edata[\"h\"] = e\n    g.update_all(\n        message_func=src_mul_edge_udf, reduce_func=fn.sum(\"sum\", \"h\")\n    )  # 2\n    assert F.allclose(g.ndata[\"h\"], ans)\n\n    g.ndata[\"h\"] = h\n    g.edata[\"h\"] = e\n    g.update_all(message_func=src_mul_edge_udf, reduce_func=sum_udf)  # 3\n    assert F.allclose(g.ndata[\"h\"], ans)\n\n\nif __name__ == \"__main__\":\n    test_v2v_update_all()\n    test_v2v_snr()\n    test_v2v_pull()\n    test_v2v_update_all_multi_fn()\n    test_v2v_snr_multi_fn()\n    test_e2v_update_all_multi_fn()\n    test_e2v_snr_multi_fn()\n    test_e2v_recv_multi_fn()\n    test_update_all_multi_fallback()\n    test_pull_multi_fallback()\n    test_spmv_3d_feat()\n"
  },
  {
    "path": "tests/python/common/test_heterograph-update-all.py",
    "content": "import itertools\nimport unittest\nfrom collections import Counter\nfrom itertools import product\n\nimport backend as F\n\nimport dgl\nimport dgl.function as fn\nimport networkx as nx\nimport numpy as np\nimport pytest\nimport scipy.sparse as ssp\nfrom dgl import DGLError\nfrom scipy.sparse import rand\nfrom utils import get_cases, parametrize_idtype\n\nrfuncs = {\"sum\": fn.sum, \"max\": fn.max, \"min\": fn.min, \"mean\": fn.mean}\nfeat_size = 2\n\n\n@unittest.skipIf(\n    dgl.backend.backend_name != \"pytorch\", reason=\"Only support PyTorch for now\"\n)\ndef create_test_heterograph(idtype):\n    # test heterograph from the docstring, plus a user -- wishes -- game relation\n    # 3 users, 2 games, 2 developers\n    # metagraph:\n    #    ('user', 'follows', 'user'),\n    #    ('user', 'plays', 'game'),\n    #    ('user', 'wishes', 'game'),\n    #    ('developer', 'develops', 'game')])\n\n    g = dgl.heterograph(\n        {\n            (\"user\", \"follows\", \"user\"): ([0, 1, 2, 1], [0, 0, 1, 1]),\n            (\"user\", \"plays\", \"game\"): ([0, 1, 2, 1], [0, 0, 1, 1]),\n            (\"user\", \"wishes\", \"game\"): ([0, 1, 1], [0, 0, 1]),\n            (\"developer\", \"develops\", \"game\"): ([0, 1, 0], [0, 1, 1]),\n        },\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    assert g.idtype == idtype\n    assert g.device == F.ctx()\n    return g\n\n\ndef create_test_heterograph_2(idtype):\n    src = np.random.randint(0, 50, 25)\n    dst = np.random.randint(0, 50, 25)\n    src1 = np.random.randint(0, 25, 10)\n    dst1 = np.random.randint(0, 25, 10)\n    src2 = np.random.randint(0, 100, 1000)\n    dst2 = np.random.randint(0, 100, 1000)\n    g = dgl.heterograph(\n        {\n            (\"user\", \"becomes\", \"player\"): (src, dst),\n            (\"user\", \"follows\", \"user\"): (src, dst),\n            (\"user\", \"plays\", \"game\"): (src, dst),\n            (\"user\", \"wishes\", \"game\"): (src1, dst1),\n            (\"developer\", \"develops\", \"game\"): (src2, dst2),\n        },\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    assert g.idtype == idtype\n    assert g.device == F.ctx()\n    return g\n\n\ndef create_test_heterograph_large(idtype):\n    src = np.random.randint(0, 50, 2500)\n    dst = np.random.randint(0, 50, 2500)\n    g = dgl.heterograph(\n        {\n            (\"user\", \"follows\", \"user\"): (src, dst),\n            (\"user\", \"plays\", \"game\"): (src, dst),\n            (\"user\", \"wishes\", \"game\"): (src, dst),\n            (\"developer\", \"develops\", \"game\"): (src, dst),\n        },\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    assert g.idtype == idtype\n    assert g.device == F.ctx()\n    return g\n\n\n@parametrize_idtype\ndef test_unary_copy_u(idtype):\n    def _test(mfunc, rfunc):\n        g = create_test_heterograph_2(idtype)\n        g0 = create_test_heterograph(idtype)\n        g1 = create_test_heterograph_large(idtype)\n        cross_reducer = rfunc.__name__\n        x1 = F.randn((g.num_nodes(\"user\"), feat_size))\n        x2 = F.randn((g.num_nodes(\"developer\"), feat_size))\n        F.attach_grad(x1)\n        F.attach_grad(x2)\n        g.nodes[\"user\"].data[\"h\"] = x1\n        g.nodes[\"developer\"].data[\"h\"] = x2\n\n        #################################################################\n        #  multi_update_all(): call msg_passing separately for each etype\n        #################################################################\n\n        with F.record_grad():\n            g.multi_update_all(\n                {\n                    etype: (mfunc(\"h\", \"m\"), rfunc(\"m\", \"y\"))\n                    for etype in g.canonical_etypes\n                },\n                cross_reducer,\n            )\n            r1 = g.nodes[\"game\"].data[\"y\"].clone()\n            r2 = g.nodes[\"user\"].data[\"y\"].clone()\n            r3 = g.nodes[\"player\"].data[\"y\"].clone()\n            loss = r1.sum() + r2.sum() + r3.sum()\n            F.backward(loss)\n            n_grad1 = F.grad(g.nodes[\"user\"].data[\"h\"]).clone()\n            n_grad2 = F.grad(g.nodes[\"developer\"].data[\"h\"]).clone()\n\n        g.nodes[\"user\"].data.clear()\n        g.nodes[\"developer\"].data.clear()\n        g.nodes[\"game\"].data.clear()\n        g.nodes[\"player\"].data.clear()\n\n        #################################################################\n        #  update_all(): call msg_passing for all etypes\n        #################################################################\n\n        F.attach_grad(x1)\n        F.attach_grad(x2)\n        g.nodes[\"user\"].data[\"h\"] = x1\n        g.nodes[\"developer\"].data[\"h\"] = x2\n\n        with F.record_grad():\n            g.update_all(mfunc(\"h\", \"m\"), rfunc(\"m\", \"y\"))\n            r4 = g.nodes[\"game\"].data[\"y\"]\n            r5 = g.nodes[\"user\"].data[\"y\"]\n            r6 = g.nodes[\"player\"].data[\"y\"]\n            loss = r4.sum() + r5.sum() + r6.sum()\n            F.backward(loss)\n            n_grad3 = F.grad(g.nodes[\"user\"].data[\"h\"])\n            n_grad4 = F.grad(g.nodes[\"developer\"].data[\"h\"])\n\n        assert F.allclose(r1, r4)\n        assert F.allclose(r2, r5)\n        assert F.allclose(r3, r6)\n        assert F.allclose(n_grad1, n_grad3)\n        assert F.allclose(n_grad2, n_grad4)\n\n    _test(fn.copy_u, fn.sum)\n    _test(fn.copy_u, fn.max)\n    _test(fn.copy_u, fn.min)\n    # _test('copy_u', 'mean')\n\n\n@parametrize_idtype\ndef test_unary_copy_e(idtype):\n    def _test(mfunc, rfunc):\n        g = create_test_heterograph_large(idtype)\n        g0 = create_test_heterograph_2(idtype)\n        g1 = create_test_heterograph(idtype)\n        cross_reducer = rfunc.__name__\n        x1 = F.randn((g.num_edges(\"plays\"), feat_size))\n        x2 = F.randn((g.num_edges(\"follows\"), feat_size))\n        x3 = F.randn((g.num_edges(\"develops\"), feat_size))\n        x4 = F.randn((g.num_edges(\"wishes\"), feat_size))\n        F.attach_grad(x1)\n        F.attach_grad(x2)\n        F.attach_grad(x3)\n        F.attach_grad(x4)\n        g[\"plays\"].edata[\"eid\"] = x1\n        g[\"follows\"].edata[\"eid\"] = x2\n        g[\"develops\"].edata[\"eid\"] = x3\n        g[\"wishes\"].edata[\"eid\"] = x4\n\n        #################################################################\n        #  multi_update_all(): call msg_passing separately for each etype\n        #################################################################\n\n        with F.record_grad():\n            g.multi_update_all(\n                {\n                    \"plays\": (mfunc(\"eid\", \"m\"), rfunc(\"m\", \"y\")),\n                    \"follows\": (mfunc(\"eid\", \"m\"), rfunc(\"m\", \"y\")),\n                    \"develops\": (mfunc(\"eid\", \"m\"), rfunc(\"m\", \"y\")),\n                    \"wishes\": (mfunc(\"eid\", \"m\"), rfunc(\"m\", \"y\")),\n                },\n                cross_reducer,\n            )\n            r1 = g.nodes[\"game\"].data[\"y\"].clone()\n            r2 = g.nodes[\"user\"].data[\"y\"].clone()\n            loss = r1.sum() + r2.sum()\n            F.backward(loss)\n            e_grad1 = F.grad(g[\"develops\"].edata[\"eid\"]).clone()\n            e_grad2 = F.grad(g[\"plays\"].edata[\"eid\"]).clone()\n            e_grad3 = F.grad(g[\"wishes\"].edata[\"eid\"]).clone()\n            e_grad4 = F.grad(g[\"follows\"].edata[\"eid\"]).clone()\n        {etype: (g[etype].edata.clear()) for _, etype, _ in g.canonical_etypes},\n\n        #################################################################\n        #  update_all(): call msg_passing for all etypes\n        #################################################################\n\n        # TODO(Israt): output type can be None in multi_update and empty\n        F.attach_grad(x1)\n        F.attach_grad(x2)\n        F.attach_grad(x3)\n        F.attach_grad(x4)\n\n        g[\"plays\"].edata[\"eid\"] = x1\n        g[\"follows\"].edata[\"eid\"] = x2\n        g[\"develops\"].edata[\"eid\"] = x3\n        g[\"wishes\"].edata[\"eid\"] = x4\n\n        with F.record_grad():\n            g.update_all(mfunc(\"eid\", \"m\"), rfunc(\"m\", \"y\"))\n            r3 = g.nodes[\"game\"].data[\"y\"]\n            r4 = g.nodes[\"user\"].data[\"y\"]\n            loss = r3.sum() + r4.sum()\n            F.backward(loss)\n            e_grad5 = F.grad(g[\"develops\"].edata[\"eid\"])\n            e_grad6 = F.grad(g[\"plays\"].edata[\"eid\"])\n            e_grad7 = F.grad(g[\"wishes\"].edata[\"eid\"])\n            e_grad8 = F.grad(g[\"follows\"].edata[\"eid\"])\n\n        # # correctness check\n        def _print_error(a, b):\n            for i, (x, y) in enumerate(\n                zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())\n            ):\n                if not np.allclose(x, y):\n                    print(\"@{} {} v.s. {}\".format(i, x, y))\n\n        assert F.allclose(r1, r3)\n        assert F.allclose(r2, r4)\n        assert F.allclose(e_grad1, e_grad5)\n        assert F.allclose(e_grad2, e_grad6)\n        assert F.allclose(e_grad3, e_grad7)\n        assert F.allclose(e_grad4, e_grad8)\n\n    _test(fn.copy_e, fn.sum)\n    _test(fn.copy_e, fn.max)\n    _test(fn.copy_e, fn.min)\n    # _test('copy_e', 'mean')\n\n\n@parametrize_idtype\ndef test_binary_op(idtype):\n    def _test(lhs, rhs, binary_op, reducer):\n        g = create_test_heterograph(idtype)\n\n        x1 = F.randn((g.num_nodes(\"user\"), feat_size))\n        x2 = F.randn((g.num_nodes(\"developer\"), feat_size))\n        x3 = F.randn((g.num_nodes(\"game\"), feat_size))\n\n        F.attach_grad(x1)\n        F.attach_grad(x2)\n        F.attach_grad(x3)\n        g.nodes[\"user\"].data[\"h\"] = x1\n        g.nodes[\"developer\"].data[\"h\"] = x2\n        g.nodes[\"game\"].data[\"h\"] = x3\n\n        x1 = F.randn((4, feat_size))\n        x2 = F.randn((4, feat_size))\n        x3 = F.randn((3, feat_size))\n        x4 = F.randn((3, feat_size))\n        F.attach_grad(x1)\n        F.attach_grad(x2)\n        F.attach_grad(x3)\n        F.attach_grad(x4)\n        g[\"plays\"].edata[\"h\"] = x1\n        g[\"follows\"].edata[\"h\"] = x2\n        g[\"develops\"].edata[\"h\"] = x3\n        g[\"wishes\"].edata[\"h\"] = x4\n\n        builtin_msg_name = \"{}_{}_{}\".format(lhs, binary_op, rhs)\n        builtin_msg = getattr(fn, builtin_msg_name)\n        builtin_red = getattr(fn, reducer)\n\n        #################################################################\n        #  multi_update_all(): call msg_passing separately for each etype\n        #################################################################\n\n        with F.record_grad():\n            g.multi_update_all(\n                {\n                    etype: (builtin_msg(\"h\", \"h\", \"m\"), builtin_red(\"m\", \"y\"))\n                    for etype in g.canonical_etypes\n                },\n                \"sum\",\n            )\n            r1 = g.nodes[\"game\"].data[\"y\"]\n            F.backward(r1, F.ones(r1.shape))\n            n_grad1 = F.grad(r1)\n\n        #################################################################\n        #  update_all(): call msg_passing for all etypes\n        #################################################################\n\n        g.update_all(builtin_msg(\"h\", \"h\", \"m\"), builtin_red(\"m\", \"y\"))\n        r2 = g.nodes[\"game\"].data[\"y\"]\n        F.backward(r2, F.ones(r2.shape))\n        n_grad2 = F.grad(r2)\n\n        # correctness check\n        def _print_error(a, b):\n            for i, (x, y) in enumerate(\n                zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())\n            ):\n                if not np.allclose(x, y):\n                    print(\"@{} {} v.s. {}\".format(i, x, y))\n\n        if not F.allclose(r1, r2):\n            _print_error(r1, r2)\n        assert F.allclose(r1, r2)\n        # TODO (Israt): r1 and r2 have different frad func associated with\n        # if not F.allclose(n_grad1, n_grad2):\n        #     print('node grad')\n        #     _print_error(n_grad1, n_grad2)\n        # assert(F.allclose(n_grad1, n_grad2))\n\n    target = [\"u\", \"v\", \"e\"]\n    for lhs, rhs in product(target, target):\n        if lhs == rhs:\n            continue\n        for binary_op in [\"add\", \"sub\", \"mul\", \"div\"]:\n            # TODO(Israt) :Add support for reduce func \"max\", \"min\", \"mean\"\n            for reducer in [\"sum\"]:\n                print(lhs, rhs, binary_op, reducer)\n                _test(lhs, rhs, binary_op, reducer)\n\n\n# Issue #5873\ndef test_multi_update_all_minmax_reduce_with_isolated_nodes():\n    g = dgl.heterograph(\n        {\n            (\"A\", \"AB\", \"B\"): ([0, 1, 2, 3], [0, 0, 1, 1]),\n            (\"C\", \"CB\", \"B\"): ([0, 1, 2, 3], [2, 2, 3, 3]),\n        },\n        device=F.ctx(),\n    )\n    g.nodes[\"A\"].data[\"x\"] = F.randn((4, 16))\n    g.nodes[\"C\"].data[\"x\"] = F.randn((4, 16))\n    g.multi_update_all(\n        {\n            \"AB\": (dgl.function.copy_u(\"x\", \"m\"), dgl.function.min(\"m\", \"a1\")),\n            \"CB\": (dgl.function.copy_u(\"x\", \"m\"), dgl.function.min(\"m\", \"a2\")),\n        },\n        cross_reducer=\"min\",\n    )\n    assert not np.isinf(F.asnumpy(g.nodes[\"B\"].data[\"a1\"])).any()\n    assert not np.isinf(F.asnumpy(g.nodes[\"B\"].data[\"a2\"])).any()\n\n    g.multi_update_all(\n        {\n            \"AB\": (dgl.function.copy_u(\"x\", \"m\"), dgl.function.max(\"m\", \"a1\")),\n            \"CB\": (dgl.function.copy_u(\"x\", \"m\"), dgl.function.max(\"m\", \"a2\")),\n        },\n        cross_reducer=\"max\",\n    )\n    assert not np.isinf(F.asnumpy(g.nodes[\"B\"].data[\"a1\"])).any()\n    assert not np.isinf(F.asnumpy(g.nodes[\"B\"].data[\"a2\"])).any()\n\n\nif __name__ == \"__main__\":\n    test_unary_copy_u()\n    test_unary_copy_e()\n    test_binary_op()\n"
  },
  {
    "path": "tests/python/common/test_heterograph.py",
    "content": "import itertools\nimport multiprocessing as mp\nimport unittest\nfrom collections import Counter\n\nimport backend as F\n\nimport dgl\nimport dgl.function as fn\nimport networkx as nx\nimport numpy as np\nimport pytest\nimport scipy.sparse as ssp\nfrom dgl import DGLError\nfrom scipy.sparse import rand\nfrom utils import (\n    assert_is_identical_hetero,\n    check_graph_equal,\n    get_cases,\n    parametrize_idtype,\n)\n\n\ndef create_test_heterograph(idtype):\n    # test heterograph from the docstring, plus a user -- wishes -- game relation\n    # 3 users, 2 games, 2 developers\n    # metagraph:\n    #    ('user', 'follows', 'user'),\n    #    ('user', 'plays', 'game'),\n    #    ('user', 'wishes', 'game'),\n    #    ('developer', 'develops', 'game')])\n\n    g = dgl.heterograph(\n        {\n            (\"user\", \"follows\", \"user\"): ([0, 1], [1, 2]),\n            (\"user\", \"plays\", \"game\"): ([0, 1, 2, 1], [0, 0, 1, 1]),\n            (\"user\", \"wishes\", \"game\"): ([0, 2], [1, 0]),\n            (\"developer\", \"develops\", \"game\"): ([0, 1], [0, 1]),\n        },\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    assert g.idtype == idtype\n    assert g.device == F.ctx()\n    return g\n\n\ndef create_test_heterograph1(idtype):\n    edges = []\n    edges.extend([(0, 1), (1, 2)])  # follows\n    edges.extend([(0, 3), (1, 3), (2, 4), (1, 4)])  # plays\n    edges.extend([(0, 4), (2, 3)])  # wishes\n    edges.extend([(5, 3), (6, 4)])  # develops\n    edges = tuple(zip(*edges))\n    ntypes = F.tensor([0, 0, 0, 1, 1, 2, 2])\n    etypes = F.tensor([0, 0, 1, 1, 1, 1, 2, 2, 3, 3])\n    g0 = dgl.graph(edges, idtype=idtype, device=F.ctx())\n    g0.ndata[dgl.NTYPE] = ntypes\n    g0.edata[dgl.ETYPE] = etypes\n    return dgl.to_heterogeneous(\n        g0,\n        [\"user\", \"game\", \"developer\"],\n        [\"follows\", \"plays\", \"wishes\", \"develops\"],\n    )\n\n\ndef create_test_heterograph2(idtype):\n    g = dgl.heterograph(\n        {\n            (\"user\", \"follows\", \"user\"): ([0, 1], [1, 2]),\n            (\"user\", \"plays\", \"game\"): ([0, 1, 2, 1], [0, 0, 1, 1]),\n            (\"user\", \"wishes\", \"game\"): (\"csr\", ([0, 1, 1, 2], [1, 0], [])),\n            (\"developer\", \"develops\", \"game\"): (\n                \"csc\",\n                ([0, 1, 2], [0, 1], [0, 1]),\n            ),\n        },\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    assert g.idtype == idtype\n    assert g.device == F.ctx()\n    return g\n\n\ndef create_test_heterograph3(idtype):\n    g = dgl.heterograph(\n        {\n            (\"user\", \"plays\", \"game\"): (\n                F.tensor([0, 1, 1, 2], dtype=idtype),\n                F.tensor([0, 0, 1, 1], dtype=idtype),\n            ),\n            (\"developer\", \"develops\", \"game\"): (\n                F.tensor([0, 1], dtype=idtype),\n                F.tensor([0, 1], dtype=idtype),\n            ),\n        },\n        idtype=idtype,\n        device=F.ctx(),\n    )\n\n    g.nodes[\"user\"].data[\"h\"] = F.copy_to(\n        F.tensor([1, 1, 1], dtype=idtype), ctx=F.ctx()\n    )\n    g.nodes[\"game\"].data[\"h\"] = F.copy_to(\n        F.tensor([2, 2], dtype=idtype), ctx=F.ctx()\n    )\n    g.nodes[\"developer\"].data[\"h\"] = F.copy_to(\n        F.tensor([3, 3], dtype=idtype), ctx=F.ctx()\n    )\n    g.edges[\"plays\"].data[\"h\"] = F.copy_to(\n        F.tensor([1, 1, 1, 1], dtype=idtype), ctx=F.ctx()\n    )\n    return g\n\n\ndef create_test_heterograph4(idtype):\n    g = dgl.heterograph(\n        {\n            (\"user\", \"follows\", \"user\"): (\n                F.tensor([0, 1, 1, 2, 2, 2], dtype=idtype),\n                F.tensor([0, 0, 1, 1, 2, 2], dtype=idtype),\n            ),\n            (\"user\", \"plays\", \"game\"): (\n                F.tensor([0, 1], dtype=idtype),\n                F.tensor([0, 1], dtype=idtype),\n            ),\n        },\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    g.nodes[\"user\"].data[\"h\"] = F.copy_to(\n        F.tensor([1, 1, 1], dtype=idtype), ctx=F.ctx()\n    )\n    g.nodes[\"game\"].data[\"h\"] = F.copy_to(\n        F.tensor([2, 2], dtype=idtype), ctx=F.ctx()\n    )\n    g.edges[\"follows\"].data[\"h\"] = F.copy_to(\n        F.tensor([1, 2, 3, 4, 5, 6], dtype=idtype), ctx=F.ctx()\n    )\n    g.edges[\"plays\"].data[\"h\"] = F.copy_to(\n        F.tensor([1, 2], dtype=idtype), ctx=F.ctx()\n    )\n    return g\n\n\ndef create_test_heterograph5(idtype):\n    g = dgl.heterograph(\n        {\n            (\"user\", \"follows\", \"user\"): (\n                F.tensor([1, 2], dtype=idtype),\n                F.tensor([0, 1], dtype=idtype),\n            ),\n            (\"user\", \"plays\", \"game\"): (\n                F.tensor([0, 1], dtype=idtype),\n                F.tensor([0, 1], dtype=idtype),\n            ),\n        },\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    g.nodes[\"user\"].data[\"h\"] = F.copy_to(\n        F.tensor([1, 1, 1], dtype=idtype), ctx=F.ctx()\n    )\n    g.nodes[\"game\"].data[\"h\"] = F.copy_to(\n        F.tensor([2, 2], dtype=idtype), ctx=F.ctx()\n    )\n    g.edges[\"follows\"].data[\"h\"] = F.copy_to(\n        F.tensor([1, 2], dtype=idtype), ctx=F.ctx()\n    )\n    g.edges[\"plays\"].data[\"h\"] = F.copy_to(\n        F.tensor([1, 2], dtype=idtype), ctx=F.ctx()\n    )\n    return g\n\n\ndef get_redfn(name):\n    return getattr(F, name)\n\n\n@parametrize_idtype\ndef test_create(idtype):\n    device = F.ctx()\n    g0 = create_test_heterograph(idtype)\n    g1 = create_test_heterograph1(idtype)\n    g2 = create_test_heterograph2(idtype)\n    assert set(g0.ntypes) == set(g1.ntypes) == set(g2.ntypes)\n    assert (\n        set(g0.canonical_etypes)\n        == set(g1.canonical_etypes)\n        == set(g2.canonical_etypes)\n    )\n\n    # Create a bipartite graph from a SciPy matrix\n    src_ids = np.array([2, 3, 4])\n    dst_ids = np.array([1, 2, 3])\n    eweight = np.array([0.2, 0.3, 0.5])\n    sp_mat = ssp.coo_matrix((eweight, (src_ids, dst_ids)))\n    g = dgl.bipartite_from_scipy(\n        sp_mat,\n        utype=\"user\",\n        etype=\"plays\",\n        vtype=\"game\",\n        idtype=idtype,\n        device=device,\n    )\n    assert g.idtype == idtype\n    assert g.device == device\n    assert g.num_src_nodes() == 5\n    assert g.num_dst_nodes() == 4\n    assert g.num_edges() == 3\n    src, dst = g.edges()\n    assert F.allclose(src, F.tensor([2, 3, 4], dtype=idtype))\n    assert F.allclose(dst, F.tensor([1, 2, 3], dtype=idtype))\n    g = dgl.bipartite_from_scipy(\n        sp_mat,\n        utype=\"_U\",\n        etype=\"_E\",\n        vtype=\"_V\",\n        eweight_name=\"w\",\n        idtype=idtype,\n        device=device,\n    )\n    assert F.allclose(g.edata[\"w\"], F.tensor(eweight))\n\n    # Create a bipartite graph from a NetworkX graph\n    nx_g = nx.DiGraph()\n    nx_g.add_nodes_from(\n        [1, 3], bipartite=0, feat1=np.zeros((2)), feat2=np.ones((2))\n    )\n    nx_g.add_nodes_from([2, 4, 5], bipartite=1, feat3=np.zeros((3)))\n    nx_g.add_edge(1, 4, weight=np.ones((1)), eid=np.array([1]))\n    nx_g.add_edge(3, 5, weight=np.ones((1)), eid=np.array([0]))\n    g = dgl.bipartite_from_networkx(\n        nx_g,\n        utype=\"user\",\n        etype=\"plays\",\n        vtype=\"game\",\n        idtype=idtype,\n        device=device,\n    )\n    assert g.idtype == idtype\n    assert g.device == device\n    assert g.num_src_nodes() == 2\n    assert g.num_dst_nodes() == 3\n    assert g.num_edges() == 2\n    src, dst = g.edges()\n    assert F.allclose(src, F.tensor([0, 1], dtype=idtype))\n    assert F.allclose(dst, F.tensor([1, 2], dtype=idtype))\n    g = dgl.bipartite_from_networkx(\n        nx_g,\n        utype=\"_U\",\n        etype=\"_E\",\n        vtype=\"V\",\n        u_attrs=[\"feat1\", \"feat2\"],\n        e_attrs=[\"weight\"],\n        v_attrs=[\"feat3\"],\n    )\n    assert F.allclose(g.srcdata[\"feat1\"], F.tensor(np.zeros((2, 2))))\n    assert F.allclose(g.srcdata[\"feat2\"], F.tensor(np.ones((2, 2))))\n    assert F.allclose(g.dstdata[\"feat3\"], F.tensor(np.zeros((3, 3))))\n    assert F.allclose(g.edata[\"weight\"], F.tensor(np.ones((2, 1))))\n    g = dgl.bipartite_from_networkx(\n        nx_g,\n        utype=\"_U\",\n        etype=\"_E\",\n        vtype=\"V\",\n        edge_id_attr_name=\"eid\",\n        idtype=idtype,\n        device=device,\n    )\n    src, dst = g.edges()\n    assert F.allclose(src, F.tensor([1, 0], dtype=idtype))\n    assert F.allclose(dst, F.tensor([2, 1], dtype=idtype))\n\n    # create from scipy\n    spmat = ssp.coo_matrix(([1, 1, 1], ([0, 0, 1], [2, 3, 2])), shape=(4, 4))\n    g = dgl.from_scipy(spmat, idtype=idtype, device=device)\n    assert g.num_nodes() == 4\n    assert g.num_edges() == 3\n    assert g.idtype == idtype\n    assert g.device == device\n\n    # test inferring number of nodes for heterograph\n    g = dgl.heterograph(\n        {\n            (\"l0\", \"e0\", \"l1\"): ([0, 0], [1, 2]),\n            (\"l0\", \"e1\", \"l2\"): ([2], [2]),\n            (\"l2\", \"e2\", \"l2\"): ([1, 3], [1, 3]),\n        },\n        idtype=idtype,\n        device=device,\n    )\n    assert g.num_nodes(\"l0\") == 3\n    assert g.num_nodes(\"l1\") == 3\n    assert g.num_nodes(\"l2\") == 4\n    assert g.idtype == idtype\n    assert g.device == device\n\n    # test if validate flag works\n    # homo graph\n    with pytest.raises(DGLError):\n        g = dgl.graph(\n            ([0, 0, 0, 1, 1, 2], [0, 1, 2, 0, 1, 2]),\n            num_nodes=2,\n            idtype=idtype,\n            device=device,\n        )\n\n    # bipartite graph\n    def _test_validate_bipartite(card):\n        with pytest.raises(DGLError):\n            g = dgl.heterograph(\n                {(\"_U\", \"_E\", \"_V\"): ([0, 0, 1, 1, 2], [1, 1, 2, 2, 3])},\n                {\"_U\": card[0], \"_V\": card[1]},\n                idtype=idtype,\n                device=device,\n            )\n\n    _test_validate_bipartite((3, 3))\n    _test_validate_bipartite((2, 4))\n\n    # test from_scipy\n    num_nodes = 10\n    density = 0.25\n    for fmt in [\"csr\", \"coo\", \"csc\"]:\n        adj = rand(num_nodes, num_nodes, density=density, format=fmt)\n        g = dgl.from_scipy(adj, eweight_name=\"w\", idtype=idtype)\n        assert g.idtype == idtype\n        assert g.device == F.cpu()\n        assert F.array_equal(\n            g.edata[\"w\"], F.copy_to(F.tensor(adj.data), F.cpu())\n        )\n\n\ndef test_create2():\n    mat = ssp.random(20, 30, 0.1)\n\n    # coo\n    mat = mat.tocoo()\n    row = F.tensor(mat.row, dtype=F.int64)\n    col = F.tensor(mat.col, dtype=F.int64)\n    g = dgl.heterograph(\n        {(\"A\", \"AB\", \"B\"): (\"coo\", (row, col))},\n        num_nodes_dict={\"A\": 20, \"B\": 30},\n    )\n\n    # csr\n    mat = mat.tocsr()\n    indptr = F.tensor(mat.indptr, dtype=F.int64)\n    indices = F.tensor(mat.indices, dtype=F.int64)\n    data = F.tensor([], dtype=F.int64)\n    g = dgl.heterograph(\n        {(\"A\", \"AB\", \"B\"): (\"csr\", (indptr, indices, data))},\n        num_nodes_dict={\"A\": 20, \"B\": 30},\n    )\n\n    # csc\n    mat = mat.tocsc()\n    indptr = F.tensor(mat.indptr, dtype=F.int64)\n    indices = F.tensor(mat.indices, dtype=F.int64)\n    data = F.tensor([], dtype=F.int64)\n    g = dgl.heterograph(\n        {(\"A\", \"AB\", \"B\"): (\"csc\", (indptr, indices, data))},\n        num_nodes_dict={\"A\": 20, \"B\": 30},\n    )\n\n\n@parametrize_idtype\ndef test_query(idtype):\n    g = create_test_heterograph(idtype)\n\n    ntypes = [\"user\", \"game\", \"developer\"]\n    canonical_etypes = [\n        (\"user\", \"follows\", \"user\"),\n        (\"user\", \"plays\", \"game\"),\n        (\"user\", \"wishes\", \"game\"),\n        (\"developer\", \"develops\", \"game\"),\n    ]\n    etypes = [\"follows\", \"plays\", \"wishes\", \"develops\"]\n\n    # node & edge types\n    assert set(ntypes) == set(g.ntypes)\n    assert set(etypes) == set(g.etypes)\n    assert set(canonical_etypes) == set(g.canonical_etypes)\n\n    # metagraph\n    mg = g.metagraph()\n    assert set(g.ntypes) == set(mg.nodes)\n    etype_triplets = [(u, v, e) for u, v, e in mg.edges(keys=True)]\n    assert set(\n        [\n            (\"user\", \"user\", \"follows\"),\n            (\"user\", \"game\", \"plays\"),\n            (\"user\", \"game\", \"wishes\"),\n            (\"developer\", \"game\", \"develops\"),\n        ]\n    ) == set(etype_triplets)\n    for i in range(len(etypes)):\n        assert g.to_canonical_etype(etypes[i]) == canonical_etypes[i]\n\n    def _test(g):\n        # number of nodes\n        assert [g.num_nodes(ntype) for ntype in ntypes] == [3, 2, 2]\n\n        # number of edges\n        assert [g.num_edges(etype) for etype in etypes] == [2, 4, 2, 2]\n\n        # has_nodes\n        for ntype in ntypes:\n            n = g.num_nodes(ntype)\n            for i in range(n):\n                assert g.has_nodes(i, ntype)\n            assert not g.has_nodes(n, ntype)\n            assert np.array_equal(\n                F.asnumpy(g.has_nodes([0, n], ntype)).astype(\"int32\"), [1, 0]\n            )\n\n        assert not g.is_multigraph\n\n        for etype in etypes:\n            srcs, dsts = edges[etype]\n            for src, dst in zip(srcs, dsts):\n                assert g.has_edges_between(src, dst, etype)\n            assert F.asnumpy(g.has_edges_between(srcs, dsts, etype)).all()\n\n            srcs, dsts = negative_edges[etype]\n            for src, dst in zip(srcs, dsts):\n                assert not g.has_edges_between(src, dst, etype)\n            assert not F.asnumpy(g.has_edges_between(srcs, dsts, etype)).any()\n\n            srcs, dsts = edges[etype]\n            n_edges = len(srcs)\n\n            # predecessors & in_edges & in_degree\n            pred = [s for s, d in zip(srcs, dsts) if d == 0]\n            assert set(F.asnumpy(g.predecessors(0, etype)).tolist()) == set(\n                pred\n            )\n            u, v = g.in_edges([0], etype=etype)\n            assert F.asnumpy(v).tolist() == [0] * len(pred)\n            assert set(F.asnumpy(u).tolist()) == set(pred)\n            assert g.in_degrees(0, etype) == len(pred)\n\n            # successors & out_edges & out_degree\n            succ = [d for s, d in zip(srcs, dsts) if s == 0]\n            assert set(F.asnumpy(g.successors(0, etype)).tolist()) == set(succ)\n            u, v = g.out_edges([0], etype=etype)\n            assert F.asnumpy(u).tolist() == [0] * len(succ)\n            assert set(F.asnumpy(v).tolist()) == set(succ)\n            assert g.out_degrees(0, etype) == len(succ)\n\n            # edge_ids\n            for i, (src, dst) in enumerate(zip(srcs, dsts)):\n                assert g.edge_ids(src, dst, etype=etype) == i\n                _, _, eid = g.edge_ids(src, dst, etype=etype, return_uv=True)\n                assert eid == i\n            assert F.asnumpy(\n                g.edge_ids(srcs, dsts, etype=etype)\n            ).tolist() == list(range(n_edges))\n            u, v, e = g.edge_ids(srcs, dsts, etype=etype, return_uv=True)\n            u, v, e = F.asnumpy(u), F.asnumpy(v), F.asnumpy(e)\n            assert u[e].tolist() == srcs\n            assert v[e].tolist() == dsts\n\n            # find_edges\n            for eid in [\n                list(range(n_edges)),\n                np.arange(n_edges),\n                F.astype(F.arange(0, n_edges), g.idtype),\n            ]:\n                u, v = g.find_edges(eid, etype)\n                assert F.asnumpy(u).tolist() == srcs\n                assert F.asnumpy(v).tolist() == dsts\n\n            # all_edges.\n            for order in [\"eid\"]:\n                u, v, e = g.edges(\"all\", order, etype)\n                assert F.asnumpy(u).tolist() == srcs\n                assert F.asnumpy(v).tolist() == dsts\n                assert F.asnumpy(e).tolist() == list(range(n_edges))\n\n            # in_degrees & out_degrees\n            in_degrees = F.asnumpy(g.in_degrees(etype=etype))\n            out_degrees = F.asnumpy(g.out_degrees(etype=etype))\n            src_count = Counter(srcs)\n            dst_count = Counter(dsts)\n            utype, _, vtype = g.to_canonical_etype(etype)\n            for i in range(g.num_nodes(utype)):\n                assert out_degrees[i] == src_count[i]\n            for i in range(g.num_nodes(vtype)):\n                assert in_degrees[i] == dst_count[i]\n\n    edges = {\n        \"follows\": ([0, 1], [1, 2]),\n        \"plays\": ([0, 1, 2, 1], [0, 0, 1, 1]),\n        \"wishes\": ([0, 2], [1, 0]),\n        \"develops\": ([0, 1], [0, 1]),\n    }\n    # edges that does not exist in the graph\n    negative_edges = {\n        \"follows\": ([0, 1], [0, 1]),\n        \"plays\": ([0, 2], [1, 0]),\n        \"wishes\": ([0, 1], [0, 1]),\n        \"develops\": ([0, 1], [1, 0]),\n    }\n    g = create_test_heterograph(idtype)\n    _test(g)\n    g = create_test_heterograph1(idtype)\n    _test(g)\n    if F._default_context_str != \"gpu\":\n        # XXX: CUDA COO operators have not been live yet.\n        g = create_test_heterograph2(idtype)\n        _test(g)\n\n    etypes = canonical_etypes\n    edges = {\n        (\"user\", \"follows\", \"user\"): ([0, 1], [1, 2]),\n        (\"user\", \"plays\", \"game\"): ([0, 1, 2, 1], [0, 0, 1, 1]),\n        (\"user\", \"wishes\", \"game\"): ([0, 2], [1, 0]),\n        (\"developer\", \"develops\", \"game\"): ([0, 1], [0, 1]),\n    }\n    # edges that does not exist in the graph\n    negative_edges = {\n        (\"user\", \"follows\", \"user\"): ([0, 1], [0, 1]),\n        (\"user\", \"plays\", \"game\"): ([0, 2], [1, 0]),\n        (\"user\", \"wishes\", \"game\"): ([0, 1], [0, 1]),\n        (\"developer\", \"develops\", \"game\"): ([0, 1], [1, 0]),\n    }\n    g = create_test_heterograph(idtype)\n    _test(g)\n    g = create_test_heterograph1(idtype)\n    _test(g)\n    if F._default_context_str != \"gpu\":\n        # XXX: CUDA COO operators have not been live yet.\n        g = create_test_heterograph2(idtype)\n        _test(g)\n\n    # test repr\n    print(g)\n\n\n@parametrize_idtype\ndef test_empty_query(idtype):\n    g = dgl.graph(([1, 2, 3], [0, 4, 5]), idtype=idtype, device=F.ctx())\n    g.add_nodes(0)\n    g.add_edges([], [])\n    g.remove_edges([])\n    g.remove_nodes([])\n    assert F.shape(g.has_nodes([])) == (0,)\n    assert F.shape(g.has_edges_between([], [])) == (0,)\n    g.edge_ids([], [])\n    g.edge_ids([], [], return_uv=True)\n    g.find_edges([])\n\n    assert F.shape(g.in_edges([], form=\"eid\")) == (0,)\n    u, v = g.in_edges([], form=\"uv\")\n    assert F.shape(u) == (0,)\n    assert F.shape(v) == (0,)\n    u, v, e = g.in_edges([], form=\"all\")\n    assert F.shape(u) == (0,)\n    assert F.shape(v) == (0,)\n    assert F.shape(e) == (0,)\n\n    assert F.shape(g.out_edges([], form=\"eid\")) == (0,)\n    u, v = g.out_edges([], form=\"uv\")\n    assert F.shape(u) == (0,)\n    assert F.shape(v) == (0,)\n    u, v, e = g.out_edges([], form=\"all\")\n    assert F.shape(u) == (0,)\n    assert F.shape(v) == (0,)\n    assert F.shape(e) == (0,)\n\n    assert F.shape(g.in_degrees([])) == (0,)\n    assert F.shape(g.out_degrees([])) == (0,)\n\n    g = dgl.graph(([], []), idtype=idtype, device=F.ctx())\n    error_thrown = True\n    try:\n        g.in_degrees([0])\n        fail = False\n    except:\n        pass\n    assert error_thrown\n    error_thrown = True\n    try:\n        g.out_degrees([0])\n        fail = False\n    except:\n        pass\n    assert error_thrown\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\", reason=\"GPU does not have COO impl.\"\n)\ndef _test_hypersparse():\n    N1 = 1 << 50  # should crash if allocated a CSR\n    N2 = 1 << 48\n\n    g = dgl.heterograph(\n        {\n            (\"user\", \"follows\", \"user\"): (\n                F.tensor([0], F.int64),\n                F.tensor([1], F.int64),\n            ),\n            (\"user\", \"plays\", \"game\"): (\n                F.tensor([0], F.int64),\n                F.tensor([N2], F.int64),\n            ),\n        },\n        {\"user\": N1, \"game\": N1},\n        device=F.ctx(),\n    )\n    assert g.num_nodes(\"user\") == N1\n    assert g.num_nodes(\"game\") == N1\n    assert g.num_edges(\"follows\") == 1\n    assert g.num_edges(\"plays\") == 1\n\n    assert g.has_edges_between(0, 1, \"follows\")\n    assert not g.has_edges_between(0, 0, \"follows\")\n    mask = F.asnumpy(g.has_edges_between([0, 0], [0, 1], \"follows\")).tolist()\n    assert mask == [0, 1]\n\n    assert g.has_edges_between(0, N2, \"plays\")\n    assert not g.has_edges_between(0, 0, \"plays\")\n    mask = F.asnumpy(g.has_edges_between([0, 0], [0, N2], \"plays\")).tolist()\n    assert mask == [0, 1]\n\n    assert F.asnumpy(g.predecessors(0, \"follows\")).tolist() == []\n    assert F.asnumpy(g.successors(0, \"follows\")).tolist() == [1]\n    assert F.asnumpy(g.predecessors(1, \"follows\")).tolist() == [0]\n    assert F.asnumpy(g.successors(1, \"follows\")).tolist() == []\n\n    assert F.asnumpy(g.predecessors(0, \"plays\")).tolist() == []\n    assert F.asnumpy(g.successors(0, \"plays\")).tolist() == [N2]\n    assert F.asnumpy(g.predecessors(N2, \"plays\")).tolist() == [0]\n    assert F.asnumpy(g.successors(N2, \"plays\")).tolist() == []\n\n    assert g.edge_ids(0, 1, etype=\"follows\") == 0\n    assert g.edge_ids(0, N2, etype=\"plays\") == 0\n\n    u, v = g.find_edges([0], \"follows\")\n    assert F.asnumpy(u).tolist() == [0]\n    assert F.asnumpy(v).tolist() == [1]\n    u, v = g.find_edges([0], \"plays\")\n    assert F.asnumpy(u).tolist() == [0]\n    assert F.asnumpy(v).tolist() == [N2]\n    u, v, e = g.all_edges(\"all\", \"eid\", \"follows\")\n    assert F.asnumpy(u).tolist() == [0]\n    assert F.asnumpy(v).tolist() == [1]\n    assert F.asnumpy(e).tolist() == [0]\n    u, v, e = g.all_edges(\"all\", \"eid\", \"plays\")\n    assert F.asnumpy(u).tolist() == [0]\n    assert F.asnumpy(v).tolist() == [N2]\n    assert F.asnumpy(e).tolist() == [0]\n\n    assert g.in_degrees(0, \"follows\") == 0\n    assert g.in_degrees(1, \"follows\") == 1\n    assert F.asnumpy(g.in_degrees([0, 1], \"follows\")).tolist() == [0, 1]\n    assert g.in_degrees(0, \"plays\") == 0\n    assert g.in_degrees(N2, \"plays\") == 1\n    assert F.asnumpy(g.in_degrees([0, N2], \"plays\")).tolist() == [0, 1]\n    assert g.out_degrees(0, \"follows\") == 1\n    assert g.out_degrees(1, \"follows\") == 0\n    assert F.asnumpy(g.out_degrees([0, 1], \"follows\")).tolist() == [1, 0]\n    assert g.out_degrees(0, \"plays\") == 1\n    assert g.out_degrees(N2, \"plays\") == 0\n    assert F.asnumpy(g.out_degrees([0, N2], \"plays\")).tolist() == [1, 0]\n\n\ndef _test_edge_ids():\n    N1 = 1 << 50  # should crash if allocated a CSR\n    N2 = 1 << 48\n\n    g = dgl.heterograph(\n        {\n            (\"user\", \"follows\", \"user\"): (\n                F.tensor([0], F.int64),\n                F.tensor([1], F.int64),\n            ),\n            (\"user\", \"plays\", \"game\"): (\n                F.tensor([0], F.int64),\n                F.tensor([N2], F.int64),\n            ),\n        },\n        {\"user\": N1, \"game\": N1},\n    )\n    with pytest.raises(DGLError):\n        eid = g.edge_ids(0, 0, etype=\"follows\")\n\n    g2 = dgl.heterograph(\n        {\n            (\"user\", \"follows\", \"user\"): (\n                F.tensor([0, 0], F.int64),\n                F.tensor([1, 1], F.int64),\n            ),\n            (\"user\", \"plays\", \"game\"): (\n                F.tensor([0], F.int64),\n                F.tensor([N2], F.int64),\n            ),\n        },\n        {\"user\": N1, \"game\": N1},\n        device=F.cpu(),\n    )\n\n    eid = g2.edge_ids(0, 1, etype=\"follows\")\n    assert eid == 0\n\n\n@pytest.mark.skipif(\n    F.backend_name != \"pytorch\", reason=\"Only support PyTorch for now\"\n)\n@parametrize_idtype\ndef test_adj(idtype):\n    g = create_test_heterograph(idtype)\n    adj = g.adj(\"follows\")\n    assert F.asnumpy(adj.indices()).tolist() == [[0, 1], [1, 2]]\n    assert np.allclose(F.asnumpy(adj.val), np.array([1, 1]))\n    g.edata[\"h\"] = {(\"user\", \"plays\", \"game\"): F.tensor([1, 2, 3, 4])}\n    print(g.edata[\"h\"])\n    adj = g.adj(\"plays\", \"h\")\n    assert F.asnumpy(adj.indices()).tolist() == [[0, 1, 2, 1], [0, 0, 1, 1]]\n    assert np.allclose(F.asnumpy(adj.val), np.array([1, 2, 3, 4]))\n\n\n@parametrize_idtype\ndef test_adj_external(idtype):\n    g = create_test_heterograph(idtype)\n    adj = F.sparse_to_numpy(g.adj_external(transpose=True, etype=\"follows\"))\n    assert np.allclose(\n        adj, np.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]])\n    )\n    adj = F.sparse_to_numpy(g.adj_external(transpose=False, etype=\"follows\"))\n    assert np.allclose(\n        adj, np.array([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 0.0]])\n    )\n    adj = F.sparse_to_numpy(g.adj_external(transpose=True, etype=\"plays\"))\n    assert np.allclose(adj, np.array([[1.0, 1.0, 0.0], [0.0, 1.0, 1.0]]))\n    adj = F.sparse_to_numpy(g.adj_external(transpose=False, etype=\"plays\"))\n    assert np.allclose(adj, np.array([[1.0, 0.0], [1.0, 1.0], [0.0, 1.0]]))\n\n    adj = g.adj_external(transpose=True, scipy_fmt=\"csr\", etype=\"follows\")\n    assert np.allclose(\n        adj.todense(),\n        np.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]),\n    )\n    adj = g.adj_external(transpose=True, scipy_fmt=\"coo\", etype=\"follows\")\n    assert np.allclose(\n        adj.todense(),\n        np.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]),\n    )\n    adj = g.adj_external(transpose=True, scipy_fmt=\"csr\", etype=\"plays\")\n    assert np.allclose(\n        adj.todense(), np.array([[1.0, 1.0, 0.0], [0.0, 1.0, 1.0]])\n    )\n    adj = g.adj_external(transpose=True, scipy_fmt=\"coo\", etype=\"plays\")\n    assert np.allclose(\n        adj.todense(), np.array([[1.0, 1.0, 0.0], [0.0, 1.0, 1.0]])\n    )\n    adj = F.sparse_to_numpy(g[\"follows\"].adj_external(transpose=True))\n    assert np.allclose(\n        adj, np.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]])\n    )\n\n\n@parametrize_idtype\ndef test_inc(idtype):\n    g = create_test_heterograph(idtype)\n    adj = F.sparse_to_numpy(g[\"follows\"].inc(\"in\"))\n    assert np.allclose(adj, np.array([[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]]))\n    adj = F.sparse_to_numpy(g[\"follows\"].inc(\"out\"))\n    assert np.allclose(adj, np.array([[1.0, 0.0], [0.0, 1.0], [0.0, 0.0]]))\n    adj = F.sparse_to_numpy(g[\"follows\"].inc(\"both\"))\n    assert np.allclose(adj, np.array([[-1.0, 0.0], [1.0, -1.0], [0.0, 1.0]]))\n    adj = F.sparse_to_numpy(g.inc(\"in\", etype=\"plays\"))\n    assert np.allclose(\n        adj, np.array([[1.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 1.0]])\n    )\n    adj = F.sparse_to_numpy(g.inc(\"out\", etype=\"plays\"))\n    assert np.allclose(\n        adj,\n        np.array(\n            [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0.0]]\n        ),\n    )\n    adj = F.sparse_to_numpy(g.inc(\"both\", etype=\"follows\"))\n    assert np.allclose(adj, np.array([[-1.0, 0.0], [1.0, -1.0], [0.0, 1.0]]))\n\n\n@parametrize_idtype\ndef test_view(idtype):\n    # test single node type\n    g = dgl.heterograph(\n        {(\"user\", \"follows\", \"user\"): ([0, 1], [1, 2])},\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    f1 = F.randn((3, 6))\n    g.ndata[\"h\"] = f1\n    f2 = g.nodes[\"user\"].data[\"h\"]\n    assert F.array_equal(f1, f2)\n    fail = False\n    try:\n        g.ndata[\"h\"] = {\"user\": f1}\n    except Exception:\n        fail = True\n    assert fail\n\n    # test single edge type\n    f3 = F.randn((2, 4))\n    g.edata[\"h\"] = f3\n    f4 = g.edges[\"follows\"].data[\"h\"]\n    assert F.array_equal(f3, f4)\n    fail = False\n    try:\n        g.edata[\"h\"] = {\"follows\": f3}\n    except Exception:\n        fail = True\n    assert fail\n\n    # test data view\n    g = create_test_heterograph(idtype)\n\n    f1 = F.randn((3, 6))\n    g.nodes[\"user\"].data[\"h\"] = f1  # ok\n    f2 = g.nodes[\"user\"].data[\"h\"]\n    assert F.array_equal(f1, f2)\n    assert F.array_equal(g.nodes(\"user\"), F.arange(0, 3, idtype))\n    g.nodes[\"user\"].data.pop(\"h\")\n\n    # multi type ndata\n    f1 = F.randn((3, 6))\n    f2 = F.randn((2, 6))\n    fail = False\n    try:\n        g.ndata[\"h\"] = f1\n    except Exception:\n        fail = True\n    assert fail\n\n    f3 = F.randn((2, 4))\n    g.edges[\"user\", \"follows\", \"user\"].data[\"h\"] = f3\n    f4 = g.edges[\"user\", \"follows\", \"user\"].data[\"h\"]\n    f5 = g.edges[\"follows\"].data[\"h\"]\n    assert F.array_equal(f3, f4)\n    assert F.array_equal(f3, f5)\n    assert F.array_equal(\n        g.edges(etype=\"follows\", form=\"eid\"), F.arange(0, 2, idtype)\n    )\n    g.edges[\"follows\"].data.pop(\"h\")\n\n    f3 = F.randn((2, 4))\n    fail = False\n    try:\n        g.edata[\"h\"] = f3\n    except Exception:\n        fail = True\n    assert fail\n\n    # test srcdata\n    f1 = F.randn((3, 6))\n    g.srcnodes[\"user\"].data[\"h\"] = f1  # ok\n    f2 = g.srcnodes[\"user\"].data[\"h\"]\n    assert F.array_equal(f1, f2)\n    assert F.array_equal(g.srcnodes(\"user\"), F.arange(0, 3, idtype))\n    g.srcnodes[\"user\"].data.pop(\"h\")\n\n    # test dstdata\n    f1 = F.randn((3, 6))\n    g.dstnodes[\"user\"].data[\"h\"] = f1  # ok\n    f2 = g.dstnodes[\"user\"].data[\"h\"]\n    assert F.array_equal(f1, f2)\n    assert F.array_equal(g.dstnodes(\"user\"), F.arange(0, 3, idtype))\n    g.dstnodes[\"user\"].data.pop(\"h\")\n\n\n@parametrize_idtype\ndef test_view1(idtype):\n    # test relation view\n    HG = create_test_heterograph(idtype)\n    ntypes = [\"user\", \"game\", \"developer\"]\n    canonical_etypes = [\n        (\"user\", \"follows\", \"user\"),\n        (\"user\", \"plays\", \"game\"),\n        (\"user\", \"wishes\", \"game\"),\n        (\"developer\", \"develops\", \"game\"),\n    ]\n    etypes = [\"follows\", \"plays\", \"wishes\", \"develops\"]\n\n    def _test_query():\n        for etype in etypes:\n            utype, _, vtype = HG.to_canonical_etype(etype)\n            g = HG[etype]\n            srcs, dsts = edges[etype]\n            for src, dst in zip(srcs, dsts):\n                assert g.has_edges_between(src, dst)\n            assert F.asnumpy(g.has_edges_between(srcs, dsts)).all()\n\n            srcs, dsts = negative_edges[etype]\n            for src, dst in zip(srcs, dsts):\n                assert not g.has_edges_between(src, dst)\n            assert not F.asnumpy(g.has_edges_between(srcs, dsts)).any()\n\n            srcs, dsts = edges[etype]\n            n_edges = len(srcs)\n\n            # predecessors & in_edges & in_degree\n            pred = [s for s, d in zip(srcs, dsts) if d == 0]\n            assert set(F.asnumpy(g.predecessors(0)).tolist()) == set(pred)\n            u, v = g.in_edges([0])\n            assert F.asnumpy(v).tolist() == [0] * len(pred)\n            assert set(F.asnumpy(u).tolist()) == set(pred)\n            assert g.in_degrees(0) == len(pred)\n\n            # successors & out_edges & out_degree\n            succ = [d for s, d in zip(srcs, dsts) if s == 0]\n            assert set(F.asnumpy(g.successors(0)).tolist()) == set(succ)\n            u, v = g.out_edges([0])\n            assert F.asnumpy(u).tolist() == [0] * len(succ)\n            assert set(F.asnumpy(v).tolist()) == set(succ)\n            assert g.out_degrees(0) == len(succ)\n\n            # edge_ids\n            for i, (src, dst) in enumerate(zip(srcs, dsts)):\n                assert g.edge_ids(src, dst, etype=etype) == i\n                _, _, eid = g.edge_ids(src, dst, etype=etype, return_uv=True)\n                assert eid == i\n            assert F.asnumpy(g.edge_ids(srcs, dsts)).tolist() == list(\n                range(n_edges)\n            )\n            u, v, e = g.edge_ids(srcs, dsts, return_uv=True)\n            u, v, e = F.asnumpy(u), F.asnumpy(v), F.asnumpy(e)\n            assert u[e].tolist() == srcs\n            assert v[e].tolist() == dsts\n\n            # find_edges\n            u, v = g.find_edges(list(range(n_edges)))\n            assert F.asnumpy(u).tolist() == srcs\n            assert F.asnumpy(v).tolist() == dsts\n\n            # all_edges.\n            for order in [\"eid\"]:\n                u, v, e = g.all_edges(form=\"all\", order=order)\n                assert F.asnumpy(u).tolist() == srcs\n                assert F.asnumpy(v).tolist() == dsts\n                assert F.asnumpy(e).tolist() == list(range(n_edges))\n\n            # in_degrees & out_degrees\n            in_degrees = F.asnumpy(g.in_degrees())\n            out_degrees = F.asnumpy(g.out_degrees())\n            src_count = Counter(srcs)\n            dst_count = Counter(dsts)\n            for i in range(g.num_nodes(utype)):\n                assert out_degrees[i] == src_count[i]\n            for i in range(g.num_nodes(vtype)):\n                assert in_degrees[i] == dst_count[i]\n\n    edges = {\n        \"follows\": ([0, 1], [1, 2]),\n        \"plays\": ([0, 1, 2, 1], [0, 0, 1, 1]),\n        \"wishes\": ([0, 2], [1, 0]),\n        \"develops\": ([0, 1], [0, 1]),\n    }\n    # edges that does not exist in the graph\n    negative_edges = {\n        \"follows\": ([0, 1], [0, 1]),\n        \"plays\": ([0, 2], [1, 0]),\n        \"wishes\": ([0, 1], [0, 1]),\n        \"develops\": ([0, 1], [1, 0]),\n    }\n    _test_query()\n    etypes = canonical_etypes\n    edges = {\n        (\"user\", \"follows\", \"user\"): ([0, 1], [1, 2]),\n        (\"user\", \"plays\", \"game\"): ([0, 1, 2, 1], [0, 0, 1, 1]),\n        (\"user\", \"wishes\", \"game\"): ([0, 2], [1, 0]),\n        (\"developer\", \"develops\", \"game\"): ([0, 1], [0, 1]),\n    }\n    # edges that does not exist in the graph\n    negative_edges = {\n        (\"user\", \"follows\", \"user\"): ([0, 1], [0, 1]),\n        (\"user\", \"plays\", \"game\"): ([0, 2], [1, 0]),\n        (\"user\", \"wishes\", \"game\"): ([0, 1], [0, 1]),\n        (\"developer\", \"develops\", \"game\"): ([0, 1], [1, 0]),\n    }\n    _test_query()\n\n    # test features\n    HG.nodes[\"user\"].data[\"h\"] = F.ones((HG.num_nodes(\"user\"), 5))\n    HG.nodes[\"game\"].data[\"m\"] = F.ones((HG.num_nodes(\"game\"), 3)) * 2\n\n    # test only one node type\n    g = HG[\"follows\"]\n    assert g.num_nodes() == 3\n\n    # test ndata and edata\n    f1 = F.randn((3, 6))\n    g.ndata[\"h\"] = f1  # ok\n    f2 = HG.nodes[\"user\"].data[\"h\"]\n    assert F.array_equal(f1, f2)\n    assert F.array_equal(g.nodes(), F.arange(0, 3, g.idtype))\n\n    f3 = F.randn((2, 4))\n    g.edata[\"h\"] = f3\n    f4 = HG.edges[\"follows\"].data[\"h\"]\n    assert F.array_equal(f3, f4)\n    assert F.array_equal(g.edges(form=\"eid\"), F.arange(0, 2, g.idtype))\n\n\n@parametrize_idtype\ndef test_flatten(idtype):\n    def check_mapping(g, fg):\n        if len(fg.ntypes) == 1:\n            SRC = DST = fg.ntypes[0]\n        else:\n            SRC = fg.ntypes[0]\n            DST = fg.ntypes[1]\n\n        etypes = F.asnumpy(fg.edata[dgl.ETYPE]).tolist()\n        eids = F.asnumpy(fg.edata[dgl.EID]).tolist()\n\n        for i, (etype, eid) in enumerate(zip(etypes, eids)):\n            src_g, dst_g = g.find_edges([eid], g.canonical_etypes[etype])\n            src_fg, dst_fg = fg.find_edges([i])\n            # TODO(gq): I feel this code is quite redundant; can we just add new members (like\n            # \"induced_srcid\") to returned heterograph object and not store them as features?\n            assert F.asnumpy(src_g) == F.asnumpy(\n                F.gather_row(fg.nodes[SRC].data[dgl.NID], src_fg)[0]\n            )\n            tid = F.asnumpy(\n                F.gather_row(fg.nodes[SRC].data[dgl.NTYPE], src_fg)\n            ).item()\n            assert g.canonical_etypes[etype][0] == g.ntypes[tid]\n            assert F.asnumpy(dst_g) == F.asnumpy(\n                F.gather_row(fg.nodes[DST].data[dgl.NID], dst_fg)[0]\n            )\n            tid = F.asnumpy(\n                F.gather_row(fg.nodes[DST].data[dgl.NTYPE], dst_fg)\n            ).item()\n            assert g.canonical_etypes[etype][2] == g.ntypes[tid]\n\n    # check for wildcard slices\n    g = create_test_heterograph(idtype)\n    g.nodes[\"user\"].data[\"h\"] = F.ones((3, 5))\n    g.nodes[\"game\"].data[\"i\"] = F.ones((2, 5))\n    g.edges[\"plays\"].data[\"e\"] = F.ones((4, 4))\n    g.edges[\"wishes\"].data[\"e\"] = F.ones((2, 4))\n    g.edges[\"wishes\"].data[\"f\"] = F.ones((2, 4))\n\n    fg = g[\"user\", :, \"game\"]  # user--plays->game and user--wishes->game\n    assert len(fg.ntypes) == 2\n    assert fg.ntypes == [\"user\", \"game\"]\n    assert fg.etypes == [\"plays+wishes\"]\n    assert fg.idtype == g.idtype\n    assert fg.device == g.device\n    etype = fg.etypes[0]\n    assert fg[etype] is not None  # Issue #2166\n\n    assert F.array_equal(fg.nodes[\"user\"].data[\"h\"], F.ones((3, 5)))\n    assert F.array_equal(fg.nodes[\"game\"].data[\"i\"], F.ones((2, 5)))\n    assert F.array_equal(fg.edata[\"e\"], F.ones((6, 4)))\n    assert \"f\" not in fg.edata\n\n    etypes = F.asnumpy(fg.edata[dgl.ETYPE]).tolist()\n    eids = F.asnumpy(fg.edata[dgl.EID]).tolist()\n    assert set(zip(etypes, eids)) == set(\n        [(3, 0), (3, 1), (2, 1), (2, 0), (2, 3), (2, 2)]\n    )\n\n    check_mapping(g, fg)\n\n    fg = g[\"user\", :, \"user\"]\n    assert fg.idtype == g.idtype\n    assert fg.device == g.device\n    # NOTE(gq): The node/edge types from the parent graph is returned if there is only one\n    # node/edge type.  This differs from the behavior above.\n    assert fg.ntypes == [\"user\"]\n    assert fg.etypes == [\"follows\"]\n    u1, v1 = g.edges(etype=\"follows\", order=\"eid\")\n    u2, v2 = fg.edges(etype=\"follows\", order=\"eid\")\n    assert F.array_equal(u1, u2)\n    assert F.array_equal(v1, v2)\n\n    fg = g[\"developer\", :, \"game\"]\n    assert fg.idtype == g.idtype\n    assert fg.device == g.device\n    assert fg.ntypes == [\"developer\", \"game\"]\n    assert fg.etypes == [\"develops\"]\n    u1, v1 = g.edges(etype=\"develops\", order=\"eid\")\n    u2, v2 = fg.edges(etype=\"develops\", order=\"eid\")\n    assert F.array_equal(u1, u2)\n    assert F.array_equal(v1, v2)\n\n    fg = g[:, :, :]\n    assert fg.idtype == g.idtype\n    assert fg.device == g.device\n    assert fg.ntypes == [\"developer+user\", \"game+user\"]\n    assert fg.etypes == [\"develops+follows+plays+wishes\"]\n    check_mapping(g, fg)\n\n    # Test another heterograph\n    g = dgl.heterograph(\n        {\n            (\"user\", \"follows\", \"user\"): ([0, 1, 2], [1, 2, 3]),\n            (\"user\", \"knows\", \"user\"): ([0, 2], [2, 3]),\n        },\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    g.nodes[\"user\"].data[\"h\"] = F.randn((4, 3))\n    g.edges[\"follows\"].data[\"w\"] = F.randn((3, 2))\n    g.nodes[\"user\"].data[\"hh\"] = F.randn((4, 5))\n    g.edges[\"knows\"].data[\"ww\"] = F.randn((2, 10))\n\n    fg = g[\"user\", :, \"user\"]\n    assert fg.idtype == g.idtype\n    assert fg.device == g.device\n    assert fg.ntypes == [\"user\"]\n    assert fg.etypes == [\"follows+knows\"]\n    check_mapping(g, fg)\n\n    fg = g[\"user\", :, :]\n    assert fg.idtype == g.idtype\n    assert fg.device == g.device\n    assert fg.ntypes == [\"user\"]\n    assert fg.etypes == [\"follows+knows\"]\n    check_mapping(g, fg)\n\n\n@unittest.skipIf(\n    F._default_context_str == \"cpu\", reason=\"Need gpu for this test\"\n)\n@parametrize_idtype\ndef test_to_device(idtype):\n    # TODO: rewrite this test case to accept different graphs so we\n    #  can test reverse graph and batched graph\n    g = create_test_heterograph(idtype)\n    g.nodes[\"user\"].data[\"h\"] = F.ones((3, 5))\n    g.nodes[\"game\"].data[\"i\"] = F.ones((2, 5))\n    g.edges[\"plays\"].data[\"e\"] = F.ones((4, 4))\n    assert g.device == F.ctx()\n    g = g.to(F.cpu())\n    assert g.device == F.cpu()\n    assert F.context(g.nodes[\"user\"].data[\"h\"]) == F.cpu()\n    assert F.context(g.nodes[\"game\"].data[\"i\"]) == F.cpu()\n    assert F.context(g.edges[\"plays\"].data[\"e\"]) == F.cpu()\n    for ntype in g.ntypes:\n        assert F.context(g.batch_num_nodes(ntype)) == F.cpu()\n    for etype in g.canonical_etypes:\n        assert F.context(g.batch_num_edges(etype)) == F.cpu()\n\n    if F.is_cuda_available():\n        g1 = g.to(F.cuda())\n        assert g1.device == F.cuda()\n        assert F.context(g1.nodes[\"user\"].data[\"h\"]) == F.cuda()\n        assert F.context(g1.nodes[\"game\"].data[\"i\"]) == F.cuda()\n        assert F.context(g1.edges[\"plays\"].data[\"e\"]) == F.cuda()\n        for ntype in g1.ntypes:\n            assert F.context(g1.batch_num_nodes(ntype)) == F.cuda()\n        for etype in g1.canonical_etypes:\n            assert F.context(g1.batch_num_edges(etype)) == F.cuda()\n        assert F.context(g.nodes[\"user\"].data[\"h\"]) == F.cpu()\n        assert F.context(g.nodes[\"game\"].data[\"i\"]) == F.cpu()\n        assert F.context(g.edges[\"plays\"].data[\"e\"]) == F.cpu()\n        for ntype in g.ntypes:\n            assert F.context(g.batch_num_nodes(ntype)) == F.cpu()\n        for etype in g.canonical_etypes:\n            assert F.context(g.batch_num_edges(etype)) == F.cpu()\n        with pytest.raises(DGLError):\n            g1.nodes[\"user\"].data[\"h\"] = F.copy_to(F.ones((3, 5)), F.cpu())\n        with pytest.raises(DGLError):\n            g1.edges[\"plays\"].data[\"e\"] = F.copy_to(F.ones((4, 4)), F.cpu())\n\n\n@unittest.skipIf(\n    F._default_context_str == \"cpu\", reason=\"Need gpu for this test\"\n)\n@parametrize_idtype\n@pytest.mark.parametrize(\"g\", get_cases([\"block\"]))\ndef test_to_device2(g, idtype):\n    g = g.astype(idtype)\n    g = g.to(F.cpu())\n    assert g.device == F.cpu()\n    if F.is_cuda_available():\n        g1 = g.to(F.cuda())\n        assert g1.device == F.cuda()\n        assert g1.ntypes == g.ntypes\n        assert g1.etypes == g.etypes\n        assert g1.canonical_etypes == g.canonical_etypes\n\n\n@unittest.skipIf(\n    F._default_context_str == \"cpu\", reason=\"Need gpu for this test\"\n)\n@unittest.skipIf(\n    dgl.backend.backend_name != \"pytorch\",\n    reason=\"Pinning graph inplace only supported for PyTorch\",\n)\n@parametrize_idtype\ndef test_pin_memory_(idtype):\n    # TODO: rewrite this test case to accept different graphs so we\n    #  can test reverse graph and batched graph\n    g = create_test_heterograph(idtype)\n    g.nodes[\"user\"].data[\"h\"] = F.ones((3, 5))\n    g.nodes[\"game\"].data[\"i\"] = F.ones((2, 5))\n    g.edges[\"plays\"].data[\"e\"] = F.ones((4, 4))\n    g = g.to(F.cpu())\n    assert not g.is_pinned()\n\n    # unpin an unpinned CPU graph, directly return\n    g.unpin_memory_()\n    assert not g.is_pinned()\n    assert g.device == F.cpu()\n\n    # pin a CPU graph\n    g.pin_memory_()\n    assert g.is_pinned()\n    assert g.device == F.cpu()\n    assert g.nodes[\"user\"].data[\"h\"].is_pinned()\n    assert g.nodes[\"game\"].data[\"i\"].is_pinned()\n    assert g.edges[\"plays\"].data[\"e\"].is_pinned()\n    assert F.context(g.nodes[\"user\"].data[\"h\"]) == F.cpu()\n    assert F.context(g.nodes[\"game\"].data[\"i\"]) == F.cpu()\n    assert F.context(g.edges[\"plays\"].data[\"e\"]) == F.cpu()\n    for ntype in g.ntypes:\n        assert F.context(g.batch_num_nodes(ntype)) == F.cpu()\n    for etype in g.canonical_etypes:\n        assert F.context(g.batch_num_edges(etype)) == F.cpu()\n\n    # it's fine to clone with new formats, but new graphs are not pinned\n    # >>> g.formats()\n    # {'created': ['coo'], 'not created': ['csr', 'csc']}\n    assert not g.formats(\"csc\").is_pinned()\n    assert not g.formats(\"csr\").is_pinned()\n    # 'coo' formats is already created and thus not cloned\n    assert g.formats(\"coo\").is_pinned()\n\n    # pin a pinned graph, directly return\n    g.pin_memory_()\n    assert g.is_pinned()\n    assert g.device == F.cpu()\n\n    # unpin a pinned graph\n    g.unpin_memory_()\n    assert not g.is_pinned()\n    assert g.device == F.cpu()\n\n    g1 = g.to(F.cuda())\n\n    # unpin an unpinned GPU graph, directly return\n    g1.unpin_memory_()\n    assert not g1.is_pinned()\n    assert g1.device == F.cuda()\n\n    # error pinning a GPU graph\n    with pytest.raises(DGLError):\n        g1.pin_memory_()\n\n    # test pin empty homograph\n    g2 = dgl.graph(([], []))\n    assert not g2.is_pinned()\n    g2.pin_memory_()\n    assert g2.is_pinned()\n    g2.unpin_memory_()\n    assert not g2.is_pinned()\n\n    # test pin heterograph with 0 edge of one relation type\n    g3 = dgl.heterograph(\n        {(\"a\", \"b\", \"c\"): ([0, 1], [1, 2]), (\"c\", \"d\", \"c\"): ([], [])}\n    ).astype(idtype)\n    g3.pin_memory_()\n    assert g3.is_pinned()\n    g3.unpin_memory_()\n    assert not g3.is_pinned()\n\n\n@parametrize_idtype\ndef test_convert_bound(idtype):\n    def _test_bipartite_bound(data, card):\n        with pytest.raises(DGLError):\n            dgl.heterograph(\n                {(\"_U\", \"_E\", \"_V\"): data},\n                {\"_U\": card[0], \"_V\": card[1]},\n                idtype=idtype,\n                device=F.ctx(),\n            )\n\n    def _test_graph_bound(data, card):\n        with pytest.raises(DGLError):\n            dgl.graph(data, num_nodes=card, idtype=idtype, device=F.ctx())\n\n    _test_bipartite_bound(([1, 2], [1, 2]), (2, 3))\n    _test_bipartite_bound(([0, 1], [1, 4]), (2, 3))\n    _test_graph_bound(([1, 3], [1, 2]), 3)\n    _test_graph_bound(([0, 1], [1, 3]), 3)\n\n\n@parametrize_idtype\ndef test_convert(idtype):\n    hg = create_test_heterograph(idtype)\n    hs = []\n    for ntype in hg.ntypes:\n        h = F.randn((hg.num_nodes(ntype), 5))\n        hg.nodes[ntype].data[\"h\"] = h\n        hs.append(h)\n    hg.nodes[\"user\"].data[\"x\"] = F.randn((3, 3))\n    ws = []\n    for etype in hg.canonical_etypes:\n        w = F.randn((hg.num_edges(etype), 5))\n        hg.edges[etype].data[\"w\"] = w\n        ws.append(w)\n    hg.edges[\"plays\"].data[\"x\"] = F.randn((4, 3))\n\n    g = dgl.to_homogeneous(hg, ndata=[\"h\"], edata=[\"w\"])\n    assert g.idtype == idtype\n    assert g.device == hg.device\n    assert F.array_equal(F.cat(hs, dim=0), g.ndata[\"h\"])\n    assert \"x\" not in g.ndata\n    assert F.array_equal(F.cat(ws, dim=0), g.edata[\"w\"])\n    assert \"x\" not in g.edata\n\n    src, dst = g.all_edges(order=\"eid\")\n    src = F.asnumpy(src)\n    dst = F.asnumpy(dst)\n    etype_id, eid = F.asnumpy(g.edata[dgl.ETYPE]), F.asnumpy(g.edata[dgl.EID])\n    ntype_id, nid = F.asnumpy(g.ndata[dgl.NTYPE]), F.asnumpy(g.ndata[dgl.NID])\n    for i in range(g.num_edges()):\n        srctype = hg.ntypes[ntype_id[src[i]]]\n        dsttype = hg.ntypes[ntype_id[dst[i]]]\n        etype = hg.etypes[etype_id[i]]\n        src_i, dst_i = hg.find_edges([eid[i]], (srctype, etype, dsttype))\n        assert F.asnumpy(src_i).item() == nid[src[i]]\n        assert F.asnumpy(dst_i).item() == nid[dst[i]]\n\n    mg = nx.MultiDiGraph(\n        [\n            (\"user\", \"user\", \"follows\"),\n            (\"user\", \"game\", \"plays\"),\n            (\"user\", \"game\", \"wishes\"),\n            (\"developer\", \"game\", \"develops\"),\n        ]\n    )\n\n    for _mg in [None, mg]:\n        hg2 = dgl.to_heterogeneous(\n            g,\n            hg.ntypes,\n            hg.etypes,\n            ntype_field=dgl.NTYPE,\n            etype_field=dgl.ETYPE,\n            metagraph=_mg,\n        )\n        assert hg2.idtype == hg.idtype\n        assert hg2.device == hg.device\n        assert set(hg.ntypes) == set(hg2.ntypes)\n        assert set(hg.canonical_etypes) == set(hg2.canonical_etypes)\n        for ntype in hg.ntypes:\n            assert hg.num_nodes(ntype) == hg2.num_nodes(ntype)\n            assert F.array_equal(\n                hg.nodes[ntype].data[\"h\"], hg2.nodes[ntype].data[\"h\"]\n            )\n        for canonical_etype in hg.canonical_etypes:\n            src, dst = hg.all_edges(etype=canonical_etype, order=\"eid\")\n            src2, dst2 = hg2.all_edges(etype=canonical_etype, order=\"eid\")\n            assert F.array_equal(src, src2)\n            assert F.array_equal(dst, dst2)\n            assert F.array_equal(\n                hg.edges[canonical_etype].data[\"w\"],\n                hg2.edges[canonical_etype].data[\"w\"],\n            )\n\n    # hetero_from_homo test case 2\n    g = dgl.graph(([0, 1, 2, 0], [2, 2, 3, 3]), idtype=idtype, device=F.ctx())\n    g.ndata[dgl.NTYPE] = F.tensor([0, 0, 1, 2])\n    g.edata[dgl.ETYPE] = F.tensor([0, 0, 1, 2])\n    hg = dgl.to_heterogeneous(g, [\"l0\", \"l1\", \"l2\"], [\"e0\", \"e1\", \"e2\"])\n    assert hg.idtype == idtype\n    assert hg.device == g.device\n    assert set(hg.canonical_etypes) == set(\n        [(\"l0\", \"e0\", \"l1\"), (\"l1\", \"e1\", \"l2\"), (\"l0\", \"e2\", \"l2\")]\n    )\n    assert hg.num_nodes(\"l0\") == 2\n    assert hg.num_nodes(\"l1\") == 1\n    assert hg.num_nodes(\"l2\") == 1\n    assert hg.num_edges(\"e0\") == 2\n    assert hg.num_edges(\"e1\") == 1\n    assert hg.num_edges(\"e2\") == 1\n    assert F.array_equal(hg.ndata[dgl.NID][\"l0\"], F.tensor([0, 1], F.int64))\n    assert F.array_equal(hg.ndata[dgl.NID][\"l1\"], F.tensor([2], F.int64))\n    assert F.array_equal(hg.ndata[dgl.NID][\"l2\"], F.tensor([3], F.int64))\n    assert F.array_equal(\n        hg.edata[dgl.EID][(\"l0\", \"e0\", \"l1\")], F.tensor([0, 1], F.int64)\n    )\n    assert F.array_equal(\n        hg.edata[dgl.EID][(\"l0\", \"e2\", \"l2\")], F.tensor([3], F.int64)\n    )\n    assert F.array_equal(\n        hg.edata[dgl.EID][(\"l1\", \"e1\", \"l2\")], F.tensor([2], F.int64)\n    )\n\n    # hetero_from_homo test case 3\n    mg = nx.MultiDiGraph(\n        [(\"user\", \"movie\", \"watches\"), (\"user\", \"TV\", \"watches\")]\n    )\n    g = dgl.graph(((0, 0), (1, 2)), idtype=idtype, device=F.ctx())\n    g.ndata[dgl.NTYPE] = F.tensor([0, 1, 2])\n    g.edata[dgl.ETYPE] = F.tensor([0, 0])\n    for _mg in [None, mg]:\n        hg = dgl.to_heterogeneous(\n            g, [\"user\", \"TV\", \"movie\"], [\"watches\"], metagraph=_mg\n        )\n        assert hg.idtype == g.idtype\n        assert hg.device == g.device\n        assert set(hg.canonical_etypes) == set(\n            [(\"user\", \"watches\", \"movie\"), (\"user\", \"watches\", \"TV\")]\n        )\n        assert hg.num_nodes(\"user\") == 1\n        assert hg.num_nodes(\"TV\") == 1\n        assert hg.num_nodes(\"movie\") == 1\n        assert hg.num_edges((\"user\", \"watches\", \"TV\")) == 1\n        assert hg.num_edges((\"user\", \"watches\", \"movie\")) == 1\n        assert len(hg.etypes) == 2\n\n    # hetero_to_homo test case 2\n    hg = dgl.heterograph(\n        {(\"_U\", \"_E\", \"_V\"): ([0, 1], [0, 1])},\n        {\"_U\": 2, \"_V\": 3},\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    g = dgl.to_homogeneous(hg)\n    assert hg.idtype == g.idtype\n    assert hg.device == g.device\n    assert g.num_nodes() == 5\n\n    # hetero_to_subgraph_to_homo\n    hg = dgl.heterograph(\n        {\n            (\"user\", \"plays\", \"game\"): ([0, 1, 1, 2], [0, 0, 2, 1]),\n            (\"user\", \"follows\", \"user\"): ([0, 1, 1], [1, 2, 2]),\n        },\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    hg.nodes[\"user\"].data[\"h\"] = F.copy_to(\n        F.tensor([[1, 0], [0, 1], [1, 1]], dtype=idtype), ctx=F.ctx()\n    )\n    sg = dgl.node_subgraph(hg, {\"user\": [1, 2]})\n    assert len(sg.ntypes) == 2\n    assert len(sg.etypes) == 2\n    assert sg.num_nodes(\"user\") == 2\n    assert sg.num_nodes(\"game\") == 0\n    g = dgl.to_homogeneous(sg, ndata=[\"h\"])\n    assert \"h\" in g.ndata.keys()\n    assert g.num_nodes() == 2\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\", reason=\"Test on cpu is enough\"\n)\n@parametrize_idtype\ndef test_to_homo_zero_nodes(idtype):\n    # Fix gihub issue #2870\n    g = dgl.heterograph(\n        {\n            (\"A\", \"AB\", \"B\"): (\n                np.random.randint(0, 200, (1000,)),\n                np.random.randint(0, 200, (1000,)),\n            ),\n            (\"B\", \"BA\", \"A\"): (\n                np.random.randint(0, 200, (1000,)),\n                np.random.randint(0, 200, (1000,)),\n            ),\n        },\n        num_nodes_dict={\"A\": 200, \"B\": 200, \"C\": 0},\n        idtype=idtype,\n    )\n    g.nodes[\"A\"].data[\"x\"] = F.randn((200, 3))\n    g.nodes[\"B\"].data[\"x\"] = F.randn((200, 3))\n    gg = dgl.to_homogeneous(g, [\"x\"])\n    assert \"x\" in gg.ndata\n\n\n@parametrize_idtype\ndef test_to_homo2(idtype):\n    # test the result homogeneous graph has nodes and edges sorted by their types\n    hg = create_test_heterograph(idtype)\n    g = dgl.to_homogeneous(hg)\n    ntypes = F.asnumpy(g.ndata[dgl.NTYPE])\n    etypes = F.asnumpy(g.edata[dgl.ETYPE])\n    p = 0\n    for tid, ntype in enumerate(hg.ntypes):\n        num_nodes = hg.num_nodes(ntype)\n        for i in range(p, p + num_nodes):\n            assert ntypes[i] == tid\n        p += num_nodes\n    p = 0\n    for tid, etype in enumerate(hg.canonical_etypes):\n        num_edges = hg.num_edges(etype)\n        for i in range(p, p + num_edges):\n            assert etypes[i] == tid\n        p += num_edges\n    # test store_type=False\n    g = dgl.to_homogeneous(hg, store_type=False)\n    assert dgl.NTYPE not in g.ndata\n    assert dgl.ETYPE not in g.edata\n    # test return_count=True\n    g, ntype_count, etype_count = dgl.to_homogeneous(hg, return_count=True)\n    for i, count in enumerate(ntype_count):\n        assert count == hg.num_nodes(hg.ntypes[i])\n    for i, count in enumerate(etype_count):\n        assert count == hg.num_edges(hg.canonical_etypes[i])\n\n\n@parametrize_idtype\ndef test_invertible_conversion(idtype):\n    # Test whether to_homogeneous and to_heterogeneous are invertible\n    hg = create_test_heterograph(idtype)\n    g = dgl.to_homogeneous(hg)\n    hg2 = dgl.to_heterogeneous(g, hg.ntypes, hg.etypes)\n    assert_is_identical_hetero(hg, hg2, True)\n\n\n@parametrize_idtype\ndef test_metagraph_reachable(idtype):\n    g = create_test_heterograph(idtype)\n    x = F.randn((3, 5))\n    g.nodes[\"user\"].data[\"h\"] = x\n\n    new_g = dgl.metapath_reachable_graph(g, [\"follows\", \"plays\"])\n    assert new_g.idtype == idtype\n    assert new_g.ntypes == [\"game\", \"user\"]\n    assert new_g.num_edges() == 3\n    assert F.asnumpy(new_g.has_edges_between([0, 0, 1], [0, 1, 1])).all()\n\n    new_g = dgl.metapath_reachable_graph(g, [\"follows\"])\n    assert new_g.idtype == idtype\n    assert new_g.ntypes == [\"user\"]\n    assert new_g.num_edges() == 2\n    assert F.asnumpy(new_g.has_edges_between([0, 1], [1, 2])).all()\n\n\n@unittest.skipIf(\n    dgl.backend.backend_name == \"mxnet\",\n    reason=\"MXNet doesn't support bool tensor\",\n)\n@parametrize_idtype\ndef test_subgraph_mask(idtype):\n    g = create_test_heterograph(idtype)\n    g_graph = g[\"follows\"]\n    g_bipartite = g[\"plays\"]\n\n    x = F.randn((3, 5))\n    y = F.randn((2, 4))\n    g.nodes[\"user\"].data[\"h\"] = x\n    g.edges[\"follows\"].data[\"h\"] = y\n\n    def _check_subgraph(g, sg):\n        assert sg.idtype == g.idtype\n        assert sg.device == g.device\n        assert sg.ntypes == g.ntypes\n        assert sg.etypes == g.etypes\n        assert sg.canonical_etypes == g.canonical_etypes\n        assert F.array_equal(\n            F.tensor(sg.nodes[\"user\"].data[dgl.NID]), F.tensor([1, 2], idtype)\n        )\n        assert F.array_equal(\n            F.tensor(sg.nodes[\"game\"].data[dgl.NID]), F.tensor([0], idtype)\n        )\n        assert F.array_equal(\n            F.tensor(sg.edges[\"follows\"].data[dgl.EID]), F.tensor([1], idtype)\n        )\n        assert F.array_equal(\n            F.tensor(sg.edges[\"plays\"].data[dgl.EID]), F.tensor([1], idtype)\n        )\n        assert F.array_equal(\n            F.tensor(sg.edges[\"wishes\"].data[dgl.EID]), F.tensor([1], idtype)\n        )\n        assert sg.num_nodes(\"developer\") == 0\n        assert sg.num_edges(\"develops\") == 0\n        assert F.array_equal(\n            sg.nodes[\"user\"].data[\"h\"], g.nodes[\"user\"].data[\"h\"][1:3]\n        )\n        assert F.array_equal(\n            sg.edges[\"follows\"].data[\"h\"], g.edges[\"follows\"].data[\"h\"][1:2]\n        )\n\n    sg1 = g.subgraph(\n        {\n            \"user\": F.tensor([False, True, True], dtype=F.bool),\n            \"game\": F.tensor([True, False, False, False], dtype=F.bool),\n        }\n    )\n    _check_subgraph(g, sg1)\n    if F._default_context_str != \"gpu\":\n        # TODO(minjie): enable this later\n        sg2 = g.edge_subgraph(\n            {\n                \"follows\": F.tensor([False, True], dtype=F.bool),\n                \"plays\": F.tensor([False, True, False, False], dtype=F.bool),\n                \"wishes\": F.tensor([False, True], dtype=F.bool),\n            }\n        )\n        _check_subgraph(g, sg2)\n\n\n@parametrize_idtype\ndef test_subgraph(idtype):\n    g = create_test_heterograph(idtype)\n    g_graph = g[\"follows\"]\n    g_bipartite = g[\"plays\"]\n\n    x = F.randn((3, 5))\n    y = F.randn((2, 4))\n    g.nodes[\"user\"].data[\"h\"] = x\n    g.edges[\"follows\"].data[\"h\"] = y\n\n    def _check_subgraph(g, sg):\n        assert sg.idtype == g.idtype\n        assert sg.device == g.device\n        assert sg.ntypes == g.ntypes\n        assert sg.etypes == g.etypes\n        assert sg.canonical_etypes == g.canonical_etypes\n        assert F.array_equal(\n            F.tensor(sg.nodes[\"user\"].data[dgl.NID]), F.tensor([1, 2], g.idtype)\n        )\n        assert F.array_equal(\n            F.tensor(sg.nodes[\"game\"].data[dgl.NID]), F.tensor([0], g.idtype)\n        )\n        assert F.array_equal(\n            F.tensor(sg.edges[\"follows\"].data[dgl.EID]), F.tensor([1], g.idtype)\n        )\n        assert F.array_equal(\n            F.tensor(sg.edges[\"plays\"].data[dgl.EID]), F.tensor([1], g.idtype)\n        )\n        assert F.array_equal(\n            F.tensor(sg.edges[\"wishes\"].data[dgl.EID]), F.tensor([1], g.idtype)\n        )\n        assert sg.num_nodes(\"developer\") == 0\n        assert sg.num_edges(\"develops\") == 0\n        assert F.array_equal(\n            sg.nodes[\"user\"].data[\"h\"], g.nodes[\"user\"].data[\"h\"][1:3]\n        )\n        assert F.array_equal(\n            sg.edges[\"follows\"].data[\"h\"], g.edges[\"follows\"].data[\"h\"][1:2]\n        )\n\n    sg1 = g.subgraph({\"user\": [1, 2], \"game\": [0]})\n    _check_subgraph(g, sg1)\n    if F._default_context_str != \"gpu\":\n        # TODO(minjie): enable this later\n        sg2 = g.edge_subgraph({\"follows\": [1], \"plays\": [1], \"wishes\": [1]})\n        _check_subgraph(g, sg2)\n\n    # backend tensor input\n    sg1 = g.subgraph(\n        {\n            \"user\": F.tensor([1, 2], dtype=idtype),\n            \"game\": F.tensor([0], dtype=idtype),\n        }\n    )\n    _check_subgraph(g, sg1)\n    if F._default_context_str != \"gpu\":\n        # TODO(minjie): enable this later\n        sg2 = g.edge_subgraph(\n            {\n                \"follows\": F.tensor([1], dtype=idtype),\n                \"plays\": F.tensor([1], dtype=idtype),\n                \"wishes\": F.tensor([1], dtype=idtype),\n            }\n        )\n        _check_subgraph(g, sg2)\n\n    # numpy input\n    sg1 = g.subgraph({\"user\": np.array([1, 2]), \"game\": np.array([0])})\n    _check_subgraph(g, sg1)\n    if F._default_context_str != \"gpu\":\n        # TODO(minjie): enable this later\n        sg2 = g.edge_subgraph(\n            {\n                \"follows\": np.array([1]),\n                \"plays\": np.array([1]),\n                \"wishes\": np.array([1]),\n            }\n        )\n        _check_subgraph(g, sg2)\n\n    def _check_subgraph_single_ntype(g, sg, preserve_nodes=False):\n        assert sg.idtype == g.idtype\n        assert sg.device == g.device\n        assert sg.ntypes == g.ntypes\n        assert sg.etypes == g.etypes\n        assert sg.canonical_etypes == g.canonical_etypes\n\n        if not preserve_nodes:\n            assert F.array_equal(\n                F.tensor(sg.nodes[\"user\"].data[dgl.NID]),\n                F.tensor([1, 2], g.idtype),\n            )\n        else:\n            for ntype in sg.ntypes:\n                assert g.num_nodes(ntype) == sg.num_nodes(ntype)\n\n        assert F.array_equal(\n            F.tensor(sg.edges[\"follows\"].data[dgl.EID]), F.tensor([1], g.idtype)\n        )\n\n        if not preserve_nodes:\n            assert F.array_equal(\n                sg.nodes[\"user\"].data[\"h\"], g.nodes[\"user\"].data[\"h\"][1:3]\n            )\n        assert F.array_equal(\n            sg.edges[\"follows\"].data[\"h\"], g.edges[\"follows\"].data[\"h\"][1:2]\n        )\n\n    def _check_subgraph_single_etype(g, sg, preserve_nodes=False):\n        assert sg.ntypes == g.ntypes\n        assert sg.etypes == g.etypes\n        assert sg.canonical_etypes == g.canonical_etypes\n\n        if not preserve_nodes:\n            assert F.array_equal(\n                F.tensor(sg.nodes[\"user\"].data[dgl.NID]),\n                F.tensor([0, 1], g.idtype),\n            )\n            assert F.array_equal(\n                F.tensor(sg.nodes[\"game\"].data[dgl.NID]),\n                F.tensor([0], g.idtype),\n            )\n        else:\n            for ntype in sg.ntypes:\n                assert g.num_nodes(ntype) == sg.num_nodes(ntype)\n\n        assert F.array_equal(\n            F.tensor(sg.edges[\"plays\"].data[dgl.EID]),\n            F.tensor([0, 1], g.idtype),\n        )\n\n    sg1_graph = g_graph.subgraph([1, 2])\n    _check_subgraph_single_ntype(g_graph, sg1_graph)\n    if F._default_context_str != \"gpu\":\n        # TODO(minjie): enable this later\n        sg1_graph = g_graph.edge_subgraph([1])\n        _check_subgraph_single_ntype(g_graph, sg1_graph)\n        sg1_graph = g_graph.edge_subgraph([1], relabel_nodes=False)\n        _check_subgraph_single_ntype(g_graph, sg1_graph, True)\n        sg2_bipartite = g_bipartite.edge_subgraph([0, 1])\n        _check_subgraph_single_etype(g_bipartite, sg2_bipartite)\n        sg2_bipartite = g_bipartite.edge_subgraph([0, 1], relabel_nodes=False)\n        _check_subgraph_single_etype(g_bipartite, sg2_bipartite, True)\n\n    def _check_typed_subgraph1(g, sg):\n        assert g.idtype == sg.idtype\n        assert g.device == sg.device\n        assert set(sg.ntypes) == {\"user\", \"game\"}\n        assert set(sg.etypes) == {\"follows\", \"plays\", \"wishes\"}\n        for ntype in sg.ntypes:\n            assert sg.num_nodes(ntype) == g.num_nodes(ntype)\n        for etype in sg.etypes:\n            src_sg, dst_sg = sg.all_edges(etype=etype, order=\"eid\")\n            src_g, dst_g = g.all_edges(etype=etype, order=\"eid\")\n            assert F.array_equal(src_sg, src_g)\n            assert F.array_equal(dst_sg, dst_g)\n        assert F.array_equal(\n            sg.nodes[\"user\"].data[\"h\"], g.nodes[\"user\"].data[\"h\"]\n        )\n        assert F.array_equal(\n            sg.edges[\"follows\"].data[\"h\"], g.edges[\"follows\"].data[\"h\"]\n        )\n        g.nodes[\"user\"].data[\"h\"] = F.scatter_row(\n            g.nodes[\"user\"].data[\"h\"], F.tensor([2]), F.randn((1, 5))\n        )\n        g.edges[\"follows\"].data[\"h\"] = F.scatter_row(\n            g.edges[\"follows\"].data[\"h\"], F.tensor([1]), F.randn((1, 4))\n        )\n        assert F.array_equal(\n            sg.nodes[\"user\"].data[\"h\"], g.nodes[\"user\"].data[\"h\"]\n        )\n        assert F.array_equal(\n            sg.edges[\"follows\"].data[\"h\"], g.edges[\"follows\"].data[\"h\"]\n        )\n\n    def _check_typed_subgraph2(g, sg):\n        assert set(sg.ntypes) == {\"developer\", \"game\"}\n        assert set(sg.etypes) == {\"develops\"}\n        for ntype in sg.ntypes:\n            assert sg.num_nodes(ntype) == g.num_nodes(ntype)\n        for etype in sg.etypes:\n            src_sg, dst_sg = sg.all_edges(etype=etype, order=\"eid\")\n            src_g, dst_g = g.all_edges(etype=etype, order=\"eid\")\n            assert F.array_equal(src_sg, src_g)\n            assert F.array_equal(dst_sg, dst_g)\n\n    sg3 = g.node_type_subgraph([\"user\", \"game\"])\n    _check_typed_subgraph1(g, sg3)\n    sg4 = g.edge_type_subgraph([\"develops\"])\n    _check_typed_subgraph2(g, sg4)\n    sg5 = g.edge_type_subgraph([\"follows\", \"plays\", \"wishes\"])\n    _check_typed_subgraph1(g, sg5)\n\n\n@parametrize_idtype\ndef test_apply(idtype):\n    def node_udf(nodes):\n        return {\"h\": nodes.data[\"h\"] * 2}\n\n    def node_udf2(nodes):\n        return {\"h\": F.sum(nodes.data[\"h\"], dim=1, keepdims=True)}\n\n    def edge_udf(edges):\n        return {\"h\": edges.data[\"h\"] * 2 + edges.src[\"h\"]}\n\n    g = create_test_heterograph(idtype)\n    g.nodes[\"user\"].data[\"h\"] = F.ones((3, 5))\n    g.apply_nodes(node_udf, ntype=\"user\")\n    assert F.array_equal(g.nodes[\"user\"].data[\"h\"], F.ones((3, 5)) * 2)\n\n    g[\"plays\"].edata[\"h\"] = F.ones((4, 5))\n    g.apply_edges(edge_udf, etype=(\"user\", \"plays\", \"game\"))\n    assert F.array_equal(g[\"plays\"].edata[\"h\"], F.ones((4, 5)) * 4)\n\n    # test apply on graph with only one type\n    g[\"follows\"].apply_nodes(node_udf)\n    assert F.array_equal(g.nodes[\"user\"].data[\"h\"], F.ones((3, 5)) * 4)\n\n    g[\"plays\"].apply_edges(edge_udf)\n    assert F.array_equal(g[\"plays\"].edata[\"h\"], F.ones((4, 5)) * 12)\n\n    # Test the case that feature size changes\n    g.nodes[\"user\"].data[\"h\"] = F.ones((3, 5))\n    g.apply_nodes(node_udf2, ntype=\"user\")\n    assert F.array_equal(g.nodes[\"user\"].data[\"h\"], F.ones((3, 1)) * 5)\n\n    # test fail case\n    # fail due to multiple types\n    with pytest.raises(DGLError):\n        g.apply_nodes(node_udf)\n\n    with pytest.raises(DGLError):\n        g.apply_edges(edge_udf)\n\n\n@parametrize_idtype\ndef test_level2(idtype):\n    # edges = {\n    #    'follows': ([0, 1], [1, 2]),\n    #    'plays': ([0, 1, 2, 1], [0, 0, 1, 1]),\n    #    'wishes': ([0, 2], [1, 0]),\n    #    'develops': ([0, 1], [0, 1]),\n    # }\n    g = create_test_heterograph(idtype)\n\n    def rfunc(nodes):\n        return {\"y\": F.sum(nodes.mailbox[\"m\"], 1)}\n\n    def rfunc2(nodes):\n        return {\"y\": F.max(nodes.mailbox[\"m\"], 1)}\n\n    def mfunc(edges):\n        return {\"m\": edges.src[\"h\"]}\n\n    def afunc(nodes):\n        return {\"y\": nodes.data[\"y\"] + 1}\n\n    #############################################################\n    #  send_and_recv\n    #############################################################\n\n    g.nodes[\"user\"].data[\"h\"] = F.ones((3, 2))\n    g.send_and_recv([2, 3], mfunc, rfunc, etype=\"plays\")\n    y = g.nodes[\"game\"].data[\"y\"]\n    assert F.array_equal(y, F.tensor([[0.0, 0.0], [2.0, 2.0]]))\n\n    # only one type\n    g[\"plays\"].send_and_recv([2, 3], mfunc, rfunc)\n    y = g.nodes[\"game\"].data[\"y\"]\n    assert F.array_equal(y, F.tensor([[0.0, 0.0], [2.0, 2.0]]))\n\n    # test fail case\n    # fail due to multiple types\n    with pytest.raises(DGLError):\n        g.send_and_recv([2, 3], mfunc, rfunc)\n\n    g.nodes[\"game\"].data.clear()\n\n    #############################################################\n    #  pull\n    #############################################################\n\n    g.nodes[\"user\"].data[\"h\"] = F.ones((3, 2))\n    g.pull(1, mfunc, rfunc, etype=\"plays\")\n    y = g.nodes[\"game\"].data[\"y\"]\n    assert F.array_equal(y, F.tensor([[0.0, 0.0], [2.0, 2.0]]))\n\n    # only one type\n    g[\"plays\"].pull(1, mfunc, rfunc)\n    y = g.nodes[\"game\"].data[\"y\"]\n    assert F.array_equal(y, F.tensor([[0.0, 0.0], [2.0, 2.0]]))\n\n    # test fail case\n    with pytest.raises(DGLError):\n        g.pull(1, mfunc, rfunc)\n\n    g.nodes[\"game\"].data.clear()\n\n    #############################################################\n    #  update_all\n    #############################################################\n\n    g.nodes[\"user\"].data[\"h\"] = F.ones((3, 2))\n    g.update_all(mfunc, rfunc, etype=\"plays\")\n    y = g.nodes[\"game\"].data[\"y\"]\n    assert F.array_equal(y, F.tensor([[2.0, 2.0], [2.0, 2.0]]))\n\n    # only one type\n    g[\"plays\"].update_all(mfunc, rfunc)\n    y = g.nodes[\"game\"].data[\"y\"]\n    assert F.array_equal(y, F.tensor([[2.0, 2.0], [2.0, 2.0]]))\n\n    # test fail case\n    # fail due to multiple types\n    with pytest.raises(DGLError):\n        g.update_all(mfunc, rfunc)\n\n    # test multi\n    g.multi_update_all(\n        {\"plays\": (mfunc, rfunc), (\"user\", \"wishes\", \"game\"): (mfunc, rfunc2)},\n        \"sum\",\n    )\n    assert F.array_equal(\n        g.nodes[\"game\"].data[\"y\"], F.tensor([[3.0, 3.0], [3.0, 3.0]])\n    )\n\n    # test multi\n    g.multi_update_all(\n        {\n            \"plays\": (mfunc, rfunc, afunc),\n            (\"user\", \"wishes\", \"game\"): (mfunc, rfunc2),\n        },\n        \"sum\",\n        afunc,\n    )\n    assert F.array_equal(\n        g.nodes[\"game\"].data[\"y\"], F.tensor([[5.0, 5.0], [5.0, 5.0]])\n    )\n\n    # test cross reducer\n    g.nodes[\"user\"].data[\"h\"] = F.randn((3, 2))\n    for cred in [\"sum\", \"max\", \"min\", \"mean\", \"stack\"]:\n        g.multi_update_all(\n            {\"plays\": (mfunc, rfunc, afunc), \"wishes\": (mfunc, rfunc2)},\n            cred,\n            afunc,\n        )\n        y = g.nodes[\"game\"].data[\"y\"]\n        g[\"plays\"].update_all(mfunc, rfunc, afunc)\n        y1 = g.nodes[\"game\"].data[\"y\"]\n        g[\"wishes\"].update_all(mfunc, rfunc2)\n        y2 = g.nodes[\"game\"].data[\"y\"]\n        if cred == \"stack\":\n            # stack has an internal order by edge type id\n            yy = F.stack([y1, y2], 1)\n            yy = yy + 1  # final afunc\n            assert F.array_equal(y, yy)\n        else:\n            yy = get_redfn(cred)(F.stack([y1, y2], 0), 0)\n            yy = yy + 1  # final afunc\n            assert F.array_equal(y, yy)\n\n    # test fail case\n    # fail because cannot infer ntype\n    with pytest.raises(DGLError):\n        g.update_all(\n            {\"plays\": (mfunc, rfunc), \"follows\": (mfunc, rfunc2)}, \"sum\"\n        )\n\n    g.nodes[\"game\"].data.clear()\n\n\n@parametrize_idtype\n@unittest.skipIf(\n    F._default_context_str == \"cpu\", reason=\"Need gpu for this test\"\n)\ndef test_more_nnz(idtype):\n    g = dgl.graph(\n        ([0, 0, 0, 0, 0], [1, 1, 1, 1, 1]), idtype=idtype, device=F.ctx()\n    )\n    g.ndata[\"x\"] = F.copy_to(F.ones((2, 5)), ctx=F.ctx())\n    g.update_all(fn.copy_u(\"x\", \"m\"), fn.sum(\"m\", \"y\"))\n    y = g.ndata[\"y\"]\n    ans = np.zeros((2, 5))\n    ans[1] = 5\n    ans = F.copy_to(F.tensor(ans, dtype=F.dtype(y)), ctx=F.ctx())\n    assert F.array_equal(y, ans)\n\n\n@parametrize_idtype\ndef test_updates(idtype):\n    def msg_func(edges):\n        return {\"m\": edges.src[\"h\"]}\n\n    def reduce_func(nodes):\n        return {\"y\": F.sum(nodes.mailbox[\"m\"], 1)}\n\n    def apply_func(nodes):\n        return {\"y\": nodes.data[\"y\"] * 2}\n\n    g = create_test_heterograph(idtype)\n    x = F.randn((3, 5))\n    g.nodes[\"user\"].data[\"h\"] = x\n\n    for msg, red, apply in itertools.product(\n        [fn.copy_u(\"h\", \"m\"), msg_func],\n        [fn.sum(\"m\", \"y\"), reduce_func],\n        [None, apply_func],\n    ):\n        multiplier = 1 if apply is None else 2\n\n        g[\"user\", \"plays\", \"game\"].update_all(msg, red, apply)\n        y = g.nodes[\"game\"].data[\"y\"]\n        assert F.array_equal(y[0], (x[0] + x[1]) * multiplier)\n        assert F.array_equal(y[1], (x[1] + x[2]) * multiplier)\n        del g.nodes[\"game\"].data[\"y\"]\n\n        g[\"user\", \"plays\", \"game\"].send_and_recv(\n            ([0, 1, 2], [0, 1, 1]), msg, red, apply\n        )\n        y = g.nodes[\"game\"].data[\"y\"]\n        assert F.array_equal(y[0], x[0] * multiplier)\n        assert F.array_equal(y[1], (x[1] + x[2]) * multiplier)\n        del g.nodes[\"game\"].data[\"y\"]\n\n        # pulls from destination (game) node 0\n        g[\"user\", \"plays\", \"game\"].pull(0, msg, red, apply)\n        y = g.nodes[\"game\"].data[\"y\"]\n        assert F.array_equal(y[0], (x[0] + x[1]) * multiplier)\n        del g.nodes[\"game\"].data[\"y\"]\n\n        # pushes from source (user) node 0\n        g[\"user\", \"plays\", \"game\"].push(0, msg, red, apply)\n        y = g.nodes[\"game\"].data[\"y\"]\n        assert F.array_equal(y[0], x[0] * multiplier)\n        del g.nodes[\"game\"].data[\"y\"]\n\n\n@parametrize_idtype\ndef test_backward(idtype):\n    g = create_test_heterograph(idtype)\n    x = F.randn((3, 5))\n    F.attach_grad(x)\n    g.nodes[\"user\"].data[\"h\"] = x\n    with F.record_grad():\n        g.multi_update_all(\n            {\n                \"plays\": (fn.copy_u(\"h\", \"m\"), fn.sum(\"m\", \"y\")),\n                \"wishes\": (fn.copy_u(\"h\", \"m\"), fn.sum(\"m\", \"y\")),\n            },\n            \"sum\",\n        )\n        y = g.nodes[\"game\"].data[\"y\"]\n        F.backward(y, F.ones(y.shape))\n    print(F.grad(x))\n    assert F.array_equal(\n        F.grad(x),\n        F.tensor(\n            [\n                [2.0, 2.0, 2.0, 2.0, 2.0],\n                [2.0, 2.0, 2.0, 2.0, 2.0],\n                [2.0, 2.0, 2.0, 2.0, 2.0],\n            ]\n        ),\n    )\n\n\n@parametrize_idtype\ndef test_empty_heterograph(idtype):\n    def assert_empty(g):\n        assert g.num_nodes(\"user\") == 0\n        assert g.num_edges(\"plays\") == 0\n        assert g.num_nodes(\"game\") == 0\n\n    # empty src-dst pair\n    assert_empty(dgl.heterograph({(\"user\", \"plays\", \"game\"): ([], [])}))\n\n    g = dgl.heterograph(\n        {(\"user\", \"follows\", \"user\"): ([], [])}, idtype=idtype, device=F.ctx()\n    )\n    assert g.idtype == idtype\n    assert g.device == F.ctx()\n    assert g.num_nodes(\"user\") == 0\n    assert g.num_edges(\"follows\") == 0\n\n    # empty relation graph with others\n    g = dgl.heterograph(\n        {\n            (\"user\", \"plays\", \"game\"): ([], []),\n            (\"developer\", \"develops\", \"game\"): ([0, 1], [0, 1]),\n        },\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    assert g.idtype == idtype\n    assert g.device == F.ctx()\n    assert g.num_nodes(\"user\") == 0\n    assert g.num_edges(\"plays\") == 0\n    assert g.num_nodes(\"game\") == 2\n    assert g.num_edges(\"develops\") == 2\n    assert g.num_nodes(\"developer\") == 2\n\n\n@parametrize_idtype\ndef test_types_in_function(idtype):\n    def mfunc1(edges):\n        assert edges.canonical_etype == (\"user\", \"follow\", \"user\")\n        return {}\n\n    def rfunc1(nodes):\n        assert nodes.ntype == \"user\"\n        return {}\n\n    def filter_nodes1(nodes):\n        assert nodes.ntype == \"user\"\n        return F.zeros((3,))\n\n    def filter_edges1(edges):\n        assert edges.canonical_etype == (\"user\", \"follow\", \"user\")\n        return F.zeros((2,))\n\n    def mfunc2(edges):\n        assert edges.canonical_etype == (\"user\", \"plays\", \"game\")\n        return {}\n\n    def rfunc2(nodes):\n        assert nodes.ntype == \"game\"\n        return {}\n\n    def filter_nodes2(nodes):\n        assert nodes.ntype == \"game\"\n        return F.zeros((3,))\n\n    def filter_edges2(edges):\n        assert edges.canonical_etype == (\"user\", \"plays\", \"game\")\n        return F.zeros((2,))\n\n    g = dgl.heterograph(\n        {(\"user\", \"follow\", \"user\"): ((0, 1), (1, 2))},\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    g.apply_nodes(rfunc1)\n    g.apply_edges(mfunc1)\n    g.update_all(mfunc1, rfunc1)\n    g.send_and_recv([0, 1], mfunc1, rfunc1)\n    g.push([0], mfunc1, rfunc1)\n    g.pull([1], mfunc1, rfunc1)\n    g.filter_nodes(filter_nodes1)\n    g.filter_edges(filter_edges1)\n\n    g = dgl.heterograph(\n        {(\"user\", \"plays\", \"game\"): ([0, 1], [1, 2])},\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    g.apply_nodes(rfunc2, ntype=\"game\")\n    g.apply_edges(mfunc2)\n    g.update_all(mfunc2, rfunc2)\n    g.send_and_recv([0, 1], mfunc2, rfunc2)\n    g.push([0], mfunc2, rfunc2)\n    g.pull([1], mfunc2, rfunc2)\n    g.filter_nodes(filter_nodes2, ntype=\"game\")\n    g.filter_edges(filter_edges2)\n\n\n@parametrize_idtype\ndef test_stack_reduce(idtype):\n    # edges = {\n    #    'follows': ([0, 1], [1, 2]),\n    #    'plays': ([0, 1, 2, 1], [0, 0, 1, 1]),\n    #    'wishes': ([0, 2], [1, 0]),\n    #    'develops': ([0, 1], [0, 1]),\n    # }\n    g = create_test_heterograph(idtype)\n    g.nodes[\"user\"].data[\"h\"] = F.randn((3, 200))\n\n    def rfunc(nodes):\n        return {\"y\": F.sum(nodes.mailbox[\"m\"], 1)}\n\n    def rfunc2(nodes):\n        return {\"y\": F.max(nodes.mailbox[\"m\"], 1)}\n\n    def mfunc(edges):\n        return {\"m\": edges.src[\"h\"]}\n\n    g.multi_update_all(\n        {\"plays\": (mfunc, rfunc), \"wishes\": (mfunc, rfunc2)}, \"stack\"\n    )\n    assert g.nodes[\"game\"].data[\"y\"].shape == (\n        g.num_nodes(\"game\"),\n        2,\n        200,\n    )\n    # only one type-wise update_all, stack still adds one dimension\n    g.multi_update_all({\"plays\": (mfunc, rfunc)}, \"stack\")\n    assert g.nodes[\"game\"].data[\"y\"].shape == (\n        g.num_nodes(\"game\"),\n        1,\n        200,\n    )\n\n\n@parametrize_idtype\ndef test_isolated_ntype(idtype):\n    g = dgl.heterograph(\n        {(\"A\", \"AB\", \"B\"): ([0, 1, 2], [1, 2, 3])},\n        num_nodes_dict={\"A\": 3, \"B\": 4, \"C\": 4},\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    assert g.num_nodes(\"A\") == 3\n    assert g.num_nodes(\"B\") == 4\n    assert g.num_nodes(\"C\") == 4\n\n    g = dgl.heterograph(\n        {(\"A\", \"AC\", \"C\"): ([0, 1, 2], [1, 2, 3])},\n        num_nodes_dict={\"A\": 3, \"B\": 4, \"C\": 4},\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    assert g.num_nodes(\"A\") == 3\n    assert g.num_nodes(\"B\") == 4\n    assert g.num_nodes(\"C\") == 4\n\n    G = dgl.graph(\n        ([0, 1, 2], [4, 5, 6]), num_nodes=11, idtype=idtype, device=F.ctx()\n    )\n    G.ndata[dgl.NTYPE] = F.tensor(\n        [0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2], dtype=F.int64\n    )\n    G.edata[dgl.ETYPE] = F.tensor([0, 0, 0], dtype=F.int64)\n    g = dgl.to_heterogeneous(G, [\"A\", \"B\", \"C\"], [\"AB\"])\n    assert g.num_nodes(\"A\") == 3\n    assert g.num_nodes(\"B\") == 4\n    assert g.num_nodes(\"C\") == 4\n\n\n@parametrize_idtype\ndef test_ismultigraph(idtype):\n    g1 = dgl.heterograph(\n        {(\"A\", \"AB\", \"B\"): ([0, 0, 1, 2], [1, 2, 5, 5])},\n        {\"A\": 6, \"B\": 6},\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    assert g1.is_multigraph == False\n    g2 = dgl.heterograph(\n        {(\"A\", \"AC\", \"C\"): ([0, 0, 0, 1], [1, 1, 2, 5])},\n        {\"A\": 6, \"C\": 6},\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    assert g2.is_multigraph == True\n    g3 = dgl.graph(((0, 1), (1, 2)), num_nodes=6, idtype=idtype, device=F.ctx())\n    assert g3.is_multigraph == False\n    g4 = dgl.graph(\n        ([0, 0, 1], [1, 1, 2]), num_nodes=6, idtype=idtype, device=F.ctx()\n    )\n    assert g4.is_multigraph == True\n    g = dgl.heterograph(\n        {\n            (\"A\", \"AB\", \"B\"): ([0, 0, 1, 2], [1, 2, 5, 5]),\n            (\"A\", \"AA\", \"A\"): ([0, 1], [1, 2]),\n        },\n        {\"A\": 6, \"B\": 6},\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    assert g.is_multigraph == False\n    g = dgl.heterograph(\n        {\n            (\"A\", \"AB\", \"B\"): ([0, 0, 1, 2], [1, 2, 5, 5]),\n            (\"A\", \"AC\", \"C\"): ([0, 0, 0, 1], [1, 1, 2, 5]),\n        },\n        {\"A\": 6, \"B\": 6, \"C\": 6},\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    assert g.is_multigraph == True\n    g = dgl.heterograph(\n        {\n            (\"A\", \"AB\", \"B\"): ([0, 0, 1, 2], [1, 2, 5, 5]),\n            (\"A\", \"AA\", \"A\"): ([0, 0, 1], [1, 1, 2]),\n        },\n        {\"A\": 6, \"B\": 6},\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    assert g.is_multigraph == True\n    g = dgl.heterograph(\n        {\n            (\"A\", \"AC\", \"C\"): ([0, 0, 0, 1], [1, 1, 2, 5]),\n            (\"A\", \"AA\", \"A\"): ([0, 1], [1, 2]),\n        },\n        {\"A\": 6, \"C\": 6},\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    assert g.is_multigraph == True\n\n\n@parametrize_idtype\ndef test_graph_index_is_unibipartite(idtype):\n    g1 = dgl.heterograph(\n        {(\"A\", \"AB\", \"B\"): ([0, 0, 1], [1, 2, 5])},\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    assert g1._graph.is_metagraph_unibipartite()\n\n    # more complicated bipartite\n    g2 = dgl.heterograph(\n        {\n            (\"A\", \"AB\", \"B\"): ([0, 0, 1], [1, 2, 5]),\n            (\"A\", \"AC\", \"C\"): ([1, 0], [0, 0]),\n        },\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    assert g2._graph.is_metagraph_unibipartite()\n\n    g3 = dgl.heterograph(\n        {\n            (\"A\", \"AB\", \"B\"): ([0, 0, 1], [1, 2, 5]),\n            (\"A\", \"AC\", \"C\"): ([1, 0], [0, 0]),\n            (\"A\", \"AA\", \"A\"): ([0, 1], [0, 1]),\n        },\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    assert not g3._graph.is_metagraph_unibipartite()\n\n    g4 = dgl.heterograph(\n        {\n            (\"A\", \"AB\", \"B\"): ([0, 0, 1], [1, 2, 5]),\n            (\"C\", \"CA\", \"A\"): ([1, 0], [0, 0]),\n        },\n        idtype=idtype,\n        device=F.ctx(),\n    )\n\n    assert not g4._graph.is_metagraph_unibipartite()\n\n\n@parametrize_idtype\ndef test_bipartite(idtype):\n    g1 = dgl.heterograph(\n        {(\"A\", \"AB\", \"B\"): ([0, 0, 1], [1, 2, 5])},\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    assert g1.is_unibipartite\n    assert len(g1.ntypes) == 2\n    assert g1.etypes == [\"AB\"]\n    assert g1.srctypes == [\"A\"]\n    assert g1.dsttypes == [\"B\"]\n    assert g1.num_nodes(\"A\") == 2\n    assert g1.num_nodes(\"B\") == 6\n    assert g1.number_of_src_nodes(\"A\") == 2\n    assert g1.number_of_src_nodes() == 2\n    assert g1.number_of_dst_nodes(\"B\") == 6\n    assert g1.number_of_dst_nodes() == 6\n    assert g1.num_edges() == 3\n    g1.srcdata[\"h\"] = F.randn((2, 5))\n    assert F.array_equal(g1.srcnodes[\"A\"].data[\"h\"], g1.srcdata[\"h\"])\n    assert F.array_equal(g1.nodes[\"A\"].data[\"h\"], g1.srcdata[\"h\"])\n    assert F.array_equal(g1.nodes[\"SRC/A\"].data[\"h\"], g1.srcdata[\"h\"])\n    g1.dstdata[\"h\"] = F.randn((6, 3))\n    assert F.array_equal(g1.dstnodes[\"B\"].data[\"h\"], g1.dstdata[\"h\"])\n    assert F.array_equal(g1.nodes[\"B\"].data[\"h\"], g1.dstdata[\"h\"])\n    assert F.array_equal(g1.nodes[\"DST/B\"].data[\"h\"], g1.dstdata[\"h\"])\n\n    # more complicated bipartite\n    g2 = dgl.heterograph(\n        {\n            (\"A\", \"AB\", \"B\"): ([0, 0, 1], [1, 2, 5]),\n            (\"A\", \"AC\", \"C\"): ([1, 0], [0, 0]),\n        },\n        idtype=idtype,\n        device=F.ctx(),\n    )\n\n    assert g2.is_unibipartite\n    assert g2.srctypes == [\"A\"]\n    assert set(g2.dsttypes) == {\"B\", \"C\"}\n    assert g2.num_nodes(\"A\") == 2\n    assert g2.num_nodes(\"B\") == 6\n    assert g2.num_nodes(\"C\") == 1\n    assert g2.number_of_src_nodes(\"A\") == 2\n    assert g2.number_of_src_nodes() == 2\n    assert g2.number_of_dst_nodes(\"B\") == 6\n    assert g2.number_of_dst_nodes(\"C\") == 1\n    g2.srcdata[\"h\"] = F.randn((2, 5))\n    assert F.array_equal(g2.srcnodes[\"A\"].data[\"h\"], g2.srcdata[\"h\"])\n    assert F.array_equal(g2.nodes[\"A\"].data[\"h\"], g2.srcdata[\"h\"])\n    assert F.array_equal(g2.nodes[\"SRC/A\"].data[\"h\"], g2.srcdata[\"h\"])\n\n    g3 = dgl.heterograph(\n        {\n            (\"A\", \"AB\", \"B\"): ([0, 0, 1], [1, 2, 5]),\n            (\"A\", \"AC\", \"C\"): ([1, 0], [0, 0]),\n            (\"A\", \"AA\", \"A\"): ([0, 1], [0, 1]),\n        },\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    assert not g3.is_unibipartite\n\n    g4 = dgl.heterograph(\n        {\n            (\"A\", \"AB\", \"B\"): ([0, 0, 1], [1, 2, 5]),\n            (\"C\", \"CA\", \"A\"): ([1, 0], [0, 0]),\n        },\n        idtype=idtype,\n        device=F.ctx(),\n    )\n\n    assert not g4.is_unibipartite\n\n\n@parametrize_idtype\ndef test_dtype_cast(idtype):\n    g = dgl.graph(([0, 1, 0, 2], [0, 1, 1, 0]), idtype=idtype, device=F.ctx())\n    assert g.idtype == idtype\n    g.ndata[\"feat\"] = F.tensor([3, 4, 5])\n    g.edata[\"h\"] = F.tensor([3, 4, 5, 6])\n    if idtype == \"int32\":\n        g_cast = g.long()\n        assert g_cast.idtype == F.int64\n    else:\n        g_cast = g.int()\n        assert g_cast.idtype == F.int32\n    check_graph_equal(g, g_cast, check_idtype=False)\n\n\ndef test_float_cast():\n    for t in [F.bfloat16, F.float16, F.float32, F.float64]:\n        idtype = F.int32\n        g = dgl.heterograph(\n            {\n                (\"user\", \"follows\", \"user\"): (\n                    F.tensor([0, 1, 1, 2, 2, 3], dtype=idtype),\n                    F.tensor([0, 0, 1, 1, 2, 2], dtype=idtype),\n                ),\n                (\"user\", \"plays\", \"game\"): (\n                    F.tensor([0, 1, 1], dtype=idtype),\n                    F.tensor([0, 0, 1], dtype=idtype),\n                ),\n            },\n            idtype=idtype,\n            device=F.ctx(),\n        )\n        uvalues = [1, 2, 3, 4]\n        gvalues = [5, 6]\n        fvalues = [7, 8, 9, 10, 11, 12]\n        pvalues = [13, 14, 15]\n        dataNamesTypes = [\n            (\"a\", F.float16),\n            (\"b\", F.float32),\n            (\"c\", F.float64),\n            (\"d\", F.int32),\n            (\"e\", F.int64),\n            (\"f\", F.bfloat16),\n        ]\n        for name, type in dataNamesTypes:\n            g.nodes[\"user\"].data[name] = F.copy_to(\n                F.tensor(uvalues, dtype=type), ctx=F.ctx()\n            )\n        for name, type in dataNamesTypes:\n            g.nodes[\"game\"].data[name] = F.copy_to(\n                F.tensor(gvalues, dtype=type), ctx=F.ctx()\n            )\n        for name, type in dataNamesTypes:\n            g.edges[\"follows\"].data[name] = F.copy_to(\n                F.tensor(fvalues, dtype=type), ctx=F.ctx()\n            )\n        for name, type in dataNamesTypes:\n            g.edges[\"plays\"].data[name] = F.copy_to(\n                F.tensor(pvalues, dtype=type), ctx=F.ctx()\n            )\n\n        if t == F.bfloat16:\n            g = dgl.transforms.functional.to_bfloat16(g)\n        if t == F.float16:\n            g = dgl.transforms.functional.to_half(g)\n        if t == F.float32:\n            g = dgl.transforms.functional.to_float(g)\n        if t == F.float64:\n            g = dgl.transforms.functional.to_double(g)\n\n        for name, origType in dataNamesTypes:\n            # integer tensors shouldn't be converted\n            reqType = (\n                t\n                if (origType in [F.bfloat16, F.float16, F.float32, F.float64])\n                else origType\n            )\n\n            values = g.nodes[\"user\"].data[name]\n            assert values.dtype == reqType\n            assert len(values) == len(uvalues)\n            assert F.allclose(values, F.tensor(uvalues), 0, 0)\n\n            values = g.nodes[\"game\"].data[name]\n            assert values.dtype == reqType\n            assert len(values) == len(gvalues)\n            assert F.allclose(values, F.tensor(gvalues), 0, 0)\n\n            values = g.edges[\"follows\"].data[name]\n            assert values.dtype == reqType\n            assert len(values) == len(fvalues)\n            assert F.allclose(values, F.tensor(fvalues), 0, 0)\n\n            values = g.edges[\"plays\"].data[name]\n            assert values.dtype == reqType\n            assert len(values) == len(pvalues)\n            assert F.allclose(values, F.tensor(pvalues), 0, 0)\n\n\n@parametrize_idtype\ndef test_format(idtype):\n    # single relation\n    g = dgl.graph(([0, 1, 0, 2], [0, 1, 1, 0]), idtype=idtype, device=F.ctx())\n    assert g.formats()[\"created\"] == [\"coo\"]\n    g1 = g.formats([\"coo\", \"csr\", \"csc\"])\n    assert len(g1.formats()[\"created\"]) + len(g1.formats()[\"not created\"]) == 3\n    g1.create_formats_()\n    assert len(g1.formats()[\"created\"]) == 3\n    assert g.formats()[\"created\"] == [\"coo\"]\n\n    # multiple relation\n    g = dgl.heterograph(\n        {\n            (\"user\", \"follows\", \"user\"): ([0, 1], [1, 2]),\n            (\"user\", \"plays\", \"game\"): ([0, 1, 1, 2], [0, 0, 1, 1]),\n            (\"developer\", \"develops\", \"game\"): ([0, 1], [0, 1]),\n        },\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    user_feat = F.randn((g[\"follows\"].number_of_src_nodes(), 5))\n    g[\"follows\"].srcdata[\"h\"] = user_feat\n    g1 = g.formats(\"csc\")\n    # test frame\n    assert F.array_equal(g1[\"follows\"].srcdata[\"h\"], user_feat)\n    # test each relation graph\n    assert g1.formats()[\"created\"] == [\"csc\"]\n    assert len(g1.formats()[\"not created\"]) == 0\n\n    # in_degrees\n    g = dgl.rand_graph(100, 2340).to(F.ctx())\n    ind_arr = []\n    for vid in range(0, 100):\n        ind_arr.append(g.in_degrees(vid))\n    in_degrees = g.in_degrees()\n    g = g.formats(\"coo\")\n    for vid in range(0, 100):\n        assert g.in_degrees(vid) == ind_arr[vid]\n    assert F.array_equal(in_degrees, g.in_degrees())\n\n\n@parametrize_idtype\ndef test_edges_order(idtype):\n    # (0, 2), (1, 2), (0, 1), (0, 1), (2, 1)\n    g = dgl.graph(\n        (np.array([0, 1, 0, 0, 2]), np.array([2, 2, 1, 1, 1])),\n        idtype=idtype,\n        device=F.ctx(),\n    )\n\n    print(g.formats())\n    src, dst = g.all_edges(order=\"srcdst\")\n    assert F.array_equal(src, F.tensor([0, 0, 0, 1, 2], dtype=idtype))\n    assert F.array_equal(dst, F.tensor([1, 1, 2, 2, 1], dtype=idtype))\n\n\n@parametrize_idtype\ndef test_reverse(idtype):\n    g = dgl.heterograph(\n        {\n            (\"user\", \"follows\", \"user\"): (\n                [0, 1, 2, 4, 3, 1, 3],\n                [1, 2, 3, 2, 0, 0, 1],\n            )\n        },\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    gidx = g._graph\n    r_gidx = gidx.reverse()\n\n    assert gidx.num_nodes(0) == r_gidx.num_nodes(0)\n    assert gidx.num_edges(0) == r_gidx.num_edges(0)\n    g_s, g_d, _ = gidx.edges(0)\n    rg_s, rg_d, _ = r_gidx.edges(0)\n    assert F.array_equal(g_s, rg_d)\n    assert F.array_equal(g_d, rg_s)\n\n    # force to start with 'csr'\n    gidx = gidx.formats(\"csr\")\n    gidx = gidx.formats([\"coo\", \"csr\", \"csc\"])\n    r_gidx = gidx.reverse()\n    assert \"csr\" in gidx.formats()[\"created\"]\n    assert \"csc\" in r_gidx.formats()[\"created\"]\n    assert gidx.num_nodes(0) == r_gidx.num_nodes(0)\n    assert gidx.num_edges(0) == r_gidx.num_edges(0)\n    g_s, g_d, _ = gidx.edges(0)\n    rg_s, rg_d, _ = r_gidx.edges(0)\n    assert F.array_equal(g_s, rg_d)\n    assert F.array_equal(g_d, rg_s)\n\n    # force to start with 'csc'\n    gidx = gidx.formats(\"csc\")\n    gidx = gidx.formats([\"coo\", \"csr\", \"csc\"])\n    r_gidx = gidx.reverse()\n    assert \"csc\" in gidx.formats()[\"created\"]\n    assert \"csr\" in r_gidx.formats()[\"created\"]\n    assert gidx.num_nodes(0) == r_gidx.num_nodes(0)\n    assert gidx.num_edges(0) == r_gidx.num_edges(0)\n    g_s, g_d, _ = gidx.edges(0)\n    rg_s, rg_d, _ = r_gidx.edges(0)\n    assert F.array_equal(g_s, rg_d)\n    assert F.array_equal(g_d, rg_s)\n\n    g = dgl.heterograph(\n        {\n            (\"user\", \"follows\", \"user\"): (\n                [0, 1, 2, 4, 3, 1, 3],\n                [1, 2, 3, 2, 0, 0, 1],\n            ),\n            (\"user\", \"plays\", \"game\"): (\n                [0, 0, 2, 3, 3, 4, 1],\n                [1, 0, 1, 0, 1, 0, 0],\n            ),\n            (\"developer\", \"develops\", \"game\"): ([0, 1, 1, 2], [0, 0, 1, 1]),\n        },\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    gidx = g._graph\n    r_gidx = gidx.reverse()\n\n    # metagraph\n    mg = gidx.metagraph\n    r_mg = r_gidx.metagraph\n    for etype in range(3):\n        assert mg.find_edge(etype) == r_mg.find_edge(etype)[::-1]\n\n    # three node types and three edge types\n    assert gidx.num_nodes(0) == r_gidx.num_nodes(0)\n    assert gidx.num_nodes(1) == r_gidx.num_nodes(1)\n    assert gidx.num_nodes(2) == r_gidx.num_nodes(2)\n    assert gidx.num_edges(0) == r_gidx.num_edges(0)\n    assert gidx.num_edges(1) == r_gidx.num_edges(1)\n    assert gidx.num_edges(2) == r_gidx.num_edges(2)\n    g_s, g_d, _ = gidx.edges(0)\n    rg_s, rg_d, _ = r_gidx.edges(0)\n    assert F.array_equal(g_s, rg_d)\n    assert F.array_equal(g_d, rg_s)\n    g_s, g_d, _ = gidx.edges(1)\n    rg_s, rg_d, _ = r_gidx.edges(1)\n    assert F.array_equal(g_s, rg_d)\n    assert F.array_equal(g_d, rg_s)\n    g_s, g_d, _ = gidx.edges(2)\n    rg_s, rg_d, _ = r_gidx.edges(2)\n    assert F.array_equal(g_s, rg_d)\n    assert F.array_equal(g_d, rg_s)\n\n    # force to start with 'csr'\n    gidx = gidx.formats(\"csr\")\n    gidx = gidx.formats([\"coo\", \"csr\", \"csc\"])\n    r_gidx = gidx.reverse()\n    # three node types and three edge types\n    assert \"csr\" in gidx.formats()[\"created\"]\n    assert \"csc\" in r_gidx.formats()[\"created\"]\n    assert gidx.num_nodes(0) == r_gidx.num_nodes(0)\n    assert gidx.num_nodes(1) == r_gidx.num_nodes(1)\n    assert gidx.num_nodes(2) == r_gidx.num_nodes(2)\n    assert gidx.num_edges(0) == r_gidx.num_edges(0)\n    assert gidx.num_edges(1) == r_gidx.num_edges(1)\n    assert gidx.num_edges(2) == r_gidx.num_edges(2)\n    g_s, g_d, _ = gidx.edges(0)\n    rg_s, rg_d, _ = r_gidx.edges(0)\n    assert F.array_equal(g_s, rg_d)\n    assert F.array_equal(g_d, rg_s)\n    g_s, g_d, _ = gidx.edges(1)\n    rg_s, rg_d, _ = r_gidx.edges(1)\n    assert F.array_equal(g_s, rg_d)\n    assert F.array_equal(g_d, rg_s)\n    g_s, g_d, _ = gidx.edges(2)\n    rg_s, rg_d, _ = r_gidx.edges(2)\n    assert F.array_equal(g_s, rg_d)\n    assert F.array_equal(g_d, rg_s)\n\n    # force to start with 'csc'\n    gidx = gidx.formats(\"csc\")\n    gidx = gidx.formats([\"coo\", \"csr\", \"csc\"])\n    r_gidx = gidx.reverse()\n    # three node types and three edge types\n    assert \"csc\" in gidx.formats()[\"created\"]\n    assert \"csr\" in r_gidx.formats()[\"created\"]\n    assert gidx.num_nodes(0) == r_gidx.num_nodes(0)\n    assert gidx.num_nodes(1) == r_gidx.num_nodes(1)\n    assert gidx.num_nodes(2) == r_gidx.num_nodes(2)\n    assert gidx.num_edges(0) == r_gidx.num_edges(0)\n    assert gidx.num_edges(1) == r_gidx.num_edges(1)\n    assert gidx.num_edges(2) == r_gidx.num_edges(2)\n    g_s, g_d, _ = gidx.edges(0)\n    rg_s, rg_d, _ = r_gidx.edges(0)\n    assert F.array_equal(g_s, rg_d)\n    assert F.array_equal(g_d, rg_s)\n    g_s, g_d, _ = gidx.edges(1)\n    rg_s, rg_d, _ = r_gidx.edges(1)\n    assert F.array_equal(g_s, rg_d)\n    assert F.array_equal(g_d, rg_s)\n    g_s, g_d, _ = gidx.edges(2)\n    rg_s, rg_d, _ = r_gidx.edges(2)\n    assert F.array_equal(g_s, rg_d)\n    assert F.array_equal(g_d, rg_s)\n\n\n@parametrize_idtype\ndef test_clone(idtype):\n    g = dgl.graph(([0, 1], [1, 2]), idtype=idtype, device=F.ctx())\n    g.ndata[\"h\"] = F.copy_to(F.tensor([1, 1, 1], dtype=idtype), ctx=F.ctx())\n    g.edata[\"h\"] = F.copy_to(F.tensor([1, 1], dtype=idtype), ctx=F.ctx())\n\n    new_g = g.clone()\n    assert g.num_nodes() == new_g.num_nodes()\n    assert g.num_edges() == new_g.num_edges()\n    assert g.device == new_g.device\n    assert g.idtype == new_g.idtype\n    assert F.array_equal(g.ndata[\"h\"], new_g.ndata[\"h\"])\n    assert F.array_equal(g.edata[\"h\"], new_g.edata[\"h\"])\n    # data change\n    new_g.ndata[\"h\"] = F.copy_to(F.tensor([2, 2, 2], dtype=idtype), ctx=F.ctx())\n    assert F.array_equal(g.ndata[\"h\"], new_g.ndata[\"h\"]) == False\n    g.edata[\"h\"] = F.copy_to(F.tensor([2, 2], dtype=idtype), ctx=F.ctx())\n    assert F.array_equal(g.edata[\"h\"], new_g.edata[\"h\"]) == False\n    # graph structure change\n    g.add_nodes(1)\n    assert g.num_nodes() != new_g.num_nodes()\n    new_g.add_edges(1, 1)\n    assert g.num_edges() != new_g.num_edges()\n\n    # zero data graph\n    g = dgl.graph(([], []), num_nodes=0, idtype=idtype, device=F.ctx())\n    new_g = g.clone()\n    assert g.num_nodes() == new_g.num_nodes()\n    assert g.num_edges() == new_g.num_edges()\n\n    # heterograph\n    g = create_test_heterograph3(idtype)\n    g.edges[\"plays\"].data[\"h\"] = F.copy_to(\n        F.tensor([1, 2, 3, 4], dtype=idtype), ctx=F.ctx()\n    )\n    new_g = g.clone()\n    assert g.num_nodes(\"user\") == new_g.num_nodes(\"user\")\n    assert g.num_nodes(\"game\") == new_g.num_nodes(\"game\")\n    assert g.num_nodes(\"developer\") == new_g.num_nodes(\"developer\")\n    assert g.num_edges(\"plays\") == new_g.num_edges(\"plays\")\n    assert g.num_edges(\"develops\") == new_g.num_edges(\"develops\")\n    assert F.array_equal(\n        g.nodes[\"user\"].data[\"h\"], new_g.nodes[\"user\"].data[\"h\"]\n    )\n    assert F.array_equal(\n        g.nodes[\"game\"].data[\"h\"], new_g.nodes[\"game\"].data[\"h\"]\n    )\n    assert F.array_equal(\n        g.edges[\"plays\"].data[\"h\"], new_g.edges[\"plays\"].data[\"h\"]\n    )\n    assert g.device == new_g.device\n    assert g.idtype == new_g.idtype\n    u, v = g.edges(form=\"uv\", order=\"eid\", etype=\"plays\")\n    nu, nv = new_g.edges(form=\"uv\", order=\"eid\", etype=\"plays\")\n    assert F.array_equal(u, nu)\n    assert F.array_equal(v, nv)\n    # graph structure change\n    u = F.tensor([0, 4], dtype=idtype)\n    v = F.tensor([2, 6], dtype=idtype)\n    g.add_edges(u, v, etype=\"plays\")\n    u, v = g.edges(form=\"uv\", order=\"eid\", etype=\"plays\")\n    assert u.shape[0] != nu.shape[0]\n    assert v.shape[0] != nv.shape[0]\n    assert (\n        g.nodes[\"user\"].data[\"h\"].shape[0]\n        != new_g.nodes[\"user\"].data[\"h\"].shape[0]\n    )\n    assert (\n        g.nodes[\"game\"].data[\"h\"].shape[0]\n        != new_g.nodes[\"game\"].data[\"h\"].shape[0]\n    )\n    assert (\n        g.edges[\"plays\"].data[\"h\"].shape[0]\n        != new_g.edges[\"plays\"].data[\"h\"].shape[0]\n    )\n\n\n@parametrize_idtype\ndef test_add_edges(idtype):\n    # homogeneous graph\n    g = dgl.graph(([0, 1], [1, 2]), idtype=idtype, device=F.ctx())\n    u = 0\n    v = 1\n    g.add_edges(u, v)\n    assert g.device == F.ctx()\n    assert g.num_nodes() == 3\n    assert g.num_edges() == 3\n    u = [0]\n    v = [1]\n    g.add_edges(u, v)\n    assert g.device == F.ctx()\n    assert g.num_nodes() == 3\n    assert g.num_edges() == 4\n    u = F.tensor(u, dtype=idtype)\n    v = F.tensor(v, dtype=idtype)\n    g.add_edges(u, v)\n    assert g.device == F.ctx()\n    assert g.num_nodes() == 3\n    assert g.num_edges() == 5\n    u, v = g.edges(form=\"uv\", order=\"eid\")\n    assert F.array_equal(u, F.tensor([0, 1, 0, 0, 0], dtype=idtype))\n    assert F.array_equal(v, F.tensor([1, 2, 1, 1, 1], dtype=idtype))\n\n    # node id larger than current max node id\n    g = dgl.graph(([0, 1], [1, 2]), idtype=idtype, device=F.ctx())\n    u = F.tensor([0, 1], dtype=idtype)\n    v = F.tensor([2, 3], dtype=idtype)\n    g.add_edges(u, v)\n    assert g.num_nodes() == 4\n    assert g.num_edges() == 4\n    u, v = g.edges(form=\"uv\", order=\"eid\")\n    assert F.array_equal(u, F.tensor([0, 1, 0, 1], dtype=idtype))\n    assert F.array_equal(v, F.tensor([1, 2, 2, 3], dtype=idtype))\n\n    # has data\n    g = dgl.graph(([0, 1], [1, 2]), idtype=idtype, device=F.ctx())\n    g.ndata[\"h\"] = F.copy_to(F.tensor([1, 1, 1], dtype=idtype), ctx=F.ctx())\n    g.edata[\"h\"] = F.copy_to(F.tensor([1, 1], dtype=idtype), ctx=F.ctx())\n    u = F.tensor([0, 1], dtype=idtype)\n    v = F.tensor([2, 3], dtype=idtype)\n    e_feat = {\n        \"h\": F.copy_to(F.tensor([2, 2], dtype=idtype), ctx=F.ctx()),\n        \"hh\": F.copy_to(F.tensor([2, 2], dtype=idtype), ctx=F.ctx()),\n    }\n    g.add_edges(u, v, e_feat)\n    assert g.num_nodes() == 4\n    assert g.num_edges() == 4\n    u, v = g.edges(form=\"uv\", order=\"eid\")\n    assert F.array_equal(u, F.tensor([0, 1, 0, 1], dtype=idtype))\n    assert F.array_equal(v, F.tensor([1, 2, 2, 3], dtype=idtype))\n    assert F.array_equal(g.ndata[\"h\"], F.tensor([1, 1, 1, 0], dtype=idtype))\n    assert F.array_equal(g.edata[\"h\"], F.tensor([1, 1, 2, 2], dtype=idtype))\n    assert F.array_equal(g.edata[\"hh\"], F.tensor([0, 0, 2, 2], dtype=idtype))\n\n    # zero data graph\n    g = dgl.graph(([], []), num_nodes=0, idtype=idtype, device=F.ctx())\n    u = F.tensor([0, 1], dtype=idtype)\n    v = F.tensor([2, 2], dtype=idtype)\n    e_feat = {\n        \"h\": F.copy_to(F.tensor([2, 2], dtype=idtype), ctx=F.ctx()),\n        \"hh\": F.copy_to(F.tensor([2, 2], dtype=idtype), ctx=F.ctx()),\n    }\n    g.add_edges(u, v, e_feat)\n    assert g.num_nodes() == 3\n    assert g.num_edges() == 2\n    u, v = g.edges(form=\"uv\", order=\"eid\")\n    assert F.array_equal(u, F.tensor([0, 1], dtype=idtype))\n    assert F.array_equal(v, F.tensor([2, 2], dtype=idtype))\n    assert F.array_equal(g.edata[\"h\"], F.tensor([2, 2], dtype=idtype))\n    assert F.array_equal(g.edata[\"hh\"], F.tensor([2, 2], dtype=idtype))\n\n    # bipartite graph\n    g = dgl.heterograph(\n        {(\"user\", \"plays\", \"game\"): ([0, 1], [1, 2])},\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    u = 0\n    v = 1\n    g.add_edges(u, v)\n    assert g.device == F.ctx()\n    assert g.num_nodes(\"user\") == 2\n    assert g.num_nodes(\"game\") == 3\n    assert g.num_edges() == 3\n    u = [0]\n    v = [1]\n    g.add_edges(u, v)\n    assert g.device == F.ctx()\n    assert g.num_nodes(\"user\") == 2\n    assert g.num_nodes(\"game\") == 3\n    assert g.num_edges() == 4\n    u = F.tensor(u, dtype=idtype)\n    v = F.tensor(v, dtype=idtype)\n    g.add_edges(u, v)\n    assert g.device == F.ctx()\n    assert g.num_nodes(\"user\") == 2\n    assert g.num_nodes(\"game\") == 3\n    assert g.num_edges() == 5\n    u, v = g.edges(form=\"uv\")\n    assert F.array_equal(u, F.tensor([0, 1, 0, 0, 0], dtype=idtype))\n    assert F.array_equal(v, F.tensor([1, 2, 1, 1, 1], dtype=idtype))\n\n    # node id larger than current max node id\n    g = dgl.heterograph(\n        {(\"user\", \"plays\", \"game\"): ([0, 1], [1, 2])},\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    u = F.tensor([0, 2], dtype=idtype)\n    v = F.tensor([2, 3], dtype=idtype)\n    g.add_edges(u, v)\n    assert g.device == F.ctx()\n    assert g.num_nodes(\"user\") == 3\n    assert g.num_nodes(\"game\") == 4\n    assert g.num_edges() == 4\n    u, v = g.edges(form=\"uv\", order=\"eid\")\n    assert F.array_equal(u, F.tensor([0, 1, 0, 2], dtype=idtype))\n    assert F.array_equal(v, F.tensor([1, 2, 2, 3], dtype=idtype))\n\n    # has data\n    g = dgl.heterograph(\n        {(\"user\", \"plays\", \"game\"): ([0, 1], [1, 2])},\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    g.nodes[\"user\"].data[\"h\"] = F.copy_to(\n        F.tensor([1, 1], dtype=idtype), ctx=F.ctx()\n    )\n    g.nodes[\"game\"].data[\"h\"] = F.copy_to(\n        F.tensor([2, 2, 2], dtype=idtype), ctx=F.ctx()\n    )\n    g.edata[\"h\"] = F.copy_to(F.tensor([1, 1], dtype=idtype), ctx=F.ctx())\n    u = F.tensor([0, 2], dtype=idtype)\n    v = F.tensor([2, 3], dtype=idtype)\n    e_feat = {\n        \"h\": F.copy_to(F.tensor([2, 2], dtype=idtype), ctx=F.ctx()),\n        \"hh\": F.copy_to(F.tensor([2, 2], dtype=idtype), ctx=F.ctx()),\n    }\n    g.add_edges(u, v, e_feat)\n    assert g.num_nodes(\"user\") == 3\n    assert g.num_nodes(\"game\") == 4\n    assert g.num_edges() == 4\n    u, v = g.edges(form=\"uv\", order=\"eid\")\n    assert F.array_equal(u, F.tensor([0, 1, 0, 2], dtype=idtype))\n    assert F.array_equal(v, F.tensor([1, 2, 2, 3], dtype=idtype))\n    assert F.array_equal(\n        g.nodes[\"user\"].data[\"h\"], F.tensor([1, 1, 0], dtype=idtype)\n    )\n    assert F.array_equal(\n        g.nodes[\"game\"].data[\"h\"], F.tensor([2, 2, 2, 0], dtype=idtype)\n    )\n    assert F.array_equal(g.edata[\"h\"], F.tensor([1, 1, 2, 2], dtype=idtype))\n    assert F.array_equal(g.edata[\"hh\"], F.tensor([0, 0, 2, 2], dtype=idtype))\n\n    # heterogeneous graph\n    g = create_test_heterograph3(idtype)\n    u = F.tensor([0, 2], dtype=idtype)\n    v = F.tensor([2, 3], dtype=idtype)\n    g.add_edges(u, v, etype=\"plays\")\n    assert g.num_nodes(\"user\") == 3\n    assert g.num_nodes(\"game\") == 4\n    assert g.num_nodes(\"developer\") == 2\n    assert g.num_edges(\"plays\") == 6\n    assert g.num_edges(\"develops\") == 2\n    u, v = g.edges(form=\"uv\", order=\"eid\", etype=\"plays\")\n    assert F.array_equal(u, F.tensor([0, 1, 1, 2, 0, 2], dtype=idtype))\n    assert F.array_equal(v, F.tensor([0, 0, 1, 1, 2, 3], dtype=idtype))\n    assert F.array_equal(\n        g.nodes[\"user\"].data[\"h\"], F.tensor([1, 1, 1], dtype=idtype)\n    )\n    assert F.array_equal(\n        g.nodes[\"game\"].data[\"h\"], F.tensor([2, 2, 0, 0], dtype=idtype)\n    )\n    assert F.array_equal(\n        g.edges[\"plays\"].data[\"h\"], F.tensor([1, 1, 1, 1, 0, 0], dtype=idtype)\n    )\n\n    # add with feature\n    e_feat = {\"h\": F.copy_to(F.tensor([2, 2], dtype=idtype), ctx=F.ctx())}\n    u = F.tensor([0, 2], dtype=idtype)\n    v = F.tensor([2, 3], dtype=idtype)\n    g.nodes[\"game\"].data[\"h\"] = F.copy_to(\n        F.tensor([2, 2, 1, 1], dtype=idtype), ctx=F.ctx()\n    )\n    g.add_edges(u, v, data=e_feat, etype=\"develops\")\n    assert g.num_nodes(\"user\") == 3\n    assert g.num_nodes(\"game\") == 4\n    assert g.num_nodes(\"developer\") == 3\n    assert g.num_edges(\"plays\") == 6\n    assert g.num_edges(\"develops\") == 4\n    u, v = g.edges(form=\"uv\", order=\"eid\", etype=\"develops\")\n    assert F.array_equal(u, F.tensor([0, 1, 0, 2], dtype=idtype))\n    assert F.array_equal(v, F.tensor([0, 1, 2, 3], dtype=idtype))\n    assert F.array_equal(\n        g.nodes[\"developer\"].data[\"h\"], F.tensor([3, 3, 0], dtype=idtype)\n    )\n    assert F.array_equal(\n        g.nodes[\"game\"].data[\"h\"], F.tensor([2, 2, 1, 1], dtype=idtype)\n    )\n    assert F.array_equal(\n        g.edges[\"develops\"].data[\"h\"], F.tensor([0, 0, 2, 2], dtype=idtype)\n    )\n\n\n@parametrize_idtype\ndef test_add_nodes(idtype):\n    # homogeneous Graphs\n    g = dgl.graph(([0, 1], [1, 2]), idtype=idtype, device=F.ctx())\n    g.ndata[\"h\"] = F.copy_to(F.tensor([1, 1, 1], dtype=idtype), ctx=F.ctx())\n    g.add_nodes(1)\n    assert g.num_nodes() == 4\n    assert F.array_equal(g.ndata[\"h\"], F.tensor([1, 1, 1, 0], dtype=idtype))\n\n    # zero node graph\n    g = dgl.graph(([], []), num_nodes=3, idtype=idtype, device=F.ctx())\n    g.ndata[\"h\"] = F.copy_to(F.tensor([1, 1, 1], dtype=idtype), ctx=F.ctx())\n    g.add_nodes(\n        1, data={\"h\": F.copy_to(F.tensor([2], dtype=idtype), ctx=F.ctx())}\n    )\n    assert g.num_nodes() == 4\n    assert F.array_equal(g.ndata[\"h\"], F.tensor([1, 1, 1, 2], dtype=idtype))\n\n    # bipartite graph\n    g = dgl.heterograph(\n        {(\"user\", \"plays\", \"game\"): ([0, 1], [1, 2])},\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    g.add_nodes(\n        2,\n        data={\"h\": F.copy_to(F.tensor([2, 2], dtype=idtype), ctx=F.ctx())},\n        ntype=\"user\",\n    )\n    assert g.num_nodes(\"user\") == 4\n    assert F.array_equal(\n        g.nodes[\"user\"].data[\"h\"], F.tensor([0, 0, 2, 2], dtype=idtype)\n    )\n    g.add_nodes(2, ntype=\"game\")\n    assert g.num_nodes(\"game\") == 5\n\n    # heterogeneous graph\n    g = create_test_heterograph3(idtype)\n    g.add_nodes(1, ntype=\"user\")\n    g.add_nodes(\n        2,\n        data={\"h\": F.copy_to(F.tensor([2, 2], dtype=idtype), ctx=F.ctx())},\n        ntype=\"game\",\n    )\n    g.add_nodes(0, ntype=\"developer\")\n    assert g.num_nodes(\"user\") == 4\n    assert g.num_nodes(\"game\") == 4\n    assert g.num_nodes(\"developer\") == 2\n    assert F.array_equal(\n        g.nodes[\"user\"].data[\"h\"], F.tensor([1, 1, 1, 0], dtype=idtype)\n    )\n    assert F.array_equal(\n        g.nodes[\"game\"].data[\"h\"], F.tensor([2, 2, 2, 2], dtype=idtype)\n    )\n\n\n@unittest.skipIf(\n    dgl.backend.backend_name == \"mxnet\",\n    reason=\"MXNet has error with (0,) shape tensor.\",\n)\n@parametrize_idtype\ndef test_remove_edges(idtype):\n    # homogeneous Graphs\n    g = dgl.graph(([0, 1], [1, 2]), idtype=idtype, device=F.ctx())\n    e = 0\n    g.remove_edges(e)\n    assert g.num_edges() == 1\n    u, v = g.edges(form=\"uv\", order=\"eid\")\n    assert F.array_equal(u, F.tensor([1], dtype=idtype))\n    assert F.array_equal(v, F.tensor([2], dtype=idtype))\n    g = dgl.graph(([0, 1], [1, 2]), idtype=idtype, device=F.ctx())\n    e = [0]\n    g.remove_edges(e)\n    assert g.num_edges() == 1\n    u, v = g.edges(form=\"uv\", order=\"eid\")\n    assert F.array_equal(u, F.tensor([1], dtype=idtype))\n    assert F.array_equal(v, F.tensor([2], dtype=idtype))\n    e = F.tensor([0], dtype=idtype)\n    g.remove_edges(e)\n    assert g.num_edges() == 0\n\n    # has node data\n    g = dgl.graph(([0, 1], [1, 2]), idtype=idtype, device=F.ctx())\n    g.ndata[\"h\"] = F.copy_to(F.tensor([1, 2, 3], dtype=idtype), ctx=F.ctx())\n    g.remove_edges(1)\n    assert g.num_edges() == 1\n    assert F.array_equal(g.ndata[\"h\"], F.tensor([1, 2, 3], dtype=idtype))\n\n    # has edge data\n    g = dgl.graph(([0, 1], [1, 2]), idtype=idtype, device=F.ctx())\n    g.edata[\"h\"] = F.copy_to(F.tensor([1, 2], dtype=idtype), ctx=F.ctx())\n    g.remove_edges(0)\n    assert g.num_edges() == 1\n    assert F.array_equal(g.edata[\"h\"], F.tensor([2], dtype=idtype))\n\n    # invalid eid\n    assert_fail = False\n    try:\n        g.remove_edges(1)\n    except:\n        assert_fail = True\n    assert assert_fail\n\n    # bipartite graph\n    g = dgl.heterograph(\n        {(\"user\", \"plays\", \"game\"): ([0, 1], [1, 2])},\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    e = 0\n    g.remove_edges(e)\n    assert g.num_edges() == 1\n    u, v = g.edges(form=\"uv\", order=\"eid\")\n    assert F.array_equal(u, F.tensor([1], dtype=idtype))\n    assert F.array_equal(v, F.tensor([2], dtype=idtype))\n    g = dgl.heterograph(\n        {(\"user\", \"plays\", \"game\"): ([0, 1], [1, 2])},\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    e = [0]\n    g.remove_edges(e)\n    assert g.num_edges() == 1\n    u, v = g.edges(form=\"uv\", order=\"eid\")\n    assert F.array_equal(u, F.tensor([1], dtype=idtype))\n    assert F.array_equal(v, F.tensor([2], dtype=idtype))\n    e = F.tensor([0], dtype=idtype)\n    g.remove_edges(e)\n    assert g.num_edges() == 0\n\n    # has data\n    g = dgl.heterograph(\n        {(\"user\", \"plays\", \"game\"): ([0, 1], [1, 2])},\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    g.nodes[\"user\"].data[\"h\"] = F.copy_to(\n        F.tensor([1, 1], dtype=idtype), ctx=F.ctx()\n    )\n    g.nodes[\"game\"].data[\"h\"] = F.copy_to(\n        F.tensor([2, 2, 2], dtype=idtype), ctx=F.ctx()\n    )\n    g.edata[\"h\"] = F.copy_to(F.tensor([1, 2], dtype=idtype), ctx=F.ctx())\n    g.remove_edges(1)\n    assert g.num_edges() == 1\n    assert F.array_equal(\n        g.nodes[\"user\"].data[\"h\"], F.tensor([1, 1], dtype=idtype)\n    )\n    assert F.array_equal(\n        g.nodes[\"game\"].data[\"h\"], F.tensor([2, 2, 2], dtype=idtype)\n    )\n    assert F.array_equal(g.edata[\"h\"], F.tensor([1], dtype=idtype))\n\n    # heterogeneous graph\n    g = create_test_heterograph3(idtype)\n    g.edges[\"plays\"].data[\"h\"] = F.copy_to(\n        F.tensor([1, 2, 3, 4], dtype=idtype), ctx=F.ctx()\n    )\n    g.remove_edges(1, etype=\"plays\")\n    assert g.num_edges(\"plays\") == 3\n    u, v = g.edges(form=\"uv\", order=\"eid\", etype=\"plays\")\n    assert F.array_equal(u, F.tensor([0, 1, 2], dtype=idtype))\n    assert F.array_equal(v, F.tensor([0, 1, 1], dtype=idtype))\n    assert F.array_equal(\n        g.edges[\"plays\"].data[\"h\"], F.tensor([1, 3, 4], dtype=idtype)\n    )\n    # remove all edges of 'develops'\n    g.remove_edges([0, 1], etype=\"develops\")\n    assert g.num_edges(\"develops\") == 0\n    assert F.array_equal(\n        g.nodes[\"user\"].data[\"h\"], F.tensor([1, 1, 1], dtype=idtype)\n    )\n    assert F.array_equal(\n        g.nodes[\"game\"].data[\"h\"], F.tensor([2, 2], dtype=idtype)\n    )\n    assert F.array_equal(\n        g.nodes[\"developer\"].data[\"h\"], F.tensor([3, 3], dtype=idtype)\n    )\n\n\n@parametrize_idtype\ndef test_remove_nodes(idtype):\n    # homogeneous Graphs\n    g = dgl.graph(([0, 1], [1, 2]), idtype=idtype, device=F.ctx())\n    n = 0\n    g.remove_nodes(n)\n    assert g.num_nodes() == 2\n    assert g.num_edges() == 1\n    u, v = g.edges(form=\"uv\", order=\"eid\")\n    assert F.array_equal(u, F.tensor([0], dtype=idtype))\n    assert F.array_equal(v, F.tensor([1], dtype=idtype))\n    g = dgl.graph(([0, 1], [1, 2]), idtype=idtype, device=F.ctx())\n    n = [1]\n    g.remove_nodes(n)\n    assert g.num_nodes() == 2\n    assert g.num_edges() == 0\n    g = dgl.graph(([0, 1], [1, 2]), idtype=idtype, device=F.ctx())\n    n = F.tensor([2], dtype=idtype)\n    g.remove_nodes(n)\n    assert g.num_nodes() == 2\n    assert g.num_edges() == 1\n    u, v = g.edges(form=\"uv\", order=\"eid\")\n    assert F.array_equal(u, F.tensor([0], dtype=idtype))\n    assert F.array_equal(v, F.tensor([1], dtype=idtype))\n\n    # invalid nid\n    assert_fail = False\n    try:\n        g.remove_nodes(3)\n    except:\n        assert_fail = True\n    assert assert_fail\n\n    # has node and edge data\n    g = dgl.graph(([0, 0, 2], [0, 1, 2]), idtype=idtype, device=F.ctx())\n    g.ndata[\"hv\"] = F.copy_to(F.tensor([1, 2, 3], dtype=idtype), ctx=F.ctx())\n    g.edata[\"he\"] = F.copy_to(F.tensor([1, 2, 3], dtype=idtype), ctx=F.ctx())\n    g.remove_nodes(F.tensor([0], dtype=idtype))\n    assert g.num_nodes() == 2\n    assert g.num_edges() == 1\n    u, v = g.edges(form=\"uv\", order=\"eid\")\n    assert F.array_equal(u, F.tensor([1], dtype=idtype))\n    assert F.array_equal(v, F.tensor([1], dtype=idtype))\n    assert F.array_equal(g.ndata[\"hv\"], F.tensor([2, 3], dtype=idtype))\n    assert F.array_equal(g.edata[\"he\"], F.tensor([3], dtype=idtype))\n\n    # node id larger than current max node id\n    g = dgl.heterograph(\n        {(\"user\", \"plays\", \"game\"): ([0, 1], [1, 2])},\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    n = 0\n    g.remove_nodes(n, ntype=\"user\")\n    assert g.num_nodes(\"user\") == 1\n    assert g.num_nodes(\"game\") == 3\n    assert g.num_edges() == 1\n    u, v = g.edges(form=\"uv\", order=\"eid\")\n    assert F.array_equal(u, F.tensor([0], dtype=idtype))\n    assert F.array_equal(v, F.tensor([2], dtype=idtype))\n    g = dgl.heterograph(\n        {(\"user\", \"plays\", \"game\"): ([0, 1], [1, 2])},\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    n = [1]\n    g.remove_nodes(n, ntype=\"user\")\n    assert g.num_nodes(\"user\") == 1\n    assert g.num_nodes(\"game\") == 3\n    assert g.num_edges() == 1\n    u, v = g.edges(form=\"uv\", order=\"eid\")\n    assert F.array_equal(u, F.tensor([0], dtype=idtype))\n    assert F.array_equal(v, F.tensor([1], dtype=idtype))\n    g = dgl.heterograph(\n        {(\"user\", \"plays\", \"game\"): ([0, 1], [1, 2])},\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    n = F.tensor([0], dtype=idtype)\n    g.remove_nodes(n, ntype=\"game\")\n    assert g.num_nodes(\"user\") == 2\n    assert g.num_nodes(\"game\") == 2\n    assert g.num_edges() == 2\n    u, v = g.edges(form=\"uv\", order=\"eid\")\n    assert F.array_equal(u, F.tensor([0, 1], dtype=idtype))\n    assert F.array_equal(v, F.tensor([0, 1], dtype=idtype))\n\n    # heterogeneous graph\n    g = create_test_heterograph3(idtype)\n    g.edges[\"plays\"].data[\"h\"] = F.copy_to(\n        F.tensor([1, 2, 3, 4], dtype=idtype), ctx=F.ctx()\n    )\n    g.remove_nodes(0, ntype=\"game\")\n    assert g.num_nodes(\"user\") == 3\n    assert g.num_nodes(\"game\") == 1\n    assert g.num_nodes(\"developer\") == 2\n    assert g.num_edges(\"plays\") == 2\n    assert g.num_edges(\"develops\") == 1\n    assert F.array_equal(\n        g.nodes[\"user\"].data[\"h\"], F.tensor([1, 1, 1], dtype=idtype)\n    )\n    assert F.array_equal(g.nodes[\"game\"].data[\"h\"], F.tensor([2], dtype=idtype))\n    assert F.array_equal(\n        g.nodes[\"developer\"].data[\"h\"], F.tensor([3, 3], dtype=idtype)\n    )\n    u, v = g.edges(form=\"uv\", order=\"eid\", etype=\"plays\")\n    assert F.array_equal(u, F.tensor([1, 2], dtype=idtype))\n    assert F.array_equal(v, F.tensor([0, 0], dtype=idtype))\n    assert F.array_equal(\n        g.edges[\"plays\"].data[\"h\"], F.tensor([3, 4], dtype=idtype)\n    )\n    u, v = g.edges(form=\"uv\", order=\"eid\", etype=\"develops\")\n    assert F.array_equal(u, F.tensor([1], dtype=idtype))\n    assert F.array_equal(v, F.tensor([0], dtype=idtype))\n\n\n@parametrize_idtype\ndef test_frame(idtype):\n    g = dgl.graph(([0, 1, 2], [1, 2, 3]), idtype=idtype, device=F.ctx())\n    g.ndata[\"h\"] = F.copy_to(F.tensor([0, 1, 2, 3], dtype=idtype), ctx=F.ctx())\n    g.edata[\"h\"] = F.copy_to(F.tensor([0, 1, 2], dtype=idtype), ctx=F.ctx())\n\n    # remove nodes\n    sg = dgl.remove_nodes(g, [3])\n    # check for lazy update\n    assert F.array_equal(sg._node_frames[0]._columns[\"h\"].storage, g.ndata[\"h\"])\n    assert F.array_equal(sg._edge_frames[0]._columns[\"h\"].storage, g.edata[\"h\"])\n    assert sg.ndata[\"h\"].shape[0] == 3\n    assert sg.edata[\"h\"].shape[0] == 2\n    # update after read\n    assert F.array_equal(\n        sg._node_frames[0]._columns[\"h\"].storage,\n        F.tensor([0, 1, 2], dtype=idtype),\n    )\n    assert F.array_equal(\n        sg._edge_frames[0]._columns[\"h\"].storage, F.tensor([0, 1], dtype=idtype)\n    )\n\n    ng = dgl.add_nodes(sg, 1)\n    assert ng.ndata[\"h\"].shape[0] == 4\n    assert F.array_equal(\n        ng._node_frames[0]._columns[\"h\"].storage,\n        F.tensor([0, 1, 2, 0], dtype=idtype),\n    )\n    ng = dgl.add_edges(ng, [3], [1])\n    assert ng.edata[\"h\"].shape[0] == 3\n    assert F.array_equal(\n        ng._edge_frames[0]._columns[\"h\"].storage,\n        F.tensor([0, 1, 0], dtype=idtype),\n    )\n\n    # multi level lazy update\n    sg = dgl.remove_nodes(g, [3])\n    assert F.array_equal(sg._node_frames[0]._columns[\"h\"].storage, g.ndata[\"h\"])\n    assert F.array_equal(sg._edge_frames[0]._columns[\"h\"].storage, g.edata[\"h\"])\n    ssg = dgl.remove_nodes(sg, [1])\n    assert F.array_equal(\n        ssg._node_frames[0]._columns[\"h\"].storage, g.ndata[\"h\"]\n    )\n    assert F.array_equal(\n        ssg._edge_frames[0]._columns[\"h\"].storage, g.edata[\"h\"]\n    )\n    # ssg is changed\n    assert ssg.ndata[\"h\"].shape[0] == 2\n    assert ssg.edata[\"h\"].shape[0] == 0\n    assert F.array_equal(\n        ssg._node_frames[0]._columns[\"h\"].storage,\n        F.tensor([0, 2], dtype=idtype),\n    )\n    # sg still in lazy model\n    assert F.array_equal(sg._node_frames[0]._columns[\"h\"].storage, g.ndata[\"h\"])\n    assert F.array_equal(sg._edge_frames[0]._columns[\"h\"].storage, g.edata[\"h\"])\n\n\n@unittest.skipIf(\n    dgl.backend.backend_name == \"tensorflow\",\n    reason=\"TensorFlow always create a new tensor\",\n)\n@unittest.skipIf(\n    F._default_context_str == \"cpu\",\n    reason=\"cpu do not have context change problem\",\n)\n@parametrize_idtype\ndef test_frame_device(idtype):\n    g = dgl.graph(([0, 1, 2], [2, 3, 1]))\n    g.ndata[\"h\"] = F.copy_to(F.tensor([1, 1, 1, 2], dtype=idtype), ctx=F.cpu())\n    g.ndata[\"hh\"] = F.copy_to(F.ones((4, 3), dtype=idtype), ctx=F.cpu())\n    g.edata[\"h\"] = F.copy_to(F.tensor([1, 2, 3], dtype=idtype), ctx=F.cpu())\n\n    g = g.to(F.ctx())\n    # lazy device copy\n    assert F.context(g._node_frames[0]._columns[\"h\"].storage) == F.cpu()\n    assert F.context(g._node_frames[0]._columns[\"hh\"].storage) == F.cpu()\n    print(g.ndata[\"h\"])\n    assert F.context(g._node_frames[0]._columns[\"h\"].storage) == F.ctx()\n    assert F.context(g._node_frames[0]._columns[\"hh\"].storage) == F.cpu()\n    assert F.context(g._edge_frames[0]._columns[\"h\"].storage) == F.cpu()\n\n    # lazy device copy in subgraph\n    sg = dgl.node_subgraph(g, [0, 1, 2])\n    assert F.context(sg._node_frames[0]._columns[\"h\"].storage) == F.ctx()\n    assert F.context(sg._node_frames[0]._columns[\"hh\"].storage) == F.cpu()\n    assert F.context(sg._edge_frames[0]._columns[\"h\"].storage) == F.cpu()\n    print(sg.ndata[\"hh\"])\n    assert F.context(sg._node_frames[0]._columns[\"hh\"].storage) == F.ctx()\n    assert F.context(sg._edge_frames[0]._columns[\"h\"].storage) == F.cpu()\n\n    # back to cpu\n    sg = sg.to(F.cpu())\n    assert F.context(sg._node_frames[0]._columns[\"h\"].storage) == F.ctx()\n    assert F.context(sg._node_frames[0]._columns[\"hh\"].storage) == F.ctx()\n    assert F.context(sg._edge_frames[0]._columns[\"h\"].storage) == F.cpu()\n    print(sg.ndata[\"h\"])\n    print(sg.ndata[\"hh\"])\n    print(sg.edata[\"h\"])\n    assert F.context(sg._node_frames[0]._columns[\"h\"].storage) == F.cpu()\n    assert F.context(sg._node_frames[0]._columns[\"hh\"].storage) == F.cpu()\n    assert F.context(sg._edge_frames[0]._columns[\"h\"].storage) == F.cpu()\n\n    # set some field\n    sg = sg.to(F.ctx())\n    assert F.context(sg._node_frames[0]._columns[\"h\"].storage) == F.cpu()\n    sg.ndata[\"h\"][0] = 5\n    assert F.context(sg._node_frames[0]._columns[\"h\"].storage) == F.ctx()\n    assert F.context(sg._node_frames[0]._columns[\"hh\"].storage) == F.cpu()\n    assert F.context(sg._edge_frames[0]._columns[\"h\"].storage) == F.cpu()\n\n    # add nodes\n    ng = dgl.add_nodes(sg, 3)\n    assert F.context(ng._node_frames[0]._columns[\"h\"].storage) == F.ctx()\n    assert F.context(ng._node_frames[0]._columns[\"hh\"].storage) == F.ctx()\n    assert F.context(ng._edge_frames[0]._columns[\"h\"].storage) == F.cpu()\n\n\n@parametrize_idtype\ndef test_create_block(idtype):\n    block = dgl.create_block(\n        ([0, 1, 2], [1, 2, 3]), idtype=idtype, device=F.ctx()\n    )\n    assert block.num_src_nodes() == 3\n    assert block.num_dst_nodes() == 4\n    assert block.num_edges() == 3\n\n    block = dgl.create_block(([], []), idtype=idtype, device=F.ctx())\n    assert block.num_src_nodes() == 0\n    assert block.num_dst_nodes() == 0\n    assert block.num_edges() == 0\n\n    block = dgl.create_block(([], []), 3, 4, idtype=idtype, device=F.ctx())\n    assert block.num_src_nodes() == 3\n    assert block.num_dst_nodes() == 4\n    assert block.num_edges() == 0\n\n    block = dgl.create_block(\n        ([0, 1, 2], [1, 2, 3]), 4, 5, idtype=idtype, device=F.ctx()\n    )\n    assert block.num_src_nodes() == 4\n    assert block.num_dst_nodes() == 5\n    assert block.num_edges() == 3\n\n    sx = F.randn((4, 5))\n    dx = F.randn((5, 6))\n    ex = F.randn((3, 4))\n    block.srcdata[\"x\"] = sx\n    block.dstdata[\"x\"] = dx\n    block.edata[\"x\"] = ex\n\n    g = dgl.block_to_graph(block)\n    assert g.num_src_nodes() == 4\n    assert g.num_dst_nodes() == 5\n    assert g.num_edges() == 3\n    assert g.srcdata[\"x\"] is sx\n    assert g.dstdata[\"x\"] is dx\n    assert g.edata[\"x\"] is ex\n\n    block = dgl.create_block(\n        {\n            (\"A\", \"AB\", \"B\"): ([1, 2, 3], [2, 1, 0]),\n            (\"B\", \"BA\", \"A\"): ([2, 3], [3, 4]),\n        },\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    assert block.num_src_nodes(\"A\") == 4\n    assert block.num_src_nodes(\"B\") == 4\n    assert block.num_dst_nodes(\"B\") == 3\n    assert block.num_dst_nodes(\"A\") == 5\n    assert block.num_edges(\"AB\") == 3\n    assert block.num_edges(\"BA\") == 2\n\n    block = dgl.create_block(\n        {(\"A\", \"AB\", \"B\"): ([], []), (\"B\", \"BA\", \"A\"): ([], [])},\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    assert block.num_src_nodes(\"A\") == 0\n    assert block.num_src_nodes(\"B\") == 0\n    assert block.num_dst_nodes(\"B\") == 0\n    assert block.num_dst_nodes(\"A\") == 0\n    assert block.num_edges(\"AB\") == 0\n    assert block.num_edges(\"BA\") == 0\n\n    block = dgl.create_block(\n        {(\"A\", \"AB\", \"B\"): ([], []), (\"B\", \"BA\", \"A\"): ([], [])},\n        num_src_nodes={\"A\": 5, \"B\": 5},\n        num_dst_nodes={\"A\": 6, \"B\": 4},\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    assert block.num_src_nodes(\"A\") == 5\n    assert block.num_src_nodes(\"B\") == 5\n    assert block.num_dst_nodes(\"B\") == 4\n    assert block.num_dst_nodes(\"A\") == 6\n    assert block.num_edges(\"AB\") == 0\n    assert block.num_edges(\"BA\") == 0\n\n    block = dgl.create_block(\n        {\n            (\"A\", \"AB\", \"B\"): ([1, 2, 3], [2, 1, 0]),\n            (\"B\", \"BA\", \"A\"): ([2, 3], [3, 4]),\n        },\n        num_src_nodes={\"A\": 5, \"B\": 5},\n        num_dst_nodes={\"A\": 6, \"B\": 4},\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    assert block.num_src_nodes(\"A\") == 5\n    assert block.num_src_nodes(\"B\") == 5\n    assert block.num_dst_nodes(\"B\") == 4\n    assert block.num_dst_nodes(\"A\") == 6\n    assert block.num_edges((\"A\", \"AB\", \"B\")) == 3\n    assert block.num_edges((\"B\", \"BA\", \"A\")) == 2\n\n    sax = F.randn((5, 3))\n    sbx = F.randn((5, 4))\n    dax = F.randn((6, 5))\n    dbx = F.randn((4, 6))\n    eabx = F.randn((3, 7))\n    ebax = F.randn((2, 8))\n    block.srcnodes[\"A\"].data[\"x\"] = sax\n    block.srcnodes[\"B\"].data[\"x\"] = sbx\n    block.dstnodes[\"A\"].data[\"x\"] = dax\n    block.dstnodes[\"B\"].data[\"x\"] = dbx\n    block.edges[\"AB\"].data[\"x\"] = eabx\n    block.edges[\"BA\"].data[\"x\"] = ebax\n\n    hg = dgl.block_to_graph(block)\n    assert hg.num_nodes(\"A_src\") == 5\n    assert hg.num_nodes(\"B_src\") == 5\n    assert hg.num_nodes(\"A_dst\") == 6\n    assert hg.num_nodes(\"B_dst\") == 4\n    assert hg.num_edges((\"A_src\", \"AB\", \"B_dst\")) == 3\n    assert hg.num_edges((\"B_src\", \"BA\", \"A_dst\")) == 2\n    assert hg.nodes[\"A_src\"].data[\"x\"] is sax\n    assert hg.nodes[\"B_src\"].data[\"x\"] is sbx\n    assert hg.nodes[\"A_dst\"].data[\"x\"] is dax\n    assert hg.nodes[\"B_dst\"].data[\"x\"] is dbx\n    assert hg.edges[\"AB\"].data[\"x\"] is eabx\n    assert hg.edges[\"BA\"].data[\"x\"] is ebax\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"fmt\", [\"coo\", \"csr\", \"csc\"])\ndef test_adj_tensors(idtype, fmt):\n    if fmt == \"coo\":\n        A = ssp.random(10, 10, 0.2).tocoo()\n        A.data = np.arange(20)\n        row = F.tensor(A.row, idtype)\n        col = F.tensor(A.col, idtype)\n        g = dgl.graph((row, col))\n    elif fmt == \"csr\":\n        A = ssp.random(10, 10, 0.2).tocsr()\n        A.data = np.arange(20)\n        indptr = F.tensor(A.indptr, idtype)\n        indices = F.tensor(A.indices, idtype)\n        g = dgl.graph((\"csr\", (indptr, indices, [])))\n        with pytest.raises(DGLError):\n            g2 = dgl.graph((\"csr\", (indptr[:-1], indices, [])), num_nodes=10)\n    elif fmt == \"csc\":\n        A = ssp.random(10, 10, 0.2).tocsc()\n        A.data = np.arange(20)\n        indptr = F.tensor(A.indptr, idtype)\n        indices = F.tensor(A.indices, idtype)\n        g = dgl.graph((\"csc\", (indptr, indices, [])))\n        with pytest.raises(DGLError):\n            g2 = dgl.graph((\"csr\", (indptr[:-1], indices, [])), num_nodes=10)\n\n    A_coo = A.tocoo()\n    A_csr = A.tocsr()\n    A_csc = A.tocsc()\n    row, col = g.adj_tensors(\"coo\")\n    assert np.array_equal(F.asnumpy(row), A_coo.row)\n    assert np.array_equal(F.asnumpy(col), A_coo.col)\n\n    indptr, indices, eids = g.adj_tensors(\"csr\")\n    assert np.array_equal(F.asnumpy(indptr), A_csr.indptr)\n    if fmt == \"csr\":\n        assert len(eids) == 0\n        assert np.array_equal(F.asnumpy(indices), A_csr.indices)\n    else:\n        indices_sorted = F.zeros(len(indices), idtype)\n        indices_sorted = F.scatter_row(indices_sorted, eids, indices)\n        indices_sorted_np = np.zeros(len(indices), dtype=A_csr.indices.dtype)\n        indices_sorted_np[A_csr.data] = A_csr.indices\n        assert np.array_equal(F.asnumpy(indices_sorted), indices_sorted_np)\n\n    indptr, indices, eids = g.adj_tensors(\"csc\")\n    assert np.array_equal(F.asnumpy(indptr), A_csc.indptr)\n    if fmt == \"csc\":\n        assert len(eids) == 0\n        assert np.array_equal(F.asnumpy(indices), A_csc.indices)\n    else:\n        indices_sorted = F.zeros(len(indices), idtype)\n        indices_sorted = F.scatter_row(indices_sorted, eids, indices)\n        indices_sorted_np = np.zeros(len(indices), dtype=A_csc.indices.dtype)\n        indices_sorted_np[A_csc.data] = A_csc.indices\n        assert np.array_equal(F.asnumpy(indices_sorted), indices_sorted_np)\n\n\ndef _test_forking_pickler_entry(g, q):\n    q.put(g.formats())\n\n\n@unittest.skipIf(\n    dgl.backend.backend_name == \"mxnet\", reason=\"MXNet doesn't support spawning\"\n)\ndef test_forking_pickler():\n    ctx = mp.get_context(\"spawn\")\n    g = dgl.graph(([0, 1, 2], [1, 2, 3]))\n    g.create_formats_()\n    q = ctx.Queue(1)\n    proc = ctx.Process(target=_test_forking_pickler_entry, args=(g, q))\n    proc.start()\n    fmt = q.get()[\"created\"]\n    proc.join()\n    assert \"coo\" in fmt\n    assert \"csr\" in fmt\n    assert \"csc\" in fmt\n\n\nif __name__ == \"__main__\":\n    # test_create()\n    # test_query()\n    # test_hypersparse()\n    # test_adj(\"int32\")\n    # test_inc()\n    # test_view(\"int32\")\n    # test_view1(\"int32\")\n    # test_flatten(F.int32)\n    # test_convert_bound()\n    # test_convert()\n    # test_to_device(\"int32\")\n    # test_transform(\"int32\")\n    # test_subgraph(\"int32\")\n    # test_subgraph_mask(\"int32\")\n    # test_apply()\n    # test_level1()\n    # test_level2()\n    # test_updates()\n    # test_backward()\n    # test_empty_heterograph('int32')\n    # test_types_in_function()\n    # test_stack_reduce()\n    # test_isolated_ntype()\n    # test_bipartite()\n    # test_dtype_cast()\n    # test_float_cast()\n    # test_reverse(\"int32\")\n    # test_format()\n    # test_add_edges(F.int32)\n    # test_add_nodes(F.int32)\n    # test_remove_edges(F.int32)\n    # test_remove_nodes(F.int32)\n    # test_clone(F.int32)\n    # test_frame(F.int32)\n    # test_frame_device(F.int32)\n    # test_empty_query(F.int32)\n    # test_create_block(F.int32)\n    pass\n"
  },
  {
    "path": "tests/python/common/test_homophily.py",
    "content": "import math\nimport unittest\n\nimport backend as F\n\nimport dgl\nfrom utils import parametrize_idtype\n\n\n@unittest.skipIf(\n    dgl.backend.backend_name != \"pytorch\", reason=\"Only support PyTorch for now\"\n)\n@parametrize_idtype\ndef test_node_homophily(idtype):\n    # IfChangeThenChange: python/dgl/homophily.py\n    # Update the docstring example.\n    device = F.ctx()\n    graph = dgl.graph(\n        ([1, 2, 0, 4], [0, 1, 2, 3]), idtype=idtype, device=device\n    )\n    y = F.tensor([0, 0, 0, 0, 1])\n    assert math.isclose(dgl.node_homophily(graph, y), 0.6000000238418579)\n\n\n@unittest.skipIf(\n    dgl.backend.backend_name != \"pytorch\", reason=\"Only support PyTorch for now\"\n)\n@parametrize_idtype\ndef test_edge_homophily(idtype):\n    # IfChangeThenChange: python/dgl/homophily.py\n    # Update the docstring example.\n    device = F.ctx()\n    graph = dgl.graph(\n        ([1, 2, 0, 4], [0, 1, 2, 3]), idtype=idtype, device=device\n    )\n    y = F.tensor([0, 0, 0, 0, 1])\n    assert math.isclose(dgl.edge_homophily(graph, y), 0.75)\n\n\n@unittest.skipIf(\n    dgl.backend.backend_name != \"pytorch\", reason=\"Only support PyTorch for now\"\n)\n@parametrize_idtype\ndef test_linkx_homophily(idtype):\n    # IfChangeThenChange: python/dgl/homophily.py\n    # Update the docstring example.\n    device = F.ctx()\n    graph = dgl.graph(([0, 1, 2, 3], [1, 2, 0, 4]), device=device)\n    y = F.tensor([0, 0, 0, 0, 1])\n    assert math.isclose(dgl.linkx_homophily(graph, y), 0.19999998807907104)\n\n    y = F.tensor([0, 1, 2, 3, 4])\n    assert math.isclose(dgl.linkx_homophily(graph, y), 0.0000000000000000)\n\n\n@unittest.skipIf(\n    dgl.backend.backend_name != \"pytorch\", reason=\"Only support PyTorch for now\"\n)\n@parametrize_idtype\ndef test_adjusted_homophily(idtype):\n    # IfChangeThenChange: python/dgl/homophily.py\n    # Update the docstring example.\n    device = F.ctx()\n    graph = dgl.graph(\n        ([1, 2, 0, 4], [0, 1, 2, 3]), idtype=idtype, device=device\n    )\n    y = F.tensor([0, 0, 0, 0, 1])\n    assert math.isclose(dgl.adjusted_homophily(graph, y), -0.1428571492433548)\n"
  },
  {
    "path": "tests/python/common/test_label_informativeness.py",
    "content": "import math\nimport unittest\n\nimport backend as F\n\nimport dgl\nfrom utils import parametrize_idtype\n\n\n@unittest.skipIf(\n    dgl.backend.backend_name != \"pytorch\", reason=\"Only support PyTorch for now\"\n)\n@parametrize_idtype\ndef test_edge_label_informativeness(idtype):\n    # IfChangeThenChange: python/dgl/label_informativeness.py\n    # Update the docstring example.\n    device = F.ctx()\n    graph = dgl.graph(\n        ([0, 1, 2, 2, 3, 4], [1, 2, 0, 3, 4, 5]), idtype=idtype, device=device\n    )\n    y = F.tensor([0, 0, 0, 0, 1, 1])\n    assert math.isclose(\n        dgl.edge_label_informativeness(graph, y),\n        0.25177597999572754,\n        abs_tol=1e-6,\n    )\n\n\n@unittest.skipIf(\n    dgl.backend.backend_name != \"pytorch\", reason=\"Only support PyTorch for now\"\n)\n@parametrize_idtype\ndef test_node_label_informativeness(idtype):\n    # IfChangeThenChange: python/dgl/label_informativeness.py\n    # Update the docstring example.\n    device = F.ctx()\n    graph = dgl.graph(\n        ([0, 1, 2, 2, 3, 4], [1, 2, 0, 3, 4, 5]), idtype=idtype, device=device\n    )\n    y = F.tensor([0, 0, 0, 0, 1, 1])\n    assert math.isclose(\n        dgl.node_label_informativeness(graph, y),\n        0.3381872773170471,\n        abs_tol=1e-6,\n    )\n"
  },
  {
    "path": "tests/python/common/test_merge.py",
    "content": "import backend as F\n\nimport dgl\nfrom utils import parametrize_idtype\n\n\n@parametrize_idtype\ndef test_heterograph_merge(idtype):\n    g1 = (\n        dgl.heterograph({(\"a\", \"to\", \"b\"): ([0, 1], [1, 0])})\n        .astype(idtype)\n        .to(F.ctx())\n    )\n    g1_n_edges = g1.num_edges(etype=\"to\")\n    g1.nodes[\"a\"].data[\"nh\"] = F.randn((2, 3))\n    g1.nodes[\"b\"].data[\"nh\"] = F.randn((2, 3))\n    g1.edges[\"to\"].data[\"eh\"] = F.randn((2, 3))\n\n    g2 = (\n        dgl.heterograph({(\"a\", \"to\", \"b\"): ([1, 2, 3], [2, 3, 5])})\n        .astype(idtype)\n        .to(F.ctx())\n    )\n    g2.nodes[\"a\"].data[\"nh\"] = F.randn((4, 3))\n    g2.nodes[\"b\"].data[\"nh\"] = F.randn((6, 3))\n    g2.edges[\"to\"].data[\"eh\"] = F.randn((3, 3))\n    g2.add_nodes(3, ntype=\"a\")\n    g2.add_nodes(3, ntype=\"b\")\n\n    m = dgl.merge([g1, g2])\n\n    # Check g2's edges and nodes were added to g1's in m.\n    m_us = F.asnumpy(m.edges()[0][g1_n_edges:])\n    g2_us = F.asnumpy(g2.edges()[0])\n    assert all(m_us == g2_us)\n    m_vs = F.asnumpy(m.edges()[1][g1_n_edges:])\n    g2_vs = F.asnumpy(g2.edges()[1])\n    assert all(m_vs == g2_vs)\n    for ntype in m.ntypes:\n        assert m.num_nodes(ntype=ntype) == max(\n            g1.num_nodes(ntype=ntype), g2.num_nodes(ntype=ntype)\n        )\n\n        # Check g1's node data was updated with g2's in m.\n        for key in m.nodes[ntype].data:\n            g2_n_nodes = g2.num_nodes(ntype=ntype)\n            updated_g1_ndata = F.asnumpy(m.nodes[ntype].data[key][:g2_n_nodes])\n            g2_ndata = F.asnumpy(g2.nodes[ntype].data[key])\n            assert all((updated_g1_ndata == g2_ndata).flatten())\n\n    # Check g1's edge data was updated with g2's in m.\n    for key in m.edges[\"to\"].data:\n        updated_g1_edata = F.asnumpy(m.edges[\"to\"].data[key][g1_n_edges:])\n        g2_edata = F.asnumpy(g2.edges[\"to\"].data[key])\n        assert all((updated_g1_edata == g2_edata).flatten())\n"
  },
  {
    "path": "tests/python/common/test_partition.py",
    "content": "import unittest\n\nimport backend as F\n\nfrom dgl.distributed import graph_partition_book as gpb\nfrom dgl.partition import NDArrayPartition\nfrom utils import parametrize_idtype\n\n\n@unittest.skipIf(\n    F._default_context_str == \"cpu\",\n    reason=\"NDArrayPartition only works on GPU.\",\n)\n@parametrize_idtype\ndef test_get_node_partition_from_book(idtype):\n    node_map = {\"_N\": F.tensor([[0, 3], [4, 5], [6, 10]], dtype=idtype)}\n    edge_map = {\n        (\"_N\", \"_E\", \"_N\"): F.tensor([[0, 9], [10, 15], [16, 25]], dtype=idtype)\n    }\n    ntypes = {ntype: i for i, ntype in enumerate(node_map)}\n    etypes = {etype: i for i, etype in enumerate(edge_map)}\n    book = gpb.RangePartitionBook(0, 3, node_map, edge_map, ntypes, etypes)\n    partition = gpb.get_node_partition_from_book(book, F.ctx())\n    assert partition.num_parts() == 3\n    assert partition.array_size() == 11\n\n    # Test map_to_local\n    test_ids = F.copy_to(F.tensor([0, 2, 6, 7, 10], dtype=idtype), F.ctx())\n    act_ids = partition.map_to_local(test_ids)\n    exp_ids = F.copy_to(F.tensor([0, 2, 0, 1, 4], dtype=idtype), F.ctx())\n    assert F.array_equal(act_ids, exp_ids)\n\n    # Test map_to_global\n    test_ids = F.copy_to(F.tensor([0, 2], dtype=idtype), F.ctx())\n    act_ids = partition.map_to_global(test_ids, 0)\n    exp_ids = F.copy_to(F.tensor([0, 2], dtype=idtype), F.ctx())\n    assert F.array_equal(act_ids, exp_ids)\n\n    test_ids = F.copy_to(F.tensor([0, 1], dtype=idtype), F.ctx())\n    act_ids = partition.map_to_global(test_ids, 1)\n    exp_ids = F.copy_to(F.tensor([4, 5], dtype=idtype), F.ctx())\n    assert F.array_equal(act_ids, exp_ids)\n\n    test_ids = F.copy_to(F.tensor([0, 1, 4], dtype=idtype), F.ctx())\n    act_ids = partition.map_to_global(test_ids, 2)\n    exp_ids = F.copy_to(F.tensor([6, 7, 10], dtype=idtype), F.ctx())\n    assert F.array_equal(act_ids, exp_ids)\n\n    # Test generate_permutation\n    test_ids = F.copy_to(F.tensor([6, 0, 7, 2, 10], dtype=idtype), F.ctx())\n    perm, split_sum = partition.generate_permutation(test_ids)\n    exp_perm = F.copy_to(F.tensor([1, 3, 0, 2, 4], dtype=idtype), F.ctx())\n    exp_sum = F.copy_to(F.tensor([2, 0, 3]), F.ctx())\n    assert F.array_equal(perm, exp_perm)\n    assert F.array_equal(split_sum, exp_sum)\n"
  },
  {
    "path": "tests/python/common/test_propagate.py",
    "content": "import unittest\n\nimport backend as F\n\nimport dgl\nimport networkx as nx\nfrom utils import check_fail, parametrize_idtype\n\n\ndef create_graph(idtype):\n    g = dgl.from_networkx(nx.path_graph(5), idtype=idtype, device=F.ctx())\n    return g\n\n\ndef mfunc(edges):\n    return {\"m\": edges.src[\"x\"]}\n\n\ndef rfunc(nodes):\n    msg = F.sum(nodes.mailbox[\"m\"], 1)\n    return {\"x\": nodes.data[\"x\"] + msg}\n\n\n@unittest.skipIf(F._default_context_str == \"gpu\", reason=\"GPU not implemented\")\n@parametrize_idtype\ndef test_prop_nodes_bfs(idtype):\n    g = create_graph(idtype)\n    g.ndata[\"x\"] = F.ones((5, 2))\n    dgl.prop_nodes_bfs(\n        g, 0, message_func=mfunc, reduce_func=rfunc, apply_node_func=None\n    )\n    # pull nodes using bfs order will result in a cumsum[i] + data[i] + data[i+1]\n    assert F.allclose(\n        g.ndata[\"x\"],\n        F.tensor([[2.0, 2.0], [4.0, 4.0], [6.0, 6.0], [8.0, 8.0], [9.0, 9.0]]),\n    )\n\n\n@unittest.skipIf(F._default_context_str == \"gpu\", reason=\"GPU not implemented\")\n@parametrize_idtype\ndef test_prop_edges_dfs(idtype):\n    g = create_graph(idtype)\n    g.ndata[\"x\"] = F.ones((5, 2))\n    dgl.prop_edges_dfs(\n        g, 0, message_func=mfunc, reduce_func=rfunc, apply_node_func=None\n    )\n    # snr using dfs results in a cumsum\n    assert F.allclose(\n        g.ndata[\"x\"],\n        F.tensor([[1.0, 1.0], [2.0, 2.0], [3.0, 3.0], [4.0, 4.0], [5.0, 5.0]]),\n    )\n\n    g.ndata[\"x\"] = F.ones((5, 2))\n    dgl.prop_edges_dfs(\n        g,\n        0,\n        has_reverse_edge=True,\n        message_func=mfunc,\n        reduce_func=rfunc,\n        apply_node_func=None,\n    )\n    # result is cumsum[i] + cumsum[i-1]\n    assert F.allclose(\n        g.ndata[\"x\"],\n        F.tensor([[1.0, 1.0], [3.0, 3.0], [5.0, 5.0], [7.0, 7.0], [9.0, 9.0]]),\n    )\n\n    g.ndata[\"x\"] = F.ones((5, 2))\n    dgl.prop_edges_dfs(\n        g,\n        0,\n        has_nontree_edge=True,\n        message_func=mfunc,\n        reduce_func=rfunc,\n        apply_node_func=None,\n    )\n    # result is cumsum[i] + cumsum[i+1]\n    assert F.allclose(\n        g.ndata[\"x\"],\n        F.tensor([[3.0, 3.0], [5.0, 5.0], [7.0, 7.0], [9.0, 9.0], [5.0, 5.0]]),\n    )\n\n\n@unittest.skipIf(F._default_context_str == \"gpu\", reason=\"GPU not implemented\")\n@parametrize_idtype\ndef test_prop_nodes_topo(idtype):\n    # bi-directional chain\n    g = create_graph(idtype)\n    assert check_fail(dgl.prop_nodes_topo, g)  # has loop\n\n    # tree\n    tree = dgl.graph([])\n    tree.add_nodes(5)\n    tree.add_edges(1, 0)\n    tree.add_edges(2, 0)\n    tree.add_edges(3, 2)\n    tree.add_edges(4, 2)\n    tree = dgl.graph(tree.edges())\n    # init node feature data\n    tree.ndata[\"x\"] = F.zeros((5, 2))\n    # set all leaf nodes to be ones\n    tree.nodes[[1, 3, 4]].data[\"x\"] = F.ones((3, 2))\n\n    # Filtering DGLWarning:\n    #    The input graph for the user-defined edge\n    #    function does not contain valid edges\n    import warnings\n\n    with warnings.catch_warnings():\n        warnings.simplefilter(\"ignore\", category=UserWarning)\n        dgl.prop_nodes_topo(\n            tree, message_func=mfunc, reduce_func=rfunc, apply_node_func=None\n        )\n    # root node get the sum\n    assert F.allclose(tree.nodes[0].data[\"x\"], F.tensor([[3.0, 3.0]]))\n\n\nif __name__ == \"__main__\":\n    test_prop_nodes_bfs()\n    test_prop_edges_dfs()\n    test_prop_nodes_topo()\n"
  },
  {
    "path": "tests/python/common/test_random.py",
    "content": "import unittest\n\nimport backend as F\n\nimport dgl\nimport numpy as np\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\", reason=\"GPU random choice not implemented\"\n)\ndef test_random_choice():\n    # test 1\n    a = F.arange(0, 100)\n    x = dgl.random.choice(a, 10, replace=True, prob=None)\n    assert len(x) == 10\n    for i in range(len(x)):\n        assert F.asnumpy(x[i]) >= 0 and F.asnumpy(x[i]) < 100\n    # test 2, replace=False, small num\n    a = F.arange(0, 100)\n    x = dgl.random.choice(a, 10, replace=False, prob=None)\n    assert len(x) == 10\n    for i in range(len(x)):\n        assert F.asnumpy(x[i]) >= 0 and F.asnumpy(x[i]) < 100\n    # test 3, replace=False, large num\n    a = F.arange(0, 100)\n    x = dgl.random.choice(a, 100, replace=False, prob=None)\n    assert len(x) == 100\n    assert np.array_equal(np.sort(F.asnumpy(x)), F.asnumpy(a))\n    # test 4, first arg is integer\n    x = dgl.random.choice(100, 100, replace=False, prob=None)\n    assert len(x) == 100\n    assert np.array_equal(np.sort(F.asnumpy(x)), F.asnumpy(a))\n    # test 5, with prob\n    prob = np.ones((100,))\n    prob[37:40] = 0.0\n    prob -= prob.min()\n    prob /= prob.sum()\n    prob = F.tensor(prob)\n    x = dgl.random.choice(100, 97, replace=False, prob=prob)\n    assert len(x) == 97\n    for i in range(len(x)):\n        assert F.asnumpy(x[i]) < 37 or F.asnumpy(x[i]) >= 40\n\n\nif __name__ == \"__main__\":\n    test_random_choice()\n"
  },
  {
    "path": "tests/python/common/test_readout.py",
    "content": "import unittest\n\nimport backend as F\n\nimport dgl\nimport networkx as nx\nimport numpy as np\nimport pytest\nfrom utils import parametrize_idtype\nfrom utils.graph_cases import get_cases\n\n\n@parametrize_idtype\ndef test_sum_case1(idtype):\n    # NOTE: If you want to update this test case, remember to update the docstring\n    #  example too!!!\n    g1 = dgl.graph(([0, 1], [1, 0]), idtype=idtype, device=F.ctx())\n    g1.ndata[\"h\"] = F.tensor([1.0, 2.0])\n    g2 = dgl.graph(([0, 1], [1, 2]), idtype=idtype, device=F.ctx())\n    g2.ndata[\"h\"] = F.tensor([1.0, 2.0, 3.0])\n    bg = dgl.batch([g1, g2])\n    bg.ndata[\"w\"] = F.tensor([0.1, 0.2, 0.1, 0.5, 0.2])\n    assert F.allclose(F.tensor([3.0]), dgl.sum_nodes(g1, \"h\"))\n    assert F.allclose(F.tensor([3.0, 6.0]), dgl.sum_nodes(bg, \"h\"))\n    assert F.allclose(F.tensor([0.5, 1.7]), dgl.sum_nodes(bg, \"h\", \"w\"))\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"g\", get_cases([\"homo\"], exclude=[\"dglgraph\"]))\n@pytest.mark.parametrize(\"reducer\", [\"sum\", \"max\", \"mean\"])\ndef test_reduce_readout(g, idtype, reducer):\n    g = g.astype(idtype).to(F.ctx())\n    g.ndata[\"h\"] = F.randn((g.num_nodes(), 3))\n    g.edata[\"h\"] = F.randn((g.num_edges(), 2))\n\n    # Test.1: node readout\n    x = dgl.readout_nodes(g, \"h\", op=reducer)\n    # check correctness\n    subg = dgl.unbatch(g)\n    subx = []\n    for sg in subg:\n        sx = dgl.readout_nodes(sg, \"h\", op=reducer)\n        subx.append(sx)\n    assert F.allclose(x, F.cat(subx, dim=0))\n\n    x = getattr(dgl, \"{}_nodes\".format(reducer))(g, \"h\")\n    # check correctness\n    subg = dgl.unbatch(g)\n    subx = []\n    for sg in subg:\n        sx = getattr(dgl, \"{}_nodes\".format(reducer))(sg, \"h\")\n        subx.append(sx)\n    assert F.allclose(x, F.cat(subx, dim=0))\n\n    # Test.2: edge readout\n    x = dgl.readout_edges(g, \"h\", op=reducer)\n    # check correctness\n    subg = dgl.unbatch(g)\n    subx = []\n    for sg in subg:\n        sx = dgl.readout_edges(sg, \"h\", op=reducer)\n        subx.append(sx)\n    assert F.allclose(x, F.cat(subx, dim=0))\n\n    x = getattr(dgl, \"{}_edges\".format(reducer))(g, \"h\")\n    # check correctness\n    subg = dgl.unbatch(g)\n    subx = []\n    for sg in subg:\n        sx = getattr(dgl, \"{}_edges\".format(reducer))(sg, \"h\")\n        subx.append(sx)\n    assert F.allclose(x, F.cat(subx, dim=0))\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"g\", get_cases([\"homo\"], exclude=[\"dglgraph\"]))\n@pytest.mark.parametrize(\"reducer\", [\"sum\", \"max\", \"mean\"])\ndef test_weighted_reduce_readout(g, idtype, reducer):\n    g = g.astype(idtype).to(F.ctx())\n    g.ndata[\"h\"] = F.randn((g.num_nodes(), 3))\n    g.ndata[\"w\"] = F.randn((g.num_nodes(), 1))\n    g.edata[\"h\"] = F.randn((g.num_edges(), 2))\n    g.edata[\"w\"] = F.randn((g.num_edges(), 1))\n\n    # Test.1: node readout\n    x = dgl.readout_nodes(g, \"h\", \"w\", op=reducer)\n    # check correctness\n    subg = dgl.unbatch(g)\n    subx = []\n    for sg in subg:\n        sx = dgl.readout_nodes(sg, \"h\", \"w\", op=reducer)\n        subx.append(sx)\n    assert F.allclose(x, F.cat(subx, dim=0))\n\n    x = getattr(dgl, \"{}_nodes\".format(reducer))(g, \"h\", \"w\")\n    # check correctness\n    subg = dgl.unbatch(g)\n    subx = []\n    for sg in subg:\n        sx = getattr(dgl, \"{}_nodes\".format(reducer))(sg, \"h\", \"w\")\n        subx.append(sx)\n    assert F.allclose(x, F.cat(subx, dim=0))\n\n    # Test.2: edge readout\n    x = dgl.readout_edges(g, \"h\", \"w\", op=reducer)\n    # check correctness\n    subg = dgl.unbatch(g)\n    subx = []\n    for sg in subg:\n        sx = dgl.readout_edges(sg, \"h\", \"w\", op=reducer)\n        subx.append(sx)\n    assert F.allclose(x, F.cat(subx, dim=0))\n\n    x = getattr(dgl, \"{}_edges\".format(reducer))(g, \"h\", \"w\")\n    # check correctness\n    subg = dgl.unbatch(g)\n    subx = []\n    for sg in subg:\n        sx = getattr(dgl, \"{}_edges\".format(reducer))(sg, \"h\", \"w\")\n        subx.append(sx)\n    assert F.allclose(x, F.cat(subx, dim=0))\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"g\", get_cases([\"homo\"], exclude=[\"dglgraph\"]))\n@pytest.mark.parametrize(\"descending\", [True, False])\ndef test_topk(g, idtype, descending):\n    g = g.astype(idtype).to(F.ctx())\n    g.ndata[\"x\"] = F.randn((g.num_nodes(), 3))\n\n    # Test.1: to test the case where k > number of nodes.\n    dgl.topk_nodes(g, \"x\", 100, sortby=-1)\n\n    # Test.2: test correctness\n    min_nnodes = F.asnumpy(g.batch_num_nodes()).min()\n    if min_nnodes <= 1:\n        return\n    k = min_nnodes - 1\n    val, indices = dgl.topk_nodes(g, \"x\", k, descending=descending, sortby=-1)\n    print(k)\n    print(g.ndata[\"x\"])\n    print(\"val\", val)\n    print(\"indices\", indices)\n    subg = dgl.unbatch(g)\n    subval, subidx = [], []\n    for sg in subg:\n        subx = F.asnumpy(sg.ndata[\"x\"])\n        ai = np.argsort(subx[:, -1:].flatten())\n        if descending:\n            ai = np.ascontiguousarray(ai[::-1])\n        subx = np.expand_dims(subx[ai[:k]], 0)\n        subval.append(F.tensor(subx))\n        subidx.append(F.tensor(np.expand_dims(ai[:k], 0)))\n    print(F.cat(subval, dim=0))\n    assert F.allclose(val, F.cat(subval, dim=0))\n    assert F.allclose(indices, F.cat(subidx, dim=0))\n\n    # Test.3: sorby=None\n    dgl.topk_nodes(g, \"x\", k, sortby=None)\n\n    g.edata[\"x\"] = F.randn((g.num_edges(), 3))\n\n    # Test.4: topk edges where k > number of edges.\n    dgl.topk_edges(g, \"x\", 100, sortby=-1)\n\n    # Test.5: topk edges test correctness\n    min_nedges = F.asnumpy(g.batch_num_edges()).min()\n    if min_nedges <= 1:\n        return\n    k = min_nedges - 1\n    val, indices = dgl.topk_edges(g, \"x\", k, descending=descending, sortby=-1)\n    print(k)\n    print(g.edata[\"x\"])\n    print(\"val\", val)\n    print(\"indices\", indices)\n    subg = dgl.unbatch(g)\n    subval, subidx = [], []\n    for sg in subg:\n        subx = F.asnumpy(sg.edata[\"x\"])\n        ai = np.argsort(subx[:, -1:].flatten())\n        if descending:\n            ai = np.ascontiguousarray(ai[::-1])\n        subx = np.expand_dims(subx[ai[:k]], 0)\n        subval.append(F.tensor(subx))\n        subidx.append(F.tensor(np.expand_dims(ai[:k], 0)))\n    print(F.cat(subval, dim=0))\n    assert F.allclose(val, F.cat(subval, dim=0))\n    assert F.allclose(indices, F.cat(subidx, dim=0))\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"g\", get_cases([\"homo\"], exclude=[\"dglgraph\"]))\ndef test_softmax(g, idtype):\n    g = g.astype(idtype).to(F.ctx())\n    g.ndata[\"h\"] = F.randn((g.num_nodes(), 3))\n    g.edata[\"h\"] = F.randn((g.num_edges(), 2))\n\n    # Test.1: node readout\n    x = dgl.softmax_nodes(g, \"h\")\n    subg = dgl.unbatch(g)\n    subx = []\n    for sg in subg:\n        subx.append(F.softmax(sg.ndata[\"h\"], dim=0))\n    assert F.allclose(x, F.cat(subx, dim=0))\n\n    # Test.2: edge readout\n    x = dgl.softmax_edges(g, \"h\")\n    subg = dgl.unbatch(g)\n    subx = []\n    for sg in subg:\n        subx.append(F.softmax(sg.edata[\"h\"], dim=0))\n    assert F.allclose(x, F.cat(subx, dim=0))\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"g\", get_cases([\"homo\"], exclude=[\"dglgraph\"]))\ndef test_broadcast(idtype, g):\n    g = g.astype(idtype).to(F.ctx())\n    gfeat = F.randn((g.batch_size, 3))\n\n    # Test.0: broadcast_nodes\n    g.ndata[\"h\"] = dgl.broadcast_nodes(g, gfeat)\n    subg = dgl.unbatch(g)\n    for i, sg in enumerate(subg):\n        assert F.allclose(\n            sg.ndata[\"h\"],\n            F.repeat(F.reshape(gfeat[i], (1, 3)), sg.num_nodes(), dim=0),\n        )\n\n    # Test.1: broadcast_edges\n    g.edata[\"h\"] = dgl.broadcast_edges(g, gfeat)\n    subg = dgl.unbatch(g)\n    for i, sg in enumerate(subg):\n        assert F.allclose(\n            sg.edata[\"h\"],\n            F.repeat(F.reshape(gfeat[i], (1, 3)), sg.num_edges(), dim=0),\n        )\n"
  },
  {
    "path": "tests/python/common/test_sparse_ops-csr.py",
    "content": "import backend as F\n\nimport dgl\nimport numpy as np\nimport pytest\nimport scipy.sparse as ssp\nfrom utils import parametrize_idtype\n\nif F.backend_name == \"pytorch\":\n    import torch\n\n    torch.backends.cuda.matmul.allow_tf32 = False\n\n\ndef _random_simple_graph(\n    idtype, dtype, ctx, M, N, max_nnz, srctype, dsttype, etype\n):\n    src = np.random.randint(0, M, (max_nnz,))\n    dst = np.random.randint(0, N, (max_nnz,))\n    val = np.random.randn(max_nnz)\n    a = ssp.csr_matrix((val, (src, dst)), shape=(M, N))\n    a.sum_duplicates()\n    a = a.tocoo()\n    # shuffle edges\n    perm = np.random.permutation(a.nnz)\n    row = a.row[perm]\n    col = a.col[perm]\n    val = a.data[perm]\n    a = ssp.csr_matrix((val, (row, col)), shape=(M, N))\n\n    A = dgl.heterograph(\n        {\n            (srctype, etype, dsttype): (\n                F.copy_to(F.tensor(row, dtype=idtype), ctx),\n                F.copy_to(F.tensor(col, dtype=idtype), ctx),\n            )\n        },\n        num_nodes_dict={srctype: a.shape[0], dsttype: a.shape[1]},\n    )\n    A.edata[\"w\"] = F.copy_to(F.tensor(val, dtype=dtype), ctx)\n    return a, A\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"dtype\", [F.float32, F.float64])\n@pytest.mark.parametrize(\"return_edge_ids\", [True, False])\ndef test_csrmm(idtype, dtype, return_edge_ids):\n    a, A = _random_simple_graph(\n        idtype, dtype, F.ctx(), 500, 600, 9000, \"A\", \"B\", \"AB\"\n    )\n    b, B = _random_simple_graph(\n        idtype, dtype, F.ctx(), 600, 700, 9000, \"B\", \"C\", \"BC\"\n    )\n    C, C_weights = dgl._sparse_ops._csrmm(\n        A._graph, A.edata[\"w\"], B._graph, B.edata[\"w\"], 2\n    )\n    C_adj = C.adjacency_matrix_scipy(0, False, \"csr\", return_edge_ids)\n    C_adj.data = F.asnumpy(C_weights)\n    C_adj = F.tensor(C_adj.todense(), dtype=dtype)\n    c = F.tensor((a * b).todense(), dtype=dtype)\n    assert F.allclose(C_adj, c)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"dtype\", [F.float32, F.float64])\n@pytest.mark.parametrize(\"num_vtypes\", [1, 2])\ndef test_csrmm_backward(idtype, dtype, num_vtypes):\n    a, A = _random_simple_graph(idtype, dtype, F.ctx(), 3, 4, 6, \"A\", \"B\", \"AB\")\n    b, B = _random_simple_graph(\n        idtype,\n        dtype,\n        F.ctx(),\n        4,\n        3,\n        6,\n        \"B\",\n        \"A\" if num_vtypes == 1 else \"C\",\n        \"BA\",\n    )\n    A_row, A_col = A.edges(order=\"eid\")\n    B_row, B_col = B.edges(order=\"eid\")\n    A_row = F.asnumpy(A_row)\n    A_col = F.asnumpy(A_col)\n    B_row = F.asnumpy(B_row)\n    B_col = F.asnumpy(B_col)\n    a_dense = F.attach_grad(F.tensor(a.todense(), dtype=dtype))\n    b_dense = F.attach_grad(F.tensor(b.todense(), dtype=dtype))\n\n    A.edata[\"w\"] = F.attach_grad(A.edata[\"w\"])\n    B.edata[\"w\"] = F.attach_grad(B.edata[\"w\"])\n\n    with F.record_grad():\n        C = dgl.adj_product_graph(A, B, \"w\")\n        assert len(C.ntypes) == num_vtypes\n        assert len(C.etypes) == 1\n        C_dense = np.zeros((3, 3))\n        C_row, C_col = C.edges(order=\"eid\")\n        C_row = F.asnumpy(C_row)\n        C_col = F.asnumpy(C_col)\n        C_dense[C_row, C_col] = F.asnumpy(C.edata[\"w\"])\n        c_dense = F.matmul(a_dense, b_dense)\n        assert np.allclose(C_dense, F.asnumpy(c_dense), rtol=1e-4, atol=1e-4)\n\n        F.backward(F.reduce_sum(C.edata[\"w\"]) + F.reduce_sum(c_dense))\n        a_dense_grad = F.asnumpy(F.grad(a_dense))[A_row, A_col]\n        b_dense_grad = F.asnumpy(F.grad(b_dense))[B_row, B_col]\n        A_spspmm_grad = F.asnumpy(F.grad(A.edata[\"w\"]))\n        B_spspmm_grad = F.asnumpy(F.grad(B.edata[\"w\"]))\n        assert np.allclose(a_dense_grad, A_spspmm_grad, rtol=1e-4, atol=1e-4)\n        assert np.allclose(b_dense_grad, B_spspmm_grad, rtol=1e-4, atol=1e-4)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"dtype\", [F.float32, F.float64])\n@pytest.mark.parametrize(\"return_edge_ids\", [True, False])\ndef test_csrsum(idtype, dtype, return_edge_ids):\n    a, A = _random_simple_graph(\n        idtype, dtype, F.ctx(), 500, 600, 9000, \"A\", \"B\", \"AB\"\n    )\n    b, B = _random_simple_graph(\n        idtype, dtype, F.ctx(), 500, 600, 9000, \"A\", \"B\", \"AB\"\n    )\n    C, C_weights = dgl._sparse_ops._csrsum(\n        [A._graph, B._graph], [A.edata[\"w\"], B.edata[\"w\"]]\n    )\n    C_adj = C.adjacency_matrix_scipy(0, False, \"csr\", return_edge_ids)\n    C_adj.data = F.asnumpy(C_weights)\n    C_adj = F.tensor(C_adj.todense(), dtype=dtype)\n    c = F.tensor((a + b).todense(), dtype=dtype)\n    assert F.allclose(C_adj, c)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"dtype\", [F.float32, F.float64])\n@pytest.mark.parametrize(\"nelems\", [1, 2])\ndef test_csrsum_backward(idtype, dtype, nelems):\n    a, A = _random_simple_graph(idtype, dtype, F.ctx(), 3, 4, 6, \"A\", \"B\", \"AB\")\n    b, B = _random_simple_graph(idtype, dtype, F.ctx(), 3, 4, 6, \"A\", \"B\", \"AB\")\n    A_row, A_col = A.edges(order=\"eid\")\n    B_row, B_col = B.edges(order=\"eid\")\n    A_row = F.asnumpy(A_row)\n    A_col = F.asnumpy(A_col)\n    B_row = F.asnumpy(B_row)\n    B_col = F.asnumpy(B_col)\n    a_dense = F.attach_grad(F.tensor(a.todense(), dtype=dtype))\n    b_dense = F.attach_grad(F.tensor(b.todense(), dtype=dtype))\n\n    A.edata[\"w\"] = F.attach_grad(A.edata[\"w\"])\n    B.edata[\"w\"] = F.attach_grad(B.edata[\"w\"])\n\n    with F.record_grad():\n        if nelems == 2:\n            # Test for two element case\n            C = dgl.adj_sum_graph([A, B], \"w\")\n            assert C.canonical_etypes == A.canonical_etypes\n            C_dense = np.zeros((3, 4))\n            C_row, C_col = C.edges(order=\"eid\")\n            C_row = F.asnumpy(C_row)\n            C_col = F.asnumpy(C_col)\n            C_dense[C_row, C_col] = F.asnumpy(C.edata[\"w\"])\n            c_dense = a_dense + b_dense\n            assert np.allclose(\n                C_dense, F.asnumpy(c_dense), rtol=1e-4, atol=1e-4\n            )\n\n            F.backward(F.reduce_sum(C.edata[\"w\"]) + F.reduce_sum(c_dense))\n            a_dense_grad = F.asnumpy(F.grad(a_dense))[A_row, A_col]\n            b_dense_grad = F.asnumpy(F.grad(b_dense))[B_row, B_col]\n            A_spspmm_grad = F.asnumpy(F.grad(A.edata[\"w\"]))\n            B_spspmm_grad = F.asnumpy(F.grad(B.edata[\"w\"]))\n            assert np.allclose(\n                a_dense_grad, A_spspmm_grad, rtol=1e-4, atol=1e-4\n            )\n            assert np.allclose(\n                b_dense_grad, B_spspmm_grad, rtol=1e-4, atol=1e-4\n            )\n        elif nelems == 1:\n            # Test for single element case\n            C = dgl.adj_sum_graph([A], \"w\")\n            assert C.canonical_etypes == A.canonical_etypes\n            C_dense = np.zeros((3, 4))\n            C_row, C_col = C.edges(order=\"eid\")\n            C_row = F.asnumpy(C_row)\n            C_col = F.asnumpy(C_col)\n            C_dense[C_row, C_col] = F.asnumpy(C.edata[\"w\"])\n            c_dense = a_dense\n            assert np.allclose(\n                C_dense, F.asnumpy(c_dense), rtol=1e-4, atol=1e-4\n            )\n\n            F.backward(F.reduce_sum(C.edata[\"w\"]) + F.reduce_sum(c_dense))\n            a_dense_grad = F.asnumpy(F.grad(a_dense))[A_row, A_col]\n            A_spspmm_grad = F.asnumpy(F.grad(A.edata[\"w\"]))\n            assert np.allclose(\n                a_dense_grad, A_spspmm_grad, rtol=1e-4, atol=1e-4\n            )\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"dtype\", [F.float32, F.float64])\n@pytest.mark.parametrize(\"A_nnz\", [9000, 0])\n@pytest.mark.parametrize(\"B_nnz\", [9000, 0])\ndef test_csrmask(idtype, dtype, A_nnz, B_nnz):\n    a, A = _random_simple_graph(\n        idtype, dtype, F.ctx(), 500, 600, A_nnz, \"A\", \"B\", \"AB\"\n    )\n    b, B = _random_simple_graph(\n        idtype, dtype, F.ctx(), 500, 600, B_nnz, \"A\", \"B\", \"AB\"\n    )\n    C = dgl._sparse_ops._csrmask(A._graph, A.edata[\"w\"], B._graph)\n    B_row, B_col = B.edges(order=\"eid\")\n    B_row = F.asnumpy(B_row)\n    B_col = F.asnumpy(B_col)\n    c = F.tensor(a.todense()[B_row, B_col], dtype)\n    assert F.allclose(C, c)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"dtype\", [F.float32, F.float64])\ndef test_csrmask_backward(idtype, dtype):\n    a, A = _random_simple_graph(idtype, dtype, F.ctx(), 3, 4, 6, \"A\", \"B\", \"AB\")\n    b, B = _random_simple_graph(idtype, dtype, F.ctx(), 3, 4, 6, \"A\", \"B\", \"AB\")\n    A_row, A_col = A.edges(order=\"eid\")\n    B_row, B_col = B.edges(order=\"eid\")\n    A_row = F.asnumpy(A_row)\n    A_col = F.asnumpy(A_col)\n    B_row = F.asnumpy(B_row)\n    B_col = F.asnumpy(B_col)\n    a_dense = F.attach_grad(F.tensor(a.todense(), dtype=dtype))\n\n    A.edata[\"w\"] = F.attach_grad(A.edata[\"w\"])\n\n    with F.record_grad():\n        # Test for two element case\n        C1 = F.csrmask(A._graph, A.edata[\"w\"], B._graph)\n        if dgl.backend.backend_name == \"tensorflow\":\n            import tensorflow as tf\n\n            C2 = tf.gather_nd(a_dense, tf.stack([B_row, B_col], 1))\n        else:\n            C2 = a_dense[B_row, B_col]\n        assert F.allclose(C1, C2, rtol=1e-4, atol=1e-4)\n\n        F.backward(F.reduce_sum(C1) + F.reduce_sum(C2))\n        a_dense_grad = F.asnumpy(F.grad(a_dense))[A_row, A_col]\n        A_spspmm_grad = F.asnumpy(F.grad(A.edata[\"w\"]))\n        assert np.allclose(a_dense_grad, A_spspmm_grad, rtol=1e-4, atol=1e-4)\n\n\nif __name__ == \"__main__\":\n    test_csrmm(F.int32, F.float32)\n    test_csrmm(F.int64, F.float32)\n    test_csrsum(F.int32, F.float32)\n    test_csrsum(F.int64, F.float32)\n    test_csrmask(F.int32, F.float32, 9000, 9000)\n    test_csrmask(F.int64, F.float32, 9000, 0)\n    test_csrmask(F.int32, F.float32, 0, 9000)\n    test_csrmask(F.int64, F.float32, 0, 0)\n    test_csrmm_backward(F.int32, F.float32, 1)\n    test_csrmm_backward(F.int64, F.float32, 1)\n    test_csrmm_backward(F.int32, F.float32, 2)\n    test_csrmm_backward(F.int64, F.float32, 2)\n    test_csrsum_backward(F.int32, F.float32, 1)\n    test_csrsum_backward(F.int64, F.float32, 1)\n    test_csrsum_backward(F.int32, F.float32, 2)\n    test_csrsum_backward(F.int64, F.float32, 2)\n    test_csrmask_backward(F.int32, F.float32)\n    test_csrmask_backward(F.int64, F.float32)\n"
  },
  {
    "path": "tests/python/common/test_subgraph.py",
    "content": "import unittest\n\nimport backend as F\n\nimport dgl\nimport networkx as nx\nimport numpy as np\nimport pytest\nimport scipy.sparse as ssp\nfrom utils import parametrize_idtype\n\nD = 5\n\n\ndef generate_graph(grad=False, add_data=True):\n    g = dgl.graph([]).to(F.ctx())\n    g.add_nodes(10)\n    # create a graph where 0 is the source and 9 is the sink\n    for i in range(1, 9):\n        g.add_edges(0, i)\n        g.add_edges(i, 9)\n    # add a back flow from 9 to 0\n    g.add_edges(9, 0)\n    if add_data:\n        ncol = F.randn((10, D))\n        ecol = F.randn((17, D))\n        if grad:\n            ncol = F.attach_grad(ncol)\n            ecol = F.attach_grad(ecol)\n        g.ndata[\"h\"] = ncol\n        g.edata[\"l\"] = ecol\n    return g\n\n\ndef test_edge_subgraph():\n    # Test when the graph has no node data and edge data.\n    g = generate_graph(add_data=False)\n    eid = [0, 2, 3, 6, 7, 9]\n\n    # relabel=True\n    sg = g.edge_subgraph(eid)\n    assert F.array_equal(\n        sg.ndata[dgl.NID], F.tensor([0, 2, 4, 5, 1, 9], g.idtype)\n    )\n    assert F.array_equal(sg.edata[dgl.EID], F.tensor(eid, g.idtype))\n    sg.ndata[\"h\"] = F.arange(0, sg.num_nodes())\n    sg.edata[\"h\"] = F.arange(0, sg.num_edges())\n\n    # relabel=False\n    sg = g.edge_subgraph(eid, relabel_nodes=False)\n    assert g.num_nodes() == sg.num_nodes()\n    assert F.array_equal(sg.edata[dgl.EID], F.tensor(eid, g.idtype))\n    sg.ndata[\"h\"] = F.arange(0, sg.num_nodes())\n    sg.edata[\"h\"] = F.arange(0, sg.num_edges())\n\n\n@pytest.mark.parametrize(\"relabel_nodes\", [True, False])\ndef test_subgraph_relabel_nodes(relabel_nodes):\n    g = generate_graph()\n    h = g.ndata[\"h\"]\n    l = g.edata[\"l\"]\n    nid = [0, 2, 3, 6, 7, 9]\n    sg = g.subgraph(nid, relabel_nodes=relabel_nodes)\n    eid = {2, 3, 4, 5, 10, 11, 12, 13, 16}\n    assert set(F.asnumpy(sg.edata[dgl.EID])) == eid\n    eid = sg.edata[dgl.EID]\n    # the subgraph is empty initially except for EID field\n    # the subgraph is empty initially except for NID field if relabel_nodes\n    if relabel_nodes:\n        assert len(sg.ndata) == 2\n    assert len(sg.edata) == 2\n    sh = sg.ndata[\"h\"]\n    # The node number is not reduced if relabel_node=False.\n    # The subgraph keeps the same node information as the original graph.\n    if relabel_nodes:\n        assert F.allclose(F.gather_row(h, F.tensor(nid)), sh)\n    else:\n        assert F.allclose(\n            F.gather_row(h, F.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])), sh\n        )\n    # The s,d,eid means the source node, destination node and edge id of the subgraph.\n    # The edges labeled 1 are those selected by the subgraph.\n    \"\"\"\n    s, d, eid\n    0, 1, 0\n    1, 9, 1\n    0, 2, 2    1\n    2, 9, 3    1\n    0, 3, 4    1\n    3, 9, 5    1\n    0, 4, 6\n    4, 9, 7\n    0, 5, 8\n    5, 9, 9       3\n    0, 6, 10   1\n    6, 9, 11   1  3\n    0, 7, 12   1\n    7, 9, 13   1  3\n    0, 8, 14\n    8, 9, 15      3\n    9, 0, 16   1\n    \"\"\"\n    assert F.allclose(F.gather_row(l, eid), sg.edata[\"l\"])\n    # update the node/edge features on the subgraph should NOT\n    # reflect to the parent graph.\n    if relabel_nodes:\n        sg.ndata[\"h\"] = F.zeros((6, D))\n    else:\n        sg.ndata[\"h\"] = F.zeros((10, D))\n    assert F.allclose(h, g.ndata[\"h\"])\n\n\ndef _test_map_to_subgraph():\n    g = dgl.graph([])\n    g.add_nodes(10)\n    g.add_edges(F.arange(0, 9), F.arange(1, 10))\n    h = g.subgraph([0, 1, 2, 5, 8])\n    v = h.map_to_subgraph_nid([0, 8, 2])\n    assert np.array_equal(F.asnumpy(v), np.array([0, 4, 2]))\n\n\ndef create_test_heterograph(idtype):\n    # test heterograph from the docstring, plus a user -- wishes -- game relation\n    # 3 users, 2 games, 2 developers\n    # metagraph:\n    #    ('user', 'follows', 'user'),\n    #    ('user', 'plays', 'game'),\n    #    ('user', 'wishes', 'game'),\n    #    ('developer', 'develops', 'game')])\n\n    g = dgl.heterograph(\n        {\n            (\"user\", \"follows\", \"user\"): ([0, 1], [1, 2]),\n            (\"user\", \"plays\", \"game\"): ([0, 1, 2, 1], [0, 0, 1, 1]),\n            (\"user\", \"wishes\", \"game\"): ([0, 2], [1, 0]),\n            (\"developer\", \"develops\", \"game\"): ([0, 1], [0, 1]),\n        },\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    for etype in g.etypes:\n        g.edges[etype].data[\"weight\"] = F.randn((g.num_edges(etype),))\n    assert g.idtype == idtype\n    assert g.device == F.ctx()\n    return g\n\n\ndef create_test_heterograph2(idtype):\n    \"\"\"test heterograph from the docstring, with an empty relation\"\"\"\n    # 3 users, 2 games, 2 developers\n    # metagraph:\n    #    ('user', 'follows', 'user'),\n    #    ('user', 'plays', 'game'),\n    #    ('user', 'wishes', 'game'),\n    #    ('developer', 'develops', 'game')\n\n    g = dgl.heterograph(\n        {\n            (\"user\", \"follows\", \"user\"): ([0, 1], [1, 2]),\n            (\"user\", \"plays\", \"game\"): ([0, 1, 2, 1], [0, 0, 1, 1]),\n            (\"user\", \"wishes\", \"game\"): ([0, 2], [1, 0]),\n            (\"developer\", \"develops\", \"game\"): ([], []),\n        },\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    for etype in g.etypes:\n        g.edges[etype].data[\"weight\"] = F.randn((g.num_edges(etype),))\n    assert g.idtype == idtype\n    assert g.device == F.ctx()\n    return g\n\n\n@unittest.skipIf(\n    dgl.backend.backend_name == \"mxnet\",\n    reason=\"MXNet doesn't support bool tensor\",\n)\n@parametrize_idtype\ndef test_subgraph_mask(idtype):\n    g = create_test_heterograph(idtype)\n    g_graph = g[\"follows\"]\n    g_bipartite = g[\"plays\"]\n\n    x = F.randn((3, 5))\n    y = F.randn((2, 4))\n    g.nodes[\"user\"].data[\"h\"] = x\n    g.edges[\"follows\"].data[\"h\"] = y\n\n    def _check_subgraph(g, sg):\n        assert sg.idtype == g.idtype\n        assert sg.device == g.device\n        assert sg.ntypes == g.ntypes\n        assert sg.etypes == g.etypes\n        assert sg.canonical_etypes == g.canonical_etypes\n        assert F.array_equal(\n            F.tensor(sg.nodes[\"user\"].data[dgl.NID]), F.tensor([1, 2], idtype)\n        )\n        assert F.array_equal(\n            F.tensor(sg.nodes[\"game\"].data[dgl.NID]), F.tensor([0], idtype)\n        )\n        assert F.array_equal(\n            F.tensor(sg.edges[\"follows\"].data[dgl.EID]), F.tensor([1], idtype)\n        )\n        assert F.array_equal(\n            F.tensor(sg.edges[\"plays\"].data[dgl.EID]), F.tensor([1], idtype)\n        )\n        assert F.array_equal(\n            F.tensor(sg.edges[\"wishes\"].data[dgl.EID]), F.tensor([1], idtype)\n        )\n        assert sg.num_nodes(\"developer\") == 0\n        assert sg.num_edges(\"develops\") == 0\n        assert F.array_equal(\n            sg.nodes[\"user\"].data[\"h\"], g.nodes[\"user\"].data[\"h\"][1:3]\n        )\n        assert F.array_equal(\n            sg.edges[\"follows\"].data[\"h\"], g.edges[\"follows\"].data[\"h\"][1:2]\n        )\n\n    sg1 = g.subgraph(\n        {\n            \"user\": F.tensor([False, True, True], dtype=F.bool),\n            \"game\": F.tensor([True, False, False, False], dtype=F.bool),\n        }\n    )\n    _check_subgraph(g, sg1)\n    sg2 = g.edge_subgraph(\n        {\n            \"follows\": F.tensor([False, True], dtype=F.bool),\n            \"plays\": F.tensor([False, True, False, False], dtype=F.bool),\n            \"wishes\": F.tensor([False, True], dtype=F.bool),\n        }\n    )\n    _check_subgraph(g, sg2)\n\n\n@parametrize_idtype\ndef test_subgraph1(idtype):\n    g = create_test_heterograph(idtype)\n    g_graph = g[\"follows\"]\n    g_bipartite = g[\"plays\"]\n\n    x = F.randn((3, 5))\n    y = F.randn((2, 4))\n    g.nodes[\"user\"].data[\"h\"] = x\n    g.edges[\"follows\"].data[\"h\"] = y\n\n    def _check_subgraph(g, sg):\n        assert sg.idtype == g.idtype\n        assert sg.device == g.device\n        assert sg.ntypes == g.ntypes\n        assert sg.etypes == g.etypes\n        assert sg.canonical_etypes == g.canonical_etypes\n        assert F.array_equal(\n            F.tensor(sg.nodes[\"user\"].data[dgl.NID]), F.tensor([1, 2], g.idtype)\n        )\n        assert F.array_equal(\n            F.tensor(sg.nodes[\"game\"].data[dgl.NID]), F.tensor([0], g.idtype)\n        )\n        assert F.array_equal(\n            F.tensor(sg.edges[\"follows\"].data[dgl.EID]), F.tensor([1], g.idtype)\n        )\n        assert F.array_equal(\n            F.tensor(sg.edges[\"plays\"].data[dgl.EID]), F.tensor([1], g.idtype)\n        )\n        assert F.array_equal(\n            F.tensor(sg.edges[\"wishes\"].data[dgl.EID]), F.tensor([1], g.idtype)\n        )\n        assert sg.num_nodes(\"developer\") == 0\n        assert sg.num_edges(\"develops\") == 0\n        assert F.array_equal(\n            sg.nodes[\"user\"].data[\"h\"], g.nodes[\"user\"].data[\"h\"][1:3]\n        )\n        assert F.array_equal(\n            sg.edges[\"follows\"].data[\"h\"], g.edges[\"follows\"].data[\"h\"][1:2]\n        )\n\n    sg1 = g.subgraph({\"user\": [1, 2], \"game\": [0]})\n    _check_subgraph(g, sg1)\n    sg2 = g.edge_subgraph({\"follows\": [1], \"plays\": [1], \"wishes\": [1]})\n    _check_subgraph(g, sg2)\n\n    # backend tensor input\n    sg1 = g.subgraph(\n        {\n            \"user\": F.tensor([1, 2], dtype=idtype),\n            \"game\": F.tensor([0], dtype=idtype),\n        }\n    )\n    _check_subgraph(g, sg1)\n    sg2 = g.edge_subgraph(\n        {\n            \"follows\": F.tensor([1], dtype=idtype),\n            \"plays\": F.tensor([1], dtype=idtype),\n            \"wishes\": F.tensor([1], dtype=idtype),\n        }\n    )\n    _check_subgraph(g, sg2)\n\n    # numpy input\n    sg1 = g.subgraph({\"user\": np.array([1, 2]), \"game\": np.array([0])})\n    _check_subgraph(g, sg1)\n    sg2 = g.edge_subgraph(\n        {\n            \"follows\": np.array([1]),\n            \"plays\": np.array([1]),\n            \"wishes\": np.array([1]),\n        }\n    )\n    _check_subgraph(g, sg2)\n\n    def _check_subgraph_single_ntype(g, sg, preserve_nodes=False):\n        assert sg.idtype == g.idtype\n        assert sg.device == g.device\n        assert sg.ntypes == g.ntypes\n        assert sg.etypes == g.etypes\n        assert sg.canonical_etypes == g.canonical_etypes\n\n        if not preserve_nodes:\n            assert F.array_equal(\n                F.tensor(sg.nodes[\"user\"].data[dgl.NID]),\n                F.tensor([1, 2], g.idtype),\n            )\n        else:\n            for ntype in sg.ntypes:\n                assert g.num_nodes(ntype) == sg.num_nodes(ntype)\n\n        assert F.array_equal(\n            F.tensor(sg.edges[\"follows\"].data[dgl.EID]), F.tensor([1], g.idtype)\n        )\n\n        if not preserve_nodes:\n            assert F.array_equal(\n                sg.nodes[\"user\"].data[\"h\"], g.nodes[\"user\"].data[\"h\"][1:3]\n            )\n        assert F.array_equal(\n            sg.edges[\"follows\"].data[\"h\"], g.edges[\"follows\"].data[\"h\"][1:2]\n        )\n\n    def _check_subgraph_single_etype(g, sg, preserve_nodes=False):\n        assert sg.ntypes == g.ntypes\n        assert sg.etypes == g.etypes\n        assert sg.canonical_etypes == g.canonical_etypes\n\n        if not preserve_nodes:\n            assert F.array_equal(\n                F.tensor(sg.nodes[\"user\"].data[dgl.NID]),\n                F.tensor([0, 1], g.idtype),\n            )\n            assert F.array_equal(\n                F.tensor(sg.nodes[\"game\"].data[dgl.NID]),\n                F.tensor([0], g.idtype),\n            )\n        else:\n            for ntype in sg.ntypes:\n                assert g.num_nodes(ntype) == sg.num_nodes(ntype)\n\n        assert F.array_equal(\n            F.tensor(sg.edges[\"plays\"].data[dgl.EID]),\n            F.tensor([0, 1], g.idtype),\n        )\n\n    sg1_graph = g_graph.subgraph([1, 2])\n    _check_subgraph_single_ntype(g_graph, sg1_graph)\n    sg1_graph = g_graph.edge_subgraph([1])\n    _check_subgraph_single_ntype(g_graph, sg1_graph)\n    sg1_graph = g_graph.edge_subgraph([1], relabel_nodes=False)\n    _check_subgraph_single_ntype(g_graph, sg1_graph, True)\n    sg2_bipartite = g_bipartite.edge_subgraph([0, 1])\n    _check_subgraph_single_etype(g_bipartite, sg2_bipartite)\n    sg2_bipartite = g_bipartite.edge_subgraph([0, 1], relabel_nodes=False)\n    _check_subgraph_single_etype(g_bipartite, sg2_bipartite, True)\n\n    def _check_typed_subgraph1(g, sg):\n        assert g.idtype == sg.idtype\n        assert g.device == sg.device\n        assert set(sg.ntypes) == {\"user\", \"game\"}\n        assert set(sg.etypes) == {\"follows\", \"plays\", \"wishes\"}\n        for ntype in sg.ntypes:\n            assert sg.num_nodes(ntype) == g.num_nodes(ntype)\n        for etype in sg.etypes:\n            src_sg, dst_sg = sg.all_edges(etype=etype, order=\"eid\")\n            src_g, dst_g = g.all_edges(etype=etype, order=\"eid\")\n            assert F.array_equal(src_sg, src_g)\n            assert F.array_equal(dst_sg, dst_g)\n        assert F.array_equal(\n            sg.nodes[\"user\"].data[\"h\"], g.nodes[\"user\"].data[\"h\"]\n        )\n        assert F.array_equal(\n            sg.edges[\"follows\"].data[\"h\"], g.edges[\"follows\"].data[\"h\"]\n        )\n        g.nodes[\"user\"].data[\"h\"] = F.scatter_row(\n            g.nodes[\"user\"].data[\"h\"], F.tensor([2]), F.randn((1, 5))\n        )\n        g.edges[\"follows\"].data[\"h\"] = F.scatter_row(\n            g.edges[\"follows\"].data[\"h\"], F.tensor([1]), F.randn((1, 4))\n        )\n        assert F.array_equal(\n            sg.nodes[\"user\"].data[\"h\"], g.nodes[\"user\"].data[\"h\"]\n        )\n        assert F.array_equal(\n            sg.edges[\"follows\"].data[\"h\"], g.edges[\"follows\"].data[\"h\"]\n        )\n\n    def _check_typed_subgraph2(g, sg):\n        assert set(sg.ntypes) == {\"developer\", \"game\"}\n        assert set(sg.etypes) == {\"develops\"}\n        for ntype in sg.ntypes:\n            assert sg.num_nodes(ntype) == g.num_nodes(ntype)\n        for etype in sg.etypes:\n            src_sg, dst_sg = sg.all_edges(etype=etype, order=\"eid\")\n            src_g, dst_g = g.all_edges(etype=etype, order=\"eid\")\n            assert F.array_equal(src_sg, src_g)\n            assert F.array_equal(dst_sg, dst_g)\n\n    sg3 = g.node_type_subgraph([\"user\", \"game\"])\n    _check_typed_subgraph1(g, sg3)\n    sg4 = g.edge_type_subgraph([\"develops\"])\n    _check_typed_subgraph2(g, sg4)\n    sg5 = g.edge_type_subgraph([\"follows\", \"plays\", \"wishes\"])\n    _check_typed_subgraph1(g, sg5)\n\n    # Test for restricted format\n    for fmt in [\"csr\", \"csc\", \"coo\"]:\n        g = dgl.graph(([0, 1], [1, 2])).formats(fmt)\n        sg = g.subgraph({g.ntypes[0]: [1, 0]})\n        nids = F.asnumpy(sg.ndata[dgl.NID])\n        assert np.array_equal(nids, np.array([1, 0]))\n        src, dst = sg.edges(order=\"eid\")\n        src = F.asnumpy(src)\n        dst = F.asnumpy(dst)\n        assert np.array_equal(src, np.array([1]))\n\n\n@parametrize_idtype\ndef test_in_subgraph(idtype):\n    hg = dgl.heterograph(\n        {\n            (\"user\", \"follow\", \"user\"): (\n                [1, 2, 3, 0, 2, 3, 0],\n                [0, 0, 0, 1, 1, 1, 2],\n            ),\n            (\"user\", \"play\", \"game\"): ([0, 0, 1, 3], [0, 1, 2, 2]),\n            (\"game\", \"liked-by\", \"user\"): (\n                [2, 2, 2, 1, 1, 0],\n                [0, 1, 2, 0, 3, 0],\n            ),\n            (\"user\", \"flips\", \"coin\"): ([0, 1, 2, 3], [0, 0, 0, 0]),\n        },\n        idtype=idtype,\n        num_nodes_dict={\"user\": 5, \"game\": 10, \"coin\": 8},\n    ).to(F.ctx())\n    subg = dgl.in_subgraph(hg, {\"user\": [0, 1], \"game\": 0})\n    assert subg.idtype == idtype\n    assert len(subg.ntypes) == 3\n    assert len(subg.etypes) == 4\n    u, v = subg[\"follow\"].edges()\n    edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))\n    assert F.array_equal(\n        hg[\"follow\"].edge_ids(u, v), subg[\"follow\"].edata[dgl.EID]\n    )\n    assert edge_set == {(1, 0), (2, 0), (3, 0), (0, 1), (2, 1), (3, 1)}\n    u, v = subg[\"play\"].edges()\n    edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))\n    assert F.array_equal(hg[\"play\"].edge_ids(u, v), subg[\"play\"].edata[dgl.EID])\n    assert edge_set == {(0, 0)}\n    u, v = subg[\"liked-by\"].edges()\n    edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))\n    assert F.array_equal(\n        hg[\"liked-by\"].edge_ids(u, v), subg[\"liked-by\"].edata[dgl.EID]\n    )\n    assert edge_set == {(2, 0), (2, 1), (1, 0), (0, 0)}\n    assert subg[\"flips\"].num_edges() == 0\n    for ntype in subg.ntypes:\n        assert dgl.NID not in subg.nodes[ntype].data\n\n    # Test store_ids\n    subg = dgl.in_subgraph(hg, {\"user\": [0, 1], \"game\": 0}, store_ids=False)\n    for etype in [\"follow\", \"play\", \"liked-by\"]:\n        assert dgl.EID not in subg.edges[etype].data\n    for ntype in subg.ntypes:\n        assert dgl.NID not in subg.nodes[ntype].data\n\n    # Test relabel nodes\n    subg = dgl.in_subgraph(hg, {\"user\": [0, 1], \"game\": 0}, relabel_nodes=True)\n    assert subg.idtype == idtype\n    assert len(subg.ntypes) == 3\n    assert len(subg.etypes) == 4\n\n    u, v = subg[\"follow\"].edges()\n    old_u = F.gather_row(subg.nodes[\"user\"].data[dgl.NID], u)\n    old_v = F.gather_row(subg.nodes[\"user\"].data[dgl.NID], v)\n    assert F.array_equal(\n        hg[\"follow\"].edge_ids(old_u, old_v), subg[\"follow\"].edata[dgl.EID]\n    )\n    edge_set = set(zip(list(F.asnumpy(old_u)), list(F.asnumpy(old_v))))\n    assert edge_set == {(1, 0), (2, 0), (3, 0), (0, 1), (2, 1), (3, 1)}\n\n    u, v = subg[\"play\"].edges()\n    old_u = F.gather_row(subg.nodes[\"user\"].data[dgl.NID], u)\n    old_v = F.gather_row(subg.nodes[\"game\"].data[dgl.NID], v)\n    assert F.array_equal(\n        hg[\"play\"].edge_ids(old_u, old_v), subg[\"play\"].edata[dgl.EID]\n    )\n    edge_set = set(zip(list(F.asnumpy(old_u)), list(F.asnumpy(old_v))))\n    assert edge_set == {(0, 0)}\n\n    u, v = subg[\"liked-by\"].edges()\n    old_u = F.gather_row(subg.nodes[\"game\"].data[dgl.NID], u)\n    old_v = F.gather_row(subg.nodes[\"user\"].data[dgl.NID], v)\n    assert F.array_equal(\n        hg[\"liked-by\"].edge_ids(old_u, old_v), subg[\"liked-by\"].edata[dgl.EID]\n    )\n    edge_set = set(zip(list(F.asnumpy(old_u)), list(F.asnumpy(old_v))))\n    assert edge_set == {(2, 0), (2, 1), (1, 0), (0, 0)}\n\n    assert subg.num_nodes(\"user\") == 4\n    assert subg.num_nodes(\"game\") == 3\n    assert subg.num_nodes(\"coin\") == 0\n    assert subg.num_edges(\"flips\") == 0\n\n\n@parametrize_idtype\ndef test_out_subgraph(idtype):\n    hg = dgl.heterograph(\n        {\n            (\"user\", \"follow\", \"user\"): (\n                [1, 2, 3, 0, 2, 3, 0],\n                [0, 0, 0, 1, 1, 1, 2],\n            ),\n            (\"user\", \"play\", \"game\"): ([0, 0, 1, 3], [0, 1, 2, 2]),\n            (\"game\", \"liked-by\", \"user\"): (\n                [2, 2, 2, 1, 1, 0],\n                [0, 1, 2, 0, 3, 0],\n            ),\n            (\"user\", \"flips\", \"coin\"): ([0, 1, 2, 3], [0, 0, 0, 0]),\n        },\n        idtype=idtype,\n    ).to(F.ctx())\n    subg = dgl.out_subgraph(hg, {\"user\": [0, 1], \"game\": 0})\n    assert subg.idtype == idtype\n    assert len(subg.ntypes) == 3\n    assert len(subg.etypes) == 4\n    u, v = subg[\"follow\"].edges()\n    edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))\n    assert edge_set == {(1, 0), (0, 1), (0, 2)}\n    assert F.array_equal(\n        hg[\"follow\"].edge_ids(u, v), subg[\"follow\"].edata[dgl.EID]\n    )\n    u, v = subg[\"play\"].edges()\n    edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))\n    assert edge_set == {(0, 0), (0, 1), (1, 2)}\n    assert F.array_equal(hg[\"play\"].edge_ids(u, v), subg[\"play\"].edata[dgl.EID])\n    u, v = subg[\"liked-by\"].edges()\n    edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))\n    assert edge_set == {(0, 0)}\n    assert F.array_equal(\n        hg[\"liked-by\"].edge_ids(u, v), subg[\"liked-by\"].edata[dgl.EID]\n    )\n    u, v = subg[\"flips\"].edges()\n    edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))\n    assert edge_set == {(0, 0), (1, 0)}\n    assert F.array_equal(\n        hg[\"flips\"].edge_ids(u, v), subg[\"flips\"].edata[dgl.EID]\n    )\n    for ntype in subg.ntypes:\n        assert dgl.NID not in subg.nodes[ntype].data\n\n    # Test store_ids\n    subg = dgl.out_subgraph(hg, {\"user\": [0, 1], \"game\": 0}, store_ids=False)\n    for etype in subg.canonical_etypes:\n        assert dgl.EID not in subg.edges[etype].data\n    for ntype in subg.ntypes:\n        assert dgl.NID not in subg.nodes[ntype].data\n\n    # Test relabel nodes\n    subg = dgl.out_subgraph(hg, {\"user\": [1], \"game\": 0}, relabel_nodes=True)\n    assert subg.idtype == idtype\n    assert len(subg.ntypes) == 3\n    assert len(subg.etypes) == 4\n\n    u, v = subg[\"follow\"].edges()\n    old_u = F.gather_row(subg.nodes[\"user\"].data[dgl.NID], u)\n    old_v = F.gather_row(subg.nodes[\"user\"].data[dgl.NID], v)\n    edge_set = set(zip(list(F.asnumpy(old_u)), list(F.asnumpy(old_v))))\n    assert edge_set == {(1, 0)}\n    assert F.array_equal(\n        hg[\"follow\"].edge_ids(old_u, old_v), subg[\"follow\"].edata[dgl.EID]\n    )\n\n    u, v = subg[\"play\"].edges()\n    old_u = F.gather_row(subg.nodes[\"user\"].data[dgl.NID], u)\n    old_v = F.gather_row(subg.nodes[\"game\"].data[dgl.NID], v)\n    edge_set = set(zip(list(F.asnumpy(old_u)), list(F.asnumpy(old_v))))\n    assert edge_set == {(1, 2)}\n    assert F.array_equal(\n        hg[\"play\"].edge_ids(old_u, old_v), subg[\"play\"].edata[dgl.EID]\n    )\n\n    u, v = subg[\"liked-by\"].edges()\n    old_u = F.gather_row(subg.nodes[\"game\"].data[dgl.NID], u)\n    old_v = F.gather_row(subg.nodes[\"user\"].data[dgl.NID], v)\n    edge_set = set(zip(list(F.asnumpy(old_u)), list(F.asnumpy(old_v))))\n    assert edge_set == {(0, 0)}\n    assert F.array_equal(\n        hg[\"liked-by\"].edge_ids(old_u, old_v), subg[\"liked-by\"].edata[dgl.EID]\n    )\n\n    u, v = subg[\"flips\"].edges()\n    old_u = F.gather_row(subg.nodes[\"user\"].data[dgl.NID], u)\n    old_v = F.gather_row(subg.nodes[\"coin\"].data[dgl.NID], v)\n    edge_set = set(zip(list(F.asnumpy(old_u)), list(F.asnumpy(old_v))))\n    assert edge_set == {(1, 0)}\n    assert F.array_equal(\n        hg[\"flips\"].edge_ids(old_u, old_v), subg[\"flips\"].edata[dgl.EID]\n    )\n    assert subg.num_nodes(\"user\") == 2\n    assert subg.num_nodes(\"game\") == 2\n    assert subg.num_nodes(\"coin\") == 1\n\n\ndef test_subgraph_message_passing():\n    # Unit test for PR #2055\n    g = dgl.graph(([0, 1, 2], [2, 3, 4])).to(F.cpu())\n    g.ndata[\"x\"] = F.copy_to(F.randn((5, 6)), F.cpu())\n    sg = g.subgraph([1, 2, 3]).to(F.ctx())\n    sg.update_all(\n        lambda edges: {\"x\": edges.src[\"x\"]},\n        lambda nodes: {\"y\": F.sum(nodes.mailbox[\"x\"], 1)},\n    )\n\n\n@parametrize_idtype\ndef test_khop_in_subgraph(idtype):\n    g = dgl.graph(\n        ([1, 1, 2, 3, 4], [0, 2, 0, 4, 2]), idtype=idtype, device=F.ctx()\n    )\n    g.edata[\"w\"] = F.tensor([[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]])\n    sg, inv = dgl.khop_in_subgraph(g, 0, k=2)\n    assert sg.idtype == g.idtype\n    u, v = sg.edges()\n    edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))\n    assert edge_set == {(1, 0), (1, 2), (2, 0), (3, 2)}\n    assert F.array_equal(\n        sg.edata[dgl.EID], F.tensor([0, 1, 2, 4], dtype=idtype)\n    )\n    assert F.array_equal(\n        sg.edata[\"w\"], F.tensor([[0, 1], [2, 3], [4, 5], [8, 9]])\n    )\n    assert F.array_equal(F.astype(inv, idtype), F.tensor([0], idtype))\n\n    # Test multiple nodes\n    sg, inv = dgl.khop_in_subgraph(g, [0, 2], k=1)\n    assert sg.num_edges() == 4\n\n    sg, inv = dgl.khop_in_subgraph(g, F.tensor([0, 2], idtype), k=1)\n    assert sg.num_edges() == 4\n\n    # Test isolated node\n    sg, inv = dgl.khop_in_subgraph(g, 1, k=2)\n    assert sg.idtype == g.idtype\n    assert sg.num_nodes() == 1\n    assert sg.num_edges() == 0\n    assert F.array_equal(F.astype(inv, idtype), F.tensor([0], idtype))\n\n    g = dgl.heterograph(\n        {\n            (\"user\", \"plays\", \"game\"): ([0, 1, 1, 2], [0, 0, 2, 1]),\n            (\"user\", \"follows\", \"user\"): ([0, 1, 1], [1, 2, 2]),\n        },\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    sg, inv = dgl.khop_in_subgraph(g, {\"game\": 0}, k=2)\n    assert sg.idtype == idtype\n    assert sg.num_nodes(\"game\") == 1\n    assert sg.num_nodes(\"user\") == 2\n    assert len(sg.ntypes) == 2\n    assert len(sg.etypes) == 2\n    u, v = sg[\"follows\"].edges()\n    edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))\n    assert edge_set == {(0, 1)}\n    u, v = sg[\"plays\"].edges()\n    edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))\n    assert edge_set == {(0, 0), (1, 0)}\n    assert F.array_equal(F.astype(inv[\"game\"], idtype), F.tensor([0], idtype))\n\n    # Test isolated node\n    sg, inv = dgl.khop_in_subgraph(g, {\"user\": 0}, k=2)\n    assert sg.idtype == idtype\n    assert sg.num_nodes(\"game\") == 0\n    assert sg.num_nodes(\"user\") == 1\n    assert sg.num_edges(\"follows\") == 0\n    assert sg.num_edges(\"plays\") == 0\n    assert F.array_equal(F.astype(inv[\"user\"], idtype), F.tensor([0], idtype))\n\n    # Test multiple nodes\n    sg, inv = dgl.khop_in_subgraph(\n        g, {\"user\": F.tensor([0, 1], idtype), \"game\": 0}, k=1\n    )\n    u, v = sg[\"follows\"].edges()\n    edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))\n    assert edge_set == {(0, 1)}\n    u, v = sg[\"plays\"].edges()\n    edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))\n    assert edge_set == {(0, 0), (1, 0)}\n    assert F.array_equal(\n        F.astype(inv[\"user\"], idtype), F.tensor([0, 1], idtype)\n    )\n    assert F.array_equal(F.astype(inv[\"game\"], idtype), F.tensor([0], idtype))\n\n\n@parametrize_idtype\ndef test_khop_out_subgraph(idtype):\n    g = dgl.graph(\n        ([0, 2, 0, 4, 2], [1, 1, 2, 3, 4]), idtype=idtype, device=F.ctx()\n    )\n    g.edata[\"w\"] = F.tensor([[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]])\n    sg, inv = dgl.khop_out_subgraph(g, 0, k=2)\n    assert sg.idtype == g.idtype\n    u, v = sg.edges()\n    edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))\n    assert edge_set == {(0, 1), (2, 1), (0, 2), (2, 3)}\n    assert F.array_equal(\n        sg.edata[dgl.EID], F.tensor([0, 2, 1, 4], dtype=idtype)\n    )\n    assert F.array_equal(\n        sg.edata[\"w\"], F.tensor([[0, 1], [4, 5], [2, 3], [8, 9]])\n    )\n    assert F.array_equal(F.astype(inv, idtype), F.tensor([0], idtype))\n\n    # Test multiple nodes\n    sg, inv = dgl.khop_out_subgraph(g, [0, 2], k=1)\n    assert sg.num_edges() == 4\n\n    sg, inv = dgl.khop_out_subgraph(g, F.tensor([0, 2], idtype), k=1)\n    assert sg.num_edges() == 4\n\n    # Test isolated node\n    sg, inv = dgl.khop_out_subgraph(g, 1, k=2)\n    assert sg.idtype == g.idtype\n    assert sg.num_nodes() == 1\n    assert sg.num_edges() == 0\n    assert F.array_equal(F.astype(inv, idtype), F.tensor([0], idtype))\n\n    g = dgl.heterograph(\n        {\n            (\"user\", \"plays\", \"game\"): ([0, 1, 1, 2], [0, 0, 2, 1]),\n            (\"user\", \"follows\", \"user\"): ([0, 1], [1, 3]),\n        },\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    sg, inv = dgl.khop_out_subgraph(g, {\"user\": 0}, k=2)\n    assert sg.idtype == idtype\n    assert sg.num_nodes(\"game\") == 2\n    assert sg.num_nodes(\"user\") == 3\n    assert len(sg.ntypes) == 2\n    assert len(sg.etypes) == 2\n    u, v = sg[\"follows\"].edges()\n    edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))\n    assert edge_set == {(0, 1), (1, 2)}\n    u, v = sg[\"plays\"].edges()\n    edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))\n    assert edge_set == {(0, 0), (1, 0), (1, 1)}\n    assert F.array_equal(F.astype(inv[\"user\"], idtype), F.tensor([0], idtype))\n\n    # Test isolated node\n    sg, inv = dgl.khop_out_subgraph(g, {\"user\": 3}, k=2)\n    assert sg.idtype == idtype\n    assert sg.num_nodes(\"game\") == 0\n    assert sg.num_nodes(\"user\") == 1\n    assert sg.num_edges(\"follows\") == 0\n    assert sg.num_edges(\"plays\") == 0\n    assert F.array_equal(F.astype(inv[\"user\"], idtype), F.tensor([0], idtype))\n\n    # Test multiple nodes\n    sg, inv = dgl.khop_out_subgraph(\n        g, {\"user\": F.tensor([2], idtype), \"game\": 0}, k=1\n    )\n    assert sg.num_edges(\"follows\") == 0\n    u, v = sg[\"plays\"].edges()\n    edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))\n    assert edge_set == {(0, 1)}\n    assert F.array_equal(F.astype(inv[\"user\"], idtype), F.tensor([0], idtype))\n    assert F.array_equal(F.astype(inv[\"game\"], idtype), F.tensor([0], idtype))\n\n\n@unittest.skipIf(not F.gpu_ctx(), \"only necessary with GPU\")\n@pytest.mark.parametrize(\n    \"parent_idx_device\",\n    [(\"cpu\", F.cpu()), (\"cuda\", F.cuda()), (\"uva\", F.cpu()), (\"uva\", F.cuda())],\n)\n@pytest.mark.parametrize(\"child_device\", [F.cpu(), F.cuda()])\ndef test_subframes(parent_idx_device, child_device):\n    parent_device, idx_device = parent_idx_device\n    g = dgl.graph(\n        (F.tensor([1, 2, 3], dtype=F.int64), F.tensor([2, 3, 4], dtype=F.int64))\n    )\n    print(g.device)\n    g.ndata[\"x\"] = F.randn((5, 4))\n    g.edata[\"a\"] = F.randn((3, 6))\n    idx = F.tensor([1, 2], dtype=F.int64)\n    if parent_device == \"cuda\":\n        g = g.to(F.cuda())\n    elif parent_device == \"uva\":\n        if F.backend_name != \"pytorch\":\n            pytest.skip(\"UVA only supported for PyTorch\")\n        g = g.to(F.cpu())\n        g.create_formats_()\n        g.pin_memory_()\n    elif parent_device == \"cpu\":\n        g = g.to(F.cpu())\n    idx = F.copy_to(idx, idx_device)\n    sg = g.sample_neighbors(idx, 2).to(child_device)\n    assert sg.device == F.context(sg.ndata[\"x\"])\n    assert sg.device == F.context(sg.edata[\"a\"])\n    assert sg.device == child_device\n    if parent_device != \"uva\":\n        sg = g.to(child_device).sample_neighbors(\n            F.copy_to(idx, child_device), 2\n        )\n        assert sg.device == F.context(sg.ndata[\"x\"])\n        assert sg.device == F.context(sg.edata[\"a\"])\n        assert sg.device == child_device\n    if parent_device == \"uva\":\n        g.unpin_memory_()\n\n\n@unittest.skipIf(\n    F._default_context_str != \"gpu\", reason=\"UVA only available on GPU\"\n)\n@unittest.skipIf(\n    dgl.backend.backend_name != \"pytorch\",\n    reason=\"UVA only supported for PyTorch\",\n)\n@pytest.mark.parametrize(\"device\", [F.cpu(), F.cuda()])\n@parametrize_idtype\ndef test_uva_subgraph(idtype, device):\n    g = create_test_heterograph2(idtype)\n    g = g.to(F.cpu())\n    g.create_formats_()\n    g.pin_memory_()\n    indices = {\"user\": F.copy_to(F.tensor([0], idtype), device)}\n    edge_indices = {\"follows\": F.copy_to(F.tensor([0], idtype), device)}\n    assert g.subgraph(indices).device == device\n    assert g.edge_subgraph(edge_indices).device == device\n    assert g.in_subgraph(indices).device == device\n    assert g.out_subgraph(indices).device == device\n    assert g.khop_in_subgraph(indices, 1)[0].device == device\n    assert g.khop_out_subgraph(indices, 1)[0].device == device\n    assert g.sample_neighbors(indices, 1).device == device\n    g.unpin_memory_()\n\n\nif __name__ == \"__main__\":\n    test_edge_subgraph()\n    test_uva_subgraph(F.int64, F.cpu())\n    test_uva_subgraph(F.int64, F.cuda())\n"
  },
  {
    "path": "tests/python/common/test_traversal.py",
    "content": "import itertools\nimport random\nimport sys\nimport time\nimport unittest\n\nimport backend as F\n\nimport dgl\nimport networkx as nx\nimport numpy as np\nimport scipy.sparse as sp\nfrom utils import parametrize_idtype\n\nnp.random.seed(42)\n\n\ndef toset(x):\n    # F.zerocopy_to_numpy may return a int\n    return set(F.zerocopy_to_numpy(x).tolist())\n\n\n@parametrize_idtype\ndef test_bfs(idtype, n=100):\n    def _bfs_nx(g_nx, src):\n        edges = nx.bfs_edges(g_nx, src)\n        layers_nx = [set([src])]\n        edges_nx = []\n        frontier = set()\n        edge_frontier = set()\n        for u, v in edges:\n            if u in layers_nx[-1]:\n                frontier.add(v)\n                edge_frontier.add(g.edge_ids(int(u), int(v)))\n            else:\n                layers_nx.append(frontier)\n                edges_nx.append(edge_frontier)\n                frontier = set([v])\n                edge_frontier = set([g.edge_ids(u, v)])\n        # avoids empty successors\n        if len(frontier) > 0 and len(edge_frontier) > 0:\n            layers_nx.append(frontier)\n            edges_nx.append(edge_frontier)\n        return layers_nx, edges_nx\n\n    a = sp.random(n, n, 3 / n, data_rvs=lambda n: np.ones(n))\n    g = dgl.from_scipy(a).astype(idtype)\n\n    g_nx = g.to_networkx()\n    src = random.choice(range(n))\n    layers_nx, _ = _bfs_nx(g_nx, src)\n    layers_dgl = dgl.bfs_nodes_generator(g, src)\n    assert len(layers_dgl) == len(layers_nx)\n    assert all(toset(x) == y for x, y in zip(layers_dgl, layers_nx))\n\n    g_nx = nx.random_labeled_tree(n, seed=42)\n    g = dgl.from_networkx(g_nx).astype(idtype)\n    src = 0\n    _, edges_nx = _bfs_nx(g_nx, src)\n    edges_dgl = dgl.bfs_edges_generator(g, src)\n    assert len(edges_dgl) == len(edges_nx)\n    assert all(toset(x) == y for x, y in zip(edges_dgl, edges_nx))\n\n\n@parametrize_idtype\ndef test_topological_nodes(idtype, n=100):\n    a = sp.random(n, n, 3 / n, data_rvs=lambda n: np.ones(n))\n    b = sp.tril(a, -1).tocoo()\n    g = dgl.from_scipy(b).astype(idtype)\n\n    layers_dgl = dgl.topological_nodes_generator(g)\n\n    adjmat = g.adj_external(transpose=True)\n\n    def tensor_topo_traverse():\n        n = g.num_nodes()\n        mask = F.copy_to(F.ones((n, 1)), F.cpu())\n        degree = F.spmm(adjmat, mask)\n        while F.reduce_sum(mask) != 0.0:\n            v = F.astype((degree == 0.0), F.float32)\n            v = v * mask\n            mask = mask - v\n            frontier = F.copy_to(F.nonzero_1d(F.squeeze(v, 1)), F.cpu())\n            yield frontier\n            degree -= F.spmm(adjmat, v)\n\n    layers_spmv = list(tensor_topo_traverse())\n\n    assert len(layers_dgl) == len(layers_spmv)\n    assert all(toset(x) == toset(y) for x, y in zip(layers_dgl, layers_spmv))\n\n\nDFS_LABEL_NAMES = [\"forward\", \"reverse\", \"nontree\"]\n\n\n@parametrize_idtype\ndef test_dfs_labeled_edges(idtype, example=False):\n    dgl_g = dgl.graph([]).astype(idtype)\n    dgl_g.add_nodes(6)\n    dgl_g.add_edges([0, 1, 0, 3, 3], [1, 2, 2, 4, 5])\n    dgl_edges, dgl_labels = dgl.dfs_labeled_edges_generator(\n        dgl_g, [0, 3], has_reverse_edge=True, has_nontree_edge=True\n    )\n    dgl_edges = [toset(t) for t in dgl_edges]\n    dgl_labels = [toset(t) for t in dgl_labels]\n    g1_solutions = [\n        # edges           labels\n        [[0, 1, 1, 0, 2], [0, 0, 1, 1, 2]],\n        [[2, 2, 0, 1, 0], [0, 1, 0, 2, 1]],\n    ]\n    g2_solutions = [\n        # edges        labels\n        [[3, 3, 4, 4], [0, 1, 0, 1]],\n        [[4, 4, 3, 3], [0, 1, 0, 1]],\n    ]\n\n    def combine_frontiers(sol):\n        es, ls = zip(*sol)\n        es = [\n            set(i for i in t if i is not None)\n            for t in itertools.zip_longest(*es)\n        ]\n        ls = [\n            set(i for i in t if i is not None)\n            for t in itertools.zip_longest(*ls)\n        ]\n        return es, ls\n\n    for sol_set in itertools.product(g1_solutions, g2_solutions):\n        es, ls = combine_frontiers(sol_set)\n        if es == dgl_edges and ls == dgl_labels:\n            break\n    else:\n        assert False\n\n\nif __name__ == \"__main__\":\n    test_bfs(idtype=\"int32\")\n    test_topological_nodes(idtype=\"int32\")\n    test_dfs_labeled_edges(idtype=\"int32\")\n"
  },
  {
    "path": "tests/python/common/transforms/test_functional-sort.py",
    "content": "import itertools\nimport unittest\nfrom collections import Counter\n\nimport backend as F\n\nimport dgl\nimport dgl.function as fn\nimport networkx as nx\nimport numpy as np\nimport pytest\nimport scipy.sparse as ssp\nfrom dgl import DGLError\nfrom utils import parametrize_idtype\n\n\ndef create_test_heterograph(num_nodes, num_adj, idtype):\n    if isinstance(num_adj, int):\n        num_adj = [num_adj, num_adj + 1]\n    num_adj_list = list(\n        np.random.choice(np.arange(num_adj[0], num_adj[1]), num_nodes)\n    )\n    src = np.concatenate([[i] * num_adj_list[i] for i in range(num_nodes)])\n    dst = [\n        np.random.choice(num_nodes, nadj, replace=False)\n        for nadj in num_adj_list\n    ]\n    dst = np.concatenate(dst)\n    return dgl.graph((src, dst), idtype=idtype)\n\n\ndef check_sort(spm, tag_arr=None, tag_pos=None):\n    if tag_arr is None:\n        tag_arr = np.arange(spm.shape[0])\n    else:\n        tag_arr = F.asnumpy(tag_arr)\n    if tag_pos is not None:\n        tag_pos = F.asnumpy(tag_pos)\n    for i in range(spm.shape[0]):\n        row = spm.getrow(i)\n        dst = row.nonzero()[1]\n        if tag_pos is not None:\n            tag_pos_row = tag_pos[i]\n            tag_pos_ptr = tag_arr[dst[0]] if len(dst) > 0 else 0\n        for j in range(len(dst) - 1):\n            if tag_pos is not None and tag_arr[dst[j]] != tag_pos_ptr:\n                # `tag_pos_ptr` is the expected tag value. Here we check whether the\n                # tag value is equal to `tag_pos_ptr`\n                return False\n            if tag_arr[dst[j]] > tag_arr[dst[j + 1]]:\n                # The tag should be in ascending order after sorting\n                return False\n            if tag_pos is not None and tag_arr[dst[j]] < tag_arr[dst[j + 1]]:\n                if j + 1 != int(tag_pos_row[tag_pos_ptr + 1]):\n                    # The boundary of tag should be consistent with `tag_pos`\n                    return False\n                tag_pos_ptr = tag_arr[dst[j + 1]]\n    return True\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\", reason=\"GPU sorting by tag not implemented\"\n)\n@parametrize_idtype\ndef test_sort_with_tag(idtype):\n    num_nodes, num_adj, num_tags = 200, [20, 50], 5\n    g = create_test_heterograph(num_nodes, num_adj, idtype=idtype)\n    tag = F.tensor(np.random.choice(num_tags, g.num_nodes()))\n    src, dst = g.edges()\n    edge_tag_dst = F.gather_row(tag, F.tensor(dst))\n    edge_tag_src = F.gather_row(tag, F.tensor(src))\n\n    for tag_type in [\"node\", \"edge\"]:\n        new_g = dgl.sort_csr_by_tag(\n            g, tag if tag_type == \"node\" else edge_tag_dst, tag_type=tag_type\n        )\n        old_csr = g.adj_external(scipy_fmt=\"csr\")\n        new_csr = new_g.adj_external(scipy_fmt=\"csr\")\n        assert check_sort(new_csr, tag, new_g.dstdata[\"_TAG_OFFSET\"])\n        assert not check_sort(\n            old_csr, tag\n        )  # Check the original csr is not modified.\n\n    for tag_type in [\"node\", \"edge\"]:\n        new_g = dgl.sort_csc_by_tag(\n            g, tag if tag_type == \"node\" else edge_tag_src, tag_type=tag_type\n        )\n        old_csc = g.adj_external(transpose=True, scipy_fmt=\"csr\")\n        new_csc = new_g.adj_external(transpose=True, scipy_fmt=\"csr\")\n        assert check_sort(new_csc, tag, new_g.srcdata[\"_TAG_OFFSET\"])\n        assert not check_sort(old_csc, tag)\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\", reason=\"GPU sorting by tag not implemented\"\n)\n@parametrize_idtype\ndef test_sort_with_tag_bipartite(idtype):\n    num_nodes, num_adj, num_tags = 200, [20, 50], 5\n    g = create_test_heterograph(num_nodes, num_adj, idtype=idtype)\n    g = dgl.heterograph({(\"_U\", \"_E\", \"_V\"): g.edges()})\n    utag = F.tensor(np.random.choice(num_tags, g.num_nodes(\"_U\")))\n    vtag = F.tensor(np.random.choice(num_tags, g.num_nodes(\"_V\")))\n\n    new_g = dgl.sort_csr_by_tag(g, vtag)\n    old_csr = g.adj_external(scipy_fmt=\"csr\")\n    new_csr = new_g.adj_external(scipy_fmt=\"csr\")\n    assert check_sort(new_csr, vtag, new_g.nodes[\"_U\"].data[\"_TAG_OFFSET\"])\n    assert not check_sort(old_csr, vtag)\n\n    new_g = dgl.sort_csc_by_tag(g, utag)\n    old_csc = g.adj_external(transpose=True, scipy_fmt=\"csr\")\n    new_csc = new_g.adj_external(transpose=True, scipy_fmt=\"csr\")\n    assert check_sort(new_csc, utag, new_g.nodes[\"_V\"].data[\"_TAG_OFFSET\"])\n    assert not check_sort(old_csc, utag)\n\n\nif __name__ == \"__main__\":\n    test_sort_with_tag(F.int32)\n    test_sort_with_tag_bipartite(F.int32)\n"
  },
  {
    "path": "tests/python/common/transforms/test_to_block.py",
    "content": "##\n#   Copyright 2019-2021 Contributors\n#\n#   Licensed under the Apache License, Version 2.0 (the \"License\");\n#   you may not use this file except in compliance with the License.\n#   You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n#   Unless required by applicable law or agreed to in writing, software\n#   distributed under the License is distributed on an \"AS IS\" BASIS,\n#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#   See the License for the specific language governing permissions and\n#   limitations under the License.\n#\n\nimport backend as F\n\nimport dgl\nimport dgl.partition\nfrom utils import parametrize_idtype\n\n\n@parametrize_idtype\ndef test_to_block(idtype):\n    def check(g, bg, ntype, etype, dst_nodes, include_dst_in_src=True):\n        if dst_nodes is not None:\n            assert F.array_equal(bg.dstnodes[ntype].data[dgl.NID], dst_nodes)\n        n_dst_nodes = bg.num_nodes(\"DST/\" + ntype)\n        if include_dst_in_src:\n            assert F.array_equal(\n                bg.srcnodes[ntype].data[dgl.NID][:n_dst_nodes],\n                bg.dstnodes[ntype].data[dgl.NID],\n            )\n\n        g = g[etype]\n        bg = bg[etype]\n        induced_src = bg.srcdata[dgl.NID]\n        induced_dst = bg.dstdata[dgl.NID]\n        induced_eid = bg.edata[dgl.EID]\n\n        bg_src, bg_dst = bg.all_edges(order=\"eid\")\n        src_ans, dst_ans = g.all_edges(order=\"eid\")\n\n        induced_src_bg = F.gather_row(induced_src, bg_src)\n        induced_dst_bg = F.gather_row(induced_dst, bg_dst)\n        induced_src_ans = F.gather_row(src_ans, induced_eid)\n        induced_dst_ans = F.gather_row(dst_ans, induced_eid)\n\n        assert F.array_equal(induced_src_bg, induced_src_ans)\n        assert F.array_equal(induced_dst_bg, induced_dst_ans)\n\n    def checkall(g, bg, dst_nodes, include_dst_in_src=True):\n        for etype in g.etypes:\n            ntype = g.to_canonical_etype(etype)[2]\n            if dst_nodes is not None and ntype in dst_nodes:\n                check(g, bg, ntype, etype, dst_nodes[ntype], include_dst_in_src)\n            else:\n                check(g, bg, ntype, etype, None, include_dst_in_src)\n\n    # homogeneous graph\n    g = dgl.graph(\n        (F.tensor([1, 2], dtype=idtype), F.tensor([2, 3], dtype=idtype))\n    )\n    dst_nodes = F.tensor([3, 2], dtype=idtype)\n    bg = dgl.to_block(g, dst_nodes=dst_nodes)\n    check(g, bg, \"_N\", \"_E\", dst_nodes)\n\n    src_nodes = bg.srcnodes[\"_N\"].data[dgl.NID]\n    bg = dgl.to_block(g, dst_nodes=dst_nodes, src_nodes=src_nodes)\n    check(g, bg, \"_N\", \"_E\", dst_nodes)\n\n    # heterogeneous graph\n    g = dgl.heterograph(\n        {\n            (\"A\", \"AA\", \"A\"): ([0, 2, 1, 3], [1, 3, 2, 4]),\n            (\"A\", \"AB\", \"B\"): ([0, 1, 3, 1], [1, 3, 5, 6]),\n            (\"B\", \"BA\", \"A\"): ([2, 3], [3, 2]),\n        },\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    g.nodes[\"A\"].data[\"x\"] = F.randn((5, 10))\n    g.nodes[\"B\"].data[\"x\"] = F.randn((7, 5))\n    g.edges[\"AA\"].data[\"x\"] = F.randn((4, 3))\n    g.edges[\"AB\"].data[\"x\"] = F.randn((4, 3))\n    g.edges[\"BA\"].data[\"x\"] = F.randn((2, 3))\n    g_a = g[\"AA\"]\n\n    def check_features(g, bg):\n        for ntype in bg.srctypes:\n            for key in g.nodes[ntype].data:\n                assert F.array_equal(\n                    bg.srcnodes[ntype].data[key],\n                    F.gather_row(\n                        g.nodes[ntype].data[key],\n                        bg.srcnodes[ntype].data[dgl.NID],\n                    ),\n                )\n        for ntype in bg.dsttypes:\n            for key in g.nodes[ntype].data:\n                assert F.array_equal(\n                    bg.dstnodes[ntype].data[key],\n                    F.gather_row(\n                        g.nodes[ntype].data[key],\n                        bg.dstnodes[ntype].data[dgl.NID],\n                    ),\n                )\n        for etype in bg.canonical_etypes:\n            for key in g.edges[etype].data:\n                assert F.array_equal(\n                    bg.edges[etype].data[key],\n                    F.gather_row(\n                        g.edges[etype].data[key], bg.edges[etype].data[dgl.EID]\n                    ),\n                )\n\n    bg = dgl.to_block(g_a)\n    check(g_a, bg, \"A\", \"AA\", None)\n    check_features(g_a, bg)\n    assert bg.number_of_src_nodes() == 5\n    assert bg.number_of_dst_nodes() == 4\n\n    bg = dgl.to_block(g_a, include_dst_in_src=False)\n    check(g_a, bg, \"A\", \"AA\", None, False)\n    check_features(g_a, bg)\n    assert bg.number_of_src_nodes() == 4\n    assert bg.number_of_dst_nodes() == 4\n\n    dst_nodes = F.tensor([4, 3, 2, 1], dtype=idtype)\n    bg = dgl.to_block(g_a, dst_nodes)\n    check(g_a, bg, \"A\", \"AA\", dst_nodes)\n    check_features(g_a, bg)\n\n    g_ab = g[\"AB\"]\n\n    bg = dgl.to_block(g_ab)\n    assert bg.idtype == idtype\n    assert bg.num_nodes(\"SRC/B\") == 4\n    assert F.array_equal(\n        bg.srcnodes[\"B\"].data[dgl.NID], bg.dstnodes[\"B\"].data[dgl.NID]\n    )\n    assert bg.num_nodes(\"DST/A\") == 0\n    checkall(g_ab, bg, None)\n    check_features(g_ab, bg)\n\n    dst_nodes = {\"B\": F.tensor([5, 6, 3, 1], dtype=idtype)}\n    bg = dgl.to_block(g, dst_nodes)\n    assert bg.num_nodes(\"SRC/B\") == 4\n    assert F.array_equal(\n        bg.srcnodes[\"B\"].data[dgl.NID], bg.dstnodes[\"B\"].data[dgl.NID]\n    )\n    assert bg.num_nodes(\"DST/A\") == 0\n    checkall(g, bg, dst_nodes)\n    check_features(g, bg)\n\n    dst_nodes = {\n        \"A\": F.tensor([4, 3, 2, 1], dtype=idtype),\n        \"B\": F.tensor([3, 5, 6, 1], dtype=idtype),\n    }\n    bg = dgl.to_block(g, dst_nodes=dst_nodes)\n    checkall(g, bg, dst_nodes)\n    check_features(g, bg)\n\n    # test specifying lhs_nodes with include_dst_in_src\n    src_nodes = {}\n    for ntype in dst_nodes.keys():\n        # use the previous run to get the list of source nodes\n        src_nodes[ntype] = bg.srcnodes[ntype].data[dgl.NID]\n    bg = dgl.to_block(g, dst_nodes=dst_nodes, src_nodes=src_nodes)\n    checkall(g, bg, dst_nodes)\n    check_features(g, bg)\n\n    # test without include_dst_in_src\n    dst_nodes = {\n        \"A\": F.tensor([4, 3, 2, 1], dtype=idtype),\n        \"B\": F.tensor([3, 5, 6, 1], dtype=idtype),\n    }\n    bg = dgl.to_block(g, dst_nodes=dst_nodes, include_dst_in_src=False)\n    checkall(g, bg, dst_nodes, False)\n    check_features(g, bg)\n\n    # test specifying lhs_nodes without include_dst_in_src\n    src_nodes = {}\n    for ntype in dst_nodes.keys():\n        # use the previous run to get the list of source nodes\n        src_nodes[ntype] = bg.srcnodes[ntype].data[dgl.NID]\n    bg = dgl.to_block(\n        g, dst_nodes=dst_nodes, include_dst_in_src=False, src_nodes=src_nodes\n    )\n    checkall(g, bg, dst_nodes, False)\n    check_features(g, bg)\n"
  },
  {
    "path": "tests/python/common/transforms/test_transform.py",
    "content": "##\n#   Copyright 2019-2021 Contributors\n#\n#   Licensed under the Apache License, Version 2.0 (the \"License\");\n#   you may not use this file except in compliance with the License.\n#   You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n#   Unless required by applicable law or agreed to in writing, software\n#   distributed under the License is distributed on an \"AS IS\" BASIS,\n#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#   See the License for the specific language governing permissions and\n#   limitations under the License.\n#\n\nimport math\nimport os\nimport unittest\n\nimport backend as F\n\nimport dgl\nimport dgl.function as fn\nimport dgl.partition\nimport networkx as nx\nimport numpy as np\nimport pytest\nfrom scipy import sparse as spsp\nfrom utils import parametrize_idtype\nfrom utils.graph_cases import get_cases\n\nD = 5\n\n\ndef create_test_heterograph3(idtype):\n    g = dgl.heterograph(\n        {\n            (\"user\", \"plays\", \"game\"): (\n                F.tensor([0, 1, 1, 2], dtype=idtype),\n                F.tensor([0, 0, 1, 1], dtype=idtype),\n            ),\n            (\"developer\", \"develops\", \"game\"): (\n                F.tensor([0, 1], dtype=idtype),\n                F.tensor([0, 1], dtype=idtype),\n            ),\n        },\n        idtype=idtype,\n        device=F.ctx(),\n    )\n\n    g.nodes[\"user\"].data[\"h\"] = F.copy_to(\n        F.tensor([1, 1, 1], dtype=idtype), ctx=F.ctx()\n    )\n    g.nodes[\"game\"].data[\"h\"] = F.copy_to(\n        F.tensor([2, 2], dtype=idtype), ctx=F.ctx()\n    )\n    g.nodes[\"developer\"].data[\"h\"] = F.copy_to(\n        F.tensor([3, 3], dtype=idtype), ctx=F.ctx()\n    )\n    g.edges[\"plays\"].data[\"h\"] = F.copy_to(\n        F.tensor([1, 1, 1, 1], dtype=idtype), ctx=F.ctx()\n    )\n    return g\n\n\ndef create_test_heterograph4(idtype):\n    g = dgl.heterograph(\n        {\n            (\"user\", \"follows\", \"user\"): (\n                F.tensor([0, 1, 1, 2, 2, 2], dtype=idtype),\n                F.tensor([0, 0, 1, 1, 2, 2], dtype=idtype),\n            ),\n            (\"user\", \"plays\", \"game\"): (\n                F.tensor([0, 1], dtype=idtype),\n                F.tensor([0, 1], dtype=idtype),\n            ),\n        },\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    g.nodes[\"user\"].data[\"h\"] = F.copy_to(\n        F.tensor([1, 1, 1], dtype=idtype), ctx=F.ctx()\n    )\n    g.nodes[\"game\"].data[\"h\"] = F.copy_to(\n        F.tensor([2, 2], dtype=idtype), ctx=F.ctx()\n    )\n    g.edges[\"follows\"].data[\"h\"] = F.copy_to(\n        F.tensor([1, 2, 3, 4, 5, 6], dtype=idtype), ctx=F.ctx()\n    )\n    g.edges[\"plays\"].data[\"h\"] = F.copy_to(\n        F.tensor([1, 2], dtype=idtype), ctx=F.ctx()\n    )\n    return g\n\n\ndef create_test_heterograph5(idtype):\n    g = dgl.heterograph(\n        {\n            (\"user\", \"follows\", \"user\"): (\n                F.tensor([1, 2], dtype=idtype),\n                F.tensor([0, 1], dtype=idtype),\n            ),\n            (\"user\", \"plays\", \"game\"): (\n                F.tensor([0, 1], dtype=idtype),\n                F.tensor([0, 1], dtype=idtype),\n            ),\n        },\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    g.nodes[\"user\"].data[\"h\"] = F.copy_to(\n        F.tensor([1, 1, 1], dtype=idtype), ctx=F.ctx()\n    )\n    g.nodes[\"game\"].data[\"h\"] = F.copy_to(\n        F.tensor([2, 2], dtype=idtype), ctx=F.ctx()\n    )\n    g.edges[\"follows\"].data[\"h\"] = F.copy_to(\n        F.tensor([1, 2], dtype=idtype), ctx=F.ctx()\n    )\n    g.edges[\"plays\"].data[\"h\"] = F.copy_to(\n        F.tensor([1, 2], dtype=idtype), ctx=F.ctx()\n    )\n    return g\n\n\n# line graph related\n\n\ndef test_line_graph1():\n    N = 5\n    G = dgl.from_networkx(nx.star_graph(N)).to(F.ctx())\n    G.edata[\"h\"] = F.randn((2 * N, D))\n    L = G.line_graph(shared=True)\n    assert L.num_nodes() == 2 * N\n    assert F.allclose(L.ndata[\"h\"], G.edata[\"h\"])\n    assert G.device == F.ctx()\n\n\n@parametrize_idtype\ndef test_line_graph2(idtype):\n    g = dgl.heterograph(\n        {(\"user\", \"follows\", \"user\"): ([0, 1, 1, 2, 2], [2, 0, 2, 0, 1])},\n        idtype=idtype,\n    )\n    lg = dgl.line_graph(g)\n    assert lg.num_nodes() == 5\n    assert lg.num_edges() == 8\n    row, col = lg.edges()\n    assert np.array_equal(F.asnumpy(row), np.array([0, 0, 1, 2, 2, 3, 4, 4]))\n    assert np.array_equal(F.asnumpy(col), np.array([3, 4, 0, 3, 4, 0, 1, 2]))\n\n    lg = dgl.line_graph(g, backtracking=False)\n    assert lg.num_nodes() == 5\n    assert lg.num_edges() == 4\n    row, col = lg.edges()\n    assert np.array_equal(F.asnumpy(row), np.array([0, 1, 2, 4]))\n    assert np.array_equal(F.asnumpy(col), np.array([4, 0, 3, 1]))\n    g = dgl.heterograph(\n        {(\"user\", \"follows\", \"user\"): ([0, 1, 1, 2, 2], [2, 0, 2, 0, 1])},\n        idtype=idtype,\n    ).formats(\"csr\")\n    lg = dgl.line_graph(g)\n    assert lg.num_nodes() == 5\n    assert lg.num_edges() == 8\n    row, col = lg.edges()\n    assert np.array_equal(F.asnumpy(row), np.array([0, 0, 1, 2, 2, 3, 4, 4]))\n    assert np.array_equal(F.asnumpy(col), np.array([3, 4, 0, 3, 4, 0, 1, 2]))\n\n    g = dgl.heterograph(\n        {(\"user\", \"follows\", \"user\"): ([0, 1, 1, 2, 2], [2, 0, 2, 0, 1])},\n        idtype=idtype,\n    ).formats(\"csc\")\n    lg = dgl.line_graph(g)\n    assert lg.num_nodes() == 5\n    assert lg.num_edges() == 8\n    row, col, eid = lg.edges(\"all\")\n    row = F.asnumpy(row)\n    col = F.asnumpy(col)\n    eid = F.asnumpy(eid).astype(int)\n    order = np.argsort(eid)\n    assert np.array_equal(row[order], np.array([0, 0, 1, 2, 2, 3, 4, 4]))\n    assert np.array_equal(col[order], np.array([3, 4, 0, 3, 4, 0, 1, 2]))\n\n\ndef test_no_backtracking():\n    N = 5\n    G = dgl.from_networkx(nx.star_graph(N))\n    L = G.line_graph(backtracking=False)\n    assert L.num_nodes() == 2 * N\n    for i in range(1, N):\n        e1 = G.edge_ids(0, i)\n        e2 = G.edge_ids(i, 0)\n        assert not L.has_edges_between(e1, e2)\n        assert not L.has_edges_between(e2, e1)\n\n\n# reverse graph related\n@parametrize_idtype\ndef test_reverse(idtype):\n    g = dgl.graph([])\n    g = g.astype(idtype).to(F.ctx())\n    g.add_nodes(5)\n    # The graph need not to be completely connected.\n    g.add_edges([0, 1, 2], [1, 2, 1])\n    g.ndata[\"h\"] = F.tensor([[0.0], [1.0], [2.0], [3.0], [4.0]])\n    g.edata[\"h\"] = F.tensor([[5.0], [6.0], [7.0]])\n    rg = g.reverse()\n\n    assert g.is_multigraph == rg.is_multigraph\n\n    assert g.num_nodes() == rg.num_nodes()\n    assert g.num_edges() == rg.num_edges()\n    assert F.allclose(\n        F.astype(rg.has_edges_between([1, 2, 1], [0, 1, 2]), F.float32),\n        F.ones((3,)),\n    )\n    assert g.edge_ids(0, 1) == rg.edge_ids(1, 0)\n    assert g.edge_ids(1, 2) == rg.edge_ids(2, 1)\n    assert g.edge_ids(2, 1) == rg.edge_ids(1, 2)\n\n    # test dgl.reverse\n    # test homogeneous graph\n    g = dgl.graph((F.tensor([0, 1, 2]), F.tensor([1, 2, 0])))\n    g.ndata[\"h\"] = F.tensor([[0.0], [1.0], [2.0]])\n    g.edata[\"h\"] = F.tensor([[3.0], [4.0], [5.0]])\n    g_r = dgl.reverse(g)\n    assert g.num_nodes() == g_r.num_nodes()\n    assert g.num_edges() == g_r.num_edges()\n    u_g, v_g, eids_g = g.all_edges(form=\"all\")\n    u_rg, v_rg, eids_rg = g_r.all_edges(form=\"all\")\n    assert F.array_equal(u_g, v_rg)\n    assert F.array_equal(v_g, u_rg)\n    assert F.array_equal(eids_g, eids_rg)\n    assert F.array_equal(g.ndata[\"h\"], g_r.ndata[\"h\"])\n    assert len(g_r.edata) == 0\n\n    # without share ndata\n    g_r = dgl.reverse(g, copy_ndata=False)\n    assert g.num_nodes() == g_r.num_nodes()\n    assert g.num_edges() == g_r.num_edges()\n    assert len(g_r.ndata) == 0\n    assert len(g_r.edata) == 0\n\n    # with share ndata and edata\n    g_r = dgl.reverse(g, copy_ndata=True, copy_edata=True)\n    assert g.num_nodes() == g_r.num_nodes()\n    assert g.num_edges() == g_r.num_edges()\n    assert F.array_equal(g.ndata[\"h\"], g_r.ndata[\"h\"])\n    assert F.array_equal(g.edata[\"h\"], g_r.edata[\"h\"])\n\n    # add new node feature to g_r\n    g_r.ndata[\"hh\"] = F.tensor([0, 1, 2])\n    assert (\"hh\" in g.ndata) is False\n    assert (\"hh\" in g_r.ndata) is True\n\n    # add new edge feature to g_r\n    g_r.edata[\"hh\"] = F.tensor([0, 1, 2])\n    assert (\"hh\" in g.edata) is False\n    assert (\"hh\" in g_r.edata) is True\n\n    # test heterogeneous graph\n    g = dgl.heterograph(\n        {\n            (\"user\", \"follows\", \"user\"): (\n                [0, 1, 2, 4, 3, 1, 3],\n                [1, 2, 3, 2, 0, 0, 1],\n            ),\n            (\"user\", \"plays\", \"game\"): (\n                [0, 0, 2, 3, 3, 4, 1],\n                [1, 0, 1, 0, 1, 0, 0],\n            ),\n            (\"developer\", \"develops\", \"game\"): ([0, 1, 1, 2], [0, 0, 1, 1]),\n        },\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    g.nodes[\"user\"].data[\"h\"] = F.tensor([0, 1, 2, 3, 4])\n    g.nodes[\"user\"].data[\"hh\"] = F.tensor([1, 1, 1, 1, 1])\n    g.nodes[\"game\"].data[\"h\"] = F.tensor([0, 1])\n    g.edges[\"follows\"].data[\"h\"] = F.tensor([0, 1, 2, 4, 3, 1, 3])\n    g.edges[\"follows\"].data[\"hh\"] = F.tensor([1, 2, 3, 2, 0, 0, 1])\n    g_r = dgl.reverse(g)\n\n    for etype_g, etype_gr in zip(g.canonical_etypes, g_r.canonical_etypes):\n        assert etype_g[0] == etype_gr[2]\n        assert etype_g[1] == etype_gr[1]\n        assert etype_g[2] == etype_gr[0]\n        assert g.num_edges(etype_g) == g_r.num_edges(etype_gr)\n    for ntype in g.ntypes:\n        assert g.num_nodes(ntype) == g_r.num_nodes(ntype)\n    assert F.array_equal(g.nodes[\"user\"].data[\"h\"], g_r.nodes[\"user\"].data[\"h\"])\n    assert F.array_equal(\n        g.nodes[\"user\"].data[\"hh\"], g_r.nodes[\"user\"].data[\"hh\"]\n    )\n    assert F.array_equal(g.nodes[\"game\"].data[\"h\"], g_r.nodes[\"game\"].data[\"h\"])\n    assert len(g_r.edges[\"follows\"].data) == 0\n    u_g, v_g, eids_g = g.all_edges(\n        form=\"all\", etype=(\"user\", \"follows\", \"user\")\n    )\n    u_rg, v_rg, eids_rg = g_r.all_edges(\n        form=\"all\", etype=(\"user\", \"follows\", \"user\")\n    )\n    assert F.array_equal(u_g, v_rg)\n    assert F.array_equal(v_g, u_rg)\n    assert F.array_equal(eids_g, eids_rg)\n    u_g, v_g, eids_g = g.all_edges(form=\"all\", etype=(\"user\", \"plays\", \"game\"))\n    u_rg, v_rg, eids_rg = g_r.all_edges(\n        form=\"all\", etype=(\"game\", \"plays\", \"user\")\n    )\n    assert F.array_equal(u_g, v_rg)\n    assert F.array_equal(v_g, u_rg)\n    assert F.array_equal(eids_g, eids_rg)\n    u_g, v_g, eids_g = g.all_edges(\n        form=\"all\", etype=(\"developer\", \"develops\", \"game\")\n    )\n    u_rg, v_rg, eids_rg = g_r.all_edges(\n        form=\"all\", etype=(\"game\", \"develops\", \"developer\")\n    )\n    assert F.array_equal(u_g, v_rg)\n    assert F.array_equal(v_g, u_rg)\n    assert F.array_equal(eids_g, eids_rg)\n\n    # withour share ndata\n    g_r = dgl.reverse(g, copy_ndata=False)\n    for etype_g, etype_gr in zip(g.canonical_etypes, g_r.canonical_etypes):\n        assert etype_g[0] == etype_gr[2]\n        assert etype_g[1] == etype_gr[1]\n        assert etype_g[2] == etype_gr[0]\n        assert g.num_edges(etype_g) == g_r.num_edges(etype_gr)\n    for ntype in g.ntypes:\n        assert g.num_nodes(ntype) == g_r.num_nodes(ntype)\n    assert len(g_r.nodes[\"user\"].data) == 0\n    assert len(g_r.nodes[\"game\"].data) == 0\n\n    g_r = dgl.reverse(g, copy_ndata=True, copy_edata=True)\n    print(g_r)\n    for etype_g, etype_gr in zip(g.canonical_etypes, g_r.canonical_etypes):\n        assert etype_g[0] == etype_gr[2]\n        assert etype_g[1] == etype_gr[1]\n        assert etype_g[2] == etype_gr[0]\n        assert g.num_edges(etype_g) == g_r.num_edges(etype_gr)\n    assert F.array_equal(\n        g.edges[\"follows\"].data[\"h\"], g_r.edges[\"follows\"].data[\"h\"]\n    )\n    assert F.array_equal(\n        g.edges[\"follows\"].data[\"hh\"], g_r.edges[\"follows\"].data[\"hh\"]\n    )\n\n    # add new node feature to g_r\n    g_r.nodes[\"user\"].data[\"hhh\"] = F.tensor([0, 1, 2, 3, 4])\n    assert (\"hhh\" in g.nodes[\"user\"].data) is False\n    assert (\"hhh\" in g_r.nodes[\"user\"].data) is True\n\n    # add new edge feature to g_r\n    g_r.edges[\"follows\"].data[\"hhh\"] = F.tensor([1, 2, 3, 2, 0, 0, 1])\n    assert (\"hhh\" in g.edges[\"follows\"].data) is False\n    assert (\"hhh\" in g_r.edges[\"follows\"].data) is True\n\n\n@parametrize_idtype\ndef test_reverse_shared_frames(idtype):\n    g = dgl.graph([])\n    g = g.astype(idtype).to(F.ctx())\n    g.add_nodes(3)\n    g.add_edges([0, 1, 2], [1, 2, 1])\n    g.ndata[\"h\"] = F.tensor([[0.0], [1.0], [2.0]])\n    g.edata[\"h\"] = F.tensor([[3.0], [4.0], [5.0]])\n\n    rg = g.reverse(copy_ndata=True, copy_edata=True)\n    assert F.allclose(g.ndata[\"h\"], rg.ndata[\"h\"])\n    assert F.allclose(g.edata[\"h\"], rg.edata[\"h\"])\n    assert F.allclose(\n        g.edges[[0, 2], [1, 1]].data[\"h\"], rg.edges[[1, 1], [0, 2]].data[\"h\"]\n    )\n\n\n@unittest.skipIf(F._default_context_str == \"gpu\", reason=\"GPU not implemented\")\ndef test_to_bidirected():\n    # homogeneous graph\n    elist = [(0, 0), (0, 1), (1, 0), (1, 1), (2, 1), (2, 2)]\n    num_edges = 7\n    g = dgl.graph(tuple(zip(*elist)))\n    elist.append((1, 2))\n    elist = set(elist)\n    big = dgl.to_bidirected(g)\n    assert big.num_edges() == num_edges\n    src, dst = big.edges()\n    eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))\n    assert eset == set(elist)\n\n    # heterogeneous graph\n    elist1 = [(0, 0), (0, 1), (1, 0), (1, 1), (2, 1), (2, 2)]\n    elist2 = [(0, 0), (0, 1)]\n    g = dgl.heterograph(\n        {\n            (\"user\", \"wins\", \"user\"): tuple(zip(*elist1)),\n            (\"user\", \"follows\", \"user\"): tuple(zip(*elist2)),\n        }\n    )\n    g.nodes[\"user\"].data[\"h\"] = F.ones((3, 1))\n    elist1.append((1, 2))\n    elist1 = set(elist1)\n    elist2.append((1, 0))\n    elist2 = set(elist2)\n    big = dgl.to_bidirected(g)\n    assert big.num_edges(\"wins\") == 7\n    assert big.num_edges(\"follows\") == 3\n    src, dst = big.edges(etype=\"wins\")\n    eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))\n    assert eset == set(elist1)\n    src, dst = big.edges(etype=\"follows\")\n    eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))\n    assert eset == set(elist2)\n\n    big = dgl.to_bidirected(g, copy_ndata=True)\n    assert F.array_equal(g.nodes[\"user\"].data[\"h\"], big.nodes[\"user\"].data[\"h\"])\n\n\ndef test_add_reverse_edges():\n    # homogeneous graph\n    g = dgl.graph((F.tensor([0, 1, 3, 1]), F.tensor([1, 2, 0, 2])))\n    g.ndata[\"h\"] = F.tensor([[0.0], [1.0], [2.0], [1.0]])\n    g.edata[\"h\"] = F.tensor([[3.0], [4.0], [5.0], [6.0]])\n    bg = dgl.add_reverse_edges(g, copy_ndata=True, copy_edata=True)\n    u, v = g.edges()\n    ub, vb = bg.edges()\n    assert F.array_equal(F.cat([u, v], dim=0), ub)\n    assert F.array_equal(F.cat([v, u], dim=0), vb)\n    assert F.array_equal(g.ndata[\"h\"], bg.ndata[\"h\"])\n    assert F.array_equal(\n        F.cat([g.edata[\"h\"], g.edata[\"h\"]], dim=0), bg.edata[\"h\"]\n    )\n    bg.ndata[\"hh\"] = F.tensor([[0.0], [1.0], [2.0], [1.0]])\n    assert (\"hh\" in g.ndata) is False\n    bg.edata[\"hh\"] = F.tensor(\n        [[0.0], [1.0], [2.0], [1.0], [0.0], [1.0], [2.0], [1.0]]\n    )\n    assert (\"hh\" in g.edata) is False\n\n    # donot share ndata and edata\n    bg = dgl.add_reverse_edges(g, copy_ndata=False, copy_edata=False)\n    ub, vb = bg.edges()\n    assert F.array_equal(F.cat([u, v], dim=0), ub)\n    assert F.array_equal(F.cat([v, u], dim=0), vb)\n    assert (\"h\" in bg.ndata) is False\n    assert (\"h\" in bg.edata) is False\n\n    # zero edge graph\n    g = dgl.graph(([], []))\n    bg = dgl.add_reverse_edges(\n        g, copy_ndata=True, copy_edata=True, exclude_self=False\n    )\n\n    # heterogeneous graph\n    g = dgl.heterograph(\n        {\n            (\"user\", \"wins\", \"user\"): (\n                F.tensor([0, 2, 0, 2, 2]),\n                F.tensor([1, 1, 2, 1, 0]),\n            ),\n            (\"user\", \"plays\", \"game\"): (\n                F.tensor([1, 2, 1]),\n                F.tensor([2, 1, 1]),\n            ),\n            (\"user\", \"follows\", \"user\"): (\n                F.tensor([1, 2, 1]),\n                F.tensor([0, 0, 0]),\n            ),\n        }\n    )\n    g.nodes[\"game\"].data[\"hv\"] = F.ones((3, 1))\n    g.nodes[\"user\"].data[\"hv\"] = F.ones((3, 1))\n    g.edges[\"wins\"].data[\"h\"] = F.tensor([0, 1, 2, 3, 4])\n    bg = dgl.add_reverse_edges(\n        g, copy_ndata=True, copy_edata=True, ignore_bipartite=True\n    )\n    assert F.array_equal(\n        g.nodes[\"game\"].data[\"hv\"], bg.nodes[\"game\"].data[\"hv\"]\n    )\n    assert F.array_equal(\n        g.nodes[\"user\"].data[\"hv\"], bg.nodes[\"user\"].data[\"hv\"]\n    )\n    u, v = g.all_edges(order=\"eid\", etype=(\"user\", \"wins\", \"user\"))\n    ub, vb = bg.all_edges(order=\"eid\", etype=(\"user\", \"wins\", \"user\"))\n    assert F.array_equal(F.cat([u, v], dim=0), ub)\n    assert F.array_equal(F.cat([v, u], dim=0), vb)\n    assert F.array_equal(\n        F.cat([g.edges[\"wins\"].data[\"h\"], g.edges[\"wins\"].data[\"h\"]], dim=0),\n        bg.edges[\"wins\"].data[\"h\"],\n    )\n    u, v = g.all_edges(order=\"eid\", etype=(\"user\", \"follows\", \"user\"))\n    ub, vb = bg.all_edges(order=\"eid\", etype=(\"user\", \"follows\", \"user\"))\n    assert F.array_equal(F.cat([u, v], dim=0), ub)\n    assert F.array_equal(F.cat([v, u], dim=0), vb)\n    u, v = g.all_edges(order=\"eid\", etype=(\"user\", \"plays\", \"game\"))\n    ub, vb = bg.all_edges(order=\"eid\", etype=(\"user\", \"plays\", \"game\"))\n    assert F.array_equal(u, ub)\n    assert F.array_equal(v, vb)\n    assert set(bg.edges[\"plays\"].data.keys()) == {dgl.EID}\n    assert set(bg.edges[\"follows\"].data.keys()) == {dgl.EID}\n\n    # donot share ndata and edata\n    bg = dgl.add_reverse_edges(\n        g, copy_ndata=False, copy_edata=False, ignore_bipartite=True\n    )\n    assert len(bg.edges[\"wins\"].data) == 0\n    assert len(bg.edges[\"plays\"].data) == 0\n    assert len(bg.edges[\"follows\"].data) == 0\n    assert len(bg.nodes[\"game\"].data) == 0\n    assert len(bg.nodes[\"user\"].data) == 0\n    u, v = g.all_edges(order=\"eid\", etype=(\"user\", \"wins\", \"user\"))\n    ub, vb = bg.all_edges(order=\"eid\", etype=(\"user\", \"wins\", \"user\"))\n    assert F.array_equal(F.cat([u, v], dim=0), ub)\n    assert F.array_equal(F.cat([v, u], dim=0), vb)\n    u, v = g.all_edges(order=\"eid\", etype=(\"user\", \"follows\", \"user\"))\n    ub, vb = bg.all_edges(order=\"eid\", etype=(\"user\", \"follows\", \"user\"))\n    assert F.array_equal(F.cat([u, v], dim=0), ub)\n    assert F.array_equal(F.cat([v, u], dim=0), vb)\n    u, v = g.all_edges(order=\"eid\", etype=(\"user\", \"plays\", \"game\"))\n    ub, vb = bg.all_edges(order=\"eid\", etype=(\"user\", \"plays\", \"game\"))\n    assert F.array_equal(u, ub)\n    assert F.array_equal(v, vb)\n\n    # test the case when some nodes have zero degree\n    # homogeneous graph\n    g = dgl.graph((F.tensor([0, 1, 3, 1]), F.tensor([1, 2, 0, 2])), num_nodes=6)\n    g.ndata[\"h\"] = F.tensor([[0.0], [1.0], [2.0], [1.0], [1.0], [1.0]])\n    g.edata[\"h\"] = F.tensor([[3.0], [4.0], [5.0], [6.0]])\n    bg = dgl.add_reverse_edges(g, copy_ndata=True, copy_edata=True)\n    assert g.num_nodes() == bg.num_nodes()\n    assert F.array_equal(g.ndata[\"h\"], bg.ndata[\"h\"])\n    assert F.array_equal(\n        F.cat([g.edata[\"h\"], g.edata[\"h\"]], dim=0), bg.edata[\"h\"]\n    )\n\n    # heterogeneous graph\n    g = dgl.heterograph(\n        {\n            (\"user\", \"wins\", \"user\"): (\n                F.tensor([0, 2, 0, 2, 2]),\n                F.tensor([1, 1, 2, 1, 0]),\n            ),\n            (\"user\", \"plays\", \"game\"): (\n                F.tensor([1, 2, 1]),\n                F.tensor([2, 1, 1]),\n            ),\n            (\"user\", \"follows\", \"user\"): (\n                F.tensor([1, 2, 1]),\n                F.tensor([0, 0, 0]),\n            ),\n        },\n        num_nodes_dict={\"user\": 5, \"game\": 3},\n    )\n    g.nodes[\"game\"].data[\"hv\"] = F.ones((3, 1))\n    g.nodes[\"user\"].data[\"hv\"] = F.ones((5, 1))\n    g.edges[\"wins\"].data[\"h\"] = F.tensor([0, 1, 2, 3, 4])\n    bg = dgl.add_reverse_edges(\n        g, copy_ndata=True, copy_edata=True, ignore_bipartite=True\n    )\n    assert g.num_nodes(\"user\") == bg.num_nodes(\"user\")\n    assert g.num_nodes(\"game\") == bg.num_nodes(\"game\")\n    assert F.array_equal(\n        g.nodes[\"game\"].data[\"hv\"], bg.nodes[\"game\"].data[\"hv\"]\n    )\n    assert F.array_equal(\n        g.nodes[\"user\"].data[\"hv\"], bg.nodes[\"user\"].data[\"hv\"]\n    )\n    assert F.array_equal(\n        F.cat([g.edges[\"wins\"].data[\"h\"], g.edges[\"wins\"].data[\"h\"]], dim=0),\n        bg.edges[\"wins\"].data[\"h\"],\n    )\n\n    # test exclude_self\n    g = dgl.heterograph(\n        {\n            (\"A\", \"r1\", \"A\"): (F.tensor([0, 0, 1, 1]), F.tensor([0, 1, 1, 2])),\n            (\"A\", \"r2\", \"A\"): (F.tensor([0, 1]), F.tensor([1, 2])),\n        }\n    )\n    g.edges[\"r1\"].data[\"h\"] = F.tensor([0, 1, 2, 3])\n    rg = dgl.add_reverse_edges(g, copy_edata=True, exclude_self=True)\n    assert rg.num_edges(\"r1\") == 6\n    assert rg.num_edges(\"r2\") == 4\n    assert F.array_equal(rg.edges[\"r1\"].data[\"h\"], F.tensor([0, 1, 2, 3, 1, 3]))\n\n\n@unittest.skipIf(F._default_context_str == \"gpu\", reason=\"GPU not implemented\")\ndef test_simple_graph():\n    elist = [(0, 1), (0, 2), (1, 2), (0, 1)]\n    g = dgl.graph(elist)\n    assert g.is_multigraph\n    sg = dgl.to_simple(g)\n    assert not sg.is_multigraph\n    assert sg.num_edges() == 3\n    src, dst = sg.edges()\n    eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))\n    assert eset == set(elist)\n\n\n@unittest.skipIf(F._default_context_str == \"gpu\", reason=\"GPU not implemented\")\ndef _test_bidirected_graph():\n    def _test(in_readonly, out_readonly):\n        elist = [(0, 0), (0, 1), (1, 0), (1, 1), (2, 1), (2, 2)]\n        num_edges = 7\n        g = dgl.graph(elist)\n        elist.append((1, 2))\n        elist = set(elist)\n        big = dgl.to_bidirected_stale(g, out_readonly)\n        assert big.num_edges() == num_edges\n        src, dst = big.edges()\n        eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))\n        assert eset == set(elist)\n\n    _test(True, True)\n    _test(True, False)\n    _test(False, True)\n    _test(False, False)\n\n\n@unittest.skipIf(F._default_context_str == \"gpu\", reason=\"GPU not implemented\")\ndef test_khop_graph():\n    N = 20\n    feat = F.randn((N, 5))\n\n    def _test(g):\n        for k in range(4):\n            g_k = dgl.khop_graph(g, k)\n            # use original graph to do message passing for k times.\n            g.ndata[\"h\"] = feat\n            for _ in range(k):\n                g.update_all(fn.copy_u(\"h\", \"m\"), fn.sum(\"m\", \"h\"))\n            h_0 = g.ndata.pop(\"h\")\n            # use k-hop graph to do message passing for one time.\n            g_k.ndata[\"h\"] = feat\n            g_k.update_all(fn.copy_u(\"h\", \"m\"), fn.sum(\"m\", \"h\"))\n            h_1 = g_k.ndata.pop(\"h\")\n            assert F.allclose(h_0, h_1, rtol=1e-3, atol=1e-3)\n\n    # Test for random undirected graphs\n    g = dgl.from_networkx(nx.erdos_renyi_graph(N, 0.3))\n    _test(g)\n    # Test for random directed graphs\n    g = dgl.from_networkx(nx.erdos_renyi_graph(N, 0.3, directed=True))\n    _test(g)\n\n\n@unittest.skipIf(F._default_context_str == \"gpu\", reason=\"GPU not implemented\")\ndef test_khop_adj():\n    N = 20\n    feat = F.randn((N, 5))\n    g = dgl.from_networkx(nx.erdos_renyi_graph(N, 0.3, directed=True))\n    for k in range(3):\n        adj = F.tensor(F.swapaxes(dgl.khop_adj(g, k), 0, 1))\n        # use original graph to do message passing for k times.\n        g.ndata[\"h\"] = feat\n        for _ in range(k):\n            g.update_all(fn.copy_u(\"h\", \"m\"), fn.sum(\"m\", \"h\"))\n        h_0 = g.ndata.pop(\"h\")\n        # use k-hop adj to do message passing for one time.\n        h_1 = F.matmul(adj, feat)\n        assert F.allclose(h_0, h_1, rtol=1e-3, atol=1e-3)\n\n\n@unittest.skipIf(F._default_context_str == \"gpu\", reason=\"GPU not implemented\")\ndef test_laplacian_lambda_max():\n    N = 20\n    eps = 1e-6\n    # test DGLGraph\n    g = dgl.from_networkx(nx.erdos_renyi_graph(N, 0.3))\n    l_max = dgl.laplacian_lambda_max(g)\n    assert l_max[0] < 2 + eps\n    # test batched DGLGraph\n    \"\"\"\n    N_arr = [20, 30, 10, 12]\n    bg = dgl.batch([\n        dgl.from_networkx(nx.erdos_renyi_graph(N, 0.3))\n        for N in N_arr\n    ])\n    l_max_arr = dgl.laplacian_lambda_max(bg)\n    assert len(l_max_arr) == len(N_arr)\n    for l_max in l_max_arr:\n        assert l_max < 2 + eps\n    \"\"\"\n\n\ndef create_large_graph(num_nodes, idtype=F.int64):\n    row = np.random.choice(num_nodes, num_nodes * 10)\n    col = np.random.choice(num_nodes, num_nodes * 10)\n    spm = spsp.coo_matrix((np.ones(len(row)), (row, col)))\n    spm.sum_duplicates()\n\n    return dgl.from_scipy(spm, idtype=idtype)\n\n\n# Disabled since everything will be on heterogeneous graphs\n@unittest.skipIf(F._default_context_str == \"gpu\", reason=\"GPU not implemented\")\ndef test_partition_with_halo():\n    g = create_large_graph(1000)\n    node_part = np.random.choice(4, g.num_nodes())\n    subgs, _, _ = dgl.transforms.partition_graph_with_halo(\n        g, node_part, 2, reshuffle=True\n    )\n    for part_id, subg in subgs.items():\n        node_ids = np.nonzero(node_part == part_id)[0]\n        lnode_ids = np.nonzero(F.asnumpy(subg.ndata[\"inner_node\"]))[0]\n        orig_nids = F.asnumpy(subg.ndata[\"orig_id\"])[lnode_ids]\n        assert np.all(np.sort(orig_nids) == node_ids)\n        assert np.all(\n            F.asnumpy(subg.in_degrees(lnode_ids))\n            == F.asnumpy(g.in_degrees(orig_nids))\n        )\n        assert np.all(\n            F.asnumpy(subg.out_degrees(lnode_ids))\n            == F.asnumpy(g.out_degrees(orig_nids))\n        )\n\n\n@unittest.skipIf(os.name == \"nt\", reason=\"Do not support windows yet\")\n@unittest.skipIf(\n    F._default_context_str == \"gpu\", reason=\"METIS doesn't support GPU\"\n)\n@parametrize_idtype\ndef test_metis_partition(idtype):\n    # TODO(zhengda) Metis fails to partition a small graph.\n    g = create_large_graph(1000, idtype=idtype)\n    if idtype == F.int64:\n        check_metis_partition(g, 0)\n        check_metis_partition(g, 1)\n        check_metis_partition(g, 2)\n        check_metis_partition_with_constraint(g)\n    else:\n        assert_fail = False\n        try:\n            check_metis_partition(g, 1)\n        except:\n            assert_fail = True\n        assert assert_fail\n\n\ndef check_metis_partition_with_constraint(g):\n    ntypes = np.zeros((g.num_nodes(),), dtype=np.int32)\n    ntypes[0 : int(g.num_nodes() / 4)] = 1\n    ntypes[int(g.num_nodes() * 3 / 4) :] = 2\n    subgs = dgl.transforms.metis_partition(\n        g, 4, extra_cached_hops=1, balance_ntypes=ntypes\n    )\n    if subgs is not None:\n        for i in subgs:\n            subg = subgs[i]\n            parent_nids = F.asnumpy(subg.ndata[dgl.NID])\n            sub_ntypes = ntypes[parent_nids]\n            print(\"type0:\", np.sum(sub_ntypes == 0))\n            print(\"type1:\", np.sum(sub_ntypes == 1))\n            print(\"type2:\", np.sum(sub_ntypes == 2))\n    subgs = dgl.transforms.metis_partition(\n        g, 4, extra_cached_hops=1, balance_ntypes=ntypes, balance_edges=True\n    )\n    if subgs is not None:\n        for i in subgs:\n            subg = subgs[i]\n            parent_nids = F.asnumpy(subg.ndata[dgl.NID])\n            sub_ntypes = ntypes[parent_nids]\n            print(\"type0:\", np.sum(sub_ntypes == 0))\n            print(\"type1:\", np.sum(sub_ntypes == 1))\n            print(\"type2:\", np.sum(sub_ntypes == 2))\n\n\ndef check_metis_partition(g, extra_hops):\n    subgs = dgl.transforms.metis_partition(g, 4, extra_cached_hops=extra_hops)\n    num_inner_nodes = 0\n    num_inner_edges = 0\n    if subgs is not None:\n        for part_id, subg in subgs.items():\n            lnode_ids = np.nonzero(F.asnumpy(subg.ndata[\"inner_node\"]))[0]\n            ledge_ids = np.nonzero(F.asnumpy(subg.edata[\"inner_edge\"]))[0]\n            num_inner_nodes += len(lnode_ids)\n            num_inner_edges += len(ledge_ids)\n            assert np.sum(F.asnumpy(subg.ndata[\"part_id\"]) == part_id) == len(\n                lnode_ids\n            )\n        assert num_inner_nodes == g.num_nodes()\n        print(g.num_edges() - num_inner_edges)\n\n    if extra_hops == 0:\n        return\n\n    # partitions with node reshuffling\n    subgs = dgl.transforms.metis_partition(\n        g, 4, extra_cached_hops=extra_hops, reshuffle=True\n    )\n    num_inner_nodes = 0\n    num_inner_edges = 0\n    edge_cnts = np.zeros((g.num_edges(),))\n    if subgs is not None:\n        for part_id, subg in subgs.items():\n            lnode_ids = np.nonzero(F.asnumpy(subg.ndata[\"inner_node\"]))[0]\n            ledge_ids = np.nonzero(F.asnumpy(subg.edata[\"inner_edge\"]))[0]\n            num_inner_nodes += len(lnode_ids)\n            num_inner_edges += len(ledge_ids)\n            assert np.sum(F.asnumpy(subg.ndata[\"part_id\"]) == part_id) == len(\n                lnode_ids\n            )\n            nids = F.asnumpy(subg.ndata[dgl.NID])\n\n            # ensure the local node Ids are contiguous.\n            parent_ids = F.asnumpy(subg.ndata[dgl.NID])\n            parent_ids = parent_ids[: len(lnode_ids)]\n            assert np.all(\n                parent_ids == np.arange(parent_ids[0], parent_ids[-1] + 1)\n            )\n\n            # count the local edges.\n            parent_ids = F.asnumpy(subg.edata[dgl.EID])[ledge_ids]\n            edge_cnts[parent_ids] += 1\n\n            orig_ids = subg.ndata[\"orig_id\"]\n            inner_node = F.asnumpy(subg.ndata[\"inner_node\"])\n            for nid in range(subg.num_nodes()):\n                neighs = subg.predecessors(nid)\n                old_neighs1 = F.gather_row(orig_ids, neighs)\n                old_nid = F.asnumpy(orig_ids[nid])\n                old_neighs2 = g.predecessors(old_nid)\n                # If this is an inner node, it should have the full neighborhood.\n                if inner_node[nid]:\n                    assert np.all(\n                        np.sort(F.asnumpy(old_neighs1))\n                        == np.sort(F.asnumpy(old_neighs2))\n                    )\n        # Normally, local edges are only counted once.\n        assert np.all(edge_cnts == 1)\n\n        assert num_inner_nodes == g.num_nodes()\n        print(g.num_edges() - num_inner_edges)\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\", reason=\"It doesn't support GPU\"\n)\ndef test_reorder_nodes():\n    g = create_large_graph(1000)\n    new_nids = np.random.permutation(g.num_nodes())\n    # TODO(zhengda) we need to test both CSR and COO.\n    new_g = dgl.partition.reorder_nodes(g, new_nids)\n    new_in_deg = new_g.in_degrees()\n    new_out_deg = new_g.out_degrees()\n    in_deg = g.in_degrees()\n    out_deg = g.out_degrees()\n    new_in_deg1 = F.scatter_row(in_deg, F.tensor(new_nids), in_deg)\n    new_out_deg1 = F.scatter_row(out_deg, F.tensor(new_nids), out_deg)\n    assert np.all(F.asnumpy(new_in_deg == new_in_deg1))\n    assert np.all(F.asnumpy(new_out_deg == new_out_deg1))\n    orig_ids = F.asnumpy(new_g.ndata[\"orig_id\"])\n    for nid in range(g.num_nodes()):\n        neighs = F.asnumpy(g.successors(nid))\n        new_neighs1 = new_nids[neighs]\n        new_nid = new_nids[nid]\n        new_neighs2 = new_g.successors(new_nid)\n        assert np.all(np.sort(new_neighs1) == np.sort(F.asnumpy(new_neighs2)))\n\n    for nid in range(new_g.num_nodes()):\n        neighs = F.asnumpy(new_g.successors(nid))\n        old_neighs1 = orig_ids[neighs]\n        old_nid = orig_ids[nid]\n        old_neighs2 = g.successors(old_nid)\n        assert np.all(np.sort(old_neighs1) == np.sort(F.asnumpy(old_neighs2)))\n\n        neighs = F.asnumpy(new_g.predecessors(nid))\n        old_neighs1 = orig_ids[neighs]\n        old_nid = orig_ids[nid]\n        old_neighs2 = g.predecessors(old_nid)\n        assert np.all(np.sort(old_neighs1) == np.sort(F.asnumpy(old_neighs2)))\n\n\n@parametrize_idtype\ndef test_compact(idtype):\n    g1 = dgl.heterograph(\n        {\n            (\"user\", \"follow\", \"user\"): ([1, 3], [3, 5]),\n            (\"user\", \"plays\", \"game\"): ([2, 3, 2], [4, 4, 5]),\n            (\"game\", \"wished-by\", \"user\"): ([6, 5], [7, 7]),\n        },\n        {\"user\": 20, \"game\": 10},\n        idtype=idtype,\n        device=F.ctx(),\n    )\n\n    g2 = dgl.heterograph(\n        {\n            (\"game\", \"clicked-by\", \"user\"): ([3], [1]),\n            (\"user\", \"likes\", \"user\"): ([1, 8], [8, 9]),\n        },\n        {\"user\": 20, \"game\": 10},\n        idtype=idtype,\n        device=F.ctx(),\n    )\n\n    g3 = dgl.heterograph(\n        {(\"user\", \"_E\", \"user\"): ((0, 1), (1, 2))},\n        {\"user\": 10},\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    g4 = dgl.heterograph(\n        {(\"user\", \"_E\", \"user\"): ((1, 3), (3, 5))},\n        {\"user\": 10},\n        idtype=idtype,\n        device=F.ctx(),\n    )\n\n    def _check(g, new_g, induced_nodes):\n        assert g.ntypes == new_g.ntypes\n        assert g.canonical_etypes == new_g.canonical_etypes\n\n        for ntype in g.ntypes:\n            assert -1 not in induced_nodes[ntype]\n\n        for etype in g.canonical_etypes:\n            g_src, g_dst = g.all_edges(order=\"eid\", etype=etype)\n            g_src = F.asnumpy(g_src)\n            g_dst = F.asnumpy(g_dst)\n            new_g_src, new_g_dst = new_g.all_edges(order=\"eid\", etype=etype)\n            new_g_src_mapped = induced_nodes[etype[0]][F.asnumpy(new_g_src)]\n            new_g_dst_mapped = induced_nodes[etype[2]][F.asnumpy(new_g_dst)]\n            assert (g_src == new_g_src_mapped).all()\n            assert (g_dst == new_g_dst_mapped).all()\n\n    # Test default\n    new_g1 = dgl.compact_graphs(g1)\n    induced_nodes = {\n        ntype: new_g1.nodes[ntype].data[dgl.NID] for ntype in new_g1.ntypes\n    }\n    induced_nodes = {k: F.asnumpy(v) for k, v in induced_nodes.items()}\n    assert new_g1.idtype == idtype\n    assert set(induced_nodes[\"user\"]) == set([1, 3, 5, 2, 7])\n    assert set(induced_nodes[\"game\"]) == set([4, 5, 6])\n    _check(g1, new_g1, induced_nodes)\n\n    # Test with always_preserve given a dict\n    new_g1 = dgl.compact_graphs(\n        g1, always_preserve={\"game\": F.tensor([4, 7], idtype)}\n    )\n    assert new_g1.idtype == idtype\n    induced_nodes = {\n        ntype: new_g1.nodes[ntype].data[dgl.NID] for ntype in new_g1.ntypes\n    }\n    induced_nodes = {k: F.asnumpy(v) for k, v in induced_nodes.items()}\n    assert set(induced_nodes[\"user\"]) == set([1, 3, 5, 2, 7])\n    assert set(induced_nodes[\"game\"]) == set([4, 5, 6, 7])\n    _check(g1, new_g1, induced_nodes)\n\n    # Test with always_preserve given a tensor\n    new_g3 = dgl.compact_graphs(g3, always_preserve=F.tensor([1, 7], idtype))\n    induced_nodes = {\n        ntype: new_g3.nodes[ntype].data[dgl.NID] for ntype in new_g3.ntypes\n    }\n    induced_nodes = {k: F.asnumpy(v) for k, v in induced_nodes.items()}\n\n    assert new_g3.idtype == idtype\n    assert set(induced_nodes[\"user\"]) == set([0, 1, 2, 7])\n    _check(g3, new_g3, induced_nodes)\n\n    # Test multiple graphs\n    new_g1, new_g2 = dgl.compact_graphs([g1, g2])\n    induced_nodes = {\n        ntype: new_g1.nodes[ntype].data[dgl.NID] for ntype in new_g1.ntypes\n    }\n    induced_nodes = {k: F.asnumpy(v) for k, v in induced_nodes.items()}\n    assert new_g1.idtype == idtype\n    assert new_g2.idtype == idtype\n    assert set(induced_nodes[\"user\"]) == set([1, 3, 5, 2, 7, 8, 9])\n    assert set(induced_nodes[\"game\"]) == set([3, 4, 5, 6])\n    _check(g1, new_g1, induced_nodes)\n    _check(g2, new_g2, induced_nodes)\n\n    # Test multiple graphs with always_preserve given a dict\n    new_g1, new_g2 = dgl.compact_graphs(\n        [g1, g2], always_preserve={\"game\": F.tensor([4, 7], dtype=idtype)}\n    )\n    induced_nodes = {\n        ntype: new_g1.nodes[ntype].data[dgl.NID] for ntype in new_g1.ntypes\n    }\n    induced_nodes = {k: F.asnumpy(v) for k, v in induced_nodes.items()}\n    assert new_g1.idtype == idtype\n    assert new_g2.idtype == idtype\n    assert set(induced_nodes[\"user\"]) == set([1, 3, 5, 2, 7, 8, 9])\n    assert set(induced_nodes[\"game\"]) == set([3, 4, 5, 6, 7])\n    _check(g1, new_g1, induced_nodes)\n    _check(g2, new_g2, induced_nodes)\n\n    # Test multiple graphs with always_preserve given a tensor\n    new_g3, new_g4 = dgl.compact_graphs(\n        [g3, g4], always_preserve=F.tensor([1, 7], dtype=idtype)\n    )\n    induced_nodes = {\n        ntype: new_g3.nodes[ntype].data[dgl.NID] for ntype in new_g3.ntypes\n    }\n    induced_nodes = {k: F.asnumpy(v) for k, v in induced_nodes.items()}\n\n    assert new_g3.idtype == idtype\n    assert new_g4.idtype == idtype\n\n    assert set(induced_nodes[\"user\"]) == set([0, 1, 2, 3, 5, 7])\n    _check(g3, new_g3, induced_nodes)\n    _check(g4, new_g4, induced_nodes)\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\", reason=\"GPU to simple not implemented\"\n)\n@parametrize_idtype\ndef test_to_simple(idtype):\n    # homogeneous graph\n    g = dgl.graph((F.tensor([0, 1, 2, 1]), F.tensor([1, 2, 0, 2])))\n    g.ndata[\"h\"] = F.tensor([[0.0], [1.0], [2.0]])\n    g.edata[\"h\"] = F.tensor([[3.0], [4.0], [5.0], [6.0]])\n    sg, wb = dgl.to_simple(g, writeback_mapping=True)\n    u, v = g.all_edges(form=\"uv\", order=\"eid\")\n    u = F.asnumpy(u).tolist()\n    v = F.asnumpy(v).tolist()\n    uv = list(zip(u, v))\n    eid_map = F.asnumpy(wb)\n\n    su, sv = sg.all_edges(form=\"uv\", order=\"eid\")\n    su = F.asnumpy(su).tolist()\n    sv = F.asnumpy(sv).tolist()\n    suv = list(zip(su, sv))\n    sc = F.asnumpy(sg.edata[\"count\"])\n    assert set(uv) == set(suv)\n    for i, e in enumerate(suv):\n        assert sc[i] == sum(e == _e for _e in uv)\n    for i, e in enumerate(uv):\n        assert eid_map[i] == suv.index(e)\n    # shared ndata\n    assert F.array_equal(sg.ndata[\"h\"], g.ndata[\"h\"])\n    assert \"h\" not in sg.edata\n    # new ndata to sg\n    sg.ndata[\"hh\"] = F.tensor([[0.0], [1.0], [2.0]])\n    assert \"hh\" not in g.ndata\n\n    sg = dgl.to_simple(g, writeback_mapping=False, copy_ndata=False)\n    assert \"h\" not in sg.ndata\n    assert \"h\" not in sg.edata\n\n    # test coalesce edge feature\n    sg = dgl.to_simple(g, copy_edata=True, aggregator=\"arbitrary\")\n    assert F.allclose(sg.edata[\"h\"][1], F.tensor([4.0]))\n    sg = dgl.to_simple(g, copy_edata=True, aggregator=\"sum\")\n    assert F.allclose(sg.edata[\"h\"][1], F.tensor([10.0]))\n    sg = dgl.to_simple(g, copy_edata=True, aggregator=\"mean\")\n    assert F.allclose(sg.edata[\"h\"][1], F.tensor([5.0]))\n\n    # heterogeneous graph\n    g = dgl.heterograph(\n        {\n            (\"user\", \"follow\", \"user\"): (\n                [0, 1, 2, 1, 1, 1],\n                [1, 3, 2, 3, 4, 4],\n            ),\n            (\"user\", \"plays\", \"game\"): (\n                [3, 2, 1, 1, 3, 2, 2],\n                [5, 3, 4, 4, 5, 3, 3],\n            ),\n        },\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    g.nodes[\"user\"].data[\"h\"] = F.tensor([0, 1, 2, 3, 4])\n    g.nodes[\"user\"].data[\"hh\"] = F.tensor([0, 1, 2, 3, 4])\n    g.edges[\"follow\"].data[\"h\"] = F.tensor([0, 1, 2, 3, 4, 5])\n    sg, wb = dgl.to_simple(\n        g, return_counts=\"weights\", writeback_mapping=True, copy_edata=True\n    )\n    g.nodes[\"game\"].data[\"h\"] = F.tensor([0, 1, 2, 3, 4, 5])\n\n    for etype in g.canonical_etypes:\n        u, v = g.all_edges(form=\"uv\", order=\"eid\", etype=etype)\n        u = F.asnumpy(u).tolist()\n        v = F.asnumpy(v).tolist()\n        uv = list(zip(u, v))\n        eid_map = F.asnumpy(wb[etype])\n\n        su, sv = sg.all_edges(form=\"uv\", order=\"eid\", etype=etype)\n        su = F.asnumpy(su).tolist()\n        sv = F.asnumpy(sv).tolist()\n        suv = list(zip(su, sv))\n        sw = F.asnumpy(sg.edges[etype].data[\"weights\"])\n\n        assert set(uv) == set(suv)\n        for i, e in enumerate(suv):\n            assert sw[i] == sum(e == _e for _e in uv)\n        for i, e in enumerate(uv):\n            assert eid_map[i] == suv.index(e)\n    # shared ndata\n    assert F.array_equal(sg.nodes[\"user\"].data[\"h\"], g.nodes[\"user\"].data[\"h\"])\n    assert F.array_equal(\n        sg.nodes[\"user\"].data[\"hh\"], g.nodes[\"user\"].data[\"hh\"]\n    )\n    assert \"h\" not in sg.nodes[\"game\"].data\n    # new ndata to sg\n    sg.nodes[\"user\"].data[\"hhh\"] = F.tensor([0, 1, 2, 3, 4])\n    assert \"hhh\" not in g.nodes[\"user\"].data\n    # share edata\n    feat_idx = F.asnumpy(wb[(\"user\", \"follow\", \"user\")])\n    _, indices = np.unique(feat_idx, return_index=True)\n    assert np.array_equal(\n        F.asnumpy(sg.edges[\"follow\"].data[\"h\"]),\n        F.asnumpy(g.edges[\"follow\"].data[\"h\"])[indices],\n    )\n\n    sg = dgl.to_simple(g, writeback_mapping=False, copy_ndata=False)\n    for ntype in g.ntypes:\n        assert g.num_nodes(ntype) == sg.num_nodes(ntype)\n    assert \"h\" not in sg.nodes[\"user\"].data\n    assert \"hh\" not in sg.nodes[\"user\"].data\n\n    # verify DGLGraph.edge_ids() after dgl.to_simple()\n    # in case ids are not initialized in underlying coo2csr()\n    u = F.tensor([0, 1, 2])\n    v = F.tensor([1, 2, 3])\n    eids = F.tensor([0, 1, 2])\n    g = dgl.graph((u, v))\n    assert F.array_equal(g.edge_ids(u, v), eids)\n    sg = dgl.to_simple(g)\n    assert F.array_equal(sg.edge_ids(u, v), eids)\n\n\n@unittest.skipIf(F._default_context_str == \"gpu\", reason=\"GPU not implemented\")\n@parametrize_idtype\ndef test_remove_edges(idtype):\n    def check(g1, etype, g, edges_removed):\n        src, dst, eid = g.edges(etype=etype, form=\"all\")\n        src1, dst1 = g1.edges(etype=etype, order=\"eid\")\n        if etype is not None:\n            eid1 = g1.edges[etype].data[dgl.EID]\n        else:\n            eid1 = g1.edata[dgl.EID]\n        src1 = F.asnumpy(src1)\n        dst1 = F.asnumpy(dst1)\n        eid1 = F.asnumpy(eid1)\n        src = F.asnumpy(src)\n        dst = F.asnumpy(dst)\n        eid = F.asnumpy(eid)\n        sde_set = set(zip(src, dst, eid))\n\n        for s, d, e in zip(src1, dst1, eid1):\n            assert (s, d, e) in sde_set\n        assert not np.isin(edges_removed, eid1).any()\n        assert g1.idtype == g.idtype\n\n    for fmt in [\"coo\", \"csr\", \"csc\"]:\n        for edges_to_remove in [[2], [2, 2], [3, 2], [1, 3, 1, 2]]:\n            g = dgl.graph(([0, 2, 1, 3], [1, 3, 2, 4]), idtype=idtype).formats(\n                fmt\n            )\n            g1 = dgl.remove_edges(g, F.tensor(edges_to_remove, idtype))\n            check(g1, None, g, edges_to_remove)\n\n            g = dgl.from_scipy(\n                spsp.csr_matrix(\n                    ([1, 1, 1, 1], ([0, 2, 1, 3], [1, 3, 2, 4])), shape=(5, 5)\n                ),\n                idtype=idtype,\n            ).formats(fmt)\n            g1 = dgl.remove_edges(g, F.tensor(edges_to_remove, idtype))\n            check(g1, None, g, edges_to_remove)\n\n    g = dgl.heterograph(\n        {\n            (\"A\", \"AA\", \"A\"): ([0, 2, 1, 3], [1, 3, 2, 4]),\n            (\"A\", \"AB\", \"B\"): ([0, 1, 3, 1], [1, 3, 5, 6]),\n            (\"B\", \"BA\", \"A\"): ([2, 3], [3, 2]),\n        },\n        idtype=idtype,\n    )\n    g2 = dgl.remove_edges(\n        g,\n        {\n            \"AA\": F.tensor([2], idtype),\n            \"AB\": F.tensor([3], idtype),\n            \"BA\": F.tensor([1], idtype),\n        },\n    )\n    check(g2, \"AA\", g, [2])\n    check(g2, \"AB\", g, [3])\n    check(g2, \"BA\", g, [1])\n\n    g3 = dgl.remove_edges(\n        g,\n        {\n            \"AA\": F.tensor([], idtype),\n            \"AB\": F.tensor([3], idtype),\n            \"BA\": F.tensor([1], idtype),\n        },\n    )\n    check(g3, \"AA\", g, [])\n    check(g3, \"AB\", g, [3])\n    check(g3, \"BA\", g, [1])\n\n    g4 = dgl.remove_edges(g, {\"AB\": F.tensor([3, 1, 2, 0], idtype)})\n    check(g4, \"AA\", g, [])\n    check(g4, \"AB\", g, [3, 1, 2, 0])\n    check(g4, \"BA\", g, [])\n\n\n@parametrize_idtype\ndef test_add_edges(idtype):\n    # homogeneous graph\n    g = dgl.graph(([0, 1], [1, 2]), idtype=idtype, device=F.ctx())\n    u = 0\n    v = 1\n    g = dgl.add_edges(g, u, v)\n    assert g.device == F.ctx()\n    assert g.num_nodes() == 3\n    assert g.num_edges() == 3\n    u = [0]\n    v = [1]\n    g = dgl.add_edges(g, u, v)\n    assert g.device == F.ctx()\n    assert g.num_nodes() == 3\n    assert g.num_edges() == 4\n    u = F.tensor(u, dtype=idtype)\n    v = F.tensor(v, dtype=idtype)\n    g = dgl.add_edges(g, u, v)\n    assert g.device == F.ctx()\n    assert g.num_nodes() == 3\n    assert g.num_edges() == 5\n    u, v = g.edges(form=\"uv\", order=\"eid\")\n    assert F.array_equal(u, F.tensor([0, 1, 0, 0, 0], dtype=idtype))\n    assert F.array_equal(v, F.tensor([1, 2, 1, 1, 1], dtype=idtype))\n    g = dgl.add_edges(g, [], [])\n    g = dgl.add_edges(g, 0, [])\n    g = dgl.add_edges(g, [], 0)\n    assert g.device == F.ctx()\n    assert g.num_nodes() == 3\n    assert g.num_edges() == 5\n    u, v = g.edges(form=\"uv\", order=\"eid\")\n    assert F.array_equal(u, F.tensor([0, 1, 0, 0, 0], dtype=idtype))\n    assert F.array_equal(v, F.tensor([1, 2, 1, 1, 1], dtype=idtype))\n\n    # node id larger than current max node id\n    g = dgl.graph(([0, 1], [1, 2]), idtype=idtype, device=F.ctx())\n    u = F.tensor([0, 1], dtype=idtype)\n    v = F.tensor([2, 3], dtype=idtype)\n    g = dgl.add_edges(g, u, v)\n    assert g.num_nodes() == 4\n    assert g.num_edges() == 4\n    u, v = g.edges(form=\"uv\", order=\"eid\")\n    assert F.array_equal(u, F.tensor([0, 1, 0, 1], dtype=idtype))\n    assert F.array_equal(v, F.tensor([1, 2, 2, 3], dtype=idtype))\n\n    # has data\n    g = dgl.graph(([0, 1], [1, 2]), idtype=idtype, device=F.ctx())\n    g.ndata[\"h\"] = F.copy_to(F.tensor([1, 1, 1], dtype=idtype), ctx=F.ctx())\n    g.edata[\"h\"] = F.copy_to(F.tensor([1, 1], dtype=idtype), ctx=F.ctx())\n    u = F.tensor([0, 1], dtype=idtype)\n    v = F.tensor([2, 3], dtype=idtype)\n    e_feat = {\n        \"h\": F.copy_to(F.tensor([2, 2], dtype=idtype), ctx=F.ctx()),\n        \"hh\": F.copy_to(F.tensor([2, 2], dtype=idtype), ctx=F.ctx()),\n    }\n    g = dgl.add_edges(g, u, v, e_feat)\n    assert g.num_nodes() == 4\n    assert g.num_edges() == 4\n    u, v = g.edges(form=\"uv\", order=\"eid\")\n    assert F.array_equal(u, F.tensor([0, 1, 0, 1], dtype=idtype))\n    assert F.array_equal(v, F.tensor([1, 2, 2, 3], dtype=idtype))\n    assert F.array_equal(g.ndata[\"h\"], F.tensor([1, 1, 1, 0], dtype=idtype))\n    assert F.array_equal(g.edata[\"h\"], F.tensor([1, 1, 2, 2], dtype=idtype))\n    assert F.array_equal(g.edata[\"hh\"], F.tensor([0, 0, 2, 2], dtype=idtype))\n\n    # zero data graph\n    g = dgl.graph(([], []), num_nodes=0, idtype=idtype, device=F.ctx())\n    u = F.tensor([0, 1], dtype=idtype)\n    v = F.tensor([2, 2], dtype=idtype)\n    e_feat = {\n        \"h\": F.copy_to(F.tensor([2, 2], dtype=idtype), ctx=F.ctx()),\n        \"hh\": F.copy_to(F.tensor([2, 2], dtype=idtype), ctx=F.ctx()),\n    }\n    g = dgl.add_edges(g, u, v, e_feat)\n    assert g.num_nodes() == 3\n    assert g.num_edges() == 2\n    u, v = g.edges(form=\"uv\", order=\"eid\")\n    assert F.array_equal(u, F.tensor([0, 1], dtype=idtype))\n    assert F.array_equal(v, F.tensor([2, 2], dtype=idtype))\n    assert F.array_equal(g.edata[\"h\"], F.tensor([2, 2], dtype=idtype))\n    assert F.array_equal(g.edata[\"hh\"], F.tensor([2, 2], dtype=idtype))\n\n    # bipartite graph\n    g = dgl.heterograph(\n        {(\"user\", \"plays\", \"game\"): ([0, 1], [1, 2])},\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    u = 0\n    v = 1\n    g = dgl.add_edges(g, u, v)\n    assert g.device == F.ctx()\n    assert g.num_nodes(\"user\") == 2\n    assert g.num_nodes(\"game\") == 3\n    assert g.num_edges() == 3\n    u = [0]\n    v = [1]\n    g = dgl.add_edges(g, u, v)\n    assert g.device == F.ctx()\n    assert g.num_nodes(\"user\") == 2\n    assert g.num_nodes(\"game\") == 3\n    assert g.num_edges() == 4\n    u = F.tensor(u, dtype=idtype)\n    v = F.tensor(v, dtype=idtype)\n    g = dgl.add_edges(g, u, v)\n    assert g.device == F.ctx()\n    assert g.num_nodes(\"user\") == 2\n    assert g.num_nodes(\"game\") == 3\n    assert g.num_edges() == 5\n    u, v = g.edges(form=\"uv\")\n    assert F.array_equal(u, F.tensor([0, 1, 0, 0, 0], dtype=idtype))\n    assert F.array_equal(v, F.tensor([1, 2, 1, 1, 1], dtype=idtype))\n\n    # node id larger than current max node id\n    g = dgl.heterograph(\n        {(\"user\", \"plays\", \"game\"): ([0, 1], [1, 2])},\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    u = F.tensor([0, 2], dtype=idtype)\n    v = F.tensor([2, 3], dtype=idtype)\n    g = dgl.add_edges(g, u, v)\n    assert g.device == F.ctx()\n    assert g.num_nodes(\"user\") == 3\n    assert g.num_nodes(\"game\") == 4\n    assert g.num_edges() == 4\n    u, v = g.edges(form=\"uv\", order=\"eid\")\n    assert F.array_equal(u, F.tensor([0, 1, 0, 2], dtype=idtype))\n    assert F.array_equal(v, F.tensor([1, 2, 2, 3], dtype=idtype))\n\n    # has data\n    g = dgl.heterograph(\n        {(\"user\", \"plays\", \"game\"): ([0, 1], [1, 2])},\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    g.nodes[\"user\"].data[\"h\"] = F.copy_to(\n        F.tensor([1, 1], dtype=idtype), ctx=F.ctx()\n    )\n    g.nodes[\"game\"].data[\"h\"] = F.copy_to(\n        F.tensor([2, 2, 2], dtype=idtype), ctx=F.ctx()\n    )\n    g.edata[\"h\"] = F.copy_to(F.tensor([1, 1], dtype=idtype), ctx=F.ctx())\n    u = F.tensor([0, 2], dtype=idtype)\n    v = F.tensor([2, 3], dtype=idtype)\n    e_feat = {\n        \"h\": F.copy_to(F.tensor([2, 2], dtype=idtype), ctx=F.ctx()),\n        \"hh\": F.copy_to(F.tensor([2, 2], dtype=idtype), ctx=F.ctx()),\n    }\n    g = dgl.add_edges(g, u, v, e_feat)\n    assert g.num_nodes(\"user\") == 3\n    assert g.num_nodes(\"game\") == 4\n    assert g.num_edges() == 4\n    u, v = g.edges(form=\"uv\", order=\"eid\")\n    assert F.array_equal(u, F.tensor([0, 1, 0, 2], dtype=idtype))\n    assert F.array_equal(v, F.tensor([1, 2, 2, 3], dtype=idtype))\n    assert F.array_equal(\n        g.nodes[\"user\"].data[\"h\"], F.tensor([1, 1, 0], dtype=idtype)\n    )\n    assert F.array_equal(\n        g.nodes[\"game\"].data[\"h\"], F.tensor([2, 2, 2, 0], dtype=idtype)\n    )\n    assert F.array_equal(g.edata[\"h\"], F.tensor([1, 1, 2, 2], dtype=idtype))\n    assert F.array_equal(g.edata[\"hh\"], F.tensor([0, 0, 2, 2], dtype=idtype))\n\n    # heterogeneous graph\n    g = create_test_heterograph3(idtype)\n    u = F.tensor([0, 2], dtype=idtype)\n    v = F.tensor([2, 3], dtype=idtype)\n    g = dgl.add_edges(g, u, v, etype=\"plays\")\n    assert g.num_nodes(\"user\") == 3\n    assert g.num_nodes(\"game\") == 4\n    assert g.num_nodes(\"developer\") == 2\n    assert g.num_edges(\"plays\") == 6\n    assert g.num_edges(\"develops\") == 2\n    u, v = g.edges(form=\"uv\", order=\"eid\", etype=\"plays\")\n    assert F.array_equal(u, F.tensor([0, 1, 1, 2, 0, 2], dtype=idtype))\n    assert F.array_equal(v, F.tensor([0, 0, 1, 1, 2, 3], dtype=idtype))\n    assert F.array_equal(\n        g.nodes[\"user\"].data[\"h\"], F.tensor([1, 1, 1], dtype=idtype)\n    )\n    assert F.array_equal(\n        g.nodes[\"game\"].data[\"h\"], F.tensor([2, 2, 0, 0], dtype=idtype)\n    )\n    assert F.array_equal(\n        g.edges[\"plays\"].data[\"h\"], F.tensor([1, 1, 1, 1, 0, 0], dtype=idtype)\n    )\n\n    # add with feature\n    e_feat = {\"h\": F.copy_to(F.tensor([2, 2], dtype=idtype), ctx=F.ctx())}\n    u = F.tensor([0, 2], dtype=idtype)\n    v = F.tensor([2, 3], dtype=idtype)\n    g.nodes[\"game\"].data[\"h\"] = F.copy_to(\n        F.tensor([2, 2, 1, 1], dtype=idtype), ctx=F.ctx()\n    )\n    g = dgl.add_edges(g, u, v, data=e_feat, etype=\"develops\")\n    assert g.num_nodes(\"user\") == 3\n    assert g.num_nodes(\"game\") == 4\n    assert g.num_nodes(\"developer\") == 3\n    assert g.num_edges(\"plays\") == 6\n    assert g.num_edges(\"develops\") == 4\n    u, v = g.edges(form=\"uv\", order=\"eid\", etype=\"develops\")\n    assert F.array_equal(u, F.tensor([0, 1, 0, 2], dtype=idtype))\n    assert F.array_equal(v, F.tensor([0, 1, 2, 3], dtype=idtype))\n    assert F.array_equal(\n        g.nodes[\"developer\"].data[\"h\"], F.tensor([3, 3, 0], dtype=idtype)\n    )\n    assert F.array_equal(\n        g.nodes[\"game\"].data[\"h\"], F.tensor([2, 2, 1, 1], dtype=idtype)\n    )\n    assert F.array_equal(\n        g.edges[\"develops\"].data[\"h\"], F.tensor([0, 0, 2, 2], dtype=idtype)\n    )\n\n\n@parametrize_idtype\ndef test_add_nodes(idtype):\n    # homogeneous Graphs\n    g = dgl.graph(([0, 1], [1, 2]), idtype=idtype, device=F.ctx())\n    g.ndata[\"h\"] = F.copy_to(F.tensor([1, 1, 1], dtype=idtype), ctx=F.ctx())\n    new_g = dgl.add_nodes(g, 1)\n    assert g.num_nodes() == 3\n    assert new_g.num_nodes() == 4\n    assert F.array_equal(new_g.ndata[\"h\"], F.tensor([1, 1, 1, 0], dtype=idtype))\n\n    # zero node graph\n    g = dgl.graph(([], []), num_nodes=3, idtype=idtype, device=F.ctx())\n    g.ndata[\"h\"] = F.copy_to(F.tensor([1, 1, 1], dtype=idtype), ctx=F.ctx())\n    g = dgl.add_nodes(\n        g, 1, data={\"h\": F.copy_to(F.tensor([2], dtype=idtype), ctx=F.ctx())}\n    )\n    assert g.num_nodes() == 4\n    assert F.array_equal(g.ndata[\"h\"], F.tensor([1, 1, 1, 2], dtype=idtype))\n\n    # bipartite graph\n    g = dgl.heterograph(\n        {(\"user\", \"plays\", \"game\"): ([0, 1], [1, 2])},\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    g = dgl.add_nodes(\n        g,\n        2,\n        data={\"h\": F.copy_to(F.tensor([2, 2], dtype=idtype), ctx=F.ctx())},\n        ntype=\"user\",\n    )\n    assert g.num_nodes(\"user\") == 4\n    assert g.num_nodes(\"game\") == 3\n    assert F.array_equal(\n        g.nodes[\"user\"].data[\"h\"], F.tensor([0, 0, 2, 2], dtype=idtype)\n    )\n    g = dgl.add_nodes(g, 2, ntype=\"game\")\n    assert g.num_nodes(\"user\") == 4\n    assert g.num_nodes(\"game\") == 5\n\n    # heterogeneous graph\n    g = create_test_heterograph3(idtype)\n    g = dgl.add_nodes(g, 1, ntype=\"user\")\n    g = dgl.add_nodes(\n        g,\n        2,\n        data={\"h\": F.copy_to(F.tensor([2, 2], dtype=idtype), ctx=F.ctx())},\n        ntype=\"game\",\n    )\n    assert g.num_nodes(\"user\") == 4\n    assert g.num_nodes(\"game\") == 4\n    assert g.num_nodes(\"developer\") == 2\n    assert F.array_equal(\n        g.nodes[\"user\"].data[\"h\"], F.tensor([1, 1, 1, 0], dtype=idtype)\n    )\n    assert F.array_equal(\n        g.nodes[\"game\"].data[\"h\"], F.tensor([2, 2, 2, 2], dtype=idtype)\n    )\n\n\n@parametrize_idtype\ndef test_remove_edges(idtype):\n    # homogeneous Graphs\n    g = dgl.graph(([0, 1], [1, 2]), idtype=idtype, device=F.ctx())\n    e = 0\n    g = dgl.remove_edges(g, e)\n    assert g.num_edges() == 1\n    u, v = g.edges(form=\"uv\", order=\"eid\")\n    assert F.array_equal(u, F.tensor([1], dtype=idtype))\n    assert F.array_equal(v, F.tensor([2], dtype=idtype))\n    g = dgl.graph(([0, 1], [1, 2]), idtype=idtype, device=F.ctx())\n    e = [0]\n    g = dgl.remove_edges(g, e)\n    assert g.num_edges() == 1\n    u, v = g.edges(form=\"uv\", order=\"eid\")\n    assert F.array_equal(u, F.tensor([1], dtype=idtype))\n    assert F.array_equal(v, F.tensor([2], dtype=idtype))\n    e = F.tensor([0], dtype=idtype)\n    g = dgl.remove_edges(g, e)\n    assert g.num_edges() == 0\n\n    # has node data\n    g = dgl.graph(([0, 1], [1, 2]), idtype=idtype, device=F.ctx())\n    g.ndata[\"h\"] = F.copy_to(F.tensor([1, 2, 3], dtype=idtype), ctx=F.ctx())\n    g = dgl.remove_edges(g, 1)\n    assert g.num_edges() == 1\n    assert F.array_equal(g.ndata[\"h\"], F.tensor([1, 2, 3], dtype=idtype))\n\n    # has edge data\n    g = dgl.graph(([0, 1], [1, 2]), idtype=idtype, device=F.ctx())\n    g.edata[\"h\"] = F.copy_to(F.tensor([1, 2], dtype=idtype), ctx=F.ctx())\n    g = dgl.remove_edges(g, 0)\n    assert g.num_edges() == 1\n    assert F.array_equal(g.edata[\"h\"], F.tensor([2], dtype=idtype))\n\n    # invalid eid\n    assert_fail = False\n    try:\n        g = dgl.remove_edges(g, 1)\n    except:\n        assert_fail = True\n    assert assert_fail\n\n    # bipartite graph\n    g = dgl.heterograph(\n        {(\"user\", \"plays\", \"game\"): ([0, 1], [1, 2])},\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    e = 0\n    g = dgl.remove_edges(g, e)\n    assert g.num_edges() == 1\n    u, v = g.edges(form=\"uv\", order=\"eid\")\n    assert F.array_equal(u, F.tensor([1], dtype=idtype))\n    assert F.array_equal(v, F.tensor([2], dtype=idtype))\n    g = dgl.heterograph(\n        {(\"user\", \"plays\", \"game\"): ([0, 1], [1, 2])},\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    e = [0]\n    g = dgl.remove_edges(g, e)\n    assert g.num_edges() == 1\n    u, v = g.edges(form=\"uv\", order=\"eid\")\n    assert F.array_equal(u, F.tensor([1], dtype=idtype))\n    assert F.array_equal(v, F.tensor([2], dtype=idtype))\n    e = F.tensor([0], dtype=idtype)\n    g = dgl.remove_edges(g, e)\n    assert g.num_edges() == 0\n\n    # has data\n    g = dgl.heterograph(\n        {(\"user\", \"plays\", \"game\"): ([0, 1], [1, 2])},\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    g.nodes[\"user\"].data[\"h\"] = F.copy_to(\n        F.tensor([1, 1], dtype=idtype), ctx=F.ctx()\n    )\n    g.nodes[\"game\"].data[\"h\"] = F.copy_to(\n        F.tensor([2, 2, 2], dtype=idtype), ctx=F.ctx()\n    )\n    g.edata[\"h\"] = F.copy_to(F.tensor([1, 2], dtype=idtype), ctx=F.ctx())\n    g = dgl.remove_edges(g, 1)\n    assert g.num_edges() == 1\n    assert F.array_equal(\n        g.nodes[\"user\"].data[\"h\"], F.tensor([1, 1], dtype=idtype)\n    )\n    assert F.array_equal(\n        g.nodes[\"game\"].data[\"h\"], F.tensor([2, 2, 2], dtype=idtype)\n    )\n    assert F.array_equal(g.edata[\"h\"], F.tensor([1], dtype=idtype))\n\n    # heterogeneous graph\n    g = create_test_heterograph3(idtype)\n    g.edges[\"plays\"].data[\"h\"] = F.copy_to(\n        F.tensor([1, 2, 3, 4], dtype=idtype), ctx=F.ctx()\n    )\n    g = dgl.remove_edges(g, 1, etype=\"plays\")\n    assert g.num_edges(\"plays\") == 3\n    u, v = g.edges(form=\"uv\", order=\"eid\", etype=\"plays\")\n    assert F.array_equal(u, F.tensor([0, 1, 2], dtype=idtype))\n    assert F.array_equal(v, F.tensor([0, 1, 1], dtype=idtype))\n    assert F.array_equal(\n        g.edges[\"plays\"].data[\"h\"], F.tensor([1, 3, 4], dtype=idtype)\n    )\n    # remove all edges of 'develops'\n    g = dgl.remove_edges(g, [0, 1], etype=\"develops\")\n    assert g.num_edges(\"develops\") == 0\n    assert F.array_equal(\n        g.nodes[\"user\"].data[\"h\"], F.tensor([1, 1, 1], dtype=idtype)\n    )\n    assert F.array_equal(\n        g.nodes[\"game\"].data[\"h\"], F.tensor([2, 2], dtype=idtype)\n    )\n    assert F.array_equal(\n        g.nodes[\"developer\"].data[\"h\"], F.tensor([3, 3], dtype=idtype)\n    )\n\n    # batched graph\n    ctx = F.ctx()\n    g1 = dgl.graph(([0, 1], [1, 2]), num_nodes=5, idtype=idtype, device=ctx)\n    g2 = dgl.graph(([], []), idtype=idtype, device=ctx)\n    g3 = dgl.graph(([2, 3, 4], [3, 2, 1]), idtype=idtype, device=ctx)\n    bg = dgl.batch([g1, g2, g3])\n    bg_r = dgl.remove_edges(bg, 2)\n    assert bg.batch_size == bg_r.batch_size\n    assert F.array_equal(bg.batch_num_nodes(), bg_r.batch_num_nodes())\n    assert F.array_equal(\n        bg_r.batch_num_edges(), F.tensor([2, 0, 2], dtype=idtype)\n    )\n\n    bg_r = dgl.remove_edges(bg, [0, 2])\n    assert bg.batch_size == bg_r.batch_size\n    assert F.array_equal(bg.batch_num_nodes(), bg_r.batch_num_nodes())\n    assert F.array_equal(\n        bg_r.batch_num_edges(), F.tensor([1, 0, 2], dtype=idtype)\n    )\n\n    bg_r = dgl.remove_edges(bg, F.tensor([0, 2], dtype=idtype))\n    assert bg.batch_size == bg_r.batch_size\n    assert F.array_equal(bg.batch_num_nodes(), bg_r.batch_num_nodes())\n    assert F.array_equal(\n        bg_r.batch_num_edges(), F.tensor([1, 0, 2], dtype=idtype)\n    )\n\n    # batched heterogeneous graph\n    g1 = dgl.heterograph(\n        {\n            (\"user\", \"follows\", \"user\"): ([0, 1], [1, 2]),\n            (\"user\", \"plays\", \"game\"): ([1, 3], [0, 1]),\n        },\n        num_nodes_dict={\"user\": 4, \"game\": 3},\n        idtype=idtype,\n        device=ctx,\n    )\n    g2 = dgl.heterograph(\n        {\n            (\"user\", \"follows\", \"user\"): ([0, 2], [3, 4]),\n            (\"user\", \"plays\", \"game\"): ([], []),\n        },\n        num_nodes_dict={\"user\": 6, \"game\": 2},\n        idtype=idtype,\n        device=ctx,\n    )\n    g3 = dgl.heterograph(\n        {\n            (\"user\", \"follows\", \"user\"): ([], []),\n            (\"user\", \"plays\", \"game\"): ([1, 2], [1, 2]),\n        },\n        idtype=idtype,\n        device=ctx,\n    )\n    bg = dgl.batch([g1, g2, g3])\n    bg_r = dgl.remove_edges(bg, 1, etype=\"follows\")\n    assert bg.batch_size == bg_r.batch_size\n    ntypes = bg.ntypes\n    for nty in ntypes:\n        assert F.array_equal(bg.batch_num_nodes(nty), bg_r.batch_num_nodes(nty))\n    assert F.array_equal(\n        bg_r.batch_num_edges(\"follows\"), F.tensor([1, 2, 0], dtype=idtype)\n    )\n    assert F.array_equal(\n        bg_r.batch_num_edges(\"plays\"), bg.batch_num_edges(\"plays\")\n    )\n\n    bg_r = dgl.remove_edges(bg, 2, etype=\"plays\")\n    assert bg.batch_size == bg_r.batch_size\n    for nty in ntypes:\n        assert F.array_equal(bg.batch_num_nodes(nty), bg_r.batch_num_nodes(nty))\n    assert F.array_equal(\n        bg.batch_num_edges(\"follows\"), bg_r.batch_num_edges(\"follows\")\n    )\n    assert F.array_equal(\n        bg_r.batch_num_edges(\"plays\"), F.tensor([2, 0, 1], dtype=idtype)\n    )\n\n    bg_r = dgl.remove_edges(bg, [0, 1, 3], etype=\"follows\")\n    assert bg.batch_size == bg_r.batch_size\n    for nty in ntypes:\n        assert F.array_equal(bg.batch_num_nodes(nty), bg_r.batch_num_nodes(nty))\n    assert F.array_equal(\n        bg_r.batch_num_edges(\"follows\"), F.tensor([0, 1, 0], dtype=idtype)\n    )\n    assert F.array_equal(\n        bg.batch_num_edges(\"plays\"), bg_r.batch_num_edges(\"plays\")\n    )\n\n    bg_r = dgl.remove_edges(bg, [1, 2], etype=\"plays\")\n    assert bg.batch_size == bg_r.batch_size\n    for nty in ntypes:\n        assert F.array_equal(bg.batch_num_nodes(nty), bg_r.batch_num_nodes(nty))\n    assert F.array_equal(\n        bg.batch_num_edges(\"follows\"), bg_r.batch_num_edges(\"follows\")\n    )\n    assert F.array_equal(\n        bg_r.batch_num_edges(\"plays\"), F.tensor([1, 0, 1], dtype=idtype)\n    )\n\n    bg_r = dgl.remove_edges(\n        bg, F.tensor([0, 1, 3], dtype=idtype), etype=\"follows\"\n    )\n    assert bg.batch_size == bg_r.batch_size\n    for nty in ntypes:\n        assert F.array_equal(bg.batch_num_nodes(nty), bg_r.batch_num_nodes(nty))\n    assert F.array_equal(\n        bg_r.batch_num_edges(\"follows\"), F.tensor([0, 1, 0], dtype=idtype)\n    )\n    assert F.array_equal(\n        bg.batch_num_edges(\"plays\"), bg_r.batch_num_edges(\"plays\")\n    )\n\n    bg_r = dgl.remove_edges(bg, F.tensor([1, 2], dtype=idtype), etype=\"plays\")\n    assert bg.batch_size == bg_r.batch_size\n    for nty in ntypes:\n        assert F.array_equal(bg.batch_num_nodes(nty), bg_r.batch_num_nodes(nty))\n    assert F.array_equal(\n        bg.batch_num_edges(\"follows\"), bg_r.batch_num_edges(\"follows\")\n    )\n    assert F.array_equal(\n        bg_r.batch_num_edges(\"plays\"), F.tensor([1, 0, 1], dtype=idtype)\n    )\n\n\n@parametrize_idtype\ndef test_remove_nodes(idtype):\n    # homogeneous Graphs\n    g = dgl.graph(([0, 1], [1, 2]), idtype=idtype, device=F.ctx())\n    n = 0\n    g = dgl.remove_nodes(g, n)\n    assert g.num_nodes() == 2\n    assert g.num_edges() == 1\n    u, v = g.edges(form=\"uv\", order=\"eid\")\n    assert F.array_equal(u, F.tensor([0], dtype=idtype))\n    assert F.array_equal(v, F.tensor([1], dtype=idtype))\n    g = dgl.graph(([0, 1], [1, 2]), idtype=idtype, device=F.ctx())\n    n = [1]\n    g = dgl.remove_nodes(g, n)\n    assert g.num_nodes() == 2\n    assert g.num_edges() == 0\n    g = dgl.graph(([0, 1], [1, 2]), idtype=idtype, device=F.ctx())\n    n = F.tensor([2], dtype=idtype)\n    g = dgl.remove_nodes(g, n)\n    assert g.num_nodes() == 2\n    assert g.num_edges() == 1\n    u, v = g.edges(form=\"uv\", order=\"eid\")\n    assert F.array_equal(u, F.tensor([0], dtype=idtype))\n    assert F.array_equal(v, F.tensor([1], dtype=idtype))\n\n    # invalid nid\n    assert_fail = False\n    try:\n        g.remove_nodes(3)\n    except:\n        assert_fail = True\n    assert assert_fail\n\n    # has node and edge data\n    g = dgl.graph(([0, 0, 2], [0, 1, 2]), idtype=idtype, device=F.ctx())\n    g.ndata[\"hv\"] = F.copy_to(F.tensor([1, 2, 3], dtype=idtype), ctx=F.ctx())\n    g.edata[\"he\"] = F.copy_to(F.tensor([1, 2, 3], dtype=idtype), ctx=F.ctx())\n    g = dgl.remove_nodes(g, F.tensor([0], dtype=idtype))\n    assert g.num_nodes() == 2\n    assert g.num_edges() == 1\n    u, v = g.edges(form=\"uv\", order=\"eid\")\n    assert F.array_equal(u, F.tensor([1], dtype=idtype))\n    assert F.array_equal(v, F.tensor([1], dtype=idtype))\n    assert F.array_equal(g.ndata[\"hv\"], F.tensor([2, 3], dtype=idtype))\n    assert F.array_equal(g.edata[\"he\"], F.tensor([3], dtype=idtype))\n\n    # node id larger than current max node id\n    g = dgl.heterograph(\n        {(\"user\", \"plays\", \"game\"): ([0, 1], [1, 2])},\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    n = 0\n    g = dgl.remove_nodes(g, n, ntype=\"user\")\n    assert g.num_nodes(\"user\") == 1\n    assert g.num_nodes(\"game\") == 3\n    assert g.num_edges() == 1\n    u, v = g.edges(form=\"uv\", order=\"eid\")\n    assert F.array_equal(u, F.tensor([0], dtype=idtype))\n    assert F.array_equal(v, F.tensor([2], dtype=idtype))\n    g = dgl.heterograph(\n        {(\"user\", \"plays\", \"game\"): ([0, 1], [1, 2])},\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    n = [1]\n    g = dgl.remove_nodes(g, n, ntype=\"user\")\n    assert g.num_nodes(\"user\") == 1\n    assert g.num_nodes(\"game\") == 3\n    assert g.num_edges() == 1\n    u, v = g.edges(form=\"uv\", order=\"eid\")\n    assert F.array_equal(u, F.tensor([0], dtype=idtype))\n    assert F.array_equal(v, F.tensor([1], dtype=idtype))\n    g = dgl.heterograph(\n        {(\"user\", \"plays\", \"game\"): ([0, 1], [1, 2])},\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    n = F.tensor([0], dtype=idtype)\n    g = dgl.remove_nodes(g, n, ntype=\"game\")\n    assert g.num_nodes(\"user\") == 2\n    assert g.num_nodes(\"game\") == 2\n    assert g.num_edges() == 2\n    u, v = g.edges(form=\"uv\", order=\"eid\")\n    assert F.array_equal(u, F.tensor([0, 1], dtype=idtype))\n    assert F.array_equal(v, F.tensor([0, 1], dtype=idtype))\n\n    # heterogeneous graph\n    g = create_test_heterograph3(idtype)\n    g.edges[\"plays\"].data[\"h\"] = F.copy_to(\n        F.tensor([1, 2, 3, 4], dtype=idtype), ctx=F.ctx()\n    )\n    g = dgl.remove_nodes(g, 0, ntype=\"game\")\n    assert g.num_nodes(\"user\") == 3\n    assert g.num_nodes(\"game\") == 1\n    assert g.num_nodes(\"developer\") == 2\n    assert g.num_edges(\"plays\") == 2\n    assert g.num_edges(\"develops\") == 1\n    assert F.array_equal(\n        g.nodes[\"user\"].data[\"h\"], F.tensor([1, 1, 1], dtype=idtype)\n    )\n    assert F.array_equal(g.nodes[\"game\"].data[\"h\"], F.tensor([2], dtype=idtype))\n    assert F.array_equal(\n        g.nodes[\"developer\"].data[\"h\"], F.tensor([3, 3], dtype=idtype)\n    )\n    u, v = g.edges(form=\"uv\", order=\"eid\", etype=\"plays\")\n    assert F.array_equal(u, F.tensor([1, 2], dtype=idtype))\n    assert F.array_equal(v, F.tensor([0, 0], dtype=idtype))\n    assert F.array_equal(\n        g.edges[\"plays\"].data[\"h\"], F.tensor([3, 4], dtype=idtype)\n    )\n    u, v = g.edges(form=\"uv\", order=\"eid\", etype=\"develops\")\n    assert F.array_equal(u, F.tensor([1], dtype=idtype))\n    assert F.array_equal(v, F.tensor([0], dtype=idtype))\n\n    # batched graph\n    ctx = F.ctx()\n    g1 = dgl.graph(([0, 1], [1, 2]), num_nodes=5, idtype=idtype, device=ctx)\n    g2 = dgl.graph(([], []), idtype=idtype, device=ctx)\n    g3 = dgl.graph(([2, 3, 4], [3, 2, 1]), idtype=idtype, device=ctx)\n    bg = dgl.batch([g1, g2, g3])\n    bg_r = dgl.remove_nodes(bg, 1)\n    assert bg_r.batch_size == bg.batch_size\n    assert F.array_equal(\n        bg_r.batch_num_nodes(), F.tensor([4, 0, 5], dtype=idtype)\n    )\n    assert F.array_equal(\n        bg_r.batch_num_edges(), F.tensor([0, 0, 3], dtype=idtype)\n    )\n\n    bg_r = dgl.remove_nodes(bg, [1, 7])\n    assert bg_r.batch_size == bg.batch_size\n    assert F.array_equal(\n        bg_r.batch_num_nodes(), F.tensor([4, 0, 4], dtype=idtype)\n    )\n    assert F.array_equal(\n        bg_r.batch_num_edges(), F.tensor([0, 0, 1], dtype=idtype)\n    )\n\n    bg_r = dgl.remove_nodes(bg, F.tensor([1, 7], dtype=idtype))\n    assert bg_r.batch_size == bg.batch_size\n    assert F.array_equal(\n        bg_r.batch_num_nodes(), F.tensor([4, 0, 4], dtype=idtype)\n    )\n    assert F.array_equal(\n        bg_r.batch_num_edges(), F.tensor([0, 0, 1], dtype=idtype)\n    )\n\n    # batched heterogeneous graph\n    g1 = dgl.heterograph(\n        {\n            (\"user\", \"follows\", \"user\"): ([0, 1], [1, 2]),\n            (\"user\", \"plays\", \"game\"): ([1, 3], [0, 1]),\n        },\n        num_nodes_dict={\"user\": 4, \"game\": 3},\n        idtype=idtype,\n        device=ctx,\n    )\n    g2 = dgl.heterograph(\n        {\n            (\"user\", \"follows\", \"user\"): ([0, 2], [3, 4]),\n            (\"user\", \"plays\", \"game\"): ([], []),\n        },\n        num_nodes_dict={\"user\": 6, \"game\": 2},\n        idtype=idtype,\n        device=ctx,\n    )\n    g3 = dgl.heterograph(\n        {\n            (\"user\", \"follows\", \"user\"): ([], []),\n            (\"user\", \"plays\", \"game\"): ([1, 2], [1, 2]),\n        },\n        idtype=idtype,\n        device=ctx,\n    )\n    bg = dgl.batch([g1, g2, g3])\n    bg_r = dgl.remove_nodes(bg, 1, ntype=\"user\")\n    assert bg_r.batch_size == bg.batch_size\n    assert F.array_equal(\n        bg_r.batch_num_nodes(\"user\"), F.tensor([3, 6, 3], dtype=idtype)\n    )\n    assert F.array_equal(\n        bg.batch_num_nodes(\"game\"), bg_r.batch_num_nodes(\"game\")\n    )\n    assert F.array_equal(\n        bg_r.batch_num_edges(\"follows\"), F.tensor([0, 2, 0], dtype=idtype)\n    )\n    assert F.array_equal(\n        bg_r.batch_num_edges(\"plays\"), F.tensor([1, 0, 2], dtype=idtype)\n    )\n\n    bg_r = dgl.remove_nodes(bg, 6, ntype=\"game\")\n    assert bg_r.batch_size == bg.batch_size\n    assert F.array_equal(\n        bg.batch_num_nodes(\"user\"), bg_r.batch_num_nodes(\"user\")\n    )\n    assert F.array_equal(\n        bg_r.batch_num_nodes(\"game\"), F.tensor([3, 2, 2], dtype=idtype)\n    )\n    assert F.array_equal(\n        bg.batch_num_edges(\"follows\"), bg_r.batch_num_edges(\"follows\")\n    )\n    assert F.array_equal(\n        bg_r.batch_num_edges(\"plays\"), F.tensor([2, 0, 1], dtype=idtype)\n    )\n\n    bg_r = dgl.remove_nodes(bg, [1, 5, 6, 11], ntype=\"user\")\n    assert bg_r.batch_size == bg.batch_size\n    assert F.array_equal(\n        bg_r.batch_num_nodes(\"user\"), F.tensor([3, 4, 2], dtype=idtype)\n    )\n    assert F.array_equal(\n        bg.batch_num_nodes(\"game\"), bg_r.batch_num_nodes(\"game\")\n    )\n    assert F.array_equal(\n        bg_r.batch_num_edges(\"follows\"), F.tensor([0, 1, 0], dtype=idtype)\n    )\n    assert F.array_equal(\n        bg_r.batch_num_edges(\"plays\"), F.tensor([1, 0, 1], dtype=idtype)\n    )\n\n    bg_r = dgl.remove_nodes(bg, [0, 3, 4, 7], ntype=\"game\")\n    assert bg_r.batch_size == bg.batch_size\n    assert F.array_equal(\n        bg.batch_num_nodes(\"user\"), bg_r.batch_num_nodes(\"user\")\n    )\n    assert F.array_equal(\n        bg_r.batch_num_nodes(\"game\"), F.tensor([2, 0, 2], dtype=idtype)\n    )\n    assert F.array_equal(\n        bg.batch_num_edges(\"follows\"), bg_r.batch_num_edges(\"follows\")\n    )\n    assert F.array_equal(\n        bg_r.batch_num_edges(\"plays\"), F.tensor([1, 0, 1], dtype=idtype)\n    )\n\n    bg_r = dgl.remove_nodes(\n        bg, F.tensor([1, 5, 6, 11], dtype=idtype), ntype=\"user\"\n    )\n    assert bg_r.batch_size == bg.batch_size\n    assert F.array_equal(\n        bg_r.batch_num_nodes(\"user\"), F.tensor([3, 4, 2], dtype=idtype)\n    )\n    assert F.array_equal(\n        bg.batch_num_nodes(\"game\"), bg_r.batch_num_nodes(\"game\")\n    )\n    assert F.array_equal(\n        bg_r.batch_num_edges(\"follows\"), F.tensor([0, 1, 0], dtype=idtype)\n    )\n    assert F.array_equal(\n        bg_r.batch_num_edges(\"plays\"), F.tensor([1, 0, 1], dtype=idtype)\n    )\n\n    bg_r = dgl.remove_nodes(\n        bg, F.tensor([0, 3, 4, 7], dtype=idtype), ntype=\"game\"\n    )\n    assert bg_r.batch_size == bg.batch_size\n    assert F.array_equal(\n        bg.batch_num_nodes(\"user\"), bg_r.batch_num_nodes(\"user\")\n    )\n    assert F.array_equal(\n        bg_r.batch_num_nodes(\"game\"), F.tensor([2, 0, 2], dtype=idtype)\n    )\n    assert F.array_equal(\n        bg.batch_num_edges(\"follows\"), bg_r.batch_num_edges(\"follows\")\n    )\n    assert F.array_equal(\n        bg_r.batch_num_edges(\"plays\"), F.tensor([1, 0, 1], dtype=idtype)\n    )\n\n\n@parametrize_idtype\ndef test_add_selfloop(idtype):\n    # homogeneous graph\n\n    # test for fill_data is float\n    g = dgl.graph(([0, 0, 2], [2, 1, 0]), idtype=idtype, device=F.ctx())\n    g.edata[\"he\"] = F.copy_to(F.tensor([1, 2, 3], dtype=idtype), ctx=F.ctx())\n    g.edata[\"he1\"] = F.copy_to(\n        F.tensor([[0.0, 1.0], [2.0, 3.0], [4.0, 5.0]]), ctx=F.ctx()\n    )\n    g.ndata[\"hn\"] = F.copy_to(F.tensor([1, 2, 3], dtype=idtype), ctx=F.ctx())\n    g = dgl.add_self_loop(g)\n    assert g.num_nodes() == 3\n    assert g.num_edges() == 6\n    u, v = g.edges(form=\"uv\", order=\"eid\")\n    assert F.array_equal(u, F.tensor([0, 0, 2, 0, 1, 2], dtype=idtype))\n    assert F.array_equal(v, F.tensor([2, 1, 0, 0, 1, 2], dtype=idtype))\n    assert F.array_equal(\n        g.edata[\"he\"], F.tensor([1, 2, 3, 1, 1, 1], dtype=idtype)\n    )\n    assert F.array_equal(\n        g.edata[\"he1\"],\n        F.tensor(\n            [\n                [0.0, 1.0],\n                [2.0, 3.0],\n                [4.0, 5.0],\n                [1.0, 1.0],\n                [1.0, 1.0],\n                [1.0, 1.0],\n            ]\n        ),\n    )\n\n    # test for fill_data is int\n    g = dgl.graph(([0, 0, 2], [2, 1, 0]), idtype=idtype, device=F.ctx())\n    g.edata[\"he\"] = F.copy_to(F.tensor([1, 2, 3], dtype=idtype), ctx=F.ctx())\n    g.edata[\"he1\"] = F.copy_to(\n        F.tensor([[0, 1], [2, 3], [4, 5]], dtype=idtype), ctx=F.ctx()\n    )\n    g.ndata[\"hn\"] = F.copy_to(F.tensor([1, 2, 3], dtype=idtype), ctx=F.ctx())\n    g = dgl.add_self_loop(g, fill_data=1)\n    assert g.num_nodes() == 3\n    assert g.num_edges() == 6\n    u, v = g.edges(form=\"uv\", order=\"eid\")\n    assert F.array_equal(u, F.tensor([0, 0, 2, 0, 1, 2], dtype=idtype))\n    assert F.array_equal(v, F.tensor([2, 1, 0, 0, 1, 2], dtype=idtype))\n    assert F.array_equal(\n        g.edata[\"he\"], F.tensor([1, 2, 3, 1, 1, 1], dtype=idtype)\n    )\n    assert F.array_equal(\n        g.edata[\"he1\"],\n        F.tensor(\n            [[0, 1], [2, 3], [4, 5], [1, 1], [1, 1], [1, 1]], dtype=idtype\n        ),\n    )\n\n    # test for fill_data is str\n    g = dgl.graph(([0, 0, 2], [2, 1, 0]), idtype=idtype, device=F.ctx())\n    g.edata[\"he\"] = F.copy_to(F.tensor([1.0, 2.0, 3.0]), ctx=F.ctx())\n    g.edata[\"he1\"] = F.copy_to(\n        F.tensor([[0.0, 1.0], [2.0, 3.0], [4.0, 5.0]]), ctx=F.ctx()\n    )\n    g.ndata[\"hn\"] = F.copy_to(F.tensor([1, 2, 3], dtype=idtype), ctx=F.ctx())\n    g = dgl.add_self_loop(g, fill_data=\"sum\")\n    assert g.num_nodes() == 3\n    assert g.num_edges() == 6\n    u, v = g.edges(form=\"uv\", order=\"eid\")\n    assert F.array_equal(u, F.tensor([0, 0, 2, 0, 1, 2], dtype=idtype))\n    assert F.array_equal(v, F.tensor([2, 1, 0, 0, 1, 2], dtype=idtype))\n    assert F.array_equal(\n        g.edata[\"he\"], F.tensor([1.0, 2.0, 3.0, 3.0, 2.0, 1.0])\n    )\n    assert F.array_equal(\n        g.edata[\"he1\"],\n        F.tensor(\n            [\n                [0.0, 1.0],\n                [2.0, 3.0],\n                [4.0, 5.0],\n                [4.0, 5.0],\n                [2.0, 3.0],\n                [0.0, 1.0],\n            ]\n        ),\n    )\n\n    # bipartite graph\n    g = dgl.heterograph(\n        {(\"user\", \"plays\", \"game\"): ([0, 1, 2], [1, 2, 2])},\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    # nothing will happend\n    raise_error = False\n    try:\n        g = dgl.add_self_loop(g)\n    except:\n        raise_error = True\n    assert raise_error\n\n    # test for fill_data is float\n    g = create_test_heterograph5(idtype)\n    g.edges[\"follows\"].data[\"h1\"] = F.copy_to(\n        F.tensor([[0.0, 1.0], [1.0, 2.0]]), ctx=F.ctx()\n    )\n    g = dgl.add_self_loop(g, etype=\"follows\")\n    assert g.num_nodes(\"user\") == 3\n    assert g.num_nodes(\"game\") == 2\n    assert g.num_edges(\"follows\") == 5\n    assert g.num_edges(\"plays\") == 2\n    u, v = g.edges(form=\"uv\", order=\"eid\", etype=\"follows\")\n    assert F.array_equal(u, F.tensor([1, 2, 0, 1, 2], dtype=idtype))\n    assert F.array_equal(v, F.tensor([0, 1, 0, 1, 2], dtype=idtype))\n    assert F.array_equal(\n        g.edges[\"follows\"].data[\"h\"], F.tensor([1, 2, 1, 1, 1], dtype=idtype)\n    )\n    assert F.array_equal(\n        g.edges[\"follows\"].data[\"h1\"],\n        F.tensor([[0.0, 1.0], [1.0, 2.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0]]),\n    )\n    assert F.array_equal(\n        g.edges[\"plays\"].data[\"h\"], F.tensor([1, 2], dtype=idtype)\n    )\n\n    # test for fill_data is int\n    g = create_test_heterograph5(idtype)\n    g.edges[\"follows\"].data[\"h1\"] = F.copy_to(\n        F.tensor([[0, 1], [1, 2]], dtype=idtype), ctx=F.ctx()\n    )\n    g = dgl.add_self_loop(g, fill_data=1, etype=\"follows\")\n    assert g.num_nodes(\"user\") == 3\n    assert g.num_nodes(\"game\") == 2\n    assert g.num_edges(\"follows\") == 5\n    assert g.num_edges(\"plays\") == 2\n    u, v = g.edges(form=\"uv\", order=\"eid\", etype=\"follows\")\n    assert F.array_equal(u, F.tensor([1, 2, 0, 1, 2], dtype=idtype))\n    assert F.array_equal(v, F.tensor([0, 1, 0, 1, 2], dtype=idtype))\n    assert F.array_equal(\n        g.edges[\"follows\"].data[\"h\"], F.tensor([1, 2, 1, 1, 1], dtype=idtype)\n    )\n    assert F.array_equal(\n        g.edges[\"follows\"].data[\"h1\"],\n        F.tensor([[0, 1], [1, 2], [1, 1], [1, 1], [1, 1]], dtype=idtype),\n    )\n    assert F.array_equal(\n        g.edges[\"plays\"].data[\"h\"], F.tensor([1, 2], dtype=idtype)\n    )\n\n    # test for fill_data is str\n    g = dgl.heterograph(\n        {\n            (\"user\", \"follows\", \"user\"): (\n                F.tensor([1, 2], dtype=idtype),\n                F.tensor([0, 1], dtype=idtype),\n            ),\n            (\"user\", \"plays\", \"game\"): (\n                F.tensor([0, 1], dtype=idtype),\n                F.tensor([0, 1], dtype=idtype),\n            ),\n        },\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    g.nodes[\"user\"].data[\"h\"] = F.copy_to(\n        F.tensor([1, 1, 1], dtype=idtype), ctx=F.ctx()\n    )\n    g.nodes[\"game\"].data[\"h\"] = F.copy_to(\n        F.tensor([2, 2], dtype=idtype), ctx=F.ctx()\n    )\n    g.edges[\"follows\"].data[\"h\"] = F.copy_to(F.tensor([1.0, 2.0]), ctx=F.ctx())\n    g.edges[\"follows\"].data[\"h1\"] = F.copy_to(\n        F.tensor([[0.0, 1.0], [1.0, 2.0]]), ctx=F.ctx()\n    )\n    g.edges[\"plays\"].data[\"h\"] = F.copy_to(F.tensor([1.0, 2.0]), ctx=F.ctx())\n    g = dgl.add_self_loop(g, fill_data=\"mean\", etype=\"follows\")\n    assert g.num_nodes(\"user\") == 3\n    assert g.num_nodes(\"game\") == 2\n    assert g.num_edges(\"follows\") == 5\n    assert g.num_edges(\"plays\") == 2\n    u, v = g.edges(form=\"uv\", order=\"eid\", etype=\"follows\")\n    assert F.array_equal(u, F.tensor([1, 2, 0, 1, 2], dtype=idtype))\n    assert F.array_equal(v, F.tensor([0, 1, 0, 1, 2], dtype=idtype))\n    assert F.array_equal(\n        g.edges[\"follows\"].data[\"h\"], F.tensor([1.0, 2.0, 1.0, 2.0, 0.0])\n    )\n    assert F.array_equal(\n        g.edges[\"follows\"].data[\"h1\"],\n        F.tensor([[0.0, 1.0], [1.0, 2.0], [0.0, 1.0], [1.0, 2.0], [0.0, 0.0]]),\n    )\n    assert F.array_equal(g.edges[\"plays\"].data[\"h\"], F.tensor([1.0, 2.0]))\n\n    raise_error = False\n    try:\n        g = dgl.add_self_loop(g, etype=\"plays\")\n    except:\n        raise_error = True\n    assert raise_error\n\n\n@parametrize_idtype\ndef test_remove_selfloop(idtype):\n    # homogeneous graph\n    g = dgl.graph(([0, 0, 0, 1], [1, 0, 0, 2]), idtype=idtype, device=F.ctx())\n    g.edata[\"he\"] = F.copy_to(F.tensor([1, 2, 3, 4], dtype=idtype), ctx=F.ctx())\n    g = dgl.remove_self_loop(g)\n    assert g.num_nodes() == 3\n    assert g.num_edges() == 2\n    assert F.array_equal(g.edata[\"he\"], F.tensor([1, 4], dtype=idtype))\n\n    # bipartite graph\n    g = dgl.heterograph(\n        {(\"user\", \"plays\", \"game\"): ([0, 1, 2], [1, 2, 2])},\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    # nothing will happend\n    raise_error = False\n    try:\n        g = dgl.remove_self_loop(g, etype=\"plays\")\n    except:\n        raise_error = True\n    assert raise_error\n\n    g = create_test_heterograph4(idtype)\n    g = dgl.remove_self_loop(g, etype=\"follows\")\n    assert g.num_nodes(\"user\") == 3\n    assert g.num_nodes(\"game\") == 2\n    assert g.num_edges(\"follows\") == 2\n    assert g.num_edges(\"plays\") == 2\n    u, v = g.edges(form=\"uv\", order=\"eid\", etype=\"follows\")\n    assert F.array_equal(u, F.tensor([1, 2], dtype=idtype))\n    assert F.array_equal(v, F.tensor([0, 1], dtype=idtype))\n    assert F.array_equal(\n        g.edges[\"follows\"].data[\"h\"], F.tensor([2, 4], dtype=idtype)\n    )\n    assert F.array_equal(\n        g.edges[\"plays\"].data[\"h\"], F.tensor([1, 2], dtype=idtype)\n    )\n\n    raise_error = False\n    try:\n        g = dgl.remove_self_loop(g, etype=\"plays\")\n    except:\n        raise_error = True\n    assert raise_error\n\n    # batch information\n    g = dgl.graph(\n        ([0, 0, 0, 1, 3, 3, 4], [1, 0, 0, 2, 3, 4, 4]),\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    g.set_batch_num_nodes([3, 2])\n    g.set_batch_num_edges([4, 3])\n    g = dgl.remove_self_loop(g)\n    assert g.num_nodes() == 5\n    assert g.num_edges() == 3\n    assert F.array_equal(g.batch_num_nodes(), F.tensor([3, 2], dtype=idtype))\n    assert F.array_equal(g.batch_num_edges(), F.tensor([2, 1], dtype=idtype))\n\n\n@parametrize_idtype\ndef test_reorder_graph(idtype):\n    g = dgl.graph(\n        ([0, 1, 2, 3, 4], [2, 2, 3, 2, 3]), idtype=idtype, device=F.ctx()\n    )\n    g.ndata[\"h\"] = F.copy_to(F.randn((g.num_nodes(), 3)), ctx=F.ctx())\n    g.edata[\"w\"] = F.copy_to(F.randn((g.num_edges(), 2)), ctx=F.ctx())\n\n    # call with default: node_permute_algo=None, edge_permute_algo='src'\n    rg = dgl.reorder_graph(g)\n    assert dgl.EID in rg.edata.keys()\n    src = F.asnumpy(rg.edges()[0])\n    assert np.array_equal(src, np.sort(src))\n\n    # call with 'rcmk' node_permute_algo\n    rg = dgl.reorder_graph(g, node_permute_algo=\"rcmk\")\n    assert dgl.NID in rg.ndata.keys()\n    assert dgl.EID in rg.edata.keys()\n    src = F.asnumpy(rg.edges()[0])\n    assert np.array_equal(src, np.sort(src))\n\n    # call with 'dst' edge_permute_algo\n    rg = dgl.reorder_graph(g, edge_permute_algo=\"dst\")\n    dst = F.asnumpy(rg.edges()[1])\n    assert np.array_equal(dst, np.sort(dst))\n\n    # call with unknown edge_permute_algo\n    raise_error = False\n    try:\n        dgl.reorder_graph(g, edge_permute_algo=\"none\")\n    except:\n        raise_error = True\n    assert raise_error\n\n    # reorder back to original according to stored ids\n    rg = dgl.reorder_graph(g, node_permute_algo=\"rcmk\")\n    rg2 = dgl.reorder_graph(\n        rg,\n        \"custom\",\n        permute_config={\"nodes_perm\": np.argsort(F.asnumpy(rg.ndata[dgl.NID]))},\n    )\n    assert F.array_equal(g.ndata[\"h\"], rg2.ndata[\"h\"])\n    assert F.array_equal(g.edata[\"w\"], rg2.edata[\"w\"])\n\n    # do not store ids\n    rg = dgl.reorder_graph(g, store_ids=False)\n    assert not dgl.NID in rg.ndata.keys()\n    assert not dgl.EID in rg.edata.keys()\n\n    # metis does not work on windows.\n    if os.name == \"nt\":\n        pass\n    else:\n        # metis_partition may fail for small graph.\n        mg = create_large_graph(1000).to(F.ctx())\n\n        # call with metis strategy, but k is not specified\n        raise_error = False\n        try:\n            dgl.reorder_graph(mg, node_permute_algo=\"metis\")\n        except:\n            raise_error = True\n        assert raise_error\n\n        # call with metis strategy, k is specified\n        raise_error = False\n        try:\n            dgl.reorder_graph(\n                mg, node_permute_algo=\"metis\", permute_config={\"k\": 2}\n            )\n        except:\n            raise_error = True\n        assert not raise_error\n\n    # call with qualified nodes_perm specified\n    nodes_perm = np.random.permutation(g.num_nodes())\n    raise_error = False\n    try:\n        dgl.reorder_graph(\n            g,\n            node_permute_algo=\"custom\",\n            permute_config={\"nodes_perm\": nodes_perm},\n        )\n    except:\n        raise_error = True\n    assert not raise_error\n\n    # call with unqualified nodes_perm specified\n    raise_error = False\n    try:\n        dgl.reorder_graph(\n            g,\n            node_permute_algo=\"custom\",\n            permute_config={\"nodes_perm\": nodes_perm[: g.num_nodes() - 1]},\n        )\n    except:\n        raise_error = True\n    assert raise_error\n\n    # call with unsupported strategy\n    raise_error = False\n    try:\n        dgl.reorder_graph(g, node_permute_algo=\"cmk\")\n    except:\n        raise_error = True\n    assert raise_error\n\n    # heterograph: not supported\n    raise_error = False\n    try:\n        hg = dgl.heterogrpah(\n            {(\"user\", \"follow\", \"user\"): ([0, 1], [1, 2])},\n            idtype=idtype,\n            device=F.ctx(),\n        )\n        dgl.reorder_graph(hg)\n    except:\n        raise_error = True\n    assert raise_error\n\n    # TODO: shall we fix them?\n    # add 'csc' format if needed\n    # fg = g.formats('csr')\n    # assert 'csc' not in sum(fg.formats().values(), [])\n    # rfg = dgl.reorder_graph(fg)\n    # assert 'csc' in sum(rfg.formats().values(), [])\n\n\n@unittest.skipIf(\n    dgl.backend.backend_name == \"tensorflow\",\n    reason=\"TF doesn't support a slicing operation\",\n)\n@parametrize_idtype\ndef test_norm_by_dst(idtype):\n    # Case1: A homogeneous graph\n    g = dgl.graph(([0, 1, 1], [1, 1, 2]), idtype=idtype, device=F.ctx())\n    eweight = dgl.norm_by_dst(g)\n    assert F.allclose(eweight, F.tensor([0.5, 0.5, 1.0]))\n\n    # Case2: A heterogeneous graph\n    g = dgl.heterograph(\n        {\n            (\"user\", \"follows\", \"user\"): ([0, 1], [1, 2]),\n            (\"user\", \"plays\", \"game\"): ([0, 1, 1], [1, 1, 2]),\n        },\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    eweight = dgl.norm_by_dst(g, etype=(\"user\", \"plays\", \"game\"))\n    assert F.allclose(eweight, F.tensor([0.5, 0.5, 1.0]))\n\n\n@parametrize_idtype\ndef test_module_add_self_loop(idtype):\n    g = dgl.graph(([1, 1], [1, 2]), idtype=idtype, device=F.ctx())\n    g.ndata[\"h\"] = F.randn((g.num_nodes(), 2))\n    g.edata[\"w\"] = F.randn((g.num_edges(), 3))\n\n    # Case1: add self-loops with the default setting\n    transform = dgl.AddSelfLoop()\n    new_g = transform(g)\n    assert new_g.device == g.device\n    assert new_g.idtype == g.idtype\n    assert new_g.num_nodes() == g.num_nodes()\n    assert new_g.num_edges() == 4\n    src, dst = new_g.edges()\n    eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))\n    assert eset == {(0, 0), (1, 1), (1, 2), (2, 2)}\n    assert \"h\" in new_g.ndata\n    assert \"w\" in new_g.edata\n\n    # Case2: remove self-loops first to avoid duplicate ones\n    transform = dgl.AddSelfLoop(allow_duplicate=True)\n    new_g = transform(g)\n    assert new_g.device == g.device\n    assert new_g.idtype == g.idtype\n    assert new_g.num_nodes() == g.num_nodes()\n    assert new_g.num_edges() == 5\n    src, dst = new_g.edges()\n    eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))\n    assert eset == {(0, 0), (1, 1), (1, 2), (2, 2)}\n    assert \"h\" in new_g.ndata\n    assert \"w\" in new_g.edata\n\n    # Case3: add self-loops for a homogeneous graph (the example in doc)\n    transform = dgl.AddSelfLoop(fill_data=\"sum\")\n    g = dgl.graph(([0, 0, 2], [2, 1, 0]), idtype=idtype, device=F.ctx())\n    new_g = transform(g)\n    assert new_g.device == g.device\n    assert new_g.idtype == g.idtype\n    assert new_g.num_nodes() == g.num_nodes()\n    src, dst = new_g.edges()\n    eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))\n    assert eset == {(0, 2), (0, 1), (2, 0), (0, 0), (1, 1), (2, 2)}\n\n    # Create a heterogeneous graph\n    g = dgl.heterograph(\n        {\n            (\"user\", \"plays\", \"game\"): ([0], [1]),\n            (\"user\", \"follows\", \"user\"): ([1], [3]),\n        },\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    g.nodes[\"user\"].data[\"h1\"] = F.randn((4, 2))\n    g.edges[\"plays\"].data[\"w1\"] = F.randn((1, 3))\n    g.nodes[\"game\"].data[\"h2\"] = F.randn((2, 4))\n    g.edges[\"follows\"].data[\"w2\"] = F.randn((1, 5))\n\n    # Case4: add self-loops for a heterogeneous graph\n    new_g = transform(g)\n    assert new_g.device == g.device\n    assert new_g.idtype == g.idtype\n    assert new_g.ntypes == g.ntypes\n    assert new_g.canonical_etypes == g.canonical_etypes\n    for nty in new_g.ntypes:\n        assert new_g.num_nodes(nty) == g.num_nodes(nty)\n    assert new_g.num_edges(\"plays\") == 1\n    assert new_g.num_edges(\"follows\") == 5\n    assert \"h1\" in new_g.nodes[\"user\"].data\n    assert \"h2\" in new_g.nodes[\"game\"].data\n    assert \"w1\" in new_g.edges[\"plays\"].data\n    assert \"w2\" in new_g.edges[\"follows\"].data\n\n    # Case5: add self-etypes for a heterogeneous graph\n    transform = dgl.AddSelfLoop(new_etypes=True)\n    new_g = transform(g)\n    assert new_g.device == g.device\n    assert new_g.idtype == g.idtype\n    assert new_g.ntypes == g.ntypes\n    assert set(new_g.canonical_etypes) == {\n        (\"user\", \"plays\", \"game\"),\n        (\"user\", \"follows\", \"user\"),\n        (\"user\", \"self\", \"user\"),\n        (\"game\", \"self\", \"game\"),\n    }\n    for nty in new_g.ntypes:\n        assert new_g.num_nodes(nty) == g.num_nodes(nty)\n    assert new_g.num_edges(\"plays\") == 1\n    assert new_g.num_edges(\"follows\") == 5\n    assert new_g.num_edges((\"user\", \"self\", \"user\")) == 4\n    assert new_g.num_edges((\"game\", \"self\", \"game\")) == 2\n    assert \"h1\" in new_g.nodes[\"user\"].data\n    assert \"h2\" in new_g.nodes[\"game\"].data\n    assert \"w1\" in new_g.edges[\"plays\"].data\n    assert \"w2\" in new_g.edges[\"follows\"].data\n\n\n@parametrize_idtype\ndef test_module_remove_self_loop(idtype):\n    transform = dgl.RemoveSelfLoop()\n\n    # Case1: homogeneous graph\n    g = dgl.graph(([1, 1], [1, 2]), idtype=idtype, device=F.ctx())\n    g.ndata[\"h\"] = F.randn((g.num_nodes(), 2))\n    g.edata[\"w\"] = F.randn((g.num_edges(), 3))\n    new_g = transform(g)\n    assert new_g.device == g.device\n    assert new_g.idtype == g.idtype\n    assert new_g.num_nodes() == g.num_nodes()\n    assert new_g.num_edges() == 1\n    src, dst = new_g.edges()\n    eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))\n    assert eset == {(1, 2)}\n    assert \"h\" in new_g.ndata\n    assert \"w\" in new_g.edata\n\n    # Case2: heterogeneous graph\n    g = dgl.heterograph(\n        {\n            (\"user\", \"plays\", \"game\"): ([0, 1], [1, 1]),\n            (\"user\", \"follows\", \"user\"): ([1, 2], [2, 2]),\n        },\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    g.nodes[\"user\"].data[\"h1\"] = F.randn((3, 2))\n    g.edges[\"plays\"].data[\"w1\"] = F.randn((2, 3))\n    g.nodes[\"game\"].data[\"h2\"] = F.randn((2, 4))\n    g.edges[\"follows\"].data[\"w2\"] = F.randn((2, 5))\n\n    new_g = transform(g)\n    assert new_g.device == g.device\n    assert new_g.idtype == g.idtype\n    assert new_g.ntypes == g.ntypes\n    assert new_g.canonical_etypes == g.canonical_etypes\n    for nty in new_g.ntypes:\n        assert new_g.num_nodes(nty) == g.num_nodes(nty)\n    assert new_g.num_edges(\"plays\") == 2\n    assert new_g.num_edges(\"follows\") == 1\n    assert \"h1\" in new_g.nodes[\"user\"].data\n    assert \"h2\" in new_g.nodes[\"game\"].data\n    assert \"w1\" in new_g.edges[\"plays\"].data\n    assert \"w2\" in new_g.edges[\"follows\"].data\n\n\n@parametrize_idtype\ndef test_module_add_reverse(idtype):\n    transform = dgl.AddReverse()\n\n    # Case1: Add reverse edges for a homogeneous graph\n    g = dgl.graph(([0], [1]), idtype=idtype, device=F.ctx())\n    g.ndata[\"h\"] = F.randn((g.num_nodes(), 3))\n    g.edata[\"w\"] = F.randn((g.num_edges(), 2))\n    new_g = transform(g)\n    assert new_g.device == g.device\n    assert new_g.idtype == g.idtype\n    assert g.num_nodes() == new_g.num_nodes()\n    src, dst = new_g.edges()\n    eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))\n    assert eset == {(0, 1), (1, 0)}\n    assert F.allclose(g.ndata[\"h\"], new_g.ndata[\"h\"])\n    assert F.allclose(g.edata[\"w\"], F.narrow_row(new_g.edata[\"w\"], 0, 1))\n    assert F.allclose(\n        F.narrow_row(new_g.edata[\"w\"], 1, 2),\n        F.zeros((1, 2), F.float32, F.ctx()),\n    )\n\n    # Case2: Add reverse edges for a homogeneous graph and copy edata\n    transform = dgl.AddReverse(copy_edata=True)\n    new_g = transform(g)\n    assert new_g.device == g.device\n    assert new_g.idtype == g.idtype\n    assert g.num_nodes() == new_g.num_nodes()\n    src, dst = new_g.edges()\n    eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))\n    assert eset == {(0, 1), (1, 0)}\n    assert F.allclose(g.ndata[\"h\"], new_g.ndata[\"h\"])\n    assert F.allclose(g.edata[\"w\"], F.narrow_row(new_g.edata[\"w\"], 0, 1))\n    assert F.allclose(g.edata[\"w\"], F.narrow_row(new_g.edata[\"w\"], 1, 2))\n\n    # Case3: Add reverse edges for a heterogeneous graph\n    g = dgl.heterograph(\n        {\n            (\"user\", \"plays\", \"game\"): ([0, 1], [1, 1]),\n            (\"user\", \"follows\", \"user\"): ([1, 2], [2, 2]),\n        },\n        device=F.ctx(),\n    )\n    new_g = transform(g)\n    assert new_g.device == g.device\n    assert new_g.idtype == g.idtype\n    assert g.ntypes == new_g.ntypes\n    assert set(new_g.canonical_etypes) == {\n        (\"user\", \"plays\", \"game\"),\n        (\"user\", \"follows\", \"user\"),\n        (\"game\", \"rev_plays\", \"user\"),\n    }\n    for nty in g.ntypes:\n        assert g.num_nodes(nty) == new_g.num_nodes(nty)\n\n    src, dst = new_g.edges(etype=\"plays\")\n    eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))\n    assert eset == {(0, 1), (1, 1)}\n\n    src, dst = new_g.edges(etype=\"follows\")\n    eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))\n    assert eset == {(1, 2), (2, 2), (2, 1)}\n\n    src, dst = new_g.edges(etype=\"rev_plays\")\n    eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))\n    assert eset == {(1, 1), (1, 0)}\n\n    # Case4: Enforce reverse edge types for symmetric canonical edge types\n    transform = dgl.AddReverse(sym_new_etype=True)\n    new_g = transform(g)\n    assert new_g.device == g.device\n    assert new_g.idtype == g.idtype\n    assert g.ntypes == new_g.ntypes\n    assert set(new_g.canonical_etypes) == {\n        (\"user\", \"plays\", \"game\"),\n        (\"user\", \"follows\", \"user\"),\n        (\"game\", \"rev_plays\", \"user\"),\n        (\"user\", \"rev_follows\", \"user\"),\n    }\n    for nty in g.ntypes:\n        assert g.num_nodes(nty) == new_g.num_nodes(nty)\n\n    src, dst = new_g.edges(etype=\"plays\")\n    eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))\n    assert eset == {(0, 1), (1, 1)}\n\n    src, dst = new_g.edges(etype=\"follows\")\n    eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))\n    assert eset == {(1, 2), (2, 2)}\n\n    src, dst = new_g.edges(etype=\"rev_plays\")\n    eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))\n    assert eset == {(1, 1), (1, 0)}\n\n    src, dst = new_g.edges(etype=\"rev_follows\")\n    eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))\n    assert eset == {(2, 1), (2, 2)}\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\", reason=\"GPU not supported for to_simple\"\n)\n@parametrize_idtype\ndef test_module_to_simple(idtype):\n    transform = dgl.ToSimple()\n    g = dgl.graph(([0, 1, 1], [1, 2, 2]), idtype=idtype, device=F.ctx())\n    g.ndata[\"h\"] = F.randn((g.num_nodes(), 2))\n    g.edata[\"w\"] = F.tensor([[0.1], [0.2], [0.3]])\n    sg = transform(g)\n    assert sg.device == g.device\n    assert sg.idtype == g.idtype\n    assert sg.num_nodes() == g.num_nodes()\n    assert sg.num_edges() == 2\n    src, dst = sg.edges()\n    eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))\n    assert eset == {(0, 1), (1, 2)}\n    assert F.allclose(sg.edata[\"count\"], F.tensor([1, 2]))\n    assert F.allclose(sg.ndata[\"h\"], g.ndata[\"h\"])\n\n    g = dgl.heterograph(\n        {\n            (\"user\", \"follows\", \"user\"): ([0, 1, 1], [1, 2, 2]),\n            (\"user\", \"plays\", \"game\"): ([0, 1, 0], [1, 1, 1]),\n        }\n    )\n    sg = transform(g)\n    assert sg.device == g.device\n    assert sg.idtype == g.idtype\n    assert sg.ntypes == g.ntypes\n    assert sg.canonical_etypes == g.canonical_etypes\n    for nty in sg.ntypes:\n        assert sg.num_nodes(nty) == g.num_nodes(nty)\n    for ety in sg.canonical_etypes:\n        assert sg.num_edges(ety) == 2\n\n    src, dst = sg.edges(etype=\"follows\")\n    eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))\n    assert eset == {(0, 1), (1, 2)}\n\n    src, dst = sg.edges(etype=\"plays\")\n    eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))\n    assert eset == {(0, 1), (1, 1)}\n\n\n@parametrize_idtype\ndef test_module_line_graph(idtype):\n    transform = dgl.LineGraph()\n    g = dgl.graph(([0, 1, 1], [1, 0, 2]), idtype=idtype, device=F.ctx())\n    g.ndata[\"h\"] = F.tensor([[0.0], [1.0], [2.0]])\n    g.edata[\"w\"] = F.tensor([[0.0], [0.1], [0.2]])\n    new_g = transform(g)\n    assert new_g.device == g.device\n    assert new_g.idtype == g.idtype\n    assert new_g.num_nodes() == g.num_edges()\n    src, dst = new_g.edges()\n    eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))\n    assert eset == {(0, 1), (0, 2), (1, 0)}\n\n    transform = dgl.LineGraph(backtracking=False)\n    new_g = transform(g)\n    assert new_g.device == g.device\n    assert new_g.idtype == g.idtype\n    assert new_g.num_nodes() == g.num_edges()\n    src, dst = new_g.edges()\n    eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))\n    assert eset == {(0, 2)}\n\n\n@parametrize_idtype\ndef test_module_khop_graph(idtype):\n    transform = dgl.KHopGraph(2)\n    g = dgl.graph(([0, 1], [1, 2]), idtype=idtype, device=F.ctx())\n    g.ndata[\"h\"] = F.randn((g.num_nodes(), 2))\n    new_g = transform(g)\n    assert new_g.device == g.device\n    assert new_g.idtype == g.idtype\n    assert new_g.num_nodes() == g.num_nodes()\n    assert F.allclose(g.ndata[\"h\"], new_g.ndata[\"h\"])\n    src, dst = new_g.edges()\n    eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))\n    assert eset == {(0, 2)}\n\n\n@parametrize_idtype\ndef test_module_add_metapaths(idtype):\n    g = dgl.heterograph(\n        {\n            (\"person\", \"author\", \"paper\"): ([0, 0, 1], [1, 2, 2]),\n            (\"paper\", \"accepted\", \"venue\"): ([1], [0]),\n            (\"paper\", \"rejected\", \"venue\"): ([2], [1]),\n        },\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    g.nodes[\"venue\"].data[\"h\"] = F.randn((g.num_nodes(\"venue\"), 2))\n    g.edges[\"author\"].data[\"h\"] = F.randn((g.num_edges(\"author\"), 3))\n\n    # Case1: keep_orig_edges is True\n    metapaths = {\n        \"accepted\": [\n            (\"person\", \"author\", \"paper\"),\n            (\"paper\", \"accepted\", \"venue\"),\n        ],\n        \"rejected\": [\n            (\"person\", \"author\", \"paper\"),\n            (\"paper\", \"rejected\", \"venue\"),\n        ],\n    }\n    transform = dgl.AddMetaPaths(metapaths)\n    new_g = transform(g)\n    assert new_g.device == g.device\n    assert new_g.idtype == g.idtype\n    assert new_g.ntypes == g.ntypes\n    assert set(new_g.canonical_etypes) == {\n        (\"person\", \"author\", \"paper\"),\n        (\"paper\", \"accepted\", \"venue\"),\n        (\"paper\", \"rejected\", \"venue\"),\n        (\"person\", \"accepted\", \"venue\"),\n        (\"person\", \"rejected\", \"venue\"),\n    }\n    for nty in new_g.ntypes:\n        assert new_g.num_nodes(nty) == g.num_nodes(nty)\n    for ety in g.canonical_etypes:\n        assert new_g.num_edges(ety) == g.num_edges(ety)\n    assert F.allclose(\n        g.nodes[\"venue\"].data[\"h\"], new_g.nodes[\"venue\"].data[\"h\"]\n    )\n    assert F.allclose(\n        g.edges[\"author\"].data[\"h\"], new_g.edges[\"author\"].data[\"h\"]\n    )\n\n    src, dst = new_g.edges(etype=(\"person\", \"accepted\", \"venue\"))\n    eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))\n    assert eset == {(0, 0)}\n\n    src, dst = new_g.edges(etype=(\"person\", \"rejected\", \"venue\"))\n    eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))\n    assert eset == {(0, 1), (1, 1)}\n\n    # Case2: keep_orig_edges is False\n    transform = dgl.AddMetaPaths(metapaths, keep_orig_edges=False)\n    new_g = transform(g)\n    assert new_g.device == g.device\n    assert new_g.idtype == g.idtype\n    assert new_g.ntypes == g.ntypes\n    assert len(new_g.canonical_etypes) == 2\n    for nty in new_g.ntypes:\n        assert new_g.num_nodes(nty) == g.num_nodes(nty)\n    assert F.allclose(\n        g.nodes[\"venue\"].data[\"h\"], new_g.nodes[\"venue\"].data[\"h\"]\n    )\n\n    src, dst = new_g.edges(etype=(\"person\", \"accepted\", \"venue\"))\n    eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))\n    assert eset == {(0, 0)}\n\n    src, dst = new_g.edges(etype=(\"person\", \"rejected\", \"venue\"))\n    eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))\n    assert eset == {(0, 1), (1, 1)}\n\n\n@parametrize_idtype\ndef test_module_compose(idtype):\n    g = dgl.graph(([0, 1], [1, 2]), idtype=idtype, device=F.ctx())\n    transform = dgl.Compose([dgl.AddReverse(), dgl.AddSelfLoop()])\n    new_g = transform(g)\n    assert new_g.device == g.device\n    assert new_g.idtype == g.idtype\n    assert new_g.num_edges() == 7\n\n    src, dst = new_g.edges()\n    eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))\n    assert eset == {(0, 1), (1, 2), (1, 0), (2, 1), (0, 0), (1, 1), (2, 2)}\n\n\n@parametrize_idtype\ndef test_module_gcnnorm(idtype):\n    g = dgl.heterograph(\n        {\n            (\"A\", \"r1\", \"A\"): ([0, 1, 2], [0, 0, 1]),\n            (\"A\", \"r2\", \"B\"): ([0, 0], [1, 1]),\n            (\"B\", \"r3\", \"B\"): ([0, 1, 2], [0, 0, 1]),\n        },\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    g.edges[\"r3\"].data[\"w\"] = F.tensor([0.1, 0.2, 0.3])\n    transform = dgl.GCNNorm()\n    new_g = transform(g)\n    assert \"w\" not in new_g.edges[(\"A\", \"r2\", \"B\")].data\n    assert F.allclose(\n        new_g.edges[(\"A\", \"r1\", \"A\")].data[\"w\"],\n        F.tensor([1.0 / 2, 1.0 / math.sqrt(2), 0.0]),\n    )\n    assert F.allclose(\n        new_g.edges[(\"B\", \"r3\", \"B\")].data[\"w\"],\n        F.tensor([1.0 / 3, 2.0 / 3, 0.0]),\n    )\n\n\n@unittest.skipIf(\n    dgl.backend.backend_name != \"pytorch\", reason=\"Only support PyTorch for now\"\n)\n@parametrize_idtype\ndef test_module_ppr(idtype):\n    g = dgl.graph(\n        ([0, 1, 2, 3, 4], [2, 3, 4, 5, 3]), idtype=idtype, device=F.ctx()\n    )\n    g.ndata[\"h\"] = F.randn((6, 2))\n    transform = dgl.PPR(avg_degree=2)\n    new_g = transform(g)\n    assert new_g.idtype == g.idtype\n    assert new_g.device == g.device\n    assert new_g.num_nodes() == g.num_nodes()\n    src, dst = new_g.edges()\n    eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))\n    assert eset == {\n        (0, 0),\n        (0, 2),\n        (0, 4),\n        (1, 1),\n        (1, 3),\n        (1, 5),\n        (2, 2),\n        (2, 3),\n        (2, 4),\n        (3, 3),\n        (3, 5),\n        (4, 3),\n        (4, 4),\n        (4, 5),\n        (5, 5),\n    }\n    assert F.allclose(g.ndata[\"h\"], new_g.ndata[\"h\"])\n    assert \"w\" in new_g.edata\n\n    # Prior edge weights\n    g.edata[\"w\"] = F.tensor([0.1, 0.2, 0.3, 0.4, 0.5])\n    new_g = transform(g)\n    src, dst = new_g.edges()\n    eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))\n    assert eset == {\n        (0, 0),\n        (1, 1),\n        (1, 3),\n        (2, 2),\n        (2, 3),\n        (2, 4),\n        (3, 3),\n        (3, 5),\n        (4, 3),\n        (4, 4),\n        (4, 5),\n        (5, 5),\n    }\n\n\n@unittest.skipIf(\n    dgl.backend.backend_name != \"pytorch\", reason=\"Only support PyTorch for now\"\n)\n@parametrize_idtype\ndef test_module_heat_kernel(idtype):\n    # Case1: directed graph\n    g = dgl.graph(\n        ([0, 1, 2, 3, 4], [2, 3, 4, 5, 3]), idtype=idtype, device=F.ctx()\n    )\n    g.ndata[\"h\"] = F.randn((6, 2))\n    transform = dgl.HeatKernel(avg_degree=1)\n    new_g = transform(g)\n    assert new_g.idtype == g.idtype\n    assert new_g.device == g.device\n    assert new_g.num_nodes() == g.num_nodes()\n    assert F.allclose(g.ndata[\"h\"], new_g.ndata[\"h\"])\n    assert \"w\" in new_g.edata\n\n    # Case2: weighted undirected graph\n    g = dgl.graph(([0, 1, 2, 3], [1, 0, 3, 2]), idtype=idtype, device=F.ctx())\n    g.edata[\"w\"] = F.tensor([0.1, 0.2, 0.3, 0.4])\n    new_g = transform(g)\n    src, dst = new_g.edges()\n    eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))\n    assert eset == {(0, 0), (1, 1), (2, 2), (3, 3)}\n\n\n@unittest.skipIf(\n    dgl.backend.backend_name != \"pytorch\", reason=\"Only support PyTorch for now\"\n)\n@parametrize_idtype\ndef test_module_gdc(idtype):\n    transform = dgl.GDC([0.1, 0.2, 0.1], avg_degree=1)\n    g = dgl.graph(\n        ([0, 1, 2, 3, 4], [2, 3, 4, 5, 3]), idtype=idtype, device=F.ctx()\n    )\n    g.ndata[\"h\"] = F.randn((6, 2))\n    new_g = transform(g)\n    assert new_g.idtype == g.idtype\n    assert new_g.device == g.device\n    assert new_g.num_nodes() == g.num_nodes()\n    src, dst = new_g.edges()\n    eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))\n    assert eset == {\n        (0, 0),\n        (0, 2),\n        (0, 4),\n        (1, 1),\n        (1, 3),\n        (1, 5),\n        (2, 2),\n        (2, 3),\n        (2, 4),\n        (3, 3),\n        (3, 5),\n        (4, 3),\n        (4, 4),\n        (4, 5),\n        (5, 5),\n    }\n    assert F.allclose(g.ndata[\"h\"], new_g.ndata[\"h\"])\n    assert \"w\" in new_g.edata\n\n    # Prior edge weights\n    g.edata[\"w\"] = F.tensor([0.1, 0.2, 0.3, 0.4, 0.5])\n    new_g = transform(g)\n    src, dst = new_g.edges()\n    eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))\n    assert eset == {(0, 0), (1, 1), (2, 2), (3, 3), (4, 3), (4, 4), (5, 5)}\n\n\n@unittest.skipIf(\n    dgl.backend.backend_name == \"tensorflow\",\n    reason=\"TF doesn't support a slicing operation\",\n)\n@parametrize_idtype\ndef test_module_node_shuffle(idtype):\n    transform = dgl.NodeShuffle()\n    g = dgl.heterograph(\n        {(\"A\", \"r\", \"B\"): ([0, 1], [1, 2])}, idtype=idtype, device=F.ctx()\n    )\n    g.nodes[\"B\"].data[\"h\"] = F.randn((g.num_nodes(\"B\"), 2))\n    old_nfeat = g.nodes[\"B\"].data[\"h\"]\n    new_g = transform(g)\n    new_nfeat = g.nodes[\"B\"].data[\"h\"]\n    assert F.allclose(old_nfeat, new_nfeat)\n\n\n@unittest.skipIf(\n    dgl.backend.backend_name != \"pytorch\", reason=\"Only support PyTorch for now\"\n)\n@parametrize_idtype\ndef test_module_drop_node(idtype):\n    transform = dgl.DropNode()\n    g = dgl.heterograph(\n        {(\"A\", \"r\", \"B\"): ([0, 1], [1, 2])}, idtype=idtype, device=F.ctx()\n    )\n    num_nodes_old = g.num_nodes()\n    new_g = transform(g)\n    assert new_g.idtype == g.idtype\n    assert new_g.device == g.device\n    assert new_g.ntypes == g.ntypes\n    assert new_g.canonical_etypes == g.canonical_etypes\n    num_nodes_new = g.num_nodes()\n    # Ensure that the original graph is not corrupted\n    assert num_nodes_old == num_nodes_new\n\n\n@unittest.skipIf(\n    dgl.backend.backend_name != \"pytorch\", reason=\"Only support PyTorch for now\"\n)\n@parametrize_idtype\ndef test_module_drop_edge(idtype):\n    transform = dgl.DropEdge()\n    g = dgl.heterograph(\n        {\n            (\"A\", \"r1\", \"B\"): ([0, 1], [1, 2]),\n            (\"C\", \"r2\", \"C\"): ([3, 4, 5], [6, 7, 8]),\n        },\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    num_edges_old = g.num_edges()\n    new_g = transform(g)\n    assert new_g.idtype == g.idtype\n    assert new_g.device == g.device\n    assert new_g.ntypes == g.ntypes\n    assert new_g.canonical_etypes == g.canonical_etypes\n    num_edges_new = g.num_edges()\n    # Ensure that the original graph is not corrupted\n    assert num_edges_old == num_edges_new\n\n\n@parametrize_idtype\ndef test_module_add_edge(idtype):\n    transform = dgl.AddEdge()\n    g = dgl.heterograph(\n        {\n            (\"A\", \"r1\", \"B\"): ([0, 1, 2, 3, 4], [1, 2, 3, 4, 5]),\n            (\"C\", \"r2\", \"C\"): ([0, 1, 2, 3, 4], [1, 2, 3, 4, 5]),\n        },\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    num_edges_old = g.num_edges()\n    new_g = transform(g)\n    assert new_g.num_edges((\"A\", \"r1\", \"B\")) == 6\n    assert new_g.num_edges((\"C\", \"r2\", \"C\")) == 6\n    assert new_g.idtype == g.idtype\n    assert new_g.device == g.device\n    assert new_g.ntypes == g.ntypes\n    assert new_g.canonical_etypes == g.canonical_etypes\n    num_edges_new = g.num_edges()\n    # Ensure that the original graph is not corrupted\n    assert num_edges_old == num_edges_new\n\n\n@parametrize_idtype\ndef test_module_random_walk_pe(idtype):\n    transform = dgl.RandomWalkPE(2, \"rwpe\")\n    g = dgl.graph(([0, 1, 1], [1, 1, 0]), idtype=idtype, device=F.ctx())\n    new_g = transform(g)\n    tgt = F.copy_to(F.tensor([[0.0, 0.5], [0.5, 0.75]]), g.device)\n    assert F.allclose(new_g.ndata[\"rwpe\"], tgt)\n\n\n@parametrize_idtype\ndef test_module_lap_pe(idtype):\n    g = dgl.graph(\n        ([2, 1, 0, 3, 1, 1], [3, 1, 1, 2, 1, 0]), idtype=idtype, device=F.ctx()\n    )\n    tgt_eigval = F.copy_to(\n        F.repeat(\n            F.tensor([[1.1534e-17, 1.3333e00, 2.0, np.nan, np.nan]]),\n            g.num_nodes(),\n            dim=0,\n        ),\n        g.device,\n    )\n    tgt_pe = F.copy_to(\n        F.tensor(\n            [\n                [0.5, 0.86602539, 0.0, 0.0, 0.0],\n                [0.86602539, 0.5, 0.0, 0.0, 0.0],\n                [0.0, 0.0, 0.70710677, 0.0, 0.0],\n                [0.0, 0.0, 0.70710677, 0.0, 0.0],\n            ]\n        ),\n        g.device,\n    )\n\n    # without padding (k<n)\n    transform = dgl.LapPE(2, feat_name=\"lappe\")\n    new_g = transform(g)\n    # tensorflow has no abs() api\n    if dgl.backend.backend_name == \"tensorflow\":\n        assert F.allclose(new_g.ndata[\"lappe\"].__abs__(), tgt_pe[:, :2])\n    # pytorch & mxnet\n    else:\n        assert F.allclose(new_g.ndata[\"lappe\"].abs(), tgt_pe[:, :2])\n\n    # with padding (k>=n)\n    transform = dgl.LapPE(5, feat_name=\"lappe\", padding=True)\n    new_g = transform(g)\n    # tensorflow has no abs() api\n    if dgl.backend.backend_name == \"tensorflow\":\n        assert F.allclose(new_g.ndata[\"lappe\"].__abs__(), tgt_pe)\n    # pytorch & mxnet\n    else:\n        assert F.allclose(new_g.ndata[\"lappe\"].abs(), tgt_pe)\n\n    # with eigenvalues\n    transform = dgl.LapPE(\n        5, feat_name=\"lappe\", eigval_name=\"eigval\", padding=True\n    )\n    new_g = transform(g)\n    # tensorflow has no abs() api\n    if dgl.backend.backend_name == \"tensorflow\":\n        assert F.allclose(new_g.ndata[\"eigval\"][:, :3], tgt_eigval[:, :3])\n        assert F.allclose(new_g.ndata[\"lappe\"].__abs__(), tgt_pe)\n    # pytorch & mxnet\n    else:\n        assert F.allclose(new_g.ndata[\"eigval\"][:, :3], tgt_eigval[:, :3])\n        assert F.allclose(new_g.ndata[\"lappe\"].abs(), tgt_pe)\n\n\n@unittest.skipIf(\n    dgl.backend.backend_name != \"pytorch\", reason=\"Only support PyTorch for now\"\n)\n@pytest.mark.parametrize(\"g\", get_cases([\"has_scalar_e_feature\"]))\ndef test_module_sign(g):\n    import torch\n\n    atol = 1e-06\n\n    ctx = F.ctx()\n    g = g.to(ctx)\n    adj = g.adj_external(transpose=True, scipy_fmt=\"coo\").todense()\n    adj = torch.tensor(adj).float().to(ctx)\n\n    weight_adj = (\n        g.adj_external(transpose=True, scipy_fmt=\"coo\").astype(float).todense()\n    )\n    weight_adj = torch.tensor(weight_adj).float().to(ctx)\n    src, dst = g.edges()\n    src, dst = src.long(), dst.long()\n    weight_adj[dst, src] = g.edata[\"scalar_w\"]\n\n    # raw\n    transform = dgl.SIGNDiffusion(k=1, in_feat_name=\"h\", diffuse_op=\"raw\")\n    g = transform(g)\n    target = torch.matmul(adj, g.ndata[\"h\"])\n    assert torch.allclose(g.ndata[\"out_feat_1\"], target, atol=atol)\n\n    transform = dgl.SIGNDiffusion(\n        k=1, in_feat_name=\"h\", eweight_name=\"scalar_w\", diffuse_op=\"raw\"\n    )\n    g = transform(g)\n    target = torch.matmul(weight_adj, g.ndata[\"h\"])\n    assert torch.allclose(g.ndata[\"out_feat_1\"], target, atol=atol)\n\n    # rw\n    adj_rw = torch.matmul(torch.diag(1 / adj.sum(dim=1)), adj)\n    transform = dgl.SIGNDiffusion(k=1, in_feat_name=\"h\", diffuse_op=\"rw\")\n    g = transform(g)\n    target = torch.matmul(adj_rw, g.ndata[\"h\"])\n    assert torch.allclose(g.ndata[\"out_feat_1\"], target, atol=atol)\n\n    weight_adj_rw = torch.matmul(\n        torch.diag(1 / weight_adj.sum(dim=1)), weight_adj\n    )\n    transform = dgl.SIGNDiffusion(\n        k=1, in_feat_name=\"h\", eweight_name=\"scalar_w\", diffuse_op=\"rw\"\n    )\n    g = transform(g)\n    target = torch.matmul(weight_adj_rw, g.ndata[\"h\"])\n    assert torch.allclose(g.ndata[\"out_feat_1\"], target, atol=atol)\n\n    # gcn\n    raw_eweight = g.edata[\"scalar_w\"]\n    gcn_norm = dgl.GCNNorm()\n    g = gcn_norm(g)\n    adj_gcn = adj.clone()\n    adj_gcn[dst, src] = g.edata.pop(\"w\")\n    transform = dgl.SIGNDiffusion(k=1, in_feat_name=\"h\", diffuse_op=\"gcn\")\n    g = transform(g)\n    target = torch.matmul(adj_gcn, g.ndata[\"h\"])\n    assert torch.allclose(g.ndata[\"out_feat_1\"], target, atol=atol)\n\n    gcn_norm = dgl.GCNNorm(\"scalar_w\")\n    g = gcn_norm(g)\n    weight_adj_gcn = weight_adj.clone()\n    weight_adj_gcn[dst, src] = g.edata[\"scalar_w\"]\n    g.edata[\"scalar_w\"] = raw_eweight\n    transform = dgl.SIGNDiffusion(\n        k=1, in_feat_name=\"h\", eweight_name=\"scalar_w\", diffuse_op=\"gcn\"\n    )\n    g = transform(g)\n    target = torch.matmul(weight_adj_gcn, g.ndata[\"h\"])\n    assert torch.allclose(g.ndata[\"out_feat_1\"], target, atol=atol)\n\n    # ppr\n    alpha = 0.2\n    transform = dgl.SIGNDiffusion(\n        k=1, in_feat_name=\"h\", diffuse_op=\"ppr\", alpha=alpha\n    )\n    g = transform(g)\n    target = (1 - alpha) * torch.matmul(\n        adj_gcn, g.ndata[\"h\"]\n    ) + alpha * g.ndata[\"h\"]\n    assert torch.allclose(g.ndata[\"out_feat_1\"], target, atol=atol)\n\n    transform = dgl.SIGNDiffusion(\n        k=1,\n        in_feat_name=\"h\",\n        eweight_name=\"scalar_w\",\n        diffuse_op=\"ppr\",\n        alpha=alpha,\n    )\n    g = transform(g)\n    target = (1 - alpha) * torch.matmul(\n        weight_adj_gcn, g.ndata[\"h\"]\n    ) + alpha * g.ndata[\"h\"]\n    assert torch.allclose(g.ndata[\"out_feat_1\"], target, atol=atol)\n\n\n@unittest.skipIf(\n    dgl.backend.backend_name != \"pytorch\", reason=\"Only support PyTorch for now\"\n)\n@parametrize_idtype\ndef test_module_row_feat_normalizer(idtype):\n    # Case1: Normalize features of a homogeneous graph.\n    transform = dgl.RowFeatNormalizer(\n        subtract_min=True, node_feat_names=[\"h\"], edge_feat_names=[\"w\"]\n    )\n    g = dgl.rand_graph(5, 5, idtype=idtype, device=F.ctx())\n    g.ndata[\"h\"] = F.randn((g.num_nodes(), 128))\n    g.edata[\"w\"] = F.randn((g.num_edges(), 128))\n    g = transform(g)\n    assert g.ndata[\"h\"].shape == (g.num_nodes(), 128)\n    assert g.edata[\"w\"].shape == (g.num_edges(), 128)\n    assert F.allclose(g.ndata[\"h\"].sum(1), F.tensor([1.0, 1.0, 1.0, 1.0, 1.0]))\n    assert F.allclose(g.edata[\"w\"].sum(1), F.tensor([1.0, 1.0, 1.0, 1.0, 1.0]))\n\n    # Case2: Normalize features of a heterogeneous graph.\n    transform = dgl.RowFeatNormalizer(\n        subtract_min=True, node_feat_names=[\"h\", \"h2\"], edge_feat_names=[\"w\"]\n    )\n    g = dgl.heterograph(\n        {\n            (\"user\", \"follows\", \"user\"): (F.tensor([1, 2]), F.tensor([3, 4])),\n            (\"player\", \"plays\", \"game\"): (F.tensor([2, 2]), F.tensor([1, 1])),\n        },\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    g.ndata[\"h\"] = {\"game\": F.randn((2, 128)), \"player\": F.randn((3, 128))}\n    g.ndata[\"h2\"] = {\"user\": F.randn((5, 128))}\n    g.edata[\"w\"] = {\n        (\"user\", \"follows\", \"user\"): F.randn((2, 128)),\n        (\"player\", \"plays\", \"game\"): F.randn((2, 128)),\n    }\n    g = transform(g)\n    assert g.ndata[\"h\"][\"game\"].shape == (2, 128)\n    assert g.ndata[\"h\"][\"player\"].shape == (3, 128)\n    assert g.ndata[\"h2\"][\"user\"].shape == (5, 128)\n    assert g.edata[\"w\"][(\"user\", \"follows\", \"user\")].shape == (2, 128)\n    assert g.edata[\"w\"][(\"player\", \"plays\", \"game\")].shape == (2, 128)\n    assert F.allclose(g.ndata[\"h\"][\"game\"].sum(1), F.tensor([1.0, 1.0]))\n    assert F.allclose(g.ndata[\"h\"][\"player\"].sum(1), F.tensor([1.0, 1.0, 1.0]))\n    assert F.allclose(\n        g.ndata[\"h2\"][\"user\"].sum(1), F.tensor([1.0, 1.0, 1.0, 1.0, 1.0])\n    )\n    assert F.allclose(\n        g.edata[\"w\"][(\"user\", \"follows\", \"user\")].sum(1), F.tensor([1.0, 1.0])\n    )\n    assert F.allclose(\n        g.edata[\"w\"][(\"player\", \"plays\", \"game\")].sum(1), F.tensor([1.0, 1.0])\n    )\n\n\n@unittest.skipIf(\n    dgl.backend.backend_name != \"pytorch\", reason=\"Only support PyTorch for now\"\n)\n@parametrize_idtype\ndef test_module_feat_mask(idtype):\n    # Case1: Mask node and edge feature tensors of a homogeneous graph.\n    transform = dgl.FeatMask(node_feat_names=[\"h\"], edge_feat_names=[\"w\"])\n    g = dgl.rand_graph(5, 20, idtype=idtype, device=F.ctx())\n    g.ndata[\"h\"] = F.ones((g.num_nodes(), 10))\n    g.edata[\"w\"] = F.ones((g.num_edges(), 20))\n    g = transform(g)\n    assert g.device == g.device\n    assert g.idtype == g.idtype\n    assert g.ndata[\"h\"].shape == (g.num_nodes(), 10)\n    assert g.edata[\"w\"].shape == (g.num_edges(), 20)\n\n    # Case2: Mask node and edge feature tensors of a heterogeneous graph.\n    g = dgl.heterograph(\n        {\n            (\"user\", \"follows\", \"user\"): (F.tensor([1, 2]), F.tensor([3, 4])),\n            (\"player\", \"plays\", \"game\"): (F.tensor([2, 2]), F.tensor([1, 1])),\n        },\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    g.ndata[\"h\"] = {\"game\": F.randn((2, 5)), \"player\": F.randn((3, 5))}\n    g.edata[\"w\"] = {\n        (\"user\", \"follows\", \"user\"): F.randn((2, 5)),\n        (\"player\", \"plays\", \"game\"): F.randn((2, 5)),\n    }\n    g = transform(g)\n    assert g.device == g.device\n    assert g.idtype == g.idtype\n    assert g.ndata[\"h\"][\"game\"].shape == (2, 5)\n    assert g.ndata[\"h\"][\"player\"].shape == (3, 5)\n    assert g.edata[\"w\"][(\"user\", \"follows\", \"user\")].shape == (2, 5)\n    assert g.edata[\"w\"][(\"player\", \"plays\", \"game\")].shape == (2, 5)\n\n\n@parametrize_idtype\ndef test_shortest_dist(idtype):\n    g = dgl.graph(([0, 1, 1, 2], [2, 0, 3, 3]), idtype=idtype, device=F.ctx())\n\n    # case 1: directed single source\n    dist = dgl.shortest_dist(g, root=0)\n    tgt = F.copy_to(F.tensor([0, -1, 1, 2], dtype=F.int64), g.device)\n    assert F.array_equal(dist, tgt)\n\n    # case 2: undirected all pairs\n    dist, paths = dgl.shortest_dist(g, root=None, return_paths=True)\n    tgt_dist = F.copy_to(\n        F.tensor(\n            [[0, -1, 1, 2], [1, 0, 2, 1], [-1, -1, 0, 1], [-1, -1, -1, 0]],\n            dtype=F.int64,\n        ),\n        g.device,\n    )\n    tgt_paths = F.copy_to(\n        F.tensor(\n            [\n                [[-1, -1], [-1, -1], [0, -1], [0, 3]],\n                [[1, -1], [-1, -1], [1, 0], [2, -1]],\n                [[-1, -1], [-1, -1], [-1, -1], [3, -1]],\n                [[-1, -1], [-1, -1], [-1, -1], [-1, -1]],\n            ],\n            dtype=F.int64,\n        ),\n        g.device,\n    )\n    assert F.array_equal(dist, tgt_dist)\n    assert F.array_equal(paths, tgt_paths)\n\n\n@parametrize_idtype\ndef test_module_to_levi(idtype):\n    transform = dgl.ToLevi()\n    g = dgl.graph(([0, 1, 2, 3], [1, 2, 3, 0]), idtype=idtype, device=F.ctx())\n    g.ndata[\"h\"] = F.randn((g.num_nodes(), 2))\n    g.edata[\"w\"] = F.randn((g.num_edges(), 2))\n    lg = transform(g)\n    assert lg.device == g.device\n    assert lg.idtype == g.idtype\n    assert lg.ntypes == [\"edge\", \"node\"]\n    assert lg.canonical_etypes == [\n        (\"edge\", \"e2n\", \"node\"),\n        (\"node\", \"n2e\", \"edge\"),\n    ]\n    assert lg.num_nodes(\"node\") == g.num_nodes()\n    assert lg.num_nodes(\"edge\") == g.num_edges()\n    assert lg.num_edges(\"n2e\") == g.num_edges()\n    assert lg.num_edges(\"e2n\") == g.num_edges()\n\n    src, dst = lg.edges(etype=\"n2e\")\n    eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))\n    assert eset == {(0, 0), (1, 1), (2, 2), (3, 3)}\n\n    src, dst = lg.edges(etype=\"e2n\")\n    eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))\n    assert eset == {(0, 1), (1, 2), (2, 3), (3, 0)}\n\n    assert F.allclose(lg.nodes[\"node\"].data[\"h\"], g.ndata[\"h\"])\n    assert F.allclose(lg.nodes[\"edge\"].data[\"w\"], g.edata[\"w\"])\n\n\n@parametrize_idtype\ndef test_module_svd_pe(idtype):\n    g = dgl.graph(\n        (\n            [0, 0, 1, 1, 2, 2, 2, 2, 3, 3, 4, 4],\n            [2, 3, 0, 2, 0, 2, 3, 4, 3, 4, 0, 1],\n        ),\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    # without padding\n    tgt_pe = F.copy_to(\n        F.tensor(\n            [\n                [0.6669, 0.3068, 0.7979, 0.8477],\n                [0.6311, 0.6101, 0.1248, 0.5137],\n                [1.1993, 0.0665, 0.9183, 0.1455],\n                [0.5682, 0.6766, 0.8952, 0.6449],\n                [0.3393, 0.8363, 0.6500, 0.4564],\n            ]\n        ),\n        g.device,\n    )\n    transform_1 = dgl.SVDPE(k=2, feat_name=\"svd_pe\")\n    g1 = transform_1(g)\n    if dgl.backend.backend_name == \"tensorflow\":\n        assert F.allclose(g1.ndata[\"svd_pe\"].__abs__(), tgt_pe)\n    else:\n        assert F.allclose(g1.ndata[\"svd_pe\"].abs(), tgt_pe)\n\n    # with padding\n    transform_2 = dgl.SVDPE(k=6, feat_name=\"svd_pe\", padding=True)\n    g2 = transform_2(g)\n    assert F.shape(g2.ndata[\"svd_pe\"]) == (5, 12)\n\n\nif __name__ == \"__main__\":\n    test_partition_with_halo()\n    test_module_heat_kernel(F.int32)\n"
  },
  {
    "path": "tests/python/common/utils/test_filter.py",
    "content": "import unittest\n\nimport backend as F\n\nimport dgl\nimport numpy as np\nfrom dgl.utils import Filter\nfrom utils import parametrize_idtype\n\n\ndef test_graph_filter():\n    g = dgl.graph([]).to(F.ctx())\n    g.add_nodes(4)\n    g.add_edges([0, 1, 2, 3], [1, 2, 3, 0])\n\n    n_repr = np.zeros((4, 5))\n    e_repr = np.zeros((4, 5))\n    n_repr[[1, 3]] = 1\n    e_repr[[1, 3]] = 1\n    n_repr = F.copy_to(F.zerocopy_from_numpy(n_repr), F.ctx())\n    e_repr = F.copy_to(F.zerocopy_from_numpy(e_repr), F.ctx())\n\n    g.ndata[\"a\"] = n_repr\n    g.edata[\"a\"] = e_repr\n\n    def predicate(r):\n        return F.max(r.data[\"a\"], 1) > 0\n\n    # full node filter\n    n_idx = g.filter_nodes(predicate)\n    assert set(F.zerocopy_to_numpy(n_idx)) == {1, 3}\n\n    # partial node filter\n    n_idx = g.filter_nodes(predicate, [0, 1])\n    assert set(F.zerocopy_to_numpy(n_idx)) == {1}\n\n    # full edge filter\n    e_idx = g.filter_edges(predicate)\n    assert set(F.zerocopy_to_numpy(e_idx)) == {1, 3}\n\n    # partial edge filter\n    e_idx = g.filter_edges(predicate, [0, 1])\n    assert set(F.zerocopy_to_numpy(e_idx)) == {1}\n\n\n@unittest.skipIf(\n    F._default_context_str == \"cpu\", reason=\"CPU not yet supported\"\n)\n@parametrize_idtype\ndef test_array_filter(idtype):\n    f = Filter(\n        F.copy_to(F.tensor([0, 1, 9, 4, 6, 5, 7], dtype=idtype), F.ctx())\n    )\n    x = F.copy_to(F.tensor([0, 3, 9, 11], dtype=idtype), F.ctx())\n    y = F.copy_to(\n        F.tensor([0, 19, 0, 28, 3, 9, 11, 4, 5], dtype=idtype), F.ctx()\n    )\n\n    xi_act = f.find_included_indices(x)\n    xi_exp = F.copy_to(F.tensor([0, 2], dtype=idtype), F.ctx())\n    assert F.array_equal(xi_act, xi_exp)\n    xe_act = f.find_excluded_indices(x)\n    xe_exp = F.copy_to(F.tensor([1, 3], dtype=idtype), F.ctx())\n    assert F.array_equal(xe_act, xe_exp)\n\n    yi_act = f.find_included_indices(y)\n    yi_exp = F.copy_to(F.tensor([0, 2, 5, 7, 8], dtype=idtype), F.ctx())\n    assert F.array_equal(yi_act, yi_exp)\n    ye_act = f.find_excluded_indices(y)\n    ye_exp = F.copy_to(F.tensor([1, 3, 4, 6], dtype=idtype), F.ctx())\n    assert F.array_equal(ye_act, ye_exp)\n\n\n@unittest.skipIf(\n    dgl.backend.backend_name != \"pytorch\",\n    reason=\"Multiple streams are only supported by pytorch backend\",\n)\n@unittest.skipIf(\n    F._default_context_str == \"cpu\", reason=\"CPU not yet supported\"\n)\n@parametrize_idtype\ndef test_filter_multistream(idtype):\n    # this is a smoke test to ensure we do not trip any internal assertions\n    import torch\n\n    s = torch.cuda.Stream(device=F.ctx())\n    with torch.cuda.stream(s):\n        # we must do multiple runs such that the stream is busy as we launch\n        # work\n        for i in range(10):\n            f = Filter(F.arange(1000, 4000, dtype=idtype, ctx=F.ctx()))\n            x = F.randint([30000], dtype=idtype, ctx=F.ctx(), low=0, high=50000)\n            xi = f.find_included_indices(x)\n\n\nif __name__ == \"__main__\":\n    test_graph_filter()\n    test_array_filter()\n"
  },
  {
    "path": "tests/python/common/utils/test_pin_memory.py",
    "content": "import backend as F\n\nimport dgl\nimport pytest\n\n\n@pytest.mark.skipif(\n    F._default_context_str == \"cpu\", reason=\"Need gpu for this test\"\n)\ndef test_pin_unpin():\n    t = F.arange(0, 100, dtype=F.int64, ctx=F.cpu())\n\n    assert not F.is_pinned(t)\n\n    if F.backend_name == \"pytorch\":\n        nd = dgl.utils.pin_memory_inplace(t)\n        assert F.is_pinned(t)\n        nd.unpin_memory_()\n        assert not F.is_pinned(t)\n        del nd\n\n        # tensor will be unpinned immediately if the returned ndarray is not saved\n        dgl.utils.pin_memory_inplace(t)\n        assert not F.is_pinned(t)\n\n        t_pin = t.pin_memory()\n        # cannot unpin a tensor that is pinned outside of DGL\n        with pytest.raises(dgl.DGLError):\n            F.to_dgl_nd(t_pin).unpin_memory_()\n    else:\n        with pytest.raises(dgl.DGLError):\n            # tensorflow and mxnet should throw an error\n            dgl.utils.pin_memory_inplace(t)\n\n\nif __name__ == \"__main__\":\n    test_pin_unpin()\n"
  },
  {
    "path": "tests/python/mxnet/ip_config.txt",
    "content": "0 127.0.0.1 50050\n1 127.0.0.1 50051\n2 127.0.0.1 50052\n3 127.0.0.1 50053"
  },
  {
    "path": "tests/python/mxnet/test_geometry.py",
    "content": "import backend as F\nimport mxnet as mx\nimport numpy as np\n\nfrom dgl.geometry import farthest_point_sampler\n\n\ndef test_fps():\n    N = 1000\n    batch_size = 5\n    sample_points = 10\n    x = mx.nd.array(\n        np.random.uniform(size=(batch_size, int(N / batch_size), 3))\n    )\n    ctx = F.ctx()\n    if F.gpu_ctx():\n        x = x.as_in_context(ctx)\n    res = farthest_point_sampler(x, sample_points)\n    assert res.shape[0] == batch_size\n    assert res.shape[1] == sample_points\n    assert res.sum() > 0\n\n\nif __name__ == \"__main__\":\n    test_fps()\n"
  },
  {
    "path": "tests/python/mxnet/test_nn.py",
    "content": "import backend as F\n\nimport dgl\nimport dgl.function as fn\nimport dgl.nn.mxnet as nn\nimport mxnet as mx\nimport networkx as nx\nimport numpy as np\nimport pytest\nimport scipy as sp\nfrom mxnet import autograd, gluon, nd\nfrom utils import parametrize_idtype\nfrom utils.graph_cases import (\n    get_cases,\n    random_bipartite,\n    random_dglgraph,\n    random_graph,\n)\n\n\ndef check_close(a, b):\n    assert np.allclose(a.asnumpy(), b.asnumpy(), rtol=1e-4, atol=1e-4)\n\n\ndef _AXWb(A, X, W, b):\n    X = mx.nd.dot(X, W.data(X.context))\n    Y = mx.nd.dot(A, X.reshape(X.shape[0], -1)).reshape(X.shape)\n    return Y + b.data(X.context)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"out_dim\", [1, 2])\ndef test_graph_conv(idtype, out_dim):\n    g = dgl.from_networkx(nx.path_graph(3))\n    g = g.astype(idtype).to(F.ctx())\n    ctx = F.ctx()\n    adj = g.adj_external(transpose=True, ctx=ctx)\n\n    conv = nn.GraphConv(5, out_dim, norm=\"none\", bias=True)\n    conv.initialize(ctx=ctx)\n    # test#1: basic\n    h0 = F.ones((3, 5))\n    h1 = conv(g, h0)\n    assert len(g.ndata) == 0\n    assert len(g.edata) == 0\n    check_close(h1, _AXWb(adj, h0, conv.weight, conv.bias))\n    # test#2: more-dim\n    h0 = F.ones((3, 5, 5))\n    h1 = conv(g, h0)\n    assert len(g.ndata) == 0\n    assert len(g.edata) == 0\n    check_close(h1, _AXWb(adj, h0, conv.weight, conv.bias))\n\n    conv = nn.GraphConv(5, out_dim)\n    conv.initialize(ctx=ctx)\n\n    # test#3: basic\n    h0 = F.ones((3, 5))\n    h1 = conv(g, h0)\n    assert len(g.ndata) == 0\n    assert len(g.edata) == 0\n    # test#4: basic\n    h0 = F.ones((3, 5, 5))\n    h1 = conv(g, h0)\n    assert len(g.ndata) == 0\n    assert len(g.edata) == 0\n\n    conv = nn.GraphConv(5, out_dim)\n    conv.initialize(ctx=ctx)\n\n    with autograd.train_mode():\n        # test#3: basic\n        h0 = F.ones((3, 5))\n        h1 = conv(g, h0)\n        assert len(g.ndata) == 0\n        assert len(g.edata) == 0\n        # test#4: basic\n        h0 = F.ones((3, 5, 5))\n        h1 = conv(g, h0)\n        assert len(g.ndata) == 0\n        assert len(g.edata) == 0\n\n    # test not override features\n    g.ndata[\"h\"] = 2 * F.ones((3, 1))\n    h1 = conv(g, h0)\n    assert len(g.ndata) == 1\n    assert len(g.edata) == 0\n    assert \"h\" in g.ndata\n    check_close(g.ndata[\"h\"], 2 * F.ones((3, 1)))\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\n    \"g\",\n    get_cases([\"homo\", \"block-bipartite\"], exclude=[\"zero-degree\", \"dglgraph\"]),\n)\n@pytest.mark.parametrize(\"norm\", [\"none\", \"both\", \"right\", \"left\"])\n@pytest.mark.parametrize(\"weight\", [True, False])\n@pytest.mark.parametrize(\"bias\", [False])\n@pytest.mark.parametrize(\"out_dim\", [1, 2])\ndef test_graph_conv2(idtype, g, norm, weight, bias, out_dim):\n    g = g.astype(idtype).to(F.ctx())\n    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias)\n    conv.initialize(ctx=F.ctx())\n    ext_w = F.randn((5, out_dim)).as_in_context(F.ctx())\n    nsrc = g.number_of_src_nodes()\n    ndst = g.number_of_dst_nodes()\n    h = F.randn((nsrc, 5)).as_in_context(F.ctx())\n    if weight:\n        h_out = conv(g, h)\n    else:\n        h_out = conv(g, h, ext_w)\n    assert h_out.shape == (ndst, out_dim)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\n    \"g\", get_cases([\"bipartite\"], exclude=[\"zero-degree\", \"dglgraph\"])\n)\n@pytest.mark.parametrize(\"norm\", [\"none\", \"both\", \"right\"])\n@pytest.mark.parametrize(\"weight\", [True, False])\n@pytest.mark.parametrize(\"bias\", [False])\n@pytest.mark.parametrize(\"out_dim\", [1, 2])\ndef test_graph_conv2_bi(idtype, g, norm, weight, bias, out_dim):\n    g = g.astype(idtype).to(F.ctx())\n    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias)\n    conv.initialize(ctx=F.ctx())\n    ext_w = F.randn((5, out_dim)).as_in_context(F.ctx())\n    nsrc = g.number_of_src_nodes()\n    ndst = g.number_of_dst_nodes()\n    h = F.randn((nsrc, 5)).as_in_context(F.ctx())\n    h_dst = F.randn((ndst, out_dim)).as_in_context(F.ctx())\n    if weight:\n        h_out = conv(g, (h, h_dst))\n    else:\n        h_out = conv(g, (h, h_dst), ext_w)\n    assert h_out.shape == (ndst, out_dim)\n\n\ndef _S2AXWb(A, N, X, W, b):\n    X1 = X * N\n    X1 = mx.nd.dot(A, X1.reshape(X1.shape[0], -1))\n    X1 = X1 * N\n    X2 = X1 * N\n    X2 = mx.nd.dot(A, X2.reshape(X2.shape[0], -1))\n    X2 = X2 * N\n    X = mx.nd.concat(X, X1, X2, dim=-1)\n    Y = mx.nd.dot(X, W)\n\n    return Y + b\n\n\n@pytest.mark.parametrize(\"out_dim\", [1, 2])\ndef test_tagconv(out_dim):\n    g = dgl.from_networkx(nx.path_graph(3)).to(F.ctx())\n    ctx = F.ctx()\n    adj = g.adj_external(transpose=True, ctx=ctx)\n    norm = mx.nd.power(g.in_degrees().astype(\"float32\"), -0.5)\n\n    conv = nn.TAGConv(5, out_dim, bias=True)\n    conv.initialize(ctx=ctx)\n    print(conv)\n\n    # test#1: basic\n    h0 = F.ones((3, 5))\n    h1 = conv(g, h0)\n    assert len(g.ndata) == 0\n    assert len(g.edata) == 0\n    shp = norm.shape + (1,) * (h0.ndim - 1)\n    norm = norm.reshape(shp).as_in_context(h0.context)\n\n    assert F.allclose(\n        h1, _S2AXWb(adj, norm, h0, conv.lin.data(ctx), conv.h_bias.data(ctx))\n    )\n\n    conv = nn.TAGConv(5, out_dim)\n    conv.initialize(ctx=ctx)\n\n    # test#2: basic\n    h0 = F.ones((3, 5))\n    h1 = conv(g, h0)\n    assert h1.shape[-1] == out_dim\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\n    \"g\", get_cases([\"homo\", \"block-bipartite\"], exclude=[\"zero-degree\"])\n)\n@pytest.mark.parametrize(\"out_dim\", [1, 20])\n@pytest.mark.parametrize(\"num_heads\", [1, 5])\ndef test_gat_conv(g, idtype, out_dim, num_heads):\n    g = g.astype(idtype).to(F.ctx())\n    ctx = F.ctx()\n    gat = nn.GATConv(10, out_dim, num_heads)  # n_heads = 5\n    gat.initialize(ctx=ctx)\n    print(gat)\n    feat = F.randn((g.number_of_src_nodes(), 10))\n    h = gat(g, feat)\n    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)\n    _, a = gat(g, feat, True)\n    assert a.shape == (g.num_edges(), num_heads, 1)\n\n    # test residual connection\n    gat = nn.GATConv(10, out_dim, num_heads, residual=True)\n    gat.initialize(ctx=ctx)\n    h = gat(g, feat)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"g\", get_cases([\"bipartite\"], exclude=[\"zero-degree\"]))\n@pytest.mark.parametrize(\"out_dim\", [1, 2])\n@pytest.mark.parametrize(\"num_heads\", [1, 4])\ndef test_gat_conv_bi(g, idtype, out_dim, num_heads):\n    g = g.astype(idtype).to(F.ctx())\n    ctx = F.ctx()\n    gat = nn.GATConv(5, out_dim, num_heads)\n    gat.initialize(ctx=ctx)\n    feat = (\n        F.randn((g.number_of_src_nodes(), 5)),\n        F.randn((g.number_of_dst_nodes(), 5)),\n    )\n    h = gat(g, feat)\n    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)\n    _, a = gat(g, feat, True)\n    assert a.shape == (g.num_edges(), num_heads, 1)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"g\", get_cases([\"homo\", \"block-bipartite\"]))\n@pytest.mark.parametrize(\"aggre_type\", [\"mean\", \"pool\", \"gcn\"])\n@pytest.mark.parametrize(\"out_dim\", [1, 10])\ndef test_sage_conv(idtype, g, aggre_type, out_dim):\n    g = g.astype(idtype).to(F.ctx())\n    ctx = F.ctx()\n    sage = nn.SAGEConv(5, out_dim, aggre_type)\n    feat = F.randn((g.number_of_src_nodes(), 5))\n    sage.initialize(ctx=ctx)\n    h = sage(g, feat)\n    assert h.shape[-1] == out_dim\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"g\", get_cases([\"bipartite\"]))\n@pytest.mark.parametrize(\"aggre_type\", [\"mean\", \"pool\", \"gcn\"])\n@pytest.mark.parametrize(\"out_dim\", [1, 2])\ndef test_sage_conv_bi(idtype, g, aggre_type, out_dim):\n    g = g.astype(idtype).to(F.ctx())\n    ctx = F.ctx()\n    dst_dim = 5 if aggre_type != \"gcn\" else 10\n    sage = nn.SAGEConv((10, dst_dim), out_dim, aggre_type)\n    feat = (\n        F.randn((g.number_of_src_nodes(), 10)),\n        F.randn((g.number_of_dst_nodes(), dst_dim)),\n    )\n    sage.initialize(ctx=ctx)\n    h = sage(g, feat)\n    assert h.shape[-1] == out_dim\n    assert h.shape[0] == g.number_of_dst_nodes()\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"aggre_type\", [\"mean\", \"pool\", \"gcn\"])\n@pytest.mark.parametrize(\"out_dim\", [1, 2])\ndef test_sage_conv_bi2(idtype, aggre_type, out_dim):\n    # Test the case for graphs without edges\n    g = dgl.heterograph({(\"_U\", \"_E\", \"_V\"): ([], [])}, {\"_U\": 5, \"_V\": 3})\n    g = g.astype(idtype).to(F.ctx())\n    ctx = F.ctx()\n    sage = nn.SAGEConv((3, 3), out_dim, \"gcn\")\n    feat = (F.randn((5, 3)), F.randn((3, 3)))\n    sage.initialize(ctx=ctx)\n    h = sage(g, feat)\n    assert h.shape[-1] == out_dim\n    assert h.shape[0] == 3\n    for aggre_type in [\"mean\", \"pool\"]:\n        sage = nn.SAGEConv((3, 1), out_dim, aggre_type)\n        feat = (F.randn((5, 3)), F.randn((3, 1)))\n        sage.initialize(ctx=ctx)\n        h = sage(g, feat)\n        assert h.shape[-1] == out_dim\n        assert h.shape[0] == 3\n\n\ndef test_gg_conv():\n    g = dgl.from_networkx(nx.erdos_renyi_graph(20, 0.3)).to(F.ctx())\n    ctx = F.ctx()\n\n    gg_conv = nn.GatedGraphConv(10, 20, 3, 4)  # n_step = 3, n_etypes = 4\n    gg_conv.initialize(ctx=ctx)\n    print(gg_conv)\n\n    # test#1: basic\n    h0 = F.randn((20, 10))\n    etypes = nd.random.randint(0, 4, g.num_edges()).as_in_context(ctx)\n    h1 = gg_conv(g, h0, etypes)\n    assert h1.shape == (20, 20)\n\n\n@pytest.mark.parametrize(\"out_dim\", [1, 20])\ndef test_cheb_conv(out_dim):\n    g = dgl.from_networkx(nx.erdos_renyi_graph(20, 0.3)).to(F.ctx())\n    ctx = F.ctx()\n\n    cheb = nn.ChebConv(10, out_dim, 3)  # k = 3\n    cheb.initialize(ctx=ctx)\n    print(cheb)\n\n    # test#1: basic\n    h0 = F.randn((20, 10))\n    h1 = cheb(g, h0)\n    assert h1.shape == (20, out_dim)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\n    \"g\", get_cases([\"homo\", \"block-bipartite\"], exclude=[\"zero-degree\"])\n)\ndef test_agnn_conv(g, idtype):\n    g = g.astype(idtype).to(F.ctx())\n    ctx = F.ctx()\n    agnn_conv = nn.AGNNConv(0.1, True)\n    agnn_conv.initialize(ctx=ctx)\n    print(agnn_conv)\n    feat = F.randn((g.number_of_src_nodes(), 10))\n    h = agnn_conv(g, feat)\n    assert h.shape == (g.number_of_dst_nodes(), 10)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"g\", get_cases([\"bipartite\"], exclude=[\"zero-degree\"]))\ndef test_agnn_conv_bi(g, idtype):\n    g = g.astype(idtype).to(F.ctx())\n    ctx = F.ctx()\n    agnn_conv = nn.AGNNConv(0.1, True)\n    agnn_conv.initialize(ctx=ctx)\n    print(agnn_conv)\n    feat = (\n        F.randn((g.number_of_src_nodes(), 5)),\n        F.randn((g.number_of_dst_nodes(), 5)),\n    )\n    h = agnn_conv(g, feat)\n    assert h.shape == (g.number_of_dst_nodes(), 5)\n\n\ndef test_appnp_conv():\n    g = dgl.from_networkx(nx.erdos_renyi_graph(20, 0.3)).to(F.ctx())\n    ctx = F.ctx()\n\n    appnp_conv = nn.APPNPConv(3, 0.1, 0)\n    appnp_conv.initialize(ctx=ctx)\n    print(appnp_conv)\n\n    # test#1: basic\n    h0 = F.randn((20, 10))\n    h1 = appnp_conv(g, h0)\n    assert h1.shape == (20, 10)\n\n\n@pytest.mark.parametrize(\"out_dim\", [1, 2])\ndef test_dense_cheb_conv(out_dim):\n    for k in range(1, 4):\n        ctx = F.ctx()\n        g = dgl.from_scipy(sp.sparse.random(100, 100, density=0.3)).to(F.ctx())\n        adj = g.adj_external(transpose=True, ctx=ctx).tostype(\"default\")\n        cheb = nn.ChebConv(5, out_dim, k)\n        dense_cheb = nn.DenseChebConv(5, out_dim, k)\n        cheb.initialize(ctx=ctx)\n        dense_cheb.initialize(ctx=ctx)\n\n        for i in range(len(cheb.fc)):\n            dense_cheb.fc[i].weight.set_data(cheb.fc[i].weight.data())\n            if cheb.bias is not None:\n                dense_cheb.bias.set_data(cheb.bias.data())\n\n        feat = F.randn((100, 5))\n        out_cheb = cheb(g, feat, [2.0])\n        out_dense_cheb = dense_cheb(adj, feat, 2.0)\n        assert F.allclose(out_cheb, out_dense_cheb)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"norm_type\", [\"both\", \"right\", \"none\"])\n@pytest.mark.parametrize(\n    \"g\", get_cases([\"homo\", \"block-bipartite\"], exclude=[\"zero-degree\"])\n)\n@pytest.mark.parametrize(\"out_dim\", [1, 2])\ndef test_dense_graph_conv(idtype, g, norm_type, out_dim):\n    g = g.astype(idtype).to(F.ctx())\n    ctx = F.ctx()\n    adj = g.adj_external(transpose=True, ctx=ctx).tostype(\"default\")\n    conv = nn.GraphConv(5, out_dim, norm=norm_type, bias=True)\n    dense_conv = nn.DenseGraphConv(5, out_dim, norm=norm_type, bias=True)\n    conv.initialize(ctx=ctx)\n    dense_conv.initialize(ctx=ctx)\n    dense_conv.weight.set_data(conv.weight.data())\n    dense_conv.bias.set_data(conv.bias.data())\n    feat = F.randn((g.number_of_src_nodes(), 5))\n    out_conv = conv(g, feat)\n    out_dense_conv = dense_conv(adj, feat)\n    assert F.allclose(out_conv, out_dense_conv)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\n    \"g\", get_cases([\"homo\", \"bipartite\", \"block-bipartite\"])\n)\n@pytest.mark.parametrize(\"out_dim\", [1, 2])\ndef test_dense_sage_conv(idtype, g, out_dim):\n    g = g.astype(idtype).to(F.ctx())\n    ctx = F.ctx()\n    adj = g.adj_external(transpose=True, ctx=ctx).tostype(\"default\")\n    sage = nn.SAGEConv(5, out_dim, \"gcn\")\n    dense_sage = nn.DenseSAGEConv(5, out_dim)\n    sage.initialize(ctx=ctx)\n    dense_sage.initialize(ctx=ctx)\n    dense_sage.fc.weight.set_data(sage.fc_neigh.weight.data())\n    dense_sage.fc.bias.set_data(sage.fc_neigh.bias.data())\n    if len(g.ntypes) == 2:\n        feat = (\n            F.randn((g.number_of_src_nodes(), 5)),\n            F.randn((g.number_of_dst_nodes(), 5)),\n        )\n    else:\n        feat = F.randn((g.num_nodes(), 5))\n\n    out_sage = sage(g, feat)\n    out_dense_sage = dense_sage(adj, feat)\n    assert F.allclose(out_sage, out_dense_sage)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\n    \"g\", get_cases([\"homo\", \"block-bipartite\"], exclude=[\"zero-degree\"])\n)\n@pytest.mark.parametrize(\"out_dim\", [1, 2])\ndef test_edge_conv(g, idtype, out_dim):\n    g = g.astype(idtype).to(F.ctx())\n    ctx = F.ctx()\n    edge_conv = nn.EdgeConv(5, out_dim)\n    edge_conv.initialize(ctx=ctx)\n    print(edge_conv)\n    # test #1: basic\n    h0 = F.randn((g.number_of_src_nodes(), 5))\n    h1 = edge_conv(g, h0)\n    assert h1.shape == (g.number_of_dst_nodes(), out_dim)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"g\", get_cases([\"bipartite\"], exclude=[\"zero-degree\"]))\n@pytest.mark.parametrize(\"out_dim\", [1, 2])\ndef test_edge_conv_bi(g, idtype, out_dim):\n    g = g.astype(idtype).to(F.ctx())\n    ctx = F.ctx()\n    edge_conv = nn.EdgeConv(5, out_dim)\n    edge_conv.initialize(ctx=ctx)\n    print(edge_conv)\n    # test #1: basic\n    h0 = F.randn((g.number_of_src_nodes(), 5))\n    x0 = F.randn((g.number_of_dst_nodes(), 5))\n    h1 = edge_conv(g, (h0, x0))\n    assert h1.shape == (g.number_of_dst_nodes(), out_dim)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"g\", get_cases([\"homo\", \"block-bipartite\"]))\n@pytest.mark.parametrize(\"aggregator_type\", [\"mean\", \"max\", \"sum\"])\ndef test_gin_conv(g, idtype, aggregator_type):\n    g = g.astype(idtype).to(F.ctx())\n    ctx = F.ctx()\n\n    gin_conv = nn.GINConv(lambda x: x, aggregator_type, 0.1)\n    gin_conv.initialize(ctx=ctx)\n    print(gin_conv)\n\n    # test #1: basic\n    feat = F.randn((g.number_of_src_nodes(), 5))\n    h = gin_conv(g, feat)\n    assert h.shape == (g.number_of_dst_nodes(), 5)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"g\", get_cases([\"bipartite\"]))\n@pytest.mark.parametrize(\"aggregator_type\", [\"mean\", \"max\", \"sum\"])\ndef test_gin_conv_bi(g, idtype, aggregator_type):\n    g = g.astype(idtype).to(F.ctx())\n    ctx = F.ctx()\n\n    gin_conv = nn.GINConv(lambda x: x, aggregator_type, 0.1)\n    gin_conv.initialize(ctx=ctx)\n    print(gin_conv)\n\n    # test #2: bipartite\n    feat = (\n        F.randn((g.number_of_src_nodes(), 5)),\n        F.randn((g.number_of_dst_nodes(), 5)),\n    )\n    h = gin_conv(g, feat)\n    return h.shape == (g.number_of_dst_nodes(), 5)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\n    \"g\", get_cases([\"homo\", \"block-bipartite\"], exclude=[\"zero-degree\"])\n)\ndef test_gmm_conv(g, idtype):\n    g = g.astype(idtype).to(F.ctx())\n    ctx = F.ctx()\n    gmm_conv = nn.GMMConv(5, 2, 5, 3, \"max\")\n    gmm_conv.initialize(ctx=ctx)\n    h0 = F.randn((g.number_of_src_nodes(), 5))\n    pseudo = F.randn((g.num_edges(), 5))\n    h1 = gmm_conv(g, h0, pseudo)\n    assert h1.shape == (g.number_of_dst_nodes(), 2)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"g\", get_cases([\"bipartite\"], exclude=[\"zero-degree\"]))\ndef test_gmm_conv_bi(g, idtype):\n    g = g.astype(idtype).to(F.ctx())\n    ctx = F.ctx()\n    gmm_conv = nn.GMMConv((5, 4), 2, 5, 3, \"max\")\n    gmm_conv.initialize(ctx=ctx)\n    # test #1: basic\n    h0 = F.randn((g.number_of_src_nodes(), 5))\n    hd = F.randn((g.number_of_dst_nodes(), 4))\n    pseudo = F.randn((g.num_edges(), 5))\n    h1 = gmm_conv(g, (h0, hd), pseudo)\n    assert h1.shape == (g.number_of_dst_nodes(), 2)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"g\", get_cases([\"homo\", \"block-bipartite\"]))\ndef test_nn_conv(g, idtype):\n    g = g.astype(idtype).to(F.ctx())\n    ctx = F.ctx()\n    nn_conv = nn.NNConv(5, 2, gluon.nn.Embedding(3, 5 * 2), \"max\")\n    nn_conv.initialize(ctx=ctx)\n    # test #1: basic\n    h0 = F.randn((g.number_of_src_nodes(), 5))\n    etypes = nd.random.randint(0, 4, g.num_edges()).as_in_context(ctx)\n    h1 = nn_conv(g, h0, etypes)\n    assert h1.shape == (g.number_of_dst_nodes(), 2)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"g\", get_cases([\"bipartite\"]))\ndef test_nn_conv_bi(g, idtype):\n    g = g.astype(idtype).to(F.ctx())\n    ctx = F.ctx()\n    nn_conv = nn.NNConv((5, 4), 2, gluon.nn.Embedding(3, 5 * 2), \"max\")\n    nn_conv.initialize(ctx=ctx)\n    # test #1: basic\n    h0 = F.randn((g.number_of_src_nodes(), 5))\n    hd = F.randn((g.number_of_dst_nodes(), 4))\n    etypes = nd.random.randint(0, 4, g.num_edges()).as_in_context(ctx)\n    h1 = nn_conv(g, (h0, hd), etypes)\n    assert h1.shape == (g.number_of_dst_nodes(), 2)\n\n\n@pytest.mark.parametrize(\"out_dim\", [1, 2])\ndef test_sg_conv(out_dim):\n    g = dgl.from_networkx(nx.erdos_renyi_graph(20, 0.3)).to(F.ctx())\n    g = dgl.add_self_loop(g)\n    ctx = F.ctx()\n\n    sgc = nn.SGConv(5, out_dim, 2)\n    sgc.initialize(ctx=ctx)\n    print(sgc)\n\n    # test #1: basic\n    h0 = F.randn((g.num_nodes(), 5))\n    h1 = sgc(g, h0)\n    assert h1.shape == (g.num_nodes(), out_dim)\n\n\ndef test_set2set():\n    g = dgl.from_networkx(nx.path_graph(10)).to(F.ctx())\n    ctx = F.ctx()\n\n    s2s = nn.Set2Set(5, 3, 3)  # hidden size 5, 3 iters, 3 layers\n    s2s.initialize(ctx=ctx)\n    print(s2s)\n\n    # test#1: basic\n    h0 = F.randn((g.num_nodes(), 5))\n    h1 = s2s(g, h0)\n    assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.ndim == 2\n\n    # test#2: batched graph\n    bg = dgl.batch([g, g, g])\n    h0 = F.randn((bg.num_nodes(), 5))\n    h1 = s2s(bg, h0)\n    assert h1.shape[0] == 3 and h1.shape[1] == 10 and h1.ndim == 2\n\n\ndef test_glob_att_pool():\n    g = dgl.from_networkx(nx.path_graph(10)).to(F.ctx())\n    ctx = F.ctx()\n\n    gap = nn.GlobalAttentionPooling(gluon.nn.Dense(1), gluon.nn.Dense(10))\n    gap.initialize(ctx=ctx)\n    print(gap)\n    # test#1: basic\n    h0 = F.randn((g.num_nodes(), 5))\n    h1 = gap(g, h0)\n    assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.ndim == 2\n\n    # test#2: batched graph\n    bg = dgl.batch([g, g, g, g])\n    h0 = F.randn((bg.num_nodes(), 5))\n    h1 = gap(bg, h0)\n    assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.ndim == 2\n\n\ndef test_simple_pool():\n    g = dgl.from_networkx(nx.path_graph(15)).to(F.ctx())\n\n    sum_pool = nn.SumPooling()\n    avg_pool = nn.AvgPooling()\n    max_pool = nn.MaxPooling()\n    sort_pool = nn.SortPooling(10)  # k = 10\n    print(sum_pool, avg_pool, max_pool, sort_pool)\n\n    # test#1: basic\n    h0 = F.randn((g.num_nodes(), 5))\n    h1 = sum_pool(g, h0)\n    check_close(F.squeeze(h1, 0), F.sum(h0, 0))\n    h1 = avg_pool(g, h0)\n    check_close(F.squeeze(h1, 0), F.mean(h0, 0))\n    h1 = max_pool(g, h0)\n    check_close(F.squeeze(h1, 0), F.max(h0, 0))\n    h1 = sort_pool(g, h0)\n    assert h1.shape[0] == 1 and h1.shape[1] == 10 * 5 and h1.ndim == 2\n\n    # test#2: batched graph\n    g_ = dgl.from_networkx(nx.path_graph(5)).to(F.ctx())\n    bg = dgl.batch([g, g_, g, g_, g])\n    h0 = F.randn((bg.num_nodes(), 5))\n    h1 = sum_pool(bg, h0)\n    truth = mx.nd.stack(\n        F.sum(h0[:15], 0),\n        F.sum(h0[15:20], 0),\n        F.sum(h0[20:35], 0),\n        F.sum(h0[35:40], 0),\n        F.sum(h0[40:55], 0),\n        axis=0,\n    )\n    check_close(h1, truth)\n\n    h1 = avg_pool(bg, h0)\n    truth = mx.nd.stack(\n        F.mean(h0[:15], 0),\n        F.mean(h0[15:20], 0),\n        F.mean(h0[20:35], 0),\n        F.mean(h0[35:40], 0),\n        F.mean(h0[40:55], 0),\n        axis=0,\n    )\n    check_close(h1, truth)\n\n    h1 = max_pool(bg, h0)\n    truth = mx.nd.stack(\n        F.max(h0[:15], 0),\n        F.max(h0[15:20], 0),\n        F.max(h0[20:35], 0),\n        F.max(h0[35:40], 0),\n        F.max(h0[40:55], 0),\n        axis=0,\n    )\n    check_close(h1, truth)\n\n    h1 = sort_pool(bg, h0)\n    assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.ndim == 2\n\n\n@pytest.mark.parametrize(\"O\", [1, 2, 8])\ndef test_rgcn(O):\n    ctx = F.ctx()\n    etype = []\n    g = dgl.from_scipy(sp.sparse.random(100, 100, density=0.1)).to(F.ctx())\n    # 5 etypes\n    R = 5\n    for i in range(g.num_edges()):\n        etype.append(i % 5)\n    B = 2\n    I = 10\n\n    rgc_basis = nn.RelGraphConv(I, O, R, \"basis\", B)\n    rgc_basis.initialize(ctx=ctx)\n    h = nd.random.randn(100, I, ctx=ctx)\n    r = nd.array(etype, ctx=ctx)\n    h_new = rgc_basis(g, h, r)\n    assert list(h_new.shape) == [100, O]\n\n    if O % B == 0:\n        rgc_bdd = nn.RelGraphConv(I, O, R, \"bdd\", B)\n        rgc_bdd.initialize(ctx=ctx)\n        h = nd.random.randn(100, I, ctx=ctx)\n        r = nd.array(etype, ctx=ctx)\n        h_new = rgc_bdd(g, h, r)\n        assert list(h_new.shape) == [100, O]\n\n    # with norm\n    norm = nd.zeros((g.num_edges(), 1), ctx=ctx)\n\n    rgc_basis = nn.RelGraphConv(I, O, R, \"basis\", B)\n    rgc_basis.initialize(ctx=ctx)\n    h = nd.random.randn(100, I, ctx=ctx)\n    r = nd.array(etype, ctx=ctx)\n    h_new = rgc_basis(g, h, r, norm)\n    assert list(h_new.shape) == [100, O]\n\n    if O % B == 0:\n        rgc_bdd = nn.RelGraphConv(I, O, R, \"bdd\", B)\n        rgc_bdd.initialize(ctx=ctx)\n        h = nd.random.randn(100, I, ctx=ctx)\n        r = nd.array(etype, ctx=ctx)\n        h_new = rgc_bdd(g, h, r, norm)\n        assert list(h_new.shape) == [100, O]\n\n    # id input\n    rgc_basis = nn.RelGraphConv(I, O, R, \"basis\", B)\n    rgc_basis.initialize(ctx=ctx)\n    h = nd.random.randint(0, I, (100,), ctx=ctx)\n    r = nd.array(etype, ctx=ctx)\n    h_new = rgc_basis(g, h, r)\n    assert list(h_new.shape) == [100, O]\n\n\ndef test_sequential():\n    ctx = F.ctx()\n\n    # test single graph\n    class ExampleLayer(gluon.nn.Block):\n        def __init__(self, **kwargs):\n            super().__init__(**kwargs)\n\n        def forward(self, graph, n_feat, e_feat):\n            graph = graph.local_var()\n            graph.ndata[\"h\"] = n_feat\n            graph.update_all(fn.copy_u(\"h\", \"m\"), fn.sum(\"m\", \"h\"))\n            n_feat += graph.ndata[\"h\"]\n            graph.apply_edges(fn.u_add_v(\"h\", \"h\", \"e\"))\n            e_feat += graph.edata[\"e\"]\n            return n_feat, e_feat\n\n    g = dgl.graph(([], [])).to(F.ctx())\n    g.add_nodes(3)\n    g.add_edges([0, 1, 2, 0, 1, 2, 0, 1, 2], [0, 0, 0, 1, 1, 1, 2, 2, 2])\n    net = nn.Sequential()\n    net.add(ExampleLayer())\n    net.add(ExampleLayer())\n    net.add(ExampleLayer())\n    net.initialize(ctx=ctx)\n    n_feat = F.randn((3, 4))\n    e_feat = F.randn((9, 4))\n    n_feat, e_feat = net(g, n_feat, e_feat)\n    assert n_feat.shape == (3, 4)\n    assert e_feat.shape == (9, 4)\n\n    # test multiple graphs\n    class ExampleLayer(gluon.nn.Block):\n        def __init__(self, **kwargs):\n            super().__init__(**kwargs)\n\n        def forward(self, graph, n_feat):\n            graph = graph.local_var()\n            graph.ndata[\"h\"] = n_feat\n            graph.update_all(fn.copy_u(\"h\", \"m\"), fn.sum(\"m\", \"h\"))\n            n_feat += graph.ndata[\"h\"]\n            return n_feat.reshape(graph.num_nodes() // 2, 2, -1).sum(1)\n\n    g1 = dgl.from_networkx(nx.erdos_renyi_graph(32, 0.05)).to(F.ctx())\n    g2 = dgl.from_networkx(nx.erdos_renyi_graph(16, 0.2)).to(F.ctx())\n    g3 = dgl.from_networkx(nx.erdos_renyi_graph(8, 0.8)).to(F.ctx())\n\n    net = nn.Sequential()\n    net.add(ExampleLayer())\n    net.add(ExampleLayer())\n    net.add(ExampleLayer())\n    net.initialize(ctx=ctx)\n    n_feat = F.randn((32, 4))\n    n_feat = net([g1, g2, g3], n_feat)\n    assert n_feat.shape == (4, 4)\n\n\ndef myagg(alist, dsttype):\n    rst = alist[0]\n    for i in range(1, len(alist)):\n        rst = rst + (i + 1) * alist[i]\n    return rst\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"agg\", [\"sum\", \"max\", \"min\", \"mean\", \"stack\", myagg])\ndef test_hetero_conv(agg, idtype):\n    g = dgl.heterograph(\n        {\n            (\"user\", \"follows\", \"user\"): ([0, 0, 2, 1], [1, 2, 1, 3]),\n            (\"user\", \"plays\", \"game\"): ([0, 0, 0, 1, 2], [0, 2, 3, 0, 2]),\n            (\"store\", \"sells\", \"game\"): ([0, 0, 1, 1], [0, 3, 1, 2]),\n        },\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    conv = nn.HeteroGraphConv(\n        {\n            \"follows\": nn.GraphConv(2, 3, allow_zero_in_degree=True),\n            \"plays\": nn.GraphConv(2, 4, allow_zero_in_degree=True),\n            \"sells\": nn.GraphConv(3, 4, allow_zero_in_degree=True),\n        },\n        agg,\n    )\n    conv.initialize(ctx=F.ctx())\n    print(conv)\n    uf = F.randn((4, 2))\n    gf = F.randn((4, 4))\n    sf = F.randn((2, 3))\n\n    h = conv(g, {\"user\": uf, \"store\": sf, \"game\": gf})\n    assert set(h.keys()) == {\"user\", \"game\"}\n    if agg != \"stack\":\n        assert h[\"user\"].shape == (4, 3)\n        assert h[\"game\"].shape == (4, 4)\n    else:\n        assert h[\"user\"].shape == (4, 1, 3)\n        assert h[\"game\"].shape == (4, 2, 4)\n\n    block = dgl.to_block(\n        g.to(F.cpu()), {\"user\": [0, 1, 2, 3], \"game\": [0, 1, 2, 3], \"store\": []}\n    ).to(F.ctx())\n    h = conv(\n        block,\n        (\n            {\"user\": uf, \"game\": gf, \"store\": sf},\n            {\"user\": uf, \"game\": gf, \"store\": sf[0:0]},\n        ),\n    )\n    assert set(h.keys()) == {\"user\", \"game\"}\n    if agg != \"stack\":\n        assert h[\"user\"].shape == (4, 3)\n        assert h[\"game\"].shape == (4, 4)\n    else:\n        assert h[\"user\"].shape == (4, 1, 3)\n        assert h[\"game\"].shape == (4, 2, 4)\n\n    h = conv(block, {\"user\": uf, \"game\": gf, \"store\": sf})\n    assert set(h.keys()) == {\"user\", \"game\"}\n    if agg != \"stack\":\n        assert h[\"user\"].shape == (4, 3)\n        assert h[\"game\"].shape == (4, 4)\n    else:\n        assert h[\"user\"].shape == (4, 1, 3)\n        assert h[\"game\"].shape == (4, 2, 4)\n\n    # test with mod args\n    class MyMod(mx.gluon.nn.Block):\n        def __init__(self, s1, s2):\n            super(MyMod, self).__init__()\n            self.carg1 = 0\n            self.s1 = s1\n            self.s2 = s2\n\n        def forward(self, g, h, arg1=None):  # mxnet does not support kwargs\n            if arg1 is not None:\n                self.carg1 += 1\n            return F.zeros((g.number_of_dst_nodes(), self.s2))\n\n    mod1 = MyMod(2, 3)\n    mod2 = MyMod(2, 4)\n    mod3 = MyMod(3, 4)\n    conv = nn.HeteroGraphConv(\n        {\"follows\": mod1, \"plays\": mod2, \"sells\": mod3}, agg\n    )\n    conv.initialize(ctx=F.ctx())\n    mod_args = {\"follows\": (1,), \"plays\": (1,)}\n    h = conv(g, {\"user\": uf, \"store\": sf, \"game\": gf}, mod_args)\n    assert mod1.carg1 == 1\n    assert mod2.carg1 == 1\n    assert mod3.carg1 == 0\n\n    # conv on graph without any edges\n    for etype in g.etypes:\n        g = dgl.remove_edges(g, g.edges(form=\"eid\", etype=etype), etype=etype)\n    assert g.num_edges() == 0\n    h = conv(g, {\"user\": uf, \"game\": gf, \"store\": sf})\n    assert set(h.keys()) == {\"user\", \"game\"}\n\n    block = dgl.to_block(\n        g.to(F.cpu()), {\"user\": [0, 1, 2, 3], \"game\": [0, 1, 2, 3], \"store\": []}\n    ).to(F.ctx())\n    h = conv(\n        block,\n        (\n            {\"user\": uf, \"game\": gf, \"store\": sf},\n            {\"user\": uf, \"game\": gf, \"store\": sf[0:0]},\n        ),\n    )\n    assert set(h.keys()) == {\"user\", \"game\"}\n\n\nif __name__ == \"__main__\":\n    test_graph_conv()\n    test_gat_conv()\n    test_sage_conv()\n    test_gg_conv()\n    test_cheb_conv()\n    test_agnn_conv()\n    test_appnp_conv()\n    test_dense_cheb_conv()\n    test_dense_graph_conv()\n    test_dense_sage_conv()\n    test_edge_conv()\n    test_gin_conv()\n    test_gmm_conv()\n    test_nn_conv()\n    test_sg_conv()\n    test_set2set()\n    test_glob_att_pool()\n    test_simple_pool()\n    test_rgcn()\n    test_sequential()\n    test_hetero_conv()\n"
  },
  {
    "path": "tests/python/pytorch/cuda/test_nccl.py",
    "content": "import unittest\n\nimport backend as F\nimport torch\nimport torch.distributed as dist\n\nfrom dgl.cuda import nccl\nfrom dgl.partition import NDArrayPartition\n\n\n@unittest.skipIf(\n    F._default_context_str == \"cpu\", reason=\"NCCL only runs on GPU.\"\n)\ndef test_nccl_sparse_push_single_remainder():\n    torch.cuda.set_device(\"cuda:0\")\n    dist.init_process_group(\n        backend=\"nccl\",\n        init_method=\"tcp://127.0.0.1:12345\",\n        world_size=1,\n        rank=0,\n    )\n\n    index = F.randint([10000], F.int32, F.ctx(), 0, 10000)\n    value = F.uniform([10000, 100], F.float32, F.ctx(), -1.0, 1.0)\n\n    part = NDArrayPartition(10000, 1, \"remainder\")\n\n    ri, rv = nccl.sparse_all_to_all_push(index, value, part)\n    assert F.array_equal(ri, index)\n    assert F.array_equal(rv, value)\n\n    dist.destroy_process_group()\n\n\n@unittest.skipIf(\n    F._default_context_str == \"cpu\", reason=\"NCCL only runs on GPU.\"\n)\ndef test_nccl_sparse_pull_single_remainder():\n    torch.cuda.set_device(\"cuda:0\")\n    dist.init_process_group(\n        backend=\"nccl\",\n        init_method=\"tcp://127.0.0.1:12345\",\n        world_size=1,\n        rank=0,\n    )\n\n    req_index = F.randint([10000], F.int64, F.ctx(), 0, 100000)\n    value = F.uniform([100000, 100], F.float32, F.ctx(), -1.0, 1.0)\n\n    part = NDArrayPartition(100000, 1, \"remainder\")\n\n    rv = nccl.sparse_all_to_all_pull(req_index, value, part)\n    exp_rv = F.gather_row(value, req_index)\n    assert F.array_equal(rv, exp_rv)\n\n    dist.destroy_process_group()\n\n\n@unittest.skipIf(\n    F._default_context_str == \"cpu\", reason=\"NCCL only runs on GPU.\"\n)\ndef test_nccl_sparse_push_single_range():\n    torch.cuda.set_device(\"cuda:0\")\n    dist.init_process_group(\n        backend=\"nccl\",\n        init_method=\"tcp://127.0.0.1:12345\",\n        world_size=1,\n        rank=0,\n    )\n\n    index = F.randint([10000], F.int32, F.ctx(), 0, 10000)\n    value = F.uniform([10000, 100], F.float32, F.ctx(), -1.0, 1.0)\n\n    part_ranges = F.copy_to(\n        F.tensor([0, value.shape[0]], dtype=F.int64), F.ctx()\n    )\n    part = NDArrayPartition(10000, 1, \"range\", part_ranges=part_ranges)\n\n    ri, rv = nccl.sparse_all_to_all_push(index, value, part)\n    assert F.array_equal(ri, index)\n    assert F.array_equal(rv, value)\n\n    dist.destroy_process_group()\n\n\n@unittest.skipIf(\n    F._default_context_str == \"cpu\", reason=\"NCCL only runs on GPU.\"\n)\ndef test_nccl_sparse_pull_single_range():\n    torch.cuda.set_device(\"cuda:0\")\n    dist.init_process_group(\n        backend=\"nccl\",\n        init_method=\"tcp://127.0.0.1:12345\",\n        world_size=1,\n        rank=0,\n    )\n\n    req_index = F.randint([10000], F.int64, F.ctx(), 0, 100000)\n    value = F.uniform([100000, 100], F.float32, F.ctx(), -1.0, 1.0)\n\n    part_ranges = F.copy_to(\n        F.tensor([0, value.shape[0]], dtype=F.int64), F.ctx()\n    )\n    part = NDArrayPartition(100000, 1, \"range\", part_ranges=part_ranges)\n\n    rv = nccl.sparse_all_to_all_pull(req_index, value, part)\n    exp_rv = F.gather_row(value, req_index)\n    assert F.array_equal(rv, exp_rv)\n\n    dist.destroy_process_group()\n\n\nif __name__ == \"__main__\":\n    test_nccl_sparse_push_single_remainder()\n    test_nccl_sparse_pull_single_remainder()\n    test_nccl_sparse_push_single_range()\n    test_nccl_sparse_pull_single_range()\n"
  },
  {
    "path": "tests/python/pytorch/dataloading/test_dataloader.py",
    "content": "import os\nimport unittest\nfrom collections.abc import Iterator, Mapping\nfrom functools import partial\n\nimport backend as F\n\nimport dgl\nimport dgl.ops as OPS\nimport numpy as np\nimport pytest\nimport torch\nimport torch.distributed as dist\nimport torch.multiprocessing as mp\nfrom utils import parametrize_idtype\n\n\n@pytest.mark.parametrize(\"batch_size\", [None, 16])\ndef test_graph_dataloader(batch_size):\n    num_batches = 2\n    num_samples = num_batches * (batch_size if batch_size is not None else 1)\n    minigc_dataset = dgl.data.MiniGCDataset(num_samples, 10, 20)\n    data_loader = dgl.dataloading.GraphDataLoader(\n        minigc_dataset, batch_size=batch_size, shuffle=True\n    )\n    assert isinstance(iter(data_loader), Iterator)\n    for graph, label in data_loader:\n        assert isinstance(graph, dgl.DGLGraph)\n        if batch_size is not None:\n            assert F.asnumpy(label).shape[0] == batch_size\n        else:\n            # If batch size is None, the label element will be a single scalar following\n            # PyTorch's practice.\n            assert F.asnumpy(label).ndim == 0\n\n\n@unittest.skipIf(os.name == \"nt\", reason=\"Do not support windows yet\")\n@pytest.mark.parametrize(\"num_workers\", [0, 4])\ndef test_cluster_gcn(num_workers):\n    dataset = dgl.data.CoraFullDataset()\n    g = dataset[0]\n    sampler = dgl.dataloading.ClusterGCNSampler(g, 100)\n    dataloader = dgl.dataloading.DataLoader(\n        g, torch.arange(100), sampler, batch_size=4, num_workers=num_workers\n    )\n    assert len(dataloader) == 25\n    for i, sg in enumerate(dataloader):\n        pass\n\n\n@pytest.mark.parametrize(\"num_workers\", [0, 4])\ndef test_shadow(num_workers):\n    g = dgl.data.CoraFullDataset()[0]\n    sampler = dgl.dataloading.ShaDowKHopSampler([5, 10, 15])\n    dataloader = dgl.dataloading.DataLoader(\n        g,\n        torch.arange(g.num_nodes()),\n        sampler,\n        batch_size=5,\n        shuffle=True,\n        drop_last=False,\n        num_workers=num_workers,\n    )\n    for i, (input_nodes, output_nodes, subgraph) in enumerate(dataloader):\n        assert torch.equal(input_nodes, subgraph.ndata[dgl.NID])\n        assert torch.equal(input_nodes[: output_nodes.shape[0]], output_nodes)\n        assert torch.equal(\n            subgraph.ndata[\"label\"], g.ndata[\"label\"][input_nodes]\n        )\n        assert torch.equal(subgraph.ndata[\"feat\"], g.ndata[\"feat\"][input_nodes])\n        if i == 5:\n            break\n\n\n@pytest.mark.parametrize(\"num_workers\", [0, 4])\n@pytest.mark.parametrize(\"mode\", [\"node\", \"edge\", \"walk\"])\ndef test_saint(num_workers, mode):\n    g = dgl.data.CoraFullDataset()[0]\n\n    if mode == \"node\":\n        budget = 100\n    elif mode == \"edge\":\n        budget = 200\n    elif mode == \"walk\":\n        budget = (3, 2)\n\n    sampler = dgl.dataloading.SAINTSampler(mode, budget)\n    dataloader = dgl.dataloading.DataLoader(\n        g, torch.arange(100), sampler, num_workers=num_workers\n    )\n    assert len(dataloader) == 100\n    for sg in dataloader:\n        pass\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\n    \"mode\", [\"cpu\", \"uva_cuda_indices\", \"uva_cpu_indices\", \"pure_gpu\"]\n)\n@pytest.mark.parametrize(\"use_ddp\", [False, True])\n@pytest.mark.parametrize(\"use_mask\", [False, True])\ndef test_neighbor_nonuniform(idtype, mode, use_ddp, use_mask):\n    if mode != \"cpu\" and F.ctx() == F.cpu():\n        pytest.skip(\"UVA and GPU sampling require a GPU.\")\n    if mode != \"cpu\" and use_mask:\n        pytest.skip(\"Masked sampling only works on CPU.\")\n    if use_ddp:\n        if os.name == \"nt\":\n            pytest.skip(\"PyTorch 1.13.0+ has problems in Windows DDP...\")\n        dist.init_process_group(\n            \"gloo\" if F.ctx() == F.cpu() else \"nccl\",\n            \"tcp://127.0.0.1:12347\",\n            world_size=1,\n            rank=0,\n        )\n    g = dgl.graph(([1, 2, 3, 4, 5, 6, 7, 8], [0, 0, 0, 0, 1, 1, 1, 1])).astype(\n        idtype\n    )\n    g.edata[\"p\"] = torch.FloatTensor([1, 1, 0, 0, 1, 1, 0, 0])\n    g.edata[\"mask\"] = g.edata[\"p\"] != 0\n    if mode in (\"cpu\", \"uva_cpu_indices\"):\n        indices = F.copy_to(F.tensor([0, 1], idtype), F.cpu())\n    else:\n        indices = F.copy_to(F.tensor([0, 1], idtype), F.cuda())\n    if mode == \"pure_gpu\":\n        g = g.to(F.cuda())\n    use_uva = mode.startswith(\"uva\")\n\n    if use_mask:\n        prob, mask = None, \"mask\"\n    else:\n        prob, mask = \"p\", None\n\n    sampler = dgl.dataloading.MultiLayerNeighborSampler(\n        [2], prob=prob, mask=mask\n    )\n    for num_workers in [0, 1, 2] if mode == \"cpu\" else [0]:\n        dataloader = dgl.dataloading.DataLoader(\n            g,\n            indices,\n            sampler,\n            batch_size=1,\n            device=F.ctx(),\n            num_workers=num_workers,\n            use_uva=use_uva,\n            use_ddp=use_ddp,\n        )\n        for input_nodes, output_nodes, blocks in dataloader:\n            seed = output_nodes.item()\n            neighbors = set(input_nodes[1:].cpu().numpy())\n            if seed == 1:\n                assert neighbors == {5, 6}\n            elif seed == 0:\n                assert neighbors == {1, 2}\n\n    g = dgl.heterograph(\n        {\n            (\"B\", \"BA\", \"A\"): (\n                [1, 2, 3, 4, 5, 6, 7, 8],\n                [0, 0, 0, 0, 1, 1, 1, 1],\n            ),\n            (\"C\", \"CA\", \"A\"): (\n                [1, 2, 3, 4, 5, 6, 7, 8],\n                [0, 0, 0, 0, 1, 1, 1, 1],\n            ),\n        }\n    ).astype(idtype)\n    g.edges[\"BA\"].data[\"p\"] = torch.FloatTensor([1, 1, 0, 0, 1, 1, 0, 0])\n    g.edges[\"BA\"].data[\"mask\"] = g.edges[\"BA\"].data[\"p\"] != 0\n    g.edges[\"CA\"].data[\"p\"] = torch.FloatTensor([0, 0, 1, 1, 0, 0, 1, 1])\n    g.edges[\"CA\"].data[\"mask\"] = g.edges[\"CA\"].data[\"p\"] != 0\n    if mode == \"pure_gpu\":\n        g = g.to(F.cuda())\n    for num_workers in [0, 1, 2] if mode == \"cpu\" else [0]:\n        dataloader = dgl.dataloading.DataLoader(\n            g,\n            {\"A\": indices},\n            sampler,\n            batch_size=1,\n            device=F.ctx(),\n            num_workers=num_workers,\n            use_uva=use_uva,\n            use_ddp=use_ddp,\n        )\n        for input_nodes, output_nodes, blocks in dataloader:\n            seed = output_nodes[\"A\"].item()\n            # Seed and neighbors are of different node types so slicing is not necessary here.\n            neighbors = set(input_nodes[\"B\"].cpu().numpy())\n            if seed == 1:\n                assert neighbors == {5, 6}\n            elif seed == 0:\n                assert neighbors == {1, 2}\n\n            neighbors = set(input_nodes[\"C\"].cpu().numpy())\n            if seed == 1:\n                assert neighbors == {7, 8}\n            elif seed == 0:\n                assert neighbors == {3, 4}\n\n    if use_ddp:\n        dist.destroy_process_group()\n\n\ndef _check_dtype(data, dtype, attr_name):\n    if isinstance(data, dict):\n        for k, v in data.items():\n            assert getattr(v, attr_name) == dtype\n    elif isinstance(data, list):\n        for v in data:\n            assert getattr(v, attr_name) == dtype\n    else:\n        assert getattr(data, attr_name) == dtype\n\n\ndef _check_device(data):\n    if isinstance(data, dict):\n        for k, v in data.items():\n            assert v.device == F.ctx()\n    elif isinstance(data, list):\n        for v in data:\n            assert v.device == F.ctx()\n    else:\n        assert data.device == F.ctx()\n\n\n@pytest.mark.parametrize(\"sampler_name\", [\"full\", \"neighbor\"])\n@pytest.mark.parametrize(\n    \"mode\", [\"cpu\", \"uva_cuda_indices\", \"uva_cpu_indices\", \"pure_gpu\"]\n)\n@pytest.mark.parametrize(\"nprocs\", [1, 4])\n@pytest.mark.parametrize(\"drop_last\", [True, False])\ndef test_ddp_dataloader_decompose_dataset(\n    sampler_name, mode, nprocs, drop_last\n):\n    if torch.cuda.device_count() < nprocs and mode != \"cpu\":\n        pytest.skip(\n            \"DDP dataloader needs sufficient GPUs for UVA and GPU sampling.\"\n        )\n    if mode != \"cpu\" and F.ctx() == F.cpu():\n        pytest.skip(\"UVA and GPU sampling require a GPU.\")\n\n    if os.name == \"nt\":\n        pytest.skip(\"PyTorch 1.13.0+ has problems in Windows DDP...\")\n    g, _, _, _ = _create_homogeneous()\n    g = g.to(F.cpu())\n\n    sampler = {\n        \"full\": dgl.dataloading.MultiLayerFullNeighborSampler(2),\n        \"neighbor\": dgl.dataloading.MultiLayerNeighborSampler([3, 3]),\n    }[sampler_name]\n    indices = F.copy_to(F.arange(0, g.num_nodes()), F.cpu())\n    data = indices, sampler\n    arguments = mode, drop_last\n    g.create_formats_()\n    os.environ[\"OMP_NUM_THREADS\"] = str(mp.cpu_count() // 2 // nprocs)\n    mp.spawn(_ddp_runner, args=(nprocs, g, data, arguments), nprocs=nprocs)\n\n\ndef _ddp_runner(proc_id, nprocs, g, data, args):\n    mode, drop_last = args\n    indices, sampler = data\n    if mode == \"cpu\":\n        device = torch.device(\"cpu\")\n    else:\n        device = torch.device(proc_id)\n        torch.cuda.set_device(device)\n    if mode == \"pure_gpu\":\n        g = g.to(F.cuda())\n    if mode in (\"cpu\", \"uva_cpu_indices\"):\n        indices = indices.cpu()\n    else:\n        indices = indices.cuda()\n\n    dist.init_process_group(\n        \"nccl\" if mode != \"cpu\" else \"gloo\",\n        \"tcp://127.0.0.1:12347\",\n        world_size=nprocs,\n        rank=proc_id,\n    )\n    use_uva = mode.startswith(\"uva\")\n    batch_size = g.num_nodes()\n    shuffle = False\n    for num_workers in [1, 4] if mode == \"cpu\" else [0]:\n        dataloader = dgl.dataloading.DataLoader(\n            g,\n            indices,\n            sampler,\n            device=device,\n            batch_size=batch_size,  # g1.num_nodes(),\n            num_workers=num_workers,\n            use_uva=use_uva,\n            use_ddp=True,\n            drop_last=drop_last,\n            shuffle=shuffle,\n        )\n        max_nid = [0]\n        for i, (input_nodes, output_nodes, blocks) in enumerate(dataloader):\n            block = blocks[-1]\n            o_src, o_dst = block.edges()\n            src_nodes_id = block.srcdata[dgl.NID][o_src]\n            dst_nodes_id = block.dstdata[dgl.NID][o_dst]\n            max_nid.append(np.max(dst_nodes_id.cpu().numpy()))\n\n        local_max = torch.tensor(np.max(max_nid))\n        if torch.distributed.get_backend() == \"nccl\":\n            local_max = local_max.cuda()\n        dist.reduce(local_max, 0, op=dist.ReduceOp.MAX)\n        if proc_id == 0:\n            if drop_last and not shuffle and local_max > 0:\n                assert (\n                    local_max.item()\n                    == len(indices)\n                    - len(indices) % nprocs\n                    - 1\n                    - (len(indices) // nprocs) % batch_size\n                )\n            elif not drop_last:\n                assert local_max == len(indices) - 1\n    dist.destroy_process_group()\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\n    \"sampler_name\", [\"full\", \"neighbor\", \"neighbor2\", \"labor\"]\n)\n@pytest.mark.parametrize(\n    \"mode\", [\"cpu\", \"uva_cuda_indices\", \"uva_cpu_indices\", \"pure_gpu\"]\n)\n@pytest.mark.parametrize(\"use_ddp\", [False, True])\ndef test_node_dataloader(idtype, sampler_name, mode, use_ddp):\n    if mode != \"cpu\" and F.ctx() == F.cpu():\n        pytest.skip(\"UVA and GPU sampling require a GPU.\")\n    if use_ddp:\n        if os.name == \"nt\":\n            pytest.skip(\"PyTorch 1.13.0+ has problems in Windows DDP...\")\n        dist.init_process_group(\n            \"gloo\" if F.ctx() == F.cpu() else \"nccl\",\n            \"tcp://127.0.0.1:12347\",\n            world_size=1,\n            rank=0,\n        )\n    g1 = dgl.graph(([0, 0, 0, 1, 1], [1, 2, 3, 3, 4])).astype(idtype)\n    g1.ndata[\"feat\"] = F.copy_to(F.randn((5, 8)), F.cpu())\n    g1.ndata[\"label\"] = F.copy_to(F.randn((g1.num_nodes(),)), F.cpu())\n    if mode in (\"cpu\", \"uva_cpu_indices\"):\n        indices = F.copy_to(F.arange(0, g1.num_nodes(), idtype), F.cpu())\n    else:\n        indices = F.copy_to(F.arange(0, g1.num_nodes(), idtype), F.cuda())\n    if mode == \"pure_gpu\":\n        g1 = g1.to(F.cuda())\n\n    use_uva = mode.startswith(\"uva\")\n\n    sampler = {\n        \"full\": dgl.dataloading.MultiLayerFullNeighborSampler(2),\n        \"neighbor\": dgl.dataloading.MultiLayerNeighborSampler([3, 3]),\n        \"neighbor2\": dgl.dataloading.MultiLayerNeighborSampler([3, 3]),\n        \"labor\": dgl.dataloading.LaborSampler([3, 3]),\n    }[sampler_name]\n    for num_workers in [0, 1, 2] if mode == \"cpu\" else [0]:\n        dataloader = dgl.dataloading.DataLoader(\n            g1,\n            indices,\n            sampler,\n            device=F.ctx(),\n            batch_size=g1.num_nodes(),\n            num_workers=num_workers,\n            use_uva=use_uva,\n            use_ddp=use_ddp,\n        )\n        for input_nodes, output_nodes, blocks in dataloader:\n            _check_device(input_nodes)\n            _check_device(output_nodes)\n            _check_device(blocks)\n            _check_dtype(input_nodes, idtype, \"dtype\")\n            _check_dtype(output_nodes, idtype, \"dtype\")\n            _check_dtype(blocks, idtype, \"idtype\")\n\n    g2 = dgl.heterograph(\n        {\n            (\"user\", \"follow\", \"user\"): (\n                [0, 0, 0, 1, 1, 1, 2],\n                [1, 2, 3, 0, 2, 3, 0],\n            ),\n            (\"user\", \"followed-by\", \"user\"): (\n                [1, 2, 3, 0, 2, 3, 0],\n                [0, 0, 0, 1, 1, 1, 2],\n            ),\n            (\"user\", \"play\", \"game\"): ([0, 1, 1, 3, 5], [0, 1, 2, 0, 2]),\n            (\"game\", \"played-by\", \"user\"): ([0, 1, 2, 0, 2], [0, 1, 1, 3, 5]),\n        }\n    ).astype(idtype)\n    for ntype in g2.ntypes:\n        g2.nodes[ntype].data[\"feat\"] = F.copy_to(\n            F.randn((g2.num_nodes(ntype), 8)), F.cpu()\n        )\n    if mode in (\"cpu\", \"uva_cpu_indices\"):\n        indices = {nty: F.copy_to(g2.nodes(nty), F.cpu()) for nty in g2.ntypes}\n    else:\n        indices = {nty: F.copy_to(g2.nodes(nty), F.cuda()) for nty in g2.ntypes}\n    if mode == \"pure_gpu\":\n        g2 = g2.to(F.cuda())\n\n    batch_size = max(g2.num_nodes(nty) for nty in g2.ntypes)\n    sampler = {\n        \"full\": dgl.dataloading.MultiLayerFullNeighborSampler(2),\n        \"neighbor\": dgl.dataloading.MultiLayerNeighborSampler(\n            [{etype: 3 for etype in g2.etypes}] * 2\n        ),\n        \"neighbor2\": dgl.dataloading.MultiLayerNeighborSampler([3, 3]),\n        \"labor\": dgl.dataloading.LaborSampler([3, 3]),\n    }[sampler_name]\n    for num_workers in [0, 1, 2] if mode == \"cpu\" else [0]:\n        dataloader = dgl.dataloading.DataLoader(\n            g2,\n            indices,\n            sampler,\n            device=F.ctx(),\n            batch_size=batch_size,\n            num_workers=num_workers,\n            use_uva=use_uva,\n            use_ddp=use_ddp,\n        )\n        assert isinstance(iter(dataloader), Iterator)\n        for input_nodes, output_nodes, blocks in dataloader:\n            _check_device(input_nodes)\n            _check_device(output_nodes)\n            _check_device(blocks)\n            _check_dtype(input_nodes, idtype, \"dtype\")\n            _check_dtype(output_nodes, idtype, \"dtype\")\n            _check_dtype(blocks, idtype, \"idtype\")\n\n    if use_ddp:\n        dist.destroy_process_group()\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"sampler_name\", [\"full\", \"neighbor\"])\n@pytest.mark.parametrize(\n    \"neg_sampler\",\n    [\n        dgl.dataloading.negative_sampler.Uniform(2),\n        dgl.dataloading.negative_sampler.GlobalUniform(15, False, 3),\n        dgl.dataloading.negative_sampler.GlobalUniform(15, True, 3),\n    ],\n)\n@pytest.mark.parametrize(\"mode\", [\"cpu\", \"uva\", \"pure_gpu\"])\n@pytest.mark.parametrize(\"use_ddp\", [False, True])\ndef test_edge_dataloader(idtype, sampler_name, neg_sampler, mode, use_ddp):\n    if mode != \"cpu\" and F.ctx() == F.cpu():\n        pytest.skip(\"UVA and GPU sampling require a GPU.\")\n    if mode == \"uva\" and isinstance(\n        neg_sampler, dgl.dataloading.negative_sampler.GlobalUniform\n    ):\n        pytest.skip(\"GlobalUniform don't support UVA yet.\")\n    if use_ddp:\n        if os.name == \"nt\":\n            pytest.skip(\"PyTorch 1.13.0+ has problems in Windows DDP...\")\n        dist.init_process_group(\n            \"gloo\" if F.ctx() == F.cpu() else \"nccl\",\n            \"tcp://127.0.0.1:12347\",\n            world_size=1,\n            rank=0,\n        )\n    g1 = dgl.graph(([0, 0, 0, 1, 1], [1, 2, 3, 3, 4])).astype(idtype)\n    g1.ndata[\"feat\"] = F.copy_to(F.randn((5, 8)), F.cpu())\n    if mode == \"pure_gpu\":\n        g1 = g1.to(F.cuda())\n\n    sampler = {\n        \"full\": dgl.dataloading.MultiLayerFullNeighborSampler(2),\n        \"neighbor\": dgl.dataloading.MultiLayerNeighborSampler([3, 3]),\n    }[sampler_name]\n\n    # no negative sampler\n    edge_sampler = dgl.dataloading.as_edge_prediction_sampler(sampler)\n    dataloader = dgl.dataloading.DataLoader(\n        g1,\n        g1.edges(form=\"eid\"),\n        edge_sampler,\n        device=F.ctx(),\n        batch_size=g1.num_edges(),\n        use_uva=(mode == \"uva\"),\n        use_ddp=use_ddp,\n    )\n    for input_nodes, pos_pair_graph, blocks in dataloader:\n        _check_device(input_nodes)\n        _check_device(pos_pair_graph)\n        _check_device(blocks)\n\n    # negative sampler\n    edge_sampler = dgl.dataloading.as_edge_prediction_sampler(\n        sampler, negative_sampler=neg_sampler\n    )\n    dataloader = dgl.dataloading.DataLoader(\n        g1,\n        g1.edges(form=\"eid\"),\n        edge_sampler,\n        device=F.ctx(),\n        batch_size=g1.num_edges(),\n        use_uva=(mode == \"uva\"),\n        use_ddp=use_ddp,\n    )\n    for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader:\n        _check_device(input_nodes)\n        _check_device(pos_pair_graph)\n        _check_device(neg_pair_graph)\n        _check_device(blocks)\n\n    g2 = dgl.heterograph(\n        {\n            (\"user\", \"follow\", \"user\"): (\n                [0, 0, 0, 1, 1, 1, 2],\n                [1, 2, 3, 0, 2, 3, 0],\n            ),\n            (\"user\", \"followed-by\", \"user\"): (\n                [1, 2, 3, 0, 2, 3, 0],\n                [0, 0, 0, 1, 1, 1, 2],\n            ),\n            (\"user\", \"play\", \"game\"): ([0, 1, 1, 3, 5], [0, 1, 2, 0, 2]),\n            (\"game\", \"played-by\", \"user\"): ([0, 1, 2, 0, 2], [0, 1, 1, 3, 5]),\n        }\n    ).astype(idtype)\n    for ntype in g2.ntypes:\n        g2.nodes[ntype].data[\"feat\"] = F.copy_to(\n            F.randn((g2.num_nodes(ntype), 8)), F.cpu()\n        )\n    if mode == \"pure_gpu\":\n        g2 = g2.to(F.cuda())\n\n    batch_size = max(g2.num_edges(ety) for ety in g2.canonical_etypes)\n    sampler = {\n        \"full\": dgl.dataloading.MultiLayerFullNeighborSampler(2),\n        \"neighbor\": dgl.dataloading.MultiLayerNeighborSampler(\n            [{etype: 3 for etype in g2.etypes}] * 2\n        ),\n    }[sampler_name]\n\n    # no negative sampler\n    edge_sampler = dgl.dataloading.as_edge_prediction_sampler(sampler)\n    dataloader = dgl.dataloading.DataLoader(\n        g2,\n        {ety: g2.edges(form=\"eid\", etype=ety) for ety in g2.canonical_etypes},\n        edge_sampler,\n        device=F.ctx(),\n        batch_size=batch_size,\n        use_uva=(mode == \"uva\"),\n        use_ddp=use_ddp,\n    )\n    for input_nodes, pos_pair_graph, blocks in dataloader:\n        _check_device(input_nodes)\n        _check_device(pos_pair_graph)\n        _check_device(blocks)\n\n    # negative sampler\n    edge_sampler = dgl.dataloading.as_edge_prediction_sampler(\n        sampler, negative_sampler=neg_sampler\n    )\n    dataloader = dgl.dataloading.DataLoader(\n        g2,\n        {ety: g2.edges(form=\"eid\", etype=ety) for ety in g2.canonical_etypes},\n        edge_sampler,\n        device=F.ctx(),\n        batch_size=batch_size,\n        use_uva=(mode == \"uva\"),\n        use_ddp=use_ddp,\n    )\n\n    assert isinstance(iter(dataloader), Iterator)\n    for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader:\n        _check_device(input_nodes)\n        _check_device(pos_pair_graph)\n        _check_device(neg_pair_graph)\n        _check_device(blocks)\n\n    if use_ddp:\n        dist.destroy_process_group()\n\n\ndef _create_homogeneous():\n    s = torch.randint(0, 200, (1000,), device=F.ctx())\n    d = torch.randint(0, 200, (1000,), device=F.ctx())\n    src = torch.cat([s, d])\n    dst = torch.cat([d, s])\n    g = dgl.graph((s, d), num_nodes=200)\n    reverse_eids = torch.cat(\n        [torch.arange(1000, 2000), torch.arange(0, 1000)]\n    ).to(F.ctx())\n    always_exclude = torch.randint(0, 1000, (50,), device=F.ctx())\n    seed_edges = torch.arange(0, 1000, device=F.ctx())\n    return g, reverse_eids, always_exclude, seed_edges\n\n\ndef _create_heterogeneous():\n    edges = {}\n    for utype, etype, vtype in [(\"A\", \"AA\", \"A\"), (\"A\", \"AB\", \"B\")]:\n        s = torch.randint(0, 200, (1000,), device=F.ctx())\n        d = torch.randint(0, 200, (1000,), device=F.ctx())\n        edges[utype, etype, vtype] = (s, d)\n        edges[vtype, \"rev-\" + etype, utype] = (d, s)\n    g = dgl.heterograph(edges, num_nodes_dict={\"A\": 200, \"B\": 200})\n    reverse_etypes = {\n        \"AA\": \"rev-AA\",\n        \"AB\": \"rev-AB\",\n        \"rev-AA\": \"AA\",\n        \"rev-AB\": \"AB\",\n    }\n    always_exclude = {\n        \"AA\": torch.randint(0, 1000, (50,), device=F.ctx()),\n        \"AB\": torch.randint(0, 1000, (50,), device=F.ctx()),\n    }\n    seed_edges = {\n        \"AA\": torch.arange(0, 1000, device=F.ctx()),\n        \"AB\": torch.arange(0, 1000, device=F.ctx()),\n    }\n    return g, reverse_etypes, always_exclude, seed_edges\n\n\ndef _remove_duplicates(s, d):\n    s, d = list(zip(*list(set(zip(s.tolist(), d.tolist())))))\n    return torch.tensor(s, device=F.ctx()), torch.tensor(d, device=F.ctx())\n\n\ndef _find_edges_to_exclude(g, exclude, always_exclude, pair_eids):\n    if exclude == None:\n        return always_exclude\n    elif exclude == \"self\":\n        return (\n            torch.cat([pair_eids, always_exclude])\n            if always_exclude is not None\n            else pair_eids\n        )\n    elif exclude == \"reverse_id\":\n        pair_eids = torch.cat([pair_eids, pair_eids + 1000])\n        return (\n            torch.cat([pair_eids, always_exclude])\n            if always_exclude is not None\n            else pair_eids\n        )\n    elif exclude == \"reverse_types\":\n        pair_eids = {g.to_canonical_etype(k): v for k, v in pair_eids.items()}\n        if (\"A\", \"AA\", \"A\") in pair_eids:\n            pair_eids[(\"A\", \"rev-AA\", \"A\")] = pair_eids[(\"A\", \"AA\", \"A\")]\n        if (\"A\", \"AB\", \"B\") in pair_eids:\n            pair_eids[(\"B\", \"rev-AB\", \"A\")] = pair_eids[(\"A\", \"AB\", \"B\")]\n        if always_exclude is not None:\n            always_exclude = {\n                g.to_canonical_etype(k): v for k, v in always_exclude.items()\n            }\n            for k in always_exclude.keys():\n                if k in pair_eids:\n                    pair_eids[k] = torch.cat([pair_eids[k], always_exclude[k]])\n                else:\n                    pair_eids[k] = always_exclude[k]\n        return pair_eids\n\n\n@pytest.mark.parametrize(\"always_exclude_flag\", [False, True])\n@pytest.mark.parametrize(\n    \"exclude\", [None, \"self\", \"reverse_id\", \"reverse_types\"]\n)\n@pytest.mark.parametrize(\n    \"sampler\",\n    [\n        dgl.dataloading.MultiLayerFullNeighborSampler(1),\n        dgl.dataloading.ShaDowKHopSampler([5]),\n    ],\n)\n@pytest.mark.parametrize(\"batch_size\", [1, 50])\ndef test_edge_dataloader_excludes(\n    exclude, always_exclude_flag, batch_size, sampler\n):\n    if exclude == \"reverse_types\":\n        g, reverse_etypes, always_exclude, seed_edges = _create_heterogeneous()\n    else:\n        g, reverse_eids, always_exclude, seed_edges = _create_homogeneous()\n    g = g.to(F.ctx())\n    if not always_exclude_flag:\n        always_exclude = None\n\n    kwargs = {}\n    kwargs[\"exclude\"] = (\n        partial(_find_edges_to_exclude, g, exclude, always_exclude)\n        if always_exclude_flag\n        else exclude\n    )\n    kwargs[\"reverse_eids\"] = reverse_eids if exclude == \"reverse_id\" else None\n    kwargs[\"reverse_etypes\"] = (\n        reverse_etypes if exclude == \"reverse_types\" else None\n    )\n    sampler = dgl.dataloading.as_edge_prediction_sampler(sampler, **kwargs)\n\n    dataloader = dgl.dataloading.DataLoader(\n        g,\n        seed_edges,\n        sampler,\n        batch_size=batch_size,\n        device=F.ctx(),\n        use_prefetch_thread=False,\n    )\n    for i, (input_nodes, pair_graph, blocks) in enumerate(dataloader):\n        if isinstance(blocks, list):\n            subg = blocks[0]\n        else:\n            subg = blocks\n        pair_eids = pair_graph.edata[dgl.EID]\n        block_eids = subg.edata[dgl.EID]\n\n        edges_to_exclude = _find_edges_to_exclude(\n            g, exclude, always_exclude, pair_eids\n        )\n        if edges_to_exclude is None:\n            continue\n        edges_to_exclude = dgl.utils.recursive_apply(\n            edges_to_exclude, lambda x: x.cpu().numpy()\n        )\n        block_eids = dgl.utils.recursive_apply(\n            block_eids, lambda x: x.cpu().numpy()\n        )\n\n        if isinstance(edges_to_exclude, Mapping):\n            for k in edges_to_exclude.keys():\n                assert not np.isin(edges_to_exclude[k], block_eids[k]).any()\n        else:\n            assert not np.isin(edges_to_exclude, block_eids).any()\n\n        if i == 10:\n            break\n\n\ndef test_edge_dataloader_exclusion_with_reverse_seed_nodes():\n    utype, etype, vtype = (\"A\", \"AB\", \"B\")\n    s = torch.randint(0, 20, (500,), device=F.ctx())\n    d = torch.randint(0, 20, (500,), device=F.ctx())\n    s, d = _remove_duplicates(s, d)\n    g = dgl.heterograph({(\"A\", \"AB\", \"B\"): (s, d), (\"B\", \"BA\", \"A\"): (d, s)})\n    sampler = dgl.dataloading.as_edge_prediction_sampler(\n        dgl.dataloading.NeighborSampler(fanouts=[2, 2, 2]),\n        exclude=\"reverse_types\",\n        reverse_etypes={\"AB\": \"BA\", \"BA\": \"AB\"},\n    )\n    seed_edges = {\n        \"AB\": torch.arange(g.number_of_edges(\"AB\"), device=F.ctx()),\n        \"BA\": torch.arange(g.number_of_edges(\"BA\"), device=F.ctx()),\n    }\n    dataloader = dgl.dataloading.DataLoader(\n        g,\n        seed_edges,\n        sampler,\n        batch_size=2,\n        device=F.ctx(),\n        shuffle=True,\n        drop_last=False,\n    )\n    for _, pos_graph, mfgs in dataloader:\n        s, d = pos_graph[\"AB\"].edges()\n        AB_pos = list(zip(s.tolist(), d.tolist()))\n        s, d = pos_graph[\"BA\"].edges()\n        BA_pos = list(zip(s.tolist(), d.tolist()))\n\n        s, d = mfgs[-1][\"AB\"].edges()\n        AB_mfg = list(zip(s.tolist(), d.tolist()))\n        s, d = mfgs[-1][\"BA\"].edges()\n        BA_mfg = list(zip(s.tolist(), d.tolist()))\n\n        assert all(edge not in AB_mfg for edge in AB_pos)\n        assert all(edge not in BA_mfg for edge in BA_pos)\n\n\ndef test_edge_dataloader_exclusion_without_all_reverses():\n    data_dict = {\n        (\"A\", \"AB\", \"B\"): (torch.tensor([0, 1]), torch.tensor([0, 1])),\n        (\"B\", \"BA\", \"A\"): (torch.tensor([0, 1]), torch.tensor([0, 1])),\n        (\"B\", \"BC\", \"C\"): (torch.tensor([0]), torch.tensor([0])),\n        (\"C\", \"CA\", \"A\"): (torch.tensor([0, 1]), torch.tensor([0, 1])),\n    }\n    g = dgl.heterograph(data_dict=data_dict)\n    block_sampler = dgl.dataloading.MultiLayerNeighborSampler(\n        fanouts=[1], replace=True\n    )\n    block_sampler = dgl.dataloading.as_edge_prediction_sampler(\n        block_sampler,\n        exclude=\"reverse_types\",\n        reverse_etypes={\"AB\": \"BA\"},\n    )\n    d = dgl.dataloading.DataLoader(\n        graph=g,\n        indices={\n            \"AB\": torch.tensor([0]),\n            \"BC\": torch.tensor([0]),\n        },\n        graph_sampler=block_sampler,\n        batch_size=2,\n        shuffle=True,\n        drop_last=False,\n        num_workers=0,\n        device=F.ctx(),\n        use_ddp=False,\n    )\n\n    next(iter(d))\n\n\ndef dummy_worker_init_fn(worker_id):\n    pass\n\n\ndef test_dataloader_worker_init_fn():\n    dataset = dgl.data.CoraFullDataset()\n    g = dataset[0]\n    sampler = dgl.dataloading.MultiLayerNeighborSampler([2])\n    dataloader = dgl.dataloading.DataLoader(\n        g,\n        torch.arange(100),\n        sampler,\n        batch_size=4,\n        num_workers=4,\n        worker_init_fn=dummy_worker_init_fn,\n    )\n    for _ in dataloader:\n        pass\n\n\ndef test_distributed_dataloaders():\n    # Test distributed dataloaders could be successfully imported.\n    try:\n        from dgl.dataloading import (\n            DistDataLoader,\n            DistEdgeDataLoader,\n            DistNodeDataLoader,\n            EdgeCollator,\n            NodeCollator,\n        )\n    except ImportError:\n        pytest.fail(\"Distributed DataLoader from dataloading import failed\")\n\n    try:\n        from dgl.distributed import (\n            DistDataLoader,\n            DistEdgeDataLoader,\n            DistNodeDataLoader,\n            EdgeCollator,\n            NodeCollator,\n        )\n    except ImportError:\n        pytest.fail(\"Distributed DataLoader from dataloading import failed\")\n\n\nif __name__ == \"__main__\":\n    # test_node_dataloader(F.int32, 'neighbor', None)\n    test_edge_dataloader_excludes(\n        \"reverse_types\", False, 1, dgl.dataloading.ShaDowKHopSampler([5])\n    )\n    test_edge_dataloader_exclusion_without_all_reverses()\n"
  },
  {
    "path": "tests/python/pytorch/dataloading/test_spot_target.py",
    "content": "from collections.abc import Mapping\n\nimport dgl\nimport numpy as np\nimport pytest\nimport torch\n\n\ndef _create_homogeneous():\n    s = torch.randint(0, 200, (1000,))\n    d = torch.randint(0, 200, (1000,))\n    g = dgl.graph((s, d), num_nodes=200)\n    reverse_eids = torch.cat([torch.arange(1000, 2000), torch.arange(0, 1000)])\n    seed_edges = torch.arange(0, 1000)\n    return g, reverse_eids, seed_edges\n\n\ndef _find_edges_to_exclude(g, pair_eids, degree_threshold):\n    src, dst = g.find_edges(pair_eids)\n    head_degree = g.in_degrees(src)\n    tail_degree = g.in_degrees(dst)\n    degree = torch.min(head_degree, tail_degree)\n    degree_mask = degree < degree_threshold\n    low_degree_pair_eids = pair_eids[degree_mask]\n    low_degree_pair_eids = torch.cat(\n        [low_degree_pair_eids, low_degree_pair_eids + 1000]\n    )\n    return low_degree_pair_eids\n\n\n@pytest.mark.parametrize(\"degree_threshold\", [1, 2, 3, 4, 5])\n@pytest.mark.parametrize(\"batch_size\", [1, 10, 50])\ndef test_spot_target_excludes(degree_threshold, batch_size):\n    g, reverse_eids, seed_edges = _create_homogeneous()\n    sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)\n    low_degree_excluder = dgl.dataloading.SpotTarget(\n        g,\n        exclude=\"reverse_id\",\n        degree_threshold=degree_threshold,\n        reverse_eids=reverse_eids,\n    )\n    sampler = dgl.dataloading.as_edge_prediction_sampler(\n        sampler,\n        exclude=low_degree_excluder,\n        negative_sampler=dgl.dataloading.negative_sampler.Uniform(1),\n    )\n    dataloader = dgl.dataloading.DataLoader(\n        g, seed_edges, sampler, batch_size=batch_size\n    )\n\n    for i, (input_nodes, pair_graph, neg_pair_graph, blocks) in enumerate(\n        dataloader\n    ):\n        if isinstance(blocks, list):\n            subg = blocks[0]\n        else:\n            subg = blocks\n        pair_eids = pair_graph.edata[dgl.EID]\n        block_eids = subg.edata[dgl.EID]\n        edges_to_exclude = _find_edges_to_exclude(\n            g, pair_eids, degree_threshold\n        )\n        if edges_to_exclude is None:\n            continue\n        edges_to_exclude = dgl.utils.recursive_apply(\n            edges_to_exclude, lambda x: x.cpu().numpy()\n        )\n        block_eids = dgl.utils.recursive_apply(\n            block_eids, lambda x: x.cpu().numpy()\n        )\n\n        if isinstance(edges_to_exclude, Mapping):\n            for k in edges_to_exclude.keys():\n                assert not np.isin(edges_to_exclude[k], block_eids[k]).any()\n        else:\n            assert not np.isin(edges_to_exclude, block_eids).any()\n\n        if i == 10:\n            break\n\n\nif __name__ == \"__main__\":\n    test_spot_target_excludes(degree_threshold=2, batch_size=10)\n"
  },
  {
    "path": "tests/python/pytorch/distributed/optim/test_dist_optim.py",
    "content": "import os\n\nos.environ[\"OMP_NUM_THREADS\"] = \"1\"\nimport multiprocessing as mp\nimport pickle\nimport random\nimport socket\nimport sys\nimport time\nimport unittest\n\nimport backend as F\n\nimport dgl\nimport numpy as np\nimport torch as th\nfrom dgl import function as fn\nfrom dgl.distributed import (\n    DistEmbedding,\n    DistGraph,\n    DistGraphServer,\n    load_partition_book,\n    partition_graph,\n)\nfrom dgl.distributed.optim import SparseAdagrad, SparseAdam\nfrom scipy import sparse as spsp\n\n# Set seeds to make tests fully reproducible.\nSEED = 12345  # random.randint(1, 99999)\nF.seed(SEED)\n\n\ndef create_random_graph(n):\n    arr = (\n        spsp.random(n, n, density=0.001, format=\"coo\", random_state=100) != 0\n    ).astype(np.int64)\n    return dgl.from_scipy(arr)\n\n\ndef get_local_usable_addr():\n    \"\"\"Get local usable IP and port\n\n    Returns\n    -------\n    str\n        IP address, e.g., '192.168.8.12:50051'\n    \"\"\"\n    sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)\n    try:\n        # doesn't even have to be reachable\n        sock.connect((\"10.255.255.255\", 1))\n        ip_addr = sock.getsockname()[0]\n    except ValueError:\n        ip_addr = \"127.0.0.1\"\n    finally:\n        sock.close()\n    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)\n    sock.bind((\"\", 0))\n    sock.listen(1)\n    port = sock.getsockname()[1]\n    sock.close()\n\n    return ip_addr + \" \" + str(port)\n\n\ndef prepare_dist():\n    ip_config = open(\"optim_ip_config.txt\", \"w\")\n    ip_addr = get_local_usable_addr()\n    ip_config.write(\"{}\\n\".format(ip_addr))\n    ip_config.close()\n\n\ndef run_server(graph_name, server_id, server_count, num_clients, shared_mem):\n    g = DistGraphServer(\n        server_id,\n        \"optim_ip_config.txt\",\n        num_clients,\n        server_count,\n        \"/tmp/dist_graph/{}.json\".format(graph_name),\n        disable_shared_mem=not shared_mem,\n    )\n    print(\"start server\", server_id)\n    g.start()\n\n\ndef initializer(shape, dtype):\n    arr = th.zeros(shape, dtype=dtype)\n    th.manual_seed(0)\n    th.nn.init.uniform_(arr, 0, 1.0)\n    return arr\n\n\ndef run_client(graph_name, cli_id, part_id, server_count):\n    device = F.ctx()\n    time.sleep(5)\n    os.environ[\"DGL_NUM_SERVER\"] = str(server_count)\n    dgl.distributed.initialize(\"optim_ip_config.txt\")\n    gpb, graph_name, _, _ = load_partition_book(\n        \"/tmp/dist_graph/{}.json\".format(graph_name), part_id\n    )\n    g = DistGraph(graph_name, gpb=gpb)\n    policy = dgl.distributed.PartitionPolicy(\"node\", g.get_partition_book())\n    num_nodes = g.num_nodes()\n    emb_dim = 4\n    dgl_emb = DistEmbedding(\n        num_nodes,\n        emb_dim,\n        name=\"optim\",\n        init_func=initializer,\n        part_policy=policy,\n    )\n    dgl_emb_zero = DistEmbedding(\n        num_nodes,\n        emb_dim,\n        name=\"optim-zero\",\n        init_func=initializer,\n        part_policy=policy,\n    )\n    dgl_adam = SparseAdam(params=[dgl_emb, dgl_emb_zero], lr=0.01)\n    dgl_adam._world_size = 1\n    dgl_adam._rank = 0\n\n    torch_emb = th.nn.Embedding(num_nodes, emb_dim, sparse=True)\n    torch_emb_zero = th.nn.Embedding(num_nodes, emb_dim, sparse=True)\n    th.manual_seed(0)\n    th.nn.init.uniform_(torch_emb.weight, 0, 1.0)\n    th.manual_seed(0)\n    th.nn.init.uniform_(torch_emb_zero.weight, 0, 1.0)\n    torch_adam = th.optim.SparseAdam(\n        list(torch_emb.parameters()) + list(torch_emb_zero.parameters()),\n        lr=0.01,\n    )\n\n    labels = th.ones((4,)).long()\n    idx = th.randint(0, num_nodes, size=(4,))\n    dgl_value = dgl_emb(idx, device).to(th.device(\"cpu\"))\n    torch_value = torch_emb(idx)\n    torch_adam.zero_grad()\n    torch_loss = th.nn.functional.cross_entropy(torch_value, labels)\n    torch_loss.backward()\n    torch_adam.step()\n\n    dgl_adam.zero_grad()\n    dgl_loss = th.nn.functional.cross_entropy(dgl_value, labels)\n    dgl_loss.backward()\n    dgl_adam.step()\n\n    assert F.allclose(\n        dgl_emb.weight[0 : num_nodes // 2], torch_emb.weight[0 : num_nodes // 2]\n    )\n\n\ndef check_sparse_adam(num_trainer=1, shared_mem=True):\n    prepare_dist()\n    g = create_random_graph(2000)\n    num_servers = num_trainer\n    num_clients = num_trainer\n    num_parts = 1\n\n    graph_name = \"dist_graph_test\"\n    partition_graph(g, graph_name, num_parts, \"/tmp/dist_graph\")\n\n    # let's just test on one partition for now.\n    # We cannot run multiple servers and clients on the same machine.\n    serv_ps = []\n    ctx = mp.get_context(\"spawn\")\n    for serv_id in range(num_servers):\n        p = ctx.Process(\n            target=run_server,\n            args=(graph_name, serv_id, num_servers, num_clients, shared_mem),\n        )\n        serv_ps.append(p)\n        p.start()\n\n    cli_ps = []\n    for cli_id in range(num_clients):\n        print(\"start client\", cli_id)\n        p = ctx.Process(\n            target=run_client, args=(graph_name, cli_id, 0, num_servers)\n        )\n        p.start()\n        cli_ps.append(p)\n\n    for p in cli_ps:\n        p.join()\n\n    for p in serv_ps:\n        p.join()\n\n\n@unittest.skipIf(os.name == \"nt\", reason=\"Do not support windows yet\")\ndef test_sparse_opt():\n    os.environ[\"DGL_DIST_MODE\"] = \"distributed\"\n    check_sparse_adam(1, True)\n    check_sparse_adam(1, False)\n\n\nif __name__ == \"__main__\":\n    os.makedirs(\"/tmp/dist_graph\", exist_ok=True)\n    test_sparse_opt()\n"
  },
  {
    "path": "tests/python/pytorch/geometry/test_geometry.py",
    "content": "import backend as F\n\nimport dgl\nimport dgl.nn\nimport numpy as np\nimport pytest\nimport torch as th\nfrom dgl import DGLError\nfrom dgl.base import DGLWarning\nfrom dgl.geometry import farthest_point_sampler, neighbor_matching\nfrom utils import parametrize_idtype\nfrom utils.graph_cases import get_cases\n\n\ndef test_fps():\n    N = 1000\n    batch_size = 5\n    sample_points = 10\n    x = th.tensor(np.random.uniform(size=(batch_size, int(N / batch_size), 3)))\n    ctx = F.ctx()\n    if F.gpu_ctx():\n        x = x.to(ctx)\n    res = farthest_point_sampler(x, sample_points)\n    assert res.shape[0] == batch_size\n    assert res.shape[1] == sample_points\n    assert res.sum() > 0\n\n\ndef test_fps_start_idx():\n    N = 1000\n    batch_size = 5\n    sample_points = 10\n    x = th.tensor(np.random.uniform(size=(batch_size, int(N / batch_size), 3)))\n    ctx = F.ctx()\n    if F.gpu_ctx():\n        x = x.to(ctx)\n    res = farthest_point_sampler(x, sample_points, start_idx=0)\n    assert th.any(res[:, 0] == 0)\n\n\ndef _test_knn_common(device, algorithm, dist, exclude_self):\n    x = th.randn(8, 3).to(device)\n    kg = dgl.nn.KNNGraph(3)\n    if dist == \"euclidean\":\n        d = th.cdist(x, x).to(F.cpu())\n    else:\n        x = x + th.randn(1).item()\n        tmp_x = x / (1e-5 + F.sqrt(F.sum(x * x, dim=1, keepdims=True)))\n        d = 1 - F.matmul(tmp_x, tmp_x.T).to(F.cpu())\n\n    def check_knn(g, x, start, end, k, exclude_self, check_indices=True):\n        assert g.device == x.device\n        g = g.to(F.cpu())\n        for v in range(start, end):\n            src, _ = g.in_edges(v)\n            src = set(src.numpy())\n            assert len(src) == k\n            if check_indices:\n                i = v - start\n                src_ans = set(\n                    th.topk(\n                        d[start:end, start:end][i],\n                        k + (1 if exclude_self else 0),\n                        largest=False,\n                    )[1].numpy()\n                    + start\n                )\n                if exclude_self:\n                    # remove self\n                    src_ans.remove(v)\n                assert src == src_ans\n\n    def check_batch(g, k, expected_batch_info):\n        assert F.array_equal(g.batch_num_nodes(), F.tensor(expected_batch_info))\n        assert F.array_equal(\n            g.batch_num_edges(), k * F.tensor(expected_batch_info)\n        )\n\n    # check knn with 2d input\n    g = kg(x, algorithm, dist, exclude_self)\n    check_knn(g, x, 0, 8, 3, exclude_self)\n    check_batch(g, 3, [8])\n\n    # check knn with 3d input\n    g = kg(x.view(2, 4, 3), algorithm, dist, exclude_self)\n    check_knn(g, x, 0, 4, 3, exclude_self)\n    check_knn(g, x, 4, 8, 3, exclude_self)\n    check_batch(g, 3, [4, 4])\n\n    # check segmented knn\n    # there are only 2 edges per node possible when exclude_self with 3 nodes in the segment\n    # and this test case isn't supposed to warn, so limit it when exclude_self is True\n    adjusted_k = 3 - (1 if exclude_self else 0)\n    kg = dgl.nn.SegmentedKNNGraph(adjusted_k)\n    g = kg(x, [3, 5], algorithm, dist, exclude_self)\n    check_knn(g, x, 0, 3, adjusted_k, exclude_self)\n    check_knn(g, x, 3, 8, adjusted_k, exclude_self)\n    check_batch(g, adjusted_k, [3, 5])\n\n    # check k > num_points\n    kg = dgl.nn.KNNGraph(10)\n    with pytest.warns(DGLWarning):\n        g = kg(x, algorithm, dist, exclude_self)\n    # there are only 7 edges per node possible when exclude_self with 8 nodes total\n    adjusted_k = 8 - (1 if exclude_self else 0)\n    check_knn(g, x, 0, 8, adjusted_k, exclude_self)\n    check_batch(g, adjusted_k, [8])\n\n    with pytest.warns(DGLWarning):\n        g = kg(x.view(2, 4, 3), algorithm, dist, exclude_self)\n    # there are only 3 edges per node possible when exclude_self with 4 nodes per segment\n    adjusted_k = 4 - (1 if exclude_self else 0)\n    check_knn(g, x, 0, 4, adjusted_k, exclude_self)\n    check_knn(g, x, 4, 8, adjusted_k, exclude_self)\n    check_batch(g, adjusted_k, [4, 4])\n\n    kg = dgl.nn.SegmentedKNNGraph(5)\n    with pytest.warns(DGLWarning):\n        g = kg(x, [3, 5], algorithm, dist, exclude_self)\n    # there are only 2 edges per node possible when exclude_self in the segment with\n    # only 3 nodes, and the current implementation reduces k for all segments\n    # in that case\n    adjusted_k = 3 - (1 if exclude_self else 0)\n    check_knn(g, x, 0, 3, adjusted_k, exclude_self)\n    check_knn(g, x, 3, 8, adjusted_k, exclude_self)\n    check_batch(g, adjusted_k, [3, 5])\n\n    # check k == 0\n    # that's valid for exclude_self, but -1 is not, so check -1 instead for exclude_self\n    adjusted_k = 0 - (1 if exclude_self else 0)\n    kg = dgl.nn.KNNGraph(adjusted_k)\n    with pytest.raises(DGLError):\n        g = kg(x, algorithm, dist, exclude_self)\n    kg = dgl.nn.SegmentedKNNGraph(adjusted_k)\n    with pytest.raises(DGLError):\n        g = kg(x, [3, 5], algorithm, dist, exclude_self)\n\n    # check empty\n    x_empty = th.tensor([])\n    kg = dgl.nn.KNNGraph(3)\n    with pytest.raises(DGLError):\n        g = kg(x_empty, algorithm, dist, exclude_self)\n    kg = dgl.nn.SegmentedKNNGraph(3)\n    with pytest.raises(DGLError):\n        g = kg(x_empty, [3, 5], algorithm, dist, exclude_self)\n\n    # check all coincident points\n    x = th.zeros((20, 3)).to(device)\n    kg = dgl.nn.KNNGraph(3)\n    g = kg(x, algorithm, dist, exclude_self)\n    # different algorithms may break the tie differently, so don't check the indices\n    check_knn(g, x, 0, 20, 3, exclude_self, False)\n    check_batch(g, 3, [20])\n\n    # check all coincident points\n    kg = dgl.nn.SegmentedKNNGraph(3)\n    g = kg(x, [4, 7, 5, 4], algorithm, dist, exclude_self)\n    # different algorithms may break the tie differently, so don't check the indices\n    check_knn(g, x, 0, 4, 3, exclude_self, False)\n    check_knn(g, x, 4, 11, 3, exclude_self, False)\n    check_knn(g, x, 11, 16, 3, exclude_self, False)\n    check_knn(g, x, 16, 20, 3, exclude_self, False)\n    check_batch(g, 3, [4, 7, 5, 4])\n\n\n@pytest.mark.parametrize(\n    \"algorithm\", [\"bruteforce-blas\", \"bruteforce\", \"kd-tree\"]\n)\n@pytest.mark.parametrize(\"dist\", [\"euclidean\", \"cosine\"])\n@pytest.mark.parametrize(\"exclude_self\", [False, True])\ndef test_knn_cpu(algorithm, dist, exclude_self):\n    _test_knn_common(F.cpu(), algorithm, dist, exclude_self)\n\n\n@pytest.mark.parametrize(\n    \"algorithm\", [\"bruteforce-blas\", \"bruteforce\", \"bruteforce-sharemem\"]\n)\n@pytest.mark.parametrize(\"dist\", [\"euclidean\", \"cosine\"])\n@pytest.mark.parametrize(\"exclude_self\", [False, True])\ndef test_knn_cuda(algorithm, dist, exclude_self):\n    if not th.cuda.is_available():\n        return\n    _test_knn_common(F.cuda(), algorithm, dist, exclude_self)\n\n\n@pytest.mark.parametrize(\"num_points\", [8, 64, 256, 1024])\ndef test_knn_sharedmem_large(num_points):\n    if not th.cuda.is_available():\n        return\n    x = th.randn(num_points, 5, device=\"cuda\")\n    y = th.randn(num_points, 5, device=\"cuda\")\n    k = 4\n\n    def ground_truth(x, y, k):\n        dist = (\n            th.sum(x * x, dim=1)\n            + th.sum(y * y, dim=1).unsqueeze(-1)\n            - 2 * th.mm(y, x.T)\n        )\n        ret = th.topk(dist, k, dim=-1, largest=False)[1]\n        return th.sort(ret, dim=-1)[0]\n\n    gt = ground_truth(x, y, k)\n    actual = th.sort(\n        dgl.functional.knn(\n            k, x, [num_points], y, [num_points], algorithm=\"bruteforce-sharemem\"\n        )[1].reshape(-1, k),\n        -1,\n    )[0]\n    assert th.all(actual == gt).item()\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"g\", get_cases([\"homo\"], exclude=[\"dglgraph\"]))\n@pytest.mark.parametrize(\"weight\", [True, False])\n@pytest.mark.parametrize(\"relabel\", [True, False])\ndef test_edge_coarsening(idtype, g, weight, relabel):\n    num_nodes = g.num_nodes()\n    g = dgl.to_bidirected(g)\n    g = g.astype(idtype).to(F.ctx())\n    edge_weight = None\n    if weight:\n        edge_weight = F.abs(F.randn((g.num_edges(),))).to(F.ctx())\n    node_labels = neighbor_matching(g, edge_weight, relabel_idx=relabel)\n    unique_ids, counts = th.unique(node_labels, return_counts=True)\n    num_result_ids = unique_ids.size(0)\n\n    # shape correct\n    assert node_labels.shape == (g.num_nodes(),)\n\n    # all nodes marked\n    assert F.reduce_sum(node_labels < 0).item() == 0\n\n    # number of unique node ids correct.\n    assert num_result_ids >= num_nodes // 2 and num_result_ids <= num_nodes\n\n    # each unique id has <= 2 nodes\n    assert F.reduce_sum(counts > 2).item() == 0\n\n    # if two nodes have the same id, they must be neighbors\n    idxs = F.arange(0, num_nodes, idtype)\n    for l in unique_ids:\n        l = l.item()\n        idx = idxs[(node_labels == l)]\n        if idx.size(0) == 2:\n            u, v = idx[0].item(), idx[1].item()\n            assert g.has_edges_between(u, v)\n\n\nif __name__ == \"__main__\":\n    test_fps()\n    test_fps_start_idx()\n    test_knn()\n    test_knn_sharedmem_large()\n"
  },
  {
    "path": "tests/python/pytorch/graphbolt/__init__.py",
    "content": "\"\"\" DGL graphbolt API tests\"\"\"\n"
  },
  {
    "path": "tests/python/pytorch/graphbolt/gb_test_utils.py",
    "content": "import os\n\nimport dgl\nimport dgl.graphbolt as gb\n\nimport numpy as np\nimport pandas as pd\nimport scipy.sparse as sp\nimport torch\n\n\ndef rand_csc_graph(N, density, bidirection_edge=False):\n    adj = sp.random(N, N, density)\n    if bidirection_edge:\n        adj = adj + adj.T\n    adj = adj.tocsc()\n\n    indptr = torch.LongTensor(adj.indptr)\n    indices = torch.LongTensor(adj.indices)\n\n    graph = gb.fused_csc_sampling_graph(indptr, indices)\n\n    return graph\n\n\ndef random_homo_graph(num_nodes, num_edges):\n    csc_indptr = torch.randint(0, num_edges, (num_nodes + 1,))\n    csc_indptr = torch.sort(csc_indptr)[0]\n    csc_indptr[0] = 0\n    csc_indptr[-1] = num_edges\n    indices = torch.randint(0, num_nodes, (num_edges,))\n    return csc_indptr, indices\n\n\ndef get_type_to_id(num_ntypes, num_etypes):\n    ntypes = {f\"n{i}\": i for i in range(num_ntypes)}\n    etypes = {}\n    count = 0\n    for n1 in range(num_ntypes):\n        for n2 in range(n1, num_ntypes):\n            if count >= num_etypes:\n                break\n            etypes.update({f\"n{n1}:e{count}:n{n2}\": count})\n            count += 1\n    return ntypes, etypes\n\n\ndef get_ntypes_and_etypes(num_nodes, num_ntypes, num_etypes):\n    ntypes = {f\"n{i}\": num_nodes // num_ntypes for i in range(num_ntypes)}\n    if num_nodes % num_ntypes != 0:\n        ntypes[\"n0\"] += num_nodes % num_ntypes\n    etypes = []\n    count = 0\n    while count < num_etypes:\n        for n1 in range(num_ntypes):\n            for n2 in range(num_ntypes):\n                if count >= num_etypes:\n                    break\n                etypes.append((f\"n{n1}\", f\"e{count}\", f\"n{n2}\"))\n                count += 1\n    return ntypes, etypes\n\n\ndef random_hetero_graph(num_nodes, num_edges, num_ntypes, num_etypes):\n    ntypes, etypes = get_ntypes_and_etypes(num_nodes, num_ntypes, num_etypes)\n    edges = {}\n    for step, etype in enumerate(etypes):\n        src_ntype, _, dst_ntype = etype\n        num_e = num_edges // num_etypes + (\n            0 if step != 0 else num_edges % num_etypes\n        )\n        if ntypes[src_ntype] == 0 or ntypes[dst_ntype] == 0:\n            continue\n        src = torch.randint(0, ntypes[src_ntype], (num_e,))\n        dst = torch.randint(0, ntypes[dst_ntype], (num_e,))\n\n        edges[etype] = (src, dst)\n\n    gb_g = gb.from_dglgraph(dgl.heterograph(edges, ntypes))\n    return (\n        gb_g.csc_indptr,\n        gb_g.indices,\n        gb_g.node_type_offset,\n        gb_g.type_per_edge,\n        gb_g.node_type_to_id,\n        gb_g.edge_type_to_id,\n    )\n\n\ndef random_homo_graphbolt_graph(\n    test_dir, dataset_name, num_nodes, num_edges, num_classes, edge_fmt=\"csv\"\n):\n    \"\"\"Generate random graphbolt version homograph\"\"\"\n    # Generate random edges.\n    nodes = np.repeat(np.arange(num_nodes, dtype=np.int64), 5)\n    neighbors = np.random.randint(\n        0, num_nodes, size=(num_edges), dtype=np.int64\n    )\n    edges = np.stack([nodes, neighbors], axis=1)\n    os.makedirs(os.path.join(test_dir, \"edges\"), exist_ok=True)\n    assert edge_fmt in [\n        \"numpy\",\n        \"csv\",\n    ], \"Only numpy and csv are supported for edges.\"\n    if edge_fmt == \"csv\":\n        # Write into edges/edge.csv\n        edges_DataFrame = pd.DataFrame(edges, columns=[\"src\", \"dst\"])\n        edge_path = os.path.join(\"edges\", \"edge.csv\")\n        edges_DataFrame.to_csv(\n            os.path.join(test_dir, edge_path),\n            index=False,\n            header=False,\n        )\n    else:\n        # Write into edges/edge.npy\n        edges = edges.T\n        edge_path = os.path.join(\"edges\", \"edge.npy\")\n        np.save(os.path.join(test_dir, edge_path), edges)\n\n    # Generate random graph edge-feats.\n    edge_feats = np.random.rand(num_edges, num_classes)\n    os.makedirs(os.path.join(test_dir, \"data\"), exist_ok=True)\n    edge_feat_path = os.path.join(\"data\", \"edge-feat.npy\")\n    np.save(os.path.join(test_dir, edge_feat_path), edge_feats)\n\n    # Generate random node-feats.\n    if num_classes == 1:\n        node_feats = np.random.rand(num_nodes)\n    else:\n        node_feats = np.random.rand(num_nodes, num_classes)\n    node_feat_path = os.path.join(\"data\", \"node-feat.npy\")\n    np.save(os.path.join(test_dir, node_feat_path), node_feats)\n\n    # Generate train/test/valid set.\n    assert num_nodes % 4 == 0, \"num_nodes must be divisible by 4\"\n    each_set_size = num_nodes // 4\n    os.makedirs(os.path.join(test_dir, \"set\"), exist_ok=True)\n    train_pairs = (\n        np.arange(each_set_size),\n        np.arange(each_set_size, 2 * each_set_size),\n    )\n    train_data = np.vstack(train_pairs).T.astype(edges.dtype)\n    train_path = os.path.join(\"set\", \"train.npy\")\n    np.save(os.path.join(test_dir, train_path), train_data)\n\n    validation_pairs = (\n        np.arange(each_set_size, 2 * each_set_size),\n        np.arange(2 * each_set_size, 3 * each_set_size),\n    )\n    validation_data = np.vstack(validation_pairs).T.astype(edges.dtype)\n    validation_path = os.path.join(\"set\", \"validation.npy\")\n    np.save(os.path.join(test_dir, validation_path), validation_data)\n\n    test_pairs = (\n        np.arange(2 * each_set_size, 3 * each_set_size),\n        np.arange(3 * each_set_size, 4 * each_set_size),\n    )\n    test_data = np.vstack(test_pairs).T.astype(edges.dtype)\n    test_path = os.path.join(\"set\", \"test.npy\")\n    np.save(os.path.join(test_dir, test_path), test_data)\n\n    yaml_content = f\"\"\"\n        dataset_name: {dataset_name}\n        graph: # Graph structure and required attributes.\n            nodes:\n                - num: {num_nodes}\n            edges:\n                - format: {edge_fmt}\n                  path: {edge_path}\n            feature_data:\n                - domain: node\n                  type: null\n                  name: feat\n                  format: numpy\n                  in_memory: true\n                  path: {node_feat_path}\n                - domain: edge\n                  type: null\n                  name: feat\n                  format: numpy\n                  in_memory: true\n                  path: {edge_feat_path}\n        feature_data:\n            - domain: node\n              type: null\n              name: feat\n              format: numpy\n              in_memory: true\n              path: {node_feat_path}\n            - domain: edge\n              type: null\n              name: feat\n              format: numpy\n              path: {edge_feat_path}\n        tasks:\n          - name: link_prediction\n            num_classes: {num_classes}\n            train_set:\n              - type: null\n                data:\n                  - name: seeds\n                    format: numpy\n                    in_memory: true\n                    path: {train_path}\n            validation_set:\n              - type: null\n                data:\n                  - name: seeds\n                    format: numpy\n                    in_memory: true\n                    path: {validation_path}\n            test_set:\n              - type: null\n                data:\n                  - name: seeds\n                    format: numpy\n                    in_memory: true\n                    path: {test_path}\n    \"\"\"\n    return yaml_content\n\n\ndef generate_raw_data_for_hetero_dataset(\n    test_dir, dataset_name, num_nodes, num_edges, num_classes, edge_fmt=\"csv\"\n):\n    # Generate edges.\n    edges_path = {}\n    for etype, num_edge in num_edges.items():\n        src_ntype, etype_str, dst_ntype = etype\n        src = torch.randint(0, num_nodes[src_ntype], (num_edge,))\n        dst = torch.randint(0, num_nodes[dst_ntype], (num_edge,))\n        os.makedirs(os.path.join(test_dir, \"edges\"), exist_ok=True)\n        assert edge_fmt in [\n            \"numpy\",\n            \"csv\",\n        ], \"Only numpy and csv are supported for edges.\"\n        if edge_fmt == \"csv\":\n            # Write into edges/edge.csv\n            edges = pd.DataFrame(\n                np.stack([src, dst], axis=1), columns=[\"src\", \"dst\"]\n            )\n            edge_path = os.path.join(\"edges\", f\"{etype_str}.csv\")\n            edges.to_csv(\n                os.path.join(test_dir, edge_path),\n                index=False,\n                header=False,\n            )\n        else:\n            edges = np.stack([src, dst], axis=1).T\n            edge_path = os.path.join(\"edges\", f\"{etype_str}.npy\")\n            np.save(os.path.join(test_dir, edge_path), edges)\n        edges_path[etype_str] = edge_path\n\n    # Generate node features.\n    node_feats_path = {}\n    os.makedirs(os.path.join(test_dir, \"data\"), exist_ok=True)\n    for ntype, num_node in num_nodes.items():\n        node_feat_path = os.path.join(\"data\", f\"{ntype}-feat.npy\")\n        node_feats = np.random.rand(num_node, num_classes)\n        np.save(os.path.join(test_dir, node_feat_path), node_feats)\n        node_feats_path[ntype] = node_feat_path\n\n    # Generate edge features.\n    edge_feats_path = {}\n    os.makedirs(os.path.join(test_dir, \"data\"), exist_ok=True)\n    for etype, num_edge in num_edges.items():\n        src_ntype, etype_str, dst_ntype = etype\n        edge_feat_path = os.path.join(\"data\", f\"{etype_str}-feat.npy\")\n        edge_feats = np.random.rand(num_edge, num_classes)\n        np.save(os.path.join(test_dir, edge_feat_path), edge_feats)\n        edge_feats_path[etype_str] = edge_feat_path\n\n    # Generate train/test/valid set.\n    os.makedirs(os.path.join(test_dir, \"set\"), exist_ok=True)\n    user_ids = torch.arange(num_nodes[\"user\"])\n    np.random.shuffle(user_ids.numpy())\n    num_train = int(num_nodes[\"user\"] * 0.6)\n    num_validation = int(num_nodes[\"user\"] * 0.2)\n    num_test = num_nodes[\"user\"] - num_train - num_validation\n    train_path = os.path.join(\"set\", \"train.npy\")\n    np.save(os.path.join(test_dir, train_path), user_ids[:num_train])\n    validation_path = os.path.join(\"set\", \"validation.npy\")\n    np.save(\n        os.path.join(test_dir, validation_path),\n        user_ids[num_train : num_train + num_validation],\n    )\n    test_path = os.path.join(\"set\", \"test.npy\")\n    np.save(\n        os.path.join(test_dir, test_path),\n        user_ids[num_train + num_validation :],\n    )\n\n    yaml_content = f\"\"\"\n        dataset_name: {dataset_name}\n        graph: # Graph structure and required attributes.\n          nodes:\n            - type: user\n              num: {num_nodes[\"user\"]}\n            - type: item\n              num: {num_nodes[\"item\"]}\n          edges:\n            - type: \"user:follow:user\"\n              format: {edge_fmt}\n              path: {edges_path[\"follow\"]}\n            - type: \"user:click:item\"\n              format: {edge_fmt}\n              path: {edges_path[\"click\"]}\n          feature_data:\n            - domain: node\n              type: user\n              name: feat\n              format: numpy\n              in_memory: true\n              path: {node_feats_path[\"user\"]}\n            - domain: node\n              type: item\n              name: feat\n              format: numpy\n              in_memory: true\n              path: {node_feats_path[\"item\"]}\n            - domain: edge\n              type: \"user:follow:user\"\n              name: feat\n              format: numpy\n              in_memory: true\n              path: {edge_feats_path[\"follow\"]}\n            - domain: edge\n              type: \"user:click:item\"\n              name: feat\n              format: numpy\n              in_memory: true\n              path: {edge_feats_path[\"click\"]}\n        feature_data:\n          - domain: node\n            type: user\n            name: feat\n            format: numpy\n            in_memory: true\n            path: {node_feats_path[\"user\"]}\n          - domain: node\n            type: item\n            name: feat\n            format: numpy\n            in_memory: true\n            path: {node_feats_path[\"item\"]}\n        tasks:\n          - name: node_classification\n            num_classes: {num_classes}\n            train_set:\n              - type: user\n                data:\n                  - name: seeds\n                    format: numpy\n                    in_memory: true\n                    path: {train_path}\n            validation_set:\n              - type: user\n                data:\n                  - name: seeds\n                    format: numpy\n                    in_memory: true\n                    path: {validation_path}\n            test_set:\n              - type: user\n                data:\n                  - name: seeds\n                    format: numpy\n                    in_memory: true\n                    path: {test_path}\n    \"\"\"\n\n    yaml_file = os.path.join(test_dir, \"metadata.yaml\")\n    with open(yaml_file, \"w\") as f:\n        f.write(yaml_content)\n"
  },
  {
    "path": "tests/python/pytorch/graphbolt/impl/__init__.py",
    "content": "\"\"\" DGL graphbolt/impl tests\"\"\"\n"
  },
  {
    "path": "tests/python/pytorch/graphbolt/impl/test_basic_feature_store.py",
    "content": "import pytest\nimport torch\n\nfrom dgl import graphbolt as gb\n\n\ndef test_basic_feature_store_homo():\n    a = torch.tensor([[1, 2, 4], [2, 5, 3]])\n    b = torch.tensor([[[1, 2], [3, 4]], [[2, 5], [4, 3]]])\n    metadata = {\"max_value\": 3}\n\n    features = {}\n    features[(\"node\", None, \"a\")] = gb.TorchBasedFeature(a, metadata=metadata)\n    features[(\"node\", None, \"b\")] = gb.TorchBasedFeature(b)\n\n    feature_store = gb.BasicFeatureStore(features)\n\n    # Test __getitem__ to access the stored Feature.\n    feature = feature_store[(\"node\", None, \"a\")]\n    assert isinstance(feature, gb.Feature)\n    assert torch.equal(\n        feature.read(),\n        torch.tensor([[1, 2, 4], [2, 5, 3]]),\n    )\n\n    # Test read the entire feature.\n    assert torch.equal(\n        feature_store.read(\"node\", None, \"a\"),\n        torch.tensor([[1, 2, 4], [2, 5, 3]]),\n    )\n    assert torch.equal(\n        feature_store.read(\"node\", None, \"b\"),\n        torch.tensor([[[1, 2], [3, 4]], [[2, 5], [4, 3]]]),\n    )\n\n    # Test read with ids.\n    assert torch.equal(\n        feature_store.read(\"node\", None, \"a\", torch.tensor([0])),\n        torch.tensor([[1, 2, 4]]),\n    )\n    assert torch.equal(\n        feature_store.read(\"node\", None, \"b\", torch.tensor([0])),\n        torch.tensor([[[1, 2], [3, 4]]]),\n    )\n\n    # Test get the size and count of the entire feature.\n    assert feature_store.size(\"node\", None, \"a\") == torch.Size([3])\n    assert feature_store.size(\"node\", None, \"b\") == torch.Size([2, 2])\n    assert feature_store.count(\"node\", None, \"a\") == a.size(0)\n    assert feature_store.count(\"node\", None, \"b\") == b.size(0)\n\n    # Test get metadata of the feature.\n    assert feature_store.metadata(\"node\", None, \"a\") == metadata\n    assert feature_store.metadata(\"node\", None, \"b\") == {}\n\n    # Test __setitem__ and __contains__ of FeatureStore.\n    assert (\"node\", None, \"c\") not in feature_store\n    feature_store[(\"node\", None, \"c\")] = feature_store[(\"node\", None, \"a\")]\n    assert (\"node\", None, \"c\") in feature_store\n\n    # Test get keys of the features.\n    assert feature_store.keys() == [\n        (\"node\", None, \"a\"),\n        (\"node\", None, \"b\"),\n        (\"node\", None, \"c\"),\n    ]\n\n\ndef test_basic_feature_store_hetero():\n    a = torch.tensor([[1, 2, 4], [2, 5, 3]])\n    b = torch.tensor([[[6], [8]], [[8], [9]]])\n    metadata = {\"max_value\": 3}\n\n    features = {}\n    features[(\"node\", \"author\", \"a\")] = gb.TorchBasedFeature(\n        a, metadata=metadata\n    )\n    features[(\"edge\", \"paper:cites\", \"b\")] = gb.TorchBasedFeature(b)\n\n    feature_store = gb.BasicFeatureStore(features)\n\n    # Test __getitem__ to access the stored Feature.\n    feature = feature_store[(\"node\", \"author\", \"a\")]\n    assert isinstance(feature, gb.Feature)\n    assert torch.equal(\n        feature.read(),\n        torch.tensor([[1, 2, 4], [2, 5, 3]]),\n    )\n\n    # Test read the entire feature.\n    assert torch.equal(\n        feature_store.read(\"node\", \"author\", \"a\"),\n        torch.tensor([[1, 2, 4], [2, 5, 3]]),\n    )\n    assert torch.equal(\n        feature_store.read(\"edge\", \"paper:cites\", \"b\"),\n        torch.tensor([[[6], [8]], [[8], [9]]]),\n    )\n\n    # Test read with ids.\n    assert torch.equal(\n        feature_store.read(\"node\", \"author\", \"a\", torch.tensor([0])),\n        torch.tensor([[1, 2, 4]]),\n    )\n\n    # Test get the size of the entire feature.\n    assert feature_store.size(\"node\", \"author\", \"a\") == torch.Size([3])\n    assert feature_store.size(\"edge\", \"paper:cites\", \"b\") == torch.Size([2, 1])\n\n    # Test get metadata of the feature.\n    assert feature_store.metadata(\"node\", \"author\", \"a\") == metadata\n    assert feature_store.metadata(\"edge\", \"paper:cites\", \"b\") == {}\n\n    # Test __setitem__ and __contains__ of FeatureStore.\n    assert (\"node\", \"author\", \"c\") not in feature_store\n    feature_store[(\"node\", \"author\", \"c\")] = feature_store[\n        (\"node\", \"author\", \"a\")\n    ]\n    assert (\"node\", \"author\", \"c\") in feature_store\n\n    # Test get keys of the features.\n    assert feature_store.keys() == [\n        (\"node\", \"author\", \"a\"),\n        (\"edge\", \"paper:cites\", \"b\"),\n        (\"node\", \"author\", \"c\"),\n    ]\n\n\ndef test_basic_feature_store_errors():\n    a = torch.tensor([3, 2, 1])\n    b = torch.tensor([[1, 2, 4], [2, 5, 3]])\n\n    features = {}\n    # Test error when dimension of the value is illegal.\n    with pytest.raises(\n        AssertionError,\n        match=rf\"dimension of torch_feature in TorchBasedFeature must be \"\n        rf\"greater than 1, but got {a.dim()} dimension.\",\n    ):\n        features[(\"node\", \"paper\", \"a\")] = gb.TorchBasedFeature(a)\n    features[(\"node\", \"author\", \"b\")] = gb.TorchBasedFeature(b)\n\n    feature_store = gb.BasicFeatureStore(features)\n\n    # Test error when key does not exist.\n    with pytest.raises(KeyError):\n        feature_store.read(\"node\", \"paper\", \"b\")\n\n    # Test error when at least one id is out of bound.\n    with pytest.raises(IndexError):\n        feature_store.read(\"node\", \"author\", \"b\", torch.tensor([0, 3]))\n"
  },
  {
    "path": "tests/python/pytorch/graphbolt/impl/test_cooperative_minibatching_utils.py",
    "content": "import unittest\n\nfrom functools import partial\n\nimport backend as F\nimport dgl.graphbolt as gb\nimport pytest\nimport torch\n\nWORLD_SIZE = 7\n\nassert_equal = partial(torch.testing.assert_close, rtol=0, atol=0)\n\n\n@unittest.skipIf(\n    F._default_context_str != \"gpu\",\n    reason=\"This test requires an NVIDIA GPU.\",\n)\n@pytest.mark.parametrize(\"dtype\", [torch.int32, torch.int64])\n@pytest.mark.parametrize(\"rank\", list(range(WORLD_SIZE)))\ndef test_rank_sort_and_unique_and_compact(dtype, rank):\n    torch.manual_seed(7)\n    nodes_list1 = [\n        torch.randint(0, 2111111111, [777], dtype=dtype, device=F.ctx())\n        for _ in range(10)\n    ]\n    nodes_list2 = [nodes.sort()[0] for nodes in nodes_list1]\n\n    res1 = torch.ops.graphbolt.rank_sort(nodes_list1, rank, WORLD_SIZE)\n    res2 = torch.ops.graphbolt.rank_sort(nodes_list2, rank, WORLD_SIZE)\n\n    for i, ((nodes1, idx1, offsets1), (nodes2, idx2, offsets2)) in enumerate(\n        zip(res1, res2)\n    ):\n        assert_equal(nodes_list1[i], nodes1[idx1])\n        assert_equal(nodes_list2[i], nodes2[idx2])\n        assert_equal(offsets1, offsets2)\n        assert offsets1.is_pinned() and offsets2.is_pinned()\n\n    res3 = torch.ops.graphbolt.rank_sort(nodes_list1, rank, WORLD_SIZE)\n\n    # This function is deterministic. Call with identical arguments and check.\n    for (nodes1, idx1, offsets1), (nodes3, idx3, offsets3) in zip(res1, res3):\n        assert_equal(nodes1, nodes3)\n        assert_equal(idx1, idx3)\n        assert_equal(offsets1, offsets3)\n\n    # The dependency on the rank argument is simply a permutation.\n    res4 = torch.ops.graphbolt.rank_sort(nodes_list1, 0, WORLD_SIZE)\n    for (nodes1, idx1, offsets1), (nodes4, idx4, offsets4) in zip(res1, res4):\n        off1 = offsets1.tolist()\n        off4 = offsets4.tolist()\n        assert_equal(nodes1[idx1], nodes4[idx4])\n        for i in range(WORLD_SIZE):\n            j = (i - rank + WORLD_SIZE) % WORLD_SIZE\n            assert_equal(\n                nodes1[off1[j] : off1[j + 1]], nodes4[off4[i] : off4[i + 1]]\n            )\n\n    unique, compacted, offsets = gb.unique_and_compact(\n        nodes_list1[:1], rank, WORLD_SIZE\n    )\n\n    nodes1, idx1, offsets1 = res1[0]\n\n    assert_equal(unique, nodes1)\n    assert_equal(compacted[0], idx1)\n    assert_equal(offsets, offsets1)\n"
  },
  {
    "path": "tests/python/pytorch/graphbolt/impl/test_cpu_cached_feature.py",
    "content": "import os\nimport tempfile\nimport unittest\n\nimport backend as F\nimport numpy as np\nimport pytest\nimport torch\n\nfrom dgl import graphbolt as gb\n\n\ndef to_on_disk_numpy(test_dir, name, t):\n    path = os.path.join(test_dir, name + \".npy\")\n    np.save(path, t.numpy())\n    return path\n\n\n@pytest.mark.parametrize(\n    \"dtype\",\n    [\n        torch.bool,\n        torch.uint8,\n        torch.int8,\n        torch.int16,\n        torch.int32,\n        torch.int64,\n        torch.float16,\n        torch.bfloat16,\n        torch.float32,\n        torch.float64,\n    ],\n)\n@pytest.mark.parametrize(\"policy\", [\"s3-fifo\", \"sieve\", \"lru\", \"clock\"])\ndef test_cpu_cached_feature(dtype, policy):\n    cache_size_a = 32\n    cache_size_b = 64\n    a = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=dtype)\n    b = torch.tensor([[[1, 2], [3, 4]], [[4, 5], [6, 7]]], dtype=dtype)\n\n    pin_memory = F._default_context_str == \"gpu\"\n\n    cache_size_a *= a[:1].nbytes\n    cache_size_b *= b[:1].nbytes\n\n    feat_store_a = gb.cpu_cached_feature(\n        gb.TorchBasedFeature(a), cache_size_a, policy, pin_memory\n    )\n    feat_store_b = gb.cpu_cached_feature(\n        gb.TorchBasedFeature(b), cache_size_b, policy, pin_memory\n    )\n\n    # Test read the entire feature.\n    assert torch.equal(feat_store_a.read(), a)\n    assert torch.equal(feat_store_b.read(), b)\n\n    # Test read with ids.\n    assert torch.equal(\n        # Test read when ids are on a different device.\n        feat_store_a.read(torch.tensor([0], device=F.ctx())),\n        torch.tensor([[1, 2, 3]], dtype=dtype, device=F.ctx()),\n    )\n    assert torch.equal(\n        feat_store_b.read(torch.tensor([1, 1])),\n        torch.tensor([[[4, 5], [6, 7]], [[4, 5], [6, 7]]], dtype=dtype),\n    )\n    assert torch.equal(\n        feat_store_a.read(torch.tensor([1, 1])),\n        torch.tensor([[4, 5, 6], [4, 5, 6]], dtype=dtype),\n    )\n    assert torch.equal(\n        feat_store_b.read(torch.tensor([0])),\n        torch.tensor([[[1, 2], [3, 4]]], dtype=dtype),\n    )\n    # The cache should be full now for the large cache sizes, %100 hit expected.\n    total_miss = feat_store_a._feature.total_miss\n    feat_store_a.read(torch.tensor([0, 1]))\n    assert total_miss == feat_store_a._feature.total_miss\n    total_miss = feat_store_b._feature.total_miss\n    feat_store_b.read(torch.tensor([0, 1]))\n    assert total_miss == feat_store_b._feature.total_miss\n    assert feat_store_a._feature.miss_rate == feat_store_a.miss_rate\n\n    # Test get the size and count of the entire feature.\n    assert feat_store_a.size() == torch.Size([3])\n    assert feat_store_b.size() == torch.Size([2, 2])\n    assert feat_store_a.count() == a.size(0)\n    assert feat_store_b.count() == b.size(0)\n\n    # Test update the entire feature.\n    feat_store_a.update(torch.tensor([[0, 1, 2], [3, 5, 2]], dtype=dtype))\n    assert torch.equal(\n        feat_store_a.read(),\n        torch.tensor([[0, 1, 2], [3, 5, 2]], dtype=dtype),\n    )\n\n    # Test update with ids.\n    feat_store_a.update(\n        torch.tensor([[2, 0, 1]], dtype=dtype),\n        torch.tensor([0]),\n    )\n    assert torch.equal(\n        feat_store_a.read(),\n        torch.tensor([[2, 0, 1], [3, 5, 2]], dtype=dtype),\n    )\n\n    # Test with different dimensionality\n    feat_store_a.update(b)\n    assert torch.equal(feat_store_a.read(), b)\n\n\n@pytest.mark.parametrize(\n    \"dtype\",\n    [\n        torch.bool,\n        torch.uint8,\n        torch.int8,\n        torch.int16,\n        torch.int32,\n        torch.int64,\n        torch.float16,\n        torch.bfloat16,\n        torch.float32,\n        torch.float64,\n    ],\n)\ndef test_cpu_cached_feature_read_async(dtype):\n    a = torch.randint(0, 2, [1000, 13], dtype=dtype)\n\n    cache_size = 256 * a[:1].nbytes\n\n    feat_store = gb.cpu_cached_feature(gb.TorchBasedFeature(a), cache_size)\n\n    # Test read with ids.\n    ids1 = torch.tensor([0, 15, 71, 101])\n    ids2 = torch.tensor([71, 101, 202, 303])\n    for ids in [ids1, ids2]:\n        reader = feat_store.read_async(ids)\n        for _ in range(feat_store.read_async_num_stages(ids.device)):\n            values = next(reader)\n        assert torch.equal(values.wait(), a[ids])\n\n\n@unittest.skipIf(\n    not torch.ops.graphbolt.detect_io_uring(),\n    reason=\"DiskBasedFeature is not available on this system.\",\n)\n@pytest.mark.parametrize(\n    \"dtype\",\n    [\n        torch.bool,\n        torch.uint8,\n        torch.int8,\n        torch.int16,\n        torch.int32,\n        torch.int64,\n        torch.float16,\n        torch.float32,\n        torch.float64,\n    ],\n)\ndef test_cpu_cached_disk_feature_read_async(dtype):\n    a = torch.randint(0, 2, [1000, 13], dtype=dtype)\n\n    cache_size = 256 * a[:1].nbytes\n\n    ids1 = torch.tensor([0, 15, 71, 101])\n    ids2 = torch.tensor([71, 101, 202, 303])\n\n    with tempfile.TemporaryDirectory() as test_dir:\n        path = to_on_disk_numpy(test_dir, \"tensor\", a)\n\n        feat_store = gb.cpu_cached_feature(\n            gb.DiskBasedFeature(path=path), cache_size\n        )\n\n        # Test read feature.\n        for ids in [ids1, ids2]:\n            reader = feat_store.read_async(ids)\n            for _ in range(feat_store.read_async_num_stages(ids.device)):\n                values = next(reader)\n            assert torch.equal(values.wait(), a[ids])\n\n        feat_store = None\n"
  },
  {
    "path": "tests/python/pytorch/graphbolt/impl/test_disk_based_feature_store.py",
    "content": "import os\nimport tempfile\nimport unittest\nfrom functools import partial\n\nimport backend as F\n\nimport numpy as np\nimport pytest\nimport torch\n\nfrom dgl import graphbolt as gb\n\n\ndef to_on_disk_numpy(test_dir, name, t):\n    path = os.path.join(test_dir, name + \".npy\")\n    t = t.numpy()\n    np.save(path, t)\n    return path\n\n\nassert_equal = partial(torch.testing.assert_close, rtol=0, atol=0)\n\n\n@unittest.skipIf(\n    not torch.ops.graphbolt.detect_io_uring(),\n    reason=\"DiskBasedFeature is not available on this system.\",\n)\ndef test_disk_based_feature():\n    with tempfile.TemporaryDirectory() as test_dir:\n        a = torch.tensor([[1, 2, 3], [4, 5, 6]])\n        b = torch.tensor([[[1, 2], [3, 4]], [[4, 5], [6, 7]]])\n        c = torch.randn([4111, 47])\n        metadata = {\"max_value\": 3}\n        path_a = to_on_disk_numpy(test_dir, \"a\", a)\n        path_b = to_on_disk_numpy(test_dir, \"b\", b)\n        path_c = to_on_disk_numpy(test_dir, \"c\", c)\n\n        feature_a = gb.DiskBasedFeature(path=path_a, metadata=metadata)\n        feature_b = gb.DiskBasedFeature(path=path_b)\n        feature_c = gb.DiskBasedFeature(path=path_c)\n\n        # Read the entire feature.\n        assert_equal(feature_a.read(), torch.tensor([[1, 2, 3], [4, 5, 6]]))\n\n        assert_equal(\n            feature_b.read(), torch.tensor([[[1, 2], [3, 4]], [[4, 5], [6, 7]]])\n        )\n\n        # Test read the feature with ids.\n        assert_equal(\n            feature_a.read(torch.tensor([0])),\n            torch.tensor([[1, 2, 3]]),\n        )\n        assert_equal(\n            feature_b.read(torch.tensor([1])),\n            torch.tensor([[[4, 5], [6, 7]]]),\n        )\n\n        # Test reading into pin_memory\n        if F._default_context_str == \"gpu\":\n            res = feature_a.read(torch.tensor([0], pin_memory=True))\n            assert res.is_pinned()\n\n        # Test when the index tensor is large.\n        torch_based_feature_a = gb.TorchBasedFeature(a)\n        ind_a = torch.randint(low=0, high=a.size(0), size=(4111,))\n        assert_equal(\n            feature_a.read(ind_a),\n            torch_based_feature_a.read(ind_a),\n        )\n\n        # Test converting to torch_based_feature with read_into_memory()\n        torch_based_feature_b = feature_b.read_into_memory()\n        ind_b = torch.randint(low=0, high=b.size(0), size=(4111,))\n        assert_equal(\n            feature_b.read(ind_b),\n            torch_based_feature_b.read(ind_b),\n        )\n\n        # Test with larger stored feature tensor\n        ind_c = torch.randint(low=0, high=c.size(0), size=(4111,))\n        assert_equal(feature_c.read(ind_c), c[ind_c])\n\n        # Test get the size and count of the entire feature.\n        assert feature_a.size() == torch.Size([3])\n        assert feature_b.size() == torch.Size([2, 2])\n        assert feature_a.count() == a.size(0)\n        assert feature_b.count() == b.size(0)\n\n        # Test get metadata of the feature.\n        assert feature_a.metadata() == metadata\n        assert feature_b.metadata() == {}\n\n        with pytest.raises(IndexError):\n            feature_a.read(torch.tensor([0, 1, 2, 3]))\n\n        # Test loading a Fortran contiguous ndarray.\n        a_T = np.asfortranarray(a)\n        path_a_T = test_dir + \"a_T.npy\"\n        np.save(path_a_T, a_T)\n        with pytest.raises(\n            AssertionError,\n            match=\"DiskBasedFeature only supports C_CONTIGUOUS array.\",\n        ):\n            gb.DiskBasedFeature(path=path_a_T, metadata=metadata)\n\n        # For windows, the file is locked by the numpy.load. We need to delete\n        # it before closing the temporary directory.\n        a = b = c = None\n        feature_a = feature_b = feature_c = None\n\n\n@unittest.skipIf(\n    not torch.ops.graphbolt.detect_io_uring(),\n    reason=\"DiskBasedFeature is not available on this system.\",\n)\n@pytest.mark.parametrize(\n    \"dtype\",\n    [\n        torch.float32,\n        torch.float64,\n        torch.int32,\n        torch.int64,\n        torch.int8,\n        torch.float16,\n        torch.complex128,\n    ],\n)\n@pytest.mark.parametrize(\"idtype\", [torch.int32, torch.int64])\n@pytest.mark.parametrize(\n    \"shape\", [(10, 20), (20, 10), (20, 25, 10), (137, 50, 30)]\n)\n@pytest.mark.parametrize(\"index\", [[0], [1, 2, 3], [0, 6, 2, 8]])\ndef test_more_disk_based_feature(dtype, idtype, shape, index):\n    if dtype == torch.complex128:\n        tensor = torch.complex(\n            torch.randint(0, 127, shape, dtype=torch.float64),\n            torch.randint(0, 127, shape, dtype=torch.float64),\n        )\n    else:\n        tensor = torch.randint(0, 127, shape, dtype=dtype)\n    test_tensor = tensor.clone()\n    idx = torch.tensor(index, dtype=idtype)\n\n    with tempfile.TemporaryDirectory() as test_dir:\n        path = to_on_disk_numpy(test_dir, \"tensor\", tensor)\n\n        feature = gb.DiskBasedFeature(path=path)\n\n        # Test read feature.\n        assert_equal(feature.read(idx), test_tensor[idx.long()])\n\n\n@unittest.skipIf(\n    not torch.ops.graphbolt.detect_io_uring(),\n    reason=\"DiskBasedFeature is not available on this system.\",\n)\ndef test_disk_based_feature_repr():\n    with tempfile.TemporaryDirectory() as test_dir:\n        a = torch.tensor([[1, 2, 3], [4, 5, 6]])\n        b = torch.tensor([[[1, 2], [3, 4]], [[4, 5], [6, 7]]])\n        metadata = {\"max_value\": 3}\n\n        path_a = to_on_disk_numpy(test_dir, \"a\", a)\n        path_b = to_on_disk_numpy(test_dir, \"b\", b)\n\n        feature_a = gb.DiskBasedFeature(path=path_a, metadata=metadata)\n        feature_b = gb.DiskBasedFeature(path=path_b)\n\n        expected_str_feature_a = str(\n            \"DiskBasedFeature(\\n\"\n            \"    feature=tensor([[1, 2, 3],\\n\"\n            \"                    [4, 5, 6]]),\\n\"\n            \"    metadata={'max_value': 3},\\n\"\n            \")\"\n        )\n        expected_str_feature_b = str(\n            \"DiskBasedFeature(\\n\"\n            \"    feature=tensor([[[1, 2],\\n\"\n            \"                     [3, 4]],\\n\"\n            \"\\n\"\n            \"                    [[4, 5],\\n\"\n            \"                     [6, 7]]]),\\n\"\n            \"    metadata={},\\n\"\n            \")\"\n        )\n        assert str(feature_a) == expected_str_feature_a\n        assert str(feature_b) == expected_str_feature_b\n        a = b = metadata = None\n        feature_a = feature_b = None\n        expected_str_feature_a = expected_str_feature_b = None\n"
  },
  {
    "path": "tests/python/pytorch/graphbolt/impl/test_feature_cache.py",
    "content": "import backend as F\n\nimport pytest\nimport torch\n\nfrom dgl import graphbolt as gb\n\n\ndef _test_query_and_replace(policy1, policy2, keys, offset):\n    # Testing query_and_replace equivalence to query and then replace.\n    (\n        _,\n        index,\n        pointers,\n        missing_keys,\n        found_offsets,\n        missing_offsets,\n    ) = policy1.query_and_replace(keys, offset)\n    found_cnt = keys.size(0) - missing_keys.size(0)\n    found_pointers = pointers[:found_cnt]\n    policy1.reading_completed(found_pointers, found_offsets)\n    missing_pointers = pointers[found_cnt:]\n    policy1.writing_completed(missing_pointers, missing_offsets)\n\n    (\n        _,\n        index2,\n        missing_keys2,\n        found_pointers2,\n        found_offsets2,\n        missing_offsets2,\n    ) = policy2.query(keys + offset, 0)\n    policy2.reading_completed(found_pointers2, found_offsets2)\n    (_, missing_pointers2, missing_offsets2) = policy2.replace(\n        missing_keys2, missing_offsets2, 0\n    )\n    policy2.writing_completed(missing_pointers2, missing_offsets2)\n\n    assert torch.equal(index, index2)\n    assert torch.equal(missing_keys, missing_keys2 - offset)\n\n\n@pytest.mark.parametrize(\"offsets\", [False, True])\n@pytest.mark.parametrize(\n    \"dtype\",\n    [\n        torch.bool,\n        torch.uint8,\n        torch.int8,\n        torch.int16,\n        torch.int32,\n        torch.int64,\n        torch.float16,\n        torch.bfloat16,\n        torch.float32,\n        torch.float64,\n    ],\n)\n@pytest.mark.parametrize(\"feature_size\", [2, 16])\n@pytest.mark.parametrize(\"num_parts\", [1, 2, None])\n@pytest.mark.parametrize(\"policy\", [\"s3-fifo\", \"sieve\", \"lru\", \"clock\"])\n@pytest.mark.parametrize(\"offset\", [0, 1111111])\ndef test_feature_cache(offsets, dtype, feature_size, num_parts, policy, offset):\n    cache_size = 32 * (\n        torch.get_num_threads() if num_parts is None else num_parts\n    )\n    a = torch.randint(0, 2, [1024, feature_size], dtype=dtype)\n    cache = gb.impl.CPUFeatureCache(\n        (cache_size,) + a.shape[1:], a.dtype, policy, num_parts\n    )\n    cache2 = gb.impl.CPUFeatureCache(\n        (cache_size,) + a.shape[1:], a.dtype, policy, num_parts\n    )\n    policy1 = gb.impl.CPUFeatureCache(\n        (cache_size,) + a.shape[1:], a.dtype, policy, num_parts\n    )._policy\n    policy2 = gb.impl.CPUFeatureCache(\n        (cache_size,) + a.shape[1:], a.dtype, policy, num_parts\n    )._policy\n    reader_fn = lambda keys: a[keys]\n\n    keys = torch.tensor([0, 1])\n    values, missing_index, missing_keys, missing_offsets = cache.query(\n        keys, offset\n    )\n    if not offsets:\n        missing_offsets = None\n    assert torch.equal(\n        missing_keys.flip([0]) if num_parts == 1 else missing_keys.sort()[0],\n        keys,\n    )\n\n    missing_values = a[missing_keys]\n    cache.replace(missing_keys, missing_values, missing_offsets, offset)\n    values[missing_index] = missing_values\n    assert torch.equal(values, a[keys])\n    assert torch.equal(\n        cache2.query_and_replace(keys, reader_fn, offset), a[keys]\n    )\n\n    _test_query_and_replace(policy1, policy2, keys, offset)\n\n    pin_memory = F._default_context_str == \"gpu\"\n\n    keys = torch.arange(1, 33, pin_memory=pin_memory)\n    values, missing_index, missing_keys, missing_offsets = cache.query(\n        keys, offset\n    )\n    if not offsets:\n        missing_offsets = None\n    assert torch.equal(\n        missing_keys.flip([0]) if num_parts == 1 else missing_keys.sort()[0],\n        torch.arange(2, 33),\n    )\n    assert not pin_memory or values.is_pinned()\n\n    missing_values = a[missing_keys]\n    cache.replace(missing_keys, missing_values, missing_offsets, offset)\n    values[missing_index] = missing_values\n    assert torch.equal(values, a[keys])\n    assert torch.equal(\n        cache2.query_and_replace(keys, reader_fn, offset), a[keys]\n    )\n\n    _test_query_and_replace(policy1, policy2, keys, offset)\n\n    values, missing_index, missing_keys, missing_offsets = cache.query(\n        keys, offset\n    )\n    if not offsets:\n        missing_offsets = None\n    assert torch.equal(missing_keys.flip([0]), torch.tensor([]))\n\n    missing_values = a[missing_keys]\n    cache.replace(missing_keys, missing_values, missing_offsets, offset)\n    values[missing_index] = missing_values\n    assert torch.equal(values, a[keys])\n    assert torch.equal(\n        cache2.query_and_replace(keys, reader_fn, offset), a[keys]\n    )\n\n    _test_query_and_replace(policy1, policy2, keys, offset)\n\n    values, missing_index, missing_keys, missing_offsets = cache.query(\n        keys, offset\n    )\n    if not offsets:\n        missing_offsets = None\n    assert torch.equal(missing_keys.flip([0]), torch.tensor([]))\n\n    missing_values = a[missing_keys]\n    cache.replace(missing_keys, missing_values, missing_offsets, offset)\n    values[missing_index] = missing_values\n    assert torch.equal(values, a[keys])\n    assert torch.equal(\n        cache2.query_and_replace(keys, reader_fn, offset), a[keys]\n    )\n\n    _test_query_and_replace(policy1, policy2, keys, offset)\n\n    assert cache.miss_rate == cache2.miss_rate\n\n    raw_feature_cache = torch.ops.graphbolt.feature_cache(\n        (cache_size,) + a.shape[1:], a.dtype, pin_memory\n    )\n    idx = torch.tensor([0, 1, 2])\n    raw_feature_cache.replace(idx, a[idx])\n    val = raw_feature_cache.index_select(idx)\n    assert torch.equal(val, a[idx])\n    if pin_memory:\n        val = raw_feature_cache.index_select(idx.to(F.ctx()))\n        assert torch.equal(val, a[idx].to(F.ctx()))\n"
  },
  {
    "path": "tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py",
    "content": "import os\n\nimport pickle\nimport re\nimport tempfile\nimport unittest\n\nimport backend as F\n\nimport dgl\nimport dgl.graphbolt as gb\nimport pytest\nimport torch\nimport torch.multiprocessing as mp\n\nfrom dgl.graphbolt.base import etype_str_to_tuple\nfrom scipy import sparse as spsp\n\nfrom .. import gb_test_utils as gbt\n\ntorch.manual_seed(3407)\nmp.set_sharing_strategy(\"file_system\")\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Graph is CPU only at present.\",\n)\n@pytest.mark.parametrize(\"total_num_nodes\", [0, 1, 10, 100, 1000])\ndef test_empty_graph(total_num_nodes):\n    csc_indptr = torch.zeros((total_num_nodes + 1,), dtype=int)\n    indices = torch.tensor([])\n    graph = gb.fused_csc_sampling_graph(csc_indptr, indices)\n    assert graph.total_num_edges == 0\n    assert graph.total_num_nodes == total_num_nodes\n    assert torch.equal(graph.csc_indptr, csc_indptr)\n    assert torch.equal(graph.indices, indices)\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Graph is CPU only at present.\",\n)\n@pytest.mark.parametrize(\"total_num_nodes\", [0, 1, 10, 100, 1000])\ndef test_hetero_empty_graph(total_num_nodes):\n    csc_indptr = torch.zeros((total_num_nodes + 1,), dtype=int)\n    indices = torch.tensor([])\n    node_type_to_id, edge_type_to_id = gbt.get_type_to_id(\n        num_ntypes=3, num_etypes=5\n    )\n    # Some node types have no nodes.\n    if total_num_nodes == 0:\n        node_type_offset = torch.zeros((4,), dtype=int)\n    else:\n        node_type_offset = torch.sort(torch.randint(0, total_num_nodes, (4,)))[\n            0\n        ]\n        node_type_offset[0] = 0\n        node_type_offset[-1] = total_num_nodes\n    type_per_edge = torch.tensor([])\n    graph = gb.fused_csc_sampling_graph(\n        csc_indptr,\n        indices,\n        node_type_offset=node_type_offset,\n        type_per_edge=type_per_edge,\n        node_type_to_id=node_type_to_id,\n        edge_type_to_id=edge_type_to_id,\n        edge_attributes=None,\n    )\n    assert graph.total_num_edges == 0\n    assert graph.total_num_nodes == total_num_nodes\n    assert torch.equal(graph.csc_indptr, csc_indptr)\n    assert torch.equal(graph.indices, indices)\n    assert graph.node_type_to_id == node_type_to_id\n    assert graph.edge_type_to_id == edge_type_to_id\n    assert torch.equal(graph.node_type_offset, node_type_offset)\n    assert torch.equal(graph.type_per_edge, type_per_edge)\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Graph is CPU only at present.\",\n)\n@pytest.mark.parametrize(\n    \"ntypes\", [{\"n1\": 1, \"n2\": 1}, {5: 1, \"n2\": 2}, {\"n1\": 1.5, \"n2\": 2.0}]\n)\ndef test_type_to_id_with_ntype_exception(ntypes):\n    with pytest.raises(AssertionError):\n        gb.fused_csc_sampling_graph(\n            None, None, node_type_to_id=ntypes, edge_type_to_id={\"e1\": 1}\n        )\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Graph is CPU only at present.\",\n)\n@pytest.mark.parametrize(\n    \"etypes\",\n    [\n        {(\"n1\", 5, \"n12\"): 1},\n        {\"e1\": 1},\n        {(\"n1\", \"e1\"): 1},\n        {(\"n1\", \"e1\", 10): 1},\n        {\"n1:e1:n2\": 1, (\"n1\", \"e2\", \"n3\"): 1},\n        {(\"n1\", \"e1\", \"n10\"): 1},\n        {\"n1:e1:n2\": 1.5},\n    ],\n)\ndef test_type_to_id_with_etype_exception(etypes):\n    with pytest.raises(Exception):\n        gb.fused_csc_sampling_graph(\n            None,\n            None,\n            node_type_to_id={\"n1\": 0, \"n2\": 1, \"n3\": 2},\n            edge_type_to_id=etypes,\n        )\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Graph is CPU only at present.\",\n)\n@pytest.mark.parametrize(\n    \"total_num_nodes, total_num_edges\",\n    [(1, 1), (100, 1), (10, 50), (1000, 50000)],\n)\ndef test_homo_graph(total_num_nodes, total_num_edges):\n    csc_indptr, indices = gbt.random_homo_graph(\n        total_num_nodes, total_num_edges\n    )\n    node_attributes = {\n        \"A1\": torch.arange(total_num_nodes),\n        \"A2\": torch.arange(total_num_nodes),\n    }\n    edge_attributes = {\n        \"A1\": torch.randn(total_num_edges),\n        \"A2\": torch.randn(total_num_edges),\n    }\n    graph = gb.fused_csc_sampling_graph(\n        csc_indptr,\n        indices,\n        node_attributes=node_attributes,\n        edge_attributes=edge_attributes,\n    )\n\n    assert graph.total_num_nodes == total_num_nodes\n    assert graph.total_num_edges == total_num_edges\n\n    assert torch.equal(csc_indptr, graph.csc_indptr)\n    assert torch.equal(indices, graph.indices)\n\n    assert graph.node_attributes == node_attributes\n    assert graph.edge_attributes == edge_attributes\n    assert graph.node_type_offset is None\n    assert graph.type_per_edge is None\n    assert graph.node_type_to_id is None\n    assert graph.edge_type_to_id is None\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Graph is CPU only at present.\",\n)\n@pytest.mark.parametrize(\n    \"total_num_nodes, total_num_edges\",\n    [(1, 1), (100, 1), (10, 50), (1000, 50000)],\n)\n@pytest.mark.parametrize(\"num_ntypes, num_etypes\", [(1, 1), (3, 5), (100, 1)])\ndef test_hetero_graph(total_num_nodes, total_num_edges, num_ntypes, num_etypes):\n    (\n        csc_indptr,\n        indices,\n        node_type_offset,\n        type_per_edge,\n        node_type_to_id,\n        edge_type_to_id,\n    ) = gbt.random_hetero_graph(\n        total_num_nodes, total_num_edges, num_ntypes, num_etypes\n    )\n    node_attributes = {\n        \"A1\": torch.arange(total_num_nodes),\n        \"A2\": torch.arange(total_num_nodes),\n    }\n    edge_attributes = {\n        \"A1\": torch.randn(total_num_edges),\n        \"A2\": torch.randn(total_num_edges),\n    }\n    graph = gb.fused_csc_sampling_graph(\n        csc_indptr,\n        indices,\n        node_type_offset=node_type_offset,\n        type_per_edge=type_per_edge,\n        node_type_to_id=node_type_to_id,\n        edge_type_to_id=edge_type_to_id,\n        node_attributes=node_attributes,\n        edge_attributes=edge_attributes,\n    )\n\n    assert graph.total_num_nodes == total_num_nodes\n    assert graph.total_num_edges == total_num_edges\n\n    assert torch.equal(csc_indptr, graph.csc_indptr)\n    assert torch.equal(indices, graph.indices)\n    assert torch.equal(node_type_offset, graph.node_type_offset)\n    assert torch.equal(type_per_edge, graph.type_per_edge)\n    assert graph.node_attributes == node_attributes\n    assert graph.edge_attributes == edge_attributes\n    assert node_type_to_id == graph.node_type_to_id\n    assert edge_type_to_id == graph.edge_type_to_id\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Graph is CPU only at present.\",\n)\n@pytest.mark.parametrize(\n    \"total_num_nodes, total_num_edges\",\n    [(1, 1), (100, 1), (10, 50), (1000, 50000)],\n)\ndef test_num_nodes_edges_homo(total_num_nodes, total_num_edges):\n    csc_indptr, indices = gbt.random_homo_graph(\n        total_num_nodes, total_num_edges\n    )\n    edge_attributes = {\n        \"A1\": torch.randn(total_num_edges),\n        \"A2\": torch.randn(total_num_edges),\n    }\n    graph = gb.fused_csc_sampling_graph(\n        csc_indptr, indices, edge_attributes=edge_attributes\n    )\n\n    assert graph.num_nodes == total_num_nodes\n    assert graph.num_edges == total_num_edges\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Graph is CPU only at present.\",\n)\ndef test_num_nodes_hetero():\n    \"\"\"Original graph in COO:\n    1   0   1   0   1\n    1   0   1   1   0\n    0   1   0   1   0\n    0   1   0   0   1\n    1   0   0   0   1\n\n    node_type_0: [0, 1]\n    node_type_1: [2, 3, 4]\n    edge_type_0: node_type_0 -> node_type_0\n    edge_type_1: node_type_0 -> node_type_1\n    edge_type_2: node_type_1 -> node_type_0\n    edge_type_3: node_type_1 -> node_type_1\n    \"\"\"\n    # Initialize data.\n    total_num_nodes = 5\n    total_num_edges = 12\n    ntypes = {\n        \"N0\": 0,\n        \"N1\": 1,\n    }\n    etypes = {\n        \"N0:R0:N0\": 0,\n        \"N0:R1:N1\": 1,\n        \"N1:R2:N0\": 2,\n        \"N1:R3:N1\": 3,\n        \"N1:R4:N0\": 4,\n    }\n    indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])\n    indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])\n    node_type_offset = torch.LongTensor([0, 2, 5])\n    type_per_edge = torch.LongTensor([0, 0, 2, 2, 2, 1, 1, 1, 3, 1, 3, 3])\n    assert indptr[-1] == total_num_edges\n    assert indptr[-1] == len(indices)\n    assert node_type_offset[-1] == total_num_nodes\n    assert all(type_per_edge < len(etypes))\n\n    # Construct FusedCSCSamplingGraph.\n    graph = gb.fused_csc_sampling_graph(\n        indptr,\n        indices,\n        node_type_offset=node_type_offset,\n        type_per_edge=type_per_edge,\n        node_type_to_id=ntypes,\n        edge_type_to_id=etypes,\n    )\n\n    # Verify nodes number per node types.\n    assert graph.num_nodes == {\n        \"N0\": 2,\n        \"N1\": 3,\n    }\n    assert sum(graph.num_nodes.values()) == total_num_nodes\n    # Verify edges number per edge types.\n    assert graph.num_edges == {\n        \"N0:R0:N0\": 2,\n        \"N0:R1:N1\": 4,\n        \"N1:R2:N0\": 3,\n        \"N1:R3:N1\": 3,\n        \"N1:R4:N0\": 0,\n    }\n    assert sum(graph.num_edges.values()) == total_num_edges\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Graph is CPU only at present.\",\n)\n@pytest.mark.parametrize(\n    \"node_type_offset\",\n    [\n        torch.tensor([0, 1]),\n        torch.tensor([0, 1, 5, 6, 10]),\n        torch.tensor([0, 1, 10]),\n    ],\n)\ndef test_node_type_offset_wrong_legnth(node_type_offset):\n    num_ntypes = 3\n    (\n        csc_indptr,\n        indices,\n        _,\n        type_per_edge,\n        node_type_to_id,\n        edge_type_to_id,\n    ) = gbt.random_hetero_graph(10, 50, num_ntypes, 5)\n    with pytest.raises(Exception):\n        gb.fused_csc_sampling_graph(\n            csc_indptr,\n            indices,\n            node_type_offset=node_type_offset,\n            type_per_edge=type_per_edge,\n            node_type_to_id=node_type_to_id,\n            edge_type_to_id=edge_type_to_id,\n        )\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Graph is CPU only at present.\",\n)\n@pytest.mark.parametrize(\n    \"total_num_nodes, total_num_edges\",\n    [(1, 1), (100, 1), (10, 50), (1000, 50000)],\n)\n@pytest.mark.parametrize(\"has_node_attrs\", [True, False])\n@pytest.mark.parametrize(\"has_edge_attrs\", [True, False])\ndef test_load_save_homo_graph(\n    total_num_nodes, total_num_edges, has_node_attrs, has_edge_attrs\n):\n    csc_indptr, indices = gbt.random_homo_graph(\n        total_num_nodes, total_num_edges\n    )\n    node_attributes = None\n    if has_node_attrs:\n        node_attributes = {\n            \"A\": torch.arange(total_num_nodes),\n            \"B\": torch.arange(total_num_nodes),\n        }\n    edge_attributes = None\n    if has_edge_attrs:\n        edge_attributes = {\n            \"A\": torch.arange(total_num_edges),\n            \"B\": torch.arange(total_num_edges),\n        }\n    graph = gb.fused_csc_sampling_graph(\n        csc_indptr,\n        indices,\n        node_attributes=node_attributes,\n        edge_attributes=edge_attributes,\n    )\n\n    with tempfile.TemporaryDirectory() as test_dir:\n        filename = os.path.join(test_dir, \"fused_csc_sampling_graph.pt\")\n        torch.save(graph, filename)\n        graph2 = torch.load(filename, weights_only=False)\n\n    assert graph.total_num_nodes == graph2.total_num_nodes\n    assert graph.total_num_edges == graph2.total_num_edges\n\n    assert torch.equal(graph.csc_indptr, graph2.csc_indptr)\n    assert torch.equal(graph.indices, graph2.indices)\n\n    assert graph.node_type_offset is None and graph2.node_type_offset is None\n    assert graph.type_per_edge is None and graph2.type_per_edge is None\n    assert graph.node_type_to_id is None and graph2.node_type_to_id is None\n    assert graph.edge_type_to_id is None and graph2.edge_type_to_id is None\n    if has_node_attrs:\n        assert graph.node_attributes.keys() == graph2.node_attributes.keys()\n        for key in graph.node_attributes.keys():\n            assert torch.equal(\n                graph.node_attributes[key], graph2.node_attributes[key]\n            )\n    else:\n        assert graph.node_attributes is None and graph2.node_attributes is None\n    if has_edge_attrs:\n        assert graph.edge_attributes.keys() == graph2.edge_attributes.keys()\n        for key in graph.edge_attributes.keys():\n            assert torch.equal(\n                graph.edge_attributes[key], graph2.edge_attributes[key]\n            )\n    else:\n        assert graph.edge_attributes is None and graph2.edge_attributes is None\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Graph is CPU only at present.\",\n)\n@pytest.mark.parametrize(\n    \"total_num_nodes, total_num_edges\",\n    [(1, 1), (100, 1), (10, 50), (1000, 50000)],\n)\n@pytest.mark.parametrize(\"num_ntypes, num_etypes\", [(1, 1), (3, 5), (100, 1)])\n@pytest.mark.parametrize(\"has_node_attrs\", [True, False])\n@pytest.mark.parametrize(\"has_edge_attrs\", [True, False])\ndef test_load_save_hetero_graph(\n    total_num_nodes,\n    total_num_edges,\n    num_ntypes,\n    num_etypes,\n    has_node_attrs,\n    has_edge_attrs,\n):\n    (\n        csc_indptr,\n        indices,\n        node_type_offset,\n        type_per_edge,\n        node_type_to_id,\n        edge_type_to_id,\n    ) = gbt.random_hetero_graph(\n        total_num_nodes, total_num_edges, num_ntypes, num_etypes\n    )\n    node_attributes = None\n    if has_node_attrs:\n        node_attributes = {\n            \"A\": torch.arange(total_num_nodes),\n            \"B\": torch.arange(total_num_nodes),\n        }\n    edge_attributes = None\n    if has_edge_attrs:\n        edge_attributes = {\n            \"A\": torch.arange(total_num_edges),\n            \"B\": torch.arange(total_num_edges),\n        }\n    graph = gb.fused_csc_sampling_graph(\n        csc_indptr,\n        indices,\n        node_type_offset=node_type_offset,\n        type_per_edge=type_per_edge,\n        node_type_to_id=node_type_to_id,\n        edge_type_to_id=edge_type_to_id,\n        node_attributes=node_attributes,\n        edge_attributes=edge_attributes,\n    )\n\n    with tempfile.TemporaryDirectory() as test_dir:\n        filename = os.path.join(test_dir, \"fused_csc_sampling_graph.pt\")\n        torch.save(graph, filename)\n        graph2 = torch.load(filename, weights_only=False)\n\n    assert graph.total_num_nodes == graph2.total_num_nodes\n    assert graph.total_num_edges == graph2.total_num_edges\n\n    assert torch.equal(graph.csc_indptr, graph2.csc_indptr)\n    assert torch.equal(graph.indices, graph2.indices)\n    assert torch.equal(graph.node_type_offset, graph2.node_type_offset)\n    assert torch.equal(graph.type_per_edge, graph2.type_per_edge)\n    assert graph.node_type_to_id == graph2.node_type_to_id\n    assert graph.edge_type_to_id == graph2.edge_type_to_id\n    if has_node_attrs:\n        assert graph.node_attributes.keys() == graph2.node_attributes.keys()\n        for key in graph.node_attributes.keys():\n            assert torch.equal(\n                graph.node_attributes[key], graph2.node_attributes[key]\n            )\n    else:\n        assert graph.node_attributes is None and graph2.node_attributes is None\n    if has_edge_attrs:\n        assert graph.edge_attributes.keys() == graph2.edge_attributes.keys()\n        for key in graph.edge_attributes.keys():\n            assert torch.equal(\n                graph.edge_attributes[key], graph2.edge_attributes[key]\n            )\n    else:\n        assert graph.edge_attributes is None and graph2.edge_attributes is None\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Graph is CPU only at present.\",\n)\n@pytest.mark.parametrize(\n    \"total_num_nodes, total_num_edges\",\n    [(1, 1), (100, 1), (10, 50), (1000, 50000)],\n)\n@pytest.mark.parametrize(\"has_node_attrs\", [True, False])\n@pytest.mark.parametrize(\"has_edge_attrs\", [True, False])\ndef test_pickle_homo_graph(\n    total_num_nodes, total_num_edges, has_node_attrs, has_edge_attrs\n):\n    csc_indptr, indices = gbt.random_homo_graph(\n        total_num_nodes, total_num_edges\n    )\n    node_attributes = None\n    if has_node_attrs:\n        node_attributes = {\n            \"A\": torch.arange(total_num_nodes),\n            \"B\": torch.arange(total_num_nodes),\n        }\n    edge_attributes = None\n    if has_edge_attrs:\n        edge_attributes = {\n            \"A\": torch.arange(total_num_edges),\n            \"B\": torch.arange(total_num_edges),\n        }\n    graph = gb.fused_csc_sampling_graph(\n        csc_indptr,\n        indices,\n        node_attributes=node_attributes,\n        edge_attributes=edge_attributes,\n    )\n\n    serialized = pickle.dumps(graph)\n    graph2 = pickle.loads(serialized)\n\n    assert graph.total_num_nodes == graph2.total_num_nodes\n    assert graph.total_num_edges == graph2.total_num_edges\n\n    assert torch.equal(graph.csc_indptr, graph2.csc_indptr)\n    assert torch.equal(graph.indices, graph2.indices)\n\n    assert graph.node_type_offset is None and graph2.node_type_offset is None\n    assert graph.type_per_edge is None and graph2.type_per_edge is None\n    assert graph.node_type_to_id is None and graph2.node_type_to_id is None\n    assert graph.edge_type_to_id is None and graph2.edge_type_to_id is None\n    if has_node_attrs:\n        assert graph.node_attributes.keys() == graph2.node_attributes.keys()\n        for key in graph.node_attributes.keys():\n            assert torch.equal(\n                graph.node_attributes[key], graph2.node_attributes[key]\n            )\n    else:\n        assert graph.node_attributes is None and graph2.node_attributes is None\n    if has_edge_attrs:\n        assert graph.edge_attributes.keys() == graph2.edge_attributes.keys()\n        for key in graph.edge_attributes.keys():\n            assert torch.equal(\n                graph.edge_attributes[key], graph2.edge_attributes[key]\n            )\n    else:\n        assert graph.edge_attributes is None and graph2.edge_attributes is None\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Graph is CPU only at present.\",\n)\n@pytest.mark.parametrize(\n    \"total_num_nodes, total_num_edges\",\n    [(1, 1), (100, 1), (10, 50), (1000, 50000)],\n)\n@pytest.mark.parametrize(\"num_ntypes, num_etypes\", [(1, 1), (3, 5), (100, 1)])\n@pytest.mark.parametrize(\"has_node_attrs\", [True, False])\n@pytest.mark.parametrize(\"has_edge_attrs\", [True, False])\ndef test_pickle_hetero_graph(\n    total_num_nodes,\n    total_num_edges,\n    num_ntypes,\n    num_etypes,\n    has_node_attrs,\n    has_edge_attrs,\n):\n    (\n        csc_indptr,\n        indices,\n        node_type_offset,\n        type_per_edge,\n        node_type_to_id,\n        edge_type_to_id,\n    ) = gbt.random_hetero_graph(\n        total_num_nodes, total_num_edges, num_ntypes, num_etypes\n    )\n    node_attributes = None\n    if has_node_attrs:\n        node_attributes = {\n            \"A\": torch.arange(total_num_nodes),\n            \"B\": torch.arange(total_num_nodes),\n        }\n    edge_attributes = None\n    if has_edge_attrs:\n        edge_attributes = {\n            \"A\": torch.arange(total_num_edges),\n            \"B\": torch.arange(total_num_edges),\n        }\n    graph = gb.fused_csc_sampling_graph(\n        csc_indptr,\n        indices,\n        node_type_offset=node_type_offset,\n        type_per_edge=type_per_edge,\n        node_type_to_id=node_type_to_id,\n        edge_type_to_id=edge_type_to_id,\n        node_attributes=node_attributes,\n        edge_attributes=edge_attributes,\n    )\n\n    serialized = pickle.dumps(graph)\n    graph2 = pickle.loads(serialized)\n\n    assert graph.total_num_nodes == graph2.total_num_nodes\n    assert graph.total_num_edges == graph2.total_num_edges\n\n    assert torch.equal(graph.csc_indptr, graph2.csc_indptr)\n    assert torch.equal(graph.indices, graph2.indices)\n    assert torch.equal(graph.node_type_offset, graph2.node_type_offset)\n    assert torch.equal(graph.type_per_edge, graph2.type_per_edge)\n    assert graph.node_type_to_id.keys() == graph2.node_type_to_id.keys()\n    for i in graph.node_type_to_id.keys():\n        assert graph.node_type_to_id[i] == graph2.node_type_to_id[i]\n    assert graph.edge_type_to_id.keys() == graph2.edge_type_to_id.keys()\n    for i in graph.edge_type_to_id.keys():\n        assert graph.edge_type_to_id[i] == graph2.edge_type_to_id[i]\n    if has_node_attrs:\n        assert graph.node_attributes.keys() == graph2.node_attributes.keys()\n        for key in graph.node_attributes.keys():\n            assert torch.equal(\n                graph.node_attributes[key], graph2.node_attributes[key]\n            )\n    else:\n        assert graph.node_attributes is None and graph2.node_attributes is None\n    if has_edge_attrs:\n        assert graph.edge_attributes.keys() == graph2.edge_attributes.keys()\n        for key in graph.edge_attributes.keys():\n            assert torch.equal(\n                graph.edge_attributes[key], graph2.edge_attributes[key]\n            )\n    else:\n        assert graph.edge_attributes is None and graph2.edge_attributes is None\n\n\ndef process_csc_sampling_graph_multiprocessing(graph):\n    return graph.total_num_nodes\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Graph is CPU only at present.\",\n)\ndef test_multiprocessing():\n    total_num_nodes = 5\n    total_num_edges = 10\n    num_ntypes = 2\n    num_etypes = 3\n    (\n        csc_indptr,\n        indices,\n        node_type_offset,\n        type_per_edge,\n        node_type_to_id,\n        edge_type_to_id,\n    ) = gbt.random_hetero_graph(\n        total_num_nodes, total_num_edges, num_ntypes, num_etypes\n    )\n    edge_attributes = {\n        \"a\": torch.randn((total_num_edges,)),\n    }\n    graph = gb.fused_csc_sampling_graph(\n        csc_indptr,\n        indices,\n        node_type_offset=node_type_offset,\n        type_per_edge=type_per_edge,\n        node_type_to_id=node_type_to_id,\n        edge_type_to_id=edge_type_to_id,\n        edge_attributes=edge_attributes,\n    )\n\n    p = mp.Process(\n        target=process_csc_sampling_graph_multiprocessing, args=(graph,)\n    )\n    p.start()\n    p.join()\n\n\ndef test_in_subgraph_homo():\n    \"\"\"Original graph in COO:\n    1   0   1   0   1\n    1   0   1   1   0\n    0   1   0   1   0\n    0   1   0   0   1\n    1   0   0   0   1\n    \"\"\"\n    # Initialize data.\n    total_num_nodes = 5\n    total_num_edges = 12\n    indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])\n    indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])\n    assert indptr[-1] == total_num_edges\n    assert indptr[-1] == len(indices)\n\n    # Construct FusedCSCSamplingGraph.\n    graph = gb.fused_csc_sampling_graph(indptr, indices).to(F.ctx())\n\n    # Extract in subgraph.\n    nodes = torch.tensor([4, 1, 3], device=F.ctx())\n    in_subgraph = graph.in_subgraph(nodes)\n\n    # Verify in subgraph.\n    assert torch.equal(\n        in_subgraph.sampled_csc.indices,\n        torch.tensor([0, 3, 4, 2, 3, 1, 2], device=F.ctx()),\n    )\n    assert torch.equal(\n        in_subgraph.sampled_csc.indptr,\n        torch.tensor([0, 3, 5, 7], device=F.ctx()),\n    )\n    assert in_subgraph.original_column_node_ids is None\n    assert in_subgraph.original_row_node_ids is None\n    assert torch.equal(\n        in_subgraph.original_edge_ids,\n        torch.tensor([9, 10, 11, 3, 4, 7, 8], device=F.ctx()),\n    )\n\n\ndef test_in_subgraph_hetero():\n    \"\"\"Original graph in COO:\n    1   0   1   0   1\n    1   0   1   1   0\n    0   1   0   1   0\n    0   1   0   0   1\n    1   0   0   0   1\n\n    node_type_0: [0, 1]\n    node_type_1: [2, 3, 4]\n    edge_type_0: node_type_0 -> node_type_0\n    edge_type_1: node_type_0 -> node_type_1\n    edge_type_2: node_type_1 -> node_type_0\n    edge_type_3: node_type_1 -> node_type_1\n    \"\"\"\n    # Initialize data.\n    total_num_nodes = 5\n    total_num_edges = 12\n    ntypes = {\n        \"N0\": 0,\n        \"N1\": 1,\n    }\n    etypes = {\n        \"N0:R0:N0\": 0,\n        \"N0:R1:N1\": 1,\n        \"N1:R2:N0\": 2,\n        \"N1:R3:N1\": 3,\n    }\n    indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])\n    indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])\n    node_type_offset = torch.LongTensor([0, 2, 5])\n    type_per_edge = torch.LongTensor([0, 0, 2, 2, 2, 1, 1, 1, 3, 1, 3, 3])\n    assert indptr[-1] == total_num_edges\n    assert indptr[-1] == len(indices)\n    assert node_type_offset[-1] == total_num_nodes\n    assert all(type_per_edge < len(etypes))\n\n    # Construct FusedCSCSamplingGraph.\n    graph = gb.fused_csc_sampling_graph(\n        indptr,\n        indices,\n        node_type_offset=node_type_offset,\n        type_per_edge=type_per_edge,\n        node_type_to_id=ntypes,\n        edge_type_to_id=etypes,\n    ).to(F.ctx())\n\n    # Extract in subgraph.\n    nodes = {\n        \"N0\": torch.tensor([1], device=F.ctx()),\n        \"N1\": torch.tensor([2, 1], device=F.ctx()),\n    }\n    in_subgraph = graph.in_subgraph(nodes)\n\n    # Verify in subgraph.\n    assert torch.equal(\n        in_subgraph.sampled_csc[\"N0:R0:N0\"].indices,\n        torch.tensor([], device=F.ctx()),\n    )\n    assert torch.equal(\n        in_subgraph.sampled_csc[\"N0:R0:N0\"].indptr,\n        torch.tensor([0, 0], device=F.ctx()),\n    )\n    assert torch.equal(\n        in_subgraph.sampled_csc[\"N0:R1:N1\"].indices,\n        torch.tensor([0, 1], device=F.ctx()),\n    )\n    assert torch.equal(\n        in_subgraph.sampled_csc[\"N0:R1:N1\"].indptr,\n        torch.tensor([0, 1, 2], device=F.ctx()),\n    )\n    assert torch.equal(\n        in_subgraph.sampled_csc[\"N1:R2:N0\"].indices,\n        torch.tensor([0, 1], device=F.ctx()),\n    )\n    assert torch.equal(\n        in_subgraph.sampled_csc[\"N1:R2:N0\"].indptr,\n        torch.tensor([0, 2], device=F.ctx()),\n    )\n    assert torch.equal(\n        in_subgraph.sampled_csc[\"N1:R3:N1\"].indices,\n        torch.tensor([1, 2, 0], device=F.ctx()),\n    )\n    assert torch.equal(\n        in_subgraph.sampled_csc[\"N1:R3:N1\"].indptr,\n        torch.tensor([0, 2, 3], device=F.ctx()),\n    )\n    assert in_subgraph.original_column_node_ids is None\n    assert in_subgraph.original_row_node_ids is None\n    assert torch.equal(\n        in_subgraph.original_edge_ids[\"N0:R0:N0\"],\n        torch.tensor([], device=F.ctx()),\n    )\n    assert torch.equal(\n        in_subgraph.original_edge_ids[\"N0:R1:N1\"],\n        torch.tensor([9, 7], device=F.ctx()),\n    )\n    assert torch.equal(\n        in_subgraph.original_edge_ids[\"N1:R2:N0\"],\n        torch.tensor([3, 4], device=F.ctx()),\n    )\n    assert torch.equal(\n        in_subgraph.original_edge_ids[\"N1:R3:N1\"],\n        torch.tensor([10, 11, 8], device=F.ctx()),\n    )\n\n\n@pytest.mark.parametrize(\"indptr_dtype\", [torch.int32, torch.int64])\n@pytest.mark.parametrize(\"indices_dtype\", [torch.int32, torch.int64])\n@pytest.mark.parametrize(\"replace\", [False, True])\n@pytest.mark.parametrize(\"labor\", [False, True])\n@pytest.mark.parametrize(\"use_node_timestamp\", [False, True])\n@pytest.mark.parametrize(\"use_edge_timestamp\", [False, True])\ndef test_temporal_sample_neighbors_homo(\n    indptr_dtype,\n    indices_dtype,\n    replace,\n    labor,\n    use_node_timestamp,\n    use_edge_timestamp,\n):\n    if replace and F._default_context_str == \"gpu\":\n        pytest.skip(\"Sampling with replacement not yet implemented on the GPU.\")\n    \"\"\"Original graph in COO:\n    1   0   1   0   1\n    1   0   1   1   0\n    0   1   0   1   0\n    0   1   0   0   1\n    1   0   0   0   1\n    \"\"\"\n    # Initialize data.\n    total_num_nodes = 5\n    total_num_edges = 12\n    indptr = torch.tensor([0, 3, 5, 7, 9, 12], dtype=indptr_dtype)\n    indices = torch.tensor(\n        [0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4], dtype=indices_dtype\n    )\n    assert indptr[-1] == total_num_edges\n    assert indptr[-1] == len(indices)\n    assert len(indptr) == total_num_nodes + 1\n\n    # Construct FusedCSCSamplingGraph.\n    graph = gb.fused_csc_sampling_graph(indptr, indices).to(F.ctx())\n\n    # Generate subgraph via sample neighbors.\n    fanouts = torch.LongTensor([2])\n    sampler = (\n        graph.temporal_sample_layer_neighbors\n        if labor\n        else graph.temporal_sample_neighbors\n    )\n\n    seed_list = [1, 3, 4]\n    seed_timestamp = torch.randint(\n        0, 100, (len(seed_list),), dtype=torch.int64, device=F.ctx()\n    )\n    if use_node_timestamp:\n        node_timestamp = torch.randint(\n            0, 100, (total_num_nodes,), dtype=torch.int64, device=F.ctx()\n        )\n        graph.node_attributes = {\"timestamp\": node_timestamp}\n    if use_edge_timestamp:\n        edge_timestamp = torch.randint(\n            0, 100, (total_num_edges,), dtype=torch.int64, device=F.ctx()\n        )\n        graph.edge_attributes = {\"timestamp\": edge_timestamp}\n\n    # Sample with nodes in mismatched dtype with graph's indices.\n    nodes = torch.tensor(\n        seed_list,\n        dtype=(torch.int64 if indices_dtype == torch.int32 else torch.int32),\n    )\n    with pytest.raises(\n        AssertionError,\n        match=re.escape(\n            \"Data type of nodes must be consistent with indices.dtype\"\n        ),\n    ):\n        _ = sampler(\n            nodes,\n            seed_timestamp,\n            fanouts,\n            replace=replace,\n            node_timestamp_attr_name=(\n                \"timestamp\" if use_node_timestamp else None\n            ),\n            edge_timestamp_attr_name=(\n                \"timestamp\" if use_edge_timestamp else None\n            ),\n        )\n\n    def _get_available_neighbors():\n        available_neighbors = []\n        for i, seed in enumerate(seed_list):\n            neighbors = []\n            start = indptr[seed].item()\n            end = indptr[seed + 1].item()\n            for j in range(start, end):\n                neighbor = indices[j].item()\n                if (\n                    use_node_timestamp\n                    and (node_timestamp[neighbor] >= seed_timestamp[i]).item()\n                ):\n                    continue\n                if (\n                    use_edge_timestamp\n                    and (edge_timestamp[j] >= seed_timestamp[i]).item()\n                ):\n                    continue\n                neighbors.append(neighbor)\n            available_neighbors.append(neighbors)\n        return available_neighbors\n\n    nodes = torch.tensor(seed_list, dtype=indices_dtype, device=F.ctx())\n    subgraph = sampler(\n        nodes,\n        seed_timestamp,\n        fanouts,\n        replace=replace,\n        node_timestamp_attr_name=\"timestamp\" if use_node_timestamp else None,\n        edge_timestamp_attr_name=\"timestamp\" if use_edge_timestamp else None,\n    )\n    sampled_count = torch.diff(subgraph.sampled_csc.indptr).tolist()\n    available_neighbors = _get_available_neighbors()\n    assert len(available_neighbors) == len(sampled_count)\n    for i, count in enumerate(sampled_count):\n        if not replace:\n            expect_count = min(fanouts[0], len(available_neighbors[i]))\n        else:\n            expect_count = fanouts[0] if len(available_neighbors[i]) > 0 else 0\n        assert count == expect_count\n    sampled_neighbors = torch.split(subgraph.sampled_csc.indices, sampled_count)\n    for i, neighbors in enumerate(sampled_neighbors):\n        assert set(neighbors.tolist()).issubset(set(available_neighbors[i]))\n\n\n@pytest.mark.parametrize(\"indptr_dtype\", [torch.int32, torch.int64])\n@pytest.mark.parametrize(\"indices_dtype\", [torch.int32, torch.int64])\n@pytest.mark.parametrize(\"replace\", [False, True])\n@pytest.mark.parametrize(\"labor\", [False, True])\n@pytest.mark.parametrize(\"use_node_timestamp\", [False, True])\n@pytest.mark.parametrize(\"use_edge_timestamp\", [False, True])\ndef test_temporal_sample_neighbors_hetero(\n    indptr_dtype,\n    indices_dtype,\n    replace,\n    labor,\n    use_node_timestamp,\n    use_edge_timestamp,\n):\n    if replace and F._default_context_str == \"gpu\":\n        pytest.skip(\"Sampling with replacement not yet implemented on the GPU.\")\n    \"\"\"Original graph in COO:\n    \"n1:e1:n2\":[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]\n    \"n2:e2:n1\":[0, 0, 1, 2], [0, 1, 1 ,0]\n    0   0   1   0   1\n    0   0   1   1   1\n    1   1   0   0   0\n    0   1   0   0   0\n    1   0   0   0   0\n    \"\"\"\n    # Initialize data.\n    ntypes = {\"n1\": 0, \"n2\": 1}\n    etypes = {\"n1:e1:n2\": 0, \"n2:e2:n1\": 1}\n    ntypes_to_offset = {\"n1\": 0, \"n2\": 2}\n    total_num_nodes = 5\n    total_num_edges = 9\n    indptr = torch.tensor([0, 2, 4, 6, 7, 9], dtype=indptr_dtype)\n    indices = torch.tensor([2, 4, 2, 3, 0, 1, 1, 0, 1], dtype=indices_dtype)\n    type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0])\n    node_type_offset = torch.LongTensor([0, 2, 5])\n    assert indptr[-1] == total_num_edges\n    assert indptr[-1] == len(indices)\n\n    # Construct FusedCSCSamplingGraph.\n    graph = gb.fused_csc_sampling_graph(\n        indptr,\n        indices,\n        node_type_offset=node_type_offset,\n        type_per_edge=type_per_edge,\n        node_type_to_id=ntypes,\n        edge_type_to_id=etypes,\n    ).to(F.ctx())\n\n    # Generate subgraph via sample neighbors.\n    fanouts = torch.LongTensor([-1, -1])\n    sampler = (\n        graph.temporal_sample_layer_neighbors\n        if labor\n        else graph.temporal_sample_neighbors\n    )\n\n    seeds = {\n        \"n1\": torch.tensor([0], dtype=indices_dtype, device=F.ctx()),\n        \"n2\": torch.tensor([0], dtype=indices_dtype, device=F.ctx()),\n    }\n    per_etype_destination_nodes = {\n        \"n1:e1:n2\": torch.tensor([1], dtype=indices_dtype),\n        \"n2:e2:n1\": torch.tensor([0], dtype=indices_dtype),\n    }\n\n    seed_timestamp = {\n        \"n1\": torch.randint(0, 100, (1,), dtype=torch.int64, device=F.ctx()),\n        \"n2\": torch.randint(0, 100, (1,), dtype=torch.int64, device=F.ctx()),\n    }\n    if use_node_timestamp:\n        node_timestamp = torch.randint(\n            0, 100, (total_num_nodes,), dtype=torch.int64, device=F.ctx()\n        )\n        graph.node_attributes = {\"timestamp\": node_timestamp}\n    if use_edge_timestamp:\n        edge_timestamp = torch.randint(\n            0, 100, (total_num_edges,), dtype=torch.int64, device=F.ctx()\n        )\n        graph.edge_attributes = {\"timestamp\": edge_timestamp}\n\n    subgraph = sampler(\n        seeds,\n        seed_timestamp,\n        fanouts,\n        replace=replace,\n        node_timestamp_attr_name=\"timestamp\" if use_node_timestamp else None,\n        edge_timestamp_attr_name=\"timestamp\" if use_edge_timestamp else None,\n    )\n\n    def _to_homo():\n        ret_seeds, ret_timestamps = [], []\n        for ntype, nodes in seeds.items():\n            ntype_id = ntypes[ntype]\n            offset = node_type_offset[ntype_id]\n            ret_seeds.append(nodes + offset)\n            ret_timestamps.append(seed_timestamp[ntype])\n        return torch.cat(ret_seeds), torch.cat(ret_timestamps)\n\n    homo_seeds, homo_seed_timestamp = _to_homo()\n\n    def _get_available_neighbors():\n        available_neighbors = []\n        for i, seed in enumerate(homo_seeds):\n            neighbors = []\n            start = indptr[seed].item()\n            end = indptr[seed + 1].item()\n            for j in range(start, end):\n                neighbor = indices[j].item()\n                if (\n                    use_node_timestamp\n                    and (\n                        node_timestamp[neighbor] >= homo_seed_timestamp[i]\n                    ).item()\n                ):\n                    continue\n                if (\n                    use_edge_timestamp\n                    and (edge_timestamp[j] >= homo_seed_timestamp[i]).item()\n                ):\n                    continue\n                neighbors.append(neighbor)\n            available_neighbors.append(neighbors)\n        return available_neighbors\n\n    available_neighbors = _get_available_neighbors()\n    sampled_count = [0] * homo_seeds.numel()\n    sampled_neighbors = [[] for _ in range(homo_seeds.numel())]\n    for etype, csc in subgraph.sampled_csc.items():\n        stype, _, _ = etype_str_to_tuple(etype)\n        ntype_offset = ntypes_to_offset[stype]\n        dest_nodes = per_etype_destination_nodes[etype]\n        for i in range(dest_nodes.numel()):\n            l = csc.indptr[i]\n            r = csc.indptr[i + 1]\n            seed_offset = dest_nodes[i].item()\n            sampled_neighbors[seed_offset].extend(\n                (csc.indices[l:r] + ntype_offset).tolist()\n            )\n            sampled_count[seed_offset] += r - l\n\n    for i, count in enumerate(sampled_count):\n        assert count == len(available_neighbors[i])\n        assert set(sampled_neighbors[i]).issubset(set(available_neighbors[i]))\n\n\ndef check_tensors_on_the_same_shared_memory(t1: torch.Tensor, t2: torch.Tensor):\n    \"\"\"Check if two tensors are on the same shared memory.\n\n    This function copies a random tensor value to `t1` and checks whether `t2`\n    holds the same random value and checks whether t2 is a distinct tensor\n    object from `t1`. Their equality confirms that they are separate tensors\n    that rely on the shared memory for their tensor value.\n    \"\"\"\n    assert t1.data_ptr() != t2.data_ptr()\n    old_t1 = t1.clone()\n    v = torch.randint_like(t1, 100)\n    t1[:] = v\n    assert torch.equal(t1, t2)\n    t1[:] = old_t1\n\n\ndef check_node_edge_attributes(graph1, graph2, attributes, attr_name):\n    for name, attr in attributes.items():\n        edge_attributes_1 = getattr(graph1, attr_name)\n        edge_attributes_2 = getattr(graph2, attr_name)\n        assert name in edge_attributes_1\n        assert name in edge_attributes_2\n        assert torch.equal(edge_attributes_1[name], attr)\n        check_tensors_on_the_same_shared_memory(\n            edge_attributes_1[name], edge_attributes_2[name]\n        )\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"FusedCSCSamplingGraph is only supported on CPU.\",\n)\n@pytest.mark.parametrize(\n    \"total_num_nodes, total_num_edges\",\n    [(1, 1), (100, 1), (10, 50), (1000, 50000)],\n)\n@pytest.mark.parametrize(\"test_node_attrs\", [True, False])\n@pytest.mark.parametrize(\"test_edge_attrs\", [True, False])\ndef test_homo_graph_on_shared_memory(\n    total_num_nodes, total_num_edges, test_node_attrs, test_edge_attrs\n):\n    csc_indptr, indices = gbt.random_homo_graph(\n        total_num_nodes, total_num_edges\n    )\n    node_attributes = None\n    if test_node_attrs:\n        node_attributes = {\n            \"A1\": torch.arange(total_num_nodes),\n            \"A2\": torch.arange(total_num_nodes),\n        }\n    edge_attributes = None\n    if test_edge_attrs:\n        edge_attributes = {\n            \"A1\": torch.randn(total_num_edges),\n            \"A2\": torch.randn(total_num_edges),\n        }\n    graph = gb.fused_csc_sampling_graph(\n        csc_indptr,\n        indices,\n        node_attributes=node_attributes,\n        edge_attributes=edge_attributes,\n    )\n\n    shm_name = \"test_homo_g\"\n    graph1 = graph.copy_to_shared_memory(shm_name)\n    graph2 = gb.load_from_shared_memory(shm_name)\n\n    assert graph1.total_num_nodes == total_num_nodes\n    assert graph1.total_num_nodes == total_num_nodes\n    assert graph2.total_num_edges == total_num_edges\n    assert graph2.total_num_edges == total_num_edges\n\n    # Test the value of graph1 is correct\n    assert torch.equal(graph1.csc_indptr, csc_indptr)\n    assert torch.equal(graph1.indices, indices)\n\n    # Test the value of graph2 is correct\n    assert torch.equal(graph2.csc_indptr, csc_indptr)\n    assert torch.equal(graph2.indices, indices)\n\n    # Test the memory of graph1 and graph2 is on shared memory\n    check_tensors_on_the_same_shared_memory(\n        graph1.csc_indptr, graph2.csc_indptr\n    )\n    check_tensors_on_the_same_shared_memory(graph1.indices, graph2.indices)\n\n    if test_node_attrs:\n        check_node_edge_attributes(\n            graph1, graph2, node_attributes, \"node_attributes\"\n        )\n    if test_edge_attrs:\n        check_node_edge_attributes(\n            graph1, graph2, edge_attributes, \"edge_attributes\"\n        )\n\n    assert graph1.node_type_offset is None and graph2.node_type_offset is None\n    assert graph1.type_per_edge is None and graph2.type_per_edge is None\n    assert graph1.node_type_to_id is None and graph2.node_type_to_id is None\n    assert graph1.edge_type_to_id is None and graph2.edge_type_to_id is None\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"FusedCSCSamplingGraph is only supported on CPU.\",\n)\n@pytest.mark.parametrize(\n    \"total_num_nodes, total_num_edges\",\n    [(1, 1), (100, 1), (10, 50), (1000, 50 * 1000), (10 * 1000, 100 * 1000)],\n)\n@pytest.mark.parametrize(\n    \"num_ntypes, num_etypes\", [(1, 1), (3, 5), (100, 1), (1000, 1000)]\n)\n@pytest.mark.parametrize(\"test_node_attrs\", [True, False])\n@pytest.mark.parametrize(\"test_edge_attrs\", [True, False])\ndef test_hetero_graph_on_shared_memory(\n    total_num_nodes,\n    total_num_edges,\n    num_ntypes,\n    num_etypes,\n    test_node_attrs,\n    test_edge_attrs,\n):\n    (\n        csc_indptr,\n        indices,\n        node_type_offset,\n        type_per_edge,\n        node_type_to_id,\n        edge_type_to_id,\n    ) = gbt.random_hetero_graph(\n        total_num_nodes, total_num_edges, num_ntypes, num_etypes\n    )\n\n    node_attributes = None\n    if test_node_attrs:\n        node_attributes = {\n            \"A1\": torch.arange(total_num_nodes),\n            \"A2\": torch.arange(total_num_nodes),\n        }\n\n    edge_attributes = None\n    if test_edge_attrs:\n        edge_attributes = {\n            \"A1\": torch.randn(total_num_edges),\n            \"A2\": torch.randn(total_num_edges),\n        }\n\n    graph = gb.fused_csc_sampling_graph(\n        csc_indptr,\n        indices,\n        node_type_offset=node_type_offset,\n        type_per_edge=type_per_edge,\n        node_type_to_id=node_type_to_id,\n        edge_type_to_id=edge_type_to_id,\n        node_attributes=node_attributes,\n        edge_attributes=edge_attributes,\n    )\n\n    shm_name = \"test_hetero_g\"\n    graph1 = graph.copy_to_shared_memory(shm_name)\n    graph2 = gb.load_from_shared_memory(shm_name)\n\n    assert graph1.total_num_nodes == total_num_nodes\n    assert graph1.total_num_nodes == total_num_nodes\n    assert graph2.total_num_edges == total_num_edges\n    assert graph2.total_num_edges == total_num_edges\n\n    # Test the value of graph1 is correct\n    assert torch.equal(graph1.csc_indptr, csc_indptr)\n    assert torch.equal(graph1.indices, indices)\n    assert torch.equal(graph1.node_type_offset, node_type_offset)\n    assert torch.equal(graph1.type_per_edge, type_per_edge)\n\n    # Test the value of graph2 is correct\n    assert torch.equal(graph2.csc_indptr, csc_indptr)\n    assert torch.equal(graph2.indices, indices)\n    assert torch.equal(graph2.node_type_offset, node_type_offset)\n    assert torch.equal(graph2.type_per_edge, type_per_edge)\n\n    # Test the memory of graph1 and graph2 is on shared memory\n    check_tensors_on_the_same_shared_memory(\n        graph1.csc_indptr, graph2.csc_indptr\n    )\n    check_tensors_on_the_same_shared_memory(graph1.indices, graph2.indices)\n    check_tensors_on_the_same_shared_memory(\n        graph1.node_type_offset, graph2.node_type_offset\n    )\n    check_tensors_on_the_same_shared_memory(\n        graph1.type_per_edge, graph2.type_per_edge\n    )\n\n    if test_node_attrs:\n        check_node_edge_attributes(\n            graph1, graph2, node_attributes, \"node_attributes\"\n        )\n    if test_edge_attrs:\n        check_node_edge_attributes(\n            graph1, graph2, edge_attributes, \"edge_attributes\"\n        )\n\n    assert node_type_to_id == graph1.node_type_to_id\n    assert edge_type_to_id == graph1.edge_type_to_id\n    assert node_type_to_id == graph2.node_type_to_id\n    assert edge_type_to_id == graph2.edge_type_to_id\n\n\ndef process_csc_sampling_graph_on_shared_memory(graph, data_queue, flag_queue):\n    # Backup the attributes.\n    csc_indptr = graph.csc_indptr.clone()\n    indices = graph.indices.clone()\n    node_type_offset = graph.node_type_offset.clone()\n    type_per_edge = graph.type_per_edge.clone()\n\n    # Change the value to random integers. Send the new value to the main\n    # process.\n    v = torch.randint_like(graph.csc_indptr, 100)\n    graph.csc_indptr[:] = v\n    data_queue.put(v.clone())\n\n    v = torch.randint_like(graph.indices, 100)\n    graph.indices[:] = v\n    data_queue.put(v.clone())\n\n    v = torch.randint_like(graph.node_type_offset, 100)\n    graph.node_type_offset[:] = v\n    data_queue.put(v.clone())\n\n    v = torch.randint_like(graph.type_per_edge, 100)\n    graph.type_per_edge[:] = v\n    data_queue.put(v.clone())\n\n    # Wait for the main process to finish.\n    flag_queue.get()\n\n    graph.csc_indptr[:] = csc_indptr\n    graph.indices[:] = indices\n    graph.node_type_offset[:] = node_type_offset\n    graph.type_per_edge[:] = type_per_edge\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Graph is CPU only at present.\",\n)\ndef test_multiprocessing_with_shared_memory():\n    \"\"\"Test if two CSCSamplingGraphs are on the same shared memory after\n    spawning.\n\n    For now this code only works when the sharing strategy of\n    torch.multiprocessing is set to `file_system` at the beginning.\n    The cause is still yet to be found.\n    \"\"\"\n\n    total_num_nodes = 5\n    total_num_edges = 10\n    num_ntypes = 2\n    num_etypes = 3\n    (\n        csc_indptr,\n        indices,\n        node_type_offset,\n        type_per_edge,\n        node_type_to_id,\n        edge_type_to_id,\n    ) = gbt.random_hetero_graph(\n        total_num_nodes, total_num_edges, num_ntypes, num_etypes\n    )\n\n    csc_indptr.share_memory_()\n    indices.share_memory_()\n    node_type_offset.share_memory_()\n    type_per_edge.share_memory_()\n\n    graph = gb.fused_csc_sampling_graph(\n        csc_indptr,\n        indices,\n        node_type_offset=node_type_offset,\n        type_per_edge=type_per_edge,\n        node_type_to_id=node_type_to_id,\n        edge_type_to_id=edge_type_to_id,\n        edge_attributes=None,\n    )\n\n    ctx = mp.get_context(\"spawn\")  # Use spawn method.\n\n    data_queue = ctx.Queue()  # Used for sending graph.\n    flag_queue = ctx.Queue()  # Used for sending finish signal.\n\n    p = ctx.Process(\n        target=process_csc_sampling_graph_on_shared_memory,\n        args=(graph, data_queue, flag_queue),\n    )\n    p.start()\n    try:\n        # Get data from the other process. Then check if the tensors here have\n        # the same data.\n        csc_indptr2 = data_queue.get()\n        assert torch.equal(graph.csc_indptr, csc_indptr2)\n        indices2 = data_queue.get()\n        assert torch.equal(graph.indices, indices2)\n        node_type_offset2 = data_queue.get()\n        assert torch.equal(graph.node_type_offset, node_type_offset2)\n        type_per_edge2 = data_queue.get()\n        assert torch.equal(graph.type_per_edge, type_per_edge2)\n    except:\n        raise\n    finally:\n        # Send a finish signal to end sub-process.\n        flag_queue.put(None)\n    p.join()\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Graph on GPU is not supported yet.\",\n)\ndef test_from_dglgraph_homogeneous():\n    dgl_g = dgl.rand_graph(1000, 10 * 1000)\n\n    # Check if the original edge id exist in edge attributes when the\n    # original_edge_id is set to False.\n    gb_g = gb.from_dglgraph(\n        dgl_g, is_homogeneous=False, include_original_edge_id=False\n    )\n    assert (\n        gb_g.edge_attributes is None\n        or gb.ORIGINAL_EDGE_ID not in gb_g.edge_attributes\n    )\n\n    gb_g = gb.from_dglgraph(\n        dgl_g, is_homogeneous=True, include_original_edge_id=True\n    )\n    # Get the COO representation of the FusedCSCSamplingGraph.\n    num_columns = gb_g.csc_indptr.diff()\n    rows = gb_g.indices\n    columns = torch.arange(gb_g.total_num_nodes).repeat_interleave(num_columns)\n\n    original_edge_ids = gb_g.edge_attributes[gb.ORIGINAL_EDGE_ID]\n    assert torch.all(dgl_g.edges()[0][original_edge_ids] == rows)\n    assert torch.all(dgl_g.edges()[1][original_edge_ids] == columns)\n\n    assert gb_g.total_num_nodes == dgl_g.num_nodes()\n    assert gb_g.total_num_edges == dgl_g.num_edges()\n    assert gb_g.node_type_offset is None\n    assert gb_g.type_per_edge is None\n    assert gb_g.node_type_to_id is None\n    assert gb_g.edge_type_to_id is None\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Graph on GPU is not supported yet.\",\n)\ndef test_from_dglgraph_heterogeneous():\n    dgl_g = dgl.heterograph(\n        {\n            (\"author\", \"writes\", \"paper\"): (\n                [1, 2, 3, 4, 5, 2],\n                [1, 2, 3, 4, 5, 4],\n            ),\n            (\"author\", \"affiliated_with\", \"institution\"): (\n                [1, 2, 3, 4, 5],\n                [1, 2, 3, 4, 5],\n            ),\n            (\"paper\", \"has_topic\", \"field\"): ([1, 2, 3, 4, 5], [1, 2, 3, 4, 5]),\n            (\"paper\", \"cites\", \"paper\"): (\n                [2, 3, 4, 5, 6, 1],\n                [1, 2, 3, 4, 5, 4],\n            ),\n        }\n    )\n    # Check if the original edge id exist in edge attributes when the\n    # original_edge_id is set to False.\n    gb_g = gb.from_dglgraph(\n        dgl_g, is_homogeneous=False, include_original_edge_id=False\n    )\n    assert (\n        gb_g.edge_attributes is None\n        or gb.ORIGINAL_EDGE_ID not in gb_g.edge_attributes\n    )\n\n    gb_g = gb.from_dglgraph(\n        dgl_g, is_homogeneous=False, include_original_edge_id=True\n    )\n\n    # `reverse_node_id` is used to map the node id in FusedCSCSamplingGraph to the\n    # node id in Hetero-DGLGraph.\n    num_ntypes = gb_g.node_type_offset.diff()\n    reverse_node_id = torch.cat([torch.arange(num) for num in num_ntypes])\n\n    # Get the COO representation of the FusedCSCSamplingGraph.\n    num_columns = gb_g.csc_indptr.diff()\n    rows = reverse_node_id[gb_g.indices]\n    columns = reverse_node_id[\n        torch.arange(gb_g.total_num_nodes).repeat_interleave(num_columns)\n    ]\n\n    # Check the order of etypes in DGLGraph is the same as FusedCSCSamplingGraph.\n    assert (\n        # Since the etypes in FusedCSCSamplingGraph is \"srctype:etype:dsttype\",\n        # we need to split the string and get the middle part.\n        list(\n            map(\n                lambda ss: ss.split(\":\")[1],\n                gb_g.edge_type_to_id.keys(),\n            )\n        )\n        == dgl_g.etypes\n    )\n\n    # Use ORIGINAL_EDGE_ID to check if the edge mapping is correct.\n    for edge_idx in range(gb_g.total_num_edges):\n        hetero_graph_idx = gb_g.type_per_edge[edge_idx]\n        original_edge_id = gb_g.edge_attributes[gb.ORIGINAL_EDGE_ID][edge_idx]\n        edge_type = dgl_g.etypes[hetero_graph_idx]\n        dgl_edge_pairs = dgl_g.edges(etype=edge_type)\n        assert dgl_edge_pairs[0][original_edge_id] == rows[edge_idx]\n        assert dgl_edge_pairs[1][original_edge_id] == columns[edge_idx]\n\n    assert gb_g.total_num_nodes == dgl_g.num_nodes()\n    assert gb_g.total_num_edges == dgl_g.num_edges()\n    assert torch.equal(gb_g.node_type_offset, torch.tensor([0, 6, 12, 18, 25]))\n    assert torch.equal(\n        gb_g.type_per_edge,\n        torch.tensor(\n            [3, 3, 3, 3, 3, 0, 0, 0, 0, 0, 1, 2, 1, 2, 1, 2, 1, 1, 2, 2, 1, 2]\n        ),\n    )\n    assert gb_g.node_type_to_id == {\n        \"author\": 0,\n        \"field\": 1,\n        \"institution\": 2,\n        \"paper\": 3,\n    }\n    assert gb_g.edge_type_to_id == {\n        \"author:affiliated_with:institution\": 0,\n        \"author:writes:paper\": 1,\n        \"paper:cites:paper\": 2,\n        \"paper:has_topic:field\": 3,\n    }\n\n\ndef create_fused_csc_sampling_graph():\n    # Initialize data.\n    total_num_nodes = 10\n    total_num_edges = 9\n    ntypes = {\"N0\": 0, \"N1\": 1, \"N2\": 2, \"N3\": 3}\n    etypes = {\n        \"N0:R0:N1\": 0,\n        \"N0:R1:N2\": 1,\n        \"N0:R2:N3\": 2,\n    }\n    indptr = torch.LongTensor([0, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9])\n    indices = torch.LongTensor([1, 2, 3, 4, 5, 6, 7, 8, 9])\n    node_type_offset = torch.LongTensor([0, 1, 4, 7, 10])\n    type_per_edge = torch.LongTensor([0, 0, 0, 1, 1, 1, 2, 2, 2])\n    assert indptr[-1] == total_num_edges\n    assert indptr[-1] == len(indices)\n    assert node_type_offset[-1] == total_num_nodes\n    assert all(type_per_edge < len(etypes))\n\n    edge_attributes = {\n        \"mask\": torch.BoolTensor([1, 1, 0, 1, 1, 1, 0, 0, 0]),\n        \"all\": torch.BoolTensor([1, 1, 1, 1, 1, 1, 1, 1, 1]),\n        \"zero\": torch.BoolTensor([0, 0, 0, 0, 0, 0, 0, 0, 0]),\n    }\n\n    # Construct FusedCSCSamplingGraph.\n    return gb.fused_csc_sampling_graph(\n        indptr,\n        indices,\n        edge_attributes=edge_attributes,\n        node_type_offset=node_type_offset,\n        type_per_edge=type_per_edge,\n        node_type_to_id=ntypes,\n        edge_type_to_id=etypes,\n    )\n\n\ndef is_graph_on_device_type(graph, device_type):\n    assert graph.csc_indptr.device.type == device_type\n    assert graph.indices.device.type == device_type\n    assert graph.node_type_offset.device.type == device_type\n    assert graph.type_per_edge.device.type == device_type\n    assert graph.csc_indptr.device.type == device_type\n    for key in graph.edge_attributes:\n        assert graph.edge_attributes[key].device.type == device_type\n\n\ndef is_graph_pinned(graph):\n    assert graph.csc_indptr.is_pinned()\n    assert graph.indices.is_pinned()\n    assert graph.node_type_offset.is_pinned()\n    assert graph.type_per_edge.is_pinned()\n    assert graph.csc_indptr.is_pinned()\n    for key in graph.edge_attributes:\n        assert graph.edge_attributes[key].is_pinned()\n\n\n@unittest.skipIf(\n    F._default_context_str == \"cpu\",\n    reason=\"`to` function needs GPU to test.\",\n)\n@pytest.mark.parametrize(\"device\", [\"pinned\", \"cuda\"])\ndef test_csc_sampling_graph_to_device(device):\n    # Construct FusedCSCSamplingGraph.\n    graph = create_fused_csc_sampling_graph()\n\n    # Copy to device.\n    graph2 = graph.to(device)\n\n    if device == \"cuda\":\n        is_graph_on_device_type(graph2, \"cuda\")\n    elif device == \"pinned\":\n        is_graph_on_device_type(graph2, \"cpu\")\n        is_graph_pinned(graph2)\n\n    # The original variable should be untouched.\n    is_graph_on_device_type(graph, \"cpu\")\n\n\n@unittest.skipIf(\n    F._default_context_str == \"cpu\",\n    reason=\"Tests for pinned memory are only meaningful on GPU.\",\n)\n@unittest.skipIf(\n    gb.is_wsl(), reason=\"In place pinning is not supported on WSL.\"\n)\ndef test_csc_sampling_graph_to_pinned_memory():\n    # Construct FusedCSCSamplingGraph.\n    graph = create_fused_csc_sampling_graph()\n    ptr = graph.csc_indptr.data_ptr()\n\n    # Copy to pinned_memory in-place.\n    graph.pin_memory_()\n\n    # Check if pinning is truly in-place.\n    assert graph.csc_indptr.data_ptr() == ptr\n\n    is_graph_on_device_type(graph, \"cpu\")\n    is_graph_pinned(graph)\n\n\n@pytest.mark.parametrize(\"indptr_dtype\", [torch.int32, torch.int64])\n@pytest.mark.parametrize(\"indices_dtype\", [torch.int32, torch.int64])\n@pytest.mark.parametrize(\"labor\", [False, True])\n@pytest.mark.parametrize(\"is_pinned\", [False, True])\n@pytest.mark.parametrize(\"nodes\", [None, True])\ndef test_sample_neighbors_homo(\n    indptr_dtype, indices_dtype, labor, is_pinned, nodes\n):\n    if is_pinned and nodes is None:\n        pytest.skip(\"Optional nodes and is_pinned is not supported together.\")\n    \"\"\"Original graph in COO:\n    1   0   1   0   1\n    1   0   1   1   0\n    0   1   0   1   0\n    0   1   0   0   1\n    1   0   0   0   1\n    \"\"\"\n    if F._default_context_str == \"cpu\" and is_pinned:\n        pytest.skip(\"Pinning is not meaningful without a GPU.\")\n    # Initialize data.\n    total_num_edges = 12\n    indptr = torch.tensor([0, 3, 5, 7, 9, 12], dtype=indptr_dtype)\n    indices = torch.tensor(\n        [0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4], dtype=indices_dtype\n    )\n    assert indptr[-1] == total_num_edges\n    assert indptr[-1] == len(indices)\n\n    # Construct FusedCSCSamplingGraph.\n    graph = gb.fused_csc_sampling_graph(indptr, indices).to(\n        \"pinned\" if is_pinned else F.ctx()\n    )\n\n    # Generate subgraph via sample neighbors.\n    if nodes:\n        nodes = torch.tensor([1, 3, 4], dtype=indices_dtype).to(F.ctx())\n    elif F._default_context_str != \"gpu\":\n        pytest.skip(\"Optional nodes is supported only for the GPU.\")\n    sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors\n    subgraph = sampler(nodes, fanouts=torch.LongTensor([2]))\n\n    # Verify in subgraph.\n    sampled_indptr_num = subgraph.sampled_csc.indptr.size(0)\n    sampled_num = subgraph.sampled_csc.indices.size(0)\n    assert sampled_num == len(subgraph.original_edge_ids)\n    if nodes is None:\n        assert sampled_indptr_num == indptr.shape[0]\n        assert sampled_num == 10\n    else:\n        assert sampled_indptr_num == 4\n        assert sampled_num == 6\n    assert subgraph.original_column_node_ids is None\n    assert subgraph.original_row_node_ids is None\n\n\n@pytest.mark.parametrize(\"labor\", [False, True])\ndef test_sample_neighbors_hetero_single_fanout(labor):\n    u, i = torch.randint(20, size=(1000,)), torch.randint(10, size=(1000,))\n    graph = dgl.heterograph({(\"u\", \"w\", \"i\"): (u, i), (\"i\", \"b\", \"u\"): (i, u)})\n\n    graph = gb.from_dglgraph(graph).to(F.ctx())\n\n    sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors\n\n    for i in range(11):\n        nodes = {\"u\": torch.randint(10, (100,), device=F.ctx())}\n        sampler(nodes, fanouts=torch.tensor([-1]))\n    # Should reach here without crashing.\n\n\n@pytest.mark.parametrize(\"indptr_dtype\", [torch.int32, torch.int64])\n@pytest.mark.parametrize(\"indices_dtype\", [torch.int32, torch.int64])\n@pytest.mark.parametrize(\"labor\", [False, True])\ndef test_sample_neighbors_hetero(indptr_dtype, indices_dtype, labor):\n    \"\"\"Original graph in COO:\n    \"n1:e1:n2\":[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]\n    \"n2:e2:n1\":[0, 0, 1, 2], [0, 1, 1 ,0]\n    0   0   1   0   1\n    0   0   1   1   1\n    1   1   0   0   0\n    0   1   0   0   0\n    1   0   0   0   0\n    \"\"\"\n    # Initialize data.\n    ntypes = {\"n1\": 0, \"n2\": 1}\n    etypes = {\"n1:e1:n2\": 0, \"n2:e2:n1\": 1}\n    total_num_edges = 9\n    indptr = torch.tensor([0, 2, 4, 6, 7, 9], dtype=indptr_dtype)\n    indices = torch.tensor([2, 4, 2, 3, 0, 1, 1, 0, 1], dtype=indices_dtype)\n    type_per_edge = torch.tensor(\n        [1, 1, 1, 1, 0, 0, 0, 0, 0], dtype=indices_dtype\n    )\n    node_type_offset = torch.tensor([0, 2, 5], dtype=indices_dtype)\n    assert indptr[-1] == total_num_edges\n    assert indptr[-1] == len(indices)\n\n    # Construct FusedCSCSamplingGraph.\n    graph = gb.fused_csc_sampling_graph(\n        indptr,\n        indices,\n        node_type_offset=node_type_offset,\n        type_per_edge=type_per_edge,\n        node_type_to_id=ntypes,\n        edge_type_to_id=etypes,\n    ).to(F.ctx())\n\n    # Sample on both node types.\n    nodes = {\n        \"n1\": torch.tensor([0], dtype=indices_dtype, device=F.ctx()),\n        \"n2\": torch.tensor([0], dtype=indices_dtype, device=F.ctx()),\n    }\n    fanouts = torch.tensor([-1, -1])\n    sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors\n    subgraph = sampler(nodes, fanouts)\n\n    # Verify in subgraph.\n    expected_sampled_csc = {\n        \"n1:e1:n2\": gb.CSCFormatBase(\n            indptr=torch.tensor([0, 2], device=F.ctx()),\n            indices=torch.tensor([0, 1], device=F.ctx()),\n        ),\n        \"n2:e2:n1\": gb.CSCFormatBase(\n            indptr=torch.tensor([0, 2], device=F.ctx()),\n            indices=torch.tensor([0, 2], device=F.ctx()),\n        ),\n    }\n    assert len(subgraph.sampled_csc) == 2\n    for etype, pairs in expected_sampled_csc.items():\n        assert torch.equal(subgraph.sampled_csc[etype].indptr, pairs.indptr)\n        assert torch.equal(\n            subgraph.sampled_csc[etype].indices.sort()[0], pairs.indices\n        )\n        assert len(pairs.indices) == len(subgraph.original_edge_ids[etype])\n    assert subgraph.original_column_node_ids is None\n    assert subgraph.original_row_node_ids is None\n\n    # Sample on single node type.\n    nodes = {\"n1\": torch.tensor([0], dtype=indices_dtype, device=F.ctx())}\n    fanouts = torch.tensor([-1, -1])\n    sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors\n    subgraph = sampler(nodes, fanouts)\n\n    # Verify in subgraph.\n    expected_sampled_csc = {\n        \"n1:e1:n2\": gb.CSCFormatBase(\n            indptr=torch.tensor([0], device=F.ctx()),\n            indices=torch.tensor([], device=F.ctx()),\n        ),\n        \"n2:e2:n1\": gb.CSCFormatBase(\n            indptr=torch.tensor([0, 2], device=F.ctx()),\n            indices=torch.tensor([0, 2], device=F.ctx()),\n        ),\n    }\n    assert len(subgraph.sampled_csc) == 2\n    for etype, pairs in expected_sampled_csc.items():\n        assert torch.equal(subgraph.sampled_csc[etype].indptr, pairs.indptr)\n        assert torch.equal(\n            subgraph.sampled_csc[etype].indices.sort()[0], pairs.indices\n        )\n        assert len(pairs.indices) == len(subgraph.original_edge_ids[etype])\n    assert subgraph.original_column_node_ids is None\n    assert subgraph.original_row_node_ids is None\n\n\n@pytest.mark.parametrize(\n    \"fanouts, expected_sampled_num1, expected_sampled_num2\",\n    [\n        ([0], 0, 0),\n        ([1], 1, 1),\n        ([2], 2, 2),\n        ([4], 2, 2),\n        ([-1], 2, 2),\n        ([0, 0], 0, 0),\n        ([1, 0], 1, 0),\n        ([0, 1], 0, 1),\n        ([1, 1], 1, 1),\n        ([2, 1], 2, 1),\n        ([-1, -1], 2, 2),\n    ],\n)\n@pytest.mark.parametrize(\"labor\", [False, True])\ndef test_sample_neighbors_fanouts(\n    fanouts, expected_sampled_num1, expected_sampled_num2, labor\n):\n    \"\"\"Original graph in COO:\n    \"n1:e1:n2\":[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]\n    \"n2:e2:n1\":[0, 0, 1, 2], [0, 1, 1 ,0]\n    0   0   1   0   1\n    0   0   1   1   1\n    1   1   0   0   0\n    0   1   0   0   0\n    1   0   0   0   0\n    \"\"\"\n    # Initialize data.\n    ntypes = {\"n1\": 0, \"n2\": 1}\n    etypes = {\"n1:e1:n2\": 0, \"n2:e2:n1\": 1}\n    total_num_edges = 9\n    indptr = torch.LongTensor([0, 2, 4, 6, 7, 9])\n    indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])\n    type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0])\n    node_type_offset = torch.LongTensor([0, 2, 5])\n    assert indptr[-1] == total_num_edges\n    assert indptr[-1] == len(indices)\n\n    # Construct FusedCSCSamplingGraph.\n    graph = gb.fused_csc_sampling_graph(\n        indptr,\n        indices,\n        node_type_offset=node_type_offset,\n        type_per_edge=type_per_edge,\n        node_type_to_id=ntypes,\n        edge_type_to_id=etypes,\n    ).to(F.ctx())\n\n    nodes = {\n        \"n1\": torch.tensor([0], device=F.ctx()),\n        \"n2\": torch.tensor([0], device=F.ctx()),\n    }\n    fanouts = torch.LongTensor(fanouts)\n    sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors\n    subgraph = sampler(nodes, fanouts)\n\n    # Verify in subgraph.\n    assert (\n        expected_sampled_num1 == 0\n        or subgraph.sampled_csc[\"n1:e1:n2\"].indices.numel()\n        == expected_sampled_num1\n    )\n    assert subgraph.sampled_csc[\"n1:e1:n2\"].indptr.size(0) == 2\n    assert (\n        expected_sampled_num2 == 0\n        or subgraph.sampled_csc[\"n2:e2:n1\"].indices.numel()\n        == expected_sampled_num2\n    )\n    assert subgraph.sampled_csc[\"n2:e2:n1\"].indptr.size(0) == 2\n\n\n@pytest.mark.parametrize(\n    \"replace, expected_sampled_num1, expected_sampled_num2\",\n    [(False, 2, 2), (True, 4, 4)],\n)\ndef test_sample_neighbors_replace(\n    replace, expected_sampled_num1, expected_sampled_num2\n):\n    if F._default_context_str == \"gpu\" and replace == True:\n        pytest.skip(\"Sampling with replacement not yet supported on GPU.\")\n    \"\"\"Original graph in COO:\n    \"n1:e1:n2\":[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]\n    \"n2:e2:n1\":[0, 0, 1, 2], [0, 1, 1 ,0]\n    0   0   1   0   1\n    0   0   1   1   1\n    1   1   0   0   0\n    0   1   0   0   0\n    1   0   0   0   0\n    \"\"\"\n    # Initialize data.\n    ntypes = {\"n1\": 0, \"n2\": 1}\n    etypes = {\"n1:e1:n2\": 0, \"n2:e2:n1\": 1}\n    total_num_edges = 9\n    indptr = torch.LongTensor([0, 2, 4, 6, 7, 9])\n    indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])\n    type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0])\n    node_type_offset = torch.LongTensor([0, 2, 5])\n    assert indptr[-1] == total_num_edges\n    assert indptr[-1] == len(indices)\n\n    # Construct FusedCSCSamplingGraph.\n    graph = gb.fused_csc_sampling_graph(\n        indptr,\n        indices,\n        node_type_offset=node_type_offset,\n        type_per_edge=type_per_edge,\n        node_type_to_id=ntypes,\n        edge_type_to_id=etypes,\n    )\n\n    nodes = {\"n1\": torch.LongTensor([0]), \"n2\": torch.LongTensor([0])}\n    subgraph = graph.sample_neighbors(\n        nodes, torch.LongTensor([4]), replace=replace\n    )\n\n    # Verify in subgraph.\n    assert (\n        subgraph.sampled_csc[\"n1:e1:n2\"].indices.numel()\n        == expected_sampled_num1\n    )\n    assert subgraph.sampled_csc[\"n1:e1:n2\"].indptr.size(0) == 2\n    assert (\n        subgraph.sampled_csc[\"n2:e2:n1\"].indices.numel()\n        == expected_sampled_num2\n    )\n    assert subgraph.sampled_csc[\"n2:e2:n1\"].indptr.size(0) == 2\n\n\n@pytest.mark.parametrize(\"labor\", [False, True])\n@pytest.mark.parametrize(\"is_pinned\", [False, True])\ndef test_sample_neighbors_return_eids_homo(labor, is_pinned):\n    \"\"\"Original graph in COO:\n    1   0   1   0   1\n    1   0   1   1   0\n    0   1   0   1   0\n    0   1   0   0   1\n    1   0   0   0   1\n    \"\"\"\n    if F._default_context_str == \"cpu\" and is_pinned:\n        pytest.skip(\"Pinning is not meaningful without a GPU.\")\n    # Initialize data.\n    total_num_edges = 12\n    indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])\n    indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])\n    assert indptr[-1] == total_num_edges\n    assert indptr[-1] == len(indices)\n\n    # Add edge id mapping from CSC graph -> original graph.\n    edge_attributes = {gb.ORIGINAL_EDGE_ID: torch.randperm(total_num_edges)}\n\n    # Construct FusedCSCSamplingGraph.\n    graph = gb.fused_csc_sampling_graph(\n        indptr, indices, edge_attributes=edge_attributes\n    ).to(\"pinned\" if is_pinned else F.ctx())\n\n    # Generate subgraph via sample neighbors.\n    nodes = torch.LongTensor([1, 3, 4]).to(F.ctx())\n    sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors\n    subgraph = sampler(nodes, fanouts=torch.LongTensor([-1]))\n\n    # Verify in subgraph.\n    expected_reverse_edge_ids = edge_attributes[gb.ORIGINAL_EDGE_ID][\n        torch.tensor([3, 4, 7, 8, 9, 10, 11])\n    ].to(F.ctx())\n    assert torch.equal(\n        torch.sort(expected_reverse_edge_ids)[0],\n        torch.sort(subgraph.original_edge_ids)[0],\n    )\n    assert subgraph.original_column_node_ids is None\n    assert subgraph.original_row_node_ids is None\n\n\n@pytest.mark.parametrize(\"labor\", [False, True])\ndef test_sample_neighbors_return_eids_hetero(labor):\n    \"\"\"\n    Original graph in COO:\n    \"n1:e1:n2\":[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]\n    \"n2:e2:n1\":[0, 0, 1, 2], [0, 1, 1 ,0]\n    0   0   1   0   1\n    0   0   1   1   1\n    1   1   0   0   0\n    0   1   0   0   0\n    1   0   0   0   0\n    \"\"\"\n    # Initialize data.\n    ntypes = {\"n1\": 0, \"n2\": 1}\n    etypes = {\"n1:e1:n2\": 0, \"n2:e2:n1\": 1}\n    total_num_edges = 9\n    indptr = torch.LongTensor([0, 2, 4, 6, 7, 9])\n    indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])\n    type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0])\n    node_type_offset = torch.LongTensor([0, 2, 5])\n    edge_attributes = {\n        gb.ORIGINAL_EDGE_ID: torch.cat([torch.randperm(4), torch.randperm(5)])\n    }\n    assert indptr[-1] == total_num_edges\n    assert indptr[-1] == len(indices)\n\n    # Construct FusedCSCSamplingGraph.\n    graph = gb.fused_csc_sampling_graph(\n        indptr,\n        indices,\n        node_type_offset=node_type_offset,\n        type_per_edge=type_per_edge,\n        edge_attributes=edge_attributes,\n        node_type_to_id=ntypes,\n        edge_type_to_id=etypes,\n    ).to(F.ctx())\n\n    # Sample on both node types.\n    nodes = {\n        \"n1\": torch.LongTensor([0]).to(F.ctx()),\n        \"n2\": torch.LongTensor([0]).to(F.ctx()),\n    }\n    fanouts = torch.tensor([-1, -1])\n    sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors\n    subgraph = sampler(nodes, fanouts)\n\n    expected_reverse_edge_ids = {\n        \"n2:e2:n1\": graph.edge_attributes[gb.ORIGINAL_EDGE_ID][\n            torch.tensor([0, 1], device=F.ctx())\n        ],\n        \"n1:e1:n2\": graph.edge_attributes[gb.ORIGINAL_EDGE_ID][\n            torch.tensor([4, 5], device=F.ctx())\n        ],\n    }\n    assert subgraph.original_column_node_ids is None\n    assert subgraph.original_row_node_ids is None\n    for etype in etypes.keys():\n        assert torch.equal(\n            subgraph.original_edge_ids[etype].sort()[0],\n            expected_reverse_edge_ids[etype].sort()[0],\n        )\n\n\n@pytest.mark.parametrize(\"replace\", [True, False])\n@pytest.mark.parametrize(\"labor\", [False, True])\n@pytest.mark.parametrize(\"probs_name\", [\"weight\", \"mask\"])\ndef test_sample_neighbors_probs(replace, labor, probs_name):\n    if F._default_context_str == \"gpu\" and replace == True:\n        pytest.skip(\"Sampling with replacement not yet supported on GPU.\")\n    \"\"\"Original graph in COO:\n    1   0   1   0   1\n    1   0   1   1   0\n    0   1   0   1   0\n    0   1   0   0   1\n    1   0   0   0   1\n    \"\"\"\n    # Initialize data.\n    total_num_edges = 12\n    indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])\n    indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])\n    assert indptr[-1] == total_num_edges\n    assert indptr[-1] == len(indices)\n\n    edge_attributes = {\n        \"weight\": torch.FloatTensor(\n            [2.5, 0, 8.4, 0, 0.4, 1.2, 2.5, 0, 8.4, 0.5, 0.4, 1.2]\n        ),\n        \"mask\": torch.BoolTensor([1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1]),\n    }\n\n    # Construct FusedCSCSamplingGraph.\n    graph = gb.fused_csc_sampling_graph(\n        indptr, indices, edge_attributes=edge_attributes\n    )\n\n    # Generate subgraph via sample neighbors.\n    nodes = torch.LongTensor([1, 3, 4])\n\n    sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors\n    subgraph = sampler(\n        nodes,\n        fanouts=torch.tensor([2]),\n        replace=replace,\n        probs_name=probs_name,\n    )\n\n    # Verify in subgraph.\n    sampled_num = subgraph.sampled_csc.indices.size(0)\n    assert subgraph.sampled_csc.indptr.size(0) == 4\n    if replace:\n        assert sampled_num == 6\n    else:\n        assert sampled_num == 4\n\n\n@pytest.mark.parametrize(\"replace\", [True, False])\n@pytest.mark.parametrize(\"labor\", [False, True])\n@pytest.mark.parametrize(\n    \"probs_or_mask\",\n    [\n        torch.zeros(12, dtype=torch.float32),\n        torch.zeros(12, dtype=torch.bool),\n    ],\n)\ndef test_sample_neighbors_zero_probs(replace, labor, probs_or_mask):\n    if F._default_context_str == \"gpu\" and replace == True:\n        pytest.skip(\"Sampling with replacement not yet supported on GPU.\")\n    # Initialize data.\n    total_num_nodes = 5\n    total_num_edges = 12\n    indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])\n    indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])\n    assert indptr[-1] == total_num_edges\n    assert indptr[-1] == len(indices)\n\n    edge_attributes = {\"probs_or_mask\": probs_or_mask}\n\n    # Construct FusedCSCSamplingGraph.\n    graph = gb.fused_csc_sampling_graph(\n        indptr, indices, edge_attributes=edge_attributes\n    )\n\n    # Generate subgraph via sample neighbors.\n    nodes = torch.LongTensor([1, 3, 4])\n    sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors\n    subgraph = sampler(\n        nodes,\n        fanouts=torch.tensor([5]),\n        replace=replace,\n        probs_name=\"probs_or_mask\",\n    )\n\n    # Verify in subgraph.\n    sampled_num = subgraph.sampled_csc.indices.size(0)\n    assert subgraph.sampled_csc.indptr.size(0) == 4\n    assert sampled_num == 0\n\n\n@pytest.mark.parametrize(\"replace\", [False, True])\n@pytest.mark.parametrize(\"labor\", [False, True])\n@pytest.mark.parametrize(\n    \"fanouts, probs_name\",\n    [\n        ([2], \"mask\"),\n        ([3], \"mask\"),\n        ([4], \"mask\"),\n        ([-1], \"mask\"),\n        ([7], \"mask\"),\n        ([3], \"all\"),\n        ([-1], \"all\"),\n        ([7], \"all\"),\n        ([3], \"zero\"),\n        ([-1], \"zero\"),\n        ([3], \"none\"),\n        ([-1], \"none\"),\n    ],\n)\ndef test_sample_neighbors_homo_pick_number(fanouts, replace, labor, probs_name):\n    if F._default_context_str == \"gpu\" and replace == True:\n        pytest.skip(\"Sampling with replacement not yet supported on GPU.\")\n    \"\"\"Original graph in COO:\n    1   1   1   1   1   1\n    0   0   0   0   0   0\n    0   0   0   0   0   0\n    0   0   0   0   0   0\n    0   0   0   0   0   0\n    0   0   0   0   0   0\n    \"\"\"\n    # Initialize data.\n    total_num_edges = 6\n    indptr = torch.LongTensor([0, 6, 6, 6, 6, 6, 6])\n    indices = torch.LongTensor([0, 1, 2, 3, 4, 5])\n    assert indptr[-1] == total_num_edges\n    assert indptr[-1] == len(indices)\n\n    edge_attributes = {\n        \"mask\": torch.BoolTensor([1, 0, 0, 1, 0, 1]),\n        \"all\": torch.BoolTensor([1, 1, 1, 1, 1, 1]),\n        \"zero\": torch.BoolTensor([0, 0, 0, 0, 0, 0]),\n    }\n\n    # Construct FusedCSCSamplingGraph.\n    graph = gb.fused_csc_sampling_graph(\n        indptr, indices, edge_attributes=edge_attributes\n    )\n\n    # Generate subgraph via sample neighbors.\n    nodes = torch.LongTensor([0, 1])\n\n    sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors\n\n    # Make sure no exception will be thrown.\n    subgraph = sampler(\n        nodes,\n        fanouts=torch.LongTensor(fanouts),\n        replace=replace,\n        probs_name=probs_name if probs_name != \"none\" else None,\n    )\n    sampled_num = subgraph.sampled_csc.indices.size(0)\n    assert subgraph.sampled_csc.indptr.size(0) == 3\n    # Verify in subgraph.\n    if probs_name == \"mask\":\n        if fanouts[0] == -1:\n            assert sampled_num == 3\n        else:\n            if replace:\n                assert sampled_num == fanouts[0]\n            else:\n                assert sampled_num == min(fanouts[0], 3)\n    elif probs_name == \"zero\":\n        assert sampled_num == 0\n    else:\n        if fanouts[0] == -1:\n            assert sampled_num == 6\n        else:\n            if replace:\n                assert sampled_num == fanouts[0]\n            else:\n                assert sampled_num == min(fanouts[0], 6)\n\n\n@pytest.mark.parametrize(\"replace\", [False, True])\n@pytest.mark.parametrize(\"labor\", [False, True])\n@pytest.mark.parametrize(\n    \"fanouts, probs_name\",\n    [\n        ([-1, -1, -1], \"mask\"),\n        ([1, 1, 1], \"mask\"),\n        ([2, 2, 2], \"mask\"),\n        ([3, 3, 3], \"mask\"),\n        ([4, 4, 4], \"mask\"),\n        ([-1, 1, 3], \"none\"),\n        ([2, -1, 4], \"none\"),\n    ],\n)\ndef test_sample_neighbors_hetero_pick_number(\n    fanouts, replace, labor, probs_name\n):\n    if F._default_context_str == \"gpu\" and replace == True:\n        pytest.skip(\"Sampling with replacement not yet supported on GPU.\")\n    # Initialize data.\n    total_num_nodes = 10\n    total_num_edges = 9\n    ntypes = {\"N0\": 0, \"N1\": 1, \"N2\": 2, \"N3\": 3}\n    etypes = {\n        \"N1:R0:N0\": 0,\n        \"N2:R1:N0\": 1,\n        \"N3:R2:N0\": 2,\n    }\n    indptr = torch.LongTensor([0, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9])\n    indices = torch.LongTensor([1, 2, 3, 4, 5, 6, 7, 8, 9])\n    node_type_offset = torch.LongTensor([0, 1, 4, 7, 10])\n    type_per_edge = torch.LongTensor([0, 0, 0, 1, 1, 1, 2, 2, 2])\n    assert indptr[-1] == total_num_edges\n    assert indptr[-1] == len(indices)\n    assert node_type_offset[-1] == total_num_nodes\n    assert all(type_per_edge < len(etypes))\n\n    edge_attributes = {\n        \"mask\": torch.BoolTensor([1, 1, 0, 1, 1, 1, 0, 0, 0]),\n        \"all\": torch.BoolTensor([1, 1, 1, 1, 1, 1, 1, 1, 1]),\n        \"zero\": torch.BoolTensor([0, 0, 0, 0, 0, 0, 0, 0, 0]),\n    }\n\n    # Construct FusedCSCSamplingGraph.\n    graph = gb.fused_csc_sampling_graph(\n        indptr,\n        indices,\n        edge_attributes=edge_attributes,\n        node_type_offset=node_type_offset,\n        type_per_edge=type_per_edge,\n        node_type_to_id=ntypes,\n        edge_type_to_id=etypes,\n    ).to(F.ctx())\n\n    # Generate subgraph via sample neighbors.\n    nodes = {\n        \"N0\": torch.LongTensor([0]).to(F.ctx()),\n        \"N1\": torch.LongTensor([1]).to(F.ctx()),\n    }\n\n    sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors\n\n    # Make sure no exception will be thrown.\n    subgraph = sampler(\n        nodes,\n        fanouts=torch.LongTensor(fanouts),\n        replace=replace,\n        probs_name=probs_name if probs_name != \"none\" else None,\n    )\n    print(subgraph)\n    if probs_name == \"none\":\n        for etype, pairs in subgraph.sampled_csc.items():\n            assert pairs.indptr.size(0) == 2\n            sampled_num = pairs.indices.size(0)\n            fanout = fanouts[etypes[etype]]\n            if fanout == -1:\n                assert sampled_num == 3\n            else:\n                if replace:\n                    assert sampled_num == fanout\n                else:\n                    assert sampled_num == min(fanout, 3)\n    else:\n        fanout = fanouts[0]  # Here fanout is the same for all etypes.\n        for etype, pairs in subgraph.sampled_csc.items():\n            assert pairs.indptr.size(0) == 2\n            sampled_num = pairs.indices.size(0)\n            if etypes[etype] == 0:\n                # Etype 0: 2 valid neighbors.\n                if fanout == -1:\n                    assert sampled_num == 2\n                else:\n                    if replace:\n                        assert sampled_num == fanout\n                    else:\n                        assert sampled_num == min(fanout, 2)\n            elif etypes[etype] == 1:\n                # Etype 1: 3 valid neighbors.\n                if fanout == -1:\n                    assert sampled_num == 3\n                else:\n                    if replace:\n                        assert sampled_num == fanout\n                    else:\n                        assert sampled_num == min(fanout, 3)\n            else:\n                # Etype 2: 0 valid neighbors.\n                assert sampled_num == 0\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Graph is CPU only at present.\",\n)\ndef test_graph_attributes():\n    num_nodes = 1000\n    num_edges = 10 * 1000\n    csc_indptr, indices = gbt.random_homo_graph(num_nodes, num_edges)\n    graph = gb.fused_csc_sampling_graph(\n        csc_indptr,\n        indices,\n        node_attributes=None,\n        edge_attributes=None,\n    )\n\n    # Case 1: default is None.\n    assert graph.node_attributes is None\n    assert graph.edge_attributes is None\n\n    # Case 2: Assign the whole node/edge attributes.\n    node_attributes = {\n        \"A\": torch.rand(num_nodes, 2),\n        \"B\": torch.rand(num_nodes, 2),\n    }\n    edge_attributes = {\n        \"A\": torch.rand(num_nodes, 2),\n        \"B\": torch.rand(num_nodes, 2),\n    }\n    graph.node_attributes = node_attributes\n    graph.edge_attributes = edge_attributes\n    for k, v in node_attributes.items():\n        assert torch.equal(v, graph.node_attributes[k])\n        assert torch.equal(v, graph.node_attribute(k))\n    for k, v in edge_attributes.items():\n        assert torch.equal(v, graph.edge_attributes[k])\n        assert torch.equal(v, graph.edge_attribute(k))\n    assert \"C\" not in graph.node_attributes\n    assert \"C\" not in graph.edge_attributes\n    with pytest.raises(RuntimeError, match=\"Node attribute C does not exist.\"):\n        graph.node_attribute(\"C\")\n    with pytest.raises(RuntimeError, match=\"Edge attribute C does not exist.\"):\n        graph.edge_attribute(\"C\")\n\n    # Case 3: Assign/overwrite more node/edge attributes into existing ones.\n    for key in [\"B\", \"C\"]:\n        node_attributes[key] = torch.rand(num_nodes, 2)\n        edge_attributes[key] = torch.rand(num_edges, 2)\n        graph.add_node_attribute(key, node_attributes[key])\n        graph.add_edge_attribute(key, edge_attributes[key])\n    for k, v in node_attributes.items():\n        assert torch.equal(v, graph.node_attributes[k])\n        assert torch.equal(v, graph.node_attribute(k))\n    for k, v in edge_attributes.items():\n        assert torch.equal(v, graph.edge_attributes[k])\n        assert torch.equal(v, graph.edge_attribute(k))\n\n    # Case 4: Assign more node/edge attributes which were None previously.\n    graph.node_attributes = None\n    graph.edge_attributes = None\n    graph.add_node_attribute(\"C\", node_attributes[\"C\"])\n    graph.add_edge_attribute(\"C\", edge_attributes[\"C\"])\n    assert torch.equal(node_attributes[\"C\"], graph.node_attribute(\"C\"))\n    assert torch.equal(node_attributes[\"C\"], graph.node_attributes[\"C\"])\n    assert torch.equal(edge_attributes[\"C\"], graph.edge_attribute(\"C\"))\n    assert torch.equal(edge_attributes[\"C\"], graph.edge_attributes[\"C\"])\n"
  },
  {
    "path": "tests/python/pytorch/graphbolt/impl/test_gpu_cached_feature.py",
    "content": "import os\nimport tempfile\nimport unittest\n\nimport backend as F\n\nimport numpy as np\nimport pytest\nimport torch\n\nfrom dgl import graphbolt as gb\n\n\ndef to_on_disk_numpy(test_dir, name, t):\n    path = os.path.join(test_dir, name + \".npy\")\n    np.save(path, t.cpu().numpy())\n    return path\n\n\ndef _skip_condition_cached_feature():\n    return (F._default_context_str != \"gpu\") or (\n        torch.cuda.get_device_capability()[0] < 7\n    )\n\n\ndef _reason_to_skip_cached_feature():\n    if F._default_context_str != \"gpu\":\n        return \"GPUCachedFeature tests are available only when testing the GPU backend.\"\n\n    return \"GPUCachedFeature requires a Volta or later generation NVIDIA GPU.\"\n\n\n@unittest.skipIf(\n    _skip_condition_cached_feature(),\n    reason=_reason_to_skip_cached_feature(),\n)\n@pytest.mark.parametrize(\n    \"dtype\",\n    [\n        torch.bool,\n        torch.uint8,\n        torch.int8,\n        torch.int16,\n        torch.int32,\n        torch.int64,\n        torch.float16,\n        torch.bfloat16,\n        torch.float32,\n        torch.float64,\n    ],\n)\n@pytest.mark.parametrize(\"cache_size_a\", [1, 1024])\n@pytest.mark.parametrize(\"cache_size_b\", [1, 1024])\ndef test_gpu_cached_feature(dtype, cache_size_a, cache_size_b):\n    a = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=dtype, pin_memory=True)\n    b = torch.tensor(\n        [[[1, 2], [3, 4]], [[4, 5], [6, 7]]], dtype=dtype, pin_memory=True\n    )\n\n    cache_size_a *= a[:1].element_size() * a[:1].numel()\n    cache_size_b *= b[:1].element_size() * b[:1].numel()\n\n    feat_store_a = gb.gpu_cached_feature(gb.TorchBasedFeature(a), cache_size_a)\n    feat_store_b = gb.gpu_cached_feature(gb.TorchBasedFeature(b), cache_size_b)\n\n    # Test read the entire feature.\n    assert torch.equal(feat_store_a.read(), a.to(\"cuda\"))\n    assert torch.equal(feat_store_b.read(), b.to(\"cuda\"))\n\n    # Test read with ids.\n    assert torch.equal(\n        feat_store_a.read(torch.tensor([0]).to(\"cuda\")),\n        torch.tensor([[1, 2, 3]], dtype=dtype).to(\"cuda\"),\n    )\n    assert torch.equal(\n        feat_store_b.read(torch.tensor([1, 1]).to(\"cuda\")),\n        torch.tensor([[[4, 5], [6, 7]], [[4, 5], [6, 7]]], dtype=dtype).to(\n            \"cuda\"\n        ),\n    )\n    assert torch.equal(\n        feat_store_a.read(torch.tensor([1, 1]).to(\"cuda\")),\n        torch.tensor([[4, 5, 6], [4, 5, 6]], dtype=dtype).to(\"cuda\"),\n    )\n    assert torch.equal(\n        feat_store_b.read(torch.tensor([0]).to(\"cuda\")),\n        torch.tensor([[[1, 2], [3, 4]]], dtype=dtype).to(\"cuda\"),\n    )\n    # The cache should be full now for the large cache sizes, %100 hit expected.\n    if cache_size_a >= 1024:\n        total_miss = feat_store_a._feature.total_miss\n        feat_store_a.read(torch.tensor([0, 1]).to(\"cuda\"))\n        assert total_miss == feat_store_a._feature.total_miss\n    if cache_size_b >= 1024:\n        total_miss = feat_store_b._feature.total_miss\n        feat_store_b.read(torch.tensor([0, 1]).to(\"cuda\"))\n        assert total_miss == feat_store_b._feature.total_miss\n    assert feat_store_a._feature.miss_rate == feat_store_a.miss_rate\n\n    # Test get the size and count of the entire feature.\n    assert feat_store_a.size() == torch.Size([3])\n    assert feat_store_b.size() == torch.Size([2, 2])\n    assert feat_store_a.count() == a.size(0)\n    assert feat_store_b.count() == b.size(0)\n\n    # Test update the entire feature.\n    feat_store_a.update(\n        torch.tensor([[0, 1, 2], [3, 5, 2]], dtype=dtype).to(\"cuda\")\n    )\n    assert torch.equal(\n        feat_store_a.read(),\n        torch.tensor([[0, 1, 2], [3, 5, 2]], dtype=dtype).to(\"cuda\"),\n    )\n\n    # Test update with ids.\n    feat_store_a.update(\n        torch.tensor([[2, 0, 1]], dtype=dtype).to(\"cuda\"),\n        torch.tensor([0]).to(\"cuda\"),\n    )\n    assert torch.equal(\n        feat_store_a.read(),\n        torch.tensor([[2, 0, 1], [3, 5, 2]], dtype=dtype).to(\"cuda\"),\n    )\n\n    # Test with different dimensionality\n    feat_store_a.update(b)\n    assert torch.equal(feat_store_a.read(), b.to(\"cuda\"))\n\n\n@unittest.skipIf(\n    _skip_condition_cached_feature(),\n    reason=_reason_to_skip_cached_feature(),\n)\n@pytest.mark.parametrize(\n    \"dtype\",\n    [\n        torch.bool,\n        torch.uint8,\n        torch.int8,\n        torch.int16,\n        torch.int32,\n        torch.int64,\n        torch.float16,\n        torch.bfloat16,\n        torch.float32,\n        torch.float64,\n    ],\n)\n@pytest.mark.parametrize(\"pin_memory\", [False, True])\ndef test_gpu_cached_feature_read_async(dtype, pin_memory):\n    a = torch.randint(0, 2, [1000, 13], dtype=dtype, pin_memory=pin_memory)\n    a_cuda = a.to(F.ctx())\n\n    cache_size = 256 * a[:1].nbytes\n\n    feat_store = gb.gpu_cached_feature(gb.TorchBasedFeature(a), cache_size)\n\n    # Test read with ids.\n    ids1 = torch.tensor([0, 15, 71, 101], device=F.ctx())\n    ids2 = torch.tensor([71, 101, 202, 303], device=F.ctx())\n    for ids in [ids1, ids2]:\n        reader = feat_store.read_async(ids)\n        for _ in range(feat_store.read_async_num_stages(ids.device)):\n            values = next(reader)\n        assert torch.equal(values.wait(), a_cuda[ids])\n\n\n@unittest.skipIf(\n    _skip_condition_cached_feature(),\n    reason=_reason_to_skip_cached_feature(),\n)\n@unittest.skipIf(\n    not torch.ops.graphbolt.detect_io_uring(),\n    reason=\"DiskBasedFeature is not available on this system.\",\n)\n@pytest.mark.parametrize(\n    \"dtype\",\n    [\n        torch.bool,\n        torch.uint8,\n        torch.int8,\n        torch.int16,\n        torch.int32,\n        torch.int64,\n        torch.float16,\n        torch.float32,\n        torch.float64,\n    ],\n)\ndef test_gpu_cached_nested_feature_async(dtype):\n    a = torch.randint(0, 2, [1000, 13], dtype=dtype, device=F.ctx())\n\n    cache_size = 256 * a[:1].nbytes\n\n    ids1 = torch.tensor([0, 15, 71, 101], device=F.ctx())\n    ids2 = torch.tensor([71, 101, 202, 303], device=F.ctx())\n\n    with tempfile.TemporaryDirectory() as test_dir:\n        path = to_on_disk_numpy(test_dir, \"tensor\", a)\n\n        disk_store = gb.DiskBasedFeature(path=path)\n        feat_store1 = gb.gpu_cached_feature(disk_store, cache_size)\n        feat_store2 = gb.gpu_cached_feature(\n            gb.cpu_cached_feature(disk_store, cache_size * 2), cache_size\n        )\n        feat_store3 = gb.gpu_cached_feature(\n            gb.cpu_cached_feature(disk_store, cache_size * 2, pin_memory=True),\n            cache_size,\n        )\n\n        # Test read feature.\n        for feat_store in [feat_store1, feat_store2, feat_store3]:\n            for ids in [ids1, ids2]:\n                reader = feat_store.read_async(ids)\n                for _ in range(feat_store.read_async_num_stages(ids.device)):\n                    values = next(reader)\n                assert torch.equal(values.wait(), a[ids])\n\n        feat_store1 = feat_store2 = feat_store3 = disk_store = None\n"
  },
  {
    "path": "tests/python/pytorch/graphbolt/impl/test_gpu_graph_cache.py",
    "content": "import unittest\n\nimport backend as F\n\nimport dgl.graphbolt as gb\n\nimport pytest\nimport torch\n\n\n@unittest.skipIf(\n    F._default_context_str != \"gpu\"\n    or torch.cuda.get_device_capability()[0] < 7,\n    reason=\"GPUCachedFeature tests are available only when testing the GPU backend.\"\n    if F._default_context_str != \"gpu\"\n    else \"GPUCachedFeature requires a Volta or later generation NVIDIA GPU.\",\n)\n@pytest.mark.parametrize(\n    \"indptr_dtype\",\n    [\n        torch.int32,\n        torch.int64,\n    ],\n)\n@pytest.mark.parametrize(\n    \"dtype\",\n    [\n        torch.bool,\n        torch.uint8,\n        torch.int8,\n        torch.int16,\n        torch.int32,\n        torch.int64,\n        torch.float16,\n        torch.bfloat16,\n        torch.float32,\n        torch.float64,\n    ],\n)\n@pytest.mark.parametrize(\"cache_size\", [4, 9, 11])\n@pytest.mark.parametrize(\"with_edge_ids\", [True, False])\ndef test_gpu_graph_cache(indptr_dtype, dtype, cache_size, with_edge_ids):\n    indices_dtype = torch.int32\n    indptr = torch.tensor([0, 3, 6, 10], dtype=indptr_dtype, pin_memory=True)\n    indices = torch.arange(0, indptr[-1], dtype=indices_dtype, pin_memory=True)\n    probs_or_mask = indices.to(dtype).pin_memory()\n    edge_tensors = [indices, probs_or_mask]\n\n    g = gb.GPUGraphCache(\n        cache_size,\n        2,\n        indptr.dtype,\n        [e.dtype for e in edge_tensors],\n        not with_edge_ids,\n    )\n\n    for i in range(10):\n        keys = (\n            torch.arange(2, dtype=indices_dtype, device=F.ctx()) + i * 2\n        ) % (indptr.size(0) - 1)\n        missing_keys, replace = g.query(keys)\n        (\n            missing_indptr,\n            missing_edge_tensors,\n        ) = torch.ops.graphbolt.index_select_csc_batched(\n            indptr, edge_tensors, missing_keys, with_edge_ids, None\n        )\n        output_indptr, output_edge_tensors = replace(\n            missing_indptr, missing_edge_tensors\n        )\n\n        (\n            reference_indptr,\n            reference_edge_tensors,\n        ) = torch.ops.graphbolt.index_select_csc_batched(\n            indptr, edge_tensors, keys, with_edge_ids, None\n        )\n\n        assert torch.equal(output_indptr, reference_indptr)\n        assert len(output_edge_tensors) == len(reference_edge_tensors)\n        for e, ref in zip(output_edge_tensors, reference_edge_tensors):\n            assert torch.equal(e, ref)\n"
  },
  {
    "path": "tests/python/pytorch/graphbolt/impl/test_hetero_cached_feature.py",
    "content": "import backend as F\n\nimport pytest\nimport torch\n\nfrom dgl import graphbolt as gb\n\n\n@pytest.mark.parametrize(\n    \"cached_feature_type\", [gb.cpu_cached_feature, gb.gpu_cached_feature]\n)\ndef test_hetero_cached_feature(cached_feature_type):\n    if cached_feature_type == gb.gpu_cached_feature and (\n        F._default_context_str != \"gpu\"\n        or torch.cuda.get_device_capability()[0] < 7\n    ):\n        pytest.skip(\n            \"GPUCachedFeature tests are available only when testing the GPU backend.\"\n            if F._default_context_str != \"gpu\"\n            else \"GPUCachedFeature requires a Volta or later generation NVIDIA GPU.\"\n        )\n    device = F.ctx() if cached_feature_type == gb.gpu_cached_feature else None\n    pin_memory = cached_feature_type == gb.gpu_cached_feature\n\n    a = {\n        (\"node\", str(i), \"feat\"): gb.TorchBasedFeature(\n            torch.randn([(i + 1) * 10, 5], pin_memory=pin_memory)\n        )\n        for i in range(75)\n    }\n    cached_a = cached_feature_type(a, 2**18)\n\n    for i in range(1024):\n        etype = i % len(a)\n        ids = torch.randint(\n            0, (etype + 1) * 10 - 1, ((etype + 1) * 4,), device=device\n        )\n        feature_key = (\"node\", str(etype), \"feat\")\n        ref = a[feature_key].read(ids)\n        val = cached_a[feature_key].read(ids)\n        torch.testing.assert_close(ref, val, rtol=0, atol=0)\n    assert cached_a[feature_key].miss_rate < 0.69\n"
  },
  {
    "path": "tests/python/pytorch/graphbolt/impl/test_in_subgraph_sampler.py",
    "content": "import unittest\n\nimport backend as F\nimport dgl.graphbolt as gb\nimport pytest\nimport torch\n\nfrom .. import gb_test_utils\n\n\n@unittest.skipIf(\n    F._default_context_str == \"cpu\",\n    reason=\"Tests for pinned memory are only meaningful on GPU.\",\n)\n@pytest.mark.parametrize(\n    \"indptr_dtype\",\n    [torch.int32, torch.int64],\n)\n@pytest.mark.parametrize(\n    \"indices_dtype\",\n    [\n        torch.int8,\n        torch.uint8,\n        torch.int16,\n        torch.int32,\n        torch.int64,\n        torch.float32,\n        torch.float64,\n    ],\n)\n@pytest.mark.parametrize(\"idtype\", [torch.int32, torch.int64])\n@pytest.mark.parametrize(\"is_pinned\", [False, True])\n@pytest.mark.parametrize(\"with_edge_ids\", [False, True])\n@pytest.mark.parametrize(\"output_size\", [None, True])\ndef test_index_select_csc(\n    indptr_dtype, indices_dtype, idtype, is_pinned, with_edge_ids, output_size\n):\n    \"\"\"Original graph in COO:\n    1   0   1   0   1   0\n    1   0   0   1   0   1\n    0   1   0   1   0   0\n    0   1   0   0   1   0\n    1   0   0   0   0   1\n    0   0   1   0   1   0\n    \"\"\"\n    indptr = torch.tensor([0, 3, 5, 7, 9, 12, 14], dtype=indptr_dtype)\n    indices = torch.tensor(\n        [0, 1, 4, 2, 3, 0, 5, 1, 2, 0, 3, 5, 1, 4], dtype=indices_dtype\n    )\n    index = torch.tensor([0, 5, 3], dtype=idtype)\n\n    cpu_indptr, cpu_indices = torch.ops.graphbolt.index_select_csc(\n        indptr, indices, index, None\n    )\n    if is_pinned:\n        indptr = indptr.pin_memory()\n        indices = indices.pin_memory()\n    else:\n        indptr = indptr.cuda()\n        indices = indices.cuda()\n    index = index.cuda()\n    edge_ids = torch.tensor(\n        [0, 1, 2, 12, 13, 7, 8], dtype=indptr_dtype, device=index.device\n    )\n\n    if output_size:\n        output_size = len(cpu_indices)\n\n    gpu_indptr, gpu_indices = torch.ops.graphbolt.index_select_csc(\n        indptr, indices, index, output_size\n    )\n    assert not cpu_indptr.is_cuda\n    assert not cpu_indices.is_cuda\n\n    assert gpu_indptr.is_cuda\n    assert gpu_indices.is_cuda\n\n    assert torch.equal(cpu_indptr, gpu_indptr.cpu())\n    assert torch.equal(cpu_indices, gpu_indices.cpu())\n\n    for output_size_selection in [None, output_size]:\n        indices_list = [\n            indices,\n            indices.int().pin_memory() if is_pinned else indices.int(),\n        ]\n        (\n            gpu_indptr2,\n            gpu_indices_list,\n        ) = torch.ops.graphbolt.index_select_csc_batched(\n            indptr, indices_list, index, with_edge_ids, output_size_selection\n        )\n\n        assert torch.equal(gpu_indptr, gpu_indptr2)\n        assert torch.equal(gpu_indices_list[0], gpu_indices)\n        assert torch.equal(gpu_indices_list[1], gpu_indices.int())\n        if with_edge_ids:\n            assert torch.equal(gpu_indices_list[2], edge_ids)\n\n\ndef test_InSubgraphSampler_homo():\n    \"\"\"Original graph in COO:\n    1   0   1   0   1   0\n    1   0   0   1   0   1\n    0   1   0   1   0   0\n    0   1   0   0   1   0\n    1   0   0   0   0   1\n    0   0   1   0   1   0\n    \"\"\"\n    indptr = torch.LongTensor([0, 3, 5, 7, 9, 12, 14])\n    indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 5, 1, 2, 0, 3, 5, 1, 4])\n    graph = gb.fused_csc_sampling_graph(indptr, indices).to(F.ctx())\n\n    seed_nodes = torch.LongTensor([0, 5, 3])\n    item_set = gb.ItemSet(seed_nodes, names=\"seeds\")\n    batch_size = 1\n    item_sampler = gb.ItemSampler(item_set, batch_size=batch_size).copy_to(\n        F.ctx()\n    )\n\n    in_subgraph_sampler = gb.InSubgraphSampler(item_sampler, graph)\n\n    it = iter(in_subgraph_sampler)\n\n    def original_indices(minibatch):\n        sampled_subgraph = minibatch.sampled_subgraphs[0]\n        _indices = sampled_subgraph.original_row_node_ids[\n            sampled_subgraph.sampled_csc.indices\n        ]\n        return _indices\n\n    mn = next(it)\n    assert torch.equal(mn.seeds, torch.LongTensor([0]).to(F.ctx()))\n    assert torch.equal(\n        mn.sampled_subgraphs[0].sampled_csc.indptr,\n        torch.tensor([0, 3]).to(F.ctx()),\n    )\n\n    mn = next(it)\n    assert torch.equal(mn.seeds, torch.LongTensor([5]).to(F.ctx()))\n    assert torch.equal(\n        mn.sampled_subgraphs[0].sampled_csc.indptr,\n        torch.tensor([0, 2]).to(F.ctx()),\n    )\n    assert torch.equal(original_indices(mn), torch.tensor([1, 4]).to(F.ctx()))\n\n    mn = next(it)\n    assert torch.equal(mn.seeds, torch.LongTensor([3]).to(F.ctx()))\n    assert torch.equal(\n        mn.sampled_subgraphs[0].sampled_csc.indptr,\n        torch.tensor([0, 2]).to(F.ctx()),\n    )\n    assert torch.equal(original_indices(mn), torch.tensor([1, 2]).to(F.ctx()))\n\n\ndef test_InSubgraphSampler_hetero():\n    \"\"\"Original graph in COO:\n    1   0   1   0   1   0\n    1   0   0   1   0   1\n    0   1   0   1   0   0\n    0   1   0   0   1   0\n    1   0   0   0   0   1\n    0   0   1   0   1   0\n    node_type_0: [0, 1, 2]\n    node_type_1: [3, 4, 5]\n    edge_type_0: node_type_0 -> node_type_0\n    edge_type_1: node_type_0 -> node_type_1\n    edge_type_2: node_type_1 -> node_type_0\n    edge_type_3: node_type_1 -> node_type_1\n    \"\"\"\n    ntypes = {\n        \"N0\": 0,\n        \"N1\": 1,\n    }\n    etypes = {\n        \"N0:R0:N0\": 0,\n        \"N0:R1:N1\": 1,\n        \"N1:R2:N0\": 2,\n        \"N1:R3:N1\": 3,\n    }\n    indptr = torch.LongTensor([0, 3, 5, 7, 9, 12, 14])\n    indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 5, 1, 2, 0, 3, 5, 1, 4])\n    node_type_offset = torch.LongTensor([0, 3, 6])\n    type_per_edge = torch.LongTensor([0, 0, 2, 0, 2, 0, 2, 1, 1, 1, 3, 3, 1, 3])\n    graph = gb.fused_csc_sampling_graph(\n        csc_indptr=indptr,\n        indices=indices,\n        node_type_offset=node_type_offset,\n        type_per_edge=type_per_edge,\n        node_type_to_id=ntypes,\n        edge_type_to_id=etypes,\n    ).to(F.ctx())\n\n    item_set = gb.HeteroItemSet(\n        {\n            \"N0\": gb.ItemSet(torch.LongTensor([1, 0, 2]), names=\"seeds\"),\n            \"N1\": gb.ItemSet(torch.LongTensor([0, 2, 1]), names=\"seeds\"),\n        }\n    )\n    batch_size = 2\n    item_sampler = gb.ItemSampler(item_set, batch_size=batch_size).copy_to(\n        F.ctx()\n    )\n\n    in_subgraph_sampler = gb.InSubgraphSampler(item_sampler, graph)\n\n    it = iter(in_subgraph_sampler)\n\n    mn = next(it)\n    assert torch.equal(mn.seeds[\"N0\"], torch.LongTensor([1, 0]).to(F.ctx()))\n    expected_sampled_csc = {\n        \"N0:R0:N0\": gb.CSCFormatBase(\n            indptr=torch.LongTensor([0, 1, 3]),\n            indices=torch.LongTensor([2, 1, 0]),\n        ),\n        \"N0:R1:N1\": gb.CSCFormatBase(\n            indptr=torch.LongTensor([0]), indices=torch.LongTensor([])\n        ),\n        \"N1:R2:N0\": gb.CSCFormatBase(\n            indptr=torch.LongTensor([0, 1, 2]), indices=torch.LongTensor([0, 1])\n        ),\n        \"N1:R3:N1\": gb.CSCFormatBase(\n            indptr=torch.LongTensor([0]), indices=torch.LongTensor([])\n        ),\n    }\n    for etype, pairs in mn.sampled_subgraphs[0].sampled_csc.items():\n        assert torch.equal(\n            pairs.indices, expected_sampled_csc[etype].indices.to(F.ctx())\n        )\n        assert torch.equal(\n            pairs.indptr, expected_sampled_csc[etype].indptr.to(F.ctx())\n        )\n\n    mn = next(it)\n    assert mn.seeds == {\n        \"N0\": torch.LongTensor([2]).to(F.ctx()),\n        \"N1\": torch.LongTensor([0]).to(F.ctx()),\n    }\n    expected_sampled_csc = {\n        \"N0:R0:N0\": gb.CSCFormatBase(\n            indptr=torch.LongTensor([0, 1]), indices=torch.LongTensor([1])\n        ),\n        \"N0:R1:N1\": gb.CSCFormatBase(\n            indptr=torch.LongTensor([0, 2]), indices=torch.LongTensor([2, 0])\n        ),\n        \"N1:R2:N0\": gb.CSCFormatBase(\n            indptr=torch.LongTensor([0, 1]), indices=torch.LongTensor([1])\n        ),\n        \"N1:R3:N1\": gb.CSCFormatBase(\n            indptr=torch.LongTensor([0, 0]), indices=torch.LongTensor([])\n        ),\n    }\n    for etype, pairs in mn.sampled_subgraphs[0].sampled_csc.items():\n        assert torch.equal(\n            pairs.indices, expected_sampled_csc[etype].indices.to(F.ctx())\n        )\n        assert torch.equal(\n            pairs.indptr, expected_sampled_csc[etype].indptr.to(F.ctx())\n        )\n\n    mn = next(it)\n    assert torch.equal(mn.seeds[\"N1\"], torch.LongTensor([2, 1]).to(F.ctx()))\n    expected_sampled_csc = {\n        \"N0:R0:N0\": gb.CSCFormatBase(\n            indptr=torch.LongTensor([0]), indices=torch.LongTensor([])\n        ),\n        \"N0:R1:N1\": gb.CSCFormatBase(\n            indptr=torch.LongTensor([0, 1, 2]), indices=torch.LongTensor([0, 1])\n        ),\n        \"N1:R2:N0\": gb.CSCFormatBase(\n            indptr=torch.LongTensor([0]), indices=torch.LongTensor([])\n        ),\n        \"N1:R3:N1\": gb.CSCFormatBase(\n            indptr=torch.LongTensor([0, 1, 3]),\n            indices=torch.LongTensor([1, 2, 0]),\n        ),\n    }\n    if graph.csc_indptr.is_cuda and torch.cuda.get_device_capability()[0] < 7:\n        expected_sampled_csc[\"N0:R1:N1\"] = gb.CSCFormatBase(\n            indptr=torch.LongTensor([0, 1, 2]), indices=torch.LongTensor([1, 0])\n        )\n    for etype, pairs in mn.sampled_subgraphs[0].sampled_csc.items():\n        assert torch.equal(\n            pairs.indices, expected_sampled_csc[etype].indices.to(F.ctx())\n        )\n        assert torch.equal(\n            pairs.indptr, expected_sampled_csc[etype].indptr.to(F.ctx())\n        )\n"
  },
  {
    "path": "tests/python/pytorch/graphbolt/impl/test_legacy_dataset.py",
    "content": "import dgl.graphbolt as gb\nimport pytest\nimport torch\nfrom dgl import AddSelfLoop\nfrom dgl.data import AsNodePredDataset, CoraGraphDataset\n\n\ndef test_LegacyDataset_homo_node_pred():\n    cora = CoraGraphDataset(transform=AddSelfLoop())\n    dataset = gb.LegacyDataset(cora)\n\n    # Check tasks.\n    assert len(dataset.tasks) == 1\n    task = dataset.tasks[0]\n    assert task.train_set.names == (\"seeds\", \"labels\")\n    assert len(task.train_set) == 140\n    assert task.validation_set.names == (\"seeds\", \"labels\")\n    assert len(task.validation_set) == 500\n    assert task.test_set.names == (\"seeds\", \"labels\")\n    assert len(task.test_set) == 1000\n    assert task.metadata[\"num_classes\"] == 7\n\n    num_nodes = 2708\n    assert dataset.graph.num_nodes == num_nodes\n    assert len(dataset.all_nodes_set) == num_nodes\n    assert dataset.feature.size(\"node\", None, \"feat\") == torch.Size([1433])\n    assert (\n        dataset.feature.read(\n            \"node\", None, \"feat\", torch.tensor([num_nodes - 1])\n        ).size(dim=0)\n        == 1\n    )\n    # Out of bound indexing results in segmentation fault instead of exception\n    # in CI. This may be related to docker env. Skip it for now.\n    # with pytest.raises(IndexError):\n    #    dataset.feature.read(\"node\", None, \"feat\", torch.Tensor([num_nodes]))\n"
  },
  {
    "path": "tests/python/pytorch/graphbolt/impl/test_negative_sampler.py",
    "content": "import re\n\nimport backend as F\n\nimport dgl.graphbolt as gb\nimport pytest\nimport torch\n\nfrom .. import gb_test_utils\n\n\ndef test_NegativeSampler_invoke():\n    # Instantiate graph and required datapipes.\n    num_seeds = 30\n    item_set = gb.ItemSet(\n        torch.arange(0, 2 * num_seeds).reshape(-1, 2), names=\"seeds\"\n    )\n    batch_size = 10\n    item_sampler = gb.ItemSampler(item_set, batch_size=batch_size).copy_to(\n        F.ctx()\n    )\n    negative_ratio = 2\n\n    # Invoke NegativeSampler via class constructor.\n    negative_sampler = gb.NegativeSampler(\n        item_sampler,\n        negative_ratio,\n    )\n    with pytest.raises(NotImplementedError):\n        next(iter(negative_sampler))\n\n    # Invoke NegativeSampler via functional form.\n    negative_sampler = item_sampler.sample_negative(\n        negative_ratio,\n    )\n    with pytest.raises(NotImplementedError):\n        next(iter(negative_sampler))\n\n\ndef test_UniformNegativeSampler_invoke():\n    # Instantiate graph and required datapipes.\n    graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True).to(\n        F.ctx()\n    )\n    num_seeds = 30\n    item_set = gb.ItemSet(\n        torch.arange(0, 2 * num_seeds).reshape(-1, 2), names=\"seeds\"\n    )\n    batch_size = 10\n    item_sampler = gb.ItemSampler(item_set, batch_size=batch_size).copy_to(\n        F.ctx()\n    )\n    negative_ratio = 2\n\n    def _verify(negative_sampler):\n        for data in negative_sampler:\n            # Assertation\n            seeds_len = batch_size + batch_size * negative_ratio\n            assert data.seeds.size(0) == seeds_len\n            assert data.labels.size(0) == seeds_len\n            assert data.indexes.size(0) == seeds_len\n\n    # Invoke UniformNegativeSampler via class constructor.\n    negative_sampler = gb.UniformNegativeSampler(\n        item_sampler,\n        graph,\n        negative_ratio,\n    )\n    _verify(negative_sampler)\n\n    # Invoke UniformNegativeSampler via functional form.\n    negative_sampler = item_sampler.sample_uniform_negative(\n        graph,\n        negative_ratio,\n    )\n    _verify(negative_sampler)\n\n\n@pytest.mark.parametrize(\"negative_ratio\", [1, 5, 10, 20])\ndef test_Uniform_NegativeSampler(negative_ratio):\n    # Construct FusedCSCSamplingGraph.\n    graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True).to(\n        F.ctx()\n    )\n    num_seeds = 30\n    item_set = gb.ItemSet(\n        torch.arange(0, num_seeds * 2).reshape(-1, 2), names=\"seeds\"\n    )\n    batch_size = 10\n    item_sampler = gb.ItemSampler(item_set, batch_size=batch_size).copy_to(\n        F.ctx()\n    )\n    # Construct NegativeSampler.\n    negative_sampler = gb.UniformNegativeSampler(\n        item_sampler,\n        graph,\n        negative_ratio,\n    )\n    # Perform Negative sampling.\n    for data in negative_sampler:\n        seeds_len = batch_size + batch_size * negative_ratio\n        # Assertation\n        assert data.seeds.size(0) == seeds_len\n        assert data.labels.size(0) == seeds_len\n        assert data.indexes.size(0) == seeds_len\n        # Check negative seeds value.\n        pos_src = data.seeds[:batch_size, 0]\n        neg_src = data.seeds[batch_size:, 0]\n        assert torch.equal(pos_src.repeat_interleave(negative_ratio), neg_src)\n        # Check labels.\n        assert torch.equal(\n            data.labels[:batch_size], torch.ones(batch_size).to(F.ctx())\n        )\n        assert torch.equal(\n            data.labels[batch_size:],\n            torch.zeros(batch_size * negative_ratio).to(F.ctx()),\n        )\n        # Check indexes.\n        pos_indexes = torch.arange(0, batch_size).to(F.ctx())\n        neg_indexes = pos_indexes.repeat_interleave(negative_ratio)\n        expected_indexes = torch.cat((pos_indexes, neg_indexes))\n        assert torch.equal(data.indexes, expected_indexes)\n\n\ndef test_Uniform_NegativeSampler_error_shape():\n    # 1. seeds with shape N*3.\n    # Construct FusedCSCSamplingGraph.\n    graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True).to(\n        F.ctx()\n    )\n    num_seeds = 30\n    item_set = gb.ItemSet(\n        torch.arange(0, num_seeds * 3).reshape(-1, 3), names=\"seeds\"\n    )\n    batch_size = 10\n    item_sampler = gb.ItemSampler(item_set, batch_size=batch_size).copy_to(\n        F.ctx()\n    )\n    negative_ratio = 2\n    # Construct NegativeSampler.\n    negative_sampler = gb.UniformNegativeSampler(\n        item_sampler,\n        graph,\n        negative_ratio,\n    )\n    with pytest.raises(\n        AssertionError,\n        match=re.escape(\n            \"Only tensor with shape N*2 is \"\n            + \"supported for negative sampling, but got torch.Size([10, 3]).\"\n        ),\n    ):\n        next(iter(negative_sampler))\n\n    # 2. seeds with shape N*2*1.\n    # Construct FusedCSCSamplingGraph.\n    item_set = gb.ItemSet(\n        torch.arange(0, num_seeds * 2).reshape(-1, 2, 1), names=\"seeds\"\n    )\n    item_sampler = gb.ItemSampler(item_set, batch_size=batch_size).copy_to(\n        F.ctx()\n    )\n    # Construct NegativeSampler.\n    negative_sampler = gb.UniformNegativeSampler(\n        item_sampler,\n        graph,\n        negative_ratio,\n    )\n    with pytest.raises(\n        AssertionError,\n        match=re.escape(\n            \"Only tensor with shape N*2 is \"\n            + \"supported for negative sampling, but got torch.Size([10, 2, 1]).\"\n        ),\n    ):\n        next(iter(negative_sampler))\n\n    # 3. seeds with shape N.\n    # Construct FusedCSCSamplingGraph.\n    item_set = gb.ItemSet(torch.arange(0, num_seeds), names=\"seeds\")\n    item_sampler = gb.ItemSampler(item_set, batch_size=batch_size).copy_to(\n        F.ctx()\n    )\n    # Construct NegativeSampler.\n    negative_sampler = gb.UniformNegativeSampler(\n        item_sampler,\n        graph,\n        negative_ratio,\n    )\n    with pytest.raises(\n        AssertionError,\n        match=re.escape(\n            \"Only tensor with shape N*2 is \"\n            + \"supported for negative sampling, but got torch.Size([10]).\"\n        ),\n    ):\n        next(iter(negative_sampler))\n\n\ndef get_hetero_graph():\n    # COO graph:\n    # [0, 0, 1, 1, 2, 2, 3, 3, 4, 4]\n    # [2, 4, 2, 3, 0, 1, 1, 0, 0, 1]\n    # [1, 1, 1, 1, 0, 0, 0, 0, 0] - > edge type.\n    # num_nodes = 5, num_n1 = 2, num_n2 = 3\n    ntypes = {\"n1\": 0, \"n2\": 1}\n    etypes = {\"n1:e1:n2\": 0, \"n2:e2:n1\": 1}\n    indptr = torch.LongTensor([0, 2, 4, 6, 8, 10])\n    indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1])\n    type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0])\n    node_type_offset = torch.LongTensor([0, 2, 5])\n    return gb.fused_csc_sampling_graph(\n        indptr,\n        indices,\n        node_type_offset=node_type_offset,\n        type_per_edge=type_per_edge,\n        node_type_to_id=ntypes,\n        edge_type_to_id=etypes,\n    )\n\n\ndef test_NegativeSampler_Hetero_Data():\n    graph = get_hetero_graph().to(F.ctx())\n    itemset = gb.HeteroItemSet(\n        {\n            \"n1:e1:n2\": gb.ItemSet(\n                torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T,\n                names=\"seeds\",\n            ),\n            \"n2:e2:n1\": gb.ItemSet(\n                torch.LongTensor([[0, 0, 1, 1, 2, 2], [0, 1, 1, 0, 0, 1]]).T,\n                names=\"seeds\",\n            ),\n        }\n    )\n    batch_size = 2\n    negative_ratio = 1\n    item_sampler = gb.ItemSampler(itemset, batch_size=batch_size).copy_to(\n        F.ctx()\n    )\n    negative_dp = gb.UniformNegativeSampler(item_sampler, graph, negative_ratio)\n    assert len(list(negative_dp)) == 5\n    # Perform negative sampling.\n    expected_neg_src = [\n        {\"n1:e1:n2\": torch.tensor([0, 0])},\n        {\"n1:e1:n2\": torch.tensor([1, 1])},\n        {\"n2:e2:n1\": torch.tensor([0, 0])},\n        {\"n2:e2:n1\": torch.tensor([1, 1])},\n        {\"n2:e2:n1\": torch.tensor([2, 2])},\n    ]\n    for i, data in enumerate(negative_dp):\n        # Check negative seeds value.\n        for etype, seeds_data in data.seeds.items():\n            neg_src = seeds_data[batch_size:, 0]\n            neg_dst = seeds_data[batch_size:, 1]\n            assert torch.equal(expected_neg_src[i][etype].to(F.ctx()), neg_src)\n            assert (neg_dst < 3).all(), neg_dst\n"
  },
  {
    "path": "tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py",
    "content": "import unittest\nfrom functools import partial\n\nimport backend as F\n\nimport dgl.graphbolt as gb\nimport pytest\nimport torch\n\n\ndef get_hetero_graph(include_original_edge_ids):\n    # COO graph:\n    # [0, 0, 1, 1, 2, 2, 3, 3, 4, 4]\n    # [2, 4, 2, 3, 0, 1, 1, 0, 0, 1]\n    # [1, 1, 1, 1, 0, 0, 0, 0, 0] - > edge type.\n    # num_nodes = 5, num_n1 = 2, num_n2 = 3\n    ntypes = {\"n1\": 0, \"n2\": 1, \"n3\": 2}\n    etypes = {\"n2:e1:n3\": 0, \"n3:e2:n2\": 1}\n    indptr = torch.LongTensor([0, 0, 2, 4, 6, 8, 10])\n    indices = torch.LongTensor([3, 5, 3, 4, 1, 2, 2, 1, 1, 2])\n    type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0])\n    edge_attributes = {\n        \"weight\": torch.FloatTensor(\n            [2.5, 0, 8.4, 0, 0.4, 1.2, 2.5, 0, 8.4, 0.5]\n        ),\n        \"mask\": torch.BoolTensor([1, 0, 1, 0, 1, 1, 1, 0, 1, 1]),\n    }\n    if include_original_edge_ids:\n        edge_attributes[gb.ORIGINAL_EDGE_ID] = (\n            torch.arange(indices.size(0), 0, -1) - 1\n        )\n    node_type_offset = torch.LongTensor([0, 1, 3, 6])\n    return gb.fused_csc_sampling_graph(\n        indptr,\n        indices,\n        node_type_offset=node_type_offset,\n        type_per_edge=type_per_edge,\n        node_type_to_id=ntypes,\n        edge_type_to_id=etypes,\n        edge_attributes=edge_attributes,\n    )\n\n\n@unittest.skipIf(F._default_context_str != \"gpu\", reason=\"Enabled only on GPU.\")\n@pytest.mark.parametrize(\"hetero\", [False, True])\n@pytest.mark.parametrize(\"prob_name\", [None, \"weight\", \"mask\"])\n@pytest.mark.parametrize(\"sorted\", [False, True])\n@pytest.mark.parametrize(\"num_cached_edges\", [0, 10])\n@pytest.mark.parametrize(\"is_pinned\", [False, True])\n@pytest.mark.parametrize(\"has_orig_edge_ids\", [False, True])\ndef test_NeighborSampler_GraphFetch(\n    hetero, prob_name, sorted, num_cached_edges, is_pinned, has_orig_edge_ids\n):\n    if sorted:\n        items = torch.arange(3)\n    else:\n        items = torch.tensor([2, 0, 1])\n    names = \"seeds\"\n    itemset = gb.ItemSet(items, names=names)\n    graph = get_hetero_graph(has_orig_edge_ids)\n    graph = graph.pin_memory_() if is_pinned else graph.to(F.ctx())\n    if hetero:\n        itemset = gb.HeteroItemSet({\"n3\": itemset})\n    else:\n        graph.type_per_edge = None\n    item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())\n    fanout = torch.LongTensor([2])\n    preprocess_fn = partial(\n        gb.SubgraphSampler._preprocess, cooperative=False, async_op=False\n    )\n    datapipe = item_sampler.map(preprocess_fn)\n    datapipe = datapipe.map(\n        partial(gb.NeighborSampler._prepare, graph.node_type_to_id)\n    )\n    sample_per_layer = gb.SamplePerLayer(\n        datapipe, graph.sample_neighbors, fanout, False, prob_name, False\n    )\n    compact_per_layer = sample_per_layer.compact_per_layer(True)\n    gb.seed(123)\n    expected_results = list(compact_per_layer)\n    if num_cached_edges > 0:\n        graph._initialize_gpu_graph_cache(num_cached_edges, 1, prob_name)\n    datapipe = datapipe.sample_per_layer(\n        graph.sample_neighbors, fanout, False, prob_name, True\n    )\n    datapipe = datapipe.compact_per_layer(True)\n    gb.seed(123)\n    new_results = list(datapipe)\n    assert len(expected_results) == len(new_results)\n    for a, b in zip(expected_results, new_results):\n        assert repr(a) == repr(b)\n\n    def remove_input_nodes(minibatch):\n        minibatch.input_nodes = None\n        return minibatch\n\n    datapipe = item_sampler.sample_neighbor(\n        graph, [fanout], False, prob_name=prob_name, overlap_fetch=True\n    )\n    datapipe = datapipe.transform(remove_input_nodes)\n    dataloader = gb.DataLoader(datapipe)\n    gb.seed(123)\n    new_results = list(dataloader)\n    assert len(expected_results) == len(new_results)\n    for a, b in zip(expected_results, new_results):\n        assert repr(a) == repr(b)\n\n\n@pytest.mark.parametrize(\"layer_dependency\", [False, True])\n@pytest.mark.parametrize(\"overlap_graph_fetch\", [False, True])\ndef test_labor_dependent_minibatching(layer_dependency, overlap_graph_fetch):\n    if F._default_context_str != \"gpu\" and overlap_graph_fetch:\n        pytest.skip(\"overlap_graph_fetch is only available for GPU.\")\n    num_edges = 200\n    csc_indptr = torch.cat(\n        (\n            torch.zeros(1, dtype=torch.int64),\n            torch.ones(num_edges + 1, dtype=torch.int64) * num_edges,\n        )\n    )\n    indices = torch.arange(1, num_edges + 1)\n    graph = gb.fused_csc_sampling_graph(\n        csc_indptr.int(),\n        indices.int(),\n    ).to(F.ctx())\n    torch.random.set_rng_state(torch.manual_seed(123).get_state())\n    batch_dependency = 100\n    itemset = gb.ItemSet(torch.zeros(batch_dependency + 1).int(), names=\"seeds\")\n    datapipe = gb.ItemSampler(itemset, batch_size=1).copy_to(F.ctx())\n    fanouts = [5, 5]\n    datapipe = datapipe.sample_layer_neighbor(\n        graph,\n        fanouts,\n        overlap_fetch=overlap_graph_fetch,\n        layer_dependency=layer_dependency,\n        batch_dependency=batch_dependency,\n    )\n    dataloader = gb.DataLoader(datapipe)\n    res = list(dataloader)\n    assert len(res) == batch_dependency + 1\n    if layer_dependency:\n        assert torch.equal(\n            res[0].input_nodes,\n            res[0].sampled_subgraphs[1].original_row_node_ids,\n        )\n    else:\n        assert res[0].input_nodes.size(0) > res[0].sampled_subgraphs[\n            1\n        ].original_row_node_ids.size(0)\n    delta = 0\n    for i in range(batch_dependency):\n        res_current = (\n            res[i].sampled_subgraphs[-1].original_row_node_ids.tolist()\n        )\n        res_next = (\n            res[i + 1].sampled_subgraphs[-1].original_row_node_ids.tolist()\n        )\n        intersect_len = len(set(res_current).intersection(set(res_next)))\n        assert intersect_len >= fanouts[-1]\n        delta += 1 + fanouts[-1] - intersect_len\n    assert delta >= fanouts[-1]\n"
  },
  {
    "path": "tests/python/pytorch/graphbolt/impl/test_ondisk_dataset.py",
    "content": "import os\nimport pickle\nimport random\nimport re\nimport tempfile\nimport unittest\nimport warnings\n\nimport numpy as np\nimport pandas as pd\nimport pydantic\nimport pytest\nimport torch\nimport yaml\nfrom dgl import graphbolt as gb\nfrom dgl.graphbolt import GBWarning\n\nfrom .. import gb_test_utils as gbt\n\n\ndef write_yaml_file(yaml_content, dir):\n    os.makedirs(os.path.join(dir, \"preprocessed\"), exist_ok=True)\n    yaml_file = os.path.join(dir, \"preprocessed/metadata.yaml\")\n    with open(yaml_file, \"w\") as f:\n        f.write(yaml_content)\n\n\ndef load_dataset(dataset):\n    with warnings.catch_warnings():\n        warnings.simplefilter(\"ignore\", category=UserWarning)\n        return dataset.load()\n\n\ndef write_yaml_and_load_dataset(yaml_content, dir, force_preprocess=False):\n    write_yaml_file(yaml_content, dir)\n    return load_dataset(\n        gb.OnDiskDataset(dir, force_preprocess=force_preprocess)\n    )\n\n\ndef load_sampling_graph(test_dir, processed_dataset):\n    return torch.load(\n        os.path.join(test_dir, processed_dataset[\"graph_topology\"][\"path\"]),\n        weights_only=False,\n    )\n\n\ndef test_OnDiskDataset_TVTSet_exceptions():\n    \"\"\"Test excpetions thrown when parsing TVTSet.\"\"\"\n    with tempfile.TemporaryDirectory() as test_dir:\n        # Case 1: ``format`` is invalid.\n        yaml_content = \"\"\"\n        tasks:\n          - name: node_classification\n            train_set:\n              - type: paper\n                data:\n                  - format: torch_invalid\n                    path: set/paper-train.pt\n        \"\"\"\n        write_yaml_file(yaml_content, test_dir)\n        with pytest.raises(pydantic.ValidationError):\n            _ = gb.OnDiskDataset(test_dir, force_preprocess=False).load()\n\n        # Case 2: ``type`` is not specified while multiple TVT sets are\n        # specified.\n        yaml_content = \"\"\"\n            tasks:\n              - name: node_classification\n                train_set:\n                - type: null\n                  data:\n                    - format: numpy\n                      path: set/train.npy\n                - type: null\n                  data:\n                    - format: numpy\n                      path: set/train.npy\n        \"\"\"\n        write_yaml_file(yaml_content, test_dir)\n        with pytest.raises(\n            AssertionError,\n            match=r\"Only one TVT set is allowed if type is not specified.\",\n        ):\n            _ = gb.OnDiskDataset(test_dir, force_preprocess=False).load()\n\n\ndef test_OnDiskDataset_multiple_tasks():\n    \"\"\"Teset multiple tasks are supported.\"\"\"\n    with tempfile.TemporaryDirectory() as test_dir:\n        train_ids = np.arange(1000)\n        train_ids_path = os.path.join(test_dir, \"train_ids.npy\")\n        np.save(train_ids_path, train_ids)\n        train_labels = np.random.randint(0, 10, size=1000)\n        train_labels_path = os.path.join(test_dir, \"train_labels.npy\")\n        np.save(train_labels_path, train_labels)\n\n        yaml_content = f\"\"\"\n            tasks:\n              - name: node_classification_1\n                num_classes: 10\n                train_set:\n                  - type: null\n                    data:\n                      - name: seeds\n                        format: numpy\n                        in_memory: true\n                        path: {train_ids_path}\n                      - name: labels\n                        format: numpy\n                        in_memory: true\n                        path: {train_labels_path}\n                      - format: numpy\n                        in_memory: true\n                        path: {train_labels_path}\n              - name: node_classification_2\n                num_classes: 10\n                train_set:\n                  - type: null\n                    data:\n                      - name: seeds\n                        format: numpy\n                        in_memory: true\n                        path: {train_ids_path}\n                      - name: labels\n                        format: numpy\n                        in_memory: true\n                        path: {train_labels_path}\n                      - format: numpy\n                        in_memory: true\n                        path: {train_labels_path}\n        \"\"\"\n        dataset = write_yaml_and_load_dataset(yaml_content, test_dir)\n        assert len(dataset.tasks) == 2\n\n        for task_id in range(2):\n            assert (\n                dataset.tasks[task_id].metadata[\"name\"]\n                == f\"node_classification_{task_id + 1}\"\n            )\n            assert dataset.tasks[task_id].metadata[\"num_classes\"] == 10\n            # Verify train set.\n            train_set = dataset.tasks[task_id].train_set\n            assert len(train_set) == 1000\n            assert isinstance(train_set, gb.ItemSet)\n            for i, (id, label, _) in enumerate(train_set):\n                assert id == train_ids[i]\n                assert label == train_labels[i]\n            assert train_set.names == (\"seeds\", \"labels\", None)\n            train_set = None\n        dataset = None\n\n\ndef test_OnDiskDataset_TVTSet_ItemSet_names():\n    \"\"\"Test TVTSet which returns ItemSet with IDs, labels and corresponding names.\"\"\"\n    with tempfile.TemporaryDirectory() as test_dir:\n        train_ids = np.arange(1000)\n        train_ids_path = os.path.join(test_dir, \"train_ids.npy\")\n        np.save(train_ids_path, train_ids)\n        train_labels = np.random.randint(0, 10, size=1000)\n        train_labels_path = os.path.join(test_dir, \"train_labels.npy\")\n        np.save(train_labels_path, train_labels)\n\n        yaml_content = f\"\"\"\n            tasks:\n              - name: node_classification\n                num_classes: 10\n                train_set:\n                  - type: null\n                    data:\n                      - name: seeds\n                        format: numpy\n                        in_memory: true\n                        path: {train_ids_path}\n                      - name: labels\n                        format: numpy\n                        in_memory: true\n                        path: {train_labels_path}\n                      - format: numpy\n                        in_memory: true\n                        path: {train_labels_path}\n        \"\"\"\n        dataset = write_yaml_and_load_dataset(yaml_content, test_dir)\n\n        # Verify train set.\n        train_set = dataset.tasks[0].train_set\n        assert len(train_set) == 1000\n        assert isinstance(train_set, gb.ItemSet)\n        for i, (id, label, _) in enumerate(train_set):\n            assert id == train_ids[i]\n            assert label == train_labels[i]\n        assert train_set.names == (\"seeds\", \"labels\", None)\n        train_set = None\n\n\ndef test_OnDiskDataset_TVTSet_HeteroItemSet_names():\n    \"\"\"Test TVTSet which returns ItemSet with IDs, labels and corresponding names.\"\"\"\n    with tempfile.TemporaryDirectory() as test_dir:\n        train_ids = np.arange(1000)\n        train_ids_path = os.path.join(test_dir, \"train_ids.npy\")\n        np.save(train_ids_path, train_ids)\n        train_labels = np.random.randint(0, 10, size=1000)\n        train_labels_path = os.path.join(test_dir, \"train_labels.npy\")\n        np.save(train_labels_path, train_labels)\n\n        yaml_content = f\"\"\"\n            tasks:\n              - name: node_classification\n                num_classes: 10\n                train_set:\n                  - type: \"author:writes:paper\"\n                    data:\n                      - name: seeds\n                        format: numpy\n                        in_memory: true\n                        path: {train_ids_path}\n                      - name: labels\n                        format: numpy\n                        in_memory: true\n                        path: {train_labels_path}\n                      - format: numpy\n                        in_memory: true\n                        path: {train_labels_path}\n        \"\"\"\n        dataset = write_yaml_and_load_dataset(yaml_content, test_dir)\n\n        # Verify train set.\n        train_set = dataset.tasks[0].train_set\n        assert len(train_set) == 1000\n        assert isinstance(train_set, gb.HeteroItemSet)\n        for i, item in enumerate(train_set):\n            assert isinstance(item, dict)\n            assert \"author:writes:paper\" in item\n            id, label, _ = item[\"author:writes:paper\"]\n            assert id == train_ids[i]\n            assert label == train_labels[i]\n        assert train_set.names == (\"seeds\", \"labels\", None)\n        train_set = None\n\n\ndef test_OnDiskDataset_TVTSet_ItemSet_id_label():\n    \"\"\"Test TVTSet which returns ItemSet with IDs and labels.\"\"\"\n    with tempfile.TemporaryDirectory() as test_dir:\n        train_ids = np.arange(1000)\n        train_ids_path = os.path.join(test_dir, \"train_ids.npy\")\n        np.save(train_ids_path, train_ids)\n        train_labels = np.random.randint(0, 10, size=1000)\n        train_labels_path = os.path.join(test_dir, \"train_labels.npy\")\n        np.save(train_labels_path, train_labels)\n\n        validation_ids = np.arange(1000, 2000)\n        validation_ids_path = os.path.join(test_dir, \"validation_ids.npy\")\n        np.save(validation_ids_path, validation_ids)\n        validation_labels = np.random.randint(0, 10, size=1000)\n        validation_labels_path = os.path.join(test_dir, \"validation_labels.npy\")\n        np.save(validation_labels_path, validation_labels)\n\n        test_ids = np.arange(2000, 3000)\n        test_ids_path = os.path.join(test_dir, \"test_ids.npy\")\n        np.save(test_ids_path, test_ids)\n        test_labels = np.random.randint(0, 10, size=1000)\n        test_labels_path = os.path.join(test_dir, \"test_labels.npy\")\n        np.save(test_labels_path, test_labels)\n\n        # Case 1:\n        #   all TVT sets are specified.\n        #   ``type`` is not specified or specified as ``null``.\n        #   ``in_memory`` could be ``true`` and ``false``.\n        yaml_content = f\"\"\"\n            tasks:\n              - name: node_classification\n                num_classes: 10\n                train_set:\n                  - type: null\n                    data:\n                      - name: seeds\n                        format: numpy\n                        in_memory: true\n                        path: {train_ids_path}\n                      - name: labels\n                        format: numpy\n                        in_memory: true\n                        path: {train_labels_path}\n                validation_set:\n                  - data:\n                      - name: seeds\n                        format: numpy\n                        in_memory: true\n                        path: {validation_ids_path}\n                      - name: labels\n                        format: numpy\n                        in_memory: true\n                        path: {validation_labels_path}\n                test_set:\n                  - type: null\n                    data:\n                      - name: seeds\n                        format: numpy\n                        in_memory: true\n                        path: {test_ids_path}\n                      - name: labels\n                        format: numpy\n                        in_memory: true\n                        path: {test_labels_path}\n        \"\"\"\n        dataset = write_yaml_and_load_dataset(yaml_content, test_dir)\n\n        # Verify tasks.\n        assert len(dataset.tasks) == 1\n        assert dataset.tasks[0].metadata[\"name\"] == \"node_classification\"\n        assert dataset.tasks[0].metadata[\"num_classes\"] == 10\n\n        # Verify train set.\n        train_set = dataset.tasks[0].train_set\n        assert len(train_set) == 1000\n        assert isinstance(train_set, gb.ItemSet)\n        for i, (id, label) in enumerate(train_set):\n            assert id == train_ids[i]\n            assert label == train_labels[i]\n        assert train_set.names == (\"seeds\", \"labels\")\n        train_set = None\n\n        # Verify validation set.\n        validation_set = dataset.tasks[0].validation_set\n        assert len(validation_set) == 1000\n        assert isinstance(validation_set, gb.ItemSet)\n        for i, (id, label) in enumerate(validation_set):\n            assert id == validation_ids[i]\n            assert label == validation_labels[i]\n        assert validation_set.names == (\"seeds\", \"labels\")\n        validation_set = None\n\n        # Verify test set.\n        test_set = dataset.tasks[0].test_set\n        assert len(test_set) == 1000\n        assert isinstance(test_set, gb.ItemSet)\n        for i, (id, label) in enumerate(test_set):\n            assert id == test_ids[i]\n            assert label == test_labels[i]\n        assert test_set.names == (\"seeds\", \"labels\")\n        test_set = None\n        dataset = None\n\n        # Case 2: Some TVT sets are None.\n        yaml_content = f\"\"\"\n            tasks:\n              - name: node_classification\n                train_set:\n                  - type: null\n                    data:\n                      - format: numpy\n                        path: {train_ids_path}\n        \"\"\"\n        dataset = write_yaml_and_load_dataset(yaml_content, test_dir)\n        assert dataset.tasks[0].train_set is not None\n        assert dataset.tasks[0].validation_set is None\n        assert dataset.tasks[0].test_set is None\n        dataset = None\n\n\ndef test_OnDiskDataset_TVTSet_ItemSet_node_pairs_labels():\n    \"\"\"Test TVTSet which returns ItemSet with node pairs and labels.\"\"\"\n    with tempfile.TemporaryDirectory() as test_dir:\n        train_seeds = np.arange(2000).reshape(1000, 2)\n        train_seeds_path = os.path.join(test_dir, \"train_seeds.npy\")\n        np.save(train_seeds_path, train_seeds)\n        train_labels = np.random.randint(0, 10, size=1000)\n        train_labels_path = os.path.join(test_dir, \"train_labels.npy\")\n        np.save(train_labels_path, train_labels)\n\n        validation_seeds = np.arange(2000, 4000).reshape(1000, 2)\n        validation_seeds_path = os.path.join(test_dir, \"validation_seeds.npy\")\n        np.save(validation_seeds_path, validation_seeds)\n        validation_labels = np.random.randint(0, 10, size=1000)\n        validation_labels_path = os.path.join(test_dir, \"validation_labels.npy\")\n        np.save(validation_labels_path, validation_labels)\n\n        test_seeds = np.arange(4000, 6000).reshape(1000, 2)\n        test_seeds_path = os.path.join(test_dir, \"test_seeds.npy\")\n        np.save(test_seeds_path, test_seeds)\n        test_labels = np.random.randint(0, 10, size=1000)\n        test_labels_path = os.path.join(test_dir, \"test_labels.npy\")\n        np.save(test_labels_path, test_labels)\n\n        yaml_content = f\"\"\"\n            tasks:\n              - name: link_prediction\n                train_set:\n                  - type: null\n                    data:\n                      - name: seeds\n                        format: numpy\n                        in_memory: true\n                        path: {train_seeds_path}\n                      - name: labels\n                        format: numpy\n                        in_memory: true\n                        path: {train_labels_path}\n                validation_set:\n                  - data:\n                      - name: seeds\n                        format: numpy\n                        in_memory: true\n                        path: {validation_seeds_path}\n                      - name: labels\n                        format: numpy\n                        in_memory: true\n                        path: {validation_labels_path}\n                test_set:\n                  - type: null\n                    data:\n                      - name: seeds\n                        format: numpy\n                        in_memory: true\n                        path: {test_seeds_path}\n                      - name: labels\n                        format: numpy\n                        in_memory: true\n                        path: {test_labels_path}\n        \"\"\"\n        dataset = write_yaml_and_load_dataset(yaml_content, test_dir)\n\n        # Verify train set.\n        train_set = dataset.tasks[0].train_set\n        assert len(train_set) == 1000\n        assert isinstance(train_set, gb.ItemSet)\n        for i, (node_pair, label) in enumerate(train_set):\n            assert node_pair[0] == train_seeds[i][0]\n            assert node_pair[1] == train_seeds[i][1]\n            assert label == train_labels[i]\n        assert train_set.names == (\"seeds\", \"labels\")\n        train_set = None\n\n        # Verify validation set.\n        validation_set = dataset.tasks[0].validation_set\n        assert len(validation_set) == 1000\n        assert isinstance(validation_set, gb.ItemSet)\n        for i, (node_pair, label) in enumerate(validation_set):\n            assert node_pair[0] == validation_seeds[i][0]\n            assert node_pair[1] == validation_seeds[i][1]\n            assert label == validation_labels[i]\n        assert validation_set.names == (\"seeds\", \"labels\")\n        validation_set = None\n\n        # Verify test set.\n        test_set = dataset.tasks[0].test_set\n        assert len(test_set) == 1000\n        assert isinstance(test_set, gb.ItemSet)\n        for i, (node_pair, label) in enumerate(test_set):\n            assert node_pair[0] == test_seeds[i][0]\n            assert node_pair[1] == test_seeds[i][1]\n            assert label == test_labels[i]\n        assert test_set.names == (\"seeds\", \"labels\")\n        test_set = None\n        dataset = None\n\n\ndef test_OnDiskDataset_TVTSet_ItemSet_node_pairs_labels_indexes():\n    \"\"\"Test TVTSet which returns ItemSet with node pairs and negative ones.\"\"\"\n    with tempfile.TemporaryDirectory() as test_dir:\n        train_seeds = np.arange(2000).reshape(1000, 2)\n        train_neg_dst = np.random.choice(1000 * 10, size=1000 * 10)\n        train_neg_src = train_seeds[:, 0].repeat(10)\n        train_neg_seeds = (\n            np.concatenate((train_neg_dst, train_neg_src)).reshape(2, -1).T\n        )\n        train_seeds = np.concatenate((train_seeds, train_neg_seeds))\n        train_seeds_path = os.path.join(test_dir, \"train_seeds.npy\")\n        np.save(train_seeds_path, train_seeds)\n\n        train_labels = torch.empty(1000 * 11)\n        train_labels[:1000] = 1\n        train_labels[1000:] = 0\n        train_labels_path = os.path.join(test_dir, \"train_labels.pt\")\n        torch.save(train_labels, train_labels_path)\n\n        train_indexes = torch.arange(0, 1000)\n        train_indexes = np.concatenate(\n            (train_indexes, train_indexes.repeat_interleave(10))\n        )\n        train_indexes_path = os.path.join(test_dir, \"train_indexes.pt\")\n        torch.save(train_indexes, train_indexes_path)\n\n        validation_seeds = np.arange(2000, 4000).reshape(1000, 2)\n        validation_neg_seeds = train_neg_seeds + 1\n        validation_seeds = np.concatenate(\n            (validation_seeds, validation_neg_seeds)\n        )\n        validation_seeds_path = os.path.join(test_dir, \"validation_seeds.npy\")\n        np.save(validation_seeds_path, validation_seeds)\n        validation_labels = train_labels\n        validation_labels_path = os.path.join(test_dir, \"validation_labels.pt\")\n        torch.save(validation_labels, validation_labels_path)\n\n        validation_indexes = train_indexes\n        validation_indexes_path = os.path.join(\n            test_dir, \"validation_indexes.pt\"\n        )\n        torch.save(validation_indexes, validation_indexes_path)\n\n        test_seeds = np.arange(4000, 6000).reshape(1000, 2)\n        test_neg_seeds = train_neg_seeds + 2\n        test_seeds = np.concatenate((test_seeds, test_neg_seeds))\n        test_seeds_path = os.path.join(test_dir, \"test_seeds.npy\")\n        np.save(test_seeds_path, test_seeds)\n        test_labels = train_labels\n        test_labels_path = os.path.join(test_dir, \"test_labels.pt\")\n        torch.save(test_labels, test_labels_path)\n\n        test_indexes = train_indexes\n        test_indexes_path = os.path.join(test_dir, \"test_indexes.pt\")\n        torch.save(test_indexes, test_indexes_path)\n\n        yaml_content = f\"\"\"\n            tasks:\n              - name: link_prediction\n                train_set:\n                  - type: null\n                    data:\n                      - name: seeds\n                        format: numpy\n                        in_memory: true\n                        path: {train_seeds_path}\n                      - name: labels\n                        format: torch\n                        in_memory: true\n                        path: {train_labels_path}\n                      - name: indexes\n                        format: torch\n                        in_memory: true\n                        path: {train_indexes_path}\n                validation_set:\n                  - data:\n                      - name: seeds\n                        format: numpy\n                        in_memory: true\n                        path: {validation_seeds_path}\n                      - name: labels\n                        format: torch\n                        in_memory: true\n                        path: {validation_labels_path}\n                      - name: indexes\n                        format: torch\n                        in_memory: true\n                        path: {validation_indexes_path}\n                test_set:\n                  - type: null\n                    data:\n                      - name: seeds\n                        format: numpy\n                        in_memory: true\n                        path: {test_seeds_path}\n                      - name: labels\n                        format: torch\n                        in_memory: true\n                        path: {test_labels_path}\n                      - name: indexes\n                        format: torch\n                        in_memory: true\n                        path: {test_indexes_path}\n        \"\"\"\n        dataset = write_yaml_and_load_dataset(yaml_content, test_dir)\n\n        # Verify train set.\n        train_set = dataset.tasks[0].train_set\n        assert len(train_set) == 1000 * 11\n        assert isinstance(train_set, gb.ItemSet)\n        for i, (node_pair, label, index) in enumerate(train_set):\n            assert node_pair[0] == train_seeds[i][0]\n            assert node_pair[1] == train_seeds[i][1]\n            assert label == train_labels[i]\n            assert index == train_indexes[i]\n        assert train_set.names == (\"seeds\", \"labels\", \"indexes\")\n        train_set = None\n\n        # Verify validation set.\n        validation_set = dataset.tasks[0].validation_set\n        assert len(validation_set) == 1000 * 11\n        assert isinstance(validation_set, gb.ItemSet)\n        for i, (node_pair, label, index) in enumerate(validation_set):\n            assert node_pair[0] == validation_seeds[i][0]\n            assert node_pair[1] == validation_seeds[i][1]\n            assert label == validation_labels[i]\n            assert index == validation_indexes[i]\n        assert validation_set.names == (\"seeds\", \"labels\", \"indexes\")\n        validation_set = None\n\n        # Verify test set.\n        test_set = dataset.tasks[0].test_set\n        assert len(test_set) == 1000 * 11\n        assert isinstance(test_set, gb.ItemSet)\n        for i, (node_pair, label, index) in enumerate(test_set):\n            assert node_pair[0] == test_seeds[i][0]\n            assert label == test_labels[i]\n            assert index == test_indexes[i]\n        assert test_set.names == (\"seeds\", \"labels\", \"indexes\")\n        test_set = None\n        dataset = None\n\n\ndef test_OnDiskDataset_TVTSet_HeteroItemSet_id_label():\n    \"\"\"Test TVTSet which returns HeteroItemSet with IDs and labels.\"\"\"\n    with tempfile.TemporaryDirectory() as test_dir:\n        train_ids = np.arange(1000)\n        train_labels = np.random.randint(0, 10, size=1000)\n        train_data = np.vstack([train_ids, train_labels]).T\n        train_path = os.path.join(test_dir, \"train.npy\")\n        np.save(train_path, train_data)\n\n        validation_ids = np.arange(1000, 2000)\n        validation_labels = np.random.randint(0, 10, size=1000)\n        validation_data = np.vstack([validation_ids, validation_labels]).T\n        validation_path = os.path.join(test_dir, \"validation.npy\")\n        np.save(validation_path, validation_data)\n\n        test_ids = np.arange(2000, 3000)\n        test_labels = np.random.randint(0, 10, size=1000)\n        test_data = np.vstack([test_ids, test_labels]).T\n        test_path = os.path.join(test_dir, \"test.npy\")\n        np.save(test_path, test_data)\n\n        yaml_content = f\"\"\"\n            tasks:\n              - name: node_classification\n                train_set:\n                  - type: paper\n                    data:\n                      - name: seeds\n                        format: numpy\n                        in_memory: true\n                        path: {train_path}\n                  - type: author\n                    data:\n                      - name: seeds\n                        format: numpy\n                        path: {train_path}\n                validation_set:\n                  - type: paper\n                    data:\n                      - name: seeds\n                        format: numpy\n                        path: {validation_path}\n                  - type: author\n                    data:\n                      - name: seeds\n                        format: numpy\n                        path: {validation_path}\n                test_set:\n                  - type: paper\n                    data:\n                      - name: seeds\n                        format: numpy\n                        in_memory: false\n                        path: {test_path}\n                  - type: author\n                    data:\n                      - name: seeds\n                        format: numpy\n                        path: {test_path}\n        \"\"\"\n        dataset = write_yaml_and_load_dataset(yaml_content, test_dir)\n\n        # Verify train set.\n        train_set = dataset.tasks[0].train_set\n        assert len(train_set) == 2000\n        assert isinstance(train_set, gb.HeteroItemSet)\n        for i, item in enumerate(train_set):\n            assert isinstance(item, dict)\n            assert len(item) == 1\n            key = list(item.keys())[0]\n            assert key in [\"paper\", \"author\"]\n            id, label = item[key]\n            assert id == train_ids[i % 1000]\n            assert label == train_labels[i % 1000]\n        assert train_set.names == (\"seeds\",)\n        train_set = None\n\n        # Verify validation set.\n        validation_set = dataset.tasks[0].validation_set\n        assert len(validation_set) == 2000\n        assert isinstance(validation_set, gb.HeteroItemSet)\n        for i, item in enumerate(validation_set):\n            assert isinstance(item, dict)\n            assert len(item) == 1\n            key = list(item.keys())[0]\n            assert key in [\"paper\", \"author\"]\n            id, label = item[key]\n            assert id == validation_ids[i % 1000]\n            assert label == validation_labels[i % 1000]\n        assert validation_set.names == (\"seeds\",)\n        validation_set = None\n\n        # Verify test set.\n        test_set = dataset.tasks[0].test_set\n        assert len(test_set) == 2000\n        assert isinstance(test_set, gb.HeteroItemSet)\n        for i, item in enumerate(test_set):\n            assert isinstance(item, dict)\n            assert len(item) == 1\n            key = list(item.keys())[0]\n            assert key in [\"paper\", \"author\"]\n            id, label = item[key]\n            assert id == test_ids[i % 1000]\n            assert label == test_labels[i % 1000]\n        assert test_set.names == (\"seeds\",)\n        test_set = None\n        dataset = None\n\n\ndef test_OnDiskDataset_TVTSet_HeteroItemSet_node_pairs_labels():\n    \"\"\"Test TVTSet which returns HeteroItemSet with node pairs and labels.\"\"\"\n    with tempfile.TemporaryDirectory() as test_dir:\n        train_seeds = np.arange(2000).reshape(1000, 2)\n        train_seeds_path = os.path.join(test_dir, \"train_seeds.npy\")\n        np.save(train_seeds_path, train_seeds)\n        train_labels = np.random.randint(0, 10, size=1000)\n        train_labels_path = os.path.join(test_dir, \"train_labels.npy\")\n        np.save(train_labels_path, train_labels)\n\n        validation_seeds = np.arange(2000, 4000).reshape(1000, 2)\n        validation_seeds_path = os.path.join(test_dir, \"validation_seeds.npy\")\n        np.save(validation_seeds_path, validation_seeds)\n        validation_labels = np.random.randint(0, 10, size=1000)\n        validation_labels_path = os.path.join(test_dir, \"validation_labels.npy\")\n        np.save(validation_labels_path, validation_labels)\n\n        test_seeds = np.arange(4000, 6000).reshape(1000, 2)\n        test_seeds_path = os.path.join(test_dir, \"test_seeds.npy\")\n        np.save(test_seeds_path, test_seeds)\n        test_labels = np.random.randint(0, 10, size=1000)\n        test_labels_path = os.path.join(test_dir, \"test_labels.npy\")\n        np.save(test_labels_path, test_labels)\n\n        yaml_content = f\"\"\"\n            tasks:\n              - name: edge_classification\n                train_set:\n                  - type: paper:cites:paper\n                    data:\n                      - name: seeds\n                        format: numpy\n                        in_memory: true\n                        path: {train_seeds_path}\n                      - name: labels\n                        format: numpy\n                        in_memory: true\n                        path: {train_labels_path}\n                  - type: author:writes:paper\n                    data:\n                      - name: seeds\n                        format: numpy\n                        path: {train_seeds_path}\n                      - name: labels\n                        format: numpy\n                        path: {train_labels_path}\n                validation_set:\n                  - type: paper:cites:paper\n                    data:\n                      - name: seeds\n                        format: numpy\n                        path: {validation_seeds_path}\n                      - name: labels\n                        format: numpy\n                        path: {validation_labels_path}\n                  - type: author:writes:paper\n                    data:\n                      - name: seeds\n                        format: numpy\n                        path: {validation_seeds_path}\n                      - name: labels\n                        format: numpy\n                        path: {validation_labels_path}\n                test_set:\n                  - type: paper:cites:paper\n                    data:\n                      - name: seeds\n                        format: numpy\n                        in_memory: true\n                        path: {test_seeds_path}\n                      - name: labels\n                        format: numpy\n                        in_memory: true\n                        path: {test_labels_path}\n                  - type: author:writes:paper\n                    data:\n                      - name: seeds\n                        format: numpy\n                        in_memory: true\n                        path: {test_seeds_path}\n                      - name: labels\n                        format: numpy\n                        in_memory: true\n                        path: {test_labels_path}\n        \"\"\"\n        dataset = write_yaml_and_load_dataset(yaml_content, test_dir)\n\n        # Verify train set.\n        train_set = dataset.tasks[0].train_set\n        assert len(train_set) == 2000\n        assert isinstance(train_set, gb.HeteroItemSet)\n        for i, item in enumerate(train_set):\n            assert isinstance(item, dict)\n            assert len(item) == 1\n            key = list(item.keys())[0]\n            assert key in [\"paper:cites:paper\", \"author:writes:paper\"]\n            node_pair, label = item[key]\n            assert node_pair[0] == train_seeds[i % 1000][0]\n            assert node_pair[1] == train_seeds[i % 1000][1]\n            assert label == train_labels[i % 1000]\n        assert train_set.names == (\"seeds\", \"labels\")\n        train_set = None\n\n        # Verify validation set.\n        validation_set = dataset.tasks[0].validation_set\n        assert len(validation_set) == 2000\n        assert isinstance(validation_set, gb.HeteroItemSet)\n        for i, item in enumerate(validation_set):\n            assert isinstance(item, dict)\n            assert len(item) == 1\n            key = list(item.keys())[0]\n            assert key in [\"paper:cites:paper\", \"author:writes:paper\"]\n            node_pair, label = item[key]\n            assert node_pair[0] == validation_seeds[i % 1000][0]\n            assert node_pair[1] == validation_seeds[i % 1000][1]\n            assert label == validation_labels[i % 1000]\n        assert validation_set.names == (\"seeds\", \"labels\")\n        validation_set = None\n\n        # Verify test set.\n        test_set = dataset.tasks[0].test_set\n        assert len(test_set) == 2000\n        assert isinstance(test_set, gb.HeteroItemSet)\n        for i, item in enumerate(test_set):\n            assert isinstance(item, dict)\n            assert len(item) == 1\n            key = list(item.keys())[0]\n            assert key in [\"paper:cites:paper\", \"author:writes:paper\"]\n            node_pair, label = item[key]\n            assert node_pair[0] == test_seeds[i % 1000][0]\n            assert node_pair[1] == test_seeds[i % 1000][1]\n            assert label == test_labels[i % 1000]\n        assert test_set.names == (\"seeds\", \"labels\")\n        test_set = None\n        dataset = None\n\n\ndef test_OnDiskDataset_Feature_heterograph():\n    \"\"\"Test Feature storage.\"\"\"\n    with tempfile.TemporaryDirectory() as test_dir:\n        # Generate node data.\n        node_data_paper = np.random.rand(1000, 10)\n        node_data_paper_path = os.path.join(test_dir, \"node_data_paper.npy\")\n        np.save(node_data_paper_path, node_data_paper)\n        node_data_label = torch.tensor(\n            [[random.randint(0, 10)] for _ in range(1000)]\n        )\n        node_data_label_path = os.path.join(test_dir, \"node_data_label.npy\")\n        np.save(node_data_label_path, node_data_label)\n\n        # Generate edge data.\n        edge_data_writes = np.random.rand(1000, 10)\n        edge_data_writes_path = os.path.join(test_dir, \"edge_writes_paper.npy\")\n        np.save(edge_data_writes_path, edge_data_writes)\n        edge_data_label = torch.tensor(\n            [[random.randint(0, 10)] for _ in range(1000)]\n        )\n        edge_data_label_path = os.path.join(test_dir, \"edge_data_label.npy\")\n        np.save(edge_data_label_path, edge_data_label)\n\n        # Generate YAML.\n        yaml_content = f\"\"\"\n            feature_data:\n              - domain: node\n                type: paper\n                name: feat\n                format: numpy\n                in_memory: false\n                path: {node_data_paper_path}\n                num_categories: 10\n              - domain: node\n                type: paper\n                name: labels\n                format: numpy\n                in_memory: true\n                path: {node_data_label_path}\n              - domain: edge\n                type: \"author:writes:paper\"\n                name: feat\n                format: numpy\n                in_memory: false\n                path: {edge_data_writes_path}\n                num_categories: 10\n              - domain: edge\n                type: \"author:writes:paper\"\n                name: labels\n                format: numpy\n                in_memory: true\n                path: {edge_data_label_path}\n        \"\"\"\n        dataset = write_yaml_and_load_dataset(yaml_content, test_dir)\n\n        # Verify feature data storage.\n        feature_data = dataset.feature\n        assert len(feature_data) == 4\n\n        # Verify node feature data.\n        assert torch.equal(\n            feature_data.read(\"node\", \"paper\", \"feat\"),\n            torch.tensor(node_data_paper),\n        )\n        assert (\n            feature_data.metadata(\"node\", \"paper\", \"feat\")[\"num_categories\"]\n            == 10\n        )\n        assert torch.equal(\n            feature_data.read(\"node\", \"paper\", \"labels\"),\n            node_data_label.clone().detach(),\n        )\n        assert len(feature_data.metadata(\"node\", \"paper\", \"labels\")) == 0\n\n        # Verify edge feature data.\n        assert torch.equal(\n            feature_data.read(\"edge\", \"author:writes:paper\", \"feat\"),\n            torch.tensor(edge_data_writes),\n        )\n        assert (\n            feature_data.metadata(\"edge\", \"author:writes:paper\", \"feat\")[\n                \"num_categories\"\n            ]\n            == 10\n        )\n        assert torch.equal(\n            feature_data.read(\"edge\", \"author:writes:paper\", \"labels\"),\n            edge_data_label.clone().detach(),\n        )\n        assert (\n            len(feature_data.metadata(\"edge\", \"author:writes:paper\", \"labels\"))\n            == 0\n        )\n\n        feature_data = None\n        dataset = None\n\n\ndef test_OnDiskDataset_Feature_homograph():\n    \"\"\"Test Feature storage.\"\"\"\n    with tempfile.TemporaryDirectory() as test_dir:\n        # Generate node data.\n        node_data_feat = np.random.rand(1000, 10)\n        node_data_feat_path = os.path.join(test_dir, \"node_data_feat.npy\")\n        np.save(node_data_feat_path, node_data_feat)\n        node_data_label = torch.tensor(\n            [[random.randint(0, 10)] for _ in range(1000)]\n        )\n        node_data_label_path = os.path.join(test_dir, \"node_data_label.npy\")\n        np.save(node_data_label_path, node_data_label)\n\n        # Generate edge data.\n        edge_data_feat = np.random.rand(1000, 10)\n        edge_data_feat_path = os.path.join(test_dir, \"edge_data_feat.npy\")\n        np.save(edge_data_feat_path, edge_data_feat)\n        edge_data_label = torch.tensor(\n            [[random.randint(0, 10)] for _ in range(1000)]\n        )\n        edge_data_label_path = os.path.join(test_dir, \"edge_data_label.npy\")\n        np.save(edge_data_label_path, edge_data_label)\n\n        # Generate YAML.\n        # ``type`` is not specified in the YAML.\n        yaml_content = f\"\"\"\n            feature_data:\n              - domain: node\n                name: feat\n                format: numpy\n                in_memory: false\n                path: {node_data_feat_path}\n                num_categories: 10\n              - domain: node\n                name: labels\n                format: numpy\n                in_memory: true\n                path: {node_data_label_path}\n              - domain: edge\n                name: feat\n                format: numpy\n                in_memory: false\n                path: {edge_data_feat_path}\n                num_categories: 10\n              - domain: edge\n                name: labels\n                format: numpy\n                in_memory: true\n                path: {edge_data_label_path}\n        \"\"\"\n        dataset = write_yaml_and_load_dataset(yaml_content, test_dir)\n\n        # Verify feature data storage.\n        feature_data = dataset.feature\n        assert len(feature_data) == 4\n\n        # Verify node feature data.\n        assert torch.equal(\n            feature_data.read(\"node\", None, \"feat\"),\n            torch.tensor(node_data_feat),\n        )\n        assert (\n            feature_data.metadata(\"node\", None, \"feat\")[\"num_categories\"] == 10\n        )\n        assert torch.equal(\n            feature_data.read(\"node\", None, \"labels\"),\n            node_data_label.clone().detach(),\n        )\n        assert len(feature_data.metadata(\"node\", None, \"labels\")) == 0\n\n        # Verify edge feature data.\n        assert torch.equal(\n            feature_data.read(\"edge\", None, \"feat\"),\n            torch.tensor(edge_data_feat),\n        )\n        assert (\n            feature_data.metadata(\"edge\", None, \"feat\")[\"num_categories\"] == 10\n        )\n        assert torch.equal(\n            feature_data.read(\"edge\", None, \"labels\"),\n            edge_data_label.clone().detach(),\n        )\n        assert len(feature_data.metadata(\"edge\", None, \"labels\")) == 0\n\n        feature_data = None\n        dataset = None\n\n\ndef test_OnDiskDataset_Graph_Exceptions():\n    \"\"\"Test exceptions in parsing graph topology.\"\"\"\n    with tempfile.TemporaryDirectory() as test_dir:\n        # Invalid graph type.\n        yaml_content = \"\"\"\n            graph_topology:\n              type: CSRSamplingGraph\n              path: /path/to/graph\n        \"\"\"\n        write_yaml_file(yaml_content, test_dir)\n\n        with pytest.raises(\n            pydantic.ValidationError,\n            match=\"1 validation error for OnDiskMetaData\",\n        ):\n            _ = gb.OnDiskDataset(test_dir, force_preprocess=False).load()\n\n\ndef test_OnDiskDataset_Graph_homogeneous():\n    \"\"\"Test homogeneous graph topology.\"\"\"\n    csc_indptr, indices = gbt.random_homo_graph(1000, 10 * 1000)\n    graph = gb.fused_csc_sampling_graph(csc_indptr, indices)\n\n    with tempfile.TemporaryDirectory() as test_dir:\n        graph_path = os.path.join(test_dir, \"fused_csc_sampling_graph.pt\")\n        torch.save(graph, graph_path)\n\n        yaml_content = f\"\"\"\n            graph_topology:\n              type: FusedCSCSamplingGraph\n              path: {graph_path}\n        \"\"\"\n        dataset = write_yaml_and_load_dataset(yaml_content, test_dir)\n        graph2 = dataset.graph\n\n        assert graph.total_num_nodes == graph2.total_num_nodes\n        assert graph.total_num_edges == graph2.total_num_edges\n\n        assert torch.equal(graph.csc_indptr, graph2.csc_indptr)\n        assert torch.equal(graph.indices, graph2.indices)\n\n        assert (\n            graph.node_type_offset is None and graph2.node_type_offset is None\n        )\n        assert graph.type_per_edge is None and graph2.type_per_edge is None\n        assert graph.node_type_to_id is None and graph2.node_type_to_id is None\n        assert graph.edge_type_to_id is None and graph2.edge_type_to_id is None\n\n\ndef test_OnDiskDataset_Graph_heterogeneous():\n    \"\"\"Test heterogeneous graph topology.\"\"\"\n    (\n        csc_indptr,\n        indices,\n        node_type_offset,\n        type_per_edge,\n        node_type_to_id,\n        edge_type_to_id,\n    ) = gbt.random_hetero_graph(1000, 10 * 1000, 3, 4)\n    graph = gb.fused_csc_sampling_graph(\n        csc_indptr,\n        indices,\n        node_type_offset=node_type_offset,\n        type_per_edge=type_per_edge,\n        node_type_to_id=node_type_to_id,\n        edge_type_to_id=edge_type_to_id,\n    )\n\n    with tempfile.TemporaryDirectory() as test_dir:\n        graph_path = os.path.join(test_dir, \"fused_csc_sampling_graph.pt\")\n        torch.save(graph, graph_path)\n\n        yaml_content = f\"\"\"\n            graph_topology:\n              type: FusedCSCSamplingGraph\n              path: {graph_path}\n        \"\"\"\n        dataset = write_yaml_and_load_dataset(yaml_content, test_dir)\n        graph2 = dataset.graph\n\n        assert graph.total_num_nodes == graph2.total_num_nodes\n        assert graph.total_num_edges == graph2.total_num_edges\n\n        assert torch.equal(graph.csc_indptr, graph2.csc_indptr)\n        assert torch.equal(graph.indices, graph2.indices)\n        assert torch.equal(graph.node_type_offset, graph2.node_type_offset)\n        assert torch.equal(graph.type_per_edge, graph2.type_per_edge)\n        assert graph.node_type_to_id == graph2.node_type_to_id\n        assert graph.edge_type_to_id == graph2.edge_type_to_id\n\n\ndef test_OnDiskDataset_Metadata():\n    \"\"\"Test metadata of OnDiskDataset.\"\"\"\n    with tempfile.TemporaryDirectory() as test_dir:\n        # All metadata fields are specified.\n        dataset_name = \"graphbolt_test\"\n        yaml_content = f\"\"\"\n            dataset_name: {dataset_name}\n        \"\"\"\n        dataset = write_yaml_and_load_dataset(yaml_content, test_dir)\n        assert dataset.dataset_name == dataset_name\n\n        # Only dataset_name is specified.\n        yaml_content = f\"\"\"\n            dataset_name: {dataset_name}\n        \"\"\"\n        dataset = write_yaml_and_load_dataset(yaml_content, test_dir)\n        assert dataset.dataset_name == dataset_name\n\n\n@pytest.mark.parametrize(\"edge_fmt\", [\"csv\", \"numpy\"])\ndef test_OnDiskDataset_preprocess_homogeneous(edge_fmt):\n    \"\"\"Test preprocess of OnDiskDataset.\"\"\"\n    with tempfile.TemporaryDirectory() as test_dir:\n        # All metadata fields are specified.\n        dataset_name = \"graphbolt_test\"\n        num_nodes = 4000\n        num_edges = 20000\n        num_classes = 10\n\n        # Generate random graph.\n        yaml_content = gbt.random_homo_graphbolt_graph(\n            test_dir,\n            dataset_name,\n            num_nodes,\n            num_edges,\n            num_classes,\n            edge_fmt=edge_fmt,\n        )\n        yaml_file = os.path.join(test_dir, \"metadata.yaml\")\n        with open(yaml_file, \"w\") as f:\n            f.write(yaml_content)\n        output_file = gb.ondisk_dataset.preprocess_ondisk_dataset(\n            test_dir, include_original_edge_id=False\n        )\n\n        with open(output_file, \"rb\") as f:\n            processed_dataset = yaml.load(f, Loader=yaml.Loader)\n\n        assert processed_dataset[\"dataset_name\"] == dataset_name\n        assert processed_dataset[\"tasks\"][0][\"num_classes\"] == num_classes\n        assert \"graph\" not in processed_dataset\n        assert \"graph_topology\" in processed_dataset\n\n        fused_csc_sampling_graph = load_sampling_graph(\n            test_dir, processed_dataset\n        )\n        assert fused_csc_sampling_graph.total_num_nodes == num_nodes\n        assert fused_csc_sampling_graph.total_num_edges == num_edges\n        assert (\n            fused_csc_sampling_graph.node_attributes is not None\n            and \"feat\" in fused_csc_sampling_graph.node_attributes\n        )\n        assert (\n            fused_csc_sampling_graph.edge_attributes is not None\n            and gb.ORIGINAL_EDGE_ID\n            not in fused_csc_sampling_graph.edge_attributes\n            and \"feat\" in fused_csc_sampling_graph.edge_attributes\n        )\n\n        num_samples = 100\n        fanout = 1\n        subgraph = fused_csc_sampling_graph.sample_neighbors(\n            torch.arange(\n                0,\n                num_samples,\n                dtype=fused_csc_sampling_graph.indices.dtype,\n            ),\n            torch.tensor([fanout]),\n        )\n        assert len(subgraph.sampled_csc.indices) <= num_samples\n\n    with tempfile.TemporaryDirectory() as test_dir:\n        # All metadata fields are specified.\n        dataset_name = \"graphbolt_test\"\n        num_nodes = 4000\n        num_edges = 20000\n        num_classes = 10\n\n        # Generate random graph.\n        yaml_content = gbt.random_homo_graphbolt_graph(\n            test_dir,\n            dataset_name,\n            num_nodes,\n            num_edges,\n            num_classes,\n            edge_fmt=edge_fmt,\n        )\n        yaml_file = os.path.join(test_dir, \"metadata.yaml\")\n        with open(yaml_file, \"w\") as f:\n            f.write(yaml_content)\n        # Test generating original_edge_id.\n        output_file = gb.ondisk_dataset.preprocess_ondisk_dataset(\n            test_dir, include_original_edge_id=True\n        )\n        with open(output_file, \"rb\") as f:\n            processed_dataset = yaml.load(f, Loader=yaml.Loader)\n        fused_csc_sampling_graph = load_sampling_graph(\n            test_dir, processed_dataset\n        )\n        assert (\n            fused_csc_sampling_graph.edge_attributes is not None\n            and gb.ORIGINAL_EDGE_ID in fused_csc_sampling_graph.edge_attributes\n        )\n        fused_csc_sampling_graph = None\n\n\n@pytest.mark.parametrize(\"auto_cast\", [False, True])\ndef test_OnDiskDataset_preprocess_homogeneous_hardcode(\n    auto_cast, edge_fmt=\"numpy\"\n):\n    \"\"\"Test preprocess of OnDiskDataset.\"\"\"\n    with tempfile.TemporaryDirectory() as test_dir:\n        \"\"\"Original graph in COO:\n        0   1   1   0   0\n        0   0   1   1   0\n        0   0   0   1   1\n        1   0   0   0   1\n        1   1   0   0   0\n\n        node_feats: [0.0, 1.9, 2.8, 3.7, 4.6]\n        edge_feats: [0.0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9]\n        \"\"\"\n        dataset_name = \"graphbolt_test\"\n        num_nodes = 5\n        num_edges = 10\n        num_classes = 1\n\n        # Generate edges.\n        edges = np.array(\n            [[0, 0, 1, 1, 2, 2, 3, 3, 4, 4], [1, 2, 2, 3, 3, 4, 4, 0, 0, 1]],\n            dtype=np.int64,\n        ).T\n        os.makedirs(os.path.join(test_dir, \"edges\"), exist_ok=True)\n        edges = edges.T\n        edge_path = os.path.join(\"edges\", \"edge.npy\")\n        np.save(os.path.join(test_dir, edge_path), edges)\n\n        # Generate graph edge-feats.\n        edge_feats = np.array(\n            [0.0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9],\n            dtype=np.float64,\n        )\n        os.makedirs(os.path.join(test_dir, \"data\"), exist_ok=True)\n        edge_feat_path = os.path.join(\"data\", \"edge-feat.npy\")\n        np.save(os.path.join(test_dir, edge_feat_path), edge_feats)\n\n        # Generate node-feats.\n        node_feats = np.array(\n            [0.0, 1.9, 2.8, 3.7, 4.6],\n            dtype=np.float64,\n        )\n        node_feat_path = os.path.join(\"data\", \"node-feat.npy\")\n        np.save(os.path.join(test_dir, node_feat_path), node_feats)\n\n        # Generate train/test/valid set.\n        os.makedirs(os.path.join(test_dir, \"set\"), exist_ok=True)\n        train_data = np.array([0, 1, 2, 3, 4])\n        train_path = os.path.join(\"set\", \"train.npy\")\n        np.save(os.path.join(test_dir, train_path), train_data)\n        valid_data = np.array([0, 1, 2, 3, 4])\n        valid_path = os.path.join(\"set\", \"valid.npy\")\n        np.save(os.path.join(test_dir, valid_path), valid_data)\n        test_data = np.array([0, 1, 2, 3, 4])\n        test_path = os.path.join(\"set\", \"test.npy\")\n        np.save(os.path.join(test_dir, test_path), test_data)\n\n        yaml_content = (\n            f\"dataset_name: {dataset_name}\\n\"\n            f\"graph:\\n\"\n            f\"  nodes:\\n\"\n            f\"    - num: {num_nodes}\\n\"\n            f\"  edges:\\n\"\n            f\"    - format: {edge_fmt}\\n\"\n            f\"      path: {edge_path}\\n\"\n            f\"  feature_data:\\n\"\n            f\"    - domain: node\\n\"\n            f\"      type: null\\n\"\n            f\"      name: feat\\n\"\n            f\"      format: numpy\\n\"\n            f\"      in_memory: true\\n\"\n            f\"      path: {node_feat_path}\\n\"\n            f\"    - domain: edge\\n\"\n            f\"      type: null\\n\"\n            f\"      name: feat\\n\"\n            f\"      format: numpy\\n\"\n            f\"      in_memory: true\\n\"\n            f\"      path: {edge_feat_path}\\n\"\n            f\"feature_data:\\n\"\n            f\"  - domain: node\\n\"\n            f\"    type: null\\n\"\n            f\"    name: feat\\n\"\n            f\"    format: numpy\\n\"\n            f\"    in_memory: true\\n\"\n            f\"    path: {node_feat_path}\\n\"\n            f\"  - domain: edge\\n\"\n            f\"    type: null\\n\"\n            f\"    name: feat\\n\"\n            f\"    format: numpy\\n\"\n            f\"    path: {edge_feat_path}\\n\"\n            f\"tasks:\\n\"\n            f\"  - name: node_classification\\n\"\n            f\"    num_classes: {num_classes}\\n\"\n            f\"    train_set:\\n\"\n            f\"      - type: null\\n\"\n            f\"        data:\\n\"\n            f\"          - name: seeds\\n\"\n            f\"            format: numpy\\n\"\n            f\"            in_memory: true\\n\"\n            f\"            path: {train_path}\\n\"\n            f\"    validation_set:\\n\"\n            f\"      - type: null\\n\"\n            f\"        data:\\n\"\n            f\"          - name: seeds\\n\"\n            f\"            format: numpy\\n\"\n            f\"            in_memory: true\\n\"\n            f\"            path: {valid_path}\\n\"\n            f\"    test_set:\\n\"\n            f\"      - type: null\\n\"\n            f\"        data:\\n\"\n            f\"          - name: seeds\\n\"\n            f\"            format: numpy\\n\"\n            f\"            in_memory: true\\n\"\n            f\"            path: {test_path}\\n\"\n        )\n        yaml_file = os.path.join(test_dir, \"metadata.yaml\")\n        with open(yaml_file, \"w\") as f:\n            f.write(yaml_content)\n        output_file = gb.ondisk_dataset.preprocess_ondisk_dataset(\n            test_dir,\n            include_original_edge_id=True,\n            auto_cast_to_optimal_dtype=auto_cast,\n        )\n\n        with open(output_file, \"rb\") as f:\n            processed_dataset = yaml.load(f, Loader=yaml.Loader)\n\n        assert processed_dataset[\"dataset_name\"] == dataset_name\n        assert processed_dataset[\"tasks\"][0][\"num_classes\"] == num_classes\n        assert \"graph\" not in processed_dataset\n        assert \"graph_topology\" in processed_dataset\n\n        fused_csc_sampling_graph = load_sampling_graph(\n            test_dir, processed_dataset\n        )\n        assert fused_csc_sampling_graph.total_num_nodes == num_nodes\n        assert fused_csc_sampling_graph.total_num_edges == num_edges\n        assert torch.equal(\n            fused_csc_sampling_graph.csc_indptr,\n            torch.tensor([0, 2, 4, 6, 8, 10]),\n        )\n        assert torch.equal(\n            fused_csc_sampling_graph.indices,\n            torch.tensor([3, 4, 0, 4, 0, 1, 1, 2, 2, 3]),\n        )\n        assert torch.equal(\n            fused_csc_sampling_graph.node_attributes[\"feat\"],\n            torch.tensor([0.0, 1.9, 2.8, 3.7, 4.6], dtype=torch.float64),\n        )\n        assert torch.equal(\n            fused_csc_sampling_graph.edge_attributes[\"feat\"],\n            torch.tensor(\n                [0.0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9],\n                dtype=torch.float64,\n            ),\n        )\n        assert torch.equal(\n            fused_csc_sampling_graph.edge_attributes[gb.ORIGINAL_EDGE_ID],\n            torch.tensor([7, 8, 0, 9, 1, 2, 3, 4, 5, 6]),\n        )\n\n        expected_dtype = torch.int32 if auto_cast else torch.int64\n        assert fused_csc_sampling_graph.csc_indptr.dtype == expected_dtype\n        assert fused_csc_sampling_graph.indices.dtype == expected_dtype\n        assert (\n            fused_csc_sampling_graph.edge_attributes[gb.ORIGINAL_EDGE_ID].dtype\n            == expected_dtype\n        )\n\n        num_samples = 5\n        fanout = 1\n        subgraph = fused_csc_sampling_graph.sample_neighbors(\n            torch.arange(\n                0,\n                num_samples,\n                dtype=fused_csc_sampling_graph.indices.dtype,\n            ),\n            torch.tensor([fanout]),\n        )\n        assert len(subgraph.sampled_csc.indices) <= num_samples\n\n\n@pytest.mark.parametrize(\"auto_cast\", [False, True])\ndef test_OnDiskDataset_preprocess_heterogeneous_hardcode(\n    auto_cast, edge_fmt=\"numpy\"\n):\n    \"\"\"Test preprocess of OnDiskDataset.\"\"\"\n    with tempfile.TemporaryDirectory() as test_dir:\n        \"\"\"Original graph in COO:\n        0   1   1   0   0\n        0   0   1   1   0\n        0   0   0   1   1\n        1   0   0   0   1\n        1   1   0   0   0\n\n        node_type_0: [0, 1]\n        node_type_1: [2, 3, 4]\n        edge_type_0: node_type_0 -> node_type_0\n        edge_type_1: node_type_0 -> node_type_1\n        edge_type_2: node_type_1 -> node_type_1\n        edge_type_3: node_type_1 -> node_type_0\n\n        node_feats: [0.0, 1.9, 2.8, 3.7, 4.6]\n        edge_feats: [0.0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9]\n        \"\"\"\n        dataset_name = \"graphbolt_test\"\n        num_nodes = {\n            \"A\": 2,\n            \"B\": 3,\n        }\n        num_edges = {\n            (\"A\", \"a_a\", \"A\"): 1,\n            (\"A\", \"a_b\", \"B\"): 3,\n            (\"B\", \"b_b\", \"A\"): 3,\n            (\"B\", \"b_a\", \"B\"): 3,\n        }\n        num_classes = 1\n\n        # Generate edges.\n        os.makedirs(os.path.join(test_dir, \"edges\"), exist_ok=True)\n        np.save(\n            os.path.join(test_dir, \"edges\", \"a_a.npy\"),\n            np.array([[0], [1]], dtype=np.int64),\n        )\n        np.save(\n            os.path.join(test_dir, \"edges\", \"a_b.npy\"),\n            np.array([[0, 1, 1], [0, 0, 1]], dtype=np.int64),\n        )\n        np.save(\n            os.path.join(test_dir, \"edges\", \"b_b.npy\"),\n            np.array([[0, 0, 1], [1, 2, 2]], dtype=np.int64),\n        )\n        np.save(\n            os.path.join(test_dir, \"edges\", \"b_a.npy\"),\n            np.array([[1, 2, 2], [0, 0, 1]], dtype=np.int64),\n        )\n\n        # Generate node features.\n        os.makedirs(os.path.join(test_dir, \"data\"), exist_ok=True)\n        np.save(\n            os.path.join(test_dir, \"data\", \"A-feat.npy\"),\n            np.array([0.0, 1.9], dtype=np.float64),\n        )\n        np.save(\n            os.path.join(test_dir, \"data\", \"B-feat.npy\"),\n            np.array([2.8, 3.7, 4.6], dtype=np.float64),\n        )\n\n        # Generate edge features.\n        os.makedirs(os.path.join(test_dir, \"data\"), exist_ok=True)\n        np.save(\n            os.path.join(test_dir, \"data\", \"a_a-feat.npy\"),\n            np.array([0.0], dtype=np.float64),\n        )\n        np.save(\n            os.path.join(test_dir, \"data\", \"a_b-feat.npy\"),\n            np.array([1.1, 2.2, 3.3], dtype=np.float64),\n        )\n        np.save(\n            os.path.join(test_dir, \"data\", \"b_b-feat.npy\"),\n            np.array([4.4, 5.5, 6.6], dtype=np.float64),\n        )\n        np.save(\n            os.path.join(test_dir, \"data\", \"b_a-feat.npy\"),\n            np.array([7.7, 8.8, 9.9], dtype=np.float64),\n        )\n\n        yaml_content = (\n            f\"dataset_name: {dataset_name}\\n\"\n            f\"graph:\\n\"\n            f\"  nodes:\\n\"\n            f\"    - type: A\\n\"\n            f\"      num: 2\\n\"\n            f\"    - type: B\\n\"\n            f\"      num: 3\\n\"\n            f\"  edges:\\n\"\n            f\"    - type: A:a_a:A\\n\"\n            f\"      format: {edge_fmt}\\n\"\n            f\"      path: {os.path.join('edges', 'a_a.npy')}\\n\"\n            f\"    - type: A:a_b:B\\n\"\n            f\"      format: {edge_fmt}\\n\"\n            f\"      path: {os.path.join('edges', 'a_b.npy')}\\n\"\n            f\"    - type: B:b_b:B\\n\"\n            f\"      format: {edge_fmt}\\n\"\n            f\"      path: {os.path.join('edges', 'b_b.npy')}\\n\"\n            f\"    - type: B:b_a:A\\n\"\n            f\"      format: {edge_fmt}\\n\"\n            f\"      path: {os.path.join('edges', 'b_a.npy')}\\n\"\n            f\"  feature_data:\\n\"\n            f\"    - domain: node\\n\"\n            f\"      type: A\\n\"\n            f\"      name: feat\\n\"\n            f\"      format: numpy\\n\"\n            f\"      in_memory: true\\n\"\n            f\"      path: {os.path.join(test_dir, 'data', 'A-feat.npy')}\\n\"\n            f\"    - domain: node\\n\"\n            f\"      type: B\\n\"\n            f\"      name: feat\\n\"\n            f\"      format: numpy\\n\"\n            f\"      in_memory: true\\n\"\n            f\"      path: {os.path.join(test_dir, 'data', 'B-feat.npy')}\\n\"\n            f\"    - domain: edge\\n\"\n            f\"      type: A:a_a:A\\n\"\n            f\"      name: feat\\n\"\n            f\"      format: numpy\\n\"\n            f\"      in_memory: true\\n\"\n            f\"      path: {os.path.join(test_dir, 'data', 'a_a-feat.npy')}\\n\"\n            f\"    - domain: edge\\n\"\n            f\"      type: A:a_b:B\\n\"\n            f\"      name: feat\\n\"\n            f\"      format: numpy\\n\"\n            f\"      in_memory: true\\n\"\n            f\"      path: {os.path.join(test_dir, 'data', 'a_b-feat.npy')}\\n\"\n            f\"    - domain: edge\\n\"\n            f\"      type: B:b_b:B\\n\"\n            f\"      name: feat\\n\"\n            f\"      format: numpy\\n\"\n            f\"      in_memory: true\\n\"\n            f\"      path: {os.path.join(test_dir, 'data', 'b_b-feat.npy')}\\n\"\n            f\"    - domain: edge\\n\"\n            f\"      type: B:b_a:A\\n\"\n            f\"      name: feat\\n\"\n            f\"      format: numpy\\n\"\n            f\"      in_memory: true\\n\"\n            f\"      path: {os.path.join(test_dir, 'data', 'b_a-feat.npy')}\\n\"\n        )\n        yaml_file = os.path.join(test_dir, \"metadata.yaml\")\n        with open(yaml_file, \"w\") as f:\n            f.write(yaml_content)\n        output_file = gb.ondisk_dataset.preprocess_ondisk_dataset(\n            test_dir,\n            include_original_edge_id=True,\n            auto_cast_to_optimal_dtype=auto_cast,\n        )\n\n        with open(output_file, \"rb\") as f:\n            processed_dataset = yaml.load(f, Loader=yaml.Loader)\n\n        assert processed_dataset[\"dataset_name\"] == dataset_name\n        assert \"graph\" not in processed_dataset\n        assert \"graph_topology\" in processed_dataset\n\n        fused_csc_sampling_graph = load_sampling_graph(\n            test_dir, processed_dataset\n        )\n        assert fused_csc_sampling_graph.total_num_nodes == 5\n        assert fused_csc_sampling_graph.total_num_edges == 10\n        assert torch.equal(\n            fused_csc_sampling_graph.csc_indptr,\n            torch.tensor([0, 2, 4, 6, 8, 10]),\n        )\n        assert torch.equal(\n            fused_csc_sampling_graph.indices,\n            torch.tensor([3, 4, 0, 4, 0, 1, 1, 2, 2, 3]),\n        )\n        assert torch.equal(\n            fused_csc_sampling_graph.node_attributes[\"feat\"],\n            torch.tensor([0.0, 1.9, 2.8, 3.7, 4.6], dtype=torch.float64),\n        )\n        assert torch.equal(\n            fused_csc_sampling_graph.edge_attributes[\"feat\"],\n            torch.tensor(\n                [0.0, 1.1, 2.2, 3.3, 7.7, 8.8, 9.9, 4.4, 5.5, 6.6],\n                dtype=torch.float64,\n            ),\n        )\n        assert torch.equal(\n            fused_csc_sampling_graph.type_per_edge,\n            torch.tensor([2, 2, 0, 2, 1, 1, 1, 3, 3, 3]),\n        )\n        assert torch.equal(\n            fused_csc_sampling_graph.edge_attributes[gb.ORIGINAL_EDGE_ID],\n            torch.tensor([0, 1, 0, 2, 0, 1, 2, 0, 1, 2]),\n        )\n        expected_dtype = torch.int32 if auto_cast else torch.int64\n        assert fused_csc_sampling_graph.csc_indptr.dtype == expected_dtype\n        assert fused_csc_sampling_graph.indices.dtype == expected_dtype\n        assert (\n            fused_csc_sampling_graph.edge_attributes[gb.ORIGINAL_EDGE_ID].dtype\n            == expected_dtype\n        )\n        assert fused_csc_sampling_graph.node_type_offset.dtype == expected_dtype\n        expected_etype_dtype = torch.uint8 if auto_cast else torch.int64\n        assert (\n            fused_csc_sampling_graph.type_per_edge.dtype == expected_etype_dtype\n        )\n\n\ndef test_OnDiskDataset_preprocess_path():\n    \"\"\"Test if the preprocess function can catch the path error.\"\"\"\n    with tempfile.TemporaryDirectory() as test_dir:\n        # All metadata fields are specified.\n        dataset_name = \"graphbolt_test\"\n\n        yaml_content = f\"\"\"\n            dataset_name: {dataset_name}\n        \"\"\"\n        yaml_file = os.path.join(test_dir, \"metadata.yaml\")\n        with open(yaml_file, \"w\") as f:\n            f.write(yaml_content)\n\n        # Case1. Test the passed in is the yaml file path.\n        with pytest.raises(\n            RuntimeError,\n            match=\"The dataset must be a directory. \"\n            rf\"But got {re.escape(yaml_file)}\",\n        ):\n            _ = gb.OnDiskDataset(yaml_file)\n\n        # Case2. Test the passed in is a fake directory.\n        fake_dir = os.path.join(test_dir, \"fake_dir\")\n        with pytest.raises(\n            RuntimeError,\n            match=rf\"Invalid dataset path: {re.escape(fake_dir)}\",\n        ):\n            _ = gb.OnDiskDataset(fake_dir)\n\n        # Case3. Test the passed in is the dataset directory.\n        # But the metadata.yaml is not in the directory.\n        os.makedirs(os.path.join(test_dir, \"fake_dir\"), exist_ok=True)\n        with pytest.raises(\n            RuntimeError,\n            match=r\"metadata.yaml does not exist.\",\n        ):\n            _ = gb.OnDiskDataset(fake_dir)\n\n\ndef test_OnDiskDataset_preprocess_yaml_content():\n    \"\"\"Test if the preprocessed metadata.yaml is correct.\"\"\"\n    with tempfile.TemporaryDirectory() as test_dir:\n        # All metadata fields are specified.\n        dataset_name = \"graphbolt_test\"\n        num_nodes = 4000\n        num_edges = 20000\n        num_classes = 10\n\n        # Generate random edges.\n        nodes = np.repeat(np.arange(num_nodes), 5)\n        neighbors = np.random.randint(0, num_nodes, size=(num_edges))\n        edges = np.stack([nodes, neighbors], axis=1)\n        # Write into edges/edge.csv\n        os.makedirs(os.path.join(test_dir, \"edges\"), exist_ok=True)\n        edges = pd.DataFrame(edges, columns=[\"src\", \"dst\"])\n        edge_path = os.path.join(\"edges\", \"edge.csv\")\n        edges.to_csv(\n            os.path.join(test_dir, edge_path),\n            index=False,\n            header=False,\n        )\n\n        # Generate random graph edge-feats.\n        edge_feats = np.random.rand(num_edges, 5)\n        os.makedirs(os.path.join(test_dir, \"data\"), exist_ok=True)\n        feature_edge = os.path.join(\"data\", \"edge-feat.npy\")\n        np.save(os.path.join(test_dir, feature_edge), edge_feats)\n\n        # Generate random node-feats.\n        node_feats = np.random.rand(num_nodes, 10)\n        feature_node = os.path.join(\"data\", \"node-feat.npy\")\n        np.save(os.path.join(test_dir, feature_node), node_feats)\n\n        # Generate train/test/valid set.\n        os.makedirs(os.path.join(test_dir, \"set\"), exist_ok=True)\n        train_pairs = (np.arange(1000), np.arange(1000, 2000))\n        train_labels = np.random.randint(0, 10, size=1000)\n        train_data = np.vstack([train_pairs, train_labels]).T\n        train_path = os.path.join(\"set\", \"train.npy\")\n        np.save(os.path.join(test_dir, train_path), train_data)\n\n        validation_pairs = (np.arange(1000, 2000), np.arange(2000, 3000))\n        validation_labels = np.random.randint(0, 10, size=1000)\n        validation_data = np.vstack([validation_pairs, validation_labels]).T\n        validation_path = os.path.join(\"set\", \"validation.npy\")\n        np.save(os.path.join(test_dir, validation_path), validation_data)\n\n        test_pairs = (np.arange(2000, 3000), np.arange(3000, 4000))\n        test_labels = np.random.randint(0, 10, size=1000)\n        test_data = np.vstack([test_pairs, test_labels]).T\n        test_path = os.path.join(\"set\", \"test.npy\")\n        np.save(os.path.join(test_dir, test_path), test_data)\n\n        yaml_content = f\"\"\"\n            dataset_name: {dataset_name}\n            graph: # graph structure and required attributes.\n                nodes:\n                    - num: {num_nodes}\n                edges:\n                    - format: csv\n                      path: {edge_path}\n                feature_data:\n                    - domain: edge\n                      type: null\n                      name: feat\n                      format: numpy\n                      in_memory: true\n                      path: {feature_edge}\n            feature_data:\n                - domain: node\n                  type: null\n                  name: feat\n                  format: numpy\n                  in_memory: false\n                  path: {feature_node}\n            tasks:\n              - name: node_classification\n                num_classes: {num_classes}\n                train_set:\n                  - type: null\n                    data:\n                      - format: numpy\n                        path: {train_path}\n                validation_set:\n                  - type: null\n                    data:\n                      - format: numpy\n                        path: {validation_path}\n                test_set:\n                  - type: null\n                    data:\n                      - format: numpy\n                        path: {test_path}\n        \"\"\"\n        yaml_file = os.path.join(test_dir, \"metadata.yaml\")\n        with open(yaml_file, \"w\") as f:\n            f.write(yaml_content)\n\n        preprocessed_metadata_path = gb.preprocess_ondisk_dataset(test_dir)\n        with open(preprocessed_metadata_path, \"r\") as f:\n            yaml_data = yaml.safe_load(f)\n\n        topo_path = os.path.join(\"preprocessed\", \"fused_csc_sampling_graph.pt\")\n        target_yaml_content = f\"\"\"\n            dataset_name: {dataset_name}\n            graph_topology:\n              type: FusedCSCSamplingGraph\n              path: {topo_path}\n            feature_data:\n              - domain: node\n                type: null\n                name: feat\n                format: numpy\n                in_memory: false\n                path: {os.path.join(\"preprocessed\", feature_node)}\n            tasks:\n              - name: node_classification\n                num_classes: {num_classes}\n                train_set:\n                  - type: null\n                    data:\n                      - format: numpy\n                        path: {os.path.join(\"preprocessed\", train_path)}\n                validation_set:\n                  - type: null\n                    data:\n                      - format: numpy\n                        path: {os.path.join(\"preprocessed\", validation_path)}\n                test_set:\n                  - type: null\n                    data:\n                      - format: numpy\n                        path: {os.path.join(\"preprocessed\", test_path)}\n            include_original_edge_id: False\n        \"\"\"\n        target_yaml_data = yaml.safe_load(target_yaml_content)\n        # Check yaml content.\n        assert (\n            yaml_data == target_yaml_data\n        ), \"The preprocessed metadata.yaml is not correct.\"\n\n        # Check file existence.\n        assert os.path.exists(\n            os.path.join(test_dir, yaml_data[\"graph_topology\"][\"path\"])\n        )\n        assert os.path.exists(\n            os.path.join(test_dir, yaml_data[\"feature_data\"][0][\"path\"])\n        )\n        for set_name in [\"train_set\", \"validation_set\", \"test_set\"]:\n            assert os.path.exists(\n                os.path.join(\n                    test_dir,\n                    yaml_data[\"tasks\"][0][set_name][0][\"data\"][0][\"path\"],\n                )\n            )\n\n\ndef test_OnDiskDataset_preprocess_force_preprocess(capsys):\n    \"\"\"Test force preprocess of OnDiskDataset.\"\"\"\n    with tempfile.TemporaryDirectory() as test_dir:\n        # All metadata fields are specified.\n        dataset_name = \"graphbolt_test\"\n        num_nodes = 4000\n        num_edges = 20000\n        num_classes = 10\n\n        # Generate random graph.\n        yaml_content = gbt.random_homo_graphbolt_graph(\n            test_dir,\n            dataset_name,\n            num_nodes,\n            num_edges,\n            num_classes,\n        )\n        yaml_file = os.path.join(test_dir, \"metadata.yaml\")\n        with open(yaml_file, \"w\") as f:\n            f.write(yaml_content)\n\n        # First preprocess on-disk dataset.\n        preprocessed_metadata_path = (\n            gb.ondisk_dataset.preprocess_ondisk_dataset(\n                test_dir, include_original_edge_id=False, force_preprocess=False\n            )\n        )\n        captured = capsys.readouterr().out.split(\"\\n\")\n        assert captured == [\n            \"Start to preprocess the on-disk dataset.\",\n            \"Finish preprocessing the on-disk dataset.\",\n            \"\",\n        ]\n        with open(preprocessed_metadata_path, \"r\") as f:\n            target_yaml_data = yaml.safe_load(f)\n        assert target_yaml_data[\"tasks\"][0][\"name\"] == \"link_prediction\"\n\n        # Change yaml_data, but do not force preprocess on-disk dataset.\n        with open(yaml_file, \"r\") as f:\n            yaml_data = yaml.safe_load(f)\n        yaml_data[\"tasks\"][0][\"name\"] = \"fake_name\"\n        with open(yaml_file, \"w\") as f:\n            yaml.dump(yaml_data, f)\n        preprocessed_metadata_path = (\n            gb.ondisk_dataset.preprocess_ondisk_dataset(\n                test_dir, include_original_edge_id=False, force_preprocess=False\n            )\n        )\n        captured = capsys.readouterr().out.split(\"\\n\")\n        assert captured == [\"The dataset is already preprocessed.\", \"\"]\n        with open(preprocessed_metadata_path, \"r\") as f:\n            target_yaml_data = yaml.safe_load(f)\n        assert target_yaml_data[\"tasks\"][0][\"name\"] == \"link_prediction\"\n\n        # Force preprocess on-disk dataset.\n        preprocessed_metadata_path = (\n            gb.ondisk_dataset.preprocess_ondisk_dataset(\n                test_dir, include_original_edge_id=False, force_preprocess=True\n            )\n        )\n        captured = capsys.readouterr().out.split(\"\\n\")\n        assert captured == [\n            \"The on-disk dataset is re-preprocessing, so the existing \"\n            + \"preprocessed dataset has been removed.\",\n            \"Start to preprocess the on-disk dataset.\",\n            \"Finish preprocessing the on-disk dataset.\",\n            \"\",\n        ]\n        with open(preprocessed_metadata_path, \"r\") as f:\n            target_yaml_data = yaml.safe_load(f)\n        assert target_yaml_data[\"tasks\"][0][\"name\"] == \"fake_name\"\n\n\ndef test_OnDiskDataset_preprocess_auto_force_preprocess(capsys):\n    \"\"\"Test force preprocess of OnDiskDataset.\"\"\"\n    with tempfile.TemporaryDirectory() as test_dir:\n        # All metadata fields are specified.\n        dataset_name = \"graphbolt_test\"\n        num_nodes = 4000\n        num_edges = 20000\n        num_classes = 10\n\n        # Generate random graph.\n        yaml_content = gbt.random_homo_graphbolt_graph(\n            test_dir,\n            dataset_name,\n            num_nodes,\n            num_edges,\n            num_classes,\n        )\n        yaml_file = os.path.join(test_dir, \"metadata.yaml\")\n        with open(yaml_file, \"w\") as f:\n            f.write(yaml_content)\n\n        # First preprocess on-disk dataset.\n        preprocessed_metadata_path = (\n            gb.ondisk_dataset.preprocess_ondisk_dataset(\n                test_dir, include_original_edge_id=False\n            )\n        )\n        captured = capsys.readouterr().out.split(\"\\n\")\n        assert captured == [\n            \"Start to preprocess the on-disk dataset.\",\n            \"Finish preprocessing the on-disk dataset.\",\n            \"\",\n        ]\n        with open(preprocessed_metadata_path, \"r\") as f:\n            target_yaml_data = yaml.safe_load(f)\n        assert target_yaml_data[\"tasks\"][0][\"name\"] == \"link_prediction\"\n\n        # 1. Change yaml_data.\n        with open(yaml_file, \"r\") as f:\n            yaml_data = yaml.safe_load(f)\n        yaml_data[\"tasks\"][0][\"name\"] = \"fake_name\"\n        with open(yaml_file, \"w\") as f:\n            yaml.dump(yaml_data, f)\n        preprocessed_metadata_path = (\n            gb.ondisk_dataset.preprocess_ondisk_dataset(\n                test_dir, include_original_edge_id=False\n            )\n        )\n        captured = capsys.readouterr().out.split(\"\\n\")\n        assert captured == [\n            \"The on-disk dataset is re-preprocessing, so the existing \"\n            + \"preprocessed dataset has been removed.\",\n            \"Start to preprocess the on-disk dataset.\",\n            \"Finish preprocessing the on-disk dataset.\",\n            \"\",\n        ]\n        with open(preprocessed_metadata_path, \"r\") as f:\n            target_yaml_data = yaml.safe_load(f)\n        assert target_yaml_data[\"tasks\"][0][\"name\"] == \"fake_name\"\n\n        # 2. Change edge feature.\n        edge_feats = np.random.rand(num_edges, num_classes)\n        edge_feat_path = os.path.join(\"data\", \"edge-feat.npy\")\n        np.save(os.path.join(test_dir, edge_feat_path), edge_feats)\n        preprocessed_metadata_path = (\n            gb.ondisk_dataset.preprocess_ondisk_dataset(\n                test_dir, include_original_edge_id=False\n            )\n        )\n        captured = capsys.readouterr().out.split(\"\\n\")\n        assert captured == [\n            \"The on-disk dataset is re-preprocessing, so the existing \"\n            + \"preprocessed dataset has been removed.\",\n            \"Start to preprocess the on-disk dataset.\",\n            \"Finish preprocessing the on-disk dataset.\",\n            \"\",\n        ]\n        preprocessed_edge_feat = np.load(\n            os.path.join(test_dir, \"preprocessed\", edge_feat_path)\n        )\n        assert preprocessed_edge_feat.all() == edge_feats.all()\n        with open(preprocessed_metadata_path, \"r\") as f:\n            target_yaml_data = yaml.safe_load(f)\n        assert target_yaml_data[\"include_original_edge_id\"] == False\n\n        # 3. Change include_original_edge_id.\n        preprocessed_metadata_path = (\n            gb.ondisk_dataset.preprocess_ondisk_dataset(\n                test_dir, include_original_edge_id=True\n            )\n        )\n        captured = capsys.readouterr().out.split(\"\\n\")\n        assert captured == [\n            \"The on-disk dataset is re-preprocessing, so the existing \"\n            + \"preprocessed dataset has been removed.\",\n            \"Start to preprocess the on-disk dataset.\",\n            \"Finish preprocessing the on-disk dataset.\",\n            \"\",\n        ]\n        with open(preprocessed_metadata_path, \"r\") as f:\n            target_yaml_data = yaml.safe_load(f)\n        assert target_yaml_data[\"include_original_edge_id\"] == True\n\n        # 4. Change nothing.\n        preprocessed_metadata_path = (\n            gb.ondisk_dataset.preprocess_ondisk_dataset(\n                test_dir, include_original_edge_id=True\n            )\n        )\n        captured = capsys.readouterr().out.split(\"\\n\")\n        assert captured == [\"The dataset is already preprocessed.\", \"\"]\n\n\ndef test_OnDiskDataset_preprocess_not_include_eids():\n    with tempfile.TemporaryDirectory() as test_dir:\n        # All metadata fields are specified.\n        dataset_name = \"graphbolt_test\"\n        num_nodes = 4000\n        num_edges = 20000\n        num_classes = 10\n\n        # Generate random graph.\n        yaml_content = gbt.random_homo_graphbolt_graph(\n            test_dir,\n            dataset_name,\n            num_nodes,\n            num_edges,\n            num_classes,\n        )\n        yaml_file = os.path.join(test_dir, \"metadata.yaml\")\n        with open(yaml_file, \"w\") as f:\n            f.write(yaml_content)\n\n        with pytest.warns(\n            GBWarning,\n            match=\"Edge feature is stored, but edge IDs are not saved.\",\n        ):\n            gb.ondisk_dataset.preprocess_ondisk_dataset(\n                test_dir, include_original_edge_id=False\n            )\n\n\n@pytest.mark.parametrize(\"edge_fmt\", [\"csv\", \"numpy\"])\ndef test_OnDiskDataset_load_name(edge_fmt):\n    \"\"\"Test preprocess of OnDiskDataset.\"\"\"\n    with tempfile.TemporaryDirectory() as test_dir:\n        # All metadata fields are specified.\n        dataset_name = \"graphbolt_test\"\n        num_nodes = 4000\n        num_edges = 20000\n        num_classes = 10\n\n        # Generate random graph.\n        yaml_content = gbt.random_homo_graphbolt_graph(\n            test_dir,\n            dataset_name,\n            num_nodes,\n            num_edges,\n            num_classes,\n            edge_fmt=edge_fmt,\n        )\n        yaml_file = os.path.join(test_dir, \"metadata.yaml\")\n        with open(yaml_file, \"w\") as f:\n            f.write(yaml_content)\n\n        # Check modify `dataset_name` field.\n        dataset = gb.OnDiskDataset(test_dir)\n        dataset.yaml_data[\"dataset_name\"] = \"fake_name\"\n        dataset.load()\n        assert dataset.dataset_name == \"fake_name\"\n        dataset = None\n\n\n@pytest.mark.parametrize(\"edge_fmt\", [\"csv\", \"numpy\"])\ndef test_OnDiskDataset_load_feature(edge_fmt):\n    \"\"\"Test preprocess of OnDiskDataset.\"\"\"\n    with tempfile.TemporaryDirectory() as test_dir:\n        # All metadata fields are specified.\n        dataset_name = \"graphbolt_test\"\n        num_nodes = 4000\n        num_edges = 20000\n        num_classes = 10\n\n        # Generate random graph.\n        yaml_content = gbt.random_homo_graphbolt_graph(\n            test_dir,\n            dataset_name,\n            num_nodes,\n            num_edges,\n            num_classes,\n            edge_fmt=edge_fmt,\n        )\n        yaml_file = os.path.join(test_dir, \"metadata.yaml\")\n        with open(yaml_file, \"w\") as f:\n            f.write(yaml_content)\n\n        # Case1. Test modify the `in_memory` field.\n        dataset = gb.OnDiskDataset(test_dir).load()\n        original_feature_data = dataset.feature\n        dataset.yaml_data[\"feature_data\"][0][\"in_memory\"] = True\n        load_dataset(dataset)\n        modify_feature_data = dataset.feature\n        # After modify the `in_memory` field, the feature data should be\n        # equal.\n        assert torch.equal(\n            original_feature_data.read(\"node\", None, \"feat\"),\n            modify_feature_data.read(\"node\", None, \"feat\"),\n        )\n\n        # Case2. Test modify the `format` field.\n        dataset = gb.OnDiskDataset(test_dir)\n        # If `format` is torch and `in_memory` is False, it will\n        # raise an AssertionError.\n        dataset.yaml_data[\"feature_data\"][0][\"in_memory\"] = False\n        dataset.yaml_data[\"feature_data\"][0][\"format\"] = \"torch\"\n        with pytest.raises(\n            AssertionError,\n            match=\"^Pytorch tensor can only be loaded in memory,\",\n        ):\n            load_dataset(dataset)\n\n        dataset = gb.OnDiskDataset(test_dir)\n        dataset.yaml_data[\"feature_data\"][0][\"in_memory\"] = True\n        dataset.yaml_data[\"feature_data\"][0][\"format\"] = \"torch\"\n        # If `format` is torch and `in_memory` is True, it will\n        # raise an UnpicklingError.\n        with pytest.raises(pickle.UnpicklingError):\n            load_dataset(dataset)\n\n        # Case3. Test modify the `path` field.\n        dataset = gb.OnDiskDataset(test_dir)\n        # Use invalid path will raise an FileNotFoundError.\n        dataset.yaml_data[\"feature_data\"][0][\"path\"] = \"fake_path\"\n        with pytest.raises(\n            FileNotFoundError,\n            match=r\"\\[Errno 2\\] No such file or directory:\",\n        ):\n            load_dataset(dataset)\n        # Modifying the `path` field to an absolute path should work.\n        # In os.path.join, if a segment is an absolute path (which\n        # on Windows requires both a drive and a root), then all\n        # previous segments are ignored and joining continues from\n        # the absolute path segment.\n        dataset = load_dataset(gb.OnDiskDataset(test_dir))\n        original_feature_data = dataset.feature\n        dataset.yaml_data[\"feature_data\"][0][\"path\"] = os.path.join(\n            test_dir, dataset.yaml_data[\"feature_data\"][0][\"path\"]\n        )\n        load_dataset(dataset)\n        modify_feature_data = dataset.feature\n        assert torch.equal(\n            original_feature_data.read(\"node\", None, \"feat\"),\n            modify_feature_data.read(\"node\", None, \"feat\"),\n        )\n        original_feature_data = None\n        modify_feature_data = None\n        dataset = None\n\n\n@pytest.mark.parametrize(\"edge_fmt\", [\"csv\", \"numpy\"])\ndef test_OnDiskDataset_load_graph(edge_fmt):\n    \"\"\"Test preprocess of OnDiskDataset.\"\"\"\n    with tempfile.TemporaryDirectory() as test_dir:\n        # All metadata fields are specified.\n        dataset_name = \"graphbolt_test\"\n        num_nodes = 4000\n        num_edges = 20000\n        num_classes = 10\n\n        # Generate random graph.\n        yaml_content = gbt.random_homo_graphbolt_graph(\n            test_dir,\n            dataset_name,\n            num_nodes,\n            num_edges,\n            num_classes,\n            edge_fmt=edge_fmt,\n        )\n        yaml_file = os.path.join(test_dir, \"metadata.yaml\")\n        with open(yaml_file, \"w\") as f:\n            f.write(yaml_content)\n\n        # Check the different original_edge_id option to load edge_attributes.\n        dataset = gb.OnDiskDataset(\n            test_dir, include_original_edge_id=True\n        ).load()\n        assert (\n            dataset.graph.edge_attributes is not None\n            and gb.ORIGINAL_EDGE_ID in dataset.graph.edge_attributes\n        )\n\n        # Case1. Test modify the `type` field.\n        dataset = gb.OnDiskDataset(test_dir)\n        dataset.yaml_data[\"graph_topology\"][\"type\"] = \"fake_type\"\n        with pytest.raises(\n            pydantic.ValidationError,\n            # As error message diffs in pydantic 1.x and 2.x, we just match\n            # keyword only.\n            match=\"'FusedCSCSamplingGraph'\",\n        ):\n            dataset.load()\n\n        # Case2. Test modify the `path` field.\n        dataset = gb.OnDiskDataset(test_dir)\n        dataset.yaml_data[\"graph_topology\"][\"path\"] = \"fake_path\"\n        with pytest.raises(\n            FileNotFoundError,\n            match=r\"\\[Errno 2\\] No such file or directory:\",\n        ):\n            dataset.load()\n        # Modifying the `path` field to an absolute path should work.\n        # In os.path.join, if a segment is an absolute path (which\n        # on Windows requires both a drive and a root), then all\n        # previous segments are ignored and joining continues from\n        # the absolute path segment.\n        dataset = gb.OnDiskDataset(test_dir).load()\n        original_graph = dataset.graph\n        dataset.yaml_data[\"graph_topology\"][\"path\"] = os.path.join(\n            test_dir, dataset.yaml_data[\"graph_topology\"][\"path\"]\n        )\n        dataset.load()\n        modify_graph = dataset.graph\n        assert torch.equal(\n            original_graph.csc_indptr,\n            modify_graph.csc_indptr,\n        )\n        original_graph = None\n        modify_graph = None\n        dataset = None\n\n    with tempfile.TemporaryDirectory() as test_dir:\n        # All metadata fields are specified.\n        dataset_name = \"graphbolt_test\"\n        num_nodes = 4000\n        num_edges = 20000\n        num_classes = 10\n\n        # Generate random graph.\n        yaml_content = gbt.random_homo_graphbolt_graph(\n            test_dir,\n            dataset_name,\n            num_nodes,\n            num_edges,\n            num_classes,\n            edge_fmt=edge_fmt,\n        )\n        yaml_file = os.path.join(test_dir, \"metadata.yaml\")\n        with open(yaml_file, \"w\") as f:\n            f.write(yaml_content)\n\n        # Test do not generate original_edge_id.\n        dataset = gb.OnDiskDataset(\n            test_dir, include_original_edge_id=False\n        ).load()\n        assert (\n            dataset.graph.edge_attributes is None\n            or gb.ORIGINAL_EDGE_ID not in dataset.graph.edge_attributes\n        )\n        dataset = None\n\n\n@pytest.mark.parametrize(\"edge_fmt\", [\"csv\", \"numpy\"])\ndef test_OnDiskDataset_load_tasks(edge_fmt):\n    \"\"\"Test preprocess of OnDiskDataset.\"\"\"\n    with tempfile.TemporaryDirectory() as test_dir:\n        # All metadata fields are specified.\n        dataset_name = \"graphbolt_test\"\n        num_nodes = 4000\n        num_edges = 20000\n        num_classes = 10\n\n        # Generate random graph.\n        yaml_content = gbt.random_homo_graphbolt_graph(\n            test_dir,\n            dataset_name,\n            num_nodes,\n            num_edges,\n            num_classes,\n            edge_fmt=edge_fmt,\n        )\n        yaml_file = os.path.join(test_dir, \"metadata.yaml\")\n        with open(yaml_file, \"w\") as f:\n            f.write(yaml_content)\n\n        # Case1. Test modify the `name` field.\n        dataset = gb.OnDiskDataset(test_dir)\n        dataset.yaml_data[\"tasks\"][0][\"name\"] = \"fake_name\"\n        dataset.load()\n        assert dataset.tasks[0].metadata[\"name\"] == \"fake_name\"\n\n        # Case2. Test modify the `num_classes` field.\n        dataset = gb.OnDiskDataset(test_dir)\n        dataset.yaml_data[\"tasks\"][0][\"num_classes\"] = 100\n        dataset.load()\n        assert dataset.tasks[0].metadata[\"num_classes\"] == 100\n\n        # Case3. Test modify the `format` field.\n        dataset = gb.OnDiskDataset(test_dir)\n        # Change the `format` field to torch.\n        dataset.yaml_data[\"tasks\"][0][\"train_set\"][0][\"data\"][0][\n            \"format\"\n        ] = \"torch\"\n        with pytest.raises(pickle.UnpicklingError):\n            dataset.load()\n\n        dataset = gb.OnDiskDataset(test_dir)\n        dataset.yaml_data[\"tasks\"][0][\"train_set\"][0][\"data\"][0][\n            \"format\"\n        ] = \"torch\"\n        # Change the `in_memory` field to False will also raise an\n        # UnpicklingError. Unlike the case of testing `feature_data`.\n        dataset.yaml_data[\"tasks\"][0][\"train_set\"][0][\"data\"][0][\n            \"in_memory\"\n        ] = False\n        with pytest.raises(pickle.UnpicklingError):\n            dataset.load()\n\n        # Case4. Test modify the `path` field.\n        dataset = gb.OnDiskDataset(test_dir)\n        # Use invalid path will raise an FileNotFoundError.\n        dataset.yaml_data[\"tasks\"][0][\"train_set\"][0][\"data\"][0][\n            \"path\"\n        ] = \"fake_path\"\n        with pytest.raises(\n            FileNotFoundError,\n            match=r\"\\[Errno 2\\] No such file or directory:\",\n        ):\n            dataset.load()\n\n        # Modifying the `path` field to an absolute path should work.\n        # In os.path.join, if a segment is an absolute path (which\n        # on Windows requires both a drive and a root), then all\n        # previous segments are ignored and joining continues from\n        # the absolute path segment.\n        dataset = gb.OnDiskDataset(test_dir).load()\n        original_train_set = dataset.tasks[0].train_set._items\n        dataset.yaml_data[\"tasks\"][0][\"train_set\"][0][\"data\"][0][\n            \"path\"\n        ] = os.path.join(\n            test_dir,\n            dataset.yaml_data[\"tasks\"][0][\"train_set\"][0][\"data\"][0][\"path\"],\n        )\n        dataset.load()\n        modify_train_set = dataset.tasks[0].train_set._items\n        assert torch.equal(\n            original_train_set[0],\n            modify_train_set[0],\n        )\n        original_train_set = None\n        modify_train_set = None\n        dataset = None\n\n\ndef test_OnDiskDataset_all_nodes_set_homo():\n    \"\"\"Test homograph's all nodes set of OnDiskDataset.\"\"\"\n    csc_indptr, indices = gbt.random_homo_graph(1000, 10 * 1000)\n    graph = gb.fused_csc_sampling_graph(csc_indptr, indices)\n\n    with tempfile.TemporaryDirectory() as test_dir:\n        graph_path = os.path.join(test_dir, \"fused_csc_sampling_graph.pt\")\n        torch.save(graph, graph_path)\n\n        yaml_content = f\"\"\"\n            graph_topology:\n              type: FusedCSCSamplingGraph\n              path: {graph_path}\n        \"\"\"\n        dataset = write_yaml_and_load_dataset(yaml_content, test_dir)\n        all_nodes_set = dataset.all_nodes_set\n        assert isinstance(all_nodes_set, gb.ItemSet)\n        assert all_nodes_set.names == (\"seeds\",)\n        for i, item in enumerate(all_nodes_set):\n            assert i == item\n\n        dataset = None\n\n\ndef test_OnDiskDataset_all_nodes_set_hetero():\n    \"\"\"Test heterograph's all nodes set of OnDiskDataset.\"\"\"\n    (\n        csc_indptr,\n        indices,\n        node_type_offset,\n        type_per_edge,\n        node_type_to_id,\n        edge_type_to_id,\n    ) = gbt.random_hetero_graph(1000, 10 * 1000, 3, 4)\n    graph = gb.fused_csc_sampling_graph(\n        csc_indptr,\n        indices,\n        node_type_offset=node_type_offset,\n        type_per_edge=type_per_edge,\n        node_type_to_id=node_type_to_id,\n        edge_type_to_id=edge_type_to_id,\n        edge_attributes=None,\n    )\n\n    with tempfile.TemporaryDirectory() as test_dir:\n        graph_path = os.path.join(test_dir, \"fused_csc_sampling_graph.pt\")\n        torch.save(graph, graph_path)\n\n        yaml_content = f\"\"\"\n            graph_topology:\n              type: FusedCSCSamplingGraph\n              path: {graph_path}\n        \"\"\"\n        dataset = write_yaml_and_load_dataset(yaml_content, test_dir)\n        all_nodes_set = dataset.all_nodes_set\n        assert isinstance(all_nodes_set, gb.HeteroItemSet)\n        assert all_nodes_set.names == (\"seeds\",)\n        for i, item in enumerate(all_nodes_set):\n            assert len(item) == 1\n            assert isinstance(item, dict)\n\n        dataset = None\n\n\n@pytest.mark.parametrize(\"fmt\", [\"numpy\", \"torch\"])\ndef test_OnDiskDataset_load_1D_feature(fmt):\n    with tempfile.TemporaryDirectory() as test_dir:\n        # All metadata fields are specified.\n        dataset_name = \"graphbolt_test\"\n        num_nodes = 4\n        num_edges = 20\n        num_classes = 1\n\n        type_name = \"npy\" if fmt == \"numpy\" else \"pt\"\n        # Generate random edges.\n        nodes = np.repeat(np.arange(num_nodes), 5)\n        neighbors = np.random.randint(0, num_nodes, size=(num_edges))\n        edges = np.stack([nodes, neighbors], axis=1)\n        # Write into edges/edge.csv\n        os.makedirs(os.path.join(test_dir, \"edges\"), exist_ok=True)\n        edges = pd.DataFrame(edges, columns=[\"src\", \"dst\"])\n        edge_path = os.path.join(\"edges\", \"edge.csv\")\n        edges.to_csv(\n            os.path.join(test_dir, edge_path),\n            index=False,\n            header=False,\n        )\n\n        # Generate random graph edge-feats.\n        edge_feats = np.random.rand(num_edges, 5)\n        os.makedirs(os.path.join(test_dir, \"data\"), exist_ok=True)\n        edge_feat_path = os.path.join(\"data\", f\"edge-feat.{type_name}\")\n\n        # Generate random 1-D node-feats.\n        node_feats = np.random.rand(num_nodes)\n        node_feat_path = os.path.join(\"data\", f\"node-feat.{type_name}\")\n        assert node_feats.ndim == 1\n\n        # Generate 1-D train set.\n        os.makedirs(os.path.join(test_dir, \"set\"), exist_ok=True)\n        train_path = os.path.join(\"set\", f\"train.{type_name}\")\n\n        if fmt == \"numpy\":\n            np.save(os.path.join(test_dir, edge_feat_path), edge_feats)\n            np.save(os.path.join(test_dir, node_feat_path), node_feats)\n            np.save(os.path.join(test_dir, train_path), np.array([0, 1, 0]))\n        else:\n            torch.save(\n                torch.from_numpy(edge_feats),\n                os.path.join(test_dir, edge_feat_path),\n            )\n            torch.save(\n                torch.from_numpy(node_feats),\n                os.path.join(test_dir, node_feat_path),\n            )\n            torch.save(\n                torch.tensor([0, 1, 0]), os.path.join(test_dir, train_path)\n            )\n\n        yaml_content = f\"\"\"\n            dataset_name: {dataset_name}\n            graph: # graph structure and required attributes.\n              nodes:\n                - num: {num_nodes}\n              edges:\n                - format: csv\n                  path: {edge_path}\n              feature_data:\n                  - domain: edge\n                    type: null\n                    name: feat\n                    format: {fmt}\n                    in_memory: true\n                    path: {edge_feat_path}\n            feature_data:\n              - domain: node\n                type: null\n                name: feat\n                format: {fmt}\n                in_memory: false\n                path: {node_feat_path}\n            tasks:\n                - name: node_classification\n                  num_classes: {num_classes}\n                  train_set:\n                    - type: null\n                      data:\n                        - format: {fmt}\n                          path: {train_path}\n        \"\"\"\n        yaml_file = os.path.join(test_dir, \"metadata.yaml\")\n        with open(yaml_file, \"w\") as f:\n            f.write(yaml_content)\n\n        dataset = gb.OnDiskDataset(test_dir).load()\n        feature = dataset.feature.read(\"node\", None, \"feat\")\n        # Test whether feature has changed.\n        assert torch.equal(torch.from_numpy(node_feats.reshape(-1, 1)), feature)\n        # Test whether itemsets keep same.\n        assert torch.equal(\n            dataset.tasks[0].train_set._items[0], torch.tensor([0, 1, 0])\n        )\n        dataset = None\n        node_feats = None\n        feature = None\n\n\ndef test_BuiltinDataset():\n    \"\"\"Test BuiltinDataset.\"\"\"\n    with tempfile.TemporaryDirectory() as test_dir:\n        # Case 1: download from DGL S3 storage.\n        dataset_name = \"test-dataset-231207\"\n        # Add dataset to the builtin dataset list for testing only. Due to we\n        # add `seeds` suffix to datasets when downloading, so we append\n        # dataset name with `-seeds` suffix here.\n        gb.BuiltinDataset._all_datasets.append(dataset_name + \"-seeds\")\n        dataset = gb.BuiltinDataset(name=dataset_name, root=test_dir).load()\n        assert dataset.graph is not None\n        assert dataset.feature is not None\n        assert dataset.tasks is not None\n        assert dataset.dataset_name == dataset_name\n\n        # Case 2: dataset is already downloaded.\n        dataset = gb.BuiltinDataset(name=dataset_name, root=test_dir).load()\n        assert dataset.graph is not None\n        assert dataset.feature is not None\n        assert dataset.tasks is not None\n        assert dataset.dataset_name == dataset_name\n\n        dataset = None\n\n        # Case 3: dataset is not available.\n        dataset_name = \"fake_name-seeds\"\n        with pytest.raises(\n            RuntimeError,\n            match=rf\"Dataset {dataset_name} is not available.*\",\n        ):\n            _ = gb.BuiltinDataset(name=dataset_name, root=test_dir).load()\n\n\n@pytest.mark.parametrize(\"auto_cast\", [True, False])\n@pytest.mark.parametrize(\"include_original_edge_id\", [True, False])\n@pytest.mark.parametrize(\"edge_fmt\", [\"csv\", \"numpy\"])\ndef test_OnDiskDataset_homogeneous(\n    auto_cast, include_original_edge_id, edge_fmt\n):\n    \"\"\"Preprocess and instantiate OnDiskDataset for homogeneous graph.\"\"\"\n    with tempfile.TemporaryDirectory() as test_dir:\n        # All metadata fields are specified.\n        dataset_name = \"graphbolt_test\"\n        num_nodes = 4000\n        num_edges = 20000\n        num_classes = 10\n\n        # Generate random graph.\n        yaml_content = gbt.random_homo_graphbolt_graph(\n            test_dir,\n            dataset_name,\n            num_nodes,\n            num_edges,\n            num_classes,\n            edge_fmt=edge_fmt,\n        )\n        yaml_file = os.path.join(test_dir, \"metadata.yaml\")\n        with open(yaml_file, \"w\") as f:\n            f.write(yaml_content)\n\n        dataset = gb.OnDiskDataset(\n            test_dir,\n            include_original_edge_id=include_original_edge_id,\n            auto_cast_to_optimal_dtype=auto_cast,\n        ).load()\n\n        assert dataset.dataset_name == dataset_name\n\n        graph = dataset.graph\n        assert isinstance(graph, gb.FusedCSCSamplingGraph)\n        assert graph.total_num_nodes == num_nodes\n        assert graph.total_num_edges == num_edges\n        assert (\n            graph.node_attributes is not None\n            and \"feat\" in graph.node_attributes\n        )\n        assert (\n            graph.edge_attributes is not None\n            and \"feat\" in graph.edge_attributes\n        )\n        assert (\n            not include_original_edge_id\n        ) or gb.ORIGINAL_EDGE_ID in graph.edge_attributes\n\n        tasks = dataset.tasks\n        assert len(tasks) == 1\n        assert isinstance(tasks[0].train_set, gb.ItemSet)\n        assert isinstance(tasks[0].validation_set, gb.ItemSet)\n        assert isinstance(tasks[0].test_set, gb.ItemSet)\n        assert tasks[0].train_set._items[0].dtype == graph.indices.dtype\n        assert tasks[0].validation_set._items[0].dtype == graph.indices.dtype\n        assert tasks[0].test_set._items[0].dtype == graph.indices.dtype\n        assert dataset.all_nodes_set._items.dtype == graph.indices.dtype\n        assert tasks[0].metadata[\"num_classes\"] == num_classes\n        assert tasks[0].metadata[\"name\"] == \"link_prediction\"\n\n        assert dataset.feature.size(\"node\", None, \"feat\")[0] == num_classes\n        assert dataset.feature.size(\"edge\", None, \"feat\")[0] == num_classes\n\n        for itemset in [\n            tasks[0].train_set,\n            tasks[0].validation_set,\n            tasks[0].test_set,\n            dataset.all_nodes_set,\n        ]:\n            datapipe = gb.ItemSampler(itemset, batch_size=10)\n            datapipe = datapipe.sample_neighbor(graph, [-1])\n            datapipe = datapipe.fetch_feature(\n                dataset.feature, node_feature_keys=[\"feat\"]\n            )\n            dataloader = gb.DataLoader(datapipe)\n            for _ in dataloader:\n                pass\n\n        graph = None\n        tasks = None\n        dataset = None\n\n\n@pytest.mark.parametrize(\"auto_cast\", [True, False])\n@pytest.mark.parametrize(\"include_original_edge_id\", [True, False])\n@pytest.mark.parametrize(\"edge_fmt\", [\"csv\", \"numpy\"])\ndef test_OnDiskDataset_heterogeneous(\n    auto_cast, include_original_edge_id, edge_fmt\n):\n    \"\"\"Preprocess and instantiate OnDiskDataset for heterogeneous graph.\"\"\"\n    with tempfile.TemporaryDirectory() as test_dir:\n        dataset_name = \"OnDiskDataset_hetero\"\n        num_nodes = {\n            \"user\": 1000,\n            \"item\": 2000,\n        }\n        num_edges = {\n            (\"user\", \"follow\", \"user\"): 10000,\n            (\"user\", \"click\", \"item\"): 20000,\n        }\n        num_classes = 10\n        gbt.generate_raw_data_for_hetero_dataset(\n            test_dir,\n            dataset_name,\n            num_nodes,\n            num_edges,\n            num_classes,\n            edge_fmt=edge_fmt,\n        )\n\n        dataset = gb.OnDiskDataset(\n            test_dir,\n            include_original_edge_id=include_original_edge_id,\n            auto_cast_to_optimal_dtype=auto_cast,\n        ).load()\n\n        assert dataset.dataset_name == dataset_name\n\n        graph = dataset.graph\n        assert isinstance(graph, gb.FusedCSCSamplingGraph)\n        assert graph.total_num_nodes == sum(\n            num_nodes for num_nodes in num_nodes.values()\n        )\n        assert graph.total_num_edges == sum(\n            num_edge for num_edge in num_edges.values()\n        )\n        expected_dtype = torch.int32 if auto_cast else torch.int64\n        assert graph.indices.dtype == expected_dtype\n        assert (\n            graph.node_attributes is not None\n            and \"feat\" in graph.node_attributes\n        )\n        assert (\n            graph.edge_attributes is not None\n            and \"feat\" in graph.edge_attributes\n        )\n        assert (\n            not include_original_edge_id\n        ) or gb.ORIGINAL_EDGE_ID in graph.edge_attributes\n\n        tasks = dataset.tasks\n        assert len(tasks) == 1\n        assert isinstance(tasks[0].train_set, gb.HeteroItemSet)\n        assert isinstance(tasks[0].validation_set, gb.HeteroItemSet)\n        assert isinstance(tasks[0].test_set, gb.HeteroItemSet)\n        assert tasks[0].metadata[\"num_classes\"] == num_classes\n        assert tasks[0].metadata[\"name\"] == \"node_classification\"\n\n        assert dataset.feature.size(\"node\", \"user\", \"feat\")[0] == num_classes\n        assert dataset.feature.size(\"node\", \"item\", \"feat\")[0] == num_classes\n\n        for itemset in [\n            tasks[0].train_set,\n            tasks[0].validation_set,\n            tasks[0].test_set,\n            dataset.all_nodes_set,\n        ]:\n            datapipe = gb.ItemSampler(itemset, batch_size=10)\n            datapipe = datapipe.sample_neighbor(graph, [-1])\n            datapipe = datapipe.fetch_feature(\n                dataset.feature, node_feature_keys={\"user\": [\"feat\"]}\n            )\n            dataloader = gb.DataLoader(datapipe)\n            for _ in dataloader:\n                pass\n\n        graph = None\n        tasks = None\n        dataset = None\n\n\ndef test_OnDiskDataset_force_preprocess(capsys):\n    \"\"\"Test force preprocess of OnDiskDataset.\"\"\"\n    with tempfile.TemporaryDirectory() as test_dir:\n        # All metadata fields are specified.\n        dataset_name = \"graphbolt_test\"\n        num_nodes = 4000\n        num_edges = 20000\n        num_classes = 10\n\n        # Generate random graph.\n        yaml_content = gbt.random_homo_graphbolt_graph(\n            test_dir,\n            dataset_name,\n            num_nodes,\n            num_edges,\n            num_classes,\n        )\n        yaml_file = os.path.join(test_dir, \"metadata.yaml\")\n        with open(yaml_file, \"w\") as f:\n            f.write(yaml_content)\n\n        # First preprocess on-disk dataset.\n        dataset = gb.OnDiskDataset(\n            test_dir, include_original_edge_id=False, force_preprocess=False\n        ).load()\n        captured = capsys.readouterr().out.split(\"\\n\")\n        assert captured == [\n            \"Start to preprocess the on-disk dataset.\",\n            \"Finish preprocessing the on-disk dataset.\",\n            \"\",\n        ]\n        tasks = dataset.tasks\n        assert tasks[0].metadata[\"name\"] == \"link_prediction\"\n\n        # Change yaml_data, but do not force preprocess on-disk dataset.\n        with open(yaml_file, \"r\") as f:\n            yaml_data = yaml.safe_load(f)\n        yaml_data[\"tasks\"][0][\"name\"] = \"fake_name\"\n        with open(yaml_file, \"w\") as f:\n            yaml.dump(yaml_data, f)\n        dataset = gb.OnDiskDataset(\n            test_dir, include_original_edge_id=False, force_preprocess=False\n        ).load()\n        captured = capsys.readouterr().out.split(\"\\n\")\n        assert captured == [\"The dataset is already preprocessed.\", \"\"]\n        tasks = dataset.tasks\n        assert tasks[0].metadata[\"name\"] == \"link_prediction\"\n\n        # Force preprocess on-disk dataset.\n        dataset = gb.OnDiskDataset(\n            test_dir, include_original_edge_id=False, force_preprocess=True\n        ).load()\n        captured = capsys.readouterr().out.split(\"\\n\")\n        assert captured == [\n            \"The on-disk dataset is re-preprocessing, so the existing \"\n            + \"preprocessed dataset has been removed.\",\n            \"Start to preprocess the on-disk dataset.\",\n            \"Finish preprocessing the on-disk dataset.\",\n            \"\",\n        ]\n        tasks = dataset.tasks\n        assert tasks[0].metadata[\"name\"] == \"fake_name\"\n\n        tasks = None\n        dataset = None\n\n\ndef test_OnDiskDataset_auto_force_preprocess(capsys):\n    \"\"\"Test force preprocess of OnDiskDataset.\"\"\"\n    with tempfile.TemporaryDirectory() as test_dir:\n        # All metadata fields are specified.\n        dataset_name = \"graphbolt_test\"\n        num_nodes = 4000\n        num_edges = 20000\n        num_classes = 10\n\n        # Generate random graph.\n        yaml_content = gbt.random_homo_graphbolt_graph(\n            test_dir,\n            dataset_name,\n            num_nodes,\n            num_edges,\n            num_classes,\n        )\n        yaml_file = os.path.join(test_dir, \"metadata.yaml\")\n        with open(yaml_file, \"w\") as f:\n            f.write(yaml_content)\n\n        # First preprocess on-disk dataset.\n        dataset = gb.OnDiskDataset(\n            test_dir, include_original_edge_id=False\n        ).load()\n        captured = capsys.readouterr().out.split(\"\\n\")\n        assert captured == [\n            \"Start to preprocess the on-disk dataset.\",\n            \"Finish preprocessing the on-disk dataset.\",\n            \"\",\n        ]\n        tasks = dataset.tasks\n        assert tasks[0].metadata[\"name\"] == \"link_prediction\"\n\n        # 1. Change yaml_data.\n        with open(yaml_file, \"r\") as f:\n            yaml_data = yaml.safe_load(f)\n        yaml_data[\"tasks\"][0][\"name\"] = \"fake_name\"\n        with open(yaml_file, \"w\") as f:\n            yaml.dump(yaml_data, f)\n        dataset = gb.OnDiskDataset(\n            test_dir, include_original_edge_id=False\n        ).load()\n        captured = capsys.readouterr().out.split(\"\\n\")\n        assert captured == [\n            \"The on-disk dataset is re-preprocessing, so the existing \"\n            + \"preprocessed dataset has been removed.\",\n            \"Start to preprocess the on-disk dataset.\",\n            \"Finish preprocessing the on-disk dataset.\",\n            \"\",\n        ]\n        tasks = dataset.tasks\n        assert tasks[0].metadata[\"name\"] == \"fake_name\"\n\n        # 2. Change edge feature.\n        edge_feats = np.random.rand(num_edges, num_classes)\n        edge_feat_path = os.path.join(\"data\", \"edge-feat.npy\")\n        np.save(os.path.join(test_dir, edge_feat_path), edge_feats)\n        dataset = gb.OnDiskDataset(\n            test_dir, include_original_edge_id=False\n        ).load()\n        captured = capsys.readouterr().out.split(\"\\n\")\n        assert captured == [\n            \"The on-disk dataset is re-preprocessing, so the existing \"\n            + \"preprocessed dataset has been removed.\",\n            \"Start to preprocess the on-disk dataset.\",\n            \"Finish preprocessing the on-disk dataset.\",\n            \"\",\n        ]\n        assert torch.equal(\n            dataset.feature.read(\"edge\", None, \"feat\"),\n            torch.from_numpy(edge_feats),\n        )\n        graph = dataset.graph\n        assert gb.ORIGINAL_EDGE_ID not in graph.edge_attributes\n\n        # 3. Change include_original_edge_id.\n        dataset = gb.OnDiskDataset(\n            test_dir, include_original_edge_id=True\n        ).load()\n        captured = capsys.readouterr().out.split(\"\\n\")\n        assert captured == [\n            \"The on-disk dataset is re-preprocessing, so the existing \"\n            + \"preprocessed dataset has been removed.\",\n            \"Start to preprocess the on-disk dataset.\",\n            \"Finish preprocessing the on-disk dataset.\",\n            \"\",\n        ]\n        graph = dataset.graph\n        assert gb.ORIGINAL_EDGE_ID in graph.edge_attributes\n\n        # 4. Change Nothing.\n        dataset = gb.OnDiskDataset(\n            test_dir, include_original_edge_id=True\n        ).load()\n        captured = capsys.readouterr().out.split(\"\\n\")\n        assert captured == [\"The dataset is already preprocessed.\", \"\"]\n\n        graph = None\n        tasks = None\n        dataset = None\n\n\ndef test_OnDiskTask_repr_homogeneous():\n    item_set = gb.ItemSet(\n        (torch.arange(0, 5), torch.arange(5, 10)),\n        names=(\"seeds\", \"labels\"),\n    )\n    metadata = {\"name\": \"node_classification\"}\n    task = gb.OnDiskTask(metadata, item_set, item_set, item_set)\n    expected_str = (\n        \"OnDiskTask(validation_set=ItemSet(\\n\"\n        \"               items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])),\\n\"\n        \"               names=('seeds', 'labels'),\\n\"\n        \"           ),\\n\"\n        \"           train_set=ItemSet(\\n\"\n        \"               items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])),\\n\"\n        \"               names=('seeds', 'labels'),\\n\"\n        \"           ),\\n\"\n        \"           test_set=ItemSet(\\n\"\n        \"               items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])),\\n\"\n        \"               names=('seeds', 'labels'),\\n\"\n        \"           ),\\n\"\n        \"           metadata={'name': 'node_classification'},)\"\n    )\n    assert repr(task) == expected_str, task\n\n\ndef test_OnDiskDataset_not_include_eids():\n    with tempfile.TemporaryDirectory() as test_dir:\n        # All metadata fields are specified.\n        dataset_name = \"graphbolt_test\"\n        num_nodes = 4000\n        num_edges = 20000\n        num_classes = 10\n\n        # Generate random graph.\n        yaml_content = gbt.random_homo_graphbolt_graph(\n            test_dir,\n            dataset_name,\n            num_nodes,\n            num_edges,\n            num_classes,\n        )\n        yaml_file = os.path.join(test_dir, \"metadata.yaml\")\n        with open(yaml_file, \"w\") as f:\n            f.write(yaml_content)\n\n        with pytest.warns(\n            GBWarning,\n            match=\"Edge feature is stored, but edge IDs are not saved.\",\n        ):\n            gb.OnDiskDataset(test_dir, include_original_edge_id=False)\n\n\ndef test_OnDiskTask_repr_heterogeneous():\n    item_set = gb.HeteroItemSet(\n        {\n            \"user\": gb.ItemSet(torch.arange(0, 5), names=\"seeds\"),\n            \"item\": gb.ItemSet(torch.arange(5, 10), names=\"seeds\"),\n        }\n    )\n    metadata = {\"name\": \"node_classification\"}\n    task = gb.OnDiskTask(metadata, item_set, item_set, item_set)\n    expected_str = (\n        \"OnDiskTask(validation_set=HeteroItemSet(\\n\"\n        \"               itemsets={'user': ItemSet(\\n\"\n        \"                            items=(tensor([0, 1, 2, 3, 4]),),\\n\"\n        \"                            names=('seeds',),\\n\"\n        \"                        ), 'item': ItemSet(\\n\"\n        \"                            items=(tensor([5, 6, 7, 8, 9]),),\\n\"\n        \"                            names=('seeds',),\\n\"\n        \"                        )},\\n\"\n        \"               names=('seeds',),\\n\"\n        \"           ),\\n\"\n        \"           train_set=HeteroItemSet(\\n\"\n        \"               itemsets={'user': ItemSet(\\n\"\n        \"                            items=(tensor([0, 1, 2, 3, 4]),),\\n\"\n        \"                            names=('seeds',),\\n\"\n        \"                        ), 'item': ItemSet(\\n\"\n        \"                            items=(tensor([5, 6, 7, 8, 9]),),\\n\"\n        \"                            names=('seeds',),\\n\"\n        \"                        )},\\n\"\n        \"               names=('seeds',),\\n\"\n        \"           ),\\n\"\n        \"           test_set=HeteroItemSet(\\n\"\n        \"               itemsets={'user': ItemSet(\\n\"\n        \"                            items=(tensor([0, 1, 2, 3, 4]),),\\n\"\n        \"                            names=('seeds',),\\n\"\n        \"                        ), 'item': ItemSet(\\n\"\n        \"                            items=(tensor([5, 6, 7, 8, 9]),),\\n\"\n        \"                            names=('seeds',),\\n\"\n        \"                        )},\\n\"\n        \"               names=('seeds',),\\n\"\n        \"           ),\\n\"\n        \"           metadata={'name': 'node_classification'},)\"\n    )\n    assert repr(task) == expected_str, task\n\n\ndef test_OnDiskDataset_load_tasks_selectively():\n    \"\"\"Test preprocess of OnDiskDataset.\"\"\"\n    with tempfile.TemporaryDirectory() as test_dir:\n        # All metadata fields are specified.\n        dataset_name = \"graphbolt_test\"\n        num_nodes = 4000\n        num_edges = 20000\n        num_classes = 10\n\n        # Generate random graph.\n        yaml_content = gbt.random_homo_graphbolt_graph(\n            test_dir,\n            dataset_name,\n            num_nodes,\n            num_edges,\n            num_classes,\n        )\n        train_path = os.path.join(\"set\", \"train.npy\")\n\n        yaml_content += f\"\"\"      - name: node_classification\n            num_classes: {num_classes}\n            train_set:\n              - type: null\n                data:\n                  - format: numpy\n                    path: {train_path}\n        \"\"\"\n        yaml_file = os.path.join(test_dir, \"metadata.yaml\")\n        with open(yaml_file, \"w\") as f:\n            f.write(yaml_content)\n\n        # Case1. Test load all tasks.\n        dataset = gb.OnDiskDataset(test_dir).load()\n        assert len(dataset.tasks) == 2\n\n        # Case2. Test load tasks selectively.\n        dataset = gb.OnDiskDataset(test_dir).load(tasks=\"link_prediction\")\n        assert len(dataset.tasks) == 1\n        assert dataset.tasks[0].metadata[\"name\"] == \"link_prediction\"\n        dataset = gb.OnDiskDataset(test_dir).load(tasks=[\"link_prediction\"])\n        assert len(dataset.tasks) == 1\n        assert dataset.tasks[0].metadata[\"name\"] == \"link_prediction\"\n\n        # Case3. Test load tasks with non-existent task name.\n        with pytest.warns(\n            GBWarning,\n            match=\"Below tasks are not found in YAML: {'fake-name'}. Skipped.\",\n        ):\n            dataset = gb.OnDiskDataset(test_dir).load(tasks=[\"fake-name\"])\n            assert len(dataset.tasks) == 0\n\n        # Case4. Test load tasks selectively with incorrect task type.\n        with pytest.raises(TypeError):\n            dataset = gb.OnDiskDataset(test_dir).load(tasks=2)\n\n        dataset = None\n\n\ndef test_OnDiskDataset_preprocess_graph_with_single_type():\n    \"\"\"Test for graph with single node/edge type.\"\"\"\n    with tempfile.TemporaryDirectory() as test_dir:\n        # All metadata fields are specified.\n        dataset_name = \"graphbolt_test\"\n        num_nodes = 4000\n        num_edges = 20000\n\n        # Generate random edges.\n        nodes = np.repeat(np.arange(num_nodes), 5)\n        neighbors = np.random.randint(0, num_nodes, size=(num_edges))\n        edges = np.stack([nodes, neighbors], axis=1)\n        # Write into edges/edge.csv\n        os.makedirs(os.path.join(test_dir, \"edges/\"), exist_ok=True)\n        edges = pd.DataFrame(edges, columns=[\"src\", \"dst\"])\n        edges.to_csv(\n            os.path.join(test_dir, \"edges/edge.csv\"),\n            index=False,\n            header=False,\n        )\n\n        # Generate random graph edge-feats.\n        edge_feats = np.random.rand(num_edges, 5)\n        os.makedirs(os.path.join(test_dir, \"data/\"), exist_ok=True)\n        np.save(os.path.join(test_dir, \"data/edge-feat.npy\"), edge_feats)\n\n        # Generate random node-feats.\n        node_feats = np.random.rand(num_nodes, 10)\n        np.save(os.path.join(test_dir, \"data/node-feat.npy\"), node_feats)\n\n        yaml_content = f\"\"\"\n            dataset_name: {dataset_name}\n            graph: # graph structure and required attributes.\n                nodes:\n                    - num: {num_nodes}\n                      type: author\n                edges:\n                    - type: author:collab:author\n                      format: csv\n                      path: edges/edge.csv\n                feature_data:\n                    - domain: edge\n                      type: author:collab:author\n                      name: feat\n                      format: numpy\n                      path: data/edge-feat.npy\n                    - domain: node\n                      type: author\n                      name: feat\n                      format: numpy\n                      path: data/node-feat.npy\n        \"\"\"\n        yaml_file = os.path.join(test_dir, \"metadata.yaml\")\n        with open(yaml_file, \"w\") as f:\n            f.write(yaml_content)\n\n        dataset = gb.OnDiskDataset(test_dir).load()\n        assert dataset.dataset_name == dataset_name\n\n        graph = dataset.graph\n        assert isinstance(graph, gb.FusedCSCSamplingGraph)\n        assert graph.total_num_nodes == num_nodes\n        assert graph.total_num_edges == num_edges\n        assert (\n            graph.node_attributes is not None\n            and \"feat\" in graph.node_attributes\n        )\n        assert (\n            graph.edge_attributes is not None\n            and \"feat\" in graph.edge_attributes\n        )\n        assert torch.equal(graph.node_type_offset, torch.tensor([0, num_nodes]))\n        assert torch.equal(\n            graph.type_per_edge,\n            torch.zeros(num_edges),\n        )\n        assert graph.edge_type_to_id == {\"author:collab:author\": 0}\n        assert graph.node_type_to_id == {\"author\": 0}\n"
  },
  {
    "path": "tests/python/pytorch/graphbolt/impl/test_sampled_subgraph_impl.py",
    "content": "import unittest\n\nimport backend as F\n\nimport dgl\nimport dgl.graphbolt as gb\nimport pytest\nimport torch\nfrom dgl.graphbolt.impl.sampled_subgraph_impl import SampledSubgraphImpl\n\n\ndef _assert_container_equal(lhs, rhs):\n    if isinstance(lhs, torch.Tensor):\n        assert isinstance(rhs, torch.Tensor)\n        assert torch.equal(lhs, rhs)\n    elif isinstance(lhs, tuple):\n        assert isinstance(rhs, tuple)\n        assert len(lhs) == len(rhs)\n        for l, r in zip(lhs, rhs):\n            _assert_container_equal(l, r)\n    elif isinstance(lhs, gb.CSCFormatBase):\n        assert isinstance(rhs, gb.CSCFormatBase)\n        assert len(lhs.indptr) == len(rhs.indptr)\n        assert len(lhs.indices) == len(rhs.indices)\n        _assert_container_equal(lhs.indptr, rhs.indptr)\n        _assert_container_equal(lhs.indices, rhs.indices)\n    elif isinstance(lhs, dict):\n        assert isinstance(rhs, dict)\n        assert len(lhs) == len(rhs)\n        for key, value in lhs.items():\n            assert key in rhs\n            _assert_container_equal(value, rhs[key])\n\n\n@pytest.mark.parametrize(\"reverse_row\", [True, False])\n@pytest.mark.parametrize(\"reverse_column\", [True, False])\ndef test_exclude_edges_homo_deduplicated(reverse_row, reverse_column):\n    csc_formats = gb.CSCFormatBase(\n        indptr=torch.tensor([0, 0, 1, 2, 2, 3]), indices=torch.tensor([0, 3, 2])\n    )\n    if reverse_row:\n        original_row_node_ids = torch.tensor([10, 15, 11, 24, 9])\n        src_to_exclude = torch.tensor([11])\n    else:\n        original_row_node_ids = None\n        src_to_exclude = torch.tensor([2])\n\n    if reverse_column:\n        original_column_node_ids = torch.tensor([10, 15, 11, 24, 9])\n        dst_to_exclude = torch.tensor([9])\n    else:\n        original_column_node_ids = None\n        dst_to_exclude = torch.tensor([4])\n    original_edge_ids = torch.Tensor([5, 9, 10])\n    subgraph = SampledSubgraphImpl(\n        csc_formats,\n        original_column_node_ids,\n        original_row_node_ids,\n        original_edge_ids,\n    )\n    edges_to_exclude = torch.cat((src_to_exclude, dst_to_exclude)).view(2, -1).T\n    result = subgraph.exclude_edges(edges_to_exclude)\n    expected_csc_formats = gb.CSCFormatBase(\n        indptr=torch.tensor([0, 0, 1, 2, 2, 2]), indices=torch.tensor([0, 3])\n    )\n    if reverse_row:\n        expected_row_node_ids = torch.tensor([10, 15, 11, 24, 9])\n    else:\n        expected_row_node_ids = None\n    if reverse_column:\n        expected_column_node_ids = torch.tensor([10, 15, 11, 24, 9])\n    else:\n        expected_column_node_ids = None\n    expected_edge_ids = torch.Tensor([5, 9])\n\n    _assert_container_equal(result.sampled_csc, expected_csc_formats)\n    _assert_container_equal(\n        result.original_column_node_ids, expected_column_node_ids\n    )\n    _assert_container_equal(result.original_row_node_ids, expected_row_node_ids)\n    _assert_container_equal(result.original_edge_ids, expected_edge_ids)\n\n\n@pytest.mark.parametrize(\"reverse_row\", [True, False])\n@pytest.mark.parametrize(\"reverse_column\", [True, False])\ndef test_exclude_edges_homo_duplicated(reverse_row, reverse_column):\n    csc_formats = gb.CSCFormatBase(\n        indptr=torch.tensor([0, 0, 1, 3, 3, 5]),\n        indices=torch.tensor([0, 3, 3, 2, 2]),\n    )\n    if reverse_row:\n        original_row_node_ids = torch.tensor([10, 15, 11, 24, 9])\n        src_to_exclude = torch.tensor([24])\n    else:\n        original_row_node_ids = None\n        src_to_exclude = torch.tensor([3])\n\n    if reverse_column:\n        original_column_node_ids = torch.tensor([10, 15, 11, 24, 9])\n        dst_to_exclude = torch.tensor([11])\n    else:\n        original_column_node_ids = None\n        dst_to_exclude = torch.tensor([2])\n    original_edge_ids = torch.Tensor([5, 9, 9, 10, 10])\n    subgraph = SampledSubgraphImpl(\n        csc_formats,\n        original_column_node_ids,\n        original_row_node_ids,\n        original_edge_ids,\n    )\n    edges_to_exclude = torch.cat((src_to_exclude, dst_to_exclude)).view(2, -1).T\n    result = subgraph.exclude_edges(edges_to_exclude)\n    expected_csc_formats = gb.CSCFormatBase(\n        indptr=torch.tensor([0, 0, 1, 1, 1, 3]), indices=torch.tensor([0, 2, 2])\n    )\n    if reverse_row:\n        expected_row_node_ids = torch.tensor([10, 15, 11, 24, 9])\n    else:\n        expected_row_node_ids = None\n    if reverse_column:\n        expected_column_node_ids = torch.tensor([10, 15, 11, 24, 9])\n    else:\n        expected_column_node_ids = None\n    expected_edge_ids = torch.Tensor([5, 10, 10])\n    _assert_container_equal(result.sampled_csc, expected_csc_formats)\n    _assert_container_equal(\n        result.original_column_node_ids, expected_column_node_ids\n    )\n    _assert_container_equal(result.original_row_node_ids, expected_row_node_ids)\n    _assert_container_equal(result.original_edge_ids, expected_edge_ids)\n\n\n@pytest.mark.parametrize(\"reverse_row\", [True, False])\n@pytest.mark.parametrize(\"reverse_column\", [True, False])\ndef test_exclude_edges_hetero_deduplicated(reverse_row, reverse_column):\n    csc_formats = {\n        \"A:relation:B\": gb.CSCFormatBase(\n            indptr=torch.tensor([0, 1, 2, 3]),\n            indices=torch.tensor([2, 1, 0]),\n        )\n    }\n    if reverse_row:\n        original_row_node_ids = {\n            \"A\": torch.tensor([13, 14, 15]),\n        }\n        src_to_exclude = torch.tensor([15, 13])\n    else:\n        original_row_node_ids = None\n        src_to_exclude = torch.tensor([2, 0])\n    if reverse_column:\n        original_column_node_ids = {\n            \"B\": torch.tensor([10, 11, 12]),\n        }\n        dst_to_exclude = torch.tensor([10, 12])\n    else:\n        original_column_node_ids = None\n        dst_to_exclude = torch.tensor([0, 2])\n    original_edge_ids = {\"A:relation:B\": torch.tensor([19, 20, 21])}\n    subgraph = SampledSubgraphImpl(\n        sampled_csc=csc_formats,\n        original_column_node_ids=original_column_node_ids,\n        original_row_node_ids=original_row_node_ids,\n        original_edge_ids=original_edge_ids,\n    )\n\n    edges_to_exclude = {\n        \"A:relation:B\": torch.cat(\n            (\n                src_to_exclude,\n                dst_to_exclude,\n            )\n        )\n        .view(2, -1)\n        .T\n    }\n    result = subgraph.exclude_edges(edges_to_exclude)\n    expected_csc_formats = {\n        \"A:relation:B\": gb.CSCFormatBase(\n            indptr=torch.tensor([0, 0, 1, 1]),\n            indices=torch.tensor([1]),\n        )\n    }\n    if reverse_row:\n        expected_row_node_ids = {\n            \"A\": torch.tensor([13, 14, 15]),\n        }\n    else:\n        expected_row_node_ids = None\n    if reverse_column:\n        expected_column_node_ids = {\n            \"B\": torch.tensor([10, 11, 12]),\n        }\n    else:\n        expected_column_node_ids = None\n    expected_edge_ids = {\"A:relation:B\": torch.tensor([20])}\n\n    _assert_container_equal(result.sampled_csc, expected_csc_formats)\n    _assert_container_equal(\n        result.original_column_node_ids, expected_column_node_ids\n    )\n    _assert_container_equal(result.original_row_node_ids, expected_row_node_ids)\n    _assert_container_equal(result.original_edge_ids, expected_edge_ids)\n\n\n@pytest.mark.parametrize(\"reverse_row\", [True, False])\n@pytest.mark.parametrize(\"reverse_column\", [True, False])\ndef test_exclude_edges_hetero_duplicated(reverse_row, reverse_column):\n    csc_formats = {\n        \"A:relation:B\": gb.CSCFormatBase(\n            indptr=torch.tensor([0, 2, 4, 5]),\n            indices=torch.tensor([2, 2, 1, 1, 0]),\n        )\n    }\n    if reverse_row:\n        original_row_node_ids = {\n            \"A\": torch.tensor([13, 14, 15]),\n        }\n        src_to_exclude = torch.tensor([15, 13])\n    else:\n        original_row_node_ids = None\n        src_to_exclude = torch.tensor([2, 0])\n    if reverse_column:\n        original_column_node_ids = {\n            \"B\": torch.tensor([10, 11, 12]),\n        }\n        dst_to_exclude = torch.tensor([10, 12])\n    else:\n        original_column_node_ids = None\n        dst_to_exclude = torch.tensor([0, 2])\n    original_edge_ids = {\"A:relation:B\": torch.tensor([19, 19, 20, 20, 21])}\n    subgraph = SampledSubgraphImpl(\n        sampled_csc=csc_formats,\n        original_column_node_ids=original_column_node_ids,\n        original_row_node_ids=original_row_node_ids,\n        original_edge_ids=original_edge_ids,\n    )\n\n    edges_to_exclude = {\n        \"A:relation:B\": torch.cat(\n            (\n                src_to_exclude,\n                dst_to_exclude,\n            )\n        )\n        .view(2, -1)\n        .T\n    }\n    result = subgraph.exclude_edges(edges_to_exclude)\n    expected_csc_formats = {\n        \"A:relation:B\": gb.CSCFormatBase(\n            indptr=torch.tensor([0, 0, 2, 2]),\n            indices=torch.tensor([1, 1]),\n        )\n    }\n    if reverse_row:\n        expected_row_node_ids = {\n            \"A\": torch.tensor([13, 14, 15]),\n        }\n    else:\n        expected_row_node_ids = None\n    if reverse_column:\n        expected_column_node_ids = {\n            \"B\": torch.tensor([10, 11, 12]),\n        }\n    else:\n        expected_column_node_ids = None\n    expected_edge_ids = {\"A:relation:B\": torch.tensor([20, 20])}\n\n    _assert_container_equal(result.sampled_csc, expected_csc_formats)\n    _assert_container_equal(\n        result.original_column_node_ids, expected_column_node_ids\n    )\n    _assert_container_equal(result.original_row_node_ids, expected_row_node_ids)\n    _assert_container_equal(result.original_edge_ids, expected_edge_ids)\n\n\n@pytest.mark.parametrize(\"reverse_row\", [True, False])\n@pytest.mark.parametrize(\"reverse_column\", [True, False])\ndef test_exclude_edges_homo_deduplicated_tensor(reverse_row, reverse_column):\n    csc_formats = gb.CSCFormatBase(\n        indptr=torch.tensor([0, 0, 1, 2, 2, 3]), indices=torch.tensor([0, 3, 2])\n    )\n    if reverse_row:\n        original_row_node_ids = torch.tensor([10, 15, 11, 24, 9])\n        src_to_exclude = torch.tensor([11])\n    else:\n        original_row_node_ids = None\n        src_to_exclude = torch.tensor([2])\n\n    if reverse_column:\n        original_column_node_ids = torch.tensor([10, 15, 11, 24, 9])\n        dst_to_exclude = torch.tensor([9])\n    else:\n        original_column_node_ids = None\n        dst_to_exclude = torch.tensor([4])\n    original_edge_ids = torch.Tensor([5, 9, 10])\n    subgraph = SampledSubgraphImpl(\n        csc_formats,\n        original_column_node_ids,\n        original_row_node_ids,\n        original_edge_ids,\n    )\n    edges_to_exclude = torch.cat((src_to_exclude, dst_to_exclude)).view(1, -1)\n    result = subgraph.exclude_edges(edges_to_exclude)\n    expected_csc_formats = gb.CSCFormatBase(\n        indptr=torch.tensor([0, 0, 1, 2, 2, 2]), indices=torch.tensor([0, 3])\n    )\n    if reverse_row:\n        expected_row_node_ids = torch.tensor([10, 15, 11, 24, 9])\n    else:\n        expected_row_node_ids = None\n    if reverse_column:\n        expected_column_node_ids = torch.tensor([10, 15, 11, 24, 9])\n    else:\n        expected_column_node_ids = None\n    expected_edge_ids = torch.Tensor([5, 9])\n\n    _assert_container_equal(result.sampled_csc, expected_csc_formats)\n    _assert_container_equal(\n        result.original_column_node_ids, expected_column_node_ids\n    )\n    _assert_container_equal(result.original_row_node_ids, expected_row_node_ids)\n    _assert_container_equal(result.original_edge_ids, expected_edge_ids)\n\n\n@pytest.mark.parametrize(\"reverse_row\", [True, False])\n@pytest.mark.parametrize(\"reverse_column\", [True, False])\ndef test_exclude_edges_homo_duplicated_tensor(reverse_row, reverse_column):\n    csc_formats = gb.CSCFormatBase(\n        indptr=torch.tensor([0, 0, 1, 3, 3, 5]),\n        indices=torch.tensor([0, 3, 3, 2, 2]),\n    )\n    if reverse_row:\n        original_row_node_ids = torch.tensor([10, 15, 11, 24, 9])\n        src_to_exclude = torch.tensor([24])\n    else:\n        original_row_node_ids = None\n        src_to_exclude = torch.tensor([3])\n\n    if reverse_column:\n        original_column_node_ids = torch.tensor([10, 15, 11, 24, 9])\n        dst_to_exclude = torch.tensor([11])\n    else:\n        original_column_node_ids = None\n        dst_to_exclude = torch.tensor([2])\n    original_edge_ids = torch.Tensor([5, 9, 9, 10, 10])\n    subgraph = SampledSubgraphImpl(\n        csc_formats,\n        original_column_node_ids,\n        original_row_node_ids,\n        original_edge_ids,\n    )\n    edges_to_exclude = torch.cat((src_to_exclude, dst_to_exclude)).view(1, -1)\n    result = subgraph.exclude_edges(edges_to_exclude)\n    expected_csc_formats = gb.CSCFormatBase(\n        indptr=torch.tensor([0, 0, 1, 1, 1, 3]), indices=torch.tensor([0, 2, 2])\n    )\n    if reverse_row:\n        expected_row_node_ids = torch.tensor([10, 15, 11, 24, 9])\n    else:\n        expected_row_node_ids = None\n    if reverse_column:\n        expected_column_node_ids = torch.tensor([10, 15, 11, 24, 9])\n    else:\n        expected_column_node_ids = None\n    expected_edge_ids = torch.Tensor([5, 10, 10])\n    _assert_container_equal(result.sampled_csc, expected_csc_formats)\n    _assert_container_equal(\n        result.original_column_node_ids, expected_column_node_ids\n    )\n    _assert_container_equal(result.original_row_node_ids, expected_row_node_ids)\n    _assert_container_equal(result.original_edge_ids, expected_edge_ids)\n\n\n@pytest.mark.parametrize(\"reverse_row\", [True, False])\n@pytest.mark.parametrize(\"reverse_column\", [True, False])\ndef test_exclude_edges_hetero_deduplicated_tensor(reverse_row, reverse_column):\n    csc_formats = {\n        \"A:relation:B\": gb.CSCFormatBase(\n            indptr=torch.tensor([0, 1, 2, 3]),\n            indices=torch.tensor([2, 1, 0]),\n        )\n    }\n    if reverse_row:\n        original_row_node_ids = {\n            \"A\": torch.tensor([13, 14, 15]),\n        }\n        src_to_exclude = torch.tensor([15, 13])\n    else:\n        original_row_node_ids = None\n        src_to_exclude = torch.tensor([2, 0])\n    if reverse_column:\n        original_column_node_ids = {\n            \"B\": torch.tensor([10, 11, 12]),\n        }\n        dst_to_exclude = torch.tensor([10, 12])\n    else:\n        original_column_node_ids = None\n        dst_to_exclude = torch.tensor([0, 2])\n    original_edge_ids = {\"A:relation:B\": torch.tensor([19, 20, 21])}\n    subgraph = SampledSubgraphImpl(\n        sampled_csc=csc_formats,\n        original_column_node_ids=original_column_node_ids,\n        original_row_node_ids=original_row_node_ids,\n        original_edge_ids=original_edge_ids,\n    )\n\n    edges_to_exclude = {\n        \"A:relation:B\": torch.cat((src_to_exclude, dst_to_exclude))\n        .view(2, -1)\n        .T\n    }\n    result = subgraph.exclude_edges(edges_to_exclude)\n    expected_csc_formats = {\n        \"A:relation:B\": gb.CSCFormatBase(\n            indptr=torch.tensor([0, 0, 1, 1]),\n            indices=torch.tensor([1]),\n        )\n    }\n    if reverse_row:\n        expected_row_node_ids = {\n            \"A\": torch.tensor([13, 14, 15]),\n        }\n    else:\n        expected_row_node_ids = None\n    if reverse_column:\n        expected_column_node_ids = {\n            \"B\": torch.tensor([10, 11, 12]),\n        }\n    else:\n        expected_column_node_ids = None\n    expected_edge_ids = {\"A:relation:B\": torch.tensor([20])}\n\n    _assert_container_equal(result.sampled_csc, expected_csc_formats)\n    _assert_container_equal(\n        result.original_column_node_ids, expected_column_node_ids\n    )\n    _assert_container_equal(result.original_row_node_ids, expected_row_node_ids)\n    _assert_container_equal(result.original_edge_ids, expected_edge_ids)\n\n\n@pytest.mark.parametrize(\"reverse_row\", [True, False])\n@pytest.mark.parametrize(\"reverse_column\", [True, False])\ndef test_exclude_edges_hetero_duplicated_tensor(reverse_row, reverse_column):\n    csc_formats = {\n        \"A:relation:B\": gb.CSCFormatBase(\n            indptr=torch.tensor([0, 2, 4, 5]),\n            indices=torch.tensor([2, 2, 1, 1, 0]),\n        )\n    }\n    if reverse_row:\n        original_row_node_ids = {\n            \"A\": torch.tensor([13, 14, 15]),\n        }\n        src_to_exclude = torch.tensor([15, 13])\n    else:\n        original_row_node_ids = None\n        src_to_exclude = torch.tensor([2, 0])\n    if reverse_column:\n        original_column_node_ids = {\n            \"B\": torch.tensor([10, 11, 12]),\n        }\n        dst_to_exclude = torch.tensor([10, 12])\n    else:\n        original_column_node_ids = None\n        dst_to_exclude = torch.tensor([0, 2])\n    original_edge_ids = {\"A:relation:B\": torch.tensor([19, 19, 20, 20, 21])}\n    subgraph = SampledSubgraphImpl(\n        sampled_csc=csc_formats,\n        original_column_node_ids=original_column_node_ids,\n        original_row_node_ids=original_row_node_ids,\n        original_edge_ids=original_edge_ids,\n    )\n\n    edges_to_exclude = {\n        \"A:relation:B\": torch.cat((src_to_exclude, dst_to_exclude))\n        .view(2, -1)\n        .T\n    }\n    result = subgraph.exclude_edges(edges_to_exclude)\n    expected_csc_formats = {\n        \"A:relation:B\": gb.CSCFormatBase(\n            indptr=torch.tensor([0, 0, 2, 2]),\n            indices=torch.tensor([1, 1]),\n        )\n    }\n    if reverse_row:\n        expected_row_node_ids = {\n            \"A\": torch.tensor([13, 14, 15]),\n        }\n    else:\n        expected_row_node_ids = None\n    if reverse_column:\n        expected_column_node_ids = {\n            \"B\": torch.tensor([10, 11, 12]),\n        }\n    else:\n        expected_column_node_ids = None\n    expected_edge_ids = {\"A:relation:B\": torch.tensor([20, 20])}\n\n    _assert_container_equal(result.sampled_csc, expected_csc_formats)\n    _assert_container_equal(\n        result.original_column_node_ids, expected_column_node_ids\n    )\n    _assert_container_equal(result.original_row_node_ids, expected_row_node_ids)\n    _assert_container_equal(result.original_edge_ids, expected_edge_ids)\n\n\ndef test_to_pyg_homo():\n    graph = dgl.graph(([5, 0, 7, 7, 2, 4], [0, 1, 2, 2, 3, 4]))\n    graph = gb.from_dglgraph(graph, is_homogeneous=True).to(F.ctx())\n    items = torch.LongTensor([[0, 3], [4, 4]])\n    names = \"seeds\"\n    itemset = gb.ItemSet(items, names=names)\n    datapipe = gb.ItemSampler(itemset, batch_size=4).copy_to(F.ctx())\n    num_layer = 2\n    fanouts = [torch.LongTensor([-1]) for _ in range(num_layer)]\n    sampler = gb.NeighborSampler\n    datapipe = sampler(\n        datapipe,\n        graph,\n        fanouts,\n        deduplicate=True,\n    )\n    for minibatch in datapipe:\n        x = torch.randn((minibatch.node_ids().size(0), 2), dtype=torch.float32)\n        for subgraph in minibatch.sampled_subgraphs:\n            (x_src, x_dst), edge_index, sizes = subgraph.to_pyg(x)\n            assert torch.equal(x_src, x)\n            dst_size = subgraph.original_column_node_ids.size(0)\n            assert torch.equal(x_dst, x[:dst_size])\n            src_size = subgraph.original_row_node_ids.size(0)\n            assert dst_size == sizes[1]\n            assert src_size == sizes[0]\n            assert torch.equal(edge_index[0], subgraph.sampled_csc.indices)\n            assert torch.equal(\n                edge_index[1],\n                gb.expand_indptr(\n                    subgraph.sampled_csc.indptr,\n                    subgraph.sampled_csc.indices.dtype,\n                ),\n            )\n            x = x_dst\n\n\ndef test_to_pyg_hetero():\n    # COO graph:\n    # [0, 0, 1, 1, 2, 2, 3, 3, 4, 4]\n    # [2, 4, 2, 3, 0, 1, 1, 0, 0, 1]\n    # [1, 1, 1, 1, 0, 0, 0, 0, 0] - > edge type.\n    # num_nodes = 5, num_n1 = 2, num_n2 = 3\n    ntypes = {\"n1\": 0, \"n2\": 1}\n    etypes = {\"n1:e1:n2\": 0, \"n2:e2:n1\": 1}\n    indptr = torch.LongTensor([0, 2, 4, 6, 8, 10])\n    indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1])\n    type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0])\n    node_type_offset = torch.LongTensor([0, 2, 5])\n    graph = gb.fused_csc_sampling_graph(\n        indptr,\n        indices,\n        node_type_offset=node_type_offset,\n        type_per_edge=type_per_edge,\n        node_type_to_id=ntypes,\n        edge_type_to_id=etypes,\n    ).to(F.ctx())\n    itemset = gb.HeteroItemSet(\n        {\"n1:e1:n2\": gb.ItemSet(torch.tensor([[0, 1]]), names=\"seeds\")}\n    )\n    item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())\n    num_layer = 2\n    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]\n    Sampler = gb.NeighborSampler\n    datapipe = Sampler(\n        item_sampler,\n        graph,\n        fanouts,\n        deduplicate=True,\n    )\n    for minibatch in datapipe:\n        x = {}\n        for key, ids in minibatch.node_ids().items():\n            x[key] = torch.randn((ids.size(0), 2), dtype=torch.float32)\n        for subgraph in minibatch.sampled_subgraphs:\n            (x_src, x_dst), edge_index, sizes = subgraph.to_pyg(x)\n            assert x_src == x\n            for ntype in x:\n                dst_size = subgraph.original_column_node_ids[ntype].size(0)\n                assert torch.equal(x_dst[ntype], x[ntype][:dst_size])\n            for etype in subgraph.sampled_csc:\n                src_ntype, _, dst_ntype = gb.etype_str_to_tuple(etype)\n                src_size = subgraph.original_row_node_ids[src_ntype].size(0)\n                dst_size = subgraph.original_column_node_ids[dst_ntype].size(0)\n                assert dst_size == sizes[etype][1]\n                assert src_size == sizes[etype][0]\n                assert torch.equal(\n                    edge_index[etype][0], subgraph.sampled_csc[etype].indices\n                )\n                assert torch.equal(\n                    edge_index[etype][1],\n                    gb.expand_indptr(\n                        subgraph.sampled_csc[etype].indptr,\n                        subgraph.sampled_csc[etype].indices.dtype,\n                    ),\n                )\n            x = x_dst\n\n\n@unittest.skipIf(\n    F._default_context_str == \"cpu\",\n    reason=\"`to` function needs GPU to test.\",\n)\ndef test_sampled_subgraph_to_device():\n    # Initialize data.\n    csc_format = {\n        \"A:relation:B\": gb.CSCFormatBase(\n            indptr=torch.tensor([0, 1, 2, 3]),\n            indices=torch.tensor([0, 1, 2]),\n        )\n    }\n    original_row_node_ids = {\n        \"A\": torch.tensor([13, 14, 15]),\n    }\n    src_to_exclude = torch.tensor([15, 13])\n    original_column_node_ids = {\n        \"B\": torch.tensor([10, 11, 12]),\n    }\n    dst_to_exclude = torch.tensor([10, 12])\n    original_edge_ids = {\"A:relation:B\": torch.tensor([19, 20, 21])}\n    subgraph = SampledSubgraphImpl(\n        sampled_csc=csc_format,\n        original_column_node_ids=original_column_node_ids,\n        original_row_node_ids=original_row_node_ids,\n        original_edge_ids=original_edge_ids,\n    )\n    edges_to_exclude = {\n        \"A:relation:B\": torch.cat(\n            (\n                src_to_exclude,\n                dst_to_exclude,\n            )\n        )\n        .view(2, -1)\n        .T\n    }\n    graph = subgraph.exclude_edges(edges_to_exclude)\n\n    # Copy to device.\n    graph = graph.to(\"cuda\")\n\n    # Check.\n    for key in graph.sampled_csc:\n        assert graph.sampled_csc[key].indices.device.type == \"cuda\"\n        assert graph.sampled_csc[key].indptr.device.type == \"cuda\"\n    for key in graph.original_column_node_ids:\n        assert graph.original_column_node_ids[key].device.type == \"cuda\"\n    for key in graph.original_row_node_ids:\n        assert graph.original_row_node_ids[key].device.type == \"cuda\"\n    for key in graph.original_edge_ids:\n        assert graph.original_edge_ids[key].device.type == \"cuda\"\n\n\ndef test_sampled_subgraph_impl_representation_homo():\n    sampled_subgraph_impl = SampledSubgraphImpl(\n        sampled_csc=gb.CSCFormatBase(\n            indptr=torch.arange(0, 101, 10),\n            indices=torch.arange(10, 110),\n        ),\n        original_column_node_ids=torch.arange(0, 10),\n        original_row_node_ids=torch.arange(0, 110),\n        original_edge_ids=None,\n    )\n    expected_result = str(\n        \"\"\"SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([  0,  10,  20,  30,  40,  50,  60,  70,  80,  90, 100]),\n                                             indices=tensor([ 10,  11,  12,  13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,\n                                                              24,  25,  26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,\n                                                              38,  39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,\n                                                              52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,  65,\n                                                              66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,  78,  79,\n                                                              80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,  91,  92,  93,\n                                                              94,  95,  96,  97,  98,  99, 100, 101, 102, 103, 104, 105, 106, 107,\n                                                             108, 109]),\n                               ),\n                   original_row_node_ids=tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,\n                                                  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,\n                                                  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,\n                                                  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,\n                                                  56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,\n                                                  70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,\n                                                  84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,\n                                                  98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109]),\n                   original_edge_ids=None,\n                   original_column_node_ids=tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),\n)\"\"\"\n    )\n    assert str(sampled_subgraph_impl) == expected_result, print(\n        sampled_subgraph_impl\n    )\n\n\ndef test_sampled_subgraph_impl_representation_hetero():\n    sampled_subgraph_impl = SampledSubgraphImpl(\n        sampled_csc={\n            \"n1:e1:n2\": gb.CSCFormatBase(\n                indptr=torch.tensor([0, 2, 4]),\n                indices=torch.tensor([4, 5, 6, 7]),\n            ),\n            \"n2:e2:n1\": gb.CSCFormatBase(\n                indptr=torch.tensor([0, 2, 4, 6, 8]),\n                indices=torch.tensor([2, 3, 4, 5, 6, 7, 8, 9]),\n            ),\n        },\n        original_column_node_ids={\n            \"n1\": torch.tensor([1, 0, 0, 1]),\n            \"n2\": torch.tensor([1, 2]),\n        },\n        original_row_node_ids={\n            \"n1\": torch.tensor([1, 0, 0, 1, 1, 0, 0, 1]),\n            \"n2\": torch.tensor([1, 2, 0, 1, 0, 2, 0, 2, 0, 1]),\n        },\n        original_edge_ids=None,\n    )\n    expected_result = str(\n        \"\"\"SampledSubgraphImpl(sampled_csc={'n1:e1:n2': CSCFormatBase(indptr=tensor([0, 2, 4]),\n                                             indices=tensor([4, 5, 6, 7]),\n                               ), 'n2:e2:n1': CSCFormatBase(indptr=tensor([0, 2, 4, 6, 8]),\n                                             indices=tensor([2, 3, 4, 5, 6, 7, 8, 9]),\n                               )},\n                   original_row_node_ids={'n1': tensor([1, 0, 0, 1, 1, 0, 0, 1]), 'n2': tensor([1, 2, 0, 1, 0, 2, 0, 2, 0, 1])},\n                   original_edge_ids=None,\n                   original_column_node_ids={'n1': tensor([1, 0, 0, 1]), 'n2': tensor([1, 2])},\n)\"\"\"\n    )\n    assert str(sampled_subgraph_impl) == expected_result, print(\n        sampled_subgraph_impl\n    )\n"
  },
  {
    "path": "tests/python/pytorch/graphbolt/impl/test_torch_based_feature_store.py",
    "content": "import os\nimport tempfile\nimport unittest\n\nimport backend as F\n\nimport numpy as np\nimport pydantic\nimport pytest\nimport torch\n\nfrom dgl import graphbolt as gb\n\n\ndef to_on_disk_tensor(test_dir, name, t):\n    path = os.path.join(test_dir, name + \".npy\")\n    t = t.numpy()\n    np.save(path, t)\n    # The Pytorch tensor is a view of the numpy array on disk, which does not\n    # consume memory.\n    t = torch.as_tensor(np.load(path, mmap_mode=\"r+\"))\n    return t\n\n\n@pytest.mark.parametrize(\"in_memory\", [True, False])\ndef test_torch_based_feature(in_memory):\n    with tempfile.TemporaryDirectory() as test_dir:\n        a = torch.tensor([[1, 2, 3], [4, 5, 6]])\n        b = torch.tensor([[[1, 2], [3, 4]], [[4, 5], [6, 7]]])\n        metadata = {\"max_value\": 3}\n        if not in_memory:\n            a = to_on_disk_tensor(test_dir, \"a\", a)\n            b = to_on_disk_tensor(test_dir, \"b\", b)\n\n        feature_a = gb.TorchBasedFeature(a, metadata=metadata)\n        feature_b = gb.TorchBasedFeature(b)\n\n        # Read the entire feature.\n        assert torch.equal(\n            feature_a.read(), torch.tensor([[1, 2, 3], [4, 5, 6]])\n        )\n\n        # Test read the feature with ids.\n        assert torch.equal(\n            feature_b.read(), torch.tensor([[[1, 2], [3, 4]], [[4, 5], [6, 7]]])\n        )\n        # Read the feature with ids.\n        assert torch.equal(\n            feature_a.read(torch.tensor([0])),\n            torch.tensor([[1, 2, 3]]),\n        )\n        assert torch.equal(\n            feature_b.read(torch.tensor([1])),\n            torch.tensor([[[4, 5], [6, 7]]]),\n        )\n        # Update the feature with ids.\n        feature_a.update(torch.tensor([[0, 1, 2]]), torch.tensor([0]))\n        assert torch.equal(\n            feature_a.read(), torch.tensor([[0, 1, 2], [4, 5, 6]])\n        )\n        feature_b.update(torch.tensor([[[1, 2], [3, 4]]]), torch.tensor([1]))\n        assert torch.equal(\n            feature_b.read(), torch.tensor([[[1, 2], [3, 4]], [[1, 2], [3, 4]]])\n        )\n\n        # Test update the feature.\n        feature_a.update(torch.tensor([[5, 1, 3]]))\n        assert torch.equal(\n            feature_a.read(),\n            torch.tensor([[5, 1, 3]]),\n        ), print(feature_a.read())\n        feature_b.update(\n            torch.tensor([[[1, 3], [5, 7]], [[2, 4], [6, 8]], [[2, 4], [6, 8]]])\n        )\n        assert torch.equal(\n            feature_b.read(),\n            torch.tensor(\n                [[[1, 3], [5, 7]], [[2, 4], [6, 8]], [[2, 4], [6, 8]]]\n            ),\n        )\n\n        # Test get the size and count of the entire feature.\n        assert feature_a.size() == torch.Size([3])\n        assert feature_b.size() == torch.Size([2, 2])\n        assert feature_a.count() == 1\n        assert feature_b.count() == 3\n\n        # Test get metadata of the feature.\n        assert feature_a.metadata() == metadata\n        assert feature_b.metadata() == {}\n\n        with pytest.raises(IndexError):\n            feature_a.read(torch.tensor([0, 1, 2, 3]))\n\n        # For windows, the file is locked by the numpy.load. We need to delete\n        # it before closing the temporary directory.\n        a = b = None\n        feature_a = feature_b = None\n\n        # Test loaded tensors' contiguity from C/Fortran contiguous ndarray.\n        contiguous_numpy = np.array([[1, 2, 3], [4, 5, 6]], order=\"C\")\n        non_contiguous_numpy = np.array([[1, 2, 3], [4, 5, 6]], order=\"F\")\n        assert contiguous_numpy.flags[\"C_CONTIGUOUS\"]\n        assert non_contiguous_numpy.flags[\"F_CONTIGUOUS\"]\n        np.save(\n            os.path.join(test_dir, \"contiguous_numpy.npy\"), contiguous_numpy\n        )\n        np.save(\n            os.path.join(test_dir, \"non_contiguous_numpy.npy\"),\n            non_contiguous_numpy,\n        )\n\n        cur_mmap_mode = None\n        if not in_memory:\n            cur_mmap_mode = \"r+\"\n        feature_a = gb.TorchBasedFeature(\n            torch.from_numpy(\n                np.load(\n                    os.path.join(test_dir, \"contiguous_numpy.npy\"),\n                    mmap_mode=cur_mmap_mode,\n                )\n            )\n        )\n        feature_b = gb.TorchBasedFeature(\n            torch.from_numpy(\n                np.load(\n                    os.path.join(test_dir, \"non_contiguous_numpy.npy\"),\n                    mmap_mode=cur_mmap_mode,\n                )\n            )\n        )\n        assert feature_a._tensor.is_contiguous()\n        assert feature_b._tensor.is_contiguous()\n\n        contiguous_numpy = non_contiguous_numpy = None\n        feature_a = feature_b = None\n\n\ndef is_feature_store_on_cuda(store):\n    for feature in store._features.values():\n        assert feature._tensor.is_cuda\n\n\ndef is_feature_store_on_cpu(store):\n    for feature in store._features.values():\n        assert not feature._tensor.is_cuda\n\n\n@unittest.skipIf(\n    F._default_context_str == \"cpu\",\n    reason=\"Tests for pinned memory are only meaningful on GPU.\",\n)\n@pytest.mark.parametrize(\"device\", [\"pinned\", \"cuda\"])\ndef test_feature_store_to_device(device):\n    with tempfile.TemporaryDirectory() as test_dir:\n        a = torch.tensor([[1, 2, 4], [2, 5, 3]])\n        b = torch.tensor([[[1, 2], [3, 4]], [[2, 5], [3, 4]]])\n        write_tensor_to_disk(test_dir, \"a\", a, fmt=\"torch\")\n        write_tensor_to_disk(test_dir, \"b\", b, fmt=\"numpy\")\n        feature_data = [\n            gb.OnDiskFeatureData(\n                domain=\"node\",\n                type=\"paper\",\n                name=\"a\",\n                format=\"torch\",\n                path=os.path.join(test_dir, \"a.pt\"),\n            ),\n            gb.OnDiskFeatureData(\n                domain=\"edge\",\n                type=\"paper:cites:paper\",\n                name=\"b\",\n                format=\"numpy\",\n                path=os.path.join(test_dir, \"b.npy\"),\n            ),\n        ]\n        feature_store = gb.TorchBasedFeatureStore(feature_data)\n        feature_store2 = feature_store.to(device)\n        if device == \"pinned\":\n            assert feature_store2.is_pinned()\n        elif device == \"cuda\":\n            is_feature_store_on_cuda(feature_store2)\n\n        # The original variable should be untouched.\n        is_feature_store_on_cpu(feature_store)\n\n\n@unittest.skipIf(\n    F._default_context_str == \"cpu\",\n    reason=\"Tests for pinned memory are only meaningful on GPU.\",\n)\n@pytest.mark.parametrize(\n    \"dtype\",\n    [\n        torch.float32,\n        torch.float64,\n        torch.int32,\n        torch.int64,\n        torch.int8,\n        torch.float16,\n        torch.complex128,\n    ],\n)\n@pytest.mark.parametrize(\"idtype\", [torch.int32, torch.int64])\n@pytest.mark.parametrize(\"shape\", [(2, 1), (2, 3), (2, 2, 2), (137, 13, 3)])\n@pytest.mark.parametrize(\"in_place\", [False, True])\ndef test_torch_based_pinned_feature(dtype, idtype, shape, in_place):\n    if dtype == torch.complex128:\n        tensor = torch.complex(\n            torch.randint(0, 13, shape, dtype=torch.float64),\n            torch.randint(0, 13, shape, dtype=torch.float64),\n        )\n    else:\n        tensor = torch.randint(0, 13, shape, dtype=dtype)\n    test_tensor = tensor.clone().detach()\n    test_tensor_cuda = test_tensor.cuda()\n\n    feature = gb.TorchBasedFeature(tensor)\n    if in_place:\n        if gb.is_wsl():\n            pytest.skip(\"In place pinning is not supported on WSL.\")\n        feature.pin_memory_()\n\n        # Check if pinning is truly in-place.\n        assert feature._tensor.data_ptr() == tensor.data_ptr()\n    else:\n        feature = feature.to(\"pinned\")\n\n    assert feature.is_pinned()\n\n    # Test read entire pinned feature, the result should be on cuda.\n    assert torch.equal(feature.read(), test_tensor_cuda)\n    assert feature.read().is_cuda\n    assert torch.equal(\n        feature.read(torch.tensor([0], dtype=idtype).cuda()),\n        test_tensor_cuda[[0]],\n    )\n\n    # Test read pinned feature with idx on cuda, the result should be on cuda.\n    assert feature.read(torch.tensor([0], dtype=idtype).cuda()).is_cuda\n\n    # Test read pinned feature with idx on cpu, the result should be on cpu.\n    assert torch.equal(\n        feature.read(torch.tensor([0], dtype=idtype)), test_tensor[[0]]\n    )\n    assert not feature.read(torch.tensor([0], dtype=idtype)).is_cuda\n\n\ndef write_tensor_to_disk(dir, name, t, fmt=\"torch\"):\n    if fmt == \"torch\":\n        torch.save(t, os.path.join(dir, name + \".pt\"))\n    elif fmt == \"numpy\":\n        t = t.numpy()\n        np.save(os.path.join(dir, name + \".npy\"), t)\n    else:\n        raise ValueError(f\"Unsupported format: {fmt}\")\n\n\n@pytest.mark.parametrize(\"in_memory\", [True, False])\ndef test_torch_based_feature_store(in_memory):\n    with tempfile.TemporaryDirectory() as test_dir:\n        a = torch.tensor([[1, 2, 4], [2, 5, 3]])\n        b = torch.tensor([[[1, 2], [3, 4]], [[2, 5], [3, 4]]])\n        write_tensor_to_disk(test_dir, \"a\", a, fmt=\"torch\")\n        write_tensor_to_disk(test_dir, \"b\", b, fmt=\"numpy\")\n        feature_data = [\n            gb.OnDiskFeatureData(\n                domain=\"node\",\n                type=\"paper\",\n                name=\"a\",\n                format=\"torch\",\n                path=os.path.join(test_dir, \"a.pt\"),\n                in_memory=True,\n            ),\n            gb.OnDiskFeatureData(\n                domain=\"edge\",\n                type=\"paper:cites:paper\",\n                name=\"b\",\n                format=\"numpy\",\n                path=os.path.join(test_dir, \"b.npy\"),\n                in_memory=in_memory,\n            ),\n        ]\n        feature_store = gb.TorchBasedFeatureStore(feature_data)\n\n        assert isinstance(\n            feature_store[(\"node\", \"paper\", \"a\")], gb.TorchBasedFeature\n        )\n        assert isinstance(\n            feature_store[(\"edge\", \"paper:cites:paper\", \"b\")],\n            gb.TorchBasedFeature if in_memory else gb.DiskBasedFeature,\n        )\n\n        # Test read the entire feature.\n        assert torch.equal(\n            feature_store.read(\"node\", \"paper\", \"a\"),\n            torch.tensor([[1, 2, 4], [2, 5, 3]]),\n        )\n        assert torch.equal(\n            feature_store.read(\"edge\", \"paper:cites:paper\", \"b\"),\n            torch.tensor([[[1, 2], [3, 4]], [[2, 5], [3, 4]]]),\n        )\n\n        # Test get the size of the entire feature.\n        assert feature_store.size(\"node\", \"paper\", \"a\") == torch.Size([3])\n        assert feature_store.size(\n            \"edge\", \"paper:cites:paper\", \"b\"\n        ) == torch.Size([2, 2])\n\n        # Test get the keys of the features.\n        assert feature_store.keys() == [\n            (\"node\", \"paper\", \"a\"),\n            (\"edge\", \"paper:cites:paper\", \"b\"),\n        ]\n\n        # For windows, the file is locked by the numpy.load. We need to delete\n        # it before closing the temporary directory.\n        a = b = None\n        feature_store = None\n\n        # ``domain`` should be enum.\n        with pytest.raises(pydantic.ValidationError):\n            _ = gb.OnDiskFeatureData(\n                domain=\"invalid\",\n                type=\"paper\",\n                name=\"a\",\n                format=\"torch\",\n                path=os.path.join(test_dir, \"a.pt\"),\n                in_memory=True,\n            )\n\n        # ``type`` could be null.\n        feature_data = [\n            gb.OnDiskFeatureData(\n                domain=\"node\",\n                name=\"a\",\n                format=\"torch\",\n                path=os.path.join(test_dir, \"a.pt\"),\n                in_memory=True,\n            ),\n        ]\n        feature_store = gb.TorchBasedFeatureStore(feature_data)\n        # Test read the entire feature.\n        assert torch.equal(\n            feature_store.read(\"node\", None, \"a\"),\n            torch.tensor([[1, 2, 4], [2, 5, 3]]),\n        )\n        # Test get the size of the entire feature.\n        assert feature_store.size(\"node\", None, \"a\") == torch.Size([3])\n\n        feature_store = None\n\n\n@pytest.mark.parametrize(\"in_memory\", [True, False])\ndef test_torch_based_feature_repr(in_memory):\n    with tempfile.TemporaryDirectory() as test_dir:\n        a = torch.tensor([[1, 2, 3], [4, 5, 6]])\n        b = torch.tensor([[[1, 2], [3, 4]], [[4, 5], [6, 7]]])\n        metadata = {\"max_value\": 3}\n        if not in_memory:\n            a = to_on_disk_tensor(test_dir, \"a\", a)\n            b = to_on_disk_tensor(test_dir, \"b\", b)\n\n        feature_a = gb.TorchBasedFeature(a, metadata=metadata)\n        feature_b = gb.TorchBasedFeature(b)\n\n        expected_str_feature_a = (\n            \"TorchBasedFeature(\\n\"\n            \"    feature=tensor([[1, 2, 3],\\n\"\n            \"                    [4, 5, 6]]),\\n\"\n            \"    metadata={'max_value': 3},\\n\"\n            \")\"\n        )\n        expected_str_feature_b = (\n            \"TorchBasedFeature(\\n\"\n            \"    feature=tensor([[[1, 2],\\n\"\n            \"                     [3, 4]],\\n\"\n            \"\\n\"\n            \"                    [[4, 5],\\n\"\n            \"                     [6, 7]]]),\\n\"\n            \"    metadata={},\\n\"\n            \")\"\n        )\n\n        assert repr(feature_a) == expected_str_feature_a, feature_a\n        assert repr(feature_b) == expected_str_feature_b, feature_b\n\n        a = b = metadata = None\n        feature_a = feature_b = None\n        expected_str_feature_a = expected_str_feature_b = None\n\n\n@pytest.mark.parametrize(\"in_memory\", [True, False])\ndef test_torch_based_feature_store_repr(in_memory):\n    with tempfile.TemporaryDirectory() as test_dir:\n        a = torch.tensor([[1, 2, 4], [2, 5, 3]])\n        b = torch.tensor([[[1, 2], [3, 4]], [[2, 5], [3, 4]]])\n        write_tensor_to_disk(test_dir, \"a\", a, fmt=\"torch\")\n        write_tensor_to_disk(test_dir, \"b\", b, fmt=\"numpy\")\n        feature_data = [\n            gb.OnDiskFeatureData(\n                domain=\"node\",\n                type=\"paper\",\n                name=\"a\",\n                format=\"torch\",\n                path=os.path.join(test_dir, \"a.pt\"),\n                in_memory=True,\n            ),\n            gb.OnDiskFeatureData(\n                domain=\"edge\",\n                type=\"paper:cites:paper\",\n                name=\"b\",\n                format=\"numpy\",\n                path=os.path.join(test_dir, \"b.npy\"),\n                in_memory=in_memory,\n            ),\n        ]\n        feature_store = gb.TorchBasedFeatureStore(feature_data)\n\n        expected_feature_store_str = (\n            (\n                \"TorchBasedFeatureStore(\\n\"\n                \"    {(<OnDiskFeatureDataDomain.NODE: 'node'>, 'paper', 'a'): TorchBasedFeature(\\n\"\n                \"        feature=tensor([[1, 2, 4],\\n\"\n                \"                        [2, 5, 3]]),\\n\"\n                \"        metadata={},\\n\"\n                \"    ), (<OnDiskFeatureDataDomain.EDGE: 'edge'>, 'paper:cites:paper', 'b'): TorchBasedFeature(\\n\"\n                \"        feature=tensor([[[1, 2],\\n\"\n                \"                         [3, 4]],\\n\"\n                \"\\n\"\n                \"                        [[2, 5],\\n\"\n                \"                         [3, 4]]]),\\n\"\n                \"        metadata={},\\n\"\n                \"    )}\\n\"\n                \")\"\n            )\n            if in_memory\n            else (\n                \"TorchBasedFeatureStore(\\n\"\n                \"    {(<OnDiskFeatureDataDomain.NODE: 'node'>, 'paper', 'a'): TorchBasedFeature(\\n\"\n                \"        feature=tensor([[1, 2, 4],\\n\"\n                \"                        [2, 5, 3]]),\\n\"\n                \"        metadata={},\\n\"\n                \"    ), (<OnDiskFeatureDataDomain.EDGE: 'edge'>, 'paper:cites:paper', 'b'): DiskBasedFeature(\\n\"\n                \"        feature=tensor([[[1, 2],\\n\"\n                \"                         [3, 4]],\\n\"\n                \"\\n\"\n                \"                        [[2, 5],\\n\"\n                \"                         [3, 4]]]),\\n\"\n                \"        metadata={},\\n\"\n                \"    )}\\n\"\n                \")\"\n            )\n        )\n\n        assert repr(feature_store) == expected_feature_store_str, feature_store\n\n        a = b = feature_data = None\n        feature_store = expected_feature_store_str = None\n"
  },
  {
    "path": "tests/python/pytorch/graphbolt/internal/test_sample_utils.py",
    "content": "import backend as F\nimport dgl.graphbolt as gb\nimport pytest\nimport torch\n\n\ndef test_unique_and_compact_hetero():\n    N1 = torch.tensor(\n        [0, 5, 2, 7, 12, 7, 9, 5, 6, 2, 3, 4, 1, 0, 9], device=F.ctx()\n    )\n    N2 = torch.tensor([0, 3, 3, 5, 2, 7, 2, 8, 4, 9, 2, 3], device=F.ctx())\n    N3 = torch.tensor([1, 2, 6, 6, 1, 8, 3, 6, 3, 2], device=F.ctx())\n    expected_unique = {\n        \"n1\": torch.tensor([0, 5, 2, 7, 12, 9, 6, 3, 4, 1], device=F.ctx()),\n        \"n2\": torch.tensor([0, 3, 5, 2, 7, 8, 4, 9], device=F.ctx()),\n        \"n3\": torch.tensor([1, 2, 6, 8, 3], device=F.ctx()),\n    }\n    if N1.is_cuda and torch.cuda.get_device_capability()[0] < 7:\n        expected_reverse_id = {\n            k: v.sort()[1] for k, v in expected_unique.items()\n        }\n        expected_unique = {k: v.sort()[0] for k, v in expected_unique.items()}\n    else:\n        expected_reverse_id = {\n            k: torch.arange(0, v.shape[0], device=F.ctx())\n            for k, v in expected_unique.items()\n        }\n    nodes_dict = {\n        \"n1\": N1.split(5),\n        \"n2\": N2.split(4),\n        \"n3\": N3.split(2),\n    }\n    expected_nodes_dict = {\n        \"n1\": [\n            torch.tensor([0, 1, 2, 3, 4], device=F.ctx()),\n            torch.tensor([3, 5, 1, 6, 2], device=F.ctx()),\n            torch.tensor([7, 8, 9, 0, 5], device=F.ctx()),\n        ],\n        \"n2\": [\n            torch.tensor([0, 1, 1, 2], device=F.ctx()),\n            torch.tensor([3, 4, 3, 5], device=F.ctx()),\n            torch.tensor([6, 7, 3, 1], device=F.ctx()),\n        ],\n        \"n3\": [\n            torch.tensor([0, 1], device=F.ctx()),\n            torch.tensor([2, 2], device=F.ctx()),\n            torch.tensor([0, 3], device=F.ctx()),\n            torch.tensor([4, 2], device=F.ctx()),\n            torch.tensor([4, 1], device=F.ctx()),\n        ],\n    }\n\n    unique, compacted, _ = gb.unique_and_compact(nodes_dict)\n    for ntype, nodes in unique.items():\n        expected_nodes = expected_unique[ntype]\n        assert torch.equal(nodes, expected_nodes)\n\n    for ntype, nodes in compacted.items():\n        expected_nodes = expected_nodes_dict[ntype]\n        assert isinstance(nodes, list)\n        for expected_node, node in zip(expected_nodes, nodes):\n            node = expected_reverse_id[ntype][node]\n            assert torch.equal(expected_node, node)\n\n\ndef test_unique_and_compact_homo():\n    N = torch.tensor(\n        [0, 5, 2, 7, 12, 7, 9, 5, 6, 2, 3, 4, 1, 0, 9], device=F.ctx()\n    )\n    expected_unique_N = torch.tensor(\n        [0, 5, 2, 7, 12, 9, 6, 3, 4, 1], device=F.ctx()\n    )\n    if N.is_cuda and torch.cuda.get_device_capability()[0] < 7:\n        expected_reverse_id_N = expected_unique_N.sort()[1]\n        expected_unique_N = expected_unique_N.sort()[0]\n    else:\n        expected_reverse_id_N = torch.arange(\n            0, expected_unique_N.shape[0], device=F.ctx()\n        )\n    nodes_list = N.split(5)\n    expected_nodes_list = [\n        torch.tensor([0, 1, 2, 3, 4], device=F.ctx()),\n        torch.tensor([3, 5, 1, 6, 2], device=F.ctx()),\n        torch.tensor([7, 8, 9, 0, 5], device=F.ctx()),\n    ]\n\n    unique, compacted, _ = gb.unique_and_compact(nodes_list)\n\n    assert torch.equal(unique, expected_unique_N)\n    assert isinstance(compacted, list)\n    for expected_node, node in zip(expected_nodes_list, compacted):\n        node = expected_reverse_id_N[node]\n        assert torch.equal(expected_node, node)\n\n\ndef test_unique_and_compact_csc_formats_hetero():\n    dst_nodes = {\n        \"n2\": torch.tensor([2, 4, 1, 3]),\n        \"n3\": torch.tensor([1, 3, 2, 7]),\n    }\n    csc_formats = {\n        \"n1:e1:n2\": gb.CSCFormatBase(\n            indptr=torch.tensor([0, 3, 4, 7, 10]),\n            indices=torch.tensor([1, 3, 4, 6, 2, 7, 9, 4, 2, 6]),\n        ),\n        \"n1:e2:n3\": gb.CSCFormatBase(\n            indptr=torch.tensor([0, 1, 4, 7, 10]),\n            indices=torch.tensor([5, 2, 6, 4, 7, 2, 8, 1, 3, 0]),\n        ),\n        \"n2:e3:n3\": gb.CSCFormatBase(\n            indptr=torch.tensor([0, 2, 4, 6, 8]),\n            indices=torch.tensor([2, 5, 4, 1, 4, 3, 6, 0]),\n        ),\n    }\n\n    expected_unique_nodes = {\n        \"n1\": torch.tensor([1, 3, 4, 6, 2, 7, 9, 5, 8, 0]),\n        \"n2\": torch.tensor([2, 4, 1, 3, 5, 6, 0]),\n        \"n3\": torch.tensor([1, 3, 2, 7]),\n    }\n    expected_csc_formats = {\n        \"n1:e1:n2\": gb.CSCFormatBase(\n            indptr=torch.tensor([0, 3, 4, 7, 10]),\n            indices=torch.tensor([0, 1, 2, 3, 4, 5, 6, 2, 4, 3]),\n        ),\n        \"n1:e2:n3\": gb.CSCFormatBase(\n            indptr=torch.tensor([0, 1, 4, 7, 10]),\n            indices=torch.tensor([7, 4, 3, 2, 5, 4, 8, 0, 1, 9]),\n        ),\n        \"n2:e3:n3\": gb.CSCFormatBase(\n            indptr=torch.tensor([0, 2, 4, 6, 8]),\n            indices=torch.tensor([0, 4, 1, 2, 1, 3, 5, 6]),\n        ),\n    }\n\n    unique_nodes, compacted_csc_formats, _ = gb.unique_and_compact_csc_formats(\n        csc_formats, dst_nodes\n    )\n\n    for ntype, nodes in unique_nodes.items():\n        expected_nodes = expected_unique_nodes[ntype]\n        assert torch.equal(nodes, expected_nodes)\n    for etype, pair in compacted_csc_formats.items():\n        indices = pair.indices\n        indptr = pair.indptr\n        expected_indices = expected_csc_formats[etype].indices\n        expected_indptr = expected_csc_formats[etype].indptr\n        assert torch.equal(indices, expected_indices)\n        assert torch.equal(indptr, expected_indptr)\n\n\ndef test_unique_and_compact_csc_formats_homo():\n    seeds = torch.tensor([1, 3, 5, 2, 6])\n    indptr = torch.tensor([0, 2, 4, 6, 7, 11])\n    indices = torch.tensor([2, 3, 1, 4, 5, 2, 5, 1, 4, 4, 6])\n    csc_formats = gb.CSCFormatBase(indptr=indptr, indices=indices)\n\n    expected_unique_nodes = torch.tensor([1, 3, 5, 2, 6, 4])\n    expected_indptr = indptr\n    expected_indices = torch.tensor([3, 1, 0, 5, 2, 3, 2, 0, 5, 5, 4])\n\n    unique_nodes, compacted_csc_formats, _ = gb.unique_and_compact_csc_formats(\n        csc_formats, seeds\n    )\n\n    indptr = compacted_csc_formats.indptr\n    indices = compacted_csc_formats.indices\n    assert torch.equal(indptr, expected_indptr)\n    assert torch.equal(indices, expected_indices)\n    assert torch.equal(unique_nodes, expected_unique_nodes)\n\n\ndef test_unique_and_compact_incorrect_indptr():\n    seeds = torch.tensor([1, 3, 5, 2, 6, 7])\n    indptr = torch.tensor([0, 2, 4, 6, 7, 11])\n    indices = torch.tensor([2, 3, 1, 4, 5, 2, 5, 1, 4, 4, 6])\n    csc_formats = gb.CSCFormatBase(indptr=indptr, indices=indices)\n\n    # The number of seeds is not corresponding to indptr.\n    with pytest.raises(AssertionError):\n        gb.unique_and_compact_csc_formats(csc_formats, seeds)\n\n\ndef test_compact_csc_format_hetero():\n    dst_nodes = {\n        \"n2\": torch.tensor([2, 4, 1, 3]),\n        \"n3\": torch.tensor([1, 3, 2, 7]),\n    }\n    csc_formats = {\n        \"n1:e1:n2\": gb.CSCFormatBase(\n            indptr=torch.tensor([0, 3, 4, 7, 10]),\n            indices=torch.tensor([1, 3, 4, 6, 2, 7, 9, 4, 2, 6]),\n        ),\n        \"n1:e2:n3\": gb.CSCFormatBase(\n            indptr=torch.tensor([0, 1, 4, 7, 10]),\n            indices=torch.tensor([5, 2, 6, 4, 7, 2, 8, 1, 3, 0]),\n        ),\n        \"n2:e3:n3\": gb.CSCFormatBase(\n            indptr=torch.tensor([0, 2, 4, 6, 8]),\n            indices=torch.tensor([2, 5, 4, 1, 4, 3, 6, 0]),\n        ),\n    }\n\n    expected_original_row_ids = {\n        \"n1\": torch.tensor(\n            [1, 3, 4, 6, 2, 7, 9, 4, 2, 6, 5, 2, 6, 4, 7, 2, 8, 1, 3, 0]\n        ),\n        \"n2\": torch.tensor([2, 4, 1, 3, 2, 5, 4, 1, 4, 3, 6, 0]),\n        \"n3\": torch.tensor([1, 3, 2, 7]),\n    }\n    expected_csc_formats = {\n        \"n1:e1:n2\": gb.CSCFormatBase(\n            indptr=torch.tensor([0, 3, 4, 7, 10]),\n            indices=torch.arange(0, 10),\n        ),\n        \"n1:e2:n3\": gb.CSCFormatBase(\n            indptr=torch.tensor([0, 1, 4, 7, 10]),\n            indices=torch.arange(0, 10) + 10,\n        ),\n        \"n2:e3:n3\": gb.CSCFormatBase(\n            indptr=torch.tensor([0, 2, 4, 6, 8]),\n            indices=torch.arange(0, 8) + 4,\n        ),\n    }\n    original_row_ids, compacted_csc_formats = gb.compact_csc_format(\n        csc_formats, dst_nodes\n    )\n\n    for ntype, nodes in original_row_ids.items():\n        expected_nodes = expected_original_row_ids[ntype]\n        assert torch.equal(nodes, expected_nodes)\n    for etype, csc_format in compacted_csc_formats.items():\n        indptr = csc_format.indptr\n        indices = csc_format.indices\n        expected_indptr = expected_csc_formats[etype].indptr\n        expected_indices = expected_csc_formats[etype].indices\n        assert torch.equal(indptr, expected_indptr)\n        assert torch.equal(indices, expected_indices)\n\n\ndef test_compact_csc_format_homo():\n    seeds = torch.tensor([1, 3, 5, 2, 6])\n    indptr = torch.tensor([0, 2, 4, 6, 7, 11])\n    indices = torch.tensor([2, 3, 1, 4, 5, 2, 5, 1, 4, 4, 6])\n    csc_formats = gb.CSCFormatBase(indptr=indptr, indices=indices)\n\n    expected_original_row_ids = torch.tensor(\n        [1, 3, 5, 2, 6, 2, 3, 1, 4, 5, 2, 5, 1, 4, 4, 6]\n    )\n    expected_indptr = indptr\n    expected_indices = torch.arange(0, len(indices)) + 5\n\n    original_row_ids, compacted_csc_formats = gb.compact_csc_format(\n        csc_formats, seeds\n    )\n\n    indptr = compacted_csc_formats.indptr\n    indices = compacted_csc_formats.indices\n\n    assert torch.equal(indptr, expected_indptr)\n    assert torch.equal(indices, expected_indices)\n    assert torch.equal(original_row_ids, expected_original_row_ids)\n\n\ndef test_compact_incorrect_indptr():\n    seeds = torch.tensor([1, 3, 5, 2, 6, 7])\n    indptr = torch.tensor([0, 2, 4, 6, 7, 11])\n    indices = torch.tensor([2, 3, 1, 4, 5, 2, 5, 1, 4, 4, 6])\n    csc_formats = gb.CSCFormatBase(indptr=indptr, indices=indices)\n\n    # The number of seeds is not corresponding to indptr.\n    with pytest.raises(AssertionError):\n        gb.compact_csc_format(csc_formats, seeds)\n"
  },
  {
    "path": "tests/python/pytorch/graphbolt/internal/test_utils.py",
    "content": "import json\nimport os\nimport re\nimport tempfile\nfrom functools import partial\n\nimport dgl.graphbolt as gb\n\nimport dgl.graphbolt.internal as internal\nimport numpy as np\nimport pandas as pd\nimport pytest\nimport torch\n\n\ndef test_read_torch_data():\n    with tempfile.TemporaryDirectory() as test_dir:\n        save_tensor = torch.tensor([[1, 2, 4], [2, 5, 3]])\n        file_name = os.path.join(test_dir, \"save_tensor.pt\")\n        torch.save(save_tensor, file_name)\n        read_tensor = internal.utils._read_torch_data(file_name)\n        assert torch.equal(save_tensor, read_tensor)\n        save_tensor = read_tensor = None\n\n\n@pytest.mark.parametrize(\"in_memory\", [True, False])\ndef test_read_numpy_data(in_memory):\n    with tempfile.TemporaryDirectory() as test_dir:\n        save_numpy = np.array([[1, 2, 4], [2, 5, 3]])\n        file_name = os.path.join(test_dir, \"save_numpy.npy\")\n        np.save(file_name, save_numpy)\n        read_tensor = internal.utils._read_numpy_data(file_name, in_memory)\n        assert torch.equal(torch.from_numpy(save_numpy), read_tensor)\n        save_numpy = read_tensor = None\n\n\n@pytest.mark.parametrize(\"fmt\", [\"torch\", \"numpy\"])\ndef test_read_data(fmt):\n    with tempfile.TemporaryDirectory() as test_dir:\n        data = np.array([[1, 2, 4], [2, 5, 3]])\n        type_name = \"pt\" if fmt == \"torch\" else \"npy\"\n        file_name = os.path.join(test_dir, f\"save_data.{type_name}\")\n        if fmt == \"numpy\":\n            np.save(file_name, data)\n        elif fmt == \"torch\":\n            torch.save(torch.from_numpy(data), file_name)\n        read_tensor = internal.read_data(file_name, fmt)\n        assert torch.equal(torch.from_numpy(data), read_tensor)\n\n\n@pytest.mark.parametrize(\n    \"data_fmt, save_fmt, contiguous\",\n    [\n        (\"torch\", \"torch\", True),\n        (\"torch\", \"torch\", False),\n        (\"torch\", \"numpy\", True),\n        (\"torch\", \"numpy\", False),\n        (\"numpy\", \"torch\", True),\n        (\"numpy\", \"torch\", False),\n        (\"numpy\", \"numpy\", True),\n        (\"numpy\", \"numpy\", False),\n    ],\n)\ndef test_save_data(data_fmt, save_fmt, contiguous):\n    with tempfile.TemporaryDirectory() as test_dir:\n        data = np.array([[1, 2, 4], [2, 5, 3]])\n        if not contiguous:\n            data = np.asfortranarray(data)\n        tensor_data = torch.from_numpy(data)\n        type_name = \"pt\" if save_fmt == \"torch\" else \"npy\"\n        save_file_name = os.path.join(test_dir, f\"save_data.{type_name}\")\n        # Step1. Save the data.\n        if data_fmt == \"torch\":\n            internal.save_data(tensor_data, save_file_name, save_fmt)\n        elif data_fmt == \"numpy\":\n            internal.save_data(data, save_file_name, save_fmt)\n\n        # Step2. Load the data.\n        if save_fmt == \"torch\":\n            loaded_data = torch.load(save_file_name, weights_only=False)\n            assert loaded_data.is_contiguous()\n            assert torch.equal(tensor_data, loaded_data)\n        elif save_fmt == \"numpy\":\n            loaded_data = np.load(save_file_name)\n            # Checks if the loaded data is C-contiguous.\n            assert loaded_data.flags[\"C_CONTIGUOUS\"]\n            assert np.array_equal(tensor_data.numpy(), loaded_data)\n\n        data = tensor_data = loaded_data = None\n\n\n@pytest.mark.parametrize(\"fmt\", [\"torch\", \"numpy\"])\ndef test_get_npy_dim(fmt):\n    with tempfile.TemporaryDirectory() as test_dir:\n        data = np.array([[1, 2, 4], [2, 5, 3]])\n        type_name = \"pt\" if fmt == \"torch\" else \"npy\"\n        file_name = os.path.join(test_dir, f\"save_data.{type_name}\")\n        if fmt == \"numpy\":\n            np.save(file_name, data)\n            assert internal.get_npy_dim(file_name) == 2\n        elif fmt == \"torch\":\n            torch.save(torch.from_numpy(data), file_name)\n            with pytest.raises(ValueError):\n                internal.get_npy_dim(file_name)\n        data = None\n\n\n@pytest.mark.parametrize(\"data_fmt\", [\"numpy\", \"torch\"])\n@pytest.mark.parametrize(\"save_fmt\", [\"numpy\", \"torch\"])\n@pytest.mark.parametrize(\"is_feature\", [True, False])\ndef test_copy_or_convert_data(data_fmt, save_fmt, is_feature):\n    with tempfile.TemporaryDirectory() as test_dir:\n        data = np.arange(10)\n        tensor_data = torch.from_numpy(data)\n        in_type_name = \"npy\" if data_fmt == \"numpy\" else \"pt\"\n        input_path = os.path.join(test_dir, f\"data.{in_type_name}\")\n        out_type_name = \"npy\" if save_fmt == \"numpy\" else \"pt\"\n        output_path = os.path.join(test_dir, f\"out_data.{out_type_name}\")\n        if data_fmt == \"numpy\":\n            np.save(input_path, data)\n        else:\n            torch.save(tensor_data, input_path)\n        if save_fmt == \"torch\":\n            with pytest.raises(AssertionError):\n                internal.copy_or_convert_data(\n                    input_path,\n                    output_path,\n                    data_fmt,\n                    save_fmt,\n                    is_feature=is_feature,\n                )\n        else:\n            internal.copy_or_convert_data(\n                input_path,\n                output_path,\n                data_fmt,\n                save_fmt,\n                is_feature=is_feature,\n            )\n        if is_feature:\n            data = data.reshape(-1, 1)\n            tensor_data = tensor_data.reshape(-1, 1)\n        if save_fmt == \"numpy\":\n            out_data = np.load(output_path)\n            assert (data == out_data).all()\n\n        data = None\n        tensor_data = None\n        out_data = None\n\n\n@pytest.mark.parametrize(\"edge_fmt\", [\"csv\", \"numpy\"])\ndef test_read_edges(edge_fmt):\n    with tempfile.TemporaryDirectory() as test_dir:\n        num_nodes = 40\n        num_edges = 200\n        nodes = np.repeat(np.arange(num_nodes), 5)\n        neighbors = np.random.randint(0, num_nodes, size=(num_edges))\n        edges = np.stack([nodes, neighbors], axis=1)\n        os.makedirs(os.path.join(test_dir, \"edges\"), exist_ok=True)\n        if edge_fmt == \"csv\":\n            # Wrtie into edges/edge.csv\n            edges = pd.DataFrame(edges, columns=[\"src\", \"dst\"])\n            edge_path = os.path.join(\"edges\", \"edge.csv\")\n            edges.to_csv(\n                os.path.join(test_dir, edge_path),\n                index=False,\n                header=False,\n            )\n        else:\n            # Wrtie into edges/edge.npy\n            edges = edges.T\n            edge_path = os.path.join(\"edges\", \"edge.npy\")\n            np.save(os.path.join(test_dir, edge_path), edges)\n        src, dst = internal.read_edges(test_dir, edge_fmt, edge_path)\n        assert src.all() == nodes.all()\n        assert dst.all() == neighbors.all()\n\n\ndef test_read_edges_error():\n    # 1. Unsupported file format.\n    with pytest.raises(\n        AssertionError,\n        match=\"`numpy` or `csv` is expected when reading edges but got `fake-type`.\",\n    ):\n        internal.read_edges(\"test_dir\", \"fake-type\", \"edge_path\")\n\n    # 2. Unexpected shape of numpy array\n    with tempfile.TemporaryDirectory() as test_dir:\n        num_nodes = 40\n        num_edges = 200\n        nodes = np.repeat(np.arange(num_nodes), 5)\n        neighbors = np.random.randint(0, num_nodes, size=(num_edges))\n        edges = np.stack([nodes, neighbors, nodes], axis=1)\n        os.makedirs(os.path.join(test_dir, \"edges\"), exist_ok=True)\n        # Wrtie into edges/edge.npy\n        edges = edges.T\n        edge_path = os.path.join(\"edges\", \"edge.npy\")\n        np.save(os.path.join(test_dir, edge_path), edges)\n        with pytest.raises(\n            AssertionError,\n            match=re.escape(\n                \"The shape of edges should be (2, N), but got torch.Size([3, 200]).\"\n            ),\n        ):\n            internal.read_edges(test_dir, \"numpy\", edge_path)\n\n\ndef test_calculate_file_hash():\n    with tempfile.TemporaryDirectory() as test_dir:\n        test_file_path = os.path.join(test_dir, \"test.txt\")\n        with open(test_file_path, \"w\") as file:\n            file.write(\"test content\")\n        hash_value = internal.calculate_file_hash(\n            test_file_path, hash_algo=\"md5\"\n        )\n        expected_hash_value = \"9473fdd0d880a43c21b7778d34872157\"\n        assert expected_hash_value == hash_value\n        with pytest.raises(\n            ValueError,\n            match=re.escape(\n                \"Hash algorithm must be one of: ['md5', 'sha1', 'sha224', \"\n                + \"'sha256', 'sha384', 'sha512'], but got `fake`.\"\n            ),\n        ):\n            hash_value = internal.calculate_file_hash(\n                test_file_path, hash_algo=\"fake\"\n            )\n\n\ndef test_calculate_dir_hash():\n    with tempfile.TemporaryDirectory() as test_dir:\n        test_file_path_1 = os.path.join(test_dir, \"test_1.txt\")\n        test_file_path_2 = os.path.join(test_dir, \"test_2.txt\")\n        with open(test_file_path_1, \"w\") as file:\n            file.write(\"test content\")\n        with open(test_file_path_2, \"w\") as file:\n            file.write(\"test contents of directory\")\n        hash_value = internal.calculate_dir_hash(test_dir, hash_algo=\"md5\")\n        expected_hash_value = [\n            \"56e708a2bdf92887d4a7f25cbc13c555\",\n            \"9473fdd0d880a43c21b7778d34872157\",\n        ]\n        assert len(hash_value) == 2\n        for val in hash_value.values():\n            assert val in expected_hash_value\n\n\ndef test_check_dataset_change():\n    with tempfile.TemporaryDirectory() as test_dir:\n        # Generate directory and record its hash value.\n        test_file_path_1 = os.path.join(test_dir, \"test_1.txt\")\n        test_file_path_2 = os.path.join(test_dir, \"test_2.txt\")\n        with open(test_file_path_1, \"w\") as file:\n            file.write(\"test content\")\n        with open(test_file_path_2, \"w\") as file:\n            file.write(\"test contents of directory\")\n        hash_value = internal.calculate_dir_hash(test_dir, hash_algo=\"md5\")\n        hash_value_file = \"dataset_hash_value.txt\"\n        hash_value_file_paht = os.path.join(\n            test_dir, \"preprocessed\", hash_value_file\n        )\n        os.makedirs(os.path.join(test_dir, \"preprocessed\"), exist_ok=True)\n        with open(hash_value_file_paht, \"w\") as file:\n            file.write(json.dumps(hash_value, indent=4))\n\n        # Modify the content of a file.\n        with open(test_file_path_2, \"w\") as file:\n            file.write(\"test contents of directory changed\")\n\n        assert internal.check_dataset_change(test_dir, \"preprocessed\")\n\n\ndef test_numpy_save_aligned():\n    assert_equal = partial(torch.testing.assert_close, rtol=0, atol=0)\n    a = torch.randn(1024, dtype=torch.float32)  # 4096 bytes\n    with tempfile.TemporaryDirectory() as test_dir:\n        aligned_path = os.path.join(test_dir, \"aligned.npy\")\n        gb.numpy_save_aligned(aligned_path, a.numpy())\n\n        nonaligned_path = os.path.join(test_dir, \"nonaligned.npy\")\n        np.save(nonaligned_path, a.numpy())\n\n        assert_equal(np.load(aligned_path), np.load(nonaligned_path))\n        # The size of the file should be 4K (aligned header) + 4K (tensor).\n        assert os.path.getsize(aligned_path) == 4096 * 2\n"
  },
  {
    "path": "tests/python/pytorch/graphbolt/test_base.py",
    "content": "import os\nimport re\nimport unittest\nfrom collections.abc import Iterable, Mapping\n\nimport backend as F\n\nimport dgl.graphbolt as gb\nimport pytest\nimport torch\nfrom torch.torch_version import TorchVersion\n\nfrom . import gb_test_utils\n\n\ndef test_pytorch_cuda_allocator_conf():\n    env = os.getenv(\"PYTORCH_CUDA_ALLOC_CONF\")\n    assert env is not None\n    config_list = env.split(\",\")\n    assert \"expandable_segments:True\" in config_list\n\n\n@unittest.skipIf(F._default_context_str != \"gpu\", \"CopyTo needs GPU to test\")\n@pytest.mark.parametrize(\"non_blocking\", [False, True])\ndef test_CopyTo(non_blocking):\n    item_sampler = gb.ItemSampler(\n        gb.ItemSet(torch.arange(20), names=\"seeds\"), 4\n    )\n    if non_blocking:\n        item_sampler = item_sampler.transform(lambda x: x.pin_memory())\n\n    # Invoke CopyTo via class constructor.\n    dp = gb.CopyTo(item_sampler, \"cuda\")\n    for data in dp:\n        assert data.seeds.device.type == \"cuda\"\n\n    dp = gb.CopyTo(item_sampler, \"cuda\", non_blocking)\n    for data in dp:\n        assert data.seeds.device.type == \"cuda\"\n\n    # Invoke CopyTo via functional form.\n    dp = item_sampler.copy_to(\"cuda\", non_blocking)\n    for data in dp:\n        assert data.seeds.device.type == \"cuda\"\n\n\n@pytest.mark.parametrize(\n    \"task\",\n    [\n        \"node_classification\",\n        \"node_inference\",\n        \"link_prediction\",\n        \"edge_classification\",\n    ],\n)\n@unittest.skipIf(F._default_context_str == \"cpu\", \"CopyTo needs GPU to test\")\ndef test_CopyToWithMiniBatches(task):\n    N = 16\n    B = 2\n    if task == \"node_classification\":\n        itemset = gb.ItemSet(\n            (torch.arange(N), torch.arange(N)), names=(\"seeds\", \"labels\")\n        )\n    elif task == \"node_inference\":\n        itemset = gb.ItemSet(torch.arange(N), names=\"seeds\")\n    elif task == \"link_prediction\":\n        itemset = gb.ItemSet(\n            (\n                torch.arange(2 * N).reshape(-1, 2),\n                torch.arange(N),\n            ),\n            names=(\"seeds\", \"labels\"),\n        )\n    elif task == \"edge_classification\":\n        itemset = gb.ItemSet(\n            (torch.arange(2 * N).reshape(-1, 2), torch.arange(N)),\n            names=(\"seeds\", \"labels\"),\n        )\n    graph = gb_test_utils.rand_csc_graph(100, 0.15, bidirection_edge=True)\n\n    features = {}\n    keys = [(\"node\", None, \"a\"), (\"node\", None, \"b\")]\n    features[keys[0]] = gb.TorchBasedFeature(torch.randn(200, 4))\n    features[keys[1]] = gb.TorchBasedFeature(torch.randn(200, 4))\n    feature_store = gb.BasicFeatureStore(features)\n\n    datapipe = gb.ItemSampler(itemset, batch_size=B)\n    datapipe = gb.NeighborSampler(\n        datapipe,\n        graph,\n        fanouts=[torch.LongTensor([2]) for _ in range(2)],\n    )\n    if task != \"node_inference\":\n        datapipe = gb.FeatureFetcher(\n            datapipe,\n            feature_store,\n            [\"a\"],\n        )\n\n    copied_attrs = [\n        \"labels\",\n        \"compacted_seeds\",\n        \"sampled_subgraphs\",\n        \"indexes\",\n        \"node_features\",\n        \"edge_features\",\n        \"blocks\",\n        \"seeds\",\n        \"input_nodes\",\n    ]\n\n    def test_data_device(datapipe):\n        for data in datapipe:\n            for attr in dir(data):\n                var = getattr(data, attr)\n                if isinstance(var, Mapping):\n                    var = var[next(iter(var))]\n                elif isinstance(var, Iterable):\n                    var = next(iter(var))\n                if (\n                    not callable(var)\n                    and not attr.startswith(\"__\")\n                    and hasattr(var, \"device\")\n                    and var is not None\n                ):\n                    if attr in copied_attrs:\n                        assert var.device.type == \"cuda\", attr\n                    else:\n                        assert var.device.type == \"cpu\", attr\n\n    # Invoke CopyTo via class constructor.\n    test_data_device(gb.CopyTo(datapipe, \"cuda\"))\n\n    # Invoke CopyTo via functional form.\n    test_data_device(datapipe.copy_to(\"cuda\"))\n\n\ndef test_etype_tuple_to_str():\n    \"\"\"Convert etype from tuple to string.\"\"\"\n    # Test for expected input.\n    c_etype = (\"user\", \"like\", \"item\")\n    c_etype_str = gb.etype_tuple_to_str(c_etype)\n    assert c_etype_str == \"user:like:item\"\n\n    # Test for unexpected input: not a tuple.\n    c_etype = \"user:like:item\"\n    with pytest.raises(\n        AssertionError,\n        match=re.escape(\n            \"Passed-in canonical etype should be in format of (str, str, str). \"\n            \"But got user:like:item.\"\n        ),\n    ):\n        _ = gb.etype_tuple_to_str(c_etype)\n\n    # Test for unexpected input: tuple with wrong length.\n    c_etype = (\"user\", \"like\")\n    with pytest.raises(\n        AssertionError,\n        match=re.escape(\n            \"Passed-in canonical etype should be in format of (str, str, str). \"\n            \"But got ('user', 'like').\"\n        ),\n    ):\n        _ = gb.etype_tuple_to_str(c_etype)\n\n\ndef test_etype_str_to_tuple():\n    \"\"\"Convert etype from string to tuple.\"\"\"\n    # Test for expected input.\n    c_etype_str = \"user:like:item\"\n    c_etype = gb.etype_str_to_tuple(c_etype_str)\n    assert c_etype == (\"user\", \"like\", \"item\")\n\n    # Test for unexpected input: string with wrong format.\n    c_etype_str = \"user:like\"\n    with pytest.raises(\n        AssertionError,\n        match=re.escape(\n            \"Passed-in canonical etype should be in format of 'str:str:str'. \"\n            \"But got user:like.\"\n        ),\n    ):\n        _ = gb.etype_str_to_tuple(c_etype_str)\n\n\ndef test_seed_type_str_to_ntypes():\n    \"\"\"Convert etype from string to tuple.\"\"\"\n    # Test for node pairs.\n    seed_type_str = \"user:like:item\"\n    seed_size = 2\n    node_type = gb.seed_type_str_to_ntypes(seed_type_str, seed_size)\n    assert node_type == [\"user\", \"item\"]\n\n    # Test for node pairs.\n    seed_type_str = \"user:item:user\"\n    seed_size = 3\n    node_type = gb.seed_type_str_to_ntypes(seed_type_str, seed_size)\n    assert node_type == [\"user\", \"item\", \"user\"]\n\n    # Test for unexpected input: list.\n    seed_type_str = [\"user\", \"item\"]\n    with pytest.raises(\n        AssertionError,\n        match=re.escape(\n            \"Passed-in seed type should be string, but got <class 'list'>\"\n        ),\n    ):\n        _ = gb.seed_type_str_to_ntypes(seed_type_str, 2)\n\n\ndef test_isin():\n    elements = torch.tensor([2, 3, 5, 5, 20, 13, 11], device=F.ctx())\n    test_elements = torch.tensor([2, 5], device=F.ctx())\n    res = gb.isin(elements, test_elements)\n    expected = torch.tensor(\n        [True, False, True, True, False, False, False], device=F.ctx()\n    )\n    assert torch.equal(res, expected)\n\n\ndef test_isin_big_data():\n    elements = torch.randint(0, 10000, (10000000,), device=F.ctx())\n    test_elements = torch.randint(0, 10000, (500000,), device=F.ctx())\n    res = gb.isin(elements, test_elements)\n    expected = torch.isin(elements, test_elements)\n    assert torch.equal(res, expected)\n\n\ndef test_isin_non_1D_dim():\n    elements = torch.tensor([[2, 3], [5, 5], [20, 13]], device=F.ctx())\n    test_elements = torch.tensor([2, 5], device=F.ctx())\n    with pytest.raises(Exception):\n        gb.isin(elements, test_elements)\n    elements = torch.tensor([2, 3, 5, 5, 20, 13], device=F.ctx())\n    test_elements = torch.tensor([[2, 5]], device=F.ctx())\n    with pytest.raises(Exception):\n        gb.isin(elements, test_elements)\n\n\n@pytest.mark.parametrize(\n    \"dtype\",\n    [\n        torch.bool,\n        torch.uint8,\n        torch.int8,\n        torch.int16,\n        torch.int32,\n        torch.int64,\n        torch.float16,\n        torch.bfloat16,\n        torch.float32,\n        torch.float64,\n    ],\n)\n@pytest.mark.parametrize(\"idtype\", [torch.int32, torch.int64])\n@pytest.mark.parametrize(\"pinned\", [False, True])\ndef test_index_select(dtype, idtype, pinned):\n    if F._default_context_str != \"gpu\" and pinned:\n        pytest.skip(\"Pinned tests are available only on GPU.\")\n    tensor = torch.tensor([[2, 3], [5, 5], [20, 13]], dtype=dtype)\n    tensor = tensor.pin_memory() if pinned else tensor.to(F.ctx())\n    index = torch.tensor([0, 2], dtype=idtype, device=F.ctx())\n    gb_result = gb.index_select(tensor, index)\n    torch_result = tensor.to(F.ctx())[index.long()]\n    assert torch.equal(torch_result, gb_result)\n    if pinned:\n        gb_result = gb.index_select(tensor.cpu(), index.cpu().pin_memory())\n        assert torch.equal(torch_result.cpu(), gb_result)\n        assert gb_result.is_pinned()\n\n    # Test the internal async API\n    future = torch.ops.graphbolt.index_select_async(tensor.cpu(), index.cpu())\n    assert torch.equal(torch_result.cpu(), future.wait())\n\n\n@pytest.mark.parametrize(\n    \"dtype\",\n    [\n        torch.bool,\n        torch.uint8,\n        torch.int8,\n        torch.int16,\n        torch.int32,\n        torch.int64,\n        torch.float16,\n        torch.bfloat16,\n        torch.float32,\n        torch.float64,\n    ],\n)\n@pytest.mark.parametrize(\"idtype\", [torch.int32, torch.int64])\ndef test_scatter_async(dtype, idtype):\n    input = torch.tensor([[2, 3], [5, 5], [20, 13]], dtype=dtype)\n    index = torch.ones([1], dtype=idtype)\n    res = torch.ops.graphbolt.scatter_async(input, index, input[2:3])\n    assert torch.equal(\n        torch.tensor([[2, 3], [20, 13], [20, 13]], dtype=dtype), res.wait()\n    )\n\n\ndef torch_expand_indptr(indptr, dtype, nodes=None):\n    if nodes is None:\n        nodes = torch.arange(len(indptr) - 1, dtype=dtype, device=indptr.device)\n    return nodes.to(dtype).repeat_interleave(indptr.diff())\n\n\n@pytest.mark.parametrize(\"nodes\", [None, True])\n@pytest.mark.parametrize(\"dtype\", [torch.int32, torch.int64])\ndef test_expand_indptr(nodes, dtype):\n    if nodes:\n        nodes = torch.tensor([1, 7, 3, 4, 5, 8], dtype=dtype, device=F.ctx())\n    indptr = torch.tensor([0, 2, 2, 7, 10, 12, 20], device=F.ctx())\n    torch_result = torch_expand_indptr(indptr, dtype, nodes)\n    gb_result = gb.expand_indptr(indptr, dtype, nodes)\n    assert torch.equal(torch_result, gb_result)\n    gb_result = gb.expand_indptr(indptr, dtype, nodes, indptr[-1].item())\n    assert torch.equal(torch_result, gb_result)\n\n    if TorchVersion(torch.__version__) >= TorchVersion(\"2.2.0a0\"):\n        import torch._dynamo as dynamo\n        from torch.testing._internal.optests import opcheck\n\n        # Tests torch.compile compatibility\n        for output_size in [None, indptr[-1].item()]:\n            kwargs = {\"node_ids\": nodes, \"output_size\": output_size}\n            opcheck(\n                torch.ops.graphbolt.expand_indptr,\n                (indptr, dtype),\n                kwargs,\n                test_utils=[\n                    \"test_schema\",\n                    \"test_autograd_registration\",\n                    \"test_faketensor\",\n                    \"test_aot_dispatch_dynamic\",\n                ],\n                raise_exception=True,\n            )\n\n            explanation = dynamo.explain(gb.expand_indptr)(\n                indptr, dtype, nodes, output_size\n            )\n            expected_breaks = -1 if output_size is None else 0\n            assert explanation.graph_break_count == expected_breaks\n\n\n@unittest.skipIf(\n    F._default_context_str != \"gpu\", \"Only GPU implementation is available.\"\n)\n@pytest.mark.parametrize(\"offset\", [None, True])\n@pytest.mark.parametrize(\"dtype\", [torch.int32, torch.int64])\ndef test_indptr_edge_ids(offset, dtype):\n    indptr = torch.tensor([0, 2, 2, 7, 10, 12], device=F.ctx())\n    if offset:\n        offset = indptr[:-1]\n        ref_result = torch.arange(\n            0, indptr[-1].item(), dtype=dtype, device=F.ctx()\n        )\n    else:\n        ref_result = torch.tensor(\n            [0, 1, 0, 1, 2, 3, 4, 0, 1, 2, 0, 1], dtype=dtype, device=F.ctx()\n        )\n    gb_result = gb.indptr_edge_ids(indptr, dtype, offset)\n    assert torch.equal(ref_result, gb_result)\n    gb_result = gb.indptr_edge_ids(indptr, dtype, offset, indptr[-1].item())\n    assert torch.equal(ref_result, gb_result)\n\n    if TorchVersion(torch.__version__) >= TorchVersion(\"2.2.0a0\"):\n        import torch._dynamo as dynamo\n        from torch.testing._internal.optests import opcheck\n\n        # Tests torch.compile compatibility\n        for output_size in [None, indptr[-1].item()]:\n            kwargs = {\"offset\": offset, \"output_size\": output_size}\n            opcheck(\n                torch.ops.graphbolt.indptr_edge_ids,\n                (indptr, dtype),\n                kwargs,\n                test_utils=[\n                    \"test_schema\",\n                    \"test_autograd_registration\",\n                    \"test_faketensor\",\n                    \"test_aot_dispatch_dynamic\",\n                ],\n                raise_exception=True,\n            )\n\n            explanation = dynamo.explain(gb.indptr_edge_ids)(\n                indptr, dtype, offset, output_size\n            )\n            expected_breaks = -1 if output_size is None else 0\n            assert explanation.graph_break_count == expected_breaks\n\n\ndef test_csc_format_base_representation():\n    csc_format_base = gb.CSCFormatBase(\n        indptr=torch.tensor([0, 2, 4]),\n        indices=torch.tensor([4, 5, 6, 7]),\n    )\n    expected_result = str(\n        \"\"\"CSCFormatBase(indptr=tensor([0, 2, 4]),\n              indices=tensor([4, 5, 6, 7]),\n)\"\"\"\n    )\n    assert str(csc_format_base) == expected_result, print(csc_format_base)\n\n\ndef test_csc_format_base_incorrect_indptr():\n    indptr = torch.tensor([0, 2, 4, 6, 7, 11])\n    indices = torch.tensor([2, 3, 1, 4, 5, 2, 5, 1, 4, 4])\n    with pytest.raises(AssertionError):\n        # The value of last element in indptr is not corresponding to indices.\n        csc_formats = gb.CSCFormatBase(indptr=indptr, indices=indices)\n"
  },
  {
    "path": "tests/python/pytorch/graphbolt/test_dataloader.py",
    "content": "import os\nimport unittest\nfrom sys import platform\n\nimport backend as F\n\nimport dgl\nimport dgl.graphbolt\nimport dgl.graphbolt as gb\nimport pytest\nimport torch\nimport torch.distributed as thd\n\nfrom dgl.graphbolt.datapipes import find_dps, traverse_dps\n\nfrom . import gb_test_utils\n\n\n@pytest.mark.parametrize(\"overlap_feature_fetch\", [False, True])\ndef test_DataLoader(overlap_feature_fetch):\n    N = 40\n    B = 4\n    itemset = dgl.graphbolt.ItemSet(torch.arange(N), names=\"seeds\")\n    graph = gb_test_utils.rand_csc_graph(200, 0.15, bidirection_edge=True)\n    features = {}\n    keys = [(\"node\", None, \"a\"), (\"node\", None, \"b\"), (\"edge\", None, \"c\")]\n    features[keys[0]] = dgl.graphbolt.TorchBasedFeature(torch.randn(200, 4))\n    features[keys[1]] = dgl.graphbolt.TorchBasedFeature(torch.randn(200, 4))\n    M = graph.total_num_edges\n    features[keys[2]] = dgl.graphbolt.TorchBasedFeature(torch.randn(M, 1))\n    feature_store = dgl.graphbolt.BasicFeatureStore(features)\n\n    item_sampler = dgl.graphbolt.ItemSampler(itemset, batch_size=B)\n    subgraph_sampler = dgl.graphbolt.NeighborSampler(\n        item_sampler,\n        graph,\n        fanouts=[torch.LongTensor([2]) for _ in range(2)],\n    )\n    feature_fetcher = dgl.graphbolt.FeatureFetcher(\n        subgraph_sampler,\n        feature_store,\n        [\"a\", \"b\"],\n        [\"c\"],\n        overlap_fetch=overlap_feature_fetch,\n    )\n    device_transferrer = dgl.graphbolt.CopyTo(feature_fetcher, F.ctx())\n\n    dataloader = dgl.graphbolt.DataLoader(\n        device_transferrer,\n        num_workers=4,\n    )\n    for i, minibatch in enumerate(dataloader):\n        assert \"a\" in minibatch.node_features\n        assert \"b\" in minibatch.node_features\n        for layer_id in range(minibatch.num_layers()):\n            assert \"c\" in minibatch.edge_features[layer_id]\n    assert i + 1 == N // B\n\n\n@unittest.skipIf(\n    F._default_context_str != \"gpu\",\n    reason=\"This test requires the GPU.\",\n)\n@pytest.mark.parametrize(\n    \"sampler_name\", [\"NeighborSampler\", \"LayerNeighborSampler\"]\n)\n@pytest.mark.parametrize(\"enable_feature_fetch\", [True, False])\n@pytest.mark.parametrize(\"overlap_feature_fetch\", [True, False])\n@pytest.mark.parametrize(\"overlap_graph_fetch\", [True, False])\n@pytest.mark.parametrize(\"cooperative\", [True, False])\n@pytest.mark.parametrize(\"asynchronous\", [True, False])\n@pytest.mark.parametrize(\"num_gpu_cached_edges\", [0, 1024])\n@pytest.mark.parametrize(\"gpu_cache_threshold\", [1, 3])\ndef test_gpu_sampling_DataLoader(\n    sampler_name,\n    enable_feature_fetch,\n    overlap_feature_fetch,\n    overlap_graph_fetch,\n    cooperative,\n    asynchronous,\n    num_gpu_cached_edges,\n    gpu_cache_threshold,\n):\n    if cooperative and not thd.is_initialized():\n        # On Windows, the init method can only be file.\n        init_method = (\n            f\"file:///{os.path.join(os.getcwd(), 'dis_tempfile')}\"\n            if platform == \"win32\"\n            else \"tcp://127.0.0.1:12345\"\n        )\n        thd.init_process_group(\n            init_method=init_method,\n            world_size=1,\n            rank=0,\n        )\n    N = 40\n    B = 4\n    num_layers = 2\n    itemset = dgl.graphbolt.ItemSet(torch.arange(N), names=\"seeds\")\n    graph = gb_test_utils.rand_csc_graph(200, 0.15, bidirection_edge=True)\n    graph = graph.pin_memory_() if overlap_graph_fetch else graph.to(F.ctx())\n    features = {}\n    keys = [\n        (\"node\", None, \"a\"),\n        (\"node\", None, \"b\"),\n        (\"node\", None, \"c\"),\n        (\"edge\", None, \"d\"),\n    ]\n    features[keys[0]] = dgl.graphbolt.TorchBasedFeature(\n        torch.randn(200, 4, pin_memory=True)\n    )\n    features[keys[1]] = dgl.graphbolt.TorchBasedFeature(\n        torch.randn(200, 4, pin_memory=True)\n    )\n    features[keys[2]] = dgl.graphbolt.TorchBasedFeature(\n        torch.randn(200, 4, device=F.ctx())\n    )\n    features[keys[3]] = dgl.graphbolt.TorchBasedFeature(\n        torch.randn(graph.total_num_edges, 1, device=F.ctx())\n    )\n    feature_store = dgl.graphbolt.BasicFeatureStore(features)\n\n    dataloaders = []\n    for i in range(2):\n        datapipe = dgl.graphbolt.ItemSampler(itemset, batch_size=B)\n        datapipe = datapipe.copy_to(F.ctx())\n        kwargs = {\n            \"overlap_fetch\": overlap_graph_fetch,\n            \"num_gpu_cached_edges\": num_gpu_cached_edges,\n            \"gpu_cache_threshold\": gpu_cache_threshold,\n            \"cooperative\": cooperative,\n            \"asynchronous\": asynchronous,\n        }\n        if i != 0:\n            kwargs = {}\n        datapipe = getattr(dgl.graphbolt, sampler_name)(\n            datapipe,\n            graph,\n            fanouts=[torch.LongTensor([2]) for _ in range(num_layers)],\n            **kwargs,\n        )\n        if enable_feature_fetch:\n            datapipe = dgl.graphbolt.FeatureFetcher(\n                datapipe,\n                feature_store,\n                [\"a\", \"b\", \"c\"],\n                [\"d\"],\n                overlap_fetch=overlap_feature_fetch and i == 0,\n                cooperative=asynchronous and cooperative and i == 0,\n            )\n        dataloaders.append(dgl.graphbolt.DataLoader(datapipe))\n    dataloader, dataloader2 = dataloaders\n\n    bufferer_cnt = int(enable_feature_fetch and overlap_feature_fetch)\n    if overlap_graph_fetch:\n        bufferer_cnt += num_layers\n        if num_gpu_cached_edges > 0:\n            bufferer_cnt += 2 * num_layers\n    if asynchronous:\n        bufferer_cnt += 2 * num_layers + 1  # _preprocess stage has 1.\n        if cooperative:\n            bufferer_cnt += 3 * num_layers\n            if enable_feature_fetch:\n                bufferer_cnt += 1  # feature fetch has 1.\n    if cooperative:\n        # _preprocess stage.\n        bufferer_cnt += 4\n    datapipe_graph = traverse_dps(dataloader)\n    bufferers = find_dps(\n        datapipe_graph,\n        dgl.graphbolt.Bufferer,\n    )\n    assert len(bufferers) == bufferer_cnt\n    # Fixes the randomness of LayerNeighborSampler\n    torch.manual_seed(1)\n    minibatches = list(dataloader)\n    assert len(minibatches) == N // B\n\n    for i, _ in enumerate(dataloader):\n        if i >= 1:\n            break\n\n    torch.manual_seed(1)\n\n    for minibatch, minibatch2 in zip(minibatches, dataloader2):\n        if enable_feature_fetch:\n            assert \"a\" in minibatch.node_features\n            assert \"b\" in minibatch.node_features\n            assert \"c\" in minibatch.node_features\n            if sampler_name == \"LayerNeighborSampler\":\n                assert torch.equal(\n                    minibatch.node_features[\"a\"], minibatch2.node_features[\"a\"]\n                )\n            for layer_id in range(minibatch.num_layers()):\n                assert \"d\" in minibatch.edge_features[layer_id]\n                edge_feature = minibatch.edge_features[layer_id][\"d\"]\n                edge_feature_ref = minibatch2.edge_features[layer_id][\"d\"]\n                if sampler_name == \"LayerNeighborSampler\":\n                    assert torch.equal(edge_feature, edge_feature_ref)\n    assert len(list(dataloader)) == N // B\n\n    if asynchronous and cooperative:\n        for minibatch in minibatches:\n            x = torch.ones((minibatch.node_ids().size(0), 1), device=F.ctx())\n            for subgraph in minibatch.sampled_subgraphs:\n                x = gb.CooperativeConvFunction.apply(subgraph, x)\n                x, edge_index, size = subgraph.to_pyg(x)\n                x = x[0]\n                one = torch.ones(\n                    edge_index.shape[1], dtype=x.dtype, device=x.device\n                )\n                coo = torch.sparse_coo_tensor(\n                    edge_index.flipud(), one, size=(size[1], size[0])\n                )\n                x = torch.sparse.mm(coo, x)\n            assert x.shape[0] == minibatch.seeds.shape[0]\n            assert x.shape[1] == 1\n\n    if thd.is_initialized():\n        thd.destroy_process_group()\n"
  },
  {
    "path": "tests/python/pytorch/graphbolt/test_dataset.py",
    "content": "import pytest\n\nfrom dgl import graphbolt as gb\n\n\ndef test_Dataset():\n    dataset = gb.Dataset()\n    with pytest.raises(NotImplementedError):\n        _ = dataset.tasks\n    with pytest.raises(NotImplementedError):\n        _ = dataset.graph\n    with pytest.raises(NotImplementedError):\n        _ = dataset.feature\n    with pytest.raises(NotImplementedError):\n        _ = dataset.dataset_name\n"
  },
  {
    "path": "tests/python/pytorch/graphbolt/test_feature_fetcher.py",
    "content": "import random\nfrom functools import partial\n\nimport dgl.graphbolt as gb\nimport torch\nfrom torch.utils.data.datapipes.iter import Mapper\n\nfrom . import gb_test_utils\n\n\ndef test_FeatureFetcher_invoke():\n    # Prepare graph and required datapipes.\n    graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True)\n    a = torch.tensor(\n        [[random.randint(0, 10)] for _ in range(graph.total_num_nodes)]\n    )\n    b = torch.tensor(\n        [[random.randint(0, 10)] for _ in range(graph.total_num_edges)]\n    )\n\n    features = {}\n    keys = [(\"node\", None, \"a\"), (\"edge\", None, \"b\")]\n    features[keys[0]] = gb.TorchBasedFeature(a)\n    features[keys[1]] = gb.TorchBasedFeature(b)\n    feature_store = gb.BasicFeatureStore(features)\n\n    itemset = gb.ItemSet(torch.arange(10), names=\"seeds\")\n    item_sampler = gb.ItemSampler(itemset, batch_size=2)\n    num_layer = 2\n    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]\n\n    # Invoke FeatureFetcher via class constructor.\n    datapipe = gb.NeighborSampler(item_sampler, graph, fanouts)\n\n    datapipe = gb.FeatureFetcher(datapipe, feature_store, [\"a\"], [\"b\"])\n    assert len(list(datapipe)) == 5\n\n    # Invoke FeatureFetcher via functional form.\n    datapipe = item_sampler.sample_neighbor(graph, fanouts).fetch_feature(\n        feature_store, [\"a\"], [\"b\"]\n    )\n    assert len(list(datapipe)) == 5\n\n\ndef test_FeatureFetcher_homo():\n    graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True)\n    a = torch.tensor(\n        [[random.randint(0, 10)] for _ in range(graph.total_num_nodes)]\n    )\n    b = torch.tensor(\n        [[random.randint(0, 10)] for _ in range(graph.total_num_edges)]\n    )\n\n    features = {}\n    keys = [(\"node\", None, \"a\"), (\"edge\", None, \"b\")]\n    features[keys[0]] = gb.TorchBasedFeature(a)\n    features[keys[1]] = gb.TorchBasedFeature(b)\n    feature_store = gb.BasicFeatureStore(features)\n\n    itemset = gb.ItemSet(torch.arange(10), names=\"seeds\")\n    item_sampler = gb.ItemSampler(itemset, batch_size=2)\n    num_layer = 2\n    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]\n    sampler_dp = gb.NeighborSampler(item_sampler, graph, fanouts)\n    fetcher_dp = gb.FeatureFetcher(sampler_dp, feature_store, [\"a\"], [\"b\"])\n\n    assert len(list(fetcher_dp)) == 5\n\n\ndef _func(fn, minibatch):\n    return fn(minibatch)\n\n\ndef test_FeatureFetcher_with_edges_homo():\n    graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True)\n    a = torch.tensor(\n        [[random.randint(0, 10)] for _ in range(graph.total_num_nodes)]\n    )\n    b = torch.tensor(\n        [[random.randint(0, 10)] for _ in range(graph.total_num_edges)]\n    )\n\n    def add_node_and_edge_ids(minibatch):\n        seeds = minibatch.seeds\n        subgraphs = []\n        for _ in range(3):\n            sampled_csc = gb.CSCFormatBase(\n                indptr=torch.arange(11),\n                indices=torch.arange(10),\n            )\n            subgraphs.append(\n                gb.SampledSubgraphImpl(\n                    sampled_csc=sampled_csc,\n                    original_column_node_ids=torch.arange(10),\n                    original_row_node_ids=torch.arange(10),\n                    original_edge_ids=torch.randint(\n                        0, graph.total_num_edges, (10,)\n                    ),\n                )\n            )\n        data = gb.MiniBatch(input_nodes=seeds, sampled_subgraphs=subgraphs)\n        return data\n\n    features = {}\n    keys = [(\"node\", None, \"a\"), (\"edge\", None, \"b\")]\n    features[keys[0]] = gb.TorchBasedFeature(a)\n    features[keys[1]] = gb.TorchBasedFeature(b)\n    feature_store = gb.BasicFeatureStore(features)\n\n    itemset = gb.ItemSet(torch.arange(10), names=\"seeds\")\n    item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)\n    fn = partial(_func, add_node_and_edge_ids)\n    converter_dp = Mapper(item_sampler_dp, fn)\n    fetcher_dp = gb.FeatureFetcher(converter_dp, feature_store, [\"a\"], [\"b\"])\n\n    assert len(list(fetcher_dp)) == 5\n    for data in fetcher_dp:\n        assert data.node_features[\"a\"].size(0) == 2\n        assert len(data.edge_features) == 3\n        for edge_feature in data.edge_features:\n            assert edge_feature[\"b\"].size(0) == 10\n\n\ndef get_hetero_graph():\n    # COO graph:\n    # [0, 0, 1, 1, 2, 2, 3, 3, 4, 4]\n    # [2, 4, 2, 3, 0, 1, 1, 0, 0, 1]\n    # [1, 1, 1, 1, 0, 0, 0, 0, 0] - > edge type.\n    # num_nodes = 5, num_n1 = 2, num_n2 = 3\n    ntypes = {\"n1\": 0, \"n2\": 1}\n    etypes = {\"n1:e1:n2\": 0, \"n2:e2:n1\": 1}\n    indptr = torch.LongTensor([0, 2, 4, 6, 8, 10])\n    indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1])\n    type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0])\n    node_type_offset = torch.LongTensor([0, 2, 5])\n    return gb.fused_csc_sampling_graph(\n        indptr,\n        indices,\n        node_type_offset=node_type_offset,\n        type_per_edge=type_per_edge,\n        node_type_to_id=ntypes,\n        edge_type_to_id=etypes,\n    )\n\n\ndef test_FeatureFetcher_hetero():\n    graph = get_hetero_graph()\n    a = torch.tensor([[random.randint(0, 10)] for _ in range(2)])\n    b = torch.tensor([[random.randint(0, 10)] for _ in range(3)])\n\n    features = {}\n    keys = [(\"node\", \"n1\", \"a\"), (\"node\", \"n2\", \"a\")]\n    features[keys[0]] = gb.TorchBasedFeature(a)\n    features[keys[1]] = gb.TorchBasedFeature(b)\n    feature_store = gb.BasicFeatureStore(features)\n\n    itemset = gb.HeteroItemSet(\n        {\n            \"n1\": gb.ItemSet(torch.LongTensor([0, 1]), names=\"seeds\"),\n            \"n2\": gb.ItemSet(torch.LongTensor([0, 1, 2]), names=\"seeds\"),\n        }\n    )\n    item_sampler = gb.ItemSampler(itemset, batch_size=2)\n    num_layer = 2\n    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]\n    sampler_dp = gb.NeighborSampler(item_sampler, graph, fanouts)\n    # \"n3\" is not in the sampled input nodes.\n    node_feature_keys = {\"n1\": [\"a\"], \"n2\": [\"a\"], \"n3\": [\"a\"]}\n    fetcher_dp = gb.FeatureFetcher(\n        sampler_dp, feature_store, node_feature_keys=node_feature_keys\n    )\n    assert len(list(fetcher_dp)) == 3\n\n    # Do not fetch feature for \"n1\".\n    node_feature_keys = {\"n2\": [\"a\"]}\n    fetcher_dp = gb.FeatureFetcher(\n        sampler_dp, feature_store, node_feature_keys=node_feature_keys\n    )\n    for mini_batch in fetcher_dp:\n        assert (\"n1\", \"a\") not in mini_batch.node_features\n\n\ndef test_FeatureFetcher_with_edges_hetero():\n    a = torch.tensor([[random.randint(0, 10)] for _ in range(20)])\n    b = torch.tensor([[random.randint(0, 10)] for _ in range(50)])\n\n    def add_node_and_edge_ids(minibatch):\n        seeds = minibatch.seeds\n        subgraphs = []\n        original_edge_ids = {\n            \"n1:e1:n2\": torch.randint(0, 50, (10,)),\n            \"n2:e2:n1\": torch.randint(0, 50, (10,)),\n        }\n        original_column_node_ids = {\n            \"n1\": torch.randint(0, 20, (10,)),\n            \"n2\": torch.randint(0, 20, (10,)),\n        }\n        original_row_node_ids = {\n            \"n1\": torch.randint(0, 20, (10,)),\n            \"n2\": torch.randint(0, 20, (10,)),\n        }\n        for _ in range(3):\n            subgraphs.append(\n                gb.SampledSubgraphImpl(\n                    sampled_csc={\n                        \"n1:e1:n2\": gb.CSCFormatBase(\n                            indptr=torch.arange(11),\n                            indices=torch.arange(10),\n                        ),\n                        \"n2:e2:n1\": gb.CSCFormatBase(\n                            indptr=torch.arange(11),\n                            indices=torch.arange(10),\n                        ),\n                    },\n                    original_column_node_ids=original_column_node_ids,\n                    original_row_node_ids=original_row_node_ids,\n                    original_edge_ids=original_edge_ids,\n                )\n            )\n        data = gb.MiniBatch(input_nodes=seeds, sampled_subgraphs=subgraphs)\n        return data\n\n    features = {}\n    keys = [\n        (\"node\", \"n1\", \"a\"),\n        (\"edge\", \"n1:e1:n2\", \"a\"),\n        (\"edge\", \"n2:e2:n1\", \"a\"),\n    ]\n    features[keys[0]] = gb.TorchBasedFeature(a)\n    features[keys[1]] = gb.TorchBasedFeature(b)\n    feature_store = gb.BasicFeatureStore(features)\n\n    itemset = gb.HeteroItemSet(\n        {\n            \"n1\": gb.ItemSet(torch.randint(0, 20, (10,)), names=\"seeds\"),\n        }\n    )\n    item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)\n    fn = partial(_func, add_node_and_edge_ids)\n    converter_dp = Mapper(item_sampler_dp, fn)\n    # \"n3:e3:n3\" is not in the sampled edges.\n    # Do not fetch feature for \"n2:e2:n1\".\n    node_feature_keys = {\"n1\": [\"a\"]}\n    edge_feature_keys = {\"n1:e1:n2\": [\"a\"], \"n3:e3:n3\": [\"a\"]}\n    fetcher_dp = gb.FeatureFetcher(\n        converter_dp,\n        feature_store,\n        node_feature_keys=node_feature_keys,\n        edge_feature_keys=edge_feature_keys,\n    )\n\n    assert len(list(fetcher_dp)) == 5\n    for data in fetcher_dp:\n        assert data.node_features[(\"n1\", \"a\")].size(0) == 2\n        assert len(data.edge_features) == 3\n        for edge_feature in data.edge_features:\n            assert edge_feature[(\"n1:e1:n2\", \"a\")].size(0) == 10\n            assert (\"n2:e2:n1\", \"a\") not in edge_feature\n"
  },
  {
    "path": "tests/python/pytorch/graphbolt/test_graphbolt_utils.py",
    "content": "import backend as F\nimport dgl.graphbolt as gb\nimport pytest\nimport torch\n\n\ndef test_find_reverse_edges_homo():\n    edges = torch.tensor([[1, 3, 5], [2, 4, 5]]).T\n    edges = gb.add_reverse_edges(edges)\n    expected_edges = torch.tensor([[1, 3, 5, 2, 4, 5], [2, 4, 5, 1, 3, 5]]).T\n    assert torch.equal(edges, expected_edges)\n    assert torch.equal(edges[1], expected_edges[1])\n\n\ndef test_find_reverse_edges_hetero():\n    edges = {\n        \"A:r:B\": torch.tensor([[1, 5], [2, 5]]).T,\n        \"B:rr:A\": torch.tensor([[3], [3]]).T,\n    }\n    edges = gb.add_reverse_edges(edges, {\"A:r:B\": \"B:rr:A\"})\n    expected_edges = {\n        \"A:r:B\": torch.tensor([[1, 5], [2, 5]]).T,\n        \"B:rr:A\": torch.tensor([[3, 2, 5], [3, 1, 5]]).T,\n    }\n    assert torch.equal(edges[\"A:r:B\"], expected_edges[\"A:r:B\"])\n    assert torch.equal(edges[\"B:rr:A\"], expected_edges[\"B:rr:A\"])\n\n\ndef test_find_reverse_edges_bi_reverse_types():\n    edges = {\n        \"A:r:B\": torch.tensor([[1, 5], [2, 5]]).T,\n        \"B:rr:A\": torch.tensor([[3], [3]]).T,\n    }\n    edges = gb.add_reverse_edges(edges, {\"A:r:B\": \"B:rr:A\", \"B:rr:A\": \"A:r:B\"})\n    expected_edges = {\n        \"A:r:B\": torch.tensor([[1, 5, 3], [2, 5, 3]]).T,\n        \"B:rr:A\": torch.tensor([[3, 2, 5], [3, 1, 5]]).T,\n    }\n    assert torch.equal(edges[\"A:r:B\"], expected_edges[\"A:r:B\"])\n    assert torch.equal(edges[\"B:rr:A\"], expected_edges[\"B:rr:A\"])\n\n\ndef test_find_reverse_edges_circual_reverse_types():\n    edges = {\n        \"A:r1:B\": torch.tensor([[1, 1]]),\n        \"B:r2:C\": torch.tensor([[2, 2]]),\n        \"C:r3:A\": torch.tensor([[3, 3]]),\n    }\n    edges = gb.add_reverse_edges(\n        edges, {\"A:r1:B\": \"B:r2:C\", \"B:r2:C\": \"C:r3:A\", \"C:r3:A\": \"A:r1:B\"}\n    )\n    expected_edges = {\n        \"A:r1:B\": torch.tensor([[1, 3], [1, 3]]).T,\n        \"B:r2:C\": torch.tensor([[2, 1], [2, 1]]).T,\n        \"C:r3:A\": torch.tensor([[3, 2], [3, 2]]).T,\n    }\n    assert torch.equal(edges[\"A:r1:B\"], expected_edges[\"A:r1:B\"])\n    assert torch.equal(edges[\"B:r2:C\"], expected_edges[\"B:r2:C\"])\n    assert torch.equal(edges[\"A:r1:B\"], expected_edges[\"A:r1:B\"])\n    assert torch.equal(edges[\"C:r3:A\"], expected_edges[\"C:r3:A\"])\n"
  },
  {
    "path": "tests/python/pytorch/graphbolt/test_integration.py",
    "content": "import dgl\nimport dgl.graphbolt as gb\nimport dgl.sparse as dglsp\nimport torch\n\n\ndef test_integration_link_prediction():\n    torch.manual_seed(926)\n\n    indptr = torch.tensor([0, 0, 1, 3, 6, 8, 10])\n    indices = torch.tensor([5, 3, 3, 3, 3, 4, 4, 0, 5, 4])\n\n    matrix_a = dglsp.from_csc(indptr, indices)\n    seeds = torch.t(torch.stack(matrix_a.coo()))\n    node_feature_data = torch.tensor(\n        [\n            [0.9634, 0.2294],\n            [0.6172, 0.7865],\n            [0.2109, 0.1089],\n            [0.8672, 0.2276],\n            [0.5503, 0.8223],\n            [0.5160, 0.2486],\n        ]\n    )\n    edge_feature_data = torch.tensor(\n        [\n            [0.5123, 0.1709, 0.6150],\n            [0.1476, 0.1902, 0.1314],\n            [0.2582, 0.5203, 0.6228],\n            [0.3708, 0.7631, 0.2683],\n            [0.2126, 0.7878, 0.7225],\n            [0.7885, 0.3414, 0.5485],\n            [0.4088, 0.8200, 0.1851],\n            [0.0056, 0.9469, 0.4432],\n            [0.8972, 0.7511, 0.3617],\n            [0.5773, 0.2199, 0.3366],\n        ]\n    )\n\n    item_set = gb.ItemSet(seeds, names=\"seeds\")\n    graph = gb.fused_csc_sampling_graph(indptr, indices)\n\n    node_feature = gb.TorchBasedFeature(node_feature_data)\n    edge_feature = gb.TorchBasedFeature(edge_feature_data)\n    features = {\n        (\"node\", None, \"feat\"): node_feature,\n        (\"edge\", None, \"feat\"): edge_feature,\n    }\n    feature_store = gb.BasicFeatureStore(features)\n    datapipe = gb.ItemSampler(item_set, batch_size=4)\n    datapipe = datapipe.sample_uniform_negative(graph, 2)\n    fanouts = torch.LongTensor([1])\n    datapipe = datapipe.sample_neighbor(graph, [fanouts, fanouts], replace=True)\n    datapipe = datapipe.transform(gb.exclude_seed_edges)\n    datapipe = datapipe.fetch_feature(\n        feature_store, node_feature_keys=[\"feat\"], edge_feature_keys=[\"feat\"]\n    )\n    dataloader = gb.DataLoader(\n        datapipe,\n    )\n    expected = [\n        str(\n            \"\"\"MiniBatch(seeds=tensor([[5, 1],\n                        [3, 2],\n                        [3, 2],\n                        [3, 3],\n                        [5, 2],\n                        [5, 1],\n                        [3, 4],\n                        [3, 3],\n                        [3, 5],\n                        [3, 2],\n                        [3, 0],\n                        [3, 4]]),\n          sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 1, 1, 1, 2, 2], dtype=torch.int32),\n                                                                         indices=tensor([4, 5], dtype=torch.int32),\n                                                           ),\n                                               original_row_node_ids=tensor([5, 1, 3, 2, 4, 0]),\n                                               original_edge_ids=tensor([9, 7]),\n                                               original_column_node_ids=tensor([5, 1, 3, 2, 4, 0]),\n                            ),\n                            SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 1, 1, 1, 2, 2], dtype=torch.int32),\n                                                                         indices=tensor([0, 5], dtype=torch.int32),\n                                                           ),\n                                               original_row_node_ids=tensor([5, 1, 3, 2, 4, 0]),\n                                               original_edge_ids=tensor([8, 7]),\n                                               original_column_node_ids=tensor([5, 1, 3, 2, 4, 0]),\n                            )],\n          node_features={'feat': tensor([[0.5160, 0.2486],\n                                [0.6172, 0.7865],\n                                [0.8672, 0.2276],\n                                [0.2109, 0.1089],\n                                [0.5503, 0.8223],\n                                [0.9634, 0.2294]])},\n          labels=tensor([1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.]),\n          input_nodes=tensor([5, 1, 3, 2, 4, 0]),\n          indexes=tensor([0, 1, 2, 3, 0, 0, 1, 1, 2, 2, 3, 3]),\n          edge_features=[{'feat': tensor([[0.5773, 0.2199, 0.3366],\n                                [0.0056, 0.9469, 0.4432]])},\n                        {'feat': tensor([[0.8972, 0.7511, 0.3617],\n                                [0.0056, 0.9469, 0.4432]])}],\n          compacted_seeds=tensor([[0, 1],\n                                  [2, 3],\n                                  [2, 3],\n                                  [2, 2],\n                                  [0, 3],\n                                  [0, 1],\n                                  [2, 4],\n                                  [2, 2],\n                                  [2, 0],\n                                  [2, 3],\n                                  [2, 5],\n                                  [2, 4]]),\n          blocks=[Block(num_src_nodes=6, num_dst_nodes=6, num_edges=2),\n                 Block(num_src_nodes=6, num_dst_nodes=6, num_edges=2)],\n       )\"\"\"\n        ),\n        str(\n            \"\"\"MiniBatch(seeds=tensor([[3, 3],\n                        [4, 3],\n                        [4, 4],\n                        [0, 4],\n                        [3, 4],\n                        [3, 5],\n                        [4, 1],\n                        [4, 4],\n                        [4, 4],\n                        [4, 5],\n                        [0, 1],\n                        [0, 3]]),\n          sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 0, 0, 0, 1], dtype=torch.int32),\n                                                                         indices=tensor([3], dtype=torch.int32),\n                                                           ),\n                                               original_row_node_ids=tensor([3, 4, 0, 5, 1]),\n                                               original_edge_ids=tensor([0]),\n                                               original_column_node_ids=tensor([3, 4, 0, 5, 1]),\n                            ),\n                            SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 0, 0, 1, 2], dtype=torch.int32),\n                                                                         indices=tensor([3, 3], dtype=torch.int32),\n                                                           ),\n                                               original_row_node_ids=tensor([3, 4, 0, 5, 1]),\n                                               original_edge_ids=tensor([8, 0]),\n                                               original_column_node_ids=tensor([3, 4, 0, 5, 1]),\n                            )],\n          node_features={'feat': tensor([[0.8672, 0.2276],\n                                [0.5503, 0.8223],\n                                [0.9634, 0.2294],\n                                [0.5160, 0.2486],\n                                [0.6172, 0.7865]])},\n          labels=tensor([1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.]),\n          input_nodes=tensor([3, 4, 0, 5, 1]),\n          indexes=tensor([0, 1, 2, 3, 0, 0, 1, 1, 2, 2, 3, 3]),\n          edge_features=[{'feat': tensor([[0.5123, 0.1709, 0.6150]])},\n                        {'feat': tensor([[0.8972, 0.7511, 0.3617],\n                                [0.5123, 0.1709, 0.6150]])}],\n          compacted_seeds=tensor([[0, 0],\n                                  [1, 0],\n                                  [1, 1],\n                                  [2, 1],\n                                  [0, 1],\n                                  [0, 3],\n                                  [1, 4],\n                                  [1, 1],\n                                  [1, 1],\n                                  [1, 3],\n                                  [2, 4],\n                                  [2, 0]]),\n          blocks=[Block(num_src_nodes=5, num_dst_nodes=5, num_edges=1),\n                 Block(num_src_nodes=5, num_dst_nodes=5, num_edges=2)],\n       )\"\"\"\n        ),\n        str(\n            \"\"\"MiniBatch(seeds=tensor([[5, 5],\n                        [4, 5],\n                        [5, 5],\n                        [5, 5],\n                        [4, 0],\n                        [4, 0]]),\n          sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 1, 1], dtype=torch.int32),\n                                                                         indices=tensor([1], dtype=torch.int32),\n                                                           ),\n                                               original_row_node_ids=tensor([5, 4, 0]),\n                                               original_edge_ids=tensor([6]),\n                                               original_column_node_ids=tensor([5, 4, 0]),\n                            ),\n                            SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 1, 1], dtype=torch.int32),\n                                                                         indices=tensor([2], dtype=torch.int32),\n                                                           ),\n                                               original_row_node_ids=tensor([5, 4, 0]),\n                                               original_edge_ids=tensor([7]),\n                                               original_column_node_ids=tensor([5, 4, 0]),\n                            )],\n          node_features={'feat': tensor([[0.5160, 0.2486],\n                                [0.5503, 0.8223],\n                                [0.9634, 0.2294]])},\n          labels=tensor([1., 1., 0., 0., 0., 0.]),\n          input_nodes=tensor([5, 4, 0]),\n          indexes=tensor([0, 1, 0, 0, 1, 1]),\n          edge_features=[{'feat': tensor([[0.4088, 0.8200, 0.1851]])},\n                        {'feat': tensor([[0.0056, 0.9469, 0.4432]])}],\n          compacted_seeds=tensor([[0, 0],\n                                  [1, 0],\n                                  [0, 0],\n                                  [0, 0],\n                                  [1, 2],\n                                  [1, 2]]),\n          blocks=[Block(num_src_nodes=3, num_dst_nodes=3, num_edges=1),\n                 Block(num_src_nodes=3, num_dst_nodes=3, num_edges=1)],\n       )\"\"\"\n        ),\n    ]\n    for step, data in enumerate(dataloader):\n        assert expected[step] == str(data), print(step, data)\n\n\ndef test_integration_node_classification():\n    torch.manual_seed(926)\n\n    indptr = torch.tensor([0, 0, 1, 3, 6, 8, 10])\n    indices = torch.tensor([5, 3, 3, 3, 3, 4, 4, 0, 5, 4])\n\n    seeds = torch.tensor([5, 1, 2, 4, 3, 0])\n    node_feature_data = torch.tensor(\n        [\n            [0.9634, 0.2294],\n            [0.6172, 0.7865],\n            [0.2109, 0.1089],\n            [0.8672, 0.2276],\n            [0.5503, 0.8223],\n            [0.5160, 0.2486],\n        ]\n    )\n    edge_feature_data = torch.tensor(\n        [\n            [0.5123, 0.1709, 0.6150],\n            [0.1476, 0.1902, 0.1314],\n            [0.2582, 0.5203, 0.6228],\n            [0.3708, 0.7631, 0.2683],\n            [0.2126, 0.7878, 0.7225],\n            [0.7885, 0.3414, 0.5485],\n            [0.4088, 0.8200, 0.1851],\n            [0.0056, 0.9469, 0.4432],\n            [0.8972, 0.7511, 0.3617],\n            [0.5773, 0.2199, 0.3366],\n        ]\n    )\n\n    item_set = gb.ItemSet(seeds, names=\"seeds\")\n    graph = gb.fused_csc_sampling_graph(indptr, indices)\n\n    node_feature = gb.TorchBasedFeature(node_feature_data)\n    edge_feature = gb.TorchBasedFeature(edge_feature_data)\n    features = {\n        (\"node\", None, \"feat\"): node_feature,\n        (\"edge\", None, \"feat\"): edge_feature,\n    }\n    feature_store = gb.BasicFeatureStore(features)\n    datapipe = gb.ItemSampler(item_set, batch_size=2)\n    fanouts = torch.LongTensor([1])\n    datapipe = datapipe.sample_neighbor(graph, [fanouts, fanouts], replace=True)\n    datapipe = datapipe.fetch_feature(\n        feature_store, node_feature_keys=[\"feat\"], edge_feature_keys=[\"feat\"]\n    )\n    dataloader = gb.DataLoader(\n        datapipe,\n    )\n    expected = [\n        str(\n            \"\"\"MiniBatch(seeds=tensor([5, 1]),\n          sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2], dtype=torch.int32),\n                                                                         indices=tensor([0, 0], dtype=torch.int32),\n                                                           ),\n                                               original_row_node_ids=tensor([5, 1]),\n                                               original_edge_ids=tensor([8, 0]),\n                                               original_column_node_ids=tensor([5, 1]),\n                            ),\n                            SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2], dtype=torch.int32),\n                                                                         indices=tensor([0, 0], dtype=torch.int32),\n                                                           ),\n                                               original_row_node_ids=tensor([5, 1]),\n                                               original_edge_ids=tensor([8, 0]),\n                                               original_column_node_ids=tensor([5, 1]),\n                            )],\n          node_features={'feat': tensor([[0.5160, 0.2486],\n                                [0.6172, 0.7865]])},\n          labels=None,\n          input_nodes=tensor([5, 1]),\n          indexes=None,\n          edge_features=[{'feat': tensor([[0.8972, 0.7511, 0.3617],\n                                [0.5123, 0.1709, 0.6150]])},\n                        {'feat': tensor([[0.8972, 0.7511, 0.3617],\n                                [0.5123, 0.1709, 0.6150]])}],\n          compacted_seeds=None,\n          blocks=[Block(num_src_nodes=2, num_dst_nodes=2, num_edges=2),\n                 Block(num_src_nodes=2, num_dst_nodes=2, num_edges=2)],\n       )\"\"\"\n        ),\n        str(\n            \"\"\"MiniBatch(seeds=tensor([2, 4]),\n          sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 3], dtype=torch.int32),\n                                                                         indices=tensor([2, 1, 2], dtype=torch.int32),\n                                                           ),\n                                               original_row_node_ids=tensor([2, 4, 3]),\n                                               original_edge_ids=tensor([1, 6, 3]),\n                                               original_column_node_ids=tensor([2, 4, 3]),\n                            ),\n                            SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2], dtype=torch.int32),\n                                                                         indices=tensor([2, 1], dtype=torch.int32),\n                                                           ),\n                                               original_row_node_ids=tensor([2, 4, 3]),\n                                               original_edge_ids=tensor([2, 6]),\n                                               original_column_node_ids=tensor([2, 4]),\n                            )],\n          node_features={'feat': tensor([[0.2109, 0.1089],\n                                [0.5503, 0.8223],\n                                [0.8672, 0.2276]])},\n          labels=None,\n          input_nodes=tensor([2, 4, 3]),\n          indexes=None,\n          edge_features=[{'feat': tensor([[0.1476, 0.1902, 0.1314],\n                                [0.4088, 0.8200, 0.1851],\n                                [0.3708, 0.7631, 0.2683]])},\n                        {'feat': tensor([[0.2582, 0.5203, 0.6228],\n                                [0.4088, 0.8200, 0.1851]])}],\n          compacted_seeds=None,\n          blocks=[Block(num_src_nodes=3, num_dst_nodes=3, num_edges=3),\n                 Block(num_src_nodes=3, num_dst_nodes=2, num_edges=2)],\n       )\"\"\"\n        ),\n        str(\n            \"\"\"MiniBatch(seeds=tensor([3, 0]),\n          sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 1], dtype=torch.int32),\n                                                                         indices=tensor([0], dtype=torch.int32),\n                                                           ),\n                                               original_row_node_ids=tensor([3, 0]),\n                                               original_edge_ids=tensor([3]),\n                                               original_column_node_ids=tensor([3, 0]),\n                            ),\n                            SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 1], dtype=torch.int32),\n                                                                         indices=tensor([0], dtype=torch.int32),\n                                                           ),\n                                               original_row_node_ids=tensor([3, 0]),\n                                               original_edge_ids=tensor([3]),\n                                               original_column_node_ids=tensor([3, 0]),\n                            )],\n          node_features={'feat': tensor([[0.8672, 0.2276],\n                                [0.9634, 0.2294]])},\n          labels=None,\n          input_nodes=tensor([3, 0]),\n          indexes=None,\n          edge_features=[{'feat': tensor([[0.3708, 0.7631, 0.2683]])},\n                        {'feat': tensor([[0.3708, 0.7631, 0.2683]])}],\n          compacted_seeds=None,\n          blocks=[Block(num_src_nodes=2, num_dst_nodes=2, num_edges=1),\n                 Block(num_src_nodes=2, num_dst_nodes=2, num_edges=1)],\n       )\"\"\"\n        ),\n    ]\n    for step, data in enumerate(dataloader):\n        assert expected[step] == str(data), print(step, data)\n"
  },
  {
    "path": "tests/python/pytorch/graphbolt/test_item_sampler.py",
    "content": "import os\nimport re\nimport unittest\nfrom collections import defaultdict\nfrom sys import platform\n\nimport backend as F\n\nimport dgl\nimport pytest\nimport torch\nimport torch.distributed as dist\nimport torch.multiprocessing as mp\nfrom dgl import graphbolt as gb\n\n\ndef test_ItemSampler_minibatcher():\n    # Default minibatcher is used if not specified.\n    # Warning message is raised if names are not specified.\n    item_set = gb.ItemSet(torch.arange(0, 10))\n    item_sampler = gb.ItemSampler(item_set, batch_size=4)\n    with pytest.warns(\n        UserWarning,\n        match=re.escape(\n            \"Failed to map item list to `MiniBatch` as the names of items are \"\n            \"not provided. Please provide a customized `MiniBatcher`. The \"\n            \"item list is returned as is.\"\n        ),\n    ):\n        minibatch = next(iter(item_sampler))\n        assert not isinstance(minibatch, gb.MiniBatch)\n\n    # Default minibatcher is used if not specified.\n    # Warning message is raised if unrecognized names are specified.\n    item_set = gb.ItemSet(torch.arange(0, 10), names=\"unknown_name\")\n    item_sampler = gb.ItemSampler(item_set, batch_size=4)\n    with pytest.warns(\n        UserWarning,\n        match=re.escape(\n            \"Unknown item name 'unknown_name' is detected and added into \"\n            \"`MiniBatch`. You probably need to provide a customized \"\n            \"`MiniBatcher`.\"\n        ),\n    ):\n        minibatch = next(iter(item_sampler))\n        assert isinstance(minibatch, gb.MiniBatch)\n        assert minibatch.unknown_name is not None\n\n    # Default minibatcher is used if not specified.\n    # `MiniBatch` is returned if expected names are specified.\n    item_set = gb.ItemSet(torch.arange(0, 10), names=\"seeds\")\n    item_sampler = gb.ItemSampler(item_set, batch_size=4)\n    minibatch = next(iter(item_sampler))\n    assert isinstance(minibatch, gb.MiniBatch)\n    assert minibatch.seeds is not None\n    assert len(minibatch.seeds) == 4\n\n    # Customized minibatcher is used if specified.\n    def minibatcher(batch, names):\n        return gb.MiniBatch(seeds=batch)\n\n    item_sampler = gb.ItemSampler(\n        item_set, batch_size=4, minibatcher=minibatcher\n    )\n    minibatch = next(iter(item_sampler))\n    assert isinstance(minibatch, gb.MiniBatch)\n    assert minibatch.seeds is not None\n    assert len(minibatch.seeds) == 4\n\n\n@pytest.mark.parametrize(\"batch_size\", [1, 4])\n@pytest.mark.parametrize(\"shuffle\", [True, False])\n@pytest.mark.parametrize(\"drop_last\", [True, False])\ndef test_ItemSet_integer(batch_size, shuffle, drop_last):\n    # Node IDs.\n    num_ids = 103\n    item_set = gb.ItemSet(num_ids, names=\"seeds\")\n    item_sampler = gb.ItemSampler(\n        item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last\n    )\n    minibatch_ids = []\n    for i, minibatch in enumerate(item_sampler):\n        assert isinstance(minibatch, gb.MiniBatch)\n        assert minibatch.seeds is not None\n        assert minibatch.labels is None\n        is_last = (i + 1) * batch_size >= num_ids\n        if not is_last or num_ids % batch_size == 0:\n            assert len(minibatch.seeds) == batch_size\n        else:\n            if not drop_last:\n                assert len(minibatch.seeds) == num_ids % batch_size\n            else:\n                assert False\n        minibatch_ids.append(minibatch.seeds)\n    minibatch_ids = torch.cat(minibatch_ids)\n    assert torch.all(minibatch_ids[:-1] <= minibatch_ids[1:]) is not shuffle\n\n\n@pytest.mark.parametrize(\"batch_size\", [1, 4])\n@pytest.mark.parametrize(\"shuffle\", [True, False])\n@pytest.mark.parametrize(\"drop_last\", [True, False])\ndef test_ItemSet_seed_nodes(batch_size, shuffle, drop_last):\n    # Node IDs.\n    num_ids = 103\n    seed_nodes = torch.arange(0, num_ids)\n    item_set = gb.ItemSet(seed_nodes, names=\"seeds\")\n    item_sampler = gb.ItemSampler(\n        item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last\n    )\n    minibatch_ids = []\n    for i, minibatch in enumerate(item_sampler):\n        assert isinstance(minibatch, gb.MiniBatch)\n        assert minibatch.seeds is not None\n        assert minibatch.labels is None\n        is_last = (i + 1) * batch_size >= num_ids\n        if not is_last or num_ids % batch_size == 0:\n            assert len(minibatch.seeds) == batch_size\n        else:\n            if not drop_last:\n                assert len(minibatch.seeds) == num_ids % batch_size\n            else:\n                assert False\n        minibatch_ids.append(minibatch.seeds)\n    minibatch_ids = torch.cat(minibatch_ids)\n    assert torch.all(minibatch_ids[:-1] <= minibatch_ids[1:]) is not shuffle\n\n\n@pytest.mark.parametrize(\"batch_size\", [1, 4])\n@pytest.mark.parametrize(\"shuffle\", [True, False])\n@pytest.mark.parametrize(\"drop_last\", [True, False])\ndef test_ItemSet_seed_nodes_labels(batch_size, shuffle, drop_last):\n    # Node IDs.\n    num_ids = 103\n    seed_nodes = torch.arange(0, num_ids)\n    labels = torch.arange(0, num_ids)\n    item_set = gb.ItemSet((seed_nodes, labels), names=(\"seeds\", \"labels\"))\n    item_sampler = gb.ItemSampler(\n        item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last\n    )\n    minibatch_ids = []\n    minibatch_labels = []\n    for i, minibatch in enumerate(item_sampler):\n        assert isinstance(minibatch, gb.MiniBatch)\n        assert minibatch.seeds is not None\n        assert minibatch.labels is not None\n        assert len(minibatch.seeds) == len(minibatch.labels)\n        is_last = (i + 1) * batch_size >= num_ids\n        if not is_last or num_ids % batch_size == 0:\n            assert len(minibatch.seeds) == batch_size\n        else:\n            if not drop_last:\n                assert len(minibatch.seeds) == num_ids % batch_size\n            else:\n                assert False\n        minibatch_ids.append(minibatch.seeds)\n        minibatch_labels.append(minibatch.labels)\n    minibatch_ids = torch.cat(minibatch_ids)\n    minibatch_labels = torch.cat(minibatch_labels)\n    assert torch.all(minibatch_ids[:-1] <= minibatch_ids[1:]) is not shuffle\n    assert (\n        torch.all(minibatch_labels[:-1] <= minibatch_labels[1:]) is not shuffle\n    )\n\n\n@pytest.mark.parametrize(\"batch_size\", [1, 4])\n@pytest.mark.parametrize(\"shuffle\", [True, False])\n@pytest.mark.parametrize(\"drop_last\", [True, False])\ndef test_ItemSet_node_pairs(batch_size, shuffle, drop_last):\n    # Node pairs.\n    num_ids = 103\n    node_pairs = torch.arange(0, 2 * num_ids).reshape(-1, 2)\n    item_set = gb.ItemSet(node_pairs, names=\"seeds\")\n    item_sampler = gb.ItemSampler(\n        item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last\n    )\n    src_ids = []\n    dst_ids = []\n    for i, minibatch in enumerate(item_sampler):\n        assert minibatch.seeds is not None\n        assert isinstance(minibatch.seeds, torch.Tensor)\n        assert minibatch.labels is None\n        src, dst = minibatch.seeds.T\n        is_last = (i + 1) * batch_size >= num_ids\n        if not is_last or num_ids % batch_size == 0:\n            expected_batch_size = batch_size\n        else:\n            if not drop_last:\n                expected_batch_size = num_ids % batch_size\n            else:\n                assert False\n        assert len(src) == expected_batch_size\n        assert len(dst) == expected_batch_size\n        # Verify src and dst IDs match.\n        assert torch.equal(src + 1, dst)\n        # Archive batch.\n        src_ids.append(src)\n        dst_ids.append(dst)\n    src_ids = torch.cat(src_ids)\n    dst_ids = torch.cat(dst_ids)\n    assert torch.all(src_ids[:-1] <= src_ids[1:]) is not shuffle\n    assert torch.all(dst_ids[:-1] <= dst_ids[1:]) is not shuffle\n\n\n@pytest.mark.parametrize(\"batch_size\", [1, 4])\n@pytest.mark.parametrize(\"shuffle\", [True, False])\n@pytest.mark.parametrize(\"drop_last\", [True, False])\ndef test_ItemSet_node_pairs_labels(batch_size, shuffle, drop_last):\n    # Node pairs and labels\n    num_ids = 103\n    node_pairs = torch.arange(0, 2 * num_ids).reshape(-1, 2)\n    labels = node_pairs[:, 0]\n    item_set = gb.ItemSet((node_pairs, labels), names=(\"seeds\", \"labels\"))\n    item_sampler = gb.ItemSampler(\n        item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last\n    )\n    src_ids = []\n    dst_ids = []\n    labels = []\n    for i, minibatch in enumerate(item_sampler):\n        assert minibatch.seeds is not None\n        assert isinstance(minibatch.seeds, torch.Tensor)\n        assert minibatch.labels is not None\n        src, dst = minibatch.seeds.T\n        label = minibatch.labels\n        assert len(src) == len(dst)\n        assert len(src) == len(label)\n        is_last = (i + 1) * batch_size >= num_ids\n        if not is_last or num_ids % batch_size == 0:\n            expected_batch_size = batch_size\n        else:\n            if not drop_last:\n                expected_batch_size = num_ids % batch_size\n            else:\n                assert False\n        assert len(src) == expected_batch_size\n        assert len(dst) == expected_batch_size\n        assert len(label) == expected_batch_size\n        # Verify src/dst IDs and labels match.\n        assert torch.equal(src + 1, dst)\n        assert torch.equal(src, label)\n        # Archive batch.\n        src_ids.append(src)\n        dst_ids.append(dst)\n        labels.append(label)\n    src_ids = torch.cat(src_ids)\n    dst_ids = torch.cat(dst_ids)\n    labels = torch.cat(labels)\n    assert torch.all(src_ids[:-1] <= src_ids[1:]) is not shuffle\n    assert torch.all(dst_ids[:-1] <= dst_ids[1:]) is not shuffle\n    assert torch.all(labels[:-1] <= labels[1:]) is not shuffle\n\n\n@pytest.mark.parametrize(\"batch_size\", [1, 4])\n@pytest.mark.parametrize(\"shuffle\", [True, False])\n@pytest.mark.parametrize(\"drop_last\", [True, False])\ndef test_ItemSet_node_pairs_labels_indexes(batch_size, shuffle, drop_last):\n    # Node pairs and negative destinations.\n    num_ids = 103\n    num_negs = 2\n    node_pairs = torch.arange(0, 2 * num_ids).reshape(-1, 2)\n    neg_srcs = node_pairs[:, 0].repeat_interleave(num_negs)\n    neg_dsts = torch.arange(2 * num_ids, 2 * num_ids + num_ids * num_negs)\n    neg_node_pairs = torch.cat((neg_srcs, neg_dsts)).reshape(2, -1).T\n    labels = torch.empty(num_ids * 3)\n    labels[:num_ids] = 1\n    labels[num_ids:] = 0\n    indexes = torch.cat(\n        (\n            torch.arange(0, num_ids),\n            torch.arange(0, num_ids).repeat_interleave(num_negs),\n        )\n    )\n    node_pairs = torch.cat((node_pairs, neg_node_pairs))\n    item_set = gb.ItemSet(\n        (node_pairs, labels, indexes), names=(\"seeds\", \"labels\", \"indexes\")\n    )\n    item_sampler = gb.ItemSampler(\n        item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last\n    )\n    src_ids = []\n    dst_ids = []\n    negs_ids = []\n    final_labels = []\n    final_indexes = []\n    for i, minibatch in enumerate(item_sampler):\n        assert minibatch.seeds is not None\n        assert isinstance(minibatch.seeds, torch.Tensor)\n        assert minibatch.labels is not None\n        assert minibatch.indexes is not None\n        src, dst = minibatch.seeds.T\n        negs_src = src[~minibatch.labels.to(bool)]\n        negs_dst = dst[~minibatch.labels.to(bool)]\n        is_last = (i + 1) * batch_size >= num_ids * 3\n        if not is_last or num_ids * 3 % batch_size == 0:\n            expected_batch_size = batch_size\n        else:\n            if not drop_last:\n                expected_batch_size = num_ids * 3 % batch_size\n            else:\n                assert False\n        assert len(src) == expected_batch_size\n        assert len(dst) == expected_batch_size\n        assert negs_src.dim() == 1\n        assert negs_dst.dim() == 1\n        assert torch.equal((negs_dst - 2 * num_ids) // 2 * 2, negs_src)\n        # Archive batch.\n        src_ids.append(src)\n        dst_ids.append(dst)\n        negs_ids.append(negs_dst)\n        final_labels.append(minibatch.labels)\n        final_indexes.append(minibatch.indexes)\n    src_ids = torch.cat(src_ids)\n    dst_ids = torch.cat(dst_ids)\n    negs_ids = torch.cat(negs_ids)\n    final_labels = torch.cat(final_labels)\n    final_indexes = torch.cat(final_indexes)\n    assert torch.all(src_ids[:-1] <= src_ids[1:]) is not shuffle\n    assert torch.all(dst_ids[:-1] <= dst_ids[1:]) is not shuffle\n    assert torch.all(negs_ids[:-1] <= negs_ids[1:]) is not shuffle\n    assert torch.all(final_labels[:-1] >= final_labels[1:]) is not shuffle\n    if not drop_last:\n        assert final_labels.sum() == num_ids\n        assert torch.equal(final_indexes, indexes) is not shuffle\n\n\n@pytest.mark.parametrize(\"batch_size\", [1, 4])\n@pytest.mark.parametrize(\"shuffle\", [True, False])\n@pytest.mark.parametrize(\"drop_last\", [True, False])\ndef test_ItemSet_hyperlink(batch_size, shuffle, drop_last):\n    # Node pairs.\n    num_ids = 103\n    seeds = torch.arange(0, 3 * num_ids).reshape(-1, 3)\n    item_set = gb.ItemSet(seeds, names=\"seeds\")\n    item_sampler = gb.ItemSampler(\n        item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last\n    )\n    seeds_ids = []\n    for i, minibatch in enumerate(item_sampler):\n        assert minibatch.seeds is not None\n        assert isinstance(minibatch.seeds, torch.Tensor)\n        assert minibatch.labels is None\n        is_last = (i + 1) * batch_size >= num_ids\n        if not is_last or num_ids % batch_size == 0:\n            expected_batch_size = batch_size\n        else:\n            if not drop_last:\n                expected_batch_size = num_ids % batch_size\n            else:\n                assert False\n        assert minibatch.seeds.shape == (expected_batch_size, 3)\n        # Verify seeds match.\n        assert torch.equal(minibatch.seeds[:, 0] + 1, minibatch.seeds[:, 1])\n        assert torch.equal(minibatch.seeds[:, 1] + 1, minibatch.seeds[:, 2])\n        # Archive batch.\n        seeds_ids.append(minibatch.seeds)\n    seeds_ids = torch.cat(seeds_ids)\n    assert torch.all(seeds_ids[:-1, 0] <= seeds_ids[1:, 0]) is not shuffle\n    assert torch.all(seeds_ids[:-1, 1] <= seeds_ids[1:, 1]) is not shuffle\n    assert torch.all(seeds_ids[:-1, 2] <= seeds_ids[1:, 2]) is not shuffle\n\n\n@pytest.mark.parametrize(\"batch_size\", [1, 4])\n@pytest.mark.parametrize(\"shuffle\", [True, False])\n@pytest.mark.parametrize(\"drop_last\", [True, False])\ndef test_ItemSet_seeds_labels(batch_size, shuffle, drop_last):\n    # Node pairs and labels\n    num_ids = 103\n    seeds = torch.arange(0, 3 * num_ids).reshape(-1, 3)\n    labels = seeds[:, 0]\n    item_set = gb.ItemSet((seeds, labels), names=(\"seeds\", \"labels\"))\n    item_sampler = gb.ItemSampler(\n        item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last\n    )\n    seeds_ids = []\n    labels = []\n    for i, minibatch in enumerate(item_sampler):\n        assert minibatch.seeds is not None\n        assert isinstance(minibatch.seeds, torch.Tensor)\n        assert minibatch.labels is not None\n        label = minibatch.labels\n        assert len(minibatch.seeds) == len(label)\n        is_last = (i + 1) * batch_size >= num_ids\n        if not is_last or num_ids % batch_size == 0:\n            expected_batch_size = batch_size\n        else:\n            if not drop_last:\n                expected_batch_size = num_ids % batch_size\n            else:\n                assert False\n        assert minibatch.seeds.shape == (expected_batch_size, 3)\n        assert len(label) == expected_batch_size\n        # Verify seeds and labels match.\n        assert torch.equal(minibatch.seeds[:, 0] + 1, minibatch.seeds[:, 1])\n        assert torch.equal(minibatch.seeds[:, 1] + 1, minibatch.seeds[:, 2])\n        # Archive batch.\n        seeds_ids.append(minibatch.seeds)\n        labels.append(label)\n    seeds_ids = torch.cat(seeds_ids)\n    labels = torch.cat(labels)\n    assert torch.all(seeds_ids[:-1, 0] <= seeds_ids[1:, 0]) is not shuffle\n    assert torch.all(seeds_ids[:-1, 1] <= seeds_ids[1:, 1]) is not shuffle\n    assert torch.all(seeds_ids[:-1, 2] <= seeds_ids[1:, 2]) is not shuffle\n    assert torch.all(labels[:-1] <= labels[1:]) is not shuffle\n\n\ndef test_append_with_other_datapipes():\n    num_ids = 100\n    batch_size = 4\n    item_set = gb.ItemSet(torch.arange(0, num_ids), names=\"seeds\")\n    data_pipe = gb.ItemSampler(item_set, batch_size)\n    for i, data in enumerate(data_pipe):\n        expected = torch.full((batch_size,), i * batch_size)\n        expected = expected + torch.tensor([0, 1, 2, 3])\n        assert torch.equal(data.seeds, expected)\n\n\n@pytest.mark.parametrize(\"batch_size\", [1, 4])\n@pytest.mark.parametrize(\"shuffle\", [True, False])\n@pytest.mark.parametrize(\"drop_last\", [True, False])\ndef test_HeteroItemSet_seed_nodes(batch_size, shuffle, drop_last):\n    # Node IDs.\n    num_ids = 205\n    ids = {\n        \"user\": gb.ItemSet(torch.arange(0, 99), names=\"seeds\"),\n        \"item\": gb.ItemSet(torch.arange(99, num_ids), names=\"seeds\"),\n    }\n    chained_ids = []\n    for key, value in ids.items():\n        chained_ids += [(key, v) for v in value]\n    item_set = gb.HeteroItemSet(ids)\n    item_sampler = gb.ItemSampler(\n        item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last\n    )\n    minibatch_ids = []\n    for i, minibatch in enumerate(item_sampler):\n        is_last = (i + 1) * batch_size >= num_ids\n        if not is_last or num_ids % batch_size == 0:\n            expected_batch_size = batch_size\n        else:\n            if not drop_last:\n                expected_batch_size = num_ids % batch_size\n            else:\n                assert False\n        assert isinstance(minibatch, gb.MiniBatch)\n        assert minibatch.seeds is not None\n        ids = []\n        for _, v in minibatch.seeds.items():\n            ids.append(v)\n        ids = torch.cat(ids)\n        assert len(ids) == expected_batch_size\n        minibatch_ids.append(ids)\n    minibatch_ids = torch.cat(minibatch_ids)\n    assert torch.all(minibatch_ids[:-1] <= minibatch_ids[1:]) is not shuffle\n\n\n@pytest.mark.parametrize(\"batch_size\", [1, 4])\n@pytest.mark.parametrize(\"shuffle\", [True, False])\n@pytest.mark.parametrize(\"drop_last\", [True, False])\ndef test_HeteroItemSet_seed_nodes_labels(batch_size, shuffle, drop_last):\n    # Node IDs.\n    num_ids = 205\n    ids = {\n        \"user\": gb.ItemSet(\n            (torch.arange(0, 99), torch.arange(0, 99)),\n            names=(\"seeds\", \"labels\"),\n        ),\n        \"item\": gb.ItemSet(\n            (torch.arange(99, num_ids), torch.arange(99, num_ids)),\n            names=(\"seeds\", \"labels\"),\n        ),\n    }\n    chained_ids = []\n    for key, value in ids.items():\n        chained_ids += [(key, v) for v in value]\n    item_set = gb.HeteroItemSet(ids)\n    item_sampler = gb.ItemSampler(\n        item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last\n    )\n    minibatch_ids = []\n    minibatch_labels = []\n    for i, minibatch in enumerate(item_sampler):\n        assert isinstance(minibatch, gb.MiniBatch)\n        assert minibatch.seeds is not None\n        assert minibatch.labels is not None\n        is_last = (i + 1) * batch_size >= num_ids\n        if not is_last or num_ids % batch_size == 0:\n            expected_batch_size = batch_size\n        else:\n            if not drop_last:\n                expected_batch_size = num_ids % batch_size\n            else:\n                assert False\n        ids = []\n        for _, v in minibatch.seeds.items():\n            ids.append(v)\n        ids = torch.cat(ids)\n        assert len(ids) == expected_batch_size\n        minibatch_ids.append(ids)\n        labels = []\n        for _, v in minibatch.labels.items():\n            labels.append(v)\n        labels = torch.cat(labels)\n        assert len(labels) == expected_batch_size\n        minibatch_labels.append(labels)\n    minibatch_ids = torch.cat(minibatch_ids)\n    minibatch_labels = torch.cat(minibatch_labels)\n    assert torch.all(minibatch_ids[:-1] <= minibatch_ids[1:]) is not shuffle\n    assert (\n        torch.all(minibatch_labels[:-1] <= minibatch_labels[1:]) is not shuffle\n    )\n\n\n@pytest.mark.parametrize(\"batch_size\", [1, 4])\n@pytest.mark.parametrize(\"shuffle\", [True, False])\n@pytest.mark.parametrize(\"drop_last\", [True, False])\ndef test_HeteroItemSet_node_pairs(batch_size, shuffle, drop_last):\n    # Node pairs.\n    num_ids = 103\n    total_pairs = 2 * num_ids\n    node_pairs_like = torch.arange(0, num_ids * 2).reshape(-1, 2)\n    node_pairs_follow = torch.arange(num_ids * 2, num_ids * 4).reshape(-1, 2)\n    node_pairs_dict = {\n        \"user:like:item\": gb.ItemSet(node_pairs_like, names=\"seeds\"),\n        \"user:follow:user\": gb.ItemSet(node_pairs_follow, names=\"seeds\"),\n    }\n    item_set = gb.HeteroItemSet(node_pairs_dict)\n    item_sampler = gb.ItemSampler(\n        item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last\n    )\n    src_ids = []\n    dst_ids = []\n    for i, minibatch in enumerate(item_sampler):\n        assert isinstance(minibatch, gb.MiniBatch)\n        assert minibatch.seeds is not None\n        assert minibatch.labels is None\n        is_last = (i + 1) * batch_size >= total_pairs\n        if not is_last or total_pairs % batch_size == 0:\n            expected_batch_size = batch_size\n        else:\n            if not drop_last:\n                expected_batch_size = total_pairs % batch_size\n            else:\n                assert False\n        src = []\n        dst = []\n        for _, (seeds) in minibatch.seeds.items():\n            assert isinstance(seeds, torch.Tensor)\n            src.append(seeds[:, 0])\n            dst.append(seeds[:, 1])\n        src = torch.cat(src)\n        dst = torch.cat(dst)\n        assert len(src) == expected_batch_size\n        assert len(dst) == expected_batch_size\n        src_ids.append(src)\n        dst_ids.append(dst)\n        assert torch.equal(src + 1, dst)\n    src_ids = torch.cat(src_ids)\n    dst_ids = torch.cat(dst_ids)\n    assert torch.all(src_ids[:-1] <= src_ids[1:]) is not shuffle\n    assert torch.all(dst_ids[:-1] <= dst_ids[1:]) is not shuffle\n\n\n@pytest.mark.parametrize(\"batch_size\", [1, 4])\n@pytest.mark.parametrize(\"shuffle\", [True, False])\n@pytest.mark.parametrize(\"drop_last\", [True, False])\ndef test_HeteroItemSet_node_pairs_labels(batch_size, shuffle, drop_last):\n    # Node pairs and labels\n    num_ids = 103\n    total_ids = 2 * num_ids\n    node_pairs_like = torch.arange(0, num_ids * 2).reshape(-1, 2)\n    node_pairs_follow = torch.arange(num_ids * 2, num_ids * 4).reshape(-1, 2)\n    labels = torch.arange(0, num_ids)\n    node_pairs_dict = {\n        \"user:like:item\": gb.ItemSet(\n            (node_pairs_like, node_pairs_like[:, 0]),\n            names=(\"seeds\", \"labels\"),\n        ),\n        \"user:follow:user\": gb.ItemSet(\n            (node_pairs_follow, node_pairs_follow[:, 0]),\n            names=(\"seeds\", \"labels\"),\n        ),\n    }\n    item_set = gb.HeteroItemSet(node_pairs_dict)\n    item_sampler = gb.ItemSampler(\n        item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last\n    )\n    src_ids = []\n    dst_ids = []\n    labels = []\n    for i, minibatch in enumerate(item_sampler):\n        assert isinstance(minibatch, gb.MiniBatch)\n        assert minibatch.seeds is not None\n        assert minibatch.labels is not None\n        is_last = (i + 1) * batch_size >= total_ids\n        if not is_last or total_ids % batch_size == 0:\n            expected_batch_size = batch_size\n        else:\n            if not drop_last:\n                expected_batch_size = total_ids % batch_size\n            else:\n                assert False\n        src = []\n        dst = []\n        label = []\n        for _, seeds in minibatch.seeds.items():\n            assert isinstance(seeds, torch.Tensor)\n            src.append(seeds[:, 0])\n            dst.append(seeds[:, 1])\n        for _, v_label in minibatch.labels.items():\n            label.append(v_label)\n        src = torch.cat(src)\n        dst = torch.cat(dst)\n        label = torch.cat(label)\n        assert len(src) == expected_batch_size\n        assert len(dst) == expected_batch_size\n        assert len(label) == expected_batch_size\n        src_ids.append(src)\n        dst_ids.append(dst)\n        labels.append(label)\n        assert torch.equal(src + 1, dst)\n        assert torch.equal(src, label)\n    src_ids = torch.cat(src_ids)\n    dst_ids = torch.cat(dst_ids)\n    labels = torch.cat(labels)\n    assert torch.all(src_ids[:-1] <= src_ids[1:]) is not shuffle\n    assert torch.all(dst_ids[:-1] <= dst_ids[1:]) is not shuffle\n    assert torch.all(labels[:-1] <= labels[1:]) is not shuffle\n\n\n@pytest.mark.parametrize(\"batch_size\", [1, 4])\n@pytest.mark.parametrize(\"shuffle\", [True, False])\n@pytest.mark.parametrize(\"drop_last\", [True, False])\ndef test_HeteroItemSet_node_pairs_labels_indexes(\n    batch_size, shuffle, drop_last\n):\n    # Head, tail and negative tails.\n    num_ids = 103\n    total_ids = 6 * num_ids\n    num_negs = 2\n    node_pairs_like = torch.arange(0, num_ids * 2).reshape(-1, 2)\n    node_pairs_follow = torch.arange(num_ids * 2, num_ids * 4).reshape(-1, 2)\n    neg_dsts_like = torch.arange(num_ids * 4, num_ids * 4 + num_ids * num_negs)\n    neg_node_pairs_like = (\n        torch.cat(\n            (node_pairs_like[:, 0].repeat_interleave(num_negs), neg_dsts_like)\n        )\n        .view(2, -1)\n        .T\n    )\n    all_node_pairs_like = torch.cat((node_pairs_like, neg_node_pairs_like))\n    labels_like = torch.empty(num_ids * 3)\n    labels_like[:num_ids] = 1\n    labels_like[num_ids:] = 0\n    indexes_like = torch.cat(\n        (\n            torch.arange(0, num_ids),\n            torch.arange(0, num_ids).repeat_interleave(num_negs),\n        )\n    )\n    neg_dsts_follow = torch.arange(\n        num_ids * 4 + num_ids * num_negs, num_ids * 4 + num_ids * num_negs * 2\n    )\n    neg_node_pairs_follow = (\n        torch.cat(\n            (\n                node_pairs_follow[:, 0].repeat_interleave(num_negs),\n                neg_dsts_follow,\n            )\n        )\n        .view(2, -1)\n        .T\n    )\n    all_node_pairs_follow = torch.cat(\n        (node_pairs_follow, neg_node_pairs_follow)\n    )\n    labels_follow = torch.empty(num_ids * 3)\n    labels_follow[:num_ids] = 1\n    labels_follow[num_ids:] = 0\n    indexes_follow = torch.cat(\n        (\n            torch.arange(0, num_ids),\n            torch.arange(0, num_ids).repeat_interleave(num_negs),\n        )\n    )\n    data_dict = {\n        \"user:like:item\": gb.ItemSet(\n            (all_node_pairs_like, labels_like, indexes_like),\n            names=(\"seeds\", \"labels\", \"indexes\"),\n        ),\n        \"user:follow:user\": gb.ItemSet(\n            (all_node_pairs_follow, labels_follow, indexes_follow),\n            names=(\"seeds\", \"labels\", \"indexes\"),\n        ),\n    }\n    item_set = gb.HeteroItemSet(data_dict)\n    item_sampler = gb.ItemSampler(\n        item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last\n    )\n    src_ids = []\n    dst_ids = []\n    negs_ids = []\n    final_labels = defaultdict(list)\n    final_indexes = defaultdict(list)\n    for i, minibatch in enumerate(item_sampler):\n        assert isinstance(minibatch, gb.MiniBatch)\n        assert minibatch.seeds is not None\n        assert minibatch.labels is not None\n        assert minibatch.indexes is not None\n        is_last = (i + 1) * batch_size >= total_ids\n        if not is_last or total_ids % batch_size == 0:\n            expected_batch_size = batch_size\n        else:\n            if not drop_last:\n                expected_batch_size = total_ids % batch_size\n            else:\n                assert False\n        src = []\n        dst = []\n        negs_src = []\n        negs_dst = []\n        for etype, seeds in minibatch.seeds.items():\n            assert isinstance(seeds, torch.Tensor)\n            src_etype = seeds[:, 0]\n            dst_etype = seeds[:, 1]\n            src.append(src_etype)\n            dst.append(dst_etype)\n            negs_src.append(src_etype[~minibatch.labels[etype].to(bool)])\n            negs_dst.append(dst_etype[~minibatch.labels[etype].to(bool)])\n            final_labels[etype].append(minibatch.labels[etype])\n            final_indexes[etype].append(minibatch.indexes[etype])\n        src = torch.cat(src)\n        dst = torch.cat(dst)\n        negs_src = torch.cat(negs_src)\n        negs_dst = torch.cat(negs_dst)\n        assert len(src) == expected_batch_size\n        assert len(dst) == expected_batch_size\n        src_ids.append(src)\n        dst_ids.append(dst)\n        negs_ids.append(negs_dst)\n        assert negs_src.dim() == 1\n        assert negs_dst.dim() == 1\n        assert torch.equal(negs_src, (negs_dst - num_ids * 4) // 2 * 2)\n    src_ids = torch.cat(src_ids)\n    dst_ids = torch.cat(dst_ids)\n    negs_ids = torch.cat(negs_ids)\n    assert torch.all(src_ids[:-1] <= src_ids[1:]) is not shuffle\n    assert torch.all(dst_ids[:-1] <= dst_ids[1:]) is not shuffle\n    assert torch.all(negs_ids <= negs_ids) is not shuffle\n    for etype in data_dict.keys():\n        final_labels_etype = torch.cat(final_labels[etype])\n        final_indexes_etype = torch.cat(final_indexes[etype])\n        assert (\n            torch.all(final_labels_etype[:-1] >= final_labels_etype[1:])\n            is not shuffle\n        )\n        if not drop_last:\n            assert final_labels_etype.sum() == num_ids\n            assert (\n                torch.equal(final_indexes_etype, indexes_follow) is not shuffle\n            )\n\n\n@pytest.mark.parametrize(\"batch_size\", [1, 4])\n@pytest.mark.parametrize(\"shuffle\", [True, False])\n@pytest.mark.parametrize(\"drop_last\", [True, False])\ndef test_HeteroItemSet_hyperlink(batch_size, shuffle, drop_last):\n    # Node pairs.\n    num_ids = 103\n    total_pairs = 2 * num_ids\n    seeds_like = torch.arange(0, num_ids * 3).reshape(-1, 3)\n    seeds_follow = torch.arange(num_ids * 3, num_ids * 6).reshape(-1, 3)\n    seeds_dict = {\n        \"user:like:item\": gb.ItemSet(seeds_like, names=\"seeds\"),\n        \"user:follow:user\": gb.ItemSet(seeds_follow, names=\"seeds\"),\n    }\n    item_set = gb.HeteroItemSet(seeds_dict)\n    item_sampler = gb.ItemSampler(\n        item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last\n    )\n    seeds_ids = []\n    for i, minibatch in enumerate(item_sampler):\n        assert isinstance(minibatch, gb.MiniBatch)\n        assert minibatch.seeds is not None\n        assert minibatch.labels is None\n        assert minibatch.indexes is None\n        is_last = (i + 1) * batch_size >= total_pairs\n        if not is_last or total_pairs % batch_size == 0:\n            expected_batch_size = batch_size\n        else:\n            if not drop_last:\n                expected_batch_size = total_pairs % batch_size\n            else:\n                assert False\n        seeds_lst = []\n        for _, (seeds) in minibatch.seeds.items():\n            assert isinstance(seeds, torch.Tensor)\n            seeds_lst.append(seeds)\n        seeds_lst = torch.cat(seeds_lst)\n        assert seeds_lst.shape == (expected_batch_size, 3)\n        seeds_ids.append(seeds_lst)\n        assert torch.equal(seeds_lst[:, 0] + 1, seeds_lst[:, 1])\n        assert torch.equal(seeds_lst[:, 1] + 1, seeds_lst[:, 2])\n    seeds_ids = torch.cat(seeds_ids)\n    assert torch.all(seeds_ids[:-1, 0] <= seeds_ids[1:, 0]) is not shuffle\n    assert torch.all(seeds_ids[:-1, 1] <= seeds_ids[1:, 1]) is not shuffle\n    assert torch.all(seeds_ids[:-1, 2] <= seeds_ids[1:, 2]) is not shuffle\n\n\n@pytest.mark.parametrize(\"batch_size\", [1, 4])\n@pytest.mark.parametrize(\"shuffle\", [True, False])\n@pytest.mark.parametrize(\"drop_last\", [True, False])\ndef test_HeteroItemSet_hyperlink_labels(batch_size, shuffle, drop_last):\n    # Node pairs and labels\n    num_ids = 103\n    total_ids = 2 * num_ids\n    seeds_like = torch.arange(0, num_ids * 3).reshape(-1, 3)\n    seeds_follow = torch.arange(num_ids * 3, num_ids * 6).reshape(-1, 3)\n    seeds_dict = {\n        \"user:like:item\": gb.ItemSet(\n            (seeds_like, seeds_like[:, 0]),\n            names=(\"seeds\", \"labels\"),\n        ),\n        \"user:follow:user\": gb.ItemSet(\n            (seeds_follow, seeds_follow[:, 0]),\n            names=(\"seeds\", \"labels\"),\n        ),\n    }\n    item_set = gb.HeteroItemSet(seeds_dict)\n    item_sampler = gb.ItemSampler(\n        item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last\n    )\n    seeds_ids = []\n    labels = []\n    for i, minibatch in enumerate(item_sampler):\n        assert isinstance(minibatch, gb.MiniBatch)\n        assert minibatch.seeds is not None\n        assert minibatch.labels is not None\n        assert minibatch.indexes is None\n        is_last = (i + 1) * batch_size >= total_ids\n        if not is_last or total_ids % batch_size == 0:\n            expected_batch_size = batch_size\n        else:\n            if not drop_last:\n                expected_batch_size = total_ids % batch_size\n            else:\n                assert False\n        seeds_lst = []\n        label = []\n        for _, seeds in minibatch.seeds.items():\n            assert isinstance(seeds, torch.Tensor)\n            seeds_lst.append(seeds)\n        for _, v_label in minibatch.labels.items():\n            label.append(v_label)\n        seeds_lst = torch.cat(seeds_lst)\n        label = torch.cat(label)\n        assert seeds_lst.shape == (expected_batch_size, 3)\n        assert len(label) == expected_batch_size\n        seeds_ids.append(seeds_lst)\n        labels.append(label)\n        assert torch.equal(seeds_lst[:, 0] + 1, seeds_lst[:, 1])\n        assert torch.equal(seeds_lst[:, 1] + 1, seeds_lst[:, 2])\n        assert torch.equal(seeds_lst[:, 0], label)\n    seeds_ids = torch.cat(seeds_ids)\n    labels = torch.cat(labels)\n    assert torch.all(seeds_ids[:-1, 0] <= seeds_ids[1:, 0]) is not shuffle\n    assert torch.all(seeds_ids[:-1, 1] <= seeds_ids[1:, 1]) is not shuffle\n    assert torch.all(seeds_ids[:-1, 2] <= seeds_ids[1:, 2]) is not shuffle\n    assert torch.all(labels[:-1] <= labels[1:]) is not shuffle\n\n\ndef distributed_item_sampler_subprocess(\n    proc_id,\n    nprocs,\n    item_set,\n    num_ids,\n    num_workers,\n    batch_size,\n    drop_last,\n    drop_uneven_inputs,\n):\n    # On Windows, the init method can only be file.\n    init_method = (\n        f\"file:///{os.path.join(os.getcwd(), 'dis_tempfile')}\"\n        if platform == \"win32\"\n        else \"tcp://127.0.0.1:12345\"\n    )\n    dist.init_process_group(\n        backend=\"gloo\",  # Use Gloo backend for CPU multiprocessing\n        init_method=init_method,\n        world_size=nprocs,\n        rank=proc_id,\n    )\n\n    # Create a DistributedItemSampler.\n    item_sampler = gb.DistributedItemSampler(\n        item_set,\n        batch_size=batch_size,\n        shuffle=True,\n        drop_last=drop_last,\n        drop_uneven_inputs=drop_uneven_inputs,\n    )\n    feature_fetcher = gb.FeatureFetcher(\n        item_sampler,\n        gb.BasicFeatureStore({}),\n        [],\n    )\n    data_loader = gb.DataLoader(feature_fetcher, num_workers=num_workers)\n\n    # Count the numbers of items and batches.\n    num_items = 0\n    sampled_count = torch.zeros(num_ids, dtype=torch.int32)\n    for i in data_loader:\n        # Count how many times each item is sampled.\n        sampled_count[i.seeds] += 1\n        if drop_last:\n            assert i.seeds.size(0) == batch_size\n        num_items += i.seeds.size(0)\n    num_batches = len(list(item_sampler))\n\n    if drop_uneven_inputs:\n        num_batches_tensor = torch.tensor(num_batches)\n        dist.broadcast(num_batches_tensor, 0)\n        # Test if the number of batches are the same for all processes.\n        assert num_batches_tensor == num_batches\n\n    # Add up results from all processes.\n    dist.reduce(sampled_count, 0)\n\n    try:\n        # Make sure no item is sampled more than once.\n        assert sampled_count.max() <= 1\n    finally:\n        dist.destroy_process_group()\n\n\n@pytest.mark.parametrize(\n    \"params\",\n    [\n        ((24, 4, 0, 4, False, False), [(8, 8), (8, 8), (4, 4), (4, 4)]),\n        ((30, 4, 0, 4, False, False), [(8, 8), (8, 8), (8, 8), (6, 6)]),\n        ((30, 4, 0, 4, True, False), [(8, 8), (8, 8), (8, 8), (6, 4)]),\n        ((30, 4, 0, 4, False, True), [(8, 8), (8, 8), (8, 8), (6, 6)]),\n        ((30, 4, 0, 4, True, True), [(8, 4), (8, 4), (8, 4), (6, 4)]),\n        (\n            (53, 4, 2, 4, False, False),\n            [(8, 8), (8, 8), (8, 8), (5, 5), (8, 8), (4, 4), (8, 8), (4, 4)],\n        ),\n        (\n            (53, 4, 2, 4, True, False),\n            [(8, 8), (8, 8), (9, 8), (4, 4), (8, 8), (4, 4), (8, 8), (4, 4)],\n        ),\n        (\n            (53, 4, 2, 4, False, True),\n            [(10, 8), (6, 4), (9, 8), (4, 4), (8, 8), (4, 4), (8, 8), (4, 4)],\n        ),\n        (\n            (53, 4, 2, 4, True, True),\n            [(10, 8), (6, 4), (9, 8), (4, 4), (8, 8), (4, 4), (8, 8), (4, 4)],\n        ),\n        (\n            (63, 4, 2, 4, False, False),\n            [(8, 8), (8, 8), (8, 8), (8, 8), (8, 8), (8, 8), (8, 8), (7, 7)],\n        ),\n        (\n            (63, 4, 2, 4, True, False),\n            [(8, 8), (8, 8), (8, 8), (8, 8), (8, 8), (8, 8), (10, 8), (5, 4)],\n        ),\n        (\n            (63, 4, 2, 4, False, True),\n            [(8, 8), (8, 8), (8, 8), (8, 8), (8, 8), (8, 8), (8, 8), (7, 7)],\n        ),\n        (\n            (63, 4, 2, 4, True, True),\n            [\n                (10, 8),\n                (6, 4),\n                (10, 8),\n                (6, 4),\n                (10, 8),\n                (6, 4),\n                (10, 8),\n                (5, 4),\n            ],\n        ),\n        (\n            (65, 4, 2, 4, False, False),\n            [(9, 9), (8, 8), (8, 8), (8, 8), (8, 8), (8, 8), (8, 8), (8, 8)],\n        ),\n        (\n            (65, 4, 2, 4, True, True),\n            [(9, 8), (8, 8), (8, 8), (8, 8), (8, 8), (8, 8), (8, 8), (8, 8)],\n        ),\n    ],\n)\ndef test_RangeCalculation(params):\n    (\n        (\n            total,\n            num_replicas,\n            num_workers,\n            batch_size,\n            drop_last,\n            drop_uneven_inputs,\n        ),\n        key,\n    ) = params\n    answer = []\n    sum = 0\n    for rank in range(num_replicas):\n        for worker_id in range(max(num_workers, 1)):\n            result = gb.internal.calculate_range(\n                True,\n                total,\n                num_replicas,\n                rank,\n                num_workers,\n                worker_id,\n                batch_size,\n                drop_last,\n                drop_uneven_inputs,\n            )\n            assert sum == result[0]\n            sum += result[1]\n            answer.append((result[1], result[2]))\n    assert key == answer\n\n\n@unittest.skipIf(F._default_context_str != \"cpu\", reason=\"GPU not required.\")\n@pytest.mark.parametrize(\"num_ids\", [24, 30, 32, 34, 36])\n@pytest.mark.parametrize(\"num_workers\", [0, 2])\n@pytest.mark.parametrize(\"drop_last\", [False, True])\n@pytest.mark.parametrize(\"drop_uneven_inputs\", [False, True])\ndef test_DistributedItemSampler(\n    num_ids, num_workers, drop_last, drop_uneven_inputs\n):\n    nprocs = 4\n    batch_size = 4\n    item_set = gb.ItemSet(torch.arange(0, num_ids), names=\"seeds\")\n\n    # On Windows, if the process group initialization file already exists,\n    # the program may hang. So we need to delete it if it exists.\n    if platform == \"win32\":\n        try:\n            os.remove(os.path.join(os.getcwd(), \"dis_tempfile\"))\n        except FileNotFoundError:\n            pass\n\n    mp.spawn(\n        distributed_item_sampler_subprocess,\n        args=(\n            nprocs,\n            item_set,\n            num_ids,\n            num_workers,\n            batch_size,\n            drop_last,\n            drop_uneven_inputs,\n        ),\n        nprocs=nprocs,\n        join=True,\n    )\n"
  },
  {
    "path": "tests/python/pytorch/graphbolt/test_itemset.py",
    "content": "import re\n\nimport dgl\nimport pytest\nimport torch\nfrom dgl import graphbolt as gb\n\n\ndef test_ItemSet_names():\n    # ItemSet with single name.\n    item_set = gb.ItemSet(torch.arange(0, 5), names=\"seeds\")\n    assert item_set.names == (\"seeds\",)\n\n    # ItemSet with multiple names.\n    item_set = gb.ItemSet(\n        (torch.arange(0, 5), torch.arange(5, 10)),\n        names=(\"seeds\", \"labels\"),\n    )\n    assert item_set.names == (\"seeds\", \"labels\")\n\n    # ItemSet without name.\n    item_set = gb.ItemSet(torch.arange(0, 5))\n    assert item_set.names is None\n\n    # Integer-initiated ItemSet with excessive names.\n    with pytest.raises(\n        AssertionError,\n        match=re.escape(\"Number of items (1) and names (2) don't match.\"),\n    ):\n        _ = gb.ItemSet(5, names=(\"seeds\", \"labels\"))\n\n    # ItemSet with mismatched items and names.\n    with pytest.raises(\n        AssertionError,\n        match=re.escape(\"Number of items (1) and names (2) don't match.\"),\n    ):\n        _ = gb.ItemSet(torch.arange(0, 5), names=(\"seeds\", \"labels\"))\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.int32, torch.int64])\ndef test_ItemSet_scalar_dtype(dtype):\n    item_set = gb.ItemSet(torch.tensor(5, dtype=dtype), names=\"seeds\")\n    for i, item in enumerate(item_set):\n        assert i == item\n        assert item.dtype == dtype\n    assert item_set[2] == torch.tensor(2, dtype=dtype)\n    assert torch.equal(\n        item_set[slice(1, 4, 2)], torch.arange(1, 4, 2, dtype=dtype)\n    )\n\n\ndef test_ItemSet_length():\n    # Integer with valid length\n    num = 10\n    item_set = gb.ItemSet(num)\n    assert len(item_set) == 10\n    # Test __iter__() method. Same as below.\n    for i, item in enumerate(item_set):\n        assert i == item\n\n    # Single iterable with valid length.\n    ids = torch.arange(0, 5)\n    item_set = gb.ItemSet(ids)\n    assert len(item_set) == 5\n    for i, item in enumerate(item_set):\n        assert i == item.item()\n\n    # Tuple of iterables with valid length.\n    item_set = gb.ItemSet((torch.arange(0, 5), torch.arange(5, 10)))\n    assert len(item_set) == 5\n    for i, (item1, item2) in enumerate(item_set):\n        assert i == item1.item()\n        assert i + 5 == item2.item()\n\n    class InvalidLength:\n        def __iter__(self):\n            return iter([0, 1, 2])\n\n    # Single iterable with invalid length.\n    with pytest.raises(\n        TypeError, match=\"object of type 'InvalidLength' has no len()\"\n    ):\n        item_set = gb.ItemSet(InvalidLength())\n\n    # Tuple of iterables with invalid length.\n    with pytest.raises(\n        TypeError, match=\"object of type 'InvalidLength' has no len()\"\n    ):\n        item_set = gb.ItemSet((InvalidLength(), InvalidLength()))\n\n\ndef test_ItemSet_seed_nodes():\n    # Node IDs with tensor.\n    item_set = gb.ItemSet(torch.arange(0, 5), names=\"seeds\")\n    assert item_set.names == (\"seeds\",)\n    # Iterating over ItemSet and indexing one by one.\n    for i, item in enumerate(item_set):\n        assert i == item.item()\n        assert i == item_set[i]\n    # Indexing with a slice.\n    assert torch.equal(item_set[::2], torch.tensor([0, 2, 4]))\n    # Indexing with an Iterable.\n    assert torch.equal(item_set[torch.arange(0, 5)], torch.arange(0, 5))\n\n    # Node IDs with single integer.\n    item_set = gb.ItemSet(5, names=\"seeds\")\n    assert item_set.names == (\"seeds\",)\n    # Iterating over ItemSet and indexing one by one.\n    for i, item in enumerate(item_set):\n        assert i == item.item()\n        assert i == item_set[i]\n    # Indexing with a slice.\n    assert torch.equal(item_set[::2], torch.tensor([0, 2, 4]))\n    assert torch.equal(item_set[torch.arange(0, 5)], torch.arange(0, 5))\n    # Indexing with an integer.\n    assert item_set[0] == 0\n    assert item_set[-1] == 4\n    # Indexing that is out of range.\n    with pytest.raises(IndexError, match=\"ItemSet index out of range.\"):\n        _ = item_set[5]\n    with pytest.raises(IndexError, match=\"ItemSet index out of range.\"):\n        _ = item_set[-10]\n    # Indexing with invalid input type.\n    with pytest.raises(\n        TypeError,\n        match=\"ItemSet indices must be int, slice, or torch.Tensor, not <class 'float'>.\",\n    ):\n        _ = item_set[1.5]\n\n\ndef test_ItemSet_seed_nodes_labels():\n    # Node IDs and labels.\n    seed_nodes = torch.arange(0, 5)\n    labels = torch.randint(0, 3, (5,))\n    item_set = gb.ItemSet((seed_nodes, labels), names=(\"seeds\", \"labels\"))\n    assert item_set.names == (\"seeds\", \"labels\")\n    # Iterating over ItemSet and indexing one by one.\n    for i, (seed_node, label) in enumerate(item_set):\n        assert seed_node == seed_nodes[i]\n        assert label == labels[i]\n        assert seed_node == item_set[i][0]\n        assert label == item_set[i][1]\n    # Indexing with a slice.\n    assert torch.equal(item_set[:][0], seed_nodes)\n    assert torch.equal(item_set[:][1], labels)\n    # Indexing with an Iterable.\n    assert torch.equal(item_set[torch.arange(0, 5)][0], seed_nodes)\n    assert torch.equal(item_set[torch.arange(0, 5)][1], labels)\n\n\ndef test_ItemSet_node_pairs():\n    # Node pairs.\n    node_pairs = torch.arange(0, 10).reshape(-1, 2)\n    item_set = gb.ItemSet(node_pairs, names=\"seeds\")\n    assert item_set.names == (\"seeds\",)\n    # Iterating over ItemSet and indexing one by one.\n    for i, (src, dst) in enumerate(item_set):\n        assert node_pairs[i][0] == src\n        assert node_pairs[i][1] == dst\n        assert node_pairs[i][0] == item_set[i][0]\n        assert node_pairs[i][1] == item_set[i][1]\n    # Indexing with a slice.\n    assert torch.equal(item_set[:], node_pairs)\n    # Indexing with an Iterable.\n    assert torch.equal(item_set[torch.arange(0, 5)], node_pairs)\n\n\ndef test_ItemSet_node_pairs_labels():\n    # Node pairs and labels\n    node_pairs = torch.arange(0, 10).reshape(-1, 2)\n    labels = torch.randint(0, 3, (5,))\n    item_set = gb.ItemSet((node_pairs, labels), names=(\"seeds\", \"labels\"))\n    assert item_set.names == (\"seeds\", \"labels\")\n    # Iterating over ItemSet and indexing one by one.\n    for i, (node_pair, label) in enumerate(item_set):\n        assert torch.equal(node_pairs[i], node_pair)\n        assert labels[i] == label\n        assert torch.equal(node_pairs[i], item_set[i][0])\n        assert labels[i] == item_set[i][1]\n    # Indexing with a slice.\n    assert torch.equal(item_set[:][0], node_pairs)\n    assert torch.equal(item_set[:][1], labels)\n    # Indexing with an Iterable.\n    assert torch.equal(item_set[torch.arange(0, 5)][0], node_pairs)\n    assert torch.equal(item_set[torch.arange(0, 5)][1], labels)\n\n\ndef test_ItemSet_node_pairs_labels_indexes():\n    # Node pairs and negative destinations.\n    node_pairs = torch.arange(0, 10).reshape(-1, 2)\n    labels = torch.tensor([1, 1, 0, 0, 0])\n    indexes = torch.tensor([0, 1, 0, 0, 1])\n    item_set = gb.ItemSet(\n        (node_pairs, labels, indexes), names=(\"seeds\", \"labels\", \"indexes\")\n    )\n    assert item_set.names == (\"seeds\", \"labels\", \"indexes\")\n    # Iterating over ItemSet and indexing one by one.\n    for i, (node_pair, label, index) in enumerate(item_set):\n        assert torch.equal(node_pairs[i], node_pair)\n        assert torch.equal(labels[i], label)\n        assert torch.equal(indexes[i], index)\n        assert torch.equal(node_pairs[i], item_set[i][0])\n        assert torch.equal(labels[i], item_set[i][1])\n        assert torch.equal(indexes[i], item_set[i][2])\n    # Indexing with a slice.\n    assert torch.equal(item_set[:][0], node_pairs)\n    assert torch.equal(item_set[:][1], labels)\n    assert torch.equal(item_set[:][2], indexes)\n    # Indexing with an Iterable.\n    assert torch.equal(item_set[torch.arange(0, 5)][0], node_pairs)\n    assert torch.equal(item_set[torch.arange(0, 5)][1], labels)\n    assert torch.equal(item_set[torch.arange(0, 5)][2], indexes)\n\n\ndef test_ItemSet_graphs():\n    # Graphs.\n    graphs = [dgl.rand_graph(10, 20) for _ in range(5)]\n    item_set = gb.ItemSet(graphs)\n    assert item_set.names is None\n    # Iterating over ItemSet and indexing one by one.\n    for i, item in enumerate(item_set):\n        assert graphs[i] == item\n        assert graphs[i] == item_set[i]\n    # Indexing with a slice.\n    assert item_set[:] == graphs\n\n\ndef test_HeteroItemSet_names():\n    # HeteroItemSet with single name.\n    item_set = gb.HeteroItemSet(\n        {\n            \"user\": gb.ItemSet(torch.arange(0, 5), names=\"seeds\"),\n            \"item\": gb.ItemSet(torch.arange(5, 10), names=\"seeds\"),\n        }\n    )\n    assert item_set.names == (\"seeds\",)\n\n    # HeteroItemSet with multiple names.\n    item_set = gb.HeteroItemSet(\n        {\n            \"user\": gb.ItemSet(\n                (torch.arange(0, 5), torch.arange(5, 10)),\n                names=(\"seeds\", \"labels\"),\n            ),\n            \"item\": gb.ItemSet(\n                (torch.arange(5, 10), torch.arange(10, 15)),\n                names=(\"seeds\", \"labels\"),\n            ),\n        }\n    )\n    assert item_set.names == (\"seeds\", \"labels\")\n\n    # HeteroItemSet with no name.\n    item_set = gb.HeteroItemSet(\n        {\n            \"user\": gb.ItemSet(torch.arange(0, 5)),\n            \"item\": gb.ItemSet(torch.arange(5, 10)),\n        }\n    )\n    assert item_set.names is None\n\n    # HeteroItemSet with mismatched items and names.\n    with pytest.raises(\n        AssertionError,\n        match=re.escape(\"All itemsets must have the same names.\"),\n    ):\n        _ = gb.HeteroItemSet(\n            {\n                \"user\": gb.ItemSet(\n                    (torch.arange(0, 5), torch.arange(5, 10)),\n                    names=(\"seeds\", \"labels\"),\n                ),\n                \"item\": gb.ItemSet((torch.arange(5, 10),), names=(\"seeds\",)),\n            }\n        )\n\n\ndef test_HeteroItemSet_length():\n    # Single iterable with valid length.\n    user_ids = torch.arange(0, 5)\n    item_ids = torch.arange(0, 5)\n    item_set = gb.HeteroItemSet(\n        {\n            \"user\": gb.ItemSet(user_ids),\n            \"item\": gb.ItemSet(item_ids),\n        }\n    )\n    assert len(item_set) == len(user_ids) + len(item_ids)\n\n    # Tuple of iterables with valid length.\n    node_pairs_like = torch.arange(0, 10).reshape(-1, 2)\n    neg_dsts_like = torch.arange(10, 20).reshape(-1, 2)\n    node_pairs_follow = torch.arange(0, 10).reshape(-1, 2)\n    neg_dsts_follow = torch.arange(10, 20).reshape(-1, 2)\n    item_set = gb.HeteroItemSet(\n        {\n            \"user:like:item\": gb.ItemSet((node_pairs_like, neg_dsts_like)),\n            \"user:follow:user\": gb.ItemSet(\n                (node_pairs_follow, neg_dsts_follow)\n            ),\n        }\n    )\n    assert len(item_set) == node_pairs_like.size(0) + node_pairs_follow.size(0)\n\n    class InvalidLength:\n        def __iter__(self):\n            return iter([0, 1, 2])\n\n    # Single iterable with invalid length.\n    with pytest.raises(\n        TypeError, match=\"object of type 'InvalidLength' has no len()\"\n    ):\n        item_set = gb.HeteroItemSet(\n            {\n                \"user\": gb.ItemSet(InvalidLength()),\n                \"item\": gb.ItemSet(InvalidLength()),\n            }\n        )\n\n    # Tuple of iterables with invalid length.\n    with pytest.raises(\n        TypeError, match=\"object of type 'InvalidLength' has no len()\"\n    ):\n        item_set = gb.HeteroItemSet(\n            {\n                \"user:like:item\": gb.ItemSet(\n                    (InvalidLength(), InvalidLength())\n                ),\n                \"user:follow:user\": gb.ItemSet(\n                    (InvalidLength(), InvalidLength())\n                ),\n            }\n        )\n\n\ndef test_HeteroItemSet_iteration_seed_nodes():\n    # Node IDs.\n    user_ids = torch.arange(0, 5)\n    item_ids = torch.arange(5, 10)\n    ids = {\n        \"user\": gb.ItemSet(user_ids, names=\"seeds\"),\n        \"item\": gb.ItemSet(item_ids, names=\"seeds\"),\n    }\n    chained_ids = []\n    for key, value in ids.items():\n        chained_ids += [(key, v) for v in value]\n    item_set = gb.HeteroItemSet(ids)\n    assert item_set.names == (\"seeds\",)\n    # Iterating over HeteroItemSet and indexing one by one.\n    for i, item in enumerate(item_set):\n        assert len(item) == 1\n        assert isinstance(item, dict)\n        assert chained_ids[i][0] in item\n        assert item[chained_ids[i][0]] == chained_ids[i][1]\n        assert item_set[i] == item\n        assert item_set[i - len(item_set)] == item\n    # Indexing all with a slice.\n    assert torch.equal(item_set[:][\"user\"], user_ids)\n    assert torch.equal(item_set[:][\"item\"], item_ids)\n    # Indexing partial with a slice.\n    partial_data = item_set[:3]\n    assert len(list(partial_data.keys())) == 1\n    assert torch.equal(partial_data[\"user\"], user_ids[:3])\n    partial_data = item_set[7:]\n    assert len(list(partial_data.keys())) == 1\n    assert torch.equal(partial_data[\"item\"], item_ids[2:])\n    partial_data = item_set[3:8:2]\n    assert len(list(partial_data.keys())) == 2\n    assert torch.equal(partial_data[\"user\"], user_ids[3:-1:2])\n    assert torch.equal(partial_data[\"item\"], item_ids[0:3:2])\n    # Indexing with an iterable of int.\n    partial_data = item_set[torch.tensor([1, 0, 4])]\n    assert len(list(partial_data.keys())) == 1\n    assert torch.equal(partial_data[\"user\"], torch.tensor([1, 0, 4]))\n    partial_data = item_set[torch.tensor([9, 8, 5])]\n    assert len(list(partial_data.keys())) == 1\n    assert torch.equal(partial_data[\"item\"], torch.tensor([9, 8, 5]))\n    partial_data = item_set[torch.tensor([8, 1, 0, 9, 7, 5])]\n    assert len(list(partial_data.keys())) == 2\n    assert torch.equal(partial_data[\"user\"], torch.tensor([1, 0]))\n    assert torch.equal(partial_data[\"item\"], torch.tensor([8, 9, 7, 5]))\n\n    # Exception cases.\n    with pytest.raises(\n        AssertionError, match=\"Start must be smaller than stop.\"\n    ):\n        _ = item_set[5:3]\n    with pytest.raises(\n        AssertionError, match=\"Start must be smaller than stop.\"\n    ):\n        _ = item_set[-1:3]\n    with pytest.raises(IndexError, match=\"HeteroItemSet index out of range.\"):\n        _ = item_set[20]\n    with pytest.raises(IndexError, match=\"HeteroItemSet index out of range.\"):\n        _ = item_set[-20]\n    with pytest.raises(\n        TypeError,\n        match=\"HeteroItemSet indices must be int, slice, or iterable of int, not <class 'float'>.\",\n    ):\n        _ = item_set[1.5]\n\n\ndef test_HeteroItemSet_iteration_seed_nodes_labels():\n    # Node IDs and labels.\n    user_ids = torch.arange(0, 5)\n    user_labels = torch.randint(0, 3, (5,))\n    item_ids = torch.arange(5, 10)\n    item_labels = torch.randint(0, 3, (5,))\n    ids_labels = {\n        \"user\": gb.ItemSet((user_ids, user_labels), names=(\"seeds\", \"labels\")),\n        \"item\": gb.ItemSet((item_ids, item_labels), names=(\"seeds\", \"labels\")),\n    }\n    chained_ids = []\n    for key, value in ids_labels.items():\n        chained_ids += [(key, v) for v in value]\n    item_set = gb.HeteroItemSet(ids_labels)\n    assert item_set.names == (\"seeds\", \"labels\")\n    # Iterating over HeteroItemSet and indexing one by one.\n    for i, item in enumerate(item_set):\n        assert len(item) == 1\n        assert isinstance(item, dict)\n        assert chained_ids[i][0] in item\n        assert item[chained_ids[i][0]] == chained_ids[i][1]\n        assert item_set[i] == item\n    # Indexing with a slice.\n    assert torch.equal(item_set[:][\"user\"][0], user_ids)\n    assert torch.equal(item_set[:][\"user\"][1], user_labels)\n    assert torch.equal(item_set[:][\"item\"][0], item_ids)\n    assert torch.equal(item_set[:][\"item\"][1], item_labels)\n\n\ndef test_HeteroItemSet_iteration_node_pairs():\n    # Node pairs.\n    node_pairs = torch.arange(0, 10).reshape(-1, 2)\n    node_pairs_dict = {\n        \"user:like:item\": gb.ItemSet(node_pairs, names=\"seeds\"),\n        \"user:follow:user\": gb.ItemSet(node_pairs, names=\"seeds\"),\n    }\n    expected_data = []\n    for key, value in node_pairs_dict.items():\n        expected_data += [(key, v) for v in value]\n    item_set = gb.HeteroItemSet(node_pairs_dict)\n    assert item_set.names == (\"seeds\",)\n    # Iterating over HeteroItemSet and indexing one by one.\n    for i, item in enumerate(item_set):\n        assert len(item) == 1\n        assert isinstance(item, dict)\n        assert expected_data[i][0] in item\n        assert torch.equal(item[expected_data[i][0]], expected_data[i][1])\n        assert item_set[i].keys() == item.keys()\n        key = list(item.keys())[0]\n        assert torch.equal(item_set[i][key], item[key])\n    # Indexing with a slice.\n    assert torch.equal(item_set[:][\"user:like:item\"], node_pairs)\n    assert torch.equal(item_set[:][\"user:follow:user\"], node_pairs)\n\n\ndef test_HeteroItemSet_iteration_node_pairs_labels():\n    # Node pairs and labels\n    node_pairs = torch.arange(0, 10).reshape(-1, 2)\n    labels = torch.randint(0, 3, (5,))\n    node_pairs_labels = {\n        \"user:like:item\": gb.ItemSet(\n            (node_pairs, labels), names=(\"seeds\", \"labels\")\n        ),\n        \"user:follow:user\": gb.ItemSet(\n            (node_pairs, labels), names=(\"seeds\", \"labels\")\n        ),\n    }\n    expected_data = []\n    for key, value in node_pairs_labels.items():\n        expected_data += [(key, v) for v in value]\n    item_set = gb.HeteroItemSet(node_pairs_labels)\n    assert item_set.names == (\"seeds\", \"labels\")\n    # Iterating over HeteroItemSet and indexing one by one.\n    for i, item in enumerate(item_set):\n        assert len(item) == 1\n        assert isinstance(item, dict)\n        key, value = expected_data[i]\n        assert key in item\n        assert torch.equal(item[key][0], value[0])\n        assert item[key][1] == value[1]\n        assert item_set[i].keys() == item.keys()\n        key = list(item.keys())[0]\n        assert torch.equal(item_set[i][key][0], item[key][0])\n        assert torch.equal(item_set[i][key][1], item[key][1])\n    # Indexing with a slice.\n    assert torch.equal(item_set[:][\"user:like:item\"][0], node_pairs)\n    assert torch.equal(item_set[:][\"user:like:item\"][1], labels)\n    assert torch.equal(item_set[:][\"user:follow:user\"][0], node_pairs)\n    assert torch.equal(item_set[:][\"user:follow:user\"][1], labels)\n\n\ndef test_HeteroItemSet_iteration_node_pairs_labels_indexes():\n    # Node pairs and negative destinations.\n    node_pairs = torch.arange(0, 10).reshape(-1, 2)\n    labels = torch.tensor([1, 1, 0, 0, 0])\n    indexes = torch.tensor([0, 1, 0, 0, 1])\n    node_pairs_neg_dsts = {\n        \"user:like:item\": gb.ItemSet(\n            (node_pairs, labels, indexes), names=(\"seeds\", \"labels\", \"indexes\")\n        ),\n        \"user:follow:user\": gb.ItemSet(\n            (node_pairs, labels, indexes), names=(\"seeds\", \"labels\", \"indexes\")\n        ),\n    }\n    expected_data = []\n    for key, value in node_pairs_neg_dsts.items():\n        expected_data += [(key, v) for v in value]\n    item_set = gb.HeteroItemSet(node_pairs_neg_dsts)\n    assert item_set.names == (\"seeds\", \"labels\", \"indexes\")\n    # Iterating over HeteroItemSet and indexing one by one.\n    for i, item in enumerate(item_set):\n        assert len(item) == 1\n        assert isinstance(item, dict)\n        key, value = expected_data[i]\n        assert key in item\n        assert torch.equal(item[key][0], value[0])\n        assert torch.equal(item[key][1], value[1])\n        assert torch.equal(item[key][2], value[2])\n        assert item_set[i].keys() == item.keys()\n        key = list(item.keys())[0]\n        assert torch.equal(item_set[i][key][0], item[key][0])\n        assert torch.equal(item_set[i][key][1], item[key][1])\n        assert torch.equal(item_set[i][key][2], item[key][2])\n    # Indexing with a slice.\n    assert torch.equal(item_set[:][\"user:like:item\"][0], node_pairs)\n    assert torch.equal(item_set[:][\"user:like:item\"][1], labels)\n    assert torch.equal(item_set[:][\"user:like:item\"][2], indexes)\n    assert torch.equal(item_set[:][\"user:follow:user\"][0], node_pairs)\n    assert torch.equal(item_set[:][\"user:follow:user\"][1], labels)\n    assert torch.equal(item_set[:][\"user:follow:user\"][2], indexes)\n\n\ndef test_ItemSet_repr():\n    # ItemSet with single name.\n    item_set = gb.ItemSet(torch.arange(0, 5), names=\"seeds\")\n    expected_str = (\n        \"ItemSet(\\n\"\n        \"    items=(tensor([0, 1, 2, 3, 4]),),\\n\"\n        \"    names=('seeds',),\\n\"\n        \")\"\n    )\n\n    assert str(item_set) == expected_str, item_set\n\n    # ItemSet with multiple names.\n    item_set = gb.ItemSet(\n        (torch.arange(0, 5), torch.arange(5, 10)),\n        names=(\"seeds\", \"labels\"),\n    )\n    expected_str = (\n        \"ItemSet(\\n\"\n        \"    items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])),\\n\"\n        \"    names=('seeds', 'labels'),\\n\"\n        \")\"\n    )\n    assert str(item_set) == expected_str, item_set\n\n\ndef test_HeteroItemSet_repr():\n    # HeteroItemSet with single name.\n    item_set = gb.HeteroItemSet(\n        {\n            \"user\": gb.ItemSet(torch.arange(0, 5), names=\"seeds\"),\n            \"item\": gb.ItemSet(torch.arange(5, 10), names=\"seeds\"),\n        }\n    )\n    expected_str = (\n        \"HeteroItemSet(\\n\"\n        \"    itemsets={'user': ItemSet(\\n\"\n        \"                 items=(tensor([0, 1, 2, 3, 4]),),\\n\"\n        \"                 names=('seeds',),\\n\"\n        \"             ), 'item': ItemSet(\\n\"\n        \"                 items=(tensor([5, 6, 7, 8, 9]),),\\n\"\n        \"                 names=('seeds',),\\n\"\n        \"             )},\\n\"\n        \"    names=('seeds',),\\n\"\n        \")\"\n    )\n    assert str(item_set) == expected_str, item_set\n\n    # HeteroItemSet with multiple names.\n    item_set = gb.HeteroItemSet(\n        {\n            \"user\": gb.ItemSet(\n                (torch.arange(0, 5), torch.arange(5, 10)),\n                names=(\"seeds\", \"labels\"),\n            ),\n            \"item\": gb.ItemSet(\n                (torch.arange(5, 10), torch.arange(10, 15)),\n                names=(\"seeds\", \"labels\"),\n            ),\n        }\n    )\n    expected_str = (\n        \"HeteroItemSet(\\n\"\n        \"    itemsets={'user': ItemSet(\\n\"\n        \"                 items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])),\\n\"\n        \"                 names=('seeds', 'labels'),\\n\"\n        \"             ), 'item': ItemSet(\\n\"\n        \"                 items=(tensor([5, 6, 7, 8, 9]), tensor([10, 11, 12, 13, 14])),\\n\"\n        \"                 names=('seeds', 'labels'),\\n\"\n        \"             )},\\n\"\n        \"    names=('seeds', 'labels'),\\n\"\n        \")\"\n    )\n    assert str(item_set) == expected_str, item_set\n\n\ndef test_deprecation_alias():\n    \"\"\"Test `ItemSetDict` as the alias for `HeteroItemSet`.\"\"\"\n\n    user_ids = torch.arange(0, 5)\n    item_ids = torch.arange(5, 10)\n    ids = {\n        \"user\": gb.ItemSet(user_ids, names=\"seeds\"),\n        \"item\": gb.ItemSet(item_ids, names=\"seeds\"),\n    }\n    with pytest.warns(\n        DeprecationWarning,\n        match=\"ItemSetDict is deprecated and will be removed in the future. Please use HeteroItemSet instead.\",\n    ):\n        item_set_dict = gb.ItemSetDict(ids)\n    hetero_item_set = gb.HeteroItemSet(ids)\n    assert len(item_set_dict) == len(hetero_item_set)\n    assert item_set_dict.names == hetero_item_set.names\n    assert item_set_dict._keys == hetero_item_set._keys\n    assert torch.equal(item_set_dict._offsets, hetero_item_set._offsets)\n    assert (\n        repr(item_set_dict)[len(\"ItemSetDict\") :]\n        == repr(hetero_item_set)[len(\"HeteroItemSet\") :]\n    )\n    # Indexing all with a slice.\n    assert torch.equal(item_set_dict[:][\"user\"], hetero_item_set[:][\"user\"])\n    assert torch.equal(item_set_dict[:][\"item\"], hetero_item_set[:][\"item\"])\n    # Indexing partial with a slice.\n    partial_data = item_set_dict[:3]\n    assert len(list(partial_data.keys())) == 1\n    assert torch.equal(partial_data[\"user\"], hetero_item_set[:3][\"user\"])\n    partial_data = item_set_dict[7:]\n    assert len(list(partial_data.keys())) == 1\n    assert torch.equal(partial_data[\"item\"], hetero_item_set[7:][\"item\"])\n    partial_data = item_set_dict[3:8:2]\n    assert len(list(partial_data.keys())) == 2\n    assert torch.equal(partial_data[\"user\"], hetero_item_set[3:8:2][\"user\"])\n    assert torch.equal(partial_data[\"item\"], hetero_item_set[3:8:2][\"item\"])\n    # Indexing with an iterable of int.\n    partial_data = item_set_dict[torch.tensor([1, 0, 4])]\n    assert len(list(partial_data.keys())) == 1\n    assert torch.equal(partial_data[\"user\"], hetero_item_set[1, 0, 4][\"user\"])\n    partial_data = item_set_dict[torch.tensor([9, 8, 5])]\n    assert len(list(partial_data.keys())) == 1\n    assert torch.equal(partial_data[\"item\"], hetero_item_set[9, 8, 5][\"item\"])\n    partial_data = item_set_dict[torch.tensor([8, 1, 0, 9, 7, 5])]\n    assert len(list(partial_data.keys())) == 2\n    assert torch.equal(partial_data[\"user\"], hetero_item_set[1, 0][\"user\"])\n    assert torch.equal(\n        partial_data[\"item\"], hetero_item_set[8, 9, 7, 5][\"item\"]\n    )\n"
  },
  {
    "path": "tests/python/pytorch/graphbolt/test_minibatch.py",
    "content": "import dgl\nimport dgl.graphbolt as gb\nimport pytest\nimport torch\n\n\nrelation = \"A:r:B\"\nreverse_relation = \"B:rr:A\"\n\n\n@pytest.mark.parametrize(\"indptr_dtype\", [torch.int32, torch.int64])\n@pytest.mark.parametrize(\"indices_dtype\", [torch.int32, torch.int64])\ndef test_minibatch_representation_homo(indptr_dtype, indices_dtype):\n    seeds = torch.tensor([10, 11])\n    csc_formats = [\n        gb.CSCFormatBase(\n            indptr=torch.tensor([0, 1, 3, 5, 6], dtype=indptr_dtype),\n            indices=torch.tensor([0, 1, 2, 2, 1, 2], dtype=indices_dtype),\n        ),\n        gb.CSCFormatBase(\n            indptr=torch.tensor([0, 2, 3], dtype=indptr_dtype),\n            indices=torch.tensor([1, 2, 0], dtype=indices_dtype),\n        ),\n    ]\n    original_column_node_ids = [\n        torch.tensor([10, 11, 12, 13]),\n        torch.tensor([10, 11]),\n    ]\n    original_row_node_ids = [\n        torch.tensor([10, 11, 12, 13]),\n        torch.tensor([10, 11, 12]),\n    ]\n    original_edge_ids = [\n        torch.tensor([19, 20, 21, 22, 25, 30]),\n        torch.tensor([10, 15, 17]),\n    ]\n    node_features = {\"x\": torch.tensor([5, 0, 2, 1])}\n    edge_features = [\n        {\"x\": torch.tensor([9, 0, 1, 1, 7, 4])},\n        {\"x\": torch.tensor([0, 2, 2])},\n    ]\n    subgraphs = []\n    for i in range(2):\n        subgraphs.append(\n            gb.SampledSubgraphImpl(\n                sampled_csc=csc_formats[i],\n                original_column_node_ids=original_column_node_ids[i],\n                original_row_node_ids=original_row_node_ids[i],\n                original_edge_ids=original_edge_ids[i],\n            )\n        )\n    input_nodes = torch.tensor([8, 1, 6, 5, 9, 0, 2, 4])\n    compacted_seeds = torch.tensor([0, 1])\n    labels = torch.tensor([1.0, 2.0])\n    # Test minibatch without data.\n    minibatch = gb.MiniBatch()\n    expect_result = str(\n        \"\"\"MiniBatch(seeds=None,\n          sampled_subgraphs=None,\n          node_features=None,\n          labels=None,\n          input_nodes=None,\n          indexes=None,\n          edge_features=None,\n          compacted_seeds=None,\n          blocks=None,\n       )\"\"\"\n    )\n    result = str(minibatch)\n    assert result == expect_result, print(expect_result, result)\n    # Test minibatch with all attributes.\n    minibatch = gb.MiniBatch(\n        seeds=seeds,\n        sampled_subgraphs=subgraphs,\n        labels=labels,\n        node_features=node_features,\n        edge_features=edge_features,\n        compacted_seeds=compacted_seeds,\n        input_nodes=input_nodes,\n    )\n    expect_result = str(\n        \"\"\"MiniBatch(seeds=tensor([10, 11]),\n          sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 3, 5, 6], dtype=torch.int32),\n                                                                         indices=tensor([0, 1, 2, 2, 1, 2], dtype=torch.int32),\n                                                           ),\n                                               original_row_node_ids=tensor([10, 11, 12, 13]),\n                                               original_edge_ids=tensor([19, 20, 21, 22, 25, 30]),\n                                               original_column_node_ids=tensor([10, 11, 12, 13]),\n                            ),\n                            SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 2, 3], dtype=torch.int32),\n                                                                         indices=tensor([1, 2, 0], dtype=torch.int32),\n                                                           ),\n                                               original_row_node_ids=tensor([10, 11, 12]),\n                                               original_edge_ids=tensor([10, 15, 17]),\n                                               original_column_node_ids=tensor([10, 11]),\n                            )],\n          node_features={'x': tensor([5, 0, 2, 1])},\n          labels=tensor([1., 2.]),\n          input_nodes=tensor([8, 1, 6, 5, 9, 0, 2, 4]),\n          indexes=None,\n          edge_features=[{'x': tensor([9, 0, 1, 1, 7, 4])},\n                        {'x': tensor([0, 2, 2])}],\n          compacted_seeds=tensor([0, 1]),\n          blocks=[Block(num_src_nodes=4, num_dst_nodes=4, num_edges=6),\n                 Block(num_src_nodes=3, num_dst_nodes=2, num_edges=3)],\n       )\"\"\"\n    )\n    result = str(minibatch)\n    assert result == expect_result, print(expect_result, result)\n\n\n@pytest.mark.parametrize(\"indptr_dtype\", [torch.int32, torch.int64])\n@pytest.mark.parametrize(\"indices_dtype\", [torch.int32, torch.int64])\ndef test_minibatch_representation_hetero(indptr_dtype, indices_dtype):\n    seeds = {relation: torch.tensor([10, 11])}\n    csc_formats = [\n        {\n            relation: gb.CSCFormatBase(\n                indptr=torch.tensor([0, 1, 2, 3], dtype=indptr_dtype),\n                indices=torch.tensor([0, 1, 1], dtype=indices_dtype),\n            ),\n            reverse_relation: gb.CSCFormatBase(\n                indptr=torch.tensor([0, 0, 0, 1, 2], dtype=indptr_dtype),\n                indices=torch.tensor([1, 0], dtype=indices_dtype),\n            ),\n        },\n        {\n            relation: gb.CSCFormatBase(\n                indptr=torch.tensor([0, 1, 2], dtype=indptr_dtype),\n                indices=torch.tensor([1, 0], dtype=indices_dtype),\n            ),\n            reverse_relation: gb.CSCFormatBase(\n                indptr=torch.tensor([0, 2], dtype=indptr_dtype),\n                indices=torch.tensor([1, 0], dtype=indices_dtype),\n            ),\n        },\n    ]\n    original_column_node_ids = [\n        {\"B\": torch.tensor([10, 11, 12]), \"A\": torch.tensor([5, 7, 9, 11])},\n        {\"B\": torch.tensor([10, 11]), \"A\": torch.tensor([5])},\n    ]\n    original_row_node_ids = [\n        {\n            \"A\": torch.tensor([5, 7, 9, 11]),\n            \"B\": torch.tensor([10, 11, 12]),\n        },\n        {\n            \"A\": torch.tensor([5, 7]),\n            \"B\": torch.tensor([10, 11]),\n        },\n    ]\n    original_edge_ids = [\n        {\n            relation: torch.tensor([19, 20, 21]),\n            reverse_relation: torch.tensor([23, 26]),\n        },\n        {relation: torch.tensor([10, 12])},\n    ]\n    node_features = {\n        (\"A\", \"x\"): torch.tensor([6, 4, 0, 1]),\n    }\n    edge_features = [\n        {(relation, \"x\"): torch.tensor([4, 2, 4])},\n        {(relation, \"x\"): torch.tensor([0, 6])},\n    ]\n    subgraphs = []\n    for i in range(2):\n        subgraphs.append(\n            gb.SampledSubgraphImpl(\n                sampled_csc=csc_formats[i],\n                original_column_node_ids=original_column_node_ids[i],\n                original_row_node_ids=original_row_node_ids[i],\n                original_edge_ids=original_edge_ids[i],\n            )\n        )\n    compacted_seeds = {relation: torch.tensor([0, 1])}\n    # Test minibatch with all attributes.\n    minibatch = gb.MiniBatch(\n        seeds=seeds,\n        sampled_subgraphs=subgraphs,\n        node_features=node_features,\n        edge_features=edge_features,\n        labels={\"B\": torch.tensor([2, 5])},\n        compacted_seeds=compacted_seeds,\n        input_nodes={\n            \"A\": torch.tensor([5, 7, 9, 11]),\n            \"B\": torch.tensor([10, 11, 12]),\n        },\n    )\n    expect_result = str(\n        \"\"\"MiniBatch(seeds={'A:r:B': tensor([10, 11])},\n          sampled_subgraphs=[SampledSubgraphImpl(sampled_csc={'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2, 3], dtype=torch.int32),\n                                                                         indices=tensor([0, 1, 1], dtype=torch.int32),\n                                                           ), 'B:rr:A': CSCFormatBase(indptr=tensor([0, 0, 0, 1, 2], dtype=torch.int32),\n                                                                         indices=tensor([1, 0], dtype=torch.int32),\n                                                           )},\n                                               original_row_node_ids={'A': tensor([ 5,  7,  9, 11]), 'B': tensor([10, 11, 12])},\n                                               original_edge_ids={'A:r:B': tensor([19, 20, 21]), 'B:rr:A': tensor([23, 26])},\n                                               original_column_node_ids={'B': tensor([10, 11, 12]), 'A': tensor([ 5,  7,  9, 11])},\n                            ),\n                            SampledSubgraphImpl(sampled_csc={'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2], dtype=torch.int32),\n                                                                         indices=tensor([1, 0], dtype=torch.int32),\n                                                           ), 'B:rr:A': CSCFormatBase(indptr=tensor([0, 2], dtype=torch.int32),\n                                                                         indices=tensor([1, 0], dtype=torch.int32),\n                                                           )},\n                                               original_row_node_ids={'A': tensor([5, 7]), 'B': tensor([10, 11])},\n                                               original_edge_ids={'A:r:B': tensor([10, 12])},\n                                               original_column_node_ids={'B': tensor([10, 11]), 'A': tensor([5])},\n                            )],\n          node_features={('A', 'x'): tensor([6, 4, 0, 1])},\n          labels={'B': tensor([2, 5])},\n          input_nodes={'A': tensor([ 5,  7,  9, 11]), 'B': tensor([10, 11, 12])},\n          indexes=None,\n          edge_features=[{('A:r:B', 'x'): tensor([4, 2, 4])},\n                        {('A:r:B', 'x'): tensor([0, 6])}],\n          compacted_seeds={'A:r:B': tensor([0, 1])},\n          blocks=[Block(num_src_nodes={'A': 4, 'B': 3},\n                       num_dst_nodes={'A': 4, 'B': 3},\n                       num_edges={('A', 'r', 'B'): 3, ('B', 'rr', 'A'): 2},\n                       metagraph=[('A', 'B', 'r'), ('B', 'A', 'rr')]),\n                 Block(num_src_nodes={'A': 2, 'B': 2},\n                       num_dst_nodes={'A': 1, 'B': 2},\n                       num_edges={('A', 'r', 'B'): 2, ('B', 'rr', 'A'): 2},\n                       metagraph=[('A', 'B', 'r'), ('B', 'A', 'rr')])],\n       )\"\"\"\n    )\n    result = str(minibatch)\n    assert result == expect_result, print(result)\n\n\n@pytest.mark.parametrize(\"indptr_dtype\", [torch.int32, torch.int64])\n@pytest.mark.parametrize(\"indices_dtype\", [torch.int32, torch.int64])\ndef test_get_dgl_blocks_homo(indptr_dtype, indices_dtype):\n    csc_formats = [\n        gb.CSCFormatBase(\n            indptr=torch.tensor([0, 1, 3, 5, 6], dtype=indptr_dtype),\n            indices=torch.tensor([0, 1, 2, 2, 1, 2], dtype=indices_dtype),\n        ),\n        gb.CSCFormatBase(\n            indptr=torch.tensor([0, 1, 3], dtype=indptr_dtype),\n            indices=torch.tensor([0, 1, 2], dtype=indices_dtype),\n        ),\n    ]\n    original_column_node_ids = [\n        torch.tensor([10, 11, 12, 13]),\n        torch.tensor([10, 11]),\n    ]\n    original_row_node_ids = [\n        torch.tensor([10, 11, 12, 13]),\n        torch.tensor([10, 11, 12]),\n    ]\n    original_edge_ids = [\n        torch.tensor([19, 20, 21, 22, 25, 30]),\n        torch.tensor([10, 15, 17]),\n    ]\n    subgraphs = []\n    for i in range(2):\n        subgraphs.append(\n            gb.SampledSubgraphImpl(\n                sampled_csc=csc_formats[i],\n                original_column_node_ids=original_column_node_ids[i],\n                original_row_node_ids=original_row_node_ids[i],\n                original_edge_ids=original_edge_ids[i],\n            )\n        )\n    # Test minibatch with all attributes.\n    minibatch = gb.MiniBatch(\n        sampled_subgraphs=subgraphs,\n    )\n    dgl_blocks = minibatch.blocks\n    expect_result = str(\n        \"\"\"[Block(num_src_nodes=4, num_dst_nodes=4, num_edges=6), Block(num_src_nodes=3, num_dst_nodes=2, num_edges=3)]\"\"\"\n    )\n    result = str(dgl_blocks)\n    assert result == expect_result\n\n\ndef test_get_dgl_blocks_hetero():\n    csc_formats = [\n        {\n            relation: gb.CSCFormatBase(\n                indptr=torch.tensor([0, 1, 2, 3]),\n                indices=torch.tensor([0, 1, 1]),\n            ),\n            reverse_relation: gb.CSCFormatBase(\n                indptr=torch.tensor([0, 0, 0, 1, 2]),\n                indices=torch.tensor([1, 0]),\n            ),\n        },\n        {\n            relation: gb.CSCFormatBase(\n                indptr=torch.tensor([0, 1, 2]), indices=torch.tensor([1, 0])\n            ),\n            reverse_relation: gb.CSCFormatBase(\n                indptr=torch.tensor([0, 1]),\n                indices=torch.tensor([1]),\n            ),\n        },\n    ]\n    original_column_node_ids = [\n        {\"B\": torch.tensor([10, 11, 12]), \"A\": torch.tensor([5, 7, 9, 11])},\n        {\"B\": torch.tensor([10, 11]), \"A\": torch.tensor([5])},\n    ]\n    original_row_node_ids = [\n        {\n            \"A\": torch.tensor([5, 7, 9, 11]),\n            \"B\": torch.tensor([10, 11, 12]),\n        },\n        {\n            \"A\": torch.tensor([5, 7]),\n            \"B\": torch.tensor([10, 11]),\n        },\n    ]\n    original_edge_ids = [\n        {\n            relation: torch.tensor([19, 20, 21]),\n            reverse_relation: torch.tensor([23, 26]),\n        },\n        {relation: torch.tensor([10, 12])},\n    ]\n    subgraphs = []\n    for i in range(2):\n        subgraphs.append(\n            gb.SampledSubgraphImpl(\n                sampled_csc=csc_formats[i],\n                original_column_node_ids=original_column_node_ids[i],\n                original_row_node_ids=original_row_node_ids[i],\n                original_edge_ids=original_edge_ids[i],\n            )\n        )\n    # Test minibatch with all attributes.\n    minibatch = gb.MiniBatch(\n        sampled_subgraphs=subgraphs,\n    )\n    dgl_blocks = minibatch.blocks\n    expect_result = str(\n        \"\"\"[Block(num_src_nodes={'A': 4, 'B': 3},\n      num_dst_nodes={'A': 4, 'B': 3},\n      num_edges={('A', 'r', 'B'): 3, ('B', 'rr', 'A'): 2},\n      metagraph=[('A', 'B', 'r'), ('B', 'A', 'rr')]), Block(num_src_nodes={'A': 2, 'B': 2},\n      num_dst_nodes={'A': 1, 'B': 2},\n      num_edges={('A', 'r', 'B'): 2, ('B', 'rr', 'A'): 1},\n      metagraph=[('A', 'B', 'r'), ('B', 'A', 'rr')])]\"\"\"\n    )\n    result = str(dgl_blocks)\n    assert result == expect_result\n\n\ndef test_get_dgl_blocks_hetero_partial_empty_edges():\n    hg = dgl.heterograph(\n        {\n            (\"n1\", \"e1\", \"n1\"): ([0, 1, 1], [1, 2, 0]),\n            (\"n1\", \"e2\", \"n2\"): ([0, 1, 2], [1, 0, 2]),\n        }\n    )\n\n    gb_g = gb.from_dglgraph(hg, is_homogeneous=False)\n\n    train_set = gb.HeteroItemSet(\n        {\"n1:e2:n2\": gb.ItemSet(torch.LongTensor([[0, 1]]), names=\"seeds\")}\n    )\n    datapipe = gb.ItemSampler(train_set, batch_size=1)\n    datapipe = datapipe.sample_neighbor(gb_g, fanouts=[-1, -1])\n    dataloader = gb.DataLoader(datapipe)\n    blocks_str = str(next(iter(dataloader)).blocks)\n    expected_str = \"\"\"[Block(num_src_nodes={'n1': 2, 'n2': 0},\n      num_dst_nodes={'n1': 2, 'n2': 0},\n      num_edges={('n1', 'e1', 'n1'): 2, ('n1', 'e2', 'n2'): 0},\n      metagraph=[('n1', 'n1', 'e1'), ('n1', 'n2', 'e2')]), Block(num_src_nodes={'n1': 2, 'n2': 0},\n      num_dst_nodes={'n1': 1, 'n2': 1},\n      num_edges={('n1', 'e1', 'n1'): 1, ('n1', 'e2', 'n2'): 1},\n      metagraph=[('n1', 'n1', 'e1'), ('n1', 'n2', 'e2')])]\"\"\"\n    assert expected_str == blocks_str\n\n\ndef test_get_dgl_blocks_hetero_empty_edges():\n    hg = dgl.heterograph(\n        {\n            (\"n3\", \"e1\", \"n1\"): ([0, 1, 1], [1, 2, 0]),\n            (\"n3\", \"e2\", \"n2\"): ([0, 1, 2], [1, 0, 2]),\n        }\n    )\n\n    gb_g = gb.from_dglgraph(hg, is_homogeneous=False)\n\n    train_set = gb.HeteroItemSet(\n        {\"n3:e1:n1\": gb.ItemSet(torch.LongTensor([[2, 1]]), names=\"seeds\")}\n    )\n    datapipe = gb.ItemSampler(train_set, batch_size=1)\n    datapipe = datapipe.sample_neighbor(gb_g, fanouts=[-1, -1])\n    dataloader = gb.DataLoader(datapipe)\n    blocks_str = str(next(iter(dataloader)).blocks)\n    expected_str = \"\"\"[Block(num_src_nodes={'n1': 0, 'n2': 0, 'n3': 2},\n      num_dst_nodes={'n1': 0, 'n2': 0, 'n3': 2},\n      num_edges={('n3', 'e1', 'n1'): 0, ('n3', 'e2', 'n2'): 0},\n      metagraph=[('n3', 'n1', 'e1'), ('n3', 'n2', 'e2')]), Block(num_src_nodes={'n1': 0, 'n2': 0, 'n3': 2},\n      num_dst_nodes={'n1': 1, 'n2': 0, 'n3': 1},\n      num_edges={('n3', 'e1', 'n1'): 1, ('n3', 'e2', 'n2'): 0},\n      metagraph=[('n3', 'n1', 'e1'), ('n3', 'n2', 'e2')])]\"\"\"\n    assert expected_str == blocks_str\n\n\ndef test_get_dgl_blocks_homo_empty_edges():\n    g = dgl.graph(([2, 3, 4], [3, 4, 5]))\n\n    gb_g = gb.from_dglgraph(g, is_homogeneous=True)\n    train_set = gb.ItemSet(torch.LongTensor([[0, 1]]), names=\"seeds\")\n    datapipe = gb.ItemSampler(train_set, batch_size=1)\n    datapipe = datapipe.sample_neighbor(gb_g, fanouts=[-1, -1])\n    dataloader = gb.DataLoader(datapipe)\n    blocks_str = str(next(iter(dataloader)).blocks)\n    expected_str = \"[Block(num_src_nodes=2, num_dst_nodes=2, num_edges=0), Block(num_src_nodes=2, num_dst_nodes=2, num_edges=0)]\"\n    assert expected_str == blocks_str\n\n\ndef test_seeds_ntype_being_passed():\n    hg = dgl.heterograph({(\"n1\", \"e1\", \"n2\"): ([0, 1, 2], [2, 0, 1])})\n\n    gb_g = gb.from_dglgraph(hg, is_homogeneous=False)\n    train_set = gb.HeteroItemSet(\n        {\"n2\": gb.ItemSet(torch.LongTensor([0, 1]), names=\"seeds\")}\n    )\n    datapipe = gb.ItemSampler(train_set, batch_size=2)\n    datapipe = datapipe.sample_neighbor(gb_g, [-1, -1, -1])\n    dataloader = gb.DataLoader(datapipe)\n    blocks = next(iter(dataloader)).blocks\n    for block in blocks:\n        assert \"n2\" in block.srctypes\n\n\ndef create_homo_minibatch():\n    csc_formats = [\n        gb.CSCFormatBase(\n            indptr=torch.tensor([0, 1, 3, 5, 6]),\n            indices=torch.tensor([0, 1, 2, 2, 1, 2]),\n        ),\n        gb.CSCFormatBase(\n            indptr=torch.tensor([0, 2, 3]),\n            indices=torch.tensor([1, 2, 0]),\n        ),\n    ]\n    original_column_node_ids = [\n        torch.tensor([10, 11, 12, 13]),\n        torch.tensor([10, 11]),\n    ]\n    original_row_node_ids = [\n        torch.tensor([10, 11, 12, 13]),\n        torch.tensor([10, 11, 12]),\n    ]\n    original_edge_ids = [\n        torch.tensor([19, 20, 21, 22, 25, 30]),\n        torch.tensor([10, 15, 17]),\n    ]\n    node_features = {\"x\": torch.randint(0, 10, (4,))}\n    edge_features = [\n        {\"x\": torch.randint(0, 10, (6,))},\n        {\"x\": torch.randint(0, 10, (3,))},\n    ]\n    subgraphs = []\n    for i in range(2):\n        subgraphs.append(\n            gb.SampledSubgraphImpl(\n                sampled_csc=csc_formats[i],\n                original_column_node_ids=original_column_node_ids[i],\n                original_row_node_ids=original_row_node_ids[i],\n                original_edge_ids=original_edge_ids[i],\n            )\n        )\n    return gb.MiniBatch(\n        sampled_subgraphs=subgraphs,\n        node_features=node_features,\n        edge_features=edge_features,\n        input_nodes=torch.tensor([10, 11, 12, 13]),\n    )\n\n\ndef create_hetero_minibatch():\n    sampled_csc = [\n        {\n            relation: gb.CSCFormatBase(\n                indptr=torch.tensor([0, 1, 2, 3]),\n                indices=torch.tensor([0, 1, 1]),\n            ),\n            reverse_relation: gb.CSCFormatBase(\n                indptr=torch.tensor([0, 0, 0, 1, 2]),\n                indices=torch.tensor([1, 0]),\n            ),\n        },\n        {\n            relation: gb.CSCFormatBase(\n                indptr=torch.tensor([0, 1, 2]), indices=torch.tensor([1, 0])\n            )\n        },\n    ]\n    original_column_node_ids = [\n        {\"B\": torch.tensor([10, 11, 12]), \"A\": torch.tensor([5, 7, 9, 11])},\n        {\"B\": torch.tensor([10, 11])},\n    ]\n    original_row_node_ids = [\n        {\n            \"A\": torch.tensor([5, 7, 9, 11]),\n            \"B\": torch.tensor([10, 11, 12]),\n        },\n        {\n            \"A\": torch.tensor([5, 7]),\n            \"B\": torch.tensor([10, 11]),\n        },\n    ]\n    original_edge_ids = [\n        {\n            relation: torch.tensor([19, 20, 21]),\n            reverse_relation: torch.tensor([23, 26]),\n        },\n        {relation: torch.tensor([10, 12])},\n    ]\n    node_features = {\n        (\"A\", \"x\"): torch.randint(0, 10, (4,)),\n    }\n    edge_features = [\n        {(relation, \"x\"): torch.randint(0, 10, (3,))},\n        {(relation, \"x\"): torch.randint(0, 10, (2,))},\n    ]\n    subgraphs = []\n    for i in range(2):\n        subgraphs.append(\n            gb.SampledSubgraphImpl(\n                sampled_csc=sampled_csc[i],\n                original_column_node_ids=original_column_node_ids[i],\n                original_row_node_ids=original_row_node_ids[i],\n                original_edge_ids=original_edge_ids[i],\n            )\n        )\n    return gb.MiniBatch(\n        sampled_subgraphs=subgraphs,\n        node_features=node_features,\n        edge_features=edge_features,\n        input_nodes={\n            \"A\": torch.tensor([5, 7, 9, 11]),\n            \"B\": torch.tensor([10, 11, 12]),\n        },\n    )\n\n\ndef check_dgl_blocks_hetero(minibatch, blocks):\n    etype = gb.etype_str_to_tuple(relation)\n    sampled_csc = [\n        subgraph.sampled_csc for subgraph in minibatch.sampled_subgraphs\n    ]\n    original_edge_ids = [\n        subgraph.original_edge_ids for subgraph in minibatch.sampled_subgraphs\n    ]\n    original_row_node_ids = [\n        subgraph.original_row_node_ids\n        for subgraph in minibatch.sampled_subgraphs\n    ]\n\n    for i, block in enumerate(blocks):\n        edges = block.edges(etype=etype)\n        dst_ndoes = torch.arange(\n            0, len(sampled_csc[i][relation].indptr) - 1\n        ).repeat_interleave(sampled_csc[i][relation].indptr.diff())\n        assert torch.equal(edges[0], sampled_csc[i][relation].indices)\n        assert torch.equal(edges[1], dst_ndoes)\n        assert torch.equal(\n            block.edges[etype].data[dgl.EID], original_edge_ids[i][relation]\n        )\n    edges = blocks[0].edges(etype=gb.etype_str_to_tuple(reverse_relation))\n    dst_ndoes = torch.arange(\n        0, len(sampled_csc[0][reverse_relation].indptr) - 1\n    ).repeat_interleave(sampled_csc[0][reverse_relation].indptr.diff())\n    assert torch.equal(edges[0], sampled_csc[0][reverse_relation].indices)\n    assert torch.equal(edges[1], dst_ndoes)\n    assert torch.equal(\n        blocks[0].srcdata[dgl.NID][\"A\"], original_row_node_ids[0][\"A\"]\n    )\n    assert torch.equal(\n        blocks[0].srcdata[dgl.NID][\"B\"], original_row_node_ids[0][\"B\"]\n    )\n\n\ndef check_dgl_blocks_homo(minibatch, blocks):\n    sampled_csc = [\n        subgraph.sampled_csc for subgraph in minibatch.sampled_subgraphs\n    ]\n    original_edge_ids = [\n        subgraph.original_edge_ids for subgraph in minibatch.sampled_subgraphs\n    ]\n    original_row_node_ids = [\n        subgraph.original_row_node_ids\n        for subgraph in minibatch.sampled_subgraphs\n    ]\n    for i, block in enumerate(blocks):\n        dst_ndoes = torch.arange(\n            0, len(sampled_csc[i].indptr) - 1\n        ).repeat_interleave(sampled_csc[i].indptr.diff())\n        assert torch.equal(block.edges()[0], sampled_csc[i].indices)\n        assert torch.equal(block.edges()[1], dst_ndoes)\n        assert torch.equal(block.edata[dgl.EID], original_edge_ids[i])\n    assert torch.equal(blocks[0].srcdata[dgl.NID], original_row_node_ids[0])\n\n\ndef test_dgl_node_classification_without_feature():\n    # Arrange\n    minibatch = create_homo_minibatch()\n    minibatch.node_features = None\n    minibatch.labels = None\n    minibatch.seeds = torch.tensor([10, 15])\n    # Act\n    dgl_blocks = minibatch.blocks\n\n    # Assert\n    assert len(dgl_blocks) == 2\n    assert minibatch.node_features is None\n    assert minibatch.labels is None\n    check_dgl_blocks_homo(minibatch, dgl_blocks)\n\n\ndef test_dgl_node_classification_homo():\n    # Arrange\n    minibatch = create_homo_minibatch()\n    minibatch.seeds = torch.tensor([10, 15])\n    minibatch.labels = torch.tensor([2, 5])\n    # Act\n    dgl_blocks = minibatch.blocks\n\n    # Assert\n    assert len(dgl_blocks) == 2\n    check_dgl_blocks_homo(minibatch, dgl_blocks)\n\n\ndef test_dgl_node_classification_hetero():\n    minibatch = create_hetero_minibatch()\n    minibatch.labels = {\"B\": torch.tensor([2, 5])}\n    minibatch.seeds = {\"B\": torch.tensor([10, 15])}\n    # Act\n    dgl_blocks = minibatch.blocks\n\n    # Assert\n    assert len(dgl_blocks) == 2\n    check_dgl_blocks_hetero(minibatch, dgl_blocks)\n\n\ndef test_dgl_link_predication_homo():\n    # Arrange\n    minibatch = create_homo_minibatch()\n    minibatch.compacted_seeds = (\n        torch.tensor([[0, 1, 0, 0, 1, 1], [1, 0, 1, 1, 0, 0]]).T,\n    )\n    minibatch.labels = torch.tensor([1, 1, 0, 0, 0, 0])\n    # Act\n    dgl_blocks = minibatch.blocks\n\n    # Assert\n    assert len(dgl_blocks) == 2\n    check_dgl_blocks_homo(minibatch, dgl_blocks)\n\n\ndef test_dgl_link_predication_hetero():\n    # Arrange\n    minibatch = create_hetero_minibatch()\n    minibatch.compacted_seeds = {\n        relation: (torch.tensor([[1, 1, 2, 0, 1, 2], [1, 0, 1, 1, 0, 0]]).T,),\n        reverse_relation: (\n            torch.tensor([[0, 1, 1, 2, 0, 2], [1, 0, 1, 1, 0, 0]]).T,\n        ),\n    }\n    minibatch.labels = {\n        relation: (torch.tensor([1, 1, 0, 0, 0, 0]),),\n        reverse_relation: (torch.tensor([1, 1, 0, 0, 0, 0]),),\n    }\n    # Act\n    dgl_blocks = minibatch.blocks\n\n    # Assert\n    assert len(dgl_blocks) == 2\n    check_dgl_blocks_hetero(minibatch, dgl_blocks)\n\n\ndef test_to_pyg_data():\n    test_minibatch = create_homo_minibatch()\n    test_minibatch.seeds = torch.tensor([0, 1])\n    test_minibatch.labels = torch.tensor([7, 8])\n\n    expected_edge_index = torch.tensor(\n        [[0, 0, 1, 1, 1, 2, 2, 2, 2], [0, 1, 0, 1, 2, 0, 1, 2, 3]]\n    )\n    expected_node_features = next(iter(test_minibatch.node_features.values()))\n    expected_labels = torch.tensor([7, 8])\n    expected_batch_size = 2\n    expected_n_id = torch.tensor([10, 11, 12, 13])\n\n    pyg_data = test_minibatch.to_pyg_data()\n    pyg_data.validate()\n    assert torch.equal(pyg_data.edge_index, expected_edge_index)\n    assert torch.equal(pyg_data.x, expected_node_features)\n    assert torch.equal(pyg_data.y, expected_labels)\n    assert pyg_data.batch_size == expected_batch_size\n    assert torch.equal(pyg_data.n_id, expected_n_id)\n\n    test_minibatch.seeds = torch.tensor([[0, 1], [2, 3]])\n    assert pyg_data.batch_size == expected_batch_size\n\n    test_minibatch.seeds = {\"A\": torch.tensor([0, 1])}\n    assert pyg_data.batch_size == expected_batch_size\n\n    test_minibatch.seeds = {\"A\": torch.tensor([[0, 1], [2, 3]])}\n    assert pyg_data.batch_size == expected_batch_size\n\n    subgraph = test_minibatch.sampled_subgraphs[0]\n    # Test with sampled_csc as None.\n    test_minibatch = gb.MiniBatch(\n        sampled_subgraphs=None,\n        node_features={\"feat\": expected_node_features},\n        labels=expected_labels,\n    )\n    pyg_data = test_minibatch.to_pyg_data()\n    assert pyg_data.edge_index is None, \"Edge index should be none.\"\n\n    # Test with node_features as None.\n    test_minibatch = gb.MiniBatch(\n        sampled_subgraphs=[subgraph],\n        node_features=None,\n        labels=expected_labels,\n    )\n    pyg_data = test_minibatch.to_pyg_data()\n    assert pyg_data.x is None, \"Node features should be None.\"\n\n    # Test with labels as None.\n    test_minibatch = gb.MiniBatch(\n        sampled_subgraphs=[subgraph],\n        node_features={\"feat\": expected_node_features},\n        labels=None,\n    )\n    pyg_data = test_minibatch.to_pyg_data()\n    assert pyg_data.y is None, \"Labels should be None.\"\n\n    # Test with multiple features.\n    test_minibatch = gb.MiniBatch(\n        sampled_subgraphs=[subgraph],\n        node_features={\n            \"feat\": expected_node_features,\n            \"extra_feat\": torch.tensor([[3], [4]]),\n        },\n        labels=expected_labels,\n    )\n    try:\n        pyg_data = test_minibatch.to_pyg_data()\n        assert (\n            pyg_data.x is None\n        ), \"Multiple features case should raise an error.\"\n    except AssertionError as e:\n        assert (\n            str(e)\n            == \"`to_pyg_data` only supports single feature homogeneous graph.\"\n        )\n"
  },
  {
    "path": "tests/python/pytorch/graphbolt/test_subgraph_sampler.py",
    "content": "import unittest\nimport warnings\n\nfrom enum import Enum\nfrom functools import partial\n\nimport backend as F\n\nimport dgl\nimport dgl.graphbolt as gb\nimport pytest\nimport torch\n\nfrom . import gb_test_utils\n\n\ndef _check_sampler_len(sampler, lenExp):\n    with warnings.catch_warnings():\n        warnings.simplefilter(\"ignore\", category=UserWarning)\n        assert len(list(sampler)) == lenExp\n\n\nclass SamplerType(Enum):\n    Normal = 0\n    Layer = 1\n    Temporal = 2\n    TemporalLayer = 3\n\n\ndef _get_sampler(sampler_type):\n    if sampler_type == SamplerType.Normal:\n        return gb.NeighborSampler\n    if sampler_type == SamplerType.Layer:\n        return gb.LayerNeighborSampler\n    if sampler_type == SamplerType.Temporal:\n        return partial(\n            gb.TemporalNeighborSampler,\n            node_timestamp_attr_name=\"timestamp\",\n            edge_timestamp_attr_name=\"timestamp\",\n        )\n    else:\n        return partial(\n            gb.TemporalLayerNeighborSampler,\n            node_timestamp_attr_name=\"timestamp\",\n            edge_timestamp_attr_name=\"timestamp\",\n        )\n\n\ndef _is_temporal(sampler_type):\n    return sampler_type in [SamplerType.Temporal, SamplerType.TemporalLayer]\n\n\ndef get_hetero_graph():\n    # COO graph:\n    # [0, 0, 1, 1, 2, 2, 3, 3, 4, 4]\n    # [2, 4, 2, 3, 0, 1, 1, 0, 0, 1]\n    # [1, 1, 1, 1, 0, 0, 0, 0, 0] - > edge type.\n    # num_nodes = 5, num_n1 = 2, num_n2 = 3\n    ntypes = {\"n1\": 0, \"n2\": 1}\n    etypes = {\"n1:e1:n2\": 0, \"n2:e2:n1\": 1}\n    indptr = torch.LongTensor([0, 2, 4, 6, 8, 10])\n    indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1])\n    type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0])\n    node_type_offset = torch.LongTensor([0, 2, 5])\n    return gb.fused_csc_sampling_graph(\n        indptr,\n        indices,\n        node_type_offset=node_type_offset,\n        type_per_edge=type_per_edge,\n        node_type_to_id=ntypes,\n        edge_type_to_id=etypes,\n    )\n\n\ndef _assert_hetero_values(\n    datapipe, original_row_node_ids, original_column_node_ids, csc_formats\n):\n    for data in datapipe:\n        for step, sampled_subgraph in enumerate(data.sampled_subgraphs):\n            for ntype in [\"n1\", \"n2\"]:\n                assert torch.equal(\n                    sampled_subgraph.original_row_node_ids[ntype],\n                    original_row_node_ids[step][ntype].to(F.ctx()),\n                )\n                assert torch.equal(\n                    sampled_subgraph.original_column_node_ids[ntype],\n                    original_column_node_ids[step][ntype].to(F.ctx()),\n                )\n            for etype in [\"n1:e1:n2\", \"n2:e2:n1\"]:\n                assert torch.equal(\n                    sampled_subgraph.sampled_csc[etype].indices,\n                    csc_formats[step][etype].indices.to(F.ctx()),\n                )\n                assert torch.equal(\n                    sampled_subgraph.sampled_csc[etype].indptr,\n                    csc_formats[step][etype].indptr.to(F.ctx()),\n                )\n\n\ndef _assert_homo_values(\n    datapipe, original_row_node_ids, compacted_indices, indptr, seeds\n):\n    for data in datapipe:\n        for step, sampled_subgraph in enumerate(data.sampled_subgraphs):\n            assert torch.equal(\n                sampled_subgraph.original_row_node_ids,\n                original_row_node_ids[step],\n            )\n            assert torch.equal(\n                sampled_subgraph.sampled_csc.indices, compacted_indices[step]\n            )\n            assert torch.equal(\n                sampled_subgraph.sampled_csc.indptr, indptr[step]\n            )\n            assert torch.equal(\n                sampled_subgraph.original_column_node_ids, seeds[step]\n            )\n\n\ndef test_SubgraphSampler_invoke():\n    itemset = gb.ItemSet(torch.arange(10), names=\"seeds\")\n    item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())\n\n    # Invoke via class constructor.\n    datapipe = gb.SubgraphSampler(item_sampler)\n    with pytest.raises(NotImplementedError):\n        next(iter(datapipe))\n\n    # Invokde via functional form.\n    datapipe = item_sampler.sample_subgraph()\n    with pytest.raises(NotImplementedError):\n        next(iter(datapipe))\n\n\n@pytest.mark.parametrize(\"labor\", [False, True])\ndef test_NeighborSampler_invoke(labor):\n    graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to(\n        F.ctx()\n    )\n    itemset = gb.ItemSet(torch.arange(10), names=\"seeds\")\n    item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())\n    num_layer = 2\n    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]\n\n    # Invoke via class constructor.\n    Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler\n    datapipe = Sampler(item_sampler, graph, fanouts)\n    assert len(list(datapipe)) == 5\n\n    # Invokde via functional form.\n    if labor:\n        datapipe = item_sampler.sample_layer_neighbor(graph, fanouts)\n    else:\n        datapipe = item_sampler.sample_neighbor(graph, fanouts)\n    assert len(list(datapipe)) == 5\n\n\n@pytest.mark.parametrize(\"labor\", [False, True])\ndef test_NeighborSampler_fanouts(labor):\n    graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to(\n        F.ctx()\n    )\n    itemset = gb.ItemSet(torch.arange(10), names=\"seeds\")\n    item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())\n    num_layer = 2\n\n    # `fanouts` is a list of tensors.\n    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]\n    if labor:\n        datapipe = item_sampler.sample_layer_neighbor(graph, fanouts)\n    else:\n        datapipe = item_sampler.sample_neighbor(graph, fanouts)\n    assert len(list(datapipe)) == 5\n\n    # `fanouts` is a list of integers.\n    fanouts = [2 for _ in range(num_layer)]\n    if labor:\n        datapipe = item_sampler.sample_layer_neighbor(graph, fanouts)\n    else:\n        datapipe = item_sampler.sample_neighbor(graph, fanouts)\n    assert len(list(datapipe)) == 5\n\n\n@pytest.mark.parametrize(\n    \"sampler_type\",\n    [\n        SamplerType.Normal,\n        SamplerType.Layer,\n        SamplerType.Temporal,\n        SamplerType.TemporalLayer,\n    ],\n)\ndef test_SubgraphSampler_Node(sampler_type):\n    graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to(\n        F.ctx()\n    )\n    items = torch.arange(10)\n    names = \"seeds\"\n    if _is_temporal(sampler_type):\n        graph.node_attributes = {\"timestamp\": torch.arange(20).to(F.ctx())}\n        graph.edge_attributes = {\n            \"timestamp\": torch.arange(len(graph.indices)).to(F.ctx())\n        }\n        items = (items, torch.arange(10))\n        names = (names, \"timestamp\")\n    itemset = gb.ItemSet(items, names=names)\n    item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())\n    num_layer = 2\n    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]\n    sampler = _get_sampler(sampler_type)\n    sampler_dp = sampler(item_sampler, graph, fanouts)\n    _check_sampler_len(sampler_dp, 5)\n\n\n@pytest.mark.parametrize(\n    \"sampler_type\",\n    [\n        SamplerType.Normal,\n        SamplerType.Layer,\n        SamplerType.Temporal,\n        SamplerType.TemporalLayer,\n    ],\n)\ndef test_SubgraphSampler_Link(sampler_type):\n    graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to(\n        F.ctx()\n    )\n    items = torch.arange(20).reshape(-1, 2)\n    names = \"seeds\"\n    if _is_temporal(sampler_type):\n        graph.node_attributes = {\"timestamp\": torch.arange(20).to(F.ctx())}\n        graph.edge_attributes = {\n            \"timestamp\": torch.arange(len(graph.indices)).to(F.ctx())\n        }\n        items = (items, torch.arange(10))\n        names = (names, \"timestamp\")\n    itemset = gb.ItemSet(items, names=names)\n    datapipe = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())\n    num_layer = 2\n    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]\n    sampler = _get_sampler(sampler_type)\n    datapipe = sampler(datapipe, graph, fanouts)\n    datapipe = datapipe.transform(partial(gb.exclude_seed_edges))\n    _check_sampler_len(datapipe, 5)\n    for data in datapipe:\n        assert torch.equal(\n            data.compacted_seeds, torch.tensor([[0, 1], [2, 3]]).to(F.ctx())\n        )\n\n\n@pytest.mark.parametrize(\n    \"sampler_type\",\n    [\n        SamplerType.Normal,\n        SamplerType.Layer,\n        SamplerType.Temporal,\n        SamplerType.TemporalLayer,\n    ],\n)\ndef test_SubgraphSampler_Link_With_Negative(sampler_type):\n    graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to(\n        F.ctx()\n    )\n    items = torch.arange(20).reshape(-1, 2)\n    names = \"seeds\"\n    if _is_temporal(sampler_type):\n        graph.node_attributes = {\"timestamp\": torch.arange(20).to(F.ctx())}\n        graph.edge_attributes = {\n            \"timestamp\": torch.arange(len(graph.indices)).to(F.ctx())\n        }\n        items = (items, torch.arange(10))\n        names = (names, \"timestamp\")\n    itemset = gb.ItemSet(items, names=names)\n    datapipe = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())\n    num_layer = 2\n    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]\n    datapipe = gb.UniformNegativeSampler(datapipe, graph, 1)\n    sampler = _get_sampler(sampler_type)\n    datapipe = sampler(datapipe, graph, fanouts)\n    datapipe = datapipe.transform(partial(gb.exclude_seed_edges))\n    _check_sampler_len(datapipe, 5)\n\n\n@pytest.mark.parametrize(\n    \"sampler_type\",\n    [\n        SamplerType.Normal,\n        SamplerType.Layer,\n        SamplerType.Temporal,\n        SamplerType.TemporalLayer,\n    ],\n)\ndef test_SubgraphSampler_HyperLink(sampler_type):\n    graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to(\n        F.ctx()\n    )\n    items = torch.arange(20).reshape(-1, 5)\n    names = \"seeds\"\n    if _is_temporal(sampler_type):\n        graph.node_attributes = {\"timestamp\": torch.arange(20).to(F.ctx())}\n        graph.edge_attributes = {\n            \"timestamp\": torch.arange(len(graph.indices)).to(F.ctx())\n        }\n        items = (items, torch.arange(4))\n        names = (names, \"timestamp\")\n    itemset = gb.ItemSet(items, names=names)\n    datapipe = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())\n    num_layer = 2\n    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]\n    sampler = _get_sampler(sampler_type)\n    datapipe = sampler(datapipe, graph, fanouts)\n    _check_sampler_len(datapipe, 2)\n    for data in datapipe:\n        assert torch.equal(\n            data.compacted_seeds,\n            torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]).to(F.ctx()),\n        )\n\n\n@pytest.mark.parametrize(\n    \"sampler_type\",\n    [\n        SamplerType.Normal,\n        SamplerType.Layer,\n        SamplerType.Temporal,\n        SamplerType.TemporalLayer,\n    ],\n)\ndef test_SubgraphSampler_Node_Hetero(sampler_type):\n    graph = get_hetero_graph().to(F.ctx())\n    items = torch.arange(3)\n    names = \"seeds\"\n    if _is_temporal(sampler_type):\n        graph.node_attributes = {\n            \"timestamp\": torch.arange(graph.csc_indptr.numel() - 1).to(F.ctx())\n        }\n        graph.edge_attributes = {\n            \"timestamp\": torch.arange(graph.indices.numel()).to(F.ctx())\n        }\n        items = (items, torch.randint(0, 10, (3,)))\n        names = (names, \"timestamp\")\n    itemset = gb.HeteroItemSet({\"n2\": gb.ItemSet(items, names=names)})\n    item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())\n    num_layer = 2\n    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]\n    sampler = _get_sampler(sampler_type)\n    sampler_dp = sampler(item_sampler, graph, fanouts)\n    _check_sampler_len(sampler_dp, 2)\n    with warnings.catch_warnings():\n        warnings.simplefilter(\"ignore\", category=UserWarning)\n        for minibatch in sampler_dp:\n            assert len(minibatch.sampled_subgraphs) == num_layer\n\n\n@pytest.mark.parametrize(\n    \"sampler_type\",\n    [\n        SamplerType.Normal,\n        SamplerType.Layer,\n        SamplerType.Temporal,\n        SamplerType.TemporalLayer,\n    ],\n)\ndef test_SubgraphSampler_Link_Hetero(sampler_type):\n    graph = get_hetero_graph().to(F.ctx())\n    first_items = torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T\n    first_names = \"seeds\"\n    second_items = torch.LongTensor([[0, 0, 1, 1, 2, 2], [0, 1, 1, 0, 0, 1]]).T\n    second_names = \"seeds\"\n    if _is_temporal(sampler_type):\n        graph.node_attributes = {\n            \"timestamp\": torch.arange(graph.csc_indptr.numel() - 1).to(F.ctx())\n        }\n        graph.edge_attributes = {\n            \"timestamp\": torch.arange(graph.indices.numel()).to(F.ctx())\n        }\n        first_items = (first_items, torch.randint(0, 10, (4,)))\n        first_names = (first_names, \"timestamp\")\n        second_items = (second_items, torch.randint(0, 10, (6,)))\n        second_names = (second_names, \"timestamp\")\n    itemset = gb.HeteroItemSet(\n        {\n            \"n1:e1:n2\": gb.ItemSet(\n                first_items,\n                names=first_names,\n            ),\n            \"n2:e2:n1\": gb.ItemSet(\n                second_items,\n                names=second_names,\n            ),\n        }\n    )\n\n    datapipe = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())\n    num_layer = 2\n    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]\n    sampler = _get_sampler(sampler_type)\n    datapipe = sampler(datapipe, graph, fanouts)\n    datapipe = datapipe.transform(partial(gb.exclude_seed_edges))\n    _check_sampler_len(datapipe, 5)\n    for data in datapipe:\n        for compacted_seeds in data.compacted_seeds.values():\n            if _is_temporal(sampler_type):\n                assert torch.equal(\n                    compacted_seeds, torch.tensor([[0, 0], [1, 1]]).to(F.ctx())\n                )\n            else:\n                assert torch.equal(\n                    torch.sort(compacted_seeds.T, dim=1)[0].T,\n                    torch.tensor([[0, 0], [0, 1]]).to(F.ctx()),\n                )\n\n\n@pytest.mark.parametrize(\n    \"sampler_type\",\n    [\n        SamplerType.Normal,\n        SamplerType.Layer,\n        SamplerType.Temporal,\n        SamplerType.TemporalLayer,\n    ],\n)\ndef test_SubgraphSampler_Link_Hetero_With_Negative(sampler_type):\n    graph = get_hetero_graph().to(F.ctx())\n    first_items = torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T\n    first_names = \"seeds\"\n    second_items = torch.LongTensor([[0, 0, 1, 1, 2, 2], [0, 1, 1, 0, 0, 1]]).T\n    second_names = \"seeds\"\n    if _is_temporal(sampler_type):\n        graph.node_attributes = {\n            \"timestamp\": torch.arange(graph.csc_indptr.numel() - 1).to(F.ctx())\n        }\n        graph.edge_attributes = {\n            \"timestamp\": torch.arange(graph.indices.numel()).to(F.ctx())\n        }\n        first_items = (first_items, torch.randint(0, 10, (4,)))\n        first_names = (first_names, \"timestamp\")\n        second_items = (second_items, torch.randint(0, 10, (6,)))\n        second_names = (second_names, \"timestamp\")\n    itemset = gb.HeteroItemSet(\n        {\n            \"n1:e1:n2\": gb.ItemSet(\n                first_items,\n                names=first_names,\n            ),\n            \"n2:e2:n1\": gb.ItemSet(\n                second_items,\n                names=second_names,\n            ),\n        }\n    )\n\n    datapipe = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())\n    num_layer = 2\n    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]\n    datapipe = gb.UniformNegativeSampler(datapipe, graph, 1)\n    sampler = _get_sampler(sampler_type)\n    datapipe = sampler(datapipe, graph, fanouts)\n    datapipe = datapipe.transform(partial(gb.exclude_seed_edges))\n    _check_sampler_len(datapipe, 5)\n\n\n@pytest.mark.parametrize(\n    \"sampler_type\",\n    [\n        SamplerType.Normal,\n        SamplerType.Layer,\n        SamplerType.Temporal,\n        SamplerType.TemporalLayer,\n    ],\n)\ndef test_SubgraphSampler_Link_Hetero_Unknown_Etype(sampler_type):\n    graph = get_hetero_graph().to(F.ctx())\n    first_items = torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T\n    first_names = \"seeds\"\n    second_items = torch.LongTensor([[0, 0, 1, 1, 2, 2], [0, 1, 1, 0, 0, 1]]).T\n    second_names = \"seeds\"\n    if _is_temporal(sampler_type):\n        graph.node_attributes = {\n            \"timestamp\": torch.arange(graph.csc_indptr.numel() - 1).to(F.ctx())\n        }\n        graph.edge_attributes = {\n            \"timestamp\": torch.arange(graph.indices.numel()).to(F.ctx())\n        }\n        first_items = (first_items, torch.randint(0, 10, (4,)))\n        first_names = (first_names, \"timestamp\")\n        second_items = (second_items, torch.randint(0, 10, (6,)))\n        second_names = (second_names, \"timestamp\")\n    # \"e11\" and \"e22\" are not valid edge types.\n    itemset = gb.HeteroItemSet(\n        {\n            \"n1:e11:n2\": gb.ItemSet(\n                first_items,\n                names=first_names,\n            ),\n            \"n2:e22:n1\": gb.ItemSet(\n                second_items,\n                names=second_names,\n            ),\n        }\n    )\n\n    datapipe = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())\n    num_layer = 2\n    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]\n    sampler = _get_sampler(sampler_type)\n    datapipe = sampler(datapipe, graph, fanouts)\n    datapipe = datapipe.transform(partial(gb.exclude_seed_edges))\n    _check_sampler_len(datapipe, 5)\n\n\n@pytest.mark.parametrize(\n    \"sampler_type\",\n    [\n        SamplerType.Normal,\n        SamplerType.Layer,\n        SamplerType.Temporal,\n        SamplerType.TemporalLayer,\n    ],\n)\ndef test_SubgraphSampler_Link_Hetero_With_Negative_Unknown_Etype(sampler_type):\n    graph = get_hetero_graph().to(F.ctx())\n    first_items = torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T\n    first_names = \"seeds\"\n    second_items = torch.LongTensor([[0, 0, 1, 1, 2, 2], [0, 1, 1, 0, 0, 1]]).T\n    second_names = \"seeds\"\n    if _is_temporal(sampler_type):\n        graph.node_attributes = {\n            \"timestamp\": torch.arange(graph.csc_indptr.numel() - 1).to(F.ctx())\n        }\n        graph.edge_attributes = {\n            \"timestamp\": torch.arange(graph.indices.numel()).to(F.ctx())\n        }\n        first_items = (first_items, torch.randint(0, 10, (4,)))\n        first_names = (first_names, \"timestamp\")\n        second_items = (second_items, torch.randint(0, 10, (6,)))\n        second_names = (second_names, \"timestamp\")\n    # \"e11\" and \"e22\" are not valid edge types.\n    itemset = gb.HeteroItemSet(\n        {\n            \"n1:e11:n2\": gb.ItemSet(\n                first_items,\n                names=first_names,\n            ),\n            \"n2:e22:n1\": gb.ItemSet(\n                second_items,\n                names=second_names,\n            ),\n        }\n    )\n\n    datapipe = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())\n    num_layer = 2\n    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]\n    datapipe = gb.UniformNegativeSampler(datapipe, graph, 1)\n    sampler = _get_sampler(sampler_type)\n    datapipe = sampler(datapipe, graph, fanouts)\n    datapipe = datapipe.transform(partial(gb.exclude_seed_edges))\n    _check_sampler_len(datapipe, 5)\n\n\n@pytest.mark.parametrize(\n    \"sampler_type\",\n    [\n        SamplerType.Normal,\n        SamplerType.Layer,\n        SamplerType.Temporal,\n        SamplerType.TemporalLayer,\n    ],\n)\ndef test_SubgraphSampler_HyperLink_Hetero(sampler_type):\n    graph = get_hetero_graph().to(F.ctx())\n    items = torch.LongTensor([[2, 0, 1, 1, 2], [0, 1, 1, 0, 0]])\n    names = \"seeds\"\n    if _is_temporal(sampler_type):\n        graph.node_attributes = {\n            \"timestamp\": torch.arange(graph.csc_indptr.numel() - 1).to(F.ctx())\n        }\n        graph.edge_attributes = {\n            \"timestamp\": torch.arange(graph.indices.numel()).to(F.ctx())\n        }\n        items = (items, torch.randint(0, 10, (2,)))\n        names = (names, \"timestamp\")\n    itemset = gb.HeteroItemSet(\n        {\n            \"n2:n1:n2:n1:n2\": gb.ItemSet(\n                items,\n                names=names,\n            ),\n        }\n    )\n\n    datapipe = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())\n    num_layer = 2\n    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]\n    sampler = _get_sampler(sampler_type)\n    datapipe = sampler(datapipe, graph, fanouts)\n    _check_sampler_len(datapipe, 1)\n    for data in datapipe:\n        for compacted_seeds in data.compacted_seeds.values():\n            if _is_temporal(sampler_type):\n                assert torch.equal(\n                    compacted_seeds,\n                    torch.tensor([[0, 0, 2, 2, 4], [1, 1, 3, 3, 5]]).to(\n                        F.ctx()\n                    ),\n                )\n            else:\n                assert torch.equal(\n                    compacted_seeds,\n                    torch.tensor([[0, 0, 2, 1, 0], [1, 1, 2, 0, 1]]).to(\n                        F.ctx()\n                    ),\n                )\n\n\n@pytest.mark.parametrize(\n    \"sampler_type\",\n    [\n        SamplerType.Normal,\n        SamplerType.Layer,\n        SamplerType.Temporal,\n        SamplerType.TemporalLayer,\n    ],\n)\n@pytest.mark.parametrize(\n    \"replace\",\n    [False, True],\n)\ndef test_SubgraphSampler_Random_Hetero_Graph(sampler_type, replace):\n    if F._default_context_str == \"gpu\" and replace == True:\n        pytest.skip(\"Sampling with replacement not yet supported on GPU.\")\n    num_nodes = 5\n    num_edges = 9\n    num_ntypes = 3\n    num_etypes = 3\n    (\n        csc_indptr,\n        indices,\n        node_type_offset,\n        type_per_edge,\n        node_type_to_id,\n        edge_type_to_id,\n    ) = gb_test_utils.random_hetero_graph(\n        num_nodes, num_edges, num_ntypes, num_etypes\n    )\n    node_attributes = {}\n    edge_attributes = {\n        \"A1\": torch.randn(num_edges),\n        \"A2\": torch.randn(num_edges),\n    }\n    if _is_temporal(sampler_type):\n        node_attributes[\"timestamp\"] = torch.randint(0, 10, (num_nodes,))\n        edge_attributes[\"timestamp\"] = torch.randint(0, 10, (num_edges,))\n    graph = gb.fused_csc_sampling_graph(\n        csc_indptr,\n        indices,\n        node_type_offset=node_type_offset,\n        type_per_edge=type_per_edge,\n        node_type_to_id=node_type_to_id,\n        edge_type_to_id=edge_type_to_id,\n        node_attributes=node_attributes,\n        edge_attributes=edge_attributes,\n    ).to(F.ctx())\n    first_items = torch.tensor([0])\n    first_names = \"seeds\"\n    second_items = torch.tensor([0])\n    second_names = \"seeds\"\n    if _is_temporal(sampler_type):\n        first_items = (first_items, torch.randint(0, 10, (1,)))\n        first_names = (first_names, \"timestamp\")\n        second_items = (second_items, torch.randint(0, 10, (1,)))\n        second_names = (second_names, \"timestamp\")\n    itemset = gb.HeteroItemSet(\n        {\n            \"n2\": gb.ItemSet(first_items, names=first_names),\n            \"n1\": gb.ItemSet(second_items, names=second_names),\n        }\n    )\n\n    item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())\n    num_layer = 2\n    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]\n    sampler = _get_sampler(sampler_type)\n\n    sampler_dp = sampler(item_sampler, graph, fanouts, replace=replace)\n\n    with warnings.catch_warnings():\n        warnings.simplefilter(\"ignore\", category=UserWarning)\n        for data in sampler_dp:\n            for sampledsubgraph in data.sampled_subgraphs:\n                for _, value in sampledsubgraph.sampled_csc.items():\n                    assert torch.equal(\n                        torch.ge(\n                            value.indices,\n                            torch.zeros(len(value.indices)).to(F.ctx()),\n                        ),\n                        torch.ones(len(value.indices)).to(F.ctx()),\n                    )\n                    assert torch.equal(\n                        torch.ge(\n                            value.indptr,\n                            torch.zeros(len(value.indptr)).to(F.ctx()),\n                        ),\n                        torch.ones(len(value.indptr)).to(F.ctx()),\n                    )\n                for (\n                    _,\n                    value,\n                ) in sampledsubgraph.original_column_node_ids.items():\n                    assert torch.equal(\n                        torch.ge(value, torch.zeros(len(value)).to(F.ctx())),\n                        torch.ones(len(value)).to(F.ctx()),\n                    )\n                for _, value in sampledsubgraph.original_row_node_ids.items():\n                    assert torch.equal(\n                        torch.ge(value, torch.zeros(len(value)).to(F.ctx())),\n                        torch.ones(len(value)).to(F.ctx()),\n                    )\n\n\n@pytest.mark.parametrize(\n    \"sampler_type\",\n    [\n        SamplerType.Normal,\n        SamplerType.Layer,\n        SamplerType.Temporal,\n        SamplerType.TemporalLayer,\n    ],\n)\ndef test_SubgraphSampler_without_deduplication_Homo_Node(sampler_type):\n    graph = dgl.graph(\n        ([5, 0, 1, 5, 6, 7, 2, 2, 4], [0, 1, 2, 2, 2, 2, 3, 4, 4])\n    )\n    graph = gb.from_dglgraph(graph, True).to(F.ctx())\n    seed_nodes = torch.LongTensor([0, 3, 4])\n    items = seed_nodes\n    names = \"seeds\"\n    if _is_temporal(sampler_type):\n        graph.node_attributes = {\n            \"timestamp\": torch.zeros(\n                graph.csc_indptr.numel() - 1, dtype=torch.int64\n            ).to(F.ctx())\n        }\n        graph.edge_attributes = {\n            \"timestamp\": torch.zeros(\n                graph.indices.numel(), dtype=torch.int64\n            ).to(F.ctx())\n        }\n        items = (items, torch.randint(1, 10, (3,)))\n        names = (names, \"timestamp\")\n\n    itemset = gb.ItemSet(items, names=names)\n    item_sampler = gb.ItemSampler(itemset, batch_size=len(seed_nodes)).copy_to(\n        F.ctx()\n    )\n    num_layer = 2\n    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]\n\n    sampler = _get_sampler(sampler_type)\n    if _is_temporal(sampler_type):\n        datapipe = sampler(item_sampler, graph, fanouts)\n    else:\n        datapipe = sampler(item_sampler, graph, fanouts, deduplicate=False)\n\n    length = [17, 7]\n    compacted_indices = [\n        (torch.arange(0, 10) + 7).to(F.ctx()),\n        (torch.arange(0, 4) + 3).to(F.ctx()),\n    ]\n    indptr = [\n        torch.tensor([0, 1, 2, 4, 4, 6, 8, 10]).to(F.ctx()),\n        torch.tensor([0, 1, 2, 4]).to(F.ctx()),\n    ]\n    seeds = [\n        torch.tensor([0, 2, 2, 3, 4, 4, 5]).to(F.ctx()),\n        torch.tensor([0, 3, 4]).to(F.ctx()),\n    ]\n    with warnings.catch_warnings():\n        warnings.simplefilter(\"ignore\", category=UserWarning)\n        for data in datapipe:\n            for step, sampled_subgraph in enumerate(data.sampled_subgraphs):\n                assert (\n                    len(sampled_subgraph.original_row_node_ids) == length[step]\n                )\n                assert torch.equal(\n                    sampled_subgraph.sampled_csc.indices,\n                    compacted_indices[step],\n                )\n                assert torch.equal(\n                    sampled_subgraph.sampled_csc.indptr, indptr[step]\n                )\n                assert torch.equal(\n                    torch.sort(sampled_subgraph.original_column_node_ids)[0],\n                    seeds[step],\n                )\n\n\n@pytest.mark.parametrize(\n    \"sampler_type\",\n    [\n        SamplerType.Normal,\n        SamplerType.Layer,\n        SamplerType.Temporal,\n        SamplerType.TemporalLayer,\n    ],\n)\ndef test_SubgraphSampler_without_deduplication_Hetero_Node(sampler_type):\n    graph = get_hetero_graph().to(F.ctx())\n    items = torch.arange(2)\n    names = \"seeds\"\n    if _is_temporal(sampler_type):\n        graph.node_attributes = {\n            \"timestamp\": torch.zeros(\n                graph.csc_indptr.numel() - 1, dtype=torch.int64, device=F.ctx()\n            )\n        }\n        graph.edge_attributes = {\n            \"timestamp\": torch.zeros(\n                graph.indices.numel(), dtype=torch.int64, device=F.ctx()\n            )\n        }\n        items = (items, torch.randint(1, 10, (2,)))\n        names = (names, \"timestamp\")\n    itemset = gb.HeteroItemSet({\"n2\": gb.ItemSet(items, names=names)})\n    item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())\n    num_layer = 2\n    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]\n    sampler = _get_sampler(sampler_type)\n    if _is_temporal(sampler_type):\n        datapipe = sampler(item_sampler, graph, fanouts)\n    else:\n        datapipe = sampler(item_sampler, graph, fanouts, deduplicate=False)\n    csc_formats = [\n        {\n            \"n1:e1:n2\": gb.CSCFormatBase(\n                indptr=torch.tensor([0, 2, 4]),\n                indices=torch.tensor([4, 5, 6, 7]),\n            ),\n            \"n2:e2:n1\": gb.CSCFormatBase(\n                indptr=torch.tensor([0, 2, 4, 6, 8]),\n                indices=torch.tensor([2, 3, 4, 5, 6, 7, 8, 9]),\n            ),\n        },\n        {\n            \"n1:e1:n2\": gb.CSCFormatBase(\n                indptr=torch.tensor([0, 2, 4]),\n                indices=torch.tensor([0, 1, 2, 3]),\n            ),\n            \"n2:e2:n1\": gb.CSCFormatBase(\n                indptr=torch.tensor([0]),\n                indices=torch.tensor([], dtype=torch.int64),\n            ),\n        },\n    ]\n    original_column_node_ids = [\n        {\n            \"n1\": torch.tensor([0, 1, 1, 0]),\n            \"n2\": torch.tensor([0, 1]),\n        },\n        {\n            \"n1\": torch.tensor([], dtype=torch.int64),\n            \"n2\": torch.tensor([0, 1]),\n        },\n    ]\n    original_row_node_ids = [\n        {\n            \"n1\": torch.tensor([0, 1, 1, 0, 0, 1, 1, 0]),\n            \"n2\": torch.tensor([0, 1, 0, 2, 0, 1, 0, 1, 0, 2]),\n        },\n        {\n            \"n1\": torch.tensor([0, 1, 1, 0]),\n            \"n2\": torch.tensor([0, 1]),\n        },\n    ]\n\n    with warnings.catch_warnings():\n        warnings.simplefilter(\"ignore\", category=UserWarning)\n        _assert_hetero_values(\n            datapipe,\n            original_row_node_ids,\n            original_column_node_ids,\n            csc_formats,\n        )\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Fails due to different result on the GPU.\",\n)\n@pytest.mark.parametrize(\"labor\", [False, True])\ndef test_SubgraphSampler_unique_csc_format_Homo_Node_cpu(labor):\n    torch.manual_seed(1205)\n    graph = dgl.graph(([5, 0, 6, 7, 2, 2, 4], [0, 1, 2, 2, 3, 4, 4]))\n    graph = gb.from_dglgraph(graph, True).to(F.ctx())\n    seed_nodes = torch.LongTensor([0, 3, 4])\n\n    itemset = gb.ItemSet(seed_nodes, names=\"seeds\")\n    item_sampler = gb.ItemSampler(itemset, batch_size=len(seed_nodes)).copy_to(\n        F.ctx()\n    )\n    num_layer = 2\n    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]\n\n    Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler\n    datapipe = Sampler(\n        item_sampler,\n        graph,\n        fanouts,\n        deduplicate=True,\n    )\n\n    original_row_node_ids = [\n        torch.tensor([0, 3, 4, 5, 2, 6, 7]).to(F.ctx()),\n        torch.tensor([0, 3, 4, 5, 2]).to(F.ctx()),\n    ]\n    compacted_indices = [\n        torch.tensor([3, 4, 4, 2, 5, 6]).to(F.ctx()),\n        torch.tensor([3, 4, 4, 2]).to(F.ctx()),\n    ]\n    indptr = [\n        torch.tensor([0, 1, 2, 4, 4, 6]).to(F.ctx()),\n        torch.tensor([0, 1, 2, 4]).to(F.ctx()),\n    ]\n    seeds = [\n        torch.tensor([0, 3, 4, 5, 2]).to(F.ctx()),\n        torch.tensor([0, 3, 4]).to(F.ctx()),\n    ]\n    _assert_homo_values(\n        datapipe, original_row_node_ids, compacted_indices, indptr, seeds\n    )\n\n\n@unittest.skipIf(\n    F._default_context_str == \"cpu\",\n    reason=\"Fails due to different result on the CPU.\",\n)\n@pytest.mark.parametrize(\"labor\", [False, True])\ndef test_SubgraphSampler_unique_csc_format_Homo_Node_gpu(labor):\n    torch.manual_seed(1205)\n    graph = dgl.graph(([5, 0, 7, 7, 2, 4], [0, 1, 2, 2, 3, 4]))\n    graph = gb.from_dglgraph(graph, is_homogeneous=True).to(F.ctx())\n    seed_nodes = torch.LongTensor([0, 3, 4])\n\n    itemset = gb.ItemSet(seed_nodes, names=\"seeds\")\n    item_sampler = gb.ItemSampler(itemset, batch_size=len(seed_nodes)).copy_to(\n        F.ctx()\n    )\n    num_layer = 2\n    fanouts = [torch.LongTensor([-1]) for _ in range(num_layer)]\n\n    Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler\n    datapipe = Sampler(\n        item_sampler,\n        graph,\n        fanouts,\n        deduplicate=True,\n    )\n\n    if torch.cuda.get_device_capability()[0] < 7:\n        original_row_node_ids = [\n            torch.tensor([0, 3, 4, 2, 5, 7]).to(F.ctx()),\n            torch.tensor([0, 3, 4, 2, 5]).to(F.ctx()),\n        ]\n        compacted_indices = [\n            torch.tensor([4, 3, 2, 5, 5]).to(F.ctx()),\n            torch.tensor([4, 3, 2]).to(F.ctx()),\n        ]\n        indptr = [\n            torch.tensor([0, 1, 2, 3, 5, 5]).to(F.ctx()),\n            torch.tensor([0, 1, 2, 3]).to(F.ctx()),\n        ]\n        seeds = [\n            torch.tensor([0, 3, 4, 2, 5]).to(F.ctx()),\n            torch.tensor([0, 3, 4]).to(F.ctx()),\n        ]\n    else:\n        original_row_node_ids = [\n            torch.tensor([0, 3, 4, 5, 2, 7]).to(F.ctx()),\n            torch.tensor([0, 3, 4, 5, 2]).to(F.ctx()),\n        ]\n        compacted_indices = [\n            torch.tensor([3, 4, 2, 5, 5]).to(F.ctx()),\n            torch.tensor([3, 4, 2]).to(F.ctx()),\n        ]\n        indptr = [\n            torch.tensor([0, 1, 2, 3, 3, 5]).to(F.ctx()),\n            torch.tensor([0, 1, 2, 3]).to(F.ctx()),\n        ]\n        seeds = [\n            torch.tensor([0, 3, 4, 5, 2]).to(F.ctx()),\n            torch.tensor([0, 3, 4]).to(F.ctx()),\n        ]\n\n    _assert_homo_values(\n        datapipe, original_row_node_ids, compacted_indices, indptr, seeds\n    )\n\n\n@pytest.mark.parametrize(\"labor\", [False, True])\ndef test_SubgraphSampler_unique_csc_format_Hetero_Node(labor):\n    graph = get_hetero_graph().to(F.ctx())\n    itemset = gb.HeteroItemSet(\n        {\"n2\": gb.ItemSet(torch.arange(2), names=\"seeds\")}\n    )\n    item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())\n    num_layer = 2\n    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]\n    Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler\n    datapipe = Sampler(\n        item_sampler,\n        graph,\n        fanouts,\n        deduplicate=True,\n    )\n    csc_formats = [\n        {\n            \"n1:e1:n2\": gb.CSCFormatBase(\n                indptr=torch.tensor([0, 2, 4]),\n                indices=torch.tensor([0, 1, 1, 0]),\n            ),\n            \"n2:e2:n1\": gb.CSCFormatBase(\n                indptr=torch.tensor([0, 2, 4]),\n                indices=torch.tensor([0, 2, 0, 1]),\n            ),\n        },\n        {\n            \"n1:e1:n2\": gb.CSCFormatBase(\n                indptr=torch.tensor([0, 2, 4]),\n                indices=torch.tensor([0, 1, 1, 0]),\n            ),\n            \"n2:e2:n1\": gb.CSCFormatBase(\n                indptr=torch.tensor([0]),\n                indices=torch.tensor([], dtype=torch.int64),\n            ),\n        },\n    ]\n    original_column_node_ids = [\n        {\n            \"n1\": torch.tensor([0, 1]),\n            \"n2\": torch.tensor([0, 1]),\n        },\n        {\n            \"n1\": torch.tensor([], dtype=torch.int64),\n            \"n2\": torch.tensor([0, 1]),\n        },\n    ]\n    original_row_node_ids = [\n        {\n            \"n1\": torch.tensor([0, 1]),\n            \"n2\": torch.tensor([0, 1, 2]),\n        },\n        {\n            \"n1\": torch.tensor([0, 1]),\n            \"n2\": torch.tensor([0, 1]),\n        },\n    ]\n\n    _assert_hetero_values(\n        datapipe, original_row_node_ids, original_column_node_ids, csc_formats\n    )\n\n\n@pytest.mark.parametrize(\n    \"sampler_type\",\n    [\n        SamplerType.Normal,\n        SamplerType.Layer,\n        SamplerType.Temporal,\n        SamplerType.TemporalLayer,\n    ],\n)\ndef test_SubgraphSampler_Hetero_multifanout_per_layer(sampler_type):\n    graph = get_hetero_graph().to(F.ctx())\n    items_n1 = torch.tensor([0])\n    items_n2 = torch.tensor([1])\n    names = \"seeds\"\n    if _is_temporal(sampler_type):\n        graph.node_attributes = {\n            \"timestamp\": torch.arange(graph.csc_indptr.numel() - 1).to(F.ctx())\n        }\n        graph.edge_attributes = {\n            \"timestamp\": torch.arange(graph.indices.numel()).to(F.ctx())\n        }\n        # All edges can be sampled.\n        items_n1 = (items_n1, torch.tensor([10]))\n        items_n2 = (items_n2, torch.tensor([10]))\n        names = (names, \"timestamp\")\n    itemset = gb.HeteroItemSet(\n        {\n            \"n1\": gb.ItemSet(items=items_n1, names=names),\n            \"n2\": gb.ItemSet(items=items_n2, names=names),\n        }\n    )\n    item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())\n    num_layer = 2\n    # The number of edges to be sampled for each edge types of each node.\n    fanouts = [torch.LongTensor([2, 1]) for _ in range(num_layer)]\n    sampler = _get_sampler(sampler_type)\n    sampler_dp = sampler(item_sampler, graph, fanouts)\n    if _is_temporal(sampler_type):\n        indices_len = [\n            {\n                \"n1:e1:n2\": 4,\n                \"n2:e2:n1\": 3,\n            },\n            {\n                \"n1:e1:n2\": 2,\n                \"n2:e2:n1\": 1,\n            },\n        ]\n    else:\n        indices_len = [\n            {\n                \"n1:e1:n2\": 4,\n                \"n2:e2:n1\": 2,\n            },\n            {\n                \"n1:e1:n2\": 2,\n                \"n2:e2:n1\": 1,\n            },\n        ]\n    with warnings.catch_warnings():\n        warnings.simplefilter(\"ignore\", category=UserWarning)\n        for minibatch in sampler_dp:\n            for step, sampled_subgraph in enumerate(\n                minibatch.sampled_subgraphs\n            ):\n                assert (\n                    len(sampled_subgraph.sampled_csc[\"n1:e1:n2\"].indices)\n                    == indices_len[step][\"n1:e1:n2\"]\n                )\n                assert (\n                    len(sampled_subgraph.sampled_csc[\"n2:e2:n1\"].indices)\n                    == indices_len[step][\"n2:e2:n1\"]\n                )\n\n\n@pytest.mark.parametrize(\n    \"sampler_type\",\n    [\n        SamplerType.Normal,\n        SamplerType.Layer,\n        SamplerType.Temporal,\n        SamplerType.TemporalLayer,\n    ],\n)\ndef test_SubgraphSampler_without_deduplication_Homo_Link(sampler_type):\n    graph = dgl.graph(\n        ([5, 0, 1, 5, 6, 7, 2, 2, 4], [0, 1, 2, 2, 2, 2, 3, 4, 4])\n    )\n    graph = gb.from_dglgraph(graph, True).to(F.ctx())\n    seed_nodes = torch.LongTensor([[0, 1], [3, 5]])\n    items = seed_nodes\n    names = \"seeds\"\n    if _is_temporal(sampler_type):\n        graph.node_attributes = {\n            \"timestamp\": torch.zeros(\n                graph.csc_indptr.numel() - 1, dtype=torch.int64\n            ).to(F.ctx())\n        }\n        graph.edge_attributes = {\n            \"timestamp\": torch.zeros(\n                graph.indices.numel(), dtype=torch.int64\n            ).to(F.ctx())\n        }\n        items = (items, torch.randint(1, 10, (2,)))\n        names = (names, \"timestamp\")\n\n    itemset = gb.ItemSet(items, names=names)\n    item_sampler = gb.ItemSampler(itemset, batch_size=4).copy_to(F.ctx())\n    num_layer = 2\n    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]\n\n    sampler = _get_sampler(sampler_type)\n    if _is_temporal(sampler_type):\n        datapipe = sampler(item_sampler, graph, fanouts)\n    else:\n        datapipe = sampler(item_sampler, graph, fanouts, deduplicate=False)\n\n    length = [13, 7]\n    compacted_indices = [\n        (torch.arange(0, 6) + 7).to(F.ctx()),\n        (torch.arange(0, 3) + 4).to(F.ctx()),\n    ]\n    indptr = [\n        torch.tensor([0, 1, 2, 3, 3, 3, 4, 6]).to(F.ctx()),\n        torch.tensor([0, 1, 2, 3, 3]).to(F.ctx()),\n    ]\n    seeds = [\n        torch.tensor([0, 0, 1, 2, 3, 5, 5]).to(F.ctx()),\n        torch.tensor([0, 1, 3, 5]).to(F.ctx()),\n    ]\n    for data in datapipe:\n        for step, sampled_subgraph in enumerate(data.sampled_subgraphs):\n            assert len(sampled_subgraph.original_row_node_ids) == length[step]\n            assert torch.equal(\n                sampled_subgraph.sampled_csc.indices, compacted_indices[step]\n            )\n            assert torch.equal(\n                sampled_subgraph.sampled_csc.indptr, indptr[step]\n            )\n            assert torch.equal(\n                torch.sort(sampled_subgraph.original_column_node_ids)[0],\n                seeds[step],\n            )\n\n\n@pytest.mark.parametrize(\n    \"sampler_type\",\n    [\n        SamplerType.Normal,\n        SamplerType.Layer,\n        SamplerType.Temporal,\n        SamplerType.TemporalLayer,\n    ],\n)\ndef test_SubgraphSampler_without_deduplication_Hetero_Link(sampler_type):\n    graph = get_hetero_graph().to(F.ctx())\n    items = torch.arange(2).view(1, 2)\n    names = \"seeds\"\n    if _is_temporal(sampler_type):\n        graph.node_attributes = {\n            \"timestamp\": torch.zeros(\n                graph.csc_indptr.numel() - 1, dtype=torch.int64\n            ).to(F.ctx())\n        }\n        graph.edge_attributes = {\n            \"timestamp\": torch.zeros(\n                graph.indices.numel(), dtype=torch.int64\n            ).to(F.ctx())\n        }\n        items = (items, torch.randint(1, 10, (1,)))\n        names = (names, \"timestamp\")\n    itemset = gb.HeteroItemSet({\"n1:e1:n2\": gb.ItemSet(items, names=names)})\n    item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())\n    num_layer = 2\n    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]\n    sampler = _get_sampler(sampler_type)\n    if _is_temporal(sampler_type):\n        datapipe = sampler(item_sampler, graph, fanouts)\n    else:\n        datapipe = sampler(item_sampler, graph, fanouts, deduplicate=False)\n    csc_formats = [\n        {\n            \"n1:e1:n2\": gb.CSCFormatBase(\n                indptr=torch.tensor([0, 2, 4, 6]),\n                indices=torch.tensor([3, 4, 5, 6, 7, 8]),\n            ),\n            \"n2:e2:n1\": gb.CSCFormatBase(\n                indptr=torch.tensor([0, 2, 4, 6]),\n                indices=torch.tensor([3, 4, 5, 6, 7, 8]),\n            ),\n        },\n        {\n            \"n1:e1:n2\": gb.CSCFormatBase(\n                indptr=torch.tensor([0, 2]),\n                indices=torch.tensor([1, 2]),\n            ),\n            \"n2:e2:n1\": gb.CSCFormatBase(\n                indptr=torch.tensor([0, 2]),\n                indices=torch.tensor([1, 2], dtype=torch.int64),\n            ),\n        },\n    ]\n    original_column_node_ids = [\n        {\n            \"n1\": torch.tensor([0, 1, 0]),\n            \"n2\": torch.tensor([1, 0, 2]),\n        },\n        {\n            \"n1\": torch.tensor([0]),\n            \"n2\": torch.tensor([1]),\n        },\n    ]\n    original_row_node_ids = [\n        {\n            \"n1\": torch.tensor([0, 1, 0, 1, 0, 0, 1, 0, 1]),\n            \"n2\": torch.tensor([1, 0, 2, 0, 2, 0, 1, 0, 2]),\n        },\n        {\n            \"n1\": torch.tensor([0, 1, 0]),\n            \"n2\": torch.tensor([1, 0, 2]),\n        },\n    ]\n\n    for data in datapipe:\n        for step, sampled_subgraph in enumerate(data.sampled_subgraphs):\n            for ntype in [\"n1\", \"n2\"]:\n                assert torch.equal(\n                    sampled_subgraph.original_row_node_ids[ntype],\n                    original_row_node_ids[step][ntype].to(F.ctx()),\n                )\n                assert torch.equal(\n                    sampled_subgraph.original_column_node_ids[ntype],\n                    original_column_node_ids[step][ntype].to(F.ctx()),\n                )\n            for etype in [\"n1:e1:n2\", \"n2:e2:n1\"]:\n                assert torch.equal(\n                    sampled_subgraph.sampled_csc[etype].indices,\n                    csc_formats[step][etype].indices.to(F.ctx()),\n                )\n                assert torch.equal(\n                    sampled_subgraph.sampled_csc[etype].indptr,\n                    csc_formats[step][etype].indptr.to(F.ctx()),\n                )\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Fails due to different result on the GPU.\",\n)\n@pytest.mark.parametrize(\"labor\", [False, True])\ndef test_SubgraphSampler_unique_csc_format_Homo_Link_cpu(labor):\n    torch.manual_seed(1205)\n    graph = dgl.graph(([5, 0, 6, 7, 2, 2, 4], [0, 1, 2, 2, 3, 4, 4]))\n    graph = gb.from_dglgraph(graph, True).to(F.ctx())\n    seed_nodes = torch.LongTensor([[0, 3], [4, 4]])\n\n    itemset = gb.ItemSet(seed_nodes, names=\"seeds\")\n    item_sampler = gb.ItemSampler(itemset, batch_size=4).copy_to(F.ctx())\n    num_layer = 2\n    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]\n\n    Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler\n    datapipe = Sampler(\n        item_sampler,\n        graph,\n        fanouts,\n        deduplicate=True,\n    )\n\n    original_row_node_ids = [\n        torch.tensor([0, 3, 4, 5, 2, 6, 7]).to(F.ctx()),\n        torch.tensor([0, 3, 4, 5, 2]).to(F.ctx()),\n    ]\n    compacted_indices = [\n        torch.tensor([3, 4, 4, 2, 5, 6]).to(F.ctx()),\n        torch.tensor([3, 4, 4, 2]).to(F.ctx()),\n    ]\n    indptr = [\n        torch.tensor([0, 1, 2, 4, 4, 6]).to(F.ctx()),\n        torch.tensor([0, 1, 2, 4]).to(F.ctx()),\n    ]\n    seeds = [\n        torch.tensor([0, 3, 4, 5, 2]).to(F.ctx()),\n        torch.tensor([0, 3, 4]).to(F.ctx()),\n    ]\n    for data in datapipe:\n        for step, sampled_subgraph in enumerate(data.sampled_subgraphs):\n            assert torch.equal(\n                sampled_subgraph.original_row_node_ids,\n                original_row_node_ids[step],\n            )\n            assert torch.equal(\n                sampled_subgraph.sampled_csc.indices, compacted_indices[step]\n            )\n            assert torch.equal(\n                sampled_subgraph.sampled_csc.indptr, indptr[step]\n            )\n            assert torch.equal(\n                sampled_subgraph.original_column_node_ids, seeds[step]\n            )\n\n\n@unittest.skipIf(\n    F._default_context_str == \"cpu\",\n    reason=\"Fails due to different result on the CPU.\",\n)\n@pytest.mark.parametrize(\"labor\", [False, True])\ndef test_SubgraphSampler_unique_csc_format_Homo_Link_gpu(labor):\n    torch.manual_seed(1205)\n    graph = dgl.graph(([5, 0, 7, 7, 2, 4], [0, 1, 2, 2, 3, 4]))\n    graph = gb.from_dglgraph(graph, is_homogeneous=True).to(F.ctx())\n    seed_nodes = torch.LongTensor([[0, 3], [4, 4]])\n\n    itemset = gb.ItemSet(seed_nodes, names=\"seeds\")\n    item_sampler = gb.ItemSampler(itemset, batch_size=4).copy_to(F.ctx())\n    num_layer = 2\n    fanouts = [torch.LongTensor([-1]) for _ in range(num_layer)]\n\n    Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler\n    datapipe = Sampler(\n        item_sampler,\n        graph,\n        fanouts,\n        deduplicate=True,\n    )\n\n    if torch.cuda.get_device_capability()[0] < 7:\n        original_row_node_ids = [\n            torch.tensor([0, 3, 4, 2, 5, 7]).to(F.ctx()),\n            torch.tensor([0, 3, 4, 2, 5]).to(F.ctx()),\n        ]\n        compacted_indices = [\n            torch.tensor([4, 3, 2, 5, 5]).to(F.ctx()),\n            torch.tensor([4, 3, 2]).to(F.ctx()),\n        ]\n        indptr = [\n            torch.tensor([0, 1, 2, 3, 5, 5]).to(F.ctx()),\n            torch.tensor([0, 1, 2, 3]).to(F.ctx()),\n        ]\n        seeds = [\n            torch.tensor([0, 3, 4, 2, 5]).to(F.ctx()),\n            torch.tensor([0, 3, 4]).to(F.ctx()),\n        ]\n    else:\n        original_row_node_ids = [\n            torch.tensor([0, 3, 4, 5, 2, 7]).to(F.ctx()),\n            torch.tensor([0, 3, 4, 5, 2]).to(F.ctx()),\n        ]\n        compacted_indices = [\n            torch.tensor([3, 4, 2, 5, 5]).to(F.ctx()),\n            torch.tensor([3, 4, 2]).to(F.ctx()),\n        ]\n        indptr = [\n            torch.tensor([0, 1, 2, 3, 3, 5]).to(F.ctx()),\n            torch.tensor([0, 1, 2, 3]).to(F.ctx()),\n        ]\n        seeds = [\n            torch.tensor([0, 3, 4, 5, 2]).to(F.ctx()),\n            torch.tensor([0, 3, 4]).to(F.ctx()),\n        ]\n\n    for data in datapipe:\n        for step, sampled_subgraph in enumerate(data.sampled_subgraphs):\n            assert torch.equal(\n                sampled_subgraph.original_row_node_ids,\n                original_row_node_ids[step],\n            )\n            assert torch.equal(\n                sampled_subgraph.sampled_csc.indices, compacted_indices[step]\n            )\n            assert torch.equal(\n                sampled_subgraph.sampled_csc.indptr, indptr[step]\n            )\n            assert torch.equal(\n                sampled_subgraph.original_column_node_ids, seeds[step]\n            )\n\n\n@pytest.mark.parametrize(\"labor\", [False, True])\ndef test_SubgraphSampler_unique_csc_format_Hetero_Link(labor):\n    graph = get_hetero_graph().to(F.ctx())\n    itemset = gb.HeteroItemSet(\n        {\"n1:e1:n2\": gb.ItemSet(torch.tensor([[0, 1]]), names=\"seeds\")}\n    )\n    item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())\n    num_layer = 2\n    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]\n    Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler\n    datapipe = Sampler(\n        item_sampler,\n        graph,\n        fanouts,\n        deduplicate=True,\n    )\n    csc_formats = [\n        {\n            \"n1:e1:n2\": gb.CSCFormatBase(\n                indptr=torch.tensor([0, 2, 4, 6]),\n                indices=torch.tensor([1, 0, 0, 1, 0, 1]),\n            ),\n            \"n2:e2:n1\": gb.CSCFormatBase(\n                indptr=torch.tensor([0, 2, 4]),\n                indices=torch.tensor([1, 2, 1, 0]),\n            ),\n        },\n        {\n            \"n1:e1:n2\": gb.CSCFormatBase(\n                indptr=torch.tensor([0, 2]),\n                indices=torch.tensor([1, 0]),\n            ),\n            \"n2:e2:n1\": gb.CSCFormatBase(\n                indptr=torch.tensor([0, 2]),\n                indices=torch.tensor([1, 2], dtype=torch.int64),\n            ),\n        },\n    ]\n    original_column_node_ids = [\n        {\n            \"n1\": torch.tensor([0, 1]),\n            \"n2\": torch.tensor([0, 1, 2]),\n        },\n        {\n            \"n1\": torch.tensor([0]),\n            \"n2\": torch.tensor([1]),\n        },\n    ]\n    original_row_node_ids = [\n        {\n            \"n1\": torch.tensor([0, 1]),\n            \"n2\": torch.tensor([0, 1, 2]),\n        },\n        {\n            \"n1\": torch.tensor([0, 1]),\n            \"n2\": torch.tensor([0, 1, 2]),\n        },\n    ]\n\n    for data in datapipe:\n        for step, sampled_subgraph in enumerate(data.sampled_subgraphs):\n            for ntype in [\"n1\", \"n2\"]:\n                assert torch.equal(\n                    torch.sort(sampled_subgraph.original_row_node_ids[ntype])[\n                        0\n                    ],\n                    original_row_node_ids[step][ntype].to(F.ctx()),\n                )\n                assert torch.equal(\n                    torch.sort(\n                        sampled_subgraph.original_column_node_ids[ntype]\n                    )[0],\n                    original_column_node_ids[step][ntype].to(F.ctx()),\n                )\n            for etype in [\"n1:e1:n2\", \"n2:e2:n1\"]:\n                assert torch.equal(\n                    sampled_subgraph.sampled_csc[etype].indices,\n                    csc_formats[step][etype].indices.to(F.ctx()),\n                )\n                assert torch.equal(\n                    sampled_subgraph.sampled_csc[etype].indptr,\n                    csc_formats[step][etype].indptr.to(F.ctx()),\n                )\n\n\n@pytest.mark.parametrize(\n    \"sampler_type\",\n    [\n        SamplerType.Normal,\n        SamplerType.Layer,\n        SamplerType.Temporal,\n        SamplerType.TemporalLayer,\n    ],\n)\ndef test_SubgraphSampler_without_deduplication_Homo_HyperLink(sampler_type):\n    graph = dgl.graph(\n        ([5, 0, 1, 5, 6, 7, 2, 2, 4], [0, 1, 2, 2, 2, 2, 3, 4, 4])\n    )\n    graph = gb.from_dglgraph(graph, True).to(F.ctx())\n    items = torch.LongTensor([[0, 1, 4], [3, 5, 6]])\n    names = \"seeds\"\n    if _is_temporal(sampler_type):\n        graph.node_attributes = {\n            \"timestamp\": torch.zeros(\n                graph.csc_indptr.numel() - 1, dtype=torch.int64\n            ).to(F.ctx())\n        }\n        graph.edge_attributes = {\n            \"timestamp\": torch.zeros(\n                graph.indices.numel(), dtype=torch.int64\n            ).to(F.ctx())\n        }\n        items = (items, torch.randint(1, 10, (2,)))\n        names = (names, \"timestamp\")\n\n    itemset = gb.ItemSet(items, names=names)\n    item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())\n    num_layer = 2\n    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]\n\n    sampler = _get_sampler(sampler_type)\n    if _is_temporal(sampler_type):\n        datapipe = sampler(item_sampler, graph, fanouts)\n    else:\n        datapipe = sampler(item_sampler, graph, fanouts, deduplicate=False)\n\n    length = [23, 11]\n    compacted_indices = [\n        (torch.arange(0, 12) + 11).to(F.ctx()),\n        (torch.arange(0, 5) + 6).to(F.ctx()),\n    ]\n    indptr = [\n        torch.tensor([0, 1, 2, 4, 5, 5, 5, 5, 6, 8, 10, 12]).to(F.ctx()),\n        torch.tensor([0, 1, 2, 4, 5, 5, 5]).to(F.ctx()),\n    ]\n    seeds = [\n        torch.tensor([0, 0, 1, 2, 2, 3, 4, 4, 5, 5, 6]).to(F.ctx()),\n        torch.tensor([0, 1, 3, 4, 5, 6]).to(F.ctx()),\n    ]\n    for data in datapipe:\n        for step, sampled_subgraph in enumerate(data.sampled_subgraphs):\n            assert len(sampled_subgraph.original_row_node_ids) == length[step]\n            assert torch.equal(\n                sampled_subgraph.sampled_csc.indices, compacted_indices[step]\n            )\n            assert torch.equal(\n                sampled_subgraph.sampled_csc.indptr, indptr[step]\n            )\n            assert torch.equal(\n                torch.sort(sampled_subgraph.original_column_node_ids)[0],\n                seeds[step],\n            )\n\n\n@pytest.mark.parametrize(\n    \"sampler_type\",\n    [\n        SamplerType.Normal,\n        SamplerType.Layer,\n        SamplerType.Temporal,\n        SamplerType.TemporalLayer,\n    ],\n)\ndef test_SubgraphSampler_without_deduplication_Hetero_HyperLink(sampler_type):\n    graph = get_hetero_graph().to(F.ctx())\n    items = torch.arange(3).view(1, 3)\n    names = \"seeds\"\n    if _is_temporal(sampler_type):\n        graph.node_attributes = {\n            \"timestamp\": torch.zeros(\n                graph.csc_indptr.numel() - 1, dtype=torch.int64\n            ).to(F.ctx())\n        }\n        graph.edge_attributes = {\n            \"timestamp\": torch.zeros(\n                graph.indices.numel(), dtype=torch.int64\n            ).to(F.ctx())\n        }\n        items = (items, torch.randint(1, 10, (1,)))\n        names = (names, \"timestamp\")\n    itemset = gb.HeteroItemSet({\"n2:n1:n2\": gb.ItemSet(items, names=names)})\n    item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())\n    num_layer = 2\n    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]\n    sampler = _get_sampler(sampler_type)\n    if _is_temporal(sampler_type):\n        datapipe = sampler(item_sampler, graph, fanouts)\n    else:\n        datapipe = sampler(item_sampler, graph, fanouts, deduplicate=False)\n    csc_formats = [\n        {\n            \"n1:e1:n2\": gb.CSCFormatBase(\n                indptr=torch.tensor([0, 2, 4, 6, 8]),\n                indices=torch.tensor([5, 6, 7, 8, 9, 10, 11, 12]),\n            ),\n            \"n2:e2:n1\": gb.CSCFormatBase(\n                indptr=torch.tensor([0, 2, 4, 6, 8, 10]),\n                indices=torch.tensor([4, 5, 6, 7, 8, 9, 10, 11, 12, 13]),\n            ),\n        },\n        {\n            \"n1:e1:n2\": gb.CSCFormatBase(\n                indptr=torch.tensor([0, 2, 4]),\n                indices=torch.tensor([1, 2, 3, 4]),\n            ),\n            \"n2:e2:n1\": gb.CSCFormatBase(\n                indptr=torch.tensor([0, 2]),\n                indices=torch.tensor([2, 3], dtype=torch.int64),\n            ),\n        },\n    ]\n    original_column_node_ids = [\n        {\n            \"n1\": torch.tensor([1, 0, 1, 0, 1]),\n            \"n2\": torch.tensor([0, 2, 0, 1]),\n        },\n        {\n            \"n1\": torch.tensor([1]),\n            \"n2\": torch.tensor([0, 2]),\n        },\n    ]\n    original_row_node_ids = [\n        {\n            \"n1\": torch.tensor([1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0]),\n            \"n2\": torch.tensor([0, 2, 0, 1, 0, 1, 0, 2, 0, 1, 0, 2, 0, 1]),\n        },\n        {\n            \"n1\": torch.tensor([1, 0, 1, 0, 1]),\n            \"n2\": torch.tensor([0, 2, 0, 1]),\n        },\n    ]\n\n    for data in datapipe:\n        for step, sampled_subgraph in enumerate(data.sampled_subgraphs):\n            for ntype in [\"n1\", \"n2\"]:\n                assert torch.equal(\n                    sampled_subgraph.original_row_node_ids[ntype],\n                    original_row_node_ids[step][ntype].to(F.ctx()),\n                )\n                assert torch.equal(\n                    sampled_subgraph.original_column_node_ids[ntype],\n                    original_column_node_ids[step][ntype].to(F.ctx()),\n                )\n            for etype in [\"n1:e1:n2\", \"n2:e2:n1\"]:\n                assert torch.equal(\n                    sampled_subgraph.sampled_csc[etype].indices,\n                    csc_formats[step][etype].indices.to(F.ctx()),\n                )\n                assert torch.equal(\n                    sampled_subgraph.sampled_csc[etype].indptr,\n                    csc_formats[step][etype].indptr.to(F.ctx()),\n                )\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Fails due to different result on the GPU.\",\n)\n@pytest.mark.parametrize(\"labor\", [False, True])\ndef test_SubgraphSampler_unique_csc_format_Homo_HyperLink_cpu(labor):\n    torch.manual_seed(1205)\n    graph = dgl.graph(([5, 0, 6, 7, 2, 2, 4], [0, 1, 2, 2, 3, 4, 4]))\n    graph = gb.from_dglgraph(graph, True).to(F.ctx())\n    seed_nodes = torch.LongTensor([[0, 3, 3], [4, 4, 4]])\n\n    itemset = gb.ItemSet(seed_nodes, names=\"seeds\")\n    item_sampler = gb.ItemSampler(itemset, batch_size=4).copy_to(F.ctx())\n    num_layer = 2\n    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]\n\n    Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler\n    datapipe = Sampler(\n        item_sampler,\n        graph,\n        fanouts,\n        deduplicate=True,\n    )\n\n    original_row_node_ids = [\n        torch.tensor([0, 3, 4, 5, 2, 6, 7]).to(F.ctx()),\n        torch.tensor([0, 3, 4, 5, 2]).to(F.ctx()),\n    ]\n    compacted_indices = [\n        torch.tensor([3, 4, 4, 2, 5, 6]).to(F.ctx()),\n        torch.tensor([3, 4, 4, 2]).to(F.ctx()),\n    ]\n    indptr = [\n        torch.tensor([0, 1, 2, 4, 4, 6]).to(F.ctx()),\n        torch.tensor([0, 1, 2, 4]).to(F.ctx()),\n    ]\n    seeds = [\n        torch.tensor([0, 3, 4, 5, 2]).to(F.ctx()),\n        torch.tensor([0, 3, 4]).to(F.ctx()),\n    ]\n    for data in datapipe:\n        for step, sampled_subgraph in enumerate(data.sampled_subgraphs):\n            assert torch.equal(\n                sampled_subgraph.original_row_node_ids,\n                original_row_node_ids[step],\n            )\n            assert torch.equal(\n                sampled_subgraph.sampled_csc.indices, compacted_indices[step]\n            )\n            assert torch.equal(\n                sampled_subgraph.sampled_csc.indptr, indptr[step]\n            )\n            assert torch.equal(\n                sampled_subgraph.original_column_node_ids, seeds[step]\n            )\n\n\n@unittest.skipIf(\n    F._default_context_str == \"cpu\",\n    reason=\"Fails due to different result on the CPU.\",\n)\n@pytest.mark.parametrize(\"labor\", [False, True])\ndef test_SubgraphSampler_unique_csc_format_Homo_HyperLink_gpu(labor):\n    torch.manual_seed(1205)\n    graph = dgl.graph(([5, 0, 7, 7, 2, 4], [0, 1, 2, 2, 3, 4]))\n    graph = gb.from_dglgraph(graph, is_homogeneous=True).to(F.ctx())\n    seed_nodes = torch.LongTensor([[0, 3, 4], [4, 4, 3]])\n\n    itemset = gb.ItemSet(seed_nodes, names=\"seeds\")\n    item_sampler = gb.ItemSampler(itemset, batch_size=4).copy_to(F.ctx())\n    num_layer = 2\n    fanouts = [torch.LongTensor([-1]) for _ in range(num_layer)]\n\n    Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler\n    datapipe = Sampler(\n        item_sampler,\n        graph,\n        fanouts,\n        deduplicate=True,\n    )\n\n    if torch.cuda.get_device_capability()[0] < 7:\n        original_row_node_ids = [\n            torch.tensor([0, 3, 4, 2, 5, 7]).to(F.ctx()),\n            torch.tensor([0, 3, 4, 2, 5]).to(F.ctx()),\n        ]\n        compacted_indices = [\n            torch.tensor([4, 3, 2, 5, 5]).to(F.ctx()),\n            torch.tensor([4, 3, 2]).to(F.ctx()),\n        ]\n        indptr = [\n            torch.tensor([0, 1, 2, 3, 5, 5]).to(F.ctx()),\n            torch.tensor([0, 1, 2, 3]).to(F.ctx()),\n        ]\n        seeds = [\n            torch.tensor([0, 3, 4, 2, 5]).to(F.ctx()),\n            torch.tensor([0, 3, 4]).to(F.ctx()),\n        ]\n    else:\n        original_row_node_ids = [\n            torch.tensor([0, 3, 4, 5, 2, 7]).to(F.ctx()),\n            torch.tensor([0, 3, 4, 5, 2]).to(F.ctx()),\n        ]\n        compacted_indices = [\n            torch.tensor([3, 4, 2, 5, 5]).to(F.ctx()),\n            torch.tensor([3, 4, 2]).to(F.ctx()),\n        ]\n        indptr = [\n            torch.tensor([0, 1, 2, 3, 3, 5]).to(F.ctx()),\n            torch.tensor([0, 1, 2, 3]).to(F.ctx()),\n        ]\n        seeds = [\n            torch.tensor([0, 3, 4, 5, 2]).to(F.ctx()),\n            torch.tensor([0, 3, 4]).to(F.ctx()),\n        ]\n\n    for data in datapipe:\n        for step, sampled_subgraph in enumerate(data.sampled_subgraphs):\n            assert torch.equal(\n                sampled_subgraph.original_row_node_ids,\n                original_row_node_ids[step],\n            )\n            assert torch.equal(\n                sampled_subgraph.sampled_csc.indices, compacted_indices[step]\n            )\n            assert torch.equal(\n                sampled_subgraph.sampled_csc.indptr, indptr[step]\n            )\n            assert torch.equal(\n                sampled_subgraph.original_column_node_ids, seeds[step]\n            )\n\n\n@pytest.mark.parametrize(\"labor\", [False, True])\ndef test_SubgraphSampler_unique_csc_format_Hetero_HyperLink(labor):\n    graph = get_hetero_graph().to(F.ctx())\n    itemset = gb.HeteroItemSet(\n        {\"n1:n2:n1\": gb.ItemSet(torch.tensor([[0, 1, 0]]), names=\"seeds\")}\n    )\n    item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())\n    num_layer = 2\n    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]\n    Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler\n    datapipe = Sampler(\n        item_sampler,\n        graph,\n        fanouts,\n        deduplicate=True,\n    )\n    csc_formats = [\n        {\n            \"n1:e1:n2\": gb.CSCFormatBase(\n                indptr=torch.tensor([0, 2, 4, 6]),\n                indices=torch.tensor([1, 0, 0, 1, 0, 1]),\n            ),\n            \"n2:e2:n1\": gb.CSCFormatBase(\n                indptr=torch.tensor([0, 2, 4]),\n                indices=torch.tensor([1, 2, 1, 0]),\n            ),\n        },\n        {\n            \"n1:e1:n2\": gb.CSCFormatBase(\n                indptr=torch.tensor([0, 2]),\n                indices=torch.tensor([1, 0]),\n            ),\n            \"n2:e2:n1\": gb.CSCFormatBase(\n                indptr=torch.tensor([0, 2]),\n                indices=torch.tensor([1, 2], dtype=torch.int64),\n            ),\n        },\n    ]\n    original_column_node_ids = [\n        {\n            \"n1\": torch.tensor([0, 1]),\n            \"n2\": torch.tensor([0, 1, 2]),\n        },\n        {\n            \"n1\": torch.tensor([0]),\n            \"n2\": torch.tensor([1]),\n        },\n    ]\n    original_row_node_ids = [\n        {\n            \"n1\": torch.tensor([0, 1]),\n            \"n2\": torch.tensor([0, 1, 2]),\n        },\n        {\n            \"n1\": torch.tensor([0, 1]),\n            \"n2\": torch.tensor([0, 1, 2]),\n        },\n    ]\n\n    for data in datapipe:\n        for step, sampled_subgraph in enumerate(data.sampled_subgraphs):\n            for ntype in [\"n1\", \"n2\"]:\n                assert torch.equal(\n                    torch.sort(sampled_subgraph.original_row_node_ids[ntype])[\n                        0\n                    ],\n                    original_row_node_ids[step][ntype].to(F.ctx()),\n                )\n                assert torch.equal(\n                    torch.sort(\n                        sampled_subgraph.original_column_node_ids[ntype]\n                    )[0],\n                    original_column_node_ids[step][ntype].to(F.ctx()),\n                )\n            for etype in [\"n1:e1:n2\", \"n2:e2:n1\"]:\n                assert torch.equal(\n                    sampled_subgraph.sampled_csc[etype].indices,\n                    csc_formats[step][etype].indices.to(F.ctx()),\n                )\n                assert torch.equal(\n                    sampled_subgraph.sampled_csc[etype].indptr,\n                    csc_formats[step][etype].indptr.to(F.ctx()),\n                )\n"
  },
  {
    "path": "tests/python/pytorch/graphbolt/test_utils.py",
    "content": "import re\nimport unittest\n\nfrom functools import partial\n\nimport backend as F\n\nimport dgl\nimport dgl.graphbolt as gb\nimport pytest\nimport torch\n\n\ndef test_add_reverse_edges_homo():\n    edges = torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7]]).T\n    combined_edges = gb.add_reverse_edges(edges)\n    assert torch.equal(\n        combined_edges,\n        torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7], [4, 5, 6, 7, 0, 1, 2, 3]]).T,\n    )\n    # Tensor with uncorrect dimensions.\n    edges = torch.tensor([0, 1, 2, 3])\n    with pytest.raises(\n        AssertionError,\n        match=re.escape(\n            \"Only tensor with shape N*2 is supported now, but got torch.Size([4]).\"\n        ),\n    ):\n        gb.add_reverse_edges(edges)\n\n\ndef test_add_reverse_edges_hetero():\n    # reverse_etype doesn't exist in original etypes.\n    edges = {\"n1:e1:n2\": torch.tensor([[0, 1, 2], [4, 5, 6]]).T}\n    reverse_etype_mapping = {\"n1:e1:n2\": \"n2:e2:n1\"}\n    combined_edges = gb.add_reverse_edges(edges, reverse_etype_mapping)\n    assert torch.equal(\n        combined_edges[\"n1:e1:n2\"], torch.tensor([[0, 1, 2], [4, 5, 6]]).T\n    )\n    assert torch.equal(\n        combined_edges[\"n2:e2:n1\"], torch.tensor([[4, 5, 6], [0, 1, 2]]).T\n    )\n    # reverse_etype exists in original etypes.\n    edges = {\n        \"n1:e1:n2\": torch.tensor([[0, 1, 2], [4, 5, 6]]).T,\n        \"n2:e2:n1\": torch.tensor([[7, 8, 9], [10, 11, 12]]).T,\n    }\n    reverse_etype_mapping = {\"n1:e1:n2\": \"n2:e2:n1\"}\n    combined_edges = gb.add_reverse_edges(edges, reverse_etype_mapping)\n    assert torch.equal(\n        combined_edges[\"n1:e1:n2\"], torch.tensor([[0, 1, 2], [4, 5, 6]]).T\n    )\n    assert torch.equal(\n        combined_edges[\"n2:e2:n1\"],\n        torch.tensor([[7, 8, 9, 4, 5, 6], [10, 11, 12, 0, 1, 2]]).T,\n    )\n    # Tensor with uncorrect dimensions.\n    edges = {\n        \"n1:e1:n2\": torch.tensor([0, 1, 2]),\n        \"n2:e2:n1\": torch.tensor([7, 8, 9]),\n    }\n    with pytest.raises(\n        AssertionError,\n        match=re.escape(\n            \"Only tensor with shape N*2 is supported now, but got torch.Size([3]).\"\n        ),\n    ):\n        gb.add_reverse_edges(edges, reverse_etype_mapping)\n\n\n@unittest.skipIf(\n    F._default_context_str == \"gpu\",\n    reason=\"Fails due to different result on the GPU.\",\n)\n@pytest.mark.parametrize(\"use_datapipe\", [False, True])\ndef test_exclude_seed_edges_homo_cpu(use_datapipe):\n    graph = dgl.graph(([5, 0, 6, 7, 2, 2, 4], [0, 1, 2, 2, 3, 4, 4]))\n    graph = gb.from_dglgraph(graph, True).to(F.ctx())\n    items = torch.LongTensor([[0, 3], [4, 4]])\n    names = \"seeds\"\n    itemset = gb.ItemSet(items, names=names)\n    datapipe = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())\n    num_layer = 2\n    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]\n    sampler = gb.NeighborSampler\n    datapipe = sampler(datapipe, graph, fanouts)\n    if use_datapipe:\n        datapipe = datapipe.exclude_seed_edges()\n    else:\n        datapipe = datapipe.transform(partial(gb.exclude_seed_edges))\n    original_row_node_ids = [\n        torch.tensor([0, 3, 4, 5, 2, 6, 7]).to(F.ctx()),\n        torch.tensor([0, 3, 4, 5, 2]).to(F.ctx()),\n    ]\n    compacted_indices = [\n        torch.tensor([3, 4, 4, 5, 6]).to(F.ctx()),\n        torch.tensor([3, 4, 4]).to(F.ctx()),\n    ]\n    indptr = [\n        torch.tensor([0, 1, 2, 3, 3, 5]).to(F.ctx()),\n        torch.tensor([0, 1, 2, 3]).to(F.ctx()),\n    ]\n    seeds = [\n        torch.tensor([0, 3, 4, 5, 2]).to(F.ctx()),\n        torch.tensor([0, 3, 4]).to(F.ctx()),\n    ]\n    for data in datapipe:\n        for step, sampled_subgraph in enumerate(data.sampled_subgraphs):\n            assert torch.equal(\n                sampled_subgraph.original_row_node_ids,\n                original_row_node_ids[step],\n            )\n            assert torch.equal(\n                sampled_subgraph.sampled_csc.indices, compacted_indices[step]\n            )\n            assert torch.equal(\n                sampled_subgraph.sampled_csc.indptr, indptr[step]\n            )\n            assert torch.equal(\n                sampled_subgraph.original_column_node_ids, seeds[step]\n            )\n\n\n@unittest.skipIf(\n    F._default_context_str == \"cpu\",\n    reason=\"Fails due to different result on the CPU.\",\n)\n@pytest.mark.parametrize(\"use_datapipe\", [False, True])\n@pytest.mark.parametrize(\"async_op\", [False, True])\ndef test_exclude_seed_edges_gpu(use_datapipe, async_op):\n    graph = dgl.graph(([5, 0, 7, 7, 2, 4], [0, 1, 2, 2, 3, 4]))\n    graph = gb.from_dglgraph(graph, is_homogeneous=True).to(F.ctx())\n    items = torch.LongTensor([[0, 3], [4, 4]])\n    names = \"seeds\"\n    itemset = gb.ItemSet(items, names=names)\n    datapipe = gb.ItemSampler(itemset, batch_size=4).copy_to(F.ctx())\n    num_layer = 2\n    fanouts = [torch.LongTensor([-1]) for _ in range(num_layer)]\n    sampler = gb.NeighborSampler\n    datapipe = sampler(\n        datapipe,\n        graph,\n        fanouts,\n        deduplicate=True,\n    )\n    if use_datapipe:\n        datapipe = datapipe.exclude_seed_edges(asynchronous=async_op)\n    else:\n        datapipe = datapipe.transform(\n            partial(gb.exclude_seed_edges, async_op=async_op)\n        )\n    if torch.cuda.get_device_capability()[0] < 7:\n        original_row_node_ids = [\n            torch.tensor([0, 3, 4, 2, 5, 7]).to(F.ctx()),\n            torch.tensor([0, 3, 4, 2, 5]).to(F.ctx()),\n        ]\n        compacted_indices = [\n            torch.tensor([4, 3, 5, 5]).to(F.ctx()),\n            torch.tensor([4, 3]).to(F.ctx()),\n        ]\n        indptr = [\n            torch.tensor([0, 1, 2, 2, 5, 5]).to(F.ctx()),\n            torch.tensor([0, 1, 2, 2]).to(F.ctx()),\n        ]\n        seeds = [\n            torch.tensor([0, 3, 4, 2, 5]).to(F.ctx()),\n            torch.tensor([0, 3, 4]).to(F.ctx()),\n        ]\n    else:\n        original_row_node_ids = [\n            torch.tensor([0, 3, 4, 5, 2, 7]).to(F.ctx()),\n            torch.tensor([0, 3, 4, 5, 2]).to(F.ctx()),\n        ]\n        compacted_indices = [\n            torch.tensor([3, 4, 5, 5]).to(F.ctx()),\n            torch.tensor([3, 4]).to(F.ctx()),\n        ]\n        indptr = [\n            torch.tensor([0, 1, 2, 2, 2, 4]).to(F.ctx()),\n            torch.tensor([0, 1, 2, 2]).to(F.ctx()),\n        ]\n        seeds = [\n            torch.tensor([0, 3, 4, 5, 2]).to(F.ctx()),\n            torch.tensor([0, 3, 4]).to(F.ctx()),\n        ]\n    for data in datapipe:\n        for step, sampled_subgraph in enumerate(data.sampled_subgraphs):\n            if async_op and not use_datapipe:\n                sampled_subgraph = sampled_subgraph.wait()\n            assert torch.equal(\n                sampled_subgraph.original_row_node_ids,\n                original_row_node_ids[step],\n            )\n            assert torch.equal(\n                (sampled_subgraph.sampled_csc.indices), compacted_indices[step]\n            )\n            assert torch.equal(\n                sampled_subgraph.sampled_csc.indptr, indptr[step]\n            )\n            assert torch.equal(\n                sampled_subgraph.original_column_node_ids, seeds[step]\n            )\n\n\ndef get_hetero_graph():\n    # COO graph:\n    # [0, 0, 1, 1, 2, 2, 3, 3, 4, 4]\n    # [2, 4, 2, 3, 0, 1, 1, 0, 0, 1]\n    # [1, 1, 1, 1, 0, 0, 0, 0, 0] - > edge type.\n    # num_nodes = 5, num_n1 = 2, num_n2 = 3\n    ntypes = {\"n1\": 0, \"n2\": 1}\n    etypes = {\"n1:e1:n2\": 0, \"n2:e2:n1\": 1}\n    indptr = torch.LongTensor([0, 2, 4, 6, 8, 10])\n    indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1])\n    type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0])\n    node_type_offset = torch.LongTensor([0, 2, 5])\n    return gb.fused_csc_sampling_graph(\n        indptr,\n        indices,\n        node_type_offset=node_type_offset,\n        type_per_edge=type_per_edge,\n        node_type_to_id=ntypes,\n        edge_type_to_id=etypes,\n    )\n\n\ndef test_exclude_seed_edges_hetero():\n    graph = get_hetero_graph().to(F.ctx())\n    itemset = gb.HeteroItemSet(\n        {\"n1:e1:n2\": gb.ItemSet(torch.tensor([[0, 1]]), names=\"seeds\")}\n    )\n    item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())\n    num_layer = 2\n    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]\n    Sampler = gb.NeighborSampler\n    datapipe = Sampler(\n        item_sampler,\n        graph,\n        fanouts,\n        deduplicate=True,\n    )\n    datapipe = datapipe.transform(partial(gb.exclude_seed_edges))\n    csc_formats = [\n        {\n            \"n1:e1:n2\": gb.CSCFormatBase(\n                indptr=torch.tensor([0, 1, 3, 5]),\n                indices=torch.tensor([1, 0, 1, 0, 1]),\n            ),\n            \"n2:e2:n1\": gb.CSCFormatBase(\n                indptr=torch.tensor([0, 2, 4]),\n                indices=torch.tensor([1, 2, 1, 0]),\n            ),\n        },\n        {\n            \"n1:e1:n2\": gb.CSCFormatBase(\n                indptr=torch.tensor([0, 1]),\n                indices=torch.tensor([1]),\n            ),\n            \"n2:e2:n1\": gb.CSCFormatBase(\n                indptr=torch.tensor([0, 2]),\n                indices=torch.tensor([1, 2], dtype=torch.int64),\n            ),\n        },\n    ]\n    original_column_node_ids = [\n        {\n            \"n1\": torch.tensor([0, 1]),\n            \"n2\": torch.tensor([0, 1, 2]),\n        },\n        {\n            \"n1\": torch.tensor([0]),\n            \"n2\": torch.tensor([1]),\n        },\n    ]\n    original_row_node_ids = [\n        {\n            \"n1\": torch.tensor([0, 1]),\n            \"n2\": torch.tensor([0, 1, 2]),\n        },\n        {\n            \"n1\": torch.tensor([0, 1]),\n            \"n2\": torch.tensor([0, 1, 2]),\n        },\n    ]\n    for data in datapipe:\n        for step, sampled_subgraph in enumerate(data.sampled_subgraphs):\n            for ntype in [\"n1\", \"n2\"]:\n                assert torch.equal(\n                    torch.sort(sampled_subgraph.original_row_node_ids[ntype])[\n                        0\n                    ],\n                    original_row_node_ids[step][ntype].to(F.ctx()),\n                )\n                assert torch.equal(\n                    torch.sort(\n                        sampled_subgraph.original_column_node_ids[ntype]\n                    )[0],\n                    original_column_node_ids[step][ntype].to(F.ctx()),\n                )\n            for etype in [\"n1:e1:n2\", \"n2:e2:n1\"]:\n                assert torch.equal(\n                    sampled_subgraph.sampled_csc[etype].indices,\n                    csc_formats[step][etype].indices.to(F.ctx()),\n                )\n                assert torch.equal(\n                    sampled_subgraph.sampled_csc[etype].indptr,\n                    csc_formats[step][etype].indptr.to(F.ctx()),\n                )\n"
  },
  {
    "path": "tests/python/pytorch/ip_config.txt",
    "content": "0 127.0.0.1 40050\n1 127.0.0.1 40051\n2 127.0.0.1 40052\n3 127.0.0.1 40053"
  },
  {
    "path": "tests/python/pytorch/mpops/test_edgewise.py",
    "content": "import random\n\nimport backend as F\n\nimport dgl\nimport numpy as np\nimport pytest\nimport torch\nfrom utils import parametrize_idtype\n\nrandom.seed(42)\nnp.random.seed(42)\ndgl.seed(42)\ntorch.random.manual_seed(42)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"feat_size\", [(5,), ()])\ndef test_copy_u(idtype, feat_size):\n    ctx = F.ctx()\n    g = dgl.rand_graph(30, 100)\n    g = g.astype(idtype).to(ctx)\n    x = torch.randn(\n        (g.num_nodes(),) + feat_size, requires_grad=True, device=ctx\n    )\n\n    y = dgl.copy_u(g, x)\n    y.sum().backward()\n    x_grad = x.grad\n\n    x.grad.zero_()\n    u, v = g.edges()\n    y_true = x[u.long()]\n    y_true.sum().backward()\n    x_grad_true = x.grad\n\n    assert torch.allclose(y, y_true)\n    assert torch.allclose(x_grad, x_grad_true)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"feat_size\", [(5,), ()])\ndef test_copy_u_hetero(idtype, feat_size):\n    ctx = F.ctx()\n    hg = dgl.heterograph(\n        {\n            (\"user\", \"follow\", \"user\"): ([0, 1, 2], [2, 3, 4]),\n            (\"user\", \"like\", \"movie\"): ([3, 3, 1, 2], [0, 0, 1, 1]),\n        }\n    )\n\n    hg = hg.astype(idtype).to(ctx)\n    x = torch.randn(\n        (hg.num_nodes(\"user\"),) + feat_size, requires_grad=True, device=ctx\n    )\n\n    y = dgl.copy_u(hg, x, etype=\"like\")\n    y.sum().backward()\n    x_grad = x.grad\n\n    x.grad.zero_()\n    u, v = hg.edges(etype=\"like\")\n    y_true = x[u.long()]\n    y_true.sum().backward()\n    x_grad_true = x.grad\n\n    assert torch.allclose(y, y_true)\n    assert torch.allclose(x_grad, x_grad_true)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"feat_size\", [(5,), ()])\ndef test_copy_v(idtype, feat_size):\n    ctx = F.ctx()\n    g = dgl.rand_graph(30, 100)\n    g = g.astype(idtype).to(ctx)\n    x = torch.randn(\n        (g.num_nodes(),) + feat_size, requires_grad=True, device=ctx\n    )\n\n    y = dgl.copy_v(g, x)\n    y.sum().backward()\n    x_grad = x.grad\n\n    x.grad.zero_()\n    u, v = g.edges()\n    y_true = x[v.long()]\n    y_true.sum().backward()\n    x_grad_true = x.grad\n\n    assert torch.allclose(y, y_true)\n    assert torch.allclose(x_grad, x_grad_true)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"feat_size\", [(5,), ()])\ndef test_copy_v_hetero(idtype, feat_size):\n    ctx = F.ctx()\n    hg = dgl.heterograph(\n        {\n            (\"user\", \"follow\", \"user\"): ([0, 1, 2], [2, 3, 4]),\n            (\"user\", \"like\", \"movie\"): ([3, 3, 1, 2], [0, 0, 1, 1]),\n        }\n    )\n\n    hg = hg.astype(idtype).to(ctx)\n    x = torch.randn(\n        (hg.num_nodes(\"movie\"),) + feat_size, requires_grad=True, device=ctx\n    )\n\n    y = dgl.copy_v(hg, x, etype=\"like\")\n    y.sum().backward()\n    x_grad = x.grad\n\n    x.grad.zero_()\n    u, v = hg.edges(etype=\"like\")\n    y_true = x[v.long()]\n    y_true.sum().backward()\n    x_grad_true = x.grad\n\n    assert torch.allclose(y, y_true)\n    assert torch.allclose(x_grad, x_grad_true)\n\n\nbinary_arg_sizes = [\n    ((5,), (5,)),\n    ((5,), ()),\n    ((), (5,)),\n    ((1, 3, 3), (4, 1, 3)),\n    ((3, 3), (4, 1, 3)),\n    ((4, 1, 3), (3, 3)),\n]\n\ndot_arg_sizes = [\n    ((5,), (5,)),\n    ((1, 3, 3), (4, 1, 3)),\n    ((3, 3), (4, 1, 3)),\n    ((4, 1, 3), (3, 3)),\n]\n\nops = [\"add\", \"sub\", \"mul\", \"div\"]\n\n\ndef pad_shape(x, y, x_size, y_size):\n    xy_size = torch.broadcast_shapes(x_size, y_size)\n    new_x_size = (1,) * (len(xy_size) - len(x_size)) + x_size\n    new_y_size = (1,) * (len(xy_size) - len(y_size)) + y_size\n    new_x = x.view(-1, *new_x_size)\n    new_y = y.view(-1, *new_y_size)\n    return new_x, new_y\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"op\", ops)\n@pytest.mark.parametrize(\"x_size,y_size\", binary_arg_sizes)\ndef test_u_op_v(idtype, op, x_size, y_size):\n    ctx = F.ctx()\n    g = dgl.rand_graph(30, 100)\n    g = g.astype(idtype).to(ctx)\n    x = torch.randn((g.num_nodes(),) + x_size, requires_grad=True, device=ctx)\n    y = torch.randn((g.num_nodes(),) + y_size, requires_grad=True, device=ctx)\n\n    f_dgl = getattr(dgl, f\"u_{op}_v\")\n    z = f_dgl(g, x, y)\n    z.sum().backward()\n    x_grad = x.grad\n    y_grad = y.grad\n\n    x_grad.zero_()\n    y_grad.zero_()\n    u, v = g.edges()\n    f_torch = getattr(torch, op)\n    x_u, y_v = pad_shape(x[u.long()], y[v.long()], x_size, y_size)\n    z_true = f_torch(x_u, y_v)\n    z_true.sum().backward()\n    x_grad_true = x.grad\n    y_grad_true = y.grad\n\n    assert torch.allclose(z, z_true)\n    assert torch.allclose(x_grad, x_grad_true)\n    assert torch.allclose(y_grad, y_grad_true)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"x_size,y_size\", dot_arg_sizes)\ndef test_u_dot_v(idtype, x_size, y_size):\n    ctx = F.ctx()\n    g = dgl.rand_graph(30, 100)\n    g = g.astype(idtype).to(ctx)\n    x = torch.randn((g.num_nodes(),) + x_size, requires_grad=True, device=ctx)\n    y = torch.randn((g.num_nodes(),) + y_size, requires_grad=True, device=ctx)\n\n    z = dgl.u_dot_v(g, x, y)\n    z.sum().backward()\n    x_grad = x.grad\n    y_grad = y.grad\n\n    x_grad.zero_()\n    y_grad.zero_()\n    u, v = g.edges()\n    x_u, y_v = pad_shape(x[u.long()], y[v.long()], x_size, y_size)\n    z_true = (x_u * y_v).sum(-1).unsqueeze(-1)\n    z_true.sum().backward()\n    x_grad_true = x.grad\n    y_grad_true = y.grad\n\n    assert torch.allclose(z, z_true, atol=1e-4, rtol=1e-4)\n    assert torch.allclose(x_grad, x_grad_true)\n    assert torch.allclose(y_grad, y_grad_true)\n"
  },
  {
    "path": "tests/python/pytorch/nn/conv/test_gatedgcnconv.py",
    "content": "import io\n\nimport backend as F\n\nimport dgl.nn.pytorch as nn\nimport pytest\nfrom utils import parametrize_idtype\nfrom utils.graph_cases import get_cases\n\ntmp_buffer = io.BytesIO()\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"g\", get_cases([\"homo\"], exclude=[\"zero-degree\"]))\ndef test_gatedgcn_conv(g, idtype):\n    ctx = F.ctx()\n    g = g.astype(idtype).to(ctx)\n    gatedgcnconv = nn.GatedGCNConv(10, 10, 5)\n    feat = F.randn((g.num_nodes(), 10))\n    efeat = F.randn((g.num_edges(), 10))\n    gatedgcnconv = gatedgcnconv.to(ctx)\n\n    h, edge_h = gatedgcnconv(g, feat, efeat)\n    # current we only do shape check\n    assert h.shape == (g.number_of_dst_nodes(), 5)\n    assert edge_h.shape == (g.number_of_edges(), 5)\n"
  },
  {
    "path": "tests/python/pytorch/nn/test_nn.py",
    "content": "import io\nimport pickle\nimport random\nimport re\nfrom copy import deepcopy\n\nimport backend as F\n\nimport dgl\nimport dgl.function as fn\nimport dgl.nn.pytorch as nn\nimport networkx as nx\nimport numpy as np  # For setting seed for scipy\nimport pytest\nimport scipy as sp\nimport torch\nimport torch as th\nfrom dgl import shortest_dist\nfrom torch.nn.utils.rnn import pad_sequence\nfrom torch.optim import Adam, SparseAdam\nfrom torch.utils.data import DataLoader\nfrom utils import parametrize_idtype\nfrom utils.graph_cases import (\n    get_cases,\n    random_bipartite,\n    random_dglgraph,\n    random_graph,\n)\n\n# Set seeds to make tests fully reproducible.\nSEED = 12345  # random.randint(1, 99999)\nrandom.seed(SEED)  # For networkx\nnp.random.seed(SEED)  # For scipy\ndgl.seed(SEED)\nF.seed(SEED)\n\ntmp_buffer = io.BytesIO()\n\n\ndef _AXWb(A, X, W, b):\n    X = th.matmul(X, W)\n    Y = th.matmul(A, X.view(X.shape[0], -1)).view_as(X)\n    return Y + b\n\n\ndef graph_with_nodes(num_nodes, ctx=None):\n    g = dgl.from_networkx(nx.path_graph(num_nodes))\n    return g.to(ctx) if ctx else g\n\n\n@pytest.mark.parametrize(\"out_dim\", [1, 2])\ndef test_graph_conv0(out_dim):\n    ctx = F.ctx()\n    g = graph_with_nodes(3, ctx)\n    adj = g.adj_external(transpose=True, ctx=ctx)\n\n    conv = nn.GraphConv(5, out_dim, norm=\"none\", bias=True)\n    conv = conv.to(ctx)\n    print(conv)\n\n    # test pickle\n    th.save(conv, tmp_buffer)\n\n    # test#1: basic\n    h0 = F.ones((3, 5))\n    h1 = conv(g, h0)\n    assert len(g.ndata) == 0\n    assert len(g.edata) == 0\n    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))\n    # test#2: more-dim\n    h0 = F.ones((3, 5, 5))\n    h1 = conv(g, h0)\n    assert len(g.ndata) == 0\n    assert len(g.edata) == 0\n    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))\n\n    conv = nn.GraphConv(5, out_dim)\n    conv = conv.to(ctx)\n    # test#3: basic\n    h0 = F.ones((3, 5))\n    h1 = conv(g, h0)\n    assert len(g.ndata) == 0\n    assert len(g.edata) == 0\n    # test#4: basic\n    h0 = F.ones((3, 5, 5))\n    h1 = conv(g, h0)\n    assert len(g.ndata) == 0\n    assert len(g.edata) == 0\n\n    conv = nn.GraphConv(5, out_dim)\n    conv = conv.to(ctx)\n    # test#3: basic\n    h0 = F.ones((3, 5))\n    h1 = conv(g, h0)\n    assert len(g.ndata) == 0\n    assert len(g.edata) == 0\n    # test#4: basic\n    h0 = F.ones((3, 5, 5))\n    h1 = conv(g, h0)\n    assert len(g.ndata) == 0\n    assert len(g.edata) == 0\n\n    # test rest_parameters\n    old_weight = deepcopy(conv.weight.data)\n    conv.reset_parameters()\n    new_weight = conv.weight.data\n    assert not F.allclose(old_weight, new_weight)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\n    \"g\", get_cases([\"homo\", \"bipartite\"], exclude=[\"zero-degree\", \"dglgraph\"])\n)\n@pytest.mark.parametrize(\"norm\", [\"none\", \"both\", \"right\", \"left\"])\n@pytest.mark.parametrize(\"weight\", [True, False])\n@pytest.mark.parametrize(\"bias\", [True, False])\n@pytest.mark.parametrize(\"out_dim\", [1, 2])\ndef test_graph_conv(idtype, g, norm, weight, bias, out_dim):\n    # Test one tensor input\n    g = g.astype(idtype).to(F.ctx())\n    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias).to(\n        F.ctx()\n    )\n    ext_w = F.randn((5, out_dim)).to(F.ctx())\n    nsrc = g.number_of_src_nodes()\n    ndst = g.number_of_dst_nodes()\n    h = F.randn((nsrc, 5)).to(F.ctx())\n    if weight:\n        h_out = conv(g, h)\n    else:\n        h_out = conv(g, h, weight=ext_w)\n    assert h_out.shape == (ndst, out_dim)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\n    \"g\",\n    get_cases([\"has_scalar_e_feature\"], exclude=[\"zero-degree\", \"dglgraph\"]),\n)\n@pytest.mark.parametrize(\"norm\", [\"none\", \"both\", \"right\"])\n@pytest.mark.parametrize(\"weight\", [True, False])\n@pytest.mark.parametrize(\"bias\", [True, False])\n@pytest.mark.parametrize(\"out_dim\", [1, 2])\ndef test_graph_conv_e_weight(idtype, g, norm, weight, bias, out_dim):\n    g = g.astype(idtype).to(F.ctx())\n    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias).to(\n        F.ctx()\n    )\n    ext_w = F.randn((5, out_dim)).to(F.ctx())\n    nsrc = g.number_of_src_nodes()\n    ndst = g.number_of_dst_nodes()\n    h = F.randn((nsrc, 5)).to(F.ctx())\n    e_w = g.edata[\"scalar_w\"]\n    if weight:\n        h_out = conv(g, h, edge_weight=e_w)\n    else:\n        h_out = conv(g, h, weight=ext_w, edge_weight=e_w)\n    assert h_out.shape == (ndst, out_dim)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\n    \"g\",\n    get_cases([\"has_scalar_e_feature\"], exclude=[\"zero-degree\", \"dglgraph\"]),\n)\n@pytest.mark.parametrize(\"norm\", [\"none\", \"both\", \"right\"])\n@pytest.mark.parametrize(\"weight\", [True, False])\n@pytest.mark.parametrize(\"bias\", [True, False])\n@pytest.mark.parametrize(\"out_dim\", [1, 2])\ndef test_graph_conv_e_weight_norm(idtype, g, norm, weight, bias, out_dim):\n    g = g.astype(idtype).to(F.ctx())\n    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias).to(\n        F.ctx()\n    )\n\n    # test pickle\n    th.save(conv, tmp_buffer)\n\n    ext_w = F.randn((5, out_dim)).to(F.ctx())\n    nsrc = g.number_of_src_nodes()\n    ndst = g.number_of_dst_nodes()\n    h = F.randn((nsrc, 5)).to(F.ctx())\n    edgenorm = nn.EdgeWeightNorm(norm=norm)\n    norm_weight = edgenorm(g, g.edata[\"scalar_w\"])\n    if weight:\n        h_out = conv(g, h, edge_weight=norm_weight)\n    else:\n        h_out = conv(g, h, weight=ext_w, edge_weight=norm_weight)\n    assert h_out.shape == (ndst, out_dim)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\n    \"g\", get_cases([\"bipartite\"], exclude=[\"zero-degree\", \"dglgraph\"])\n)\n@pytest.mark.parametrize(\"norm\", [\"none\", \"both\", \"right\"])\n@pytest.mark.parametrize(\"weight\", [True, False])\n@pytest.mark.parametrize(\"bias\", [True, False])\n@pytest.mark.parametrize(\"out_dim\", [1, 2])\ndef test_graph_conv_bi(idtype, g, norm, weight, bias, out_dim):\n    # Test a pair of tensor inputs\n    g = g.astype(idtype).to(F.ctx())\n    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias).to(\n        F.ctx()\n    )\n\n    # test pickle\n    th.save(conv, tmp_buffer)\n\n    ext_w = F.randn((5, out_dim)).to(F.ctx())\n    nsrc = g.number_of_src_nodes()\n    ndst = g.number_of_dst_nodes()\n    h = F.randn((nsrc, 5)).to(F.ctx())\n    h_dst = F.randn((ndst, out_dim)).to(F.ctx())\n    if weight:\n        h_out = conv(g, (h, h_dst))\n    else:\n        h_out = conv(g, (h, h_dst), weight=ext_w)\n    assert h_out.shape == (ndst, out_dim)\n\n\ndef _S2AXWb(A, N, X, W, b):\n    X1 = X * N\n    X1 = th.matmul(A, X1.view(X1.shape[0], -1))\n    X1 = X1 * N\n    X2 = X1 * N\n    X2 = th.matmul(A, X2.view(X2.shape[0], -1))\n    X2 = X2 * N\n    X = th.cat([X, X1, X2], dim=-1)\n    Y = th.matmul(X, W.rot90())\n\n    return Y + b\n\n\n@pytest.mark.parametrize(\"out_dim\", [1, 2])\ndef test_tagconv(out_dim):\n    ctx = F.ctx()\n    g = graph_with_nodes(3, ctx)\n    adj = g.adj_external(transpose=True, ctx=ctx)\n    norm = th.pow(g.in_degrees().float(), -0.5)\n\n    conv = nn.TAGConv(5, out_dim, bias=True)\n    conv = conv.to(ctx)\n    print(conv)\n\n    # test pickle\n    th.save(conv, tmp_buffer)\n\n    # test#1: basic\n    h0 = F.ones((3, 5))\n    h1 = conv(g, h0)\n    assert len(g.ndata) == 0\n    assert len(g.edata) == 0\n    shp = norm.shape + (1,) * (h0.dim() - 1)\n    norm = th.reshape(norm, shp).to(ctx)\n\n    assert F.allclose(\n        h1, _S2AXWb(adj, norm, h0, conv.lin.weight, conv.lin.bias)\n    )\n\n    conv = nn.TAGConv(5, out_dim)\n    conv = conv.to(ctx)\n\n    # test#2: basic\n    h0 = F.ones((3, 5))\n    h1 = conv(g, h0)\n    assert h1.shape[-1] == out_dim\n\n    # test reset_parameters\n    old_weight = deepcopy(conv.lin.weight.data)\n    conv.reset_parameters()\n    new_weight = conv.lin.weight.data\n    assert not F.allclose(old_weight, new_weight)\n\n\ndef test_set2set():\n    ctx = F.ctx()\n    g = graph_with_nodes(10, ctx)\n\n    s2s = nn.Set2Set(5, 3, 3)  # hidden size 5, 3 iters, 3 layers\n    s2s = s2s.to(ctx)\n    print(s2s)\n\n    # test#1: basic\n    h0 = F.randn((g.num_nodes(), 5))\n    h1 = s2s(g, h0)\n    assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.dim() == 2\n\n    # test#2: batched graph\n    g1 = graph_with_nodes(11, ctx)\n    g2 = graph_with_nodes(5, ctx)\n    bg = dgl.batch([g, g1, g2])\n    h0 = F.randn((bg.num_nodes(), 5))\n    h1 = s2s(bg, h0)\n    assert h1.shape[0] == 3 and h1.shape[1] == 10 and h1.dim() == 2\n\n\ndef test_glob_att_pool():\n    ctx = F.ctx()\n    g = graph_with_nodes(10, ctx)\n\n    gap = nn.GlobalAttentionPooling(th.nn.Linear(5, 1), th.nn.Linear(5, 10))\n    gap = gap.to(ctx)\n    print(gap)\n\n    # test pickle\n    th.save(gap, tmp_buffer)\n\n    # test#1: basic\n    h0 = F.randn((g.num_nodes(), 5))\n    h1 = gap(g, h0)\n    assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.dim() == 2\n\n    # test#2: batched graph\n    bg = dgl.batch([g, g, g, g])\n    h0 = F.randn((bg.num_nodes(), 5))\n    h1 = gap(bg, h0)\n    assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.dim() == 2\n\n\ndef test_simple_pool():\n    ctx = F.ctx()\n    g = graph_with_nodes(15, ctx)\n\n    sum_pool = nn.SumPooling()\n    avg_pool = nn.AvgPooling()\n    max_pool = nn.MaxPooling()\n    sort_pool = nn.SortPooling(10)  # k = 10\n    print(sum_pool, avg_pool, max_pool, sort_pool)\n\n    # test#1: basic\n    h0 = F.randn((g.num_nodes(), 5))\n    sum_pool = sum_pool.to(ctx)\n    avg_pool = avg_pool.to(ctx)\n    max_pool = max_pool.to(ctx)\n    sort_pool = sort_pool.to(ctx)\n    h1 = sum_pool(g, h0)\n    assert F.allclose(F.squeeze(h1, 0), F.sum(h0, 0))\n    h1 = avg_pool(g, h0)\n    assert F.allclose(F.squeeze(h1, 0), F.mean(h0, 0))\n    h1 = max_pool(g, h0)\n    assert F.allclose(F.squeeze(h1, 0), F.max(h0, 0))\n    h1 = sort_pool(g, h0)\n    assert h1.shape[0] == 1 and h1.shape[1] == 10 * 5 and h1.dim() == 2\n\n    # test#2: batched graph\n    g_ = graph_with_nodes(5, ctx)\n    bg = dgl.batch([g, g_, g, g_, g])\n    h0 = F.randn((bg.num_nodes(), 5))\n    h1 = sum_pool(bg, h0)\n    truth = th.stack(\n        [\n            F.sum(h0[:15], 0),\n            F.sum(h0[15:20], 0),\n            F.sum(h0[20:35], 0),\n            F.sum(h0[35:40], 0),\n            F.sum(h0[40:55], 0),\n        ],\n        0,\n    )\n    assert F.allclose(h1, truth)\n\n    h1 = avg_pool(bg, h0)\n    truth = th.stack(\n        [\n            F.mean(h0[:15], 0),\n            F.mean(h0[15:20], 0),\n            F.mean(h0[20:35], 0),\n            F.mean(h0[35:40], 0),\n            F.mean(h0[40:55], 0),\n        ],\n        0,\n    )\n    assert F.allclose(h1, truth)\n\n    h1 = max_pool(bg, h0)\n    truth = th.stack(\n        [\n            F.max(h0[:15], 0),\n            F.max(h0[15:20], 0),\n            F.max(h0[20:35], 0),\n            F.max(h0[35:40], 0),\n            F.max(h0[40:55], 0),\n        ],\n        0,\n    )\n    assert F.allclose(h1, truth)\n\n    h1 = sort_pool(bg, h0)\n    assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.dim() == 2\n\n\ndef test_set_trans():\n    ctx = F.ctx()\n    g = graph_with_nodes(15)\n\n    st_enc_0 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, \"sab\")\n    st_enc_1 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, \"isab\", 3)\n    st_dec = nn.SetTransformerDecoder(50, 5, 10, 100, 2, 4)\n    st_enc_0 = st_enc_0.to(ctx)\n    st_enc_1 = st_enc_1.to(ctx)\n    st_dec = st_dec.to(ctx)\n    print(st_enc_0, st_enc_1, st_dec)\n\n    # test#1: basic\n    h0 = F.randn((g.num_nodes(), 50))\n    h1 = st_enc_0(g, h0)\n    assert h1.shape == h0.shape\n    h1 = st_enc_1(g, h0)\n    assert h1.shape == h0.shape\n    h2 = st_dec(g, h1)\n    assert h2.shape[0] == 1 and h2.shape[1] == 200 and h2.dim() == 2\n\n    # test#2: batched graph\n    g1 = graph_with_nodes(5)\n    g2 = graph_with_nodes(10)\n    bg = dgl.batch([g, g1, g2])\n    h0 = F.randn((bg.num_nodes(), 50))\n    h1 = st_enc_0(bg, h0)\n    assert h1.shape == h0.shape\n    h1 = st_enc_1(bg, h0)\n    assert h1.shape == h0.shape\n\n    h2 = st_dec(bg, h1)\n    assert h2.shape[0] == 3 and h2.shape[1] == 200 and h2.dim() == 2\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"O\", [1, 8, 32])\ndef test_rgcn(idtype, O):\n    ctx = F.ctx()\n    etype = []\n    g = dgl.from_scipy(sp.sparse.random(100, 100, density=0.1))\n    g = g.astype(idtype).to(F.ctx())\n    # 5 etypes\n    R = 5\n    for i in range(g.num_edges()):\n        etype.append(i % 5)\n    B = 2\n    I = 10\n\n    h = th.randn((100, I)).to(ctx)\n    r = th.tensor(etype).to(ctx)\n    norm = th.rand((g.num_edges(), 1)).to(ctx)\n    sorted_r, idx = th.sort(r)\n    sorted_g = dgl.reorder_graph(\n        g,\n        edge_permute_algo=\"custom\",\n        permute_config={\"edges_perm\": idx.to(idtype)},\n    )\n    sorted_norm = norm[idx]\n\n    rgc = nn.RelGraphConv(I, O, R).to(ctx)\n    th.save(rgc, tmp_buffer)  # test pickle\n    rgc_basis = nn.RelGraphConv(I, O, R, \"basis\", B).to(ctx)\n    th.save(rgc_basis, tmp_buffer)  # test pickle\n    if O % B == 0:\n        rgc_bdd = nn.RelGraphConv(I, O, R, \"bdd\", B).to(ctx)\n        th.save(rgc_bdd, tmp_buffer)  # test pickle\n\n    # basic usage\n    h_new = rgc(g, h, r)\n    assert h_new.shape == (100, O)\n    h_new_basis = rgc_basis(g, h, r)\n    assert h_new_basis.shape == (100, O)\n    if O % B == 0:\n        h_new_bdd = rgc_bdd(g, h, r)\n        assert h_new_bdd.shape == (100, O)\n\n    # sorted input\n    h_new_sorted = rgc(sorted_g, h, sorted_r, presorted=True)\n    assert th.allclose(h_new, h_new_sorted, atol=1e-4, rtol=1e-4)\n    h_new_basis_sorted = rgc_basis(sorted_g, h, sorted_r, presorted=True)\n    assert th.allclose(h_new_basis, h_new_basis_sorted, atol=1e-4, rtol=1e-4)\n    if O % B == 0:\n        h_new_bdd_sorted = rgc_bdd(sorted_g, h, sorted_r, presorted=True)\n        assert th.allclose(h_new_bdd, h_new_bdd_sorted, atol=1e-4, rtol=1e-4)\n\n    # norm input\n    h_new = rgc(g, h, r, norm)\n    assert h_new.shape == (100, O)\n    h_new = rgc_basis(g, h, r, norm)\n    assert h_new.shape == (100, O)\n    if O % B == 0:\n        h_new = rgc_bdd(g, h, r, norm)\n        assert h_new.shape == (100, O)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"O\", [1, 10, 40])\ndef test_rgcn_default_nbasis(idtype, O):\n    ctx = F.ctx()\n    etype = []\n    g = dgl.from_scipy(sp.sparse.random(100, 100, density=0.1))\n    g = g.astype(idtype).to(F.ctx())\n    # 5 etypes\n    R = 5\n    for i in range(g.num_edges()):\n        etype.append(i % 5)\n    I = 10\n\n    h = th.randn((100, I)).to(ctx)\n    r = th.tensor(etype).to(ctx)\n    norm = th.rand((g.num_edges(), 1)).to(ctx)\n    sorted_r, idx = th.sort(r)\n    sorted_g = dgl.reorder_graph(\n        g,\n        edge_permute_algo=\"custom\",\n        permute_config={\"edges_perm\": idx.to(idtype)},\n    )\n    sorted_norm = norm[idx]\n\n    rgc = nn.RelGraphConv(I, O, R).to(ctx)\n    th.save(rgc, tmp_buffer)  # test pickle\n    rgc_basis = nn.RelGraphConv(I, O, R, \"basis\").to(ctx)\n    th.save(rgc_basis, tmp_buffer)  # test pickle\n    if O % R == 0:\n        rgc_bdd = nn.RelGraphConv(I, O, R, \"bdd\").to(ctx)\n        th.save(rgc_bdd, tmp_buffer)  # test pickle\n\n    # basic usage\n    h_new = rgc(g, h, r)\n    assert h_new.shape == (100, O)\n    h_new_basis = rgc_basis(g, h, r)\n    assert h_new_basis.shape == (100, O)\n    if O % R == 0:\n        h_new_bdd = rgc_bdd(g, h, r)\n        assert h_new_bdd.shape == (100, O)\n\n    # sorted input\n    h_new_sorted = rgc(sorted_g, h, sorted_r, presorted=True)\n    assert th.allclose(h_new, h_new_sorted, atol=1e-4, rtol=1e-4)\n    h_new_basis_sorted = rgc_basis(sorted_g, h, sorted_r, presorted=True)\n    assert th.allclose(h_new_basis, h_new_basis_sorted, atol=1e-4, rtol=1e-4)\n    if O % R == 0:\n        h_new_bdd_sorted = rgc_bdd(sorted_g, h, sorted_r, presorted=True)\n        assert th.allclose(h_new_bdd, h_new_bdd_sorted, atol=1e-4, rtol=1e-4)\n\n    # norm input\n    h_new = rgc(g, h, r, norm)\n    assert h_new.shape == (100, O)\n    h_new = rgc_basis(g, h, r, norm)\n    assert h_new.shape == (100, O)\n    if O % R == 0:\n        h_new = rgc_bdd(g, h, r, norm)\n        assert h_new.shape == (100, O)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\n    \"g\", get_cases([\"homo\", \"block-bipartite\"], exclude=[\"zero-degree\"])\n)\n@pytest.mark.parametrize(\"out_dim\", [1, 5])\n@pytest.mark.parametrize(\"num_heads\", [1, 4])\ndef test_gat_conv(g, idtype, out_dim, num_heads):\n    ctx = F.ctx()\n    g = g.astype(idtype).to(ctx)\n    gat = nn.GATConv(5, out_dim, num_heads)\n    feat = F.randn((g.number_of_src_nodes(), 5))\n    gat = gat.to(ctx)\n    h = gat(g, feat)\n\n    # test pickle\n    th.save(gat, tmp_buffer)\n\n    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)\n    _, a = gat(g, feat, get_attention=True)\n    assert a.shape == (g.num_edges(), num_heads, 1)\n\n    # test residual connection\n    gat = nn.GATConv(5, out_dim, num_heads, residual=True)\n    gat = gat.to(ctx)\n    h = gat(g, feat)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"g\", get_cases([\"bipartite\"], exclude=[\"zero-degree\"]))\n@pytest.mark.parametrize(\"out_dim\", [1, 2])\n@pytest.mark.parametrize(\"num_heads\", [1, 4])\ndef test_gat_conv_bi(g, idtype, out_dim, num_heads):\n    ctx = F.ctx()\n    g = g.astype(idtype).to(ctx)\n    gat = nn.GATConv(5, out_dim, num_heads)\n    feat = (\n        F.randn((g.number_of_src_nodes(), 5)),\n        F.randn((g.number_of_dst_nodes(), 5)),\n    )\n    gat = gat.to(ctx)\n    h = gat(g, feat)\n    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)\n    _, a = gat(g, feat, get_attention=True)\n    assert a.shape == (g.num_edges(), num_heads, 1)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"g\", get_cases([\"bipartite\"], exclude=[\"zero-degree\"]))\n@pytest.mark.parametrize(\"out_dim\", [1, 2])\n@pytest.mark.parametrize(\"num_heads\", [1, 4])\ndef test_gat_conv_edge_weight(g, idtype, out_dim, num_heads):\n    ctx = F.ctx()\n    g = g.astype(idtype).to(ctx)\n    gat = nn.GATConv(5, out_dim, num_heads)\n    feat = (\n        F.randn((g.number_of_src_nodes(), 5)),\n        F.randn((g.number_of_dst_nodes(), 5)),\n    )\n    gat = gat.to(ctx)\n    ew = F.randn((g.num_edges(),))\n    h = gat(g, feat, edge_weight=ew)\n    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)\n    _, a = gat(g, feat, get_attention=True)\n    assert a.shape[0] == ew.shape[0]\n    assert a.shape == (g.num_edges(), num_heads, 1)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\n    \"g\", get_cases([\"homo\", \"block-bipartite\"], exclude=[\"zero-degree\"])\n)\n@pytest.mark.parametrize(\"out_dim\", [1, 5])\n@pytest.mark.parametrize(\"num_heads\", [1, 4])\ndef test_gatv2_conv(g, idtype, out_dim, num_heads):\n    g = g.astype(idtype).to(F.ctx())\n    ctx = F.ctx()\n    gat = nn.GATv2Conv(5, out_dim, num_heads)\n    feat = F.randn((g.number_of_src_nodes(), 5))\n    gat = gat.to(ctx)\n    h = gat(g, feat)\n\n    # test pickle\n    th.save(gat, tmp_buffer)\n\n    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)\n    _, a = gat(g, feat, get_attention=True)\n    assert a.shape == (g.num_edges(), num_heads, 1)\n\n    # test residual connection\n    gat = nn.GATConv(5, out_dim, num_heads, residual=True)\n    gat = gat.to(ctx)\n    h = gat(g, feat)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"g\", get_cases([\"bipartite\"], exclude=[\"zero-degree\"]))\n@pytest.mark.parametrize(\"out_dim\", [1, 2])\n@pytest.mark.parametrize(\"num_heads\", [1, 4])\ndef test_gatv2_conv_bi(g, idtype, out_dim, num_heads):\n    g = g.astype(idtype).to(F.ctx())\n    ctx = F.ctx()\n    gat = nn.GATv2Conv(5, out_dim, num_heads)\n    feat = (\n        F.randn((g.number_of_src_nodes(), 5)),\n        F.randn((g.number_of_dst_nodes(), 5)),\n    )\n    gat = gat.to(ctx)\n    h = gat(g, feat)\n    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)\n    _, a = gat(g, feat, get_attention=True)\n    assert a.shape == (g.num_edges(), num_heads, 1)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"g\", get_cases([\"homo\"], exclude=[\"zero-degree\"]))\n@pytest.mark.parametrize(\"out_node_feats\", [1, 5])\n@pytest.mark.parametrize(\"out_edge_feats\", [1, 5])\n@pytest.mark.parametrize(\"num_heads\", [1, 4])\ndef test_egat_conv(g, idtype, out_node_feats, out_edge_feats, num_heads):\n    ctx = F.ctx()\n    g = g.astype(idtype).to(ctx)\n    egat = nn.EGATConv(\n        in_node_feats=10,\n        in_edge_feats=5,\n        out_node_feats=out_node_feats,\n        out_edge_feats=out_edge_feats,\n        num_heads=num_heads,\n    )\n    nfeat = F.randn((g.num_nodes(), 10))\n    efeat = F.randn((g.num_edges(), 5))\n    egat = egat.to(ctx)\n    h, f = egat(g, nfeat, efeat)\n\n    th.save(egat, tmp_buffer)\n\n    assert h.shape == (g.num_nodes(), num_heads, out_node_feats)\n    assert f.shape == (g.num_edges(), num_heads, out_edge_feats)\n    _, _, attn = egat(g, nfeat, efeat, get_attention=True)\n    assert attn.shape == (g.num_edges(), num_heads, 1)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"g\", get_cases([\"bipartite\"], exclude=[\"zero-degree\"]))\n@pytest.mark.parametrize(\"out_node_feats\", [1, 5])\n@pytest.mark.parametrize(\"out_edge_feats\", [1, 5])\n@pytest.mark.parametrize(\"num_heads\", [1, 4])\ndef test_egat_conv_bi(g, idtype, out_node_feats, out_edge_feats, num_heads):\n    ctx = F.ctx()\n    g = g.astype(idtype).to(ctx)\n    egat = nn.EGATConv(\n        in_node_feats=(10, 15),\n        in_edge_feats=7,\n        out_node_feats=out_node_feats,\n        out_edge_feats=out_edge_feats,\n        num_heads=num_heads,\n    )\n    nfeat = (\n        F.randn((g.number_of_src_nodes(), 10)),\n        F.randn((g.number_of_dst_nodes(), 15)),\n    )\n    efeat = F.randn((g.num_edges(), 7))\n    egat = egat.to(ctx)\n    h, f = egat(g, nfeat, efeat)\n\n    th.save(egat, tmp_buffer)\n\n    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_node_feats)\n    assert f.shape == (g.num_edges(), num_heads, out_edge_feats)\n    _, _, attn = egat(g, nfeat, efeat, get_attention=True)\n    assert attn.shape == (g.num_edges(), num_heads, 1)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"g\", get_cases([\"homo\"], exclude=[\"zero-degree\"]))\n@pytest.mark.parametrize(\"out_node_feats\", [1, 5])\n@pytest.mark.parametrize(\"out_edge_feats\", [1, 5])\n@pytest.mark.parametrize(\"num_heads\", [1, 4])\ndef test_egat_conv_edge_weight(\n    g, idtype, out_node_feats, out_edge_feats, num_heads\n):\n    ctx = F.ctx()\n    g = g.astype(idtype).to(ctx)\n    egat = nn.EGATConv(\n        in_node_feats=10,\n        in_edge_feats=5,\n        out_node_feats=out_node_feats,\n        out_edge_feats=out_edge_feats,\n        num_heads=num_heads,\n    )\n    egat = egat.to(ctx)\n    nfeat = F.randn((g.num_nodes(), 10))\n    efeat = F.randn((g.num_edges(), 5))\n    ew = F.randn((g.num_edges(),))\n\n    h, f, attn = egat(g, nfeat, efeat, edge_weight=ew, get_attention=True)\n\n    assert h.shape == (g.num_nodes(), num_heads, out_node_feats)\n    assert f.shape == (g.num_edges(), num_heads, out_edge_feats)\n    assert attn.shape == (g.num_edges(), num_heads, 1)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"g\", get_cases([\"homo\"], exclude=[\"zero-degree\"]))\n@pytest.mark.parametrize(\"out_feats\", [1, 5])\n@pytest.mark.parametrize(\"num_heads\", [1, 4])\ndef test_edgegat_conv(g, idtype, out_feats, num_heads):\n    ctx = F.ctx()\n    g = g.astype(idtype).to(ctx)\n    edgegat = nn.EdgeGATConv(\n        in_feats=10, edge_feats=5, out_feats=out_feats, num_heads=num_heads\n    )\n    nfeat = F.randn((g.number_of_nodes(), 10))\n    efeat = F.randn((g.number_of_edges(), 5))\n    edgegat = edgegat.to(ctx)\n    h = edgegat(g, nfeat, efeat)\n\n    th.save(edgegat, tmp_buffer)\n\n    assert h.shape == (g.number_of_nodes(), num_heads, out_feats)\n    _, attn = edgegat(g, nfeat, efeat, True)\n    assert attn.shape == (g.number_of_edges(), num_heads, 1)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"g\", get_cases([\"bipartite\"], exclude=[\"zero-degree\"]))\n@pytest.mark.parametrize(\"out_feats\", [1, 5])\n@pytest.mark.parametrize(\"num_heads\", [1, 4])\ndef test_edgegat_conv_bi(g, idtype, out_feats, num_heads):\n    ctx = F.ctx()\n    g = g.astype(idtype).to(ctx)\n    edgegat = nn.EdgeGATConv(\n        in_feats=(10, 15),\n        edge_feats=7,\n        out_feats=out_feats,\n        num_heads=num_heads,\n    )\n    nfeat = (\n        F.randn((g.number_of_src_nodes(), 10)),\n        F.randn((g.number_of_dst_nodes(), 15)),\n    )\n    efeat = F.randn((g.number_of_edges(), 7))\n    edgegat = edgegat.to(ctx)\n    h = edgegat(g, nfeat, efeat)\n\n    th.save(edgegat, tmp_buffer)\n\n    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_feats)\n    _, attn = edgegat(g, nfeat, efeat, True)\n    assert attn.shape == (g.number_of_edges(), num_heads, 1)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"g\", get_cases([\"homo\", \"block-bipartite\"]))\n@pytest.mark.parametrize(\"aggre_type\", [\"mean\", \"pool\", \"gcn\", \"lstm\"])\ndef test_sage_conv(idtype, g, aggre_type):\n    g = g.astype(idtype).to(F.ctx())\n    sage = nn.SAGEConv(5, 10, aggre_type)\n    feat = F.randn((g.number_of_src_nodes(), 5))\n    sage = sage.to(F.ctx())\n    # test pickle\n    th.save(sage, tmp_buffer)\n    h = sage(g, feat)\n    assert h.shape[-1] == 10\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"g\", get_cases([\"bipartite\"]))\n@pytest.mark.parametrize(\"aggre_type\", [\"mean\", \"pool\", \"gcn\", \"lstm\"])\n@pytest.mark.parametrize(\"out_dim\", [1, 2])\ndef test_sage_conv_bi(idtype, g, aggre_type, out_dim):\n    g = g.astype(idtype).to(F.ctx())\n    dst_dim = 5 if aggre_type != \"gcn\" else 10\n    sage = nn.SAGEConv((10, dst_dim), out_dim, aggre_type)\n    feat = (\n        F.randn((g.number_of_src_nodes(), 10)),\n        F.randn((g.number_of_dst_nodes(), dst_dim)),\n    )\n    sage = sage.to(F.ctx())\n    h = sage(g, feat)\n    assert h.shape[-1] == out_dim\n    assert h.shape[0] == g.number_of_dst_nodes()\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"out_dim\", [1, 2])\ndef test_sage_conv2(idtype, out_dim):\n    # TODO: add test for blocks\n    # Test the case for graphs without edges\n    g = dgl.heterograph({(\"_U\", \"_E\", \"_V\"): ([], [])}, {\"_U\": 5, \"_V\": 3})\n    g = g.astype(idtype).to(F.ctx())\n    ctx = F.ctx()\n    sage = nn.SAGEConv((3, 3), out_dim, \"gcn\")\n    feat = (F.randn((5, 3)), F.randn((3, 3)))\n    sage = sage.to(ctx)\n    h = sage(g, (F.copy_to(feat[0], F.ctx()), F.copy_to(feat[1], F.ctx())))\n    assert h.shape[-1] == out_dim\n    assert h.shape[0] == 3\n    for aggre_type in [\"mean\", \"pool\", \"lstm\"]:\n        sage = nn.SAGEConv((3, 1), out_dim, aggre_type)\n        feat = (F.randn((5, 3)), F.randn((3, 1)))\n        sage = sage.to(ctx)\n        h = sage(g, feat)\n        assert h.shape[-1] == out_dim\n        assert h.shape[0] == 3\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"g\", get_cases([\"homo\"], exclude=[\"zero-degree\"]))\n@pytest.mark.parametrize(\"out_dim\", [1, 2])\ndef test_sgc_conv(g, idtype, out_dim):\n    ctx = F.ctx()\n    g = g.astype(idtype).to(ctx)\n    # not cached\n    sgc = nn.SGConv(5, out_dim, 3)\n\n    # test pickle\n    th.save(sgc, tmp_buffer)\n\n    feat = F.randn((g.num_nodes(), 5))\n    sgc = sgc.to(ctx)\n\n    h = sgc(g, feat)\n    assert h.shape[-1] == out_dim\n\n    # cached\n    sgc = nn.SGConv(5, out_dim, 3, True)\n    sgc = sgc.to(ctx)\n    h_0 = sgc(g, feat)\n    h_1 = sgc(g, feat + 1)\n    assert F.allclose(h_0, h_1)\n    assert h_0.shape[-1] == out_dim\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"g\", get_cases([\"homo\"], exclude=[\"zero-degree\"]))\ndef test_appnp_conv(g, idtype):\n    ctx = F.ctx()\n    g = g.astype(idtype).to(ctx)\n    appnp = nn.APPNPConv(10, 0.1)\n    feat = F.randn((g.num_nodes(), 5))\n    appnp = appnp.to(ctx)\n\n    # test pickle\n    th.save(appnp, tmp_buffer)\n\n    h = appnp(g, feat)\n    assert h.shape[-1] == 5\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"g\", get_cases([\"homo\"], exclude=[\"zero-degree\"]))\ndef test_appnp_conv_e_weight(g, idtype):\n    ctx = F.ctx()\n    g = g.astype(idtype).to(ctx)\n    appnp = nn.APPNPConv(10, 0.1)\n    feat = F.randn((g.num_nodes(), 5))\n    eweight = F.ones((g.num_edges(),))\n    appnp = appnp.to(ctx)\n\n    h = appnp(g, feat, edge_weight=eweight)\n    assert h.shape[-1] == 5\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"g\", get_cases([\"homo\"], exclude=[\"zero-degree\"]))\n@pytest.mark.parametrize(\"bias\", [True, False])\ndef test_gcn2conv_e_weight(g, idtype, bias):\n    ctx = F.ctx()\n    g = g.astype(idtype).to(ctx)\n    gcn2conv = nn.GCN2Conv(\n        5, layer=2, alpha=0.5, bias=bias, project_initial_features=True\n    )\n    feat = F.randn((g.num_nodes(), 5))\n    eweight = F.ones((g.num_edges(),))\n    gcn2conv = gcn2conv.to(ctx)\n    res = feat\n    h = gcn2conv(g, res, feat, edge_weight=eweight)\n    assert h.shape[-1] == 5\n    assert re.match(\n        re.compile(\".*GCN2Conv.*in=.*, alpha=.*, beta=.*\"), str(gcn2conv)\n    )\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"g\", get_cases([\"homo\"], exclude=[\"zero-degree\"]))\ndef test_sgconv_e_weight(g, idtype):\n    ctx = F.ctx()\n    g = g.astype(idtype).to(ctx)\n    sgconv = nn.SGConv(5, 5, 3)\n    feat = F.randn((g.num_nodes(), 5))\n    eweight = F.ones((g.num_edges(),))\n    sgconv = sgconv.to(ctx)\n    h = sgconv(g, feat, edge_weight=eweight)\n    assert h.shape[-1] == 5\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"g\", get_cases([\"homo\"], exclude=[\"zero-degree\"]))\ndef test_tagconv_e_weight(g, idtype):\n    ctx = F.ctx()\n    g = g.astype(idtype).to(ctx)\n    conv = nn.TAGConv(5, 5, bias=True)\n    conv = conv.to(ctx)\n    feat = F.randn((g.num_nodes(), 5))\n    eweight = F.ones((g.num_edges(),))\n    conv = conv.to(ctx)\n    h = conv(g, feat, edge_weight=eweight)\n    assert h.shape[-1] == 5\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\n    \"g\", get_cases([\"homo\", \"block-bipartite\"], exclude=[\"zero-degree\"])\n)\n@pytest.mark.parametrize(\"aggregator_type\", [\"mean\", \"max\", \"sum\"])\ndef test_gin_conv(g, idtype, aggregator_type):\n    g = g.astype(idtype).to(F.ctx())\n    ctx = F.ctx()\n    gin = nn.GINConv(th.nn.Linear(5, 12), aggregator_type)\n    th.save(gin, tmp_buffer)\n    feat = F.randn((g.number_of_src_nodes(), 5))\n    gin = gin.to(ctx)\n    h = gin(g, feat)\n\n    # test pickle\n    th.save(gin, tmp_buffer)\n\n    assert h.shape == (g.number_of_dst_nodes(), 12)\n\n    gin = nn.GINConv(None, aggregator_type)\n    th.save(gin, tmp_buffer)\n    gin = gin.to(ctx)\n    h = gin(g, feat)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"g\", get_cases([\"homo\", \"block-bipartite\"]))\ndef test_gine_conv(g, idtype):\n    ctx = F.ctx()\n    g = g.astype(idtype).to(ctx)\n    gine = nn.GINEConv(th.nn.Linear(5, 12))\n    th.save(gine, tmp_buffer)\n    nfeat = F.randn((g.number_of_src_nodes(), 5))\n    efeat = F.randn((g.num_edges(), 5))\n    gine = gine.to(ctx)\n    h = gine(g, nfeat, efeat)\n\n    # test pickle\n    th.save(gine, tmp_buffer)\n    assert h.shape == (g.number_of_dst_nodes(), 12)\n\n    gine = nn.GINEConv(None)\n    th.save(gine, tmp_buffer)\n    gine = gine.to(ctx)\n    h = gine(g, nfeat, efeat)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"g\", get_cases([\"bipartite\"], exclude=[\"zero-degree\"]))\n@pytest.mark.parametrize(\"aggregator_type\", [\"mean\", \"max\", \"sum\"])\ndef test_gin_conv_bi(g, idtype, aggregator_type):\n    g = g.astype(idtype).to(F.ctx())\n    ctx = F.ctx()\n    gin = nn.GINConv(th.nn.Linear(5, 12), aggregator_type)\n    feat = (\n        F.randn((g.number_of_src_nodes(), 5)),\n        F.randn((g.number_of_dst_nodes(), 5)),\n    )\n    gin = gin.to(ctx)\n    h = gin(g, feat)\n    assert h.shape == (g.number_of_dst_nodes(), 12)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\n    \"g\", get_cases([\"homo\", \"block-bipartite\"], exclude=[\"zero-degree\"])\n)\ndef test_agnn_conv(g, idtype):\n    g = g.astype(idtype).to(F.ctx())\n    ctx = F.ctx()\n    agnn = nn.AGNNConv(1)\n    feat = F.randn((g.number_of_src_nodes(), 5))\n    agnn = agnn.to(ctx)\n    h = agnn(g, feat)\n    assert h.shape == (g.number_of_dst_nodes(), 5)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"g\", get_cases([\"bipartite\"], exclude=[\"zero-degree\"]))\ndef test_agnn_conv_bi(g, idtype):\n    g = g.astype(idtype).to(F.ctx())\n    ctx = F.ctx()\n    agnn = nn.AGNNConv(1)\n    feat = (\n        F.randn((g.number_of_src_nodes(), 5)),\n        F.randn((g.number_of_dst_nodes(), 5)),\n    )\n    agnn = agnn.to(ctx)\n    h = agnn(g, feat)\n    assert h.shape == (g.number_of_dst_nodes(), 5)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"g\", get_cases([\"homo\"], exclude=[\"zero-degree\"]))\ndef test_gated_graph_conv(g, idtype):\n    ctx = F.ctx()\n    g = g.astype(idtype).to(ctx)\n    ggconv = nn.GatedGraphConv(5, 10, 5, 3)\n    etypes = th.arange(g.num_edges()) % 3\n    feat = F.randn((g.num_nodes(), 5))\n    ggconv = ggconv.to(ctx)\n    etypes = etypes.to(ctx)\n\n    h = ggconv(g, feat, etypes)\n    # current we only do shape check\n    assert h.shape[-1] == 10\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"g\", get_cases([\"homo\"], exclude=[\"zero-degree\"]))\ndef test_gated_graph_conv_one_etype(g, idtype):\n    ctx = F.ctx()\n    g = g.astype(idtype).to(ctx)\n    ggconv = nn.GatedGraphConv(5, 10, 5, 1)\n    etypes = th.zeros(g.num_edges())\n    feat = F.randn((g.num_nodes(), 5))\n    ggconv = ggconv.to(ctx)\n    etypes = etypes.to(ctx)\n\n    h = ggconv(g, feat, etypes)\n    h2 = ggconv(g, feat)\n    # current we only do shape check\n    assert F.allclose(h, h2)\n    assert h.shape[-1] == 10\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\n    \"g\", get_cases([\"homo\", \"block-bipartite\"], exclude=[\"zero-degree\"])\n)\ndef test_nn_conv(g, idtype):\n    g = g.astype(idtype).to(F.ctx())\n    ctx = F.ctx()\n    edge_func = th.nn.Linear(4, 5 * 10)\n    nnconv = nn.NNConv(5, 10, edge_func, \"mean\")\n    feat = F.randn((g.number_of_src_nodes(), 5))\n    efeat = F.randn((g.num_edges(), 4))\n    nnconv = nnconv.to(ctx)\n    h = nnconv(g, feat, efeat)\n    # currently we only do shape check\n    assert h.shape[-1] == 10\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"g\", get_cases([\"bipartite\"], exclude=[\"zero-degree\"]))\ndef test_nn_conv_bi(g, idtype):\n    g = g.astype(idtype).to(F.ctx())\n    ctx = F.ctx()\n    edge_func = th.nn.Linear(4, 5 * 10)\n    nnconv = nn.NNConv((5, 2), 10, edge_func, \"mean\")\n    feat = F.randn((g.number_of_src_nodes(), 5))\n    feat_dst = F.randn((g.number_of_dst_nodes(), 2))\n    efeat = F.randn((g.num_edges(), 4))\n    nnconv = nnconv.to(ctx)\n    h = nnconv(g, (feat, feat_dst), efeat)\n    # currently we only do shape check\n    assert h.shape[-1] == 10\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"g\", get_cases([\"homo\"], exclude=[\"zero-degree\"]))\ndef test_gmm_conv(g, idtype):\n    g = g.astype(idtype).to(F.ctx())\n    ctx = F.ctx()\n    gmmconv = nn.GMMConv(5, 10, 3, 4, \"mean\")\n    feat = F.randn((g.num_nodes(), 5))\n    pseudo = F.randn((g.num_edges(), 3))\n    gmmconv = gmmconv.to(ctx)\n    h = gmmconv(g, feat, pseudo)\n    # currently we only do shape check\n    assert h.shape[-1] == 10\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\n    \"g\", get_cases([\"bipartite\", \"block-bipartite\"], exclude=[\"zero-degree\"])\n)\ndef test_gmm_conv_bi(g, idtype):\n    g = g.astype(idtype).to(F.ctx())\n    ctx = F.ctx()\n    gmmconv = nn.GMMConv((5, 2), 10, 3, 4, \"mean\")\n    feat = F.randn((g.number_of_src_nodes(), 5))\n    feat_dst = F.randn((g.number_of_dst_nodes(), 2))\n    pseudo = F.randn((g.num_edges(), 3))\n    gmmconv = gmmconv.to(ctx)\n    h = gmmconv(g, (feat, feat_dst), pseudo)\n    # currently we only do shape check\n    assert h.shape[-1] == 10\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"norm_type\", [\"both\", \"right\", \"none\"])\n@pytest.mark.parametrize(\n    \"g\", get_cases([\"homo\", \"bipartite\"], exclude=[\"zero-degree\"])\n)\n@pytest.mark.parametrize(\"out_dim\", [1, 2])\ndef test_dense_graph_conv(norm_type, g, idtype, out_dim):\n    g = g.astype(idtype).to(F.ctx())\n    ctx = F.ctx()\n    # TODO(minjie): enable the following option after #1385\n    adj = g.adj_external(transpose=True, ctx=ctx).to_dense()\n    conv = nn.GraphConv(5, out_dim, norm=norm_type, bias=True)\n    dense_conv = nn.DenseGraphConv(5, out_dim, norm=norm_type, bias=True)\n    dense_conv.weight.data = conv.weight.data\n    dense_conv.bias.data = conv.bias.data\n    feat = F.randn((g.number_of_src_nodes(), 5))\n    conv = conv.to(ctx)\n    dense_conv = dense_conv.to(ctx)\n    out_conv = conv(g, feat)\n    out_dense_conv = dense_conv(adj, feat)\n    assert F.allclose(out_conv, out_dense_conv)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"g\", get_cases([\"homo\", \"bipartite\"]))\n@pytest.mark.parametrize(\"out_dim\", [1, 2])\ndef test_dense_sage_conv(g, idtype, out_dim):\n    g = g.astype(idtype).to(F.ctx())\n    ctx = F.ctx()\n    adj = g.adj_external(transpose=True, ctx=ctx).to_dense()\n    sage = nn.SAGEConv(5, out_dim, \"gcn\")\n    dense_sage = nn.DenseSAGEConv(5, out_dim)\n    dense_sage.fc.weight.data = sage.fc_neigh.weight.data\n    dense_sage.fc.bias.data = sage.bias.data\n    if len(g.ntypes) == 2:\n        feat = (\n            F.randn((g.number_of_src_nodes(), 5)),\n            F.randn((g.number_of_dst_nodes(), 5)),\n        )\n    else:\n        feat = F.randn((g.num_nodes(), 5))\n    sage = sage.to(ctx)\n    dense_sage = dense_sage.to(ctx)\n    out_sage = sage(g, feat)\n    out_dense_sage = dense_sage(adj, feat)\n    assert F.allclose(out_sage, out_dense_sage), g\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\n    \"g\", get_cases([\"homo\", \"block-bipartite\"], exclude=[\"zero-degree\"])\n)\n@pytest.mark.parametrize(\"out_dim\", [1, 2])\ndef test_edge_conv(g, idtype, out_dim):\n    g = g.astype(idtype).to(F.ctx())\n    ctx = F.ctx()\n    edge_conv = nn.EdgeConv(5, out_dim).to(ctx)\n    print(edge_conv)\n\n    # test pickle\n    th.save(edge_conv, tmp_buffer)\n\n    h0 = F.randn((g.number_of_src_nodes(), 5))\n    h1 = edge_conv(g, h0)\n    assert h1.shape == (g.number_of_dst_nodes(), out_dim)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"g\", get_cases([\"bipartite\"], exclude=[\"zero-degree\"]))\n@pytest.mark.parametrize(\"out_dim\", [1, 2])\ndef test_edge_conv_bi(g, idtype, out_dim):\n    g = g.astype(idtype).to(F.ctx())\n    ctx = F.ctx()\n    edge_conv = nn.EdgeConv(5, out_dim).to(ctx)\n    print(edge_conv)\n    h0 = F.randn((g.number_of_src_nodes(), 5))\n    x0 = F.randn((g.number_of_dst_nodes(), 5))\n    h1 = edge_conv(g, (h0, x0))\n    assert h1.shape == (g.number_of_dst_nodes(), out_dim)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\n    \"g\", get_cases([\"homo\", \"block-bipartite\"], exclude=[\"zero-degree\"])\n)\n@pytest.mark.parametrize(\"out_dim\", [1, 2])\n@pytest.mark.parametrize(\"num_heads\", [1, 4])\ndef test_dotgat_conv(g, idtype, out_dim, num_heads):\n    g = g.astype(idtype).to(F.ctx())\n    ctx = F.ctx()\n    dotgat = nn.DotGatConv(5, out_dim, num_heads)\n    feat = F.randn((g.number_of_src_nodes(), 5))\n    dotgat = dotgat.to(ctx)\n\n    # test pickle\n    th.save(dotgat, tmp_buffer)\n\n    h = dotgat(g, feat)\n    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)\n    _, a = dotgat(g, feat, get_attention=True)\n    assert a.shape == (g.num_edges(), num_heads, 1)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"g\", get_cases([\"bipartite\"], exclude=[\"zero-degree\"]))\n@pytest.mark.parametrize(\"out_dim\", [1, 2])\n@pytest.mark.parametrize(\"num_heads\", [1, 4])\ndef test_dotgat_conv_bi(g, idtype, out_dim, num_heads):\n    g = g.astype(idtype).to(F.ctx())\n    ctx = F.ctx()\n    dotgat = nn.DotGatConv((5, 5), out_dim, num_heads)\n    feat = (\n        F.randn((g.number_of_src_nodes(), 5)),\n        F.randn((g.number_of_dst_nodes(), 5)),\n    )\n    dotgat = dotgat.to(ctx)\n    h = dotgat(g, feat)\n    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)\n    _, a = dotgat(g, feat, get_attention=True)\n    assert a.shape == (g.num_edges(), num_heads, 1)\n\n\n@pytest.mark.parametrize(\"out_dim\", [1, 2])\ndef test_dense_cheb_conv(out_dim):\n    for k in range(1, 4):\n        ctx = F.ctx()\n        g = dgl.from_scipy(sp.sparse.random(100, 100, density=0.1))\n        g = g.to(F.ctx())\n        adj = g.adj_external(transpose=True, ctx=ctx).to_dense()\n        cheb = nn.ChebConv(5, out_dim, k, None)\n        dense_cheb = nn.DenseChebConv(5, out_dim, k)\n        # for i in range(len(cheb.fc)):\n        #    dense_cheb.W.data[i] = cheb.fc[i].weight.data.t()\n        dense_cheb.W.data = cheb.linear.weight.data.transpose(-1, -2).view(\n            k, 5, out_dim\n        )\n        if cheb.linear.bias is not None:\n            dense_cheb.bias.data = cheb.linear.bias.data\n        feat = F.randn((100, 5))\n        cheb = cheb.to(ctx)\n        dense_cheb = dense_cheb.to(ctx)\n        out_cheb = cheb(g, feat, [2.0])\n        out_dense_cheb = dense_cheb(adj, feat, 2.0)\n        print(k, out_cheb, out_dense_cheb)\n        assert F.allclose(out_cheb, out_dense_cheb)\n\n\ndef test_sequential():\n    ctx = F.ctx()\n\n    # Test single graph\n    class ExampleLayer(th.nn.Module):\n        def __init__(self):\n            super().__init__()\n\n        def forward(self, graph, n_feat, e_feat):\n            graph = graph.local_var()\n            graph.ndata[\"h\"] = n_feat\n            graph.update_all(fn.copy_u(\"h\", \"m\"), fn.sum(\"m\", \"h\"))\n            n_feat += graph.ndata[\"h\"]\n            graph.apply_edges(fn.u_add_v(\"h\", \"h\", \"e\"))\n            e_feat += graph.edata[\"e\"]\n            return n_feat, e_feat\n\n    g = dgl.graph([])\n    g.add_nodes(3)\n    g.add_edges([0, 1, 2, 0, 1, 2, 0, 1, 2], [0, 0, 0, 1, 1, 1, 2, 2, 2])\n    g = g.to(F.ctx())\n    net = nn.Sequential(ExampleLayer(), ExampleLayer(), ExampleLayer())\n    n_feat = F.randn((3, 4))\n    e_feat = F.randn((9, 4))\n    net = net.to(ctx)\n    n_feat, e_feat = net(g, n_feat, e_feat)\n    assert n_feat.shape == (3, 4)\n    assert e_feat.shape == (9, 4)\n\n    # Test multiple graph\n    class ExampleLayer(th.nn.Module):\n        def __init__(self):\n            super().__init__()\n\n        def forward(self, graph, n_feat):\n            graph = graph.local_var()\n            graph.ndata[\"h\"] = n_feat\n            graph.update_all(fn.copy_u(\"h\", \"m\"), fn.sum(\"m\", \"h\"))\n            n_feat += graph.ndata[\"h\"]\n            return n_feat.view(graph.num_nodes() // 2, 2, -1).sum(1)\n\n    g1 = dgl.from_networkx(nx.erdos_renyi_graph(32, 0.05)).to(ctx)\n    g2 = dgl.from_networkx(nx.erdos_renyi_graph(16, 0.2)).to(ctx)\n    g3 = dgl.from_networkx(nx.erdos_renyi_graph(8, 0.8)).to(ctx)\n    net = nn.Sequential(ExampleLayer(), ExampleLayer(), ExampleLayer())\n    net = net.to(ctx)\n    n_feat = F.randn((32, 4))\n    n_feat = net([g1, g2, g3], n_feat)\n    assert n_feat.shape == (4, 4)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"g\", get_cases([\"homo\"], exclude=[\"zero-degree\"]))\ndef test_atomic_conv(g, idtype):\n    g = g.astype(idtype).to(F.ctx())\n    aconv = nn.AtomicConv(\n        interaction_cutoffs=F.tensor([12.0, 12.0]),\n        rbf_kernel_means=F.tensor([0.0, 2.0]),\n        rbf_kernel_scaling=F.tensor([4.0, 4.0]),\n        features_to_use=F.tensor([6.0, 8.0]),\n    )\n\n    ctx = F.ctx()\n    if F.gpu_ctx():\n        aconv = aconv.to(ctx)\n\n    feat = F.randn((g.num_nodes(), 1))\n    dist = F.randn((g.num_edges(), 1))\n\n    h = aconv(g, feat, dist)\n\n    # current we only do shape check\n    assert h.shape[-1] == 4\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\n    \"g\", get_cases([\"homo\", \"bipartite\"], exclude=[\"zero-degree\"])\n)\n@pytest.mark.parametrize(\"out_dim\", [1, 3])\ndef test_cf_conv(g, idtype, out_dim):\n    g = g.astype(idtype).to(F.ctx())\n    cfconv = nn.CFConv(\n        node_in_feats=2, edge_in_feats=3, hidden_feats=2, out_feats=out_dim\n    )\n\n    ctx = F.ctx()\n    if F.gpu_ctx():\n        cfconv = cfconv.to(ctx)\n\n    src_feats = F.randn((g.number_of_src_nodes(), 2))\n    edge_feats = F.randn((g.num_edges(), 3))\n    h = cfconv(g, src_feats, edge_feats)\n    # current we only do shape check\n    assert h.shape[-1] == out_dim\n\n    # case for bipartite graphs\n    dst_feats = F.randn((g.number_of_dst_nodes(), 3))\n    h = cfconv(g, (src_feats, dst_feats), edge_feats)\n    # current we only do shape check\n    assert h.shape[-1] == out_dim\n\n\ndef myagg(alist, dsttype):\n    rst = alist[0]\n    for i in range(1, len(alist)):\n        rst = rst + (i + 1) * alist[i]\n    return rst\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"agg\", [\"sum\", \"max\", \"min\", \"mean\", \"stack\", myagg])\n@pytest.mark.parametrize(\"canonical_keys\", [False, True])\ndef test_hetero_conv(agg, idtype, canonical_keys):\n    g = dgl.heterograph(\n        {\n            (\"user\", \"follows\", \"user\"): ([0, 0, 2, 1], [1, 2, 1, 3]),\n            (\"user\", \"plays\", \"game\"): ([0, 0, 0, 1, 2], [0, 2, 3, 0, 2]),\n            (\"store\", \"sells\", \"game\"): ([0, 0, 1, 1], [0, 3, 1, 2]),\n        },\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    if not canonical_keys:\n        conv = nn.HeteroGraphConv(\n            {\n                \"follows\": nn.GraphConv(2, 3, allow_zero_in_degree=True),\n                \"plays\": nn.GraphConv(2, 4, allow_zero_in_degree=True),\n                \"sells\": nn.GraphConv(3, 4, allow_zero_in_degree=True),\n            },\n            agg,\n        )\n    else:\n        conv = nn.HeteroGraphConv(\n            {\n                (\"user\", \"follows\", \"user\"): nn.GraphConv(\n                    2, 3, allow_zero_in_degree=True\n                ),\n                (\"user\", \"plays\", \"game\"): nn.GraphConv(\n                    2, 4, allow_zero_in_degree=True\n                ),\n                (\"store\", \"sells\", \"game\"): nn.GraphConv(\n                    3, 4, allow_zero_in_degree=True\n                ),\n            },\n            agg,\n        )\n\n    conv = conv.to(F.ctx())\n\n    # test pickle\n    th.save(conv, tmp_buffer)\n\n    uf = F.randn((4, 2))\n    gf = F.randn((4, 4))\n    sf = F.randn((2, 3))\n\n    h = conv(g, {\"user\": uf, \"game\": gf, \"store\": sf})\n    assert set(h.keys()) == {\"user\", \"game\"}\n    if agg != \"stack\":\n        assert h[\"user\"].shape == (4, 3)\n        assert h[\"game\"].shape == (4, 4)\n    else:\n        assert h[\"user\"].shape == (4, 1, 3)\n        assert h[\"game\"].shape == (4, 2, 4)\n\n    block = dgl.to_block(\n        g.to(F.cpu()), {\"user\": [0, 1, 2, 3], \"game\": [0, 1, 2, 3], \"store\": []}\n    ).to(F.ctx())\n    h = conv(\n        block,\n        (\n            {\"user\": uf, \"game\": gf, \"store\": sf},\n            {\"user\": uf, \"game\": gf, \"store\": sf[0:0]},\n        ),\n    )\n    assert set(h.keys()) == {\"user\", \"game\"}\n    if agg != \"stack\":\n        assert h[\"user\"].shape == (4, 3)\n        assert h[\"game\"].shape == (4, 4)\n    else:\n        assert h[\"user\"].shape == (4, 1, 3)\n        assert h[\"game\"].shape == (4, 2, 4)\n\n    h = conv(block, {\"user\": uf, \"game\": gf, \"store\": sf})\n    assert set(h.keys()) == {\"user\", \"game\"}\n    if agg != \"stack\":\n        assert h[\"user\"].shape == (4, 3)\n        assert h[\"game\"].shape == (4, 4)\n    else:\n        assert h[\"user\"].shape == (4, 1, 3)\n        assert h[\"game\"].shape == (4, 2, 4)\n\n    # test with mod args\n    class MyMod(th.nn.Module):\n        def __init__(self, s1, s2):\n            super(MyMod, self).__init__()\n            self.carg1 = 0\n            self.carg2 = 0\n            self.s1 = s1\n            self.s2 = s2\n\n        def forward(self, g, h, arg1=None, *, arg2=None):\n            if arg1 is not None:\n                self.carg1 += 1\n            if arg2 is not None:\n                self.carg2 += 1\n            return th.zeros((g.number_of_dst_nodes(), self.s2))\n\n    mod1 = MyMod(2, 3)\n    mod2 = MyMod(2, 4)\n    mod3 = MyMod(3, 4)\n    conv = nn.HeteroGraphConv(\n        {\"follows\": mod1, \"plays\": mod2, \"sells\": mod3}, agg\n    )\n    conv = conv.to(F.ctx())\n    mod_args = {\"follows\": (1,), \"plays\": (1,)}\n    mod_kwargs = {\"sells\": {\"arg2\": \"abc\"}}\n    h = conv(\n        g,\n        {\"user\": uf, \"game\": gf, \"store\": sf},\n        mod_args=mod_args,\n        mod_kwargs=mod_kwargs,\n    )\n    assert mod1.carg1 == 1\n    assert mod1.carg2 == 0\n    assert mod2.carg1 == 1\n    assert mod2.carg2 == 0\n    assert mod3.carg1 == 0\n    assert mod3.carg2 == 1\n\n    # conv on graph without any edges\n    for etype in g.etypes:\n        g = dgl.remove_edges(g, g.edges(form=\"eid\", etype=etype), etype=etype)\n    assert g.num_edges() == 0\n    h = conv(g, {\"user\": uf, \"game\": gf, \"store\": sf})\n    assert set(h.keys()) == {\"user\", \"game\"}\n\n    block = dgl.to_block(\n        g.to(F.cpu()), {\"user\": [0, 1, 2, 3], \"game\": [0, 1, 2, 3], \"store\": []}\n    ).to(F.ctx())\n    h = conv(\n        block,\n        (\n            {\"user\": uf, \"game\": gf, \"store\": sf},\n            {\"user\": uf, \"game\": gf, \"store\": sf[0:0]},\n        ),\n    )\n    assert set(h.keys()) == {\"user\", \"game\"}\n\n\n@pytest.mark.parametrize(\"out_dim\", [1, 2, 100])\ndef test_hetero_linear(out_dim):\n    in_feats = {\n        \"user\": F.randn((2, 1)),\n        (\"user\", \"follows\", \"user\"): F.randn((3, 2)),\n    }\n\n    layer = nn.HeteroLinear(\n        {\"user\": 1, (\"user\", \"follows\", \"user\"): 2}, out_dim\n    )\n    layer = layer.to(F.ctx())\n    out_feats = layer(in_feats)\n    assert out_feats[\"user\"].shape == (2, out_dim)\n    assert out_feats[(\"user\", \"follows\", \"user\")].shape == (3, out_dim)\n\n\n@pytest.mark.parametrize(\"out_dim\", [1, 2, 100])\ndef test_hetero_embedding(out_dim):\n    layer = nn.HeteroEmbedding(\n        {\"user\": 2, (\"user\", \"follows\", \"user\"): 3}, out_dim\n    )\n    layer = layer.to(F.ctx())\n\n    embeds = layer.weight\n    assert embeds[\"user\"].shape == (2, out_dim)\n    assert embeds[(\"user\", \"follows\", \"user\")].shape == (3, out_dim)\n\n    layer.reset_parameters()\n    embeds = layer.weight\n    assert embeds[\"user\"].shape == (2, out_dim)\n    assert embeds[(\"user\", \"follows\", \"user\")].shape == (3, out_dim)\n\n    embeds = layer(\n        {\n            \"user\": F.tensor([0], dtype=F.int64),\n            (\"user\", \"follows\", \"user\"): F.tensor([0, 2], dtype=F.int64),\n        }\n    )\n    assert embeds[\"user\"].shape == (1, out_dim)\n    assert embeds[(\"user\", \"follows\", \"user\")].shape == (2, out_dim)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"g\", get_cases([\"homo\"], exclude=[\"zero-degree\"]))\n@pytest.mark.parametrize(\"out_dim\", [1, 2])\ndef test_gnnexplainer(g, idtype, out_dim):\n    g = g.astype(idtype).to(F.ctx())\n    feat = F.randn((g.num_nodes(), 5))\n\n    class Model(th.nn.Module):\n        def __init__(self, in_feats, out_feats, graph=False):\n            super(Model, self).__init__()\n            self.linear = th.nn.Linear(in_feats, out_feats)\n            if graph:\n                self.pool = nn.AvgPooling()\n            else:\n                self.pool = None\n\n        def forward(self, graph, feat, eweight=None):\n            with graph.local_scope():\n                feat = self.linear(feat)\n                graph.ndata[\"h\"] = feat\n                if eweight is None:\n                    graph.update_all(fn.copy_u(\"h\", \"m\"), fn.sum(\"m\", \"h\"))\n                else:\n                    graph.edata[\"w\"] = eweight\n                    graph.update_all(\n                        fn.u_mul_e(\"h\", \"w\", \"m\"), fn.sum(\"m\", \"h\")\n                    )\n\n                if self.pool:\n                    return self.pool(graph, graph.ndata[\"h\"])\n                else:\n                    return graph.ndata[\"h\"]\n\n    # Explain node prediction\n    model = Model(5, out_dim)\n    model = model.to(F.ctx())\n    explainer = nn.GNNExplainer(model, num_hops=1)\n    new_center, sg, feat_mask, edge_mask = explainer.explain_node(0, g, feat)\n\n    # Explain graph prediction\n    model = Model(5, out_dim, graph=True)\n    model = model.to(F.ctx())\n    explainer = nn.GNNExplainer(model, num_hops=1)\n    feat_mask, edge_mask = explainer.explain_graph(g, feat)\n\n\n@pytest.mark.parametrize(\"g\", get_cases([\"hetero\"], exclude=[\"zero-degree\"]))\n@pytest.mark.parametrize(\"idtype\", [F.int64])\n@pytest.mark.parametrize(\"input_dim\", [5])\n@pytest.mark.parametrize(\"output_dim\", [1, 2])\ndef test_heterognnexplainer(g, idtype, input_dim, output_dim):\n    g = g.astype(idtype).to(F.ctx())\n    device = g.device\n\n    # add self-loop and reverse edges\n    transform1 = dgl.transforms.AddSelfLoop(new_etypes=True)\n    g = transform1(g)\n    transform2 = dgl.transforms.AddReverse(copy_edata=True)\n    g = transform2(g)\n\n    feat = {\n        ntype: th.zeros((g.num_nodes(ntype), input_dim), device=device)\n        for ntype in g.ntypes\n    }\n\n    class Model(th.nn.Module):\n        def __init__(self, in_dim, num_classes, canonical_etypes, graph=False):\n            super(Model, self).__init__()\n            self.graph = graph\n            self.etype_weights = th.nn.ModuleDict(\n                {\n                    \"_\".join(c_etype): th.nn.Linear(in_dim, num_classes)\n                    for c_etype in canonical_etypes\n                }\n            )\n\n        def forward(self, graph, feat, eweight=None):\n            with graph.local_scope():\n                c_etype_func_dict = {}\n                for c_etype in graph.canonical_etypes:\n                    src_type, etype, dst_type = c_etype\n                    wh = self.etype_weights[\"_\".join(c_etype)](feat[src_type])\n                    graph.nodes[src_type].data[f\"h_{c_etype}\"] = wh\n                    if eweight is None:\n                        c_etype_func_dict[c_etype] = (\n                            fn.copy_u(f\"h_{c_etype}\", \"m\"),\n                            fn.mean(\"m\", \"h\"),\n                        )\n                    else:\n                        graph.edges[c_etype].data[\"w\"] = eweight[c_etype]\n                        c_etype_func_dict[c_etype] = (\n                            fn.u_mul_e(f\"h_{c_etype}\", \"w\", \"m\"),\n                            fn.mean(\"m\", \"h\"),\n                        )\n                graph.multi_update_all(c_etype_func_dict, \"sum\")\n                if self.graph:\n                    hg = 0\n                    for ntype in graph.ntypes:\n                        if graph.num_nodes(ntype):\n                            hg = hg + dgl.mean_nodes(graph, \"h\", ntype=ntype)\n\n                    return hg\n                else:\n                    return graph.ndata[\"h\"]\n\n    # Explain node prediction\n    model = Model(input_dim, output_dim, g.canonical_etypes)\n    model = model.to(F.ctx())\n    ntype = g.ntypes[0]\n    explainer = nn.explain.HeteroGNNExplainer(model, num_hops=1)\n    new_center, sg, feat_mask, edge_mask = explainer.explain_node(\n        ntype, 0, g, feat\n    )\n\n    # Explain graph prediction\n    model = Model(input_dim, output_dim, g.canonical_etypes, graph=True)\n    model = model.to(F.ctx())\n    explainer = nn.explain.HeteroGNNExplainer(model, num_hops=1)\n    feat_mask, edge_mask = explainer.explain_graph(g, feat)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\n    \"g\",\n    get_cases(\n        [\"homo\"],\n        exclude=[\n            \"zero-degree\",\n            \"homo-zero-degree\",\n            \"has_feature\",\n            \"has_scalar_e_feature\",\n            \"row_sorted\",\n            \"col_sorted\",\n            \"batched\",\n        ],\n    ),\n)\n@pytest.mark.parametrize(\"n_classes\", [2])\ndef test_subgraphx(g, idtype, n_classes):\n    ctx = F.ctx()\n    g = g.astype(idtype).to(ctx)\n    feat = F.randn((g.num_nodes(), 5))\n\n    class Model(th.nn.Module):\n        def __init__(self, in_dim, n_classes):\n            super().__init__()\n            self.conv = nn.GraphConv(in_dim, n_classes)\n            self.pool = nn.AvgPooling()\n\n        def forward(self, g, h):\n            h = th.nn.functional.relu(self.conv(g, h))\n            return self.pool(g, h)\n\n    model = Model(feat.shape[1], n_classes)\n    model = model.to(ctx)\n    explainer = nn.SubgraphX(\n        model, num_hops=1, shapley_steps=20, num_rollouts=5, coef=2.0\n    )\n    explainer.explain_graph(g, feat, target_class=0)\n\n\n@pytest.mark.parametrize(\"g\", get_cases([\"hetero\"], exclude=[\"zero-degree\"]))\n@pytest.mark.parametrize(\"idtype\", [F.int64])\n@pytest.mark.parametrize(\"input_dim\", [5])\n@pytest.mark.parametrize(\"n_classes\", [2])\ndef test_heterosubgraphx(g, idtype, input_dim, n_classes):\n    ctx = F.ctx()\n    g = g.astype(idtype).to(ctx)\n    device = g.device\n\n    # add self-loop and reverse edges\n    transform1 = dgl.transforms.AddSelfLoop(new_etypes=True)\n    g = transform1(g)\n    transform2 = dgl.transforms.AddReverse(copy_edata=True)\n    g = transform2(g)\n\n    feat = {\n        ntype: th.zeros((g.num_nodes(ntype), input_dim), device=device)\n        for ntype in g.ntypes\n    }\n\n    class Model(th.nn.Module):\n        def __init__(self, in_dim, n_classes, canonical_etypes):\n            super(Model, self).__init__()\n            self.etype_weights = th.nn.ModuleDict(\n                {\n                    \"_\".join(c_etype): th.nn.Linear(in_dim, n_classes)\n                    for c_etype in canonical_etypes\n                }\n            )\n\n        def forward(self, graph, feat):\n            with graph.local_scope():\n                c_etype_func_dict = {}\n                for c_etype in graph.canonical_etypes:\n                    src_type, etype, dst_type = c_etype\n                    wh = self.etype_weights[\"_\".join(c_etype)](feat[src_type])\n                    graph.nodes[src_type].data[f\"h_{c_etype}\"] = wh\n                    c_etype_func_dict[c_etype] = (\n                        fn.copy_u(f\"h_{c_etype}\", \"m\"),\n                        fn.mean(\"m\", \"h\"),\n                    )\n                graph.multi_update_all(c_etype_func_dict, \"sum\")\n                hg = 0\n                for ntype in graph.ntypes:\n                    if graph.num_nodes(ntype):\n                        hg = hg + dgl.mean_nodes(graph, \"h\", ntype=ntype)\n\n                return hg\n\n    model = Model(input_dim, n_classes, g.canonical_etypes)\n    model = model.to(ctx)\n    explainer = nn.HeteroSubgraphX(\n        model, num_hops=1, shapley_steps=20, num_rollouts=5, coef=2.0\n    )\n    explainer.explain_graph(g, feat, target_class=0)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\n    \"g\",\n    get_cases(\n        [\"homo\"],\n        exclude=[\n            \"zero-degree\",\n            \"homo-zero-degree\",\n            \"has_feature\",\n            \"has_scalar_e_feature\",\n            \"row_sorted\",\n            \"col_sorted\",\n        ],\n    ),\n)\n@pytest.mark.parametrize(\"n_classes\", [2])\ndef test_pgexplainer(g, idtype, n_classes):\n    ctx = F.ctx()\n    g = g.astype(idtype).to(ctx)\n    feat = F.randn((g.num_nodes(), 5))\n    g.ndata[\"attr\"] = feat\n\n    # add reverse edges\n    transform = dgl.transforms.AddReverse(copy_edata=True)\n    g = transform(g)\n\n    class Model(th.nn.Module):\n        def __init__(self, in_feats, out_feats, graph=False):\n            super(Model, self).__init__()\n            self.graph = graph\n            self.conv = nn.GraphConv(in_feats, out_feats)\n            self.fc = th.nn.Linear(out_feats, out_feats)\n            th.nn.init.xavier_uniform_(self.fc.weight)\n\n        def forward(self, g, h, embed=False, edge_weight=None):\n            h = self.conv(g, h, edge_weight=edge_weight)\n\n            if not self.graph or embed:\n                return h\n\n            with g.local_scope():\n                g.ndata[\"h\"] = h\n                hg = dgl.mean_nodes(g, \"h\")\n                return self.fc(hg)\n\n    # graph explainer\n    model = Model(feat.shape[1], n_classes, graph=True)\n    model = model.to(ctx)\n    explainer = nn.PGExplainer(model, n_classes)\n    explainer.train_step(g, g.ndata[\"attr\"], 5.0)\n\n    probs, edge_weight = explainer.explain_graph(g, feat)\n\n    # node explainer\n    model = Model(feat.shape[1], n_classes, graph=False)\n    model = model.to(ctx)\n    explainer = nn.PGExplainer(\n        model, n_classes, num_hops=1, explain_graph=False\n    )\n    explainer.train_step_node(0, g, g.ndata[\"attr\"], 5.0)\n    explainer.train_step_node([0, 1], g, g.ndata[\"attr\"], 5.0)\n    explainer.train_step_node(th.tensor(0), g, g.ndata[\"attr\"], 5.0)\n    explainer.train_step_node(th.tensor([0, 1]), g, g.ndata[\"attr\"], 5.0)\n\n    probs, edge_weight, bg, inverse_indices = explainer.explain_node(0, g, feat)\n    probs, edge_weight, bg, inverse_indices = explainer.explain_node(\n        [0, 1], g, feat\n    )\n    probs, edge_weight, bg, inverse_indices = explainer.explain_node(\n        th.tensor(0), g, feat\n    )\n    probs, edge_weight, bg, inverse_indices = explainer.explain_node(\n        th.tensor([0, 1]), g, feat\n    )\n\n\n@pytest.mark.parametrize(\"g\", get_cases([\"hetero\"]))\n@pytest.mark.parametrize(\"idtype\", [F.int64])\n@pytest.mark.parametrize(\"input_dim\", [5])\n@pytest.mark.parametrize(\"n_classes\", [2])\ndef test_heteropgexplainer(g, idtype, input_dim, n_classes):\n    ctx = F.ctx()\n    g = g.astype(idtype).to(ctx)\n    feat = {\n        ntype: F.randn((g.num_nodes(ntype), input_dim)) for ntype in g.ntypes\n    }\n\n    # add self-loop and reverse edges\n    transform1 = dgl.transforms.AddSelfLoop(new_etypes=True)\n    g = transform1(g)\n    transform2 = dgl.transforms.AddReverse(copy_edata=True)\n    g = transform2(g)\n\n    class Model(th.nn.Module):\n        def __init__(\n            self, in_feats, embed_dim, out_feats, canonical_etypes, graph=True\n        ):\n            super(Model, self).__init__()\n            self.graph = graph\n            self.conv = nn.HeteroGraphConv(\n                {\n                    c_etype: nn.GraphConv(in_feats, embed_dim)\n                    for c_etype in canonical_etypes\n                }\n            )\n            self.fc = th.nn.Linear(embed_dim, out_feats)\n\n        def forward(self, g, h, embed=False, edge_weight=None):\n            if edge_weight is not None:\n                mod_kwargs = {\n                    etype: {\"edge_weight\": mask}\n                    for etype, mask in edge_weight.items()\n                }\n                h = self.conv(g, h, mod_kwargs=mod_kwargs)\n            else:\n                h = self.conv(g, h)\n\n            if not self.graph or embed:\n                return h\n\n            with g.local_scope():\n                g.ndata[\"h\"] = h\n                hg = 0\n                for ntype in g.ntypes:\n                    hg = hg + dgl.mean_nodes(g, \"h\", ntype=ntype)\n                return self.fc(hg)\n\n    embed_dim = input_dim\n\n    # graph explainer\n    model = Model(\n        input_dim, embed_dim, n_classes, g.canonical_etypes, graph=True\n    )\n    model = model.to(ctx)\n    explainer = nn.HeteroPGExplainer(model, embed_dim)\n    explainer.train_step(g, feat, 5.0)\n\n    probs, edge_weight = explainer.explain_graph(g, feat)\n\n    # node explainer\n    model = Model(\n        input_dim, embed_dim, n_classes, g.canonical_etypes, graph=False\n    )\n    model = model.to(ctx)\n    explainer = nn.HeteroPGExplainer(\n        model, embed_dim, num_hops=1, explain_graph=False\n    )\n    explainer.train_step_node({g.ntypes[0]: [0]}, g, feat, 5.0)\n    explainer.train_step_node({g.ntypes[0]: th.tensor([0, 1])}, g, feat, 5.0)\n\n    probs, edge_weight, bg, inverse_indices = explainer.explain_node(\n        {g.ntypes[0]: [0]}, g, feat\n    )\n    probs, edge_weight, bg, inverse_indices = explainer.explain_node(\n        {g.ntypes[0]: th.tensor([0, 1])}, g, feat\n    )\n\n\ndef test_jumping_knowledge():\n    ctx = F.ctx()\n    num_layers = 2\n    num_nodes = 3\n    num_feats = 4\n\n    feat_list = [\n        th.randn((num_nodes, num_feats)).to(ctx) for _ in range(num_layers)\n    ]\n\n    model = nn.JumpingKnowledge(\"cat\").to(ctx)\n    model.reset_parameters()\n    assert model(feat_list).shape == (num_nodes, num_layers * num_feats)\n\n    model = nn.JumpingKnowledge(\"max\").to(ctx)\n    model.reset_parameters()\n    assert model(feat_list).shape == (num_nodes, num_feats)\n\n    model = nn.JumpingKnowledge(\"lstm\", num_feats, num_layers).to(ctx)\n    model.reset_parameters()\n    assert model(feat_list).shape == (num_nodes, num_feats)\n\n\n@pytest.mark.parametrize(\"op\", [\"dot\", \"cos\", \"ele\", \"cat\"])\ndef test_edge_predictor(op):\n    ctx = F.ctx()\n    num_pairs = 3\n    in_feats = 4\n    out_feats = 5\n    h_src = th.randn((num_pairs, in_feats)).to(ctx)\n    h_dst = th.randn((num_pairs, in_feats)).to(ctx)\n\n    pred = nn.EdgePredictor(op)\n    if op in [\"dot\", \"cos\"]:\n        assert pred(h_src, h_dst).shape == (num_pairs, 1)\n    elif op == \"ele\":\n        assert pred(h_src, h_dst).shape == (num_pairs, in_feats)\n    else:\n        assert pred(h_src, h_dst).shape == (num_pairs, 2 * in_feats)\n    pred = nn.EdgePredictor(op, in_feats, out_feats, bias=True).to(ctx)\n    assert pred(h_src, h_dst).shape == (num_pairs, out_feats)\n\n\ndef test_ke_score_funcs():\n    ctx = F.ctx()\n    num_edges = 30\n    num_rels = 3\n    nfeats = 4\n\n    h_src = th.randn((num_edges, nfeats)).to(ctx)\n    h_dst = th.randn((num_edges, nfeats)).to(ctx)\n    rels = th.randint(low=0, high=num_rels, size=(num_edges,)).to(ctx)\n\n    score_func = nn.TransE(num_rels=num_rels, feats=nfeats).to(ctx)\n    score_func.reset_parameters()\n    score_func(h_src, h_dst, rels).shape == (num_edges)\n\n    score_func = nn.TransR(\n        num_rels=num_rels, rfeats=nfeats - 1, nfeats=nfeats\n    ).to(ctx)\n    score_func.reset_parameters()\n    score_func(h_src, h_dst, rels).shape == (num_edges)\n\n\ndef test_twirls():\n    g = dgl.graph(([0, 1, 2, 3, 2, 5], [1, 2, 3, 4, 0, 3]))\n    feat = th.ones(6, 10)\n    conv = nn.TWIRLSConv(10, 2, 128, prop_step=64)\n    res = conv(g, feat)\n    assert res.size() == (6, 2)\n\n\n@pytest.mark.parametrize(\"feat_size\", [4, 32])\n@pytest.mark.parametrize(\n    \"regularizer,num_bases\", [(None, None), (\"basis\", 4), (\"bdd\", 4)]\n)\ndef test_typed_linear(feat_size, regularizer, num_bases):\n    dev = F.ctx()\n    num_types = 5\n    lin = nn.TypedLinear(\n        feat_size,\n        feat_size * 2,\n        5,\n        regularizer=regularizer,\n        num_bases=num_bases,\n    ).to(dev)\n    print(lin)\n    x = th.randn(100, feat_size).to(dev)\n    x_type = th.randint(0, 5, (100,)).to(dev)\n    x_type_sorted, idx = th.sort(x_type)\n    _, rev_idx = th.sort(idx)\n    x_sorted = x[idx]\n\n    # test unsorted\n    y = lin(x, x_type)\n    assert y.shape == (100, feat_size * 2)\n    # test sorted\n    y_sorted = lin(x_sorted, x_type_sorted, sorted_by_type=True)\n    assert y_sorted.shape == (100, feat_size * 2)\n\n    assert th.allclose(y, y_sorted[rev_idx], atol=1e-4, rtol=1e-4)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"in_size\", [4])\n@pytest.mark.parametrize(\"num_heads\", [1])\ndef test_hgt(idtype, in_size, num_heads):\n    dev = F.ctx()\n    num_etypes = 5\n    num_ntypes = 2\n    head_size = in_size // num_heads\n\n    g = dgl.from_scipy(sp.sparse.random(100, 100, density=0.01))\n    g = g.astype(idtype).to(dev)\n    etype = th.tensor([i % num_etypes for i in range(g.num_edges())]).to(dev)\n    ntype = th.tensor([i % num_ntypes for i in range(g.num_nodes())]).to(dev)\n    x = th.randn(g.num_nodes(), in_size).to(dev)\n\n    m = nn.HGTConv(in_size, head_size, num_heads, num_ntypes, num_etypes).to(\n        dev\n    )\n\n    y = m(g, x, ntype, etype)\n    assert y.shape == (g.num_nodes(), head_size * num_heads)\n    # presorted\n    sorted_ntype, idx_nt = th.sort(ntype)\n    sorted_etype, idx_et = th.sort(etype)\n    _, rev_idx = th.sort(idx_nt)\n    g.ndata[\"t\"] = ntype\n    g.ndata[\"x\"] = x\n    g.edata[\"t\"] = etype\n    sorted_g = dgl.reorder_graph(\n        g,\n        node_permute_algo=\"custom\",\n        edge_permute_algo=\"custom\",\n        permute_config={\n            \"nodes_perm\": idx_nt.to(idtype),\n            \"edges_perm\": idx_et.to(idtype),\n        },\n    )\n    print(sorted_g.ndata[\"t\"])\n    print(sorted_g.edata[\"t\"])\n    sorted_x = sorted_g.ndata[\"x\"]\n    sorted_y = m(\n        sorted_g, sorted_x, sorted_ntype, sorted_etype, presorted=False\n    )\n    assert sorted_y.shape == (g.num_nodes(), head_size * num_heads)\n    # mini-batch\n    train_idx = th.randperm(100, dtype=idtype)[:10]\n    sampler = dgl.dataloading.NeighborSampler([-1])\n    train_loader = dgl.dataloading.DataLoader(\n        g, train_idx.to(dev), sampler, batch_size=8, device=dev, shuffle=True\n    )\n    (input_nodes, output_nodes, block) = next(iter(train_loader))\n    block = block[0]\n    x = x[input_nodes.to(th.long)]\n    ntype = ntype[input_nodes.to(th.long)]\n    edge = block.edata[dgl.EID]\n    etype = etype[edge.to(th.long)]\n    y = m(block, x, ntype, etype)\n    assert y.shape == (block.number_of_dst_nodes(), head_size * num_heads)\n    # TODO(minjie): enable the following check\n    # assert th.allclose(y, sorted_y[rev_idx], atol=1e-4, rtol=1e-4)\n\n\n@pytest.mark.parametrize(\"self_loop\", [True, False])\n@pytest.mark.parametrize(\"get_distances\", [True, False])\ndef test_radius_graph(self_loop, get_distances):\n    pos = th.tensor(\n        [\n            [0.1, 0.3, 0.4],\n            [0.5, 0.2, 0.1],\n            [0.7, 0.9, 0.5],\n            [0.3, 0.2, 0.5],\n            [0.2, 0.8, 0.2],\n            [0.9, 0.2, 0.1],\n            [0.7, 0.4, 0.4],\n            [0.2, 0.1, 0.6],\n            [0.5, 0.3, 0.5],\n            [0.4, 0.2, 0.6],\n        ]\n    )\n\n    rg = nn.RadiusGraph(0.3, self_loop=self_loop)\n\n    if get_distances:\n        g, dists = rg(pos, get_distances=get_distances)\n    else:\n        g = rg(pos)\n\n    if self_loop:\n        src_target = th.tensor(\n            [\n                0,\n                0,\n                1,\n                2,\n                3,\n                3,\n                3,\n                3,\n                3,\n                4,\n                5,\n                6,\n                6,\n                7,\n                7,\n                7,\n                8,\n                8,\n                8,\n                8,\n                9,\n                9,\n                9,\n                9,\n            ]\n        )\n        dst_target = th.tensor(\n            [\n                0,\n                3,\n                1,\n                2,\n                0,\n                3,\n                7,\n                8,\n                9,\n                4,\n                5,\n                6,\n                8,\n                3,\n                7,\n                9,\n                3,\n                6,\n                8,\n                9,\n                3,\n                7,\n                8,\n                9,\n            ]\n        )\n\n        if get_distances:\n            dists_target = th.tensor(\n                [\n                    [0.0000],\n                    [0.2449],\n                    [0.0000],\n                    [0.0000],\n                    [0.2449],\n                    [0.0000],\n                    [0.1732],\n                    [0.2236],\n                    [0.1414],\n                    [0.0000],\n                    [0.0000],\n                    [0.0000],\n                    [0.2449],\n                    [0.1732],\n                    [0.0000],\n                    [0.2236],\n                    [0.2236],\n                    [0.2449],\n                    [0.0000],\n                    [0.1732],\n                    [0.1414],\n                    [0.2236],\n                    [0.1732],\n                    [0.0000],\n                ]\n            )\n    else:\n        src_target = th.tensor([0, 3, 3, 3, 3, 6, 7, 7, 8, 8, 8, 9, 9, 9])\n        dst_target = th.tensor([3, 0, 7, 8, 9, 8, 3, 9, 3, 6, 9, 3, 7, 8])\n\n        if get_distances:\n            dists_target = th.tensor(\n                [\n                    [0.2449],\n                    [0.2449],\n                    [0.1732],\n                    [0.2236],\n                    [0.1414],\n                    [0.2449],\n                    [0.1732],\n                    [0.2236],\n                    [0.2236],\n                    [0.2449],\n                    [0.1732],\n                    [0.1414],\n                    [0.2236],\n                    [0.1732],\n                ]\n            )\n\n    src, dst = g.edges()\n\n    assert th.equal(src, src_target)\n    assert th.equal(dst, dst_target)\n\n    if get_distances:\n        assert th.allclose(dists, dists_target, rtol=1e-03)\n\n\n@parametrize_idtype\ndef test_group_rev_res(idtype):\n    dev = F.ctx()\n\n    num_nodes = 5\n    num_edges = 20\n    feats = 32\n    groups = 2\n    g = dgl.rand_graph(num_nodes, num_edges).to(dev)\n    h = th.randn(num_nodes, feats).to(dev)\n    conv = nn.GraphConv(feats // groups, feats // groups)\n    model = nn.GroupRevRes(conv, groups).to(dev)\n    result = model(g, h)\n    result.sum().backward()\n\n\n@pytest.mark.parametrize(\"in_size\", [16, 32])\n@pytest.mark.parametrize(\"hidden_size\", [16, 32])\n@pytest.mark.parametrize(\"out_size\", [16, 32])\n@pytest.mark.parametrize(\"edge_feat_size\", [16, 10, 0])\ndef test_egnn_conv(in_size, hidden_size, out_size, edge_feat_size):\n    dev = F.ctx()\n    num_nodes = 5\n    num_edges = 20\n    g = dgl.rand_graph(num_nodes, num_edges).to(dev)\n    h = th.randn(num_nodes, in_size).to(dev)\n    x = th.randn(num_nodes, 3).to(dev)\n    e = th.randn(num_edges, edge_feat_size).to(dev)\n    model = nn.EGNNConv(in_size, hidden_size, out_size, edge_feat_size).to(dev)\n    model(g, h, x, e)\n\n\n@pytest.mark.parametrize(\"in_size\", [16, 32])\n@pytest.mark.parametrize(\"out_size\", [16, 32])\n@pytest.mark.parametrize(\n    \"aggregators\",\n    [\n        [\"mean\", \"max\", \"sum\"],\n        [\"min\", \"std\", \"var\"],\n        [\"moment3\", \"moment4\", \"moment5\"],\n    ],\n)\n@pytest.mark.parametrize(\n    \"scalers\", [[\"identity\"], [\"amplification\", \"attenuation\"]]\n)\n@pytest.mark.parametrize(\"delta\", [2.5, 7.4])\n@pytest.mark.parametrize(\"dropout\", [0.0, 0.1])\n@pytest.mark.parametrize(\"num_towers\", [1, 4])\n@pytest.mark.parametrize(\"edge_feat_size\", [16, 0])\n@pytest.mark.parametrize(\"residual\", [True, False])\ndef test_pna_conv(\n    in_size,\n    out_size,\n    aggregators,\n    scalers,\n    delta,\n    dropout,\n    num_towers,\n    edge_feat_size,\n    residual,\n):\n    dev = F.ctx()\n    num_nodes = 5\n    num_edges = 20\n    g = dgl.rand_graph(num_nodes, num_edges).to(dev)\n    h = th.randn(num_nodes, in_size).to(dev)\n    e = th.randn(num_edges, edge_feat_size).to(dev)\n    model = nn.PNAConv(\n        in_size,\n        out_size,\n        aggregators,\n        scalers,\n        delta,\n        dropout,\n        num_towers,\n        edge_feat_size,\n        residual,\n    ).to(dev)\n    model(g, h, edge_feat=e)\n\n\n@pytest.mark.parametrize(\"k\", [3, 5])\n@pytest.mark.parametrize(\"alpha\", [0.0, 0.5, 1.0])\n@pytest.mark.parametrize(\"norm_type\", [\"sym\", \"row\"])\n@pytest.mark.parametrize(\"clamp\", [True, False])\n@pytest.mark.parametrize(\"normalize\", [True, False])\n@pytest.mark.parametrize(\"reset\", [True, False])\ndef test_label_prop(k, alpha, norm_type, clamp, normalize, reset):\n    dev = F.ctx()\n    num_nodes = 5\n    num_edges = 20\n    num_classes = 4\n    g = dgl.rand_graph(num_nodes, num_edges).to(dev)\n    labels = th.tensor([0, 2, 1, 3, 0]).long().to(dev)\n    ml_labels = th.rand(num_nodes, num_classes).to(dev) > 0.7\n    mask = th.tensor([0, 1, 1, 1, 0]).bool().to(dev)\n    model = nn.LabelPropagation(k, alpha, norm_type, clamp, normalize, reset)\n    model(g, labels, mask)\n    # multi-label case\n    model(g, ml_labels, mask)\n\n\n@pytest.mark.parametrize(\"in_size\", [16])\n@pytest.mark.parametrize(\"out_size\", [16, 32])\n@pytest.mark.parametrize(\n    \"aggregators\", [[\"mean\", \"max\", \"dir2-av\"], [\"min\", \"std\", \"dir1-dx\"]]\n)\n@pytest.mark.parametrize(\"scalers\", [[\"amplification\", \"attenuation\"]])\n@pytest.mark.parametrize(\"delta\", [2.5])\n@pytest.mark.parametrize(\"edge_feat_size\", [16, 0])\ndef test_dgn_conv(\n    in_size, out_size, aggregators, scalers, delta, edge_feat_size\n):\n    dev = F.ctx()\n    num_nodes = 5\n    num_edges = 20\n    g = dgl.rand_graph(num_nodes, num_edges).to(dev)\n    h = th.randn(num_nodes, in_size).to(dev)\n    e = th.randn(num_edges, edge_feat_size).to(dev)\n    transform = dgl.LapPE(k=3, feat_name=\"eig\")\n    g = transform(g)\n    eig = g.ndata[\"eig\"]\n    model = nn.DGNConv(\n        in_size,\n        out_size,\n        aggregators,\n        scalers,\n        delta,\n        edge_feat_size=edge_feat_size,\n    ).to(dev)\n    model(g, h, edge_feat=e, eig_vec=eig)\n\n    aggregators_non_eig = [\n        aggr for aggr in aggregators if not aggr.startswith(\"dir\")\n    ]\n    model = nn.DGNConv(\n        in_size,\n        out_size,\n        aggregators_non_eig,\n        scalers,\n        delta,\n        edge_feat_size=edge_feat_size,\n    ).to(dev)\n    model(g, h, edge_feat=e)\n\n\ndef test_DeepWalk():\n    dev = F.ctx()\n    g = dgl.graph(([0, 1, 2, 1, 2, 0], [1, 2, 0, 0, 1, 2]))\n    model = nn.DeepWalk(\n        g, emb_dim=8, walk_length=2, window_size=1, fast_neg=True, sparse=True\n    )\n    model = model.to(dev)\n    dataloader = DataLoader(\n        torch.arange(g.num_nodes()), batch_size=16, collate_fn=model.sample\n    )\n    optim = SparseAdam(model.parameters(), lr=0.01)\n    walk = next(iter(dataloader)).to(dev)\n    loss = model(walk)\n    loss.backward()\n    optim.step()\n\n    model = nn.DeepWalk(\n        g, emb_dim=8, walk_length=2, window_size=1, fast_neg=False, sparse=False\n    )\n    model = model.to(dev)\n    dataloader = DataLoader(\n        torch.arange(g.num_nodes()), batch_size=16, collate_fn=model.sample\n    )\n    optim = Adam(model.parameters(), lr=0.01)\n    walk = next(iter(dataloader)).to(dev)\n    loss = model(walk)\n    loss.backward()\n    optim.step()\n\n\n@pytest.mark.parametrize(\"max_degree\", [2, 6])\n@pytest.mark.parametrize(\"embedding_dim\", [8, 16])\n@pytest.mark.parametrize(\"direction\", [\"in\", \"out\", \"both\"])\ndef test_degree_encoder(max_degree, embedding_dim, direction):\n    g1 = dgl.graph(\n        (\n            th.tensor([0, 0, 0, 1, 1, 2, 3, 3]),\n            th.tensor([1, 2, 3, 0, 3, 0, 0, 1]),\n        )\n    )\n    g2 = dgl.graph(\n        (\n            th.tensor([0, 1]),\n            th.tensor([1, 0]),\n        )\n    )\n    in_degree = pad_sequence(\n        [g1.in_degrees(), g2.in_degrees()], batch_first=True\n    )\n    out_degree = pad_sequence(\n        [g1.out_degrees(), g2.out_degrees()], batch_first=True\n    )\n    model = nn.DegreeEncoder(max_degree, embedding_dim, direction=direction)\n    if direction == \"in\":\n        de_g = model(in_degree)\n    elif direction == \"out\":\n        de_g = model(out_degree)\n    elif direction == \"both\":\n        de_g = model(th.stack((in_degree, out_degree)))\n    assert de_g.shape == (2, 4, embedding_dim)\n\n\n@parametrize_idtype\ndef test_MetaPath2Vec(idtype):\n    dev = F.ctx()\n    g = dgl.heterograph(\n        {\n            (\"user\", \"uc\", \"company\"): ([0, 0, 2, 1, 3], [1, 2, 1, 3, 0]),\n            (\"company\", \"cp\", \"product\"): (\n                [0, 0, 0, 1, 2, 3],\n                [0, 2, 3, 0, 2, 1],\n            ),\n            (\"company\", \"cu\", \"user\"): ([1, 2, 1, 3, 0], [0, 0, 2, 1, 3]),\n            (\"product\", \"pc\", \"company\"): (\n                [0, 2, 3, 0, 2, 1],\n                [0, 0, 0, 1, 2, 3],\n            ),\n        },\n        idtype=idtype,\n        device=dev,\n    )\n    model = nn.MetaPath2Vec(g, [\"uc\", \"cu\"], window_size=1)\n    model = model.to(dev)\n    embeds = model.node_embed.weight\n    assert embeds.shape[0] == g.num_nodes()\n\n\n@pytest.mark.parametrize(\"num_layer\", [1, 4])\n@pytest.mark.parametrize(\"k\", [3, 5])\n@pytest.mark.parametrize(\"lpe_dim\", [4, 16])\n@pytest.mark.parametrize(\"n_head\", [2, 4])\n@pytest.mark.parametrize(\"batch_norm\", [True, False])\n@pytest.mark.parametrize(\"num_post_layer\", [0, 1, 2])\ndef test_LapPosEncoder(\n    num_layer, k, lpe_dim, n_head, batch_norm, num_post_layer\n):\n    ctx = F.ctx()\n    num_nodes = 4\n\n    EigVals = th.randn((num_nodes, k)).to(ctx)\n    EigVecs = th.randn((num_nodes, k)).to(ctx)\n\n    model = nn.LapPosEncoder(\n        \"Transformer\", num_layer, k, lpe_dim, n_head, batch_norm, num_post_layer\n    ).to(ctx)\n    assert model(EigVals, EigVecs).shape == (num_nodes, lpe_dim)\n\n    model = nn.LapPosEncoder(\n        \"DeepSet\",\n        num_layer,\n        k,\n        lpe_dim,\n        batch_norm=batch_norm,\n        num_post_layer=num_post_layer,\n    ).to(ctx)\n    assert model(EigVals, EigVecs).shape == (num_nodes, lpe_dim)\n\n\n@pytest.mark.parametrize(\"feat_size\", [128, 512])\n@pytest.mark.parametrize(\"num_heads\", [8, 16])\n@pytest.mark.parametrize(\"bias\", [True, False])\n@pytest.mark.parametrize(\"attn_bias_type\", [\"add\", \"mul\"])\n@pytest.mark.parametrize(\"attn_drop\", [0.1, 0.5])\ndef test_BiasedMHA(feat_size, num_heads, bias, attn_bias_type, attn_drop):\n    ndata = th.rand(16, 100, feat_size)\n    attn_bias = th.rand(16, 100, 100, num_heads)\n    attn_mask = th.rand(16, 100, 100) < 0.5\n\n    net = nn.BiasedMHA(feat_size, num_heads, bias, attn_bias_type, attn_drop)\n    out = net(ndata, attn_bias, attn_mask)\n\n    assert out.shape == (16, 100, feat_size)\n\n\n@pytest.mark.parametrize(\"edge_update\", [True, False])\ndef test_EGTLayer(edge_update):\n    batch_size = 16\n    num_nodes = 100\n    feat_size, edge_feat_size = 128, 32\n    nfeat = th.rand(batch_size, num_nodes, feat_size)\n    efeat = th.rand(batch_size, num_nodes, num_nodes, edge_feat_size)\n    mask = (th.rand(batch_size, num_nodes, num_nodes) < 0.5) * -1e9\n\n    net = nn.EGTLayer(\n        feat_size=feat_size,\n        edge_feat_size=edge_feat_size,\n        num_heads=8,\n        num_virtual_nodes=4,\n        edge_update=edge_update,\n    )\n\n    if edge_update:\n        out_nfeat, out_efeat = net(nfeat, efeat, mask)\n        assert out_nfeat.shape == (batch_size, num_nodes, feat_size)\n        assert out_efeat.shape == (\n            batch_size,\n            num_nodes,\n            num_nodes,\n            edge_feat_size,\n        )\n    else:\n        out_nfeat = net(nfeat, efeat, mask)\n        assert out_nfeat.shape == (batch_size, num_nodes, feat_size)\n\n\n@pytest.mark.parametrize(\"attn_bias_type\", [\"add\", \"mul\"])\n@pytest.mark.parametrize(\"norm_first\", [True, False])\ndef test_GraphormerLayer(attn_bias_type, norm_first):\n    batch_size = 16\n    num_nodes = 100\n    feat_size = 512\n    num_heads = 8\n    nfeat = th.rand(batch_size, num_nodes, feat_size)\n    attn_bias = th.rand(batch_size, num_nodes, num_nodes, num_heads)\n    attn_mask = th.rand(batch_size, num_nodes, num_nodes) < 0.5\n\n    net = nn.GraphormerLayer(\n        feat_size=feat_size,\n        hidden_size=2048,\n        num_heads=num_heads,\n        attn_bias_type=attn_bias_type,\n        norm_first=norm_first,\n        dropout=0.1,\n        attn_dropout=0.1,\n        activation=th.nn.ReLU(),\n    )\n    out = net(nfeat, attn_bias, attn_mask)\n\n    assert out.shape == (batch_size, num_nodes, feat_size)\n\n\n@pytest.mark.parametrize(\"max_len\", [1, 2])\n@pytest.mark.parametrize(\"feat_dim\", [16])\n@pytest.mark.parametrize(\"num_heads\", [1, 8])\ndef test_PathEncoder(max_len, feat_dim, num_heads):\n    dev = F.ctx()\n    g = dgl.graph(\n        (\n            th.tensor([0, 0, 0, 1, 1, 2, 3, 3]),\n            th.tensor([1, 2, 3, 0, 3, 0, 0, 1]),\n        )\n    ).to(dev)\n    edge_feat = th.rand(g.num_edges(), feat_dim).to(dev)\n    edge_feat = th.cat((edge_feat, th.zeros(1, 16).to(dev)), dim=0)\n    dist, path = shortest_dist(g, root=None, return_paths=True)\n    path_data = edge_feat[path[:, :, :max_len]]\n    model = nn.PathEncoder(max_len, feat_dim, num_heads=num_heads).to(dev)\n    bias = model(dist.unsqueeze(0), path_data.unsqueeze(0))\n    assert bias.shape == (1, 4, 4, num_heads)\n\n\n@pytest.mark.parametrize(\"max_dist\", [1, 4])\n@pytest.mark.parametrize(\"num_kernels\", [4, 16])\n@pytest.mark.parametrize(\"num_heads\", [1, 8])\ndef test_SpatialEncoder(max_dist, num_kernels, num_heads):\n    dev = F.ctx()\n    # single graph encoding 3d\n    num_nodes = 4\n    coord = th.rand(1, num_nodes, 3).to(dev)\n    node_type = th.tensor([[1, 0, 2, 1]]).to(dev)\n    spatial_encoder = nn.SpatialEncoder3d(\n        num_kernels=num_kernels, num_heads=num_heads, max_node_type=3\n    ).to(dev)\n    out = spatial_encoder(coord, node_type=node_type)\n    assert out.shape == (1, num_nodes, num_nodes, num_heads)\n\n    # encoding on a batch of graphs\n    g1 = dgl.graph(\n        (\n            th.tensor([0, 0, 0, 1, 1, 2, 3, 3]),\n            th.tensor([1, 2, 3, 0, 3, 0, 0, 1]),\n        )\n    ).to(dev)\n    g2 = dgl.graph(\n        (th.tensor([0, 1, 2, 3, 2, 5]), th.tensor([1, 2, 3, 4, 0, 3]))\n    ).to(dev)\n    bsz, max_num_nodes = 2, 6\n    # 2d encoding\n    dist = -th.ones((bsz, max_num_nodes, max_num_nodes), dtype=th.long).to(dev)\n    dist[0, :4, :4] = shortest_dist(g1, root=None, return_paths=False)\n    dist[1, :6, :6] = shortest_dist(g2, root=None, return_paths=False)\n    model_1 = nn.SpatialEncoder(max_dist, num_heads=num_heads).to(dev)\n    encoding = model_1(dist)\n    assert encoding.shape == (bsz, max_num_nodes, max_num_nodes, num_heads)\n    # 3d encoding\n    coord = th.rand(bsz, max_num_nodes, 3).to(dev)\n    node_type = th.randint(\n        0,\n        512,\n        (\n            bsz,\n            max_num_nodes,\n        ),\n    ).to(dev)\n    model_2 = nn.SpatialEncoder3d(num_kernels, num_heads=num_heads).to(dev)\n    model_3 = nn.SpatialEncoder3d(\n        num_kernels, num_heads=num_heads, max_node_type=512\n    ).to(dev)\n    encoding3d_1 = model_2(coord)\n    encoding3d_2 = model_3(coord, node_type)\n    assert encoding3d_1.shape == (bsz, max_num_nodes, max_num_nodes, num_heads)\n    assert encoding3d_2.shape == (bsz, max_num_nodes, max_num_nodes, num_heads)\n\n\n@pytest.mark.parametrize(\"residual\", [True, False])\ndef test_conv_with_zero_nodes_bugfix_7894(residual):\n    \"\"\"Test for PR #7894 in DGL where HeteroGraphConv with zero nodes in a\n    specific node type would cause an error due to empty tensors.\n    This test ensures that GATConv, GATv2Conv, and EdgeGATConv can handle\n    such cases without raising errors.\n    \"\"\"\n    # Create a heterogeneous graph with zero nodes in the \"tag\" type\n    user_item_src = torch.tensor([0, 1, 2])\n    user_item_dst = torch.tensor([4, 5, 6])\n\n    user_tag_src = torch.tensor([], dtype=torch.int64)\n    user_tag_dst = torch.tensor([], dtype=torch.int64)\n\n    num_nodes_dict = {\n        \"user\": 5,\n        \"item\": 10,\n        \"tag\": 0,\n    }\n\n    data_dict = {\n        (\"user\", \"buys\", \"item\"): (user_item_src, user_item_dst),\n        (\"user\", \"likes\", \"tag\"): (user_tag_src, user_tag_dst),\n    }\n\n    g = dgl.heterograph(data_dict, num_nodes_dict=num_nodes_dict)\n\n    feat_dim = 16\n    node_features = {\n        \"user\": torch.randn(num_nodes_dict[\"user\"], feat_dim),\n        \"item\": torch.randn(num_nodes_dict[\"item\"], feat_dim),\n        \"tag\": torch.randn(num_nodes_dict[\"tag\"], feat_dim), \n    }\n    edge_features = {\n        (\"user\", \"buys\", \"item\"): torch.randn(g.num_edges((\"user\", \"buys\", \"item\")), feat_dim),\n        (\"user\", \"likes\", \"tag\"): torch.randn(g.num_edges((\"user\", \"likes\", \"tag\")), feat_dim),\n    }\n\n    # Test GATConv with zero nodes in \"tag\" type\n    conv = nn.HeteroGraphConv({\n        (\"user\", \"buys\", \"item\"): nn.GATConv(16, 2, num_heads=2, residual=residual),\n        (\"user\", \"likes\", \"tag\"): nn.GATConv(16, 2, num_heads=2, residual=residual),\n    }, aggregate=\"sum\")\n    out = conv(g, node_features)\n    assert out[\"item\"].shape == (10, 2, 2)\n    assert out[\"tag\"].shape == (0, 2, 2)\n    assert \"user\" not in out\n\n    # Test GATv2Conv with zero nodes in \"tag\" type\n    conv_v2 = nn.HeteroGraphConv({\n        (\"user\", \"buys\", \"item\"): nn.GATv2Conv(16, 2, num_heads=2, residual=residual),\n        (\"user\", \"likes\", \"tag\"): nn.GATv2Conv(16, 2, num_heads=2, residual=residual),\n    }, aggregate=\"sum\")\n    out_v2 = conv_v2(g, node_features)\n    assert out_v2[\"item\"].shape == (10, 2, 2)\n    assert out_v2[\"tag\"].shape == (0, 2, 2)\n    assert \"user\" not in out_v2\n\n    # Test EdgeGATConv with zero nodes in \"tag\" type\n    edge_conv = nn.HeteroGraphConv({\n        (\"user\", \"buys\", \"item\"): nn.EdgeGATConv(16, 16, 2, num_heads=2, residual=residual),\n        (\"user\", \"likes\", \"tag\"): nn.EdgeGATConv(16, 16, 2, num_heads=2, residual=residual),\n    }, aggregate=\"sum\")\n    mod_kwargs = {\n        \"buys\": {\"edge_feat\": edge_features[(\"user\", \"buys\", \"item\")]},\n        \"likes\": {\"edge_feat\": edge_features[(\"user\", \"likes\", \"tag\")]},\n    }\n    out_edge = edge_conv(g, node_features, mod_kwargs=mod_kwargs)\n    assert out_edge[\"item\"].shape == (10, 2, 2)\n    assert out_edge[\"tag\"].shape == (0, 2, 2)\n    assert \"user\" not in out_edge\n"
  },
  {
    "path": "tests/python/pytorch/nn/test_sparse_emb.py",
    "content": "import multiprocessing as mp\nimport os\nimport unittest\n\nimport backend as F\nimport pytest\nimport torch as th\n\nfrom dgl.nn import NodeEmbedding\nfrom dgl.optim import SparseAdam\n\n\ndef initializer(emb):\n    th.manual_seed(0)\n    emb.uniform_(-1.0, 1.0)\n    return emb\n\n\ndef check_all_set_all_get_emb(device, init_emb):\n    num_embs = init_emb.shape[0]\n    emb_dim = init_emb.shape[1]\n    dgl_emb = NodeEmbedding(num_embs, emb_dim, \"test\", device=device)\n    dgl_emb.all_set_embedding(init_emb)\n\n    out_emb = dgl_emb.all_get_embedding()\n    assert F.allclose(init_emb, out_emb)\n\n\ndef check_all_set_all_get_optm_state(\n    device, state_step, state_mem, state_power\n):\n    num_embs = state_mem.shape[0]\n    emb_dim = state_mem.shape[1]\n    dgl_emb = NodeEmbedding(num_embs, emb_dim, \"test\", device=device)\n    optm = SparseAdam(params=[dgl_emb], lr=0.01)\n\n    dgl_emb._all_set_optm_state((state_step, state_mem, state_power))\n\n    out_step, out_mem, out_power = dgl_emb._all_get_optm_state()\n\n    assert F.allclose(state_step, out_step)\n    assert F.allclose(state_mem, out_mem)\n    assert F.allclose(state_power, out_power)\n\n\ndef start_sparse_worker(rank, world_size, test, args):\n    print(\"start sparse worker {}\".format(rank))\n    dist_init_method = \"tcp://{master_ip}:{master_port}\".format(\n        master_ip=\"127.0.0.1\", master_port=\"12345\"\n    )\n    backend = \"gloo\"\n    device = F.ctx()\n    if device.type == \"cuda\":\n        device = th.device(rank)\n        th.cuda.set_device(device)\n    th.distributed.init_process_group(\n        backend=backend,\n        init_method=dist_init_method,\n        world_size=world_size,\n        rank=rank,\n    )\n\n    test(device, *args)\n    th.distributed.barrier()\n    th.distributed.destroy_process_group()\n\n\n@unittest.skipIf(os.name == \"nt\", reason=\"Do not support windows yet\")\n@pytest.mark.parametrize(\"num_workers\", [1, 2, 3])\ndef test_multiprocess_sparse_emb_get_set(num_workers):\n    if F.ctx().type == \"cuda\" and th.cuda.device_count() < num_workers:\n        pytest.skip(\"Not enough GPUs to run test.\")\n\n    worker_list = []\n\n    init_emb = th.rand([1000, 8])\n\n    ctx = mp.get_context(\"spawn\")\n    for i in range(num_workers):\n        p = ctx.Process(\n            target=start_sparse_worker,\n            args=(i, num_workers, check_all_set_all_get_emb, (init_emb,)),\n        )\n        p.start()\n        worker_list.append(p)\n\n    for p in worker_list:\n        p.join()\n    for p in worker_list:\n        assert p.exitcode == 0\n\n\n@unittest.skipIf(os.name == \"nt\", reason=\"Do not support windows yet\")\n@pytest.mark.parametrize(\"num_workers\", [1, 2, 3])\ndef test_multiprocess_sparse_emb_get_set_optm_state(num_workers):\n    if F.ctx().type == \"cuda\" and th.cuda.device_count() < num_workers:\n        pytest.skip(\"Not enough GPUs to run test.\")\n\n    worker_list = []\n\n    num_embs, emb_dim = 1000, 8\n    state_step = th.randint(1000, (num_embs,))\n    state_mem = th.rand((num_embs, emb_dim))\n    state_power = th.rand((num_embs, emb_dim))\n\n    ctx = mp.get_context(\"spawn\")\n    for i in range(num_workers):\n        p = ctx.Process(\n            target=start_sparse_worker,\n            args=(\n                i,\n                num_workers,\n                check_all_set_all_get_optm_state,\n                (state_step, state_mem, state_power),\n            ),\n        )\n        p.start()\n        worker_list.append(p)\n\n    for p in worker_list:\n        p.join()\n    for p in worker_list:\n        assert p.exitcode == 0\n\n\nif __name__ == \"__main__\":\n    # test_multiprocess_sparse_emb_get_set(1)\n    # test_multiprocess_sparse_emb_get_set(2)\n    # test_multiprocess_sparse_emb_get_set(3)\n\n    test_multiprocess_sparse_emb_get_set_optm_state(1)\n    # test_multiprocess_sparse_emb_get_set_optm_state(2)\n    # test_multiprocess_sparse_emb_get_set_optm_state(3)\n"
  },
  {
    "path": "tests/python/pytorch/optim/test_optim.py",
    "content": "import os\nimport unittest\n\nimport backend as F\nimport pytest\nimport torch as th\nimport torch.multiprocessing as mp\n\nfrom dgl.nn import NodeEmbedding\nfrom dgl.optim import SparseAdagrad, SparseAdam\n\n\n@unittest.skipIf(os.name == \"nt\", reason=\"Do not support windows yet\")\n@pytest.mark.parametrize(\"emb_dim\", [1, 4, 101, 1024])\ndef test_sparse_adam(emb_dim):\n    num_embs = 10\n    device = F.ctx()\n    dgl_emb = NodeEmbedding(num_embs, emb_dim, \"test\")\n    torch_emb = th.nn.Embedding(num_embs, emb_dim, sparse=True)\n    th.manual_seed(0)\n    th.nn.init.uniform_(torch_emb.weight, 0, 1.0)\n    th.manual_seed(0)\n    th.nn.init.uniform_(dgl_emb.weight, 0, 1.0)\n\n    dgl_adam = SparseAdam(params=[dgl_emb], lr=0.01)\n    torch_adam = th.optim.SparseAdam(list(torch_emb.parameters()), lr=0.01)\n\n    # first step\n    idx = th.randint(0, num_embs, size=(4,))\n    dgl_value = dgl_emb(idx, device).to(th.device(\"cpu\"))\n    torch_value = torch_emb(idx)\n    labels = th.zeros((4,)).long()\n    print(\"dgl_value = {}\".format(dgl_value))\n    print(\"labels = {}\".format(labels))\n\n    dgl_adam.zero_grad()\n    torch_adam.zero_grad()\n    dgl_loss = th.nn.functional.cross_entropy(dgl_value, labels)\n    torch_loss = th.nn.functional.cross_entropy(torch_value, labels)\n    dgl_loss.backward()\n    torch_loss.backward()\n\n    dgl_adam.step()\n    torch_adam.step()\n    assert F.allclose(dgl_emb.weight, torch_emb.weight)\n\n    # Can not test second step\n    # Pytorch sparseAdam maintains a global step\n    # DGL sparseAdam use a per embedding step\n\n\n@unittest.skipIf(os.name == \"nt\", reason=\"Do not support windows yet\")\n@pytest.mark.parametrize(\"use_uva\", [False, True, None])\n@pytest.mark.parametrize(\"emb_dim\", [1, 4, 101, 1024])\ndef test_sparse_adam_uva(use_uva, emb_dim):\n    if F.ctx().type == \"cpu\" and use_uva == True:\n        # we want to only test values of False and None when not using GPU\n        pytest.skip(\"UVA cannot be used without GPUs.\")\n\n    num_embs = 10\n    device = F.ctx()\n    dgl_emb = NodeEmbedding(num_embs, emb_dim, \"test_uva{}\".format(use_uva))\n    torch_emb = th.nn.Embedding(num_embs, emb_dim, sparse=True)\n    th.manual_seed(0)\n    th.nn.init.uniform_(torch_emb.weight, 0, 1.0)\n    th.manual_seed(0)\n    th.nn.init.uniform_(dgl_emb.weight, 0, 1.0)\n\n    dgl_adam = SparseAdam(params=[dgl_emb], lr=0.01, use_uva=use_uva)\n    torch_adam = th.optim.SparseAdam(list(torch_emb.parameters()), lr=0.01)\n\n    # first step\n    idx = th.randint(0, num_embs, size=(4,))\n    dgl_value = dgl_emb(idx, device).to(th.device(\"cpu\"))\n    torch_value = torch_emb(idx)\n    labels = th.zeros((4,)).long()\n\n    dgl_adam.zero_grad()\n    torch_adam.zero_grad()\n    dgl_loss = th.nn.functional.cross_entropy(dgl_value, labels)\n    torch_loss = th.nn.functional.cross_entropy(torch_value, labels)\n    dgl_loss.backward()\n    torch_loss.backward()\n\n    dgl_adam.step()\n    torch_adam.step()\n    assert F.allclose(dgl_emb.weight, torch_emb.weight)\n\n    # Can not test second step\n    # Pytorch sparseAdam maintains a global step\n    # DGL sparseAdam use a per embedding step\n\n\n@unittest.skipIf(os.name == \"nt\", reason=\"Do not support windows yet\")\n@pytest.mark.parametrize(\"dtype\", [th.float32, th.float16])\n@pytest.mark.parametrize(\"emb_dim\", [1, 4, 101, 1024])\ndef test_sparse_adam_dtype(dtype, emb_dim):\n    num_embs = 10\n    device = F.ctx()\n    dgl_emb = NodeEmbedding(num_embs, emb_dim, \"test_dtype{}\".format(dtype))\n    torch_emb = th.nn.Embedding(num_embs, emb_dim, sparse=True)\n    th.manual_seed(0)\n    th.nn.init.uniform_(torch_emb.weight, 0, 1.0)\n    th.manual_seed(0)\n    th.nn.init.uniform_(dgl_emb.weight, 0, 1.0)\n\n    dgl_adam = SparseAdam(params=[dgl_emb], lr=0.01, dtype=dtype)\n    torch_adam = th.optim.SparseAdam(list(torch_emb.parameters()), lr=0.01)\n\n    # first step\n    idx = th.randint(0, num_embs, size=(4,))\n    dgl_value = dgl_emb(idx, device).to(th.device(\"cpu\"))\n    torch_value = torch_emb(idx)\n    labels = th.zeros((4,)).long()\n\n    dgl_adam.zero_grad()\n    torch_adam.zero_grad()\n    dgl_loss = th.nn.functional.cross_entropy(dgl_value, labels)\n    torch_loss = th.nn.functional.cross_entropy(torch_value, labels)\n    dgl_loss.backward()\n    torch_loss.backward()\n\n    dgl_adam.step()\n    torch_adam.step()\n    assert F.allclose(dgl_emb.weight, torch_emb.weight)\n\n    # Can not test second step\n    # Pytorch sparseAdam maintains a global step\n    # DGL sparseAdam use a per embedding step\n\n\n@unittest.skipIf(os.name == \"nt\", reason=\"Do not support windows yet\")\ndef test_sparse_adam_zero_step():\n    num_embs = 10\n    emb_dim = 4\n    device = F.ctx()\n    dgl_emb = NodeEmbedding(num_embs, emb_dim, \"test\")\n    torch_emb = th.nn.Embedding(num_embs, emb_dim, sparse=True)\n    dgl_emb_zero = NodeEmbedding(num_embs, emb_dim, \"test2\")\n    torch_emb_zero = th.nn.Embedding(num_embs, emb_dim, sparse=True)\n    th.manual_seed(0)\n    th.nn.init.uniform_(torch_emb.weight, 0, 1.0)\n    th.nn.init.uniform_(torch_emb_zero.weight, 0, 1.0)\n    th.manual_seed(0)\n    th.nn.init.uniform_(dgl_emb.weight, 0, 1.0)\n    th.nn.init.uniform_(dgl_emb_zero.weight, 0, 1.0)\n\n    dgl_adam = SparseAdam(params=[dgl_emb, dgl_emb_zero], lr=0.01)\n    torch_adam = th.optim.SparseAdam(\n        list(torch_emb.parameters()) + list(torch_emb_zero.parameters()),\n        lr=0.01,\n    )\n\n    # first step\n    idx = th.randint(0, num_embs, size=(4,))\n    dgl_value = dgl_emb(idx, device).to(th.device(\"cpu\"))\n    torch_value = torch_emb(idx)\n    labels = th.ones((4,)).long()\n\n    dgl_adam.zero_grad()\n    torch_adam.zero_grad()\n    dgl_loss = th.nn.functional.cross_entropy(dgl_value, labels)\n    torch_loss = th.nn.functional.cross_entropy(torch_value, labels)\n    dgl_loss.backward()\n    torch_loss.backward()\n\n    dgl_adam.step()\n    torch_adam.step()\n    assert F.allclose(dgl_emb.weight, torch_emb.weight)\n\n\ndef initializer(emb):\n    th.manual_seed(0)\n    emb.uniform_(-1.0, 1.0)\n    return emb\n\n\ndef start_sparse_adam_worker(\n    rank,\n    device,\n    world_size,\n    weight,\n    tensor_dev=\"cpu\",\n    has_zero_grad=False,\n    backend=\"gloo\",\n    num_embs=128,\n    emb_dim=10,\n    zero_comm=True,\n):\n    print(\"start sparse worker for adam {}\".format(rank))\n    dist_init_method = \"tcp://{master_ip}:{master_port}\".format(\n        master_ip=\"127.0.0.1\", master_port=\"12345\"\n    )\n\n    if device.type == \"cuda\":\n        th.cuda.set_device(device)\n\n    th.distributed.init_process_group(\n        backend=backend,\n        init_method=dist_init_method,\n        world_size=world_size,\n        rank=rank,\n    )\n\n    init_weight = th.empty((num_embs, emb_dim))\n    th.manual_seed(0)\n    th.nn.init.uniform_(init_weight, -1.0, 1.0)\n    dgl_emb = NodeEmbedding(\n        num_embs, emb_dim, \"test\", init_func=initializer, device=tensor_dev\n    )\n    dgl_emb.all_set_embedding(init_weight)\n\n    if has_zero_grad:\n        dgl_emb_zero = NodeEmbedding(\n            num_embs, emb_dim, \"zero\", init_func=initializer, device=tensor_dev\n        )\n        dgl_adam = SparseAdam(params=[dgl_emb, dgl_emb_zero], lr=0.01)\n    else:\n        dgl_adam = SparseAdam(params=[dgl_emb], lr=0.01)\n\n    th.manual_seed(rank)\n    if zero_comm:\n        start = (num_embs // world_size) * rank\n        end = (num_embs // world_size) * (rank + 1)\n        idx = th.randint(start, end, size=(4,)).to(tensor_dev)\n    else:\n        idx = th.randint(0, num_embs, size=(4,)).to(tensor_dev)\n    dgl_value = dgl_emb(idx, device)\n    labels = th.ones((4,)).long().to(device)\n    dgl_loss = th.nn.functional.cross_entropy(dgl_value, labels)\n    dgl_adam.zero_grad()\n    dgl_loss.backward()\n    dgl_adam.step()\n    th.distributed.barrier()\n    dgl_weight = dgl_emb.all_get_embedding().detach()\n    after_step = dgl_emb(idx, device).cpu()\n\n    if rank == 0:\n        dgl_value = dgl_value.detach().cpu()\n        assert F.allclose(dgl_value, after_step) is False\n        weight[:] = dgl_weight[:]\n    th.distributed.barrier()\n\n\ndef start_torch_adam_worker(\n    rank,\n    world_size,\n    weight,\n    has_zero_grad=False,\n    num_embs=128,\n    emb_dim=10,\n    zero_comm=True,\n):\n    print(\"start sparse worker for adam {}\".format(rank))\n    dist_init_method = \"tcp://{master_ip}:{master_port}\".format(\n        master_ip=\"127.0.0.1\", master_port=\"12345\"\n    )\n    backend = \"gloo\"\n\n    th.distributed.init_process_group(\n        backend=backend,\n        init_method=dist_init_method,\n        world_size=world_size,\n        rank=rank,\n    )\n\n    torch_emb = th.nn.Embedding(num_embs, emb_dim, sparse=True)\n    th.manual_seed(0)\n    th.nn.init.uniform_(torch_emb.weight, -1.0, 1.0)\n    torch_emb = th.nn.parallel.DistributedDataParallel(torch_emb)\n    if has_zero_grad:\n        torch_emb_zero = th.nn.Embedding(num_embs, emb_dim, sparse=True)\n        torch_emb_zero = torch_emb_zero.to(tensor_dev)\n        th.manual_seed(0)\n        th.nn.init.uniform_(torch_emb_zero.weight, -1.0, 1.0)\n        torch_emb_zero = th.nn.parallel.DistributedDataParallel(torch_emb_zero)\n        torch_adam = th.optim.SparseAdam(\n            list(torch_emb.module.parameters())\n            + list(torch_emb_zero.module.parameters()),\n            lr=0.01,\n        )\n    else:\n        torch_adam = th.optim.SparseAdam(\n            list(torch_emb.module.parameters()), lr=0.01\n        )\n\n    th.manual_seed(rank)\n    if zero_comm:\n        start = (num_embs // world_size) * rank\n        end = (num_embs // world_size) * (rank + 1)\n        idx = th.randint(start, end, size=(4,))\n    else:\n        idx = th.randint(0, num_embs, size=(4,))\n    labels = th.ones((4,)).long()\n    torch_value = torch_emb(idx)\n    torch_loss = th.nn.functional.cross_entropy(torch_value, labels)\n    torch_adam.zero_grad()\n    torch_loss.backward()\n    torch_adam.step()\n    th.distributed.barrier()\n\n    if rank == 0:\n        weight[:] = torch_emb.module.weight.cpu()[:]\n    th.distributed.barrier()\n\n\n@unittest.skipIf(os.name == \"nt\", reason=\"Do not support windows yet\")\n@unittest.skipIf(F.ctx().type != \"cpu\", reason=\"cpu only test\")\n@pytest.mark.parametrize(\"num_workers\", [2, 4])\ndef test_multiprocess_cpu_sparse_adam(num_workers):\n    backend = \"gloo\"\n    worker_list = []\n    num_embs = 128\n    emb_dim = 10\n    dgl_weight = th.empty((num_embs, emb_dim))\n    ctx = mp.get_context(\"spawn\")\n    for i in range(num_workers):\n        device = F.ctx()\n        p = ctx.Process(\n            target=start_sparse_adam_worker,\n            args=(\n                i,\n                device,\n                num_workers,\n                dgl_weight,\n                th.device(\"cpu\"),\n                True,\n                backend,\n            ),\n        )\n        p.start()\n        worker_list.append(p)\n    for p in worker_list:\n        p.join()\n\n    worker_list = []\n    torch_weight = th.empty((num_embs, emb_dim))\n    for i in range(num_workers):\n        p = ctx.Process(\n            target=start_torch_adam_worker,\n            args=(i, num_workers, torch_weight, False),\n        )\n        p.start()\n        worker_list.append(p)\n    for p in worker_list:\n        p.join()\n\n    assert F.allclose(dgl_weight, torch_weight)\n\n\n@unittest.skipIf(os.name == \"nt\", reason=\"Do not support windows yet\")\n@unittest.skipIf(F.ctx().type == \"cpu\", reason=\"gpu only test\")\n@pytest.mark.parametrize(\"num_workers\", [2, 4, 8])\n@pytest.mark.parametrize(\"backend\", [\"nccl\", \"gloo\"])\n@pytest.mark.parametrize(\"zero_comm\", [True, False])\ndef test_multiprocess_sparse_adam(num_workers, backend, zero_comm):\n    if F.ctx().type == \"cuda\" and th.cuda.device_count() < num_workers:\n        pytest.skip(\"Not enough GPUs to run test.\")\n\n    worker_list = []\n    num_embs = 128\n    emb_dim = 10\n    dgl_weight = th.empty((num_embs, emb_dim))\n    ctx = mp.get_context(\"spawn\")\n    for i in range(num_workers):\n        device = F.ctx()\n        if device.type == \"cuda\":\n            # make sure each process has a unique GPU\n            device = th.device(i)\n        p = ctx.Process(\n            target=start_sparse_adam_worker,\n            args=(\n                i,\n                device,\n                num_workers,\n                dgl_weight,\n                th.device(\"cpu\"),\n                True,\n                backend,\n                num_embs,\n                emb_dim,\n                zero_comm,\n            ),\n        )\n        p.start()\n        worker_list.append(p)\n    for p in worker_list:\n        p.join()\n\n    worker_list = []\n    torch_weight = th.empty((num_embs, emb_dim))\n    for i in range(num_workers):\n        p = ctx.Process(\n            target=start_torch_adam_worker,\n            args=(\n                i,\n                num_workers,\n                torch_weight,\n                False,\n                num_embs,\n                emb_dim,\n                zero_comm,\n            ),\n        )\n        p.start()\n        worker_list.append(p)\n    for p in worker_list:\n        p.join()\n\n    assert F.allclose(dgl_weight, torch_weight)\n\n\n@unittest.skipIf(os.name == \"nt\", reason=\"Do not support windows yet\")\n@unittest.skipIf(\n    F.ctx().type == \"cpu\", reason=\"cuda tensor is not supported for cpu\"\n)\n@pytest.mark.parametrize(\"num_workers\", [2, 4, 8])\ndef test_multiprocess_sparse_adam_cuda_tensor(num_workers):\n    if F.ctx().type == \"cpu\":\n        pytest.skip(\"Do not test CPU\")\n    if F.ctx().type == \"cuda\" and th.cuda.device_count() < num_workers:\n        pytest.skip(\"Not enough GPUs to run test.\")\n\n    backend = \"nccl\"\n    worker_list = []\n    num_embs = 128\n    emb_dim = 10\n    dgl_weight = th.empty((num_embs, emb_dim))\n    ctx = mp.get_context(\"spawn\")\n    for i in range(num_workers):\n        device = th.device(i)\n        p = ctx.Process(\n            target=start_sparse_adam_worker,\n            args=(i, device, num_workers, dgl_weight, device, False, backend),\n        )\n        p.start()\n        worker_list.append(p)\n    for p in worker_list:\n        p.join()\n\n    worker_list = []\n    torch_weight = th.empty((num_embs, emb_dim))\n    for i in range(num_workers):\n        p = ctx.Process(\n            target=start_torch_adam_worker,\n            args=(i, num_workers, torch_weight, False),\n        )\n        p.start()\n        worker_list.append(p)\n    for p in worker_list:\n        p.join()\n\n    assert F.allclose(dgl_weight, torch_weight)\n\n\n@unittest.skipIf(os.name == \"nt\", reason=\"Do not support windows yet\")\n@unittest.skipIf(F.ctx().type != \"cpu\", reason=\"cpu only test\")\n@pytest.mark.parametrize(\"num_workers\", [2, 4])\ndef test_multiprocess_sparse_adam_cpu_zero_step(num_workers):\n    backend = \"gloo\"\n\n    worker_list = []\n    num_embs = 128\n    emb_dim = 10\n    dgl_weight = th.empty((num_embs, emb_dim))\n    ctx = mp.get_context(\"spawn\")\n    for i in range(num_workers):\n        device = F.ctx()\n        p = ctx.Process(\n            target=start_sparse_adam_worker,\n            args=(\n                i,\n                device,\n                num_workers,\n                dgl_weight,\n                th.device(\"cpu\"),\n                True,\n                backend,\n            ),\n        )\n        p.start()\n        worker_list.append(p)\n    for p in worker_list:\n        p.join()\n\n    worker_list = []\n    torch_weight = th.empty((num_embs, emb_dim))\n    for i in range(num_workers):\n        p = ctx.Process(\n            target=start_torch_adam_worker,\n            args=(i, num_workers, torch_weight, False),\n        )\n        p.start()\n        worker_list.append(p)\n    for p in worker_list:\n        p.join()\n\n    assert F.allclose(dgl_weight, torch_weight)\n\n\n@unittest.skipIf(os.name == \"nt\", reason=\"Do not support windows yet\")\n@unittest.skipIf(F.ctx().type == \"cpu\", reason=\"gpu only test\")\n@pytest.mark.parametrize(\"num_workers\", [2, 4, 8])\n@pytest.mark.parametrize(\"backend\", [\"nccl\", \"gloo\"])\ndef test_multiprocess_sparse_adam_zero_step(num_workers, backend):\n    if F.ctx().type == \"cuda\" and th.cuda.device_count() < num_workers:\n        pytest.skip(\"Not enough GPUs to run test.\")\n\n    worker_list = []\n    num_embs = 128\n    emb_dim = 10\n    dgl_weight = th.empty((num_embs, emb_dim))\n    ctx = mp.get_context(\"spawn\")\n    for i in range(num_workers):\n        device = F.ctx()\n        if device.type == \"cuda\":\n            # make sure each process has a unique GPU\n            device = th.device(i)\n        p = ctx.Process(\n            target=start_sparse_adam_worker,\n            args=(\n                i,\n                device,\n                num_workers,\n                dgl_weight,\n                th.device(\"cpu\"),\n                True,\n                backend,\n            ),\n        )\n        p.start()\n        worker_list.append(p)\n    for p in worker_list:\n        p.join()\n\n    worker_list = []\n    torch_weight = th.empty((num_embs, emb_dim))\n    for i in range(num_workers):\n        p = ctx.Process(\n            target=start_torch_adam_worker,\n            args=(i, num_workers, torch_weight, False),\n        )\n        p.start()\n        worker_list.append(p)\n    for p in worker_list:\n        p.join()\n\n    assert F.allclose(dgl_weight, torch_weight)\n\n\n@unittest.skipIf(os.name == \"nt\", reason=\"Do not support windows yet\")\n@unittest.skipIf(\n    F.ctx().type == \"cpu\", reason=\"cuda tensor is not supported for cpu\"\n)\n@pytest.mark.parametrize(\"num_workers\", [2, 4, 8])\ndef test_multiprocess_sparse_adam_zero_step_cuda_tensor(num_workers):\n    if F.ctx().type == \"cuda\" and th.cuda.device_count() < num_workers:\n        pytest.skip(\"Not enough GPUs to run test.\")\n\n    backend = \"nccl\"\n    worker_list = []\n    num_embs = 128\n    emb_dim = 10\n    dgl_weight = th.empty((num_embs, emb_dim))\n    ctx = mp.get_context(\"spawn\")\n    for i in range(num_workers):\n        device = th.device(i)\n        p = ctx.Process(\n            target=start_sparse_adam_worker,\n            args=(i, device, num_workers, dgl_weight, device, True, backend),\n        )\n        p.start()\n        worker_list.append(p)\n    for p in worker_list:\n        p.join()\n\n    worker_list = []\n    torch_weight = th.empty((num_embs, emb_dim))\n    for i in range(num_workers):\n        p = ctx.Process(\n            target=start_torch_adam_worker,\n            args=(i, num_workers, torch_weight, False),\n        )\n        p.start()\n        worker_list.append(p)\n    for p in worker_list:\n        p.join()\n\n    assert F.allclose(dgl_weight, torch_weight)\n\n\ndef start_sparse_adam_state_dict_worker(\n    rank,\n    world_size,\n    init_weight,\n    backend,\n    num_embs,\n    emb_dim,\n):\n    print(\"start sparse worker for adam {}\".format(rank))\n    dist_init_method = \"tcp://{master_ip}:{master_port}\".format(\n        master_ip=\"127.0.0.1\", master_port=\"12345\"\n    )\n\n    device = th.device(f\"cuda:{rank}\")\n    th.cuda.set_device(device)\n    tensor_dev = device if backend == \"nccl\" else th.device(\"cpu\")\n\n    th.distributed.init_process_group(\n        backend=backend,\n        init_method=dist_init_method,\n        world_size=world_size,\n        rank=rank,\n    )\n\n    th.manual_seed(0)\n    dgl_emb = NodeEmbedding(\n        num_embs, emb_dim, \"test\", init_func=initializer, device=tensor_dev\n    )\n    dgl_emb.all_set_embedding(init_weight)\n\n    dgl_adam = SparseAdam(params=[dgl_emb], lr=0.01)\n\n    start = (num_embs // world_size) * rank\n    end = (num_embs // world_size) * (rank + 1)\n    th.manual_seed(rank)\n    idx = th.randint(start, end, size=(4,)).to(tensor_dev)\n    dgl_value = dgl_emb(idx, device)\n    labels = th.ones((4,)).long().to(device)\n    dgl_loss = th.nn.functional.cross_entropy(dgl_value, labels)\n    dgl_adam.zero_grad()\n    dgl_loss.backward()\n    dgl_adam.step()\n    th.distributed.barrier()\n\n    worker_state_dict = [t.detach().clone() for t in dgl_emb.optm_state]\n    state_dict = dgl_adam.state_dict()\n    for t in dgl_emb.optm_state:\n        t.zero_()\n    dgl_adam.load_state_dict(state_dict)\n\n    for i, j in zip(worker_state_dict, dgl_emb.optm_state):\n        F.allclose(i, j)\n\n    th.distributed.barrier()\n\n\n@unittest.skipIf(os.name == \"nt\", reason=\"Do not support windows yet\")\n@unittest.skipIf(F.ctx().type == \"cpu\", reason=\"gpu only test\")\n@pytest.mark.parametrize(\"num_workers\", [1, 2, 4, 8])\n@pytest.mark.parametrize(\"backend\", [\"nccl\", \"gloo\"])\ndef test_multiprocess_sparse_adam_state_dict(num_workers, backend):\n    if F.ctx().type == \"cuda\" and th.cuda.device_count() < num_workers:\n        pytest.skip(\"Not enough GPUs to run test.\")\n\n    num_embs = 128\n    emb_dim = 10\n    init_weight = th.rand((num_embs, emb_dim))\n    mp.spawn(\n        start_sparse_adam_state_dict_worker,\n        (\n            num_workers,\n            init_weight,\n            backend,\n            num_embs,\n            emb_dim,\n        ),\n        nprocs=num_workers,\n    )\n\n\nif __name__ == \"__main__\":\n    test_sparse_adam(1)\n    test_sparse_adam(4)\n    test_sparse_adam(101)\n    test_sparse_adam(1024)\n    test_sparse_adam_zero_step()\n\n    test_multiprocess_cpu_sparse_adam(2)\n    test_multiprocess_cpu_sparse_adam(4)\n    test_multiprocess_cpu_sparse_adam(8)\n    test_multiprocess_sparse_adam_cpu_zero_step(2)\n\n    test_multiprocess_sparse_adam(2, backend=\"gloo\")\n    test_multiprocess_sparse_adam(4, backend=\"gloo\")\n    test_multiprocess_sparse_adam(8, backend=\"gloo\")\n    test_multiprocess_sparse_adam(2, backend=\"nccl\")\n    test_multiprocess_sparse_adam(4, backend=\"nccl\")\n    test_multiprocess_sparse_adam(8, backend=\"nccl\")\n\n    test_multiprocess_sparse_adam_zero_step(2, backend=\"gloo\")\n    test_multiprocess_sparse_adam_zero_step(4, backend=\"nccl\")\n\n    test_multiprocess_sparse_adam_cuda_tensor(2)\n    test_multiprocess_sparse_adam_zero_step_cuda_tensor(4)\n\n    test_multiprocess_sparse_adam_state_dict(2, \"nccl\")\n    test_multiprocess_sparse_adam_state_dict(2, \"gloo\")\n"
  },
  {
    "path": "tests/python/pytorch/sparse/__init__.py",
    "content": "\"\"\" DGL sparse tests\"\"\"\n"
  },
  {
    "path": "tests/python/pytorch/sparse/test_broadcast.py",
    "content": "import operator\n\nimport backend as F\nimport pytest\nimport torch\n\nfrom dgl.sparse import sp_broadcast_v\n\nfrom .utils import rand_coo\n\n\n@pytest.mark.parametrize(\"shape\", [(3, 4), (1, 5), (5, 1)])\n@pytest.mark.parametrize(\"nnz\", [1, 4])\n@pytest.mark.parametrize(\"nz_dim\", [None, 2])\n@pytest.mark.parametrize(\"op\", [\"add\", \"sub\", \"mul\", \"truediv\"])\ndef test_sp_broadcast_v(shape, nnz, nz_dim, op):\n    dev = F.ctx()\n    A = rand_coo(shape, nnz, dev, nz_dim)\n\n    v = torch.randn(A.shape[1], device=dev)\n    res1 = sp_broadcast_v(A, v, op)\n    if A.val.dim() == 1:\n        rhs = v[A.col]\n    else:\n        rhs = v[A.col].view(-1, 1)\n    res2 = getattr(operator, op)(A.val, rhs)\n    assert torch.allclose(res1.val, res2)\n\n    v = torch.randn(1, A.shape[1], device=dev)\n    res1 = sp_broadcast_v(A, v, op)\n    if A.val.dim() == 1:\n        rhs = v.view(-1)[A.col]\n    else:\n        rhs = v.view(-1)[A.col].view(-1, 1)\n    res2 = getattr(operator, op)(A.val, rhs)\n    assert torch.allclose(res1.val, res2)\n\n    v = torch.randn(A.shape[0], 1, device=dev)\n    res1 = sp_broadcast_v(A, v, op)\n    if A.val.dim() == 1:\n        rhs = v.view(-1)[A.row]\n    else:\n        rhs = v.view(-1)[A.row].view(-1, 1)\n    res2 = getattr(operator, op)(A.val, rhs)\n    assert torch.allclose(res1.val, res2)\n"
  },
  {
    "path": "tests/python/pytorch/sparse/test_elementwise_op.py",
    "content": "import operator\n\nimport backend as F\n\nimport dgl.sparse as dglsp\nimport pytest\nimport torch\n\nfrom dgl.sparse import diag, power\n\n\n@pytest.mark.parametrize(\"opname\", [\"add\", \"sub\", \"mul\", \"truediv\"])\ndef test_diag_op_diag(opname):\n    op = getattr(operator, opname)\n    ctx = F.ctx()\n    shape = (3, 4)\n    D1 = diag(torch.arange(1, 4).to(ctx), shape=shape)\n    D2 = diag(torch.arange(10, 13).to(ctx), shape=shape)\n    result = op(D1, D2)\n    assert torch.allclose(result.val, op(D1.val, D2.val), rtol=1e-4, atol=1e-4)\n    assert result.shape == D1.shape\n\n\n@pytest.mark.parametrize(\n    \"v_scalar\", [2, 2.5, torch.tensor(2), torch.tensor(2.5)]\n)\ndef test_diag_op_scalar(v_scalar):\n    ctx = F.ctx()\n    shape = (3, 4)\n    D1 = diag(torch.arange(1, 4).to(ctx), shape=shape)\n\n    # D * v\n    D2 = D1 * v_scalar\n    assert torch.allclose(D1.val * v_scalar, D2.val, rtol=1e-4, atol=1e-4)\n    assert D1.shape == D2.shape\n\n    # v * D\n    D2 = v_scalar * D1\n    assert torch.allclose(v_scalar * D1.val, D2.val, rtol=1e-4, atol=1e-4)\n    assert D1.shape == D2.shape\n\n    # D / v\n    D2 = D1 / v_scalar\n    assert torch.allclose(D1.val / v_scalar, D2.val, rtol=1e-4, atol=1e-4)\n    assert D1.shape == D2.shape\n\n    # D ^ v\n    D1 = diag(torch.arange(1, 4).to(ctx))\n    D2 = D1**v_scalar\n    assert torch.allclose(D1.val**v_scalar, D2.val, rtol=1e-4, atol=1e-4)\n    assert D1.shape == D2.shape\n\n    # pow(D, v)\n    D2 = power(D1, v_scalar)\n    assert torch.allclose(D1.val**v_scalar, D2.val, rtol=1e-4, atol=1e-4)\n    assert D1.shape == D2.shape\n\n    with pytest.raises(TypeError):\n        D1 + v_scalar\n    with pytest.raises(TypeError):\n        v_scalar + D1\n\n    with pytest.raises(TypeError):\n        D1 - v_scalar\n    with pytest.raises(TypeError):\n        v_scalar - D1\n\n\n@pytest.mark.parametrize(\"val_shape\", [(), (2,)])\n@pytest.mark.parametrize(\"opname\", [\"add\", \"sub\"])\ndef test_addsub_coo(val_shape, opname):\n    op = getattr(operator, opname)\n    func = getattr(dglsp, opname)\n    ctx = F.ctx()\n    row = torch.tensor([1, 0, 2]).to(ctx)\n    col = torch.tensor([0, 3, 2]).to(ctx)\n    val = torch.randn(row.shape + val_shape).to(ctx)\n    A = dglsp.from_coo(row, col, val)\n\n    row = torch.tensor([1, 0]).to(ctx)\n    col = torch.tensor([0, 2]).to(ctx)\n    val = torch.randn(row.shape + val_shape).to(ctx)\n    B = dglsp.from_coo(row, col, val, shape=A.shape)\n\n    C1 = op(A, B).to_dense()\n    C2 = func(A, B).to_dense()\n    dense_C = op(A.to_dense(), B.to_dense())\n\n    assert torch.allclose(dense_C, C1)\n    assert torch.allclose(dense_C, C2)\n\n    with pytest.raises(TypeError):\n        op(A, 2)\n    with pytest.raises(TypeError):\n        op(2, A)\n\n\n@pytest.mark.parametrize(\"val_shape\", [(), (2,)])\n@pytest.mark.parametrize(\"opname\", [\"add\", \"sub\"])\ndef test_addsub_csr(val_shape, opname):\n    op = getattr(operator, opname)\n    func = getattr(dglsp, opname)\n    ctx = F.ctx()\n    indptr = torch.tensor([0, 1, 2, 3]).to(ctx)\n    indices = torch.tensor([3, 0, 2]).to(ctx)\n    val = torch.randn(indices.shape + val_shape).to(ctx)\n    A = dglsp.from_csr(indptr, indices, val)\n\n    indptr = torch.tensor([0, 1, 2, 2]).to(ctx)\n    indices = torch.tensor([2, 0]).to(ctx)\n    val = torch.randn(indices.shape + val_shape).to(ctx)\n    B = dglsp.from_csr(indptr, indices, val, shape=A.shape)\n\n    C1 = op(A, B).to_dense()\n    C2 = func(A, B).to_dense()\n    dense_C = op(A.to_dense(), B.to_dense())\n\n    assert torch.allclose(dense_C, C1)\n    assert torch.allclose(dense_C, C2)\n\n    with pytest.raises(TypeError):\n        op(A, 2)\n    with pytest.raises(TypeError):\n        op(2, A)\n\n\n@pytest.mark.parametrize(\"val_shape\", [(), (2,)])\n@pytest.mark.parametrize(\"opname\", [\"add\", \"sub\"])\ndef test_addsub_csc(val_shape, opname):\n    op = getattr(operator, opname)\n    func = getattr(dglsp, opname)\n    ctx = F.ctx()\n    indptr = torch.tensor([0, 1, 1, 2, 3]).to(ctx)\n    indices = torch.tensor([1, 2, 0]).to(ctx)\n    val = torch.randn(indices.shape + val_shape).to(ctx)\n    A = dglsp.from_csc(indptr, indices, val)\n\n    indptr = torch.tensor([0, 1, 1, 2, 2]).to(ctx)\n    indices = torch.tensor([1, 0]).to(ctx)\n    val = torch.randn(indices.shape + val_shape).to(ctx)\n    B = dglsp.from_csc(indptr, indices, val, shape=A.shape)\n\n    C1 = op(A, B).to_dense()\n    C2 = func(A, B).to_dense()\n    dense_C = op(A.to_dense(), B.to_dense())\n\n    assert torch.allclose(dense_C, C1)\n    assert torch.allclose(dense_C, C2)\n\n    with pytest.raises(TypeError):\n        op(A, 2)\n    with pytest.raises(TypeError):\n        op(2, A)\n\n\n@pytest.mark.parametrize(\"val_shape\", [(), (2,)])\n@pytest.mark.parametrize(\"opname\", [\"add\", \"sub\"])\ndef test_addsub_diag(val_shape, opname):\n    op = getattr(operator, opname)\n    func = getattr(dglsp, opname)\n    ctx = F.ctx()\n    shape = (3, 4)\n    val_shape = (shape[0],) + val_shape\n    D1 = dglsp.diag(torch.randn(val_shape).to(ctx), shape=shape)\n    D2 = dglsp.diag(torch.randn(val_shape).to(ctx), shape=shape)\n\n    C1 = op(D1, D2).to_dense()\n    C2 = func(D1, D2).to_dense()\n    dense_C = op(D1.to_dense(), D2.to_dense())\n\n    assert torch.allclose(dense_C, C1)\n    assert torch.allclose(dense_C, C2)\n\n    with pytest.raises(TypeError):\n        op(D1, 2)\n    with pytest.raises(TypeError):\n        op(2, D1)\n\n\n@pytest.mark.parametrize(\"val_shape\", [(), (2,)])\ndef test_add_sparse_diag(val_shape):\n    ctx = F.ctx()\n    row = torch.tensor([1, 0, 2]).to(ctx)\n    col = torch.tensor([0, 3, 2]).to(ctx)\n    val = torch.randn(row.shape + val_shape).to(ctx)\n    A = dglsp.from_coo(row, col, val)\n\n    shape = (3, 4)\n    val_shape = (shape[0],) + val_shape\n    D = dglsp.diag(torch.randn(val_shape).to(ctx), shape=shape)\n\n    sum1 = (A + D).to_dense()\n    sum2 = (D + A).to_dense()\n    sum3 = dglsp.add(A, D).to_dense()\n    sum4 = dglsp.add(D, A).to_dense()\n    dense_sum = A.to_dense() + D.to_dense()\n\n    assert torch.allclose(dense_sum, sum1)\n    assert torch.allclose(dense_sum, sum2)\n    assert torch.allclose(dense_sum, sum3)\n    assert torch.allclose(dense_sum, sum4)\n\n\n@pytest.mark.parametrize(\"val_shape\", [(), (2,)])\ndef test_sub_sparse_diag(val_shape):\n    ctx = F.ctx()\n    row = torch.tensor([1, 0, 2]).to(ctx)\n    col = torch.tensor([0, 3, 2]).to(ctx)\n    val = torch.randn(row.shape + val_shape).to(ctx)\n    A = dglsp.from_coo(row, col, val)\n\n    shape = (3, 4)\n    val_shape = (shape[0],) + val_shape\n    D = dglsp.diag(torch.randn(val_shape).to(ctx), shape=shape)\n\n    diff1 = (A - D).to_dense()\n    diff2 = (D - A).to_dense()\n    diff3 = dglsp.sub(A, D).to_dense()\n    diff4 = dglsp.sub(D, A).to_dense()\n    dense_diff = A.to_dense() - D.to_dense()\n\n    assert torch.allclose(dense_diff, diff1)\n    assert torch.allclose(dense_diff, -diff2)\n    assert torch.allclose(dense_diff, diff3)\n    assert torch.allclose(dense_diff, -diff4)\n\n\n@pytest.mark.parametrize(\"op\", [\"pow\"])\ndef test_error_op_sparse_diag(op):\n    ctx = F.ctx()\n    row = torch.tensor([1, 0, 2]).to(ctx)\n    col = torch.tensor([0, 3, 2]).to(ctx)\n    val = torch.randn(row.shape).to(ctx)\n    A = dglsp.from_coo(row, col, val)\n\n    shape = (3, 4)\n    D = dglsp.diag(torch.randn(row.shape[0]).to(ctx), shape=shape)\n\n    with pytest.raises(TypeError):\n        getattr(operator, op)(A, D)\n    with pytest.raises(TypeError):\n        getattr(operator, op)(D, A)\n"
  },
  {
    "path": "tests/python/pytorch/sparse/test_elementwise_op_sp.py",
    "content": "import sys\n\nimport backend as F\nimport pytest\nimport torch\n\nfrom dgl.sparse import div, from_coo, mul, power, spmatrix, val_like\n\nfrom .utils import (\n    rand_coo,\n    rand_csc,\n    rand_csr,\n    rand_diag,\n    sparse_matrix_to_dense,\n)\n\n\ndef all_close_sparse(A, row, col, val, shape):\n    rowA, colA = A.coo()\n    valA = A.val\n    assert torch.allclose(rowA, row)\n    assert torch.allclose(colA, col)\n    assert torch.allclose(valA, val)\n    assert A.shape == shape\n\n\n@pytest.mark.parametrize(\n    \"v_scalar\", [2, 2.5, torch.tensor(2), torch.tensor(2.5)]\n)\ndef test_muldiv_scalar(v_scalar):\n    ctx = F.ctx()\n    row = torch.tensor([1, 0, 2]).to(ctx)\n    col = torch.tensor([0, 3, 2]).to(ctx)\n    val = torch.randn(len(row)).to(ctx)\n    A1 = from_coo(row, col, val, shape=(3, 4))\n\n    # A * v\n    A2 = A1 * v_scalar\n    assert torch.allclose(A1.val * v_scalar, A2.val, rtol=1e-4, atol=1e-4)\n    assert A1.shape == A2.shape\n\n    # v * A\n    A2 = v_scalar * A1\n    assert torch.allclose(A1.val * v_scalar, A2.val, rtol=1e-4, atol=1e-4)\n    assert A1.shape == A2.shape\n\n    # A / v\n    A2 = A1 / v_scalar\n    assert torch.allclose(A1.val / v_scalar, A2.val, rtol=1e-4, atol=1e-4)\n    assert A1.shape == A2.shape\n\n    # v / A\n    with pytest.raises(TypeError):\n        v_scalar / A1\n\n\n@pytest.mark.parametrize(\"val_shape\", [(3,), (3, 2)])\ndef test_pow(val_shape):\n    # A ** v\n    ctx = F.ctx()\n    row = torch.tensor([1, 0, 2]).to(ctx)\n    col = torch.tensor([0, 3, 2]).to(ctx)\n    val = torch.randn(val_shape).to(ctx)\n    A = from_coo(row, col, val, shape=(3, 4))\n    exponent = 2\n    A_new = A**exponent\n    assert torch.allclose(A_new.val, val**exponent)\n    assert A_new.shape == A.shape\n    new_row, new_col = A_new.coo()\n    assert torch.allclose(new_row, row)\n    assert torch.allclose(new_col, col)\n\n    # power(A, v)\n    A_new = power(A, exponent)\n    assert torch.allclose(A_new.val, val**exponent)\n    assert A_new.shape == A.shape\n    new_row, new_col = A_new.coo()\n    assert torch.allclose(new_row, row)\n    assert torch.allclose(new_col, col)\n\n\n@pytest.mark.parametrize(\"op\", [\"add\", \"sub\"])\n@pytest.mark.parametrize(\n    \"v_scalar\", [2, 2.5, torch.tensor(2), torch.tensor(2.5)]\n)\ndef test_error_op_scalar(op, v_scalar):\n    ctx = F.ctx()\n    row = torch.tensor([1, 0, 2]).to(ctx)\n    col = torch.tensor([0, 3, 2]).to(ctx)\n    val = torch.randn(len(row)).to(ctx)\n    A = from_coo(row, col, val, shape=(3, 4))\n\n    with pytest.raises(TypeError):\n        A + v_scalar\n    with pytest.raises(TypeError):\n        v_scalar + A\n\n    with pytest.raises(TypeError):\n        A - v_scalar\n    with pytest.raises(TypeError):\n        v_scalar - A\n\n\n@pytest.mark.parametrize(\n    \"create_func1\", [rand_coo, rand_csr, rand_csc, rand_diag]\n)\n@pytest.mark.parametrize(\n    \"create_func2\", [rand_coo, rand_csr, rand_csc, rand_diag]\n)\n@pytest.mark.parametrize(\"shape\", [(5, 5), (5, 3)])\n@pytest.mark.parametrize(\"nnz1\", [5, 15])\n@pytest.mark.parametrize(\"nnz2\", [1, 14])\n@pytest.mark.parametrize(\"nz_dim\", [None, 3])\ndef test_spspmul(create_func1, create_func2, shape, nnz1, nnz2, nz_dim):\n    dev = F.ctx()\n    A = create_func1(shape, nnz1, dev, nz_dim)\n    B = create_func2(shape, nnz2, dev, nz_dim)\n    C = mul(A, B)\n    assert not C.has_duplicate()\n\n    DA = sparse_matrix_to_dense(A)\n    DB = sparse_matrix_to_dense(B)\n    DC = DA * DB\n\n    grad = torch.rand_like(C.val)\n    C.val.backward(grad)\n    DC_grad = sparse_matrix_to_dense(val_like(C, grad))\n    DC.backward(DC_grad)\n\n    assert torch.allclose(sparse_matrix_to_dense(C), DC, atol=1e-05)\n    assert torch.allclose(\n        val_like(A, A.val.grad).to_dense(), DA.grad, atol=1e-05\n    )\n    assert torch.allclose(\n        val_like(B, B.val.grad).to_dense(), DB.grad, atol=1e-05\n    )\n\n\n@pytest.mark.parametrize(\n    \"create_func\", [rand_coo, rand_csr, rand_csc, rand_diag]\n)\n@pytest.mark.parametrize(\"shape\", [(5, 5), (5, 3)])\n@pytest.mark.parametrize(\"nnz\", [1, 14])\n@pytest.mark.parametrize(\"nz_dim\", [None, 3])\ndef test_spspdiv(create_func, nnz, shape, nz_dim):\n    dev = F.ctx()\n    A = create_func(shape, nnz, dev, nz_dim)\n\n    perm = torch.randperm(A.nnz, device=dev)\n    rperm = torch.argsort(perm)\n    B = spmatrix(A.indices()[:, perm], A.val[perm], A.shape)\n    C = div(A, B)\n    assert not C.has_duplicate()\n    assert torch.allclose(C.val, A.val / B.val[rperm], atol=1e-05)\n    assert torch.allclose(C.indices(), A.indices(), atol=1e-05)\n\n    # No need to test backward here, since it is handled by Pytorch\n"
  },
  {
    "path": "tests/python/pytorch/sparse/test_matmul.py",
    "content": "import warnings\n\nimport backend as F\nimport pytest\nimport torch\n\nfrom dgl.sparse import bspmm, diag, from_coo, val_like\nfrom dgl.sparse.matmul import matmul\n\nfrom .utils import (\n    clone_detach_and_grad,\n    dense_mask,\n    rand_coo,\n    rand_csc,\n    rand_csr,\n    rand_stride,\n    sparse_matrix_to_dense,\n    sparse_matrix_to_torch_sparse,\n)\n\n\ndef _torch_sparse_mm(torch_A1, torch_A2):\n    with warnings.catch_warnings():\n        warnings.simplefilter(\"ignore\", category=UserWarning)\n        return torch.sparse.mm(torch_A1, torch_A2)\n\n\n@pytest.mark.parametrize(\"create_func\", [rand_coo, rand_csr, rand_csc])\n@pytest.mark.parametrize(\"shape\", [(2, 7), (5, 2)])\n@pytest.mark.parametrize(\"nnz\", [1, 10])\n@pytest.mark.parametrize(\"out_dim\", [None, 10])\ndef test_spmm(create_func, shape, nnz, out_dim):\n    dev = F.ctx()\n    A = create_func(shape, nnz, dev)\n    if out_dim is not None:\n        X = torch.randn(shape[1], out_dim, requires_grad=True, device=dev)\n    else:\n        X = torch.randn(shape[1], requires_grad=True, device=dev)\n\n    X = rand_stride(X)\n    sparse_result = matmul(A, X)\n    grad = torch.randn_like(sparse_result)\n    sparse_result.backward(grad)\n\n    adj = sparse_matrix_to_dense(A)\n    XX = clone_detach_and_grad(X)\n    dense_result = torch.matmul(adj, XX)\n    if out_dim is None:\n        dense_result = dense_result.view(-1)\n    dense_result.backward(grad)\n    assert torch.allclose(sparse_result, dense_result, atol=1e-05)\n    assert torch.allclose(X.grad, XX.grad, atol=1e-05)\n    assert torch.allclose(\n        dense_mask(adj.grad, A),\n        sparse_matrix_to_dense(val_like(A, A.val.grad)),\n        atol=1e-05,\n    )\n\n\n@pytest.mark.parametrize(\"create_func\", [rand_coo, rand_csr, rand_csc])\n@pytest.mark.parametrize(\"shape\", [(2, 7), (5, 2)])\n@pytest.mark.parametrize(\"nnz\", [1, 10])\ndef test_bspmm(create_func, shape, nnz):\n    dev = F.ctx()\n    A = create_func(shape, nnz, dev, 2)\n    X = torch.randn(shape[1], 10, 2, requires_grad=True, device=dev)\n    X = rand_stride(X)\n\n    sparse_result = matmul(A, X)\n    grad = torch.randn_like(sparse_result)\n    sparse_result.backward(grad)\n\n    XX = clone_detach_and_grad(X)\n    torch_A = A.to_dense().clone().detach().requires_grad_()\n    torch_result = torch_A.permute(2, 0, 1) @ XX.permute(2, 0, 1)\n\n    torch_result.backward(grad.permute(2, 0, 1))\n    assert torch.allclose(\n        sparse_result.permute(2, 0, 1), torch_result, atol=1e-05\n    )\n    assert torch.allclose(X.grad, XX.grad, atol=1e-05)\n    assert torch.allclose(\n        dense_mask(torch_A.grad, A),\n        sparse_matrix_to_dense(val_like(A, A.val.grad)),\n        atol=1e-05,\n    )\n\n\n@pytest.mark.parametrize(\"create_func1\", [rand_coo, rand_csr, rand_csc])\n@pytest.mark.parametrize(\"create_func2\", [rand_coo, rand_csr, rand_csc])\n@pytest.mark.parametrize(\"shape_n_m\", [(5, 5), (5, 6)])\n@pytest.mark.parametrize(\"shape_k\", [3, 4])\n@pytest.mark.parametrize(\"nnz1\", [1, 10])\n@pytest.mark.parametrize(\"nnz2\", [1, 10])\ndef test_spspmm(create_func1, create_func2, shape_n_m, shape_k, nnz1, nnz2):\n    dev = F.ctx()\n    shape1 = shape_n_m\n    shape2 = (shape_n_m[1], shape_k)\n    A1 = create_func1(shape1, nnz1, dev)\n    A2 = create_func2(shape2, nnz2, dev)\n    A3 = matmul(A1, A2)\n    grad = torch.randn_like(A3.val)\n    A3.val.backward(grad)\n\n    torch_A1 = sparse_matrix_to_torch_sparse(A1)\n    torch_A2 = sparse_matrix_to_torch_sparse(A2)\n    torch_A3 = _torch_sparse_mm(torch_A1, torch_A2)\n    torch_A3_grad = sparse_matrix_to_torch_sparse(A3, grad)\n    torch_A3.backward(torch_A3_grad)\n\n    with torch.no_grad():\n        assert torch.allclose(A3.to_dense(), torch_A3.to_dense(), atol=1e-05)\n        assert torch.allclose(\n            val_like(A1, A1.val.grad).to_dense(),\n            torch_A1.grad.to_dense(),\n            atol=1e-05,\n        )\n        assert torch.allclose(\n            val_like(A2, A2.val.grad).to_dense(),\n            torch_A2.grad.to_dense(),\n            atol=1e-05,\n        )\n\n\ndef test_spspmm_duplicate():\n    dev = F.ctx()\n\n    row = torch.tensor([1, 0, 0, 0, 1]).to(dev)\n    col = torch.tensor([1, 1, 1, 2, 2]).to(dev)\n    val = torch.randn(len(row)).to(dev)\n    shape = (4, 4)\n    A1 = from_coo(row, col, val, shape)\n\n    row = torch.tensor([1, 0, 0, 1]).to(dev)\n    col = torch.tensor([1, 1, 2, 2]).to(dev)\n    val = torch.randn(len(row)).to(dev)\n    shape = (4, 4)\n    A2 = from_coo(row, col, val, shape)\n\n    try:\n        matmul(A1, A2)\n    except:\n        pass\n    else:\n        assert False, \"Should raise error.\"\n\n    try:\n        matmul(A2, A1)\n    except:\n        pass\n    else:\n        assert False, \"Should raise error.\"\n\n\n@pytest.mark.parametrize(\"create_func\", [rand_coo, rand_csr, rand_csc])\n@pytest.mark.parametrize(\"sparse_shape\", [(5, 5), (5, 6)])\n@pytest.mark.parametrize(\"nnz\", [1, 10])\ndef test_sparse_diag_mm(create_func, sparse_shape, nnz):\n    dev = F.ctx()\n    diag_shape = sparse_shape[1], sparse_shape[1]\n    A = create_func(sparse_shape, nnz, dev)\n    diag_val = torch.randn(sparse_shape[1], device=dev, requires_grad=True)\n    D = diag(diag_val, diag_shape)\n    B = matmul(A, D)\n    grad = torch.randn_like(B.val)\n    B.val.backward(grad)\n\n    torch_A = sparse_matrix_to_torch_sparse(A)\n    torch_D = sparse_matrix_to_torch_sparse(D)\n    torch_B = _torch_sparse_mm(torch_A, torch_D)\n    torch_B_grad = sparse_matrix_to_torch_sparse(B, grad)\n    torch_B.backward(torch_B_grad)\n\n    with torch.no_grad():\n        assert torch.allclose(B.to_dense(), torch_B.to_dense(), atol=1e-05)\n        assert torch.allclose(\n            val_like(A, A.val.grad).to_dense(),\n            torch_A.grad.to_dense(),\n            atol=1e-05,\n        )\n        assert torch.allclose(\n            diag(D.val.grad, D.shape).to_dense(),\n            torch_D.grad.to_dense(),\n            atol=1e-05,\n        )\n\n\n@pytest.mark.parametrize(\"create_func\", [rand_coo, rand_csr, rand_csc])\n@pytest.mark.parametrize(\"sparse_shape\", [(5, 5), (5, 6)])\n@pytest.mark.parametrize(\"nnz\", [1, 10])\ndef test_diag_sparse_mm(create_func, sparse_shape, nnz):\n    dev = F.ctx()\n    diag_shape = sparse_shape[0], sparse_shape[0]\n    A = create_func(sparse_shape, nnz, dev)\n    diag_val = torch.randn(sparse_shape[0], device=dev, requires_grad=True)\n    D = diag(diag_val, diag_shape)\n    B = matmul(D, A)\n    grad = torch.randn_like(B.val)\n    B.val.backward(grad)\n\n    torch_A = sparse_matrix_to_torch_sparse(A)\n    torch_D = sparse_matrix_to_torch_sparse(D)\n    torch_B = _torch_sparse_mm(torch_D, torch_A)\n    torch_B_grad = sparse_matrix_to_torch_sparse(B, grad)\n    torch_B.backward(torch_B_grad)\n\n    with torch.no_grad():\n        assert torch.allclose(B.to_dense(), torch_B.to_dense(), atol=1e-05)\n        assert torch.allclose(\n            val_like(A, A.val.grad).to_dense(),\n            torch_A.grad.to_dense(),\n            atol=1e-05,\n        )\n        assert torch.allclose(\n            diag(D.val.grad, D.shape).to_dense(),\n            torch_D.grad.to_dense(),\n            atol=1e-05,\n        )\n"
  },
  {
    "path": "tests/python/pytorch/sparse/test_matrix_op.py",
    "content": "import backend as F\nimport pytest\nimport torch\n\nfrom .utils import (\n    rand_coo,\n    rand_csc,\n    rand_csr,\n    rand_diag,\n    sparse_matrix_to_dense,\n)\n\n\n@pytest.mark.parametrize(\n    \"create_func\", [rand_diag, rand_csr, rand_csc, rand_coo]\n)\n@pytest.mark.parametrize(\"dim\", [0, 1])\n@pytest.mark.parametrize(\"index\", [None, (1, 3), (4, 0, 2)])\ndef test_compact(create_func, dim, index):\n    ctx = F.ctx()\n    shape = (5, 5)\n    ans_idx = []\n    if index is not None:\n        ans_idx = list(dict.fromkeys(index))\n        index = torch.tensor(index).to(ctx)\n\n    A = create_func(shape, 8, ctx)\n\n    A_compact, ret_id = A.compact(dim, index)\n    A_compact_dense = sparse_matrix_to_dense(A_compact)\n\n    A_dense = sparse_matrix_to_dense(A)\n\n    for i in range(shape[dim]):\n        if dim == 0:\n            row = list(A_dense[i, :].nonzero().reshape(-1))\n        else:\n            row = list(A_dense[:, i].nonzero().reshape(-1))\n        if (i not in list(ans_idx)) and len(row) > 0:\n            ans_idx.append(i)\n    if len(ans_idx):\n        ans_idx = torch.tensor(ans_idx).to(ctx)\n    A_dense_select = sparse_matrix_to_dense(A.index_select(dim, ans_idx))\n\n    assert A_compact_dense.shape == A_dense_select.shape\n    assert torch.allclose(A_compact_dense, A_dense_select)\n    assert torch.allclose(ans_idx, ret_id)\n"
  },
  {
    "path": "tests/python/pytorch/sparse/test_reduction.py",
    "content": "import doctest\nimport operator\nimport sys\n\nimport backend as F\n\nimport dgl.sparse as dglsp\nimport pytest\nimport torch\n\ndgl_op_map = {\n    \"sum\": \"sum\",\n    \"amin\": \"smin\",\n    \"amax\": \"smax\",\n    \"mean\": \"smean\",\n    \"prod\": \"sprod\",\n}\ndefault_entry = {\n    \"sum\": 0,\n    \"amin\": float(\"inf\"),\n    \"amax\": float(\"-inf\"),\n    \"mean\": 0,\n    \"prod\": 1,\n}\nbinary_op_map = {\n    \"sum\": operator.add,\n    \"amin\": torch.min,\n    \"amax\": torch.max,\n    \"mean\": operator.add,\n    \"prod\": operator.mul,\n}\n\nNUM_ROWS = 10\nNUM_COLS = 15\n\n\ndef _coalesce_dense(row, col, val, nrows, ncols, op):\n    # Sparse matrix coalescing on a dense matrix.\n    #\n    # It is done by stacking every non-zero entry on an individual slice\n    # of an (nrows x ncols x nnz), that is, construct a tensor A with\n    # shape (nrows, ncols, len(val)) where\n    #\n    #     A[row[i], col[i], i] = val[i]\n    #\n    # and then reducing on the third \"nnz\" dimension.\n    #\n    # The mask matrix M has the same sparsity pattern as A with 1 being\n    # the non-zero entries.  This is used for division if the reduce\n    # operator is mean.\n    M = torch.zeros(NUM_ROWS, NUM_COLS, device=F.ctx())\n    A = torch.full(\n        (NUM_ROWS, NUM_COLS, 20) + val.shape[1:],\n        default_entry[op],\n        device=F.ctx(),\n        dtype=val.dtype,\n    )\n    A = torch.index_put(A, (row, col, torch.arange(20)), val)\n    for i in range(20):\n        M[row[i], col[i]] += 1\n    if op == \"mean\":\n        A = A.sum(2)\n    else:\n        A = getattr(A, op)(2)\n    M = M.view(NUM_ROWS, NUM_COLS, *([1] * (val.dim() - 1)))\n    return A, M\n\n\n# Add docstring tests of dglsp.reduction to unit tests\n@pytest.mark.parametrize(\n    \"func\", [\"reduce\", \"sum\", \"smin\", \"smax\", \"sprod\", \"smean\"]\n)\ndef test_docstring(func):\n    globs = {\"torch\": torch, \"dglsp\": dglsp}\n    runner = doctest.DebugRunner()\n    finder = doctest.DocTestFinder()\n    obj = getattr(dglsp, func)\n    for test in finder.find(obj, func, globs=globs):\n        runner.run(test)\n\n\n@pytest.mark.parametrize(\"shape\", [(20,), (20, 20)])\n@pytest.mark.parametrize(\"op\", [\"sum\", \"amin\", \"amax\", \"mean\", \"prod\"])\n@pytest.mark.parametrize(\"use_reduce\", [False, True])\ndef test_reduce_all(shape, op, use_reduce):\n    row = torch.randint(0, NUM_ROWS, (20,), device=F.ctx())\n    col = torch.randint(0, NUM_COLS, (20,), device=F.ctx())\n    val = torch.randn(*shape, device=F.ctx())\n    val2 = val.clone()\n    val = val.requires_grad_()\n    val2 = val2.requires_grad_()\n    A = dglsp.from_coo(row, col, val, shape=(NUM_ROWS, NUM_COLS))\n\n    A2, M = _coalesce_dense(row, col, val2, NUM_ROWS, NUM_COLS, op)\n\n    if not use_reduce:\n        output = getattr(A, dgl_op_map[op])()\n    else:\n        output = A.reduce(rtype=dgl_op_map[op])\n\n    if op == \"mean\":\n        output2 = A2.sum((0, 1)) / M.sum()\n    elif op == \"prod\":\n        output2 = A2.prod(0).prod(0)  # prod() does not support tuple of dims\n    else:\n        output2 = getattr(A2, op)((0, 1))\n    assert (output - output2).abs().max() < 1e-4\n\n    head = torch.randn(*output.shape).to(val) if output.dim() > 0 else None\n    output.backward(head)\n    output2.backward(head)\n    assert (val.grad - val2.grad).abs().max() < 1e-4\n\n\n@pytest.mark.parametrize(\"shape\", [(20,), (20, 20)])\n@pytest.mark.parametrize(\"dim\", [0, 1])\n@pytest.mark.parametrize(\"empty_nnz\", [False, True])\n@pytest.mark.parametrize(\"op\", [\"sum\", \"amin\", \"amax\", \"mean\", \"prod\"])\n@pytest.mark.parametrize(\"use_reduce\", [False, True])\ndef test_reduce_along(shape, dim, empty_nnz, op, use_reduce):\n    row = torch.randint(0, NUM_ROWS, (20,), device=F.ctx())\n    col = torch.randint(0, NUM_COLS, (20,), device=F.ctx())\n    if dim == 0:\n        mask = torch.bincount(col, minlength=NUM_COLS) == 0\n    else:\n        mask = torch.bincount(row, minlength=NUM_ROWS) == 0\n    val = torch.randn(*shape, device=F.ctx())\n    val2 = val.clone()\n    val = val.requires_grad_()\n    val2 = val2.requires_grad_()\n\n    # empty_nnz controls whether at least one column or one row has no\n    # non-zero entry.\n    if empty_nnz:\n        row[row == 0] = 1\n        col[col == 0] = 1\n\n    A = dglsp.from_coo(row, col, val, shape=(NUM_ROWS, NUM_COLS))\n\n    A2, M = _coalesce_dense(row, col, val2, NUM_ROWS, NUM_COLS, op)\n\n    if not use_reduce:\n        output = getattr(A, dgl_op_map[op])(dim)\n    else:\n        output = A.reduce(dim=dim, rtype=dgl_op_map[op])\n\n    if op == \"mean\":\n        output2 = A2.sum(dim) / M.sum(dim)\n    else:\n        output2 = getattr(A2, op)(dim)\n    zero_entry_idx = (M.sum(dim) != 0).nonzero(as_tuple=True)[0]\n    output3 = torch.index_put(\n        torch.zeros_like(output2), (zero_entry_idx,), output2[zero_entry_idx]\n    )\n    assert (output - output3).abs().max() < 1e-4\n\n    head = torch.randn(*output.shape).to(val) if output.dim() > 0 else None\n    output.backward(head)\n    output3.backward(head)\n    assert (val.grad - val2.grad).abs().max() < 1e-4\n"
  },
  {
    "path": "tests/python/pytorch/sparse/test_sddmm.py",
    "content": "import sys\n\nimport backend as F\nimport pytest\nimport torch\n\nfrom dgl.sparse import bsddmm, sddmm\n\nfrom .utils import (\n    clone_detach_and_grad,\n    rand_coo,\n    rand_csc,\n    rand_csr,\n    rand_stride,\n)\n\n\n@pytest.mark.parametrize(\"create_func\", [rand_coo, rand_csr, rand_csc])\n@pytest.mark.parametrize(\"shape\", [(5, 5), (5, 4)])\n@pytest.mark.parametrize(\"nnz\", [2, 10])\n@pytest.mark.parametrize(\"hidden\", [1, 5])\ndef test_sddmm(create_func, shape, nnz, hidden):\n    dev = F.ctx()\n    A = create_func(shape, nnz, dev)\n    if hidden > 1:\n        B = torch.rand(shape[0], hidden, requires_grad=True, device=dev)\n        C = torch.rand(hidden, shape[1], requires_grad=True, device=dev)\n    else:\n        B = torch.rand(shape[0], requires_grad=True, device=dev)\n        C = torch.rand(shape[1], requires_grad=True, device=dev)\n\n    B = rand_stride(B)\n    C = rand_stride(C)\n\n    A_val_clone = clone_detach_and_grad(A.val)\n    dense_B = clone_detach_and_grad(B)\n    dense_C = clone_detach_and_grad(C)\n\n    sparse_result = sddmm(A, B, C)\n\n    grad = torch.rand_like(sparse_result.val)\n    sparse_result.val.backward(grad)\n\n    if hidden == 1:\n        dense_result = dense_B.view(-1, 1) @ dense_C.view(1, -1)\n    else:\n        dense_result = dense_B @ dense_C\n\n    row, col = A.coo()\n    dense_val = dense_result[row, col] * A_val_clone\n    dense_val.backward(grad)\n\n    assert torch.allclose(dense_val, sparse_result.val, atol=1e-05)\n    assert torch.allclose(dense_C.grad, C.grad, atol=1e-05)\n    assert torch.allclose(dense_B.grad, B.grad, atol=1e-05)\n    assert torch.allclose(A_val_clone.grad, A.val.grad, atol=1e-05)\n\n\n@pytest.mark.parametrize(\"create_func\", [rand_coo, rand_csr, rand_csc])\n@pytest.mark.parametrize(\"shape\", [(5, 5), (5, 4)])\n@pytest.mark.parametrize(\"nnz\", [2, 10])\n@pytest.mark.parametrize(\"nz_dim\", [2, 10])\ndef test_bsddmm(create_func, shape, nnz, nz_dim):\n    dev = F.ctx()\n    hidden = 2\n    A = create_func(shape, nnz, dev, nz_dim)\n    B = torch.rand(shape[0], hidden, nz_dim, requires_grad=True, device=dev)\n    C = torch.rand(hidden, shape[1], nz_dim, requires_grad=True, device=dev)\n\n    B = rand_stride(B)\n    C = rand_stride(C)\n\n    A_val_clone = clone_detach_and_grad(A.val)\n    dense_B = clone_detach_and_grad(B)\n    dense_C = clone_detach_and_grad(C)\n\n    sparse_result = bsddmm(A, B, C)\n\n    grad = torch.rand_like(sparse_result.val)\n    sparse_result.val.backward(grad)\n\n    dense_result = dense_B.permute(2, 0, 1) @ dense_C.permute(2, 0, 1)\n    dense_result = dense_result.permute(1, 2, 0)\n\n    row, col = A.coo()\n    dense_val = dense_result[row, col] * A_val_clone\n    dense_val.backward(grad)\n\n    assert torch.allclose(dense_val, sparse_result.val, atol=1e-05)\n    assert torch.allclose(dense_C.grad, C.grad, atol=1e-05)\n    assert torch.allclose(dense_B.grad, B.grad, atol=1e-05)\n    assert torch.allclose(A_val_clone.grad, A.val.grad, atol=1e-05)\n"
  },
  {
    "path": "tests/python/pytorch/sparse/test_softmax.py",
    "content": "import sys\n\nimport backend as F\n\nimport dgl\nimport pytest\nimport torch\nfrom dgl.sparse import from_coo, softmax\n\n\n@pytest.mark.parametrize(\"val_D\", [None, 2])\n@pytest.mark.parametrize(\"csr\", [True, False])\n@pytest.mark.parametrize(\"dim\", [0, 1])\ndef test_softmax(val_D, csr, dim):\n    dev = F.ctx()\n    row = torch.tensor([0, 0, 1, 1]).to(dev)\n    col = torch.tensor([0, 2, 1, 2]).to(dev)\n    nnz = len(row)\n    if val_D is None:\n        val = torch.randn(nnz).to(dev)\n    else:\n        val = torch.randn(nnz, val_D).to(dev)\n\n    val_sparse = val.clone().requires_grad_()\n    A = from_coo(row, col, val_sparse)\n\n    if csr:\n        # Test CSR\n        A.csr()\n\n    A_max = softmax(A, dim)\n    if dim == 1:\n        g = dgl.graph((col, row), num_nodes=max(A.shape))\n    else:\n        g = dgl.graph((row, col), num_nodes=max(A.shape))\n    val_g = val.clone().requires_grad_()\n    score = dgl.nn.functional.edge_softmax(g, val_g)\n    assert torch.allclose(A_max.val, score, atol=1e-05)\n\n    grad = torch.randn_like(score).to(dev)\n    A_max.val.backward(grad)\n    score.backward(grad)\n    assert torch.allclose(A.val.grad, val_g.grad, atol=1e-05)\n"
  },
  {
    "path": "tests/python/pytorch/sparse/test_sparse_matrix.py",
    "content": "import unittest\nimport warnings\n\nimport backend as F\nimport pytest\nimport torch\n\nfrom dgl.sparse import (\n    diag,\n    from_coo,\n    from_csc,\n    from_csr,\n    from_torch_sparse,\n    identity,\n    to_torch_sparse_coo,\n    to_torch_sparse_csc,\n    to_torch_sparse_csr,\n    val_like,\n)\n\nfrom .utils import (\n    rand_coo,\n    rand_csc,\n    rand_csr,\n    rand_diag,\n    sparse_matrix_to_dense,\n)\n\n\ndef _torch_sparse_csr_tensor(indptr, indices, val, torch_sparse_shape):\n    with warnings.catch_warnings():\n        warnings.simplefilter(\"ignore\", category=UserWarning)\n        return torch.sparse_csr_tensor(indptr, indices, val, torch_sparse_shape)\n\n\n@pytest.mark.parametrize(\"dense_dim\", [None, 4])\n@pytest.mark.parametrize(\"row\", [(0, 0, 1, 2), (0, 1, 2, 4)])\n@pytest.mark.parametrize(\"col\", [(0, 1, 2, 2), (1, 3, 3, 4)])\n@pytest.mark.parametrize(\"shape\", [None, (5, 5), (5, 6)])\ndef test_from_coo(dense_dim, row, col, shape):\n    val_shape = (len(row),)\n    if dense_dim is not None:\n        val_shape += (dense_dim,)\n    ctx = F.ctx()\n    val = torch.randn(val_shape).to(ctx)\n    row = torch.tensor(row).to(ctx)\n    col = torch.tensor(col).to(ctx)\n    mat = from_coo(row, col, val, shape)\n\n    if shape is None:\n        shape = (torch.max(row).item() + 1, torch.max(col).item() + 1)\n\n    mat_row, mat_col = mat.coo()\n    mat_val = mat.val\n\n    assert mat.shape == shape\n    assert mat.nnz == row.numel()\n    assert mat.dtype == val.dtype\n    assert torch.allclose(mat_val, val)\n    assert torch.allclose(mat_row, row)\n    assert torch.allclose(mat_col, col)\n\n\n@pytest.mark.parametrize(\"dense_dim\", [None, 4])\n@pytest.mark.parametrize(\"indptr\", [(0, 0, 1, 4), (0, 1, 2, 4)])\n@pytest.mark.parametrize(\"indices\", [(0, 1, 2, 3), (1, 2, 3, 4)])\n@pytest.mark.parametrize(\"shape\", [None, (3, 5)])\ndef test_from_csr(dense_dim, indptr, indices, shape):\n    val_shape = (len(indices),)\n    if dense_dim is not None:\n        val_shape += (dense_dim,)\n    ctx = F.ctx()\n    val = torch.randn(val_shape).to(ctx)\n    indptr = torch.tensor(indptr).to(ctx)\n    indices = torch.tensor(indices).to(ctx)\n    mat = from_csr(indptr, indices, val, shape)\n\n    if shape is None:\n        shape = (indptr.numel() - 1, torch.max(indices).item() + 1)\n\n    assert mat.device == val.device\n    assert mat.shape == shape\n    assert mat.nnz == indices.numel()\n    assert mat.dtype == val.dtype\n    mat_indptr, mat_indices, value_indices = mat.csr()\n    mat_val = mat.val if value_indices is None else mat.val[value_indices]\n    assert torch.allclose(mat_indptr, indptr)\n    assert torch.allclose(mat_indices, indices)\n    assert torch.allclose(mat_val, val)\n\n\n@pytest.mark.parametrize(\"dense_dim\", [None, 4])\n@pytest.mark.parametrize(\"indptr\", [(0, 0, 1, 4), (0, 1, 2, 4)])\n@pytest.mark.parametrize(\"indices\", [(0, 1, 2, 3), (1, 2, 3, 4)])\n@pytest.mark.parametrize(\"shape\", [None, (5, 3)])\ndef test_from_csc(dense_dim, indptr, indices, shape):\n    val_shape = (len(indices),)\n    if dense_dim is not None:\n        val_shape += (dense_dim,)\n    ctx = F.ctx()\n    val = torch.randn(val_shape).to(ctx)\n    indptr = torch.tensor(indptr).to(ctx)\n    indices = torch.tensor(indices).to(ctx)\n    mat = from_csc(indptr, indices, val, shape)\n\n    if shape is None:\n        shape = (torch.max(indices).item() + 1, indptr.numel() - 1)\n\n    assert mat.device == val.device\n    assert mat.shape == shape\n    assert mat.nnz == indices.numel()\n    assert mat.dtype == val.dtype\n    mat_indptr, mat_indices, value_indices = mat.csc()\n    mat_val = mat.val if value_indices is None else mat.val[value_indices]\n    assert torch.allclose(mat_indptr, indptr)\n    assert torch.allclose(mat_indices, indices)\n    assert torch.allclose(mat_val, val)\n\n\n@pytest.mark.parametrize(\"val_shape\", [(3), (3, 2)])\ndef test_dense(val_shape):\n    ctx = F.ctx()\n\n    row = torch.tensor([1, 1, 2]).to(ctx)\n    col = torch.tensor([2, 4, 3]).to(ctx)\n    val = torch.randn(val_shape).to(ctx)\n    A = from_coo(row, col, val)\n    A_dense = A.to_dense()\n\n    shape = A.shape + val.shape[1:]\n    mat = torch.zeros(shape, device=ctx)\n    mat[row, col] = val\n    assert torch.allclose(A_dense, mat)\n\n\n@pytest.mark.parametrize(\"dense_dim\", [None, 4])\n@pytest.mark.parametrize(\"indptr\", [(0, 0, 1, 4), (0, 1, 2, 4)])\n@pytest.mark.parametrize(\"indices\", [(0, 1, 2, 3), (1, 4, 3, 2)])\n@pytest.mark.parametrize(\"shape\", [None, (3, 5)])\ndef test_csr_to_coo(dense_dim, indptr, indices, shape):\n    ctx = F.ctx()\n    val_shape = (len(indices),)\n    if dense_dim is not None:\n        val_shape += (dense_dim,)\n    val = torch.randn(val_shape).to(ctx)\n    indptr = torch.tensor(indptr).to(ctx)\n    indices = torch.tensor(indices).to(ctx)\n    mat = from_csr(indptr, indices, val, shape)\n\n    if shape is None:\n        shape = (indptr.numel() - 1, torch.max(indices).item() + 1)\n\n    row = (\n        torch.arange(0, indptr.shape[0] - 1)\n        .to(ctx)\n        .repeat_interleave(torch.diff(indptr))\n    )\n    col = indices\n    mat_row, mat_col = mat.coo()\n    mat_val = mat.val\n\n    assert mat.shape == shape\n    assert mat.nnz == row.numel()\n    assert mat.device == row.device\n    assert mat.dtype == val.dtype\n    assert torch.allclose(mat_val, val)\n    assert torch.allclose(mat_row, row)\n    assert torch.allclose(mat_col, col)\n\n\n@pytest.mark.parametrize(\"dense_dim\", [None, 4])\n@pytest.mark.parametrize(\"indptr\", [(0, 0, 1, 4), (0, 1, 2, 4)])\n@pytest.mark.parametrize(\"indices\", [(0, 1, 2, 3), (1, 4, 3, 2)])\n@pytest.mark.parametrize(\"shape\", [None, (5, 3)])\ndef test_csc_to_coo(dense_dim, indptr, indices, shape):\n    ctx = F.ctx()\n    val_shape = (len(indices),)\n    if dense_dim is not None:\n        val_shape += (dense_dim,)\n    val = torch.randn(val_shape).to(ctx)\n    indptr = torch.tensor(indptr).to(ctx)\n    indices = torch.tensor(indices).to(ctx)\n    mat = from_csc(indptr, indices, val, shape)\n\n    if shape is None:\n        shape = (torch.max(indices).item() + 1, indptr.numel() - 1)\n\n    col = (\n        torch.arange(0, indptr.shape[0] - 1)\n        .to(ctx)\n        .repeat_interleave(torch.diff(indptr))\n    )\n    row = indices\n    mat_row, mat_col = mat.coo()\n    mat_val = mat.val\n\n    assert mat.shape == shape\n    assert mat.nnz == row.numel()\n    assert mat.device == row.device\n    assert mat.dtype == val.dtype\n    assert torch.allclose(mat_val, val)\n    assert torch.allclose(mat_row, row)\n    assert torch.allclose(mat_col, col)\n\n\ndef _scatter_add(a, index, v=1):\n    index = index.tolist()\n    for i in index:\n        a[i] += v\n    return a\n\n\n@pytest.mark.parametrize(\"dense_dim\", [None, 4])\n@pytest.mark.parametrize(\"row\", [(0, 0, 1, 2), (0, 1, 2, 4)])\n@pytest.mark.parametrize(\"col\", [(0, 1, 2, 2), (1, 3, 3, 4)])\n@pytest.mark.parametrize(\"shape\", [None, (5, 5), (5, 6)])\ndef test_coo_to_csr(dense_dim, row, col, shape):\n    val_shape = (len(row),)\n    if dense_dim is not None:\n        val_shape += (dense_dim,)\n    ctx = F.ctx()\n    val = torch.randn(val_shape).to(ctx)\n    row = torch.tensor(row).to(ctx)\n    col = torch.tensor(col).to(ctx)\n    mat = from_coo(row, col, val, shape)\n\n    if shape is None:\n        shape = (torch.max(row).item() + 1, torch.max(col).item() + 1)\n\n    mat_indptr, mat_indices, value_indices = mat.csr()\n    mat_val = mat.val if value_indices is None else mat.val[value_indices]\n    indptr = torch.zeros(shape[0] + 1).to(ctx)\n    indptr = _scatter_add(indptr, row + 1)\n    indptr = torch.cumsum(indptr, 0).long()\n    indices = col\n\n    assert mat.shape == shape\n    assert mat.nnz == row.numel()\n    assert mat.dtype == val.dtype\n    assert torch.allclose(mat_val, val)\n    assert torch.allclose(mat_indptr, indptr)\n    assert torch.allclose(mat_indices, indices)\n\n\n@pytest.mark.parametrize(\"dense_dim\", [None, 4])\n@pytest.mark.parametrize(\"indptr\", [(0, 0, 1, 4), (0, 1, 2, 4)])\n@pytest.mark.parametrize(\"indices\", [(0, 1, 2, 3), (1, 4, 3, 2)])\n@pytest.mark.parametrize(\"shape\", [None, (5, 3)])\ndef test_csc_to_csr(dense_dim, indptr, indices, shape):\n    ctx = F.ctx()\n    val_shape = (len(indices),)\n    if dense_dim is not None:\n        val_shape += (dense_dim,)\n    val = torch.randn(val_shape).to(ctx)\n    indptr = torch.tensor(indptr).to(ctx)\n    indices = torch.tensor(indices).to(ctx)\n    mat = from_csc(indptr, indices, val, shape)\n    mat_indptr, mat_indices, value_indices = mat.csr()\n    mat_val = mat.val if value_indices is None else mat.val[value_indices]\n\n    if shape is None:\n        shape = (torch.max(indices).item() + 1, indptr.numel() - 1)\n\n    col = (\n        torch.arange(0, indptr.shape[0] - 1)\n        .to(ctx)\n        .repeat_interleave(torch.diff(indptr))\n    )\n    row = indices\n    row, sort_index = row.sort(stable=True)\n    col = col[sort_index]\n    val = val[sort_index]\n    indptr = torch.zeros(shape[0] + 1).to(ctx)\n    indptr = _scatter_add(indptr, row + 1)\n    indptr = torch.cumsum(indptr, 0).long()\n    indices = col\n\n    assert mat.shape == shape\n    assert mat.nnz == row.numel()\n    assert mat.device == row.device\n    assert mat.dtype == val.dtype\n    assert torch.allclose(mat_val, val)\n    assert torch.allclose(mat_indptr, indptr)\n    assert torch.allclose(mat_indices, indices)\n\n\n@pytest.mark.parametrize(\"dense_dim\", [None, 4])\n@pytest.mark.parametrize(\"row\", [(0, 0, 1, 2), (0, 1, 2, 4)])\n@pytest.mark.parametrize(\"col\", [(0, 1, 2, 2), (1, 3, 3, 4)])\n@pytest.mark.parametrize(\"shape\", [None, (5, 5), (5, 6)])\ndef test_coo_to_csc(dense_dim, row, col, shape):\n    val_shape = (len(row),)\n    if dense_dim is not None:\n        val_shape += (dense_dim,)\n    ctx = F.ctx()\n    val = torch.randn(val_shape).to(ctx)\n    row = torch.tensor(row).to(ctx)\n    col = torch.tensor(col).to(ctx)\n    mat = from_coo(row, col, val, shape)\n\n    if shape is None:\n        shape = (torch.max(row).item() + 1, torch.max(col).item() + 1)\n\n    mat_indptr, mat_indices, value_indices = mat.csc()\n    mat_val = mat.val if value_indices is None else mat.val[value_indices]\n    indptr = torch.zeros(shape[1] + 1).to(ctx)\n    _scatter_add(indptr, col + 1)\n    indptr = torch.cumsum(indptr, 0).long()\n    indices = row\n\n    assert mat.shape == shape\n    assert mat.nnz == row.numel()\n    assert mat.dtype == val.dtype\n    assert torch.allclose(mat_val, val)\n    assert torch.allclose(mat_indptr, indptr)\n    assert torch.allclose(mat_indices, indices)\n\n\n@pytest.mark.parametrize(\"dense_dim\", [None, 4])\n@pytest.mark.parametrize(\"indptr\", [(0, 0, 1, 4), (0, 1, 2, 4)])\n@pytest.mark.parametrize(\"indices\", [(0, 1, 2, 3), (1, 2, 3, 4)])\n@pytest.mark.parametrize(\"shape\", [None, (3, 5)])\ndef test_csr_to_csc(dense_dim, indptr, indices, shape):\n    val_shape = (len(indices),)\n    if dense_dim is not None:\n        val_shape += (dense_dim,)\n    ctx = F.ctx()\n    val = torch.randn(val_shape).to(ctx)\n    indptr = torch.tensor(indptr).to(ctx)\n    indices = torch.tensor(indices).to(ctx)\n    mat = from_csr(indptr, indices, val, shape)\n    mat_indptr, mat_indices, value_indices = mat.csc()\n    mat_val = mat.val if value_indices is None else mat.val[value_indices]\n\n    if shape is None:\n        shape = (indptr.numel() - 1, torch.max(indices).item() + 1)\n\n    row = (\n        torch.arange(0, indptr.shape[0] - 1)\n        .to(ctx)\n        .repeat_interleave(torch.diff(indptr))\n    )\n\n    col = indices\n    col, sort_index = col.sort(stable=True)\n    row = row[sort_index]\n    val = val[sort_index]\n    indptr = torch.zeros(shape[1] + 1).to(ctx)\n    indptr = _scatter_add(indptr, col + 1)\n    indptr = torch.cumsum(indptr, 0).long()\n    indices = row\n\n    assert mat.shape == shape\n    assert mat.nnz == row.numel()\n    assert mat.device == row.device\n    assert mat.dtype == val.dtype\n    assert torch.allclose(mat_val, val)\n    assert torch.allclose(mat_indptr, indptr)\n    assert torch.allclose(mat_indices, indices)\n\n\n@pytest.mark.parametrize(\"shape\", [(3, 5), (5, 5), (5, 4)])\ndef test_diag_conversions(shape):\n    n_rows, n_cols = shape\n    nnz = min(shape)\n    ctx = F.ctx()\n    val = torch.randn(nnz).to(ctx)\n    D = diag(val, shape)\n    row, col = D.coo()\n    assert torch.allclose(row, torch.arange(nnz).to(ctx))\n    assert torch.allclose(col, torch.arange(nnz).to(ctx))\n\n    indptr, indices, _ = D.csr()\n    exp_indptr = list(range(0, nnz + 1)) + [nnz] * (n_rows - nnz)\n    assert torch.allclose(indptr, torch.tensor(exp_indptr).to(ctx))\n    assert torch.allclose(indices, torch.arange(nnz).to(ctx))\n\n    indptr, indices, _ = D.csc()\n    exp_indptr = list(range(0, nnz + 1)) + [nnz] * (n_cols - nnz)\n    assert torch.allclose(indptr, torch.tensor(exp_indptr).to(ctx))\n    assert torch.allclose(indices, torch.arange(nnz).to(ctx))\n\n\n@pytest.mark.parametrize(\"val_shape\", [(3), (3, 2)])\n@pytest.mark.parametrize(\"shape\", [(3, 5), (5, 5)])\ndef test_val_like(val_shape, shape):\n    def check_val_like(A, B):\n        assert A.shape == B.shape\n        assert A.nnz == B.nnz\n        assert torch.allclose(torch.stack(A.coo()), torch.stack(B.coo()))\n        assert A.val.device == B.val.device\n\n    ctx = F.ctx()\n\n    # COO\n    row = torch.tensor([1, 1, 2]).to(ctx)\n    col = torch.tensor([2, 4, 3]).to(ctx)\n    val = torch.randn(3).to(ctx)\n    coo_A = from_coo(row, col, val, shape)\n    new_val = torch.randn(val_shape).to(ctx)\n    coo_B = val_like(coo_A, new_val)\n    check_val_like(coo_A, coo_B)\n\n    # CSR\n    indptr, indices, _ = coo_A.csr()\n    csr_A = from_csr(indptr, indices, val, shape)\n    csr_B = val_like(csr_A, new_val)\n    check_val_like(csr_A, csr_B)\n\n    # CSC\n    indptr, indices, _ = coo_A.csc()\n    csc_A = from_csc(indptr, indices, val, shape)\n    csc_B = val_like(csc_A, new_val)\n    check_val_like(csc_A, csc_B)\n\n\ndef test_coalesce():\n    ctx = F.ctx()\n\n    row = torch.tensor([1, 0, 0, 0, 1]).to(ctx)\n    col = torch.tensor([1, 1, 1, 2, 2]).to(ctx)\n    val = torch.arange(len(row)).to(ctx)\n    A = from_coo(row, col, val, (4, 4))\n\n    assert A.has_duplicate()\n\n    A_coalesced = A.coalesce()\n\n    assert A_coalesced.nnz == 4\n    assert A_coalesced.shape == (4, 4)\n    assert list(A_coalesced.row) == [0, 0, 1, 1]\n    assert list(A_coalesced.col) == [1, 2, 1, 2]\n    # Values of duplicate indices are added together.\n    assert list(A_coalesced.val) == [3, 3, 0, 4]\n    assert not A_coalesced.has_duplicate()\n\n\ndef test_has_duplicate():\n    ctx = F.ctx()\n\n    row = torch.tensor([1, 0, 0, 0, 1]).to(ctx)\n    col = torch.tensor([1, 1, 1, 2, 2]).to(ctx)\n    val = torch.arange(len(row)).to(ctx)\n    shape = (4, 4)\n\n    # COO\n    coo_A = from_coo(row, col, val, shape)\n    assert coo_A.has_duplicate()\n\n    # CSR\n    indptr, indices, _ = coo_A.csr()\n    csr_A = from_csr(indptr, indices, val, shape)\n    assert csr_A.has_duplicate()\n\n    # CSC\n    indptr, indices, _ = coo_A.csc()\n    csc_A = from_csc(indptr, indices, val, shape)\n    assert csc_A.has_duplicate()\n\n\n@pytest.mark.parametrize(\n    \"create_func\", [rand_diag, rand_csr, rand_csc, rand_coo]\n)\n@pytest.mark.parametrize(\"shape\", [(5, 5), (6, 4)])\n@pytest.mark.parametrize(\"dense_dim\", [None, 4])\n@pytest.mark.parametrize(\"select_dim\", [0, 1])\n@pytest.mark.parametrize(\"index\", [(0, 1, 3), (1, 2)])\ndef test_index_select(create_func, shape, dense_dim, select_dim, index):\n    ctx = F.ctx()\n    A = create_func(shape, 20, ctx, dense_dim)\n    index = torch.tensor(index).to(ctx)\n    A_select = A.index_select(select_dim, index)\n\n    dense = sparse_matrix_to_dense(A)\n    dense_select = torch.index_select(dense, select_dim, index)\n\n    A_select_to_dense = sparse_matrix_to_dense(A_select)\n\n    assert A_select_to_dense.shape == dense_select.shape\n    assert torch.allclose(A_select_to_dense, dense_select)\n\n\n@pytest.mark.parametrize(\n    \"create_func\", [rand_diag, rand_csr, rand_csc, rand_coo]\n)\n@pytest.mark.parametrize(\"shape\", [(5, 5), (6, 4)])\n@pytest.mark.parametrize(\"dense_dim\", [None, 4])\n@pytest.mark.parametrize(\"select_dim\", [0, 1])\n@pytest.mark.parametrize(\"rang\", [slice(0, 2), slice(1, 3)])\ndef test_range_select(create_func, shape, dense_dim, select_dim, rang):\n    ctx = F.ctx()\n    A = create_func(shape, 20, ctx, dense_dim)\n    A_select = A.range_select(select_dim, rang)\n\n    dense = sparse_matrix_to_dense(A)\n    if select_dim == 0:\n        dense_select = dense[rang, :]\n    else:\n        dense_select = dense[:, rang]\n\n    A_select_to_dense = sparse_matrix_to_dense(A_select)\n\n    assert A_select_to_dense.shape == dense_select.shape\n    assert torch.allclose(A_select_to_dense, dense_select)\n\n\n@pytest.mark.parametrize(\n    \"create_func\", [rand_diag, rand_csr, rand_csc, rand_coo]\n)\n@pytest.mark.parametrize(\"index\", [(0, 1, 2, 3, 4), (0, 1, 3), (1, 1, 2)])\n@pytest.mark.parametrize(\"replace\", [False, True])\n@pytest.mark.parametrize(\"bias\", [False, True])\ndef test_sample_rowwise(create_func, index, replace, bias):\n    ctx = F.ctx()\n    shape = (5, 5)\n    sample_dim = 0\n    sample_num = 3\n    A = create_func(shape, 10, ctx)\n    A = val_like(A, torch.abs(A.val))\n\n    index = torch.tensor(index).to(ctx)\n\n    A_sample = A.sample(sample_dim, sample_num, index, replace, bias)\n    A_dense = sparse_matrix_to_dense(A)\n    A_sample_to_dense = sparse_matrix_to_dense(A_sample)\n\n    ans_shape = (index.size(0), shape[1])\n    # Verify sample elements in origin rows\n    for i, row in enumerate(list(index)):\n        ans_ele = list(A_dense[row, :].nonzero().reshape(-1))\n        ret_ele = list(A_sample_to_dense[i, :].nonzero().reshape(-1))\n        for e in ret_ele:\n            assert e in ans_ele\n        if replace:\n            # The number of sample elements in one row should be equal to\n            # 'sample_num' if the row is not empty otherwise should be\n            # equal to 0.\n            assert list(A_sample.row).count(torch.tensor(i)) == (\n                sample_num if len(ans_ele) != 0 else 0\n            )\n        else:\n            assert len(ret_ele) == min(sample_num, len(ans_ele))\n\n    assert A_sample.shape == ans_shape\n    if not replace:\n        assert not A_sample.has_duplicate()\n\n\n@pytest.mark.parametrize(\n    \"create_func\", [rand_diag, rand_csr, rand_csc, rand_coo]\n)\n@pytest.mark.parametrize(\"index\", [(0, 1, 2, 3, 4), (0, 1, 3), (1, 1, 2)])\n@pytest.mark.parametrize(\"replace\", [False, True])\n@pytest.mark.parametrize(\"bias\", [False, True])\ndef test_sample_columnwise(create_func, index, replace, bias):\n    ctx = F.ctx()\n    shape = (5, 5)\n    sample_dim = 1\n    sample_num = 3\n    A = create_func(shape, 10, ctx)\n    A = val_like(A, torch.abs(A.val))\n\n    index = torch.tensor(index).to(ctx)\n\n    A_sample = A.sample(sample_dim, sample_num, index, replace, bias)\n    A_dense = sparse_matrix_to_dense(A)\n    A_sample_to_dense = sparse_matrix_to_dense(A_sample)\n\n    ans_shape = (shape[0], index.size(0))\n    # Verify sample elements in origin columns\n    for i, col in enumerate(list(index)):\n        ans_ele = list(A_dense[:, col].nonzero().reshape(-1))\n        ret_ele = list(A_sample_to_dense[:, i].nonzero().reshape(-1))\n        for e in ret_ele:\n            assert e in ans_ele\n        if replace:\n            # The number of sample elements in one column should be equal to\n            # 'sample_num' if the column is not empty otherwise should be\n            # equal to 0.\n            assert list(A_sample.col).count(torch.tensor(i)) == (\n                sample_num if len(ans_ele) != 0 else 0\n            )\n        else:\n            assert len(ret_ele) == min(sample_num, len(ans_ele))\n\n    assert A_sample.shape == ans_shape\n    if not replace:\n        assert not A_sample.has_duplicate()\n\n\ndef test_print():\n    ctx = F.ctx()\n\n    # basic\n    row = torch.tensor([1, 1, 3]).to(ctx)\n    col = torch.tensor([2, 1, 3]).to(ctx)\n    val = torch.tensor([1.0, 1.0, 2.0]).to(ctx)\n    A = from_coo(row, col, val)\n    expected = (\n        str(\n            \"\"\"SparseMatrix(indices=tensor([[1, 1, 3],\n                             [2, 1, 3]]),\n             values=tensor([1., 1., 2.]),\n             shape=(4, 4), nnz=3)\"\"\"\n        )\n        if str(ctx) == \"cpu\"\n        else str(\n            \"\"\"SparseMatrix(indices=tensor([[1, 1, 3],\n                             [2, 1, 3]], device='cuda:0'),\n             values=tensor([1., 1., 2.], device='cuda:0'),\n             shape=(4, 4), nnz=3)\"\"\"\n        )\n    )\n    assert str(A) == expected, print(A, expected)\n\n    # vector-shape non zero\n    row = torch.tensor([1, 1, 3]).to(ctx)\n    col = torch.tensor([2, 1, 3]).to(ctx)\n    val = torch.tensor(\n        [[1.3080, 1.5984], [-0.4126, 0.7250], [-0.5416, -0.7022]]\n    ).to(ctx)\n    A = from_coo(row, col, val)\n    expected = (\n        str(\n            \"\"\"SparseMatrix(indices=tensor([[1, 1, 3],\n                             [2, 1, 3]]),\n             values=tensor([[ 1.3080,  1.5984],\n                            [-0.4126,  0.7250],\n                            [-0.5416, -0.7022]]),\n             shape=(4, 4), nnz=3, val_size=(2,))\"\"\"\n        )\n        if str(ctx) == \"cpu\"\n        else str(\n            \"\"\"SparseMatrix(indices=tensor([[1, 1, 3],\n                             [2, 1, 3]], device='cuda:0'),\n             values=tensor([[ 1.3080,  1.5984],\n                            [-0.4126,  0.7250],\n                            [-0.5416, -0.7022]], device='cuda:0'),\n             shape=(4, 4), nnz=3, val_size=(2,))\"\"\"\n        )\n    )\n    assert str(A) == expected, print(A, expected)\n\n\n@unittest.skipIf(\n    F._default_context_str == \"cpu\",\n    reason=\"Device conversions don't need to be tested on CPU.\",\n)\n@pytest.mark.parametrize(\"device\", [\"cpu\", \"cuda\"])\ndef test_to_device(device):\n    row = torch.tensor([1, 1, 2])\n    col = torch.tensor([1, 2, 0])\n    mat = from_coo(row, col, shape=(3, 4))\n\n    target_row = row.to(device)\n    target_col = col.to(device)\n    target_val = mat.val.to(device)\n\n    mat2 = mat.to(device=device)\n    assert mat2.shape == mat.shape\n    assert torch.allclose(mat2.row, target_row)\n    assert torch.allclose(mat2.col, target_col)\n    assert torch.allclose(mat2.val, target_val)\n\n    mat2 = getattr(mat, device)()\n    assert mat2.shape == mat.shape\n    assert torch.allclose(mat2.row, target_row)\n    assert torch.allclose(mat2.col, target_col)\n    assert torch.allclose(mat2.val, target_val)\n\n\n@pytest.mark.parametrize(\n    \"dtype\", [torch.float, torch.double, torch.int, torch.long]\n)\ndef test_to_dtype(dtype):\n    row = torch.tensor([1, 1, 2])\n    col = torch.tensor([1, 2, 0])\n    mat = from_coo(row, col, shape=(3, 4))\n\n    target_val = mat.val.to(dtype=dtype)\n\n    mat2 = mat.to(dtype=dtype)\n    assert mat2.shape == mat.shape\n    assert torch.allclose(mat2.val, target_val)\n\n    func_name = {\n        torch.float: \"float\",\n        torch.double: \"double\",\n        torch.int: \"int\",\n        torch.long: \"long\",\n    }\n    mat2 = getattr(mat, func_name[dtype])()\n    assert mat2.shape == mat.shape\n    assert torch.allclose(mat2.val, target_val)\n\n\n@pytest.mark.parametrize(\"dense_dim\", [None, 2])\n@pytest.mark.parametrize(\"row\", [[0, 0, 1, 2], (0, 1, 2, 4)])\n@pytest.mark.parametrize(\"col\", [(0, 1, 2, 2), (1, 3, 3, 4)])\n@pytest.mark.parametrize(\"extra_shape\", [(0, 1), (2, 1)])\ndef test_sparse_matrix_transpose(dense_dim, row, col, extra_shape):\n    mat_shape = (max(row) + 1 + extra_shape[0], max(col) + 1 + extra_shape[1])\n    val_shape = (len(row),)\n    if dense_dim is not None:\n        val_shape += (dense_dim,)\n    ctx = F.ctx()\n    val = torch.randn(val_shape).to(ctx)\n    row = torch.tensor(row).to(ctx)\n    col = torch.tensor(col).to(ctx)\n    mat = from_coo(row, col, val, mat_shape).transpose()\n    mat_row, mat_col = mat.coo()\n    mat_val = mat.val\n\n    assert mat.shape == mat_shape[::-1]\n    assert torch.allclose(mat_val, val)\n    assert torch.allclose(mat_row, col)\n    assert torch.allclose(mat_col, row)\n\n\n@pytest.mark.parametrize(\"row\", [[0, 0, 1, 2], (0, 1, 2, 4)])\n@pytest.mark.parametrize(\"col\", [(0, 1, 2, 2), (1, 3, 3, 4)])\n@pytest.mark.parametrize(\"nz_dim\", [None, 2])\n@pytest.mark.parametrize(\"shape\", [(5, 5), (6, 7)])\ndef test_torch_sparse_coo_conversion(row, col, nz_dim, shape):\n    dev = F.ctx()\n    row = torch.tensor(row).to(dev)\n    col = torch.tensor(col).to(dev)\n    indices = torch.stack([row, col])\n    torch_sparse_shape = shape\n    val_shape = (row.shape[0],)\n    if nz_dim is not None:\n        torch_sparse_shape += (nz_dim,)\n        val_shape += (nz_dim,)\n    val = torch.randn(val_shape).to(dev)\n    torch_sparse_coo = torch.sparse_coo_tensor(indices, val, torch_sparse_shape)\n    spmat = from_torch_sparse(torch_sparse_coo)\n\n    def _assert_spmat_equal_to_torch_sparse_coo(spmat, torch_sparse_coo):\n        assert torch_sparse_coo.layout == torch.sparse_coo\n        # Use .data_ptr() to check whether indices and values are on the same\n        # memory address\n        assert (\n            spmat.indices().data_ptr() == torch_sparse_coo._indices().data_ptr()\n        )\n        assert spmat.val.data_ptr() == torch_sparse_coo._values().data_ptr()\n        assert spmat.shape == torch_sparse_coo.shape[:2]\n\n    _assert_spmat_equal_to_torch_sparse_coo(spmat, torch_sparse_coo)\n    torch_sparse_coo = to_torch_sparse_coo(spmat)\n    _assert_spmat_equal_to_torch_sparse_coo(spmat, torch_sparse_coo)\n\n\n@pytest.mark.parametrize(\"indptr\", [(0, 0, 1, 4), (0, 1, 2, 4)])\n@pytest.mark.parametrize(\"indices\", [(0, 1, 2, 3), (1, 2, 3, 4)])\n@pytest.mark.parametrize(\"shape\", [(3, 5), (3, 7)])\ndef test_torch_sparse_csr_conversion(indptr, indices, shape):\n    dev = F.ctx()\n    indptr = torch.tensor(indptr).to(dev)\n    indices = torch.tensor(indices).to(dev)\n    torch_sparse_shape = shape\n    val_shape = (indices.shape[0],)\n    val = torch.randn(val_shape).to(dev)\n    torch_sparse_csr = _torch_sparse_csr_tensor(\n        indptr, indices, val, torch_sparse_shape\n    )\n    spmat = from_torch_sparse(torch_sparse_csr)\n\n    def _assert_spmat_equal_to_torch_sparse_csr(spmat, torch_sparse_csr):\n        indptr, indices, value_indices = spmat.csr()\n        assert torch_sparse_csr.layout == torch.sparse_csr\n        assert value_indices is None\n        # Use .data_ptr() to check whether indices and values are on the same\n        # memory address\n        assert indptr.data_ptr() == torch_sparse_csr.crow_indices().data_ptr()\n        assert indices.data_ptr() == torch_sparse_csr.col_indices().data_ptr()\n        assert spmat.val.data_ptr() == torch_sparse_csr.values().data_ptr()\n        assert spmat.shape == torch_sparse_csr.shape[:2]\n\n    _assert_spmat_equal_to_torch_sparse_csr(spmat, torch_sparse_csr)\n    torch_sparse_csr = to_torch_sparse_csr(spmat)\n    _assert_spmat_equal_to_torch_sparse_csr(spmat, torch_sparse_csr)\n\n\n@pytest.mark.parametrize(\"indptr\", [(0, 0, 1, 4), (0, 1, 2, 4)])\n@pytest.mark.parametrize(\"indices\", [(0, 1, 2, 3), (1, 2, 3, 4)])\n@pytest.mark.parametrize(\"shape\", [(8, 3), (5, 3)])\ndef test_torch_sparse_csc_conversion(indptr, indices, shape):\n    dev = F.ctx()\n    indptr = torch.tensor(indptr).to(dev)\n    indices = torch.tensor(indices).to(dev)\n    torch_sparse_shape = shape\n    val_shape = (indices.shape[0],)\n    val = torch.randn(val_shape).to(dev)\n    torch_sparse_csc = torch.sparse_csc_tensor(\n        indptr, indices, val, torch_sparse_shape\n    )\n    spmat = from_torch_sparse(torch_sparse_csc)\n\n    def _assert_spmat_equal_to_torch_sparse_csc(spmat, torch_sparse_csc):\n        indptr, indices, value_indices = spmat.csc()\n        assert torch_sparse_csc.layout == torch.sparse_csc\n        assert value_indices is None\n        # Use .data_ptr() to check whether indices and values are on the same\n        # memory address\n        assert indptr.data_ptr() == torch_sparse_csc.ccol_indices().data_ptr()\n        assert indices.data_ptr() == torch_sparse_csc.row_indices().data_ptr()\n        assert spmat.val.data_ptr() == torch_sparse_csc.values().data_ptr()\n        assert spmat.shape == torch_sparse_csc.shape[:2]\n\n    _assert_spmat_equal_to_torch_sparse_csc(spmat, torch_sparse_csc)\n    torch_sparse_csc = to_torch_sparse_csc(spmat)\n    _assert_spmat_equal_to_torch_sparse_csc(spmat, torch_sparse_csc)\n\n\n### Diag foramt related tests ###\n\n\n@pytest.mark.parametrize(\"val_shape\", [(3,), (3, 2)])\n@pytest.mark.parametrize(\"mat_shape\", [None, (3, 5), (5, 3)])\ndef test_diag(val_shape, mat_shape):\n    ctx = F.ctx()\n    # creation\n    val = torch.randn(val_shape).to(ctx)\n    mat = diag(val, mat_shape)\n\n    # val, shape attributes\n    assert torch.allclose(mat.val, val)\n    if mat_shape is None:\n        mat_shape = (val_shape[0], val_shape[0])\n    assert mat.shape == mat_shape\n\n    val = torch.randn(val_shape).to(ctx)\n\n    # nnz\n    assert mat.nnz == val.shape[0]\n    # dtype\n    assert mat.dtype == val.dtype\n    # device\n    assert mat.device == val.device\n\n    # row, col, val\n    edge_index = torch.arange(len(val)).to(mat.device)\n    row, col = mat.coo()\n    val = mat.val\n    assert torch.allclose(row, edge_index)\n    assert torch.allclose(col, edge_index)\n    assert torch.allclose(val, val)\n\n\n@pytest.mark.parametrize(\"shape\", [(3, 3), (3, 5), (5, 3)])\n@pytest.mark.parametrize(\"d\", [None, 2])\ndef test_identity(shape, d):\n    ctx = F.ctx()\n    # creation\n    mat = identity(shape, d)\n    # shape\n    assert mat.shape == shape\n    # val\n    len_val = min(shape)\n    if d is None:\n        val_shape = len_val\n    else:\n        val_shape = (len_val, d)\n    val = torch.ones(val_shape)\n    assert torch.allclose(val, mat.val)\n\n\n@pytest.mark.parametrize(\"val_shape\", [(3,), (3, 2)])\n@pytest.mark.parametrize(\"mat_shape\", [None, (3, 5), (5, 3)])\ndef test_diag_matrix_transpose(val_shape, mat_shape):\n    ctx = F.ctx()\n    val = torch.randn(val_shape).to(ctx)\n    mat = diag(val, mat_shape).transpose()\n\n    assert torch.allclose(mat.val, val)\n    if mat_shape is None:\n        mat_shape = (val_shape[0], val_shape[0])\n    assert mat.shape == mat_shape[::-1]\n"
  },
  {
    "path": "tests/python/pytorch/sparse/test_unary_op.py",
    "content": "import sys\n\nimport backend as F\nimport torch\n\nfrom dgl.sparse import diag, spmatrix\n\n\ndef test_neg():\n    ctx = F.ctx()\n    row = torch.tensor([1, 1, 3]).to(ctx)\n    col = torch.tensor([1, 2, 3]).to(ctx)\n    val = torch.tensor([1.0, 1.0, 2.0]).to(ctx)\n    A = spmatrix(torch.stack([row, col]), val)\n    neg_A = -A\n    assert A.shape == neg_A.shape\n    assert A.nnz == neg_A.nnz\n    assert torch.allclose(-A.val, neg_A.val)\n    assert torch.allclose(torch.stack(A.coo()), torch.stack(neg_A.coo()))\n    assert A.val.device == neg_A.val.device\n\n\ndef test_diag_neg():\n    ctx = F.ctx()\n    val = torch.arange(3).float().to(ctx)\n    D = diag(val)\n    neg_D = -D\n    assert D.shape == neg_D.shape\n    assert torch.allclose(-D.val, neg_D.val)\n    assert D.val.device == neg_D.val.device\n\n\ndef test_diag_inv():\n    ctx = F.ctx()\n    val = torch.arange(1, 4).float().to(ctx)\n    D = diag(val)\n    inv_D = D.inv()\n    assert D.shape == inv_D.shape\n    assert torch.allclose(1.0 / D.val, inv_D.val)\n    assert D.val.device == inv_D.val.device\n"
  },
  {
    "path": "tests/python/pytorch/sparse/utils.py",
    "content": "import numpy as np\nimport torch\n\nfrom dgl.sparse import diag, from_csc, from_csr, SparseMatrix, spmatrix\n\nnp.random.seed(42)\ntorch.random.manual_seed(42)\n\n\ndef clone_detach_and_grad(t):\n    t = t.clone().detach()\n    t.requires_grad_()\n    return t\n\n\ndef rand_stride(t):\n    \"\"\"Add stride to the last dimension of a tensor.\"\"\"\n    stride = np.random.randint(2, 4)\n    ret = torch.stack([t] * stride, dim=-1)[..., 0]\n    ret = ret.detach()\n    if torch.is_floating_point(t):\n        ret.requires_grad_()\n    return ret\n\n\ndef rand_coo(shape, nnz, dev, nz_dim=None):\n    # Create a sparse matrix without duplicate entries.\n    nnzid = np.random.choice(shape[0] * shape[1], nnz, replace=False)\n    nnzid = torch.tensor(nnzid, device=dev).long()\n    row = torch.div(nnzid, shape[1], rounding_mode=\"floor\")\n    col = nnzid % shape[1]\n    if nz_dim is None:\n        val = torch.randn(nnz, device=dev, requires_grad=True)\n    else:\n        val = torch.randn(nnz, nz_dim, device=dev, requires_grad=True)\n    indices = torch.stack([row, col])\n    indices = rand_stride(indices)\n    val = rand_stride(val)\n    return spmatrix(indices, val, shape)\n\n\ndef rand_csr(shape, nnz, dev, nz_dim=None):\n    # Create a sparse matrix without duplicate entries.\n    nnzid = np.random.choice(shape[0] * shape[1], nnz, replace=False)\n    nnzid = torch.tensor(nnzid, device=dev).long()\n    row = torch.div(nnzid, shape[1], rounding_mode=\"floor\")\n    col = nnzid % shape[1]\n    if nz_dim is None:\n        val = torch.randn(nnz, device=dev, requires_grad=True)\n    else:\n        val = torch.randn(nnz, nz_dim, device=dev, requires_grad=True)\n    indptr = torch.zeros(shape[0] + 1, device=dev, dtype=torch.int64)\n    for r in row.tolist():\n        indptr[r + 1] += 1\n    indptr = torch.cumsum(indptr, 0)\n    row_sorted, row_sorted_idx = torch.sort(row)\n    indices = col[row_sorted_idx]\n    indptr = rand_stride(indptr)\n    indices = rand_stride(indices)\n    val = rand_stride(val)\n    return from_csr(indptr, indices, val, shape=shape)\n\n\ndef rand_csc(shape, nnz, dev, nz_dim=None):\n    # Create a sparse matrix without duplicate entries.\n    nnzid = np.random.choice(shape[0] * shape[1], nnz, replace=False)\n    nnzid = torch.tensor(nnzid, device=dev).long()\n    row = torch.div(nnzid, shape[1], rounding_mode=\"floor\")\n    col = nnzid % shape[1]\n    if nz_dim is None:\n        val = torch.randn(nnz, device=dev, requires_grad=True)\n    else:\n        val = torch.randn(nnz, nz_dim, device=dev, requires_grad=True)\n    indptr = torch.zeros(shape[1] + 1, device=dev, dtype=torch.int64)\n    for c in col.tolist():\n        indptr[c + 1] += 1\n    indptr = torch.cumsum(indptr, 0)\n    col_sorted, col_sorted_idx = torch.sort(col)\n    indices = row[col_sorted_idx]\n    indptr = rand_stride(indptr)\n    indices = rand_stride(indices)\n    val = rand_stride(val)\n    return from_csc(indptr, indices, val, shape=shape)\n\n\ndef rand_diag(shape, nnz, dev, nz_dim=None):\n    nnz = min(shape)\n    if nz_dim is None:\n        val = torch.randn(nnz, device=dev, requires_grad=True)\n    else:\n        val = torch.randn(nnz, nz_dim, device=dev, requires_grad=True)\n    return diag(val, shape)\n\n\ndef rand_coo_uncoalesced(shape, nnz, dev):\n    # Create a sparse matrix with possible duplicate entries.\n    row = torch.randint(shape[0], (nnz,), device=dev)\n    col = torch.randint(shape[1], (nnz,), device=dev)\n    val = torch.randn(nnz, device=dev, requires_grad=True)\n    indices = torch.stack([row, col])\n    indices = rand_stride(indices)\n    return spmatrix(indices, val, shape)\n\n\ndef rand_csr_uncoalesced(shape, nnz, dev):\n    # Create a sparse matrix with possible duplicate entries.\n    row = torch.randint(shape[0], (nnz,), device=dev)\n    col = torch.randint(shape[1], (nnz,), device=dev)\n    val = torch.randn(nnz, device=dev, requires_grad=True)\n    indptr = torch.zeros(shape[0] + 1, device=dev, dtype=torch.int64)\n    for r in row.tolist():\n        indptr[r + 1] += 1\n    indptr = torch.cumsum(indptr, 0)\n    row_sorted, row_sorted_idx = torch.sort(row)\n    indices = col[row_sorted_idx]\n    indptr = rand_stride(indptr)\n    indices = rand_stride(indices)\n    val = rand_stride(val)\n    return from_csr(indptr, indices, val, shape=shape)\n\n\ndef rand_csc_uncoalesced(shape, nnz, dev):\n    # Create a sparse matrix with possible duplicate entries.\n    row = torch.randint(shape[0], (nnz,), device=dev)\n    col = torch.randint(shape[1], (nnz,), device=dev)\n    val = torch.randn(nnz, device=dev, requires_grad=True)\n    indptr = torch.zeros(shape[1] + 1, device=dev, dtype=torch.int64)\n    for c in col.tolist():\n        indptr[c + 1] += 1\n    indptr = torch.cumsum(indptr, 0)\n    col_sorted, col_sorted_idx = torch.sort(col)\n    indices = row[col_sorted_idx]\n    indptr = rand_stride(indptr)\n    indices = rand_stride(indices)\n    val = rand_stride(val)\n    return from_csc(indptr, indices, val, shape=shape)\n\n\ndef sparse_matrix_to_dense(A: SparseMatrix):\n    dense = A.to_dense()\n    return clone_detach_and_grad(dense)\n\n\ndef sparse_matrix_to_torch_sparse(A: SparseMatrix, val=None):\n    row, col = A.coo()\n    edge_index = torch.cat((row.unsqueeze(0), col.unsqueeze(0)), 0)\n    shape = A.shape\n    if val is None:\n        val = A.val\n    val = val.clone().detach()\n    if len(A.val.shape) > 1:\n        shape += (A.val.shape[-1],)\n    ret = torch.sparse_coo_tensor(edge_index, val, shape).coalesce()\n    ret.requires_grad_()\n    return ret\n\n\ndef dense_mask(dense, sparse):\n    ret = torch.zeros_like(dense)\n    row, col = sparse.coo()\n    for r, c in zip(row, col):\n        ret[r, c] = dense[r, c]\n    return ret\n"
  },
  {
    "path": "tests/python/pytorch/test_ffi-stream.py",
    "content": "import unittest\nfrom statistics import mean\n\nimport backend as F\n\nimport dgl\nimport dgl.ndarray as nd\nimport dgl.ops as OPS\nimport numpy as np\nimport torch\nfrom dgl import rand_graph\nfrom dgl._ffi.streams import _dgl_get_stream, to_dgl_stream_handle\nfrom dgl.utils import to_dgl_context\n\n\n# borrowed from PyTorch, torch/testing/_internal/common_utils.py\ndef _get_cycles_per_ms() -> float:\n    \"\"\"Measure and return approximate number of cycles per millisecond for torch.cuda._sleep\"\"\"\n\n    def measure() -> float:\n        start = torch.cuda.Event(enable_timing=True)\n        end = torch.cuda.Event(enable_timing=True)\n        start.record()\n        torch.cuda._sleep(1000000)\n        end.record()\n        end.synchronize()\n        cycles_per_ms = 1000000 / start.elapsed_time(end)\n        return cycles_per_ms\n\n    # Get 10 values and remove the 2 max and 2 min and return the avg.\n    # This is to avoid system disturbance that skew the results, e.g.\n    # the very first cuda call likely does a bunch of init, which takes\n    # much longer than subsequent calls.\n    num = 10\n    vals = []\n    for _ in range(num):\n        vals.append(measure())\n    vals = sorted(vals)\n    return mean(vals[2 : num - 2])\n\n\n@unittest.skipIf(\n    F._default_context_str == \"cpu\", reason=\"stream only runs on GPU.\"\n)\ndef test_basics():\n    g = rand_graph(10, 20, device=F.cpu())\n    x = torch.ones(g.num_nodes(), 10)\n    result = OPS.copy_u_sum(g, x).to(F.ctx())\n\n    # launch on default stream used in DGL\n    xx = x.to(device=F.ctx())\n    gg = g.to(device=F.ctx())\n    OPS.copy_u_sum(gg, xx)\n    assert torch.equal(OPS.copy_u_sum(gg, xx), result)\n\n    # launch on new stream created via torch.cuda\n    s = torch.cuda.Stream(device=F.ctx())\n    with torch.cuda.stream(s):\n        xx = x.to(device=F.ctx(), non_blocking=True)\n        gg = g.to(device=F.ctx())\n        OPS.copy_u_sum(gg, xx)\n    s.synchronize()\n    assert torch.equal(OPS.copy_u_sum(gg, xx), result)\n\n\n@unittest.skipIf(\n    F._default_context_str == \"cpu\", reason=\"stream only runs on GPU.\"\n)\ndef test_set_get_stream():\n    current_stream = torch.cuda.current_stream()\n    # test setting another stream\n    s = torch.cuda.Stream(device=F.ctx())\n    torch.cuda.set_stream(s)\n    assert (\n        to_dgl_stream_handle(s).value\n        == _dgl_get_stream(to_dgl_context(F.ctx())).value\n    )\n    # revert to default stream\n    torch.cuda.set_stream(current_stream)\n\n\n@unittest.skipIf(\n    F._default_context_str == \"cpu\", reason=\"stream only runs on GPU.\"\n)\n# borrowed from PyTorch, test/test_cuda.py: test_record_stream()\ndef test_record_stream_ndarray():\n    cycles_per_ms = _get_cycles_per_ms()\n\n    t = nd.array(np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32), ctx=nd.cpu())\n    t.pin_memory_()\n    result = nd.empty([4], ctx=nd.gpu(0))\n    stream = torch.cuda.Stream()\n    ptr = [None]\n\n    # Performs the CPU->GPU copy in a background stream\n    def perform_copy():\n        with torch.cuda.stream(stream):\n            tmp = t.copyto(nd.gpu(0))\n            ptr[0] = F.from_dgl_nd(tmp).data_ptr()\n        torch.cuda.current_stream().wait_stream(stream)\n        tmp.record_stream(to_dgl_stream_handle(torch.cuda.current_stream()))\n        torch.cuda._sleep(int(50 * cycles_per_ms))  # delay the copy\n        result.copyfrom(tmp)\n\n    perform_copy()\n    with torch.cuda.stream(stream):\n        tmp2 = nd.empty([4], ctx=nd.gpu(0))\n        assert (\n            F.from_dgl_nd(tmp2).data_ptr() != ptr[0]\n        ), \"allocation re-used too soon\"\n\n    assert torch.equal(\n        F.from_dgl_nd(result).cpu(), torch.tensor([1.0, 2.0, 3.0, 4.0])\n    )\n\n    # Check that the block will be re-used after the main stream finishes\n    torch.cuda.current_stream().synchronize()\n    with torch.cuda.stream(stream):\n        tmp3 = nd.empty([4], ctx=nd.gpu(0))\n        assert (\n            F.from_dgl_nd(tmp3).data_ptr() == ptr[0]\n        ), \"allocation not re-used\"\n\n\n@unittest.skipIf(\n    F._default_context_str == \"cpu\", reason=\"stream only runs on GPU.\"\n)\ndef test_record_stream_graph_positive():\n    cycles_per_ms = _get_cycles_per_ms()\n\n    g = rand_graph(10, 20, device=F.cpu())\n    g.create_formats_()\n    x = torch.ones(g.num_nodes(), 10).to(F.ctx())\n    g1 = g.to(F.ctx())\n    # this is necessary to initialize the cusparse handle\n    result = OPS.copy_u_sum(g1, x)\n    torch.cuda.current_stream().synchronize()\n\n    stream = torch.cuda.Stream()\n    results2 = torch.zeros_like(result)\n\n    # Performs the computing in a background stream\n    def perform_computing():\n        with torch.cuda.stream(stream):\n            g2 = g.to(F.ctx())\n        torch.cuda.current_stream().wait_stream(stream)\n        g2.record_stream(torch.cuda.current_stream())\n        torch.cuda._sleep(int(50 * cycles_per_ms))  # delay the computing\n        results2.copy_(OPS.copy_u_sum(g2, x))\n\n    perform_computing()\n    with torch.cuda.stream(stream):\n        # since we have called record stream for g2, g3 won't reuse its memory\n        g3 = rand_graph(10, 20, device=F.ctx())\n        g3.create_formats_()\n    torch.cuda.current_stream().synchronize()\n    assert torch.equal(result, results2)\n\n\n@unittest.skipIf(\n    F._default_context_str == \"cpu\", reason=\"stream only runs on GPU.\"\n)\ndef test_record_stream_graph_negative():\n    cycles_per_ms = _get_cycles_per_ms()\n\n    g = rand_graph(10, 20, device=F.cpu())\n    g.create_formats_()\n    x = torch.ones(g.num_nodes(), 10).to(F.ctx())\n    g1 = g.to(F.ctx())\n    # this is necessary to initialize the cusparse handle\n    result = OPS.copy_u_sum(g1, x)\n    torch.cuda.current_stream().synchronize()\n\n    stream = torch.cuda.Stream()\n    results2 = torch.zeros_like(result)\n\n    # Performs the computing in a background stream\n    def perform_computing():\n        with torch.cuda.stream(stream):\n            g2 = g.to(F.ctx())\n        torch.cuda.current_stream().wait_stream(stream)\n        # omit record_stream will produce a wrong result\n        # g2.record_stream(torch.cuda.current_stream())\n        torch.cuda._sleep(int(50 * cycles_per_ms))  # delay the computing\n        results2.copy_(OPS.copy_u_sum(g2, x))\n\n    perform_computing()\n    with torch.cuda.stream(stream):\n        # g3 will reuse g2's memory block, resulting a wrong result\n        g3 = rand_graph(10, 20, device=F.ctx())\n        g3.create_formats_()\n    torch.cuda.current_stream().synchronize()\n    assert not torch.equal(result, results2)\n\n\nif __name__ == \"__main__\":\n    test_basics()\n    test_set_get_stream()\n    test_record_stream_ndarray()\n    test_record_stream_graph_positive()\n    test_record_stream_graph_negative()\n"
  },
  {
    "path": "tests/python/pytorch/test_heterograph-pickle.py",
    "content": "import io\nimport pickle\n\nimport dgl\n\nimport networkx as nx\nimport torch\n\n\ndef _reconstruct_pickle(obj):\n    f = io.BytesIO()\n    pickle.dump(obj, f)\n    f.seek(0)\n    obj = pickle.load(f)\n    f.close()\n    return obj\n\n\ndef test_pickling_batched_graph():\n    # NOTE: this is a test for a wierd bug mentioned in\n    #   https://github.com/dmlc/dgl/issues/438\n    glist = [nx.path_graph(i + 5) for i in range(5)]\n    glist = [dgl.from_networkx(g) for g in glist]\n    bg = dgl.batch(glist)\n    bg.ndata[\"x\"] = torch.randn((35, 5))\n    bg.edata[\"y\"] = torch.randn((60, 3))\n    new_bg = _reconstruct_pickle(bg)\n\n\nif __name__ == \"__main__\":\n    test_pickling_batched_graph()\n"
  },
  {
    "path": "tests/python/pytorch/test_multiprocessing-ipc.py",
    "content": "import os\nimport unittest\n\nimport dgl\n\nimport torch as th\nimport torch.multiprocessing as mp\n\n\ndef sub_ipc(g):\n    print(g)\n    return g\n\n\n@unittest.skipIf(os.name == \"nt\", reason=\"Do not support windows yet\")\ndef test_torch_ipc():\n    g = dgl.graph(([0, 1, 2], [1, 2, 3]))\n    ctx = mp.get_context(\"spawn\")\n    p = ctx.Process(target=sub_ipc, args=(g,))\n\n    p.start()\n    p.join()\n\n\nif __name__ == \"__main__\":\n    test_torch_ipc()\n"
  },
  {
    "path": "tests/python/pytorch/utils/test_pin_memory.py",
    "content": "import backend as F\n\nimport dgl\nimport pytest\nimport torch\n\n\n@pytest.mark.skipif(\n    F._default_context_str == \"cpu\", reason=\"Need gpu for this test.\"\n)\ndef test_pin_noncontiguous():\n    t = torch.empty([10, 100]).transpose(0, 1)\n\n    assert not t.is_contiguous()\n    assert not F.is_pinned(t)\n\n    with pytest.raises(dgl.DGLError):\n        dgl.utils.pin_memory_inplace(t)\n\n\n@pytest.mark.skipif(\n    F._default_context_str == \"cpu\", reason=\"Need gpu for this test.\"\n)\ndef test_pin_view():\n    t = torch.empty([100, 10])\n    v = t[10:20]\n\n    assert v.is_contiguous()\n    assert not F.is_pinned(t)\n\n    with pytest.raises(dgl.DGLError):\n        dgl.utils.pin_memory_inplace(v)\n\n    # make sure an empty view does not generate an error\n    u = t[10:10]\n    u = dgl.utils.pin_memory_inplace(u)\n\n\n@pytest.mark.skipif(\n    F._default_context_str == \"cpu\", reason=\"Need gpu for this test.\"\n)\ndef test_unpin_automatically():\n    # run a sufficient number of iterations such that the memory pool should be\n    # re-used\n    for j in range(10):\n        t = torch.ones(10000, 10)\n        assert not F.is_pinned(t)\n        nd = dgl.utils.pin_memory_inplace(t)\n        assert F.is_pinned(t)\n        del nd\n        # dgl.ndarray will unpin its data upon destruction\n        assert not F.is_pinned(t)\n        del t\n\n\n@pytest.mark.skipif(\n    F._default_context_str == \"cpu\", reason=\"Need gpu for this test.\"\n)\ndef test_pin_unpin_column():\n    g = dgl.graph(([1, 2, 3, 4], [0, 0, 0, 0]))\n\n    g.ndata[\"x\"] = torch.randn(g.num_nodes())\n    g.pin_memory_()\n    assert g.is_pinned()\n    assert g.ndata[\"x\"].is_pinned()\n    for col in g._node_frames[0].values():\n        assert col.pinned_by_dgl\n        assert col._data_nd is not None\n\n    g.ndata[\"x\"] = torch.randn(g.num_nodes())  # unpin the old ndata['x']\n    assert g.is_pinned()\n    for col in g._node_frames[0].values():\n        assert not col.pinned_by_dgl\n        assert col._data_nd is None\n    assert not g.ndata[\"x\"].is_pinned()\n\n\n@pytest.mark.skipif(\n    F._default_context_str == \"cpu\", reason=\"Need gpu for this test.\"\n)\ndef test_pin_empty():\n    t = torch.tensor([])\n    assert not t.is_pinned()\n\n    # Empty tensors will not be pinned or unpinned. It's a no-op.\n    # This is also the default behavior in PyTorch.\n    # We just check that it won't raise an error.\n    nd = dgl.utils.pin_memory_inplace(t)\n    assert not t.is_pinned()\n\n\nif __name__ == \"__main__\":\n    test_pin_noncontiguous()\n    test_pin_view()\n    test_unpin_automatically()\n    test_pin_unpin_column()\n"
  },
  {
    "path": "tests/python/tensorflow/test_basic.py",
    "content": "def test():\n    pass\n\n\nif __name__ == \"__main__\":\n    test()\n"
  },
  {
    "path": "tests/python/tensorflow/test_nn.py",
    "content": "from copy import deepcopy\n\nimport backend as F\n\nimport dgl\nimport dgl.function as fn\nimport dgl.nn.tensorflow as nn\nimport networkx as nx\nimport numpy as np\nimport pytest\nimport scipy as sp\nimport tensorflow as tf\nfrom tensorflow.keras import layers\nfrom utils import parametrize_idtype\nfrom utils.graph_cases import (\n    get_cases,\n    random_bipartite,\n    random_dglgraph,\n    random_graph,\n)\n\n\ndef _AXWb(A, X, W, b):\n    X = tf.matmul(X, W)\n    Y = tf.reshape(tf.matmul(A, tf.reshape(X, (X.shape[0], -1))), X.shape)\n    return Y + b\n\n\n@pytest.mark.parametrize(\"out_dim\", [1, 2])\ndef test_graph_conv(out_dim):\n    g = dgl.DGLGraph(nx.path_graph(3)).to(F.ctx())\n    ctx = F.ctx()\n    adj = tf.sparse.to_dense(\n        tf.sparse.reorder(g.adj_external(transpose=True, ctx=ctx))\n    )\n\n    conv = nn.GraphConv(5, out_dim, norm=\"none\", bias=True)\n    # conv = conv\n    print(conv)\n    # test#1: basic\n    h0 = F.ones((3, 5))\n    h1 = conv(g, h0)\n    assert len(g.ndata) == 0\n    assert len(g.edata) == 0\n    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))\n    # test#2: more-dim\n    h0 = F.ones((3, 5, 5))\n    h1 = conv(g, h0)\n    assert len(g.ndata) == 0\n    assert len(g.edata) == 0\n    assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias))\n\n    conv = nn.GraphConv(5, out_dim)\n    # conv = conv\n    # test#3: basic\n    h0 = F.ones((3, 5))\n    h1 = conv(g, h0)\n    assert len(g.ndata) == 0\n    assert len(g.edata) == 0\n    # test#4: basic\n    h0 = F.ones((3, 5, 5))\n    h1 = conv(g, h0)\n    assert len(g.ndata) == 0\n    assert len(g.edata) == 0\n\n    conv = nn.GraphConv(5, out_dim)\n    # conv = conv\n    # test#3: basic\n    h0 = F.ones((3, 5))\n    h1 = conv(g, h0)\n    assert len(g.ndata) == 0\n    assert len(g.edata) == 0\n    # test#4: basic\n    h0 = F.ones((3, 5, 5))\n    h1 = conv(g, h0)\n    assert len(g.ndata) == 0\n    assert len(g.edata) == 0\n\n    # test rest_parameters\n    # old_weight = deepcopy(conv.weight.data)\n    # conv.reset_parameters()\n    # new_weight = conv.weight.data\n    # assert not F.allclose(old_weight, new_weight)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\n    \"g\",\n    get_cases([\"homo\", \"block-bipartite\"], exclude=[\"zero-degree\", \"dglgraph\"]),\n)\n@pytest.mark.parametrize(\"norm\", [\"none\", \"both\", \"right\", \"left\"])\n@pytest.mark.parametrize(\"weight\", [True, False])\n@pytest.mark.parametrize(\"bias\", [True, False])\n@pytest.mark.parametrize(\"out_dim\", [1, 2])\ndef test_graph_conv2(idtype, g, norm, weight, bias, out_dim):\n    g = g.astype(idtype).to(F.ctx())\n    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias)\n    ext_w = F.randn((5, out_dim))\n    nsrc = g.number_of_src_nodes()\n    ndst = g.number_of_dst_nodes()\n    h = F.randn((nsrc, 5))\n    h_dst = F.randn((ndst, out_dim))\n    if weight:\n        h_out = conv(g, h)\n    else:\n        h_out = conv(g, h, weight=ext_w)\n    assert h_out.shape == (ndst, out_dim)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\n    \"g\", get_cases([\"bipartite\"], exclude=[\"zero-degree\", \"dglgraph\"])\n)\n@pytest.mark.parametrize(\"norm\", [\"none\", \"both\", \"right\"])\n@pytest.mark.parametrize(\"weight\", [True, False])\n@pytest.mark.parametrize(\"bias\", [True, False])\n@pytest.mark.parametrize(\"out_dim\", [1, 2])\ndef test_graph_conv2_bi(idtype, g, norm, weight, bias, out_dim):\n    g = g.astype(idtype).to(F.ctx())\n    conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias)\n    ext_w = F.randn((5, out_dim))\n    nsrc = g.number_of_src_nodes()\n    ndst = g.number_of_dst_nodes()\n    h = F.randn((nsrc, 5))\n    h_dst = F.randn((ndst, out_dim))\n    if weight:\n        h_out = conv(g, (h, h_dst))\n    else:\n        h_out = conv(g, (h, h_dst), weight=ext_w)\n    assert h_out.shape == (ndst, out_dim)\n\n\ndef test_simple_pool():\n    ctx = F.ctx()\n    g = dgl.DGLGraph(nx.path_graph(15)).to(F.ctx())\n\n    sum_pool = nn.SumPooling()\n    avg_pool = nn.AvgPooling()\n    max_pool = nn.MaxPooling()\n    sort_pool = nn.SortPooling(10)  # k = 10\n    print(sum_pool, avg_pool, max_pool, sort_pool)\n\n    # test#1: basic\n    h0 = F.randn((g.num_nodes(), 5))\n    h1 = sum_pool(g, h0)\n    assert F.allclose(F.squeeze(h1, 0), F.sum(h0, 0))\n    h1 = avg_pool(g, h0)\n    assert F.allclose(F.squeeze(h1, 0), F.mean(h0, 0))\n    h1 = max_pool(g, h0)\n    assert F.allclose(F.squeeze(h1, 0), F.max(h0, 0))\n    h1 = sort_pool(g, h0)\n    assert h1.shape[0] == 1 and h1.shape[1] == 10 * 5 and h1.ndim == 2\n\n    # test#2: batched graph\n    g_ = dgl.DGLGraph(nx.path_graph(5)).to(F.ctx())\n    bg = dgl.batch([g, g_, g, g_, g])\n    h0 = F.randn((bg.num_nodes(), 5))\n    h1 = sum_pool(bg, h0)\n    truth = tf.stack(\n        [\n            F.sum(h0[:15], 0),\n            F.sum(h0[15:20], 0),\n            F.sum(h0[20:35], 0),\n            F.sum(h0[35:40], 0),\n            F.sum(h0[40:55], 0),\n        ],\n        0,\n    )\n    assert F.allclose(h1, truth)\n\n    h1 = avg_pool(bg, h0)\n    truth = tf.stack(\n        [\n            F.mean(h0[:15], 0),\n            F.mean(h0[15:20], 0),\n            F.mean(h0[20:35], 0),\n            F.mean(h0[35:40], 0),\n            F.mean(h0[40:55], 0),\n        ],\n        0,\n    )\n    assert F.allclose(h1, truth)\n\n    h1 = max_pool(bg, h0)\n    truth = tf.stack(\n        [\n            F.max(h0[:15], 0),\n            F.max(h0[15:20], 0),\n            F.max(h0[20:35], 0),\n            F.max(h0[35:40], 0),\n            F.max(h0[40:55], 0),\n        ],\n        0,\n    )\n    assert F.allclose(h1, truth)\n\n    h1 = sort_pool(bg, h0)\n    assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.ndim == 2\n\n\ndef test_glob_att_pool():\n    g = dgl.DGLGraph(nx.path_graph(10)).to(F.ctx())\n\n    gap = nn.GlobalAttentionPooling(layers.Dense(1), layers.Dense(10))\n    print(gap)\n\n    # test#1: basic\n    h0 = F.randn((g.num_nodes(), 5))\n    h1 = gap(g, h0)\n    assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.ndim == 2\n\n    # test#2: batched graph\n    bg = dgl.batch([g, g, g, g])\n    h0 = F.randn((bg.num_nodes(), 5))\n    h1 = gap(bg, h0)\n    assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.ndim == 2\n\n\n@pytest.mark.parametrize(\"O\", [1, 2, 8])\ndef test_rgcn(O):\n    etype = []\n    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True).to(\n        F.ctx()\n    )\n    # 5 etypes\n    R = 5\n    for i in range(g.num_edges()):\n        etype.append(i % 5)\n    B = 2\n    I = 10\n\n    rgc_basis = nn.RelGraphConv(I, O, R, \"basis\", B)\n    rgc_basis_low = nn.RelGraphConv(I, O, R, \"basis\", B, low_mem=True)\n    rgc_basis_low.weight = rgc_basis.weight\n    rgc_basis_low.w_comp = rgc_basis.w_comp\n    rgc_basis_low.loop_weight = rgc_basis.loop_weight\n    h = tf.random.normal((100, I))\n    r = tf.constant(etype)\n    h_new = rgc_basis(g, h, r)\n    h_new_low = rgc_basis_low(g, h, r)\n    assert list(h_new.shape) == [100, O]\n    assert list(h_new_low.shape) == [100, O]\n    assert F.allclose(h_new, h_new_low)\n\n    if O % B == 0:\n        rgc_bdd = nn.RelGraphConv(I, O, R, \"bdd\", B)\n        rgc_bdd_low = nn.RelGraphConv(I, O, R, \"bdd\", B, low_mem=True)\n        rgc_bdd_low.weight = rgc_bdd.weight\n        rgc_bdd_low.loop_weight = rgc_bdd.loop_weight\n        h = tf.random.normal((100, I))\n        r = tf.constant(etype)\n        h_new = rgc_bdd(g, h, r)\n        h_new_low = rgc_bdd_low(g, h, r)\n        assert list(h_new.shape) == [100, O]\n        assert list(h_new_low.shape) == [100, O]\n        assert F.allclose(h_new, h_new_low)\n\n    # with norm\n    norm = tf.zeros((g.num_edges(), 1))\n\n    rgc_basis = nn.RelGraphConv(I, O, R, \"basis\", B)\n    rgc_basis_low = nn.RelGraphConv(I, O, R, \"basis\", B, low_mem=True)\n    rgc_basis_low.weight = rgc_basis.weight\n    rgc_basis_low.w_comp = rgc_basis.w_comp\n    rgc_basis_low.loop_weight = rgc_basis.loop_weight\n    h = tf.random.normal((100, I))\n    r = tf.constant(etype)\n    h_new = rgc_basis(g, h, r, norm)\n    h_new_low = rgc_basis_low(g, h, r, norm)\n    assert list(h_new.shape) == [100, O]\n    assert list(h_new_low.shape) == [100, O]\n    assert F.allclose(h_new, h_new_low)\n\n    if O % B == 0:\n        rgc_bdd = nn.RelGraphConv(I, O, R, \"bdd\", B)\n        rgc_bdd_low = nn.RelGraphConv(I, O, R, \"bdd\", B, low_mem=True)\n        rgc_bdd_low.weight = rgc_bdd.weight\n        rgc_bdd_low.loop_weight = rgc_bdd.loop_weight\n        h = tf.random.normal((100, I))\n        r = tf.constant(etype)\n        h_new = rgc_bdd(g, h, r, norm)\n        h_new_low = rgc_bdd_low(g, h, r, norm)\n        assert list(h_new.shape) == [100, O]\n        assert list(h_new_low.shape) == [100, O]\n        assert F.allclose(h_new, h_new_low)\n\n    # id input\n    rgc_basis = nn.RelGraphConv(I, O, R, \"basis\", B)\n    rgc_basis_low = nn.RelGraphConv(I, O, R, \"basis\", B, low_mem=True)\n    rgc_basis_low.weight = rgc_basis.weight\n    rgc_basis_low.w_comp = rgc_basis.w_comp\n    rgc_basis_low.loop_weight = rgc_basis.loop_weight\n    h = tf.constant(np.random.randint(0, I, (100,))) * 1\n    r = tf.constant(etype) * 1\n    h_new = rgc_basis(g, h, r)\n    h_new_low = rgc_basis_low(g, h, r)\n    assert list(h_new.shape) == [100, O]\n    assert list(h_new_low.shape) == [100, O]\n    assert F.allclose(h_new, h_new_low)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\n    \"g\", get_cases([\"homo\", \"block-bipartite\"], exclude=[\"zero-degree\"])\n)\n@pytest.mark.parametrize(\"out_dim\", [1, 2])\n@pytest.mark.parametrize(\"num_heads\", [1, 4])\ndef test_gat_conv(g, idtype, out_dim, num_heads):\n    g = g.astype(idtype).to(F.ctx())\n    ctx = F.ctx()\n    gat = nn.GATConv(5, out_dim, num_heads)\n    feat = F.randn((g.number_of_src_nodes(), 5))\n    h = gat(g, feat)\n    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)\n    _, a = gat(g, feat, get_attention=True)\n    assert a.shape == (g.num_edges(), num_heads, 1)\n\n    # test residual connection\n    gat = nn.GATConv(5, out_dim, num_heads, residual=True)\n    h = gat(g, feat)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"g\", get_cases([\"bipartite\"], exclude=[\"zero-degree\"]))\n@pytest.mark.parametrize(\"out_dim\", [1, 2])\n@pytest.mark.parametrize(\"num_heads\", [1, 4])\ndef test_gat_conv_bi(g, idtype, out_dim, num_heads):\n    g = g.astype(idtype).to(F.ctx())\n    ctx = F.ctx()\n    gat = nn.GATConv(5, out_dim, num_heads)\n    feat = (\n        F.randn((g.number_of_src_nodes(), 5)),\n        F.randn((g.number_of_dst_nodes(), 5)),\n    )\n    h = gat(g, feat)\n    assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)\n    _, a = gat(g, feat, get_attention=True)\n    assert a.shape == (g.num_edges(), num_heads, 1)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"g\", get_cases([\"homo\", \"block-bipartite\"]))\n@pytest.mark.parametrize(\"aggre_type\", [\"mean\", \"pool\", \"gcn\"])\n@pytest.mark.parametrize(\"out_dim\", [1, 10])\ndef test_sage_conv(idtype, g, aggre_type, out_dim):\n    g = g.astype(idtype).to(F.ctx())\n    sage = nn.SAGEConv(5, out_dim, aggre_type)\n    feat = F.randn((g.number_of_src_nodes(), 5))\n    h = sage(g, feat)\n    assert h.shape[-1] == out_dim\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"g\", get_cases([\"bipartite\"]))\n@pytest.mark.parametrize(\"aggre_type\", [\"mean\", \"pool\", \"gcn\"])\n@pytest.mark.parametrize(\"out_dim\", [1, 2])\ndef test_sage_conv_bi(idtype, g, aggre_type, out_dim):\n    g = g.astype(idtype).to(F.ctx())\n    dst_dim = 5 if aggre_type != \"gcn\" else 10\n    sage = nn.SAGEConv((10, dst_dim), out_dim, aggre_type)\n    feat = (\n        F.randn((g.number_of_src_nodes(), 10)),\n        F.randn((g.number_of_dst_nodes(), dst_dim)),\n    )\n    h = sage(g, feat)\n    assert h.shape[-1] == out_dim\n    assert h.shape[0] == g.number_of_dst_nodes()\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"aggre_type\", [\"mean\", \"pool\", \"gcn\"])\n@pytest.mark.parametrize(\"out_dim\", [1, 2])\ndef test_sage_conv_bi_empty(idtype, aggre_type, out_dim):\n    # Test the case for graphs without edges\n    g = dgl.heterograph({(\"_U\", \"_E\", \"_V\"): ([], [])}, {\"_U\": 5, \"_V\": 3}).to(\n        F.ctx()\n    )\n    g = g.astype(idtype).to(F.ctx())\n    sage = nn.SAGEConv((3, 3), out_dim, \"gcn\")\n    feat = (F.randn((5, 3)), F.randn((3, 3)))\n    h = sage(g, feat)\n    assert h.shape[-1] == out_dim\n    assert h.shape[0] == 3\n    for aggre_type in [\"mean\", \"pool\", \"lstm\"]:\n        sage = nn.SAGEConv((3, 1), out_dim, aggre_type)\n        feat = (F.randn((5, 3)), F.randn((3, 1)))\n        h = sage(g, feat)\n        assert h.shape[-1] == out_dim\n        assert h.shape[0] == 3\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"g\", get_cases([\"homo\"], exclude=[\"zero-degree\"]))\n@pytest.mark.parametrize(\"out_dim\", [1, 2])\ndef test_sgc_conv(g, idtype, out_dim):\n    ctx = F.ctx()\n    g = g.astype(idtype).to(ctx)\n    # not cached\n    sgc = nn.SGConv(5, out_dim, 3)\n    feat = F.randn((g.num_nodes(), 5))\n\n    h = sgc(g, feat)\n    assert h.shape[-1] == out_dim\n\n    # cached\n    sgc = nn.SGConv(5, out_dim, 3, True)\n    h_0 = sgc(g, feat)\n    h_1 = sgc(g, feat + 1)\n    assert F.allclose(h_0, h_1)\n    assert h_0.shape[-1] == out_dim\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"g\", get_cases([\"homo\"], exclude=[\"zero-degree\"]))\ndef test_appnp_conv(g, idtype):\n    ctx = F.ctx()\n    g = g.astype(idtype).to(ctx)\n    appnp = nn.APPNPConv(10, 0.1)\n    feat = F.randn((g.num_nodes(), 5))\n\n    h = appnp(g, feat)\n    assert h.shape[-1] == 5\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"g\", get_cases([\"homo\", \"block-bipartite\"]))\n@pytest.mark.parametrize(\"aggregator_type\", [\"mean\", \"max\", \"sum\"])\ndef test_gin_conv(g, idtype, aggregator_type):\n    g = g.astype(idtype).to(F.ctx())\n    ctx = F.ctx()\n    gin = nn.GINConv(tf.keras.layers.Dense(12), aggregator_type)\n    feat = F.randn((g.number_of_src_nodes(), 5))\n    h = gin(g, feat)\n    assert h.shape == (g.number_of_dst_nodes(), 12)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"g\", get_cases([\"bipartite\"]))\n@pytest.mark.parametrize(\"aggregator_type\", [\"mean\", \"max\", \"sum\"])\ndef test_gin_conv_bi(g, idtype, aggregator_type):\n    g = g.astype(idtype).to(F.ctx())\n    gin = nn.GINConv(tf.keras.layers.Dense(12), aggregator_type)\n    feat = (\n        F.randn((g.number_of_src_nodes(), 5)),\n        F.randn((g.number_of_dst_nodes(), 5)),\n    )\n    h = gin(g, feat)\n    assert h.shape == (g.number_of_dst_nodes(), 12)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\n    \"g\", get_cases([\"homo\", \"block-bipartite\"], exclude=[\"zero-degree\"])\n)\n@pytest.mark.parametrize(\"out_dim\", [1, 2])\ndef test_edge_conv(g, idtype, out_dim):\n    g = g.astype(idtype).to(F.ctx())\n    edge_conv = nn.EdgeConv(out_dim)\n\n    h0 = F.randn((g.number_of_src_nodes(), 5))\n    h1 = edge_conv(g, h0)\n    assert h1.shape == (g.number_of_dst_nodes(), out_dim)\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"g\", get_cases([\"bipartite\"], exclude=[\"zero-degree\"]))\n@pytest.mark.parametrize(\"out_dim\", [1, 2])\ndef test_edge_conv_bi(g, idtype, out_dim):\n    g = g.astype(idtype).to(F.ctx())\n    ctx = F.ctx()\n    edge_conv = nn.EdgeConv(out_dim)\n\n    h0 = F.randn((g.number_of_src_nodes(), 5))\n    x0 = F.randn((g.number_of_dst_nodes(), 5))\n    h1 = edge_conv(g, (h0, x0))\n    assert h1.shape == (g.number_of_dst_nodes(), out_dim)\n\n\ndef myagg(alist, dsttype):\n    rst = alist[0]\n    for i in range(1, len(alist)):\n        rst = rst + (i + 1) * alist[i]\n    return rst\n\n\n@parametrize_idtype\n@pytest.mark.parametrize(\"agg\", [\"sum\", \"max\", \"min\", \"mean\", \"stack\", myagg])\ndef test_hetero_conv(agg, idtype):\n    g = dgl.heterograph(\n        {\n            (\"user\", \"follows\", \"user\"): ([0, 0, 2, 1], [1, 2, 1, 3]),\n            (\"user\", \"plays\", \"game\"): ([0, 0, 0, 1, 2], [0, 2, 3, 0, 2]),\n            (\"store\", \"sells\", \"game\"): ([0, 0, 1, 1], [0, 3, 1, 2]),\n        },\n        idtype=idtype,\n        device=F.ctx(),\n    )\n    conv = nn.HeteroGraphConv(\n        {\n            \"follows\": nn.GraphConv(2, 3, allow_zero_in_degree=True),\n            \"plays\": nn.GraphConv(2, 4, allow_zero_in_degree=True),\n            \"sells\": nn.GraphConv(3, 4, allow_zero_in_degree=True),\n        },\n        agg,\n    )\n    uf = F.randn((4, 2))\n    gf = F.randn((4, 4))\n    sf = F.randn((2, 3))\n\n    h = conv(g, {\"user\": uf, \"store\": sf, \"game\": gf})\n    assert set(h.keys()) == {\"user\", \"game\"}\n    if agg != \"stack\":\n        assert h[\"user\"].shape == (4, 3)\n        assert h[\"game\"].shape == (4, 4)\n    else:\n        assert h[\"user\"].shape == (4, 1, 3)\n        assert h[\"game\"].shape == (4, 2, 4)\n\n    block = dgl.to_block(\n        g.to(F.cpu()), {\"user\": [0, 1, 2, 3], \"game\": [0, 1, 2, 3], \"store\": []}\n    ).to(F.ctx())\n    h = conv(\n        block,\n        (\n            {\"user\": uf, \"game\": gf, \"store\": sf},\n            {\"user\": uf, \"game\": gf, \"store\": sf[0:0]},\n        ),\n    )\n    assert set(h.keys()) == {\"user\", \"game\"}\n    if agg != \"stack\":\n        assert h[\"user\"].shape == (4, 3)\n        assert h[\"game\"].shape == (4, 4)\n    else:\n        assert h[\"user\"].shape == (4, 1, 3)\n        assert h[\"game\"].shape == (4, 2, 4)\n\n    h = conv(block, {\"user\": uf, \"game\": gf, \"store\": sf})\n    assert set(h.keys()) == {\"user\", \"game\"}\n    if agg != \"stack\":\n        assert h[\"user\"].shape == (4, 3)\n        assert h[\"game\"].shape == (4, 4)\n    else:\n        assert h[\"user\"].shape == (4, 1, 3)\n        assert h[\"game\"].shape == (4, 2, 4)\n\n    # test with mod args\n    class MyMod(tf.keras.layers.Layer):\n        def __init__(self, s1, s2):\n            super(MyMod, self).__init__()\n            self.carg1 = 0\n            self.carg2 = 0\n            self.s1 = s1\n            self.s2 = s2\n\n        def call(self, g, h, arg1=None, *, arg2=None):\n            if arg1 is not None:\n                self.carg1 += 1\n            if arg2 is not None:\n                self.carg2 += 1\n            return tf.zeros((g.number_of_dst_nodes(), self.s2))\n\n    mod1 = MyMod(2, 3)\n    mod2 = MyMod(2, 4)\n    mod3 = MyMod(3, 4)\n    conv = nn.HeteroGraphConv(\n        {\"follows\": mod1, \"plays\": mod2, \"sells\": mod3}, agg\n    )\n    mod_args = {\"follows\": (1,), \"plays\": (1,)}\n    mod_kwargs = {\"sells\": {\"arg2\": \"abc\"}}\n    h = conv(\n        g,\n        {\"user\": uf, \"game\": gf, \"store\": sf},\n        mod_args=mod_args,\n        mod_kwargs=mod_kwargs,\n    )\n    assert mod1.carg1 == 1\n    assert mod1.carg2 == 0\n    assert mod2.carg1 == 1\n    assert mod2.carg2 == 0\n    assert mod3.carg1 == 0\n    assert mod3.carg2 == 1\n\n    # conv on graph without any edges\n    for etype in g.etypes:\n        g = dgl.remove_edges(g, g.edges(form=\"eid\", etype=etype), etype=etype)\n    assert g.num_edges() == 0\n    h = conv(g, {\"user\": uf, \"game\": gf, \"store\": sf})\n    assert set(h.keys()) == {\"user\", \"game\"}\n\n    block = dgl.to_block(\n        g.to(F.cpu()), {\"user\": [0, 1, 2, 3], \"game\": [0, 1, 2, 3], \"store\": []}\n    ).to(F.ctx())\n    h = conv(\n        block,\n        (\n            {\"user\": uf, \"game\": gf, \"store\": sf},\n            {\"user\": uf, \"game\": gf, \"store\": sf[0:0]},\n        ),\n    )\n    assert set(h.keys()) == {\"user\", \"game\"}\n\n\n@pytest.mark.parametrize(\"out_dim\", [1, 2])\ndef test_dense_cheb_conv(out_dim):\n    for k in range(3, 4):\n        ctx = F.ctx()\n        g = dgl.DGLGraph(\n            sp.sparse.random(100, 100, density=0.1, random_state=42)\n        )\n        g = g.to(ctx)\n\n        adj = tf.sparse.to_dense(\n            tf.sparse.reorder(g.adj_external(transpose=True, ctx=ctx))\n        )\n        cheb = nn.ChebConv(5, out_dim, k, None, bias=True)\n        dense_cheb = nn.DenseChebConv(5, out_dim, k, bias=True)\n\n        # init cheb modules\n        feat = F.ones((100, 5))\n        out_cheb = cheb(g, feat, [2.0])\n\n        dense_cheb.W = tf.reshape(cheb.linear.weights[0], (k, 5, out_dim))\n        if cheb.linear.bias is not None:\n            dense_cheb.bias = cheb.linear.bias\n\n        out_dense_cheb = dense_cheb(adj, feat, 2.0)\n        print(out_cheb - out_dense_cheb)\n        assert F.allclose(out_cheb, out_dense_cheb)\n\n\nif __name__ == \"__main__\":\n    test_graph_conv()\n    # test_set2set()\n    test_glob_att_pool()\n    test_simple_pool()\n    # test_set_trans()\n    test_rgcn()\n    # test_tagconv()\n    test_gat_conv()\n    test_sage_conv()\n    test_sgc_conv()\n    test_appnp_conv()\n    test_gin_conv()\n    test_edge_conv()\n    # test_agnn_conv()\n    # test_gated_graph_conv()\n    # test_nn_conv()\n    # test_gmm_conv()\n    # test_dense_graph_conv()\n    # test_dense_sage_conv()\n    test_dense_cheb_conv()\n    # test_sequential()\n    test_hetero_conv()\n"
  },
  {
    "path": "tests/python/test_dgl_import.py",
    "content": "import sys\n\n\ndef test_graphbolt_is_not_imported():\n    assert (\n        \"dgl.graphbolt\" not in sys.modules\n    ), \"dgl.graphbolt is already imported\"\n    import dgl\n\n    assert \"dgl.graphbolt\" not in sys.modules, \"dgl.graphbolt is imported\"\n"
  },
  {
    "path": "tests/scripts/build_dgl.bat",
    "content": "@ECHO OFF\nSETLOCAL EnableDelayedExpansion\n\nECHO \"Current user: %USERNAME%\"\n\npython --version\n\nCALL \"C:\\Program Files (x86)\\Microsoft Visual Studio\\2019\\BuildTools\\VC\\Auxiliary\\Build\\vcvars64.bat\"\nCALL mkvirtualenv --system-site-packages %BUILD_TAG%\nDEL /S /Q build\nDEL /S /Q _download\nMD build\n\nSET _MSPDBSRV_ENDPOINT_=%BUILD_TAG%\nSET TMP=%WORKSPACE%\\tmp\nSET TEMP=%WORKSPACE%\\tmp\nSET TMPDIR=%WORKSPACE%\\tmp\n\nPUSHD build\ncmake -DCMAKE_CXX_FLAGS=\"/DDGL_EXPORTS\" -Dgtest_force_shared_crt=ON -DDMLC_FORCE_SHARED_CRT=ON -DCMAKE_CONFIGURATION_TYPES=\"Release\" -DTORCH_PYTHON_INTERPS=python .. -G \"Visual Studio 16 2019\" || EXIT /B 1\nmsbuild dgl.sln /m /nr:false || EXIT /B 1\nCOPY /Y Release\\runUnitTests.exe .\nPOPD\n\nCALL workon %BUILD_TAG%\n\nPUSHD python\nDEL /S /Q build *.egg-info dist\npip install -e . || EXIT /B 1\nPOPD\n\nENDLOCAL\nEXIT /B\n"
  },
  {
    "path": "tests/scripts/build_dgl.sh",
    "content": "#!/bin/bash\nset -e\n. /opt/conda/etc/profile.d/conda.sh\n\nif [ $# -ne 1 ]; then\n    echo \"Device argument required, can be cpu, gpu or cugraph\"\n    exit -1\nfi\n\nif [[ $1 != \"cpu\" ]]; then\n    # CI is now running on g4dn instance. Specify target arch to avoid below\n    # error: Unknown CUDA Architecture Name 9.0a in CUDA_SELECT_NVCC_ARCH_FLAGS\n    export TORCH_CUDA_ARCH_LIST=7.5 # For dgl_sparse and tensoradaptor.\n    CMAKE_VARS=\"$CMAKE_VARS -DUSE_CUDA=ON -DCUDA_ARCH_NAME=Turing\" # For graphbolt.\nfi\n\n# This is a semicolon-separated list of Python interpreters containing PyTorch.\n# The value here is for CI.  Replace it with your own or comment this whole\n# statement for default Python interpreter.\nif [ \"$1\" != \"cugraph\" ]; then\n    # We do not build pytorch for cugraph because currently building\n    # pytorch against all the supported cugraph versions is not supported\n    # See issue: https://github.com/rapidsai/cudf/issues/8510\n    CMAKE_VARS=\"$CMAKE_VARS -DTORCH_PYTHON_INTERPS=/opt/conda/envs/pytorch-ci/bin/python\"\nelse\n    # Disable sparse build as cugraph docker image lacks cuDNN.\n    CMAKE_VARS=\"$CMAKE_VARS -DBUILD_TORCH=OFF -DBUILD_SPARSE=OFF\"\nfi\n\nif [ -d build ]; then\n    rm -rf build\nfi\nmkdir build\n\nrm -rf _download\n\npushd build\ncmake $CMAKE_VARS ..\nmake -j\npopd\n\npushd python\nif [[ $1 == \"cugraph\" ]]; then\n    rm -rf build *.egg-info dist\n    pip uninstall -y dgl\n    # test install\n    python3 setup.py install\n    # test inplace build (for cython)\n    python3 setup.py build_ext --inplace\nelse\n    for backend in pytorch mxnet tensorflow\n    do\n    conda activate \"${backend}-ci\"\n    rm -rf build *.egg-info dist\n    pip uninstall -y dgl\n    # test install\n    DGLBACKEND=${backend} python3 setup.py install\n    # test inplace build (for cython)\n    DGLBACKEND=${backend} python3 setup.py build_ext --inplace\n    done\nfi\npopd\n"
  },
  {
    "path": "tests/scripts/ci_report/report.py",
    "content": "import enum\nimport json\nimport os\nimport tempfile\nfrom pathlib import Path\nfrom urllib.parse import urljoin, urlparse\n\nimport pytest\nimport requests\n\n\nclass JobStatus(enum.Enum):\n    SUCCESS = 0\n    FAIL = 1\n    SKIP = 2\n\n\nJENKINS_STATUS_MAPPING = {\n    \"SUCCESS\": JobStatus.SUCCESS,\n    \"ABORTED\": JobStatus.FAIL,\n    \"FAILED\": JobStatus.FAIL,\n    \"IN_PROGRESS\": JobStatus.FAIL,\n    \"NOT_EXECUTED\": JobStatus.SKIP,\n    \"PAUSED_PENDING_INPUT\": JobStatus.SKIP,\n    \"QUEUED\": JobStatus.SKIP,\n    \"UNSTABLE\": JobStatus.FAIL,\n}\n\nassert \"BUILD_URL\" in os.environ, \"Are you in the Jenkins environment?\"\njob_link = os.environ[\"BUILD_URL\"]\nresponse = requests.get(\"{}wfapi\".format(job_link), verify=False).json()\ndomain = \"{uri.scheme}://{uri.netloc}/\".format(uri=urlparse(job_link))\nstages = response[\"stages\"]\n\nfinal_dict = {}\nfailed_nodes = []\nnodes_dict = {}\n\n\ndef get_jenkins_json(path):\n    return requests.get(urljoin(domain, path), verify=False).json()\n\n\nfor stage in stages:\n    link = stage[\"_links\"][\"self\"][\"href\"]\n    stage_name = stage[\"name\"]\n    res = requests.get(urljoin(domain, link), verify=False).json()\n    nodes = res[\"stageFlowNodes\"]\n    for node in nodes:\n        nodes_dict[node[\"id\"]] = node\n        nodes_dict[node[\"id\"]][\"stageName\"] = stage_name\n\n\ndef get_node_full_name(node, node_dict):\n    name = \"\"\n    while \"parentNodes\" in node:\n        name = name + \"/\" + node[\"name\"]\n        id = node[\"parentNodes\"][0]\n        if id in nodes_dict:\n            node = node_dict[id]\n        else:\n            break\n    return name\n\n\nfor key, node in nodes_dict.items():\n    logs = get_jenkins_json(node[\"_links\"][\"log\"][\"href\"]).get(\"text\", \"\")\n    node_name = node[\"name\"]\n    if \"Post Actions\" in node[\"stageName\"]:\n        continue\n    node_status = node[\"status\"]\n    id = node[\"id\"]\n    full_name = get_node_full_name(node, nodes_dict)\n    final_dict[\"{}_{}/{}\".format(id, node[\"stageName\"], full_name)] = {\n        \"status\": JENKINS_STATUS_MAPPING[node_status],\n        \"logs\": logs,\n    }\n\nJOB_NAME = os.getenv(\"JOB_NAME\")\nBUILD_NUMBER = os.getenv(\"BUILD_NUMBER\")\nBUILD_ID = os.getenv(\"BUILD_ID\")\n\nprefix = f\"https://dgl-ci-result.s3.us-west-2.amazonaws.com/{JOB_NAME}/{BUILD_NUMBER}/{BUILD_ID}/logs/logs_dir/\"\n\n\n@pytest.mark.parametrize(\"test_name\", final_dict)\ndef test_generate_report(test_name):\n    os.makedirs(\"./logs_dir/\", exist_ok=True)\n    tmp = tempfile.NamedTemporaryFile(\n        mode=\"w\", delete=False, suffix=\".log\", dir=\"./logs_dir/\"\n    )\n    tmp.write(final_dict[test_name][\"logs\"])\n    filename = Path(tmp.name).name\n    # print(final_dict[test_name][\"logs\"])\n    print(\"Log path: {}\".format(prefix + filename))\n\n    if final_dict[test_name][\"status\"] == JobStatus.FAIL:\n        pytest.fail(\n            \"Test failed. Please see the log at {}\".format(prefix + filename)\n        )\n    elif final_dict[test_name][\"status\"] == JobStatus.SKIP:\n        pytest.skip(\n            \"Test skipped. Please see the log at {}\".format(prefix + filename)\n        )\n"
  },
  {
    "path": "tests/scripts/ci_report/status.py",
    "content": "import argparse\nimport os\n\nimport requests\n\nparser = argparse.ArgumentParser()\nparser.add_argument(\n    \"--result\",\n    type=str,\n    default=\"FAILURE\",\n)\nargs = parser.parse_args()\n\nJOB_NAME = os.getenv(\"JOB_NAME\")\nBUILD_NUMBER = os.getenv(\"BUILD_NUMBER\")\nBUILD_ID = os.getenv(\"BUILD_ID\")\nCOMMIT = os.getenv(\"GIT_COMMIT\")\n\n# List of status of entire job.\n# https://javadoc.jenkins.io/hudson/model/Result.html\nif args.result == \"SUCCESS\":\n    status_output = \"✅ CI test succeeded.\"\nelif args.result == \"NOT_BUILT\":\n    status_output = \"⚪️ CI test cancelled due to overrun.\"\nelif args.result in [\"FAILURE\", \"ABORTED\"]:\n    status_output = \"❌ CI test failed.\"\n    JOB_LINK = os.environ[\"BUILD_URL\"]\n    response = requests.get(\"{}wfapi\".format(JOB_LINK), verify=False).json()\n    for stage in response[\"stages\"]:\n        # List of status of individual stage.\n        # https://javadoc.jenkins.io/plugin/pipeline-graph-analysis/org/jenkinsci/plugins/workflow/pipelinegraphanalysis/GenericStatus.html\n        if stage[\"status\"] in [\"FAILED\", \"ABORTED\"]:\n            stage_name = stage[\"name\"]\n            status_output = f\"❌ CI test failed in Stage [{stage_name}].\"\n            break\nelse:\n    status_output = f\"[Debug Only] CI test with result [{args.result}].\"\n\n\ncomment = f\"\"\"\nCommit ID: {COMMIT}\\n\nBuild ID: {BUILD_ID}\\n\nStatus: {status_output}\\n\nReport path: [link](https://dgl-ci-result.s3.us-west-2.amazonaws.com/{JOB_NAME}/{BUILD_NUMBER}/{BUILD_ID}/logs/report.html)\\n\nFull logs path: [link](https://dgl-ci-result.s3.us-west-2.amazonaws.com/{JOB_NAME}/{BUILD_NUMBER}/{BUILD_ID}/logs/cireport.log)\n\"\"\"\n\nprint(comment)\n"
  },
  {
    "path": "tests/scripts/cugraph_unit_test.sh",
    "content": "#!/bin/bash\n\n. /opt/conda/etc/profile.d/conda.sh\n\nfunction fail {\n    echo FAIL: $@\n    exit -1\n}\n\nexport DGLBACKEND=$1\nexport DGLTESTDEV=gpu\nexport DGL_LIBRARY_PATH=${PWD}/build\nexport PYTHONPATH=tests:${PWD}/python:$PYTHONPATH\nexport DGL_DOWNLOAD_DIR=${PWD}/_download\nexport TF_FORCE_GPU_ALLOW_GROWTH=true\n\nexport CUDA_VISIBLE_DEVICES=0\n\npython3 -m pip install pytest psutil pyyaml pydantic pandas rdflib ogb torchdata || fail \"pip install\"\n\npython3 -m pytest -v --junitxml=pytest_cugraph.xml --durations=20 tests/cugraph || fail \"cugraph\"\n"
  },
  {
    "path": "tests/scripts/task_cpp_unit_test.bat",
    "content": "@ECHO OFF\nSETLOCAL EnableDelayedExpansion\n\nPUSHD build\nrunUnitTests.exe || EXIT /B 1\nPOPD\n"
  },
  {
    "path": "tests/scripts/task_cpp_unit_test.sh",
    "content": "#!/bin/bash\nfunction fail {\n    echo FAIL: $@\n    exit -1\n}\necho $PWD\npushd build\nls -lh\nexport LD_LIBRARY_PATH=$PWD:$LD_LIBRARY_PATH\n./runUnitTests || fail \"CPP unit test\"\npopd\n"
  },
  {
    "path": "tests/scripts/task_dist_test.sh",
    "content": "#!/bin/bash\nfunction fail {\n    echo FAIL: $@\n    exit -1\n}\n\necho $PWD\nexport DGLBACKEND=pytorch\nexport DGL_LIBRARY_PATH=${PWD}/build\nexport PYTHONPATH=${PWD}/tests:${PWD}/python:$PYTHONPATH\nexport LD_LIBRARY_PATH=${PWD}/build:$LD_LIBRARY_PATH\nexport DIST_DGL_TEST_CPP_BIN_DIR=${PWD}/build\nexport DIST_DGL_TEST_IP_CONFIG=/home/ubuntu/workspace/ip_config.txt\nexport DIST_DGL_TEST_PY_BIN_DIR=${PWD}/tests/dist/python\n\nif [[ -v DIST_DGL_TEST_SSH_PORT ]]; then\n    SSH_PORT_LINE=\"-p $DIST_DGL_TEST_SSH_PORT\";\nfi\n\nif [[ -v DIST_DGL_TEST_SSH_KEY ]]; then\n    SSH_KEY_LINE=\"-i $DIST_DGL_TEST_SSH_KEY\";\nfi\n\nif [[ -v DIST_DGL_TEST_SSH_SETUP ]]; then\n    SSH_SETUP_LINE=\"$DIST_DGL_TEST_SSH_SETUP;\";\nfi\n\n\nwhile IFS= read line\ndo\n    for pkg in 'pytest' 'psutil' 'torch'\n    do\n        ret_pkg=$(ssh -o StrictHostKeyChecking=no ${line} ${SSH_PORT_LINE} ${SSH_KEY_LINE} \"${SSH_SETUP_LINE}python3 -m pip list | grep -i ${pkg} \") || fail \"${pkg} not installed in ${line}\"\n    done\ndone < ${DIST_DGL_TEST_IP_CONFIG}\n\npython3 -m pytest -v --capture=tee-sys --junitxml=pytest_dist.xml --durations=100 tests/dist/test_*.py || fail \"dist across machines\"\n"
  },
  {
    "path": "tests/scripts/task_distributed_test.sh",
    "content": "#!/bin/bash\n\n. /opt/conda/etc/profile.d/conda.sh\n\nfunction fail {\n    echo FAIL: $@\n    exit -1\n}\n\nfunction usage {\n    echo \"Usage: $0 backend device\"\n}\n\nif [ $# -ne 2 ]; then\n    usage\n    fail \"Error: must specify backend and device\"\nfi\n\n[ $1 == \"pytorch\" ] || fail \"Distrbuted tests run on pytorch backend only.\"\n[ $2 == \"cpu\" ] || fail \"Distrbuted tests run on cpu only.\"\n\nexport DGLBACKEND=$1\nexport DGLTESTDEV=$2\nexport DGL_LIBRARY_PATH=${PWD}/build\nexport PYTHONPATH=tests:${PWD}/python:$PYTHONPATH\nexport DGL_DOWNLOAD_DIR=${PWD}/_download\nunset TORCH_ALLOW_TF32_CUBLAS_OVERRIDE\n\nexport CUDA_VISIBLE_DEVICES=-1\n\nconda activate ${DGLBACKEND}-ci\n\nexport PYTHONUNBUFFERED=1\nexport OMP_NUM_THREADS=1\nexport DMLC_LOG_DEBUG=1\n\n# Tests for distributed except test_partition.py are skipped due to glitch @2024.06.27.\npython3 -m pytest -v --capture=tee-sys --junitxml=pytest_distributed.xml --durations=100 tests/distributed/test_partition.py || fail \"distributed\"\n\n# Tests for tools are skipped due to glitch.\n#PYTHONPATH=tools:tools/distpartitioning:$PYTHONPATH python3 -m pytest -v --capture=tee-sys --junitxml=pytest_tools.xml --durations=100 tests/tools/*.py || fail \"tools\"\n"
  },
  {
    "path": "tests/scripts/task_example_test.bat",
    "content": "@ECHO OFF\nSETLOCAL EnableDelayedExpansion\n\nSET GCN_EXAMPLE_DIR=.\\examples\\pytorch\n\nIF x%1x==xx (\n\tECHO Must supply CPU or GPU\n\tGOTO :FAIL\n) ELSE IF x%1x==xcpux (\n\tSET DEV=-1\n) ELSE IF x%1x==xgpux (\n\tSET DEV=0\n\tSET CUDA_VISIBLE_DEVICES=0\n) ELSE (\n\tECHO Must supply CPU or GPU\n\tGOTO :FAIL\n)\nCALL workon %BUILD_TAG%\n\nSET DGLBACKEND=pytorch\nSET DGL_LIBRARY_PATH=!CD!\\build\nSET PYTHONPATH=!CD!\\python;!PYTHONPATH!\nSET DGL_DOWNLOAD_DIR=!CD!\\_download\n\npython -m pytest -v --junitxml=pytest_backend.xml --durations=100 tests\\examples || GOTO :FAIL\n\nPUSHD !GCN_EXAMPLE_DIR!\npython pagerank.py || GOTO :FAIL\npython gcn\\train.py --dataset cora || GOTO :FAIL\nPOPD\nENDLOCAL\nEXIT /B\n\n:FAIL\nECHO Example test failed\nENDLOCAL\nEXIT /B 1\n"
  },
  {
    "path": "tests/scripts/task_example_test.sh",
    "content": "#!/bin/bash\n\n. /opt/conda/etc/profile.d/conda.sh\nconda activate pytorch-ci\nGCN_EXAMPLE_DIR=\"./examples/pytorch/\"\n\nfunction fail {\n    echo FAIL: $@\n    exit -1\n}\n\nfunction usage {\n    echo \"Usage: $0 [cpu|gpu]\"\n}\n\n# check arguments\nif [ $# -ne 1 ]; then\n    usage\n    fail \"Error: must specify device\"\nfi\n\nif [ \"$1\" == \"cpu\" ]; then\n    dev=-1\nelif [ \"$1\" == \"gpu\" ]; then\n    export CUDA_VISIBLE_DEVICES=0\n    dev=0\nelse\n    usage\n    fail \"Unknown device $1\"\nfi\n\nexport DGLBACKEND=pytorch\nexport DGL_LIBRARY_PATH=${PWD}/build\nexport PYTHONPATH=${PWD}/python:$PYTHONPATH\nexport DGL_DOWNLOAD_DIR=${PWD}/_download\n\n# test\n\npython3 -m pytest -v --junitxml=pytest_backend.xml --durations=100 tests/examples || fail \"sparse examples on $1\"\n\npushd $GCN_EXAMPLE_DIR> /dev/null\n\npython3 pagerank.py || fail \"run pagerank.py on $1\"\npython3 gcn/train.py --dataset cora || fail \"run gcn/train.py on $1\"\npython3 lda/lda_model.py || fail \"run lda/lda_model.py on $1\"\n\npopd > /dev/null\n"
  },
  {
    "path": "tests/scripts/task_go_test.sh",
    "content": "#!/bin/bash\n\n. /opt/conda/etc/profile.d/conda.sh\n\nfunction fail {\n    echo FAIL: $@\n    exit -1\n}\n\nexport DGLBACKEND=pytorch\nexport DGL_LIBRARY_PATH=${PWD}/build\nexport PYTHONPATH=tests:${PWD}/python:$PYTHONPATH\nexport DGL_DOWNLOAD_DIR=${PWD}/_download\n\nconda activate pytorch-ci\n\npushd dglgo\nrm -rf build *.egg-info dist\npip uninstall -y dglgo\npython3 setup.py install\npopd\n\nexport LC_ALL=C.UTF-8\nexport LANG=C.UTF-8\n\n# Skip go tests due to ImportError: cannot import name 'cached_property' from 'functools' in python3.7\n#python3 -m pytest -v --junitxml=pytest_go.xml --durations=100 tests/go/test_model.py || fail \"go\"\n"
  },
  {
    "path": "tests/scripts/task_lint.sh",
    "content": "#!/bin/bash\n\n# cpplint\necho 'Checking code style of C++ codes...'\npython3 tests/lint/lint.py dgl cpp include src || exit 1\npython3 tests/lint/lint.py dgl_sparse cpp dgl_sparse/include dgl_sparse/src || exit 1\n\n# pylint\necho 'Checking code style of python codes...'\npython3 -m pylint --reports=y -v --rcfile=tests/lint/pylintrc python/dgl || exit 1\n"
  },
  {
    "path": "tests/scripts/task_pytorch_tutorial_test.sh",
    "content": "#!/bin/bash\n# The working directory for this script will be \"tests/scripts\"\n\n. /opt/conda/etc/profile.d/conda.sh\nconda activate pytorch-ci\nTUTORIAL_ROOT=\"./tutorials\"\n\nfunction fail {\n    echo FAIL: $@\n    exit -1\n}\n\nexport MPLBACKEND=Agg\nexport DGLBACKEND=pytorch\nexport DGL_LIBRARY_PATH=${PWD}/build\nexport PYTHONPATH=${PWD}/python:$PYTHONPATH\nexport DGL_DOWNLOAD_DIR=${PWD}/_download\n\npushd ${TUTORIAL_ROOT} > /dev/null\n# Install requirements\npip install -r requirements.txt || fail \"installing requirements\"\n\n# Test\nfor f in $(find . -path ./dist -prune -false -o -name \"*.py\" ! -name \"*_mx.py\")\ndo\n    echo \"Running tutorial ${f} ...\"\n    python3 $f || fail \"run ${f}\"\ndone\n\npopd > /dev/null\n"
  },
  {
    "path": "tests/scripts/task_unit_test.bat",
    "content": "@ECHO OFF\nSETLOCAL EnableDelayedExpansion\n\nIF x%1x==xx (\n\tECHO Specify backend\n\tEXIT /B 1\n) ELSE (\n\tSET BACKEND=%1\n)\nCALL workon %BUILD_TAG%\n\nSET PYTHONPATH=tests;!CD!\\python;!PYTHONPATH!\nSET DGLBACKEND=!BACKEND!\nSET DGL_LIBRARY_PATH=!CD!\\build\nSET DGL_DOWNLOAD_DIR=!CD!\\_download\n\npython -m pip install pytest psutil pandas pyyaml pydantic rdflib torchmetrics expecttest || EXIT /B 1\npython -m pytest -v --junitxml=pytest_backend.xml --durations=100 tests\\python\\!DGLBACKEND! || EXIT /B 1\npython -m pytest -v --junitxml=pytest_common.xml --durations=100 tests\\python\\common || EXIT /B 1\nENDLOCAL\nEXIT /B\n"
  },
  {
    "path": "tests/scripts/task_unit_test.sh",
    "content": "#!/bin/bash\n\n. /opt/conda/etc/profile.d/conda.sh\n\nfunction fail {\n    echo FAIL: $@\n    exit -1\n}\n\nfunction usage {\n    echo \"Usage: $0 backend device\"\n}\n\nif [ $# -ne 2 ]; then\n    usage\n    fail \"Error: must specify backend and device\"\nfi\n\nexport DGLBACKEND=$1\nexport DGLTESTDEV=$2\nexport DGL_LIBRARY_PATH=${PWD}/build\nexport PYTHONPATH=tests:${PWD}/python:$PYTHONPATH\nexport DGL_DOWNLOAD_DIR=${PWD}/_download\nexport TF_FORCE_GPU_ALLOW_GROWTH=true\nunset TORCH_ALLOW_TF32_CUBLAS_OVERRIDE\n\nif [ $2 == \"gpu\" ] \nthen\n  export CUDA_VISIBLE_DEVICES=0\nelse\n  export CUDA_VISIBLE_DEVICES=-1\nfi\n\nconda activate ${DGLBACKEND}-ci\n\npython3 -m pip install expecttest\n\nif [ $DGLBACKEND == \"mxnet\" ]\nthen\n  python3 -m pytest -v --junitxml=pytest_compute.xml --durations=100 --ignore=tests/python/common/test_ffi.py tests/python/common || fail \"common\"\nelse\n  python3 -m pytest -v --junitxml=pytest_dgl_import.xml tests/python/test_dgl_import.py || fail \"dgl_import\"\n  python3 -m pytest -v --junitxml=pytest_common.xml --durations=100 tests/python/common || fail \"common\"\nfi\npython3 -m pytest -v --junitxml=pytest_backend.xml --durations=100 tests/python/$DGLBACKEND || fail \"backend-specific\"\n"
  },
  {
    "path": "tests/tools/pytest_utils.py",
    "content": "import json\nimport logging\nimport os\n\nimport dgl\nimport numpy as np\nimport torch\nfrom distpartitioning import array_readwriter\nfrom distpartitioning.array_readwriter.parquet import ParquetArrayParser\nfrom files import setdir\n\n\ndef _chunk_numpy_array(arr, fmt_meta, chunk_sizes, path_fmt, vector_rows=False):\n    paths = []\n    offset = 0\n\n    for j, n in enumerate(chunk_sizes):\n        path = os.path.abspath(path_fmt % j)\n        arr_chunk = arr[offset : offset + n]\n        shape = arr_chunk.shape\n        logging.info(\"Chunking %d-%d\" % (offset, offset + n))\n        # If requested we write multi-column arrays as single-column vector Parquet files\n        array_parser = array_readwriter.get_array_parser(**fmt_meta)\n        if (\n            isinstance(array_parser, ParquetArrayParser)\n            and len(shape) > 1\n            and shape[1] > 1\n        ):\n            array_parser.write(path, arr_chunk, vector_rows=vector_rows)\n        else:\n            array_parser.write(path, arr_chunk)\n        offset += n\n        paths.append(path)\n\n    return paths\n\n\ndef _initialize_num_chunks(g, num_chunks, kwargs=None):\n    \"\"\"Initialize num_chunks for each node/edge.\n\n    Parameters\n    ----------\n    g: DGLGraph\n        Graph to be chunked.\n    num_chunks: int\n        Default number of chunks to be applied onto node/edge data.\n    kwargs: dict\n        Key word arguments to specify details for each node/edge data.\n\n    Returns\n    -------\n    num_chunks_data: dict\n        Detailed number of chunks for each node/edge.\n    \"\"\"\n\n    def _init(g, num_chunks, key, kwargs=None):\n        chunks_data = kwargs.get(key, None)\n        is_node = \"_node\" in key\n        data_types = g.ntypes if is_node else g.canonical_etypes\n        if isinstance(chunks_data, int):\n            chunks_data = {data_type: chunks_data for data_type in data_types}\n        elif isinstance(chunks_data, dict):\n            for data_type in data_types:\n                if data_type not in chunks_data:\n                    chunks_data[data_type] = num_chunks\n        else:\n            chunks_data = {data_type: num_chunks for data_type in data_types}\n        for _, data in chunks_data.items():\n            if isinstance(data, dict):\n                n_chunks = list(data.values())\n            else:\n                n_chunks = [data]\n            assert all(\n                isinstance(v, int) for v in n_chunks\n            ), \"num_chunks for each data type should be int.\"\n        return chunks_data\n\n    num_chunks_data = {}\n    for key in [\n        \"num_chunks_nodes\",\n        \"num_chunks_edges\",\n        \"num_chunks_node_data\",\n        \"num_chunks_edge_data\",\n    ]:\n        num_chunks_data[key] = _init(g, num_chunks, key, kwargs=kwargs)\n    return num_chunks_data\n\n\ndef _chunk_graph(\n    g,\n    name,\n    ndata_paths,\n    edata_paths,\n    num_chunks,\n    data_fmt,\n    edges_format,\n    vector_rows=False,\n    **kwargs,\n):\n    # First deal with ndata and edata that are homogeneous\n    # (i.e. not a dict-of-dict)\n    if len(g.ntypes) == 1 and not isinstance(\n        next(iter(ndata_paths.values())), dict\n    ):\n        ndata_paths = {g.ntypes[0]: ndata_paths}\n    if len(g.etypes) == 1 and not isinstance(\n        next(iter(edata_paths.values())), dict\n    ):\n        edata_paths = {g.etypes[0]: ndata_paths}\n    # Then convert all edge types to canonical edge types\n    etypestrs = {etype: \":\".join(etype) for etype in g.canonical_etypes}\n    edata_paths = {\n        \":\".join(g.to_canonical_etype(k)): v for k, v in edata_paths.items()\n    }\n\n    metadata = {}\n\n    metadata[\"graph_name\"] = name\n    metadata[\"node_type\"] = g.ntypes\n\n    # add node_type_counts\n    metadata[\"num_nodes_per_type\"] = [g.num_nodes(ntype) for ntype in g.ntypes]\n\n    # Initialize num_chunks for each node/edge.\n    num_chunks_details = _initialize_num_chunks(g, num_chunks, kwargs=kwargs)\n\n    # Compute the number of nodes per chunk per node type\n    metadata[\"num_nodes_per_chunk\"] = num_nodes_per_chunk = []\n    num_chunks_nodes = num_chunks_details[\"num_chunks_nodes\"]\n    for ntype in g.ntypes:\n        num_nodes = g.num_nodes(ntype)\n        num_nodes_list = []\n        n_chunks = num_chunks_nodes[ntype]\n        for i in range(n_chunks):\n            n = num_nodes // n_chunks + (i < num_nodes % n_chunks)\n            num_nodes_list.append(n)\n        num_nodes_per_chunk.append(num_nodes_list)\n\n    metadata[\"edge_type\"] = [etypestrs[etype] for etype in g.canonical_etypes]\n    metadata[\"num_edges_per_type\"] = [\n        g.num_edges(etype) for etype in g.canonical_etypes\n    ]\n\n    # Compute the number of edges per chunk per edge type\n    metadata[\"num_edges_per_chunk\"] = num_edges_per_chunk = []\n    num_chunks_edges = num_chunks_details[\"num_chunks_edges\"]\n    for etype in g.canonical_etypes:\n        num_edges = g.num_edges(etype)\n        num_edges_list = []\n        n_chunks = num_chunks_edges[etype]\n        for i in range(n_chunks):\n            n = num_edges // n_chunks + (i < num_edges % n_chunks)\n            num_edges_list.append(n)\n        num_edges_per_chunk.append(num_edges_list)\n    num_edges_per_chunk_dict = {\n        k: v for k, v in zip(g.canonical_etypes, num_edges_per_chunk)\n    }\n\n    idxes_etypestr = {\n        idx: (etype, etypestrs[etype])\n        for idx, etype in enumerate(g.canonical_etypes)\n    }\n    idxes = np.arange(len(idxes_etypestr))\n\n    # Split edge index\n    metadata[\"edges\"] = {}\n    with setdir(\"edge_index\"):\n        np.random.shuffle(idxes)\n        for idx in idxes:\n            etype = idxes_etypestr[idx][0]\n            etypestr = idxes_etypestr[idx][1]\n            logging.info(\"Chunking edge index for %s\" % etypestr)\n            edges_meta = {}\n            if edges_format == \"csv\":\n                fmt_meta = {\"name\": edges_format, \"delimiter\": \" \"}\n            elif edges_format == \"parquet\":\n                fmt_meta = {\"name\": edges_format}\n            else:\n                raise RuntimeError(f\"Invalid edges_fmt: {edges_format}\")\n            edges_meta[\"format\"] = fmt_meta\n\n            srcdst = torch.stack(g.edges(etype=etype), 1)\n            edges_meta[\"data\"] = _chunk_numpy_array(\n                srcdst.numpy(),\n                fmt_meta,\n                num_edges_per_chunk_dict[etype],\n                etypestr + \"%d.txt\",\n            )\n            metadata[\"edges\"][etypestr] = edges_meta\n\n    # Chunk node data\n    reader_fmt_meta, writer_fmt_meta = {\"name\": \"numpy\"}, {\"name\": data_fmt}\n    file_suffix = \"npy\" if data_fmt == \"numpy\" else \"parquet\"\n    metadata[\"node_data\"] = {}\n    num_chunks_node_data = num_chunks_details[\"num_chunks_node_data\"]\n    with setdir(\"node_data\"):\n        for ntype, ndata_per_type in ndata_paths.items():\n            ndata_meta = {}\n            with setdir(ntype):\n                for key, path in ndata_per_type.items():\n                    logging.info(\n                        \"Chunking node data for type %s key %s\" % (ntype, key)\n                    )\n                    chunk_sizes = []\n                    num_nodes = g.num_nodes(ntype)\n                    n_chunks = num_chunks_node_data[ntype]\n                    if isinstance(n_chunks, dict):\n                        n_chunks = n_chunks.get(key, num_chunks)\n                    assert isinstance(n_chunks, int), (\n                        f\"num_chunks for {ntype}/{key} should be int while \"\n                        f\"{type(n_chunks)} is got.\"\n                    )\n                    for i in range(n_chunks):\n                        n = num_nodes // n_chunks + (i < num_nodes % n_chunks)\n                        chunk_sizes.append(n)\n                    ndata_key_meta = {}\n                    arr = array_readwriter.get_array_parser(\n                        **reader_fmt_meta\n                    ).read(path)\n                    ndata_key_meta[\"format\"] = writer_fmt_meta\n                    ndata_key_meta[\"data\"] = _chunk_numpy_array(\n                        arr,\n                        writer_fmt_meta,\n                        chunk_sizes,\n                        key + \"-%d.\" + file_suffix,\n                        vector_rows=vector_rows,\n                    )\n                    ndata_meta[key] = ndata_key_meta\n\n            metadata[\"node_data\"][ntype] = ndata_meta\n\n    # Chunk edge data\n    metadata[\"edge_data\"] = {}\n    num_chunks_edge_data = num_chunks_details[\"num_chunks_edge_data\"]\n    with setdir(\"edge_data\"):\n        for etypestr, edata_per_type in edata_paths.items():\n            edata_meta = {}\n            etype = tuple(etypestr.split(\":\"))\n            with setdir(etypestr):\n                for key, path in edata_per_type.items():\n                    logging.info(\n                        \"Chunking edge data for type %s key %s\"\n                        % (etypestr, key)\n                    )\n                    chunk_sizes = []\n                    num_edges = g.num_edges(etype)\n                    n_chunks = num_chunks_edge_data[etype]\n                    if isinstance(n_chunks, dict):\n                        n_chunks = n_chunks.get(key, num_chunks)\n                    assert isinstance(n_chunks, int), (\n                        f\"num_chunks for {etype}/{key} should be int while \"\n                        f\"{type(n_chunks)} is got.\"\n                    )\n                    for i in range(n_chunks):\n                        n = num_edges // n_chunks + (i < num_edges % n_chunks)\n                        chunk_sizes.append(n)\n                    edata_key_meta = {}\n                    arr = array_readwriter.get_array_parser(\n                        **reader_fmt_meta\n                    ).read(path)\n                    edata_key_meta[\"format\"] = writer_fmt_meta\n                    edata_key_meta[\"data\"] = _chunk_numpy_array(\n                        arr,\n                        writer_fmt_meta,\n                        chunk_sizes,\n                        key + \"-%d.\" + file_suffix,\n                        vector_rows=vector_rows,\n                    )\n                    edata_meta[key] = edata_key_meta\n\n            metadata[\"edge_data\"][etypestr] = edata_meta\n\n    metadata_path = \"metadata.json\"\n    with open(metadata_path, \"w\") as f:\n        json.dump(metadata, f, sort_keys=True, indent=4)\n    logging.info(\"Saved metadata in %s\" % os.path.abspath(metadata_path))\n\n\ndef chunk_graph(\n    g,\n    name,\n    ndata_paths,\n    edata_paths,\n    num_chunks,\n    output_path,\n    data_fmt=\"numpy\",\n    edges_fmt=\"csv\",\n    vector_rows=False,\n    **kwargs,\n):\n    \"\"\"\n    Split the graph into multiple chunks.\n\n    A directory will be created at :attr:`output_path` with the metadata and\n    chunked edge list as well as the node/edge data.\n\n    Parameters\n    ----------\n    g : DGLGraph\n        The graph.\n    name : str\n        The name of the graph, to be used later in DistDGL training.\n    ndata_paths : dict[str, pathlike] or dict[ntype, dict[str, pathlike]]\n        The dictionary of paths pointing to the corresponding numpy array file\n        for each node data key.\n    edata_paths : dict[etype, pathlike] or dict[etype, dict[str, pathlike]]\n        The dictionary of paths pointing to the corresponding numpy array file\n        for each edge data key. ``etype`` could be canonical or non-canonical.\n    num_chunks : int\n        The number of chunks\n    output_path : pathlike\n        The output directory saving the chunked graph.\n    data_fmt : str\n        Format of node/edge data: 'numpy' or 'parquet'.\n    edges_fmt : str\n        Format of edges files: 'csv' or 'parquet'.\n    vector_rows : str\n        When true will write parquet files as single-column vector row files.\n    kwargs : dict\n        Key word arguments to control chunk details.\n    \"\"\"\n    for ntype, ndata in ndata_paths.items():\n        for key in ndata.keys():\n            ndata[key] = os.path.abspath(ndata[key])\n    for etype, edata in edata_paths.items():\n        for key in edata.keys():\n            edata[key] = os.path.abspath(edata[key])\n    with setdir(output_path):\n        _chunk_graph(\n            g,\n            name,\n            ndata_paths,\n            edata_paths,\n            num_chunks,\n            data_fmt,\n            edges_fmt,\n            vector_rows,\n            **kwargs,\n        )\n\n\ndef create_chunked_dataset(\n    root_dir,\n    num_chunks,\n    data_fmt=\"numpy\",\n    edges_fmt=\"csv\",\n    vector_rows=False,\n    **kwargs,\n):\n    \"\"\"\n    This function creates a sample dataset, based on MAG240 dataset.\n\n    Parameters:\n    -----------\n    root_dir : string\n        directory in which all the files for the chunked dataset will be stored.\n    \"\"\"\n    # Step0: prepare chunked graph data format.\n    # A synthetic mini MAG240.\n    num_institutions = 1200\n    num_authors = 1200\n    num_papers = 1200\n\n    def rand_edges(num_src, num_dst, num_edges):\n        eids = np.random.choice(num_src * num_dst, num_edges, replace=False)\n        src = torch.from_numpy(eids // num_dst)\n        dst = torch.from_numpy(eids % num_dst)\n\n        return src, dst\n\n    num_cite_edges = 24 * 1000\n    num_write_edges = 12 * 1000\n    num_affiliate_edges = 2400\n\n    # Structure.\n    data_dict = {\n        (\"paper\", \"cites\", \"paper\"): rand_edges(\n            num_papers, num_papers, num_cite_edges\n        ),\n        (\"author\", \"writes\", \"paper\"): rand_edges(\n            num_authors, num_papers, num_write_edges\n        ),\n        (\"author\", \"affiliated_with\", \"institution\"): rand_edges(\n            num_authors, num_institutions, num_affiliate_edges\n        ),\n        (\"institution\", \"writes\", \"paper\"): rand_edges(\n            num_institutions, num_papers, num_write_edges\n        ),\n    }\n    src, dst = data_dict[(\"author\", \"writes\", \"paper\")]\n    data_dict[(\"paper\", \"rev_writes\", \"author\")] = (dst, src)\n    g = dgl.heterograph(data_dict)\n\n    # paper feat, label, year\n    num_paper_feats = 3\n    paper_feat = np.random.randn(num_papers, num_paper_feats)\n    num_classes = 4\n    paper_label = np.random.choice(num_classes, num_papers)\n    paper_year = np.random.choice(2022, num_papers)\n    paper_orig_ids = np.arange(0, num_papers)\n    writes_orig_ids = np.arange(0, num_write_edges)\n\n    # masks.\n    paper_train_mask = np.random.choice([True, False], num_papers)\n    paper_test_mask = np.random.choice([True, False], num_papers)\n    paper_val_mask = np.random.choice([True, False], num_papers)\n\n    author_train_mask = np.random.choice([True, False], num_authors)\n    author_test_mask = np.random.choice([True, False], num_authors)\n    author_val_mask = np.random.choice([True, False], num_authors)\n\n    inst_train_mask = np.random.choice([True, False], num_institutions)\n    inst_test_mask = np.random.choice([True, False], num_institutions)\n    inst_val_mask = np.random.choice([True, False], num_institutions)\n\n    write_train_mask = np.random.choice([True, False], num_write_edges)\n    write_test_mask = np.random.choice([True, False], num_write_edges)\n    write_val_mask = np.random.choice([True, False], num_write_edges)\n\n    # Edge features.\n    cite_count = np.random.choice(10, num_cite_edges)\n    write_year = np.random.choice(2022, num_write_edges)\n    write2_year = np.random.choice(2022, num_write_edges)\n\n    # Save features.\n    input_dir = os.path.join(root_dir, \"data_test\")\n    os.makedirs(input_dir)\n    for sub_d in [\"paper\", \"cites\", \"writes\", \"writes2\"]:\n        os.makedirs(os.path.join(input_dir, sub_d))\n\n    paper_feat_path = os.path.join(input_dir, \"paper/feat.npy\")\n    with open(paper_feat_path, \"wb\") as f:\n        np.save(f, paper_feat)\n    g.nodes[\"paper\"].data[\"feat\"] = torch.from_numpy(paper_feat)\n\n    paper_label_path = os.path.join(input_dir, \"paper/label.npy\")\n    with open(paper_label_path, \"wb\") as f:\n        np.save(f, paper_label)\n    g.nodes[\"paper\"].data[\"label\"] = torch.from_numpy(paper_label)\n\n    paper_year_path = os.path.join(input_dir, \"paper/year.npy\")\n    with open(paper_year_path, \"wb\") as f:\n        np.save(f, paper_year)\n    g.nodes[\"paper\"].data[\"year\"] = torch.from_numpy(paper_year)\n\n    paper_orig_ids_path = os.path.join(input_dir, \"paper/orig_ids.npy\")\n    with open(paper_orig_ids_path, \"wb\") as f:\n        np.save(f, paper_orig_ids)\n    g.nodes[\"paper\"].data[\"orig_ids\"] = torch.from_numpy(paper_orig_ids)\n\n    cite_count_path = os.path.join(input_dir, \"cites/count.npy\")\n    with open(cite_count_path, \"wb\") as f:\n        np.save(f, cite_count)\n    g.edges[\"cites\"].data[\"count\"] = torch.from_numpy(cite_count)\n\n    write_year_path = os.path.join(input_dir, \"writes/year.npy\")\n    with open(write_year_path, \"wb\") as f:\n        np.save(f, write_year)\n    g.edges[(\"author\", \"writes\", \"paper\")].data[\"year\"] = torch.from_numpy(\n        write_year\n    )\n    g.edges[\"rev_writes\"].data[\"year\"] = torch.from_numpy(write_year)\n\n    writes_orig_ids_path = os.path.join(input_dir, \"writes/orig_ids.npy\")\n    with open(writes_orig_ids_path, \"wb\") as f:\n        np.save(f, writes_orig_ids)\n    g.edges[(\"author\", \"writes\", \"paper\")].data[\"orig_ids\"] = torch.from_numpy(\n        writes_orig_ids\n    )\n\n    write2_year_path = os.path.join(input_dir, \"writes2/year.npy\")\n    with open(write2_year_path, \"wb\") as f:\n        np.save(f, write2_year)\n    g.edges[(\"institution\", \"writes\", \"paper\")].data[\"year\"] = torch.from_numpy(\n        write2_year\n    )\n\n    etype = (\"author\", \"writes\", \"paper\")\n    write_train_mask_path = os.path.join(input_dir, \"writes/train_mask.npy\")\n    with open(write_train_mask_path, \"wb\") as f:\n        np.save(f, write_train_mask)\n    g.edges[etype].data[\"train_mask\"] = torch.from_numpy(write_train_mask)\n\n    write_test_mask_path = os.path.join(input_dir, \"writes/test_mask.npy\")\n    with open(write_test_mask_path, \"wb\") as f:\n        np.save(f, write_test_mask)\n    g.edges[etype].data[\"test_mask\"] = torch.from_numpy(write_test_mask)\n\n    write_val_mask_path = os.path.join(input_dir, \"writes/val_mask.npy\")\n    with open(write_val_mask_path, \"wb\") as f:\n        np.save(f, write_val_mask)\n    g.edges[etype].data[\"val_mask\"] = torch.from_numpy(write_val_mask)\n\n    for sub_d in [\"author\", \"institution\"]:\n        os.makedirs(os.path.join(input_dir, sub_d))\n    paper_train_mask_path = os.path.join(input_dir, \"paper/train_mask.npy\")\n    with open(paper_train_mask_path, \"wb\") as f:\n        np.save(f, paper_train_mask)\n    g.nodes[\"paper\"].data[\"train_mask\"] = torch.from_numpy(paper_train_mask)\n\n    paper_test_mask_path = os.path.join(input_dir, \"paper/test_mask.npy\")\n    with open(paper_test_mask_path, \"wb\") as f:\n        np.save(f, paper_test_mask)\n    g.nodes[\"paper\"].data[\"test_mask\"] = torch.from_numpy(paper_test_mask)\n\n    paper_val_mask_path = os.path.join(input_dir, \"paper/val_mask.npy\")\n    with open(paper_val_mask_path, \"wb\") as f:\n        np.save(f, paper_val_mask)\n    g.nodes[\"paper\"].data[\"val_mask\"] = torch.from_numpy(paper_val_mask)\n\n    author_train_mask_path = os.path.join(input_dir, \"author/train_mask.npy\")\n    with open(author_train_mask_path, \"wb\") as f:\n        np.save(f, author_train_mask)\n    g.nodes[\"author\"].data[\"train_mask\"] = torch.from_numpy(author_train_mask)\n\n    author_test_mask_path = os.path.join(input_dir, \"author/test_mask.npy\")\n    with open(author_test_mask_path, \"wb\") as f:\n        np.save(f, author_test_mask)\n    g.nodes[\"author\"].data[\"test_mask\"] = torch.from_numpy(author_test_mask)\n\n    author_val_mask_path = os.path.join(input_dir, \"author/val_mask.npy\")\n    with open(author_val_mask_path, \"wb\") as f:\n        np.save(f, author_val_mask)\n    g.nodes[\"author\"].data[\"val_mask\"] = torch.from_numpy(author_val_mask)\n\n    inst_train_mask_path = os.path.join(input_dir, \"institution/train_mask.npy\")\n    with open(inst_train_mask_path, \"wb\") as f:\n        np.save(f, inst_train_mask)\n    g.nodes[\"institution\"].data[\"train_mask\"] = torch.from_numpy(\n        inst_train_mask\n    )\n\n    inst_test_mask_path = os.path.join(input_dir, \"institution/test_mask.npy\")\n    with open(inst_test_mask_path, \"wb\") as f:\n        np.save(f, inst_test_mask)\n    g.nodes[\"institution\"].data[\"test_mask\"] = torch.from_numpy(inst_test_mask)\n\n    inst_val_mask_path = os.path.join(input_dir, \"institution/val_mask.npy\")\n    with open(inst_val_mask_path, \"wb\") as f:\n        np.save(f, inst_val_mask)\n    g.nodes[\"institution\"].data[\"val_mask\"] = torch.from_numpy(inst_val_mask)\n\n    node_data = {\n        \"paper\": {\n            \"feat\": paper_feat_path,\n            \"train_mask\": paper_train_mask_path,\n            \"test_mask\": paper_test_mask_path,\n            \"val_mask\": paper_val_mask_path,\n            \"label\": paper_label_path,\n            \"year\": paper_year_path,\n            \"orig_ids\": paper_orig_ids_path,\n        },\n        \"author\": {\n            \"train_mask\": author_train_mask_path,\n            \"test_mask\": author_test_mask_path,\n            \"val_mask\": author_val_mask_path,\n        },\n        \"institution\": {\n            \"train_mask\": inst_train_mask_path,\n            \"test_mask\": inst_test_mask_path,\n            \"val_mask\": inst_val_mask_path,\n        },\n    }\n\n    edge_data = {\n        \"cites\": {\"count\": cite_count_path},\n        (\"author\", \"writes\", \"paper\"): {\n            \"year\": write_year_path,\n            \"orig_ids\": writes_orig_ids_path,\n            \"train_mask\": write_train_mask_path,\n            \"test_mask\": write_test_mask_path,\n            \"val_mask\": write_val_mask_path,\n        },\n        \"rev_writes\": {\"year\": write_year_path},\n        (\"institution\", \"writes\", \"paper\"): {\"year\": write2_year_path},\n    }\n\n    output_dir = os.path.join(root_dir, \"chunked-data\")\n    chunk_graph(\n        g,\n        \"mag240m\",\n        node_data,\n        edge_data,\n        num_chunks=num_chunks,\n        output_path=output_dir,\n        data_fmt=data_fmt,\n        edges_fmt=edges_fmt,\n        vector_rows=vector_rows,\n        **kwargs,\n    )\n    logging.debug(\"Done with creating chunked graph\")\n\n    return g\n"
  },
  {
    "path": "tests/tools/test_array_readwriter.py",
    "content": "import os\nimport tempfile\n\nimport numpy as np\nimport pytest\nfrom distpartitioning import array_readwriter\n\n\n@pytest.mark.parametrize(\n    \"shape\", [[500], [300, 10], [200, 5, 5], [100, 5, 5, 5]]\n)\n@pytest.mark.parametrize(\"format\", [\"numpy\", \"parquet\"])\ndef test_array_readwriter(format, shape):\n    original_array = np.random.rand(*shape)\n    fmt_meta = {\"name\": format}\n\n    with tempfile.TemporaryDirectory() as test_dir:\n        path = os.path.join(test_dir, f\"nodes.{format}\")\n        array_readwriter.get_array_parser(**fmt_meta).write(\n            path, original_array\n        )\n        array = array_readwriter.get_array_parser(**fmt_meta).read(path)\n\n        assert original_array.shape == array.shape\n        assert np.array_equal(original_array, array)\n"
  },
  {
    "path": "tests/tools/test_change_etype_to_canonical_etype.py",
    "content": "import json\nimport os\nimport tempfile\nimport unittest\nfrom collections import Counter\n\nimport dgl\n\nimport pytest\nfrom change_etype_to_canonical_etype import convert_conf, is_old_version\nfrom dgl.distributed import partition_graph\nfrom scipy import sparse as spsp\n\n\ndef create_random_hetero(type_n, node_n):\n    num_nodes = {}\n    for i in range(1, type_n + 1):\n        num_nodes[f\"n{i}\"] = node_n\n    c_etypes = []\n    count = 0\n    for i in range(1, type_n):\n        for j in range(i + 1, type_n + 1):\n            count += 1\n            c_etypes.append((f\"n{i}\", f\"r{count}\", f\"n{j}\"))\n    edges = {}\n    for etype in c_etypes:\n        src_ntype, _, dst_ntype = etype\n        arr = spsp.random(\n            num_nodes[src_ntype],\n            num_nodes[dst_ntype],\n            density=0.001,\n            format=\"coo\",\n            random_state=100,\n        )\n        edges[etype] = (arr.row, arr.col)\n    return dgl.heterograph(edges, num_nodes), [\n        \":\".join(c_etype) for c_etype in c_etypes\n    ]\n\n\n@unittest.skip(reason=\"Skip due to glitch in CI\")\n@pytest.mark.parametrize(\n    \"type_n, node_n, num_parts\", [[3, 100, 2], [10, 500, 4], [10, 1000, 8]]\n)\ndef test_hetero_graph(type_n, node_n, num_parts):\n    g, expected_c_etypes = create_random_hetero(type_n, node_n)\n    do_convert_and_check(g, \"convert_conf_test\", num_parts, expected_c_etypes)\n\n\n@unittest.skip(reason=\"Skip due to glitch in CI\")\n@pytest.mark.parametrize(\"node_n, num_parts\", [[100, 2], [500, 4]])\ndef test_homo_graph(node_n, num_parts):\n    g = dgl.rand_graph(node_n, node_n // 10)\n    do_convert_and_check(g, \"convert_conf_test\", num_parts, [\"_N:_E:_N\"])\n\n\ndef do_convert_and_check(g, graph_name, num_parts, expected_c_etypes):\n    with tempfile.TemporaryDirectory() as root_dir:\n        partition_graph(g, graph_name, num_parts, root_dir)\n        part_config = os.path.join(root_dir, graph_name + \".json\")\n        old_config = _get_old_config(part_config)\n        # Call convert function\n        convert_conf(part_config)\n        with open(part_config, \"r\") as config_f:\n            config = json.load(config_f)\n            # Check we get all canonical etypes\n            assert Counter(expected_c_etypes) == Counter(\n                config[\"etypes\"].keys()\n            )\n            # Check the id is match after transform from etypes -> canonical\n            assert old_config[\"etypes\"] == _extract_etypes(config[\"etypes\"])\n\n\ndef _get_old_config(part_config):\n    with open(part_config, \"r+\") as config_f:\n        config = json.load(config_f)\n        if not is_old_version(config):\n            config[\"etypes\"] = _extract_etypes(config[\"etypes\"])\n            config[\"edge_map\"] = _extract_edge_map(config[\"edge_map\"])\n            config_f.seek(0)\n            json.dump(config, config_f, indent=4)\n            config_f.truncate()\n        return config\n\n\ndef _extract_etypes(c_etypes):\n    etypes = {}\n    for c_etype, eid in c_etypes.items():\n        etype = c_etype.split(\":\")[1]\n        etypes[etype] = eid\n    return etypes\n\n\ndef _extract_edge_map(c_edge_map):\n    edge_map = {}\n    for c_etype, emap in c_edge_map.items():\n        etype = c_etype.split(\":\")[1]\n        edge_map[etype] = emap\n    return edge_map\n"
  },
  {
    "path": "tests/tools/test_convert_partition.py",
    "content": "import os\nimport tempfile\n\nimport numpy as np\nimport pytest\n\nimport utils\n\nfrom convert_partition import _get_unique_invidx\n\n\n@pytest.mark.parametrize(\n    \"num_nodes, num_edges, nid_begin, nid_end\",\n    [\n        [4000, 40000, 0, 1000],\n        [4000, 40000, 1000, 2000],\n        [4000, 40000, 2000, 3000],\n        [4000, 40000, 3000, 4000],\n        [4000, 100, 0, 1000],\n        [4000, 100, 1000, 2000],\n        [4000, 100, 2000, 3000],\n        [4000, 100, 3000, 4000],\n        [1, 1, 0, 1],\n    ],\n)\ndef test_get_unique_invidx_with_numpy(num_nodes, num_edges, nid_begin, nid_end):\n    # prepare data for the function\n    # generate synthetic edges\n    if num_edges > 0:\n        srcids = np.random.randint(0, num_nodes, (num_edges,))  # exclusive\n        dstids = np.random.randint(\n            nid_begin, nid_end, (num_edges,)\n        )  # exclusive\n    else:\n        srcids = np.array([])\n        dstids = np.array([])\n\n    assert nid_begin <= nid_end\n\n    # generate unique node-ids for any\n    # partition. This list should be sorted.\n    # This is equivilant to shuffle_nids in a partition\n    unique_nids = np.arange(nid_begin, nid_end)  # exclusive\n\n    # test with numpy unique here\n    orig_srcids = srcids.copy()\n    orig_dstids = dstids.copy()\n    input_arr = np.concatenate([srcids, dstids, unique_nids])\n\n    # test\n    uniques, idxes, srcids, dstids = _get_unique_invidx(\n        srcids, dstids, unique_nids\n    )\n\n    assert len(uniques) == len(idxes)\n    assert np.all(srcids < len(uniques))\n    assert np.all(dstids < len(uniques))\n    assert np.all(uniques[srcids].sort() == orig_srcids.sort())\n    assert np.all(uniques[dstids] == orig_dstids)\n\n    assert np.all(uniques == input_arr[idxes])\n\n    # numpy\n    np_uniques, np_idxes, np_inv_idxes = np.unique(\n        np.concatenate([orig_srcids, orig_dstids, unique_nids]),\n        return_index=True,\n        return_inverse=True,\n    )\n\n    # test uniques\n    assert np.all(np_uniques == uniques)\n\n    # test idxes array\n    assert np.all(input_arr[idxes].sort() == input_arr[np_idxes].sort())\n\n    # test srcids, inv_indices\n    assert np.all(\n        uniques[srcids].sort()\n        == np_uniques[np_inv_idxes[0 : len(srcids)]].sort()\n    )\n\n    # test dstids, inv_indices\n    assert np.all(\n        uniques[dstids].sort() == np_uniques[np_inv_idxes[len(srcids) :]].sort()\n    )\n\n\n@pytest.mark.parametrize(\n    \"num_nodes, num_edges, nid_begin, nid_end\",\n    [\n        # dense networks, no. of edges more than no. of nodes\n        [4000, 40000, 0, 1000],\n        [4000, 40000, 1000, 2000],\n        [4000, 40000, 2000, 3000],\n        [4000, 40000, 3000, 4000],\n        # sparse networks, no. of edges smaller than no. of nodes\n        [4000, 100, 0, 1000],\n        [4000, 100, 1000, 2000],\n        [4000, 100, 2000, 3000],\n        [4000, 100, 3000, 4000],\n        # corner case\n        [1, 1, 0, 1],\n    ],\n)\ndef test_get_unique_invidx(num_nodes, num_edges, nid_begin, nid_end):\n    # prepare data for the function\n    # generate synthetic edges\n    if num_edges > 0:\n        srcids = np.random.randint(0, num_nodes, (num_edges,))\n        dstids = np.random.randint(nid_begin, nid_end, (num_edges,))\n    else:\n        srcids = np.array([])\n        dstids = np.array([])\n\n    assert nid_begin <= nid_end\n\n    # generate unique node-ids for any\n    # partition. This list should be sorted.\n    # This is equivilant to shuffle_nids in a partition\n    unique_nids = np.arange(nid_begin, nid_end)\n\n    # invoke the test target\n    uniques, idxes, src_ids, dst_ids = _get_unique_invidx(\n        srcids, dstids, unique_nids\n    )\n\n    # validate the outputs of this function\n    # array uniques should be sorted list of integers.\n    assert np.all(\n        np.diff(uniques) >= 0\n    ), f\"Output parameter uniques assert failing.\"\n\n    # idxes are list of integers\n    # these are indices in the concatenated list (srcids, dstids, unique_nids)\n    max_idx = len(src_ids) + len(dst_ids) + len(unique_nids)\n    assert np.all(idxes >= 0), f\"Output parameter idxes has negative values.\"\n    assert np.all(\n        idxes < max_idx\n    ), f\"Output parameter idxes has invalid maximum value.\"\n\n    # srcids and dstids will be inverse indices in the uniques list\n    min_src = np.amin(src_ids)\n    max_src = np.amax(src_ids)\n\n    min_dst = np.amin(dst_ids)\n    max_dst = np.amax(dst_ids)\n\n    assert (\n        len(uniques) > max_src\n    ), f\"Inverse idx, src_ids, has invalid max value.\"\n    assert min_src >= 0, f\"Inverse idx, src_ids has negative values.\"\n\n    assert len(uniques) > max_dst, f\"Inverse idx, dst_ids, invalid max value.\"\n    assert max_dst >= 0, f\"Inverse idx, dst_ids has negative values.\"\n\n\ndef test_get_unique_invidx_low_mem():\n    srcids = np.array([14, 0, 3, 3, 0, 3, 9, 5, 14, 12])\n    dstids = np.array([10, 16, 12, 13, 10, 17, 16, 13, 14, 16])\n    unique_nids = np.array([10, 11, 12, 13, 14, 15, 16, 17, 18, 19])\n\n    uniques, idxes, srcids, dstids = _get_unique_invidx(\n        srcids,\n        dstids,\n        unique_nids,\n        low_mem=True,\n    )\n    expected_unqiues = np.array(\n        [0, 3, 5, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]\n    )\n    expected_idxes = np.array(\n        [1, 2, 7, 6, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29]\n    )\n    expected_srcids = np.array([8, 0, 1, 1, 0, 1, 3, 2, 8, 6])\n    expected_dstids = np.array([4, 10, 6, 7, 4, 11, 10, 7, 8, 10])\n    assert np.all(\n        uniques == expected_unqiues\n    ), f\"unique is not expected. {uniques} != {expected_unqiues}\"\n    assert np.all(\n        idxes == expected_idxes\n    ), f\"indices is not expected. {idxes} != {expected_idxes}\"\n    assert np.all(\n        srcids == expected_srcids\n    ), f\"srcids is not expected. {srcids} != {expected_srcids}\"\n    assert np.all(\n        dstids == expected_dstids\n    ), f\"dstdis is not expected. {dstids} != {expected_dstids}\"\n\n\ndef test_get_unique_invidx_high_mem():\n    srcids = np.array([14, 0, 3, 3, 0, 3, 9, 5, 14, 12])\n    dstids = np.array([10, 16, 12, 13, 10, 17, 16, 13, 14, 16])\n    unique_nids = np.array([10, 11, 12, 13, 14, 15, 16, 17, 18, 19])\n\n    uniques, idxes, srcids, dstids = _get_unique_invidx(\n        srcids,\n        dstids,\n        unique_nids,\n        low_mem=False,\n    )\n    expected_unqiues = np.array(\n        [0, 3, 5, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]\n    )\n    expected_idxes = np.array(\n        [1, 2, 7, 6, 10, 21, 9, 13, 0, 25, 11, 15, 28, 29]\n    )\n    expected_srcids = np.array([8, 0, 1, 1, 0, 1, 3, 2, 8, 6])\n    expected_dstids = np.array([4, 10, 6, 7, 4, 11, 10, 7, 8, 10])\n    assert np.all(\n        uniques == expected_unqiues\n    ), f\"unique is not expected. {uniques} != {expected_unqiues}\"\n    assert np.all(\n        idxes == expected_idxes\n    ), f\"indices is not expected. {idxes} != {expected_idxes}\"\n    assert np.all(\n        srcids == expected_srcids\n    ), f\"srcids is not expected. {srcids} != {expected_srcids}\"\n    assert np.all(\n        dstids == expected_dstids\n    ), f\"dstdis is not expected. {dstids} != {expected_dstids}\"\n\n\ndef test_get_unique_invidx_low_high_mem():\n    srcids = np.array([14, 0, 3, 3, 0, 3, 9, 5, 14, 12])\n    dstids = np.array([10, 16, 12, 13, 10, 17, 16, 13, 14, 16])\n    unique_nids = np.array([10, 11, 12, 13, 14, 15, 16, 17, 18, 19])\n\n    uniques_low, idxes_low, srcids_low, dstids_low = _get_unique_invidx(\n        srcids,\n        dstids,\n        unique_nids,\n        low_mem=True,\n    )\n    uniques_high, idxes_high, srcids_high, dstids_high = _get_unique_invidx(\n        srcids,\n        dstids,\n        unique_nids,\n        low_mem=False,\n    )\n    assert np.all(\n        uniques_low == uniques_high\n    ), f\"unique is not expected. {uniques_low} != {uniques_high}\"\n    assert not np.all(\n        idxes_low == idxes_high\n    ), f\"indices is not expected. {idxes_low} == {idxes_high}\"\n    assert np.all(\n        srcids_low == srcids_high\n    ), f\"srcids is not expected. {srcids_low} != {srcids_high}\"\n    assert np.all(\n        dstids_low == dstids_high\n    ), f\"dstdis is not expected. {dstids_low} != {dstids_high}\"\n"
  },
  {
    "path": "tests/tools/test_dist_lookup.py",
    "content": "import logging\nimport os\nimport platform\nimport tempfile\nfrom datetime import timedelta\n\nimport dgl\n\nimport numpy as np\nimport pyarrow\nimport pytest\n\nimport torch.distributed as dist\nimport torch.multiprocessing as mp\n\nfrom pytest_utils import create_chunked_dataset\nfrom tools.distpartitioning import constants, dist_lookup\nfrom tools.distpartitioning.gloo_wrapper import allgather_sizes\nfrom tools.distpartitioning.utils import (\n    get_idranges,\n    get_ntype_counts_map,\n    read_json,\n)\n\ntry:\n    mp.set_start_method(\"spawn\", force=True)\nexcept RuntimeError:\n    pass\n\n\ndef _init_process_group(rank, world_size):\n    # init the gloo process group here.\n    dist.init_process_group(\n        backend=\"gloo\",\n        rank=rank,\n        world_size=world_size,\n        timeout=timedelta(seconds=180),\n    )\n    print(f\"[Rank: {rank}] Done with process group initialization...\")\n\n\ndef _create_lookup_service(\n    partitions_dir, ntypes, id_map, rank, world_size, num_parts\n):\n    id_lookup = dist_lookup.DistLookupService(\n        partitions_dir, ntypes, rank, world_size, num_parts\n    )\n    id_lookup.set_idMap(id_map)\n\n    # invoke the main function here.\n    print(f\"[Rank: {rank}] Done with Dist Lookup Service initialization...\")\n\n    return id_lookup\n\n\ndef _run(\n    port_num,\n    rank,\n    num_parts,\n    world_size,\n    partitions_dir,\n    ntypes,\n    id_map,\n    test_data,\n):\n    os.environ[\"MASTER_ADDR\"] = \"127.0.0.1\"\n    os.environ[\"MASTER_PORT\"] = str(port_num)\n\n    _init_process_group(rank, world_size)\n    lookup = _create_lookup_service(\n        partitions_dir, ntypes, id_map, rank, world_size, num_parts\n    )\n\n    tests_exec = 0\n    for worker, data in test_data.items():\n        if f\"rank-{rank}\" == worker:\n            for item in data:\n                method = item[0]\n                request = item[1]\n                response = item[2]\n\n                if method == \"getpartitionids\":\n                    ret_val = lookup.get_partition_ids(request)\n                    tests_exec += 1\n                    assert np.all(ret_val == response)\n                else:\n                    assert False\n\n    # ensure all the tests are executed.\n    rank_counts = allgather_sizes([tests_exec], world_size, num_parts, True)\n    assert np.sum(rank_counts) == len(test_data)\n\n\ndef _single_machine_run(\n    num_parts, world_size, partitions_dir, ntypes, id_map, test_data\n):\n    port_num = np.random.randint(10000, 20000, size=(1,), dtype=int)[0]\n    ctx = mp.get_context(\"spawn\")\n    processes = []\n    for rank in range(world_size):\n        p = ctx.Process(\n            target=_run,\n            args=(\n                port_num,\n                rank,\n                num_parts,\n                world_size,\n                partitions_dir,\n                ntypes,\n                id_map,\n                test_data,\n            ),\n        )\n        p.start()\n        processes.append(p)\n\n    for p in processes:\n        p.join()\n        p.close()\n\n\ndef _prepare_test_data(partitions_dir, ntypes, gid_ranges, world_size):\n    # read node-id to partition-id mappings from disk\n    ntype_partids = []\n    for ntype_id, ntype in enumerate(ntypes):\n        filename = f\"{ntype}.txt\"\n        assert os.path.isfile(os.path.join(partitions_dir, filename))\n\n        read_options = pyarrow.csv.ReadOptions(\n            use_threads=True,\n            block_size=4096,\n            autogenerate_column_names=True,\n        )\n        parse_options = pyarrow.csv.ParseOptions(delimiter=\" \")\n\n        with pyarrow.csv.open_csv(\n            os.path.join(partitions_dir, \"{}.txt\".format(ntype)),\n            read_options=read_options,\n            parse_options=parse_options,\n        ) as reader:\n            for next_chunk in reader:\n                if next_chunk is None:\n                    break\n                next_table = pyarrow.Table.from_batches([next_chunk])\n                ntype_partids.append(next_table[\"f0\"].to_numpy())\n\n    # prepare test data for each rank here\n    # key = f'rank-{rank}'\n    # value is a list of tuple [(method-name, request, response)]\n    test_data = {}\n    for rank in range(world_size):\n        ntype_id = np.random.randint(0, len(ntypes) - 1)\n        ntype = ntypes[ntype_id]\n        request = (\n            np.arange(len(ntype_partids[ntype_id]))\n            + gid_ranges[ntypes[ntype_id]][0, 0]\n        )\n        response = ntype_partids[ntype_id]\n\n        test_data[f\"rank-{rank}\"] = [(\"getpartitionids\", request, response)]\n\n    # randomly shuffle the global-nids and retrieve their partition-ids.\n    for rank in range(world_size):\n        ntype_id = np.random.randint(0, len(ntypes) - 1)\n        ntype = ntypes[ntype_id]\n        idx = np.arange(len(ntype_partids[ntype_id]))\n        request = idx + gid_ranges[ntypes[ntype_id]][0, 0]\n\n        np.random.shuffle(idx)\n        request = request[idx]\n        response = ntype_partids[ntype_id][idx]\n\n        test_data[f\"rank-{rank}\"] = [(\"getpartitionids\", request, response)]\n\n    # one final test\n    # mix all the ntypes and shuffle randomly\n    request = []\n    response = []\n    for idx in range(len(ntype_partids)):\n        request.append(\n            np.arange(len(ntype_partids[idx])) + gid_ranges[ntypes[idx]][0, 0]\n        )\n        response.append(ntype_partids[idx])\n\n    request = np.concatenate(request)\n    response = np.concatenate(response)\n\n    idx = np.arange(len(request))\n    np.random.shuffle(idx)\n    request = request[idx]\n    response = response[idx]\n    for idx in range(world_size):\n        test_data[f\"rank-{idx}\"] = [(\"getpartitionids\", request, response)]\n\n    return test_data\n\n\n@pytest.mark.parametrize(\n    \"num_chunks, num_parts, world_size\",\n    [[4, 4, 4], [8, 4, 2], [8, 4, 4], [9, 6, 3], [11, 11, 1], [11, 4, 1]],\n)\ndef test_lookup_service(\n    num_chunks,\n    num_parts,\n    world_size,\n    num_chunks_nodes=None,\n    num_chunks_edges=None,\n    num_chunks_node_data=None,\n    num_chunks_edge_data=None,\n):\n\n    with tempfile.TemporaryDirectory() as root_dir:\n        g = create_chunked_dataset(\n            root_dir,\n            num_chunks,\n            data_fmt=\"numpy\",\n            num_chunks_nodes=num_chunks_nodes,\n            num_chunks_edges=num_chunks_edges,\n            num_chunks_node_data=num_chunks_node_data,\n            num_chunks_edge_data=num_chunks_edge_data,\n        )\n\n        # Step1: graph partition\n        in_dir = os.path.join(root_dir, \"chunked-data\")\n        output_dir = os.path.join(root_dir, \"parted_data\")\n        os.system(\n            \"python3 tools/partition_algo/random_partition.py \"\n            \"--in_dir {} --out_dir {} --num_partitions {}\".format(\n                in_dir, output_dir, num_parts\n            )\n        )\n\n        # metadata for original graph\n        orig_config = os.path.join(in_dir, \"metadata.json\")\n        orig_schema = read_json(orig_config)\n        ntypes = orig_schema[constants.STR_NODE_TYPE]\n\n        _, global_nid_ranges = get_idranges(\n            orig_schema[constants.STR_NODE_TYPE],\n            get_ntype_counts_map(\n                orig_schema[constants.STR_NODE_TYPE],\n                orig_schema[constants.STR_NUM_NODES_PER_TYPE],\n            ),\n            num_chunks=num_parts,\n        )\n\n        id_map = dgl.distributed.id_map.IdMap(global_nid_ranges)\n\n        # run the test\n        _single_machine_run(\n            num_parts,\n            world_size,\n            output_dir,\n            ntypes,\n            id_map,\n            _prepare_test_data(\n                output_dir, ntypes, global_nid_ranges, world_size\n            ),\n        )\n"
  },
  {
    "path": "tests/tools/test_dist_part.py",
    "content": "import json\nimport os\nimport tempfile\n\nimport dgl\nimport dgl.backend as F\n\nimport numpy as np\nimport pyarrow.parquet as pq\nimport pytest\nimport torch\nfrom dgl.data.utils import load_graphs, load_tensors\nfrom dgl.distributed.partition import (\n    _etype_tuple_to_str,\n    _get_inner_edge_mask,\n    _get_inner_node_mask,\n    load_partition,\n    RESERVED_FIELD_DTYPE,\n)\n\nfrom distpartitioning import array_readwriter\nfrom distpartitioning.utils import generate_read_list\nfrom pytest_utils import chunk_graph, create_chunked_dataset\nfrom scipy import sparse as spsp\n\nfrom tools.verification_utils import (\n    verify_graph_feats,\n    verify_partition_data_types,\n    verify_partition_formats,\n)\n\n\ndef _test_chunk_graph(\n    num_chunks,\n    data_fmt=\"numpy\",\n    edges_fmt=\"csv\",\n    vector_rows=False,\n    num_chunks_nodes=None,\n    num_chunks_edges=None,\n    num_chunks_node_data=None,\n    num_chunks_edge_data=None,\n):\n    with tempfile.TemporaryDirectory() as root_dir:\n        g = create_chunked_dataset(\n            root_dir,\n            num_chunks,\n            data_fmt=data_fmt,\n            edges_fmt=edges_fmt,\n            vector_rows=vector_rows,\n            num_chunks_nodes=num_chunks_nodes,\n            num_chunks_edges=num_chunks_edges,\n            num_chunks_node_data=num_chunks_node_data,\n            num_chunks_edge_data=num_chunks_edge_data,\n        )\n\n        # check metadata.json\n        output_dir = os.path.join(root_dir, \"chunked-data\")\n        json_file = os.path.join(output_dir, \"metadata.json\")\n        assert os.path.isfile(json_file)\n        with open(json_file, \"rb\") as f:\n            meta_data = json.load(f)\n        assert meta_data[\"graph_name\"] == \"mag240m\"\n        assert len(meta_data[\"num_nodes_per_chunk\"][0]) == num_chunks\n\n        # check edge_index\n        output_edge_index_dir = os.path.join(output_dir, \"edge_index\")\n        for c_etype in g.canonical_etypes:\n            c_etype_str = _etype_tuple_to_str(c_etype)\n            if num_chunks_edges is None:\n                n_chunks = num_chunks\n            else:\n                n_chunks = num_chunks_edges\n            for i in range(n_chunks):\n                fname = os.path.join(\n                    output_edge_index_dir, f\"{c_etype_str}{i}.txt\"\n                )\n                assert os.path.isfile(fname)\n                if edges_fmt == \"csv\":\n                    with open(fname, \"r\") as f:\n                        header = f.readline()\n                        num1, num2 = header.rstrip().split(\" \")\n                        assert isinstance(int(num1), int)\n                        assert isinstance(int(num2), int)\n                elif edges_fmt == \"parquet\":\n                    metadata = pq.read_metadata(fname)\n                    assert metadata.num_columns == 2\n                else:\n                    assert False, f\"Invalid edges_fmt: {edges_fmt}\"\n\n        # check node/edge_data\n        suffix = \"npy\" if data_fmt == \"numpy\" else \"parquet\"\n        reader_fmt_meta = {\"name\": data_fmt}\n\n        def test_data(sub_dir, feat, expected_data, expected_shape, num_chunks):\n            data = []\n            for i in range(num_chunks):\n                fname = os.path.join(sub_dir, f\"{feat}-{i}.{suffix}\")\n                assert os.path.isfile(fname), f\"{fname} cannot be found.\"\n                feat_array = array_readwriter.get_array_parser(\n                    **reader_fmt_meta\n                ).read(fname)\n                assert feat_array.shape[0] == expected_shape\n                data.append(feat_array)\n            data = np.concatenate(data, 0)\n            assert torch.equal(torch.from_numpy(data), expected_data)\n\n        output_node_data_dir = os.path.join(output_dir, \"node_data\")\n        for ntype in g.ntypes:\n            sub_dir = os.path.join(output_node_data_dir, ntype)\n            if isinstance(num_chunks_node_data, int):\n                chunks_data = num_chunks_node_data\n            elif isinstance(num_chunks_node_data, dict):\n                chunks_data = num_chunks_node_data.get(ntype, num_chunks)\n            else:\n                chunks_data = num_chunks\n            for feat, data in g.nodes[ntype].data.items():\n                if isinstance(chunks_data, dict):\n                    n_chunks = chunks_data.get(feat, num_chunks)\n                else:\n                    n_chunks = chunks_data\n                test_data(\n                    sub_dir,\n                    feat,\n                    data,\n                    g.num_nodes(ntype) // n_chunks,\n                    n_chunks,\n                )\n\n        output_edge_data_dir = os.path.join(output_dir, \"edge_data\")\n        for c_etype in g.canonical_etypes:\n            c_etype_str = _etype_tuple_to_str(c_etype)\n            sub_dir = os.path.join(output_edge_data_dir, c_etype_str)\n            if isinstance(num_chunks_edge_data, int):\n                chunks_data = num_chunks_edge_data\n            elif isinstance(num_chunks_edge_data, dict):\n                chunks_data = num_chunks_edge_data.get(c_etype, num_chunks)\n            else:\n                chunks_data = num_chunks\n            for feat, data in g.edges[c_etype].data.items():\n                if isinstance(chunks_data, dict):\n                    n_chunks = chunks_data.get(feat, num_chunks)\n                else:\n                    n_chunks = chunks_data\n                test_data(\n                    sub_dir,\n                    feat,\n                    data,\n                    g.num_edges(c_etype) // n_chunks,\n                    n_chunks,\n                )\n\n\n@pytest.mark.parametrize(\"num_chunks\", [1, 8])\n@pytest.mark.parametrize(\"data_fmt\", [\"numpy\", \"parquet\"])\n@pytest.mark.parametrize(\"edges_fmt\", [\"csv\", \"parquet\"])\ndef test_chunk_graph_basics(num_chunks, data_fmt, edges_fmt):\n    _test_chunk_graph(num_chunks, data_fmt=data_fmt, edges_fmt=edges_fmt)\n\n\n@pytest.mark.parametrize(\"num_chunks\", [1, 8])\n@pytest.mark.parametrize(\"vector_rows\", [True, False])\ndef test_chunk_graph_vector_rows(num_chunks, vector_rows):\n    _test_chunk_graph(\n        num_chunks,\n        data_fmt=\"parquet\",\n        edges_fmt=\"parquet\",\n        vector_rows=vector_rows,\n    )\n\n\n@pytest.mark.parametrize(\n    \"num_chunks, \"\n    \"num_chunks_nodes, \"\n    \"num_chunks_edges, \"\n    \"num_chunks_node_data, \"\n    \"num_chunks_edge_data\",\n    [\n        [1, None, None, None, None],\n        [8, None, None, None, None],\n        [4, 4, 4, 8, 12],\n        [4, 4, 4, {\"paper\": 10}, {(\"author\", \"writes\", \"paper\"): 24}],\n        [\n            4,\n            4,\n            4,\n            {\"paper\": {\"feat\": 10}},\n            {(\"author\", \"writes\", \"paper\"): {\"year\": 24}},\n        ],\n    ],\n)\ndef test_chunk_graph_arbitrary_chunks(\n    num_chunks,\n    num_chunks_nodes,\n    num_chunks_edges,\n    num_chunks_node_data,\n    num_chunks_edge_data,\n):\n    _test_chunk_graph(\n        num_chunks,\n        num_chunks_nodes=num_chunks_nodes,\n        num_chunks_edges=num_chunks_edges,\n        num_chunks_node_data=num_chunks_node_data,\n        num_chunks_edge_data=num_chunks_edge_data,\n    )\n\n\ndef create_mini_chunked_dataset(\n    root_dir,\n    num_chunks,\n    data_fmt,\n    edges_fmt,\n    vector_rows,\n    few_entity=\"node\",\n    **kwargs,\n):\n    num_nodes = {\"n1\": 1000, \"n2\": 1010, \"n3\": 1020}\n    etypes = [\n        (\"n1\", \"r1\", \"n2\"),\n        (\"n2\", \"r1\", \"n1\"),\n        (\"n1\", \"r2\", \"n3\"),\n        (\"n2\", \"r3\", \"n3\"),\n    ]\n    node_items = [\"n1\", \"n2\", \"n3\"]\n    edges_coo = {}\n    for etype in etypes:\n        src_ntype, _, dst_ntype = etype\n        arr = spsp.random(\n            num_nodes[src_ntype],\n            num_nodes[dst_ntype],\n            density=0.001,\n            format=\"coo\",\n            random_state=100,\n        )\n        edges_coo[etype] = (arr.row, arr.col)\n    edge_items = []\n    if few_entity == \"edge\":\n        edges_coo[(\"n1\", \"a0\", \"n2\")] = (\n            torch.tensor([0, 1]),\n            torch.tensor([1, 0]),\n        )\n        edges_coo[(\"n1\", \"a1\", \"n3\")] = (\n            torch.tensor([0, 1]),\n            torch.tensor([1, 0]),\n        )\n        edge_items.append((\"n1\", \"a0\", \"n2\"))\n        edge_items.append((\"n1\", \"a1\", \"n3\"))\n    elif few_entity == \"node\":\n        edges_coo[(\"n1\", \"r_few\", \"n_few\")] = (\n            torch.tensor([0, 1]),\n            torch.tensor([1, 0]),\n        )\n        edges_coo[(\"a0\", \"a01\", \"n_1\")] = (\n            torch.tensor([0, 1]),\n            torch.tensor([1, 0]),\n        )\n        edge_items.append((\"n1\", \"r_few\", \"n_few\"))\n        edge_items.append((\"a0\", \"a01\", \"n_1\"))\n        node_items.append(\"n_few\")\n        node_items.append(\"n_1\")\n        num_nodes[\"n_few\"] = 2\n        num_nodes[\"n_1\"] = 2\n    g = dgl.heterograph(edges_coo)\n\n    node_data = {}\n    edge_data = {}\n    # save feature\n    input_dir = os.path.join(root_dir, \"data_test\")\n\n    for ntype in node_items:\n        os.makedirs(os.path.join(input_dir, ntype))\n        feat = np.random.randn(num_nodes[ntype], 3)\n        feat_path = os.path.join(input_dir, f\"{ntype}/feat.npy\")\n        with open(feat_path, \"wb\") as f:\n            np.save(f, feat)\n        g.nodes[ntype].data[\"feat\"] = torch.from_numpy(feat)\n        node_data[ntype] = {\"feat\": feat_path}\n\n    for etype in set(edge_items):\n        os.makedirs(os.path.join(input_dir, etype[1]))\n        num_edge = len(edges_coo[etype][0])\n        feat = np.random.randn(num_edge, 4)\n        feat_path = os.path.join(input_dir, f\"{etype[1]}/feat.npy\")\n        with open(feat_path, \"wb\") as f:\n            np.save(f, feat)\n        g.edges[etype].data[\"feat\"] = torch.from_numpy(feat)\n        edge_data[etype] = {\"feat\": feat_path}\n\n    output_dir = os.path.join(root_dir, \"chunked-data\")\n    chunk_graph(\n        g,\n        \"mag240m\",\n        node_data,\n        edge_data,\n        num_chunks=num_chunks,\n        output_path=output_dir,\n        data_fmt=data_fmt,\n        edges_fmt=edges_fmt,\n        vector_rows=vector_rows,\n        **kwargs,\n    )\n    return g\n\n\ndef _test_pipeline(\n    num_chunks,\n    num_parts,\n    world_size,\n    graph_formats=None,\n    data_fmt=\"numpy\",\n    num_chunks_nodes=None,\n    num_chunks_edges=None,\n    num_chunks_node_data=None,\n    num_chunks_edge_data=None,\n    use_verify_partitions=False,\n):\n\n    if num_parts % world_size != 0:\n        # num_parts should be a multiple of world_size\n        return\n\n    with tempfile.TemporaryDirectory() as root_dir:\n        g = create_chunked_dataset(\n            root_dir,\n            num_chunks,\n            data_fmt=data_fmt,\n            num_chunks_nodes=num_chunks_nodes,\n            num_chunks_edges=num_chunks_edges,\n            num_chunks_node_data=num_chunks_node_data,\n            num_chunks_edge_data=num_chunks_edge_data,\n        )\n\n        # Step1: graph partition\n        in_dir = os.path.join(root_dir, \"chunked-data\")\n        output_dir = os.path.join(root_dir, \"parted_data\")\n        os.system(\n            \"python3 tools/partition_algo/random_partition.py \"\n            \"--in_dir {} --out_dir {} --num_partitions {}\".format(\n                in_dir, output_dir, num_parts\n            )\n        )\n        for ntype in [\"author\", \"institution\", \"paper\"]:\n            fname = os.path.join(output_dir, \"{}.txt\".format(ntype))\n            with open(fname, \"r\") as f:\n                header = f.readline().rstrip()\n                assert isinstance(int(header), int)\n\n        # Step2: data dispatch\n        partition_dir = os.path.join(root_dir, \"parted_data\")\n        out_dir = os.path.join(root_dir, \"partitioned\")\n        ip_config = os.path.join(root_dir, \"ip_config.txt\")\n        with open(ip_config, \"w\") as f:\n            for i in range(world_size):\n                f.write(f\"127.0.0.{i + 1}\\n\")\n\n        cmd = \"python3 tools/dispatch_data.py\"\n        cmd += f\" --in-dir {in_dir}\"\n        cmd += f\" --partitions-dir {partition_dir}\"\n        cmd += f\" --out-dir {out_dir}\"\n        cmd += f\" --ip-config {ip_config}\"\n        cmd += \" --ssh-port 22\"\n        cmd += \" --process-group-timeout 60\"\n        cmd += \" --save-orig-nids\"\n        cmd += \" --save-orig-eids\"\n        cmd += f\" --graph-formats {graph_formats}\" if graph_formats else \"\"\n        os.system(cmd)\n\n        # check if verify_partitions.py is used for validation.\n        if use_verify_partitions:\n            cmd = \"python3 tools/verify_partitions.py \"\n            cmd += f\" --orig-dataset-dir {in_dir}\"\n            cmd += f\" --part-graph {out_dir}\"\n            cmd += f\" --partitions-dir {output_dir}\"\n            os.system(cmd)\n            return\n\n        # read original node/edge IDs\n        def read_orig_ids(fname):\n            orig_ids = {}\n            for i in range(num_parts):\n                ids_path = os.path.join(out_dir, f\"part{i}\", fname)\n                part_ids = load_tensors(ids_path)\n                for type, data in part_ids.items():\n                    if type not in orig_ids:\n                        orig_ids[type] = data\n                    else:\n                        orig_ids[type] = torch.cat((orig_ids[type], data))\n            return orig_ids\n\n        orig_nids = read_orig_ids(\"orig_nids.dgl\")\n        orig_eids = read_orig_ids(\"orig_eids.dgl\")\n\n        # load partitions and verify\n        part_config = os.path.join(out_dir, \"metadata.json\")\n        for i in range(num_parts):\n            part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition(\n                part_config, i\n            )\n            verify_partition_data_types(part_g)\n            verify_partition_formats(part_g, graph_formats)\n            verify_graph_feats(\n                g, gpb, part_g, node_feats, edge_feats, orig_nids, orig_eids\n            )\n\n\n@pytest.mark.parametrize(\n    \"num_chunks, num_parts, world_size\",\n    [[4, 4, 4], [8, 4, 2], [8, 4, 4], [9, 6, 3], [11, 11, 1], [11, 4, 1]],\n)\ndef test_pipeline_basics(num_chunks, num_parts, world_size):\n    _test_pipeline(num_chunks, num_parts, world_size)\n    _test_pipeline(\n        num_chunks, num_parts, world_size, use_verify_partitions=False\n    )\n\n\n@pytest.mark.parametrize(\n    \"graph_formats\", [None, \"csc\", \"coo,csc\", \"coo,csc,csr\"]\n)\ndef test_pipeline_formats(graph_formats):\n    _test_pipeline(4, 4, 4, graph_formats)\n\n\n@pytest.mark.parametrize(\n    \"num_chunks, \"\n    \"num_parts, \"\n    \"world_size, \"\n    \"num_chunks_node_data, \"\n    \"num_chunks_edge_data\",\n    [\n        # Test cases where no. of chunks more than\n        # no. of partitions\n        [8, 4, 4, 8, 8],\n        [8, 4, 2, 8, 8],\n        [9, 7, 5, 9, 9],\n        [8, 8, 4, 8, 8],\n        # Test cases where no. of chunks smaller\n        # than no. of partitions\n        [7, 8, 4, 7, 7],\n        [1, 8, 4, 1, 1],\n        [1, 4, 4, 1, 1],\n        [3, 4, 4, 3, 3],\n        [1, 4, 2, 1, 1],\n        [3, 4, 2, 3, 3],\n        [1, 5, 3, 1, 1],\n    ],\n)\ndef test_pipeline_arbitrary_chunks(\n    num_chunks,\n    num_parts,\n    world_size,\n    num_chunks_node_data,\n    num_chunks_edge_data,\n):\n    _test_pipeline(\n        num_chunks,\n        num_parts,\n        world_size,\n        num_chunks_node_data=num_chunks_node_data,\n        num_chunks_edge_data=num_chunks_edge_data,\n    )\n\n\n@pytest.mark.parametrize(\n    \"graph_formats\", [None, \"csc\", \"coo,csc\", \"coo,csc,csr\"]\n)\ndef test_pipeline_formats(graph_formats):\n    _test_pipeline(4, 4, 4, graph_formats)\n\n\n@pytest.mark.parametrize(\"data_fmt\", [\"numpy\", \"parquet\"])\ndef test_pipeline_feature_format(data_fmt):\n    _test_pipeline(4, 4, 4, data_fmt=data_fmt)\n\n\n@pytest.mark.parametrize(\n    \"num_chunks, num_parts, world_size\",\n    [[4, 4, 4], [8, 4, 2], [8, 4, 4], [9, 6, 3], [11, 11, 1], [11, 4, 1]],\n)\n@pytest.mark.parametrize(\"few_entity\", [\"node\", \"edge\"])\ndef test_partition_hetero_few_entity(\n    num_chunks,\n    num_parts,\n    world_size,\n    few_entity,\n    graph_formats=None,\n    data_fmt=\"numpy\",\n    edges_fmt=\"csv\",\n    vector_rows=False,\n    num_chunks_nodes=None,\n    num_chunks_edges=None,\n    num_chunks_node_data=None,\n    num_chunks_edge_data=None,\n):\n    with tempfile.TemporaryDirectory() as root_dir:\n        g = create_mini_chunked_dataset(\n            root_dir,\n            num_chunks,\n            few_entity=few_entity,\n            data_fmt=data_fmt,\n            edges_fmt=edges_fmt,\n            vector_rows=vector_rows,\n            num_chunks_nodes=num_chunks_nodes,\n            num_chunks_edges=num_chunks_edges,\n            num_chunks_node_data=num_chunks_node_data,\n            num_chunks_edge_data=num_chunks_edge_data,\n        )\n\n        # Step1: graph partition\n        in_dir = os.path.join(root_dir, \"chunked-data\")\n        output_dir = os.path.join(root_dir, \"parted_data\")\n        os.system(\n            \"python3 tools/partition_algo/random_partition.py \"\n            \"--in_dir {} --out_dir {} --num_partitions {}\".format(\n                in_dir, output_dir, num_parts\n            )\n        )\n\n        # Step2: data dispatch\n        partition_dir = os.path.join(root_dir, \"parted_data\")\n        out_dir = os.path.join(root_dir, \"partitioned\")\n        ip_config = os.path.join(root_dir, \"ip_config.txt\")\n        with open(ip_config, \"w\") as f:\n            for i in range(world_size):\n                f.write(f\"127.0.0.{i + 1}\\n\")\n\n        cmd = \"python3 tools/dispatch_data.py\"\n        cmd += f\" --in-dir {in_dir}\"\n        cmd += f\" --partitions-dir {partition_dir}\"\n        cmd += f\" --out-dir {out_dir}\"\n        cmd += f\" --ip-config {ip_config}\"\n        cmd += \" --ssh-port 22\"\n        cmd += \" --process-group-timeout 60\"\n        cmd += \" --save-orig-nids\"\n        cmd += \" --save-orig-eids\"\n        cmd += f\" --graph-formats {graph_formats}\" if graph_formats else \"\"\n        os.system(cmd)\n\n        # read original node/edge IDs\n        def read_orig_ids(fname):\n            orig_ids = {}\n            for i in range(num_parts):\n                ids_path = os.path.join(out_dir, f\"part{i}\", fname)\n                part_ids = load_tensors(ids_path)\n                for type, data in part_ids.items():\n                    if type not in orig_ids:\n                        orig_ids[type] = data\n                    else:\n                        orig_ids[type] = torch.cat((orig_ids[type], data))\n            return orig_ids\n\n        orig_nids = read_orig_ids(\"orig_nids.dgl\")\n        orig_eids = read_orig_ids(\"orig_eids.dgl\")\n\n        # load partitions and verify\n        part_config = os.path.join(out_dir, \"metadata.json\")\n        for i in range(num_parts):\n            part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition(\n                part_config, i\n            )\n            verify_partition_data_types(part_g)\n            verify_partition_formats(part_g, graph_formats)\n            verify_graph_feats(\n                g, gpb, part_g, node_feats, edge_feats, orig_nids, orig_eids\n            )\n\n\ndef test_utils_generate_read_list():\n    read_list = generate_read_list(10, 4)\n    assert np.array_equal(read_list[0], np.array([0, 1, 2]))\n    assert np.array_equal(read_list[1], np.array([3, 4, 5]))\n    assert np.array_equal(read_list[2], np.array([6, 7]))\n    assert np.array_equal(read_list[3], np.array([8, 9]))\n"
  },
  {
    "path": "tests/tools/test_dist_partition_graphbolt.py",
    "content": "import json\nimport os\nimport tempfile\n\nimport dgl\nimport dgl.backend as F\nimport dgl.graphbolt as gb\n\nimport numpy as np\nimport pyarrow.parquet as pq\nimport pytest\nimport torch\nfrom dgl.data.utils import load_graphs, load_tensors\nfrom dgl.distributed.partition import (\n    _etype_str_to_tuple,\n    _etype_tuple_to_str,\n    _get_inner_edge_mask,\n    _get_inner_node_mask,\n    load_partition,\n    RESERVED_FIELD_DTYPE,\n)\n\nfrom distpartitioning import array_readwriter\nfrom distpartitioning.utils import generate_read_list\nfrom pytest_utils import create_chunked_dataset\n\n\ndef _verify_metadata_gb(gpb, g, num_parts, part_id, part_sizes):\n    \"\"\"\n    check list:\n        make sure the number of nodes and edges is correct.\n        make sure the number of parts is correct.\n        make sure the number of nodes and edges in each part is corrcet.\n    \"\"\"\n    assert gpb._num_nodes() == g.num_nodes()\n    assert gpb._num_edges() == g.num_edges()\n\n    assert gpb.num_partitions() == num_parts\n    gpb_meta = gpb.metadata()\n    assert len(gpb_meta) == num_parts\n    assert len(gpb.partid2nids(part_id)) == gpb_meta[part_id][\"num_nodes\"]\n    assert len(gpb.partid2eids(part_id)) == gpb_meta[part_id][\"num_edges\"]\n    part_sizes.append(\n        (gpb_meta[part_id][\"num_nodes\"], gpb_meta[part_id][\"num_edges\"])\n    )\n\n\ndef _verify_local_id_gb(part_g, part_id, gpb):\n    \"\"\"\n    check list:\n        make sure the type of local id is correct.\n        make sure local id have a right order.\n    \"\"\"\n    nid = F.boolean_mask(\n        part_g.node_attributes[dgl.NID],\n        part_g.node_attributes[\"inner_node\"],\n    )\n    local_nid = gpb.nid2localnid(nid, part_id)\n    assert F.dtype(local_nid) in (F.int64, F.int32)\n    assert np.all(F.asnumpy(local_nid) == np.arange(0, len(local_nid)))\n    eid = F.boolean_mask(\n        part_g.edge_attributes[dgl.EID],\n        part_g.edge_attributes[\"inner_edge\"],\n    )\n    local_eid = gpb.eid2localeid(eid, part_id)\n    assert F.dtype(local_eid) in (F.int64, F.int32)\n    assert np.all(np.sort(F.asnumpy(local_eid)) == np.arange(0, len(local_eid)))\n    return local_nid, local_eid\n\n\ndef _verify_map_gb(\n    part_g,\n    part_id,\n    gpb,\n):\n    \"\"\"\n    check list:\n        make sure the map node and its data type is correct.\n    \"\"\"\n    # Check the node map.\n    local_nodes = F.boolean_mask(\n        part_g.node_attributes[dgl.NID],\n        part_g.node_attributes[\"inner_node\"],\n    )\n    inner_node_index = F.nonzero_1d(part_g.node_attributes[\"inner_node\"])\n    mapping_nodes = gpb.partid2nids(part_id)\n    assert F.dtype(mapping_nodes) in (F.int32, F.int64)\n    assert np.all(\n        np.sort(F.asnumpy(local_nodes)) == np.sort(F.asnumpy(mapping_nodes))\n    )\n    assert np.all(\n        F.asnumpy(inner_node_index) == np.arange(len(inner_node_index))\n    )\n\n    # Check the edge map.\n\n    local_edges = F.boolean_mask(\n        part_g.edge_attributes[dgl.EID],\n        part_g.edge_attributes[\"inner_edge\"],\n    )\n    inner_edge_index = F.nonzero_1d(part_g.edge_attributes[\"inner_edge\"])\n    mapping_edges = gpb.partid2eids(part_id)\n    assert F.dtype(mapping_edges) in (F.int32, F.int64)\n    assert np.all(\n        np.sort(F.asnumpy(local_edges)) == np.sort(F.asnumpy(mapping_edges))\n    )\n    assert np.all(\n        F.asnumpy(inner_edge_index) == np.arange(len(inner_edge_index))\n    )\n    return local_nodes, local_edges\n\n\ndef _verify_local_and_map_id_gb(\n    part_g,\n    part_id,\n    gpb,\n    store_inner_node,\n    store_inner_edge,\n    store_eids,\n):\n    \"\"\"\n    check list:\n        make sure local id are correct.\n        make sure mapping id are correct.\n    \"\"\"\n    if store_inner_node and store_inner_edge and store_eids:\n        _verify_local_id_gb(part_g, part_id, gpb)\n        _verify_map_gb(part_g, part_id, gpb)\n\n\ndef _get_part_IDs(part_g):\n    # These are partition-local IDs.\n    num_columns = part_g.csc_indptr.diff()\n    part_src_ids = part_g.indices\n    part_dst_ids = torch.arange(part_g.total_num_nodes).repeat_interleave(\n        num_columns\n    )\n    # These are reshuffled global homogeneous IDs.\n    part_src_ids = F.gather_row(part_g.node_attributes[dgl.NID], part_src_ids)\n    part_dst_ids = F.gather_row(part_g.node_attributes[dgl.NID], part_dst_ids)\n    return part_src_ids, part_dst_ids\n\n\ndef _verify_node_type_ID_gb(part_g, gpb):\n    \"\"\"\n    check list:\n        make sure ntype id have correct data type\n    \"\"\"\n    part_src_ids, part_dst_ids = _get_part_IDs(part_g)\n    # These are reshuffled per-type IDs.\n    src_ntype_ids, part_src_ids = gpb.map_to_per_ntype(part_src_ids)\n    dst_ntype_ids, part_dst_ids = gpb.map_to_per_ntype(part_dst_ids)\n    # `IdMap` is in int64 by default.\n    assert src_ntype_ids.dtype == F.int64\n    assert dst_ntype_ids.dtype == F.int64\n\n    with pytest.raises(dgl.utils.internal.InconsistentDtypeException):\n        gpb.map_to_per_ntype(F.tensor([0], F.int32))\n    with pytest.raises(dgl.utils.internal.InconsistentDtypeException):\n        gpb.map_to_per_etype(F.tensor([0], F.int32))\n    return (\n        part_src_ids,\n        part_dst_ids,\n        src_ntype_ids,\n        part_src_ids,\n        dst_ntype_ids,\n    )\n\n\ndef _verify_orig_edge_IDs_gb(\n    g,\n    orig_nids,\n    orig_eids,\n    part_eids,\n    part_src_ids,\n    part_dst_ids,\n    src_ntype=None,\n    dst_ntype=None,\n    etype=None,\n):\n    \"\"\"\n    check list:\n        make sure orig edge id are correct after\n    \"\"\"\n    if src_ntype is not None and dst_ntype is not None:\n        orig_src_nid = orig_nids[src_ntype]\n        orig_dst_nid = orig_nids[dst_ntype]\n    else:\n        orig_src_nid = orig_nids\n        orig_dst_nid = orig_nids\n    orig_src_ids = F.gather_row(orig_src_nid, part_src_ids)\n    orig_dst_ids = F.gather_row(orig_dst_nid, part_dst_ids)\n    if etype is not None:\n        orig_eids = orig_eids[etype]\n    orig_eids1 = F.gather_row(orig_eids, part_eids)\n    orig_eids2 = g.edge_ids(orig_src_ids, orig_dst_ids, etype=etype)\n    assert len(orig_eids1) == len(orig_eids2)\n    assert np.all(F.asnumpy(orig_eids1) == F.asnumpy(orig_eids2))\n\n\ndef _verify_orig_IDs_gb(\n    part_g,\n    gpb,\n    g,\n    is_homo=False,\n    part_src_ids=None,\n    part_dst_ids=None,\n    src_ntype_ids=None,\n    dst_ntype_ids=None,\n    orig_nids=None,\n    orig_eids=None,\n):\n    \"\"\"\n    check list:\n        make sure orig edge id are correct.\n        make sure hetero ntype id are correct.\n    \"\"\"\n    part_eids = part_g.edge_attributes[dgl.EID]\n    if is_homo:\n        _verify_orig_edge_IDs_gb(\n            g, orig_nids, orig_eids, part_eids, part_src_ids, part_dst_ids\n        )\n        local_orig_nids = orig_nids[part_g.node_attributes[dgl.NID]]\n        local_orig_eids = orig_eids[part_g.edge_attributes[dgl.EID]]\n        part_g.node_attributes[\"feats\"] = F.gather_row(\n            g.ndata[\"feats\"], local_orig_nids\n        )\n        part_g.edge_attributes[\"feats\"] = F.gather_row(\n            g.edata[\"feats\"], local_orig_eids\n        )\n    else:\n        etype_ids, part_eids = gpb.map_to_per_etype(part_eids)\n        # `IdMap` is in int64 by default.\n        assert etype_ids.dtype == F.int64\n\n        # These are original per-type IDs.\n        for etype_id, etype in enumerate(g.canonical_etypes):\n            part_src_ids1 = F.boolean_mask(part_src_ids, etype_ids == etype_id)\n            src_ntype_ids1 = F.boolean_mask(\n                src_ntype_ids, etype_ids == etype_id\n            )\n            part_dst_ids1 = F.boolean_mask(part_dst_ids, etype_ids == etype_id)\n            dst_ntype_ids1 = F.boolean_mask(\n                dst_ntype_ids, etype_ids == etype_id\n            )\n            part_eids1 = F.boolean_mask(part_eids, etype_ids == etype_id)\n            assert np.all(F.asnumpy(src_ntype_ids1 == src_ntype_ids1[0]))\n            assert np.all(F.asnumpy(dst_ntype_ids1 == dst_ntype_ids1[0]))\n            src_ntype = g.ntypes[F.as_scalar(src_ntype_ids1[0])]\n            dst_ntype = g.ntypes[F.as_scalar(dst_ntype_ids1[0])]\n\n            _verify_orig_edge_IDs_gb(\n                g,\n                orig_nids,\n                orig_eids,\n                part_eids1,\n                part_src_ids1,\n                part_dst_ids1,\n                src_ntype,\n                dst_ntype,\n                etype,\n            )\n\n\ndef _verify_constructed_id_gb(part_sizes, gpb):\n    \"\"\"\n    verify the part id of each node by constructed nids.\n    check list:\n        make sure each node' part id and its type are corect\n    \"\"\"\n    node_map = []\n    edge_map = []\n    for part_i, (num_nodes, num_edges) in enumerate(part_sizes):\n        node_map.append(np.ones(num_nodes) * part_i)\n        edge_map.append(np.ones(num_edges) * part_i)\n    node_map = np.concatenate(node_map)\n    edge_map = np.concatenate(edge_map)\n    nid2pid = gpb.nid2partid(F.arange(0, len(node_map)))\n    assert F.dtype(nid2pid) in (F.int32, F.int64)\n    assert np.all(F.asnumpy(nid2pid) == node_map)\n    eid2pid = gpb.eid2partid(F.arange(0, len(edge_map)))\n    assert F.dtype(eid2pid) in (F.int32, F.int64)\n    assert np.all(F.asnumpy(eid2pid) == edge_map)\n\n\ndef _verify_IDs_gb(\n    g,\n    part_g,\n    part_id,\n    gpb,\n    part_sizes,\n    orig_nids,\n    orig_eids,\n    store_inner_node,\n    store_inner_edge,\n    store_eids,\n    is_homo,\n):\n    # verify local id and mapping id\n    _verify_local_and_map_id_gb(\n        part_g,\n        part_id,\n        gpb,\n        store_inner_node,\n        store_inner_edge,\n        store_eids,\n    )\n\n    # Verify the mapping between the reshuffled IDs and the original IDs.\n    (\n        part_src_ids,\n        part_dst_ids,\n        src_ntype_ids,\n        part_src_ids,\n        dst_ntype_ids,\n    ) = _verify_node_type_ID_gb(part_g, gpb)\n\n    if store_eids:\n        _verify_orig_IDs_gb(\n            part_g,\n            gpb,\n            g,\n            part_src_ids=part_src_ids,\n            part_dst_ids=part_dst_ids,\n            src_ntype_ids=src_ntype_ids,\n            dst_ntype_ids=dst_ntype_ids,\n            orig_nids=orig_nids,\n            orig_eids=orig_eids,\n            is_homo=is_homo,\n        )\n    _verify_constructed_id_gb(part_sizes, gpb)\n\n\ndef _collect_data_gb(\n    parts,\n    part_g,\n    gpbs,\n    gpb,\n    tot_node_feats,\n    node_feats,\n    tot_edge_feats,\n    edge_feats,\n    shuffled_labels,\n    shuffled_edata,\n    test_ntype,\n    test_etype,\n):\n    if test_ntype != None:\n        shuffled_labels.append(node_feats[test_ntype + \"/label\"])\n        shuffled_edata.append(\n            edge_feats[_etype_tuple_to_str(test_etype) + \"/count\"]\n        )\n    else:\n        shuffled_labels.append(node_feats[\"_N/labels\"])\n        shuffled_edata.append(edge_feats[\"_N:_E:_N/feats\"])\n    parts.append(part_g)\n    gpbs.append(gpb)\n    tot_node_feats.append(node_feats)\n    tot_edge_feats.append(edge_feats)\n\n\ndef _verify_node_feats(g, part, gpb, orig_nids, node_feats, is_homo=False):\n    for ntype in g.ntypes:\n        ndata = (\n            part.node_attributes\n            if isinstance(part, gb.FusedCSCSamplingGraph)\n            else part.ndata\n        )\n        ntype_id = g.get_ntype_id(ntype)\n        inner_node_mask = _get_inner_node_mask(\n            part,\n            ntype_id,\n            (gpb if isinstance(part, gb.FusedCSCSamplingGraph) else None),\n        )\n        inner_nids = F.boolean_mask(ndata[dgl.NID], inner_node_mask)\n        ntype_ids, inner_type_nids = gpb.map_to_per_ntype(inner_nids)\n        partid = gpb.nid2partid(inner_type_nids, ntype)\n        if is_homo:\n            assert np.all(F.asnumpy(ntype_ids) == ntype_id)\n            assert np.all(F.asnumpy(partid) == gpb.partid)\n\n        if is_homo:\n            orig_id = orig_nids[inner_type_nids]\n        else:\n            orig_id = orig_nids[ntype][inner_type_nids]\n        local_nids = gpb.nid2localnid(inner_type_nids, gpb.partid, ntype)\n\n        for name in g.nodes[ntype].data:\n            if name in [dgl.NID, \"inner_node\"]:\n                continue\n            true_feats = F.gather_row(g.nodes[ntype].data[name], orig_id)\n            ndata = F.gather_row(node_feats[ntype + \"/\" + name], local_nids)\n            assert np.all(F.asnumpy(ndata == true_feats))\n\n\ndef _verify_edge_feats(g, part, gpb, orig_eids, edge_feats, is_homo=False):\n    for etype in g.canonical_etypes:\n        edata = (\n            part.edge_attributes\n            if isinstance(part, gb.FusedCSCSamplingGraph)\n            else part.edata\n        )\n        etype_id = g.get_etype_id(etype)\n        inner_edge_mask = _get_inner_edge_mask(part, etype_id)\n        inner_eids = F.boolean_mask(edata[dgl.EID], inner_edge_mask)\n        etype_ids, inner_type_eids = gpb.map_to_per_etype(inner_eids)\n        partid = gpb.eid2partid(inner_type_eids, etype)\n        assert np.all(F.asnumpy(etype_ids) == etype_id)\n        assert np.all(F.asnumpy(partid) == gpb.partid)\n\n        if is_homo:\n            orig_id = orig_eids[inner_type_eids]\n        else:\n            orig_id = orig_eids[etype][inner_type_eids]\n        local_eids = gpb.eid2localeid(inner_type_eids, gpb.partid, etype)\n\n        for name in g.edges[etype].data:\n            if name in [dgl.EID, \"inner_edge\"]:\n                continue\n            true_feats = F.gather_row(g.edges[etype].data[name], orig_id)\n            edata = F.gather_row(\n                edge_feats[_etype_tuple_to_str(etype) + \"/\" + name],\n                local_eids,\n            )\n            assert np.all(F.asnumpy(edata == true_feats))\n\n\ndef _verify_shuffled_labels_gb(\n    g,\n    shuffled_labels,\n    shuffled_edata,\n    orig_nids,\n    orig_eids,\n    test_ntype=None,\n    test_etype=None,\n):\n    \"\"\"\n    check list:\n        make sure node data are correct.\n        make sure edge data are correct.\n    \"\"\"\n    shuffled_labels = F.asnumpy(F.cat(shuffled_labels, 0))\n    shuffled_edata = F.asnumpy(F.cat(shuffled_edata, 0))\n    orig_labels = np.zeros(shuffled_labels.shape, dtype=shuffled_labels.dtype)\n    orig_edata = np.zeros(shuffled_edata.shape, dtype=shuffled_edata.dtype)\n\n    orig_nid = orig_nids if test_ntype is None else orig_nids[test_ntype]\n    orig_eid = orig_eids if test_etype is None else orig_eids[test_etype]\n    nlabel = (\n        g.ndata[\"labels\"]\n        if test_ntype is None\n        else g.nodes[test_ntype].data[\"label\"]\n    )\n    edata = (\n        g.edata[\"feats\"]\n        if test_etype is None\n        else g.edges[test_etype].data[\"count\"]\n    )\n\n    orig_labels[F.asnumpy(orig_nid)] = shuffled_labels\n    orig_edata[F.asnumpy(orig_eid)] = shuffled_edata\n    assert np.all(orig_labels == F.asnumpy(nlabel))\n    assert np.all(orig_edata == F.asnumpy(edata))\n\n\ndef verify_graph_feats_gb(\n    g,\n    gpbs,\n    parts,\n    tot_node_feats,\n    tot_edge_feats,\n    orig_nids,\n    orig_eids,\n    shuffled_labels,\n    shuffled_edata,\n    test_ntype,\n    test_etype,\n    store_inner_node=False,\n    store_inner_edge=False,\n    store_eids=False,\n    is_homo=False,\n):\n    \"\"\"\n    check list:\n        make sure the feats of nodes and edges are correct\n    \"\"\"\n    for part_id in range(len(parts)):\n        part = parts[part_id]\n        gpb = gpbs[part_id]\n        node_feats = tot_node_feats[part_id]\n        edge_feats = tot_edge_feats[part_id]\n        if store_inner_node:\n            _verify_node_feats(\n                g,\n                part,\n                gpb,\n                orig_nids,\n                node_feats,\n                is_homo=is_homo,\n            )\n        if store_inner_edge and store_eids:\n            _verify_edge_feats(\n                g,\n                part,\n                gpb,\n                orig_eids,\n                edge_feats,\n                is_homo=is_homo,\n            )\n\n    _verify_shuffled_labels_gb(\n        g,\n        shuffled_labels,\n        shuffled_edata,\n        orig_nids,\n        orig_eids,\n        test_ntype,\n        test_etype,\n    )\n\n\ndef _verify_graphbolt_attributes(\n    parts, store_inner_node, store_inner_edge, store_eids\n):\n    \"\"\"\n    check list:\n        make sure arguments work.\n    \"\"\"\n    for part in parts:\n        assert store_inner_edge == (\"inner_edge\" in part.edge_attributes)\n        assert store_inner_node == (\"inner_node\" in part.node_attributes)\n        assert store_eids == (dgl.EID in part.edge_attributes)\n\n\ndef _verify_graphbolt_part(\n    g,\n    test_dir,\n    orig_nids,\n    orig_eids,\n    graph_name,\n    num_parts,\n    store_inner_node,\n    store_inner_edge,\n    store_eids,\n    part_config=None,\n    test_ntype=None,\n    test_etype=None,\n    is_homo=False,\n):\n    \"\"\"\n    check list:\n        _verify_metadata_gb:\n            data type, ID's order and ID's number of edges and nodes\n        _verify_IDs_gb:\n            local id, mapping id,node type id, orig edge, hetero ntype id\n        verify_graph_feats_gb:\n            nodes and edges' feats\n        _verify_graphbolt_attributes:\n            arguments\n    \"\"\"\n    parts = []\n    tot_node_feats = []\n    tot_edge_feats = []\n    shuffled_labels = []\n    shuffled_edata = []\n    part_sizes = []\n    gpbs = []\n    if part_config is None:\n        part_config = os.path.join(test_dir, f\"{graph_name}.json\")\n    # test each part\n    for part_id in range(num_parts):\n        part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition(\n            part_config, part_id, load_feats=True, use_graphbolt=True\n        )\n        # verify metadata\n        _verify_metadata_gb(\n            gpb,\n            g,\n            num_parts,\n            part_id,\n            part_sizes,\n        )\n\n        # verify eid and nid\n        _verify_IDs_gb(\n            g,\n            part_g,\n            part_id,\n            gpb,\n            part_sizes,\n            orig_nids,\n            orig_eids,\n            store_inner_node,\n            store_inner_edge,\n            store_eids,\n            is_homo,\n        )\n\n        # collect shuffled data and parts\n        _collect_data_gb(\n            parts,\n            part_g,\n            gpbs,\n            gpb,\n            tot_node_feats,\n            node_feats,\n            tot_edge_feats,\n            edge_feats,\n            shuffled_labels,\n            shuffled_edata,\n            test_ntype,\n            test_etype,\n        )\n\n    # verify graph feats\n    verify_graph_feats_gb(\n        g,\n        gpbs,\n        parts,\n        tot_node_feats,\n        tot_edge_feats,\n        orig_nids,\n        orig_eids,\n        shuffled_labels=shuffled_labels,\n        shuffled_edata=shuffled_edata,\n        test_ntype=test_ntype,\n        test_etype=test_etype,\n        store_inner_node=store_inner_node,\n        store_inner_edge=store_inner_edge,\n        store_eids=store_eids,\n        is_homo=is_homo,\n    )\n\n    _verify_graphbolt_attributes(\n        parts, store_inner_node, store_inner_edge, store_eids\n    )\n\n    return parts\n\n\ndef _verify_hetero_graph_node_edge_num(\n    g,\n    parts,\n    store_inner_edge,\n    debug_mode,\n):\n    \"\"\"\n    check list:\n        make sure edge type are correct.\n        make sure the number of nodes in each node type are correct.\n        make sure the number of nodes in each node type are correct.\n    \"\"\"\n    num_nodes = {ntype: 0 for ntype in g.ntypes}\n    num_edges = {etype: 0 for etype in g.canonical_etypes}\n    for part in parts:\n        edata = (\n            part.edge_attributes\n            if isinstance(part, gb.FusedCSCSamplingGraph)\n            else part.edata\n        )\n        if dgl.ETYPE in edata:\n            assert len(g.canonical_etypes) == len(F.unique(edata[dgl.ETYPE]))\n        if debug_mode or isinstance(part, dgl.DGLGraph):\n            for ntype in g.ntypes:\n                ntype_id = g.get_ntype_id(ntype)\n                inner_node_mask = _get_inner_node_mask(part, ntype_id)\n                num_inner_nodes = F.sum(F.astype(inner_node_mask, F.int64), 0)\n                num_nodes[ntype] += num_inner_nodes\n        if store_inner_edge or isinstance(part, dgl.DGLGraph):\n            for etype in g.canonical_etypes:\n                etype_id = g.get_etype_id(etype)\n                inner_edge_mask = _get_inner_edge_mask(part, etype_id)\n                num_inner_edges = F.sum(F.astype(inner_edge_mask, F.int64), 0)\n                num_edges[etype] += num_inner_edges\n\n    # Verify the number of nodes are correct.\n    if debug_mode or isinstance(part, dgl.DGLGraph):\n        for ntype in g.ntypes:\n            print(\n                \"node {}: {}, {}\".format(\n                    ntype, g.num_nodes(ntype), num_nodes[ntype]\n                )\n            )\n            assert g.num_nodes(ntype) == num_nodes[ntype]\n    # Verify the number of edges are correct.\n    if store_inner_edge or isinstance(part, dgl.DGLGraph):\n        for etype in g.canonical_etypes:\n            print(\n                \"edge {}: {}, {}\".format(\n                    etype, g.num_edges(etype), num_edges[etype]\n                )\n            )\n            assert g.num_edges(etype) == num_edges[etype]\n\n\ndef _verify_edge_id_range_hetero(\n    g,\n    part,\n    eids,\n):\n    \"\"\"\n    check list:\n        make sure inner_eids fall into a range.\n        make sure all edges are included.\n    \"\"\"\n    edata = (\n        part.edge_attributes\n        if isinstance(part, gb.FusedCSCSamplingGraph)\n        else part.edata\n    )\n    etype = (\n        part.type_per_edge\n        if isinstance(part, gb.FusedCSCSamplingGraph)\n        else edata[dgl.ETYPE]\n    )\n    eid = torch.arange(len(edata[dgl.EID]))\n    etype_arr = F.gather_row(etype, eid)\n    eid_arr = F.gather_row(edata[dgl.EID], eid)\n    for etype in g.canonical_etypes:\n        etype_id = g.get_etype_id(etype)\n        eids[etype].append(F.boolean_mask(eid_arr, etype_arr == etype_id))\n        # Make sure edge Ids fall into a range.\n        inner_edge_mask = _get_inner_edge_mask(part, etype_id)\n        inner_eids = np.sort(\n            F.asnumpy(F.boolean_mask(edata[dgl.EID], inner_edge_mask))\n        )\n        assert np.all(\n            inner_eids == np.arange(inner_eids[0], inner_eids[-1] + 1)\n        )\n    return eids\n\n\ndef _verify_node_id_range_hetero(g, part, nids):\n    \"\"\"\n    check list:\n        make sure inner nodes have Ids fall into a range.\n    \"\"\"\n    for ntype in g.ntypes:\n        ntype_id = g.get_ntype_id(ntype)\n        # Make sure inner nodes have Ids fall into a range.\n        inner_node_mask = _get_inner_node_mask(part, ntype_id)\n        inner_nids = F.boolean_mask(\n            part.node_attributes[dgl.NID], inner_node_mask\n        )\n        assert np.all(\n            F.asnumpy(\n                inner_nids\n                == F.arange(\n                    F.as_scalar(inner_nids[0]),\n                    F.as_scalar(inner_nids[-1]) + 1,\n                )\n            )\n        )\n        nids[ntype].append(inner_nids)\n    return nids\n\n\ndef _verify_graph_attributes_hetero(\n    g,\n    parts,\n    store_inner_edge,\n    store_inner_node,\n):\n    \"\"\"\n    check list:\n        make sure edge ids fall into a range.\n        make sure inner nodes have Ids fall into a range.\n        make sure all nodes is included.\n        make sure all edges is included.\n    \"\"\"\n    nids = {ntype: [] for ntype in g.ntypes}\n    eids = {etype: [] for etype in g.canonical_etypes}\n    # check edge id.\n    if store_inner_edge or isinstance(parts[0], dgl.DGLGraph):\n        for part in parts:\n            # collect eids\n            eids = _verify_edge_id_range_hetero(g, part, eids)\n        for etype in eids:\n            eids_type = F.cat(eids[etype], 0)\n            uniq_ids = F.unique(eids_type)\n            # We should get all nodes.\n            assert len(uniq_ids) == g.num_edges(etype)\n\n    # check node id.\n    if store_inner_node or isinstance(parts[0], dgl.DGLGraph):\n        for part in parts:\n            nids = _verify_node_id_range_hetero(g, part, nids)\n        for ntype in nids:\n            nids_type = F.cat(nids[ntype], 0)\n            uniq_ids = F.unique(nids_type)\n            # We should get all nodes.\n            assert len(uniq_ids) == g.num_nodes(ntype)\n\n\ndef _verify_hetero_graph(\n    g,\n    parts,\n    store_eids=False,\n    store_inner_edge=False,\n    store_inner_node=False,\n    debug_mode=False,\n):\n    _verify_hetero_graph_node_edge_num(\n        g,\n        parts,\n        store_inner_edge=store_inner_edge,\n        debug_mode=debug_mode,\n    )\n    if store_eids:\n        _verify_graph_attributes_hetero(\n            g,\n            parts,\n            store_inner_edge=store_inner_edge,\n            store_inner_node=store_inner_node,\n        )\n\n\ndef _test_pipeline_graphbolt(\n    num_chunks,\n    num_parts,\n    world_size,\n    graph_formats=None,\n    data_fmt=\"numpy\",\n    num_chunks_nodes=None,\n    num_chunks_edges=None,\n    num_chunks_node_data=None,\n    num_chunks_edge_data=None,\n    use_verify_partitions=False,\n    store_eids=True,\n    store_inner_edge=True,\n    store_inner_node=True,\n):\n    if num_parts % world_size != 0:\n        # num_parts should be a multiple of world_size\n        return\n\n    with tempfile.TemporaryDirectory() as root_dir:\n        g = create_chunked_dataset(\n            root_dir,\n            num_chunks,\n            data_fmt=data_fmt,\n            num_chunks_nodes=num_chunks_nodes,\n            num_chunks_edges=num_chunks_edges,\n            num_chunks_node_data=num_chunks_node_data,\n            num_chunks_edge_data=num_chunks_edge_data,\n        )\n        graph_name = \"test\"\n        test_ntype = \"paper\"\n        test_etype = (\"paper\", \"cites\", \"paper\")\n\n        # Step1: graph partition\n        in_dir = os.path.join(root_dir, \"chunked-data\")\n        output_dir = os.path.join(root_dir, \"parted_data\")\n        os.system(\n            \"python3 tools/partition_algo/random_partition.py \"\n            \"--in_dir {} --out_dir {} --num_partitions {}\".format(\n                in_dir, output_dir, num_parts\n            )\n        )\n        for ntype in [\"author\", \"institution\", \"paper\"]:\n            fname = os.path.join(output_dir, \"{}.txt\".format(ntype))\n            with open(fname, \"r\") as f:\n                header = f.readline().rstrip()\n                assert isinstance(int(header), int)\n\n        # Step2: data dispatch\n        partition_dir = os.path.join(root_dir, \"parted_data\")\n        out_dir = os.path.join(root_dir, \"partitioned\")\n        ip_config = os.path.join(root_dir, \"ip_config.txt\")\n        with open(ip_config, \"w\") as f:\n            for i in range(world_size):\n                f.write(f\"127.0.0.{i + 1}\\n\")\n\n        cmd = \"python3 tools/dispatch_data.py \"\n        cmd += f\" --in-dir {in_dir} \"\n        cmd += f\" --partitions-dir {partition_dir} \"\n        cmd += f\" --out-dir {out_dir} \"\n        cmd += f\" --ip-config {ip_config} \"\n        cmd += \" --ssh-port 22 \"\n        cmd += \" --process-group-timeout 60 \"\n        cmd += \" --save-orig-nids \"\n        cmd += \" --save-orig-eids \"\n        cmd += \" --use-graphbolt \"\n        cmd += f\" --graph-formats {graph_formats} \" if graph_formats else \"\"\n\n        if store_eids:\n            cmd += \" --store-eids \"\n        if store_inner_edge:\n            cmd += \" --store-inner-edge \"\n        if store_inner_node:\n            cmd += \" --store-inner-node \"\n        os.system(cmd)\n\n        # check if verify_partitions.py is used for validation.\n        if use_verify_partitions:\n            cmd = \"python3 tools/verify_partitions.py \"\n            cmd += f\" --orig-dataset-dir {in_dir}\"\n            cmd += f\" --part-graph {out_dir}\"\n            cmd += f\" --partitions-dir {output_dir}\"\n            os.system(cmd)\n            return\n\n        # read original node/edge IDs\n        def read_orig_ids(fname):\n            orig_ids = {}\n            for i in range(num_parts):\n                ids_path = os.path.join(out_dir, f\"part{i}\", fname)\n                part_ids = load_tensors(ids_path)\n                for type, data in part_ids.items():\n                    if type not in orig_ids:\n                        orig_ids[type] = data\n                    else:\n                        orig_ids[type] = torch.cat((orig_ids[type], data))\n            return orig_ids\n\n        orig_nids, orig_eids = None, None\n        orig_nids = read_orig_ids(\"orig_nids.dgl\")\n\n        orig_eids_str = read_orig_ids(\"orig_eids.dgl\")\n\n        orig_eids = {}\n        # transmit etype from string to tuple.\n        for etype, eids in orig_eids_str.items():\n            orig_eids[_etype_str_to_tuple(etype)] = eids\n\n        # load partitions and verify\n        part_config = os.path.join(out_dir, \"metadata.json\")\n        parts = _verify_graphbolt_part(\n            g,\n            root_dir,\n            orig_nids,\n            orig_eids,\n            graph_name,\n            num_parts,\n            store_inner_node,\n            store_inner_edge,\n            store_eids,\n            test_ntype=test_ntype,\n            test_etype=test_etype,\n            part_config=part_config,\n            is_homo=False,\n        )\n        _verify_hetero_graph(\n            g,\n            parts,\n            store_eids=store_eids,\n            store_inner_edge=store_inner_edge,\n        )\n\n\n@pytest.mark.parametrize(\n    \"num_chunks, num_parts, world_size\",\n    [[4, 4, 4], [8, 4, 2], [8, 4, 4], [9, 6, 3], [11, 11, 1], [11, 4, 1]],\n)\ndef test_pipeline_basics(num_chunks, num_parts, world_size):\n    _test_pipeline_graphbolt(\n        num_chunks,\n        num_parts,\n        world_size,\n    )\n    _test_pipeline_graphbolt(\n        num_chunks, num_parts, world_size, use_verify_partitions=False\n    )\n\n\n@pytest.mark.parametrize(\"store_inner_node\", [True, False])\n@pytest.mark.parametrize(\"store_inner_edge\", [True, False])\n@pytest.mark.parametrize(\"store_eids\", [True, False])\ndef test_pipeline_attributes(store_inner_node, store_inner_edge, store_eids):\n    _test_pipeline_graphbolt(\n        4,\n        4,\n        4,\n        store_inner_node=store_inner_node,\n        store_inner_edge=store_inner_edge,\n        store_eids=store_eids,\n    )\n\n\n@pytest.mark.parametrize(\n    \"num_chunks, \"\n    \"num_parts, \"\n    \"world_size, \"\n    \"num_chunks_node_data, \"\n    \"num_chunks_edge_data\",\n    [\n        # Test cases where no. of chunks more than\n        # no. of partitions\n        [8, 4, 4, 8, 8],\n        [8, 4, 2, 8, 8],\n        [9, 7, 5, 9, 9],\n        [8, 8, 4, 8, 8],\n        # Test cases where no. of chunks smaller\n        # than no. of partitions\n        [7, 8, 4, 7, 7],\n        [1, 8, 4, 1, 1],\n        [1, 4, 4, 1, 1],\n        [3, 4, 4, 3, 3],\n        [1, 4, 2, 1, 1],\n        [3, 4, 2, 3, 3],\n        [1, 5, 3, 1, 1],\n    ],\n)\ndef test_pipeline_arbitrary_chunks(\n    num_chunks,\n    num_parts,\n    world_size,\n    num_chunks_node_data,\n    num_chunks_edge_data,\n):\n\n    _test_pipeline_graphbolt(\n        num_chunks,\n        num_parts,\n        world_size,\n        num_chunks_node_data=num_chunks_node_data,\n        num_chunks_edge_data=num_chunks_edge_data,\n    )\n\n\n@pytest.mark.parametrize(\"data_fmt\", [\"numpy\", \"parquet\"])\ndef test_pipeline_feature_format(data_fmt):\n    _test_pipeline_graphbolt(4, 4, 4, data_fmt=data_fmt)\n"
  },
  {
    "path": "tests/tools/test_launch.py",
    "content": "import json\nimport os\nimport tempfile\nimport unittest\n\nfrom launch import *\n\n\nclass TestWrapUdfInTorchDistLauncher(unittest.TestCase):\n    \"\"\"wrap_udf_in_torch_dist_launcher()\"\"\"\n\n    def test_simple(self):\n        # test that a simple udf_command is correctly wrapped\n        udf_command = \"python3.7 path/to/some/trainer.py arg1 arg2\"\n        wrapped_udf_command = wrap_udf_in_torch_dist_launcher(\n            udf_command=udf_command,\n            num_trainers=2,\n            num_nodes=2,\n            node_rank=1,\n            master_addr=\"127.0.0.1\",\n            master_port=1234,\n        )\n        expected = (\n            \"python3.7 -m torch.distributed.run \"\n            \"--nproc_per_node=2 --nnodes=2 --node_rank=1 --master_addr=127.0.0.1 \"\n            \"--master_port=1234 path/to/some/trainer.py arg1 arg2\"\n        )\n        self.assertEqual(wrapped_udf_command, expected)\n\n    def test_chained_udf(self):\n        # test that a chained udf_command is properly handled\n        udf_command = (\n            \"cd path/to && python3.7 path/to/some/trainer.py arg1 arg2\"\n        )\n        wrapped_udf_command = wrap_udf_in_torch_dist_launcher(\n            udf_command=udf_command,\n            num_trainers=2,\n            num_nodes=2,\n            node_rank=1,\n            master_addr=\"127.0.0.1\",\n            master_port=1234,\n        )\n        expected = (\n            \"cd path/to && python3.7 -m torch.distributed.run \"\n            \"--nproc_per_node=2 --nnodes=2 --node_rank=1 --master_addr=127.0.0.1 \"\n            \"--master_port=1234 path/to/some/trainer.py arg1 arg2\"\n        )\n        self.assertEqual(wrapped_udf_command, expected)\n\n    def test_py_versions(self):\n        # test that this correctly handles different py versions/binaries\n        py_binaries = (\n            \"python3.7\",\n            \"python3.8\",\n            \"python3.9\",\n            \"python3\",\n            \"python\",\n        )\n        udf_command = \"{python_bin} path/to/some/trainer.py arg1 arg2\"\n\n        for py_bin in py_binaries:\n            wrapped_udf_command = wrap_udf_in_torch_dist_launcher(\n                udf_command=udf_command.format(python_bin=py_bin),\n                num_trainers=2,\n                num_nodes=2,\n                node_rank=1,\n                master_addr=\"127.0.0.1\",\n                master_port=1234,\n            )\n            expected = (\n                \"{python_bin} -m torch.distributed.run \".format(\n                    python_bin=py_bin\n                )\n                + \"--nproc_per_node=2 --nnodes=2 --node_rank=1 --master_addr=127.0.0.1 \"\n                \"--master_port=1234 path/to/some/trainer.py arg1 arg2\"\n            )\n            self.assertEqual(wrapped_udf_command, expected)\n\n\nclass TestWrapCmdWithLocalEnvvars(unittest.TestCase):\n    \"\"\"wrap_cmd_with_local_envvars()\"\"\"\n\n    def test_simple(self):\n        self.assertEqual(\n            wrap_cmd_with_local_envvars(\"ls && pwd\", \"VAR1=value1 VAR2=value2\"),\n            \"(export VAR1=value1 VAR2=value2; ls && pwd)\",\n        )\n\n\nclass TestConstructDglServerEnvVars(unittest.TestCase):\n    \"\"\"construct_dgl_server_env_vars()\"\"\"\n\n    def test_simple(self):\n        self.assertEqual(\n            construct_dgl_server_env_vars(\n                num_samplers=2,\n                num_server_threads=3,\n                tot_num_clients=4,\n                part_config=\"path/to/part.config\",\n                ip_config=\"path/to/ip.config\",\n                num_servers=5,\n                graph_format=\"csc\",\n            ),\n            (\n                \"DGL_ROLE=server \"\n                \"DGL_NUM_SAMPLER=2 \"\n                \"OMP_NUM_THREADS=3 \"\n                \"DGL_NUM_CLIENT=4 \"\n                \"DGL_CONF_PATH=path/to/part.config \"\n                \"DGL_IP_CONFIG=path/to/ip.config \"\n                \"DGL_NUM_SERVER=5 \"\n                \"DGL_GRAPH_FORMAT=csc \"\n            ),\n        )\n\n\nclass TestConstructDglClientEnvVars(unittest.TestCase):\n    \"\"\"construct_dgl_client_env_vars()\"\"\"\n\n    def test_simple(self):\n        # with pythonpath\n        self.assertEqual(\n            construct_dgl_client_env_vars(\n                num_samplers=1,\n                tot_num_clients=2,\n                part_config=\"path/to/part.config\",\n                ip_config=\"path/to/ip.config\",\n                num_servers=3,\n                graph_format=\"csc\",\n                num_omp_threads=4,\n                group_id=0,\n                pythonpath=\"some/pythonpath/\",\n            ),\n            (\n                \"DGL_DIST_MODE=distributed \"\n                \"DGL_ROLE=client \"\n                \"DGL_NUM_SAMPLER=1 \"\n                \"DGL_NUM_CLIENT=2 \"\n                \"DGL_CONF_PATH=path/to/part.config \"\n                \"DGL_IP_CONFIG=path/to/ip.config \"\n                \"DGL_NUM_SERVER=3 \"\n                \"DGL_GRAPH_FORMAT=csc \"\n                \"OMP_NUM_THREADS=4 \"\n                \"DGL_GROUP_ID=0 \"\n                \"PYTHONPATH=some/pythonpath/ \"\n            ),\n        )\n        # without pythonpath\n        self.assertEqual(\n            construct_dgl_client_env_vars(\n                num_samplers=1,\n                tot_num_clients=2,\n                part_config=\"path/to/part.config\",\n                ip_config=\"path/to/ip.config\",\n                num_servers=3,\n                graph_format=\"csc\",\n                num_omp_threads=4,\n                group_id=0,\n            ),\n            (\n                \"DGL_DIST_MODE=distributed \"\n                \"DGL_ROLE=client \"\n                \"DGL_NUM_SAMPLER=1 \"\n                \"DGL_NUM_CLIENT=2 \"\n                \"DGL_CONF_PATH=path/to/part.config \"\n                \"DGL_IP_CONFIG=path/to/ip.config \"\n                \"DGL_NUM_SERVER=3 \"\n                \"DGL_GRAPH_FORMAT=csc \"\n                \"OMP_NUM_THREADS=4 \"\n                \"DGL_GROUP_ID=0 \"\n            ),\n        )\n\n\ndef test_submit_jobs():\n    class Args:\n        pass\n\n    args = Args()\n\n    with tempfile.TemporaryDirectory() as test_dir:\n        num_machines = 8\n        ip_config = os.path.join(test_dir, \"ip_config.txt\")\n        with open(ip_config, \"w\") as f:\n            for i in range(num_machines):\n                f.write(\"{} {}\\n\".format(\"127.0.0.\" + str(i), 30050))\n        part_config = os.path.join(test_dir, \"ogb-products.json\")\n        with open(part_config, \"w\") as f:\n            json.dump({\"num_parts\": num_machines}, f)\n        args.num_trainers = 8\n        args.num_samplers = 1\n        args.num_servers = 4\n        args.workspace = test_dir\n        args.part_config = \"ogb-products.json\"\n        args.ip_config = \"ip_config.txt\"\n        args.num_server_threads = 1\n        args.graph_format = \"csc\"\n        args.extra_envs = [\"NCCL_DEBUG=INFO\"]\n        args.num_omp_threads = 1\n        udf_command = \"python3 train_dist.py --num_epochs 10\"\n        clients_cmd, servers_cmd = submit_jobs(args, udf_command, dry_run=True)\n\n        def common_checks():\n            assert \"cd \" + test_dir in cmd\n            assert \"export \" + args.extra_envs[0] in cmd\n            assert f\"DGL_NUM_SAMPLER={args.num_samplers}\" in cmd\n            assert (\n                f\"DGL_NUM_CLIENT={args.num_trainers*(args.num_samplers+1)*num_machines}\"\n                in cmd\n            )\n            assert f\"DGL_CONF_PATH={args.part_config}\" in cmd\n            assert f\"DGL_IP_CONFIG={args.ip_config}\" in cmd\n            assert f\"DGL_NUM_SERVER={args.num_servers}\" in cmd\n            assert f\"DGL_GRAPH_FORMAT={args.graph_format}\" in cmd\n            assert f\"OMP_NUM_THREADS={args.num_omp_threads}\" in cmd\n            assert udf_command[len(\"python3 \") :] in cmd\n\n        for cmd in clients_cmd:\n            common_checks()\n            assert \"DGL_DIST_MODE=distributed\" in cmd\n            assert \"DGL_ROLE=client\" in cmd\n            assert \"DGL_GROUP_ID=0\" in cmd\n            assert (\n                f\"python3 -m torch.distributed.run --nproc_per_node={args.num_trainers} --nnodes={num_machines}\"\n                in cmd\n            )\n            assert \"--master_addr=127.0.0\" in cmd\n            assert \"--master_port=1234\" in cmd\n        for cmd in servers_cmd:\n            common_checks()\n            assert \"DGL_ROLE=server\" in cmd\n            assert \"DGL_SERVER_ID=\" in cmd\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/tools/test_parmetis.py",
    "content": "import argparse\nimport json\nimport os\nimport sys\nimport tempfile\nimport unittest\n\nimport dgl\nimport numpy as np\nimport torch\nfrom dgl.data.utils import load_graphs, load_tensors\nfrom partition_algo.base import load_partition_meta\n\nfrom pytest_utils import create_chunked_dataset\n\n\"\"\"\nTODO: skipping this test case since the dependency, mpirun, is\nnot yet configured in the CI framework.\n\"\"\"\n\n\n@unittest.skipIf(True, reason=\"mpi is not available in CI test framework.\")\ndef test_parmetis_preprocessing():\n    with tempfile.TemporaryDirectory() as root_dir:\n        num_chunks = 2\n        g = create_chunked_dataset(root_dir, num_chunks)\n\n        # Trigger ParMETIS pre-processing here.\n        input_dir = os.path.join(root_dir, \"chunked-data\")\n        results_dir = os.path.join(root_dir, \"parmetis-data\")\n        os.system(\n            f\"mpirun -np {num_chunks} python3 tools/distpartitioning/parmetis_preprocess.py \"\n            f\"--schema {metadata.json} \"\n            f\"--input_dir {input_dir} \"\n            f\"--output_dir {results_dir} \"\n            f\"--num_parts {num_chunks}\"\n        )\n\n        # Now add all the tests and check whether the test has passed or failed.\n        # Read parmetis_nfiles and ensure all files are present.\n        parmetis_data_dir = os.path.join(root_dir, \"parmetis-data\")\n        assert os.path.isdir(parmetis_data_dir)\n        parmetis_nodes_file = os.path.join(\n            parmetis_data_dir, \"parmetis_nfiles.txt\"\n        )\n        assert os.path.isfile(parmetis_nodes_file)\n\n        # `parmetis_nfiles.txt` should have each line in the following format.\n        # <filename> <global_id_start> <global_id_end>\n        with open(parmetis_nodes_file, \"r\") as nodes_metafile:\n            lines = nodes_metafile.readlines()\n            total_node_count = 0\n            for line in lines:\n                tokens = line.split(\" \")\n                assert len(tokens) == 3\n                assert os.path.isfile(tokens[0])\n                assert int(tokens[1]) == total_node_count\n\n                # check contents of each of the nodes files here\n                with open(tokens[0], \"r\") as nodes_file:\n                    node_lines = nodes_file.readlines()\n                    for line in node_lines:\n                        val = line.split(\" \")\n                        # <ntype_id> <weight_list> <mask_list> <type_node_id>\n                        assert len(val) == 8\n                    node_count = len(node_lines)\n                    total_node_count += node_count\n                assert int(tokens[2]) == total_node_count\n\n        # Meta_data object.\n        output_dir = os.path.join(root_dir, \"chunked-data\")\n        json_file = os.path.join(output_dir, \"metadata.json\")\n        assert os.path.isfile(json_file)\n        with open(json_file, \"rb\") as f:\n            meta_data = json.load(f)\n\n        # Count the total no. of nodes.\n        true_node_count = 0\n        num_nodes_per_chunk = meta_data[\"num_nodes_per_chunk\"]\n        for i in range(len(num_nodes_per_chunk)):\n            node_per_part = num_nodes_per_chunk[i]\n            for j in range(len(node_per_part)):\n                true_node_count += node_per_part[j]\n        assert total_node_count == true_node_count\n\n        # Read parmetis_efiles and ensure all files are present.\n        # This file contains a list of filenames.\n        parmetis_edges_file = os.path.join(\n            parmetis_data_dir, \"parmetis_efiles.txt\"\n        )\n        assert os.path.isfile(parmetis_edges_file)\n\n        with open(parmetis_edges_file, \"r\") as edges_metafile:\n            lines = edges_metafile.readlines()\n            total_edge_count = 0\n            for line in lines:\n                edges_filename = line.strip()\n                assert os.path.isfile(edges_filename)\n\n                with open(edges_filename, \"r\") as edges_file:\n                    edge_lines = edges_file.readlines()\n                    total_edge_count += len(edge_lines)\n                    for line in edge_lines:\n                        val = line.split(\" \")\n                        assert len(val) == 2\n\n        # Count the total no. of edges\n        true_edge_count = 0\n        num_edges_per_chunk = meta_data[\"num_edges_per_chunk\"]\n        for i in range(len(num_edges_per_chunk)):\n            edges_per_part = num_edges_per_chunk[i]\n            for j in range(len(edges_per_part)):\n                true_edge_count += edges_per_part[j]\n        assert true_edge_count == total_edge_count\n\n\ndef test_parmetis_postprocessing():\n    with tempfile.TemporaryDirectory() as root_dir:\n        num_chunks = 2\n        g = create_chunked_dataset(root_dir, num_chunks)\n\n        num_nodes = g.num_nodes()\n        num_institutions = g.num_nodes(\"institution\")\n        num_authors = g.num_nodes(\"author\")\n        num_papers = g.num_nodes(\"paper\")\n\n        # Generate random parmetis partition ids for the nodes in the graph.\n        # Replace this code with actual ParMETIS executable when it is ready\n        output_dir = os.path.join(root_dir, \"chunked-data\")\n        assert os.path.isdir(output_dir)\n\n        parmetis_file = os.path.join(output_dir, \"parmetis_output.txt\")\n        node_ids = np.arange(num_nodes)\n        partition_ids = np.random.randint(0, 2, (num_nodes,))\n        parmetis_output = np.column_stack([node_ids, partition_ids])\n\n        # Create parmetis output, this is mimicking running actual parmetis.\n        with open(parmetis_file, \"w\") as f:\n            np.savetxt(f, parmetis_output)\n        assert os.path.isfile(parmetis_file)\n\n        # Check the post processing script here.\n        results_dir = os.path.join(output_dir, \"partitions_dir\")\n        json_file = os.path.join(output_dir, \"metadata.json\")\n        print(json_file)\n        print(results_dir)\n        print(parmetis_file)\n        os.system(\n            f\"python3 tools/distpartitioning/parmetis_postprocess.py \"\n            f\"--postproc_input_dir {output_dir} \"\n            f\"--schema_file metadata.json \"\n            f\"--parmetis_output_file {parmetis_file} \"\n            f\"--partitions_dir {results_dir}\"\n        )\n\n        ntype_count = {\n            \"author\": num_authors,\n            \"paper\": num_papers,\n            \"institution\": num_institutions,\n        }\n        for ntype_name in [\"author\", \"paper\", \"institution\"]:\n            fname = os.path.join(results_dir, f\"{ntype_name}.txt\")\n            print(fname)\n            assert os.path.isfile(fname)\n\n            # Load and check the partition ids in this file.\n            part_ids = np.loadtxt(fname)\n            assert part_ids.shape[0] == ntype_count[ntype_name]\n            assert np.min(part_ids) == 0\n            assert np.max(part_ids) == 1\n\n        # check partition meta file\n        part_meta_file = os.path.join(results_dir, \"partition_meta.json\")\n        assert os.path.isfile(part_meta_file)\n        part_meta = load_partition_meta(part_meta_file)\n        assert part_meta.num_parts == 2\n        assert part_meta.algo_name == \"metis\"\n\n\n\"\"\"\nTODO: skipping this test case since it depends on the dependency, mpi,\nwhich is not yet configured in the CI framework.\n\"\"\"\n\n\n@unittest.skipIf(True, reason=\"mpi is not available in CI test framework.\")\ndef test_parmetis_wrapper():\n    with tempfile.TemporaryDirectory() as root_dir:\n        num_chunks = 2\n        graph_name = \"mag240m\"\n        g = create_chunked_dataset(root_dir, num_chunks)\n        all_ntypes = g.ntypes\n        all_etypes = g.etypes\n        num_constraints = len(all_ntypes) + 3\n        num_institutions = g.num_nodes(\"institution\")\n        num_authors = g.num_nodes(\"author\")\n        num_papers = g.num_nodes(\"paper\")\n\n        # Trigger ParMETIS.\n        schema_file = os.path.join(root_dir, \"chunked-data/metadata.json\")\n        preproc_input_dir = os.path.join(root_dir, \"chunked-data\")\n        preproc_output_dir = os.path.join(\n            root_dir, \"chunked-data/preproc_output_dir\"\n        )\n        parmetis_output_file = os.path.join(\n            os.getcwd(), f\"{graph_name}_part.{num_chunks}\"\n        )\n        partitions_dir = os.path.join(root_dir, \"chunked-data/partitions_dir\")\n        hostfile = os.path.join(root_dir, \"ip_config.txt\")\n        with open(hostfile, \"w\") as f:\n            f.write(\"127.0.0.1\\n\")\n            f.write(\"127.0.0.1\\n\")\n\n        num_nodes = g.num_nodes()\n        num_edges = g.num_edges()\n        stats_file = f\"{graph_name}_stats.txt\"\n        with open(stats_file, \"w\") as f:\n            f.write(f\"{num_nodes} {num_edges} {num_constraints}\")\n\n        os.system(\n            f\"python3 tools/distpartitioning/parmetis_wrapper.py \"\n            f\"--schema_file {schema_file} \"\n            f\"--preproc_input_dir {preproc_input_dir} \"\n            f\"--preproc_output_dir {preproc_output_dir} \"\n            f\"--hostfile {hostfile} \"\n            f\"--num_parts {num_chunks} \"\n            f\"--parmetis_output_file {parmetis_output_file} \"\n            f\"--partitions_dir {partitions_dir} \"\n        )\n        print(\"Executing Done.\")\n\n        ntype_count = {\n            \"author\": num_authors,\n            \"paper\": num_papers,\n            \"institution\": num_institutions,\n        }\n        for ntype_name in [\"author\", \"paper\", \"institution\"]:\n            fname = os.path.join(partitions_dir, f\"{ntype_name}.txt\")\n            print(fname)\n            assert os.path.isfile(fname)\n\n            # Load and check the partition ids in this file.\n            part_ids = np.loadtxt(fname)\n            assert part_ids.shape[0] == ntype_count[ntype_name]\n            assert np.min(part_ids) == 0\n            assert np.max(part_ids) == (num_chunks - 1)\n"
  },
  {
    "path": "tests/tools/test_parmetis_preproc.py",
    "content": "import os\nimport tempfile\nfrom collections import namedtuple\n\nimport numpy as np\nimport pytest\nfrom distpartitioning import array_readwriter, constants\nfrom distpartitioning.parmetis_preprocess import gen_edge_files\nfrom distpartitioning.utils import generate_roundrobin_read_list\nfrom numpy.testing import assert_array_equal\n\nNODE_TYPE = \"n1\"\nEDGE_TYPE = f\"{NODE_TYPE}:e1:{NODE_TYPE}\"\n\n\ndef _read_file(fname, fmt_name, fmt_delimiter):\n    \"\"\"Read a file\n\n    Parameters:\n    -----------\n    fname : string\n        filename of the input file to read\n    fmt_name : string\n        specifying whether it is a csv or a parquet file\n    fmt_delimiter : string\n        string specifying the delimiter used in the input file\n    \"\"\"\n    reader_fmt_meta = {\n        \"name\": fmt_name,\n    }\n    if fmt_name == constants.STR_CSV:\n        reader_fmt_meta[\"delimiter\"] = fmt_delimiter\n    data_df = array_readwriter.get_array_parser(**reader_fmt_meta).read(fname)\n    return data_df\n\n\ndef _get_test_data(edges_dir, num_chunks, edge_fmt, edge_fmt_del):\n    \"\"\"Creates unit test input which are a set of edge files\n    in the following format \"src_node_id<delimiter>dst_node_id\"\n\n    Parameters:\n    -----------\n    edges_dir : str\n        folder where edge files are stored\n    num_chunks : int\n        no. of files to create for each edge type\n    edge_fmt : str, optional\n        to specify whether this file is csv or parquet\n    edge_fmt_del : str optional\n        delimiter to use in the edges file\n\n    Returns:\n    --------\n    dict :\n        dictionary created which represents the schema used for\n        creating the input dataset\n    \"\"\"\n    schema = {}\n    schema[\"num_nodes_per_type\"] = [10]\n    schema[\"edge_type\"] = [EDGE_TYPE]\n    schema[\"node_type\"] = [NODE_TYPE]\n\n    edges = {}\n    edges[EDGE_TYPE] = {}\n    edges[EDGE_TYPE][\"format\"] = {}\n    edges[EDGE_TYPE][\"format\"][\"name\"] = edge_fmt\n    edges[EDGE_TYPE][\"format\"][\"delimiter\"] = edge_fmt_del\n\n    os.makedirs(edges_dir, exist_ok=True)\n    fmt_meta = {\"name\": edge_fmt}\n    if edge_fmt == \"csv\":\n        fmt_meta[\"delimiter\"] = edge_fmt_del\n\n    edge_files = []\n    for idx in range(num_chunks):\n        path = os.path.join(edges_dir, f\"test_file_{idx}.{fmt_meta['name']}\")\n        array_parser = array_readwriter.get_array_parser(**fmt_meta)\n        edge_data = (\n            np.array([np.arange(10), np.arange(10)]).reshape(10, 2) + 10 * idx\n        )\n        array_parser.write(path, edge_data)\n\n        edge_files.append(path)\n\n    edges[EDGE_TYPE][\"data\"] = edge_files\n    schema[\"edges\"] = edges\n\n    return schema\n\n\n@pytest.mark.parametrize(\"num_chunks, num_parts\", [[4, 1], [4, 2], [4, 4]])\n@pytest.mark.parametrize(\"edges_fmt\", [\"csv\", \"parquet\"])\n@pytest.mark.parametrize(\"edges_delimiter\", [\" \", \",\"])\ndef test_gen_edge_files(num_chunks, num_parts, edges_fmt, edges_delimiter):\n    \"\"\"Unit test case for the function\n    tools/distpartitioning/parmetis_preprocess.py::gen_edge_files\n\n    Parameters:\n    -----------\n    num_chunks : int\n        no. of chunks the input graph needs to be partititioned into\n    num_parts : int\n        no. of partitions\n    edges_fmt : string\n        specifying the storage format for the edge files\n    edges_delimiter : string\n        specifying the delimiter used in the edge files\n    \"\"\"\n    # Create the input dataset\n    with tempfile.TemporaryDirectory() as root_dir:\n\n        # Create expected environment for test\n        input_dir = os.path.join(root_dir, \"chunked-data\")\n        output_dir = os.path.join(root_dir, \"preproc_dir\")\n\n        # Mock a parser object\n        fn_params = namedtuple(\"fn_params\", \"input_dir output_dir num_parts\")\n        fn_params.input_dir = input_dir\n        fn_params.output_dir = output_dir\n        fn_params.num_parts = num_parts\n\n        # Create test files and get corresponding file schema\n        schema_map = _get_test_data(\n            input_dir, num_chunks, edges_fmt, edges_delimiter\n        )\n        edges_file_list = schema_map[\"edges\"][EDGE_TYPE][\"data\"]\n        # This is breaking encapsulation, but no other good way to get file list\n        rank_assignments = generate_roundrobin_read_list(\n            len(edges_file_list), num_parts\n        )\n\n        # Get the global node id offsets for each node type\n        # There is only one node-type in the test graph\n        # which range from 0 thru 9.\n        ntype_gnid_offset = {}\n        ntype_gnid_offset[NODE_TYPE] = np.array([0, 10 * num_chunks]).reshape(\n            1, 2\n        )\n\n        # Iterate over no. of partitions\n        for rank in range(num_parts):\n            actual_results = gen_edge_files(rank, schema_map, fn_params)\n\n            # Get the original files\n            original_files = [\n                edges_file_list[file_idx] for file_idx in rank_assignments[rank]\n            ]\n\n            # Validate the results with the baseline results\n            # Test 1. no. of files should have the same count per rank\n            assert len(original_files) == len(actual_results)\n            assert len(actual_results) > 0\n\n            # Test 2. Check the contents of each file and verify the\n            # file contents match with the expected results.\n            for actual_fname, original_fname in zip(\n                actual_results, original_files\n            ):\n                # Check the actual file exists\n                assert os.path.isfile(actual_fname)\n                # Read both files and compare the edges\n                # Here note that the src and dst end points are global_node_ids\n                actual_data = _read_file(actual_fname, \"csv\", \" \")\n                expected_data = _read_file(\n                    original_fname, edges_fmt, edges_delimiter\n                )\n\n                # Subtract the global node id offsets, so that we get type node ids\n                # In the current unit test case, the graph has only one node-type.\n                # and this means that type-node-ids are same as the global-node-ids.\n                # Below two lines will take take into effect when the graphs have\n                # more than one node type.\n                actual_data[:, 0] -= ntype_gnid_offset[NODE_TYPE][0, 0]\n                actual_data[:, 1] -= ntype_gnid_offset[NODE_TYPE][0, 0]\n\n                # Verify that the contents are equal\n                assert_array_equal(expected_data, actual_data)\n"
  },
  {
    "path": "tests/utils/__init__.py",
    "content": "import backend as F\nimport pytest\n\nparametrize_idtype = pytest.mark.parametrize(\"idtype\", [F.int32, F.int64])\n\nfrom .checks import *\nfrom .graph_cases import get_cases\n"
  },
  {
    "path": "tests/utils/checks.py",
    "content": "import backend as F\n\nimport dgl\nimport pytest\nfrom dgl.base import is_internal_column\n\n__all__ = [\n    \"check_fail\",\n    \"assert_is_identical\",\n    \"assert_is_identical_hetero\",\n    \"check_graph_equal\",\n]\n\n\ndef check_fail(fn, *args, **kwargs):\n    try:\n        fn(*args, **kwargs)\n        return False\n    except:\n        return True\n\n\ndef assert_is_identical(g, g2):\n    assert g.num_nodes() == g2.num_nodes()\n    src, dst = g.all_edges(order=\"eid\")\n    src2, dst2 = g2.all_edges(order=\"eid\")\n    assert F.array_equal(src, src2)\n    assert F.array_equal(dst, dst2)\n\n    assert len(g.ndata) == len(g2.ndata)\n    assert len(g.edata) == len(g2.edata)\n    for k in g.ndata:\n        assert F.allclose(g.ndata[k], g2.ndata[k])\n    for k in g.edata:\n        assert F.allclose(g.edata[k], g2.edata[k])\n\n\ndef assert_is_identical_hetero(g, g2, ignore_internal_data=False):\n    assert g.ntypes == g2.ntypes\n    assert g.canonical_etypes == g2.canonical_etypes\n\n    # check if two metagraphs are identical\n    for edges, features in g.metagraph().edges(keys=True).items():\n        assert g2.metagraph().edges(keys=True)[edges] == features\n\n    # check if node ID spaces and feature spaces are equal\n    for ntype in g.ntypes:\n        assert g.num_nodes(ntype) == g2.num_nodes(ntype)\n        if ignore_internal_data:\n            for k in list(g.nodes[ntype].data.keys()):\n                if is_internal_column(k):\n                    del g.nodes[ntype].data[k]\n            for k in list(g2.nodes[ntype].data.keys()):\n                if is_internal_column(k):\n                    del g2.nodes[ntype].data[k]\n        assert len(g.nodes[ntype].data) == len(g2.nodes[ntype].data)\n        for k in g.nodes[ntype].data:\n            assert F.allclose(g.nodes[ntype].data[k], g2.nodes[ntype].data[k])\n\n    # check if edge ID spaces and feature spaces are equal\n    for etype in g.canonical_etypes:\n        src, dst = g.all_edges(etype=etype, order=\"eid\")\n        src2, dst2 = g2.all_edges(etype=etype, order=\"eid\")\n        assert F.array_equal(src, src2)\n        assert F.array_equal(dst, dst2)\n        if ignore_internal_data:\n            for k in list(g.edges[etype].data.keys()):\n                if is_internal_column(k):\n                    del g.edges[etype].data[k]\n            for k in list(g2.edges[etype].data.keys()):\n                if is_internal_column(k):\n                    del g2.edges[etype].data[k]\n        assert len(g.edges[etype].data) == len(g2.edges[etype].data)\n        for k in g.edges[etype].data:\n            assert F.allclose(g.edges[etype].data[k], g2.edges[etype].data[k])\n\n\ndef check_graph_equal(g1, g2, *, check_idtype=True, check_feature=True):\n    assert g1.device == g2.device\n    if check_idtype:\n        assert g1.idtype == g2.idtype\n    assert g1.ntypes == g2.ntypes\n    assert g1.etypes == g2.etypes\n    assert g1.srctypes == g2.srctypes\n    assert g1.dsttypes == g2.dsttypes\n    assert g1.canonical_etypes == g2.canonical_etypes\n    assert g1.batch_size == g2.batch_size\n\n    # check if two metagraphs are identical\n    for edges, features in g1.metagraph().edges(keys=True).items():\n        assert g2.metagraph().edges(keys=True)[edges] == features\n\n    for nty in g1.ntypes:\n        assert g1.num_nodes(nty) == g2.num_nodes(nty)\n        assert F.allclose(g1.batch_num_nodes(nty), g2.batch_num_nodes(nty))\n    for ety in g1.canonical_etypes:\n        assert g1.num_edges(ety) == g2.num_edges(ety)\n        assert F.allclose(g1.batch_num_edges(ety), g2.batch_num_edges(ety))\n        src1, dst1, eid1 = g1.edges(etype=ety, form=\"all\")\n        src2, dst2, eid2 = g2.edges(etype=ety, form=\"all\")\n        if check_idtype:\n            assert F.allclose(src1, src2)\n            assert F.allclose(dst1, dst2)\n            assert F.allclose(eid1, eid2)\n        else:\n            assert F.allclose(src1, F.astype(src2, g1.idtype))\n            assert F.allclose(dst1, F.astype(dst2, g1.idtype))\n            assert F.allclose(eid1, F.astype(eid2, g1.idtype))\n\n    if check_feature:\n        for nty in g1.ntypes:\n            if g1.num_nodes(nty) == 0:\n                continue\n            for feat_name in g1.nodes[nty].data.keys():\n                assert F.allclose(\n                    g1.nodes[nty].data[feat_name], g2.nodes[nty].data[feat_name]\n                )\n        for ety in g1.canonical_etypes:\n            if g1.num_edges(ety) == 0:\n                continue\n            for feat_name in g2.edges[ety].data.keys():\n                assert F.allclose(\n                    g1.edges[ety].data[feat_name], g2.edges[ety].data[feat_name]\n                )\n"
  },
  {
    "path": "tests/utils/graph_cases.py",
    "content": "from collections import defaultdict\n\nimport backend as F\nimport dgl\nimport networkx as nx\nimport numpy as np\nimport scipy.sparse as ssp\n\ncase_registry = defaultdict(list)\n\n\ndef register_case(labels):\n    def wrapper(fn):\n        for lbl in labels:\n            case_registry[lbl].append(fn)\n        fn.__labels__ = labels\n        return fn\n\n    return wrapper\n\n\ndef get_cases(labels=None, exclude=[]):\n    \"\"\"Get all graph instances of the given labels.\"\"\"\n    cases = set()\n    if labels is None:\n        # get all the cases\n        labels = case_registry.keys()\n    for lbl in labels:\n        for case in case_registry[lbl]:\n            if not any([l in exclude for l in case.__labels__]):\n                cases.add(case)\n    return [fn() for fn in cases]\n\n\n@register_case([\"bipartite\", \"zero-degree\"])\ndef bipartite1():\n    return dgl.heterograph(\n        {(\"_U\", \"_E\", \"_V\"): ([0, 0, 0, 2, 2, 3], [0, 1, 4, 1, 4, 3])}\n    )\n\n\n@register_case([\"bipartite\"])\ndef bipartite_full():\n    return dgl.heterograph(\n        {\n            (\"_U\", \"_E\", \"_V\"): (\n                [0, 0, 0, 0, 1, 1, 1, 1],\n                [0, 1, 2, 3, 0, 1, 2, 3],\n            )\n        }\n    )\n\n\n@register_case([\"homo\"])\ndef graph0():\n    return dgl.graph(\n        (\n            [0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 3, 4, 6, 6, 7, 8, 9],\n            [4, 5, 1, 2, 4, 7, 9, 8, 6, 4, 1, 0, 1, 0, 2, 3, 5],\n        )\n    )\n\n\n@register_case([\"homo\", \"zero-degree\", \"homo-zero-degree\"])\ndef bipartite1():\n    return dgl.graph(([0, 0, 0, 2, 2, 3], [0, 1, 4, 1, 4, 3]))\n\n\n@register_case([\"homo\", \"has_feature\"])\ndef graph1():\n    g = dgl.graph(\n        (\n            [0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 3, 4, 6, 6, 7, 8, 9],\n            [4, 5, 1, 2, 4, 7, 9, 8, 6, 4, 1, 0, 1, 0, 2, 3, 5],\n        ),\n        device=F.cpu(),\n    )\n    g.ndata[\"h\"] = F.copy_to(F.randn((g.num_nodes(), 2)), F.cpu())\n    g.edata[\"w\"] = F.copy_to(F.randn((g.num_edges(), 3)), F.cpu())\n    return g\n\n\n@register_case([\"homo\", \"has_scalar_e_feature\"])\ndef graph1():\n    g = dgl.graph(\n        (\n            [0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 3, 4, 6, 6, 7, 8, 9],\n            [4, 5, 1, 2, 4, 7, 9, 8, 6, 4, 1, 0, 1, 0, 2, 3, 5],\n        ),\n        device=F.cpu(),\n    )\n    g.ndata[\"h\"] = F.copy_to(F.randn((g.num_nodes(), 2)), F.cpu())\n    g.edata[\"scalar_w\"] = F.copy_to(F.abs(F.randn((g.num_edges(),))), F.cpu())\n    return g\n\n\n@register_case([\"homo\", \"row_sorted\"])\ndef graph2():\n    return dgl.graph(\n        (\n            [0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 3, 4, 6, 6, 7, 8, 9],\n            [4, 5, 1, 2, 4, 7, 9, 8, 6, 4, 1, 0, 1, 0, 2, 3, 5],\n        ),\n        row_sorted=True,\n    )\n\n\n@register_case([\"homo\", \"row_sorted\", \"col_sorted\"])\ndef graph3():\n    return dgl.graph(\n        (\n            [0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 3, 4, 6, 6, 7, 8, 9],\n            [1, 4, 5, 2, 4, 7, 8, 9, 1, 4, 6, 0, 0, 1, 2, 3, 5],\n        ),\n        row_sorted=True,\n        col_sorted=True,\n    )\n\n\n@register_case([\"hetero\", \"has_feature\"])\ndef heterograph0():\n    g = dgl.heterograph(\n        {\n            (\"user\", \"plays\", \"game\"): ([0, 1, 1, 2], [0, 0, 1, 1]),\n            (\"developer\", \"develops\", \"game\"): ([0, 1], [0, 1]),\n        },\n        device=F.cpu(),\n    )\n    g.nodes[\"user\"].data[\"h\"] = F.copy_to(\n        F.randn((g.num_nodes(\"user\"), 3)), F.cpu()\n    )\n    g.nodes[\"game\"].data[\"h\"] = F.copy_to(\n        F.randn((g.num_nodes(\"game\"), 2)), F.cpu()\n    )\n    g.nodes[\"developer\"].data[\"h\"] = F.copy_to(\n        F.randn((g.num_nodes(\"developer\"), 3)), F.cpu()\n    )\n    g.edges[\"plays\"].data[\"h\"] = F.copy_to(\n        F.randn((g.num_edges(\"plays\"), 1)), F.cpu()\n    )\n    g.edges[\"develops\"].data[\"h\"] = F.copy_to(\n        F.randn((g.num_edges(\"develops\"), 5)), F.cpu()\n    )\n    return g\n\n\n@register_case([\"batched\", \"homo\"])\ndef batched_graph0():\n    g1 = dgl.add_self_loop(dgl.graph(([0, 1, 2], [1, 2, 3])))\n    g2 = dgl.add_self_loop(dgl.graph(([1, 1], [2, 0])))\n    g3 = dgl.add_self_loop(dgl.graph(([0], [1])))\n    return dgl.batch([g1, g2, g3])\n\n\n@register_case([\"block\", \"bipartite\", \"block-bipartite\"])\ndef block_graph0():\n    g = dgl.graph(([2, 3, 4], [5, 6, 7]), num_nodes=100)\n    g = g.to(F.cpu())\n    return dgl.to_block(g)\n\n\n@register_case([\"block\"])\ndef block_graph1():\n    g = dgl.heterograph(\n        {\n            (\"user\", \"plays\", \"game\"): ([0, 1, 2], [1, 1, 0]),\n            (\"user\", \"likes\", \"game\"): ([1, 2, 3], [0, 0, 2]),\n            (\"store\", \"sells\", \"game\"): ([0, 1, 1], [0, 1, 2]),\n        },\n        device=F.cpu(),\n    )\n    return dgl.to_block(g)\n\n\n@register_case([\"clique\"])\ndef clique():\n    g = dgl.graph(([0, 0, 0, 1, 1, 1, 2, 2, 2], [0, 1, 2, 0, 1, 2, 0, 1, 2]))\n    return g\n\n\ndef random_dglgraph(size):\n    return dgl.DGLGraph(nx.erdos_renyi_graph(size, 0.3))\n\n\ndef random_graph(size):\n    return dgl.from_networkx(nx.erdos_renyi_graph(size, 0.3))\n\n\ndef random_bipartite(size_src, size_dst):\n    return dgl.bipartite_from_scipy(\n        ssp.random(size_src, size_dst, 0.1),\n        utype=\"_U\",\n        etype=\"_E\",\n        vtype=\"V\",\n    )\n\n\ndef random_block(size):\n    g = dgl.from_networkx(nx.erdos_renyi_graph(size, 0.1))\n    return dgl.to_block(g, np.unique(F.zerocopy_to_numpy(g.edges()[1])))\n\n\n@register_case([\"two_hetero_batch\"])\ndef two_hetero_batch():\n    g1 = dgl.heterograph(\n        {\n            (\"user\", \"follows\", \"user\"): ([0, 1], [1, 2]),\n            (\"user\", \"follows\", \"developer\"): ([0, 1], [1, 2]),\n            (\"user\", \"plays\", \"game\"): ([0, 1, 2, 3], [0, 0, 1, 1]),\n        }\n    )\n    g2 = dgl.heterograph(\n        {\n            (\"user\", \"follows\", \"user\"): ([0, 1], [1, 2]),\n            (\"user\", \"follows\", \"developer\"): ([0, 1], [1, 2]),\n            (\"user\", \"plays\", \"game\"): ([0, 1, 2], [0, 0, 1]),\n        }\n    )\n    return [g1, g2]\n\n\n@register_case([\"two_hetero_batch\"])\ndef two_hetero_batch_with_isolated_ntypes():\n    g1 = dgl.heterograph(\n        {\n            (\"user\", \"follows\", \"user\"): ([0, 1], [1, 2]),\n            (\"user\", \"follows\", \"developer\"): ([0, 1], [1, 2]),\n            (\"user\", \"plays\", \"game\"): ([0, 1, 2, 3], [0, 0, 1, 1]),\n        },\n        num_nodes_dict={\"user\": 4, \"game\": 2, \"developer\": 3, \"platform\": 2},\n    )\n    g2 = dgl.heterograph(\n        {\n            (\"user\", \"follows\", \"user\"): ([0, 1], [1, 2]),\n            (\"user\", \"follows\", \"developer\"): ([0, 1], [1, 2]),\n            (\"user\", \"plays\", \"game\"): ([0, 1, 2], [0, 0, 1]),\n        },\n        num_nodes_dict={\"user\": 3, \"game\": 2, \"developer\": 3, \"platform\": 3},\n    )\n    return [g1, g2]\n\n\n@register_case([\"batched\", \"hetero\"])\ndef batched_heterograph0():\n    g1 = dgl.heterograph(\n        {\n            (\"user\", \"follows\", \"user\"): ([0, 1], [1, 2]),\n            (\"user\", \"follows\", \"developer\"): ([0, 1], [1, 2]),\n            (\"user\", \"plays\", \"game\"): ([0, 1, 2, 3], [0, 0, 1, 1]),\n        }\n    )\n    g2 = dgl.heterograph(\n        {\n            (\"user\", \"follows\", \"user\"): ([0, 1], [1, 2]),\n            (\"user\", \"follows\", \"developer\"): ([0, 1], [1, 2]),\n            (\"user\", \"plays\", \"game\"): ([0, 1, 2], [0, 0, 1]),\n        }\n    )\n    g3 = dgl.heterograph(\n        {\n            (\"user\", \"follows\", \"user\"): ([1], [2]),\n            (\"user\", \"follows\", \"developer\"): ([0, 1, 2], [0, 2, 2]),\n            (\"user\", \"plays\", \"game\"): ([0, 1], [0, 0]),\n        }\n    )\n    return dgl.batch([g1, g2, g3])\n"
  },
  {
    "path": "third_party/HugeCTR/gpu_cache/ReadMe.md",
    "content": "# GPU Embedding Cache\n\nThis project implements an embedding cache on GPU memory that is designed for CTR inference and training workload.\n\nThe cache stores the hot pairs, (embedding id, embedding vectors), on GPU memory.\nStoring the data on GPU memory reduces the traffic to the parameter server when performing embedding table lookup.\n\nThe cache is designed for CTR inference and training, it has following features and restrictions:\n\n* All the backup memory-side operations are performed by the parameter server.\n  These operations include prefetching, latency hiding, and so on.\n* This is a single-GPU design.\n  Each cache belongs to one GPU.\n* The cache is thread-safe: multiple workers, CPU threads, can concurrently call the API of a single cache object with well-defined behavior.\n* The cache implements a least recently used (LRU) replacement algorithm so that it caches the most recently queried embeddings.\n* The embeddings stored inside the cache are unique: there are no duplicated embedding IDs in the cache.\n\n## Project Structure\n\nThis project is a stand-alone module in HugeCTR project.\nThe root folder of this project is the `gpu_cache` folder under the HugeCTR root directory.\n\nThe `include` folder contains the headers for the cache library and the `src` folder contains the implementations and Makefile for the cache library.\nThe `test` folder contains a test that tests the correctness and performance of the GPU embedding cache.\nThe test also acts as sample code that shows how to use the cache.\n\nThe `nv_gpu_cache.hpp` file contains the definition of the main class, `gpu_cache`, that implements the GPU embedding cache.\nThe `nv_gpu_cache.cu` file contains the implementation.\n\nAs a module of HugeCTR, this project is built with and used by the HugeCTR project.\n\n## Supported Data Types\n\n* The cache supports 32 and 64-bit scalar integer types for the key (embedding ID) type.\n  For example, the data type declarations `unsigned int` and `long long` match these integer types.\n* The cache supports a vector of floats for the value (embedding vector) type.\n* You need to specify an empty key to indicate the empty bucket.\n  Do not use an empty key to represent any real key.\n* Refer to the instantiation code at the end of the `nv_gpu_cache.cu` file for template parameters.\n\n## Requirements\n\n* NVIDIA GPU >= Volta (SM 70).\n* CUDA environment >= 11.0.\n* (Optional) libcu++ library >= 1.1.0.\n  The CUDA Toolkit 11.0 (Early Access) and above meets the required library version.\n  Using the libcu++ library provides better performance and more precisely-defined behavior.\n  You can enable libcu++ library by defining the `LIBCUDACXX_VERSION` macro when you compile.\n  Otherwise, the libcu++ library is not enabled.\n* The default building option for HugeCTR is to disable the libcu++ library.\n\n## Usage Overview\n\n```c++\ntemplate<typename key_type,\n         typename ref_counter_type,\n         key_type empty_key,\n         int set_associativity,\n         int warp_size,\n         typename set_hasher = MurmurHash3_32<key_type>,\n         typename slab_hasher = Mod_Hash<key_type, size_t>>\nclass gpu_cache{\npublic:\n    //Ctor\n    gpu_cache(const size_t capacity_in_set, const size_t embedding_vec_size);\n\n    //Dtor\n    ~gpu_cache();\n\n    // Query API, i.e. A single read from the cache\n    void Query(const key_type* d_keys,\n               const size_t len,\n               float* d_values,\n               uint64_t* d_missing_index,\n               key_type* d_missing_keys,\n               size_t* d_missing_len,\n               cudaStream_t stream,\n               const size_t task_per_warp_tile = TASK_PER_WARP_TILE_MACRO);\n\n    // Replace API, i.e. Follow the Query API to update the content of the cache to Most Recent\n    void Replace(const key_type* d_keys,\n                 const size_t len,\n                 const float* d_values,\n                 cudaStream_t stream,\n                 const size_t task_per_warp_tile = TASK_PER_WARP_TILE_MACRO);\n\n    // Update API, i.e. update the embeddings which exist in the cache\n    void Update(const key_type* d_keys,\n                const size_t len,\n                const float* d_values,\n                cudaStream_t stream,\n                const size_t task_per_warp_tile = TASK_PER_WARP_TILE_MACRO);\n\n    // Dump API, i.e. dump some slabsets' keys from the cache\n    void Dump(key_type* d_keys,\n              size_t* d_dump_counter,\n              const size_t start_set_index,\n              const size_t end_set_index,\n              cudaStream_t stream);\n\n};\n```\n\n## API\n\n`Constructor`\n\nTo create a new embedding cache, you need to provide the following:\n\n* Template parameters:\n    + key_type: the data type of embedding ID.\n    + ref_counter_type: the data type of the internal counter. This data type should be 64bit unsigned integer(i.e. uint64_t), 32bit integer has the risk of overflow.\n    + empty_key: the key value indicate for empty bucket(i.e. The empty key), user should never use empty key value to represent any real keys.\n    + set_associativity: the hyper-parameter indicates how many slabs per cache set.(See `Performance hint` session below)\n    + warp_size: the hyper-parameter indicates how many [key, value] pairs per slab. Acceptable value includes 1/2/4/8/16/32.(See `Performance hint` session below)\n    + For other template parameters just use the default value.\n* Parameters:\n    + capacity_in_set: # of cache set in the embedding cache. So the total capacity of the embedding cache is `warp_size * set_associativity * capacity_in_set` [key, value] pairs.\n    + embedding_vec_size: # of float per a embedding vector.\n* The host thread will wait for the GPU kernels to complete before returning from the API, thus this API is synchronous with CPU thread. When returned, the initialization process of the cache is already done.\n* The embedding cache will be created on the GPU where user call the constructor. Thus, user should set the host thread to the target CUDA device before creating the embedding cache. All resources(i.e. device-side buffers, CUDA streams) used later for this embedding cache should be allocated on the same CUDA device as the embedding cache.\n* The constructor can be called only once, thus is not thread-safe.\n\n`Destructor`\n\n* The destructor clean up the embedding cache. This API should be called only once when user need to delete the embedding cache object, thus is not thread-safe.\n\n`Query`\n\n* Search `len` elements from device-side buffers `d_keys` in the cache and return the result in device-side buffer `d_values` if a key is hit in the cache.\n* If a key is missing, the missing key and its index in the `d_keys` buffer will be returned in device-side buffers `d_missing_keys` and `d_missing_index`. The # of missing key will be return in device-side buffer `d_missing_len`. For simplicity, these buffers should have the same length as `d_keys` to avoid out-of-bound access.\n* The GPU kernels will be launched in `stream` CUDA stream.\n* The host thread will return from the API immediately after the kernels are launched, thus this API is Asynchronous with CPU thread.\n* The keys to be queried in the `d_keys` buffer can have duplication. In this case, user will get duplicated returned values or missing information.\n* This API is thread-safe and can be called concurrently with other APIs.\n* For hyper-parameter `task_per_warp_tile`, see `Performance hint` session below.\n\n`Replace`\n\n* The API will replace `len` [key, value] pairs listed in `d_keys` and `d_values` into the embedding cache using the LRU replacement algorithm.\n* The GPU kernels will be launched in `stream` CUDA stream.\n* The host thread will return from the API immediately after the kernels are launched, thus this API is Asynchronous with CPU thread.\n* The keys to be replaced in the `d_keys` buffer can have duplication and can be already stored inside the cache. In these cases, the cache will detect any possible duplication and maintain the uniqueness of all the [key ,value] pairs stored in the cache.\n* This API is thread-safe and can be called concurrently with other APIs.\n* This API will first try to insert the [key, value] pairs into the cache if there is any empty slot. If the cache is full, it will do the replacement.\n* For hyper-parameter `task_per_warp_tile`, see `Performance hint` session below.\n\n`Update`\n\n* The API will search for `len` keys listed in `d_keys` buffer within the cache. If a key is found in the cache, this API will update the value associated with the key to the corresponding values provided in `d_values` buffer. If a key is not found in the cache, this API will do nothing to this key.\n* The GPU kernels will be launched in `stream` CUDA stream.\n* The host thread will return from the API immediately after the kernels are launched, thus this API is Asynchronous with CPU thread.\n* If the keys to be updated in the `d_keys` buffer have duplication, all values associated with this key in the `d_values` buffer will be updated to the cache atomically. The final result depends on the order of updating the value.\n* This API is thread-safe and can be called concurrently with other APIs.\n* For hyper-parameter `task_per_warp_tile`, see `Performance hint` session below.\n\n`Dump`\n\n* The API will dump all the keys stored in [`start_set_index`, `end_set_index`) cache sets to `d_keys` buffer as a linear array(the key order is not guaranteed). The total # of keys dumped will be reported in `d_dump_counter` variable.\n* The GPU kernels will be launched in `stream` CUDA stream.\n* The host thread will return from the API immediately after the kernels are launched, thus this API is Asynchronous with CPU thread.\n* This API is thread-safe and can be called concurrently with other APIs.\n\n## More Information\n\n* The detailed introduction of the GPU embedding cache data structure is presented at GTC China 2020: https://on-demand-gtc.gputechconf.com/gtcnew/sessionview.php?sessionName=cns20626-%e4%bd%bf%e7%94%a8+gpu+embedding+cache+%e5%8a%a0%e9%80%9f+ctr+%e6%8e%a8%e7%90%86%e8%bf%87%e7%a8%8b\n* The `test` folder contains a example of using the GPU embedding cache.\n* This project is used by `embedding_cache` class in `HugeCTR/include/inference/embedding_cache.hpp` which can be used as an example.\n\n## Performance Hint\n\n* The hyper-parameter `warp_size` should be keep as 32 by default. When the length for Query or Replace operations is small(~1-50k), user can choose smaller warp_size and increase the total # of cache set(while maintaining the same cache size) to increase the parallelism and improve the performance.\n* The hyper-parameter `set_associativity` is critical to performance:\n    + If set too small, may cause load imbalance between different cache sets(lower down the effective capacity of the cache, lower down the hit rate). To prevent this, the embedding cache uses a very random hash function to hash the keys to different cache set, thus will achieve load balance statistically. However, larger cache set will tends to have better load balance.\n    + If set too large, the searching space for a single key will be very large. The performance of the embedding cache API will drop dramatically. Also, each set will be accessed exclusively, thus the more cache sets the higher parallelism can be achieved.\n    + Recommend setting `set_associativity` to 2 or 4.\n* The runtime hyper-parameter `task_per_warp_tile` is set to 1 as default parameter, thus users don't need to change their code to accommodate this interface change. This hyper-parameter determines how many keys are been queried/replaced/updated by a single warp tile. The acceptable value is between [1, `warp_size`]. For small to medium size operations to the cache, less task per warp tile can increase the total # of warp tiles running concurrently on the GPU chip, thus can bring significant performance improvement. For large size operations to the cache, the increased # of warp tile will not bring any performance improvement(even a little regression on the performance, ~5%). User can choose the value for this parameter based on the value of `len` parameter.\n* The GPU is designed for optimizing throughput. Always try to batch up the inference task and try to have larger `query_size`.\n* As the APIs of the embedding cache is asynchronous with host threads. Try to optimize the E2E inference pipeline by overlapping asynchronous tasks on GPU or between CPU and GPU. For example, after retrieving the missing values from the parameter server, user can combine the missing values with the hit values and do the rest of inference pipeline at the same time with the `Replace` API. Replacement is not necessarily happens together with Query all the time, user can do query multiple times then do a replacement if the hit rate is acceptable.\n* Try different cache capacity and evaluate the hit rate. If the capacity of embedding cache can be larger than actual embedding footprint, the hit rate can be as high as 99%+.\n"
  },
  {
    "path": "third_party/HugeCTR/gpu_cache/include/gpu_cache_api.hpp",
    "content": "/*\n * Copyright (c) 2021, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#pragma once\n#include <nv_util.h>\n\n#define TASK_PER_WARP_TILE_MACRO 1\n\nnamespace gpu_cache {\n\n///////////////////////////////////////////////////////////////////////////////////////////////////\n\n// GPU Cache API\ntemplate <typename key_type>\nclass gpu_cache_api {\n public:\n  virtual ~gpu_cache_api() noexcept(false) {}\n  // Query API, i.e. A single read from the cache\n  virtual void Query(const key_type* d_keys, const size_t len, float* d_values,\n                     uint64_t* d_missing_index, key_type* d_missing_keys, size_t* d_missing_len,\n                     cudaStream_t stream,\n                     const size_t task_per_warp_tile = TASK_PER_WARP_TILE_MACRO) = 0;\n\n  // Replace API, i.e. Follow the Query API to update the content of the cache to Most Recent\n  virtual void Replace(const key_type* d_keys, const size_t len, const float* d_values,\n                       cudaStream_t stream,\n                       const size_t task_per_warp_tile = TASK_PER_WARP_TILE_MACRO) = 0;\n\n  // Update API, i.e. update the embeddings which exist in the cache\n  virtual void Update(const key_type* d_keys, const size_t len, const float* d_values,\n                      cudaStream_t stream,\n                      const size_t task_per_warp_tile = TASK_PER_WARP_TILE_MACRO) = 0;\n\n  // Dump API, i.e. dump some slabsets' keys from the cache\n  virtual void Dump(key_type* d_keys, size_t* d_dump_counter, const size_t start_set_index,\n                    const size_t end_set_index, cudaStream_t stream) = 0;\n\n  // Record all the lookup stream of a specific cache for Update/Replace sync\n  virtual void Record(cudaStream_t stream) = 0;\n};\n\n}  // namespace gpu_cache\n"
  },
  {
    "path": "third_party/HugeCTR/gpu_cache/include/hash_functions.cuh",
    "content": "/*\n * Copyright (c) 2023, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n#pragma once\n\n#include <cstdint>\n\n// MurmurHash3_32 implementation from\n// https://github.com/aappleby/smhasher/blob/master/src/MurmurHash3.cpp\n//-----------------------------------------------------------------------------\n// MurmurHash3 was written by Austin Appleby, and is placed in the public\n// domain. The author hereby disclaims copyright to this source code.\n// Note - The x86 and x64 versions do _not_ produce the same results, as the\n// algorithms are optimized for their respective platforms. You can still\n// compile and run any of them on any platform, but your performance with the\n// non-native version will be less than optimal.\ntemplate <typename Key, uint32_t m_seed = 0>\nstruct MurmurHash3_32 {\n  using argument_type = Key;\n  using result_type = uint32_t;\n\n  /*__forceinline__\n  __host__ __device__\n  MurmurHash3_32() : m_seed( 0 ) {}*/\n\n  __forceinline__ __host__ __device__ static uint32_t rotl32(uint32_t x, int8_t r) {\n    return (x << r) | (x >> (32 - r));\n  }\n\n  __forceinline__ __host__ __device__ static uint32_t fmix32(uint32_t h) {\n    h ^= h >> 16;\n    h *= 0x85ebca6b;\n    h ^= h >> 13;\n    h *= 0xc2b2ae35;\n    h ^= h >> 16;\n    return h;\n  }\n\n  /* --------------------------------------------------------------------------*/\n  /**\n   * @Synopsis  Combines two hash values into a new single hash value. Called\n   * repeatedly to create a hash value from several variables.\n   * Taken from the Boost hash_combine function\n   * https://www.boost.org/doc/libs/1_35_0/doc/html/boost/hash_combine_id241013.html\n   *\n   * @Param lhs The first hash value to combine\n   * @Param rhs The second hash value to combine\n   *\n   * @Returns A hash value that intelligently combines the lhs and rhs hash values\n   */\n  /* ----------------------------------------------------------------------------*/\n  __host__ __device__ static result_type hash_combine(result_type lhs, result_type rhs) {\n    result_type combined{lhs};\n\n    combined ^= rhs + 0x9e3779b9 + (combined << 6) + (combined >> 2);\n\n    return combined;\n  }\n\n  __forceinline__ __host__ __device__ static result_type hash(const Key& key) {\n    constexpr int len = sizeof(argument_type);\n    const uint8_t* const data = (const uint8_t*)&key;\n    constexpr int nblocks = len / 4;\n    uint32_t h1 = m_seed;\n    constexpr uint32_t c1 = 0xcc9e2d51;\n    constexpr uint32_t c2 = 0x1b873593;\n    //----------\n    // body\n    const uint32_t* const blocks = (const uint32_t*)(data + nblocks * 4);\n    for (int i = -nblocks; i; i++) {\n      uint32_t k1 = blocks[i];  // getblock32(blocks,i);\n      k1 *= c1;\n      k1 = rotl32(k1, 15);\n      k1 *= c2;\n      h1 ^= k1;\n      h1 = rotl32(h1, 13);\n      h1 = h1 * 5 + 0xe6546b64;\n    }\n    //----------\n    // tail\n    const uint8_t* tail = (const uint8_t*)(data + nblocks * 4);\n    uint32_t k1 = 0;\n    switch (len & 3) {\n      case 3:\n        k1 ^= tail[2] << 16;\n      case 2:\n        k1 ^= tail[1] << 8;\n      case 1:\n        k1 ^= tail[0];\n        k1 *= c1;\n        k1 = rotl32(k1, 15);\n        k1 *= c2;\n        h1 ^= k1;\n    };\n    //----------\n    // finalization\n    h1 ^= len;\n    h1 = fmix32(h1);\n    return h1;\n  }\n\n  __host__ __device__ __forceinline__ result_type operator()(const Key& key) const {\n    return this->hash(key);\n  }\n};\n\ntemplate <typename key_type, typename index_type, index_type result>\nstruct Fix_Hash {\n  using result_type = index_type;\n\n  __forceinline__ __host__ __device__ static index_type hash(const key_type& key) { return result; }\n};\n\ntemplate <typename key_type, typename result_type>\nstruct Mod_Hash {\n  __forceinline__ __host__ __device__ static result_type hash(const key_type& key) {\n    return (result_type)key;\n  }\n};\n"
  },
  {
    "path": "third_party/HugeCTR/gpu_cache/include/nv_gpu_cache.hpp",
    "content": "/*\n * Copyright (c) 2023, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n#pragma once\n\n#include <nv_util.h>\n\n#include <cstdio>\n#include <hash_functions.cuh>\n#include <limits>\n\n#include \"gpu_cache_api.hpp\"\n#ifdef LIBCUDACXX_VERSION\n#include <cuda/atomic>\n#include <cuda/semaphore>\n#endif\n\n#define SET_ASSOCIATIVITY 2\n#define SLAB_SIZE 32\n#define TASK_PER_WARP_TILE_MACRO 1\n\nnamespace gpu_cache {\n\n// slab for static slab list\ntemplate <typename key_type, int warp_size>\nstruct static_slab {\n  key_type slab_[warp_size];\n};\n\n// Static slablist(slabset) for GPU Cache\ntemplate <int set_associativity, typename key_type, int warp_size>\nstruct slab_set {\n  static_slab<key_type, warp_size> set_[set_associativity];\n};\n\n///////////////////////////////////////////////////////////////////////////////////////////////////\n\n// GPU Cache\ntemplate <typename key_type, typename ref_counter_type, key_type empty_key, int set_associativity,\n          int warp_size, typename set_hasher = MurmurHash3_32<key_type>,\n          typename slab_hasher = Mod_Hash<key_type, size_t>>\nclass gpu_cache : public gpu_cache_api<key_type> {\n public:\n  // Ctor\n  gpu_cache(const size_t capacity_in_set, const size_t embedding_vec_size);\n\n  // Dtor\n  ~gpu_cache();\n\n  // Query API, i.e. A single read from the cache\n  void Query(const key_type* d_keys, const size_t len, float* d_values, uint64_t* d_missing_index,\n             key_type* d_missing_keys, size_t* d_missing_len, cudaStream_t stream,\n             const size_t task_per_warp_tile = TASK_PER_WARP_TILE_MACRO) override;\n\n  // Replace API, i.e. Follow the Query API to update the content of the cache to Most Recent\n  void Replace(const key_type* d_keys, const size_t len, const float* d_values, cudaStream_t stream,\n               const size_t task_per_warp_tile = TASK_PER_WARP_TILE_MACRO) override;\n\n  // Update API, i.e. update the embeddings which exist in the cache\n  void Update(const key_type* d_keys, const size_t len, const float* d_values, cudaStream_t stream,\n              const size_t task_per_warp_tile = TASK_PER_WARP_TILE_MACRO) override;\n\n  // Dump API, i.e. dump some slabsets' keys from the cache\n  void Dump(key_type* d_keys, size_t* d_dump_counter, const size_t start_set_index,\n            const size_t end_set_index, cudaStream_t stream) override;\n\n  void Record(cudaStream_t stream) override {}\n\n public:\n  using slabset = slab_set<set_associativity, key_type, warp_size>;\n#ifdef LIBCUDACXX_VERSION\n  using atomic_ref_counter_type = cuda::atomic<ref_counter_type, cuda::thread_scope_device>;\n  using mutex = cuda::binary_semaphore<cuda::thread_scope_device>;\n#endif\n\n private:\n  static const size_t BLOCK_SIZE_ = 64;\n\n  // Cache data\n  slabset* keys_;\n  float* vals_;\n  ref_counter_type* slot_counter_;\n\n  // Global counter\n#ifdef LIBCUDACXX_VERSION\n  atomic_ref_counter_type* global_counter_;\n#else\n  ref_counter_type* global_counter_;\n#endif\n  // CUDA device\n  int dev_;\n\n  // Cache capacity\n  size_t capacity_in_set_;\n  size_t num_slot_;\n\n  // Embedding vector size\n  size_t embedding_vec_size_;\n\n#ifdef LIBCUDACXX_VERSION\n  // Array of mutex to protect (sub-)warp-level data structure, each mutex protect 1 slab set\n  mutex* set_mutex_;\n#else\n  // Array of flag to protect (sub-)warp-level data structure, each flag act as a mutex and protect\n  // 1 slab set 1 for unlock, 0 for lock\n  int* set_mutex_;\n#endif\n};\n\n}  // namespace gpu_cache\n"
  },
  {
    "path": "third_party/HugeCTR/gpu_cache/include/nv_util.h",
    "content": "/*\n * Copyright (c) 2023, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n#pragma once\n\n#include <cuda_fp16.h>\n#include <cuda_fp8.h>\n#include <cuda_runtime_api.h>\n\n#include <stdexcept>\n#include <string>\n\n#define CUDA_CHECK(val) \\\n  { nv::cuda_check_((val), __FILE__, __LINE__); }\n\nnamespace nv {\n\ntemplate <typename T>\nstruct is_fp8 : std::false_type {};\n\ntemplate <>\nstruct is_fp8<__nv_fp8_e4m3> : std::true_type {};\n\ntemplate <>\nstruct is_fp8<__nv_fp8_e5m2> : std::true_type {};\n\nclass CudaException : public std::runtime_error {\n public:\n  CudaException(const std::string& what) : runtime_error(what) {}\n};\n\ninline void cuda_check_(cudaError_t val, const char* file, int line) {\n  if (val != cudaSuccess) {\n    throw CudaException(std::string(file) + \":\" + std::to_string(line) + \": CUDA error \" +\n                        std::to_string(val) + \": \" + cudaGetErrorString(val));\n  }\n}\n\nclass CudaDeviceRestorer {\n public:\n  CudaDeviceRestorer() { CUDA_CHECK(cudaGetDevice(&dev_)); }\n  ~CudaDeviceRestorer() { CUDA_CHECK(cudaSetDevice(dev_)); }\n  void check_device(int device) const {\n    if (device != dev_) {\n      throw std::runtime_error(\n          std::string(__FILE__) + \":\" + std::to_string(__LINE__) +\n          \": Runtime Error: The device id in the context is not consistent with configuration\");\n    }\n  }\n\n private:\n  int dev_;\n};\n\ninline int get_dev(const void* ptr) {\n  cudaPointerAttributes attr;\n  CUDA_CHECK(cudaPointerGetAttributes(&attr, ptr));\n  int dev = -1;\n\n#if CUDART_VERSION >= 10000\n  if (attr.type == cudaMemoryTypeDevice)\n#else\n  if (attr.memoryType == cudaMemoryTypeDevice)\n#endif\n  {\n    dev = attr.device;\n  }\n  return dev;\n}\n\ninline void switch_to_dev(const void* ptr) {\n  int dev = get_dev(ptr);\n  if (dev >= 0) {\n    CUDA_CHECK(cudaSetDevice(dev));\n  }\n}\n\n}  // namespace nv\n"
  },
  {
    "path": "third_party/HugeCTR/gpu_cache/src/nv_gpu_cache.cu",
    "content": "/*\n * Copyright (c) 2023, NVIDIA CORPORATION.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include <cooperative_groups.h>\n\n#include <nv_gpu_cache.hpp>\n\nnamespace cg = cooperative_groups;\n\n// Overload CUDA atomic for other 64bit unsigned/signed integer type\n__forceinline__ __device__ long atomicAdd(long* address, long val) {\n  return (long)atomicAdd((unsigned long long*)address, (unsigned long long)val);\n}\n\n__forceinline__ __device__ long long atomicAdd(long long* address, long long val) {\n  return (long long)atomicAdd((unsigned long long*)address, (unsigned long long)val);\n}\n\n__forceinline__ __device__ unsigned long atomicAdd(unsigned long* address, unsigned long val) {\n  return (unsigned long)atomicAdd((unsigned long long*)address, (unsigned long long)val);\n}\n\nnamespace gpu_cache {\n\n#ifdef LIBCUDACXX_VERSION\ntemplate <int warp_size>\n__forceinline__ __device__ void warp_tile_copy(const size_t lane_idx,\n                                               const size_t emb_vec_size_in_float, float* d_dst,\n                                               const float* d_src) {\n#pragma unroll\n  for (size_t i = lane_idx; i < emb_vec_size_in_float; i += warp_size) {\n    d_dst[i] = d_src[i];\n  }\n}\n#else\ntemplate <int warp_size>\n__forceinline__ __device__ void warp_tile_copy(const size_t lane_idx,\n                                               const size_t emb_vec_size_in_float,\n                                               volatile float* d_dst, volatile float* d_src) {\n\n#pragma unroll\n  for (size_t i = lane_idx; i < emb_vec_size_in_float; i += warp_size) {\n    d_dst[i] = d_src[i];\n  }\n}\n#endif\n\n#ifdef LIBCUDACXX_VERSION\n// Will be called by multiple thread_block_tile((sub-)warp) on the same mutex\n// Expect only one thread_block_tile return to execute critical section at any time\ntemplate <typename mutex, int warp_size>\n__forceinline__ __device__ void warp_lock_mutex(const cg::thread_block_tile<warp_size>& warp_tile,\n                                                mutex& set_mutex) {\n  // The first thread of this (sub-)warp to acquire the lock\n  if (warp_tile.thread_rank() == 0) {\n    set_mutex.acquire();\n  }\n  warp_tile.sync();  // Synchronize the threads in the (sub-)warp. Execution barrier + memory fence\n}\n\n// The (sub-)warp holding the mutex will unlock the mutex after finishing the critical section on a\n// set Expect any following (sub-)warp that acquire the mutex can see its modification done in the\n// critical section\ntemplate <typename mutex, int warp_size>\n__forceinline__ __device__ void warp_unlock_mutex(const cg::thread_block_tile<warp_size>& warp_tile,\n                                                  mutex& set_mutex) {\n  warp_tile.sync();  // Synchronize the threads in the (sub-)warp. Execution barrier + memory fence\n  // The first thread of this (sub-)warp to release the lock\n  if (warp_tile.thread_rank() == 0) {\n    set_mutex.release();\n  }\n}\n#else\n// Will be called by multiple thread_block_tile((sub-)warp) on the same mutex\n// Expect only one thread_block_tile return to execute critical section at any time\ntemplate <int warp_size>\n__forceinline__ __device__ void warp_lock_mutex(const cg::thread_block_tile<warp_size>& warp_tile,\n                                                volatile int& set_mutex) {\n  // The first thread of this (sub-)warp to acquire the lock\n  if (warp_tile.thread_rank() == 0) {\n    while (0 == atomicCAS((int*)&set_mutex, 1, 0))\n      ;\n  }\n  __threadfence();\n  warp_tile.sync();  // Synchronize the threads in the (sub-)warp. Execution barrier + memory fence\n}\n\n// The (sub-)warp holding the mutex will unlock the mutex after finishing the critical section on a\n// set Expect any following (sub-)warp that acquire the mutex can see its modification done in the\n// critical section\ntemplate <int warp_size>\n__forceinline__ __device__ void warp_unlock_mutex(const cg::thread_block_tile<warp_size>& warp_tile,\n                                                  volatile int& set_mutex) {\n  __threadfence();\n  warp_tile.sync();  // Synchronize the threads in the (sub-)warp. Execution barrier + memory fence\n  // The first thread of this (sub-)warp to release the lock\n  if (warp_tile.thread_rank() == 0) {\n    atomicExch((int*)&set_mutex, 1);\n  }\n}\n#endif\n\n// The (sub-)warp doing all reduction to find the slot with min slot_counter\n// The slot with min slot_counter is the LR slot.\ntemplate <typename ref_counter_type, int warp_size>\n__forceinline__ __device__ void warp_min_reduction(\n    const cg::thread_block_tile<warp_size>& warp_tile, ref_counter_type& min_slot_counter_val,\n    size_t& slab_distance, size_t& slot_distance) {\n  const size_t lane_idx = warp_tile.thread_rank();\n  slot_distance = lane_idx;\n\n  for (size_t i = (warp_tile.size() >> 1); i > 0; i = i >> 1) {\n    ref_counter_type input_slot_counter_val = warp_tile.shfl_xor(min_slot_counter_val, (int)i);\n    size_t input_slab_distance = warp_tile.shfl_xor(slab_distance, (int)i);\n    size_t input_slot_distance = warp_tile.shfl_xor(slot_distance, (int)i);\n\n    if (input_slot_counter_val == min_slot_counter_val) {\n      if (input_slab_distance == slab_distance) {\n        if (input_slot_distance < slot_distance) {\n          slot_distance = input_slot_distance;\n        }\n      } else if (input_slab_distance < slab_distance) {\n        slab_distance = input_slab_distance;\n        slot_distance = input_slot_distance;\n      }\n    } else if (input_slot_counter_val < min_slot_counter_val) {\n      min_slot_counter_val = input_slot_counter_val;\n      slab_distance = input_slab_distance;\n      slot_distance = input_slot_distance;\n    }\n  }\n}\n\n///////////////////////////////////////////////////////////////////////////////////////////////////\n\n#ifdef LIBCUDACXX_VERSION\n// Kernel to initialize the GPU cache\n// Init every entry of the cache with <unused_key, value> pair\ntemplate <typename slabset, typename ref_counter_type, typename atomic_ref_counter_type,\n          typename key_type, typename mutex>\n__global__ void init_cache(slabset* keys, ref_counter_type* slot_counter,\n                           atomic_ref_counter_type* global_counter, const size_t num_slot,\n                           const key_type empty_key, mutex* set_mutex,\n                           const size_t capacity_in_set) {\n  const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;\n  if (idx < num_slot) {\n    // Set the key of this slot to unused key\n    // Flatten the cache\n    key_type* key_slot = (key_type*)keys;\n    key_slot[idx] = empty_key;\n\n    // Clear the counter for this slot\n    slot_counter[idx] = 0;\n  }\n  // First CUDA thread clear the global counter\n  if (idx == 0) {\n    new (global_counter) atomic_ref_counter_type(0);\n  }\n\n  // First capacity_in_set CUDA thread initialize mutex\n  if (idx < capacity_in_set) {\n    new (set_mutex + idx) mutex(1);\n  }\n}\n\ntemplate <typename atomic_ref_counter_type, typename mutex>\n__global__ void destruct_kernel(atomic_ref_counter_type* global_counter, mutex* set_mutex,\n                                const size_t capacity_in_set) {\n  const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;\n  // First CUDA thread destruct the global_counter\n  if (idx == 0) {\n    global_counter->~atomic_ref_counter_type();\n  }\n  // First capacity_in_set CUDA thread destruct the set mutex\n  if (idx < capacity_in_set) {\n    (set_mutex + idx)->~mutex();\n  }\n}\n#else\n// Kernel to initialize the GPU cache\n// Init every entry of the cache with <unused_key, value> pair\ntemplate <typename slabset, typename ref_counter_type, typename key_type>\n__global__ void init_cache(slabset* keys, ref_counter_type* slot_counter,\n                           ref_counter_type* global_counter, const size_t num_slot,\n                           const key_type empty_key, int* set_mutex, const size_t capacity_in_set) {\n  const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;\n  if (idx < num_slot) {\n    // Set the key of this slot to unused key\n    // Flatten the cache\n    key_type* key_slot = (key_type*)keys;\n    key_slot[idx] = empty_key;\n\n    // Clear the counter for this slot\n    slot_counter[idx] = 0;\n  }\n  // First CUDA thread clear the global counter\n  if (idx == 0) {\n    global_counter[idx] = 0;\n  }\n\n  // First capacity_in_set CUDA thread initialize mutex\n  if (idx < capacity_in_set) {\n    set_mutex[idx] = 1;\n  }\n}\n#endif\n\n// Kernel to update global counter\n// Resolve distance overflow issue as well\n#ifdef LIBCUDACXX_VERSION\ntemplate <typename atomic_ref_counter_type>\n__global__ void update_kernel_overflow_ignore(atomic_ref_counter_type* global_counter,\n                                              size_t* d_missing_len) {\n  // Update global counter\n  global_counter->fetch_add(1, cuda::std::memory_order_relaxed);\n  *d_missing_len = 0;\n}\n#else\ntemplate <typename ref_counter_type>\n__global__ void update_kernel_overflow_ignore(ref_counter_type* global_counter,\n                                              size_t* d_missing_len) {\n  // Update global counter\n  atomicAdd(global_counter, 1);\n  *d_missing_len = 0;\n}\n#endif\n\n#ifdef LIBCUDACXX_VERSION\n// Kernel to read from cache\n// Also update locality information for touched slot\ntemplate <typename key_type, typename ref_counter_type, typename atomic_ref_counter_type,\n          typename slabset, typename set_hasher, typename slab_hasher, typename mutex,\n          key_type empty_key, int set_associativity, int warp_size>\n__global__ void get_kernel(const key_type* d_keys, const size_t len, float* d_values,\n                           const size_t embedding_vec_size, uint64_t* d_missing_index,\n                           key_type* d_missing_keys, size_t* d_missing_len,\n                           const atomic_ref_counter_type* global_counter,\n                           ref_counter_type* slot_counter, const size_t capacity_in_set,\n                           const slabset* keys, const float* vals, mutex* set_mutex,\n                           const size_t task_per_warp_tile) {\n  // Lane(thread) ID within a warp_tile\n  cg::thread_block_tile<warp_size> warp_tile =\n      cg::tiled_partition<warp_size>(cg::this_thread_block());\n  const size_t lane_idx = warp_tile.thread_rank();\n  // Warp tile global ID\n  const size_t warp_tile_global_idx =\n      (blockIdx.x * (blockDim.x / warp_size)) + warp_tile.meta_group_rank();\n  // The index of key for this thread\n  const size_t key_idx = (warp_tile_global_idx * task_per_warp_tile) + lane_idx;\n  // The assigned key for this lane(thread)\n  key_type key;\n  // The dst slabset and the dst slab inside this set\n  size_t src_set;\n  size_t src_slab;\n  // The variable that contains the missing key\n  key_type missing_key;\n  // The variable that contains the index for the missing key\n  uint64_t missing_index;\n  // The counter for counting the missing key in this warp\n  uint8_t warp_missing_counter = 0;\n  // Active flag: whether current lane(thread) has unfinished task\n  bool active = false;\n  if (lane_idx < task_per_warp_tile) {\n    if (key_idx < len) {\n      active = true;\n      key = d_keys[key_idx];\n      src_set = set_hasher::hash(key) % capacity_in_set;\n      src_slab = slab_hasher::hash(key) % set_associativity;\n    }\n  }\n\n  // Lane participate in warp_tile ballot to produce warp-level work queue\n  unsigned active_mask = warp_tile.ballot(active);\n\n  // The warp-level outer loop: finish all the tasks within the work queue\n  while (active_mask != 0) {\n    // Next task in the work quere, start from lower index lane(thread)\n    int next_lane = __ffs(active_mask) - 1;\n    // Broadcast the task and the global index to all lane in the warp_tile\n    key_type next_key = warp_tile.shfl(key, next_lane);\n    size_t next_idx = warp_tile.shfl(key_idx, next_lane);\n    size_t next_set = warp_tile.shfl(src_set, next_lane);\n    size_t next_slab = warp_tile.shfl(src_slab, next_lane);\n\n    // Counter to record how many slab have been searched\n    size_t counter = 0;\n\n    // Working queue before task started\n    const unsigned old_active_mask = active_mask;\n\n    // Lock the slabset before operating the slabset\n    warp_lock_mutex<mutex, warp_size>(warp_tile, set_mutex[next_set]);\n\n    // The warp-level inner loop: finish a single task in the work queue\n    while (active_mask == old_active_mask) {\n      // When all the slabs inside a slabset have been searched, mark missing task, task is\n      // completed\n      if (counter >= set_associativity) {\n        if (lane_idx == warp_missing_counter) {\n          missing_key = next_key;\n          missing_index = next_idx;\n        }\n\n        if (lane_idx == (size_t)next_lane) {\n          active = false;\n        }\n\n        warp_missing_counter++;\n        active_mask = warp_tile.ballot(active);\n        break;\n      }\n\n      // The warp_tile read out the slab\n      key_type read_key = keys[next_set].set_[next_slab].slab_[lane_idx];\n\n      // Compare the slab data with the target key\n      int found_lane = __ffs(warp_tile.ballot(read_key == next_key)) - 1;\n\n      // If found, mark hit task, copy the founded data, the task is completed\n      if (found_lane >= 0) {\n        size_t found_offset = (next_set * set_associativity + next_slab) * warp_size + found_lane;\n        if (lane_idx == (size_t)next_lane) {\n          slot_counter[found_offset] = global_counter->load(cuda::std::memory_order_relaxed);\n          active = false;\n        }\n\n        warp_tile_copy<warp_size>(lane_idx, embedding_vec_size,\n                                  d_values + next_idx * embedding_vec_size,\n                                  vals + found_offset * embedding_vec_size);\n\n        active_mask = warp_tile.ballot(active);\n        break;\n      }\n\n      // Compare the slab data with empty key, if found empty key, mark missing task, task is\n      // completed\n      if (warp_tile.ballot(read_key == empty_key) != 0) {\n        if (lane_idx == warp_missing_counter) {\n          missing_key = next_key;\n          missing_index = next_idx;\n        }\n\n        if (lane_idx == (size_t)next_lane) {\n          active = false;\n        }\n\n        warp_missing_counter++;\n        active_mask = warp_tile.ballot(active);\n        break;\n      }\n\n      // Not found in this slab, the task is not completed, goto searching next slab\n      counter++;\n      next_slab = (next_slab + 1) % set_associativity;\n    }\n\n    // Unlock the slabset after operating the slabset\n    warp_unlock_mutex<mutex, warp_size>(warp_tile, set_mutex[next_set]);\n  }\n\n  // After warp_tile complete the working queue, save the result for output\n  // First thread of the warp_tile accumulate the missing length to global variable\n  size_t warp_position;\n  if (lane_idx == 0) {\n    warp_position = atomicAdd(d_missing_len, (size_t)warp_missing_counter);\n  }\n  warp_position = warp_tile.shfl(warp_position, 0);\n\n  if (lane_idx < warp_missing_counter) {\n    d_missing_keys[warp_position + lane_idx] = missing_key;\n    d_missing_index[warp_position + lane_idx] = missing_index;\n  }\n}\n#else\n// Kernel to read from cache\n// Also update locality information for touched slot\ntemplate <typename key_type, typename ref_counter_type, typename slabset, typename set_hasher,\n          typename slab_hasher, key_type empty_key, int set_associativity, int warp_size>\n__global__ void get_kernel(const key_type* d_keys, const size_t len, float* d_values,\n                           const size_t embedding_vec_size, uint64_t* d_missing_index,\n                           key_type* d_missing_keys, size_t* d_missing_len,\n                           ref_counter_type* global_counter,\n                           volatile ref_counter_type* slot_counter, const size_t capacity_in_set,\n                           volatile slabset* keys, volatile float* vals, volatile int* set_mutex,\n                           const size_t task_per_warp_tile) {\n  // Lane(thread) ID within a warp_tile\n  cg::thread_block_tile<warp_size> warp_tile =\n      cg::tiled_partition<warp_size>(cg::this_thread_block());\n  const size_t lane_idx = warp_tile.thread_rank();\n  // Warp tile global ID\n  const size_t warp_tile_global_idx =\n      (blockIdx.x * (blockDim.x / warp_size)) + warp_tile.meta_group_rank();\n  // The index of key for this thread\n  const size_t key_idx = (warp_tile_global_idx * task_per_warp_tile) + lane_idx;\n  // The assigned key for this lane(thread)\n  key_type key;\n  // The dst slabset and the dst slab inside this set\n  size_t src_set;\n  size_t src_slab;\n  // The variable that contains the missing key\n  key_type missing_key;\n  // The variable that contains the index for the missing key\n  uint64_t missing_index;\n  // The counter for counting the missing key in this warp\n  uint8_t warp_missing_counter = 0;\n  // Active flag: whether current lane(thread) has unfinished task\n  bool active = false;\n  if (lane_idx < task_per_warp_tile) {\n    if (key_idx < len) {\n      active = true;\n      key = d_keys[key_idx];\n      src_set = set_hasher::hash(key) % capacity_in_set;\n      src_slab = slab_hasher::hash(key) % set_associativity;\n    }\n  }\n\n  // Lane participate in warp_tile ballot to produce warp-level work queue\n  unsigned active_mask = warp_tile.ballot(active);\n\n  // The warp-level outer loop: finish all the tasks within the work queue\n  while (active_mask != 0) {\n    // Next task in the work quere, start from lower index lane(thread)\n    int next_lane = __ffs(active_mask) - 1;\n    // Broadcast the task and the global index to all lane in the warp_tile\n    key_type next_key = warp_tile.shfl(key, next_lane);\n    size_t next_idx = warp_tile.shfl(key_idx, next_lane);\n    size_t next_set = warp_tile.shfl(src_set, next_lane);\n    size_t next_slab = warp_tile.shfl(src_slab, next_lane);\n\n    // Counter to record how many slab have been searched\n    size_t counter = 0;\n\n    // Working queue before task started\n    const unsigned old_active_mask = active_mask;\n\n    // Lock the slabset before operating the slabset\n    warp_lock_mutex<warp_size>(warp_tile, set_mutex[next_set]);\n\n    // The warp-level inner loop: finish a single task in the work queue\n    while (active_mask == old_active_mask) {\n      // When all the slabs inside a slabset have been searched, mark missing task, task is\n      // completed\n      if (counter >= set_associativity) {\n        if (lane_idx == warp_missing_counter) {\n          missing_key = next_key;\n          missing_index = next_idx;\n        }\n\n        if (lane_idx == (size_t)next_lane) {\n          active = false;\n        }\n\n        warp_missing_counter++;\n        active_mask = warp_tile.ballot(active);\n        break;\n      }\n\n      // The warp_tile read out the slab\n      key_type read_key = ((volatile key_type*)(keys[next_set].set_[next_slab].slab_))[lane_idx];\n\n      // Compare the slab data with the target key\n      int found_lane = __ffs(warp_tile.ballot(read_key == next_key)) - 1;\n\n      // If found, mark hit task, copy the founded data, the task is completed\n      if (found_lane >= 0) {\n        size_t found_offset = (next_set * set_associativity + next_slab) * warp_size + found_lane;\n        if (lane_idx == (size_t)next_lane) {\n          slot_counter[found_offset] = atomicAdd(global_counter, 0);\n          active = false;\n        }\n\n        warp_tile_copy<warp_size>(lane_idx, embedding_vec_size,\n                                  (volatile float*)(d_values + next_idx * embedding_vec_size),\n                                  (volatile float*)(vals + found_offset * embedding_vec_size));\n\n        active_mask = warp_tile.ballot(active);\n        break;\n      }\n\n      // Compare the slab data with empty key, if found empty key, mark missing task, task is\n      // completed\n      if (warp_tile.ballot(read_key == empty_key) != 0) {\n        if (lane_idx == warp_missing_counter) {\n          missing_key = next_key;\n          missing_index = next_idx;\n        }\n\n        if (lane_idx == (size_t)next_lane) {\n          active = false;\n        }\n\n        warp_missing_counter++;\n        active_mask = warp_tile.ballot(active);\n        break;\n      }\n\n      // Not found in this slab, the task is not completed, goto searching next slab\n      counter++;\n      next_slab = (next_slab + 1) % set_associativity;\n    }\n\n    // Unlock the slabset after operating the slabset\n    warp_unlock_mutex<warp_size>(warp_tile, set_mutex[next_set]);\n  }\n\n  // After warp_tile complete the working queue, save the result for output\n  // First thread of the warp_tile accumulate the missing length to global variable\n  size_t warp_position;\n  if (lane_idx == 0) {\n    warp_position = atomicAdd(d_missing_len, (size_t)warp_missing_counter);\n  }\n  warp_position = warp_tile.shfl(warp_position, 0);\n\n  if (lane_idx < warp_missing_counter) {\n    d_missing_keys[warp_position + lane_idx] = missing_key;\n    d_missing_index[warp_position + lane_idx] = missing_index;\n  }\n}\n#endif\n\n#ifdef LIBCUDACXX_VERSION\n// Kernel to insert or replace the <k,v> pairs into the cache\ntemplate <typename key_type, typename slabset, typename ref_counter_type, typename mutex,\n          typename atomic_ref_counter_type, typename set_hasher, typename slab_hasher,\n          key_type empty_key, int set_associativity, int warp_size,\n          ref_counter_type max_ref_counter_type = std::numeric_limits<ref_counter_type>::max(),\n          size_t max_slab_distance = std::numeric_limits<size_t>::max()>\n__global__ void insert_replace_kernel(const key_type* d_keys, const float* d_values,\n                                      const size_t embedding_vec_size, const size_t len,\n                                      slabset* keys, float* vals, ref_counter_type* slot_counter,\n                                      mutex* set_mutex,\n                                      const atomic_ref_counter_type* global_counter,\n                                      const size_t capacity_in_set,\n                                      const size_t task_per_warp_tile) {\n  // Lane(thread) ID within a warp_tile\n  cg::thread_block_tile<warp_size> warp_tile =\n      cg::tiled_partition<warp_size>(cg::this_thread_block());\n  const size_t lane_idx = warp_tile.thread_rank();\n  // Warp tile global ID\n  const size_t warp_tile_global_idx =\n      (blockIdx.x * (blockDim.x / warp_size)) + warp_tile.meta_group_rank();\n  // The index of key for this thread\n  const size_t key_idx = (warp_tile_global_idx * task_per_warp_tile) + lane_idx;\n  // The assigned key for this lane(thread)\n  key_type key;\n  // The dst slabset and the dst slab inside this set\n  size_t src_set;\n  size_t src_slab;\n  // Active flag: whether current lane(thread) has unfinished task\n  bool active = false;\n  if (lane_idx < task_per_warp_tile) {\n    if (key_idx < len) {\n      active = true;\n      key = d_keys[key_idx];\n      src_set = set_hasher::hash(key) % capacity_in_set;\n      src_slab = slab_hasher::hash(key) % set_associativity;\n    }\n  }\n\n  // Lane participate in warp_tile ballot to produce warp-level work queue\n  unsigned active_mask = warp_tile.ballot(active);\n\n  // The warp-level outer loop: finish all the tasks within the work queue\n  while (active_mask != 0) {\n    // Next task in the work quere, start from lower index lane(thread)\n    int next_lane = __ffs(active_mask) - 1;\n    // Broadcast the task, the global index and the src slabset and slab to all lane in a warp_tile\n    key_type next_key = warp_tile.shfl(key, next_lane);\n    size_t next_idx = warp_tile.shfl(key_idx, next_lane);\n    size_t next_set = warp_tile.shfl(src_set, next_lane);\n    size_t next_slab = warp_tile.shfl(src_slab, next_lane);\n    size_t first_slab = next_slab;\n\n    // Counter to record how many slab have been searched\n    size_t counter = 0;\n\n    // Variable to keep the min slot counter during the probing\n    ref_counter_type min_slot_counter_val = max_ref_counter_type;\n    // Variable to keep the slab distance for slot with min counter\n    size_t slab_distance = max_slab_distance;\n    // Variable to keep the slot distance for slot with min counter within the slab\n    size_t slot_distance;\n    // Working queue before task started\n    const unsigned old_active_mask = active_mask;\n\n    // Lock the slabset before operating the slabset\n    warp_lock_mutex<mutex, warp_size>(warp_tile, set_mutex[next_set]);\n\n    // The warp-level inner loop: finish a single task in the work queue\n    while (active_mask == old_active_mask) {\n      // When all the slabs inside a slabset have been searched\n      // and no empty slots or target slots are found. Replace with LRU\n      if (counter >= set_associativity) {\n        // (sub)Warp all-reduction, the reduction result store in all threads\n        warp_min_reduction<ref_counter_type, warp_size>(warp_tile, min_slot_counter_val,\n                                                        slab_distance, slot_distance);\n\n        // Calculate the position of LR slot\n        size_t target_slab = (first_slab + slab_distance) % set_associativity;\n        size_t slot_index =\n            (next_set * set_associativity + target_slab) * warp_size + slot_distance;\n\n        // Replace the LR slot\n        if (lane_idx == (size_t)next_lane) {\n          keys[next_set].set_[target_slab].slab_[slot_distance] = key;\n          slot_counter[slot_index] = global_counter->load(cuda::std::memory_order_relaxed);\n        }\n\n        warp_tile_copy<warp_size>(lane_idx, embedding_vec_size,\n                                  vals + slot_index * embedding_vec_size,\n                                  d_values + next_idx * embedding_vec_size);\n\n        // Replace complete, mark this task completed\n        if (lane_idx == (size_t)next_lane) {\n          active = false;\n        }\n\n        active_mask = warp_tile.ballot(active);\n        break;\n      }\n\n      // The warp_tile read out the slab\n      key_type read_key = keys[next_set].set_[next_slab].slab_[lane_idx];\n\n      // Compare the slab data with the target key\n      int found_lane = __ffs(warp_tile.ballot(read_key == next_key)) - 1;\n\n      // If found target key, the insertion/replace is no longer needed.\n      // Refresh the slot, the task is completed\n      if (found_lane >= 0) {\n        size_t found_offset = (next_set * set_associativity + next_slab) * warp_size + found_lane;\n        if (lane_idx == (size_t)next_lane) {\n          slot_counter[found_offset] = global_counter->load(cuda::std::memory_order_relaxed);\n          active = false;\n        }\n\n        active_mask = warp_tile.ballot(active);\n        break;\n      }\n\n      // Compare the slab data with empty key.\n      // If found empty key, do insertion,the task is complete\n      found_lane = __ffs(warp_tile.ballot(read_key == empty_key)) - 1;\n      if (found_lane >= 0) {\n        size_t found_offset = (next_set * set_associativity + next_slab) * warp_size + found_lane;\n\n        if (lane_idx == (size_t)next_lane) {\n          keys[next_set].set_[next_slab].slab_[found_lane] = key;\n          slot_counter[found_offset] = global_counter->load(cuda::std::memory_order_relaxed);\n        }\n\n        warp_tile_copy<warp_size>(lane_idx, embedding_vec_size,\n                                  vals + found_offset * embedding_vec_size,\n                                  d_values + next_idx * embedding_vec_size);\n\n        if (lane_idx == (size_t)next_lane) {\n          active = false;\n        }\n\n        active_mask = warp_tile.ballot(active);\n        break;\n      }\n\n      // If no target or unused slot found in this slab,\n      // Refresh LR info, continue probing\n      ref_counter_type read_slot_counter =\n          slot_counter[(next_set * set_associativity + next_slab) * warp_size + lane_idx];\n      if (read_slot_counter < min_slot_counter_val) {\n        min_slot_counter_val = read_slot_counter;\n        slab_distance = counter;\n      }\n\n      counter++;\n      next_slab = (next_slab + 1) % set_associativity;\n    }\n\n    // Unlock the slabset after operating the slabset\n    warp_unlock_mutex<mutex, warp_size>(warp_tile, set_mutex[next_set]);\n  }\n}\n#else\n// Kernel to insert or replace the <k,v> pairs into the cache\ntemplate <typename key_type, typename slabset, typename ref_counter_type, typename set_hasher,\n          typename slab_hasher, key_type empty_key, int set_associativity, int warp_size,\n          ref_counter_type max_ref_counter_type = std::numeric_limits<ref_counter_type>::max(),\n          size_t max_slab_distance = std::numeric_limits<size_t>::max()>\n__global__ void insert_replace_kernel(const key_type* d_keys, const float* d_values,\n                                      const size_t embedding_vec_size, const size_t len,\n                                      volatile slabset* keys, volatile float* vals,\n                                      volatile ref_counter_type* slot_counter,\n                                      volatile int* set_mutex, ref_counter_type* global_counter,\n                                      const size_t capacity_in_set,\n                                      const size_t task_per_warp_tile) {\n  // Lane(thread) ID within a warp_tile\n  cg::thread_block_tile<warp_size> warp_tile =\n      cg::tiled_partition<warp_size>(cg::this_thread_block());\n  const size_t lane_idx = warp_tile.thread_rank();\n  // Warp tile global ID\n  const size_t warp_tile_global_idx =\n      (blockIdx.x * (blockDim.x / warp_size)) + warp_tile.meta_group_rank();\n  // The index of key for this thread\n  const size_t key_idx = (warp_tile_global_idx * task_per_warp_tile) + lane_idx;\n  // The assigned key for this lane(thread)\n  key_type key;\n  // The dst slabset and the dst slab inside this set\n  size_t src_set;\n  size_t src_slab;\n  // Active flag: whether current lane(thread) has unfinished task\n  bool active = false;\n  if (lane_idx < task_per_warp_tile) {\n    if (key_idx < len) {\n      active = true;\n      key = d_keys[key_idx];\n      src_set = set_hasher::hash(key) % capacity_in_set;\n      src_slab = slab_hasher::hash(key) % set_associativity;\n    }\n  }\n\n  // Lane participate in warp_tile ballot to produce warp-level work queue\n  unsigned active_mask = warp_tile.ballot(active);\n\n  // The warp-level outer loop: finish all the tasks within the work queue\n  while (active_mask != 0) {\n    // Next task in the work quere, start from lower index lane(thread)\n    int next_lane = __ffs(active_mask) - 1;\n    // Broadcast the task, the global index and the src slabset and slab to all lane in a warp_tile\n    key_type next_key = warp_tile.shfl(key, next_lane);\n    size_t next_idx = warp_tile.shfl(key_idx, next_lane);\n    size_t next_set = warp_tile.shfl(src_set, next_lane);\n    size_t next_slab = warp_tile.shfl(src_slab, next_lane);\n    size_t first_slab = next_slab;\n\n    // Counter to record how many slab have been searched\n    size_t counter = 0;\n\n    // Variable to keep the min slot counter during the probing\n    ref_counter_type min_slot_counter_val = max_ref_counter_type;\n    // Variable to keep the slab distance for slot with min counter\n    size_t slab_distance = max_slab_distance;\n    // Variable to keep the slot distance for slot with min counter within the slab\n    size_t slot_distance;\n    // Working queue before task started\n    const unsigned old_active_mask = active_mask;\n\n    // Lock the slabset before operating the slabset\n    warp_lock_mutex<warp_size>(warp_tile, set_mutex[next_set]);\n\n    // The warp-level inner loop: finish a single task in the work queue\n    while (active_mask == old_active_mask) {\n      // When all the slabs inside a slabset have been searched\n      // and no empty slots or target slots are found. Replace with LRU\n      if (counter >= set_associativity) {\n        // (sub)Warp all-reduction, the reduction result store in all threads\n        warp_min_reduction<ref_counter_type, warp_size>(warp_tile, min_slot_counter_val,\n                                                        slab_distance, slot_distance);\n\n        // Calculate the position of LR slot\n        size_t target_slab = (first_slab + slab_distance) % set_associativity;\n        size_t slot_index =\n            (next_set * set_associativity + target_slab) * warp_size + slot_distance;\n\n        // Replace the LR slot\n        if (lane_idx == (size_t)next_lane) {\n          ((volatile key_type*)(keys[next_set].set_[target_slab].slab_))[slot_distance] = key;\n          slot_counter[slot_index] = atomicAdd(global_counter, 0);\n        }\n\n        warp_tile_copy<warp_size>(lane_idx, embedding_vec_size,\n                                  (volatile float*)(vals + slot_index * embedding_vec_size),\n                                  (volatile float*)(d_values + next_idx * embedding_vec_size));\n\n        // Replace complete, mark this task completed\n        if (lane_idx == (size_t)next_lane) {\n          active = false;\n        }\n\n        active_mask = warp_tile.ballot(active);\n        break;\n      }\n\n      // The warp_tile read out the slab\n      key_type read_key = ((volatile key_type*)(keys[next_set].set_[next_slab].slab_))[lane_idx];\n\n      // Compare the slab data with the target key\n      int found_lane = __ffs(warp_tile.ballot(read_key == next_key)) - 1;\n\n      // If found target key, the insertion/replace is no longer needed.\n      // Refresh the slot, the task is completed\n      if (found_lane >= 0) {\n        size_t found_offset = (next_set * set_associativity + next_slab) * warp_size + found_lane;\n        if (lane_idx == (size_t)next_lane) {\n          slot_counter[found_offset] = atomicAdd(global_counter, 0);\n          active = false;\n        }\n\n        active_mask = warp_tile.ballot(active);\n        break;\n      }\n\n      // Compare the slab data with empty key.\n      // If found empty key, do insertion,the task is complete\n      found_lane = __ffs(warp_tile.ballot(read_key == empty_key)) - 1;\n      if (found_lane >= 0) {\n        size_t found_offset = (next_set * set_associativity + next_slab) * warp_size + found_lane;\n\n        if (lane_idx == (size_t)next_lane) {\n          ((volatile key_type*)(keys[next_set].set_[next_slab].slab_))[found_lane] = key;\n          slot_counter[found_offset] = atomicAdd(global_counter, 0);\n        }\n\n        warp_tile_copy<warp_size>(lane_idx, embedding_vec_size,\n                                  (volatile float*)(vals + found_offset * embedding_vec_size),\n                                  (volatile float*)(d_values + next_idx * embedding_vec_size));\n\n        if (lane_idx == (size_t)next_lane) {\n          active = false;\n        }\n\n        active_mask = warp_tile.ballot(active);\n        break;\n      }\n\n      // If no target or unused slot found in this slab,\n      // Refresh LR info, continue probing\n      ref_counter_type read_slot_counter =\n          slot_counter[(next_set * set_associativity + next_slab) * warp_size + lane_idx];\n      if (read_slot_counter < min_slot_counter_val) {\n        min_slot_counter_val = read_slot_counter;\n        slab_distance = counter;\n      }\n\n      counter++;\n      next_slab = (next_slab + 1) % set_associativity;\n    }\n\n    // Unlock the slabset after operating the slabset\n    warp_unlock_mutex<warp_size>(warp_tile, set_mutex[next_set]);\n  }\n}\n#endif\n\n#ifdef LIBCUDACXX_VERSION\n// Kernel to update the existing keys in the cache\n// Will not change the locality information\ntemplate <typename key_type, typename slabset, typename set_hasher, typename slab_hasher,\n          typename mutex, key_type empty_key, int set_associativity, int warp_size>\n__global__ void update_kernel(const key_type* d_keys, const size_t len, const float* d_values,\n                              const size_t embedding_vec_size, const size_t capacity_in_set,\n                              const slabset* keys, float* vals, mutex* set_mutex,\n                              const size_t task_per_warp_tile) {\n  // Lane(thread) ID within a warp_tile\n  cg::thread_block_tile<warp_size> warp_tile =\n      cg::tiled_partition<warp_size>(cg::this_thread_block());\n  const size_t lane_idx = warp_tile.thread_rank();\n  // Warp tile global ID\n  const size_t warp_tile_global_idx =\n      (blockIdx.x * (blockDim.x / warp_size)) + warp_tile.meta_group_rank();\n  // The index of key for this thread\n  const size_t key_idx = (warp_tile_global_idx * task_per_warp_tile) + lane_idx;\n  // The assigned key for this lane(thread)\n  key_type key;\n  // The dst slabset and the dst slab inside this set\n  size_t src_set;\n  size_t src_slab;\n  // Active flag: whether current lane(thread) has unfinished task\n  bool active = false;\n  if (lane_idx < task_per_warp_tile) {\n    if (key_idx < len) {\n      active = true;\n      key = d_keys[key_idx];\n      src_set = set_hasher::hash(key) % capacity_in_set;\n      src_slab = slab_hasher::hash(key) % set_associativity;\n    }\n  }\n\n  // Lane participate in warp_tile ballot to produce warp-level work queue\n  unsigned active_mask = warp_tile.ballot(active);\n\n  // The warp-level outer loop: finish all the tasks within the work queue\n  while (active_mask != 0) {\n    // Next task in the work quere, start from lower index lane(thread)\n    int next_lane = __ffs(active_mask) - 1;\n    // Broadcast the task and the global index to all lane in the warp_tile\n    key_type next_key = warp_tile.shfl(key, next_lane);\n    size_t next_idx = warp_tile.shfl(key_idx, next_lane);\n    size_t next_set = warp_tile.shfl(src_set, next_lane);\n    size_t next_slab = warp_tile.shfl(src_slab, next_lane);\n\n    // Counter to record how many slab have been searched\n    size_t counter = 0;\n\n    // Working queue before task started\n    const unsigned old_active_mask = active_mask;\n\n    // Lock the slabset before operating the slabset\n    warp_lock_mutex<mutex, warp_size>(warp_tile, set_mutex[next_set]);\n\n    // The warp-level inner loop: finish a single task in the work queue\n    while (active_mask == old_active_mask) {\n      // When all the slabs inside a slabset have been searched, mark missing task, do nothing, task\n      // complete\n      if (counter >= set_associativity) {\n        if (lane_idx == (size_t)next_lane) {\n          active = false;\n        }\n\n        active_mask = warp_tile.ballot(active);\n        break;\n      }\n\n      // The warp_tile read out the slab\n      key_type read_key = keys[next_set].set_[next_slab].slab_[lane_idx];\n\n      // Compare the slab data with the target key\n      int found_lane = __ffs(warp_tile.ballot(read_key == next_key)) - 1;\n\n      // If found, mark hit task, update the value, the task is completed\n      if (found_lane >= 0) {\n        size_t found_offset = (next_set * set_associativity + next_slab) * warp_size + found_lane;\n        if (lane_idx == (size_t)next_lane) {\n          active = false;\n        }\n\n        warp_tile_copy<warp_size>(lane_idx, embedding_vec_size,\n                                  vals + found_offset * embedding_vec_size,\n                                  d_values + next_idx * embedding_vec_size);\n\n        active_mask = warp_tile.ballot(active);\n        break;\n      }\n\n      // Compare the slab data with empty key, if found empty key, mark missing task, do nothing,\n      // task is completed\n      if (warp_tile.ballot(read_key == empty_key) != 0) {\n        if (lane_idx == (size_t)next_lane) {\n          active = false;\n        }\n\n        active_mask = warp_tile.ballot(active);\n        break;\n      }\n\n      // Not found in this slab, the task is not completed, goto searching next slab\n      counter++;\n      next_slab = (next_slab + 1) % set_associativity;\n    }\n\n    // Unlock the slabset after operating the slabset\n    warp_unlock_mutex<mutex, warp_size>(warp_tile, set_mutex[next_set]);\n  }\n}\n#else\n// Kernel to update the existing keys in the cache\n// Will not change the locality information\ntemplate <typename key_type, typename slabset, typename set_hasher, typename slab_hasher,\n          key_type empty_key, int set_associativity, int warp_size>\n__global__ void update_kernel(const key_type* d_keys, const size_t len, const float* d_values,\n                              const size_t embedding_vec_size, const size_t capacity_in_set,\n                              volatile slabset* keys, volatile float* vals, volatile int* set_mutex,\n                              const size_t task_per_warp_tile) {\n  // Lane(thread) ID within a warp_tile\n  cg::thread_block_tile<warp_size> warp_tile =\n      cg::tiled_partition<warp_size>(cg::this_thread_block());\n  const size_t lane_idx = warp_tile.thread_rank();\n  // Warp tile global ID\n  const size_t warp_tile_global_idx =\n      (blockIdx.x * (blockDim.x / warp_size)) + warp_tile.meta_group_rank();\n  // The index of key for this thread\n  const size_t key_idx = (warp_tile_global_idx * task_per_warp_tile) + lane_idx;\n  // The assigned key for this lane(thread)\n  key_type key;\n  // The dst slabset and the dst slab inside this set\n  size_t src_set;\n  size_t src_slab;\n  // Active flag: whether current lane(thread) has unfinished task\n  bool active = false;\n  if (lane_idx < task_per_warp_tile) {\n    if (key_idx < len) {\n      active = true;\n      key = d_keys[key_idx];\n      src_set = set_hasher::hash(key) % capacity_in_set;\n      src_slab = slab_hasher::hash(key) % set_associativity;\n    }\n  }\n\n  // Lane participate in warp_tile ballot to produce warp-level work queue\n  unsigned active_mask = warp_tile.ballot(active);\n\n  // The warp-level outer loop: finish all the tasks within the work queue\n  while (active_mask != 0) {\n    // Next task in the work quere, start from lower index lane(thread)\n    int next_lane = __ffs(active_mask) - 1;\n    // Broadcast the task and the global index to all lane in the warp_tile\n    key_type next_key = warp_tile.shfl(key, next_lane);\n    size_t next_idx = warp_tile.shfl(key_idx, next_lane);\n    size_t next_set = warp_tile.shfl(src_set, next_lane);\n    size_t next_slab = warp_tile.shfl(src_slab, next_lane);\n\n    // Counter to record how many slab have been searched\n    size_t counter = 0;\n\n    // Working queue before task started\n    const unsigned old_active_mask = active_mask;\n\n    // Lock the slabset before operating the slabset\n    warp_lock_mutex<warp_size>(warp_tile, set_mutex[next_set]);\n\n    // The warp-level inner loop: finish a single task in the work queue\n    while (active_mask == old_active_mask) {\n      // When all the slabs inside a slabset have been searched, mark missing task, do nothing, task\n      // complete\n      if (counter >= set_associativity) {\n        if (lane_idx == (size_t)next_lane) {\n          active = false;\n        }\n\n        active_mask = warp_tile.ballot(active);\n        break;\n      }\n\n      // The warp_tile read out the slab\n      key_type read_key = ((volatile key_type*)(keys[next_set].set_[next_slab].slab_))[lane_idx];\n\n      // Compare the slab data with the target key\n      int found_lane = __ffs(warp_tile.ballot(read_key == next_key)) - 1;\n\n      // If found, mark hit task, update the value, the task is completed\n      if (found_lane >= 0) {\n        size_t found_offset = (next_set * set_associativity + next_slab) * warp_size + found_lane;\n        if (lane_idx == (size_t)next_lane) {\n          active = false;\n        }\n\n        warp_tile_copy<warp_size>(lane_idx, embedding_vec_size,\n                                  (volatile float*)(vals + found_offset * embedding_vec_size),\n                                  (volatile float*)(d_values + next_idx * embedding_vec_size));\n\n        active_mask = warp_tile.ballot(active);\n        break;\n      }\n\n      // Compare the slab data with empty key, if found empty key, mark missing task, do nothing,\n      // task is completed\n      if (warp_tile.ballot(read_key == empty_key) != 0) {\n        if (lane_idx == (size_t)next_lane) {\n          active = false;\n        }\n\n        active_mask = warp_tile.ballot(active);\n        break;\n      }\n\n      // Not found in this slab, the task is not completed, goto searching next slab\n      counter++;\n      next_slab = (next_slab + 1) % set_associativity;\n    }\n\n    // Unlock the slabset after operating the slabset\n    warp_unlock_mutex<warp_size>(warp_tile, set_mutex[next_set]);\n  }\n}\n#endif\n\n#ifdef LIBCUDACXX_VERSION\ntemplate <typename key_type, typename slabset, typename mutex, key_type empty_key,\n          int set_associativity, int warp_size>\n__global__ void dump_kernel(key_type* d_keys, size_t* d_dump_counter, const slabset* keys,\n                            mutex* set_mutex, const size_t start_set_index,\n                            const size_t end_set_index) {\n  // Block-level counter used by all warp tiles within a block\n  __shared__ uint32_t block_acc;\n  // Initialize block-level counter\n  if (threadIdx.x == 0) {\n    block_acc = 0;\n  }\n  __syncthreads();\n  // Lane(thread) ID within a warp tile\n  cg::thread_block_tile<warp_size> warp_tile =\n      cg::tiled_partition<warp_size>(cg::this_thread_block());\n  const size_t lane_idx = warp_tile.thread_rank();\n  // Warp tile target slabset id\n  const size_t set_idx =\n      ((blockIdx.x * (blockDim.x / warp_size)) + warp_tile.meta_group_rank()) + start_set_index;\n  // Keys dump from cache\n  key_type read_key[set_associativity];\n  // Lane(thread) offset for storing each key\n  uint32_t thread_key_offset[set_associativity];\n  // Warp offset for storing each key\n  uint32_t warp_key_offset;\n  // Block offset for storing each key\n  __shared__ size_t block_key_offset;\n\n  // Warp tile dump target slabset\n  if (set_idx < end_set_index) {\n    // Lock the slabset before operating the slabset\n    warp_lock_mutex<mutex, warp_size>(warp_tile, set_mutex[set_idx]);\n\n    // The warp tile read out the slabset\n    for (unsigned slab_id = 0; slab_id < set_associativity; slab_id++) {\n      // The warp tile read out a slab\n      read_key[slab_id] = keys[set_idx].set_[slab_id].slab_[lane_idx];\n    }\n\n    // Finish dumping the slabset, unlock the slabset\n    warp_unlock_mutex<mutex, warp_size>(warp_tile, set_mutex[set_idx]);\n\n    // Each lane(thread) within the warp tile calculate the offset to store its keys\n    uint32_t warp_tile_total_keys = 0;\n    for (unsigned slab_id = 0; slab_id < set_associativity; slab_id++) {\n      unsigned valid_mask = warp_tile.ballot(read_key[slab_id] != empty_key);\n      thread_key_offset[slab_id] =\n          __popc(valid_mask & ((1U << lane_idx) - 1U)) + warp_tile_total_keys;\n      warp_tile_total_keys = warp_tile_total_keys + __popc(valid_mask);\n    }\n\n    // Each warp tile request a unique place from the block-level counter\n    if (lane_idx == 0) {\n      warp_key_offset = atomicAdd(&block_acc, warp_tile_total_keys);\n    }\n    warp_key_offset = warp_tile.shfl(warp_key_offset, 0);\n  }\n\n  // Each block request a unique place in global memory output buffer\n  __syncthreads();\n  if (threadIdx.x == 0) {\n    block_key_offset = atomicAdd(d_dump_counter, (size_t)block_acc);\n  }\n  __syncthreads();\n\n  // Warp tile store the (non-empty)keys back to output buffer\n  if (set_idx < end_set_index) {\n    for (unsigned slab_id = 0; slab_id < set_associativity; slab_id++) {\n      if (read_key[slab_id] != empty_key) {\n        d_keys[block_key_offset + warp_key_offset + thread_key_offset[slab_id]] = read_key[slab_id];\n      }\n    }\n  }\n}\n#else\ntemplate <typename key_type, typename slabset, key_type empty_key, int set_associativity,\n          int warp_size>\n__global__ void dump_kernel(key_type* d_keys, size_t* d_dump_counter, volatile slabset* keys,\n                            volatile int* set_mutex, const size_t start_set_index,\n                            const size_t end_set_index) {\n  // Block-level counter used by all warp tiles within a block\n  __shared__ uint32_t block_acc;\n  // Initialize block-level counter\n  if (threadIdx.x == 0) {\n    block_acc = 0;\n  }\n  __syncthreads();\n  // Lane(thread) ID within a warp tile\n  cg::thread_block_tile<warp_size> warp_tile =\n      cg::tiled_partition<warp_size>(cg::this_thread_block());\n  const size_t lane_idx = warp_tile.thread_rank();\n  // Warp tile target slabset id\n  const size_t set_idx =\n      ((blockIdx.x * (blockDim.x / warp_size)) + warp_tile.meta_group_rank()) + start_set_index;\n  // Keys dump from cache\n  key_type read_key[set_associativity];\n  // Lane(thread) offset for storing each key\n  uint32_t thread_key_offset[set_associativity];\n  // Warp offset for storing each key\n  uint32_t warp_key_offset;\n  // Block offset for storing each key\n  __shared__ size_t block_key_offset;\n\n  // Warp tile dump target slabset\n  if (set_idx < end_set_index) {\n    // Lock the slabset before operating the slabset\n    warp_lock_mutex<warp_size>(warp_tile, set_mutex[set_idx]);\n\n    // The warp tile read out the slabset\n    for (unsigned slab_id = 0; slab_id < set_associativity; slab_id++) {\n      // The warp tile read out a slab\n      read_key[slab_id] = ((volatile key_type*)(keys[set_idx].set_[slab_id].slab_))[lane_idx];\n    }\n\n    // Finish dumping the slabset, unlock the slabset\n    warp_unlock_mutex<warp_size>(warp_tile, set_mutex[set_idx]);\n\n    // Each lane(thread) within the warp tile calculate the offset to store its keys\n    uint32_t warp_tile_total_keys = 0;\n    for (unsigned slab_id = 0; slab_id < set_associativity; slab_id++) {\n      unsigned valid_mask = warp_tile.ballot(read_key[slab_id] != empty_key);\n      thread_key_offset[slab_id] =\n          __popc(valid_mask & ((1U << lane_idx) - 1U)) + warp_tile_total_keys;\n      warp_tile_total_keys = warp_tile_total_keys + __popc(valid_mask);\n    }\n\n    // Each warp tile request a unique place from the block-level counter\n    if (lane_idx == 0) {\n      warp_key_offset = atomicAdd(&block_acc, warp_tile_total_keys);\n    }\n    warp_key_offset = warp_tile.shfl(warp_key_offset, 0);\n  }\n\n  // Each block request a unique place in global memory output buffer\n  __syncthreads();\n  if (threadIdx.x == 0) {\n    block_key_offset = atomicAdd(d_dump_counter, (size_t)block_acc);\n  }\n  __syncthreads();\n\n  // Warp tile store the (non-empty)keys back to output buffer\n  if (set_idx < end_set_index) {\n    for (unsigned slab_id = 0; slab_id < set_associativity; slab_id++) {\n      if (read_key[slab_id] != empty_key) {\n        d_keys[block_key_offset + warp_key_offset + thread_key_offset[slab_id]] = read_key[slab_id];\n      }\n    }\n  }\n}\n#endif\n///////////////////////////////////////////////////////////////////////////////////////////////////\n\n#ifdef LIBCUDACXX_VERSION\ntemplate <typename key_type, typename ref_counter_type, key_type empty_key, int set_associativity,\n          int warp_size, typename set_hasher, typename slab_hasher>\ngpu_cache<key_type, ref_counter_type, empty_key, set_associativity, warp_size, set_hasher,\n          slab_hasher>::gpu_cache(const size_t capacity_in_set, const size_t embedding_vec_size)\n    : capacity_in_set_(capacity_in_set), embedding_vec_size_(embedding_vec_size) {\n  // Check parameter\n  if (capacity_in_set_ == 0) {\n    printf(\"Error: Invalid value for capacity_in_set.\\n\");\n    return;\n  }\n  if (embedding_vec_size_ == 0) {\n    printf(\"Error: Invalid value for embedding_vec_size.\\n\");\n    return;\n  }\n  if (set_associativity <= 0) {\n    printf(\"Error: Invalid value for set_associativity.\\n\");\n    return;\n  }\n  if (warp_size != 1 && warp_size != 2 && warp_size != 4 && warp_size != 8 && warp_size != 16 &&\n      warp_size != 32) {\n    printf(\"Error: Invalid value for warp_size.\\n\");\n    return;\n  }\n\n  // Get the current CUDA dev\n  CUDA_CHECK(cudaGetDevice(&dev_));\n\n  // Calculate # of slot\n  num_slot_ = capacity_in_set_ * set_associativity * warp_size;\n\n  // Allocate GPU memory for cache\n  CUDA_CHECK(cudaMalloc((void**)&keys_, sizeof(slabset) * capacity_in_set_));\n  CUDA_CHECK(cudaMalloc((void**)&vals_, sizeof(float) * embedding_vec_size_ * num_slot_));\n  CUDA_CHECK(cudaMalloc((void**)&slot_counter_, sizeof(ref_counter_type) * num_slot_));\n  CUDA_CHECK(cudaMalloc((void**)&global_counter_, sizeof(atomic_ref_counter_type)));\n\n  // Allocate GPU memory for set mutex\n  CUDA_CHECK(cudaMalloc((void**)&set_mutex_, sizeof(mutex) * capacity_in_set_));\n\n  // Initialize the cache, set all entry to unused <K,V>\n  init_cache<<<((num_slot_ - 1) / BLOCK_SIZE_) + 1, BLOCK_SIZE_>>>(\n      keys_, slot_counter_, global_counter_, num_slot_, empty_key, set_mutex_, capacity_in_set_);\n\n  // Wait for initialization to finish\n  CUDA_CHECK(cudaStreamSynchronize(0));\n  CUDA_CHECK(cudaGetLastError());\n}\n#else\ntemplate <typename key_type, typename ref_counter_type, key_type empty_key, int set_associativity,\n          int warp_size, typename set_hasher, typename slab_hasher>\ngpu_cache<key_type, ref_counter_type, empty_key, set_associativity, warp_size, set_hasher,\n          slab_hasher>::gpu_cache(const size_t capacity_in_set, const size_t embedding_vec_size)\n    : capacity_in_set_(capacity_in_set), embedding_vec_size_(embedding_vec_size) {\n  // Check parameter\n  if (capacity_in_set_ == 0) {\n    printf(\"Error: Invalid value for capacity_in_set.\\n\");\n    return;\n  }\n  if (embedding_vec_size_ == 0) {\n    printf(\"Error: Invalid value for embedding_vec_size.\\n\");\n    return;\n  }\n  if (set_associativity <= 0) {\n    printf(\"Error: Invalid value for set_associativity.\\n\");\n    return;\n  }\n  if (warp_size != 1 && warp_size != 2 && warp_size != 4 && warp_size != 8 && warp_size != 16 &&\n      warp_size != 32) {\n    printf(\"Error: Invalid value for warp_size.\\n\");\n    return;\n  }\n\n  // Get the current CUDA dev\n  CUDA_CHECK(cudaGetDevice(&dev_));\n\n  // Calculate # of slot\n  num_slot_ = capacity_in_set_ * set_associativity * warp_size;\n\n  // Allocate GPU memory for cache\n  CUDA_CHECK(cudaMalloc((void**)&keys_, sizeof(slabset) * capacity_in_set_));\n  CUDA_CHECK(cudaMalloc((void**)&vals_, sizeof(float) * embedding_vec_size_ * num_slot_));\n  CUDA_CHECK(cudaMalloc((void**)&slot_counter_, sizeof(ref_counter_type) * num_slot_));\n  CUDA_CHECK(cudaMalloc((void**)&global_counter_, sizeof(ref_counter_type)));\n\n  // Allocate GPU memory for set mutex\n  CUDA_CHECK(cudaMalloc((void**)&set_mutex_, sizeof(int) * capacity_in_set_));\n\n  // Initialize the cache, set all entry to unused <K,V>\n  init_cache<<<((num_slot_ - 1) / BLOCK_SIZE_) + 1, BLOCK_SIZE_>>>(\n      keys_, slot_counter_, global_counter_, num_slot_, empty_key, set_mutex_, capacity_in_set_);\n\n  // Wait for initialization to finish\n  CUDA_CHECK(cudaStreamSynchronize(0));\n  CUDA_CHECK(cudaGetLastError());\n}\n#endif\n\n#ifdef LIBCUDACXX_VERSION\ntemplate <typename key_type, typename ref_counter_type, key_type empty_key, int set_associativity,\n          int warp_size, typename set_hasher, typename slab_hasher>\ngpu_cache<key_type, ref_counter_type, empty_key, set_associativity, warp_size, set_hasher,\n          slab_hasher>::~gpu_cache() {\n  // Device Restorer\n  nv::CudaDeviceRestorer dev_restorer;\n\n  // Check device\n  dev_restorer.check_device(dev_);\n\n  // Destruct CUDA std object\n  destruct_kernel<<<((capacity_in_set_ - 1) / BLOCK_SIZE_) + 1, BLOCK_SIZE_>>>(\n      global_counter_, set_mutex_, capacity_in_set_);\n  // Wait for destruction to finish\n  CUDA_CHECK(cudaStreamSynchronize(0));\n\n  // Free GPU memory for cache\n  CUDA_CHECK(cudaFree(keys_));\n  CUDA_CHECK(cudaFree(vals_));\n  CUDA_CHECK(cudaFree(slot_counter_));\n  CUDA_CHECK(cudaFree(global_counter_));\n  // Free GPU memory for set mutex\n  CUDA_CHECK(cudaFree(set_mutex_));\n}\n#else\ntemplate <typename key_type, typename ref_counter_type, key_type empty_key, int set_associativity,\n          int warp_size, typename set_hasher, typename slab_hasher>\ngpu_cache<key_type, ref_counter_type, empty_key, set_associativity, warp_size, set_hasher,\n          slab_hasher>::~gpu_cache() noexcept(false) {\n  // Device Restorer\n  nv::CudaDeviceRestorer dev_restorer;\n\n  // Check device\n  dev_restorer.check_device(dev_);\n\n  // Free GPU memory for cache\n  CUDA_CHECK(cudaFree(keys_));\n  CUDA_CHECK(cudaFree(vals_));\n  CUDA_CHECK(cudaFree(slot_counter_));\n  CUDA_CHECK(cudaFree(global_counter_));\n  // Free GPU memory for set mutex\n  CUDA_CHECK(cudaFree(set_mutex_));\n}\n#endif\n\n#ifdef LIBCUDACXX_VERSION\ntemplate <typename key_type, typename ref_counter_type, key_type empty_key, int set_associativity,\n          int warp_size, typename set_hasher, typename slab_hasher>\nvoid gpu_cache<key_type, ref_counter_type, empty_key, set_associativity, warp_size, set_hasher,\n               slab_hasher>::Query(const key_type* d_keys, const size_t len, float* d_values,\n                                   uint64_t* d_missing_index, key_type* d_missing_keys,\n                                   size_t* d_missing_len, cudaStream_t stream,\n                                   const size_t task_per_warp_tile) {\n  // Device Restorer\n  nv::CudaDeviceRestorer dev_restorer;\n  // Check device\n  dev_restorer.check_device(dev_);\n\n  // Check if it is a valid query\n  if (len == 0) {\n    // Set the d_missing_len to 0 before return\n    CUDA_CHECK(cudaMemsetAsync(d_missing_len, 0, sizeof(size_t), stream));\n    return;\n  }\n\n  // Update the global counter as user perform a new(most recent) read operation to the cache\n  // Resolve distance overflow issue as well.\n  update_kernel_overflow_ignore<atomic_ref_counter_type>\n      <<<1, 1, 0, stream>>>(global_counter_, d_missing_len);\n\n  // Read from the cache\n  // Touch and refresh the hitting slot\n  const size_t keys_per_block = (BLOCK_SIZE_ / warp_size) * task_per_warp_tile;\n  const size_t grid_size = ((len - 1) / keys_per_block) + 1;\n  get_kernel<key_type, ref_counter_type, atomic_ref_counter_type, slabset, set_hasher, slab_hasher,\n             mutex, empty_key, set_associativity, warp_size><<<grid_size, BLOCK_SIZE_, 0, stream>>>(\n      d_keys, len, d_values, embedding_vec_size_, d_missing_index, d_missing_keys, d_missing_len,\n      global_counter_, slot_counter_, capacity_in_set_, keys_, vals_, set_mutex_,\n      task_per_warp_tile);\n\n  // Check for GPU error before return\n  CUDA_CHECK(cudaGetLastError());\n}\n#else\ntemplate <typename key_type, typename ref_counter_type, key_type empty_key, int set_associativity,\n          int warp_size, typename set_hasher, typename slab_hasher>\nvoid gpu_cache<key_type, ref_counter_type, empty_key, set_associativity, warp_size, set_hasher,\n               slab_hasher>::Query(const key_type* d_keys, const size_t len, float* d_values,\n                                   uint64_t* d_missing_index, key_type* d_missing_keys,\n                                   size_t* d_missing_len, cudaStream_t stream,\n                                   const size_t task_per_warp_tile) {\n  // Device Restorer\n  nv::CudaDeviceRestorer dev_restorer;\n  // Check device\n  dev_restorer.check_device(dev_);\n\n  // Check if it is a valid query\n  if (len == 0) {\n    // Set the d_missing_len to 0 before return\n    CUDA_CHECK(cudaMemsetAsync(d_missing_len, 0, sizeof(size_t), stream));\n    return;\n  }\n\n  // Update the global counter as user perform a new(most recent) read operation to the cache\n  // Resolve distance overflow issue as well.\n  update_kernel_overflow_ignore<ref_counter_type>\n      <<<1, 1, 0, stream>>>(global_counter_, d_missing_len);\n\n  // Read from the cache\n  // Touch and refresh the hitting slot\n  const size_t keys_per_block = (BLOCK_SIZE_ / warp_size) * task_per_warp_tile;\n  const size_t grid_size = ((len - 1) / keys_per_block) + 1;\n  get_kernel<key_type, ref_counter_type, slabset, set_hasher, slab_hasher, empty_key,\n             set_associativity, warp_size><<<grid_size, BLOCK_SIZE_, 0, stream>>>(\n      d_keys, len, d_values, embedding_vec_size_, d_missing_index, d_missing_keys, d_missing_len,\n      global_counter_, slot_counter_, capacity_in_set_, keys_, vals_, set_mutex_,\n      task_per_warp_tile);\n\n  // Check for GPU error before return\n  CUDA_CHECK(cudaGetLastError());\n}\n#endif\n\n#ifdef LIBCUDACXX_VERSION\ntemplate <typename key_type, typename ref_counter_type, key_type empty_key, int set_associativity,\n          int warp_size, typename set_hasher, typename slab_hasher>\nvoid gpu_cache<key_type, ref_counter_type, empty_key, set_associativity, warp_size, set_hasher,\n               slab_hasher>::Replace(const key_type* d_keys, const size_t len,\n                                     const float* d_values, cudaStream_t stream,\n                                     const size_t task_per_warp_tile) {\n  // Check if it is a valid replacement\n  if (len == 0) {\n    return;\n  }\n\n  // Device Restorer\n  nv::CudaDeviceRestorer dev_restorer;\n  // Check device\n  dev_restorer.check_device(dev_);\n\n  // Try to insert the <k,v> paris into the cache as long as there are unused slot\n  // Then replace the <k,v> pairs into the cache\n  const size_t keys_per_block = (BLOCK_SIZE_ / warp_size) * task_per_warp_tile;\n  const size_t grid_size = ((len - 1) / keys_per_block) + 1;\n  insert_replace_kernel<key_type, slabset, ref_counter_type, mutex, atomic_ref_counter_type,\n                        set_hasher, slab_hasher, empty_key, set_associativity, warp_size>\n      <<<grid_size, BLOCK_SIZE_, 0, stream>>>(d_keys, d_values, embedding_vec_size_, len, keys_,\n                                              vals_, slot_counter_, set_mutex_, global_counter_,\n                                              capacity_in_set_, task_per_warp_tile);\n\n  // Check for GPU error before return\n  CUDA_CHECK(cudaGetLastError());\n}\n#else\ntemplate <typename key_type, typename ref_counter_type, key_type empty_key, int set_associativity,\n          int warp_size, typename set_hasher, typename slab_hasher>\nvoid gpu_cache<key_type, ref_counter_type, empty_key, set_associativity, warp_size, set_hasher,\n               slab_hasher>::Replace(const key_type* d_keys, const size_t len,\n                                     const float* d_values, cudaStream_t stream,\n                                     const size_t task_per_warp_tile) {\n  // Check if it is a valid replacement\n  if (len == 0) {\n    return;\n  }\n\n  // Device Restorer\n  nv::CudaDeviceRestorer dev_restorer;\n  // Check device\n  dev_restorer.check_device(dev_);\n\n  // Try to insert the <k,v> paris into the cache as long as there are unused slot\n  // Then replace the <k,v> pairs into the cache\n  const size_t keys_per_block = (BLOCK_SIZE_ / warp_size) * task_per_warp_tile;\n  const size_t grid_size = ((len - 1) / keys_per_block) + 1;\n  insert_replace_kernel<key_type, slabset, ref_counter_type, set_hasher, slab_hasher, empty_key,\n                        set_associativity, warp_size><<<grid_size, BLOCK_SIZE_, 0, stream>>>(\n      d_keys, d_values, embedding_vec_size_, len, keys_, vals_, slot_counter_, set_mutex_,\n      global_counter_, capacity_in_set_, task_per_warp_tile);\n\n  // Check for GPU error before return\n  CUDA_CHECK(cudaGetLastError());\n}\n#endif\n\n#ifdef LIBCUDACXX_VERSION\ntemplate <typename key_type, typename ref_counter_type, key_type empty_key, int set_associativity,\n          int warp_size, typename set_hasher, typename slab_hasher>\nvoid gpu_cache<key_type, ref_counter_type, empty_key, set_associativity, warp_size, set_hasher,\n               slab_hasher>::Update(const key_type* d_keys, const size_t len, const float* d_values,\n                                    cudaStream_t stream, const size_t task_per_warp_tile) {\n  // Check if it is a valid update request\n  if (len == 0) {\n    return;\n  }\n\n  // Device Restorer\n  nv::CudaDeviceRestorer dev_restorer;\n  // Check device\n  dev_restorer.check_device(dev_);\n\n  // Update the value of input keys that are existed in the cache\n  const size_t keys_per_block = (BLOCK_SIZE_ / warp_size) * task_per_warp_tile;\n  const size_t grid_size = ((len - 1) / keys_per_block) + 1;\n  update_kernel<key_type, slabset, set_hasher, slab_hasher, mutex, empty_key, set_associativity,\n                warp_size><<<grid_size, BLOCK_SIZE_, 0, stream>>>(\n      d_keys, len, d_values, embedding_vec_size_, capacity_in_set_, keys_, vals_, set_mutex_,\n      task_per_warp_tile);\n\n  // Check for GPU error before return\n  CUDA_CHECK(cudaGetLastError());\n}\n#else\ntemplate <typename key_type, typename ref_counter_type, key_type empty_key, int set_associativity,\n          int warp_size, typename set_hasher, typename slab_hasher>\nvoid gpu_cache<key_type, ref_counter_type, empty_key, set_associativity, warp_size, set_hasher,\n               slab_hasher>::Update(const key_type* d_keys, const size_t len, const float* d_values,\n                                    cudaStream_t stream, const size_t task_per_warp_tile) {\n  // Check if it is a valid update request\n  if (len == 0) {\n    return;\n  }\n\n  // Device Restorer\n  nv::CudaDeviceRestorer dev_restorer;\n  // Check device\n  dev_restorer.check_device(dev_);\n\n  // Update the value of input keys that are existed in the cache\n  const size_t keys_per_block = (BLOCK_SIZE_ / warp_size) * task_per_warp_tile;\n  const size_t grid_size = ((len - 1) / keys_per_block) + 1;\n  update_kernel<key_type, slabset, set_hasher, slab_hasher, empty_key, set_associativity, warp_size>\n      <<<grid_size, BLOCK_SIZE_, 0, stream>>>(d_keys, len, d_values, embedding_vec_size_,\n                                              capacity_in_set_, keys_, vals_, set_mutex_,\n                                              task_per_warp_tile);\n\n  // Check for GPU error before return\n  CUDA_CHECK(cudaGetLastError());\n}\n#endif\n\n#ifdef LIBCUDACXX_VERSION\ntemplate <typename key_type, typename ref_counter_type, key_type empty_key, int set_associativity,\n          int warp_size, typename set_hasher, typename slab_hasher>\nvoid gpu_cache<key_type, ref_counter_type, empty_key, set_associativity, warp_size, set_hasher,\n               slab_hasher>::Dump(key_type* d_keys, size_t* d_dump_counter,\n                                  const size_t start_set_index, const size_t end_set_index,\n                                  cudaStream_t stream) {\n  // Check if it is a valid dump request\n  if (start_set_index >= capacity_in_set_) {\n    printf(\"Error: Invalid value for start_set_index. Nothing dumped.\\n\");\n    return;\n  }\n  if (end_set_index <= start_set_index || end_set_index > capacity_in_set_) {\n    printf(\"Error: Invalid value for end_set_index. Nothing dumped.\\n\");\n    return;\n  }\n\n  // Device Restorer\n  nv::CudaDeviceRestorer dev_restorer;\n  // Check device\n  dev_restorer.check_device(dev_);\n\n  // Set the global counter to 0 first\n  CUDA_CHECK(cudaMemsetAsync(d_dump_counter, 0, sizeof(size_t), stream));\n\n  // Dump keys from the cache\n  const size_t grid_size =\n      (((end_set_index - start_set_index) - 1) / (BLOCK_SIZE_ / warp_size)) + 1;\n  dump_kernel<key_type, slabset, mutex, empty_key, set_associativity, warp_size>\n      <<<grid_size, BLOCK_SIZE_, 0, stream>>>(d_keys, d_dump_counter, keys_, set_mutex_,\n                                              start_set_index, end_set_index);\n\n  // Check for GPU error before return\n  CUDA_CHECK(cudaGetLastError());\n}\n#else\ntemplate <typename key_type, typename ref_counter_type, key_type empty_key, int set_associativity,\n          int warp_size, typename set_hasher, typename slab_hasher>\nvoid gpu_cache<key_type, ref_counter_type, empty_key, set_associativity, warp_size, set_hasher,\n               slab_hasher>::Dump(key_type* d_keys, size_t* d_dump_counter,\n                                  const size_t start_set_index, const size_t end_set_index,\n                                  cudaStream_t stream) {\n  // Check if it is a valid dump request\n  if (start_set_index >= capacity_in_set_) {\n    printf(\"Error: Invalid value for start_set_index. Nothing dumped.\\n\");\n    return;\n  }\n  if (end_set_index <= start_set_index || end_set_index > capacity_in_set_) {\n    printf(\"Error: Invalid value for end_set_index. Nothing dumped.\\n\");\n    return;\n  }\n\n  // Device Restorer\n  nv::CudaDeviceRestorer dev_restorer;\n  // Check device\n  dev_restorer.check_device(dev_);\n\n  // Set the global counter to 0 first\n  CUDA_CHECK(cudaMemsetAsync(d_dump_counter, 0, sizeof(size_t), stream));\n\n  // Dump keys from the cache\n  const size_t grid_size =\n      (((end_set_index - start_set_index) - 1) / (BLOCK_SIZE_ / warp_size)) + 1;\n  dump_kernel<key_type, slabset, empty_key, set_associativity, warp_size>\n      <<<grid_size, BLOCK_SIZE_, 0, stream>>>(d_keys, d_dump_counter, keys_, set_mutex_,\n                                              start_set_index, end_set_index);\n\n  // Check for GPU error before return\n  CUDA_CHECK(cudaGetLastError());\n}\n#endif\n\ntemplate class gpu_cache<unsigned int, uint64_t, std::numeric_limits<unsigned int>::max(),\n                         SET_ASSOCIATIVITY, SLAB_SIZE>;\ntemplate class gpu_cache<long long, uint64_t, std::numeric_limits<long long>::max(),\n                         SET_ASSOCIATIVITY, SLAB_SIZE>;\n}  // namespace gpu_cache\n"
  },
  {
    "path": "tools/README.md",
    "content": "# DGL Utility Scripts\n\nThis folder contains the utilities that do not belong to DGL core package as standalone executable\nscripts.\n\n## Graph Chunking\n\n`chunk_graph.py` provides an example of chunking an existing DGLGraph object into the on-disk\n[chunked graph format](http://13.231.216.217/guide/distributed-preprocessing.html#chunked-graph-format).\n\n<!-- TODO: change the link of documentation once it's merged to master -->\n\nAn example of chunking the OGB MAG240M dataset:\n\n```python\nimport ogb.lsc\n\ndataset = ogb.lsc.MAG240MDataset('.')\netypes = [\n    ('paper', 'cites', 'paper'),\n    ('author', 'writes', 'paper'),\n    ('author', 'affiliated_with', 'institution')]\ng = dgl.heterograph({k: tuple(dataset.edge_index(*k)) for k in etypes})\nchunk_graph(\n    g,\n    'mag240m',\n    {'paper': {\n        'feat': 'mag240m_kddcup2021/processed/paper/node_feat.npy',\n        'label': 'mag240m_kddcup2021/processed/paper/node_label.npy',\n        'year': 'mag240m_kddcup2021/processed/paper/node_year.npy'}},\n    {},\n    4,\n    'output')\n```\n\nThe output chunked graph metadata will go as follows (assuming the current directory as\n`/home/user`:\n\n```json\n{\n    \"graph_name\": \"mag240m\",\n    \"node_type\": [\n        \"author\",\n        \"institution\",\n        \"paper\"\n    ],\n    \"num_nodes_per_chunk\": [\n        [\n            30595778,\n            30595778,\n            30595778,\n            30595778\n        ],\n        [\n            6431,\n            6430,\n            6430,\n            6430\n        ],\n        [\n            30437917,\n            30437917,\n            30437916,\n            30437916\n        ]\n    ],\n    \"edge_type\": [\n        \"author:affiliated_with:institution\",\n        \"author:writes:paper\",\n        \"paper:cites:paper\"\n    ],\n    \"num_edges_per_chunk\": [\n        [\n            11148147,\n            11148147,\n            11148146,\n            11148146\n        ],\n        [\n            96505680,\n            96505680,\n            96505680,\n            96505680\n        ],\n        [\n            324437232,\n            324437232,\n            324437231,\n            324437231\n        ]\n    ],\n    \"edges\": {\n        \"author:affiliated_with:institution\": {\n            \"format\": {\n                \"name\": \"csv\",\n                \"delimiter\": \" \"\n            },\n            \"data\": [\n                \"/home/user/output/edge_index/author:affiliated_with:institution0.txt\",\n                \"/home/user/output/edge_index/author:affiliated_with:institution1.txt\",\n                \"/home/user/output/edge_index/author:affiliated_with:institution2.txt\",\n                \"/home/user/output/edge_index/author:affiliated_with:institution3.txt\"\n            ]\n        },\n        \"author:writes:paper\": {\n            \"format\": {\n                \"name\": \"csv\",\n                \"delimiter\": \" \"\n            },\n            \"data\": [\n                \"/home/user/output/edge_index/author:writes:paper0.txt\",\n                \"/home/user/output/edge_index/author:writes:paper1.txt\",\n                \"/home/user/output/edge_index/author:writes:paper2.txt\",\n                \"/home/user/output/edge_index/author:writes:paper3.txt\"\n            ]\n        },\n        \"paper:cites:paper\": {\n            \"format\": {\n                \"name\": \"csv\",\n                \"delimiter\": \" \"\n            },\n            \"data\": [\n                \"/home/user/output/edge_index/paper:cites:paper0.txt\",\n                \"/home/user/output/edge_index/paper:cites:paper1.txt\",\n                \"/home/user/output/edge_index/paper:cites:paper2.txt\",\n                \"/home/user/output/edge_index/paper:cites:paper3.txt\"\n            ]\n        }\n    },\n    \"node_data\": {\n        \"paper\": {\n            \"feat\": {\n                \"format\": {\n                    \"name\": \"numpy\"\n                },\n                \"data\": [\n                    \"/home/user/output/node_data/paper/feat-0.npy\",\n                    \"/home/user/output/node_data/paper/feat-1.npy\",\n                    \"/home/user/output/node_data/paper/feat-2.npy\",\n                    \"/home/user/output/node_data/paper/feat-3.npy\"\n                ]\n            },\n            \"label\": {\n                \"format\": {\n                    \"name\": \"numpy\"\n                },\n                \"data\": [\n                    \"/home/user/output/node_data/paper/label-0.npy\",\n                    \"/home/user/output/node_data/paper/label-1.npy\",\n                    \"/home/user/output/node_data/paper/label-2.npy\",\n                    \"/home/user/output/node_data/paper/label-3.npy\"\n                ]\n            },\n            \"year\": {\n                \"format\": {\n                    \"name\": \"numpy\"\n                },\n                \"data\": [\n                    \"/home/user/output/node_data/paper/year-0.npy\",\n                    \"/home/user/output/node_data/paper/year-1.npy\",\n                    \"/home/user/output/node_data/paper/year-2.npy\",\n                    \"/home/user/output/node_data/paper/year-3.npy\"\n                ]\n            }\n        }\n    },\n    \"edge_data\": {}\n}\n```\n\n## Change edge type to canonical edge type for partition configuration json\n\nIn the upcoming DGL v1.0, we will require the partition configuration file to contain only canonical edge type. This tool is designed to help migrating existing configuration files from old style to new one.\n\n### Sample Usage\n\n```\npython tools/change_etype_to_canonical_etype.py --part_config \"{configuration file path}\"\n```\n\n### Requirement\n\nPartition algorithms produce one configuration file and multiple data folders, and each data folder corresponds to a partition. **This tool needs to read from the partition configuration file (specified by the commandline argument) *and* the graph structure data (stored in `graph.dgl` under the data folder) of the first partition.** They can be local files or shared files among network, if you follow this [official tutorial](https://docs.dgl.ai/en/latest/tutorials/dist/1_node_classification.html#sphx-glr-tutorials-dist-1-node-classification-py) for distributed training, you don't need to care about this as all files are shared by every participant through NFS.\n\n**For example, below is a typical data folder expected by this tool:**\n```\ndata_root_dir/\n|-- graph_name.json    # specified by part_config\n|-- part0/\n    ...\n    |-- graph.dgl\n...\n```\n\nFor more information about partition algorithm, see https://docs.dgl.ai/en/latest/generated/dgl.distributed.partition.partition_graph.html.\n\n### Input arguments\n\n1. *part_config*: The path of partition json file. < **Required**>\n\n### Result\n\nThis tool changes the key of ``etypes`` and ``edge_map`` from format ``str`` to ``str:str:str`` and it overwrites the original file instead of creating a new one.\n\nE.g. **File content before running the script**\n```json\n{\n    \"edge_map\": {\n        \"r1\": [ [ 0, 6 ], [ 16, 20 ] ],\n        \"r2\": [ [ 6, 11 ], [ 20, 25 ] ],\n        \"r3\": [ [ 11, 16 ], [ 25, 30 ] ]\n    },\n    \"etypes\": {\n        \"r1\": 0,\n        \"r2\": 1,\n        \"r3\": 2\n    },\n    ...\n}\n```\n\n**After running**\n```json\n{\n    \"edge_map\": {\n        \"n1:r1:n2\": [ [ 0, 6 ], [ 16, 20 ] ],\n        \"n1:r2:n3\": [ [ 6, 11 ], [ 20, 25 ] ],\n        \"n2:r3:n3\": [ [ 11, 16 ], [ 25, 30 ] ] },\n    \"etypes\": {\n        \"n1:r1:n2\": 0,\n        \"n1:r2:n3\": 1,\n        \"n2:r3:n3\": 2\n    }\n    ...\n}\n```\n"
  },
  {
    "path": "tools/change_etype_to_canonical_etype.py",
    "content": "import argparse\nimport json\nimport logging\nimport os\nimport time\n\nimport dgl\n\nimport torch\nfrom dgl._ffi.base import DGLError\nfrom dgl.data.utils import load_graphs\nfrom dgl.utils import toindex\n\nETYPES_KEY = \"etypes\"\nEDGE_MAP_KEY = \"edge_map\"\nNTYPES_KEY = \"ntypes\"\nNUM_PARTS_KEY = \"num_parts\"\nCANONICAL_ETYPE_DELIMITER = \":\"\n\n\ndef convert_conf(part_config):\n    with open(part_config, \"r+\", encoding=\"utf-8\") as f:\n        config = json.load(f)\n        logging.info(\"Checking if the provided json file need to be changed.\")\n        if is_old_version(config):\n            logging.info(\"Changing the partition configuration file.\")\n            canonical_etypes = {}\n            if len(config[NTYPES_KEY]) == 1:\n                ntype = list(config[NTYPES_KEY].keys())[0]\n                canonical_etypes = {\n                    CANONICAL_ETYPE_DELIMITER.join((ntype, etype, ntype)): eid\n                    for etype, eid in config[ETYPES_KEY].items()\n                }\n            else:\n                canonical_etypes = etype2canonical_etype(part_config, config)\n            reverse_c_etypes = {v: k for k, v in canonical_etypes.items()}\n            # Convert edge_map keys from etype -> c_etype.\n            new_edge_map = {}\n            for e_type, range in config[EDGE_MAP_KEY].items():\n                eid = config[ETYPES_KEY][e_type]\n                c_etype = reverse_c_etypes[eid]\n                new_edge_map[c_etype] = range\n            config[EDGE_MAP_KEY] = new_edge_map\n            config[ETYPES_KEY] = canonical_etypes\n            logging.info(\"Dumping the content to disk.\")\n            f.seek(0)\n            json.dump(config, f, indent=4)\n            f.truncate()\n\n\ndef etype2canonical_etype(part_config, config):\n    num_parts = config[NUM_PARTS_KEY]\n    edge_map = config[EDGE_MAP_KEY]\n    etypes = list(edge_map.keys())\n    # Get part id of each seed edge.\n    partition_ids = []\n    for _, bound in edge_map.items():\n        for i in range(num_parts):\n            if bound[i][1] > bound[i][0]:\n                partition_ids.append(i)\n                break\n    partition_ids = torch.tensor(partition_ids)\n\n    # Get starting index of each partition.\n    shifts = []\n    for i in range(num_parts):\n        shifts.append(edge_map[etypes[0]][i][0])\n    shifts = torch.tensor(shifts)\n\n    canonical_etypes = {}\n    part_ids = [\n        part_id for part_id in range(num_parts) if part_id in partition_ids\n    ]\n    for part_id in part_ids:\n        seed_etypes = [\n            etypes[i] for i in range(len(etypes)) if partition_ids[i] == part_id\n        ]\n        c_etype = _find_c_etypes_in_partition(\n            part_id,\n            seed_etypes,\n            config[ETYPES_KEY],\n            config[NTYPES_KEY],\n            edge_map,\n            shifts,\n            part_config,\n        )\n        canonical_etypes.update(c_etype)\n    return canonical_etypes\n\n\ndef _find_c_etypes_in_partition(\n    part_id, seed_etypes, etypes, ntypes, edge_map, shifts, config_path\n):\n    try:\n        folder = os.path.dirname(os.path.realpath(config_path))\n        local_g = load_graphs(f\"{folder}/part{part_id}/graph.dgl\")[0][0]\n        local_eids = [\n            edge_map[etype][part_id][0] - shifts[part_id]\n            for etype in seed_etypes\n        ]\n        local_eids = toindex(torch.tensor(local_eids))\n        local_eids = local_eids.tousertensor()\n        local_src, local_dst = local_g.find_edges(local_eids)\n        src_ntids, dst_ntids = (\n            local_g.ndata[dgl.NTYPE][local_src],\n            local_g.ndata[dgl.NTYPE][local_dst],\n        )\n        ntypes = {v: k for k, v in ntypes.items()}\n        src_ntypes = [ntypes[ntid.item()] for ntid in src_ntids]\n        dst_ntypes = [ntypes[ntid.item()] for ntid in dst_ntids]\n        c_etypes = list(zip(src_ntypes, seed_etypes, dst_ntypes))\n        c_etypes = [\n            CANONICAL_ETYPE_DELIMITER.join(c_etype) for c_etype in c_etypes\n        ]\n        return {k: etypes[v] for (k, v) in zip(c_etypes, seed_etypes)}\n    except DGLError as e:\n        print(e)\n        logging.fatal(\n            f\"Graph data of partition {part_id} is requested but not found.\"\n        )\n\n\ndef is_old_version(config):\n    first_etype = list(config[ETYPES_KEY].keys())[0]\n    etype_tuple = first_etype.split(CANONICAL_ETYPE_DELIMITER)\n    return len(etype_tuple) == 1\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(\n        description=\"Change edge type in config file from format (str)\"\n        \" to (str,str,str), the original file will be overwritten\",\n        formatter_class=argparse.ArgumentDefaultsHelpFormatter,\n    )\n\n    parser.add_argument(\n        \"--part_config\", type=str, help=\"The file of the partition config\"\n    )\n    args, _ = parser.parse_known_args()\n    assert (\n        args.part_config is not None\n    ), \"A user has to specify a partition config file with --part_config.\"\n\n    start = time.time()\n    convert_conf(args.part_config)\n    end = time.time()\n    logging.info(f\"elplased time in seconds: {end - start}\")\n"
  },
  {
    "path": "tools/chunk_graph.py",
    "content": "# See the __main__ block for usage of chunk_graph().\nimport json\nimport logging\nimport os\nimport pathlib\nfrom contextlib import contextmanager\n\nimport dgl\n\nimport torch\nfrom distpartitioning import array_readwriter\nfrom files import setdir\n\n\ndef chunk_numpy_array(arr, fmt_meta, chunk_sizes, path_fmt):\n    paths = []\n    offset = 0\n\n    for j, n in enumerate(chunk_sizes):\n        path = os.path.abspath(path_fmt % j)\n        arr_chunk = arr[offset : offset + n]\n        logging.info(\"Chunking %d-%d\" % (offset, offset + n))\n        array_readwriter.get_array_parser(**fmt_meta).write(path, arr_chunk)\n        offset += n\n        paths.append(path)\n\n    return paths\n\n\ndef _chunk_graph(\n    g, name, ndata_paths, edata_paths, num_chunks, output_path, data_fmt\n):\n    # First deal with ndata and edata that are homogeneous (i.e. not a dict-of-dict)\n    if len(g.ntypes) == 1 and not isinstance(\n        next(iter(ndata_paths.values())), dict\n    ):\n        ndata_paths = {g.ntypes[0]: ndata_paths}\n    if len(g.etypes) == 1 and not isinstance(\n        next(iter(edata_paths.values())), dict\n    ):\n        edata_paths = {g.etypes[0]: ndata_paths}\n    # Then convert all edge types to canonical edge types\n    etypestrs = {etype: \":\".join(etype) for etype in g.canonical_etypes}\n    edata_paths = {\n        \":\".join(g.to_canonical_etype(k)): v for k, v in edata_paths.items()\n    }\n\n    metadata = {}\n\n    metadata[\"graph_name\"] = name\n    metadata[\"node_type\"] = g.ntypes\n\n    # Compute the number of nodes per chunk per node type\n    metadata[\"num_nodes_per_chunk\"] = num_nodes_per_chunk = []\n    for ntype in g.ntypes:\n        num_nodes = g.num_nodes(ntype)\n        num_nodes_list = []\n        for i in range(num_chunks):\n            n = num_nodes // num_chunks + (i < num_nodes % num_chunks)\n            num_nodes_list.append(n)\n        num_nodes_per_chunk.append(num_nodes_list)\n    num_nodes_per_chunk_dict = {\n        k: v for k, v in zip(g.ntypes, num_nodes_per_chunk)\n    }\n\n    metadata[\"edge_type\"] = [etypestrs[etype] for etype in g.canonical_etypes]\n\n    # Compute the number of edges per chunk per edge type\n    metadata[\"num_edges_per_chunk\"] = num_edges_per_chunk = []\n    for etype in g.canonical_etypes:\n        num_edges = g.num_edges(etype)\n        num_edges_list = []\n        for i in range(num_chunks):\n            n = num_edges // num_chunks + (i < num_edges % num_chunks)\n            num_edges_list.append(n)\n        num_edges_per_chunk.append(num_edges_list)\n    num_edges_per_chunk_dict = {\n        k: v for k, v in zip(g.canonical_etypes, num_edges_per_chunk)\n    }\n\n    # Split edge index\n    metadata[\"edges\"] = {}\n    with setdir(\"edge_index\"):\n        for etype in g.canonical_etypes:\n            etypestr = etypestrs[etype]\n            logging.info(\"Chunking edge index for %s\" % etypestr)\n            edges_meta = {}\n            fmt_meta = {\"name\": \"csv\", \"delimiter\": \" \"}\n            edges_meta[\"format\"] = fmt_meta\n\n            srcdst = torch.stack(g.edges(etype=etype), 1)\n            edges_meta[\"data\"] = chunk_numpy_array(\n                srcdst.numpy(),\n                fmt_meta,\n                num_edges_per_chunk_dict[etype],\n                etypestr + \"%d.txt\",\n            )\n            metadata[\"edges\"][etypestr] = edges_meta\n\n    # Chunk node data\n    reader_fmt_meta, writer_fmt_meta = {\"name\": \"numpy\"}, {\"name\": data_fmt}\n    file_suffix = \"npy\" if data_fmt == \"numpy\" else \"parquet\"\n    metadata[\"node_data\"] = {}\n    with setdir(\"node_data\"):\n        for ntype, ndata_per_type in ndata_paths.items():\n            ndata_meta = {}\n            with setdir(ntype):\n                for key, path in ndata_per_type.items():\n                    logging.info(\n                        \"Chunking node data for type %s key %s\" % (ntype, key)\n                    )\n                    ndata_key_meta = {}\n                    arr = array_readwriter.get_array_parser(\n                        **reader_fmt_meta\n                    ).read(path)\n                    ndata_key_meta[\"format\"] = writer_fmt_meta\n                    ndata_key_meta[\"data\"] = chunk_numpy_array(\n                        arr,\n                        writer_fmt_meta,\n                        num_nodes_per_chunk_dict[ntype],\n                        key + \"-%d.\" + file_suffix,\n                    )\n                    ndata_meta[key] = ndata_key_meta\n\n            metadata[\"node_data\"][ntype] = ndata_meta\n\n    # Chunk edge data\n    metadata[\"edge_data\"] = {}\n    with setdir(\"edge_data\"):\n        for etypestr, edata_per_type in edata_paths.items():\n            edata_meta = {}\n            with setdir(etypestr):\n                for key, path in edata_per_type.items():\n                    logging.info(\n                        \"Chunking edge data for type %s key %s\"\n                        % (etypestr, key)\n                    )\n                    edata_key_meta = {}\n                    arr = array_readwriter.get_array_parser(\n                        **reader_fmt_meta\n                    ).read(path)\n                    edata_key_meta[\"format\"] = writer_fmt_meta\n                    etype = tuple(etypestr.split(\":\"))\n                    edata_key_meta[\"data\"] = chunk_numpy_array(\n                        arr,\n                        writer_fmt_meta,\n                        num_edges_per_chunk_dict[etype],\n                        key + \"-%d.\" + file_suffix,\n                    )\n                    edata_meta[key] = edata_key_meta\n\n            metadata[\"edge_data\"][etypestr] = edata_meta\n\n    metadata_path = \"metadata.json\"\n    with open(metadata_path, \"w\") as f:\n        json.dump(metadata, f, sort_keys=True, indent=4)\n    logging.info(\"Saved metadata in %s\" % os.path.abspath(metadata_path))\n\n\ndef chunk_graph(\n    g, name, ndata_paths, edata_paths, num_chunks, output_path, data_fmt=\"numpy\"\n):\n    \"\"\"\n    Split the graph into multiple chunks.\n\n    A directory will be created at :attr:`output_path` with the metadata and chunked\n    edge list as well as the node/edge data.\n\n    Parameters\n    ----------\n    g : DGLGraph\n        The graph.\n    name : str\n        The name of the graph, to be used later in DistDGL training.\n    ndata_paths : dict[str, pathlike] or dict[ntype, dict[str, pathlike]]\n        The dictionary of paths pointing to the corresponding numpy array file for each\n        node data key.\n    edata_paths : dict[etype, pathlike] or dict[etype, dict[str, pathlike]]\n        The dictionary of paths pointing to the corresponding numpy array file for each\n        edge data key. ``etype`` could be canonical or non-canonical.\n    num_chunks : int\n        The number of chunks\n    output_path : pathlike\n        The output directory saving the chunked graph.\n    \"\"\"\n    for ntype, ndata in ndata_paths.items():\n        for key in ndata.keys():\n            ndata[key] = os.path.abspath(ndata[key])\n    for etype, edata in edata_paths.items():\n        for key in edata.keys():\n            edata[key] = os.path.abspath(edata[key])\n    with setdir(output_path):\n        _chunk_graph(\n            g, name, ndata_paths, edata_paths, num_chunks, output_path, data_fmt\n        )\n\n\nif __name__ == \"__main__\":\n    logging.basicConfig(level=\"INFO\")\n    input_dir = \"/data\"\n    output_dir = \"/chunked-data\"\n    (g,), _ = dgl.load_graphs(os.path.join(input_dir, \"graph.dgl\"))\n    chunk_graph(\n        g,\n        \"mag240m\",\n        {\n            \"paper\": {\n                \"feat\": os.path.join(input_dir, \"paper/feat.npy\"),\n                \"label\": os.path.join(input_dir, \"paper/label.npy\"),\n                \"year\": os.path.join(input_dir, \"paper/year.npy\"),\n            }\n        },\n        {\n            \"cites\": {\"count\": os.path.join(input_dir, \"cites/count.npy\")},\n            \"writes\": {\"year\": os.path.join(input_dir, \"writes/year.npy\")},\n            # you can put the same data file if they indeed share the features.\n            \"rev_writes\": {\"year\": os.path.join(input_dir, \"writes/year.npy\")},\n        },\n        4,\n        output_dir,\n    )\n# The generated metadata goes as in tools/sample-config/mag240m-metadata.json.\n"
  },
  {
    "path": "tools/copy_files.py",
    "content": "\"\"\"Copy the partitions to a cluster of machines.\"\"\"\nimport argparse\nimport copy\nimport json\nimport logging\nimport os\nimport signal\nimport stat\nimport subprocess\nimport sys\n\n\ndef copy_file(file_name, ip, workspace, param=\"\"):\n    print(\"copy {} to {}\".format(file_name, ip + \":\" + workspace + \"/\"))\n    cmd = \"scp \" + param + \" \" + file_name + \" \" + ip + \":\" + workspace + \"/\"\n    subprocess.check_call(cmd, shell=True)\n\n\ndef exec_cmd(ip, cmd):\n    cmd = \"ssh -o StrictHostKeyChecking=no \" + ip + \" '\" + cmd + \"'\"\n    subprocess.check_call(cmd, shell=True)\n\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"Copy data to the servers.\")\n    parser.add_argument(\n        \"--workspace\",\n        type=str,\n        required=True,\n        help=\"Path of user directory of distributed tasks. \\\n                        This is used to specify a destination location where \\\n                        data are copied to on remote machines.\",\n    )\n    parser.add_argument(\n        \"--rel_data_path\",\n        type=str,\n        required=True,\n        help=\"Relative path in workspace to store the partition data.\",\n    )\n    parser.add_argument(\n        \"--part_config\",\n        type=str,\n        required=True,\n        help=\"The partition config file. The path is on the local machine.\",\n    )\n    parser.add_argument(\n        \"--script_folder\",\n        type=str,\n        required=True,\n        help=\"The folder contains all the user code scripts.\",\n    )\n    parser.add_argument(\n        \"--ip_config\",\n        type=str,\n        required=True,\n        help=\"The file of IP configuration for servers. \\\n                        The path is on the local machine.\",\n    )\n    args = parser.parse_args()\n\n    hosts = []\n    with open(args.ip_config) as f:\n        for line in f:\n            res = line.strip().split(\" \")\n            ip = res[0]\n            hosts.append(ip)\n\n    # We need to update the partition config file so that the paths are relative to\n    # the workspace in the remote machines.\n    with open(args.part_config) as conf_f:\n        part_metadata = json.load(conf_f)\n        tmp_part_metadata = copy.deepcopy(part_metadata)\n        num_parts = part_metadata[\"num_parts\"]\n        assert num_parts == len(\n            hosts\n        ), \"The number of partitions needs to be the same as the number of hosts.\"\n        graph_name = part_metadata[\"graph_name\"]\n        node_map = part_metadata[\"node_map\"]\n        edge_map = part_metadata[\"edge_map\"]\n        if not isinstance(node_map, dict):\n            assert (\n                node_map[-4:] == \".npy\"\n            ), \"node map should be stored in a NumPy array.\"\n            tmp_part_metadata[\"node_map\"] = \"{}/{}/node_map.npy\".format(\n                args.workspace, args.rel_data_path\n            )\n        if not isinstance(edge_map, dict):\n            assert (\n                edge_map[-4:] == \".npy\"\n            ), \"edge map should be stored in a NumPy array.\"\n            tmp_part_metadata[\"edge_map\"] = \"{}/{}/edge_map.npy\".format(\n                args.workspace, args.rel_data_path\n            )\n\n        for part_id in range(num_parts):\n            part_files = tmp_part_metadata[\"part-{}\".format(part_id)]\n            part_files[\"edge_feats\"] = \"{}/part{}/edge_feat.dgl\".format(\n                args.rel_data_path, part_id\n            )\n            part_files[\"node_feats\"] = \"{}/part{}/node_feat.dgl\".format(\n                args.rel_data_path, part_id\n            )\n            part_files[\"part_graph\"] = \"{}/part{}/graph.dgl\".format(\n                args.rel_data_path, part_id\n            )\n    tmp_part_config = \"/tmp/{}.json\".format(graph_name)\n    with open(tmp_part_config, \"w\") as outfile:\n        json.dump(tmp_part_metadata, outfile, sort_keys=True, indent=4)\n\n    # Copy ip config.\n    for part_id, ip in enumerate(hosts):\n        remote_path = \"{}/{}\".format(args.workspace, args.rel_data_path)\n        exec_cmd(ip, \"mkdir -p {}\".format(remote_path))\n\n        copy_file(args.ip_config, ip, args.workspace)\n        copy_file(\n            tmp_part_config,\n            ip,\n            \"{}/{}\".format(args.workspace, args.rel_data_path),\n        )\n        node_map = part_metadata[\"node_map\"]\n        edge_map = part_metadata[\"edge_map\"]\n        if not isinstance(node_map, dict):\n            copy_file(node_map, ip, tmp_part_metadata[\"node_map\"])\n        if not isinstance(edge_map, dict):\n            copy_file(edge_map, ip, tmp_part_metadata[\"edge_map\"])\n        remote_path = \"{}/{}/part{}\".format(\n            args.workspace, args.rel_data_path, part_id\n        )\n        exec_cmd(ip, \"mkdir -p {}\".format(remote_path))\n\n        part_files = part_metadata[\"part-{}\".format(part_id)]\n        copy_file(part_files[\"node_feats\"], ip, remote_path)\n        copy_file(part_files[\"edge_feats\"], ip, remote_path)\n        copy_file(part_files[\"part_graph\"], ip, remote_path)\n        # copy script folder\n        copy_file(args.script_folder, ip, args.workspace, \"-r\")\n\n\ndef signal_handler(signal, frame):\n    logging.info(\"Stop copying\")\n    sys.exit(0)\n\n\nif __name__ == \"__main__\":\n    fmt = \"%(asctime)s %(levelname)s %(message)s\"\n    logging.basicConfig(format=fmt, level=logging.INFO)\n    signal.signal(signal.SIGINT, signal_handler)\n    main()\n"
  },
  {
    "path": "tools/dispatch_data.py",
    "content": "\"\"\"Launching distributed graph partitioning pipeline \"\"\"\nimport argparse\nimport json\nimport logging\nimport os\nimport sys\n\nfrom partition_algo.base import load_partition_meta\n\nINSTALL_DIR = os.path.abspath(os.path.join(__file__, \"..\"))\nLAUNCH_SCRIPT = \"distgraphlaunch.py\"\nPIPELINE_SCRIPT = \"distpartitioning/data_proc_pipeline.py\"\n\nUDF_WORLD_SIZE = \"world-size\"\nUDF_PART_DIR = \"partitions-dir\"\nUDF_INPUT_DIR = \"input-dir\"\nUDF_GRAPH_NAME = \"graph-name\"\nUDF_SCHEMA = \"schema\"\nUDF_NUM_PARTS = \"num-parts\"\nUDF_OUT_DIR = \"output\"\n\nLARG_PROCS_MACHINE = \"num_proc_per_machine\"\nLARG_IPCONF = \"ip_config\"\nLARG_MASTER_PORT = \"master_port\"\nLARG_SSH_PORT = \"ssh_port\"\n\n\ndef get_launch_cmd(args) -> str:\n    cmd = sys.executable + \" \" + os.path.join(INSTALL_DIR, LAUNCH_SCRIPT)\n    cmd = f\"{cmd} --{LARG_SSH_PORT} {args.ssh_port} \"\n    cmd = f\"{cmd} --{LARG_PROCS_MACHINE} 1 \"\n    cmd = f\"{cmd} --{LARG_IPCONF} {args.ip_config} \"\n    cmd = f\"{cmd} --{LARG_MASTER_PORT} {args.master_port} \"\n\n    return cmd\n\n\ndef submit_jobs(args) -> str:\n    # read the json file and get the remaining argument here.\n    schema_path = args.metadata_filename\n    with open(os.path.join(args.in_dir, schema_path)) as schema:\n        schema_map = json.load(schema)\n\n    graph_name = schema_map[\"graph_name\"]\n\n    # retrieve num_parts\n    num_parts = 0\n    partition_path = os.path.join(args.partitions_dir, \"partition_meta.json\")\n    if os.path.isfile(partition_path):\n        part_meta = load_partition_meta(partition_path)\n        num_parts = part_meta.num_parts\n\n    assert (\n        num_parts != 0\n    ), f\"Invalid value for no. of partitions. Please check partition_meta.json file.\"\n\n    # verify ip_config\n    with open(args.ip_config, \"r\") as f:\n        num_ips = len(f.readlines())\n        assert (\n            num_parts % num_ips == 0\n        ), f\"The num_parts[{args.num_parts}] should be a multiple of number of lines(ip addresses)[{args.ip_config}].\"\n\n    argslist = \"\"\n    argslist += \"--world-size {} \".format(num_ips)\n    argslist += \"--partitions-dir {} \".format(\n        os.path.abspath(args.partitions_dir)\n    )\n    argslist += \"--input-dir {} \".format(os.path.abspath(args.in_dir))\n    argslist += \"--graph-name {} \".format(graph_name)\n    argslist += \"--schema {} \".format(schema_path)\n    argslist += \"--num-parts {} \".format(num_parts)\n    argslist += \"--output {} \".format(os.path.abspath(args.out_dir))\n    argslist += \"--process-group-timeout {} \".format(args.process_group_timeout)\n    argslist += \"--log-level {} \".format(args.log_level)\n    argslist += \"--save-orig-nids \" if args.save_orig_nids else \"\"\n    argslist += \"--save-orig-eids \" if args.save_orig_eids else \"\"\n    argslist += \"--use-graphbolt \" if args.use_graphbolt else \"\"\n    argslist += \"--store-eids \" if args.store_eids else \"\"\n    argslist += \"--store-inner-node \" if args.store_inner_node else \"\"\n    argslist += \"--store-inner-edge \" if args.store_inner_edge else \"\"\n    argslist += (\n        f\"--graph-formats {args.graph_formats} \" if args.graph_formats else \"\"\n    )\n\n    # (BarclayII) Is it safe to assume all the workers have the Python executable at the same path?\n    pipeline_cmd = os.path.join(INSTALL_DIR, PIPELINE_SCRIPT)\n    udf_cmd = f\"{args.python_path} {pipeline_cmd} {argslist}\"\n\n    launch_cmd = get_launch_cmd(args)\n    launch_cmd += '\"' + udf_cmd + '\"'\n\n    print(launch_cmd)\n    os.system(launch_cmd)\n\n\ndef main():\n    parser = argparse.ArgumentParser(\n        description=\"Dispatch edge index and data to partitions\",\n        formatter_class=argparse.ArgumentDefaultsHelpFormatter,\n    )\n\n    parser.add_argument(\n        \"--in-dir\",\n        type=str,\n        help=\"Location of the input directory where the dataset is located\",\n    )\n    parser.add_argument(\n        \"--metadata-filename\",\n        type=str,\n        default=\"metadata.json\",\n        help=\"Filename for the metadata JSON file that describes the dataset to be dispatched.\",\n    )\n    parser.add_argument(\n        \"--partitions-dir\",\n        type=str,\n        help=\"Location of the partition-id mapping files which define node-ids and their respective partition-ids, relative to the input directory\",\n    )\n    parser.add_argument(\n        \"--out-dir\",\n        type=str,\n        help=\"Location of the output directory where the graph partitions will be created by this pipeline\",\n    )\n    parser.add_argument(\n        \"--ip-config\",\n        type=str,\n        help=\"File location of IP configuration for server processes\",\n    )\n    parser.add_argument(\n        \"--master-port\",\n        type=int,\n        default=12345,\n        help=\"port used by gloo group to create randezvous point\",\n    )\n    parser.add_argument(\n        \"--log-level\",\n        required=False,\n        type=str,\n        help=\"Log level to use for execution.\",\n        default=\"INFO\",\n        choices=[\"DEBUG\", \"INFO\", \"WARNING\", \"ERROR\", \"CRITICAL\"],\n    )\n    parser.add_argument(\n        \"--python-path\",\n        type=str,\n        default=sys.executable,\n        help=\"Path to the Python executable on all workers\",\n    )\n    parser.add_argument(\"--ssh-port\", type=int, default=22, help=\"SSH Port.\")\n    parser.add_argument(\n        \"--process-group-timeout\",\n        type=int,\n        default=1800,\n        help=\"timeout[seconds] for operations executed against the process group\",\n    )\n    parser.add_argument(\n        \"--save-orig-nids\",\n        action=\"store_true\",\n        help=\"Save original node IDs into files\",\n    )\n    parser.add_argument(\n        \"--save-orig-eids\",\n        action=\"store_true\",\n        help=\"Save original edge IDs into files\",\n    )\n    parser.add_argument(\n        \"--use-graphbolt\",\n        action=\"store_true\",\n        help=\"Use GraphBolt for distributed partition.\",\n    )\n    parser.add_argument(\n        \"--store-inner-node\",\n        action=\"store_true\",\n        default=False,\n        help=\"Store inner nodes.\",\n    )\n\n    parser.add_argument(\n        \"--store-inner-edge\",\n        action=\"store_true\",\n        default=False,\n        help=\"Store inner edges.\",\n    )\n    parser.add_argument(\n        \"--store-eids\",\n        action=\"store_true\",\n        default=False,\n        help=\"Store edge IDs.\",\n    )\n    parser.add_argument(\n        \"--graph-formats\",\n        type=str,\n        default=None,\n        help=\"Save partitions in specified formats. It could be any combination(joined with ``,``) \"\n        \"of ``coo``, ``csc`` and ``csr``. If not specified, save one format only according to \"\n        \"what format is available. If multiple formats are available, selection priority \"\n        \"from high to low is ``coo``, ``csc``, ``csr``.\",\n    )\n\n    args, _ = parser.parse_known_args()\n\n    fmt = \"%(asctime)s %(levelname)s %(message)s\"\n    logging.basicConfig(\n        format=fmt,\n        level=getattr(logging, args.log_level, None),\n    )\n\n    assert os.path.isdir(args.in_dir)\n    assert os.path.isdir(args.partitions_dir)\n    assert os.path.isfile(args.ip_config)\n    assert isinstance(args.master_port, int)\n\n    submit_jobs(args)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "tools/distgraphlaunch.py",
    "content": "\"\"\"Launching tool for DGL distributed training\"\"\"\nimport argparse\nimport json\nimport logging\nimport multiprocessing\nimport os\nimport re\nimport signal\nimport stat\nimport subprocess\nimport sys\nimport time\nfrom functools import partial\nfrom threading import Thread\nfrom typing import Optional\n\nDEFAULT_PORT = 30050\n\n\ndef cleanup_proc(get_all_remote_pids, conn):\n    \"\"\"This process tries to clean up the remote training tasks.\"\"\"\n    print(\"cleanupu process runs\")\n    # This process should not handle SIGINT.\n    signal.signal(signal.SIGINT, signal.SIG_IGN)\n\n    data = conn.recv()\n    # If the launch process exits normally, this process doesn't need to do anything.\n    if data == \"exit\":\n        sys.exit(0)\n    else:\n        remote_pids = get_all_remote_pids()\n        # Otherwise, we need to ssh to each machine and kill the training jobs.\n        for (ip, port), pids in remote_pids.items():\n            kill_process(ip, port, pids)\n    print(\"cleanup process exits\")\n\n\ndef kill_process(ip, port, pids):\n    \"\"\"ssh to a remote machine and kill the specified processes.\"\"\"\n    curr_pid = os.getpid()\n    killed_pids = []\n    # If we kill child processes first, the parent process may create more again. This happens\n    # to Python's process pool. After sorting, we always kill parent processes first.\n    pids.sort()\n    for pid in pids:\n        assert curr_pid != pid\n        print(\"kill process {} on {}:{}\".format(pid, ip, port), flush=True)\n        kill_cmd = (\n            \"ssh -o StrictHostKeyChecking=no -p \"\n            + str(port)\n            + \" \"\n            + ip\n            + \" 'kill {}'\".format(pid)\n        )\n        subprocess.run(kill_cmd, shell=True)\n        killed_pids.append(pid)\n    # It's possible that some of the processes are not killed. Let's try again.\n    for i in range(3):\n        killed_pids = get_killed_pids(ip, port, killed_pids)\n        if len(killed_pids) == 0:\n            break\n        else:\n            killed_pids.sort()\n            for pid in killed_pids:\n                print(\n                    \"kill process {} on {}:{}\".format(pid, ip, port), flush=True\n                )\n                kill_cmd = (\n                    \"ssh -o StrictHostKeyChecking=no -p \"\n                    + str(port)\n                    + \" \"\n                    + ip\n                    + \" 'kill -9 {}'\".format(pid)\n                )\n                subprocess.run(kill_cmd, shell=True)\n\n\ndef get_killed_pids(ip, port, killed_pids):\n    \"\"\"Get the process IDs that we want to kill but are still alive.\"\"\"\n    killed_pids = [str(pid) for pid in killed_pids]\n    killed_pids = \",\".join(killed_pids)\n    ps_cmd = (\n        \"ssh -o StrictHostKeyChecking=no -p \"\n        + str(port)\n        + \" \"\n        + ip\n        + \" 'ps -p {} -h'\".format(killed_pids)\n    )\n    res = subprocess.run(ps_cmd, shell=True, stdout=subprocess.PIPE)\n    pids = []\n    for p in res.stdout.decode(\"utf-8\").split(\"\\n\"):\n        l = p.split()\n        if len(l) > 0:\n            pids.append(int(l[0]))\n    return pids\n\n\ndef execute_remote(\n    cmd: str, ip: str, port: int, username: Optional[str] = \"\"\n) -> Thread:\n    \"\"\"Execute command line on remote machine via ssh.\n\n    Args:\n        cmd: User-defined command (udf) to execute on the remote host.\n        ip: The ip-address of the host to run the command on.\n        port: Port number that the host is listening on.\n        thread_list:\n        username: Optional. If given, this will specify a username to use when issuing commands over SSH.\n            Useful when your infra requires you to explicitly specify a username to avoid permission issues.\n\n    Returns:\n        thread: The Thread whose run() is to run the `cmd` on the remote host. Returns when the cmd completes\n            on the remote host.\n    \"\"\"\n    ip_prefix = \"\"\n    if username:\n        ip_prefix += \"{username}@\".format(username=username)\n\n    # Construct ssh command that executes `cmd` on the remote host\n    ssh_cmd = \"ssh -o StrictHostKeyChecking=no -p {port} {ip_prefix}{ip} '{cmd}'\".format(\n        port=str(port),\n        ip_prefix=ip_prefix,\n        ip=ip,\n        cmd=cmd,\n    )\n\n    # thread func to run the job\n    def run(ssh_cmd):\n        subprocess.check_call(ssh_cmd, shell=True)\n\n    thread = Thread(target=run, args=(ssh_cmd,))\n    thread.setDaemon(True)\n    thread.start()\n    return thread\n\n\ndef get_remote_pids(ip, port, cmd_regex):\n    \"\"\"Get the process IDs that run the command in the remote machine.\"\"\"\n    pids = []\n    curr_pid = os.getpid()\n    # Here we want to get the python processes. We may get some ssh processes, so we should filter them out.\n    ps_cmd = (\n        \"ssh -o StrictHostKeyChecking=no -p \"\n        + str(port)\n        + \" \"\n        + ip\n        + \" 'ps -aux | grep python | grep -v StrictHostKeyChecking'\"\n    )\n    res = subprocess.run(ps_cmd, shell=True, stdout=subprocess.PIPE)\n    for p in res.stdout.decode(\"utf-8\").split(\"\\n\"):\n        l = p.split()\n        if len(l) < 2:\n            continue\n        # We only get the processes that run the specified command.\n        res = re.search(cmd_regex, p)\n        if res is not None and int(l[1]) != curr_pid:\n            pids.append(l[1])\n\n    pid_str = \",\".join([str(pid) for pid in pids])\n    ps_cmd = (\n        \"ssh -o StrictHostKeyChecking=no -p \"\n        + str(port)\n        + \" \"\n        + ip\n        + \" 'pgrep -P {}'\".format(pid_str)\n    )\n    res = subprocess.run(ps_cmd, shell=True, stdout=subprocess.PIPE)\n    pids1 = res.stdout.decode(\"utf-8\").split(\"\\n\")\n    all_pids = []\n    for pid in set(pids + pids1):\n        if pid == \"\" or int(pid) == curr_pid:\n            continue\n        all_pids.append(int(pid))\n    all_pids.sort()\n    return all_pids\n\n\ndef get_all_remote_pids(hosts, ssh_port, udf_command):\n    \"\"\"Get all remote processes.\"\"\"\n    remote_pids = {}\n    for node_id, host in enumerate(hosts):\n        ip, _ = host\n        # When creating training processes in remote machines, we may insert some arguments\n        # in the commands. We need to use regular expressions to match the modified command.\n        cmds = udf_command.split()\n        new_udf_command = \" .*\".join(cmds)\n        pids = get_remote_pids(ip, ssh_port, new_udf_command)\n        remote_pids[(ip, ssh_port)] = pids\n    return remote_pids\n\n\ndef construct_torch_dist_launcher_cmd(\n    num_trainers: int,\n    num_nodes: int,\n    node_rank: int,\n    master_addr: str,\n    master_port: int,\n) -> str:\n    \"\"\"Constructs the torch distributed launcher command.\n    Helper function.\n\n    Args:\n        num_trainers:\n        num_nodes:\n        node_rank:\n        master_addr:\n        master_port:\n\n    Returns:\n        cmd_str.\n    \"\"\"\n    torch_cmd_template = (\n        \"-m torch.distributed.launch \"\n        \"--nproc_per_node={nproc_per_node} \"\n        \"--nnodes={nnodes} \"\n        \"--node_rank={node_rank} \"\n        \"--master_addr={master_addr} \"\n        \"--master_port={master_port}\"\n    )\n    return torch_cmd_template.format(\n        nproc_per_node=num_trainers,\n        nnodes=num_nodes,\n        node_rank=node_rank,\n        master_addr=master_addr,\n        master_port=master_port,\n    )\n\n\ndef wrap_udf_in_torch_dist_launcher(\n    udf_command: str,\n    num_trainers: int,\n    num_nodes: int,\n    node_rank: int,\n    master_addr: str,\n    master_port: int,\n) -> str:\n    \"\"\"Wraps the user-defined function (udf_command) with the torch.distributed.launch module.\n\n     Example: if udf_command is \"python3 run/some/trainer.py arg1 arg2\", then new_df_command becomes:\n         \"python3 -m torch.distributed.launch <TORCH DIST ARGS> run/some/trainer.py arg1 arg2\n\n    udf_command is assumed to consist of pre-commands (optional) followed by the python launcher script (required):\n    Examples:\n        # simple\n        python3.7 path/to/some/trainer.py arg1 arg2\n\n        # multi-commands\n        (cd some/dir && python3.7 path/to/some/trainer.py arg1 arg2)\n\n    IMPORTANT: If udf_command consists of multiple python commands, then this will result in undefined behavior.\n\n    Args:\n        udf_command:\n        num_trainers:\n        num_nodes:\n        node_rank:\n        master_addr:\n        master_port:\n\n    Returns:\n\n    \"\"\"\n    torch_dist_cmd = construct_torch_dist_launcher_cmd(\n        num_trainers=num_trainers,\n        num_nodes=num_nodes,\n        node_rank=node_rank,\n        master_addr=master_addr,\n        master_port=master_port,\n    )\n    # Auto-detect the python binary that kicks off the distributed trainer code.\n    # Note: This allowlist order matters, this will match with the FIRST matching entry. Thus, please add names to this\n    #       from most-specific to least-specific order eg:\n    #           (python3.7, python3.8) -> (python3)\n    # The allowed python versions are from this: https://www.dgl.ai/pages/start.html\n    python_bin_allowlist = (\n        \"python3.6\",\n        \"python3.7\",\n        \"python3.8\",\n        \"python3.9\",\n        \"python3\",\n        # for backwards compatibility, accept python2 but technically DGL is a py3 library, so this is not recommended\n        \"python2.7\",\n        \"python2\",\n    )\n    # If none of the candidate python bins match, then we go with the default `python`\n    python_bin = \"python\"\n    for candidate_python_bin in python_bin_allowlist:\n        if candidate_python_bin in udf_command:\n            python_bin = candidate_python_bin\n            break\n\n    # transforms the udf_command from:\n    #     python path/to/dist_trainer.py arg0 arg1\n    # to:\n    #     python -m torch.distributed.launch [DIST TORCH ARGS] path/to/dist_trainer.py arg0 arg1\n    # Note: if there are multiple python commands in `udf_command`, this may do the Wrong Thing, eg launch each\n    #       python command within the torch distributed launcher.\n    new_udf_command = udf_command.replace(\n        python_bin, f\"{python_bin} {torch_dist_cmd}\"\n    )\n\n    return new_udf_command\n\n\ndef construct_dgl_server_env_vars(\n    ip_config: str,\n    num_proc_per_machine: int,\n    pythonpath: Optional[str] = \"\",\n) -> str:\n    \"\"\"Constructs the DGL server-specific env vars string that are required for DGL code to behave in the correct\n    server role.\n    Convenience function.\n\n    Args:\n        ip_config: IP config file containing IP addresses of cluster hosts.\n            Relative path to workspace.\n        num_proc_per_machine:\n        pythonpath: Optional. If given, this will pass this as PYTHONPATH.\n\n    Returns:\n        server_env_vars: The server-specific env-vars in a string format, friendly for CLI execution.\n\n    \"\"\"\n    server_env_vars_template = (\n        \"DGL_IP_CONFIG={DGL_IP_CONFIG} \"\n        \"DGL_NUM_SERVER={DGL_NUM_SERVER} \"\n        \"{suffix_optional_envvars}\"\n    )\n    suffix_optional_envvars = \"\"\n    if pythonpath:\n        suffix_optional_envvars += f\"PYTHONPATH={pythonpath} \"\n    return server_env_vars_template.format(\n        DGL_IP_CONFIG=ip_config,\n        DGL_NUM_SERVER=num_proc_per_machine,\n        suffix_optional_envvars=suffix_optional_envvars,\n    )\n\n\ndef wrap_cmd_with_local_envvars(cmd: str, env_vars: str) -> str:\n    \"\"\"Wraps a CLI command with desired env vars with the following properties:\n    (1) env vars persist for the entire `cmd`, even if it consists of multiple \"chained\" commands like:\n        cmd = \"ls && pwd && python run/something.py\"\n    (2) env vars don't pollute the environment after `cmd` completes.\n\n    Example:\n        >>> cmd = \"ls && pwd\"\n        >>> env_vars = \"VAR1=value1 VAR2=value2\"\n        >>> wrap_cmd_with_local_envvars(cmd, env_vars)\n        \"(export VAR1=value1 VAR2=value2; ls && pwd)\"\n\n    Args:\n        cmd:\n        env_vars: A string containing env vars, eg \"VAR1=val1 VAR2=val2\"\n\n    Returns:\n        cmd_with_env_vars:\n\n    \"\"\"\n    # use `export` to persist env vars for entire cmd block. required if udf_command is a chain of commands\n    # also: wrap in parens to not pollute env:\n    #     https://stackoverflow.com/a/45993803\n    return f\"(export {env_vars}; {cmd})\"\n\n\ndef wrap_cmd_with_extra_envvars(cmd: str, env_vars: list) -> str:\n    \"\"\"Wraps a CLI command with extra env vars\n\n    Example:\n        >>> cmd = \"ls && pwd\"\n        >>> env_vars = [\"VAR1=value1\", \"VAR2=value2\"]\n        >>> wrap_cmd_with_extra_envvars(cmd, env_vars)\n        \"(export VAR1=value1 VAR2=value2; ls && pwd)\"\n\n    Args:\n        cmd:\n        env_vars: A list of strings containing env vars, e.g., [\"VAR1=value1\", \"VAR2=value2\"]\n\n    Returns:\n        cmd_with_env_vars:\n    \"\"\"\n    env_vars = \" \".join(env_vars)\n    return wrap_cmd_with_local_envvars(cmd, env_vars)\n\n\ndef submit_jobs(args, udf_command):\n    \"\"\"Submit distributed jobs (server and client processes) via ssh\"\"\"\n    hosts = []\n    thread_list = []\n    server_count_per_machine = 0\n\n    # Get the IP addresses of the cluster.\n    # ip_config = os.path.join(args.workspace, args.ip_config)\n    ip_config = args.ip_config\n    with open(ip_config) as f:\n        for line in f:\n            result = line.strip().split()\n            if len(result) == 2:\n                ip = result[0]\n                port = int(result[1])\n                hosts.append((ip, port))\n            elif len(result) == 1:\n                ip = result[0]\n                port = DEFAULT_PORT\n                hosts.append((ip, port))\n            else:\n                raise RuntimeError(\"Format error of ip_config.\")\n            server_count_per_machine = args.num_proc_per_machine\n\n    # launch server tasks\n    server_env_vars = construct_dgl_server_env_vars(\n        ip_config=args.ip_config,\n        num_proc_per_machine=args.num_proc_per_machine,\n        pythonpath=os.environ.get(\"PYTHONPATH\", \"\"),\n    )\n    for i in range(len(hosts) * server_count_per_machine):\n        ip, _ = hosts[int(i / server_count_per_machine)]\n        server_env_vars_cur = f\"{server_env_vars} RANK={i} MASTER_ADDR={hosts[0][0]} MASTER_PORT={args.master_port}\"\n        cmd = wrap_cmd_with_local_envvars(udf_command, server_env_vars_cur)\n        print(cmd)\n        thread_list.append(\n            execute_remote(cmd, ip, args.ssh_port, username=args.ssh_username)\n        )\n\n    # Start a cleanup process dedicated for cleaning up remote training jobs.\n    conn1, conn2 = multiprocessing.Pipe()\n    func = partial(get_all_remote_pids, hosts, args.ssh_port, udf_command)\n    process = multiprocessing.Process(target=cleanup_proc, args=(func, conn1))\n    process.start()\n\n    def signal_handler(signal, frame):\n        logging.info(\"Stop launcher\")\n        # We need to tell the cleanup process to kill remote training jobs.\n        conn2.send(\"cleanup\")\n        sys.exit(0)\n\n    signal.signal(signal.SIGINT, signal_handler)\n\n    for thread in thread_list:\n        thread.join()\n    # The training processes complete. We should tell the cleanup process to exit.\n    conn2.send(\"exit\")\n    process.join()\n\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"Launch a distributed job\")\n    parser.add_argument(\"--ssh_port\", type=int, default=22, help=\"SSH Port.\")\n    parser.add_argument(\n        \"--ssh_username\",\n        default=\"\",\n        help=\"Optional. When issuing commands (via ssh) to cluster, use the provided username in the ssh cmd. \"\n        \"Example: If you provide --ssh_username=bob, then the ssh command will be like: 'ssh bob@1.2.3.4 CMD' \"\n        \"instead of 'ssh 1.2.3.4 CMD'\",\n    )\n    parser.add_argument(\n        \"--num_proc_per_machine\",\n        type=int,\n        help=\"The number of server processes per machine\",\n    )\n    parser.add_argument(\n        \"--master_port\",\n        type=int,\n        help=\"This port is used to form gloo group (randevouz server)\",\n    )\n    parser.add_argument(\n        \"--ip_config\",\n        type=str,\n        help=\"The file (in workspace) of IP configuration for server processes\",\n    )\n\n    args, udf_command = parser.parse_known_args()\n    assert len(udf_command) == 1, \"Please provide user command line.\"\n    assert (\n        args.num_proc_per_machine is not None and args.num_proc_per_machine > 0\n    ), \"--num_proc_per_machine must be a positive number.\"\n    assert (\n        args.ip_config is not None\n    ), \"A user has to specify an IP configuration file with --ip_config.\"\n\n    udf_command = str(udf_command[0])\n    if \"python\" not in udf_command:\n        raise RuntimeError(\n            \"DGL launching script can only support Python executable file.\"\n        )\n\n    submit_jobs(args, udf_command)\n\n\nif __name__ == \"__main__\":\n    fmt = \"%(asctime)s %(levelname)s %(message)s\"\n    logging.basicConfig(format=fmt, level=logging.INFO)\n    main()\n"
  },
  {
    "path": "tools/distpartitioning/README.md",
    "content": "### xxx_nodes.txt format\nThis file is used to provide node information to this framework. Following is the format for each line in this file:\n```\n<node_type> <weight1> <weight2> <weight3> <weight4> <global_type_node_id> <attributes>\n```\nwhere node_type is the type id of this node, weights can be any number of columns as determined by the user, global_type_node_id are the contiguous ids starting from `0` for a particular node_type. And attributes can be any number of columns at the end of each line. \n\n### xxx___edges.txt format\nThis file is used to provide edge information to this framework. Following is the format for each line in this file:\n```\n<global_src_id> <global_dst_id> <global_type_edge_id> <edge_type> <attributes>\n```\nwhere global_src_id and global_dst_id are two end points of an edge, global_type_edge_id is the unique id assigned to each edge type and are contiguous, and starting from 0, for each edge_type. Attributes can be any number of columns at the end of each line. \n\n### Naming convention \n`global_` prefix (for any node or edge ids) indicate that these ids are read from graph input files. These ids are allocated to nodes and edges before `data shuffling`. These ids are globally unique across all partitions.\n`shuffle_global_` prefix (for any node or edge ids) indicate that these ids are assigned after the `data shuffling` is completed. These ids are globally unique across all partitions.\n`part_local_` prefix (for any node or edge ids) indicate that these ids are assigned after the `data shuffling` and are unique within a given partition.\nFor instance, if a variable is named as `global_src_id` it means that this id is read from the graph input file and is assumed to be globally unique across all partitions. Similarly if a variable is named `part_local_node_id`  then it means that this node_id is assigned after the data shuffling is complete and is unique with a given partition.\n\n### High level description of the algorithm\n#### Single file format for graph input files\nHere we assume that all the nodes' related data is present in one single file and similarly all the edges are in one single file. \nIn this case following steps are executed to write dgl objects for each partition, as assigned my any partitioning algorithm, for example METIS. \n##### Step 1 (Data Loading):\nRank-0 process reads in all the graph files which are xxx_nodes.txt, xxx_edges.txt, node_feats.dgl, edge_feats.dgl and xxx_removed_edges.txt.\nRank-0 process determines the ownership of nodes by using the output of partitioning algorithm (here, we expect the output of partitioning step is a mapping between a node and its partition id for the entire graph). Edge ownership is determined by the `destination` node-id for that edge. Each edge belongs to the partition-id of the destination node-id of each edge. \n##### Step 2 (Data Shuffling):\nRank-0 process will send node-data, edge-data, node-features, edge-features to their respective processes by using the ownership rules described in Step-1. Non-Rank-0 processes will receive their own nodes, edges, node-features and edge-features and store them in local data-structures. Upon completion of sending information Rank-0 process will delete nodes, edges, node-features and edge-features which are not owned by rank-0. \n##### Step 3 (ID assignment and resolution): \nAt this time all the ranks will have their own local information in their respective data structures. Then each process will perform the following steps: a) Assign shuffle_global_xxx (here xxx is node_ids and edge_ids) for nodes and edges by performing prefix sum on all ranks. b) Assign part_local_xxx (xxx means node_ids and edge_ids) to nodes and edges so that they can be used to index into the node and edge features, and c) Retrieve shuffle_global_node_ids by using global_node_ids to determine the ownership of any given node. This step is done for the node_ids (present locally on any given rank) for which shuffle_global_node_ids were assigned on a different rank'ed process.\n##### Step 4 (Serialization): \nAfter every rank has global-ids, shuffle_global-ids, part_local-ids for all the nodes and edges present locally, then it proceeds by DGL object creation. Finally Rank-0 process will aggregate graph-level metadata and create a json file with graph-level information. \n\n### How to use this tool\nTo run this code on a single machine using multiple processes, use the following command\n```\npython3 data_proc_pipeline.py --world-size 2 --nodes-file mag_nodes.txt --edges-file mag_edges.txt --node-feats-file node_feat.dgl --metis-partitions mag_part.2 --input-dir /home/ubuntu/data --graph-name mag --schema mag.json --num-parts 2 --num-node-weights 4 --workspace /home/ubuntu/data --node-attr-dtype float --output /home/ubuntu/data/outputs --removed-edges mag_removed_edges.txt\n```\nAbove command, assumes that there are `2` partitions and number of node weights are `4`. All other command line arguments are self-explanatory.\n"
  },
  {
    "path": "tools/distpartitioning/array_readwriter/__init__.py",
    "content": "from . import csv, numpy_array, parquet\nfrom .registry import get_array_parser, register_array_parser\n"
  },
  {
    "path": "tools/distpartitioning/array_readwriter/csv.py",
    "content": "import logging\n\nimport pandas as pd\nimport pyarrow\nimport pyarrow.csv\n\nfrom .registry import register_array_parser\n\n\n@register_array_parser(\"csv\")\nclass CSVArrayParser(object):\n    def __init__(self, delimiter=\",\"):\n        self.delimiter = delimiter\n\n    def read(self, path):\n        logging.debug(\n            \"Reading from %s using CSV format with configuration %s\"\n            % (path, self.__dict__)\n        )\n        # do not read the first line as header\n        read_options = pyarrow.csv.ReadOptions(autogenerate_column_names=True)\n        parse_options = pyarrow.csv.ParseOptions(delimiter=self.delimiter)\n        arr = pyarrow.csv.read_csv(\n            path, read_options=read_options, parse_options=parse_options\n        )\n        logging.debug(\"Done reading from %s\" % path)\n        return arr.to_pandas().to_numpy()\n\n    def write(self, path, arr):\n        logging.debug(\n            \"Writing to %s using CSV format with configuration %s\"\n            % (path, self.__dict__)\n        )\n        write_options = pyarrow.csv.WriteOptions(\n            include_header=False, delimiter=self.delimiter\n        )\n        arr = pyarrow.Table.from_pandas(pd.DataFrame(arr))\n        pyarrow.csv.write_csv(arr, path, write_options=write_options)\n        logging.debug(\"Done writing to %s\" % path)\n"
  },
  {
    "path": "tools/distpartitioning/array_readwriter/numpy_array.py",
    "content": "import logging\n\nimport numpy as np\nfrom numpy.lib.format import open_memmap\n\nfrom .registry import register_array_parser\n\n\n@register_array_parser(\"numpy\")\nclass NumpyArrayParser(object):\n    def __init__(self):\n        pass\n\n    def read(self, path):\n        logging.debug(\"Reading from %s using numpy format\" % path)\n        arr = np.load(path, mmap_mode=\"r\")\n        logging.debug(\"Done reading from %s\" % path)\n        return arr\n\n    def write(self, path, arr):\n        logging.debug(\"Writing to %s using numpy format\" % path)\n        # np.save would load the entire memmap array up into CPU.  So we manually open\n        # an empty npy file with memmap mode and manually flush it instead.\n        new_arr = open_memmap(path, mode=\"w+\", dtype=arr.dtype, shape=arr.shape)\n        new_arr[:] = arr[:]\n        logging.debug(\"Done writing to %s\" % path)\n"
  },
  {
    "path": "tools/distpartitioning/array_readwriter/parquet.py",
    "content": "import logging\n\nimport numpy as np\nimport pandas as pd\nimport pyarrow\nimport pyarrow.parquet\n\nfrom .registry import register_array_parser\n\n\n@register_array_parser(\"parquet\")\nclass ParquetArrayParser(object):\n    def __init__(self):\n        pass\n\n    def read(self, path):\n        logging.debug(\"Reading from %s using parquet format\" % path)\n        metadata = pyarrow.parquet.read_metadata(path)\n        metadata = metadata.schema.to_arrow_schema().metadata\n\n        # As parquet data are tabularized, we assume the dim of ndarray is 2.\n        # If not, it should be explictly specified in the file as metadata.\n        if metadata:\n            shape = metadata.get(b\"shape\", None)\n        else:\n            shape = None\n        table = pyarrow.parquet.read_table(path, memory_map=True)\n\n        data_types = table.schema.types\n        # Spark ML feature processing produces single-column parquet files where each row is a vector object\n        if len(data_types) == 1 and isinstance(data_types[0], pyarrow.ListType):\n            arr = np.array(table.to_pandas().iloc[:, 0].to_list())\n            logging.debug(\n                f\"Parquet data under {path} converted from single vector per row to ndarray\"\n            )\n        else:\n            arr = table.to_pandas().to_numpy()\n        if not shape:\n            logging.debug(\n                \"Shape information not found in the metadata, read the data as \"\n                \"a 2 dim array.\"\n            )\n        logging.debug(\"Done reading from %s\" % path)\n        shape = tuple(eval(shape.decode())) if shape else arr.shape\n        return arr.reshape(shape)\n\n    def write(self, path, array, vector_rows=False):\n        logging.debug(\"Writing to %s using parquet format\" % path)\n        shape = array.shape\n        if len(shape) > 2:\n            array = array.reshape(shape[0], -1)\n        if vector_rows:\n            table = pyarrow.table(\n                [pyarrow.array(array.tolist())], names=[\"vector\"]\n            )\n            logging.debug(\"Writing to %s using single-vector rows...\" % path)\n        else:\n            table = pyarrow.Table.from_pandas(pd.DataFrame(array))\n            table = table.replace_schema_metadata({\"shape\": str(shape)})\n\n        pyarrow.parquet.write_table(table, path)\n        logging.debug(\"Done writing to %s\" % path)\n"
  },
  {
    "path": "tools/distpartitioning/array_readwriter/registry.py",
    "content": "REGISTRY = {}\n\n\ndef register_array_parser(name):\n    def _deco(cls):\n        REGISTRY[name] = cls\n        return cls\n\n    return _deco\n\n\ndef get_array_parser(**fmt_meta):\n    cls = REGISTRY[fmt_meta.pop(\"name\")]\n    return cls(**fmt_meta)\n"
  },
  {
    "path": "tools/distpartitioning/constants.py",
    "content": "GLOBAL_NID = \"global_node_id\"\nGLOBAL_EID = \"global_edge_id\"\n\nSHUFFLE_GLOBAL_NID = \"shuffle_global_node_id\"\nSHUFFLE_GLOBAL_EID = \"shuffle_global_edge_id\"\n\nNTYPE_ID = \"node_type_id\"\nETYPE_ID = \"edge_type_id\"\n\nGLOBAL_TYPE_NID = \"global_type_node_id\"\nGLOBAL_TYPE_EID = \"global_type_edge_id\"\n\nGLOBAL_SRC_ID = \"global_src_id\"\nGLOBAL_DST_ID = \"global_dst_id\"\nSHUFFLE_GLOBAL_SRC_ID = \"shuffle_global_src_id\"\nSHUFFLE_GLOBAL_DST_ID = \"shuffle_global_dst_id\"\n\nOWNER_PROCESS = \"owner_proc_id\"\n\nPART_LOCAL_NID = \"part_local_nid\"\n\nSTR_NODE_TYPE = \"node_type\"\nSTR_EDGE_TYPE = \"edge_type\"\nSTR_EDGES = \"edges\"\nSTR_FORMAT = \"format\"\nSTR_FORMAT_DELIMITER = \"delimiter\"\nSTR_DATA = \"data\"\nSTR_NODE_DATA = \"node_data\"\nSTR_EDGE_DATA = \"edge_data\"\n\nSTR_NUMPY = \"numpy\"\nSTR_PARQUET = \"parquet\"\nSTR_CSV = \"csv\"\nSTR_NAME = \"name\"\n\nSTR_GRAPH_NAME = \"graph_name\"\nSTR_NODE_FEATURES = \"node_features\"\nSTR_EDGE_FEATURES = \"edge_features\"\n\nSTR_NUM_NODES_PER_TYPE = \"num_nodes_per_type\"\nSTR_NUM_EDGES_PER_TYPE = \"num_edges_per_type\"\n\nSTR_NTYPES = \"ntypes\"\n"
  },
  {
    "path": "tools/distpartitioning/convert_partition.py",
    "content": "import copy\nimport gc\nimport logging\nimport os\n\nimport constants\nimport dgl\nimport dgl.backend as F\nimport dgl.graphbolt as gb\nimport numpy as np\nimport torch as th\nimport torch.distributed as dist\nfrom dgl import EID, ETYPE, NID, NTYPE\n\nfrom dgl.distributed.constants import DGL2GB_EID, GB_DST_ID\nfrom dgl.distributed.partition import (\n    _cast_to_minimum_dtype,\n    _etype_str_to_tuple,\n    _etype_tuple_to_str,\n    cast_various_to_minimum_dtype_gb,\n    RESERVED_FIELD_DTYPE,\n)\nfrom utils import get_idranges, memory_snapshot\n\n\ndef _get_unique_invidx(srcids, dstids, nids, low_mem=True):\n    \"\"\"This function is used to compute a list of unique elements,\n    and their indices in the input list, which is the concatenation\n    of srcids, dstids and uniq_nids. In addition, this function will also\n    compute inverse indices, in the list of unique elements, for the\n    elements in srcids, dstids and nids arrays. srcids, dstids will be\n    over-written to contain the inverse indices. Basically, this function\n    is mimicing the functionality of numpy's unique function call.\n    The problem with numpy's unique function call is its high memory\n    requirement. For an input list of 3 billion edges it consumes about\n    550GB of systems memory, which is limiting the capability of the\n    partitioning pipeline.\n\n    Note: This function is a workaround solution for the high memory requirement\n    of numpy's unique function call. This function is not a general purpose\n    function and is only used in the context of the partitioning pipeline.\n    What's more, this function does not behave exactly the same as numpy's\n    unique function call. Namely, this function does not return the exact same\n    inverse indices as numpy's unique function call. However, for the current\n    use case, this function is sufficient.\n\n    Current numpy uniques function returns 3 return parameters, which are\n        . list of unique elements\n        . list of indices, in the input argument list, which are first\n            occurance of the corresponding element in the uniques list\n        . list of inverse indices, which are indices from the uniques list\n            and can be used to rebuild the original input array\n    Compared to the above numpy's return parameters, this work around\n    solution returns 4 values\n        . list of unique elements,\n        . list of indices, which may not be the first occurance of the\n            corresponding element from the uniques\n        . list of inverse indices, here we only build the inverse indices\n            for srcids and dstids input arguments. For the current use case,\n            only these two inverse indices are needed.\n\n    Parameters:\n    -----------\n    srcids : numpy array\n        a list of numbers, which are the src-ids of the edges\n    dstids : numpy array\n        a list of numbers, which are the dst-ids of the edges\n    nids : numpy array\n        a list of numbers, a list of unique shuffle-global-nids.\n        This list is guaranteed to be a list of sorted consecutive unique\n        list of numbers. Also, this list will be a `super set` for the\n        list of dstids. Current implementation of the pipeline guarantees\n        this assumption and is used to simplify the current implementation\n        of the workaround solution.\n    low_mem : bool, optional\n        Indicates whether to use the low memory version of the function. If\n        ``False``, the function will use numpy's native ``unique`` function.\n        Otherwise, the function will use the low memory version of the\n        function.\n\n    Returns:\n    --------\n    numpy array :\n        a list of unique, sorted elements, computed from the input arguments\n    numpy array :\n        a list of integers. These are indices in the concatenated list\n        [srcids, dstids, uniq_nids], which are the input arguments to this function\n    numpy array :\n        a list of integers. These are inverse indices, which will be indices\n        from the unique elements list specifying the elements from the\n        input array, srcids\n    numpy array :\n        a list of integers. These are inverse indices, which will be indices\n        from the unique elements list specifying the elements from the\n        input array, dstids\n    \"\"\"\n    assert len(srcids) == len(\n        dstids\n    ), f\"Please provide the correct input parameters\"\n    assert len(srcids) != 0, f\"Please provide a non-empty edge-list.\"\n\n    if not low_mem:\n        logging.warning(\n            \"Calling numpy's native function unique. This functions memory \"\n            \"overhead will limit size of the partitioned graph objects \"\n            \"processed by each node in the cluster.\"\n        )\n        uniques, idxes, inv_idxes = np.unique(\n            np.concatenate([srcids, dstids, nids]),\n            return_index=True,\n            return_inverse=True,\n        )\n        src_len = len(srcids)\n        dst_len = len(dstids)\n        return (\n            uniques,\n            idxes,\n            inv_idxes[:src_len],\n            inv_idxes[src_len : (src_len + dst_len)],\n        )\n\n    # find uniqes which appear only in the srcids list\n    mask = np.isin(srcids, nids, invert=True, kind=\"table\")\n    srcids_only = srcids[mask]\n    srcids_idxes = np.where(mask == 1)[0]\n\n    # sort\n    uniques, unique_srcids_idx = np.unique(srcids_only, return_index=True)\n    idxes = srcids_idxes[unique_srcids_idx]\n\n    # build uniques and idxes, first and second return parameters\n    uniques = np.concatenate([uniques, nids])\n    idxes = np.concatenate(\n        [idxes, len(srcids) + len(dstids) + np.arange(len(nids))]\n    )\n\n    # sort and idxes\n    sort_idx = np.argsort(uniques)\n    uniques = uniques[sort_idx]\n    idxes = idxes[sort_idx]\n\n    # uniques and idxes are built\n    assert len(uniques) == len(idxes), f\"Error building the idxes array.\"\n\n    srcids = np.searchsorted(uniques, srcids, side=\"left\")\n\n    # process dstids now.\n    # dstids is guaranteed to be a subset of the `nids` list\n    # here we are computing index in the list of uniqes for\n    # each element in the list of dstids, in a two step process\n    # 1. locate the position of first element from nids in the\n    #       list of uniques - dstids cannot appear to the left\n    #       of this number, they are guaranteed to be on the right\n    #       side of this number.\n    # 2. dstids = dstids - nids[0]\n    #       By subtracting nids[0] from the list of dstids will make\n    #       the list of dstids to be in the range of [0, max(nids)-1]\n    # 3. dstids = dstids - nids[0] + offset\n    #       Now we move the list of dstids by `offset` which will be\n    #       the starting position of the nids[0] element. Note that\n    #       nids will ALWAYS be a SUPERSET of dstids.\n    offset = np.searchsorted(uniques, nids[0], side=\"left\")\n    dstids = dstids - nids[0] + offset\n\n    # return the values\n    return uniques, idxes, srcids, dstids\n\n\n# Utility functions.\ndef _is_homogeneous(ntypes, etypes):\n    \"\"\"Checks if the provided ntypes and etypes form a homogeneous graph.\"\"\"\n    return len(ntypes) == 1 and len(etypes) == 1\n\n\ndef _coo2csc(src_ids, dst_ids):\n    src_ids, dst_ids = th.tensor(src_ids, dtype=th.int64), th.tensor(\n        dst_ids, dtype=th.int64\n    )\n    num_nodes = th.max(th.stack([src_ids, dst_ids], dim=0)).item() + 1\n    dst, idx = dst_ids.sort()\n    indptr = th.searchsorted(dst, th.arange(num_nodes + 1))\n    indices = src_ids[idx]\n    return indptr, indices, idx\n\n\ndef _create_edge_data(edgeid_offset, etype_ids, num_edges):\n    eid = th.arange(\n        edgeid_offset,\n        edgeid_offset + num_edges,\n        dtype=RESERVED_FIELD_DTYPE[dgl.EID],\n    )\n    etype = th.as_tensor(etype_ids, dtype=RESERVED_FIELD_DTYPE[dgl.ETYPE])\n    inner_edge = th.ones(num_edges, dtype=RESERVED_FIELD_DTYPE[\"inner_edge\"])\n    return eid, etype, inner_edge\n\n\ndef _create_node_data(ntype, uniq_ids, reshuffle_nodes, inner_nodes):\n    node_type = th.as_tensor(ntype, dtype=RESERVED_FIELD_DTYPE[dgl.NTYPE])\n    node_id = th.as_tensor(uniq_ids[reshuffle_nodes])\n    inner_node = th.as_tensor(\n        inner_nodes[reshuffle_nodes],\n        dtype=RESERVED_FIELD_DTYPE[\"inner_node\"],\n    )\n    return node_type, node_id, inner_node\n\n\ndef _compute_node_ntype(\n    global_src_id, global_dst_id, global_homo_nid, idx, reshuffle_nodes, id_map\n):\n    global_ids = np.concatenate([global_src_id, global_dst_id, global_homo_nid])\n    part_global_ids = global_ids[idx]\n    part_global_ids = part_global_ids[reshuffle_nodes]\n    ntype, per_type_ids = id_map(part_global_ids)\n    return ntype, per_type_ids\n\n\ndef _graph_orig_ids(\n    return_orig_nids,\n    return_orig_eids,\n    ntypes_map,\n    etypes_map,\n    node_attr,\n    edge_attr,\n    per_type_ids,\n    type_per_edge,\n    global_edge_id,\n):\n    orig_nids = None\n    orig_eids = None\n    if return_orig_nids:\n        orig_nids = {}\n        for ntype, ntype_id in ntypes_map.items():\n            mask = th.logical_and(\n                node_attr[dgl.NTYPE] == ntype_id,\n                node_attr[\"inner_node\"],\n            )\n            orig_nids[ntype] = th.as_tensor(per_type_ids[mask])\n    if return_orig_eids:\n        orig_eids = {}\n        for etype, etype_id in etypes_map.items():\n            mask = th.logical_and(\n                type_per_edge == etype_id,\n                edge_attr[\"inner_edge\"],\n            )\n            orig_eids[_etype_tuple_to_str(etype)] = th.as_tensor(\n                global_edge_id[mask]\n            )\n    return orig_nids, orig_eids\n\n\ndef _create_edge_attr_gb(\n    part_local_dst_id, edgeid_offset, etype_ids, ntypes, etypes, etypes_map\n):\n    edge_attr = {}\n    # create edge data in graph.\n    num_edges = len(part_local_dst_id)\n    (\n        edge_attr[dgl.EID],\n        type_per_edge,\n        edge_attr[\"inner_edge\"],\n    ) = _create_edge_data(edgeid_offset, etype_ids, num_edges)\n    assert \"inner_edge\" in edge_attr\n\n    is_homo = _is_homogeneous(ntypes, etypes)\n\n    edge_type_to_id = (\n        {gb.etype_tuple_to_str((\"_N\", \"_E\", \"_N\")): 0}\n        if is_homo\n        else {\n            gb.etype_tuple_to_str(etype): etid\n            for etype, etid in etypes_map.items()\n        }\n    )\n    return edge_attr, type_per_edge, edge_type_to_id\n\n\ndef _create_node_attr(\n    idx,\n    global_src_id,\n    global_dst_id,\n    global_homo_nid,\n    uniq_ids,\n    reshuffle_nodes,\n    id_map,\n    inner_nodes,\n):\n    # compute per_type_ids and ntype for all the nodes in the graph.\n    ntype, per_type_ids = _compute_node_ntype(\n        global_src_id,\n        global_dst_id,\n        global_homo_nid,\n        idx,\n        reshuffle_nodes,\n        id_map,\n    )\n\n    # create node data in graph.\n    node_attr = {}\n    (\n        node_attr[dgl.NTYPE],\n        node_attr[dgl.NID],\n        node_attr[\"inner_node\"],\n    ) = _create_node_data(ntype, uniq_ids, reshuffle_nodes, inner_nodes)\n    return node_attr, per_type_ids\n\n\ndef remove_attr_gb(\n    edge_attr, node_attr, store_inner_node, store_inner_edge, store_eids\n):\n    edata, ndata = copy.deepcopy(edge_attr), copy.deepcopy(node_attr)\n    if not store_inner_edge:\n        assert \"inner_edge\" in edata\n        edata.pop(\"inner_edge\")\n\n    if not store_eids:\n        assert dgl.EID in edata\n        edata.pop(dgl.EID)\n\n    if not store_inner_node:\n        assert \"inner_node\" in ndata\n        ndata.pop(\"inner_node\")\n    return edata, ndata\n\n\ndef _process_partition_gb(\n    node_attr,\n    edge_attr,\n    type_per_edge,\n    src_ids,\n    dst_ids,\n    sort_etypes,\n):\n    \"\"\"Preprocess partitions before saving:\n    1. format data types.\n    2. sort csc/csr by tag.\n    \"\"\"\n    for k, dtype in RESERVED_FIELD_DTYPE.items():\n        if k in node_attr:\n            node_attr[k] = F.astype(node_attr[k], dtype)\n        if k in edge_attr:\n            edge_attr[k] = F.astype(edge_attr[k], dtype)\n\n    indptr, indices, edge_ids = _coo2csc(src_ids, dst_ids)\n    if sort_etypes:\n        split_size = th.diff(indptr)\n        split_indices = th.split(type_per_edge, tuple(split_size), dim=0)\n        sorted_idxs = []\n        for split_indice in split_indices:\n            sorted_idxs.append(split_indice.sort()[1])\n\n        sorted_idx = th.cat(sorted_idxs, dim=0)\n        sorted_idx = (\n            th.repeat_interleave(indptr[:-1], split_size, dim=0) + sorted_idx\n        )\n\n    return indptr, indices[sorted_idx], edge_ids[sorted_idx]\n\n\ndef _update_node_map(node_map_val, end_ids_per_rank, id_ntypes, prev_last_id):\n    \"\"\"this function is modified from the function '_update_node_edge_map' in dgl.distributed.partition\"\"\"\n    # Update the node_map_val to be contiguous.\n    rank = dist.get_rank()\n    prev_end_id = (\n        end_ids_per_rank[rank - 1].item() if rank > 0 else prev_last_id\n    )\n    ntype_ids = {ntype: ntype_id for ntype_id, ntype in enumerate(id_ntypes)}\n    for ntype_id in list(ntype_ids.values()):\n        ntype = id_ntypes[ntype_id]\n        start_id = node_map_val[ntype][0][0]\n        end_id = node_map_val[ntype][0][1]\n        if not (start_id == -1 and end_id == -1):\n            continue\n        prev_ntype_id = (\n            ntype_ids[ntype] - 1\n            if ntype_ids[ntype] > 0\n            else max(ntype_ids.values())\n        )\n        prev_ntype = id_ntypes[prev_ntype_id]\n        if ntype_ids[ntype] == 0:\n            node_map_val[ntype][0][0] = prev_end_id\n        else:\n            node_map_val[ntype][0][0] = node_map_val[prev_ntype][0][1]\n        node_map_val[ntype][0][1] = node_map_val[ntype][0][0]\n    return node_map_val[ntype][0][-1]\n\n\ndef create_graph_object(\n    tot_node_count,\n    tot_edge_count,\n    node_count,\n    edge_count,\n    num_parts,\n    schema,\n    part_id,\n    node_data,\n    edge_data,\n    edgeid_offset,\n    node_typecounts,\n    edge_typecounts,\n    last_ids={},\n    return_orig_nids=False,\n    return_orig_eids=False,\n    use_graphbolt=False,\n    **kwargs,\n):\n    \"\"\"\n    This function creates dgl objects for a given graph partition, as in function\n    arguments.\n\n    The \"schema\" argument is a dictionary, which contains the metadata related to node ids\n    and edge ids. It contains two keys: \"nid\" and \"eid\", whose value is also a dictionary\n    with the following structure.\n\n    1. The key-value pairs in the \"nid\" dictionary has the following format.\n       \"ntype-name\" is the user assigned name to this node type. \"format\" describes the\n       format of the contents of the files. and \"data\" is a list of lists, each list has\n       3 elements: file-name, start_id and end_id. File-name can be either absolute or\n       relative path to this file and starting and ending ids are type ids of the nodes\n       which are contained in this file. These type ids are later used to compute global ids\n       of these nodes which are used throughout the processing of this pipeline.\n        \"ntype-name\" : {\n            \"format\" : \"csv\",\n            \"data\" : [\n                    [ <path-to-file>/ntype0-name-0.csv, start_id0, end_id0],\n                    [ <path-to-file>/ntype0-name-1.csv, start_id1, end_id1],\n                    ...\n                    [ <path-to-file>/ntype0-name-<p-1>.csv, start_id<p-1>, end_id<p-1>],\n            ]\n        }\n\n    2. The key-value pairs in the \"eid\" dictionary has the following format.\n       As described for the \"nid\" dictionary the \"eid\" dictionary is similarly structured\n       except that these entries are for edges.\n        \"etype-name\" : {\n            \"format\" : \"csv\",\n            \"data\" : [\n                    [ <path-to-file>/etype0-name-0, start_id0, end_id0],\n                    [ <path-to-file>/etype0-name-1 start_id1, end_id1],\n                    ...\n                    [ <path-to-file>/etype0-name-1 start_id<p-1>, end_id<p-1>]\n            ]\n        }\n\n    In \"nid\" dictionary, the type_nids are specified that\n    should be assigned to nodes which are read from the corresponding nodes file.\n    Along the same lines dictionary for the key \"eid\" is used for edges in the\n    input graph.\n\n    These type ids, for nodes and edges, are used to compute global ids for nodes\n    and edges which are stored in the graph object.\n\n    Parameters:\n    -----------\n    tot_node_count : int\n        the number of all nodes\n    tot_edge_count : int\n        the number of all edges\n    node_count : int\n        the number of nodes in partition\n    edge_count : int\n        the number of edges in partition\n    graph_formats : str\n        the format of graph\n    num_parts : int\n        the number of parts\n    schame : json object\n        json object created by reading the graph metadata json file\n    part_id : int\n        partition id of the graph partition for which dgl object is to be created\n    node_data : numpy ndarray\n        node_data, where each row is of the following format:\n        <global_nid> <ntype_id> <global_type_nid>\n    edge_data : numpy ndarray\n        edge_data, where each row is of the following format:\n        <global_src_id> <global_dst_id> <etype_id> <global_type_eid>\n    edgeid_offset : int\n        offset to be used when assigning edge global ids in the current partition\n    return_orig_ids : bool, optional\n        Indicates whether to return original node/edge IDs.\n\n    Returns:\n    --------\n    dgl object\n        dgl object created for the current graph partition\n    dictionary\n        map between node types and the range of global node ids used\n    dictionary\n        map between edge types and the range of global edge ids used\n    dictionary\n        map between node type(string)  and node_type_id(int)\n    dictionary\n        map between edge type(string)  and edge_type_id(int)\n    dict of tensors\n        If `return_orig_nids=True`, return a dict of 1D tensors whose key is the node type\n        and value is a 1D tensor mapping between shuffled node IDs and the original node\n        IDs for each node type. Otherwise, ``None`` is returned.\n    dict of tensors\n        If `return_orig_eids=True`, return a dict of 1D tensors whose key is the edge type\n        and value is a 1D tensor mapping between shuffled edge IDs and the original edge\n        IDs for each edge type. Otherwise, ``None`` is returned.\n    \"\"\"\n    # create auxiliary data structures from the schema object\n    memory_snapshot(\"CreateDGLObj_Begin\", part_id)\n    _, global_nid_ranges = get_idranges(\n        schema[constants.STR_NODE_TYPE], node_typecounts\n    )\n    _, global_eid_ranges = get_idranges(\n        schema[constants.STR_EDGE_TYPE], edge_typecounts\n    )\n\n    id_map = dgl.distributed.id_map.IdMap(global_nid_ranges)\n\n    ntypes = [(key, global_nid_ranges[key][0, 0]) for key in global_nid_ranges]\n    ntypes.sort(key=lambda e: e[1])\n    ntype_offset_np = np.array([e[1] for e in ntypes])\n    ntypes = [e[0] for e in ntypes]\n    ntypes_map = {e: i for i, e in enumerate(ntypes)}\n    etypes = [(key, global_eid_ranges[key][0, 0]) for key in global_eid_ranges]\n    etypes.sort(key=lambda e: e[1])\n    etypes = [e[0] for e in etypes]\n    etypes_map = {_etype_str_to_tuple(e): i for i, e in enumerate(etypes)}\n\n    node_map_val = {ntype: [] for ntype in ntypes}\n    edge_map_val = {_etype_str_to_tuple(etype): [] for etype in etypes}\n\n    memory_snapshot(\"CreateDGLObj_AssignNodeData\", part_id)\n    shuffle_global_nids = node_data[constants.SHUFFLE_GLOBAL_NID]\n    node_data.pop(constants.SHUFFLE_GLOBAL_NID)\n    gc.collect()\n\n    ntype_ids = node_data[constants.NTYPE_ID]\n    node_data.pop(constants.NTYPE_ID)\n    gc.collect()\n\n    global_type_nid = node_data[constants.GLOBAL_TYPE_NID]\n    node_data.pop(constants.GLOBAL_TYPE_NID)\n    node_data = None\n    gc.collect()\n\n    global_homo_nid = ntype_offset_np[ntype_ids] + global_type_nid\n    assert np.all(shuffle_global_nids[1:] - shuffle_global_nids[:-1] == 1)\n    shuffle_global_nid_range = (shuffle_global_nids[0], shuffle_global_nids[-1])\n\n    # Determine the node ID ranges of different node types.\n    prev_last_id = last_ids.get(part_id - 1, 0)\n    for ntype_name in global_nid_ranges:\n        ntype_id = ntypes_map[ntype_name]\n        type_nids = shuffle_global_nids[ntype_ids == ntype_id]\n        if len(type_nids) == 0:\n            node_map_val[ntype_name].append([-1, -1])\n        else:\n            node_map_val[ntype_name].append(\n                [int(type_nids[0]), int(type_nids[-1]) + 1]\n            )\n            last_id = th.tensor(\n                [max(prev_last_id, int(type_nids[-1]) + 1)], dtype=th.int64\n            )\n    id_ntypes = list(global_nid_ranges.keys())\n\n    gather_last_ids = [\n        th.zeros(1, dtype=th.int64) for _ in range(dist.get_world_size())\n    ]\n\n    dist.all_gather(gather_last_ids, last_id)\n    prev_last_id = _update_node_map(\n        node_map_val, gather_last_ids, id_ntypes, prev_last_id\n    )\n    last_ids[part_id] = prev_last_id\n\n    # process edges\n    memory_snapshot(\"CreateDGLObj_AssignEdgeData: \", part_id)\n    shuffle_global_src_id = edge_data[constants.SHUFFLE_GLOBAL_SRC_ID]\n    edge_data.pop(constants.SHUFFLE_GLOBAL_SRC_ID)\n    gc.collect()\n\n    shuffle_global_dst_id = edge_data[constants.SHUFFLE_GLOBAL_DST_ID]\n    edge_data.pop(constants.SHUFFLE_GLOBAL_DST_ID)\n    gc.collect()\n\n    global_src_id = edge_data[constants.GLOBAL_SRC_ID]\n    edge_data.pop(constants.GLOBAL_SRC_ID)\n    gc.collect()\n\n    global_dst_id = edge_data[constants.GLOBAL_DST_ID]\n    edge_data.pop(constants.GLOBAL_DST_ID)\n    gc.collect()\n\n    global_edge_id = edge_data[constants.GLOBAL_TYPE_EID]\n    edge_data.pop(constants.GLOBAL_TYPE_EID)\n    gc.collect()\n\n    etype_ids = edge_data[constants.ETYPE_ID]\n    edge_data.pop(constants.ETYPE_ID)\n    edge_data = None\n    gc.collect()\n    logging.info(\n        f\"There are {len(shuffle_global_src_id)} edges in partition {part_id}\"\n    )\n\n    # It's not guaranteed that the edges are sorted based on edge type.\n    # Let's sort edges and all attributes on the edges.\n    if not np.all(np.diff(etype_ids) >= 0):\n        sort_idx = np.argsort(etype_ids)\n        (\n            shuffle_global_src_id,\n            shuffle_global_dst_id,\n            global_src_id,\n            global_dst_id,\n            global_edge_id,\n            etype_ids,\n        ) = (\n            shuffle_global_src_id[sort_idx],\n            shuffle_global_dst_id[sort_idx],\n            global_src_id[sort_idx],\n            global_dst_id[sort_idx],\n            global_edge_id[sort_idx],\n            etype_ids[sort_idx],\n        )\n        assert np.all(np.diff(etype_ids) >= 0)\n    else:\n        print(f\"[Rank: {part_id} Edge data is already sorted !!!\")\n\n    # Determine the edge ID range of different edge types.\n    edge_id_start = edgeid_offset\n    for etype_name in global_eid_ranges:\n        etype = _etype_str_to_tuple(etype_name)\n        assert len(etype) == 3\n        etype_id = etypes_map[etype]\n        edge_map_val[etype].append(\n            [edge_id_start, edge_id_start + np.sum(etype_ids == etype_id)]\n        )\n        edge_id_start += np.sum(etype_ids == etype_id)\n    memory_snapshot(\"CreateDGLObj_UniqueNodeIds: \", part_id)\n\n    # get the edge list in some order and then reshuffle.\n    # Here the order of nodes is defined by the sorted order.\n    uniq_ids, idx, part_local_src_id, part_local_dst_id = _get_unique_invidx(\n        shuffle_global_src_id,\n        shuffle_global_dst_id,\n        np.arange(shuffle_global_nid_range[0], shuffle_global_nid_range[1] + 1),\n    )\n\n    inner_nodes = th.as_tensor(\n        np.logical_and(\n            uniq_ids >= shuffle_global_nid_range[0],\n            uniq_ids <= shuffle_global_nid_range[1],\n        )\n    )\n\n    # get the list of indices, from inner_nodes, which will sort inner_nodes as [True, True, ...., False, False, ...]\n    # essentially local nodes will be placed before non-local nodes.\n    reshuffle_nodes = th.arange(len(uniq_ids))\n    reshuffle_nodes = th.cat(\n        [reshuffle_nodes[inner_nodes.bool()], reshuffle_nodes[inner_nodes == 0]]\n    )\n\n    \"\"\"\n    Following procedure is used to map the part_local_src_id, part_local_dst_id to account for\n    reshuffling of nodes (to order localy owned nodes prior to non-local nodes in a partition)\n    1. Form a node_map, in this case a numpy array, which will be used to map old node-ids (pre-reshuffling)\n    to post-reshuffling ids.\n    2. Once the map is created, use this map to map all the node-ids in the part_local_src_id \n    and part_local_dst_id list to their appropriate `new` node-ids (post-reshuffle order).\n    3. Since only the node's order is changed, we will have to re-order nodes related information when\n    creating dgl object: this includes dgl.NTYPE, dgl.NID and inner_node.\n    4. Edge's order is not changed. At this point in the execution path edges are still ordered by their etype-ids.\n    5. Create the dgl object appropriately and return the dgl object.\n    \n    Here is a  simple example to understand the above flow better.\n\n    part_local_nids = [0, 1, 2, 3, 4, 5]\n    part_local_src_ids = [0, 0, 0, 0, 2, 3, 4]\n    part_local_dst_ids = [1, 2, 3, 4, 4, 4, 5]\n\n    Assume that nodes {1, 5} are halo-nodes, which are not owned by this partition.\n\n    reshuffle_nodes = [0, 2, 3, 4, 1, 5]\n\n    A node_map, which maps node-ids from old to reshuffled order is as follows:\n    node_map = np.zeros((len(reshuffle_nodes,)))\n    node_map[reshuffle_nodes] = np.arange(len(reshuffle_nodes))\n\n    Using the above map, we have mapped part_local_src_ids and part_local_dst_ids as follows:\n    part_local_src_ids = [0, 0, 0, 0, 1, 2, 3]\n    part_local_dst_ids = [4, 1, 2, 3, 3, 3, 5]\n\n    In this graph above, note that nodes {0, 1, 2, 3} are inner_nodes and {4, 5} are NON-inner-nodes\n\n    Since the edge are re-ordered in any way, there is no reordering required for edge related data\n    during the DGL object creation.\n    \"\"\"\n    # create the mappings to generate mapped part_local_src_id and part_local_dst_id\n    # This map will map from unshuffled node-ids to reshuffled-node-ids (which are ordered to prioritize\n    # locally owned nodes).\n    nid_map = np.zeros(\n        (\n            len(\n                reshuffle_nodes,\n            )\n        )\n    )\n    nid_map[reshuffle_nodes] = np.arange(len(reshuffle_nodes))\n\n    # Now map the edge end points to reshuffled_values.\n    part_local_src_id, part_local_dst_id = (\n        nid_map[part_local_src_id],\n        nid_map[part_local_dst_id],\n    )\n\n    \"\"\"\n    Creating attributes for graphbolt and DGLGraph is as follows.\n\n    node attributes:    \n        this part is implemented in _create_node_attr.\n        compute the ntype and per type ids for each node with global node type id.\n        create ntype, nid and inner node with orig ntype and inner nodes\n    this part is shared by graphbolt and DGLGraph.\n\n    the attributes created for graphbolt are as follows:\n\n    edge attributes:\n        this part is implemented in _create_edge_attr_gb.\n        create eid, type per edge and inner edge with edgeid_offset.\n        create edge_type_to_id with etypes_map.\n    \n    The process to remove extra attribute is implemented in  remove_attr_gb.\n    the unused attributes like inner_node, inner_edge, eids will be removed following the arguments in kwargs.\n    edge_attr, node_attr are the variable that have removed extra attributes to construct csc_graph.\n    edata, ndata are the variable that reserve extra attributes to be used to generate orig_nid and orig_eid. \n    \n    the src_ids and dst_ids will be transformed into indptr and indices in _coo2csc.\n\n    all variable mentioned above will be casted to minimum data type in cast_various_to_minimum_dtype_gb.\n\n    orig_nids and orig_eids will be generated in _graph_orig_ids with ndata and edata.\n    \"\"\"\n    # create the graph here now.\n    ndata, per_type_ids = _create_node_attr(\n        idx,\n        global_src_id,\n        global_dst_id,\n        global_homo_nid,\n        uniq_ids,\n        reshuffle_nodes,\n        id_map,\n        inner_nodes,\n    )\n    if use_graphbolt:\n        edata, type_per_edge, edge_type_to_id = _create_edge_attr_gb(\n            part_local_dst_id,\n            edgeid_offset,\n            etype_ids,\n            ntypes,\n            etypes,\n            etypes_map,\n        )\n\n        assert edata is not None\n        assert ndata is not None\n\n        sort_etypes = len(etypes_map) > 1\n        indptr, indices, csc_edge_ids = _process_partition_gb(\n            ndata,\n            edata,\n            type_per_edge,\n            part_local_src_id,\n            part_local_dst_id,\n            sort_etypes,\n        )\n        edge_attr, node_attr = remove_attr_gb(\n            edge_attr=edata, node_attr=ndata, **kwargs\n        )\n        edge_attr = {\n            attr: edge_attr[attr][csc_edge_ids] for attr in edge_attr.keys()\n        }\n        cast_various_to_minimum_dtype_gb(\n            node_count=node_count,\n            edge_count=edge_count,\n            tot_node_count=tot_node_count,\n            tot_edge_count=tot_edge_count,\n            num_parts=num_parts,\n            indptr=indptr,\n            indices=indices,\n            type_per_edge=type_per_edge,\n            etypes=etypes,\n            ntypes=ntypes,\n            node_attributes=node_attr,\n            edge_attributes=edge_attr,\n        )\n        part_graph = gb.fused_csc_sampling_graph(\n            csc_indptr=indptr,\n            indices=indices,\n            node_type_offset=None,\n            type_per_edge=type_per_edge[csc_edge_ids],\n            node_attributes=node_attr,\n            edge_attributes=edge_attr,\n            node_type_to_id=ntypes_map,\n            edge_type_to_id=edge_type_to_id,\n        )\n    else:\n        num_edges = len(part_local_dst_id)\n        part_graph = dgl.graph(\n            data=(part_local_src_id, part_local_dst_id), num_nodes=len(uniq_ids)\n        )\n        # create edge data in graph.\n        (\n            part_graph.edata[dgl.EID],\n            part_graph.edata[dgl.ETYPE],\n            part_graph.edata[\"inner_edge\"],\n        ) = _create_edge_data(edgeid_offset, etype_ids, num_edges)\n\n        ndata, per_type_ids = _create_node_attr(\n            idx,\n            global_src_id,\n            global_dst_id,\n            global_homo_nid,\n            uniq_ids,\n            reshuffle_nodes,\n            id_map,\n            inner_nodes,\n        )\n        for attr_name, node_attributes in ndata.items():\n            part_graph.ndata[attr_name] = node_attributes\n        type_per_edge = part_graph.edata[dgl.ETYPE]\n        ndata, edata = part_graph.ndata, part_graph.edata\n    # get the original node ids and edge ids from original graph.\n    orig_nids, orig_eids = _graph_orig_ids(\n        return_orig_nids,\n        return_orig_eids,\n        ntypes_map,\n        etypes_map,\n        ndata,\n        edata,\n        per_type_ids,\n        type_per_edge,\n        global_edge_id,\n    )\n    return (\n        part_graph,\n        node_map_val,\n        edge_map_val,\n        ntypes_map,\n        etypes_map,\n        orig_nids,\n        orig_eids,\n    )\n\n\ndef create_metadata_json(\n    graph_name,\n    num_nodes,\n    num_edges,\n    part_id,\n    num_parts,\n    node_map_val,\n    edge_map_val,\n    ntypes_map,\n    etypes_map,\n    output_dir,\n    use_graphbolt,\n):\n    \"\"\"\n    Auxiliary function to create json file for the graph partition metadata\n\n    Parameters:\n    -----------\n    graph_name : string\n        name of the graph\n    num_nodes : int\n        no. of nodes in the graph partition\n    num_edges : int\n        no. of edges in the graph partition\n    part_id : int\n       integer indicating the partition id\n    num_parts : int\n        total no. of partitions of the original graph\n    node_map_val : dictionary\n        map between node types and the range of global node ids used\n    edge_map_val : dictionary\n        map between edge types and the range of global edge ids used\n    ntypes_map : dictionary\n        map between node type(string)  and node_type_id(int)\n    etypes_map : dictionary\n        map between edge type(string)  and edge_type_id(int)\n    output_dir : string\n        directory where the output files are to be stored\n    use_graphbolt : bool\n        whether to use graphbolt or not\n\n    Returns:\n    --------\n    dictionary\n        map describing the graph information\n\n    \"\"\"\n    part_metadata = {\n        \"graph_name\": graph_name,\n        \"num_nodes\": num_nodes,\n        \"num_edges\": num_edges,\n        \"part_method\": \"metis\",\n        \"num_parts\": num_parts,\n        \"halo_hops\": 1,\n        \"node_map\": node_map_val,\n        \"edge_map\": edge_map_val,\n        \"ntypes\": ntypes_map,\n        \"etypes\": etypes_map,\n    }\n\n    part_dir = \"part\" + str(part_id)\n    node_feat_file = os.path.join(part_dir, \"node_feat.dgl\")\n    edge_feat_file = os.path.join(part_dir, \"edge_feat.dgl\")\n    if use_graphbolt:\n        part_graph_file = os.path.join(part_dir, \"fused_csc_sampling_graph.pt\")\n    else:\n        part_graph_file = os.path.join(part_dir, \"graph.dgl\")\n    part_graph_type = \"part_graph_graphbolt\" if use_graphbolt else \"part_graph\"\n    part_metadata[\"part-{}\".format(part_id)] = {\n        \"node_feats\": node_feat_file,\n        \"edge_feats\": edge_feat_file,\n        part_graph_type: part_graph_file,\n    }\n    return part_metadata\n"
  },
  {
    "path": "tools/distpartitioning/data_proc_pipeline.py",
    "content": "import argparse\nimport logging\nimport os\nimport platform\n\nimport numpy as np\nimport torch.multiprocessing as mp\n\nfrom data_shuffle import multi_machine_run, single_machine_run\n\n\ndef log_params(params):\n    \"\"\"Print all the command line arguments for debugging purposes.\n\n    Parameters:\n    -----------\n    params: argparse object\n        Argument Parser structure listing all the pre-defined parameters\n    \"\"\"\n    print(\"Input Dir: \", params.input_dir)\n    print(\"Graph Name: \", params.graph_name)\n    print(\"Schema File: \", params.schema)\n    print(\"No. partitions: \", params.num_parts)\n    print(\"Output Dir: \", params.output)\n    print(\"WorldSize: \", params.world_size)\n    print(\"Metis partitions: \", params.partitions_dir)\n\n\nif __name__ == \"__main__\":\n    \"\"\"\n    Start of execution from this point.\n    Invoke the appropriate function to begin execution\n    \"\"\"\n    # arguments which are already needed by the existing implementation of convert_partition.py\n    parser = argparse.ArgumentParser(description=\"Construct graph partitions\")\n    parser.add_argument(\n        \"--input-dir\",\n        required=True,\n        type=str,\n        help=\"The directory path that contains the partition results.\",\n    )\n    parser.add_argument(\n        \"--graph-name\", required=True, type=str, help=\"The graph name\"\n    )\n    parser.add_argument(\n        \"--schema\", required=True, type=str, help=\"The schema of the graph\"\n    )\n    parser.add_argument(\n        \"--num-parts\", required=True, type=int, help=\"The number of partitions\"\n    )\n    parser.add_argument(\n        \"--output\",\n        required=True,\n        type=str,\n        help=\"The output directory of the partitioned results\",\n    )\n    parser.add_argument(\n        \"--partitions-dir\",\n        help=\"directory of the partition-ids for each node type\",\n        default=None,\n        type=str,\n    )\n    parser.add_argument(\n        \"--log-level\",\n        type=str,\n        default=\"info\",\n        help=\"To enable log level for debugging purposes. Available options: \\\n\t\t\t  (Critical, Error, Warning, Info, Debug, Notset), default value \\\n\t\t\t  is: Info\",\n    )\n\n    # arguments needed for the distributed implementation\n    parser.add_argument(\n        \"--world-size\",\n        help=\"no. of processes to spawn\",\n        default=1,\n        type=int,\n        required=True,\n    )\n    parser.add_argument(\n        \"--process-group-timeout\",\n        required=True,\n        type=int,\n        help=\"timeout[seconds] for operations executed against the process group \"\n        \"(see torch.distributed.init_process_group)\",\n    )\n    parser.add_argument(\n        \"--save-orig-nids\",\n        action=\"store_true\",\n        help=\"Save original node IDs into files\",\n    )\n    parser.add_argument(\n        \"--save-orig-eids\",\n        action=\"store_true\",\n        help=\"Save original edge IDs into files\",\n    )\n    parser.add_argument(\n        \"--use-graphbolt\",\n        action=\"store_true\",\n        help=\"Use GraphBolt for distributed partition.\",\n    )\n    parser.add_argument(\n        \"--store-inner-node\",\n        action=\"store_true\",\n        default=False,\n        help=\"Store inner nodes.\",\n    )\n\n    parser.add_argument(\n        \"--store-inner-edge\",\n        action=\"store_true\",\n        default=False,\n        help=\"Store inner edges.\",\n    )\n    parser.add_argument(\n        \"--store-eids\",\n        action=\"store_true\",\n        default=False,\n        help=\"Store edge IDs.\",\n    )\n    parser.add_argument(\n        \"--graph-formats\",\n        default=None,\n        type=str,\n        help=\"Save partitions in specified formats.\",\n    )\n    params = parser.parse_args()\n    # invoke the pipeline function\n    numeric_level = getattr(logging, params.log_level.upper(), None)\n    logging.basicConfig(\n        level=numeric_level,\n        format=f\"[{platform.node()} %(levelname)s %(asctime)s PID:%(process)d] %(message)s\",\n    )\n    multi_machine_run(params)\n"
  },
  {
    "path": "tools/distpartitioning/data_shuffle.py",
    "content": "import gc\nimport logging\nimport math\nimport os\nimport sys\nfrom datetime import timedelta\nfrom timeit import default_timer as timer\n\nimport constants\n\nimport dgl\nimport numpy as np\nimport torch\nimport torch.distributed as dist\nimport torch.multiprocessing as mp\nfrom convert_partition import create_graph_object, create_metadata_json\nfrom dataset_utils import get_dataset\nfrom dist_lookup import DistLookupService\nfrom globalids import (\n    assign_shuffle_global_nids_edges,\n    assign_shuffle_global_nids_nodes,\n    lookup_shuffle_global_nids_edges,\n)\nfrom gloo_wrapper import allgather_sizes, alltoallv_cpu, gather_metadata_json\nfrom utils import (\n    augment_edge_data,\n    DATA_TYPE_ID,\n    get_edge_types,\n    get_etype_featnames,\n    get_gid_offsets,\n    get_gnid_range_map,\n    get_idranges,\n    get_node_types,\n    get_ntype_counts_map,\n    get_ntype_featnames,\n    map_partid_rank,\n    memory_snapshot,\n    read_json,\n    read_ntype_partition_files,\n    REV_DATA_TYPE_ID,\n    write_dgl_objects,\n    write_metadata_json,\n)\n\n\ndef gen_node_data(\n    rank, world_size, num_parts, id_lookup, ntid_ntype_map, schema_map\n):\n    \"\"\"\n    For this data processing pipeline, reading node files is not needed. All the needed information about\n    the nodes can be found in the metadata json file. This function generates the nodes owned by a given\n    process, using metis partitions.\n\n    Parameters:\n    -----------\n    rank : int\n        rank of the process\n    world_size : int\n        total no. of processes\n    num_parts : int\n        total no. of partitions\n    id_lookup : instance of class DistLookupService\n       Distributed lookup service used to map global-nids to respective partition-ids and\n       shuffle-global-nids\n    ntid_ntype_map :\n        a dictionary where keys are node_type ids(integers) and values are node_type names(strings).\n    schema_map:\n        dictionary formed by reading the input metadata json file for the input dataset.\n\n        Please note that, it is assumed that for the input graph files, the nodes of a particular node-type are\n        split into `p` files (because of `p` partitions to be generated). On a similar node, edges of a particular\n        edge-type are split into `p` files as well.\n\n        #assuming m nodetypes present in the input graph\n        \"num_nodes_per_chunk\" : [\n            [a0, a1, a2, ... a<p-1>],\n            [b0, b1, b2, ... b<p-1>],\n            ...\n            [m0, m1, m2, ... m<p-1>]\n        ]\n        Here, each sub-list, corresponding a nodetype in the input graph, has `p` elements. For instance [a0, a1, ... a<p-1>]\n        where each element represents the number of nodes which are to be processed by a process during distributed partitioning.\n\n        In addition to the above key-value pair for the nodes in the graph, the node-features are captured in the\n        \"node_data\" key-value pair. In this dictionary the keys will be nodetype names and value will be a dictionary which\n        is used to capture all the features present for that particular node-type. This is shown in the following example:\n\n        \"node_data\" : {\n            \"paper\": {       # node type\n                \"feat\": {   # feature key\n                    \"format\": {\"name\": \"numpy\"},\n                    \"data\": [\"node_data/paper-feat-part1.npy\", \"node_data/paper-feat-part2.npy\"]\n                },\n                \"label\": {   # feature key\n                    \"format\": {\"name\": \"numpy\"},\n                    \"data\": [\"node_data/paper-label-part1.npy\", \"node_data/paper-label-part2.npy\"]\n                },\n                \"year\": {   # feature key\n                    \"format\": {\"name\": \"numpy\"},\n                    \"data\": [\"node_data/paper-year-part1.npy\", \"node_data/paper-year-part2.npy\"]\n                }\n            }\n        }\n        In the above textual description we have a node-type, which is paper, and it has 3 features namely feat, label and year.\n        Each feature has `p` files whose location in the filesystem is the list for the key \"data\" and \"foramt\" is used to\n        describe storage format.\n\n    Returns:\n    --------\n    dictionary :\n        dictionary where keys are column names and values are numpy arrays, these arrays are generated by\n        using information present in the metadata json file\n\n    \"\"\"\n    local_node_data = {}\n    for local_part_id in range(num_parts // world_size):\n        local_node_data[constants.GLOBAL_NID + \"/\" + str(local_part_id)] = []\n        local_node_data[constants.NTYPE_ID + \"/\" + str(local_part_id)] = []\n        local_node_data[\n            constants.GLOBAL_TYPE_NID + \"/\" + str(local_part_id)\n        ] = []\n\n    # Note that `get_idranges` always returns two dictionaries. Keys in these\n    # dictionaries are type names for nodes and edges and values are\n    # `num_parts` number of tuples indicating the range of type-ids in first\n    # dictionary and range of global-nids in the second dictionary.\n    type_nid_dict, global_nid_dict = get_idranges(\n        schema_map[constants.STR_NODE_TYPE],\n        get_ntype_counts_map(\n            schema_map[constants.STR_NODE_TYPE],\n            schema_map[constants.STR_NUM_NODES_PER_TYPE],\n        ),\n        num_chunks=num_parts,\n    )\n\n    for ntype_id, ntype_name in ntid_ntype_map.items():\n        # No. of nodes in each process can differ significantly in lopsided distributions\n        # Synchronize on a per ntype basis\n        dist.barrier()\n\n        type_start, type_end = (\n            type_nid_dict[ntype_name][0][0],\n            type_nid_dict[ntype_name][-1][1],\n        )\n        gnid_start, gnid_end = (\n            global_nid_dict[ntype_name][0, 0],\n            global_nid_dict[ntype_name][0, 1],\n        )\n\n        node_partid_slice = id_lookup.get_partition_ids(\n            np.arange(gnid_start, gnid_end, dtype=np.int64)\n        )  # exclusive\n\n        for local_part_id in range(num_parts // world_size):\n            cond = node_partid_slice == (rank + local_part_id * world_size)\n            own_gnids = np.arange(gnid_start, gnid_end, dtype=np.int64)\n            own_gnids = own_gnids[cond]\n\n            own_tnids = np.arange(type_start, type_end, dtype=np.int64)\n            own_tnids = own_tnids[cond]\n\n            local_node_data[\n                constants.NTYPE_ID + \"/\" + str(local_part_id)\n            ].append(np.ones(own_gnids.shape, dtype=np.int64) * ntype_id)\n            local_node_data[\n                constants.GLOBAL_NID + \"/\" + str(local_part_id)\n            ].append(own_gnids)\n            local_node_data[\n                constants.GLOBAL_TYPE_NID + \"/\" + str(local_part_id)\n            ].append(own_tnids)\n\n    for k in local_node_data.keys():\n        local_node_data[k] = np.concatenate(local_node_data[k])\n\n    return local_node_data\n\n\ndef exchange_edge_data(rank, world_size, num_parts, edge_data, id_lookup):\n    \"\"\"\n    Exchange edge_data among processes in the world.\n    Prepare list of sliced data targeting each process and trigger\n    alltoallv_cpu to trigger messaging api\n\n    Parameters:\n    -----------\n    rank : int\n        rank of the process\n    world_size : int\n        total no. of processes\n    edge_data : dictionary\n        edge information, as a dicitonary which stores column names as keys and values\n        as column data. This information is read from the edges.txt file.\n    id_lookup : DistLookupService instance\n        this object will be used to retrieve ownership information of nodes\n\n    Returns:\n    --------\n    dictionary :\n        the input argument, edge_data, is updated with the edge data received by other processes\n        in the world.\n    \"\"\"\n\n    # Synchronize at the beginning of this function\n    dist.barrier()\n\n    # Prepare data for each rank in the cluster.\n    timer_start = timer()\n\n    CHUNK_SIZE = 100 * 1000 * 1000  # 100 * 8 * 5 = 1 * 4 = 8 GB/message/node\n    num_edges = edge_data[constants.GLOBAL_SRC_ID].shape[0]\n    all_counts = allgather_sizes(\n        [num_edges], world_size, num_parts, return_sizes=True\n    )\n    max_edges = np.amax(all_counts)\n    all_edges = np.sum(all_counts)\n    num_chunks = (max_edges // CHUNK_SIZE) + (\n        0 if (max_edges % CHUNK_SIZE == 0) else 1\n    )\n    LOCAL_CHUNK_SIZE = (num_edges // num_chunks) + (\n        0 if (num_edges % num_chunks == 0) else 1\n    )\n    logging.debug(\n        f\"[Rank: {rank} Edge Data Shuffle - max_edges: {max_edges}, \\\n                        local_edges: {num_edges} and num_chunks: {num_chunks} \\\n                        Total edges: {all_edges} Local_CHUNK_SIZE: {LOCAL_CHUNK_SIZE}\"\n    )\n\n    for local_part_id in range(num_parts // world_size):\n        local_src_ids = []\n        local_dst_ids = []\n        local_type_eids = []\n        local_etype_ids = []\n        local_eids = []\n\n        for chunk in range(num_chunks):\n            chunk_start = chunk * LOCAL_CHUNK_SIZE\n            chunk_end = (chunk + 1) * LOCAL_CHUNK_SIZE\n\n            logging.debug(\n                f\"[Rank: {rank}] EdgeData Shuffle: processing \\\n                    local_part_id: {local_part_id} and chunkid: {chunk}\"\n            )\n            cur_src_id = edge_data[constants.GLOBAL_SRC_ID][\n                chunk_start:chunk_end\n            ]\n            cur_dst_id = edge_data[constants.GLOBAL_DST_ID][\n                chunk_start:chunk_end\n            ]\n            cur_type_eid = edge_data[constants.GLOBAL_TYPE_EID][\n                chunk_start:chunk_end\n            ]\n            cur_etype_id = edge_data[constants.ETYPE_ID][chunk_start:chunk_end]\n            cur_eid = edge_data[constants.GLOBAL_EID][chunk_start:chunk_end]\n\n            input_list = []\n            owner_ids = id_lookup.get_partition_ids(cur_dst_id)\n            for idx in range(world_size):\n                send_idx = owner_ids == (idx + local_part_id * world_size)\n                send_idx = send_idx.reshape(cur_src_id.shape[0])\n                filt_data = np.column_stack(\n                    (\n                        cur_src_id[send_idx == 1],\n                        cur_dst_id[send_idx == 1],\n                        cur_type_eid[send_idx == 1],\n                        cur_etype_id[send_idx == 1],\n                        cur_eid[send_idx == 1],\n                    )\n                )\n                if filt_data.shape[0] <= 0:\n                    input_list.append(torch.empty((0, 5), dtype=torch.int64))\n                else:\n                    input_list.append(torch.from_numpy(filt_data))\n\n            # Now send newly formed chunk to others.\n            dist.barrier()\n            output_list = alltoallv_cpu(\n                rank, world_size, input_list, retain_nones=False\n            )\n\n            # Replace the values of the edge_data, with the received data from all the other processes.\n            rcvd_edge_data = torch.cat(output_list).numpy()\n            local_src_ids.append(rcvd_edge_data[:, 0])\n            local_dst_ids.append(rcvd_edge_data[:, 1])\n            local_type_eids.append(rcvd_edge_data[:, 2])\n            local_etype_ids.append(rcvd_edge_data[:, 3])\n            local_eids.append(rcvd_edge_data[:, 4])\n\n        edge_data[\n            constants.GLOBAL_SRC_ID + \"/\" + str(local_part_id)\n        ] = np.concatenate(local_src_ids)\n        edge_data[\n            constants.GLOBAL_DST_ID + \"/\" + str(local_part_id)\n        ] = np.concatenate(local_dst_ids)\n        edge_data[\n            constants.GLOBAL_TYPE_EID + \"/\" + str(local_part_id)\n        ] = np.concatenate(local_type_eids)\n        edge_data[\n            constants.ETYPE_ID + \"/\" + str(local_part_id)\n        ] = np.concatenate(local_etype_ids)\n        edge_data[\n            constants.GLOBAL_EID + \"/\" + str(local_part_id)\n        ] = np.concatenate(local_eids)\n\n    # Check if the data was exchanged correctly\n    local_edge_count = 0\n    for local_part_id in range(num_parts // world_size):\n        local_edge_count += edge_data[\n            constants.GLOBAL_SRC_ID + \"/\" + str(local_part_id)\n        ].shape[0]\n    shuffle_edge_counts = allgather_sizes(\n        [local_edge_count], world_size, num_parts, return_sizes=True\n    )\n    shuffle_edge_total = np.sum(shuffle_edge_counts)\n    assert shuffle_edge_total == all_edges\n\n    timer_end = timer()\n    logging.info(\n        f\"[Rank: {rank}] Time to send/rcv edge data: {timedelta(seconds=timer_end-timer_start)}\"\n    )\n\n    # Clean up.\n    edge_data.pop(constants.GLOBAL_SRC_ID)\n    edge_data.pop(constants.GLOBAL_DST_ID)\n    edge_data.pop(constants.GLOBAL_TYPE_EID)\n    edge_data.pop(constants.ETYPE_ID)\n    edge_data.pop(constants.GLOBAL_EID)\n\n    return edge_data\n\n\ndef exchange_feature(\n    rank,\n    data,\n    id_lookup,\n    feat_type,\n    feat_key,\n    featdata_key,\n    gid_start,\n    gid_end,\n    type_id_start,\n    type_id_end,\n    local_part_id,\n    world_size,\n    num_parts,\n    cur_features,\n    cur_global_ids,\n):\n    \"\"\"This function is used to send/receive one feature for either nodes or\n    edges of the input graph dataset.\n\n    Parameters:\n    -----------\n    rank : int\n        integer, unique id assigned to the current process\n    data: dicitonary\n        dictionry in which node or edge features are stored and this information\n        is read from the appropriate node features file which belongs to the\n        current process\n    id_lookup : instance of DistLookupService\n        instance of an implementation of dist. lookup service to retrieve values\n        for keys\n    feat_type : string\n        this is used to distinguish which features are being exchanged. Please\n        note that for nodes ownership is clearly defined and for edges it is\n        always assumed that destination end point of the edge defines the\n        ownership of that particular edge\n    feat_key : string\n        this string is used as a key in the dictionary to store features, as\n        tensors, in local dictionaries\n    featdata_key : numpy array\n        features associated with this feature key being processed\n    gid_start : int\n        starting global_id, of either node or edge, for the feature data\n    gid_end : int\n        ending global_if, of either node or edge, for the feature data\n    type_id_start : int\n        starting type_id for the feature data\n    type_id_end : int\n        ending type_id for the feature data\n    local_part_id : int\n        integers used to the identify the local partition id used to locate\n        data belonging to this partition\n    world_size : int\n        total number of processes created\n    num_parts : int\n        total number of partitions\n    cur_features : dictionary\n        dictionary to store the feature data which belongs to the current\n        process\n    cur_global_ids : dictionary\n        dictionary to store global ids, of either nodes or edges, for which\n        the features stored in the cur_features dictionary\n\n    Returns:\n    -------\n    dictionary :\n        a dictionary is returned where keys are type names and\n        feature data are the values\n    list :\n        a dictionary of global_ids either nodes or edges whose features are\n        received during the data shuffle process\n    \"\"\"\n    # type_ids for this feature subset on the current rank\n    gids_feat = np.arange(gid_start, gid_end)\n    local_idx = np.arange(0, type_id_end - type_id_start)\n\n    feats_per_rank = []\n    global_id_per_rank = []\n\n    tokens = feat_key.split(\"/\")\n    assert len(tokens) == 3\n    local_feat_key = \"/\".join(tokens[:-1]) + \"/\" + str(local_part_id)\n\n    logging.debug(\n        f\"[Rank: {rank} feature: {feat_key}, gid_start - {gid_start} and gid_end - {gid_end}\"\n    )\n\n    # Get the partition ids for the range of global nids.\n    if feat_type == constants.STR_NODE_FEATURES:\n        # Retrieve the partition ids for the node features.\n        # Each partition id will be in the range [0, num_parts).\n        partid_slice = id_lookup.get_partition_ids(\n            np.arange(gid_start, gid_end, dtype=np.int64)\n        )\n    else:\n        # Edge data case.\n        # Ownership is determined by the destination node.\n        assert data is not None\n        global_eids = np.arange(gid_start, gid_end, dtype=np.int64)\n        if data[constants.GLOBAL_EID].shape[0] > 0:\n            logging.debug(\n                f\"[Rank: {rank} disk read global eids - min - {np.amin(data[constants.GLOBAL_EID])}, max - {np.amax(data[constants.GLOBAL_EID])}, count - {data[constants.GLOBAL_EID].shape}\"\n            )\n\n        # Now use `data` to extract destination nodes' global id\n        # and use that to get the ownership\n        common, idx1, idx2 = np.intersect1d(\n            data[constants.GLOBAL_EID], global_eids, return_indices=True\n        )\n        assert (\n            common.shape[0] == idx2.shape[0]\n        ), f\"Rank {rank}: {common.shape[0]} != {idx2.shape[0]}\"\n        assert (\n            common.shape[0] == global_eids.shape[0]\n        ), f\"Rank {rank}: {common.shape[0]} != {global_eids.shape[0]}\"\n\n        global_dst_nids = data[constants.GLOBAL_DST_ID][idx1]\n        assert np.all(global_eids == data[constants.GLOBAL_EID][idx1])\n        partid_slice = id_lookup.get_partition_ids(global_dst_nids)\n\n    # determine the shape of the feature-data\n    # this is needed to so that ranks where feature-data is not present\n    # should use the correct shape for sending the padded vector.\n    # exchange length here.\n    feat_dim_len = 0\n    if featdata_key is not None:\n        feat_dim_len = len(featdata_key.shape)\n    all_lens = allgather_sizes(\n        [feat_dim_len], world_size, num_parts, return_sizes=True\n    )\n    if all_lens[0] <= 0:\n        logging.debug(\n            f\"[Rank: {rank} No process has any feature data to shuffle for {local_feat_key}\"\n        )\n        return cur_features, cur_global_ids\n\n    rank0_shape_len = all_lens[0]\n    for idx in range(1, world_size):\n        assert (all_lens[idx] == 0) or (all_lens[idx] == rank0_shape_len), (\n            f\"feature: {local_feat_key} shapes does not match \"\n            f\"at rank - {idx} and rank - 0\"\n        )\n\n    # exchange actual data here.\n    if featdata_key is not None:\n        logging.debug(f\"Rank: {rank} {featdata_key.shape=}\")\n        feat_dims_dtype = list(featdata_key.shape)\n        assert (\n            len(featdata_key.shape) == 2 or len(featdata_key.shape) == 1\n        ), f\"We expect 1D or 2D tensors for features, got shape {featdata_key.shape}\"\n        # When a feature is 2-dim, the shape should match the feature dimension.\n        if len(featdata_key.shape) == 2:\n            feature_dimension = feat_dims_dtype[1]\n        else:\n            feature_dimension = 0\n        feat_dims_dtype.append(DATA_TYPE_ID[featdata_key.dtype])\n    else:\n        feat_dims_dtype = list(np.zeros((rank0_shape_len), dtype=np.int64))\n        feat_dims_dtype.append(DATA_TYPE_ID[torch.float32])\n        feature_dimension = 0\n\n    feature_dimension_tensor = torch.tensor([feature_dimension])\n    dist.all_reduce(feature_dimension_tensor, op=dist.ReduceOp.MAX)\n    feature_dimension = feature_dimension_tensor.item()\n\n    logging.debug(f\"Sending the feature shape information - {feat_dims_dtype}\")\n    all_dims_dtype = allgather_sizes(\n        feat_dims_dtype, world_size, num_parts, return_sizes=True\n    )\n\n    for idx in range(world_size):\n        cond = partid_slice == (idx + local_part_id * world_size)\n        gids_per_partid = gids_feat[cond]\n        local_idx_partid = local_idx[cond]\n\n        if gids_per_partid.shape[0] == 0:\n            assert len(all_dims_dtype) % world_size == 0\n            dim_len = int(len(all_dims_dtype) / world_size)\n            rank0_shape = list(np.zeros((dim_len - 1), dtype=np.int32))\n            assert (\n                len(rank0_shape) == 2 or len(rank0_shape) == 1\n            ), f\"We expect 1D or 2D tensors for features, got shape {rank0_shape}\"\n            # When a feature is 2-dim, the shape[1] (number of columns) should match the feature dimension.\n            if len(rank0_shape) == 2:\n                rank0_shape[1] = feature_dimension\n            rank0_dtype = REV_DATA_TYPE_ID[\n                all_dims_dtype[(dim_len - 1) : (dim_len)][0]\n            ]\n            data = torch.empty(rank0_shape, dtype=rank0_dtype)\n            feats_per_rank.append(data)\n            global_id_per_rank.append(torch.empty((0,), dtype=torch.int64))\n        else:\n            feats_per_rank.append(featdata_key[local_idx_partid])\n            global_id_per_rank.append(\n                torch.from_numpy(gids_per_partid).type(torch.int64)\n            )\n    for idx, tt in enumerate(feats_per_rank):\n        logging.debug(\n            f\"[Rank: {rank} features shape - {tt.shape} and ids - {global_id_per_rank[idx].shape}\"\n        )\n\n    # features (and global nids) per rank to be sent out are ready\n    # for transmission, perform alltoallv here.\n    output_feat_list = alltoallv_cpu(\n        rank, world_size, feats_per_rank, retain_nones=False\n    )\n    output_id_list = alltoallv_cpu(\n        rank, world_size, global_id_per_rank, retain_nones=False\n    )\n    logging.debug(\n        f\"[Rank : {rank} feats - {output_feat_list}, ids - {output_id_list}\"\n    )\n    assert len(output_feat_list) == len(output_id_list), (\n        \"Length of feature list and id list are expected to be equal while \"\n        f\"got {len(output_feat_list)} and {len(output_id_list)}.\"\n    )\n\n    # stitch node_features together to form one large feature tensor\n    if len(output_feat_list) > 0:\n        output_feat_list = torch.cat(output_feat_list)\n        output_id_list = torch.cat(output_id_list)\n        if local_feat_key in cur_features:\n            temp = cur_features[local_feat_key]\n            cur_features[local_feat_key] = torch.cat([temp, output_feat_list])\n            temp = cur_global_ids[local_feat_key]\n            cur_global_ids[local_feat_key] = torch.cat([temp, output_id_list])\n        else:\n            cur_features[local_feat_key] = output_feat_list\n            cur_global_ids[local_feat_key] = output_id_list\n    else:\n        cur_features[local_feat_key] = torch.empty(\n            (0, feature_dimension), dtype=torch.float32\n        )\n        cur_global_ids[local_feat_key] = torch.empty((0,), dtype=torch.int64)\n    return cur_features, cur_global_ids\n\n\ndef exchange_features(\n    rank,\n    world_size,\n    num_parts,\n    feature_tids,\n    type_id_map,\n    id_lookup,\n    feature_data,\n    feat_type,\n    data,\n):\n    \"\"\"\n    This function is used to shuffle node features so that each process will receive\n    all the node features whose corresponding nodes are owned by the same process.\n    The mapping procedure to identify the owner process is not straight forward. The\n    following steps are used to identify the owner processes for the locally read node-\n    features.\n    a. Compute the global_nids for the locally read node features. Here metadata json file\n        is used to identify the corresponding global_nids. Please note that initial graph input\n        nodes.txt files are sorted based on node_types.\n    b. Using global_nids and metis partitions owner processes can be easily identified.\n    c. Now each process sends the global_nids for which shuffle_global_nids are needed to be\n        retrieved.\n    d. After receiving the corresponding shuffle_global_nids these ids are added to the\n        node_data and edge_data dictionaries\n\n    This pipeline assumes all the input data in numpy format, except node/edge features which\n    are maintained as tensors throughout the various stages of the pipeline execution.\n\n    Parameters:\n    -----------\n    rank : int\n        rank of the current process\n    world_size : int\n        total no. of participating processes.\n    feature_tids : dictionary\n        dictionary with keys as node-type names with suffixes as feature names\n        and value is a dictionary. This dictionary contains information about\n        node-features associated with a given node-type and value is a list.\n        This list contains a of indexes, like [starting-idx, ending-idx) which\n        can be used to index into the node feature tensors read from\n        corresponding input files.\n    type_id_map : dictionary\n        mapping between type names and global_ids, of either nodes or edges,\n        which belong to the keys in this dictionary\n    id_lookup : instance of class DistLookupService\n       Distributed lookup service used to map global-nids to respective\n       partition-ids and shuffle-global-nids\n    feat_type : string\n        this is used to distinguish which features are being exchanged. Please\n        note that for nodes ownership is clearly defined and for edges it is\n        always assumed that destination end point of the edge defines the\n        ownership of that particular edge\n    data: dicitonary\n        dictionry in which node or edge features are stored and this information\n        is read from the appropriate node features file which belongs to the\n        current process\n\n    Returns:\n    --------\n    dictionary :\n        a dictionary is returned where keys are type names and\n        feature data are the values\n    list :\n        a dictionary of global_ids either nodes or edges whose features are\n        received during the data shuffle process\n    \"\"\"\n    start = timer()\n    own_features = {}\n    own_global_ids = {}\n\n    # To iterate over the node_types and associated node_features\n    for feat_key, type_info in feature_tids.items():\n        # To iterate over the feature data, of a given (node or edge )type\n        # type_info is a list of 3 elements (as shown below):\n        #   [feature-name, starting-idx, ending-idx]\n        #       feature-name is the name given to the feature-data,\n        #       read from the input metadata file\n        #       [starting-idx, ending-idx) specifies the range of indexes\n        #        associated with the features data\n        # Determine the owner process for these features.\n        # Note that the keys in the node features (and similarly edge features)\n        # dictionary is of the following format:\n        #   `node_type/feature_name/local_part_id`:\n        #    where node_type and feature_name are self-explanatory and\n        #    local_part_id denotes the partition-id, in the local process,\n        #    which will be used a suffix to store all the information of a\n        #    given partition which is processed by the current process. Its\n        #    values start from 0 onwards, for instance 0, 1, 2 ... etc.\n        #    local_part_id can be easily mapped to global partition id very\n        #    easily, using cyclic ordering. All local_part_ids = 0 from all\n        #    processes will form global partition-ids between 0 and world_size-1.\n        #    Similarly all local_part_ids = 1 from all processes will form\n        #    global partition ids in the range [world_size, 2*world_size-1] and\n        #    so on.\n        tokens = feat_key.split(\"/\")\n        assert len(tokens) == 3\n        type_name = tokens[0]\n        feat_name = tokens[1]\n        logging.debug(f\"[Rank: {rank}] processing feature: {feat_key}\")\n\n        for feat_info in type_info:\n            # Compute the global_id range for this feature data\n            type_id_start = int(feat_info[0])\n            type_id_end = int(feat_info[1])\n            begin_global_id = type_id_map[type_name][0]\n            gid_start = begin_global_id + type_id_start\n            gid_end = begin_global_id + type_id_end\n\n            # Check if features exist for this type_name + feat_name.\n            # This check should always pass, because feature_tids are built\n            # by reading the input metadata json file for existing features.\n            assert feat_key in feature_data\n\n            for local_part_id in range(num_parts // world_size):\n                featdata_key = feature_data[feat_key]\n\n                # Synchronize for each feature\n                dist.barrier()\n                own_features, own_global_ids = exchange_feature(\n                    rank,\n                    data,\n                    id_lookup,\n                    feat_type,\n                    feat_key,\n                    featdata_key,\n                    gid_start,\n                    gid_end,\n                    type_id_start,\n                    type_id_end,\n                    local_part_id,\n                    world_size,\n                    num_parts,\n                    own_features,\n                    own_global_ids,\n                )\n\n    end = timer()\n    logging.info(\n        f\"[Rank: {rank}] Total time for feature exchange: {timedelta(seconds = end - start)}\"\n    )\n    for k, v in own_features.items():\n        logging.debug(f\"Rank: {rank}] Key - {k} Value - {v.shape}\")\n    return own_features, own_global_ids\n\n\ndef exchange_graph_data(\n    rank,\n    world_size,\n    num_parts,\n    node_features,\n    edge_features,\n    node_feat_tids,\n    edge_feat_tids,\n    edge_data,\n    id_lookup,\n    ntypes_ntypeid_map,\n    ntypes_gnid_range_map,\n    etypes_geid_range_map,\n    ntid_ntype_map,\n    schema_map,\n):\n    \"\"\"\n    Wrapper function which is used to shuffle graph data on all the processes.\n\n    Parameters:\n    -----------\n    rank : int\n        rank of the current process\n    world_size : int\n        total no. of participating processes.\n    num_parts : int\n        total no. of graph partitions.\n    node_feautres : dicitonary\n        dictionry where node_features are stored and this information is read from the appropriate\n        node features file which belongs to the current process\n    edge_features : dictionary\n        dictionary where edge_features are stored. This information is read from the appropriate\n        edge feature files whose ownership is assigned to the current process\n    node_feat_tids: dictionary\n        in which keys are node-type names and values are triplets. Each triplet has node-feature name\n        and the starting and ending type ids of the node-feature data read from the corresponding\n        node feature data file read by current process. Each node type may have several features and\n        hence each key may have several triplets.\n    edge_feat_tids : dictionary\n        a dictionary in which keys are edge-type names and values are triplets of the format\n        <feat-name, start-per-type-idx, end-per-type-idx>. This triplet is used to identify\n        the chunk of feature data for which current process is responsible for\n    edge_data : dictionary\n        dictionary which is used to store edge information as read from appropriate files assigned\n        to each process.\n    id_lookup : instance of class DistLookupService\n       Distributed lookup service used to map global-nids to respective partition-ids and\n       shuffle-global-nids\n    ntypes_ntypeid_map : dictionary\n        mappings between node type names and node type ids\n    ntypes_gnid_range_map : dictionary\n        mapping between node type names and global_nids which belong to the keys in this dictionary\n    etypes_geid_range_map : dictionary\n        mapping between edge type names and global_eids which are assigned to the edges of this\n        edge_type\n    ntid_ntype_map : dictionary\n        mapping between node type id and no of nodes which belong to each node_type_id\n    schema_map : dictionary\n        is the data structure read from the metadata json file for the input graph\n\n    Returns:\n    --------\n    dictionary :\n        the input argument, node_data dictionary, is updated with the node data received from other processes\n        in the world. The node data is received by each rank in the process of data shuffling.\n    dictionary :\n        node features dictionary which has node features for the nodes which are owned by the current\n        process\n    dictionary :\n        list of global_nids for the nodes whose node features are received when node features shuffling was\n        performed in the `exchange_features` function call\n    dictionary :\n        the input argument, edge_data dictionary, is updated with the edge data received from other processes\n        in the world. The edge data is received by each rank in the process of data shuffling.\n    dictionary :\n        edge features dictionary which has edge features. These destination end points of these edges\n        are owned by the current process\n    dictionary :\n        list of global_eids for the edges whose edge features are received when edge features shuffling\n        was performed in the `exchange_features` function call\n    \"\"\"\n    memory_snapshot(\"ShuffleNodeFeaturesBegin: \", rank)\n    logging.debug(f\"[Rank: {rank} - node_feat_tids - {node_feat_tids}\")\n    rcvd_node_features, rcvd_global_nids = exchange_features(\n        rank,\n        world_size,\n        num_parts,\n        node_feat_tids,\n        ntypes_gnid_range_map,\n        id_lookup,\n        node_features,\n        constants.STR_NODE_FEATURES,\n        None,\n    )\n    dist.barrier()\n    memory_snapshot(\"ShuffleNodeFeaturesComplete: \", rank)\n    logging.debug(f\"[Rank: {rank}] Done with node features exchange.\")\n\n    rcvd_edge_features, rcvd_global_eids = exchange_features(\n        rank,\n        world_size,\n        num_parts,\n        edge_feat_tids,\n        etypes_geid_range_map,\n        id_lookup,\n        edge_features,\n        constants.STR_EDGE_FEATURES,\n        edge_data,\n    )\n    dist.barrier()\n    logging.debug(f\"[Rank: {rank}] Done with edge features exchange.\")\n\n    node_data = gen_node_data(\n        rank, world_size, num_parts, id_lookup, ntid_ntype_map, schema_map\n    )\n    dist.barrier()\n    memory_snapshot(\"NodeDataGenerationComplete: \", rank)\n\n    edge_data = exchange_edge_data(\n        rank, world_size, num_parts, edge_data, id_lookup\n    )\n    dist.barrier()\n    memory_snapshot(\"ShuffleEdgeDataComplete: \", rank)\n    return (\n        node_data,\n        rcvd_node_features,\n        rcvd_global_nids,\n        edge_data,\n        rcvd_edge_features,\n        rcvd_global_eids,\n    )\n\n\ndef read_dataset(rank, world_size, id_lookup, params, schema_map, ntype_counts):\n    \"\"\"\n    This function gets the dataset and performs post-processing on the data which is read from files.\n    Additional information(columns) are added to nodes metadata like owner_process, global_nid which\n    are later used in processing this information. For edge data, which is now a dictionary, we add new columns\n    like global_edge_id and owner_process. Augmenting these data structure helps in processing these data structures\n    when data shuffling is performed.\n\n    Parameters:\n    -----------\n    rank : int\n        rank of the current process\n    world_size : int\n        total no. of processes instantiated\n    id_lookup : instance of class DistLookupService\n       Distributed lookup service used to map global-nids to respective partition-ids and\n       shuffle-global-nids\n    params : argparser object\n        argument parser object to access command line arguments\n    schema_map : dictionary\n        dictionary created by reading the input graph metadata json file\n\n    Returns :\n    ---------\n    dictionary\n        in which keys are node-type names and values are are tuples representing the range of ids\n        for nodes to be read by the current process\n    dictionary\n        node features which is a dictionary where keys are feature names and values are feature\n        data as multi-dimensional tensors\n    dictionary\n        in which keys are node-type names and values are triplets. Each triplet has node-feature name\n        and the starting and ending type ids of the node-feature data read from the corresponding\n        node feature data file read by current process. Each node type may have several features and\n        hence each key may have several triplets.\n    dictionary\n        edge data information is read from edges.txt and additional columns are added such as\n        owner process for each edge.\n    dictionary\n        edge features which is also a dictionary, similar to node features dictionary\n    dictionary\n        a dictionary in which keys are edge-type names and values are tuples indicating the range of ids\n        for edges read by the current process.\n    dictionary\n        a dictionary in which keys are edge-type names and values are triplets,\n        (edge-feature-name, start_type_id, end_type_id). These type_ids are indices in the edge-features\n        read by the current process. Note that each edge-type may have several edge-features.\n    \"\"\"\n    edge_features = {}\n    (\n        node_features,\n        node_feat_tids,\n        edge_data,\n        edge_typecounts,\n        edge_tids,\n        edge_features,\n        edge_feat_tids,\n    ) = get_dataset(\n        params.input_dir,\n        params.graph_name,\n        rank,\n        world_size,\n        params.num_parts,\n        schema_map,\n        ntype_counts,\n    )\n    # Synchronize so that everybody completes reading dataset from disk\n    dist.barrier()\n    logging.info(f\"[Rank: {rank}] Done reading dataset {params.input_dir}\")\n\n    edge_data = augment_edge_data(\n        edge_data, id_lookup, edge_tids, rank, world_size, params.num_parts\n    )\n    dist.barrier()  # SYNCH\n    logging.debug(\n        f\"[Rank: {rank}] Done augmenting edge_data: {len(edge_data)}, {edge_data[constants.GLOBAL_SRC_ID].shape}\"\n    )\n\n    return (\n        node_features,\n        node_feat_tids,\n        edge_data,\n        edge_typecounts,\n        edge_features,\n        edge_feat_tids,\n    )\n\n\ndef reorder_data(num_parts, world_size, data, key):\n    \"\"\"\n    Auxiliary function used to sort node and edge data for the input graph.\n\n    Parameters:\n    -----------\n    num_parts : int\n        total no. of partitions\n    world_size : int\n        total number of nodes used in this execution\n    data : dictionary\n        which is used to store the node and edge data for the input graph\n    key : string\n        specifies the column which is used to determine the sort order for\n        the remaining columns\n\n    Returns:\n    --------\n    dictionary\n        same as the input dictionary, but with reordered columns (values in\n        the dictionary), as per the np.argsort results on the column specified\n        by the ``key`` column\n    \"\"\"\n    for local_part_id in range(num_parts // world_size):\n        sorted_idx = data[key + \"/\" + str(local_part_id)].argsort()\n        for k, v in data.items():\n            tokens = k.split(\"/\")\n            assert len(tokens) == 2\n            if tokens[1] == str(local_part_id):\n                data[k] = v[sorted_idx]\n        sorted_idx = None\n    gc.collect()\n    return data\n\n\ndef gen_dist_partitions(rank, world_size, params):\n    \"\"\"\n    Function which will be executed by all Gloo processes to begin execution of the pipeline.\n    This function expects the input dataset is split across multiple file format.\n\n    Input dataset and its file structure is described in metadata json file which is also part of the\n    input dataset. On a high-level, this metadata json file contains information about the following items\n    a) Nodes metadata, It is assumed that nodes which belong to each node-type are split into p files\n       (wherer `p` is no. of partitions).\n    b) Similarly edge metadata contains information about edges which are split into p-files.\n    c) Node and Edge features, it is also assumed that each node (and edge) feature, if present, is also\n       split into `p` files.\n\n    For example, a sample metadata json file might be as follows: :\n    (In this toy example, we assume that we have \"m\" node-types, \"k\" edge types, and for node_type = ntype0-name\n     we have two features namely feat0-name and feat1-name. Please note that the node-features are also split into\n     `p` files. This will help in load-balancing during data-shuffling phase).\n\n    Terminology used to identify any particular \"id\" assigned to nodes, edges or node features. Prefix \"global\" is\n    used to indicate that this information is either read from the input dataset or autogenerated based on the information\n    read from input dataset files. Prefix \"type\" is used to indicate a unique id assigned to either nodes or edges.\n    For instance, type_node_id means that a unique id, with a given node type,  assigned to a node. And prefix \"shuffle\"\n    will be used to indicate a unique id, across entire graph, assigned to either a node or an edge. For instance,\n    SHUFFLE_GLOBAL_NID means a unique id which is assigned to a node after the data shuffle is completed.\n\n    Some high-level notes on the structure of the metadata json file.\n    1. path(s) mentioned in the entries for nodes, edges and node-features files can be either absolute or relative.\n       if these paths are relative, then it is assumed that they are relative to the folder from which the execution is\n       launched.\n    2. The id_startx and id_endx represent the type_node_id and type_edge_id respectively for nodes and edge data. This\n       means that these ids should match the no. of nodes/edges read from any given file. Since these are type_ids for\n       the nodes and edges in any given file, their global_ids can be easily computed as well.\n\n    {\n        \"graph_name\" : xyz,\n        \"node_type\" : [\"ntype0-name\", \"ntype1-name\", ....], #m node types\n        \"num_nodes_per_chunk\" : [\n            [a0, a1, ...a<p-1>], #p partitions\n            [b0, b1, ... b<p-1>],\n            ....\n            [c0, c1, ..., c<p-1>] #no, of node types\n        ],\n        \"edge_type\" : [\"src_ntype:edge_type:dst_ntype\", ....], #k edge types\n        \"num_edges_per_chunk\" : [\n            [a0, a1, ...a<p-1>], #p partitions\n            [b0, b1, ... b<p-1>],\n            ....\n            [c0, c1, ..., c<p-1>] #no, of edge types\n        ],\n        \"node_data\" : {\n            \"ntype0-name\" : {\n                \"feat0-name\" : {\n                    \"format\" : {\"name\": \"numpy\"},\n                    \"data\" :   [ #list of lists\n                        [\"<path>/feat-0.npy\", 0, id_end0],\n                        [\"<path>/feat-1.npy\", id_start1, id_end1],\n                        ....\n                        [\"<path>/feat-<p-1>.npy\", id_start<p-1>, id_end<p-1>]\n                    ]\n                },\n                \"feat1-name\" : {\n                    \"format\" : {\"name\": \"numpy\"},\n                    \"data\" : [ #list of lists\n                        [\"<path>/feat-0.npy\", 0, id_end0],\n                        [\"<path>/feat-1.npy\", id_start1, id_end1],\n                        ....\n                        [\"<path>/feat-<p-1>.npy\", id_start<p-1>, id_end<p-1>]\n                    ]\n                }\n            }\n        },\n        \"edges\": { #k edge types\n            \"src_ntype:etype0-name:dst_ntype\" : {\n                \"format\": {\"name\" : \"csv\", \"delimiter\" : \" \"},\n                \"data\" : [\n                    [\"<path>/etype0-name-0.txt\", 0, id_end0], #These are type_edge_ids for edges of this type\n                    [\"<path>/etype0-name-1.txt\", id_start1, id_end1],\n                    ...,\n                    [\"<path>/etype0-name-<p-1>.txt\", id_start<p-1>, id_end<p-1>]\n                ]\n            },\n            ...,\n            \"src_ntype:etype<k-1>-name:dst_ntype\" : {\n                \"format\": {\"name\" : \"csv\", \"delimiter\" : \" \"},\n                \"data\" : [\n                    [\"<path>/etype<k-1>-name-0.txt\", 0, id_end0],\n                    [\"<path>/etype<k-1>-name-1.txt\", id_start1, id_end1],\n                    ...,\n                    [\"<path>/etype<k-1>-name-<p-1>.txt\", id_start<p-1>, id_end<p-1>]\n                ]\n            },\n        },\n    }\n\n    The function performs the following steps:\n    1. Reads the metis partitions to identify the owner process of all the nodes in the entire graph.\n    2. Reads the input data set, each partitipating process will map to a single file for the edges,\n        node-features and edge-features for each node-type and edge-types respectively. Using nodes metadata\n        information, nodes which are owned by a given process are generated to optimize communication to some\n        extent.\n    3. Now each process shuffles the data by identifying the respective owner processes using metis\n        partitions.\n        a. To identify owner processes for nodes, metis partitions will be used.\n        b. For edges, the owner process of the destination node will be the owner of the edge as well.\n        c. For node and edge features, identifying the owner process is a little bit involved.\n            For this purpose, graph metadata json file is used to first map the locally read node features\n            to their global_nids. Now owner process is identified using metis partitions for these global_nids\n            to retrieve shuffle_global_nids. A similar process is used for edge_features as well.\n        d. After all the data shuffling is done, the order of node-features may be different when compared to\n            their global_type_nids. Node- and edge-data are ordered by node-type and edge-type respectively.\n            And now node features and edge features are re-ordered to match the order of their node- and edge-types.\n    4. Last step is to create the DGL objects with the data present on each of the processes.\n        a. DGL objects for nodes, edges, node- and edge- features.\n        b. Metadata is gathered from each process to create the global metadata json file, by process rank = 0.\n\n    Parameters:\n    ----------\n    rank : int\n        integer representing the rank of the current process in a typical distributed implementation\n    world_size : int\n        integer representing the total no. of participating processes in a typical distributed implementation\n    params : argparser object\n        this object, key value pairs, provides access to the command line arguments from the runtime environment\n    \"\"\"\n    global_start = timer()\n    logging.info(\n        f\"[Rank: {rank}] Starting distributed data processing pipeline...\"\n    )\n    memory_snapshot(\"Pipeline Begin: \", rank)\n\n    # init processing\n    schema_map = read_json(os.path.join(params.input_dir, params.schema))\n\n    # The resources, which are node-id to partition-id mappings, are split\n    # into `world_size` number of parts, where each part can be mapped to\n    # each physical node.\n    id_lookup = DistLookupService(\n        os.path.join(params.input_dir, params.partitions_dir),\n        schema_map[constants.STR_NODE_TYPE],\n        rank,\n        world_size,\n        params.num_parts,\n    )\n\n    # get the id to name mappings here.\n    ntypes_ntypeid_map, ntypes, ntypeid_ntypes_map = get_node_types(schema_map)\n    etypes_etypeid_map, etypes, etypeid_etypes_map = get_edge_types(schema_map)\n    logging.info(\n        f\"[Rank: {rank}] Initialized metis partitions and node_types map...\"\n    )\n\n    # Initialize distributed lookup service for partition-id and shuffle-global-nids mappings\n    # for global-nids\n    _, global_nid_ranges = get_idranges(\n        schema_map[constants.STR_NODE_TYPE],\n        get_ntype_counts_map(\n            schema_map[constants.STR_NODE_TYPE],\n            schema_map[constants.STR_NUM_NODES_PER_TYPE],\n        ),\n    )\n    id_map = dgl.distributed.id_map.IdMap(global_nid_ranges)\n    id_lookup.set_idMap(id_map)\n    # read input graph files and augment these datastructures with\n    # appropriate information (global_nid and owner process) for node and edge data\n    (\n        node_features,\n        node_feat_tids,\n        edge_data,\n        edge_typecounts,\n        edge_features,\n        edge_feat_tids,\n    ) = read_dataset(\n        rank,\n        world_size,\n        id_lookup,\n        params,\n        schema_map,\n        get_ntype_counts_map(\n            schema_map[constants.STR_NODE_TYPE],\n            schema_map[constants.STR_NUM_NODES_PER_TYPE],\n        ),\n    )\n    logging.info(\n        f\"[Rank: {rank}] Done augmenting file input data with auxilary columns\"\n    )\n    memory_snapshot(\"DatasetReadComplete: \", rank)\n\n    # send out node and edge data --- and appropriate features.\n    # this function will also stitch the data recvd from other processes\n    # and return the aggregated data\n    # ntypes_gnid_range_map = get_gnid_range_map(node_tids)\n    # etypes_geid_range_map = get_gnid_range_map(edge_tids)\n    ntypes_gnid_range_map = get_gid_offsets(\n        schema_map[constants.STR_NODE_TYPE],\n        get_ntype_counts_map(\n            schema_map[constants.STR_NODE_TYPE],\n            schema_map[constants.STR_NUM_NODES_PER_TYPE],\n        ),\n    )\n    etypes_geid_range_map = get_gid_offsets(\n        schema_map[constants.STR_EDGE_TYPE], edge_typecounts\n    )\n\n    (\n        node_data,\n        rcvd_node_features,\n        rcvd_global_nids,\n        edge_data,\n        rcvd_edge_features,\n        rcvd_global_eids,\n    ) = exchange_graph_data(\n        rank,\n        world_size,\n        params.num_parts,\n        node_features,\n        edge_features,\n        node_feat_tids,\n        edge_feat_tids,\n        edge_data,\n        id_lookup,\n        ntypes_ntypeid_map,\n        ntypes_gnid_range_map,\n        etypes_geid_range_map,\n        ntypeid_ntypes_map,\n        schema_map,\n    )\n    gc.collect()\n    logging.debug(f\"[Rank: {rank}] Done with data shuffling...\")\n    memory_snapshot(\"DataShuffleComplete: \", rank)\n\n    # sort node_data by ntype\n    node_data = reorder_data(\n        params.num_parts, world_size, node_data, constants.NTYPE_ID\n    )\n    logging.debug(f\"[Rank: {rank}] Sorted node_data by node_type\")\n    memory_snapshot(\"NodeDataSortComplete: \", rank)\n\n    # resolve global_ids for nodes\n    # Synchronize before assigning shuffle-global-ids to nodes\n    dist.barrier()\n    assign_shuffle_global_nids_nodes(\n        rank, world_size, params.num_parts, node_data\n    )\n    logging.debug(f\"[Rank: {rank}] Done assigning global-ids to nodes...\")\n    memory_snapshot(\"ShuffleGlobalID_Nodes_Complete: \", rank)\n\n    # shuffle node feature according to the node order on each rank.\n    for ntype_name in ntypes:\n        featnames = get_ntype_featnames(ntype_name, schema_map)\n        for featname in featnames:\n            # if a feature name exists for a node-type, then it should also have\n            # feature data as well. Hence using the assert statement.\n            for local_part_id in range(params.num_parts // world_size):\n                feature_key = (\n                    ntype_name + \"/\" + featname + \"/\" + str(local_part_id)\n                )\n                assert feature_key in rcvd_global_nids\n                global_nids = rcvd_global_nids[feature_key]\n\n                _, idx1, _ = np.intersect1d(\n                    node_data[constants.GLOBAL_NID + \"/\" + str(local_part_id)],\n                    global_nids,\n                    return_indices=True,\n                )\n                shuffle_global_ids = node_data[\n                    constants.SHUFFLE_GLOBAL_NID + \"/\" + str(local_part_id)\n                ][idx1]\n                feature_idx = shuffle_global_ids.argsort()\n\n                rcvd_node_features[feature_key] = rcvd_node_features[\n                    feature_key\n                ][feature_idx]\n    memory_snapshot(\"ReorderNodeFeaturesComplete: \", rank)\n\n    # Sort edge_data by etype\n    edge_data = reorder_data(\n        params.num_parts, world_size, edge_data, constants.ETYPE_ID\n    )\n    logging.debug(f\"[Rank: {rank}] Sorted edge_data by edge_type\")\n    memory_snapshot(\"EdgeDataSortComplete: \", rank)\n\n    # Synchronize before assigning shuffle-global-nids for edges end points.\n    dist.barrier()\n    shuffle_global_eid_offsets = assign_shuffle_global_nids_edges(\n        rank, world_size, params.num_parts, edge_data\n    )\n    logging.debug(f\"[Rank: {rank}] Done assigning global_ids to edges ...\")\n\n    memory_snapshot(\"ShuffleGlobalID_Edges_Complete: \", rank)\n\n    # Shuffle edge features according to the edge order on each rank.\n    for etype_name in etypes:\n        featnames = get_etype_featnames(etype_name, schema_map)\n        for featname in featnames:\n            for local_part_id in range(params.num_parts // world_size):\n                feature_key = (\n                    etype_name + \"/\" + featname + \"/\" + str(local_part_id)\n                )\n                assert feature_key in rcvd_global_eids\n                global_eids = rcvd_global_eids[feature_key]\n\n                _, idx1, _ = np.intersect1d(\n                    edge_data[constants.GLOBAL_EID + \"/\" + str(local_part_id)],\n                    global_eids,\n                    return_indices=True,\n                )\n                shuffle_global_ids = edge_data[\n                    constants.SHUFFLE_GLOBAL_EID + \"/\" + str(local_part_id)\n                ][idx1]\n                feature_idx = shuffle_global_ids.argsort()\n\n                rcvd_edge_features[feature_key] = rcvd_edge_features[\n                    feature_key\n                ][feature_idx]\n\n    # determine global-ids for edge end-points\n    # Synchronize before retrieving shuffle-global-nids for edges end points.\n    dist.barrier()\n    edge_data = lookup_shuffle_global_nids_edges(\n        rank, world_size, params.num_parts, edge_data, id_lookup, node_data\n    )\n    logging.debug(\n        f\"[Rank: {rank}] Done resolving orig_node_id for local node_ids...\"\n    )\n    memory_snapshot(\"ShuffleGlobalID_Lookup_Complete: \", rank)\n\n    def prepare_local_data(src_data, local_part_id):\n        local_data = {}\n        for k, v in src_data.items():\n            tokens = k.split(\"/\")\n            if tokens[len(tokens) - 1] == str(local_part_id):\n                local_data[\"/\".join(tokens[:-1])] = v\n        return local_data\n\n    # create dgl objects here\n    output_meta_json = {}\n    start = timer()\n\n    graph_formats = None\n    if params.graph_formats:\n        graph_formats = params.graph_formats.split(\",\")\n\n    prev_last_ids = {}\n    for local_part_id in range(params.num_parts // world_size):\n        # Synchronize for each local partition of the graph object.\n        dist.barrier()\n\n        num_edges = shuffle_global_eid_offsets[local_part_id]\n        node_count = len(\n            node_data[constants.NTYPE_ID + \"/\" + str(local_part_id)]\n        )\n        edge_count = len(\n            edge_data[constants.ETYPE_ID + \"/\" + str(local_part_id)]\n        )\n        local_node_data = prepare_local_data(node_data, local_part_id)\n        local_edge_data = prepare_local_data(edge_data, local_part_id)\n        tot_node_count = sum(schema_map[\"num_nodes_per_type\"])\n        tot_edge_count = sum(schema_map[\"num_edges_per_type\"])\n        (\n            graph_obj,\n            ntypes_map_val,\n            etypes_map_val,\n            ntypes_map,\n            etypes_map,\n            orig_nids,\n            orig_eids,\n        ) = create_graph_object(\n            tot_node_count,\n            tot_edge_count,\n            node_count,\n            edge_count,\n            params.num_parts,\n            schema_map,\n            rank + local_part_id * world_size,\n            local_node_data,\n            local_edge_data,\n            num_edges,\n            get_ntype_counts_map(\n                schema_map[constants.STR_NODE_TYPE],\n                schema_map[constants.STR_NUM_NODES_PER_TYPE],\n            ),\n            edge_typecounts,\n            prev_last_ids,\n            return_orig_nids=params.save_orig_nids,\n            return_orig_eids=params.save_orig_eids,\n            use_graphbolt=params.use_graphbolt,\n            store_inner_node=params.store_inner_node,\n            store_inner_edge=params.store_inner_edge,\n            store_eids=params.store_eids,\n        )\n        sort_etypes = len(etypes_map) > 1\n        local_node_features = prepare_local_data(\n            rcvd_node_features, local_part_id\n        )\n        local_edge_features = prepare_local_data(\n            rcvd_edge_features, local_part_id\n        )\n        write_dgl_objects(\n            graph_obj,\n            local_node_features,\n            local_edge_features,\n            params.output,\n            rank + (local_part_id * world_size),\n            orig_nids,\n            orig_eids,\n            graph_formats,\n            sort_etypes,\n            params.use_graphbolt,\n        )\n        if params.use_graphbolt:\n            memory_snapshot(\"DiskWriteGrapgboltObjectsComplete: \", rank)\n        else:\n            memory_snapshot(\"DiskWriteDGLObjectsComplete: \", rank)\n\n        # get the meta-data\n        json_metadata = create_metadata_json(\n            params.graph_name,\n            node_count,\n            edge_count,\n            local_part_id * world_size + rank,\n            params.num_parts,\n            ntypes_map_val,\n            etypes_map_val,\n            ntypes_map,\n            etypes_map,\n            params.output,\n            params.use_graphbolt,\n        )\n        output_meta_json[\n            \"local-part-id-\" + str(local_part_id * world_size + rank)\n        ] = json_metadata\n        memory_snapshot(\"MetadataCreateComplete: \", rank)\n\n        last_id_tensor = torch.tensor(\n            [prev_last_ids[rank + (local_part_id * world_size)]],\n            dtype=torch.int64,\n        )\n        gather_list = [\n            torch.zeros(1, dtype=torch.int64) for _ in range(world_size)\n        ]\n        dist.all_gather(gather_list, last_id_tensor)\n        for rank_id, last_id in enumerate(gather_list):\n            prev_last_ids[\n                rank_id + (local_part_id * world_size)\n            ] = last_id.item()\n\n    if rank == 0:\n        # get meta-data from all partitions and merge them on rank-0\n        metadata_list = gather_metadata_json(output_meta_json, rank, world_size)\n        metadata_list[0] = output_meta_json\n        write_metadata_json(\n            metadata_list,\n            params.output,\n            params.graph_name,\n            world_size,\n            params.num_parts,\n        )\n    else:\n        # send meta-data to Rank-0 process\n        gather_metadata_json(output_meta_json, rank, world_size)\n    end = timer()\n    logging.info(\n        f\"[Rank: {rank}] Time to create dgl objects: {timedelta(seconds = end - start)}\"\n    )\n    memory_snapshot(\"MetadataWriteComplete: \", rank)\n\n    global_end = timer()\n    logging.info(\n        f\"[Rank: {rank}] Total execution time of the program: {timedelta(seconds = global_end - global_start)}\"\n    )\n    memory_snapshot(\"PipelineComplete: \", rank)\n\n\ndef single_machine_run(params):\n    \"\"\"Main function for distributed implementation on a single machine\n\n    Parameters:\n    -----------\n    params : argparser object\n        Argument Parser structure with pre-determined arguments as defined\n        at the bottom of this file.\n    \"\"\"\n    processes = []\n    mp.set_start_method(\"spawn\")\n\n    # Invoke `target` function from each of the spawned process for distributed\n    # implementation\n    for rank in range(params.world_size):\n        p = mp.Process(\n            target=run,\n            args=(rank, params.world_size, gen_dist_partitions, params),\n        )\n        p.start()\n        processes.append(p)\n\n    for p in processes:\n        p.join()\n\n\ndef run(rank, world_size, func_exec, params, backend=\"gloo\"):\n    \"\"\"\n    Init. function which is run by each process in the Gloo ProcessGroup\n\n    Parameters:\n    -----------\n    rank : integer\n        rank of the process\n    world_size : integer\n        number of processes configured in the Process Group\n    proc_exec : function name\n        function which will be invoked which has the logic for each process in the group\n    params : argparser object\n        argument parser object to access the command line arguments\n    backend : string\n        string specifying the type of backend to use for communication\n    \"\"\"\n    os.environ[\"MASTER_ADDR\"] = \"127.0.0.1\"\n    os.environ[\"MASTER_PORT\"] = \"29500\"\n\n    # create Gloo Process Group\n    dist.init_process_group(\n        backend,\n        rank=rank,\n        world_size=world_size,\n        timeout=timedelta(seconds=5 * 60),\n    )\n\n    # Invoke the main function to kick-off each process\n    func_exec(rank, world_size, params)\n\n\ndef multi_machine_run(params):\n    \"\"\"\n    Function to be invoked when executing data loading pipeline on multiple machines\n\n    Parameters:\n    -----------\n    params : argparser object\n        argparser object providing access to command line arguments.\n    \"\"\"\n    rank = int(os.environ[\"RANK\"])\n\n    # init the gloo process group here.\n    dist.init_process_group(\n        backend=\"gloo\",\n        rank=rank,\n        world_size=params.world_size,\n        timeout=timedelta(seconds=params.process_group_timeout),\n    )\n    logging.info(f\"[Rank: {rank}] Done with process group initialization...\")\n\n    # invoke the main function here.\n    gen_dist_partitions(rank, params.world_size, params)\n    logging.info(\n        f\"[Rank: {rank}] Done with Distributed data processing pipeline processing.\"\n    )\n"
  },
  {
    "path": "tools/distpartitioning/dataset_utils.py",
    "content": "import gc\nimport logging\nimport os\n\nimport array_readwriter\nimport constants\n\nimport numpy as np\nimport pyarrow\nimport pyarrow.parquet as pq\nimport torch\nimport torch.distributed as dist\nfrom gloo_wrapper import alltoallv_cpu\nfrom utils import (\n    DATA_TYPE_ID,\n    generate_read_list,\n    get_gid_offsets,\n    get_idranges,\n    map_partid_rank,\n    REV_DATA_TYPE_ID,\n)\n\n\ndef _broadcast_shape(\n    data, rank, world_size, num_parts, is_feat_data, feat_name\n):\n    \"\"\"Auxiliary function to broadcast the shape of a feature data.\n    This information is used to figure out the type-ids for the\n    local features.\n\n    Parameters:\n    -----------\n    data : numpy array\n        which is the feature data read from the disk\n    rank : integer\n        which represents the id of the process in the process group\n    world_size : integer\n        represents the total no. of process in the process group\n    num_parts : integer\n        specifying the no. of partitions\n    is_feat_data : bool\n        flag used to seperate feature data and edge data\n    feat_name : string\n        name of the feature\n\n    Returns:\n    -------\n    list of tuples :\n        which represents the range of type-ids for the data array.\n    \"\"\"\n    assert len(data.shape) in [\n        1,\n        2,\n    ], f\"Data is expected to be 1-D or 2-D but got {data.shape}.\"\n    data_shape = list(data.shape)\n\n    if len(data_shape) == 1:\n        data_shape.append(1)\n\n    if is_feat_data:\n        data_shape.append(DATA_TYPE_ID[data.dtype])\n\n    data_shape = torch.tensor(data_shape, dtype=torch.int64)\n    data_shape_output = [\n        torch.zeros_like(data_shape) for _ in range(world_size)\n    ]\n    dist.all_gather(data_shape_output, data_shape)\n    logging.debug(\n        f\"[Rank: {rank} Received shapes from all ranks: {data_shape_output}\"\n    )\n    shapes = [x.numpy() for x in data_shape_output if x[0] != 0]\n    shapes = np.vstack(shapes)\n\n    if is_feat_data:\n        logging.debug(\n            f\"shapes: {shapes}, condition: {all(shapes[0,2] == s for s in shapes[:,2])}\"\n        )\n        assert all(\n            shapes[0, 2] == s for s in shapes[:, 2]\n        ), f\"dtypes for {feat_name} does not match on all ranks\"\n\n    # compute tids here.\n    type_counts = list(shapes[:, 0])\n    tid_start = np.cumsum([0] + type_counts[:-1])\n    tid_end = np.cumsum(type_counts)\n    tid_ranges = list(zip(tid_start, tid_end))\n    logging.debug(f\"starts -> {tid_start} ... end -> {tid_end}\")\n\n    return tid_ranges\n\n\ndef get_dataset(\n    input_dir, graph_name, rank, world_size, num_parts, schema_map, ntype_counts\n):\n    \"\"\"\n    Function to read the multiple file formatted dataset.\n\n    Parameters:\n    -----------\n    input_dir : string\n        root directory where dataset is located.\n    graph_name : string\n        graph name string\n    rank : int\n        rank of the current process\n    world_size : int\n        total number of process in the current execution\n    num_parts : int\n        total number of output graph partitions\n    schema_map : dictionary\n        this is the dictionary created by reading the graph metadata json file\n        for the input graph dataset\n\n    Return:\n    -------\n    dictionary\n        where keys are node-type names and values are tuples. Each tuple represents the\n        range of type ids read from a file by the current process. Please note that node\n        data for each node type is split into \"p\" files and each one of these \"p\" files are\n        read a process in the distributed graph partitioning pipeline\n    dictionary\n        Data read from numpy files for all the node features in this dataset. Dictionary built\n        using this data has keys as node feature names and values as tensor data representing\n        node features\n    dictionary\n        in which keys are node-type and values are a triplet. This triplet has node-feature name,\n        and range of tids for the node feature data read from files by the current process. Each\n        node-type may have mutiple feature(s) and associated tensor data.\n    dictionary\n        Data read from edges.txt file and used to build a dictionary with keys as column names\n        and values as columns in the csv file.\n    dictionary\n        in which keys are edge-type names and values are triplets. This triplet has edge-feature name,\n        and range of tids for theedge feature data read from the files by the current process. Each\n        edge-type may have several edge features and associated tensor data.\n    dictionary\n        Data read from numpy files for all the edge features in this dataset. This dictionary's keys\n        are feature names and values are tensors data representing edge feature data.\n    dictionary\n        This dictionary is used for identifying the global-id range for the associated edge features\n        present in the previous return value. The keys are edge-type names and values are triplets.\n        Each triplet consists of edge-feature name and starting and ending points of the range of\n        tids representing the corresponding edge feautres.\n    \"\"\"\n\n    # node features dictionary\n    # TODO: With the new file format, It is guaranteed that the input dataset will have\n    # no. of nodes with features (node-features) files and nodes metadata will always be the same.\n    # This means the dimension indicating the no. of nodes in any node-feature files and the no. of\n    # nodes in the corresponding nodes metadata file will always be the same. With this guarantee,\n    # we can eliminate the `node_feature_tids` dictionary since the same information is also populated\n    # in the `node_tids` dictionary. This will be remnoved in the next iteration of code changes.\n    node_features = {}\n    node_feature_tids = {}\n\n    \"\"\"\n    The structure of the node_data is as follows, which is present in the input metadata json file.\n       \"node_data\" : {\n            \"ntype0-name\" : {\n                \"feat0-name\" : {\n                    \"format\" : {\"name\": \"numpy\"},\n                    \"data\" :   [ #list\n                        \"<path>/feat-0.npy\",\n                        \"<path>/feat-1.npy\",\n                        ....\n                        \"<path>/feat-<p-1>.npy\"\n                    ]\n                },\n                \"feat1-name\" : {\n                    \"format\" : {\"name\": \"numpy\"},\n                    \"data\" : [ #list\n                        \"<path>/feat-0.npy\",\n                        \"<path>/feat-1.npy\",\n                        ....\n                        \"<path>/feat-<p-1>.npy\"\n                    ]\n                }\n            }\n       }\n\n    As shown above, the value for the key \"node_data\" is a dictionary object, which is\n    used to describe the feature data for each of the node-type names. Keys in this top-level\n    dictionary are node-type names and value is a dictionary which captures all the features\n    for the current node-type. Feature data is captured with keys being the feature-names and\n    value is a dictionary object which has 2 keys namely format and data. Format entry is used\n    to mention the format of the storage used by the node features themselves and \"data\" is used\n    to mention all the files present for this given node feature.\n\n    Data read from each of the node features file is a multi-dimensional tensor data and is read\n    in numpy or parquet format, which is also the storage format of node features on the permanent storage.\n\n        \"node_type\" : [\"ntype0-name\", \"ntype1-name\", ....], #m node types\n        \"num_nodes_per_chunk\" : [\n            [a0, a1, ...a<p-1>], #p partitions\n            [b0, b1, ... b<p-1>],\n            ....\n            [c0, c1, ..., c<p-1>] #no, of node types\n        ],\n\n    The \"node_type\" points to a list of all the node names present in the graph\n    And \"num_nodes_per_chunk\" is used to mention no. of nodes present in each of the\n    input nodes files. These node counters are used to compute the type_node_ids as\n    well as global node-ids by using a simple cumulative summation and maitaining an\n    offset counter to store the end of the current.\n\n    Since nodes are NOT actually associated with any additional metadata, w.r.t to the processing\n    involved in this pipeline this information is not needed to be stored in files. This optimization\n    saves a considerable amount of time when loading massively large datasets for paritioning.\n    As opposed to reading from files and performing shuffling process each process/rank generates nodes\n    which are owned by that particular rank. And using the \"num_nodes_per_chunk\" information each\n    process can easily compute any nodes per-type node_id and global node_id.\n    The node-ids are treated as int64's in order to support billions of nodes in the input graph.\n    \"\"\"\n\n    # read my nodes for each node type\n    \"\"\"\n    node_tids, ntype_gnid_offset = get_idranges(\n        schema_map[constants.STR_NODE_TYPE],\n        schema_map[constants.STR_NUM_NODES_PER_CHUNK],\n        num_chunks=num_parts,\n    )\n    \"\"\"\n    logging.debug(f\"[Rank: {rank} ntype_counts: {ntype_counts}\")\n    ntype_gnid_offset = get_gid_offsets(\n        schema_map[constants.STR_NODE_TYPE], ntype_counts\n    )\n    logging.debug(f\"[Rank: {rank} - ntype_gnid_offset = {ntype_gnid_offset}\")\n\n    # iterate over the \"node_data\" dictionary in the schema_map\n    # read the node features if exists\n    # also keep track of the type_nids for which the node_features are read.\n    dataset_features = schema_map[constants.STR_NODE_DATA]\n    if (dataset_features is not None) and (len(dataset_features) > 0):\n        for ntype_name, ntype_feature_data in dataset_features.items():\n            for feat_name, feat_data in ntype_feature_data.items():\n                assert feat_data[constants.STR_FORMAT][constants.STR_NAME] in [\n                    constants.STR_NUMPY,\n                    constants.STR_PARQUET,\n                ]\n\n                # It is guaranteed that num_chunks is always greater\n                # than num_partitions.\n                node_data = []\n                num_files = len(feat_data[constants.STR_DATA])\n                if num_files == 0:\n                    continue\n                reader_fmt_meta = {\n                    \"name\": feat_data[constants.STR_FORMAT][constants.STR_NAME]\n                }\n                read_list = generate_read_list(num_files, world_size)\n                for idx in read_list[rank]:\n                    data_file = feat_data[constants.STR_DATA][idx]\n                    if not os.path.isabs(data_file):\n                        data_file = os.path.join(input_dir, data_file)\n                    node_data.append(\n                        array_readwriter.get_array_parser(\n                            **reader_fmt_meta\n                        ).read(data_file)\n                    )\n                if len(node_data) > 0:\n                    node_data = np.concatenate(node_data)\n                else:\n                    node_data = np.array([])\n                node_data = torch.from_numpy(node_data)\n                cur_tids = _broadcast_shape(\n                    node_data,\n                    rank,\n                    world_size,\n                    num_parts,\n                    True,\n                    f\"{ntype_name}/{feat_name}\",\n                )\n                logging.debug(f\"[Rank: {rank} - cur_tids: {cur_tids}\")\n\n                # collect data on current rank.\n                for local_part_id in range(num_parts):\n                    data_key = (\n                        f\"{ntype_name}/{feat_name}/{local_part_id//world_size}\"\n                    )\n                    if map_partid_rank(local_part_id, world_size) == rank:\n                        if len(cur_tids) > local_part_id:\n                            start, end = cur_tids[local_part_id]\n                            assert node_data.shape[0] == (\n                                end - start\n                            ), f\"Node feature data, {data_key}, shape = {node_data.shape} does not match with tids = ({start},{end})\"\n                            node_features[data_key] = node_data\n                            node_feature_tids[data_key] = [(start, end)]\n                        else:\n                            node_features[data_key] = None\n                            node_feature_tids[data_key] = [(0, 0)]\n\n    # done building node_features locally.\n    if len(node_features) <= 0:\n        logging.debug(\n            f\"[Rank: {rank}] This dataset does not have any node features\"\n        )\n    else:\n        assert len(node_features) == len(node_feature_tids)\n\n        # Note that the keys in the node_features dictionary are as follows:\n        # `ntype_name/feat_name/local_part_id`.\n        #   where ntype_name and feat_name are self-explanatory, and\n        #   local_part_id indicates the partition-id, in the context of current\n        #   process which take the values 0, 1, 2, ....\n        for feat_name, feat_info in node_features.items():\n            if feat_info == None:\n                continue\n\n            logging.debug(\n                f\"[Rank: {rank}] node feature name: {feat_name}, feature data shape: {feat_info.size()}\"\n            )\n            tokens = feat_name.split(\"/\")\n            assert len(tokens) == 3\n\n            # Get the range of type ids which are mapped to the current node.\n            tids = node_feature_tids[feat_name]\n\n            # Iterate over the range of type ids for the current node feature\n            # and count the number of features for this feature name.\n            count = tids[0][1] - tids[0][0]\n            assert (\n                count == feat_info.size()[0]\n            ), f\"{feat_name}, {count} vs {feat_info.size()[0]}.\"\n\n    \"\"\"\n    Reading edge features now.\n    The structure of the edge_data is as follows, which is present in the input metadata json file.\n       \"edge_data\" : {\n            \"etype0-name\" : {\n                \"feat0-name\" : {\n                    \"format\" : {\"name\": \"numpy\"},\n                    \"data\" :   [ #list\n                        \"<path>/feat-0.npy\",\n                        \"<path>/feat-1.npy\",\n                        ....\n                        \"<path>/feat-<p-1>.npy\"\n                    ]\n                },\n                \"feat1-name\" : {\n                    \"format\" : {\"name\": \"numpy\"},\n                    \"data\" : [ #list\n                        \"<path>/feat-0.npy\",\n                        \"<path>/feat-1.npy\",\n                        ....\n                        \"<path>/feat-<p-1>.npy\"\n                    ]\n                }\n            }\n       }\n\n    As shown above, the value for the key \"edge_data\" is a dictionary object, which is\n    used to describe the feature data for each of the edge-type names. Keys in this top-level\n    dictionary are edge-type names and value is a dictionary which captures all the features\n    for the current edge-type. Feature data is captured with keys being the feature-names and\n    value is a dictionary object which has 2 keys namely `format` and `data`. Format entry is used\n    to mention the format of the storage used by the node features themselves and \"data\" is used\n    to mention all the files present for this given node feature.\n\n    Data read from each of the node features file is a multi-dimensional tensor data and is read\n    in numpy format, which is also the storage format of node features on the permanent storage.\n    \"\"\"\n    edge_features = {}\n    edge_feature_tids = {}\n\n    # Iterate over the \"edge_data\" dictionary in the schema_map.\n    # Read the edge features if exists.\n    # Also keep track of the type_eids for which the edge_features are read.\n    dataset_features = schema_map[constants.STR_EDGE_DATA]\n    if dataset_features and (len(dataset_features) > 0):\n        for etype_name, etype_feature_data in dataset_features.items():\n            for feat_name, feat_data in etype_feature_data.items():\n                assert feat_data[constants.STR_FORMAT][constants.STR_NAME] in [\n                    constants.STR_NUMPY,\n                    constants.STR_PARQUET,\n                ]\n\n                edge_data = []\n                num_files = len(feat_data[constants.STR_DATA])\n                if num_files == 0:\n                    continue\n                reader_fmt_meta = {\n                    \"name\": feat_data[constants.STR_FORMAT][constants.STR_NAME]\n                }\n                read_list = generate_read_list(num_files, world_size)\n                for idx in read_list[rank]:\n                    data_file = feat_data[constants.STR_DATA][idx]\n                    if not os.path.isabs(data_file):\n                        data_file = os.path.join(input_dir, data_file)\n                    logging.debug(\n                        f\"[Rank: {rank}] Loading edges-feats of {etype_name}[{feat_name}] from {data_file}\"\n                    )\n                    edge_data.append(\n                        array_readwriter.get_array_parser(\n                            **reader_fmt_meta\n                        ).read(data_file)\n                    )\n                if len(edge_data) > 0:\n                    edge_data = np.concatenate(edge_data)\n                else:\n                    edge_data = np.array([])\n                edge_data = torch.from_numpy(edge_data)\n\n                # exchange the amount of data read from the disk.\n                edge_tids = _broadcast_shape(\n                    edge_data,\n                    rank,\n                    world_size,\n                    num_parts,\n                    True,\n                    f\"{etype_name}/{feat_name}\",\n                )\n\n                # collect data on current rank.\n                for local_part_id in range(num_parts):\n                    data_key = (\n                        f\"{etype_name}/{feat_name}/{local_part_id//world_size}\"\n                    )\n                    if map_partid_rank(local_part_id, world_size) == rank:\n                        if len(edge_tids) > local_part_id:\n                            start, end = edge_tids[local_part_id]\n                            assert edge_data.shape[0] == (\n                                end - start\n                            ), f\"Edge Feature data, for {data_key}, of shape = {edge_data.shape} does not match with tids = ({start}, {end})\"\n                            edge_features[data_key] = edge_data\n                            edge_feature_tids[data_key] = [(start, end)]\n                        else:\n                            edge_features[data_key] = None\n                            edge_feature_tids[data_key] = [(0, 0)]\n\n    # Done with building node_features locally.\n    if len(edge_features) <= 0:\n        logging.debug(\n            f\"[Rank: {rank}] This dataset does not have any edge features\"\n        )\n    else:\n        assert len(edge_features) == len(edge_feature_tids)\n\n        for k, v in edge_features.items():\n            if v == None:\n                continue\n            logging.debug(\n                f\"[Rank: {rank}] edge feature name: {k}, feature data shape: {v.shape}\"\n            )\n            tids = edge_feature_tids[k]\n            count = tids[0][1] - tids[0][0]\n            assert count == v.size()[0]\n\n    \"\"\"\n    Code below is used to read edges from the input dataset with the help of the metadata json file\n    for the input graph dataset.\n    In the metadata json file, we expect the following key-value pairs to help read the edges of the\n    input graph.\n\n    \"edge_type\" : [ # a total of n edge types\n        canonical_etype_0,\n        canonical_etype_1,\n        ...,\n        canonical_etype_n-1\n    ]\n\n    The value for the key is a list of strings, each string is associated with an edgetype in the input graph.\n    Note that these strings are in canonical edgetypes format. This means, these edge type strings follow the\n    following naming convention: src_ntype:etype:dst_ntype. src_ntype and dst_ntype are node type names of the\n    src and dst end points of this edge type, and etype is the relation name between src and dst ntypes.\n\n    The files in which edges are present and their storage format are present in the following key-value pair:\n\n    \"edges\" : {\n        \"canonical_etype_0\" : {\n            \"format\" : { \"name\" : \"csv\", \"delimiter\" : \" \" },\n            \"data\" : [\n                filename_0,\n                filename_1,\n                filename_2,\n                ....\n                filename_<p-1>\n            ]\n        },\n    }\n\n    As shown above the \"edges\" dictionary value has canonical edgetypes as keys and for each canonical edgetype\n    we have \"format\" and \"data\" which describe the storage format of the edge files and actual filenames respectively.\n    Please note that each edgetype data is split in to `p` files, where p is the no. of partitions to be made of\n    the input graph.\n\n    Each edge file contains two columns representing the source per-type node_ids and destination per-type node_ids\n    of any given edge. Since these are node-ids as well they are read in as int64's.\n    \"\"\"\n\n    # read my edges for each edge type\n    etype_names = schema_map[constants.STR_EDGE_TYPE]\n    etype_name_idmap = {e: idx for idx, e in enumerate(etype_names)}\n\n    edge_tids = {}\n    edge_typecounts = {}\n    edge_datadict = {}\n    edge_data = schema_map[constants.STR_EDGES]\n\n    # read the edges files and store this data in memory.\n    for col in [\n        constants.GLOBAL_SRC_ID,\n        constants.GLOBAL_DST_ID,\n        constants.GLOBAL_TYPE_EID,\n        constants.ETYPE_ID,\n    ]:\n        edge_datadict[col] = []\n\n    for etype_name, etype_id in etype_name_idmap.items():\n        etype_info = edge_data[etype_name]\n        edge_info = etype_info[constants.STR_DATA]\n\n        # edgetype strings are in canonical format, src_node_type:edge_type:dst_node_type\n        tokens = etype_name.split(\":\")\n        assert len(tokens) == 3\n\n        src_ntype_name = tokens[0]\n        dst_ntype_name = tokens[2]\n\n        num_chunks = len(edge_info)\n        read_list = generate_read_list(num_chunks, world_size)\n        src_ids = []\n        dst_ids = []\n\n        \"\"\"\n        curr_partids = []\n        for part_id in range(num_parts):\n            if map_partid_rank(part_id, world_size) == rank:\n                curr_partids.append(read_list[part_id])\n\n        for idx in np.concatenate(curr_partids):\n        \"\"\"\n        for idx in read_list[rank]:\n            edge_file = edge_info[idx]\n            if not os.path.isabs(edge_file):\n                edge_file = os.path.join(input_dir, edge_file)\n            logging.debug(\n                f\"[Rank: {rank}] Loading edges of etype[{etype_name}] from {edge_file}\"\n            )\n\n            if (\n                etype_info[constants.STR_FORMAT][constants.STR_NAME]\n                == constants.STR_CSV\n            ):\n                read_options = pyarrow.csv.ReadOptions(\n                    use_threads=True,\n                    block_size=4096,\n                    autogenerate_column_names=True,\n                )\n                parse_options = pyarrow.csv.ParseOptions(delimiter=\" \")\n\n                if os.path.getsize(edge_file) == 0:\n                    # if getsize() == 0, the file is empty, indicating that the partition doesn't have this attribute.\n                    # The src_ids and dst_ids should remain empty.\n                    continue\n                with pyarrow.csv.open_csv(\n                    edge_file,\n                    read_options=read_options,\n                    parse_options=parse_options,\n                ) as reader:\n                    for next_chunk in reader:\n                        if next_chunk is None:\n                            break\n\n                        next_table = pyarrow.Table.from_batches([next_chunk])\n                        src_ids.append(next_table[\"f0\"].to_numpy())\n                        dst_ids.append(next_table[\"f1\"].to_numpy())\n            elif (\n                etype_info[constants.STR_FORMAT][constants.STR_NAME]\n                == constants.STR_PARQUET\n            ):\n                data_df = pq.read_table(edge_file)\n                data_df = data_df.rename_columns([\"f0\", \"f1\"])\n                src_ids.append(data_df[\"f0\"].to_numpy())\n                dst_ids.append(data_df[\"f1\"].to_numpy())\n            else:\n                raise ValueError(\n                    f\"Unknown edge format {etype_info[constants.STR_FORMAT][constants.STR_NAME]} for edge type {etype_name}\"\n                )\n\n        if len(src_ids) > 0:\n            src_ids = np.concatenate(src_ids)\n            dst_ids = np.concatenate(dst_ids)\n\n            # currently these are just type_edge_ids... which will be converted to global ids\n            edge_datadict[constants.GLOBAL_SRC_ID].append(\n                src_ids + ntype_gnid_offset[src_ntype_name][0]\n            )\n            edge_datadict[constants.GLOBAL_DST_ID].append(\n                dst_ids + ntype_gnid_offset[dst_ntype_name][0]\n            )\n            edge_datadict[constants.ETYPE_ID].append(\n                etype_name_idmap[etype_name]\n                * np.ones(shape=(src_ids.shape), dtype=np.int64)\n            )\n        else:\n            src_ids = np.array([])\n\n        # broadcast shape to compute the etype_id, and global_eid's later.\n        cur_tids = _broadcast_shape(\n            src_ids, rank, world_size, num_parts, False, None\n        )\n        edge_typecounts[etype_name] = cur_tids[-1][1]\n        edge_tids[etype_name] = cur_tids\n\n        for local_part_id in range(num_parts):\n            if map_partid_rank(local_part_id, world_size) == rank:\n                if len(cur_tids) > local_part_id:\n                    edge_datadict[constants.GLOBAL_TYPE_EID].append(\n                        np.arange(\n                            cur_tids[local_part_id][0],\n                            cur_tids[local_part_id][1],\n                            dtype=np.int64,\n                        )\n                    )\n                    # edge_tids[etype_name] = [(cur_tids[local_part_id][0], cur_tids[local_part_id][1])]\n                    assert len(edge_datadict[constants.GLOBAL_SRC_ID]) == len(\n                        edge_datadict[constants.GLOBAL_TYPE_EID]\n                    ), f\"Error while reading edges from the disk, local_part_id = {local_part_id}, num_parts = {num_parts}, world_size = {world_size} cur_tids = {cur_tids}\"\n\n    # stitch together to create the final data on the local machine\n    for col in [\n        constants.GLOBAL_SRC_ID,\n        constants.GLOBAL_DST_ID,\n        constants.GLOBAL_TYPE_EID,\n        constants.ETYPE_ID,\n    ]:\n        if len(edge_datadict[col]) > 0:\n            edge_datadict[col] = np.concatenate(edge_datadict[col])\n\n    if len(edge_datadict[constants.GLOBAL_SRC_ID]) > 0:\n        assert (\n            edge_datadict[constants.GLOBAL_SRC_ID].shape\n            == edge_datadict[constants.GLOBAL_DST_ID].shape\n        )\n        assert (\n            edge_datadict[constants.GLOBAL_DST_ID].shape\n            == edge_datadict[constants.GLOBAL_TYPE_EID].shape\n        )\n        assert (\n            edge_datadict[constants.GLOBAL_TYPE_EID].shape\n            == edge_datadict[constants.ETYPE_ID].shape\n        )\n        logging.debug(\n            f\"[Rank: {rank}] Done reading edge_file: {len(edge_datadict)}, {edge_datadict[constants.GLOBAL_SRC_ID].shape}\"\n        )\n    else:\n        assert edge_datadict[constants.GLOBAL_SRC_ID] == []\n        assert edge_datadict[constants.GLOBAL_DST_ID] == []\n        assert edge_datadict[constants.GLOBAL_TYPE_EID] == []\n\n        edge_datadict[constants.GLOBAL_SRC_ID] = np.array([], dtype=np.int64)\n        edge_datadict[constants.GLOBAL_DST_ID] = np.array([], dtype=np.int64)\n        edge_datadict[constants.GLOBAL_TYPE_EID] = np.array([], dtype=np.int64)\n        edge_datadict[constants.ETYPE_ID] = np.array([], dtype=np.int64)\n\n    logging.debug(f\"Rank: {rank} edge_feat_tids: {edge_feature_tids}\")\n\n    return (\n        node_features,\n        node_feature_tids,\n        edge_datadict,\n        edge_typecounts,\n        edge_tids,\n        edge_features,\n        edge_feature_tids,\n    )\n"
  },
  {
    "path": "tools/distpartitioning/dist_lookup.py",
    "content": "import copy\nimport logging\nimport os\n\nimport numpy as np\nimport pyarrow\nimport torch\nfrom gloo_wrapper import allgather_sizes, alltoallv_cpu\nfrom pyarrow import csv\nfrom utils import map_partid_rank\n\n\nclass DistLookupService:\n    \"\"\"\n    This is an implementation of a Distributed Lookup Service to provide the following\n    services to its users. Map 1) global node-ids to partition-ids, and 2) global node-ids\n    to shuffle global node-ids (contiguous, within each node for a give node_type and across\n    all the partitions)\n\n    This services initializes itself with the node-id to partition-id mappings, which are inputs\n    to this service. The node-id to partition-id  mappings are assumed to be in one file for each\n    node type. These node-id-to-partition-id mappings are split within the service processes so that\n    each process ends up with a contiguous chunk. It first divides the no of mappings (node-id to\n    partition-id) for each node type into equal chunks across all the service processes. So each\n    service process will be thse owner of a set of node-id-to-partition-id mappings. This class\n    has two functions which are as follows:\n\n    1) `get_partition_ids` function which returns the node-id to partition-id mappings to the user\n    2) `get_shuffle_nids` function which returns the node-id to shuffle-node-id mapping to the user\n\n    Parameters:\n    -----------\n    input_dir : string\n        string representing the input directory where the node-type partition-id\n        files are located\n    ntype_names : list of strings\n        list of strings which are used to read files located within the input_dir\n        directory and these files contents are partition-id's for the node-ids which\n        are of a particular node type\n    id_map : dgl.distributed.id_map instance\n        this id_map is used to retrieve ntype-ids, node type ids, and type_nids, per type\n        node ids, for any given global node id\n    rank : integer\n        integer indicating the rank of a given process\n    world_size : integer\n        integer indicating the total no. of processes\n    num_parts : integer\n        interger representing the no. of partitions\n    \"\"\"\n\n    def __init__(self, input_dir, ntype_names, rank, world_size, num_parts):\n        assert os.path.isdir(input_dir)\n        assert ntype_names is not None\n        assert len(ntype_names) > 0\n\n        # These lists are indexed by ntype_ids.\n        type_nid_begin = []\n        type_nid_end = []\n        partid_list = []\n        ntype_count = []\n        ntypes = []\n\n        # Iterate over the node types and extract the partition id mappings.\n        for ntype in ntype_names:\n\n            filename = f\"{ntype}.txt\"\n            logging.debug(\n                f\"[Rank: {rank}] Reading file: {os.path.join(input_dir, filename)}\"\n            )\n\n            read_options = pyarrow.csv.ReadOptions(\n                use_threads=True,\n                block_size=4096,\n                autogenerate_column_names=True,\n            )\n            parse_options = pyarrow.csv.ParseOptions(delimiter=\" \")\n            ntype_partids = []\n            with pyarrow.csv.open_csv(\n                os.path.join(input_dir, \"{}.txt\".format(ntype)),\n                read_options=read_options,\n                parse_options=parse_options,\n            ) as reader:\n                for next_chunk in reader:\n                    if next_chunk is None:\n                        break\n                    next_table = pyarrow.Table.from_batches([next_chunk])\n                    ntype_partids.append(next_table[\"f0\"].to_numpy())\n\n            ntype_partids = np.concatenate(ntype_partids)\n            count = len(ntype_partids)\n            ntype_count.append(count)\n            ntypes.append(ntype)\n\n            # Each rank assumes a contiguous set of partition-ids which are equally split\n            # across all the processes.\n            split_size = np.ceil(count / np.int64(world_size)).astype(np.int64)\n            start, end = (\n                np.int64(rank) * split_size,\n                np.int64(rank + 1) * split_size,\n            )\n            if rank == (world_size - 1):\n                end = count\n            type_nid_begin.append(start)\n            type_nid_end.append(end)\n\n            # Slice the partition-ids which belong to the current instance.\n            partid_list.append(copy.deepcopy(ntype_partids[start:end]))\n\n            # Explicitly release the array read from the file.\n            del ntype_partids\n\n        logging.debug(\n            f\"[Rank: {rank}] ntypeid begin - {type_nid_begin} - {type_nid_end}\"\n        )\n\n        # Store all the information in the object instance variable.\n        self.type_nid_begin = np.array(type_nid_begin, dtype=np.int64)\n        self.type_nid_end = np.array(type_nid_end, dtype=np.int64)\n        self.partid_list = partid_list\n        self.ntype_count = np.array(ntype_count, dtype=np.int64)\n        self.ntypes = ntypes\n        self.rank = rank\n        self.world_size = world_size\n        self.num_parts = num_parts\n\n    def set_idMap(self, id_map):\n        self.id_map = id_map\n\n    def get_partition_ids(self, agg_global_nids):\n        \"\"\"\n        This function is used to get the partition-ids for a given set of global node ids\n\n        global_nids <-> partition-ids mappings are deterministically  distributed across\n        all the participating processes, within the service. A contiguous global-nids\n        (ntype-ids, per-type-nids) are stored within each process and this is determined\n        by the total no. of nodes of a given ntype-id and the rank of the process.\n\n        Process, where the global_nid <-> partition-id mapping is stored can be easily computed\n        as described above. Once this is determined we perform an alltoallv to send the request.\n        On the receiving side, each process receives a set of global_nids and retrieves corresponding\n        partition-ids using locally stored lookup tables. It builds responses to all the other\n        processes and performs alltoallv.\n\n        Once the response, partition-ids, is received, they are re-ordered corresponding to the\n        incoming global-nids order and returns to the caller.\n\n        Parameters:\n        -----------\n        self : instance of this class\n            instance of this class, which is passed by the runtime implicitly\n\n        agg_global_nids : numpy array\n            an array of aggregated global node-ids for which partition-ids are\n            to be retrieved by the distributed lookup service.\n\n        Returns:\n        --------\n        list of integers :\n            list of integers, which are the partition-ids of the global-node-ids (which is the\n            function argument)\n        \"\"\"\n        CHUNK_SIZE = 200 * 1000 * 1000\n        # Determine the no. of times each process has to send alltoall messages.\n        local_rows = agg_global_nids.shape[0]\n        all_sizes = allgather_sizes(\n            [local_rows], self.world_size, self.num_parts, return_sizes=True\n        )\n        max_count = np.amax(all_sizes)\n\n        if max_count <= 0:\n            logging.debug(\n                f\"[Rank: {self.rank}] No process has global_nids to process !!!\"\n            )\n            return\n\n        num_splits = np.ceil(max_count / CHUNK_SIZE).astype(np.uint16)\n        LOCAL_CHUNK_SIZE = np.ceil(local_rows / num_splits).astype(np.int64)\n        agg_partition_ids = []\n\n        logging.debug(\n            f\"[Rank: {self.rank}] BatchSize: {CHUNK_SIZE}, \\\n                            max_count: {max_count}, \\\n                            splits: {num_splits}, \\\n                            rows: {agg_global_nids.shape}, \\\n                            local batch_size: {LOCAL_CHUNK_SIZE}\"\n        )\n\n        for split in range(num_splits):\n            # Compute the global_nids for this iteration\n            global_nids = agg_global_nids[\n                split * LOCAL_CHUNK_SIZE : (split + 1) * LOCAL_CHUNK_SIZE\n            ]\n\n            # Find the process where global_nid --> partition-id(owner) is stored.\n            if len(global_nids) > 0:\n                ntype_ids, type_nids = self.id_map(global_nids)\n                ntype_ids, type_nids = ntype_ids.numpy(), type_nids.numpy()\n            else:\n                ntype_ids = np.array([], dtype=np.int64)\n                type_nids = np.array([], dtype=np.int64)\n\n            assert len(ntype_ids) == len(global_nids)\n\n            # For each node-type, the per-type-node-id <-> partition-id mappings are\n            # stored as contiguous chunks by this lookup service.\n            # The no. of these mappings stored by each process, in the lookup service, are\n            # equally split among all the processes in the lookup service, deterministically.\n            typeid_counts = self.ntype_count[ntype_ids]\n            chunk_sizes = np.ceil(typeid_counts / self.world_size).astype(\n                np.int64\n            )\n            service_owners = np.floor_divide(type_nids, chunk_sizes).astype(\n                np.int64\n            )\n\n            # Now `service_owners` is a list of ranks (process-ids) which own the corresponding\n            # global-nid <-> partition-id mapping.\n\n            # Split the input global_nids into a list of lists where each list will be\n            # sent to the respective rank/process\n            # We also need to store the indices, in the indices_list, so that we can re-order\n            # the final result (partition-ids) in the same order as the global-nids (function argument)\n            send_list = []\n            indices_list = []\n            for idx in range(self.world_size):\n                idxes = np.where(service_owners == idx)\n                ll = global_nids[idxes[0]]\n                send_list.append(torch.from_numpy(ll))\n                indices_list.append(idxes[0])\n            assert len(np.concatenate(indices_list)) == len(global_nids)\n            assert np.all(\n                np.sort(np.concatenate(indices_list))\n                == np.arange(len(global_nids))\n            )\n\n            # Send the request to everyone else.\n            # As a result of this operation, the current process also receives a list of lists\n            # from all the other processes.\n            # These lists are global-node-ids whose global-node-ids <-> partition-id mappings\n            # are owned/stored by the current process\n            owner_req_list = alltoallv_cpu(\n                self.rank, self.world_size, send_list\n            )\n\n            # Create the response list here for each of the request list received in the previous\n            # step. Populate the respective partition-ids in this response lists appropriately\n            out_list = []\n            for idx in range(self.world_size):\n                if owner_req_list[idx] is None:\n                    out_list.append(torch.empty((0,), dtype=torch.int64))\n                    continue\n                # Get the node_type_ids and per_type_nids for the incoming global_nids.\n                ntype_ids, type_nids = self.id_map(owner_req_list[idx].numpy())\n                ntype_ids, type_nids = ntype_ids.numpy(), type_nids.numpy()\n\n                # Lists to store partition-ids for the incoming global-nids.\n                type_id_lookups = []\n                local_order_idx = []\n\n                # Now iterate over all the node_types and acculumulate all the partition-ids\n                # since all the partition-ids are based on the node_type order... they\n                # must be re-ordered as per the order of the input, which may be different.\n                for tid in range(len(self.partid_list)):\n                    cond = ntype_ids == tid\n                    local_order_idx.append(np.where(cond)[0])\n                    global_type_nids = type_nids[cond]\n                    if len(global_type_nids) <= 0:\n                        continue\n\n                    local_type_nids = (\n                        global_type_nids - self.type_nid_begin[tid]\n                    )\n\n                    assert np.all(local_type_nids >= 0)\n                    assert np.all(\n                        local_type_nids\n                        <= (\n                            self.type_nid_end[tid]\n                            + 1\n                            - self.type_nid_begin[tid]\n                        )\n                    )\n\n                    cur_owners = self.partid_list[tid][local_type_nids]\n                    type_id_lookups.append(cur_owners)\n\n                # Reorder the partition-ids, so that it agrees with the input order --\n                # which is the order in which the incoming message is received.\n                if len(type_id_lookups) <= 0:\n                    out_list.append(torch.empty((0,), dtype=torch.int64))\n                else:\n                    # Now reorder results for each request.\n                    sort_order_idx = np.argsort(np.concatenate(local_order_idx))\n                    lookups = np.concatenate(type_id_lookups)[sort_order_idx]\n                    out_list.append(torch.from_numpy(lookups))\n\n            # Send the partition-ids to their respective requesting processes.\n            owner_resp_list = alltoallv_cpu(\n                self.rank, self.world_size, out_list\n            )\n\n            # Owner_resp_list, is a list of lists of numpy arrays where each list\n            # is a list of partition-ids which the current process requested\n            # Now we need to re-order so that the parition-ids correspond to the\n            # global_nids which are passed into this function.\n\n            # Order according to the requesting order.\n            # Owner_resp_list is the list of owner-ids for global_nids (function argument).\n            owner_ids = [x for x in owner_resp_list if x is not None]\n            if len(owner_ids) > 0:\n                owner_ids = torch.cat(owner_ids).numpy()\n            else:\n                owner_ids = np.array([], dtype=np.int64)\n            assert len(owner_ids) == len(global_nids)\n\n            global_nids_order = np.concatenate(indices_list)\n            sort_order_idx = np.argsort(global_nids_order)\n            owner_ids = owner_ids[sort_order_idx]\n            global_nids_order = global_nids_order[sort_order_idx]\n            assert np.all(np.arange(len(global_nids)) == global_nids_order)\n\n            if len(owner_ids) > 0:\n                # Store the partition-ids for the current split\n                agg_partition_ids.append(owner_ids)\n\n        # Stitch the list of partition-ids and return to the caller\n        if len(agg_partition_ids) > 0:\n            agg_partition_ids = np.concatenate(agg_partition_ids)\n        else:\n            agg_partition_ids = np.array([], dtype=np.int64)\n        assert agg_global_nids.shape[0] == agg_partition_ids.shape[0]\n\n        # Now the owner_ids (partition-ids) which corresponding to the  global_nids.\n        return agg_partition_ids\n\n    def get_shuffle_nids(\n        self, global_nids, my_global_nids, my_shuffle_global_nids, world_size\n    ):\n        \"\"\"\n        This function is used to retrieve shuffle_global_nids for a given set of incoming\n        global_nids. Note that global_nids are of random order and will contain duplicates\n\n        This function first retrieves the partition-ids of the incoming global_nids.\n        These partition-ids which are also the ranks of processes which own the respective\n        global-nids as well as shuffle-global-nids. alltoallv is performed to send the\n        global-nids to respective ranks/partition-ids where the mapping\n        global-nids <-> shuffle-global-nid is located.\n\n        On the receiving side, once the global-nids are received associated shuffle-global-nids\n        are retrieved and an alltoallv is performed to send the responses to all the other\n        processes.\n\n        Once the responses, shuffle-global-nids, are received, they are re-ordered according\n        to the incoming global-nids order and returns to the caller.\n\n        Parameters:\n        -----------\n        self : instance of this class\n            instance of this class, which is passed by the runtime implicitly\n        global_nids : numpy array\n            an array of global node-ids for which partition-ids are to be retrieved by\n            the distributed lookup service.\n        my_global_nids: numpy ndarray\n            array of global_nids which are owned by the current partition/rank/process\n            This process has the node <-> partition id mapping\n        my_shuffle_global_nids : numpy ndarray\n            array of shuffle_global_nids which are assigned by the current process/rank\n        world_size : int\n            total no. of processes in the MPI_WORLD\n\n        Returns:\n        --------\n        list of integers:\n            list of shuffle_global_nids which correspond to the incoming node-ids in the\n            global_nids.\n        \"\"\"\n\n        # Get the owner_ids (partition-ids or rank).\n        owner_ids = self.get_partition_ids(global_nids)\n\n        # These owner_ids, which are also partition ids of the nodes in the\n        # input graph, are in the range 0 - (num_partitions - 1).\n        # These ids are generated using some kind of graph partitioning method.\n        # Distribuged lookup service, as used by the graph partitioning\n        # pipeline, is used to store ntype-ids (also type_nids) and their\n        # mapping to the associated partition-id.\n        # These ids are split into `num_process` chunks and processes in the\n        # dist. lookup service are assigned the owernship of these chunks.\n        # The pipeline also enforeces the following constraint among the\n        # pipeline input parameters: num_partitions, num_processes\n        #   num_partitions is an integer multiple of num_processes\n        #   which means each individual node in the cluster will be running\n        #   equal number of processes.\n        owner_ids = map_partid_rank(owner_ids, world_size)\n\n        # Ask these owners to supply for the shuffle_global_nids.\n        send_list = []\n        id_list = []\n        for idx in range(self.world_size):\n            cond = owner_ids == idx\n            idxes = np.where(cond)\n            ll = global_nids[idxes[0]]\n            send_list.append(torch.from_numpy(ll))\n            id_list.append(idxes[0])\n\n        assert len(np.concatenate(id_list)) == len(global_nids)\n        cur_global_nids = alltoallv_cpu(self.rank, self.world_size, send_list)\n\n        # At this point, current process received a list of lists each containing\n        # a list of global-nids whose corresponding shuffle_global_nids are located\n        # in the current process.\n        shuffle_nids_list = []\n        for idx in range(self.world_size):\n            if cur_global_nids[idx] is None:\n                shuffle_nids_list.append(torch.empty((0,), dtype=torch.int64))\n                continue\n\n            uniq_ids, inverse_idx = np.unique(\n                cur_global_nids[idx], return_inverse=True\n            )\n            common, idx1, idx2 = np.intersect1d(\n                uniq_ids,\n                my_global_nids,\n                assume_unique=True,\n                return_indices=True,\n            )\n            assert len(common) == len(uniq_ids)\n\n            req_shuffle_global_nids = my_shuffle_global_nids[idx2][inverse_idx]\n            assert len(req_shuffle_global_nids) == len(cur_global_nids[idx])\n            shuffle_nids_list.append(torch.from_numpy(req_shuffle_global_nids))\n\n        # Send the shuffle-global-nids to their respective ranks.\n        mapped_global_nids = alltoallv_cpu(\n            self.rank, self.world_size, shuffle_nids_list\n        )\n        for idx in range(len(mapped_global_nids)):\n            if mapped_global_nids[idx] == None:\n                mapped_global_nids[idx] = torch.empty((0,), dtype=torch.int64)\n\n        # Reorder to match global_nids (function parameter).\n        global_nids_order = np.concatenate(id_list)\n        shuffle_global_nids = torch.cat(mapped_global_nids).numpy()\n        assert len(shuffle_global_nids) == len(global_nids)\n\n        sorted_idx = np.argsort(global_nids_order)\n        shuffle_global_nids = shuffle_global_nids[sorted_idx]\n        global_nids_ordered = global_nids_order[sorted_idx]\n        assert np.all(global_nids_ordered == np.arange(len(global_nids)))\n\n        return shuffle_global_nids\n"
  },
  {
    "path": "tools/distpartitioning/globalids.py",
    "content": "import itertools\nimport operator\n\nimport constants\n\nimport numpy as np\nimport torch\nfrom dist_lookup import DistLookupService\nfrom gloo_wrapper import allgather_sizes, alltoallv_cpu\nfrom utils import memory_snapshot\n\n\ndef get_shuffle_global_nids(rank, world_size, global_nids_ranks, node_data):\n    \"\"\"\n    For nodes which are not owned by the current rank, whose global_nid <-> shuffle_global-nid mapping\n    is not present at the current rank, this function retrieves their shuffle_global_ids from the owner rank\n\n    Parameters:\n    -----------\n    rank : integer\n        rank of the process\n    world_size : integer\n        total no. of ranks configured\n    global_nids_ranks : list\n        list of numpy arrays (of global_nids), index of the list is the rank of the process\n                    where global_nid <-> shuffle_global_nid mapping is located.\n    node_data : dictionary\n        node_data is a dictionary with keys as column names and values as numpy arrays\n\n    Returns:\n    --------\n    numpy ndarray\n        where the column-0 are global_nids and column-1 are shuffle_global_nids which are retrieved\n        from other processes.\n    \"\"\"\n    # build a list of sizes (lengths of lists)\n    global_nids_ranks = [torch.from_numpy(x) for x in global_nids_ranks]\n    recv_nodes = alltoallv_cpu(rank, world_size, global_nids_ranks)\n\n    # Use node_data to lookup global id to send over.\n    send_nodes = []\n    for proc_i_nodes in recv_nodes:\n        # list of node-ids to lookup\n        if proc_i_nodes is not None:\n            global_nids = proc_i_nodes.numpy()\n            if len(global_nids) != 0:\n                common, ind1, ind2 = np.intersect1d(\n                    node_data[constants.GLOBAL_NID],\n                    global_nids,\n                    return_indices=True,\n                )\n                shuffle_global_nids = node_data[constants.SHUFFLE_GLOBAL_NID][\n                    ind1\n                ]\n                send_nodes.append(\n                    torch.from_numpy(shuffle_global_nids).type(\n                        dtype=torch.int64\n                    )\n                )\n            else:\n                send_nodes.append(torch.empty((0), dtype=torch.int64))\n        else:\n            send_nodes.append(torch.empty((0), dtype=torch.int64))\n\n    # send receive global-ids\n    recv_shuffle_global_nids = alltoallv_cpu(rank, world_size, send_nodes)\n    shuffle_global_nids = np.concatenate(\n        [x.numpy() if x is not None else [] for x in recv_shuffle_global_nids]\n    )\n    global_nids = np.concatenate([x for x in global_nids_ranks])\n    ret_val = np.column_stack([global_nids, shuffle_global_nids])\n    return ret_val\n\n\ndef lookup_shuffle_global_nids_edges(\n    rank, world_size, num_parts, edge_data, id_lookup, node_data\n):\n    \"\"\"\n    This function is a helper function used to lookup shuffle-global-nids for a given set of\n    global-nids using a distributed lookup service.\n\n    Parameters:\n    -----------\n    rank : integer\n        rank of the process\n    world_size : integer\n        total number of processes used in the process group\n    num_parts : integer\n        total number of output graph partitions\n    edge_data : dictionary\n        edge_data is a dicitonary with keys as column names and values as numpy arrays representing\n        all the edges present in the current graph partition\n    id_lookup : instance of DistLookupService class\n        instance of a distributed lookup service class which is used to retrieve partition-ids and\n        shuffle-global-nids for any given set of global-nids\n    node_data : dictionary\n        node_data is a dictionary with keys as column names and values as numpy arrays representing\n        all the nodes owned by the current process\n\n    Returns:\n    --------\n    dictionary :\n        dictionary where keys are column names and values are numpy arrays representing all the\n        edges present in the current graph partition\n    \"\"\"\n    # Make sure that the outgoing message size does not exceed 2GB in size.\n    # Even though gloo can handle upto 10GB size of data in the outgoing messages,\n    # it needs additional memory to store temporary information into the buffers which will increase\n    # the memory needs of the process.\n    MILLION = 1000 * 1000\n    BATCH_SIZE = 250 * MILLION\n    memory_snapshot(\"GlobalToShuffleIDMapBegin: \", rank)\n\n    local_nids = []\n    local_shuffle_nids = []\n    for local_part_id in range(num_parts // world_size):\n        local_nids.append(\n            node_data[constants.GLOBAL_NID + \"/\" + str(local_part_id)]\n        )\n        local_shuffle_nids.append(\n            node_data[constants.SHUFFLE_GLOBAL_NID + \"/\" + str(local_part_id)]\n        )\n\n    local_nids = np.concatenate(local_nids)\n    local_shuffle_nids = np.concatenate(local_shuffle_nids)\n\n    for local_part_id in range(num_parts // world_size):\n        node_list = edge_data[\n            constants.GLOBAL_SRC_ID + \"/\" + str(local_part_id)\n        ]\n\n        # Determine the no. of times each process has to send alltoall messages.\n        all_sizes = allgather_sizes(\n            [node_list.shape[0]], world_size, num_parts, return_sizes=True\n        )\n        max_count = np.amax(all_sizes)\n        num_splits = max_count // BATCH_SIZE + 1\n\n        # Split the message into batches and send.\n        splits = np.array_split(node_list, num_splits)\n        shuffle_mappings = []\n        for item in splits:\n            shuffle_ids = id_lookup.get_shuffle_nids(\n                item, local_nids, local_shuffle_nids, world_size\n            )\n            shuffle_mappings.append(shuffle_ids)\n\n        shuffle_ids = np.concatenate(shuffle_mappings)\n        assert shuffle_ids.shape[0] == node_list.shape[0]\n        edge_data[\n            constants.SHUFFLE_GLOBAL_SRC_ID + \"/\" + str(local_part_id)\n        ] = shuffle_ids\n\n        # Destination end points of edges are owned by the current node and therefore\n        # should have corresponding SHUFFLE_GLOBAL_NODE_IDs.\n        # Here retrieve SHUFFLE_GLOBAL_NODE_IDs for the destination end points of local edges.\n        uniq_ids, inverse_idx = np.unique(\n            edge_data[constants.GLOBAL_DST_ID + \"/\" + str(local_part_id)],\n            return_inverse=True,\n        )\n        common, idx1, idx2 = np.intersect1d(\n            uniq_ids,\n            node_data[constants.GLOBAL_NID + \"/\" + str(local_part_id)],\n            assume_unique=True,\n            return_indices=True,\n        )\n        assert len(common) == len(uniq_ids)\n\n        edge_data[\n            constants.SHUFFLE_GLOBAL_DST_ID + \"/\" + str(local_part_id)\n        ] = node_data[constants.SHUFFLE_GLOBAL_NID + \"/\" + str(local_part_id)][\n            idx2\n        ][\n            inverse_idx\n        ]\n        assert len(\n            edge_data[\n                constants.SHUFFLE_GLOBAL_DST_ID + \"/\" + str(local_part_id)\n            ]\n        ) == len(edge_data[constants.GLOBAL_DST_ID + \"/\" + str(local_part_id)])\n\n    memory_snapshot(\"GlobalToShuffleIDMap_AfterLookupServiceCalls: \", rank)\n    return edge_data\n\n\ndef assign_shuffle_global_nids_nodes(rank, world_size, num_parts, node_data):\n    \"\"\"\n    Utility function to assign shuffle global ids to nodes at a given rank\n    node_data gets converted from [ntype, global_type_nid, global_nid]\n    to [shuffle_global_nid, ntype, global_type_nid, global_nid, part_local_type_nid]\n    where shuffle_global_nid : global id of the node after data shuffle\n            ntype : node-type as read from xxx_nodes.txt\n            global_type_nid : node-type-id as read from xxx_nodes.txt\n            global_nid : node-id as read from xxx_nodes.txt, implicitly\n                            this is the line no. in the file\n            part_local_type_nid : type_nid assigned by the current rank within its scope\n\n    Parameters:\n    -----------\n    rank : integer\n        rank of the process\n    world_size : integer\n        total number of processes used in the process group\n    num_parts : integer\n        total number of output graph partitions\n    node_data : dictionary\n        node_data is a dictionary with keys as column names and values as numpy arrays\n    \"\"\"\n    # Compute prefix sum to determine node-id offsets\n    local_row_counts = []\n    for local_part_id in range(num_parts // world_size):\n        local_row_counts.append(\n            node_data[constants.GLOBAL_NID + \"/\" + str(local_part_id)].shape[0]\n        )\n\n    # Perform allgather to compute the local offsets.\n    prefix_sum_nodes = allgather_sizes(local_row_counts, world_size, num_parts)\n\n    for local_part_id in range(num_parts // world_size):\n        shuffle_global_nid_start = prefix_sum_nodes[\n            rank + (local_part_id * world_size)\n        ]\n        shuffle_global_nid_end = prefix_sum_nodes[\n            rank + 1 + (local_part_id * world_size)\n        ]\n        shuffle_global_nids = np.arange(\n            shuffle_global_nid_start, shuffle_global_nid_end, dtype=np.int64\n        )\n        node_data[\n            constants.SHUFFLE_GLOBAL_NID + \"/\" + str(local_part_id)\n        ] = shuffle_global_nids\n\n\ndef assign_shuffle_global_nids_edges(rank, world_size, num_parts, edge_data):\n    \"\"\"\n    Utility function to assign shuffle_global_eids to edges\n    edge_data gets converted from [global_src_nid, global_dst_nid, global_type_eid, etype]\n    to [shuffle_global_src_nid, shuffle_global_dst_nid, global_src_nid, global_dst_nid, global_type_eid, etype]\n\n    Parameters:\n    -----------\n    rank : integer\n        rank of the current process\n    world_size : integer\n        total count of processes in execution\n    num_parts : integer\n        total number of output graph partitions\n    edge_data : numpy ndarray\n        edge data as read from xxx_edges.txt file\n\n    Returns:\n    --------\n    integer\n        shuffle_global_eid_start, which indicates the starting value from which shuffle_global-ids are assigned to edges\n        on this rank\n    \"\"\"\n    # get prefix sum of edge counts per rank to locate the starting point\n    # from which global-ids to edges are assigned in the current rank\n    local_row_counts = []\n    for local_part_id in range(num_parts // world_size):\n        local_row_counts.append(\n            edge_data[constants.GLOBAL_SRC_ID + \"/\" + str(local_part_id)].shape[\n                0\n            ]\n        )\n\n    shuffle_global_eid_offset = []\n    prefix_sum_edges = allgather_sizes(local_row_counts, world_size, num_parts)\n    for local_part_id in range(num_parts // world_size):\n        shuffle_global_eid_start = prefix_sum_edges[\n            rank + (local_part_id * world_size)\n        ]\n        shuffle_global_eid_end = prefix_sum_edges[\n            rank + 1 + (local_part_id * world_size)\n        ]\n        shuffle_global_eids = np.arange(\n            shuffle_global_eid_start, shuffle_global_eid_end, dtype=np.int64\n        )\n        edge_data[\n            constants.SHUFFLE_GLOBAL_EID + \"/\" + str(local_part_id)\n        ] = shuffle_global_eids\n        shuffle_global_eid_offset.append(shuffle_global_eid_start)\n\n    return shuffle_global_eid_offset\n"
  },
  {
    "path": "tools/distpartitioning/gloo_wrapper.py",
    "content": "import numpy as np\nimport torch\nimport torch.distributed as dist\n\n\ndef allgather_sizes(send_data, world_size, num_parts, return_sizes=False):\n    \"\"\"\n    Perform all gather on list lengths, used to compute prefix sums\n    to determine the offsets on each ranks. This is used to allocate\n    global ids for edges/nodes on each ranks.\n\n    Parameters\n    ----------\n    send_data : numpy array\n        Data on which allgather is performed.\n    world_size : integer\n        No. of processes configured for execution\n    num_parts : integer\n        No. of output graph partitions\n    return_sizes : bool\n        Boolean flag to indicate whether to return raw sizes from each process\n        or perform prefix sum on the raw sizes.\n\n    Returns :\n    ---------\n        numpy array\n            array with the prefix sum\n    \"\"\"\n\n    # Assert on the world_size, num_parts\n    assert (num_parts % world_size) == 0\n\n    # compute the length of the local data\n    send_length = len(send_data)\n    out_tensor = torch.as_tensor(send_data, dtype=torch.int64)\n    in_tensor = [\n        torch.zeros(send_length, dtype=torch.int64) for _ in range(world_size)\n    ]\n\n    # all_gather message\n    dist.all_gather(in_tensor, out_tensor)\n\n    # Return on the raw sizes from each process\n    if return_sizes:\n        return torch.cat(in_tensor).numpy()\n\n    # gather sizes in on array to return to the invoking function\n    rank_sizes = np.zeros(num_parts + 1, dtype=np.int64)\n    part_counts = torch.cat(in_tensor).numpy()\n\n    count = rank_sizes[0]\n    idx = 1\n    for local_part_id in range(num_parts // world_size):\n        for r in range(world_size):\n            count += part_counts[r * (num_parts // world_size) + local_part_id]\n            rank_sizes[idx] = count\n            idx += 1\n\n    return rank_sizes\n\n\ndef __alltoall_cpu(rank, world_size, output_tensor_list, input_tensor_list):\n    \"\"\"\n    Each process scatters list of input tensors to all processes in a cluster\n    and return gathered list of tensors in output list. The tensors should have the same shape.\n\n    Parameters\n    ----------\n    rank : int\n        The rank of current worker\n    world_size : int\n        The size of the entire\n    output_tensor_list : List of tensor\n        The received tensors\n    input_tensor_list : List of tensor\n        The tensors to exchange\n    \"\"\"\n    input_tensor_list = [\n        tensor.to(torch.device(\"cpu\")) for tensor in input_tensor_list\n    ]\n    # TODO(#5002): As Boolean data is not supported in\n    # ``torch.distributed.scatter()``, we convert boolean into uint8 before\n    # scatter and convert it back afterwards.\n    dtypes = [t.dtype for t in input_tensor_list]\n    for i, dtype in enumerate(dtypes):\n        if dtype == torch.bool:\n            input_tensor_list[i] = input_tensor_list[i].to(torch.int8)\n            output_tensor_list[i] = output_tensor_list[i].to(torch.int8)\n    for i in range(world_size):\n        dist.scatter(\n            output_tensor_list[i], input_tensor_list if i == rank else [], src=i\n        )\n    # Convert back to original dtype\n    for i, dtype in enumerate(dtypes):\n        if dtype == torch.bool:\n            input_tensor_list[i] = input_tensor_list[i].to(dtype)\n            output_tensor_list[i] = output_tensor_list[i].to(dtype)\n\n\ndef alltoallv_cpu(rank, world_size, input_tensor_list, retain_nones=True):\n    \"\"\"\n    Wrapper function to providing the alltoallv functionality by using underlying alltoall\n    messaging primitive. This function, in its current implementation, supports exchanging\n    messages of arbitrary dimensions and is not tied to the user of this function.\n\n    This function pads all input tensors, except one, so that all the messages are of the same\n    size. Once the messages are padded, It first sends a vector whose first two elements are\n    1) actual message size along first dimension, and 2) Message size along first dimension\n    which is used for communication. The rest of the dimensions are assumed to be same across\n    all the input tensors. After receiving the message sizes, the receiving end will create buffers\n    of appropriate sizes. And then slices the received messages to remove the added padding, if any,\n    and returns to the caller.\n\n    Parameters:\n    -----------\n    rank : int\n        The rank of current worker\n    world_size : int\n        The size of the entire\n    input_tensor_list : List of tensor\n        The tensors to exchange\n    retain_nones : bool\n        Indicates whether to retain ``None`` data in returned value.\n\n    Returns:\n    --------\n    list :\n        list of tensors received from other processes during alltoall message\n\n    \"\"\"\n    # ensure len of input_tensor_list is same as the world_size.\n    assert input_tensor_list != None\n    assert len(input_tensor_list) == world_size\n\n    # ensure that all the tensors in the input_tensor_list are of same size.\n    sizes = [list(x.size()) for x in input_tensor_list]\n    for idx in range(1, len(sizes)):\n        assert len(sizes[idx - 1]) == len(\n            sizes[idx]\n        )  # no. of dimensions should be same\n        assert (\n            input_tensor_list[idx - 1].dtype == input_tensor_list[idx].dtype\n        )  # dtype should be same\n        assert (\n            sizes[idx - 1][1:] == sizes[idx][1:]\n        )  # except first dimension remaining dimensions should all be the same\n\n    # decide how much to pad.\n    # always use the first-dimension for padding.\n    ll = [x[0] for x in sizes]\n\n    # dims of the padding needed, if any\n    # these dims are used for padding purposes.\n    diff_dims = [[np.amax(ll) - l[0]] + l[1:] for l in sizes]\n\n    # pad the actual message\n    input_tensor_list = [\n        torch.cat((x, torch.zeros(diff_dims[idx]).type(x.dtype)))\n        for idx, x in enumerate(input_tensor_list)\n    ]\n\n    # send useful message sizes to all\n    send_counts = []\n    recv_counts = []\n    for idx in range(world_size):\n        # send a vector, of atleast 3 elements, [a, b, ....] where\n        # a = useful message dim, b = actual message outgoing message size along the first dimension\n        # and remaining elements are the remaining dimensions of the tensor\n        send_counts.append(\n            torch.from_numpy(\n                np.array([sizes[idx][0]] + [np.amax(ll)] + sizes[idx][1:])\n            ).type(torch.int64)\n        )\n        recv_counts.append(\n            torch.zeros((1 + len(sizes[idx])), dtype=torch.int64)\n        )\n    __alltoall_cpu(rank, world_size, recv_counts, send_counts)\n\n    # allocate buffers for receiving message\n    output_tensor_list = []\n    recv_counts = [tsize.numpy() for tsize in recv_counts]\n    for idx, tsize in enumerate(recv_counts):\n        output_tensor_list.append(\n            torch.zeros(tuple(tsize[1:])).type(input_tensor_list[idx].dtype)\n        )\n\n    # send actual message itself.\n    __alltoall_cpu(rank, world_size, output_tensor_list, input_tensor_list)\n\n    # extract un-padded message from the output_tensor_list and return it\n    return_vals = []\n    for s, t in zip(recv_counts, output_tensor_list):\n        if s[0] == 0:\n            if retain_nones:\n                return_vals.append(None)\n        else:\n            return_vals.append(t[0 : s[0]])\n    return return_vals\n\n\ndef gather_metadata_json(metadata, rank, world_size):\n    \"\"\"\n    Gather an object (json schema on `rank`)\n    Parameters:\n    -----------\n    metadata : json dictionary object\n        json schema formed on each rank with graph level data.\n        This will be used as input to the distributed training in the later steps.\n    Returns:\n    --------\n    list : list of json dictionary objects\n        The result of the gather operation, which is the list of json dicitonary\n        objects from each rank in the world\n    \"\"\"\n\n    # Populate input obj and output obj list on rank-0 and non-rank-0 machines\n    input_obj = None if rank == 0 else metadata\n    output_objs = [None for _ in range(world_size)] if rank == 0 else None\n\n    # invoke the gloo method to perform gather on rank-0\n    dist.gather_object(input_obj, output_objs, dst=0)\n    return output_objs\n"
  },
  {
    "path": "tools/distpartitioning/parmetis_postprocess.py",
    "content": "import argparse\nimport logging\nimport os\nimport platform\nimport sys\nfrom pathlib import Path\n\nimport constants\n\nimport numpy as np\nimport pyarrow\nimport pyarrow.csv as csv\nfrom partition_algo.base import dump_partition_meta, PartitionMeta\nfrom utils import get_idranges, get_node_types, read_json\n\n\ndef post_process(params):\n    \"\"\"Auxiliary function to read the parmetis output file and generate\n    metis partition-id files, sorted, per node-type. These files are used\n    by the dist. graph partitioning pipeline for further processing.\n\n    Parameters:\n    -----------\n    params : argparser object\n        argparser object to capture command line options passed to the\n        executable\n    \"\"\"\n    logging.info(\"Starting to process parmetis output.\")\n\n    logging.info(params.postproc_input_dir)\n    logging.info(params.schema_file)\n    logging.info(params.parmetis_output_file)\n    assert os.path.isfile(\n        os.path.join(params.postproc_input_dir, params.schema_file)\n    )\n    assert os.path.isfile(params.parmetis_output_file)\n    schema = read_json(\n        os.path.join(params.postproc_input_dir, params.schema_file)\n    )\n\n    metis_df = csv.read_csv(\n        params.parmetis_output_file,\n        read_options=pyarrow.csv.ReadOptions(autogenerate_column_names=True),\n        parse_options=pyarrow.csv.ParseOptions(delimiter=\" \"),\n    )\n    global_nids = metis_df[\"f0\"].to_numpy()\n    partition_ids = metis_df[\"f1\"].to_numpy()\n    num_parts = np.unique(partition_ids).size\n\n    sort_idx = np.argsort(global_nids)\n    global_nids = global_nids[sort_idx]\n    partition_ids = partition_ids[sort_idx]\n\n    ntypes_ntypeid_map, ntypes, ntid_ntype_map = get_node_types(schema)\n    type_nid_dict, ntype_gnid_offset = get_idranges(\n        schema[constants.STR_NODE_TYPE],\n        dict(\n            zip(\n                schema[constants.STR_NODE_TYPE],\n                schema[constants.STR_NUM_NODES_PER_TYPE],\n            )\n        ),\n    )\n\n    outdir = Path(params.partitions_dir)\n    os.makedirs(outdir, exist_ok=True)\n    for ntype_id, ntype_name in ntid_ntype_map.items():\n        start = ntype_gnid_offset[ntype_name][0, 0]\n        end = ntype_gnid_offset[ntype_name][0, 1]\n        out_data = partition_ids[start:end]\n\n        out_file = os.path.join(outdir, f\"{ntype_name}.txt\")\n        options = csv.WriteOptions(include_header=False, delimiter=\" \")\n\n        csv.write_csv(\n            pyarrow.Table.from_arrays([out_data], names=[\"partition-ids\"]),\n            out_file,\n            options,\n        )\n        logging.info(f\"Generated {out_file}\")\n\n    # generate partition meta file.\n    part_meta = PartitionMeta(\n        version=\"1.0.0\", num_parts=num_parts, algo_name=\"metis\"\n    )\n    dump_partition_meta(part_meta, os.path.join(outdir, \"partition_meta.json\"))\n\n    logging.info(\"Done processing parmetis output\")\n\n\nif __name__ == \"__main__\":\n    \"\"\"Main function to convert the output of parmetis into metis partitions\n    which are accepted by graph partitioning pipeline.\n\n    ParMETIS currently generates one output file, which is in the following format:\n    <global-node-id> <partition-id>\n\n    Graph partitioing pipeline, per the new dataset file format rules expects the\n    metis partitions to be in the following format:\n    No. of files will be equal to the no. of node-types in the graph\n    Each file will have one-number/line which is <partition-id>.\n\n    Example usage:\n    --------------\n    python parmetis_postprocess.py\n        --input_file <metis-partitions-file>\n        --output-dir <directory where the output files are stored>\n        --schema <schema-file-path>\n    \"\"\"\n    parser = argparse.ArgumentParser(\n        description=\"PostProcessing the ParMETIS\\\n        output for partitioning pipeline\"\n    )\n    parser.add_argument(\n        \"--postproc_input_dir\",\n        required=True,\n        type=str,\n        help=\"Base directory for post processing step.\",\n    )\n    parser.add_argument(\n        \"--schema_file\",\n        required=True,\n        type=str,\n        help=\"The schema of the input graph\",\n    )\n    parser.add_argument(\n        \"--parmetis_output_file\",\n        required=True,\n        type=str,\n        help=\"ParMETIS output file\",\n    )\n    parser.add_argument(\n        \"--partitions_dir\",\n        required=True,\n        type=str,\n        help=\"The output\\\n        will be files (with metis partition ids) and each file corresponds to\\\n        a node-type in the input graph dataset.\",\n    )\n    params = parser.parse_args()\n\n    # Configure logging.\n    logging.basicConfig(\n        level=\"INFO\",\n        format=f\"[{platform.node()} \\\n        %(levelname)s %(asctime)s PID:%(process)d] %(message)s\",\n    )\n\n    # Invoke the function for post processing\n    post_process(params)\n"
  },
  {
    "path": "tools/distpartitioning/parmetis_preprocess.py",
    "content": "import argparse\nimport logging\nimport os\nimport platform\nfrom pathlib import Path\n\nimport array_readwriter\n\nimport constants\n\nimport numpy as np\nimport pyarrow\nimport pyarrow.csv as csv\nfrom utils import (\n    generate_read_list,\n    generate_roundrobin_read_list,\n    get_idranges,\n    get_node_types,\n    read_json,\n)\n\n\ndef get_proc_info():\n    \"\"\"Helper function to get the rank from the\n    environment when `mpirun` is used to run this python program.\n\n    Please note that for mpi(openmpi) installation the rank is retrieved from the\n    environment using OMPI_COMM_WORLD_RANK. For mpich it is\n    retrieved from the environment using PMI_RANK.\n\n    Returns:\n    --------\n    integer :\n        Rank of the current process.\n    \"\"\"\n    env_variables = dict(os.environ)\n    # mpich\n    if \"PMI_RANK\" in env_variables:\n        return int(env_variables[\"PMI_RANK\"])\n    # openmpi\n    elif \"OMPI_COMM_WORLD_RANK\" in env_variables:\n        return int(env_variables[\"OMPI_COMM_WORLD_RANK\"])\n    else:\n        return 0\n\n\ndef get_world_size():\n    \"\"\"Helper function to get the world size from the\n    environment when `mpirun` is used to run this python program.\n\n    Returns:\n    --------\n    integer :\n        Numer of processes created by the executor that created this process.\n    \"\"\"\n    env_variables = dict(os.environ)\n    # mpich\n    if \"PMI_SIZE\" in env_variables:\n        return int(env_variables[\"PMI_SIZE\"])\n    # openmpi\n    elif \"OMPI_COMM_WORLD_SIZE\" in env_variables:\n        return int(env_variables[\"OMPI_COMM_WORLD_SIZE\"])\n    else:\n        return 1\n\n\ndef gen_edge_files(rank, schema_map, params):\n    \"\"\"Function to create edges files to be consumed by ParMETIS\n    for partitioning purposes.\n\n    This function creates the edge files and each of these will have the\n    following format (meaning each line of these file is of the following format)\n    <global_src_id> <global_dst_id>\n\n    Here ``global`` prefix means that globally unique identifier assigned each node\n    in the input graph. In this context globally unique means unique across all the\n    nodes in the input graph.\n\n    Parameters:\n    -----------\n    rank : int\n        rank of the current process\n    schema_map : json dictionary\n        Dictionary created by reading the metadata.json file for the input dataset.\n    output : string\n        Location of storing the node-weights and edge files for ParMETIS.\n    \"\"\"\n    _, ntype_gnid_offset = get_idranges(\n        schema_map[constants.STR_NODE_TYPE],\n        dict(\n            zip(\n                schema_map[constants.STR_NODE_TYPE],\n                schema_map[constants.STR_NUM_NODES_PER_TYPE],\n            )\n        ),\n    )\n\n    # Regenerate edge files here.\n    edge_data = schema_map[constants.STR_EDGES]\n\n    outdir = Path(params.output_dir)\n    os.makedirs(outdir, exist_ok=True)\n\n    def process_and_write_back(data_df, idx):\n        data_f0 = data_df[:, 0]\n        data_f1 = data_df[:, 1]\n\n        global_src_id = data_f0 + ntype_gnid_offset[src_ntype_name][0, 0]\n        global_dst_id = data_f1 + ntype_gnid_offset[dst_ntype_name][0, 0]\n        cols = [global_src_id, global_dst_id]\n        col_names = [\"global_src_id\", \"global_dst_id\"]\n\n        out_file_name = Path(edge_data_files[idx]).stem.split(\".\")[0]\n        out_file = os.path.join(\n            outdir, etype_name, f\"edges_{out_file_name}.csv\"\n        )\n        os.makedirs(os.path.dirname(out_file), exist_ok=True)\n\n        options = csv.WriteOptions(include_header=False, delimiter=\" \")\n        csv.write_csv(\n            pyarrow.Table.from_arrays(cols, names=col_names),\n            out_file,\n            options,\n        )\n        return out_file\n\n    edge_files = []\n    for etype_name, etype_info in edge_data.items():\n        edge_data_files = etype_info[constants.STR_DATA]\n\n        # ``edgetype`` strings are in canonical format, src_node_type:edge_type:dst_node_type\n        tokens = etype_name.split(\":\")\n        assert len(tokens) == 3\n\n        src_ntype_name = tokens[0]\n\n        dst_ntype_name = tokens[2]\n\n        rank_assignments = generate_roundrobin_read_list(\n            len(edge_data_files), params.num_parts\n        )\n        for file_idx in rank_assignments[rank]:\n            reader_fmt_meta = {\n                \"name\": etype_info[constants.STR_FORMAT][constants.STR_NAME],\n            }\n            if reader_fmt_meta[\"name\"] == constants.STR_CSV:\n                reader_fmt_meta[\"delimiter\"] = etype_info[constants.STR_FORMAT][\n                    constants.STR_FORMAT_DELIMITER\n                ]\n            data_df = array_readwriter.get_array_parser(**reader_fmt_meta).read(\n                os.path.join(params.input_dir, edge_data_files[file_idx])\n            )\n            out_file = process_and_write_back(data_df, file_idx)\n            edge_files.append(out_file)\n\n    return edge_files\n\n\ndef gen_node_weights_files(schema_map, params):\n    \"\"\"Function to create node weight files for ParMETIS along with the edge files.\n\n    This function generates node-data files, which will be read by the ParMETIS\n    executable for partitioning purposes. Each line in these files will be of the\n    following format:\n        <node_type_id> <node_weight_list> <type_wise_node_id>\n    node_type_id -  is id assigned to the node-type to which a given particular\n        node belongs to\n    weight_list - this is a one-hot vector in which the number in the location of\n        the current nodes' node-type will be set to `1` and other will be `0`\n    type_node_id - this is the id assigned to the node (in the context of the current\n        nodes` node-type). Meaning this id is unique across all the nodes which belong to\n        the current nodes` node-type.\n\n    Parameters:\n    -----------\n    schema_map : json dictionary\n        Dictionary created by reading the metadata.json file for the input dataset.\n    output : string\n        Location of storing the node-weights and edge files for ParMETIS.\n\n    Returns:\n    --------\n    list :\n        List of filenames for nodes of the input graph.\n    list :\n        List o ffilenames for edges of the input graph.\n    \"\"\"\n    rank = get_proc_info()\n    ntypes_ntypeid_map, ntypes, ntid_ntype_map = get_node_types(schema_map)\n    type_nid_dict, ntype_gnid_offset = get_idranges(\n        schema_map[constants.STR_NODE_TYPE],\n        dict(\n            zip(\n                schema_map[constants.STR_NODE_TYPE],\n                schema_map[constants.STR_NUM_NODES_PER_TYPE],\n            )\n        ),\n    )\n\n    node_files = []\n    outdir = Path(params.output_dir)\n    os.makedirs(outdir, exist_ok=True)\n\n    for ntype_id, ntype_name in ntid_ntype_map.items():\n\n        # This ntype does not have any train/test/val masks...\n        # Each rank will generate equal no. of rows for this node type.\n        total_count = schema_map[constants.STR_NUM_NODES_PER_TYPE][ntype_id]\n        per_rank_range = np.ones((params.num_parts,), dtype=np.int64) * (\n            total_count // params.num_parts\n        )\n        for i in range(total_count % params.num_parts):\n            per_rank_range[i] += 1\n\n        tid_start = np.cumsum([0] + list(per_rank_range[:-1]))\n        tid_end = np.cumsum(list(per_rank_range))\n        local_tid_start = tid_start[rank]\n        local_tid_end = tid_end[rank]\n        sz = local_tid_end - local_tid_start\n\n        cols = []\n        col_names = []\n\n        # ntype-id\n        cols.append(\n            pyarrow.array(np.ones(sz, dtype=np.int64) * np.int64(ntype_id))\n        )\n        col_names.append(\"ntype\")\n\n        # one-hot vector for ntype-id here.\n        for i in range(len(ntypes)):\n            if i == ntype_id:\n                cols.append(pyarrow.array(np.ones(sz, dtype=np.int64)))\n            else:\n                cols.append(pyarrow.array(np.zeros(sz, dtype=np.int64)))\n            col_names.append(\"w{}\".format(i))\n\n        # `type_nid` should be the very last column in the node weights files.\n        cols.append(\n            pyarrow.array(\n                np.arange(local_tid_start, local_tid_end, dtype=np.int64)\n            )\n        )\n        col_names.append(\"type_nid\")\n\n        out_file = os.path.join(\n            outdir, \"node_weights_{}_{}.txt\".format(ntype_name, rank)\n        )\n        options = csv.WriteOptions(include_header=False, delimiter=\" \")\n        options.delimiter = \" \"\n\n        csv.write_csv(\n            pyarrow.Table.from_arrays(cols, names=col_names), out_file, options\n        )\n        node_files.append(\n            (\n                ntype_gnid_offset[ntype_name][0, 0] + local_tid_start,\n                ntype_gnid_offset[ntype_name][0, 0] + local_tid_end,\n                out_file,\n            )\n        )\n\n    return node_files\n\n\ndef gen_parmetis_input_args(params, schema_map):\n    \"\"\"Function to create two input arguments which will be passed to the parmetis.\n    first argument is a text file which has a list of node-weights files,\n    namely parmetis-nfiles.txt, and second argument is a text file which has a\n    list of edge files, namely parmetis_efiles.txt.\n    ParMETIS uses these two files to read/load the graph and partition the graph\n    With regards to the file format, parmetis_nfiles.txt uses the following format\n    for each line in that file:\n        <filename> <global_node_id_start> <global_node_id_end>(exclusive)\n    While parmetis_efiles.txt just has <filename> in each line.\n\n    Parameters:\n    -----------\n    params : argparser instance\n        Instance of ArgParser class, which has all the input arguments passed to\n        run this program.\n    schema_map : json dictionary\n        Dictionary object created after reading the graph metadata.json file.\n    \"\"\"\n\n    # TODO: This makes the assumption that all node files have the same number of chunks\n    ntypes_ntypeid_map, ntypes, ntid_ntype_map = get_node_types(schema_map)\n    type_nid_dict, ntype_gnid_offset = get_idranges(\n        schema_map[constants.STR_NODE_TYPE],\n        dict(\n            zip(\n                schema_map[constants.STR_NODE_TYPE],\n                schema_map[constants.STR_NUM_NODES_PER_TYPE],\n            )\n        ),\n    )\n\n    # Check if <graph-name>_stats.txt exists, if not create one using metadata.\n    # Here stats file will be created in the current directory.\n    # No. of constraints, third column in the stats file is computed as follows:\n    #   num_constraints = no. of node types + train_mask + test_mask + val_mask\n    #   Here, (train/test/val) masks will be set to 1 if these masks exist for\n    #   all the node types in the graph, otherwise these flags will be set to 0\n    assert (\n        constants.STR_GRAPH_NAME in schema_map\n    ), \"Graph name is not present in the json file\"\n    graph_name = schema_map[constants.STR_GRAPH_NAME]\n    if not os.path.isfile(\n        os.path.join(params.input_dir, f\"{graph_name}_stats.txt\")\n    ):\n        num_nodes = np.sum(schema_map[constants.STR_NUM_NODES_PER_TYPE])\n        num_edges = np.sum(schema_map[constants.STR_NUM_EDGES_PER_TYPE])\n        num_ntypes = len(schema_map[constants.STR_NODE_TYPE])\n\n        num_constraints = num_ntypes\n\n        with open(\n            os.path.join(params.input_dir, f\"{graph_name}_stats.txt\"), \"w\"\n        ) as sf:\n            sf.write(f\"{num_nodes} {num_edges} {num_constraints}\")\n\n    node_files = []\n    outdir = Path(params.output_dir)\n    os.makedirs(outdir, exist_ok=True)\n    for ntype_id, ntype_name in ntid_ntype_map.items():\n        global_nid_offset = ntype_gnid_offset[ntype_name][0, 0]\n        total_count = schema_map[constants.STR_NUM_NODES_PER_TYPE][ntype_id]\n        per_rank_range = np.ones((params.num_parts,), dtype=np.int64) * (\n            total_count // params.num_parts\n        )\n        for i in range(total_count % params.num_parts):\n            per_rank_range[i] += 1\n        tid_start = np.cumsum([0] + list(per_rank_range[:-1]))\n        tid_end = np.cumsum(per_rank_range)\n        logging.info(f\" tid-start = {tid_start}, tid-end = {tid_end}\")\n        logging.info(f\" per_rank_range - {per_rank_range}\")\n\n        for part_idx in range(params.num_parts):\n            local_tid_start = tid_start[part_idx]\n            local_tid_end = tid_end[part_idx]\n            out_file = os.path.join(\n                outdir, \"node_weights_{}_{}.txt\".format(ntype_name, part_idx)\n            )\n            node_files.append(\n                (\n                    out_file,\n                    global_nid_offset + local_tid_start,\n                    global_nid_offset + local_tid_end,\n                )\n            )\n\n    with open(\n        os.path.join(params.output_dir, \"parmetis_nfiles.txt\"), \"w\"\n    ) as parmetis_nf:\n        for node_file in node_files:\n            # format: filename global_node_id_start global_node_id_end(exclusive)\n            parmetis_nf.write(\n                \"{} {} {}\\n\".format(node_file[0], node_file[1], node_file[2])\n            )\n\n    # Regenerate edge files here.\n    # NOTE: The file names need to match the ones generated by gen_edge_files function\n    edge_data = schema_map[constants.STR_EDGES]\n    edge_files = []\n    for etype_name, etype_info in edge_data.items():\n        edge_data_files = etype_info[constants.STR_DATA]\n        for edge_file_path in edge_data_files:\n            out_file_name = Path(edge_file_path).stem.split(\".\")[0]\n            out_file = os.path.join(\n                outdir, etype_name, \"edges_{}.csv\".format(out_file_name)\n            )\n            edge_files.append(out_file)\n\n    with open(\n        os.path.join(params.output_dir, \"parmetis_efiles.txt\"), \"w\"\n    ) as parmetis_efile:\n        for edge_file in edge_files:\n            parmetis_efile.write(\"{}\\n\".format(edge_file))\n\n\ndef run_preprocess_data(params):\n    \"\"\"Main function which will help create graph files for ParMETIS processing\n\n    Parameters:\n    -----------\n    params : argparser object\n        An instance of argparser class which stores command line arguments.\n    \"\"\"\n    logging.info(\"Starting to generate ParMETIS files...\")\n    rank = get_proc_info()\n\n    assert os.path.isdir(\n        params.input_dir\n    ), f\"Please check `input_dir` argument: {params.input_dit}.\"\n\n    schema_map = read_json(os.path.join(params.input_dir, params.schema_file))\n    gen_node_weights_files(schema_map, params)\n    logging.info(\"Done with node weights....\")\n\n    gen_edge_files(rank, schema_map, params)\n    logging.info(\"Done with edge weights...\")\n\n    if rank == 0:\n        gen_parmetis_input_args(params, schema_map)\n    logging.info(\"Done generating files for ParMETIS run ..\")\n\n\nif __name__ == \"__main__\":\n    \"\"\"Main function used to generate temporary files needed for ParMETIS execution.\n    This function generates node-weight files and edges files which are consumed by ParMETIS.\n\n    Example usage:\n    --------------\n    mpirun -np 4 python3 parmetis_preprocess.py --schema <file> --output <target-output-dir>\n    \"\"\"\n    parser = argparse.ArgumentParser(\n        description=\"Generate ParMETIS files for input dataset\"\n    )\n    parser.add_argument(\n        \"--schema_file\",\n        required=True,\n        type=str,\n        help=\"The schema of the input graph\",\n    )\n    parser.add_argument(\n        \"--input_dir\",\n        required=True,\n        type=str,\n        help=\"This directory will be used as the relative directory to locate files, if absolute paths are not used\",\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        required=True,\n        type=str,\n        help=\"The output directory for the node weights files and auxiliary files for ParMETIS.\",\n    )\n    parser.add_argument(\n        \"--num_parts\",\n        required=True,\n        type=int,\n        help=\"Total no. of output graph partitions.\",\n    )\n    parser.add_argument(\n        \"--log_level\",\n        required=False,\n        type=str,\n        help=\"Log level to use for execution.\",\n        default=\"INFO\",\n        choices=[\"DEBUG\", \"INFO\", \"WARNING\", \"ERROR\", \"CRITICAL\"],\n    )\n    params = parser.parse_args()\n\n    # Configure logging.\n    logging.basicConfig(\n        level=getattr(logging, params.log_level, None),\n        format=f\"[{platform.node()} \\\n        %(levelname)s %(asctime)s PID:%(process)d] %(message)s\",\n    )\n\n    # Invoke the function to generate files for parmetis\n    run_preprocess_data(params)\n"
  },
  {
    "path": "tools/distpartitioning/parmetis_wrapper.py",
    "content": "import argparse\nimport logging\nimport os\nimport platform\nimport sys\nfrom pathlib import Path\n\nimport constants\nfrom utils import read_json\n\n\ndef check_dependencies():\n    \"\"\"Check if all the dependencies needed for the execution of this file\n    are installed.\n    \"\"\"\n\n    exec_path = os.get_exec_path()\n    mpi_install = False\n    for x in exec_path:\n        if os.path.isfile(os.path.join(x, \"mpirun\")):\n            mpi_install = True\n            break\n    assert (\n        mpi_install\n    ), \"Could not locate the following dependency: MPI. Please install it and try again.\"\n\n    dgl_path = os.environ.get(\"DGL_HOME\", \"\")\n    assert os.path.isdir(\n        dgl_path\n    ), \"Environment variable DGL_HOME not found. Please define the DGL installation path\"\n\n\ndef run_parmetis_wrapper(params):\n    \"\"\"Function to execute all the steps needed to run ParMETIS\n\n    Parameters:\n    -----------\n    params : argparser object\n        an instance of argparser class to capture command-line arguments\n    \"\"\"\n    schema = read_json(\n        os.path.join(params.preproc_input_dir, params.schema_file)\n    )\n    graph_name = schema[constants.STR_GRAPH_NAME]\n    num_partitions = params.num_parts\n\n    # Check if parmetis_preprocess.py exists.\n    assert os.path.isfile(\n        os.path.join(\n            os.path.dirname(os.path.abspath(__file__)), \"parmetis_preprocess.py\"\n        )\n    ), \"Please check DGL Installation, parmetis_preprocess.py file does not exist.\"\n\n    # Trigger pre-processing step to generate input files for ParMETIS.\n    preproc_cmd = (\n        f\"mpirun -np {num_partitions} -hostfile {params.hostfile} \"\n        f\"python3 $DGL_HOME/tools/distpartitioning/parmetis_preprocess.py \"\n        f\"--schema_file {params.schema_file} \"\n        f\"--input_dir {params.preproc_input_dir} \"\n        f\"--output_dir {params.preproc_output_dir} \"\n        f\"--num_parts {num_partitions}\"\n    )\n    logging.info(f\"Executing Preprocessing Step: {preproc_cmd}\")\n    os.system(preproc_cmd)\n    logging.info(f\"Done Preprocessing Step\")\n\n    # Trigger ParMETIS for creating metis partitions for the input graph.\n    parmetis_install_path = \"pm_dglpart3\"\n    if params.parmetis_install_path is not None:\n        parmetis_install_path = os.path.join(\n            params.parmetis_install_path, parmetis_install_path\n        )\n    parmetis_nfiles = os.path.join(\n        params.preproc_output_dir, \"parmetis_nfiles.txt\"\n    )\n    parmetis_efiles = os.path.join(\n        params.preproc_output_dir, \"parmetis_efiles.txt\"\n    )\n    parmetis_cmd = (\n        f\"mpirun -np {num_partitions} -hostfile {params.hostfile} \"\n        f\"{parmetis_install_path} {graph_name} {num_partitions} \"\n        f\"{parmetis_nfiles} {parmetis_efiles}\"\n    )\n    logging.info(f\"Executing ParMETIS: {parmetis_cmd}\")\n    os.system(parmetis_cmd)\n    logging.info(f\"Done ParMETIS execution step\")\n\n    # Trigger post-processing step to convert parmetis output to the form\n    # acceptable by dist. graph partitioning pipeline.\n    parmetis_output_file = os.path.join(\n        os.getcwd(), f\"{graph_name}_part.{num_partitions}\"\n    )\n    postproc_cmd = (\n        f\"python3 $DGL_HOME/tools/distpartitioning/parmetis_postprocess.py \"\n        f\"--postproc_input_dir {params.preproc_input_dir} \"\n        f\"--schema_file {params.schema_file} \"\n        f\"--parmetis_output_file {parmetis_output_file} \"\n        f\"--partitions_dir {params.partitions_dir}\"\n    )\n    logging.info(f\"Executing PostProcessing: {postproc_cmd}\")\n    os.system(postproc_cmd)\n    logging.info(\"Done Executing ParMETIS...\")\n\n\nif __name__ == \"__main__\":\n    \"\"\"Main function to invoke the parmetis wrapper function\"\"\"\n    parser = argparse.ArgumentParser(\n        description=\"Run ParMETIS as part of the graph partitioning pipeline\"\n    )\n    # Preprocessing step.\n    parser.add_argument(\n        \"--schema_file\",\n        required=True,\n        type=str,\n        help=\"The schema of the input graph\",\n    )\n    parser.add_argument(\n        \"--preproc_input_dir\",\n        type=str,\n        help=\"The input directory for preprocess where the dataset is located\",\n    )\n    parser.add_argument(\n        \"--preproc_output_dir\",\n        required=True,\n        type=str,\n        help=\"The output directory for the node weights files and auxiliary\\\n              files for ParMETIS.\",\n    )\n    parser.add_argument(\n        \"--hostfile\",\n        required=True,\n        type=str,\n        help=\"A text file with a list of ip addresses.\",\n    )\n    parser.add_argument(\n        \"--num_parts\",\n        required=True,\n        type=int,\n        help=\"integer representing no. of partitions.\",\n    )\n\n    # ParMETIS step.\n    parser.add_argument(\n        \"--parmetis_install_path\",\n        required=False,\n        type=str,\n        help=\"The directory where ParMETIS is installed\",\n    )\n\n    # Postprocessing step.\n    parser.add_argument(\n        \"--parmetis_output_file\",\n        required=True,\n        type=str,\n        help=\"ParMETIS output file (global_node_id to partition_id mappings)\",\n    )\n    parser.add_argument(\n        \"--partitions_dir\",\n        required=True,\n        type=str,\n        help=\"The directory where the files (with metis partition ids) grouped \\\n              by node_types\",\n    )\n    params = parser.parse_args()\n\n    # Configure logging.\n    logging.basicConfig(\n        level=\"INFO\",\n        format=f\"[{platform.node()} \\\n        %(levelname)s %(asctime)s PID:%(process)d] %(message)s\",\n    )\n\n    check_dependencies()\n    run_parmetis_wrapper(params)\n"
  },
  {
    "path": "tools/distpartitioning/utils.py",
    "content": "import json\nimport logging\nimport os\nfrom itertools import cycle\n\nimport constants\n\nimport dgl\nimport numpy as np\nimport psutil\nimport pyarrow\n\nimport torch\nfrom dgl.distributed.partition import _dump_part_config\nfrom pyarrow import csv\n\nDATA_TYPE_ID = {\n    data_type: id\n    for id, data_type in enumerate(\n        [\n            torch.float32,\n            torch.float64,\n            torch.float16,\n            torch.uint8,\n            torch.int8,\n            torch.int16,\n            torch.int32,\n            torch.int64,\n            torch.bool,\n        ]\n    )\n}\n\nREV_DATA_TYPE_ID = {id: data_type for data_type, id in DATA_TYPE_ID.items()}\n\n\ndef read_ntype_partition_files(schema_map, input_dir):\n    \"\"\"\n    Utility method to read the partition id mapping for each node.\n    For each node type, there will be an file, in the input directory argument\n    containing the partition id mapping for a given nodeid.\n\n    Parameters:\n    -----------\n    schema_map : dictionary\n        dictionary created by reading the input metadata json file\n    input_dir : string\n        directory in which the node-id to partition-id mappings files are\n        located for each of the node types in the input graph\n\n    Returns:\n    --------\n    numpy array :\n        array of integers representing mapped partition-ids for a given node-id.\n        The line number, in these files, are used as the type_node_id in each of\n        the files. The index into this array will be the homogenized node-id and\n        value will be the partition-id for that node-id (index). Please note that\n        the partition-ids of each node-type are stacked together vertically and\n        in this way heterogenous node-ids are converted to homogenous node-ids.\n    \"\"\"\n    assert os.path.isdir(input_dir)\n\n    # iterate over the node types and extract the partition id mappings\n    part_ids = []\n    ntype_names = schema_map[constants.STR_NODE_TYPE]\n    for ntype in ntype_names:\n        df = csv.read_csv(\n            os.path.join(input_dir, \"{}.txt\".format(ntype)),\n            read_options=pyarrow.csv.ReadOptions(\n                autogenerate_column_names=True\n            ),\n            parse_options=pyarrow.csv.ParseOptions(delimiter=\" \"),\n        )\n        ntype_partids = df[\"f0\"].to_numpy()\n        part_ids.append(ntype_partids)\n    return np.concatenate(part_ids)\n\n\ndef read_json(json_file):\n    \"\"\"\n    Utility method to read a json file schema\n\n    Parameters:\n    -----------\n    json_file : string\n        file name for the json schema\n\n    Returns:\n    --------\n    dictionary, as serialized in the json_file\n    \"\"\"\n    with open(json_file) as schema:\n        val = json.load(schema)\n\n    return val\n\n\ndef get_etype_featnames(etype_name, schema_map):\n    \"\"\"Retrieves edge feature names for a given edge_type\n\n    Parameters:\n    -----------\n    eype_name : string\n        a string specifying a edge_type name\n\n    schema : dictionary\n        metadata json object as a dictionary, which is read from the input\n        metadata file from the input dataset\n\n    Returns:\n    --------\n    list :\n        a list of feature names for a given edge_type\n    \"\"\"\n    edge_data = schema_map[constants.STR_EDGE_DATA]\n    feats = edge_data.get(etype_name, {})\n    return [feat for feat in feats]\n\n\ndef get_ntype_featnames(ntype_name, schema_map):\n    \"\"\"\n    Retrieves node feature names for a given node_type\n\n    Parameters:\n    -----------\n    ntype_name : string\n        a string specifying a node_type name\n\n    schema : dictionary\n        metadata json object as a dictionary, which is read from the input\n        metadata file from the input dataset\n\n    Returns:\n    --------\n    list :\n        a list of feature names for a given node_type\n    \"\"\"\n    node_data = schema_map[constants.STR_NODE_DATA]\n    feats = node_data.get(ntype_name, {})\n    return [feat for feat in feats]\n\n\ndef get_edge_types(schema_map):\n    \"\"\"Utility method to extract edge_typename -> edge_type mappings\n    as defined by the input schema\n\n    Parameters:\n    -----------\n    schema_map : dictionary\n        Input schema from which the edge_typename -> edge_typeid\n        dictionary is created.\n\n    Returns:\n    --------\n    dictionary\n        with keys as edge type names and values as ids (integers)\n    list\n        list of etype name strings\n    dictionary\n        with keys as etype ids (integers) and values as edge type names\n    \"\"\"\n    etypes = schema_map[constants.STR_EDGE_TYPE]\n    etype_etypeid_map = {e: i for i, e in enumerate(etypes)}\n    etypeid_etype_map = {i: e for i, e in enumerate(etypes)}\n    return etype_etypeid_map, etypes, etypeid_etype_map\n\n\ndef get_node_types(schema_map):\n    \"\"\"\n    Utility method to extract node_typename -> node_type mappings\n    as defined by the input schema\n\n    Parameters:\n    -----------\n    schema_map : dictionary\n        Input schema from which the node_typename -> node_type\n        dictionary is created.\n\n    Returns:\n    --------\n    dictionary\n        with keys as node type names and values as ids (integers)\n    list\n        list of ntype name strings\n    dictionary\n        with keys as ntype ids (integers) and values as node type names\n    \"\"\"\n    ntypes = schema_map[constants.STR_NODE_TYPE]\n    ntype_ntypeid_map = {e: i for i, e in enumerate(ntypes)}\n    ntypeid_ntype_map = {i: e for i, e in enumerate(ntypes)}\n    return ntype_ntypeid_map, ntypes, ntypeid_ntype_map\n\n\ndef get_gid_offsets(typenames, typecounts):\n    \"\"\"\n    Builds a map where the key-value pairs are typnames and respective\n    global-id offsets.\n\n    Parameters:\n    -----------\n    typenames : list of strings\n        a list of strings which can be either node typenames or edge typenames\n    typecounts : list of integers\n        a list of integers indicating the total number of nodes/edges for its\n        typeid which is the index in this list\n\n    Returns:\n    --------\n    dictionary :\n        a dictionary where keys are node_type names and values are\n        global_nid range, which is a tuple.\n\n    \"\"\"\n    assert len(typenames) == len(\n        typecounts\n    ), f\"No. of typenames does not match with its type counts names = {typenames}, counts = {typecounts}\"\n\n    counts = []\n    for name in typenames:\n        counts.append(typecounts[name])\n    starts = np.cumsum([0] + counts[:-1])\n    ends = np.cumsum(counts)\n\n    gid_offsets = {}\n    for idx, name in enumerate(typenames):\n        gid_offsets[name] = [starts[idx], ends[idx]]\n    return gid_offsets\n\n    \"\"\"\n    starts = np.cumsum([0] + type_counts[:-1])\n    ends = np.cumsum(type_counts)\n    gid_offsets = {}\n    for idx, name in enumerate(typenames):\n        gid_offsets[name] = [start[idx], ends[idx]]\n\n    return gid_offsets\n    \"\"\"\n\n\ndef get_gnid_range_map(node_tids):\n    \"\"\"\n    Retrieves auxiliary dictionaries from the metadata json object\n\n    Parameters:\n    -----------\n    node_tids: dictionary\n        This dictionary contains the information about nodes for each node_type.\n        Typically this information contains p-entries, where each entry has a file-name,\n        starting and ending type_node_ids for the nodes in this file. Keys in this dictionary\n        are the node_type and value is a list of lists. Each individual entry in this list has\n        three items: file-name, starting type_nid and ending type_nid\n\n    Returns:\n    --------\n    dictionary :\n        a dictionary where keys are node_type names and values are global_nid range, which is a tuple.\n\n    \"\"\"\n    ntypes_gid_range = {}\n    offset = 0\n    for k, v in node_tids.items():\n        ntypes_gid_range[k] = [offset + int(v[0][0]), offset + int(v[-1][1])]\n        offset += int(v[-1][1])\n\n    return ntypes_gid_range\n\n\ndef write_metadata_json(\n    input_list, output_dir, graph_name, world_size, num_parts\n):\n    \"\"\"\n    Merge json schema's from each of the rank's on rank-0.\n    This utility function, to be used on rank-0, to create aggregated json file.\n\n    Parameters:\n    -----------\n    metadata_list : list of json (dictionaries)\n        a list of json dictionaries to merge on rank-0\n    output_dir : string\n        output directory path in which results are stored (as a json file)\n    graph-name : string\n        a string specifying the graph name\n    \"\"\"\n    # Preprocess the input_list, a list of dictionaries\n    # each dictionary will contain num_parts/world_size metadata json\n    # which correspond to local partitions on the respective ranks.\n    metadata_list = []\n    for local_part_id in range(num_parts // world_size):\n        for idx in range(world_size):\n            metadata_list.append(\n                input_list[idx][\n                    \"local-part-id-\" + str(local_part_id * world_size + idx)\n                ]\n            )\n\n    # Initialize global metadata\n    graph_metadata = {}\n\n    # Merge global_edge_ids from each json object in the input list\n    edge_map = {}\n    x = metadata_list[0][\"edge_map\"]\n    for k in x:\n        edge_map[k] = []\n        for idx in range(len(metadata_list)):\n            edge_map[k].append(\n                [\n                    int(metadata_list[idx][\"edge_map\"][k][0][0]),\n                    int(metadata_list[idx][\"edge_map\"][k][0][1]),\n                ]\n            )\n    graph_metadata[\"edge_map\"] = edge_map\n\n    graph_metadata[\"etypes\"] = metadata_list[0][\"etypes\"]\n    graph_metadata[\"graph_name\"] = metadata_list[0][\"graph_name\"]\n    graph_metadata[\"halo_hops\"] = metadata_list[0][\"halo_hops\"]\n\n    # Merge global_nodeids from each of json object in the input list\n    node_map = {}\n    x = metadata_list[0][\"node_map\"]\n    for k in x:\n        node_map[k] = []\n        for idx in range(len(metadata_list)):\n            node_map[k].append(\n                [\n                    int(metadata_list[idx][\"node_map\"][k][0][0]),\n                    int(metadata_list[idx][\"node_map\"][k][0][1]),\n                ]\n            )\n    graph_metadata[\"node_map\"] = node_map\n\n    graph_metadata[\"ntypes\"] = metadata_list[0][\"ntypes\"]\n    graph_metadata[\"num_edges\"] = int(\n        sum([metadata_list[i][\"num_edges\"] for i in range(len(metadata_list))])\n    )\n    graph_metadata[\"num_nodes\"] = int(\n        sum([metadata_list[i][\"num_nodes\"] for i in range(len(metadata_list))])\n    )\n    graph_metadata[\"num_parts\"] = metadata_list[0][\"num_parts\"]\n    graph_metadata[\"part_method\"] = metadata_list[0][\"part_method\"]\n\n    for i in range(len(metadata_list)):\n        graph_metadata[\"part-{}\".format(i)] = metadata_list[i][\n            \"part-{}\".format(i)\n        ]\n\n    _dump_part_config(f\"{output_dir}/metadata.json\", graph_metadata)\n\n\ndef augment_edge_data(\n    edge_data, lookup_service, edge_tids, rank, world_size, num_parts\n):\n    \"\"\"\n    Add partition-id (rank which owns an edge) column to the edge_data.\n\n    Parameters:\n    -----------\n    edge_data : numpy ndarray\n        Edge information as read from the xxx_edges.txt file\n    lookup_service : instance of class DistLookupService\n       Distributed lookup service used to map global-nids to respective partition-ids and▒\n       shuffle-global-nids\n    edge_tids: dictionary\n        dictionary where keys are canonical edge types and values are list of tuples\n        which indicate the range of edges assigned to each of the partitions\n    rank : integer\n        rank of the current process\n    world_size : integer\n        total no. of process participating in the communication primitives\n    num_parts : integer\n        total no. of partitions requested for the input graph\n\n    Returns:\n    --------\n    dictionary :\n        dictionary with keys as column names and values as numpy arrays and this information is\n        loaded from input dataset files. In addition to this we include additional columns which\n        aid this pipelines computation, like constants.OWNER_PROCESS\n    \"\"\"\n    # add global_nids to the node_data\n    etype_offset = {}\n    offset = 0\n    for etype_name, tid_range in edge_tids.items():\n        etype_offset[etype_name] = offset + int(tid_range[0][0])\n        offset += int(tid_range[-1][1])\n\n    global_eids = []\n    for etype_name, tid_range in edge_tids.items():\n        for idx in range(num_parts):\n            if map_partid_rank(idx, world_size) == rank:\n                if len(tid_range) > idx:\n                    global_eid_start = etype_offset[etype_name]\n                    begin = global_eid_start + int(tid_range[idx][0])\n                    end = global_eid_start + int(tid_range[idx][1])\n                    global_eids.append(np.arange(begin, end, dtype=np.int64))\n\n    global_eids = (\n        np.concatenate(global_eids)\n        if len(global_eids) > 0\n        else np.array([], dtype=np.int64)\n    )\n    assert global_eids.shape[0] == edge_data[constants.ETYPE_ID].shape[0]\n    edge_data[constants.GLOBAL_EID] = global_eids\n    return edge_data\n\n\ndef read_edges_file(edge_file, edge_data_dict):\n    \"\"\"\n    Utility function to read xxx_edges.txt file\n\n    Parameters:\n    -----------\n    edge_file : string\n        Graph file for edges in the input graph\n\n    Returns:\n    --------\n    dictionary\n        edge data as read from xxx_edges.txt file and columns are stored\n        in a dictionary with key-value pairs as column-names and column-data.\n    \"\"\"\n    if edge_file == \"\" or edge_file == None:\n        return None\n\n    # Read the file from here.\n    # <global_src_id> <global_dst_id> <type_eid> <etype> <attributes>\n    # global_src_id -- global idx for the source node ... line # in the graph_nodes.txt\n    # global_dst_id -- global idx for the destination id node ... line # in the graph_nodes.txt\n\n    edge_data_df = csv.read_csv(\n        edge_file,\n        read_options=pyarrow.csv.ReadOptions(autogenerate_column_names=True),\n        parse_options=pyarrow.csv.ParseOptions(delimiter=\" \"),\n    )\n    edge_data_dict = {}\n    edge_data_dict[constants.GLOBAL_SRC_ID] = edge_data_df[\"f0\"].to_numpy()\n    edge_data_dict[constants.GLOBAL_DST_ID] = edge_data_df[\"f1\"].to_numpy()\n    edge_data_dict[constants.GLOBAL_TYPE_EID] = edge_data_df[\"f2\"].to_numpy()\n    edge_data_dict[constants.ETYPE_ID] = edge_data_df[\"f3\"].to_numpy()\n    return edge_data_dict\n\n\ndef read_node_features_file(nodes_features_file):\n    \"\"\"\n    Utility function to load tensors from a file\n\n    Parameters:\n    -----------\n    nodes_features_file : string\n        Features file for nodes in the graph\n\n    Returns:\n    --------\n    dictionary\n        mappings between ntype and list of features\n    \"\"\"\n\n    node_features = dgl.data.utils.load_tensors(nodes_features_file, False)\n    return node_features\n\n\ndef read_edge_features_file(edge_features_file):\n    \"\"\"\n    Utility function to load tensors from a file\n\n    Parameters:\n    -----------\n    edge_features_file : string\n        Features file for edges in the graph\n\n    Returns:\n    --------\n    dictionary\n        mappings between etype and list of features\n    \"\"\"\n    edge_features = dgl.data.utils.load_tensors(edge_features_file, True)\n    return edge_features\n\n\ndef write_node_features(node_features, node_file):\n    \"\"\"\n    Utility function to serialize node_features in node_file file\n\n    Parameters:\n    -----------\n    node_features : dictionary\n        dictionary storing ntype <-> list of features\n    node_file     : string\n        File in which the node information is serialized\n    \"\"\"\n    dgl.data.utils.save_tensors(node_file, node_features)\n\n\ndef write_edge_features(edge_features, edge_file):\n    \"\"\"\n    Utility function to serialize edge_features in edge_file file\n\n    Parameters:\n    -----------\n    edge_features : dictionary\n        dictionary storing etype <-> list of features\n    edge_file     : string\n        File in which the edge information is serialized\n    \"\"\"\n    dgl.data.utils.save_tensors(edge_file, edge_features)\n\n\ndef write_graph_graghbolt(graph_file, graph_obj):\n    \"\"\"\n    Utility function to serialize FusedCSCSamplingGraph\n\n    Parameters:\n    -----------\n    graph_obj : FusedCSCSamplingGraph\n        FusedCSCSamplingGraph, as created in convert_partition.py, which is to be serialized\n    graph_file : string\n        File name in which graph object is serialized\n    \"\"\"\n    torch.save(graph_obj, graph_file)\n\n\ndef write_graph_dgl(graph_file, graph_obj, formats, sort_etypes):\n    \"\"\"\n    Utility function to serialize graph dgl objects\n\n    Parameters:\n    -----------\n    graph_obj : dgl graph object\n        graph dgl object, as created in convert_partition.py, which is to be serialized\n    graph_file : string\n        File name in which graph object is serialized\n    formats : str or list[str]\n        Save graph in specified formats.\n    sort_etypes : bool\n        Whether to sort etypes in csc/csr.\n    \"\"\"\n    dgl.distributed.partition.process_partitions(\n        graph_obj, formats, sort_etypes\n    )\n    dgl.save_graphs(graph_file, [graph_obj], formats=formats)\n\n\ndef _write_graph(\n    part_dir, graph_obj, formats=None, sort_etypes=None, use_graphbolt=False\n):\n    if use_graphbolt:\n        write_graph_graghbolt(\n            os.path.join(part_dir, \"fused_csc_sampling_graph.pt\"), graph_obj\n        )\n    else:\n        write_graph_dgl(\n            os.path.join(part_dir, \"graph.dgl\"), graph_obj, formats, sort_etypes\n        )\n\n\ndef write_dgl_objects(\n    graph_obj,\n    node_features,\n    edge_features,\n    output_dir,\n    part_id,\n    orig_nids,\n    orig_eids,\n    formats,\n    sort_etypes,\n    use_graphbolt,\n):\n    \"\"\"\n    Wrapper function to write graph, node/edge feature, original node/edge IDs.\n\n    Parameters:\n    -----------\n    graph_obj : dgl object\n        graph dgl object as created in convert_partition.py file\n    node_features : dgl object\n        Tensor data for node features\n    edge_features : dgl object\n        Tensor data for edge features\n    output_dir : string\n        location where the output files will be located\n    part_id : int\n        integer indicating the partition-id\n    orig_nids : dict\n        original node IDs\n    orig_eids : dict\n        original edge IDs\n    formats : str or list[str]\n        Save graph in formats.\n    sort_etypes : bool\n        Whether to sort etypes in csc/csr.\n    use_graphbolt : bool\n        Whether to use graphbolt or not.\n    \"\"\"\n    part_dir = output_dir + \"/part\" + str(part_id)\n    os.makedirs(part_dir, exist_ok=True)\n    _write_graph(\n        part_dir,\n        graph_obj,\n        formats=formats,\n        sort_etypes=sort_etypes,\n        use_graphbolt=use_graphbolt,\n    )\n    if node_features != None:\n        write_node_features(\n            node_features, os.path.join(part_dir, \"node_feat.dgl\")\n        )\n\n    if edge_features != None:\n        write_edge_features(\n            edge_features, os.path.join(part_dir, \"edge_feat.dgl\")\n        )\n\n    if orig_nids is not None:\n        orig_nids_file = os.path.join(part_dir, \"orig_nids.dgl\")\n        dgl.data.utils.save_tensors(orig_nids_file, orig_nids)\n    if orig_eids is not None:\n        orig_eids_file = os.path.join(part_dir, \"orig_eids.dgl\")\n        dgl.data.utils.save_tensors(orig_eids_file, orig_eids)\n\n\ndef get_idranges(names, counts, num_chunks=None):\n    \"\"\"\n    counts will be a list of numbers of a dictionary.\n    Length is less than or equal to the num_parts variable.\n\n    Parameters:\n    -----------\n    names : list of strings\n        which are either node-types or edge-types\n    counts : list of integers\n        which are total no. of nodes or edges for a give node\n        or edge type\n    num_chunks : int, optional\n        specifying the no. of chunks\n\n    Returns:\n    --------\n    dictionary\n        dictionary where the keys are node-/edge-type names and values are\n        list of tuples where each tuple indicates the range of values for\n        corresponding type-ids.\n    dictionary\n        dictionary where the keys are node-/edge-type names and value is a tuple.\n        This tuple indicates the global-ids for the associated node-/edge-type.\n    \"\"\"\n    gnid_start = 0\n    gnid_end = gnid_start\n    tid_dict = {}\n    gid_dict = {}\n\n    for idx, typename in enumerate(names):\n        gnid_end += counts[typename]\n        tid_dict[typename] = [[0, counts[typename]]]\n        gid_dict[typename] = np.array([gnid_start, gnid_end]).reshape([1, 2])\n        gnid_start = gnid_end\n\n    return tid_dict, gid_dict\n\n\ndef get_ntype_counts_map(ntypes, ntype_counts):\n    \"\"\"\n    Return a dictionary with key, value pairs as node type names and no. of\n    nodes of a particular type in the input graph.\n\n    Parameters:\n    -----------\n    ntypes : list of strings\n        where each string is a node-type name\n    ntype_counts : list of integers\n        where each integer is the total no. of nodes for that, idx, node type\n\n    Returns:\n    --------\n    dictinary :\n        a dictionary where node-type names are keys and values are total no.\n        of nodes for a given node-type name (which is also the key)\n    \"\"\"\n    return dict(zip(ntypes, ntype_counts))\n\n\ndef memory_snapshot(tag, rank):\n    \"\"\"\n    Utility function to take a snapshot of the usage of system resources\n    at a given point of time.\n\n    Parameters:\n    -----------\n    tag : string\n        string provided by the user for bookmarking purposes\n    rank : integer\n        process id of the participating process\n    \"\"\"\n    GB = 1024 * 1024 * 1024\n    MB = 1024 * 1024\n    KB = 1024\n\n    peak = dgl.partition.get_peak_mem() * KB\n    mem = psutil.virtual_memory()\n    avail = mem.available / MB\n    used = mem.used / MB\n    total = mem.total / MB\n\n    mem_string = f\"{total:.0f} (MB) total, {peak:.0f} (MB) peak, {used:.0f} (MB) used, {avail:.0f} (MB) avail\"\n    logging.debug(f\"[Rank: {rank} MEMORY_SNAPSHOT] {mem_string} - {tag}\")\n\n\ndef map_partid_rank(partid, world_size):\n    \"\"\"Auxiliary function to map a given partition id to one of the rank in the\n    MPI_WORLD processes. The range of partition ids is assumed to equal or a\n    multiple of the total size of MPI_WORLD. In this implementation, we use\n    a cyclical mapping procedure to convert partition ids to ranks.\n\n    Parameters:\n    -----------\n    partid : int\n        partition id, as read from node id to partition id mappings.\n\n    Returns:\n    --------\n    int :\n        rank of the process, which will be responsible for the given partition\n        id.\n    \"\"\"\n    return partid % world_size\n\n\ndef generate_read_list(num_files, world_size):\n    \"\"\"\n    Generate the file IDs to read for each rank\n    using sequential assignment.\n\n\n    Parameters:\n    -----------\n    num_files : int\n        Total number of files.\n    world_size : int\n        World size of group.\n\n    Returns:\n    --------\n    read_list : np.array\n        Array of target file IDs to read. Each worker is expected\n        to read the list of file indexes in its rank's index in the list.\n        e.g. rank 0 reads the file indexed in read_list[0], rank 1 the\n        ones in read_list[1] etc.\n\n\n    Examples\n    --------\n    >>> tools.distpartitionning.utils.generate_read_list(10, 4)\n    [array([0, 1, 2]), array([3, 4, 5]), array([6, 7]), array([8, 9])]\n    \"\"\"\n    return np.array_split(np.arange(num_files), world_size)\n\n\ndef generate_roundrobin_read_list(num_files, world_size):\n    \"\"\"\n    Generate the file IDs to read for each rank\n    using round robin assignment.\n\n    Parameters:\n    -----------\n    num_files : int\n        Total number of files.\n    world_size : int\n        World size of group.\n\n    Returns:\n    --------\n    read_list : np.array\n        Array of target file IDs to read. Each worker is expected\n        to read the list of file indexes in its rank's index in the list.\n        e.g. rank 0 reads the indexed in read_list[0], rank 1 the\n        ones in read_list[1] etc.\n\n    Examples\n    --------\n    >>> tools.distpartitionning.utils.generate_roundrobin_read_list(10, 4)\n    [[0, 4, 8], [1, 5, 9], [2, 6], [3, 7]]\n    \"\"\"\n    assignment_lists = [[] for _ in range(world_size)]\n    for rank, part_idx in zip(cycle(range(world_size)), range(num_files)):\n        assignment_lists[rank].append(part_idx)\n\n    return assignment_lists\n"
  },
  {
    "path": "tools/files.py",
    "content": "import logging\nimport os\nfrom contextlib import contextmanager\n\nfrom numpy.lib.format import open_memmap\n\n\n@contextmanager\ndef setdir(path):\n    try:\n        os.makedirs(path, exist_ok=True)\n        cwd = os.getcwd()\n        logging.info(\"Changing directory to %s\" % path)\n        logging.info(\"Previously: %s\" % cwd)\n        os.chdir(path)\n        yield\n    finally:\n        logging.info(\"Restoring directory to %s\" % cwd)\n        os.chdir(cwd)\n"
  },
  {
    "path": "tools/launch.py",
    "content": "\"\"\"Launching tool for DGL distributed training\"\"\"\nimport argparse\nimport json\nimport logging\nimport multiprocessing\nimport os\nimport queue\nimport re\nimport signal\nimport subprocess\nimport sys\nimport time\nfrom functools import partial\nfrom threading import Thread\nfrom typing import Optional\n\n\ndef cleanup_proc(get_all_remote_pids, conn):\n    \"\"\"This process tries to clean up the remote training tasks.\"\"\"\n    print(\"cleanup process runs\")\n    # This process should not handle SIGINT.\n    signal.signal(signal.SIGINT, signal.SIG_IGN)\n\n    data = conn.recv()\n    # If the launch process exits normally, this process doesn't need to do anything.\n    if data == \"exit\":\n        sys.exit(0)\n    else:\n        remote_pids = get_all_remote_pids()\n        # Otherwise, we need to ssh to each machine and kill the training jobs.\n        for (ip, port), pids in remote_pids.items():\n            kill_process(ip, port, pids)\n    print(\"cleanup process exits\")\n\n\ndef kill_process(ip, port, pids):\n    \"\"\"ssh to a remote machine and kill the specified processes.\"\"\"\n    curr_pid = os.getpid()\n    killed_pids = []\n    # If we kill child processes first, the parent process may create more again. This happens\n    # to Python's process pool. After sorting, we always kill parent processes first.\n    pids.sort()\n    for pid in pids:\n        assert curr_pid != pid\n        print(\"kill process {} on {}:{}\".format(pid, ip, port), flush=True)\n        kill_cmd = (\n            \"ssh -o StrictHostKeyChecking=no -p \"\n            + str(port)\n            + \" \"\n            + ip\n            + \" 'kill {}'\".format(pid)\n        )\n        subprocess.run(kill_cmd, shell=True)\n        killed_pids.append(pid)\n    # It's possible that some of the processes are not killed. Let's try again.\n    for i in range(3):\n        killed_pids = get_killed_pids(ip, port, killed_pids)\n        if len(killed_pids) == 0:\n            break\n        else:\n            killed_pids.sort()\n            for pid in killed_pids:\n                print(\n                    \"kill process {} on {}:{}\".format(pid, ip, port), flush=True\n                )\n                kill_cmd = (\n                    \"ssh -o StrictHostKeyChecking=no -p \"\n                    + str(port)\n                    + \" \"\n                    + ip\n                    + \" 'kill -9 {}'\".format(pid)\n                )\n                subprocess.run(kill_cmd, shell=True)\n\n\ndef get_killed_pids(ip, port, killed_pids):\n    \"\"\"Get the process IDs that we want to kill but are still alive.\"\"\"\n    killed_pids = [str(pid) for pid in killed_pids]\n    killed_pids = \",\".join(killed_pids)\n    ps_cmd = (\n        \"ssh -o StrictHostKeyChecking=no -p \"\n        + str(port)\n        + \" \"\n        + ip\n        + \" 'ps -p {} -h'\".format(killed_pids)\n    )\n    res = subprocess.run(ps_cmd, shell=True, stdout=subprocess.PIPE)\n    pids = []\n    for p in res.stdout.decode(\"utf-8\").split(\"\\n\"):\n        l = p.split()\n        if len(l) > 0:\n            pids.append(int(l[0]))\n    return pids\n\n\ndef execute_remote(\n    cmd: str,\n    state_q: queue.Queue,\n    ip: str,\n    port: int,\n    username: Optional[str] = \"\",\n) -> Thread:\n    \"\"\"Execute command line on remote machine via ssh.\n\n    Args:\n        cmd: User-defined command (udf) to execute on the remote host.\n        state_q: A queue collecting Thread exit states.\n        ip: The ip-address of the host to run the command on.\n        port: Port number that the host is listening on.\n        thread_list:\n        username: Optional. If given, this will specify a username to use when issuing commands over SSH.\n            Useful when your infra requires you to explicitly specify a username to avoid permission issues.\n\n    Returns:\n        thread: The Thread whose run() is to run the `cmd` on the remote host. Returns when the cmd completes\n            on the remote host.\n    \"\"\"\n    ip_prefix = \"\"\n    if username:\n        ip_prefix += \"{username}@\".format(username=username)\n\n    # Construct ssh command that executes `cmd` on the remote host\n    ssh_cmd = \"ssh -o StrictHostKeyChecking=no -p {port} {ip_prefix}{ip} '{cmd}'\".format(\n        port=str(port),\n        ip_prefix=ip_prefix,\n        ip=ip,\n        cmd=cmd,\n    )\n\n    # thread func to run the job\n    def run(ssh_cmd, state_q):\n        try:\n            subprocess.check_call(ssh_cmd, shell=True)\n            state_q.put(0)\n        except subprocess.CalledProcessError as err:\n            print(f\"Called process error {err}\")\n            state_q.put(err.returncode)\n        except Exception:\n            state_q.put(-1)\n\n    thread = Thread(\n        target=run,\n        args=(\n            ssh_cmd,\n            state_q,\n        ),\n    )\n    thread.setDaemon(True)\n    thread.start()\n    # sleep for a while in case of ssh is rejected by peer due to busy connection\n    time.sleep(0.2)\n    return thread\n\n\ndef get_remote_pids(ip, port, cmd_regex):\n    \"\"\"Get the process IDs that run the command in the remote machine.\"\"\"\n    pids = []\n    curr_pid = os.getpid()\n    # Here we want to get the python processes. We may get some ssh processes, so we should filter them out.\n    ps_cmd = (\n        \"ssh -o StrictHostKeyChecking=no -p \"\n        + str(port)\n        + \" \"\n        + ip\n        + \" 'ps -aux | grep python | grep -v StrictHostKeyChecking'\"\n    )\n    res = subprocess.run(ps_cmd, shell=True, stdout=subprocess.PIPE)\n    for p in res.stdout.decode(\"utf-8\").split(\"\\n\"):\n        l = p.split()\n        if len(l) < 2:\n            continue\n        # We only get the processes that run the specified command.\n        res = re.search(cmd_regex, p)\n        if res is not None and int(l[1]) != curr_pid:\n            pids.append(l[1])\n\n    pid_str = \",\".join([str(pid) for pid in pids])\n    ps_cmd = (\n        \"ssh -o StrictHostKeyChecking=no -p \"\n        + str(port)\n        + \" \"\n        + ip\n        + \" 'pgrep -P {}'\".format(pid_str)\n    )\n    res = subprocess.run(ps_cmd, shell=True, stdout=subprocess.PIPE)\n    pids1 = res.stdout.decode(\"utf-8\").split(\"\\n\")\n    all_pids = []\n    for pid in set(pids + pids1):\n        if pid == \"\" or int(pid) == curr_pid:\n            continue\n        all_pids.append(int(pid))\n    all_pids.sort()\n    return all_pids\n\n\ndef get_all_remote_pids(hosts, ssh_port, udf_command):\n    \"\"\"Get all remote processes.\"\"\"\n    remote_pids = {}\n    for node_id, host in enumerate(hosts):\n        ip, _ = host\n        # When creating training processes in remote machines, we may insert some arguments\n        # in the commands. We need to use regular expressions to match the modified command.\n        cmds = udf_command.split()\n        new_udf_command = \" .*\".join(cmds)\n        pids = get_remote_pids(ip, ssh_port, new_udf_command)\n        remote_pids[(ip, ssh_port)] = pids\n    return remote_pids\n\n\ndef construct_torch_dist_launcher_cmd(\n    num_trainers: int,\n    num_nodes: int,\n    node_rank: int,\n    master_addr: str,\n    master_port: int,\n) -> str:\n    \"\"\"Constructs the torch distributed launcher command.\n    Helper function.\n\n    Args:\n        num_trainers:\n        num_nodes:\n        node_rank:\n        master_addr:\n        master_port:\n\n    Returns:\n        cmd_str.\n    \"\"\"\n    torch_cmd_template = (\n        \"-m torch.distributed.run \"\n        \"--nproc_per_node={nproc_per_node} \"\n        \"--nnodes={nnodes} \"\n        \"--node_rank={node_rank} \"\n        \"--master_addr={master_addr} \"\n        \"--master_port={master_port}\"\n    )\n    return torch_cmd_template.format(\n        nproc_per_node=num_trainers,\n        nnodes=num_nodes,\n        node_rank=node_rank,\n        master_addr=master_addr,\n        master_port=master_port,\n    )\n\n\ndef wrap_udf_in_torch_dist_launcher(\n    udf_command: str,\n    num_trainers: int,\n    num_nodes: int,\n    node_rank: int,\n    master_addr: str,\n    master_port: int,\n) -> str:\n    \"\"\"Wraps the user-defined function (udf_command) with the torch.distributed.run module.\n\n     Example: if udf_command is \"python3 run/some/trainer.py arg1 arg2\", then new_df_command becomes:\n         \"python3 -m torch.distributed.run <TORCH DIST ARGS> run/some/trainer.py arg1 arg2\n\n    udf_command is assumed to consist of pre-commands (optional) followed by the python launcher script (required):\n    Examples:\n        # simple\n        python3.7 path/to/some/trainer.py arg1 arg2\n\n        # multi-commands\n        (cd some/dir && python3.7 path/to/some/trainer.py arg1 arg2)\n\n    IMPORTANT: If udf_command consists of multiple python commands, then this will result in undefined behavior.\n\n    Args:\n        udf_command:\n        num_trainers:\n        num_nodes:\n        node_rank:\n        master_addr:\n        master_port:\n\n    Returns:\n\n    \"\"\"\n    torch_dist_cmd = construct_torch_dist_launcher_cmd(\n        num_trainers=num_trainers,\n        num_nodes=num_nodes,\n        node_rank=node_rank,\n        master_addr=master_addr,\n        master_port=master_port,\n    )\n    # Auto-detect the python binary that kicks off the distributed trainer code.\n    # Note: This allowlist order matters, this will match with the FIRST matching entry. Thus, please add names to this\n    #       from most-specific to least-specific order eg:\n    #           (python3.7, python3.8) -> (python3)\n    # The allowed python versions are from this: https://www.dgl.ai/pages/start.html\n    python_bin_allowlist = (\n        \"python3.6\",\n        \"python3.7\",\n        \"python3.8\",\n        \"python3.9\",\n        \"python3\",\n        # for backwards compatibility, accept python2 but technically DGL is a py3 library, so this is not recommended\n        \"python2.7\",\n        \"python2\",\n    )\n    # If none of the candidate python bins match, then we go with the default `python`\n    python_bin = \"python\"\n    for candidate_python_bin in python_bin_allowlist:\n        if candidate_python_bin in udf_command:\n            python_bin = candidate_python_bin\n            break\n\n    # transforms the udf_command from:\n    #     python path/to/dist_trainer.py arg0 arg1\n    # to:\n    #     python -m torch.distributed.run [DIST TORCH ARGS] path/to/dist_trainer.py arg0 arg1\n    # Note: if there are multiple python commands in `udf_command`, this may do the Wrong Thing, eg launch each\n    #       python command within the torch distributed launcher.\n    new_udf_command = udf_command.replace(\n        python_bin, f\"{python_bin} {torch_dist_cmd}\"\n    )\n\n    return new_udf_command\n\n\ndef construct_dgl_server_env_vars(\n    num_samplers: int,\n    num_server_threads: int,\n    tot_num_clients: int,\n    part_config: str,\n    ip_config: str,\n    num_servers: int,\n    graph_format: str,\n    pythonpath: Optional[str] = \"\",\n) -> str:\n    \"\"\"Constructs the DGL server-specific env vars string that are required for DGL code to behave in the correct\n    server role.\n    Convenience function.\n\n    Args:\n        num_samplers:\n        num_server_threads:\n        tot_num_clients:\n        part_config: Partition config.\n            Relative path to workspace.\n        ip_config: IP config file containing IP addresses of cluster hosts.\n            Relative path to workspace.\n        num_servers:\n        graph_format:\n        pythonpath: Optional. If given, this will pass this as PYTHONPATH.\n\n    Returns:\n        server_env_vars: The server-specific env-vars in a string format, friendly for CLI execution.\n\n    \"\"\"\n    server_env_vars_template = (\n        \"DGL_ROLE={DGL_ROLE} \"\n        \"DGL_NUM_SAMPLER={DGL_NUM_SAMPLER} \"\n        \"OMP_NUM_THREADS={OMP_NUM_THREADS} \"\n        \"DGL_NUM_CLIENT={DGL_NUM_CLIENT} \"\n        \"DGL_CONF_PATH={DGL_CONF_PATH} \"\n        \"DGL_IP_CONFIG={DGL_IP_CONFIG} \"\n        \"DGL_NUM_SERVER={DGL_NUM_SERVER} \"\n        \"DGL_GRAPH_FORMAT={DGL_GRAPH_FORMAT} \"\n        \"{suffix_optional_envvars}\"\n    )\n    suffix_optional_envvars = \"\"\n    if pythonpath:\n        suffix_optional_envvars += f\"PYTHONPATH={pythonpath} \"\n    return server_env_vars_template.format(\n        DGL_ROLE=\"server\",\n        DGL_NUM_SAMPLER=num_samplers,\n        OMP_NUM_THREADS=num_server_threads,\n        DGL_NUM_CLIENT=tot_num_clients,\n        DGL_CONF_PATH=part_config,\n        DGL_IP_CONFIG=ip_config,\n        DGL_NUM_SERVER=num_servers,\n        DGL_GRAPH_FORMAT=graph_format,\n        suffix_optional_envvars=suffix_optional_envvars,\n    )\n\n\ndef construct_dgl_client_env_vars(\n    num_samplers: int,\n    tot_num_clients: int,\n    part_config: str,\n    ip_config: str,\n    num_servers: int,\n    graph_format: str,\n    num_omp_threads: int,\n    group_id: int,\n    pythonpath: Optional[str] = \"\",\n) -> str:\n    \"\"\"Constructs the DGL client-specific env vars string that are required for DGL code to behave in the correct\n    client role.\n    Convenience function.\n\n    Args:\n        num_samplers:\n        tot_num_clients:\n        part_config: Partition config.\n            Relative path to workspace.\n        ip_config: IP config file containing IP addresses of cluster hosts.\n            Relative path to workspace.\n        num_servers:\n        graph_format:\n        num_omp_threads:\n        group_id:\n            Used in client processes to indicate which group it belongs to.\n        pythonpath: Optional. If given, this will pass this as PYTHONPATH.\n\n    Returns:\n        client_env_vars: The client-specific env-vars in a string format, friendly for CLI execution.\n\n    \"\"\"\n    client_env_vars_template = (\n        \"DGL_DIST_MODE={DGL_DIST_MODE} \"\n        \"DGL_ROLE={DGL_ROLE} \"\n        \"DGL_NUM_SAMPLER={DGL_NUM_SAMPLER} \"\n        \"DGL_NUM_CLIENT={DGL_NUM_CLIENT} \"\n        \"DGL_CONF_PATH={DGL_CONF_PATH} \"\n        \"DGL_IP_CONFIG={DGL_IP_CONFIG} \"\n        \"DGL_NUM_SERVER={DGL_NUM_SERVER} \"\n        \"DGL_GRAPH_FORMAT={DGL_GRAPH_FORMAT} \"\n        \"OMP_NUM_THREADS={OMP_NUM_THREADS} \"\n        \"DGL_GROUP_ID={DGL_GROUP_ID} \"\n        \"{suffix_optional_envvars}\"\n    )\n    # append optional additional env-vars\n    suffix_optional_envvars = \"\"\n    if pythonpath:\n        suffix_optional_envvars += f\"PYTHONPATH={pythonpath} \"\n    return client_env_vars_template.format(\n        DGL_DIST_MODE=\"distributed\",\n        DGL_ROLE=\"client\",\n        DGL_NUM_SAMPLER=num_samplers,\n        DGL_NUM_CLIENT=tot_num_clients,\n        DGL_CONF_PATH=part_config,\n        DGL_IP_CONFIG=ip_config,\n        DGL_NUM_SERVER=num_servers,\n        DGL_GRAPH_FORMAT=graph_format,\n        OMP_NUM_THREADS=num_omp_threads,\n        DGL_GROUP_ID=group_id,\n        suffix_optional_envvars=suffix_optional_envvars,\n    )\n\n\ndef wrap_cmd_with_local_envvars(cmd: str, env_vars: str) -> str:\n    \"\"\"Wraps a CLI command with desired env vars with the following properties:\n    (1) env vars persist for the entire `cmd`, even if it consists of multiple \"chained\" commands like:\n        cmd = \"ls && pwd && python run/something.py\"\n    (2) env vars don't pollute the environment after `cmd` completes.\n\n    Example:\n        >>> cmd = \"ls && pwd\"\n        >>> env_vars = \"VAR1=value1 VAR2=value2\"\n        >>> wrap_cmd_with_local_envvars(cmd, env_vars)\n        \"(export VAR1=value1 VAR2=value2; ls && pwd)\"\n\n    Args:\n        cmd:\n        env_vars: A string containing env vars, eg \"VAR1=val1 VAR2=val2\"\n\n    Returns:\n        cmd_with_env_vars:\n\n    \"\"\"\n    # use `export` to persist env vars for entire cmd block. required if udf_command is a chain of commands\n    # also: wrap in parens to not pollute env:\n    #     https://stackoverflow.com/a/45993803\n    return f\"(export {env_vars}; {cmd})\"\n\n\ndef wrap_cmd_with_extra_envvars(cmd: str, env_vars: list) -> str:\n    \"\"\"Wraps a CLI command with extra env vars\n\n    Example:\n        >>> cmd = \"ls && pwd\"\n        >>> env_vars = [\"VAR1=value1\", \"VAR2=value2\"]\n        >>> wrap_cmd_with_extra_envvars(cmd, env_vars)\n        \"(export VAR1=value1 VAR2=value2; ls && pwd)\"\n\n    Args:\n        cmd:\n        env_vars: A list of strings containing env vars, e.g., [\"VAR1=value1\", \"VAR2=value2\"]\n\n    Returns:\n        cmd_with_env_vars:\n    \"\"\"\n    env_vars = \" \".join(env_vars)\n    return wrap_cmd_with_local_envvars(cmd, env_vars)\n\n\ndef get_available_port(ip):\n    \"\"\"Get available port with specified ip.\"\"\"\n    import socket\n\n    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)\n    for port in range(1234, 65535):\n        try:\n            sock.connect((ip, port))\n        except:\n            return port\n    raise RuntimeError(\"Failed to get available port for ip~{}\".format(ip))\n\n\ndef submit_jobs(args, udf_command, dry_run=False):\n    \"\"\"Submit distributed jobs (server and client processes) via ssh\"\"\"\n    if dry_run:\n        print(\n            \"Currently it's in dry run mode which means no jobs will be launched.\"\n        )\n    servers_cmd = []\n    clients_cmd = []\n    hosts = []\n    thread_list = []\n    server_count_per_machine = 0\n\n    # Get the IP addresses of the cluster.\n    ip_config = os.path.join(args.workspace, args.ip_config)\n    with open(ip_config) as f:\n        for line in f:\n            result = line.strip().split()\n            if len(result) == 2:\n                ip = result[0]\n                port = int(result[1])\n                hosts.append((ip, port))\n            elif len(result) == 1:\n                ip = result[0]\n                port = get_available_port(ip)\n                hosts.append((ip, port))\n            else:\n                raise RuntimeError(\"Format error of ip_config.\")\n            server_count_per_machine = args.num_servers\n    # Get partition info of the graph data\n    part_config = os.path.join(args.workspace, args.part_config)\n    with open(part_config) as conf_f:\n        part_metadata = json.load(conf_f)\n    assert \"num_parts\" in part_metadata, \"num_parts does not exist.\"\n    # The number of partitions must match the number of machines in the cluster.\n    assert part_metadata[\"num_parts\"] == len(\n        hosts\n    ), \"The number of graph partitions has to match the number of machines in the cluster.\"\n\n    state_q = queue.Queue()\n    tot_num_clients = args.num_trainers * (1 + args.num_samplers) * len(hosts)\n    # launch server tasks\n    server_env_vars = construct_dgl_server_env_vars(\n        num_samplers=args.num_samplers,\n        num_server_threads=args.num_server_threads,\n        tot_num_clients=tot_num_clients,\n        part_config=args.part_config,\n        ip_config=args.ip_config,\n        num_servers=args.num_servers,\n        graph_format=args.graph_format,\n        pythonpath=os.environ.get(\"PYTHONPATH\", \"\"),\n    )\n    for i in range(len(hosts) * server_count_per_machine):\n        ip, _ = hosts[int(i / server_count_per_machine)]\n        server_env_vars_cur = f\"{server_env_vars} DGL_SERVER_ID={i}\"\n        cmd = wrap_cmd_with_local_envvars(udf_command, server_env_vars_cur)\n        cmd = (\n            wrap_cmd_with_extra_envvars(cmd, args.extra_envs)\n            if len(args.extra_envs) > 0\n            else cmd\n        )\n        cmd = \"cd \" + str(args.workspace) + \"; \" + cmd\n        servers_cmd.append(cmd)\n        if not dry_run:\n            thread_list.append(\n                execute_remote(\n                    cmd,\n                    state_q,\n                    ip,\n                    args.ssh_port,\n                    username=args.ssh_username,\n                )\n            )\n\n    # launch client tasks\n    client_env_vars = construct_dgl_client_env_vars(\n        num_samplers=args.num_samplers,\n        tot_num_clients=tot_num_clients,\n        part_config=args.part_config,\n        ip_config=args.ip_config,\n        num_servers=args.num_servers,\n        graph_format=args.graph_format,\n        num_omp_threads=os.environ.get(\n            \"OMP_NUM_THREADS\", str(args.num_omp_threads)\n        ),\n        group_id=0,\n        pythonpath=os.environ.get(\"PYTHONPATH\", \"\"),\n    )\n\n    master_addr = hosts[0][0]\n    master_port = get_available_port(master_addr)\n    for node_id, host in enumerate(hosts):\n        ip, _ = host\n        # Transform udf_command to follow torch's dist launcher format: `PYTHON_BIN -m torch.distributed.run ... UDF`\n        torch_dist_udf_command = wrap_udf_in_torch_dist_launcher(\n            udf_command=udf_command,\n            num_trainers=args.num_trainers,\n            num_nodes=len(hosts),\n            node_rank=node_id,\n            master_addr=master_addr,\n            master_port=master_port,\n        )\n        cmd = wrap_cmd_with_local_envvars(\n            torch_dist_udf_command, client_env_vars\n        )\n        cmd = (\n            wrap_cmd_with_extra_envvars(cmd, args.extra_envs)\n            if len(args.extra_envs) > 0\n            else cmd\n        )\n        cmd = \"cd \" + str(args.workspace) + \"; \" + cmd\n        clients_cmd.append(cmd)\n        if not dry_run:\n            thread_list.append(\n                execute_remote(\n                    cmd, state_q, ip, args.ssh_port, username=args.ssh_username\n                )\n            )\n\n    # return commands of clients/servers directly if in dry run mode\n    if dry_run:\n        return clients_cmd, servers_cmd\n\n    # Start a cleanup process dedicated for cleaning up remote training jobs.\n    conn1, conn2 = multiprocessing.Pipe()\n    func = partial(get_all_remote_pids, hosts, args.ssh_port, udf_command)\n    process = multiprocessing.Process(target=cleanup_proc, args=(func, conn1))\n    process.start()\n\n    def signal_handler(signal, frame):\n        logging.info(\"Stop launcher\")\n        # We need to tell the cleanup process to kill remote training jobs.\n        conn2.send(\"cleanup\")\n        sys.exit(0)\n\n    signal.signal(signal.SIGINT, signal_handler)\n\n    err = 0\n    for thread in thread_list:\n        thread.join()\n        err_code = state_q.get()\n        if err_code != 0:\n            # Record err_code\n            # We record one of the error if there are multiple\n            err = err_code\n\n    # The training processes complete. We should tell the cleanup process to exit.\n    conn2.send(\"exit\")\n    process.join()\n    if err != 0:\n        print(\"Task failed\")\n        sys.exit(-1)\n\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"Launch a distributed job\")\n    parser.add_argument(\"--ssh_port\", type=int, default=22, help=\"SSH Port.\")\n    parser.add_argument(\n        \"--ssh_username\",\n        default=\"\",\n        help=\"Optional. When issuing commands (via ssh) to cluster, use the provided username in the ssh cmd. \"\n        \"Example: If you provide --ssh_username=bob, then the ssh command will be like: 'ssh bob@1.2.3.4 CMD' \"\n        \"instead of 'ssh 1.2.3.4 CMD'\",\n    )\n    parser.add_argument(\n        \"--workspace\",\n        type=str,\n        help=\"Path of user directory of distributed tasks. \\\n                        This is used to specify a destination location where \\\n                        the contents of current directory will be rsyncd\",\n    )\n    parser.add_argument(\n        \"--num_trainers\",\n        type=int,\n        help=\"The number of trainer processes per machine\",\n    )\n    parser.add_argument(\n        \"--num_omp_threads\",\n        type=int,\n        help=\"The number of OMP threads per trainer\",\n    )\n    parser.add_argument(\n        \"--num_samplers\",\n        type=int,\n        default=0,\n        help=\"The number of sampler processes per trainer process\",\n    )\n    parser.add_argument(\n        \"--num_servers\",\n        type=int,\n        help=\"The number of server processes per machine\",\n    )\n    parser.add_argument(\n        \"--part_config\",\n        type=str,\n        help=\"The file (in workspace) of the partition config\",\n    )\n    parser.add_argument(\n        \"--ip_config\",\n        type=str,\n        help=\"The file (in workspace) of IP configuration for server processes\",\n    )\n    parser.add_argument(\n        \"--num_server_threads\",\n        type=int,\n        default=1,\n        help=\"The number of OMP threads in the server process. \\\n                        It should be small if server processes and trainer processes run on \\\n                        the same machine. By default, it is 1.\",\n    )\n    parser.add_argument(\n        \"--graph_format\",\n        type=str,\n        default=\"csc\",\n        help='The format of the graph structure of each partition. \\\n                        The allowed formats are csr, csc and coo. A user can specify multiple \\\n                        formats, separated by \",\". For example, the graph format is \"csr,csc\".',\n    )\n    parser.add_argument(\n        \"--extra_envs\",\n        nargs=\"+\",\n        type=str,\n        default=[],\n        help=\"Extra environment parameters need to be set. For example, \\\n                        you can set the LD_LIBRARY_PATH and NCCL_DEBUG by adding: \\\n                        --extra_envs LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH NCCL_DEBUG=INFO \",\n    )\n    args, udf_command = parser.parse_known_args()\n    assert len(udf_command) == 1, \"Please provide user command line.\"\n    assert (\n        args.num_trainers is not None and args.num_trainers > 0\n    ), \"--num_trainers must be a positive number.\"\n    assert (\n        args.num_samplers is not None and args.num_samplers >= 0\n    ), \"--num_samplers must be a non-negative number.\"\n    assert (\n        args.num_servers is not None and args.num_servers > 0\n    ), \"--num_servers must be a positive number.\"\n    assert (\n        args.num_server_threads > 0\n    ), \"--num_server_threads must be a positive number.\"\n    assert (\n        args.workspace is not None\n    ), \"A user has to specify a workspace with --workspace.\"\n    assert (\n        args.part_config is not None\n    ), \"A user has to specify a partition configuration file with --part_config.\"\n    assert (\n        args.ip_config is not None\n    ), \"A user has to specify an IP configuration file with --ip_config.\"\n    if args.num_omp_threads is None:\n        # Here we assume all machines have the same number of CPU cores as the machine\n        # where the launch script runs.\n        args.num_omp_threads = max(\n            multiprocessing.cpu_count() // 2 // args.num_trainers, 1\n        )\n        print(\n            \"The number of OMP threads per trainer is set to\",\n            args.num_omp_threads,\n        )\n\n    udf_command = str(udf_command[0])\n    if \"python\" not in udf_command:\n        raise RuntimeError(\n            \"DGL launching script can only support Python executable file.\"\n        )\n    submit_jobs(args, udf_command)\n\n\nif __name__ == \"__main__\":\n    fmt = \"%(asctime)s %(levelname)s %(message)s\"\n    logging.basicConfig(format=fmt, level=logging.INFO)\n    main()\n"
  },
  {
    "path": "tools/partition_algo/base.py",
    "content": "import json\nfrom typing import Optional\n\nimport pydantic as dt\nfrom dgl import DGLError\n\n\nclass PartitionMeta(dt.BaseModel):\n    \"\"\"Metadata that describes the partition assignment results.\n\n    Regardless of the choice of partitioning algorithm, a metadata JSON file\n    will be created in the output directory which includes the meta information\n    of the partition algorithm.\n\n    To generate a metadata JSON:\n\n    >>> part_meta = PartitionMeta(version='1.0.0', num_parts=4, algo_name='random')\n    >>> with open('metadata.json', 'w') as f:\n    ...     json.dump(part_meta.dict(), f)\n\n    To read a metadata JSON:\n\n    >>> with open('metadata.json') as f:\n    ...     part_meta = PartitionMeta(**(json.load(f)))\n\n    \"\"\"\n\n    # version of metadata JSON.\n    version: Optional[str] = \"1.0.0\"\n    # number of partitions.\n    num_parts: int\n    # name of partition algorithm.\n    algo_name: str\n\n\ndef dump_partition_meta(part_meta, meta_file):\n    \"\"\"Dump partition metadata into json file.\n\n    Parameters\n    ----------\n    part_meta : PartitionMeta\n        The partition metadata.\n    meta_file : str\n        The target file to save data.\n    \"\"\"\n    with open(meta_file, \"w\") as f:\n        json.dump(part_meta.dict(), f, sort_keys=True, indent=4)\n\n\ndef load_partition_meta(meta_file):\n    \"\"\"Load partition metadata and do sanity check.\n\n    Parameters\n    ----------\n    meta_file : str\n        The path of the partition metadata file.\n\n    Returns\n    -------\n    PartitionMeta\n        The partition metadata.\n    \"\"\"\n    with open(meta_file) as f:\n        try:\n            part_meta = PartitionMeta(**(json.load(f)))\n        except dt.ValidationError as e:\n            raise DGLError(\n                f\"Invalid partition metadata JSON. Error details: {e.json()}\"\n            )\n        if part_meta.version != \"1.0.0\":\n            raise DGLError(\n                f\"Invalid version[{part_meta.version}]. Supported versions: '1.0.0'\"\n            )\n        if part_meta.num_parts <= 0:\n            raise DGLError(\n                f\"num_parts[{part_meta.num_parts}] should be greater than 0.\"\n            )\n        if part_meta.algo_name not in [\"random\", \"metis\"]:\n            raise DGLError(\n                f\"algo_name[{part_meta.num_parts}] is not supported.\"\n            )\n        return part_meta\n"
  },
  {
    "path": "tools/partition_algo/random_partition.py",
    "content": "# Requires setting PYTHONPATH=${GITROOT}/tools\nimport argparse\nimport json\nimport logging\nimport os\n\nimport numpy as np\nfrom base import dump_partition_meta, PartitionMeta\nfrom distpartitioning import array_readwriter\nfrom files import setdir\n\n\ndef _random_partition(metadata, num_parts):\n    num_nodes_per_type = metadata[\"num_nodes_per_type\"]\n    ntypes = metadata[\"node_type\"]\n    for ntype, n in zip(ntypes, num_nodes_per_type):\n        logging.info(\"Generating partition for node type %s\" % ntype)\n        parts = np.random.randint(0, num_parts, (n,))\n        array_readwriter.get_array_parser(name=\"csv\").write(\n            ntype + \".txt\", parts\n        )\n\n\ndef random_partition(metadata, num_parts, output_path):\n    \"\"\"\n    Randomly partition the graph described in metadata and generate partition ID mapping\n    in :attr:`output_path`.\n\n    A directory will be created at :attr:`output_path` containing the partition ID\n    mapping files named \"<node-type>.txt\" (e.g. \"author.txt\", \"paper.txt\" and\n    \"institution.txt\" for OGB-MAG240M).  Each file contains one line per node representing\n    the partition ID the node belongs to.\n    In addition, metadata which includes version, number of partitions is dumped.\n    \"\"\"\n    with setdir(output_path):\n        _random_partition(metadata, num_parts)\n        part_meta = PartitionMeta(\n            version=\"1.0.0\", num_parts=num_parts, algo_name=\"random\"\n        )\n        dump_partition_meta(part_meta, \"partition_meta.json\")\n\n\n# Run with PYTHONPATH=${GIT_ROOT_DIR}/tools\n# where ${GIT_ROOT_DIR} is the directory to the DGL git repository.\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--in_dir\",\n        type=str,\n        help=\"input directory that contains the metadata file\",\n    )\n    parser.add_argument(\"--out_dir\", type=str, help=\"output directory\")\n    parser.add_argument(\n        \"--num_partitions\", type=int, help=\"number of partitions\"\n    )\n    logging.basicConfig(level=\"INFO\")\n    args = parser.parse_args()\n    with open(os.path.join(args.in_dir, \"metadata.json\")) as f:\n        metadata = json.load(f)\n    num_parts = args.num_partitions\n    random_partition(metadata, num_parts, args.out_dir)\n"
  },
  {
    "path": "tools/verification_utils.py",
    "content": "import json\nimport os\n\nimport constants\n\nimport dgl\n\nimport numpy as np\nimport pyarrow\nimport pyarrow.parquet as pq\nimport pytest\nimport torch\nfrom dgl.data.utils import load_tensors\nfrom dgl.distributed.partition import (\n    _etype_str_to_tuple,\n    _etype_tuple_to_str,\n    _get_inner_edge_mask,\n    _get_inner_node_mask,\n    RESERVED_FIELD_DTYPE,\n)\nfrom distpartitioning.utils import get_idranges\n\n\ndef read_file(fname, ftype):\n    \"\"\"Read a file from disk\n    Parameters:\n    -----------\n    fname : string\n        specifying the absolute path to the file to read\n    ftype : string\n        supported formats are `numpy`, `parquet', `csv`\n\n    Returns:\n    --------\n    numpy ndarray :\n        file contents are returned as numpy array\n    \"\"\"\n    reader_fmt_meta = {\"name\": ftype}\n    array_readwriter.get_array_parser(**reader_fmt_meta).read(fname)\n\n    return data\n\n\ndef verify_partition_data_types(part_g):\n    \"\"\"Validate the dtypes in the partitioned graphs are valid\n\n    Parameters:\n    -----------\n    part_g : DGL Graph object\n        created for the partitioned graphs\n    \"\"\"\n    for k, dtype in RESERVED_FIELD_DTYPE.items():\n        if k in part_g.ndata:\n            assert part_g.ndata[k].dtype == dtype\n        if k in part_g.edata:\n            assert part_g.edata[k].dtype == dtype\n\n\ndef verify_partition_formats(part_g, formats):\n    \"\"\"Validate the partitioned graphs with supported formats\n\n    Parameters:\n    -----------\n    part_g : DGL Graph object\n        created for the partitioned graphs\n    formats : string\n        formats(csc, coo, csr) supported formats and multiple\n        values can be seperated by comma\n    \"\"\"\n    # Verify saved graph formats\n    if formats is None:\n        assert \"coo\" in part_g.formats()[\"created\"]\n    else:\n        formats = formats.split(\",\")\n        for format in formats:\n            assert format in part_g.formats()[\"created\"]\n\n\ndef verify_graph_feats(\n    g, gpb, part, node_feats, edge_feats, orig_nids, orig_eids\n):\n    \"\"\"Verify the node/edge features of the partitioned graph with\n    the original graph\n\n    Parameters:\n    -----------\n    g : DGL Graph Object\n        of the original graph\n    gpb : global partition book\n        created for the partitioned graph object\n    node_feats : dictionary\n        with key, value pairs as node-types and features as numpy arrays\n    edge_feats : dictionary\n        with key, value pairs as edge-types and features as numpy arrays\n    orig_nids : dictionary\n        with key, value pairs as node-types and (global) nids from the\n        original graph\n    orig_eids : dictionary\n        with key, value pairs as edge-types and (global) eids from the\n        original graph\n    \"\"\"\n    for ntype in g.ntypes:\n        ntype_id = g.get_ntype_id(ntype)\n        inner_node_mask = _get_inner_node_mask(part, ntype_id)\n        inner_nids = part.ndata[dgl.NID][inner_node_mask]\n        ntype_ids, inner_type_nids = gpb.map_to_per_ntype(inner_nids)\n        partid = gpb.nid2partid(inner_type_nids, ntype)\n        assert np.all(ntype_ids.numpy() == ntype_id)\n        assert np.all(partid.numpy() == gpb.partid)\n\n        orig_id = orig_nids[ntype][inner_type_nids]\n        local_nids = gpb.nid2localnid(inner_type_nids, gpb.partid, ntype)\n\n        for name in g.nodes[ntype].data:\n            if name in [dgl.NID, \"inner_node\"]:\n                continue\n            true_feats = g.nodes[ntype].data[name][orig_id]\n            ndata = node_feats[ntype + \"/\" + name][local_nids]\n            assert np.array_equal(ndata.numpy(), true_feats.numpy())\n\n    for etype in g.canonical_etypes:\n        etype_id = g.get_etype_id(etype)\n        inner_edge_mask = _get_inner_edge_mask(part, etype_id)\n        inner_eids = part.edata[dgl.EID][inner_edge_mask]\n        etype_ids, inner_type_eids = gpb.map_to_per_etype(inner_eids)\n        partid = gpb.eid2partid(inner_type_eids, etype)\n        assert np.all(etype_ids.numpy() == etype_id)\n        assert np.all(partid.numpy() == gpb.partid)\n\n        orig_id = orig_eids[_etype_tuple_to_str(etype)][inner_type_eids]\n        local_eids = gpb.eid2localeid(inner_type_eids, gpb.partid, etype)\n\n        for name in g.edges[etype].data:\n            if name in [dgl.EID, \"inner_edge\"]:\n                continue\n            true_feats = g.edges[etype].data[name][orig_id]\n            edata = edge_feats[_etype_tuple_to_str(etype) + \"/\" + name][\n                local_eids\n            ]\n            assert np.array_equal(edata.numpy(), true_feats.numpy())\n\n\ndef verify_metadata_counts(part_schema, part_g, graph_schema, g, partid):\n    \"\"\"Verify the partitioned graph objects with the metadata\n\n    Parameters:\n    -----------\n    part_schema : json object\n        which is created by reading the metadata.json file for the\n        partitioned graph\n    part_g : DGL graph object\n        of a graph partition\n    graph_schema : json object\n        which is created by reading the metadata.json file for the\n        original graph\n    g : DGL Graph object\n        created by reading the original graph from the disk.\n    partid : integer\n        specifying the partition id of the graph object, part_g\n    \"\"\"\n    for ntype in part_schema[constants.STR_NTYPES]:\n        ntype_data = part_schema[constants.STR_NODE_MAP][ntype]\n        meta_ntype_count = ntype_data[partid][1] - ntype_data[partid][0]\n        inner_node_mask = _get_inner_node_mask(part_g, g.get_ntype_id(ntype))\n        graph_ntype_count = len(part_g.ndata[dgl.NID][inner_node_mask])\n        assert (\n            meta_ntype_count == graph_ntype_count\n        ), f\"Metadata ntypecount = {meta_ntype_count} and graph_ntype_count = {graph_ntype_count}\"\n\n    for etype in part_schema[constants.STR_ETYPES]:\n        etype_data = part_schema[constants.STR_EDGE_MAP][etype]\n        meta_etype_count = etype_data[partid][1] - etype_data[partid][0]\n        mask = _get_inner_edge_mask(\n            part_g, g.get_etype_id(_etype_str_to_tuple(etype))\n        )\n        graph_etype_count = len(part_g.edata[dgl.EID][mask])\n        assert (\n            meta_etype_count == graph_etype_count\n        ), f\"Metadata etypecount = {meta_etype_count} does not match part graph etypecount = {graph_etype_count}\"\n\n\ndef get_node_partids(partitions_dir, graph_schema):\n    \"\"\"load the node partition ids from the disk\n\n    Parameters:\n    ----------\n    partitions_dir : string\n        directory path where metis/random partitions are located\n    graph_schema : json object\n        which is created by reading the metadata.json file for the\n        original graph\n\n    Returns:\n    --------\n    dictionary :\n        where keys are node-types and value is a list of partition-ids for all the\n        nodes of that particular node-type.\n    \"\"\"\n    assert os.path.isdir(\n        partitions_dir\n    ), f\"Please provide a valid directory to read nodes to partition-id mappings.\"\n    _, gid_dict = get_idranges(\n        graph_schema[constants.STR_NODE_TYPE],\n        dict(\n            zip(\n                graph_schema[constants.STR_NODE_TYPE],\n                graph_schema[constants.STR_NODE_TYPE_COUNTS],\n            )\n        ),\n    )\n    node_partids = {}\n    for ntype_id, ntype in enumerate(graph_schema[constants.STR_NODE_TYPE]):\n        node_partids[ntype] = read_file(\n            os.path.join(partitions_dir, f\"{ntype}.txt\"), constants.STR_CSV\n        )\n        assert (\n            len(node_partids[ntype])\n            == graph_schema[constants.STR_NODE_TYPE_COUNTS][ntype_id]\n        ), f\"Node count for {ntype} = {len(node_partids[ntype])} in the partitions_dir while it should be {graph_schema[constants.STR_NTYPE_COUNTS][ntype_id]} (from graph schema).\"\n\n    return node_partids\n\n\ndef verify_node_partitionids(\n    node_partids, part_g, g, gpb, graph_schema, orig_nids, partition_id\n):\n    \"\"\"Verify partitioned graph objects node counts with the original graph\n\n    Parameters:\n    -----------\n    params : argparser object\n        to access command line arguments for this python script\n    part_data : list of tuples\n        partitioned graph objects read from the disk\n    g : DGL Graph object\n        created by reading the original graph from disk\n    graph_schema : json object\n        created by reading the metadata.json file for the original graph\n    orig_nids : dictionary\n        which contains the origial(global) node-ids\n    partition_id : integer\n        partition id of the partitioned graph, part_g\n    \"\"\"\n    # read part graphs and verify the counts\n    # inner node masks, should give the node counts in each part-g and get the corresponding orig-ids to map to the original graph node-ids\n    for ntype_id, ntype in enumerate(graph_schema[constants.STR_NODE_TYPE]):\n        mask = _get_inner_node_mask(part_g, g.get_ntype_id(ntype))\n\n        # map these to orig-nids.\n        inner_nids = part_g.ndata[dgl.NID][mask]\n        ntype_ids, inner_type_nids = gpb.map_to_per_ntype(inner_nids)\n        partid = gpb.nid2partid(inner_type_nids, ntype)\n\n        assert np.all(ntype_ids.numpy() == ntype_id)\n        assert np.all(partid.numpy() == gpb.partid)\n\n        idxes = orig_nids[ntype][inner_type_nids]\n        assert np.all(idxes >= 0)\n\n        # get the partition-ids for these nodes.\n        assert np.all(\n            node_partids[ntype][idxes] == partition_id\n        ), f\"All the nodes in the partition = {partid} does not their nodeid to partition-id maps are defined by the partitioning algorithm. Node-type = {ntype}\"\n\n\ndef read_orig_ids(out_dir, fname, num_parts):\n    \"\"\"Read original id files for the partitioned graph objects\n\n    Parameters:\n    -----------\n    out_dir : string\n        specifying the directory where the files are located\n    fname : string\n        file name to read from\n    num_parts : integer\n        no. of partitions\n\n    Returns:\n    --------\n    dictionary :\n        where keys are node/edge types and values are original node\n        or edge ids from the original graph\n    \"\"\"\n    orig_ids = {}\n    for i in range(num_parts):\n        ids_path = os.path.join(out_dir, f\"part{i}\", fname)\n        part_ids = load_tensors(ids_path)\n        for type, data in part_ids.items():\n            if type not in orig_ids:\n                orig_ids[type] = data.numpy()\n            else:\n                orig_ids[type] = np.concatenate((orig_ids[type], data))\n    return orig_ids\n"
  },
  {
    "path": "tutorials/blitz/.gitignore",
    "content": "*.dgl\n*.csv\n"
  },
  {
    "path": "tutorials/blitz/1_introduction.py",
    "content": "\"\"\"\nNode Classification with DGL\n============================\n\nGNNs are powerful tools for many machine learning tasks on graphs. In\nthis introductory tutorial, you will learn the basic workflow of using\nGNNs for node classification, i.e. predicting the category of a node in\na graph.\n\nBy completing this tutorial, you will be able to\n\n-  Load a DGL-provided dataset.\n-  Build a GNN model with DGL-provided neural network modules.\n-  Train and evaluate a GNN model for node classification on either CPU\n   or GPU.\n\nThis tutorial assumes that you have experience in building neural\nnetworks with PyTorch.\n\n(Time estimate: 13 minutes)\n\n\"\"\"\n\nimport os\n\nos.environ[\"DGLBACKEND\"] = \"pytorch\"\nimport dgl\nimport dgl.data\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n######################################################################\n# Overview of Node Classification with GNN\n# ----------------------------------------\n#\n# One of the most popular and widely adopted tasks on graph data is node\n# classification, where a model needs to predict the ground truth category\n# of each node. Before graph neural networks, many proposed methods are\n# using either connectivity alone (such as DeepWalk or node2vec), or simple\n# combinations of connectivity and the node's own features.  GNNs, by\n# contrast, offers an opportunity to obtain node representations by\n# combining the connectivity and features of a *local neighborhood*.\n#\n# `Kipf et\n# al., <https://arxiv.org/abs/1609.02907>`__ is an example that formulates\n# the node classification problem as a semi-supervised node classification\n# task. With the help of only a small portion of labeled nodes, a graph\n# neural network (GNN) can accurately predict the node category of the\n# others.\n#\n# This tutorial will show how to build such a GNN for semi-supervised node\n# classification with only a small number of labels on the Cora\n# dataset,\n# a citation network with papers as nodes and citations as edges. The task\n# is to predict the category of a given paper. Each paper node contains a\n# word count vector as its features, normalized so that they sum up to one,\n# as described in Section 5.2 of\n# `the paper <https://arxiv.org/abs/1609.02907>`__.\n#\n# Loading Cora Dataset\n# --------------------\n#\n\n\ndataset = dgl.data.CoraGraphDataset()\nprint(f\"Number of categories: {dataset.num_classes}\")\n\n\n######################################################################\n# A DGL Dataset object may contain one or multiple graphs. The Cora\n# dataset used in this tutorial only consists of one single graph.\n#\n\ng = dataset[0]\n\n\n######################################################################\n# A DGL graph can store node features and edge features in two\n# dictionary-like attributes called ``ndata`` and ``edata``.\n# In the DGL Cora dataset, the graph contains the following node features:\n#\n# - ``train_mask``: A boolean tensor indicating whether the node is in the\n#   training set.\n#\n# - ``val_mask``: A boolean tensor indicating whether the node is in the\n#   validation set.\n#\n# - ``test_mask``: A boolean tensor indicating whether the node is in the\n#   test set.\n#\n# - ``label``: The ground truth node category.\n#\n# -  ``feat``: The node features.\n#\n\nprint(\"Node features\")\nprint(g.ndata)\nprint(\"Edge features\")\nprint(g.edata)\n\n\n######################################################################\n# Defining a Graph Convolutional Network (GCN)\n# --------------------------------------------\n#\n# This tutorial will build a two-layer `Graph Convolutional Network\n# (GCN) <http://tkipf.github.io/graph-convolutional-networks/>`__. Each\n# layer computes new node representations by aggregating neighbor\n# information.\n#\n# To build a multi-layer GCN you can simply stack ``dgl.nn.GraphConv``\n# modules, which inherit ``torch.nn.Module``.\n#\n\nfrom dgl.nn import GraphConv\n\n\nclass GCN(nn.Module):\n    def __init__(self, in_feats, h_feats, num_classes):\n        super(GCN, self).__init__()\n        self.conv1 = GraphConv(in_feats, h_feats)\n        self.conv2 = GraphConv(h_feats, num_classes)\n\n    def forward(self, g, in_feat):\n        h = self.conv1(g, in_feat)\n        h = F.relu(h)\n        h = self.conv2(g, h)\n        return h\n\n\n# Create the model with given dimensions\nmodel = GCN(g.ndata[\"feat\"].shape[1], 16, dataset.num_classes)\n\n\n######################################################################\n# DGL provides implementation of many popular neighbor aggregation\n# modules. You can easily invoke them with one line of code.\n#\n\n\n######################################################################\n# Training the GCN\n# ----------------\n#\n# Training this GCN is similar to training other PyTorch neural networks.\n#\n\n\ndef train(g, model):\n    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n    best_val_acc = 0\n    best_test_acc = 0\n\n    features = g.ndata[\"feat\"]\n    labels = g.ndata[\"label\"]\n    train_mask = g.ndata[\"train_mask\"]\n    val_mask = g.ndata[\"val_mask\"]\n    test_mask = g.ndata[\"test_mask\"]\n    for e in range(100):\n        # Forward\n        logits = model(g, features)\n\n        # Compute prediction\n        pred = logits.argmax(1)\n\n        # Compute loss\n        # Note that you should only compute the losses of the nodes in the training set.\n        loss = F.cross_entropy(logits[train_mask], labels[train_mask])\n\n        # Compute accuracy on training/validation/test\n        train_acc = (pred[train_mask] == labels[train_mask]).float().mean()\n        val_acc = (pred[val_mask] == labels[val_mask]).float().mean()\n        test_acc = (pred[test_mask] == labels[test_mask]).float().mean()\n\n        # Save the best validation accuracy and the corresponding test accuracy.\n        if best_val_acc < val_acc:\n            best_val_acc = val_acc\n            best_test_acc = test_acc\n\n        # Backward\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        if e % 5 == 0:\n            print(\n                f\"In epoch {e}, loss: {loss:.3f}, val acc: {val_acc:.3f} (best {best_val_acc:.3f}), test acc: {test_acc:.3f} (best {best_test_acc:.3f})\"\n            )\n\n\nmodel = GCN(g.ndata[\"feat\"].shape[1], 16, dataset.num_classes)\ntrain(g, model)\n\n\n######################################################################\n# Training on GPU\n# ---------------\n#\n# Training on GPU requires to put both the model and the graph onto GPU\n# with the ``to`` method, similar to what you will do in PyTorch.\n#\n# .. code:: python\n#\n#    g = g.to('cuda')\n#    model = GCN(g.ndata['feat'].shape[1], 16, dataset.num_classes).to('cuda')\n#    train(g, model)\n#\n\n\n######################################################################\n# What’s next?\n# ------------\n#\n# -  :doc:`How does DGL represent a graph <2_dglgraph>`?\n# -  :doc:`Write your own GNN module <3_message_passing>`.\n# -  :doc:`Link prediction (predicting existence of edges) on full\n#    graph <4_link_predict>`.\n# -  :doc:`Graph classification <5_graph_classification>`.\n# -  :doc:`Make your own dataset <6_load_data>`.\n# -  :ref:`The list of supported graph convolution\n#    modules <apinn-pytorch>`.\n# -  :ref:`The list of datasets provided by DGL <apidata>`.\n#\n\n\n# Thumbnail credits: Stanford CS224W Notes\n# sphinx_gallery_thumbnail_path = '_static/blitz_1_introduction.png'\n"
  },
  {
    "path": "tutorials/blitz/2_dglgraph.py",
    "content": "\"\"\"\nHow Does DGL Represent A Graph?\n===============================\n\nBy the end of this tutorial you will be able to:\n\n-  Construct a graph in DGL from scratch.\n-  Assign node and edge features to a graph.\n-  Query properties of a DGL graph such as node degrees and\n   connectivity.\n-  Transform a DGL graph into another graph.\n-  Load and save DGL graphs.\n\n(Time estimate: 16 minutes)\n\n\"\"\"\n\n\n######################################################################\n# DGL Graph Construction\n# ----------------------\n#\n# DGL represents a directed graph as a ``DGLGraph`` object. You can\n# construct a graph by specifying the number of nodes in the graph as well\n# as the list of source and destination nodes.  Nodes in the graph have\n# consecutive IDs starting from 0.\n#\n# For instance, the following code constructs a directed star graph with 5\n# leaves. The center node's ID is 0. The edges go from the\n# center node to the leaves.\n#\n\nimport os\n\nos.environ[\"DGLBACKEND\"] = \"pytorch\"\nimport dgl\nimport numpy as np\nimport torch\n\ng = dgl.graph(([0, 0, 0, 0, 0], [1, 2, 3, 4, 5]), num_nodes=6)\n# Equivalently, PyTorch LongTensors also work.\ng = dgl.graph(\n    (torch.LongTensor([0, 0, 0, 0, 0]), torch.LongTensor([1, 2, 3, 4, 5])),\n    num_nodes=6,\n)\n\n# You can omit the number of nodes argument if you can tell the number of nodes from the edge list alone.\ng = dgl.graph(([0, 0, 0, 0, 0], [1, 2, 3, 4, 5]))\n\n\n######################################################################\n# Edges in the graph have consecutive IDs starting from 0, and are\n# in the same order as the list of source and destination nodes during\n# creation.\n#\n\n# Print the source and destination nodes of every edge.\nprint(g.edges())\n\n\n######################################################################\n# .. note::\n#\n#    ``DGLGraph``'s are always directed to best fit the computation\n#    pattern of graph neural networks, where the messages sent\n#    from one node to the other are often different between both\n#    directions. If you want to handle undirected graphs, you may consider\n#    treating it as a bidirectional graph. See `Graph\n#    Transformations`_ for an example of making\n#    a bidirectional graph.\n#\n\n\n######################################################################\n# Assigning Node and Edge Features to Graph\n# -----------------------------------------\n#\n# Many graph data contain attributes on nodes and edges.\n# Although the types of node and edge attributes can be arbitrary in real\n# world, ``DGLGraph`` only accepts attributes stored in tensors (with\n# numerical contents). Consequently, an attribute of all the nodes or\n# edges must have the same shape. In the context of deep learning, those\n# attributes are often called *features*.\n#\n# You can assign and retrieve node and edge features via ``ndata`` and\n# ``edata`` interface.\n#\n\n# Assign a 3-dimensional node feature vector for each node.\ng.ndata[\"x\"] = torch.randn(6, 3)\n# Assign a 4-dimensional edge feature vector for each edge.\ng.edata[\"a\"] = torch.randn(5, 4)\n# Assign a 5x4 node feature matrix for each node.  Node and edge features in DGL can be multi-dimensional.\ng.ndata[\"y\"] = torch.randn(6, 5, 4)\n\nprint(g.edata[\"a\"])\n\n\n######################################################################\n# .. note::\n#\n#    The vast development of deep learning has provided us many\n#    ways to encode various types of attributes into numerical features.\n#    Here are some general suggestions:\n#\n#    -  For categorical attributes (e.g. gender, occupation), consider\n#       converting them to integers or one-hot encoding.\n#    -  For variable length string contents (e.g. news article, quote),\n#       consider applying a language model.\n#    -  For images, consider applying a vision model such as CNNs.\n#\n#    You can find plenty of materials on how to encode such attributes\n#    into a tensor in the `PyTorch Deep Learning\n#    Tutorials <https://pytorch.org/tutorials/>`__.\n#\n\n\n######################################################################\n# Querying Graph Structures\n# -------------------------\n#\n# ``DGLGraph`` object provides various methods to query a graph structure.\n#\n\nprint(g.num_nodes())\nprint(g.num_edges())\n# Out degrees of the center node\nprint(g.out_degrees(0))\n# In degrees of the center node - note that the graph is directed so the in degree should be 0.\nprint(g.in_degrees(0))\n\n\n######################################################################\n# Graph Transformations\n# ---------------------\n#\n\n\n######################################################################\n# DGL provides many APIs to transform a graph to another such as\n# extracting a subgraph:\n#\n\n# Induce a subgraph from node 0, node 1 and node 3 from the original graph.\nsg1 = g.subgraph([0, 1, 3])\n# Induce a subgraph from edge 0, edge 1 and edge 3 from the original graph.\nsg2 = g.edge_subgraph([0, 1, 3])\n\n\n######################################################################\n# You can obtain the node/edge mapping from the subgraph to the original\n# graph by looking into the node feature ``dgl.NID`` or edge feature\n# ``dgl.EID`` in the new graph.\n#\n\n# The original IDs of each node in sg1\nprint(sg1.ndata[dgl.NID])\n# The original IDs of each edge in sg1\nprint(sg1.edata[dgl.EID])\n# The original IDs of each node in sg2\nprint(sg2.ndata[dgl.NID])\n# The original IDs of each edge in sg2\nprint(sg2.edata[dgl.EID])\n\n\n######################################################################\n# ``subgraph`` and ``edge_subgraph`` also copies the original features\n# to the subgraph:\n#\n\n# The original node feature of each node in sg1\nprint(sg1.ndata[\"x\"])\n# The original edge feature of each node in sg1\nprint(sg1.edata[\"a\"])\n# The original node feature of each node in sg2\nprint(sg2.ndata[\"x\"])\n# The original edge feature of each node in sg2\nprint(sg2.edata[\"a\"])\n\n\n######################################################################\n# Another common transformation is to add a reverse edge for each edge in\n# the original graph with ``dgl.add_reverse_edges``.\n#\n# .. note::\n#\n#    If you have an undirected graph, it is better to convert it\n#    into a bidirectional graph first via adding reverse edges.\n#\n\nnewg = dgl.add_reverse_edges(g)\nprint(newg.edges())\n\n\n######################################################################\n# Loading and Saving Graphs\n# -------------------------\n#\n# You can save a graph or a list of graphs via ``dgl.save_graphs`` and\n# load them back with ``dgl.load_graphs``.\n#\n\n# Save graphs\ndgl.save_graphs(\"graph.dgl\", g)\ndgl.save_graphs(\"graphs.dgl\", [g, sg1, sg2])\n\n# Load graphs\n(g,), _ = dgl.load_graphs(\"graph.dgl\")\nprint(g)\n(g, sg1, sg2), _ = dgl.load_graphs(\"graphs.dgl\")\nprint(g)\nprint(sg1)\nprint(sg2)\n\n\n######################################################################\n# What’s next?\n# ------------\n#\n# -  See\n#    :ref:`here <apigraph-querying-graph-structure>`\n#    for a list of graph structure query APIs.\n# -  See\n#    :ref:`here <api-subgraph-extraction>`\n#    for a list of subgraph extraction routines.\n# -  See\n#    :ref:`here <api-transform>`\n#    for a list of graph transformation routines.\n# -  API reference of :func:`dgl.save_graphs`\n#    and\n#    :func:`dgl.load_graphs`\n#\n\n\n# Thumbnail credits: Wikipedia\n# sphinx_gallery_thumbnail_path = '_static/blitz_2_dglgraph.png'\n"
  },
  {
    "path": "tutorials/blitz/3_message_passing.py",
    "content": "\"\"\"\nWrite your own GNN module\n=========================\n\nSometimes, your model goes beyond simply stacking existing GNN modules.\nFor example, you would like to invent a new way of aggregating neighbor\ninformation by considering node importance or edge weights.\n\nBy the end of this tutorial you will be able to\n\n-  Understand DGL’s message passing APIs.\n-  Implement GraphSAGE convolution module by your own.\n\nThis tutorial assumes that you already know :doc:`the basics of training a\nGNN for node classification <1_introduction>`.\n\n(Time estimate: 10 minutes)\n\n\"\"\"\n\nimport os\n\nos.environ[\"DGLBACKEND\"] = \"pytorch\"\nimport dgl\nimport dgl.function as fn\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n######################################################################\n# Message passing and GNNs\n# ------------------------\n#\n# DGL follows the *message passing paradigm* inspired by the Message\n# Passing Neural Network proposed by `Gilmer et\n# al. <https://arxiv.org/abs/1704.01212>`__ Essentially, they found many\n# GNN models can fit into the following framework:\n#\n# .. math::\n#\n#\n#    m_{u\\to v}^{(l)} = M^{(l)}\\left(h_v^{(l-1)}, h_u^{(l-1)}, e_{u\\to v}^{(l-1)}\\right)\n#\n# .. math::\n#\n#\n#    m_{v}^{(l)} = \\sum_{u\\in\\mathcal{N}(v)}m_{u\\to v}^{(l)}\n#\n# .. math::\n#\n#\n#    h_v^{(l)} = U^{(l)}\\left(h_v^{(l-1)}, m_v^{(l)}\\right)\n#\n# where DGL calls :math:`M^{(l)}` the *message function*, :math:`\\sum` the\n# *reduce function* and :math:`U^{(l)}` the *update function*. Note that\n# :math:`\\sum` here can represent any function and is not necessarily a\n# summation.\n#\n\n\n######################################################################\n# For example, the `GraphSAGE convolution (Hamilton et al.,\n# 2017) <https://cs.stanford.edu/people/jure/pubs/graphsage-nips17.pdf>`__\n# takes the following mathematical form:\n#\n# .. math::\n#\n#\n#    h_{\\mathcal{N}(v)}^k\\leftarrow \\text{Average}\\{h_u^{k-1},\\forall u\\in\\mathcal{N}(v)\\}\n#\n# .. math::\n#\n#\n#    h_v^k\\leftarrow \\text{ReLU}\\left(W^k\\cdot \\text{CONCAT}(h_v^{k-1}, h_{\\mathcal{N}(v)}^k) \\right)\n#\n# You can see that message passing is directional: the message sent from\n# one node :math:`u` to other node :math:`v` is not necessarily the same\n# as the other message sent from node :math:`v` to node :math:`u` in the\n# opposite direction.\n#\n# Although DGL has builtin support of GraphSAGE via\n# :class:`dgl.nn.SAGEConv <dgl.nn.pytorch.SAGEConv>`,\n# here is how you can implement GraphSAGE convolution in DGL by your own.\n#\n\n\nclass SAGEConv(nn.Module):\n    \"\"\"Graph convolution module used by the GraphSAGE model.\n\n    Parameters\n    ----------\n    in_feat : int\n        Input feature size.\n    out_feat : int\n        Output feature size.\n    \"\"\"\n\n    def __init__(self, in_feat, out_feat):\n        super(SAGEConv, self).__init__()\n        # A linear submodule for projecting the input and neighbor feature to the output.\n        self.linear = nn.Linear(in_feat * 2, out_feat)\n\n    def forward(self, g, h):\n        \"\"\"Forward computation\n\n        Parameters\n        ----------\n        g : Graph\n            The input graph.\n        h : Tensor\n            The input node feature.\n        \"\"\"\n        with g.local_scope():\n            g.ndata[\"h\"] = h\n            # update_all is a message passing API.\n            g.update_all(\n                message_func=fn.copy_u(\"h\", \"m\"),\n                reduce_func=fn.mean(\"m\", \"h_N\"),\n            )\n            h_N = g.ndata[\"h_N\"]\n            h_total = torch.cat([h, h_N], dim=1)\n            return self.linear(h_total)\n\n\n######################################################################\n# The central piece in this code is the\n# :func:`g.update_all <dgl.DGLGraph.update_all>`\n# function, which gathers and averages the neighbor features. There are\n# three concepts here:\n#\n# * Message function ``fn.copy_u('h', 'm')`` that\n#   copies the node feature under name ``'h'`` as *messages* with name\n#   ``'m'`` sent to neighbors.\n#\n# * Reduce function ``fn.mean('m', 'h_N')`` that averages\n#   all the received messages under name ``'m'`` and saves the result as a\n#   new node feature ``'h_N'``.\n#\n# * ``update_all`` tells DGL to trigger the\n#   message and reduce functions for all the nodes and edges.\n#\n\n\n######################################################################\n# Afterwards, you can stack your own GraphSAGE convolution layers to form\n# a multi-layer GraphSAGE network.\n#\n\n\nclass Model(nn.Module):\n    def __init__(self, in_feats, h_feats, num_classes):\n        super(Model, self).__init__()\n        self.conv1 = SAGEConv(in_feats, h_feats)\n        self.conv2 = SAGEConv(h_feats, num_classes)\n\n    def forward(self, g, in_feat):\n        h = self.conv1(g, in_feat)\n        h = F.relu(h)\n        h = self.conv2(g, h)\n        return h\n\n\n######################################################################\n# Training loop\n# ~~~~~~~~~~~~~\n# The following code for data loading and training loop is directly copied\n# from the introduction tutorial.\n#\n\nimport dgl.data\n\ndataset = dgl.data.CoraGraphDataset()\ng = dataset[0]\n\n\ndef train(g, model):\n    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n    all_logits = []\n    best_val_acc = 0\n    best_test_acc = 0\n\n    features = g.ndata[\"feat\"]\n    labels = g.ndata[\"label\"]\n    train_mask = g.ndata[\"train_mask\"]\n    val_mask = g.ndata[\"val_mask\"]\n    test_mask = g.ndata[\"test_mask\"]\n    for e in range(200):\n        # Forward\n        logits = model(g, features)\n\n        # Compute prediction\n        pred = logits.argmax(1)\n\n        # Compute loss\n        # Note that we should only compute the losses of the nodes in the training set,\n        # i.e. with train_mask 1.\n        loss = F.cross_entropy(logits[train_mask], labels[train_mask])\n\n        # Compute accuracy on training/validation/test\n        train_acc = (pred[train_mask] == labels[train_mask]).float().mean()\n        val_acc = (pred[val_mask] == labels[val_mask]).float().mean()\n        test_acc = (pred[test_mask] == labels[test_mask]).float().mean()\n\n        # Save the best validation accuracy and the corresponding test accuracy.\n        if best_val_acc < val_acc:\n            best_val_acc = val_acc\n            best_test_acc = test_acc\n\n        # Backward\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n        all_logits.append(logits.detach())\n\n        if e % 5 == 0:\n            print(\n                \"In epoch {}, loss: {:.3f}, val acc: {:.3f} (best {:.3f}), test acc: {:.3f} (best {:.3f})\".format(\n                    e, loss, val_acc, best_val_acc, test_acc, best_test_acc\n                )\n            )\n\n\nmodel = Model(g.ndata[\"feat\"].shape[1], 16, dataset.num_classes)\ntrain(g, model)\n\n\n######################################################################\n# More customization\n# ------------------\n#\n# In DGL, we provide many built-in message and reduce functions under the\n# ``dgl.function`` package. You can find more details in :ref:`the API\n# doc <apifunction>`.\n#\n\n\n######################################################################\n# These APIs allow one to quickly implement new graph convolution modules.\n# For example, the following implements a new ``SAGEConv`` that aggregates\n# neighbor representations using a weighted average. Note that ``edata``\n# member can hold edge features which can also take part in message\n# passing.\n#\n\n\nclass WeightedSAGEConv(nn.Module):\n    \"\"\"Graph convolution module used by the GraphSAGE model with edge weights.\n\n    Parameters\n    ----------\n    in_feat : int\n        Input feature size.\n    out_feat : int\n        Output feature size.\n    \"\"\"\n\n    def __init__(self, in_feat, out_feat):\n        super(WeightedSAGEConv, self).__init__()\n        # A linear submodule for projecting the input and neighbor feature to the output.\n        self.linear = nn.Linear(in_feat * 2, out_feat)\n\n    def forward(self, g, h, w):\n        \"\"\"Forward computation\n\n        Parameters\n        ----------\n        g : Graph\n            The input graph.\n        h : Tensor\n            The input node feature.\n        w : Tensor\n            The edge weight.\n        \"\"\"\n        with g.local_scope():\n            g.ndata[\"h\"] = h\n            g.edata[\"w\"] = w\n            g.update_all(\n                message_func=fn.u_mul_e(\"h\", \"w\", \"m\"),\n                reduce_func=fn.mean(\"m\", \"h_N\"),\n            )\n            h_N = g.ndata[\"h_N\"]\n            h_total = torch.cat([h, h_N], dim=1)\n            return self.linear(h_total)\n\n\n######################################################################\n# Because the graph in this dataset does not have edge weights, we\n# manually assign all edge weights to one in the ``forward()`` function of\n# the model. You can replace it with your own edge weights.\n#\n\n\nclass Model(nn.Module):\n    def __init__(self, in_feats, h_feats, num_classes):\n        super(Model, self).__init__()\n        self.conv1 = WeightedSAGEConv(in_feats, h_feats)\n        self.conv2 = WeightedSAGEConv(h_feats, num_classes)\n\n    def forward(self, g, in_feat):\n        h = self.conv1(g, in_feat, torch.ones(g.num_edges(), 1).to(g.device))\n        h = F.relu(h)\n        h = self.conv2(g, h, torch.ones(g.num_edges(), 1).to(g.device))\n        return h\n\n\nmodel = Model(g.ndata[\"feat\"].shape[1], 16, dataset.num_classes)\ntrain(g, model)\n\n\n######################################################################\n# Even more customization by user-defined function\n# ------------------------------------------------\n#\n# DGL allows user-defined message and reduce function for the maximal\n# expressiveness. Here is a user-defined message function that is\n# equivalent to ``fn.u_mul_e('h', 'w', 'm')``.\n#\n\n\ndef u_mul_e_udf(edges):\n    return {\"m\": edges.src[\"h\"] * edges.data[\"w\"]}\n\n\n######################################################################\n# ``edges`` has three members: ``src``, ``data`` and ``dst``, representing\n# the source node feature, edge feature, and destination node feature for\n# all edges.\n#\n\n\n######################################################################\n# You can also write your own reduce function. For example, the following\n# is equivalent to the builtin ``fn.mean('m', 'h_N')`` function that averages\n# the incoming messages:\n#\n\n\ndef mean_udf(nodes):\n    return {\"h_N\": nodes.mailbox[\"m\"].mean(1)}\n\n\n######################################################################\n# In short, DGL will group the nodes by their in-degrees, and for each\n# group DGL stacks the incoming messages along the second dimension. You\n# can then perform a reduction along the second dimension to aggregate\n# messages.\n#\n# For more details on customizing message and reduce function with\n# user-defined function, please refer to the :ref:`API\n# reference <apiudf>`.\n#\n\n\n######################################################################\n# Best practice of writing custom GNN modules\n# -------------------------------------------\n#\n# DGL recommends the following practice ranked by preference:\n#\n# -  Use ``dgl.nn`` modules.\n# -  Use ``dgl.nn.functional`` functions which contain lower-level complex\n#    operations such as computing a softmax for each node over incoming\n#    edges.\n# -  Use ``update_all`` with builtin message and reduce functions.\n# -  Use user-defined message or reduce functions.\n#\n\n\n######################################################################\n# What’s next?\n# ------------\n#\n# -  :ref:`Writing Efficient Message Passing\n#    Code <guide-message-passing-efficient>`.\n#\n\n\n# Thumbnail credits: Representation Learning on Networks, Jure Leskovec, WWW 2018\n# sphinx_gallery_thumbnail_path = '_static/blitz_3_message_passing.png'\n"
  },
  {
    "path": "tutorials/blitz/4_link_predict.py",
    "content": "\"\"\"\nLink Prediction using Graph Neural Networks\n===========================================\n\nIn the :doc:`introduction <1_introduction>`, you have already learned\nthe basic workflow of using GNNs for node classification,\ni.e. predicting the category of a node in a graph. This tutorial will\nteach you how to train a GNN for link prediction, i.e. predicting the\nexistence of an edge between two arbitrary nodes in a graph.\n\nBy the end of this tutorial you will be able to\n\n-  Build a GNN-based link prediction model.\n-  Train and evaluate the model on a small DGL-provided dataset.\n\n(Time estimate: 28 minutes)\n\n\"\"\"\n\nimport itertools\nimport os\n\nos.environ[\"DGLBACKEND\"] = \"pytorch\"\n\nimport dgl\nimport dgl.data\nimport numpy as np\nimport scipy.sparse as sp\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n######################################################################\n# Overview of Link Prediction with GNN\n# ------------------------------------\n#\n# Many applications such as social recommendation, item recommendation,\n# knowledge graph completion, etc., can be formulated as link prediction,\n# which predicts whether an edge exists between two particular nodes. This\n# tutorial shows an example of predicting whether a citation relationship,\n# either citing or being cited, between two papers exists in a citation\n# network.\n#\n# This tutorial formulates the link prediction problem as a binary classification\n# problem as follows:\n#\n# -  Treat the edges in the graph as *positive examples*.\n# -  Sample a number of non-existent edges (i.e. node pairs with no edges\n#    between them) as *negative* examples.\n# -  Divide the positive examples and negative examples into a training\n#    set and a test set.\n# -  Evaluate the model with any binary classification metric such as Area\n#    Under Curve (AUC).\n#\n# .. note::\n#\n#    The practice comes from\n#    `SEAL <https://papers.nips.cc/paper/2018/file/53f0d7c537d99b3824f0f99d62ea2428-Paper.pdf>`__,\n#    although the model here does not use their idea of node labeling.\n#\n# In some domains such as large-scale recommender systems or information\n# retrieval, you may favor metrics that emphasize good performance of\n# top-K predictions. In these cases you may want to consider other metrics\n# such as mean average precision, and use other negative sampling methods,\n# which are beyond the scope of this tutorial.\n#\n# Loading graph and features\n# --------------------------\n#\n# Following the :doc:`introduction <1_introduction>`, this tutorial\n# first loads the Cora dataset.\n#\n\n\ndataset = dgl.data.CoraGraphDataset()\ng = dataset[0]\n\n\n######################################################################\n# Prepare training and testing sets\n# ---------------------------------\n#\n# This tutorial randomly picks 10% of the edges for positive examples in\n# the test set, and leave the rest for the training set. It then samples\n# the same number of edges for negative examples in both sets.\n#\n\n# Split edge set for training and testing\nu, v = g.edges()\n\neids = np.arange(g.num_edges())\neids = np.random.permutation(eids)\ntest_size = int(len(eids) * 0.1)\ntrain_size = g.num_edges() - test_size\ntest_pos_u, test_pos_v = u[eids[:test_size]], v[eids[:test_size]]\ntrain_pos_u, train_pos_v = u[eids[test_size:]], v[eids[test_size:]]\n\n# Find all negative edges and split them for training and testing\nadj = sp.coo_matrix((np.ones(len(u)), (u.numpy(), v.numpy())))\nadj_neg = 1 - adj.todense() - np.eye(g.num_nodes())\nneg_u, neg_v = np.where(adj_neg != 0)\n\nneg_eids = np.random.choice(len(neg_u), g.num_edges())\ntest_neg_u, test_neg_v = (\n    neg_u[neg_eids[:test_size]],\n    neg_v[neg_eids[:test_size]],\n)\ntrain_neg_u, train_neg_v = (\n    neg_u[neg_eids[test_size:]],\n    neg_v[neg_eids[test_size:]],\n)\n\n\n######################################################################\n# When training, you will need to remove the edges in the test set from\n# the original graph. You can do this via ``dgl.remove_edges``.\n#\n# .. note::\n#\n#    ``dgl.remove_edges`` works by creating a subgraph from the\n#    original graph, resulting in a copy and therefore could be slow for\n#    large graphs. If so, you could save the training and test graph to\n#    disk, as you would do for preprocessing.\n#\n\ntrain_g = dgl.remove_edges(g, eids[:test_size])\n\n\n######################################################################\n# Define a GraphSAGE model\n# ------------------------\n#\n# This tutorial builds a model consisting of two\n# `GraphSAGE <https://arxiv.org/abs/1706.02216>`__ layers, each computes\n# new node representations by averaging neighbor information. DGL provides\n# ``dgl.nn.SAGEConv`` that conveniently creates a GraphSAGE layer.\n#\n\nfrom dgl.nn import SAGEConv\n\n\n# ----------- 2. create model -------------- #\n# build a two-layer GraphSAGE model\nclass GraphSAGE(nn.Module):\n    def __init__(self, in_feats, h_feats):\n        super(GraphSAGE, self).__init__()\n        self.conv1 = SAGEConv(in_feats, h_feats, \"mean\")\n        self.conv2 = SAGEConv(h_feats, h_feats, \"mean\")\n\n    def forward(self, g, in_feat):\n        h = self.conv1(g, in_feat)\n        h = F.relu(h)\n        h = self.conv2(g, h)\n        return h\n\n\n######################################################################\n# The model then predicts the probability of existence of an edge by\n# computing a score between the representations of both incident nodes\n# with a function (e.g. an MLP or a dot product), which you will see in\n# the next section.\n#\n# .. math::\n#\n#\n#    \\hat{y}_{u\\sim v} = f(h_u, h_v)\n#\n\n\n######################################################################\n# Positive graph, negative graph, and ``apply_edges``\n# ---------------------------------------------------\n#\n# In previous tutorials you have learned how to compute node\n# representations with a GNN. However, link prediction requires you to\n# compute representation of *pairs of nodes*.\n#\n# DGL recommends you to treat the pairs of nodes as another graph, since\n# you can describe a pair of nodes with an edge. In link prediction, you\n# will have a *positive graph* consisting of all the positive examples as\n# edges, and a *negative graph* consisting of all the negative examples.\n# The *positive graph* and the *negative graph* will contain the same set\n# of nodes as the original graph.  This makes it easier to pass node\n# features among multiple graphs for computation.  As you will see later,\n# you can directly feed the node representations computed on the entire\n# graph to the positive and the negative graphs for computing pair-wise\n# scores.\n#\n# The following code constructs the positive graph and the negative graph\n# for the training set and the test set respectively.\n#\n\ntrain_pos_g = dgl.graph((train_pos_u, train_pos_v), num_nodes=g.num_nodes())\ntrain_neg_g = dgl.graph((train_neg_u, train_neg_v), num_nodes=g.num_nodes())\n\ntest_pos_g = dgl.graph((test_pos_u, test_pos_v), num_nodes=g.num_nodes())\ntest_neg_g = dgl.graph((test_neg_u, test_neg_v), num_nodes=g.num_nodes())\n\n\n######################################################################\n# The benefit of treating the pairs of nodes as a graph is that you can\n# use the ``DGLGraph.apply_edges`` method, which conveniently computes new\n# edge features based on the incident nodes’ features and the original\n# edge features (if applicable).\n#\n# DGL provides a set of optimized builtin functions to compute new\n# edge features based on the original node/edge features. For example,\n# ``dgl.function.u_dot_v`` computes a dot product of the incident nodes’\n# representations for each edge.\n#\n\nimport dgl.function as fn\n\n\nclass DotPredictor(nn.Module):\n    def forward(self, g, h):\n        with g.local_scope():\n            g.ndata[\"h\"] = h\n            # Compute a new edge feature named 'score' by a dot-product between the\n            # source node feature 'h' and destination node feature 'h'.\n            g.apply_edges(fn.u_dot_v(\"h\", \"h\", \"score\"))\n            # u_dot_v returns a 1-element vector for each edge so you need to squeeze it.\n            return g.edata[\"score\"][:, 0]\n\n\n######################################################################\n# You can also write your own function if it is complex.\n# For instance, the following module produces a scalar score on each edge\n# by concatenating the incident nodes’ features and passing it to an MLP.\n#\n\n\nclass MLPPredictor(nn.Module):\n    def __init__(self, h_feats):\n        super().__init__()\n        self.W1 = nn.Linear(h_feats * 2, h_feats)\n        self.W2 = nn.Linear(h_feats, 1)\n\n    def apply_edges(self, edges):\n        \"\"\"\n        Computes a scalar score for each edge of the given graph.\n\n        Parameters\n        ----------\n        edges :\n            Has three members ``src``, ``dst`` and ``data``, each of\n            which is a dictionary representing the features of the\n            source nodes, the destination nodes, and the edges\n            themselves.\n\n        Returns\n        -------\n        dict\n            A dictionary of new edge features.\n        \"\"\"\n        h = torch.cat([edges.src[\"h\"], edges.dst[\"h\"]], 1)\n        return {\"score\": self.W2(F.relu(self.W1(h))).squeeze(1)}\n\n    def forward(self, g, h):\n        with g.local_scope():\n            g.ndata[\"h\"] = h\n            g.apply_edges(self.apply_edges)\n            return g.edata[\"score\"]\n\n\n######################################################################\n# .. note::\n#\n#    The builtin functions are optimized for both speed and memory.\n#    We recommend using builtin functions whenever possible.\n#\n# .. note::\n#\n#    If you have read the :doc:`message passing\n#    tutorial <3_message_passing>`, you will notice that the\n#    argument ``apply_edges`` takes has exactly the same form as a message\n#    function in ``update_all``.\n#\n\n\n######################################################################\n# Training loop\n# -------------\n#\n# After you defined the node representation computation and the edge score\n# computation, you can go ahead and define the overall model, loss\n# function, and evaluation metric.\n#\n# The loss function is simply binary cross entropy loss.\n#\n# .. math::\n#\n#\n#    \\mathcal{L} = -\\sum_{u\\sim v\\in \\mathcal{D}}\\left( y_{u\\sim v}\\log(\\hat{y}_{u\\sim v}) + (1-y_{u\\sim v})\\log(1-\\hat{y}_{u\\sim v})) \\right)\n#\n# The evaluation metric in this tutorial is AUC.\n#\n\nmodel = GraphSAGE(train_g.ndata[\"feat\"].shape[1], 16)\n# You can replace DotPredictor with MLPPredictor.\n# pred = MLPPredictor(16)\npred = DotPredictor()\n\n\ndef compute_loss(pos_score, neg_score):\n    scores = torch.cat([pos_score, neg_score])\n    labels = torch.cat(\n        [torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])]\n    )\n    return F.binary_cross_entropy_with_logits(scores, labels)\n\n\ndef compute_auc(pos_score, neg_score):\n    scores = torch.cat([pos_score, neg_score]).numpy()\n    labels = torch.cat(\n        [torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])]\n    ).numpy()\n    return roc_auc_score(labels, scores)\n\n\n######################################################################\n# The training loop goes as follows:\n#\n# .. note::\n#\n#    This tutorial does not include evaluation on a validation\n#    set. In practice you should save and evaluate the best model based on\n#    performance on the validation set.\n#\n\n# ----------- 3. set up loss and optimizer -------------- #\n# in this case, loss will in training loop\noptimizer = torch.optim.Adam(\n    itertools.chain(model.parameters(), pred.parameters()), lr=0.01\n)\n\n# ----------- 4. training -------------------------------- #\nall_logits = []\nfor e in range(100):\n    # forward\n    h = model(train_g, train_g.ndata[\"feat\"])\n    pos_score = pred(train_pos_g, h)\n    neg_score = pred(train_neg_g, h)\n    loss = compute_loss(pos_score, neg_score)\n\n    # backward\n    optimizer.zero_grad()\n    loss.backward()\n    optimizer.step()\n\n    if e % 5 == 0:\n        print(\"In epoch {}, loss: {}\".format(e, loss))\n\n# ----------- 5. check results ------------------------ #\nfrom sklearn.metrics import roc_auc_score\n\nwith torch.no_grad():\n    pos_score = pred(test_pos_g, h)\n    neg_score = pred(test_neg_g, h)\n    print(\"AUC\", compute_auc(pos_score, neg_score))\n\n\n# Thumbnail credits: Link Prediction with Neo4j, Mark Needham\n# sphinx_gallery_thumbnail_path = '_static/blitz_4_link_predict.png'\n"
  },
  {
    "path": "tutorials/blitz/5_graph_classification.py",
    "content": "\"\"\"\nTraining a GNN for Graph Classification\n=======================================\n\nBy the end of this tutorial, you will be able to\n\n-  Load a DGL-provided graph classification dataset.\n-  Understand what *readout* function does.\n-  Understand how to create and use a minibatch of graphs.\n-  Build a GNN-based graph classification model.\n-  Train and evaluate the model on a DGL-provided dataset.\n\n(Time estimate: 18 minutes)\n\"\"\"\n\nimport os\n\nos.environ[\"DGLBACKEND\"] = \"pytorch\"\nimport dgl\nimport dgl.data\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n######################################################################\n# Overview of Graph Classification with GNN\n# -----------------------------------------\n#\n# Graph classification or regression requires a model to predict certain\n# graph-level properties of a single graph given its node and edge\n# features.  Molecular property prediction is one particular application.\n#\n# This tutorial shows how to train a graph classification model for a\n# small dataset from the paper `How Powerful Are Graph Neural\n# Networks <https://arxiv.org/abs/1810.00826>`__.\n#\n# Loading Data\n# ------------\n#\n\n\n# Generate a synthetic dataset with 10000 graphs, ranging from 10 to 500 nodes.\ndataset = dgl.data.GINDataset(\"PROTEINS\", self_loop=True)\n\n\n######################################################################\n# The dataset is a set of graphs, each with node features and a single\n# label. One can see the node feature dimensionality and the number of\n# possible graph categories of ``GINDataset`` objects in ``dim_nfeats``\n# and ``gclasses`` attributes.\n#\n\nprint(\"Node feature dimensionality:\", dataset.dim_nfeats)\nprint(\"Number of graph categories:\", dataset.gclasses)\n\n\nfrom dgl.dataloading import GraphDataLoader\n\n######################################################################\n# Defining Data Loader\n# --------------------\n#\n# A graph classification dataset usually contains two types of elements: a\n# set of graphs, and their graph-level labels. Similar to an image\n# classification task, when the dataset is large enough, we need to train\n# with mini-batches. When you train a model for image classification or\n# language modeling, you will use a ``DataLoader`` to iterate over the\n# dataset. In DGL, you can use the ``GraphDataLoader``.\n#\n# You can also use various dataset samplers provided in\n# `torch.utils.data.sampler <https://pytorch.org/docs/stable/data.html#data-loading-order-and-sampler>`__.\n# For example, this tutorial creates a training ``GraphDataLoader`` and\n# test ``GraphDataLoader``, using ``SubsetRandomSampler`` to tell PyTorch\n# to sample from only a subset of the dataset.\n#\n\nfrom torch.utils.data.sampler import SubsetRandomSampler\n\nnum_examples = len(dataset)\nnum_train = int(num_examples * 0.8)\n\ntrain_sampler = SubsetRandomSampler(torch.arange(num_train))\ntest_sampler = SubsetRandomSampler(torch.arange(num_train, num_examples))\n\ntrain_dataloader = GraphDataLoader(\n    dataset, sampler=train_sampler, batch_size=5, drop_last=False\n)\ntest_dataloader = GraphDataLoader(\n    dataset, sampler=test_sampler, batch_size=5, drop_last=False\n)\n\n\n######################################################################\n# You can try to iterate over the created ``GraphDataLoader`` and see what it\n# gives:\n#\n\nit = iter(train_dataloader)\nbatch = next(it)\nprint(batch)\n\n\n######################################################################\n# As each element in ``dataset`` has a graph and a label, the\n# ``GraphDataLoader`` will return two objects for each iteration. The\n# first element is the batched graph, and the second element is simply a\n# label vector representing the category of each graph in the mini-batch.\n# Next, we’ll talked about the batched graph.\n#\n# A Batched Graph in DGL\n# ----------------------\n#\n# In each mini-batch, the sampled graphs are combined into a single bigger\n# batched graph via ``dgl.batch``. The single bigger batched graph merges\n# all original graphs as separately connected components, with the node\n# and edge features concatenated. This bigger graph is also a ``DGLGraph``\n# instance (so you can\n# still treat it as a normal ``DGLGraph`` object as in\n# `here <2_dglgraph.ipynb>`__). It however contains the information\n# necessary for recovering the original graphs, such as the number of\n# nodes and edges of each graph element.\n#\n\nbatched_graph, labels = batch\nprint(\n    \"Number of nodes for each graph element in the batch:\",\n    batched_graph.batch_num_nodes(),\n)\nprint(\n    \"Number of edges for each graph element in the batch:\",\n    batched_graph.batch_num_edges(),\n)\n\n# Recover the original graph elements from the minibatch\ngraphs = dgl.unbatch(batched_graph)\nprint(\"The original graphs in the minibatch:\")\nprint(graphs)\n\n\n######################################################################\n# Define Model\n# ------------\n#\n# This tutorial will build a two-layer `Graph Convolutional Network\n# (GCN) <http://tkipf.github.io/graph-convolutional-networks/>`__. Each of\n# its layer computes new node representations by aggregating neighbor\n# information. If you have gone through the\n# :doc:`introduction <1_introduction>`, you will notice two\n# differences:\n#\n# -  Since the task is to predict a single category for the *entire graph*\n#    instead of for every node, you will need to aggregate the\n#    representations of all the nodes and potentially the edges to form a\n#    graph-level representation. Such process is more commonly referred as\n#    a *readout*. A simple choice is to average the node features of a\n#    graph with ``dgl.mean_nodes()``.\n#\n# -  The input graph to the model will be a batched graph yielded by the\n#    ``GraphDataLoader``. The readout functions provided by DGL can handle\n#    batched graphs so that they will return one representation for each\n#    minibatch element.\n#\n\nfrom dgl.nn import GraphConv\n\n\nclass GCN(nn.Module):\n    def __init__(self, in_feats, h_feats, num_classes):\n        super(GCN, self).__init__()\n        self.conv1 = GraphConv(in_feats, h_feats)\n        self.conv2 = GraphConv(h_feats, num_classes)\n\n    def forward(self, g, in_feat):\n        h = self.conv1(g, in_feat)\n        h = F.relu(h)\n        h = self.conv2(g, h)\n        g.ndata[\"h\"] = h\n        return dgl.mean_nodes(g, \"h\")\n\n\n######################################################################\n# Training Loop\n# -------------\n#\n# The training loop iterates over the training set with the\n# ``GraphDataLoader`` object and computes the gradients, just like\n# image classification or language modeling.\n#\n\n# Create the model with given dimensions\nmodel = GCN(dataset.dim_nfeats, 16, dataset.gclasses)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n\nfor epoch in range(20):\n    for batched_graph, labels in train_dataloader:\n        pred = model(batched_graph, batched_graph.ndata[\"attr\"].float())\n        loss = F.cross_entropy(pred, labels)\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\nnum_correct = 0\nnum_tests = 0\nfor batched_graph, labels in test_dataloader:\n    pred = model(batched_graph, batched_graph.ndata[\"attr\"].float())\n    num_correct += (pred.argmax(1) == labels).sum().item()\n    num_tests += len(labels)\n\nprint(\"Test accuracy:\", num_correct / num_tests)\n\n\n######################################################################\n# What’s next\n# -----------\n#\n# -  See `GIN\n#    example <https://github.com/dmlc/dgl/tree/master/examples/pytorch/gin>`__\n#    for an end-to-end graph classification model.\n#\n\n\n# Thumbnail credits: DGL\n# sphinx_gallery_thumbnail_path = '_static/blitz_5_graph_classification.png'\n"
  },
  {
    "path": "tutorials/blitz/README.txt",
    "content": "A Blitz Introduction to DGL\n===========================\n"
  },
  {
    "path": "tutorials/cpu/README.txt",
    "content": "Training on CPUs\n=========================\n"
  },
  {
    "path": "tutorials/cpu/argo_tutorial.py",
    "content": "\"\"\"\nImprove Scalability on Multi-Core CPUs\n=====================================================\n\nGraph Neural Network (GNN) training suffers from low scalability on multi-core CPUs. \nSpecificially, the performance often caps at 16 cores, and no improvement is observed when applying more than 16 cores [#f1]_.\nARGO is a runtime system that offers scalable performance.  \nWith ARGO enabled, we are able to scale over 64 cores, allowing ARGO to speedup GNN training (in terms of epoch time) by up to 4.30x and 3.32x on a Xeon 8380H and a Xeon 6430L, respectively [#f2]_.\nThis chapter focus on how to setup ARGO to unleash the potential of multi-core CPUs to speedup GNN training.\n\nInstallation\n`````````````````````````````\n\nARGO utilizes the scikit-optimize library for auto-tuning. Please install scikit-optimize to run ARGO:\n.. code-block:: shell\n    conda install -c conda-forge \"scikit-optimize>=0.9.0\" \nor\n.. code-block:: shell\n    pip install scikit-optimize>=0.9\n\nEnabling ARGO on your own GNN program\n```````````````````````````````````````````\nIn this section, we provide a step-by-step tutorial on how to enable ARGO on a DGL program. \nWe use the *ogb_example.py* [#f3]_ as an example.\n.. note::\n    We also provide the complete example file *ogb_example_ARGO.py* [#f4]_ \n    which followed the steps below to enable ARGO on *ogb_example.py*.\n\nStep 1\n---------------------------\nFirst, include all necessary packages on top of the file. Please place your file and *argo.py* [#f5]_ in the same directory.\n\n.. code-block:: python\n    import os\n    import torch.distributed as dist\n    from torch.nn.parallel import DistributedDataParallel\n    import torch.multiprocessing as mp\n    from argo import ARGO\n\nStep 2\n---------------------------\nSetup PyTorch Distributed Data Parallel (DDP)\n\n2.1. Add the initialization function on top of the training program, and wrap the ```model``` with the DDP wrapper\n.. code-block:: python\n    def train(...):\n        dist.init_process_group('gloo', rank=rank, world_size=world_size) # newly added\n        model = SAGE(...) # original code\n        model = DistributedDataParallel(model) # newly added\n        ...\n     \n2.2. In the main program, add the following before launching the training function\n.. code-block:: python\n    ...\n    os.environ['MASTER_ADDR'] = '127.0.0.1'\n    os.environ['MASTER_PORT'] = '29501'\n    mp.set_start_method('fork', force=True)\n    train(args, device, data) # original code for launching the training function\n\nStep 3\n---------------------------\nEnable ARGO by initializing the runtime system, and wrapping the training function\n.. code-block:: python\n    runtime = ARGO(n_search = 15, epoch = args.num_epochs, batch_size = args.batch_size) # initialization\n    runtime.run(train, args=(args, device, data)) # wrap the training function\n\n.. note::\n    ARGO takes three input parameters: number of searches *n_search*, number of epochs, and the mini-batch size. \n    Increasing *n_search* potentially leads to a better configuration with less epoch time; \n    however, searching itself also causes extra overhead. We recommend setting *n_search* from 15 to 45 for an optimal overall performance. \n\nStep 4\n---------------------------\nModify the input of the training function, by directly adding ARGO parameters after the original inputs.\n   \nThis is the original function:\n.. code-block:: python\n   def train(args, device, data):\n   \nAdd the following variables: *rank, world_size, comp_core, load_core, counter, b_size, ep*\n.. code-block:: python\n    def train(args, device, data, rank, world_size, comp_core, load_core, counter, b_size, ep):\n\nStep 5\n---------------------------\nModify the *dataloader* function in the training function\n.. code-block:: python\n    dataloader = dgl.dataloading.DataLoader(\n            g,\n            train_nid,\n            sampler,\n            batch_size=b_size, # modified\n            shuffle=True,\n            drop_last=False,\n            num_workers=len(load_core), # modified\n            use_ddp = True) # newly added\n\nStep 6\n---------------------------\nEnable core-binding by adding *enable_cpu_affinity()* before the training for-loop, and also change the number of epochs into the variable *ep*: \n.. code-block:: python\n    with dataloader.enable_cpu_affinity(loader_cores=load_core, compute_cores=comp_core): \n        for epoch in range(ep): # change num_epochs to ep\n   \nStep 7\n---------------------------\nLast step! Load the model before training and save it afterward.  \n\nOriginal Program:\n.. code-block:: python\n    with dataloader.enable_cpu_affinity(loader_cores=load_core, compute_cores=comp_core): \n        for epoch in range(ep): \n        ... # training operations\n   \nModified:\n.. code-block:: python\n    PATH = \"model.pt\"\n    if counter[0] != 0:\n        checkpoint = th.load(PATH)\n        model.load_state_dict(checkpoint['model_state_dict'])\n        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n        epoch = checkpoint['epoch']\n        loss = checkpoint['loss']\n   \n    with dataloader.enable_cpu_affinity(loader_cores=load_core, compute_cores=comp_core): \n        for epoch in range(ep): \n        ... # training operations\n   \n    dist.barrier()\n    if rank == 0:\n        th.save({'epoch': counter[0],\n                    'model_state_dict': model.state_dict(),\n                    'optimizer_state_dict': optimizer.state_dict(),\n                    'loss': loss,\n                    }, PATH)\n   \nStep 8\n---------------------------\nDone! You can now run your GNN program with ARGO enabled.\n.. code-block:: shell\n    python <your_code>.py\n\n    \n.. rubric:: Footnotes\n\n.. [#f1] https://github.com/dmlc/dgl/blob/master/examples/pytorch/argo/argo_scale.png\n.. [#f2] https://arxiv.org/abs/2402.03671\n.. [#f3] https://github.com/dmlc/dgl/blob/master/examples/pytorch/argo/ogb_example.py\n.. [#f4] https://github.com/dmlc/dgl/blob/master/examples/pytorch/argo/ogb_example_ARGO.py\n.. [#f5] https://github.com/dmlc/dgl/blob/master/examples/pytorch/argo/argo.py\n\"\"\"\n"
  },
  {
    "path": "tutorials/models/1_gnn/1_gcn.py",
    "content": "\"\"\"\r\n.. _model-gcn:\r\n\r\nGraph Convolutional Network\r\n====================================\r\n\r\n**Author:** `Qi Huang <https://github.com/HQ01>`_, `Minjie Wang  <https://jermainewang.github.io/>`_,\r\nYu Gai, Quan Gan, Zheng Zhang\r\n\r\n.. warning::\r\n\r\n    The tutorial aims at gaining insights into the paper, with code as a mean\r\n    of explanation. The implementation thus is NOT optimized for running\r\n    efficiency. For recommended implementation, please refer to the `official\r\n    examples <https://github.com/dmlc/dgl/tree/master/examples>`_.\r\n\r\nThis is a gentle introduction of using DGL to implement Graph Convolutional\r\nNetworks (Kipf & Welling et al., `Semi-Supervised Classification with Graph\r\nConvolutional Networks <https://arxiv.org/pdf/1609.02907.pdf>`_). We explain\r\nwhat is under the hood of the :class:`~dgl.nn.GraphConv` module.\r\nThe reader is expected to learn how to define a new GNN layer using DGL's\r\nmessage passing APIs.\r\n\"\"\"\n\n###############################################################################\n# Model Overview\n# ------------------------------------------\n# GCN from the perspective of message passing\n# ```````````````````````````````````````````````\n# We describe a layer of graph convolutional neural network from a message\n# passing perspective; the math can be found `here <math_>`_.\n# It boils down to the following step, for each node :math:`u`:\n#\n# 1) Aggregate neighbors' representations :math:`h_{v}` to produce an\n# intermediate representation :math:`\\hat{h}_u`.  2) Transform the aggregated\n# representation :math:`\\hat{h}_{u}` with a linear projection followed by a\n# non-linearity: :math:`h_{u} = f(W_{u} \\hat{h}_u)`.\n#\n# We will implement step 1 with DGL message passing, and step 2 by\n# PyTorch ``nn.Module``.\n#\n# GCN implementation with DGL\n# ``````````````````````````````````````````\n# We first define the message and reduce function as usual.  Since the\n# aggregation on a node :math:`u` only involves summing over the neighbors'\n# representations :math:`h_v`, we can simply use builtin functions:\n\nimport os\n\nos.environ[\"DGLBACKEND\"] = \"pytorch\"\nimport dgl\nimport dgl.function as fn\nimport torch as th\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dgl import DGLGraph\n\ngcn_msg = fn.copy_u(u=\"h\", out=\"m\")\ngcn_reduce = fn.sum(msg=\"m\", out=\"h\")\n\n###############################################################################\n# We then proceed to define the GCNLayer module. A GCNLayer essentially performs\n# message passing on all the nodes then applies a fully-connected layer.\n#\n# .. note::\n#\n#    This is showing how to implement a GCN from scratch.  DGL provides a more\n#    efficient :class:`builtin GCN layer module <dgl.nn.pytorch.conv.GraphConv>`.\n#\n\n\nclass GCNLayer(nn.Module):\n    def __init__(self, in_feats, out_feats):\n        super(GCNLayer, self).__init__()\n        self.linear = nn.Linear(in_feats, out_feats)\n\n    def forward(self, g, feature):\n        # Creating a local scope so that all the stored ndata and edata\n        # (such as the `'h'` ndata below) are automatically popped out\n        # when the scope exits.\n        with g.local_scope():\n            g.ndata[\"h\"] = feature\n            g.update_all(gcn_msg, gcn_reduce)\n            h = g.ndata[\"h\"]\n            return self.linear(h)\n\n\n###############################################################################\n# The forward function is essentially the same as any other commonly seen NNs\n# model in PyTorch.  We can initialize GCN like any ``nn.Module``. For example,\n# let's define a simple neural network consisting of two GCN layers. Suppose we\n# are training the classifier for the cora dataset (the input feature size is\n# 1433 and the number of classes is 7). The last GCN layer computes node embeddings,\n# so the last layer in general does not apply activation.\n\n\nclass Net(nn.Module):\n    def __init__(self):\n        super(Net, self).__init__()\n        self.layer1 = GCNLayer(1433, 16)\n        self.layer2 = GCNLayer(16, 7)\n\n    def forward(self, g, features):\n        x = F.relu(self.layer1(g, features))\n        x = self.layer2(g, x)\n        return x\n\n\nnet = Net()\nprint(net)\n\n###############################################################################\n# We load the cora dataset using DGL's built-in data module.\n\nfrom dgl.data import CoraGraphDataset\n\n\ndef load_cora_data():\n    dataset = CoraGraphDataset()\n    g = dataset[0]\n    features = g.ndata[\"feat\"]\n    labels = g.ndata[\"label\"]\n    train_mask = g.ndata[\"train_mask\"]\n    test_mask = g.ndata[\"test_mask\"]\n    return g, features, labels, train_mask, test_mask\n\n\n###############################################################################\n# When a model is trained, we can use the following method to evaluate\n# the performance of the model on the test dataset:\n\n\ndef evaluate(model, g, features, labels, mask):\n    model.eval()\n    with th.no_grad():\n        logits = model(g, features)\n        logits = logits[mask]\n        labels = labels[mask]\n        _, indices = th.max(logits, dim=1)\n        correct = th.sum(indices == labels)\n        return correct.item() * 1.0 / len(labels)\n\n\n###############################################################################\n# We then train the network as follows:\n\nimport time\n\nimport numpy as np\n\ng, features, labels, train_mask, test_mask = load_cora_data()\n# Add edges between each node and itself to preserve old node representations\ng.add_edges(g.nodes(), g.nodes())\noptimizer = th.optim.Adam(net.parameters(), lr=1e-2)\ndur = []\nfor epoch in range(50):\n    if epoch >= 3:\n        t0 = time.time()\n    net.train()\n    logits = net(g, features)\n    logp = F.log_softmax(logits, 1)\n    loss = F.nll_loss(logp[train_mask], labels[train_mask])\n\n    optimizer.zero_grad()\n    loss.backward()\n    optimizer.step()\n\n    if epoch >= 3:\n        dur.append(time.time() - t0)\n    acc = evaluate(net, g, features, labels, test_mask)\n    print(\n        \"Epoch {:05d} | Loss {:.4f} | Test Acc {:.4f} | Time(s) {:.4f}\".format(\n            epoch, loss.item(), acc, np.mean(dur)\n        )\n    )\n###############################################################################\n# .. _math:\n#\n# GCN in one formula\n# ------------------\n# Mathematically, the GCN model follows this formula:\n#\n# :math:`H^{(l+1)} = \\sigma(\\tilde{D}^{-\\frac{1}{2}}\\tilde{A}\\tilde{D}^{-\\frac{1}{2}}H^{(l)}W^{(l)})`\n#\n# Here, :math:`H^{(l)}` denotes the :math:`l^{th}` layer in the network,\n# :math:`\\sigma` is the non-linearity, and :math:`W` is the weight matrix for\n# this layer. :math:`\\tilde{D}` and :math:`\\tilde{A}` are separately the degree\n# and adjacency matrices for the graph. With the superscript ~, we are referring\n# to the variant where we add additional edges between each node and itself to\n# preserve its old representation in graph convolutions. The shape of the input\n# :math:`H^{(0)}` is :math:`N \\times D`, where :math:`N` is the number of nodes\n# and :math:`D` is the number of input features. We can chain up multiple\n# layers as such to produce a node-level representation output with shape\n# :math:`N \\times F`, where :math:`F` is the dimension of the output node\n# feature vector.\n#\n# The equation can be efficiently implemented using sparse matrix\n# multiplication kernels (such as Kipf's\n# `pygcn <https://github.com/tkipf/pygcn>`_ code). The above DGL implementation\n# in fact has already used this trick due to the use of builtin functions.\n#\n# Note that the tutorial code implements a simplified version of GCN where we\n# replace :math:`\\tilde{D}^{-\\frac{1}{2}}\\tilde{A}\\tilde{D}^{-\\frac{1}{2}}` with\n# :math:`\\tilde{A}`. For a full implementation, see our example\n# `here  <https://github.com/dmlc/dgl/tree/master/examples/pytorch/gcn>`_.\n"
  },
  {
    "path": "tutorials/models/2_small_graph/README.txt",
    "content": ".. _tutorials2-index:\n\nBatching many small graphs\n-------------------------------\n\n* **Tree-LSTM** `[paper] <https://arxiv.org/abs/1503.00075>`__ `[tutorial]\n  <2_small_graph/3_tree-lstm.html>`__ `[PyTorch code]\n  <https://github.com/dmlc/dgl/blob/master/examples/pytorch/tree_lstm>`__:\n  Sentences have inherent structures that are thrown\n  away by treating them simply as sequences. Tree-LSTM is a powerful model\n  that learns the representation by using prior syntactic structures such as a parse-tree.\n  The challenge in training is that simply by padding\n  a sentence to the maximum length no longer works. Trees of different\n  sentences have different sizes and topologies. DGL solves this problem by\n  adding the trees to a bigger container graph, and then using message-passing\n  to explore maximum parallelism. Batching is a key API for this.\n"
  },
  {
    "path": "tutorials/models/3_generative_model/5_dgmg.py",
    "content": "\"\"\"\r\n.. _model-dgmg:\r\n\r\nGenerative Models of Graphs\r\n===========================================\r\n\r\n**Author**: `Mufei Li <https://github.com/mufeili>`_,\r\n`Lingfan Yu <https://github.com/ylfdq1118>`_, Zheng Zhang\r\n\r\n.. warning::\r\n\r\n    The tutorial aims at gaining insights into the paper, with code as a mean\r\n    of explanation. The implementation thus is NOT optimized for running\r\n    efficiency. For recommended implementation, please refer to the `official\r\n    examples <https://github.com/dmlc/dgl/tree/master/examples>`_.\r\n\r\n\"\"\"\n\n##############################################################################\n#\n# In this tutorial, you learn how to train and generate one graph at\n# a time. You also explore parallelism within the graph embedding operation, which is an\n# essential building block. The tutorial ends with a simple optimization that\n# delivers double the speed by batching across graphs.\n#\n# Earlier tutorials showed how embedding a graph or\n# a node enables you to work on tasks such as `semi-supervised classification for nodes\n# <http://docs.dgl.ai/tutorials/models/1_gcn.html#sphx-glr-tutorials-models-1-gcn-py>`__\n# or `sentiment analysis\n# <http://docs.dgl.ai/tutorials/models/3_tree-lstm.html#sphx-glr-tutorials-models-3-tree-lstm-py>`__.\n# Wouldn't it be interesting to predict the future evolution of the graph and\n# perform the analysis iteratively?\n#\n# To address the evolution of the graphs, you generate a variety of graph samples. In other words, you need\n# **generative models** of graphs. In-addition to learning\n# node and edge features, you would need to model the distribution of arbitrary graphs.\n# While general generative models can model the density function explicitly and\n# implicitly and generate samples at once or sequentially, you only focus\n# on explicit generative models for sequential generation here. Typical applications\n# include drug or materials discovery, chemical processes, or proteomics.\n#\n# Introduction\n# --------------------\n# The primitive actions of mutating a graph in Deep Graph Library (DGL) are nothing more than ``add_nodes``\n# and ``add_edges``. That is, if you were to draw a circle of three nodes,\n#\n# .. figure:: https://user-images.githubusercontent.com/19576924/48313438-78baf000-e5f7-11e8-931e-cd00ab34fa50.gif\n#    :alt:\n#\n# you can write the code as follows.\n#\n\nimport os\n\nos.environ[\"DGLBACKEND\"] = \"pytorch\"\nimport dgl\n\ng = dgl.DGLGraph()\ng.add_nodes(1)  # Add node 0\ng.add_nodes(1)  # Add node 1\n\n# Edges in DGLGraph are directed by default.\n# For undirected edges, add edges for both directions.\ng.add_edges([1, 0], [0, 1])  # Add edges (1, 0), (0, 1)\ng.add_nodes(1)  # Add node 2\ng.add_edges([2, 1], [1, 2])  # Add edges (2, 1), (1, 2)\ng.add_edges([2, 0], [0, 2])  # Add edges (2, 0), (0, 2)\n\n#######################################################################################\n# Real-world graphs are much more complex. There are many families of graphs,\n# with different sizes, topologies, node types, edge types, and the possibility\n# of multigraphs. Besides, a same graph can be generated in many different\n# orders. Regardless, the generative process entails a few steps.\n#\n# - Encode a changing graph.\n# - Perform actions stochastically.\n# - If you are training, collect error signals and optimize the model parameters.\n#\n# When it comes to implementation, another important aspect is speed. How do you\n# parallelize the computation, given that generating a graph is fundamentally a\n# sequential process?\n#\n# .. note::\n#\n#    To be sure, this is not necessarily a hard constraint. Subgraphs can be\n#    built in parallel and then get assembled. But we\n#    will restrict ourselves to the sequential processes for this tutorial.\n#\n#\n# DGMG: The main flow\n# --------------------\n# For this tutorial, you use\n# `Deep Generative Models of Graphs <https://arxiv.org/abs/1803.03324>`__\n# ) (DGMG) to implement a graph generative model using DGL. Its algorithmic\n# framework is general but also challenging to parallelize.\n#\n# .. note::\n#\n#    While it's possible for DGMG to handle complex graphs with typed nodes,\n#    typed edges, and multigraphs, here you use a simplified version of it\n#    for generating graph topologies.\n#\n# DGMG generates a graph by following a state machine, which is basically a\n# two-level loop. Generate one node at a time and connect it to a subset of\n# the existing nodes, one at a time. This is similar to language modeling. The\n# generative process is an iterative one that emits one word or character or sentence\n# at a time, conditioned on the sequence generated so far.\n#\n# At each time step, you either:\n#      - Add a new node to the graph\n#      - Select two existing nodes and add an edge between them\n#\n# .. figure:: https://user-images.githubusercontent.com/19576924/48605003-7f11e900-e9b6-11e8-8880-87362348e154.png\n#    :alt:\n#\n# The Python code will look as follows. In fact, this is *exactly* how inference\n# with DGMG is implemented in DGL.\n#\n\n\ndef forward_inference(self):\n    stop = self.add_node_and_update()\n    while (not stop) and (self.g.num_nodes() < self.v_max + 1):\n        num_trials = 0\n        to_add_edge = self.add_edge_or_not()\n        while to_add_edge and (num_trials < self.g.num_nodes() - 1):\n            self.choose_dest_and_update()\n            num_trials += 1\n            to_add_edge = self.add_edge_or_not()\n        stop = self.add_node_and_update()\n    return self.g\n\n\n#######################################################################################\n# Assume you have a pre-trained model for generating cycles of nodes 10-20.\n# How does it generate a cycle on-the-fly during inference? Use the code below\n# to create an animation with your own model.\n#\n# ::\n#\n#     import torch\n#     import matplotlib.animation as animation\n#     import matplotlib.pyplot as plt\n#     import networkx as nx\n#     from copy import deepcopy\n#\n#     if __name__ == '__main__':\n#         # pre-trained model saved with path ./model.pth\n#         model = torch.load('./model.pth')\n#         model.eval()\n#         g = model()\n#\n#         src_list = g.edges()[1]\n#         dest_list = g.edges()[0]\n#\n#         evolution = []\n#\n#         nx_g = nx.Graph()\n#         evolution.append(deepcopy(nx_g))\n#\n#         for i in range(0, len(src_list), 2):\n#             src = src_list[i].item()\n#             dest = dest_list[i].item()\n#             if src not in nx_g.nodes():\n#                 nx_g.add_node(src)\n#                 evolution.append(deepcopy(nx_g))\n#             if dest not in nx_g.nodes():\n#                 nx_g.add_node(dest)\n#                 evolution.append(deepcopy(nx_g))\n#             nx_g.add_edges_from([(src, dest), (dest, src)])\n#             evolution.append(deepcopy(nx_g))\n#\n#         def animate(i):\n#             ax.cla()\n#             g_t = evolution[i]\n#             nx.draw_circular(g_t, with_labels=True, ax=ax,\n#                              node_color=['#FEBD69'] * g_t.num_nodes())\n#\n#         fig, ax = plt.subplots()\n#         ani = animation.FuncAnimation(fig, animate,\n#                                       frames=len(evolution),\n#                                       interval=600)\n#\n# .. figure:: https://user-images.githubusercontent.com/19576924/48928548-2644d200-ef1b-11e8-8591-da93345382ad.gif\n#    :alt:\n#\n# DGMG: Optimization objective\n# ------------------------------\n# Similar to language modeling, DGMG trains the model with *behavior cloning*,\n# or *teacher forcing*. Assume for each graph there exists a sequence of\n# *oracle actions* :math:`a_{1},\\cdots,a_{T}` that generates it. What the model\n# does is to follow these actions, compute the joint probabilities of such\n# action sequences, and maximize them.\n#\n# By chain rule, the probability of taking :math:`a_{1},\\cdots,a_{T}` is:\n#\n# .. math::\n#\n#    p(a_{1},\\cdots, a_{T}) = p(a_{1})p(a_{2}|a_{1})\\cdots p(a_{T}|a_{1},\\cdots,a_{T-1}).\\\\\n#\n# The optimization objective is then simply the typical MLE loss:\n#\n# .. math::\n#\n#    -\\log p(a_{1},\\cdots,a_{T})=-\\sum_{t=1}^{T}\\log p(a_{t}|a_{1},\\cdots, a_{t-1}).\\\\\n#\n\n\ndef forward_train(self, actions):\n    \"\"\"\n    - actions: list\n        - Contains a_1, ..., a_T described above\n    - self.prepare_for_train()\n        - Initializes self.action_step to be 0, which will get\n          incremented by 1 every time it is called.\n        - Initializes objects recording log p(a_t|a_1,...a_{t-1})\n\n    Returns\n    -------\n    - self.get_log_prob(): log p(a_1, ..., a_T)\n    \"\"\"\n    self.prepare_for_train()\n\n    stop = self.add_node_and_update(a=actions[self.action_step])\n    while not stop:\n        to_add_edge = self.add_edge_or_not(a=actions[self.action_step])\n        while to_add_edge:\n            self.choose_dest_and_update(a=actions[self.action_step])\n            to_add_edge = self.add_edge_or_not(a=actions[self.action_step])\n        stop = self.add_node_and_update(a=actions[self.action_step])\n    return self.get_log_prob()\n\n\n#######################################################################################\n# The key difference between ``forward_train`` and ``forward_inference`` is\n# that the training process takes oracle actions as input and returns log\n# probabilities for evaluating the loss.\n#\n# DGMG: The implementation\n# --------------------------\n# The ``DGMG`` class\n# ``````````````````````````\n# Below you can find the skeleton code for the model. You gradually\n# fill in the details for each function.\n#\n\nimport torch.nn as nn\n\n\nclass DGMGSkeleton(nn.Module):\n    def __init__(self, v_max):\n        \"\"\"\n        Parameters\n        ----------\n        v_max: int\n            Max number of nodes considered\n        \"\"\"\n        super(DGMGSkeleton, self).__init__()\n\n        # Graph configuration\n        self.v_max = v_max\n\n    def add_node_and_update(self, a=None):\n        \"\"\"Decide if to add a new node.\n        If a new node should be added, update the graph.\"\"\"\n        return NotImplementedError\n\n    def add_edge_or_not(self, a=None):\n        \"\"\"Decide if a new edge should be added.\"\"\"\n        return NotImplementedError\n\n    def choose_dest_and_update(self, a=None):\n        \"\"\"Choose destination and connect it to the latest node.\n        Add edges for both directions and update the graph.\"\"\"\n        return NotImplementedError\n\n    def forward_train(self, actions):\n        \"\"\"Forward at training time. It records the probability\n        of generating a ground truth graph following the actions.\"\"\"\n        return NotImplementedError\n\n    def forward_inference(self):\n        \"\"\"Forward at inference time.\n        It generates graphs on the fly.\"\"\"\n        return NotImplementedError\n\n    def forward(self, actions=None):\n        # The graph you will work on\n        self.g = dgl.DGLGraph()\n\n        # If there are some features for nodes and edges,\n        # zero tensors will be set for those of new nodes and edges.\n        self.g.set_n_initializer(dgl.frame.zero_initializer)\n        self.g.set_e_initializer(dgl.frame.zero_initializer)\n\n        if self.training:\n            return self.forward_train(actions=actions)\n        else:\n            return self.forward_inference()\n\n\n#######################################################################################\n# Encoding a dynamic graph\n# ``````````````````````````\n# All the actions generating a graph are sampled from probability\n# distributions. In order to do that, you project the structured data,\n# namely the graph, onto an Euclidean space. The challenge is that such\n# process, called *embedding*, needs to be repeated as the graphs mutate.\n#\n# Graph embedding\n# ''''''''''''''''''''''''''\n# Let :math:`G=(V,E)` be an arbitrary graph. Each node :math:`v` has an\n# embedding vector :math:`\\textbf{h}_{v} \\in \\mathbb{R}^{n}`. Similarly,\n# the graph has an embedding vector :math:`\\textbf{h}_{G} \\in \\mathbb{R}^{k}`.\n# Typically, :math:`k > n` since a graph contains more information than\n# an individual node.\n#\n# The graph embedding is a weighted sum of node embeddings under a linear\n# transformation:\n#\n# .. math::\n#\n#    \\textbf{h}_{G} =\\sum_{v\\in V}\\text{Sigmoid}(g_m(\\textbf{h}_{v}))f_{m}(\\textbf{h}_{v}),\\\\\n#\n# The first term, :math:`\\text{Sigmoid}(g_m(\\textbf{h}_{v}))`, computes a\n# gating function and can be thought of as how much the overall graph embedding\n# attends on each node. The second term :math:`f_{m}:\\mathbb{R}^{n}\\rightarrow\\mathbb{R}^{k}`\n# maps the node embeddings to the space of graph embeddings.\n#\n# Implement graph embedding as a ``GraphEmbed`` class.\n#\n\nimport torch\n\n\nclass GraphEmbed(nn.Module):\n    def __init__(self, node_hidden_size):\n        super(GraphEmbed, self).__init__()\n\n        # Setting from the paper\n        self.graph_hidden_size = 2 * node_hidden_size\n\n        # Embed graphs\n        self.node_gating = nn.Sequential(\n            nn.Linear(node_hidden_size, 1), nn.Sigmoid()\n        )\n        self.node_to_graph = nn.Linear(node_hidden_size, self.graph_hidden_size)\n\n    def forward(self, g):\n        if g.num_nodes() == 0:\n            return torch.zeros(1, self.graph_hidden_size)\n        else:\n            # Node features are stored as hv in ndata.\n            hvs = g.ndata[\"hv\"]\n            return (self.node_gating(hvs) * self.node_to_graph(hvs)).sum(\n                0, keepdim=True\n            )\n\n\n#######################################################################################\n# Update node embeddings via graph propagation\n# '''''''''''''''''''''''''''''''''''''''''''''\n#\n# The mechanism of updating node embeddings in DGMG is similar to that for\n# graph convolutional networks. For a node :math:`v` in the graph, its\n# neighbor :math:`u` sends a message to it with\n#\n# .. math::\n#\n#    \\textbf{m}_{u\\rightarrow v}=\\textbf{W}_{m}\\text{concat}([\\textbf{h}_{v}, \\textbf{h}_{u}, \\textbf{x}_{u, v}]) + \\textbf{b}_{m},\\\\\n#\n# where :math:`\\textbf{x}_{u,v}` is the embedding of the edge between\n# :math:`u` and :math:`v`.\n#\n# After receiving messages from all its neighbors, :math:`v` summarizes them\n# with a node activation vector\n#\n# .. math::\n#\n#    \\textbf{a}_{v} = \\sum_{u: (u, v)\\in E}\\textbf{m}_{u\\rightarrow v}\\\\\n#\n# and use this information to update its own feature:\n#\n# .. math::\n#\n#    \\textbf{h}'_{v} = \\textbf{GRU}(\\textbf{h}_{v}, \\textbf{a}_{v}).\\\\\n#\n# Performing all the operations above once for all nodes synchronously is\n# called one round of graph propagation. The more rounds of graph propagation\n# you perform, the longer distance messages travel throughout the graph.\n#\n# With DGL, you implement graph propagation with ``g.update_all``.\n# The message notation here can be a bit confusing. Researchers can refer\n# to :math:`\\textbf{m}_{u\\rightarrow v}` as messages, however the message function\n# below only passes :math:`\\text{concat}([\\textbf{h}_{u}, \\textbf{x}_{u, v}])`.\n# The operation :math:`\\textbf{W}_{m}\\text{concat}([\\textbf{h}_{v}, \\textbf{h}_{u}, \\textbf{x}_{u, v}]) + \\textbf{b}_{m}`\n# is then performed across all edges at once for efficiency consideration.\n#\n\nfrom functools import partial\n\n\nclass GraphProp(nn.Module):\n    def __init__(self, num_prop_rounds, node_hidden_size):\n        super(GraphProp, self).__init__()\n\n        self.num_prop_rounds = num_prop_rounds\n\n        # Setting from the paper\n        self.node_activation_hidden_size = 2 * node_hidden_size\n\n        message_funcs = []\n        node_update_funcs = []\n        self.reduce_funcs = []\n\n        for t in range(num_prop_rounds):\n            # input being [hv, hu, xuv]\n            message_funcs.append(\n                nn.Linear(\n                    2 * node_hidden_size + 1, self.node_activation_hidden_size\n                )\n            )\n\n            self.reduce_funcs.append(partial(self.dgmg_reduce, round=t))\n            node_update_funcs.append(\n                nn.GRUCell(self.node_activation_hidden_size, node_hidden_size)\n            )\n        self.message_funcs = nn.ModuleList(message_funcs)\n        self.node_update_funcs = nn.ModuleList(node_update_funcs)\n\n    def dgmg_msg(self, edges):\n        \"\"\"For an edge u->v, return concat([h_u, x_uv])\"\"\"\n        return {\"m\": torch.cat([edges.src[\"hv\"], edges.data[\"he\"]], dim=1)}\n\n    def dgmg_reduce(self, nodes, round):\n        hv_old = nodes.data[\"hv\"]\n        m = nodes.mailbox[\"m\"]\n        message = torch.cat(\n            [hv_old.unsqueeze(1).expand(-1, m.size(1), -1), m], dim=2\n        )\n        node_activation = (self.message_funcs[round](message)).sum(1)\n\n        return {\"a\": node_activation}\n\n    def forward(self, g):\n        if g.num_edges() > 0:\n            for t in range(self.num_prop_rounds):\n                g.update_all(\n                    message_func=self.dgmg_msg, reduce_func=self.reduce_funcs[t]\n                )\n                g.ndata[\"hv\"] = self.node_update_funcs[t](\n                    g.ndata[\"a\"], g.ndata[\"hv\"]\n                )\n\n\n#######################################################################################\n# Actions\n# ``````````````````````````\n# All actions are sampled from distributions parameterized using neural networks\n# and here they are in turn.\n#\n# Action 1: Add nodes\n# ''''''''''''''''''''''''''\n#\n# Given the graph embedding vector :math:`\\textbf{h}_{G}`, evaluate\n#\n# .. math::\n#\n#    \\text{Sigmoid}(\\textbf{W}_{\\text{add node}}\\textbf{h}_{G}+b_{\\text{add node}}),\\\\\n#\n# which is then used to parametrize a Bernoulli distribution for deciding whether\n# to add a new node.\n#\n# If a new node is to be added, initialize its feature with\n#\n# .. math::\n#\n#    \\textbf{W}_{\\text{init}}\\text{concat}([\\textbf{h}_{\\text{init}} , \\textbf{h}_{G}])+\\textbf{b}_{\\text{init}},\\\\\n#\n# where :math:`\\textbf{h}_{\\text{init}}` is a learnable embedding module for\n# untyped nodes.\n#\n\nimport torch.nn.functional as F\nfrom torch.distributions import Bernoulli\n\n\ndef bernoulli_action_log_prob(logit, action):\n    \"\"\"Calculate the log p of an action with respect to a Bernoulli\n    distribution. Use logit rather than prob for numerical stability.\"\"\"\n    if action == 0:\n        return F.logsigmoid(-logit)\n    else:\n        return F.logsigmoid(logit)\n\n\nclass AddNode(nn.Module):\n    def __init__(self, graph_embed_func, node_hidden_size):\n        super(AddNode, self).__init__()\n\n        self.graph_op = {\"embed\": graph_embed_func}\n\n        self.stop = 1\n        self.add_node = nn.Linear(graph_embed_func.graph_hidden_size, 1)\n\n        # If to add a node, initialize its hv\n        self.node_type_embed = nn.Embedding(1, node_hidden_size)\n        self.initialize_hv = nn.Linear(\n            node_hidden_size + graph_embed_func.graph_hidden_size,\n            node_hidden_size,\n        )\n\n        self.init_node_activation = torch.zeros(1, 2 * node_hidden_size)\n\n    def _initialize_node_repr(self, g, node_type, graph_embed):\n        \"\"\"Whenver a node is added, initialize its representation.\"\"\"\n        num_nodes = g.num_nodes()\n        hv_init = self.initialize_hv(\n            torch.cat(\n                [\n                    self.node_type_embed(torch.LongTensor([node_type])),\n                    graph_embed,\n                ],\n                dim=1,\n            )\n        )\n        g.nodes[num_nodes - 1].data[\"hv\"] = hv_init\n        g.nodes[num_nodes - 1].data[\"a\"] = self.init_node_activation\n\n    def prepare_training(self):\n        self.log_prob = []\n\n    def forward(self, g, action=None):\n        graph_embed = self.graph_op[\"embed\"](g)\n\n        logit = self.add_node(graph_embed)\n        prob = torch.sigmoid(logit)\n\n        if not self.training:\n            action = Bernoulli(prob).sample().item()\n        stop = bool(action == self.stop)\n\n        if not stop:\n            g.add_nodes(1)\n            self._initialize_node_repr(g, action, graph_embed)\n        if self.training:\n            sample_log_prob = bernoulli_action_log_prob(logit, action)\n\n            self.log_prob.append(sample_log_prob)\n        return stop\n\n\n#######################################################################################\n# Action 2: Add edges\n# ''''''''''''''''''''''''''\n#\n# Given the graph embedding vector :math:`\\textbf{h}_{G}` and the node\n# embedding vector :math:`\\textbf{h}_{v}` for the latest node :math:`v`,\n# you evaluate\n#\n# .. math::\n#\n#    \\text{Sigmoid}(\\textbf{W}_{\\text{add edge}}\\text{concat}([\\textbf{h}_{G}, \\textbf{h}_{v}])+b_{\\text{add edge}}),\\\\\n#\n# which is then used to parametrize a Bernoulli distribution for deciding\n# whether to add a new edge starting from :math:`v`.\n#\n\n\nclass AddEdge(nn.Module):\n    def __init__(self, graph_embed_func, node_hidden_size):\n        super(AddEdge, self).__init__()\n\n        self.graph_op = {\"embed\": graph_embed_func}\n        self.add_edge = nn.Linear(\n            graph_embed_func.graph_hidden_size + node_hidden_size, 1\n        )\n\n    def prepare_training(self):\n        self.log_prob = []\n\n    def forward(self, g, action=None):\n        graph_embed = self.graph_op[\"embed\"](g)\n        src_embed = g.nodes[g.num_nodes() - 1].data[\"hv\"]\n\n        logit = self.add_edge(torch.cat([graph_embed, src_embed], dim=1))\n        prob = torch.sigmoid(logit)\n\n        if self.training:\n            sample_log_prob = bernoulli_action_log_prob(logit, action)\n            self.log_prob.append(sample_log_prob)\n        else:\n            action = Bernoulli(prob).sample().item()\n        to_add_edge = bool(action == 0)\n        return to_add_edge\n\n\n#######################################################################################\n# Action 3: Choose a destination\n# '''''''''''''''''''''''''''''''''\n#\n# When action 2 returns `True`, choose a destination for the\n# latest node :math:`v`.\n#\n# For each possible destination :math:`u\\in\\{0, \\cdots, v-1\\}`, the\n# probability of choosing it is given by\n#\n# .. math::\n#\n#    \\frac{\\text{exp}(\\textbf{W}_{\\text{dest}}\\text{concat}([\\textbf{h}_{u}, \\textbf{h}_{v}])+\\textbf{b}_{\\text{dest}})}{\\sum_{i=0}^{v-1}\\text{exp}(\\textbf{W}_{\\text{dest}}\\text{concat}([\\textbf{h}_{i}, \\textbf{h}_{v}])+\\textbf{b}_{\\text{dest}})}\\\\\n#\n\nfrom torch.distributions import Categorical\n\n\nclass ChooseDestAndUpdate(nn.Module):\n    def __init__(self, graph_prop_func, node_hidden_size):\n        super(ChooseDestAndUpdate, self).__init__()\n\n        self.graph_op = {\"prop\": graph_prop_func}\n        self.choose_dest = nn.Linear(2 * node_hidden_size, 1)\n\n    def _initialize_edge_repr(self, g, src_list, dest_list):\n        # For untyped edges, only add 1 to indicate its existence.\n        # For multiple edge types, use a one-hot representation\n        # or an embedding module.\n        edge_repr = torch.ones(len(src_list), 1)\n        g.edges[src_list, dest_list].data[\"he\"] = edge_repr\n\n    def prepare_training(self):\n        self.log_prob = []\n\n    def forward(self, g, dest):\n        src = g.num_nodes() - 1\n        possible_dests = range(src)\n\n        src_embed_expand = g.nodes[src].data[\"hv\"].expand(src, -1)\n        possible_dests_embed = g.nodes[possible_dests].data[\"hv\"]\n\n        dests_scores = self.choose_dest(\n            torch.cat([possible_dests_embed, src_embed_expand], dim=1)\n        ).view(1, -1)\n        dests_probs = F.softmax(dests_scores, dim=1)\n\n        if not self.training:\n            dest = Categorical(dests_probs).sample().item()\n        if not g.has_edges_between(src, dest):\n            # For undirected graphs, add edges for both directions\n            # so that you can perform graph propagation.\n            src_list = [src, dest]\n            dest_list = [dest, src]\n\n            g.add_edges(src_list, dest_list)\n            self._initialize_edge_repr(g, src_list, dest_list)\n\n            self.graph_op[\"prop\"](g)\n        if self.training:\n            if dests_probs.nelement() > 1:\n                self.log_prob.append(\n                    F.log_softmax(dests_scores, dim=1)[:, dest : dest + 1]\n                )\n\n\n#######################################################################################\n# Putting it together\n# ``````````````````````````\n#\n# You are now ready to have a complete implementation of the model class.\n#\n\n\nclass DGMG(DGMGSkeleton):\n    def __init__(self, v_max, node_hidden_size, num_prop_rounds):\n        super(DGMG, self).__init__(v_max)\n\n        # Graph embedding module\n        self.graph_embed = GraphEmbed(node_hidden_size)\n\n        # Graph propagation module\n        self.graph_prop = GraphProp(num_prop_rounds, node_hidden_size)\n\n        # Actions\n        self.add_node_agent = AddNode(self.graph_embed, node_hidden_size)\n        self.add_edge_agent = AddEdge(self.graph_embed, node_hidden_size)\n        self.choose_dest_agent = ChooseDestAndUpdate(\n            self.graph_prop, node_hidden_size\n        )\n\n        # Forward functions\n        self.forward_train = partial(forward_train, self=self)\n        self.forward_inference = partial(forward_inference, self=self)\n\n    @property\n    def action_step(self):\n        old_step_count = self.step_count\n        self.step_count += 1\n\n        return old_step_count\n\n    def prepare_for_train(self):\n        self.step_count = 0\n\n        self.add_node_agent.prepare_training()\n        self.add_edge_agent.prepare_training()\n        self.choose_dest_agent.prepare_training()\n\n    def add_node_and_update(self, a=None):\n        \"\"\"Decide if to add a new node.\n        If a new node should be added, update the graph.\"\"\"\n\n        return self.add_node_agent(self.g, a)\n\n    def add_edge_or_not(self, a=None):\n        \"\"\"Decide if a new edge should be added.\"\"\"\n\n        return self.add_edge_agent(self.g, a)\n\n    def choose_dest_and_update(self, a=None):\n        \"\"\"Choose destination and connect it to the latest node.\n        Add edges for both directions and update the graph.\"\"\"\n\n        self.choose_dest_agent(self.g, a)\n\n    def get_log_prob(self):\n        add_node_log_p = torch.cat(self.add_node_agent.log_prob).sum()\n        add_edge_log_p = torch.cat(self.add_edge_agent.log_prob).sum()\n        choose_dest_log_p = torch.cat(self.choose_dest_agent.log_prob).sum()\n        return add_node_log_p + add_edge_log_p + choose_dest_log_p\n\n\n#######################################################################################\n# Below is an animation where a graph is generated on the fly\n# after every 10 batches of training for the first 400 batches. You\n# can see how the model improves over time and begins generating cycles.\n#\n# .. figure:: https://user-images.githubusercontent.com/19576924/48929291-60fe3880-ef22-11e8-832a-fbe56656559a.gif\n#    :alt:\n#\n# For generative models, you can evaluate performance by checking the percentage\n# of valid graphs among the graphs it generates on the fly.\n\nimport torch.utils.model_zoo as model_zoo\n\n# Download a pre-trained model state dict for generating cycles with 10-20 nodes.\nstate_dict = model_zoo.load_url(\n    \"https://data.dgl.ai/model/dgmg_cycles-5a0c40be.pth\"\n)\nmodel = DGMG(v_max=20, node_hidden_size=16, num_prop_rounds=2)\nmodel.load_state_dict(state_dict)\nmodel.eval()\n\n\ndef is_valid(g):\n    # Check if g is a cycle having 10-20 nodes.\n    def _get_previous(i, v_max):\n        if i == 0:\n            return v_max\n        else:\n            return i - 1\n\n    def _get_next(i, v_max):\n        if i == v_max:\n            return 0\n        else:\n            return i + 1\n\n    size = g.num_nodes()\n\n    if size < 10 or size > 20:\n        return False\n    for node in range(size):\n        neighbors = g.successors(node)\n\n        if len(neighbors) != 2:\n            return False\n        if _get_previous(node, size - 1) not in neighbors:\n            return False\n        if _get_next(node, size - 1) not in neighbors:\n            return False\n    return True\n\n\nnum_valid = 0\nfor i in range(100):\n    g = model()\n    num_valid += is_valid(g)\ndel model\nprint(\"Among 100 graphs generated, {}% are valid.\".format(num_valid))\n\n#######################################################################################\n# For the complete implementation, see the `DGL DGMG example\n# <https://github.com/dmlc/dgl/tree/master/examples/pytorch/dgmg>`__.\n#\n"
  },
  {
    "path": "tutorials/models/4_old_wines/7_transformer.py",
    "content": "\"\"\"\n.. _model-transformer:\n\nTransformer as a Graph Neural Network\n======================================\n\n**Author**: Zihao Ye, Jinjing Zhou, Qipeng Guo, Quan Gan, Zheng Zhang\n\n.. warning::\n\n    The tutorial aims at gaining insights into the paper, with code as a mean\n    of explanation. The implementation thus is NOT optimized for running\n    efficiency. For recommended implementation, please refer to the `official\n    examples <https://github.com/dmlc/dgl/tree/master/examples>`_.\n\n\"\"\"\n################################################################################################\n# In this tutorial, you learn about a simplified implementation of the Transformer model.\n# You can see highlights of the most important design points. For instance, there is\n# only single-head attention. The complete code can be found\n# `here <https://github.com/dmlc/dgl/tree/master/examples/pytorch/transformer>`__.\n#\n# The overall structure is similar to the one from the research papaer `Annotated\n# Transformer <http://nlp.seas.harvard.edu/2018/04/03/attention.html>`__.\n#\n# The Transformer model, as a replacement of CNN/RNN architecture for\n# sequence modeling, was introduced in the research paper: `Attention is All\n# You Need <https://arxiv.org/pdf/1706.03762.pdf>`__. It improved the\n# state of the art for machine translation as well as natural language\n# inference task\n# (`GPT <https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf>`__).\n# Recent work on pre-training Transformer with large scale corpus\n# (`BERT <https://arxiv.org/pdf/1810.04805.pdf>`__) supports that it is\n# capable of learning high-quality semantic representation.\n#\n# The interesting part of Transformer is its extensive employment of\n# attention. The classic use of attention comes from machine translation\n# model, where the output token attends to all input tokens.\n#\n# Transformer additionally applies *self-attention* in both decoder and\n# encoder. This process forces words relate to each other to combine\n# together, irrespective of their positions in the sequence. This is\n# different from RNN-based model, where words (in the source sentence) are\n# combined along the chain, which is thought to be too constrained.\n#\n# Attention layer of Transformer\n# ------------------------------\n#\n# In the attention layer of Transformer, for each node the module learns to\n# assign weights on its in-coming edges. For node pair :math:`(i, j)`\n# (from :math:`i` to :math:`j`) with node\n# :math:`x_i, x_j \\in \\mathbb{R}^n`, the score of their connection is\n# defined as follows:\n#\n# .. math::\n#\n#\n#    q_j = W_q\\cdot x_j \\\\\n#    k_i = W_k\\cdot x_i\\\\\n#    v_i = W_v\\cdot x_i\\\\\n#    \\textrm{score} = q_j^T k_i\n#\n# where :math:`W_q, W_k, W_v \\in \\mathbb{R}^{n\\times d_k}` map the\n# representations :math:`x` to “query”, “key”, and “value” space\n# respectively.\n#\n# There are other possibilities to implement the score function. The dot\n# product measures the similarity of a given query :math:`q_j` and a key\n# :math:`k_i`: if :math:`j` needs the information stored in :math:`i`, the\n# query vector at position :math:`j` (:math:`q_j`) is supposed to be close\n# to key vector at position :math:`i` (:math:`k_i`).\n#\n# The score is then used to compute the sum of the incoming values,\n# normalized over the weights of edges, stored in :math:`\\textrm{wv}`.\n# Then apply an affine layer to :math:`\\textrm{wv}` to get the output\n# :math:`o`:\n#\n# .. math::\n#\n#\n#    w_{ji} = \\frac{\\exp\\{\\textrm{score}_{ji} \\}}{\\sum\\limits_{(k, i)\\in E}\\exp\\{\\textrm{score}_{ki} \\}} \\\\\n#    \\textrm{wv}_i = \\sum_{(k, i)\\in E} w_{ki} v_k \\\\\n#    o = W_o\\cdot \\textrm{wv} \\\\\n#\n# Multi-head attention layer\n# ~~~~~~~~~~~~~~~~~~~~~~~~~~\n#\n# In Transformer, attention is *multi-headed*. A head is very much like a\n# channel in a convolutional network. The multi-head attention consists of\n# multiple attention heads, in which each head refers to a single\n# attention module. :math:`\\textrm{wv}^{(i)}` for all the heads are\n# concatenated and mapped to output :math:`o` with an affine layer:\n#\n# .. math::\n#\n#\n#    o = W_o \\cdot \\textrm{concat}\\left(\\left[\\textrm{wv}^{(0)}, \\textrm{wv}^{(1)}, \\cdots, \\textrm{wv}^{(h)}\\right]\\right)\n#\n# The code below wraps necessary components for multi-head attention, and\n# provides two interfaces.\n#\n# -  ``get`` maps state ‘x’, to query, key and value, which is required by\n#    following steps(\\ ``propagate_attention``).\n# -  ``get_o`` maps the updated value after attention to the output\n#    :math:`o` for post-processing.\n#\n# .. code::\n#\n#    class MultiHeadAttention(nn.Module):\n#        \"Multi-Head Attention\"\n#        def __init__(self, h, dim_model):\n#            \"h: number of heads; dim_model: hidden dimension\"\n#            super(MultiHeadAttention, self).__init__()\n#            self.d_k = dim_model // h\n#            self.h = h\n#            # W_q, W_k, W_v, W_o\n#            self.linears = clones(nn.Linear(dim_model, dim_model), 4)\n#\n#        def get(self, x, fields='qkv'):\n#            \"Return a dict of queries / keys / values.\"\n#            batch_size = x.shape[0]\n#            ret = {}\n#            if 'q' in fields:\n#                ret['q'] = self.linears[0](x).view(batch_size, self.h, self.d_k)\n#            if 'k' in fields:\n#                ret['k'] = self.linears[1](x).view(batch_size, self.h, self.d_k)\n#            if 'v' in fields:\n#                ret['v'] = self.linears[2](x).view(batch_size, self.h, self.d_k)\n#            return ret\n#\n#        def get_o(self, x):\n#            \"get output of the multi-head attention\"\n#            batch_size = x.shape[0]\n#            return self.linears[3](x.view(batch_size, -1))\n#\n#\n# How DGL implements Transformer with a graph neural network\n# ----------------------------------------------------------\n#\n# You get a different perspective of Transformer by treating the\n# attention as edges in a graph and adopt message passing on the edges to\n# induce the appropriate processing.\n#\n# Graph structure\n# ~~~~~~~~~~~~~~~\n#\n# Construct the graph by mapping tokens of the source and target\n# sentence to nodes. The complete Transformer graph is made up of three\n# subgraphs:\n#\n# **Source language graph**. This is a complete graph, each\n# token :math:`s_i` can attend to any other token :math:`s_j` (including\n# self-loops). |image0|\n# **Target language graph**. The graph is\n# half-complete, in that :math:`t_i` attends only to :math:`t_j` if\n# :math:`i > j` (an output token can not depend on future words). |image1|\n# **Cross-language graph**. This is a bi-partitie graph, where there is\n# an edge from every source token :math:`s_i` to every target token\n# :math:`t_j`, meaning every target token can attend on source tokens.\n# |image2|\n#\n# The full picture looks like this: |image3|\n#\n# Pre-build the graphs in dataset preparation stage.\n#\n# Message passing\n# ~~~~~~~~~~~~~~~\n#\n# Once you define the graph structure, move on to defining the\n# computation for message passing.\n#\n# Assuming that you have already computed all the queries :math:`q_i`, keys\n# :math:`k_i` and values :math:`v_i`. For each node :math:`i` (no matter\n# whether it is a source token or target token), you can decompose the\n# attention computation into two steps:\n#\n# 1. **Message computation:** Compute attention score\n#    :math:`\\mathrm{score}_{ij}` between :math:`i` and all nodes :math:`j`\n#    to be attended over, by taking the scaled-dot product between\n#    :math:`q_i` and :math:`k_j`. The message sent from :math:`j` to\n#    :math:`i` will consist of the score :math:`\\mathrm{score}_{ij}` and\n#    the value :math:`v_j`.\n# 2. **Message aggregation:** Aggregate the values :math:`v_j` from all\n#    :math:`j` according to the scores :math:`\\mathrm{score}_{ij}`.\n#\n# Simple implementation\n# ^^^^^^^^^^^^^^^^^^^^\n#\n# Message computation\n# '''''''''''''''''''\n#\n# Compute ``score`` and send source node’s ``v`` to destination’s mailbox\n#\n# .. code::\n#\n#    def message_func(edges):\n#        return {'score': ((edges.src['k'] * edges.dst['q'])\n#                          .sum(-1, keepdim=True)),\n#                'v': edges.src['v']}\n#\n# Message aggregation\n# '''''''''''''''''''\n#\n# Normalize over all in-edges and weighted sum to get output\n#\n# .. code::\n#\n#    import torch as th\n#    import torch.nn.functional as F\n#\n#    def reduce_func(nodes, d_k=64):\n#        v = nodes.mailbox['v']\n#        att = F.softmax(nodes.mailbox['score'] / th.sqrt(d_k), 1)\n#        return {'dx': (att * v).sum(1)}\n#\n# Execute on specific edges\n# '''''''''''''''''''''''''\n#\n# .. code::\n#\n#    import functools.partial as partial\n#    def naive_propagate_attention(self, g, eids):\n#        g.send_and_recv(eids, message_func, partial(reduce_func, d_k=self.d_k))\n#\n# Speeding up with built-in functions\n# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n#\n# To speed up the message passing process, use DGL’s built-in\n# functions, including:\n#\n# - ``fn.src_mul_egdes(src_field, edges_field, out_field)`` multiplies\n#   source’s attribute and edges attribute, and send the result to the\n#   destination node’s mailbox keyed by ``out_field``.\n# - ``fn.copy_e(edges_field, out_field)`` copies edge’s attribute to\n#   destination node’s mailbox.\n# - ``fn.sum(edges_field, out_field)`` sums up\n#   edge’s attribute and sends aggregation to destination node’s mailbox.\n#\n# Here, you assemble those built-in functions into ``propagate_attention``,\n# which is also the main graph operation function in the final\n# implementation. To accelerate it, break the ``softmax`` operation into\n# the following steps. Recall that for each head there are two phases.\n#\n# 1. Compute attention score by multiply src node’s ``k`` and dst node’s\n#    ``q``\n#\n#    -  ``g.apply_edges(src_dot_dst('k', 'q', 'score'), eids)``\n#\n# 2. Scaled Softmax over all dst nodes’ in-coming edges\n#\n#    -  Step 1: Exponentialize score with scale normalize constant\n#\n#       -  ``g.apply_edges(scaled_exp('score', np.sqrt(self.d_k)))``\n#\n#          .. math:: \\textrm{score}_{ij}\\leftarrow\\exp{\\left(\\frac{\\textrm{score}_{ij}}{ \\sqrt{d_k}}\\right)}\n#\n#    -  Step 2: Get the “values” on associated nodes weighted by “scores”\n#       on in-coming edges of each node; get the sum of “scores” on\n#       in-coming edges of each node for normalization. Note that here\n#       :math:`\\textrm{wv}` is not normalized.\n#\n#       -  ``msg: fn.u_mul_e('v', 'score', 'v'), reduce: fn.sum('v', 'wv')``\n#\n#          .. math:: \\textrm{wv}_j=\\sum_{i=1}^{N} \\textrm{score}_{ij} \\cdot v_i\n#\n#       -  ``msg: fn.copy_e('score', 'score'), reduce: fn.sum('score', 'z')``\n#\n#          .. math:: \\textrm{z}_j=\\sum_{i=1}^{N} \\textrm{score}_{ij}\n#\n# The normalization of :math:`\\textrm{wv}` is left to post processing.\n#\n# .. code::\n#\n#    def src_dot_dst(src_field, dst_field, out_field):\n#        def func(edges):\n#            return {out_field: (edges.src[src_field] * edges.dst[dst_field]).sum(-1, keepdim=True)}\n#\n#        return func\n#\n#    def scaled_exp(field, scale_constant):\n#        def func(edges):\n#            # clamp for softmax numerical stability\n#            return {field: th.exp((edges.data[field] / scale_constant).clamp(-5, 5))}\n#\n#        return func\n#\n#\n#    def propagate_attention(self, g, eids):\n#        # Compute attention score\n#        g.apply_edges(src_dot_dst('k', 'q', 'score'), eids)\n#        g.apply_edges(scaled_exp('score', np.sqrt(self.d_k)))\n#        # Update node state\n#        g.send_and_recv(eids,\n#                        [fn.u_mul_e('v', 'score', 'v'), fn.copy_e('score', 'score')],\n#                        [fn.sum('v', 'wv'), fn.sum('score', 'z')])\n#\n# Preprocessing and postprocessing\n# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n#\n# In Transformer, data needs to be pre- and post-processed before and\n# after the ``propagate_attention`` function.\n#\n# **Preprocessing** The preprocessing function ``pre_func`` first\n# normalizes the node representations and then map them to a set of\n# queries, keys and values, using self-attention as an example:\n#\n# .. math::\n#\n#\n#    x \\leftarrow \\textrm{LayerNorm}(x) \\\\\n#    [q, k, v] \\leftarrow [W_q, W_k, W_v ]\\cdot x\n#\n# **Postprocessing** The postprocessing function ``post_funcs`` completes\n# the whole computation correspond to one layer of the transformer: 1.\n# Normalize :math:`\\textrm{wv}` and get the output of Multi-Head Attention\n# Layer :math:`o`.\n#\n# .. math::\n#\n#\n#    \\textrm{wv} \\leftarrow \\frac{\\textrm{wv}}{z} \\\\\n#    o \\leftarrow W_o\\cdot \\textrm{wv} + b_o\n#\n# add residual connection:\n#\n# .. math::\n#\n#\n#    x \\leftarrow x + o\n#\n# 2. Applying a two layer position-wise feed forward layer on :math:`x`\n#    then add residual connection:\n#\n#    .. math::\n#\n#\n#       x \\leftarrow x + \\textrm{LayerNorm}(\\textrm{FFN}(x))\n#\n#    where :math:`\\textrm{FFN}` refers to the feed forward function.\n#\n# .. code::\n#\n#    class Encoder(nn.Module):\n#        def __init__(self, layer, N):\n#            super(Encoder, self).__init__()\n#            self.N = N\n#            self.layers = clones(layer, N)\n#            self.norm = LayerNorm(layer.size)\n#\n#        def pre_func(self, i, fields='qkv'):\n#            layer = self.layers[i]\n#            def func(nodes):\n#                x = nodes.data['x']\n#                norm_x = layer.sublayer[0].norm(x)\n#                return layer.self_attn.get(norm_x, fields=fields)\n#            return func\n#\n#        def post_func(self, i):\n#            layer = self.layers[i]\n#            def func(nodes):\n#                x, wv, z = nodes.data['x'], nodes.data['wv'], nodes.data['z']\n#                o = layer.self_attn.get_o(wv / z)\n#                x = x + layer.sublayer[0].dropout(o)\n#                x = layer.sublayer[1](x, layer.feed_forward)\n#                return {'x': x if i < self.N - 1 else self.norm(x)}\n#            return func\n#\n#    class Decoder(nn.Module):\n#        def __init__(self, layer, N):\n#            super(Decoder, self).__init__()\n#            self.N = N\n#            self.layers = clones(layer, N)\n#            self.norm = LayerNorm(layer.size)\n#\n#        def pre_func(self, i, fields='qkv', l=0):\n#            layer = self.layers[i]\n#            def func(nodes):\n#                x = nodes.data['x']\n#                if fields == 'kv':\n#                    norm_x = x # In enc-dec attention, x has already been normalized.\n#                else:\n#                    norm_x = layer.sublayer[l].norm(x)\n#                return layer.self_attn.get(norm_x, fields)\n#            return func\n#\n#        def post_func(self, i, l=0):\n#            layer = self.layers[i]\n#            def func(nodes):\n#                x, wv, z = nodes.data['x'], nodes.data['wv'], nodes.data['z']\n#                o = layer.self_attn.get_o(wv / z)\n#                x = x + layer.sublayer[l].dropout(o)\n#                if l == 1:\n#                    x = layer.sublayer[2](x, layer.feed_forward)\n#                return {'x': x if i < self.N - 1 else self.norm(x)}\n#            return func\n#\n# This completes all procedures of one layer of encoder and decoder in\n# Transformer.\n#\n# .. note::\n#\n#    The sublayer connection part is little bit different from the\n#    original paper. However, this implementation is the same as `The Annotated\n#    Transformer <http://nlp.seas.harvard.edu/2018/04/03/attention.html>`__\n#    and\n#    `OpenNMT <https://github.com/OpenNMT/OpenNMT-py/blob/cd29c1dbfb35f4a2701ff52a1bf4e5bdcf02802e/onmt/encoders/transformer.py>`__.\n#\n# Main class of Transformer graph\n# -------------------------------\n#\n# The processing flow of Transformer can be seen as a 2-stage\n# message-passing within the complete graph (adding pre- and post-\n# processing appropriately): 1) self-attention in encoder, 2)\n# self-attention in decoder followed by cross-attention between encoder\n# and decoder, as shown below. |image4|\n#\n# .. code:: python\n#\n#    class Transformer(nn.Module):\n#        def __init__(self, encoder, decoder, src_embed, tgt_embed, pos_enc, generator, h, d_k):\n#            super(Transformer, self).__init__()\n#            self.encoder, self.decoder = encoder, decoder\n#            self.src_embed, self.tgt_embed = src_embed, tgt_embed\n#            self.pos_enc = pos_enc\n#            self.generator = generator\n#            self.h, self.d_k = h, d_k\n#\n#        def propagate_attention(self, g, eids):\n#            # Compute attention score\n#            g.apply_edges(src_dot_dst('k', 'q', 'score'), eids)\n#            g.apply_edges(scaled_exp('score', np.sqrt(self.d_k)))\n#            # Send weighted values to target nodes\n#            g.send_and_recv(eids,\n#                            [fn.u_mul_e('v', 'score', 'v'), fn.copy_e('score', 'score')],\n#                            [fn.sum('v', 'wv'), fn.sum('score', 'z')])\n#\n#        def update_graph(self, g, eids, pre_pairs, post_pairs):\n#            \"Update the node states and edge states of the graph.\"\n#\n#            # Pre-compute queries and key-value pairs.\n#            for pre_func, nids in pre_pairs:\n#                g.apply_nodes(pre_func, nids)\n#            self.propagate_attention(g, eids)\n#            # Further calculation after attention mechanism\n#            for post_func, nids in post_pairs:\n#                g.apply_nodes(post_func, nids)\n#\n#        def forward(self, graph):\n#            g = graph.g\n#            nids, eids = graph.nids, graph.eids\n#\n#            # Word Embedding and Position Embedding\n#            src_embed, src_pos = self.src_embed(graph.src[0]), self.pos_enc(graph.src[1])\n#            tgt_embed, tgt_pos = self.tgt_embed(graph.tgt[0]), self.pos_enc(graph.tgt[1])\n#            g.nodes[nids['enc']].data['x'] = self.pos_enc.dropout(src_embed + src_pos)\n#            g.nodes[nids['dec']].data['x'] = self.pos_enc.dropout(tgt_embed + tgt_pos)\n#\n#            for i in range(self.encoder.N):\n#                # Step 1: Encoder Self-attention\n#                pre_func = self.encoder.pre_func(i, 'qkv')\n#                post_func = self.encoder.post_func(i)\n#                nodes, edges = nids['enc'], eids['ee']\n#                self.update_graph(g, edges, [(pre_func, nodes)], [(post_func, nodes)])\n#\n#            for i in range(self.decoder.N):\n#                # Step 2: Dncoder Self-attention\n#                pre_func = self.decoder.pre_func(i, 'qkv')\n#                post_func = self.decoder.post_func(i)\n#                nodes, edges = nids['dec'], eids['dd']\n#                self.update_graph(g, edges, [(pre_func, nodes)], [(post_func, nodes)])\n#                # Step 3: Encoder-Decoder attention\n#                pre_q = self.decoder.pre_func(i, 'q', 1)\n#                pre_kv = self.decoder.pre_func(i, 'kv', 1)\n#                post_func = self.decoder.post_func(i, 1)\n#                nodes_e, nodes_d, edges = nids['enc'], nids['dec'], eids['ed']\n#                self.update_graph(g, edges, [(pre_q, nodes_d), (pre_kv, nodes_e)], [(post_func, nodes_d)])\n#\n#            return self.generator(g.ndata['x'][nids['dec']])\n#\n#\n# .. note::\n#\n#    By calling ``update_graph`` function, you can create your own\n#    Transformer on any subgraphs with nearly the same code. This\n#    flexibility enables us to discover new, sparse structures (c.f. local attention\n#    mentioned `here <https://arxiv.org/pdf/1508.04025.pdf>`__). Note in this\n#    implementation you don't use mask or padding, which makes the logic\n#    more clear and saves memory. The trade-off is that the implementation is\n#    slower.\n#\n# Training\n# --------\n#\n# This tutorial does not cover several other techniques such as Label\n# Smoothing and Noam Optimizations mentioned in the original paper. For\n# detailed description about these modules, read `The\n# Annotated\n# Transformer <http://nlp.seas.harvard.edu/2018/04/03/attention.html>`__\n# written by Harvard NLP team.\n#\n# Task and the dataset\n# ~~~~~~~~~~~~~~~~~~~~\n#\n# The Transformer is a general framework for a variety of NLP tasks. This tutorial focuses\n# on the sequence to sequence learning: it’s a typical case to illustrate how it works.\n#\n# As for the dataset, there are two example tasks: copy and sort, together\n# with two real-world translation tasks: multi30k en-de task and wmt14\n# en-de task.\n#\n# -  **copy dataset**: copy input sequences to output. (train/valid/test:\n#    9000, 1000, 1000)\n# -  **sort dataset**: sort input sequences as output. (train/valid/test:\n#    9000, 1000, 1000)\n# -  **Multi30k en-de**, translate sentences from En to De.\n#    (train/valid/test: 29000, 1000, 1000)\n# -  **WMT14 en-de**, translate sentences from En to De.\n#    (Train/Valid/Test: 4500966/3000/3003)\n#\n# .. note::\n#    Training with wmt14 requires multi-GPU support and is not available. Contributions are welcome!\n#\n# Graph building\n# ~~~~~~~~~~~~~~\n#\n# **Batching** This is similar to the way you handle Tree-LSTM. Build a graph pool in\n# advance, including all possible combination of input lengths and output\n# lengths. Then for each sample in a batch, call ``dgl.batch`` to batch\n# graphs of their sizes together in to a single large graph.\n#\n# You can wrap the process of creating graph pool and building\n# BatchedGraph in ``dataset.GraphPool`` and\n# ``dataset.TranslationDataset``.\n#\n# .. code:: python\n#\n#    graph_pool = GraphPool()\n#\n#    data_iter = dataset(graph_pool, mode='train', batch_size=1, devices=devices)\n#    for graph in data_iter:\n#        print(graph.nids['enc']) # encoder node ids\n#        print(graph.nids['dec']) # decoder node ids\n#        print(graph.eids['ee']) # encoder-encoder edge ids\n#        print(graph.eids['ed']) # encoder-decoder edge ids\n#        print(graph.eids['dd']) # decoder-decoder edge ids\n#        print(graph.src[0]) # Input word index list\n#        print(graph.src[1]) # Input positions\n#        print(graph.tgt[0]) # Output word index list\n#        print(graph.tgt[1]) # Ouptut positions\n#        break\n#\n# Output:\n#\n# .. code::\n#\n#    tensor([0, 1, 2, 3, 4, 5, 6, 7, 8], device='cuda:0')\n#    tensor([ 9, 10, 11, 12, 13, 14, 15, 16, 17, 18], device='cuda:0')\n#    tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,\n#            18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,\n#            36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,\n#            54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71,\n#            72, 73, 74, 75, 76, 77, 78, 79, 80], device='cuda:0')\n#    tensor([ 81,  82,  83,  84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,\n#             95,  96,  97,  98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108,\n#            109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122,\n#            123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136,\n#            137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150,\n#            151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164,\n#            165, 166, 167, 168, 169, 170], device='cuda:0')\n#    tensor([171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184,\n#            185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198,\n#            199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212,\n#            213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225],\n#           device='cuda:0')\n#    tensor([28, 25,  7, 26,  6,  4,  5,  9, 18], device='cuda:0')\n#    tensor([0, 1, 2, 3, 4, 5, 6, 7, 8], device='cuda:0')\n#    tensor([ 0, 28, 25,  7, 26,  6,  4,  5,  9, 18], device='cuda:0')\n#    tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')\n#\n# Put it all together\n# -------------------\n#\n# Train a one-head transformer with one layer, 128 dimension on copy\n# task. Set other parameters to the default.\n#\n# Inference module is not included in this tutorial. It\n# requires beam search. For a full implementation, see the `GitHub\n# repo <https://github.com/dmlc/dgl/tree/master/examples/pytorch/transformer>`__.\n#\n# .. code:: python\n#\n#    from tqdm.auto import tqdm\n#    import torch as th\n#    import numpy as np\n#\n#    from loss import LabelSmoothing, SimpleLossCompute\n#    from modules import make_model\n#    from optims import NoamOpt\n#    from dgl.contrib.transformer import get_dataset, GraphPool\n#\n#    def run_epoch(data_iter, model, loss_compute, is_train=True):\n#        for i, g in tqdm(enumerate(data_iter)):\n#            with th.set_grad_enabled(is_train):\n#                output = model(g)\n#                loss = loss_compute(output, g.tgt_y, g.n_tokens)\n#        print('average loss: {}'.format(loss_compute.avg_loss))\n#        print('accuracy: {}'.format(loss_compute.accuracy))\n#\n#    N = 1\n#    batch_size = 128\n#    devices = ['cuda' if th.cuda.is_available() else 'cpu']\n#\n#    dataset = get_dataset(\"copy\")\n#    V = dataset.vocab_size\n#    criterion = LabelSmoothing(V, padding_idx=dataset.pad_id, smoothing=0.1)\n#    dim_model = 128\n#\n#    # Create model\n#    model = make_model(V, V, N=N, dim_model=128, dim_ff=128, h=1)\n#\n#    # Sharing weights between Encoder & Decoder\n#    model.src_embed.lut.weight = model.tgt_embed.lut.weight\n#    model.generator.proj.weight = model.tgt_embed.lut.weight\n#\n#    model, criterion = model.to(devices[0]), criterion.to(devices[0])\n#    model_opt = NoamOpt(dim_model, 1, 400,\n#                        th.optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.98), eps=1e-9))\n#    loss_compute = SimpleLossCompute\n#\n#    att_maps = []\n#    for epoch in range(4):\n#        train_iter = dataset(graph_pool, mode='train', batch_size=batch_size, devices=devices)\n#        valid_iter = dataset(graph_pool, mode='valid', batch_size=batch_size, devices=devices)\n#        print('Epoch: {} Training...'.format(epoch))\n#        model.train(True)\n#        run_epoch(train_iter, model,\n#                  loss_compute(criterion, model_opt), is_train=True)\n#        print('Epoch: {} Evaluating...'.format(epoch))\n#        model.att_weight_map = None\n#        model.eval()\n#        run_epoch(valid_iter, model,\n#                  loss_compute(criterion, None), is_train=False)\n#        att_maps.append(model.att_weight_map)\n#\n# Visualization\n# -------------\n#\n# After training, you can visualize the attention that the Transformer generates\n# on copy task.\n#\n# .. code:: python\n#\n#    src_seq = dataset.get_seq_by_id(VIZ_IDX, mode='valid', field='src')\n#    tgt_seq = dataset.get_seq_by_id(VIZ_IDX, mode='valid', field='tgt')[:-1]\n#    # visualize head 0 of encoder-decoder attention\n#    att_animation(att_maps, 'e2d', src_seq, tgt_seq, 0)\n#\n# |image5| from the figure you see the decoder nodes gradually learns to\n# attend to corresponding nodes in input sequence, which is the expected\n# behavior.\n#\n# Multi-head attention\n# ~~~~~~~~~~~~~~~~~~~~\n#\n# Besides the attention of a one-head attention trained on toy task. We\n# also visualize the attention scores of Encoder’s Self Attention,\n# Decoder’s Self Attention and the Encoder-Decoder attention of an\n# one-Layer Transformer network trained on multi-30k dataset.\n#\n# From the visualization you see the diversity of different heads, which is what you would\n# expect. Different heads learn different relations between word pairs.\n#\n# -  **Encoder Self-Attention** |image6|\n#\n# -  **Encoder-Decoder Attention** Most words in target sequence attend on\n#    their related words in source sequence, for example: when generating\n#    “See” (in De), several heads attend on “lake”; when generating\n#    “Eisfischerhütte”, several heads attend on “ice”. |image7|\n#\n# -  **Decoder Self-Attention** Most words attend on their previous few\n#    words. |image8|\n#\n# Adaptive Universal Transformer\n# ------------------------------\n#\n# A recent research paper by Google, `Universal\n# Transformer <https://arxiv.org/pdf/1807.03819.pdf>`__, is an example to\n# show how ``update_graph`` adapts to more complex updating rules.\n#\n# The Universal Transformer was proposed to address the problem that\n# vanilla Transformer is not computationally universal by introducing\n# recurrence in Transformer:\n#\n# -  The basic idea of Universal Transformer is to repeatedly revise its\n#    representations of all symbols in the sequence with each recurrent\n#    step by applying a Transformer layer on the representations.\n# -  Compared to vanilla Transformer, Universal Transformer shares weights\n#    among its layers, and it does not fix the recurrence time (which\n#    means the number of layers in Transformer).\n#\n# A further optimization employs an `adaptive computation time\n# (ACT) <https://arxiv.org/pdf/1603.08983.pdf>`__ mechanism to allow the\n# model to dynamically adjust the number of times the representation of\n# each position in a sequence is revised (refereed to as **step**\n# hereafter). This model is also known as the Adaptive Universal\n# Transformer (AUT).\n#\n# In AUT, you maintain an active nodes list. In each step :math:`t`, we\n# compute a halting probability: :math:`h (0<h<1)` for all nodes in this\n# list by:\n#\n# .. math::  h^t_i = \\sigma(W_h x^t_i + b_h)\n#\n# then dynamically decide which nodes are still active. A node is halted\n# at time :math:`T` if and only if\n# :math:`\\sum_{t=1}^{T-1} h_t < 1 - \\varepsilon \\leq \\sum_{t=1}^{T}h_t`.\n# Halted nodes are removed from the list. The procedure proceeds until the\n# list is empty or a pre-defined maximum step is reached. From DGL’s\n# perspective, this means that the “active” graph becomes sparser over\n# time.\n#\n# The final state of a node :math:`s_i` is a weighted average of\n# :math:`x_i^t` by :math:`h_i^t`:\n#\n# .. math::  s_i = \\sum_{t=1}^{T} h_i^t\\cdot x_i^t\n#\n# In DGL, implement an algorithm by calling\n# ``update_graph`` on nodes that are still active and edges associated\n# with this nodes. The following code shows the Universal Transformer\n# class in DGL:\n#\n# .. code::\n#\n#    class UTransformer(nn.Module):\n#        \"Universal Transformer(https://arxiv.org/pdf/1807.03819.pdf) with ACT(https://arxiv.org/pdf/1603.08983.pdf).\"\n#        MAX_DEPTH = 8\n#        thres = 0.99\n#        act_loss_weight = 0.01\n#        def __init__(self, encoder, decoder, src_embed, tgt_embed, pos_enc, time_enc, generator, h, d_k):\n#            super(UTransformer, self).__init__()\n#            self.encoder,  self.decoder = encoder, decoder\n#            self.src_embed, self.tgt_embed = src_embed, tgt_embed\n#            self.pos_enc, self.time_enc = pos_enc, time_enc\n#            self.halt_enc = HaltingUnit(h * d_k)\n#            self.halt_dec = HaltingUnit(h * d_k)\n#            self.generator = generator\n#            self.h, self.d_k = h, d_k\n#\n#        def step_forward(self, nodes):\n#            # add positional encoding and time encoding, increment step by one\n#            x = nodes.data['x']\n#            step = nodes.data['step']\n#            pos = nodes.data['pos']\n#            return {'x': self.pos_enc.dropout(x + self.pos_enc(pos.view(-1)) + self.time_enc(step.view(-1))),\n#                    'step': step + 1}\n#\n#        def halt_and_accum(self, name, end=False):\n#            \"field: 'enc' or 'dec'\"\n#            halt = self.halt_enc if name == 'enc' else self.halt_dec\n#            thres = self.thres\n#            def func(nodes):\n#                p = halt(nodes.data['x'])\n#                sum_p = nodes.data['sum_p'] + p\n#                active = (sum_p < thres) & (1 - end)\n#                _continue = active.float()\n#                r = nodes.data['r'] * (1 - _continue) + (1 - sum_p) * _continue\n#                s = nodes.data['s'] + ((1 - _continue) * r + _continue * p) * nodes.data['x']\n#                return {'p': p, 'sum_p': sum_p, 'r': r, 's': s, 'active': active}\n#            return func\n#\n#        def propagate_attention(self, g, eids):\n#            # Compute attention score\n#            g.apply_edges(src_dot_dst('k', 'q', 'score'), eids)\n#            g.apply_edges(scaled_exp('score', np.sqrt(self.d_k)), eids)\n#            # Send weighted values to target nodes\n#            g.send_and_recv(eids,\n#                            [fn.u_mul_e('v', 'score', 'v'), fn.copy_e('score', 'score')],\n#                            [fn.sum('v', 'wv'), fn.sum('score', 'z')])\n#\n#        def update_graph(self, g, eids, pre_pairs, post_pairs):\n#            \"Update the node states and edge states of the graph.\"\n#            # Pre-compute queries and key-value pairs.\n#            for pre_func, nids in pre_pairs:\n#                g.apply_nodes(pre_func, nids)\n#            self.propagate_attention(g, eids)\n#            # Further calculation after attention mechanism\n#            for post_func, nids in post_pairs:\n#                g.apply_nodes(post_func, nids)\n#\n#        def forward(self, graph):\n#            g = graph.g\n#            N, E = graph.n_nodes, graph.n_edges\n#            nids, eids = graph.nids, graph.eids\n#\n#            # embed & pos\n#            g.nodes[nids['enc']].data['x'] = self.src_embed(graph.src[0])\n#            g.nodes[nids['dec']].data['x'] = self.tgt_embed(graph.tgt[0])\n#            g.nodes[nids['enc']].data['pos'] = graph.src[1]\n#            g.nodes[nids['dec']].data['pos'] = graph.tgt[1]\n#\n#            # init step\n#            device = next(self.parameters()).device\n#            g.ndata['s'] = th.zeros(N, self.h * self.d_k, dtype=th.float, device=device)    # accumulated state\n#            g.ndata['p'] = th.zeros(N, 1, dtype=th.float, device=device)                    # halting prob\n#            g.ndata['r'] = th.ones(N, 1, dtype=th.float, device=device)                     # remainder\n#            g.ndata['sum_p'] = th.zeros(N, 1, dtype=th.float, device=device)                # sum of pondering values\n#            g.ndata['step'] = th.zeros(N, 1, dtype=th.long, device=device)                  # step\n#            g.ndata['active'] = th.ones(N, 1, dtype=th.uint8, device=device)                # active\n#\n#            for step in range(self.MAX_DEPTH):\n#                pre_func = self.encoder.pre_func('qkv')\n#                post_func = self.encoder.post_func()\n#                nodes = g.filter_nodes(lambda v: v.data['active'].view(-1), nids['enc'])\n#                if len(nodes) == 0: break\n#                edges = g.filter_edges(lambda e: e.dst['active'].view(-1), eids['ee'])\n#                end = step == self.MAX_DEPTH - 1\n#                self.update_graph(g, edges,\n#                                  [(self.step_forward, nodes), (pre_func, nodes)],\n#                                  [(post_func, nodes), (self.halt_and_accum('enc', end), nodes)])\n#\n#            g.nodes[nids['enc']].data['x'] = self.encoder.norm(g.nodes[nids['enc']].data['s'])\n#\n#            for step in range(self.MAX_DEPTH):\n#                pre_func = self.decoder.pre_func('qkv')\n#                post_func = self.decoder.post_func()\n#                nodes = g.filter_nodes(lambda v: v.data['active'].view(-1), nids['dec'])\n#                if len(nodes) == 0: break\n#                edges = g.filter_edges(lambda e: e.dst['active'].view(-1), eids['dd'])\n#                self.update_graph(g, edges,\n#                                  [(self.step_forward, nodes), (pre_func, nodes)],\n#                                  [(post_func, nodes)])\n#\n#                pre_q = self.decoder.pre_func('q', 1)\n#                pre_kv = self.decoder.pre_func('kv', 1)\n#                post_func = self.decoder.post_func(1)\n#                nodes_e = nids['enc']\n#                edges = g.filter_edges(lambda e: e.dst['active'].view(-1), eids['ed'])\n#                end = step == self.MAX_DEPTH - 1\n#                self.update_graph(g, edges,\n#                                  [(pre_q, nodes), (pre_kv, nodes_e)],\n#                                  [(post_func, nodes), (self.halt_and_accum('dec', end), nodes)])\n#\n#            g.nodes[nids['dec']].data['x'] = self.decoder.norm(g.nodes[nids['dec']].data['s'])\n#            act_loss = th.mean(g.ndata['r']) # ACT loss\n#\n#            return self.generator(g.ndata['x'][nids['dec']]), act_loss * self.act_loss_weight\n#\n# Call ``filter_nodes`` and ``filter_edge`` to find nodes/edges\n# that are still active:\n#\n# .. note::\n#\n#    - :func:`~dgl.DGLGraph.filter_nodes` takes a predicate and a node\n#      ID list/tensor as input, then returns a tensor of node IDs that satisfy\n#      the given predicate.\n#    - :func:`~dgl.DGLGraph.filter_edges` takes a predicate\n#      and an edge ID list/tensor as input, then returns a tensor of edge IDs\n#      that satisfy the given predicate.\n#\n# For the full implementation, see the `GitHub\n# repo <https://github.com/dmlc/dgl/tree/master/examples/pytorch/transformer/modules/act.py>`__.\n#\n# The figure below shows the effect of Adaptive Computational\n# Time. Different positions of a sentence were revised different times.\n#\n# |image9|\n#\n# You can also visualize the dynamics of step distribution on nodes during the\n# training of AUT on sort task(reach 99.7% accuracy), which demonstrates\n# how AUT learns to reduce recurrence steps during training. |image10|\n#\n# .. |image0| image:: https://i.imgur.com/zV5LmTX.png\n# .. |image1| image:: https://i.imgur.com/dETQMMx.png\n# .. |image2| image:: https://i.imgur.com/hnGP229.png\n# .. |image3| image:: https://i.imgur.com/Hj2rRGT.png\n# .. |image4| image:: https://i.imgur.com/zlUpJ41.png\n# .. |image5| image:: https://s1.ax1x.com/2018/12/06/F126xI.gif\n# .. |image6| image:: https://i.imgur.com/HjYb7F2.png\n# .. |image7| image:: https://i.imgur.com/383J5O5.png\n# .. |image8| image:: https://i.imgur.com/c0UWB1V.png\n# .. |image9| image:: https://s1.ax1x.com/2018/12/06/F1sGod.png\n# .. |image10| image:: https://s1.ax1x.com/2018/12/06/F1r8Cq.gif\n#\n# .. note::\n#     The notebook itself is not executable due to many dependencies.\n#     Download `7_transformer.py <https://data.dgl.ai/tutorial/7_transformer.py>`__,\n#     and copy the python script to directory ``examples/pytorch/transformer``\n#     then run ``python 7_transformer.py`` to see how it works.\n"
  }
]